diff --git a/.gitignore b/.gitignore index 828bbe9bd3363853ae3f58f54a8d5f60cefad837..5afe375f46f07b3b557ae23f75740b337517d3bd 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ __pycache__ cmake_build/ .idea/** /build/ +[Bb]uild/ /tensorflow/core/util/version_info.cc /tensorflow/python/framework/fast_tensor_util.cpp Pods @@ -28,6 +29,7 @@ Podfile.lock /tensorflow/contrib/lite/examples/ios/simple/data/*.tflite xcuserdata/** /api_init_files_list.txt +/estimator_api_init_files_list.txt # Android .gradle diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index db4b1581ae671b1e676e215c9a80dfaab832fa21..f598999f351c10f8bd01dfbd3ad8897f19d570e8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -107,7 +107,7 @@ diff /tmp/my_cc_file.cc #### Python coding style Changes to TensorFlow Python code should conform to -[Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) +[Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md) Use `pylint` to check your Python changes. To install `pylint` and retrieve TensorFlow's custom style definition: diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 2f3df7cda9cec29ed0c2266629022f0a22b37df9..52faed9297cfcaf8c93bb9c79686c9258a53c560 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -15,9 +15,10 @@ If you open a GitHub issue, here is our policy: ### System information - **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**: - **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**: +- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device**: - **TensorFlow installed from (source or binary)**: - **TensorFlow version (use command below)**: -- **Python version**: +- **Python version**: - **Bazel version (if compiling from source)**: - **GCC/Compiler version (if compiling from source)**: - **CUDA/cuDNN version**: diff --git a/README.md b/README.md index 63853137cfd30b396f8c7d204811f3e4a1794c07..669ff5b711c62455f48038743ca1e089fa23d9e6 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ data flow graphs. The graph nodes represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) that flow between them. This flexible architecture enables you to deploy computation to one or more CPUs or GPUs in a desktop, server, or mobile device without rewriting -code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard), a data visualization toolkit. +code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard), a data visualization toolkit. TensorFlow was originally developed by researchers and engineers working on the Google Brain team within Google's Machine Intelligence Research @@ -22,6 +22,8 @@ organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. +TensorFlow provides stable Python API and C APIs as well as without API backwards compatibility guarantee like C++, Go, Java, JavaScript and Swift. + Keep up to date with release announcements and security updates by subscribing to [announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce). @@ -81,13 +83,13 @@ The TensorFlow project strives to abide by generally accepted best practices in | Build Type | Status | Artifacts | | --- | --- | --- | -| **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | -| **Linux XLA** | TBA | TBA | -| **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Windows CPU** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Windows GPU** | [![Status](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/badge/icon)](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | -| **Android** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) [build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) | +| **Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA | +| **MacOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | ### Community Supported Builds @@ -96,16 +98,21 @@ The TensorFlow project strives to abide by generally accepted best practices in | --- | --- | --- | | **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | | **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | +| **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | +| **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | +| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6| ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)|[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)
[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)
[1.9.0 py3.6](https://storage.cloud.google.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) | ## For more information - +* [Tensorflow Blog](https://medium.com/tensorflow) +* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) +* [TensorFlow Model Zoo](https://github.com/tensorflow/models) +* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) +* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) +* [Tensorflow Twitter](https://twitter.com/tensorflow) * [TensorFlow Website](https://www.tensorflow.org) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) -* [TensorFlow Model Zoo](https://github.com/tensorflow/models) -* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) -* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate. diff --git a/RELEASE.md b/RELEASE.md index e09e9c6190f57adec67c2ae1d85848dabfd9c2a7..763ef3b279dde209ed387534032deae40a33a9e4 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,31 +1,120 @@ +# Release 1.10.0 + +## Major Features And Improvements + +* The `tf.lite` runtime now supports `complex64`. +* Initial [Google Cloud Bigtable integration](https://github.com/tensorflow/tensorflow/tree/r1.10/tensorflow/contrib/bigtable) for `tf.data`. +* Improved local run behavior in `tf.estimator.train_and_evaluate` which does not reload checkpoints for evaluation. +* `RunConfig` now sets device_filters to restrict how workers and PS can communicate. This can speed up training and ensure clean shutdowns in some situations. But if you have jobs that require communication between workers, you will have to set custom session_options in your `RunConfig`. +* Moved Distributions and Bijectors from `tf.contrib.distributions` to [Tensorflow Probability (TFP)](https://github.com/tensorflow/probability). `tf.contrib.distributions` is now deprecated and will be removed by the end of 2018. +* Adding new endpoints for existing tensorflow symbols. These endpoints are going to be the preferred endpoints going forward and may replace some of the existing endpoints in the future. See below for the complete list. New symbols have been added to the following modules: [`tf.debugging`](https://www.tensorflow.org/versions/master/api_docs/python/tf/debugging), [`tf.dtypes`](https://www.tensorflow.org/versions/master/api_docs/python/tf/dtypes), [`tf.image`](https://www.tensorflow.org/versions/master/api_docs/python/tf/image), [`tf.io`](https://www.tensorflow.org/versions/master/api_docs/python/tf/io), [`tf.linalg`](https://www.tensorflow.org/versions/master/api_docs/python/tf/linalg), [`tf.manip`](https://www.tensorflow.org/versions/master/api_docs/python/tf/manip), [`tf.math`](https://www.tensorflow.org/versions/master/api_docs/python/tf/math), [`tf.quantization`](https://www.tensorflow.org/versions/master/api_docs/python/tf/quantization), [`tf.strings`](https://www.tensorflow.org/versions/master/api_docs/python/tf/strings) + +## Breaking Changes + +* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) and [Install TensorFlow from Sources](https://www.tensorflow.org/install/install_sources#optional_install_tensorflow_for_gpu_prerequisites). +* Starting from TensorFlow 1.11, Windows builds will use Bazel. Therefore, we will drop official support for cmake. + +## Bug Fixes and Other Changes + +* `tf.data`: + * `tf.contrib.data.group_by_reducer()` is now available via the public API. + * `tf.contrib.data.choose_from_datasets()` is now available via the public API. + * Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating `tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`. +* `tf.estimator`: + * `Estimator`s now use custom savers included in `EstimatorSpec` scaffolds for saving SavedModels during export. + * `EstimatorSpec` will now add a default prediction output for export if no `export_output` is provided, eliminating the need to explicitly include a `PredictOutput` object in the `model_fn` for simple use-cases. + * Support sparse_combiner in canned Linear Estimators. + * Added batch normalization to `DNNClassifier`, `DNNRegressor`, and `DNNEstimator`. + * Adding ranking support for boosted trees. + * Adding center bias option for boosted trees. +* Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables. +* Add `synchronization` and `aggregation` args to the layer `add_weight()` API. These args will be used for distributed variables. +* `tf.losses.*` do not add to the global collection when executing eagerly (to avoid leaking memory). +* Support different summary and checkpoint directories in `tf.train.MonitoredTrainingSession()`. +* Added IndRNN, IndyGRU, and IndyLSTM cells to `tf.contrib.rnn`. +* Add safe static factory functions for SparseTensor and convert all CHECKs to DCHECKs. Using the constructor directly is unsafe and deprecated. +* Make the Bigtable client connection pool configurable & increase the default # of connections for performance. +* Added derivative of `tf.random_gamma` with respect to the alpha parameter. +* Added derivative of `tf.igamma(a, x)` and `tf.igammac(a, x)` with respect to a. +* Modified Bessel functions of order zero and one. +* Add FillTriangular Bijector to create triangular matrices. +* Added support for Type III DCT, and `tf.spectral.idct(type=2|3)`. +* Correctly handle CuDNN RNN weight loaded when nest in `TimeDistributed`. +* Adding per-element weight support for `WALSComputePartialLhsAndRhsOp`. +* ZerosLike and OnesLike ops treated as constants by Graph Transform Tool. +* Gamma distribution and the derived distributions (Beta, Dirichlet, Student's t, inverse Gamma) now fully reparameterized. +* Java: Experimental wrapper classes to make graph generation easier. Thanks @karllessard and @kbsriram +* Build & link in secure gRPC components (switch from the insecure grpc dependency to secure grpc dependency). +* Adding new endpoints for existing tensorflow symbols. These endpoints are going to be the preferred endpoints going forward and may replace some of the existing endpoints in the future. List of new endpoints: + * New endpoints in `tf.image` namespace: `tf.image.extract_image_patches` + * New endpoints in `tf.debugging` namespace: `tf.debugging.check_numerics`, `tf.debugging.is_finite`, `tf.debugging.is_inf`, `tf.debugging.is_nan`. + * New endpoints in `tf.dtypes` namespace: `tf.dtypes.as_string`. + * New endpoints in `tf.io` namespace: `tf.io.decode_base64`, `tf.io.decode_compressed`, `tf.io.decode_json_example`, `tf.io.decode_raw`, `tf.io.encode_base64`, `tf.io.matching_files`, `tf.io.parse_tensor`, `tf.io.read_file, `tf.io.write_file`. + * New endpoints in tf.linalg namespace: `tf.linalg.cross`, `tf.linalg.tensor_diag` (corresponds to `tf.diag`), `tf.linalg.tensor_diag_part` (corresponds to `tf.diag_part`). + * New endpoints in tf.manip namespace: `tf.manip.batch_to_space_nd`, `tf.manip.gather_nd`, `tf.manip.reshape`, `tf.manip.reverse`, `tf.manip.scatter_nd`, `tf.manip.space_to_batch_nd`, `tf.manip.tile` + * New endpoints in tf.math namespace: `tf.math.acos`, `tf.math.acosh`, `tf.math.add`, `tf.math.asin`, `tf.math.asinh`, `tf.math.atan`, `tf.math.atan2`, `tf.math.atanh`, `tf.math.betainc`, `tf.math.ceil`, `tf.math.cos`, `tf.math.cosh`, `tf.math.digamma`, `tf.math.equal`, `tf.math.erfc`, `tf.math.exp`, `tf.math.expm1`, `tf.math.floor`, `tf.math.greater`, `tf.math.greater_equal`, `tf.math.igamma`, `tf.math.igammac`, `tf.math.invert_permutation`, `tf.math.less`, `tf.math.less_equal`, `tf.math.lgamma`, `tf.math.log`, `tf.math.log1p`, `tf.math.logical_and`, `tf.math.logical_not`, `tf.math.logical_or`, `tf.math.maximum`, `tf.math.minimum`, `tf.math.not_equal`, `tf.math.polygamma`, `tf.math.reciprocal`, `tf.math.rint`, `tf.math.rsqrt`, `tf.math.segment_max`, `tf.math.segment_mean`, `tf.math.segment_min`, `tf.math.segment_prod`, `tf.math.segment_sum`, `tf.math.sin`, `tf.math.sinh`, `tf.math.softplus`, `tf.math.softsign`, `tf.math.squared_difference`, `tf.math.tan`, `tf.math.unsorted_segment_max`, `tf.math.unsorted_segment_min`, `tf.math.unsorted_segment_prod`, `tf.math.unsorted_segment_sum`, `tf.math.zeta`. + * New endpoints in `tf.quantization` namespace: `tf.quantization.dequantize`, `tf.quantization.fake_quant_with_min_max_args`, `tf.quantization.fake_quant_with_min_max_args_gradient`, `tf.quantization.fake_quant_with_min_max_vars`, `tf.quantization.fake_quant_with_min_max_vars_gradient`, `tf.quantization.fake_quant_with_min_max_vars_per_channel`, `tf.quantization.fake_quant_with_min_max_vars_per_channel_gradient`. + * New endpoints in tf.strings namespace: `tf.strings.join` (corresponds to `tf.string_join`), `tf.strings.regex_replace`, `tf.strings.to_number` (corresponds to `tf.string_to_number`), `tf.strings.strip` (corresponds to `tf.string_strip`), `tf.strings.substr`, `tf.strings.to_hash_bucket` (corresponds to `tf.string_to_hash_bucket`), `tf.strings.to_hash_bucket_fast` (corresponds to `tf.string_to_hash_bucket_fast`), `tf.strings.to_hash_bucket_strong` (corresponds to `tf.string_to_hash_bucket_strong`). + + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, Andrei Nigmatulin, Andrew Ginns, BjøRn Moholt, Brett Koonce, Chengzhi Chen, Chinmay Das, Christian Ertler, Christoph Boeddeker, Clayne Robison, Courtial Florian, ctiijima, Dan Douthit, Dan J, Dan Ringwalt, EFanZh, Emanuele Ballarin, eqy, Evgeniy Zheltonozhskiy, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, G K, gracehoney, Guillaume Klein, Guozhong Zhuang, Hsien-Yang Li, hsm207, ImSheridan, Jayaram Bobba, Jiandong Ruan, Jie, Joel Shor, Jonas Rauber, Jongmin Baek, jsawruk, Karan Kaw, Karl Lessard, karl@kubx.ca, Kb Sriram, KinmanLam, leiiwang, Li, Yiqiang, Loo Rong Jie, Mahmoud Abuzaina, Mahmoud Aslan, ManHyuk, Martin Patz, Martin Zeitler, mktozk, Mohammad Ashraf Bhuiyan, mrTsjolder, Naman Bhalla, Nick Felt, Nicolas Lopez, Niranjan Hasabnis, Nishidha Panpaliya, Nitish, nrstott, Nutti, Parag Jain, PeterLee, Philipp Jund, Rach L, Rafal Wojdyla, Roland Zimmermann, Sergei Lebedev, SneakyFish5, Soila Kavulya, Sriram Veturi, Steven Schmatz, Taehoon Lee, Tang, Wenyi, Taras Sereda, Ted Chang, Tim Zaman, Tristan Rice, tucan, vchigrin, Vikram Tiwari, Vincent, WeberXie, William D. Irons, Yan Facai (颜发才), Yong Tang, Yu Yi, Yuxin Wu, Zé ViníCius + # Release 1.9.0 ## Major Features And Improvements -* Update tf.keras to the Keras 2.1.6 API. +* Updated docs for `tf.keras`: New Keras-based [get started](http://tensorflow.org/versions/r1.9/get_started), + and [programmers guide page](http://tensorflow.org/versions/r1.9/programmers_guide/keras). +* Update `tf.keras` to the Keras 2.1.6 API. +* Added [`tf.keras.layers.CuDNNGRU`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNGRU) and [`tf.keras.layers.CuDNNLSTM`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNLSTM) layers. [Try it](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb?linkId=53292082). +* Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/boosted_trees). +* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/lite) + for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/README.md) + has been expanded, and the command line interface (AKA: `toco`, `tflite_convert`) is once again + included in the standard `pip` installation. +* Improved data-loading and text processing with: + * [`tf.decode_compressed`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/decode_compressed) + * [`tf.string_strip`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/string_strip) + * [`tf.strings.regex_full_match`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/strings/regex_full_match) +* Added experimental support for new pre-made Estimators: + * [`tf.contrib.estimator.BaselineEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/BaselineEstimator) + * [`tf.contrib.estimator.RNNClassifier`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNEstimator) + * [`tf.contrib.estimator.RNNEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNClassifier) +* The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector) + API supports broadcasting for Bijectors with new API changes. + +## Breaking Changes + * If you're opening empty variable scopes; replace `variable_scope('', ...)` by + `variable_scope(tf.get_variable_scope(), ...)`. + * Headers used for building custom ops have been moved from site-packages/external into site-packages/tensorflow/include/external. + +## Bug Fixes and Other Changes + * `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. -* Adding support of core feature columns and losses to gradient boosted trees estimators. -* The distributions.Bijector API supports broadcasting for Bijectors with new API changes. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/distributions/bijectors/Bijector) for more details. * Layered variable names have changed in the following conditions: * Using `tf.keras.layers` with custom variable scopes. - * Using `tf.layers` in a subclassed `tf.keras.Model` class. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details - -## Breaking Chances - * If you're opening empty variable scopes; replace `variable_scope`('', ...) by `variable_scope`(`tf.get_variable_scope()`, ...). - -## Bug Fixes and Other Changes + * Using `tf.layers` in a subclassed `tf.keras.Model` class. See + [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details * `tf.data`: - * The `DatasetBase::DebugString()` method is now `const`. - * Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets. + * `Dataset.from_generator()` now accepts an `args` list, in order to create nested generators. + * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed. + * `tf.contrib.data.sample_from_datasets()` and `tf.contrib.data.choose_from_datasets()` make it easier to sample or deterministically choose elements from multiple datasets. + * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings, and two infrequently used arguments removed. + * (C++) `DatasetBase::DebugString()` is now `const`. + * (C++) `DatasetBase::MakeIterator()` has been renamed to `DatasetBase::MakeIteratorInternal()`. + * (C++) `IteratorBase::Initialize()` method was added to support raising errors during iterator construction. * Eager Execution: + * Added the ability to pause recording operations for gradient computation via `tf.GradientTape.stop_recording`. + * Updated documentation, introductory notebooks. * `tf.keras`: * Move Keras code out of _impl folder and remove API files. * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. * Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods. -* Accelerated Linear Algebra (XLA): -* TensorFlow Debugger (tfdbg): fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB). +* TensorFlow Debugger (tfdbg) CLI: fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB). * `tf.contrib`: - * Add `tf.contrib.data.choose_from_datasets()`. - * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings. Two arguments were removed from `make_csv_dataset`. * `tf.contrib.framework.zero_initializer` supports ResourceVariable. * Adding "constrained_optimization" to tensorflow/contrib. * Other: @@ -35,7 +124,6 @@ * More consistent GcsFileSystem behavior for certain reads past EOF. * Update benchmark for tf.scan to match ranges across eager and graph modes. * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. - * Add optional `args` argument to `Dataset.from_generator()`. * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). * Benchmark for tf.scan in graph and eager modes. * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. @@ -45,7 +133,6 @@ * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary. * Added `tf.train.Checkpoint` for reading/writing object-based checkpoints. - * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed. * Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product. * Allow LinearOperator to broadcast. * SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other. @@ -465,7 +552,7 @@ answered questions, and were part of inspiring discussions. ## Major Features And Improvements * `tf.keras` is now part of the core TensorFlow API. -* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of +* [`tf.data`](http://tensorflow.org/guide/datasets) is now part of the core TensorFlow API. * The API is now subject to backwards compatibility guarantees. * For a guide to migrating from the `tf.contrib.data` API, see the @@ -485,7 +572,7 @@ answered questions, and were part of inspiring discussions. * TensorFlow Debugger (tfdbg): * Add `eval` command to allow evaluation of arbitrary Python/numpy expressions in tfdbg command-line interface. See - [Debugging TensorFlow Programs](https://www.tensorflow.org/programmers_guide/debugger) + [Debugging TensorFlow Programs](https://www.tensorflow.org/guide/debugger) for more details. * Usability improvement: The frequently used tensor filter `has_inf_or_nan` is now added to `Session` wrappers and hooks by default. So there is no need @@ -772,7 +859,7 @@ answered questions, and were part of inspiring discussions. * Support client-provided ClusterSpec's and propagate them to all workers to enable the creation of dynamic TensorFlow clusters. * TensorFlow C library now available for Windows. * We released a new open-source version of TensorBoard. -* [`SavedModel CLI`](https://www.tensorflow.org/versions/master/programmers_guide/saved_model_cli) tool available to inspect and execute MetaGraph in SavedModel +* [`SavedModel CLI`](https://www.tensorflow.org/versions/master/guide/saved_model_cli) tool available to inspect and execute MetaGraph in SavedModel * Android releases of TensorFlow are now pushed to jcenter for easier integration into apps. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/android/README.md diff --git a/SECURITY.md b/SECURITY.md index e2f6ff353a3c04a6ec6b8ccbaeb75db59fa22d54..0b52fdc7ab84b7bd5bce5d247ede81b40699005c 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -245,4 +245,4 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= ### Known Vulnerabilities For a list of known vulnerabilities and security advisories for TensorFlow, -(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md)[click here]. +[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md). diff --git a/WORKSPACE b/WORKSPACE index fd7570a80ae2ee0087f7d2fd771fcce5b9690028..17961829a605c2d1f2d2ba86a7c30c47618c139b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -18,7 +18,7 @@ closure_repositories() # files, in case the parsing of those build files depends on the bazel # version we require here. load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") -check_bazel_version_at_least("0.10.0") +check_bazel_version_at_least("0.15.0") load("//tensorflow:workspace.bzl", "tf_workspace") diff --git a/configure.py b/configure.py index ada342a50ab5104509156d3e44e6435a308255a3..bf570a9fa394f8fb7ef98f57007b656afd0c466c 100644 --- a/configure.py +++ b/configure.py @@ -35,8 +35,8 @@ except ImportError: _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' -_DEFAULT_NCCL_VERSION = '1.3' -_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' +_DEFAULT_NCCL_VERSION = '2.2' +_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' @@ -680,7 +680,7 @@ def create_android_sdk_rule(environ_cp): if is_windows() or is_cygwin(): default_sdk_path = cygpath('%s/Android/Sdk' % environ_cp['APPDATA']) elif is_macos(): - default_sdk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + default_sdk_path = '%s/library/Android/Sdk' % environ_cp['HOME'] else: default_sdk_path = '%s/Android/Sdk' % environ_cp['HOME'] @@ -835,17 +835,20 @@ def set_tf_cuda_version(environ_cp): '[Default is %s]: ') % (tf_cuda_version, default_cuda_path) cuda_toolkit_path = get_from_env_or_user_or_default( environ_cp, 'CUDA_TOOLKIT_PATH', ask_cuda_path, default_cuda_path) + if is_windows() or is_cygwin(): + cuda_toolkit_path = cygpath(cuda_toolkit_path) if is_windows(): - cuda_rt_lib_path = 'lib/x64/cudart.lib' + cuda_rt_lib_paths = ['lib/x64/cudart.lib'] elif is_linux(): - cuda_rt_lib_path = 'lib64/libcudart.so.%s' % tf_cuda_version + cuda_rt_lib_paths = ['%s/libcudart.so.%s' % (x, tf_cuda_version) + for x in ['lib64', 'lib/x86_64-linux-gnu']] elif is_macos(): - cuda_rt_lib_path = 'lib/libcudart.%s.dylib' % tf_cuda_version + cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version] - cuda_toolkit_path_full = os.path.join(cuda_toolkit_path, cuda_rt_lib_path) - if os.path.exists(cuda_toolkit_path_full): - break + cuda_toolkit_paths_full = [os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths] + if any([os.path.exists(x) for x in cuda_toolkit_paths_full]): + break # Reset and retry print('Invalid path to CUDA %s toolkit. %s cannot be found' % @@ -880,7 +883,7 @@ def set_tf_cudnn_version(environ_cp): default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH') ask_cudnn_path = (r'Please specify the location where cuDNN %s library is ' 'installed. Refer to README.md for more details. [Default' - ' is %s]:') % (tf_cudnn_version, default_cudnn_path) + ' is %s]: ') % (tf_cudnn_version, default_cudnn_path) cudnn_install_path = get_from_env_or_user_or_default( environ_cp, 'CUDNN_INSTALL_PATH', ask_cudnn_path, default_cudnn_path) @@ -943,6 +946,35 @@ def set_tf_cudnn_version(environ_cp): write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version) +def is_cuda_compatible(lib, cuda_ver, cudnn_ver): + """Check compatibility between given library and cudnn/cudart libraries.""" + ldd_bin = which('ldd') or '/usr/bin/ldd' + ldd_out = run_shell([ldd_bin, lib], True) + ldd_out = ldd_out.split(os.linesep) + cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$') + cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$') + cudnn = None + cudart = None + cudnn_ok = True # assume no cudnn dependency by default + cuda_ok = True # assume no cuda dependency by default + for line in ldd_out: + if 'libcudnn.so' in line: + cudnn = cudnn_pattern.search(line) + cudnn_ok = False + elif 'libcudart.so' in line: + cudart = cuda_pattern.search(line) + cuda_ok = False + if cudnn and len(cudnn.group(1)): + cudnn = convert_version_to_int(cudnn.group(1)) + if cudart and len(cudart.group(1)): + cudart = convert_version_to_int(cudart.group(1)) + if cudnn is not None: + cudnn_ok = (cudnn == cudnn_ver) + if cudart is not None: + cuda_ok = (cudart == cuda_ver) + return cudnn_ok and cuda_ok + + def set_tf_tensorrt_install_path(environ_cp): """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION. @@ -959,8 +991,8 @@ def set_tf_tensorrt_install_path(environ_cp): raise ValueError('Currently TensorRT is only supported on Linux platform.') # Ask user whether to add TensorRT support. - if str(int(get_var( - environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1': + if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', + False))) != '1': return for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): @@ -973,47 +1005,29 @@ def set_tf_tensorrt_install_path(environ_cp): # Result returned from "read" will be used unexpanded. That make "~" # unusable. Going through one more level of expansion to handle that. - trt_install_path = os.path.realpath( - os.path.expanduser(trt_install_path)) + trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path)) def find_libs(search_path): """Search for libnvinfer.so in "search_path".""" fl = set() if os.path.exists(search_path) and os.path.isdir(search_path): - fl.update([os.path.realpath(os.path.join(search_path, x)) - for x in os.listdir(search_path) if 'libnvinfer.so' in x]) + fl.update([ + os.path.realpath(os.path.join(search_path, x)) + for x in os.listdir(search_path) + if 'libnvinfer.so' in x + ]) return fl possible_files = find_libs(trt_install_path) possible_files.update(find_libs(os.path.join(trt_install_path, 'lib'))) possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64'))) - - def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver): - """Check the compatibility between tensorrt and cudnn/cudart libraries.""" - ldd_bin = which('ldd') or '/usr/bin/ldd' - ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep) - cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$') - cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$') - cudnn = None - cudart = None - for line in ldd_out: - if 'libcudnn.so' in line: - cudnn = cudnn_pattern.search(line) - elif 'libcudart.so' in line: - cudart = cuda_pattern.search(line) - if cudnn and len(cudnn.group(1)): - cudnn = convert_version_to_int(cudnn.group(1)) - if cudart and len(cudart.group(1)): - cudart = convert_version_to_int(cudart.group(1)) - return (cudnn == cudnn_ver) and (cudart == cuda_ver) - cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION']) cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION']) nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$') highest_ver = [0, None, None] for lib_file in possible_files: - if is_compatible(lib_file, cuda_ver, cudnn_ver): + if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver): matches = nvinfer_pattern.search(lib_file) if len(matches.groups()) == 0: continue @@ -1029,12 +1043,13 @@ def set_tf_tensorrt_install_path(environ_cp): # Try another alternative from ldconfig. ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' ldconfig_output = run_shell([ldconfig_bin, '-p']) - search_result = re.search( - '.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output) + search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)', + ldconfig_output) if search_result: libnvinfer_path_from_ldconfig = search_result.group(2) if os.path.exists(libnvinfer_path_from_ldconfig): - if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver): + if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver, + cudnn_ver): trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig) tf_tensorrt_version = search_result.group(1) break @@ -1083,8 +1098,10 @@ def set_tf_nccl_install_path(environ_cp): raise ValueError('Currently NCCL is only supported on Linux platforms.') ask_nccl_version = ( - 'Please specify the NCCL version you want to use. ' - '[Leave empty to default to NCCL %s]: ') % _DEFAULT_NCCL_VERSION + 'Please specify the NCCL version you want to use. If NCCL %s is not ' + 'installed, then you can use version 1.3 that can be fetched ' + 'automatically but it may have worse performance with multiple GPUs. ' + '[Default is %s]: ') % (_DEFAULT_NCCL_VERSION, _DEFAULT_NCCL_VERSION) for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_nccl_version = get_from_env_or_user_or_default( @@ -1185,7 +1202,7 @@ def set_tf_cuda_compute_capabilities(environ_cp): 'https://developer.nvidia.com/cuda-gpus.\nPlease' ' note that each additional compute ' 'capability significantly increases your ' - 'build time and binary size. [Default is: %s]' % + 'build time and binary size. [Default is: %s]: ' % default_cuda_compute_capabilities) tf_cuda_compute_capabilities = get_from_env_or_user_or_default( environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', @@ -1220,28 +1237,13 @@ def set_tf_cuda_compute_capabilities(environ_cp): def set_other_cuda_vars(environ_cp): """Set other CUDA related variables.""" - if is_windows(): - # The following three variables are needed for MSVC toolchain configuration - # in Bazel - environ_cp['CUDA_PATH'] = environ_cp.get('CUDA_TOOLKIT_PATH') - environ_cp['CUDA_COMPUTE_CAPABILITIES'] = environ_cp.get( - 'TF_CUDA_COMPUTE_CAPABILITIES') - environ_cp['NO_WHOLE_ARCHIVE_OPTION'] = 1 - write_action_env_to_bazelrc('CUDA_PATH', environ_cp.get('CUDA_PATH')) - write_action_env_to_bazelrc('CUDA_COMPUTE_CAPABILITIE', - environ_cp.get('CUDA_COMPUTE_CAPABILITIE')) - write_action_env_to_bazelrc('NO_WHOLE_ARCHIVE_OPTION', - environ_cp.get('NO_WHOLE_ARCHIVE_OPTION')) - write_to_bazelrc('build --config=win-cuda') - write_to_bazelrc('test --config=win-cuda') + # If CUDA is enabled, always use GPU during build and test. + if environ_cp.get('TF_CUDA_CLANG') == '1': + write_to_bazelrc('build --config=cuda_clang') + write_to_bazelrc('test --config=cuda_clang') else: - # If CUDA is enabled, always use GPU during build and test. - if environ_cp.get('TF_CUDA_CLANG') == '1': - write_to_bazelrc('build --config=cuda_clang') - write_to_bazelrc('test --config=cuda_clang') - else: - write_to_bazelrc('build --config=cuda') - write_to_bazelrc('test --config=cuda') + write_to_bazelrc('build --config=cuda') + write_to_bazelrc('test --config=cuda') def set_host_cxx_compiler(environ_cp): @@ -1401,14 +1403,36 @@ def set_build_strip_flag(): write_to_bazelrc('build --strip=always') -def set_windows_build_flags(): - if is_windows(): - # The non-monolithic build is not supported yet - write_to_bazelrc('build --config monolithic') - # Suppress warning messages - write_to_bazelrc('build --copt=-w --host_copt=-w') - # Output more verbose information when something goes wrong - write_to_bazelrc('build --verbose_failures') +def set_windows_build_flags(environ_cp): + """Set Windows specific build options.""" + # The non-monolithic build is not supported yet + write_to_bazelrc('build --config monolithic') + # Suppress warning messages + write_to_bazelrc('build --copt=-w --host_copt=-w') + # Output more verbose information when something goes wrong + write_to_bazelrc('build --verbose_failures') + # The host and target platforms are the same in Windows build. So we don't + # have to distinct them. This avoids building the same targets twice. + write_to_bazelrc('build --distinct_host_configuration=false') + # Enable short object file path to avoid long path issue on Windows. + # TODO(pcloudy): Remove this flag when upgrading Bazel to 0.16.0 + # Short object file path will be enabled by default. + write_to_bazelrc('build --experimental_shortened_obj_file_path=true') + + if get_var( + environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline', + True, + ('Would you like to override eigen strong inline for some C++ ' + 'compilation to reduce the compilation time?'), + 'Eigen strong inline overridden.', + 'Not overriding eigen strong inline, ' + 'some compilations could take more than 20 mins.'): + # Due to a known MSVC compiler issue + # https://github.com/tensorflow/tensorflow/issues/10521 + # Overriding eigen strong inline speeds up the compiling of + # conv_grad_ops_3d.cc and conv_ops_3d.cc by 20 minutes, + # but this also hurts the performance. Let users decide what they want. + write_to_bazelrc('build --define=override_eigen_strong_inline=true') def config_info_line(name, help_text): @@ -1428,14 +1452,14 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.10.0') + check_bazel_version('0.15.0') reset_tf_configure_bazelrc(args.workspace) cleanup_makefile() setup_python(environ_cp) if is_windows(): - environ_cp['TF_NEED_S3'] = '0' + environ_cp['TF_NEED_AWS'] = '0' environ_cp['TF_NEED_GCP'] = '0' environ_cp['TF_NEED_HDFS'] = '0' environ_cp['TF_NEED_JEMALLOC'] = '0' @@ -1448,19 +1472,31 @@ def main(): # TODO(ibiryukov): Investigate using clang as a cpu or cuda compiler on # Windows. environ_cp['TF_DOWNLOAD_CLANG'] = '0' + environ_cp['TF_ENABLE_XLA'] = '0' + environ_cp['TF_NEED_GDR'] = '0' + environ_cp['TF_NEED_VERBS'] = '0' + environ_cp['TF_NEED_MPI'] = '0' + environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0' if is_macos(): environ_cp['TF_NEED_JEMALLOC'] = '0' environ_cp['TF_NEED_TENSORRT'] = '0' + # The numpy package on ppc64le uses OpenBLAS which has multi-threading + # issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at + # runtime to allow the Tensorflow testcases which compare numpy + # results to Tensorflow results to succeed. + if is_ppc64le(): + write_action_env_to_bazelrc("OMP_NUM_THREADS", 1) + set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc', 'with_jemalloc', True) set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform', 'with_gcp_support', True, 'gcp') set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System', 'with_hdfs_support', True, 'hdfs') - set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', - 'with_s3_support', True, 's3') + set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform', + 'with_aws_support', True, 'aws') set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', 'with_kafka_support', True, 'kafka') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', @@ -1524,7 +1560,8 @@ def main(): set_grpc_build_flags() set_cc_opt_flags(environ_cp) set_build_strip_flag() - set_windows_build_flags() + if is_windows(): + set_windows_build_flags(environ_cp) if get_var( environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', @@ -1536,11 +1573,15 @@ def main(): create_android_ndk_rule(environ_cp) create_android_sdk_rule(environ_cp) - print('Preconfigured Bazel build configs. You can use any of the below by ' - 'adding "--config=<>" to your build command. See tools/bazel.rc for ' - 'more details.') - config_info_line('mkl', 'Build with MKL support.') - config_info_line('monolithic', 'Config for mostly static monolithic build.') + # On Windows, we don't have MKL support and the build is always monolithic. + # So no need to print the following message. + # TODO(pcloudy): remove the following if check when they make sense on Windows + if not is_windows(): + print('Preconfigured Bazel build configs. You can use any of the below by ' + 'adding "--config=<>" to your build command. See tools/bazel.rc for ' + 'more details.') + config_info_line('mkl', 'Build with MKL support.') + config_info_line('monolithic', 'Config for mostly static monolithic build.') if __name__ == '__main__': main() diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 6d134dbb80cb8c3dcf15b2ba20783870a67e9a62..f1000c1bffd07d090315f6ebbbd3ea504e710df9 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -20,10 +20,18 @@ load( "tf_additional_binary_deps", ) load( - "//tensorflow/tools/api/generator:api_gen.bzl", + "//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files", # @unused ) +# Config setting used when building for products +# which requires restricted licenses to be avoided. +config_setting( + name = "no_lgpl_deps", + values = {"define": "__TENSORFLOW_NO_LGPL_DEPS__=1"}, + visibility = ["//visibility:public"], +) + # Config setting for determining if we are building for Android. config_setting( name = "android", @@ -115,12 +123,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "windows_msvc", - values = {"cpu": "x64_windows_msvc"}, - visibility = ["//visibility:public"], -) - config_setting( name = "no_tensorflow_py_deps", define_values = {"no_tensorflow_py_deps": "true"}, @@ -154,6 +156,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "linux_s390x", + values = {"cpu": "s390x"}, + visibility = ["//visibility:public"], +) + config_setting( name = "debug", values = { @@ -210,8 +218,8 @@ config_setting( ) config_setting( - name = "with_s3_support", - define_values = {"with_s3_support": "true"}, + name = "with_aws_support", + define_values = {"with_aws_support": "true"}, visibility = ["//visibility:public"], ) @@ -238,8 +246,8 @@ config_setting( ) config_setting( - name = "with_s3_support_windows_override", - define_values = {"with_s3_support": "true"}, + name = "with_aws_support_windows_override", + define_values = {"with_aws_support": "true"}, values = {"cpu": "x64_windows"}, visibility = ["//visibility:public"], ) @@ -251,6 +259,13 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_cuda_support_windows_override", + define_values = {"using_cuda_nvcc": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_gcp_support_android_override", define_values = {"with_gcp_support": "true"}, @@ -266,8 +281,8 @@ config_setting( ) config_setting( - name = "with_s3_support_android_override", - define_values = {"with_s3_support": "true"}, + name = "with_aws_support_android_override", + define_values = {"with_aws_support": "true"}, values = {"crosstool_top": "//external:android/crosstool"}, visibility = ["//visibility:public"], ) @@ -287,8 +302,8 @@ config_setting( ) config_setting( - name = "with_s3_support_ios_override", - define_values = {"with_s3_support": "true"}, + name = "with_aws_support_ios_override", + define_values = {"with_aws_support": "true"}, values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, visibility = ["//visibility:public"], ) @@ -360,6 +375,15 @@ config_setting( }, ) +# Setting to use when loading kernels dynamically +config_setting( + name = "dynamic_loaded_kernels", + define_values = { + "dynamic_loaded_kernels": "true", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "using_cuda_nvcc", define_values = { @@ -387,17 +411,10 @@ config_setting( visibility = ["//visibility:public"], ) -# TODO(laigd): consider removing this option and make TensorRT enabled -# automatically when CUDA is enabled. -config_setting( - name = "with_tensorrt_support", - values = {"define": "with_tensorrt_support=true"}, - visibility = ["//visibility:public"], -) - package_group( name = "internal", packages = [ + "-//third_party/tensorflow/python/estimator", "//learning/meta_rank/...", "//tensorflow/...", "//tensorflow_fold/llgtm/...", @@ -407,21 +424,32 @@ package_group( load( "//third_party/mkl:build_defs.bzl", - "if_mkl", + "if_mkl_ml", ) filegroup( name = "intel_binary_blob", - data = if_mkl( + data = if_mkl_ml( [ - "//third_party/mkl:intel_binary_blob", + "//third_party/intel_mkl_ml", ], ), ) -filegroup( - name = "docs_src", - data = glob(["docs_src/**/*.md"]), +cc_library( + name = "grpc", + deps = select({ + ":linux_s390x": ["@grpc//:grpc_unsecure"], + "//conditions:default": ["@grpc"], + }), +) + +cc_library( + name = "grpc++", + deps = select({ + ":linux_s390x": ["@grpc//:grpc++_unsecure"], + "//conditions:default": ["@grpc//:grpc++"], + }), ) # A shared object which includes registration mechanisms for ops and @@ -451,6 +479,14 @@ filegroup( tf_cc_shared_object( name = "libtensorflow_framework.so", framework_so = [], + linkopts = select({ + "//tensorflow:darwin": [], + "//tensorflow:windows": [], + "//conditions:default": [ + "-Wl,--version-script", # This line must be directly followed by the version_script.lds file + "$(location //tensorflow:tf_framework_version_script.lds)", + ], + }), linkstatic = 1, visibility = ["//visibility:public"], deps = [ @@ -460,6 +496,7 @@ tf_cc_shared_object( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", "//tensorflow/core:lib_internal_impl", "//tensorflow/stream_executor:stream_executor_impl", + "//tensorflow:tf_framework_version_script.lds", ] + tf_additional_binary_deps(), ) @@ -486,7 +523,6 @@ tf_cc_shared_object( "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file @@ -511,7 +547,6 @@ tf_cc_shared_object( "$(location //tensorflow:tf_exported_symbols.lds)", ], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file @@ -541,16 +576,26 @@ exports_files( gen_api_init_files( name = "tensorflow_python_api_gen", srcs = ["api_template.__init__.py"], + api_version = 1, root_init_template = "api_template.__init__.py", ) py_library( name = "tensorflow_py", - srcs = [ - ":tensorflow_python_api_gen", - "//tensorflow/python/estimator/api:estimator_python_api_gen", + srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":tensorflow_py_no_contrib", + "//tensorflow/contrib:contrib_py", + "//tensorflow/python/estimator:estimator_py", ], +) + +py_library( + name = "tensorflow_py_no_contrib", + srcs = [":tensorflow_python_api_gen"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = ["//tensorflow/python"], + deps = ["//tensorflow/python:no_contrib"], ) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 9662d7b478ba61c69edc20b0d47293f9939e7881..779f65d5b17c350833f67f07985b00e8eb561e72 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -20,7 +20,6 @@ from __future__ import print_function # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import -# API IMPORTS PLACEHOLDER try: import os # pylint: disable=g-import-not-at-top @@ -37,6 +36,8 @@ try: except (ImportError, AttributeError): print('tf.estimator package not installed.') +# API IMPORTS PLACEHOLDER + from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index cb0b093ad260e000dcef9d1123e967a77cf1a041..19ccb6e71d2f3021c1ce5c8905d8a72059c1cfcb 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eval_const_tensor.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -45,11 +46,13 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" @@ -326,6 +329,7 @@ TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) { } void TF_DeleteBuffer(TF_Buffer* buffer) { + if (buffer == nullptr) return; if (buffer->data_deallocator != nullptr) { (*buffer->data_deallocator)(const_cast(buffer->data), buffer->length); @@ -355,6 +359,7 @@ void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { status->status = Status::OK(); + if (s == nullptr) return; delete s->session; delete s; } @@ -390,64 +395,6 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, status->status = Reset(opt->options, container_names); } -// This traverses the specified nodes in topological order to verify there are -// no cycles. Starting with inputless nodes, it visits nodes whose inputs have -// all been visited, and counts the total number of visited nodes. If there is a -// cycle, nodes in the cycle will never be visited, and the visited count will -// be less than the total node count. -Status ValidateNoCycles(const Graph& g) { - // TODO(nolivia): check this on a subset of the graph instead of all of it. - // A node is ready when all of its inputs have been visited. - std::vector ready; - std::vector pending_count(g.num_node_ids(), 0); - - for (int i = 0; i < g.num_node_ids(); ++i) { - const Node* n = g.FindNodeId(i); - if (n == nullptr) continue; - pending_count[i] = n->in_edges().size(); - if (n->IsMerge()) { - // While-loop cycles are legal cycles so we manually adjust the - // pending_count to make sure that the loop is visited. - for (const Edge* e : n->in_edges()) { - if (!e->IsControlEdge() && e->src()->IsNextIteration()) { - pending_count[i]--; - } - } - } - if (pending_count[i] == 0) { - ready.push_back(n); - } - } - - int processed = 0; - while (!ready.empty()) { - const Node* node = ready.back(); - ready.pop_back(); - ++processed; - - for (const Edge* out : node->out_edges()) { - const int output_id = out->dst()->id(); - pending_count[output_id]--; - if (pending_count[output_id] == 0) { - ready.push_back(out->dst()); - } - } - } - - if (processed < g.num_nodes()) { - std::vector nodes_in_cycle; - for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; - ++i) { - if (pending_count[i] != 0) { - nodes_in_cycle.push_back(g.FindNodeId(i)->name()); - } - } - return errors::InvalidArgument( - "Graph is invalid, contains a cycle with ", g.num_nodes() - processed, - " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); - } - return Status::OK(); -} } // namespace } // namespace tensorflow @@ -746,7 +693,9 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { - status->status = tensorflow::ValidateNoCycles(session->graph->graph); + // TODO(nolivia): check this on a subset of the graph instead of all of + // it. + status->status = graph::ValidateGraphHasNoCycle(session->graph->graph); if (!status->status.ok()) { session->graph->mu.unlock(); return false; @@ -961,6 +910,7 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; } void TF_DeleteLibraryHandle(TF_Library* lib_handle) { + if (lib_handle == nullptr) return; tensorflow::port::Free(const_cast(lib_handle->op_list.data)); delete lib_handle; } @@ -1018,6 +968,7 @@ TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr); TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(), nullptr); TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1); +TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0); #undef TF_DEVICELIST_METHOD @@ -1907,6 +1858,7 @@ TF_Graph::TF_Graph() TF_Graph* TF_NewGraph() { return new TF_Graph; } void TF_DeleteGraph(TF_Graph* g) { + if (g == nullptr) return; g->mu.lock(); g->delete_requested = true; const bool del = g->sessions.empty(); @@ -2123,7 +2075,8 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status) { GraphDef def; - if (!def.ParseFromArray(graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return nullptr; } @@ -2153,7 +2106,8 @@ void TF_GraphImportGraphDefWithReturnOutputs( return; } GraphDef def; - if (!def.ParseFromArray(graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return; } @@ -2436,6 +2390,12 @@ void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy) { + TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy); +} + +void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, + int ny, TF_Output* x, int nx, TF_Output* dx, + TF_Status* status, TF_Output* dy) { #ifdef __ANDROID__ status->status = tensorflow::errors::Unimplemented( "Adding gradients is not supported in Android. File a bug at " @@ -2452,9 +2412,29 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, const int first_new_node_id = g->graph.num_node_ids(); + string prefix_cmp; + const char* child_scope_name; + if (prefix == nullptr) { + child_scope_name = "gradients"; + } else { + prefix_cmp = string(prefix) + "/"; + // The operation should fail if the provided name prefix has already been + // used in this graph + for (const auto& pair : g->name_map) { + const string& name = pair.first; + if (name.compare(prefix) == 0 || + tensorflow::str_util::StartsWith(name, prefix_cmp)) { + status->status = InvalidArgument( + "prefix [", prefix, + "] conflicts with existing node in the graph named [", name, "]"); + return; + } + } + child_scope_name = prefix; + } tensorflow::Scope scope = NewInternalScope(&g->graph, &status->status, &g->refiner) - .NewSubScope("gradients"); + .NewSubScope(child_scope_name); if (dx != nullptr) { std::vector dx_arg = OutputsFromTFOutputs(dx, ny); @@ -2469,7 +2449,30 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { Node* n = g->graph.FindNodeId(i); if (n == nullptr) continue; - g->name_map[n->name()] = n; + + // Adding the gradients to the graph can alter the prefix to prevent + // name collisions only if this prefix has not been provided explicitly + // by the user. If it was provided, assert that it remained intact. + if (prefix != nullptr && + !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) { + status->status = tensorflow::errors::Internal( + "BUG: The gradients prefix have been unexpectedly altered when " + "adding the nodes to the graph. This is a bug. Please file an " + "issue at https://github.com/tensorflow/tensorflow/issues."); + return; + } + // We have a convoluted scheme here: Using the C++ graph construction API + // to add potentially many nodes to the graph without running the checks + // (such as uniqueness of the names of nodes) we run with other functions + // that add a node to the graph (like TF_FinishOperation). + if (!g->name_map.insert(std::make_pair(n->name(), n)).second) { + status->status = tensorflow::errors::Internal( + "BUG: The API allowed construction of a graph with duplicate node " + "names (", + n->name(), + "). This is a bug. Please file an issue at " + "https://github.com/tensorflow/tensorflow/issues."); + } } } @@ -2569,6 +2572,7 @@ void TF_CloseSession(TF_Session* s, TF_Status* status) { void TF_DeleteSession(TF_Session* s, TF_Status* status) { status->status = Status::OK(); + if (s == nullptr) return; TF_Graph* const graph = s->graph; if (graph != nullptr) { graph->mu.lock(); @@ -2767,7 +2771,34 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, TF_Buffer* ret = TF_NewBuffer(); status->status = MessageToBuffer(*api_def, ret); + if (!status->status.ok()) { + TF_DeleteBuffer(ret); + return nullptr; + } return ret; #endif // __ANDROID__ } + +TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) { + tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels(); + TF_Buffer* ret = TF_NewBuffer(); + status->status = MessageToBuffer(kernel_list, ret); + if (!status->status.ok()) { + TF_DeleteBuffer(ret); + return nullptr; + } + return ret; +} + +TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { + tensorflow::KernelList kernel_list = + tensorflow::GetRegisteredKernelsForOp(name); + TF_Buffer* ret = TF_NewBuffer(); + status->status = MessageToBuffer(kernel_list, ret); + if (!status->status.ok()) { + TF_DeleteBuffer(ret); + return nullptr; + } + return ret; +} } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index c8594347451dffd465d7fa926cc53818dc9e38d4..850f6ecd637d768bca99720e0add07680829e17a 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -44,6 +44,7 @@ limitations under the License. // * size_t is used to represent byte sizes of objects that are // materialized in the address space of the calling process. // * int is used as an index into arrays. +// * Deletion functions are safe to call on nullptr. // // Questions left to address: // * Might at some point need a way for callers to provide their own Env. @@ -894,7 +895,8 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_ImportGraphDefOptions* opts); // Set the prefix to be prepended to the names of nodes in `graph_def` that will -// be imported into `graph`. +// be imported into `graph`. `prefix` is copied and has no lifetime +// requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( TF_ImportGraphDefOptions* opts, const char* prefix); @@ -915,6 +917,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( // Set any imported nodes with input `src_name:src_index` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references a node already existing in the graph being imported into. +// `src_name` is copied and has no lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst); @@ -922,7 +925,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( // Set any imported nodes with control input `src_name` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references an operation already existing in the graph being imported -// into. +// into. `src_name` is copied and has no lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsRemapControlDependency( TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst); @@ -934,6 +937,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddControlDependency( // Add an output in `graph_def` to be returned via the `return_outputs` output // parameter of TF_GraphImportGraphDef(). If the output is remapped via an input // mapping, the corresponding existing tensor in `graph` will be returned. +// `oper_name` is copied and has no lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput( TF_ImportGraphDefOptions* opts, const char* oper_name, int index); @@ -943,7 +947,8 @@ TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs( const TF_ImportGraphDefOptions* opts); // Add an operation in `graph_def` to be returned via the `return_opers` output -// parameter of TF_GraphImportGraphDef(). +// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no +// lifetime requirements. TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation( TF_ImportGraphDefOptions* opts, const char* oper_name); @@ -1126,6 +1131,7 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params); // Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, // i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// // `dx` are used as initial gradients (which represent the symbolic partial // derivatives of some loss function `L` w.r.t. `y`). // `dx` must be nullptr or have size `ny`. @@ -1134,6 +1140,12 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params); // The partial derivatives are returned in `dy`. `dy` should be allocated to // size `nx`. // +// Gradient nodes are automatically named under the "gradients/" prefix. To +// guarantee name uniqueness, subsequent calls to the same graph will +// append an incremental tag to the prefix: "gradients_1/", "gradients_2/", ... +// See TF_AddGradientsWithPrefix, which provides a means to specify a custom +// name prefix for operations added to a graph to compute the gradients. +// // WARNING: This function does not yet support all the gradients that python // supports. See // https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md @@ -1142,6 +1154,33 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy); +// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, +// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// This is a variant of TF_AddGradients that allows to caller to pass a custom +// name prefix to the operations added to a graph to compute the gradients. +// +// `dx` are used as initial gradients (which represent the symbolic partial +// derivatives of some loss function `L` w.r.t. `y`). +// `dx` must be nullptr or have size `ny`. +// If `dx` is nullptr, the implementation will use dx of `OnesLike` for all +// shapes in `y`. +// The partial derivatives are returned in `dy`. `dy` should be allocated to +// size `nx`. +// `prefix` names the scope into which all gradients operations are being added. +// `prefix` must be unique within the provided graph otherwise this operation +// will fail. If `prefix` is nullptr, the default prefixing behaviour takes +// place, see TF_AddGradients for more details. +// +// WARNING: This function does not yet support all the gradients that python +// supports. See +// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md +// for instructions on how to add C++ more gradients. +TF_CAPI_EXPORT void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, + TF_Output* y, int ny, + TF_Output* x, int nx, + TF_Output* dx, TF_Status* status, + TF_Output* dy); + // Create a TF_Function from a TF_Graph // // Params: @@ -1231,6 +1270,11 @@ TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( int noutputs, const TF_Output* outputs, const char* const* output_names, const TF_FunctionOptions* opts, const char* description, TF_Status* status); +// Returns the name of the graph function. +// The return value points to memory that is only usable until the next +// mutation to *func. +TF_CAPI_EXPORT extern const char* TF_FunctionName(TF_Function* func); + // Write out a serialized representation of `func` (as a FunctionDef protocol // message) to `output_func_def` (allocated by TF_NewBuffer()). // `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() @@ -1517,6 +1561,13 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( const TF_DeviceList* list, int index, TF_Status* status); +// Retrieve the incarnation number of a given device. +// +// If index is out of bounds, an error code will be set in the status object, +// and 0 will be returned. +TF_CAPI_EXPORT extern uint64_t TF_DeviceListIncarnation( + const TF_DeviceList* list, int index, TF_Status* status); + // -------------------------------------------------------------------------- // Load plugins containing custom ops and kernels @@ -1599,6 +1650,18 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, size_t name_len, TF_Status* status); +// -------------------------------------------------------------------------- +// Kernel definition information. + +// Returns a serialized KernelList protocol buffer containing KernelDefs for all +// registered kernels. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); + +// Returns a serialized KernelList protocol buffer containing KernelDefs for all +// kernels registered for the operation named `name`. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( + const char* name, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 95b04f9058afdfaadbc24f0238860279fcd3e800..69b3ffe2a1f620e346405607ecf742fb863aa644 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -57,6 +57,45 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { } } +TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, + unsigned char gpu_memory_allow_growth) { + tensorflow::ConfigProto config; + auto* optimizer_options = + config.mutable_graph_options()->mutable_optimizer_options(); + if (enable_xla_compilation) { + optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1); + + // These XLA flags are needed to trigger XLA properly from C (more generally + // non-Python) clients. If this API is called again with `enable` set to + // false, it is safe to keep these flag values as is. + tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = + tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + flags->tf_xla_cpu_global_jit = true; + flags->tf_xla_min_cluster_size = 1; + } else { + optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF); + } + + auto* gpu_options = config.mutable_gpu_options(); + gpu_options->set_allow_growth(gpu_memory_allow_growth); + + TF_Buffer* ret = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(config, ret)); + return ret; +} + +TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) { + tensorflow::RunOptions options; + if (enable_full_trace) { + options.set_trace_level(tensorflow::RunOptions::FULL_TRACE); + } else { + options.set_trace_level(tensorflow::RunOptions::NO_TRACE); + } + TF_Buffer* ret = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(options, ret)); + return ret; +} + const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { tensorflow::mutex_lock c(graph->mu); const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString(); diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 20bdace40f1272ded06e710034053a7610326e7f..6617c5a572e90e78369f73d714f39942f213040f 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -55,11 +55,27 @@ extern "C" { // set XLA flag values to prepare for XLA compilation. Otherwise set // global_jit_level to OFF. // -// This API is syntax sugar over TF_SetConfig(), and is used by clients that -// cannot read/write the tensorflow.ConfigProto proto. +// This and the next API are syntax sugar over TF_SetConfig(), and is used by +// clients that cannot read/write the tensorflow.ConfigProto proto. +// TODO: Migrate to TF_CreateConfig() below. TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable); +// Create a serialized tensorflow.ConfigProto proto, where: +// +// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if +// `enable_xla_compilation` is non-zero, and OFF otherwise. +// b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`. +TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig( + unsigned char enable_xla_compilation, + unsigned char gpu_memory_allow_growth); + +// Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level +// is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE +// otherwise. +TF_CAPI_EXPORT extern TF_Buffer* TF_CreateRunOptions( + unsigned char enable_full_trace); + // Returns the graph content in a human-readable format, with length set in // `len`. The format is subject to change in the future. // The returned string is heap-allocated, and caller should call free() on it. diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 384e6c8cb97022264c5327da5ca5861057608fbe..a2c5a42c11361779de61b515e0f08dcc45e609b9 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -536,6 +536,10 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, return tf_function; } +const char* TF_FunctionName(TF_Function* func) { + return func->fdef.signature().name().c_str(); +} + void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func, const TF_Function* grad, TF_Status* status) { if (func == nullptr) { diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 610274696f5940c063e68f2310cfd9cc1e0bd964..73fe73769bc1219ce865149d67d333c53371ccc5 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -193,6 +193,7 @@ class CApiFunctionTest : public ::testing::Test { ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); ASSERT_NE(func_, nullptr); + ASSERT_EQ(std::string(func_name_), std::string(TF_FunctionName(func_))); TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); } @@ -1516,7 +1517,8 @@ void DefineStatefulFunction(const char* name, TF_Function** func) { TF_Output inputs[] = {}; TF_Output outputs[] = {{random, 0}}; - *func = TF_GraphToFunction(func_graph.get(), name, /*append_hash=*/false, -1, + *func = TF_GraphToFunction(func_graph.get(), name, + /*append_hash_to_fn_name=*/false, -1, /*opers=*/nullptr, 0, inputs, 1, outputs, /*output_names=*/nullptr, /*opts=*/nullptr, "", s.get()); @@ -1617,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) { TF_DeleteFunction(func1); } +// This test only works when the TF build includes XLA compiler. One way to set +// this up is via bazel build option "--define with_xla_support=true". +// +// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to +// something like TENSORFLOW_CAPI_USE_XLA. +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST_F(CApiFunctionTest, StatelessIf_XLA) { + TF_Function* func; + const std::string funcName = "BranchFunc"; + DefineFunction(funcName.c_str(), &func); + TF_GraphCopyFunction(host_graph_, func, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* feed = Placeholder(host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* true_cond = ScalarConst(true, host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_OperationDescription* desc = + TF_NewOperation(host_graph_, "StatelessIf", "IfNode"); + TF_AddInput(desc, {true_cond, 0}); + TF_Output inputs[] = {{feed, 0}}; + TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs)); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_SetAttrType(desc, "Tcond", TF_BOOL); + TF_DataType inputType = TF_INT32; + TF_SetAttrTypeList(desc, "Tin", &inputType, 1); + TF_SetAttrTypeList(desc, "Tout", &inputType, 1); + TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size()); + TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size()); + TF_SetDevice(desc, "/device:XLA_CPU:0"); + auto op = TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_NE(op, nullptr); + + // Create a session for this graph. + CSession csession(host_graph_, s_, /*use_XLA*/ true); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Run the graph. + csession.SetInputs({{feed, Int32Tensor(17)}}); + csession.SetOutputs({op}); + csession.Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Tensor* out = csession.output_tensor(0); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); + int32* output_contents = static_cast(TF_TensorData(out)); + EXPECT_EQ(-17, *output_contents); + + // Clean up + csession.CloseAndDelete(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_DeleteFunction(func); +} +#endif // TENSORFLOW_EAGER_USE_XLA + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 577f10c5e69ea9ecbe8ce821c6bd5167e98bef25..aa2a537f03be31ae45ff3d6f7815b449d661cf9c 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -29,9 +29,11 @@ limitations under the License. #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb_text.h" +#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -1160,7 +1162,7 @@ TEST(CAPI, GetOpDef) { } void StringVectorToArrays(const std::vector& v, - std::unique_ptr* ptrs, + std::unique_ptr* ptrs, std::unique_ptr* lens) { ptrs->reset(new const void*[v.size()]); lens->reset(new size_t[v.size()]); @@ -1196,7 +1198,7 @@ class CApiColocationTest : public ::testing::Test { void SetViaStringList(TF_OperationDescription* desc, const std::vector& list) { - std::unique_ptr list_ptrs; + std::unique_ptr list_ptrs; std::unique_ptr list_lens; StringVectorToArrays(list, &list_ptrs, &list_lens); TF_SetAttrStringList(desc, tensorflow::kColocationAttrName, list_ptrs.get(), @@ -1424,6 +1426,29 @@ TEST(CAPI, SavedModelNullArgsAreValid) { TF_DeleteStatus(s); } +TEST(CAPI, DeletingNullPointerIsSafe) { + TF_Status* status = TF_NewStatus(); + + TF_DeleteStatus(nullptr); + TF_DeleteBuffer(nullptr); + TF_DeleteTensor(nullptr); + TF_DeleteSessionOptions(nullptr); + TF_DeleteGraph(nullptr); + TF_DeleteImportGraphDefOptions(nullptr); + TF_DeleteImportGraphDefResults(nullptr); + TF_DeleteFunction(nullptr); + TF_DeleteSession(nullptr, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeletePRunHandle(nullptr); + TF_DeleteDeprecatedSession(nullptr, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteDeviceList(nullptr); + TF_DeleteLibraryHandle(nullptr); + TF_DeleteApiDefMap(nullptr); + + TF_DeleteStatus(status); +} + REGISTER_OP("TestOpWithNoGradient") .Input("x: T") .Output("y: T") @@ -1458,8 +1483,8 @@ class CApiGradientsTest : public ::testing::Test { BuildSuccessGraph(inputs, outputs); BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs); - AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs); - + AddGradients(grad_inputs_provided, nullptr, inputs, 2, outputs, 1, + grad_outputs); EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); // Compare that the graphs match. @@ -1480,7 +1505,8 @@ class CApiGradientsTest : public ::testing::Test { BuildErrorGraph(inputs, outputs); - AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs); + AddGradients(grad_inputs_provided, nullptr, inputs, 1, outputs, 1, + grad_outputs); string expected_msg = "No gradient defined for op: TestOpWithNoGradient. Please see " @@ -1524,19 +1550,20 @@ class CApiGradientsTest : public ::testing::Test { EXPECT_EQ(*a_data, *b_data); } - void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs, - TF_Output* outputs, int noutputs, TF_Output* grad_outputs) { + void AddGradients(bool grad_inputs_provided, const char* prefix, + TF_Output* inputs, int ninputs, TF_Output* outputs, + int noutputs, TF_Output* grad_outputs) { if (grad_inputs_provided) { TF_Output grad_inputs[1]; const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0}; TF_Operation* grad_inputs_op = FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs"); grad_inputs[0] = TF_Output{grad_inputs_op, 0}; - TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs, - s_, grad_outputs); + TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, + ninputs, grad_inputs, s_, grad_outputs); } else { - TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_, - grad_outputs); + TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, + ninputs, nullptr, s_, grad_outputs); } } @@ -1681,6 +1708,20 @@ class CApiGradientsTest : public ::testing::Test { return op; } + void BuildGraphAndAddGradientsWithPrefixes(const char* prefix1, + const char* prefix2 = nullptr) { + TF_Output inputs[2]; + TF_Output outputs[1]; + TF_Output grad_outputs[2]; + + BuildSuccessGraph(inputs, outputs); + + AddGradients(false, prefix1, inputs, 2, outputs, 1, grad_outputs); + if (prefix2 != nullptr) { + AddGradients(false, prefix2, inputs, 2, outputs, 1, grad_outputs); + } + } + TF_Status* s_; TF_Graph* graph_; TF_Graph* expected_graph_; @@ -1700,6 +1741,111 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { TestGradientsError(false); } +TEST_F(CApiGradientsTest, GradientsPrefix_PrefixIsOk) { + BuildGraphAndAddGradientsWithPrefixes("gradients"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithDistinctPrefixes) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients_1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInSameScope) { + BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope/gradients_1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInDifferentScopes) { + BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope_1/gradients"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsSubScopeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/sub"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_PrefixMatchesExistingNodeName) { + BuildGraphAndAddGradientsWithPrefixes("Const_0"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithIdenticalPrefixes) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsMatchingNodeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/MatMul"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_1stGradientsMatchingNodeOf2nd) { + BuildGraphAndAddGradientsWithPrefixes("gradients/MatMul", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsParentScopeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients/sub", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +void ScalarFloatFromTensor(const TF_Tensor* t, float* f) { + ASSERT_TRUE(t != nullptr); + ASSERT_EQ(TF_FLOAT, TF_TensorType(t)); + ASSERT_EQ(0, TF_NumDims(t)); + ASSERT_EQ(4, TF_TensorByteSize(t)); + float* p = static_cast(TF_TensorData(t)); + *f = *p; +} + +TEST_F(CApiGradientsTest, MultipleCallsToAddGradients) { + const float X = 3.0f, Y = 7.0f; + TF_Operation* x = Placeholder(graph_, s_, "x", TF_FLOAT); + TF_Operation* y = Placeholder(graph_, s_, "y", TF_FLOAT); + TF_Operation* xy = Mul(x, y, graph_, s_, "xy"); + TF_Output dxy_dx, dxy_dy; + + TF_Output outputs[1] = {{xy, 0}}; + TF_Output inputs[1] = {{x, 0}}; + TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dx); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + inputs[0] = {y, 0}; + TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dy); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_SessionOptions* opts = TF_NewSessionOptions(); + TF_Session* sess = TF_NewSession(graph_, opts, s_); + TF_DeleteSessionOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Output feeds[] = {{x, 0}, {y, 0}}; + TF_Tensor* feedValues[] = {FloatTensor(X), FloatTensor(Y)}; + TF_Output fetches[] = {dxy_dx, dxy_dy}; + TF_Tensor* fetchValues[] = {nullptr, nullptr}; + + TF_SessionRun(sess, nullptr /* run_options */, feeds, feedValues, 2, fetches, + fetchValues, 2, nullptr /* target_opers */, 0, + nullptr /* run_metadata */, s_); + TF_DeleteTensor(feedValues[0]); + TF_DeleteTensor(feedValues[1]); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_DeleteSession(sess, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + float dxy_dxValue = 0.0f, dxy_dyValue = 0.0f; + ScalarFloatFromTensor(fetchValues[0], &dxy_dxValue); + EXPECT_EQ(Y, dxy_dxValue); + + ScalarFloatFromTensor(fetchValues[1], &dxy_dyValue); + EXPECT_EQ(X, dxy_dyValue); + + TF_DeleteTensor(fetchValues[0]); + TF_DeleteTensor(fetchValues[1]); +} + // REGISTER_OP for CApiAttributesTest test cases. // Registers two ops, each with a single attribute called 'v'. // The attribute in one op will have a type 'type', the other @@ -1784,7 +1930,7 @@ TEST_F(CApiAttributesTest, String) { TEST_F(CApiAttributesTest, StringList) { std::vector list = {"bugs", "bunny", "duck"}; - std::unique_ptr list_ptrs; + std::unique_ptr list_ptrs; std::unique_ptr list_lens; StringVectorToArrays(list, &list_ptrs, &list_lens); int list_total_size = 0; @@ -1800,7 +1946,7 @@ TEST_F(CApiAttributesTest, StringList) { ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); EXPECT_TF_META("v", list.size(), TF_ATTR_STRING, list_total_size); - std::unique_ptr values(new void*[list.size()]); + std::unique_ptr values(new void*[list.size()]); std::unique_ptr lens(new size_t[list.size()]); std::unique_ptr storage(new char[list_total_size]); TF_OperationGetAttrStringList(oper, "v", values.get(), lens.get(), @@ -2025,7 +2171,7 @@ TEST_F(CApiAttributesTest, TensorShapeProtoList) { tensorflow::PartialTensorShape(pts2).AsProto(&proto); proto.SerializeToString(&bytes2); - std::unique_ptr list_ptrs; + std::unique_ptr list_ptrs; std::unique_ptr list_lens; const std::vector list = {bytes1, bytes2}; StringVectorToArrays(list, &list_ptrs, &list_lens); @@ -2257,6 +2403,57 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) { TF_DeleteLibraryHandle(lib); } +class DummyKernel : public tensorflow::OpKernel { + public: + explicit DummyKernel(tensorflow::OpKernelConstruction* context) + : OpKernel(context) {} + void Compute(tensorflow::OpKernelContext* context) override {} +}; + +// Test we can query kernels +REGISTER_OP("TestOpWithSingleKernel") + .Input("a: float") + .Input("b: float") + .Output("o: float"); +REGISTER_KERNEL_BUILDER( + Name("TestOpWithSingleKernel").Device(tensorflow::DEVICE_CPU), DummyKernel); + +TEST(TestKernel, TestGetAllRegisteredKernels) { + TF_Status* status = TF_NewStatus(); + TF_Buffer* kernel_list_buf = TF_GetAllRegisteredKernels(status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + KernelList kernel_list; + kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length); + ASSERT_GT(kernel_list.kernel_size(), 0); + TF_DeleteBuffer(kernel_list_buf); + TF_DeleteStatus(status); +} + +TEST(TestKernel, TestGetRegisteredKernelsForOp) { + TF_Status* status = TF_NewStatus(); + TF_Buffer* kernel_list_buf = + TF_GetRegisteredKernelsForOp("TestOpWithSingleKernel", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + KernelList kernel_list; + kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length); + ASSERT_EQ(kernel_list.kernel_size(), 1); + EXPECT_EQ(kernel_list.kernel(0).op(), "TestOpWithSingleKernel"); + EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU"); + TF_DeleteBuffer(kernel_list_buf); + TF_DeleteStatus(status); +} + +TEST(TestKernel, TestGetRegisteredKernelsForOpNoKernels) { + TF_Status* status = TF_NewStatus(); + TF_Buffer* kernel_list_buf = TF_GetRegisteredKernelsForOp("Unknown", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + KernelList kernel_list; + kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length); + ASSERT_EQ(kernel_list.kernel_size(), 0); + TF_DeleteBuffer(kernel_list_buf); + TF_DeleteStatus(status); +} + #undef EXPECT_TF_META } // namespace diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index f3b28c1708129d39e451d927a89c0d10e2193b63..f15d9ee20adb31a0b76e2cd0d1e67f17a9deff05 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -26,6 +26,10 @@ limitations under the License. using tensorflow::GraphDef; using tensorflow::NodeDef; +static void BoolDeallocator(void* data, size_t, void* arg) { + delete[] static_cast(data); +} + static void Int32Deallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } @@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } +TF_Tensor* BoolTensor(bool v) { + const int num_bytes = sizeof(bool); + bool* values = new bool[1]; + values[0] = v; + return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator, + nullptr); +} + TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; for (int i = 0; i < num_dims; ++i) { @@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, return op; } +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name) { + unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name) { unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); @@ -216,6 +234,13 @@ TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, return MinWithDevice(l, r, graph, /*op_device=*/"", s, name); } +TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name) { + TF_Operation* op; + BinaryOpHelper("Mul", l, r, graph, s, name, &op, "", true); + return op; +} + TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, const char* name) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index c16aba666ee6974fed5351c2d9ac291dcbcdecab..7eeb1ee5e17ad7e5644f8bc8a18ca967b108475d 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -31,6 +31,8 @@ using ::tensorflow::string; typedef std::unique_ptr unique_tensor_ptr; +TF_Tensor* BoolTensor(int32_t v); + // Create a tensor with values of type TF_INT8 provided by `values`. TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); @@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name = "const"); +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name = "scalar"); @@ -80,6 +85,9 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name = "min"); +TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "mul"); + // If `op_device` is non-empty, set the created op on that device. TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph, const string& op_device, TF_Status* s, diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index f265da2c2c89c0e9caf14f2213c606fcb69997e0..37be52f57d865c1e59611540d5dab04b59e89444 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -54,7 +54,6 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", @@ -93,10 +92,10 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/eager:remote_tensor_handle", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", ], ) @@ -122,6 +121,7 @@ tf_cuda_library( tf_cuda_cc_test( name = "c_api_test", + size = "small", srcs = [ "c_api_debug_test.cc", "c_api_test.cc", @@ -139,7 +139,7 @@ tf_cuda_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 81221c4078bec9820ee187efdf0314da378be62b..dfb1c9a37644c726e1eabab775593596d5b556b9 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -36,9 +36,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/node_def_util.h" @@ -46,10 +46,12 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -107,7 +109,8 @@ tensorflow::Status GetAllRemoteDevices( } tensorflow::Status CreateRemoteContexts( - const std::vector& remote_workers, + const std::vector& remote_workers, int64 rendezvous_id, + int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::gtl::FlatMap* remote_contexts) { for (int i = 0; i < remote_workers.size(); i++) { @@ -115,15 +118,18 @@ tensorflow::Status CreateRemoteContexts( tensorflow::eager::CreateContextRequest request; tensorflow::eager::CreateContextResponse response; + request.set_rendezvous_id(rendezvous_id); tensorflow::DeviceNameUtils::ParsedName parsed_name; if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, &parsed_name)) { return tensorflow::errors::InvalidArgument( "Unable to parse ", remote_worker, " as a device name"); } + *request.mutable_server_def() = server_def; request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); + request.set_keep_alive_secs(keep_alive_secs); auto* eager_client = remote_eager_workers->GetClient(remote_worker); if (eager_client == nullptr) { return tensorflow::errors::Internal( @@ -145,48 +151,86 @@ tensorflow::Status CreateRemoteContexts( return tensorflow::Status::OK(); } -tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, - TFE_Context** ctx) { - string worker_name = tensorflow::strings::StrCat( - "/job:", opts->server_def.job_name(), - "/replica:0/task:", opts->server_def.task_index()); - std::unique_ptr server; - TF_RETURN_IF_ERROR( - tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server)); +tensorflow::Status UpdateTFE_ContextWithServerDef( + int keep_alive_secs, const tensorflow::ServerDef& server_def, + TFE_Context* ctx) { + // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the + // server object (which currently CHECK-fails) and we miss the error, instead, + // we log the error, and then return to allow the user to see the error + // message. +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const ::tensorflow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + LOG(ERROR) << _status.error_message(); \ + return _status; \ + } \ + } while (0); + + string worker_name = + tensorflow::strings::StrCat("/job:", server_def.job_name(), + "/replica:0/task:", server_def.task_index()); + + std::unique_ptr server; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server)); + + tensorflow::GrpcServer* grpc_server = + dynamic_cast(server.get()); + if (grpc_server == nullptr) { + LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( + "Currently, TFE_NewContext only supports tensorflow::GrpcServer.")); + } - TF_RETURN_IF_ERROR(server->Start()); + LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); + + int64 rendezvous_id = tensorflow::random::New64(); std::vector remote_workers; - server->master_env()->worker_cache->ListWorkers(&remote_workers); + grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers); remote_workers.erase( std::remove(remote_workers.begin(), remote_workers.end(), worker_name), remote_workers.end()); std::unique_ptr remote_device_mgr; - TF_RETURN_IF_ERROR(GetAllRemoteDevices( - remote_workers, server->master_env()->worker_cache, &remote_device_mgr)); + LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices( + remote_workers, grpc_server->master_env()->worker_cache, + &remote_device_mgr)); std::shared_ptr channel_cache = - server->channel_cache(); + grpc_server->channel_cache(); std::unique_ptr remote_eager_workers( tensorflow::eager::NewGrpcEagerClientCache(channel_cache)); // Initialize remote eager workers. tensorflow::gtl::FlatMap remote_contexts; - TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers, - remote_eager_workers.get(), - opts->async, &remote_contexts)); + LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( + remote_workers, rendezvous_id, keep_alive_secs, server_def, + remote_eager_workers.get(), ctx->context.Async(), &remote_contexts)); tensorflow::RemoteRendezvous* r = - server->worker_env()->rendezvous_mgr->Find(0); + grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); + + auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); + TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( + session_name, server_def, true)); + + std::shared_ptr worker_session; + TF_RETURN_IF_ERROR( + grpc_server->worker_env()->session_mgr->WorkerSessionForSession( + session_name, &worker_session)); - auto* device_mgr = server->worker_env()->device_mgr; - *ctx = new TFE_Context(opts->session_options.options, opts->policy, - opts->async, device_mgr, r, std::move(server), - std::move(remote_eager_workers), - std::move(remote_device_mgr), remote_contexts); + // Initialize remote tensor communication based on worker session. + TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); + + auto* device_mgr = grpc_server->worker_env()->device_mgr; + + ctx->context.InitializeRemote(std::move(server), + std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts, + r, device_mgr, keep_alive_secs); return tensorflow::Status::OK(); +#undef LOG_AND_RETURN_IF_ERROR } } // namespace @@ -208,15 +252,6 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( options->policy = policy; } -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( - TFE_ContextOptions* options, const void* proto, size_t proto_len, - TF_Status* status) { - if (!options->server_def.ParseFromArray(proto, proto_len)) { - status->status = tensorflow::errors::InvalidArgument( - "Invalid tensorflow.ServerDef protocol buffer"); - } -} - TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char async, TF_Status* status) { @@ -226,12 +261,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { - if (!opts->server_def.job_name().empty()) { - TFE_Context* ctx = nullptr; - status->status = NewRemoteAwareTFE_Context(opts, &ctx); - return ctx; - } - std::vector devices; status->status = tensorflow::DeviceFactory::AddDevices( opts->session_options.options, "/job:localhost/replica:0/task:0", @@ -247,7 +276,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { opts->async, std::move(device_mgr), r); } -void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; } +void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; @@ -260,6 +289,22 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } +// Set server_def on the context, possibly updating it. +TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status) { + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + return; + } + status->status = + UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx); +} + void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { ctx->context.SetThreadLocalDevicePlacementPolicy( @@ -295,7 +340,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { } void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { - DCHECK(h); + if (h == nullptr) return; if (h->handle) { h->handle->Unref(); } @@ -307,19 +352,34 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { - const tensorflow::Tensor* t = nullptr; - status->status = h->handle->Tensor(&t); - return t == nullptr ? 0 : t->dims(); + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } + int result; + status->status = h->handle->NumDims(&result); + return result; } int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { - const tensorflow::Tensor* t = nullptr; - status->status = h->handle->Tensor(&t); - return t == nullptr ? 0 : t->dim_size(dim_index); + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } + tensorflow::int64 result; + status->status = h->handle->Dim(dim_index, &result); + return result; } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } tensorflow::Device* d = nullptr; status->status = h->handle->OpDevice(&d); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" @@ -327,6 +387,11 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { } TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } // TODO(agarwal): move this implementation inside TFE_TensorHandle. tensorflow::Device* d = nullptr; tensorflow::Device* op_device = nullptr; @@ -421,8 +486,11 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, return ret; } -void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { - op->operation.MutableAttrs()->Set(attr_name, value); +void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value, + size_t length) { + op->operation.MutableAttrs()->Set( + attr_name, + tensorflow::StringPiece(static_cast(value), length)); } void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { @@ -473,16 +541,22 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, op->operation.MutableAttrs()->Set(attr_name, attr_value); } -#define TFE_OP_SET_ATTR_LIST(fn, type) \ - void fn(TFE_Op* op, const char* attr_name, const type* values, \ - int num_values) { \ - op->operation.MutableAttrs()->Set( \ - attr_name, \ - tensorflow::gtl::ArraySlice(values, num_values)); \ +void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, + const void* const* values, const size_t* lengths, + int num_values) { + std::vector v(num_values); + for (int i = 0; i < num_values; ++i) { + v[i] = tensorflow::StringPiece(static_cast(values[i]), + lengths[i]); } -TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*) -TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float) -#undef TFE_OP_SET_ATTR_LIST + op->operation.MutableAttrs()->Set(attr_name, v); +} + +void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, + const float* values, int num_values) { + op->operation.MutableAttrs()->Set( + attr_name, tensorflow::gtl::ArraySlice(values, num_values)); +} void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, const int64_t* values, int num_values) { @@ -614,17 +688,17 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( TFE_TensorHandle* h, TF_Status* status) { - tensorflow::Device* d = nullptr; - tensorflow::Device* op_device = nullptr; - const tensorflow::Tensor* t = nullptr; - status->status = h->handle->TensorAndDevice(&t, &d, &op_device); - if (!status->status.ok()) return nullptr; - if (d != nullptr) { + if (!h->handle->OnHostCPU()) { status->status = tensorflow::errors::FailedPrecondition( "TFE_TensorHandle is placed in device (not host) memory. Cannot return " "a tensorflow::Tensor"); return nullptr; } + tensorflow::Device* d = nullptr; + tensorflow::Device* op_device = nullptr; + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) return nullptr; return t; } @@ -650,14 +724,20 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } } // namespace +void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); } + +void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); } + namespace tensorflow { void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, const char* attr_name, TF_Status* status) { switch (default_value.value_case()) { - case tensorflow::AttrValue::kS: - TFE_OpSetAttrString(op, attr_name, default_value.s().data()); + case tensorflow::AttrValue::kS: { + const string& v = default_value.s(); + TFE_OpSetAttrString(op, attr_name, v.data(), v.size()); break; + } case tensorflow::AttrValue::kI: TFE_OpSetAttrInt(op, attr_name, static_cast(default_value.i())); break; diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 1862af3ce2f505a6e83b4805417eaf335ed07bc0..a0ebc6fa0a22ed61be91c2974352c2988fb4cd92 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -81,16 +81,6 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); -// A tensorflow.ServerDef specifies remote workers (in addition to the current -// workers name). Operations created on this context can then be executed on -// any of these remote workers by setting an appropriate device. -// -// If the following is set, all servers identified by the -// ServerDef must be up when the context is created. -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( - TFE_ContextOptions* options, const void* proto, size_t proto_len, - TF_Status* status); - // Destroy an options object. TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); @@ -102,8 +92,7 @@ typedef struct TFE_Context TFE_Context; TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext( const TFE_ContextOptions* opts, TF_Status* status); -TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, - TF_Status* status); +TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx); TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status); @@ -128,6 +117,18 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, unsigned char async, TF_Status* status); +// A tensorflow.ServerDef specifies remote workers (in addition to the current +// workers name). Operations created on this context can then be executed on +// any of these remote workers by setting an appropriate device. +// +// If the following is set, all servers identified by the +// ServerDef must be up when the context is created. +TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status); + // Causes the calling thread to block till all ops dispatched in async mode // have been executed. Note that "execution" here refers to kernel execution / // scheduling of copies, etc. Similar to sync execution, it doesn't guarantee @@ -278,7 +279,8 @@ TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType( TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, - const char* value); + const void* value, + size_t length); TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value); TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, @@ -305,7 +307,8 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, - const char** value, + const void* const* values, + const size_t* lengths, int num_values); TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, @@ -378,6 +381,16 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); +// Some TF ops need a step container to be set to limit the lifetime of some +// resources (mostly TensorArray and Stack, used in while loop gradients in +// graph mode). Calling this on a context tells it to start a step. +TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx); + +// Ends a step. When there is no active step (that is, every started step has +// been ended) step containers will be cleared. Note: it is not safe to call +// TFE_ContextEndStep while ops which rely on the step container may be running. +TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 04a6efc47c5177c82b7e88168b67cc584587de7c..a5c0681e2e4eddae08954d9d0178ca96a3f8f29a 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -39,7 +39,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/remote_device.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" @@ -59,7 +59,6 @@ struct TFE_ContextOptions { // true if async execution is enabled. bool async = false; TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT}; - tensorflow::ServerDef server_def; }; struct TFE_Context { @@ -73,23 +72,6 @@ struct TFE_Context { default_policy), async, std::move(device_mgr), rendezvous) {} - explicit TFE_Context( - const tensorflow::SessionOptions& opts, - TFE_ContextDevicePlacementPolicy default_policy, bool async, - tensorflow::DeviceMgr* local_device_mgr, - tensorflow::Rendezvous* rendezvous, - std::unique_ptr server, - std::unique_ptr remote_eager_workers, - std::unique_ptr remote_device_mgr, - const tensorflow::gtl::FlatMap& - remote_contexts) - : context(opts, - static_cast( - default_policy), - async, local_device_mgr, rendezvous, std::move(server), - std::move(remote_eager_workers), std::move(remote_device_mgr), - remote_contexts) {} - tensorflow::EagerContext context; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 992d1afd5fcb0641794bb2abbe5ab20a287d3b62..71d5f3613c89762633113b4e1dfb82b8199a1cd1 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/c/eager/c_api_test_util.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -49,7 +49,7 @@ void BM_InitOp(int iters) { } tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -80,7 +80,7 @@ void BM_Execute(int iters, int async) { tensorflow::testing::StopTiming(); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -95,7 +95,7 @@ TEST(CAPI, Context) { TF_DeviceList* devices = TFE_ContextListDevices(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); const int num_devices = TF_DeviceListCount(devices); @@ -108,14 +108,14 @@ TEST(CAPI, Context) { TF_DeleteStatus(status); } -tensorflow::ServerDef GetServerDef(int num_tasks) { +tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { tensorflow::ServerDef server_def; server_def.set_protocol("grpc"); - server_def.set_job_name("localhost"); + server_def.set_job_name(job_name); server_def.set_task_index(0); tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); tensorflow::JobDef* job_def = cluster_def->add_job(); - job_def->set_name("localhost"); + job_def->set_name(job_name); for (int i = 0; i < num_tasks; i++) { int port = tensorflow::testing::PickUnusedPortOrDie(); job_def->mutable_tasks()->insert( @@ -124,6 +124,10 @@ tensorflow::ServerDef GetServerDef(int num_tasks) { return server_def; } +tensorflow::ServerDef GetServerDef(int num_tasks) { + return GetServerDef("localhost", num_tasks); +} + void TestRemoteExecute(bool async) { tensorflow::ServerDef server_def = GetServerDef(2); @@ -132,24 +136,24 @@ void TestRemoteExecute(bool async) { server_def.set_task_index(1); - std::unique_ptr worker_server; - ASSERT_TRUE( - tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) - .ok()); + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); ASSERT_TRUE(worker_server->Start().ok()); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_EXPLICIT); TFE_Context* ctx = TFE_NewContext(opts, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); const char remote_device_name[] = @@ -195,8 +199,8 @@ void TestRemoteExecute(bool async) { TFE_DeleteOp(matmul); TFE_ContextAsyncWait(ctx, status); - TFE_DeleteContext(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); @@ -208,38 +212,48 @@ TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } void TestRemoteExecuteSilentCopies(bool async) { - tensorflow::ServerDef server_def = GetServerDef(2); + tensorflow::ServerDef server_def = GetServerDef(3); // This server def has the task index set to 0. string serialized = server_def.SerializeAsString(); server_def.set_task_index(1); - - std::unique_ptr worker_server; - ASSERT_TRUE( - tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) - .ok()); - ASSERT_TRUE(worker_server->Start().ok()); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); TFE_Context* ctx = TFE_NewContext(opts, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); - const char remote_device_name[] = - "/job:localhost/replica:0/task:1/device:CPU:0"; + const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; + const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; - // Handles are on task0, but op is on remote (task1). - TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0); - TFE_OpSetDevice(matmul, remote_device_name, status); + auto* h1_task2 = + TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Handles are on task0 (local), and task2, but op is on task1. + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2); + TFE_OpSetDevice(matmul, task1_name, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* retvals[1]; @@ -265,18 +279,20 @@ void TestRemoteExecuteSilentCopies(bool async) { TFE_DeleteTensorHandle(h0_task0); TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(h1_task2); TFE_DeleteTensorHandle(retvals[0]); TFE_DeleteOp(matmul); TFE_ContextAsyncWait(ctx, status); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); // TODO(nareshmodi): Figure out how to correctly shut the server down. - worker_server.release(); + worker_server1.release(); + worker_server2.release(); } TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); } @@ -284,6 +300,147 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) { TestRemoteExecuteSilentCopies(true); } +void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, + const std::vector& expected_values) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + std::unique_ptr actual_values(new float[expected_values.size()]); + EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t)); + memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + + for (int i = 0; i < expected_values.size(); i++) { + EXPECT_EQ(expected_values[i], actual_values[i]) + << "Mismatch in expected values at (zero-based) index " << i; + } +} + +void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, + const char* remote_device_name, + const char* local_device_name) { + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = + TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22}); + + TFE_DeleteTensorHandle(retval_task0); + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); +} + +void TestRemoteExecuteChangeServerDef(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + const char local_device_name[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); + + // Update the server def with a new set of names (worker instead of + // localhost). + tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2); + serialized = updated_server_def.SerializeAsString(); + + updated_server_def.set_task_index(1); + tensorflow::Status s = tensorflow::GrpcServer::Create( + updated_server_def, tensorflow::Env::Default(), &worker_server); + ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(worker_server->Start().ok()); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Create a new tensor_handle. + TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(); + + // Check that copying it to the old remote device (named localhost) fails. + TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Copying and executing on the new remote device works. + const char new_remote_device_name[] = + "/job:worker/replica:0/task:1/device:CPU:0"; + const char new_local_device_name[] = + "/job:worker/replica:0/task:0/device:CPU:0"; + + auto* h0_task1_new = TFE_TensorHandleCopyToDevice( + h0_task0_new, ctx, new_remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteTensorHandle(h0_task0_new); + TFE_DeleteTensorHandle(h0_task1_new); + + CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name, + new_local_device_name); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + TFE_DeleteContext(ctx); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteChangeServerDef) { + TestRemoteExecuteChangeServerDef(false); +} +TEST(CAPI, RemoteExecuteChangeServerDefAsync) { + TestRemoteExecuteChangeServerDef(true); +} + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); @@ -368,8 +525,7 @@ void TensorHandleCopyBetweenDevices(bool async) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleCopyBetweenDevices) { @@ -406,7 +562,7 @@ void TensorHandleCopyBetweenDevicesError(bool async) { TFE_DeleteTensorHandle(hcopy); TFE_DeleteTensorHandle(hcpu); if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice); - TFE_DeleteContext(ctx, status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleCopyBetweenDevicesError) { @@ -439,7 +595,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_DeleteContext(ctx, status.get()); + TFE_DeleteContext(ctx); return; } const string gpu_1_name(TF_DeviceListName(devices, 1, status.get())); @@ -472,8 +628,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { @@ -521,8 +676,7 @@ void TensorHandleSilentCopy(bool async) { TFE_DeleteTensorHandle(hcpu); TFE_ContextAsyncWait(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); } @@ -568,8 +722,7 @@ void TensorHandleSilentCopyLocal(bool async) { TFE_DeleteTensorHandle(hcpu); TFE_ContextAsyncWait(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); } TEST(CAPI, TensorHandleSilentCopyLocalAsync) { @@ -602,11 +755,47 @@ void SetAndGetOpDevices(bool async) { TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } +TEST(CAPI, TensorHandleNullptr) { + TFE_TensorHandle* h = nullptr; + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + TF_Tensor* t = TFE_TensorHandleResolve(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(t, nullptr); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + const char* device_name = TFE_TensorHandleDeviceName(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_name, nullptr); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + int num_dims = TFE_TensorHandleNumDims(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(num_dims, -1); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + int dim = TFE_TensorHandleDim(h, 0, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(dim, -1); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); +} + void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -628,7 +817,7 @@ void Execute_MatMul_CPU(bool async) { TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -700,7 +889,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { TFE_DeleteTensorHandle(m1); TFE_DeleteTensorHandle(m2); TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) { @@ -731,7 +920,7 @@ void Execute_MatMul_CPU_Type_Error(bool async) { if (retvals[0] != nullptr) { TFE_DeleteTensorHandle(retvals[0]); } - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } @@ -769,7 +958,7 @@ TEST(CAPI, Execute_Min_CPU) { TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -811,7 +1000,7 @@ void Execute_MatMul_XLA_CPU(bool async) { EXPECT_EQ(10, product[1]); EXPECT_EQ(15, product[2]); EXPECT_EQ(22, product[3]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); } @@ -850,7 +1039,7 @@ void Execute_Min_XLA_CPU(bool async) { TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); } @@ -886,7 +1075,7 @@ void ExecuteWithTracing(bool async) { TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -962,7 +1151,7 @@ TEST(CAPI, Function_ident_CPU) { TF_DeleteTensor(r); TFE_DeleteTensorHandle(result[0]); } - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); } @@ -1032,7 +1221,7 @@ TEST(CAPI, Function_ident_XLA_CPU) { TF_DeleteTensor(r); TFE_DeleteTensorHandle(result[0]); } - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); } @@ -1108,7 +1297,7 @@ void FunctionDefAndExecute(bool async) { EXPECT_EQ(10, product[1]); EXPECT_EQ(15, product[2]); EXPECT_EQ(22, product[3]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -1149,7 +1338,7 @@ void BM_ExecuteFunction(int iters, int async) { tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); TFE_DeleteTensorHandle(retval[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -1162,8 +1351,8 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, if (TF_GetCode(status) != TF_OK) return nullptr; TFE_OpSetAttrType(op, "dtype", TF_FLOAT); TFE_OpSetAttrShape(op, "shape", {}, 0, status); - TFE_OpSetAttrString(op, "container", ""); - TFE_OpSetAttrString(op, "shared_name", ""); + TFE_OpSetAttrString(op, "container", "", 0); + TFE_OpSetAttrString(op, "shared_name", "", 0); if (TF_GetCode(status) != TF_OK) return nullptr; TFE_TensorHandle* var_handle = nullptr; int num_retvals = 1; @@ -1237,7 +1426,7 @@ TEST(CAPI, Variables) { TFE_DeleteTensorHandle(var_handle); TFE_DeleteTensorHandle(value_handle); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -1276,7 +1465,7 @@ void BM_ReadVariable(int iters) { TFE_DeleteOp(op); TFE_DeleteTensorHandle(var_handle); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 734e712daa39c03f0177eb199b1acb1b19e5d845..1adb0458c35193117b5fa5cfe9ceffbaaf699af7 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -520,7 +520,12 @@ Status GradientTape::ComputeGradient( } } else { any_gradient_nonzero = true; - auto new_gradients = vspace.AggregateGradients(grad_it->second); + Gradient* new_gradients = nullptr; + if (grad_it->second.size() == 1) { + new_gradients = grad_it->second.at(0); + } else { + new_gradients = vspace.AggregateGradients(grad_it->second); + } if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); } else { diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index e18fdf6c57bd3f432d8cb73536fb816df90b3963..8486b585c8587e18e8eea18a893fac0a40ff4a27 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -155,7 +155,7 @@ void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, tensorflow::shape_inference::ShapeHandle shape; status->status = ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); - if (status->status.ok()) return; + if (!status->status.ok()) return; shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); } ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 079e063d3e3fbdaf833e9031f5f9438853c14099..f56521dac0374849081fe94f16feb08e55647b56 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -121,6 +121,7 @@ cc_library( deps = [ ":array_grad", ":data_flow_grad", + ":image_grad", ":math_grad", ":nn_grad", ], @@ -331,6 +332,36 @@ tf_cc_test( ], ) +cc_library( + name = "image_grad", + srcs = ["gradients/image_grad.cc"], + deps = [ + ":cc_ops", + ":cc_ops_internal", + ":grad_op_registry", + ":gradients", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "gradients_image_grad_test", + srcs = ["gradients/image_grad_test.cc"], + deps = [ + ":cc_ops", + ":client_session", + ":grad_op_registry", + ":grad_testutil", + ":gradient_checker", + ":image_grad", + ":testutil", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "math_grad", srcs = ["gradients/math_grad.cc"], @@ -348,9 +379,11 @@ tf_cc_test( srcs = ["gradients/math_grad_test.cc"], deps = [ ":cc_ops", + ":client_session", ":grad_op_registry", ":grad_testutil", ":gradient_checker", + ":gradients", ":math_grad", ":testutil", "//tensorflow/core:lib_internal", @@ -530,7 +563,7 @@ cc_library_with_android_deps( "//tensorflow/core/api_def:base_api_def", ], deps = [ - "//tensorflow/core:framework", + "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", @@ -595,7 +628,6 @@ tf_cc_binary( copts = tf_copts(), linkopts = select({ "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//tensorflow:darwin": [ "-lm", "-lpthread", diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index ba056a8f3a84910aebf5079573cb64c19f41469d..0e61089a5950ee894ad5489317757cff8a85e966 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -127,4 +127,22 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs, target_node_names, outputs, run_metadata); } +Status ClientSession::MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) { + TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph()); + return impl()->session_->MakeCallable(callable_options, out_handle); +} + +Status ClientSession::RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) { + return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors, + run_metadata); +} + +Status ClientSession::ReleaseCallable(CallableHandle handle) { + return impl()->session_->ReleaseCallable(handle); +} + } // end namespace tensorflow diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h index 5fb4109f7d15d5997f745acd913e60a02855fd73..7dd653eec4ec729b652cb779d06e820bfb437b3c 100644 --- a/tensorflow/cc/client/client_session.h +++ b/tensorflow/cc/client/client_session.h @@ -87,7 +87,33 @@ class ClientSession { const std::vector& run_outputs, std::vector* outputs, RunMetadata* run_metadata) const; - // TODO(keveman): Add support for partial run. + /// \brief A handle to a subgraph, created with + /// `ClientSession::MakeCallable()`. + typedef int64 CallableHandle; + + /// \brief Creates a `handle` for invoking the subgraph defined by + /// `callable_options`. + /// NOTE: This API is still experimental and may change. + Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle); + + /// \brief Invokes the subgraph named by `handle` with the given options and + /// input tensors. + /// + /// The order of tensors in `feed_tensors` must match the order of names in + /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will + /// match the order of names in `CallableOptions::fetch()` when this subgraph + /// was created. + /// NOTE: This API is still experimental and may change. + Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata); + + /// \brief Releases resources associated with the given `handle` in this + /// session. + /// NOTE: This API is still experimental and may change. + Status ReleaseCallable(CallableHandle handle); private: class Impl; diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc index ea5cf5a1f12be316cc6e0d0a02cd3caf4d177400..559ffea7e817526e7f1396cd0e8187d01364f23b 100644 --- a/tensorflow/cc/client/client_session_test.cc +++ b/tensorflow/cc/client/client_session_test.cc @@ -95,5 +95,26 @@ TEST(ClientSessionTest, MultiThreaded) { test::ExpectTensorEqual(outputs[0], test::AsTensor({-1, 2}, {2})); } +TEST(ClientSessionTest, Callable) { + Scope root = Scope::NewRootScope(); + auto a = Placeholder(root, DT_INT32); + auto b = Placeholder(root, DT_INT32); + auto c = Add(root, a, b); + ClientSession session(root); + std::vector outputs; + + CallableOptions options; + options.add_feed(a.node()->name()); + options.add_feed(b.node()->name()); + options.add_fetch(c.node()->name()); + ClientSession::CallableHandle callable; + TF_CHECK_OK(session.MakeCallable(options, &callable)); + TF_EXPECT_OK(session.RunCallable( + callable, {test::AsTensor({1}, {}), test::AsTensor({41}, {})}, + &outputs, nullptr)); + test::ExpectTensorEqual(outputs[0], test::AsTensor({42}, {})); + TF_EXPECT_OK(session.ReleaseCallable(callable)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc index de2645cb440bda1f35e764af9197ca97bb760c08..e9f9c59e3aa0e8a9dc5d5e658540e9da73adaca5 100644 --- a/tensorflow/cc/framework/gradient_checker.cc +++ b/tensorflow/cc/framework/gradient_checker.cc @@ -247,7 +247,7 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs, auto y_pos_flat = y_pos[y_idx].flat(); auto y_neg_flat = y_neg[y_idx].flat(); const int64 y_size = y_shapes[y_idx].num_elements(); - const Y_T scale = Y_T{2 * delta}; + const Y_T scale = 2 * delta; auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix(); for (int c = 0; c < y_size; ++c) { SetJacobian(&jacobian, r * x_stride + unit_dimension, @@ -351,7 +351,14 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs, auto jac_n = jacobian_ns[i].matrix(); for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) { for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) { - *max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c))); + auto cur_error = std::fabs(jac_t(r, c) - jac_n(r, c)); + // Treat any NaN as max_error and immediately return. + // (Note that std::max may ignore NaN arguments.) + if (std::isnan(cur_error)) { + *max_error = cur_error; + return Status::OK(); + } + *max_error = std::max(*max_error, cur_error); } } } @@ -409,6 +416,7 @@ Status ComputeGradientError(const Scope& scope, const Output& x, const Output& y, const TensorShape& y_shape, JAC_T* max_error); INSTANTIATE_GRAD_ERR_TYPE(float, float, float); +INSTANTIATE_GRAD_ERR_TYPE(double, float, double); INSTANTIATE_GRAD_ERR_TYPE(double, double, double); INSTANTIATE_GRAD_ERR_TYPE(complex64, float, float); INSTANTIATE_GRAD_ERR_TYPE(float, complex64, float); diff --git a/tensorflow/cc/framework/gradient_checker_test.cc b/tensorflow/cc/framework/gradient_checker_test.cc index d4f0a7f5ab3716be41e22c02a21aca028f76fb88..8dd762c282eff287bddd49ea6f38b2b8060949b0 100644 --- a/tensorflow/cc/framework/gradient_checker_test.cc +++ b/tensorflow/cc/framework/gradient_checker_test.cc @@ -28,12 +28,14 @@ namespace { using ops::Complex; using ops::Const; +using ops::Div; using ops::MatMul; using ops::Placeholder; using ops::Real; using ops::Split; using ops::Square; using ops::Stack; +using ops::Sub; using ops::Unstack; TEST(GradientCheckerTest, BasicFloat) { @@ -104,6 +106,20 @@ TEST(GradientCheckerTest, Complex64ToFloat) { EXPECT_LT(max_error, 1e-4); } +// When calculating gradients that are undefined, test we get NaN +// as the computed error rather than 0. +TEST(GradientCheckerTest, BasicNan) { + Scope scope = Scope::NewRootScope(); + TensorShape shape({2, 4, 3}); + auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape)); + // y = x/(x-x) should always return NaN + auto y = Div(scope, x, Sub(scope, x, x)); + float max_error; + TF_ASSERT_OK((ComputeGradientError( + scope, {x}, {shape}, {y}, {shape}, &max_error))); + EXPECT_TRUE(std::isnan(max_error)); +} + TEST(GradientCheckerTest, MatMulGrad) { Scope scope = Scope::NewRootScope(); diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 62a889181e787f2e181135ab0563c45e1bab8812..8c886f31711eb014fb9e9d600c9c78cf22073f71 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -37,6 +37,11 @@ Scope& Scope::operator=(const Scope& other) { return *this; } +namespace { +const char kScopeSeparator[] = "/"; +const char kSuffixSeparator[] = "_"; +} // namespace + Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, bool disable_shape_inference) : graph_(graph), @@ -308,19 +313,23 @@ string Scope::Impl::GetUniqueName(const string& prefix, return prefix; } auto entry = name_map_->find(prefix); - string unique_name = prefix; if (entry == name_map_->end()) { name_map_->insert({prefix, 0}); - } else { - unique_name = strings::StrCat(unique_name, "_", ++entry->second); + return prefix; } + string unique_name; + do { + unique_name = strings::StrCat(prefix, kSuffixSeparator, ++entry->second); + } while (name_map_->find(unique_name) != name_map_->end()); + name_map_->insert({unique_name, 0}); return unique_name; } string Scope::Impl::GetNameForOp(const string& default_name) const { const string unique_name = GetUniqueName(default_name, true /* check_single_use */); - const string sep = name_.empty() || unique_name.empty() ? "" : "/"; + const string sep = + name_.empty() || unique_name.empty() ? "" : kScopeSeparator; return strings::StrCat(name_, sep, unique_name); } @@ -345,7 +354,8 @@ Scope Scope::NewSubScope(const string& child_scope_name) const { } const string unique_name = impl()->GetUniqueName(child_scope_name, false /* check_single_use */); - const string sep = impl()->name_.empty() || unique_name.empty() ? "" : "/"; + const string sep = + impl()->name_.empty() || unique_name.empty() ? "" : kScopeSeparator; return Scope(new Impl(*this, Impl::Tags::ScopeName(), strings::StrCat(impl()->name_, sep, unique_name), false /* copy_names */)); @@ -412,7 +422,7 @@ CompositeOpScopes Scope::GetCompositeOpScopes( if (!impl()->single_use_scope()) { Scope child = NewSubScope(impl()->op_name_.empty() ? composite_op_name : impl()->op_name_); - const string child_op_sep = impl()->name_.empty() ? "" : "_"; + const string child_op_sep = impl()->name_.empty() ? "" : kSuffixSeparator; const string child_name = strings::StrCat(impl()->name_, child_op_sep, child.impl()->name_); return {child, @@ -435,7 +445,13 @@ class InternalScope { static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) { Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap; for (const Node* node : graph->nodes()) { - (*name_map)[node->name()] = 0; + const string& name = node->name(); + (*name_map)[name] = 0; + // Add all name prefixes ('/' separated). + size_t idx = -1; + while ((idx = name.find(kScopeSeparator, idx + 1)) != string::npos) { + (*name_map)[name.substr(0, idx)] = 0; + } } // We provide null destructors for these shared ptrs (except for name_map) // since the caller owns them and doesn't want the scope to destroy them. diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 8efcfed20d0b86d86d8c20a3d8630c7c6bc909c3..58adaef2e942a7fa6b0ce8d5534ac3e2fd380580 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -34,8 +34,7 @@ class Scope::Impl { // name that has not been used so far in a scope will get no suffix. Later // uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes // can share the same NameMap. For instance, a new scope created using - // WithControlDependencies() should would share the same NameMap with the - // parent. + // WithControlDependencies() would share the same NameMap with the parent. typedef std::unordered_map NameMap; Impl(const std::shared_ptr& graph, diff --git a/tensorflow/cc/framework/scope_test.cc b/tensorflow/cc/framework/scope_test.cc index 9eca9d3face34319413e1acbc2f5ac0b2ba85374..b40b345eb84237c34ea593021bea022ad28095f7 100644 --- a/tensorflow/cc/framework/scope_test.cc +++ b/tensorflow/cc/framework/scope_test.cc @@ -26,6 +26,16 @@ TEST(ScopeTest, BasicNames) { EXPECT_EQ(root.GetUniqueNameForOp("mul"), "mul"); } +TEST(ScopeTest, OpAndScopeNameCollision) { + Scope root = Scope::NewRootScope(); + EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo"); + EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo_1"); + EXPECT_EQ(root.GetUniqueNameForOp("foo_1"), "foo_1_1"); + EXPECT_EQ(root.GetUniqueNameForOp("foo_2"), "foo_2"); + EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo_3"); + EXPECT_EQ(root.GetUniqueNameForOp("foo_2"), "foo_2_1"); +} + TEST(ScopeTest, HierarchicalNames) { Scope root = Scope::NewRootScope(); Scope child = root.NewSubScope("child"); diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index ff348fadb24e29a83bd6c8853aa67931f6df4182..e9173227aadbf86eab666e6c17bacacb92888572 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -120,6 +120,24 @@ Status SplitGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Split", SplitGrad); +Status FillGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = fill(fill_shape, x) + // No gradient returned for the fill_shape argument. + grad_outputs->push_back(NoGradient()); + // The gradient for x (which must be a scalar) is just the sum of + // all the gradients from the shape it fills. + // We use ReduceSum to implement this, which needs an argument providing + // the indices of all the dimensions of the incoming gradient. + // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))]) + auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]), + Const(scope, 1)); + grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims)); + return scope.status(); +} +REGISTER_GRADIENT_OP("Fill", FillGrad); + Status DiagGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { @@ -421,6 +439,58 @@ Status StridedSliceGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper); +Status SliceGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // Propagate the incoming gradient along all the selected values, + // and zero everywhere else. Use the Pad operator for this. + // + // First create an Nx2 padding where N is the number of input + // dimensions. The first column is the number of prepended zeros + // for each dimension, and the second column is the number of + // appended zeros. + // + // The first column is just the begin vector. + // The second column is the shape of the input element-wise + // subtracted by begin+size + + // Running example: + // input.shape = [3, 5, 3] + // begin = [1, 2, 1], size = [1, 3, 2] + Input input = op.input(0); + Input begin = op.input(1); + // input_rank = 3 + auto input_rank = Rank(scope, input); + // slice_size = [1, 3, 2] + auto slice_size = Shape(scope, op.output(0)); + // padding_shape = [3, 1] + auto padding_shape = Stack(scope, {input_rank, 1}); + // before_padding = [[1] + // [2] + // [1]] + Input before_padding = Reshape(scope, begin, padding_shape); + // after_padding_sizes = shape(input) - slice_size - begin + // = [3, 5, 3] - [1, 3, 2] - [1, 2, 1] + // = [1, 0, 0] + auto after_padding_sizes = + Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin); + // after_padding = [[1] + // [0] + // [0]] + Input after_padding = Reshape(scope, after_padding_sizes, padding_shape); + // paddings = [[1 1] + // [2 0] + // [1 0]] + auto paddings = + Concat(scope, {before_padding, after_padding}, Const(scope, 1)); + grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings)); + // Nothing propagated for "begin" and "size" inputs + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("Slice", SliceGrad); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index de3bd0fc9e2493f8ff76163f5be6bd4327c58c5a..f41de3dc2098df55fbbb616557f264a4e70db6b6 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -108,6 +108,14 @@ TEST_F(ArrayGradTest, SplitGrad) { RunTest({x}, {x_shape}, y.output, {y_shape, y_shape}); } +TEST_F(ArrayGradTest, FillGrad) { + TensorShape x_shape({}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + TensorShape y_shape({2, 5, 3}); + auto y = Fill(scope_, {2, 5, 3}, x); + RunTest(x, x_shape, y, y_shape); +} + TEST_F(ArrayGradTest, DiagGrad) { TensorShape x_shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); @@ -378,5 +386,12 @@ TEST_F(ArrayGradTest, StridedSliceGrad) { RunTest(x, x_shape, y, {1, 2, 2, 2}); } +TEST_F(ArrayGradTest, SliceGrad) { + TensorShape x_shape({3, 5, 3}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = Slice(scope_, x, {1, 2, 1}, {1, 3, 2}); + RunTest(x, x_shape, y, {1, 3, 2}); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc new file mode 100644 index 0000000000000000000000000000000000000000..882709e1e2817431a32c453fe0f35f2b2e6c69b0 --- /dev/null +++ b/tensorflow/cc/gradients/image_grad.cc @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/ops/image_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace ops { +namespace { + +Status ResizeNearestNeighborGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool align_corners; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + // The internal gradient implementation needs the shape of the input image. + // x_shape = shape(x)[1:3] + // = slice(shape(x), {1}, {3 - 1}) + auto x_shape = Slice(scope, Shape(scope, op.input(0)), {1}, {2}); + grad_outputs->push_back(internal::ResizeNearestNeighborGrad( + scope, grad_inputs[0], x_shape, + internal::ResizeNearestNeighborGrad::AlignCorners(align_corners))); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ResizeNearestNeighbor", ResizeNearestNeighborGradHelper); + +Status ResizeBilinearGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool align_corners; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + grad_outputs->push_back(internal::ResizeBilinearGrad( + scope, grad_inputs[0], op.input(0), + internal::ResizeBilinearGrad::AlignCorners(align_corners))); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ResizeBilinear", ResizeBilinearGradHelper); + +Status ResizeBicubicGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool align_corners; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + grad_outputs->push_back(internal::ResizeBicubicGrad( + scope, grad_inputs[0], op.input(0), + internal::ResizeBicubicGrad::AlignCorners(align_corners))); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ResizeBicubic", ResizeBicubicGradHelper); + +} // anonymous namespace +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e55c7561b030c50bd67bd53fd0d55710085c5d2 --- /dev/null +++ b/tensorflow/cc/gradients/image_grad_test.cc @@ -0,0 +1,157 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradient_checker.h" +#include "tensorflow/cc/framework/testutil.h" +#include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +using ops::Const; +using ops::ResizeBicubic; +using ops::ResizeBilinear; +using ops::ResizeNearestNeighbor; + +class ImageGradTest : public ::testing::Test { + protected: + ImageGradTest() : scope_(Scope::NewRootScope()) {} + + enum OpType { RESIZE_NEAREST, RESIZE_BILINEAR, RESIZE_BICUBIC }; + + template + Tensor MakeData(const TensorShape& data_shape) { + DataType data_type = DataTypeToEnum::v(); + Tensor data(data_type, data_shape); + auto data_flat = data.flat(); + for (int i = 0; i < data_flat.size(); ++i) { + data_flat(i) = T(i); + } + return data; + } + + template + void MakeOp(const OpType op_type, const Tensor& x_data, const Input& y_shape, + const bool align_corners, Output* x, Output* y) { + *x = Const(scope_, x_data); + switch (op_type) { + case RESIZE_NEAREST: + *y = ResizeNearestNeighbor( + scope_, *x, y_shape, + ResizeNearestNeighbor::AlignCorners(align_corners)); + return; + case RESIZE_BILINEAR: + *y = ResizeBilinear(scope_, *x, y_shape, + ResizeBilinear::AlignCorners(align_corners)); + return; + case RESIZE_BICUBIC: + *y = ResizeBicubic(scope_, *x, y_shape, + ResizeBicubic::AlignCorners(align_corners)); + return; + } + assert(false); + } + + template + void TestResizedShapeForType(const OpType op_type, const bool align_corners) { + TensorShape x_shape({1, 2, 2, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(op_type, x_data, {4, 6}, align_corners, &x, &y); + + ClientSession session(scope_); + std::vector outputs; + TF_ASSERT_OK(session.Run({y}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + EXPECT_EQ(outputs[0].shape(), TensorShape({1, 4, 6, 1})); + } + + void TestResizedShape(OpType op_type) { + for (const bool align_corners : {true, false}) { + TestResizedShapeForType(op_type, align_corners); + TestResizedShapeForType(op_type, align_corners); + TestResizedShapeForType(op_type, align_corners); + } + } + + template + void TestResizeToSmallerAndAlign(const OpType op_type, + const bool align_corners) { + TensorShape x_shape({1, 4, 6, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(op_type, x_data, {2, 3}, align_corners, &x, &y); + JAC_T max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_data, y, {1, 2, 3, 1}, &max_error))); + EXPECT_LT(max_error, 1e-3); + } + + template + void TestResizeToLargerAndAlign(const OpType op_type, + const bool align_corners) { + TensorShape x_shape({1, 2, 3, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(op_type, x_data, {4, 6}, align_corners, &x, &y); + JAC_T max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_data, y, {1, 4, 6, 1}, &max_error))); + EXPECT_LT(max_error, 1e-3); + } + + template + void TestResize(OpType op_type) { + for (const bool align_corners : {true, false}) { + TestResizeToSmallerAndAlign(op_type, align_corners); + TestResizeToLargerAndAlign(op_type, align_corners); + } + } + + Scope scope_; +}; + +TEST_F(ImageGradTest, TestNearestNeighbor) { + TestResizedShape(RESIZE_NEAREST); + TestResize(RESIZE_NEAREST); + TestResize(RESIZE_NEAREST); +} + +TEST_F(ImageGradTest, TestBilinear) { + TestResizedShape(RESIZE_BILINEAR); + TestResize(RESIZE_BILINEAR); + // Note that Y_T is always float for this op. We choose + // double for the jacobian to capture the higher precision + // between X_T and Y_T. + TestResize(RESIZE_BILINEAR); +} + +TEST_F(ImageGradTest, TestBicubic) { + TestResizedShape(RESIZE_BICUBIC); + TestResize(RESIZE_BICUBIC); + // Note that Y_T is always float for this op. We choose + // double for the jacobian to capture the higher precision + // between X_T and Y_T. + TestResize(RESIZE_BICUBIC); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 35a01e0341cb08c9b314908b6dcd76fd99c1e68b..5dcf00857df0eabd4e99f2782c1910515a9be265 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -441,6 +441,22 @@ Status RealDivGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("RealDiv", RealDivGrad); +Status UnsafeDivGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto x_1 = ConjugateHelper(scope, op.input(0)); + auto x_2 = ConjugateHelper(scope, op.input(1)); + // y = x_1 / x_2 + // dy/dx_1 = 1/x_2 + // dy/dx_2 = -x_1/x_2^2 + auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2); + auto gx_2 = + Mul(scope, grad_inputs[0], + UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2)); + return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); +} +REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad); + Status SquaredDifferenceGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { @@ -1007,6 +1023,26 @@ Status ProdGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Prod", ProdGrad); +Status SegmentSumGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // The SegmentSum operation sums segments of the Tensor that have the same + // index in the segment_ids parameter. + // i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1] + // will produce [2 + 3 + 4, 5] = [9, 5] + // The gradient that will flow back to the gather operation will look like + // [x1, x2], it will have the same shape as the output of the SegmentSum + // operation. The differentiation step of the SegmentSum operation just + // broadcast the gradient in order to retrieve the z's shape. + // dy/dz = [x1, x1, x1, x2] + grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1))); + + // stop propagation along segment_ids + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad); + // MatMulGrad helper function used to compute two MatMul operations // based on input matrix transposition combinations. Status MatMulGradHelper(const Scope& scope, const bool is_batch, diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index fd7b6fe6625f27bda92e2f56f60908658cdecd7e..88aef1fab410e11aa17a9e44578f5db95ed6e52b 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/grad_op_registry.h" #include "tensorflow/cc/framework/gradient_checker.h" +#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -42,9 +44,11 @@ using ops::Placeholder; using ops::Pow; using ops::Prod; using ops::RealDiv; +using ops::SegmentSum; using ops::SquaredDifference; using ops::Sub; using ops::Sum; +using ops::UnsafeDiv; // TODO(andydavis) Test gradient function against numeric gradients output. // TODO(andydavis) As more gradients are added move common test functions @@ -475,11 +479,7 @@ TEST_F(CWiseUnaryGradTest, Tan_Complex) { auto x_fn = [this](const int i) { return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); }; - // TODO(kbsriram) - // Enable when tan kernel supports complex inputs - if (false) { - TestCWiseGrad(TAN, x_fn); - } + TestCWiseGrad(TAN, x_fn); } TEST_F(CWiseUnaryGradTest, Atan) { @@ -854,6 +854,36 @@ TEST_F(NaryGradTest, RealDiv) { RunTest({x}, {x_shape}, {y}, {x_shape}); } +TEST_F(NaryGradTest, UnsafeDiv) { + { + TensorShape x_shape({3, 2, 5}); + const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large + // division errors in the numeric estimator used by the gradient checker. + const auto y = UnsafeDiv( + scope_, x, Add(scope_, Const(scope_, 1), Abs(scope_, x))); + RunTest({x}, {x_shape}, {y}, {x_shape}); + } + { + // Return 0 gradient (rather than NaN) for division by zero. + const auto x = Placeholder(scope_, DT_FLOAT); + const auto zero = Const(scope_, 0.0); + const auto y = UnsafeDiv(scope_, x, zero); + + std::vector grad_outputs; + TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs)); + ClientSession session(scope_); + std::vector grad_result; + TF_EXPECT_OK( + session.Run({{x, {-3.0f, 0.0f, 3.0f}}}, grad_outputs, &grad_result)); + EXPECT_EQ(grad_result.size(), 1); + EXPECT_EQ(grad_result[0].NumElements(), 3); + EXPECT_EQ(grad_result[0].flat()(0), 0.0f); + EXPECT_EQ(grad_result[0].flat()(1), 0.0f); + EXPECT_EQ(grad_result[0].flat()(2), 0.0f); + } +} + TEST_F(NaryGradTest, SquaredDifference) { TensorShape x1_shape({3, 2, 5}); TensorShape x2_shape({2, 5}); @@ -902,5 +932,14 @@ TEST_F(NaryGradTest, Prod) { RunTest({x}, {x_shape}, {y}, {y_shape}); } +TEST_F(NaryGradTest, SegmentSum) { + TensorShape x_shape({3, 4}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = SegmentSum(scope_, x, {0, 0, 1}); + // the sum is always on the first dimension + TensorShape y_shape({2, 4}); + RunTest({x}, {x_shape}, {y}, {y_shape}); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index c73482d5f4d13ade0dc0412941251d1651371b6e..588e96cb196189780037f66266484962ba0385e4 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -47,6 +47,72 @@ Status SoftmaxGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad); +bool IsZero(const Scope& scope, const Output& grad) { + string op_type_name = grad.op().node()->type_string(); + if (op_type_name == "ZerosLike" || op_type_name == "Zeros") { + return true; + } + // The Operation we were provided is not named something obvious so + // we need to actually look at its contents. + // The original python code did this by calling a utility function called + // tensor_util.constant_value. + // There is no C++ equivalent to tensor_util.constant_value so we do nothing + // for the moment. + return false; +} + +// Multiply after broadcasting vec to match dimensions of mat. +// Args: +// vec: A 1-D tensor of dimension [D0] +// mat: A 2-D tensor of dimesnion [D0, D1] +// +// Returns: +// A tensor of dimension [D0, D1], the result fo vec * mat. +Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) { + auto reshaped = ExpandDims(scope, vec, -1); + return Multiply(scope, reshaped, mat); +} + +Status SoftmaxCrossEntropyWithLogitsGrad(const Scope& scope, + const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // Softmax gradient with cross entropy logits function. + // We multiply the backprop for cost with the gradients - op.output[1]. + // There is no gradient for labels. + + // The outputs of the network are at input index 0. + auto logits = op.input(0); + // The "truth" labels are at index 1. + auto softmax_grad = op.output(1); + + // The loss is the output at index 0, and backprop is the output at index 1. + auto grad_loss = grad_inputs[0]; + auto grad_grad = grad_inputs[1]; + + auto grad = BroadcastMul(scope, grad_loss, softmax_grad); + if (!IsZero(scope, grad_grad)) { + std::vector axis; + auto logits_softmax = Softmax(scope, logits); + + auto grad_grad_expand = ExpandDims(scope, grad_grad, 1); + auto logits_softmax_expand = ExpandDims(scope, logits_softmax, 2); + auto matmul_result = + BatchMatMul(scope, grad_grad_expand, logits_softmax_expand); + axis.push_back(1); + auto squeeze_result = Squeeze(scope, matmul_result, Squeeze::Axis(axis)); + auto subtraction_result = Subtract(scope, grad_grad, squeeze_result); + auto multiply_result = Multiply(scope, subtraction_result, logits_softmax); + grad = Add(scope, grad, multiply_result); + } + auto minus_log_softmax = Multiply(scope, LogSoftmax(scope, logits), -1.0f); + grad_outputs->push_back(grad); + grad_outputs->push_back(BroadcastMul(scope, grad_loss, minus_log_softmax)); + return scope.status(); +} +REGISTER_GRADIENT_OP("SoftmaxCrossEntropyWithLogits", + SoftmaxCrossEntropyWithLogitsGrad); + Status LogSoftmaxGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { @@ -195,9 +261,9 @@ Status MaxPool3DGradHelper(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); MaxPool3DGrad::Attrs grad_attrs; - auto dx = MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], - ksize, strides, padding, - grad_attrs.DataFormat(data_format)); + auto dx = + MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], ksize, + strides, padding, grad_attrs.DataFormat(data_format)); grad_outputs->push_back(dx); return scope.status(); } @@ -216,10 +282,9 @@ Status AvgPoolGradHelper(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); internal::AvgPoolGrad::Attrs grad_attrs; - auto dx = - internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], - ksize, strides, padding, - grad_attrs.DataFormat(data_format)); + auto dx = internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), + grad_inputs[0], ksize, strides, padding, + grad_attrs.DataFormat(data_format)); grad_outputs->push_back(dx); return scope.status(); } @@ -238,9 +303,9 @@ Status AvgPool3DGradHelper(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); AvgPool3DGrad::Attrs grad_attrs; - auto dx = AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], - ksize, strides, padding, - grad_attrs.DataFormat(data_format)); + auto dx = + AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], ksize, + strides, padding, grad_attrs.DataFormat(data_format)); grad_outputs->push_back(dx); return scope.status(); } diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index b4d457a9d14eb79232cda9412fa0050f6a9968cc..aa72cf7ba2a958f54d50b59f0edaefb27edf0e86 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -25,6 +25,8 @@ limitations under the License. namespace tensorflow { namespace { +using ops::AvgPool; +using ops::AvgPool3D; using ops::BiasAdd; using ops::Conv2D; using ops::Elu; @@ -33,11 +35,9 @@ using ops::FractionalMaxPool; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; -using ops::AvgPool; -using ops::AvgPool3D; using ops::MaxPool; -using ops::MaxPoolV2; using ops::MaxPool3D; +using ops::MaxPoolV2; using ops::Placeholder; using ops::Relu; using ops::Relu6; @@ -111,6 +111,20 @@ TEST_F(NNGradTest, SoftmaxGrad) { RunTest(x, shape, y, shape); } +TEST_F(NNGradTest, SoftmaxCrossEntropyWithLogitsGrad) { + TensorShape logits_shape({5, 3}); + TensorShape loss_shape({5}); + + auto logits = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape)); + auto labels = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape)); + auto y = + tensorflow::ops::SoftmaxCrossEntropyWithLogits(scope_, logits, labels); + // Note the reversal of the backprop and loss orders. Issue #18734 has been + // opened for this. + RunTest({logits, labels}, {logits_shape, logits_shape}, {y.backprop, y.loss}, + {logits_shape, loss_shape}); +} + TEST_F(NNGradTest, LogSoftmaxGrad) { TensorShape shape({5, 3}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); @@ -253,7 +267,7 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) { RunTest(x, x_shape, y, y_shape); } -TEST_F(NNGradTest, LRN){ +TEST_F(NNGradTest, LRN) { TensorShape x_shape({1, 1, 2, 1}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); auto y = LRN(scope_, x); diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 06a3be18e08f611d3ecf9804908d791d15fdab13..3d3895c8fa82c3c0e2974228e9cad767d0e00df4 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -33,6 +33,46 @@ cc_library( hdrs = ["tag_constants.h"], ) +cc_library( + name = "reader", + srcs = ["reader.cc"], + hdrs = ["reader.h"], + deps = [ + ":constants", + ] + if_not_mobile([ + # TODO(b/111634734): :lib and :protos_all contain dependencies that + # cannot be built on mobile platforms. Instead, include the appropriate + # tf_lib depending on the build platform. + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ]) + if_mobile([ + # Mobile-friendly SavedModel proto. See go/portable-proto for more info. + "//tensorflow/core:saved_model_portable_proto", + ]) + if_android([ + "//tensorflow/core:android_tensorflow_lib", + ]) + if_ios([ + "//tensorflow/core:ios_tensorflow_lib", + ]), +) + +tf_cc_test( + name = "reader_test", + srcs = ["reader_test.cc"], + data = [ + ":saved_model_half_plus_two", + ], + linkstatic = 1, + deps = [ + ":constants", + ":reader", + ":tag_constants", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "loader", hdrs = ["loader.h"], @@ -54,6 +94,7 @@ cc_library( hdrs = ["loader.h"], deps = [ ":constants", + ":reader", ] + if_not_mobile([ "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index faa1e378d07ea94ad08ee084d18bf6a113f054af..3830416159158cca8bfb8422c2959b49fa42406d 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -18,8 +18,10 @@ limitations under the License. #include #include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf_internal.h" @@ -43,56 +45,6 @@ auto* load_latency = monitoring::Counter<1>::New( constexpr char kLoadAttemptFail[] = "fail"; constexpr char kLoadAttemptSuccess[] = "success"; -Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { - const string saved_model_pb_path = - io::JoinPath(export_dir, kSavedModelFilenamePb); - if (Env::Default()->FileExists(saved_model_pb_path).ok()) { - return ReadBinaryProto(Env::Default(), saved_model_pb_path, - saved_model_proto); - } - const string saved_model_pbtxt_path = - io::JoinPath(export_dir, kSavedModelFilenamePbTxt); - if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) { - return ReadTextProto(Env::Default(), saved_model_pbtxt_path, - saved_model_proto); - } - return Status(error::Code::NOT_FOUND, - "Could not find SavedModel .pb or .pbtxt at supplied export " - "directory path: " + - export_dir); -} - -string GetTagsAsString(const std::unordered_set& tags) { - string tags_as_string = "{ "; - for (const string& tag : tags) { - tags_as_string = strings::StrCat(tags_as_string, tag, " "); - } - tags_as_string = strings::StrCat(tags_as_string, "}"); - return tags_as_string; -} - -Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, - const std::unordered_set& tags, - MetaGraphDef* meta_graph_def_to_load) { - for (const MetaGraphDef& meta_graph_def : saved_model_proto.meta_graphs()) { - // Get tags from the meta_graph_def. - std::unordered_set graph_tags; - for (const string& tag : meta_graph_def.meta_info_def().tags()) { - graph_tags.insert(tag); - } - // Match with the set of tags provided. - if (graph_tags == tags) { - *meta_graph_def_to_load = meta_graph_def; - return Status::OK(); - } - } - return Status(error::Code::NOT_FOUND, - "Could not find meta graph def matching supplied tags: " + - GetTagsAsString(tags) + - ". To inspect available tag-sets in the SavedModel, please " - "use the SavedModel CLI: `saved_model_cli`"); -} - Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr* session) { @@ -122,6 +74,54 @@ void AddAssetsTensorsToInputs(const StringPiece export_dir, } } +// Like Session::Run(), but uses the Make/Run/ReleaseCallable() API to avoid +// leaving behind non-GC'ed state. +// +// Detailed motivation behind this approach, from ashankar@: +// +// Each call to Session::Run() that identifies a new subgraph (based on feeds +// and fetches) creates some datastructures that live as long as the session +// (the partitioned graph, associated executors etc.). +// +// A pathological case of this would be if say the initialization op +// (main_op/legacy_init_op) involves the use of a large constant. Then we +// allocate memory for that large constant that will just stick around till the +// session dies. With this Callable mechanism, that memory will be released +// right after ReleaseCallable returns. +// +// However, the resource manager state remains. +Status RunOnce(const RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs, RunMetadata* run_metadata, + Session* session) { + CallableOptions callable_options; + std::vector feed_tensors; + *callable_options.mutable_run_options() = run_options; + for (const auto& input : inputs) { + const string& name = input.first; + const Tensor& tensor = input.second; + callable_options.add_feed(name); + feed_tensors.push_back(tensor); + } + for (const string& output_tensor_name : output_tensor_names) { + callable_options.add_fetch(output_tensor_name); + } + for (const string& target_node_name : target_node_names) { + callable_options.add_target(target_node_name); + } + + Session::CallableHandle callable_handle; + TF_RETURN_IF_ERROR(session->MakeCallable(callable_options, &callable_handle)); + const Status run_status = session->RunCallable(callable_handle, feed_tensors, + outputs, run_metadata); + // Be sure to call ReleaseCallable() regardless of the outcome of + // RunCallable(). + session->ReleaseCallable(callable_handle).IgnoreError(); + return run_status; +} + bool HasMainOp(const MetaGraphDef& meta_graph_def) { const auto& collection_def_map = meta_graph_def.collection_def(); if (collection_def_map.find(kSavedModelMainOpKey) != @@ -134,10 +134,11 @@ bool HasMainOp(const MetaGraphDef& meta_graph_def) { Status RunMainOp(const RunOptions& run_options, const string& export_dir, const MetaGraphDef& meta_graph_def, const std::vector& asset_file_defs, - Session* session) { - LOG(INFO) << "Running MainOp on SavedModel bundle."; + Session* session, const string& main_op_key) { + LOG(INFO) << "Running MainOp with key " << main_op_key + << " on SavedModel bundle."; const auto& collection_def_map = meta_graph_def.collection_def(); - const auto main_op_it = collection_def_map.find(kSavedModelMainOpKey); + const auto main_op_it = collection_def_map.find(main_op_key); if (main_op_it != collection_def_map.end()) { if (main_op_it->second.node_list().value_size() != 1) { return errors::FailedPrecondition( @@ -147,8 +148,8 @@ Status RunMainOp(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; const StringPiece main_op_name = main_op_it->second.node_list().value(0); - return session->Run(run_options, inputs, {}, {main_op_name.ToString()}, - nullptr /* outputs */, &run_metadata); + return RunOnce(run_options, inputs, {}, {main_op_name.ToString()}, + nullptr /* outputs */, &run_metadata, session); } return Status::OK(); } @@ -169,7 +170,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, variables_directory, MetaFilename(kSavedModelVariablesFilename)); if (!Env::Default()->FileExists(variables_index_path).ok()) { LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " - "were restored."; + "were restored. File does not exist: " + << variables_index_path; return Status::OK(); } const string variables_path = @@ -185,32 +187,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; - return session->Run(run_options, inputs, {}, {restore_op_name.ToString()}, - nullptr /* outputs */, &run_metadata); -} - -Status RunLegacyInitOp(const RunOptions& run_options, const string& export_dir, - const MetaGraphDef& meta_graph_def, - const std::vector& asset_file_defs, - Session* session) { - LOG(INFO) << "Running LegacyInitOp on SavedModel bundle."; - const auto& collection_def_map = meta_graph_def.collection_def(); - const auto init_op_it = collection_def_map.find(kSavedModelLegacyInitOpKey); - if (init_op_it != collection_def_map.end()) { - if (init_op_it->second.node_list().value_size() != 1) { - return errors::FailedPrecondition(strings::StrCat( - "Expected exactly one serving init op in : ", export_dir)); - } - std::vector> inputs; - AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); - RunMetadata run_metadata; - const StringPiece legacy_init_op_name = - init_op_it->second.node_list().value(0); - return session->Run(run_options, inputs, {}, - {legacy_init_op_name.ToString()}, nullptr /* outputs */, - &run_metadata); - } - return Status::OK(); + return RunOnce(run_options, inputs, {}, {restore_op_name.ToString()}, + nullptr /* outputs */, &run_metadata, session); } Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, @@ -235,18 +213,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, const string& export_dir, const std::unordered_set& tags, SavedModelBundle* const bundle) { - if (!MaybeSavedModelDirectory(export_dir)) { - return Status(error::Code::NOT_FOUND, - "SavedModel not found in export directory: " + export_dir); - } - LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags) - << "; from: " << export_dir; - - SavedModel saved_model_proto; - TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); - - TF_RETURN_IF_ERROR( - FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def)); + TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags, + &bundle->meta_graph_def)); TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession( bundle->meta_graph_def, session_options, &bundle->session)); @@ -262,11 +230,11 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, if (HasMainOp(bundle->meta_graph_def)) { TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir, bundle->meta_graph_def, asset_file_defs, - bundle->session.get())); + bundle->session.get(), kSavedModelMainOpKey)); } else { - TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir, - bundle->meta_graph_def, asset_file_defs, - bundle->session.get())); + TF_RETURN_IF_ERROR(RunMainOp( + run_options, export_dir, bundle->meta_graph_def, asset_file_defs, + bundle->session.get(), kSavedModelLegacyInitOpKey)); } return Status::OK(); } @@ -288,8 +256,8 @@ Status LoadSavedModel(const SessionOptions& session_options, return end_microseconds - start_microseconds; }(); auto log_and_count = [&](const string& status_str) { - LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags) - << "; Status: " << status_str << ". Took " + LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ") + << " }; Status: " << status_str << ". Took " << load_latency_microsecs << " microseconds."; load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); }; diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc new file mode 100644 index 0000000000000000000000000000000000000000..2146c8a19745fa9ea2484c4bb4a2104a38d85144 --- /dev/null +++ b/tensorflow/cc/saved_model/reader.cc @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/reader.h" + +#include + +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" + +namespace tensorflow { +namespace { + +Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { + LOG(INFO) << "Reading SavedModel from: " << export_dir; + + const string saved_model_pb_path = + io::JoinPath(export_dir, kSavedModelFilenamePb); + if (Env::Default()->FileExists(saved_model_pb_path).ok()) { + return ReadBinaryProto(Env::Default(), saved_model_pb_path, + saved_model_proto); + } + const string saved_model_pbtxt_path = + io::JoinPath(export_dir, kSavedModelFilenamePbTxt); + if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) { + return ReadTextProto(Env::Default(), saved_model_pbtxt_path, + saved_model_proto); + } + return Status(error::Code::NOT_FOUND, + "Could not find SavedModel .pb or .pbtxt at supplied export " + "directory path: " + + export_dir); +} + +Status FindMetaGraphDef(const SavedModel& saved_model_proto, + const std::unordered_set& tags, + MetaGraphDef* meta_graph_def) { + LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ") + << " }"; + for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) { + // Get tags from the graph_def. + std::unordered_set graph_tags; + for (const string& tag : graph_def.meta_info_def().tags()) { + graph_tags.insert(tag); + } + // Match with the set of tags provided. + if (graph_tags == tags) { + *meta_graph_def = graph_def; + return Status::OK(); + } + } + return Status( + error::Code::NOT_FOUND, + strings::StrCat( + "Could not find meta graph def matching supplied tags: { ", + str_util::Join(tags, " "), + " }. To inspect available tag-sets in the SavedModel, please " + "use the SavedModel CLI: `saved_model_cli`")); +} + +} // namespace + +Status ReadMetaGraphDefFromSavedModel(const string& export_dir, + const std::unordered_set& tags, + MetaGraphDef* const meta_graph_def) { + SavedModel saved_model_proto; + TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); + TF_RETURN_IF_ERROR(FindMetaGraphDef(saved_model_proto, tags, meta_graph_def)); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/reader.h b/tensorflow/cc/saved_model/reader.h new file mode 100644 index 0000000000000000000000000000000000000000..5815108df2a1883b6618e801f30c1915cde8c895 --- /dev/null +++ b/tensorflow/cc/saved_model/reader.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/// Functions to read the SavedModel proto, or parts of it. + +#ifndef TENSORFLOW_CC_SAVED_MODEL_READER_H_ +#define TENSORFLOW_CC_SAVED_MODEL_READER_H_ + +#include +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { + +// Reads the SavedModel proto from saved_model.pb(txt) in the given directory, +// finds the MetaGraphDef that matches the given set of tags and writes it to +// the `meta_graph_def` parameter. Returns a failure status when the SavedModel +// file does not exist or no MetaGraphDef matches the tags. +Status ReadMetaGraphDefFromSavedModel(const string& export_dir, + const std::unordered_set& tags, + MetaGraphDef* const meta_graph_def); + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_ diff --git a/tensorflow/cc/saved_model/reader_test.cc b/tensorflow/cc/saved_model/reader_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..620e9c2eece886c9600a8c93cede3b132ccbccaa --- /dev/null +++ b/tensorflow/cc/saved_model/reader_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/reader.h" + +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/tag_constants.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +constexpr char kTestDataPbTxt[] = + "cc/saved_model/testdata/half_plus_two_pbtxt/00000123"; +constexpr char kTestDataSharded[] = + "cc/saved_model/testdata/half_plus_two/00000123"; + +class ReaderTest : public ::testing::Test { + protected: + ReaderTest() {} + + void CheckMetaGraphDef(const MetaGraphDef& meta_graph_def) { + const auto& tags = meta_graph_def.meta_info_def().tags(); + EXPECT_TRUE(std::find(tags.begin(), tags.end(), kSavedModelTagServe) != + tags.end()); + EXPECT_NE(meta_graph_def.meta_info_def().tensorflow_version(), ""); + EXPECT_EQ( + meta_graph_def.signature_def().at("serving_default").method_name(), + "tensorflow/serving/predict"); + } +}; + +TEST_F(ReaderTest, TagMatch) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, + &meta_graph_def)); + CheckMetaGraphDef(meta_graph_def); +} + +TEST_F(ReaderTest, NoTagMatch) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"}, + &meta_graph_def); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: { missing-tag }")) + << st.error_message(); +} + +TEST_F(ReaderTest, NoTagMatchMultiple) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = ReadMetaGraphDefFromSavedModel( + export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: ")) + << st.error_message(); +} + +TEST_F(ReaderTest, PbtxtFormat) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt); + TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, + &meta_graph_def)); + CheckMetaGraphDef(meta_graph_def); +} + +TEST_F(ReaderTest, InvalidExportPath) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path"); + Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, + &meta_graph_def); + EXPECT_FALSE(st.ok()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 2119c8ec47f941a76e81346ae5d20da78eae11a3..1899a32e4dc5487875f091fece6acf0c44c9243f 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -8,28 +8,6 @@ load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -# Optional runtime utilities for use by code generated by tfcompile. -cc_library( - name = "runtime", - srcs = ["runtime.cc"], - hdrs = ["runtime.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework_lite", - ], -) - -tf_cc_test( - name = "runtime_test", - srcs = ["runtime_test.cc"], - deps = [ - ":runtime", - "//tensorflow/core:framework", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - # Don't depend on this directly; this is only used for the benchmark test # generated by tf_library. cc_library( @@ -53,9 +31,9 @@ cc_library( ], deps = [ ":embedded_protocol_buffers", - ":runtime", # needed by codegen to print aligned_buffer_bytes "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -68,7 +46,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:compile_only_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service/cpu:buffer_info_util", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework_internal", @@ -237,7 +217,6 @@ test_suite( tests = [ ":benchmark_test", ":codegen_test", - ":runtime_test", ":test_graph_tfadd_test", ":test_graph_tfunknownop2_test", ":test_graph_tfunknownop3_test", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 0025842aead53973befc794378a26fa8db2ae1cb..89fefdad54fabcc953e72c6aa7a2361468b61259 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -20,10 +20,11 @@ limitations under the License. #include #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" -#include "tensorflow/compiler/aot/runtime.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/tf2xla/str_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -36,6 +37,8 @@ namespace tfcompile { namespace { +using BufferInfo = cpu_function_runtime::BufferInfo; + bool IsAlpha(char c) { return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); } @@ -85,27 +88,36 @@ Status XLATypeToCpp(xla::PrimitiveType type, string* str) { return Status::OK(); } -// total_buffer_bytes returns the sum of each size in `sizes`, skipping -1 -// values. There are `n` entries in `sizes`. -size_t total_buffer_bytes(const intptr_t* sizes, size_t n) { - size_t total = 0; - for (size_t i = 0; i < n; ++i) { - if (sizes[i] != -1) { - total += sizes[i]; - } - } - return total; +// Returns the sum of the size of each buffer in `buffer_infos`. +size_t TotalBufferBytes(const std::vector& buffer_infos) { + return std::accumulate(buffer_infos.begin(), buffer_infos.end(), size_t{0}, + [](size_t size, const BufferInfo& buffer_info) { + return size + buffer_info.size(); + }); } -// Fills in arg_sizes with the byte size of each positional arg. -Status ComputeArgSizes(const CompileResult& compile_result, - std::vector* arg_sizes) { - const xla::ProgramShape& ps = compile_result.program_shape; - for (int i = 0; i < ps.parameters_size(); ++i) { - arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( - ps.parameters(i), compile_result.pointer_size)); - } - return Status::OK(); +// Returns a vector of BufferInfo instances in `buffer_infos` that are entry +// parameter buffers. +std::vector ExtractEntryParamBufferInfos( + const std::vector& buffer_infos) { + std::vector result; + std::copy_if(buffer_infos.begin(), buffer_infos.end(), + std::back_inserter(result), [](const BufferInfo& buffer_info) { + return buffer_info.is_entry_parameter(); + }); + return result; +} + +// Returns a vector of BufferInfo instances in `buffer_infos` that are temp +// buffers. +std::vector ExtractTempBufferInfos( + const std::vector& buffer_infos) { + std::vector result; + std::copy_if(buffer_infos.begin(), buffer_infos.end(), + std::back_inserter(result), [](const BufferInfo& buffer_info) { + return buffer_info.is_temp_buffer(); + }); + return result; } // Add (from,to) rewrite pairs based on the given shape. These rewrite pairs @@ -278,6 +290,25 @@ Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { return Status::OK(); } +// Returns a list of C++ expressions that, when executed, will construct the +// BufferInfo instances in `buffer_infos`. +std::vector BufferInfosToCppExpression( + const std::vector& buffer_infos) { + std::vector buffer_infos_as_strings; + std::transform(buffer_infos.begin(), buffer_infos.end(), + std::back_inserter(buffer_infos_as_strings), + [](const BufferInfo& buffer_info) { + std::pair encoded = buffer_info.Encode(); + string encoded_second_as_str = + encoded.second == ~0ULL + ? "~0ULL" + : strings::StrCat(encoded.second, "ULL"); + return strings::StrCat( + "::tensorflow::cpu_function_runtime::BufferInfo({", + encoded.first, "ULL, ", encoded_second_as_str, "})"); + }); + return buffer_infos_as_strings; +} } // namespace Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, @@ -286,29 +317,35 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ValidateConfig(config)); TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); const int64 result_index = compile_result.aot->result_buffer_index(); - const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes(); - if (result_index < 0 || result_index > temp_sizes.size()) { + const std::vector& buffer_infos = + compile_result.aot->buffer_infos(); + const std::vector arg_index_table = + ::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); + std::vector buffer_infos_as_strings = + BufferInfosToCppExpression(buffer_infos); + if (result_index < 0 || result_index >= buffer_infos.size()) { return errors::InvalidArgument("result index: ", result_index, " is outside the range of temp sizes: [0,", - temp_sizes.size(), ")"); + buffer_infos.size(), ")"); } // Compute sizes and generate methods. - std::vector arg_sizes; - TF_RETURN_IF_ERROR(ComputeArgSizes(compile_result, &arg_sizes)); + std::vector buffer_infos_for_args = + ExtractEntryParamBufferInfos(buffer_infos); + std::vector buffer_infos_for_temps = + ExtractTempBufferInfos(buffer_infos); const xla::ProgramShape& ps = compile_result.program_shape; string methods_arg, methods_result; TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); - const std::vector iarg(arg_sizes.begin(), arg_sizes.end()); - const std::vector itemp(temp_sizes.begin(), temp_sizes.end()); - const size_t arg_bytes_aligned = - runtime::aligned_buffer_bytes(iarg.data(), iarg.size()); - const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size()); - const size_t temp_bytes_aligned = - runtime::aligned_buffer_bytes(itemp.data(), itemp.size()); - const size_t temp_bytes_total = - total_buffer_bytes(itemp.data(), itemp.size()); + const size_t arg_bytes_aligned = cpu_function_runtime::AlignedBufferBytes( + buffer_infos_for_args.data(), buffer_infos_for_args.size(), + /*allocate_entry_params=*/true); + const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args); + const size_t temp_bytes_aligned = cpu_function_runtime::AlignedBufferBytes( + buffer_infos_for_temps.data(), buffer_infos_for_temps.size(), + /*allocate_entry_params=*/true); + const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps); // Create rewrite strings for namespace start and end. string ns_start; @@ -343,8 +380,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, // calling HloProfilePrinter::profile_counters_size. const string assign_profile_counters_size = opts.gen_hlo_profile_printer_data - ? "data->profile_counters_size = " - "data->hlo_profile_printer_data->profile_counters_size();" + ? "data->set_profile_counters_size(" + "data->hlo_profile_printer_data()->profile_counters_size());" : ""; // Use a poor-man's text templating mechanism; first populate the full header @@ -414,9 +451,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static constexpr size_t kNumArgs = {{ARG_NUM}}; // Byte size of each argument buffer. There are kNumArgs entries. - static const intptr_t* ArgSizes() { - static constexpr intptr_t kArgSizes[kNumArgs] = {{{ARG_SIZES}}}; - return kArgSizes; + static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) { + return BufferInfos()[ArgIndexToBufferIndex()[index]].size(); } // Returns static data used to create an XlaCompiledCpuFunction. @@ -424,17 +460,17 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; - data->raw_function = {{ENTRY}}; - data->arg_sizes = ArgSizes(); - data->num_args = kNumArgs; - data->temp_sizes = TempSizes(); - data->num_temps = kNumTemps; - data->result_index = kResultIndex; - data->arg_names = StaticArgNames(); - data->result_names = StaticResultNames(); - data->program_shape = StaticProgramShape(); - data->hlo_profile_printer_data = StaticHloProfilePrinterData(); - {{ASSIGN_PROFILE_COUNTERS_SIZE}} + data->set_raw_function({{ENTRY}}); + data->set_buffer_infos(BufferInfos()); + data->set_num_buffers(kNumBuffers); + data->set_arg_index_table(ArgIndexToBufferIndex()); + data->set_num_args(kNumArgs); + data->set_result_index(kResultIndex); + data->set_arg_names(StaticArgNames()); + data->set_result_names(StaticResultNames()); + data->set_program_shape(StaticProgramShape()); + data->set_hlo_profile_printer_data(StaticHloProfilePrinterData()); +{{ASSIGN_PROFILE_COUNTERS_SIZE}} return data; }(); return *kStaticData; @@ -482,17 +518,27 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {{METHODS_RESULT}} private: - // Number of result and temporary buffers for the compiled computation. - static constexpr size_t kNumTemps = {{TEMP_NUM}}; - // The 0-based index of the result tuple in the temporary buffers. - static constexpr size_t kResultIndex = {{RESULT_INDEX}}; + // Number of buffers for the compiled computation. + static constexpr size_t kNumBuffers = {{NUM_BUFFERS}}; - // Byte size of each result / temporary buffer. There are kNumTemps entries. - static const intptr_t* TempSizes() { - static constexpr intptr_t kTempSizes[kNumTemps] = {{{TEMP_SIZES}}}; - return kTempSizes; + static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() { + static const ::tensorflow::cpu_function_runtime::BufferInfo + kBufferInfos[kNumBuffers] = { +{{BUFFER_INFOS_AS_STRING}} + }; + return kBufferInfos; } + static const ::tensorflow::int32* ArgIndexToBufferIndex() { + static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = { +{{ARG_INDEX_TABLE}} + }; + return kArgIndexToBufferIndex; + } + + // The 0-based index of the result tuple in the temporary buffers. + static constexpr size_t kResultIndex = {{RESULT_INDEX}}; + // Array of names of each positional argument, terminated by nullptr. static const char** StaticArgNames() {{ARG_NAMES_CODE}} @@ -523,8 +569,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)}, {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, - {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())}, - {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, + {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())}, + {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, {"{{DECLS_FROM_OBJ_FILE}}", @@ -546,8 +592,9 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, - {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, - {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}}; + {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())}, + {"{{BUFFER_INFOS_AS_STRING}}", + str_util::Join(buffer_infos_as_strings, ",\n")}}; str_util::ReplaceAllPairs(header, rewrites); return Status::OK(); } diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 29bc9c13b889c86c2ba8776c7b067c54cb05bc43..60d59ae996e8f7ec490c98aeab05182626e61976 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -32,6 +32,8 @@ namespace tensorflow { namespace tfcompile { namespace { +using ::tensorflow::cpu_function_runtime::BufferInfo; + void ExpectErrorContains(const Status& status, StringPiece str) { EXPECT_NE(Status::OK(), status); EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) @@ -171,8 +173,14 @@ TEST(CodegenTest, Golden) { fetch->mutable_id()->set_node_name("fetch0"); fetch->set_name("myfetch"); CompileResult compile_result; - compile_result.aot.reset( - new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {})); + compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult( + {}, + {BufferInfo::MakeTempBuffer(1), + BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0), + BufferInfo::MakeTempBuffer(2), + BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1), + BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)}, + 5, {})); compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 6641d45e83020f4144616a6a2837c844330298f5..e4d8a02877c75fa72c5747650ab9c7ac229955b3 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -65,9 +65,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { static constexpr size_t kNumArgs = 2; // Byte size of each argument buffer. There are kNumArgs entries. - static const intptr_t* ArgSizes() { - static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96}; - return kArgSizes; + static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) { + return BufferInfos()[ArgIndexToBufferIndex()[index]].size(); } // Returns static data used to create an XlaCompiledCpuFunction. @@ -75,17 +74,17 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ XlaCompiledCpuFunction::StaticData* data = new XlaCompiledCpuFunction::StaticData; - data->raw_function = entry_point; - data->arg_sizes = ArgSizes(); - data->num_args = kNumArgs; - data->temp_sizes = TempSizes(); - data->num_temps = kNumTemps; - data->result_index = kResultIndex; - data->arg_names = StaticArgNames(); - data->result_names = StaticResultNames(); - data->program_shape = StaticProgramShape(); - data->hlo_profile_printer_data = StaticHloProfilePrinterData(); - + data->set_raw_function(entry_point); + data->set_buffer_infos(BufferInfos()); + data->set_num_buffers(kNumBuffers); + data->set_arg_index_table(ArgIndexToBufferIndex()); + data->set_num_args(kNumArgs); + data->set_result_index(kResultIndex); + data->set_arg_names(StaticArgNames()); + data->set_result_names(StaticResultNames()); + data->set_program_shape(StaticProgramShape()); + data->set_hlo_profile_printer_data(StaticHloProfilePrinterData()); + return data; }(); return *kStaticData; @@ -215,17 +214,32 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { } private: - // Number of result and temporary buffers for the compiled computation. - static constexpr size_t kNumTemps = 6; - // The 0-based index of the result tuple in the temporary buffers. - static constexpr size_t kResultIndex = 5; + // Number of buffers for the compiled computation. + static constexpr size_t kNumBuffers = 6; + + static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() { + static const ::tensorflow::cpu_function_runtime::BufferInfo + kBufferInfos[kNumBuffers] = { +::tensorflow::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({34ULL, 0ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({386ULL, 1ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}), +::tensorflow::cpu_function_runtime::BufferInfo({481ULL, ~0ULL}) + }; + return kBufferInfos; + } - // Byte size of each result / temporary buffer. There are kNumTemps entries. - static const intptr_t* TempSizes() { - static constexpr intptr_t kTempSizes[kNumTemps] = {1, -1, 2, -1, 3, 120}; - return kTempSizes; + static const ::tensorflow::int32* ArgIndexToBufferIndex() { + static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = { +1, 3 + }; + return kArgIndexToBufferIndex; } + // The 0-based index of the result tuple in the temporary buffers. + static constexpr size_t kResultIndex = 5; + // Array of names of each positional argument, terminated by nullptr. static const char** StaticArgNames() { static const char* kNames[] = {"myfeed", nullptr}; diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index bbc35da2ef6d14ff0d3570ef2d5cf6743456c674..2b5f97b34cd928d32eb220536342c715d91d45bb 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/aot/runtime.h b/tensorflow/compiler/aot/runtime.h deleted file mode 100644 index d1a669ceb17b9fd71d26e978035283f8824b0376..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/aot/runtime.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file contains utilities to make it easier to invoke functions generated -// by tfcompile. Usage of these utilities is optional. - -#ifndef TENSORFLOW_COMPILER_AOT_RUNTIME_H_ -#define TENSORFLOW_COMPILER_AOT_RUNTIME_H_ - -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace tfcompile { -namespace runtime { - -// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. -static constexpr size_t kAlign = 64; - -// aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1 -// values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign -// byte boundaries. -size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n); - -// MallocContiguousBuffers allocates buffers for use by the entry point -// generated by tfcompile. `sizes` is an array of byte sizes for each buffer, -// where -1 causes the buffer pointer to be nullptr. There are `n` entries in -// `sizes`. If `annotate_initialized` is set, the allocated memory will be -// annotated as having been initialized - this is useful when allocating -// temporary buffers. -// -// A single contiguous block of memory is allocated, and portions of it are -// parceled out into `bufs`, which must have space for `n` entries. Returns the -// head of the allocated contiguous block, which should be passed to -// FreeContiguous when the buffers are no longer in use. -void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, - bool annotate_initialized); - -// FreeContiguous frees the contiguous block of memory allocated by -// MallocContiguousBuffers. -void FreeContiguous(void* contiguous); - -} // namespace runtime -} // namespace tfcompile -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_AOT_RUNTIME_H_ diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc index 6b098049cbd7539a2b2e2696b13139a8a6b28e0f..5deb47d12310d24dce847227bd119249210ffb8d 100644 --- a/tensorflow/compiler/aot/test.cc +++ b/tensorflow/compiler/aot/test.cc @@ -51,11 +51,9 @@ namespace tensorflow { namespace tfcompile { namespace { -void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) { - for (int i = 0; i < n; ++i) { - if (sizes[i] != -1) { - memset(bufs[i], 0, sizes[i]); - } +void zero_buffers(XlaCompiledCpuFunction* computation) { + for (int i = 0; i < computation->num_args(); ++i) { + memset(computation->arg_data(i), 0, computation->arg_size(i)); } } @@ -66,7 +64,7 @@ TEST(TEST_NAME, NoCrash) { CPP_CLASS computation; computation.set_thread_pool(&device); - zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs); + zero_buffers(&computation); EXPECT_TRUE(computation.Run()); } @@ -80,7 +78,7 @@ void BM_NAME(int iters) { CPP_CLASS computation; computation.set_thread_pool(&device); - zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs); + zero_buffers(&computation); testing::StartTiming(); while (--iters) { diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index fee46280e9a0e7ba2cf7c3ed46469ae8cc0841d4..0c0c676ece78565e03578d3e33633c7e23b77669 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -44,8 +44,8 @@ using ::testing::IsSupersetOf; TEST(TFCompileTest, Add) { AddComp add; - EXPECT_EQ(add.arg0_data(), add.args()[0]); - EXPECT_EQ(add.arg1_data(), add.args()[1]); + EXPECT_EQ(add.arg0_data(), add.arg_data(0)); + EXPECT_EQ(add.arg1_data(), add.arg_data(1)); add.arg0() = 1; add.arg1() = 2; @@ -67,10 +67,10 @@ TEST(TFCompileTest, Add) { EXPECT_EQ(add_const.error_msg(), ""); EXPECT_EQ(add_const.arg0(), 123); EXPECT_EQ(add_const.arg0_data()[0], 123); - EXPECT_EQ(add_const.arg0_data(), add.args()[0]); + EXPECT_EQ(add_const.arg0_data(), add.arg_data(0)); EXPECT_EQ(add_const.arg1(), 456); EXPECT_EQ(add_const.arg1_data()[0], 456); - EXPECT_EQ(add_const.arg1_data(), add.args()[1]); + EXPECT_EQ(add_const.arg1_data(), add.arg_data(1)); EXPECT_EQ(add_const.result0(), 579); EXPECT_EQ(add_const.result0_data()[0], 579); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); @@ -85,8 +85,8 @@ TEST(TFCompileTest, Add_SetArg) { int32 arg_y = 32; add.set_arg0_data(&arg_x); add.set_arg1_data(&arg_y); - EXPECT_EQ(add.arg0_data(), add.args()[0]); - EXPECT_EQ(add.arg1_data(), add.args()[1]); + EXPECT_EQ(add.arg0_data(), add.arg_data(0)); + EXPECT_EQ(add.arg1_data(), add.arg_data(1)); EXPECT_TRUE(add.Run()); EXPECT_EQ(add.error_msg(), ""); @@ -97,7 +97,7 @@ TEST(TFCompileTest, Add_SetArg) { TEST(TFCompileTest, AddWithCkpt) { AddWithCkptComp add; - EXPECT_EQ(add.arg0_data(), add.args()[0]); + EXPECT_EQ(add.arg0_data(), add.arg_data(0)); add.arg0() = 1; EXPECT_TRUE(add.Run()); @@ -117,7 +117,7 @@ TEST(TFCompileTest, AddWithCkpt) { EXPECT_EQ(add_const.error_msg(), ""); EXPECT_EQ(add_const.arg0(), 111); EXPECT_EQ(add_const.arg0_data()[0], 111); - EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]); + EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0)); EXPECT_EQ(add_const.result0(), 153); EXPECT_EQ(add_const.result0_data()[0], 153); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); @@ -125,7 +125,7 @@ TEST(TFCompileTest, AddWithCkpt) { TEST(TFCompileTest, AddWithCkptSaver) { AddWithCkptSaverComp add; - EXPECT_EQ(add.arg0_data(), add.args()[0]); + EXPECT_EQ(add.arg0_data(), add.arg_data(0)); add.arg0() = 1; EXPECT_TRUE(add.Run()); @@ -145,7 +145,7 @@ TEST(TFCompileTest, AddWithCkptSaver) { EXPECT_EQ(add_const.error_msg(), ""); EXPECT_EQ(add_const.arg0(), 111); EXPECT_EQ(add_const.arg0_data()[0], 111); - EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]); + EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0)); EXPECT_EQ(add_const.result0(), 153); EXPECT_EQ(add_const.result0_data()[0], 153); EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); @@ -153,9 +153,9 @@ TEST(TFCompileTest, AddWithCkptSaver) { TEST(TFCompileTest, Cond) { CondComp cond; - EXPECT_EQ(cond.arg0_data(), cond.args()[0]); - EXPECT_EQ(cond.arg1_data(), cond.args()[1]); - EXPECT_EQ(cond.arg2_data(), cond.args()[2]); + EXPECT_EQ(cond.arg0_data(), cond.arg_data(0)); + EXPECT_EQ(cond.arg1_data(), cond.arg_data(1)); + EXPECT_EQ(cond.arg2_data(), cond.arg_data(2)); cond.arg1() = 10; cond.arg2() = 20; { @@ -178,8 +178,8 @@ TEST(TFCompileTest, Cond) { TEST(TFCompileTest, Gather) { GatherComp gather; - EXPECT_EQ(gather.arg0_data(), gather.args()[0]); - EXPECT_EQ(gather.arg1_data(), gather.args()[1]); + EXPECT_EQ(gather.arg0_data(), gather.arg_data(0)); + EXPECT_EQ(gather.arg1_data(), gather.arg_data(1)); // Successful gather. { @@ -202,12 +202,12 @@ TEST(TFCompileTest, Gather) { EXPECT_EQ(gather_const.arg0(i), params[i]); EXPECT_EQ(gather_const.arg0_data()[i], params[i]); } - EXPECT_EQ(gather_const.arg0_data(), gather_const.args()[0]); + EXPECT_EQ(gather_const.arg0_data(), gather_const.arg_data(0)); for (int i = 0; i < 2; ++i) { EXPECT_EQ(gather_const.arg1(i), indices[i]); EXPECT_EQ(gather_const.arg1_data()[i], indices[i]); } - EXPECT_EQ(gather_const.arg1_data(), gather_const.args()[1]); + EXPECT_EQ(gather_const.arg1_data(), gather_const.arg_data(1)); for (int i = 0; i < 2; ++i) { EXPECT_EQ(gather_const.result0(i), results[i]); EXPECT_EQ(gather_const.result0_data()[i], results[i]); @@ -222,8 +222,8 @@ TEST(TFCompileTest, MatMul2) { foo::bar::MatMulComp matmul; matmul.set_thread_pool(&device); - EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]); - EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]); + EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0)); + EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1)); // Test using the argN() methods. { @@ -271,12 +271,12 @@ TEST(TFCompileTest, MatMul2) { EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]); EXPECT_EQ(matmul_const.arg0_data()[i], args[i]); } - EXPECT_EQ(matmul_const.arg0_data(), matmul.args()[0]); + EXPECT_EQ(matmul_const.arg0_data(), matmul.arg_data(0)); for (int i = 0; i < 6; ++i) { EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]); EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]); } - EXPECT_EQ(matmul_const.arg1_data(), matmul.args()[1]); + EXPECT_EQ(matmul_const.arg1_data(), matmul.arg_data(1)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]); EXPECT_EQ(matmul_const.result0_data()[i], results[i]); @@ -300,8 +300,8 @@ TEST(TFCompileTest, MatMul2_SetArg) { float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}}; matmul.set_arg0_data(&arg0); matmul.set_arg1_data(&arg1); - EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]); - EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]); + EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0)); + EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1)); EXPECT_TRUE(matmul.Run()); EXPECT_EQ(matmul.error_msg(), ""); @@ -319,8 +319,8 @@ TEST(TFCompileTest, MatMulAndAdd1) { MatMulAndAddComp muladd; muladd.set_thread_pool(&device); - EXPECT_EQ(muladd.arg0_data(), muladd.args()[0]); - EXPECT_EQ(muladd.arg1_data(), muladd.args()[1]); + EXPECT_EQ(muladd.arg0_data(), muladd.arg_data(0)); + EXPECT_EQ(muladd.arg1_data(), muladd.arg_data(1)); // Test methods with positional args and results. { @@ -346,12 +346,12 @@ TEST(TFCompileTest, MatMulAndAdd1) { EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]); EXPECT_EQ(muladd_const.arg0_data()[i], args[i]); } - EXPECT_EQ(muladd_const.arg0_data(), muladd.args()[0]); + EXPECT_EQ(muladd_const.arg0_data(), muladd.arg_data(0)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]); EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]); } - EXPECT_EQ(muladd_const.arg1_data(), muladd.args()[1]); + EXPECT_EQ(muladd_const.arg1_data(), muladd.arg_data(1)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]); EXPECT_EQ(muladd_const.result0_data()[i], results0[i]); @@ -387,12 +387,12 @@ TEST(TFCompileTest, MatMulAndAdd1) { EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]); EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]); } - EXPECT_EQ(muladd_const.arg_x_data(), muladd.args()[0]); + EXPECT_EQ(muladd_const.arg_x_data(), muladd.arg_data(0)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]); EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]); } - EXPECT_EQ(muladd_const.arg_y_data(), muladd.args()[1]); + EXPECT_EQ(muladd_const.arg_y_data(), muladd.arg_data(1)); for (int i = 0; i < 4; ++i) { EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]); EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]); @@ -407,8 +407,8 @@ TEST(TFCompileTest, MatMulAndAdd1) { TEST(TFCompileTest, Function) { // The function is equivalent to an addition FunctionComp add_fn; - EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]); - EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]); + EXPECT_EQ(add_fn.arg0_data(), add_fn.arg_data(0)); + EXPECT_EQ(add_fn.arg1_data(), add_fn.arg_data(1)); add_fn.arg0() = 1; add_fn.arg1() = 2; @@ -451,8 +451,8 @@ TEST(TFCompileTest, AssertEqAndReturnDiff) { // Assert is converted into a no-op in XLA, so there is no failure even if the // two args are different. AssertComp assert; - EXPECT_EQ(assert.arg0_data(), assert.args()[0]); - EXPECT_EQ(assert.arg1_data(), assert.args()[1]); + EXPECT_EQ(assert.arg0_data(), assert.arg_data(0)); + EXPECT_EQ(assert.arg1_data(), assert.arg_data(1)); assert.arg0() = 2; assert.arg1() = 1; diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 5c57fee326ca743dcb8aaae354d261ed4d7f44be..326f73b975aec3a7a6bc7cdc9a92f540ad545ad6 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -16,339 +16,365 @@ tf_library( ) """ -load("//tensorflow:tensorflow.bzl", - "if_android", "tf_cc_test", "tf_copts") - -def tf_library(name, graph, config, - freeze_checkpoint=None, freeze_saver=None, - cpp_class=None, gen_test=True, gen_benchmark=True, - visibility=None, testonly=None, - tfcompile_flags=None, - tfcompile_tool="//tensorflow/compiler/aot:tfcompile", - include_standard_runtime_deps=True, - enable_xla_hlo_profiling=False, deps=None, tags=None): - """Runs tfcompile to compile a TensorFlow graph into executable code. - - Given an invocation of tf_library(name="foo", ...), generates the following - build targets: - foo: A cc_library containing the generated header and computation. - foo_test: A cc_test with simple tests and benchmarks. Only created if - gen_test=True. - foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, useful - for mobile devices or other platforms that can't compile the - full test libraries. Only created if gen_benchmark=True. - - Args: - name: The name of the build rule. - graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it - is expected to be in the human-readable proto text format, otherwise it is - expected to be in the proto binary format. - config: File containing tensorflow.tf2xla.Config proto. If the file ends - in '.pbtxt' it is expected to be in the human-readable proto text format, - otherwise it is expected to be in the proto binary format. - freeze_checkpoint: If provided, run freeze_graph with this checkpoint to - convert variables into constants. - freeze_saver: If provided, run freeze_graph with this saver, in SaverDef - binary form, to convert variables into constants. - cpp_class: The name of the generated C++ class, wrapping the generated - function. The syntax of this flag is - [[::],...]. This mirrors the C++ syntax - for referring to a class, where multiple namespaces may precede the class - name, separated by double-colons. The class will be generated in the - given namespace(s), or if no namespaces are given, within the global - namespace. - gen_test: If True, also generate a cc_test rule that builds a simple - test and benchmark. - gen_benchmark: If True, also generate a binary with a simple benchmark. - Unlike the output of gen_test, this benchmark can be run on android. - visibility: Bazel build visibility. - testonly: Bazel testonly attribute. - tfcompile_flags: Extra flags to pass to tfcompile to control compilation. - tfcompile_tool: The tfcompile binary. A non-default can be passed to - use a tfcompile built with extra dependencies. - include_standard_runtime_deps: If True, the standard list of kernel/runtime - deps is added to deps. If False, deps must contain the full set of deps - needed by the generated library. - enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program, - and emit metadata that lets us pretty-print the gathered profile counters. - deps: a list of deps to include on the build rules for the generated - library, added to the standard deps if standard_runtime_deps is True. - tags: tags to apply to subsidiary build rules. - - The output header is called .h. - """ - if not cpp_class: - fail("cpp_class must be specified") - - tfcompile_graph = graph - if freeze_checkpoint or freeze_saver: - if not freeze_checkpoint: - fail("freeze_checkpoint must be specified when freeze_saver is specified") +load( + "//tensorflow:tensorflow.bzl", + "if_android", + "tf_cc_test", + "tf_copts", +) - freeze_name = "freeze_" + name - freeze_file = freeze_name + ".pb" +def tf_library( + name, + graph, + config, + freeze_checkpoint = None, + freeze_saver = None, + cpp_class = None, + gen_test = True, + gen_benchmark = True, + visibility = None, + testonly = None, + tfcompile_flags = None, + tfcompile_tool = "//tensorflow/compiler/aot:tfcompile", + include_standard_runtime_deps = True, + enable_xla_hlo_profiling = False, + deps = None, + tags = None): + """Runs tfcompile to compile a TensorFlow graph into executable code. - # First run tfcompile to generate the list of out_nodes. - out_nodes_file = "out_nodes_" + freeze_name - native.genrule( - name=("gen_" + out_nodes_file), - srcs=[config], - outs=[out_nodes_file], - cmd=("$(location " + tfcompile_tool + ")" + - " --config=$(location " + config + ")" + - " --dump_fetch_nodes > $@"), - tools=[tfcompile_tool], - # Run tfcompile on the build host, rather than forge, since it's - # typically way faster on the local machine. - local=1, - tags=tags, - ) + Given an invocation of tf_library(name="foo", ...), generates the following + build targets: + foo: A cc_library containing the generated header and + computation. + foo_test: A cc_test with simple tests and benchmarks. Only created if + gen_test=True. + foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, + useful for mobile devices or other platforms that can't + compile the full test libraries. Only created if + gen_benchmark=True. + The output header is called .h. - # Now run freeze_graph to convert variables into constants. - freeze_args = (" --input_graph=$(location " + graph + ")" + - " --checkpoint_version=1" + - " --input_binary=" + str(not graph.endswith(".pbtxt")) + - " --input_checkpoint=$(location " + freeze_checkpoint + ")" + - " --output_graph=$(location " + freeze_file + ")" + - " --output_node_names=$$(<$(location " + out_nodes_file + - "))") - freeze_saver_srcs = [] - if freeze_saver: - freeze_args += " --input_saver=$(location " + freeze_saver + ")" - freeze_saver_srcs += [freeze_saver] - native.genrule( - name=freeze_name, - srcs=[ - graph, - freeze_checkpoint, - out_nodes_file, - ] + freeze_saver_srcs, - outs=[freeze_file], - cmd=("$(location //tensorflow/python/tools:freeze_graph)" + - freeze_args), - tools=["//tensorflow/python/tools:freeze_graph"], - tags=tags, - ) - tfcompile_graph = freeze_file + Args: + name: The name of the build rule. + graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' + it is expected to be in the human-readable proto text format, otherwise + it is expected to be in the proto binary format. + config: File containing tensorflow.tf2xla.Config proto. If the file ends + in '.pbtxt' it is expected to be in the human-readable proto text + format, otherwise it is expected to be in the proto binary format. + freeze_checkpoint: If provided, run freeze_graph with this checkpoint to + convert variables into constants. + freeze_saver: If provided, run freeze_graph with this saver, in SaverDef + binary form, to convert variables into constants. + cpp_class: The name of the generated C++ class, wrapping the generated + function. The syntax of this flag is + [[::],...]. This mirrors the C++ syntax + for referring to a class, where multiple namespaces may precede the + class name, separated by double-colons. The class will be generated in + the given namespace(s), or if no namespaces are given, within the global + namespace. + gen_test: If True, also generate a cc_test rule that builds a simple + test and benchmark. + gen_benchmark: If True, also generate a binary with a simple benchmark. + Unlike the output of gen_test, this benchmark can be run on android. + visibility: Bazel build visibility. + testonly: Bazel testonly attribute. + tfcompile_flags: Extra flags to pass to tfcompile to control compilation. + tfcompile_tool: The tfcompile binary. A non-default can be passed to + use a tfcompile built with extra dependencies. + include_standard_runtime_deps: If True, the standard list of + kernel/runtime deps is added to deps. If False, deps must contain the + full set of deps needed by the generated library. + enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated + program, and emit metadata that lets us pretty-print the gathered + profile counters. + deps: a list of deps to include on the build rules for the generated + library, added to the standard deps if standard_runtime_deps is True. + tags: tags to apply to subsidiary build rules. + """ + if not cpp_class: + fail("cpp_class must be specified") - # Rule that runs tfcompile to produce the header and object file. - header_file = name + ".h" - metadata_object_file = name + "_tfcompile_metadata.o" - function_object_file = name + "_tfcompile_function.o" - ep = ("__" + native.package_name() + "__" + name).replace("/", "_") - if type(tfcompile_flags) == type(""): - flags = tfcompile_flags - else: - flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])]) - if enable_xla_hlo_profiling: - profiling_flag = "--xla_hlo_profile" - else: - profiling_flag = "" - native.genrule( - name=("gen_" + name), - srcs=[ - tfcompile_graph, - config, - ], - outs=[ - header_file, - metadata_object_file, - function_object_file, - ], - cmd=("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_header=$(@D)/" + header_file + - " --out_metadata_object=$(@D)/" + metadata_object_file + - " --out_function_object=$(@D)/" + function_object_file + - " " + flags + " " + profiling_flag), - tools=[tfcompile_tool], - visibility=visibility, - testonly=testonly, - # Run tfcompile on the build host since it's typically faster on the local - # machine. - # - # Note that setting the local=1 attribute on a *test target* causes the - # test infrastructure to skip that test. However this is a genrule, not a - # test target, and runs with --genrule_strategy=forced_forge, meaning the - # local=1 attribute is ignored, and the genrule is still run. - # - # https://www.bazel.io/versions/master/docs/be/general.html#genrule - local=1, - tags=tags, - ) + tfcompile_graph = graph + if freeze_checkpoint or freeze_saver: + if not freeze_checkpoint: + fail("freeze_checkpoint must be specified when freeze_saver is " + + "specified") - # Rule that runs tfcompile to produce the SessionModule proto, useful for - # debugging. TODO(b/64813587): Once the SessionModule proto is - # deterministic, move this into the main rule above. - session_module_pb = name + "_session_module.pb" - native.genrule( - name=(name + "_session_module"), - srcs=[ - tfcompile_graph, - config, - ], - outs=[ - session_module_pb, - ], - cmd=("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_session_module=$(@D)/" + session_module_pb + - " " + flags), - tools=[tfcompile_tool], - visibility=visibility, - testonly=testonly, - local=1, - tags=tags, - ) + freeze_name = "freeze_" + name + freeze_file = freeze_name + ".pb" - # The cc_library rule packaging up the header and object file, and needed - # kernel implementations. - need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) - native.cc_library( - name=name, - srcs=[function_object_file, metadata_object_file], - hdrs=[header_file], - visibility=visibility, - testonly=testonly, - deps = [ - # These deps are required by all tf_library targets even if - # include_standard_runtime_deps is False. Without them, the - # generated code will fail to compile. - "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", - "//tensorflow/core:framework_lite", - ] + (need_xla_data_proto and [ - # If we're generating the program shape, we must depend on the proto. - "//tensorflow/compiler/xla:xla_data_proto", - ] or []) + (enable_xla_hlo_profiling and [ - "//tensorflow/compiler/xla/service:hlo_profile_printer_data" - ] or []) + (include_standard_runtime_deps and [ - # TODO(cwhipkey): only depend on kernel code that the model actually needed. - "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", - "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", - "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", - "//tensorflow/compiler/xla/service/cpu:runtime_matmul", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", - "//third_party/eigen3", - ] or []) + (deps or []), - tags=tags, - ) + # First run tfcompile to generate the list of out_nodes. + out_nodes_file = "out_nodes_" + freeze_name + native.genrule( + name = ("gen_" + out_nodes_file), + srcs = [config], + outs = [out_nodes_file], + cmd = ("$(location " + tfcompile_tool + ")" + + " --config=$(location " + config + ")" + + " --dump_fetch_nodes > $@"), + tools = [tfcompile_tool], + # Run tfcompile on the build host, rather than forge, since it's + # typically way faster on the local machine. + local = 1, + tags = tags, + ) - # Variables used for gen_test and gen_benchmark. - no_ns_name = "" - cpp_class_split = cpp_class.rsplit("::", maxsplit=2) - if len(cpp_class_split) == 1: - no_ns_name = cpp_class_split[0] - else: - no_ns_name = cpp_class_split[1] - sed_replace = ( - "-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " + - "-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " + - "-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" ") + # Now run freeze_graph to convert variables into constants. + freeze_args = ( + " --input_graph=$(location " + graph + ")" + + " --checkpoint_version=1" + + " --input_binary=" + str(not graph.endswith(".pbtxt")) + + " --input_checkpoint=$(location " + freeze_checkpoint + ")" + + " --output_graph=$(location " + freeze_file + ")" + + " --output_node_names=$$(<$(location " + out_nodes_file + + "))" + ) + freeze_saver_srcs = [] + if freeze_saver: + freeze_args += " --input_saver=$(location " + freeze_saver + ")" + freeze_saver_srcs += [freeze_saver] + native.genrule( + name = freeze_name, + srcs = [ + graph, + freeze_checkpoint, + out_nodes_file, + ] + freeze_saver_srcs, + outs = [freeze_file], + cmd = ("$(location " + + "//tensorflow/python/tools:freeze_graph)" + + freeze_args), + tools = ["//tensorflow/python/tools:freeze_graph"], + tags = tags, + ) + tfcompile_graph = freeze_file - if gen_test: - test_name = name + "_test" - test_file = test_name + ".cc" - # Rule to rewrite test.cc to produce the test_file. + # Rule that runs tfcompile to produce the header and object file. + header_file = name + ".h" + metadata_object_file = name + "_tfcompile_metadata.o" + function_object_file = name + "_tfcompile_function.o" + ep = ("__" + native.package_name() + "__" + name).replace("/", "_") + if type(tfcompile_flags) == type(""): + flags = tfcompile_flags + else: + flags = " ".join([ + "'" + arg.replace("'", "'\\''") + "'" + for arg in (tfcompile_flags or []) + ]) + if enable_xla_hlo_profiling: + profiling_flag = "--xla_hlo_profile" + else: + profiling_flag = "" native.genrule( - name=("gen_" + test_name), - testonly=1, - srcs=[ - "//tensorflow/compiler/aot:test.cc", + name = ("gen_" + name), + srcs = [ + tfcompile_graph, + config, + ], + outs = [ header_file, + metadata_object_file, + function_object_file, ], - outs=[test_file], - cmd=("sed " + sed_replace + - " $(location //tensorflow/compiler/aot:test.cc) " + - "> $(OUTS)"), - tags=tags, - ) - - # The cc_test rule for the generated code. To ensure that this works - # reliably across build configurations, we must use tf_cc_test instead of - # native.cc_test. This is related to how we build - # //tensorflow/core:lib -- see the note in tensorflow/core/BUILD - # for more details. - tf_cc_test( - name=test_name, - srcs=[test_file], - deps=[ - ":" + name, - "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/aot:tf_library_test_main", - "//tensorflow/compiler/xla:executable_run_options", - "//third_party/eigen3", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], - tags=tags, + cmd = ("$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_header=$(@D)/" + header_file + + " --out_metadata_object=$(@D)/" + metadata_object_file + + " --out_function_object=$(@D)/" + function_object_file + + " " + flags + " " + profiling_flag), + tools = [tfcompile_tool], + visibility = visibility, + testonly = testonly, + # Run tfcompile on the build host since it's typically faster on the + # local machine. + # + # Note that setting the local=1 attribute on a *test target* causes the + # test infrastructure to skip that test. However this is a genrule, not + # a test target, and runs with --genrule_strategy=forced_forge, meaning + # the local=1 attribute is ignored, and the genrule is still run. + # + # https://www.bazel.io/versions/master/docs/be/general.html#genrule + local = 1, + tags = tags, ) - if gen_benchmark: - benchmark_name = name + "_benchmark" - benchmark_file = benchmark_name + ".cc" - benchmark_main = ("//tensorflow/compiler/aot:" + - "benchmark_main.template") - - # Rule to rewrite benchmark.cc to produce the benchmark_file. + # Rule that runs tfcompile to produce the SessionModule proto, useful for + # debugging. TODO(b/64813587): Once the SessionModule proto is + # deterministic, move this into the main rule above. + session_module_pb = name + "_session_module.pb" native.genrule( - name=("gen_" + benchmark_name), - srcs=[ - benchmark_main, - header_file, + name = (name + "_session_module"), + srcs = [ + tfcompile_graph, + config, ], + outs = [ + session_module_pb, + ], + cmd = ("$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_session_module=$(@D)/" + session_module_pb + + " " + flags), + tools = [tfcompile_tool], + visibility = visibility, testonly = testonly, - outs=[benchmark_file], - cmd=("sed " + sed_replace + - " $(location " + benchmark_main + ") " + - "> $(OUTS)"), - tags=tags, + local = 1, + tags = tags, ) - # The cc_benchmark rule for the generated code. This does not need the - # tf_cc_binary since we (by deliberate design) do not depend on - # //tensorflow/core:lib. - # - # Note: to get smaller size on android for comparison, compile with: - # --copt=-fvisibility=hidden - # --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN - # --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN - native.cc_binary( - name=benchmark_name, - srcs=[benchmark_file], + # The cc_library rule packaging up the header and object file, and needed + # kernel implementations. + need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) + native.cc_library( + name = name, + srcs = [function_object_file, metadata_object_file], + hdrs = [header_file], + visibility = visibility, testonly = testonly, - copts = tf_copts(), - linkopts = if_android(["-pie", "-s"]), - deps=[ - ":" + name, - "//tensorflow/compiler/aot:benchmark", - "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/xla:executable_run_options", + deps = [ + # These deps are required by all tf_library targets even if + # include_standard_runtime_deps is False. Without them, the + # generated code will fail to compile. + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", + "//tensorflow/core:framework_lite", + ] + (need_xla_data_proto and [ + # If we're generating the program shape, we must depend on the + # proto. + "//tensorflow/compiler/xla:xla_data_proto", + ] or []) + (enable_xla_hlo_profiling and [ + "//tensorflow/compiler/xla/service:hlo_profile_printer_data", + ] or []) + (include_standard_runtime_deps and [ + # TODO(cwhipkey): only depend on kernel code that the model actually + # needed. + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", + "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_matmul", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//third_party/eigen3", - ] + if_android([ - "//tensorflow/compiler/aot:benchmark_extra_android", - ]), - tags=tags, + ] or []) + (deps or []), + tags = tags, + ) + + # Variables used for gen_test and gen_benchmark. + cpp_class_split = cpp_class.rsplit("::", maxsplit = 2) + if len(cpp_class_split) == 1: + no_ns_name = cpp_class_split[0] + else: + no_ns_name = cpp_class_split[1] + sed_replace = ( + "-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " + + "-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " + + "-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" " ) + if gen_test: + test_name = name + "_test" + test_file = test_name + ".cc" + + # Rule to rewrite test.cc to produce the test_file. + native.genrule( + name = ("gen_" + test_name), + testonly = 1, + srcs = [ + "//tensorflow/compiler/aot:test.cc", + header_file, + ], + outs = [test_file], + cmd = ( + "sed " + sed_replace + + " $(location //tensorflow/compiler/aot:test.cc) " + + "> $(OUTS)" + ), + tags = tags, + ) + + # The cc_test rule for the generated code. To ensure that this works + # reliably across build configurations, we must use tf_cc_test instead + # of native.cc_test. This is related to how we build + # //tensorflow/core:lib -- see the note in + # tensorflow/core/BUILD for more details. + tf_cc_test( + name = test_name, + srcs = [test_file], + deps = [ + ":" + name, + "//tensorflow/compiler/aot:tf_library_test_main", + "//tensorflow/compiler/xla:executable_run_options", + "//third_party/eigen3", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], + tags = tags, + ) + + if gen_benchmark: + benchmark_name = name + "_benchmark" + benchmark_file = benchmark_name + ".cc" + benchmark_main = ("//tensorflow/compiler/aot:" + + "benchmark_main.template") + + # Rule to rewrite benchmark.cc to produce the benchmark_file. + native.genrule( + name = ("gen_" + benchmark_name), + srcs = [ + benchmark_main, + header_file, + ], + testonly = testonly, + outs = [benchmark_file], + cmd = ("sed " + sed_replace + + " $(location " + benchmark_main + ") " + + "> $(OUTS)"), + tags = tags, + ) + + # The cc_benchmark rule for the generated code. This does not need the + # tf_cc_binary since we (by deliberate design) do not depend on + # //tensorflow/core:lib. + # + # Note: to get smaller size on android for comparison, compile with: + # --copt=-fvisibility=hidden + # --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN + # --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN + native.cc_binary( + name = benchmark_name, + srcs = [benchmark_file], + testonly = testonly, + copts = tf_copts(), + linkopts = if_android(["-pie", "-s"]), + deps = [ + ":" + name, + "//tensorflow/compiler/aot:benchmark", + "//tensorflow/compiler/xla:executable_run_options", + "//third_party/eigen3", + ] + if_android([ + "//tensorflow/compiler/aot:benchmark_extra_android", + ]), + tags = tags, + ) + def target_llvm_triple(): - """Returns the target LLVM triple to be used for compiling the target.""" - # TODO(toddw): Add target_triple for other targets. For details see: - # http://llvm.org/docs/doxygen/html/Triple_8h_source.html - return select({ - "//tensorflow:android_armeabi": "armv5-none-android", - "//tensorflow:android_arm": "armv7-none-android", - "//tensorflow:android_arm64": "aarch64-none-android", - "//tensorflow:android_x86": "i686-none-android", - "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", - "//tensorflow:darwin": "x86_64-none-darwin", - "//conditions:default": "x86_64-pc-linux", - }) + """Returns the target LLVM triple to be used for compiling the target.""" + + # TODO(toddw): Add target_triple for other targets. For details see: + # http://llvm.org/docs/doxygen/html/Triple_8h_source.html + return select({ + "//tensorflow:android_armeabi": "armv5-none-android", + "//tensorflow:android_arm": "armv7-none-android", + "//tensorflow:android_arm64": "aarch64-none-android", + "//tensorflow:android_x86": "i686-none-android", + "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", + "//tensorflow:darwin": "x86_64-none-darwin", + "//conditions:default": "x86_64-pc-linux", + }) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 8c74014614789758192691ee065f92759a113a7a..e059f77563ba46d9df247c897ece7a1c14f7801a 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -160,12 +160,14 @@ cc_library( "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -176,13 +178,19 @@ cc_library( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:fifo_queue", + "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", + "//tensorflow/core/kernels:queue_op", "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:variable_ops", + "//tensorflow/core/kernels/data:generator_dataset_op", + "//tensorflow/core/kernels/data:iterator_ops", + "//tensorflow/core/kernels/data:prefetch_dataset_op", ], ) @@ -302,13 +310,20 @@ cc_library( name = "compilation_passes", srcs = [ "build_xla_launch_ops_pass.cc", + "deadness_analysis.cc", + "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", "mark_for_compilation_pass.cc", + "mark_for_compilation_pass_test_helper.cc", + "partially_decluster_pass.cc", ], hdrs = [ "build_xla_launch_ops_pass.h", + "deadness_analysis.h", "encapsulate_subgraphs_pass.h", "mark_for_compilation_pass.h", + "mark_for_compilation_pass_test_helper.h", + "partially_decluster_pass.h", ], deps = [ ":common", @@ -323,6 +338,7 @@ cc_library( "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -342,6 +358,7 @@ cc_library( "//tensorflow/compiler/jit/graphcycles", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", ], @@ -371,16 +388,74 @@ tf_cc_test( ], ) +tf_cc_test( + name = "deadness_analysis_test", + size = "small", + srcs = [ + "deadness_analysis_internal.h", + "deadness_analysis_test.cc", + ], + deps = [ + ":common", + ":compilation_passes", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "compilation_passes_test", size = "small", srcs = [ "encapsulate_subgraphs_pass_test.cc", "mark_for_compilation_pass_test.cc", + "partially_decluster_pass_test.cc", ], deps = [ ":common", ":compilation_passes", + ":xla_cluster_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "xla_cluster_util_test", + size = "small", + srcs = [ + "xla_cluster_util_test.cc", + ], + deps = [ + ":common", + ":xla_cluster_util", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", @@ -430,6 +505,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":common", + ":compilation_passes", ":union_find", ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 731b8ebfdc6262500940274c94a03ae7c0376096..a2e6285339f9ed0bde8d72f5b4752b1ecc22f426 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -66,8 +66,28 @@ class SinglePassSearch { Status CompilationRequested(const FunctionLibraryRuntime& flr, const NodeDef& node_def) { + const FunctionDef* function_def = + flr.GetFunctionLibraryDefinition()->Find(node_def.name()); + if (function_def == nullptr) { + // The node def is not calling a function. Individual ops can be + // run directly using on-demand mode, no need to create XlaLaunch + // kernel for them. + // TODO(b/110359382): Make custom kernel creation return a bool instead of + // status. + // We don't set error messages here to avoid unnecessary string copy. + // Similarly below. + return Status(error::INVALID_ARGUMENT, ""); + } + + // If kXlaCompileAttr is set on the node_def, use its value. + const auto& it = node_def.attr().find(kXlaCompileAttr); + if (it != node_def.attr().end()) { + return it->second.b() ? Status::OK() : Status(error::INVALID_ARGUMENT, ""); + } + + // kXlaCompileAttr is not set on node_def, check if it is set on + // FunctionDef. bool xla_compile = false; - // Check if op is marked _XlaCompile=true. Status status = flr.GetFunctionLibraryDefinition()->GetAttr( node_def, kXlaCompileAttr, &xla_compile); if (!status.ok() || !xla_compile) { diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..62007e6115d3fb81def844fcfa462094e223f565 --- /dev/null +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -0,0 +1,594 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/deadness_analysis.h" +#include "tensorflow/compiler/jit/deadness_analysis_internal.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/hash/hash.h" + +// ALGORITHM OVERVIEW +// +// We map every output produced by each node in the TensorFlow graph (including +// control dependence) into an instance of the Predicate class. Instances of +// Predicate denote logical formulas and mapping a node `n` to a predicate +// `pred` implies that `n` is executed whenver `pred` is true. Then we can +// deduce mismatching liveness in the inputs to node by comparing the predicate +// those inputs are mapped to. +// +// Loops are handled pessimistically -- we map Merge nodes with backedges to +// uninterpreted symbols (the same kind we use to represent Switch and _Recv). +// Predicate equality has to hold over all possible assignments to these +// uninterpreted symbols. + +namespace tensorflow { + +namespace { + +// Represents a logical predicate, used as described in the algorithm overview +// above. +class Predicate { + public: + enum class Kind { kAnd, kOr, kNot, kSymbol }; + + virtual string ToString() const = 0; + int64 hash() const { return hash_; } + virtual gtl::ArraySlice GetOperands() const = 0; + + virtual Kind kind() const = 0; + virtual ~Predicate() {} + + protected: + explicit Predicate(int64 hash) : hash_(hash) {} + + private: + const int64 hash_; + + TF_DISALLOW_COPY_AND_ASSIGN(Predicate); +}; + +int64 HashPredicateSequence(Predicate::Kind kind, + gtl::ArraySlice preds) { + int64 hash = ::tensorflow::hash()(kind); + for (Predicate* pred : preds) { + hash = Hash64Combine(hash, pred->hash()); + } + return hash; +} + +// Represents a logical conjunction of a set of predicates. +class AndPredicate : public Predicate { + public: + explicit AndPredicate(std::vector operands) + : Predicate(HashPredicateSequence(Kind::kAnd, operands)), + operands_(std::move(operands)) {} + + string ToString() const override { + if (operands().empty()) { + return "#true"; + } + + std::vector operands_str; + std::transform(operands().begin(), operands().end(), + std::back_inserter(operands_str), + [](Predicate* pred) { return pred->ToString(); }); + + return strings::StrCat("(", str_util::Join(operands_str, " & "), ")"); + } + + Kind kind() const override { return Kind::kAnd; } + + gtl::ArraySlice GetOperands() const override { return operands_; } + gtl::ArraySlice operands() const { return operands_; } + + private: + std::vector operands_; +}; + +// Represents a logical disjunction of a set of predicates. +class OrPredicate : public Predicate { + public: + explicit OrPredicate(std::vector operands) + : Predicate(HashPredicateSequence(Kind::kOr, operands)), + operands_(std::move(operands)) {} + + string ToString() const override { + if (operands().empty()) { + return "#false"; + } + + std::vector operands_str; + std::transform(operands().begin(), operands().end(), + std::back_inserter(operands_str), + [](Predicate* pred) { return pred->ToString(); }); + + return strings::StrCat("(", str_util::Join(operands_str, " | "), ")"); + } + + Kind kind() const override { return Kind::kOr; } + gtl::ArraySlice GetOperands() const override { return operands_; } + gtl::ArraySlice operands() const { return operands_; } + + private: + std::vector operands_; +}; + +// Represents a logical negation of a set of predicates. +class NotPredicate : public Predicate { + public: + explicit NotPredicate(Predicate* operand) + : Predicate(HashPredicateSequence(Kind::kNot, {operand})), + operands_({operand}) {} + + string ToString() const override { + return strings::StrCat("~", operand()->ToString()); + } + + Kind kind() const override { return Kind::kNot; } + Predicate* operand() const { return operands_[0]; } + gtl::ArraySlice GetOperands() const override { return operands_; } + + private: + std::array operands_; +}; + +// Represents an uninterpreted symbol in a logical predicate. +// +// Two predicates are equivalent iff they are equivalent for all assignments to +// the symbols contained in them. +class SymbolPredicate : public Predicate { + public: + explicit SymbolPredicate(TensorId tensor_id, bool must_be_true) + : Predicate(Hash(tensor_id, must_be_true)), + tensor_id_(std::move(tensor_id)), + must_be_true_(must_be_true) {} + + string ToString() const override { + return must_be_true() ? strings::StrCat("*", tensor_id_.ToString()) + : tensor_id_.ToString(); + } + + Kind kind() const override { return Kind::kSymbol; } + gtl::ArraySlice GetOperands() const override { return {}; } + + // If `must_be_true()` is true this SymbolPredicate represents the proposition + // "tensor_id() is live and evaluates to true". + // + // If `must_be_true()` is false then this SymbolPredicate represents the + // proposition "tensor_id() is live (and may evalutate to any value)" + TensorId tensor_id() const { return tensor_id_; } + bool must_be_true() const { return must_be_true_; } + + private: + TensorId tensor_id_; + bool must_be_true_; + + static int64 Hash(const TensorId tensor_id, bool must_be_true) { + return Hash64Combine( + ::tensorflow::hash()(must_be_true), + Hash64Combine(::tensorflow::hash()(Kind::kSymbol), + TensorId::Hasher{}(tensor_id))); + } +}; + +// Creates and owns Predicate instances. Simplifies predicates as it creates +// them. +class PredicateFactory { + public: + Predicate* MakeAndPredicate(gtl::ArraySlice operands) { + return MakeAndOrImpl(operands, /*is_and=*/true); + } + + Predicate* MakeOrPredicate(gtl::ArraySlice operands) { + return MakeAndOrImpl(operands, /*is_and=*/false); + } + + Predicate* MakeNotPredicate(Predicate* pred) { + SignatureForNot signature = pred; + auto it = interned_not_instances_.find(signature); + if (it == interned_not_instances_.end()) { + std::unique_ptr new_pred = Make(pred); + Predicate* new_pred_ptr = new_pred.get(); + interned_not_instances_.emplace(signature, std::move(new_pred)); + return new_pred_ptr; + } else { + return it->second.get(); + } + } + + Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) { + SignatureForSymbol signature = {tensor_id, must_be_true}; + auto it = interned_symbol_instances_.find(signature); + if (it == interned_symbol_instances_.end()) { + std::unique_ptr new_pred = + Make(tensor_id, must_be_true); + Predicate* new_pred_ptr = new_pred.get(); + interned_symbol_instances_.emplace(std::move(signature), + std::move(new_pred)); + return new_pred_ptr; + } else { + return it->second.get(); + } + } + + Predicate* MakeTrue() { return MakeAndPredicate({}); } + Predicate* MakeFalse() { return MakeOrPredicate({}); } + + private: + template + std::unique_ptr Make(Args&&... args) { + return std::unique_ptr( + new PredicateT(std::forward(args)...)); + } + + Predicate* MakeAndOrImpl(gtl::ArraySlice operands, bool is_and); + + // Predicate instances are interned, meaning that there is only a single + // instance of a Predicate object with a given content. This makes checking + // for structural equality super-cheap -- we can just compare pointers. + // + // We intern predicates by maintaining a map from the content of a Predicate + // to the only instance of said predicate we allow to exist in the + // interned_and_or_instances_, interned_not_instances_ and + // interned_symbol_instances_ fields. These maps also double up as storage + // for the owning pointers to predicate instances. + + using SignatureForAndOr = + std::pair>; + using SignatureForNot = Predicate*; + using SignatureForSymbol = std::pair; + + struct HashSignatureForAndOr { + size_t operator()(const SignatureForAndOr& signature) const { + size_t hash = ::tensorflow::hash()(signature.first); + for (Predicate* p : signature.second) { + hash = Hash64Combine(hash, ::tensorflow::hash()(p)); + } + return hash; + } + }; + + struct HashSignatureForSymbol { + size_t operator()(const SignatureForSymbol& signature) const { + return Hash64Combine(SafeTensorId::Hasher()(signature.first), + ::tensorflow::hash()(signature.second)); + } + }; + + gtl::FlatMap, + HashSignatureForAndOr> + interned_and_or_instances_; + gtl::FlatMap> + interned_not_instances_; + gtl::FlatMap, + HashSignatureForSymbol> + interned_symbol_instances_; +}; + +// Common code to create AndPredicate or OrPredicate instances. +Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, + bool is_and) { + Predicate::Kind pred_kind = + is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; + gtl::FlatSet simplified_ops_set; + std::vector simplified_ops; + for (Predicate* op : operands) { + // Simplify A&A => A and A|A => A. + if (!simplified_ops_set.insert(op).second) { + continue; + } + + if (op->kind() == pred_kind) { + // "Inline" the operands of an inner And/Or into the parent And/Or. + for (Predicate* subop : op->GetOperands()) { + if (simplified_ops_set.insert(subop).second) { + simplified_ops.push_back(subop); + } + } + } else { + simplified_ops.push_back(op); + } + } + + if (simplified_ops.size() == 1) { + return simplified_ops[0]; + } + + // Simplify "A&~A=>False" and "A|~A=>True". + gtl::FlatSet negated_ops; + for (Predicate* op : simplified_ops) { + if (op->kind() == Predicate::Kind::kNot) { + negated_ops.insert(dynamic_cast(*op).operand()); + } + } + + for (Predicate* op : simplified_ops) { + if (negated_ops.count(op)) { + return is_and ? MakeFalse() : MakeTrue(); + } + } + + std::stable_sort( + simplified_ops.begin(), simplified_ops.end(), + [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + + auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); + if (it == interned_and_or_instances_.end()) { + simplified_ops.shrink_to_fit(); + // NB! Because we'll use a non-owning reference to simplified_ops in the + // key for interned_and_or_instances_ we need to be careful to std::move() + // it all the way through. + gtl::ArraySlice operands_slice = simplified_ops; + std::unique_ptr new_pred = + is_and ? Make(std::move(simplified_ops)) + : Make(std::move(simplified_ops)); + + Predicate* new_pred_ptr = new_pred.get(); + CHECK(interned_and_or_instances_ + .emplace(SignatureForAndOr(pred_kind, operands_slice), + std::move(new_pred)) + .second); + return new_pred_ptr; + } else { + return it->second.get(); + } +} + +class DeadnessAnalysisImpl : public DeadnessAnalysis { + public: + explicit DeadnessAnalysisImpl(const Graph* graph) + : graph_(*graph), vlog_(VLOG_IS_ON(2)) {} + + Status Populate(); + bool HasInputsWithMismatchingDeadness(const Node& node) override; + void Print() const override; + gtl::FlatMap PredicateMapAsString() const; + + private: + enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; + + std::vector GetIncomingPreds(Node* n, EdgeKind edge_kind); + void SetPred(Node* n, int output_idx, Predicate* pred) { + CHECK( + predicate_map_.insert({TensorId(n->name(), output_idx), pred}).second); + } + void SetPred(Node* n, gtl::ArraySlice output_idxs, Predicate* pred) { + for (int output_idx : output_idxs) { + SetPred(n, output_idx, pred); + } + } + + Status HandleSwitch(Node* n); + Status HandleMerge(Node* n); + Status HandleRecv(Node* n); + Status HandleGeneric(Node* n); + + const Graph& graph_; + gtl::FlatMap predicate_map_; + PredicateFactory predicate_factory_; + bool vlog_; +}; + +TensorId InputEdgeToTensorId(const Edge* e) { + return TensorId(e->src()->name(), e->src_output()); +} + +std::vector DeadnessAnalysisImpl::GetIncomingPreds( + Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) { + std::vector incoming_preds; + for (const Edge* in_edge : n->in_edges()) { + bool should_process = + edge_kind == EdgeKind::kDataAndControl || + (in_edge->IsControlEdge() && edge_kind == EdgeKind::kControlOnly) || + (!in_edge->IsControlEdge() && edge_kind == EdgeKind::kDataOnly); + + if (should_process) { + auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); + CHECK(it != predicate_map_.end()); + incoming_preds.push_back(it->second); + } + } + return incoming_preds; +} + +Status DeadnessAnalysisImpl::HandleSwitch(Node* n) { + std::vector input_preds = + GetIncomingPreds(n, EdgeKind::kDataAndControl); + const Edge* pred_edge; + TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge)); + Predicate* true_switch = predicate_factory_.MakeSymbolPredicate( + TensorId(pred_edge->src()->name(), pred_edge->src_output()), + /*must_be_true=*/true); + Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch); + + // Output 0 is alive iff all inputs are alive and the condition is false. + input_preds.push_back(false_switch); + SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds)); + input_preds.pop_back(); + + // Output 1 is alive iff all inputs are alive and the condition is true. + input_preds.push_back(true_switch); + SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds)); + input_preds.pop_back(); + + // Control is alive iff any inputs are alive. + SetPred(n, Graph::kControlSlot, + predicate_factory_.MakeAndPredicate(input_preds)); + + return Status::OK(); +} + +Status DeadnessAnalysisImpl::HandleMerge(Node* n) { + // Merge ignores deadness of its control inputs. A merge that isn't the + // target of a backedge has is alive iff any of its data inputs are. We treat + // the liveness of a merge that is the target of a backedge symbolically. + + bool has_backedge = std::any_of( + n->in_edges().begin(), n->in_edges().end(), [](const Edge* e) { + return !e->IsControlEdge() && e->src()->IsNextIteration(); + }); + + Predicate* input_data_pred = + has_backedge ? predicate_factory_.MakeSymbolPredicate( + TensorId(n->name(), 0), /*must_be_true=*/false) + : predicate_factory_.MakeOrPredicate( + GetIncomingPreds(n, EdgeKind::kDataOnly)); + + SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred); + return Status::OK(); +} + +Status DeadnessAnalysisImpl::HandleRecv(Node* n) { + // In addition to being alive or dead based on the inputs, a _Recv can also + // acquire a dead signal from a _Send. + std::vector input_preds = + GetIncomingPreds(n, EdgeKind::kDataAndControl); + input_preds.push_back(predicate_factory_.MakeSymbolPredicate( + TensorId(n->name(), 0), /*must_be_true=*/false)); + SetPred(n, {0, Graph::kControlSlot}, + predicate_factory_.MakeAndPredicate(input_preds)); + return Status::OK(); +} + +Status DeadnessAnalysisImpl::HandleGeneric(Node* n) { + // Generally nodes are alive iff all their inputs are alive. + Predicate* pred = predicate_factory_.MakeAndPredicate( + GetIncomingPreds(n, EdgeKind::kDataAndControl)); + for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) { + SetPred(n, output_idx, pred); + } + SetPred(n, Graph::kControlSlot, pred); + return Status::OK(); +} + +Status DeadnessAnalysisImpl::Populate() { + std::vector rpo; + GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{}, + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + + // This an abstract interpretation over the deadness propagation semantics of + // the graph executor. + for (Node* n : rpo) { + if (n->IsSwitch()) { + TF_RETURN_IF_ERROR(HandleSwitch(n)); + } else if (n->IsMerge()) { + TF_RETURN_IF_ERROR(HandleMerge(n)); + } else if (n->IsControlTrigger()) { + SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue()); + } else if (n->IsRecv() || n->IsHostRecv()) { + TF_RETURN_IF_ERROR(HandleRecv(n)); + } else { + TF_RETURN_IF_ERROR(HandleGeneric(n)); + } + } + + return Status::OK(); +} + +bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) { + CHECK(!node.IsMerge()); + + if (vlog_) { + VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ")"; + } + + Predicate* pred = nullptr; + for (const Edge* edge : node.in_edges()) { + auto it = predicate_map_.find(InputEdgeToTensorId(edge)); + CHECK(it != predicate_map_.end()); + if (vlog_) { + VLOG(2) << " " << InputEdgeToTensorId(edge).ToString() << ": " + << it->second->ToString(); + } + + // Today we just compare the predicates for equality (with some + // canonicalization/simplification happening before) but we could be more + // sophisticated here if need be. Comparing pointers is sufficient because + // we intern Predicate instances by their content. + if (pred != nullptr && pred != it->second) { + if (vlog_) { + VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() + << ") -> true"; + } + return true; + } + pred = it->second; + } + + if (vlog_) { + VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() + << ") -> false"; + } + + return false; +} + +void DeadnessAnalysisImpl::Print() const { + std::vector tensor_ids; + for (const auto& kv_pair : predicate_map_) { + tensor_ids.push_back(kv_pair.first); + } + + std::sort(tensor_ids.begin(), tensor_ids.end()); + + for (TensorId tensor_id : tensor_ids) { + auto it = predicate_map_.find(tensor_id); + CHECK(it != predicate_map_.end()) << tensor_id.ToString(); + VLOG(2) << tensor_id.ToString() << " -> " << it->second->ToString(); + } +} + +} // namespace + +DeadnessAnalysis::~DeadnessAnalysis() {} + +/*static*/ Status DeadnessAnalysis::Run( + const Graph& graph, std::unique_ptr* result) { + std::unique_ptr analysis( + new DeadnessAnalysisImpl(&graph)); + TF_RETURN_IF_ERROR(analysis->Populate()); + + if (VLOG_IS_ON(2)) { + analysis->Print(); + } + + *result = std::move(analysis); + return Status::OK(); +} + +gtl::FlatMap +DeadnessAnalysisImpl::PredicateMapAsString() const { + gtl::FlatMap result; + std::vector tensor_ids; + for (const auto& kv_pair : predicate_map_) { + CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second); + } + return result; +} + +namespace deadness_analysis_internal { +Status ComputePredicates(const Graph& graph, + PredicateMapTy* out_predicate_map) { + DeadnessAnalysisImpl impl(&graph); + TF_RETURN_IF_ERROR(impl.Populate()); + *out_predicate_map = impl.PredicateMapAsString(); + return Status::OK(); +} +} // namespace deadness_analysis_internal + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis.h b/tensorflow/compiler/jit/deadness_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..6e7ab411619ba08060aa4925e91dce06299d1d23 --- /dev/null +++ b/tensorflow/compiler/jit/deadness_analysis.h @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// This analyzes a TensorFlow graph to identify nodes which may have partially +// dead inputs (i.e. these nodes may have some dead inputs and some alive +// inputs). +// +// For example, the ADD node in the following graph +// +// V0 PRED0 V1 PRED1 +// | | | | +// v v v v +// SWITCH SWITCH +// | | +// +---+ + ---+ +// | | +// v v +// ADD +// +// can have its inputs independently dead or alive based on the runtime values +// of PRED0 and PRED1. +// +// It is tempting to call this a liveness analysis but I avoided that because +// "liveness" already has other connotations. +class DeadnessAnalysis { + public: + // Returns true if `node` may have some live inputs and some dead inputs. + // + // This is a conservatively correct routine -- if it returns false then `node` + // is guaranteed to not have inputs with mismatching liveness, but not the + // converse. + // + // REQUIRES: node is not a Merge operation. + virtual bool HasInputsWithMismatchingDeadness(const Node& node) = 0; + + // Prints out the internal state of this instance. For debugging purposes + // only. + virtual void Print() const = 0; + virtual ~DeadnessAnalysis(); + + // Run the deadness analysis over `graph` and returns an error or a populated + // instance of DeadnessAnalysis in `result`. + static Status Run(const Graph& graph, + std::unique_ptr* result); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..cdef4051108fdc5d063ab592676c7644989155bf --- /dev/null +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ +#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ + +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { +namespace deadness_analysis_internal { + +// Returns a map describing the predicate each Tensor was mapped to. For +// testing purposes only. +using PredicateMapTy = gtl::FlatMap; +Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); +} // namespace deadness_analysis_internal +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6881095b51758d2e0b06c60021bc8c2860ac566e --- /dev/null +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -0,0 +1,467 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/deadness_analysis.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/deadness_analysis_internal.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +Status AnalyzeDeadness(Graph* graph, + std::unique_ptr* result) { + FixupSourceAndSinkEdges(graph); + return DeadnessAnalysis::Run(*graph, result); +} + +ops::Switch CreateSwitch(const Scope& root, const string& prefix) { + Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT); + Output predicate = + ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL); + return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate); +} + +Output CreateInductionVariable(const Scope& root, const string& prefix, + const string& frame_name, int32 init) { + Output initial_value = ops::Const(root.WithOpName(prefix + "/init"), init); + Output enter_initial_value = ops::internal::Enter( + root.WithOpName(prefix + "/enter"), initial_value, frame_name); + + ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_initial_value}); + Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1); + Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10); + Output loop_cond_expr = + ops::Less(root.WithOpName(prefix + "/less"), iv.output, final_value); + Output loop_cond = + ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); + ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + Output iv_next = + ops::Add(root.WithOpName(prefix + "/ivnext"), iv.output, increment_by); + Output next_iteration = + ops::NextIteration(root.WithOpName(prefix + "next_iteration"), iv_next); + + root.graph()->AddEdge(next_iteration.node(), 0, iv.output.node(), 1); + root.graph()->AddControlEdge(iv.output.node(), increment_by.node()); + root.graph()->AddControlEdge(iv.output.node(), final_value.node()); + + return iv.output; +} + +TEST(DeadnessAnalysisTest, BasicPositive) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw = CreateSwitch(root, "0"); + Output add = + ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, BasicNegative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT); + Output add = ops::Add(root.WithOpName("add"), a, b); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, AndIsCommutative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + + Output a0 = + ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false); + Output a1 = + ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false); + + Output b0 = + ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true); + Output b1 = + ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false); + + Output live0 = ops::Add(root.WithOpName("live0"), a0, a1); + Output live1 = ops::Add(root.WithOpName("live1"), b0, b1); + + Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0); + Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node())); + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node())); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node())); + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node())); +} + +TEST(DeadnessAnalysisTest, AndIsAssociative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + + Output a0 = + ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false); + Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false); + + Output b0 = + ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false); + Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0); + + Output add = ops::Add(root.WithOpName("add"), a1, b1); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, OrIsCommutative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + + ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false}); + ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false}); + ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true}); + ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false}); + + Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output); + Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output); + + Output halfdead0 = + ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output); + Output halfdead1 = + ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node())); + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node())); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node())); + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node())); +} + +TEST(DeadnessAnalysisTest, OrIsAssociative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + + ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false}); + ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false}); + ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false}); + ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output}); + + Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, AndOfOr) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + ops::Switch sw_3 = CreateSwitch(root, "3"); + + ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false}); + ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false}); + + Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output); + Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output); + + Output add2 = ops::Add(root.WithOpName("add2"), add0, add1); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node())); +} + +TEST(DeadnessAnalysisTest, OrOfAnd) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + ops::Switch sw_3 = CreateSwitch(root, "3"); + + Output add0 = + ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false); + Output add1 = + ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false); + + ops::Merge m0(root.WithOpName("m0"), {add0, add1}); + ops::Merge m1(root.WithOpName("m1"), {add0, add1}); + + Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node())); +} + +TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) { + // This demonstrates one of the weaknesses in the current approach -- since we + // only do some basic simplifications we can't see that "(A|B)&C" == + // "(A&C)|(B&C)". + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + + ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false}); + Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false); + + Output add1 = + ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false); + Output add2 = + ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false); + ops::Merge m1(root.WithOpName("m1"), {add1, add2}); + + Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node())); +} + +TEST(DeadnessAnalysisTest, Ternary) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL); + Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT); + Output false_value = + ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT); + + ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value, + predicate); + + ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value, + predicate); + ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true, + predicated_false.output_false}); + Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT); + Output add = ops::Add(root.WithOpName("add"), merge.output, addend); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, Recv) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a", + "sender", 0, "receiver"); + Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b", + "sender", 0, "receiver"); + Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, HostRecv) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT, + "tensor_a", "sender", 0, "receiver"); + Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT, + "tensor_b", "sender", 0, "receiver"); + Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, Loop) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0); + Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0); + Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1); + Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1); + Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + // NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have + // noticed that. Today we are pessimistic here because we assign an + // uninterpreted symbol to merges with backedges. + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node())); +} + +TEST(DeadnessAnalysisTest, ControlInputs) { + Scope root = Scope::NewRootScope().ExitOnError(); + ops::Switch sw = CreateSwitch(root, "0"); + + Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false); + Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true); + + Output const0 = ops::Const(root.WithOpName("const0"), 1); + Output const1 = ops::Const(root.WithOpName("const1"), 2); + + Output add = ops::Add(root.WithOpName("add"), const0, const1); + + root.graph()->AddControlEdge(id0.node(), const0.node()); + root.graph()->AddControlEdge(id1.node(), const1.node()); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, ControlTrigger) { + Scope root = Scope::NewRootScope().ExitOnError(); + ops::Switch sw = CreateSwitch(root, "0"); + + Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false); + Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true); + + ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0")); + ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1")); + + Output const0 = ops::Const(root.WithOpName("const0"), 1); + Output const1 = ops::Const(root.WithOpName("const1"), 2); + + Output add = ops::Add(root.WithOpName("add"), const0, const1); + + root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node()); + root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node()); + + root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node()); + root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node()); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, ControlInputsToMerge) { + Scope root = Scope::NewRootScope().ExitOnError(); + ops::Switch sw = CreateSwitch(root, "0"); + + Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false); + Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true); + + Output constant = ops::Const(root.WithOpName("constant"), 5); + ops::Merge m0(root.WithOpName("m0"), {constant}); + ops::Merge m1(root.WithOpName("m0"), {constant}); + Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output); + + root.graph()->AddControlEdge(id0.node(), m0.output.node()); + root.graph()->AddControlEdge(id1.node(), m1.output.node()); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, RecvVsSwitch) { + // Demonstrates why we need the must_be_true bit on SymbolP. + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender", + 0, "receiver"); + Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL); + ops::Switch sw(root.WithOpName("switch"), value, recv); + Output logical_and = + ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node())); +} + +TEST(DeadnessAnalysisTest, RecvVsSwitchText) { + // Demonstrates why we need the must_be_true bit on SymbolP. + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender", + 0, "receiver"); + Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL); + ops::Switch sw(root.WithOpName("switch"), value, recv); + Output logical_and = + ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + deadness_analysis_internal::PredicateMapTy predicate_map; + TF_ASSERT_OK(deadness_analysis_internal::ComputePredicates(*root.graph(), + &predicate_map)); + + TensorId logical_and_output_0 = {logical_and.node()->name(), + Graph::kControlSlot}; + EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index ea90d714c8fe2c2959a92a83ec105dbf2c098f7a..f150bf1819d407e1c6a279673a89de4307b5426b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -60,9 +60,9 @@ const char* const kXlaHostTransferSequencerAttr = namespace { -bool AreAllParentsConst(const Node& n, - const gtl::FlatSet& runtime_const_nodes) { - if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") { +bool AreAllParentsGuaranteedConst( + const Node& n, const gtl::FlatSet& runtime_const_nodes) { + if (n.type_string() == "GuaranteeConst") { // If the current node is itself a cast-to-const, no need // to look at the incoming edges. return true; @@ -93,7 +93,8 @@ void MarkGuaranteedConstants( ReverseDFSFrom(graph, srcs, /*enter=*/nullptr, /*leave=*/[&guaranteed_const_nodes](const Node* n) { // TODO(vinuraja): Doesn't work in the presence of loops. - if (AreAllParentsConst(*n, guaranteed_const_nodes)) { + if (AreAllParentsGuaranteedConst(*n, + guaranteed_const_nodes)) { guaranteed_const_nodes.insert(n); } }); @@ -106,41 +107,11 @@ void MarkGuaranteedConstants( } } -// A node/slot pair. -// TODO(phawkins): is there a common definition of this? -struct NodeSlot { - NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {} - NodeSlot(const Node* node, int slot) - : node(node), slot(slot), dtype(DT_INVALID) {} - NodeSlot(const Node* node, int slot, DataType dtype) - : node(node), slot(slot), dtype(dtype) {} - - const Node* node; - int slot; - - // Optional: used to record the destination type of a source NodeSlot in case - // the source output is a Ref type that is cast to a Tensor at the - // destination. - DataType dtype; - - bool operator==(const NodeSlot& other) const { - return node == other.node && slot == other.slot && dtype == other.dtype; - } - - // Leave dtype out of the hash since there are never two NodeSlots with the - // same node and slot and different dtypes. - struct Hasher { - uint64 operator()(NodeSlot const& s) const { - return Hash64Combine(std::hash()(s.node), - std::hash()(s.slot)); - } - }; - - struct PairHasher { - uint64 operator()(std::pair const& s) const { - return Hash64Combine(Hasher()(s.first), Hasher()(s.second)); - } - }; +struct OutputInputTensorPairHasher { + uint64 operator()(std::pair const& s) const { + return Hash64Combine(OutputTensor::Hash()(s.first), + InputTensor::Hash()(s.second)); + } }; // TODO(phawkins) add a canonical copy of these operator names and refactor @@ -167,7 +138,7 @@ class Encapsulator { // Find subgraphs marked with 'group_attribute', and build a new // subgraph, one for each value of 'group_attribute'. - Status SplitIntoSubgraphs(); + Status SplitIntoSubgraphs(FunctionLibraryDefinition* library); // Build a FunctionDef for each subgraph, and add it 'library'. The values of // the 'group_attribute' annotations become the function names. @@ -376,7 +347,7 @@ class Encapsulator { // Map from source (producer node/slot) tensors in the original graph to // input index (slot number in the HostCompute/RecvAtHost nodes that will // be created) for the outside_compilation subgraph. - std::unordered_map inputs; + std::unordered_map inputs; // Set of nodes in the original graph that are the source of control edges // that cross from the containing compiled subgraph into the @@ -392,8 +363,15 @@ class Encapsulator { // node/slot) tensors in the original graph to output index (slot number // in the SendFromHost/HostCompute nodes that will be created) for the // outside_compilation subgraph. - std::unordered_map outputs_by_src; - std::unordered_map outputs_by_dst; + struct ArgNumAndType { + int index; + DataType dtype; + + ArgNumAndType(int i, DataType t) : index(i), dtype(t) {} + }; + std::unordered_map + outputs_by_src; + std::unordered_map outputs_by_dst; // Set of nodes in the original graph that are the destination of control // edges that cross from the outside_compilation subgraph into the @@ -479,14 +457,14 @@ class Encapsulator { // (consumer node/slot) tensors in the input graph to _Arg numbers in // the subgraph. The source map is one-to-one, whereas the dest map may be // many-to-one. - std::unordered_map args_by_src_; - std::unordered_map args_by_dst_; + std::unordered_map args_by_src_; + std::unordered_map args_by_dst_; - // The _Arg nodes in the subgraph, in order by argument number. + // The arguments to the subgraph, in order. std::vector args_; // Map from source tensor in the input graph to result #. - std::unordered_map results_; + std::unordered_map results_; // The outside_compilation clusters in this subgraph. std::unordered_map @@ -583,8 +561,8 @@ class Encapsulator { const string& dst_outside_compilation_id, const std::unordered_map& node_images, Graph* graph_out, - std::unordered_set, NodeSlot::PairHasher>* - edges_added); + std::unordered_set, + OutputInputTensorPairHasher>* edges_added); // Adds control dependencies between subgraph call nodes that have // dependencies via outside_compilation edges. @@ -716,11 +694,11 @@ void TopologicalClusterSort( Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; } int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const { - return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input())); + return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input())); } int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const { - return results_.at(NodeSlot(edge->src(), edge->src_output())); + return results_.at(OutputTensor(edge->src(), edge->src_output())); } Node* Encapsulator::Subgraph::GetRecvAtHostNode( @@ -732,7 +710,7 @@ Node* Encapsulator::Subgraph::GetRecvAtHostNode( int Encapsulator::Subgraph::GetRecvAtHostSlot( const string& outside_compilation_subgraph_name, const Edge* edge) const { return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) - .inputs.at(NodeSlot(edge->src(), edge->src_output())); + .inputs.at(OutputTensor(edge->src(), edge->src_output())); } Node* Encapsulator::Subgraph::GetSendFromHostNode( @@ -744,7 +722,7 @@ Node* Encapsulator::Subgraph::GetSendFromHostNode( int Encapsulator::Subgraph::GetSendFromHostSlot( const string& outside_compilation_subgraph_name, const Edge* edge) const { return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) - .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input())); + .outputs_by_dst.at(InputTensor(edge->dst(), edge->dst_input())); } Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { @@ -769,10 +747,10 @@ Status Encapsulator::Subgraph::RecordArg( std::vector>* src_arg_pairs) { Node* src_node = edge->src(); int src_slot = edge->src_output(); - std::unordered_map::iterator iter; + std::unordered_map::iterator iter; bool inserted; - std::tie(iter, inserted) = - args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size()); + std::tie(iter, inserted) = args_by_src_.emplace( + OutputTensor(src_node, src_slot), args_by_src_.size()); int arg_index = iter->second; if (inserted) { NodeDef arg_def; @@ -793,7 +771,7 @@ Status Encapsulator::Subgraph::RecordArg( Node* dst_node = edge->dst(); Node* dst_image = node_images.at(dst_node); int dst_slot = edge->dst_input(); - args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index; + args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index; graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot); return Status::OK(); } @@ -804,10 +782,10 @@ Status Encapsulator::Subgraph::RecordResult( Node* src_node = edge->src(); Node* src_image = node_images.at(src_node); int src_slot = edge->src_output(); - std::unordered_map::iterator iter; + std::unordered_map::iterator iter; bool inserted; std::tie(iter, inserted) = - results_.emplace(NodeSlot(src_node, src_slot), results_.size()); + results_.emplace(OutputTensor(src_node, src_slot), results_.size()); int ret_index = iter->second; if (inserted) { NodeDef ret_def; @@ -845,8 +823,8 @@ void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( outside_subgraph->control_inputs.insert(edge->src()); } else { int input_index = outside_subgraph->inputs.size(); - outside_subgraph->inputs.emplace(NodeSlot(edge->src(), edge->src_output()), - input_index); + outside_subgraph->inputs.emplace( + OutputTensor(edge->src(), edge->src_output()), input_index); } } @@ -860,11 +838,13 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( DataType dtype = edge->dst()->input_type(edge->dst_input()); auto output_iter = outside_subgraph->outputs_by_src - .emplace(NodeSlot(edge->src(), edge->src_output(), dtype), - outside_subgraph->outputs_by_src.size()) + .emplace(OutputTensor(edge->src(), edge->src_output()), + OutsideCompilationSubgraph::ArgNumAndType( + outside_subgraph->outputs_by_src.size(), dtype)) .first; - int output_index = output_iter->second; - outside_subgraph->outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = + const int output_index = output_iter->second.index; + outside_subgraph + ->outputs_by_dst[InputTensor(edge->dst(), edge->dst_input())] = output_index; } } @@ -946,7 +926,7 @@ Status Encapsulator::Subgraph::AddHostComputes( for (const auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; Node* src_image = node_images.at(src_node); - int src_slot = input_src.first.slot; + int src_slot = input_src.first.index; int input_index = input_src.second; DataType dtype = src_node->output_type(src_slot); @@ -954,8 +934,8 @@ Status Encapsulator::Subgraph::AddHostComputes( input_dtypes[input_index] = dtype; } for (const auto& output : oc_subgraph.outputs_by_src) { - DataType dtype = output.first.dtype; - int output_index = output.second; + DataType dtype = output.second.dtype; + int output_index = output.second.index; output_dtypes[output_index] = dtype; } @@ -993,7 +973,7 @@ Status Encapsulator::Subgraph::AddHostComputes( for (auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; Node* src_image = node_images.at(src_node); - int src_slot = input_src.first.slot; + int src_slot = input_src.first.index; int input_index = input_src.second; graph_->AddEdge(src_image, src_slot, host_compute, input_index); } @@ -1015,7 +995,7 @@ Status Encapsulator::Subgraph::AddHostComputes( for (const auto& output : oc_subgraph.outputs_by_dst) { const Node* dst_node = output.first.node; Node* dst_image = node_images.at(dst_node); - int dst_slot = output.first.slot; + int dst_slot = output.first.index; int output_index = output.second; graph_->AddEdge(host_compute, output_index, dst_image, dst_slot); @@ -1068,14 +1048,19 @@ Status Encapsulator::Subgraph::BuildFunctionDef( call_node_def_.set_device(device_); if (rewrite_subgraph_fn) { + std::vector arg_source_tensors(args_by_src_.size()); + for (const auto& arg : args_by_src_) { + arg_source_tensors.at(arg.second) = arg.first; + } // Initialize the input and output permutations to the identity. std::vector input_permutation(args_by_src_.size()); std::iota(input_permutation.begin(), input_permutation.end(), 0); std::vector output_permutation(results_.size()); std::iota(output_permutation.begin(), output_permutation.end(), 0); - TF_RETURN_IF_ERROR(rewrite_subgraph_fn( - &graph_, &input_permutation, &output_permutation, &call_node_def_)); + TF_RETURN_IF_ERROR( + rewrite_subgraph_fn(arg_source_tensors, &graph_, &input_permutation, + &output_permutation, &call_node_def_)); // Apply the input/output permutations to the 'args_by_...' and 'results_' // mappings, so when we build edges in BuildOutputGraph() we @@ -1152,7 +1137,10 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo( GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef)); host_compute->AddAttr("shape_inference_graph", inference_graph_name); host_compute->AddAttr("shapes", std::vector()); - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + // TODO(sibyl-Aix6ihai): Understand why there are multiple calls to Encapsulator. + if (library->Find(inference_graph_name) == nullptr) { + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + } } return Status::OK(); } @@ -1173,8 +1161,7 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( strings::StrCat("replace_encapsulate_fdef_", name), fdef); } - TF_RETURN_IF_ERROR(library->RemoveFunction(name)); - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); return Status::OK(); } @@ -1226,7 +1213,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( for (const auto& input : oc_subgraph->inputs) { const Node* src_node = input.first.node; - int src_slot = input.first.slot; + int src_slot = input.first.index; int input_index = input.second; DataType dtype = src_node->output_type(src_slot); @@ -1280,8 +1267,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( for (const auto& output : oc_subgraph->outputs_by_src) { const Node* src_node = output.first.node; Node* src_image = node_images.at(src_node); - int src_slot = output.first.slot; - int output_index = output.second; + int src_slot = output.first.index; + int output_index = output.second.index; DataType dtype = src_node->output_type(src_slot); dtypes[output_index] = dtype; @@ -1490,7 +1477,7 @@ Status Encapsulator::CopySubgraphEdges( return Status::OK(); } -Status Encapsulator::SplitIntoSubgraphs() { +Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { Status s; // Map from input graph nodes to subgraph nodes. @@ -1520,6 +1507,18 @@ Status Encapsulator::SplitIntoSubgraphs() { for (auto& entry : subgraphs_) { Subgraph& subgraph = entry.second; FixupSourceAndSinkEdges(subgraph.GetGraph()); + // Verify that the graph has well-formed control flow structure. + std::vector dummy; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(subgraph.GetGraph(), &dummy)); + } + + if (VLOG_IS_ON(1)) { + // Dump subgraphs. + for (auto& entry : subgraphs_) { + dump_graph::DumpGraphToFile( + strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first), + *entry.second.GetGraph(), library); + } } return s; @@ -1680,8 +1679,8 @@ Status Encapsulator::CopyEdgeToOutputGraph( const string& src_outside_compilation_id, const string& dst_func_id, const string& dst_outside_compilation_id, const std::unordered_map& node_images, Graph* graph_out, - std::unordered_set, NodeSlot::PairHasher>* - edges_added) { + std::unordered_set, + OutputInputTensorPairHasher>* edges_added) { Node* src_image; TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc( src_func_id, src_outside_compilation_id, dst_func_id, @@ -1696,7 +1695,8 @@ Status Encapsulator::CopyEdgeToOutputGraph( if (edge->IsControlEdge()) { // Add the control edge, if we have not already added it, using the images // determined above (potentially call operators or RecvAtHost/SendFromHost). - if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1)) + if (edges_added + ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1)) .second) { graph_out->AddControlEdge(src_image, dst_image); } @@ -1714,8 +1714,8 @@ Status Encapsulator::CopyEdgeToOutputGraph( // Add the edge, if we have not already added it. if (edges_added - ->emplace(NodeSlot(src_image, src_output), - NodeSlot(dst_image, dst_input)) + ->emplace(OutputTensor(src_image, src_output), + InputTensor(dst_image, dst_input)) .second) { graph_out->AddEdge(src_image, src_output, dst_image, dst_input); } @@ -1739,7 +1739,8 @@ Status Encapsulator::AddEdgesToOutputGraph( // Set of edges already added to the output graph, represented as (src, dst) // pairs. We use the set to deduplicate edges; multiple edges in the input // graph may map to one edge in the output graph. - std::unordered_set, NodeSlot::PairHasher> + std::unordered_set, + OutputInputTensorPairHasher> edges_added; for (const Edge* edge : graph_in_->edges()) { @@ -1943,6 +1944,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( // continue. TensorShapeProto proto; context->ShapeHandleToProto(shape, &proto); + VLOG(2) << "Node " << src_node->name() + << " has known shape: " << proto.DebugString(); if (dummy_node_images.find(src_node) == dummy_node_images.end()) { dummy_node_images[src_node] = AddDummyShapedNode(src_node, src_port, control_flow_info, @@ -1960,6 +1963,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( if (VLOG_IS_ON(2)) { TensorShapeProto proto; context->ShapeHandleToProto(shape, &proto); + VLOG(2) << "Node " << src_node->name() + << " has unknown shape: " << proto.DebugString(); } stack.push_back({src_node, false}); } @@ -2202,6 +2207,23 @@ Status Encapsulator::FindClusterDependencies() { } } } + if (VLOG_IS_ON(2)) { + // Print debug information. + VLOG(2) << "node_ancestors_map:"; + for (const auto& node_iter : node_ancestors_map) { + VLOG(2) << "\t" << node_iter.first->name() << ": subgraph = '" + << node_iter.second.subgraph + << "', outside_compilation_cluster = '" + << node_iter.second.outside_compilation_cluster + << "', ancestor_clusters: " + << (node_iter.second.ancestor_clusters.empty() ? "(empty)" : ""); + for (const auto& cluster_iter : node_iter.second.ancestor_clusters) { + VLOG(2) << "\t\tsubgraph = '" << cluster_iter.subgraph + << "', outside_compilation_cluster = '" + << cluster_iter.outside_compilation_cluster << "'"; + } + } + } return Status::OK(); } @@ -2409,7 +2431,7 @@ Status EncapsulateSubgraphsInFunctions( std::move(outside_compilation_attribute), &graph_in); TF_RETURN_IF_ERROR(encapsulator.FindClusterDependencies()); - TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs()); + TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library)); TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs( rewrite_subgraph_fn, reuse_existing_functions, library)); @@ -2458,7 +2480,7 @@ Status EncapsulateSubgraphsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateSubgraphsPass::Run"; if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph, + dump_graph::DumpGraphToFile("encapsulate_subgraphs_before", **options.graph, options.flib_def); } @@ -2472,72 +2494,76 @@ Status EncapsulateSubgraphsPass::Run( FunctionLibraryRuntime* flr = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); - auto rewrite_subgraph = [flr](std::unique_ptr* subgraph, - std::vector* input_permutation, - std::vector* output_permutation, - NodeDef* node) { - // Optimize the subgraph. - OptimizeGraph(flr, subgraph); - - const int num_args = input_permutation->size(); - std::vector const_args(num_args); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); - - DataTypeVector arg_types(num_args); - TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); - - // Compute a permutation of the arguments such that the constant arguments - // are first. - const int num_consts = - std::count(const_args.begin(), const_args.end(), true); - - const int num_resources = - std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE); - const int num_nonconsts = num_args - num_resources - num_consts; - if (num_nonconsts < 0) { - return errors::Internal("num_nonconsts should be >= 0, was ", - num_nonconsts); - } + auto rewrite_subgraph = + [flr](const std::vector& arg_source_tensors, + std::unique_ptr* subgraph, + std::vector* input_permutation, + std::vector* output_permutation, NodeDef* node) { + // Optimize the subgraph. + OptimizeGraph(flr, subgraph); + + const int num_args = input_permutation->size(); + std::vector const_args(num_args); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); + + DataTypeVector arg_types(num_args); + TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); + + // Compute a permutation of the arguments such that the constant + // arguments are first. + const int num_consts = + std::count(const_args.begin(), const_args.end(), true); + + const int num_resources = + std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE); + const int num_nonconsts = num_args - num_resources - num_consts; + if (num_nonconsts < 0) { + return errors::Internal("num_nonconsts should be >= 0, was ", + num_nonconsts); + } - int const_pos = 0; - int arg_pos = num_consts; - int resource_pos = num_consts + num_nonconsts; - for (int i = 0; i < num_args; ++i) { - if (const_args[i]) { - if (arg_types[i] == DT_RESOURCE) { - return errors::Internal( - "Resource arguments cannot be constant (argument ", i, ")"); + int const_pos = 0; + int arg_pos = num_consts; + int resource_pos = num_consts + num_nonconsts; + for (int i = 0; i < num_args; ++i) { + if (const_args[i]) { + if (arg_types[i] == DT_RESOURCE) { + return errors::Internal( + "Resource arguments cannot be constant (argument ", i, ")"); + } + (*input_permutation)[i] = const_pos; + ++const_pos; + } else if (arg_types[i] == DT_RESOURCE) { + (*input_permutation)[i] = resource_pos; + ++resource_pos; + } else { + (*input_permutation)[i] = arg_pos; + ++arg_pos; + } } - (*input_permutation)[i] = const_pos; - ++const_pos; - } else if (arg_types[i] == DT_RESOURCE) { - (*input_permutation)[i] = resource_pos; - ++resource_pos; - } else { - (*input_permutation)[i] = arg_pos; - ++arg_pos; - } - } - // Renumber argument nodes in the graph. - TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation)); + // Renumber argument nodes in the graph. + TF_RETURN_IF_ERROR( + RenumberArguments(subgraph->get(), *input_permutation)); - // TODO(phawkins): add a forward is-constant analysis, similarly split - // outputs into host-memory constants and device-memory non-constants. + // TODO(phawkins): add a forward is-constant analysis, similarly split + // outputs into host-memory constants and device-memory non-constants. - AddNodeAttr(kXlaCompiledKernelAttr, true, node); - AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); - AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node); - return Status::OK(); - }; + AddNodeAttr(kXlaCompiledKernelAttr, true, node); + AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); + AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node); + return Status::OK(); + }; - TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( - kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, - rewrite_subgraph, - /*reuse_existing_functions=*/false, &graph_out, library)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EncapsulateSubgraphsInFunctions( + kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, + rewrite_subgraph, /*reuse_existing_functions=*/false, &graph_out, + library), + "EncapsulateSubgraphsPass failed"); if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, + dump_graph::DumpGraphToFile("encapsulate_subgraphs_after", *graph_out, options.flib_def); } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index e5dab7c657c79afa861b0443314d0c7801e4b66d..926589546fec72048485d30966f31b24e44b1245 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -28,6 +28,9 @@ limitations under the License. namespace tensorflow { // A rewriting function to apply to each subgraph during encapsulation. +// 'arg_source_tensors' are the tensors corresponding to the arguments in the +// original source graph (*not* 'graph'). +// // 'graph' is the subgraph. The rewriting may renumber the inputs and outputs; // 'input_permutation' is a mapping from old argument numbers to new argument // numbers, whereas 'output_permutation' is the same for outputs. Both @@ -37,6 +40,7 @@ namespace tensorflow { // The rewrite may also change the NodeDef's operator name, and that // name will be used as the name of the generated function. typedef std::function& arg_source_tensors, std::unique_ptr* graph, std::vector* input_permutation, std::vector* output_permutation, NodeDef* node_def)> RewriteSubgraphFn; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 6a7cd932e53cc0428850ab048cc325e84ef1fce6..c0543a00792235c5dd090e81930d8c219dc7f1a3 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -742,10 +742,13 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { Scope root = Scope::NewRootScope().ExitOnError().WithDevice( "/job:localhost/replica:0/task:0/cpu:0"); auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); - auto const_x2 = ops::Const(root.WithOpName("const_x2"), 10.0f); + auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); + auto const_guarantee_x2 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2); auto const_guarantee_x1 = ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); - auto add1 = ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_x2); + auto add1 = + ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_guarantee_x2); add1.node()->AddAttr("_encapsulate", "encapsulate1"); Graph graph_before(OpRegistry::Global()); @@ -757,7 +760,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_encapsulate", "_outside", graph_before, /*rewrite_subgraph_fn=*/ - [&guaranteed_consts](std::unique_ptr* graph_ptr, + [&guaranteed_consts](const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, std::vector* input_permutation, std::vector* output_permutation, NodeDef* call_def) { @@ -801,7 +805,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_encapsulate", "_outside", graph_before, /*rewrite_subgraph_fn=*/ - [&guaranteed_consts](std::unique_ptr* graph_ptr, + [&guaranteed_consts](const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, std::vector* input_permutation, std::vector* output_permutation, NodeDef* call_def) { diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 4d49a14b24d53bbcb434560d59b8c97a17e18f86..c37b6112cc8a92047d495d057f59e2281710e678 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { @@ -23,15 +24,18 @@ namespace tensorflow { REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, + PartiallyDeclusterPass); + // The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We // also need to run it after the graph been rewritten to have _Send nodes added // for fetches. Before the _Send nodes are added, fetch nodes are identified by // name, and encapsulation might remove that node from the graph. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, EncapsulateSubgraphsPass); // Must run after EncapsulateSubgraphsPass. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, BuildXlaLaunchOpsPass); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 00a6f4075f9a18efc3895b033eb6d08e36088a53..8f78c110cb15f3cbc0344d102764241996b0d7de 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -16,6 +16,7 @@ cc_library( "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 902fe27acdec1cb323217e6409fbd02f62177612..7f4370b5b07b249bc9cf1f2ecf4086de359be68c 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -51,7 +52,11 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, if (device_type_ == DeviceType(DEVICE_CPU)) { platform_id_ = se::host::kHostPlatformId; } else if (device_type_ == DeviceType(DEVICE_GPU)) { - platform_id_ = se::cuda::kCudaPlatformId; + platform_id_ = ctx->device() + ->tensorflow_gpu_device_info() + ->stream->parent() + ->platform() + ->id(); } else { platform_id_ = nullptr; } @@ -115,6 +120,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { const XlaDevice::Metadata* metadata = nullptr; Status s = XlaDevice::GetMetadata(ctx, &metadata); bool allocate_xla_tensors = s.ok(); + bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams(); // Get the platform_id_ for XLA_* devices. if (platform_id_ == nullptr) { @@ -148,6 +154,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { XlaCompiler::Options options; options.client = client; + if (ctx->op_device_context() != nullptr) { + options.device_ordinal = + ctx->op_device_context()->stream()->parent()->device_ordinal(); + } options.device_type = cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); @@ -166,14 +176,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { } XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; + // Optimization: don't resolve constants. If we resolve constants we never + // emit them on the device, meaning that if they are needed by a following + // computation the host has to transfer them. + compile_options.resolve_compile_time_constants = false; + // Optimization: where possible, have the computation return a naked array + // rather than a one-element tuple. + compile_options.always_return_tuple = false; + OP_REQUIRES_OK( ctx, cache->Compile(options, function_, constant_args, variables, ctx, &kernel, &executable, &compile_options)); VLOG(1) << "Executing XLA Computation..."; - XlaComputationLaunchContext launch_context(client, xla_allocator, - allocate_xla_tensors); + XlaComputationLaunchContext launch_context( + client, xla_allocator, allocate_xla_tensors, use_multiple_streams); launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. @@ -182,7 +200,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { run_options.set_stream(stream); run_options.set_allocator(xla_allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); - run_options.set_rng_seed(ctx->step_id()); + run_options.set_rng_seed(GetXLARandomSeed()); Env* env = Env::Default(); auto start_time = env->NowMicros(); @@ -192,7 +210,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; - launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie()); + OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( + ctx, kernel, run_result.ConsumeValueOrDie())); VLOG(1) << "Done"; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 8c3882116dd4f048ea3e32c037bf4139c67a3eb9..90d5d56998c75d6e2c7a64e8516591153a26f82d 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" @@ -63,6 +65,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave // such nodes out of XLA clusters. if (HasForwardedRefInput(node)) { + VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast."; return false; } @@ -82,14 +85,13 @@ bool IsCompilableCall(const NodeDef& call_def, bool IsCompilableWhile(const Node& while_node, const DeviceType& jit_device_type, int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Loop marking: " << while_node.type_string(); - const NameAttrList* name_attr; NodeDef call; Status status; status = GetNodeAttr(while_node.attrs(), "cond", &name_attr); if (!status.ok()) { - VLOG(2) << "Missing 'cond' attribute on While node."; + VLOG(2) << "Rejecting While " << while_node.name() + << ": missing 'cond' attribute on While node."; return false; } const string cond_func = name_attr->name(); @@ -97,12 +99,14 @@ bool IsCompilableWhile(const Node& while_node, call.set_op(cond_func); *call.mutable_attr() = name_attr->attr(); if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Can't compile loop condition: " << cond_func; + VLOG(2) << "Rejecting While " << while_node.name() + << ": can't compile loop condition: " << cond_func; return false; } status = GetNodeAttr(while_node.attrs(), "body", &name_attr); if (!status.ok()) { - VLOG(2) << "Missing 'body' attribute on While node."; + VLOG(2) << "Rejecting While " << while_node.name() + << ": missing 'body' attribute on While node."; return false; } const string body_func = name_attr->name(); @@ -110,10 +114,10 @@ bool IsCompilableWhile(const Node& while_node, call.set_op(body_func); *call.mutable_attr() = name_attr->attr(); if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Can't compile loop body: " << body_func; + VLOG(2) << "Rejecting While " << while_node.name() + << ": can't compile loop body: " << body_func; return false; } - VLOG(2) << "Loop is compilable."; return true; } @@ -123,10 +127,9 @@ bool IsCompilableWhile(const Node& while_node, bool IsCompilableCall(const NodeDef& call_def, const DeviceType& jit_device_type, int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Function marking: " << call_def.op(); - if (depth > kMaxRecursionDepth) { - VLOG(2) << "Function depth limit exceeded"; + VLOG(2) << "Rejecting " << call_def.op() + << ": function depth limit exceeded."; return false; } @@ -134,7 +137,8 @@ bool IsCompilableCall(const NodeDef& call_def, Status status = lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle); if (!status.ok()) { - VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status; + VLOG(2) << "Rejecting " << call_def.op() + << ": could not instantiate: " << status; return false; } const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); @@ -148,7 +152,8 @@ bool IsCompilableCall(const NodeDef& call_def, // tf2xla to translate the TF graph into XLA. So we avoid this for now. // // TODO(b/36139787): Create a mechanism to set inlining hints. - VLOG(2) << "Can't compile noinline function: " << fdef.DebugString(); + VLOG(2) << "Rejecting " << call_def.op() + << ": can't compile noinline function."; return false; } @@ -162,23 +167,14 @@ bool IsCompilableCall(const NodeDef& call_def, if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Function marking failed: unsupported op " << node->name() - << ": " << node->def().ShortDebugString(); + VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op " + << node->name() << ": " << node->def().ShortDebugString(); return false; } } - VLOG(2) << "Function is compilable: " << call_def.op(); return true; } -// Tests whether `node` has a DT_RESOURCE typed input or output. -bool HasResourceInputOrOutput(const Node& node) { - return std::find(node.input_types().begin(), node.input_types().end(), - DT_RESOURCE) != node.input_types().end() || - std::find(node.output_types().begin(), node.output_types().end(), - DT_RESOURCE) != node.output_types().end(); -} - // Returns true if the op can be decomposed into XLA ops for which // there are fusable elemental implementations. // @@ -355,24 +351,27 @@ Status FindCompilationCandidates( } std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); + if (fuel >= std::numeric_limits::max() / 2) { + // The assumption is that if fuel started out as INT64_MAX, it will forever + // stay greater than INT64_MAX / 2. + VLOG(2) << "Starting fuel: infinity"; + } else { + VLOG(2) << "Starting fuel: " << fuel; + } + for (Node* node : sorted_nodes) { - VLOG(2) << "Fuel: " << fuel; if (fuel <= 0) { - VLOG(2) + VLOG(1) << "Hit fuel limit; not marking any remaining ops as clusterable."; break; } - VLOG(2) << "FindCompilationCandidates(): Processing " - << node->DebugString(); - DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceToDeviceType(node->assigned_device_name(), &device_type)); if (is_compilable_fn && !is_compilable_fn(node, device_type)) { - VLOG(2) << "Compilation rejected node: not compilable " << node->name() - << ": " << node->type_string(); + // is_compilable_fn has already logged the reason if it returned false. continue; } @@ -382,14 +381,14 @@ Status FindCompilationCandidates( DeviceType jit_device_type(registration->compilation_device_name); if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) { - VLOG(2) << "Compilation rejected node: unsupported op " << node->name() - << ": " << node->type_string(); + VLOG(2) << "Rejecting " << node->name() << ": unsupported op " + << node->type_string(); continue; } if (!registration->compile_resource_ops && HasResourceInputOrOutput(*node)) { - VLOG(2) << "Compilation rejected node: resource input/output " - << node->name() << ": " << node->type_string(); + VLOG(2) << "Rejecting: " << node->name() << ": resource input/output " + << node->type_string(); continue; } if (node->type_string() == "While" && @@ -399,15 +398,11 @@ Status FindCompilationCandidates( // _Arg nodes in a top-level function represent feeds. // Do not compile them. if (node->type_string() == "_Arg") { - VLOG(2) << "Skipping jit compilation for '_Arg'-typed node " - << node->DebugString(); continue; } // _Retval nodes in a top-level function represent fetches. // Do not compile them. if (node->type_string() == "_Retval") { - VLOG(2) << "Compilation rejected node: return value " << node->name() - << ": " << node->type_string(); continue; } candidates->insert(node); @@ -460,33 +455,59 @@ Status MarkForCompilationPass::Run( VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; + VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; const FunctionLibraryDefinition* fld = options.flib_def; - auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld]( - const Node* node, const DeviceType& device_type) { + std::unique_ptr deadness; + { + XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1); + TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness)); + } + + auto is_compilable = [&](const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device."; return false; } - // Don't compile control trigger nodes. We won't preserve their deadness - // semantics correctly, so it's safest not to compile them. - if (node->IsControlTrigger()) return false; - // If this device requires a JIT, we must say yes. if (registration->requires_compilation) return true; // If there is a _XlaCompile annotation, use its value. bool compile = false; Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); - if (status.ok()) return compile; + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") is false."; + } + return compile; + } status = fld->GetAttr(*node, kXlaCompileAttr, &compile); - if (status.ok()) return compile; + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") on callee is false."; + } + return compile; + } + + // If inputs to `node` can have conflicting deadness (i.e. some are alive + // and some are dead) then don't compile it. XLA cannot represent the + // deadness semantics of these nodes correctly and auto-clustering these + // nodes can cause deadness to propagate to nodes that should be live. + if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { + VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; + return false; + } // Check for fusable ops only if requested. if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) { + VLOG(2) << "Rejecting " << node->name() + << ": not fusable op but fusion_only enabled."; return false; } @@ -494,8 +515,17 @@ Status MarkForCompilationPass::Run( // Ignore enable_jit_by_default if global jit compilation for CPU // is explicitly requested via tf_xla_cpu_global_jit flag bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; - return (ignore_registration || registration->enable_jit_by_default) && - global_jit_level > 0; + bool should_compile = + (ignore_registration || registration->enable_jit_by_default) && + global_jit_level > 0; + if (!should_compile) { + if (global_jit_level <= 0) { + VLOG(2) << "Rejecting " << node->name() << ": global jit disabled."; + } else { + VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled."; + } + } + return should_compile; }; return RunImpl(options, is_compilable); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index e9acbfb19e42cb43cb0b986c438a569de29b2ebc..f1137af3c1e8539fda318d88d2c5b5187953ccab 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -40,20 +40,18 @@ class MarkForCompilationPass : public GraphOptimizationPass { Status Run(const GraphOptimizationPassOptions& options) override; - // Run() just calls RunImpl() if --tf_xla_auto_jit is enabled. To run the pass - // unconditionally, call RunImpl() directly. - // is_compilable_fn, if set, is a predicate that must be true for a node to - // be compiled. + private: Status RunImpl(const GraphOptimizationPassOptions& options, const std::function& is_compilable_fn = {}); + + friend class MarkForCompilationPassTestHelper; }; // Returns true iff 'ndef' is a call to a function that is compilable. A // function is compilable iff every operator in the function body is // compilable. bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 772c92d369e67f431b5d030d1d5cdc5ae2700d39..a780d4a936a3b757495c26d337f19c80a67f343a 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -38,27 +39,6 @@ namespace { REGISTER_OP("UncompilableNullary").Output("o: float"); REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); -Status MarkForCompilation(std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def) { - // Assign all nodes to the CPU device. - static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; - for (Node* n : (*graph)->nodes()) { - n->set_assigned_device_name(kCpuDevice); - } - - GraphOptimizationPassOptions opt_options; - opt_options.graph = graph; - opt_options.flib_def = flib_def; - MarkForCompilationPass pass; - return pass.RunImpl(opt_options); -} - -Status MarkForCompilation(std::unique_ptr* graph) { - FunctionDefLibrary flib; - FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); - return MarkForCompilation(graph, &flib_def); -} - std::unordered_map GetClusters(const Graph& graph) { std::unordered_map ids; for (Node* node : graph.nodes()) { @@ -87,7 +67,7 @@ TEST(XlaCompilationTest, Chains) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(4, clusters.size()); EXPECT_EQ(clusters["B"], clusters["C"]); @@ -112,7 +92,7 @@ TEST(XlaCompilationTest, UncompilableCycles) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -132,7 +112,7 @@ TEST(XlaCompilationTest, CompilableCycles) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(3, clusters.size()); @@ -155,7 +135,7 @@ TEST(XlaCompilationTest, Complex128Unsupported) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); } @@ -176,7 +156,7 @@ TEST(XlaCompilationTest, HalfSupported) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_FALSE(clusters.empty()); } @@ -205,7 +185,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(3, clusters.size()); // Everything should be compiled. } @@ -240,7 +220,8 @@ TEST(XlaCompilationTest, FunctionCalls) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def)); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -271,7 +252,7 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { ops::UnaryOp("Shape", d, builder.opts().WithName("E")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. } @@ -358,7 +339,7 @@ TEST(XlaCompilationTest, SymbolicGradients) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -383,7 +364,7 @@ TEST(XlaCompilationTest, Loops) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // Nothing should be compiled. In particular, 'd' and 'c' must not be @@ -410,7 +391,7 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: C = A + relu(A) @@ -441,7 +422,7 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: D = relu(A) + (A @ relu(A)) @@ -471,7 +452,7 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: C = A @ relu(A) @@ -511,7 +492,7 @@ TEST(XlaCompilationTest, Resources) { ops::UnaryOp("Relu", d, builder.opts().WithName("E")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. } @@ -541,7 +522,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { TF_EXPECT_OK(root.ToGraph(graph.get())); - Status status = MarkForCompilation(&graph); + Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph); EXPECT_FALSE(status.ok()); EXPECT_TRUE(str_util::StrContains(status.ToString(), "Edge from c to a would create a cycle.\n" @@ -569,7 +550,7 @@ TEST(XlaCompilationTest, Retval) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -587,7 +568,7 @@ TEST(XlaCompilationTest, DontCountIdentityOps) { auto r = ops::_Retval(root.WithOpName("R"), c, 0); } TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -603,7 +584,7 @@ TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) { auto r = ops::_Retval(root.WithOpName("R"), b, 0); } TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -617,7 +598,7 @@ TEST(XlaCompilationTest, ConstOp) { auto c = ops::Const(root.WithOpName("const"), 0.5f); c.node()->AddAttr(kXlaCompileAttr, true); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); EXPECT_EQ(1, GetClusters(*graph).size()); } @@ -628,7 +609,7 @@ TEST(XlaCompilationTest, ConstOp) { auto c = ops::Const(root.WithOpName("const"), string("string")); c.node()->AddAttr(kXlaCompileAttr, true); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); EXPECT_TRUE(GetClusters(*graph).empty()); } } @@ -643,7 +624,7 @@ TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); @@ -666,7 +647,7 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); @@ -680,5 +661,37 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { EXPECT_EQ(clusters, expected_clusters); } +TEST(XlaCompilationTest, ClusterControlTrigger) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_BOOL, "tensor_a", + "sender", 0, "receiver"); + Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_BOOL, "tensor_b", + "sender", 0, "receiver"); + Output const_a = ops::Const(root.WithOpName("const_a"), 42); + + ops::ControlTrigger ctrl_trigger_a(root.WithOpName("ctrl_trigger_a")); + ops::ControlTrigger ctrl_trigger_b(root.WithOpName("ctrl_trigger_b")); + root.graph()->AddControlEdge(recv_a.node(), ctrl_trigger_a.operation.node()); + root.graph()->AddControlEdge(recv_b.node(), ctrl_trigger_a.operation.node()); + root.graph()->AddControlEdge(ctrl_trigger_b.operation.node(), const_a.node()); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + // ctrl_trigger_a has inputs with mismatching deadness so it won't be + // clustered. ctrl_trigger_b is okay to cluster. + std::unordered_map expected_clusters( + {{"const_a", cluster_name}, {"ctrl_trigger_b", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc new file mode 100644 index 0000000000000000000000000000000000000000..a84b82e47923b2e7eec0e7eb848bd4377befbd07 --- /dev/null +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" + +namespace tensorflow { +/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : (*graph)->nodes()) { + n->set_assigned_device_name(kCpuDevice); + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = graph; + opt_options.flib_def = flib_def; + MarkForCompilationPass pass; + return pass.RunImpl(opt_options); +} + +/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( + std::unique_ptr* graph) { + FunctionDefLibrary flib; + FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); + return MarkForCompilation(graph, &flib_def); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..b9a0531cb0e431a98d57a6d9a2e3e41b51e7b743 --- /dev/null +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ +#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ + +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" + +namespace tensorflow { +class MarkForCompilationPassTestHelper { + public: + // Runs the MarkForCompilation pass on `graph` after assigning all nodes in + // `graph` to the CPU device. To make testing easier, ignores device + // registration, _XlaCompile attributes, input deadness and global jit level. + static Status MarkForCompilation(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def); + + // Like `MarkForCompilation` but creates `flib_def` from the op registry. + static Status MarkForCompilation(std::unique_ptr* graph); +}; +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..68ead39424c35c1ef0bcc92e57af7931c0c57462 --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace tensorflow { +namespace { +Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, + gtl::ArraySlice post_order) { + // Find nodes that have at least one user outside their cluster that expects + // hostmem output. These nodes should be cloned to outside the cluster to + // avoid the device-host copy we'd otherwise need. + + MemoryTypeVector input_mtypes, output_mtypes; + + for (Node* n : post_order) { + gtl::optional from_cluster = GetXlaClusterForNode(*n); + if (!from_cluster) { + continue; + } + + // We assume the only XLA-auto-clusterable operations with side effects are + // resource variable updates. We can't execute these twice. + if (HasResourceInputOrOutput(*n)) { + continue; + } + + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, + n->def(), &input_mtypes, + &output_mtypes)); + for (const Edge* e : n->out_edges()) { + Node* dst = e->dst(); + + if (e->IsControlEdge()) { + continue; + } + + bool edge_incurs_extra_device_to_host_copy; + if (output_mtypes[e->src_output()] == DEVICE_MEMORY) { + // If the output of the *TensorFlow* operation is in DEVICE_MEMORY then + // keep the node clustered -- XLA will also produce the output in device + // memory and we will get some benefit from clustering. + edge_incurs_extra_device_to_host_copy = false; + } else { + MemoryTypeVector dst_input_mtypes, dst_output_mtypes; + DeviceType dst_device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(dst->assigned_device_name(), &dst_device_type)); + TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, + dst->def(), &dst_input_mtypes, + &dst_output_mtypes)); + edge_incurs_extra_device_to_host_copy = + dst_input_mtypes[e->dst_input()] == HOST_MEMORY; + } + + if (!edge_incurs_extra_device_to_host_copy) { + continue; + } + + // Check if `dst` is in a different cluster, unclustered, or about to be + // partially declustered (here we rely on the post-order traversal order). + // If yes, decluster `n` to avoid the device-to-host memcpy. + gtl::optional dst_cluster = + result->count(dst) ? gtl::nullopt : GetXlaClusterForNode(*dst); + if (from_cluster != dst_cluster) { + CHECK(result->insert(n).second); + break; + } + } + } + return Status::OK(); +} + +Status PartiallyDeclusterNode(Graph* graph, Node* n) { + StringPiece cluster_name = *GetXlaClusterForNode(*n); + gtl::InlinedVector out_edges_to_clone; + for (const Edge* out_edge : n->out_edges()) { + if (out_edge->IsControlEdge()) { + continue; + } + + Node* dst = out_edge->dst(); + gtl::optional dst_cluster_name = GetXlaClusterForNode(*dst); + if (dst_cluster_name != cluster_name) { + out_edges_to_clone.push_back(out_edge); + } + } + + CHECK(!out_edges_to_clone.empty()) << n->DebugString(); + + NodeDef ndef = n->def(); + ndef.set_name(strings::StrCat(n->name(), "/declustered")); + RemoveFromXlaCluster(&ndef); + Status s; + Node* cloned_node = graph->AddNode(ndef, &s); + cloned_node->set_assigned_device_name(n->assigned_device_name()); + TF_RETURN_IF_ERROR(s); + + for (const Edge* in_edge : n->in_edges()) { + graph->AddEdge(in_edge->src(), in_edge->src_output(), cloned_node, + in_edge->dst_input()); + } + + for (const Edge* out_edge_to_clone : out_edges_to_clone) { + graph->AddEdge(cloned_node, out_edge_to_clone->src_output(), + out_edge_to_clone->dst(), out_edge_to_clone->dst_input()); + graph->RemoveEdge(out_edge_to_clone); + } + + return Status::OK(); +} +} // namespace + +Status PartiallyDeclusterPass::Run( + const GraphOptimizationPassOptions& options) { + // NB! In this pass we assume the only XLA-auto-clusterable operations that + // may have side effects are resource variable operations so we don't cluster + // those. The pass will have to be updated if this assumption becomes + // invalid. + + Graph* graph = options.graph->get(); + + // When deciding whether to decluster a particular node, we base our decision + // on if we've decided that some of its consumers have to be declustered too. + // Iterating the graph in post-order guarantees that consumers have been + // visited before producers. + std::vector post_order; + GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + + gtl::FlatSet nodes_to_partially_decluster; + TF_RETURN_IF_ERROR(FindNodesToDecluster( + **options.graph, &nodes_to_partially_decluster, post_order)); + + if (VLOG_IS_ON(3)) { + for (Node* n : post_order) { + if (nodes_to_partially_decluster.count(n)) { + VLOG(3) << n->DebugString(); + } + } + } + + for (Node* n : post_order) { + if (nodes_to_partially_decluster.count(n)) { + TF_RETURN_IF_ERROR(PartiallyDeclusterNode(graph, n)); + } + } + + nodes_to_partially_decluster.clear(); + TF_RETURN_IF_ERROR(FindNodesToDecluster( + **options.graph, &nodes_to_partially_decluster, post_order)); + CHECK(nodes_to_partially_decluster.empty()); + + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..6949b5028ee55e182b27589f9a9711dad7839e86 --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// Clones nodes from within a cluster to outside the cluster if profitable. +// +// Today this only clones to avoid device-to-host copies, but in the future we +// may consider other reasons to clone. For instance, we convert this: +// +// ..... +// | +// v +// A_Clustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// to: +// +// ..... +// | | +// | +-------------+ +// | | +// v v +// A_Clustered A_Unclustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// where the ===> arrow has a hostmem source and destination and would entail a +// device to host copy if the source and destination were not in the same XLA +// cluster. +class PartiallyDeclusterPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..08a956e4c6478ff76a0fe8f1f60d94824daf535c --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -0,0 +1,284 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/partially_decluster_pass.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +REGISTER_OP("FakeNullary").Output("out: float"); + +REGISTER_OP("FakeBinary") + .Input("host_in: float") + .Input("device_in: float") + .Output("host_out: float") + .Output("device_out: float"); + +REGISTER_OP("FakeResourceVar").Output("out: resource"); + +REGISTER_OP("FakeResourceUpdate") + .Input("in: resource") + .Output("out: resource") + .Output("something_else: float"); + +class FakeBinaryOp : public OpKernel { + public: + explicit FakeBinaryOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { CHECK(false); } +}; + +class FakeResourceVarUpdateOp : public OpKernel { + public: + explicit FakeResourceVarUpdateOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { CHECK(false); } +}; + +REGISTER_KERNEL_BUILDER(Name("FakeBinary") + .Device(DEVICE_CPU) + .HostMemory("host_in") + .HostMemory("host_out"), + FakeBinaryOp); + +REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate") + .Device(DEVICE_CPU) + .HostMemory("something_else"), + FakeResourceVarUpdateOp); + +Status PartiallyDecluster(std::unique_ptr* graph) { + FixupSourceAndSinkEdges(graph->get()); + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : (*graph)->nodes()) { + n->set_assigned_device_name(kCpuDevice); + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = graph; + PartiallyDeclusterPass pass; + return pass.Run(opt_options); +} + +const Node* FindNodeByName(const Graph& graph, const string& name) { + for (const Node* node : graph.nodes()) { + if (node->name() == name) { + return node; + } + } + return nullptr; +} + +bool GetInputsForNode(const Graph& graph, const string& node_name, + std::vector* inputs) { + const Node* node = FindNodeByName(graph, node_name); + if (node == nullptr) { + return false; + } + for (const Edge* e : node->in_edges()) { + inputs->push_back(e->src()); + } + std::sort(inputs->begin(), inputs->end(), NodeComparatorName()); + return true; +} + +TEST(PartiallyDeclusterPassTest, ClusteredAndUnclustered) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + ops::BinaryOp("FakeBinary", clustered_producer, input, + builder.opts().WithName("UnclusteredConsumer")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector unclustered_consumer_inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer", + &unclustered_consumer_inputs)); + ASSERT_EQ(unclustered_consumer_inputs.size(), 2); + EXPECT_EQ(unclustered_consumer_inputs[0]->name(), + "ClusteredProducer/declustered"); + EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input"); + + std::vector clustered_consumer_inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredConsumer", + &clustered_consumer_inputs)); + ASSERT_EQ(clustered_consumer_inputs.size(), 2); + EXPECT_EQ(clustered_consumer_inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(clustered_consumer_inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DifferentClusters) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", clustered_producer, input, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer/declustered"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DontDeclusterIfUserIsDeviceMem) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + // The first input is hostmem and the second input is devicemem. + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", input, clustered_producer, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DontDuplicateResourceVarOps) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* resource_var = ops::SourceOp("FakeResourceVar", + builder.opts().WithName("ResourceVar")); + Node* clustered_producer = + ops::UnaryOp("FakeResourceUpdate", resource_var, + builder.opts().WithName("ClusteredProducer")); + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer_0 = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer0")); + Node* clustered_producer_1 = + ops::BinaryOp("FakeBinary", clustered_producer_0, input, + builder.opts().WithName("ClusteredProducer1")); + ops::BinaryOp("FakeBinary", clustered_producer_1, input, + builder.opts().WithName("UnclusteredConsumer")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", {clustered_producer_1, 1}, input, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector unclustered_consumer_inputs, declustered_producer_1_inputs; + + ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer", + &unclustered_consumer_inputs)); + ASSERT_EQ(unclustered_consumer_inputs.size(), 2); + EXPECT_EQ(unclustered_consumer_inputs[0]->name(), + "ClusteredProducer1/declustered"); + EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input"); + + ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredProducer1/declustered", + &declustered_producer_1_inputs)); + ASSERT_EQ(declustered_producer_1_inputs.size(), 2); + EXPECT_EQ(declustered_producer_1_inputs[0]->name(), + "ClusteredProducer0/declustered"); + EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 05b7821b8865d0f210ca9af92370e177d6043e80..0a025a1fc0b268963069a8c1a3be700040be3f8e 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -139,27 +139,32 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { }; for (Edge const* edge : graph->edges()) { - if (edge->dst()->IsEnter()) { - // Lift edges to an "Enter" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->dst()->id()].frame_name; - int dst = GetOrAddFrameNodeId(frame_name); - if (!cycles->InsertEdge(edge->src()->id(), dst)) { - return errors::Internal( - "Cycle detected when adding enter->frame edge: ", - DescribeCycle(cycles, *graph, edge->src()->id(), dst)); + if (edge->dst()->IsEnter() || edge->src()->IsExit()) { + const char* src_type = "pre-enter"; + const char* dst_type = "post-exit"; + int src = edge->src()->id(); + int dst = edge->dst()->id(); + + if (edge->dst()->IsEnter()) { + // Lift edges to an "Enter" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->dst()->id()].frame_name; + dst = GetOrAddFrameNodeId(frame_name); + dst_type = "frame"; } - continue; - } - if (edge->src()->IsExit()) { - // Lift edges from an "Exit" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->src()->id()].frame_name; - int src = GetOrAddFrameNodeId(frame_name); - if (!cycles->InsertEdge(src, edge->dst()->id())) { + + if (edge->src()->IsExit()) { + // Lift edges from an "Exit" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->src()->id()].frame_name; + src = GetOrAddFrameNodeId(frame_name); + src_type = "frame"; + } + + if (!cycles->InsertEdge(src, dst)) { return errors::Internal( - "Cycle detected when adding frame->exit edge: ", - DescribeCycle(cycles, *graph, src, edge->dst()->id())); + "Cycle detected when adding ", src_type, "->", dst_type, + " edge: ", DescribeCycle(cycles, *graph, src, dst)); } // Drop the original edge. continue; @@ -180,4 +185,26 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { return Status::OK(); } +gtl::optional GetXlaClusterForNode(const Node& node) { + const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr); + if (attr_value == nullptr) { + return gtl::nullopt; + } + Status s = AttrValueHasType(*attr_value, "string"); + if (!s.ok()) { + return gtl::nullopt; + } + return attr_value->s(); +} + +bool HasResourceInputOrOutput(const Node& node) { + return std::find(node.input_types().begin(), node.input_types().end(), + DT_RESOURCE) != node.input_types().end() || + std::find(node.output_types().begin(), node.output_types().end(), + DT_RESOURCE) != node.output_types().end(); +} + +void RemoveFromXlaCluster(NodeDef* node_def) { + node_def->mutable_attr()->erase(kXlaClusterAttr); +} } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index bcce082aaf6044ff0654efa4d78c0f493a350d00..bff76da6f9bcb06170e5aeb111da8545a6d291f8 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/gtl/optional.h" namespace tensorflow { @@ -44,6 +45,16 @@ bool HasForwardedRefInput(const Node& node); // the enclosing graph. Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); +// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, +// otherwise returns nullopt. +gtl::optional GetXlaClusterForNode(const Node& node); + +// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute). +void RemoveFromXlaCluster(NodeDef* node_def); + +// Returns true if `node` has a DT_RESOURCE typed input or output. +bool HasResourceInputOrOutput(const Node& node); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2cb351e1ecdb4523a8652886af156540e4736b18 --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_cluster_util.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0)); + Output enter = + ops::internal::Enter(root.WithOpName("enter"), a, "only_frame"); + Output exit = ops::internal::Exit(root.WithOpName("exit"), enter); + Output b = ops::Add(root.WithOpName("b"), a, exit); + + FixupSourceAndSinkEdges(root.graph()); + + GraphCycles cycles; + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); +} + +TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0)); + Output enter_0 = + ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0"); + Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0); + Output enter_1 = + ops::internal::Enter(root.WithOpName("enter_1"), a, "frame_1"); + Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1); + Output b = ops::Add(root.WithOpName("b"), a, exit_1); + + FixupSourceAndSinkEdges(root.graph()); + + GraphCycles cycles; + TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles)); + EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id())); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 7ed609c43748062656b631243c01d790519c54fd..7140d47a9421ec73d0144e855b490f89569e6ae9 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -40,7 +40,23 @@ namespace tensorflow { XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client, DeviceType device_type) : client_(client), device_type_(std::move(device_type)) {} -XlaCompilationCache::~XlaCompilationCache() = default; +XlaCompilationCache::~XlaCompilationCache() { + // Ensure any use of our programs have completed by waiting for all stream + // executors to complete. + for (auto* executor : client_->backend().stream_executors()) { + bool ok = executor->SynchronizeAllActivity(); + if (!ok) { + LOG(ERROR) << "Error synchronizing activity while waiting for all " + "programs to complete"; + } + } + // TODO(b/110813685): Think about the program ownership model. Programs are + // currently owned by the compilation cache which means we must wait for + // program completion in the destructor. There are multiple compilation caches + // around, which complicates things a little. Perhaps having programs be + // shared_ptrs (an invasive change) would make the model easier to reason + // about? +} string XlaCompilationCache::DebugString() { return "XLA JIT compilation cache"; @@ -193,7 +209,9 @@ Status XlaCompilationCache::BuildExecutable( argument_layouts[i] = &result.xla_input_shapes[i]; } xla::ExecutableBuildOptions build_options; - build_options.set_device_ordinal(client_->default_device_ordinal()); + build_options.set_device_ordinal(options.device_ordinal != -1 + ? options.device_ordinal + : client_->default_device_ordinal()); build_options.set_result_layout(result.xla_output_shape); build_options.set_device_allocator(options.device_allocator); @@ -240,6 +258,7 @@ Status XlaCompilationCache::CompileImpl( xla::LocalExecutable** executable, const XlaCompiler::CompileOptions* compile_options, bool compile_single_op) { + CHECK_NE(executable, nullptr); VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { @@ -277,7 +296,7 @@ Status XlaCompilationCache::CompileImpl( // protect the contents of the cache entry. Entry* entry; { - mutex_lock lock(mu_); + mutex_lock lock(compile_cache_mu_); // Find or create a cache entry. std::unique_ptr& e = cache_[signature]; if (!e) { @@ -293,6 +312,8 @@ Status XlaCompilationCache::CompileImpl( if (!entry->compiled) { VLOG(1) << "Compilation cache miss for signature: " << SignatureDebugString(signature); + tensorflow::Env* env = tensorflow::Env::Default(); + const uint64 compile_start_us = env->NowMicros(); // Do the actual JIT compilation without holding the lock (it can take // a long time.) std::vector args; @@ -311,18 +332,35 @@ Status XlaCompilationCache::CompileImpl( compile_options ? *compile_options : XlaCompiler::CompileOptions(), function, args, &entry->compilation_result); } - } - *compilation_result = &entry->compilation_result; - if (entry->compilation_status.ok() && executable) { - if (entry->executable == nullptr) { - entry->compilation_status = BuildExecutable( - options, entry->compilation_result, &entry->executable); + TF_RETURN_IF_ERROR(entry->compilation_status); + CHECK_EQ(entry->executable.get(), nullptr); + entry->compilation_status = + BuildExecutable(options, entry->compilation_result, &entry->executable); + + const uint64 compile_end_us = env->NowMicros(); + const uint64 compile_time_us = compile_end_us - compile_start_us; + { + mutex_lock lock(compile_stats_mu_); + auto it = compile_stats_.emplace(function.name(), CompileStats{}).first; + it->second.compile_count++; + it->second.cumulative_compile_time_us += compile_time_us; + VLOG(1) << "compiled " << function.name() << " " + << it->second.compile_count + << " times, compile time: " << compile_time_us + << " us, cumulative: " << it->second.cumulative_compile_time_us + << " us (" + << tensorflow::strings::HumanReadableElapsedTime(compile_time_us / + 1.0e6) + << " / " + << tensorflow::strings::HumanReadableElapsedTime( + it->second.cumulative_compile_time_us / 1.0e6) + << ")"; } - *executable = entry->executable.get(); } - - Status status = entry->compilation_status; - return status; + TF_RETURN_IF_ERROR(entry->compilation_status); + *compilation_result = &entry->compilation_result; + *executable = entry->executable.get(); + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index be1043d8c3fc0573922837e541615114a6d7a1a5..fc5f008f4f52c32d97e680784082d0e7bcb7d8eb 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -150,9 +151,22 @@ class XlaCompilationCache : public ResourceBase { std::unique_ptr executable GUARDED_BY(mu); }; - mutex mu_; - std::unordered_map, Signature::Hash> cache_ - GUARDED_BY(mu_); + mutex compile_cache_mu_; + gtl::FlatMap, Signature::Hash> cache_ + GUARDED_BY(compile_cache_mu_); + + struct CompileStats { + // Number of times the cluster has been (re-)compiled. + int64 compile_count = 0; + + // Cumulative time spent compiling the cluster. + int64 cumulative_compile_time_us = 0; + }; + mutex compile_stats_mu_; + + // Maps cluster names to compilation statistics for said cluster. + gtl::FlatMap compile_stats_ + GUARDED_BY(compile_stats_mu_); TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); }; diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index b1943d3e1a7e321b5a3796a0c6e4f2b5d9ac7018..dd84fb34c171f8d2174444ddd3b3b476e7142718 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -53,7 +54,9 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, // Builds an XLA allocator for the device. XlaComputationLaunchContext launch_context( - client, client->backend().memory_allocator(), true); + client, client->backend().memory_allocator(), + /*allocate_xla_tensors=*/true, + /*use_multiple_streams=*/metadata.UseMultipleStreams()); launch_context.PopulateInputs(ctx, result, variables); @@ -61,17 +64,22 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; TF_RET_CHECK(stream); - VLOG(2) << "Executing computation."; + VLOG(2) << "Executing computation: " << name(); + for (const xla::ShapedBuffer* arg : launch_context.arguments()) { + VLOG(2) << name() << ": " << *arg; + } xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(client->backend().memory_allocator()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); - run_options.set_rng_seed(ctx->step_id()); + run_options.set_rng_seed(GetXLARandomSeed()); - auto run_result = executable->Run(launch_context.arguments(), run_options); + xla::StatusOr run_result = + executable->Run(launch_context.arguments(), run_options); TF_RETURN_IF_ERROR(run_result.status()); - launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie()); + TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( + ctx, result, run_result.ConsumeValueOrDie())); return Status::OK(); } @@ -159,6 +167,13 @@ Status XlaCompileOnDemandOp::Compile( XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; + // Optimization: don't resolve constants. If we resolve constants we never + // emit them on the device, meaning that if they are needed by a following + // computation the host has to transfer them. + compile_options.resolve_compile_time_constants = false; + // Optimization: where possible, have the computation return a naked array + // rather than a one-element tuple. + compile_options.always_return_tuple = false; std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 43648402f65c656b6b4eb2e83e61ce45f1c73669..7e159e3171113b0d53f03bb676ac9c21db7fe77a 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -54,6 +54,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, DEVICE_CPU_XLA_JIT, options, name_prefix, registration, /*transfer_as_literal=*/false, + /*use_multiple_streams=*/false, /*shape_representation_fn=*/{}, /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index ed007d603ea1b3d27dd25f00726261cdd029c20c..2a2691a6a404520da4df451293ec0cb6028a165d 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -130,7 +131,7 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { const string& jit_device_name, const SessionOptions& options, const string& name_prefix, const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, + bool transfer_as_literal, bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" @@ -151,22 +152,24 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), strings::StrCat("device: ", device_name, " device")); - device->reset(new XlaDevice( - options, attrs, device_ordinal, DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal, shape_representation_fn, - padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); + device->reset( + new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name), + platform.ValueOrDie(), transfer_as_literal, + use_multiple_streams, shape_representation_fn, + padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); return Status::OK(); } XlaDevice::Metadata::Metadata( int device_ordinal, se::Platform* platform, const DeviceType& device_type, XlaCompiler::ShapeRepresentationFn shape_representation_fn, - PaddedShapeFn padded_shape_fn) + PaddedShapeFn padded_shape_fn, bool use_multiple_streams) : device_ordinal_(device_ordinal), device_type_(device_type), platform_(platform), shape_representation_fn_(std::move(shape_representation_fn)), - padded_shape_fn_(std::move(padded_shape_fn)) {} + padded_shape_fn_(std::move(padded_shape_fn)), + use_multiple_streams_(use_multiple_streams) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -200,24 +203,29 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { XlaDevice::XlaDevice( const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - se::Platform* platform, bool transfer_as_literal, + se::Platform* platform, bool transfer_as_literal, bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn) : LocalDevice(options, attrs), xla_metadata_(device_ordinal, platform, jit_device_name, - shape_representation_fn, padded_shape_fn), + shape_representation_fn, padded_shape_fn, + use_multiple_streams), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), - xla_allocator_(nullptr), platform_(platform), + use_multiple_streams_(use_multiple_streams), transfer_as_literal_(transfer_as_literal), shape_representation_fn_(shape_representation_fn) { - VLOG(1) << "Created XLA device " << jit_device_name; + VLOG(1) << "Created XLA device " << jit_device_name << " " << this; + thread_pool_.reset(new thread::ThreadPool(options.env, "xla_device", + /*num_threads=*/1)); } XlaDevice::~XlaDevice() { - if (gpu_device_info_ != nullptr) { - gpu_device_info_->default_context->Unref(); + VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this; + mutex_lock lock(mu_); + if (device_context_) { + device_context_->Unref(); } } @@ -233,6 +241,11 @@ xla::LocalClient* XlaDevice::client() const { } Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { + mutex_lock lock(mu_); + return GetAllocatorLocked(attr); +} + +Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) { if (attr.on_host()) { return cpu_allocator(); } @@ -245,52 +258,111 @@ Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { return xla_allocator_; } -xla::StatusOr XlaDevice::GetStream() { - if (!stream_) { - xla::Backend* backend = client()->mutable_backend(); - TF_ASSIGN_OR_RETURN(stream_, backend->BorrowStream(device_ordinal_)); +Status XlaDevice::EnsureDeviceContextOk() { + mutex_lock lock(mu_); + return GetDeviceContextLocked().status(); +} + +Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend, + const string& name, + std::shared_ptr* stream, + bool* stream_was_changed) { + if (!(*stream) || !(*stream)->ok()) { + xla::StreamPool::Ptr ptr; + TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_)); + *stream = std::shared_ptr(std::move(ptr)); + VLOG(1) << "XlaDevice " << this << " new " << name << " " + << (*stream)->DebugStreamPointers(); + *stream_was_changed = true; } - return stream_.get(); + return Status::OK(); } -Status XlaDevice::CreateAndSetGpuDeviceInfo() { - if (gpu_device_info_ == nullptr) { - TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - // Call GetAllocator for the side-effect of ensuring the allocator - // is created. - GetAllocator({}); - // XlaDevice owns both gpu_device_info_ and - // gpu_device_info_->default_context. - gpu_device_info_ = MakeUnique(); - gpu_device_info_->stream = stream; - gpu_device_info_->default_context = new XlaDeviceContext( - stream, client(), transfer_as_literal_, shape_representation_fn_); - set_tensorflow_gpu_device_info(gpu_device_info_.get()); +xla::StatusOr XlaDevice::GetDeviceContextLocked() { + xla::Backend* backend = client()->mutable_backend(); + + // Ensure all our streams are valid, borrowing new streams if necessary. + bool need_new_device_context = !device_context_; + TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_, + &need_new_device_context)); + + std::shared_ptr host_to_device_stream = stream_; + std::shared_ptr device_to_host_stream = stream_; + if (use_multiple_streams_) { + TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream", + &host_to_device_stream_, + &need_new_device_context)); + TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream", + &device_to_host_stream_, + &need_new_device_context)); + host_to_device_stream = host_to_device_stream_; + device_to_host_stream = device_to_host_stream_; } - return Status::OK(); + if (!need_new_device_context) { + return device_context_; + } + + // At this point we know we need a new device context. + // Call GetAllocator for the side-effect of ensuring the allocator is created. + GetAllocatorLocked({}); + if (device_context_) { + device_context_->Unref(); + } + // The XlaDeviceContext keeps a reference count to the streams, and the + // XlaDeviceContext remains live for the duration of a Executor run. This + // ensures that the streams remain live for the duration of a run, even if + // an error is encountered and the streams are replaced with new ones. + device_context_ = new XlaDeviceContext( + stream_, host_to_device_stream, device_to_host_stream, client(), + transfer_as_literal_, shape_representation_fn_, thread_pool_.get()); + VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext " + << device_context_; + + // Create and set a new GpuDeviceInfo, if necessary. + // + // TODO(b/78232898): This isn't thread-safe; there is a race between the call + // to set_tensorflow_gpu_device_info() with ops that call the getter + // tensorflow_gpu_device_info(). This isn't trivially fixed by adding locking + // to those methods; see the bug for details. Our only saving grace at the + // moment is that this race doesn't seem to occur in practice. + if (use_gpu_device_info_) { + auto gpu_device_info = MakeUnique(); + gpu_device_info->stream = stream_.get(); + gpu_device_info->default_context = device_context_; + set_tensorflow_gpu_device_info(gpu_device_info.get()); + gpu_device_info_ = std::move(gpu_device_info); + VLOG(1) << "XlaDevice " << this << " new GpuDeviceInfo " + << gpu_device_info_.get(); + } + + return device_context_; +} + +Status XlaDevice::UseGpuDeviceInfo() { + mutex_lock lock(mu_); + use_gpu_device_info_ = true; + return GetDeviceContextLocked().status(); } Status XlaDevice::FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) { VLOG(1) << "XlaDevice::FillContextMap"; + mutex_lock lock(mu_); + TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context, + GetDeviceContextLocked()); + device_context_map->resize(graph->num_node_ids()); - TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - // Call GetAllocator for the side-effect of ensuring the allocator is created. - GetAllocator({}); - auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_, - shape_representation_fn_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); - ctx->Ref(); - (*device_context_map)[n->id()] = ctx; + device_context->Ref(); + (*device_context_map)[n->id()] = device_context; } - ctx->Unref(); return Status::OK(); } void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { - VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":" + VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); // When Xprof profiling is off (which is the default), constructing the // activity is simple enough that its overhead is negligible. @@ -301,13 +373,29 @@ void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) { - VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" + VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" << op_kernel->type_string(); tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), op_kernel->IsExpensive()); op_kernel->ComputeAsync(context, done); } +Status XlaDevice::Sync() { + VLOG(1) << "XlaDevice::Sync"; + std::shared_ptr stream; + { + mutex_lock lock(mu_); + stream = stream_; + } + if (!stream) return Status::OK(); + + if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) { + return errors::Internal("XlaDevice::Sync() failed."); + } + VLOG(1) << "XlaDevice::Sync completed"; + return Status::OK(); +} + Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { @@ -323,16 +411,17 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, if (alloc_attrs.on_host()) { *tensor = parsed; } else { - Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); + mutex_lock lock(mu_); + TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context, + GetDeviceContextLocked()); + Allocator* allocator = GetAllocatorLocked(alloc_attrs); + Tensor copy(allocator, parsed.dtype(), parsed.shape()); Notification n; - TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream, client(), transfer_as_literal_, - shape_representation_fn_); - manager.CopyCPUTensorToDevice(&parsed, this, ©, - [&n, &status](const Status& s) { - status = s; - n.Notify(); - }); + device_context->CopyCPUTensorToDevice(&parsed, this, ©, + [&n, &status](const Status& s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); *tensor = copy; } diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 02e88ee6793e984a7b782790f8011cbcbc5a5026..dbf35f349f84268ebac0f73a86c9ca0704e90835 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -25,6 +25,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -39,6 +40,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace tensorflow { @@ -57,7 +59,7 @@ class XlaDevice : public LocalDevice { Metadata(int device_ordinal, se::Platform* platform, const DeviceType& device_type, XlaCompiler::ShapeRepresentationFn shape_representation_fn, - PaddedShapeFn padded_shape_fn); + PaddedShapeFn padded_shape_fn, bool use_multiple_streams); // The index of the device on this host. int device_ordinal() const; @@ -70,12 +72,15 @@ class XlaDevice : public LocalDevice { } const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; } + bool UseMultipleStreams() const { return use_multiple_streams_; } + private: const int device_ordinal_; const DeviceType device_type_; se::Platform* platform_; // Not owned. XlaCompiler::ShapeRepresentationFn shape_representation_fn_; PaddedShapeFn padded_shape_fn_; + const bool use_multiple_streams_; TF_DISALLOW_COPY_AND_ASSIGN(Metadata); }; @@ -89,6 +94,8 @@ class XlaDevice : public LocalDevice { // 'transfer_as_literal' is true if device<->host transfers must be done using // XLA's TransferLiteral{To,From}Device interface. If false, we can use // ThenMemcpy instead. + // If 'use_multiple_streams' is true, we create separate streams for + // host-to-device and device-to-host communication. // If padded_shape_fn is empty, a default implementation that returns // the on-host shape is used. static Status Create( @@ -96,7 +103,7 @@ class XlaDevice : public LocalDevice { int device_ordinal, const string& jit_device_name, const SessionOptions& options, const string& name_prefix, const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, + bool transfer_as_literal, bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device); @@ -106,54 +113,93 @@ class XlaDevice : public LocalDevice { XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, se::Platform* platform, bool transfer_as_literal, + bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn); ~XlaDevice() override; - Allocator* GetAllocator(AllocatorAttributes attr) override; + Allocator* GetAllocator(AllocatorAttributes attr) override + LOCKS_EXCLUDED(mu_); void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; - Status Sync() override { return Status::OK(); } + Status Sync() override; Status FillContextMap(const Graph* graph, - DeviceContextMap* device_context_map) override; + DeviceContextMap* device_context_map) override + LOCKS_EXCLUDED(mu_); Status MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, - Tensor* tensor) override; + Tensor* tensor) override LOCKS_EXCLUDED(mu_); - xla::LocalClient* client() const; const Metadata& metadata() { return xla_metadata_; } - xla::StatusOr GetStream(); - // If not already set, create and set GpuDeviceInfo. - // Not thread-safe - Status CreateAndSetGpuDeviceInfo(); + // Ensures the DeviceContext associated with this XlaDevice is created and + // valid (i.e. all streams are ok). If any state is not valid, a new + // DeviceContext will be created. + // + // TODO(b/111859745): The Eager context needs to call this method to recover + // from failures. + Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_); + + // Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra + // information for GPU and TPU devices. + Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); private: + xla::LocalClient* client() const; + Allocator* GetAllocatorLocked(AllocatorAttributes attr) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, + std::shared_ptr* stream, + bool* stream_was_changed) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + xla::StatusOr GetDeviceContextLocked() + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; // Which hardware device in the client's platform this XlaDevice controls. const int device_ordinal_; // The name of the device that is used to compile Ops for this XlaDevice. - DeviceType jit_device_name_; + const DeviceType jit_device_name_; + // The platform for this device. + se::Platform* const platform_; // Not owned. // Memory allocator associated with this device. - Allocator* xla_allocator_; // Not owned. - se::Platform* platform_; // Not owned. + Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned. // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and // computations enqueued by XLA. - xla::Backend::StreamPtr stream_; + std::shared_ptr stream_ GUARDED_BY(mu_); + // If false, only stream_ is valid and all computation and transfers use + // stream_. If true, computation is performed by stream_ and transfers are + // performed by host_to_device/device_to_host_stream. + const bool use_multiple_streams_; + // If use_multiple_streams_, host to device transfers are performed using this + // stream. + std::shared_ptr host_to_device_stream_ GUARDED_BY(mu_); + // If use_multiple_streams_, device to host transfers are performed using this + // stream. + std::shared_ptr device_to_host_stream_ GUARDED_BY(mu_); // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. - bool transfer_as_literal_; - XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + const bool transfer_as_literal_; + const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + + // The device context accessed by all users of the XlaDevice, set by calls to + // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is + // also filled in to that struct. XlaDeviceContext is a ref-counted object. + XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr; + + // Holds extra information for GPU and TPU devices, e.g. the device context. + bool use_gpu_device_info_ GUARDED_BY(mu_) = false; + std::unique_ptr gpu_device_info_ GUARDED_BY(mu_); - // If set, holds default device context (that we must Unref) - // and its stream. - std::unique_ptr gpu_device_info_; + // Thread pool used for running closures + std::unique_ptr thread_pool_; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 71e63b110b3b132a57fc291e53a165954c72a03c..0a0c0892411e8ebcd5624a29f3bd020fe6483944 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_context.h" +#include + +#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -48,17 +51,27 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } XlaTransferManager::XlaTransferManager( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : stream_(stream), + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, xla::LocalClient* client, + bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool) + : stream_(std::move(compute_stream)), + host_to_device_stream_(std::move(host_to_device_stream)), + device_to_host_stream_(std::move(device_to_host_stream)), client_(client), transfer_manager_(client->backend().transfer_manager()), transfer_as_literal_(transfer_as_literal), - shape_representation_fn_(std::move(shape_representation_fn)) { + shape_representation_fn_(std::move(shape_representation_fn)), + thread_pool_(thread_pool) { + CHECK(host_to_device_stream_ != nullptr); + CHECK(device_to_host_stream_ != nullptr); + CHECK(stream_ != nullptr); if (!shape_representation_fn_) { - shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) { - return shape; - }; + shape_representation_fn_ = + [](const TensorShape& shape, + DataType dtype) -> xla::StatusOr { return shape; }; } } @@ -67,99 +80,136 @@ Status XlaTransferManager::TransferLiteralToDevice( xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), host_tensor.shape(), &xla_shape)); - xla::BorrowingLiteral literal( + // Create a reference to hold onto host_tensor until after the literal has + // been transferred. Also make sure the literal exists until the function + // asynchronously completes, as it will be wrapped in an xla::LiteralSlice. + TensorReference ref(host_tensor); + auto literal = std::make_shared( static_cast(DMAHelper::base(&host_tensor)), xla_shape); - const xla::ShapedBuffer& shaped_buffer = - XlaTensor::FromTensor(device_tensor)->shaped_buffer(); - VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); + const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); + VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " " << shaped_buffer.ToString(); - return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal, - shaped_buffer); + if (UseMultipleStreams()) { + // Initially wait for the compute stream so that memory allocations are + // synchronized. + host_to_device_stream_->ThenWaitFor(stream_.get()); + } + TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( + host_to_device_stream_.get(), *literal, shaped_buffer)); + if (UseMultipleStreams()) { + auto event = std::make_shared(stream_->parent()); + TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; + host_to_device_stream_->ThenRecordEvent(event.get()); + xla_tensor->SetDefinedOn(host_to_device_stream_.get(), std::move(event)); + } + // Unref the host tensor, and capture the literal shared_ptr too so it goes + // out of scope when the lambda completes. + host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); }); + + return Status::OK(); } -Status XlaTransferManager::TransferLiteralFromDevice( - Tensor* host_tensor, const Tensor& device_tensor) const { +void XlaTransferManager::TransferLiteralFromDevice( + Tensor* host_tensor, const Tensor& device_tensor, + const StatusCallback& done) const { + xla::MutableBorrowingLiteral literal; + TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal)); + const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - transfer_manager_->TransferLiteralFromDevice( - stream_->parent(), shaped_buffer)); - VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " " - << shaped_buffer.ToString(); - Tensor tensor; - TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); - // Reshape the tensor back to its declared shape. - if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { - return errors::Internal( - "Tensor::CopyFrom failed when copying from XLA device to CPU"); - } - return Status::OK(); + TensorReference ref(device_tensor); + transfer_manager_->TransferLiteralFromDevice( + device_to_host_stream_.get(), shaped_buffer, literal, + [=, &shaped_buffer, &literal](xla::Status status) { + ref.Unref(); + done([&]() -> Status { + VLOG(1) << "Transfer from device as literal: " << literal.ToString() + << " " << shaped_buffer.ToString(); + return status; + }()); + }); } void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const { - if (cpu_tensor->NumElements() > 0) { - VLOG(2) << "CopyCPUTensorToDevice " - << reinterpret_cast(cpu_tensor->tensor_data().data()) - << " " - << reinterpret_cast( - device_tensor->tensor_data().data()) - << " " << cpu_tensor->NumElements() << " " - << cpu_tensor->shape().DebugString() << " " - << device_tensor->shape().DebugString(); - - void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); - const int64 total_bytes = cpu_tensor->TotalBytes(); - - XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); - CHECK(xla_tensor); - - TensorShape shape = shape_representation_fn_(device_tensor->shape(), - device_tensor->dtype()); - if (!xla_tensor->has_shaped_buffer()) { - Status s = xla_tensor->AllocateShapedBuffer( - device_tensor->dtype(), shape, client_, - stream_->parent()->device_ordinal()); - if (!s.ok()) { - done(s); - return; - } - } + if (cpu_tensor->NumElements() == 0) { + VLOG(2) << "CopyCPUTensorToDevice empty tensor"; + done(Status::OK()); + return; + } - Status status; - if (transfer_as_literal_) { - Tensor reshaped_cpu_tensor; - if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { - done(errors::Internal( - "Tensor::CopyFrom failed when copying from CPU to XLA device")); - return; - } - status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); - } else { - se::DeviceMemoryBase dev_dst_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); - stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); - } - } - xla_tensor->set_host_tensor(*cpu_tensor); + VLOG(2) << "CopyCPUTensorToDevice " + << reinterpret_cast(cpu_tensor->tensor_data().data()) + << " " + << reinterpret_cast(device_tensor->tensor_data().data()) + << " " << cpu_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); + + void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); + const int64 total_bytes = cpu_tensor->TotalBytes(); + + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); + CHECK(xla_tensor); - done(status); + xla::StatusOr shape_or_status = + shape_representation_fn_(device_tensor->shape(), device_tensor->dtype()); + if (!shape_or_status.ok()) { + done(shape_or_status.status()); return; } + TensorShape shape = shape_or_status.ValueOrDie(); + if (!xla_tensor->has_shaped_buffer()) { + Status s = + xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, + stream_->parent()->device_ordinal()); + if (!s.ok()) { + done(s); + return; + } + } - VLOG(2) << "CopyCPUTensorToDevice empty tensor"; - done(Status::OK()); + Status status; + if (transfer_as_literal_) { + Tensor reshaped_cpu_tensor; + if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { + done(errors::Internal( + "Tensor::CopyFrom failed when copying from CPU to XLA device")); + return; + } + status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); + if (status.ok()) { + xla_tensor->set_host_tensor(*cpu_tensor); + host_to_device_stream_->ThenDoHostCallback([this, done]() { + // We must not call the done closure directly from DoHostCallback + // to avoid a deadlock. If done() is the callback that ends an + // Executor's run, the Executor may call XlaDevice::Sync() inside the + // callback. This deadlocks, because XlaDevice::Sync() waits for all + // stream activity to complete. + thread_pool_->Schedule([done]() { done(Status::OK()); }); + }); + return; + } + } else { + se::DeviceMemoryBase dev_dst_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); + host_to_device_stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = host_to_device_stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", + host_to_device_stream_.get(), block_status.error_message().c_str()); + } + } + xla_tensor->set_host_tensor(*cpu_tensor); + + done(status); } void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, @@ -167,84 +217,129 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, Device* device, Tensor* cpu_tensor, StatusCallback done) { - if (device_tensor->NumElements() > 0) { - VLOG(2) << "CopyDeviceTensorToCPU " - << reinterpret_cast( - device_tensor->tensor_data().data()) - << " " - << reinterpret_cast(cpu_tensor->tensor_data().data()) - << " " << device_tensor->NumElements() << " " - << cpu_tensor->shape().DebugString() << " " - << device_tensor->shape().DebugString(); - - const int64 total_bytes = cpu_tensor->TotalBytes(); - se::DeviceMemoryBase dev_src_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); - void* dst_ptr = DMAHelper::base(cpu_tensor); - - Status status; - if (transfer_as_literal_) { - status = TransferLiteralFromDevice(cpu_tensor, *device_tensor); - } else { - stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); - } - } + if (device_tensor->NumElements() == 0) { + VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; + done(Status::OK()); + return; + } + VLOG(2) << "CopyDeviceTensorToCPU " + << reinterpret_cast(device_tensor->tensor_data().data()) + << " " + << reinterpret_cast(cpu_tensor->tensor_data().data()) + << " " << device_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); + + const int64 total_bytes = cpu_tensor->TotalBytes(); + se::DeviceMemoryBase dev_src_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); + void* dst_ptr = DMAHelper::base(cpu_tensor); + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); + + if (se::Event* event = + xla_tensor->GetDefinitionEvent(device_to_host_stream_.get())) { + device_to_host_stream_->ThenWaitFor(event); + xla_tensor->SetDefinedOn(device_to_host_stream_.get()); + } - done(status); + Status status; + if (transfer_as_literal_) { + TransferLiteralFromDevice(cpu_tensor, *device_tensor, done); return; + } else { + device_to_host_stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = device_to_host_stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", stream_.get(), + block_status.error_message().c_str()); + } } - VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; - done(Status::OK()); + done(status); } void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done) { - // TODO(phawkins): replace this code with an asynchronous implementation. - auto body = [&]() { + VLOG(2) << "CopyDeviceTensorToDevice " + << reinterpret_cast(src_tensor.tensor_data().data()) + << " " + << reinterpret_cast(dst_tensor->tensor_data().data()); + // Perform memory allocation now, and enqueue the device-to-device transfer. + Status status = [&]() -> Status { if (src_tensor.NumElements() == 0) { return Status::OK(); } + // TODO(jmolloy): We co-opt the device_to_host stream for device to device + // transfers; perhaps we should have a dedicated device to device stream? or + // one per device? + auto device_to_device_stream = stream_; XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor); XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor); CHECK(xla_src && xla_dst) << "Missing destination tensor for device-to-device copy"; if (!xla_dst->has_shaped_buffer()) { - TensorShape shape = - shape_representation_fn_(src_tensor.shape(), src_tensor.dtype()); + TF_ASSIGN_OR_RETURN( + TensorShape shape, + shape_representation_fn_(src_tensor.shape(), src_tensor.dtype())); TF_RETURN_IF_ERROR( xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_, stream_->parent()->device_ordinal())); + if (stream_ != device_to_device_stream) { + // Initially wait for the compute stream so that memory allocations are + // synchronized. + device_to_device_stream->ThenWaitFor(stream_.get()); + } + } + + if (se::Event* event = + xla_src->GetDefinitionEvent(device_to_device_stream.get())) { + device_to_device_stream->ThenWaitFor(event); + xla_src->SetDefinedOn(device_to_device_stream.get()); + } + + auto from_iter = xla_src->shaped_buffer().buffers().begin(); + auto to_iter = xla_dst->shaped_buffer().buffers().begin(); + for (auto end_iter = xla_src->shaped_buffer().buffers().end(); + from_iter != end_iter; ++from_iter, ++to_iter) { + device_to_device_stream->ThenMemcpyD2D( + &to_iter->second, from_iter->second, to_iter->second.size()); + } + + if (UseMultipleStreams()) { + auto event = std::make_shared(stream_->parent()); + TF_RET_CHECK(event->Init()) << "Event failed to initialize"; + device_to_device_stream->ThenRecordEvent(event.get()); + xla_dst->SetDefinedOn(device_to_device_stream.get(), std::move(event)); } - TF_RETURN_IF_ERROR( - xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus( - [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { - const se::DeviceMemoryBase& from_buffer = - xla_src->shaped_buffer().buffers().element(index); - CHECK_EQ(buffer->size(), from_buffer.size()); - if (!stream_->parent()->SynchronousMemcpy(buffer, from_buffer, - buffer->size())) { - return errors::Internal("Device to device memcpy failed"); - } - return Status::OK(); - })); return Status::OK(); - }; - done(body()); + }(); + if (!status.ok()) { + return done(status); + } else { + stream_->ThenDoHostCallback([this, done]() { + // We must not call the done closure directly from DoHostCallback to avoid + // a deadlock. If done() is the callback that ends an Executor's run, the + // Executor may call XlaDevice::Sync() inside the callback. This + // deadlocks, because XlaDevice::Sync() waits for all stream activity to + // complete. + thread_pool_->Schedule([done]() { done(Status::OK()); }); + }); + } } XlaDeviceContext::XlaDeviceContext( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : manager_(stream, client, transfer_as_literal, - std::move(shape_representation_fn)) {} + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, xla::LocalClient* client, + bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool) + : manager_(std::move(compute_stream), std::move(host_to_device_stream), + std::move(device_to_host_stream), client, transfer_as_literal, + std::move(shape_representation_fn), thread_pool) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index ee346e5653bbf9f393df202572c2150b4989506f..2e7445340cbaf788bfd06260f4376596895231c1 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -47,8 +47,12 @@ class XlaDeviceAllocator : public Allocator { class XlaTransferManager { public: explicit XlaTransferManager( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, + xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; @@ -59,17 +63,25 @@ class XlaTransferManager { void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); - se::Stream* stream() const { return stream_; } + se::Stream* stream() const { return stream_.get(); } private: Status TransferLiteralToDevice(const Tensor& host_tensor, Tensor* device_tensor) const; - Status TransferLiteralFromDevice(Tensor* host_tensor, - const Tensor& device_tensor) const; - - // Stream obtained from a Device, used to transfer tensors between - // CPU and device. - se::Stream* stream_; + void TransferLiteralFromDevice(Tensor* host_tensor, + const Tensor& device_tensor, + const StatusCallback& done) const; + bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; } + + // The main compute stream of the device, used to synchronize the transfer + // streams if they are set. + std::shared_ptr stream_; + // The stream to use for transferring data from host to device. Can be + // idential to stream_, but must not be nullptr. + std::shared_ptr host_to_device_stream_; + // The stream to use for transferring data from device to host. Can be + // idential to stream_, but must not be nullptr. + std::shared_ptr device_to_host_stream_; // For the underlying memory allocator and XLA's TransferManager. xla::LocalClient* client_; // Transfer manager, for marshalling data to and from the device. @@ -77,6 +89,9 @@ class XlaTransferManager { // True if we must use XLA's TransferManager for correct device transfers. const bool transfer_as_literal_; XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + + // Thread pool used for running closures + thread::ThreadPool* thread_pool_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -85,8 +100,12 @@ class XlaTransferManager { class XlaDeviceContext : public DeviceContext { public: explicit XlaDeviceContext( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, + xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 11e45d2823da2b623bd3cd45f7147686b05fdb2f..da3e329247e825d4a33a53dc310899d6ba6ce9cf 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,9 +23,15 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/data/generator_dataset_op.h" +#include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tensorflow/core/kernels/data/prefetch_dataset_op.h" +#include "tensorflow/core/kernels/fifo_queue.h" +#include "tensorflow/core/kernels/function_ops.h" #include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" +#include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" #include "tensorflow/core/kernels/shape_ops.h" @@ -75,9 +81,7 @@ class XlaAssignVariableOp : public AsyncOpKernel { ConstantOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("IdentityN").Device(DEVICE).TypeConstraint("T", TYPES), \ - IdentityNOp); \ + REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp); \ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \ REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ PlaceholderOp); \ @@ -88,6 +92,9 @@ class XlaAssignVariableOp : public AsyncOpKernel { REGISTER_KERNEL_BUILDER( \ Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ ReadVariableOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \ + DestroyResourceOp); \ REGISTER_KERNEL_BUILDER(Name("Shape") \ .Device(DEVICE) \ .HostMemory("output") \ @@ -145,7 +152,94 @@ class XlaAssignVariableOp : public AsyncOpKernel { .Device(DEVICE) \ .HostMemory("input") \ .HostMemory("output"), \ - LoopCondOp); + LoopCondOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \ + REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \ + .Device(DEVICE) \ + .HostMemory("size") \ + .HostMemory("handle"), \ + QueueSizeOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \ + QueueIsClosedOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name(kArgOp).Device(DEVICE).HostMemory("output").TypeConstraint("T", \ + TYPES), \ + ArgOp); \ + REGISTER_KERNEL_BUILDER(Name(kArgOp) \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ArgOp); \ + \ + REGISTER_KERNEL_BUILDER(Name(kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T", TYPES) \ + .HostMemory("input"), \ + RetvalOp); \ + REGISTER_KERNEL_BUILDER(Name(kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input"), \ + RetvalOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \ + GeneratorDatasetOp); \ + REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \ + .Device(DEVICE) \ + .HostMemory("buffer_size") \ + .HostMemory("input_dataset") \ + .HostMemory("handle"), \ + PrefetchDatasetOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \ + IteratorHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \ + MakeIteratorOp); \ + REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ + AnonymousIteratorHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ + IteratorGetNextOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ + .Device(DEVICE) \ + .HostMemory("string_handle"), \ + IteratorToStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \ + .Device(DEVICE) \ + .HostMemory("string_handle"), \ + IteratorFromStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ArgOp); \ + REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input"), \ + RetvalOp); + +// TODO(phawkins): currently we do not register the QueueEnqueueMany, +// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read +// and write the tensors they access in order to concatenate them into a batch. +// We would need either to call out to an XLA computation to perform the +// concatenation, or we would need to refactor those kernels so the splitting +// or merging is done in a separate operator that can be compiled. } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index 74257b09a808a39454eace3b1a9bf57a2e071360..4b499b161371ecece14447b29fbf809b6e8857db 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/union_find.h" @@ -146,6 +147,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, TF_RETURN_IF_ERROR( ImportGraphDef(options, item.graph, &graph, &shape_refiner)); + std::unique_ptr deadness; + TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness)); + // Collect nodes that can be fused via XLA, while ignoring those that // explicitly ask for XLA: (*) nodes that are marked to be compiled // explicitly. (*) nodes assigned to XLA device. @@ -185,6 +189,14 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, continue; } + // If inputs to `node` can have conflicting deadness (i.e. some are alive + // and some are dead) then don't compile it. XLA cannot represent the + // deadness semantics of these nodes correctly and auto-clustering these + // nodes can cause deadness to propagate to nodes that should be live. + if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { + continue; + } + compilation_candidates.insert(node); } diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index c0d86a28c7698c302e28bab972bb2f847cc00ca4..ef4466f0056ea98adc1ae6774105466af0d14293 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -49,6 +49,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, registration, /*transfer_as_literal=*/false, + /*use_multiple_streams=*/false, /*shape_representation_fn=*/{}, /*padded_shape_fn=*/{}, &device); if (!status.ok()) { @@ -58,7 +59,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, } // TODO(b/78468222): Uncomment after fixing this bug - // status = device->CreateAndSetGpuDeviceInfo(); + // status = device->UseGpuDeviceInfo(); // if (!status.ok()) { // errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, // " device"); diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 661187f4a873b03b8d013aa74cb6b6315bb4e2eb..45745596749207189c60ee1e3dcf19b6ecb7eb5b 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -52,6 +52,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices( DEVICE_INTERPRETER_XLA_JIT, options, name_prefix, registration, /*transfer_as_literal=*/false, + /*use_multiple_streams=*/false, /*shape_representation_fn=*/{}, /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index d0c7a9365125708b2af43f87c7617d8d84050a61..4efbb2d5d7cf09d9cf1e35c8cf5403e7e0dfe733 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_launch_util.h" +#include + #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -64,11 +66,13 @@ xla::StatusOr XlaAllocator::Allocate( int device_ordinal, uint64 size, bool retry_on_failure) { AllocationAttributes attrs; attrs.no_retry_on_failure = !retry_on_failure; - void* data = - wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs); - if (data == nullptr) { - return errors::ResourceExhausted("Out of memory while trying to allocate ", - size, " bytes."); + void* data = nullptr; + if (size != 0) { + data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs); + if (data == nullptr) { + return errors::ResourceExhausted( + "Out of memory while trying to allocate ", size, " bytes."); + } } return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size), device_ordinal, this); @@ -115,14 +119,22 @@ using internal::ExtractSubShapedBuffer; XlaComputationLaunchContext::XlaComputationLaunchContext( xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, - bool allocate_xla_tensors) + bool allocate_xla_tensors, bool use_multiple_streams) : client_(client), xla_allocator_(xla_allocator), - allocate_xla_tensors_(allocate_xla_tensors) {} + allocate_xla_tensors_(allocate_xla_tensors), + use_multiple_streams_(use_multiple_streams) { + if (use_multiple_streams_) { + CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must " + "be allocating XLA tensors!"; + } +} void XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, const std::map& variables) { + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; // Build ShapedBuffers that point directly to the Tensor buffers. arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1); arg_buffers_.resize(kernel->xla_input_shapes.size()); @@ -140,6 +152,16 @@ void XlaComputationLaunchContext::PopulateInputs( t = &(ctx->input(arg_num)); } + if (use_multiple_streams_) { + CHECK(stream) << "Must have a stream available when using XLA tensors!"; + XlaTensor* xla_tensor = XlaTensor::FromTensor(t); + CHECK(xla_tensor); + if (se::Event* event = xla_tensor->GetDefinitionEvent(stream)) { + stream->ThenWaitFor(event); + xla_tensor->SetDefinedOn(stream); + } + } + const xla::Shape on_device_shape = client_->backend().transfer_manager()->HostShapeToDeviceShape(shape); if (xla::ShapeUtil::IsTuple(on_device_shape)) { @@ -162,7 +184,7 @@ void XlaComputationLaunchContext::PopulateInputs( } } -void XlaComputationLaunchContext::PopulateOutputs( +Status XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, ScopedShapedBuffer output) { se::Stream* stream = @@ -176,6 +198,30 @@ void XlaComputationLaunchContext::PopulateOutputs( } CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); + // If the on-host-shape isn't a tuple, create a new single-element tuple + // buffer with a nullptr root index table. This allows the code below to treat + // output as a tuple unconditionally. + if (!xla::ShapeUtil::IsTuple(output.on_host_shape())) { + ShapedBuffer nontuple_buffer = output.release(); + ShapedBuffer buffer( + xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}), + xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}), + output.platform(), output.device_ordinal()); + buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(), + /*source_base_index=*/{}, + /*target_base_index=*/{0}); + output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator()); + } + + std::shared_ptr definition_event; + if (use_multiple_streams_) { + definition_event = std::make_shared(stream->parent()); + if (!definition_event->Init()) { + return errors::Internal("Failed to initialize tensor definition event."); + } + stream->ThenRecordEvent(definition_event.get()); + } + // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { @@ -193,12 +239,13 @@ void XlaComputationLaunchContext::PopulateOutputs( // reallocate the device buffer later. VLOG(1) << "Constant output tensor on device"; - OP_REQUIRES_OK( - ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + TF_RETURN_IF_ERROR( + ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); Device* device = dynamic_cast(ctx->device()); - OP_REQUIRES(ctx, device != nullptr, - errors::Internal("DeviceBase was not a Device.")); + if (device == nullptr) { + return errors::Internal("DeviceBase was not a Device."); + } ctx->op_device_context()->CopyCPUTensorToDevice( &const_tensor, device, output_tensor, [&](Status status) { TF_CHECK_OK(status); }); @@ -228,11 +275,19 @@ void XlaComputationLaunchContext::PopulateOutputs( se::DeviceMemoryBase buffer = output.buffer({output_num}); if (allocate_xla_tensors_) { Tensor* output_tensor; - OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor)); + TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); - CHECK(xla_tensor); - xla_tensor->set_shaped_buffer(ScopedShapedBuffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + if (xla_tensor) { + xla_tensor->set_shaped_buffer(ScopedShapedBuffer( + ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + if (use_multiple_streams_) { + xla_tensor->SetDefinedOn(stream, definition_event); + } + } else { + // xla_tensor wasn't valid, which must mean this is a zero-element + // tensor. + CHECK_EQ(output_tensor->TotalBytes(), 0); + } } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( ctx->expected_output_dtype(i), shape, buffer, allocator); @@ -252,36 +307,40 @@ void XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < kernel->resource_updates.size(); ++i) { Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - OP_REQUIRES(ctx, - write.input_index >= 0 && write.input_index < ctx->num_inputs(), - errors::Internal("Invalid input index for variable write.")); + if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) { + return errors::Internal("Invalid input index for variable write."); + } se::DeviceMemoryBase buffer = output.buffer({output_num}); Var* variable = nullptr; // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, // not a Tensor. - OP_REQUIRES_OK(ctx, LookupOrCreateResource( - ctx, HandleFromInput(ctx, write.input_index), - &variable, [this, ctx, &write](Var** ptr) { - *ptr = new Var(write.type); - return Status::OK(); - })); + TF_RETURN_IF_ERROR(LookupOrCreateResource( + ctx, HandleFromInput(ctx, write.input_index), &variable, + [&write](Var** ptr) { + *ptr = new Var(write.type); + return Status::OK(); + })); core::ScopedUnref s(variable); mutex_lock ml(*variable->mu()); - OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, - errors::Internal("Mismatched type in variable write")); + if (variable->tensor()->dtype() != write.type) { + return errors::Internal("Mismatched type in variable write"); + } if (allocate_xla_tensors_) { Tensor output_tensor; - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor)); + TF_RETURN_IF_ERROR( + ctx->allocate_temp(write.type, write.shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); CHECK(xla_tensor); xla_tensor->set_shaped_buffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); + if (use_multiple_streams_) { + xla_tensor->SetDefinedOn(stream, definition_event); + } *variable->tensor() = output_tensor; } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( @@ -291,6 +350,7 @@ void XlaComputationLaunchContext::PopulateOutputs( } ++output_num; } + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 4390701ccbd0bc3971413ddcd917c11019990087..4232f514b3b48681bf510ee568f916f5f4ebe882 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -76,9 +76,15 @@ class XlaComputationLaunchContext { // Create a new launch context. 'allocate_xla_tensors' is true if allocated // output tensors and variables are always XlaTensors. If false they are // assumed to be "normal" device pointers. + // If 'use_multiple_streams' is true, tensors may be defined and used on + // multiple streams and so se::Events must be defined and waited for. If + // 'use_multiple_streams' is true, 'allocate_xla_tensors' must also be true + // because we track inter-stream dependencies through events inside XlaTensor + // objects. XlaComputationLaunchContext(xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, - bool allocate_xla_tensors); + bool allocate_xla_tensors, + bool use_multiple_streams); // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. @@ -87,9 +93,9 @@ class XlaComputationLaunchContext { const std::map& variables); // Given the XLA output in `output`, populate all outputs of `ctx`. - void PopulateOutputs(OpKernelContext* ctx, - const XlaCompiler::CompilationResult* kernel, - xla::ScopedShapedBuffer output); + Status PopulateOutputs(OpKernelContext* ctx, + const XlaCompiler::CompilationResult* kernel, + xla::ScopedShapedBuffer output); // Return the argument list. Only valid after PopulateInputs() has been // called. @@ -99,6 +105,7 @@ class XlaComputationLaunchContext { xla::LocalClient* client_; xla::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; + bool use_multiple_streams_; std::vector> arg_buffers_; std::vector arg_ptrs_; }; @@ -115,7 +122,11 @@ class XlaTensorBuffer : public TensorBuffer { data_ = const_cast(ptr); } - ~XlaTensorBuffer() override { allocator_->DeallocateRaw(data_); } + ~XlaTensorBuffer() override { + if (data_) { + allocator_->DeallocateRaw(data_); + } + } void* data() const override { return data_; } size_t size() const override { return expected_size_; } diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index 3c44c4ae6df7f3e2d60d8933561c0c71888e8c3f..92ba7de1b7d32fcf693cd12a380d7a1e0d861d71 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -73,6 +73,35 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, return Status::OK(); } +se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { + mutex_lock lock(mu_); + if (!definition_event_) { + return nullptr; + } + + // The set of defined streams is expected to be very small indeed (usually + // 1-2), so a simple linear scan should be fast enough. + if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(), + stream) != streams_defined_on_.end()) { + // stream is in streams_defined_on_; it doesn't need to be waited on. + return nullptr; + } + + return definition_event_.get(); +} + +void XlaTensor::SetDefinedOn(se::Stream* stream, + std::shared_ptr event) { + mutex_lock lock(mu_); + definition_event_ = std::move(event); + streams_defined_on_ = {stream}; +} + +void XlaTensor::SetDefinedOn(se::Stream* stream) { + mutex_lock lock(mu_); + streams_defined_on_.push_back(stream); +} + // The pointer tag, OR-ed into the XlaTensor's address to distinguish it from // device-side tensors, which are either CPU or GPU memory pointers. This works // because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits. diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index c54001a999998f45c0cdacd752ca4036f0792857..8d36d0fa0a8230bcd1b16cc67de104e09358144f 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ #define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ +#include + #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/core/framework/allocator.h" @@ -85,6 +87,24 @@ class XlaTensor { host_tensor_.reset(new Tensor(tensor)); } + // If the tensor's content is not yet defined on 'stream', and there exists an + // se::Event declaring when the tensor's content is defined, return it. + // Otherwise, return nullptr. If this function returns nullptr then the + // tensor's content can be read on 'stream' without additional + // synchronization. + se::Event* GetDefinitionEvent(se::Stream* stream); + + // Assert that the tensor's content is defined on 'stream' by the time 'event' + // triggers. + void SetDefinedOn(se::Stream* stream, std::shared_ptr event); + + // Assert that the tensor's content is defined on 'stream'. This version does + // not provide an event, and must be called *after* SetDefinedOn(Stream, + // Event). This call can be read as an assertion that the definition event has + // been waited on by 'stream', so further calls to GetDefinitionEvent(stream) + // do not need to also wait on the event. + void SetDefinedOn(se::Stream* stream); + // Convert from a raw pointer to an XlaTensor, removing the pointer tag. static XlaTensor* FromOpaquePointer(void* ptr); // Convert to a raw pointer from an XlaTensor, adding the pointer tag. @@ -95,6 +115,14 @@ class XlaTensor { std::unique_ptr shaped_buffer_; // An optional host tensor value. std::unique_ptr host_tensor_; + // An optional event that is triggered when the tensor's content has been + // defined. If this event is nullptr, it is assumed that the tensor's content + // is always defined. + std::shared_ptr definition_event_; + // A list of all streams for which the tensor's content is defined for any + // newly enqueued command. + gtl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); + mutex mu_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index e6c92f9720e1285617280f60d1c5fea443c5ebef..ae98b3f0f9d5dac66b9716ad84a9f0371511e9b6 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -51,6 +51,38 @@ py_library( ], ) +py_library( + name = "test_utils", + testonly = 1, + srcs = ["test_utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//third_party/py/numpy", + ], +) + +py_test( + name = "xla_test_test", + size = "small", + srcs = ["xla_test_test.py"], + deps = [ + ":xla_test", + ], +) + +tf_xla_py_test( + name = "adadelta_test", + size = "medium", + srcs = ["adadelta_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "adagrad_test", size = "small", @@ -65,6 +97,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "adagrad_da_test", + size = "small", + srcs = ["adagrad_da_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "adam_test", size = "small", @@ -79,6 +124,48 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "adamax_test", + size = "small", + srcs = ["adamax_test.py"], + deps = [ + ":xla_test", + "//tensorflow/contrib/opt:opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:training", + ], +) + +tf_xla_py_test( + name = "addsign_test", + size = "small", + srcs = ["addsign_test.py"], + deps = [ + ":xla_test", + "//tensorflow/contrib/opt:opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:training", + ], +) + +tf_xla_py_test( + name = "powersign_test", + size = "small", + srcs = ["powersign_test.py"], + deps = [ + ":xla_test", + "//tensorflow/contrib/opt:opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "argminmax_test", size = "small", @@ -148,7 +235,7 @@ tf_xla_py_test( tf_xla_py_test( name = "cholesky_op_test", - size = "small", + size = "medium", srcs = ["cholesky_op_test.py"], tags = ["optonly"], deps = [ @@ -238,6 +325,7 @@ tf_xla_py_test( srcs = ["conv2d_test.py"], shard_count = 10, deps = [ + ":test_utils", ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:framework", @@ -245,6 +333,7 @@ tf_xla_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", ], ) @@ -329,7 +418,7 @@ tf_xla_py_test( tf_xla_py_test( name = "eager_test", - size = "small", + size = "large", srcs = ["eager_test.py"], disabled_backends = [ # TODO(b/78199195) Support XLA CPU devices in eager runtime @@ -350,6 +439,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "fifo_queue_test", + size = "medium", + srcs = ["fifo_queue_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "fft_test", size = "medium", @@ -535,13 +638,60 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "proximal_adagrad_test", + size = "medium", + srcs = ["proximal_adagrad_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:training", + ], +) + +tf_xla_py_test( + name = "proximal_gradient_descent_test", + size = "medium", + srcs = ["proximal_gradient_descent_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:training", + ], +) + +tf_xla_py_test( + name = "qr_op_test", + size = "medium", + srcs = ["qr_op_test.py"], + disabled_backends = [ + # Test is very slow on CPU. + "cpu", + "cpu_ondemand", + ], + shard_count = 5, + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + "@absl_py//absl/testing:parameterized", + ], +) + tf_xla_py_test( name = "random_ops_test", size = "small", srcs = ["random_ops_test.py"], - # TODO(b/31361304): enable RNG ops on GPU when parallelized. disabled_backends = [ - "gpu", + "cpu_ondemand", ], deps = [ ":xla_test", @@ -664,6 +814,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "sparse_to_dense_op_test", + size = "small", + srcs = ["sparse_to_dense_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:sparse_ops", + ], +) + tf_xla_py_test( name = "stack_ops_test", size = "small", @@ -743,9 +906,10 @@ tf_xla_py_test( tf_xla_py_test( name = "fused_batchnorm_test", - size = "small", + size = "medium", srcs = ["fused_batchnorm_test.py"], deps = [ + ":test_utils", ":xla_test", "//tensorflow/python:framework", "//tensorflow/python:math_ops", @@ -755,6 +919,7 @@ tf_xla_py_test( "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", "//tensorflow/python:training", + "@absl_py//absl/testing:parameterized", ], ) @@ -830,6 +995,21 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "sort_ops_test", + size = "medium", + srcs = ["sort_ops_test.py"], + shard_count = 5, + # Times out in fastbuild mode. + tags = ["optonly"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + ], +) + tf_xla_py_test( name = "xla_device_test", size = "small", diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3c09c66e72c4de141b64cea3c4693fabb7b2a2 --- /dev/null +++ b/tensorflow/compiler/tests/adadelta_test.py @@ -0,0 +1,134 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Adadelta Optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adadelta + + +class AdadeltaOptimizerTest(xla_test.XLATestCase): + + def testBasic(self): + num_updates = 4 # number of ADADELTA steps to perform + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + for grad in [0.2, 0.1, 0.01]: + for lr in [1.0, 0.5, 0.1]: + var0_init = [1.0, 2.0] + var1_init = [3.0, 4.0] + var0 = resource_variable_ops.ResourceVariable( + var0_init, dtype=dtype) + var1 = resource_variable_ops.ResourceVariable( + var1_init, dtype=dtype) + + grads = constant_op.constant([grad, grad], dtype=dtype) + + accum = 0.0 + accum_update = 0.0 + + # ADADELTA gradient optimizer + rho = 0.95 + epsilon = 1e-8 + adadelta_opt = adadelta.AdadeltaOptimizer( + learning_rate=lr, rho=rho, epsilon=epsilon) + adadelta_update = adadelta_opt.apply_gradients( + zip([grads, grads], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + opt_vars = adadelta_opt.variables() + self.assertStartsWith(opt_vars[0].name, var0._shared_name) + self.assertStartsWith(opt_vars[1].name, var0._shared_name) + self.assertStartsWith(opt_vars[2].name, var1._shared_name) + self.assertStartsWith(opt_vars[3].name, var1._shared_name) + self.assertEqual(4, len(opt_vars)) + # Assign slots + slot = [None] * 2 + slot_update = [None] * 2 + self.assertEqual(["accum", "accum_update"], + adadelta_opt.get_slot_names()) + slot[0] = adadelta_opt.get_slot(var0, "accum") + self.assertEquals(slot[0].get_shape(), var0.get_shape()) + self.assertFalse(slot[0] in variables.trainable_variables()) + + slot_update[0] = adadelta_opt.get_slot(var0, "accum_update") + self.assertEquals(slot_update[0].get_shape(), var0.get_shape()) + self.assertFalse(slot_update[0] in variables.trainable_variables()) + + slot[1] = adadelta_opt.get_slot(var1, "accum") + self.assertEquals(slot[1].get_shape(), var1.get_shape()) + self.assertFalse(slot[1] in variables.trainable_variables()) + + slot_update[1] = adadelta_opt.get_slot(var1, "accum_update") + self.assertEquals(slot_update[1].get_shape(), var1.get_shape()) + self.assertFalse(slot_update[1] in variables.trainable_variables()) + + # Fetch params to validate initial values + self.assertAllClose(var0_init, self.evaluate(var0)) + self.assertAllClose(var1_init, self.evaluate(var1)) + + update = [None] * num_updates + tot_update = 0 + for step in range(num_updates): + # Run adadelta update for comparison + self.evaluate(adadelta_update) + + # Perform initial update without previous accum values + accum = accum * rho + (grad**2) * (1 - rho) + update[step] = ( + np.sqrt(accum_update + epsilon) * + (1. / np.sqrt(accum + epsilon)) * grad) + accum_update = ( + accum_update * rho + (update[step]**2) * (1.0 - rho)) + tot_update += update[step] * lr + + # Check that the accumulators have been updated + for slot_idx in range(2): + self.assertAllCloseAccordingToType( + np.array([accum, accum], dtype=dtype), + self.evaluate(slot[slot_idx]), + rtol=1e-5) + + self.assertAllCloseAccordingToType( + np.array([accum_update, accum_update], dtype=dtype), + self.evaluate(slot_update[slot_idx]), + rtol=1e-5) + + # Check that the parameters have been updated + self.assertAllCloseAccordingToType( + np.array( + [var0_init[0] - tot_update, var0_init[1] - tot_update], + dtype=dtype), + self.evaluate(var0), + rtol=1e-5) + + self.assertAllCloseAccordingToType( + np.array( + [var1_init[0] - tot_update, var1_init[1] - tot_update], + dtype=dtype), + self.evaluate(var1), + rtol=1e-5) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dc1625793aa44b96d3b96e175237caf96e7d7e74 --- /dev/null +++ b/tensorflow/compiler/tests/adagrad_da_test.py @@ -0,0 +1,165 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for AdagradDA optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adagrad_da + + +class AdagradDAOptimizerTest(xla_test.XLATestCase): + + def testAdagradDAWithoutRegularizationBasic1(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + global_step = resource_variable_ops.ResourceVariable( + 0, dtype=dtypes.int64) + var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + opt = adagrad_da.AdagradDAOptimizer( + 3.0, + global_step, + initial_gradient_squared_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + + self.assertAllClose([0.0, 0.0], var0.eval()) + self.assertAllClose([0.0, 0.0], var1.eval()) + + # Run a step of AdagradDA + update.run() + + # Let g to be gradient accumulator, gg to be gradient squared + # accumulator, T be the global step, lr is the learning rate, and k the + # initial gradient squared accumulator value. + # w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})} + # For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534 + # similarly for others. + self.assertAllCloseAccordingToType( + np.array([-0.904534, -1.603567]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([-0.094821, -0.189358]), var1.eval()) + + def testAdagradDAwithoutRegularizationBasic2(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + global_step = resource_variable_ops.ResourceVariable( + 0, dtype=dtypes.int64) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + + opt = adagrad_da.AdagradDAOptimizer( + 3.0, + global_step, + initial_gradient_squared_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + + # Run a step of AdagradDA + update.run() + + self.assertAllCloseAccordingToType( + np.array([-0.904534, -1.603567]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([-0.094821, -0.189358]), var1.eval()) + + def testAdagradDAWithL1(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + global_step = resource_variable_ops.ResourceVariable( + 0, dtype=dtypes.int64) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + + opt = adagrad_da.AdagradDAOptimizer( + 3.0, + global_step, + initial_gradient_squared_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=0.0) + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + + # Run a step of AdagradDA + update.run() + + self.assertAllCloseAccordingToType( + np.array([-0.895489, -1.59555]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([-0.085339, -0.17989]), var1.eval()) + + def testAdagradDAWithL1_L2(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + global_step = resource_variable_ops.ResourceVariable( + 0, dtype=dtypes.int64) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + + opt = adagrad_da.AdagradDAOptimizer( + 3.0, + global_step, + initial_gradient_squared_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0) + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + + # Run a step of AdagradDA + update.run() + + self.assertAllCloseAccordingToType( + np.array([-0.046907, -0.093659]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([-0.004275, -0.009023]), var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index 9a93b3216404d8ed21fd6c57757bec1730c119b4..d775850a80e9f83f7b2c9f1cf8997dd50e229635 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -28,7 +28,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import adagrad -class AdagradOptimizerTest(XLATestCase): +class AdagradOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 3215dc36e5b2d517aa951db1b0d41188185ef93a..0d2e4d029636577adc74784d9a8b3494b94dc67d 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops @@ -48,10 +48,13 @@ def adam_update_numpy(param, return param_t, m_t, v_t -class AdamOptimizerTest(XLATestCase): +class AdamOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: + # TODO: test fails for float16 due to excessive precision requirements. + if dtype == np.float16: + continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) @@ -91,6 +94,9 @@ class AdamOptimizerTest(XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: + # TODO: test fails for float16 due to excessive precision requirements. + if dtype == np.float16: + continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) @@ -130,6 +136,9 @@ class AdamOptimizerTest(XLATestCase): def testSharing(self): for dtype in self.float_types: + # TODO: test fails for float16 due to excessive precision requirements. + if dtype == np.float16: + continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c4fdbc5974319db9243eb2c323746cbaaea795f6 --- /dev/null +++ b/tensorflow/compiler/tests/adamax_test.py @@ -0,0 +1,139 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for AdaMax optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.contrib.opt.python.training import adamax +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adamax_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + m_t = beta1 * m + (1 - beta1) * g_t + v_t = np.maximum(beta2 * v, np.abs(g_t)) + param_t = param - (alpha / (1 - beta1**t)) * (m_t / (v_t + epsilon)) + return param_t, m_t, v_t + + +class AdaMaxOptimizerTest(xla_test.XLATestCase): + + def testBasic(self): + for i, dtype in enumerate(self.float_types): + with self.test_session(), self.test_scope(): + variable_scope.get_variable_scope().set_use_resource(True) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = adamax.AdaMaxOptimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + opt_variables = opt.variables() + beta1_power = opt._get_beta_accumulators() + self.assertTrue(beta1_power is not None) + self.assertIn(beta1_power, opt_variables) + + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power = opt._get_beta_accumulators() + + # Run 3 steps of AdaMax + for t in range(1, 4): + update.run() + + self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval()) + + var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval(), rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, var1.eval(), rtol=1e-2) + self.assertEqual("var0_%d/AdaMax:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testTensorLearningRate(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + variable_scope.get_variable_scope().set_use_resource(True) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adamax.AdaMaxOptimizer(constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power = opt._get_beta_accumulators() + + # Run 3 steps of AdaMax + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + update.run() + + var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9ec5a964cbb4dd98d2ef2d0b684872292118800f --- /dev/null +++ b/tensorflow/compiler/tests/addsign_test.py @@ -0,0 +1,142 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for AddSign.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.contrib.opt.python.training import addsign +from tensorflow.contrib.opt.python.training import sign_decay +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def py_linear_decay_fn(decay_steps): + def linear_decay(step): + step = min(step, decay_steps) + return float(decay_steps - step) / decay_steps + return linear_decay + + +def addsign_update_numpy(params, + g_t, + m, + lr, + alpha=1.0, + beta=0.9, + py_sign_decay_fn=None, + t=None): + m_t = beta * m + (1 - beta) * g_t + if py_sign_decay_fn is None: + sign_decayed = 1.0 + else: + sign_decayed = py_sign_decay_fn(t-1) + multiplier = alpha + sign_decayed * np.sign(g_t) * np.sign(m_t) + params_t = params - lr * multiplier * g_t + return params_t, m_t + + +class AddSignTest(xla_test.XLATestCase): + + def _testDense(self, + learning_rate=0.1, + sign_decay_fn=None, + py_sign_decay_fn=None, + alpha=1.0, + beta=0.9): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + # Initialize variables for numpy implementation. + m0, m1 = 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + global_step = resource_variable_ops.ResourceVariable(0, trainable=False) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = addsign.AddSignOptimizer( + learning_rate=learning_rate, + alpha=alpha, + beta=beta, + sign_decay_fn=sign_decay_fn, + ) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 7 steps of AddSign + # first 4 steps with positive gradient + # last 3 steps with negative gradient (sign(gm) should be -1) + for t in range(1, 8): + if t < 5: + update.run() + else: + neg_update.run() + + var0_np, m0 = addsign_update_numpy( + var0_np, + grads0_np if t < 5 else -grads0_np, + m0, + learning_rate, + alpha=alpha, + beta=beta, + py_sign_decay_fn=py_sign_decay_fn, + t=t, + ) + var1_np, m1 = addsign_update_numpy( + var1_np, + grads1_np if t < 5 else -grads1_np, + m1, + learning_rate, + alpha=alpha, + beta=beta, + py_sign_decay_fn=py_sign_decay_fn, + t=t, + ) + + # Validate updated params + self.assertAllCloseAccordingToType( + var0_np, var0.eval(), half_rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testDense(self): + decay_steps = 10 + sign_decay_fn = sign_decay.get_linear_decay_fn(decay_steps) + py_sign_decay_fn = py_linear_decay_fn(decay_steps) + self._testDense() + self._testDense(learning_rate=0.01, alpha=0.1, beta=0.8) + self._testDense( + sign_decay_fn=sign_decay_fn, py_sign_decay_fn=py_sign_decay_fn) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 1e4dd32916c3a40282735fb8f75670b0e9ef0dc9..0aafda7fb4d710f154157ee352d6616e5aa8935f 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops @@ -32,7 +32,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest -class BinaryOpsTest(XLATestCase): +class BinaryOpsTest(xla_test.XLATestCase): """Test cases for binary operators.""" def _testBinary(self, op, a, b, expected, equality_test=None): @@ -226,6 +226,11 @@ class BinaryOpsTest(XLATestCase): np.array([0b1, 0b101, 0b1000], dtype=dtype), np.array([0b0, 0b101, 0b1001], dtype=dtype), expected=np.array([0b1, 0b101, 0b1001], dtype=dtype)) + self._testSymmetricBinary( + bitwise_ops.bitwise_xor, + np.array([0b1, 0b111, 0b1100], dtype=dtype), + np.array([0b0, 0b101, 0b1001], dtype=dtype), + expected=np.array([0b1, 0b010, 0b0101], dtype=dtype)) lhs = np.array([0, 5, 3, 14], dtype=dtype) rhs = np.array([5, 0, 7, 11], dtype=dtype) @@ -686,11 +691,13 @@ class BinaryOpsTest(XLATestCase): np.array([[10], [7], [2]], dtype=np.float32), np.float32(7), expected=np.array([[False], [False], [True]], dtype=np.bool)) - self._testBinary( - less_op, - np.array([[10], [7], [2], [-1]], dtype=np.int64), - np.int64(7), - expected=np.array([[False], [False], [True], [True]], dtype=np.bool)) + if np.int64 in self.numeric_types: + self._testBinary( + less_op, + np.array([[10], [7], [2], [-1]], dtype=np.int64), + np.int64(7), + expected=np.array( + [[False], [False], [True], [True]], dtype=np.bool)) for less_equal_op in [math_ops.less_equal, (lambda x, y: x <= y)]: self._testBinary( @@ -1216,6 +1223,24 @@ class BinaryOpsTest(XLATestCase): np.array([1, 0], dtype=np.int32), expected=np.array([[1, 3], [2, 4]], dtype=dtype)) + def testConjugateTranspose(self): + for dtype in self.complex_types: + self._testBinary( + array_ops.conjugate_transpose, + np.zeros(shape=[1, 0, 4], dtype=dtype), + np.array([1, 2, 0], dtype=np.int32), + expected=np.zeros(shape=[0, 4, 1], dtype=dtype)) + self._testBinary( + array_ops.conjugate_transpose, + np.array([[1 - 1j, 2 + 2j], [3 - 3j, 4 + 4j]], dtype=dtype), + np.array([0, 1], dtype=np.int32), + expected=np.array([[1 + 1j, 2 - 2j], [3 + 3j, 4 - 4j]], dtype=dtype)) + self._testBinary( + array_ops.conjugate_transpose, + np.array([[1 - 1j, 2 + 2j], [3 - 3j, 4 + 4j]], dtype=dtype), + np.array([1, 0], dtype=np.int32), + expected=np.array([[1 + 1j, 3 + 3j], [2 - 2j, 4 - 4j]], dtype=dtype)) + def testCross(self): for dtype in self.float_types: self._testBinary( diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py index fde9759a1c209844caac99d5f303cd3e406e5370..ef4d5f6322b7ae79b051795b5af7e6f7f1e55550 100644 --- a/tensorflow/compiler/tests/bucketize_op_test.py +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops @@ -26,7 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class BucketizationOpTest(XLATestCase): +class BucketizationOpTest(xla_test.XLATestCase): def testInt(self): with self.test_session() as sess: diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index 035cdea1786d39f3d21bb63be5c8ccffe1608bdf..a4e7f75081dfd07fd4b5c94c33908aab8e7d8aa9 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -22,7 +22,7 @@ import collections import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest # TODO(srvasude): Merge this with # third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py. -class CategoricalTest(XLATestCase): +class CategoricalTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def output_dtypes(self): diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index 1a8989d7c2f617525c301f30fd899a01362310bf..ed532db0ee5553a275192e6cc3ebf394075fa0e1 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -18,12 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import unittest - import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -32,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class CholeskyOpTest(XLATestCase): +class CholeskyOpTest(xla_test.XLATestCase): # Cholesky defined for float64, float32, complex64, complex128 # (https://www.tensorflow.org/api_docs/python/tf/cholesky) @@ -103,9 +101,8 @@ class CholeskyOpTest(XLATestCase): with self.assertRaises(ValueError): linalg_ops.cholesky(tensor3) - @unittest.skip("Test is slow") - def testLarge(self): - n = 200 + def testLarge2000x2000(self): + n = 2000 shape = (n, n) data = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag( np.ones(n).astype(np.float32)) @@ -128,6 +125,5 @@ class CholeskyOpTest(XLATestCase): matrix = np.dot(np.dot(w, np.diag(v)), w.T).astype(dtype) self._verifyCholesky(matrix, atol=1e-4) - if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py index 574f82fc717818334ac5d72ebef2191f1c18e669..e42ebf8f9e01dab13cde15979ffc42b7c0fbc57b 100644 --- a/tensorflow/compiler/tests/clustering_test.py +++ b/tensorflow/compiler/tests/clustering_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" -class ClusteringTest(XLATestCase): +class ClusteringTest(xla_test.XLATestCase): def testAdd(self): val1 = np.array([4, 3, 2, 1], dtype=np.float32) diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index f10973e19f1945515b776cf86349445ed7334629..d9ad4281477e87f79f2ecb52989ae86a5030d0cc 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class ConcatTest(XLATestCase): +class ConcatTest(xla_test.XLATestCase): def testHStack(self): with self.test_session(): @@ -292,7 +292,7 @@ class ConcatTest(XLATestCase): array_ops.concat([scalar, scalar, scalar], dim) -class ConcatOffsetTest(XLATestCase): +class ConcatOffsetTest(xla_test.XLATestCase): def testBasic(self): with self.test_session() as sess: @@ -306,7 +306,7 @@ class ConcatOffsetTest(XLATestCase): self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) -class PackTest(XLATestCase): +class PackTest(xla_test.XLATestCase): def testBasic(self): with self.test_session() as sess: diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py index 62577b70ce96e220d79978f01614b2d9a3647680..f9db103f6d0f9ea0e393a0971593552ec5c14079 100644 --- a/tensorflow/compiler/tests/conv2d_test.py +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -22,17 +22,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import test_utils +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest +DATA_FORMATS = ( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), +) -class Conv2DTest(XLATestCase): + +class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, @@ -40,6 +47,8 @@ class Conv2DTest(XLATestCase): strides=None, dilations=None, padding=None, + data_format_src="NHWC", + data_format_dst="NHWC", expected=None): """Tests that tf.nn.conv2d produces the expected value. @@ -51,8 +60,12 @@ class Conv2DTest(XLATestCase): strides: Strides. dilations: RHS dilations. padding: Padding type. + data_format_src: Data format input is in. + data_format_dst: Data format verification will run and input is converted + to. expected: Expected output. """ + total_size_1 = np.prod(input_sizes) total_size_2 = np.prod(filter_sizes) x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes) @@ -62,6 +75,18 @@ class Conv2DTest(XLATestCase): dilations = [1, 1] dilations = [1] + dilations + [1] + # Convert between data formats. + expected = test_utils.ConvertBetweenDataFormats(expected, data_format_src, + data_format_dst) + x1 = test_utils.ConvertBetweenDataFormats(x1, data_format_src, + data_format_dst) + input_sizes = test_utils.PermuteDimsBetweenDataFormats( + input_sizes, data_format_src, data_format_dst) + strides = test_utils.PermuteDimsBetweenDataFormats(strides, data_format_src, + data_format_dst) + dilations = test_utils.PermuteDimsBetweenDataFormats( + dilations, data_format_src, data_format_dst) + with self.test_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) @@ -71,12 +96,14 @@ class Conv2DTest(XLATestCase): t2, strides=strides, padding=padding, - data_format="NHWC", + data_format=data_format_dst, dilations=dilations) + value = sess.run(out, {t1: x1, t2: x2}) self.assertAllClose(expected, value, 1e-3) - def testConv2D1x1Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x1Filter(self, data_format): expected_output = np.reshape([ 30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0, 204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0 @@ -86,9 +113,12 @@ class Conv2DTest(XLATestCase): filter_sizes=[1, 1, 3, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Filter(self, data_format): expected_output = np.reshape( [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0], [1, 1, 2, 3]) self._VerifyValues( @@ -96,9 +126,12 @@ class Conv2DTest(XLATestCase): filter_sizes=[2, 2, 3, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Filter2x1Dilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Filter2x1Dilation(self, data_format): expected_output = np.array([[[[72], [82], [92]], [[112], [122], [132]]]]) self._VerifyValues( input_sizes=[1, 4, 4, 1], @@ -106,9 +139,12 @@ class Conv2DTest(XLATestCase): strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2Filter(self, data_format): expected_output = np.reshape([ 231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0, 936.0, 1029.0 @@ -118,18 +154,24 @@ class Conv2DTest(XLATestCase): filter_sizes=[1, 2, 3, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2(self, data_format): expected_output = np.reshape([2271.0, 2367.0, 2463.0], [1, 1, 1, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], strides=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2Same(self, data_format): expected_output = np.reshape( [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0], [1, 1, 2, 3]) self._VerifyValues( @@ -137,47 +179,61 @@ class Conv2DTest(XLATestCase): filter_sizes=[2, 2, 3, 3], strides=[2, 2], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2DEmptyDilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DEmptyDilation(self, data_format): self._VerifyValues( input_sizes=[0, 2, 3, 3], filter_sizes=[1, 1, 3, 3], strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.zeros([0, 2, 3, 3])) - def testConv2D2x2FilterDilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterDilation(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.reshape([2667, 2781, 2895], [1, 1, 1, 3])) - def testConv2D1x2FilterDilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterDilation(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[1, 2, 3, 3], strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.array([[[[231, 252, 273], [384, 423, 462]], [[690, 765, 840], [843, 936, 1029]]]])) - def testConv2DKernelSizeMatchesInputSizeDilation(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DKernelSizeMatchesInputSizeDilation(self, data_format): self._VerifyValues( input_sizes=[1, 3, 3, 1], filter_sizes=[2, 2, 1, 2], strides=[1, 1], dilations=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.reshape([108, 128], [1, 1, 1, 2])) -class Conv2DBackpropInputTest(XLATestCase): +class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, @@ -186,6 +242,8 @@ class Conv2DBackpropInputTest(XLATestCase): strides=None, dilations=None, padding=None, + data_format_src="NHWC", + data_format_dst="NHWC", expected=None): """Tests that gen_nn_ops.conv2d_backprop_input produces the expected output. @@ -198,8 +256,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides: Strides. dilations: Dilations. padding: Padding type. + data_format_src: Data format input is in. + data_format_dst: Data format verification will run and input is converted + to. expected: Expected output. """ + total_size_1 = np.prod(filter_sizes) total_size_2 = np.prod(out_backprop_sizes) x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(filter_sizes) @@ -209,6 +271,23 @@ class Conv2DBackpropInputTest(XLATestCase): if dilations is not None: dilations = [1] + dilations + [1] + expected = np.reshape(expected, input_sizes) + + # Convert between data formats. + expected = test_utils.ConvertBetweenDataFormats(expected, data_format_src, + data_format_dst) + x2 = test_utils.ConvertBetweenDataFormats(x2, data_format_src, + data_format_dst) + input_sizes = test_utils.PermuteDimsBetweenDataFormats( + input_sizes, data_format_src, data_format_dst) + out_backprop_sizes = test_utils.PermuteDimsBetweenDataFormats( + out_backprop_sizes, data_format_src, data_format_dst) + strides = test_utils.PermuteDimsBetweenDataFormats(strides, data_format_src, + data_format_dst) + if dilations is not None: + dilations = test_utils.PermuteDimsBetweenDataFormats( + dilations, data_format_src, data_format_dst) + with self.test_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) @@ -220,12 +299,14 @@ class Conv2DBackpropInputTest(XLATestCase): strides=strides, dilations=dilations, padding=padding, - data_format="NHWC") + data_format=data_format_dst) + value = sess.run(out, {t1: x1, t2: x2}) self.assertAllEqual(input_sizes, value.shape) - self.assertAllClose(expected, np.ravel(value), 1e-3) + self.assertAllClose(expected, value, 1e-3) - def testConv2D1x1Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x1Filter(self, data_format): expected_output = [ 5, 11, 17, 11, 25, 39, 17, 39, 61, 23, 53, 83, 29, 67, 105, 35, 81, 127, 41, 95, 149, 47, 109, 171, 53, 123, 193, 59, 137, 215, 65, 151, 237, 71, @@ -237,9 +318,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 4, 4, 2], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width5(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width5(self, data_format): expected_output = [1, 2, 0, 2, 4] self._VerifyValues( input_sizes=[1, 1, 5, 1], @@ -247,9 +331,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width6(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width6(self, data_format): expected_output = [1, 2, 0, 2, 4, 0] self._VerifyValues( input_sizes=[1, 1, 6, 1], @@ -257,9 +344,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width7(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width7(self, data_format): expected_output = [1, 2, 0, 2, 4, 0, 0] self._VerifyValues( input_sizes=[1, 1, 7, 1], @@ -267,9 +357,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterC1Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterC1Same(self, data_format): expected_output = [1, 4, 7, 7, 23, 33] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -277,9 +370,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 2, 3, 1], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Filter(self, data_format): expected_output = [ 14, 32, 50, 100, 163, 226, 167, 212, 257, 122, 140, 158, 478, 541, 604, 437, 482, 527 @@ -290,9 +386,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterSame(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterSame(self, data_format): expected_output = [ 14, 32, 50, 100, 163, 226, 217, 334, 451, 190, 307, 424, 929, 1217, 1505, 1487, 1883, 2279 @@ -303,9 +402,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 2, 3, 3], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2Filter(self, data_format): expected_output = [1, 4, 4, 3, 10, 8, 5, 16, 12] self._VerifyValues( input_sizes=[1, 3, 3, 1], @@ -313,9 +415,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 3, 2, 1], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterSame(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterSame(self, data_format): expected_output = [1, 4, 7, 4, 13, 16, 7, 22, 25] self._VerifyValues( input_sizes=[1, 3, 3, 1], @@ -323,9 +428,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 3, 3, 1], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2(self, data_format): expected_output = [1, 2, 5, 4, 6, 0, 0, 0, 0, 0, 3, 6, 13, 8, 12] self._VerifyValues( input_sizes=[1, 3, 5, 1], @@ -333,9 +441,12 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 2, 2, 1], strides=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2Same(self, data_format): expected_output = [1, 2, 2, 3, 4, 6] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -343,9 +454,13 @@ class Conv2DBackpropInputTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[2, 2], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1( + self, data_format): self._VerifyValues( input_sizes=[1, 3, 6, 1], filter_sizes=[2, 2, 1, 1], @@ -353,9 +468,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[1, 4, 7, 10, 13, 10, 0, 0, 0, 0, 0, 0, 3, 10, 17, 24, 31, 20]) - def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], @@ -363,9 +481,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[1, 0, 2, 3, 0, 4]) - def testConv2DEmptyBackpropInputDilation1x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DEmptyBackpropInputDilation1x2(self, data_format): self._VerifyValues( input_sizes=[0, 2, 3, 1], filter_sizes=[2, 2, 1, 1], @@ -373,9 +494,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.zeros([0])) - def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self, data_format): # The GPU version of this test is not very stable. So adjusting the # error threshold to 1e-4. self._VerifyValues( @@ -385,12 +509,16 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[ 14, 32, 50, 68, 86, 104, 0, 0, 0, 0, 0, 0, 122, 140, 158, 176, 194, 212 ]) - def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2( + self, data_format): self._VerifyValues( input_sizes=[1, 3, 3, 1], filter_sizes=[2, 2, 1, 2], @@ -398,10 +526,12 @@ class Conv2DBackpropInputTest(XLATestCase): strides=[1, 1], dilations=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[5, 0, 11, 0, 0, 0, 17, 0, 23]) -class Conv2DBackpropFilterTest(XLATestCase): +class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, @@ -410,6 +540,8 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=None, dilations=None, padding=None, + data_format_src="NHWC", + data_format_dst="NHWC", expected=None): """Tests that gen_nn_ops.conv2d_backprop_filter produces the right output. @@ -422,6 +554,9 @@ class Conv2DBackpropFilterTest(XLATestCase): strides: Stride. dilations: Dilations. padding: Padding type. + data_format_src: Data format input is in. + data_format_dst: Data format verification will run and input is converted + to. expected: Expected output. """ @@ -434,6 +569,23 @@ class Conv2DBackpropFilterTest(XLATestCase): if dilations is not None: dilations = [1] + dilations + [1] + expected = np.reshape(expected, filter_sizes) + + # Convert between data formats. + x1 = test_utils.ConvertBetweenDataFormats(x1, data_format_src, + data_format_dst) + x2 = test_utils.ConvertBetweenDataFormats(x2, data_format_src, + data_format_dst) + input_sizes = test_utils.PermuteDimsBetweenDataFormats( + input_sizes, data_format_src, data_format_dst) + out_backprop_sizes = test_utils.PermuteDimsBetweenDataFormats( + out_backprop_sizes, data_format_src, data_format_dst) + strides = test_utils.PermuteDimsBetweenDataFormats(strides, data_format_src, + data_format_dst) + if dilations is not None: + dilations = test_utils.PermuteDimsBetweenDataFormats( + dilations, data_format_src, data_format_dst) + with self.test_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) @@ -445,13 +597,14 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=strides, dilations=dilations, padding=padding, - data_format="NHWC") + data_format=data_format_dst) value = sess.run(tensor, {t1: x1, t2: x2}) self.assertAllEqual(filter_sizes, value.shape) - self.assertAllClose(expected, np.ravel(value), 1e-3) + self.assertAllClose(expected, value, 1e-3) - def testConv2D1x1Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x1Filter(self, data_format): expected_output = [8056, 8432, 8312, 8704, 8568, 8976] self._VerifyValues( input_sizes=[1, 4, 4, 3], @@ -459,9 +612,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 4, 4, 2], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2Filter(self, data_format): expected_output = [120, 141] self._VerifyValues( input_sizes=[1, 3, 3, 1], @@ -469,9 +625,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 3, 2, 1], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterDepth1(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterDepth1(self, data_format): expected_output = [5, 8, 14, 17] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -479,9 +638,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Filter(self, data_format): expected_output = [ 17, 22, 27, 22, 29, 36, 27, 36, 45, 32, 43, 54, 37, 50, 63, 42, 57, 72, 62, 85, 108, 67, 92, 117, 72, 99, 126, 77, 106, 135, 82, 113, 144, 87, @@ -493,9 +655,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 3], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width5(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width5(self, data_format): expected_output = [9, 12] self._VerifyValues( input_sizes=[1, 1, 5, 1], @@ -503,9 +668,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width6(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width6(self, data_format): expected_output = [9, 12] self._VerifyValues( input_sizes=[1, 1, 6, 1], @@ -513,9 +681,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x2FilterStride3Width7(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x2FilterStride3Width7(self, data_format): expected_output = [9, 12] self._VerifyValues( input_sizes=[1, 1, 7, 1], @@ -523,9 +694,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[3, 3], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x3Filter(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x3Filter(self, data_format): expected_output = [5, 8, 11] self._VerifyValues( input_sizes=[1, 1, 4, 1], @@ -533,9 +707,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[1, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x3FilterSame(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x3FilterSame(self, data_format): expected_output = [20, 30, 20] self._VerifyValues( input_sizes=[1, 1, 4, 1], @@ -543,9 +720,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 4, 1], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D1x3FilterSameOutbackprop2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D1x3FilterSameOutbackprop2(self, data_format): expected_output = [7, 10, 3] self._VerifyValues( input_sizes=[1, 1, 4, 1], @@ -553,9 +733,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[2, 2], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterC1Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterC1Same(self, data_format): expected_output = [91, 58, 32, 17] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -563,9 +746,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 2, 3, 1], strides=[1, 1], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2(self, data_format): expected_output = [92, 102, 112] self._VerifyValues( input_sizes=[1, 3, 5, 1], @@ -573,9 +759,12 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 2, 2, 1], strides=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2FilterStride2Same(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2FilterStride2Same(self, data_format): expected_output = [7, 2, 16, 5] self._VerifyValues( input_sizes=[1, 2, 3, 1], @@ -583,9 +772,13 @@ class Conv2DBackpropFilterTest(XLATestCase): out_backprop_sizes=[1, 1, 2, 1], strides=[2, 2], padding="SAME", + data_format_src="NHWC", + data_format_dst=data_format, expected=expected_output) - def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1( + self, data_format): self._VerifyValues( input_sizes=[1, 3, 6, 1], filter_sizes=[2, 2, 1, 1], @@ -593,9 +786,12 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[2, 1], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[55, 70, 235, 250]) - def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], @@ -603,9 +799,12 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[1, 3, 4, 6]) - def testConv2DEmptyBackpropFilterDilation1x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DEmptyBackpropFilterDilation1x2(self, data_format): self._VerifyValues( input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 0], @@ -613,9 +812,12 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[1, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=np.zeros([0])) - def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self, data_format): self._VerifyValues( input_sizes=[1, 3, 4, 3], filter_sizes=[2, 2, 3, 3], @@ -623,13 +825,17 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[ 17, 22, 27, 22, 29, 36, 27, 36, 45, 47, 64, 81, 52, 71, 90, 57, 78, 99, 137, 190, 243, 142, 197, 252, 147, 204, 261, 167, 232, 297, 172, 239, 306, 177, 246, 315 ]) - def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self): + @parameterized.named_parameters(*DATA_FORMATS) + def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2( + self, data_format): self._VerifyValues( input_sizes=[1, 3, 3, 1], filter_sizes=[2, 2, 1, 2], @@ -637,6 +843,8 @@ class Conv2DBackpropFilterTest(XLATestCase): strides=[1, 1], dilations=[2, 2], padding="VALID", + data_format_src="NHWC", + data_format_dst=data_format, expected=[1, 2, 3, 6, 7, 14, 9, 18]) diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index 3bebf46511cbc471d3fbbbe92d28511fcc717387..31ee41f04f27d387415e9fa2c4fa70b33cab7b04 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest # Test cloned from # tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py -class Conv3DBackpropFilterV2GradTest(XLATestCase): +class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): def testGradient(self): with self.test_session(), self.test_scope(): @@ -66,7 +66,7 @@ class Conv3DBackpropFilterV2GradTest(XLATestCase): # Test cloned from tensorflow/python/kernel_tests/conv3d_transpose_test.py -class Conv3DTransposeTest(XLATestCase): +class Conv3DTransposeTest(xla_test.XLATestCase): def testConv3DTransposeSingleStride(self): with self.test_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 03d96a2cd8ab22a472a67f092e36224820405fa8..98dc73e189f99b7b811487756659d89dacb97d8a 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -114,7 +114,7 @@ def CheckGradConfigsToTest(): yield i, f, o, s, p -class DepthwiseConv2DTest(XLATestCase): +class DepthwiseConv2DTest(xla_test.XLATestCase): # This is testing that depthwise_conv2d and depthwise_conv2d_native # produce the same results. It also tests that NCHW and NWHC diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py index 6a46d2ec3e7aee3a4ecfbf1ab9f622d8eb659e3c..154e36b10e6da409606ae6022aaf53e34c8e37cc 100644 --- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py +++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py @@ -20,14 +20,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class DynamicUpdateSliceOpsTest(XLATestCase): +class DynamicUpdateSliceOpsTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py index c109c27abe2f145685f83251e1d21ec8ddad563a..edd78153b56bb5bf1c268936fb82a60581389733 100644 --- a/tensorflow/compiler/tests/dynamic_stitch_test.py +++ b/tensorflow/compiler/tests/dynamic_stitch_test.py @@ -20,14 +20,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.platform import googletest -class DynamicStitchTest(XLATestCase): +class DynamicStitchTest(xla_test.XLATestCase): def _AssertDynamicStitchResultIs(self, indices, data, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index a4154ad1e846f8241a2ab6598da36ccb6b3b653e..ff097f80f1f2586bd483a54d532750c90b2a8b03 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -32,6 +32,7 @@ from tensorflow.python.layers import convolutional from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -40,7 +41,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.training import adam -class EagerTest(XLATestCase): +class EagerTest(xla_test.XLATestCase): def testBasic(self): with self.test_scope(): @@ -49,6 +50,21 @@ class EagerTest(XLATestCase): product = three * five self.assertAllEqual(15, product) + def testGradientTape(self): + with self.test_scope(): + + x = constant_op.constant(1.0) + y = constant_op.constant(10.0) + with backprop.GradientTape(persistent=True) as tape: + tape.watch(x) + tape.watch(y) + a = x + y + x * y + da_dx = tape.gradient(a, x) + da_dy = tape.gradient(a, y) + + self.assertEqual(11.0, da_dx.numpy()) + self.assertEqual(2.0, da_dy.numpy()) + def testExecuteListOutputLen0(self): with self.test_scope(): empty = constant_op.constant([], dtype=dtypes.float32) @@ -107,6 +123,14 @@ class EagerTest(XLATestCase): with self.test_scope(): self.assertAllEqual(2, array_ops.identity(2)) + def testRandomOps(self): + with self.test_scope(): + tensor = gen_random_ops.random_uniform((2, 2), dtypes.float32) + row0 = tensor[0].numpy() + row1 = tensor[1].numpy() + # It should be very unlikely to rng to generate two equal rows. + self.assertFalse((row0 == row1).all()) + def testIdentityOnVariable(self): with self.test_scope(): v = resource_variable_ops.ResourceVariable(True) @@ -271,11 +295,11 @@ class EagerTest(XLATestCase): [2.0, 2.0]], embedding_matrix.numpy()) -class EagerFunctionTest(XLATestCase): +class EagerFunctionTest(xla_test.XLATestCase): def testBasic(self): with self.test_scope(): - matmul = function.defun(math_ops.matmul, compiled=True) + matmul = function.defun(math_ops.matmul) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) sq = matmul(t, t, transpose_a=True) self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) @@ -297,7 +321,7 @@ class EagerFunctionTest(XLATestCase): def model(x): x = conv(x) return pool(x) - model = function.defun(model, compiled=True) + model = function.defun(model) x = array_ops.ones([1, 4, 4, 1]) y = model(x) @@ -307,7 +331,7 @@ class EagerFunctionTest(XLATestCase): with self.test_scope(): v = resource_variable_ops.ResourceVariable(1.0) - @function.defun(compiled=True) + @function.defun def f(): return v.read_value() @@ -322,7 +346,7 @@ class EagerFunctionTest(XLATestCase): v.assign_add(1.0) return v - f = function.defun(f, compiled=True) + f = function.defun(f) var = f(v) self.assertEqual(2.0, var.numpy()) @@ -350,7 +374,7 @@ class EagerFunctionTest(XLATestCase): d = r2 * v2 return a, b, c, d - foo = function.defun(foo, compiled=True) + foo = function.defun(foo) c1 = [0, 0] c2 = array_ops.ones([2], dtype=dtypes.int32) @@ -372,7 +396,7 @@ class EagerFunctionTest(XLATestCase): with self.test_scope(): v0 = resource_variable_ops.ResourceVariable(5.0) - @function.defun(compiled=True) + @function.defun def f(x): x = v0 * v0 * x return x @@ -385,8 +409,57 @@ class EagerFunctionTest(XLATestCase): self.assertEqual(75, y.numpy()) self.assertEqual(30, dy.numpy()) + def testGradientTapeInDefun(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun + def f(): + x = constant_op.constant(1.0) + with backprop.GradientTape() as tape: + y = v0 * x + dy = tape.gradient(y, v0) + return dy + + dy = f() + self.assertEqual(1.0, dy.numpy()) + + def testSliceInDefun(self): + with self.test_scope(): + + @function.defun + def f(x, y): + return x[0::2, y:, ...] + + x = array_ops.ones([2, 3, 4]) + y = array_ops.ones([], dtype=dtypes.int32) + with backprop.GradientTape() as tape: + tape.watch(x) + tape.watch(y) + z = f(x, y) + dz = tape.gradient(z, x) + + self.assertAllEqual(np.ones([1, 2, 4]), z.numpy()) + self.assertAllEqual((2, 3, 4), dz.shape.as_list()) -class ExcessivePaddingTest(XLATestCase): + def testNestedDefun(self): + self.skipTest('Nested defuns do not work on TPU at the moment') + with self.test_scope(): + + @function.defun + def times_two(x): + return 2 * x + + @function.defun + def two_x_plus_1(x): + return times_two(x) + 1 + + x = constant_op.constant([2, 3, 4]) + y = two_x_plus_1(x) + self.assertAllEqual([5, 7, 9], y.numpy()) + + +class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. Tensors that would normally be excessively padded when written @@ -417,7 +490,7 @@ class ExcessivePaddingTest(XLATestCase): def testAsFunctionInput(self): with self.test_scope(): - @function.defun(compiled=True) + @function.defun def f(x): return math_ops.reduce_sum(x, axis=2) @@ -428,7 +501,7 @@ class ExcessivePaddingTest(XLATestCase): def testAsFunctionOutput(self): with self.test_scope(): - @function.defun(compiled=True) + @function.defun def f(x): return x * constant_op.constant(100 * [[[10.0, 2.0]]]) @@ -437,6 +510,36 @@ class ExcessivePaddingTest(XLATestCase): self.assertAllEqual(100 * [[36.0]], reduced) +def multiple_tpus(): + devices = context.context().devices() + return len([d for d in devices if 'device:TPU:' in d]) > 1 + + +class MultiDeviceTest(xla_test.XLATestCase): + """Test running TPU computation on more than one core.""" + + def testBasic(self): + if not multiple_tpus(): + self.skipTest('MultiDeviceTest requires multiple TPU devices.') + + # Compute 10 on TPU core 0 + with ops.device('device:TPU:0'): + two = constant_op.constant(2) + five = constant_op.constant(5) + ten = two * five + self.assertAllEqual(10, ten) + + # Compute 6 on TPU core 1 + with ops.device('device:TPU:1'): + two = constant_op.constant(2) + three = constant_op.constant(3) + six = two * three + self.assertAllEqual(6, six) + + # Copy 10 and 6 to CPU and sum them + self.assertAllEqual(16, ten + six) + + if __name__ == '__main__': ops.enable_eager_execution( config=config_pb2.ConfigProto(log_device_placement=True)) diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py index 0361702e7af778176daed941d64e61198090daf2..5529fdbb090315e1d7f47589777d8a538c90db2b 100644 --- a/tensorflow/compiler/tests/extract_image_patches_op_test.py +++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ExtractImagePatches(XLATestCase): +class ExtractImagePatches(xla_test.XLATestCase): """Functional tests for ExtractImagePatches op.""" def _VerifyValues(self, image, ksizes, strides, rates, padding, patches): diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py index dfe9400ef0f55ca011d4e23ba5d735899ca2e054..c48ab178bf53558084fb500b2811c6f0b77a7943 100644 --- a/tensorflow/compiler/tests/fake_quant_ops_test.py +++ b/tensorflow/compiler/tests/fake_quant_ops_test.py @@ -17,14 +17,14 @@ from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.platform import googletest -class FakeQuantWithMinMaxArgsTest(XLATestCase): +class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxArgs operation.""" # 8 bits, wide range. @@ -122,7 +122,7 @@ class FakeQuantWithMinMaxArgsTest(XLATestCase): result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) -class FakeQuantWithMinMaxArgsGradientTest(XLATestCase): +class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxArgsGradient operation.""" # 8 bits, wide range. @@ -223,7 +223,7 @@ class FakeQuantWithMinMaxArgsGradientTest(XLATestCase): bfloat16_rtol=0.03) -class FakeQuantWithMinMaxVarsTest(XLATestCase): +class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxVars operation.""" # 8 bits, wide range. @@ -328,7 +328,7 @@ class FakeQuantWithMinMaxVarsTest(XLATestCase): result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) -class FakeQuantWithMinMaxVarsGradientTest(XLATestCase): +class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxVarsGradient operation.""" # 8 bits, wide range. diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index afb5fa4bb4fefe5bc2ecded826143ffc83c2b559..c64ea249ecb97991952a960a6d16e1bb3be35b17 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -23,10 +23,11 @@ import itertools import numpy as np import scipy.signal as sps -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.contrib.signal.python.ops import spectral_ops as signal from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import spectral_ops from tensorflow.python.platform import googletest @@ -57,7 +58,7 @@ INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2)) INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2)) -class FFTTest(XLATestCase): +class FFTTest(xla_test.XLATestCase): def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected, tf_method): @@ -97,8 +98,11 @@ class FFTTest(XLATestCase): ph = array_ops.placeholder( dtypes.as_dtype(data.dtype), shape=data.shape) out = signal.stft(ph, ws, hs) + grad = gradients_impl.gradients(out, ph, + grad_ys=array_ops.ones_like(out)) - value = sess.run(out, {ph: data}) + # For gradients, we simply verify that they compile & execute. + value, _ = sess.run([out, grad], {ph: data}) self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL) def testFFT(self): diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0f64cc87cde77fbbef6c4e570879e992bc34bafa --- /dev/null +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -0,0 +1,201 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.platform import test + + +class FIFOQueueTest(xla_test.XLATestCase): + + def testEnqueue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + enqueue_op = q.enqueue((10.0,)) + enqueue_op.run() + + def testEnqueueWithShape(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2)) + enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],)) + enqueue_correct_op.run() + with self.assertRaises(ValueError): + q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],)) + self.assertEqual(1, q.size().eval()) + + def testMultipleDequeues(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue([1])) + self.evaluate(q.enqueue([2])) + self.evaluate(q.enqueue([3])) + a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()]) + self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) + + def testQueuesDontShare(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue(1)) + q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q2.enqueue(2)) + self.assertAllEqual(self.evaluate(q2.dequeue()), 2) + self.assertAllEqual(self.evaluate(q.dequeue()), 1) + + def testEnqueueDictWithoutNames(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + with self.assertRaisesRegexp(ValueError, "must have names"): + q.enqueue({"a": 12.0}) + + def testParallelEnqueue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + # Run one producer thread for each element in elems. + def enqueue(enqueue_op): + sess.run(enqueue_op) + + threads = [ + self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Dequeue every element using a single thread. + results = [] + for _ in xrange(len(elems)): + results.append(dequeued_t.eval()) + self.assertItemsEqual(elems, results) + + def testParallelDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + # Enqueue every element using a single thread. + for enqueue_op in enqueue_ops: + enqueue_op.run() + + # Run one consumer thread for each element in elems. + results = [] + + def dequeue(): + results.append(sess.run(dequeued_t)) + + threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + self.assertItemsEqual(elems, results) + + def testDequeue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + for enqueue_op in enqueue_ops: + enqueue_op.run() + + for i in xrange(len(elems)): + vals = dequeued_t.eval() + self.assertEqual([elems[i]], vals) + + def testEnqueueAndBlockingDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + def enqueue(): + # The enqueue_ops should run after the dequeue op has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + for enqueue_op in enqueue_ops: + sess.run(enqueue_op) + + results = [] + + def dequeue(): + for _ in xrange(len(elems)): + results.append(sess.run(dequeued_t)) + + enqueue_thread = self.checkedThread(target=enqueue) + dequeue_thread = self.checkedThread(target=dequeue) + enqueue_thread.start() + dequeue_thread.start() + enqueue_thread.join() + dequeue_thread.join() + + for elem, result in zip(elems, results): + self.assertEqual([elem], result) + + def testMultiEnqueueAndDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32)) + elems = [(5, 10.0), (10, 20.0), (15, 30.0)] + enqueue_ops = [q.enqueue((x, y)) for x, y in elems] + dequeued_t = q.dequeue() + + for enqueue_op in enqueue_ops: + enqueue_op.run() + + for i in xrange(len(elems)): + x_val, y_val = sess.run(dequeued_t) + x, y = elems[i] + self.assertEqual([x], x_val) + self.assertEqual([y], y_val) + + def testQueueSizeEmpty(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + self.assertEqual([0], q.size().eval()) + + def testQueueSizeAfterEnqueueAndDequeue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + enqueue_op = q.enqueue((10.0,)) + dequeued_t = q.dequeue() + size = q.size() + self.assertEqual([], size.get_shape()) + + enqueue_op.run() + self.assertEqual(1, size.eval()) + dequeued_t.op.run() + self.assertEqual(0, size.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 8e6407dffdac3adbcda8cbca2109ef9196defa8c..1da97fd51217a0f28d4b3ba2ccfae3f6b094e65b 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -30,7 +30,7 @@ from tensorflow.python.training import ftrl from tensorflow.python.training import gradient_descent -class FtrlOptimizerTest(XLATestCase): +class FtrlOptimizerTest(xla_test.XLATestCase): def initVariableAndGradient(self, dtype): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 8a3f4b0bdc7a61d6cfa2ba7474ce8579e293a5c7..04fba444460e714ce96205361ac02ed492206b04 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class FunctionTest(XLATestCase): +class FunctionTest(xla_test.XLATestCase): def testFunction(self): """Executes a simple TensorFlow function.""" diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index a80d69fa5f5099b8a8b67df0da9c92b957e9d194..132e42ac7a28d0769b0de12ea0cee6eae752b245 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -18,9 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import test_utils +from tensorflow.compiler.tests import xla_test from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker @@ -28,7 +30,7 @@ from tensorflow.python.ops import nn from tensorflow.python.platform import test -class FusedBatchNormTest(XLATestCase): +class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): def _reference_training(self, x, scale, offset, epsilon, data_format): if data_format != "NHWC": @@ -63,24 +65,36 @@ class FusedBatchNormTest(XLATestCase): grad_offset = np.sum(grad_y, axis=(0, 1, 2)) return grad_x, grad_scale, grad_offset - def testInference(self): + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testInference(self, data_format): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) - offset_val = np.random.random_sample(scale_shape).astype(np.float32) - data_format = "NHWC" + epsilon = 0.001 + data_format_src = "NHWC" + y_ref, mean_ref, var_ref = self._reference_training( + x_val, scale_val, offset_val, epsilon, data_format_src) + with self.test_session() as sess, self.test_scope(): # To avoid constant folding - t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") + x_val_converted = test_utils.ConvertBetweenDataFormats( + x_val, data_format_src, data_format) + y_ref_converted = test_utils.ConvertBetweenDataFormats( + y_ref, data_format_src, data_format) + + t_val = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder( np.float32, shape=scale_shape, name="offset") - epsilon = 0.001 - y_ref, mean_ref, var_ref = self._reference_training( - x_val, scale_val, offset_val, epsilon, data_format) y, mean, variance = nn.fused_batch_norm( t_val, scale, @@ -91,31 +105,39 @@ class FusedBatchNormTest(XLATestCase): data_format=data_format, is_training=False) - y_val, _, _ = sess.run( - [y, mean, - variance], {t_val: x_val, - scale: scale_val, - offset: offset_val}) - self.assertAllClose(y_val, y_ref, atol=1e-3) + y_val, _, _ = sess.run([y, mean, variance], { + t_val: x_val_converted, + scale: scale_val, + offset: offset_val + }) + self.assertAllClose(y_val, y_ref_converted, atol=1e-3) - def _testLearning(self, use_gradient_checker): + def _testLearning(self, use_gradient_checker, data_format): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) - offset_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) - data_format = "NHWC" + epsilon = 0.001 + data_format_src = "NHWC" + y_ref, mean_ref, var_ref = self._reference_training( + x_val, scale_val, offset_val, epsilon, data_format_src) + with self.test_session() as sess, self.test_scope(): # To avoid constant folding - t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") + x_val_converted = test_utils.ConvertBetweenDataFormats( + x_val, data_format_src, data_format) + y_ref_converted = test_utils.ConvertBetweenDataFormats( + y_ref, data_format_src, data_format) + + t_val = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder( np.float32, shape=scale_shape, name="offset") - epsilon = 0.001 y, mean, var = nn.fused_batch_norm( t_val, scale, @@ -129,33 +151,50 @@ class FusedBatchNormTest(XLATestCase): if use_gradient_checker: err = gradient_checker.compute_gradient_error( t_val, - x_shape, + x_val_converted.shape, y, - x_shape, + x_val_converted.shape, extra_feed_dict={ - t_val: x_val, + t_val: x_val_converted, scale: scale_val, offset: offset_val }) self.assertLess(err, 1e-3) - y_val, mean_val, var_val = sess.run( - [y, mean, var], {t_val: x_val, - scale: scale_val, - offset: offset_val}) - y_ref, mean_ref, var_ref = self._reference_training( - x_val, scale_val, offset_val, epsilon, data_format) + y_val, mean_val, var_val = sess.run([y, mean, var], { + t_val: x_val_converted, + scale: scale_val, + offset: offset_val + }) self.assertAllClose(mean_val, mean_ref, atol=1e-3) - self.assertAllClose(y_val, y_ref, atol=1e-3) + self.assertAllClose(y_val, y_ref_converted, atol=1e-3) self.assertAllClose(var_val, var_ref, atol=1e-3) - def testLearning(self): - self._testLearning(False) + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testLearning(self, data_format): + self._testLearning(False, data_format) - def testLearningWithGradientChecker(self): - self._testLearning(True) + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testLearningWithGradientChecker(self, data_format): + self._testLearning(True, data_format) - def testGradientTraining(self): + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testGradientTraining(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. channel = 3 @@ -167,33 +206,48 @@ class FusedBatchNormTest(XLATestCase): mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 + data_format_src = "NHWC" + grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( + x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src) with self.test_session() as sess, self.test_scope(): - grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") - x = array_ops.placeholder(np.float32, shape=x_shape, name="x") + grad_val_converted = test_utils.ConvertBetweenDataFormats( + grad_val, data_format_src, data_format) + x_val_converted = test_utils.ConvertBetweenDataFormats( + x_val, data_format_src, data_format) + grad_x_ref_converted = test_utils.ConvertBetweenDataFormats( + grad_x_ref, data_format_src, data_format) + + grad = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="grad") + x = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="x") mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC", is_training=True) + grad, x, scale, mean, var, data_format=data_format, is_training=True) grad_x_val, grad_scale_val, grad_offset_val = sess.run( [grad_x, grad_scale, grad_offset], { - grad: grad_val, - x: x_val, + grad: grad_val_converted, + x: x_val_converted, mean: mean_val, var: var_val, scale: scale_val }) - grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( - x_val, grad_val, scale_val, mean_val, var_val, epsilon, "NHWC") - - self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2) + self.assertAllClose(grad_x_val, grad_x_ref_converted, atol=1e-2) self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) - def testGradientInference(self): + @parameterized.named_parameters( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), + ("_data_format_HWNC", "HWNC"), + ("_data_format_HWCN", "HWCN"), + ) + def testGradientInference(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. channel = 3 @@ -204,33 +258,47 @@ class FusedBatchNormTest(XLATestCase): scale_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) + data_format_src = "NHWC" with self.test_session() as sess, self.test_scope(): - grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") - x = array_ops.placeholder(np.float32, shape=x_shape, name="x") + grad_val_converted = test_utils.ConvertBetweenDataFormats( + grad_val, data_format_src, data_format) + x_val_converted = test_utils.ConvertBetweenDataFormats( + x_val, data_format_src, data_format) + + grad = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="grad") + x = array_ops.placeholder( + np.float32, shape=x_val_converted.shape, name="x") mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") with self.test_scope(): out = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC", is_training=False) + grad, + x, + scale, + mean, + var, + data_format=data_format, + is_training=False) grad_x, grad_scale, grad_offset, _, _ = out ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC", is_training=False) + grad, x, scale, mean, var, data_format=data_format, is_training=False) grad_x_val, grad_scale_val, grad_offset_val, = sess.run( [grad_x, grad_scale, grad_offset], { - grad: grad_val, - x: x_val, + grad: grad_val_converted, + x: x_val_converted, mean: mean_val, var: var_val, scale: scale_val }) grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run( [ref_x, ref_scale, ref_offset], { - grad: grad_val, - x: x_val, + grad: grad_val_converted, + x: x_val_converted, mean: mean_val, var: var_val, scale: scale_val diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py index 9378b1db7245c0da3e8298e7dcd972491616b0cd..23b0aed34fb460f50c241e5a920cb4f6f613b947 100644 --- a/tensorflow/compiler/tests/gather_nd_op_test.py +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class GatherNdTest(XLATestCase): +class GatherNdTest(xla_test.XLATestCase): def _runGather(self, params, indices): with self.test_session(): diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index 1a8c4519118f69ce51ca9a5eb95a9d706c7766cc..e9c8ef7c91a728b7dfc948fd9b315e6c9102f6a3 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -136,6 +136,20 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual( [[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]})) + def testGatherPrecision(self): + with self.test_session() as session, self.test_scope(): + data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0], + [0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]]) + indices = np.array([1, 2, 3, 1]) + dtype = dtypes.float32 + params_np = self._buildParams(data, dtype) + params = array_ops.placeholder(dtype=dtype) + indices_tf = constant_op.constant(indices) + gather_t = array_ops.gather(params, indices_tf) + gather_val = session.run(gather_t, feed_dict={params: params_np}) + np_val = params_np[indices] + self.assertAllEqual(np_val, gather_val) + class GatherBenchmark(test.Benchmark): """Microbenchmarks for the gather op.""" diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 7cf953ef25ef5daf8a6d4fc9985ed8dbfb2081e5..bf986ade06b11358552ee92df3169f965ce3f534 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -25,7 +25,8 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test +from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -41,7 +42,7 @@ def GenerateNumpyRandomRGB(shape): return np.random.randint(0, 256, shape) / 256. -class RGBToHSVTest(XLATestCase): +class RGBToHSVTest(xla_test.XLATestCase): def testBatch(self): # Build an arbitrary RGB image @@ -104,7 +105,7 @@ class RGBToHSVTest(XLATestCase): self.assertAllCloseAccordingToType(hsv_tf, hsv_np) -class AdjustContrastTest(XLATestCase): +class AdjustContrastTest(xla_test.XLATestCase): def _testContrast(self, x_np, y_np, contrast_factor): with self.test_session(): @@ -168,7 +169,7 @@ class AdjustContrastTest(XLATestCase): self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5) -class AdjustHueTest(XLATestCase): +class AdjustHueTest(xla_test.XLATestCase): def testAdjustNegativeHue(self): x_shape = [2, 2, 3] @@ -303,7 +304,7 @@ class AdjustHueTest(XLATestCase): self._adjustHueTf(x_np, delta_h) -class AdjustSaturationTest(XLATestCase): +class AdjustSaturationTest(xla_test.XLATestCase): def _adjust_saturation(self, image, saturation_factor): image = ops.convert_to_tensor(image, name="image") @@ -403,7 +404,7 @@ class AdjustSaturationTest(XLATestCase): self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) -class ResizeBilinearTest(XLATestCase): +class ResizeBilinearTest(xla_test.XLATestCase): def _assertForwardOpMatchesExpected(self, image_np, @@ -579,5 +580,140 @@ class ResizeBilinearTest(XLATestCase): large_tolerance=True) +class NonMaxSuppressionTest(xla_test.XLATestCase): + + def testNMS128From1024(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + num_boxes = 1024 + boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") + scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4") + + max_output_size = 128 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + + def testNMS3From6Boxes(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + # Three boxes are selected based on IOU. + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 3) + self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) + + def testNMS3Then2WithScoreThresh(self): + # Three boxes are selected based on IOU. + # One is filtered out by score threshold. + + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 2) + self.assertAllClose(indices_tf[:num_valid], [3, 0]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index 69bd8f7230d4394c45764d02a88fb0ec097c5756..253b45902fba2df64e5234f135b373cd2a0a7e2a 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -22,7 +22,7 @@ import copy import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -36,7 +36,7 @@ CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" # Local response normalization tests. The forward tests are copied from # tensorflow/python/kernel_tests/lrn_op_test.py -class LRNTest(XLATestCase): +class LRNTest(xla_test.XLATestCase): def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0, alpha=1.0, beta=0.5): diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index 29394f9ea5139b30f88f53de0469b27e37d79195..0d9f99f8a6803ecae5f9233518a1768109161ac0 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -19,14 +19,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class MatrixBandPartTest(XLATestCase): +class MatrixBandPartTest(xla_test.XLATestCase): def _testMatrixBandPart(self, dtype, shape): with self.test_session(): diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 5819b2bf2b55b9213a039c0ba82dd0bf1c738b00..2bb8a97bdaf5836a05501ab9754433e29ae34675 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -22,7 +22,7 @@ import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -35,7 +35,7 @@ def MakePlaceholder(x): return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape) -class MatrixTriangularSolveOpTest(XLATestCase): +class MatrixTriangularSolveOpTest(xla_test.XLATestCase): # MatrixTriangularSolve defined for float64, float32, complex64, complex128 # (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve) diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index af9394e7d7dc9cf7dd009420ff9c845aec8785bd..c2592c54cf83d41f0e3bdbc1f4dc9ff276ddb078 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import momentum as momentum_lib -class MomentumOptimizerTest(XLATestCase): +class MomentumOptimizerTest(xla_test.XLATestCase): def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum): var += accum * lr * momentum diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index e4843b169b943b63346b783ddc50039030988ca5..da08225e9fc0d5a8ec21ee9961c4758fa38628b4 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -22,14 +22,14 @@ import unittest import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class NAryOpsTest(XLATestCase): +class NAryOpsTest(xla_test.XLATestCase): def _testNAry(self, op, args, expected, equality_fn=None): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index 6f588d8ab562cb24f33c4c2987df22264aede027..2f9122645d3c5ccabc8130ac30a3f09cf4bc2de7 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import googletest -class NullaryOpsTest(XLATestCase): +class NullaryOpsTest(xla_test.XLATestCase): def _testNullary(self, op, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py index 5e6d1313bd0336eba71fcf3658d949bd3342ae11..a75d99189b5b673261c9e48f1c5998ea0c575594 100644 --- a/tensorflow/compiler/tests/placeholder_test.py +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest -class PlaceholderTest(XLATestCase): +class PlaceholderTest(xla_test.XLATestCase): def test_placeholder_with_default_default(self): with self.test_session() as sess, self.test_scope(): diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index 4eed903963a34a253ea5c409782d9a89a97a4fdf..17f860db61aeda98326a6820771d67ee948b6dda 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -41,7 +41,7 @@ def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding): padding=padding) -class Pooling3DTest(XLATestCase): +class Pooling3DTest(xla_test.XLATestCase): def _VerifyValues(self, pool_func, input_sizes, window, strides, padding, expected): @@ -187,8 +187,14 @@ class Pooling3DTest(XLATestCase): padding="VALID", expected=[29.5, 32.5, 50.5, 53.5, 176.5, 179.5, 197.5, 200.5]) - def _VerifyGradient(self, pool_func, pool_grad_func, input_sizes, ksize, - strides, padding): + def _VerifyGradient(self, + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + pool_grad_grad_func=None): """Verifies the output values of the pooling gradient function. Args: @@ -198,6 +204,7 @@ class Pooling3DTest(XLATestCase): ksize: The kernel size dimensions strides: The stride dimensions padding: Padding type. + pool_grad_grad_func: Second-order gradient function, if available. """ ksize = [1] + ksize + [1] strides = [1] + strides + [1] @@ -218,6 +225,8 @@ class Pooling3DTest(XLATestCase): output_gradient_vals = np.arange( 1, output_vals.size + 1, dtype=np.float32) output_gradient_vals = output_gradient_vals.reshape(output_vals.shape) + output_grad_grad_vals = np.arange(1, x.size + 1, dtype=np.float32) + output_grad_grad_vals = output_grad_grad_vals.reshape(x.shape) # Use the Tensorflow CPU pooling gradient to compute the expected input # gradients. @@ -236,6 +245,22 @@ class Pooling3DTest(XLATestCase): {inputs: x, output_gradients: output_gradient_vals}) + output_grad_gradients = array_ops.placeholder( + dtypes.float32, shape=expected_input_gradient_vals.shape) + if pool_grad_grad_func is not None: + expected_grad_gradients = pool_grad_grad_func( + inputs, + outputs, + output_grad_gradients, + ksize=ksize, + strides=strides, + padding=padding, + data_format="NDHWC") + expected_grad_gradients_vals = sess.run(expected_grad_gradients, { + inputs: x, + output_grad_gradients: output_grad_grad_vals + }) + # Run the gradient op on the XLA device with self.test_scope(): outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape) @@ -246,6 +271,16 @@ class Pooling3DTest(XLATestCase): ksize=ksize, strides=strides, padding=padding) + if pool_grad_grad_func is not None: + actual_grad_gradients = pool_grad_grad_func( + inputs, + outputs, + output_grad_gradients, + ksize=ksize, + strides=strides, + padding=padding, + data_format="NDHWC") + actual = sess.run(actual_input_gradients, { inputs: x, outputs: output_vals, @@ -260,6 +295,22 @@ class Pooling3DTest(XLATestCase): atol=1e-6) self.assertShapeEqual(actual, inputs) + if pool_grad_grad_func is not None: + actual_grad_gradients_vals = sess.run( + actual_grad_gradients, { + inputs: x, + outputs: output_vals, + output_grad_gradients: output_grad_grad_vals + }) + + # Compare the Tensorflow and XLA results. + self.assertAllClose( + expected_grad_gradients_vals, + actual_grad_gradients_vals, + rtol=1e-4, + atol=1e-6) + self.assertShapeEqual(actual_grad_gradients_vals, outputs) + def testMaxPoolGradValidPadding1_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, @@ -267,7 +318,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[1, 3, 3, 3, 1], ksize=[1, 1, 1], strides=[1, 1, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradValidPadding2_1_6_3d(self): self._VerifyGradient( @@ -276,9 +328,13 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 3, 3, 6, 3], ksize=[2, 2, 2], strides=[1, 1, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradValidPadding2_1_7_3d(self): + # TODO(b/73062247): the bfloat16 implementation of MaxPool3DGradGrad does + # not have enough precision for this test case to pass if + # pool_grad_grad_func is passed. self._VerifyGradient( nn_ops.max_pool3d, gen_nn_ops.max_pool3d_grad, @@ -294,7 +350,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 2, 2, 2, 3], ksize=[2, 2, 2], strides=[2, 2, 2], - padding="VALID") + padding="VALID", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradSamePadding1_1_3d(self): self._VerifyGradient( @@ -303,7 +360,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 3, 2, 4, 1], ksize=[1, 1, 1], strides=[1, 1, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradSamePadding2_1_3d(self): self._VerifyGradient( @@ -312,7 +370,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 3, 2, 4, 1], ksize=[2, 2, 2], strides=[1, 1, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradSamePadding2_2_3d(self): self._VerifyGradient( @@ -321,7 +380,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[2, 5, 2, 4, 3], ksize=[2, 2, 2], strides=[2, 2, 2], - padding="SAME") + padding="SAME", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testMaxPoolGradSamePadding3_1_3d(self): self._VerifyGradient( @@ -330,7 +390,8 @@ class Pooling3DTest(XLATestCase): input_sizes=[1, 3, 3, 7, 1], ksize=[3, 3, 3], strides=[1, 1, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad) def testAvgPoolGradValidPadding1_1_3d(self): self._VerifyGradient( diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index fe270af3d636c0824621f36360ce9e7d14d8fc91..9fc94752ea660f7fb8b2c792180f01485ad04419 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -69,7 +69,7 @@ def GetTestConfigs(): return test_configs -class PoolingTest(XLATestCase): +class PoolingTest(xla_test.XLATestCase): def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding, data_format, expected): @@ -288,7 +288,7 @@ class PoolingTest(XLATestCase): expected=expected_output) -class PoolGradTest(XLATestCase): +class PoolGradTest(xla_test.XLATestCase): CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa7706d7294f2cffb7d24a56851be02d759335a --- /dev/null +++ b/tensorflow/compiler/tests/powersign_test.py @@ -0,0 +1,142 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for PowerSign.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.contrib.opt.python.training import powersign +from tensorflow.contrib.opt.python.training import sign_decay +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def py_linear_decay_fn(decay_steps): + def linear_decay(step): + step = min(step, decay_steps) + return float(decay_steps - step) / decay_steps + return linear_decay + + +def powersign_update_numpy(params, + g_t, + m, + lr, + base=math.e, + beta=0.9, + py_sign_decay_fn=None, + t=None): + m_t = beta * m + (1 - beta) * g_t + if py_sign_decay_fn is None: + sign_decayed = 1.0 + else: + sign_decayed = py_sign_decay_fn(t-1) + multiplier = base ** (sign_decayed * np.sign(g_t) * np.sign(m_t)) + params_t = params - lr * multiplier * g_t + return params_t, m_t + + +class PowerSignTest(xla_test.XLATestCase): + + def _testDense(self, + learning_rate=0.1, + sign_decay_fn=None, + py_sign_decay_fn=None, + base=math.e, + beta=0.9): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + # Initialize variables for numpy implementation. + m0, m1 = 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + global_step = resource_variable_ops.ResourceVariable(0, trainable=False) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = powersign.PowerSignOptimizer( + learning_rate=learning_rate, + base=base, + beta=beta, + sign_decay_fn=sign_decay_fn, + ) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]), + global_step=global_step) + + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 7 steps of powersign + # first 4 steps with positive gradient + # last 3 steps with negative gradient (sign(gm) should be -1) + for t in range(1, 8): + if t < 5: + update.run() + else: + neg_update.run() + + var0_np, m0 = powersign_update_numpy( + var0_np, + grads0_np if t < 5 else -grads0_np, + m0, + learning_rate, + base=base, + beta=beta, + py_sign_decay_fn=py_sign_decay_fn, + t=t, + ) + var1_np, m1 = powersign_update_numpy( + var1_np, + grads1_np if t < 5 else -grads1_np, + m1, + learning_rate, + base=base, + beta=beta, + py_sign_decay_fn=py_sign_decay_fn, + t=t, + ) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testDense(self): + decay_steps = 10 + sign_decay_fn = sign_decay.get_linear_decay_fn(decay_steps) + py_sign_decay_fn = py_linear_decay_fn(decay_steps) + self._testDense() + self._testDense(learning_rate=0.1, base=10.0, beta=0.8) + self._testDense( + sign_decay_fn=sign_decay_fn, py_sign_decay_fn=py_sign_decay_fn) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cde87db63dbfd7c8d823c6fd0e41eee8b23735bb --- /dev/null +++ b/tensorflow/compiler/tests/proximal_adagrad_test.py @@ -0,0 +1,172 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Proximal Adagrad optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adagrad +from tensorflow.python.training import proximal_adagrad + + +class ProximalAdagradOptimizerTest(xla_test.XLATestCase): + + def testResourceProximalAdagradwithoutRegularization(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) + var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + opt = proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([0.0, 0.0], var0.eval()) + self.assertAllClose([0.0, 0.0], var1.eval()) + + # Run 3 steps Proximal Adagrad. + for _ in range(3): + update.run() + + self.assertAllClose(np.array([-2.60260963, -4.29698515]), var0.eval()) + self.assertAllClose(np.array([-0.28432083, -0.56694895]), var1.eval()) + opt_vars = opt.variables() + self.assertStartsWith(opt_vars[0].name, var0._shared_name) + self.assertStartsWith(opt_vars[1].name, var1._shared_name) + self.assertEqual(2, len(opt_vars)) + + def testProximalAdagradwithoutRegularization2(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 3 steps Proximal Adagrad. + for _ in range(3): + update.run() + self.assertAllClose(np.array([-1.60261, -2.296985]), var0.eval()) + self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval()) + + def testProximalAdagradWithL1(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps Proximal Adagrad + for _ in range(10): + update.run() + self.assertAllClose(np.array([-6.663634, -9.190331]), var0.eval()) + self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval()) + + def testProximalAdagradWithL1_L2(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps Proximal Adagrad. + for _ in range(10): + update.run() + + self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) + self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + + def applyOptimizer(self, opt, steps=5): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run ProximalAdagrad for a few steps + for _ in range(steps): + update.run() + + return var0.eval(), var1.eval() + + def testEquivAdagradwithoutRegularization(self): + with self.test_session(), self.test_scope(): + val0, val1 = self.applyOptimizer( + proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0)) + + with self.test_session(), self.test_scope(): + val2, val3 = self.applyOptimizer( + adagrad.AdagradOptimizer( + 3.0, initial_accumulator_value=0.1)) + + self.assertAllClose(val0, val2) + self.assertAllClose(val1, val3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py new file mode 100644 index 0000000000000000000000000000000000000000..11eb76871133eba8fcd24621afb03e16614fb005 --- /dev/null +++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py @@ -0,0 +1,156 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Proximal Gradient Descent optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import proximal_gradient_descent + + +class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): + + def testResourceProximalGradientDescentwithoutRegularization(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) + var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + opt = proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([0.0, 0.0], var0.eval()) + self.assertAllClose([0.0, 0.0], var1.eval()) + + # Run 3 steps Proximal Gradient Descent. + for _ in range(3): + update.run() + + self.assertAllClose(np.array([-0.9, -1.8]), var0.eval()) + self.assertAllClose(np.array([-0.09, -0.18]), var1.eval()) + + def testProximalGradientDescentwithoutRegularization2(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 3 steps Proximal Gradient Descent + for _ in range(3): + update.run() + + self.assertAllClose(np.array([0.1, 0.2]), var0.eval()) + self.assertAllClose(np.array([3.91, 2.82]), var1.eval()) + + def testProximalGradientDescentWithL1(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, l1_regularization_strength=0.001, l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps proximal gradient descent. + for _ in range(10): + update.run() + + self.assertAllClose(np.array([-1.988, -3.988001]), var0.eval()) + self.assertAllClose(np.array([3.67, 2.37]), var1.eval()) + + def testProximalGradientDescentWithL1_L2(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, l1_regularization_strength=0.001, l2_regularization_strength=2.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps Proximal Gradient Descent + for _ in range(10): + update.run() + + self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) + self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + + def applyOptimizer(self, opt, steps=5): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run ProximalAdagrad for a few steps + for _ in range(steps): + update.run() + + return var0.eval(), var1.eval() + + def testEquivGradientDescentwithoutRegularization(self): + with self.test_session(), self.test_scope(): + val0, val1 = self.applyOptimizer( + proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0)) + + with self.test_session(), self.test_scope(): + val2, val3 = self.applyOptimizer( + gradient_descent.GradientDescentOptimizer(3.0)) + + self.assertAllClose(val0, val2) + self.assertAllClose(val1, val3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1b969ee2b3886fca6ec9951d1621ca5af6a673d8 --- /dev/null +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -0,0 +1,115 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.math_ops.matrix_inverse.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): + + def AdjustedNorm(self, x): + """Computes the norm of matrices in 'x', adjusted for dimension and type.""" + norm = np.linalg.norm(x, axis=(-2, -1)) + return norm / (max(x.shape[-2:]) * np.finfo(x.dtype).eps) + + def CompareOrthogonal(self, x, y, rank): + # We only compare the first 'rank' orthogonal vectors since the + # remainder form an arbitrary orthonormal basis for the + # (row- or column-) null space, whose exact value depends on + # implementation details. Notice that since we check that the + # matrices of singular vectors are unitary elsewhere, we do + # implicitly test that the trailing vectors of x and y span the + # same space. + x = x[..., 0:rank] + y = y[..., 0:rank] + # Q is only unique up to sign (complex phase factor for complex matrices), + # so we normalize the sign first. + sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True) + phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios)) + x *= phases + self.assertTrue(np.all(self.AdjustedNorm(x - y) < 30.0)) + + def CheckApproximation(self, a, q, r): + # Tests that a ~= q*r. + precision = self.AdjustedNorm(a - np.matmul(q, r)) + self.assertTrue(np.all(precision < 10.0)) + + def CheckUnitary(self, x): + # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. + xx = math_ops.matmul(x, x, adjoint_a=True) + identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) + precision = self.AdjustedNorm(xx.eval() - identity.eval()) + self.assertTrue(np.all(precision < 5.0)) + + def _test(self, dtype, shape, full_matrices): + np.random.seed(1) + x_np = np.random.uniform( + low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) + + with self.test_session() as sess: + x_tf = array_ops.placeholder(dtype) + with self.test_scope(): + q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices) + q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np}) + + q_dims = q_tf_val.shape + np_q = np.ndarray(q_dims, dtype) + np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1])) + new_first_dim = np_q_reshape.shape[0] + + x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1])) + for i in range(new_first_dim): + if full_matrices: + np_q_reshape[i, :, :], _ = np.linalg.qr( + x_reshape[i, :, :], mode="complete") + else: + np_q_reshape[i, :, :], _ = np.linalg.qr( + x_reshape[i, :, :], mode="reduced") + np_q = np.reshape(np_q_reshape, q_dims) + self.CompareOrthogonal(np_q, q_tf_val, min(shape[-2:])) + self.CheckApproximation(x_np, q_tf_val, r_tf_val) + self.CheckUnitary(q_tf_val) + + SIZES = [1, 2, 5, 10, 32, 100, 300] + DTYPES = [np.float32] + PARAMS = itertools.product(SIZES, SIZES, DTYPES) + + @parameterized.parameters(*PARAMS) + def testQR(self, rows, cols, dtype): + # TODO(b/111317468): implement full_matrices=False, test other types. + for full_matrices in [True]: + # Only tests the (3, 2) case for small numbers of rows/columns. + for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): + self._test(dtype, batch_dims + (rows, cols), full_matrices) + + def testLarge2000x2000(self): + self._test(np.float32, (2000, 2000), full_matrices=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index f13dff96203b5480480c2a2fc9ac38ca78b7f78a..8c4e16e4e075726d741f6ff8cdfb6b1aad6cd33e 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -18,17 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math + import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import googletest -class RandomOpsTest(XLATestCase): +class RandomOpsTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def _random_types(self): @@ -54,7 +57,8 @@ class RandomOpsTest(XLATestCase): def testRandomUniformIsNotConstant(self): def rng(dtype): - return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) + dtype = dtypes.as_dtype(dtype) + return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=dtype.max) for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) @@ -70,6 +74,11 @@ class RandomOpsTest(XLATestCase): def testRandomUniformIsInRange(self): for dtype in self._random_types(): + # TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is + # fixed. + if (self.device in ["XLA_GPU", "XLA_CPU" + ]) and (dtype in [dtypes.bfloat16, dtypes.half]): + continue with self.test_session() as sess: with self.test_scope(): x = random_ops.random_uniform( @@ -87,23 +96,63 @@ class RandomOpsTest(XLATestCase): self._testRngIsNotConstant(rng, dtypes.float32) def testTruncatedNormalIsInRange(self): - count = 10000 + count = 10000000 # TODO(b/34339814): implement inverse erf support for non-F32 types. for dtype in [dtypes.float32]: with self.test_session() as sess: with self.test_scope(): - x = random_ops.truncated_normal(shape=[count], dtype=dtype, seed=42) + x = random_ops.truncated_normal(shape=[count], dtype=dtype) y = sess.run(x) - self.assertTrue((y >= -2).sum() == count) - self.assertTrue((y <= 2).sum() == count) + + def normal_cdf(x): + return .5 * math.erfc(-x / math.sqrt(2)) + + def normal_pdf(x): + return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) + + def probit(x, sess=sess): + return sess.run(special_math.ndtri(x)) + + a = -2. + b = 2. + mu = 0. + sigma = 1. + + alpha = (a - mu) / sigma + beta = (b - mu) / sigma + z = normal_cdf(beta) - normal_cdf(alpha) + + self.assertTrue((y >= a).sum() == count) + self.assertTrue((y <= b).sum() == count) + + # For more information on these calculations, see: + # Burkardt, John. "The Truncated Normal Distribution". + # Department of Scientific Computing website. Florida State University. + expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma + actual_mean = np.mean(y) + self.assertAllClose(actual_mean, expected_mean, atol=2e-3) + + expected_median = mu + probit( + (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma + actual_median = np.median(y) + self.assertAllClose(actual_median, expected_median, atol=1e-2) + + expected_variance = sigma**2 * (1 + ( + (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( + (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) + actual_variance = np.var(y) + self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3) def testShuffle1d(self): + # TODO(b/26783907): this test requires the CPU backend to implement sort. + if self.device in ["XLA_CPU"]: + return with self.test_session() as sess: with self.test_scope(): - x = math_ops.range(20) + x = math_ops.range(1 << 16) shuffle = random_ops.random_shuffle(x) result = sess.run(shuffle) - expected = range(20) + expected = range(1 << 16) # Compare sets to avoid randomness behavior changes but make sure still # have all the values. self.assertAllEqual(set(result), set(expected)) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 16f293891d56d78885dd515bb7b9899faf0690f7..c0ea242044540b1cef44186880ba3cd92b8849d6 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -62,6 +62,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" @@ -101,6 +102,9 @@ class OpTestBuilder { OpTestBuilder& RandomInput(DataType type); OpTestBuilder& RandomInput(DataType type, std::vector dims); + // As RandomInput but the values are unique. + OpTestBuilder& RandomUniqueInput(DataType type, std::vector dims); + // Sets an attribute. template OpTestBuilder& Attr(StringPiece attr_name, T&& value); @@ -126,6 +130,7 @@ class OpTestBuilder { DataType type = DT_INVALID; bool has_dims = false; + bool needs_unique_values = false; std::vector dims; }; @@ -167,6 +172,18 @@ OpTestBuilder& OpTestBuilder::RandomInput(DataType type, return *this; } +OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type, + std::vector dims) { + VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString(); + InputDescription input; + input.type = type; + input.has_dims = true; + input.needs_unique_values = true; + input.dims = std::move(dims); + inputs_.push_back(input); + return *this; +} + template OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) { AddNodeAttr(attr_name, std::forward(value), &node_def_); @@ -289,7 +306,8 @@ class OpTest : public ::testing::Test { // Returns a tensor filled with random but "reasonable" values from the middle // of the type's range. If the shape is omitted, a random shape is used. // TODO(phawkins): generalize this code to a caller-supplied distribution. - Tensor RandomTensor(DataType dtype, gtl::ArraySlice shape); + Tensor RandomTensor(DataType dtype, bool needs_unique_values, + gtl::ArraySlice shape); Tensor RandomTensor(DataType dtype); // Like RandomTensor, but uses values >= 0. @@ -432,49 +450,90 @@ std::vector OpTest::RandomDims(int min_rank, int max_rank, return dims; } -Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice shape) { +Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, + gtl::ArraySlice shape) { Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { + gtl::FlatSet already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); - test::FillFn(&tensor, [this, &distribution](int i) -> float { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> float { + float generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_DOUBLE: { + gtl::FlatSet already_generated; std::uniform_real_distribution distribution(-1.0, 1.0); - test::FillFn(&tensor, [this, &distribution](int i) -> double { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> double { + double generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_COMPLEX64: { + gtl::FlatSet> already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); - test::FillFn(&tensor, [this, &distribution](int i) { - return complex64(distribution(generator()), distribution(generator())); + test::FillFn(&tensor, [&](int i) { + complex64 generated; + do { + generated = + complex64(distribution(generator()), distribution(generator())); + } while ( + needs_unique_values && + !already_generated + .insert(std::make_pair(generated.real(), generated.imag())) + .second); + return generated; }); break; } case DT_INT32: { + gtl::FlatSet already_generated; std::uniform_int_distribution distribution(-(1 << 20), 1 << 20); - test::FillFn(&tensor, [this, &distribution](int i) -> int32 { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> int32 { + int32 generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_INT64: { + gtl::FlatSet already_generated; std::uniform_int_distribution distribution(-(1LL << 40), 1LL << 40); - test::FillFn(&tensor, [this, &distribution](int i) -> int64 { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> int64 { + int64 generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_BOOL: { + gtl::FlatSet already_generated; std::bernoulli_distribution distribution; - test::FillFn(&tensor, [this, &distribution](int i) -> bool { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> bool { + bool generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } @@ -485,7 +544,7 @@ Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice shape) { } Tensor OpTest::RandomTensor(DataType dtype) { - return RandomTensor(dtype, RandomDims()); + return RandomTensor(dtype, /*needs_unique_values=*/false, RandomDims()); } Tensor OpTest::RandomNonNegativeTensor(DataType dtype, @@ -761,7 +820,8 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( VLOG(1) << "Ignoring oversize dims."; return kInvalid; } - input_tensors.push_back(RandomTensor(input.type, dims)); + input_tensors.push_back( + RandomTensor(input.type, input.needs_unique_values, dims)); } VLOG(1) << "Input: " << input_tensors.back().DebugString(); } @@ -960,7 +1020,7 @@ TEST_F(OpTest, ArgMax) { std::uniform_int_distribution(-num_dims, num_dims)(generator()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ArgMax") - .RandomInput(DT_FLOAT, dims) + .RandomUniqueInput(DT_FLOAT, dims) .Input(test::AsScalar(reduce_dim)) .Attr("T", DT_FLOAT) .Attr("Tidx", DT_INT32) @@ -976,7 +1036,7 @@ TEST_F(OpTest, ArgMin) { std::uniform_int_distribution(-num_dims, num_dims)(generator()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ArgMin") - .RandomInput(DT_FLOAT, dims) + .RandomUniqueInput(DT_FLOAT, dims) .Input(test::AsScalar(reduce_dim)) .Attr("T", DT_FLOAT) .Attr("Tidx", DT_INT32) diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 7420724bdbeab63b39542ada59328621febad895..cea2ec816f85e88b11e6e80c91c14fca9015f45c 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -22,7 +22,7 @@ import functools import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class ReduceOpsTest(XLATestCase): +class ReduceOpsTest(xla_test.XLATestCase): def _testReduction(self, tf_reduce_fn, @@ -156,7 +156,7 @@ class ReduceOpsTest(XLATestCase): self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) -class ReduceOpPrecisionTest(XLATestCase): +class ReduceOpPrecisionTest(xla_test.XLATestCase): def _testReduceSum(self, expected_result, diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py index e78a63465b80644d8810d9fa7433653bc4639fed..c69b6837b0f88ced844faf3713a29a1c14c8790d 100644 --- a/tensorflow/compiler/tests/reduce_window_test.py +++ b/tensorflow/compiler/tests/reduce_window_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class ReduceWindowTest(XLATestCase): +class ReduceWindowTest(xla_test.XLATestCase): """Test cases for xla.reduce_window.""" def _reduce_window(self, operand, init, reducer, **kwargs): diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py index 18fabca28c9817fc8517595fa1694a18399f54b0..d01c676e7c2fe705344f26818350c46c30451c67 100644 --- a/tensorflow/compiler/tests/reverse_ops_test.py +++ b/tensorflow/compiler/tests/reverse_ops_test.py @@ -21,14 +21,14 @@ from __future__ import print_function import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class ReverseOpsTest(XLATestCase): +class ReverseOpsTest(xla_test.XLATestCase): def testReverseOneDim(self): shape = (7, 5, 9, 11) diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py index 1a5d05094e53cfecd9476d7d87f023e8a02d7458..ccfa63001653537c4d1b7140e3d745c126f9034b 100644 --- a/tensorflow/compiler/tests/reverse_sequence_op_test.py +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ReverseSequenceTest(XLATestCase): +class ReverseSequenceTest(xla_test.XLATestCase): def _testReverseSequence(self, x, diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index ecdce4f052bbe3eeae8697c02c891105103f4f69..ff8bbac911abe73f946464663984ff1626302882 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -28,33 +28,104 @@ from tensorflow.python.platform import test from tensorflow.python.training import rmsprop -class RmspropTest(XLATestCase): +class RmspropTest(xla_test.XLATestCase): + + def _rmsprop_update_numpy(self, + var, + g, + mg, + rms, + mom, + lr, + decay=0.9, + momentum=0.0, + epsilon=1e-10, + centered=False): + rms_t = rms * decay + (1 - decay) * g * g + denom_t = rms_t + epsilon + if centered: + mg_t = mg * decay + (1 - decay) * g + denom_t -= mg_t * mg_t + else: + mg_t = mg + mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype) + var_t = var - mom_t + return var_t, mg_t, rms_t, mom_t def testBasic(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - rms_opt = rmsprop.RMSPropOptimizer(3.0) - rms_update = rms_opt.apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - - # Run 3 steps of RMSProp - for _ in range(3): - rms_update.run() - - # Validate updated params - self.assertAllCloseAccordingToType( - np.array([2.91705132e-04, 1.00029182e+00]), var0.eval()) - self.assertAllCloseAccordingToType( - np.array([2.89990854, 3.89990854]), var1.eval()) + for centered in [False, True]: + with self.test_session(), self.test_scope(): + # Initialize variables for numpy implementation. + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + mg0_np = np.array([0.0, 0.0], dtype=dtype) + mg1_np = np.array([0.0, 0.0], dtype=dtype) + rms0_np = np.array([1.0, 1.0], dtype=dtype) + rms1_np = np.array([1.0, 1.0], dtype=dtype) + mom0_np = np.array([0.0, 0.0], dtype=dtype) + mom1_np = np.array([0.0, 0.0], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + learning_rate = 3.0 + rms_opt = rmsprop.RMSPropOptimizer(learning_rate, centered=centered) + rms_update = rms_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + mg0 = rms_opt.get_slot(var0, "mg") + self.assertEqual(mg0 is not None, centered) + mg1 = rms_opt.get_slot(var1, "mg") + self.assertEqual(mg1 is not None, centered) + rms0 = rms_opt.get_slot(var0, "rms") + self.assertTrue(rms0 is not None) + rms1 = rms_opt.get_slot(var1, "rms") + self.assertTrue(rms1 is not None) + mom0 = rms_opt.get_slot(var0, "momentum") + self.assertTrue(mom0 is not None) + mom1 = rms_opt.get_slot(var1, "momentum") + self.assertTrue(mom1 is not None) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of RMSProp + for _ in range(3): + rms_update.run() + + var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( + var0_np, + grads0_np, + mg0_np, + rms0_np, + mom0_np, + learning_rate, + centered=centered) + var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy( + var1_np, + grads1_np, + mg1_np, + rms1_np, + mom1_np, + learning_rate, + centered=centered) + + # Validate updated params + if centered: + self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) + self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) + self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) + self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) + self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) + self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 3260e63b23226d736a7ddc0f21a94a8c791e0442..4292352e76ebcef7dbf41df7b857d2604a468117 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops @@ -69,7 +69,7 @@ def handle_options(func, x, axis, exclusive, reverse): return x -class CumsumTest(XLATestCase): +class CumsumTest(xla_test.XLATestCase): valid_dtypes = [np.float32] @@ -147,7 +147,7 @@ class CumsumTest(XLATestCase): math_ops.cumsum(input_tensor, [0]).eval() -class CumprodTest(XLATestCase): +class CumprodTest(xla_test.XLATestCase): valid_dtypes = [np.float32] diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index 638946e234daf28dc4a34e6c33fc0f78b8e8699b..f606f88545d0b6f0b52cee9b93083a6bd91169bc 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -22,7 +22,7 @@ import functools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -68,7 +68,7 @@ def _NumpyUpdate(indices, updates, shape): return _NumpyScatterNd(ref, indices, updates, lambda p, u: u) -class ScatterNdTest(XLATestCase): +class ScatterNdTest(xla_test.XLATestCase): def _VariableRankTest(self, np_scatter, diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 4a9c0e7471f9cdb2a47b54705495d2dda9748890..772c20fd424577c3e06eeae409f424b77b52aa8a 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -21,26 +21,40 @@ from __future__ import print_function import functools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class SegmentReductionOpsTest(XLATestCase): +class SegmentReductionOpsTest(xla_test.XLATestCase): """Test cases for segment reduction ops.""" - def UnsortedSegmentSum(self, data, indices, num_segments): + def _segmentReduction(self, op, data, indices, num_segments): with self.test_session() as sess, self.test_scope(): d = array_ops.placeholder(data.dtype, shape=data.shape) if isinstance(indices, int): i = array_ops.placeholder(np.int32, shape=[]) else: i = array_ops.placeholder(indices.dtype, shape=indices.shape) - return sess.run( - math_ops.unsorted_segment_sum(d, i, num_segments), - {d: data, - i: indices}) + return sess.run(op(d, i, num_segments), {d: data, i: indices}) + + def _unsortedSegmentSum(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_sum, data, indices, + num_segments) + + def _unsortedSegmentProd(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_prod, data, indices, + num_segments) + + def _unsortedSegmentMin(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_min, data, indices, + num_segments) + + def _unsortedSegmentMax(self, data, indices, num_segments): + return self._segmentReduction(math_ops.unsorted_segment_max, data, indices, + num_segments) def testUnsortedSegmentSum0DIndices1DData(self): for dtype in self.numeric_types: @@ -49,14 +63,14 @@ class SegmentReductionOpsTest(XLATestCase): [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 0]], dtype=dtype), - self.UnsortedSegmentSum( + self._unsortedSegmentSum( np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4)) def testUnsortedSegmentSum1DIndices1DData(self): for dtype in self.numeric_types: self.assertAllClose( np.array([1, 3, 2, 9], dtype=dtype), - self.UnsortedSegmentSum( + self._unsortedSegmentSum( np.array([0, 1, 2, 3, 4, 5], dtype=dtype), np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4)) @@ -64,7 +78,7 @@ class SegmentReductionOpsTest(XLATestCase): for dtype in self.numeric_types: self.assertAllClose( np.array([6, 3, 0, 6], dtype=dtype), - self.UnsortedSegmentSum( + self._unsortedSegmentSum( np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) @@ -76,7 +90,7 @@ class SegmentReductionOpsTest(XLATestCase): dtype=dtype) indices = np.array([8, 1, 0, 3, 7], dtype=np.int32) num_segments = 10 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( [[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0], @@ -92,7 +106,7 @@ class SegmentReductionOpsTest(XLATestCase): dtype=dtype) indices = np.array([0, 1, 2, 0, 1], dtype=np.int32) num_segments = 4 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( [[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33], @@ -102,30 +116,30 @@ class SegmentReductionOpsTest(XLATestCase): def testUnsortedSegmentSum2DIndices3DData(self): for dtype in self.numeric_types: data = np.array( - [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], - [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], - [310, 311, 312]]], + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ + 200, 201, 202 + ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], dtype=dtype) indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32) num_segments = 8 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( - [[210, 211, 212], [110, 111, 112], [310, 311, 312], - [100, 102, 104], [0, 0, 0.], [210, 212, 214], [300, 301, - 302], [0, 0, 0]], + [[210, 211, 212], [110, 111, 112], [310, 311, 312], [ + 100, 102, 104 + ], [0, 0, 0.], [210, 212, 214], [300, 301, 302], [0, 0, 0]], dtype=dtype), y) def testUnsortedSegmentSum1DIndices3DData(self): for dtype in self.numeric_types: data = np.array( - [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], - [[200, 201, 202], [210, 211, 212]], [[300, 301, 302], - [310, 311, 312]]], + [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ + 200, 201, 202 + ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], dtype=dtype) indices = np.array([3, 0, 2, 5], dtype=np.int32) num_segments = 6 - y = self.UnsortedSegmentSum(data, indices, num_segments) + y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]], @@ -138,10 +152,40 @@ class SegmentReductionOpsTest(XLATestCase): data = np.ones((4, 8, 7), dtype=dtype) indices = np.ones((3, 2), dtype=np.int32) num_segments = 4 - self.assertRaises(ValueError, - functools.partial(self.UnsortedSegmentSum, data, - indices, num_segments)) + self.assertRaises( + ValueError, + functools.partial(self._segmentReduction, + math_ops.unsorted_segment_sum, data, indices, + num_segments)) + + def testUnsortedSegmentOps1DIndices1DDataNegativeIndices(self): + """Tests for min, max, and prod ops. + + These share most of their implementation with sum, so we only test basic + functionality. + """ + for dtype in self.numeric_types: + self.assertAllClose( + np.array([8, 3, 1, 0], dtype=dtype), + self._unsortedSegmentProd( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) + + for dtype in self.int_types | self.float_types: + minval = dtypes.as_dtype(dtype).min + maxval = dtypes.as_dtype(dtype).max + + self.assertAllClose( + np.array([2, 3, maxval, 0], dtype=dtype), + self._unsortedSegmentMin( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) + self.assertAllClose( + np.array([4, 3, minval, 6], dtype=dtype), + self._unsortedSegmentMax( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) -if __name__ == '__main__': +if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index 305ca0c6b78d3ef985deb38816f9388e7983906b..6c4890565d2083a9493abc59bd563c4dd9fdb186 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class SliceTest(XLATestCase): +class SliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: @@ -110,7 +110,7 @@ class SliceTest(XLATestCase): self.assertAllEqual([[[1, 1, 1, 1], [6, 5, 4, 3]]], result) -class StridedSliceTest(XLATestCase): +class StridedSliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff01be3cb4848d6bb85b8ab96b3ee1db6889791 --- /dev/null +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -0,0 +1,172 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sorting operators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class XlaSortOpTest(xla_test.XLATestCase): + + def _assertOpOutputMatchesExpected(self, op, args, expected): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + if isinstance(output, ops.Tensor): + output = [output] + + results = session.run(output, feeds) + for result, v in zip(results, expected): + self.assertAllClose(v, result, rtol=1e-3) + + def testSort(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + for dtype in supported_types.intersection(self.numeric_types): + x = np.arange(101, dtype=dtype) + np.random.shuffle(x) + self._assertOpOutputMatchesExpected( + xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) + + def testTopK(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for dtype in supported_types.intersection(self.numeric_types): + # Use small input size for bfloat16. Otherwise, we'll get duplicate values + # after conversion to bfloat16, so the possible resulting index array is + # no longer unique. + if dtype == dtypes.bfloat16.as_numpy_dtype: + array_size = 20 + k_options = [0, 1, 2, 10, 20] + else: + array_size = 200 * 1000 + k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] + for x in [np.arange(array_size)]: + np.random.shuffle(x) + for k in k_options: + indices = x.argsort()[::-1][:k] + + def topk(v, k=k): + return nn_ops.top_k(v, k=k, sorted=True) + + self._assertOpOutputMatchesExpected( + topk, [x.astype(dtype)], + expected=[x[indices].astype(dtype), indices]) + + def testTopK2D(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for dtype in supported_types.intersection(self.numeric_types): + # Use small input size for bfloat16. Otherwise, we'll get duplicate values + # after conversion to bfloat16, so the possible resulting index array is + # no longer unique. + if dtype == dtypes.bfloat16.as_numpy_dtype: + array_size = 10 + k_options = [0, 1, 2, 10] + else: + array_size = 200 * 1000 + k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] + batch = 16 + for x in [np.arange(batch * array_size)]: + np.random.shuffle(x) + x = np.reshape(x, [batch, array_size]) + for k in k_options: + indices = x.argsort(axis=1)[::, -1:-k - 1:-1] + expected = np.sort(x, axis=1)[::, -1:-k - 1:-1] + + def topk(v, k=k): + return nn_ops.top_k(v, k=k, sorted=True) + + self._assertOpOutputMatchesExpected( + topk, [x.astype(dtype)], + expected=[expected.astype(dtype), indices]) + + def testTopKZeros(self): + """Tests that positive and negative zeros sort correctly.""" + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + # Only bfloat16 is implemented. + bfloat16 = dtypes.bfloat16.as_numpy_dtype + if bfloat16 not in self.numeric_types: + return + + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.bfloat16) + with self.test_scope(): + topk = nn_ops.top_k(p, k=4) + results = sess.run( + topk, + {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)}) + self.assertAllEqual( + np.array([3., 0., 0., 0.], dtype=bfloat16), results[0]) + self.assertEqual(list([3, 0, 2, 6]), list(results[1])) + + def testTopKInfinities(self): + """Tests that positive and negative infinity sort correctly.""" + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + # Only bfloat16 is implemented. + bfloat16 = dtypes.bfloat16.as_numpy_dtype + if bfloat16 not in self.numeric_types: + return + + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.bfloat16) + with self.test_scope(): + topk = nn_ops.top_k(p, k=6) + results = sess.run(topk, { + p: np.array( + [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16) + }) + self.assertAllEqual( + np.array( + [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")], + dtype=bfloat16), results[0]) + self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index f37c34156f96761632247be4bc1b62fca54f666e..c685bc548f9f6f8f7723c6f94dfd45f5420b4a67 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -68,7 +68,7 @@ def space_to_batch_direct(input_array, block_shape, paddings): return permuted_reshaped_padded.reshape(output_shape) -class SpaceToBatchTest(XLATestCase): +class SpaceToBatchTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" def _testPad(self, inputs, paddings, block_size, outputs): @@ -149,7 +149,7 @@ class SpaceToBatchTest(XLATestCase): self._testOne(x_np, block_size, x_out) -class SpaceToBatchNDTest(XLATestCase): +class SpaceToBatchNDTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops.""" def _testPad(self, inputs, block_shape, paddings, outputs): diff --git a/tensorflow/compiler/tests/sparse_to_dense_op_test.py b/tensorflow/compiler/tests/sparse_to_dense_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3db8101c4bfbb1b53c7318a36519612984d6f179 --- /dev/null +++ b/tensorflow/compiler/tests/sparse_to_dense_op_test.py @@ -0,0 +1,118 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.kernels.sparse_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +def _SparseToDense(sparse_indices, + output_size, + sparse_values, + default_value, + validate_indices=True): + feed_sparse_indices = array_ops.placeholder(dtypes.int32) + feed_dict = {feed_sparse_indices: sparse_indices} + return sparse_ops.sparse_to_dense( + feed_sparse_indices, + output_size, + sparse_values, + default_value=default_value, + validate_indices=validate_indices).eval(feed_dict=feed_dict) + + +class SparseToDenseTest(xla_test.XLATestCase): + + def testInt(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([1, 3], [5], 1, 0) + np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32) + self.assertAllClose(np_ans, tf_ans) + + def testFloat(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0) + np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32) + self.assertAllClose(np_ans, tf_ans) + + def testSetValue(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1) + np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32) + self.assertAllClose(np_ans, tf_ans) + + def testSetSingleValue(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([1, 3], [5], 1, -1) + np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32) + self.assertAllClose(np_ans, tf_ans) + + def test2d(self): + # pylint: disable=bad-whitespace + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1) + np_ans = np.array([[-1, -1, -1, -1], + [-1, -1, -1, 1], + [ 1, -1, -1, -1]]).astype(np.int32) + self.assertAllClose(np_ans, tf_ans) + + def testZeroDefault(self): + with self.test_session(): + x = sparse_ops.sparse_to_dense(2, [4], 7).eval() + self.assertAllEqual(x, [0, 0, 7, 0]) + + def test3d(self): + with self.test_session(), self.test_scope(): + tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1) + np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1 + np_ans[1, 3, 0] = 1 + np_ans[2, 0, 1] = 1 + self.assertAllClose(np_ans, tf_ans) + + def testBadShape(self): + with self.test_session(), self.test_scope(): + with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"): + _SparseToDense([1, 3], [[5], [3]], 1, -1) + + def testBadValue(self): + with self.test_session(), self.test_scope(): + with self.assertRaisesOpError( + r"sparse_values has incorrect shape \[2,1\], " + r"should be \[\] or \[2\]"): + _SparseToDense([1, 3], [5], [[5], [3]], -1) + + def testBadNumValues(self): + with self.test_session(), self.test_scope(): + with self.assertRaisesOpError( + r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"): + _SparseToDense([1, 3], [5], [1, 2, 3], -1) + + def testBadDefault(self): + with self.test_session(), self.test_scope(): + with self.assertRaisesOpError("default_value should be a scalar"): + _SparseToDense([1, 3], [5], [1, 2], [0]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py index 94342f9567ca71274609e63b0482d55637c98d51..b7dd787feff2b22a9cfb5d43a4ba6ceb6eb0b301 100644 --- a/tensorflow/compiler/tests/stack_ops_test.py +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -28,7 +28,7 @@ from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.platform import test -class StackOpTest(XLATestCase): +class StackOpTest(xla_test.XLATestCase): def testStackPushPop(self): with self.test_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index b6f8390a45d43bf7666b90e14cc6ff2f3f61947e..d162675ef840131485128414b4a29e3cd89c8761 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -22,14 +22,15 @@ import math import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.contrib import stateless from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import test -class StatelessRandomOpsTest(XLATestCase): +class StatelessRandomOpsTest(xla_test.XLATestCase): """Test cases for stateless random-number generator operators.""" def _random_types(self): @@ -122,6 +123,56 @@ class StatelessRandomOpsTest(XLATestCase): # so to avoid flakiness the seed is fixed. self.assertTrue(self._anderson_darling(y) < 2.492) + def testTruncatedNormalIsInRange(self): + # TODO(b/34339814): implement inverse erf support for non-F32 types. + for dtype in [dtypes.float32]: + with self.test_session() as sess, self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + n = 10000000 + x = stateless.stateless_truncated_normal( + shape=[n], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + + def normal_cdf(x): + return .5 * math.erfc(-x / math.sqrt(2)) + + def normal_pdf(x): + return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) + + def probit(x, sess=sess): + return sess.run(special_math.ndtri(x)) + + a = -2. + b = 2. + mu = 0. + sigma = 1. + + alpha = (a - mu) / sigma + beta = (b - mu) / sigma + z = normal_cdf(beta) - normal_cdf(alpha) + + self.assertTrue((y >= a).sum() == n) + self.assertTrue((y <= b).sum() == n) + + # For more information on these calculations, see: + # Burkardt, John. "The Truncated Normal Distribution". + # Department of Scientific Computing website. Florida State University. + expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma + actual_mean = np.mean(y) + self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + + expected_median = mu + probit( + (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma + actual_median = np.median(y) + self.assertAllClose(actual_median, expected_median, atol=8e-4) + + expected_variance = sigma**2 * (1 + ( + (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( + (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) + actual_variance = np.var(y) + self.assertAllClose(actual_variance, expected_variance, rtol=1e-3) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index ef047005b60bd156a677050368ef67ae030d6c3a..effa5a59fee7dda543b2c409dfaa27a972a55808 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops @@ -28,7 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class TernaryOpsTest(XLATestCase): +class TernaryOpsTest(xla_test.XLATestCase): def _testTernary(self, op, a, b, c, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/test_utils.py b/tensorflow/compiler/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6abde18ea91f16d153a154b94effab037a911c6c --- /dev/null +++ b/tensorflow/compiler/tests/test_utils.py @@ -0,0 +1,63 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for helping test ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +def ConvertBetweenDataFormats(x, data_format_src, data_format_dst): + """Converts 4D tensor between data formats.""" + + valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"] + if data_format_src not in valid_data_formats: + raise ValueError("data_format_src must be of %s, got %s." % + (valid_data_formats, data_format_src)) + if data_format_dst not in valid_data_formats: + raise ValueError("data_format_dst must be of %s, got %s." % + (valid_data_formats, data_format_dst)) + if len(x.shape) != 4: + raise ValueError("x must be 4D, got shape %s." % x.shape) + + if data_format_src == data_format_dst: + return x + + dim_map = {d: i for i, d in enumerate(data_format_src)} + transpose_dims = [dim_map[d] for d in data_format_dst] + return np.transpose(x, transpose_dims) + + +def PermuteDimsBetweenDataFormats(dims, data_format_src, data_format_dst): + """Get new shape for converting between data formats.""" + + valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"] + if data_format_src not in valid_data_formats: + raise ValueError("data_format_src must be of %s, got %s." % + (valid_data_formats, data_format_src)) + if data_format_dst not in valid_data_formats: + raise ValueError("data_format_dst must be of %s, got %s." % + (valid_data_formats, data_format_dst)) + if len(dims) != 4: + raise ValueError("dims must be of length 4, got %s." % dims) + + if data_format_src == data_format_dst: + return dims + + dim_map = {d: i for i, d in enumerate(data_format_src)} + permuted_dims = [dims[dim_map[d]] for d in data_format_dst] + return permuted_dims diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 689a4a1f4e02f5dd48f64dc94afd0fcb50df8b5b..73adb0d243b3b27e6c6ba669b2fd134a5976a2ec 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -23,7 +23,7 @@ import unittest import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops @@ -44,11 +44,16 @@ def nhwc_to_format(x, data_format): raise ValueError("Unknown format {}".format(data_format)) -class UnaryOpsTest(XLATestCase): +class UnaryOpsTest(xla_test.XLATestCase): """Test cases for unary operators.""" - def _assertOpOutputMatchesExpected(self, op, inp, expected, - equality_test=None, rtol=1e-3, atol=1e-5): + def _assertOpOutputMatchesExpected(self, + op, + inp, + expected, + equality_test=None, + rtol=1e-3, + atol=1e-5): """Verifies that 'op' produces 'expected' when fed input 'inp' . Args: @@ -81,10 +86,10 @@ class UnaryOpsTest(XLATestCase): def testAllTypeOps(self): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( - array_ops.diag, - np.array([1, 2, 3, 4], dtype=dtype), - np.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], - dtype=dtype)) + array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype), + np.array( + [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], + dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.diag_part, np.arange(36).reshape([2, 3, 2, 3]).astype(dtype), @@ -102,8 +107,7 @@ class UnaryOpsTest(XLATestCase): expected=np.array([[-1, 1]], dtype=dtype)) self._assertOpOutputMatchesExpected( - array_ops.matrix_diag, - np.array([[1, 2], [3, 4]], dtype=dtype), + array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype), np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype), @@ -115,10 +119,10 @@ class UnaryOpsTest(XLATestCase): np.array( [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype), np.array( - [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], - [[4, 0, 0], [0, 5, 0], [0, 0, 6]]], - [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], - [[10, 0, 0], [0, 11, 0], [0, 0, 12]]]], + [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], [[4, 0, 0], [0, 5, 0], [ + 0, 0, 6 + ]]], [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], [[10, 0, 0], [0, 11, 0], + [0, 0, 12]]]], dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.matrix_diag_part, @@ -159,36 +163,30 @@ class UnaryOpsTest(XLATestCase): continue x = np.arange(-0.90, 0.90, 0.25) self._assertOpOutputMatchesExpected( - math_ops.acos, - x.astype(dtype), - expected=np.arccos(x).astype(dtype)) + math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype)) self._assertOpOutputMatchesExpected( - math_ops.asin, - x.astype(dtype), - expected=np.arcsin(x).astype(dtype)) + math_ops.asin, x.astype(dtype), expected=np.arcsin(x).astype(dtype)) x = np.arange(-3, 3).reshape(1, 3, 2) self._assertOpOutputMatchesExpected( - math_ops.atan, - x.astype(dtype), - expected=np.arctan(x).astype(dtype)) + math_ops.atan, x.astype(dtype), expected=np.arctan(x).astype(dtype)) self._assertOpOutputMatchesExpected( math_ops.acosh, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([0, 1.3169579, 1.76274717, 2.06343707], - dtype=dtype)) + expected=np.array( + [0, 1.3169579, 1.76274717, 2.06343707], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.asinh, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([0.88137359, 1.44363548, 1.81844646, 2.09471255], - dtype=dtype)) + expected=np.array( + [0.88137359, 1.44363548, 1.81844646, 2.09471255], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.atanh, np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype), - expected=np.array([0.10033535, 0.20273255, 0.3095196, 0.42364893], - dtype=dtype)) + expected=np.array( + [0.10033535, 0.20273255, 0.3095196, 0.42364893], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.ceil, @@ -198,8 +196,18 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.cosh, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([1.54308063, 3.76219569, 10.067662, 27.30823284], - dtype=dtype)) + expected=np.array( + [1.54308063, 3.76219569, 10.067662, 27.30823284], dtype=dtype)) + + # Disable float16 testing for now + if dtype != np.float16: + x = np.arange(-10, 10, 1).astype(dtype) + with self.test_session() as session: + erf_x = session.run(math_ops.erf(x)) + erfc_x = session.run(math_ops.erfc(x)) + + self._assertOpOutputMatchesExpected(math_ops.erf, x, expected=erf_x) + self._assertOpOutputMatchesExpected(math_ops.erfc, x, expected=erfc_x) self._assertOpOutputMatchesExpected( math_ops.exp, @@ -219,8 +227,8 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.is_finite, - np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], - dtype=dtype), + np.array( + [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype), expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool)) # Tests for tf.nn ops. @@ -261,16 +269,20 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.rint, - np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], - [0.5, 1.5, 2.5, 3.5]], dtype=dtype), - expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], - dtype=dtype)) + np.array( + [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], + [0.5, 1.5, 2.5, 3.5]], + dtype=dtype), + expected=np.array( + [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.round, - np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], - [0.5, 1.5, 2.5, 3.5]], dtype=dtype), - expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], - dtype=dtype)) + np.array( + [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], + [0.5, 1.5, 2.5, 3.5]], + dtype=dtype), + expected=np.array( + [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.rsqrt, @@ -279,10 +291,7 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.sigmoid, - np.array( - [[1, 1, 1, 1], - [1, 2, 3, 4]], - dtype=dtype), + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), expected=np.array( [[0.7310586, 0.7310586, 0.7310586, 0.7310586], [0.7310586, 0.880797, 0.95257413, 0.98201376]], @@ -296,8 +305,8 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.sinh, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([1.17520119, 3.62686041, 10.01787493, 27.2899172], - dtype=dtype)) + expected=np.array( + [1.17520119, 3.62686041, 10.01787493, 27.2899172], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.sqrt, @@ -307,15 +316,12 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.tan, np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array([1.55740772, -2.18503986, -0.14254654, 1.15782128], - dtype=dtype)) + expected=np.array( + [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.tanh, - np.array( - [[1, 1, 1, 1], - [1, 2, 3, 4]], - dtype=dtype), + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), expected=np.array( [[0.76159418, 0.76159418, 0.76159418, 0.76159418], [0.76159418, 0.96402758, 0.99505478, 0.99932933]], @@ -323,10 +329,7 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.log_softmax, - np.array( - [[1, 1, 1, 1], - [1, 2, 3, 4]], - dtype=dtype), + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), expected=np.array( [[-1.3862944, -1.3862944, -1.3862944, -1.3862944], [-3.4401896, -2.4401896, -1.4401897, -0.44018969]], @@ -360,20 +363,31 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.softmax, - np.array( - [[1, 1, 1, 1], - [1, 2, 3, 4]], - dtype=dtype), + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array([0.032058604, 0.087144323, 0.23688284, 0.64391428], + dtype=dtype)) + + self._assertOpOutputMatchesExpected( + nn_ops.softmax, + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), expected=np.array( [[0.25, 0.25, 0.25, 0.25], [0.032058604, 0.087144323, 0.23688284, 0.64391428]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.softmax, + np.array([[[1, 1], [1, 1]], [[1, 2], [3, 4]]], dtype=dtype), + expected=np.array( + [[[0.5, 0.5], [0.5, 0.5]], + [[0.26894142, 0.73105858], [0.26894142, 0.73105858]]], + dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.softsign, np.array([[-2, -1, 0, 1, 2]], dtype=dtype), - expected=np.array([[-0.66666669, -0.5, 0, 0.5, 0.66666669]], - dtype=dtype)) + expected=np.array( + [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.is_finite, @@ -383,9 +397,78 @@ class UnaryOpsTest(XLATestCase): [[True, False, True], [False, True, True]], dtype=np.bool)) self._assertOpOutputMatchesExpected( - lambda x: array_ops.quantize_and_dequantize_v2(x, -127, 127, True, 8), + math_ops.lgamma, + np.array( + [[1, 2, 3], [4, 5, 6], [1 / 2, 3 / 2, 5 / 2], + [-3 / 2, -7 / 2, -11 / 2]], + dtype=dtype), + expected=np.array( + [ + [0, 0, np.log(2.0)], + [np.log(6.0), np.log(24.0), + np.log(120)], + [ + np.log(np.pi) / 2, + np.log(np.pi) / 2 - np.log(2), + np.log(np.pi) / 2 - np.log(4) + np.log(3) + ], + [ + np.log(np.pi) / 2 - np.log(3) + np.log(4), + np.log(np.pi) / 2 - np.log(105) + np.log(16), + np.log(np.pi) / 2 - np.log(10395) + np.log(64), + ], + ], + dtype=dtype)) + + self._assertOpOutputMatchesExpected( + math_ops.digamma, + np.array( + [[1.0, 0.5, 1 / 3.0], [0.25, 1 / 6.0, 0.125], [2.0, 3.0, 4.0], + [6.0, 8.0, 9.0]], + dtype=dtype), + expected=np.array( + [ + [ + -np.euler_gamma, -2 * np.log(2) - np.euler_gamma, + -np.pi / 2 / np.sqrt(3) - 3 * np.log(3) / 2 - + np.euler_gamma + ], + [ + -np.pi / 2 - 3 * np.log(2) - np.euler_gamma, + -np.pi * np.sqrt(3) / 2 - 2 * np.log(2) - + 3 * np.log(3) / 2 - np.euler_gamma, + -np.pi / 2 - 4 * np.log(2) - + (np.pi + np.log(2 + np.sqrt(2)) - np.log(2 - np.sqrt(2))) + / np.sqrt(2) - np.euler_gamma + ], + [ + 1 - np.euler_gamma, 1.5 - np.euler_gamma, + 11 / 6.0 - np.euler_gamma + ], + [ + 137 / 60.0 - np.euler_gamma, 363 / 140.0 - np.euler_gamma, + 761 / 280.0 - np.euler_gamma + ], + ], + dtype=dtype)) + + def quantize_and_dequantize_v2(x): + return array_ops.quantize_and_dequantize_v2( + x, -127, 127, signed_input=True, num_bits=8) + + self._assertOpOutputMatchesExpected( + quantize_and_dequantize_v2, np.array([-1, -0.5, 0, 0.3], dtype=dtype), - expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype)) + expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) + + def quantize_and_dequantize_v3(x): + return array_ops.quantize_and_dequantize_v3( + x, -127, 127, num_bits=8, signed_input=True, range_given=False) + + self._assertOpOutputMatchesExpected( + quantize_and_dequantize_v3, + np.array([-1, -0.5, 0, 0.3], dtype=dtype), + expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) def testComplexOps(self): for dtype in self.complex_types: @@ -566,13 +649,13 @@ class UnaryOpsTest(XLATestCase): for dtype in self.float_types: self._assertOpOutputMatchesExpected( math_ops.is_inf, - np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], - dtype=dtype), + np.array( + [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype), expected=np.array([[1, 0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.bool)) self._assertOpOutputMatchesExpected( math_ops.is_nan, - np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], - dtype=dtype), + np.array( + [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype), expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool)) def testLogicalOps(self): @@ -589,14 +672,15 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"), - np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], - dtype=np.float32), + np.array( + [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32), expected=np.array([10., 26.], dtype=np.float32)) def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] - types = (set([dtypes.bool, dtypes.int32, dtypes.float32]) | - self.complex_tf_types) + types = ( + set([dtypes.bool, dtypes.int32, dtypes.float32]) + | self.complex_tf_types) for shape in shapes: for src_type in types: for dst_type in types: @@ -638,14 +722,11 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( rank_op, dtype(7), expected=np.int32(0)) self._assertOpOutputMatchesExpected( - rank_op, np.array( - [[], []], dtype=dtype), expected=np.int32(2)) + rank_op, np.array([[], []], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( - rank_op, np.array( - [-1, 1], dtype=dtype), expected=np.int32(1)) + rank_op, np.array([-1, 1], dtype=dtype), expected=np.int32(1)) self._assertOpOutputMatchesExpected( - rank_op, np.array( - [[-1, 1]], dtype=dtype), expected=np.int32(2)) + rank_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( rank_op, np.array([[-1], [1], [4]], dtype=dtype), @@ -710,97 +791,97 @@ class UnaryOpsTest(XLATestCase): equality_test=self.ListsAreClose) def testDepthToSpace(self): + def make_op(data_format): + def op(x): - return array_ops.depth_to_space(x, block_size=2, - data_format=data_format) + return array_ops.depth_to_space( + x, block_size=2, data_format=data_format) + return op for dtype in self.numeric_types: for data_format in ["NCHW", "NHWC"]: self._assertOpOutputMatchesExpected( make_op(data_format), - nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), - data_format), - expected=nhwc_to_format(np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype), - data_format)) + nhwc_to_format( + np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format), + expected=nhwc_to_format( + np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format)) self._assertOpOutputMatchesExpected( make_op(data_format), nhwc_to_format( - np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], - dtype=dtype), + np.array( + [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), data_format), expected=nhwc_to_format( - np.array([[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], - dtype=dtype), - data_format)) + np.array( + [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], + dtype=dtype), data_format)) self._assertOpOutputMatchesExpected( make_op(data_format), nhwc_to_format( - np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype), - data_format), + np.array( + [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], + [13, 14, 15, 16]]]], + dtype=dtype), data_format), expected=nhwc_to_format( - np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype), - data_format)) + np.array( + [[[[1], [2], [5], [6]], [[3], [4], [7], [8]], + [[9], [10], [13], [14]], [[11], [12], [15], [16]]]], + dtype=dtype), data_format)) def testSpaceToDepth(self): + def make_op(data_format): + def op(x): - return array_ops.space_to_depth(x, block_size=2, - data_format=data_format) + return array_ops.space_to_depth( + x, block_size=2, data_format=data_format) + return op for dtype in self.numeric_types: for data_format in ["NCHW", "NHWC"]: self._assertOpOutputMatchesExpected( make_op(data_format), - nhwc_to_format(np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype), - data_format), - expected=nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), - data_format)) + nhwc_to_format( + np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format)) self._assertOpOutputMatchesExpected( make_op(data_format), - nhwc_to_format(np.array([[[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), - data_format), + nhwc_to_format( + np.array( + [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], + dtype=dtype), data_format), expected=nhwc_to_format( - np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], - dtype=dtype), + np.array( + [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), data_format)) self._assertOpOutputMatchesExpected( make_op(data_format), - nhwc_to_format(np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype), - data_format), + nhwc_to_format( + np.array( + [[[[1], [2], [5], [6]], [[3], [4], [7], [8]], + [[9], [10], [13], [14]], [[11], [12], [15], [16]]]], + dtype=dtype), data_format), expected=nhwc_to_format( - np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype), - data_format)) + np.array( + [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], + [13, 14, 15, 16]]]], + dtype=dtype), data_format)) def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) zero = np.asarray(0).astype(dtype) expected = np.logaddexp(zero, features) self._assertOpOutputMatchesExpected( - nn_ops.softplus, features, expected=expected, - rtol=1e-6, - atol=9.1e-6) + nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6) def testSoftplus(self): for dtype in self.float_types: @@ -814,9 +895,10 @@ class UnaryOpsTest(XLATestCase): one = dtype(1) ten = dtype(10) self._assertSoftplusMatchesExpected([ - log_eps, log_eps - one, log_eps + one, log_eps - ten, - log_eps + ten, -log_eps, -log_eps - one, -log_eps + one, - -log_eps - ten, -log_eps + ten], dtype) + log_eps, log_eps - one, log_eps + one, log_eps - ten, log_eps + ten, + -log_eps, -log_eps - one, -log_eps + one, -log_eps - ten, + -log_eps + ten + ], dtype) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index 2c09b03d5a35cde2c42d8a145781270c0c908587..dd2c252d383bca9c59033ac07e442b487e4975a6 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -20,12 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -36,7 +37,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.training.gradient_descent import GradientDescentOptimizer -class VariableOpsTest(XLATestCase): +class VariableOpsTest(xla_test.XLATestCase): """Test cases for resource variable operators.""" def testOneWriteOneOutput(self): @@ -52,9 +53,7 @@ class VariableOpsTest(XLATestCase): with ops.control_dependencies([x]): y = v.read_value() self.assertAllClose( - np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, { - p: 1 - })) + np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, {p: 1})) def testSparseRead0DIndices(self): for dtype in self.numeric_types: @@ -103,9 +102,9 @@ class VariableOpsTest(XLATestCase): x = v.sparse_read([[2, 1], [3, 0]]) self.assertAllClose( np.array( - [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]], - [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]], - ).astype(dtype), sess.run(x)) + [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]] + ], [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]] + ],).astype(dtype), sess.run(x)) def testShape(self): for dtype in self.numeric_types: @@ -206,6 +205,206 @@ class VariableOpsTest(XLATestCase): self.assertAllClose(update, result[1]) self.assertAllClose(update, result[2]) + def testScatterAdd(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[2, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1], [7]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_add( + handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertAllEqual(sess.run(read), [[3], [7]]) + + def testScatterSub(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[2, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[4], [1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_sub( + handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertAllEqual(sess.run(read), [[4], [-1]]) + + def testScatterMul(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_mul( + handle, [0], constant_op.constant([[5]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[5]]) + + def testScatterDiv(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_div( + handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertAllEqual(sess.run(read), [[2]]) + + def testScatterMin(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_min( + handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[3]]) + + def testScatterMax(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_max( + handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[6]]) + + def testScatterUpdate(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_update( + handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[3]]) + + def testScatterAddScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_add( + handle, [0], constant_op.constant(2, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[3]]) + + def testScatterSubScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_sub( + handle, [0], constant_op.constant(2, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[-1]]) + + def testScatterMulScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_mul( + handle, [0], constant_op.constant(5, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[5]]) + + def testScatterDivScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_div( + handle, [0], constant_op.constant(3, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[2]]) + + def testScatterMinScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_min( + handle, [0], constant_op.constant(3, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[3]]) + + def testScatterMaxScalar(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[6]], dtype=dtypes.int32))) + sess.run( + resource_variable_ops.resource_scatter_max( + handle, [0], constant_op.constant(3, dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(sess.run(read), [[6]]) + + def testScatterNdAddOps(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.float32, shape=[8]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([1] * 8, dtype=dtypes.float32))) + indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) + updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) + expected = np.array([1, 12, 1, 11, 10, 1, 1, 13]) + sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates)) + read = resource_variable_ops.read_variable_op( + handle, dtype=dtypes.float32) + self.assertAllClose(expected, sess.run(read)) + + def testScatterNdUpdateAddOps(self): + with self.test_session() as sess, self.test_scope(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.float32, shape=[8]) + sess.run( + resource_variable_ops.assign_variable_op( + handle, constant_op.constant([1] * 8, dtype=dtypes.float32))) + indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) + updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) + expected = np.array([1, 11, 1, 10, 9, 1, 1, 12]) + sess.run( + gen_state_ops.resource_scatter_nd_update(handle, indices, updates)) + read = resource_variable_ops.read_variable_op( + handle, dtype=dtypes.float32) + self.assertAllClose(expected, sess.run(read)) + class StridedSliceAssignChecker(object): """Compares the results of a slice assignment using Tensorflow and numpy.""" @@ -236,12 +435,12 @@ class StridedSliceAssignChecker(object): self.test.assertAllEqual(val, valnp) -class SliceAssignTest(XLATestCase): +class SliceAssignTest(xla_test.XLATestCase): def testSliceAssign(self): for dtype in self.numeric_types: - checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]], - dtype=dtype) + checker = StridedSliceAssignChecker( + self, [[1, 2, 3], [4, 5, 6]], dtype=dtype) # No-op assignment checker[:] = [[10, 20, 30], [40, 50, 60]] # Checks trivial (1,1) shape tensor diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index f79eb27435cc954cebde4357c1d946a320f4ed75..b637cf31cfc303ebe84ce8307ef4ad8b0b5cd720 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class WhileTest(XLATestCase): +class WhileTest(xla_test.XLATestCase): def testSingletonLoopHandrolled(self): # Define a function for the loop body diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index f0b010fa67f2ffb3f81fd14d4d89585f716b4890..85084bb1240cf05f6eabfbea772df113cabe613c 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -20,14 +20,16 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.platform import test -class XlaDeviceTest(XLATestCase): +class XlaDeviceTest(xla_test.XLATestCase): def testCopies(self): """Tests that copies onto and off XLA devices work.""" @@ -47,6 +49,34 @@ class XlaDeviceTest(XLATestCase): result = sess.run(z, {x: inputs}) self.assertAllCloseAccordingToType(result, inputs + inputs) + def testCopiesOfUnsupportedTypesFailGracefully(self): + """Tests that copies of unsupported types don't crash.""" + test_types = set([ + np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32, + np.int64, np.float16, np.float32, np.float16, + dtypes.bfloat16.as_numpy_dtype + ]) + shape = (10, 10) + for unsupported_dtype in test_types - self.all_types: + with self.test_session() as sess: + with ops.device("CPU"): + x = array_ops.placeholder(unsupported_dtype, shape) + with self.test_scope(): + y, = array_ops.identity_n([x]) + with ops.device("CPU"): + z = array_ops.identity(y) + + inputs = np.random.randint(-100, 100, shape) + inputs = inputs.astype(unsupported_dtype) + # Execution should either succeed or raise an InvalidArgumentError, + # but not crash. Even "unsupported types" may succeed here since some + # backends (e.g., the CPU backend) are happy to handle buffers of + # unsupported types, even if they cannot compute with them. + try: + sess.run(z, {x: inputs}) + except errors.InvalidArgumentError: + pass + def testControlTrigger(self): with self.test_session() as sess: with self.test_scope(): diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index e924fe1e61454aefda622a5a46a0e483d26db5c1..88827cb53bee7bb809d0163d6badcef17e59aa78 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -49,6 +49,32 @@ flags.DEFINE_string('tf_xla_flags', None, 'Value to set the TF_XLA_FLAGS environment variable to') +def parse_disabled_manifest(manifest_content): + comments_re = re.compile('#.*$') + disabled_tests = [] + disabled_method_types = [] + for l in manifest_content.splitlines(): + stripped = comments_re.sub('', l).strip() + if not stripped: + continue + entry = stripped.split(' ') + if len(entry) == 1: + disabled_tests.append(entry[0]) + elif len(entry) == 2: + disabled_method_types.append((entry[0], entry[1].strip().split(','))) + else: + raise ValueError('Bad entry in manifest file.') + + disabled_regex = '|'.join(disabled_tests) + method_types_filter = dict() + for method, types in disabled_method_types: + method_types_filter[method] = set([ + dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype + for name in types + ]) + return disabled_regex, method_types_filter + + class XLATestCase(test.TestCase): """XLA test cases are parameterized test cases.""" @@ -85,38 +111,21 @@ class XLATestCase(test.TestCase): # Parse the manifest file, if any, into a regex identifying tests to # disable - self.disabled_regex = None - self._method_types_filter = dict() # TODO(xpan): Make it text proto if it doesn't scale. # Each line of the manifest file specifies an entry. The entry can be # 1) TestNameRegex // E.g. CumprodTest.* Or # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16 # The 1) disables the entire test. While 2) only filter some numeric types # so that they are not used in those tests. + self.disabled_regex = None + self._method_types_filter = {} if FLAGS.disabled_manifest is not None: - comments_re = re.compile('#.*$') - manifest_file = open(FLAGS.disabled_manifest, 'r') - disabled_tests = [] - disabled_method_types = [] - for l in manifest_file.read().splitlines(): - if not l: - continue - entry = comments_re.sub('', l).strip().split(' ') - if len(entry) == 1: - disabled_tests.append(entry[0]) - elif len(entry) == 2: - disabled_method_types.append( - (entry[0], entry[1].strip().split(','))) - else: - raise ValueError('Bad entry in manifest file.') - - self.disabled_regex = re.compile('|'.join(disabled_tests)) - for method, types in disabled_method_types: - self._method_types_filter[method] = set([ - dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype - for name in types]) - manifest_file.close() + with open(FLAGS.disabled_manifest, 'r') as manifest_file: + disabled_regex, self._method_types_filter = ( + parse_disabled_manifest(manifest_file.read())) + if disabled_regex: + self.disabled_regex = re.compile(disabled_regex) if FLAGS.tf_xla_flags is not None: os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags diff --git a/tensorflow/compiler/tests/xla_test_test.py b/tensorflow/compiler/tests/xla_test_test.py new file mode 100644 index 0000000000000000000000000000000000000000..24664451579445edaadb335c30d253ee55f003da --- /dev/null +++ b/tensorflow/compiler/tests/xla_test_test.py @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the XLATestCase test fixture base class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.platform import test + + +class XlaTestCaseTestCase(test.TestCase): + + def testManifestEmptyLineDoesNotCatchAll(self): + manifest = """ +testCaseOne +""" + disabled_regex, _ = xla_test.parse_disabled_manifest(manifest) + self.assertEqual(disabled_regex, "testCaseOne") + + def testManifestWholeLineCommentDoesNotCatchAll(self): + manifest = """# I am a comment +testCaseOne +testCaseTwo +""" + disabled_regex, _ = xla_test.parse_disabled_manifest(manifest) + self.assertEqual(disabled_regex, "testCaseOne|testCaseTwo") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index cd57452302fcbde37d79ce760a80615a76d7ad8c..fda32c8a1c9491e0dadceec0d7265e1002d41528 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -81,7 +81,7 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -91,6 +91,22 @@ cc_library( ], ) +cc_library( + name = "cpu_function_runtime", + srcs = ["cpu_function_runtime.cc"], + hdrs = ["cpu_function_runtime.h"], + visibility = [ + "//tensorflow/compiler/aot:__pkg__", + "//tensorflow/compiler/xla/service/cpu:__pkg__", + ], + deps = [ + # Keep dependencies to a minimum here; this library is used in every AOT + # binary produced by tfcompile. + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + ], +) + cc_library( name = "xla_compiled_cpu_function", srcs = ["xla_compiled_cpu_function.cc"], @@ -99,12 +115,23 @@ cc_library( deps = [ # Keep dependencies to a minimum here; this library is used in every AOT # binary produced by tfcompile. - "//tensorflow/compiler/aot:runtime", + ":cpu_function_runtime", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", ], ) +tf_cc_test( + name = "cpu_function_runtime_test", + srcs = ["cpu_function_runtime_test.cc"], + deps = [ + ":cpu_function_runtime", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "xla_jit_compiled_cpu_function", srcs = ["xla_jit_compiled_cpu_function.cc"], @@ -119,7 +146,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service/cpu:buffer_info_util", "//tensorflow/compiler/xla/service/cpu:cpu_executable", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -162,15 +191,19 @@ cc_library( ":sharding_util", ":tf2xla_util", "//tensorflow/compiler/tf2xla/lib:util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -198,7 +231,7 @@ cc_library( ], visibility = [":friends"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu_internal", @@ -281,10 +314,12 @@ tf_cc_test( deps = [ ":tf2xla", ":tf2xla_proto", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -323,7 +358,7 @@ tf_cc_test( "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", @@ -360,6 +395,7 @@ tf_cc_test( ], deps = [ ":common", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/core:framework", "//tensorflow/core:test", @@ -462,3 +498,13 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) + +tf_cc_test( + name = "xla_op_registry_test", + srcs = ["xla_op_registry_test.cc"], + deps = [ + ":xla_compiler", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc similarity index 70% rename from tensorflow/compiler/aot/runtime.cc rename to tensorflow/compiler/tf2xla/cpu_function_runtime.cc index 5e74079fc158379b8977ada6412141e39142c3d3..fcc4095e39673b786544984a41988c3e9c5b0efb 100644 --- a/tensorflow/compiler/aot/runtime.cc +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,22 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/aot/runtime.h" - -#include +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/core/platform/dynamic_annotations.h" namespace tensorflow { -namespace tfcompile { -namespace runtime { - namespace { - // Inline memory allocation routines here, because depending on '//base' brings // in libraries which use c++ streams, which adds considerable code size on // android. -inline void* aligned_malloc(size_t size, int minimum_alignment) { +void* aligned_malloc(size_t size, int minimum_alignment) { #if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN) return memalign(minimum_alignment, size); #elif defined(_WIN32) @@ -47,7 +41,7 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) { #endif } -inline void aligned_free(void* aligned_memory) { +void aligned_free(void* aligned_memory) { #if defined(_WIN32) _aligned_free(aligned_memory); #else @@ -58,22 +52,29 @@ inline void aligned_free(void* aligned_memory) { size_t align_to(size_t n, size_t align) { return (((n - 1) / align) + 1) * align; } - } // namespace -size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) { +namespace cpu_function_runtime { +size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n, + bool allocate_entry_params) { size_t total = 0; for (size_t i = 0; i < n; ++i) { - if (sizes[i] != -1) { - total += align_to(sizes[i], kAlign); + bool should_allocate = + buffer_infos[i].is_temp_buffer() || + (buffer_infos[i].is_entry_parameter() && allocate_entry_params); + + if (should_allocate) { + total += align_to(buffer_infos[i].size(), kAlign); } } return total; } -void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, +void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n, + bool allocate_entry_params, void** bufs, bool annotate_initialized) { - const size_t total = aligned_buffer_bytes(sizes, n); + const size_t total = + AlignedBufferBytes(buffer_infos, n, allocate_entry_params); void* contiguous = nullptr; if (total > 0) { contiguous = aligned_malloc(total, kAlign); @@ -85,11 +86,14 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, } uintptr_t pos = reinterpret_cast(contiguous); for (size_t i = 0; i < n; ++i) { - if (sizes[i] == -1) { - bufs[i] = nullptr; - } else { + bool should_allocate = + buffer_infos[i].is_temp_buffer() || + (buffer_infos[i].is_entry_parameter() && allocate_entry_params); + if (should_allocate) { bufs[i] = reinterpret_cast(pos); - pos += align_to(sizes[i], kAlign); + pos += align_to(buffer_infos[i].size(), kAlign); + } else { + bufs[i] = nullptr; } } return contiguous; @@ -100,7 +104,5 @@ void FreeContiguous(void* contiguous) { aligned_free(contiguous); } } - -} // namespace runtime -} // namespace tfcompile +} // namespace cpu_function_runtime } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime.h b/tensorflow/compiler/tf2xla/cpu_function_runtime.h new file mode 100644 index 0000000000000000000000000000000000000000..dfc1e8b8aebcf3142e9f61f60171c6b58634c71d --- /dev/null +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.h @@ -0,0 +1,165 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ +#define TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ + +#include "tensorflow/core/platform/types.h" + +#include + +namespace tensorflow { +namespace cpu_function_runtime { +// Stores information about one buffer used by an XLA:CPU compiled function. +// These buffers are used for holding inputs to the computation, outputs from +// the computation and as temporary scratch space. +class BufferInfo { + public: + // Creates a BufferInfo from a serialized encoding generated by `Encode`. + explicit BufferInfo(std::pair encoding) + : entry_param_number_(encoding.second) { + Kind kind; + uint64 size; + Unpack(encoding.first, &kind, &size); + kind_ = kind; + size_ = size; + } + + // Returns true if this buffer stores a constant. These never need to be + // allocated by the runtime. + bool is_constant() const { return kind() == Kind::kConstant; } + + // Returns true if this buffer stores an entry parameter. These may or may + // not need to be allocated by the runtime, depending on + // XlaCompiledCpuFunction::AllocMode. + bool is_entry_parameter() const { return kind() == Kind::kEntryParameter; } + + // Returns the entry parameter number of this buffer. + uint64 entry_parameter_number() const { + assert(is_entry_parameter()); + return entry_param_number_; + } + + // Returns true if this buffer is temporary scratch space required by the XLA + // computations. These are always allocated by the runtime. + bool is_temp_buffer() const { return kind() == Kind::kTempBuffer; } + + // Returns true if this buffer is allocated on the C stack or into registers. + // These buffers are never allocated by the runtime. + bool is_on_stack_buffer() const { return kind() == Kind::kOnStackBuffer; } + + // Returns the size for this buffer. + uint64 size() const { return size_; } + + // Encodes this BufferInfo into two 64 bit integers that can be used to + // reconstruct the BufferInfo later using the constructor. We need this + // because we use BufferInfo in places where using protocol buffers would + // negatively impact binary size. + std::pair Encode() const { + static_assert(sizeof(*this) == 16, ""); + uint64 upper = Pack(kind(), size_); + uint64 lower = entry_param_number_; + return {upper, lower}; + } + + bool operator==(const BufferInfo& buffer_info) const { + if (kind() != buffer_info.kind() || size() != buffer_info.size()) { + return false; + } + return !is_entry_parameter() || + entry_parameter_number() == buffer_info.entry_parameter_number(); + } + + // Factory methods: + + static BufferInfo MakeTempBuffer(uint64 size) { + return BufferInfo(Kind::kTempBuffer, /*size=*/size, + /*entry_param_number=*/-1); + } + static BufferInfo MakeConstant(uint64 size) { + return BufferInfo(Kind::kConstant, /*size=*/size, + /*entry_param_number=*/-1); + } + static BufferInfo MakeEntryParameter(uint64 size, uint64 param_number) { + return BufferInfo(Kind::kEntryParameter, /*size=*/size, + /*entry_param_number=*/param_number); + } + static BufferInfo MakeOnStackBuffer(uint64 size) { + return BufferInfo(Kind::kOnStackBuffer, /*size=*/size, + /*entry_param_number=*/-1); + } + + private: + BufferInfo() = default; + + enum class Kind : unsigned { + kConstant, + kTempBuffer, + kEntryParameter, + kOnStackBuffer + }; + + Kind kind() const { return static_cast(kind_); } + + explicit BufferInfo(Kind kind, uint64 size, uint64 entry_param_number) + : kind_(kind), size_(size), entry_param_number_(entry_param_number) {} + + static uint64 Pack(Kind kind, uint64 size) { + return (static_cast(size) << 2) | static_cast(kind); + } + + static void Unpack(uint64 packed, Kind* kind, uint64* size) { + *size = packed >> 2; + *kind = static_cast((packed << 62) >> 62); + } + + Kind kind_ : 2; + uint64 size_ : 62; + int64 entry_param_number_; +}; + +// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. +constexpr size_t kAlign = 64; + +// AlignedBufferBytes returns the sum of the size of each buffer in +// `buffer_infos`, skipping constants, on-stack buffers and, if +// allocate_entry_params is false, entry parameters. There are `n` entries in +// `buffer_infos`. Each buffer is aligned to kAlign byte boundaries. +size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n, + bool allocate_entry_params); + +// MallocContiguousBuffers allocates buffers for use by the entry point +// generated by tfcompile. There are `n` entries in `buffer_infos`. If +// `annotate_initialized` is set, the allocated memory will be annotated as +// having been initialized - this is useful when allocating temporary buffers. +// If allocate_entry_params is true then allocates temp buffers and entry +// parameters, otherwise allocated only temp buffers. Slots in `bufs` +// corresponding to unallocated buffers are set to nullptr. +// +// A single contiguous block of memory is allocated, and portions of it are +// parceled out into `bufs`, which must have space for `n` entries. Returns +// the head of the allocated contiguous block, which should be passed to +// FreeContiguous when the buffers are no longer in use. +void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n, + bool allocate_entry_params, void** bufs, + bool annotate_initialized); + +// FreeContiguous frees the contiguous block of memory allocated by +// MallocContiguousBuffers. +void FreeContiguous(void* contiguous); +} // namespace cpu_function_runtime +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc similarity index 50% rename from tensorflow/compiler/aot/runtime_test.cc rename to tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc index 06ec623eb2dce5f8dc7156fb7e7b9ad57d90c8ee..8ca628c4eb6700d7184899bc1753dd6c6aa392b0 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc @@ -13,39 +13,70 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/aot/runtime.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { -namespace tfcompile { -namespace runtime { namespace { -TEST(Runtime, AlignmentValue) { +using cpu_function_runtime::BufferInfo; + +TEST(XlaCompiledCpuFunctionTest, AlignmentValue) { // We've chosen 64 byte alignment for the tfcompile runtime to mimic the // regular tensorflow allocator, which was chosen to play nicely with Eigen. // The tfcompile runtime also has a requirement that comes from the xla // generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8 // So any value that we choose must abide by that constraint as well. - EXPECT_EQ(kAlign, Allocator::kAllocatorAlignment); + EXPECT_EQ(cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment); +} + +std::vector SizesToBufferInfos(const intptr_t* sizes, size_t n) { + std::vector buffer_infos; + std::transform(sizes, sizes + n, std::back_inserter(buffer_infos), + [&](intptr_t size) { + if (size == -1) { + // Use a dummy on-stack buffer allocation to indicat the + // the current slot does not need an allocation. + int64 on_stack_buffer_size = 4; + return BufferInfo::MakeOnStackBuffer(on_stack_buffer_size); + } + return BufferInfo::MakeTempBuffer(size); + }); + return buffer_infos; +} + +// Simple wrappers to make writing tests more ergonomic. + +size_t AlignedBufferBytesFromSizes(const intptr_t* sizes, size_t n) { + std::vector buffer_infos = SizesToBufferInfos(sizes, n); + return AlignedBufferBytes(buffer_infos.data(), n, + /*allocate_entry_params=*/false); } -TEST(Runtime, AlignedBufferBytes) { - EXPECT_EQ(aligned_buffer_bytes(nullptr, 0), 0); +void* MallocContiguousBuffersFromSizes(const intptr_t* sizes, size_t n, + void** bufs, bool annotate_initialized) { + std::vector buffer_infos = SizesToBufferInfos(sizes, n); + return MallocContiguousBuffers(buffer_infos.data(), n, + /*allocate_entry_params=*/false, bufs, + annotate_initialized); +} + +TEST(XlaCompiledCpuFunctionTest, AlignedBufferBytes) { + EXPECT_EQ(AlignedBufferBytesFromSizes(nullptr, 0), 0); static constexpr intptr_t sizesA[1] = {-1}; - EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0); + EXPECT_EQ(AlignedBufferBytesFromSizes(sizesA, 1), 0); static constexpr intptr_t sizesB[1] = {3}; - EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 64); + EXPECT_EQ(AlignedBufferBytesFromSizes(sizesB, 1), 64); static constexpr intptr_t sizesC[1] = {32}; - EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 64); + EXPECT_EQ(AlignedBufferBytesFromSizes(sizesC, 1), 64); static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; - EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 320); + EXPECT_EQ(AlignedBufferBytesFromSizes(sizesD, 7), 320); } void* add_ptr(void* base, uintptr_t delta) { @@ -56,48 +87,48 @@ void* add_ptr(void* base, uintptr_t delta) { // expected nullptrs, and write to each byte of allocated memory. We rely on // the leak checker to tell us if there's an inconsistency between malloc and // free. We also check the contiguous property. -TEST(Runtime, MallocFreeContiguousBuffers) { +TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { // Test empty sizes. - void* base = MallocContiguousBuffers(nullptr, 0, nullptr, false); + void* base = MallocContiguousBuffersFromSizes(nullptr, 0, nullptr, false); EXPECT_EQ(base, nullptr); - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with 0 sum. static constexpr intptr_t sizesA[1] = {-1}; void* bufA[1]; - base = MallocContiguousBuffers(sizesA, 1, bufA, false); + base = MallocContiguousBuffersFromSizes(sizesA, 1, bufA, false); EXPECT_EQ(base, nullptr); EXPECT_EQ(bufA[0], nullptr); - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with non-0 sum. static constexpr intptr_t sizesB[1] = {3}; void* bufB[1]; - base = MallocContiguousBuffers(sizesB, 1, bufB, false); + base = MallocContiguousBuffersFromSizes(sizesB, 1, bufB, false); EXPECT_NE(base, nullptr); EXPECT_EQ(bufB[0], add_ptr(base, 0)); char* bufB0_bytes = static_cast(bufB[0]); bufB0_bytes[0] = 'A'; bufB0_bytes[1] = 'B'; bufB0_bytes[2] = 'C'; - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with non-0 sum, and annotate_initialized. static constexpr intptr_t sizesC[1] = {3}; void* bufC[1]; - base = MallocContiguousBuffers(sizesC, 1, bufC, true); + base = MallocContiguousBuffersFromSizes(sizesC, 1, bufC, true); EXPECT_NE(base, nullptr); EXPECT_EQ(bufC[0], add_ptr(base, 0)); char* bufC0_bytes = static_cast(bufC[0]); bufC0_bytes[0] = 'A'; bufC0_bytes[1] = 'B'; bufC0_bytes[2] = 'C'; - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test mixed sizes. static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; void* bufD[7]; - base = MallocContiguousBuffers(sizesD, 7, bufD, false); + base = MallocContiguousBuffersFromSizes(sizesD, 7, bufD, false); EXPECT_NE(base, nullptr); EXPECT_EQ(bufD[0], add_ptr(base, 0)); EXPECT_EQ(bufD[1], nullptr); @@ -115,10 +146,26 @@ TEST(Runtime, MallocFreeContiguousBuffers) { } } } - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); +} + +void CheckRoundTripIsOk(const BufferInfo& buffer_info) { + BufferInfo round_trip(buffer_info.Encode()); + ASSERT_EQ(round_trip, buffer_info); +} + +TEST(XlaCompiledCpuFunctionTest, BufferInfoTest) { + CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(0)); + CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(4)); + CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(0)); + CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(4)); + CheckRoundTripIsOk(BufferInfo::MakeConstant(0)); + CheckRoundTripIsOk(BufferInfo::MakeConstant(4)); + CheckRoundTripIsOk( + BufferInfo::MakeEntryParameter(/*size=*/0, /*param_number=*/4)); + CheckRoundTripIsOk( + BufferInfo::MakeEntryParameter(/*size=*/4, /*param_number=*/0)); } } // namespace -} // namespace runtime -} // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 03603ee9baefd1d20d220faf63c9c1c427ebdf31..24616c01c7e54b2e8662457ca6af23a0bc563e08 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -33,7 +33,7 @@ struct NameCounts { std::unordered_map counts; }; -string MakeUniquePath(string name) { +string MakeUniqueFilename(string name) { static NameCounts& instance = *new NameCounts; // Remove illegal characters from `name`. @@ -50,26 +50,41 @@ string MakeUniquePath(string name) { count = instance.counts[name]++; } - legacy_flags::DumpGraphFlags* flags = legacy_flags::GetDumpGraphFlags(); - string path = strings::StrCat(flags->tf_dump_graph_prefix, "/", name); + string filename = name; if (count > 0) { - strings::StrAppend(&path, "_", count); + strings::StrAppend(&filename, "_", count); } - strings::StrAppend(&path, ".pbtxt"); - return path; + strings::StrAppend(&filename, ".pbtxt"); + return filename; +} + +string WriteTextProtoToUniqueFile( + Env* env, const string& name, const char* proto_type, + const ::tensorflow::protobuf::Message& proto) { + const string& dirname = + legacy_flags::GetDumpGraphFlags()->tf_dump_graph_prefix; + Status status = env->RecursivelyCreateDir(dirname); + if (!status.ok()) { + LOG(WARNING) << "Failed to create " << dirname << " for dumping " + << proto_type << ": " << status; + return "(unavailable)"; + } + string filepath = strings::StrCat(dirname, "/", MakeUniqueFilename(name)); + status = WriteTextProto(Env::Default(), filepath, proto); + if (!status.ok()) { + LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath + << " : " << status; + return "(unavailable)"; + } + LOG(INFO) << "Dumped " << proto_type << " to " << filepath; + return filepath; } } // anonymous namespace string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { - string path = MakeUniquePath(name); - Status status = WriteTextProto(Env::Default(), path, graph_def); - if (!status.ok()) { - VLOG(1) << "Failed to dump GraphDef to file: " << path << " : " << status; - path.clear(); - path = "(unavailable)"; - } - return path; + return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef", + graph_def); } string DumpGraphToFile(const string& name, Graph const& graph, @@ -83,15 +98,7 @@ string DumpGraphToFile(const string& name, Graph const& graph, } string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { - string path = MakeUniquePath(name); - Status status = WriteTextProto(Env::Default(), path, fdef); - if (!status.ok()) { - VLOG(1) << "Failed to dump FunctionDef to file: " << path << " : " - << status; - path.clear(); - path = "(unavailable)"; - } - return path; + return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef); } } // namespace dump_graph diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 1438f6b48c4913e60b0c0a9f5c3d67fe595cbfe8..0904778f97c95628c81054cd4bc2ff32ff440a33 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -166,6 +166,27 @@ StatusOr AddNode(const NodeDef& node_def, Graph* graph) { return inserted_node; } +// Check that the graph has no cycle containing the given node. +Status CheckNoCycleContains(const Node* node, const int num_nodes) { + std::vector ready; + ready.push_back(node); + std::vector visited(num_nodes); + while (!ready.empty()) { + const Node* current_node = ready.back(); + ready.pop_back(); + visited[current_node->id()] = true; + for (const Edge* out : current_node->out_edges()) { + if (out->dst() == node) { + return errors::Internal("Detected a cycle: ", FormatNodeForError(*node), + "(", node->def().op(), ") feeds into itself."); + } else if (!visited[out->dst()->id()]) { + ready.push_back(out->dst()); + } + } + } + return Status::OK(); +} + StatusOr BuildArgNode(Graph* graph, DataType type, int index) { NodeDef arg_def; NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); @@ -303,7 +324,7 @@ Status AddMissingFunctionDef(const FunctionDef& fdef, if (library->Find(node.op())) { continue; } - // The function refered by 'SymbolicGradient' node is specified in its + // The function referred by 'SymbolicGradient' node is specified in its // attribute 'f'. if (node.op() == FunctionLibraryDefinition::kGradientOp) { const AttrValue* attr = @@ -416,22 +437,24 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, continue; } if (enter_merge != nullptr) { - return errors::Internal( - "Enter node for loop-varying argument ", arg.enter->name(), - " has multiple successors: ", enter_merge->dst()->name(), " and ", - e->dst()->name()); + return errors::Internal("Enter node for loop-varying argument ", + FormatNodeForError(*arg.enter), + " has multiple successors: ", + FormatNodeForError(*enter_merge->dst()), + " and ", FormatNodeForError(*e->dst())); } enter_merge = e; } if (enter_merge == nullptr) { return errors::Internal("Enter node for loop-varying argument ", - arg.enter->name(), " has zero successors"); + FormatNodeForError(*arg.enter), + " has zero successors"); } arg.merge = enter_merge->dst(); if (!IsMerge(arg.merge)) { return errors::InvalidArgument( "Successor of Enter node for loop-varying argument ", - arg.merge->name(), + FormatNodeForError(*arg.merge), " is not a Merge node; got: ", arg.merge->type_string()); } @@ -441,7 +464,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, return errors::InvalidArgument( "Unexpected number of inputs to Merge node for loop-varying " "argument ", - arg.merge->name(), "; expected 2, got ", + FormatNodeForError(*arg.merge), "; expected 2, got ", arg.merge->input_types().size()); } TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(), @@ -449,7 +472,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, if (!IsNextIteration(arg.next_iteration)) { return errors::InvalidArgument( "Expected NextIteration node as input to Merge node; got node ", - arg.next_iteration->name(), " with kind ", + FormatNodeForError(*arg.next_iteration), " with kind ", arg.next_iteration->type_string()); } @@ -460,14 +483,14 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, switches.find(edge->dst()) != switches.end()) { if (arg.switch_node != nullptr) { return errors::InvalidArgument("Duplicate Switch successors to ", - arg.merge->name()); + FormatNodeForError(*arg.merge)); } arg.switch_node = edge->dst(); } } if (arg.switch_node == nullptr) { return errors::InvalidArgument("Missing Switch successor to ", - arg.merge->name()); + FormatNodeForError(*arg.merge)); } // Update the device on the Identity outputs of the switch to match their @@ -495,14 +518,15 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, possible_exit.pop_front(); if (IsExit(edge->dst())) { if (arg.exit != nullptr) { - return errors::InvalidArgument("Duplicate Exit successors to ", - arg.switch_node->name()); + return errors::InvalidArgument( + "Duplicate Exit successors to ", + FormatNodeForError(*arg.switch_node)); } arg.exit = edge->dst(); } else { if (!IsIdentity(edge->dst())) { return errors::Unimplemented("General graph between switch (", - arg.switch_node->name(), + FormatNodeForError(*arg.switch_node), ") and exit node of frame ", frame->name, " not supported yet."); } @@ -1407,6 +1431,10 @@ StatusOr FunctionalizeCond::ConvertToXlaIf( TF_RETURN_IF_ERROR( AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node)); TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); + // Check that the if_node doesn't feed into itself. + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNoCycleContains(if_node, graph_->num_node_ids()), + "ConvertToXlaIf failed."); return if_node; } @@ -1439,11 +1467,13 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, // invariant. std::vector cf_info; std::vector unreachable_nodes; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes), + "FunctionalizeControlFlow failed"); if (!unreachable_nodes.empty()) { return errors::InvalidArgument( "The following nodes are unreachable from the source in the graph: ", - tensorflow::str_util::Join(unreachable_nodes, ", ")); + errors::FormatNodeNamesForError(unreachable_nodes)); } // Builds Frames, indexed by name. @@ -1464,10 +1494,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, frame.parent = parent; frame.name = cf.frame_name; ++parent->num_children; - } else if (frame.parent != parent) { - return errors::InvalidArgument("Mismatched parent frames for ", - cf.frame->id(), ": ", parent->name, " vs ", - frame.parent->name); } if (IsEnter(node)) { @@ -1477,12 +1503,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, &arg.is_loop_invariant)); frame.args.push_back(arg); } else if (IsLoopCond(node)) { - if (frame.loop_cond) { - return errors::InvalidArgument( - "Loop ", cf.frame_name, - " has more than one LoopCond node: ", node->name(), " and ", - frame.loop_cond->name()); - } frame.loop_cond = node; } frame.nodes.insert(node); @@ -1514,6 +1534,16 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, worklist.push_back(frame->parent); } } + // There should be no cycle at this point, since while loops have been removed + // from graph. + // Check that the newly added XlaWhile nodes don't feed into themselves. + for (const Node* node : graph->op_nodes()) { + if (node->def().op() == "XlaWhile") { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNoCycleContains(node, graph->num_node_ids()), + "FunctionalizeLoop failed."); + } + } // FunctionalizeControlFlow is invoked for every function, so the loops's // bodies and conditionals that were extracted into functions will be handled diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 14977a908ae2b0ff7e13b634c41b6d331b4b8a36..ccf249b35d66861888ad5e5e904b5f63b8ac50a1 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/equal_graph_def.h" @@ -1012,5 +1013,63 @@ TEST(FunctionalizeControlFlow, Complex) { } } +TEST(FunctionalizeControlFlow, Cycle) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + // ----------------------------------------------------- + // | | + // | v + // less -> switch_1 --> add -> merge_1 -> identity -> switch_2 + // | ^ | + // | | v + // --------> one -------------------------> add_2 ---> merge_2 + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); + auto two = + ops::Const(scope.WithOpName("cond/two") + .WithControlDependencies(switch_1.output_true), + 2); + auto mul = ops::Multiply(scope.WithOpName("cond/true/mul"), + switch_1.output_true, two); + auto one = + ops::Const(scope.WithOpName("cond/one") + .WithControlDependencies(switch_1.output_false), + 1); + auto add = ops::Add(scope.WithOpName("cond/false/add"), + switch_1.output_false, one); + + auto merge_1 = ops::Merge(scope.WithOpName("cond/Merge"), + std::initializer_list{add, mul}); + auto identity = + ops::Identity(scope.WithOpName("cond/Merge/identity"), merge_1.output); + auto switch_2 = + ops::Switch(scope.WithOpName("grad/cond/Switch"), identity, less); + auto add_2 = ops::Add(scope.WithOpName("cond_2/false/add"), + switch_2.output_false, one); + auto mul_2 = ops::Multiply(scope.WithOpName("cond_2/true/mul"), + switch_2.output_true, two); + auto merge_2 = ops::Merge(scope.WithOpName("cond_2/Merge"), + std::initializer_list{add_2, mul_2}); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + // No cycle before functionalize control flow. + TF_EXPECT_OK(graph::ValidateGraphHasNoCycle(*graph)); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + // switch_1 and switch_2 have the same switch depth. They are replaced by a + // single XlaIf node during FunctionalizeControlFlow, resulting in a cycle: + // less -> XlaIf <--> identity. + Status status = FunctionalizeControlFlow(graph.get(), &library); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detected a cycle")) + << status.error_message(); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "{{node cond/Less_5_If}}")) + << status.error_message(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 212f6f3966149ca0b2d2e012b19300e1f488f996..e4fdf0a6186eb69a2e3413838c91616b992ef2d6 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -39,6 +40,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -87,6 +89,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, } } // namespace Status GraphCompiler::Compile() { + // Check that the graph has no illegal cycles. + TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_)); // Maintain a mapping from node id to node outputs. using NodeOutputs = std::vector; std::vector output_registry(graph_->num_node_ids()); @@ -157,9 +161,8 @@ Status GraphCompiler::Compile() { outputs.resize(n->num_outputs()); for (int o = 0; o < n->num_outputs(); ++o) { outputs[o] = op_context.release_output(o); - if (*op_context.is_output_dead() || outputs[o].tensor == nullptr) { + if (outputs[o].tensor == nullptr) { return errors::Internal("Missing xla_context ", o, "-th output from ", - (*op_context.is_output_dead() ? "(dead)" : ""), SummarizeNode(*n)); } } @@ -227,7 +230,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, XlaContext& context = XlaContext::Get(op_context); auto* b = context.builder(); - auto output_handle = b->Call(*result.computation, handles); + auto output_handle = xla::Call(b, *result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so // that it can fit into future computations. int computation_output = 0; @@ -236,7 +239,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); } else { xla_op_context.SetOutput( - i, b->GetTupleElement(output_handle, computation_output)); + i, xla::GetTupleElement(output_handle, computation_output)); ++computation_output; } } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index edd2ab6301ee891c433639ce300cde0c72929cea..b1366e9e31e28406c5bf1a808b9c5670558ed9c7 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -6,6 +6,10 @@ package( load("//tensorflow:tensorflow.bzl", "tf_copts") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load( + "//third_party/mkl:build_defs.bzl", + "if_mkl", +) tf_kernel_library( name = "xla_ops", @@ -58,6 +62,7 @@ tf_kernel_library( "pack_op.cc", "pad_op.cc", "pooling_ops.cc", + "qr_op.cc", "quantize_and_dequantize_op.cc", "random_ops.cc", "reduce_window_op.cc", @@ -79,14 +84,17 @@ tf_kernel_library( "shape_util.cc", "slice_op.cc", "softmax_op.cc", + "sort_ops.cc", "spacetobatch_op.cc", "spacetodepth_op.cc", + "sparse_to_dense_op.cc", "split_op.cc", "stack_ops.cc", "stateless_random_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", "tile_ops.cc", + "topk_op.cc", "training_ops.cc", "transpose_op.cc", "unary_ops.cc", @@ -104,20 +112,30 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:cholesky", + "//tensorflow/compiler/tf2xla/lib:qr", + "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:triangular_solve", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:pooling", + "//tensorflow/compiler/xla/client/lib:prng", + "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", @@ -140,8 +158,14 @@ tf_kernel_library( "//tensorflow/core/kernels:sparse_to_dense_op", "//tensorflow/core/kernels:stack_ops", "//tensorflow/core/kernels:training_ops", - "//tensorflow/core/kernels:transpose_op", - ], + ] + if_mkl( + [ + "//tensorflow/core/kernels:mkl_transpose_op", + ], + [ + "//tensorflow/core/kernels:transpose_op", + ], + ), ) tf_kernel_library( @@ -152,8 +176,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -168,8 +193,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -203,10 +228,11 @@ tf_kernel_library( ":index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/kernels:argmax_op", diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index 1e59868621475cf72f4cc8b14dafec2dd8cd5c95..41a453da80dec6b6f57a4d222e2c33ef6b786a10 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { @@ -31,7 +32,7 @@ class AddNOp : public XlaOpKernel { xla::XlaOp sum = ctx->Input(0); for (int i = 1; i < ctx->num_inputs(); ++i) { - sum = ctx->builder()->Add(sum, ctx->Input(i)); + sum = xla::Add(sum, ctx->Input(i)); } ctx->SetOutput(0, sum); diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 26fc1620a4f032b3af28de6e3a5af0e965e82341..276d744c096f8996c774964204feaa3762bdb844 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -65,6 +65,6 @@ class XlaArgOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp); }; -REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes(), XlaArgOp); +REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes().CompilationOnly(), XlaArgOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index b0ba25b9983c3a9af26728ce4b1c263c844327db..4cfe946b2e6146f034867c06e996ffae42b90705 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -28,11 +28,10 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), + auto result = BatchDot(ctx->Input(0), ctx->Input(1), /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); - OP_REQUIRES_OK(ctx, result.status()); - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, result); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 15e1815a4cf07ff50dd1431b6790d14781da590f..b3ad0aea84eef601de08909f760699b8700d28f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -34,10 +35,11 @@ class FusedBatchNormOp : public XlaOpKernel { ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || + data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), errors::InvalidArgument( "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC and NCHW")); + "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { @@ -48,8 +50,6 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1), &scale_type)); - xla::XlaBuilder* builder = ctx->builder(); - xla::XlaOp input = ctx->Input(0); TensorShape input_shape = ctx->InputShape(0); @@ -59,30 +59,30 @@ class FusedBatchNormOp : public XlaOpKernel { // TODO(b/69928690): support mixed precision in the XLA batch normalization // operators. As a workaround, cast everything to the statistics type (which // may be more precise than the input type). - input = builder->ConvertElementType(input, scale_type); + input = xla::ConvertElementType(input, scale_type); if (is_training_) { - xla::XlaOp output = builder->BatchNormTraining( + xla::XlaOp output = xla::BatchNormTraining( input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); // In training mode, outputs the normalized value as well as the // calculated mean and variance. - ctx->SetOutput(0, builder->ConvertElementType( - builder->GetTupleElement(output, 0), input_type)); - ctx->SetOutput(1, builder->GetTupleElement(output, 1)); - ctx->SetOutput(2, builder->GetTupleElement(output, 2)); + ctx->SetOutput(0, xla::ConvertElementType(xla::GetTupleElement(output, 0), + input_type)); + ctx->SetOutput(1, xla::GetTupleElement(output, 1)); + ctx->SetOutput(2, xla::GetTupleElement(output, 2)); // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved // space 1 & 2". They are used to pass the per-batch mean and // variance to the gradient. Here we maintain the same behavior by setting // them to the mean and variance calculated by BatchNormTraining. - ctx->SetOutput(3, builder->GetTupleElement(output, 1)); - ctx->SetOutput(4, builder->GetTupleElement(output, 2)); + ctx->SetOutput(3, xla::GetTupleElement(output, 1)); + ctx->SetOutput(4, xla::GetTupleElement(output, 2)); } else { - xla::XlaOp output = builder->BatchNormInference( + xla::XlaOp output = xla::BatchNormInference( input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), epsilon_, feature_index); - ctx->SetOutput(0, builder->ConvertElementType(output, input_type)); + ctx->SetOutput(0, xla::ConvertElementType(output, input_type)); // Directly send input to output as mean and variance in inference mode. ctx->SetOutput(1, ctx->Input(3)); ctx->SetOutput(2, ctx->Input(4)); @@ -111,10 +111,11 @@ class FusedBatchNormGradOp : public XlaOpKernel { ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || + data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), errors::InvalidArgument( "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC and NCHW")); + "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { @@ -142,12 +143,12 @@ class FusedBatchNormGradOp : public XlaOpKernel { xla::XlaOp offset_backprop; if (is_training_) { xla::XlaOp output = - b->BatchNormGrad(activations, scale, mean, var, grad_backprop, - epsilon_, feature_index); + xla::BatchNormGrad(activations, scale, mean, var, grad_backprop, + epsilon_, feature_index); - x_backprop = b->GetTupleElement(output, 0); - scale_backprop = b->GetTupleElement(output, 1); - offset_backprop = b->GetTupleElement(output, 2); + x_backprop = xla::GetTupleElement(output, 0); + scale_backprop = xla::GetTupleElement(output, 1); + offset_backprop = xla::GetTupleElement(output, 2); } else { // Reduce over all dimensions except the feature dim. std::vector reduction_dims(input_dims - 1); @@ -164,35 +165,35 @@ class FusedBatchNormGradOp : public XlaOpKernel { auto converted = XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type); auto reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); - auto scratch1 = - b->Pow(b->Add(var, b->ConstantR0(epsilon_)), neg_half); + auto scratch1 = xla::Pow( + xla::Add(var, xla::ConstantR0(b, epsilon_)), neg_half); // scratch2 = sum(y_backprop * (x - mean)) auto mul = - b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})); + xla::Mul(grad_backprop, xla::Sub(activations, mean, {feature_index})); converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type); reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); x_backprop = - b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index}); - scale_backprop = b->Mul(scratch1, scratch2); + xla::Mul(grad_backprop, xla::Mul(scratch1, scale), {feature_index}); + scale_backprop = xla::Mul(scratch1, scratch2); } ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, x_backprop, input_dtype)); ctx->SetOutput(1, scale_backprop); ctx->SetOutput(2, offset_backprop); - ctx->SetConstantOutput(3, Tensor(scale_dtype, {})); - ctx->SetConstantOutput(4, Tensor(scale_dtype, {})); + ctx->SetConstantOutput(3, Tensor()); + ctx->SetConstantOutput(4, Tensor()); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 642278ab994bf3cc84396f093ed56b009a1435c1..48f2a005ab16651fe29d0f6f9d881f95693da461 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { @@ -45,7 +46,6 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, ", 2] instead of ", xla::ShapeUtil::HumanString(crops.shape()))); - xla::XlaBuilder* b = ctx->builder(); const int64 batch_size = input_shape[0]; // Compute the product of the block_shape values. @@ -72,7 +72,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, reshaped_shape[block_rank] = batch_size / block_num_elems; std::copy(input_shape.begin() + 1, input_shape.end(), reshaped_shape.begin() + block_rank + 1); - xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce `permuted` of shape // [batch / prod(block_shape), @@ -90,7 +90,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, } std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); - xla::XlaOp permuted = b->Transpose(reshaped, permutation); + xla::XlaOp permuted = xla::Transpose(reshaped, permutation); // 3. Reshape `permuted` to produce `reshaped_permuted` of shape // [batch / prod(block_shape), @@ -110,7 +110,8 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_permuted_shape.begin() + 1 + block_rank); - xla::XlaOp reshaped_permuted = b->Reshape(permuted, reshaped_permuted_shape); + xla::XlaOp reshaped_permuted = + xla::Reshape(permuted, reshaped_permuted_shape); // 4. Crop the start and end of dimensions `[1, ..., M]` of // `reshaped_permuted` according to `crops` to produce the output of shape: @@ -138,7 +139,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i])); } xla::XlaOp output = - b->Slice(reshaped_permuted, start_indices, end_indices, strides); + xla::Slice(reshaped_permuted, start_indices, end_indices, strides); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index ee2c920453c3bbaef2c145df743fddf999167c39..ba3b1c9dab79a387c48e8e25e4804917f328f8a0 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index 9d677f426650ea17a49e5ab1401078f04623fe97..41f540506ba41fbe7f91393e7b8e26a89e72ef0a 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/util/tensor_format.h" @@ -60,8 +61,7 @@ class BiasOp : public XlaOpKernel { "of the input tensor: ", bias_shape.DebugString(), " vs. ", input_shape.DebugString())); - xla::XlaOp result = - ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim}); + xla::XlaOp result = xla::Add(ctx->Input(0), ctx->Input(1), {feature_dim}); ctx->SetOutput(0, result); } @@ -109,8 +109,8 @@ class BiasAddGradOp : public XlaOpKernel { auto converted = XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); auto reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0))); } diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index f04cde878e98002d9442e0f3ec251c5197ef7969..2c328102e0bd84709707f102272691b6aec9a577 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -41,18 +41,19 @@ namespace { const BCast& broadcast_helper, \ const std::vector& extend_dimensions) override { \ xla::XlaBuilder* b = ctx->builder(); \ + (void)b; \ return HLO; \ } \ }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op) -XLA_MAKE_BINARY(Add, b->Add(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Atan2, b->Atan2(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); // Implementation of FloorDiv. Pseudo-code: // if ((x < 0) != (y < 0)) { @@ -67,13 +68,13 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); - auto different_sign = b->Ne(b->Lt(x, zero), b->Lt(y, zero)); - auto abs_x = b->Abs(x); - auto abs_y = b->Abs(y); - auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one)); - auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y)); + auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero)); + auto abs_x = xla::Abs(x); + auto abs_y = xla::Abs(y); + auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one)); + auto result = xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y)); if (DataTypeIsFloating(dtype)) { - result = b->Floor(result); + result = xla::Floor(result); } return result; } @@ -87,75 +88,78 @@ static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); - auto same_sign = b->Eq(b->Lt(x, zero), b->Lt(y, zero)); - auto trunc_mod = b->Rem(x, y); - return b->Select(same_sign, trunc_mod, b->Rem(b->Add(trunc_mod, y), y)); + auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero)); + auto trunc_mod = xla::Rem(x, y); + return xla::Select(same_sign, trunc_mod, xla::Rem(xla::Add(trunc_mod, y), y)); } XLA_MAKE_BINARY(FloorMod, FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper)); -XLA_MAKE_BINARY(BitwiseAnd, b->And(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(BitwiseOr, b->Or(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseAnd, xla::And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseOr, xla::Or(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseXor, xla::Xor(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LeftShift, b->ShiftLeft(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LeftShift, xla::ShiftLeft(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(RightShift, (DataTypeIsUnsigned(ctx->input_type(0)) - ? b->ShiftRightLogical(lhs, rhs, extend_dimensions) - : b->ShiftRightArithmetic(lhs, rhs, extend_dimensions))); - -XLA_MAKE_BINARY(LogicalAnd, b->And(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LogicalOr, b->Or(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(RealDiv, b->Div(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(ReciprocalGrad, b->Neg(b->Mul(rhs, b->Mul(lhs, lhs)))); + ? xla::ShiftRightLogical(lhs, rhs, extend_dimensions) + : xla::ShiftRightArithmetic(lhs, rhs, extend_dimensions))); + +XLA_MAKE_BINARY(LogicalAnd, xla::And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LogicalOr, xla::Or(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Mod, xla::Rem(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Maximum, xla::Max(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Minimum, xla::Min(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs)))); XLA_MAKE_BINARY( RsqrtGrad, - b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)), - b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), - extend_dimensions)); -XLA_MAKE_BINARY(SqrtGrad, - b->Div(b->Mul(rhs, - XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), - lhs, extend_dimensions)); + xla::Mul(xla::Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)), + xla::Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), + extend_dimensions)); +XLA_MAKE_BINARY( + SqrtGrad, + xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), + lhs, extend_dimensions)); static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) { - return builder->Mul(x, x); + return xla::Mul(x, x); } XLA_MAKE_BINARY(SquaredDifference, - Square(b, b->Sub(lhs, rhs, extend_dimensions))); + Square(b, xla::Sub(lhs, rhs, extend_dimensions))); -XLA_MAKE_BINARY(TruncateDiv, b->Div(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(TruncateMod, b->Rem(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions)); // Comparison ops -XLA_MAKE_BINARY(Equal, b->Eq(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(NotEqual, b->Ne(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Greater, b->Gt(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(GreaterEqual, b->Ge(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(Less, b->Lt(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LessEqual, b->Le(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Equal, xla::Eq(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(NotEqual, xla::Ne(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Greater, xla::Gt(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(GreaterEqual, xla::Ge(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Less, xla::Lt(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LessEqual, xla::Le(lhs, rhs, extend_dimensions)); // Non-linear ops XLA_MAKE_BINARY(SigmoidGrad, - b->Mul(b->Mul(rhs, lhs), - b->Sub(XlaHelpers::One(b, input_type(0)), lhs))); + xla::Mul(xla::Mul(rhs, lhs), + xla::Sub(XlaHelpers::One(b, input_type(0)), lhs))); XLA_MAKE_BINARY(SoftplusGrad, - b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)), - XlaHelpers::One(b, input_type(1))))); + xla::Div(lhs, xla::Add(xla::Exp(xla::Neg(rhs)), + XlaHelpers::One(b, input_type(1))))); // softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2 XLA_MAKE_BINARY(SoftsignGrad, - b->Div(lhs, Square(b, b->Add(XlaHelpers::One(b, input_type(0)), - b->Abs(rhs))))); + xla::Div(lhs, + Square(b, xla::Add(XlaHelpers::One(b, input_type(0)), + xla::Abs(rhs))))); -XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(lhs, lhs)))); +XLA_MAKE_BINARY(TanhGrad, + xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)), + xla::Mul(lhs, lhs)))); -XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions)); #undef XLA_MAKE_BINARY @@ -168,12 +172,13 @@ class ApproximateEqualOp : public XlaOpKernel { // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); - auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))); + auto abs = xla::Abs(xla::Sub(ctx->Input(0), ctx->Input(1))); auto abs_shape = b->GetShape(abs); OP_REQUIRES_OK(ctx, abs_shape.status()); auto abs_type = abs_shape.ValueOrDie().element_type(); - auto result = b->Lt( - abs, b->ConvertElementType(b->ConstantR0(tolerance_), abs_type)); + auto result = + xla::Lt(abs, xla::ConvertElementType( + xla::ConstantR0(b, tolerance_), abs_type)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc index ca9a6b40688d1e8496d1b823e20d273d519f65e8..5078f8662bd397eaa51274ec816c130b8ced92cc 100644 --- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -36,22 +37,22 @@ class BucketizeOp : public XlaOpKernel { const DataType dtype = context->input_type(0); xla::XlaOp input = context->Input(0); - xla::XlaOp boundaries = builder->ConstantR1(boundaries_); + xla::XlaOp boundaries = xla::ConstantR1(builder, boundaries_); // TODO(phawkins): the following behavior matches the behavior of the core // Bucketize kernel. However, comparing an int32 or int64 against float may // lead to inaccurate bucketing due to rounding. if (dtype == DT_DOUBLE) { - input = builder->ConvertElementType(input, xla::F64); - boundaries = builder->ConvertElementType(boundaries, xla::F64); + input = xla::ConvertElementType(input, xla::F64); + boundaries = xla::ConvertElementType(boundaries, xla::F64); } else { - input = builder->ConvertElementType(input, xla::F32); + input = xla::ConvertElementType(input, xla::F32); } - xla::XlaOp comparison = builder->ConvertElementType( - builder->Ge(builder->Broadcast(input, {1}), boundaries, - /*broadcast_dimensions=*/{0}), - xla::S32); - xla::XlaOp buckets = builder->Reduce( - comparison, /*init_value=*/builder->ConstantR0(0), + xla::XlaOp comparison = + xla::ConvertElementType(xla::Ge(xla::Broadcast(input, {1}), boundaries, + /*broadcast_dimensions=*/{0}), + xla::S32); + xla::XlaOp buckets = xla::Reduce( + comparison, /*init_value=*/xla::ConstantR0(builder, 0), /*computation=*/xla::CreateScalarAddComputation(xla::S32, builder), /*dimensions_to_reduce=*/{0}); context->SetOutput(0, buckets); diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index e9d98c768572c52825fa5192ecec834889f040fe..8cc2479dd555380da7500abe6b2aca380110333b 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -40,14 +41,14 @@ class CastOp : public XlaOpKernel { if (src_dtype_ == dst_dtype_) { output = input; } else if (dst_dtype_ == DT_BOOL) { - output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_)); + output = xla::Ne(input, XlaHelpers::Zero(builder, src_dtype_)); } else if (xla::primitive_util::IsComplexType(src_type_) && !xla::primitive_util::IsComplexType(dst_type_)) { // As in cast_op.h, we replicate the numpy behavior of truncating the // imaginary part. - output = builder->ConvertElementType(builder->Real(input), dst_type_); + output = xla::ConvertElementType(xla::Real(input), dst_type_); } else { - output = builder->ConvertElementType(input, dst_type_); + output = xla::ConvertElementType(input, dst_type_); } ctx->SetOutput(0, output); @@ -72,7 +73,6 @@ class BitcastOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); xla::XlaOp input = ctx->Input(0); xla::XlaOp output; @@ -92,7 +92,7 @@ class BitcastOp : public XlaOpKernel { xla::primitive_util::BitWidth(dst_type_), errors::Unimplemented( "Only bitcasts between equally sized types supported.")); - output = builder->BitcastConvertType(input, dst_type_); + output = xla::BitcastConvertType(input, dst_type_); } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 835a7f568945f0bee86fe2b39491c3326726e1aa..e7fef77edcba0ea5a521956a704225ac4f7fcb22 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -65,24 +66,22 @@ class CategoricalOp : public XlaOpKernel { DataTypeToPrimitiveType(input_type(0), &uniform_xla_type)); xla::Shape uniform_shape = xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); - auto uniforms = builder->RngUniform( - XlaHelpers::Zero(builder, input_type(0)), - XlaHelpers::One(builder, input_type(0)), uniform_shape); + auto uniforms = + xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), + XlaHelpers::One(builder, input_type(0)), uniform_shape); // Use Gumbel softmax trick to generate categorical samples. // See: // https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ // TODO(b/68769470): Switch to using a cumulative sum approach. - auto softmax_entries = - builder->Sub(logits, builder->Log(builder->Neg(builder->Log(uniforms))), - /*broadcast_dimensions=*/{0, 2}); - - TensorShape softmax_shape(uniform_shape_array); - xla::XlaOp argmax; - OP_REQUIRES_OK( - ctx, - XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape, - input_type(0), output_type(0), /*axis=*/2, &argmax)); + auto softmax_entries = xla::Sub(logits, xla::Log(-xla::Log(uniforms)), + /*broadcast_dimensions=*/{0, 2}); + + xla::PrimitiveType xla_output_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(output_type(0), &xla_output_type)); + xla::XlaOp argmax = + XlaHelpers::ArgMax(softmax_entries, xla_output_type, /*axis=*/2); ctx->SetOutput(0, argmax); } diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index fe6651793dc763d13f4a4b0ac294ec3ecf64af8f..9fcbc86adc0967cbb7fb73da8bdabc58b60953da 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -24,12 +24,7 @@ class CholeskyOp : public XlaOpKernel { public: explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - auto result = Cholesky(ctx->builder(), ctx->Input(0)); - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, Cholesky(ctx->Input(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc index a00bc912f9f40052565446c6bf9390629af9a4cd..547fe48046e8c934e3bc14d02c8448e107c1a406 100644 --- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -29,7 +30,6 @@ class ClipByValueOp : public XlaOpKernel { const TensorShape min_shape = ctx->InputShape(1); const TensorShape max_shape = ctx->InputShape(2); - xla::XlaBuilder* builder = ctx->builder(); auto input = ctx->Input(0); auto min = ctx->Input(1); auto max = ctx->Input(2); @@ -45,13 +45,13 @@ class ClipByValueOp : public XlaOpKernel { if (shape != min_shape) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(min_shape), shape_error()); - min = builder->Broadcast(min, shape.dim_sizes()); + min = xla::Broadcast(min, shape.dim_sizes()); } if (shape != max_shape) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(max_shape), shape_error()); - max = builder->Broadcast(max, shape.dim_sizes()); + max = xla::Broadcast(max, shape.dim_sizes()); } - ctx->SetOutput(0, builder->Clamp(min, input, max)); + ctx->SetOutput(0, xla::Clamp(min, input, max)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 78285affa1c399ae107a9172fb85cf257457c368..f4106051043859a6786705009d76b02a64cd3ff1 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -88,7 +89,7 @@ class ConcatBaseOp : public XlaOpKernel { "] = ", in_shape.DebugString())); if (in_shape.dims() == 0) { // Inputs that come in as scalars must be reshaped to 1-vectors. - input_data.push_back(ctx->builder()->Reshape(handle, {1})); + input_data.push_back(xla::Reshape(handle, {1})); } else { input_data.push_back(handle); } @@ -96,7 +97,7 @@ class ConcatBaseOp : public XlaOpKernel { } VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - ctx->SetOutput(0, ctx->builder()->ConcatInDim(input_data, axis)); + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 59d06c654de18c9003fe0bdc706d0c2443de6d7b..da8cf3fc6fa694f592280f8c249d317827d9cd09 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -53,41 +54,41 @@ class ConstOp : public XlaOpKernel { switch (proto_.dtype()) { case DT_BOOL: if (proto_.bool_val_size() == 1) { - ctx->SetOutput(0, - b->Broadcast(b->ConstantR0(proto_.bool_val(0)), - shape.dim_sizes())); + ctx->SetOutput( + 0, xla::Broadcast(xla::ConstantR0(b, proto_.bool_val(0)), + shape.dim_sizes())); return; } break; case DT_FLOAT: if (proto_.float_val_size() == 1) { - ctx->SetOutput( - 0, b->Broadcast(b->ConstantR0(proto_.float_val(0)), - shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0( + b, proto_.float_val(0)), + shape.dim_sizes())); return; } break; case DT_DOUBLE: if (proto_.double_val_size() == 1) { - ctx->SetOutput( - 0, b->Broadcast(b->ConstantR0(proto_.double_val(0)), - shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0( + b, proto_.double_val(0)), + shape.dim_sizes())); return; } break; case DT_INT32: if (proto_.int_val_size() == 1) { - ctx->SetOutput(0, - b->Broadcast(b->ConstantR0(proto_.int_val(0)), - shape.dim_sizes())); + ctx->SetOutput( + 0, xla::Broadcast(xla::ConstantR0(b, proto_.int_val(0)), + shape.dim_sizes())); return; } break; case DT_INT64: if (proto_.int64_val_size() == 1) { - ctx->SetOutput( - 0, b->Broadcast(b->ConstantR0(proto_.int64_val(0)), - shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0( + b, proto_.int64_val(0)), + shape.dim_sizes())); return; } break; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 627bad12f33c82e91bc3c6f3323f562bc8174056..5da7972397b32fb4a2f216913e065c04131a3773 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -51,8 +53,8 @@ xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype, xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return builder->Broadcast(XlaHelpers::Zero(builder, dtype), - expanded_filter_shape.dim_sizes()); + return xla::Broadcast(XlaHelpers::Zero(builder, dtype), + expanded_filter_shape.dim_sizes()); } // Create a mask for depthwise convolution that will make a normal convolution @@ -95,32 +97,27 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, // Create a M sized linspace and an M*N sized linspace that will be // broadcasted into perpendicular dimensions and compared. - xla::XlaOp input_feature_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, - &input_feature_iota)); - xla::XlaOp expanded_feature_iota; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, - input_feature * depthwise_multiplier, - &expanded_feature_iota)); + xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); + xla::XlaOp expanded_feature_iota = + xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); // Divide the M*N sized linspace by the depthwise_multiplier to create // [0 0 1 1 2 2] in the example in the function comment. expanded_feature_iota = - builder->Div(expanded_feature_iota, - XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, - depthwise_multiplier)); + xla::Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); // Broadcast the N*M linspace to [H, W, ..., M, M*N]. auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); expanded_feature_broadcast_dims.pop_back(); - auto broadcasted_expanded_feature_iota = builder->Broadcast( - expanded_feature_iota, expanded_feature_broadcast_dims); + auto broadcasted_expanded_feature_iota = + xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); // Compare the broadcasted linspace to the input feature linspace in the // input feature dimension to create a diagonal predicate. - return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota, - {expanded_filter_shape.dims() - 2}); + return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dims() - 2}); } // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding @@ -142,16 +139,16 @@ xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape, implicit_broadcast_filter_shape.dims() - 1, depthwise_multiplier * input_feature); auto implicit_broadcast_filter = - builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); + xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); // Broadcast the filter to [H, W, ..., M, M*N]. auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); - auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero); + auto expanded_filter = xla::Add(implicit_broadcast_filter, expanded_zero); // If the filter mask is set, choose the broadcasted filter, othwerwise, // choose zero. - return builder->Select(CreateExpandedFilterMask(filter_shape, builder), - expanded_filter, expanded_zero); + return xla::Select(CreateExpandedFilterMask(filter_shape, builder), + expanded_filter, expanded_zero); } // Inverse of ExpandFilterForDepthwiseConvolution. @@ -162,17 +159,17 @@ xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - auto masked_expanded_filter = builder->Select( + auto masked_expanded_filter = xla::Select( CreateExpandedFilterMask(filter_shape, builder), filter_backprop, CreateExpandedZero(filter_shape, dtype, builder)); - return builder->Reshape( + return xla::Reshape( // This reduce does not need inputs to be converted with // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with // ExpandedZero guarantees that only one element is non zero, so there // cannot be accumulated precision error. - builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), - *ctx->GetOrCreateAdd(dtype), - {expanded_filter_shape.dims() - 2}), + xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), + *ctx->GetOrCreateAdd(dtype), + {expanded_filter_shape.dims() - 2}), filter_shape.dim_sizes()); } @@ -289,8 +286,8 @@ class ConvOp : public XlaOpKernel { } xla::XlaOp conv = - b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, - lhs_dilation, rhs_dilation, dims); + xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, + lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); } @@ -435,11 +432,11 @@ class ConvBackpropInputOp : public XlaOpKernel { } // Mirror the filter in the spatial dimensions. - xla::XlaOp mirrored_weights = b->Rev(filter, kernel_spatial_dims); + xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); // activation gradients // = gradients (with padding and dilation) mirrored_weights - xla::XlaOp in_backprop = b->ConvGeneralDilated( + xla::XlaOp in_backprop = xla::ConvGeneralDilated( out_backprop, mirrored_weights, /*window_strides=*/ones, padding, lhs_dilation, rhs_dilation, dnums); @@ -638,8 +635,8 @@ class ConvBackpropFilterOp : public XlaOpKernel { // This is done by specifying the window dilation factors in the // convolution HLO below. auto filter_backprop = - b->ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); + xla::ConvGeneralDilated(activations, gradients, window_strides, padding, + /*lhs_dilation=*/ones, rhs_dilation, dnums); if (depthwise_) { filter_backprop = ContractFilterForDepthwiseBackprop( diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index 7fcd4170fb79a574663c1abffe873d4b53f471d3..db579a5b35d69deb3dca578e31c1b54fada76342 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { @@ -58,21 +59,21 @@ class CrossOp : public XlaOpKernel { auto in1 = ctx->Input(1); starts.back() = 0; limits.back() = 1; - auto u1 = b->Slice(in0, starts, limits, strides); - auto v1 = b->Slice(in1, starts, limits, strides); + auto u1 = xla::Slice(in0, starts, limits, strides); + auto v1 = xla::Slice(in1, starts, limits, strides); starts.back() = 1; limits.back() = 2; - auto u2 = b->Slice(in0, starts, limits, strides); - auto v2 = b->Slice(in1, starts, limits, strides); + auto u2 = xla::Slice(in0, starts, limits, strides); + auto v2 = xla::Slice(in1, starts, limits, strides); starts.back() = 2; limits.back() = 3; - auto u3 = b->Slice(in0, starts, limits, strides); - auto v3 = b->Slice(in1, starts, limits, strides); + auto u3 = xla::Slice(in0, starts, limits, strides); + auto v3 = xla::Slice(in1, starts, limits, strides); - auto s1 = b->Sub(b->Mul(u2, v3), b->Mul(u3, v2)); - auto s2 = b->Sub(b->Mul(u3, v1), b->Mul(u1, v3)); - auto s3 = b->Sub(b->Mul(u1, v2), b->Mul(u2, v1)); - auto output = b->ConcatInDim({s1, s2, s3}, in0_shape.dims() - 1); + auto s1 = xla::Sub(xla::Mul(u2, v3), xla::Mul(u3, v2)); + auto s2 = xla::Sub(xla::Mul(u3, v1), xla::Mul(u1, v3)); + auto s3 = xla::Sub(xla::Mul(u1, v2), xla::Mul(u2, v1)); + auto output = xla::ConcatInDim(b, {s1, s2, s3}, in0_shape.dims() - 1); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 01aa1a83e7967921f1583b3ef18ec57e452dcfea..ef1015552d181a183d412f9c269dd5ec608b388f 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/bcast.h" @@ -96,18 +96,16 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { // First reshape the inputs, which should be a metadata-only // operation since we are flattening the dimensions in order. - auto lhs_shaped = builder->Reshape(lhs, broadcast_helper.x_reshape()); - auto rhs_shaped = builder->Reshape(rhs, broadcast_helper.y_reshape()); + auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape()); + auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape()); // Next broadcast the necessary input dimensions. We rely on the // XLA optimizer to be smart about the fact that we are asking // it to broadcast size 1 on some of these dimensions, to avoid // adding complexity to this code. - auto lhs_broadcast = - builder->Broadcast(lhs_shaped, broadcast_helper.x_bcast()); + auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast()); int lhs_size = broadcast_helper.x_bcast().size(); - auto rhs_broadcast = - builder->Broadcast(rhs_shaped, broadcast_helper.y_bcast()); + auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast()); int rhs_size = broadcast_helper.y_bcast().size(); // Now reshape them to the correct output shape. After the @@ -122,15 +120,15 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { lhs_reorder.push_back(i); lhs_reorder.push_back(i + lhs_size); } - auto lhs_output = builder->Reshape(lhs_broadcast, lhs_reorder, - broadcast_helper.output_shape()); + auto lhs_output = + xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape()); std::vector rhs_reorder; for (int i = 0; i < rhs_size; ++i) { rhs_reorder.push_back(i); rhs_reorder.push_back(i + rhs_size); } - auto rhs_output = builder->Reshape(rhs_broadcast, rhs_reorder, - broadcast_helper.output_shape()); + auto rhs_output = + xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape()); return {lhs_output, rhs_output}; } diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 4f92dbc8740b697322424058530b8477c35d809a..a5b870f8dbf70bcee331992345d63fd5d986bdca 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 23243f62462c6315e359d9621823b19fc98c6218..12b0e38288e8f222ed506a75ec2575f27141c859 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -50,7 +51,6 @@ class DepthToSpaceOp : public XlaOpKernel { const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); - xla::XlaBuilder* b = ctx->builder(); xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); @@ -130,7 +130,7 @@ class DepthToSpaceOp : public XlaOpKernel { ") is not divisible by square of the block size (", block_size_, ")")); - xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -141,7 +141,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2], // block_size_, // depth / (block_size_ * block_size_)] - xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -151,7 +151,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2] * block_size_, // depth / (block_size_ * block_size_)] // - xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 931705ba837153e1175cd9a209876ef5ec93f0fc..ed44ad218b6dc073583ec339da082b6881ad672d 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -18,6 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -25,10 +28,10 @@ namespace tensorflow { namespace { // Create a diagonal / batch diagonal matrix with 'input' on the diagonal. -xla::StatusOr CreateDiagonal( - const xla::XlaOp& input, int64 last_dim_size, - tensorflow::gtl::ArraySlice other_dims, XlaOpKernelContext* ctx, - xla::XlaBuilder* builder) { +xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size, + gtl::ArraySlice other_dims, + xla::PrimitiveType element_type) { + xla::XlaBuilder* builder = input.builder(); // Create two matrices that have the following forms, and compare them: // // [[0, 0, 0, 0] [[0, 1, 2, 3] @@ -38,16 +41,14 @@ xla::StatusOr CreateDiagonal( // // This produces a predicate matrix of the right size, with "true" on the // diagonal. - xla::XlaOp iota; - TF_RETURN_IF_ERROR( - XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); - xla::XlaOp iota_broadcast = builder->Broadcast(iota, {last_dim_size}); - xla::XlaOp mask = builder->Eq(iota_broadcast, iota, {0}); + xla::XlaOp iota = xla::Iota(builder, xla::S32, last_dim_size); + xla::XlaOp iota_broadcast = xla::Broadcast(iota, {last_dim_size}); + xla::XlaOp mask = xla::Eq(iota_broadcast, iota, {0}); // If this is a batched diagonal, broadcast the mask across the other // dimensions. if (!other_dims.empty()) { - mask = builder->Broadcast(mask, other_dims); + mask = xla::Broadcast(mask, other_dims); } // Broadcast the input, and then use the mask computed above to select the @@ -64,18 +65,15 @@ xla::StatusOr CreateDiagonal( std::vector broadcast_dims(other_dims.begin(), other_dims.end()); broadcast_dims.push_back(1LL); broadcast_dims.push_back(last_dim_size); - xla::XlaOp input_broadcast = builder->Reshape(input, broadcast_dims); + xla::XlaOp input_broadcast = xla::Reshape(input, broadcast_dims); broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; - xla::PrimitiveType element_type; - TF_RETURN_IF_ERROR( - DataTypeToPrimitiveType(ctx->input_type(0), &element_type)); auto broadcast_shape = xla::ShapeUtil::MakeShape(element_type, broadcast_dims); - xla::XlaOp zeros = Zeros(builder, broadcast_shape); + xla::XlaOp zeros = xla::Zeros(builder, broadcast_shape); - input_broadcast = builder->Add(input_broadcast, zeros); - return builder->Select(mask, input_broadcast, zeros); + input_broadcast = xla::Add(input_broadcast, zeros); + return xla::Select(mask, input_broadcast, zeros); } class DiagOp : public XlaOpKernel { @@ -83,8 +81,6 @@ class DiagOp : public XlaOpKernel { explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("Diag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); @@ -104,19 +100,17 @@ class DiagOp : public XlaOpKernel { // Flattens the input to 1D. int64 size = input_shape.num_elements(); - input = builder->Reshape(input, {size}); + input = xla::Reshape(input, {size}); // Create an R2 with the R1 diagonal. - auto diag_or_status = - CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder); - OP_REQUIRES_OK(ctx, diag_or_status.status()); - xla::XlaOp diag = diag_or_status.ValueOrDie(); + xla::XlaOp diag = + CreateDiagonal(input, size, /*other_dims=*/{}, ctx->input_xla_type(0)); // Reshapes to the final shape. std::vector new_dims(dims.size() * 2); std::copy(dims.begin(), dims.end(), new_dims.begin()); std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size()); - diag = builder->Reshape(diag, new_dims); + diag = xla::Reshape(diag, new_dims); ctx->SetOutput(0, diag); } @@ -129,8 +123,6 @@ class DiagPartOp : public XlaOpKernel { explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -156,37 +148,13 @@ class DiagPartOp : public XlaOpKernel { new_dims.push_back(dims[i]); } - xla::XlaOp diag = ctx->Input(0); - - // TODO(b/30878775): use Slice with strides when supported, in place of - // the Pad -> Reshape -> Slice. - - // Picture: - // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0], - // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0], - // [0, 0, 3, 0] [3, 0, 0, 0, 0], - // [0, 0, 0, 4]] [4, 0, 0, 0, 0]] - // and then slice out the first column. - - // Flattens the input to 1D. - int64 size = input_shape.num_elements(); - diag = builder->Reshape(diag, {size}); - - // Adds padding after the last element of 'new_size'. - xla::PaddingConfig config; - auto* dim = config.add_dimensions(); - dim->set_edge_padding_high(new_size); - auto zero = XlaHelpers::Zero(builder, input_type(0)); - diag = builder->Pad(diag, zero, config); - - // Reshapes so the diagonal is now in the first column. - diag = builder->Reshape(diag, {new_size, new_size + 1}); + xla::XlaOp input = ctx->Input(0); - // Slices out the first column and reshapes to the final shape. - diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1}); - diag = builder->Reshape(diag, new_dims); + xla::XlaOp output = xla::Reshape( + xla::GetMatrixDiagonal(xla::Reshape(input, {new_size, new_size})), + new_dims); - ctx->SetOutput(0, diag); + ctx->SetOutput(0, output); } }; @@ -197,8 +165,6 @@ class MatrixDiagOp : public XlaOpKernel { explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("MatrixDiag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); @@ -208,17 +174,15 @@ class MatrixDiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::XlaOp diag = ctx->Input(0); int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); tensorflow::gtl::ArraySlice other_dims(dims); other_dims.pop_back(); - auto diag_or_status = - CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder); - OP_REQUIRES_OK(ctx, diag_or_status.status()); - diag = diag_or_status.ValueOrDie(); + xla::XlaOp input = ctx->Input(0); + xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims, + ctx->input_xla_type(0)); ctx->SetOutput(0, diag); } }; @@ -230,8 +194,6 @@ class MatrixDiagPartOp : public XlaOpKernel { explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -239,71 +201,8 @@ class MatrixDiagPartOp : public XlaOpKernel { errors::InvalidArgument("Expected 2 <= dims, got shape ", input_shape.DebugString())); - xla::XlaOp diag = ctx->Input(0); - - int last_dim = dims.size() - 1; - int64 last_dim_size = dims[last_dim]; - - // The smaller of the last two dimension sizes. - int64 smaller_dim_size = std::min(dims[last_dim - 1], dims[last_dim]); - - // TODO(b/30878775): use Slice with strides when supported, in place of - // the Pad -> Reshape -> Slice. - - // Picture: for each 2D matrix in the tensor's last two dimensions: - // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0], - // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0], - // [0, 0, 3, 0]] [3, 0, 0, 0, 0], - // and then slice out the first column. - // - // Another example, with tall and narrow input. - // [[1, 0] pad and reshape to [[1, 0, 0], - // [0, 2] =================> [2, 0, 0]] - // [0, 0] - // [0, 0]] - - // Collapses the last two dimensions. - std::vector flattened_dims(dims.begin(), dims.end() - 1); - flattened_dims.back() *= dims.back(); - diag = builder->Reshape(diag, flattened_dims); - - // Slices or pads the last dimension to 'target_size'. - int64 actual_size = flattened_dims.back(); - int64 target_size = smaller_dim_size * (last_dim_size + 1); - if (actual_size < target_size) { - xla::PaddingConfig config = - xla::MakeNoPaddingConfig(flattened_dims.size()); - auto* dim = config.mutable_dimensions(flattened_dims.size() - 1); - dim->set_edge_padding_high(target_size - actual_size); - auto zero = XlaHelpers::Zero(builder, input_type(0)); - diag = builder->Pad(diag, zero, config); - } else if (actual_size > target_size) { - std::vector start(flattened_dims.size(), 0); - std::vector limits(flattened_dims.begin(), flattened_dims.end()); - std::vector strides(flattened_dims.size(), 1); - limits[flattened_dims.size() - 1] = target_size; - diag = builder->Slice(diag, start, limits, strides); - } - - // Reshape so the target values are in the first position of the last - // dimension. - std::vector unflattened_dims(dims.begin(), dims.end()); - dims[last_dim - 1] = smaller_dim_size; - dims[last_dim] = last_dim_size + 1; - diag = builder->Reshape(diag, dims); - - // Slices out the first column and reshapes to the final shape. - std::vector start(dims.size(), 0); - std::vector limits(dims.begin(), dims.end()); - std::vector strides(dims.size(), 1); - limits[last_dim] = 1; - diag = builder->Slice(diag, start, limits, strides); - - // Collapses away the last dimension. - dims.pop_back(); - diag = builder->Reshape(diag, dims); - - ctx->SetOutput(0, diag); + xla::XlaOp input = ctx->Input(0); + ctx->SetOutput(0, xla::GetMatrixDiagonal(input)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index 0419de78b2ee83fd395e8bf23444fde84f30bba2..a3389d5b905bf3ee15744ab4fcee193d312e2ae0 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -57,8 +57,8 @@ class DynamicUpdateSliceOp : public XlaOpKernel { input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); - xla::XlaOp result = ctx->builder()->DynamicUpdateSlice( - ctx->Input(0), ctx->Input(1), ctx->Input(2)); + xla::XlaOp result = + xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2)); ctx->SetOutput(0, result); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index dd4a16908779508380b36f43ce2306ff2f5fb8c4..cb73053666d4c32bc0a2ef19b174aee1a29f101e 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -150,8 +151,7 @@ class DynamicStitchOp : public XlaOpKernel { if (new_shape == data_shapes[input_num]) { input[input_num] = handle; } else { - input[input_num] = - ctx->builder()->Reshape(handle, new_shape.dim_sizes()); + input[input_num] = xla::Reshape(handle, new_shape.dim_sizes()); } } @@ -175,10 +175,10 @@ class DynamicStitchOp : public XlaOpKernel { // And place it in the concat list in the place indicated by // the index. to_concat[index_num] = - ctx->builder()->Slice(expression, slice_start, slice_limit, stride); + xla::Slice(expression, slice_start, slice_limit, stride); } - ctx->SetOutput(0, ctx->builder()->ConcatInDim(to_concat, 0)); + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), to_concat, 0)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 493781a1e68b8906f1a7e018e5710130e2eb08b5..5fdb1d972c55efb876972d3f472b53a1f7cde1c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" @@ -34,9 +34,9 @@ class EluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Expm1(ctx->Input(0)); - ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); + const auto pred = xla::Gt(ctx->Input(0), zero); + const auto expm1 = xla::Expm1(ctx->Input(0)); + ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), expm1)); } }; @@ -51,9 +51,9 @@ class EluGradOp : public XlaOpKernel { const auto one = XlaHelpers::One(b, input_type(0)); const auto grad = ctx->Input(0); const auto activation = ctx->Input(1); - const auto exp_grad = b->Mul(grad, b->Add(activation, one)); - const auto pred = b->Gt(activation, zero); - ctx->SetOutput(0, b->Select(pred, grad, exp_grad)); + const auto exp_grad = xla::Mul(grad, xla::Add(activation, one)); + const auto pred = xla::Gt(activation, zero); + ctx->SetOutput(0, xla::Select(pred, grad, exp_grad)); } }; @@ -71,10 +71,10 @@ class SeluOp : public XlaOpKernel { 1.0507009873554804934193349852946); const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), 1.7580993408473768599402175208123); - const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Expm1(ctx->Input(0)); - ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)), - b->Mul(scale_alpha, expm1))); + const auto pred = xla::Gt(ctx->Input(0), zero); + const auto expm1 = xla::Expm1(ctx->Input(0)); + ctx->SetOutput(0, xla::Select(pred, xla::Mul(scale, ctx->Input(0)), + xla::Mul(scale_alpha, expm1))); } }; @@ -92,10 +92,10 @@ class SeluGradOp : public XlaOpKernel { 1.7580993408473768599402175208123); const auto grad = ctx->Input(0); const auto activation = ctx->Input(1); - const auto lin_grad = b->Mul(grad, scale); - const auto exp_grad = b->Mul(grad, b->Add(activation, scale_alpha)); - const auto pred = b->Gt(activation, zero); - ctx->SetOutput(0, b->Select(pred, lin_grad, exp_grad)); + const auto lin_grad = xla::Mul(grad, scale); + const auto exp_grad = xla::Mul(grad, xla::Add(activation, scale_alpha)); + const auto pred = xla::Gt(activation, zero); + ctx->SetOutput(0, xla::Select(pred, lin_grad, exp_grad)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 6df01cabbf1d98c0299bfd808bcc6db6223c4777..c68b0bfd7961892294c2931e5c4c44de534a7740 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -110,13 +112,11 @@ class ExtractImagePatchesOp : public XlaOpKernel { // Builds an identity matrix as a broadcast equality of iotas. // iota = np.arange(np.prod(ksize), depth) // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) - xla::XlaOp iota; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, - kernel_size * depth, &iota)); + xla::XlaOp iota = xla::Iota(builder, xla::S32, kernel_size * depth); - auto lhs = builder->Reshape(iota, lhs_shape); - auto filter = builder->ConvertElementType( - builder->Eq(lhs, iota, {num_spatial_dims + 1}), type); + auto lhs = xla::Reshape(iota, lhs_shape); + auto filter = xla::ConvertElementType( + xla::Eq(lhs, iota, {num_spatial_dims + 1}), type); xla::ConvolutionDimensionNumbers dims; std::vector window_strides(num_spatial_dims); @@ -148,8 +148,8 @@ class ExtractImagePatchesOp : public XlaOpKernel { } xla::XlaOp conv = - builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides, - padding, lhs_dilation, rhs_dilation, dims); + xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, + lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); } diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 8f0de0a524c908b598c1a2165a462275346ad137..cdba6680dee3fade5bdf0c453ed672b653072b0d 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -49,20 +50,20 @@ void XlaNudge(xla::XlaBuilder* b, const DataType data_type, const float quant_min_value, const float quant_max_value, xla::XlaOp* nudged_min, xla::XlaOp* nudged_max, xla::XlaOp* scale) { - *scale = b->Div(b->Sub(max, min), - XlaHelpers::FloatLiteral(b, data_type, - quant_max_value - quant_min_value)); + *scale = xla::Div(xla::Sub(max, min), + XlaHelpers::FloatLiteral( + b, data_type, quant_max_value - quant_min_value)); xla::XlaOp quant_min = XlaHelpers::FloatLiteral(b, data_type, quant_min_value); - xla::XlaOp zero_point_from_min = b->Sub(quant_min, b->Div(min, *scale)); + xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale)); xla::XlaOp quant_max = XlaHelpers::FloatLiteral(b, data_type, quant_max_value); xla::XlaOp nudged_zero_point = - b->Select(b->Le(zero_point_from_min, quant_min), quant_min, - b->Select(b->Ge(zero_point_from_min, quant_max), quant_max, - b->Round(zero_point_from_min))); - *nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale); - *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale); + xla::Select(xla::Le(zero_point_from_min, quant_min), quant_min, + xla::Select(xla::Ge(zero_point_from_min, quant_max), + quant_max, xla::Round(zero_point_from_min))); + *nudged_min = xla::Mul(xla::Sub(quant_min, nudged_zero_point), *scale); + *nudged_max = xla::Mul(xla::Sub(quant_max, nudged_zero_point), *scale); } xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, @@ -71,14 +72,14 @@ xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, const xla::XlaOp& nudged_input_max, const xla::XlaOp& input_scale) { xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); - xla::XlaOp inv_scale = b->Div(one, input_scale); + xla::XlaOp inv_scale = xla::Div(one, input_scale); xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f); - xla::XlaOp clamped = b->Clamp(nudged_input_min, input, nudged_input_max); - xla::XlaOp clamped_shifted = b->Sub(clamped, nudged_input_min); + xla::XlaOp clamped = xla::Clamp(nudged_input_min, input, nudged_input_max); + xla::XlaOp clamped_shifted = xla::Sub(clamped, nudged_input_min); xla::XlaOp rounded = - b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half)); - return b->Add(b->Mul(rounded, input_scale), nudged_input_min); + xla::Floor(xla::Add(xla::Mul(clamped_shifted, inv_scale), half)); + return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min); } class FakeQuantWithMinMaxArgsOp : public XlaOpKernel { @@ -163,11 +164,11 @@ class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel { xla::XlaOp nudged_input_max = XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); - xla::XlaOp between_nudged_min_max = - b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); - xla::XlaOp zeroes = b->Broadcast(XlaHelpers::Zero(b, data_type), - gradient_shape.dim_sizes()); - xla::XlaOp output = b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp between_nudged_min_max = xla::And( + xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); + xla::XlaOp zeroes = xla::Broadcast(XlaHelpers::Zero(b, data_type), + gradient_shape.dim_sizes()); + xla::XlaOp output = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output); } @@ -249,25 +250,25 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, &nudged_input_min, &nudged_input_max, &input_scale); - xla::XlaOp between_nudged_min_max = - b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); + xla::XlaOp between_nudged_min_max = xla::And( + xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max)); xla::XlaOp zero = XlaHelpers::Zero(b, data_type); - xla::XlaOp zeroes = b->Broadcast(zero, gradient_shape.dim_sizes()); - xla::XlaOp output0 = b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes()); + xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output0); - xla::XlaOp below_min = b->Lt(input, nudged_input_min); - xla::XlaOp select1 = b->Select(below_min, gradient, zeroes); - xla::XlaOp reduce1 = b->ReduceAll( + xla::XlaOp below_min = xla::Lt(input, nudged_input_min); + xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes); + xla::XlaOp reduce1 = xla::ReduceAll( XlaHelpers::ConvertElementType(b, select1, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type); ctx->SetOutput(1, output1); - xla::XlaOp above_max = b->Gt(input, nudged_input_max); - xla::XlaOp select2 = b->Select(above_max, gradient, zeroes); - xla::XlaOp reduce2 = b->ReduceAll( + xla::XlaOp above_max = xla::Gt(input, nudged_input_max); + xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes); + xla::XlaOp reduce2 = xla::ReduceAll( XlaHelpers::ConvertElementType(b, select2, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 933924cad1c7cac2879bd4720cb21ffc33c23f50..80bcef966360ec9a1ca63a02741108ce41b31846 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -62,8 +63,7 @@ class GenericFftOp : public XlaOpKernel { } } - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp fft = b->Fft(ctx->Input(0), fft_type_, fft_length); + xla::XlaOp fft = xla::Fft(ctx->Input(0), fft_type_, fft_length); ctx->SetOutput(0, fft); } diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index e4467a0fb138ed7919af62ed032c0f5abee3e4f6..54b21a278229024e3e54e9135548be6b69b077e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -59,11 +60,11 @@ class FillOp : public XlaOpKernel { xla::XlaOp data = ctx->Input(1); if (value_shape.dims() > 0) { CHECK_EQ(value_shape.dims(), 1); - data = ctx->builder()->Reshape(data, {}); + data = xla::Reshape(data, {}); } // Emit the actual computation, which broadcasts the scalar to the // desired shape. - auto result = ctx->builder()->Broadcast(data, broadcast); + auto result = xla::Broadcast(data, broadcast); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index d13e25bcddae16d0cd630403219657121b80868d..35de96e0aab847fa39ef26d5f3052c392062fd7d 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -75,8 +76,8 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, out_shape.AppendShape(indices_shape_no_index_vectors); out_shape.AppendShape(input_shape_post_axis); - *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - out_shape.dim_sizes()); + *gather_output = + xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes()); return Status::OK(); } @@ -142,7 +143,7 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, dim_numbers.add_gather_dims_to_operand_dims(i); } - *gather_output = builder->Gather(input, indices, dim_numbers, window_bounds); + *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index d898e43b858bac706d524c7c271f48b1b5fa258f..92346283c31dfe1d638526ac4b26ef762cd7fd14 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index d48c6eea754f75a8879d3938f233a6a591d26d0d..6a7eb8d90c45ab119096eaa259e05c6ca768c5aa 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { @@ -199,25 +200,35 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } } - xla::XlaOp outputs = - b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, - b->Tuple(inputs), *else_result.computation); + bool resource_variable_seen = false; + for (int i = 0; i < ctx->num_inputs(); ++i) { + if (ctx->input_type(i) == DT_RESOURCE) { + resource_variable_seen = true; + } else { + OP_REQUIRES( + ctx, !resource_variable_seen, + errors::FailedPrecondition( + "Resource variables and regular inputs cannot be interleaved.")); + } + } + + xla::XlaOp outputs = xla::Conditional( + ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation, + xla::Tuple(b, inputs), *else_result.computation); // Sets non-variable outputs. for (int i = 0; i < output_types_.size(); ++i) { - if (ctx->input_type(i) != DT_RESOURCE) { - xla::XlaOp output_handle = b->GetTupleElement(outputs, i); - if (VLOG_IS_ON(2)) { - LOG(INFO) << "Setting output " << i; - auto shape_or = b->GetShape(output_handle); - if (shape_or.ok()) { - LOG(INFO) << "Shape for output " << i << ": " - << xla::ShapeUtil::HumanString(shape_or.ValueOrDie()); - } else { - LOG(INFO) << "Shape unknown for output " << i; - } + xla::XlaOp output_handle = xla::GetTupleElement(outputs, i); + if (VLOG_IS_ON(2)) { + LOG(INFO) << "Setting output " << i; + auto shape_or = b->GetShape(output_handle); + if (shape_or.ok()) { + LOG(INFO) << "Shape for output " << i << ": " + << xla::ShapeUtil::HumanString(shape_or.ValueOrDie()); + } else { + LOG(INFO) << "Shape unknown for output " << i; } - ctx->SetOutput(i, output_handle); } + ctx->SetOutput(i, output_handle); } // Updates the values of any resource variables modified by the conditional @@ -233,7 +244,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, - b->GetTupleElement(outputs, pos), b)); + xla::GetTupleElement(outputs, pos), b)); } VLOG(2) << "If variable: pos: " << update.input_index << " name: " << resource->name() @@ -245,6 +256,8 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Done building If"; } +REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp); +REGISTER_XLA_OP(Name("StatelessIf").AllowResourceTypes(), XlaIfOp); REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 1568b33679963c1a6630525f60560180d40b8d53..33a73fe5fdf403e513be085dd7bcea3255277b4a 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -17,6 +17,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -32,23 +38,26 @@ std::array RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, auto red = rgb[0]; auto green = rgb[1]; auto blue = rgb[2]; - auto value = b->Max(b->Max(red, green), blue); - auto minimum = b->Min(b->Min(red, green), blue); - auto range = b->Sub(value, minimum); - - auto zeros = b->Broadcast(zero, shape.dim_sizes()); - auto saturation = b->Select(b->Gt(value, zero), b->Div(range, value), zeros); - - auto norm = b->Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range); - - auto hue = b->Select(b->Eq(green, value), - b->Add(b->Mul(norm, b->Sub(blue, red)), - XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)), - b->Add(b->Mul(norm, b->Sub(red, green)), - XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0))); - hue = b->Select(b->Eq(red, value), b->Mul(norm, b->Sub(green, blue)), hue); - hue = b->Select(b->Gt(range, zero), hue, zeros); - hue = b->Select(b->Lt(hue, zero), b->Add(hue, one), hue); + auto value = xla::Max(xla::Max(red, green), blue); + auto minimum = xla::Min(xla::Min(red, green), blue); + auto range = xla::Sub(value, minimum); + + auto zeros = xla::Broadcast(zero, shape.dim_sizes()); + auto saturation = + xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros); + + auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range); + + auto hue = + xla::Select(xla::Eq(green, value), + xla::Add(xla::Mul(norm, xla::Sub(blue, red)), + XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)), + xla::Add(xla::Mul(norm, xla::Sub(red, green)), + XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0))); + hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)), + hue); + hue = xla::Select(xla::Gt(range, zero), hue, zeros); + hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue); return {hue, saturation, value}; } @@ -66,15 +75,15 @@ std::array HSVToRGB(xla::XlaBuilder* b, auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0); auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0); - auto dh = b->Mul(hue, six); - auto dr = b->Clamp(zero, b->Sub(b->Abs(b->Sub(dh, three)), one), one); - auto dg = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, two))), one); - auto db = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, four))), one); - auto one_minus_s = b->Sub(one, saturation); + auto dh = xla::Mul(hue, six); + auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one); + auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one); + auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one); + auto one_minus_s = xla::Sub(one, saturation); - auto red = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dr)), value); - auto green = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dg)), value); - auto blue = b->Mul(b->Add(one_minus_s, b->Mul(saturation, db)), value); + auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value); + auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value); + auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value); return {red, green, blue}; } @@ -97,21 +106,21 @@ class RGBToHSVOp : public XlaOpKernel { xla::XlaBuilder* b = context->builder(); xla::XlaOp input = context->Input(0); - xla::XlaOp red = - b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp green = - b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp blue = - b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, - /*dimno=*/channel_dim); + xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), channel_shape); - context->SetOutput(0, b->ConcatInDim(hsv, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim)); } }; REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp); @@ -134,20 +143,20 @@ class HSVToRGBOp : public XlaOpKernel { xla::XlaBuilder* b = context->builder(); xla::XlaOp input = context->Input(0); - xla::XlaOp hue = - b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp saturation = - b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp value = - b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, - /*dimno=*/channel_dim); + xla::XlaOp hue = xla::SliceInDim(input, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp saturation = xla::SliceInDim(input, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp value = xla::SliceInDim(input, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); auto rgb = HSVToRGB(context->builder(), {hue, saturation, value}, context->input_type(0)); - context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); } }; REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp); @@ -182,18 +191,20 @@ class AdjustContrastOpV2 : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = XlaHelpers::ConvertElementType(b, input, accumulation_type); - auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *context->GetOrCreateAdd(accumulation_type), - {height_dim, width_dim}); + auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *context->GetOrCreateAdd(accumulation_type), + {height_dim, width_dim}); auto output = XlaHelpers::ConvertElementType(b, reduce, type); - output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); + output = + xla::Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); std::vector broadcast_dims(input_shape.dims() - 2); std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); broadcast_dims.back() = channel_dim; - output = b->Add(b->Mul(input, factor), - b->Mul(output, b->Sub(XlaHelpers::One(b, type), factor)), - broadcast_dims); + output = + xla::Add(xla::Mul(input, factor), + xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)), + broadcast_dims); context->SetOutput(0, output); } }; @@ -226,26 +237,26 @@ class AdjustSaturationOp : public XlaOpKernel { DataType type = context->input_type(0); - xla::XlaOp red = - b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp green = - b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp blue = - b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, - /*dimno=*/channel_dim); + xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), channel_shape); - hsv[1] = b->Clamp(XlaHelpers::Zero(b, type), b->Mul(hsv[1], scale), - XlaHelpers::One(b, type)); + hsv[1] = xla::Clamp(XlaHelpers::Zero(b, type), xla::Mul(hsv[1], scale), + XlaHelpers::One(b, type)); auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); - context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); } }; REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp); @@ -276,15 +287,15 @@ class AdjustHueOp : public XlaOpKernel { DataType type = context->input_type(0); - xla::XlaOp red = - b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp green = - b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, - /*dimno=*/channel_dim); - xla::XlaOp blue = - b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, - /*dimno=*/channel_dim); + xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, + /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, + /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2, + /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; channel_shape.set_dim(channel_dim, 1); auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), @@ -294,15 +305,161 @@ class AdjustHueOp : public XlaOpKernel { auto one = XlaHelpers::One(b, type); auto& hue = hsv[0]; - hue = b->Rem(b->Add(hsv[0], delta), one); - hue = b->Select(b->Lt(hue, zero), b->Rem(b->Add(one, hue), one), hue); + hue = xla::Rem(xla::Add(hsv[0], delta), one); + hue = + xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue); auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); - context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim)); } }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); +class NonMaxSuppressionOp : public XlaOpKernel { + public: + explicit NonMaxSuppressionOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size", + &pad_to_max_output_size_)); + } + + void Compile(XlaOpKernelContext* context) override { + // TODO(b/111646731): Improve scalability of this op, using blocking. + int num_boxes_dim = 0; + int coords_dim = 1; + const TensorShape& boxes_shape = context->InputShape("boxes"); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape), + errors::InvalidArgument("boxes must be 2-D, currently: ", + boxes_shape.DebugString())); + const int64 num_boxes = boxes_shape.dim_size(num_boxes_dim); + OP_REQUIRES(context, boxes_shape.dim_size(coords_dim) == 4, + errors::InvalidArgument("boxes must have 4 columns", + boxes_shape.DebugString())); + const TensorShape& scores_shape = context->InputShape("scores"); + OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape), + errors::InvalidArgument("scores must be 1-D, currently: ", + scores_shape.DebugString())); + OP_REQUIRES( + context, scores_shape.dim_size(0) == num_boxes, + errors::InvalidArgument("scores size must equal number of boxes", + scores_shape.DebugString())); + OP_REQUIRES(context, pad_to_max_output_size_, + errors::InvalidArgument( + "XLA compilation requires pad_to_max_output_size == True")); + + xla::XlaOp boxes = context->Input("boxes"); + xla::XlaOp scores = context->Input("scores"); + int64 output_size; + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size)); + OP_REQUIRES( + context, output_size >= 0, + errors::InvalidArgument("Need output_size >= 0, got ", output_size)); + xla::XlaOp score_thresh = context->Input("score_threshold"); + xla::XlaOp iou_thresh = context->Input("iou_threshold"); + + xla::XlaBuilder* const builder = context->builder(); + + // Choose a more convenient layout. + xla::XlaOp boxes_t = xla::Transpose(boxes, {1, 0}); + coords_dim = 0; + num_boxes_dim = 1; + + // Shapes are henceforth [1, num_boxes]. + xla::XlaOp coord_y0 = xla::SliceInDim(boxes_t, + /*start_index=*/0, + /*limit_index=*/1, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_x0 = xla::SliceInDim(boxes_t, + /*start_index=*/1, + /*limit_index=*/2, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_y1 = xla::SliceInDim(boxes_t, + /*start_index=*/2, + /*limit_index=*/3, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_x1 = xla::SliceInDim(boxes_t, + /*start_index=*/3, + /*limit_index=*/4, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp y1 = + xla::Select(xla::Le(coord_y0, coord_y1), coord_y0, coord_y1); + xla::XlaOp y2 = + xla::Select(xla::Le(coord_y0, coord_y1), coord_y1, coord_y0); + xla::XlaOp x1 = + xla::Select(xla::Le(coord_x0, coord_x1), coord_x0, coord_x1); + xla::XlaOp x2 = + xla::Select(xla::Le(coord_x0, coord_x1), coord_x1, coord_x0); + xla::XlaOp area = (y2 - y1) * (x2 - x1); + + // Transpose the 1xN tensors, instead of the NxN tensors. + xla::XlaOp y1_t = xla::Transpose(y1, {1, 0}); + xla::XlaOp y2_t = xla::Transpose(y2, {1, 0}); + xla::XlaOp x1_t = xla::Transpose(x1, {1, 0}); + xla::XlaOp x2_t = xla::Transpose(x2, {1, 0}); + xla::XlaOp area_t = xla::Transpose(area, {1, 0}); + + // Shapes are henceforth [num_boxes, num_boxes]. + xla::XlaOp i_xmin = xla::Max(x1, x1_t); + xla::XlaOp i_ymin = xla::Max(y1, y1_t); + xla::XlaOp i_xmax = xla::Min(x2, x2_t); + xla::XlaOp i_ymax = xla::Min(y2, y2_t); + auto square_zero = xla::ZerosLike(i_xmin); + + xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) * + xla::Max(i_ymax - i_ymin, square_zero); + xla::XlaOp u_area = area + area_t - i_area; + xla::XlaOp iou = i_area / u_area; + + xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero); + xla::XlaOp scores_2d = xla::Reshape(scores, {num_boxes, 1}); + xla::XlaOp score_cmp_mask = + xla::Gt(scores_2d, xla::Transpose(scores_2d, {1, 0})); + xla::XlaOp suppress = xla::And(iou_thresh_mask, score_cmp_mask); + + // Shapes are [num_boxes] after the reduce. + xla::XlaOp included_iou = xla::Not(xla::Reduce( + suppress, + /*init_value=*/xla::ConstantR0(builder, false), + /*computation=*/CreateScalarOrComputation(xla::PRED, builder), + /*dimensions_to_reduce=*/{0})); + xla::XlaOp included_score = + xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes})); + xla::XlaOp included = xla::And(included_iou, included_score); + xla::XlaOp neg_inf = + xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes}); + xla::XlaOp scores_included = xla::Select(included, scores, neg_inf); + + xla::XlaOp ones_included = xla::Select( + included, + xla::Broadcast(xla::ConstantR0(builder, 1), {num_boxes}), + xla::Broadcast(xla::ConstantR0(builder, 0), {num_boxes})); + + // num_valid is scalar. + xla::XlaOp num_valid = xla::Reduce( + ones_included, + /*init_value=*/xla::ConstantR0(builder, 0), + /*computation=*/CreateScalarAddComputation(xla::S32, builder), + /*dimensions_to_reduce=*/{0}); + + xla::XlaOp output_tuple = TopK(scores_included, output_size); + xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1); + + context->SetOutput(0, selected_indices); + context->SetOutput(1, num_valid); + } + + private: + bool pad_to_max_output_size_; +}; + +REGISTER_XLA_OP( + Name("NonMaxSuppressionV4").CompileTimeConstInput("max_output_size"), + NonMaxSuppressionOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 79d3a6979cec4c6bda92a71dcff4ddd2151367d5..8d75624e74028ea083c3facc4f9578ec14c50e6d 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/math/math_util.h" @@ -127,48 +129,41 @@ const int64 kMax2DKernelSize = 16; xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, gtl::ArraySlice kernel_size, int64 channels) { - xla::XlaOp channels_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK( - XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); - - auto diag = builder->ConvertElementType( - builder->Eq( - builder->Broadcast(channels_iota, {2 * kernel_size[0] - 1, + xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); + + auto diag = xla::ConvertElementType( + xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1, 2 * kernel_size[1] - 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), + channels_iota, /*broadcast_dimensions=*/{2}), xla::PrimitiveType::F32); - return builder->Mul( - builder->Mul(diag, - builder->ConstantR1(Make1DKernel(kernel_size[1])), - /*broadcast_dimensions=*/{1}), - builder->ConstantR1(Make1DKernel(kernel_size[0])), + return xla::Mul( + xla::Mul(diag, + xla::ConstantR1(builder, Make1DKernel(kernel_size[1])), + /*broadcast_dimensions=*/{1}), + xla::ConstantR1(builder, Make1DKernel(kernel_size[0])), /*broadcast_dimensions=*/{0}); } xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, gtl::ArraySlice kernel_size, int64 channels, int64 dim) { - xla::XlaOp channels_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK( - XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); - - auto diag = builder->ConvertElementType( - builder->Eq(builder->Broadcast( - channels_iota, - {dim == 0 ? (2 * kernel_size[0] - 1) : 1, - dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), + xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); + + auto diag = xla::ConvertElementType( + xla::Eq( + xla::Broadcast(channels_iota, + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), + channels_iota, /*broadcast_dimensions=*/{2}), xla::PrimitiveType::F32); if (dim == 1) { - return builder->Mul( - diag, builder->ConstantR1(Make1DKernel(kernel_size[1])), + return xla::Mul( + diag, xla::ConstantR1(builder, Make1DKernel(kernel_size[1])), /*broadcast_dimensions=*/{1}); } - return builder->Mul(diag, - builder->ConstantR1(Make1DKernel(kernel_size[0])), - /*broadcast_dimensions=*/{0}); + return xla::Mul(diag, + xla::ConstantR1(builder, Make1DKernel(kernel_size[0])), + /*broadcast_dimensions=*/{0}); } xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, @@ -208,7 +203,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { xla::XlaOp kernel = MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( input, kernel, dims.stride, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, @@ -218,7 +213,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( input, kernel0, {dims.stride[0], 1}, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, @@ -226,7 +221,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*rhs_dilation=*/{1, 1}, dimension_numbers); xla::XlaOp kernel1 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( output, kernel1, {1, dims.stride[1]}, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, @@ -238,8 +233,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, // size > 1 dimension. for (int i = 0; i < num_spatial_dims; ++i) { if (in_size[i] == 1 && out_size[i] > 1) { - output = builder->Add(output, builder->ConstantR1(out_size[i], 0), - /*broadcast_dimensions=*/{1 + i}); + output = xla::Add(output, xla::ConstantR1(builder, out_size[i], 0), + /*broadcast_dimensions=*/{1 + i}); } } return output; @@ -279,12 +274,12 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, for (int i = 0; i < num_spatial_dims; ++i) { if (in_size[i] == 1 && grad_size[i] > 1) { kernel = - builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), - /*broadcast_dimensions=*/{i}); + xla::Add(kernel, xla::ConstantR1(builder, grad_size[i], 0), + /*broadcast_dimensions=*/{i}); } } - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( grad, kernel, /*window_strides=*/dims.kernel_size, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, @@ -302,23 +297,23 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, // gradient contributions in that dimension. if (in_size[0] == 1 && grad_size[0] > 1) { kernel0 = - builder->Add(kernel0, builder->ConstantR1(grad_size[0], 0), - /*broadcast_dimensions=*/{0}); + xla::Add(kernel0, xla::ConstantR1(builder, grad_size[0], 0), + /*broadcast_dimensions=*/{0}); } if (in_size[1] == 1 && grad_size[1] > 1) { kernel1 = - builder->Add(kernel0, builder->ConstantR1(grad_size[1], 0), - /*broadcast_dimensions=*/{1}); + xla::Add(kernel0, xla::ConstantR1(builder, grad_size[1], 0), + /*broadcast_dimensions=*/{1}); } - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1}, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, /*lhs_dilation=*/{dims.stride[0], 1}, /*rhs_dilation=*/{1, 1}, dimension_numbers); - output = builder->ConvGeneralDilated( + output = xla::ConvGeneralDilated( output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, @@ -337,7 +332,7 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, } } if (pad_output) { - output = builder->Pad(output, builder->ConstantR0(0.0f), padding); + output = xla::Pad(output, xla::ConstantR0(builder, 0.0f), padding); } return output; } @@ -393,13 +388,13 @@ class ResizeBilinearOp : public XlaOpKernel { } } if (slice_input) { - input = b->Slice(input, {0, 0, 0, 0}, - {batch, slice_size[0], slice_size[1], channels}, - {1, 1, 1, 1}); + input = xla::Slice(input, {0, 0, 0, 0}, + {batch, slice_size[0], slice_size[1], channels}, + {1, 1, 1, 1}); } // Output is always type float. - input = b->ConvertElementType(input, xla::F32); + input = xla::ConvertElementType(input, xla::F32); // Special Case: // Instead of doing a ResizeUsingDilationAndConvolution directly, @@ -529,7 +524,7 @@ class ResizeBilinearGradOp : public XlaOpKernel { } } - output = b->ConvertElementType(output, output_type_); + output = xla::ConvertElementType(output, output_type_); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 36eb4c75454ed82804c40b82e5dbaec2eef0a719..f3964748587c1b31cf8b1b76643ff19a9044bf44 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -60,19 +60,15 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { input_shape.DebugString())); DataType index_type = output_type(0); + xla::PrimitiveType index_xla_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(index_type, &index_xla_type)); - xla::XlaBuilder* b = ctx->builder(); xla::XlaOp input = ctx->Input(0); - xla::XlaOp output; if (is_min_) { - OP_REQUIRES_OK(ctx, - XlaHelpers::ArgMin(b, ctx, input, input_shape, input_type(0), - index_type, axis, &output)); + output = XlaHelpers::ArgMin(input, index_xla_type, axis); } else { - OP_REQUIRES_OK(ctx, - XlaHelpers::ArgMax(b, ctx, input, input_shape, input_type(0), - index_type, axis, &output)); + output = XlaHelpers::ArgMax(input, index_xla_type, axis); } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 2c2d88486fda99d2380382a3e2f633f5bdc7478c..22a45b2a11e8ecb688f8e773ef4b286eafe68f4f 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -76,14 +77,15 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // XLA passes to the function, so it is not included here. std::vector args; args.push_back(ctx->Input(0)); - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1(input_shape.dim_sizes()))); + args.push_back(xla::ConstantLiteral( + &b, *xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1(output_shape.dim_sizes()))); - args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0(dim))); + args.push_back(xla::ConstantLiteral( + &b, *xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); + args.push_back( + xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0(dim))); } xla::Shape xla_shape = @@ -94,10 +96,12 @@ class ArgMaxCustomCallOp : public XlaOpKernel { xla::XlaOp output; switch (input_shape.dims()) { case 1: - output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape); + output = + xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape); break; case 2: - output = b.CustomCall("argmax_float_2d_xla_impl", args, xla_shape); + output = + xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape); break; default: OP_REQUIRES(ctx, false, diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index 1decf7d72d72bb697477e7f841ced2a1a0d5fbe9..f028e361bccd51de0bd69a1d2227c7afaed53455 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" @@ -39,12 +39,12 @@ class L2LossOp : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); auto t = XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); - auto square = b->Mul(t, t); - auto reduce = b->Reduce(square, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), dims); + auto square = xla::Mul(t, t); + auto reduce = xla::Reduce(square, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), dims); auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype); auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); - ctx->SetOutput(0, b->Div(deconverted, two)); + ctx->SetOutput(0, xla::Div(deconverted, two)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index 0388b4c830702ea00ec69fc42c6468326c88cf38..a11bbe918f7f8eb050aaa40d4344f9cc9e9a10a4 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/core/errors.h" @@ -90,8 +91,10 @@ class ListDiffOp : public XlaOpKernel { idx_output.push_back(i); } - context->SetOutput(0, context->builder()->ConstantR1(val_output)); - context->SetOutput(1, context->builder()->ConstantR1(idx_output)); + context->SetOutput(0, + xla::ConstantR1(context->builder(), val_output)); + context->SetOutput(1, + xla::ConstantR1(context->builder(), idx_output)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 39fbf98a6274918840e9e351470f04c2d80c5d01..87ee2d3aede50eb24e65570f106d49030e1d4236 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -50,8 +51,8 @@ class LRNOp : public XlaOpKernel { auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = XlaHelpers::ConvertElementType(builder, input, accumulation_type); - auto squared = builder->Mul(converted, converted); - auto reduce = builder->ReduceWindow( + auto squared = xla::Mul(converted, converted); + auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, @@ -59,12 +60,12 @@ class LRNOp : public XlaOpKernel { auto sqr_sum = XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); - auto scale = builder->Pow( - builder->Add(builder->ConstantR0(bias_), - builder->Mul(builder->ConstantR0(alpha_), sqr_sum)), - builder->ConstantR0(-beta_)); + auto scale = xla::Pow( + xla::Add(xla::ConstantR0(builder, bias_), + xla::Mul(xla::ConstantR0(builder, alpha_), sqr_sum)), + xla::ConstantR0(builder, -beta_)); - ctx->SetOutput(0, builder->Mul(input, scale)); + ctx->SetOutput(0, xla::Mul(input, scale)); } private: @@ -138,8 +139,8 @@ class LRNGradOp : public XlaOpKernel { auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = XlaHelpers::ConvertElementType(builder, in_image, accumulation_type); - auto squared = builder->Mul(converted, converted); - auto reduce = builder->ReduceWindow( + auto squared = xla::Mul(converted, converted); + auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, @@ -148,17 +149,17 @@ class LRNGradOp : public XlaOpKernel { XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); auto norm = - builder->Add(builder->ConstantR0(bias_), - builder->Mul(builder->ConstantR0(alpha_), sqr_sum)); + xla::Add(xla::ConstantR0(builder, bias_), + xla::Mul(xla::ConstantR0(builder, alpha_), sqr_sum)); - auto dy = builder->Mul( - builder->Mul(builder->ConstantR0(-2.0f * alpha_ * beta_), - builder->Div(out_image, norm)), + auto dy = xla::Mul( + xla::Mul(xla::ConstantR0(builder, -2.0f * alpha_ * beta_), + xla::Div(out_image, norm)), in_grads); auto converted_dy = XlaHelpers::ConvertElementType(builder, dy, accumulation_type); - auto dy_reduce = builder->ReduceWindow( + auto dy_reduce = xla::ReduceWindow( converted_dy, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, @@ -166,10 +167,10 @@ class LRNGradOp : public XlaOpKernel { auto dy_reduced = XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0)); - xla::XlaOp gradients = builder->Add( - builder->Mul(in_image, dy_reduced), - builder->Mul(in_grads, - builder->Pow(norm, builder->ConstantR0(-beta_)))); + xla::XlaOp gradients = xla::Add( + xla::Mul(in_image, dy_reduced), + xla::Mul(in_grads, + xla::Pow(norm, xla::ConstantR0(builder, -beta_)))); ctx->SetOutput(0, gradients); } diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 6949b296f4b9afe4a0c9152c763a9ad233b9f595..6440770c29894c951f010f6c1deb929f4fe79bbf 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -53,10 +54,14 @@ class MatMulOp : public XlaOpKernel { const TensorShape b_shape = ctx->InputShape(1); // Check that the dimensions of the two matrices are valid. - OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_shape), - errors::InvalidArgument("In[0] is not a matrix")); - OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b_shape), - errors::InvalidArgument("In[1] is not a matrix")); + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrix(a_shape), + errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ", + a_shape.DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrix(b_shape), + errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ", + b_shape.DebugString())); int first_index = transpose_a_ ? 0 : 1; int second_index = transpose_b_ ? 1 : 0; @@ -70,15 +75,15 @@ class MatMulOp : public XlaOpKernel { xla::XlaOp b = ctx->Input(1); if (is_sparse_) { if (a_type_ == DT_BFLOAT16) { - a = ctx->builder()->ConvertElementType(a, xla::F32); + a = xla::ConvertElementType(a, xla::F32); } if (b_type_ == DT_BFLOAT16) { - b = ctx->builder()->ConvertElementType(b, xla::F32); + b = xla::ConvertElementType(b, xla::F32); } } - auto lhs = (transpose_a_) ? ctx->builder()->Transpose(a, {1, 0}) : a; - auto rhs = (transpose_b_) ? ctx->builder()->Transpose(b, {1, 0}) : b; - ctx->SetOutput(0, ctx->builder()->Dot(lhs, rhs)); + auto lhs = (transpose_a_) ? xla::Transpose(a, {1, 0}) : a; + auto rhs = (transpose_b_) ? xla::Transpose(b, {1, 0}) : b; + ctx->SetOutput(0, xla::Dot(lhs, rhs)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index fbd5dc0fdad4483aadbe9bc263cc1f7a034cee09..8dfd7de591c4a3c4768dd60b41e03d294ad49397 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -50,6 +52,7 @@ class MatrixBandPartOp : public XlaOpKernel { xla::XlaOp num_upper = context->Input(2); DataType input_type = context->input_type(0); DataType index_type = context->input_type(1); + xla::PrimitiveType index_xla_type = context->input_xla_type(1); TensorShape batch_shape = input_shape; batch_shape.RemoveLastDims(2); @@ -58,33 +61,29 @@ class MatrixBandPartOp : public XlaOpKernel { // Compute 'offset', which is how many diagonals we are above/below the // diagonal. - xla::XlaOp iota_m; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m)); + xla::XlaOp iota_m = xla::Iota(builder, index_xla_type, m); + xla::XlaOp iota_n = xla::Iota(builder, index_xla_type, n); - xla::XlaOp iota_n; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n)); - - auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m, - /*broadcast_dimensions=*/{0}); + auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m, + /*broadcast_dimensions=*/{0}); // If num_lower or num_upper are negative, include all lower/upper // diagonals. auto zero_index = XlaHelpers::Zero(builder, index_type); - num_lower = builder->Select( - builder->Lt(num_lower, zero_index), - XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower); - num_upper = builder->Select( - builder->Lt(num_upper, zero_index), - XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper); + num_lower = xla::Select(xla::Lt(num_lower, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, m), + num_lower); + num_upper = xla::Select(xla::Lt(num_upper, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, n), + num_upper); - auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset), - builder->Le(offset, num_upper)); - indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + auto indicator = xla::And(xla::Le(xla::Neg(num_lower), offset), + xla::Le(offset, num_upper)); + indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); auto zero_input = XlaHelpers::Zero(builder, input_type); - auto output = builder->Select( - indicator, input, - builder->Broadcast(zero_input, input_shape.dim_sizes())); + auto output = xla::Select( + indicator, input, xla::Broadcast(zero_input, input_shape.dim_sizes())); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index db53f6fef8d6bf901c8281f50791ca6766c46efd..c0ca881ff82cee04e0c5e35f9a2d5732fabdd8a6 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { @@ -61,14 +63,11 @@ class MatrixSetDiagOp : public XlaOpKernel { auto zero = XlaHelpers::Zero(builder, context->input_type(0)); // Create an indicator tensor that is true only on the diagonal. - xla::XlaOp iota_m; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m)); - xla::XlaOp iota_n; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n)); - auto indicator = builder->Eq(iota_m, - builder->Broadcast(iota_n, {m}), - /*broadcast_dimensions=*/{0}); - indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + xla::XlaOp iota_m = xla::Iota(builder, xla::S32, m); + xla::XlaOp iota_n = xla::Iota(builder, xla::S32, n); + auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}), + /*broadcast_dimensions=*/{0}); + indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); // Broadcast diag up to the input shape. Use an implicit broadcast (Add) // because we need to broadcast on the right. @@ -77,10 +76,10 @@ class MatrixSetDiagOp : public XlaOpKernel { if (min_dim != m) { diag_broadcast_dims.back() = rank - 1; } - diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()), - /*broadcast_dimensions=*/diag_broadcast_dims); + diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()), + /*broadcast_dimensions=*/diag_broadcast_dims); - auto output = builder->Select(indicator, diag, input); + auto output = xla::Select(indicator, diag, input); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index eaed93146460de5a6e8328432302cc75bf36a534..f4def11d08c31513aec5aad15187016a7294c2fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -30,13 +30,9 @@ class MatrixTriangularSolveOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { auto result = TriangularSolve( - ctx->builder(), ctx->Input(0), ctx->Input(1), /*left_side=*/true, + ctx->Input(0), ctx->Input(1), /*left_side=*/true, /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, result); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 7e9de3ef9b245c113cc143128fe58e7e017a361c..eedfc3c9140d7b1ccc1944611de98c1d49fbdaf2 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/mirror_pad_mode.h" namespace tensorflow { @@ -27,21 +28,21 @@ class MirrorPadOp : public XlaOpKernel { xla::StatusOr DoMirrorPad(const xla::XlaOp& t, const xla::Shape& original_shape, - const xla::Literal& pad_literal, + const xla::LiteralSlice& pad_literal, xla::XlaBuilder* b) { xla::XlaOp accum = t; for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { - auto t_rev = b->Rev(accum, {dimno}); + auto t_rev = xla::Rev(accum, {dimno}); TF_ASSIGN_OR_RETURN(int64 lhs_padding, pad_literal.GetIntegralAsS64({dimno, 0})); TF_ASSIGN_OR_RETURN(int64 rhs_padding, pad_literal.GetIntegralAsS64({dimno, 1})); int64 dim_size = original_shape.dimensions(dimno); - auto lhs_pad = b->SliceInDim(t_rev, dim_size - 1 - lhs_padding, - dim_size - 1, 1, dimno); - auto rhs_pad = b->SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno); - accum = b->ConcatInDim({lhs_pad, accum, rhs_pad}, dimno); + auto lhs_pad = xla::SliceInDim(t_rev, dim_size - 1 - lhs_padding, + dim_size - 1, 1, dimno); + auto rhs_pad = xla::SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno); + accum = xla::ConcatInDim(b, {lhs_pad, accum, rhs_pad}, dimno); } return accum; } diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index aecaabb6dcf46bdd6ae3da929448d6370acb989b..a9b519d8928cc2807831fd6b4f12e60b7d58ea55 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -76,11 +77,10 @@ class PackOp : public XlaOpKernel { for (int i = 0; i < num; ++i) { // Reshape the inputs to have an extra dimension of size 1. - reshaped_inputs[i] = - ctx->builder()->Reshape(values[i], child_shape.dim_sizes()); + reshaped_inputs[i] = xla::Reshape(values[i], child_shape.dim_sizes()); } - ctx->SetOutput(0, ctx->builder()->ConcatInDim(reshaped_inputs, axis)); + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), reshaped_inputs, axis)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 7c95475e7b1f02183e44f73f116a4aeb25f05c09..e5937b56c17d01892928b073da09f38941ea1bbb 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -63,8 +64,8 @@ class PadOp : public XlaOpKernel { int before = pad_literal.Get({i, 0}); int after = pad_literal.Get({i, 1}); OP_REQUIRES(ctx, before >= 0 && after >= 0, - errors::InvalidArgument("Paddings must be non-negative: ", - before, " ", after)); + errors::InvalidArgument( + "Paddings must be non-negative: ", before, " ", after)); dim->set_edge_padding_low(before); dim->set_edge_padding_high(after); } @@ -74,11 +75,10 @@ class PadOp : public XlaOpKernel { if (ctx->num_inputs() == 3) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)), errors::InvalidArgument("constant_values must be a scalar.")); - ctx->SetOutput(0, - ctx->builder()->Pad(ctx->Input(0), ctx->Input(2), config)); + ctx->SetOutput(0, xla::Pad(ctx->Input(0), ctx->Input(2), config)); } else { auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, ctx->builder()->Pad(ctx->Input(0), zero, config)); + ctx->SetOutput(0, xla::Pad(ctx->Input(0), zero, config)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index f8e7b48a0fd94835964aea033ad33523150067b4..d4d180aff806f12875f0e43f111ee090f6607ef6 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -20,7 +20,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -61,63 +65,60 @@ class PoolingOp : public XlaOpKernel { Padding padding; OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding)); padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); } int num_dims() const { return num_spatial_dims_ + 2; } - // Method that builds an initial value to use in reductions. - virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0; - - // The reduction operation to apply to each window. - virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0; - - // A post-processing operation to apply on the outputs of the ReduceWindow. - virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) = 0; - - void Compile(XlaOpKernelContext* ctx) override { - std::vector ksize = ksize_; - std::vector stride = stride_; - if (ctx->num_inputs() != 1) { - const TensorShape ksize_shape = ctx->InputShape(1); - // Validate input sizes. - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), - errors::InvalidArgument("ksize must be a vector, not shape ", - ksize_shape.DebugString())); - OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(), - errors::InvalidArgument("Sliding window ksize field must " - "specify ", - num_dims(), " dimensions")); - ksize.clear(); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize)); - - const TensorShape stride_shape = ctx->InputShape(2); - // Validate input sizes. - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), - errors::InvalidArgument("stride must be a vector, not shape ", - stride_shape.DebugString())); - OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(), - errors::InvalidArgument("Sliding window stride field must " - "specify ", - num_dims(), " dimensions")); - stride.clear(); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); + protected: + xla::StatusOr> GetKernelSize(XlaOpKernelContext* ctx) { + if (ctx->num_inputs() == 1) { + return ksize_; } - const TensorShape input_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, input_shape.dims() == num_dims(), - errors::InvalidArgument("Input to ", type_string(), - " operator must have ", num_dims(), - " dimensions")); + const TensorShape ksize_shape = ctx->InputShape(1); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(ksize_shape)) { + return errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString()); + } + if (ksize_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window ksize field must " + "specify ", + num_dims(), " dimensions"); + } + std::vector ksize; + auto status = ctx->ConstantInputAsIntVector(1, &ksize); + if (!status.ok()) { + return status; + } + return ksize; + } - xla::XlaBuilder* const b = ctx->builder(); - auto input = - XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); - auto reduce = ctx->builder()->ReduceWindow( - input, InitValue(b), *Reduction(ctx), ksize, stride, padding_); - auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); - ctx->SetOutput(0, - PostProcessOutput(ctx, pooled, input_type(0), input_shape)); + xla::StatusOr> GetStride(XlaOpKernelContext* ctx) { + if (ctx->num_inputs() == 1) { + return stride_; + } + const TensorShape stride_shape = ctx->InputShape(2); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(stride_shape)) { + return errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString()); + } + if (stride_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window stride field must " + "specify ", + num_dims(), " dimensions"); + } + std::vector stride; + auto status = ctx->ConstantInputAsIntVector(2, &stride); + if (!status.ok()) { + return status; + } + return stride; } protected: @@ -127,26 +128,51 @@ class PoolingOp : public XlaOpKernel { xla::Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; DataType reduction_type_; + xla::PrimitiveType xla_reduction_type_; }; +// Converts the tensor data format to the one required by the XLA pooling +// library. +xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format, + int num_spatial_dims) { + int num_dims = num_spatial_dims + 2; + int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format); + int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format); + gtl::InlinedVector spatial_dimensions(num_spatial_dims); + for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) { + spatial_dimensions[spatial_dim] = + GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim); + } + return xla::TensorFormat(/*batch_dimension=*/batch_dimension, + /*feature_dimension=*/feature_dimension, + /*spatial_dimensions=*/spatial_dimensions); +} + class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ctx->input_type(0)) {} - xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return XlaHelpers::MinValue(b, reduction_type_); - } + void Compile(XlaOpKernelContext* ctx) override { + auto ksize_or_error = GetKernelSize(ctx); + OP_REQUIRES_OK(ctx, ksize_or_error.status()); + std::vector ksize = ksize_or_error.ValueOrDie(); - const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { - return ctx->GetOrCreateMax(reduction_type_); - } + auto stride_or_error = GetStride(ctx); + OP_REQUIRES_OK(ctx, stride_or_error.status()); + std::vector stride = stride_or_error.ValueOrDie(); + + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("Input to ", type_string(), + " operator must have ", num_dims(), + " dimensions")); - xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) override { - return output; + auto pooling = + xla::MaxPool(ctx->Input(0), ksize, stride, padding_, + XlaTensorFormat(data_format_, input_shape.dims() - 2)); + ctx->SetOutput(0, pooling); } }; @@ -173,9 +199,8 @@ class MaxPool3DOp : public MaxPoolOp { }; REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); -// Common computation shared between AvgPool and AvgPoolGrad. Divide each -// element of an image by the count of elements that contributed to that -// element during pooling. +// Divide each element of an image by the count of elements that contributed to +// that element during pooling. static xla::XlaOp AvgPoolDivideByCount( XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, const TensorShape& input_shape, xla::Padding padding, @@ -190,7 +215,7 @@ static xla::XlaOp AvgPoolDivideByCount( auto divisor = XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size); - return ctx->builder()->Div(output, divisor); + return xla::Div(output, divisor); } else { // For SAME padding, the padding shouldn't be included in the // counts. We use another ReduceWindow to find the right counts. @@ -212,18 +237,18 @@ static xla::XlaOp AvgPoolDivideByCount( // Build a matrix of all 1s, with the same width/height as the input. const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto ones = ctx->builder()->Broadcast( + auto ones = xla::Broadcast( XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes); // Perform a ReduceWindow with the same window size, strides, and padding // to count the number of contributions to each result element. - auto reduce = ctx->builder()->ReduceWindow( + auto reduce = xla::ReduceWindow( ones, XlaHelpers::Zero(ctx->builder(), accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride, xla::Padding::kSame); auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype); - return ctx->builder()->Div(output, counts, window_dims); + return xla::Div(output, counts, window_dims); } } @@ -234,20 +259,34 @@ class AvgPoolOp : public PoolingOp { /*reduction_type=*/ XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return XlaHelpers::Zero(b, reduction_type_); - } + void Compile(XlaOpKernelContext* ctx) override { + auto ksize_or_error = GetKernelSize(ctx); + OP_REQUIRES_OK(ctx, ksize_or_error.status()); + std::vector ksize = ksize_or_error.ValueOrDie(); - const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { - return ctx->GetOrCreateAdd(reduction_type_); - } + auto stride_or_error = GetStride(ctx); + OP_REQUIRES_OK(ctx, stride_or_error.status()); + std::vector stride = stride_or_error.ValueOrDie(); - xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) override { - return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, - ksize_, stride_, num_spatial_dims_, - data_format_); + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("Input to ", type_string(), + " operator must have ", num_dims(), + " dimensions")); + + auto xla_data_format = + XlaTensorFormat(data_format_, input_shape.dims() - 2); + auto spatial_padding = MakeSpatialPadding( + input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format); + + // Convert the input to the reduction type. + auto converted_input = + ConvertElementType(ctx->Input(0), xla_reduction_type_); + auto pooling = + xla::AvgPool(converted_input, ksize, stride, spatial_padding, + xla_data_format, padding_ == xla::Padding::kValid); + // Convert the pooling result back to the input type before returning it. + ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0))); } }; @@ -347,9 +386,9 @@ class MaxPoolGradOp : public XlaOpKernel { xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2)); auto select = CreateScalarGeComputation(element_type, ctx->builder()); auto scatter = CreateScalarAddComputation(element_type, ctx->builder()); - xla::XlaOp gradients = ctx->builder()->SelectAndScatter( - input, select, ksize_, stride_, xla_padding, out_backprop, init_value, - scatter); + xla::XlaOp gradients = + xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding, + out_backprop, init_value, scatter); ctx->SetOutput(0, gradients); } @@ -485,12 +524,12 @@ class AvgPoolGradOp : public XlaOpKernel { } auto zero = XlaHelpers::Zero(b, dtype); - auto padded_gradients = b->Pad(out_backprop_div, zero, padding_config); + auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config); // in_backprop = padded_gradients ones std::vector ones(num_dims(), 1LL); auto accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto in_backprop = b->ReduceWindow( + auto in_backprop = xla::ReduceWindow( XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), ksize_, @@ -614,58 +653,61 @@ class MaxPoolGradGradOp : public XlaOpKernel { auto b = ctx->builder(); - auto sixteen = b->ConstantR0(16); + auto sixteen = xla::ConstantR0(b, 16); // in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32 - auto in_hi = b->BitcastConvertType( - b->ConvertElementType(b->ConvertElementType(input, xla::BF16), - xla::F32), + auto in_hi = xla::BitcastConvertType( + xla::ConvertElementType(xla::ConvertElementType(input, xla::BF16), + xla::F32), xla::U32); - auto bp_int = b->BitcastConvertType(out_backprop, xla::U32); - auto bp_hi = b->ShiftRightLogical(bp_int, sixteen); - auto bp_lo = b->ShiftRightLogical(b->ShiftLeft(bp_int, sixteen), sixteen); - auto in_hi_bp_hi = b->Add(in_hi, bp_hi); // Want an unsigned add. - auto in_hi_bp_lo = b->Add(in_hi, bp_lo); // Want an unsigned add. - - auto init_value = XlaHelpers::MinValue(b, DT_FLOAT); + auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32); + auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen); + auto bp_lo = + xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen); + auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add. + auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add. + + auto init_value = xla::MinValue(b, xla::F32); // We will reduce by taking the maximal value up to 16 bits (ignoring the lo // 16 bits of packed-in hi/lo backprop value). auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits"); { // F32 parameters to satisfy lowering type restriction for reduce opcode. const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {}); - auto lhs = rb->Parameter(0, scalar, "lhs"); - auto rhs = rb->Parameter(1, scalar, "rhs"); - auto sixteen = rb->ConstantR0(16); - auto lhs_criteria = rb->ShiftLeft( - rb->ShiftRightLogical(rb->BitcastConvertType(lhs, xla::S32), sixteen), - sixteen); - auto rhs_criteria = rb->ShiftLeft( - rb->ShiftRightLogical(rb->BitcastConvertType(rhs, xla::S32), sixteen), - sixteen); + auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs"); + auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs"); + auto sixteen = xla::ConstantR0(rb.get(), 16); + auto lhs_criteria = + xla::ShiftLeft(xla::ShiftRightLogical( + xla::BitcastConvertType(lhs, xla::S32), sixteen), + sixteen); + auto rhs_criteria = + xla::ShiftLeft(xla::ShiftRightLogical( + xla::BitcastConvertType(rhs, xla::S32), sixteen), + sixteen); // Must use a F32 comparison, because S32 would not work for negatives. - rb->Select(rb->Ge(rb->BitcastConvertType(lhs_criteria, xla::F32), - rb->BitcastConvertType(rhs_criteria, xla::F32)), - lhs, rhs); + xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32), + xla::BitcastConvertType(rhs_criteria, xla::F32)), + lhs, rhs); } auto reduce = rb->BuildAndNoteError(); xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; auto pooled_hi = - b->ReduceWindow(b->BitcastConvertType(in_hi_bp_hi, xla::F32), - init_value, reduce, ksize_, stride_, xla_padding); + xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); auto pooled_lo = - b->ReduceWindow(b->BitcastConvertType(in_hi_bp_lo, xla::F32), - init_value, reduce, ksize_, stride_, xla_padding); + xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); auto grads_hi = - b->ShiftLeft(b->BitcastConvertType(pooled_hi, xla::U32), sixteen); - auto grads_lo = b->ShiftRightLogical( - b->ShiftLeft(b->BitcastConvertType(pooled_lo, xla::U32), sixteen), + xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen); + auto grads_lo = xla::ShiftRightLogical( + xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen), sixteen); - auto grads = b->Add(grads_hi, grads_lo); // Want an unsigned add. + auto grads = xla::Add(grads_hi, grads_lo); // Want an unsigned add. xla::PrimitiveType element_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); - ctx->SetOutput(0, b->BitcastConvertType(grads, element_type)); + ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type)); } protected: @@ -694,5 +736,18 @@ REGISTER_XLA_OP(Name("MaxPoolGradGradV2") .CompileTimeConstInput("strides"), MaxPool2DGradGradOp); +class MaxPool3DGradGradOp : public MaxPoolGradGradOp { + public: + explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx) + : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) { + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("MaxPool3DGradGrad").TypeConstraint("T", DT_FLOAT), + MaxPool3DGradGradOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..de9068a640dc03b141b6954eaa1629dd6c8c1f3a --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/qr.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +class QROp : public XlaOpKernel { + public: + explicit QROp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + bool full_matrices; + OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices)); + OP_REQUIRES( + ctx, full_matrices, + errors::Unimplemented("full_matrices=False case of QR decomposition is " + "not implemented in TF/XLA")); + } + void Compile(XlaOpKernelContext* ctx) override { + auto result = QRDecomposition(ctx->Input(0)); + if (!result.ok()) { + ctx->SetStatus(result.status()); + return; + } + ctx->SetOutput(0, result.ValueOrDie().q); + ctx->SetOutput(1, result.ValueOrDie().r); + } +}; + +REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 661cd5923e1023eaf89a6bc4f56fcc362c8bcfb6..6f4ed496a1774dde68dd9d5fbd37995d615b678c 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -13,10 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -28,82 +32,115 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_)); - OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), - errors::InvalidArgument("num_bits is out of range: ", num_bits_, - " with signed_input_ ", signed_input_)); } void Compile(XlaOpKernelContext* ctx) override { xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); - // Comments taken from semantics description at - // https://www.tensorflow.org/versions/r1.0/api_docs/cc/class/tensorflow/ops/quantize-and-dequantize - // - // ... we find m such that - // - // m = max(abs(input_min), abs(input_max)) if range_given is true, - // m = max(abs(min_elem(input)), - // abs(max_elem(input))) otherwise. + xla::PrimitiveType xla_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(data_type, &xla_type)); + xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp input_min, input_max; + + // The implementation follows + // tensorflow/core/kernels/quantize_and_dequantize_op.h closely. + xla::XlaOp min_range, max_range; if (range_given_) { - double input_min_value, input_max_value; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &input_min_value)); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(2, &input_max_value)); - input_min = XlaHelpers::FloatLiteral(b, data_type, input_min_value); - input_max = XlaHelpers::FloatLiteral(b, data_type, input_max_value); + min_range = ctx->Input(1); + max_range = ctx->Input(2); } else { const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type); const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type); - input_min = - b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin); - input_max = - b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax); + min_range = ReduceAll(input, xla::MaxValue(b, xla_type), *fmin); + max_range = ReduceAll(input, xla::MinValue(b, xla_type), *fmax); } - xla::XlaOp m = b->Max(b->Abs(input_min), b->Abs(input_max)); - - // Next, we choose our fixed-point quantization buckets, [min_fixed, - // max_fixed]. If signed_input is true, this is - // - // [min_fixed, max_fixed ] = [-((1 << (num_bits - 1)) - 1), - // (1 << (num_bits - 1)) - 1]. - // - // Otherwise, if signed_input is false, the fixed-point range is - // - // [min_fixed, max_fixed] = [0, (1 << num_bits) - 1]. - int64 min_fixed, max_fixed; + + xla::XlaOp num_bits; + if (num_bits_ < 0) { + OP_REQUIRES( + ctx, ctx->num_inputs() == 4, + errors::Internal("Expected 4 inputs to QuantizeAndDequantize")); + num_bits = ctx->Input(3); + } else { + num_bits = xla::ConstantR0(b, num_bits_); + } + + const xla::XlaOp zero = XlaHelpers::Zero(b, data_type); + const xla::XlaOp one = XlaHelpers::One(b, data_type); + const xla::XlaOp two = XlaHelpers::FloatLiteral(b, data_type, 2.0); + const xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5); + + // Calculate the range for the simulated integer quantization: + // e.g. [-128,127] for signed = true, num_bits = 8, + // or [0, 255] for signed = false, num_bits = 8. + // We do this in floating point for hardware that does not have 64-bit + // integer support. + xla::XlaOp min_quantized, max_quantized; if (signed_input_) { - min_fixed = -((1LL << (num_bits_ - 1)) - 1); - max_fixed = (1LL << (num_bits_ - 1)) - 1; + min_quantized = + -Pow(two, ConvertElementType(num_bits - xla::ConstantR0(b, 1), + xla_type)); + max_quantized = + Pow(two, ConvertElementType(num_bits - xla::ConstantR0(b, 1), + xla_type)) - + one; } else { - min_fixed = 0; - max_fixed = (1LL << num_bits_) - 1; + min_quantized = zero; + max_quantized = Pow(two, ConvertElementType(num_bits, xla_type)) - one; } - // From this we compute our scaling factor, s: - // - // s = (max_fixed - min_fixed) / (2 * m). - xla::XlaOp s = - b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed), - b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m)); + // Determine the maximum scaling factor that would scale + // [min_range, max_range] to not exceed [min_quantized, max_quantized], + // while keeping 0 unchanged. + xla::XlaOp scale_from_min_side = + Select(Gt(min_quantized * min_range, zero), min_quantized / min_range, + xla::MaxFiniteValue(b, xla_type)); + xla::XlaOp scale_from_max_side = + Select(Gt(max_quantized * max_range, zero), max_quantized / max_range, + xla::MaxFiniteValue(b, xla_type)); - // Now we can quantize and dequantize the elements of our tensor. An element - // e is transformed into e': - // - // e' = (e * s).round_to_nearest() / s. - xla::XlaOp result = b->Div(b->Round(b->Mul(input, s)), s); + // Note: Avoids changing the side of the range that determines scale. + xla::XlaOp cond = Lt(scale_from_min_side, scale_from_max_side); + xla::XlaOp scale = Select(cond, scale_from_min_side, scale_from_max_side); + xla::XlaOp inverse_scale = + Select(cond, min_range / min_quantized, max_range / max_quantized); + min_range = Select(cond, min_range, min_quantized * inverse_scale); + max_range = Select(cond, max_quantized * inverse_scale, max_range); + if (range_given_) { + // Note: The clamping here is to avoid overflow in the quantized type. + // The semantics of the op does not guarantee to clamp to the specified + // min_range and max_range - because we may have changed either min_range + // or max_range. + // No need to clamp to min_range and max_range if range_given_ == false as + // in that case they were measured from the tensor. + input = Clamp(min_range, input, max_range); + } + xla::XlaOp result = + Floor((input - min_range) * scale + half) * inverse_scale + min_range; ctx->SetOutput(0, result); } - int64 num_bits_; + protected: + int64 num_bits_ = -1; bool signed_input_; bool range_given_; }; -REGISTER_XLA_OP(Name("QuantizeAndDequantizeV2"), QuantizeAndDequantizeOp); +class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp { + public: + explicit QuantizeAndDequantizeV2Op(OpKernelConstruction* ctx) + : QuantizeAndDequantizeOp(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_)); + OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), + errors::InvalidArgument("num_bits is out of range: ", num_bits_, + " with signed_input_ ", signed_input_)); + } +}; + +REGISTER_XLA_OP(Name("QuantizeAndDequantizeV2"), QuantizeAndDequantizeV2Op); +REGISTER_XLA_OP(Name("QuantizeAndDequantizeV3"), QuantizeAndDequantizeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 105be38fe26b6667e8b4ce6da92a3969cdc0c187..2da9340625db08b14b78340c471f096baf15689d 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -18,6 +18,7 @@ limitations under the License. // TODO(misard,phawkins): add tests. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -25,6 +26,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -45,8 +48,8 @@ class RandomUniformOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp result = b->RngUniform(XlaHelpers::Zero(b, dtype), - XlaHelpers::One(b, dtype), xla_shape); + xla::XlaOp result = xla::RngUniform(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -71,57 +74,121 @@ class RandomShuffleOp : public XlaOpKernel { for (tensorflow::TensorShapeDim dimension : input_shape) { num_elements *= dimension.size; } + if (num_elements <= 1 || n <= 1) { // No shuffling is required, so copy input directly to output ctx->SetOutput(0, input); - } else { - // Generate the random swaps for the indices. - auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); - auto swaps = - builder->RngUniform(builder->ConstantR0(0), - builder->ConstantR0(n), swaps_shape); - - // Generate range(n) as the initial value for the indices to be swapped. - xla::XlaOp indices; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices)); - - // Swap the indices at i and swaps[i]. - auto swap_body_fn = [&](xla::XlaOp i, - gtl::ArraySlice loop_vars, - xla::XlaBuilder* builder) - -> xla::StatusOr> { - auto swaps = loop_vars[0]; - auto indices = loop_vars[1]; - i = builder->Reshape(i, {1}); - // temp = indices[i] - auto temp = builder->DynamicSlice(indices, i, {1}); - // swap_index = swaps[i] - auto swap_index = builder->DynamicSlice(swaps, i, {1}); - // swap_value = indices[swaps[i]] - auto swap_value = builder->DynamicSlice(indices, swap_index, {1}); - // indices[i] = indices[swaps[i]] - indices = builder->DynamicUpdateSlice(indices, swap_value, i); - // indices[swaps[i]] = temp - indices = builder->DynamicUpdateSlice(indices, temp, swap_index); - return std::vector{swaps, indices}; - }; - // for i in range(n): - auto swap_loop_result = - XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, - "indices_swap_loop", builder) - .ValueOrDie(); - auto swapped_indices = swap_loop_result[1]; - - // Gather the data using the swapped indices as the shuffled order. - auto indices_tensor_shape = TensorShape({n}); - DataType type = ctx->expected_output_dtype(0); - xla::XlaOp gather; - OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices, - indices_tensor_shape, - /*axis=*/0, /*indices_are_nd=*/false, type, - DT_INT32, builder, &gather)); - ctx->SetOutput(0, gather); + return; + } + + if (input_shape.dims() == 1) { + // For R1s, shuffle values by sorting instead of the obvious Fisher-Yates + // algorithm. Fisher-Yates is simple to implement and correct, but not + // easily parallelizable. For a sufficiently parallel architecture, it is + // faster to sort many times, than Fisher-Yates shuffle once. + + // Shuffle values by assigning each value a random key and sorting the + // keys. Keys can collide causing detectable patterns in the shuffled + // output. Collisions translates into more ascending sub-sequences in the + // shuffled output than would be expected by chance. To avoid collisions, + // the number of possible key values must be sufficiently large. + + // How are more than 2^32 keys created? In each loop iteration, the + // algorithm sorts by random keys. Conceptually, the earlier iterations + // are sorting on the lower-order bits of larger keys that are never + // actually assembled. + + // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is + // the number of possible keys and n is the number of values. If d = n^2, + // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit + // as n goes to infinity is zero. + + // This implementation ensures that the key-space is greater than or equal + // to the cube of the number of values. The risk of collisions can be + // further reduced by increasing Exponent at the expense of + // performance. + + // For Exponent = 2, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is + // about 1/2. + + // For Exponent = 3, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is + // about 1/3255. + + // For Exponent = 4, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is + // about 1/132622. + constexpr int Exponent = 3; + const int rounds = static_cast( + std::ceil(Exponent * std::log(num_elements) / std::log(kuint32max))); + + const xla::Shape key_shape = + xla::ShapeUtil::MakeShape(xla::U32, {num_elements}); + xla::XlaOp zero = xla::ConstantR0(builder, 0U); + + // Unfortunately, xla::RngUniform gives values in the half open interval + // rather than the closed interval, so instead of 2^32 possible keys there + // are only 2^32 - 1 (kuint32max). + xla::XlaOp max_value = xla::ConstantR0(builder, kuint32max); + + xla::XlaOp curr = input; + for (int i = 0; i < rounds; ++i) { + xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape); + xla::XlaOp sorted = xla::Sort(keys, curr); + curr = xla::GetTupleElement(sorted, 1); + } + + ctx->SetOutput(0, curr); + return; } + + // The Fisher-Yates algorithm. + + // Generate the random swaps for the indices. + auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); + auto swaps = + xla::RngUniform(xla::ConstantR0(builder, 0), + xla::ConstantR0(builder, n), swaps_shape); + + // Generate range(n) as the initial value for the indices to be swapped. + xla::XlaOp indices = xla::Iota(builder, xla::S32, n); + + // Swap the indices at i and swaps[i]. + auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* builder) + -> xla::StatusOr> { + auto swaps = loop_vars[0]; + auto indices = loop_vars[1]; + i = xla::Reshape(i, {1}); + // temp = indices[i] + auto temp = xla::DynamicSlice(indices, i, {1}); + // swap_index = swaps[i] + auto swap_index = xla::DynamicSlice(swaps, i, {1}); + // swap_value = indices[swaps[i]] + auto swap_value = xla::DynamicSlice(indices, swap_index, {1}); + // indices[i] = indices[swaps[i]] + indices = xla::DynamicUpdateSlice(indices, swap_value, i); + // indices[swaps[i]] = temp + indices = xla::DynamicUpdateSlice(indices, temp, swap_index); + return std::vector{swaps, indices}; + }; + // for i in range(n): + auto swap_loop_result = + XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) + .ValueOrDie(); + auto swapped_indices = swap_loop_result[1]; + + // Gather the data using the swapped indices as the shuffled order. + auto indices_tensor_shape = TensorShape({n}); + DataType type = ctx->expected_output_dtype(0); + xla::XlaOp gather; + OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices, + indices_tensor_shape, + /*axis=*/0, /*indices_are_nd=*/false, type, + DT_INT32, builder, &gather)); + ctx->SetOutput(0, gather); } private: @@ -152,7 +219,7 @@ class RandomUniformIntOp : public XlaOpKernel { auto minval = ctx->Input(1); auto maxval = ctx->Input(2); - ctx->SetOutput(0, ctx->builder()->RngUniform(minval, maxval, xla_shape)); + ctx->SetOutput(0, xla::RngUniform(minval, maxval, xla_shape)); } private: @@ -178,8 +245,8 @@ class RandomStandardNormalOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); // Normal distribution with a mean of 0 and a standard deviation of 1: - xla::XlaOp result = b->RngNormal(XlaHelpers::Zero(b, dtype), - XlaHelpers::One(b, dtype), xla_shape); + xla::XlaOp result = xla::RngNormal(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -205,59 +272,18 @@ class TruncatedNormalOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); - auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) { - return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0); - }; - auto out_of_range_mask = [two_sd](xla::XlaOp candidate, - xla::XlaBuilder* b) { - xla::XlaOp too_large = b->Gt(candidate, two_sd(false, b)); - xla::XlaOp too_small = b->Lt(candidate, two_sd(true, b)); - return b->Or(too_large, too_small); - }; - - // The algorithm we're using is roughly: - // - // while (any(candidate < mean-2*sd || candidate > mean+2*sd)) { - // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd - // candidate = select(out_of_range_mask, rng_normal(), candidate) - // } - std::vector initial_values = { - // The current candidate. - b->Broadcast(XlaHelpers::Zero(b, dtype), shape.dim_sizes()), - // The to_resample mask, where 'true' identifies a location in the - // current candidate that is out of range and must be regenerated. - b->Broadcast(b->ConstantR0(true), shape.dim_sizes()), - // Is any element in the mask true? - b->ConstantR0(true)}; - auto condition = [&](gtl::ArraySlice values, - xla::XlaBuilder* b) -> xla::StatusOr { - // Continue while any element in the mask is true. - return values[2]; - }; - auto body = - [&](gtl::ArraySlice values, - xla::XlaBuilder* b) -> xla::StatusOr> { - xla::XlaOp candidate = values[0]; - xla::XlaOp to_resample = values[1]; - xla::XlaOp mean = XlaHelpers::Zero(b, dtype); - xla::XlaOp stddev = XlaHelpers::One(b, dtype); - candidate = b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), - candidate); - // Compute a new to_resample mask, and determine whether any value is - // still out of range. - to_resample = out_of_range_mask(candidate, b); - TF_ASSIGN_OR_RETURN(xla::XlaOp done, Any(to_resample, b)); - return std::vector{candidate, to_resample, done}; - }; - auto result = - XlaWhileLoop(condition, body, initial_values, "truncated_normal", b); - OP_REQUIRES_OK(ctx, result.status()); - ctx->SetOutput(0, result.ValueOrDie()[0]); + xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype, 1.0); + xla::XlaOp min_positive = + XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits::min()); + auto uniform = xla::RngUniform(min_positive, one, xla_shape); + ctx->SetOutput(0, TruncatedNormal(uniform)); } }; -REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"), +REGISTER_XLA_OP(Name("TruncatedNormal") + .CompileTimeConstInput("shape") + .TypeConstraint("dtype", DT_FLOAT), TruncatedNormalOp); -} // anonymous namespace +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 08894489ac77bbbe4ddb067c06a6d031a537697d..b11a4ce36da9907ce8fe377c075023a4540797fa 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" @@ -98,10 +100,10 @@ class ReduceWindowOp : public XlaOpKernel { { std::unique_ptr cb = builder->CreateSubBuilder("wrapper"); - auto x = cb->Parameter(0, scalar_shape, "x"); - auto y = cb->Parameter(1, scalar_shape, "y"); - auto outputs = cb->Call(*reducer.computation, {x, y}); - cb->GetTupleElement(outputs, 0); + auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x"); + auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y"); + auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y}); + xla::GetTupleElement(outputs, 0); xla::StatusOr result = cb->Build(); OP_REQUIRES_OK(context, result.status()); wrapper = std::move(result.ValueOrDie()); @@ -112,7 +114,7 @@ class ReduceWindowOp : public XlaOpKernel { padding[i] = {padding_low_[i], padding_high_[i]}; } - xla::XlaOp output = builder->ReduceWindowWithGeneralPadding( + xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( context->Input(0), context->Input(1), wrapper, window_dimensions_, window_strides_, padding); context->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 0f425637795e9633a8e36f921000ee2f5e25813a..0d260fa8fcaa513d7854c1e9215952404d555c70 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -19,7 +19,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -31,11 +33,11 @@ class SumOp : public XlaReductionOp { : XlaReductionOp(ctx, XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::Zero(builder, reduction_type_); + return xla::Zero(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Add(scalar_lhs, scalar_rhs); + xla::Add(scalar_lhs, scalar_rhs); } }; @@ -48,12 +50,12 @@ class ProdOp : public XlaReductionOp { XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::One(builder, reduction_type_); + return xla::One(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Mul(scalar_lhs, scalar_rhs); + xla::Mul(scalar_lhs, scalar_rhs); } }; @@ -66,12 +68,12 @@ class MinOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::MaxValue(builder, reduction_type_); + return xla::MaxValue(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Min(scalar_lhs, scalar_rhs); + xla::Min(scalar_lhs, scalar_rhs); } }; @@ -83,12 +85,12 @@ class MaxOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::MinValue(builder, reduction_type_); + return xla::MinValue(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Max(scalar_lhs, scalar_rhs); + xla::Max(scalar_lhs, scalar_rhs); } }; @@ -101,11 +103,11 @@ class MeanOp : public XlaReductionOp { XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::Zero(builder, reduction_type_); + return xla::Zero(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Add(scalar_lhs, scalar_rhs); + xla::Add(scalar_lhs, scalar_rhs); } xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, @@ -113,7 +115,7 @@ class MeanOp : public XlaReductionOp { int64 num_elements_reduced) override { auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), num_elements_reduced); - return builder->Div(reduce_output, divisor); + return reduce_output / divisor; } }; @@ -126,12 +128,12 @@ class AllOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return builder->ConstantR0(true); + return xla::ConstantR0(builder, true); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->And(scalar_lhs, scalar_rhs); + xla::And(scalar_lhs, scalar_rhs); } }; @@ -143,12 +145,12 @@ class AnyOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return builder->ConstantR0(false); + return xla::ConstantR0(builder, false); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { - builder->Or(scalar_lhs, scalar_rhs); + xla::Or(scalar_lhs, scalar_rhs); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 2ecfb854a1c8625524d4f1199af3927edd204926..466e79828d111ee7cadcf713703e8f252c63e62c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -64,6 +64,7 @@ class XlaReductionOp : public XlaOpKernel { protected: DataType reduction_type_; + xla::PrimitiveType xla_reduction_type_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 4fd5bfd03999a7f8b7bb081cc4b03aa1434d4c3d..b52f0a0ab6290f2019bb58120be5c2364ec15bb6 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -19,7 +19,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -31,6 +33,8 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); } // Unless BuildFinalizer is overridden the reduction has no @@ -56,9 +60,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { // Evaluate the constant, reshaping to a 1-vector if it is a scalar. xla::Literal axes_literal; - OP_REQUIRES_OK(ctx, - ctx->ConstantInputReshaped( - 1, {axes_tensor_shape.num_elements()}, &axes_literal)); + OP_REQUIRES_OK( + ctx, ctx->ConstantInputReshaped(1, {axes_tensor_shape.num_elements()}, + &axes_literal)); VLOG(1) << "data shape: " << data_shape.DebugString(); VLOG(1) << "axes : " << axes_literal.ToString(); @@ -101,20 +105,20 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); - auto data = b->ConvertElementType(ctx->Input(0), type); + auto data = xla::ConvertElementType(ctx->Input(0), type); // Call virtual method to get the initial value. - auto initial = b->ConvertElementType(InitialValue(b), type); + auto initial = xla::ConvertElementType(InitialValue(b), type); // Make two scalar parameters of the desired type for the lambda. - auto rx = r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); - auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); + auto rx = xla::Parameter(&r, 0, xla::ShapeUtil::MakeShape(type, {}), "x"); + auto ry = xla::Parameter(&r, 1, xla::ShapeUtil::MakeShape(type, {}), "y"); // Call virtual method to build the reduction lambda. BuildReducer(&r, rx, ry); xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie(); - auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes); + auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes); auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced); - auto result = keep_dims_ ? b->Reshape(finalized, final_shape) : finalized; + auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index ba7d484d53d7258edaa5bc42fa116cf16e94835b..d35777ccb1271ec6a7c9972c714d06b2415d9c34 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" @@ -34,7 +34,7 @@ class ReluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); - ctx->SetOutput(0, builder->Max(zero, ctx->Input(0))); + ctx->SetOutput(0, xla::Max(zero, ctx->Input(0))); } }; @@ -46,7 +46,7 @@ class Relu6Op : public XlaOpKernel { xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6); - ctx->SetOutput(0, builder->Clamp(zero, ctx->Input(0), six)); + ctx->SetOutput(0, xla::Clamp(zero, ctx->Input(0), six)); } }; @@ -59,9 +59,9 @@ class ReluGradOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = - b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); - const auto pred = b->Gt(ctx->Input(1), zero); - ctx->SetOutput(0, b->Select(pred, ctx->Input(0), zero)); + xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto pred = xla::Gt(ctx->Input(1), zero); + ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), zero)); } }; @@ -74,12 +74,12 @@ class Relu6GradOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = - b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); - const auto six = b->Broadcast( + xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); + const auto six = xla::Broadcast( XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes()); - auto out = - b->Select(b->And(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)), - ctx->Input(0), zero); + auto out = xla::Select( + xla::And(xla::Lt(ctx->Input(1), six), xla::Gt(ctx->Input(1), zero)), + ctx->Input(0), zero); ctx->SetOutput(0, out); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index af4d64b159c09ed7e01017f25a2b23e58542dc3c..121750a82a8c5cbe940068555ad273b7e0d22dfc 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -19,7 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -90,8 +91,7 @@ class ReshapeOp : public XlaOpKernel { VLOG(1) << "Reshape " << input_shape.DebugString() << " " << shape.DebugString(); - ctx->SetOutput(0, - ctx->builder()->Reshape(ctx->Input(0), shape.dim_sizes())); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), shape.dim_sizes())); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index a711278638444be01fb865561957702368b75114..64900e4709fd3e16d21096b0cfff8922906cb0d4 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -16,7 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -62,15 +63,24 @@ class RetvalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { TensorShape shape = ctx->InputShape(0); - TensorShape representation_shape = - tc.is_entry_computation() - ? tc.RepresentationShape(shape, ctx->input_type(0)) - : shape; + ctx->SetStatus(is_constant.status()); + TensorShape representation_shape; + if (tc.is_entry_computation()) { + xla::StatusOr shape_or_status = + tc.RepresentationShape(shape, ctx->input_type(0)); + if (!shape_or_status.ok()) { + ctx->SetStatus(shape_or_status.status()); + return; + } else { + representation_shape = shape_or_status.ValueOrDie(); + } + } else { + representation_shape = shape; + } xla::XlaOp output = input; if (tc.is_entry_computation()) { - output = - ctx->builder()->Reshape(input, representation_shape.dim_sizes()); + output = xla::Reshape(input, representation_shape.dim_sizes()); } else { // The core from which a return value is returned depends on the // device assignment of the input to the retval. Since we can't change @@ -78,8 +88,8 @@ class RetvalOp : public XlaOpKernel { // introduce an operator here, even if the shape does not change. // TODO(b/76097077): propagate device assignments onto arguments and // return values of functions, and then reshape unconditionally. - output = ctx->builder()->GetTupleElement( - ctx->builder()->Tuple({output}), 0); + output = + xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0); } tc.AddRetval(index_, dtype_, shape, output); } @@ -94,7 +104,7 @@ class RetvalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_XLA_OP(Name("_Retval"), RetvalOp); +REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 2872a3c4d49d0d269aa3d216887a5c32cd51f1c3..d962ef4a5f53470838643541f8a1e693d2f4011c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -19,7 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -62,7 +63,7 @@ class ReverseOp : public XlaOpKernel { } } - ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), dimensions)); + ctx->SetOutput(0, xla::Rev(ctx->Input(0), dimensions)); } }; @@ -100,7 +101,7 @@ class ReverseV2Op : public XlaOpKernel { x_shape.dims(), ").")); } - ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), axes)); + ctx->SetOutput(0, xla::Rev(ctx->Input(0), axes)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 5d1c05268493f4f6404c40a4092a71f1e5b3f3b9..03a50ef8a059e5a005c4cc2e5e98acedfea8619a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -85,103 +87,96 @@ class ReverseSequenceOp : public XlaOpKernel { auto condition_builder = builder->CreateSubBuilder("reverse_sequence_condition"); { - auto param = condition_builder->Parameter(0, tuple_shape, "param"); - auto i = condition_builder->GetTupleElement(param, 0); - condition_builder->Lt( - i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type, - batch_size)); + auto param = + xla::Parameter(condition_builder.get(), 0, tuple_shape, "param"); + auto i = xla::GetTupleElement(param, 0); + xla::Lt(i, XlaHelpers::IntegerLiteral(condition_builder.get(), + seq_lens_type, batch_size)); } auto condition = condition_builder->Build(); OP_REQUIRES_OK(context, condition.status()); auto body_builder = builder->CreateSubBuilder("reverse_sequence_body"); { - auto param = body_builder->Parameter(0, tuple_shape, "param"); - auto i = body_builder->GetTupleElement(param, 0); - auto seq_lens = body_builder->GetTupleElement(param, 1); - auto output = body_builder->GetTupleElement(param, 2); + auto param = xla::Parameter(body_builder.get(), 0, tuple_shape, "param"); + auto i = xla::GetTupleElement(param, 0); + auto seq_lens = xla::GetTupleElement(param, 1); + auto output = xla::GetTupleElement(param, 2); // seq_len is the sequence length of the current batch element (rank 1) - auto seq_len = body_builder->DynamicSlice( - seq_lens, body_builder->Reshape(i, {1}), {1}); + auto seq_len = xla::DynamicSlice(seq_lens, xla::Reshape(i, {1}), {1}); // Indices is the offset of the batch element in the input. - auto batch_element_indices = body_builder->Broadcast( - XlaHelpers::Zero(body_builder.get(), seq_lens_type), - {input_shape.dims()}); - batch_element_indices = body_builder->DynamicUpdateSlice( - batch_element_indices, body_builder->Reshape(i, {1}), - body_builder->Reshape( - XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, - batch_dim_), - {1})); + auto batch_element_indices = + xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {input_shape.dims()}); + batch_element_indices = xla::DynamicUpdateSlice( + batch_element_indices, xla::Reshape(i, {1}), + xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(), + seq_lens_type, batch_dim_), + {1})); // Slice out the current batch element and pad it out in the sequence // dimension. TensorShape slice_shape = input_shape; slice_shape.set_dim(batch_dim_, 1); slice_shape.set_dim(seq_dim_, max_seq_len); - auto slice = body_builder->DynamicSlice(output, batch_element_indices, - slice_shape.dim_sizes()); + auto slice = xla::DynamicSlice(output, batch_element_indices, + slice_shape.dim_sizes()); auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims()); padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( slice_shape.dim_size(seq_dim_)); - slice = body_builder->Pad( - slice, XlaHelpers::Zero(body_builder.get(), input_type), - padding_config); + slice = xla::Pad(slice, XlaHelpers::Zero(body_builder.get(), input_type), + padding_config); // Now slice out the reversed sequence from its actual start. // sequence_start_indices is the offset of the start of the reversed // sequence in the input. The slice will go into the padding, however, we // will mask off these elements and replace them with elements from the // original input so their values do not matter. - auto sequence_start_indices = body_builder->Broadcast( - XlaHelpers::Zero(body_builder.get(), seq_lens_type), - {slice_shape.dims()}); - sequence_start_indices = body_builder->DynamicUpdateSlice( + auto sequence_start_indices = + xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {slice_shape.dims()}); + sequence_start_indices = xla::DynamicUpdateSlice( sequence_start_indices, - body_builder->Sub(XlaHelpers::IntegerLiteral( - body_builder.get(), seq_lens_type, max_seq_len), - seq_len), - body_builder->Reshape( - XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, - seq_dim_), - {1})); - slice = body_builder->DynamicSlice(slice, sequence_start_indices, - slice_shape.dim_sizes()); + xla::Sub(XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, + max_seq_len), + seq_len), + xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(), + seq_lens_type, seq_dim_), + {1})); + slice = xla::DynamicSlice(slice, sequence_start_indices, + slice_shape.dim_sizes()); // Shift the reversed sequence to the left. - output = body_builder->DynamicUpdateSlice(output, slice, - batch_element_indices); + output = xla::DynamicUpdateSlice(output, slice, batch_element_indices); - body_builder->Tuple( - {body_builder->Add( - i, XlaHelpers::One(body_builder.get(), seq_lens_type)), + xla::Tuple( + body_builder.get(), + {xla::Add(i, XlaHelpers::One(body_builder.get(), seq_lens_type)), seq_lens, output}); } auto body = body_builder->Build(); OP_REQUIRES_OK(context, body.status()); - auto loop_output = builder->While( + auto loop_output = xla::While( condition.ValueOrDie(), body.ValueOrDie(), - builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens, - builder->Rev(input, {seq_dim_})})); - auto output = builder->GetTupleElement(loop_output, 2); + xla::Tuple(builder, {XlaHelpers::Zero(builder, seq_lens_type), seq_lens, + xla::Rev(input, {seq_dim_})})); + auto output = xla::GetTupleElement(loop_output, 2); // Mask out elements after the sequence length. - xla::XlaOp iota; - OP_REQUIRES_OK( - context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); + xla::XlaOp iota = + xla::Iota(builder, seq_lens_xla_shape.element_type(), max_seq_len); std::vector dims(input_shape.dims(), 1); dims[batch_dim_] = batch_size; - auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_}); + auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_}); // Broadcast the mask up to the input shape. - mask = - builder->Or(mask, builder->Broadcast(builder->ConstantR0(false), - input_shape.dim_sizes())); + mask = xla::Or(mask, xla::Broadcast(xla::ConstantR0(builder, false), + input_shape.dim_sizes())); - output = builder->Select(mask, output, input); + output = xla::Select(mask, output, input); context->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 1819fb543317eed15b2fe0518d74aba5c564697d..ab094d7dd1ce9856a3c2854fd2776827d6c4b76f 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -20,7 +20,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" @@ -100,7 +102,7 @@ class ScanOp : public XlaOpKernel { init = XlaHelpers::One(builder, dtype); reducer = ctx->GetOrCreateMul(dtype); } - auto output = builder->ReduceWindowWithGeneralPadding( + auto output = xla::ReduceWindowWithGeneralPadding( XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, *reducer, window_dims, window_strides, padding); output = @@ -110,12 +112,12 @@ class ScanOp : public XlaOpKernel { // of all the input elements. Slice off this extra "last" element. if (exclusive_) { if (reverse_) { - output = builder->SliceInDim(output, 1, input_shape.dim_size(axis) + 1, - 1, axis); + output = + xla::SliceInDim(output, 1, input_shape.dim_size(axis) + 1, 1, axis); } else { output = - builder->SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis); + xla::SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis); } } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index f2c63b4f9083ad3c7dd7cf318dc22def1e99fa9f..f1f32699fee5f03f603f830722fe65622dee5d3e 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -103,8 +104,8 @@ class ScatterNdOp : public XlaOpKernel { updates_shape)); xla::XlaBuilder* builder = context->builder(); - auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - buffer_shape.dim_sizes()); + auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype), + buffer_shape.dim_sizes()); auto indices = context->Input(0); auto updates = context->Input(1); auto result = diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 664078ca16c6d5d4b57c4a8c661ad0848f30dd7d..b22ecb7c6dbb42a33a4f4d90b18b20816df16a50 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -14,20 +14,30 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/lib/scatter.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { -class UnsortedSegmentSum : public XlaOpKernel { +class UnsortedSegmentReduce : public XlaOpKernel { public: - explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + DataType dtype; + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &type_)); } + // The initial value to initialize elements of the output to. + virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; + + // A function to combine two scalars with the same index (e.g., sum). + virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) = 0; + void Compile(XlaOpKernelContext* ctx) override { // output = unsorted_segment_sum(data, indices, num_segments) // Compute a tensor such that: @@ -50,28 +60,28 @@ class UnsortedSegmentSum : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments)); OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(), - errors::InvalidArgument( - "UnsortedSegmentSum requires that indices' rank be" - " less than or equal to data's rank.")); + errors::InvalidArgument(type_string(), + " requires that indices' rank be" + " less than or equal to data's rank.")); // Validate that indices.shape is a prefix of data.shape. for (int d = 0; d < indices_shape.dims(); ++d) { - OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)), - errors::InvalidArgument( - "UnsortedSegmentSum requires indices shape to be prefix" - " of data_shape, but dimension ", - d, " differs ", data_shape.dim_size(d), " vs. ", - indices_shape.dim_size(d))); + OP_REQUIRES( + ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)), + errors::InvalidArgument(type_string(), + " requires indices shape to be prefix" + " of data_shape, but dimension ", + d, " differs ", data_shape.dim_size(d), + " vs. ", indices_shape.dim_size(d))); } xla::XlaBuilder* builder = ctx->builder(); TensorShape buffer_shape = data_shape; buffer_shape.RemoveDimRange(0, indices_shape.dims()); buffer_shape.InsertDim(0, num_segments); - auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), - buffer_shape.dim_sizes()); + auto buffer = + xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); - auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) { - return builder->Add(a, b); - }; + auto combiner = [this](xla::XlaOp a, xla::XlaOp b, + xla::XlaBuilder* builder) { return Combine(a, b); }; auto result = XlaScatter(buffer, /*updates=*/data, indices, /*indices_are_vectors=*/false, combiner, builder); @@ -79,13 +89,73 @@ class UnsortedSegmentSum : public XlaOpKernel { ctx->SetOutput(0, result.ValueOrDie()); } - private: - DataType dtype_; + protected: + xla::PrimitiveType type_; +}; + +class UnsortedSegmentSum : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentSum(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return xla::Zero(builder, type_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a + b; }; }; REGISTER_XLA_OP( Name("UnsortedSegmentSum").CompileTimeConstInput("num_segments"), UnsortedSegmentSum); +class UnsortedSegmentProd : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentProd(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return xla::One(builder, type_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a * b; }; +}; + +REGISTER_XLA_OP( + Name("UnsortedSegmentProd").CompileTimeConstInput("num_segments"), + UnsortedSegmentProd); + +class UnsortedSegmentMin : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentMin(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return xla::MaxFiniteValue(builder, type_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { + return xla::Min(a, b); + }; +}; + +REGISTER_XLA_OP( + Name("UnsortedSegmentMin").CompileTimeConstInput("num_segments"), + UnsortedSegmentMin); + +class UnsortedSegmentMax : public UnsortedSegmentReduce { + public: + explicit UnsortedSegmentMax(OpKernelConstruction* ctx) + : UnsortedSegmentReduce(ctx) {} + + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { + return xla::MinFiniteValue(builder, type_); + }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { + return xla::Max(a, b); + }; +}; + +REGISTER_XLA_OP( + Name("UnsortedSegmentMax").CompileTimeConstInput("num_segments"), + UnsortedSegmentMax); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index f9f48164d63492b057d4950abfc2ca6153e44870..6ce50efb4aa6e3434a7c6009cf9f52f6cff9cc9f 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -40,8 +41,6 @@ class SelectOp : public XlaOpKernel { "'then' and 'else' must have the same size. but received: ", then_shape.DebugString(), " vs. ", else_shape.DebugString())); - xla::XlaBuilder* builder = ctx->builder(); - auto cond_handle = ctx->Input(0); auto then_handle = ctx->Input(1); auto else_handle = ctx->Input(2); @@ -69,14 +68,14 @@ class SelectOp : public XlaOpKernel { const auto dim_sizes = then_shape.dim_sizes(); gtl::ArraySlice bdims = dim_sizes; bdims.pop_front(); - cond_handle = builder->Broadcast(cond_handle, bdims); + cond_handle = xla::Broadcast(cond_handle, bdims); std::vector dim_order(then_shape.dims()); dim_order[0] = then_shape.dims() - 1; std::iota(dim_order.begin() + 1, dim_order.end(), 0); - cond_handle = builder->Transpose(cond_handle, dim_order); + cond_handle = xla::Transpose(cond_handle, dim_order); } - ctx->SetOutput(0, builder->Select(cond_handle, then_handle, else_handle)); + ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index 9ce01d0d44509bbcbea18afdb4210a675834bb6d..a7f5a8f1698b9d02560de427d356e9e6be5caa7c 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -45,7 +45,7 @@ void SendOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); - ctx->builder()->Send(ctx->Input(0), channel); + xla::Send(ctx->Input(0), channel); } REGISTER_XLA_OP(Name("XlaSend"), SendOp); @@ -76,7 +76,7 @@ void RecvOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); - ctx->SetOutput(0, ctx->builder()->Recv(shape_, channel)); + ctx->SetOutput(0, xla::Recv(ctx->builder(), shape_, channel)); } REGISTER_XLA_OP(Name("XlaRecv"), RecvOp); diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 2c31f8d90891924f6f86a54ccf548de4df87f3bd..25a5bcbe1dd27d741ce3b74125ba9ce425ee78f3 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -55,9 +55,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { // The type-specific part of the implementation of Range. template -Status CreateRangeTensor(const xla::Literal& start_literal, - const xla::Literal& limit_literal, - const xla::Literal& delta_literal, Tensor* output) { +Status CreateRangeTensor(const xla::LiteralSlice& start_literal, + const xla::LiteralSlice& limit_literal, + const xla::LiteralSlice& delta_literal, + Tensor* output) { T start = start_literal.Get({}); T limit = limit_literal.Get({}); T delta = delta_literal.Get({}); @@ -67,13 +68,13 @@ Status CreateRangeTensor(const xla::Literal& start_literal, } if (delta > 0) { if (start > limit) { - return errors::InvalidArgument("Requires start <= limit when delta > 0: ", - start, "/", limit); + return errors::InvalidArgument( + "Requires start <= limit when delta > 0: ", start, "/", limit); } } else { if (start < limit) { - return errors::InvalidArgument("Requires start >= limit when delta < 0: ", - start, "/", limit); + return errors::InvalidArgument( + "Requires start >= limit when delta < 0: ", start, "/", limit); } } int64 size = diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index d59720bef742c7441ee01a954247013559bb909c..4e0cf99d8e7ff45ed9145981b5e2e637ce4d4e4b 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -147,7 +148,7 @@ class ExpandDimsOp : public XlaOpKernel { dim = std::min(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); - ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); } }; REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp); @@ -204,7 +205,7 @@ class SqueezeOp : public XlaOpKernel { } } - ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); } private: @@ -221,7 +222,7 @@ class ZerosLikeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, ctx->builder()->Broadcast(zero, input_shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes())); } }; @@ -235,7 +236,7 @@ class OnesLikeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); auto one = XlaHelpers::One(ctx->builder(), input_type(0)); - ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes())); + ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes())); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index be1e97bf26fa4cde1b741c8d0b843a85ce33a59c..6adc3c58de63ee70c26bed47eebef955893df4a5 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -92,8 +93,7 @@ class SliceOp : public XlaOpKernel { limits.push_back(begin[i] + size[i]); } std::vector strides(begin.size(), 1); - ctx->SetOutput( - 0, ctx->builder()->Slice(ctx->Input(0), begin, limits, strides)); + ctx->SetOutput(0, xla::Slice(ctx->Input(0), begin, limits, strides)); } else { // `begin` is not a compile-time constant. for (int i = 0; i < input_dims; ++i) { @@ -106,8 +106,7 @@ class SliceOp : public XlaOpKernel { input_shape.dim_size(i), "], but ", "got ", size[i])); } - ctx->SetOutput( - 0, ctx->builder()->DynamicSlice(ctx->Input(0), ctx->Input(1), size)); + ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), ctx->Input(1), size)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index bbf5ee8b12186a582666121b1df5d8b7d881863e..025ba827410f1a9f993a8a1855558a2daa86609b 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -15,9 +15,13 @@ limitations under the License. // XLA-specific Ops for softmax. +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -34,13 +38,18 @@ class SoftmaxOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape logits_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), - errors::InvalidArgument("logits must be 2-dimensional")); + OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(logits_shape), + errors::InvalidArgument("logits must have >= 1 dimension, got ", + logits_shape.DebugString())); - const int kBatchDim = 0; - const int kClassDim = 1; + // Major dimensions are batch dimensions, minor dimension is the class + // dimension. + std::vector batch_dims(logits_shape.dims() - 1); + std::iota(batch_dims.begin(), batch_dims.end(), 0); + const int kClassDim = logits_shape.dims() - 1; const DataType type = input_type(0); + const xla::PrimitiveType xla_type = ctx->input_xla_type(0); auto logits = ctx->Input(0); xla::XlaBuilder* const b = ctx->builder(); @@ -48,24 +57,27 @@ class SoftmaxOp : public XlaOpKernel { // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = - b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); + xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. Broadcasts // along the batch dimension. - auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); - auto exp_shifted = b->Exp(shifted_logits); + auto shifted_logits = xla::Sub(logits, logits_max, batch_dims); + auto exp_shifted = xla::Exp(shifted_logits); const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + xla::PrimitiveType xla_accumulation_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(accumulation_type, + &xla_accumulation_type)); auto converted = - XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type); + xla::ConvertElementType(exp_shifted, xla_accumulation_type); auto reduce = - b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + xla::Reduce(converted, xla::Zero(b, xla_accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto sum = XlaHelpers::ConvertElementType(b, reduce, type); auto softmax = log_ // softmax = shifted_logits - log(sum(exp(shifted_logits))) - ? b->Sub(shifted_logits, b->Log(sum), {kBatchDim}) + ? xla::Sub(shifted_logits, xla::Log(sum), batch_dims) // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) - : b->Div(exp_shifted, sum, {kBatchDim}); + : xla::Div(exp_shifted, sum, batch_dims); ctx->SetOutput(0, softmax); } @@ -77,8 +89,8 @@ REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp); REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp); std::pair CrossEntropyWithLogits( - XlaOpKernelContext* ctx, DataType type, const xla::XlaOp& logits, - const xla::XlaOp& labels) { + XlaOpKernelContext* ctx, DataType type, xla::PrimitiveType xla_type, + xla::XlaOp logits, xla::XlaOp labels) { const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); const int kBatchDim = 0; @@ -87,43 +99,44 @@ std::pair CrossEntropyWithLogits( xla::XlaBuilder* b = ctx->builder(); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = - b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); + xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. // Broadcasts along the batch dimension. - auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); + auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); // exp(logits - max_logits) - auto exp_shifted_logits = b->Exp(shifted_logits); + auto exp_shifted_logits = xla::Exp(shifted_logits); // sum_{class} (exp(logits - max_logits)) const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type); - auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto reduce = + xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type); // log(sum(exp(logits - max_logits))) - auto log_sum_exp = b->Log(sum_exp); + auto log_sum_exp = xla::Log(sum_exp); // sum(-labels * // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) // along classes // (The subtraction broadcasts along the batch dimension.) - auto sub = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); - auto mul = b->Mul(b->Neg(labels), sub); + auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim}); + auto mul = xla::Mul(xla::Neg(labels), sub); auto sum = - b->Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), - XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + xla::Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto loss = XlaHelpers::ConvertElementType(b, sum, type); // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) // (where the division broadcasts along the batch dimension) xla::XlaOp backprop = - b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); + xla::Sub(xla::Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); return {loss, backprop}; } @@ -146,12 +159,13 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel { // check that "labels" is a matrix too. const DataType type = input_type(0); + const xla::PrimitiveType xla_type = ctx->input_xla_type(0); auto logits = ctx->Input(0); auto labels = ctx->Input(1); xla::XlaOp loss, backprop; std::tie(loss, backprop) = - CrossEntropyWithLogits(ctx, type, logits, labels); + CrossEntropyWithLogits(ctx, type, xla_type, logits, labels); ctx->SetOutput(0, loss); ctx->SetOutput(1, backprop); } @@ -187,8 +201,9 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { int64 batch_size = logits_shape.dim_size(0); int64 depth = logits_shape.dim_size(1); - DataType logits_type = input_type(0); - DataType indices_type = input_type(1); + const DataType logits_type = input_type(0); + const xla::PrimitiveType xla_logits_type = ctx->input_xla_type(0); + const DataType indices_type = input_type(1); xla::XlaOp indices = ctx->Input(1); @@ -206,20 +221,18 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { // Builds a vector of {batch_size} that is 0 if the index is in range, or // NaN otherwise; then add that vector to the labels to force out-of-range // values to NaNs. - xla::XlaOp nan_or_zero = builder->Select( - builder->And( - builder->Le(XlaHelpers::Zero(builder, indices_type), indices), - builder->Lt(indices, XlaHelpers::IntegerLiteral( - builder, indices_type, depth))), - builder->Broadcast(XlaHelpers::Zero(builder, logits_type), - {batch_size}), - builder->Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN), - {batch_size})); - labels = builder->Add(labels, nan_or_zero, {0}); + xla::XlaOp nan_or_zero = xla::Select( + xla::And(xla::Le(XlaHelpers::Zero(builder, indices_type), indices), + xla::Lt(indices, XlaHelpers::IntegerLiteral( + builder, indices_type, depth))), + xla::Broadcast(XlaHelpers::Zero(builder, logits_type), {batch_size}), + xla::Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN), + {batch_size})); + labels = xla::Add(labels, nan_or_zero, {0}); xla::XlaOp loss, backprop; - std::tie(loss, backprop) = - CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels); + std::tie(loss, backprop) = CrossEntropyWithLogits( + ctx, logits_type, xla_logits_type, ctx->Input(0), labels); ctx->SetOutput(0, loss); ctx->SetOutput(1, backprop); } diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..aaeeae01ccb303091a6d37d1aeb4b2a3377dc638 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" + +namespace tensorflow { +namespace { + +class XlaSortOp : public XlaOpKernel { + public: + explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + context->SetOutput(0, xla::Sort(context->Input(0))); + } +}; + +REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index ec077924b5b5af4a573c86c8d9aeb8623bd7f801..7327258c31f21f45ff7ffffbc9db7a2a70b4a14c 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { @@ -73,7 +74,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, "The product of the block dimensions must be positive")); xla::XlaOp padded = - b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); + xla::Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); // 2. Reshape `padded` to `reshaped_padded` of shape: // @@ -100,7 +101,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_padded_shape.begin() + 1 + 2 * block_rank); - xla::XlaOp reshaped_padded = b->Reshape(padded, reshaped_padded_shape); + xla::XlaOp reshaped_padded = xla::Reshape(padded, reshaped_padded_shape); // 3. Permute dimensions of `reshaped_padded` to produce // `permuted_reshaped_padded` of shape: @@ -120,7 +121,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); xla::XlaOp permuted_reshaped_padded = - b->Transpose(reshaped_padded, permutation); + xla::Transpose(reshaped_padded, permutation); // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -140,7 +141,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, std::copy(remainder_shape.begin(), remainder_shape.end(), output_shape.begin() + 1 + block_rank); - xla::XlaOp output = b->Reshape(permuted_reshaped_padded, output_shape); + xla::XlaOp output = xla::Reshape(permuted_reshaped_padded, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 4c5886ee2a0f63d609f79fc690f457d93e284e3e..4493539fe34f0ce635fdc58660d4ff90af9c9379 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -50,7 +51,6 @@ class SpaceToDepthOp : public XlaOpKernel { const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); - xla::XlaBuilder* b = ctx->builder(); xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); @@ -135,7 +135,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[1] / block_size_, block_size_, // input_shape[2] / block_size_, block_size_, // depth] - xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -145,7 +145,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_, block_size_, // depth] - xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -155,7 +155,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_ * block_size_ * depth] // - xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e831dc30a9d3c27ec3b1494e7d8a6de836ff2a11 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc @@ -0,0 +1,88 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/scatter.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +// Operator to convert sparse representations to dense. +class SparseToDenseOp : public XlaOpKernel { + public: + explicit SparseToDenseOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + // sparse_indices + const TensorShape indices_shape = context->InputShape(0); + OP_REQUIRES(context, indices_shape.dims() <= 2, + errors::InvalidArgument( + "sparse_indices should be a scalar, vector, or matrix, " + "got shape ", + indices_shape.DebugString())); + const int64 num_elems = + indices_shape.dims() > 0 ? indices_shape.dim_size(0) : 1; + const int64 num_dims = + indices_shape.dims() > 1 ? indices_shape.dim_size(1) : 1; + + // output_shape + TensorShape output_shape; + OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); + OP_REQUIRES(context, output_shape.dims() == num_dims, + errors::InvalidArgument( + "output_shape has incorrect number of elements: ", + output_shape.num_elements(), " should be: ", num_dims)); + + // sparse_values + const TensorShape sparse_values_shape = context->InputShape(2); + const int64 num_values = sparse_values_shape.num_elements(); + OP_REQUIRES( + context, + sparse_values_shape.dims() == 0 || + (sparse_values_shape.dims() == 1 && num_values == num_elems), + errors::InvalidArgument("sparse_values has incorrect shape ", + sparse_values_shape.DebugString(), + ", should be [] or [", num_elems, "]")); + + // default_value + const TensorShape default_value_shape = context->InputShape(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(default_value_shape), + errors::InvalidArgument("default_value should be a scalar.")); + + xla::XlaOp indices = context->Input(0); + xla::XlaOp sparse_values = context->Input(2); + xla::XlaOp default_value = context->Input(3); + + if (sparse_values_shape.dims() == 0 && num_elems != 1) { + sparse_values = Broadcast(sparse_values, {num_elems}); + } + xla::XlaBuilder* builder = context->builder(); + auto buffer = Broadcast(default_value, output_shape.dim_sizes()); + + auto result = XlaScatter(buffer, sparse_values, indices, + /*indices_are_vectors=*/num_dims > 1, + /*combiner=*/{}, builder); + context->SetOutput(0, builder->ReportErrorOrReturn(result)); + } +}; + +REGISTER_XLA_OP(Name("SparseToDense").CompileTimeConstInput("output_shape"), + SparseToDenseOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 8958b2e7701e62d802e37a895c14b662ecf9786a..93fc14e9efca868e84444dd0e07d7f0dfa84c042 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -19,7 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -98,7 +99,7 @@ class SplitOp : public XlaOpKernel { // Slice out the ith split from the split dimension. begin[split_dim] = i * slice_size; limits[split_dim] = (i + 1) * slice_size; - ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); + ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); } } }; @@ -134,7 +135,7 @@ class SplitVOp : public XlaOpKernel { errors::InvalidArgument( "Number of ways to split should be > 0, but got ", num_split)); - // check that sizes are correct + // Check that sizes are correct. int total_split_size = 0; int neg_one_dim = -1; std::vector split_sizes_vec(num_split, -1); @@ -148,7 +149,7 @@ class SplitVOp : public XlaOpKernel { " number of elements as the output. Got ", split_size_shape.dims(), "-D and ", split_size_shape.num_elements(), " elements")); - // get the dimension of this split + // Get the dimension of this split. xla::Literal split_size_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal)); @@ -199,7 +200,7 @@ class SplitVOp : public XlaOpKernel { // Slice out the ith split from the split dimension. limits[split_dim] = begin[split_dim] + slice_size; - ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides)); + ctx->SetOutput(i, xla::Slice(input, begin, limits, strides)); begin[split_dim] = limits[split_dim]; } } diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 0fb05a2be7b1034d6c2e864643b69647d622ede7..df91900570107609c0f1c2281faaab8a5e65b98b 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" @@ -144,24 +144,25 @@ class StackPushOp : public XlaOpKernel { // Initializes the Stack, if the element shape was not already known. OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape)); - xla::XlaOp ta = b->GetTupleElement(resource->value(), 0); - xla::XlaOp index = b->GetTupleElement(resource->value(), 1); + xla::XlaOp ta = xla::GetTupleElement(resource->value(), 0); + xla::XlaOp index = xla::GetTupleElement(resource->value(), 1); xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); - auto update = b->Reshape(value, slice_shape.dim_sizes()); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple( - {b->DynamicUpdateSlice(ta, update, start_indices), - b->Add(index, b->ConstantR0(1))}))); + OP_REQUIRES_OK(ctx, + resource->SetValue(xla::Tuple( + b, {xla::DynamicUpdateSlice(ta, update, start_indices), + xla::Add(index, xla::ConstantR0(b, 1))}))); ctx->SetOutput(0, value); } @@ -197,27 +198,27 @@ class StackPopOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape)); xla::XlaOp state = resource->value(); - xla::XlaOp ta = b->GetTupleElement(state, 0); - xla::XlaOp index = b->GetTupleElement(state, 1); + xla::XlaOp ta = xla::GetTupleElement(state, 0); + xla::XlaOp index = xla::GetTupleElement(state, 1); - index = b->Sub(index, b->ConstantR0(1)); - OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index}))); + index = Sub(index, xla::ConstantR0(b, 1)); + OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}})); auto slice_shape = stack_shape.dim_sizes(); slice_shape[0] = 1LL; // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); - ctx->SetOutput(0, b->Reshape(read, value_shape)); + ctx->SetOutput(0, xla::Reshape(read, value_shape)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index a99d4ddc7c4956f7144512a9bdf6f4c2eb0f944f..5412e135478361d08965e4621ec52cfb4a792f1d 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -15,11 +15,16 @@ limitations under the License. #include +#include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -29,187 +34,6 @@ limitations under the License. namespace tensorflow { namespace { -// Rotates a 32-bit integer 'v' left by 'distance' bits. -xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v, - int distance) { - return builder->Or( - builder->ShiftLeft(v, builder->ConstantR0(distance)), - builder->ShiftRightLogical(v, builder->ConstantR0(32 - distance))); -} - -// TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than -// building XOR out of other bitwise operators. -xla::XlaOp BitwiseXor(xla::XlaBuilder* builder, const xla::XlaOp& x, - const xla::XlaOp& y) { - return builder->Or(builder->And(x, builder->Not(y)), - builder->And(builder->Not(x), y)); -} - -using ThreeFry2x32State = std::array; - -// Implements the ThreeFry counter-based PRNG algorithm. -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder, - ThreeFry2x32State input, ThreeFry2x32State key) { - // Rotation distances specified by the Threefry2x32 algorithm. - constexpr std::array rotations = {13, 15, 26, 6, 17, 29, 16, 24}; - ThreeFry2x32State x; - - std::array ks; - // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. - ks[2] = builder->ConstantR0(0x1BD11BDA); - for (int i = 0; i < 2; ++i) { - ks[i] = key[i]; - x[i] = input[i]; - ks[2] = BitwiseXor(builder, ks[2], key[i]); - } - - x[0] = builder->Add(x[0], ks[0]); - x[1] = builder->Add(x[1], ks[1]); - - // Performs a single round of the Threefry2x32 algorithm, with a rotation - // amount 'rotation'. - auto round = [builder](ThreeFry2x32State v, int rotation) { - v[0] = builder->Add(v[0], v[1]); - v[1] = RotateLeftS32(builder, v[1], rotation); - v[1] = BitwiseXor(builder, v[0], v[1]); - return v; - }; - - // There are no known statistical flaws with 13 rounds of Threefry2x32. - // We are conservative and use 20 rounds. - x = round(x, rotations[0]); - x = round(x, rotations[1]); - x = round(x, rotations[2]); - x = round(x, rotations[3]); - x[0] = builder->Add(x[0], ks[1]); - x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0(1)); - - x = round(x, rotations[4]); - x = round(x, rotations[5]); - x = round(x, rotations[6]); - x = round(x, rotations[7]); - x[0] = builder->Add(x[0], ks[2]); - x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0(2)); - - x = round(x, rotations[0]); - x = round(x, rotations[1]); - x = round(x, rotations[2]); - x = round(x, rotations[3]); - x[0] = builder->Add(x[0], ks[0]); - x[1] = builder->Add(builder->Add(x[1], ks[1]), builder->ConstantR0(3)); - - x = round(x, rotations[4]); - x = round(x, rotations[5]); - x = round(x, rotations[6]); - x = round(x, rotations[7]); - x[0] = builder->Add(x[0], ks[1]); - x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0(4)); - - x = round(x, rotations[0]); - x = round(x, rotations[1]); - x = round(x, rotations[2]); - x = round(x, rotations[3]); - x[0] = builder->Add(x[0], ks[2]); - x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0(5)); - - return x; -} - -// Returns a tensor of 'shape' random values uniformly distributed in the range -// [minval, maxval) -xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, - const TensorShape& shape, double minval, - double maxval) { - // Split the seed into two 32-bit scalars to form a key. - auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {}); - ThreeFry2x32State key = {seed0, seed1}; - const int64 size = shape.num_elements(); - - const int64 half_size = MathUtil::CeilOfRatio(size, 2); - const bool size_is_odd = (half_size * 2 != size); - - // Fill the generator inputs with unique counter values. - ThreeFry2x32State inputs; - TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0])); - inputs[1] = builder->Add(inputs[0], builder->ConstantR0(half_size)); - ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key); - - if (size_is_odd) { - outputs[1] = builder->Slice(outputs[1], {0}, {half_size - 1}, {1}); - } - - auto bits = - builder->Reshape(builder->ConcatInDim(outputs, 0), shape.dim_sizes()); - - // Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit - // forces the random bits into the mantissa. - constexpr int kFloatBits = 32; - constexpr int kMantissaBits = 23; - bits = builder->Or( - builder->ShiftRightLogical( - bits, builder->ConstantR0(kFloatBits - kMantissaBits)), - builder->ConstantR0(bit_cast(1.0f))); - auto floats = builder->BitcastConvertType(bits, xla::F32); - - // We have a floating point number in the range [1.0, 2.0). - // Subtract 1.0f to shift to the range [0.0, 1.0) - floats = builder->Sub(floats, builder->ConstantR0(1.0f)); - // Multiply and add to shift to the range [minval, maxval). - floats = builder->Mul(floats, builder->ConstantR0(maxval - minval)); - floats = builder->Add(floats, builder->ConstantR0(minval)); - return floats; -} - -// Approximation for the inverse error function from -// Giles, M., "Approximating the erfinv function". -// The approximation has the form: -// w = -log((1 - x) * (1 + x)) -// if ( w < 5 ) { -// w = w - 2.5 -// p = sum_{i=1}^n lq[i]*w^i -// } else { -// w = sqrt(w) - 3 -// p = sum_{i=1}^n gq[i]*w^i -// } -// return p*x -xla::XlaOp ErfInvF32(xla::XlaBuilder* b, const xla::XlaOp& x, - const TensorShape& shape) { - constexpr int kDegree = 9; - constexpr std::array w_less_than_5_constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array w_greater_than_5_constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - - auto one = b->ConstantR0(1.0); - auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x)))); - - auto lt = b->Lt(w, b->ConstantR0(5.0)); - auto coefficient = [&](int i) { - return b->Select( - lt, - b->Broadcast(b->ConstantR0(w_less_than_5_constants[i]), - shape.dim_sizes()), - b->Broadcast(b->ConstantR0(w_greater_than_5_constants[i]), - shape.dim_sizes())); - }; - w = b->Select(lt, b->Sub(w, b->ConstantR0(2.5f)), - b->Sub(b->SqrtF32(w), b->ConstantR0(3.0f))); - auto p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = b->Add(coefficient(i), b->Mul(p, w)); - } - return b->Mul(p, x); -} - -} // namespace - class StatelessRandomUniformOp : public XlaOpKernel { public: explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) @@ -226,7 +50,17 @@ class StatelessRandomUniformOp : public XlaOpKernel { errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); - ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0)); + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + auto uniform = xla::StatelessRngUniform( + {seed0, seed1}, xla_shape, xla::ConstantR0(builder, 0.0), + xla::ConstantR0(builder, 1.0)); + ctx->SetOutput(0, uniform); } private: @@ -255,12 +89,20 @@ class StatelessRandomNormalOp : public XlaOpKernel { seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); xla::XlaBuilder* builder = ctx->builder(); - auto uniform = - RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + auto uniform = xla::StatelessRngUniform( + {seed0, seed1}, xla_shape, + xla::ConstantR0(builder, std::nextafter(-1.0f, 0.0f)), + xla::ConstantR0(builder, 1.0)); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) - auto normal = builder->Mul(builder->ConstantR0(std::sqrt(2.0)), - ErfInvF32(builder, uniform, shape)); + auto normal = + xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform); ctx->SetOutput(0, normal); } @@ -275,4 +117,44 @@ REGISTER_XLA_OP(Name("StatelessRandomNormal") .TypeConstraint("Tseed", DT_INT32), StatelessRandomNormalOp); +class StatelessTruncatedNormalOp : public XlaOpKernel { + public: + explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + TensorShape seed_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, seed_shape == TensorShape({2}), + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + xla::XlaOp seed = ctx->Input(1); + xla::XlaBuilder* builder = ctx->builder(); + + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + auto uniform = xla::StatelessRngUniform( + {seed0, seed1}, xla_shape, + xla::ConstantR0(builder, std::numeric_limits::min()), + xla::ConstantR0(builder, 1.0)); + + ctx->SetOutput(0, TruncatedNormal(uniform)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp); +}; + +REGISTER_XLA_OP(Name("StatelessTruncatedNormal") + .CompileTimeConstInput("shape") + .TypeConstraint("dtype", DT_FLOAT) + .TypeConstraint("Tseed", DT_INT32), + StatelessTruncatedNormalOp); + +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 55254c746e5ebaf6b468c24ab59b968bf0d6260b..1062399d91bd9a9bf8c3820c5ecac534c110746d 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -92,12 +93,12 @@ class StridedSliceOp : public XlaOpKernel { xla::XlaOp slice = ctx->Input(0); if (!dimensions_to_reverse.empty()) { - slice = ctx->builder()->Rev(slice, dimensions_to_reverse); + slice = xla::Rev(slice, dimensions_to_reverse); } - slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides); + slice = xla::Slice(slice, slice_begin, slice_end, slice_strides); - slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes()); + slice = xla::Reshape(slice, final_shape.dim_sizes()); ctx->SetOutput(0, slice); } @@ -171,7 +172,7 @@ class StridedSliceGradOp : public XlaOpKernel { xla::XlaOp grad = ctx->Input(4); // Undo any new/shrink axes. - grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes()); + grad = xla::Reshape(grad, processing_shape.dim_sizes()); // Pad the input gradients. gtl::InlinedVector dimensions_to_reverse; @@ -204,9 +205,9 @@ class StridedSliceGradOp : public XlaOpKernel { } } if (!dimensions_to_reverse.empty()) { - grad = ctx->builder()->Rev(grad, dimensions_to_reverse); + grad = xla::Rev(grad, dimensions_to_reverse); } - grad = ctx->builder()->Pad(grad, zero, padding_config); + grad = xla::Pad(grad, zero, padding_config); ctx->SetOutput(0, grad); } @@ -306,17 +307,17 @@ class StridedSliceAssignOp : public XlaOpKernel { } if (!dimensions_to_reverse.empty()) { - rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse); + rhs = xla::Rev(rhs, dimensions_to_reverse); } - rhs = ctx->builder()->Reshape(rhs, slice_dims); + rhs = xla::Reshape(rhs, slice_dims); if (lhs_shape.dims() == 0) { // TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix // and remove this workaround. lhs = rhs; } else { - lhs = ctx->builder()->DynamicUpdateSlice( - lhs, rhs, ctx->builder()->ConstantR1(slice_begin)); + lhs = xla::DynamicUpdateSlice( + lhs, rhs, xla::ConstantR1(ctx->builder(), slice_begin)); } OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 9adee78a1fd1fb9a12afae83197425c328b5fe7e..be1814d8e3ae2c0ddad0134b9288e0ea084aa81b 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -25,7 +25,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" @@ -123,10 +124,9 @@ xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, const gtl::ArraySlice& update_dims, const xla::XlaOp& start_indices) { - xla::XlaOp current = - builder->DynamicSlice(operand, start_indices, update_dims); - xla::XlaOp sum = builder->Add(current, update); - return builder->DynamicUpdateSlice(operand, sum, start_indices); + xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); + xla::XlaOp sum = xla::Add(current, update); + return xla::DynamicUpdateSlice(operand, sum, start_indices); } class TensorArrayOp : public XlaOpKernel { @@ -162,7 +162,7 @@ class TensorArrayOp : public XlaOpKernel { ta_shape.AddDim(size); ta_shape.AppendShape(shape); xla::XlaOp zero = XlaHelpers::Zero(b, dtype_); - value = b->Broadcast(zero, ta_shape.dim_sizes()); + value = xla::Broadcast(zero, ta_shape.dim_sizes()); } XlaContext& xc = XlaContext::Get(ctx); @@ -215,12 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); TensorShape slice_shape = elem_shape; slice_shape.InsertDim(0, 1LL); - auto update = b->Reshape(value, slice_shape.dim_sizes()); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); xla::XlaOp written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); @@ -259,17 +259,17 @@ class TensorArrayReadOp : public XlaOpKernel { // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}})); auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; - xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); - ctx->SetOutput(0, b->Reshape(read, value_shape)); + ctx->SetOutput(0, xla::Reshape(read, value_shape)); } private: @@ -326,7 +326,7 @@ class TensorArrayGatherOp : public XlaOpKernel { for (auto i = 1; i < ta_shape.dims(); i++) { end[i] = ta_shape.dim_size(i); } - ctx->SetOutput(0, b->Slice(ta, begin, end, strides)); + ctx->SetOutput(0, xla::Slice(ta, begin, end, strides)); return; } } @@ -391,7 +391,7 @@ class TensorArrayScatterOp : public XlaOpKernel { } if (scatter_all_elements_in_order) { - ta = b->Add(ta, value); + ta = xla::Add(ta, value); } else { auto slice_dims = value_shape.dim_sizes(); slice_dims[0] = 1LL; @@ -407,13 +407,13 @@ class TensorArrayScatterOp : public XlaOpKernel { // Slice out part of the value. value_starts[0] = i; value_ends[0] = i + 1; - auto slice = b->Slice(value, value_starts, value_ends, value_strides); + auto slice = xla::Slice(value, value_starts, value_ends, value_strides); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = b->Slice(indices, {i}, {i + 1}, {1}); + auto index = xla::Slice(indices, {i}, {i + 1}, {1}); auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } } @@ -452,7 +452,7 @@ class TensorArrayConcatOp : public XlaOpKernel { auto ta_dims = ta_shape.dim_sizes(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); shape[0] *= ta_shape.dim_size(0); - ctx->SetOutput(0, b->Reshape(ta, shape)); + ctx->SetOutput(0, xla::Reshape(ta, shape)); Tensor lengths(DT_INT64, {ta_dims[0]}); auto lengths_vec = lengths.vec(); @@ -522,8 +522,8 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - OP_REQUIRES_OK(ctx, resource->SetValue(b->Add( - ta, b->Reshape(value, ta_shape.dim_sizes())))); + OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add( + ta, xla::Reshape(value, ta_shape.dim_sizes())))); ctx->SetOutput(0, flow); } diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index e91075196bd8414939888e22b5483ad637487af6..1233a37565d3a40c6dd2882b3139dedbf690a7b6 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -93,9 +94,9 @@ class TileOp : public XlaOpKernel { if (one_dimension_is_broadcasted_without_multiple) { // Create a constant Zero the size of the output shape to leverage binary // operation broadcast semantics. - auto broadcasted_zero = ctx->builder()->Broadcast( + auto broadcasted_zero = xla::Broadcast( XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_shape); - ctx->SetOutput(0, ctx->builder()->Add(broadcasted_zero, input)); + ctx->SetOutput(0, xla::Add(broadcasted_zero, input)); return; } @@ -103,7 +104,7 @@ class TileOp : public XlaOpKernel { // dimension. This prepends the broadcasted dimensions, so an // input of shape [2,3,1] broadcast with multiples [5,4,3] will // end up with shape [5,4,3,2,3,1]. - auto broadcasted = ctx->builder()->Broadcast(input, multiples_array); + auto broadcasted = xla::Broadcast(input, multiples_array); // Now flatten and reshape. The broadcasted dimensions are // paired with the original dimensions so in the above example // we flatten [0,3,1,4,2,5] then reshape to [10,12,3]. @@ -112,8 +113,7 @@ class TileOp : public XlaOpKernel { flattened.push_back(i); flattened.push_back(i + output_shape.size()); } - xla::XlaOp output = - ctx->builder()->Reshape(broadcasted, flattened, output_shape); + xla::XlaOp output = xla::Reshape(broadcasted, flattened, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..183879c7602ccbbd74fca6cb9fa3fc94c066c37d --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/no_op.h" + +namespace tensorflow { +namespace { + +class TopKOp : public XlaOpKernel { + public: + explicit TopKOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_)); + } + + void Compile(XlaOpKernelContext* context) override { + int64 k; + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(1, &k)); + OP_REQUIRES(context, k >= 0, + errors::InvalidArgument("Need k >= 0, got ", k)); + const TensorShape input_shape = context->InputShape(0); + OP_REQUIRES(context, input_shape.dims() >= 1, + errors::InvalidArgument("input must be >= 1-D, got shape ", + input_shape.DebugString())); + int last_dim = input_shape.dims() - 1; + int last_dim_size = input_shape.dim_size(last_dim); + OP_REQUIRES( + context, last_dim_size >= k, + errors::InvalidArgument("input must have at least k columns. Had ", + last_dim_size, ", needed ", k)); + if (last_dim_size < k) { + k = last_dim_size; + } + xla::XlaOp output_tuple = TopK(context->Input(0), k); + context->SetOutput(0, xla::GetTupleElement(output_tuple, 0)); + context->SetOutput(1, xla::GetTupleElement(output_tuple, 1)); + } + + private: + bool sorted_; +}; + +REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstInput("k").TypeConstraint( + "T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_BFLOAT16}), + TopKOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 34caefa050c0d58f5f7bad557286b6ed64b996ad..be5e91138656716daddcc3c7a68dbb78ecb69103 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" @@ -31,7 +33,6 @@ class ResourceApplyGradientDescent : public XlaOpKernel { : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { xla::XlaOp handle; - xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(1); TensorShape var_shape; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle)); @@ -48,7 +49,7 @@ class ResourceApplyGradientDescent : public XlaOpKernel { var_shape.DebugString(), " vs ", delta_shape.DebugString())); - handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2))); + handle = handle - ctx->Input(1) * ctx->Input(2); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -56,6 +57,64 @@ REGISTER_XLA_OP( Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes), ResourceApplyGradientDescent); +xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr, + xla::XlaOp l1, xla::XlaOp l2, + xla::XlaOp grad) { + xla::XlaOp one = xla::ScalarLike(lr, 1.0); + xla::XlaOp zero = xla::ScalarLike(lr, 0.0); + xla::XlaOp prox_var = var - grad * lr; + xla::XlaOp l1_gt_zero = xla::Sign(prox_var) * + xla::Max(xla::Abs(prox_var) - lr * l1, zero) / + (one + lr * l2); + xla::XlaOp l1_le_zero = prox_var / (one + lr * l2); + return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero); +} + +class ResourceApplyProximalGradientDescent : public XlaOpKernel { + public: + explicit ResourceApplyProximalGradientDescent(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp var; + TensorShape var_shape; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + + TensorShape alpha_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("alpha is not a scalar: ", + alpha_shape.DebugString())); + TensorShape l1_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("l1 is not a scalar: ", + l1_shape.DebugString())); + TensorShape l2_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("l2 is not a scalar: ", + l2_shape.DebugString())); + TensorShape delta_shape = ctx->InputShape(4); + OP_REQUIRES( + ctx, var_shape.IsSameSize(delta_shape), + errors::InvalidArgument("var and delta do not have the same shape: ", + var_shape.DebugString(), " vs ", + delta_shape.DebugString())); + xla::XlaOp alpha = ctx->Input(1); + xla::XlaOp l1 = ctx->Input(2); + xla::XlaOp l2 = ctx->Input(3); + xla::XlaOp delta = ctx->Input(4); + var = ProximalGradientDescentUpdate(var, alpha, l1, l2, delta); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyProximalGradientDescent") + .TypeConstraint("T", kFloatTypes), + ResourceApplyProximalGradientDescent); + class ResourceApplyMomentum : public XlaOpKernel { public: explicit ResourceApplyMomentum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -63,8 +122,6 @@ class ResourceApplyMomentum : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; @@ -97,14 +154,13 @@ class ResourceApplyMomentum : public XlaOpKernel { xla::XlaOp grad = ctx->Input(3); xla::XlaOp momentum = ctx->Input(4); - accum = b->Add(b->Mul(accum, momentum), grad); + accum = accum * momentum + grad; if (use_nesterov_) { // See https://github.com/tensorflow/tensorflow/pull/2798 for an // explanation of the reparameterization used here. - var = b->Sub( - var, b->Add(b->Mul(grad, lr), b->Mul(b->Mul(accum, momentum), lr))); + var = var - (grad * lr + accum * momentum * lr); } else { - var = b->Sub(var, b->Mul(accum, lr)); + var = var - accum * lr; } OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); @@ -121,8 +177,6 @@ class ResourceApplyAdagrad : public XlaOpKernel { explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; @@ -149,10 +203,8 @@ class ResourceApplyAdagrad : public XlaOpKernel { xla::XlaOp lr = ctx->Input(2); xla::XlaOp grad = ctx->Input(3); - accum = b->Add(accum, b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0))); - var = b->Sub( - var, b->Mul(b->Mul(grad, lr), - b->Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5)))); + accum = accum + xla::Square(grad); + var = var - grad * lr * xla::Rsqrt(accum); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); } @@ -160,6 +212,139 @@ class ResourceApplyAdagrad : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes), ResourceApplyAdagrad); +class ResourceApplyProximalAdagrad : public XlaOpKernel { + public: + explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, accum_shape; + xla::XlaOp var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + TensorShape lr_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + TensorShape l1_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape), + errors::InvalidArgument("l1 is not a scalar: ", + l1_shape.DebugString())); + TensorShape l2_shape = ctx->InputShape(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape), + errors::InvalidArgument("l2 is not a scalar: ", + l2_shape.DebugString())); + TensorShape grad_shape = ctx->InputShape(5); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape: ", + var_shape.DebugString(), " vs ", grad_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp l1 = ctx->Input(3); + xla::XlaOp l2 = ctx->Input(4); + xla::XlaOp grad = ctx->Input(5); + accum = accum + xla::Square(grad); + // Adagrad learning rate. + xla::XlaOp adagrad_lr = lr * xla::Rsqrt(accum); + var = ProximalGradientDescentUpdate(var, adagrad_lr, l1, l2, grad); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP( + Name("ResourceApplyProximalAdagrad").TypeConstraint("T", kFloatTypes), + ResourceApplyProximalAdagrad); + +class ResourceApplyAdagradDA : public XlaOpKernel { + public: + explicit ResourceApplyAdagradDA(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, accum_shape, squared_accum_shape; + xla::XlaOp var, accum, squared_accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &squared_accum_shape, + &squared_accum)); + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + OP_REQUIRES( + ctx, var_shape.IsSameSize(squared_accum_shape), + errors::InvalidArgument( + "var and squared accum do not have the same shape", + var_shape.DebugString(), " ", squared_accum_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(3); + TensorShape lr_shape = ctx->InputShape(4); + TensorShape l1_shape = ctx->InputShape(5); + TensorShape l2_shape = ctx->InputShape(6); + TensorShape global_step_shape = ctx->InputShape(7); + + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape), + errors::InvalidArgument("l1 is not a scalar: ", + l1_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape), + errors::InvalidArgument("l2 is not a scalar: ", + l2_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step_shape), + errors::InvalidArgument("global step is not a scalar: ", + global_step_shape.DebugString())); + + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp lr = ctx->Input(4); + xla::XlaOp l1 = ctx->Input(5); + xla::XlaOp l2 = ctx->Input(6); + xla::XlaBuilder* const b = ctx->builder(); + xla::XlaOp global_step = + XlaHelpers::ConvertElementType(b, ctx->Input(7), dtype_); + + accum = accum + grad; + squared_accum = squared_accum + xla::Square(grad); + xla::XlaOp zero = xla::ScalarLike(lr, 0.0); + xla::XlaOp denominator = global_step * lr * l2 + xla::Sqrt(squared_accum); + xla::XlaOp l1_le_zero = -lr * accum / denominator; + xla::XlaOp l1_gt_zero = -lr * xla::Sign(accum) * + xla::Max(xla::Abs(accum) - global_step * l1, zero) / + denominator; + + var = xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, squared_accum)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdagradDA").TypeConstraint("T", kFloatTypes), + ResourceApplyAdagradDA); + class ResourceApplyAdam : public XlaOpKernel { public: explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -227,17 +412,12 @@ class ResourceApplyAdam : public XlaOpKernel { // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon) xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); - xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); - xla::XlaOp alpha = - b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)), - b->Sub(one, beta1_power)); - m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1))); - v = b->Add(v, b->Mul(b->Sub(b->Pow(grad, two), v), b->Sub(one, beta2))); - var = - b->Sub(var, b->Div(b->Mul(m, alpha), b->Add(b->Pow(v, half), epsilon))); + xla::XlaOp alpha = lr * xla::Sqrt(one - beta2_power) / (one - beta1_power); + m = m + (grad - m) * (one - beta1); + v = v + (xla::Square(grad) - v) * (one - beta2); + var = var - m * alpha / (xla::Sqrt(v) + epsilon); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); @@ -250,38 +430,112 @@ class ResourceApplyAdam : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes), ResourceApplyAdam); -class ResourceApplyRMSProp : public XlaOpKernel { +class ResourceApplyAdaMax : public XlaOpKernel { public: - explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ResourceApplyAdaMax(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); + TensorShape var_shape, m_shape, v_shape; + xla::XlaOp var, m, v; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v)); - DataType type = ctx->input_type(3); + TensorShape beta1_power_shape = ctx->InputShape(3); + TensorShape lr_shape = ctx->InputShape(4); + TensorShape beta1_shape = ctx->InputShape(5); + TensorShape beta2_shape = ctx->InputShape(6); + TensorShape epsilon_shape = ctx->InputShape(7); + TensorShape grad_shape = ctx->InputShape(8); - TensorShape var_shape, ms_shape, mom_shape; - xla::XlaOp var, ms, mom; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape), + errors::InvalidArgument("beta1_power is not a scalar: ", + beta1_power_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar : ", + lr_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape), + errors::InvalidArgument("beta1 is not a scalar: ", + beta1_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape), + errors::InvalidArgument("beta2 is not a scalar: ", + beta2_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape), + errors::InvalidArgument("var and m do not have the same shape", + var_shape.DebugString(), " ", + m_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape), + errors::InvalidArgument("var and v do not have the same shape", + var_shape.DebugString(), " ", + v_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); - TensorShape lr_shape = ctx->InputShape(3); + xla::XlaOp beta1_power = ctx->Input(3); + xla::XlaOp lr = ctx->Input(4); + xla::XlaOp beta1 = ctx->Input(5); + xla::XlaOp beta2 = ctx->Input(6); + xla::XlaOp epsilon = ctx->Input(7); + xla::XlaOp grad = ctx->Input(8); + + xla::XlaOp one = xla::ScalarLike(lr, 1.0); + m = beta1 * m + (one - beta1) * grad; + v = xla::Max(beta2 * v, xla::Abs(grad)); + var = var - lr / (one - beta1_power) * (m / (v + epsilon)); + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdaMax").TypeConstraint("T", kFloatTypes), + ResourceApplyAdaMax); + +class ResourceApplyRMSProp : public XlaOpKernel { + public: + explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, ms_shape, mom_shape, mg_shape; + xla::XlaOp var, ms, mom, mg; + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput("var", dtype_, &var_shape, &var)); + if (centered_) { + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("mg", dtype_, &mg_shape, &mg)); + } + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("ms", dtype_, &ms_shape, &ms)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput("mom", dtype_, &mom_shape, &mom)); + + TensorShape lr_shape = ctx->InputShape("lr"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), errors::InvalidArgument("lr is not a scalar: ", lr_shape.DebugString())); - TensorShape rho_shape = ctx->InputShape(4); + TensorShape rho_shape = ctx->InputShape("rho"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape), errors::InvalidArgument("rho is not a scalar: ", rho_shape.DebugString())); - TensorShape momentum_shape = ctx->InputShape(5); + TensorShape momentum_shape = ctx->InputShape("momentum"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape), errors::InvalidArgument("momentum is not a scalar: ", momentum_shape.DebugString())); - TensorShape epsilon_shape = ctx->InputShape(6); + TensorShape epsilon_shape = ctx->InputShape("epsilon"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), errors::InvalidArgument("epsilon is not a scalar: ", epsilon_shape.DebugString())); - TensorShape grad_shape = ctx->InputShape(7); + TensorShape grad_shape = ctx->InputShape("grad"); // var should be the same shape as mom and ms. OP_REQUIRES(ctx, var_shape.IsSameSize(ms_shape), @@ -297,11 +551,11 @@ class ResourceApplyRMSProp : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::XlaOp lr = ctx->Input(3); - xla::XlaOp rho = ctx->Input(4); - xla::XlaOp momentum = ctx->Input(5); - xla::XlaOp epsilon = ctx->Input(6); - xla::XlaOp grad = ctx->Input(7); + xla::XlaOp lr = ctx->Input("lr"); + xla::XlaOp rho = ctx->Input("rho"); + xla::XlaOp momentum = ctx->Input("momentum"); + xla::XlaOp epsilon = ctx->Input("epsilon"); + xla::XlaOp grad = ctx->Input("grad"); // ms <- rho * ms_{t-1} + (1-rho) * grad * grad // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) @@ -320,25 +574,46 @@ class ResourceApplyRMSProp : public XlaOpKernel { // ms <- grad**2 (1 - rho) + ms * rho // // Which is the equation listed above. - xla::XlaOp new_ms = b->Add( - ms, - b->Mul(b->Sub(b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), ms), - b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); - xla::XlaOp new_mom = - b->Add(b->Mul(mom, momentum), - b->Mul(b->Mul(grad, lr), - b->Pow(b->Add(new_ms, epsilon), - XlaHelpers::FloatLiteral(b, type, -0.5)))); - xla::XlaOp new_var = b->Sub(var, new_mom); - - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom)); + xla::XlaOp one = xla::ScalarLike(ms, 1.0); + xla::XlaOp new_ms = xla::Square(grad) * (one - rho) + ms * rho; + xla::XlaOp denominator; + if (centered_) { + mg = grad * (one - rho) + mg * rho; + denominator = new_ms - xla::Square(mg) + epsilon; + } else { + denominator = new_ms + epsilon; + } + xla::XlaOp new_mom = mom * momentum + grad * lr * xla::Rsqrt(denominator); + xla::XlaOp new_var = var - new_mom; + + OP_REQUIRES_OK(ctx, ctx->AssignVariable("var", dtype_, new_var)); + if (centered_) { + OP_REQUIRES_OK(ctx, ctx->AssignVariable("mg", dtype_, mg)); + } + OP_REQUIRES_OK(ctx, ctx->AssignVariable("ms", dtype_, new_ms)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable("mom", dtype_, new_mom)); } + + protected: + bool centered_ = false; + + private: + DataType dtype_; }; REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes), ResourceApplyRMSProp); +class ResourceApplyCenteredRMSProp : public ResourceApplyRMSProp { + public: + explicit ResourceApplyCenteredRMSProp(OpKernelConstruction* ctx) + : ResourceApplyRMSProp(ctx) { + centered_ = true; + } +}; +REGISTER_XLA_OP( + Name("ResourceApplyCenteredRMSProp").TypeConstraint("T", kFloatTypes), + ResourceApplyCenteredRMSProp); + void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage) { xla::XlaBuilder* b = ctx->builder(); @@ -424,21 +699,18 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0); xla::XlaOp grad_to_use; if (has_l2_shrinkage) { - grad_to_use = b->Add(grad, b->Mul(two, b->Mul(l2_shrinkage, var))); + grad_to_use = grad + two * l2_shrinkage * var; } else { grad_to_use = grad; } - xla::XlaOp new_accum = b->Add(accum, b->Pow(grad_to_use, two)); - xla::XlaOp new_accum_lr_pow = b->Pow(new_accum, b->Neg(lr_power)); - xla::XlaOp accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); - linear = b->Add( - linear, - b->Sub(grad_to_use, - b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), var))); - xla::XlaOp linear_clipped = b->Clamp(b->Neg(l1), linear, l1); - xla::XlaOp quadratic = b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); - var = b->Div(b->Sub(linear_clipped, linear), quadratic); + xla::XlaOp new_accum = accum + xla::Square(grad_to_use); + xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power); + xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power); + linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var; + xla::XlaOp linear_clipped = xla::Clamp(-l1, linear, l1); + xla::XlaOp quadratic = new_accum_lr_pow / lr + two * l2; + var = (linear_clipped - linear) / quadratic; accum = new_accum; OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var)); @@ -478,5 +750,176 @@ class ResourceApplyFtrlV2 : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes), ResourceApplyFtrlV2); +class ResourceApplyAdadelta : public XlaOpKernel { + public: + explicit ResourceApplyAdadelta(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, accum_shape, accum_update_shape; + xla::XlaOp var, accum, accum_update; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &accum_update_shape, + &accum_update)); + + TensorShape lr_shape = ctx->InputShape(3); + TensorShape rho_shape = ctx->InputShape(4); + TensorShape epsilon_shape = ctx->InputShape(5); + TensorShape grad_shape = ctx->InputShape(6); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape), + errors::InvalidArgument("rho is not a scalar: ", + rho_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(3); + xla::XlaOp rho = ctx->Input(4); + xla::XlaOp epsilon = ctx->Input(5); + xla::XlaOp grad = ctx->Input(6); + + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp neg_half = XlaHelpers::FloatLiteral(b, dtype_, -0.5); + xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); + xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); + xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + + accum = rho * accum + (one - rho) * xla::Pow(grad, two); + xla::XlaOp update = xla::Pow(accum_update + epsilon, half) * + xla::Pow(accum + epsilon, neg_half) * grad; + accum_update = rho * accum_update + (one - rho) * xla::Pow(update, two); + var = var - update * lr; + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, accum_update)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes), + ResourceApplyAdadelta); + +class ResourceApplySignBase : public XlaOpKernel { + public: + explicit ResourceApplySignBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, m_shape; + xla::XlaOp var, m; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); + OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape), + errors::InvalidArgument("var and m do not have the same shape", + var_shape.DebugString(), " ", + m_shape.DebugString())); + TensorShape grad_shape = ctx->InputShape(6); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + CheckScalarParams(ctx); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp alpha = ctx->Input(3); + xla::XlaOp sign_decay = ctx->Input(4); + xla::XlaOp beta = ctx->Input(5); + xla::XlaOp grad = ctx->Input(6); + + m = m * beta + grad * (xla::ScalarLike(beta, 1.0) - beta); + xla::XlaOp decay = xla::Sign(grad) * xla::Sign(m) * sign_decay; + + xla::XlaOp grad_scale = ComputeGradientScale(alpha, decay); + var = var - lr * grad_scale * grad; + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); + } + + virtual void CheckScalarParams(XlaOpKernelContext* ctx) { + TensorShape lr_shape = ctx->InputShape(2); + TensorShape sign_decay_shape = ctx->InputShape(4); + TensorShape beta_shape = ctx->InputShape(5); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sign_decay_shape), + errors::InvalidArgument("sign_decay is not a scalar: ", + sign_decay_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta_shape), + errors::InvalidArgument("beta is not a scalar: ", + beta_shape.DebugString())); + } + + virtual xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, + xla::XlaOp decay) = 0; + + private: + DataType dtype_; +}; + +class ResourceApplyAddSign : public ResourceApplySignBase { + public: + explicit ResourceApplyAddSign(OpKernelConstruction* ctx) + : ResourceApplySignBase(ctx) {} + + void CheckScalarParams(XlaOpKernelContext* ctx) override { + ResourceApplySignBase::CheckScalarParams(ctx); + TensorShape alpha_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("alpha is not a scalar: ", + alpha_shape.DebugString())); + } + + xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override { + return alpha + decay; + } +}; +REGISTER_XLA_OP(Name("ResourceApplyAddSign").TypeConstraint("T", kFloatTypes), + ResourceApplyAddSign); + +class ResourceApplyPowerSign : public ResourceApplySignBase { + public: + explicit ResourceApplyPowerSign(OpKernelConstruction* ctx) + : ResourceApplySignBase(ctx) {} + + void CheckScalarParams(XlaOpKernelContext* ctx) override { + ResourceApplySignBase::CheckScalarParams(ctx); + TensorShape logbase_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase_shape), + errors::InvalidArgument("logbase is not a scalar: ", + logbase_shape.DebugString())); + } + + xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override { + return xla::Exp(alpha * decay); + } +}; +REGISTER_XLA_OP(Name("ResourceApplyPowerSign").TypeConstraint("T", kFloatTypes), + ResourceApplyPowerSign); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index c167642174b328a968d7f7ce1f0ad6e0ab8a7a68..f9148b394212777271f9eba51313ee17b19819af 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -32,7 +33,8 @@ namespace { class TransposeOp : public XlaOpKernel { public: - explicit TransposeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit TransposeOp(OpKernelConstruction* ctx, bool conjugate = false) + : XlaOpKernel(ctx), conjugate_(conjugate) {} void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); @@ -78,19 +80,37 @@ class TransposeOp : public XlaOpKernel { errors::InvalidArgument(i, " is missing from 'perm' argument.")); } + xla::XlaOp transposed; // 0-D, 1-D, and identity transposes do nothing. if (dims <= 1 || is_identity) { - ctx->SetOutput(0, ctx->Input(0)); - return; + transposed = ctx->Input(0); + } else { + transposed = xla::Transpose(ctx->Input(0), transposed_order); } - ctx->SetOutput(0, - ctx->builder()->Transpose(ctx->Input(0), transposed_order)); + // Conjugate the transposed result if this is ConjugateTransposeOp. + if (conjugate_) { + ctx->SetOutput(0, xla::Conj(transposed)); + } else { + ctx->SetOutput(0, transposed); + } } + + private: + const bool conjugate_; +}; + +class ConjugateTransposeOp : public TransposeOp { + public: + explicit ConjugateTransposeOp(OpKernelConstruction* ctx) + : TransposeOp(ctx, /*conjugate=*/true) {} }; REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp); +REGISTER_XLA_OP(Name("ConjugateTranspose").CompileTimeConstInput("perm"), + ConjugateTransposeOp); + // InvertPermutation frequently forms part of the gradient of Transpose. // // inv = InvertPermutationOp(T p) takes a permutation of @@ -127,7 +147,7 @@ class InvertPermutationOp : public XlaOpKernel { output[d] = i; } - ctx->SetOutput(0, ctx->builder()->ConstantR1(output)); + ctx->SetOutput(0, xla::ConstantR1(ctx->builder(), output)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 71a9fd051bfc8db09738a4bfe8ddde447895ecf0..0bdfc05726105e2d18362a691cbe2aab00bf77f3 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -16,24 +16,26 @@ limitations under the License. // Native XLA implementations of simple unary Ops #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { namespace { -// A subclass of a TlaUnaryOp must build the lambda computation that -// describes the scalar->scalar function to apply to each element of -// the input. #define XLAJIT_MAKE_UNARY(NAME, COMPUTATION) \ class NAME##Op : public XlaOpKernel { \ public: \ explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \ void Compile(XlaOpKernelContext* ctx) { \ xla::XlaBuilder* b = ctx->builder(); \ + (void)b; \ xla::XlaOp x = ctx->Input(0); \ xla::XlaOp y = COMPUTATION; \ ctx->SetOutput(0, y); \ @@ -41,122 +43,74 @@ namespace { }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op); -XLAJIT_MAKE_UNARY(ComplexAbs, b->Abs(x)); +XLAJIT_MAKE_UNARY(ComplexAbs, xla::Abs(x)); -XLAJIT_MAKE_UNARY(Angle, b->Atan2(b->Imag(x), b->Real(x))); +XLAJIT_MAKE_UNARY(Angle, xla::Atan2(xla::Imag(x), xla::Real(x))); -XLAJIT_MAKE_UNARY(Conj, b->Conj(x)); +XLAJIT_MAKE_UNARY(Conj, xla::Conj(x)); // Return x if x>0, otherwise -x. -XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); - -// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) -XLAJIT_MAKE_UNARY( - Acos, - b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), - b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(x, x)), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), - b->Add(XlaHelpers::One(b, input_type(0)), x)))); - -// acosh(x) = log(x + sqrt(x^2 - 1)) -// = log(x + sqrt((x+1)*(x-1))) -XLAJIT_MAKE_UNARY( - Acosh, - b->Log(b->Add(x, - b->Pow(b->Mul(b->Add(x, XlaHelpers::One(b, input_type(0))), - b->Sub(x, XlaHelpers::One(b, input_type(0)))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); - -// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) -XLAJIT_MAKE_UNARY( - Asin, - b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), - b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)), - b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), - b->Mul(x, x)), - XlaHelpers::FloatLiteral(b, input_type(0), - 0.5)))))); - -// asinh(x) = log(x + sqrt(x^2 + 1)) +XLAJIT_MAKE_UNARY(Abs, xla::Abs(x)); +XLAJIT_MAKE_UNARY(Acos, xla::Acos(x)); +XLAJIT_MAKE_UNARY(Acosh, xla::Acosh(x)); +XLAJIT_MAKE_UNARY(Asin, xla::Asin(x)) +XLAJIT_MAKE_UNARY(Asinh, xla::Asinh(x)); +XLAJIT_MAKE_UNARY(Atan, xla::Atan(x)); +XLAJIT_MAKE_UNARY(Atanh, xla::Atanh(x)); +XLAJIT_MAKE_UNARY(Ceil, xla::Ceil(x)); +XLAJIT_MAKE_UNARY(Cos, xla::Cos(x)); +XLAJIT_MAKE_UNARY(Cosh, xla::Cosh(x)); +XLAJIT_MAKE_UNARY(Sin, xla::Sin(x)); +XLAJIT_MAKE_UNARY(Exp, xla::Exp(x)); +XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x)); +XLAJIT_MAKE_UNARY(Floor, xla::Floor(x)); +XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x)); XLAJIT_MAKE_UNARY( - Asinh, - b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x), - XlaHelpers::One(b, input_type(0))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); - -XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0)))); - -// atanh(x) = 0.5 * log((1 + x) / (1 - x)) -XLAJIT_MAKE_UNARY( - Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x), - b->Sub(XlaHelpers::One(b, input_type(0)), x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x)); -XLAJIT_MAKE_UNARY(Cos, b->Cos(x)); -XLAJIT_MAKE_UNARY(Cosh, - b->Mul(b->Add(b->Exp(x), b->Exp(b->Neg(x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Sin, b->Sin(x)); -XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); - -XLAJIT_MAKE_UNARY(Expm1, b->Expm1(x)); - -XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); -XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x)); -XLAJIT_MAKE_UNARY(IsInf, b->Eq(b->Abs(x), - XlaHelpers::FloatLiteral( - b, input_type(0), - std::numeric_limits::infinity()))); -XLAJIT_MAKE_UNARY(IsNan, b->Ne(x, x)); + IsInf, + xla::Eq(xla::Abs(x), + xla::ScalarLike(x, std::numeric_limits::infinity()))); +XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x)); // Return 1/x -XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x)); -XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x)); -XLAJIT_MAKE_UNARY(Log, b->Log(x)); - -XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x)); +XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x); +XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x); +XLAJIT_MAKE_UNARY(Log, xla::Log(x)); +XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x)); -XLAJIT_MAKE_UNARY(Invert, b->Not(x)); -XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); -XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); +XLAJIT_MAKE_UNARY(Invert, xla::Not(x)); +XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x)); +XLAJIT_MAKE_UNARY(Neg, -x); // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. -static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype, - const xla::XlaOp& x) { - auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); - auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); - auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); - - auto round_val = b->Floor(x); - auto fraction = b->Sub(x, round_val); - auto nearest_even_int = - b->Sub(round_val, b->Mul(two, b->Floor(b->Mul(half, x)))); - auto is_odd = b->Eq(nearest_even_int, one); - return b->Select( - b->Or(b->Gt(fraction, half), b->And(b->Eq(fraction, half), is_odd)), - b->Add(round_val, one), round_val); +xla::XlaOp RoundToEven(xla::XlaOp x) { + auto half = xla::ScalarLike(x, 0.5); + auto one = xla::ScalarLike(x, 1.0); + auto two = xla::ScalarLike(x, 2.0); + + auto round_val = xla::Floor(x); + auto fraction = x - round_val; + auto nearest_even_int = round_val - two * xla::Floor(half * x); + auto is_odd = xla::Eq(nearest_even_int, one); + return xla::Select(xla::Or(xla::Gt(fraction, half), + xla::And(xla::Eq(fraction, half), is_odd)), + round_val + one, round_val); } -XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x)); -XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); +XLAJIT_MAKE_UNARY(Rint, RoundToEven(x)); +XLAJIT_MAKE_UNARY(Round, RoundToEven(x)); -XLAJIT_MAKE_UNARY(Rsqrt, - b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); +XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x)); // Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. -static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype, - const xla::XlaOp& x) { - auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); - return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x)))); +xla::XlaOp Sigmoid(xla::XlaOp x) { + auto half = xla::ScalarLike(x, 0.5); + return half + half * xla::Tanh(half * x); } -XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x)); +XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x)); // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. -XLAJIT_MAKE_UNARY(Sign, b->Sign(x)); -XLAJIT_MAKE_UNARY(Sinh, - b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Sign, xla::Sign(x)); +XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x)); // softplus(x) = log(1 + exp(x)) // @@ -166,24 +120,94 @@ XLAJIT_MAKE_UNARY(Sinh, // // This is equivalent to: // max(x, 0) + log1p(exp(-abs(x))) -XLAJIT_MAKE_UNARY(Softplus, - b->Add(b->Max(x, XlaHelpers::Zero(b, input_type(0))), - b->Log1p(b->Exp(b->Neg(b->Abs(x)))))); +XLAJIT_MAKE_UNARY(Softplus, xla::Max(x, xla::ScalarLike(x, 0.0)) + + xla::Log1p(xla::Exp(-xla::Abs(x)))); // softsign(x) = x / (abs(x) + 1) -XLAJIT_MAKE_UNARY(Softsign, - b->Div(x, - b->Add(b->Abs(x), XlaHelpers::One(b, input_type(0))))); -XLAJIT_MAKE_UNARY(Sqrt, - b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Square, b->Mul(x, x)); -XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x))); -XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x)); - -XLAJIT_MAKE_UNARY(Real, b->Real(x)); -XLAJIT_MAKE_UNARY(Imag, b->Imag(x)); +XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0))); +XLAJIT_MAKE_UNARY(Sqrt, xla::Sqrt(x)); +XLAJIT_MAKE_UNARY(Square, x* x); +XLAJIT_MAKE_UNARY(Tan, xla::Tan(x)); +XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x)); + +XLAJIT_MAKE_UNARY(Real, xla::Real(x)); +XLAJIT_MAKE_UNARY(Imag, xla::Imag(x)); #undef XLAJIT_MAKE_UNARY +// Erf/Erfc. For x in (-1, 1), the erf approximation is used; erfc polynomial +// is used outside of this range. +class ErfOp : public XlaOpKernel { + public: + explicit ErfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp x = ctx->Input(0); + xla::XlaOp one = xla::ScalarLike(x, 1.0); + auto y = + xla::Select(xla::Gt(xla::Abs(x), one), one - xla::Erfc(x), xla::Erf(x)); + ctx->SetOutput(0, y); + } +}; +REGISTER_XLA_OP(Name("Erf"), ErfOp); + +class ErfcOp : public XlaOpKernel { + public: + explicit ErfcOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp x = ctx->Input(0); + xla::XlaOp one = xla::ScalarLike(x, 1.0); + auto y = + xla::Select(xla::Lt(xla::Abs(x), one), one - xla::Erf(x), xla::Erfc(x)); + ctx->SetOutput(0, y); + } +}; +REGISTER_XLA_OP(Name("Erfc"), ErfcOp); + +class LgammaOp : public XlaOpKernel { + public: + explicit LgammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Calculate lgamma using the Lanczos approximation + // (https://en.wikipedia.org/wiki/Lanczos_approximation). + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input = ctx->Input(0); + xla::PrimitiveType input_type = ctx->input_xla_type(0); + + if (input_type == xla::F16 || input_type == xla::BF16) { + // The approximation works better with at least 32-bits of accuracy. + xla::XlaOp input_f32 = xla::ConvertElementType(input, xla::F32); + xla::XlaOp result_f32 = xla::Lgamma(input_f32); + xla::XlaOp result_x16 = xla::ConvertElementType(result_f32, input_type); + ctx->SetOutput(0, result_x16); + } else { + xla::XlaOp result = xla::Lgamma(input); + ctx->SetOutput(0, result); + } + } +}; // namespace +REGISTER_XLA_OP(Name("Lgamma"), LgammaOp); + +class DigammaOp : public XlaOpKernel { + public: + explicit DigammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Calculate lgamma using the Lanczos approximation + // (https://en.wikipedia.org/wiki/Lanczos_approximation). + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input = ctx->Input(0); + xla::PrimitiveType input_type = ctx->input_xla_type(0); + + if (input_type == xla::F16 || input_type == xla::BF16) { + // The approximation works better with at least 32-bits of accuracy. + xla::XlaOp input_f32 = xla::ConvertElementType(input, xla::F32); + xla::XlaOp result_f32 = xla::Digamma(input_f32); + xla::XlaOp result_x16 = xla::ConvertElementType(result_f32, input_type); + ctx->SetOutput(0, result_x16); + } else { + xla::XlaOp result = xla::Digamma(input); + ctx->SetOutput(0, result); + } + } +}; // namespace +REGISTER_XLA_OP(Name("Digamma"), DigammaOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index f87586ba578a6138e7fb921032e1a71f8c9ac80c..8671632976023fded04c26a9780c1a67638b0916 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -22,7 +22,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -74,10 +75,9 @@ class UnpackOp : public XlaOpKernel { for (int i = 0; i < num; ++i) { start_indices[axis] = i; limit_indices[axis] = i + 1; - auto slice = ctx->builder()->Slice(input, start_indices, limit_indices, - strides); + auto slice = xla::Slice(input, start_indices, limit_indices, strides); // Reshape to drop the 'axis' dimension. - auto result = ctx->builder()->Reshape(slice, output_shape.dim_sizes()); + auto result = xla::Reshape(slice, output_shape.dim_sizes()); ctx->SetOutput(i, result); } } diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index a163fa0a5b34675e46d0d7c5f4e0ccb1e3fb18eb..2c92a585f5679242d672d0402e617ff199b94f17 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -13,18 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" +#include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/no_op.h" namespace tensorflow { namespace { @@ -35,12 +33,33 @@ class VarIsInitializedOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { XlaResource* variable; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable)); - ctx->SetOutput(0, - ctx->builder()->ConstantR0(variable->initialized())); + ctx->SetOutput( + 0, xla::ConstantR0(ctx->builder(), variable->initialized())); } }; REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp); +class VariableShapeOp : public XlaOpKernel { + public: + explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType variable_dtype; + TensorShape shape; + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); + Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); + ctx->SetConstantOutput(0, shape_constant); + } + + private: + DataType out_dtype_; +}; +REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); + class ReadVariableOp : public XlaOpKernel { public: explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -77,7 +96,7 @@ class AssignAddVariableOp : public XlaOpKernel { xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); - handle = ctx->builder()->Add(handle, ctx->Input(1)); + handle = xla::Add(handle, ctx->Input(1)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -93,7 +112,7 @@ class AssignSubVariableOp : public XlaOpKernel { xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); - handle = ctx->builder()->Sub(handle, ctx->Input(1)); + handle = xla::Sub(handle, ctx->Input(1)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -125,29 +144,152 @@ class ResourceGatherOp : public XlaOpKernel { ctx->SetOutput(0, gather); } }; -REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes), - ResourceGatherOp); +REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp); -class VariableShapeOp : public XlaOpKernel { +class ResourceScatterOp : public XlaOpKernel { public: - explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + explicit ResourceScatterOp( + OpKernelConstruction* context, bool indices_are_vectors, + std::function + combiner) + : XlaOpKernel(context), + indices_are_vectors_(indices_are_vectors), + combiner_(std::move(combiner)) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaBuilder* builder = context->builder(); + + DataType dtype = context->input_type(2); + TensorShape var_shape; + xla::XlaOp var_value; + OP_REQUIRES_OK( + context, context->ReadVariableInput(0, dtype, &var_shape, &var_value)); + + const xla::XlaOp indices = context->Input(1); + const xla::XlaOp updates = context->Input(2); + + auto result = XlaScatter(var_value, updates, indices, indices_are_vectors_, + combiner_, builder); + OP_REQUIRES_OK(context, result.status()); + OP_REQUIRES_OK(context, + context->AssignVariable(0, dtype, result.ValueOrDie())); } - void Compile(XlaOpKernelContext* ctx) override { - DataType variable_dtype; - TensorShape shape; - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); - Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); - OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); - ctx->SetConstantOutput(0, shape_constant); + private: + const bool indices_are_vectors_; + const std::function + combiner_; +}; + +class ResourceScatterAddOp : public ResourceScatterOp { + public: + explicit ResourceScatterAddOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Add(x, y); } +}; +REGISTER_XLA_OP(Name("ResourceScatterAdd"), ResourceScatterAddOp); + +class ResourceScatterSubOp : public ResourceScatterOp { + public: + explicit ResourceScatterSubOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: - DataType out_dtype_; + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Sub(x, y); + } }; +REGISTER_XLA_OP(Name("ResourceScatterSub"), ResourceScatterSubOp); + +class ResourceScatterMulOp : public ResourceScatterOp { + public: + explicit ResourceScatterMulOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Mul(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterMul"), ResourceScatterMulOp); + +class ResourceScatterDivOp : public ResourceScatterOp { + public: + explicit ResourceScatterDivOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Div(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterDiv"), ResourceScatterDivOp); + +class ResourceScatterMinOp : public ResourceScatterOp { + public: + explicit ResourceScatterMinOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Min(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterMin"), ResourceScatterMinOp); + +class ResourceScatterMaxOp : public ResourceScatterOp { + public: + explicit ResourceScatterMaxOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Max(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterMax"), ResourceScatterMaxOp); + +class ResourceScatterUpdateOp : public ResourceScatterOp { + public: + explicit ResourceScatterUpdateOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/false, + /*combiner=*/{}) {} +}; +REGISTER_XLA_OP(Name("ResourceScatterUpdate"), ResourceScatterUpdateOp); + +class ResourceScatterNdUpdateOp : public ResourceScatterOp { + public: + explicit ResourceScatterNdUpdateOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/true, + /*combiner=*/{}) {} +}; +REGISTER_XLA_OP(Name("ResourceScatterNdUpdate"), ResourceScatterNdUpdateOp); + +class ResourceScatterNdAddOp : public ResourceScatterOp { + public: + explicit ResourceScatterNdAddOp(OpKernelConstruction* context) + : ResourceScatterOp(context, /*indices_are_vectors=*/true, + /*combiner=*/Combine) {} + + private: + static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, + xla::XlaBuilder* builder) { + return xla::Add(x, y); + } +}; +REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp); -REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 5467c5d9946846ff9f14ce9c5aac9e2be4b9d6ab..296518229ebf0ba46717afc4f26d5ae1551c2862 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -21,8 +21,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" @@ -246,7 +247,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } } - xla::XlaOp init = builder->Tuple(inputs); + xla::XlaOp init = xla::Tuple(builder, inputs); VLOG(1) << "Building while loop"; @@ -255,22 +256,21 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { { std::unique_ptr cb = builder->CreateSubBuilder("cond_wrapper"); - auto inputs = cb->Parameter(0, cond_input_shape, "inputs"); - auto outputs = cb->Call(*cond.computation, {inputs}); - cb->GetTupleElement(outputs, 0); + auto inputs = xla::Parameter(cb.get(), 0, cond_input_shape, "inputs"); + auto outputs = xla::Call(cb.get(), *cond.computation, {inputs}); + xla::GetTupleElement(outputs, 0); xla::StatusOr result = cb->Build(); OP_REQUIRES_OK(ctx, result.status()); cond_wrapper = std::move(result.ValueOrDie()); } - xla::XlaOp while_result = - builder->While(cond_wrapper, *body.computation, init); + xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init); // Sets non-variable outputs. for (int i = 0; i < ctx->num_outputs(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { ctx->SetOutput(body.input_mapping[i], - builder->GetTupleElement(while_result, i)); + xla::GetTupleElement(while_result, i)); } } @@ -284,7 +284,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, - builder->GetTupleElement(while_result, pos), builder)); + xla::GetTupleElement(while_result, pos), builder)); } VLOG(2) << "Loop-carried variable: pos: " << update.input_index << " name: " << resource->name() << " modified: " << update.modified @@ -300,6 +300,8 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Done building while loop"; } +REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp); +REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp); REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index ee7f5d510ab7a3ce7d3bbe843c5fefd362f79b7b..cb7a40e23d539f758d963791f1c2b4d37374ade5 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -25,8 +25,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -40,12 +40,50 @@ cc_library( ":triangular_solve", ":util", ":while_loop", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "random", + srcs = ["random.cc"], + hdrs = ["random.h"], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( + name = "qr", + srcs = ["qr.cc"], + hdrs = ["qr.h"], + deps = [ + ":batch_dot", + ":util", + ":while_loop", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:lib", ], ) @@ -57,14 +95,14 @@ cc_library( deps = [ ":util", ":while_loop", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -76,13 +114,15 @@ cc_library( deps = [ ":batch_dot", ":util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:lib", ], ) @@ -94,7 +134,7 @@ xla_test( deps = [ ":triangular_solve", "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -102,7 +142,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -116,13 +156,14 @@ cc_library( srcs = ["util.cc"], hdrs = ["util.h"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -134,7 +175,7 @@ xla_test( ":batch_dot", ":util", "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -159,8 +200,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 526694d5a0c7124e1696f34b516f3b202462bc19..f666d22ea44216beef74608bb4d9f33fb2fe82c6 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -25,91 +26,94 @@ limitations under the License. namespace tensorflow { -xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, - xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, - bool conjugate_y) { - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { - return errors::InvalidArgument( - "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(x_shape), " vs. ", - xla::ShapeUtil::HumanString(y_shape)); - } - const int ndims = xla::ShapeUtil::Rank(x_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to BatchedDot must have rank >= 2: ", ndims); - } - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { +xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, + bool transpose_y, bool conjugate_x, bool conjugate_y) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { return errors::InvalidArgument( - "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " vs ", + "Arguments to BatchedDot have different ranks: ", + xla::ShapeUtil::HumanString(x_shape), " vs. ", xla::ShapeUtil::HumanString(y_shape)); } - batch_dimension_numbers.push_back(i); - } - - int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); - int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return errors::InvalidArgument( - "Dimensions ", x_inner_dim, " and ", y_inner_dim, - " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(y_shape), - " transpose: ", transpose_y); - } - - // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::HasZeroElements(x_shape) || - xla::ShapeUtil::HasZeroElements(y_shape)) { - std::vector dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + const int ndims = xla::ShapeUtil::Rank(x_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to BatchedDot must have rank >= 2: ", ndims); + } + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + for (int i = 0; i < ndims - 2; ++i) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { + return errors::InvalidArgument( + "Dimension ", i, " of inputs to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(x_shape), " vs ", + xla::ShapeUtil::HumanString(y_shape)); + } + batch_dimension_numbers.push_back(i); + } + + int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); + int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { + return errors::InvalidArgument( + "Dimensions ", x_inner_dim, " and ", y_inner_dim, + " of arguments to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, + " vs. ", xla::ShapeUtil::HumanString(y_shape), + " transpose: ", transpose_y); + } + + // Check for zero lhs/rhs dim size. + if (xla::ShapeUtil::IsZeroElementArray(x_shape) || + xla::ShapeUtil::IsZeroElementArray(y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); + int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); + return xla::Broadcast( + xla::ConstantLiteral(builder, + xla::LiteralUtil::Zero(x_shape.element_type())), + dimensions); + } + + if (x_shape.element_type() == xla::C64 && conjugate_x) { + x = xla::Conj(x); + } + if (y_shape.element_type() == xla::C64 && conjugate_y) { + y = xla::Conj(y); + } + + // If there are no batch dimensions, use a regular Dot. + // TODO(b/69062148) Remove this code when Dot emitters can be passed + // dimensions to transpose directly (i.e. without requiring a Transpose + // HLO). + if (batch_dimension_numbers.empty()) { + auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; + auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; + return xla::Dot(lhs, rhs); + } + + xla::DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(x_shape.element_type())), - dimensions); - } - - if (x_shape.element_type() == xla::C64 && conjugate_x) { - x = builder->Conj(x); - } - if (y_shape.element_type() == xla::C64 && conjugate_y) { - y = builder->Conj(y); - } - - // If there are no batch dimensions, use a regular Dot. - // TODO(b/69062148) Remove this code when Dot emitters can be passed - // dimensions to transpose directly (i.e. without requiring a Transpose HLO). - if (batch_dimension_numbers.empty()) { - auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; - auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; - return builder->Dot(lhs, rhs); - } - - xla::DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); - } - return builder->DotGeneral(x, y, dot_dnums); + return xla::DotGeneral(x, y, dot_dnums); + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index 1acc72033b05e73b0f5f88907df20cde5cfffbf0..8757b16a1ca6a8cec5e3c801c885e7bbbb2f2c76 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" namespace tensorflow { @@ -43,10 +43,9 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, - xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x = false, - bool conjugate_y = false); +xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, + bool transpose_y = false, bool conjugate_x = false, + bool conjugate_y = false); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 20925118bf598a6436c43bd727ce40e3abafc46c..87d73eb3f07ebd7dfa4fef50ebe76cad0c4ed117 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -22,7 +22,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -47,178 +49,163 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::StatusOr CholeskyUnblocked(xla::XlaBuilder* builder, - const xla::XlaOp& a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int n_dims = xla::ShapeUtil::Rank(a_shape); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - 2); - - xla::XlaOp l = Zeros(builder, a_shape); - - // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::Shape col_shape; - xla::Shape row_shape; - for (int64 d : major_dims) { - row_shape.add_dimensions(d); - col_shape.add_dimensions(d); - } - row_shape.add_dimensions(1); - row_shape.add_dimensions(n); - row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = Zeros(body_builder, row_shape); - - col_shape.add_dimensions(n); - col_shape.add_dimensions(1); - col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = Zeros(body_builder, col_shape); - - std::vector mask_vector(n); - std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = body_builder->ConstantR1(mask_vector); - auto mask_range_row = body_builder->Broadcast( - body_builder->Reshape(mask_range, {0}, {1, n}), major_dims); - auto mask_range_col = body_builder->Broadcast( - body_builder->Reshape(mask_range, {0}, {n, 1}), major_dims); - auto body_a = loop_vars[0]; - auto body_l = loop_vars[1]; - - // row = l[..., i, :i] - // select the whole i-th row, then mask out all columns past i-1 - auto zero = body_builder->ConstantR0(0); - TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l, - {i, zero}, {1, n})); - auto row = body_builder->Select(body_builder->Ge(mask_range_row, i), - mask_zeros_row, l_i); - // a[..., i, i] - TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a, - {i, i}, {1, 1})); - // np.dot(row, np.swapaxes(row, -1, -2)) - xla::XlaOp diag_dot; - TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row, - /*transpose_x=*/false, - /*transpose_y=*/true)); - // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, - // np.swapaxes(row, -1, -2))) - auto l_ii = body_builder->Pow( - body_builder->Sub(a_ii, diag_dot), - FloatLiteral(body_builder, a_shape.element_type(), 0.5)); - - // a[..., i+1:, i] - // select the whole i-th column, then mask out all rows above i+1 +xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int n_dims = xla::ShapeUtil::Rank(a_shape); + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - 2); + + xla::XlaOp l = xla::ZerosLike(a); + + // Construct the for loop body to iterate over rows. + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* body_builder) + -> xla::StatusOr> { + xla::Shape col_shape; + xla::Shape row_shape; + for (int64 d : major_dims) { + row_shape.add_dimensions(d); + col_shape.add_dimensions(d); + } + row_shape.add_dimensions(1); + row_shape.add_dimensions(n); + row_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_row = xla::Zeros(body_builder, row_shape); + + col_shape.add_dimensions(n); + col_shape.add_dimensions(1); + col_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_col = xla::Zeros(body_builder, col_shape); + + std::vector mask_vector(n); + std::iota(mask_vector.begin(), mask_vector.end(), 0); + auto mask_range = xla::ConstantR1(body_builder, mask_vector); + auto mask_range_row = + xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); + auto mask_range_col = + xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); + auto body_a = loop_vars[0]; + auto body_l = loop_vars[1]; + + // row = l[..., i, :i] + // select the whole i-th row, then mask out all columns past i-1 + auto zero = xla::ConstantR0(body_builder, 0); + auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n}); + auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); + // a[..., i, i] + auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); + // np.dot(row, np.swapaxes(row, -1, -2)) + auto diag_dot = BatchDot(row, row, + /*transpose_x=*/false, + /*transpose_y=*/true); + // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, + // np.swapaxes(row, -1, -2))) + auto l_ii = + xla::Pow(a_ii - diag_dot, + FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + + // a[..., i+1:, i] + // select the whole i-th column, then mask out all rows above i+1 + auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1}); + auto a_ip1i = + xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); + + // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / + // l[..., i, i] + // The columns in [i, n] are zeroed out in `row`, so we just have to + // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], + // r.T) + auto dot = BatchDot(body_l, row, + /*transpose_x=*/false, + /*transpose_y=*/true); + // np.dot(l[..., i+1:, :i], r.T) + auto dot_ip1 = + xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); + + body_l = + DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i}); + // Assign the diagonal after the rest of the column because otherwise the + // column assign will wrap around and overwrite the diagonal assign. + body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i}); + + return std::vector{body_a, body_l}; + }; + TF_ASSIGN_OR_RETURN( - auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1})); - auto a_ip1i = body_builder->Select(body_builder->Le(mask_range_col, i), - mask_zeros_col, a_0i); - - // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / - // l[..., i, i] - // The columns in [i, n] are zeroed out in `row`, so we just have to - // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], - // r.T) - TF_ASSIGN_OR_RETURN(auto dot, BatchDot(body_builder, body_l, row, - /*transpose_x=*/false, - /*transpose_y=*/true)); - // np.dot(l[..., i+1:, :i], r.T) - auto dot_ip1 = body_builder->Select(body_builder->Le(mask_range_col, i), - mask_zeros_col, dot); - - auto col_update = - body_builder->Div(body_builder->Sub(a_ip1i, dot_ip1), l_ii); - TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( - body_builder, body_l, col_update, {i})); - // Assign the diagonal after the rest of the column because otherwise the - // column assign will wrap around and overwrite the diagonal assign. - TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( - body_builder, body_l, l_ii, {i, i})); - - return std::vector{body_a, body_l}; - }; - - TF_ASSIGN_OR_RETURN( - auto cholesky_while, - XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); - - return cholesky_while[1]; + auto cholesky_while, + XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + + return cholesky_while[1]; + }); } } // namespace -xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(a_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to Cholesky must have rank >= 2: ", ndims); - } - - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); - } - - if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to Cholesky must be >= 1; got ", block_size); - } - - // Blocked left-looking Cholesky factorization. - // Algorithm 1 from - // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only - // execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::XlaOp l = Zeros(builder, a_shape); - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - if (i > 0) { - // TODO(phawkins): consider implementing SYRK for the diagonal part of - // the panel. - // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) - TF_ASSIGN_OR_RETURN(auto lhs, - SliceInMinorDims(builder, l, {i, 0}, {n, i})); - TF_ASSIGN_OR_RETURN(auto rhs, - SliceInMinorDims(builder, l, {i, 0}, {i + k, i})); - TF_ASSIGN_OR_RETURN(auto delta, - BatchDot(builder, lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto before, - SliceInMinorDims(builder, a, {i, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN( - a, UpdateSliceInMinorDims(builder, a, builder->Sub(before, delta), - {i, i})); +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int ndims = xla::ShapeUtil::Rank(a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to Cholesky must have rank >= 2: ", ndims); + } + + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { + return errors::InvalidArgument( + "Arguments to Cholesky must be square matrices: ", + xla::ShapeUtil::HumanString(a_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to Cholesky must be >= 1; got ", block_size); } - // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) - TF_ASSIGN_OR_RETURN(auto x, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto factorized, CholeskyUnblocked(builder, x)); - TF_ASSIGN_OR_RETURN(l, - UpdateSliceInMinorDims(builder, l, factorized, {i, i})); - - if (i + k < n) { - // l[i+k:, i:i+k] = trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) - TF_ASSIGN_OR_RETURN(auto panel, - SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN(auto update, - TriangularSolve(builder, factorized, panel, - /*left_side=*/false, - /*lower=*/true, - /*transpose_a=*/true, - /*conjugate_a=*/false, - /*block_size=*/block_size)); - TF_ASSIGN_OR_RETURN( - l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); + // Blocked left-looking Cholesky factorization. + // Algorithm 1 from + // Haidar, Azzam, et al. "High-performance Cholesky factorization for + // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. + xla::XlaOp l = xla::ZerosLike(a); + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + if (i > 0) { + // TODO(phawkins): consider implementing SYRK for the diagonal part of + // the panel. + // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) + auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); + auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); + auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, + /*transpose_y=*/true); + auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); + a = UpdateSliceInMinorDims(a, before - delta, {i, i}); + } + + // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) + auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto factorized = CholeskyUnblocked(x); + l = UpdateSliceInMinorDims(l, factorized, {i, i}); + + if (i + k < n) { + // l[i+k:, i:i+k] = + // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) + auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k}); + auto update = TriangularSolve(factorized, panel, + /*left_side=*/false, + /*lower=*/true, + /*transpose_a=*/true, + /*conjugate_a=*/false, + /*block_size=*/block_size); + l = UpdateSliceInMinorDims(l, update, {i + k, i}); + } } - } - return l; + return l; + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 20fca7969ece2729a44933fd3ef3f87230ab6cad..1bef9bb166c576ec665bb48265b4da200ddca2a0 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" namespace tensorflow { @@ -30,8 +30,7 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, - int64 block_size = 256); +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc0c1ee838190b1f1a7ca5b901c97e0a35232a97 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -0,0 +1,387 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/qr.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace { + +// Computes a Householder reflection of the form: +// H = I - tau v v.T. +// such that +// H . ( x1 ) = ( x1 ) +// ( x2 ) = ( x2 ) +// ( ... ) = ( ... ) +// ( xk ) = ( beta ) +// ( ... ) ( 0 ) +// ( ... ) ( 0 ) +// Unlike the usual formulation, we allow the caller to supply 'k' rather than +// only providing the relevant part of 'x' to maintain XLA's static shape +// invariant. In addition, the implementation supports batching. +// Pseudo-code, without batching: +// alpha = x[k] +// x_copy = np.copy(x) +// x_copy[:k+1] = 0 +// xnorm = norm2(x_copy) +// if xnorm == 0: +// beta = alpha +// tau = 0 +// v = np.zeros_like(x) +// else: +// beta = - np.sign(alpha) * dlapy2(alpha, xnorm) +// tau = (beta - alpha) / beta +// v = x / (alpha - beta) +// v[k] = 1 +// return (v, tau, beta) +// TODO(phawkins): LAPACK's xLARFG implementation has code for handling +// overflows in the norm/beta calculations. Perhaps do the same here. +xla::Status House(xla::XlaOp x, xla::XlaOp k, gtl::ArraySlice batch_dims, + const int64 m, xla::XlaOp* v, xla::XlaOp* tau, + xla::XlaOp* beta) { + xla::XlaBuilder* const builder = x.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); + const xla::PrimitiveType type = x_shape.element_type(); + + std::vector batch_dim_ids(batch_dims.size()); + std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); + const int64 minor_dim = batch_dims.size(); + + xla::XlaOp zero = xla::ScalarLike(x, 0.0); + xla::XlaOp one = xla::ScalarLike(x, 1.0); + + // alpha = x[k] + xla::XlaOp alpha = + xla::Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); + + // Compute x[k+1:] (padded with zeros in elements 0..k) + xla::XlaOp iota = xla::Iota(builder, xla::S32, m); + xla::XlaOp x_after_k = + xla::Mul(x, xla::ConvertElementType(xla::Gt(iota, k), type), + /*broadcast_dimensions=*/{minor_dim}); + + // sigma = np.dot(x[k+1:], x[k+1:]) + auto sigma = + xla::Reduce(x_after_k * x_after_k, zero, + xla::CreateScalarAddComputation(type, builder), {minor_dim}); + // mu = np.sqrt(x[k]*x[k] + sigma) + auto mu = xla::Sqrt(xla::Square(alpha) + sigma); + + auto sigma_is_zero = xla::Eq(sigma, zero); + + *beta = xla::Select(sigma_is_zero, alpha, -xla::Sign(alpha) * mu); + *tau = xla::Select(sigma_is_zero, xla::Broadcast(zero, batch_dims), + (*beta - alpha) / *beta); + auto divisor = xla::Select(sigma_is_zero, xla::Broadcast(one, batch_dims), + alpha - *beta); + + auto e_k = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, k), type), + std::vector(batch_dims.size(), 1)); + + // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor + // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. + *v = e_k + + xla::Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); + return Status::OK(); +} + +// Householder QR decomposition. Algorithm 5.2.1 from Golub and Van +// Loan "Matrix Computations", 4th Edition. This is an unblocked implementation +// used as an inner routine of the blocked implementation. +// Algorithm is adapted slightly so the shapes inside the loop are static, at +// the cost of some redundant computation. Since this is used as an inner block +// kernel, accumulates the Householder transformations (vs, taus) rather than +// the matrix q. +// Equivalent Python code, without batching: +// def qr(a): +// m = a.shape[0] +// n = a.shape[1] +// vs = np.zeros([m, n]) +// taus = np.zeros([n]) +// for j in xrange(min(m, n)): +// v, tau, beta = house(a[:, j], j) +// # Unusually, we apply the Householder transformation to the entirety of +// # a, wasting FLOPs to maintain the static shape invariant that XLA +// # requires. For columns that precede j this has no effect. +// a[:, :] -= tau * np.dot(v[:, np.newaxis], +// np.dot(v[np.newaxis, :], a[:, :])) +// # Form column j explicitly rather than relying on the precision of the +// # Householder update. +// a[j, j] = beta +// a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype) +// vs[:, j] = v +// taus[j] = tau +// return (q, vs, taus) +struct QRBlockResult { + // The factored R value + xla::XlaOp r; + + // Representation of the Householder matrices I - beta v v.T + xla::XlaOp taus; // Shape: [..., n] + xla::XlaOp vs; // Shape: [..., m, n] +}; +xla::StatusOr QRBlock(xla::XlaOp a) { + xla::XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int num_dims = xla::ShapeUtil::Rank(a_shape); + if (num_dims < 2) { + return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", + num_dims); + } + xla::PrimitiveType type = a_shape.element_type(); + + const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + } + + std::vector batch_dim_indices(num_batch_dims); + std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); + + auto qr_body_fn = + [&](xla::XlaOp j, gtl::ArraySlice values, + xla::XlaBuilder* builder) -> xla::StatusOr> { + auto a = values[0]; + auto vs = values[1]; + auto taus = values[2]; + + // v, beta = house(a[:, j], j) + auto x = DynamicSliceInMinorDims(a, {j}, {1}); + xla::XlaOp v, tau, beta; + TF_RETURN_IF_ERROR(House(xla::Collapse(x, {num_dims - 2, num_dims - 1}), j, + batch_dims, m, &v, &tau, &beta)); + + std::vector shape = batch_dims; + shape.push_back(1); + shape.push_back(m); + auto v_broadcast = xla::Reshape(v, shape); + // a[:, :] -= tau * np.dot(v[:, np.newaxis], + // np.dot(v[np.newaxis, :], a[:, :])) + auto vva = BatchDot(v_broadcast, a); + vva = BatchDot(v_broadcast, vva, /*transpose_x=*/true); + a = a - xla::Mul(tau, vva, + /*broadcast_dimensions=*/batch_dim_indices); + + // It is more precise to populate column 'k' explicitly, rather than + // computing it implicitly by applying the Householder transformation. + // a[k,k] = beta + // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype) + auto iota = xla::Reshape(xla::Iota(a.builder(), xla::S32, m), {m, 1}); + auto predecessor_mask = xla::ConvertElementType(xla::Lt(iota, j), type); + auto mask = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, j), type), + std::vector(batch_dims.size(), 1)); + auto new_x = + xla::Mul(x, predecessor_mask, + /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + + xla::Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); + a = DynamicUpdateSliceInMinorDims(a, new_x, {j}); + + // vs[:, j] = v + vs = DynamicUpdateSliceInMinorDims( + vs, xla::Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); + // taus[j] = tau + taus = DynamicUpdateSliceInMinorDims( + taus, xla::Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); + return std::vector{a, vs, taus}; + }; + + auto vs = xla::Zeros(builder, xla::ShapeUtil::MakeShape( + type, ConcatVectors(batch_dims, {m, n}))); + auto taus = xla::Zeros( + builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); + + TF_ASSIGN_OR_RETURN(auto values, + XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn, + {a, vs, taus}, "qr", builder)); + + QRBlockResult result; + result.r = values[0]; + result.vs = values[1]; + result.taus = values[2]; + return result; +} + +// Computes W and Y such that I-WY is equivalent to the sequence of Householder +// transformations given by vs and taus. +// Golub and van Loan, "Matrix Computations", algorithm 5.1.2. +// Y = np.zeros([m, n]) +// W = np.zeros([m, n]) +// Y[:, 0] = vs[:, 0] +// W[:, 0] = -taus[0] * vs[:, 0] +// for j in xrange(1, n): +// v = vs[:, j] +// z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v)) +// W[:, j] = z +// Y[:, j] = v +// return W +// There is no need to return Y since at termination of the loop it is equal to +// vs. +xla::StatusOr ComputeWYRepresentation( + xla::PrimitiveType type, gtl::ArraySlice batch_dims, xla::XlaOp vs, + xla::XlaOp taus, int64 m, int64 n) { + std::vector batch_dim_indices(batch_dims.size()); + std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); + int64 n_index = batch_dims.size() + 1; + + auto body_fn = + [&](xla::XlaOp j, gtl::ArraySlice values, + xla::XlaBuilder* builder) -> xla::StatusOr> { + auto w = values[0]; + auto y = values[1]; + const auto vs = values[2]; + const auto taus = values[3]; + + // Want j values in range [1, ... n). + j = j + xla::ConstantR0(builder, 1); + // vs has shape [..., m, 1] + auto v = DynamicSliceInMinorDims(vs, {j}, {1}); + // beta has shape [..., 1] + auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); + + // yv has shape [..., n, 1] + auto yv = BatchDot(y, v, /*transpose_x=*/true); + // wyv has shape [..., m, 1] + auto wyv = BatchDot(w, yv); + + auto z = xla::Mul( + -beta, v + wyv, + /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); + + w = DynamicUpdateSliceInMinorDims(w, z, {j}); + y = DynamicUpdateSliceInMinorDims(y, v, {j}); + + return std::vector{w, y, vs, taus}; + }; + + xla::XlaBuilder* builder = vs.builder(); + auto w = xla::Zeros(builder, xla::ShapeUtil::MakeShape( + type, ConcatVectors(batch_dims, {m, n}))); + auto y = w; + auto v = SliceInMinorDims(vs, {0}, {1}); + auto beta = SliceInMinorDims(taus, {0}, {1}); + y = UpdateSliceInMinorDims(y, v, {0}); + auto bv = xla::Mul( + -beta, v, + /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); + w = UpdateSliceInMinorDims(w, bv, {0}); + + TF_ASSIGN_OR_RETURN( + auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus}, + "wy", builder)); + return values[0]; +} + +} // namespace + +// Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan. +// def qr_blocked(a, block_size): +// m = a.shape[0] +// n = a.shape[1] +// q = np.eye(m) +// for i in xrange(0, min(m, n), block_size): +// k = min(block_size, min(m, n) - s) +// (a, vs, taus) = qr(a[i:, i:i+k]) +// y = vs +// w = ComputeWYRepresentation(vs, taus, m-i, k) +// a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:])) +// q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T)) +// return (q, a) +// TODO(phawkins): consider using UT transformations (in the form I - V U V') +// rather than WY transformations. +xla::StatusOr QRDecomposition(xla::XlaOp a, + int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int num_dims = xla::ShapeUtil::Rank(a_shape); + if (num_dims < 2) { + return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", + num_dims); + } + xla::PrimitiveType type = a_shape.element_type(); + + const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + const int64 p = std::min(m, n); + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to QR must be >= 1; got ", block_size); + } + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + } + + auto q = xla::Broadcast(xla::IdentityMatrix(builder, type, m, m), batch_dims); + for (int64 i = 0; i < p; i += block_size) { + int64 k = std::min(block_size, p - i); + + auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k}); + TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block)); + + a = UpdateSliceInMinorDims(a, qr_block.r, {i, i}); + + // Compute the I-WY block representation of a product of Householder + // matrices. + TF_ASSIGN_OR_RETURN(auto w, + ComputeWYRepresentation(type, batch_dims, qr_block.vs, + qr_block.taus, m - i, k)); + auto y = qr_block.vs; + + // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) + auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); + auto a_update = BatchDot(w, a_panel, /*transpose_x=*/true); + a_update = BatchDot(y, a_update); + a_panel = a_panel + a_update; + a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); + + // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) + auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); + auto q_update = BatchDot(q_panel, w); + q_update = + BatchDot(q_update, y, /*transpose_x=*/false, /*transpose_y=*/true); + q_panel = q_panel + q_update; + q = UpdateSliceInMinorDims(q, q_panel, {0, i}); + } + QRDecompositionResult result; + result.q = q; + result.r = a; + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h new file mode 100644 index 0000000000000000000000000000000000000000..abd2316ac961f583dd29f90f43cf6209de30bd6a --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" + +namespace tensorflow { + +// Computes the QR decompositions of a batch of matrices. That is, +// given a (batched) matrix a, computes an orthonormal matrix Q and an +// upper-triangular matrix R such that a = QR. +// `a` must be a (batched) matrix of size [..., m, n]. +// The algorithm implements a blocked QR decomposition; `block_size` is +// the block size to use. +// TODO(phawkins): handle the complex case. +struct QRDecompositionResult { + xla::XlaOp q; + xla::XlaOp r; +}; + +xla::StatusOr QRDecomposition(xla::XlaOp a, + int64 block_size = 128); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e7cf00ee5e063aef36a9531ff87d8fe6928ca1f --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/random.cc @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/random.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace tensorflow { + +xla::XlaOp TruncatedNormal(xla::XlaOp uniform) { + auto normal_cdf = [](double x) { + return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; + }; + + const double kA = -2.0; + const double kB = 2.0; + const double kMu = 0.0; + const double kSigma = 1.0; + const double kAlpha = (kA - kMu) / kSigma; + const double kBeta = (kB - kMu) / kSigma; + const double kAlphaNormalCdf = normal_cdf(kAlpha); + const double kBetaNormalCdf = normal_cdf(kBeta); + const double kZ = kBetaNormalCdf - kAlphaNormalCdf; + + xla::XlaOp one = xla::ScalarLike(uniform, 1.0); + xla::XlaOp two = xla::ScalarLike(uniform, 2.0); + xla::XlaOp sqrt_2 = xla::ScalarLike(uniform, std::sqrt(2.0)); + xla::XlaOp z = xla::ScalarLike(uniform, kZ); + xla::XlaOp alpha_normal_cdf = xla::ScalarLike(uniform, kAlphaNormalCdf); + + auto p = alpha_normal_cdf + z * uniform; + // probit(p) = sqrt(2) * erfinv(2*p-1) + return sqrt_2 * xla::ErfInv(two * p - one); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/random.h b/tensorflow/compiler/tf2xla/lib/random.h new file mode 100644 index 0000000000000000000000000000000000000000..59fc5d0433a51328bc78006ab1c3495d908b44ac --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/random.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { + +// Builds an array filled with values sampled from a truncated normal +// distribution such that no values are greater than two or less than negative +// two. +// +// The "uniform" parameter must be an array of random numbers distributed in +// (0,1). +xla::XlaOp TruncatedNormal(xla::XlaOp uniform); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index d5a27abb2585f699ae2719cb8a6b9a829263389e..ba22eff73abab11abeb57283c63318b2e50a9ca1 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -21,7 +21,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -97,8 +98,8 @@ xla::StatusOr XlaScatter( buffer_shape_post_axes.end()); // Construct the initial values of the loop-carried Tensors. - auto flat_indices = builder->Reshape(indices, flat_indices_shape); - auto flat_updates = builder->Reshape(updates, flat_updates_shape); + auto flat_indices = xla::Reshape(indices, flat_indices_shape); + auto flat_updates = xla::Reshape(updates, flat_updates_shape); auto init = {flat_indices, flat_updates, buffer}; // Constructs the loop body. The implementation of scatter is essentially: @@ -112,46 +113,44 @@ xla::StatusOr XlaScatter( auto updates = loop_vars[1]; auto buffer = loop_vars[2]; - auto zero_index = body_builder->ConstantLiteral( - xla::Literal::Zero(indices_shape.element_type())); + auto zero_index = xla::ConstantLiteral( + body_builder, xla::LiteralUtil::Zero(indices_shape.element_type())); // Slice the i-th index from the indices array. xla::XlaOp index; - auto indices_offset = body_builder->Reshape(i, {1}); + auto indices_offset = xla::Reshape(i, {1}); if (indices_are_vectors) { - indices_offset = body_builder->Pad(indices_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, 1}})); + indices_offset = xla::Pad(indices_offset, zero_index, + xla::MakeEdgePaddingConfig({{0, 1}})); - index = body_builder->DynamicSlice(indices, indices_offset, - {1, num_index_dims}); - index = body_builder->Collapse(index, {0, 1}); + index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims}); + index = xla::Collapse(index, {0, 1}); } else { - index = body_builder->DynamicSlice(indices, indices_offset, {1}); + index = xla::DynamicSlice(indices, indices_offset, {1}); } // Discard updates with negative indices, since some users expect this. - auto index_in_range = - body_builder->ReduceAll(body_builder->Le(zero_index, index), - body_builder->ConstantR0(true), - xla::CreateScalarAndComputation(body_builder)); + auto index_in_range = xla::ReduceAll( + xla::Le(zero_index, index), xla::ConstantR0(body_builder, true), + xla::CreateScalarAndComputation(xla::PRED, body_builder)); // Make the index in bounds to prevent implementation defined behavior. - index = body_builder->Max(index, zero_index); - index = body_builder->Pad( + index = xla::Max(index, zero_index); + index = xla::Pad( index, zero_index, xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); // Slice the i-th index from the updates array. - auto updates_offset = body_builder->Reshape(i, {1}); - updates_offset = body_builder->Pad( + auto updates_offset = xla::Reshape(i, {1}); + updates_offset = xla::Pad( updates_offset, zero_index, xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); std::vector flat_updates_slice_shape({1}); flat_updates_slice_shape.insert(flat_updates_slice_shape.end(), buffer_shape_post_axes.begin(), buffer_shape_post_axes.end()); - auto update = body_builder->DynamicSlice(updates, updates_offset, - flat_updates_slice_shape); + auto update = + xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape); // Unflatten the major (iteration) dimensions of the slice to their // original shape. @@ -159,20 +158,19 @@ xla::StatusOr XlaScatter( updates_slice_shape.insert(updates_slice_shape.end(), buffer_shape_post_axes.begin(), buffer_shape_post_axes.end()); - update = body_builder->Reshape(update, updates_slice_shape); + update = xla::Reshape(update, updates_slice_shape); // Apply the update to the buffer. If there is a combiner, use it to merge // the current values with the update. - auto current_value = - body_builder->DynamicSlice(buffer, index, updates_slice_shape); + auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape); if (combiner) { update = combiner(current_value, update, body_builder); } // Use the current value instead of the update if the index is out of // bounds. - update = body_builder->Select(index_in_range, update, current_value); + update = xla::Select(index_in_range, update, current_value); // Apply the update. - buffer = body_builder->DynamicUpdateSlice(buffer, update, index); + buffer = xla::DynamicUpdateSlice(buffer, update, index); return std::vector{indices, updates, buffer}; }; diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 87309e10ede320a81d173cd0a64492f88a2c7376..13a5f1b850a612bddeeac39bef431c19925351ca 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index b4503601f94baa5a595a64c9fc81bc92d9980ac6..04fa10108cef66f429392951eea70e59643a2d29 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -20,628 +20,383 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/math/math_util.h" namespace tensorflow { -xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, - const xla::XlaOp& a, xla::XlaOp b, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(a_shape), " vs. ", - xla::ShapeUtil::HumanString(b_shape)); - } - const int ndims = xla::ShapeUtil::Rank(a_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to TriangularSolve must have rank >= 2: ", ndims); - } - // The batch dimensions must be equal. - std::vector batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - int64 b_size = b_shape.dimensions(i); - if (a_size != b_size) { - return errors::InvalidArgument( - "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); +// Get the diagonal blocks of the coefficient matrix +xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(a)); + int ndims = xla::ShapeUtil::Rank(shape); + int64 n = xla::ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = n / block_size; + + xla::XlaOp diag_blocks; + + // If the coefficient matrix is exactly the block size, we just add a + // singleton dimension i.e. [..., n, n] -> [..., 1, n, n] + if (n == block_size) { + std::vector permutation(ndims); + std::iota(permutation.begin(), permutation.end(), 1); + permutation.insert(permutation.end() - 2, 0); + return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation); } - batch_dimensions.push_back(a_size); - } - - if (xla::ShapeUtil::GetDimension(a_shape, -1) != - xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); - } - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); - } - - if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to TriangularSolve must be >= 1; got ", - block_size); - } - - std::map base_computations; - auto get_base_triangular_solve = - [&](int k) -> xla::StatusOr { - xla::XlaComputation& computation = base_computations[k]; - if (computation.IsNull()) { - std::unique_ptr sub = builder->CreateSubBuilder( - tensorflow::strings::StrCat("trsm_base_", k)); - - auto a_param = sub->Parameter( - 0, - xla::ShapeUtil::MakeShape( - b_shape.element_type(), - PrependMajorDims(sub.get(), batch_dimensions, {k, k})), - "a"); - - std::array b_lastd; - if (left_side) { - b_lastd = {k, n}; - } else { - b_lastd = {m, k}; - } - auto b_param = sub->Parameter( - 1, - xla::ShapeUtil::MakeShape( - b_shape.element_type(), - PrependMajorDims(sub.get(), batch_dimensions, b_lastd)), - "b"); - - // We use a left-looking or right-looking subroutine on the block diagonal - // in the lower=true cases, while falling back to a recursive call in - // others. The left-looking and right-looking subroutines are written with - // a While loop and so yields much faster compile times. Moreover, they - // can give higher performance on smaller (sub)problems. - if (left_side && lower) { - TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, - b_param, transpose_a, - conjugate_a) - .status()); - } else if (!left_side && lower) { - TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param, - b_param, transpose_a, - conjugate_a) - .status()); - } else { - TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, - left_side, lower, transpose_a, - conjugate_a, - /*block_size=*/1) - .status()); - } - TF_ASSIGN_OR_RETURN(computation, sub->Build()); + // We can grab entire blocks using gather + if (n > block_size) { + // Construct the starting indices of the diagonal blocks + auto gather_indices = + Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks), + xla::ConstantR0(builder, block_size)), + /*broadcast_sizes=*/{2}), + /*permutation=*/{1, 0}); + + // Gather the diagonal blocks + xla::GatherDimensionNumbers dim_numbers; + dim_numbers.add_output_window_dims(ndims - 1); + dim_numbers.add_output_window_dims(ndims); + dim_numbers.add_gather_dims_to_operand_dims(ndims - 2); + dim_numbers.add_gather_dims_to_operand_dims(ndims - 1); + dim_numbers.set_index_vector_dim(1); + diag_blocks = Gather(a, gather_indices, dim_numbers, + /*window_bounds=*/{block_size, block_size}); } - return &computation; - }; - - xla::XlaOp output = Zeros(builder, b_shape); - - // Right-looking blocked triangular solve. - // For an explanation of the algorithm, see the TRSM discussion in: - // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation - // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 - // (2008): 4. - - // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if - // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if - // conjugate_a is True. - - if (!left_side && lower == transpose_a) { - // for i in range(0, a.shape[-1], block_size): - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - - // output[..., :, i:i+k] = triangular_solve( - // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - - // if i + k < a.shape[-1]: - // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) - if (i + k < n) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - } else { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n})); - } - TF_ASSIGN_OR_RETURN(auto b_update, - BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, i + k}, {m, n})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); + // The last block might be smaller than the block size, + // so we will need to pad it + if (n % block_size != 0) { + // Pad with zeros + auto last_blocks = + SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n}); + xla::PaddingConfig config = xla::MakeNoPaddingConfig(ndims); + int64 padding = block_size - n % block_size; + config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding); + config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding); + last_blocks = + Pad(last_blocks, Zero(builder, shape.element_type()), config); + + // Add a singleton dimension + // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size] + TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, + builder->GetShape(last_blocks)); + auto shape_dims = xla::AsInt64Slice(blocks_shape.dimensions()); + auto last_blocks_dims = std::vector(ndims); + std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin()); + last_blocks_dims.insert(last_blocks_dims.end() - 2, 1); + last_blocks = Reshape(last_blocks, last_blocks_dims); + + // Concatenate with the other blocks if necessary + if (n > block_size) { + diag_blocks = + xla::ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2); + } else { + diag_blocks = last_blocks; } } - } else if (left_side && lower != transpose_a) { - // for i in range(0, a.shape[-1], block_size): - for (int64 i = 0; i < m; i += block_size) { - int64 k = std::min(block_size, m - i); - - // output[..., i:i+k, :] = triangular_solve( - // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - - // if i + k < a.shape[-1]: - // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) - if (i + k < m) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); - } else { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m})); - } + return diag_blocks; + }); +} - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {i + k, 0}, {m, n})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0})); - } +xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, + bool transpose_a, bool conjugate_a) { + xla::XlaBuilder* builder = diag_blocks.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + // Input is a batch of square lower triangular square matrices. Its shape is + // (..., size, size). We resize this to (num_blocks, size, size). + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(diag_blocks)); + int64 block_size = xla::ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = xla::ShapeUtil::ElementsIn(shape) / + tensorflow::MathUtil::IPow(block_size, 2); + diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); + + // The input must be triangular because we rely on that when doing + // multiplications later on + diag_blocks = Triangle(diag_blocks, /*lower=*/lower); + + // Rescale blocks to be unit triangular, but avoid dividing by + // zero (which can happen if the last block was padded) otherwise it will + // introduce nans which will propagate + auto diags = GetMatrixDiagonal(diag_blocks); + TF_ASSIGN_OR_RETURN(xla::Shape diags_shape, builder->GetShape(diags)); + auto one = ScalarLike(diags, 1); + auto ones = Broadcast(one, xla::AsInt64Slice(diags_shape.dimensions())); + diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); + auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); + + // We can now use the fact that for an upper triangular matrix + // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have + // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks + // have been rescaled to be unit triangular, so L22 = L22' = 1. + + // Initialize the output matrix with -1s on the diagonal. We use -1 instead + // of 1 because we cannot do matrix-vector multiplies with variable shapes + // inside of a loop, or do irregularly shaped in-place updates. Hence, + // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the + // entire row i.e. we calculate + // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I]) + // which means [L21 L22 0] <- [-L21 * L11', L22, 0]. + auto identity = + IdentityMatrix(builder, shape.element_type(), block_size, block_size); + auto neg_identity = -identity; + + // The first or last diagonal element should be set to 1 instead of -1 + // though, since we never update it + auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); + auto start_index = (lower) ? 0 : block_size - 1; + auto output_block = DynamicUpdateSlice( + neg_identity, pos_one, + /*start_indices=*/xla::ConstantR1(builder, 2, start_index)); + + // Broadcast diag([1, -1, -1, ...]) to every block + xla::XlaOp output = Broadcast(output_block, + /*broadcast_sizes=*/{num_blocks}); + + // Now we construct a loop that performs matrix-vector multiplications + // inverting the blocks one row at a time + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of A, with one row updated each iteration. + xla::ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size}), + // The input is a loop invariant. + xla::ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size})}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + + auto init_i = One(builder, xla::S32); + auto init = xla::Tuple(builder, {init_i, output, scaled_diag_blocks}); + + // Construct the loop condition function. + std::unique_ptr condb = + builder->CreateSubBuilder("InvertDiagCond"); + { + auto i = GetTupleElement( + Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); + Lt(i, xla::ConstantR0(condb.get(), block_size)); } - } else if (!left_side && lower != transpose_a) { - // for i in reversed(range(0, a.shape[-1], block_size)): - const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size; - for (int64 i = last_blk_ix; i >= 0; i -= block_size) { - int64 k = std::min(block_size, n - i); - - // output[..., :, i:i+k] triangular_solve( - // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - - // if i - k >= 0: - // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) - if (i - k >= 0) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); - } else { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {0, i}, {i, i + k})); - } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function. + std::unique_ptr bodyb = + builder->CreateSubBuilder("InvertDiagBody"); + { + auto input_tuple = + Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple"); + + auto i = GetTupleElement(input_tuple, 0); + auto body_out = GetTupleElement(input_tuple, 1); + auto body_input = GetTupleElement(input_tuple, 2); + + auto zero = xla::ConstantR1(bodyb.get(), 1, 0); + auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; + auto start_indices = + xla::ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); + auto input_row = + DynamicSlice(body_input, start_indices, + /*slice_sizes=*/{num_blocks, 1, block_size}); + + // We want -L21 L11^{-1} + xla::DotDimensionNumbers dnums; + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + auto update = -DotGeneral(input_row, body_out, dnums); + + body_out = DynamicUpdateSlice(body_out, update, start_indices); + + auto next_i = i + ScalarLike(i, 1); + xla::Tuple(bodyb.get(), {next_i, body_out, body_input}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto invert_while = While(cond, body, init); + auto inv_diag_blocks = GetTupleElement(invert_while, 1); + + // Undo the scaling + inv_diag_blocks = Div(inv_diag_blocks, diags, + /*broadcast_dimensions=*/{0, 1}); + + // Reshape back to original batch major dimensions + return Reshape(inv_diag_blocks, xla::AsInt64Slice(shape.dimensions())); + }); +} - TF_ASSIGN_OR_RETURN(auto b_update, - BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, 0}, {m, i})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); +xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, + xla::XlaOp inv_diag_blocks, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, + builder->GetShape(inv_diag_blocks)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + int64 block_size = xla::ShapeUtil::GetDimension(blocks_shape, -1); + + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + int64 ndims = xla::ShapeUtil::Rank(a_shape); + int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + int64 num_blocks = n / block_size + (n % block_size != 0); + int64 m_dim = (left_side) ? -1 : -2; + int64 m = xla::ShapeUtil::GetDimension(b_shape, m_dim); + + // Initialize the solution + auto x = ZerosLike(b); + + // This loop is unrolled for performance reasons, but it could be expressed + // rolled as well since the matrices are of the same size each iteration + for (int i = 0; i < num_blocks; i++) { + // High-level intuition: We have B[i] = L[i] @ X. Since L is upper + // triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split + // this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which + // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i] + + // Decide whether we go from first block to last or vice versa + auto j = (left_side ^ lower ^ transpose_a) ? num_blocks - 1 - i : i; + + // Get the size of the inverse blocks (the last one might be smaller) + int64 block = (n % block_size != 0 && j + 1 == num_blocks) + ? n % block_size + : block_size; + auto inv_block = + MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0}, + {j + 1, block, block}), + /*dimensions=*/{ndims - 2, ndims - 1}), + conjugate_a); + + // Get the corresponding row of B + int64 k = std::min((j + 1) * block_size, n); + std::vector start = {j * block_size, 0}; + std::vector end = {k, m}; + if (!left_side) { + std::swap(start[0], start[1]); + std::swap(end[0], end[1]); } - } - } else { // left_side && lower == transpose_a - // for i in reversed(range(0, a.shape[-1], block_size)): - const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size; - for (int64 i = last_blk_ix; i >= 0; i -= block_size) { - int64 k = std::min(block_size, m - i); - - // output[..., i:i+k, :] triangular_solve( - // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); + auto b_row = SliceInMinorDims(b, start, end); + + xla::XlaOp remainder; + if (i == 0) { + remainder = b_row; } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = builder->Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - - // if i - k >= 0: - // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) - if (i - k >= 0) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + // This matrix multiply involves a lot of multiplying with zero (namely, + // X[i * block_size:] = 0), but this is faster than slicing... + end = {k, n}; + if (!left_side) { + std::swap(end[0], end[1]); + } + if (transpose_a) { + std::swap(start[0], start[1]); + std::swap(end[0], end[1]); + } + auto a_row = + MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); + if (left_side) { + remainder = b_row - BatchDot(a_row, x, transpose_a, false); } else { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + remainder = b_row - BatchDot(x, a_row, false, transpose_a); } + } - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, 0}, {i, n})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); + xla::XlaOp x_update; + auto zero = Zero(builder, xla::S32); + auto start_index = + xla::ConstantR0WithType(builder, xla::S32, j * block_size); + std::vector update_starts = {start_index, zero}; + if (left_side) { + x_update = BatchDot(inv_block, remainder, transpose_a, false); + } else { + x_update = BatchDot(remainder, inv_block, false, transpose_a); + std::swap(update_starts[0], update_starts[1]); } + x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); } - } - return output; + return x; + }); } -xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(a_shape); - - std::vector batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - batch_dimensions.push_back(a_size); - } - - // The main computation is performed in a While loop. - - // Allocate the output and set its first or last row, - // output = np.zeros_like(b) - // if transpose_a: - // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] - // else: - // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] - xla::XlaOp output = Zeros(builder, b_shape); - { - auto i = transpose_a ? m - 1 : 0; - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - auto update = builder->Div(b_slice, a_slice_conj); - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - } - - // Construct the initial loop carry tuple, - // if transpose_a: - // init = (m-2, output, a, b) - // else: - // init = (1, output, a, b) - std::vector tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The output has the shape of b, with one row updated each iteration. - b_shape, - // The coefficient matrix a is a loop invariant. - a_shape, - // The right-hand-side matrix b is a loop invariant. - b_shape}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = builder->ConstantR0(transpose_a ? m - 2 : 1); - auto init = builder->Tuple({init_i, output, a, b}); - - // Construct the loop condition function, - // def cond_fun(loop_carry): - // i, output, a, b = loop_carry - // return i >= 0 if transpose_a else i < m - std::unique_ptr condb = - builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); - { - auto i = condb->GetTupleElement( - condb->Parameter(0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"), - 0); - if (transpose_a) { - condb->Ge(i, condb->ConstantR0(0)); - } else { - condb->Lt(i, condb->ConstantR0(m)); +xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have different ranks: ", + xla::ShapeUtil::HumanString(a_shape), " vs. ", + xla::ShapeUtil::HumanString(b_shape)); } - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function, - // def body_fun(loop_carry): - // i, output, a, b = loop_carry - // if transpose_a: - // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2) - // else: - // a_row = a[..., i:i+1, :i] - // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :]) - // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - // if transpose_a: - // return (i - 1, output, a, b) - // else: - // return (i + 1, output, a, b) - // We have to do some extra FLOPs propagating zeros in the matrix multiply - // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr bodyb = - builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); - { - auto input_tuple = bodyb->Parameter(0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"); - - // i, output, a, b = loop_carry - auto i = bodyb->GetTupleElement(input_tuple, 0); - auto body_out = bodyb->GetTupleElement(input_tuple, 1); - auto body_a = bodyb->GetTupleElement(input_tuple, 2); - auto body_b = bodyb->GetTupleElement(input_tuple, 3); - auto zero = bodyb->ConstantR0(0); - - // We'd like to implement this: - // if transpose_a: - // a_row = T(a[..., i+1:, i:i+1]) - // result_row = (b[..., i:i+1, :] - // - np.matmul(a_row, body_out[..., i+1:, :])) - // else: - // result_row = (b[..., i:i+1, :] - // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :])) - // But since we can't have intermediate array sizes depend on the loop - // counter, we instead exploit the fact that we initialized the output to - // all zeros and use that as zero-padding (doing unnecessary FLOPs). - xla::XlaOp a_row; - if (transpose_a) { - TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, - {zero, i}, {m, 1})); - } else { - TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, zero}, {1, m})); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to TriangularSolve must have rank >= 2: ", ndims); + } + // The batch dimensions must be equal. + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + int64 b_size = b_shape.dimensions(i); + if (a_size != b_size) { + return errors::InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal: ", + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); + } + batch_dimensions.push_back(a_size); } - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN( - auto result_row_slice, - DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n})); - auto result_row = bodyb->Sub(result_row_slice, b_update); - - // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, i}, {1, 1})); - TF_ASSIGN_OR_RETURN(auto a_elt_conj, - MaybeConjugate(bodyb.get(), a_elt, conjugate_a)); - auto div_result = bodyb->Div(result_row, a_elt_conj); - TF_ASSIGN_OR_RETURN(body_out, - DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, - div_result, {i, zero})); - - // if transpose_a: - // return (i - 1, body_out, a, b) - // else: - // return (i + 1, body_out, a, b) - auto next_i = bodyb->Add(i, bodyb->ConstantR0(transpose_a ? -1 : 1)); - bodyb->Tuple({next_i, body_out, body_a, body_b}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = builder->While(cond, body, init); - return builder->GetTupleElement(triangular_solve_left_looking_while, 1); -} -xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(a_shape); - - std::vector batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - batch_dimensions.push_back(a_size); - } - - // The main computation is performed in a While loop. - xla::XlaOp output = Zeros(builder, b_shape); - - // Construct the initial loop carry tuple, - // if transpose_a: - // init = (0, output, a, b) - // else: - // init = (n-1, output, a, b) - std::vector tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The output has the shape of b, with one row updated each iteration. - b_shape, - // The coefficient matrix a is a loop invariant. - a_shape, - // The right-hand-side matrix b is a loop invariant. - b_shape}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = builder->ConstantR0(transpose_a ? 0 : n - 1); - auto init = builder->Tuple({init_i, output, a, b}); - - // Construct the loop condition function, - // def cond_fun(loop_carry): - // i, output, a, b = loop_carry - // return i < n if transpose_a else i >= 0 - std::unique_ptr condb = - builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); - { - auto i = condb->GetTupleElement( - condb->Parameter(0, tuple_shape, - "TriangularSolveRightLookingWhileTuple"), - 0); - if (transpose_a) { - condb->Lt(i, condb->ConstantR0(n)); - } else { - condb->Ge(i, condb->ConstantR0(0)); + if (xla::ShapeUtil::GetDimension(a_shape, -1) != + xla::ShapeUtil::GetDimension(a_shape, -2)) { + return errors::InvalidArgument( + "The 'a' arguments to TriangularSolve must be square matrices: ", + xla::ShapeUtil::HumanString(a_shape)); } - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function, - // def body_fun(loop_carry): - // i, output, a, b = loop_carry - // if transpose_a: - // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2) - // else: - // a_row = a[..., :, i:i+1] - // result_row = b[..., :, i:i+1] - np.matmul(output, a_row) - // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] - // if transpose_a: - // return (i - 1, output, a, b) - // else: - // return (i + 1, output, a, b) - // We have to do some extra FLOPs propagating zeros in the matrix multiply - // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr bodyb = - builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); - { - auto input_tuple = bodyb->Parameter( - 0, tuple_shape, "TriangularSolveRightLookingWhileTuple"); - - // i, output, a, b = loop_carry - auto i = bodyb->GetTupleElement(input_tuple, 0); - auto body_out = bodyb->GetTupleElement(input_tuple, 1); - auto body_a = bodyb->GetTupleElement(input_tuple, 2); - auto body_b = bodyb->GetTupleElement(input_tuple, 3); - auto zero = bodyb->ConstantR0(0); - - // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :, - // i:i+1]) But since we can't have intermediate array sizes depend on the - // loop counter, we instead exploit the fact that we initialized the output - // to all zeros and use that as zero-padding (doing unnecessary FLOPs). - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - // result = b - np.matmul(output, a) - auto result = bodyb->Sub(body_b, b_update); - // result_row = result[..., :, i:i+1] - TF_ASSIGN_OR_RETURN( - auto result_row, - DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1})); - - // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] - TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, i}, {1, 1})); - TF_ASSIGN_OR_RETURN(auto a_ii_conj, - MaybeConjugate(bodyb.get(), a_ii, conjugate_a)); - auto div_result = bodyb->Div(result_row, a_ii_conj); - TF_ASSIGN_OR_RETURN(body_out, - DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, - div_result, {zero, i})); - - // if transpose_a: - // return (i + 1, body_out, a, b) - // else: - // return (i - 1, body_out, a, b) - auto next_i = bodyb->Add(i, bodyb->ConstantR0(transpose_a ? 1 : -1)); - bodyb->Tuple({next_i, body_out, body_a, body_b}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = builder->While(cond, body, init); - return builder->GetTupleElement(triangular_solve_left_looking_while, 1); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes: ", + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got ", + block_size); + } + + // We find the diagonal blocks of the coefficient matrix + auto diag_blocks = DiagonalBlocks(a, block_size); + + // We invert these blocks in parallel using batched matrix-vector products + auto inv_diag_blocks = + InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a); + + // We now find the solution using GEMMs + auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, + lower, transpose_a, conjugate_a); + + return x; + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 540c26b2473df9e7885f4e549b3e516a3d8a0d43..555760b7efabddfb25c9135b109a1c48b487415e 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" namespace tensorflow { @@ -57,23 +57,9 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, - const xla::XlaOp& a, xla::XlaOp b, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - int64 block_size = 256); - -xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a); - -xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a); +xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + int64 block_size = 128); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index 87ea4763f7c2357ae179b68ade3715b24c46432f..aeebf16028d40189203cdfd815f06a339ee72902 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -85,11 +85,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -107,11 +106,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -129,11 +127,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -151,11 +148,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -173,11 +169,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -196,11 +191,32 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/3); xla::Array2D expected({ {0.5, 1.0, 1.5}, @@ -219,11 +235,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 1.0, 1.5}, @@ -242,11 +257,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -267,11 +281,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRightComplex(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/true, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/true, + /*block_size=*/2); xla::Array2D expected({ {0.5, complex64(0.08333333, 0.08333333), @@ -295,11 +308,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeftComplex(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 1., 1.5}, @@ -317,49 +329,5 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { xla::ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolveLeftLooking(&builder, a, b, - /*transpose_a=*/false, - /*conjugate_a=*/false); - TF_ASSERT_OK(result.status()); - - xla::Array2D expected({ - {0.5, 1.0, 1.5}, - {0.41666667, 0.33333333, 0.25}, - {0.23148148, 0.18518519, 0.13888889}, - {0.16835017, 0.13468013, 0.1010101}, - }); - - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolveLeftLooking(&builder, a, b, - /*transpose_a=*/false, - /*conjugate_a=*/false); - TF_ASSERT_OK(result.status()); - - xla::Array2D expected({ - {0.5, 1.0, 1.5}, - {0.41666667, 0.33333333, 0.25}, - {0.23148148, 0.18518519, 0.13888889}, - {0.16835017, 0.13468013, 0.1010101}, - }); - - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index d9ff7e6259f3fbab8957394bff5c5670a67dd0eb..8b5beba383cda45d36e2ee27ca5e3b3c5988b6b7 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -28,8 +30,9 @@ limitations under the License. namespace tensorflow { xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) { - return builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), + return xla::Broadcast( + xla::ConstantLiteral(builder, + xla::LiteralUtil::Zero(shape.element_type())), xla::AsInt64Slice(shape.dimensions())); } @@ -37,19 +40,19 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, double value) { switch (type) { case xla::F16: - return builder->ConstantR0(static_cast(value)); + return xla::ConstantR0(builder, static_cast(value)); break; case xla::BF16: - return builder->ConstantR0(static_cast(value)); + return xla::ConstantR0(builder, static_cast(value)); break; case xla::F32: - return builder->ConstantR0(static_cast(value)); + return xla::ConstantR0(builder, static_cast(value)); break; case xla::F64: - return builder->ConstantR0(value); + return xla::ConstantR0(builder, value); break; case xla::C64: - return builder->ConstantR0(value); + return xla::ConstantR0(builder, value); break; default: LOG(FATAL) << "unhandled element type " << type; @@ -61,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, xla::Literal literal; switch (type) { case xla::U8: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::U32: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::U64: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::S8: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::S32: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::S64: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::F32: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::F64: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::C64: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; @@ -94,11 +97,11 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::BF16: literal = std::move( - *xla::Literal::CreateR0(static_cast(value))); + *xla::LiteralUtil::CreateR0(static_cast(value))); break; case xla::F16: - literal = std::move( - *xla::Literal::CreateR0(static_cast(value))); + literal = std::move(*xla::LiteralUtil::CreateR0( + static_cast(value))); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; @@ -107,134 +110,140 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, default: LOG(FATAL) << "unhandled element type " << type; } - return builder->ConstantLiteral(literal); + return xla::ConstantLiteral(builder, literal); } -xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - gtl::ArraySlice start, - gtl::ArraySlice end) { - TF_RET_CHECK(start.size() == end.size()); - int64 n_minor_dims = start.size(); - - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); - - // Prepends 0s in the major dim - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + major_dims.size()); - - // Prepends the shape of the major dims. - std::vector padded_end(n_dims); - std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); - std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); - - std::vector strides(n_dims, 1); - return builder->Slice(x, padded_start, padded_end, strides); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, + gtl::ArraySlice end) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector strides(n_dims, 1); + return xla::Slice(x, padded_start, padded_end, strides); + }); } -std::vector PrependMajorDims(xla::XlaBuilder* builder, - const gtl::ArraySlice& major_dims, - const gtl::ArraySlice& indices) { - std::vector output(indices.size() + major_dims.size()); - std::copy(major_dims.begin(), major_dims.end(), output.begin()); - std::copy(indices.begin(), indices.end(), output.begin() + major_dims.size()); +std::vector ConcatVectors(gtl::ArraySlice xs, + gtl::ArraySlice ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); return output; } -xla::StatusOr DynamicSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts, - const gtl::ArraySlice& sizes) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - int64 n_minor_dims = starts.size(); - TF_RET_CHECK(n_minor_dims == sizes.size()); - TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - sizes.size()); - TF_ASSIGN_OR_RETURN(auto padded_starts, - PrependZerosInMajorDims(builder, x, starts)); - auto padded_sizes = PrependMajorDims(builder, major_dims, sizes); - return builder->DynamicSlice(x, padded_starts, padded_sizes); +xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, + gtl::ArraySlice starts, + gtl::ArraySlice sizes) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + auto padded_starts = PrependZerosInMajorDims(x, starts); + auto padded_sizes = ConcatVectors(major_dims, sizes); + return xla::DynamicSlice(x, padded_starts, padded_sizes); + }); } -xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start) { - // TODO(phawkins): make int64 work on all backends, remove the int32 cast. - std::vector start_as_int32(start.begin(), start.end()); - auto start_constant = builder->ConstantR1(start_as_int32); - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - xla::ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return builder->DynamicUpdateSlice(x, update, start_constant); +xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + auto start_constant = xla::ConstantR1(builder, start_as_int32); + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + xla::ShapeUtil::GetDimension(start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return xla::DynamicUpdateSlice(x, update, start_constant); + }); } -xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - const int64 n_minor_dims = start.size(); - TF_RET_CHECK(n_minor_dims <= n_dims); - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + (n_dims - n_minor_dims)); - return UpdateSlice(builder, x, update, padded_start); +xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(x, update, padded_start); + }); } -xla::StatusOr DynamicUpdateSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, - const std::vector& starts) { - TF_ASSIGN_OR_RETURN(auto padded_starts, - PrependZerosInMajorDims(builder, x, starts)); - return builder->DynamicUpdateSlice(x, update, padded_starts); +xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice starts) { + auto padded_starts = PrependZerosInMajorDims(x, starts); + return xla::DynamicUpdateSlice(x, update, padded_starts); } -xla::StatusOr PrependZerosInMajorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - auto zero = builder->Reshape(builder->ConstantR0(0), {1}); - std::vector padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = - builder->Reshape(starts[i], {1}); - } - return builder->ConcatInDim(padded_starts, 0); +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, + gtl::ArraySlice starts) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + auto zero = xla::Reshape(xla::ConstantR0(builder, 0), {1}); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); + } + return xla::ConcatInDim(builder, padded_starts, 0); + }); } -xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - std::vector permutation(n_dims); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); - return builder->Transpose(x, permutation); +xla::XlaOp TransposeInMinorDims(xla::XlaOp x) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + std::vector permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return xla::Transpose(x, permutation); + }); } -xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, - const xla::XlaOp& x, bool conjugate) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == xla::C64 && conjugate; - return perform_conj ? builder->Conj(x) : x; +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == xla::C64 && conjugate; + return perform_conj ? xla::Conj(x) : x; + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 3c120a2548576d6ad46870583ca65beea63507a3..b4905c952820a45371e090aa98466654e2db9661 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -16,16 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { -// Returns a zero-filled tensor with shape `shape`. -xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape); - // Returns a floating point scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, @@ -33,7 +30,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, // Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros // prepended until the array is length n_dims. -xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, gtl::ArraySlice starts); // Returns a integer scalar constant of 'type' with 'value'. @@ -41,54 +38,43 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, int64 value); -// Builds a vector of zeros of length rank(x) with the last two values being +// Builds a vector of zeros of length rank(x) with the last values being // those in `starts`. -xla::StatusOr PrependZerosInMajorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts); +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, + gtl::ArraySlice starts); // Performs a slice in the minor dimensions of a Tensor. -xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - gtl::ArraySlice start, - gtl::ArraySlice end); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, + gtl::ArraySlice end); -// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`. -std::vector PrependMajorDims(xla::XlaBuilder* builder, - const gtl::ArraySlice& major_dims, - const gtl::ArraySlice& indices); +// Returns the concatenation of `xs` and `ys`. +std::vector ConcatVectors(gtl::ArraySlice xs, + gtl::ArraySlice ys); // Performs a dynamic slice in the minor dimensions of a Tensor. -xla::StatusOr DynamicSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts, const gtl::ArraySlice& sizes); +xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, + gtl::ArraySlice starts, + gtl::ArraySlice sizes); // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update -xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start); +xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start); // Updates a slice of 'x', where 'start' contains a list of minor dimensions: // x[..., start[0], ..., start[n]] = update -xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start); +xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start); -xla::StatusOr DynamicUpdateSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, - const std::vector& starts); +xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice starts); // Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x); +xla::XlaOp TransposeInMinorDims(xla::XlaOp x); // Applies a complex conjugation operation if `a` is complex and `conjugate_a` // is true, otherwise returns its argument. -xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, - const xla::XlaOp& x, bool conjugate); +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc index 265b39402c832f8c810a74f281563b05afdf2b1b..442fe92c34ca26cb1a854cc90da8dc034bca79bb 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -70,8 +70,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) { auto a_data = CreateR2Parameter(BValsRight(), 0, "a", &builder, &a); auto x_data = CreateR0Parameter(2, 1, "x", &builder, &x); auto y_data = CreateR0Parameter(1, 2, "y", &builder, &y); - auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1}); - TF_ASSERT_OK(result.status()); + DynamicSliceInMinorDims(a, {x, y}, {1, 1}); ComputeAndCompareR2(&builder, {{10}}, {a_data.get(), x_data.get(), y_data.get()}, @@ -86,10 +85,8 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); auto index_data = CreateR0Parameter(1, 1, "index", &builder, &index); - TF_ASSERT_OK_AND_ASSIGN( - auto l_index, - DynamicSliceInMinorDims(&builder, a, - {index, builder.ConstantR0(0)}, {1, 4})); + DynamicSliceInMinorDims(a, {index, xla::ConstantR0(&builder, 0)}, + {1, 4}); ComputeAndCompareR3(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}}, {a_data.get(), index_data.get()}); @@ -104,8 +101,7 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) { auto x_data = CreateR0Parameter(2, 2, "x", &builder, &x); auto y_data = CreateR0Parameter(1, 3, "y", &builder, &y); - auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y}); - TF_ASSERT_OK(result.status()); + DynamicUpdateSliceInMinorDims(a, b, {x, y}); xla::Array2D expected( {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}}); @@ -128,13 +124,9 @@ XLA_TEST_F(UtilTest, RowBatchDot) { // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); - TF_ASSERT_OK_AND_ASSIGN( - auto l_index, - DynamicSliceInMinorDims(&builder, a, - {index, builder.ConstantR0(0)}, {1, n})); - TF_ASSERT_OK_AND_ASSIGN( - auto dot, BatchDot(&builder, l_index, row, - /*transpose_x=*/false, /*transpose_y=*/true)); + auto l_index = DynamicSliceInMinorDims( + a, {index, xla::ConstantR0(&builder, 0)}, {1, n}); + BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true); ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, {a_data.get(), row_data.get(), index_data.get()}); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 09ce594930efc0af47306590d76b322ac730f80f..d64394f1401d7ceea004a59c991ef6f4a1c58b41 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -39,7 +40,7 @@ xla::StatusOr> XlaWhileLoop( xla::XlaBuilder* builder) { std::vector elements(arity); for (int i = 0; i < arity; ++i) { - elements[i] = builder->GetTupleElement(tuple, i); + elements[i] = xla::GetTupleElement(tuple, i); } return elements; }; @@ -48,7 +49,8 @@ xla::StatusOr> XlaWhileLoop( std::unique_ptr cond_builder = builder->CreateSubBuilder(strings::StrCat(name, "_condition")); { - auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter"); + auto parameter = + xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), @@ -61,7 +63,8 @@ xla::StatusOr> XlaWhileLoop( std::unique_ptr body_builder = builder->CreateSubBuilder(strings::StrCat(name, "_body")); { - auto parameter = body_builder->Parameter(0, tuple_shape, "parameter"); + auto parameter = + xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); TF_ASSIGN_OR_RETURN( auto result, @@ -69,11 +72,11 @@ xla::StatusOr> XlaWhileLoop( body_builder.get())); TF_RET_CHECK(result.size() == initial_values.size()); - body_builder->Tuple(result); + xla::Tuple(body_builder.get(), result); } TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); - auto outputs = builder->While(cond, body, builder->Tuple(initial_values)); + auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values)); return unpack_tuple(outputs, arity, builder); } @@ -86,9 +89,8 @@ xla::StatusOr> XlaForEachIndex( auto while_cond_fn = [&](gtl::ArraySlice values, xla::XlaBuilder* cond_builder) -> xla::StatusOr { - return cond_builder->Lt( - values[0], - IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); + return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, + num_iterations)); }; auto while_body_fn = [&](gtl::ArraySlice values, xla::XlaBuilder* body_builder) @@ -97,9 +99,10 @@ xla::StatusOr> XlaForEachIndex( std::vector updated_values; updated_values.reserve(values.size()); - updated_values.push_back(body_builder->Add( + updated_values.push_back(xla::Add( iteration, - body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type)))); + xla::ConstantLiteral(body_builder, + xla::LiteralUtil::One(num_iterations_type)))); values.remove_prefix(1); TF_ASSIGN_OR_RETURN(std::vector body_outputs, @@ -111,8 +114,8 @@ xla::StatusOr> XlaForEachIndex( std::vector values; values.reserve(initial_values.size() + 1); - values.push_back( - builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type))); + values.push_back(xla::ConstantLiteral( + builder, xla::LiteralUtil::Zero(num_iterations_type))); values.insert(values.end(), initial_values.begin(), initial_values.end()); TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 5b6684c995889efbb1378c7ac4903548891d090a..9493b1f109be0725f7f733b9f9da664264275a69 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index db56b128375ce8ff2faf12c5d7ea256bdfab0f63..77da1bf29ced60e490f07abad41cf8ce96232982 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -17,29 +17,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/common_runtime/dma_helper.h" namespace tensorflow { -Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { - xla::Shape literal_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape( - host_tensor.dtype(), host_tensor.shape(), &literal_shape)); - - *literal = xla::Literal(literal_shape); - - // memcpy over the payload ... - // TODO(phawkins): handle string types. - size_t total_bytes = host_tensor.TotalBytes(); - if (total_bytes > 0) { - void* dst_ptr = literal->untyped_data(); - const void* src_ptr = DMAHelper::base(&host_tensor); - memcpy(dst_ptr, src_ptr, total_bytes); - } - return Status::OK(); -} - Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::BorrowingLiteral* literal) { xla::Shape xla_shape; @@ -50,6 +32,23 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, return Status::OK(); } +Status HostTensorToMutableBorrowingLiteral( + Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor->dtype(), + host_tensor->shape(), &xla_shape)); + return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal); +} + +Status HostTensorToMutableBorrowingLiteral( + const xla::Shape& xla_shape, Tensor* host_tensor, + xla::MutableBorrowingLiteral* literal) { + *literal = xla::MutableBorrowingLiteral( + static_cast(DMAHelper::base(host_tensor)), xla_shape); + + return Status::OK(); +} + Status HostTensorsToBorrowingLiteralTuple( tensorflow::gtl::ArraySlice host_tensors, xla::BorrowingLiteral* literal) { diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 74685025c1780c5c0ba56205a98786582e9191e9..09d6fa811669b422532673540e4da47f47e6be4e 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -18,7 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -26,14 +26,20 @@ limitations under the License. namespace tensorflow { -// Copies 'host_tensor' to an XLA Literal. Fails if host_tensor is of an -// unsupported type. -Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); - // Returns a BorrowingLiteral that utilizes the same underlying buffer owned by // 'host_tensor'. Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::BorrowingLiteral* literal); +// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer +// owned by 'host_tensor', but is mutable via the xla::Literal methods. +Status HostTensorToMutableBorrowingLiteral( + Tensor* host_tensor, xla::MutableBorrowingLiteral* literal); +// Similar as above, except the literal shape is explicitly provided and used +// instead of obtaining it from the 'host_tensor'. The provided literal shape +// 'xla_shape' must be compatible with the shape of 'host_tensor'. +Status HostTensorToMutableBorrowingLiteral( + const xla::Shape& xla_shape, Tensor* host_tensor, + xla::MutableBorrowingLiteral* literal); // Returns a BorrowingLiteral tuple that utilizes the same underlying buffers // owned by 'host_tensors'. diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index f3d6787daaa1165b28ce63dfd501533fa0963edd..a3404c2b3df7bf25011359d1f5f5b88c29a3f83b 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -27,7 +28,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { { std::vector int64_values = {1, 2, 3}; std::unique_ptr int64_values_literal = - xla::Literal::CreateR1(gtl::ArraySlice(int64_values)); + xla::LiteralUtil::CreateR1(gtl::ArraySlice(int64_values)); Tensor host_tensor; EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) @@ -48,7 +49,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { Tensor host_tensor; std::vector int32_values = {10, 11}; std::unique_ptr int32_values_literal = - xla::Literal::CreateR1(gtl::ArraySlice(int32_values)); + xla::LiteralUtil::CreateR1(gtl::ArraySlice(int32_values)); EXPECT_TRUE( LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) .ok()); diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index bb9168fa358154f3db9dab87bacc9bf28dd16406..ace6fd1d8eeaf439509a7b75d8d986997c392e73 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -8,12 +8,7 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") cc_library( name = "xla_ops", - srcs = [ - "dynamic_slice_ops.cc", - "functional_ops.cc", - "reduce_window_op.cc", - "sendrecv_ops.cc", - ], + srcs = ["xla_ops.cc"], deps = [ "//tensorflow/core:framework", ], diff --git a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc deleted file mode 100644 index d6c0edbb889b1751ac9d9d47d0c9534b543196ff..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -REGISTER_OP("XlaDynamicUpdateSlice") - .Input("input: T") - .Input("update: T") - .Input("indices: Tindices") - .Output("output: T") - .Attr("T: type") - .Attr("Tindices: {int32, int64}") - .SetShapeFn(shape_inference::UnchangedShape) - .Doc(R"doc( -Wraps the XLA DynamicUpdateSlice operator, documented at - https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice -. - -XlaDynamicUpdateSlice generates a result which is the value of the `input` -operand, with a slice update overwritten at `indices`. The shape of `update` -determines the shape of the sub-array of the result which is updated. The shape -of indices must be rank == 1, with dimension size equal to the rank of `input`. - -Handling of out-of-bounds slice indices is implementation-defined. - -input: A `Tensor` of type T. -indices: A vector of indices into `input`. Must have length equal to the rank of - `input`. -update: A `Tensor` of type T. Same rank as `input`. -output: A `Tensor` of type T. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc deleted file mode 100644 index 4a669f8e6eaf644f119f3c0a66f29d9f2c9a9d16..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/functional_ops.cc +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -// TODO(b/37549631) setting the While Op to always be stateful is too -// conservative. -REGISTER_OP("XlaWhile") - .Input("input: T") - .Output("output: T") - .Attr("T: list(type) >= 0") - .Attr("cond: func") - .Attr("body: func") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -output = input; While (Cond(output)) { output = Body(output) } - -input: A list of input tensors whose types are T. -output: A list of output tensors whose types are T. -cond: A function takes 'input' and returns a tensor. If the tensor is - a scalar of non-boolean, the scalar is converted to a boolean - according to the following rule: if the scalar is a numerical - value, non-zero means True and zero means False; if the scalar is - a string, non-empty means True and empty means False. If the - tensor is not a scalar, non-emptiness means True and False - otherwise. -body: A function that takes a list of tensors and returns another - list of tensors. Both lists have the same types as specified by T. -)doc"); - -// TODO(b/37549631) setting the If Op to always be stateful is too -// conservative. -REGISTER_OP("XlaIf") - .Input("cond: Tcond") - .Input("inputs: Tin") - .Output("output: Tout") - .Attr("Tcond: type") - .Attr("then_branch: func") - .Attr("else_branch: func") - .Attr("Tin: list(type) >= 0") - .Attr("Tout: list(type) >= 0") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -output = cond ? then_branch(inputs) : else_branch(inputs). - -cond: A boolean scalar. -inputs: A list of input tensors. -output: A list of tensors returned by either then_branch(inputs) or - else_branch(inputs). The input shapes of the then_branch and - else_branch must match. -then_branch: A function takes 'inputs' and returns a list of tensors, - whose types are the same as what else_branch returns. -else_branch: A function takes 'inputs' and returns a list of tensors. - whose types are the same as what then_branch returns. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc b/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc deleted file mode 100644 index d9af982adc090ea78c711fd4656ba429c53b18c9..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("XlaReduceWindow") - .Input("input: T") - .Input("init_value: T") - .Attr("T: numbertype") - .Attr("computation: func") - .Attr("window_dimensions: list(int)") - .Attr("window_strides: list(int)") - .Attr("padding_low: list(int)") - .Attr("padding_high: list(int)") - .Output("output: T") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Wraps the XLA ReduceWindow operator, documented at - https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . - -input: the input tensor -init_value: a scalar representing the initial value for the reduction -computation: a reducer function to apply -window_dimensions: the shape of the window -window_strides: the inter-window strides -padding_low: the padding to apply at the start of each input dimensions -padding_high: the padding to apply at the end of each input dimension. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc deleted file mode 100644 index 7ec7b50e905a6cbdecea4543dcb87322b5a7e844..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("XlaSend") - .Input("tensor: T") - .Attr("T: type") - .Attr("tensor_name: string") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Sends the named tensor to another XLA computation. Wraps the XLA Send operator -documented at - https://www.tensorflow.org/performance/xla/operation_semantics#send . - -tensor: The tensor to send. -tensor_name: A string key that identifies the channel. -)doc"); - -REGISTER_OP("XlaRecv") - .Output("tensor: dtype") - .Attr("dtype: type") - .Attr("tensor_name: string") - .Attr("shape: shape") - .SetIsStateful() - .SetShapeFn([](shape_inference::InferenceContext* c) { - TensorShape shape_attr; - TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr)); - shape_inference::ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); - c->set_output(0, s); - return Status::OK(); - }) - .Doc(R"doc( -Receives the named tensor from another XLA computation. Wraps the XLA Recv -operator documented at - https://www.tensorflow.org/performance/xla/operation_semantics#recv . - -tensor: The tensor to receive. -dtype: The type of the tensor. -tensor_name: A string key that identifies the channel. -shape: The shape of the tensor. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..a59c77f5c3a309abe8f6fbab1e48455d54e8fae5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("XlaDynamicUpdateSlice") + .Input("input: T") + .Input("update: T") + .Input("indices: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA DynamicUpdateSlice operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice +. + +XlaDynamicUpdateSlice generates a result which is the value of the `input` +operand, with a slice update overwritten at `indices`. The shape of `update` +determines the shape of the sub-array of the result which is updated. The shape +of indices must be rank == 1, with dimension size equal to the rank of `input`. + +Handling of out-of-bounds slice indices is implementation-defined. + +input: A `Tensor` of type T. +indices: A vector of indices into `input`. Must have length equal to the rank of + `input`. +update: A `Tensor` of type T. Same rank as `input`. +output: A `Tensor` of type T. +)doc"); + +// TODO(b/37549631) setting the If Op to always be stateful is too +// conservative. +REGISTER_OP("XlaIf") + .Input("cond: Tcond") + .Input("inputs: Tin") + .Output("output: Tout") + .Attr("Tcond: type") + .Attr("then_branch: func") + .Attr("else_branch: func") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = cond ? then_branch(inputs) : else_branch(inputs). + +cond: A boolean scalar. +inputs: A list of input tensors. +output: A list of tensors returned by either then_branch(inputs) or + else_branch(inputs). The input shapes of the then_branch and + else_branch must match. +then_branch: A function takes 'inputs' and returns a list of tensors, + whose types are the same as what else_branch returns. +else_branch: A function takes 'inputs' and returns a list of tensors. + whose types are the same as what then_branch returns. +)doc"); + +REGISTER_OP("XlaRecv") + .Output("tensor: dtype") + .Attr("dtype: type") + .Attr("tensor_name: string") + .Attr("shape: shape") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + TensorShape shape_attr; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); + c->set_output(0, s); + return Status::OK(); + }) + .Doc(R"doc( +Receives the named tensor from another XLA computation. Wraps the XLA Recv +operator documented at + https://www.tensorflow.org/performance/xla/operation_semantics#recv . + +tensor: The tensor to receive. +dtype: The type of the tensor. +tensor_name: A string key that identifies the channel. +shape: The shape of the tensor. +)doc"); + +REGISTER_OP("XlaReduceWindow") + .Input("input: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("computation: func") + .Attr("window_dimensions: list(int)") + .Attr("window_strides: list(int)") + .Attr("padding_low: list(int)") + .Attr("padding_high: list(int)") + .Output("output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA ReduceWindow operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . + +input: the input tensor +init_value: a scalar representing the initial value for the reduction +computation: a reducer function to apply +window_dimensions: the shape of the window +window_strides: the inter-window strides +padding_low: the padding to apply at the start of each input dimensions +padding_high: the padding to apply at the end of each input dimension. +)doc"); + +REGISTER_OP("XlaSend") + .Input("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Sends the named tensor to another XLA computation. Wraps the XLA Send operator +documented at + https://www.tensorflow.org/performance/xla/operation_semantics#send . + +tensor: The tensor to send. +tensor_name: A string key that identifies the channel. +)doc"); + +REGISTER_OP("XlaSort") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA Sort operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#sort +. + +Sorts a tensor. Currently only rank 1 sorts in ascending order are supported. + +input: A `Tensor` of type T. +output: A `Tensor` of type T. +)doc"); + +// TODO(b/37549631) setting the While Op to always be stateful is too +// conservative. +REGISTER_OP("XlaWhile") + .Input("input: T") + .Output("output: T") + .Attr("T: list(type) >= 0") + .Attr("cond: func") + .Attr("body: func") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = input; While (Cond(output)) { output = Body(output) } + +input: A list of input tensors whose types are T. +output: A list of output tensors whose types are T. +cond: A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. +body: A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified by T. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index e5ce65bec950fdfd38c3ca5bc62ac745ef8ca4a7..2fc47dffb8f5f16f24e3beb1ff75aeed3e857c58 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -77,4 +77,6 @@ def reduce_window(operand, recv = gen_xla_ops.xla_recv send = gen_xla_ops.xla_send +sort = gen_xla_ops.xla_sort + while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index ac768b206e2a8d163a4253432a1911152f89ce86..48568c825b7a0f13011d3d6e8e62ec5db026760f 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index d02fc56c5b8f58f0e4cfe1779ad34fe3b79324c7..432a12a51622b56ae74a677420da321c58960ee6 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/graph.pb.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 84c133ffabe20dbdaa4d5a64e035efb5e4c4c44b..56f7045a98201ed398244f9e3f5ff23788135b75 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -73,8 +75,8 @@ TEST(ConvertGraphDefToXla, Sum) { TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. - auto x_literal = xla::Literal::CreateR0(10); - auto y_literal = xla::Literal::CreateR0(32); + auto x_literal = xla::LiteralUtil::CreateR0(10); + auto y_literal = xla::LiteralUtil::CreateR0(32); auto x_global_or = client->TransferToServer(*x_literal); auto y_global_or = client->TransferToServer(*y_literal); TF_EXPECT_OK(x_global_or.status()); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 9203e8d9e607e99ad738350a1c3f2b9e900df179..0e07485d1861aa40b14e527b14947c6f8bab647e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include +#include #include #include @@ -297,4 +298,29 @@ void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, } } +namespace { +uint32 InitialRandomSeed() { + // Support plumbing the TF seed through to XLA is being worked on. + // If a user wants deterministic behavior, their best option + // is to start with a known checkpoint. This also handles issues when + // multiple random calls can be invoked in any order by TF executor. + // Another option is to use stateless random ops. They have much cleaner + // semantics. + // If a user really wants to set a deterministic seed for XLA-based + // devices, this is the place to do it. + std::random_device rd; + // Make the starting value odd. + return rd() | 1; +} +} // namespace + +uint32 GetXLARandomSeed() { + // We initialize counter with an odd number and increment it by two + // everytime. This ensures that it will never be zero, even + // after an overflow. When seeded with zero, some XLA backends + // can return all zeros instead of random numbers. + static std::atomic counter(InitialRandomSeed()); + return counter.fetch_add(2); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 745beb39c1d917cd0d1cd219536ee26a96253ec9..33620ef810bd4fe897f384474e661e341a448b93 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -56,6 +56,9 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, KernelDef* kdef); +// Returns the next random seed to use for seeding xla rng. +uint32 GetXLARandomSeed(); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index fe7ec633eca2504faf6cbb2f5fd7f59780ab7976..e89f4733281194f0263ae8cc4907caa0ad781165 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/platform/mem.h" diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index d0b9e34e162f3412cd6662a2e2bbfe3df213c4c2..a6e78825334fec748be5fee80669649df699d2fb 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 672e19bd93449ccc31f4af5ded23257b197a3c39..1f0f240135dfcd0c540cc39a42514c67ce979ee0 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -16,45 +16,47 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include -#include "tensorflow/compiler/aot/runtime.h" namespace tensorflow { XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, AllocMode alloc_mode) - : raw_function_(static_data.raw_function), - result_index_(static_data.result_index), - args_(new void*[static_data.num_args]), - temps_(new void*[static_data.num_temps]), - arg_names_(static_data.arg_names), - result_names_(static_data.result_names), - program_shape_(static_data.program_shape), - hlo_profile_printer_data_(static_data.hlo_profile_printer_data) { + : raw_function_(static_data.raw_function_), + result_index_(static_data.result_index_), + buffer_table_(new void*[static_data.num_buffers_]), + buffer_infos_(static_data.buffer_infos_), + arg_index_table_(static_data.arg_index_table_), + num_args_(static_data.num_args_), + arg_names_(static_data.arg_names_), + result_names_(static_data.result_names_), + program_shape_(static_data.program_shape_), + hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) { + bool allocate_entry_params = + alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS; // Allocate arg and temp buffers. - if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { - alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - static_data.arg_sizes, static_data.num_args, args_, - /*annotate_initialized=*/false); - } - alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - static_data.temp_sizes, static_data.num_temps, temps_, + alloc_buffer_table_ = cpu_function_runtime::MallocContiguousBuffers( + static_data.buffer_infos_, static_data.num_buffers_, + /*allocate_entry_params=*/allocate_entry_params, buffer_table_, /*annotate_initialized=*/true); - // If Hlo profiling is enabled the generated code expects an appropriately // sized buffer to be passed in as the last argument. If Hlo profiling is // disabled the last function argument is still present in the function // signature, but it is ignored by the generated code and we pass in null for // it. if (hlo_profiling_enabled()) { - profile_counters_ = new int64[static_data.profile_counters_size](); + profile_counters_ = new int64[static_data.profile_counters_size_](); } } +bool XlaCompiledCpuFunction::Run() { + raw_function_(buffer_table_[result_index_], &run_options_, nullptr, + buffer_table_, profile_counters_); + return true; +} + XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { - tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); - tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); - delete[] args_; - delete[] temps_; + cpu_function_runtime::FreeContiguous(alloc_buffer_table_); + delete[] buffer_table_; delete[] profile_counters_; } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 48a8c083cacf2f6ecf9dc1817b6174c01385d035..425e769346ffcbc548495d93cb7adc779f860110 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/core/platform/types.h" @@ -56,36 +57,85 @@ class XlaCompiledCpuFunction { // StaticData represents the state necessary to run an XLA-compiled // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for // AOT this is backed by data compiled into the object file. - struct StaticData { + // + // The contents of StaticData are XLA-internal implementation details and + // should not be relied on by clients. + // + // TODO(sanjoy): Come up with a cleaner way to express the contraint we want + // here: generated XlaCompiledCpuFunction subclasses should be able to create + // instances of StaticData but only XlaCompiledCpuFunction should be able to + // read from StaticData instances. + class StaticData { + public: + void set_raw_function(RawFunction raw_function) { + raw_function_ = raw_function; + } + void set_buffer_infos( + const cpu_function_runtime::BufferInfo* buffer_infos) { + buffer_infos_ = buffer_infos; + } + void set_num_buffers(size_t num_buffers) { num_buffers_ = num_buffers; } + void set_arg_index_table(const int32* arg_index_table) { + arg_index_table_ = arg_index_table; + } + void set_num_args(int64 num_args) { num_args_ = num_args; } + void set_result_index(size_t result_index) { result_index_ = result_index; } + void set_arg_names(const char** arg_names) { arg_names_ = arg_names; } + void set_result_names(const char** result_names) { + result_names_ = result_names; + } + void set_program_shape(const xla::ProgramShape* program_shape) { + program_shape_ = program_shape; + } + const xla::HloProfilePrinterData* hlo_profile_printer_data() const { + return hlo_profile_printer_data_; + } + void set_hlo_profile_printer_data( + const xla::HloProfilePrinterData* hlo_profile_printer_data) { + hlo_profile_printer_data_ = hlo_profile_printer_data; + } + void set_profile_counters_size(int64 profile_counters_size) { + profile_counters_size_ = profile_counters_size; + } + + private: // The raw function to call. - RawFunction raw_function; + RawFunction raw_function_; + + // Contains information about the buffers used by the XLA computation. + const cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr; + size_t num_buffers_ = 0; + + // Entry parameter i is described by + // buffer_infos[arg_index_table[i]]. + const int32* arg_index_table_ = nullptr; - // Cardinality and sizes of arg and temp buffers. - const intptr_t* arg_sizes = nullptr; - size_t num_args = 0; - const intptr_t* temp_sizes = nullptr; - size_t num_temps = 0; + // There are num_args entry parameters. + int64 num_args_ = 0; // The 0-based index of the result tuple, in the temp buffers. - size_t result_index = 0; + size_t result_index_ = 0; // [Optional] Arrays of arg and result names. These are arrays of C-style // strings, where the array is terminated by nullptr. - const char** arg_names = nullptr; - const char** result_names = nullptr; + const char** arg_names_ = nullptr; + const char** result_names_ = nullptr; // [Optional] Arg and result shapes. - const xla::ProgramShape* program_shape = nullptr; + const xla::ProgramShape* program_shape_ = nullptr; // [Optional] Profile printer data. Null if profiling is disabled. - const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr; + const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; // [Optional] The number of profile counters expected in the profile counter // buffer by the generated code and hlo_profile_printer. 0 if profiling is // disabled. This information is already present in // hlo_profile_printer_data but xla::HloProfilePrinterData is forward // declared so we don't have access to that information here. - int64 profile_counters_size = 0; + int64 profile_counters_size_ = 0; + + // Only XlaCompiledCpuFunction is allowed to read the above fields. + friend class XlaCompiledCpuFunction; }; // AllocMode controls the buffer allocation mode. @@ -113,11 +163,7 @@ class XlaCompiledCpuFunction { // Runs the computation, with inputs read from arg buffers, and outputs // written to result buffers. Returns true on success and false on failure. - bool Run() { - raw_function_(temps_[result_index_], &run_options_, - const_cast(args_), temps_, profile_counters_); - return true; - } + bool Run(); // Returns the error message from the previous failed Run call. // @@ -129,14 +175,25 @@ class XlaCompiledCpuFunction { // ------------------------------ // Arg methods for managing input buffers. Buffers are in row-major order. - // Returns the underlying array of argument buffers, where args()[I] is the - // buffer for the positional argument at index I. - void** args() { return args_; } - const void* const* args() const { return args_; } - // Returns the buffer for the positional argument at the given `index`. - void* arg_data(size_t index) { return args_[index]; } - const void* arg_data(size_t index) const { return args_[index]; } + void* arg_data(size_t index) { + return buffer_table_[arg_index_table_[index]]; + } + const void* arg_data(size_t index) const { + return buffer_table_[arg_index_table_[index]]; + } + + int num_args() const { return num_args_; } + + // Returns the size of entry parameter `idx`. + // + // There is a static version of this method on tfcompile generated subclasses + // of XlaCompiledCpuFunction, but try to prefer this when possible since it + // works both for XlaJitCompiledCpuFunction and AOT compiled subclasses. + int arg_size(int idx) const { + assert(idx < num_args()); + return buffer_infos_[arg_index_table_[idx]].size(); + } // Sets the buffer for the positional argument at the given `index` to `data`. // Must be called before Run to have an effect. May be called under any @@ -149,7 +206,9 @@ class XlaCompiledCpuFunction { // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. - void set_arg_data(size_t index, void* data) { args_[index] = data; } + void set_arg_data(size_t index, void* data) { + buffer_table_[arg_index_table_[index]] = data; + } // ------------------------------ // Result methods for managing output buffers. Buffers are in row-major order. @@ -159,9 +218,9 @@ class XlaCompiledCpuFunction { // Returns the underlying array of result buffers, where results()[I] is the // buffer for the positional result at index I. - void** results() { return static_cast(temps_[result_index_]); } + void** results() { return static_cast(buffer_table_[result_index_]); } const void* const* results() const { - return static_cast(temps_[result_index_]); + return static_cast(buffer_table_[result_index_]); } // Profile counters for this XLA computation. @@ -219,14 +278,28 @@ class XlaCompiledCpuFunction { const RawFunction raw_function_; const size_t result_index_; - // Arrays of argument and temp buffers; entries in args_ may be overwritten by - // the user. - void** args_ = nullptr; - void** temps_ = nullptr; + // Array containing pointers to argument and temp buffers (slots corresponding + // to constant and on-stack buffers are null). + void** const buffer_table_; - // Backing memory for individual arg and temp buffers. - void* alloc_args_ = nullptr; - void* alloc_temps_ = nullptr; + // Describes the buffers used by the XLA computation. + const cpu_function_runtime::BufferInfo* const buffer_infos_; + + // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] + // for XLA generated code to be able to find it. + // + // For now we need to keep around the args_ array because there is code that + // depends on args() returning a void**. However, in the future we may remove + // args_ in favor of using buffer_table_ as the sole storage for the + // arguments. + const int32* const arg_index_table_; + + // The number of incoming arguments. + const int32 num_args_; + + // Backing memory for buffer_table_ and args_, the latter depending on + // AllocMode. + void* alloc_buffer_table_ = nullptr; // Backing memory for profiling counters. int64* profile_counters_ = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 9c8e56a17e07348d3cfaaca0b5eb335295af05c3..226c89bcf1e66b5afb43cddb03db39b931ca55a8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -28,11 +28,14 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" @@ -230,10 +233,13 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; case XlaCompiler::Argument::kParameter: { - TensorShape shape = - is_entry_computation - ? options_.shape_representation_fn(arg.shape, arg.type) - : arg.shape; + TensorShape shape; + if (is_entry_computation) { + TF_ASSIGN_OR_RETURN( + shape, options_.shape_representation_fn(arg.shape, arg.type)); + } else { + shape = arg.shape; + } return TensorShapeToXLAShape(arg.type, shape, xla_shape); } case XlaCompiler::Argument::kResource: { @@ -241,8 +247,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, switch (arg.resource_kind) { case XlaResource::kVariable: { - TensorShape representation_shape = - options_.shape_representation_fn(arg.shape, arg.type); + TF_ASSIGN_OR_RETURN( + TensorShape representation_shape, + options_.shape_representation_fn(arg.shape, arg.type)); return TensorShapeToXLAShape(arg.type, representation_shape, xla_shape); } @@ -338,9 +345,9 @@ Status BuildComputation( const std::vector& arg_cores, const std::vector& retvals, const std::vector>& resources, - bool return_updated_values_for_all_resources, xla::XlaBuilder* builder, - xla::XlaComputation* computation, int* num_computation_outputs, - int* num_nonconst_outputs, + bool return_updated_values_for_all_resources, bool always_return_tuple, + xla::XlaBuilder* builder, xla::XlaComputation* computation, + int* num_computation_outputs, int* num_nonconst_outputs, std::vector* outputs, std::vector* resource_updates) { std::vector elems; @@ -384,13 +391,14 @@ Status BuildComputation( const XlaCompiler::Argument& arg = args[resource->arg_num()]; const int core = arg_cores[resource->arg_num()]; DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = resource->value() != resource->initial_value(); + bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); // TensorArray gradients were modified if their values changed or there are // any newly created gradients. for (const auto& grad : resource->tensor_array_gradients()) { - modified = modified || - grad.second->value() != grad.second->initial_value() || - arg.tensor_array_gradients.count(grad.first) == 0; + modified = + modified || + !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || + arg.tensor_array_gradients.count(grad.first) == 0; } if (return_updated_values_for_all_resources || modified) { resource_updates->emplace_back(); @@ -415,16 +423,20 @@ Status BuildComputation( // create a tuple/get-tuple-element combination so that sharding // assignment will be placed on this value, which will cause the resource // update to be returned from the same device that provided the resource. - handle = builder->GetTupleElement(builder->Tuple({handle}), 0); - + handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0); elems.push_back(handle); } } *num_computation_outputs = elems.size(); - // Builds the XLA computation. - builder->Tuple(elems); + // Builds the XLA computation. We *always* form a tuple here to ensure that + // the output value is the last thing added into the XLA computation, even + // if there is only one output value. + auto tuple = xla::Tuple(builder, elems); + if (!always_return_tuple && elems.size() == 1) { + xla::GetTupleElement(tuple, 0); + } builder->ClearOpMetadata(); xla::StatusOr computation_status = builder->Build(); @@ -551,16 +563,16 @@ Status XlaCompiler::BuildArguments( } xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, tuple_sharding); - tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } else { - tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); - arg_handles[i] = builder->GetTupleElement(tuple, i); + arg_handles[i] = xla::GetTupleElement(tuple, i); } } else { for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { @@ -568,8 +580,8 @@ Status XlaCompiler::BuildArguments( xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); - arg_handles[i] = - builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i)); + arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], + strings::StrCat("arg", i)); } } @@ -600,7 +612,7 @@ Status XlaCompiler::BuildArguments( // return values of functions, and then reshape unconditionally. if (is_entry_computation) { arg_expression.set_handle( - builder->Reshape(arg_handles[i], arg.shape.dim_sizes())); + xla::Reshape(arg_handles[i], arg.shape.dim_sizes())); } else { arg_expression.set_handle(arg_handles[i]); } @@ -660,20 +672,17 @@ Status XlaCompiler::CompileSingleOp( namespace { // Check that the ops of all non-functional nodes have been registered. -string ValidateFunctionDef(const FunctionDef* fdef, +Status ValidateFunctionDef(const FunctionDef* fdef, const FunctionLibraryDefinition& flib_def) { - std::vector invalid_ops; for (const NodeDef& node : fdef->node_def()) { const string& op = node.op(); if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { continue; } const OpDef* op_def; - if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) { - invalid_ops.push_back(op); - } + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def)); } - return tensorflow::str_util::Join(invalid_ops, ", "); + return Status::OK(); } // Check that the graph doesn't have any invalid nodes (e.g. incompatible with @@ -681,35 +690,33 @@ string ValidateFunctionDef(const FunctionDef* fdef, Status ValidateGraph(const Graph* graph, const FunctionLibraryDefinition& flib_def, const DeviceType& device_type, const string& name) { - std::vector invalid_ops; + auto maybe_error = [&](const Node* node, const Status& s) -> Status { + if (!s.ok()) { + return errors::InvalidArgument(strings::StrCat( + "Detected unsupported operations when trying to compile graph ", name, + " on ", device_type.type_string(), ": ", node->def().op(), " (", + s.error_message(), ")", FormatNodeForError(*node))); + } + return Status::OK(); + }; + for (const Node* node : graph->nodes()) { if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { continue; } const FunctionDef* fdef = flib_def.Find(node->def().op()); + Status s; if (fdef) { - string error_msg = ValidateFunctionDef(fdef, flib_def); - if (!error_msg.empty()) { - invalid_ops.push_back( - strings::StrCat(node->def().op(), ":{", error_msg, "}")); - } + s = ValidateFunctionDef(fdef, flib_def); + TF_RETURN_IF_ERROR(maybe_error(node, s)); continue; } const OpDef* op_def; - if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) { - invalid_ops.push_back(node->def().op()); - continue; - } + s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def); + TF_RETURN_IF_ERROR(maybe_error(node, s)); TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); - if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) { - invalid_ops.push_back(node->def().op()); - } - } - if (!invalid_ops.empty()) { - return errors::InvalidArgument(strings::StrCat( - "Detected unsupported operations when trying to compile graph ", name, - " on ", device_type.type_string(), ":", - tensorflow::str_util::Join(invalid_ops, ", "))); + s = FindKernelDef(device_type, node->def(), nullptr, nullptr); + TF_RETURN_IF_ERROR(maybe_error(node, s)); } return Status::OK(); } @@ -767,9 +774,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( args, arg_cores, context->retvals(), context->resources(), - options.return_updated_values_for_all_resources, &builder, - result->computation.get(), &num_computation_outputs, - &num_nonconst_outputs, &result->outputs, &result->resource_updates)); + options.return_updated_values_for_all_resources, + options.always_return_tuple, &builder, result->computation.get(), + &num_computation_outputs, &num_nonconst_outputs, &result->outputs, + &result->resource_updates)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index c93850ce270502ea1df1f6469963e96e86994fa2..25332c8d8e3210a0217a1ba3f5767115fe6b1d93 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" @@ -52,13 +54,7 @@ class XlaContext; // (kind kResource). // // Only kParameter and initialized kResource arguments become runtime parameters -// to the generated XLA computation. The XLA computation will have run-time -// parameters in the following order: -// +---------------------+-----------------------------------------+ -// | kParameter values | Initial values of kResource arguments | -// +---------------------+-----------------------------------------+ -// Within each block, the arguments are arranged by the _Arg index from which -// they were derived. +// to the generated XLA computation. // // The run-time outputs of the XLA computation are arranged in the following // order: @@ -77,10 +73,10 @@ class XlaContext; // tensors with a different shape to their representation inside the XLA // computation. // -// In both inputs and outputs, kResource values are placed the end. When +// In computation outputs, updated kResource values are placed the end. When // emitting While loop bodies, we must ensure that the loop body has -// identical input and output signatures. By moving variable values -// to the end of the argument list and using the +// identical input and output signatures. By passing variable values +// at the end of the argument list and using the // `return_updated_values_for_all_variables` option, we can ensure that the // input and output values of resources appear at the same positions. // @@ -175,6 +171,11 @@ class XlaCompiler { // computation. bool resolve_compile_time_constants = true; + // If 'always_return_tuple' is true, then the output of a computation will + // always be a tuple. Otherwise, a single-element output will not be wrapped + // in a tuple. + bool always_return_tuple = true; + // True when compiling the entry computation, false for subcomputations // (while, call, etc.) bool is_entry_computation = true; @@ -234,7 +235,8 @@ class XlaCompiler { tf2xla::HostComputeMetadata host_compute_metadata; // Resources whose values were updated by the computation, ordered - // by return value position. Resource updates follow the non-constant + // by return value position (which is the same as the order the resources + // were passed as arguments). Resource updates follow the non-constant // results in the outputs of XLA computation. std::vector resource_updates; @@ -242,13 +244,20 @@ class XlaCompiler { std::shared_ptr computation; }; - typedef std::function + typedef std::function(const TensorShape&, + DataType)> ShapeRepresentationFn; struct Options { // Name of the compilation device to use. It must be set by the caller. // The default empty value is invalid. DeviceType device_type = DeviceType(""); + // The device to use during compilation to execute instructions on, for + // example for auto-tuning. + // Valid values are defined by `xla::Backend::devices_ordinal_supported()`. + // -1 indicates the default device should be used. + int device_ordinal = -1; + xla::Client* client = nullptr; // Function library in which to find function definitions. Must be non-null. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 613230452b74755ce7543ec2ab82861aa0dfeb7a..be00ed8813fdf2778d6af81556001ef51538dd34 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -206,9 +206,9 @@ TEST_F(XlaCompilerTest, Simple) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param1_literal = - xla::Literal::CreateR1({-3, 101}); + xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -222,12 +222,64 @@ TEST_F(XlaCompilerTest, Simple) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR1({4, 143}); + xla::LiteralUtil::CreateR1({4, 143}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get()}); + xla::LiteralUtil::MakeTuple({expected0.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } +// Tests compilation of a graph where the _Retval node is not necessarily last +// amongst the graph nodes in construction order, and always_return_tuple is +// false. Regression test for bug where the wrong value was returned. +TEST_F(XlaCompilerTest, OutOfOrderGraph) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); + // The _Retval node is not last in construction order. + auto d = ops::_Retval(scope.WithOpName("D"), a, 0); + auto c = ops::Add(scope.WithOpName("C"), a, b); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kParameter; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompileOptions compile_options; + compile_options.always_return_tuple = false; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + // Tests that the generated computation works. + std::unique_ptr param0_literal = + xla::LiteralUtil::CreateR1({7, 42}); + std::unique_ptr param1_literal = + xla::LiteralUtil::CreateR1({-3, 101}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); +} + TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { // Builds a graph that adds reshapes a tensor, but with the shape not // statically known. @@ -260,7 +312,7 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { str_util::StrContains(status.error_message(), "depends on a parameter")) << status.error_message(); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "[[Node: C = Reshape")) + str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape")) << status.error_message(); } @@ -306,7 +358,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -317,9 +369,9 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR1({-7, -42}); + xla::LiteralUtil::CreateR1({-7, -42}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get()}); + xla::LiteralUtil::MakeTuple({expected0.get()}); EXPECT_TRUE( xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -341,7 +393,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -351,11 +403,12 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = xla::Literal::CreateR0(7); + std::unique_ptr expected0 = + xla::LiteralUtil::CreateR0(7); std::unique_ptr expected1 = - xla::Literal::CreateR1({-7, -42}); + xla::LiteralUtil::CreateR1({-7, -42}); std::unique_ptr expected = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); } } @@ -569,11 +622,11 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { // Tests that the generated computation works. std::unique_ptr input_base = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr input_grad2 = - xla::Literal::CreateR1({-3, 101}); + xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr input = - xla::Literal::MakeTuple({input_base.get(), input_grad2.get()}); + xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()}); std::unique_ptr param0_data = client_->TransferToServer(*input).ConsumeValueOrDie(); @@ -583,17 +636,18 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr output_read = xla::Literal::CreateR0(42); + std::unique_ptr output_read = + xla::LiteralUtil::CreateR0(42); std::unique_ptr output_base = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr output_grad1 = - xla::Literal::CreateR1({0, 1}); + xla::LiteralUtil::CreateR1({0, 1}); std::unique_ptr output_grad2 = - xla::Literal::CreateR1({-3, 101}); - std::unique_ptr output_resource = xla::Literal::MakeTuple( + xla::LiteralUtil::CreateR1({-3, 101}); + std::unique_ptr output_resource = xla::LiteralUtil::MakeTuple( {output_base.get(), output_grad1.get(), output_grad2.get()}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); + xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -796,9 +850,9 @@ TEST_F(XlaCompilerTest, Variables) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param1_literal = - xla::Literal::CreateR1({-3, 101}); + xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -812,11 +866,11 @@ TEST_F(XlaCompilerTest, Variables) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR1({5, 144}); + xla::LiteralUtil::CreateR1({5, 144}); std::unique_ptr expected1 = - xla::Literal::CreateR1({4, 143}); + xla::LiteralUtil::CreateR1({4, 143}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -884,9 +938,9 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR2({{4, 55}, {1, -3}}); + xla::LiteralUtil::CreateR2({{4, 55}, {1, -3}}); std::unique_ptr param1_literal = - xla::Literal::CreateR1({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -900,11 +954,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR2({{27, 67}, {35, 402}}); + xla::LiteralUtil::CreateR2({{27, 67}, {35, 402}}); std::unique_ptr expected1 = - xla::Literal::CreateR1({26, 66, 34, 401}); + xla::LiteralUtil::CreateR1({26, 66, 34, 401}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -953,9 +1007,9 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({4, 55, 1, -3}); + xla::LiteralUtil::CreateR1({4, 55, 1, -3}); std::unique_ptr param1_literal = - xla::Literal::CreateR1({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -969,11 +1023,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR1({27, 67, 35, 402}); + xla::LiteralUtil::CreateR1({27, 67, 35, 402}); std::unique_ptr expected1 = - xla::Literal::CreateR1({26, 66, 34, 401}); + xla::LiteralUtil::CreateR1({26, 66, 34, 401}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -1021,8 +1075,9 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE( - str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}")) + EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp")) + << status.error_message(); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}")) << status.error_message(); } @@ -1048,6 +1103,8 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { EXPECT_TRUE(str_util::StrContains(status.error_message(), "is not in the list of allowed values")) << status.error_message(); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}")) + << status.error_message(); } TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { @@ -1069,9 +1126,10 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", std::move(graph_copy), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: NoOp")) + EXPECT_TRUE( + str_util::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: {{node NoOp}}")) << status.error_message(); } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 098072d33cd4eb7f7dec0ec4196b43eca0220d4a..b24e3aabbe6ba858a8bfb4dd435726984cc7b0f5 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -25,9 +25,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -66,8 +67,8 @@ XlaContext::XlaContext( XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, bool is_entry_computation, - const std::function* - shape_representation_fn) + const std::function( + const TensorShape&, DataType)>* shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), @@ -92,7 +93,7 @@ void XlaContext::AddRetval(int retval_index, DataType type, } Status XlaContext::AddConstRetval(int retval_index, DataType dtype, - const xla::Literal& literal) { + const xla::LiteralSlice& literal) { VLOG(1) << "Adding retval index " << retval_index << " with non-data-dependent tensor to XLA computation"; if (retvals_.size() <= retval_index) { @@ -119,8 +120,8 @@ Status XlaContext::CreateResource( return Status::OK(); } -TensorShape XlaContext::RepresentationShape(const TensorShape& shape, - DataType type) const { +xla::StatusOr XlaContext::RepresentationShape( + const TensorShape& shape, DataType type) const { return (*shape_representation_fn_)(shape, type); } @@ -131,9 +132,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { xla::XlaBuilder b("max<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Max(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Max(x, y); return b.Build().ConsumeValueOrDie(); }); } @@ -145,9 +148,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { xla::XlaBuilder b("min<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Min(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Min(x, y); return b.Build().ConsumeValueOrDie(); }); } @@ -159,9 +164,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { xla::XlaBuilder b("add<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Add(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Add(x, y); return b.Build().ConsumeValueOrDie(); }); } @@ -173,9 +180,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) { xla::XlaBuilder b("mul<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); - auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); - b.Mul(x, y); + auto x = + xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = + xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + xla::Mul(x, y); return b.Build().ConsumeValueOrDie(); }); } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 341bf6ff1f37fa7cd81f41c02a941214067b1bd1..3db37afdba71342cfb20af8841a40cb54709ca73 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -22,8 +22,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -47,8 +48,8 @@ class XlaContext : public ResourceBase { XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, bool is_entry_computation, - const std::function* - shape_representation_fn); + const std::function( + const TensorShape&, DataType)>* shape_representation_fn); // Virtual method defined by ResourceBase. string DebugString() override; @@ -83,7 +84,7 @@ class XlaContext : public ResourceBase { // As for Retval, but for return values that are compile-time constants. Status AddConstRetval(int retval_index, DataType dtype, - const xla::Literal& literal); + const xla::LiteralSlice& literal); // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` @@ -101,8 +102,8 @@ class XlaContext : public ResourceBase { // Returns the XLA shape to be used to represent a variable of TF `shape` // and `type`, or of an argument or return value of a top-level computation. - TensorShape RepresentationShape(const TensorShape& shape, - DataType type) const; + xla::StatusOr RepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -160,7 +161,7 @@ class XlaContext : public ResourceBase { // should be represented in XLA. Parameters/return values will be shaped // according to this function, and reshaped back to/from their declared shapes // for computations. Must be non-null. - const std::function* + const std::function(const TensorShape&, DataType)>* shape_representation_fn_; // Cache of prebuilt computations indexed by their type. diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc index ead229aaccc292d4944db0c1eaf98c82583533cd..23d04d43b358e858ad1ab2463322ce0ab93b23c2 100644 --- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc @@ -31,6 +31,10 @@ bool CpuOpFilter(KernelDef* kdef) { DT_FLOAT); return true; } + // TODO(b/26783907): The CPU backend currently does not implement sort. + if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") { + return false; + } if (kdef->op() == "Const") { AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); } diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc index 62168b648331844bfe2db1a4d5dcad895c8726f3..1398e9ee536a9675e5b703ec3fabf4a8b9d89cbf 100644 --- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc @@ -20,12 +20,6 @@ limitations under the License. namespace tensorflow { bool GpuOpFilter(KernelDef* kdef) { - // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to - // slow code. - if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" || - kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") { - return false; - } if (kdef->op() == "Const") { AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index a1da176fe30ddd0d4460a51b60b2568ecc1af6aa..8efb3d55c88757b9366bdf9622287bdd0a72e295 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -23,9 +23,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -34,103 +37,71 @@ namespace tensorflow { namespace { -Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, const TensorShape& input_shape, - DataType input_type, DataType output_type, int axis, - bool is_min, xla::XlaOp* argminmax) { - xla::XlaOp init_value; - const xla::XlaComputation* reducer; - if (is_min) { - init_value = XlaHelpers::MaxValue(builder, input_type); - reducer = ctx->GetOrCreateMin(input_type); - } else { - init_value = XlaHelpers::MinValue(builder, input_type); - reducer = ctx->GetOrCreateMax(input_type); - } - - xla::PrimitiveType xla_output_type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type)); - - xla::XlaOp input_max = builder->Reduce(input, init_value, *reducer, - /*dimensions_to_reduce=*/{axis}); - std::vector broadcast_dims(input_shape.dims() - 1); - std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); - std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - // Compute a mask that has 1s for elements equal to the maximum. - xla::XlaOp partial_mask = builder->ConvertElementType( - builder->Eq(input, input_max, broadcast_dims), xla_output_type); - - // In order to make identity elements for a bitwise And, we: - // Left shift the 1 to the leftmost bit, yielding 0x10...0 - // Arithmetic right shift the 1 back to the rightmost bit, yielding - // 0xFF...F - int32 bits_in_type = - xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1; - xla::XlaOp shift_amount = - XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type); - xla::XlaOp full_mask = builder->ShiftRightArithmetic( - builder->ShiftLeft(partial_mask, shift_amount), shift_amount); - - // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its - // index. - xla::XlaOp iota; - - const int64 axis_size = input_shape.dim_size(axis); - TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota)); - xla::XlaOp product = - builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis}); - - // If there are multiple maximum elements, choose the one with the highest - // index. - xla::XlaOp output = - builder->Reduce(product, XlaHelpers::MinValue(builder, output_type), - *ctx->GetOrCreateMax(output_type), - /*dimensions_to_reduce=*/{axis}); - *argminmax = output; - return Status::OK(); +xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis, + bool is_min) { + xla::XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + xla::XlaOp init_value; + xla::XlaComputation reducer; + if (is_min) { + init_value = xla::MaxValue(builder, input_shape.element_type()); + reducer = + xla::CreateScalarMinComputation(input_shape.element_type(), builder); + } else { + init_value = xla::MinValue(builder, input_shape.element_type()); + reducer = + xla::CreateScalarMaxComputation(input_shape.element_type(), builder); + } + + xla::XlaOp input_max = xla::Reduce(input, init_value, reducer, + /*dimensions_to_reduce=*/{axis}); + std::vector broadcast_dims(xla::ShapeUtil::Rank(input_shape) - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + // Compute a mask that has 1s for elements equal to the maximum. + xla::XlaOp partial_mask = xla::ConvertElementType( + xla::Eq(input, input_max, broadcast_dims), output_type); + + // In order to make identity elements for a bitwise And, we: + // Left shift the 1 to the leftmost bit, yielding 0x10...0 + // Arithmetic right shift the 1 back to the rightmost bit, yielding + // 0xFF...F + int32 bits_in_type = + xla::ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1; + xla::XlaOp shift_amount = + xla::ConstantR0WithType(builder, output_type, bits_in_type); + xla::XlaOp full_mask = xla::ShiftRightArithmetic( + xla::ShiftLeft(partial_mask, shift_amount), shift_amount); + + // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its + // index. + + const int64 axis_size = xla::ShapeUtil::GetDimension(input_shape, axis); + xla::XlaOp iota = xla::Iota(builder, output_type, axis_size); + xla::XlaOp product = + xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis}); + + // If there are multiple maximum elements, choose the one with the highest + // index. + return xla::Reduce(product, xla::MinValue(builder, output_type), + xla::CreateScalarMaxComputation(output_type, builder), + /*dimensions_to_reduce=*/{axis}); + }); } } // namespace -xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::MinValue(type)); -} - -xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::MaxValue(type)); -} - xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::Zero(type)); + return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type)); } xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::Literal::One(type)); -} - -xla::XlaOp XlaHelpers::Epsilon(xla::XlaBuilder* b, DataType data_type) { - switch (data_type) { - case DT_HALF: - return b->ConstantR0( - static_cast(Eigen::NumTraits::epsilon())); - case DT_BFLOAT16: - return b->ConstantR0(bfloat16::epsilon()); - case DT_FLOAT: - return b->ConstantR0(std::numeric_limits::epsilon()); - case DT_DOUBLE: - return b->ConstantR0(std::numeric_limits::epsilon()); - default: - LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: " - << DataTypeString(data_type); - } + return xla::ConstantLiteral(b, xla::LiteralUtil::One(type)); } xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type, @@ -178,45 +149,14 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { return linspace; } -Status XlaHelpers::ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, xla::XlaOp* argmax) { - return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, - axis, /*is_min=*/false, argmax); +xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, + int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/false); } -Status XlaHelpers::ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, xla::XlaOp* argmin) { - return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, - axis, /*is_min=*/true, argmin); -} - -Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, - xla::XlaOp* iota) { - TensorShape linspace_shape({size}); - Tensor linspace; - switch (dtype) { - case DT_UINT8: - linspace = MakeLinspaceTensor(linspace_shape, size); - break; - case DT_INT32: - linspace = MakeLinspaceTensor(linspace_shape, size); - break; - case DT_INT64: - linspace = MakeLinspaceTensor(linspace_shape, size); - break; - default: - return errors::InvalidArgument("Invalid argument type ", - DataTypeString(dtype)); - } - xla::BorrowingLiteral linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); - - *iota = builder->ConstantLiteral(linspace_literal); - return Status::OK(); +xla::XlaOp XlaHelpers::ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, + int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/true); } Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, @@ -248,6 +188,7 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, return errors::InvalidArgument("Invalid argument type ", DataTypeString(index_type)); } + xla::BorrowingLiteral linspace_literal; TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); @@ -256,17 +197,19 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, std::vector broadcast_dims(indices_shape.dims()); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - xla::XlaOp one_hot_bool = builder->Eq( - indices, builder->ConstantLiteral(linspace_literal), broadcast_dims); + xla::XlaOp one_hot_bool = xla::Eq( + indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims); // Selects the user-provided off_value and on_value values. - *one_hot = builder->Select( - one_hot_bool, builder->Broadcast(on_value, output_shape.dim_sizes()), - builder->Broadcast(off_value, output_shape.dim_sizes())); + *one_hot = xla::Select(one_hot_bool, + xla::Broadcast(on_value, output_shape.dim_sizes()), + xla::Broadcast(off_value, output_shape.dim_sizes())); return Status::OK(); } DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { + // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from + // repeated floating point additions. if (dtype == DT_BFLOAT16 || dtype == DT_HALF) { return DT_FLOAT; } @@ -278,7 +221,7 @@ xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder, const DataType new_element_type) { xla::PrimitiveType convert_to; TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); - return builder->ConvertElementType(operand, convert_to); + return xla::ConvertElementType(operand, convert_to); } } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index c3fdc5252e74363fe289eeabb2cb0d68298ee291..e6522157a535fc3e4ec96cb0496b6be2e525c336 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -28,14 +28,6 @@ namespace tensorflow { // Helper methods for building XLA computations. class XlaHelpers { public: - // Returns a handle representing the minimum value of a scalar - // element of data_type. - static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type); - - // Returns a handle representing the maximum value of a scalar - // element of data_type. - static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type); - // Returns a handle representing the zero value of a scalar // element of data_type. static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type); @@ -44,10 +36,6 @@ class XlaHelpers { // element of data_type. static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type); - // Returns the machine epsilon for floating-point type `data_type`, i.e., - // the difference between 1.0 and the next representable value. - static xla::XlaOp Epsilon(xla::XlaBuilder* b, DataType data_type); - // Returns a handle representing the given value of an integer scalar // element of data_type. // Note that unlike One and Zero, does not work on boolean types. @@ -65,25 +53,15 @@ class XlaHelpers { gtl::ArraySlice shape, xla::Literal* output); - // Sets `argmax` to the argmax of `input` along `axis`. `input_shape` and - // `input_dtype` are the shape and dtype of `input` respectively, and - // `output_type` is the dtype to use for `argmax`. - static Status ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, const TensorShape& input_shape, - DataType input_type, DataType output_type, int axis, - xla::XlaOp* argmax); - - // Sets `argmin` to the argmin of `input` along `axis`. `input_shape` and - // `input_dtype` are the shape and dtype of `input` respectively, and - // `output_type` is the dtype to use for `argmin`. - static Status ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, const TensorShape& input_shape, - DataType input_type, DataType output_type, int axis, - xla::XlaOp* argmin); - - // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`. - static Status Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, - xla::XlaOp* iota); + // Returns the argmax of `input` along `axis`. `output_type` is the type to + // use for the output. + static xla::XlaOp ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, + int axis); + + // Returns the argmin of `input` along `axis`. `output_type` is the type to + // use for the output. + static xla::XlaOp ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, + int axis); // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 9e17756b27733e2453ea1688d13e1d718c25cfc8..86a78ee429e8913edb4a948727fa692083c472f4 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -23,6 +23,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -34,41 +36,6 @@ limitations under the License. namespace tensorflow { namespace { - -// Returns a vector of positional argument buffer sizes. -xla::StatusOr> ComputeArgSizes( - const xla::ProgramShape& program_shape) { - std::vector arg_sizes; - const size_t num_args = program_shape.parameters_size(); - arg_sizes.reserve(num_args); - for (int i = 0; i < num_args; ++i) { - const xla::Shape& arg_shape = program_shape.parameters(i); - constexpr size_t kPointerSize = sizeof(void*); - arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); - } - return std::move(arg_sizes); -} - -// Returns a vector of positional temporary buffer sizes. -xla::StatusOr> ComputeTempSizes( - const xla::BufferAssignment& buffer_assignment) { - const std::vector& allocations = - buffer_assignment.Allocations(); - std::vector temp_sizes; - temp_sizes.reserve(allocations.size()); - for (const xla::BufferAllocation& allocation : allocations) { - // Callers don't allocate temporary buffers for parameters. Nor for - // thread-local buffers, which are lowered to alloca. - if (allocation.is_entry_computation_parameter() || - allocation.is_thread_local()) { - temp_sizes.push_back(-1); - } else { - temp_sizes.push_back(allocation.size()); - } - } - return std::move(temp_sizes); -} - // Returns the index of the result in the temp buffers. xla::StatusOr ComputeResultIndex( const xla::BufferAssignment& buffer_assignment) { @@ -152,11 +119,11 @@ XlaJitCompiledCpuFunction::Compile( const xla::BufferAssignment& buffer_assignment = cpu_executable->buffer_assignment(); - // Compute buffer sizes and the result index, needed to run the raw function. - TF_ASSIGN_OR_RETURN(std::vector arg_sizes, - ComputeArgSizes(*program_shape)); - TF_ASSIGN_OR_RETURN(std::vector temp_sizes, - ComputeTempSizes(buffer_assignment)); + // Compute buffer infos and the result index, needed to run the raw function. + std::vector buffer_infos = + xla::cpu::CreateBufferInfosFromBufferAssignment(buffer_assignment); + std::vector arg_index_table = + xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); TF_ASSIGN_OR_RETURN(size_t result_index, ComputeResultIndex(buffer_assignment)); @@ -164,28 +131,28 @@ XlaJitCompiledCpuFunction::Compile( new XlaJitCompiledCpuFunction); XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get(); jit->executable_ = std::move(executable); - jit->arg_sizes_ = std::move(arg_sizes); - jit->temp_sizes_ = std::move(temp_sizes); + jit->buffer_infos_ = std::move(buffer_infos); + jit->arg_index_table_ = std::move(arg_index_table); jit->program_shape_ = std::move(program_shape); - jit->static_data_.raw_function = std::move(raw_function); - jit->static_data_.arg_sizes = jit->arg_sizes_.data(); - jit->static_data_.num_args = jit->arg_sizes_.size(); - jit->static_data_.temp_sizes = jit->temp_sizes_.data(); - jit->static_data_.num_temps = jit->temp_sizes_.size(); - jit->static_data_.result_index = result_index; + jit->static_data_.set_raw_function(raw_function); + jit->static_data_.set_buffer_infos(jit->buffer_infos_.data()); + jit->static_data_.set_num_buffers(jit->buffer_infos_.size()); + jit->static_data_.set_arg_index_table(jit->arg_index_table_.data()); + jit->static_data_.set_num_args(jit->arg_index_table_.size()); + jit->static_data_.set_result_index(result_index); // Optional metadata is collected and set below. CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); CollectNames(config.fetch(), &jit->nonempty_result_names_, &jit->result_names_); - jit->static_data_.arg_names = jit->arg_names_.data(); - jit->static_data_.result_names = jit->result_names_.data(); - jit->static_data_.program_shape = jit->program_shape_.get(); + jit->static_data_.set_arg_names(jit->arg_names_.data()); + jit->static_data_.set_result_names(jit->result_names_.data()); + jit->static_data_.set_program_shape(jit->program_shape_.get()); if (cpu_executable->hlo_profiling_enabled()) { - jit->static_data_.hlo_profile_printer_data = - &cpu_executable->hlo_profile_printer_data(); - jit->static_data_.profile_counters_size = - cpu_executable->hlo_profile_printer_data().profile_counters_size(); + jit->static_data_.set_hlo_profile_printer_data( + &cpu_executable->hlo_profile_printer_data()); + jit->static_data_.set_profile_counters_size( + cpu_executable->hlo_profile_printer_data().profile_counters_size()); } return std::move(jit_unique_ptr); diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h index af307ae4eff74927242c4650d8a43710e991cc52..d3c8f22a8078d03d15447ed200c914390f40b04f 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -66,9 +66,11 @@ class XlaJitCompiledCpuFunction { // The static data is backed by the rest of the state in this class. XlaCompiledCpuFunction::StaticData static_data_; - // The backing arrays of arg and temp buffer sizes. - std::vector arg_sizes_; - std::vector temp_sizes_; + // The backing array for buffer infos. + std::vector buffer_infos_; + + // The backing array for the arg index table. + std::vector arg_index_table_; // The backing arrays of arg and result names. We hold the actual strings in // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 76c68d81af4dd9ec40fe6b1c33b03a876a0c6dc6..82028c8b9ca9f65a73f8b50edc0a47c7068aba9a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -19,7 +19,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/dma_helper.h" namespace tensorflow { @@ -38,8 +43,7 @@ xla::XlaBuilder* XlaOpKernelContext::builder() const { static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); - CHECK(expression->handle().builder() != nullptr || - expression->resource() != nullptr); + CHECK(expression->handle().valid() || expression->resource() != nullptr); VLOG(1) << "Fetched T" << expression->handle(); return expression; } @@ -48,7 +52,7 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { const XlaExpression* expression = reinterpret_cast(tensor->tensor_data().data()); - CHECK_EQ(expression->handle().builder(), nullptr); + CHECK(!expression->handle().valid()); return const_cast(expression); } @@ -63,10 +67,32 @@ const xla::XlaOp& XlaOpKernelContext::Input(int index) { return GetComputationFromTensor(context_->input(index)); } +const xla::XlaOp& XlaOpKernelContext::Input(StringPiece name) { + return GetComputationFromTensor(GetInputTensorByName(name)); +} + TensorShape XlaOpKernelContext::InputShape(int index) { return context_->input(index).shape(); } +TensorShape XlaOpKernelContext::InputShape(StringPiece name) { + return GetInputTensorByName(name).shape(); +} + +DataType XlaOpKernelContext::input_type(int index) const { + return context_->input(index).dtype(); +} + +xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { + xla::PrimitiveType type; + Status status = DataTypeToPrimitiveType(input_type(index), &type); + if (!status.ok()) { + SetStatus(status); + return xla::PRIMITIVE_TYPE_INVALID; + } + return type; +} + Status XlaOpKernelContext::ConstantInput(int index, xla::Literal* constant_literal) { return ConstantInputReshaped( @@ -87,6 +113,25 @@ Status XlaOpKernelContext::ConstantInputReshaped( } const XlaExpression* expression = CastExpressionFromTensor(tensor); + auto copy_tensor_to_literal = [](const Tensor& tensor, + xla::Literal* literal) { + xla::Shape literal_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape)); + + *literal = xla::Literal(literal_shape); + + // memcpy over the payload ... + // TODO(phawkins): handle string types. + size_t total_bytes = tensor.TotalBytes(); + if (total_bytes > 0) { + void* dst_ptr = literal->untyped_data(); + const void* src_ptr = DMAHelper::base(&tensor); + memcpy(dst_ptr, src_ptr, total_bytes); + } + return Status::OK(); + }; + // If the tensor has a known constant value, there is no need to invoke XLA. if (expression->has_constant_value()) { Tensor temp(tensor.dtype()); @@ -95,19 +140,21 @@ Status XlaOpKernelContext::ConstantInputReshaped( // with the enclosing Tensor. return errors::Internal("Incompatible shapes in ConstantInputReshaped."); } - return HostTensorToLiteral(temp, constant_literal); + + return copy_tensor_to_literal(temp, constant_literal); } // Make sure we treat zero-element tensors as constant. if (new_shape.num_elements() == 0) { Tensor temp(tensor.dtype(), new_shape); - return HostTensorToLiteral(temp, constant_literal); + + return copy_tensor_to_literal(temp, constant_literal); } xla::XlaOp handle = expression->handle(); if (new_shape != tensor.shape()) { // Reshape the handle to the desired shape. - handle = builder()->Reshape(handle, new_shape.dim_sizes()); + handle = xla::Reshape(handle, new_shape.dim_sizes()); } // The XLA layout is specified minor to major, and TensorFlow's minor @@ -162,7 +209,8 @@ Status XlaOpKernelContext::ConstantInputReshaped( } // Converts an int32 or int64 scalar literal to an int64. -static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) { +static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, + int64* out) { if (xla::ShapeUtil::Rank(literal.shape()) != 0) { return errors::InvalidArgument("value is not a scalar"); } @@ -177,7 +225,8 @@ static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) { } // Converts an float32 or float64 scalar literal to a float64. -static Status LiteralToFloat64Scalar(const xla::Literal& literal, double* out) { +static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, + double* out) { if (xla::ShapeUtil::Rank(literal.shape()) != 0) { return errors::InvalidArgument("value is not a scalar"); } @@ -204,7 +253,7 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { } // Converts an int32 or int64 1D literal to an int64 vector. -static Status LiteralToInt64Vector(const xla::Literal& literal, +static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, std::vector* out) { if (xla::ShapeUtil::Rank(literal.shape()) != 1) { return errors::InvalidArgument("value is not 1D"); @@ -292,10 +341,11 @@ Status XlaOpKernelContext::ConstantInputList( return Status::OK(); } -Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, - TensorShape* shape, - xla::XlaOp* value) { - const Tensor& tensor = context_->input(index); +namespace { + +Status ReadVariableInputTensor(const Tensor& tensor, DataType type, + const OpKernelContext* ctx, TensorShape* shape, + xla::XlaOp* value) { const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); @@ -313,18 +363,34 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, *shape = variable->shape(); } - XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = - xla_context.RepresentationShape(variable->shape(), variable->type()); + XlaContext& xla_context = XlaContext::Get(ctx); + TF_ASSIGN_OR_RETURN( + TensorShape representation_shape, + xla_context.RepresentationShape(variable->shape(), variable->type())); if (representation_shape == variable->shape()) { *value = variable->value(); } else { - *value = - builder()->Reshape(variable->value(), variable->shape().dim_sizes()); + *value = xla::Reshape(variable->value(), variable->shape().dim_sizes()); } return Status::OK(); } +} // namespace + +Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, + TensorShape* shape, + xla::XlaOp* value) { + return ReadVariableInputTensor(context_->input(index), type, context_, shape, + value); +} + +Status XlaOpKernelContext::ReadVariableInput(StringPiece name, DataType type, + TensorShape* shape, + xla::XlaOp* value) { + return ReadVariableInputTensor(GetInputTensorByName(name), type, context_, + shape, value); +} + Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, TensorShape* shape) const { const Tensor& tensor = context_->input(index); @@ -368,10 +434,11 @@ void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { const TensorShape& shape = constant.shape(); - xla::Literal literal; - OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal)); - xla::XlaOp handle = builder()->ConstantLiteral(literal); - CHECK_NE(handle.builder(), nullptr); + xla::BorrowingLiteral literal; + OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal)); + + xla::XlaOp handle = xla::ConstantLiteral(builder(), literal); + CHECK(handle.valid()); // Make the Tensor that will refer to the expression. Tensor* output = nullptr; @@ -414,17 +481,17 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { return Status::OK(); } -Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, - xla::XlaOp handle) { - TF_RET_CHECK(handle.builder() != nullptr); +namespace { - const XlaExpression* expression = - CastExpressionFromTensor(context_->input(input_index)); +Status AssignVariableTensor(const Tensor& tensor, DataType type, + const OpKernelContext* ctx, xla::XlaOp handle, + xla::XlaBuilder* builder) { + const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); - auto shape_or_status = builder()->GetShape(handle); + auto shape_or_status = builder->GetShape(handle); if (!shape_or_status.ok()) { return shape_or_status.status(); } @@ -434,15 +501,31 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); - XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = - xla_context.RepresentationShape(shape, type); + XlaContext& xla_context = XlaContext::Get(ctx); + TF_ASSIGN_OR_RETURN(TensorShape representation_shape, + xla_context.RepresentationShape(shape, type)); if (shape != representation_shape) { - handle = builder()->Reshape(handle, representation_shape.dim_sizes()); + handle = xla::Reshape(handle, representation_shape.dim_sizes()); } return variable->SetValue(handle); } +} // namespace + +Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, + xla::XlaOp handle) { + TF_RET_CHECK(handle.valid()); + return AssignVariableTensor(context_->input(input_index), type, context_, + handle, builder()); +} + +Status XlaOpKernelContext::AssignVariable(StringPiece name, DataType type, + xla::XlaOp handle) { + TF_RET_CHECK(handle.valid()); + return AssignVariableTensor(GetInputTensorByName(name), type, context_, + handle, builder()); +} + XlaCompiler* XlaOpKernelContext::compiler() const { return XlaContext::Get(context_).compiler(); } @@ -482,6 +565,12 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( return XlaContext::Get(context_).GetOrCreateMul(type); } +const Tensor& XlaOpKernelContext::GetInputTensorByName(StringPiece name) { + const Tensor* tensor; + CHECK(context_->input(name, &tensor).ok()); + return *tensor; +} + XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} void XlaOpKernel::Compute(OpKernelContext* context) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 667dc262ca03ca716ffbf015a78fc14c7a8b7c1a..ac9dfe3369078df7392a4ef04679f7d7beacf8bb 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -17,7 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" @@ -66,16 +68,26 @@ class XlaOpKernelContext { // Returns the number of inputs to the operator. int num_inputs() const { return context_->num_inputs(); } - // Returns the type of input 'index'. - DataType input_type(int index) { return context_->input(index).dtype(); } + // Returns the type of input `index`. + DataType input_type(int index) const; - // Returns the shape of input 'index'. + // Returns the type of input `index` as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType input_xla_type(int index); + + // Returns the shape of input `index`. TensorShape InputShape(int index); - // Returns input 'index' as a XlaOp. Unlike + // Returns the shape of input `name`. + TensorShape InputShape(StringPiece name); + + // Returns input `index` as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. const xla::XlaOp& Input(int index); + // Returns input `name` as a XlaOp. + const xla::XlaOp& Input(StringPiece name); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -90,13 +102,13 @@ class XlaOpKernelContext { // Helper methods for constant inputs. - // Evaluates input 'index' and stores it in '*constant_literal'. If the + // Evaluates input `index` and stores it in `*constant_literal`. If the // expression cannot be evaluated, e.g., because it depends on unbound // parameters, returns a non-OK status. Status ConstantInput(int index, xla::Literal* constant_literal); - // Evaluates input 'index', reshapes it to 'new_shape' if new_shape != - // InputShape(index), and stores it in '*constant_literal'. If the input + // Evaluates input `index`, reshapes it to `new_shape` if new_shape != + // InputShape(index), and stores it in `*constant_literal`. If the input // cannot be evaluated, e.g., because it depends on unbound parameters, // returns a non-Ok status. If InputShape(index).num_elements() != // new_shape.num_elements(), returns an error status. @@ -131,17 +143,17 @@ class XlaOpKernelContext { return context_->expected_output_dtype(index); } - // Sets output 'index' to the XlaOp 'handle'. + // Sets output `index` to the XlaOp `handle`. // All outputs should be set using SetOutput and SetConstantOutput, not // via the underlying OpKernelContext. void SetOutput(int index, const xla::XlaOp& handle); - // Sets output 'index' to compile-time constant 'host_tensor', where - // 'host_tensor' is a tensor in host memory. It is preferable to use + // Sets output `index` to compile-time constant `host_tensor`, where + // `host_tensor` is a tensor in host memory. It is preferable to use // SetConstantOutput where possible. void SetConstantOutput(int index, const Tensor& host_tensor); - // Sets output 'index' to an invalid value. + // Sets output `index` to an invalid value. // Any subsequent attempt to consume this output will cause an error. void SetInvalidOutput(int index); @@ -151,10 +163,10 @@ class XlaOpKernelContext { // Variables - // Sets '*resource' to the resource associated with input `index`. + // Sets `*resource` to the resource associated with input `index`. Status GetResourceInput(int index, XlaResource** resource); - // Sets output 'index' to be a reference to resource 'resource'. + // Sets output `index` to be a reference to resource `resource`. void SetResourceOutput(int index, XlaResource* resource); // Sets `*type` and `*shape` to the current type and shape of a variable's @@ -163,17 +175,23 @@ class XlaOpKernelContext { TensorShape* shape) const; // Reads the current value of the resouce variable referred to by input - // 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the + // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the // variable. Returns an error if the variable has not been initialized, or if // its type does not match `type`. Status ReadVariableInput(int index, DataType type, TensorShape* shape, xla::XlaOp* value); + // Reads the current value of the resouce variable referred to by input + // `name`. + Status ReadVariableInput(StringPiece name, DataType type, TensorShape* shape, + xla::XlaOp* value); // Assigns the value `handle` to the variable referenced by input // `input_index`. The variable must be of `type`. Returns an error if the // variable has been initialized with a different type or with a // different shape. Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); + // Assigns the value `handle` to the variable referenced by input `name`. + Status AssignVariable(StringPiece name, DataType type, xla::XlaOp handle); // Helper routines for the OP_REQUIRES macros void CtxFailure(const Status& s); @@ -221,6 +239,9 @@ class XlaOpKernelContext { const xla::XlaComputation* GetOrCreateMul(const DataType type); private: + // Returns the tensor of input `name`. + const Tensor& GetInputTensorByName(StringPiece name); + OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 4692038b61f6871a8a16299fd4d11e963eb46a57..46785bc1f0a1279bfd67a55844fe238d9797382b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -71,16 +71,18 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible allow_resource_types settings."; return false; } - if (!x.has_device_whitelist || !y.has_device_whitelist) { - LOG(WARNING) << "Registrations of " << x.name - << " do not both have device whitelists."; + if (!x.has_device_whitelist && !y.has_device_whitelist) { + LOG(WARNING) << "Duplicate registrations of " << x.name + << "with no device whitelists."; return false; } - for (const auto& device : x.device_whitelist) { - if (y.device_whitelist.count(device) != 0) { - LOG(WARNING) << "Multiple registrations of " << x.name << " on device " - << device; - return false; + if (x.has_device_whitelist && y.has_device_whitelist) { + for (const auto& device : x.device_whitelist) { + if (y.device_whitelist.count(device) != 0) { + LOG(WARNING) << "Multiple registrations of " << x.name << " on device " + << device; + return false; + } } } if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) { @@ -157,97 +159,143 @@ void XlaOpRegistry::RegisterCompilationKernels() { registry.jit_kernels_registered_ = true; OpRegistryInterface* op_registry = OpRegistry::Global(); - for (const auto& op : registry.ops_) { - const string& op_name = op.first; - const std::unique_ptr& op_registration = op.second; - const OpDef* op_def; - Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); - if (!lookup_status.ok()) { - LOG(ERROR) << lookup_status.error_message(); - XLA_LOG_LINES( - ERROR, "Ops registered: \n" + - dynamic_cast(op_registry)->DebugString(true)); + // Order of op registration: + // The goal is to allow the co-existence of backend-specific kernels and + // generic kernels. To achieve this, we enforce the following order of + // registrations for one op: + // 1. Process op registration with device whitelists: + // this pass registers backend-specific kernels for this op. + // 2. Process op registration without device whitelists: + // this pass registers the kernels for all the other supported backends. + for (auto& ops : registry.ops_) { + const string& op_name = ops.first; + std::vector>& op_registrations = ops.second; + // Partition the op registration so that the ones with device whitelists + // precede the one without device whitelist. + std::partition(op_registrations.begin(), op_registrations.end(), + [](const std::unique_ptr& op_reg) { + return op_reg->has_device_whitelist; + }); + + // Collect a set of backend registered by ops with device whitelists. + // The op registration without whitelists will register a generic kernel + // for all other backends not in this set. + std::unordered_set whitelisted_backend; + for (auto& op_registration : op_registrations) { + if (op_registration->has_device_whitelist) { + whitelisted_backend.insert(op_registration->device_whitelist.begin(), + op_registration->device_whitelist.end()); + } } - TF_CHECK_OK(lookup_status); - std::unordered_set type_attrs; - for (const OpDef::AttrDef& attr_def : op_def->attr()) { - if (attr_def.type() == "type" || attr_def.type() == "list(type)") { - type_attrs.insert(attr_def.name()); + for (auto& op_registration : op_registrations) { + const OpDef* op_def; + Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); + if (!lookup_status.ok()) { + LOG(ERROR) << lookup_status.error_message(); + XLA_LOG_LINES( + ERROR, + "Ops registered: \n" + + dynamic_cast(op_registry)->DebugString(true)); } - } + TF_CHECK_OK(lookup_status); - // Checks there are no type constraints referring to unknown attributes. - for (const auto& constraint : op_registration->type_constraints) { - if (type_attrs.find(constraint.first) == type_attrs.end()) { - LOG(FATAL) << "Unknown type attribute " << constraint.first - << " in XLA op registration for " << op_name; + std::unordered_set type_attrs; + for (const OpDef::AttrDef& attr_def : op_def->attr()) { + if (attr_def.type() == "type" || attr_def.type() == "list(type)") { + type_attrs.insert(attr_def.name()); + } } - } - for (auto& backend : registry.backends_) { - // If the operator has a device whitelist, only register on whitelisted - // devices. - if (op_registration->has_device_whitelist && - op_registration->device_whitelist.find(backend.first) == - op_registration->device_whitelist.end()) { - continue; + // Checks there are no type constraints referring to unknown attributes. + for (const auto& constraint : op_registration->type_constraints) { + if (type_attrs.find(constraint.first) == type_attrs.end()) { + LOG(FATAL) << "Unknown type attribute " << constraint.first + << " in XLA op registration for " << op_name; + } } - std::unique_ptr kdef(new KernelDef); - kdef->set_op(op_registration->name); - kdef->set_device_type(backend.first); - - // Constrain each type attribute to the intersection of: - // a) the types supported by the backend, and - // b) the types allowed by the OpDef, and - // c) the type constraints. - for (const string& type_attr : type_attrs) { - KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); - attr_constraint->set_name(type_attr); - auto* allowed_values = - attr_constraint->mutable_allowed_values()->mutable_list(); - - const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); - const auto* op_def_allowed_types = - op_def_attr.has_allowed_values() - ? &op_def_attr.allowed_values().list().type() - : nullptr; - auto constraint_it = op_registration->type_constraints.find(type_attr); - const std::set* type_constraints = - constraint_it != op_registration->type_constraints.end() - ? &constraint_it->second - : nullptr; - for (DataType dtype : backend.second.supported_types) { - // Filter out types that aren't allowed by the OpDef. - if (op_def_allowed_types != nullptr && - std::find(op_def_allowed_types->begin(), - op_def_allowed_types->end(), - dtype) == op_def_allowed_types->end()) { - continue; + for (auto& backend : registry.backends_) { + // If the operator has a device whitelist, only register on whitelisted + // devices. + if (op_registration->has_device_whitelist && + op_registration->device_whitelist.find(backend.first) == + op_registration->device_whitelist.end()) { + continue; + } + + // If the operator does NOT has a device whitelist, skip all devices + // that has already been registered. + if (!op_registration->has_device_whitelist && + whitelisted_backend.find(backend.first) != + whitelisted_backend.end()) { + continue; + } + + std::unique_ptr kdef(new KernelDef); + kdef->set_op(op_registration->name); + kdef->set_device_type(backend.first); + + // Constrain each type attribute to the intersection of: + // a) the types supported by the backend, and + // b) the types allowed by the OpDef, and + // c) the type constraints. + bool unsatisfiable_type_constraint = false; + for (const string& type_attr : type_attrs) { + KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); + attr_constraint->set_name(type_attr); + auto* allowed_values = + attr_constraint->mutable_allowed_values()->mutable_list(); + + const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); + const auto* op_def_allowed_types = + op_def_attr.has_allowed_values() + ? &op_def_attr.allowed_values().list().type() + : nullptr; + auto constraint_it = + op_registration->type_constraints.find(type_attr); + const std::set* type_constraints = + constraint_it != op_registration->type_constraints.end() + ? &constraint_it->second + : nullptr; + for (DataType dtype : backend.second.supported_types) { + // Filter out types that aren't allowed by the OpDef. + if (op_def_allowed_types != nullptr && + std::find(op_def_allowed_types->begin(), + op_def_allowed_types->end(), + dtype) == op_def_allowed_types->end()) { + continue; + } + // Filter out types based on the type constraints. + if (type_constraints != nullptr && + type_constraints->find(dtype) == type_constraints->end()) { + continue; + } + // Passed all the filters, this type is allowed. + allowed_values->add_type(dtype); } - // Filter out types based on the type constraints. - if (type_constraints != nullptr && - type_constraints->find(dtype) == type_constraints->end()) { - continue; + if (op_registration->allow_resource_types) { + allowed_values->add_type(DT_RESOURCE); + } + // Don't build KernelDefs that have unsatisfiable type constraints. + if (allowed_values->type().empty()) { + unsatisfiable_type_constraint = true; + break; } - // Passed all the filters, this type is allowed. - allowed_values->add_type(dtype); } - if (op_registration->allow_resource_types) { - allowed_values->add_type(DT_RESOURCE); + if (unsatisfiable_type_constraint) continue; + + if (backend.second.op_filter != nullptr && + !backend.second.op_filter(kdef.get())) { + continue; } + VLOG(2) << "XLA op registration: device: " << backend.first + << " op: " << op_name; + registry.kernel_registrars_.emplace_back( + new kernel_factory::OpKernelRegistrar( + new KernelDef(*kdef), "XlaJitOp", op_registration->factory)); + backend.second.kernel_defs.push_back(std::move(kdef)); } - if (backend.second.op_filter != nullptr && - !backend.second.op_filter(kdef.get())) { - continue; - } - VLOG(2) << "XLA op registration: device: " << backend.first - << " op: " << op_name; - registry.kernel_registrars_.emplace_back( - new kernel_factory::OpKernelRegistrar( - new KernelDef(*kdef), "XlaJitOp", op_registration->factory)); - backend.second.kernel_defs.push_back(std::move(kdef)); } } } @@ -265,12 +313,12 @@ std::vector XlaOpRegistry::DeviceKernels( << "Unknown backend " << compilation_device_name; for (const std::unique_ptr& k : it->second.kernel_defs) { auto op_iter = registry.ops_.find(k->op()); - CHECK(op_iter != registry.ops_.end()); + CHECK(op_iter != registry.ops_.end() && !op_iter->second.empty()); // The test in IsCompatible ensures that if there are multiple matching // registrations for this op name, they all have the same value of // compilation_only, so only the first match needs to be tested. if (include_compilation_only_kernels || - !op_iter->second->compilation_only) { + !op_iter->second.front()->compilation_only) { kernels.push_back(k.get()); } } @@ -282,10 +330,13 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto it = registry.ops_.find(op); - if (it == registry.ops_.end()) { + if (it == registry.ops_.end() || it->second.empty()) { return nullptr; } - return &it->second->compile_time_constant_inputs; + // The test in IsCompatible ensures that if there are multiple matching + // registrations for this op name, they all have the same value of + // compile_time_constant_inputs, so only the first match is returned. + return &it->second.front()->compile_time_constant_inputs; } std::vector XlaOpRegistry::BackendNames() { @@ -378,16 +429,15 @@ XlaOpRegistrar::XlaOpRegistrar( std::unique_ptr registration) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); mutex_lock lock(registry.mutex_); - auto existing_ops = registry.ops_.equal_range(registration->name); - for (auto existing = existing_ops.first; existing != existing_ops.second; - ++existing) { - if (!XlaOpRegistry::IsCompatible(*existing->second, *registration)) { + auto& existing_ops = registry.ops_[registration->name]; + for (auto& existing : existing_ops) { + if (!XlaOpRegistry::IsCompatible(*existing, *registration)) { LOG(FATAL) << "XLA op registration " << registration->name << " is incompatible with existing registration of the same name."; } } - registry.ops_.emplace(registration->name, std::move(registration)); + existing_ops.emplace_back(std::move(registration)); } XlaBackendRegistrar::XlaBackendRegistrar( diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index e255b01dd7fdcb095c7992d4352d2d9bb7d36ac3..fc14834ca6441ea785eacc57e1f502086f36657e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -203,7 +203,7 @@ class XlaOpRegistry { // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. // Registrations present under the same key must satisfy IsCompatible above, // and this is checked during registration. - std::unordered_multimap> ops_ + std::unordered_map>> ops_ GUARDED_BY(mutex_); // Have we already registered the JIT kernels on the JIT devices? @@ -279,7 +279,7 @@ class XlaOpRegistrar { #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP) \ static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ - XlaOpRegistrationBuilder::BUILDER.Build( \ + ::tensorflow::XlaOpRegistrationBuilder::BUILDER.Build( \ [](::tensorflow::OpKernelConstruction* context) \ -> ::tensorflow::OpKernel* { return new OP(context); })); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b3b15b1af7636fddd4c29477cbfe6f9761f2c47 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// This test is to verify the correctness of XLA op registration with specific +// backend overrides. + +// A dummy backend-specific OpKernel for CPU. +class DummyCPUOp : public XlaOpKernel { + public: + explicit DummyCPUOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetOutput(0, ctx->Input(0)); + } +}; + +// A dummy generic OpKernel for all backends. +class DummyGenericOp : public XlaOpKernel { + public: + explicit DummyGenericOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetOutput(0, ctx->Input(0)); + } +}; + +REGISTER_OP("DummyDuplicateOp") + .Attr("T: {float, int32}") + .Input("input: int32") + .Output("output: int32") + .Doc(R"doc( +A dummy Op. + +input: dummy input. +output: dummy output. +)doc"); + +// Register the DummyCPUOp kernel for CPU with type INT32. +REGISTER_XLA_OP(Name("DummyDuplicateOp") + .Device(DEVICE_CPU_XLA_JIT) + .TypeConstraint("T", DT_INT32), + DummyCPUOp); +// Register the DummyGeneric kernel for all registered device (except CPU since +// it is already registered), with type FLOAT. +REGISTER_XLA_OP(Name("DummyDuplicateOp").TypeConstraint("T", DT_FLOAT), + DummyGenericOp); + +// Test the correctness of registered kernels. The kernel registered for CPU +// should have type INT32 while all other kernels should have type FLOAT. +TEST(XlaOpRegistryTest, XlaOpRegistrationWithOverride) { + XlaOpRegistry::RegisterCompilationKernels(); + auto registered_kernels = GetAllRegisteredKernels().kernel(); + for (const auto& kernels : registered_kernels) { + if (kernels.op() == "DummyDuplicateOp") { + EXPECT_EQ(kernels.constraint_size(), 1); + EXPECT_EQ(kernels.constraint(0).name(), "T"); + if (kernels.device_type() == "XLA_CPU_JIT") { + EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0), + DT_INT32); + } else { + EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0), + DT_FLOAT); + } + } + } +} + +// A dummy generic OpKernel for all backends. +class DummyInfeasibleTypeConstraintOp : public XlaOpKernel { + public: + explicit DummyInfeasibleTypeConstraintOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + LOG(FATAL) << "unreachable"; + } +}; + +REGISTER_OP("DummyInfeasibleTypeConstraintOp") + .Attr("T: {float, string}") + .Input("input: T") + .Output("output: T") + .Doc(R"doc( +A dummy Op. + +input: dummy input. +output: dummy output. +)doc"); +REGISTER_XLA_OP( + Name("DummyInfeasibleTypeConstraintOp").TypeConstraint("T", DT_STRING), + DummyInfeasibleTypeConstraintOp); + +TEST(XlaOpRegistryTest, OpWithInfeasibleTypeConstraintIsNotRegistered) { + XlaOpRegistry::RegisterCompilationKernels(); + auto registered_kernels = GetAllRegisteredKernels().kernel(); + for (const auto& kernels : registered_kernels) { + // The operator should not be registered. + EXPECT_NE(kernels.op(), "DummyInfeasibleTypeConstraintOp"); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 540c65c597f20d5bb26494e56c09ff2187cfb0db..7928fa034725206a752cbfe086d01f15cd235df9 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { @@ -89,16 +90,16 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } switch (kind_) { case kVariable: { - value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), - shape_.dim_sizes()); + value_ = + xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes()); break; } case kTensorArray: { TensorShape ta_shape; ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); - value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), - ta_shape.dim_sizes()); + value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()); break; } case kStack: { @@ -106,9 +107,9 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); value_ = - builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_), - ta_shape.dim_sizes()), - builder->ConstantR0(0)}); + xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()), + xla::ConstantR0(builder, 0)}); break; } @@ -130,8 +131,8 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, TensorShape ta_shape; ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); - xla::XlaOp gradient_value = builder->Broadcast( - XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); + xla::XlaOp gradient_value = + xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/strings::StrCat("TensorArrayGrad: ", name_), @@ -152,7 +153,7 @@ Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const { for (const auto& gradient : tensor_array_gradients_) { elems.push_back(gradient.second->value_); } - *pack = builder->Tuple(elems); + *pack = xla::Tuple(builder, elems); } return Status::OK(); } @@ -168,7 +169,7 @@ Status XlaResource::SetFromPack(const std::set& gradient_sources, } else { TF_RET_CHECK(kind_ == kTensorArray); int pos = 0; - auto v = builder->GetTupleElement(pack, pos++); + auto v = xla::GetTupleElement(pack, pos++); if (!initialized()) { initial_value_ = v; } @@ -178,7 +179,7 @@ Status XlaResource::SetFromPack(const std::set& gradient_sources, XlaResource* gradient; TF_RETURN_IF_ERROR( GetOrCreateTensorArrayGradient(source, builder, &gradient)); - auto v = builder->GetTupleElement(pack, pos++); + auto v = xla::GetTupleElement(pack, pos++); if (!gradient->initialized()) { gradient->initial_value_ = v; } diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 9ce36d1aa7622334b2acfbe9aa85d7419c4772ed..2438490be13809b9f3571a362900b44cb838e76b 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" @@ -75,7 +75,7 @@ class XlaResource { const xla::XlaOp& initial_value() const { return initial_value_; } // A variable is initialized if it has a value. - bool initialized() const { return value_.builder() != nullptr; } + bool initialized() const { return value_.valid(); } // Sets the type and shape of the resource. The type and shape of a resource // must not change once the variable has been initialized. diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 1b8e516770c3e217dd7c2f26ce426895b478c2e4..fdf13bb18c2567d2994612d15119ae87cbfa9137 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -142,30 +142,15 @@ cc_library( cc_library( name = "statusor", - srcs = ["statusor.cc"], hdrs = [ "statusor.h", - "statusor_internals.h", ], visibility = ["//visibility:public"], deps = [ ":status", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - ], -) - -tf_cc_test( - name = "statusor_test", - size = "small", - srcs = ["statusor_test.cc"], - deps = [ - ":statusor", - ":test", - ":types", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "//tensorflow/stream_executor", ], ) @@ -175,6 +160,7 @@ cc_library( hdrs = [ "iterator_util.h", "map_util.h", + "overflow_util.h", "ptr_util.h", "util.h", ], @@ -250,7 +236,7 @@ cc_library( ":types", ":util", ":xla_data_proto", - "//tensorflow/core:framework_internal", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", @@ -268,6 +254,7 @@ tf_cc_test( ":types", ":util", ":xla_data_proto", + "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) @@ -295,9 +282,9 @@ tf_cc_test( ) cc_library( - name = "literal_util", - srcs = ["literal_util.cc"], - hdrs = ["literal_util.h"], + name = "literal", + srcs = ["literal.cc"], + hdrs = ["literal.h"], visibility = ["//visibility:public"], deps = [ ":array2d", @@ -309,17 +296,17 @@ cc_library( ":types", ":util", ":xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:lib", ], ) tf_cc_test( - name = "literal_util_test", - srcs = ["literal_util_test.cc"], + name = "literal_test", + srcs = ["literal_test.cc"], deps = [ ":array3d", ":array4d", + ":literal", ":literal_util", ":shape_util", ":test", @@ -331,6 +318,26 @@ tf_cc_test( ], ) +cc_library( + name = "literal_util", + srcs = ["literal_util.cc"], + hdrs = ["literal_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":array2d", + ":array3d", + ":array4d", + ":literal", + ":shape_util", + ":sparse_index_array", + ":status_macros", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + ], +) + cc_library( name = "error_spec", hdrs = ["error_spec.h"], @@ -342,6 +349,7 @@ cc_library( hdrs = ["literal_comparison.h"], deps = [ ":error_spec", + ":literal", ":literal_util", ":util", "//tensorflow/core:lib", @@ -473,7 +481,7 @@ cc_library( hdrs = ["packed_literal_reader.h"], visibility = [":internal"], deps = [ - ":literal_util", + ":literal", ":shape_util", ":status_macros", ":statusor", @@ -504,7 +512,7 @@ cc_library( hdrs = ["text_literal_reader.h"], visibility = [":internal"], deps = [ - ":literal_util", + ":literal", ":shape_util", ":status_macros", ":statusor", @@ -520,7 +528,7 @@ tf_cc_test( name = "text_literal_reader_test", srcs = ["text_literal_reader_test.cc"], deps = [ - ":literal_util", + ":literal", ":shape_util", ":test", ":text_literal_reader", @@ -537,7 +545,7 @@ cc_library( hdrs = ["text_literal_writer.h"], visibility = [":internal"], deps = [ - ":literal_util", + ":literal", ":shape_util", ":status_macros", ":types", @@ -550,6 +558,7 @@ tf_cc_test( name = "text_literal_writer_test", srcs = ["text_literal_writer_test.cc"], deps = [ + ":literal", ":literal_util", ":test", ":test_helpers", @@ -622,11 +631,12 @@ cc_library( ":array2d", ":array3d", ":array4d", + ":literal_util", ":util", ":window_util", ":xla_data_proto", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", @@ -642,7 +652,7 @@ tf_cc_test( ":array2d", ":array3d", ":array4d", - ":literal_util", + ":literal", ":reference_util", ":test", ":util", diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index ea75ad32d5df7bbadd37e89de6144b264ab6d5d1..2d5d078aa77423cc18bab053b80a7576acbd849e 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -409,7 +409,7 @@ class Array { // Returns the total number of elements in the array. int64 num_elements() const { - return std::accumulate(sizes_.begin(), sizes_.end(), 1, + return std::accumulate(sizes_.begin(), sizes_.end(), 1LL, std::multiplies()); } diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 8f08d3b2e04670ad6590aca1db0fd9d25faed83f..ad3fcee05b80181369bfdf3cdcdb5452ec9e7e89 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -64,8 +64,9 @@ cc_library( hdrs = ["client.h"], deps = [ ":global_data", + ":xla_computation", "//tensorflow/compiler/xla:execution_options_util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:service_interface", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -73,7 +74,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", @@ -100,12 +100,12 @@ cc_library( deps = [ ":client", ":executable_build_options", + ":xla_computation", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", @@ -114,6 +114,7 @@ cc_library( "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:support", @@ -126,11 +127,11 @@ cc_library( hdrs = ["compile_only_client.h"], deps = [ ":client", + ":xla_computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:stream_executor_no_cuda", @@ -174,3 +175,60 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", ], ) + +cc_library( + name = "xla_computation", + srcs = ["xla_computation.cc"], + hdrs = ["xla_computation.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + ], +) + +cc_library( + name = "xla_builder", + srcs = ["xla_builder.cc"], + hdrs = ["xla_builder.h"], + visibility = ["//visibility:public"], + deps = [ + ":padding", + ":sharding_builder", + ":xla_computation", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "xla_builder_test", + srcs = ["xla_builder_test.cc"], + deps = [ + ":xla_builder", + ":xla_computation", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 3d596a6e65430b6e9692aabd65fc8aa84b7b873d..d0ce5e8a6afa262d4cffdfe8431aab570ffd28df 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -18,9 +18,10 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -409,8 +410,10 @@ StatusOr Client::ExecutionStatsAsString( return string("[Execution Statistics] not available."); } -StatusOr Client::CreateChannelHandle() { +StatusOr Client::CreateChannelHandleByType( + ChannelHandle::ChannelType type) { CreateChannelHandleRequest request; + request.set_channel_type(type); CreateChannelHandleResponse response; VLOG(1) << "making create channel handle request"; @@ -424,4 +427,16 @@ StatusOr Client::CreateChannelHandle() { return response.channel(); } +StatusOr Client::CreateChannelHandle() { + return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_DEVICE); +} + +StatusOr Client::CreateHostToDeviceChannelHandle() { + return CreateChannelHandleByType(ChannelHandle::HOST_TO_DEVICE); +} + +StatusOr Client::CreateDeviceToHostChannelHandle() { + return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_HOST); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 68f0d0ac78c859fde7a6a007cd250b047a7bfcda..be50cebfcc0e3c19002635dbd280b14048aa0c93 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -20,8 +20,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -178,10 +178,15 @@ class Client { StatusOr> GetComputationShape( const XlaComputation& computation); - // Creates a channel handle that can be used to transfer data between - // two computations via a pair of Send and Recv instructions. + // Creates a channel handle that can be used to transfer data between two + // computations on different devices via a pair of Send and Recv instructions. StatusOr CreateChannelHandle(); + // Create a channel for communicating with the host via a SendtoHost or + // RecvFromHost operation. + StatusOr CreateHostToDeviceChannelHandle(); + StatusOr CreateDeviceToHostChannelHandle(); + StatusOr LoadSnapshot(const HloSnapshot& module); ServiceInterface* stub() { return stub_; } @@ -192,6 +197,9 @@ class Client { StatusOr ExecutionStatsAsString(const XlaComputation& computation, const ExecutionProfile& profile); + StatusOr CreateChannelHandleByType( + ChannelHandle::ChannelType type); + ServiceInterface* stub_; // Stub that this client is connected on. TF_DISALLOW_COPY_AND_ASSIGN(Client); diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index 332c96503637344d56e363e19db4880c37ca9684..a551edeab0943ec5213c5cb035644c02c3cf54d7 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index d49d959a6c8112d3701857a70cecb24701c7b6d9..a2f32ab97eab10294a607f35fc79ded1cc2c5792 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -13,28 +13,191 @@ filegroup( ]), ) +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") + +# Generate test_suites for all backends, named "${backend}_tests". +generate_backend_suites() + cc_library( name = "arithmetic", srcs = ["arithmetic.cc"], hdrs = ["arithmetic.h"], deps = [ + ":constants", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) +cc_library( + name = "constants", + srcs = ["constants.cc"], + hdrs = ["constants.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + ], +) + +xla_test( + name = "constants_test", + srcs = ["constants_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "math", + srcs = ["math.cc"], + hdrs = ["math.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/client:xla_builder", + ], +) + +xla_test( + name = "math_test", + srcs = ["math_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":math", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "numeric", + srcs = ["numeric.cc"], + hdrs = ["numeric.h"], + deps = [ + ":arithmetic", + ":constants", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "numeric_test", + srcs = ["numeric_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":numeric", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "pooling", + srcs = ["pooling.cc"], + hdrs = ["pooling.h"], + deps = [ + ":arithmetic", + ":constants", + "//tensorflow/compiler/tf2xla/lib:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "pooling_test", + srcs = ["pooling_test.cc"], + deps = [ + ":pooling", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "prng", + srcs = ["prng.cc"], + hdrs = ["prng.h"], + deps = [ + ":constants", + ":math", + ":numeric", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "sorting", + srcs = ["sorting.cc"], + hdrs = ["sorting.h"], + deps = [ + ":numeric", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + ], +) + +xla_test( + name = "sorting_test", + srcs = ["sorting_test.cc"], + blacklisted_backends = [ + "cpu", + "gpu", + ], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":sorting", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "testing", srcs = ["testing.cc"], hdrs = ["testing.h"], deps = [ "//tensorflow/compiler/xla:execution_options_util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -42,8 +205,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index a1d34796ccfd86f2025eff0ecb51338eb6a9b1da..9225b1acd69c214d6f08a45372a8082ed789c18c 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -17,8 +17,9 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -42,8 +43,8 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, } const Shape scalar = ShapeUtil::MakeShape(type, {}); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); + auto lhs = Parameter(b.get(), 0, scalar, "lhs"); + auto rhs = Parameter(b.get(), 1, scalar, "rhs"); generator(b.get(), lhs, rhs); return b->BuildAndNoteError(); } @@ -55,7 +56,7 @@ XlaComputation CreateScalarAddComputation(PrimitiveType type, return CreateScalarComputation( "add", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Add(lhs, rhs); + return Add(lhs, rhs); }); } @@ -64,17 +65,15 @@ XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, return CreateScalarComputation( "mul", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Mul(lhs, rhs); + return Mul(lhs, rhs); }); } XlaComputation CreateScalarGeComputation(PrimitiveType type, XlaBuilder* builder) { - return CreateScalarComputation( - "ge", type, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Ge(lhs, rhs); - }); + return CreateScalarComputation("ge", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, + const XlaOp& rhs) { return Ge(lhs, rhs); }); } XlaComputation CreateScalarMaxComputation(PrimitiveType type, @@ -82,7 +81,7 @@ XlaComputation CreateScalarMaxComputation(PrimitiveType type, return CreateScalarComputation( "max", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Max(lhs, rhs); + return Max(lhs, rhs); }); } @@ -91,34 +90,37 @@ XlaComputation CreateScalarMinComputation(PrimitiveType type, return CreateScalarComputation( "min", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Min(lhs, rhs); + return Min(lhs, rhs); }); } -XlaComputation CreateScalarAndComputation(XlaBuilder* builder) { +XlaComputation CreateScalarAndComputation(PrimitiveType type, + XlaBuilder* builder) { return CreateScalarComputation( - "and", PRED, builder, + "and", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->And(lhs, rhs); + return And(lhs, rhs); }); } -XlaComputation CreateScalarOrComputation(XlaBuilder* builder) { - return CreateScalarComputation( - "or", PRED, builder, - [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { - return b->Or(lhs, rhs); - }); +XlaComputation CreateScalarOrComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation("or", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, + const XlaOp& rhs) { return Or(lhs, rhs); }); } -StatusOr Any(const XlaOp& predicates, XlaBuilder* builder) { - auto f = builder->ConstantR0(false); - XlaComputation logical_or = CreateScalarOrComputation(builder); - TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, - builder->GetShape(predicates)); - std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); - std::iota(all_dimensions.begin(), all_dimensions.end(), 0); - return builder->Reduce(predicates, f, logical_or, all_dimensions); +XlaOp Any(XlaOp predicates) { + XlaBuilder* builder = predicates.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + auto f = ConstantR0(builder, false); + XlaComputation logical_or = CreateScalarOrComputation(PRED, builder); + TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, + builder->GetShape(predicates)); + std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); + std::iota(all_dimensions.begin(), all_dimensions.end(), 0); + return Reduce(predicates, f, logical_or, all_dimensions); + }); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 64b6b7d63353165e45bf12d35126a7eeef9e56e4..632e8cc8bc64fad236a0226c6e93079aadde7050 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -45,15 +45,17 @@ XlaComputation CreateScalarMinComputation(PrimitiveType type, XlaBuilder* builder); // Creates a scalar logical AND computation and returns it. -XlaComputation CreateScalarAndComputation(XlaBuilder* builder); +XlaComputation CreateScalarAndComputation(PrimitiveType type, + XlaBuilder* builder); // Creates a scalar logical OR computation and returns it. -XlaComputation CreateScalarOrComputation(XlaBuilder* builder); +XlaComputation CreateScalarOrComputation(PrimitiveType type, + XlaBuilder* builder); // Returns whether any predicate in "predicates" is set. // // Note: if predicates is zero-sized, Any() vacuously returns false. -StatusOr Any(const XlaOp& predicates, XlaBuilder* builder); +XlaOp Any(XlaOp predicates); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc new file mode 100644 index 0000000000000000000000000000000000000000..031d62e4ffef188082303a28866bbc72a154e9b1 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/constants.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +XlaOp Zero(XlaBuilder* builder, PrimitiveType type) { + return ConstantLiteral(builder, LiteralUtil::Zero(type)); +} + +XlaOp Zeros(XlaBuilder* builder, const Shape& shape) { + return Broadcast(Zero(builder, shape.element_type()), + AsInt64Slice(shape.dimensions())); +} + +XlaOp ZerosLike(XlaOp prototype) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + return Zeros(builder, shape); + }); +} + +XlaOp One(XlaBuilder* builder, PrimitiveType type) { + return ConstantLiteral(builder, LiteralUtil::One(type)); +} + +XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) { + switch (type) { + case F16: + return ConstantR0( + builder, + static_cast(Eigen::NumTraits::epsilon())); + case BF16: + return ConstantR0(builder, bfloat16::epsilon()); + case F32: + return ConstantR0(builder, std::numeric_limits::epsilon()); + case F64: + return ConstantR0(builder, + std::numeric_limits::epsilon()); + default: + return builder->ReportError(InvalidArgument( + "Invalid type for Epsilon (%s).", PrimitiveType_Name(type).c_str())); + } +} + +XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) { + return ConstantLiteral(builder, LiteralUtil::MinValue(type)); +} + +XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) { + switch (type) { + case F16: + return ConstantR0(builder, + Eigen::NumTraits::lowest()); + case BF16: + return ConstantR0(builder, bfloat16::lowest()); + case F32: + return ConstantR0(builder, -std::numeric_limits::max()); + case F64: + return ConstantR0(builder, -std::numeric_limits::max()); + default: + return MinValue(builder, type); + } +} + +XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) { + return ConstantLiteral(builder, LiteralUtil::MaxValue(type)); +} + +XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) { + switch (type) { + case F16: + return ConstantR0(builder, + Eigen::NumTraits::highest()); + case BF16: + return ConstantR0(builder, bfloat16::highest()); + case F32: + return ConstantR0(builder, std::numeric_limits::max()); + case F64: + return ConstantR0(builder, std::numeric_limits::max()); + default: + return MaxValue(builder, type); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h new file mode 100644 index 0000000000000000000000000000000000000000..0c8a9b8cc02ba0c1ebdf6a060d4b99262dceb178 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is +// determined at C++ run-time, rather than C++ compile-time. +// If 'value' is floating point but 'type' is not, or if 'value' is complex but +// 'type' is not, an error will be returned. This is to catch accidental +// truncation; in such cases, use an explicit cast. +template +XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { + if (std::is_floating_point::value && + !(primitive_util::IsFloatingPointType(type) || + primitive_util::IsComplexType(type))) { + return builder->ReportError(InvalidArgument( + "Invalid cast from floating point type to %s in ConstantR0WithType.", + PrimitiveType_Name(type).c_str())); + } + if (std::is_same::value && + !primitive_util::IsComplexType(type)) { + return builder->ReportError(InvalidArgument( + "Invalid cast from complex type to %s in ConstantR0WithType.", + PrimitiveType_Name(type).c_str())); + } + switch (type) { + case F16: + return ConstantR0(builder, static_cast(value)); + case BF16: + return ConstantR0(builder, static_cast(value)); + case F32: + return ConstantR0(builder, static_cast(value)); + case F64: + return ConstantR0(builder, static_cast(value)); + case C64: + return ConstantR0(builder, static_cast(value)); + case U8: + return ConstantR0(builder, static_cast(value)); + case U32: + return ConstantR0(builder, static_cast(value)); + case U64: + return ConstantR0(builder, static_cast(value)); + case S8: + return ConstantR0(builder, static_cast(value)); + case S32: + return ConstantR0(builder, static_cast(value)); + case S64: + return ConstantR0(builder, static_cast(value)); + default: + return builder->ReportError( + InvalidArgument("Invalid type for ConstantR0WithType (%s).", + PrimitiveType_Name(type).c_str())); + } +} + +// Returns a scalar containing 'value' cast to the same run-time type as +// 'prototype'. +// If 'value' is floating point but 'prototype' is not, or if 'value' is complex +// 'prototype' is not, an error will be returned. +template +XlaOp ScalarLike(XlaOp prototype, T value) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + return ConstantR0WithType(builder, shape.element_type(), value); + }); +} + +// Returns a scalar with value '0' of 'type'. +XlaOp Zero(XlaBuilder* builder, PrimitiveType type); + +// Returns a zero-filled tensor with shape `shape`. +XlaOp Zeros(XlaBuilder* builder, const Shape& shape); + +// Returns a zero-filled tensor with the same shape as `prototype`. +XlaOp ZerosLike(XlaOp prototype); + +// Returns a scalar with value '1' of 'type'. +XlaOp One(XlaBuilder* builder, PrimitiveType type); + +// Returns the machine epsilon for floating-point type `type`, i.e., +// the difference between 1.0 and the next representable value. +XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type); + +// Returns the minimum representable finite or infinite value for 'type'. +// Returns '-inf' for floating-point types. +XlaOp MinValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the minimum representable finite value for 'type'. For a floating +// point type, this is equal to -MaxFiniteValue(). +XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the maximum representable finite or infinite value for 'type'. +// Returns 'inf' for floating-point types. +XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the maximum representable finite value for 'type'. +XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4320f65c1f76d4d4c384110b39d6606773aaf01 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/constants_test.cc @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using ConstantsTest = ClientLibraryTestBase; + +using ::testing::HasSubstr; + +XLA_TEST_F(ConstantsTest, ConstantR0WithTypeS32) { + XlaBuilder builder(TestName()); + ConstantR0WithType(&builder, xla::S32, 4); + ComputeAndCompareR0(&builder, 4, {}); +} + +XLA_TEST_F(ConstantsTest, ConstantR0WithTypeS32DoesNotAcceptFloats) { + XlaBuilder builder(TestName()); + ConstantR0WithType(&builder, xla::S32, 4.5); + auto statusor = builder.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("Invalid cast")); +} + +XLA_TEST_F(ConstantsTest, ConstantR0WithTypeF32) { + XlaBuilder builder(TestName()); + ConstantR0WithType(&builder, xla::F32, -7); + ComputeAndCompareR0(&builder, -7, {}); + ConstantR0WithType(&builder, xla::F32, 0.5); + ComputeAndCompareR0(&builder, 0.5, {}); +} + +XLA_TEST_F(ConstantsTest, ScalarLikeS32) { + XlaBuilder builder(TestName()); + ScalarLike(ConstantR0(&builder, 42), -3); + ComputeAndCompareR0(&builder, -3, {}); +} + +XLA_TEST_F(ConstantsTest, ScalarLikeF32) { + XlaBuilder builder(TestName()); + ScalarLike(ConstantR0(&builder, 42.75), -3.2); + ComputeAndCompareR0(&builder, -3.2, {}); +} + +XLA_TEST_F(ConstantsTest, ZeroS32) { + XlaBuilder builder(TestName()); + Zero(&builder, S32); + ComputeAndCompareR0(&builder, 0, {}); +} + +XLA_TEST_F(ConstantsTest, ZeroF32) { + XlaBuilder builder(TestName()); + Zero(&builder, F32); + ComputeAndCompareR0(&builder, 0.0, {}); +} + +XLA_TEST_F(ConstantsTest, ZerosS32) { + XlaBuilder builder(TestName()); + Zeros(&builder, ShapeUtil::MakeShape(S32, {2, 2})); + ComputeAndCompareR2(&builder, {{0, 0}, {0, 0}}, {}); +} + +XLA_TEST_F(ConstantsTest, ZerosLikeF32) { + XlaBuilder builder(TestName()); + ZerosLike(ConstantR1(&builder, {1., 2., 3.})); + ComputeAndCompareR1(&builder, {0., 0., 0.}, {}); +} + +XLA_TEST_F(ConstantsTest, OneS32) { + XlaBuilder builder(TestName()); + One(&builder, S32); + ComputeAndCompareR0(&builder, 1, {}); +} + +XLA_TEST_F(ConstantsTest, OneF32) { + XlaBuilder builder(TestName()); + One(&builder, F32); + ComputeAndCompareR0(&builder, 1., {}); +} + +XLA_TEST_F(ConstantsTest, EpsilonF32) { + XlaBuilder builder(TestName()); + Epsilon(&builder, F32); + ComputeAndCompareR0(&builder, std::numeric_limits::epsilon(), + {}); +} + +XLA_TEST_F(ConstantsTest, MinFiniteValueS32) { + XlaBuilder builder(TestName()); + MinFiniteValue(&builder, S32); + ComputeAndCompareR0(&builder, std::numeric_limits::min(), {}); +} + +XLA_TEST_F(ConstantsTest, MaxFiniteValueS32) { + XlaBuilder builder(TestName()); + MaxFiniteValue(&builder, S32); + ComputeAndCompareR0(&builder, std::numeric_limits::max(), {}); +} + +XLA_TEST_F(ConstantsTest, MinFiniteValueF32) { + XlaBuilder builder(TestName()); + MinFiniteValue(&builder, F32); + ComputeAndCompareR0(&builder, -std::numeric_limits::max(), {}); +} + +XLA_TEST_F(ConstantsTest, MaxFiniteValueF32) { + XlaBuilder builder(TestName()); + MaxFiniteValue(&builder, F32); + ComputeAndCompareR0(&builder, std::numeric_limits::max(), {}); +} + +XLA_TEST_F(ConstantsTest, MinValueS32) { + XlaBuilder builder(TestName()); + MinValue(&builder, S32); + ComputeAndCompareR0(&builder, std::numeric_limits::min(), {}); +} + +XLA_TEST_F(ConstantsTest, MaxValueS32) { + XlaBuilder builder(TestName()); + MaxValue(&builder, S32); + ComputeAndCompareR0(&builder, std::numeric_limits::max(), {}); +} + +XLA_TEST_F(ConstantsTest, MinValueF32) { + XlaBuilder builder(TestName()); + MinValue(&builder, F32); + ComputeAndCompareR0(&builder, -std::numeric_limits::infinity(), + {}); +} + +XLA_TEST_F(ConstantsTest, MaxValueF32) { + XlaBuilder builder(TestName()); + MaxValue(&builder, F32); + ComputeAndCompareR0(&builder, std::numeric_limits::infinity(), + {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc new file mode 100644 index 0000000000000000000000000000000000000000..0221de7672c7b7c02b1f8b9c7ff4f92151e567c6 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -0,0 +1,304 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/math.h" + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); } + +XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); } + +XlaOp Square(XlaOp operand) { return operand * operand; } + +XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; } + +namespace { + +// Polynomials for computing erf/erfc. Originally from cephes. +// Note we use float for compatibility across devices, at the cost of some +// precision for 64 bit computations. +// +// Coefficients are in descending order. +std::array kErfcPCoefficient = { + 2.46196981473530512524E-10, 5.64189564831068821977E-1, + 7.46321056442269912687E0, 4.86371970985681366614E1, + 1.96520832956077098242E2, 5.26445194995477358631E2, + 9.34528527171957607540E2, 1.02755188689515710272E3, + 5.57535335369399327526E2}; +std::array kErfcQCoefficient = { + 1.00000000000000000000E0, 1.32281951154744992508E1, + 8.67072140885989742329E1, 3.54937778887819891062E2, + 9.75708501743205489753E2, 1.82390916687909736289E3, + 2.24633760818710981792E3, 1.65666309194161350182E3, + 5.57535340817727675546E2}; +std::array kErfcRCoefficient = { + 5.64189583547755073984E-1, 1.27536670759978104416E0, + 5.01905042251180477414E0, 6.16021097993053585195E0, + 7.40974269950448939160E0, 2.97886665372100240670E0}; +std::array kErfcSCoefficient = { + 1.00000000000000000000E0, 2.26052863220117276590E0, + 9.39603524938001434673E0, 1.20489539808096656605E1, + 1.70814450747565897222E1, 9.60896809063285878198E0, + 3.36907645100081516050E0}; +std::array kErfTCoefficient = { + 9.60497373987051638749E0, 9.00260197203842689217E1, + 2.23200534594684319226E3, 7.00332514112805075473E3, + 5.55923013010394962768E4}; +std::array kErfUCoefficient = { + 1.00000000000000000000E0, 3.35617141647503099647E1, + 5.21357949780152679795E2, 4.59432382970980127987E3, + 2.26290000613890934246E4, 4.92673942608635921086E4}; +} // namespace + +// Evaluate the polynomial given coefficients and `x`. +// N.B. Coefficients should be supplied in decreasing order. +XlaOp EvaluatePolynomial(XlaOp x, + tensorflow::gtl::ArraySlice coefficients) { + XlaOp poly = ScalarLike(x, 0.0); + for (float c : coefficients) { + poly = poly * x + ScalarLike(x, c); + } + return poly; +} + +// Compute an approximation of the error function complement (1 - erf(x)). +XlaOp Erfc(XlaOp x) { + XlaOp abs_x = Abs(x); + XlaOp z = Exp(-x * x); + + XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient); + XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient); + XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient); + XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient); + + XlaOp y = Select(Lt(abs_x, ScalarLike(x, 8.0)), z * pp / pq, z * pr / ps); + + return Select(Lt(x, ScalarLike(x, 0.0)), ScalarLike(x, 2.0) - y, y); +} + +// Compute a polynomial approximation of the error function. +XlaOp Erf(XlaOp x) { + XlaOp z = x * x; + XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient); + XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient); + return x * pt / pu; +} + +// Approximation for the inverse error function from +// Giles, M., "Approximating the erfinv function". +// The approximation has the form: +// w = -log((1 - x) * (1 + x)) +// if ( w < 5 ) { +// w = w - 2.5 +// p = sum_{i=1}^n lq[i]*w^i +// } else { +// w = sqrt(w) - 3 +// p = sum_{i=1}^n gq[i]*w^i +// } +// return p*x +XlaOp ErfInv(XlaOp x) { + XlaBuilder* b = x.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); + constexpr int kDegree = 9; + constexpr std::array w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + + auto one = ScalarLike(x, 1.0); + auto w = -Log((one - x) * (one + x)); + + auto lt = Lt(w, ScalarLike(x, 5.0)); + auto coefficient = [&](int i) { + return Select(lt, + Broadcast(ScalarLike(x, w_less_than_5_constants[i]), + AsInt64Slice(shape.dimensions())), + Broadcast(ScalarLike(x, w_greater_than_5_constants[i]), + AsInt64Slice(shape.dimensions()))); + }; + w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0)); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = coefficient(i) + p * w; + } + return p * x; + }); +} + +namespace { +// Coefficients for the Lanczos approximation of the gamma function. The +// coefficients are uniquely determined by the choice of g and n (kLanczosGamma +// and kLanczosCoefficients.size() + 1). The coefficients below correspond to +// [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7, +// 9] seemed to be the least sensitive to the quality of the log function. In +// particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5 +// for a particularly inaccurate log function. +static constexpr double kLanczosGamma = 7; // aka g +static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478; +static constexpr std::array kLanczosCoefficients = { + 676.520368121885098567009190444019, -1259.13921672240287047156078755283, + 771.3234287776530788486528258894, -176.61502916214059906584551354, + 12.507343278686904814458936853, -0.13857109526572011689554707, + 9.984369578019570859563e-6, 1.50563273514931155834e-7}; +} // namespace + +// Compute the Lgamma function using Lanczos' approximation from "A Precision +// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis +// series B. Vol. 1: +// lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z) +// t(z) = z + kLanczosGamma + 1/2 +// A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) +XlaOp Lgamma(XlaOp input) { + XlaOp one_half = ScalarLike(input, 0.5); + XlaOp one = ScalarLike(input, 1); + + XlaOp pi = ScalarLike(input, M_PI); + XlaOp log_pi = ScalarLike(input, std::log(M_PI)); + XlaOp log_sqrt_two_pi = ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2); + + XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); + XlaOp log_lanczos_gamma_plus_one_half = + ScalarLike(input, std::log(kLanczosGamma + 0.5)); + + XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); + + // If the input is less than 0.5 use Gauss's reflection formula: + // gamma(x) = pi / sin(pi * x) * gamma(1 - x) + XlaOp need_to_reflect = Lt(Real(input), one_half); + XlaOp z = Select(need_to_reflect, -input, input - one); + + XlaOp x = base_lanczos_coeff; + for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); + XlaOp index = ScalarLike(input, i); + x = x + lanczos_coefficient / (z + index + one); + } + + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + XlaOp t = lanczos_gamma_plus_one_half + z; + XlaOp log_t = + log_lanczos_gamma_plus_one_half + Log1p(z / lanczos_gamma_plus_one_half); + + XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x); + + XlaOp reflection = log_pi - Log(Sin(pi * input)) - log_y; + XlaOp result = Select(need_to_reflect, reflection, log_y); + return result; +} + +// Compute the Digamma function using Lanczos' approximation from "A Precision +// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis +// series B. Vol. 1: +// digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z) +// t(z) = z + kLanczosGamma + 1/2 +// A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) +// A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) +XlaOp Digamma(XlaOp input) { + XlaOp zero = ScalarLike(input, 0); + XlaOp one_half = ScalarLike(input, 0.5); + XlaOp one = ScalarLike(input, 1); + + XlaOp pi = ScalarLike(input, M_PI); + + XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma); + XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); + XlaOp log_lanczos_gamma_plus_one_half = + ScalarLike(input, std::log(kLanczosGamma + 0.5)); + + XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); + + // If the input is less than 0.5 use Gauss's reflection formula: + // digamma(x) = digamma(1 - x) - pi * cot(pi * x) + XlaOp need_to_reflect = Lt(Real(input), one_half); + XlaOp z = Select(need_to_reflect, -input, input - one); + + XlaOp num = zero; + XlaOp denom = base_lanczos_coeff; + for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); + XlaOp index = ScalarLike(input, i); + num = num - lanczos_coefficient / ((z + index + one) * (z + index + one)); + denom = denom + lanczos_coefficient / (z + index + one); + } + + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + XlaOp t = lanczos_gamma_plus_one_half + z; + XlaOp log_t = + log_lanczos_gamma_plus_one_half + Log1p(z / lanczos_gamma_plus_one_half); + + XlaOp y = log_t + num / denom - lanczos_gamma / t; + XlaOp reflection = y - pi * Cos(pi * input) / Sin(pi * input); + XlaOp result = Select(need_to_reflect, reflection, y); + return result; +} + +// Trigonometric functions. + +// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) +XlaOp Acos(XlaOp x) { + return ScalarLike(x, 2.0) * + Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), ScalarLike(x, 1.0) + x); +} + +// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) +XlaOp Asin(XlaOp x) { + return ScalarLike(x, 2.0) * + Atan2(x, ScalarLike(x, 1.0) + Sqrt(ScalarLike(x, 1.0) - x * x)); +} + +XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); } + +XlaOp Tan(XlaOp x) { return Sin(x) / Cos(x); } + +// Hyperbolic trigonometric functions. + +// acosh(x) = log(x + sqrt(x^2 - 1)) +// = log(x + sqrt((x+1)*(x-1))) +XlaOp Acosh(XlaOp x) { + return Log(x + Sqrt((x + ScalarLike(x, 1.0)) * (x - ScalarLike(x, 1.0)))); +} + +// asinh(x) = log(x + sqrt(x^2 + 1)) +XlaOp Asinh(XlaOp x) { return Log(x + Sqrt(x * x + ScalarLike(x, 1.0))); } + +// atanh(x) = 0.5 * log((1 + x) / (1 - x)) +XlaOp Atanh(XlaOp x) { + return Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) * + ScalarLike(x, 0.5); +} + +XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); } + +XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); } + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h new file mode 100644 index 0000000000000000000000000000000000000000..13db2325569cf2e25e3ff1200adf4b2544dc2f73 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" + +namespace xla { + +// Computes the square root of 'operand'. +XlaOp Sqrt(XlaOp operand); + +// Computes the reciprocal of the square root of 'operand'. +XlaOp Rsqrt(XlaOp operand); + +// Computes the square of 'operand'. +XlaOp Square(XlaOp operand); + +// Computes the reciprocal of 'operand'. +XlaOp Reciprocal(XlaOp operand); + +// Evaluates a polynomial given coefficients and `x`. +// N.B. Coefficients should be supplied in decreasing order. +XlaOp EvaluatePolynomial(XlaOp x, + tensorflow::gtl::ArraySlice coefficients); + +// Computes an approximation of the error function complement (1 - erf(x)). +XlaOp Erfc(XlaOp x); + +// Computes an approximation of the error function. +XlaOp Erf(XlaOp x); + +// Computes an approximation of the inverse of the error function. +XlaOp ErfInv(XlaOp x); + +// Computes an approximation of the lgamma function. +XlaOp Lgamma(XlaOp input); + +// Computes an approximation of the digamma function. +XlaOp Digamma(XlaOp input); + +// Trigonometric functions + +// Computes the arc cosine of 'x'. +XlaOp Acos(XlaOp x); + +// Computes the arc sine of 'x'. +XlaOp Asin(XlaOp x); + +// Computes the arc tangent of 'x'. +XlaOp Atan(XlaOp x); + +// Computes the tangent of 'x'. +XlaOp Tan(XlaOp x); + +// Hyperbolic trigonometric functions + +// Computes the inverse hyperbolic cosine of 'x'. +XlaOp Acosh(XlaOp x); + +// Computes the inverse hyperbolic sine of 'x'. +XlaOp Asinh(XlaOp x); + +// Computes the inverse hyperbolic tangent of 'x'. +XlaOp Atanh(XlaOp x); + +// Computes the hyperbolic cosine of 'x'. +XlaOp Cosh(XlaOp x); + +// Computes the hyperbolic sine of 'x'. +XlaOp Sinh(XlaOp x); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..14c259a7fa2a47642663b65d2785e5bbdc040cfd --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -0,0 +1,140 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class MathTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; +}; + +XLA_TEST_F(MathTest, SqrtF32) { + XlaBuilder builder(TestName()); + Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32); + + std::unique_ptr zero_data = + client_->TransferToServer(zero_literal).ConsumeValueOrDie(); + + XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero"); + Sqrt(zero); + + ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); +} + +XLA_TEST_F(MathTest, SquareTenValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Square(x); + + std::vector expected = {4.41, 6.76, 6.76, 16., 4.41, + 5.29, 25., 0.81, 5.76, 2.56}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, ReciprocalTenValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Reciprocal(x); + + std::vector expected = { + 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048, + 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, SqrtZeroes) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {0.0, -0.0}); + Sqrt(x); + + ComputeAndCompareR1(&builder, {0, 0}, {}, error_spec_); +} + +XLA_TEST_F(MathTest, SqrtSixValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); + Sqrt(x); + + std::vector expected = {4, 1, 32, 0.4, 0.4472, 111.1080}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, Lgamma) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5, + 2.5, -1.5, -3.5, -5.5}); + Lgamma(x); + + std::vector expected = { + 0, + 0, + static_cast(std::log(2)), + static_cast(std::log(6)), + static_cast(std::log(24)), + static_cast(std::log(120)), + static_cast(std::log(M_PI) / 2), + static_cast(std::log(M_PI) / 2 - std::log(2)), + static_cast(std::log(M_PI) / 2 - std::log(4) + std::log(3)), + static_cast(std::log(M_PI) / 2 - std::log(3) + std::log(4)), + static_cast(std::log(M_PI) / 2 - std::log(105) + std::log(16)), + static_cast(std::log(M_PI) / 2 - std::log(10395) + std::log(64))}; + error_spec_ = ErrorSpec{0.001}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, Digamma) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {1.0, 0.5, 1 / 3.0, 0.25, 1 / 6.0, 0.125, + 2.0, 3.0, 4.0, 6.0, 8.0, 9.0}); + Digamma(x); + + constexpr double euler_mascheroni = + 0.57721566490153286060651209008240243104215933593992; + std::vector expected = { + static_cast(-euler_mascheroni), + static_cast(-2 * std::log(2) - euler_mascheroni), + static_cast(-M_PI / 2 / std::sqrt(3) - 3 * std::log(3) / 2 - + euler_mascheroni), + static_cast(-M_PI / 2 - 3 * std::log(2) - euler_mascheroni), + static_cast(-M_PI * std::sqrt(3) / 2 - 2 * std::log(2) - + 3 * std::log(3) / 2 - euler_mascheroni), + static_cast( + -M_PI / 2 - 4 * std::log(2) - + (M_PI + std::log(2 + std::sqrt(2)) - std::log(2 - std::sqrt(2))) / + std::sqrt(2) - + euler_mascheroni), + static_cast(1 - euler_mascheroni), + static_cast(1.5 - euler_mascheroni), + static_cast(11 / 6.0 - euler_mascheroni), + static_cast(137 / 60.0 - euler_mascheroni), + static_cast(363 / 140.0 - euler_mascheroni), + static_cast(761 / 280.0 - euler_mascheroni)}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c91237ae1574f92cda78c9bddc6f4ac1d68f47c --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric.cc @@ -0,0 +1,137 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +namespace { + +template +XlaOp MakeIota(XlaBuilder* builder, int64 size) { + std::vector values(size); + for (int64 i = 0; i < size; ++i) { + values[i] = static_cast(i); + } + return ConstantR1(builder, values); +} + +} // namespace + +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { + switch (type) { + case S8: + return MakeIota(builder, size); + case S16: + return MakeIota(builder, size); + case S32: + return MakeIota(builder, size); + case S64: + return MakeIota(builder, size); + case U8: + return MakeIota(builder, size); + case U16: + return MakeIota(builder, size); + case U32: + return MakeIota(builder, size); + case U64: + return MakeIota(builder, size); + case BF16: + return MakeIota(builder, size); + case F16: + return MakeIota(builder, size); + case F32: + return MakeIota(builder, size); + case F64: + return MakeIota(builder, size); + case C64: + return MakeIota(builder, size); + default: + return builder->ReportError( + InvalidArgument("Unimplemented type for Iota: %s.", + PrimitiveType_Name(type).c_str())); + } +} + +XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, + int64 n) { + auto a = Iota(builder, type, m); + auto b = Iota(builder, type, n); + auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); + return ConvertElementType(indicator, type); +} + +XlaOp GetMatrixDiagonal(XlaOp x) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + tensorflow::gtl::ArraySlice major_dims( + AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, U32, n); + auto b = Iota(builder, U32, m); + auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + auto mask = Broadcast(indicator, major_dims); + + // TPUs don't support S64 add reduction at the moment. But fortunately + // OR-reductions work just as well for integers. + XlaComputation reducer = + primitive_util::IsIntegralType(shape.element_type()) + ? CreateScalarOrComputation(shape.element_type(), builder) + : CreateScalarAddComputation(shape.element_type(), builder); + + return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), + reducer, {m >= n ? n_dims - 2 : n_dims - 1}); + }); +} + +XlaOp Triangle(XlaOp x, bool lower) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + tensorflow::gtl::ArraySlice major_dims( + AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, U32, n); + auto b = Iota(builder, U32, m); + xla::XlaOp indicator; + if (lower) { + indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + } else { + indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + } + auto mask = Broadcast(indicator, major_dims); + + return Select(mask, x, Zeros(builder, shape)); + }); +} + +XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } + +XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h new file mode 100644 index 0000000000000000000000000000000000000000..efd8cdc25724198633e0bf1c48c4e7d9e4b4c9e1 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a rank 1 tensor of `type` containing values [0, 1, 2, ...]. +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); + +// Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere +// else. +XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); + +// Get the diagonals of the last two dimensions. If 'x' has shape +// [..., M, N], then the output has shape [..., min(M, N)], containing the +// diagonal elements (i.e., with indices [..., i, i]). +XlaOp GetMatrixDiagonal(XlaOp x); + +// Get the upper or lower triangle part of the last two dimensions +XlaOp Triangle(XlaOp x, bool lower); + +// Get the upper triangle part of the last two dimensions +XlaOp UpperTriangle(XlaOp x); + +// Get the lower triangle part of the last two dimensions +XlaOp LowerTriangle(XlaOp x); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8a96ec68d2dca8485215258b1f6731b934e6f2a8 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class NumericTest : public ClientLibraryTestBase { + protected: + template + void TestMatrixDiagonal(); +}; + +// TODO(b/64798317): Delete this test case once xla::IotaGen is converted to +// xla::Iota. This test is already implemented for xla::IotaGen in +// xla/tests/iota_test.cc. +XLA_TEST_F(NumericTest, Iota) { + XlaBuilder builder(TestName()); + Iota(&builder, S32, 10); + + ComputeAndCompareR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {}); +} + +XLA_TEST_F(NumericTest, Triangle) { + XlaBuilder builder(TestName()); + Array3D input(2, 3, 4); + input.FillIota(0); + + XlaOp a; + auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); + LowerTriangle(a); + Array3D expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}}, + {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}}); + + ComputeAndCompareR3(&builder, expected, {a_data.get()}); +} + +template +void NumericTest::TestMatrixDiagonal() { + XlaBuilder builder("GetMatrixDiagonal"); + Array3D input(2, 3, 4); + input.FillIota(0); + + XlaOp a; + auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); + GetMatrixDiagonal(a); + Array2D expected({{0, 5, 10}, {12, 17, 22}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}); +} + +XLA_TEST_F(NumericTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } + +XLA_TEST_F(NumericTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } + +XLA_TEST_F(NumericTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/pooling.cc b/tensorflow/compiler/xla/client/lib/pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..7199269a6c889f3589c1148687faf0bb2aaae90a --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" + +namespace xla { + +namespace { + +// Common computation shared between AvgPool and AvgPoolGrad. Divide each +// element of an image by the count of elements that contributed to that +// element during pooling. +XlaOp AvgPoolDivideByCountWithGeneralPadding( + XlaOp sums, PrimitiveType dtype, + tensorflow::gtl::ArraySlice input_shape, + tensorflow::gtl::ArraySlice> spatial_padding, + tensorflow::gtl::ArraySlice ksize, + tensorflow::gtl::ArraySlice stride, + const TensorFormat& data_format) { + // The padding shouldn't be included in the counts. We use another + // ReduceWindow to find the right counts. + const int num_spatial_dims = spatial_padding.size(); + + std::vector input_dim_sizes(num_spatial_dims); + std::vector window_dims(num_spatial_dims); + std::vector window_ksize(num_spatial_dims); + std::vector window_stride(num_spatial_dims); + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + input_dim_sizes[i] = input_shape[dim]; + window_dims[i] = dim; + window_ksize[i] = ksize[dim]; + window_stride[i] = stride[dim]; + } + + XlaBuilder* b = sums.builder(); + // Build a matrix of all 1s, with the same width/height as the input. + auto ones = Broadcast(One(b, dtype), input_dim_sizes); + PaddingConfig padding_config; + for (int i = 0; i < num_spatial_dims; ++i) { + auto dims = padding_config.add_dimensions(); + dims->set_edge_padding_low(spatial_padding[i].first); + dims->set_edge_padding_high(spatial_padding[i].second); + } + auto zero = Zero(b, dtype); + auto padded_ones = Pad(ones, zero, padding_config); + + // Perform a ReduceWindow with the same window size, strides, and padding + // to count the number of contributions to each result element. + auto counts = + ReduceWindow(padded_ones, zero, CreateScalarAddComputation(dtype, b), + window_ksize, window_stride, Padding::kValid); + + return Div(sums, counts, window_dims); +} + +// Sums all elements in the window specified by 'kernel_size' and 'stride'. +XlaOp ComputeSums(XlaOp operand, XlaOp init_value, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + const TensorFormat& data_format) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value)); + PrimitiveType accumulation_type = init_shape.element_type(); + auto add_computation = CreateScalarAddComputation(accumulation_type, b); + return ReduceWindow(operand, init_value, add_computation, kernel_size, + stride, Padding::kValid); + }); +} + +// Creates a padding configuration out of spatial padding values. +PaddingConfig MakeSpatialPaddingConfig( + tensorflow::gtl::ArraySlice> spatial_padding, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + const TensorFormat& data_format) { + const int num_spatial_dims = kernel_size.size() - 2; + PaddingConfig padding_config; + for (int i = 0; i < 2 + num_spatial_dims; ++i) { + padding_config.add_dimensions(); + } + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + auto padding_dimension = padding_config.mutable_dimensions(dim); + padding_dimension->set_edge_padding_low(spatial_padding[i].first); + padding_dimension->set_edge_padding_high(spatial_padding[i].second); + } + return padding_config; +} + +} // namespace + +XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + PrimitiveType dtype = operand_shape.element_type(); + auto max_computation = CreateScalarMaxComputation(dtype, b); + auto init_value = MinValue(b, dtype); + return ReduceWindow(operand, init_value, max_computation, kernel_size, + stride, padding); + }); +} + +XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + tensorflow::gtl::ArraySlice> padding, + const TensorFormat& data_format, + const bool counts_include_padding) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + PrimitiveType dtype = operand_shape.element_type(); + auto init_value = Zero(b, dtype); + std::vector input_size(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + auto padding_config = + MakeSpatialPaddingConfig(padding, kernel_size, stride, data_format); + auto padded_operand = Pad(operand, Zero(b, dtype), padding_config); + auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride, + data_format); + if (counts_include_padding) { + // If counts include padding, all windows have the same number of elements + // contributing to each average. Divide by the window size everywhere to + // get the average. + int64 window_size = + std::accumulate(kernel_size.begin(), kernel_size.end(), 1, + [](int64 x, int64 y) { return x * y; }); + + auto divisor = ConstantR0WithType(b, dtype, window_size); + return pooled / divisor; + } else { + return AvgPoolDivideByCountWithGeneralPadding( + pooled, dtype, input_size, padding, kernel_size, stride, data_format); + } + }); +} + +std::vector> MakeSpatialPadding( + tensorflow::gtl::ArraySlice input_size, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format) { + const int num_spatial_dims = kernel_size.size() - 2; + std::vector input_spatial_dimensions; + std::vector kernel_size_spatial_dimensions; + std::vector stride_spatial_dimensions; + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + input_spatial_dimensions.push_back(input_size[dim]); + kernel_size_spatial_dimensions.push_back(kernel_size[dim]); + stride_spatial_dimensions.push_back(stride[dim]); + } + return MakePadding(input_spatial_dimensions, kernel_size_spatial_dimensions, + stride_spatial_dimensions, padding); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/pooling.h b/tensorflow/compiler/xla/client/lib/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..1699c585d3b09a306c21cfa797a9023a8463bd1f --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling.h @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace xla { + +// Tensor format for reduce window operations. +class TensorFormat { + public: + TensorFormat(int batch_dimension, int feature_dimension, + tensorflow::gtl::ArraySlice spatial_dimensions) + : batch_dimension_(batch_dimension), + feature_dimension_(feature_dimension), + spatial_dimensions_(spatial_dimensions.begin(), + spatial_dimensions.end()) {} + + int batch_dimension() const { return batch_dimension_; } + + int feature_dimension() const { return feature_dimension_; } + + int spatial_dimension(int dim) const { return spatial_dimensions_[dim]; } + + int num_spatial_dims() const { return spatial_dimensions_.size(); } + + private: + // The number of the dimension that represents the batch. + int batch_dimension_; + // The number of the dimension that represents the features. + int feature_dimension_; + // The dimension numbers for the spatial dimensions. + tensorflow::gtl::InlinedVector spatial_dimensions_; +}; + +// Computes the max pool of 'operand'. +XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format); + +// Computes the average pool of 'operand'. +XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + tensorflow::gtl::ArraySlice> padding, + const TensorFormat& data_format, + const bool counts_include_padding); + +// Returns the list of low and high padding elements in each spatial dimension +// for the given 'padding' specification. +std::vector> MakeSpatialPadding( + tensorflow::gtl::ArraySlice input_size, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ diff --git a/tensorflow/compiler/xla/client/lib/pooling_test.cc b/tensorflow/compiler/xla/client/lib/pooling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b4553b60db555ad7c2ab6b695236df745e30683 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling_test.cc @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +TensorFormat MakeNCHWFormat(int num_spatial_dims) { + tensorflow::gtl::InlinedVector spatial_dimensions; + for (int i = 0; i < num_spatial_dims; ++i) { + spatial_dimensions.push_back(i + 2); + } + return TensorFormat(/*batch_dimension=*/0, /*feature_dimension=*/1, + /*spatial_dimensions=*/spatial_dimensions); +} + +std::vector> MakeGeneralPadding( + XlaOp input, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const xla::TensorFormat& data_format) { + XlaBuilder* b = input.builder(); + Shape operand_shape = b->GetShape(input).ValueOrDie(); + std::vector input_size(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + return MakeSpatialPadding(input_size, kernel_size, stride, padding, + data_format); +} + +// Add singleton batch and feature dimensions to spatial dimensions, according +// to 'data_format' specification. +std::vector ExpandWithBatchAndFeatureDimensions( + tensorflow::gtl::ArraySlice spatial_dim_sizes, + const xla::TensorFormat& data_format) { + const int num_spatial_dims = spatial_dim_sizes.size(); + std::vector tensor_sizes(num_spatial_dims + 2, 1); + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + tensor_sizes[dim] = spatial_dim_sizes[i]; + } + return tensor_sizes; +} + +class PoolingTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; +}; + +XLA_TEST_F(PoolingTest, MaxPool2D) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + MaxPool(input, kernel_size, stride, Padding::kValid, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, MaxPool2DWithPadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + MaxPool(input, kernel_size, stride, Padding::kSame, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4, 5}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, MaxPool2DWithPaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + MaxPool(input, kernel_size, stride, Padding::kSame, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4, 4, 5, 5}, {5, 4, 3, 2, 1}}}}, + {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2D) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kValid, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/true); + + ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithPadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{3, 3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithPaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, + {{{{3, 3, 3, 3, 3}, {4.5, 3.5, 2.5, 1.5, 1}}}}, {}, + error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithGeneralPaddingCountNotIncludePadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format); + auto stride = kernel_size; + AvgPool(input, kernel_size, stride, {{1, 1}, {2, 1}}, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, + AvgPool2DWithGeneralPaddingCountNotIncludePaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + AvgPool(input, kernel_size, stride, {{2, 1}, {1, 1}}, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{1.5, 3, 4.5}, {3, 3, 3}}}}, {}, + error_spec_); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ef81689489d8117d5951bcb75693c2e3413e4d6 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -0,0 +1,150 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" + +namespace xla { +namespace { + +// Rotates a 32-bit integer 'v' left by 'distance' bits. +XlaOp RotateLeftS32(XlaOp v, int distance) { + return (v << ConstantR0(v.builder(), distance)) | + ShiftRightLogical(v, ConstantR0(v.builder(), 32 - distance)); +} + +using ThreeFry2x32State = std::array; + +// Implements the ThreeFry counter-based PRNG algorithm. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { + XlaBuilder* builder = input[0].builder(); + // Rotation distances specified by the Threefry2x32 algorithm. + constexpr std::array rotations = {13, 15, 26, 6, 17, 29, 16, 24}; + ThreeFry2x32State x; + + std::array ks; + // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. + ks[2] = ConstantR0(builder, 0x1BD11BDA); + for (int i = 0; i < 2; ++i) { + ks[i] = key[i]; + x[i] = input[i]; + ks[2] = ks[2] ^ key[i]; + } + + x[0] = x[0] + ks[0]; + x[1] = x[1] + ks[1]; + + // Performs a single round of the Threefry2x32 algorithm, with a rotation + // amount 'rotation'. + auto round = [](ThreeFry2x32State v, int rotation) { + v[0] = v[0] + v[1]; + v[1] = RotateLeftS32(v[1], rotation); + v[1] = v[0] ^ v[1]; + return v; + }; + + // There are no known statistical flaws with 13 rounds of Threefry2x32. + // We are conservative and use 20 rounds. + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = x[0] + ks[1]; + x[1] = x[1] + ks[2] + ConstantR0(builder, 1); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = x[0] + ks[2]; + x[1] = x[1] + ks[0] + ConstantR0(builder, 2); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = x[0] + ks[0]; + x[1] = x[1] + ks[1] + ConstantR0(builder, 3); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = x[0] + ks[1]; + x[1] = x[1] + ks[2] + ConstantR0(builder, 4); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = x[0] + ks[2]; + x[1] = x[1] + ks[0] + ConstantR0(builder, 5); + + return x; +} + +} // namespace + +XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, + XlaOp minval, XlaOp maxval) { + XlaBuilder* builder = seeds[0].builder(); + if (shape.element_type() != F32) { + return builder->ReportError(Unimplemented( + "Types other than F32 are not implemented by StatelessRngUniform.")); + } + ThreeFry2x32State key = seeds; + const int64 size = ShapeUtil::ElementsIn(shape); + + const int64 half_size = CeilOfRatio(size, 2); + const bool size_is_odd = (half_size * 2 != size); + + // Fill the generator inputs with unique counter values. + ThreeFry2x32State inputs; + inputs[0] = Iota(builder, S32, half_size); + inputs[1] = inputs[0] + ConstantR0(builder, half_size); + ThreeFry2x32State outputs = ThreeFry2x32(inputs, key); + + if (size_is_odd) { + outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1}); + } + + auto bits = Reshape(ConcatInDim(builder, outputs, 0), + AsInt64Slice(shape.dimensions())); + + // Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit + // forces the random bits into the mantissa. + constexpr int kFloatBits = 32; + constexpr int kMantissaBits = 23; + bits = ShiftRightLogical( + bits, ConstantR0(builder, kFloatBits - kMantissaBits)) | + ConstantR0(builder, tensorflow::bit_cast(1.0f)); + auto floats = BitcastConvertType(bits, F32); + + // We have a floating point number in the range [1.0, 2.0). + // Subtract 1.0f to shift to the range [0.0, 1.0) + floats = floats - ConstantR0(builder, 1.0f); + // Multiply and add to shift to the range [minval, maxval). + return floats * (maxval - minval) + minval; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h new file mode 100644 index 0000000000000000000000000000000000000000..ad000b1fa1d0655c8fccc0bb33379f2499b77f26 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a tensor containing 'shape' random values uniformly distributed in +// the range [minval, maxval). Requires 2 32-bit integer seeds. +// Currently only 'shape's of type F32 are implemented. +XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, + XlaOp minval, XlaOp maxval); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc new file mode 100644 index 0000000000000000000000000000000000000000..a904be259a3870a679b2c4699ec01e2a11b1ce46 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" + +namespace xla { + +XlaOp TopK(XlaOp input, int64 k) { + XlaBuilder* const builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + int last_dim = input_shape.dimensions_size() - 1; + int last_dim_size = input_shape.dimensions(last_dim); + + XlaOp iota_s32 = Iota(builder, S32, last_dim_size); + auto input_dims = input_shape.dimensions(); + std::vector broadcast_dims(input_dims.begin(), input_dims.end() - 1); + XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims); + XlaOp sort_result = Sort(Neg(input), broadcast_s32); + std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + limit_indices[last_dim] = k; + std::vector strides(input_shape.dimensions_size(), 1); + + XlaOp values = Neg(Slice(GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides)); + XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices, + limit_indices, strides); + return Tuple(builder, {values, indices}); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.h b/tensorflow/compiler/xla/client/lib/sorting.h new file mode 100644 index 0000000000000000000000000000000000000000..b9dfafdd6f957ae050e0f5dbd076d5288235b490 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a tuple composed of the top `k` values and corresponding indices in +// `input`. Output values are in descending order, from largest to smallest. +XlaOp TopK(XlaOp input, int64 k); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fef98c9923096e21a755c6d730de2c7c10852b2d --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using SortingTest = ClientLibraryTestBase; + +XLA_TEST_F(SortingTest, TopK3From8Values) { + XlaBuilder builder(TestName()); + auto x = + ConstantR1(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}); + xla::GetTupleElement(xla::TopK(x, 3), 0); + ComputeAndCompareR1(&builder, {7.0, 6.0, 5.0}, {}); +} + +XLA_TEST_F(SortingTest, TopK3From8Indices) { + XlaBuilder builder(TestName()); + auto x_rev = + ConstantR1(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0}); + xla::GetTupleElement(xla::TopK(x_rev, 3), 1); + ComputeAndCompareR1(&builder, {0, 1, 2}, {}); +} + +XLA_TEST_F(SortingTest, TopKFullSort) { + XlaBuilder builder(TestName()); + const int kSize = 16; + std::mt19937 eng; + std::uniform_real_distribution u_dist(0.0, 100.0); + auto gen = std::bind(u_dist, eng); + std::vector inputs(kSize); + std::generate(inputs.begin(), inputs.end(), gen); + auto x = ConstantR1(&builder, inputs); + xla::GetTupleElement(xla::TopK(x, kSize), 0); + + std::sort(inputs.begin(), inputs.end(), std::greater()); + ComputeAndCompareR1(&builder, inputs, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 3380af9f303b1dc2cec09aa37410ec40cdeaa526..081fec7ad92958aa285e4be41394d7b1876e0815 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -48,15 +48,15 @@ int64 DataSizeOfShape(const Shape& shape) { // Creates a XlaOp for an op what generates fake data with the given shape. XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { if (ShapeUtil::IsArray(shape)) { - return builder->Broadcast( - builder->ConstantLiteral(Literal::One(shape.element_type())), + return Broadcast( + ConstantLiteral(builder, LiteralUtil::One(shape.element_type())), AsInt64Slice(shape.dimensions())); } std::vector parts; for (const Shape& s : shape.tuple_shapes()) { parts.push_back(BuildFakeDataOpOnDevice(s, builder)); } - return builder->Tuple(parts); + return Tuple(builder, parts); } std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, @@ -98,14 +98,13 @@ std::vector> MakeFakeArgumentsOrDie( << "Computation should have progran shape."; auto program_shape = computation.proto().program_shape(); - // For every (unbound) parameter that the computation wants, we manufacture - // some arbitrary data so that we can invoke the computation. - std::vector> fake_arguments; - for (const Shape& parameter : program_shape.parameters()) { - fake_arguments.push_back(MakeFakeDataOrDie(parameter, client)); - } - - return fake_arguments; + // Create and run a program which produces a tuple with one element per + // parameter, then return the tuple's constituent buffers. + std::vector param_shapes(program_shape.parameters().begin(), + program_shape.parameters().end()); + auto fake_input_tuple = + MakeFakeDataOrDie(ShapeUtil::MakeTupleShape(param_shapes), client); + return client->DeconstructTuple(*fake_input_tuple).ValueOrDie(); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index dc613099e2b42a60d0c11a654ab5cd41f8bd4f6f..03695ce2a339735e3e49522f4fe1bbf2d83a3834 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index ae0308020d014e038d2f0fd7de6c5f372d6cbed1..cffb24e29beda6a8c40dca2fe709be22892dd489 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -18,10 +18,12 @@ limitations under the License. #include #include "llvm/ADT/Triple.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/source_map_util.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/status_macros.h" using xla::source_map_util::InvalidParameterArgument; @@ -29,8 +31,8 @@ using xla::source_map_util::InvalidParameterArgument; namespace xla { namespace { -StatusOr BorrowStreamForDevice(int device_ordinal, - Backend* backend) { +StatusOr BorrowStreamForDevice(int device_ordinal, + Backend* backend) { if (device_ordinal < 0) { device_ordinal = backend->default_device_ordinal(); } @@ -51,24 +53,17 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend) { - const ComputationLayout& host_computation_layout = - executable_->module_config().host_entry_computation_layout(); - const ComputationLayout& device_computation_layout = - executable_->module_config().device_entry_computation_layout(); + const ComputationLayout& computation_layout = + executable_->module_config().entry_computation_layout(); // Check argument number, shapes, and layouts. - if (arguments.size() != host_computation_layout.parameter_count()) { + if (arguments.size() != computation_layout.parameter_count()) { return InvalidArgument( "invalid number of arguments for computation: expected %d, got %zu", - host_computation_layout.parameter_count(), arguments.size()); - } - if (arguments.size() != device_computation_layout.parameter_count()) { - return InvalidArgument( - "invalid number of arguments for computation: expected %d, got %zu", - device_computation_layout.parameter_count(), arguments.size()); + computation_layout.parameter_count(), arguments.size()); } for (int i = 0; i < arguments.size(); ++i) { - if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape( + if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( arguments[i]->on_host_shape())) { return InvalidParameterArgument( executable_.get(), i, @@ -76,24 +71,10 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString( - host_computation_layout.parameter_layout(i).shape()) + ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) .c_str(), ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); } - if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape( - arguments[i]->on_device_shape())) { - return InvalidParameterArgument( - executable_.get(), i, - "Argument does not match device shape or layout of computation " - "parameter " - "%d: want %s, got %s", - i, - ShapeUtil::HumanString( - device_computation_layout.parameter_layout(i).shape()) - .c_str(), - ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str()); - } } if (run_options.stream() != nullptr) { @@ -120,11 +101,14 @@ Status LocalExecutable::ValidateExecutionOptions( } } - // Verify that the device the executable was built for is equivalent to the - // device it will run on. - int run_device_ordinal = run_options.device_ordinal() == -1 - ? backend_->default_device_ordinal() - : run_options.device_ordinal(); + // Verify that the device the executable was built for is equivalent + // to the device it will run on. + int run_device_ordinal = run_options.device_ordinal(); + if (run_device_ordinal == -1) { + run_device_ordinal = run_options.stream() != nullptr + ? run_options.stream()->parent()->device_ordinal() + : backend_->default_device_ordinal(); + } TF_ASSIGN_OR_RETURN(bool devices_equivalent, backend_->devices_equivalent( run_device_ordinal, build_options_.device_ordinal())); @@ -162,7 +146,7 @@ StatusOr LocalExecutable::Run( TF_RETURN_IF_ERROR( ValidateExecutionOptions(arguments, run_options, *backend_)); - Backend::StreamPtr stream; + StreamPool::Ptr stream; if (run_options.stream() == nullptr) { // NB! The lifetime of `stream` needs to match the lifetime of // `actual_options` (otherwise we will end up using a returned stream in @@ -230,10 +214,9 @@ Status LocalExecutable::RecordResult(const ShapedBuffer* result, StatusOr> LocalExecutable::LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer) { - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - backend_->stream_executor(shaped_buffer.device_ordinal())); - return backend_->transfer_manager()->TransferLiteralFromDevice(executor, + TF_ASSIGN_OR_RETURN(auto stream, + backend_->BorrowStream(shaped_buffer.device_ordinal())); + return backend_->transfer_manager()->TransferLiteralFromDevice(stream.get(), shaped_buffer); } @@ -288,19 +271,18 @@ StatusOr LocalClient::LiteralToShapedBuffer( TF_ASSIGN_OR_RETURN(auto scoped_buffer, backend().transfer_manager()->AllocateScopedShapedBuffer( literal.shape(), allocator, device_ordinal)); - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - backend().stream_executor(device_ordinal)); + TF_ASSIGN_OR_RETURN(auto stream, + mutable_backend()->BorrowStream(device_ordinal)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - executor, literal, scoped_buffer)); + stream.get(), literal, scoped_buffer)); return std::move(scoped_buffer); } StatusOr> LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - backend().stream_executor(shaped_buffer.device_ordinal())); - return backend().transfer_manager()->TransferLiteralFromDevice(executor, + TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream( + shaped_buffer.device_ordinal())); + return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(), shaped_buffer); } @@ -321,7 +303,7 @@ StatusOr> LocalClient::TransferFromOutfeedLocal( const Shape& shape, int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); - auto literal = MakeUnique(); + auto literal = Literal::CreateFromShape(shape); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( executor, shape, literal.get())); return std::move(literal); diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 4d9e0d7cd9d6ddebead1e12b23e94b529038039b..ae23809261757c637ab4aec036750c371ac60cdc 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc similarity index 60% rename from tensorflow/compiler/xla/client/xla_client/xla_builder.cc rename to tensorflow/compiler/xla/client/xla_builder.cc index ae8fbdb2dc8b3f4fc21bcfef9692645f7e1d48b5..b3b00e2fffe1196b36190ec72d1425bae4e4e276 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include #include @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/sharding_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -43,22 +45,38 @@ int64 GetUniqueId() { return id; } -// Returns true if an instruction with the given opcode can be the root of the -// computation. -bool CanBeRoot(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kOutfeed: - case HloOpcode::kTrace: - return false; - default: - return true; - } -} - } // namespace +XlaOp operator-(const XlaOp& x) { return Neg(x); } +XlaOp operator+(const XlaOp& x, const XlaOp& y) { return Add(x, y); } +XlaOp operator-(const XlaOp& x, const XlaOp& y) { return Sub(x, y); } +XlaOp operator*(const XlaOp& x, const XlaOp& y) { return Mul(x, y); } +XlaOp operator/(const XlaOp& x, const XlaOp& y) { return Div(x, y); } +XlaOp operator%(const XlaOp& x, const XlaOp& y) { return Rem(x, y); } + +XlaOp operator~(const XlaOp& x) { return Not(x); } +XlaOp operator&(const XlaOp& x, const XlaOp& y) { return And(x, y); } +XlaOp operator|(const XlaOp& x, const XlaOp& y) { return Or(x, y); } +XlaOp operator^(const XlaOp& x, const XlaOp& y) { return Xor(x, y); } +XlaOp operator<<(const XlaOp& x, const XlaOp& y) { return ShiftLeft(x, y); } + +XlaOp operator>>(const XlaOp& x, const XlaOp& y) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + if (!ShapeUtil::ElementIsIntegral(shape)) { + return InvalidArgument( + "Argument to >> operator does not have an integral type (%s).", + ShapeUtil::HumanString(shape).c_str()); + } + if (ShapeUtil::ElementIsSigned(shape)) { + return ShiftRightArithmetic(x, y); + } else { + return ShiftRightLogical(x, y); + } + }); +} + StatusOr XlaBuilder::GetShape(const XlaOp& op) const { TF_RETURN_IF_ERROR(first_error_); @@ -81,7 +99,7 @@ XlaBuilder::XlaBuilder(const string& computation_name) XlaBuilder::~XlaBuilder() {} -void XlaBuilder::NoteError(const Status& error) { +XlaOp XlaBuilder::ReportError(const Status& error) { CHECK(!error.ok()); if (die_immediately_on_error_) { LOG(FATAL) << "error building computation: " << error; @@ -91,43 +109,31 @@ void XlaBuilder::NoteError(const Status& error) { first_error_ = error; first_error_backtrace_.CreateCurrent(/*skip_count=*/1); } + return XlaOp(this); } -XlaOp XlaBuilder::NoteErrorOrReturn( - const std::function()>& op_creator) { +XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr& op) { if (!first_error_.ok()) { - return {}; + return XlaOp(this); } - auto op = op_creator(); if (!op.ok()) { - NoteError(op.status()); - return {}; + return ReportError(op.status()); } - return op.ConsumeValueOrDie(); + return op.ValueOrDie(); } -StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { - TF_RETURN_IF_ERROR(first_error_); +XlaOp XlaBuilder::ReportErrorOrReturn( + const std::function()>& op_creator) { + return ReportErrorOrReturn(op_creator()); +} - TF_RET_CHECK(root_id != nullptr); +StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { + TF_RETURN_IF_ERROR(first_error_); + TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size())); ProgramShape program_shape; - // Not all instructions can be roots. Walk backwards from the last added - // instruction until a valid root is found. - int64 index = instructions_.size() - 1; - for (; index >= 0; index--) { - TF_ASSIGN_OR_RETURN(HloOpcode opcode, - StringToHloOpcode(instructions_[index].opcode())); - if (CanBeRoot(opcode)) { - break; - } - } - if (index < 0) { - return FailedPrecondition("no root instruction was found"); - } - *root_id = instructions_[index].id(); - *program_shape.mutable_result() = instructions_[index].shape(); + *program_shape.mutable_result() = instructions_[root_id].shape(); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. @@ -152,8 +158,15 @@ StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { } StatusOr XlaBuilder::GetProgramShape() const { - int64 root; - return GetProgramShape(&root); + TF_RET_CHECK(!instructions_.empty()); + return GetProgramShape(instructions_.back().id()); +} + +StatusOr XlaBuilder::GetProgramShape(XlaOp root) const { + if (root.builder_ != this) { + return InvalidArgument("Given root operation is not in this computation."); + } + return GetProgramShape(root.handle()); } void XlaBuilder::IsConstantVisitor(const int64 op_handle, @@ -207,7 +220,7 @@ XlaComputation XlaBuilder::BuildAndNoteError() { DCHECK(parent_builder_ != nullptr); auto build_status = Build(); if (!build_status.ok()) { - parent_builder_->NoteError( + parent_builder_->ReportError( AddStatus(build_status.status(), tensorflow::strings::StrCat("error from: ", name_))); return {}; @@ -221,17 +234,29 @@ StatusOr XlaBuilder::Build() { first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); return AppendStatus(first_error_, backtrace); } + return Build(instructions_.back().id()); +} + +StatusOr XlaBuilder::Build(XlaOp root) { + if (root.builder_ != this) { + return InvalidArgument("Given root operation is not in this computation."); + } + return Build(root.handle()); +} + +StatusOr XlaBuilder::Build(int64 root_id) { + if (!first_error_.ok()) { + string backtrace; + first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); + return AppendStatus(first_error_, backtrace); + } HloComputationProto entry; entry.set_id(GetUniqueId()); // Give the computation a global unique id. entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique. - { - int64 root_id; - TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), - GetProgramShape(&root_id)); - entry.set_root_id(root_id); - } + TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id)); + entry.set_root_id(root_id); for (auto& instruction : instructions_) { // Ensures that the instruction names are unique among the whole graph. @@ -315,7 +340,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, } XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), @@ -327,7 +352,7 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { XlaOp XlaBuilder::BinaryOp( HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -383,7 +408,7 @@ XlaOp XlaBuilder::BinaryOp( XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -430,7 +455,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = literal.shape(); *instr.mutable_literal() = literal.ToProto(); @@ -440,7 +465,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { XlaOp XlaBuilder::Call(const XlaComputation& computation, tensorflow::gtl::ArraySlice operands) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); @@ -461,7 +486,7 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, const string& name) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (!parameter_numbers_.insert(parameter_number).second) { return InvalidArgument("parameter %lld already registered", @@ -476,7 +501,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, XlaOp XlaBuilder::Broadcast( const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( const Shape& shape, @@ -498,6 +523,14 @@ XlaOp XlaBuilder::Broadcast( }); } +XlaOp XlaBuilder::BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions) { + return ReportErrorOrReturn([&]() -> StatusOr { + return InDimBroadcast(shape, operand, broadcast_dimensions); + }); +} + StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { TF_RETURN_IF_ERROR(first_error_); @@ -510,7 +543,7 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -530,7 +563,7 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand, XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); std::vector starts(ShapeUtil::Rank(shape), 0); std::vector limits(shape.dimensions().begin(), @@ -545,7 +578,7 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -566,7 +599,7 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -584,7 +617,7 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, int64 dimension) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; @@ -603,7 +636,7 @@ XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -624,7 +657,7 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& shape, ShapeInference::InferReshapeShape( @@ -638,7 +671,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice new_sizes) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); std::iota(dimensions.begin(), dimensions.end(), 0); @@ -648,7 +681,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Collapse(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { if (dimensions.size() <= 1) { // Not collapsing anything, trivially we can return the operand versus // enqueueing a trivial reshape. @@ -690,21 +723,29 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, } void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { - NoteErrorOrReturn([&]() -> StatusOr { + ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeNil(); - *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto(); + *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); }); } XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) { - return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false); + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& true_shape, GetShape(on_true)); + TF_ASSIGN_OR_RETURN(const Shape& false_shape, GetShape(on_false)); + TF_RET_CHECK(ShapeUtil::IsTuple(true_shape) == + ShapeUtil::IsTuple(false_shape)); + HloOpcode opcode = ShapeUtil::IsTuple(true_shape) ? HloOpcode::kTupleSelect + : HloOpcode::kSelect; + return TernaryOp(opcode, pred, on_true, on_false); + }); } XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); @@ -718,7 +759,7 @@ XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { } XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data)); if (!ShapeUtil::IsTuple(tuple_shape)) { @@ -767,7 +808,7 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); DotDimensionNumbers dimension_numbers; @@ -780,7 +821,7 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -859,7 +900,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -905,7 +946,7 @@ XlaOp XlaBuilder::ConvGeneralDilated( tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -992,7 +1033,7 @@ StatusOr XlaBuilder::MakeWindow( XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, const tensorflow::gtl::ArraySlice fft_length) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1009,23 +1050,144 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, } XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("Given shape to Infeed must have a layout"); } - *instr.mutable_shape() = shape; + const Shape infeed_instruction_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + *instr.mutable_shape() = infeed_instruction_shape; instr.set_infeed_config(config); - return AddInstruction(std::move(instr), HloOpcode::kInfeed); + + if (ShapeUtil::IsArray(shape) && sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) { + // TODO(b/110793772): Support tiled array-shaped infeeds. + return InvalidArgument( + "Tiled sharding is not yet supported for array-shaped infeeds"); + } + + if (sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) { + return InvalidArgument( + "Replicated sharding is not yet supported for infeeds"); + } + + // The sharding is set by the client according to the data tuple shape. + // However, the shape of the infeed instruction is a tuple containing the + // data and a token. For tuple sharding type, the sharding must be changed + // to accommodate the token. + XlaOp infeed; + if (sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_TUPLE) { + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + OpSharding infeed_instruction_sharding = *sharding(); + // Arbitrarily assign the token to device 0. + *infeed_instruction_sharding.add_tuple_shardings() = + sharding_builder::AssignDevice(0); + XlaScopedShardingAssignment scoped_sharding(this, + infeed_instruction_sharding); + TF_ASSIGN_OR_RETURN( + infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {})); + } else { + TF_ASSIGN_OR_RETURN( + infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {})); + } + + // The infeed instruction produces a tuple of the infed data and a token + // type. Return XLA op containing the data. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto infeed_data; + *infeed_data.mutable_shape() = shape; + infeed_data.set_tuple_index(0); + return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement, + {infeed}); + }); +} + +XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Given shape to Infeed must have a layout"); + } + const Shape infeed_instruction_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + *instr.mutable_shape() = infeed_instruction_shape; + instr.set_infeed_config(config); + + if (ShapeUtil::IsArray(shape) && sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) { + // TODO(b/110793772): Support tiled array-shaped infeeds. + return InvalidArgument( + "Tiled sharding is not yet supported for array-shaped infeeds"); + } + + if (sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) { + return InvalidArgument( + "Replicated sharding is not yet supported for infeeds"); + } + + return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token}); }); } void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config) { - NoteErrorOrReturn([&]() -> StatusOr { + ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeNil(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + + // Check and set outfeed shape. + if (!LayoutUtil::HasLayout(shape_with_layout)) { + return InvalidArgument("Given shape to Outfeed must have a layout"); + } + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { + return InvalidArgument( + "Outfeed shape %s must be compatible with operand shape %s", + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + } + *instr.mutable_outfeed_shape() = shape_with_layout; + + instr.set_outfeed_config(outfeed_config); + + TF_RETURN_IF_ERROR( + AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}) + .status()); + + // The outfeed instruction produces a token. However, existing users expect + // a nil shape (empty tuple). This should only be relevant if the outfeed is + // the root of a computation. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto tuple_instr; + *tuple_instr.mutable_shape() = ShapeUtil::MakeNil(); + + // The dummy tuple should have no sharding. + { + XlaScopedShardingAssignment scoped_sharding(this, OpSharding()); + TF_ASSIGN_OR_RETURN( + XlaOp empty_tuple, + AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {})); + return empty_tuple; + } + }); +} + +XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); // Check and set outfeed shape. if (!LayoutUtil::HasLayout(shape_with_layout)) { @@ -1042,14 +1204,34 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, instr.set_outfeed_config(outfeed_config); - return AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}); + return AddInstruction(std::move(instr), HloOpcode::kOutfeed, + {operand, token}); + }); +} + +XlaOp XlaBuilder::CreateToken() { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + return AddInstruction(std::move(instr), HloOpcode::kAfterAll); + }); +} + +XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice tokens) { + return ReportErrorOrReturn([&]() -> StatusOr { + if (tokens.empty()) { + return InvalidArgument("AfterAll requires at least one operand"); + } + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens); }); } XlaOp XlaBuilder::CustomCall(const string& call_target_name, tensorflow::gtl::ArraySlice operands, const Shape& shape) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (tensorflow::str_util::StartsWith(call_target_name, "$")) { return InvalidArgument( @@ -1066,7 +1248,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice operands, const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape; instr.set_channel_name(channel_name); @@ -1120,11 +1302,9 @@ XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); } -// TODO(b/65209188): Create a dedicated lowering for Xor. XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return Or(And(Not(lhs), rhs, broadcast_dimensions), - And(lhs, Not(rhs), broadcast_dimensions)); + return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Not(const XlaOp& operand) { @@ -1223,7 +1403,7 @@ XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { XlaOp XlaBuilder::Transpose(const XlaOp& operand, tensorflow::gtl::ArraySlice permutation) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1238,7 +1418,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, XlaOp XlaBuilder::Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1251,13 +1431,31 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } -XlaOp XlaBuilder::Sort(const XlaOp& operand) { - return UnaryOp(HloOpcode::kSort, operand); -} - -XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0(0.5), - /*broadcast_dimensions=*/{}); +XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional values, + int64 dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); + operand_shape_ptrs.push_back(&keys_shape); + Shape values_shape; + if (values.has_value()) { + TF_ASSIGN_OR_RETURN(values_shape, GetShape(*values)); + operand_shape_ptrs.push_back(&values_shape); + } + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, operand_shape_ptrs)); + if (dimension == -1) { + TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); + dimension = ShapeUtil::Rank(keys_shape) - 1; + } + instr.add_dimensions(dimension); + return values.has_value() + ? AddInstruction(std::move(instr), HloOpcode::kSort, + {keys, *values}) + : AddInstruction(std::move(instr), HloOpcode::kSort, {keys}); + }); } XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, @@ -1267,7 +1465,7 @@ XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1279,7 +1477,7 @@ XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1290,16 +1488,6 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, }); } -XlaOp XlaBuilder::SquareF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0(2.0), - /*broadcast_dimensions=*/{}); -} - -XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0(-1.0), - /*broadcast_dimensions=*/{}); -} - XlaOp XlaBuilder::Neg(const XlaOp& operand) { return UnaryOp(HloOpcode::kNegate, operand); } @@ -1313,13 +1501,12 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, const XlaComputation& computation, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice static_operands) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { if (!static_operands.empty()) { return Unimplemented("static_operands is not supported in Map"); } HloInstructionProto instr; - std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), @@ -1331,16 +1518,32 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, ShapeInference::InferMapShape(operand_shape_ptrs, called_program_shape, dimensions)); + const Shape& output_shape = instr.shape(); + const int64 output_rank = ShapeUtil::Rank(output_shape); AddCalledComputation(computation, &instr); + std::vector new_operands(operands.begin(), operands.end()); + for (XlaOp& new_operand : new_operands) { + TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand)); + const int64 rank = ShapeUtil::Rank(shape); + if (rank != output_rank) { + TF_ASSIGN_OR_RETURN(new_operand, + InDimBroadcast(output_shape, new_operand, {})); + TF_ASSIGN_OR_RETURN(shape, GetShape(new_operand)); + } + if (!ShapeUtil::SameDimensions(output_shape, shape)) { + TF_ASSIGN_OR_RETURN(new_operand, + AddBroadcastSequence(output_shape, new_operand)); + } + } - return AddInstruction(std::move(instr), HloOpcode::kMap, operands); + return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands); }); } XlaOp XlaBuilder::RngOp(RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters, const Shape& shape) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; // Check the number of parameters per RNG distribution. @@ -1378,7 +1581,7 @@ XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b, XlaOp XlaBuilder::While(const XlaComputation& condition, const XlaComputation& body, const XlaOp& init) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; // Infer shape. @@ -1400,7 +1603,7 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice window_bounds) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); @@ -1421,11 +1624,37 @@ XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, }); } +XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); + TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape, + GetShape(scatter_indices)); + TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates)); + TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, + update_computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferScatterShape( + input_shape, scatter_indices_shape, updates_shape, + to_apply_shape, dimension_numbers)); + + *instr.mutable_scatter_dimension_numbers() = dimension_numbers; + + AddCalledComputation(update_computation, &instr); + return AddInstruction(std::move(instr), HloOpcode::kScatter, + {input, scatter_indices, updates}); + }); +} + XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, const XlaComputation& false_computation) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& predicate_shape, GetShape(predicate)); @@ -1457,16 +1686,17 @@ XlaOp XlaBuilder::Reduce( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferReduceShape( - operand_shape, init_shape, dimensions_to_reduce, + {&operand_shape, &init_shape}, dimensions_to_reduce, called_program_shape)); for (int64 dim : dimensions_to_reduce) { @@ -1482,7 +1712,7 @@ XlaOp XlaBuilder::Reduce( XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); std::vector all_dimnos(ShapeUtil::Rank(operand_shape)); std::iota(all_dimnos.begin(), all_dimnos.end(), 0); @@ -1495,7 +1725,7 @@ XlaOp XlaBuilder::ReduceWindow( const XlaComputation& computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1518,7 +1748,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1542,7 +1772,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, float epsilon, int64 feature_index) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1565,7 +1795,7 @@ XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, const XlaOp& mean, const XlaOp& variance, float epsilon, int64 feature_index) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1590,7 +1820,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, const XlaOp& batch_mean, const XlaOp& batch_var, const XlaOp& grad_output, float epsilon, int64 feature_index) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1614,7 +1844,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, XlaOp XlaBuilder::CrossReplicaSum( const XlaOp& operand, tensorflow::gtl::ArraySlice replica_group_ids) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); auto b = CreateSubBuilder("sum"); @@ -1630,12 +1860,7 @@ XlaOp XlaBuilder::CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, tensorflow::gtl::ArraySlice replica_group_ids, const tensorflow::gtl::optional& channel_id) { - return NoteErrorOrReturn([&]() -> StatusOr { - if (channel_id.has_value()) { - return Unimplemented( - "replica_group_ids and channel_id and is not supported in AllReduce"); - } - + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1645,6 +1870,10 @@ XlaOp XlaBuilder::CrossReplicaSum( instr.add_replica_group_ids(replica_group_id); } + if (channel_id.has_value()) { + instr.set_all_reduce_id(channel_id->handle()); + } + AddCalledComputation(computation, &instr); return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, @@ -1652,13 +1881,68 @@ XlaOp XlaBuilder::CrossReplicaSum( }); } +XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + + // The HloInstruction for Alltoall currently only handles the data + // communication: it accepts N already split parts and scatters them to N + // cores, and each core gathers the N received parts into a tuple as the + // output. So here we explicitly split the operand before the hlo alltoall, + // and concat the tuple elements. + // + // First, run shape inference to make sure the shapes are valid. + TF_RETURN_IF_ERROR( + ShapeInference::InferAllToAllShape(operand_shape, split_dimension, + concat_dimension, split_count) + .status()); + + // Split into N parts. + std::vector slices; + slices.reserve(split_count); + const int64 block_size = + operand_shape.dimensions(split_dimension) / split_count; + for (int i = 0; i < split_count; i++) { + slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size, + /*limit_index=*/(i + 1) * block_size, + /*stride=*/1, /*dimno=*/split_dimension)); + } + + // Handle data communication. + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices)); + std::vector slice_shape_ptrs; + c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; + } + TF_ASSIGN_OR_RETURN( + XlaOp alltoall, + AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices)); + + // Concat the N received parts. + std::vector received; + received.reserve(split_count); + for (int i = 0; i < split_count; i++) { + received.push_back(this->GetTupleElement(alltoall, i)); + } + return this->ConcatInDim(received, concat_dimension); + }); +} + XlaOp XlaBuilder::SelectAndScatter( const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); return SelectAndScatterWithGeneralPadding( operand, select, window_dimensions, window_strides, @@ -1675,7 +1959,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( tensorflow::gtl::ArraySlice> padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1703,7 +1987,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits) { - return NoteErrorOrReturn([&]() -> StatusOr { + return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), @@ -1717,20 +2001,40 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, } void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { - NoteErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; + ReportErrorOrReturn([&]() -> StatusOr { + // Send HLO takes two operands: a data operand and a token. Generate the + // token to pass into the send. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), + HloOpcode::kAfterAll, {})); + + return SendWithToken(operand, token, handle); + }); +} + +XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr { + if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) { + return InvalidArgument("Send must use a device-to-device channel"); + } - // Send instruction produces a tuple of {aliased operand, U32 context}. + // Send instruction produces a tuple of {aliased operand, U32 context, + // token}. + HloInstructionProto send_instr; TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - *instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN( - XlaOp send, - AddInstruction(std::move(instr), HloOpcode::kSend, {operand})); + *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + send_instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp send, + AddInstruction(std::move(send_instr), HloOpcode::kSend, + {operand, token})); HloInstructionProto send_done_instr; - *send_done_instr.mutable_shape() = ShapeUtil::MakeNil(); + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); send_done_instr.set_channel_id(handle.handle()); return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, {send}); @@ -1738,19 +2042,133 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { } XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { - return NoteErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; + return ReportErrorOrReturn([&]() -> StatusOr { + // Recv HLO takes a single token operand. Generate the token to pass into + // the Recv and RecvDone instructions. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), + HloOpcode::kAfterAll, {})); + + XlaOp recv = RecvWithToken(token, shape, handle); + + // The RecvDone instruction produces a tuple of the data and a token + // type. Return XLA op containing the data. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto recv_data; + *recv_data.mutable_shape() = shape; + recv_data.set_tuple_index(0); + return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement, + {recv}); + }); +} - // Recv instruction produces a tuple of {receive buffer, U32 context}. - *instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN(XlaOp recv, - AddInstruction(std::move(instr), HloOpcode::kRecv, {})); +XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr { + if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) { + return InvalidArgument("Recv must use a device-to-device channel"); + } + + // Recv instruction produces a tuple of {receive buffer, U32 context, + // token}. + HloInstructionProto recv_instr; + *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + recv_instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), + HloOpcode::kRecv, {token})); + + HloInstructionProto recv_done_instr; + *recv_done_instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + recv_done_instr.set_channel_id(handle.handle()); + return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, + {recv}); + }); +} + +XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr { + if (!LayoutUtil::HasLayout(shape_with_layout)) { + return InvalidArgument("Shape passed to SendToHost must have a layout"); + } + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { + return InvalidArgument( + "SendToHost shape %s must be compatible with operand shape %s", + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + } + // TODO(b/111544877): Support tuple shapes. + if (!ShapeUtil::IsArray(operand_shape)) { + return InvalidArgument("SendToHost only supports array shapes, shape: %s", + ShapeUtil::HumanString(operand_shape).c_str()); + } + + if (handle.type() != ChannelHandle::DEVICE_TO_HOST) { + return InvalidArgument("SendToHost must use a device-to-host channel"); + } + + // Send instruction produces a tuple of {aliased operand, U32 context, + // token}. + HloInstructionProto send_instr; + *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape_with_layout, ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}); + send_instr.set_channel_id(handle.handle()); + send_instr.set_is_host_transfer(true); + TF_ASSIGN_OR_RETURN(XlaOp send, + AddInstruction(std::move(send_instr), HloOpcode::kSend, + {operand, token})); + + HloInstructionProto send_done_instr; + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + send_done_instr.set_channel_id(handle.handle()); + send_done_instr.set_is_host_transfer(true); + return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, + {send}); + }); +} + +XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Shape passed to RecvFromHost must have a layout"); + } + + // TODO(b/111544877): Support tuple shapes. + if (!ShapeUtil::IsArray(shape)) { + return InvalidArgument( + "RecvFromHost only supports array shapes, shape: %s", + ShapeUtil::HumanString(shape).c_str()); + } + + if (handle.type() != ChannelHandle::HOST_TO_DEVICE) { + return InvalidArgument("RecvFromHost must use a host-to-device channel"); + } + + // Recv instruction produces a tuple of {receive buffer, U32 context, + // token}. + HloInstructionProto recv_instr; + *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + recv_instr.set_channel_id(handle.handle()); + recv_instr.set_is_host_transfer(true); + TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), + HloOpcode::kRecv, {token})); HloInstructionProto recv_done_instr; - *recv_done_instr.mutable_shape() = shape; + *recv_done_instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); recv_done_instr.set_channel_id(handle.handle()); + recv_done_instr.set_is_host_transfer(true); return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, {recv}); }); @@ -1789,11 +2207,6 @@ StatusOr XlaBuilder::BuildConstantSubGraph( TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, LookUpInstruction(root_op)); - TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode())); - if (!CanBeRoot(opcode)) { - return InvalidArgument("the operand with opcode %s cannot be root", - root->opcode().c_str()); - } HloComputationProto entry; entry.set_id(GetUniqueId()); // Give the computation a global unique id. @@ -1993,9 +2406,558 @@ StatusOr XlaBuilder::LookUpInstruction( return &instructions_[op.handle()]; } -XlaOp XlaBuilder::UnimplementedOp() { - NoteError(Unimplemented("Op not implemented")); - return {}; +// Enqueues a "retrieve parameter value" instruction for a parameter that was +// passed to the computation. +XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, + const string& name) { + return builder->Parameter(parameter_number, shape, name); +} + +// Enqueues a constant with the value of the given literal onto the +// computation. +XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) { + return builder->ConstantLiteral(literal); +} + +XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes) { + return operand.builder()->Broadcast(operand, broadcast_sizes); +} + +XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions) { + return operand.builder()->BroadcastInDim(operand, shape, + broadcast_dimensions); +} + +XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config) { + return operand.builder()->Pad(operand, padding_value, padding_config); +} + +XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes) { + return operand.builder()->Reshape(operand, dimensions, new_sizes); +} + +XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes) { + return operand.builder()->Reshape(operand, new_sizes); +} + +XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions) { + return operand.builder()->Collapse(operand, dimensions); +} + +XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) { + return operand.builder()->Slice(operand, start_indices, limit_indices, + strides); +} + +XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno) { + return operand.builder()->SliceInDim(operand, start_index, limit_index, + stride, dimno); +} + +XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes) { + return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); +} + +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices) { + return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); +} + +XlaOp ConcatInDim(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + int64 dimension) { + return builder->ConcatInDim(operands, dimension); +} + +void Trace(const string& tag, const XlaOp& operand) { + return operand.builder()->Trace(tag, operand); +} + +XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) { + return pred.builder()->Select(pred, on_true, on_false); +} + +XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements) { + return builder->Tuple(elements); +} + +XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) { + return tuple_data.builder()->GetTupleElement(tuple_data, index); +} + +XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions); +} + +XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions); +} + +XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions); +} + +XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions); +} + +XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions); +} + +XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Le(lhs, rhs, broadcast_dimensions); +} + +XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs) { + return lhs.builder()->Dot(lhs, rhs); +} + +XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers) { + return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers); +} + +XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding) { + return lhs.builder()->Conv(lhs, rhs, window_strides, padding); +} + +XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, + padding); +} + +XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides, + padding, dimension_numbers); +} + +XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, + dimension_numbers); +} + +XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers) { + return lhs.builder()->ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers); +} + +XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length) { + return operand.builder()->Fft(operand, fft_type, fft_length); +} + +XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) { + return builder->Infeed(shape, config); +} + +void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config) { + return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config); +} + +XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands) { + return builder->Call(computation, operands); +} + +XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape) { + return builder->CustomCall(call_target_name, operands, shape); +} + +XlaOp HostCompute(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape) { + return builder->HostCompute(operands, channel_name, cost_estimate_ns, shape); +} + +XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return real.builder()->Complex(real, imag, broadcast_dimensions); +} + +XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); } + +XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Add(lhs, rhs, broadcast_dimensions); +} + +XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions); +} + +XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions); +} + +XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Div(lhs, rhs, broadcast_dimensions); +} + +XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions); +} + +XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Max(lhs, rhs, broadcast_dimensions); +} + +XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Min(lhs, rhs, broadcast_dimensions); +} + +XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->And(lhs, rhs, broadcast_dimensions); +} + +XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Or(lhs, rhs, broadcast_dimensions); +} + +XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions); +} + +XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); } + +XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions); +} + +XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions); +} + +XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions); +} + +XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce) { + return operand.builder()->Reduce(operand, init_value, computation, + dimensions_to_reduce); +} + +XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation) { + return operand.builder()->ReduceAll(operand, init_value, computation); +} + +XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding) { + return operand.builder()->ReduceWindow(operand, init_value, computation, + window_dimensions, window_strides, + padding); +} + +XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return operand.builder()->ReduceWindowWithGeneralPadding( + operand, init_value, computation, window_dimensions, window_strides, + padding); +} + +XlaOp CrossReplicaSum(const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids) { + return operand.builder()->CrossReplicaSum(operand, replica_group_ids); +} + +XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id) { + return operand.builder()->CrossReplicaSum(operand, computation, + replica_group_ids, channel_id); +} + +XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups) { + return operand.builder()->AllToAll(operand, split_dimension, concat_dimension, + split_count, replica_groups); +} + +XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter) { + return operand.builder()->SelectAndScatter(operand, select, window_dimensions, + window_strides, padding, source, + init_value, scatter); +} + +XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { + return operand.builder()->SelectAndScatterWithGeneralPadding( + operand, select, window_dimensions, window_strides, padding, source, + init_value, scatter); +} + +XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); } + +XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return y.builder()->Atan2(y, x, broadcast_dimensions); +} + +XlaOp Exp(const XlaOp& operand) { return operand.builder()->Exp(operand); } + +XlaOp Expm1(const XlaOp& operand) { return operand.builder()->Expm1(operand); } + +XlaOp Floor(const XlaOp& operand) { return operand.builder()->Floor(operand); } + +XlaOp Ceil(const XlaOp& operand) { return operand.builder()->Ceil(operand); } + +XlaOp Round(const XlaOp& operand) { return operand.builder()->Round(operand); } + +XlaOp Log(const XlaOp& operand) { return operand.builder()->Log(operand); } + +XlaOp Log1p(const XlaOp& operand) { return operand.builder()->Log1p(operand); } + +XlaOp Sign(const XlaOp& operand) { return operand.builder()->Sign(operand); } + +XlaOp Clz(const XlaOp& operand) { return operand.builder()->Clz(operand); } + +XlaOp Cos(const XlaOp& operand) { return operand.builder()->Cos(operand); } + +XlaOp Sin(const XlaOp& operand) { return operand.builder()->Sin(operand); } + +XlaOp Tanh(const XlaOp& operand) { return operand.builder()->Tanh(operand); } + +XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); } + +XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); } + +XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions); +} + +XlaOp IsFinite(const XlaOp& operand) { + return operand.builder()->IsFinite(operand); +} + +XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) { + return operand.builder()->ConvertElementType(operand, new_element_type); +} + +XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { + return operand.builder()->BitcastConvertType(operand, new_element_type); +} + +XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); } + +XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation) { + return operand.builder()->Transpose(operand, permutation); +} + +XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return operand.builder()->Rev(operand, dimensions); +} + +XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values, + int64 dimension) { + return keys.builder()->Sort(keys, std::move(values), dimension); +} + +XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { + return min.builder()->Clamp(min, operand, max); +} + +XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands) { + return builder->Map(operands, computation, dimensions, static_operands); +} + +XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape) { + return mu.builder()->RngNormal(mu, sigma, shape); +} + +XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape) { + return a.builder()->RngUniform(a, b, shape); +} + +XlaOp While(const XlaComputation& condition, const XlaComputation& body, + const XlaOp& init) { + return init.builder()->While(condition, body, init); +} + +XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation) { + return predicate.builder()->Conditional(predicate, true_operand, + true_computation, false_operand, + false_computation); +} + +XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits) { + return operand.builder()->ReducePrecision(operand, exponent_bits, + mantissa_bits); +} + +XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + return input.builder()->Gather(input, gather_indices, dimension_numbers, + window_bounds); +} + +XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return input.builder()->Scatter(input, scatter_indices, updates, + update_computation, dimension_numbers); +} + +void Send(const XlaOp& operand, const ChannelHandle& handle) { + return operand.builder()->Send(operand, handle); +} + +XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle) { + return builder->Recv(shape, handle); +} + +XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle) { + return operand.builder()->SendWithToken(operand, token, handle); +} + +XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return token.builder()->RecvWithToken(token, shape, handle); +} + +XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, const ChannelHandle& handle) { + return operand.builder()->SendToHost(operand, token, shape_with_layout, + handle); +} + +XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return token.builder()->RecvFromHost(token, shape, handle); +} + +XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config) { + return token.builder()->InfeedWithToken(token, shape, config); +} + +XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config) { + return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout, + outfeed_config); +} + +XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); } + +XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice tokens) { + return builder->AfterAll(tokens); +} + +XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index) { + return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon, + feature_index); +} + +XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index) { + return operand.builder()->BatchNormInference( + operand, scale, offset, mean, variance, epsilon, feature_index); +} + +XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index) { + return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var, + grad_output, epsilon, feature_index); +} + +XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeShape(type, {size}); + return builder->ReportErrorOrReturn( + builder->AddInstruction(std::move(instr), HloOpcode::kIota)); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..9403d7ca8dabc80a3964b50d29f158a98091f843 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -0,0 +1,2284 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stacktrace.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class XlaBuilder; + +// This represents an instruction that has been enqueued using the XlaBuilder. +// This is used to pass to subsequent computations that depends upon the +// instruction as an operand. +class XlaOp { + public: + XlaOp() : handle_(-1), builder_(nullptr) { + static_assert(std::is_trivially_destructible::value, + "XlaOp should be trivially destructible"); + } + ~XlaOp() = default; + + // Precondition: !IsUninitialized(). + // + // It's very common to do foo.builder()->bar(). Without this precondition, if + // foo.builder() is null, the call to bar will segfault at some point possibly + // deep in the callstack when we finally dereference `this`. The precondition + // lets us avoid this tricky-to-debug problem. + XlaBuilder* builder() const { + CHECK(builder_ != nullptr); + return builder_; + } + + // Returns true if the XlaOp represents valid, non-erroneous value. + bool valid() const { return handle_ >= 0; } + + // Returns true if the XlaOp was created by the XlaOp() constructor and + // not returned by a builder. + bool IsUninitialized() const { return builder_ == nullptr; } + + bool IsIdenticalTo(const XlaOp& rhs) const { + return handle_ == rhs.handle_ && builder_ == rhs.builder_; + } + + friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) { + out << op.handle(); + return out; + } + + private: + explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {} + XlaOp(int64 handle, XlaBuilder* builder) + : handle_(handle), builder_(builder) {} + + int64 handle() const { return handle_; } + + friend class XlaBuilder; + + // < 0 means "invalid handle". + int64 handle_; + + // Not owned. Non-null for any handle returned by XlaBuilder, even if the + // handle is invalid. + XlaBuilder* builder_; +}; + +// Arithmetic operator overloads for the XlaOp type. +XlaOp operator-(const XlaOp& x); +XlaOp operator+(const XlaOp& x, const XlaOp& y); +XlaOp operator-(const XlaOp& x, const XlaOp& y); +XlaOp operator*(const XlaOp& x, const XlaOp& y); +XlaOp operator/(const XlaOp& x, const XlaOp& y); +XlaOp operator%(const XlaOp& x, const XlaOp& y); + +// Bitwise operator overloads for the XlaOp type. +XlaOp operator~(const XlaOp& x); +XlaOp operator&(const XlaOp& x, const XlaOp& y); +XlaOp operator|(const XlaOp& x, const XlaOp& y); +XlaOp operator^(const XlaOp& x, const XlaOp& y); +XlaOp operator<<(const XlaOp& x, const XlaOp& y); +// Performs a right arithmetic shift if 'x' is a signed type, otherwise performs +// a right logical shift. +XlaOp operator>>(const XlaOp& x, const XlaOp& y); + +// We don't overload the relational operators (==, !=, <, <=, >, >=) because the +// semantics might be surprising since their result types are usually 'bool'. +// Further programmers may expect == to be a structural equality. +// We also choose not to overload any of the mutating operators (e.g., +=, -=) +// because the semantics might be misleading — XLA computations are immutable. + +// A convenient interface for building up computations. +// +// Thread-compatible. +class XlaBuilder { + public: + // computation_name: name to use for the built computation. + XlaBuilder(const string& computation_name); + + XlaBuilder(const XlaBuilder&) = delete; + XlaBuilder& operator=(const XlaBuilder&) = delete; + + ~XlaBuilder(); + + // Returns the computation name. + const string& name() const { return name_; } + + // Sets OpMetadata that will be added to all instructions until cleared. + // + // OpMetadata is often applied to a series of XLA HLO instructions. As a + // result, OpMetadata is set on the Computation Builder. All subsequent + // instructions generated via this Computation Builder will have the same + // OpMetadata attached until a call to ClearOpMetadata. + void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } + + // Clears the HloMetadata state. + void ClearOpMetadata() { metadata_.Clear(); } + + // Sets an OpSharding that will be attached to all instructions until cleared. + void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } + + // Clears the sharding. Ops will be sharded according to the default placement + // policy. + void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } + + // Returns the OpSharding that will be attached to all instructions. + const tensorflow::gtl::optional& sharding() const { + return sharding_; + } + + // Sets the builder to a mode where it will die immediately when an error is + // encountered, rather than producing it in a deferred fashion when Build() is + // called (which is the default). + void set_die_immediately_on_error(bool enabled) { + die_immediately_on_error_ = enabled; + } + + // Default dimension numbers used for a 2D convolution. + static constexpr int64 kConvBatchDimension = 0; + static constexpr int64 kConvFeatureDimension = 1; + static constexpr int64 kConvFirstSpatialDimension = 2; + static constexpr int64 kConvSecondSpatialDimension = 3; + static constexpr int64 kConvKernelOutputDimension = 0; + static constexpr int64 kConvKernelInputDimension = 1; + static constexpr int64 kConvKernelFirstSpatialDimension = 2; + static constexpr int64 kConvKernelSecondSpatialDimension = 3; + + // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for + // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for + // the kernel operand + // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. + static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( + int num_spatial_dims = 2); + + // Returns an error if the convolution dimension numbers have conflicts. + static Status Validate(const ConvolutionDimensionNumbers& dnum); + + // Returns a new XlaBuilder whose resultant Computation is used only by this + // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error + // behavior as the parent. + std::unique_ptr CreateSubBuilder(const string& computation_name); + + // Builds the computation with the requested operations, or returns a non-ok + // status. Note that all ops that have been enqueued will be moved to the + // computation being returned. The root of the computation will be the last + // added operation. + StatusOr Build(); + + // Overload of Build which specifies a particular root instruction for the + // computation. + StatusOr Build(XlaOp root); + + // Builds the computation with the requested operations, or notes an error in + // the parent XlaBuilder and returns an empty computation if building failed. + // This function is intended to be used where the returned XlaComputation is + // only used by the parent XlaBuilder and hence further operation on the + // returned XlaComputation will simply be error'ed out if an error occurred + // while building this computation. If the built computation is to be used by + // a XlaBuilder other than the parent XlaBuilder then Build() should be used + // instead. + XlaComputation BuildAndNoteError(); + + // Returns a subgraph that roots on the given root. If the root is not a + // compile-time constant (see `IsConstant`), returns an error. + // + // This will copy the needed ops/computations to the subgraph. + StatusOr BuildConstantSubGraph(const XlaOp& root_op) const; + + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // XlaOp and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + + // Returns the shape of the given op. + StatusOr GetShape(const XlaOp& op) const; + + // Returns the (inferred) result for the current computation's shape. This + // assumes the root instruction is the last added instruction. + StatusOr GetProgramShape() const; + + // Returns the (inferred) result for the current computation's shape using the + // given operation as the root. + StatusOr GetProgramShape(XlaOp root) const; + + // Reports an error to the builder, by + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to + // Build()/GetShape()/...) + // * dying if die_immediately_on_error_ is true. + // Returns an XlaOp with an invalid handle but a valid builder. This value can + // be returned in place of a value in APIs that return an XlaOp. + XlaOp ReportError(const Status& error); + + // A helper function that converts a StatusOr into an XlaOp. + // If the Status was an error, reports the error to builder and returns an + // invalid XlaOp handle. + XlaOp ReportErrorOrReturn(const StatusOr& op); + + // A helper function that runs a function that returns a StatusOr and + // returns an XlaOp. + XlaOp ReportErrorOrReturn(const std::function()>& op_creator); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on any parameters, or on stateful operators such + // as `RngNormal` or `Infeed`. + // + // This tests whether a computation is a compile-time constant without + // evaluating the computation. + StatusOr IsConstant(const XlaOp& operand) const; + + private: + // Build helper which takes the id of the root operation.. + StatusOr Build(int64 root_id); + + // Enqueues a "retrieve parameter value" instruction for a parameter that was + // passed to the computation. + XlaOp Parameter(int64 parameter_number, const Shape& shape, + const string& name); + + // Enqueues a constant with the value of the given literal onto the + // computation. + XlaOp ConstantLiteral(const LiteralSlice& literal); + + // Enqueues a constant onto the computation. Methods are templated on the + // native host type (NativeT) which corresponds to a specific XLA + // PrimitiveType as given in the following table: + // + // Native Type PrimitiveType + // ----------------------------- + // bool PRED + // int32 S32 + // int64 S64 + // uint32 U32 + // uint64 U64 + // float F32 + // double F64 + // + // Note: not all primitive types defined in xla_data.proto have a + // corresponding native type yet. + template + XlaOp ConstantR0(NativeT value); + template + XlaOp ConstantR1(tensorflow::gtl::ArraySlice values); + XlaOp ConstantR1(const tensorflow::core::Bitmap& values); + template + XlaOp ConstantR2( + std::initializer_list> values); + template + XlaOp ConstantFromArrayWithLayout(const Array& values, + const Layout& layout); + template + XlaOp ConstantFromArray(const Array& values); + template + XlaOp ConstantR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout); + template + XlaOp ConstantR2FromArray2D(const Array2D& values); + template + XlaOp ConstantR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout); + template + XlaOp ConstantR3FromArray3D(const Array3D& values); + template + XlaOp ConstantR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout); + template + XlaOp ConstantR4FromArray4D(const Array4D& values); + + // Enqueues a rank one constant (vector) onto the computation. The vector has + // size 'length' and every element has the value 'value'. + template + XlaOp ConstantR1(int64 length, NativeT value); + + // Adds dimensions to an array by duplicating the data in the array. + // + // The new dimensions are inserted on the left, i.e. if + // broadcast_sizes has values {a0, ..., aN} and the operand shape + // has dimensions {b0, ..., bM} then the shape of the output has + // dimensions {a0, ..., aN, b0, ..., bM}. + // + // The new dimensions index into copies of the operand, i.e. + // + // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] + XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + + // Performs in-dimension-style broadcast. + // + // Operand specifies the input to be broadcast. "shape" is expected output + // shape. "broadcast_dimensions" are the dimensions to be broadcasting into. + // Dimension numbers in broadcast_dimensions map to individual dimensions + // of the operand, and specify what dimension of the output shape they + // should be broadcast. + // e.g. + // Say operand = [1, 2], i.e., a 1D tensor with 2 elements. + // and dimension of shape is [2,2]. + // Specifying {1} as brodcast_dimension will generate output + // [1 , 2] + // [1 , 2] + // On the other hand, specifying {0} as broadcast_dimension + // will generate output + // [1 , 1] + // [2 , 2] + XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Enqueues a pad operation onto the computation that pads the given value on + // the edges as well as between the elements of the input. padding_config + // specifies the padding amount for each dimension. + XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + + // Enqueues an operation onto the computation that flattens the operand based + // on the dimension order (major/slowest-varying to minor/fastest-varying) + // given, followed by reshaping it into the shape with the given dimension + // sizes (also major to minor). Conceptually, this is a limited form of + // "shape casting". + XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + // Enqueues an operation onto the computation that collapses the operand, from + // first to last dimension (C order), then reshapes it to the given dimension + // sizes. Conceptually, this is a limited form of "shape casting". + XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes); + + // Wrapper for Reshape. + // Enqueues an operation to collapse the provided dimensions; e.g. an + // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to + // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must + // be a consecutive, in-order subsequence of the operand dimensions. + // + // Note that collapsing a single dimension does nothing: + // + // {256} collapsing {0} => {256} + // {1} collapsing {0} => {1} + // + // Collapsing multiple dimensions produces a single result dimension: + // + // {256, 2} collapsing {0,1} => {512} + // {256, 2, 3} collapsing {0,1} => {512, 3} + // + // This could potentially cause data to be moved -- it provides a more + // structured form of reshaping than an arbitrary Reshape operation. + XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + // Enqueues a slice operation onto the computation that slices the operand + // from the start indices to the limit indices; e.g. + // + // x + // [ 0 1 2 3 ] + // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] + // [ 8 9 a b ] + // + // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D + // range notation. + // The strides parameter determines the stride over the slice + XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + // Enqueues a slice operation in a given dimension, taking all other + // dimensions as they are; e.g. if dimno is 1 from start_index 2 to + // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand + // for: + // + // array[:, 2:4:1, :] + XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno); + + // Enqueues a slice operation onto the computation that slices the 'operand' + // from dynamic start indices which are passed in 'start_indices'. + // The size of the slice in each dimension is passed in 'slice_sizes', + // which specify the end point of exclusive slice intervals in each + // dimension [start, start + size). + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo input dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + // Enqueues a dynamic update slice operation onto the computation, which + // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. + // The shape of 'update' determines the shape of the slice of 'operand' + // which is updated. + // The indices specified in 'start_indices' specify the offset of the slice + // of 'operand' which is updated. + // + // update = {10, 11} // calculated at runtime. + // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] + // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] + // [7 8 9] [7 8 9 ] + // + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo update dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + + // Enqueues a concatenate instruction onto the computation. 'operands' must + // have >= 1 entry. + XlaOp ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension); + + // Enqueue a tracing operation onto the computation; the computation will emit + // a logging message with the operand. + void Trace(const string& tag, const XlaOp& operand); + + // Enqueues a conditional-move-like select operation onto the computation; + // predicated on pred, selects between on_true and on_false. + XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); + + // Enqueues a tuple-creation instruction onto the computation. + XlaOp Tuple(tensorflow::gtl::ArraySlice elements); + + // Enqueues a tuple-element-get instruction onto the computation. + XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + + // Enqueues an equal-to comparison instruction onto the computation. + XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a not-equal comparison instruction onto the computation. + XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a greater-or-equal comparison instruction onto the computation. + XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a greater-than comparison instruction onto the computation. + XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a less-than comparison instruction onto the computation. + XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a less-or-equal comparison instruction onto the computation. + XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a dot instruction onto the computation. + XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + + // Enqueues a general dot instruction onto the computation. + XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, which uses the + // default convolution dimension numbers. + XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration in the format returned by MakePadding(). + XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided dimension numbers configuration. + XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration as well as the dimension numbers. + XlaOp ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration, dilation factors and dimension numbers. + XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues an FFT instruction onto the computation, of the given type and + // with the given FFT length. + XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + + // Enqueues an infeed instruction onto the computation, which writes data of + // the given shape to the infeed buffer of the device. + XlaOp Infeed(const Shape& shape, const string& config = ""); + XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config = ""); + + // Enqueues an outfeed instruction onto the computation. This instruction + // generates outgoing data transfers for the given data. + // + // shape_with_layout communicates the laid out shape that we want to outfeed + // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error + // will occur. + void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config); + + // Enqueues a call instruction onto the computation. + XlaOp Call(const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands); + + // Enqueues a custom call instruction onto the computation. + // During code generation, a call instruction is emitted which targets a + // symbol with the name |call_target_name|. The |operands| are passed to the + // call instruction. |shape| is the resultant shape. + XlaOp CustomCall(const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape); + + // Enqueues a pseudo-op to represent host-side computation data-dependencies. + // During code generation, host send and receive operations will be generated + // to transfer |operands| to the host and a single result of |shape| back to + // the device. Host send/recv operations are emitted using |channel_name|. + // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO + // instruction scheduling. + XlaOp HostCompute(tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + + // The following methods enqueue element-wise binary arithmetic operations + // onto the computation. The shapes of the operands have to match unless one + // of the operands is a scalar, or an explicit broadcast dimension is given + // (see g3doc for more details). + + // Enqueues a complex compose instruction onto the computation. + XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a complex conjugate instruction onto the computation. + XlaOp Conj(const XlaOp& operand); + + // Enqueues an add instruction onto the computation. + XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a subtract instruction onto the computation. + XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a multiply instruction onto the computation. + XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a divide instruction onto the computation. + XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a remainder instruction onto the computation. + XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a max instruction onto the computation. + XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a min instruction onto the computation. + XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Element-wise logical operators + XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Not(const XlaOp& operand); + + XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Reduces an array among the provided dimensions, given "computation" as a + // reduction operator. + XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + + // Convenience wrapper around the above that reduces all the dimensions in the + // operand shape. + XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + + // Enqueues a windowed reduce instruction onto the computation. + XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + + // As ReduceWindow(), but the padding is given in the format + // returned by MakePadding(). + XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + + // Returns the sum of the operand value within each subgroup of replicas. All + // replicas supply one input to the sum and all replicas receive the resulting + // sum for each subgroup. + XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids = {}); + + // Enqueues an operation that do an AllReduce of the operand cross cores. Here + // AllReduce means doing a reduction on the input operand cross cores and then + // broadcasting the reduction result to those cores. The reduction function is + // defined by `computation`, which should be a commutative computation on + // scalars, e.g., add, min, or max. The way that AllReduce is applied is + // configured by: + // + // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // - `channel_id`: for Allreduce nodes from different modules, if they have + // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will + // not be applied cross modules. + // + // TODO(b/79737069): Rename this to AllReduce when it's ready to use. + XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids = {}, + const tensorflow::gtl::optional& channel_id = + tensorflow::gtl::nullopt); + + // Enqueues an operation that do an Alltoall of the operand cross cores. + // + // TODO(b/110096724): This is NOT YET ready to use. + XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups); + + // Enqueues an operation that scatters the `source` array to the selected + // indices of each window. + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter); + + // As SelectAndScatter(), but the padding is given in the format + // returned by MakePadding(). + XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + + // Enqueues an abs instruction onto the computation. + XlaOp Abs(const XlaOp& operand); + + // Enqueues a atan2 instruction onto the computation. + XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues an exp instruction onto the computation. + XlaOp Exp(const XlaOp& operand); + + // Enqueues an expm1 instruction onto the computation. + XlaOp Expm1(const XlaOp& operand); + + // Enqueues a floor instruction onto the computation. + XlaOp Floor(const XlaOp& operand); + + // Enqueues a ceil instruction onto the computation. + XlaOp Ceil(const XlaOp& operand); + + // Enqueues a round instruction onto the computation, rounding to nearest even + // with half-way cases rounding away from zero. + XlaOp Round(const XlaOp& operand); + + // Enqueues an log instruction (natural logarithm) onto the computation. + XlaOp Log(const XlaOp& operand); + + // Enqueues an log1p instruction (log(x+1)) onto the computation. + XlaOp Log1p(const XlaOp& operand); + + // Enqueues a sign instruction onto the computation. + XlaOp Sign(const XlaOp& operand); + + // Enqueues a count leading zeros instruction onto the computation. + XlaOp Clz(const XlaOp& operand); + + // Enqueues a cosine instruction onto the computation. + XlaOp Cos(const XlaOp& operand); + + // Enqueues a sine instruction onto the computation. + XlaOp Sin(const XlaOp& operand); + + // Enqueues a tanh instruction onto the computation. + XlaOp Tanh(const XlaOp& operand); + + // Enqueues a real-part instruction onto the computation. + XlaOp Real(const XlaOp& operand); + + // Enqueues an imaginary-part instruction onto the computation. + XlaOp Imag(const XlaOp& operand); + + // Enqueues a lhs^rhs computation onto the computation. + XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues an operator that tests if the operand's values are finite, i.e., + // not Inf or NaN. Defined only for floating-point types. Returns an array of + // booleans with the same shape where entries are true iff the corresponding + // entry was NaN. + XlaOp IsFinite(const XlaOp& operand); + + // Enqueues a convert instruction onto the computation that changes the + // element type of the operand array to primitive_type. + XlaOp ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type); + + // Enqueues a no-op instruction onto the computation that changes + // the element type of the operand array to primitive_type. The + // bit-widths of the source and destination element types must be + // identical. + XlaOp BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type); + + // Enqueues a negate instruction onto the computation. + XlaOp Neg(const XlaOp& operand); + + // Enqueues a transpose instruction onto the computation. + XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation); + + // Enqueues a reverse instruction onto the computation. The order of the + // elements in the given dimensions is reversed (i.e., the element at index i + // is moved to index dimension_size - 1 - i). + XlaOp Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + // Enqueues a sort (as increasing order) instruction onto the computation. + // If only keys are provided: + // * If the keys are an rank-1 tensor (an array), the result is a sorted array + // of keys, in ascending order. + // * If the keys have higher rank, the keys are sorted along the provided + // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension + // value of 0 will indepenently sort every column, and a dimension value of 1 + // will independently sort each row. If no dimension number is provided, then + // the last dimension is chosen by default. + // + // If both keys and values are provided: + // * The keys and the values must tensors with the same dimensions. The + // element types of the tensors may be different. + // * The result is a tuple that consists of a sorted tensor of keys (along the + // provided dimension, as above) as the first element, and a tensor with their + // corresponding values as the second element. + XlaOp Sort(XlaOp keys, + tensorflow::gtl::optional values = tensorflow::gtl::nullopt, + int64 dimension = -1); + + // Enqueues a clamp instruction onto the computation. + XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + + // Enqueues a map instruction onto the computation. + XlaOp Map(tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands = {}); + + // Enqueues a N(mu, sigma) random number generation instruction onto the + // computation. + XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); + + // Enqueues a U(a, b) random number generation instruction onto the + // computation. Returns values in the semi-open interval [a, b). + XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + + // Enqueues a while node onto the computation. + XlaOp While(const XlaComputation& condition, const XlaComputation& body, + const XlaOp& init); + + // Enqueues a conditional node onto the computation. + XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + + // Enqueues a ReducePrecision node onto the computation. + XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + + // Enqueues a Gather node onto the computation. + XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + + // Enqueues a Scatter node onto the computation. + XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + + // Enqueues a Send node onto the computation for device-to-device + // communication, to send the given operand to a Recv instruction that shares + // the same channel handle. + void Send(const XlaOp& operand, const ChannelHandle& handle); + XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle); + + // Enqueues a Send node which sends data to the host. + XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, const ChannelHandle& handle); + + // Enqueues a Recv node which receives data from the host. + XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + + // Enqueues an AfterAll operation with no operands producing a token-shaped + // value. + XlaOp CreateToken(); + + // Enqueues an AfterAll operation with no operands producing a token-shaped + // value. + XlaOp AfterAll(tensorflow::gtl::ArraySlice tokens); + + // Enqueues a Recv node onto the computation. The data comes from a Send + // instruction that shares the same channel handle and its shape must + // be the same as the given shape. + XlaOp Recv(const Shape& shape, const ChannelHandle& handle); + XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + + // Normalizes operand across spatial and batch dimensions for each feature. + // + // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` + // is the normalized result and batch_mean and batch_var are the mean and + // variance, respectively, across batch for the operand. + XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + + // Normalizes operand across spatial and batch dimensions for each feature. + // + // `BatchNormInference` is equivalent to calling `BatchNormTraining` without + // computing `mean` and `variance` for each batch inside the operation. It + // uses the input `mean` and `variance` instead as estimated values. The + // purpose of this op is to reduce latency in inference, hence the name + // `BatchNormInference`. + // + // The output has the same shape as `operand`, and contains the normalized + // values for each batch. + XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + + // Calculates the gradients of a batch norm op. + // + // The inputs `batch_mean` and `batch_var` represent the mean and variance + // across the batch. + // + // Returns a tuple of three elements: + // - grad_operand: Gradient with respect to input `operand` + // - grad_offset: Gradient with respect to input `offset` + // - grad_scale: Gradient with respect to input `scale` + XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); + + StatusOr AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + tensorflow::gtl::ArraySlice operands = {}); + + void AddCalledComputation(const XlaComputation& computation, + HloInstructionProto* instr); + + StatusOr LookUpInstruction(const XlaOp& op) const; + + // Internal helper method that does the building for an arbitrary unary op. + XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); + + // Internal helper method that does the building for an arbitrary binary op. + // broadcast_dimensions specifies which dimensions to use for broadcasting + // when the operation is between tensors of different ranks. + XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Internal helper method that does the building for an arbitrary ternary op. + XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, + const XlaOp& ehs); + + XlaOp RngOp(RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters, + const Shape& shape); + + StatusOr InDimBroadcast( + const Shape& shape, const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Internal helper method that creates a sequence of instructions that + // performs an explicit broadcast of the operand to the target shape. + StatusOr AddBroadcastSequence(const Shape& output_shape, + const XlaOp& operand); + + // Internal helper method for creating a Reshape op with the already inferred + // shape. + StatusOr Reshape(const Shape& shape, const XlaOp& operand); + + // Returns the (inferred) result for the program shape using the given root. + StatusOr GetProgramShape(int64 root_id) const; + + // Returns shapes for the operands. + StatusOr> GetOperandShapes( + tensorflow::gtl::ArraySlice operands) const; + + // A visitor which checks whether an operation is a compile-time constant, + // meaning that it doesn't depend on any parameters, or on any stateful + // operation such as `RngNormal` or `Infeed`. The visitor walks the + // computation starting at a given operation and sets is_constant to false iff + // a parameter or stateful operation is encountered. + void IsConstantVisitor(const int64 op_handle, std::set* visited, + bool* is_constant) const; + + // Checks bounds for convolution parameters. + Status VerifyConvolution( + const Shape& lhs_shape, const Shape& rhs_shape, + const ConvolutionDimensionNumbers& dimension_numbers) const; + + // Helper function for creating a Window proto from user-supplied data. + // Returns error if the user-supplied data was invalid. + StatusOr MakeWindow( + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation) const; + + string name_; // Name to use for the built computation. + + // The first error encountered while building the computation. + // This is OK until the first error is encountered. + Status first_error_; + + // The saved stack trace from the point at which the first error occurred. + tensorflow::SavedStackTrace first_error_backtrace_; + + // The instructions of this computation. + std::vector instructions_; + + // The embedded computations used by this computation. Each computation was + // the entry computation of some XlaComputation, the key is the unique id of + // that XlaComputation. + std::map embedded_; + + // The unique parameter numbers. + tensorflow::gtl::FlatSet parameter_numbers_; + + // The metadata to attach to each op. This is structured as a "modal"-like + // operation, in order to simplify client code (and not sprinkle this metadata + // throughout the TensorFlow op kernel implementations). + OpMetadata metadata_; + + // Sharding for this operator. This is structured as a "model"-like operation, + // in order to simplify client code, similar to metadata_. + tensorflow::gtl::optional sharding_; + + // Mode bit that indicates whether to die when a first error is encountered. + bool die_immediately_on_error_ = false; + + XlaBuilder* parent_builder_{nullptr}; + + friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, + const Shape& shape, const string& name); + friend XlaOp ConstantLiteral(XlaBuilder* builder, + const LiteralSlice& literal); + template + friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value); + template + friend XlaOp ConstantR1(XlaBuilder* builder, + tensorflow::gtl::ArraySlice values); + friend XlaOp ConstantR1(XlaBuilder* builder, + const tensorflow::core::Bitmap& values); + template + friend XlaOp ConstantR2( + XlaBuilder* builder, + std::initializer_list> values); + template + friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout); + template + friend XlaOp ConstantFromArray(XlaBuilder* builder, + const Array& values); + template + friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout); + template + friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values); + template + friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout); + template + friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values); + template + friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout); + template + friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values); + + template + friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); + + friend XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + + friend XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions); + + friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + + friend XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + friend XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes); + + friend XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + friend XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); + + friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + + friend XlaOp ConcatInDim(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + int64 dimension); + + friend void Trace(const string& tag, const XlaOp& operand); + + friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true, + const XlaOp& on_false); + friend XlaOp Tuple(XlaBuilder* builder, + tensorflow::gtl::ArraySlice elements); + friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + friend XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + friend XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, + const string& config); + friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands); + friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape); + friend XlaOp HostCompute(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Conj(const XlaOp& operand); + friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Not(const XlaOp& operand); + friend XlaOp ShiftLeft( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + friend XlaOp ReduceWindow( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding); + friend XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + friend XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids); + friend XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id); + friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups); + friend XlaOp SelectAndScatter( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + friend XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + friend XlaOp Abs(const XlaOp& operand); + friend XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Exp(const XlaOp& operand); + friend XlaOp Expm1(const XlaOp& operand); + friend XlaOp Floor(const XlaOp& operand); + friend XlaOp Ceil(const XlaOp& operand); + friend XlaOp Round(const XlaOp& operand); + friend XlaOp Log(const XlaOp& operand); + friend XlaOp Log1p(const XlaOp& operand); + friend XlaOp Sign(const XlaOp& operand); + friend XlaOp Clz(const XlaOp& operand); + friend XlaOp Cos(const XlaOp& operand); + friend XlaOp Sin(const XlaOp& operand); + friend XlaOp Tanh(const XlaOp& operand); + friend XlaOp Real(const XlaOp& operand); + friend XlaOp Imag(const XlaOp& operand); + friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp IsFinite(const XlaOp& operand); + // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota + // in xla/client/lib/numeric.h with this (renamed to xla::Iota). + friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size); + friend XlaOp ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type); + friend XlaOp BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type); + friend XlaOp Neg(const XlaOp& operand); + friend XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation); + friend XlaOp Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values, + int64 dimension); + friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + friend XlaOp Map(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); + friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, + const Shape& shape); + friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + friend XlaOp While(const XlaComputation& condition, + const XlaComputation& body, const XlaOp& init); + friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + friend void Send(const XlaOp& operand, const ChannelHandle& handle); + friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); + friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle); + friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const ChannelHandle& handle); + friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config); + friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config); + friend XlaOp CreateToken(XlaBuilder* builder); + friend XlaOp AfterAll(XlaBuilder* builder, + tensorflow::gtl::ArraySlice tokens); +}; + +// RAII-style object: sets the current sharding assignment in builder on +// construction, and sets back to the previous assignment on destruction. +class XlaScopedShardingAssignment { + public: + XlaScopedShardingAssignment(xla::XlaBuilder* builder, + tensorflow::gtl::optional sharding) + : builder_(builder), prev_sharding_(builder->sharding()) { + SetSharding(sharding); + } + + XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; + XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = + delete; + + ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } + + private: + void SetSharding(const tensorflow::gtl::optional& sharding) { + if (sharding.has_value()) { + builder_->SetSharding(sharding.value()); + } else { + builder_->ClearSharding(); + } + } + + xla::XlaBuilder* const builder_; + tensorflow::gtl::optional prev_sharding_; +}; + +// Free functions for building XlaOps. The intention is that these will +// become the public API for building XlaOps rather than calling methods on +// XlaBuilder directly. + +// Enqueues a "retrieve parameter value" instruction for a parameter that was +// passed to the computation. +XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, + const string& name); + +// Enqueues a constant with the value of the given literal onto the +// computation. +XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); + +// Enqueues a constant onto the computation. Methods are templated on the +// native host type (NativeT) which corresponds to a specific XLA +// PrimitiveType as given in the following table: +// +// Native Type PrimitiveType +// ----------------------------- +// bool PRED +// int32 S32 +// int64 S64 +// uint32 U32 +// uint64 U64 +// float F32 +// double F64 +// +// Note: not all primitive types defined in xla_data.proto have a +// corresponding native type yet. +template +XlaOp ConstantR0(XlaBuilder* builder, NativeT value); +template +XlaOp ConstantR1(XlaBuilder* builder, + tensorflow::gtl::ArraySlice values); +XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values); +template +XlaOp ConstantR2(XlaBuilder* builder, + std::initializer_list> values); +template +XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout); +template +XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values); +template +XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout); +template +XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values); +template +XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout); +template +XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values); +template +XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout); +template +XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values); + +// Enqueues a rank one constant (XlaBuilder* builder, vector) onto the +// computation. The vector has size 'length' and every element has the value +// 'value'. +template +XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); + +// Adds dimensions to an array by duplicating the data in the array. +// +// The new dimensions are inserted on the left, i.e. if +// broadcast_sizes has values {a0, ..., aN} and the operand shape +// has dimensions {b0, ..., bM} then the shape of the output has +// dimensions {a0, ..., aN, b0, ..., bM}. +// +// The new dimensions index into copies of the operand, i.e. +// +// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] +XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + +// Performs in-dimension-style broadcast. +// +// Operand specifies the input to be broadcast. "shape" is expected output +// shape. "broadcast_dimensions" are the dimensions to be broadcasting into. +// Dimension numbers in broadcast_dimensions map to individual dimensions +// of the operand, and specify what dimension of the output shape they +// should be broadcast. +// e.g. +// Say operand = [1, 2], i.e., a 1D tensor with 2 elements. +// and dimension of shape is [2,2]. +// Specifying {1} as brodcast_dimension will generate output +// [1 , 2] +// [1 , 2] +// On the other hand, specifying {0} as broadcast_dimension +// will generate output +// [1 , 1] +// [2 , 2] +XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions); + +// Enqueues a pad operation onto the computation that pads the given value on +// the edges as well as between the elements of the input. padding_config +// specifies the padding amount for each dimension. +XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + +// Enqueues an operation onto the computation that flattens the operand based +// on the dimension order (major/slowest-varying to minor/fastest-varying) +// given, followed by reshaping it into the shape with the given dimension +// sizes (also major to minor). Conceptually, this is a limited form of +// "shape casting". +XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + +// Enqueues an operation onto the computation that collapses the operand, from +// first to last dimension (C order), then reshapes it to the given dimension +// sizes. Conceptually, this is a limited form of "shape casting". +XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes); + +// Wrapper for Reshape. +// Enqueues an operation to collapse the provided dimensions; e.g. an +// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to +// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must +// be a consecutive, in-order subsequence of the operand dimensions. +// +// Note that collapsing a single dimension does nothing: +// +// {256} collapsing {0} => {256} +// {1} collapsing {0} => {1} +// +// Collapsing multiple dimensions produces a single result dimension: +// +// {256, 2} collapsing {0,1} => {512} +// {256, 2, 3} collapsing {0,1} => {512, 3} +// +// This could potentially cause data to be moved -- it provides a more +// structured form of reshaping than an arbitrary Reshape operation. +XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + +// Enqueues a slice operation onto the computation that slices the operand +// from the start indices to the limit indices; e.g. +// +// x +// [ 0 1 2 3 ] +// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] +// [ 8 9 a b ] +// +// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D +// range notation. +// The strides parameter determines the stride over the slice +XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + +// Enqueues a slice operation in a given dimension, taking all other +// dimensions as they are; e.g. if dimno is 1 from start_index 2 to +// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand +// for: +// +// array[:, 2:4:1, :] +XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno); + +// Enqueues a slice operation onto the computation that slices the 'operand' +// from dynamic start indices which are passed in 'start_indices'. +// The size of the slice in each dimension is passed in 'slice_sizes', +// which specify the end point of exclusive slice intervals in each +// dimension [start, start + size). +// The shape of 'start_indices' must be rank == 1, with dimension size +// equal to the rank of the 'operand'. +// Slice index calculations are computed modulo input dimension sizes to +// prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + +// Enqueues a dynamic update slice operation onto the computation, which +// updates a slice of 'operand' with 'update' at dynamic 'start_indices'. +// The shape of 'update' determines the shape of the slice of 'operand' +// which is updated. +// The indices specified in 'start_indices' specify the offset of the slice +// of 'operand' which is updated. +// +// update = {10, 11} // calculated at runtime. +// [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] +// [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] +// [7 8 9] [7 8 9 ] +// +// The shape of 'start_indices' must be rank == 1, with dimension size +// equal to the rank of the 'operand'. +// Slice index calculations are computed modulo update dimension sizes to +// prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + +// Enqueues a concatenate instruction onto the computation. 'operands' must +// have >= 1 entry. +XlaOp ConcatInDim(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, int64 dimension); + +// Enqueue a tracing operation onto the computation; the computation will emit +// a logging message with the operand. +void Trace(const string& tag, const XlaOp& operand); + +// Enqueues a conditional-move-like select operation onto the computation; +// predicated on pred, selects between on_true and on_false. +XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); + +// Enqueues a tuple-creation instruction onto the computation. +XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements); + +// Enqueues a tuple-element-get instruction onto the computation. +XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + +// Enqueues an equal-to comparison instruction onto the computation. +XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a not-equal comparison instruction onto the computation. +XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a greater-or-equal comparison instruction onto the computation. +XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a greater-than comparison instruction onto the computation. +XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a less-than comparison instruction onto the computation. +XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a less-or-equal comparison instruction onto the computation. +XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a dot instruction onto the computation. +XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + +// Enqueues a general dot instruction onto the computation. +XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + +// Enqueues a convolution instruction onto the computation, which uses the +// default convolution dimension numbers. +XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration in the format returned by MakePadding(). +XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided dimension numbers configuration. +XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration as well as the dimension numbers. +XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration, dilation factors and dimension numbers. +XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + +// Enqueues an FFT instruction onto the computation, of the given type and +// with the given FFT length. +XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + +// Enqueues an infeed instruction onto the computation, which writes data of +// the given shape to the infeed buffer of the device. +XlaOp Infeed(XlaBuilder* builder, const Shape& shape, + const string& config = ""); + +// Variant of Infeed which takes a token-shaped operand and produces a +// two-element tuple containing the data value and a token-shaped value. +// Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config = ""); + +// Enqueues an outfeed instruction onto the computation. This instruction +// generates outgoing data transfers for the given data. +// +// shape_with_layout communicates the laid out shape that we want to outfeed +// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error +// will occur. +void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + +// Variant of Outfeed which takes a token-shaped operand and produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config); + +// Enqueues a call instruction onto the computation. +XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands); + +// Enqueues a custom call instruction onto the computation. +// During code generation, a call instruction is emitted which targets a +// symbol with the name |call_target_name|. The |operands| are passed to the +// call instruction. |shape| is the resultant shape. +XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape); + +// Enqueues a pseudo-op to represent host-side computation data-dependencies. +// During code generation, host send and receive operations will be generated +// to transfer |operands| to the host and a single result of |shape| back to +// the device. Host send/recv operations are emitted using |channel_name|. +// Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO +// instruction scheduling. +XlaOp HostCompute(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + +// The following methods enqueue element-wise binary arithmetic operations +// onto the computation. The shapes of the operands have to match unless one +// of the operands is a scalar, or an explicit broadcast dimension is given +// (see g3doc for more details). + +// Enqueues a complex compose instruction onto the computation. +XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a complex conjugate instruction onto the computation. +XlaOp Conj(const XlaOp& operand); + +// Enqueues an add instruction onto the computation. +XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a subtract instruction onto the computation. +XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a multiply instruction onto the computation. +XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a divide instruction onto the computation. +XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a remainder instruction onto the computation. +XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a max instruction onto the computation. +XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues a min instruction onto the computation. +XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Element-wise logical operators +XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +XlaOp Not(const XlaOp& operand); + +XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); +XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); +XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Reduces an array among the provided dimensions, given "computation" as a +// reduction operator. +XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + +// Convenience wrapper around the above that reduces all the dimensions in the +// operand shape. +XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + +// Enqueues a windowed reduce instruction onto the computation. +XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + +// As ReduceWindow(), but the padding is given in the format +// returned by MakePadding(). +XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + +// Returns the sum of the operand value within each subgroup of replicas. All +// replicas supply one input to the sum and all replicas receive the resulting +// sum for each subgroup. +XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids = {}); + +// Enqueues an operation that do an AllReduce of the operand cross cores. Here +// AllReduce means doing a reduction on the input operand cross cores and then +// broadcasting the reduction result to those cores. The reduction function is +// defined by `computation`, which should be a commutative computation on +// scalars, e.g., add, min, or max. The way that AllReduce is applied is +// configured by: +// +// - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all +// replicas belong to one group. Allreduce will be applied within subgroups. +// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, +// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. +// +// - `channel_id`: for Allreduce nodes from different modules, if they have the +// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be +// applied cross modules. +// +// TODO(b/79737069): Rename this to AllReduce when it's ready to use. +XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids = {}, + const tensorflow::gtl::optional& + channel_id = tensorflow::gtl::nullopt); + +// Enqueues an operation that do an Alltoall of the operand cross cores. +// +// TODO(b/110096724): This is NOT YET ready to use. +XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups = {}); + +// Enqueues an operation that scatters the `source` array to the selected +// indices of each window. +XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); + +// As SelectAndScatter(), but the padding is given in the format +// returned by MakePadding(). +XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + +// Enqueues an abs instruction onto the computation. +XlaOp Abs(const XlaOp& operand); + +// Enqueues a atan2 instruction onto the computation. +XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues an exp instruction onto the computation. +XlaOp Exp(const XlaOp& operand); + +// Enqueues an expm1 instruction onto the computation. +XlaOp Expm1(const XlaOp& operand); + +// Enqueues a floor instruction onto the computation. +XlaOp Floor(const XlaOp& operand); + +// Enqueues a ceil instruction onto the computation. +XlaOp Ceil(const XlaOp& operand); + +// Enqueues a round instruction onto the computation, rounding to nearest even +// with half-way cases rounding away from zero. +XlaOp Round(const XlaOp& operand); + +// Enqueues an log instruction (natural logarithm) onto the computation. +XlaOp Log(const XlaOp& operand); + +// Enqueues an log1p instruction (log(x+1)) onto the computation. +XlaOp Log1p(const XlaOp& operand); + +// Enqueues a sign instruction onto the computation. +XlaOp Sign(const XlaOp& operand); + +// Enqueues a count leading zeros instruction onto the computation. +XlaOp Clz(const XlaOp& operand); + +// Enqueues a cosine instruction onto the computation. +XlaOp Cos(const XlaOp& operand); + +// Enqueues a sine instruction onto the computation. +XlaOp Sin(const XlaOp& operand); + +// Enqueues a tanh instruction onto the computation. +XlaOp Tanh(const XlaOp& operand); + +// Enqueues a real-part instruction onto the computation. +XlaOp Real(const XlaOp& operand); + +// Enqueues an imaginary-part instruction onto the computation. +XlaOp Imag(const XlaOp& operand); + +// Enqueues a lhs^rhs computation onto the computation. +XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + +// Enqueues an operator that tests if the operand's values are finite, i.e., +// not Inf or NaN. Defined only for floating-point types. Returns an array of +// booleans with the same shape where entries are true iff the corresponding +// entry was NaN. +XlaOp IsFinite(const XlaOp& operand); + +// Enqueues a convert instruction onto the computation that changes the +// element type of the operand array to primitive_type. +XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); + +// Enqueues a no-op instruction onto the computation that changes +// the element type of the operand array to primitive_type. The +// bit-widths of the source and destination element types must be +// identical. +XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); + +// Enqueues a negate instruction onto the computation. +XlaOp Neg(const XlaOp& operand); + +// Enqueues a transpose instruction onto the computation. +XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation); + +// Enqueues a reverse instruction onto the computation. The order of the +// elements in the given dimensions is reversed (i.e., the element at index i +// is moved to index dimension_size - 1 - i). +XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions); + +// Enqueues a sort (as increasing order) instruction onto the computation. +// If only keys are provided: +// * If the keys are an rank-1 tensor (an array), the result is a sorted array +// of keys, in ascending order. +// * If the keys have higher rank, the keys are sorted along the provided +// dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension +// value of 0 will indepenently sort every column, and a dimension value of 1 +// will independently sort each row. If no dimension number is provided, then +// the last dimension is chosen by default. +// +// If both keys and values are provided: +// * The keys and the values must tensors with the same dimensions. The +// element types of the tensors may be different. +// * The result is a tuple that consists of a sorted tensor of keys (along the +// provided dimension, as above) as the first element, and a tensor with their +// corresponding values as the second element. +XlaOp Sort(XlaOp keys, + tensorflow::gtl::optional values = tensorflow::gtl::nullopt, + int64 dimension = -1); + +// Enqueues a clamp instruction onto the computation. +XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + +// Enqueues a map instruction onto the computation. +XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands = {}); + +// Enqueues a N(mu, sigma) random number generation instruction onto the +// computation. +XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); + +// Enqueues a U(a, b) random number generation instruction onto the +// computation. Returns values in the semi-open interval [a, b). +XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + +// Enqueues a while node onto the computation. +XlaOp While(const XlaComputation& condition, const XlaComputation& body, + const XlaOp& init); + +// Enqueues a conditional node onto the computation. +XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + +// Enqueues a ReducePrecision node onto the computation. +XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + +// Enqueues a Gather node onto the computation. +XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + +// Enqueues a Scatter node onto the computation. +XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + +// Enqueues a Send node onto the computation for device-to-device +// communication. This operation sends the given operand to +// a Recv instruction in a different computation that shares the same channel +// handle. +void Send(const XlaOp& operand, const ChannelHandle& handle); + +// Variant of Send which takes a token-shaped operand and produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle); + +// Enqueues a Recv node onto the computation for device-to-device +// communication. The data comes from a Send instruction in a different +// computation that shares the same channel handle and its shape must be the +// same as the given shape. +XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle); + +// Variant of Recv which takes a token-shaped operand and produces a two-element +// tuple containing the data value and a token-shaped value. Tokens are used +// for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + +// Enqueues a Send node which transfers data from the device to the host. The +// 'shape_with_layout' argument defines the layout of the data transferred; its +// shape must be compatible with the shape of the operand. The operand must be +// array-shaped. +// TODO(b/111544877): Support tuple shapes. +XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, const ChannelHandle& handle); + +// Enqueues a Recv node which transfers data from the host to the device. The +// given shape must contain a layout and must be an array. +// TODO(b/111544877): Support tuple shapes. +XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + +// Enqueues an operation (AfterAll) with no operands that produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// This is a separate method from AfterAll to facility the removal of +// operand-less AfterAll instructions. +// TODO(b/110532604): Remove this function when all tokens are derived from a +// single token generated or passed into the entry computation. +XlaOp CreateToken(XlaBuilder* builder); + +// Enqueues an AfterAll instruction which produces a token-shaped value and +// takes a variadic number of token-shaped operands. The number of operands must +// be greater than zero. Used for joining tokens. +XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice tokens); + +// Normalizes operand across spatial and batch dimensions for each feature. +// +// Returns a tuple (normalized, batch_mean, batch_var) where `normalized` +// is the normalized result and batch_mean and batch_var are the mean and +// variance, respectively, across batch for the operand. +XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + +// Normalizes operand across spatial and batch dimensions for each feature. +// +// `BatchNormInference` is equivalent to calling `BatchNormTraining` without +// computing `mean` and `variance` for each batch inside the operation. It +// uses the input `mean` and `variance` instead as estimated values. The +// purpose of this op is to reduce latency in inference, hence the name +// `BatchNormInference`. +// +// The output has the same shape as `operand`, and contains the normalized +// values for each batch. +XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + +// Calculates the gradients of a batch norm op. +// +// The inputs `batch_mean` and `batch_var` represent the mean and variance +// across the batch. +// +// Returns a tuple of three elements: +// - grad_operand: Gradient with respect to input `operand` +// - grad_offset: Gradient with respect to input `offset` +// - grad_scale: Gradient with respect to input `scale` +XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); + +// Implementation details below this point. + +template +XlaOp XlaBuilder::ConstantR0(NativeT value) { + return ConstantLiteral(*LiteralUtil::CreateR0(value)); +} + +template +XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice values) { + return ConstantLiteral(*LiteralUtil::CreateR1(values)); +} + +template +XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {length})); + literal.PopulateWithValue(value); + return ConstantLiteral(literal); +} + +inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { + return ConstantLiteral(*LiteralUtil::CreateR1(values)); +} + +template +XlaOp XlaBuilder::ConstantR2( + std::initializer_list> values) { + return ConstantLiteral(*LiteralUtil::CreateR2(values)); +} + +template +XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, + const Layout& layout) { + return ConstantLiteral( + *LiteralUtil::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantFromArray(const Array& values) { + return ConstantLiteral(*LiteralUtil::CreateFromArray(values)); +} + +template +XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return ConstantLiteral( + *LiteralUtil::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { + return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D(values)); +} + +template +XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { + return ConstantLiteral( + *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D& values) { + return ConstantFromArray(values); +} + +template +XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { + return ConstantFromArrayWithLayout(values, layout); +} + +template +XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { + return ConstantFromArray(values); +} + +// Free function template implementations. + +template +XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { + return ConstantLiteral(builder, *LiteralUtil::CreateR0(value)); +} + +template +XlaOp ConstantR1(XlaBuilder* builder, + tensorflow::gtl::ArraySlice values) { + return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); +} + +template +XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) { + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {length})); + literal.PopulateWithValue(value); + return ConstantLiteral(builder, literal); +} + +inline XlaOp ConstantR1(XlaBuilder* builder, + const tensorflow::core::Bitmap& values) { + return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); +} + +template +XlaOp ConstantR2(XlaBuilder* builder, + std::initializer_list> values) { + return ConstantLiteral(builder, *LiteralUtil::CreateR2(values)); +} + +template +XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout) { + return ConstantLiteral( + builder, + *LiteralUtil::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { + return ConstantLiteral(builder, + *LiteralUtil::CreateFromArray(values)); +} + +template +XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout) { + return ConstantLiteral( + builder, + *LiteralUtil::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values) { + return ConstantLiteral(builder, + *LiteralUtil::CreateR2FromArray2D(values)); +} + +template +XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout) { + return ConstantLiteral( + builder, + *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); +} + +template +XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values) { + return ConstantFromArray(builder, values); +} + +template +XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout) { + return ConstantFromArrayWithLayout(builder, values, layout); +} + +template +XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values) { + return ConstantFromArray(builder, values); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..49a15ec3b449bdec07aa6ecfbc40b7b9f62c3f4e --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -0,0 +1,391 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_builder.h" + +#include + +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +namespace { + +namespace op = xla::testing::opcode_matchers; + +using ::testing::HasSubstr; + +// TODO(b/74197823): Move the tests to service/. +class XlaBuilderTest : public ::testing::Test { + protected: + StatusOr> BuildHloModule(XlaBuilder* b) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build()); + const HloModuleProto& proto = computation.proto(); + TF_ASSIGN_OR_RETURN(const auto& config, + HloModule::CreateModuleConfigFromProto( + proto, legacy_flags::GetDebugOptionsFromFlags())); + return HloModule::CreateFromProto(proto, config); + } + + // Overload which explicitly specifies the root instruction. + StatusOr> BuildHloModule(XlaBuilder* b, + XlaOp root) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build(root)); + const HloModuleProto& proto = computation.proto(); + TF_ASSIGN_OR_RETURN(const auto& config, + HloModule::CreateModuleConfigFromProto( + proto, legacy_flags::GetDebugOptionsFromFlags())); + return HloModule::CreateFromProto(proto, config); + } + + // Returns the name of the test currently being run. + string TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } +}; + +TEST_F(XlaBuilderTest, OnePlusTwo) { + XlaBuilder b(TestName()); + Add(ConstantR0(&b, 1.0), ConstantR0(&b, 2.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); +} + +TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) { + auto test_unary_operator = + [&](std::function op, + ::testing::Matcher matches_pattern) { + XlaBuilder b(TestName()); + op(ConstantR0(&b, 1)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; + test_unary_operator([](XlaOp x) { return -x; }, op::Negate(op::Constant())); + test_unary_operator([](XlaOp x) { return ~x; }, op::Not(op::Constant())); +} + +TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) { + auto test_binary_operator = + [&](std::function op, + ::testing::Matcher matches_pattern) { + XlaBuilder b(TestName()); + op(ConstantR0(&b, 1), ConstantR0(&b, 2)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; + + test_binary_operator([](XlaOp x, XlaOp y) { return x + y; }, + op::Add(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x - y; }, + op::Subtract(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x * y; }, + op::Multiply(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x / y; }, + op::Divide(op::Constant(), op::Constant())); + + test_binary_operator([](XlaOp x, XlaOp y) { return x & y; }, + op::And(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x | y; }, + op::Or(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x ^ y; }, + op::Xor(op::Constant(), op::Constant())); + test_binary_operator([](XlaOp x, XlaOp y) { return x << y; }, + op::ShiftLeft(op::Constant(), op::Constant())); + test_binary_operator( + [](XlaOp x, XlaOp y) { return x >> y; }, + op::ShiftRightArithmetic(op::Constant(), op::Constant())); + + auto test_unsigned_binary_operator = + [&](std::function op, + ::testing::Matcher matches_pattern) { + XlaBuilder b(TestName()); + op(ConstantR0(&b, 1), ConstantR0(&b, 2)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matches_pattern); + }; + test_unsigned_binary_operator( + [](XlaOp x, XlaOp y) { return x >> y; }, + op::ShiftRightLogical(op::Constant(), op::Constant())); +} + +TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) { + XlaBuilder b(TestName()); + ConstantR0(&b, 1) >> ConstantR0(&b, 2); + auto statusor = b.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Argument to >> operator does not have an integral type")); +} + +TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); + Add(x, ConstantR0(&b, 1.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant()))); +} + +TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { + XlaBuilder b(TestName()); + const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); + const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4}); + auto x = Parameter(&b, 0, x_shape, "x"); + auto y = Parameter(&b, 1, y_shape, "y"); + auto add = Add(x, y, /*broadcast_dimensions=*/{0, 1}); + + TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add)); + EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1)))); +} + +TEST_F(XlaBuilderTest, XPlusX) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x"); + Add(x, x); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0))); +} + +TEST_F(XlaBuilderTest, ShapeInferenceError) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(U32, {2, 4}), "y"); + Add(x, y); + auto statusor = BuildHloModule(&b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference")); +} + +TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { + XlaBuilder b_call("add"); + Parameter(&b_call, 0, ShapeUtil::MakeShape(PRED, {}), "x"); + + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "x"); + auto y = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "y"); + Add(x, y); + auto statusor = BuildHloModule(&b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("parameter 0 already registered")); +} + +TEST_F(XlaBuilderTest, Call) { + XlaBuilder b_call("the_only_to_apply"); + auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0"); + auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1"); + Add(p0, p1); + TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build()); + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); + auto one = ConstantR0(&b, 1); + auto two = ConstantR0(&b, 2); + Add(Call(&b, call, {x, y}), Call(&b, call, {one, two})); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()), + op::Call(op::Constant(), op::Constant()))); +} + +TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y"); + Add(x, y); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + // Expected: + // + // x: f32[1,2,3] y: f32[1,2,1] + // | | + // | reshape: f32[1,2] + // | | + // | broadcast: f32[1,2,3] + // \ / + // add + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), + op::Broadcast(op::Reshape(op::Parameter(1))))); +} + +TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y"); + Add(x, y, /*broadcast_dimensions=*/{0, 1}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + // The binary operation has in-dim broadcast and degenerate broadcast, should + // first do the in-dim broadcast then convert the degnerate broadcast into a + // reshape and a broadcast. + // + // Expected: + // + // x: f32[2,3] y: f32[2,1,4] + // | | + // broadcast: f32[2,3,4] reshape: f32[2,4] + // | | + // | broadcast: f32[2,3,4] + // \ / + // add + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)), + op::Broadcast(op::Reshape(op::Parameter(1))))); +} + +TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { + XlaBuilder b1("b1"); + auto p0 = Parameter(&b1, 0, ShapeUtil::MakeShape(F32, {}), "p0"); + XlaBuilder builder("main"); + auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "p"); + Add(p, p0); + auto statusor = builder.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "built by builder 'b1', but is trying to use it in builder 'main'")); +} + +TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + Reshape(x, /*new_sizes=*/{6, 35}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Parameter())); +} + +TEST_F(XlaBuilderTest, ReshapeHasTranspose) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter()))); +} + +TEST_F(XlaBuilderTest, Transpose) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + Transpose(x, /*permutation=*/{1, 0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Transpose(op::Parameter())); +} + +TEST_F(XlaBuilderTest, AllToAll) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); + AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, + /*split_count=*/2); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + // AllToAll is decomposed into slices -> all-to-all -> gte -> concat. + EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate); + EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll); + EXPECT_TRUE( + ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8}))); +} + +TEST_F(XlaBuilderTest, ReportError) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + Add(b.ReportError(InvalidArgument("a test error")), x); + auto statusor = b.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); +} + +TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) { + XlaBuilder b(TestName()); + StatusOr op(ConstantR0(&b, 1.0)); + Add(b.ReportErrorOrReturn(op), ConstantR0(&b, 2.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); +} + +TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) { + XlaBuilder b(TestName()); + StatusOr op(InvalidArgument("a test error")); + Add(b.ReportErrorOrReturn(op), ConstantR0(&b, 2.0)); + auto statusor = b.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); +} + +TEST_F(XlaBuilderTest, BuildWithSpecificRoot) { + XlaBuilder b(TestName()); + XlaOp constant = ConstantR0(&b, 1.0); + Add(constant, ConstantR0(&b, 2.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Constant()); +} + +TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) { + // Specifying a particular root in Build should still include all entry + // parameters. + XlaBuilder b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); + XlaOp x = Parameter(&b, 0, shape, "x"); + XlaOp y = Parameter(&b, 1, shape, "y"); + XlaOp z = Parameter(&b, 2, shape, "z"); + Add(x, Sub(y, z)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter()); + EXPECT_EQ(module->entry_computation()->num_parameters(), 3); + EXPECT_EQ(module->entry_computation()->instruction_count(), 5); +} + +TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) { + XlaBuilder b(TestName()); + XlaBuilder other_b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); + + Parameter(&b, 0, shape, "param"); + XlaOp other_param = Parameter(&other_b, 0, shape, "other_param"); + + Status status = b.Build(other_param).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("root operation is not in this computation")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD deleted file mode 100644 index 507a2dc5f088e159156f0ef3d663ba2819f6a2d4..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ /dev/null @@ -1,78 +0,0 @@ -# Description: -# The new XLA client libraries. -# -# This is NOT YET ready to use. - -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = [":friends"]) - -package_group( - name = "friends", - includes = [ - "//tensorflow/compiler/xla:friends", - ], -) - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "xla_computation", - srcs = ["xla_computation.cc"], - hdrs = ["xla_computation.h"], - deps = [ - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_proto", - ], -) - -cc_library( - name = "xla_builder", - srcs = ["xla_builder.cc"], - hdrs = ["xla_builder.h"], - deps = [ - ":xla_computation", - "//tensorflow/compiler/xla:execution_options_util", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "xla_builder_test", - srcs = ["xla_builder_test.cc"], - deps = [ - ":xla_builder", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_matchers", - "//tensorflow/core:test", - ], -) diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h deleted file mode 100644 index 0329e42ed1aef8edd1537e888ddcd78f08584407..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ /dev/null @@ -1,1039 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ - -#include -#include -#include - -#include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/stacktrace.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -class XlaBuilder; - -// This represents an instruction that has been enqueued using the XlaBuilder. -// This is used to pass to subsequent computations that depends upon the -// instruction as an operand. -class XlaOp { - public: - XlaOp() : handle_(0), builder_(nullptr) {} - ~XlaOp() {} - - const XlaBuilder* builder() const { return builder_; } - - bool operator==(const XlaOp& rhs) const { - return handle_ == rhs.handle_ && builder_ == rhs.builder_; - } - - bool operator!=(const XlaOp& rhs) const { - return handle_ != rhs.handle_ || builder_ != rhs.builder_; - } - - friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) { - out << op.handle(); - return out; - } - - private: - XlaOp(int64 handle, XlaBuilder* builder) - : handle_(handle), builder_(builder) {} - - int64 handle() const { return handle_; } - - friend class XlaBuilder; - - int64 handle_; - XlaBuilder* builder_; // Not owned. -}; - -// A convenient interface for building up computations. -// -// Thread-compatible. -class XlaBuilder { - public: - // computation_name: name to use for the built computation. - XlaBuilder(const string& computation_name); - - XlaBuilder(const XlaBuilder&) = delete; - XlaBuilder& operator=(const XlaBuilder&) = delete; - - ~XlaBuilder(); - - // Returns the computation name. - const string& name() const { return name_; } - - // Sets OpMetadata that will be added to all instructions until cleared. - // - // OpMetadata is often applied to a series of XLA HLO instructions. As a - // result, OpMetadata is set on the Computation Builder. All subsequent - // instructions generated via this Computation Builder will have the same - // OpMetadata attached until a call to ClearOpMetadata. - void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } - - // Clears the HloMetadata state. - void ClearOpMetadata() { metadata_.Clear(); } - - // Sets an OpSharding that will be attached to all instructions until cleared. - void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } - - // Clears the sharding. Ops will be sharded according to the default placement - // policy. - void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } - - // Returns the OpSharding that will be attached to all instructions. - const tensorflow::gtl::optional& sharding() const { - return sharding_; - } - - // Sets the builder to a mode where it will die immediately when an error is - // encountered, rather than producing it in a deferred fashion when Build() is - // called (which is the default). - void set_die_immediately_on_error(bool enabled) { - die_immediately_on_error_ = enabled; - } - - // Enqueues a "retrieve parameter value" instruction for a parameter that was - // passed to the computation. - XlaOp Parameter(int64 parameter_number, const Shape& shape, - const string& name); - - // Enqueues a constant with the value of the given literal onto the - // computation. - XlaOp ConstantLiteral(const LiteralSlice& literal); - - // Enqueues a constant onto the computation. Methods are templated on the - // native host type (NativeT) which corresponds to a specific XLA - // PrimitiveType as given in the following table: - // - // Native Type PrimitiveType - // ----------------------------- - // bool PRED - // int32 S32 - // int64 S64 - // uint32 U32 - // uint64 U64 - // float F32 - // double F64 - // - // Note: not all primitive types defined in xla_data.proto have a - // corresponding native type yet. - template - XlaOp ConstantR0(NativeT value); - template - XlaOp ConstantR1(tensorflow::gtl::ArraySlice values); - XlaOp ConstantR1(const tensorflow::core::Bitmap& values); - template - XlaOp ConstantR2( - std::initializer_list> values); - template - XlaOp ConstantFromArrayWithLayout(const Array& values, - const Layout& layout); - template - XlaOp ConstantFromArray(const Array& values); - template - XlaOp ConstantR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout); - template - XlaOp ConstantR2FromArray2D(const Array2D& values); - template - XlaOp ConstantR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout); - template - XlaOp ConstantR3FromArray3D(const Array3D& values); - template - XlaOp ConstantR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout); - template - XlaOp ConstantR4FromArray4D(const Array4D& values); - - // Enqueues a rank one constant (vector) onto the computation. The vector has - // size 'length' and every element has the value 'value'. - template - XlaOp ConstantR1(int64 length, NativeT value); - - // Adds dimensions to an array by duplicating the data in the array. - // - // The new dimensions are inserted on the left, i.e. if - // broadcast_sizes has values {a0, ..., aN} and the operand shape - // has dimensions {b0, ..., bM} then the shape of the output has - // dimensions {a0, ..., aN, b0, ..., bM}. - // - // The new dimensions index into copies of the operand, i.e. - // - // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] - XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); - - // Enqueues a pad operation onto the computation that pads the given value on - // the edges as well as between the elements of the input. padding_config - // specifies the padding amount for each dimension. - XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, - const PaddingConfig& padding_config); - - // Enqueues an operation onto the computation that flattens the operand based - // on the dimension order (major/slowest-varying to minor/fastest-varying) - // given, followed by reshaping it into the shape with the given dimension - // sizes (also major to minor). Conceptually, this is a limited form of - // "shape casting". - XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); - - // Enqueues an operation onto the computation that collapses the operand, from - // first to last dimension (C order), then reshapes it to the given dimension - // sizes. Conceptually, this is a limited form of "shape casting". - XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes); - - // Wrapper for Reshape. - // Enqueues an operation to collapse the provided dimensions; e.g. an - // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to - // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must - // be a consecutive, in-order subsequence of the operand dimensions. - // - // Note that collapsing a single dimension does nothing: - // - // {256} collapsing {0} => {256} - // {1} collapsing {0} => {1} - // - // Collapsing multiple dimensions produces a single result dimension: - // - // {256, 2} collapsing {0,1} => {512} - // {256, 2, 3} collapsing {0,1} => {512, 3} - // - // This could potentially cause data to be moved -- it provides a more - // structured form of reshaping than an arbitrary Reshape operation. - XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); - - // Enqueues a slice operation onto the computation that slices the operand - // from the start indices to the limit indices; e.g. - // - // x - // [ 0 1 2 3 ] - // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] - // [ 8 9 a b ] - // - // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D - // range notation. - // The strides parameter determines the stride over the slice - XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); - - // Enqueues a slice operation in a given dimension, taking all other - // dimensions as they are; e.g. if dimno is 1 from start_index 2 to - // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand - // for: - // - // array[:, 2:4:1, :] - XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno); - - // Enqueues a slice operation onto the computation that slices the 'operand' - // from dynamic start indices which are passed in 'start_indices'. - // The size of the slice in each dimension is passed in 'slice_sizes', - // which specify the end point of exclusive slice intervals in each - // dimension [start, start + size). - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo input dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. - XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); - - // Enqueues a dynamic update slice operation onto the computation, which - // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. - // The shape of 'update' determines the shape of the slice of 'operand' - // which is updated. - // The indices specified in 'start_indices' specify the offset of the slice - // of 'operand' which is updated. - // - // update = {10, 11} // calculated at runtime. - // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] - // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] - // [7 8 9] [7 8 9 ] - // - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo update dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. - XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, - const XlaOp& start_indices); - - // Enqueues a concatenate instruction onto the computation. 'operands' must - // have >= 1 entry. - XlaOp ConcatInDim(tensorflow::gtl::ArraySlice operands, - int64 dimension); - - // Enqueue a tracing operation onto the computation; the computation will emit - // a logging message with the operand. - void Trace(const string& tag, const XlaOp& operand); - - // Enqueues a conditional-move-like select operation onto the computation; - // predicated on pred, selects between on_true and on_false. - XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); - - // Enqueues a tuple-creation instruction onto the computation. - XlaOp Tuple(tensorflow::gtl::ArraySlice elements); - - // Enqueues a tuple-element-get instruction onto the computation. - XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); - - // Enqueues an equal-to comparison instruction onto the computation. - XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a not-equal comparison instruction onto the computation. - XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a greater-or-equal comparison instruction onto the computation. - XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a greater-than comparison instruction onto the computation. - XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a less-than comparison instruction onto the computation. - XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a less-or-equal comparison instruction onto the computation. - XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a dot instruction onto the computation. - XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); - - // Enqueues a general dot instruction onto the computation. - XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers); - - // Default dimension numbers used for a 2D convolution. - static constexpr int64 kConvBatchDimension = 0; - static constexpr int64 kConvFeatureDimension = 1; - static constexpr int64 kConvFirstSpatialDimension = 2; - static constexpr int64 kConvSecondSpatialDimension = 3; - static constexpr int64 kConvKernelOutputDimension = 0; - static constexpr int64 kConvKernelInputDimension = 1; - static constexpr int64 kConvKernelFirstSpatialDimension = 2; - static constexpr int64 kConvKernelSecondSpatialDimension = 3; - - // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for - // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for - // the kernel operand - // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. - static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( - int num_spatial_dims = 2); - - // Returns an error if the convolution dimension numbers have conflicts. - static Status Validate(const ConvolutionDimensionNumbers& dnum); - - // Enqueues a convolution instruction onto the computation, which uses the - // default convolution dimension numbers. - XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration in the format returned by MakePadding(). - XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided dimension numbers configuration. - XlaOp ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration as well as the dimension numbers. - XlaOp ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration, dilation factors and dimension numbers. - XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues an FFT instruction onto the computation, of the given type and - // with the given FFT length. - XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); - - // Enqueues an infeed instruction onto the computation, which writes data of - // the given shape to the infeed buffer of the device. - XlaOp Infeed(const Shape& shape, const string& config = ""); - - // Enqueues an outfeed instruction onto the computation. This instruction - // generates outgoing data transfers for the given data. - // - // shape_with_layout communicates the laid out shape that we want to outfeed - // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error - // will occur. - void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, - const string& outfeed_config); - - // Enqueues a call instruction onto the computation. - XlaOp Call(const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); - - // Enqueues a custom call instruction onto the computation. - // During code generation, a call instruction is emitted which targets a - // symbol with the name |call_target_name|. The |operands| are passed to the - // call instruction. |shape| is the resultant shape. - XlaOp CustomCall(const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); - - // Enqueues a pseudo-op to represent host-side computation data-dependencies. - // During code generation, host send and receive operations will be generated - // to transfer |operands| to the host and a single result of |shape| back to - // the device. Host send/recv operations are emitted using |channel_name|. - // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO - // instruction scheduling. - XlaOp HostCompute(tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, - const Shape& shape); - - // The following methods enqueue element-wise binary arithmetic operations - // onto the computation. The shapes of the operands have to match unless one - // of the operands is a scalar, or an explicit broadcast dimension is given - // (see g3doc for more details). - - // Enqueues a complex compose instruction onto the computation. - XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a complex conjugate instruction onto the computation. - XlaOp Conj(const XlaOp& operand); - - // Enqueues an add instruction onto the computation. - XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a subtract instruction onto the computation. - XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a multiply instruction onto the computation. - XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a divide instruction onto the computation. - XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a remainder instruction onto the computation. - XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a max instruction onto the computation. - XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a min instruction onto the computation. - XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Element-wise logical operators - XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - XlaOp Not(const XlaOp& operand); - - XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Reduces an array among the provided dimensions, given "computation" as a - // reduction operator. - XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); - - // Convenience wrapper around the above that reduces all the dimensions in the - // operand shape. - XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation); - - // Enqueues a windowed reduce instruction onto the computation. - XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); - - // As ReduceWindow(), but the padding is given in the format - // returned by MakePadding(). - XlaOp ReduceWindowWithGeneralPadding( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - - // Returns the sum of the operand value within each subgroup of replicas. All - // replicas supply one input to the sum and all replicas receive the resulting - // sum for each subgroup. - XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_group_ids = {}); - - // Enqueues an operation that do an AllReduce of the operand cross cores. Here - // AllReduce means doing a reduction on the input operand cross cores and then - // broadcasting the reduction result to those cores. The reduction function is - // defined by `computation`, which should be a commutative computation on - // scalars, e.g., add, min, or max. The way that AllReduce is applied is - // configured by: - // - // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all - // replicas belong to one group. Allreduce will be applied within subgroups. - // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, - // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. - // - // - `channel_id`: for Allreduce nodes from different models, if they have the - // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be - // applied cross models. - // - // TODO(b/79737069): Rename this to AllReduce when it's ready to use. - XlaOp CrossReplicaSum( - const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_group_ids = {}, - const tensorflow::gtl::optional& channel_id = - tensorflow::gtl::nullopt); - - // Enqueues an operation that scatters the `source` array to the selected - // indices of each window. - XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding, const XlaOp& source, - const XlaOp& init_value, - const XlaComputation& scatter); - - // As SelectAndScatter(), but the padding is given in the format - // returned by MakePadding(). - XlaOp SelectAndScatterWithGeneralPadding( - const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); - - // Enqueues an abs instruction onto the computation. - XlaOp Abs(const XlaOp& operand); - - // Enqueues a atan2 instruction onto the computation. - XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues an exp instruction onto the computation. - XlaOp Exp(const XlaOp& operand); - - // Enqueues an expm1 instruction onto the computation. - XlaOp Expm1(const XlaOp& operand); - - // Enqueues a floor instruction onto the computation. - XlaOp Floor(const XlaOp& operand); - - // Enqueues a ceil instruction onto the computation. - XlaOp Ceil(const XlaOp& operand); - - // Enqueues a round instruction onto the computation, rounding to nearest even - // with half-way cases rounding away from zero. - XlaOp Round(const XlaOp& operand); - - // Enqueues an log instruction (natural logarithm) onto the computation. - XlaOp Log(const XlaOp& operand); - - // Enqueues an log1p instruction (log(x+1)) onto the computation. - XlaOp Log1p(const XlaOp& operand); - - // Enqueues a sign instruction onto the computation. - XlaOp Sign(const XlaOp& operand); - - // Enqueues a count leading zeros instruction onto the computation. - XlaOp Clz(const XlaOp& operand); - - // Enqueues a cosine instruction onto the computation. - XlaOp Cos(const XlaOp& operand); - - // Enqueues a sine instruction onto the computation. - XlaOp Sin(const XlaOp& operand); - - // Enqueues a tanh instruction onto the computation. - XlaOp Tanh(const XlaOp& operand); - - // Enqueues a real-part instruction onto the computation. - XlaOp Real(const XlaOp& operand); - - // Enqueues an imaginary-part instruction onto the computation. - XlaOp Imag(const XlaOp& operand); - - // Enqueues a float32 sqrt instruction onto the computation. - // (float32 is specified as there is an implicit float32 0.5f constant - // exponent). - XlaOp SqrtF32(const XlaOp& operand); - - // Enqueues a float32 square instruction onto the computation. - // (float32 is specified as there is an implicit float32 2.0f constant - // exponent). - XlaOp SquareF32(const XlaOp& operand); - - // Enqueues a lhs^rhs computation onto the computation. - XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues an operator that tests if the operand's values are finite, i.e., - // not Inf or NaN. Defined only for floating-point types. Returns an array of - // booleans with the same shape where entries are true iff the corresponding - // entry was NaN. - XlaOp IsFinite(const XlaOp& operand); - - // Enqueues a convert instruction onto the computation that changes the - // element type of the operand array to primitive_type. - XlaOp ConvertElementType(const XlaOp& operand, - PrimitiveType new_element_type); - - // Enqueues a no-op instruction onto the computation that changes - // the element type of the operand array to primitive_type. The - // bit-widths of the source and destination element types must be - // identical. - XlaOp BitcastConvertType(const XlaOp& operand, - PrimitiveType new_element_type); - - // Enqueues a float32 reciprocal instruction onto the computation. - // (float32 is specified as there is an implicit float32 -1.0f constant - // exponent). - // - // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the - // shape of the operand. - XlaOp ReciprocalF32(const XlaOp& operand); - - // Enqueues a negate instruction onto the computation. - XlaOp Neg(const XlaOp& operand); - - // Enqueues a transpose instruction onto the computation. - XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation); - - // Enqueues a reverse instruction onto the computation. The order of the - // elements in the given dimensions is reversed (i.e., the element at index i - // is moved to index dimension_size - 1 - i). - XlaOp Rev(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); - - // Enqueues a sort (as increasing order) instruction onto the computation. - XlaOp Sort(const XlaOp& operand); - - // Enqueues a clamp instruction onto the computation. - XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); - - // Enqueues a map instruction onto the computation. - XlaOp Map(tensorflow::gtl::ArraySlice operands, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); - - // Enqueues a N(mu, sigma) random number generation instruction onto the - // computation. - XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); - - // Enqueues a U(a, b) random number generation instruction onto the - // computation. Returns values in the semi-open interval [a, b). - XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); - - // Enqueues a while node onto the computation. - XlaOp While(const XlaComputation& condition, const XlaComputation& body, - const XlaOp& init); - - // Enqueues a conditional node onto the computation. - XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, - const XlaComputation& true_computation, - const XlaOp& false_operand, - const XlaComputation& false_computation); - - // Enqueues a ReducePrecision node onto the computation. - XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, - const int mantissa_bits); - - // Enqueues a Gather node onto the computation. - XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, - const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); - - // Enqueues a Send node onto the computation, to send the given operand to - // a Recv instruction that shares the same channel handle. - void Send(const XlaOp& operand, const ChannelHandle& handle); - - // Enqueues a Recv node onto the computation. The data comes from a Send - // instruction that shares the same channel handle and its shape must - // be the same as the given shape. - XlaOp Recv(const Shape& shape, const ChannelHandle& handle); - - // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on any parameters, or on stateful operators such - // as `RngNormal` or `Infeed`. - // - // This tests whether a computation is a compile-time constant without - // evaluating the computation. - StatusOr IsConstant(const XlaOp& operand) const; - - // Normalizes operand across spatial and batch dimensions for each feature. - // - // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` - // is the normalized result and batch_mean and batch_var are the mean and - // variance, respectively, across batch for the operand. - XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, - const XlaOp& offset, float epsilon, - int64 feature_index); - - // Normalizes operand across spatial and batch dimensions for each feature. - // - // `BatchNormInference` is equivalent to calling `BatchNormTraining` without - // computing `mean` and `variance` for each batch inside the operation. It - // uses the input `mean` and `variance` instead as estimated values. The - // purpose of this op is to reduce latency in inference, hence the name - // `BatchNormInference`. - // - // The output has the same shape as `operand`, and contains the normalized - // values for each batch. - XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, - const XlaOp& offset, const XlaOp& mean, - const XlaOp& variance, float epsilon, - int64 feature_index); - - // Calculates the gradients of a batch norm op. - // - // The inputs `batch_mean` and `batch_var` represent the mean and variance - // across the batch. - // - // Returns a tuple of three elements: - // - grad_operand: Gradient with respect to input `operand` - // - grad_offset: Gradient with respect to input `offset` - // - grad_scale: Gradient with respect to input `scale` - XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, - const XlaOp& batch_mean, const XlaOp& batch_var, - const XlaOp& grad_output, float epsilon, - int64 feature_index); - - // Returns a new XlaBuilder whose resultant Computation is used only by this - // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error - // behavior as the parent. - std::unique_ptr CreateSubBuilder(const string& computation_name); - - // Builds the computation with the requested operations, or returns a non-ok - // status. Note that all ops that have been enqueued will be moved to the - // computation being returned. - StatusOr Build(); - - // Builds the computation with the requested operations, or notes an error in - // the parent XlaBuilder and returns an empty computation if building failed. - // This function is intended to be used where the returned XlaComputation is - // only used by the parent XlaBuilder and hence further operation on the - // returned XlaComputation will simply be error'ed out if an error occurred - // while building this computation. If the built computation is to be used by - // a XlaBuilder other than the parent XlaBuilder then Build() should be used - // instead. - XlaComputation BuildAndNoteError(); - - // Returns a subgraph that roots on the given root. If the root is not a - // compile-time constant (see `IsConstant`), returns an error. - // - // This will copy the needed ops/computations to the subgraph. - StatusOr BuildConstantSubGraph(const XlaOp& root_op) const; - - // Returns the first error that was encountered while building the - // computation. When an error is encountered, by default we return a vacuous - // XlaOp and inform the user of the error that occurred while - // building the computation when they make a final call to Build(). - // - // See also set_die_immediately_on_error(). - Status first_error() const { return first_error_; } - - // Returns the shape of the given op. - StatusOr GetShape(const XlaOp& op) const; - - // Returns the (inferred) result for the current computation's shape. - StatusOr GetProgramShape() const; - - private: - StatusOr AddInstruction( - HloInstructionProto&& instr, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands = {}); - - void AddCalledComputation(const XlaComputation& computation, - HloInstructionProto* instr); - - // Notes that the error occurred by: - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to Build()) - // * dying if die_immediately_on_error_ is true - void NoteError(const Status& error); - - XlaOp NoteErrorOrReturn(const std::function()>& op_creator); - - // Helper method that creates an empty op and notes error. - XlaOp UnimplementedOp(); - - StatusOr LookUpInstruction(const XlaOp& op) const; - - // Internal helper method that does the building for an arbitrary unary op. - XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); - - // Internal helper method that does the building for an arbitrary binary op. - // broadcast_dimensions specifies which dimensions to use for broadcasting - // when the operation is between tensors of different ranks. - XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); - - // Internal helper method that does the building for an arbitrary ternary op. - XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, - const XlaOp& ehs); - - XlaOp RngOp(RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape); - - StatusOr InDimBroadcast( - const Shape& shape, const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_dimensions); - - // Internal helper method that creates a sequence of instructions that - // performs an explicit broadcast of the operand to the target shape. - StatusOr AddBroadcastSequence(const Shape& output_shape, - const XlaOp& operand); - - // Internal helper method for creating a Reshape op with the already inferred - // shape. - StatusOr Reshape(const Shape& shape, const XlaOp& operand); - - // Returns the (inferred) result for the program shape for the current - // computation and fills the root_id in the pointer. - StatusOr GetProgramShape(int64* root_id) const; - - // Returns shapes for the operands. - StatusOr> GetOperandShapes( - tensorflow::gtl::ArraySlice operands) const; - - // A visitor which checks whether an operation is a compile-time constant, - // meaning that it doesn't depend on any parameters, or on any stateful - // operation such as `RngNormal` or `Infeed`. The visitor walks the - // computation starting at a given operation and sets is_constant to false iff - // a parameter or stateful operation is encountered. - void IsConstantVisitor(const int64 op_handle, std::set* visited, - bool* is_constant) const; - - // Checks bounds for convolution parameters. - Status VerifyConvolution( - const Shape& lhs_shape, const Shape& rhs_shape, - const ConvolutionDimensionNumbers& dimension_numbers) const; - - // Helper function for creating a Window proto from user-supplied data. - // Returns error if the user-supplied data was invalid. - StatusOr MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation) const; - - string name_; // Name to use for the built computation. - - // The first error encountered while building the computation. - // This is OK until the first error is encountered. - Status first_error_; - - // The saved stack trace from the point at which the first error occurred. - tensorflow::SavedStackTrace first_error_backtrace_; - - // The instructions of this computation. - std::vector instructions_; - - // The embedded computations used by this computation. Each computation was - // the entry computation of some XlaComputation, the key is the unique id of - // that XlaComputation. - std::map embedded_; - - // The unique parameter numbers. - tensorflow::gtl::FlatSet parameter_numbers_; - - // The metadata to attach to each op. This is structured as a "modal"-like - // operation, in order to simplify client code (and not sprinkle this metadata - // throughout the TensorFlow op kernel implementations). - OpMetadata metadata_; - - // Sharding for this operator. This is structured as a "model"-like operation, - // in order to simplify client code, similar to metadata_. - tensorflow::gtl::optional sharding_; - - // Mode bit that indicates whether to die when a first error is encountered. - bool die_immediately_on_error_ = false; - - XlaBuilder* parent_builder_{nullptr}; -}; - -template -XlaOp XlaBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*Literal::CreateR0(value)); -} - -template -XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(*Literal::CreateR1(values)); -} - -template -XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { - Literal literal(ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {length})); - literal.PopulateWithValue(value); - return ConstantLiteral(literal); -} - -inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*Literal::CreateR1(values)); -} - -template -XlaOp XlaBuilder::ConstantR2( - std::initializer_list> values) { - return ConstantLiteral(*Literal::CreateR2(values)); -} - -template -XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, - const Layout& layout) { - return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); -} - -template -XlaOp XlaBuilder::ConstantFromArray(const Array& values) { - return ConstantLiteral(*Literal::CreateFromArray(values)); -} - -template -XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); -} - -template -XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { - return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); -} - -template -XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateR3FromArray3DWithLayout(values, layout)); -} - -template -XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D& values) { - return ConstantFromArray(values); -} - -template -XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { - return ConstantFromArrayWithLayout(values, layout); -} - -template -XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { - return ConstantFromArray(values); -} - -// RAII-style object: sets the current sharding assignment in builder on -// construction, and sets back to the previous assignment on destruction. -class XlaScopedShardingAssignment { - public: - XlaScopedShardingAssignment(xla::XlaBuilder* builder, - tensorflow::gtl::optional sharding) - : builder_(builder), prev_sharding_(builder->sharding()) { - SetSharding(sharding); - } - - XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; - XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = - delete; - - ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } - - private: - void SetSharding(const tensorflow::gtl::optional& sharding) { - if (sharding.has_value()) { - builder_->SetSharding(sharding.value()); - } else { - builder_->ClearSharding(); - } - } - - xla::XlaBuilder* const builder_; - tensorflow::gtl::optional prev_sharding_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc deleted file mode 100644 index 2df3ea3af0d4fcfb9bc803feebd96f09042ab1f3..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc +++ /dev/null @@ -1,239 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" - -#include - -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -namespace { - -namespace op = xla::testing::opcode_matchers; - -using ::testing::HasSubstr; - -// TODO(b/74197823): Move the tests to service/. -class XlaBuilderTest : public ::testing::Test { - protected: - StatusOr> BuildHloModule(XlaBuilder* b) { - TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build()); - const HloModuleProto& proto = computation.proto(); - TF_ASSIGN_OR_RETURN(const auto& config, - HloModule::CreateModuleConfigFromProto( - proto, legacy_flags::GetDebugOptionsFromFlags())); - return HloModule::CreateFromProto(proto, config); - } - - // Returns the name of the test currently being run. - string TestName() const { - return ::testing::UnitTest::GetInstance()->current_test_info()->name(); - } -}; - -TEST_F(XlaBuilderTest, OnePlusTwo) { - XlaBuilder b(TestName()); - b.Add(b.ConstantR0(1.0), b.ConstantR0(2.0)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); -} - -TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { - XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); - b.Add(x, b.ConstantR0(1.0)); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant()))); -} - -TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { - XlaBuilder b(TestName()); - const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); - const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4}); - auto x = b.Parameter(0, x_shape, "x"); - auto y = b.Parameter(1, y_shape, "y"); - auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); - - TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add)); - EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); - - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1)))); -} - -TEST_F(XlaBuilderTest, XPlusX) { - XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x"); - b.Add(x, x); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0))); -} - -TEST_F(XlaBuilderTest, ShapeInferenceError) { - XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(U32, {2, 4}), "y"); - b.Add(x, y); - auto statusor = BuildHloModule(&b); - ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference")); -} - -TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { - XlaBuilder b_call("add"); - b_call.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); - - XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); - auto y = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "y"); - b.Add(x, y); - auto statusor = BuildHloModule(&b); - ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), - HasSubstr("parameter 0 already registered")); -} - -TEST_F(XlaBuilderTest, Call) { - XlaBuilder b_call("the_only_to_apply"); - auto p0 = b_call.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); - auto p1 = b_call.Parameter(1, ShapeUtil::MakeShape(F32, {}), "p1"); - b_call.Add(p0, p1); - TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build()); - XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - auto one = b.ConstantR0(1); - auto two = b.ConstantR0(2); - b.Add(b.Call(call, {x, y}), b.Call(call, {one, two})); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()), - op::Call(op::Constant(), op::Constant()))); -} - -TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { - XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y"); - b.Add(x, y); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - - // Expected: - // - // x: f32[1,2,3] y: f32[1,2,1] - // | | - // | reshape: f32[1,2] - // | | - // | broadcast: f32[1,2,3] - // \ / - // add - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Parameter(0), - op::Broadcast(op::Reshape(op::Parameter(1))))); -} - -TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { - XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y"); - b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - - // The binary operation has in-dim broadcast and degenerate broadcast, should - // first do the in-dim broadcast then convert the degnerate broadcast into a - // reshape and a broadcast. - // - // Expected: - // - // x: f32[2,3] y: f32[2,1,4] - // | | - // broadcast: f32[2,3,4] reshape: f32[2,4] - // | | - // | broadcast: f32[2,3,4] - // \ / - // add - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)), - op::Broadcast(op::Reshape(op::Parameter(1))))); -} - -TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { - XlaBuilder b1("b1"); - auto p0 = b1.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); - XlaBuilder builder("main"); - builder.Add(p0, p0); - auto statusor = builder.Build(); - ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr( - "built by builder 'b1', but is trying to use it in builder 'main'")); -} - -TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { - XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); - b.Reshape(x, /*new_sizes=*/{6, 35}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reshape(op::Parameter())); -} - -TEST_F(XlaBuilderTest, ReshapeHasTranspose) { - XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); - b.Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter()))); -} - -TEST_F(XlaBuilderTest, Transpose) { - XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); - b.Transpose(x, /*permutation=*/{1, 0}); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Transpose(op::Parameter())); -} - -// TODO(b/65209188): Create a dedicated lowering for Xor. -TEST_F(XlaBuilderTest, Xor) { - XlaBuilder b(TestName()); - auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(PRED, {}), "y"); - b.Xor(x, y); - TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - auto root = module->entry_computation()->root_instruction(); - LOG(ERROR) << module->ToString(); - EXPECT_THAT(root, - op::Or(op::And(op::Not(op::Parameter(0)), op::Parameter(1)), - op::And(op::Parameter(0), op::Not(op::Parameter(1))))); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc similarity index 94% rename from tensorflow/compiler/xla/client/xla_client/xla_computation.cc rename to tensorflow/compiler/xla/client/xla_computation.cc index 72e3935696e0c44ae3893fc8f1ceb261fa5e2646..3543d41fc2656ec028646edebc0bf5b6af7f67a5 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_computation.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_computation.h similarity index 90% rename from tensorflow/compiler/xla/client/xla_client/xla_computation.h rename to tensorflow/compiler/xla/client/xla_computation.h index 0ffba208b1f8683fe1d26107cbfd096b856267f1..71598ef8b296a760b0ee818fce0a59aed5cfc6b4 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_computation.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_COMPUTATION_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_COMPUTATION_H_ #include @@ -64,4 +64,4 @@ class XlaComputation { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/BUILD b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a26b20c861846501c911253d89619591c37322b3 --- /dev/null +++ b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD @@ -0,0 +1,18 @@ +# Description: +# Python API for shardings in XLA. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +py_library( + name = "xla_sharding", + srcs = ["xla_sharding.py"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/compiler/xla/python_api:types", + "//tensorflow/compiler/xla/python_api:xla_shape", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..fb135f5ceda67ce6c001de15b8f3f084ca164826 --- /dev/null +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -0,0 +1,204 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""Experimental support for defining XLA shardings.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as _np # Avoids becoming a part of public Tensorflow API. + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python_api import xla_shape +from tensorflow.core.framework import attr_value_pb2 + + +class Sharding(object): + """A class to support adding sharding attributes to Ops. + + Use the factory constructors and then call apply_to_tensor: + Sharding.replicate().apply_to_tensor(tensor) + """ + + def __init__(self, proto=None): + """Do not use this constructor; use the factory functions below.""" + self._proto = proto + + @classmethod + def replicate(cls): + """Returns a replicated sharding attribute. + + This causes an op to be computed in its entirety independently on all + cores in the XLA device. + """ + return Sharding( + proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)) + + @classmethod + def assign_device(cls, core): + """Returns an AssignDevice sharding attribute. + + This causes an op to be computed in its entirety only on one core in + the XLA device. + Args: + core: The core to assign this Op to. + """ + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.MAXIMAL, + tile_assignment_dimensions=[1], + tile_assignment_devices=[core])) + + @classmethod + def tile(cls, tile_shape, tile_assignment): + """Returns a Tiled sharding attribute. + + This causes an op to be partially computed on multiple cores in the + XLA device. + + Args: + tile_shape: A xla_shape.Shape describing the tile shape that each core + will compute. + The tile shape does not need to be divisible by the tile assignment. + tile_assignment: An np.ndarray describing the topology of the tiling and + which device will compute which part of the topology. + + Raises: + TypeError: tile_assignment was not of np.array type or tile_shape was + not of xla_shape.Shape type. + + TODO(jmolloy): This concept is nefarious and is not + something we really want to expose to users (especially as the + contract for tile_assignment is very strict). + """ + if not isinstance(tile_assignment, _np.ndarray): + raise TypeError('Tile assignment must be of type np.ndarray') + if not isinstance(tile_shape, xla_shape.Shape): + raise TypeError('Tile shape must be of type xla_shape.Shape') + dims = list(tile_assignment.shape) + flattened_devices = tile_assignment.reshape(-1, order='C') + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.OTHER, + tile_shape=tile_shape.message, + tile_assignment_dimensions=dims, + tile_assignment_devices=list(flattened_devices))) + + @classmethod + def split(cls, tensor, split_dimension, num_devices): + """Returns a Sharding that splits a tensor across a dimension. + + This creates a Tiled attribute, similar to tile(), but easier to use for the + common case of tiling a tensor N ways in one dimension. + + Args: + tensor: A tf.Tensor to split. + split_dimension: The dimension number to split. + num_devices: The number of cores to split `tensor` over. + + Raises: + ValueError: The tensor to split was smaller in the split dimension than + the number of devices to split over. + """ + tensor.shape.assert_is_fully_defined() + shape = tensor.shape.as_list() + if shape[split_dimension] < num_devices: + raise ValueError('Split dimension was smaller than the required number ' + 'of splits: shape=%r, dimension=%r, num_devices=%r', + shape, split_dimension, num_devices) + + tile_shape = shape + tile_shape[split_dimension] = int( + math.ceil(tile_shape[split_dimension] / num_devices)) + tile_shape_proto = xla_data_pb2.Shape( + element_type=xla_data_pb2.F32, dimensions=tile_shape) + + tile_assignment_dims = [1] * len(shape) + tile_assignment_dims[split_dimension] = num_devices + + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.OTHER, + tile_shape=tile_shape_proto, + tile_assignment_dimensions=tile_assignment_dims, + tile_assignment_devices=range(num_devices))) + + def apply_to_tensor(self, tensor): + """Applies this Sharding attribute to `tensor`.""" + if len(tensor.op.outputs) > 1: + proto = self._get_or_create_tuple_proto(tensor.op) + # We can't mutate an element of old_proto.tuple_shardings, so create + # a new proto. + tuple_shardings = list(proto.tuple_shardings) + tuple_shardings[tensor.value_index] = self._proto + proto = xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) + else: + proto = self._proto + + attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString()) + # TODO(jmolloy): This need to be seriously revisited before declaring this + # API available for public use. + # pylint: disable=protected-access + tensor.op._set_attr('_XlaSharding', attr_value) + + @property + def proto(self): + """Return the sharding protobuf of type xla_data_pb2.OpSharding.""" + return self._proto + + def _get_or_create_tuple_proto(self, op): + try: + attr = op.get_attr('_XlaSharding') + proto = xla_data_pb2.OpSharding() + proto.ParseFromString(attr) + return proto + except ValueError: + return self._create_tuple_proto(op) + + def _create_tuple_proto(self, op): + shardings = [ + xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED) + for _ in op.outputs + ] + return xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings) + + +# Helpers for the above factory functions that allow easy application of +# shardings, for example: +# tensor = xla_sharding.replicate(tensor) + + +def replicate(tensor): + Sharding.replicate().apply_to_tensor(tensor) + return tensor + + +def assign_device(tensor, device): + Sharding.assign_device(device).apply_to_tensor(tensor) + return tensor + + +def tile(tensor, tile_shape, tile_assignment): + Sharding.tile(tile_shape, tile_assignment).apply_to_tensor(tensor) + return tensor + + +def split(tensor, split_dimension, num_devices): + Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(tensor) + return tensor diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index e8f29b83291a7cb238dc25b9f4bb743fe426a162..b72d190d54591384392e79e73e90cf52df04a902 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -190,9 +190,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } if (!ShapeUtil::IsArray(shape)) { - return InvalidArgument( - "shape of primitive type %s should not have a layout", - PrimitiveType_Name(shape.element_type()).c_str()); + if (layout.minor_to_major_size() != 0 || + layout.padded_dimensions_size() != 0) { + return InvalidArgument( + "shape of primitive type %s should not have a non-trivial layout", + PrimitiveType_Name(shape.element_type()).c_str()); + } + return Status::OK(); } if (layout.format() == INVALID_FORMAT) { @@ -244,6 +248,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } } + if (layout.format() == SPARSE) { + if (!layout.padded_dimensions().empty()) { + return InvalidArgument("Sparse layout has padded dimensions"); + } + } + return Status::OK(); } @@ -287,7 +297,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { shape.layout().padded_dimensions_size() == 0) { return false; } - CHECK(IsDenseArray(shape)); + CHECK(IsDenseArray(shape)) << shape.ShortDebugString(); CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size()); for (int64 i = 0; i < shape.dimensions_size(); ++i) { if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) { diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index f42fb92359f40ec763866af094972046f6407ae1..1bf8948ef6ded56573d588258c3d9bbfaa55a50d 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -31,7 +31,6 @@ std::vector* flag_objects; std::once_flag flags_init; void SetDebugOptionsDefaults(DebugOptions* flags) { - flags->set_xla_enable_fast_math(true); flags->set_xla_llvm_enable_alias_scope_metadata(true); flags->set_xla_llvm_enable_noalias_metadata(true); flags->set_xla_llvm_enable_invariant_load_metadata(true); @@ -53,6 +52,11 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // the heuristics needed to decide when to run on multiple streams. See // b/77879207. flags->set_xla_gpu_disable_multi_streaming(true); + + // TODO(jlebar): Disable fastmath once doing so is not a performance + // regression. + flags->set_xla_cpu_enable_fast_math(true); + flags->set_xla_gpu_enable_fast_math(true); } // Allocates flag_values and flag_objects; this function must not be called more @@ -150,10 +154,16 @@ void AllocateFlags() { flag_values->mutable_xla_generate_hlo_text_to(), "Dump all HLO modules as text into the provided directory path."), tensorflow::Flag( - "xla_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_enable_fast_math), - flag_values->xla_enable_fast_math(), - "Enable unsafe fast-math optimizations in the compiler; " + "xla_cpu_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), + flag_values->xla_cpu_enable_fast_math(), + "Enable unsafe fast-math optimizations in the CPU compiler; " + "this may produce faster code at the expense of some accuracy."), + tensorflow::Flag( + "xla_gpu_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), + flag_values->xla_cpu_enable_fast_math(), + "Enable unsafe fast-math optimizations in the GPU compiler; " "this may produce faster code at the expense of some accuracy."), tensorflow::Flag( "xla_llvm_enable_alias_scope_metadata", diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc new file mode 100644 index 0000000000000000000000000000000000000000..36e472568ecfdb97c828817ed339260ee7878723 --- /dev/null +++ b/tensorflow/compiler/xla/literal.cc @@ -0,0 +1,2090 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::strings::Printf; +using tensorflow::strings::StrCat; + +namespace xla { + +namespace { + +constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; + +// Converts between little and big endian. +// +// Precondition: size % 2 == 0 (elements in the array are 16 bits long) +void ConvertEndianShort(string* bytes) { + CHECK_EQ(bytes->size() / 2, 0); + for (int64 i = 0; i < bytes->size(); i += 2) { + std::swap((*bytes)[i], (*bytes)[i + 1]); + } +} + +void ConvertEndianShort(char* bytes, int64 size) { + CHECK_EQ(size / 2, 0); + for (int64 i = 0; i < size; i += 2) { + std::swap(bytes[i], bytes[i + 1]); + } +} + +} // namespace + +LiteralBase::~LiteralBase() {} + +std::ostream& operator<<(std::ostream& out, const Literal& literal) { + out << literal.ToString(); + return out; +} + +MutableLiteralBase::StrideConfig::StrideConfig( + const Shape& source_shape, const Shape& dest_shape, + tensorflow::gtl::ArraySlice dimensions) + : dimensions(dimensions), + base(dimensions.size(), 0), + step(dimensions.size(), 1) { + if (!dimensions.empty()) { + // Selects the shape with the largest minor dimension as the one upon + // which to run the tight stride loop. + if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >= + dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) { + minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0); + dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension); + } else { + minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0); + source_stride = + IndexUtil::GetDimensionStride(source_shape, minor_dimension); + } + minor_loop_size = dimensions[minor_dimension]; + step[minor_dimension] = minor_loop_size; + } +} + +Literal::Literal(const Shape& shape) + : Literal(shape, /*allocate_arrays=*/true) {} + +void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + SetPiece(subshape, &child_piece, allocate_arrays); + + piece->emplace_back(std::move(child_piece)); + } + } else if (ShapeUtil::IsArray(shape)) { + if (allocate_arrays) { + if (LayoutUtil::IsSparseArray(shape)) { + // For sparse arrays, the buffer must be of the size of the maximum + // number of sparse elements possible. + const int64 max_sparse_elements = + LayoutUtil::MaxSparseElements(shape.layout()); + piece->set_buffer( + new char[max_sparse_elements * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); + piece->set_sparse_indices( + new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); + } else { + piece->set_buffer(new char[piece->size_bytes()]); + } + } + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(piece->size_bytes(), 0); + } +} + +Literal::Literal(const Shape& shape, bool allocate_arrays) + : MutableLiteralBase() { + shape_ = MakeUnique(shape); + CHECK(LayoutUtil::HasLayout(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + CHECK(&root_piece_->subshape() == shape_.get()); + + SetPiece(*shape_, root_piece_, allocate_arrays); +} + +Literal::~Literal() { + if (root_piece_ != nullptr) { + DeallocateBuffers(); + delete root_piece_; + } +} + +void Literal::DeallocateBuffers() { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete[] piece->buffer(); + delete piece->sparse_indices(); + } + }); +} + +Literal::Literal(Literal&& other) : MutableLiteralBase() { + *this = std::move(other); +} + +Literal& Literal::operator=(Literal&& other) { + DCHECK(&other.root_piece_->subshape() == other.shape_.get()); + using std::swap; + swap(shape_, other.shape_); + swap(root_piece_, other.root_piece_); + DCHECK(&root_piece_->subshape() == shape_.get()); + + return *this; +} + +std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { + auto literal = MakeUnique(shape); + literal->root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (ShapeUtil::IsArray(piece->subshape())) { + memset(piece->untyped_data(), 0, piece->size_bytes()); + } + }); + return literal; +} + +const SparseIndexArray* LiteralBase::sparse_indices( + const ShapeIndex& shape_index) const { + return piece(shape_index).sparse_indices(); +} + +SparseIndexArray* MutableLiteralBase::sparse_indices( + const ShapeIndex& shape_index) { + return piece(shape_index).sparse_indices(); +} + +template +Status MutableLiteralBase::CopySliceFromInternal( + const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); + TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); + + auto linear_index = [](const Shape& shape, + tensorflow::gtl::ArraySlice multi_index) { + return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); + }; + + if (ShapeUtil::Rank(src_literal.shape()) == 0 || + ShapeUtil::Rank(shape()) == 0) { + // If any of the two shapes are scalars, we can just call the StridedCopy() + // directly, and we know we will be copying only one value. + TF_RET_CHECK(copy_size.empty()); + StridedCopy(data(), linear_index(shape(), dest_base), 0, + src_literal.data(), + linear_index(src_literal.shape(), src_base), 0, 1); + } else if (!ShapeUtil::IsZeroElementArray(shape()) && + !ShapeUtil::IsZeroElementArray(src_literal.shape())) { + // Perform copy if neither src nor dest has dimensions with zero element, + // otherwise it's a no-op. + TF_RET_CHECK(src_base.size() == dest_base.size()); + TF_RET_CHECK(src_base.size() == copy_size.size()); + + // Scan the source from minor, stepping in copy size blocks, then within + // the index enumaration functor, do a strided copy advancing source index + // by one (walking through the minor dimension), and destination index by + // proper stride size at the matching dimension. + DimensionVector src_indexes(src_base.size(), 0); + DimensionVector dest_indexes(dest_base.size(), 0); + MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(), + copy_size); + + auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { + // Map from multi-dimensional index, to source index. + std::transform(indexes.begin(), indexes.end(), src_base.begin(), + src_indexes.begin(), std::plus()); + // Map from multi-dimensional index, to destination index. + std::transform(indexes.begin(), indexes.end(), dest_base.begin(), + dest_indexes.begin(), std::plus()); + + int64 src_index = linear_index(src_literal.shape(), src_indexes); + int64 dest_index = linear_index(shape(), dest_indexes); + + // `this->` is needed to workaround MSVC bug: #16882 + StridedCopy(this->data(), dest_index, stride_config.dest_stride, + src_literal.data(), src_index, + stride_config.source_stride, stride_config.minor_loop_size); + return true; + }; + + ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base, + stride_config.dimensions, stride_config.step, + copy_proc); + } + return Status::OK(); +} + +Status MutableLiteralBase::CopyElementFrom( + const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index) { + DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); + const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( + src_literal.shape(), src_index); + const int64 dest_linear_index = + IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + char* dest_address = + static_cast(untyped_data()) + dest_linear_index * primitive_size; + const char* source_address = + static_cast(src_literal.untyped_data()) + + src_linear_index * primitive_size; + if (dest_address != source_address) { + memcpy(dest_address, source_address, primitive_size); + } + return Status::OK(); +} + +/* static */ StatusOr> +MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { + if (!proto.has_shape()) { + return InvalidArgument("LiteralProto has no shape"); + } + if (!LayoutUtil::HasLayout(proto.shape())) { + return InvalidArgument("LiteralProto has no layout"); + } + + auto literal = MakeUnique(proto.shape()); + + TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + const LiteralProto* proto_element = &proto; + for (int64 i : index) { + CHECK(i < proto_element->tuple_literals_size()); + proto_element = &proto_element->tuple_literals(i); + } + + if (ShapeUtil::IsTuple(piece->subshape())) { + if (proto_element->tuple_literals_size() != + ShapeUtil::TupleElementCount(piece->subshape())) { + return InvalidArgument( + "Expected %lld tuple elements in LiteralProto, has %d", + ShapeUtil::TupleElementCount(piece->subshape()), + proto_element->tuple_literals_size()); + } + return Status::OK(); + } + if (piece->subshape().element_type() == TOKEN) { + return Status::OK(); + } + + CHECK(ShapeUtil::IsArray(piece->subshape())); + TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); + + return Status::OK(); + })); + + return std::move(literal); +} + +std::vector Literal::DecomposeTuple() { + CHECK(ShapeUtil::IsTuple(shape())); + std::vector elements; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { + elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), + /*allocate_arrays=*/false)); + Literal& element = elements.back(); + element.root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* dest_piece) { + ShapeIndex src_index = {i}; + for (int64 j : index) { + src_index.push_back(j); + } + Piece& src_piece = piece(src_index); + + // Move the respective buffer and sparse indices over to the element + // Literal. + dest_piece->set_buffer(src_piece.buffer()); + src_piece.set_buffer(nullptr); + dest_piece->set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); + }); + } + // Set this literal to be nil-shaped. + *this = Literal(); + return elements; +} + +namespace { + +// Copies the elements in 'src' to 'dest'. The shape and layout of the data in +// the array slices are indicated by dest_shape and src_shape respectively. +template +void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, + tensorflow::gtl::ArraySlice src, + const Shape& dest_shape, const Shape& src_shape) { + CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); + if (ShapeUtil::IsZeroElementArray(dest_shape)) { + return; + } + std::vector index(ShapeUtil::Rank(dest_shape)); + do { + dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = + src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; + } while (IndexUtil::BumpIndices(dest_shape, &index)); +} + +} // namespace + +Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { + CHECK(subshape_ != nullptr); + CHECK(src.subshape_ != nullptr); + if (ShapeUtil::Equal(subshape(), src.subshape())) { + // If the layouts are equal it's faster just to memcpy. + memcpy(buffer(), src.buffer(), src.size_bytes()); + } else { + TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); + std::vector origin(ShapeUtil::Rank(subshape()), 0); + switch (subshape().element_type()) { +#define COPY_ELEMENTS(XLA_T, NATIVE_T) \ + case (XLA_T): \ + CopyElementsBetween(data(), src.data(), \ + subshape(), src.subshape()); \ + break; + COPY_ELEMENTS(U8, uint8); + COPY_ELEMENTS(U16, uint16); + COPY_ELEMENTS(U32, uint32); + COPY_ELEMENTS(U64, uint64); + COPY_ELEMENTS(S8, int8); + COPY_ELEMENTS(S16, int16); + COPY_ELEMENTS(S32, int32); + COPY_ELEMENTS(S64, int64); + COPY_ELEMENTS(F16, half); + COPY_ELEMENTS(BF16, bfloat16); + COPY_ELEMENTS(F32, float); + COPY_ELEMENTS(F64, double); + COPY_ELEMENTS(C64, complex64); + COPY_ELEMENTS(PRED, bool); +#undef COPY_ELEMENTS + default: + return Unimplemented( + "Copying a Literal object with element type %s is not implemented.", + PrimitiveType_Name(subshape().element_type()).c_str()); + } + } + return Status::OK(); +} + +Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index, + const ShapeIndex& src_shape_index) { + const Shape& dest_subshape = + ShapeUtil::GetSubshape(shape(), dest_shape_index); + const Shape& src_subshape = + ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index); + if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { + return InvalidArgument( + "Destination subshape incompatible with source subshape: %s vs %s", + ShapeUtil::HumanString(dest_subshape).c_str(), + ShapeUtil::HumanString(src_subshape).c_str()); + } + return root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + if (!ShapeUtil::IsArray(piece->subshape())) { + return Status::OK(); + } + + // Determine if this index is in the part of this literal that we want + // to copy over from src_literal. + bool in_subtree_to_copy = true; + for (int i = 0; i < dest_shape_index.size(); ++i) { + if (index[i] != dest_shape_index[i]) { + in_subtree_to_copy = false; + break; + } + } + if (!in_subtree_to_copy) { + return Status::OK(); + } + // Construct the index of the corresponding piece in the source literal. + ShapeIndex src_piece_index = src_shape_index; + for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + src_piece_index.push_back(index[i]); + } + TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); + return Status::OK(); + }); +} + +Status Literal::MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index) { + const Shape& dest_subshape = + ShapeUtil::GetSubshape(shape(), dest_shape_index); + if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) { + return InvalidArgument( + "Destination subshape not equal to source shape: %s vs %s", + ShapeUtil::HumanString(dest_subshape).c_str(), + ShapeUtil::HumanString(src_literal.shape()).c_str()); + } + + src_literal.root_piece_->ForEachSubpiece( + [&](const ShapeIndex& src_index, const Piece& src_piece) { + if (!ShapeUtil::IsArray(src_piece.subshape())) { + return; + } + + ShapeIndex dest_index = dest_shape_index; + for (int64 i : src_index) { + dest_index.push_back(i); + } + Piece& dest_piece = piece(dest_index); + delete[] dest_piece.buffer(); + dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + }); + + src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); + delete src_literal.root_piece_; + src_literal.root_piece_ = new LiteralBase::Piece(); + src_literal.root_piece_->set_subshape(src_literal.shape_.get()); + + return Status::OK(); +} + +Status MutableLiteralBase::CopySliceFrom( + const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); + TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) + << ShapeUtil::HumanString(src_literal.shape()); + TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); + + switch (shape().element_type()) { + case U8: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case U16: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case U32: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case U64: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case S8: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case S16: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case S32: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case S64: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case F16: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case BF16: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case F32: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case F64: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case C64: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case PRED: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + default: + break; + } + return Unimplemented( + "Copying a slice from a Literal object with element type %d is not " + "implemented.", + shape().element_type()); +} + +void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(element_count(), values.bits()); + CHECK_EQ(shape().element_type(), PRED); + for (int64 i = 0; i < static_cast(values.bits()); ++i) { + Set({i}, values.get(i)); + } +} + +std::unique_ptr LiteralBase::Relayout( + const Layout& new_layout, const ShapeIndex& shape_index) const { + // Create new shape with 'new_layout' set at the given shape index. + Shape new_shape = shape(); + Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); + TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); + *subshape->mutable_layout() = new_layout; + auto result = MakeUnique(new_shape); + TF_CHECK_OK(result->CopyFrom(*this)); + return result; +} + +std::unique_ptr LiteralBase::Relayout( + const Shape& shape_with_layout) const { + CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) + << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) + << " not compatible with literal shape " + << ShapeUtil::HumanString(shape()); + std::unique_ptr result = CreateFromShape(shape_with_layout); + ShapeUtil::ForEachSubshape( + result->shape(), + [this, &result](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(subshape)) { + TF_CHECK_OK(result->CopyFrom(*this, + /*dest_shape_index=*/index, + /*src_shape_index=*/index)); + } + }); + return result; +} + +StatusOr> LiteralBase::Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const { + if (!ShapeUtil::IsArray(shape())) { + return InvalidArgument("Broadcast only supports arrays."); + } + + for (int64 i = 0; i < dimensions.size(); i++) { + TF_RET_CHECK(shape().dimensions(i) == + result_shape.dimensions(dimensions[i])); + } + + std::unique_ptr result = MakeUnique(result_shape); + + // scratch_source_index is temporary storage space for the computed index into + // the input literal. We put it here to avoid allocating an std::vector in + // every iteration of ShapeUtil::ForEachIndex. + std::vector scratch_source_index(shape().dimensions_size()); + + char* dest_data = static_cast(result->untyped_data()); + const char* source_data = static_cast(untyped_data()); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + ShapeUtil::ForEachIndex( + result_shape, [&](tensorflow::gtl::ArraySlice output_index) { + for (int64 i = 0; i < dimensions.size(); ++i) { + scratch_source_index[i] = output_index[dimensions[i]]; + } + int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex( + result_shape, output_index); + int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex( + shape(), scratch_source_index); + memcpy(dest_data + primitive_size * dest_index, + source_data + primitive_size * source_index, primitive_size); + return true; + }); + + return std::move(result); +} + +StatusOr> LiteralBase::Reshape( + tensorflow::gtl::ArraySlice dimensions) const { + if (!ShapeUtil::IsArray(shape())) { + return InvalidArgument("Reshape does not support tuples."); + } + std::unique_ptr output; + if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { + output = + Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); + } else { + output = CloneToUnique(); + } + // Because the layout is monotonic, we can simply reuse the same sequence of + // values without changing their order. + *output->mutable_shape_do_not_use() = + ShapeUtil::MakeShape(shape().element_type(), dimensions); + + int64 elements_before = ShapeUtil::ElementsIn(shape()); + int64 elements_after = ShapeUtil::ElementsIn(output->shape()); + if (elements_before != elements_after) { + return InvalidArgument( + "Shapes before and after Literal::Reshape have different numbers " + "of elements: %s vs %s.", + ShapeUtil::HumanString(shape()).c_str(), + ShapeUtil::HumanString(output->shape()).c_str()); + } + return std::move(output); +} + +std::unique_ptr LiteralBase::Transpose( + tensorflow::gtl::ArraySlice permutation) const { + CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; + CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) + << "Given permutation is not a permutation of dimension numbers"; + // To transpose the array, we just permute the dimensions and layout, and + // do a straight memory copy of the raw data set. + // This is considerably faster than iterating over every array element using + // the EachCell<>() and Set<>() APIs. + std::vector inverse_permutation = InversePermutation(permutation); + Shape permuted_shape = + ShapeUtil::PermuteDimensions(inverse_permutation, shape()); + // Replace the layout with one affine to this shape, such that a + // transpose operation can be performed by leaving the flat values + // representation intact. + // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. + // The shape with affine layout resulting from that operation will be + // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the + // most minor. + // + // Essentially, given MinMaj(Di) the position of the Di dimension within the + // minor to major vector, and given T(Di) the index that the original Di + // dimension has within the transposed array, a layout is affine if + // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major + // vector of the affine layout. + CHECK(LayoutUtil::IsDenseArray(permuted_shape)); + Layout* layout = permuted_shape.mutable_layout(); + layout->clear_minor_to_major(); + for (auto index : LayoutUtil::MinorToMajor(shape())) { + layout->add_minor_to_major(inverse_permutation[index]); + } + auto new_literal = MakeUnique(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), + ShapeUtil::ByteSizeOf(shape())); + std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); + return new_literal; +} + +template +std::unique_ptr LiteralBase::SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const { + auto result_literal = MakeUnique(result_shape); + DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + result_literal->EachCell( + [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + NativeT value = Get(new_indices); + result_literal->Set(indices, value); + }); + return result_literal; +} + +std::unique_ptr LiteralBase::Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const { + CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; + + DimensionVector result_dimensions; + for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { + CHECK_GE(start_indices[dnum], 0); + CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) + << "dnum = " << dnum; + int64 dimension = limit_indices[dnum] - start_indices[dnum]; + CHECK_GE(dimension, 0) << "dnum = " << dnum; + result_dimensions.push_back(dimension); + } + const auto result_shape = + ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, + LayoutUtil::MinorToMajor(shape())); + switch (result_shape.element_type()) { + case F32: + return SliceInternal(result_shape, start_indices); + case BF16: + return SliceInternal(result_shape, start_indices); + case C64: + return SliceInternal(result_shape, start_indices); + case S32: + return SliceInternal(result_shape, start_indices); + case U32: + return SliceInternal(result_shape, start_indices); + default: + LOG(FATAL) << "not yet implemented: " + << PrimitiveType_Name(result_shape.element_type()); + } +} + +Literal LiteralBase::Clone() const { + Literal result(shape()); + TF_CHECK_OK(result.CopyFrom(*this)); + return result; +} + +std::unique_ptr LiteralBase::CloneToUnique() const { + auto result = MakeUnique(shape()); + TF_CHECK_OK(result->CopyFrom(*this)); + return result; +} + +string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { + const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsDenseArray(subshape)); + switch (subshape.element_type()) { + case PRED: + return Get(multi_index, shape_index) ? "true" : "false"; + case S8: + return StrCat(Get(multi_index, shape_index)); + case S16: + return StrCat(Get(multi_index, shape_index)); + case S32: + return StrCat(Get(multi_index, shape_index)); + case S64: + return StrCat(Get(multi_index, shape_index)); + case U8: + return StrCat(Get(multi_index, shape_index)); + case U16: + return StrCat(Get(multi_index, shape_index)); + case U32: + return StrCat(Get(multi_index, shape_index)); + case U64: + return StrCat(Get(multi_index, shape_index)); + case F16: + return StrCat(static_cast(Get(multi_index, shape_index))); + case F32: + return StrCat(Get(multi_index, shape_index)); + case BF16: + return StrCat( + static_cast(Get(multi_index, shape_index))); + case F64: + return StrCat(Get(multi_index, shape_index)); + case C64: { + complex64 c = Get(multi_index, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } + default: + LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); + } +} + +string LiteralBase::GetSparseElementAsString( + int64 sparse_element_number, const ShapeIndex& shape_index) const { + const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsSparseArray(subshape)); + switch (subshape.element_type()) { + case PRED: + return GetSparseElement(sparse_element_number, shape_index) + ? "true" + : "false"; + case S8: + return StrCat(GetSparseElement(sparse_element_number, shape_index)); + case S16: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case S32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case S64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U8: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U16: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case F16: + return StrCat(static_cast( + GetSparseElement(sparse_element_number, shape_index))); + case F32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case BF16: + return StrCat(static_cast( + GetSparseElement(sparse_element_number, shape_index))); + case F64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case C64: { + complex64 c = + GetSparseElement(sparse_element_number, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } + default: + LOG(FATAL) << "Invalid element type for sparse arrays: " + << PrimitiveType_Name(subshape.element_type()); + } +} + +StatusOr LiteralBase::GetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(LayoutUtil::IsDenseArray(shape())); + switch (shape().element_type()) { + case PRED: + return Get(multi_index); + case U8: + return Get(multi_index); + case S32: + return Get(multi_index); + case S64: + return Get(multi_index); + case U32: + return Get(multi_index); + case U64: + return Get(multi_index); + default: + return FailedPrecondition( + "Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type()).c_str()); + } +} + +size_t LiteralBase::Hash() const { + using tensorflow::Hash64; + using tensorflow::Hash64Combine; + + size_t hash_value = ShapeUtil::Hash(shape()); + + ShapeUtil::ForEachSubshape( + shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsArray(subshape)) { + return; + } + + CHECK(LayoutUtil::IsDense(subshape.layout())); + hash_value = Hash64Combine( + hash_value, Hash64(static_cast(untyped_data(index)), + size_bytes(index))); + }); + + return hash_value; +} + +Status MutableLiteralBase::SetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index, int64 value) { + CHECK(LayoutUtil::IsDenseArray(shape())); + switch (shape().element_type()) { + case PRED: + Set(multi_index, value); + break; + case U8: + Set(multi_index, value); + break; + case S32: + Set(multi_index, value); + break; + case S64: + Set(multi_index, value); + break; + case U32: + Set(multi_index, value); + break; + case U64: + Set(multi_index, value); + break; + default: + return FailedPrecondition( + "Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type()).c_str()); + } + return Status::OK(); +} + +tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index) const { + const Piece& p = piece(shape_index); + CHECK_GE(sparse_element_number, 0); + CHECK_LT(sparse_element_number, p.sparse_indices()->index_count()); + return p.sparse_indices()->At(sparse_element_number); +} + +void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) { + piece(shape_index).SortSparseElements(); +} + +void LiteralBase::Piece::SortSparseElements() { + switch (subshape().element_type()) { + case PRED: + SortSparseElementsInternal(); + break; + case S8: + SortSparseElementsInternal(); + break; + case U8: + SortSparseElementsInternal(); + break; + case S16: + SortSparseElementsInternal(); + break; + case U16: + SortSparseElementsInternal(); + break; + case S32: + SortSparseElementsInternal(); + break; + case U32: + SortSparseElementsInternal(); + break; + case S64: + SortSparseElementsInternal(); + break; + case U64: + SortSparseElementsInternal(); + break; + case F32: + SortSparseElementsInternal(); + break; + case F64: + SortSparseElementsInternal(); + break; + case C64: + SortSparseElementsInternal(); + break; + case F16: + SortSparseElementsInternal(); + break; + case BF16: + SortSparseElementsInternal(); + break; + default: + LOG(FATAL) << "Element type not valid for sparse array: " + << PrimitiveType_Name(subshape().element_type()); + } +} + +template +void LiteralBase::Piece::SortSparseElementsInternal() { + CHECK(LayoutUtil::IsSparseArray(subshape())); + int64 num_elements = sparse_indices()->index_count(); + auto values = data(); + CHECK_LE(num_elements, values.size()); + sparse_indices()->SortWithValues( + tensorflow::gtl::MutableArraySlice(values.data(), num_elements)); +} + +namespace { + +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_layout, std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); + + auto shape_to_string = [print_layout](const Shape& shape) { + if (print_layout) { + return ShapeUtil::HumanStringWithLayout(shape); + } else { + return ShapeUtil::HumanString(shape); + } + }; + + // TODO(b/32894291): refactor this code to reduce code duplication. + if (ShapeUtil::IsTuple(subshape)) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" (\n"); + std::vector tuple_pieces; + for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { + ShapeIndex element_index = shape_index; + element_index.push_back(i); + std::vector element_pieces; + ToStringHelper(literal, element_index, print_layout, &element_pieces); + tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, "")); + } + pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); + pieces->push_back("\n)"); + return; + } + + if (ShapeUtil::IsToken(subshape)) { + pieces->push_back("token"); + return; + } + + if (LayoutUtil::IsSparseArray(subshape)) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back("{"); + int64 rank = ShapeUtil::Rank(subshape); + int64 num_elements = literal.sparse_element_count(); + for (int64 i = 0; i < num_elements; ++i) { + if (i > 0) { + pieces->push_back(", "); + } + if (rank == 1) { + pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); + pieces->push_back(": "); + } else { + pieces->push_back("["); + pieces->push_back( + tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); + pieces->push_back("]: "); + } + pieces->push_back(literal.GetSparseElementAsString(i)); + } + pieces->push_back("}"); + return; + } + + CHECK(LayoutUtil::IsDenseArray(subshape)); + + auto element_to_string = + [&](tensorflow::gtl::ArraySlice indices) -> string { + PrimitiveType element_type = subshape.element_type(); + if (element_type == PRED) { + // We display predicates in a densely packed form. + return literal.Get(indices, shape_index) ? "1" : "0"; + } + return ((!indices.empty() && indices.back() > 0) ? ", " : "") + + literal.GetAsString(indices, shape_index); + }; + + if (ShapeUtil::Rank(subshape) == 0) { + pieces->push_back(literal.GetAsString({}, shape_index)); + } else if (ShapeUtil::Rank(subshape) == 1) { + pieces->push_back("{"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(element_to_string({i0})); + } + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 2) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(" { "); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(element_to_string({i0, i1})); + } + pieces->push_back(" "); + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n"); + } + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 3) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(i0 > 0 ? ",\n{" : "{"); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(i1 > 0 ? ",\n { " : " { "); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(element_to_string({i0, i1, i2})); + } + pieces->push_back(" }"); + } + pieces->push_back(" }"); + } + pieces->push_back("\n}"); + } else if (ShapeUtil::Rank(subshape) == 4) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(" {"); + for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { + pieces->push_back(element_to_string({i0, i1, i2, i3})); + } + pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n"); + } + pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" + : " },\n"); + } + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); + } + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 5) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(Printf(" { /*i2=%lld*/\n", i2)); + for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { + pieces->push_back(" {"); + for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { + pieces->push_back(element_to_string({i0, i1, i2, i3, i4})); + } + pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n" + : "},\n"); + } + pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n" + : " },\n"); + } + pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" + : " },\n"); + } + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); + } + pieces->push_back("}"); + } else { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {"); + literal.EachCellAsString( + [&](tensorflow::gtl::ArraySlice indices, const string& value) { + pieces->push_back(" "); + pieces->push_back(value); + }); + pieces->push_back("}"); + } +} + +} // namespace + +int64 LiteralBase::sparse_element_count() const { + CHECK(LayoutUtil::IsSparseArray(shape())); + return sparse_indices()->index_count(); +} + +string LiteralBase::ToString(bool print_layout) const { + std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, print_layout, &pieces); + return tensorflow::str_util::Join(pieces, ""); +} + +void LiteralBase::EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const { + if (ShapeUtil::IsZeroElementArray(shape())) { + return; + } + std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( + shape(), /*linear_index=*/0); + do { + per_cell(indices, GetAsString(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); +} + +namespace { +template +std::unique_ptr ConvertBetweenNativeTypesWithConverter( + const LiteralBase& src_literal, const ConverterType& converter) { + CHECK(ShapeUtil::IsArray(src_literal.shape())); + auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( + src_literal.shape(), + primitive_util::NativeToPrimitiveType())); + auto src_data = src_literal.data(); + auto dest_data = result_literal->template data(); + int64 num_elements = src_literal.element_count(); + + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = converter(src_data[i]); + } + return result_literal; +} + +template +std::unique_ptr ConvertBetweenNativeTypes( + const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { return static_cast(src); }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +template +typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), + std::unique_ptr>::type +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { + return tensorflow::bit_cast(src); + }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +// This template specialization is here to make the compiler happy. bit_cast has +// a static check that the types are the same size. This specialization should +// never be used because the source and destination types are checked for +// identical sizes higher up. +template +typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), + std::unique_ptr>::type +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + LOG(FATAL) << "Invalid bitcast between types of different sizes."; +} + +template +std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { + CHECK(ShapeUtil::IsArray(src_literal.shape())); + auto result_literal = MakeUnique( + ShapeUtil::ChangeElementType(src_literal.shape(), C64)); + using NativeSrcT = + typename primitive_util::PrimitiveTypeToNative::type; + tensorflow::gtl::ArraySlice src_data = + src_literal.data(); + tensorflow::gtl::MutableArraySlice dest_data = + result_literal->data(); + int64 num_elements = src_literal.element_count(); + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = complex64(static_cast(src_data[i]), 0); + } + return result_literal; +} + +template +std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, + bool bitcast) { + CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); + if (bitcast) { + return BitcastBetweenNativeTypes< + typename primitive_util::PrimitiveTypeToNative< + primitive_src_type>::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); + } else { + return ConvertBetweenNativeTypes< + typename primitive_util::PrimitiveTypeToNative< + primitive_src_type>::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); + } +} + +template +StatusOr> ConvertIfDestTypeMatches( + const LiteralBase& src_literal, PrimitiveType primitive_dest_type, + bool bitcast) { + switch (primitive_dest_type) { +#define CONVERT_IF_TYPES_MATCH(type) \ + case (type): \ + return ConvertIfTypesMatch(src_literal, \ + bitcast); + CONVERT_IF_TYPES_MATCH(PRED) + CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S32) + CONVERT_IF_TYPES_MATCH(S64) + CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U32) + CONVERT_IF_TYPES_MATCH(U64) + CONVERT_IF_TYPES_MATCH(F16) + CONVERT_IF_TYPES_MATCH(F32) + CONVERT_IF_TYPES_MATCH(F64) + CONVERT_IF_TYPES_MATCH(BF16) +#undef CONVERT_IF_TYPES_MATCH + case C64: + if (!bitcast) { + return ConvertToC64(src_literal); + } + break; + // Other types are not yet supported. + default: + break; + } + return Unimplemented( + "Converting from type %s to type %s is not implemented.", + PrimitiveType_Name(src_literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); +} + +StatusOr> ConvertSwitch( + const LiteralBase& literal, PrimitiveType primitive_dest_type, + bool bitcast) { + TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); + if (literal.shape().element_type() == primitive_dest_type) { + return literal.CloneToUnique(); + } + switch (literal.shape().element_type()) { +#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ + case (type): \ + return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \ + bitcast); + CONVERT_IF_DEST_TYPE_MATCHES(PRED) + CONVERT_IF_DEST_TYPE_MATCHES(S8) + CONVERT_IF_DEST_TYPE_MATCHES(S32) + CONVERT_IF_DEST_TYPE_MATCHES(S64) + CONVERT_IF_DEST_TYPE_MATCHES(U8) + CONVERT_IF_DEST_TYPE_MATCHES(U32) + CONVERT_IF_DEST_TYPE_MATCHES(U64) + CONVERT_IF_DEST_TYPE_MATCHES(F16) + CONVERT_IF_DEST_TYPE_MATCHES(F32) + CONVERT_IF_DEST_TYPE_MATCHES(F64) + CONVERT_IF_DEST_TYPE_MATCHES(BF16) +#undef CONVERT_IF_DEST_TYPE_MATCHES + // Other types are not yet supported. + default: + return Unimplemented( + "%s from type %s to type %s is not implemented.", + (bitcast ? "Bitcast converting" : "Converting"), + PrimitiveType_Name(literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); + } +} + +} // namespace + +StatusOr> LiteralBase::Convert( + PrimitiveType primitive_dest_type) const { + return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); +} + +StatusOr> LiteralBase::BitcastConvert( + PrimitiveType primitive_dest_type) const { + if (primitive_util::BitWidth(shape().element_type()) != + primitive_util::BitWidth(primitive_dest_type)) { + return InvalidArgument( + "Cannot bitcast convert from %s to %s, bit widths are different: %d != " + "%d", + PrimitiveType_Name(shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str(), + primitive_util::BitWidth(shape().element_type()), + primitive_util::BitWidth(primitive_dest_type)); + } + return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); +} + +StatusOr> LiteralBase::ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16) const { + if (!ShapeUtil::IsTuple(dest_shape)) { + if (round_f32_to_bf16 && shape().element_type() == F32 && + dest_shape.element_type() == BF16) { + auto converter = [](float src) { + return tensorflow::bfloat16::round_to_bfloat16(src); + }; + return ConvertBetweenNativeTypesWithConverter(*this, + converter); + } + return Convert(dest_shape.element_type()); + } + std::vector elements; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { + auto element = LiteralSlice(*this, {i}); + TF_ASSIGN_OR_RETURN( + auto new_element, + element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); + elements.push_back(std::move(*new_element)); + } + auto converted = MakeUnique(); + *converted = MutableLiteralBase::MoveIntoTuple(&elements); + return std::move(converted); +} + +/* static */ Literal MutableLiteralBase::MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements) { + std::vector element_shapes; + for (const Literal& element : elements) { + element_shapes.push_back(element.shape()); + } + Literal literal(ShapeUtil::MakeTupleShape(element_shapes), + /*allocate_arrays=*/false); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK( + literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); + } + return literal; +} + +template +bool LiteralBase::Piece::EqualElementsInternal( + const LiteralBase::Piece& other, std::vector* multi_index) const { + if (multi_index->size() == ShapeUtil::Rank(subshape())) { + return (Get(*multi_index) == other.Get(*multi_index)); + } + for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { + multi_index->push_back(i); + if (!EqualElementsInternal(other, multi_index)) { + return false; + } + multi_index->pop_back(); + } + return true; +} + +bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { + DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); + + std::vector multi_index; + switch (subshape().element_type()) { + case PRED: + return EqualElementsInternal(other, &multi_index); + case U8: + return EqualElementsInternal(other, &multi_index); + case S32: + return EqualElementsInternal(other, &multi_index); + case S64: + return EqualElementsInternal(other, &multi_index); + case U32: + return EqualElementsInternal(other, &multi_index); + case U64: + return EqualElementsInternal(other, &multi_index); + case F32: + return EqualElementsInternal(other, &multi_index); + case F64: + return EqualElementsInternal(other, &multi_index); + case F16: + return EqualElementsInternal(other, &multi_index); + case BF16: + return EqualElementsInternal(other, &multi_index); + case C64: + return EqualElementsInternal(other, &multi_index); + default: + LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " + << PrimitiveType_Name(subshape().element_type()); + } +} + +bool LiteralBase::operator==(const LiteralBase& other) const { + if (!ShapeUtil::Compatible(shape(), other.shape())) { + return false; + } + + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + const Piece& other_piece = other.piece(index); + if (!piece.EqualElements(other_piece)) { + return false; + } + return true; + }); +} + +namespace { + +template +static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, + NativeT value) { + for (int64 i = 0; i < data.size(); ++i) { + if (data[i] != value) { + return false; + } + } + return true; +} + +} // namespace + +bool LiteralBase::IsAll(int8 value) const { + return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, + const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case U8: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case U32: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case U64: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case S8: + return AllElementsEqualValue(piece.data(), value); + case S32: + return AllElementsEqualValue(piece.data(), value); + case S64: + return AllElementsEqualValue(piece.data(), value); + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case PRED: + if (value == 0) { + return AllElementsEqualValue(piece.data(), false); + } + if (value == 1) { + return AllElementsEqualValue(piece.data(), true); + } + return false; + default: + return false; + } + return false; + }; + + if (!piece_is_all()) { + return false; + } + return true; + }); +} + +bool LiteralBase::IsAllFloat(float value) const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue( + piece.data(), static_cast(value)); + default: + return false; + } + }; + if (!piece_is_all()) { + return false; + } + return true; + }); +} + +bool LiteralBase::IsAllComplex(complex64 value) const { + switch (shape().element_type()) { + case C64: + return AllElementsEqualValue(root_piece().data(), + value); + default: + return false; + } +} + +bool LiteralBase::IsAllFirst() const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + // Empty shapes are not all the first element since there is no first + // element. + if (ShapeUtil::IsZeroElementArray(piece.subshape())) { + return false; + } + auto piece_is_all = [&]() { + switch (piece.subshape().element_type()) { + case PRED: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 8 bit types + case S8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 16 bit types + case BF16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 32 bit types + case F32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 64 bit types + case C64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + default: + return false; + } + }; + + if (!piece_is_all()) { + return false; + } + return true; + }); +} + +bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { + CHECK(ShapeUtil::IsArray(shape())); + switch (shape().element_type()) { + case U8: + return Get(indices) == 0; + case U32: + return Get(indices) == 0; + case U64: + return Get(indices) == 0; + case S8: + return Get(indices) == 0; + case S32: + return Get(indices) == 0; + case S64: + return Get(indices) == 0; + case F32: + return Get(indices) == 0.0f; + case F64: + return Get(indices) == 0.0; + case C64: + return Get(indices) == complex64(0.0f, 0.0f); + case F16: + return Get(indices) == static_cast(0.0f); + case BF16: + return Get(indices) == static_cast(0.0f); + case PRED: + return Get(indices) == false; + default: + LOG(FATAL) << "Input literal must be an array."; + } +} + +namespace { + +template +void CopyToRepeatedField(RepeatedFieldT* dest, + const tensorflow::gtl::ArraySlice src) { + *dest = RepeatedFieldT(src.begin(), src.end()); +} + +} // namespace + +void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { + *proto->mutable_shape() = subshape(); + switch (subshape().element_type()) { + case PRED: + CopyToRepeatedField(proto->mutable_preds(), data()); + break; + case U8: + proto->set_u8s(static_cast(data().data()), + element_count()); + break; + case U32: + CopyToRepeatedField(proto->mutable_u32s(), data()); + break; + case U64: + CopyToRepeatedField(proto->mutable_u64s(), data()); + break; + case S32: + CopyToRepeatedField(proto->mutable_s32s(), data()); + break; + case S64: + CopyToRepeatedField(proto->mutable_s64s(), data()); + break; + case F16: + *proto->mutable_f16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_f16s()); + } + break; + case BF16: + *proto->mutable_bf16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_bf16s()); + } + break; + case F32: + CopyToRepeatedField(proto->mutable_f32s(), data()); + break; + case F64: + CopyToRepeatedField(proto->mutable_f64s(), data()); + break; + case C64: + for (complex64 value : data()) { + proto->add_c64s(value.real()); + proto->add_c64s(value.imag()); + } + break; + case TUPLE: + case TOKEN: + // Nothing to do but assign the shape which is done above. + return; + default: + // TODO(b/111551621): Support serializing more PrimitiveTypes. + LOG(FATAL) << "Unhandled primitive type " + << PrimitiveType_Name(subshape().element_type()); + } +} + +const void* LiteralBase::Piece::untyped_data() const { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + return buffer(); +} + +void* LiteralBase::Piece::untyped_data() { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + return buffer(); +} + +namespace { + +template +Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, + const RepeatedFieldT& src) { + if (dest.size() != src.size()) { + return InvalidArgument( + "Expected %lu elements in LiteralProto repeated field, has %d", + dest.size(), src.size()); + } + std::copy(src.begin(), src.end(), dest.begin()); + return Status::OK(); +} + +} // namespace + +Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { + // These conditions should have been checked in + // MutableLiteralBase::CreateFromProto. + TF_RET_CHECK(proto.has_shape()); + TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); + TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); + + switch (subshape().element_type()) { + case PRED: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); + break; + case U8: { + auto u8_data = data(); + TF_RET_CHECK(proto.u8s().size() == u8_data.size()); + std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin()); + } break; + case S32: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s32s())); + break; + case S64: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s64s())); + break; + case U32: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u32s())); + break; + case U64: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); + break; + case F16: { + const string& s(proto.f16s()); + TF_RET_CHECK(data().size() * sizeof(half) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; + + case BF16: { + const string& s(proto.bf16s()); + TF_RET_CHECK(data().size() * sizeof(bfloat16) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; + case F32: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f32s())); + break; + case F64: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f64s())); + break; + case C64: { + auto complex_data = data(); + TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2); + for (int64 i = 0; i < complex_data.size(); ++i) { + complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)}; + } + } break; + case TUPLE: + LOG(FATAL) << "Should not be called on tuple shapes: " + << ShapeUtil::HumanString(subshape()); + break; + default: + LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); + } + return Status::OK(); +} + +LiteralProto LiteralBase::ToProto() const { + LiteralProto proto; + root_piece().ForEachSubpiece( + [&](const ShapeIndex& index, const Piece& piece) { + LiteralProto* proto_piece = &proto; + for (int64 i : index) { + while (proto_piece->tuple_literals_size() <= i) { + proto_piece->add_tuple_literals(); + } + proto_piece = proto_piece->mutable_tuple_literals(i); + } + piece.WriteToProto(proto_piece); + }); + + if (LayoutUtil::IsSparseArray(shape())) { + CopyToRepeatedField(proto.mutable_sparse_indices(), + sparse_indices()->data()); + } + + return proto; +} + +const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { + return piece(shape_index).untyped_data(); +} + +void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) { + return piece(shape_index).untyped_data(); +} + +int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { + return piece(shape_index).size_bytes(); +} + +string LiteralBase::GetR1U8AsString() const { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(shape().element_type(), U8); + return string(tensorflow::bit_cast(data().data()), + ShapeUtil::ElementsIn(shape())); +} + +void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape, + Piece* src_piece, + Piece* dest_piece) { + DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape())) + << "src_piece has shape: " + << ShapeUtil::HumanString(src_piece->subshape()) + << "dest_piece has shape: " + << ShapeUtil::HumanString(dest_piece->subshape()); + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece); + + dest_piece->emplace_back(std::move(child_piece)); + } + } else if (ShapeUtil::IsArray(shape)) { + dest_piece->set_buffer(src_piece->buffer()); + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(dest_piece->size_bytes(), 0); + } +} + +MutableLiteralBase::~MutableLiteralBase() {} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + const MutableBorrowingLiteral& literal) + : MutableLiteralBase() { + shape_ = MakeUnique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); +} + +MutableBorrowingLiteral& MutableBorrowingLiteral::operator=( + const MutableBorrowingLiteral& literal) { + shape_ = MakeUnique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); + + return *this; +} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + const MutableLiteralBase& literal) + : MutableLiteralBase() { + shape_ = MakeUnique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal) + : MutableLiteralBase() { + shape_ = MakeUnique(literal->shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + MutableBorrowingLiteral literal, const ShapeIndex& view_root) + : MutableLiteralBase() { + shape_ = MakeUnique(literal.piece(view_root).subshape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, + const Shape& shape) + : MutableLiteralBase() { + shape_ = MakeUnique(shape); + CHECK(LayoutUtil::HasLayout(*shape_)); + CHECK(!ShapeUtil::IsTuple(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_buffer(const_cast(src_buf_ptr)); + root_piece_->set_subshape(shape_.get()); +} + +MutableBorrowingLiteral::~MutableBorrowingLiteral() { + if (root_piece_ != nullptr) { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete piece->sparse_indices(); + } + }); + delete root_piece_; + } +} + +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} + +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} + +void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { + CHECK(ShapeUtil::IsTuple(shape)); + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + if (ShapeUtil::IsTuple(subshape)) { + BuildPieceSubtree(subshape, &child_piece); + } + + piece->emplace_back(std::move(child_piece)); + } +} + +BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsArray(*shape_)); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = Piece(); + root_piece_.set_buffer(const_cast(src_buf_ptr)); + root_piece_.set_subshape(shape_.get()); +} + +BorrowingLiteral::BorrowingLiteral( + tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); + root_piece_ = Piece(); + root_piece_.set_subshape(shape_.get()); + BuildPieceSubtree(*shape_, &root_piece_); + + for (int i = 0; i < src_buf_ptrs.size(); ++i) { + const auto& src_shape = shape_->tuple_shapes(i); + CHECK(ShapeUtil::IsArray(src_shape)); + root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h new file mode 100644 index 0000000000000000000000000000000000000000..92c0f903cbe252a153103aa8514bb5531696bbfe --- /dev/null +++ b/tensorflow/compiler/xla/literal.h @@ -0,0 +1,1188 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_ +#define TENSORFLOW_COMPILER_XLA_LITERAL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/sparse_index_array.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Forward declare Literal and LiteralSlice class to be used by the creation +// methods in the base class. +class Literal; +class LiteralSlice; + +// Abstract base class for literals. +class LiteralBase { + public: + virtual ~LiteralBase() = 0; + + // Literals are equal if they have compatible shapes and the same data + // values. Layout is not compared. + bool operator==(const LiteralBase& other) const; + bool operator!=(const LiteralBase& other) const { return !(*this == other); } + + // Returns the shape of the literal. + const Shape& shape() const { return root_piece().subshape(); } + + // Serialize to proto. + LiteralProto ToProto() const; + + // Returns an ArraySlice of the array for this literal for the given NativeT + // (e.g., float). CHECKs if the subshape of the literal at the given + // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type + // to native type. + template + tensorflow::gtl::ArraySlice data( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to the sparse index array. Returns nullptr if the + // literal is not a sparse array. + const SparseIndexArray* sparse_indices( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to (or size of) the underlying buffer holding the + // array at the given shape index. CHECKs if the subshape of the literal at + // the given ShapeIndex is not array. + const void* untyped_data(const ShapeIndex& shape_index = {}) const; + int64 size_bytes(const ShapeIndex& shape_index = {}) const; + + // Returns this literal's data as a string. This literal must be a rank-1 U8 + // array. + string GetR1U8AsString() const; + + // Returns a string representation of the literal value. + // Warning: this function can take minutes for multi-million element Literals. + string ToString(bool print_layout = false) const; + + // Gets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const; + // Overloads of Get for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the element value at index (0, ..., 0), however many zeroes are + // required for that index. + template + NativeT GetFirstElement() const; + + // As Get(), but determines the correct type and converts the value + // into text. + string GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index = {}) const; + // As GetSparseElement(), but determines the correct type and converts the + // value into text. + string GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + // As Get(), but determines the correct type and converts the value into + // int64. This literal must be an array. + StatusOr GetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the multi-index of the element in a sparse literal at the given + // sparse element number. The sparse element number is the position with in + // the sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + tensorflow::gtl::ArraySlice GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; + + // Returns the value of the element in a sparse literal at the given sparse + // element number. The sparse element number is the position with in the + // sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + template + NativeT GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + + // Invokes the "per cell" callback for each element in the provided + // literal with the element's indices and a string representation of + // the element's value. + // + // This function is useful if you want a polymorphic representation + // of the tensor's elements (turning it to a string for something + // like representation in a protobuf). + // + // This literal must have a dense layout. + void EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const; + template + void EachCell(std::function indices, + NativeT value)> + per_cell) const; + + // Returns whether every element in this literal is equal to value. + // + // value is an int8 because we expect this to be called with small + // compile-time constants (0, -1, etc.) and so that whatever value you pass + // can be represented exactly by floating-point types as small as 16 bits. + // + // If value doesn't fit in this literal's type, returns false. Values of 1/0 + // are considered equal to true/false; other values are not considered equal + // to true. Also if this literal is not array-shaped false is returned. + bool IsAll(int8 value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular floating-point number. + // + // If the literal is not a floating-point value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for values that can be expressed precisely as a float, + // e.g. -0.5. Also if this literal is not array-shaped false is returned. + bool IsAllFloat(float value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular complex number. + // + // If the literal is not a complex value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for complex values that can be expressed precisely as + // float pairs e.g. (-0.5, 1.0). + // + // This literal must have a dense layout. + bool IsAllComplex(complex64 value) const; + + // Literal consists entirely of the first element of the literal. + bool IsAllFirst() const; + + // Returns whether this literal is zero at the specified index. This literal + // must be an array with a dense layout. + bool IsZero(tensorflow::gtl::ArraySlice indices) const; + + // Returns the count of the elements in the array at the given shape index in + // this literal. + int64 element_count(const ShapeIndex& index = {}) const { + return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + } + + // Returns the count of the elements in the sparse array at the given shape + // index in this literal, which will be no larger than + // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). + int64 sparse_element_count() const; + + // Compute a hash for this literal. This literal must not be a sparse tensor + // or a tuple containing a sparse tensor. + size_t Hash() const; + + // Converts this literal to the given shape. Returns an error is the + // conversion is not possible. + // + // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding + // instead of truncation; otherwise, truncation is used. + // + // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes + // the default behavior. + StatusOr> ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + + // Converts this literal to another primitive type using a bitcast + // conversion. The to and from primitive types must have the same bit + // width. Returns an error if the conversion is not possible. This literal + // must be array-shaped. + StatusOr> BitcastConvert( + PrimitiveType primitive_dest_type) const; + + // Converts this literal to another primitive type. Returns an error if the + // conversion is not possible. This literal must be array-shaped. + StatusOr> Convert( + PrimitiveType primitive_dest_type) const; + + // Clones the underlying buffers into a new Literal, or new + // std::unique_ptr. + Literal Clone() const; + std::unique_ptr CloneToUnique() const; + + // TODO(b/67651157): The methods below which perform computation on Literals + // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with + // evaluator code which operates on Literals. + // + // Creates a new value that has the equivalent value as this + // literal, but conforms to new_layout; e.g. a literal matrix that was in {0, + // 1} minor-to-major dimension layout can be re-layed-out as {1, 0} + // minor-to-major dimension layout and the value in the cell at any given + // logical index (i0, i1) will be the same. + // + // For tuple shaped literals, shape_index should be used to select the inner + // array that the new layout applies to. + // + // Note: this is useful when the client wants to ensure that a value placed in + // the XLA allocation tracker has a particular layout; for efficiency + // purposes or avoiding unimplemented operation/layout combinations. + std::unique_ptr Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; + + // An overload of Relayout which changes the layout of the entire shape rather + // than being limited to a single array within the shape. + std::unique_ptr Relayout(const Shape& shape_with_layout) const; + + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The + // implementation currently only supports monotonic dim0-major layouts. + // This literal must be an array. + StatusOr> Reshape( + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by broadcasting this literal with `dimensions` to + // yield a literal of shape `result_shape`. + StatusOr> Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by reordering the dimensions of this literal. + // The given `permutation` must be a permutation of the dimension numbers + // in the original literal, and it specifies the order of the new dimensions + // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). + // For example, a transpose call on a literal of shape [3 x 8 x 4] and + // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + // This literal must be an array. + std::unique_ptr Transpose( + tensorflow::gtl::ArraySlice permutation) const; + + // Creates a sub-array from this literal by extracting the indices + // [start_index, limit_index) of each dimension. The result literal has the + // same rank and layout as for the given literal. The number of indices in + // start_indices and limit_indices must be the rank of the literal, and the + // indices follow the order of the dimensions. + // This literal must be an array. + std::unique_ptr Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const; + + // Creates a literal with a prepended dimension with bound "times"; e.g. a + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this + // literal replicated four times. + // This literal must be an array. + template + std::unique_ptr Replicate(int64 times) const; + + // Creates a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + // + // Note: It's an antipattern to use this method then immediately call + // MutableLiteralBase::Populate on the result (since that results in zero + // initialization, then reinitialization. Conside if a call to + // MakeUnique(shape), followed by the call to + // MutableLiteralBase::Populate can be used instead. + static std::unique_ptr CreateFromShape(const Shape& shape); + + protected: + // A data structure representing a subshape at a particular ShapeIndex within + // the literal. For array-shaped ShapeIndexes, this data structure holds the + // pointer to the memory allocated for the array data. + class Piece { + public: + // Returns the buffer holding the array data for this piece as an array + // slice. This piece must be array-shaped. + template + tensorflow::gtl::ArraySlice data() const; + template + tensorflow::gtl::MutableArraySlice data(); + + // Returns the buffer holding the array data for this piece as a void*. This + // piece must be array-shaped. + void* untyped_data(); + const void* untyped_data() const; + + // Gets or sets an element in the array at the given index. The multi_index + // is CHECKed against the dimension sizes of the array. This piece must be + // array-shaped. + template + NativeT Get(tensorflow::gtl::ArraySlice index) const; + template + void Set(tensorflow::gtl::ArraySlice index, NativeT value); + + // Gets/sets the buffer holding the array data. + char* buffer() const { return buffer_; } + void set_buffer(char* buffer) { buffer_ = buffer; } + + // The array of multi-indices that provide the locations of non-zero + // elements in a sparse array. Only used if + // LayoutUtil::IsSparseArray(shape()) is true. + SparseIndexArray* sparse_indices() const { return sparse_indices_; } + void set_sparse_indices(SparseIndexArray* sparse_indices) { + sparse_indices_ = sparse_indices; + } + + // Gets or sets the subshape of this piece. This reference points to a + // subshape within the shape in the containing Literal (Literal::shape_). + const Shape& subshape() const { return *subshape_; } + void set_subshape(const Shape* subshape) { subshape_ = subshape; } + + // Returns the size in bytes of the buffer holding the array data. + int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } + + // Returns the number of elements in this piece's array. + int64 element_count() const { + // If this is a sparse array, use the number of elements represented by + // the indices in the associated SparseIndexArray. + return LayoutUtil::IsSparseArray(subshape()) + ? sparse_indices()->index_count() + : ShapeUtil::ElementsIn(subshape()); + } + + // Returns the child piece at 'index' of this piece. + Piece& child(int64 index) { return children_[index]; } + + // Adds a child piece to this piece's children. + void emplace_back(Piece child_piece) { + children_.emplace_back(std::move(child_piece)); + } + + // Returns the size of children pieces of this piece. + int64 children_size() { return children_.size(); } + + // Visitor functions that recursively traverses the piece and calls the + // given function at each child piece. The function has the type: + // void (const ShapeIndex& index, const Piece& piece) + template + void ForEachSubpiece(const Fn& func) const { + ShapeIndex index; + return ForEachHelper( + [&func](const ShapeIndex& index, const Piece& piece) { + func(index, piece); + return Status::OK(); + }, + *this, &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, const Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachSubpieceWithStatus(const Fn& func) const { + ShapeIndex index; + return ForEachHelper(func, *this, &index); + } + // Same as above, but the function has the type: + // Bool (const ShapeIndex& index, const Piece& piece) + // The first non-true return value is returned by the function. + template + bool ForEachSubpieceWithBool(const Fn& func) const { + ShapeIndex index; + return ForEachHelperBool(func, *this, &index); + } + // Same as above, but the function has the type: + // Void (const ShapeIndex& index, Piece& piece) + template + void ForEachMutableSubpiece(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + [&func](const ShapeIndex& index, Piece* piece) { + func(index, piece); + return Status::OK(); + }, + const_cast(this), &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachMutableSubpieceWithStatus(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + func, const_cast(this), &index); + } + + // Returns true if this piece and 'other' contain the same data. This piece + // and 'other' must be array-shaped and compatible. + bool EqualElements(const Piece& other) const; + + // Writes the shape and data (if array-shaped) into the given proto. + void WriteToProto(LiteralProto* proto) const; + + // Copy the data from 'src' into this piece's buffer. Shapes of this piece + // and src must be compatible. + Status CopyFrom(const Piece& src); + + // Copies the data from the given proto into this piece. The shape of this + // piece must be equal (not just compatible) to the shape of the proto. + Status CopyFromProto(const LiteralProto& proto); + + // Sorts the elements in a sparse array. + void SortSparseElements(); + + private: + // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'. + // The first non-OK (or non-true) value is returned by the function. + // The callable 'func' has the same signature as described above in + // ForEachSubpiece*. + template + Status ForEachHelper(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + template + bool ForEachHelperBool(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + if (!func(*index, piece)) { + return false; + } + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + if (!ForEachHelperBool(func, piece.children_[i], index)) { + return false; + } + index->pop_back(); + } + return true; + } + template + Status ForEachMutableHelper(const Fn& func, Piece* piece, + ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece->children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR( + ForEachMutableHelper(func, &piece->children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + + // Recursive helper for EqualElements. + template + bool EqualElementsInternal(const Piece& other, + std::vector* multi_index) const; + + // Helper for SortSparseElements that has the element type as a template + // parameter. + template + void SortSparseElementsInternal(); + + // For array-shaped pieces, this is the buffer holding the literal data. + char* buffer_ = nullptr; + + // For sparse arrays, this is the array of indices. + SparseIndexArray* sparse_indices_ = nullptr; + + // The shape of piece. This points into the shape of the containing Literal + // (Literal::shape_). + const Shape* subshape_ = nullptr; + + // Children pieces for tuple shaped pieces. + std::vector children_ = {}; + }; // class Piece + + const Piece& piece(const ShapeIndex& shape_index) const { + Piece* piece = &const_cast(root_piece()); + for (const auto i : shape_index) { + DCHECK_GE(i, 0); + DCHECK_LT(i, piece->children_size()); + piece = &piece->child(i); + } + return *piece; + } + + // Returns the piece at the root of the shape. + virtual const Piece& root_piece() const = 0; + + // LiteralSlice and Literal must access Pieces of other Literals. + friend class MutableLiteralBase; + friend class LiteralSlice; + friend class BorrowingLiteral; + + private: + template + std::unique_ptr SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const; +}; + +// Abstract base class representing a mutable literal in XLA. +class MutableLiteralBase : public LiteralBase { + public: + virtual ~MutableLiteralBase() = 0; + + // Returns a MutableArraySlice view of the array for this literal for the + // given NativeT (e.g., float). CHECKs if the subshape of the literal at the + // given ShapeIndex is not array. See primitive_util.h for the mapping from + // XLA type to native type. + template + tensorflow::gtl::MutableArraySlice data( + const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::data; + + // Returns a pointer to the sparse index array. Returns nullptr if the literal + // is not a sparse array. + SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } + + // Returns a pointer to the underlying buffer holding the array at the given + // shape index. CHECKs if the subshape of the literal at the given ShapeIndex + // is not array. + void* untyped_data(const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::untyped_data; + + // Populates a literal with a sparse layout with the given indices and values. + // Each index in the indices array is CHECKed against the dimensions in the + // literal's shape. If sort is true, then the indices and values will be + // sorted. If sort is false, then the indices and values are assumed to + // already be in sorted order. See CreateSparse for an example of how data + // are populated. + template + void PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort = true); + + // Copy values from 'src_literal' rooted at 'src_shape_index' into this + // literal rooted at 'dest_shape_index'. The subshape of this literal rooted + // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' + // rooted at 'src_shape_index', but need not be arrays. + Status CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index = {}, + const ShapeIndex& src_shape_index = {}); + + // Copies the values from src_literal, starting at src_base shape indexes, + // to this literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // The src_literal and this literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + // Note: if either src_literal or this literal contains dimensions with zero + // element, then copy_size must be 0 in these dimensions while the + // corresponding base indices being 0. + // This literal and 'src_literal' must be arrays. + Status CopySliceFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Copies one element from src_literal[src_index] to (*this)[dest_index]. + Status CopyElementFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index); + + // Sets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + void Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value); + // Overloads of Set for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + + // Appends the given element to the literal. If the elements are not appended + // in sorted order, then SortSparseElements should be called before calling + // other methods. This literal must have a sparse layout. + template + void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, + NativeT value, const ShapeIndex& shape_index = {}); + + // Sorts the elements in a sparse array. + void SortSparseElements(const ShapeIndex& shape_index = {}); + + // As Set(), but truncates `value` to the literal element type before storing. + // This literal must be an array. + Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value); + + // Populate this literal with the given values. Examples: + // + // // Populate with floats. + // Array2D float_values = ... + // literal.PopulateR2FromArray2D(values); + // + // // Populate with int32s. + // literal.PopulateR2({{1, 2}, {3, 4}}); + // + // The shape and element type of this literal must match given values. For + // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 + // array of S32. + template + void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(const tensorflow::core::Bitmap& values); + template + void PopulateR2(std::initializer_list> values); + template + void PopulateFromArray(const Array& values); + template + void PopulateR2FromArray2D(const Array2D& values); + template + void PopulateR3FromArray3D(const Array3D& values); + template + void PopulateR4FromArray4D(const Array4D& values); + + // Populates literal values by calling the generator function for every cell + // in this literal object. + // + // generator must be a callable of the type + // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // + // This literal must have a dense layout. + template + Status Populate(const FnType& generator); + + // A parallel version of Populate(). This can be used if the generator is + // thread-safe and the values for the shape's different elements are + // independent. + template + Status PopulateParallel(const FnType& generator); + + // Fills this literal with the given value. + template + void PopulateWithValue(NativeT value); + + // This operation is the inverse of DecomposeTuple. The given elements are + // moved into the tuple elements of a new tuple-shaped Literal which is + // returned. Upon return, each of the Literals in 'elements' is set to a nil + // shape (empty tuple). + static Literal MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements); + + // Serialize from a proto. + static StatusOr> CreateFromProto( + const LiteralProto& proto); + + protected: + // Returns the piece at the given ShapeIndex. + Piece& piece(const ShapeIndex& shape_index) { + return const_cast(LiteralBase::piece(shape_index)); + } + + Piece& root_piece() const override { return *root_piece_; }; + + // Internal template helper for the Literal::CopySliceFrom(), matching its + // arguments one by one. + template + Status CopySliceFromInternal(const LiteralBase& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Utility structure which is used to create the optimal configuration for + // a ShapeUtil::ForEachIndex() scan across two literals. + struct StrideConfig { + StrideConfig(const Shape& source_shape, const Shape& dest_shape, + tensorflow::gtl::ArraySlice dimensions); + + // The dimensions of the stride operation. Essentially every dimension + // will be iterated from base[i] to base[i]+dimensions[i], in step[i] + // steps. + tensorflow::gtl::ArraySlice dimensions; + DimensionVector base; + DimensionVector step; + int64 minor_dimension = 0; + // The size of the strides for source and destination. One of the two + // (the one looping through its most minor dimension) will be 1, while + // the other will be the stride size at the dimension matching the other + // shape most minor dimension being scanned. + int64 dest_stride = 1; + int64 source_stride = 1; + // The size of the inner loop on the most minor dimension. + int64 minor_loop_size = 1; + }; + + // Literal class always owns the shape. The parent class borrows this shape. + std::unique_ptr shape_; + + Piece* root_piece_ = nullptr; + + // Implementation details shared between Populate() and PopulateParallel() + template + Status PopulateInternal(const FnType& generator, bool parallel); + + friend class LiteralBase; + friend class MutableBorrowingLiteral; +}; +std::ostream& operator<<(std::ostream& out, const Literal& literal); + +// The underlying buffer and shape is always owned by this class. +class Literal : public MutableLiteralBase { + public: + Literal() : Literal(ShapeUtil::MakeNil()) {} + + // Create a literal of the given shape. The literal is allocated sufficient + // memory to hold the shape. Memory is uninitialized. + explicit Literal(const Shape& shape); + virtual ~Literal(); + + // Literals are moveable, but not copyable. To copy a literal use + // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies + // of literals which can be expensive. + Literal(const Literal& other) = delete; + Literal& operator=(const Literal& other) = delete; + Literal(Literal&& other); + // 'allocate_arrays' indicates whether to allocate memory for the arrays in + // the shape. If false, buffer pointers inside of the Literal::Pieces are set + // to nullptr. + Literal(const Shape& shape, bool allocate_arrays); + Literal& operator=(Literal&& other); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + virtual Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); + + private: + // Deallocate the buffers held by this literal. + void DeallocateBuffers(); + + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); +}; + +// The underlying buffer is not owned by this class and is always owned by +// others. The shape is not owned by this class and not mutable. +class MutableBorrowingLiteral : public MutableLiteralBase { + public: + virtual ~MutableBorrowingLiteral(); + + MutableBorrowingLiteral() : MutableLiteralBase() {} + + MutableBorrowingLiteral(const MutableBorrowingLiteral& literal); + MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal); + + // Implicit conversion constructors. + MutableBorrowingLiteral(const MutableLiteralBase& literal); + MutableBorrowingLiteral(MutableLiteralBase* literal); + MutableBorrowingLiteral(MutableBorrowingLiteral literal, + const ShapeIndex& view_root); + MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + + private: + // Recursively copies the subtree from the `src_piece` at the given child + // index to the `dest_piece`. For buffers only the pointers are copied, but + // not the content. + void CopyPieceSubtree(const Shape& shape, Piece* src_piece, + Piece* dest_piece); +}; + +// A read-only view of a Literal. A LiteralSlice contains pointers to shape and +// literal buffers always owned by others. +class LiteralSlice : public LiteralBase { + public: + LiteralSlice() : LiteralBase() {} + + // Implicit conversion constructors. + LiteralSlice(const LiteralBase& literal); + LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root); + + private: + const Piece& root_piece() const override { return *root_piece_; }; + + const Piece* root_piece_; // Not owned. +}; + +// A read-only Literal where the underlying buffers are never owned by this +// class. +class BorrowingLiteral : public LiteralBase { + public: + BorrowingLiteral() : LiteralBase() {} + + // 'src_buf_ptr' is not owned by this class and must outlive the + // lifetime of this class. It points to an appropirately sized buffer with + // data interpretered as indicated by 'shape'. + // This constructor is only used for array shapes. + BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + // Similar as above, except to be used for constructing non-nested tuples. + BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, + const Shape& shape); + // TODO(b/79707221): adding constructors for nested tuples as well. + + private: + // Recursively builds the subtree for the given piece and sets the subshapes + // of the given piece with the given shape. + void BuildPieceSubtree(const Shape& shape, Piece* piece); + + // Accessor for the root piece of this literal. + const Piece& root_piece() const override { return root_piece_; }; + Piece root_piece_; + + // Shape of this literal. Stored as unique_ptr such that the (default) move + // construction of this class would be trivially correct: the pointer to Shape + // root_piece_ stores will still point to the correct address. + std::unique_ptr shape_; +}; + +template +tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) + << "Attempting to access " + << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) + << " type, but literal element type is " + << PrimitiveType_Name(subshape().element_type()); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(buffer()), element_count()); +} + +template +tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) + << "Attempting to access " + << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) + << " type, but literal element type is " + << PrimitiveType_Name(subshape().element_type()); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(buffer()), element_count()); +} + +template +NativeT LiteralBase::Piece::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(LayoutUtil::IsDenseArray(subshape())); + return data()[IndexUtil::MultidimensionalIndexToLinearIndex( + subshape(), multi_index)]; +} + +template +void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + CHECK(LayoutUtil::IsDenseArray(subshape())); + data()[IndexUtil::MultidimensionalIndexToLinearIndex( + subshape(), multi_index)] = value; +} + +template +tensorflow::gtl::ArraySlice LiteralBase::data( + const ShapeIndex& shape_index) const { + return piece(shape_index).data(); +} + +template +tensorflow::gtl::MutableArraySlice MutableLiteralBase::data( + const ShapeIndex& shape_index) { + return piece(shape_index).data(); +} + +template +inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { + return piece(shape_index).Get(multi_index); +} + +template +inline NativeT LiteralBase::Get( + tensorflow::gtl::ArraySlice multi_index) const { + return root_piece().Get(multi_index); +} + +template +inline void MutableLiteralBase::Set( + tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value) { + return piece(shape_index).Set(multi_index, value); +} + +template +inline void MutableLiteralBase::Set( + tensorflow::gtl::ArraySlice multi_index, NativeT value) { + return root_piece().Set(multi_index, value); +} + +template +NativeT LiteralBase::GetFirstElement() const { + return data().at(0); +} + +template +NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index) const { + CHECK( + LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); + return data(shape_index)[sparse_element_number]; +} + +template +void MutableLiteralBase::AppendSparseElement( + tensorflow::gtl::ArraySlice multi_index, NativeT value, + const ShapeIndex& shape_index) { + Piece& p = piece(shape_index); + const Shape& subshape = p.subshape(); + CHECK(LayoutUtil::IsSparseArray(subshape)); + int64 rank = ShapeUtil::Rank(subshape); + CHECK_EQ(multi_index.size(), rank); + int64 last_element = p.sparse_indices()->index_count(); + CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); + p.sparse_indices()->Append(multi_index); + CHECK_LT(last_element, p.data().size()); + p.data()[last_element] = value; +} + +template +void LiteralBase::EachCell( + std::function indices, + NativeT value)> + per_cell) const { + if (ShapeUtil::IsZeroElementArray(shape())) { + return; + } + std::vector indices(ShapeUtil::Rank(shape()), 0); + do { + per_cell(indices, Get(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); +} + +template +inline void MutableLiteralBase::PopulateR1( + tensorflow::gtl::ArraySlice values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + for (int64 i = 0; i < values.size(); ++i) { + Set({i}, values[i]); + } +} + +template +void MutableLiteralBase::PopulateR2( + std::initializer_list> values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 2); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + + const int64 dim0_size = values.size(); + const int64 dim1_size = values.begin()->size(); + CHECK_EQ(dim0_size, shape().dimensions(0)); + CHECK_EQ(dim1_size, shape().dimensions(1)); + + int64 dim0 = 0; + for (auto inner_list : values) { + int64 dim1 = 0; + for (auto value : inner_list) { + Set({dim0, dim1}, value); + ++dim1; + } + CHECK_EQ(dim1_size, dim1); + ++dim0; + } +} + +template +void MutableLiteralBase::PopulateFromArray(const Array& values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); + for (int dim = 0; dim < values.num_dimensions(); ++dim) { + CHECK_EQ(values.dim(dim), shape().dimensions(dim)); + } + values.Each([this](tensorflow::gtl::ArraySlice indices, + NativeT value) { this->Set(indices, value); }); +} + +template +void MutableLiteralBase::PopulateR2FromArray2D(const Array2D& values) { + PopulateFromArray(values); +} + +template +void MutableLiteralBase::PopulateR3FromArray3D(const Array3D& values) { + PopulateFromArray(values); +} + +template +void MutableLiteralBase::PopulateR4FromArray4D(const Array4D& values) { + PopulateFromArray(values); +} + +template +void MutableLiteralBase::PopulateSparse( + SparseIndexArray indices, tensorflow::gtl::ArraySlice values, + bool sort) { + CHECK(LayoutUtil::IsSparseArray(shape())); + int rank = ShapeUtil::Rank(shape()); + CHECK_EQ(indices.rank(), rank); + int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); + CHECK_LE(indices.max_indices(), max_elements); + int64 num_elements = values.size(); + CHECK_LE(num_elements, max_elements); + CHECK_EQ(num_elements, indices.index_count()); + auto root_data = root_piece().data(); + // Piece::data() returns an ArraySlice of size equal to the number of indices + // in the SparseIndexArray. So there is no need to adjust the size of the data + // here. It is enough to just copy the incoming values into the data buffer. + std::copy(values.begin(), values.end(), root_data.begin()); + *this->root_piece().sparse_indices() = std::move(indices); + if (sort) { + auto root_data = this->root_piece().data(); + this->root_piece().sparse_indices()->SortWithValues(root_data); + } + DCHECK(this->root_piece().sparse_indices()->Validate(shape())); +} + +template +Status MutableLiteralBase::PopulateInternal(const FnType& generator, + bool parallel) { + const Shape& this_shape = shape(); + const int64 rank = ShapeUtil::Rank(this_shape); + TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); + TF_RET_CHECK(this_shape.element_type() == + primitive_util::NativeToPrimitiveType()); + tensorflow::gtl::MutableArraySlice literal_data = data(); + if (rank > 0) { + StrideConfig stride_config(this_shape, this_shape, + AsInt64Slice(this_shape.dimensions())); + int64 minor_dimension_size = + ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); + + auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { + DimensionVector minor_scan_indexes(rank, 0); + const int64 index = + IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); + std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); + for (int64 i = 0; i < minor_dimension_size; ++i) { + minor_scan_indexes[stride_config.minor_dimension] = i; + literal_data.at(index + i) = generator(minor_scan_indexes); + } + }; + if (parallel) { + ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base, + stride_config.dimensions, + stride_config.step, init_function); + } else { + ShapeUtil::ForEachIndex( + this_shape, stride_config.base, stride_config.dimensions, + stride_config.step, + [&init_function](tensorflow::gtl::ArraySlice indexes) { + init_function(indexes); + return true; + }); + } + } else { + // For scalars. + literal_data.at(0) = generator({}); + } + return Status::OK(); +} +template +Status MutableLiteralBase::Populate(const FnType& generator) { + return PopulateInternal(generator, /*parallel=*/false); +} + +template +Status MutableLiteralBase::PopulateParallel(const FnType& generator) { + return PopulateInternal(generator, /*parallel=*/true); +} + +template +void MutableLiteralBase::PopulateWithValue(NativeT value) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + for (NativeT& element : data()) { + element = value; + } +} + +template +std::unique_ptr LiteralBase::Replicate(int64 times) const { + DimensionVector bounds = {times}; + bounds.reserve(shape().dimensions_size() + 1); + for (int64 bound : shape().dimensions()) { + bounds.push_back(bound); + } + auto literal = + MakeUnique(ShapeUtil::MakeShape(shape().element_type(), bounds)); + int64 elements = ShapeUtil::ElementsIn(literal->shape()); + if (elements == 0) { + return literal; + } + + DimensionVector output_indices(bounds.size(), 0); + tensorflow::gtl::ArraySlice input_indices = output_indices; + input_indices.remove_prefix(1); + + bool done = false; + while (!done) { + const auto element = Get(input_indices); + literal->Set(output_indices, element); + + done = true; + for (int n = 0; n < output_indices.size(); ++n) { + ++output_indices[n]; + if (output_indices[n] < bounds[n]) { + done = false; + break; + } + output_indices[n] = 0; + } + } + return literal; +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LITERAL_H_ diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index bf9679cafec72c2e9dc5796e9058c6703239c508..94993cc87443ba8c22fd7c2eacfc8756d3f48edc 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -217,7 +218,7 @@ class NearComparator { return Printf( "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), - Literal::MultiIndexAsString( + LiteralUtil::MultiIndexAsString( IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index)) .c_str(), @@ -606,8 +607,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, } // namespace Status EqualShapes(const Shape& expected, const Shape& actual) { - if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { - return InvalidArgument("tupleness-mismatch! want: %s got %s", + if (expected.element_type() != actual.element_type()) { + return InvalidArgument("element type mismatch, want: %s got %s", ShapeUtil::HumanString(expected).c_str(), ShapeUtil::HumanString(actual).c_str()); } @@ -626,7 +627,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { return AppendStatus(result, StrCat("mismatch in tuple index", i)); } } - } else { + } else if (ShapeUtil::IsArray(expected)) { if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { return InvalidArgument("want rank of %s got rank of %s", ShapeUtil::HumanString(expected).c_str(), @@ -652,6 +653,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { } } } + // Non-array, non-tuple shapes are trivially equivalent. return Status::OK(); } @@ -705,6 +707,9 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { } break; } + case TOKEN: + // Tokens have no on-device representation and are trivially equal. + return Status::OK(); default: LOG(FATAL) << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " @@ -718,7 +723,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { return AppendStatus(result, tensorflow::strings::Printf( "\nat index: %s\nexpected: %s\nactual: %s", - Literal::MultiIndexAsString(multi_index).c_str(), + LiteralUtil::MultiIndexAsString(multi_index).c_str(), ToStringTruncated(expected).c_str(), ToStringTruncated(actual).c_str())); } diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h index 00a13e361932e74a9a1e614d5c851d3851208852..9e5bf7c1d062ef0f25d07a80d6ded8106df5dacc 100644 --- a/tensorflow/compiler/xla/literal_comparison.h +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -20,7 +20,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ #include "tensorflow/compiler/xla/error_spec.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/lib/core/status.h" namespace xla { diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_test.cc similarity index 76% rename from tensorflow/compiler/xla/literal_util_test.cc rename to tensorflow/compiler/xla/literal_test.cc index 53b926163c472c3ed7b72bf8b035d13996d59e34..e8f919950f0efc8b508f7ad4aee5233176bc0abd 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" @@ -76,11 +77,11 @@ class LiteralUtilTest : public ::testing::Test { layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3}); literal_r4_2x2x3x3_dim0major_ = - Literal::CreateR4FromArray4DWithLayout(arr4d, - layout_r4_dim0major_); + LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0major_); literal_r4_2x2x3x3_dim0minor_ = - Literal::CreateR4FromArray4DWithLayout(arr4d, - layout_r4_dim0minor_); + LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0minor_); } Layout layout_r2_dim0major_; @@ -94,47 +95,47 @@ class LiteralUtilTest : public ::testing::Test { }; TEST_F(LiteralUtilTest, LiteralScalarToString) { - auto true_lit = Literal::CreateR0(true); + auto true_lit = LiteralUtil::CreateR0(true); ASSERT_EQ("true", true_lit->ToString()); - auto false_lit = Literal::CreateR0(false); + auto false_lit = LiteralUtil::CreateR0(false); ASSERT_EQ("false", false_lit->ToString()); - auto u32_lit = Literal::CreateR0(42); + auto u32_lit = LiteralUtil::CreateR0(42); ASSERT_EQ("42", u32_lit->ToString()); - auto s32_lit = Literal::CreateR0(-999); + auto s32_lit = LiteralUtil::CreateR0(-999); ASSERT_EQ("-999", s32_lit->ToString()); - auto f32_lit = Literal::CreateR0(3.14f); + auto f32_lit = LiteralUtil::CreateR0(3.14f); ASSERT_EQ("3.14", f32_lit->ToString()); - auto f16_lit = Literal::CreateR0(static_cast(0.5f)); + auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); ASSERT_EQ("0.5", f16_lit->ToString()); - auto c64_lit = Literal::CreateR0({3.14f, 2.78f}); + auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); - auto bf16_lit = Literal::CreateR0(static_cast(0.5f)); + auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); ASSERT_EQ("0.5", bf16_lit->ToString()); // 3.14 will be truncated to 3.125 in bfloat16 format. auto bf16_lit_truncated = - Literal::CreateR0(static_cast(3.14f)); + LiteralUtil::CreateR0(static_cast(3.14f)); ASSERT_EQ("3.125", bf16_lit_truncated->ToString()); auto bf16_lit_truncated2 = - Literal::CreateR0(static_cast(9.001f)); + LiteralUtil::CreateR0(static_cast(9.001f)); ASSERT_EQ("9", bf16_lit_truncated2->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { - auto pred_vec = Literal::CreateR1({true, false, true}); + auto pred_vec = LiteralUtil::CreateR1({true, false, true}); ASSERT_EQ("{101}", pred_vec->ToString()); } TEST_F(LiteralUtilTest, R2ToString) { - const auto literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); const string expected = R"(s32[3,2] { { 1, 2 }, { 3, 4 }, @@ -144,7 +145,8 @@ TEST_F(LiteralUtilTest, R2ToString) { } TEST_F(LiteralUtilTest, R3ToString) { - const auto literal = Literal::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); + const auto literal = + LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); const string expected = R"(s32[3,2,1] { { { 1 }, { 2 } }, @@ -157,9 +159,9 @@ TEST_F(LiteralUtilTest, R3ToString) { } TEST_F(LiteralUtilTest, TupleToString) { - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); const string expected = R"((f32[], f32[2,2]) ( 1, f32[2,2] { @@ -182,7 +184,7 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { }); // clang-format on - auto literal = Literal::CreateR3FromArray3D(array_3d); + auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); string result = literal->ToString(); const string expected = R"(f32[2,3,2] { @@ -205,7 +207,7 @@ TEST_F(LiteralUtilTest, CreateSparse) { {3, 5, 6}, }; std::vector values = {7, 8, 9, 10}; - auto literal = Literal::CreateSparse( + auto literal = LiteralUtil::CreateSparse( dimensions, SparseIndexArray(indices.n1() + 3, indices), values); Array2D expected_indices = { @@ -224,7 +226,7 @@ TEST_F(LiteralUtilTest, CreateSparse) { TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { // clang-format off - auto literal = Literal::CreateR4Projected({ + auto literal = LiteralUtil::CreateR4Projected({ {1, 2}, {1001, 1002}, {2001, 2002}, @@ -284,7 +286,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { TEST_F(LiteralUtilTest, EachCellR2F32) { // clang-format off - auto literal = Literal::CreateR2({ + auto literal = LiteralUtil::CreateR2({ {3.1f, 4.2f}, {9.3f, 12.4f}, }); @@ -303,26 +305,27 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { TEST_F(LiteralUtilTest, ScalarEquality) { // Test equality with scalars. - auto f32_42 = Literal::CreateR0(42.0); - auto f32_42_clone = Literal::CreateR0(42.0); + auto f32_42 = LiteralUtil::CreateR0(42.0); + auto f32_42_clone = LiteralUtil::CreateR0(42.0); EXPECT_EQ(*f32_42, *f32_42); EXPECT_EQ(*f32_42, *f32_42_clone); - auto f32_123 = Literal::CreateR0(123.0); + auto f32_123 = LiteralUtil::CreateR0(123.0); EXPECT_NE(*f32_42, *f32_123); - auto f64_42 = Literal::CreateR0(42.0); + auto f64_42 = LiteralUtil::CreateR0(42.0); EXPECT_NE(*f32_42, *f64_42); } TEST_F(LiteralUtilTest, NonScalarEquality) { // Test equality with nonscalars. - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_clone = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_different = Literal::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); - auto vector_literal = Literal::CreateR1({1.0, 2.0, 3.0, 4.0}); - auto scalar = Literal::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_clone = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_different = + LiteralUtil::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); + auto vector_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); + auto scalar = LiteralUtil::CreateR0(1.0); Literal nil(ShapeUtil::MakeNil()); EXPECT_EQ(*matrix, *matrix); @@ -334,6 +337,22 @@ TEST_F(LiteralUtilTest, NonScalarEquality) { EXPECT_EQ(nil, nil); } +TEST_F(LiteralUtilTest, TokenEquality) { + auto token0 = LiteralUtil::CreateToken(); + auto token1 = LiteralUtil::CreateToken(); + auto scalar = LiteralUtil::CreateR0(1.0); + + EXPECT_EQ(*token0, *token1); + EXPECT_NE(*token0, *scalar); + + EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}), + *LiteralUtil::MakeTuple({token0.get()})); + EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), + *LiteralUtil::MakeTuple({token1.get(), scalar.get()})); + EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), + *LiteralUtil::MakeTuple({scalar.get(), token1.get()})); +} + TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. auto colmajor = @@ -355,43 +374,46 @@ TEST_F(LiteralUtilTest, DifferentLayoutEquality) { TEST_F(LiteralUtilTest, TupleEquality) { // Test equality with tuples. - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. - auto scalar_clone = Literal::CreateR0(1.0); - auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()}); + auto scalar_clone = LiteralUtil::CreateR0(1.0); + auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()}); EXPECT_EQ(*tuple1, *tuple2); // Tuple with elements reversed. - auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()}); + auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()}); EXPECT_NE(*tuple1, *reversed_tuple); // Tuple with different value. - auto scalar_42 = Literal::CreateR0(42.0); - auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()}); + auto scalar_42 = LiteralUtil::CreateR0(42.0); + auto different_tuple = + LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()}); EXPECT_NE(*tuple1, *different_tuple); } TEST_F(LiteralUtilTest, C64Equality) { // Test equality with tuples. - auto vector = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. - auto vector_clone = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector_clone = + LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); EXPECT_EQ(*vector, *vector_clone); - auto vector_reversed = Literal::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); + auto vector_reversed = + LiteralUtil::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); EXPECT_NE(*vector, *vector_reversed); } TEST_F(LiteralUtilTest, IsAllTuple) { - auto element1 = Literal::CreateR0(0.0); - auto element2 = Literal::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - auto tuple = Literal::MakeTuple({element1.get(), element1.get()}); + auto element1 = LiteralUtil::CreateR0(0.0); + auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); + auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()}); // Tuples should always return false for IsAll. EXPECT_FALSE(tuple->IsAll(0)); @@ -400,140 +422,141 @@ TEST_F(LiteralUtilTest, IsAllTuple) { // Verifies that CreateFromShape works for tuples. TEST_F(LiteralUtilTest, CreateFromShapeTuple) { - auto scalar = Literal::CreateR0(0.0); - auto matrix = Literal::CreateR2({{0, 0}, {0, 0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto scalar = LiteralUtil::CreateR0(0.0); + auto matrix = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); auto x = Literal::CreateFromShape(tuple->shape()); EXPECT_EQ(*tuple, *x); } TEST_F(LiteralUtilTest, IsAll) { - EXPECT_TRUE(Literal::CreateR0(false)->IsAll(0)); - EXPECT_TRUE(Literal::CreateR0(true)->IsAll(1)); - EXPECT_FALSE(Literal::CreateR0(false)->IsAll(1)); - EXPECT_FALSE(Literal::CreateR0(false)->IsAll(2)); - EXPECT_FALSE(Literal::CreateR0(true)->IsAll(0)); - EXPECT_FALSE(Literal::CreateR0(true)->IsAll(2)); - EXPECT_FALSE(Literal::CreateR0(true)->IsAll(-1)); + EXPECT_TRUE(LiteralUtil::CreateR0(false)->IsAll(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(true)->IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(-1)); // We shouldn't reinterpret int8_min as an unsigned type and then decide that // it is equal to 255. auto int8_min = std::numeric_limits::min(); - EXPECT_FALSE(Literal::CreateR0(255)->IsAll(int8_min)); + EXPECT_FALSE(LiteralUtil::CreateR0(255)->IsAll(int8_min)); - EXPECT_TRUE(Literal::CreateR0(42.0)->IsAll(42)); - EXPECT_FALSE(Literal::CreateR0(42.0001)->IsAll(42)); + EXPECT_TRUE(LiteralUtil::CreateR0(42.0)->IsAll(42)); + EXPECT_FALSE(LiteralUtil::CreateR0(42.0001)->IsAll(42)); - EXPECT_TRUE(Literal::CreateR1({100, 100, 100})->IsAll(100)); - EXPECT_FALSE(Literal::CreateR1({100, 100, 100.001})->IsAll(100)); + EXPECT_TRUE(LiteralUtil::CreateR1({100, 100, 100})->IsAll(100)); + EXPECT_FALSE(LiteralUtil::CreateR1({100, 100, 100.001})->IsAll(100)); - EXPECT_TRUE(Literal::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); half h8(8.0f); half h9(9.0f); - EXPECT_TRUE(Literal::CreateR2({{h8}, {h8}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{h8}, {h9}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{h9}, {h8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{h8}, {h8}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h8}, {h9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h9}, {h8}})->IsAll(8)); bfloat16 b8(8.0f); bfloat16 b9(9.0f); - EXPECT_TRUE(Literal::CreateR2({{b8}, {b8}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{b8}, {b9}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{b9}, {b8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b8}, {b8}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b8}, {b9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b9}, {b8}})->IsAll(8)); // 9.001 will be truncated to 9.0 bfloat16 b91(9.001f); bfloat16 b90(9.00f); - EXPECT_TRUE(Literal::CreateR2({{b91}, {b90}})->IsAll(9.0)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}})->IsAll(9.0)); complex64 c8_9 = {8, 9}; - EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); auto uint64_max = std::numeric_limits::max(); - EXPECT_FALSE(Literal::CreateR2( + EXPECT_FALSE(LiteralUtil::CreateR2( {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) ->IsAll(-1)); } TEST_F(LiteralUtilTest, IsAllFloat) { // IsAllFloat always returns false when the literal is not floating-point. - EXPECT_FALSE(Literal::CreateR0(false)->IsAllFloat(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); - - EXPECT_TRUE(Literal::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(Literal::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(Literal::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(Literal::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + + EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); EXPECT_FALSE( - Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); - EXPECT_TRUE( - Literal::CreateR2({{.5, .5, .5}, {.5, .5, .5}})->IsAllFloat(.5)); - - EXPECT_TRUE(Literal::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(Literal::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(Literal::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(Literal::CreateR0(-.5)->IsAllFloat(-.49)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR2({{.5, .5, .5}, {.5, .5, .5}}) + ->IsAllFloat(.5)); + + EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); EXPECT_FALSE( - Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); } TEST_F(LiteralUtilTest, IsAllComplex) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(Literal::CreateR0(false)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; - EXPECT_TRUE(Literal::CreateR2({{c8_9}, {c8_9}}) + EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}) ->IsAllComplex({8.0f, 9.0f})); - EXPECT_FALSE(Literal::CreateR2({{c7_9}, {c8_9}}) + EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}) ->IsAllComplex({8.0f, 9.0f})); - EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c7_9}}) + EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c7_9}}) ->IsAllComplex({8.0f, 9.0f})); } TEST_F(LiteralUtilTest, IsAllFirst) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(Literal::CreateR1({false, true})->IsAllFirst()); - EXPECT_TRUE(Literal::CreateR1({false, false})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({false, true})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({false, false})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; - EXPECT_TRUE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); + EXPECT_FALSE( + LiteralUtil::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); } TEST_F(LiteralUtilTest, IsZero) { - auto scalar_zero = Literal::CreateR0(0.0f); - auto scalar_one = Literal::CreateR0(1.0f); + auto scalar_zero = LiteralUtil::CreateR0(0.0f); + auto scalar_one = LiteralUtil::CreateR0(1.0f); EXPECT_TRUE(scalar_zero->IsZero({})); EXPECT_FALSE(scalar_one->IsZero({})); - auto array = Literal::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); + auto array = LiteralUtil::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); EXPECT_FALSE(array->IsZero({0, 1})); EXPECT_TRUE(array->IsZero({0, 2})); EXPECT_TRUE(array->IsZero({1, 1})); EXPECT_FALSE(array->IsZero({1, 2})); - auto complex_zero = Literal::CreateR0(0.0f); - auto complex_nonzero = Literal::CreateR0(0.5f); + auto complex_zero = LiteralUtil::CreateR0(0.0f); + auto complex_nonzero = LiteralUtil::CreateR0(0.5f); EXPECT_TRUE(complex_zero->IsZero({})); EXPECT_FALSE(complex_nonzero->IsZero({})); } @@ -547,7 +570,7 @@ TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { // Make a non-integer for floating point types. TypeParam half = TypeParam(1) / TypeParam(2); - auto data = Literal::CreateR2({{half, 2}, {3, 4}}); + auto data = LiteralUtil::CreateR2({{half, 2}, {3, 4}}); const Layout layout01 = LayoutUtil::MakeLayout({0, 1}); const Layout layout10 = LayoutUtil::MakeLayout({1, 0}); @@ -561,7 +584,7 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { } TEST_F(LiteralUtilTest, ReshapeR0) { - auto original = Literal::CreateR0(1.7f); + auto original = LiteralUtil::CreateR0(1.7f); auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); EXPECT_EQ(*original, *reshape); } @@ -569,13 +592,13 @@ TEST_F(LiteralUtilTest, ReshapeR0) { TEST_F(LiteralUtilTest, ReshapeR4) { // clang-format off // F32[1x3x2x4] - auto original = Literal::CreateR4WithLayout({{ + auto original = LiteralUtil::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // F32[1x3x4x2] - auto expected = Literal::CreateR3WithLayout({ + auto expected = LiteralUtil::CreateR3WithLayout({ {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, @@ -589,13 +612,13 @@ TEST_F(LiteralUtilTest, ReshapeR4) { TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { // clang-format off // F32[1x3x2x4] - auto original = Literal::CreateR4WithLayout({{ + auto original = LiteralUtil::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0minor_); // F32[1x3x4x2] - auto expected = Literal::CreateR3WithLayout({ + auto expected = LiteralUtil::CreateR3WithLayout({ {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, @@ -607,7 +630,7 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { } TEST_F(LiteralUtilTest, TransposeR0) { - auto original = Literal::CreateR0(1.7f); + auto original = LiteralUtil::CreateR0(1.7f); auto reshape = original->Transpose(/*permutation=*/{}); EXPECT_EQ(*original, *reshape); } @@ -615,7 +638,7 @@ TEST_F(LiteralUtilTest, TransposeR0) { TEST_F(LiteralUtilTest, TransposeR4) { // clang-format off // F32[1x3x2x4] - auto original = Literal::CreateR4({{ + auto original = LiteralUtil::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -643,7 +666,7 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { TEST_F(LiteralUtilTest, TestR2LinearLayout) { // Test expected memory layout of R2 dim0-minor (column-major) literal. - auto mat_dim0minor = Literal::CreateR2WithLayout( + auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); EXPECT_EQ(mat_dim0minor->element_count(), 6); EXPECT_THAT(mat_dim0minor->data(), ElementsAre(1, 4, 2, 5, 3, 6)); @@ -654,7 +677,7 @@ TEST_F(LiteralUtilTest, TestR2LinearLayout) { ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). - auto mat_dim0major = Literal::CreateR2WithLayout( + auto mat_dim0major = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); EXPECT_EQ(mat_dim0major->element_count(), 6); EXPECT_THAT(mat_dim0major->data(), ElementsAre(1, 2, 3, 4, 5, 6)); @@ -679,8 +702,8 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { {10, 11, 12}, }, }); // clang-format on - auto lit_dim0minor = - Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0minor_); + auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout( + arr3d, layout_r3_dim0minor_); EXPECT_EQ(lit_dim0minor->element_count(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; @@ -694,8 +717,8 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). - auto lit_dim0major = - Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0major_); + auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( + arr3d, layout_r3_dim0major_); EXPECT_EQ(lit_dim0major->element_count(), 12); EXPECT_THAT(lit_dim0major->data(), testing::ElementsAreArray(expected_dim0major)); @@ -707,28 +730,28 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { } TEST_F(LiteralUtilTest, SliceR0S32) { - auto input = Literal::CreateR0(1); + auto input = LiteralUtil::CreateR0(1); auto result = input->Slice({}, {}); EXPECT_EQ(*input, *result); } TEST_F(LiteralUtilTest, SliceR1F32) { - auto input = Literal::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); + auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); auto result = input->Slice({3}, {4}); - auto expected = Literal::CreateR1({4.0}); + auto expected = LiteralUtil::CreateR1({4.0}); EXPECT_EQ(*expected, *result); } TEST_F(LiteralUtilTest, SliceR2U32) { - auto input_3x4 = - Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto input_3x4 = LiteralUtil::CreateR2( + {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); auto result = input_3x4->Slice({0, 2}, {2, 4}); - auto expected = Literal::CreateR2({{3, 4}, {7, 8}}); + auto expected = LiteralUtil::CreateR2({{3, 4}, {7, 8}}); EXPECT_EQ(*expected, *result); } TEST_F(LiteralUtilTest, SliceR3U32Full) { - auto input_2x3x2 = Literal::CreateR3( + auto input_2x3x2 = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); EXPECT_EQ(*input_2x3x2, *result); @@ -737,21 +760,21 @@ TEST_F(LiteralUtilTest, SliceR3U32Full) { TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output(ShapeUtil::MakeShape(S64, {1})); output.PopulateR1({77}); - auto expected = Literal::CreateR1({77}); + auto expected = LiteralUtil::CreateR1({77}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR1U64) { Literal output(ShapeUtil::MakeShape(U64, {2})); output.PopulateR1({{77, 88}}); - auto expected = Literal::CreateR1({{77, 88}}); + auto expected = LiteralUtil::CreateR1({{77, 88}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR1C64) { Literal output(ShapeUtil::MakeShape(C64, {1})); output.PopulateR1({{77, 88}}); - auto expected = Literal::CreateR1({{77, 88}}); + auto expected = LiteralUtil::CreateR1({{77, 88}}); EXPECT_EQ(output, *expected); } @@ -759,7 +782,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { Literal output(ShapeUtil::MakeShape(C64, {2, 2})); output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); auto expected = - Literal::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); + LiteralUtil::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); EXPECT_EQ(output, *expected); } @@ -767,7 +790,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { Literal output(ShapeUtil::MakeShape(BF16, {})); bfloat16 h(0.25f); output.PopulateWithValue(h); - auto expected = Literal::CreateR0(h); + auto expected = LiteralUtil::CreateR0(h); EXPECT_EQ(output, *expected); } @@ -775,7 +798,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { Literal output(ShapeUtil::MakeShape(BF16, {3})); bfloat16 h(0.5f); output.PopulateWithValue(h); - auto expected = Literal::CreateR1({h, h, h}); + auto expected = LiteralUtil::CreateR1({h, h, h}); EXPECT_EQ(output, *expected); } @@ -783,28 +806,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { Literal output(ShapeUtil::MakeShape(BF16, {2, 2})); bfloat16 h(2.0f); output.PopulateWithValue(h); - auto expected = Literal::CreateR2({{h, h}, {h, h}}); + auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output(ShapeUtil::MakeShape(F32, {})); output.PopulateWithValue(2.5f); - auto expected = Literal::CreateR0(2.5f); + auto expected = LiteralUtil::CreateR0(2.5f); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { Literal output(ShapeUtil::MakeShape(S64, {3})); output.PopulateWithValue(-7); - auto expected = Literal::CreateR1({-7, -7, -7}); + auto expected = LiteralUtil::CreateR1({-7, -7, -7}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { Literal output(ShapeUtil::MakeShape(U64, {2, 2})); output.PopulateWithValue(42); - auto expected = Literal::CreateR2({{42, 42}, {42, 42}}); + auto expected = LiteralUtil::CreateR2({{42, 42}, {42, 42}}); EXPECT_EQ(output, *expected); } @@ -812,7 +835,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { Literal output(ShapeUtil::MakeShape(C64, {2, 2})); output.PopulateWithValue({4, 2}); auto expected = - Literal::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); + LiteralUtil::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); EXPECT_EQ(output, *expected); } @@ -820,7 +843,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output(ShapeUtil::MakeShape(F16, {})); half h(0.25f); output.PopulateWithValue(h); - auto expected = Literal::CreateR0(h); + auto expected = LiteralUtil::CreateR0(h); EXPECT_EQ(output, *expected); } @@ -828,7 +851,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { Literal output(ShapeUtil::MakeShape(F16, {3})); half h(0.5f); output.PopulateWithValue(h); - auto expected = Literal::CreateR1({h, h, h}); + auto expected = LiteralUtil::CreateR1({h, h, h}); EXPECT_EQ(output, *expected); } @@ -836,15 +859,15 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { Literal output(ShapeUtil::MakeShape(F16, {2, 2})); half h(2.0f); output.PopulateWithValue(h); - auto expected = Literal::CreateR2({{h, h}, {h, h}}); + auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, ReplicateR2U32) { - auto input = - Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto input = LiteralUtil::CreateR2( + {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); auto output = input->Replicate(3); - auto expected = Literal::CreateR3( + auto expected = LiteralUtil::CreateR3( {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); @@ -898,12 +921,12 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { } TEST_F(LiteralUtilTest, CopyFromScalars) { - auto zero = Literal::CreateR0(0); - auto nine = Literal::CreateR0(9); + auto zero = LiteralUtil::CreateR0(0); + auto nine = LiteralUtil::CreateR0(9); TF_EXPECT_OK(zero->CopyFrom(*nine)); EXPECT_EQ(*zero, *nine); - auto vect = Literal::CreateR1({3, 4, 9, 12, 5, 17, 21}); + auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {})); EXPECT_EQ(zero->Get({}), 17); TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {})); @@ -912,13 +935,13 @@ TEST_F(LiteralUtilTest, CopyFromScalars) { TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0}); - const auto const_nine = Literal::CreateR1({9}); + const auto const_nine = LiteralUtil::CreateR1({9}); const auto const_empty = Literal::CreateFromShape(empty_r1_shape); { // Source contains dimension with zero elements. const auto empty = Literal::CreateFromShape(empty_r1_shape); - auto nine = Literal::CreateR1({9}); + auto nine = LiteralUtil::CreateR1({9}); TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0})); EXPECT_EQ(*nine, *const_nine); @@ -927,7 +950,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { { // Copy 0 element to destination with zero elements. const auto empty = Literal::CreateFromShape(empty_r1_shape); - auto nine = Literal::CreateR1({9}); + auto nine = LiteralUtil::CreateR1({9}); TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0})); EXPECT_EQ(*empty, *const_empty); @@ -942,16 +965,16 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) { } TEST_F(LiteralUtilTest, CopyFromArrays) { - auto scalar_42 = Literal::CreateR0(42.0); - auto scalar_123 = Literal::CreateR0(123.0); + auto scalar_42 = LiteralUtil::CreateR0(42.0); + auto scalar_123 = LiteralUtil::CreateR0(123.0); EXPECT_NE(*scalar_42, *scalar_123); TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{}, /*src_shape_index=*/{})); EXPECT_EQ(*scalar_42, *scalar_123); EXPECT_EQ(scalar_42->Get({}), 123.0f); - auto matrix_1234 = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_5678 = Literal::CreateR2({{5.0, 6.0}, {7.0, 8.0}}); + auto matrix_1234 = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_5678 = LiteralUtil::CreateR2({{5.0, 6.0}, {7.0, 8.0}}); EXPECT_NE(*matrix_1234, *matrix_5678); EXPECT_EQ(matrix_1234->Get({0, 0}), 1.0f); TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{}, @@ -961,19 +984,19 @@ TEST_F(LiteralUtilTest, CopyFromArrays) { } TEST_F(LiteralUtilTest, CopyFromTuples) { - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = Literal::MakeTuple( + auto nested_tuple = LiteralUtil::MakeTuple( {matrix.get(), - Literal::MakeTuple({Literal::CreateR0(42).get(), - Literal::CreateR1({23.0, 44.0}).get(), - &nil_literal}) + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) .get()}); // Create a tuple the same shape as the inner tuple of nested_tuple but with // different values.. - auto tuple = Literal::MakeTuple({Literal::CreateR0(-5).get(), - Literal::CreateR1({2.0, 4.0}).get(), - &nil_literal}); + auto tuple = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(-5).get(), + LiteralUtil::CreateR1({2.0, 4.0}).get(), &nil_literal}); EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); @@ -994,8 +1017,8 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 4.0); } TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { - auto tuple = Literal::MakeTuple( - {Literal::CreateR0(-2).get(), Literal::CreateR0(4).get()}); + auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0(-2).get(), + LiteralUtil::CreateR0(4).get()}); EXPECT_EQ(tuple->Get({}, {0}), -2); EXPECT_EQ(tuple->Get({}, {1}), 4); @@ -1009,8 +1032,8 @@ TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { } TEST_F(LiteralUtilTest, CopyFromDifferentShapes) { - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto vector = Literal::CreateR1({5.0, 7.0}); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto vector = LiteralUtil::CreateR1({5.0, 7.0}); Status status = matrix->CopyFrom(*vector); ASSERT_FALSE(status.ok()); ASSERT_THAT(status.error_message(), @@ -1035,7 +1058,7 @@ TEST_F(LiteralUtilTest, F16) { half h1(1.0f); half h2(2.0f); - auto m2 = Literal::CreateR2({{h1, h2}, {h2, h1}}); + auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); Literal* l2 = m2.get(); const char* d2 = reinterpret_cast(l2->data().data()); EXPECT_EQ(d2[0], 0); @@ -1134,12 +1157,12 @@ TEST_F(LiteralUtilTest, PopulateParallel) { TEST_F(LiteralUtilTest, ConvertR4) { // clang-format off - auto original = Literal::CreateR4WithLayout({{ + auto original = LiteralUtil::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); - auto expected = Literal::CreateR4WithLayout({{ + auto expected = LiteralUtil::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -1153,42 +1176,42 @@ TEST_F(LiteralUtilTest, ConvertR4) { TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { // clang-format off - auto s8 = Literal::CreateR4WithLayout({{ + auto s8 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto s32 = Literal::CreateR4WithLayout({{ + auto s32 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto u32 = Literal::CreateR4WithLayout({{ + auto u32 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto s64 = Literal::CreateR4WithLayout({{ + auto s64 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto u64 = Literal::CreateR4WithLayout({{ + auto u64 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto pred = Literal::CreateR4WithLayout({{ + auto pred = LiteralUtil::CreateR4WithLayout({{ {{true, false, true, false}, {false, true, false, true}}, {{false, true, false, true}, {true, false, true, false}}, {{true, false, true, false}, {false, true, false, true}}, }}, layout_r4_dim0major_); - auto int32_pred = Literal::CreateR4WithLayout({{ + auto int32_pred = LiteralUtil::CreateR4WithLayout({{ {{1, 0, 1, 0}, {0, 1, 0, 1}}, {{0, 1, 0, 1}, {1, 0, 1, 0}}, {{1, 0, 1, 0}, {0, 1, 0, 1}}, }}, layout_r4_dim0major_); - auto f16 = Literal::CreateR4WithLayout({{ + auto f16 = LiteralUtil::CreateR4WithLayout({{ {{half(10.0), half(0.0), half(12.0), half(0.0)}, {half(0.0), half(15.0), half(0.0), half(17.0)}}, {{half(0.0), half(19.0), half(0.0), half(21.0)}, @@ -1196,7 +1219,7 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{half(26.0), half(0.0), half(28.0), half(0.0)}, {half(0.0), half(31.0), half(0.0), half(33.0)}}, }}, layout_r4_dim0major_); - auto bf16 = Literal::CreateR4WithLayout({{ + auto bf16 = LiteralUtil::CreateR4WithLayout({{ {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)}, {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}}, {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)}, @@ -1204,17 +1227,17 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)}, {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}}, }}, layout_r4_dim0major_); - auto f32 = Literal::CreateR4WithLayout({{ + auto f32 = LiteralUtil::CreateR4WithLayout({{ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); - auto f64 = Literal::CreateR4WithLayout({{ + auto f64 = LiteralUtil::CreateR4WithLayout({{ {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}}, {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, }}, layout_r4_dim0major_); - auto c64 = Literal::CreateR4WithLayout({{ + auto c64 = LiteralUtil::CreateR4WithLayout({{ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, @@ -1286,18 +1309,18 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { } TEST_F(LiteralUtilTest, BitcastConvert) { - auto original = - Literal::CreateR1({tensorflow::bit_cast(2.5f), - tensorflow::bit_cast(-42.25f), - tensorflow::bit_cast(100.f), 0xbeef}); - auto expected = Literal::CreateR1( + auto original = LiteralUtil::CreateR1( + {tensorflow::bit_cast(2.5f), + tensorflow::bit_cast(-42.25f), + tensorflow::bit_cast(100.f), 0xbeef}); + auto expected = LiteralUtil::CreateR1( {2.5f, -42.25f, 100.0f, tensorflow::bit_cast(0xbeef)}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, original->BitcastConvert(F32)); } TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { - auto literal = Literal::CreateR0(1234); + auto literal = LiteralUtil::CreateR0(1234); Status status = literal->BitcastConvert(F64).status(); EXPECT_NE(Status::OK(), status); EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(), @@ -1332,7 +1355,7 @@ TEST_F(LiteralUtilTest, ToProto_f16) { half h1(1.0f); half h2(2.0f); - auto m = Literal::CreateR2({{h1, h2}, {h2, h1}}); + auto m = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); Literal* l = m.get(); EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape())); EXPECT_EQ(4, l->data().size()); @@ -1375,10 +1398,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { } TEST_F(LiteralUtilTest, LiteralSliceTest) { - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); Literal nil(ShapeUtil::MakeNil()); EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); @@ -1397,10 +1420,10 @@ TEST_F(LiteralUtilTest, LiteralSliceTest) { } TEST_F(LiteralUtilTest, MutatingLiteralSlice) { - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. const auto nested_tuple_view = LiteralSlice(*nested_tuple); @@ -1420,15 +1443,16 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) { } TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); const auto nested_tuple_view = LiteralSlice(*nested_tuple); const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); - EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + EXPECT_EQ(matrix_view, + *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { @@ -1472,7 +1496,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { TEST_F(LiteralUtilTest, LiteralMove) { std::unique_ptr matrix = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); Literal literal(std::move(*matrix)); EXPECT_TRUE( @@ -1485,11 +1509,11 @@ TEST_F(LiteralUtilTest, LiteralMove) { TEST_F(LiteralUtilTest, DecomposeTuple) { Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = Literal::MakeTuple( - {Literal::CreateR2({{1, 2}, {3, 4}}).get(), - Literal::MakeTuple({Literal::CreateR0(42).get(), - Literal::CreateR1({23.0, 44.0}).get(), - &nil_literal}) + auto nested_tuple = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) .get(), &nil_literal}); @@ -1526,13 +1550,13 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) { TEST_F(LiteralUtilTest, MoveIntoTuple) { std::vector elements; - elements.push_back(std::move(*Literal::CreateR0(1.0))); - elements.push_back(std::move(*Literal::CreateR1({4, 8}))); - elements.push_back(std::move( - *Literal::MakeTuple({Literal::CreateR0(42).get(), - Literal::CreateR1({23.0, 44.0}).get()}) + elements.push_back(std::move(*LiteralUtil::CreateR0(1.0))); + elements.push_back(std::move(*LiteralUtil::CreateR1({4, 8}))); + elements.push_back(std::move(*LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR1({23.0, 44.0}).get()}) - )); + )); Literal literal = Literal::MoveIntoTuple(&elements); ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); @@ -1561,7 +1585,7 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape())); std::unique_ptr matrix = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); literal = std::move(*matrix); EXPECT_TRUE( @@ -1574,7 +1598,7 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { TEST_F(LiteralUtilTest, LiteralSliceCopy) { std::unique_ptr matrix = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); const auto matrix_view = LiteralSlice(*matrix); LiteralSlice matrix_view_copy(matrix_view); @@ -1585,9 +1609,9 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) { } TEST_F(LiteralUtilTest, GetSetTuple) { - auto tuple = Literal::MakeTuple( - {Literal::CreateR0(42.0).get(), - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get()}); + auto tuple = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(42.0).get(), + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get()}); EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); @@ -1628,20 +1652,20 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { TEST_F(LiteralUtilTest, ProtoRoundTrip) { // Test serializing then deserializing a Literal through a proto. - auto one_f32 = Literal::CreateR0(1.0); - auto two_f32 = Literal::CreateR0(2.0); - auto vector_int8 = Literal::CreateR1({-128, 0, 2, 4, 7, 56, 127}); - auto vector_c64 = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); - auto vector_bfloat16 = Literal::CreateR1( + auto one_f32 = LiteralUtil::CreateR0(1.0); + auto two_f32 = LiteralUtil::CreateR0(2.0); + auto vector_int8 = LiteralUtil::CreateR1({-128, 0, 2, 4, 7, 56, 127}); + auto vector_c64 = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector_bfloat16 = LiteralUtil::CreateR1( {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); auto vector_half = - Literal::CreateR1({half{10.0}, half{20.0}, half{-30.0}}); + LiteralUtil::CreateR1({half{10.0}, half{20.0}, half{-30.0}}); auto matrix_pred = - Literal::CreateR2({{true, false, true}, {false, false, true}}); - auto tuple = Literal::MakeTuple( + LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); + auto tuple = LiteralUtil::MakeTuple( {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = Literal::MakeTuple( + auto nested_tuple = LiteralUtil::MakeTuple( {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal}); auto to_from_proto = [](const Literal& literal) -> Literal { @@ -1774,8 +1798,8 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { } TEST_F(LiteralUtilTest, SortSparseElements) { - auto literal = - Literal::CreateSparse({10, 10, 10}, SparseIndexArray(10, 3), {}); + auto literal = LiteralUtil::CreateSparse({10, 10, 10}, + SparseIndexArray(10, 3), {}); literal->AppendSparseElement({2, 3, 4}, 2.0); literal->AppendSparseElement({3, 4, 5}, 3.0); literal->AppendSparseElement({1, 2, 3}, 1.0); @@ -1789,21 +1813,22 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}}); ASSERT_EQ( - Literal::CreateSparse(dimensions, indices, {true, false, true}) + LiteralUtil::CreateSparse(dimensions, indices, {true, false, true}) ->GetSparseElementAsString(1), "false"); - ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {1, 2, 3}) + ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) ->GetSparseElementAsString(1), tensorflow::strings::StrCat(int64{2})); - ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) - ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(double{2.0})); - ASSERT_EQ(Literal::CreateSparse(dimensions, indices, - {half{1.0}, half{2.0}, half{3.0}}) + ASSERT_EQ( + LiteralUtil::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat(double{2.0})); + ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, + {half{1.0}, half{2.0}, half{3.0}}) ->GetSparseElementAsString(1), tensorflow::strings::StrCat(static_cast(half{2.0}))); ASSERT_EQ( - Literal::CreateSparse( + LiteralUtil::CreateSparse( dimensions, indices, std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) ->GetSparseElementAsString(1), @@ -1811,33 +1836,36 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { - std::unique_ptr literal = Literal::CreateR1({1, 2}); + std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr broadcasted_literal, literal->Broadcast( /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), /*dimensions=*/{0})); - EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 1}, {2, 2}})); + EXPECT_EQ(*broadcasted_literal, + *LiteralUtil::CreateR2({{1, 1}, {2, 2}})); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) { - std::unique_ptr literal = Literal::CreateR1({1, 2}); + std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr broadcasted_literal, literal->Broadcast( /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), /*dimensions=*/{1})); - EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 2}, {1, 2}})); + EXPECT_EQ(*broadcasted_literal, + *LiteralUtil::CreateR2({{1, 2}, {1, 2}})); } TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { - std::unique_ptr literal = Literal::CreateR0(9); + std::unique_ptr literal = LiteralUtil::CreateR0(9); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr broadcasted_literal, literal->Broadcast( /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), /*dimensions=*/{})); - EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{9, 9}, {9, 9}})); + EXPECT_EQ(*broadcasted_literal, + *LiteralUtil::CreateR2({{9, 9}, {9, 9}})); } } // namespace diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 6b295897004cebce003ddd3999aacf63915ffe5f..5d33df7d40bf3bfcc8012ce1129d532b34555344 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -34,34 +34,15 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" -using tensorflow::strings::Printf; using tensorflow::strings::StrCat; namespace xla { namespace { -constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; - -// Converts between little and big endian. -// -// Precondition: size % 2 == 0 (elements in the array are 16 bits long) -void ConvertEndianShort(string* bytes) { - CHECK_EQ(bytes->size() / 2, 0); - for (int64 i = 0; i < bytes->size(); i += 2) { - std::swap((*bytes)[i], (*bytes)[i + 1]); - } -} - -void ConvertEndianShort(char* bytes, int64 size) { - CHECK_EQ(size / 2, 0); - for (int64 i = 0; i < size; i += 2) { - std::swap(bytes[i], bytes[i + 1]); - } -} - // Return a literal with all arrays of type FromNativeT converted to type // ToNativeT in the given literal. template @@ -103,498 +84,54 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } // namespace -LiteralBase::~LiteralBase() {} - -std::ostream& operator<<(std::ostream& out, const Literal& literal) { - out << literal.ToString(); - return out; -} - -Literal::StrideConfig::StrideConfig( - const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions) - : dimensions(dimensions), - base(dimensions.size(), 0), - step(dimensions.size(), 1) { - if (!dimensions.empty()) { - // Selects the shape with the largest minor dimension as the one upon - // which to run the tight stride loop. - if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >= - dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) { - minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0); - dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension); - } else { - minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0); - source_stride = - IndexUtil::GetDimensionStride(source_shape, minor_dimension); - } - minor_loop_size = dimensions[minor_dimension]; - step[minor_dimension] = minor_loop_size; - } -} - -Literal::Literal(const Shape& shape) - : Literal(shape, /*allocate_arrays=*/true) {} - -void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { - if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const Shape& subshape = shape.tuple_shapes(i); - - auto child_piece = Piece(); - child_piece.set_subshape(&subshape); - - SetPiece(subshape, &child_piece, allocate_arrays); - - piece->emplace_back(std::move(child_piece)); - } - } else { - CHECK(ShapeUtil::IsArray(shape)); - if (allocate_arrays) { - if (LayoutUtil::IsSparseArray(shape)) { - // For sparse arrays, the buffer must be of the size of the maximum - // number of sparse elements possible. - const int64 max_sparse_elements = - LayoutUtil::MaxSparseElements(shape.layout()); - piece->set_buffer( - new char[max_sparse_elements * - ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); - piece->set_sparse_indices( - new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); - } else { - piece->set_buffer(new char[piece->size_bytes()]); - } - } - } -} - -Literal::Literal(const Shape& shape, bool allocate_arrays) - : LiteralBase(), shape_(MakeUnique(shape)) { - CHECK(LayoutUtil::HasLayout(*shape_)); - root_piece_ = new Piece(); - root_piece_->set_subshape(shape_.get()); - CHECK(&root_piece_->subshape() == shape_.get()); - - SetPiece(*shape_, root_piece_, allocate_arrays); -} - -Literal::~Literal() { - if (root_piece_ != nullptr) { - DeallocateBuffers(); - delete root_piece_; - } -} - -void Literal::DeallocateBuffers() { - root_piece_->ForEachMutableSubpiece( - [&](const ShapeIndex& index, Piece* piece) { - if (piece->buffer() != nullptr) { - delete[] piece->buffer(); - delete piece->sparse_indices(); - } - }); -} - -Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } - -Literal& Literal::operator=(Literal&& other) { - DCHECK(&other.root_piece_->subshape() == other.shape_.get()); - using std::swap; - swap(shape_, other.shape_); - swap(root_piece_, other.root_piece_); - DCHECK(&root_piece_->subshape() == shape_.get()); - - return *this; -} - -std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { - auto literal = MakeUnique(shape); - literal->root_piece_->ForEachMutableSubpiece( - [&](const ShapeIndex& index, Piece* piece) { - if (ShapeUtil::IsArray(piece->subshape())) { - memset(piece->untyped_data(), 0, piece->size_bytes()); - } - }); - return literal; -} - -const SparseIndexArray* LiteralBase::sparse_indices( - const ShapeIndex& shape_index) const { - return piece(shape_index).sparse_indices(); -} - -SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { - return piece(shape_index).sparse_indices(); -} - -/* static */ std::unique_ptr Literal::CreateFromDimensions( +/* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( PrimitiveType primitive_type, tensorflow::gtl::ArraySlice dimensions) { - return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); + return Literal::CreateFromShape( + ShapeUtil::MakeShape(primitive_type, dimensions)); } -/* static */ std::unique_ptr Literal::ConvertBF16ToF32( +/* static */ std::unique_ptr LiteralUtil::ConvertBF16ToF32( const LiteralSlice& bf16_literal) { return ConvertType(bf16_literal); } -/* static */ std::unique_ptr Literal::ConvertF32ToBF16( +/* static */ std::unique_ptr LiteralUtil::ConvertF32ToBF16( const LiteralSlice& f32_literal) { return ConvertType(f32_literal); } -template -Status Literal::CopySliceFromInternal( - const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { - TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); - TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); - - auto linear_index = [](const Shape& shape, - tensorflow::gtl::ArraySlice multi_index) { - return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); - }; - - if (ShapeUtil::Rank(src_literal.shape()) == 0 || - ShapeUtil::Rank(shape()) == 0) { - // If any of the two shapes are scalars, we can just call the StridedCopy() - // directly, and we know we will be copying only one value. - TF_RET_CHECK(copy_size.empty()); - StridedCopy(data(), linear_index(shape(), dest_base), 0, - src_literal.data(), - linear_index(src_literal.shape(), src_base), 0, 1); - } else if (!ShapeUtil::HasZeroElements(shape()) && - !ShapeUtil::HasZeroElements(src_literal.shape())) { - // Perform copy if neither src nor dest has dimensions with zero element, - // otherwise it's a no-op. - TF_RET_CHECK(src_base.size() == dest_base.size()); - TF_RET_CHECK(src_base.size() == copy_size.size()); - - // Scan the source from minor, stepping in copy size blocks, then within - // the index enumaration functor, do a strided copy advancing source index - // by one (walking through the minor dimension), and destination index by - // proper stride size at the matching dimension. - DimensionVector src_indexes(src_base.size(), 0); - DimensionVector dest_indexes(dest_base.size(), 0); - Literal::StrideConfig stride_config(src_literal.shape(), shape(), - copy_size); - - auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { - // Map from multi-dimensional index, to source index. - std::transform(indexes.begin(), indexes.end(), src_base.begin(), - src_indexes.begin(), std::plus()); - // Map from multi-dimensional index, to destination index. - std::transform(indexes.begin(), indexes.end(), dest_base.begin(), - dest_indexes.begin(), std::plus()); - - int64 src_index = linear_index(src_literal.shape(), src_indexes); - int64 dest_index = linear_index(shape(), dest_indexes); - - // `this->` is needed to workaround MSVC bug: #16882 - StridedCopy(this->data(), dest_index, stride_config.dest_stride, - src_literal.data(), src_index, - stride_config.source_stride, stride_config.minor_loop_size); - return true; - }; - - ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base, - stride_config.dimensions, stride_config.step, - copy_proc); - } - return Status::OK(); -} - -Status Literal::CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index) { - DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); - const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( - src_literal.shape(), src_index); - const int64 dest_linear_index = - IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index); - const int64 primitive_size = - ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); - - char* dest_address = - static_cast(untyped_data()) + dest_linear_index * primitive_size; - const char* source_address = - static_cast(src_literal.untyped_data()) + - src_linear_index * primitive_size; - if (dest_address != source_address) { - memcpy(dest_address, source_address, primitive_size); - } - return Status::OK(); -} - -std::vector Literal::DecomposeTuple() { - CHECK(ShapeUtil::IsTuple(shape())); - std::vector elements; - for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { - elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), - /*allocate_arrays=*/false)); - Literal& element = elements.back(); - element.root_piece_->ForEachMutableSubpiece( - [&](const ShapeIndex& index, Piece* dest_piece) { - ShapeIndex src_index = {i}; - for (int64 j : index) { - src_index.push_back(j); - } - Piece& src_piece = piece(src_index); - - // Move the respective buffer and sparse indices over to the element - // Literal. - dest_piece->set_buffer(src_piece.buffer()); - src_piece.set_buffer(nullptr); - dest_piece->set_sparse_indices(src_piece.sparse_indices()); - src_piece.set_sparse_indices(nullptr); - }); - } - // Set this literal to be nil-shaped. - *this = Literal(); - return elements; -} - -/* static */ Literal Literal::MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements) { - std::vector element_shapes; - for (const Literal& element : elements) { - element_shapes.push_back(element.shape()); - } - Literal literal(ShapeUtil::MakeTupleShape(element_shapes), - /*allocate_arrays=*/false); - for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK( - literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); - } - return literal; -} - -namespace { - -// Copies the elements in 'src' to 'dest'. The shape and layout of the data in -// the array slices are indicated by dest_shape and src_shape respectively. -template -void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, - tensorflow::gtl::ArraySlice src, - const Shape& dest_shape, const Shape& src_shape) { - CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); - if (ShapeUtil::HasZeroElements(dest_shape)) { - return; - } - std::vector index(ShapeUtil::Rank(dest_shape)); - do { - dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = - src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; - } while (IndexUtil::BumpIndices(dest_shape, &index)); -} - -} // namespace - -Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { - CHECK(subshape_ != nullptr); - CHECK(src.subshape_ != nullptr); - if (ShapeUtil::Equal(subshape(), src.subshape())) { - // If the layouts are equal it's faster just to memcpy. - memcpy(buffer(), src.buffer(), src.size_bytes()); - } else { - TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); - std::vector origin(ShapeUtil::Rank(subshape()), 0); - switch (subshape().element_type()) { -#define COPY_ELEMENTS(XLA_T, NATIVE_T) \ - case (XLA_T): \ - CopyElementsBetween(data(), src.data(), \ - subshape(), src.subshape()); \ - break; - COPY_ELEMENTS(U8, uint8); - COPY_ELEMENTS(U16, uint16); - COPY_ELEMENTS(U32, uint32); - COPY_ELEMENTS(U64, uint64); - COPY_ELEMENTS(S8, int8); - COPY_ELEMENTS(S16, int16); - COPY_ELEMENTS(S32, int32); - COPY_ELEMENTS(S64, int64); - COPY_ELEMENTS(F16, half); - COPY_ELEMENTS(BF16, bfloat16); - COPY_ELEMENTS(F32, float); - COPY_ELEMENTS(F64, double); - COPY_ELEMENTS(C64, complex64); - COPY_ELEMENTS(PRED, bool); -#undef COPY_ELEMENTS - default: - return Unimplemented( - "Copying a Literal object with element type %s is not implemented.", - PrimitiveType_Name(subshape().element_type()).c_str()); - } - } - return Status::OK(); -} - -Status Literal::CopyFrom(const LiteralSlice& src_literal, - const ShapeIndex& dest_shape_index, - const ShapeIndex& src_shape_index) { - const Shape& dest_subshape = - ShapeUtil::GetSubshape(shape(), dest_shape_index); - const Shape& src_subshape = - ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index); - if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { - return InvalidArgument( - "Destination subshape incompatible with source subshape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_subshape).c_str()); - } - return root_piece_->ForEachMutableSubpieceWithStatus( - [&](const ShapeIndex& index, Piece* piece) { - if (!ShapeUtil::IsArray(piece->subshape())) { - return Status::OK(); - } - - // Determine if this index is in the part of this literal that we want - // to copy over from src_literal. - bool in_subtree_to_copy = true; - for (int i = 0; i < dest_shape_index.size(); ++i) { - if (index[i] != dest_shape_index[i]) { - in_subtree_to_copy = false; - break; - } - } - if (!in_subtree_to_copy) { - return Status::OK(); - } - // Construct the index of the corresponding piece in the source literal. - ShapeIndex src_piece_index = src_shape_index; - for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { - src_piece_index.push_back(index[i]); - } - TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); - return Status::OK(); - }); -} - -Status Literal::MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index) { - const Shape& dest_subshape = - ShapeUtil::GetSubshape(shape(), dest_shape_index); - if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) { - return InvalidArgument( - "Destination subshape not equal to source shape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_literal.shape()).c_str()); - } - - src_literal.root_piece_->ForEachSubpiece( - [&](const ShapeIndex& src_index, const Piece& src_piece) { - if (!ShapeUtil::IsArray(src_piece.subshape())) { - return; - } - - ShapeIndex dest_index = dest_shape_index; - for (int64 i : src_index) { - dest_index.push_back(i); - } - Piece& dest_piece = piece(dest_index); - delete[] dest_piece.buffer(); - dest_piece.set_buffer(src_piece.buffer()); - delete dest_piece.sparse_indices(); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - }); - - src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); - delete src_literal.root_piece_; - src_literal.root_piece_ = new LiteralBase::Piece(); - src_literal.root_piece_->set_subshape(src_literal.shape_.get()); - - return Status::OK(); -} - -Status Literal::CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { - TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); - TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) - << ShapeUtil::HumanString(src_literal.shape()); - TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); - - switch (shape().element_type()) { - case U8: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case U16: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case U32: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case U64: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case S8: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case S16: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case S32: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case S64: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case F16: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case BF16: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case F32: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case F64: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case C64: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case PRED: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - default: - break; - } - return Unimplemented( - "Copying a slice from a Literal object with element type %d is not " - "implemented.", - shape().element_type()); +/* static */ std::unique_ptr LiteralUtil::CreateToken() { + return MakeUnique(ShapeUtil::MakeTokenShape()); } -/* static */ Literal Literal::Zero(PrimitiveType primitive_type) { +/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case U32: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case U64: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case S8: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case S32: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case S64: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case F16: - return std::move(*Literal::CreateR0(static_cast(0.0f))); + return std::move(*LiteralUtil::CreateR0(static_cast(0.0f))); case BF16: return std::move( - *Literal::CreateR0(static_cast(0.0f))); + *LiteralUtil::CreateR0(static_cast(0.0f))); case F32: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case F64: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case C64: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case PRED: - return std::move(*Literal::CreateR0(false)); + return std::move(*LiteralUtil::CreateR0(false)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -607,33 +144,33 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, } } -/* static */ Literal Literal::One(PrimitiveType primitive_type) { +/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case U32: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case U64: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case S8: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case S32: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case S64: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case F16: - return std::move(*Literal::CreateR0(static_cast(1.0f))); + return std::move(*LiteralUtil::CreateR0(static_cast(1.0f))); case BF16: return std::move( - *Literal::CreateR0(static_cast(1.0f))); + *LiteralUtil::CreateR0(static_cast(1.0f))); case F32: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case F64: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case C64: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case PRED: - return std::move(*Literal::CreateR0(true)); + return std::move(*LiteralUtil::CreateR0(true)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -646,44 +183,44 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, } } -/* static */ Literal Literal::MinValue(PrimitiveType primitive_type) { +/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case U32: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case U64: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case S8: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case S32: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case S64: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case F32: - return std::move( - *Literal::CreateR0(-std::numeric_limits::infinity())); + return std::move(*LiteralUtil::CreateR0( + -std::numeric_limits::infinity())); case F64: - return std::move( - *Literal::CreateR0(-std::numeric_limits::infinity())); + return std::move(*LiteralUtil::CreateR0( + -std::numeric_limits::infinity())); case C64: LOG(FATAL) << "C64 element type has no minimum value"; case PRED: - return std::move(*Literal::CreateR0(false)); + return std::move(*LiteralUtil::CreateR0(false)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*Literal::CreateR0( + return std::move(*LiteralUtil::CreateR0( static_cast(-std::numeric_limits::infinity()))); case BF16: - return std::move(*Literal::CreateR0( + return std::move(*LiteralUtil::CreateR0( static_cast(-std::numeric_limits::infinity()))); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; @@ -694,42 +231,42 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, } } -/* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) { +/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case U32: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case U64: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case S8: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case S32: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case S64: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case F32: - return std::move( - *Literal::CreateR0(std::numeric_limits::infinity())); + return std::move(*LiteralUtil::CreateR0( + std::numeric_limits::infinity())); case F64: - return std::move( - *Literal::CreateR0(std::numeric_limits::infinity())); + return std::move(*LiteralUtil::CreateR0( + std::numeric_limits::infinity())); case PRED: - return std::move(*Literal::CreateR0(true)); + return std::move(*LiteralUtil::CreateR0(true)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*Literal::CreateR0( + return std::move(*LiteralUtil::CreateR0( static_cast(std::numeric_limits::infinity()))); case BF16: - return std::move(*Literal::CreateR0( + return std::move(*LiteralUtil::CreateR0( static_cast(std::numeric_limits::infinity()))); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; @@ -740,7 +277,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, } } -/* static */ std::unique_ptr Literal::CreateR1( +/* static */ std::unique_ptr LiteralUtil::CreateR1( const tensorflow::core::Bitmap& values) { auto literal = MakeUnique( ShapeUtil::MakeShape(PRED, {static_cast(values.bits())})); @@ -748,17 +285,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, return literal; } -void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); - CHECK_EQ(element_count(), values.bits()); - CHECK_EQ(shape().element_type(), PRED); - for (int64 i = 0; i < static_cast(values.bits()); ++i) { - Set({i}, values.get(i)); - } -} - -/* static */ std::unique_ptr Literal::CreateR1U8( +/* static */ std::unique_ptr LiteralUtil::CreateR1U8( tensorflow::StringPiece value) { auto literal = MakeUnique( ShapeUtil::MakeShape(U8, {static_cast(value.size())})); @@ -768,116 +295,13 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { return literal; } -/* static */ std::unique_ptr Literal::CreateR2F32Linspace(float from, - float to, - int64 rows, - int64 cols) { +/* static */ std::unique_ptr LiteralUtil::CreateR2F32Linspace( + float from, float to, int64 rows, int64 cols) { auto value = MakeLinspaceArray2D(from, to, rows, cols); return CreateR2FromArray2D(*value); } -std::unique_ptr LiteralBase::Relayout( - const Layout& new_layout, const ShapeIndex& shape_index) const { - // Create new shape with 'new_layout' set at the given shape index. - Shape new_shape = shape(); - Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); - TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); - *subshape->mutable_layout() = new_layout; - auto result = MakeUnique(new_shape); - TF_CHECK_OK(result->CopyFrom(*this)); - return result; -} - -std::unique_ptr LiteralBase::Relayout( - const Shape& shape_with_layout) const { - CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) - << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) - << " not compatible with literal shape " - << ShapeUtil::HumanString(shape()); - std::unique_ptr result = CreateFromShape(shape_with_layout); - ShapeUtil::ForEachSubshape( - result->shape(), - [this, &result](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(subshape)) { - TF_CHECK_OK(result->CopyFrom(*this, - /*dest_shape_index=*/index, - /*src_shape_index=*/index)); - } - }); - return result; -} - -StatusOr> LiteralBase::Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const { - if (!ShapeUtil::IsArray(shape())) { - return InvalidArgument("Broadcast only supports arrays."); - } - - for (int64 i = 0; i < dimensions.size(); i++) { - TF_RET_CHECK(shape().dimensions(i) == - result_shape.dimensions(dimensions[i])); - } - - std::unique_ptr result = MakeUnique(result_shape); - - // scratch_source_index is temporary storage space for the computed index into - // the input literal. We put it here to avoid allocating an std::vector in - // every iteration of ShapeUtil::ForEachIndex. - std::vector scratch_source_index(shape().dimensions_size()); - - char* dest_data = static_cast(result->untyped_data()); - const char* source_data = static_cast(untyped_data()); - const int64 primitive_size = - ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); - - ShapeUtil::ForEachIndex( - result_shape, [&](tensorflow::gtl::ArraySlice output_index) { - for (int64 i = 0; i < dimensions.size(); ++i) { - scratch_source_index[i] = output_index[dimensions[i]]; - } - int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex( - result_shape, output_index); - int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex( - shape(), scratch_source_index); - memcpy(dest_data + primitive_size * dest_index, - source_data + primitive_size * source_index, primitive_size); - return true; - }); - - return std::move(result); -} - -StatusOr> LiteralBase::Reshape( - tensorflow::gtl::ArraySlice dimensions) const { - if (!ShapeUtil::IsArray(shape())) { - return InvalidArgument("Reshape does not support tuples."); - } - std::unique_ptr output; - if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { - output = - Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); - } else { - output = CloneToUnique(); - } - // Because the layout is monotonic, we can simply reuse the same sequence of - // values without changing their order. - *output->mutable_shape_do_not_use() = - ShapeUtil::MakeShape(shape().element_type(), dimensions); - - int64 elements_before = ShapeUtil::ElementsIn(shape()); - int64 elements_after = ShapeUtil::ElementsIn(output->shape()); - if (elements_before != elements_after) { - return InvalidArgument( - "Shapes before and after Literal::Reshape have different numbers " - "of elements: %s vs %s.", - ShapeUtil::HumanString(shape()).c_str(), - ShapeUtil::HumanString(output->shape()).c_str()); - } - return std::move(output); -} - -/* static */ std::unique_ptr Literal::ReshapeSlice( +/* static */ std::unique_ptr LiteralUtil::ReshapeSlice( tensorflow::gtl::ArraySlice new_dimensions, tensorflow::gtl::ArraySlice minor_to_major, const LiteralSlice& literal) { @@ -949,583 +373,77 @@ StatusOr> LiteralBase::Reshape( return new_literal; } -std::unique_ptr LiteralBase::Transpose( - tensorflow::gtl::ArraySlice permutation) const { - CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; - CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) - << "Given permutation is not a permutation of dimension numbers"; - // To transpose the array, we just permute the dimensions and layout, and - // do a straight memory copy of the raw data set. - // This is considerably faster than iterating over every array element using - // the EachCell<>() and Set<>() APIs. - std::vector inverse_permutation = InversePermutation(permutation); - Shape permuted_shape = - ShapeUtil::PermuteDimensions(inverse_permutation, shape()); - // Replace the layout with one affine to this shape, such that a - // transpose operation can be performed by leaving the flat values - // representation intact. - // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. - // The shape with affine layout resulting from that operation will be - // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the - // most minor. - // - // Essentially, given MinMaj(Di) the position of the Di dimension within the - // minor to major vector, and given T(Di) the index that the original Di - // dimension has within the transposed array, a layout is affine if - // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major - // vector of the affine layout. - CHECK(LayoutUtil::IsDenseArray(permuted_shape)); - Layout* layout = permuted_shape.mutable_layout(); - layout->clear_minor_to_major(); - for (auto index : LayoutUtil::MinorToMajor(shape())) { - layout->add_minor_to_major(inverse_permutation[index]); - } - auto new_literal = MakeUnique(permuted_shape); - DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), - ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); - return new_literal; -} - -template -std::unique_ptr LiteralBase::SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const { - auto result_literal = MakeUnique(result_shape); - DimensionVector new_indices(ShapeUtil::Rank(result_shape)); - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - NativeT value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; -} - -std::unique_ptr LiteralBase::Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const { - CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; - - DimensionVector result_dimensions; - for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { - CHECK_GE(start_indices[dnum], 0); - CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) - << "dnum = " << dnum; - int64 dimension = limit_indices[dnum] - start_indices[dnum]; - CHECK_GE(dimension, 0) << "dnum = " << dnum; - result_dimensions.push_back(dimension); - } - const auto result_shape = - ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, - LayoutUtil::MinorToMajor(shape())); - switch (result_shape.element_type()) { - case F32: - return SliceInternal(result_shape, start_indices); - case BF16: - return SliceInternal(result_shape, start_indices); - case C64: - return SliceInternal(result_shape, start_indices); - case S32: - return SliceInternal(result_shape, start_indices); - case U32: - return SliceInternal(result_shape, start_indices); - default: - LOG(FATAL) << "not yet implemented: " - << PrimitiveType_Name(result_shape.element_type()); - } -} - -Literal LiteralBase::Clone() const { - Literal result(shape()); - TF_CHECK_OK(result.CopyFrom(*this)); - return result; -} - -std::unique_ptr LiteralBase::CloneToUnique() const { - auto result = MakeUnique(shape()); - TF_CHECK_OK(result->CopyFrom(*this)); - return result; -} - -string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { - const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); - CHECK(LayoutUtil::IsDenseArray(subshape)); - switch (subshape.element_type()) { +/* static */ Literal LiteralUtil::GetFirstScalarLiteral( + const LiteralSlice& literal) { + CHECK(ShapeUtil::IsArray(literal.shape())); + CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0); + switch (literal.shape().element_type()) { case PRED: - return Get(multi_index, shape_index) ? "true" : "false"; + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); + // 8 bit types. case S8: - return StrCat(Get(multi_index, shape_index)); - case S16: - return StrCat(Get(multi_index, shape_index)); - case S32: - return StrCat(Get(multi_index, shape_index)); - case S64: - return StrCat(Get(multi_index, shape_index)); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case U8: - return StrCat(Get(multi_index, shape_index)); - case U16: - return StrCat(Get(multi_index, shape_index)); - case U32: - return StrCat(Get(multi_index, shape_index)); - case U64: - return StrCat(Get(multi_index, shape_index)); - case F16: - return StrCat(static_cast(Get(multi_index, shape_index))); - case F32: - return StrCat(Get(multi_index, shape_index)); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); + // 16 bit types. case BF16: - return StrCat( - static_cast(Get(multi_index, shape_index))); - case F64: - return StrCat(Get(multi_index, shape_index)); - case C64: { - complex64 c = Get(multi_index, shape_index); - return StrCat("(", c.real(), ", ", c.imag(), ")"); - } - default: - LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); - } -} - -string LiteralBase::GetSparseElementAsString( - int64 sparse_element_number, const ShapeIndex& shape_index) const { - const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); - CHECK(LayoutUtil::IsSparseArray(subshape)); - switch (subshape.element_type()) { - case PRED: - return GetSparseElement(sparse_element_number, shape_index) - ? "true" - : "false"; - case S8: - return StrCat(GetSparseElement(sparse_element_number, shape_index)); + return std::move(*LiteralUtil::CreateR0( + literal.GetFirstElement())); + case F16: + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case S16: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case S32: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case S64: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U8: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case U16: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U32: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U64: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case F16: - return StrCat(static_cast( - GetSparseElement(sparse_element_number, shape_index))); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); + // 32 bit types. case F32: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case BF16: - return StrCat(static_cast( - GetSparseElement(sparse_element_number, shape_index))); - case F64: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case C64: { - complex64 c = - GetSparseElement(sparse_element_number, shape_index); - return StrCat("(", c.real(), ", ", c.imag(), ")"); - } - default: - LOG(FATAL) << "Invalid element type for sparse arrays: " - << PrimitiveType_Name(subshape.element_type()); - } -} - -StatusOr LiteralBase::GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(LayoutUtil::IsDenseArray(shape())); - switch (shape().element_type()) { - case PRED: - return Get(multi_index); - case U8: - return Get(multi_index); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case S32: - return Get(multi_index); - case S64: - return Get(multi_index); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case U32: - return Get(multi_index); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); + // 64 bit types. + case C64: + return std::move(*LiteralUtil::CreateR0( + literal.GetFirstElement())); + case F64: + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); + case S64: + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case U64: - return Get(multi_index); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); + LOG(FATAL) << "Unhandled primitive type " + << literal.shape().element_type(); } } -size_t LiteralBase::Hash() const { - using tensorflow::Hash64; - using tensorflow::Hash64Combine; - - size_t hash_value = ShapeUtil::Hash(shape()); - - ShapeUtil::ForEachSubshape( - shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsTuple(subshape)) { - return; - } - - CHECK(LayoutUtil::IsDense(subshape.layout())); - hash_value = Hash64Combine( - hash_value, Hash64(static_cast(untyped_data(index)), - size_bytes(index))); - }); - - return hash_value; +/* static */ std::unique_ptr LiteralUtil::MakeTuple( + tensorflow::gtl::ArraySlice elements) { + std::vector element_shapes; + for (const auto* element : elements) { + element_shapes.push_back(element->shape()); + } + auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); + } + return literal; } -Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value) { - CHECK(LayoutUtil::IsDenseArray(shape())); - switch (shape().element_type()) { - case PRED: - Set(multi_index, value); - break; - case U8: - Set(multi_index, value); - break; - case S32: - Set(multi_index, value); - break; - case S64: - Set(multi_index, value); - break; - case U32: - Set(multi_index, value); - break; - case U64: - Set(multi_index, value); - break; - default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); - } - return Status::OK(); -} - -tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index) const { - const Piece& p = piece(shape_index); - CHECK_GE(sparse_element_number, 0); - CHECK_LT(sparse_element_number, p.sparse_indices()->index_count()); - return p.sparse_indices()->At(sparse_element_number); -} - -void Literal::SortSparseElements(const ShapeIndex& shape_index) { - piece(shape_index).SortSparseElements(); -} - -Literal LiteralBase::GetFirstScalarLiteral() const { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_GT(ShapeUtil::ElementsIn(shape()), 0); - switch (shape().element_type()) { - case PRED: - return std::move(*Literal::CreateR0(GetFirstElement())); - // 8 bit types. - case S8: - return std::move(*Literal::CreateR0(GetFirstElement())); - case U8: - return std::move(*Literal::CreateR0(GetFirstElement())); - // 16 bit types. - case BF16: - return std::move( - *Literal::CreateR0(GetFirstElement())); - case F16: - return std::move(*Literal::CreateR0(GetFirstElement())); - case S16: - return std::move(*Literal::CreateR0(GetFirstElement())); - case U16: - return std::move(*Literal::CreateR0(GetFirstElement())); - // 32 bit types. - case F32: - return std::move(*Literal::CreateR0(GetFirstElement())); - case S32: - return std::move(*Literal::CreateR0(GetFirstElement())); - case U32: - return std::move(*Literal::CreateR0(GetFirstElement())); - // 64 bit types. - case C64: - return std::move( - *Literal::CreateR0(GetFirstElement())); - case F64: - return std::move(*Literal::CreateR0(GetFirstElement())); - case S64: - return std::move(*Literal::CreateR0(GetFirstElement())); - case U64: - return std::move(*Literal::CreateR0(GetFirstElement())); - default: - LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); - } -} - -void LiteralBase::Piece::SortSparseElements() { - switch (subshape().element_type()) { - case PRED: - SortSparseElementsInternal(); - break; - case S8: - SortSparseElementsInternal(); - break; - case U8: - SortSparseElementsInternal(); - break; - case S16: - SortSparseElementsInternal(); - break; - case U16: - SortSparseElementsInternal(); - break; - case S32: - SortSparseElementsInternal(); - break; - case U32: - SortSparseElementsInternal(); - break; - case S64: - SortSparseElementsInternal(); - break; - case U64: - SortSparseElementsInternal(); - break; - case F32: - SortSparseElementsInternal(); - break; - case F64: - SortSparseElementsInternal(); - break; - case C64: - SortSparseElementsInternal(); - break; - case F16: - SortSparseElementsInternal(); - break; - case BF16: - SortSparseElementsInternal(); - break; - default: - LOG(FATAL) << "Element type not valid for sparse array: " - << PrimitiveType_Name(subshape().element_type()); - } -} - -template -void LiteralBase::Piece::SortSparseElementsInternal() { - CHECK(LayoutUtil::IsSparseArray(subshape())); - int64 num_elements = sparse_indices()->index_count(); - auto values = data(); - CHECK_LE(num_elements, values.size()); - sparse_indices()->SortWithValues( - tensorflow::gtl::MutableArraySlice(values.data(), num_elements)); -} - -namespace { - -void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces) { - const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - CHECK(LayoutUtil::HasLayout(literal.shape())); - CHECK(LayoutUtil::HasLayout(subshape)); - - auto shape_to_string = [print_layout](const Shape& shape) { - if (print_layout) { - return ShapeUtil::HumanStringWithLayout(shape); - } else { - return ShapeUtil::HumanString(shape); - } - }; - - // TODO(b/32894291): refactor this code to reduce code duplication. - if (ShapeUtil::IsTuple(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" (\n"); - std::vector tuple_pieces; - for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { - ShapeIndex element_index = shape_index; - element_index.push_back(i); - std::vector element_pieces; - ToStringHelper(literal, element_index, print_layout, &element_pieces); - tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, "")); - } - pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); - pieces->push_back("\n)"); - return; - } - - if (LayoutUtil::IsSparseArray(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back("{"); - int64 rank = ShapeUtil::Rank(subshape); - int64 num_elements = literal.sparse_element_count(); - for (int64 i = 0; i < num_elements; ++i) { - if (i > 0) { - pieces->push_back(", "); - } - if (rank == 1) { - pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); - pieces->push_back(": "); - } else { - pieces->push_back("["); - pieces->push_back( - tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); - pieces->push_back("]: "); - } - pieces->push_back(literal.GetSparseElementAsString(i)); - } - pieces->push_back("}"); - return; - } - - CHECK(LayoutUtil::IsDenseArray(subshape)); - - auto element_to_string = - [&](tensorflow::gtl::ArraySlice indices) -> string { - PrimitiveType element_type = subshape.element_type(); - if (element_type == PRED) { - // We display predicates in a densely packed form. - return literal.Get(indices, shape_index) ? "1" : "0"; - } - return ((!indices.empty() && indices.back() > 0) ? ", " : "") + - literal.GetAsString(indices, shape_index); - }; - - if (ShapeUtil::Rank(subshape) == 0) { - pieces->push_back(literal.GetAsString({}, shape_index)); - } else if (ShapeUtil::Rank(subshape) == 1) { - pieces->push_back("{"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(element_to_string({i0})); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 2) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(" { "); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(element_to_string({i0, i1})); - } - pieces->push_back(" "); - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 3) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(i0 > 0 ? ",\n{" : "{"); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(i1 > 0 ? ",\n { " : " { "); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(element_to_string({i0, i1, i2})); - } - pieces->push_back(" }"); - } - pieces->push_back(" }"); - } - pieces->push_back("\n}"); - } else if (ShapeUtil::Rank(subshape) == 4) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(" {"); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(element_to_string({i0, i1, i2, i3})); - } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n"); - } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 5) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(Printf(" { /*i2=%lld*/\n", i2)); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(" {"); - for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { - pieces->push_back(element_to_string({i0, i1, i2, i3, i4})); - } - pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n" - : "},\n"); - } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); - } else { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {"); - literal.EachCellAsString( - [&](tensorflow::gtl::ArraySlice indices, const string& value) { - pieces->push_back(" "); - pieces->push_back(value); - }); - pieces->push_back("}"); - } -} - -} // namespace - -int64 LiteralBase::sparse_element_count() const { - CHECK(LayoutUtil::IsSparseArray(shape())); - return sparse_indices()->index_count(); -} - -string LiteralBase::ToString(bool print_layout) const { - std::vector pieces; - CHECK(LayoutUtil::HasLayout(this->shape())); - ToStringHelper(*this, {}, print_layout, &pieces); - return tensorflow::str_util::Join(pieces, ""); -} - -/* static */ std::unique_ptr Literal::MakeTuple( - tensorflow::gtl::ArraySlice elements) { - std::vector element_shapes; - for (const auto* element : elements) { - element_shapes.push_back(element->shape()); - } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); - for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); - } - return literal; -} - -/* static */ std::unique_ptr Literal::MakeTupleFromSlices( +/* static */ std::unique_ptr LiteralUtil::MakeTupleFromSlices( tensorflow::gtl::ArraySlice elements) { std::vector element_shapes; for (const auto& element : elements) { @@ -1538,7 +456,7 @@ string LiteralBase::ToString(bool print_layout) const { return literal; } -/* static */ std::unique_ptr Literal::MakeTupleOwned( +/* static */ std::unique_ptr LiteralUtil::MakeTupleOwned( std::vector> elements) { std::vector element_shapes; element_shapes.reserve(elements.size()); @@ -1553,819 +471,9 @@ string LiteralBase::ToString(bool print_layout) const { return literal; } -void LiteralBase::EachCellAsString( - const std::function indices, - const string& value)>& per_cell) const { - if (ShapeUtil::HasZeroElements(shape())) { - return; - } - std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( - shape(), /*linear_index=*/0); - do { - per_cell(indices, GetAsString(indices)); - } while (IndexUtil::BumpIndices(shape(), &indices)); -} - -namespace { -template -std::unique_ptr ConvertBetweenNativeTypesWithConverter( - const LiteralBase& src_literal, const ConverterType& converter) { - CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( - src_literal.shape(), - primitive_util::NativeToPrimitiveType())); - auto src_data = src_literal.data(); - auto dest_data = result_literal->template data(); - int64 num_elements = src_literal.element_count(); - - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = converter(src_data[i]); - } - return result_literal; -} - -template -std::unique_ptr ConvertBetweenNativeTypes( - const LiteralBase& src_literal) { - auto converter = [](NativeSrcT src) { return static_cast(src); }; - return ConvertBetweenNativeTypesWithConverter( - src_literal, converter); -} - -template -typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), - std::unique_ptr>::type -BitcastBetweenNativeTypes(const LiteralBase& src_literal) { - auto converter = [](NativeSrcT src) { - return tensorflow::bit_cast(src); - }; - return ConvertBetweenNativeTypesWithConverter( - src_literal, converter); -} - -// This template specialization is here to make the compiler happy. bit_cast has -// a static check that the types are the same size. This specialization should -// never be used because the source and destination types are checked for -// identical sizes higher up. -template -typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), - std::unique_ptr>::type -BitcastBetweenNativeTypes(const LiteralBase& src_literal) { - LOG(FATAL) << "Invalid bitcast between types of different sizes."; -} - -template -std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { - CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique( - ShapeUtil::ChangeElementType(src_literal.shape(), C64)); - using NativeSrcT = - typename primitive_util::PrimitiveTypeToNative::type; - tensorflow::gtl::ArraySlice src_data = - src_literal.data(); - tensorflow::gtl::MutableArraySlice dest_data = - result_literal->data(); - int64 num_elements = src_literal.element_count(); - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = complex64(static_cast(src_data[i]), 0); - } - return result_literal; -} - -template -std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, - bool bitcast) { - CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - if (bitcast) { - return BitcastBetweenNativeTypes< - typename primitive_util::PrimitiveTypeToNative< - primitive_src_type>::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal); - } else { - return ConvertBetweenNativeTypes< - typename primitive_util::PrimitiveTypeToNative< - primitive_src_type>::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal); - } -} - -template -StatusOr> ConvertIfDestTypeMatches( - const LiteralBase& src_literal, PrimitiveType primitive_dest_type, - bool bitcast) { - switch (primitive_dest_type) { -#define CONVERT_IF_TYPES_MATCH(type) \ - case (type): \ - return ConvertIfTypesMatch(src_literal, \ - bitcast); - CONVERT_IF_TYPES_MATCH(PRED) - CONVERT_IF_TYPES_MATCH(S8) - CONVERT_IF_TYPES_MATCH(S32) - CONVERT_IF_TYPES_MATCH(S64) - CONVERT_IF_TYPES_MATCH(U8) - CONVERT_IF_TYPES_MATCH(U32) - CONVERT_IF_TYPES_MATCH(U64) - CONVERT_IF_TYPES_MATCH(F16) - CONVERT_IF_TYPES_MATCH(F32) - CONVERT_IF_TYPES_MATCH(F64) - CONVERT_IF_TYPES_MATCH(BF16) -#undef CONVERT_IF_TYPES_MATCH - case C64: - if (!bitcast) { - return ConvertToC64(src_literal); - } - break; - // Other types are not yet supported. - default: - break; - } - return Unimplemented( - "Converting from type %s to type %s is not implemented.", - PrimitiveType_Name(src_literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); -} - -StatusOr> ConvertSwitch( - const LiteralBase& literal, PrimitiveType primitive_dest_type, - bool bitcast) { - TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); - if (literal.shape().element_type() == primitive_dest_type) { - return literal.CloneToUnique(); - } - switch (literal.shape().element_type()) { -#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ - case (type): \ - return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \ - bitcast); - CONVERT_IF_DEST_TYPE_MATCHES(PRED) - CONVERT_IF_DEST_TYPE_MATCHES(S8) - CONVERT_IF_DEST_TYPE_MATCHES(S32) - CONVERT_IF_DEST_TYPE_MATCHES(S64) - CONVERT_IF_DEST_TYPE_MATCHES(U8) - CONVERT_IF_DEST_TYPE_MATCHES(U32) - CONVERT_IF_DEST_TYPE_MATCHES(U64) - CONVERT_IF_DEST_TYPE_MATCHES(F16) - CONVERT_IF_DEST_TYPE_MATCHES(F32) - CONVERT_IF_DEST_TYPE_MATCHES(F64) - CONVERT_IF_DEST_TYPE_MATCHES(BF16) -#undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. - default: - return Unimplemented( - "%s from type %s to type %s is not implemented.", - (bitcast ? "Bitcast converting" : "Converting"), - PrimitiveType_Name(literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); - } -} - -} // namespace - -StatusOr> LiteralBase::Convert( - PrimitiveType primitive_dest_type) const { - return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); -} - -StatusOr> LiteralBase::BitcastConvert( - PrimitiveType primitive_dest_type) const { - if (primitive_util::BitWidth(shape().element_type()) != - primitive_util::BitWidth(primitive_dest_type)) { - return InvalidArgument( - "Cannot bitcast convert from %s to %s, bit widths are different: %d != " - "%d", - PrimitiveType_Name(shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str(), - primitive_util::BitWidth(shape().element_type()), - primitive_util::BitWidth(primitive_dest_type)); - } - return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); -} - -StatusOr> LiteralBase::ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16) const { - if (!ShapeUtil::IsTuple(dest_shape)) { - if (round_f32_to_bf16 && shape().element_type() == F32 && - dest_shape.element_type() == BF16) { - auto converter = [](float src) { - return tensorflow::bfloat16::round_to_bfloat16(src); - }; - return ConvertBetweenNativeTypesWithConverter(*this, - converter); - } - return Convert(dest_shape.element_type()); - } - std::vector elements; - for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { - auto element = LiteralSlice(*this, {i}); - TF_ASSIGN_OR_RETURN( - auto new_element, - element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); - elements.push_back(std::move(*new_element)); - } - auto converted = MakeUnique(); - *converted = Literal::MoveIntoTuple(&elements); - return std::move(converted); -} - -template -bool LiteralBase::Piece::EqualElementsInternal( - const LiteralBase::Piece& other, std::vector* multi_index) const { - if (multi_index->size() == ShapeUtil::Rank(subshape())) { - return (Get(*multi_index) == other.Get(*multi_index)); - } - for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { - multi_index->push_back(i); - if (!EqualElementsInternal(other, multi_index)) { - return false; - } - multi_index->pop_back(); - } - return true; -} - -bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { - DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); - - std::vector multi_index; - switch (subshape().element_type()) { - case PRED: - return EqualElementsInternal(other, &multi_index); - case U8: - return EqualElementsInternal(other, &multi_index); - case S32: - return EqualElementsInternal(other, &multi_index); - case S64: - return EqualElementsInternal(other, &multi_index); - case U32: - return EqualElementsInternal(other, &multi_index); - case U64: - return EqualElementsInternal(other, &multi_index); - case F32: - return EqualElementsInternal(other, &multi_index); - case F64: - return EqualElementsInternal(other, &multi_index); - case F16: - return EqualElementsInternal(other, &multi_index); - case BF16: - return EqualElementsInternal(other, &multi_index); - case C64: - return EqualElementsInternal(other, &multi_index); - default: - LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " - << PrimitiveType_Name(subshape().element_type()); - } -} - -bool LiteralBase::operator==(const LiteralBase& other) const { - if (!ShapeUtil::Compatible(shape(), other.shape())) { - return false; - } - - return root_piece().ForEachSubpieceWithBool( - [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { - return true; - } - - const Piece& other_piece = other.piece(index); - if (!piece.EqualElements(other_piece)) { - return false; - } - return true; - }); -} - -namespace { - -template -static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, - NativeT value) { - for (int64 i = 0; i < data.size(); ++i) { - if (data[i] != value) { - return false; - } - } - return true; -} - -} // namespace - -bool LiteralBase::IsAll(int8 value) const { - return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, - const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { - return true; - } - - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case U8: - if (value >= 0) { - return AllElementsEqualValue(piece.data(), value); - } - return false; - case U32: - if (value >= 0) { - return AllElementsEqualValue(piece.data(), value); - } - return false; - case U64: - if (value >= 0) { - return AllElementsEqualValue(piece.data(), value); - } - return false; - case S8: - return AllElementsEqualValue(piece.data(), value); - case S32: - return AllElementsEqualValue(piece.data(), value); - case S64: - return AllElementsEqualValue(piece.data(), value); - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case PRED: - if (value == 0) { - return AllElementsEqualValue(piece.data(), false); - } - if (value == 1) { - return AllElementsEqualValue(piece.data(), true); - } - return false; - default: - return false; - } - return false; - }; - - if (!piece_is_all()) { - return false; - } - return true; - }); -} - -bool LiteralBase::IsAllFloat(float value) const { - return root_piece().ForEachSubpieceWithBool( - [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { - return true; - } - - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue( - piece.data(), static_cast(value)); - default: - return false; - } - }; - if (!piece_is_all()) { - return false; - } - return true; - }); -} - -bool LiteralBase::IsAllComplex(complex64 value) const { - switch (shape().element_type()) { - case C64: - return AllElementsEqualValue(root_piece().data(), - value); - default: - return false; - } -} - -bool LiteralBase::IsAllFirst() const { - return root_piece().ForEachSubpieceWithBool( - [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { - return true; - } - - // Empty shapes are not all the first element since there is no first - // element. - if (ShapeUtil::HasZeroElements(piece.subshape())) { - return false; - } - auto piece_is_all = [&]() { - switch (piece.subshape().element_type()) { - case PRED: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 8 bit types - case S8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 16 bit types - case BF16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 32 bit types - case F32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 64 bit types - case C64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - default: - return false; - } - }; - - if (!piece_is_all()) { - return false; - } - return true; - }); -} - -bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { - CHECK(ShapeUtil::IsArray(shape())); - switch (shape().element_type()) { - case U8: - return Get(indices) == 0; - case U32: - return Get(indices) == 0; - case U64: - return Get(indices) == 0; - case S8: - return Get(indices) == 0; - case S32: - return Get(indices) == 0; - case S64: - return Get(indices) == 0; - case F32: - return Get(indices) == 0.0f; - case F64: - return Get(indices) == 0.0; - case C64: - return Get(indices) == complex64(0.0f, 0.0f); - case F16: - return Get(indices) == static_cast(0.0f); - case BF16: - return Get(indices) == static_cast(0.0f); - case PRED: - return Get(indices) == false; - default: - LOG(FATAL) << "Input literal must be an array."; - } -} - -namespace { - -template -void CopyToRepeatedField(RepeatedFieldT* dest, - const tensorflow::gtl::ArraySlice src) { - *dest = RepeatedFieldT(src.begin(), src.end()); -} - -} // namespace - -void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { - *proto->mutable_shape() = subshape(); - switch (subshape().element_type()) { - case PRED: - CopyToRepeatedField(proto->mutable_preds(), data()); - break; - case U8: - proto->set_u8s(static_cast(data().data()), - element_count()); - break; - case U32: - CopyToRepeatedField(proto->mutable_u32s(), data()); - break; - case U64: - CopyToRepeatedField(proto->mutable_u64s(), data()); - break; - case S32: - CopyToRepeatedField(proto->mutable_s32s(), data()); - break; - case S64: - CopyToRepeatedField(proto->mutable_s64s(), data()); - break; - case F16: - *proto->mutable_f16s() = string( - reinterpret_cast(data().data()), size_bytes()); - if (!kLittleEndian) { - ConvertEndianShort(proto->mutable_f16s()); - } - break; - case BF16: - *proto->mutable_bf16s() = string( - reinterpret_cast(data().data()), size_bytes()); - if (!kLittleEndian) { - ConvertEndianShort(proto->mutable_bf16s()); - } - break; - case F32: - CopyToRepeatedField(proto->mutable_f32s(), data()); - break; - case F64: - CopyToRepeatedField(proto->mutable_f64s(), data()); - break; - case C64: - for (complex64 value : data()) { - proto->add_c64s(value.real()); - proto->add_c64s(value.imag()); - } - break; - case TUPLE: - // Nothing to do but assign the shape which is done above. - return; - default: - LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); - } -} - -const void* LiteralBase::Piece::untyped_data() const { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - return buffer(); -} - -void* LiteralBase::Piece::untyped_data() { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - return buffer(); -} - -namespace { - -template -Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, - const RepeatedFieldT& src) { - if (dest.size() != src.size()) { - return InvalidArgument( - "Expected %lu elements in LiteralProto repeated field, has %d", - dest.size(), src.size()); - } - std::copy(src.begin(), src.end(), dest.begin()); - return Status::OK(); -} - -} // namespace - -Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { - // These conditions should have been checked in Literal::CreateFromProto. - TF_RET_CHECK(proto.has_shape()); - TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); - TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); - - switch (subshape().element_type()) { - case PRED: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); - break; - case U8: { - auto u8_data = data(); - TF_RET_CHECK(proto.u8s().size() == u8_data.size()); - std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin()); - } break; - case S32: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s32s())); - break; - case S64: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s64s())); - break; - case U32: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u32s())); - break; - case U64: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); - break; - case F16: { - const string& s(proto.f16s()); - TF_RET_CHECK(data().size() * sizeof(half) == s.size()); - memcpy(untyped_data(), s.data(), s.size()); - if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); - } - } break; - - case BF16: { - const string& s(proto.bf16s()); - TF_RET_CHECK(data().size() * sizeof(bfloat16) == s.size()); - memcpy(untyped_data(), s.data(), s.size()); - if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); - } - } break; - case F32: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f32s())); - break; - case F64: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f64s())); - break; - case C64: { - auto complex_data = data(); - TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2); - for (int64 i = 0; i < complex_data.size(); ++i) { - complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)}; - } - } break; - case TUPLE: - LOG(FATAL) << "Should not be called on tuple shapes: " - << ShapeUtil::HumanString(subshape()); - break; - default: - LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); - } - return Status::OK(); -} - -LiteralProto LiteralBase::ToProto() const { - LiteralProto proto; - root_piece().ForEachSubpiece( - [&](const ShapeIndex& index, const Piece& piece) { - LiteralProto* proto_piece = &proto; - for (int64 i : index) { - while (proto_piece->tuple_literals_size() <= i) { - proto_piece->add_tuple_literals(); - } - proto_piece = proto_piece->mutable_tuple_literals(i); - } - piece.WriteToProto(proto_piece); - }); - - if (LayoutUtil::IsSparseArray(shape())) { - CopyToRepeatedField(proto.mutable_sparse_indices(), - sparse_indices()->data()); - } - - return proto; -} - -/* static */ -StatusOr> Literal::CreateFromProto( - const LiteralProto& proto) { - if (!proto.has_shape()) { - return InvalidArgument("LiteralProto has no shape"); - } - if (!LayoutUtil::HasLayout(proto.shape())) { - return InvalidArgument("LiteralProto has no layout"); - } - - auto literal = MakeUnique(proto.shape()); - - TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( - [&](const ShapeIndex& index, Piece* piece) { - const LiteralProto* proto_element = &proto; - for (int64 i : index) { - CHECK(i < proto_element->tuple_literals_size()); - proto_element = &proto_element->tuple_literals(i); - } - - if (ShapeUtil::IsTuple(piece->subshape())) { - if (proto_element->tuple_literals_size() != - ShapeUtil::TupleElementCount(piece->subshape())) { - return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", - ShapeUtil::TupleElementCount(piece->subshape()), - proto_element->tuple_literals_size()); - } - return Status::OK(); - } - - CHECK(ShapeUtil::IsArray(piece->subshape())); - TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); - - return Status::OK(); - })); - - return std::move(literal); -} - -/* static */ string Literal::MultiIndexAsString( +/* static */ string LiteralUtil::MultiIndexAsString( tensorflow::gtl::ArraySlice multi_index) { return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); } -const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { - return piece(shape_index).untyped_data(); -} - -void* Literal::untyped_data(const ShapeIndex& shape_index) { - return piece(shape_index).untyped_data(); -} - -int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { - return piece(shape_index).size_bytes(); -} - -string LiteralBase::GetR1U8AsString() const { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); - CHECK_EQ(shape().element_type(), U8); - return string(tensorflow::bit_cast(data().data()), - ShapeUtil::ElementsIn(shape())); -} - -void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { - CHECK(ShapeUtil::IsTuple(shape)); - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const Shape& subshape = shape.tuple_shapes(i); - - auto child_piece = Piece(); - child_piece.set_subshape(&subshape); - - if (ShapeUtil::IsTuple(subshape)) { - BuildPieceSubtree(subshape, &child_piece); - } - - piece->emplace_back(std::move(child_piece)); - } -} - -LiteralSlice::LiteralSlice(const LiteralBase& literal) - : LiteralBase(), root_piece_(&literal.root_piece()) {} - -LiteralSlice::LiteralSlice(const LiteralBase& literal, - const ShapeIndex& view_root) - : LiteralBase(), root_piece_(&literal.piece(view_root)) {} - -BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { - CHECK(ShapeUtil::IsArray(*shape_)); - CHECK_NE(src_buf_ptr, nullptr); - CHECK(LayoutUtil::HasLayout(*shape_)); - - root_piece_ = Piece(); - root_piece_.set_buffer(const_cast(src_buf_ptr)); - root_piece_.set_subshape(shape_.get()); -} - -BorrowingLiteral::BorrowingLiteral( - tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { - CHECK(ShapeUtil::IsTuple(*shape_)); - CHECK(!ShapeUtil::IsNestedTuple(*shape_)); - CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); - root_piece_ = Piece(); - root_piece_.set_subshape(shape_.get()); - BuildPieceSubtree(*shape_, &root_piece_); - - for (int i = 0; i < src_buf_ptrs.size(); ++i) { - const auto& src_shape = shape_->tuple_shapes(i); - CHECK(ShapeUtil::IsArray(src_shape)); - root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); - } -} - } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 8e4159e360e042beb31a75c432a3c7dfa7356007..e3737a9d0051b32dc0becc19e1849c856a50e52e 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -51,679 +52,12 @@ limitations under the License. namespace xla { -// Forward declare Literal and LiteralSlice class to be used by the creation -// methods in the base class. -class Literal; -class LiteralSlice; - -// Abstract base class for literals. -class LiteralBase { +class LiteralUtil { public: - virtual ~LiteralBase() = 0; - - // Literals are equal if they have compatible shapes and the same data - // values. Layout is not compared. - bool operator==(const LiteralBase& other) const; - bool operator!=(const LiteralBase& other) const { return !(*this == other); } - - // Returns the shape of the literal. - const Shape& shape() const { return root_piece().subshape(); } - - // Serialize to proto. - LiteralProto ToProto() const; - - // Returns an ArraySlice of the array for this literal for the given NativeT - // (e.g., float). CHECKs if the subshape of the literal at the given - // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type - // to native type. - template - tensorflow::gtl::ArraySlice data( - const ShapeIndex& shape_index = {}) const; - - // Returns a const pointer to the sparse index array. Returns nullptr if the - // literal is not a sparse array. - const SparseIndexArray* sparse_indices( - const ShapeIndex& shape_index = {}) const; - - // Returns a const pointer to (or size of) the underlying buffer holding the - // array at the given shape index. CHECKs if the subshape of the literal at - // the given ShapeIndex is not array. - const void* untyped_data(const ShapeIndex& shape_index = {}) const; - int64 size_bytes(const ShapeIndex& shape_index = {}) const; - - // Returns this literal's data as a string. This literal must be a rank-1 U8 - // array. - string GetR1U8AsString() const; - - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; - - // Gets an element in the literal at the given index. The multi_index is - // CHECKed against the dimension sizes. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const; - // Overloads of Get for array literals. CHECKs if the literal is not - // array-shaped and dense. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; - - // Returns the element value at index (0, ..., 0), however many zeroes are - // required for that index. - template - NativeT GetFirstElement() const; - - // As Get(), but determines the correct type and converts the value - // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index = {}) const; - // As GetSparseElement(), but determines the correct type and converts the - // value into text. - string GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - // As Get(), but determines the correct type and converts the value into - // int64. This literal must be an array. - StatusOr GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const; - - // Returns the multi-index of the element in a sparse literal at the given - // sparse element number. The sparse element number is the position with in - // the sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - tensorflow::gtl::ArraySlice GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; - - // Returns the value of the element in a sparse literal at the given sparse - // element number. The sparse element number is the position with in the - // sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - template - NativeT GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // Invokes the "per cell" callback for each element in the provided - // literal with the element's indices and a string representation of - // the element's value. - // - // This function is useful if you want a polymorphic representation - // of the tensor's elements (turning it to a string for something - // like representation in a protobuf). - // - // This literal must have a dense layout. - void EachCellAsString( - const std::function indices, - const string& value)>& per_cell) const; - template - void EachCell(std::function indices, - NativeT value)> - per_cell) const; - - // Returns whether every element in this literal is equal to value. - // - // value is an int8 because we expect this to be called with small - // compile-time constants (0, -1, etc.) and so that whatever value you pass - // can be represented exactly by floating-point types as small as 16 bits. - // - // If value doesn't fit in this literal's type, returns false. Values of 1/0 - // are considered equal to true/false; other values are not considered equal - // to true. Also if this literal is not array-shaped false is returned. - bool IsAll(int8 value) const; - - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular floating-point number. - // - // If the literal is not a floating-point value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. Also if this literal is not array-shaped false is returned. - bool IsAllFloat(float value) const; - - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular complex number. - // - // If the literal is not a complex value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for complex values that can be expressed precisely as - // float pairs e.g. (-0.5, 1.0). - // - // This literal must have a dense layout. - bool IsAllComplex(complex64 value) const; - - // Literal consists entirely of the first element of the literal. - bool IsAllFirst() const; - - // Returns whether this literal is zero at the specified index. This literal - // must be an array with a dense layout. - bool IsZero(tensorflow::gtl::ArraySlice indices) const; - - // Returns the count of the elements in the array at the given shape index in - // this literal. - int64 element_count(const ShapeIndex& index = {}) const { - return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); - } - - // Returns the count of the elements in the sparse array at the given shape - // index in this literal, which will be no larger than - // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). - int64 sparse_element_count() const; - - // Compute a hash for this literal. This literal must not be a sparse tensor - // or a tuple containing a sparse tensor. - size_t Hash() const; - - // Converts this literal to the given shape. Returns an error is the - // conversion is not possible. - // - // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding - // instead of truncation; otherwise, truncation is used. - // - // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes - // the default behavior. - StatusOr> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; - - // Converts this literal to another primitive type using a bitcast - // conversion. The to and from primitive types must have the same bit - // width. Returns an error if the conversion is not possible. This literal - // must be array-shaped. - StatusOr> BitcastConvert( - PrimitiveType primitive_dest_type) const; - - // Converts this literal to another primitive type. Returns an error if the - // conversion is not possible. This literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; + LiteralUtil() = delete; // Returns a literal scalar representing the first element. - Literal GetFirstScalarLiteral() const; - - // Clones the underlying buffers into a new Literal, or new - // std::unique_ptr. - Literal Clone() const; - std::unique_ptr CloneToUnique() const; - - // TODO(b/67651157): The methods below which perform computation on Literals - // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with - // evaluator code which operates on Literals. - // - // Creates a new value that has the equivalent value as this - // literal, but conforms to new_layout; e.g. a literal matrix that was in {0, - // 1} minor-to-major dimension layout can be re-layed-out as {1, 0} - // minor-to-major dimension layout and the value in the cell at any given - // logical index (i0, i1) will be the same. - // - // For tuple shaped literals, shape_index should be used to select the inner - // array that the new layout applies to. - // - // Note: this is useful when the client wants to ensure that a value placed in - // the XLA allocation tracker has a particular layout; for efficiency - // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; - - // An overload of Relayout which changes the layout of the entire shape rather - // than being limited to a single array within the shape. - std::unique_ptr Relayout(const Shape& shape_with_layout) const; - - // Creates a new literal by reshaping this literal to have the given - // dimensions. The total number of elements must not change; The - // implementation currently only supports monotonic dim0-major layouts. - // This literal must be an array. - StatusOr> Reshape( - tensorflow::gtl::ArraySlice dimensions) const; - - // Creates a new literal by broadcasting this literal with `dimensions` to - // yield a literal of shape `result_shape`. - StatusOr> Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const; - - // Creates a new literal by reordering the dimensions of this literal. - // The given `permutation` must be a permutation of the dimension numbers - // in the original literal, and it specifies the order of the new dimensions - // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). - // For example, a transpose call on a literal of shape [3 x 8 x 4] and - // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - // This literal must be an array. - std::unique_ptr Transpose( - tensorflow::gtl::ArraySlice permutation) const; - - // Creates a sub-array from this literal by extracting the indices - // [start_index, limit_index) of each dimension. The result literal has the - // same rank and layout as for the given literal. The number of indices in - // start_indices and limit_indices must be the rank of the literal, and the - // indices follow the order of the dimensions. - // This literal must be an array. - std::unique_ptr Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const; - - // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this - // literal replicated four times. - // This literal must be an array. - template - std::unique_ptr Replicate(int64 times) const; - - // Creates a new Literal object with the shape specified as parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - // - // Note: It's an antipattern to use this method then immediately call - // Literal::Populate on the result (since that results in zero initialization, - // then reinitialization. Conside if a call to MakeUnique(shape), - // followed by the call to Literal::Populate can be used instead. - static std::unique_ptr CreateFromShape(const Shape& shape); - - protected: - // A data structure representing a subshape at a particular ShapeIndex within - // the literal. For array-shaped ShapeIndexes, this data structure holds the - // pointer to the memory allocated for the array data. - class Piece { - public: - // Returns the buffer holding the array data for this piece as an array - // slice. This piece must be array-shaped. - template - tensorflow::gtl::ArraySlice data() const; - template - tensorflow::gtl::MutableArraySlice data(); - - // Returns the buffer holding the array data for this piece as a void*. This - // piece must be array-shaped. - void* untyped_data(); - const void* untyped_data() const; - - // Gets or sets an element in the array at the given index. The multi_index - // is CHECKed against the dimension sizes of the array. This piece must be - // array-shaped. - template - NativeT Get(tensorflow::gtl::ArraySlice index) const; - template - void Set(tensorflow::gtl::ArraySlice index, NativeT value); - - // Gets/sets the buffer holding the array data. - char* buffer() const { return buffer_; } - void set_buffer(char* buffer) { buffer_ = buffer; } - - // The array of multi-indices that provide the locations of non-zero - // elements in a sparse array. Only used if - // LayoutUtil::IsSparseArray(shape()) is true. - SparseIndexArray* sparse_indices() const { return sparse_indices_; } - void set_sparse_indices(SparseIndexArray* sparse_indices) { - sparse_indices_ = sparse_indices; - } - - // Gets or sets the subshape of this piece. This reference points to a - // subshape within the shape in the containing Literal (Literal::shape_). - const Shape& subshape() const { return *subshape_; } - void set_subshape(const Shape* subshape) { subshape_ = subshape; } - - // Returns the size in bytes of the buffer holding the array data. - int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } - - // Returns the number of elements in this piece's array. - int64 element_count() const { - // If this is a sparse array, use the number of elements represented by - // the indices in the associated SparseIndexArray. - return LayoutUtil::IsSparseArray(subshape()) - ? sparse_indices()->index_count() - : ShapeUtil::ElementsIn(subshape()); - } - - // Returns the child piece at 'index' of this piece. - Piece& child(int64 index) { return children_[index]; } - - // Adds a child piece to this piece's children. - void emplace_back(Piece child_piece) { - children_.emplace_back(std::move(child_piece)); - } - - // Returns the size of children pieces of this piece. - int64 children_size() { return children_.size(); } - - // Visitor functions that recursively traverses the piece and calls the - // given function at each child piece. The function has the type: - // void (const ShapeIndex& index, const Piece& piece) - template - void ForEachSubpiece(const Fn& func) const { - ShapeIndex index; - return ForEachHelper( - [&func](const ShapeIndex& index, const Piece& piece) { - func(index, piece); - return Status::OK(); - }, - *this, &index) - .IgnoreError(); - } - // Same as above, but the function has the type: - // Status (const ShapeIndex& index, const Piece& piece) - // The first non-OK return value is returned by the function. - template - Status ForEachSubpieceWithStatus(const Fn& func) const { - ShapeIndex index; - return ForEachHelper(func, *this, &index); - } - // Same as above, but the function has the type: - // Bool (const ShapeIndex& index, const Piece& piece) - // The first non-true return value is returned by the function. - template - bool ForEachSubpieceWithBool(const Fn& func) const { - ShapeIndex index; - return ForEachHelperBool(func, *this, &index); - } - // Same as above, but the function has the type: - // Void (const ShapeIndex& index, Piece& piece) - template - void ForEachMutableSubpiece(const Fn& func) { - ShapeIndex index; - return ForEachMutableHelper( - [&func](const ShapeIndex& index, Piece* piece) { - func(index, piece); - return Status::OK(); - }, - const_cast(this), &index) - .IgnoreError(); - } - // Same as above, but the function has the type: - // Status (const ShapeIndex& index, Piece& piece) - // The first non-OK return value is returned by the function. - template - Status ForEachMutableSubpieceWithStatus(const Fn& func) { - ShapeIndex index; - return ForEachMutableHelper( - func, const_cast(this), &index); - } - - // Returns true if this piece and 'other' contain the same data. This piece - // and 'other' must be array-shaped and compatible. - bool EqualElements(const Piece& other) const; - - // Writes the shape and data (if array-shaped) into the given proto. - void WriteToProto(LiteralProto* proto) const; - - // Copy the data from 'src' into this piece's buffer. Shapes of this piece - // and src must be compatible. - Status CopyFrom(const Piece& src); - - // Copies the data from the given proto into this piece. The shape of this - // piece must be equal (not just compatible) to the shape of the proto. - Status CopyFromProto(const LiteralProto& proto); - - // Sorts the elements in a sparse array. - void SortSparseElements(); - - private: - // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'. - // The first non-OK (or non-true) value is returned by the function. - // The callable 'func' has the same signature as described above in - // ForEachSubpiece*. - template - Status ForEachHelper(const Fn& func, const Piece& piece, - ShapeIndex* index) const { - TF_RETURN_IF_ERROR(func(*index, piece)); - for (int64 i = 0; i < piece.children_.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index)); - index->pop_back(); - } - return Status::OK(); - } - template - bool ForEachHelperBool(const Fn& func, const Piece& piece, - ShapeIndex* index) const { - if (!func(*index, piece)) { - return false; - } - for (int64 i = 0; i < piece.children_.size(); ++i) { - index->push_back(i); - if (!ForEachHelperBool(func, piece.children_[i], index)) { - return false; - } - index->pop_back(); - } - return true; - } - template - Status ForEachMutableHelper(const Fn& func, Piece* piece, - ShapeIndex* index) { - TF_RETURN_IF_ERROR(func(*index, piece)); - for (int64 i = 0; i < piece->children_.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR( - ForEachMutableHelper(func, &piece->children_[i], index)); - index->pop_back(); - } - return Status::OK(); - } - - // Recursive helper for EqualElements. - template - bool EqualElementsInternal(const Piece& other, - std::vector* multi_index) const; - - // Helper for SortSparseElements that has the element type as a template - // parameter. - template - void SortSparseElementsInternal(); - - // For array-shaped pieces, this is the buffer holding the literal data. - char* buffer_ = nullptr; - - // For sparse arrays, this is the array of indices. - SparseIndexArray* sparse_indices_ = nullptr; - - // The shape of piece. This points into the shape of the containing Literal - // (Literal::shape_). - const Shape* subshape_ = nullptr; - - // Children pieces for tuple shaped pieces. - std::vector children_ = {}; - }; // class Piece - - const Piece& piece(const ShapeIndex& shape_index) const { - Piece* piece = &const_cast(root_piece()); - for (const auto i : shape_index) { - DCHECK_GE(i, 0); - DCHECK_LT(i, piece->children_size()); - piece = &piece->child(i); - } - return *piece; - } - - // Returns the piece at the root of the shape. - virtual const Piece& root_piece() const = 0; - - // LiteralSlice and Literal must access Pieces of other Literals. - friend class Literal; - friend class LiteralSlice; - friend class BorrowingLiteral; - - private: - template - std::unique_ptr SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const; -}; - -// Class representing literal values in XLA. -// -// The underlying buffer and shape is always owned by this class. -class Literal : public LiteralBase { - public: - Literal() : Literal(ShapeUtil::MakeNil()) {} - - // Create a literal of the given shape. The literal is allocated sufficient - // memory to hold the shape. Memory is uninitialized. - explicit Literal(const Shape& shape); - virtual ~Literal(); - - // Literals are moveable, but not copyable. To copy a literal use - // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies - // of literals which can be expensive. - Literal(const Literal& other) = delete; - Literal& operator=(const Literal& other) = delete; - Literal(Literal&& other); - // 'allocate_arrays' indicates whether to allocate memory for the arrays in - // the shape. If false, buffer pointers inside of the Literal::Pieces are set - // to nullptr. - Literal(const Shape& shape, bool allocate_arrays); - Literal& operator=(Literal&& other); - - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return shape_.get(); } - - // Returns a MutableArraySlice view of the array for this literal for the - // given NativeT (e.g., float). CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. See primitive_util.h for the mapping from - // XLA type to native type. - template - tensorflow::gtl::MutableArraySlice data( - const ShapeIndex& shape_index = {}); - // Unhide const method from parent class. - using LiteralBase::data; - - // Returns a pointer to the sparse index array. Returns nullptr if the literal - // is not a sparse array. - SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); - - // Returns a pointer to the underlying buffer holding the array at the given - // shape index. CHECKs if the subshape of the literal at the given ShapeIndex - // is not array. - void* untyped_data(const ShapeIndex& shape_index = {}); - // Unhide const method from parent class. - using LiteralBase::untyped_data; - - // Populates a literal with a sparse layout with the given indices and values. - // Each index in the indices array is CHECKed against the dimensions in the - // literal's shape. If sort is true, then the indices and values will be - // sorted. If sort is false, then the indices and values are assumed to - // already be in sorted order. See CreateSparse for an example of how data - // are populated. - template - void PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort = true); - - // Copy values from 'src_literal' rooted at 'src_shape_index' into this - // literal rooted at 'dest_shape_index'. The subshape of this literal rooted - // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' - // rooted at 'src_shape_index', but need not be arrays. - Status CopyFrom(const LiteralSlice& src_literal, - const ShapeIndex& dest_shape_index = {}, - const ShapeIndex& src_shape_index = {}); - - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); - - // Copies the values from src_literal, starting at src_base shape indexes, - // to this literal, starting at dest_base, where the copy size in each - // dimension is specified by copy_size. - // The src_literal and this literal must have the same primitive type, - // src_base+copy_size must fit the source literal dimensions, as well as - // dest_base+copy_size must fit the destination literal dimensions. - // Note: if either src_literal or this literal contains dimensions with zero - // element, then copy_size must be 0 in these dimensions while the - // corresponding base indices being 0. - // This literal and 'src_literal' must be arrays. - Status CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); - - // Copies one element from src_literal[src_index] to (*this)[dest_index]. - Status CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index); - - // Sets an element in the literal at the given index. The multi_index is - // CHECKed against the dimension sizes. - template - void Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value); - // Overloads of Set for array literals. CHECKs if the literal is not - // array-shaped and dense. - template - void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - - // Appends the given element to the literal. If the elements are not appended - // in sorted order, then SortSparseElements should be called before calling - // other methods. This literal must have a sparse layout. - template - void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, - NativeT value, const ShapeIndex& shape_index = {}); - - // Sorts the elements in a sparse array. - void SortSparseElements(const ShapeIndex& shape_index = {}); - - // As Set(), but truncates `value` to the literal element type before storing. - // This literal must be an array. - Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value); - - // Populate this literal with the given values. Examples: - // - // // Populate with floats. - // Array2D float_values = ... - // literal.PopulateR2FromArray2D(values); - // - // // Populate with int32s. - // literal.PopulateR2({{1, 2}, {3, 4}}); - // - // The shape and element type of this literal must match given values. For - // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 - // array of S32. - template - void PopulateR1(tensorflow::gtl::ArraySlice values); - void PopulateR1(const tensorflow::core::Bitmap& values); - template - void PopulateR2(std::initializer_list> values); - template - void PopulateFromArray(const Array& values); - template - void PopulateR2FromArray2D(const Array2D& values); - template - void PopulateR3FromArray3D(const Array3D& values); - template - void PopulateR4FromArray4D(const Array4D& values); - - // Populates literal values by calling the generator function for every cell - // in this literal object. - // - // generator must be a callable of the type - // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. - // - // This literal must have a dense layout. - template - Status Populate(const FnType& generator); - - // A parallel version of Populate(). This can be used if the generator is - // thread-safe and the values for the shape's different elements are - // independent. - template - Status PopulateParallel(const FnType& generator); - - // Fills this literal with the given value. - template - void PopulateWithValue(NativeT value); - - // Factory methods below. - // - - // Serialize from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); + static Literal GetFirstScalarLiteral(const LiteralSlice& literal); // Creates a new literal of a given rank. To minimize ambiguity (for users // and the compiler) these CreateR[0-2] methods should explicitly specify the @@ -889,7 +223,7 @@ class Literal : public LiteralBase { // As above, but intended to be invoked with move semantics; i.e. // // std::vector> elements = ...; - // auto result = Literal::MakeTupleOwned(std::move(elements)); + // auto result = LiteralUtil::MakeTupleOwned(std::move(elements)); // // This would have been declared as an overload, but there is ambiguity // in invocation between the above signature and this one. @@ -899,7 +233,7 @@ class Literal : public LiteralBase { // This overload lets you pass a braced list of unique_ptrs to // MakeTupleOwned: // - // Literal::MakeTupleOwned(Literal::CreateR1(...), ...). + // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...). // // Simply relying on the MakeTupleOwned(std::vector>) // overload doesn't work because std::initializer_list's elements are always @@ -917,18 +251,8 @@ class Literal : public LiteralBase { return MakeTupleOwned(std::move(v)); } - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector DecomposeTuple(); - - // This operation is the inverse of DecomposeTuple. The given elements are - // moved into the tuple elements of a new tuple-shaped Literal which is - // returned. Upon return, each of the Literals in 'elements' is set to a nil - // shape (empty tuple). - static Literal MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements); + // Create a constant token literal. Token types have no value. + static std::unique_ptr CreateToken(); // Creates a new Literal object with its values havings the primitive_type // type, and with dimensions defined by the dimensions parameter. @@ -997,194 +321,12 @@ class Literal : public LiteralBase { // dimension 1 equal to 8. static string MultiIndexAsString( tensorflow::gtl::ArraySlice multi_index); - - private: - // Recursively sets the subshapes and buffers of all subpieces rooted at - // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in - // the shape. - void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); - - // Returns the piece at the given ShapeIndex. - Piece& piece(const ShapeIndex& shape_index) { - return const_cast(LiteralBase::piece(shape_index)); - } - - Piece& root_piece() const override { return *root_piece_; }; - - // Internal template helper for the Literal::CopySliceFrom(), matching its - // arguments one by one. - template - Status CopySliceFromInternal(const LiteralBase& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); - - // Utility structure which is used to create the optimal configuration for - // a ShapeUtil::ForEachIndex() scan across two literals. - struct StrideConfig { - StrideConfig(const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions); - - // The dimensions of the stride operation. Essentially every dimension - // will be iterated from base[i] to base[i]+dimensions[i], in step[i] - // steps. - tensorflow::gtl::ArraySlice dimensions; - DimensionVector base; - DimensionVector step; - int64 minor_dimension = 0; - // The size of the strides for source and destination. One of the two - // (the one looping through its most minor dimension) will be 1, while - // the other will be the stride size at the dimension matching the other - // shape most minor dimension being scanned. - int64 dest_stride = 1; - int64 source_stride = 1; - // The size of the inner loop on the most minor dimension. - int64 minor_loop_size = 1; - }; - - // Literal class always owns the shape. The parent class borrows this shape. - std::unique_ptr shape_; - - Piece* root_piece_ = nullptr; - - // Implementation details shared between Populate() and PopulateParallel() - template - Status PopulateInternal(const FnType& generator, bool parallel); - - // Deallocate the buffers held by this literal. - void DeallocateBuffers(); - - friend class LiteralBase; -}; -std::ostream& operator<<(std::ostream& out, const Literal& literal); - -// A read-only view of a Literal. A LiteralSlice contains pointers to shape and -// literal buffers always owned by others. -class LiteralSlice : public LiteralBase { - public: - LiteralSlice() : LiteralBase() {} - - // Implicit conversion constructors. - LiteralSlice(const LiteralBase& literal); - LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root); - - private: - const Piece& root_piece() const override { return *root_piece_; }; - - const Piece* root_piece_; // Not owned. -}; - -// A read-only Literal where the underlying buffers are never owned by this -// class. -class BorrowingLiteral : public LiteralBase { - public: - BorrowingLiteral() : LiteralBase() {} - - // 'src_buf_ptr' is not owned by this class and must outlive the - // lifetime of this class. It points to an appropirately sized buffer with - // data interpretered as indicated by 'shape'. - // This constructor is only used for array shapes. - BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); - // Similar as above, except to be used for constructing non-nested tuples. - BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, - const Shape& shape); - // TODO(b/79707221): adding constructors for nested tuples as well. - - private: - // Recursively builds the subtree for the given piece and sets the subshapes - // of the given piece with the given shape. - void BuildPieceSubtree(const Shape& shape, Piece* piece); - - // Accessor for the root piece of this literal. - const Piece& root_piece() const override { return root_piece_; }; - Piece root_piece_; - - // Shape of this literal. Stored as unique_ptr so such that the (default) - // move construction of this class would be trivially correct: the pointer to - // Shape root_piece_ stores will still point to the correct address. - std::unique_ptr shape_; }; -template -tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType()) - << "Attempting to access " - << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) - << " type, but literal element type is " - << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(buffer()), element_count()); -} - -template -tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType()) - << "Attempting to access " - << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) - << " type, but literal element type is " - << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(buffer()), element_count()); -} - -template -NativeT LiteralBase::Piece::Get( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(LayoutUtil::IsDenseArray(subshape())); - return data()[IndexUtil::MultidimensionalIndexToLinearIndex( - subshape(), multi_index)]; -} - -template -void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - CHECK(LayoutUtil::IsDenseArray(subshape())); - data()[IndexUtil::MultidimensionalIndexToLinearIndex( - subshape(), multi_index)] = value; -} - -template -tensorflow::gtl::ArraySlice LiteralBase::data( - const ShapeIndex& shape_index) const { - return piece(shape_index).data(); -} - -template -tensorflow::gtl::MutableArraySlice Literal::data( - const ShapeIndex& shape_index) { - return piece(shape_index).data(); -} - -template -inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { - return piece(shape_index).Get(multi_index); -} - -template -inline NativeT LiteralBase::Get( - tensorflow::gtl::ArraySlice multi_index) const { - return root_piece().Get(multi_index); -} - -template -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value) { - return piece(shape_index).Set(multi_index, value); -} - -template -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - return root_piece().Set(multi_index, value); -} +std::ostream& operator<<(std::ostream& out, const Literal& literal); template -/* static */ std::unique_ptr Literal::CreateR0(NativeT value) { +/* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { auto literal = MakeUnique(ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {})); literal->Set({}, value); @@ -1192,7 +334,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateR1( +/* static */ std::unique_ptr LiteralUtil::CreateR1( tensorflow::gtl::ArraySlice values) { auto literal = MakeUnique( ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), @@ -1202,7 +344,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateR2WithLayout( +/* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( @@ -1215,13 +357,13 @@ template } template -/* static */ std::unique_ptr Literal::CreateR2( +/* static */ std::unique_ptr LiteralUtil::CreateR2( std::initializer_list> values) { return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ std::unique_ptr Literal::CreateR3WithLayout( +/* static */ std::unique_ptr LiteralUtil::CreateR3WithLayout( std::initializer_list>> values, const Layout& layout) { @@ -1246,14 +388,14 @@ template } template -/* static */ std::unique_ptr Literal::CreateR3( +/* static */ std::unique_ptr LiteralUtil::CreateR3( std::initializer_list>> values) { return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr Literal::CreateR4WithLayout( +/* static */ std::unique_ptr LiteralUtil::CreateR4WithLayout( std::initializer_list>>> values, @@ -1284,7 +426,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateSparse( +/* static */ std::unique_ptr LiteralUtil::CreateSparse( tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, tensorflow::gtl::ArraySlice values, bool sort) { int64 num_elements = values.size(); @@ -1299,7 +441,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateR4( +/* static */ std::unique_ptr LiteralUtil::CreateR4( std::initializer_list>>> values) { @@ -1307,7 +449,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateFromArrayWithLayout( +/* static */ std::unique_ptr LiteralUtil::CreateFromArrayWithLayout( const Array& values, const Layout& layout) { auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), @@ -1317,38 +459,40 @@ template } template -/* static */ std::unique_ptr Literal::CreateFromArray( +/* static */ std::unique_ptr LiteralUtil::CreateFromArray( const Array& values) { return CreateFromArrayWithLayout( values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); } template -/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { +/* static */ std::unique_ptr +LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr Literal::CreateR2FromArray2D( +/* static */ std::unique_ptr LiteralUtil::CreateR2FromArray2D( const Array2D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr Literal::CreateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { +/* static */ std::unique_ptr +LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr Literal::CreateR3FromArray3D( +/* static */ std::unique_ptr LiteralUtil::CreateR3FromArray3D( const Array3D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr Literal::CreateR3Projected( +/* static */ std::unique_ptr LiteralUtil::CreateR3Projected( std::initializer_list> values, int64 projection) { int64 dim0_size = projection; @@ -1373,7 +517,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateR4Projected( +/* static */ std::unique_ptr LiteralUtil::CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z) { int64 dim0_size = projection_p; @@ -1401,49 +545,21 @@ template } template -/* static */ std::unique_ptr Literal::CreateR4FromArray4D( +/* static */ std::unique_ptr LiteralUtil::CreateR4FromArray4D( const Array4D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr Literal::CreateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { +/* static */ std::unique_ptr +LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } -template -NativeT LiteralBase::GetFirstElement() const { - return data().at(0); -} - -template -NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index) const { - CHECK( - LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); - return data(shape_index)[sparse_element_number]; -} - -template -void Literal::AppendSparseElement( - tensorflow::gtl::ArraySlice multi_index, NativeT value, - const ShapeIndex& shape_index) { - Piece& p = piece(shape_index); - const Shape& subshape = p.subshape(); - CHECK(LayoutUtil::IsSparseArray(subshape)); - int64 rank = ShapeUtil::Rank(subshape); - CHECK_EQ(multi_index.size(), rank); - int64 last_element = p.sparse_indices()->index_count(); - CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); - p.sparse_indices()->Append(multi_index); - CHECK_LT(last_element, p.data().size()); - p.data()[last_element] = value; -} - // Returns an identity matrix (rank 2) with the given row and column count. template -/* static */ std::unique_ptr Literal::MakeIdentityR2(int64 size) { +/* static */ std::unique_ptr LiteralUtil::MakeIdentityR2(int64 size) { Array2D array(size, size, 0); for (int64 i = 0; i < size; ++i) { array(i, i) = 1; @@ -1452,174 +568,8 @@ template } template -void LiteralBase::EachCell( - std::function indices, - NativeT value)> - per_cell) const { - if (ShapeUtil::HasZeroElements(shape())) { - return; - } - std::vector indices(ShapeUtil::Rank(shape()), 0); - do { - per_cell(indices, Get(indices)); - } while (IndexUtil::BumpIndices(shape(), &indices)); -} - -template -inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); - CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); - CHECK_EQ(shape().element_type(), - primitive_util::NativeToPrimitiveType()); - for (int64 i = 0; i < values.size(); ++i) { - Set({i}, values[i]); - } -} - -template -void Literal::PopulateR2( - std::initializer_list> values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 2); - CHECK_EQ(shape().element_type(), - primitive_util::NativeToPrimitiveType()); - - const int64 dim0_size = values.size(); - const int64 dim1_size = values.begin()->size(); - CHECK_EQ(dim0_size, shape().dimensions(0)); - CHECK_EQ(dim1_size, shape().dimensions(1)); - - int64 dim0 = 0; - for (auto inner_list : values) { - int64 dim1 = 0; - for (auto value : inner_list) { - Set({dim0, dim1}, value); - ++dim1; - } - CHECK_EQ(dim1_size, dim1); - ++dim0; - } -} - -template -void Literal::PopulateFromArray(const Array& values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(shape().element_type(), - primitive_util::NativeToPrimitiveType()); - CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); - for (int dim = 0; dim < values.num_dimensions(); ++dim) { - CHECK_EQ(values.dim(dim), shape().dimensions(dim)); - } - values.Each([this](tensorflow::gtl::ArraySlice indices, - NativeT value) { this->Set(indices, value); }); -} - -template -void Literal::PopulateR2FromArray2D(const Array2D& values) { - PopulateFromArray(values); -} - -template -void Literal::PopulateR3FromArray3D(const Array3D& values) { - PopulateFromArray(values); -} - -template -void Literal::PopulateR4FromArray4D(const Array4D& values) { - PopulateFromArray(values); -} - -template -void Literal::PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort) { - CHECK(LayoutUtil::IsSparseArray(shape())); - int rank = ShapeUtil::Rank(shape()); - CHECK_EQ(indices.rank(), rank); - int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); - CHECK_LE(indices.max_indices(), max_elements); - int64 num_elements = values.size(); - CHECK_LE(num_elements, max_elements); - CHECK_EQ(num_elements, indices.index_count()); - auto root_data = root_piece().data(); - // Piece::data() returns an ArraySlice of size equal to the number of indices - // in the SparseIndexArray. So there is no need to adjust the size of the data - // here. It is enough to just copy the incoming values into the data buffer. - std::copy(values.begin(), values.end(), root_data.begin()); - *this->root_piece().sparse_indices() = std::move(indices); - if (sort) { - auto root_data = this->root_piece().data(); - this->root_piece().sparse_indices()->SortWithValues(root_data); - } - DCHECK(this->root_piece().sparse_indices()->Validate(shape())); -} - -template -Status Literal::PopulateInternal(const FnType& generator, bool parallel) { - const Shape& this_shape = shape(); - const int64 rank = ShapeUtil::Rank(this_shape); - TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); - TF_RET_CHECK(this_shape.element_type() == - primitive_util::NativeToPrimitiveType()); - tensorflow::gtl::MutableArraySlice literal_data = data(); - if (rank > 0) { - StrideConfig stride_config(this_shape, this_shape, - AsInt64Slice(this_shape.dimensions())); - int64 minor_dimension_size = - ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); - - auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { - DimensionVector minor_scan_indexes(rank, 0); - const int64 index = - IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); - std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); - for (int64 i = 0; i < minor_dimension_size; ++i) { - minor_scan_indexes[stride_config.minor_dimension] = i; - literal_data.at(index + i) = generator(minor_scan_indexes); - } - }; - if (parallel) { - ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base, - stride_config.dimensions, - stride_config.step, init_function); - } else { - ShapeUtil::ForEachIndex( - this_shape, stride_config.base, stride_config.dimensions, - stride_config.step, - [&init_function](tensorflow::gtl::ArraySlice indexes) { - init_function(indexes); - return true; - }); - } - } else { - // For scalars. - literal_data.at(0) = generator({}); - } - return Status::OK(); -} -template -Status Literal::Populate(const FnType& generator) { - return PopulateInternal(generator, /*parallel=*/false); -} - -template -Status Literal::PopulateParallel(const FnType& generator) { - return PopulateInternal(generator, /*parallel=*/true); -} - -template -void Literal::PopulateWithValue(NativeT value) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(shape().element_type(), - primitive_util::NativeToPrimitiveType()); - for (NativeT& element : data()) { - element = value; - } -} - -template -/* static */ std::unique_ptr Literal::CreateFullWithDescendingLayout( +/* static */ std::unique_ptr +LiteralUtil::CreateFullWithDescendingLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value) { auto literal = MakeUnique(ShapeUtil::MakeShapeWithDescendingLayout( primitive_util::NativeToPrimitiveType(), dimensions)); @@ -1627,44 +577,9 @@ template return literal; } -template -std::unique_ptr LiteralBase::Replicate(int64 times) const { - DimensionVector bounds = {times}; - bounds.reserve(shape().dimensions_size() + 1); - for (int64 bound : shape().dimensions()) { - bounds.push_back(bound); - } - auto literal = - MakeUnique(ShapeUtil::MakeShape(shape().element_type(), bounds)); - int64 elements = ShapeUtil::ElementsIn(literal->shape()); - if (elements == 0) { - return literal; - } - - DimensionVector output_indices(bounds.size(), 0); - tensorflow::gtl::ArraySlice input_indices = output_indices; - input_indices.remove_prefix(1); - - bool done = false; - while (!done) { - const auto element = Get(input_indices); - literal->Set(output_indices, element); - - done = true; - for (int n = 0; n < output_indices.size(); ++n) { - ++output_indices[n]; - if (output_indices[n] < bounds[n]) { - done = false; - break; - } - output_indices[n] = 0; - } - } - return literal; -} - template -/* static */ StatusOr> Literal::CreateRandomLiteral( +/* static */ StatusOr> +LiteralUtil::CreateRandomLiteral( const Shape& shape, const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; @@ -1678,8 +593,9 @@ template } template -/* static */ StatusOr> Literal::CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev) { +/* static */ StatusOr> +LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, + T stddev) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; std::normal_distribution generator(mean, stddev); return CreateRandomLiteral( @@ -1689,8 +605,8 @@ template } template -/* static */ StatusOr> Literal::CreateRandomLiteral( - const Shape& shape, T mean, T stddev) { +/* static */ StatusOr> +LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { std::minstd_rand0 engine; return CreateRandomLiteral(shape, &engine, mean, stddev); } diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index fed0e58e66a04df2ff9554cb0dd0053b7c669803..69ef4f7a2f3ea559a334a11cbe8392b610742bab 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -134,8 +134,7 @@ void MetricTableReport::AppendHeader() { void MetricTableReport::AppendCategoryTable() { const std::vector categories = MakeCategories(&entries_); - AppendLine("********** categories table **********"); - AppendLine("The left hand side numbers are ", metric_name_, "."); + AppendLine("********** categories table for ", metric_name_, " **********"); AppendLine(); double metric_sum = UnaccountedMetric(); @@ -185,8 +184,8 @@ void MetricTableReport::AppendCategoryTable() { } void MetricTableReport::AppendEntryTable() { - AppendLine("********** ", entry_name_, " table **********"); - AppendLine("The left hand side numbers are ", metric_name_, "."); + AppendLine("********** ", entry_name_, " table for ", metric_name_, + " **********"); AppendLine(); double metric_sum = UnaccountedMetric(); diff --git a/tensorflow/compiler/xla/overflow_util.h b/tensorflow/compiler/xla/overflow_util.h new file mode 100644 index 0000000000000000000000000000000000000000..8657d3a4bfa992b9ca0619f24923fd4542eed894 --- /dev/null +++ b/tensorflow/compiler/xla/overflow_util.h @@ -0,0 +1,50 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Multiply two nonnegative int64's, returning negative for overflow +inline int64 MultiplyWithoutOverflow(const int64 x, const int64 y) { + // Multiply in uint64 rather than int64 since signed overflow is undefined. + // Negative values will wrap around to large unsigned values in the casts + // (see section 4.7 [conv.integral] of the C++14 standard). + const uint64 ux = x; + const uint64 uy = y; + const uint64 uxy = ux * uy; + + // Check if we overflow uint64, using a cheap check if both inputs are small + if (TF_PREDICT_FALSE((ux | uy) >> 32 != 0)) { + // Ensure nonnegativity. Note that negative numbers will appear "large" + // to the unsigned comparisons above. + CHECK(x >= 0 && y >= 0); + + // Otherwise, detect overflow using a division + if (ux != 0 && uxy / ux != uy) return -1; + } + + // Cast back to signed. Any negative value will signal an error. + return static_cast(uxy); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_ diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 857aae0a7982a57bb3057a6f267f5f033a0fdde4..6b7fd10d63f8f97b0e0bf7570488c06323368d75 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h index 45a9fe012784d3e4168e7549240dec962aa1a17a..98dccaa9a246520bf60217b96d67a13a24c34b4a 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.h +++ b/tensorflow/compiler/xla/packed_literal_reader.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 143c9a2366be5786b7ef2148580caeb97d67d2d8..b16147e3be71771269d8b7a18528bef3a8c72d99 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -85,5 +85,10 @@ PrimitiveType ComplexComponentType(PrimitiveType complex_type) { } } +bool IsArrayType(PrimitiveType primitive_type) { + return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && + primitive_type != OPAQUE && primitive_type != TOKEN; +} + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index b26a10ade63a5dad3bf8f9f3a2a33c3c5e67bdb2..889e9a1ceca675689406d255d348c82c398563aa 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -133,6 +133,9 @@ bool IsUnsignedIntegralType(PrimitiveType type); bool IsIntegralType(PrimitiveType type); +// Returns true if values of the given primitive type are held in array shapes. +bool IsArrayType(PrimitiveType primitive_type); + // Returns the number of bits in the representation for a given type. int BitWidth(PrimitiveType type); diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 83834c1ff65ea2f9989fe08279c29056d9070adb..c8f2d65c223ccfe20862954c224d016cca421812 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -33,6 +33,7 @@ cc_library( srcs = ["numpy_bridge.cc"], hdrs = ["numpy_bridge.h"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -52,9 +53,9 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", - "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", @@ -70,7 +71,7 @@ tf_py_wrap_cc( deps = [ ":local_computation_builder", ":numpy_bridge", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:cpu_plugin", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 445cee1aa7b462f7ae2b6b0771ff57f0c8f3db99..8246f76d3443d58f4174cc4f86100f54d6b46928 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/local_computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" @@ -174,73 +176,73 @@ StatusOr> CompiledLocalComputation::Execute( GetReplicaCount()); for (int replica = 0; replica < GetReplicaCount(); ++replica) { - pool.Schedule([this, client, replica, &arguments, &shapes_with_layout, - &results] { - StatusOr device_ordinal_status = - client->ReplicaNumberToDeviceOrdinal(replica); - if (!device_ordinal_status.ok()) { - results[replica] = device_ordinal_status.status(); - return; - } - const int device_ordinal = device_ordinal_status.ValueOrDie(); - VLOG(3) << "Replica " << replica - << " mapped to device ordinal for execution: " - << device_ordinal; - - // Transfer arguments in - std::vector scoped_buffers; - scoped_buffers.reserve(arguments.size()); - for (int i = 0; i < arguments.size(); ++i) { - const Literal& argument = arguments[i]; - const tensorflow::gtl::optional& shape_with_layout = - shapes_with_layout[i]; - - StatusOr pushed; - if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - pushed = ToBuffer(client, device_ordinal, *relaid); - } else { - pushed = ToBuffer(client, device_ordinal, argument); - } - if (!pushed.ok()) { - results[replica] = pushed.status(); - return; - } - - scoped_buffers.push_back(std::move(pushed).ValueOrDie()); - } - - // Execute - std::vector argument_buffers; - argument_buffers.reserve(scoped_buffers.size()); - for (auto& buffer : scoped_buffers) { - argument_buffers.push_back(&buffer); - } - - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) - .ConsumeValueOrDie(); - - ExecutableRunOptions options; - options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); - StatusOr result_buffer_status = - executable_->Run(argument_buffers, options); - if (!result_buffer_status.ok()) { - results[replica] = result_buffer_status.status(); - return; - } - - // Transfer result out - results[replica] = client->ShapedBufferToLiteral( - std::move(result_buffer_status).ValueOrDie()); - }); + pool.Schedule( + [this, client, replica, &arguments, &shapes_with_layout, &results] { + StatusOr device_ordinal_status = + client->ReplicaNumberToDeviceOrdinal(replica); + if (!device_ordinal_status.ok()) { + results[replica] = device_ordinal_status.status(); + return; + } + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica " << replica + << " mapped to device ordinal for execution: " + << device_ordinal; + + // Transfer arguments in + std::vector scoped_buffers; + scoped_buffers.reserve(arguments.size()); + for (int i = 0; i < arguments.size(); ++i) { + const Literal& argument = arguments[i]; + const tensorflow::gtl::optional& shape_with_layout = + shapes_with_layout[i]; + + StatusOr pushed; + if (shape_with_layout) { + std::unique_ptr relaid = + argument.Relayout(shape_with_layout.value()); + pushed = ToBuffer(client, device_ordinal, *relaid); + } else { + pushed = ToBuffer(client, device_ordinal, argument); + } + if (!pushed.ok()) { + results[replica] = pushed.status(); + return; + } + + scoped_buffers.push_back(std::move(pushed).ValueOrDie()); + } + + // Execute + std::vector argument_buffers; + argument_buffers.reserve(scoped_buffers.size()); + for (auto& buffer : scoped_buffers) { + argument_buffers.push_back(&buffer); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) + .ConsumeValueOrDie(); + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + StatusOr result_buffer_status = + executable_->Run(argument_buffers, options); + if (!result_buffer_status.ok()) { + results[replica] = result_buffer_status.status(); + return; + } + + // Transfer result out + results[replica] = client->ShapedBufferToLiteral( + std::move(result_buffer_status).ValueOrDie()); + }); } } @@ -341,14 +343,11 @@ StatusOr LocalComputationBuilder::Build() { LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, const Shape& shape, const string& name) { - return builder_.Parameter(parameter_number, shape, name); + return xla::Parameter(&builder_, parameter_number, shape, name); } -std::unique_ptr LocalComputationBuilder::GetShape( - const LocalOp& operand) { - auto result = MakeUnique(); - *result = builder_.GetShape(operand.op()).ValueOrDie(); - return result; +StatusOr LocalComputationBuilder::GetShape(const LocalOp& operand) { + return builder_.GetShape(operand.op()); } StatusOr LocalComputationBuilder::GetReturnValueShape() { @@ -357,72 +356,70 @@ StatusOr LocalComputationBuilder::GetReturnValueShape() { } LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { - return builder_.Infeed(shape); + return xla::Infeed(&builder_, shape); } void LocalComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config) { - builder_.Outfeed(operand.op(), shape, outfeed_config); + xla::Outfeed(operand.op(), shape, outfeed_config); } LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { - return builder_.ConstantLiteral(literal); + return xla::ConstantLiteral(&builder_, literal); } LocalOp LocalComputationBuilder::Broadcast( const LocalOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - return builder_.Broadcast(operand.op(), broadcast_sizes); + return xla::Broadcast(operand.op(), broadcast_sizes); } LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, const LocalOp& padding_value, const PaddingConfig& padding_config) { - return builder_.Pad(operand.op(), padding_value.op(), padding_config); + return xla::Pad(operand.op(), padding_value.op(), padding_config); } LocalOp LocalComputationBuilder::Reshape( const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - return builder_.Reshape(operand.op(), dimensions, new_sizes); + return xla::Reshape(operand.op(), dimensions, new_sizes); } LocalOp LocalComputationBuilder::Collapse( const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { - return builder_.Collapse(operand.op(), dimensions); + return xla::Collapse(operand.op(), dimensions); } LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { - return builder_.CrossReplicaSum(operand.op()); + return xla::CrossReplicaSum(operand.op()); } LocalOp LocalComputationBuilder::Slice( const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return builder_.Slice(operand.op(), start_indices, limit_indices, strides); + return xla::Slice(operand.op(), start_indices, limit_indices, strides); } LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { - return builder_.SliceInDim(operand.op(), start_index, limit_index, stride, - dimno); + return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno); } LocalOp LocalComputationBuilder::DynamicSlice( const LocalOp& operand, const LocalOp& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return builder_.DynamicSlice(operand.op(), start_indices.op(), slice_sizes); + return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } LocalOp LocalComputationBuilder::DynamicUpdateSlice( const LocalOp& operand, const LocalOp& update, const LocalOp& start_indices) { - return builder_.DynamicUpdateSlice(operand.op(), update.op(), - start_indices.op()); + return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); } LocalOp LocalComputationBuilder::ConcatInDim( @@ -432,7 +429,7 @@ LocalOp LocalComputationBuilder::ConcatInDim( for (const auto& op : operands) { xla_ops.push_back(op.op()); } - return builder_.ConcatInDim(xla_ops, dimension); + return xla::ConcatInDim(&builder_, xla_ops, dimension); } LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( @@ -442,7 +439,7 @@ LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( tensorflow::gtl::ArraySlice> padding, const LocalOp& source, const LocalOp& init_value, const LocalComputation& scatter) { - return builder_.SelectAndScatterWithGeneralPadding( + return xla::SelectAndScatterWithGeneralPadding( operand.op(), select.computation(), window_dimensions, window_strides, padding, source.op(), init_value.op(), scatter.computation()); } @@ -455,22 +452,22 @@ LocalOp LocalComputationBuilder::Tuple( xla_ops.push_back(op.op()); } - return builder_.Tuple(xla_ops); + return xla::Tuple(&builder_, xla_ops); } LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, int64 index) { - return builder_.GetTupleElement(tuple_data.op(), index); + return xla::GetTupleElement(tuple_data.op(), index); } LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { - return builder_.Dot(lhs.op(), rhs.op()); + return xla::Dot(lhs.op(), rhs.op()); } LocalOp LocalComputationBuilder::DotGeneral( const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return builder_.DotGeneral(lhs.op(), rhs.op(), dimension_numbers); + return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } LocalOp LocalComputationBuilder::ConvGeneralDilated( @@ -480,14 +477,18 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated( tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return builder_.ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, - padding, lhs_dilation, rhs_dilation, - dimension_numbers); + return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding, + lhs_dilation, rhs_dilation, dimension_numbers); } LocalOp LocalComputationBuilder::ConvertElementType( const LocalOp& operand, PrimitiveType new_element_type) { - return builder_.ConvertElementType(operand.op(), new_element_type); + return xla::ConvertElementType(operand.op(), new_element_type); +} + +LocalOp LocalComputationBuilder::BitcastConvertType( + const LocalOp& operand, PrimitiveType new_element_type) { + return xla::BitcastConvertType(operand.op(), new_element_type); } LocalOp LocalComputationBuilder::Call( @@ -498,46 +499,39 @@ LocalOp LocalComputationBuilder::Call( for (const auto& op : operands) { xla_ops.push_back(op.op()); } - return builder_.Call(local_computation.computation(), xla_ops); + return xla::Call(&builder_, local_computation.computation(), xla_ops); } LocalOp LocalComputationBuilder::Transpose( const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { - return builder_.Transpose(operand.op(), permutation); + return xla::Transpose(operand.op(), permutation); } LocalOp LocalComputationBuilder::Rev( const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { - return builder_.Rev(operand.op(), dimensions); + return xla::Rev(operand.op(), dimensions); } LocalOp LocalComputationBuilder::Map( tensorflow::gtl::ArraySlice operands, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { + tensorflow::gtl::ArraySlice dimensions) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { xla_ops.push_back(op.op()); } - std::vector static_xla_ops; - static_xla_ops.reserve(static_operands.size()); - for (const auto& op : static_operands) { - static_xla_ops.push_back(op.op()); - } - - return builder_.Map(xla_ops, local_computation.computation(), dimensions, - static_xla_ops); + return xla::Map(&builder_, xla_ops, local_computation.computation(), + dimensions); } LocalOp LocalComputationBuilder::Reduce( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return builder_.Reduce(operand.op(), init_value.op(), - local_computation.computation(), dimensions_to_reduce); + return xla::Reduce(operand.op(), init_value.op(), + local_computation.computation(), dimensions_to_reduce); } LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( @@ -546,7 +540,7 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { - return builder_.ReduceWindowWithGeneralPadding( + return xla::ReduceWindowWithGeneralPadding( operand.op(), init_value.op(), local_computation.computation(), window_dimensions, window_strides, padding); } @@ -554,27 +548,27 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, const LocalOp& sigma, const Shape& shape) { - return builder_.RngNormal(mu.op(), sigma.op(), shape); + return xla::RngNormal(mu.op(), sigma.op(), shape); } LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape) { - return builder_.RngUniform(a.op(), b.op(), shape); + return xla::RngUniform(a.op(), b.op(), shape); } LocalOp LocalComputationBuilder::While(const LocalComputation& condition, const LocalComputation& body, const LocalOp& init) { - return builder_.While(condition.computation(), body.computation(), init.op()); + return xla::While(condition.computation(), body.computation(), init.op()); } LocalOp LocalComputationBuilder::Conditional( const LocalOp& predicate, const LocalOp& true_operand, const LocalComputation& true_computation, const LocalOp& false_operand, const LocalComputation& false_computation) { - return builder_.Conditional( - predicate.op(), true_operand.op(), true_computation.computation(), - false_operand.op(), false_computation.computation()); + return xla::Conditional(predicate.op(), true_operand.op(), + true_computation.computation(), false_operand.op(), + false_computation.computation()); } StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { @@ -590,7 +584,7 @@ StatusOr LocalComputationBuilder::BuildConstantSubGraph( #define _FORWARD(method_name, return_sig, args_sig, args) \ return_sig LocalComputationBuilder::method_name args_sig { \ - return builder_.method_name args; \ + return xla::method_name args; \ } #define _FORWARD_UNOP(method_name) \ @@ -624,6 +618,13 @@ _FORWARD_BINOP(Max) _FORWARD_BINOP(Min) _FORWARD_BINOP(And) _FORWARD_BINOP(Or) +_FORWARD_BINOP(Xor) +_FORWARD_BINOP(ShiftLeft) +_FORWARD_BINOP(ShiftRightArithmetic) +_FORWARD_BINOP(ShiftRightLogical) +_FORWARD_BINOP(Atan2) +_FORWARD_BINOP(Pow) +_FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) @@ -637,13 +638,30 @@ _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) _FORWARD_UNOP(Tanh) -_FORWARD_UNOP(SqrtF32) -_FORWARD_UNOP(SquareF32) -_FORWARD_BINOP(Pow) _FORWARD_UNOP(IsFinite) -_FORWARD_UNOP(ReciprocalF32) _FORWARD_UNOP(Neg) _FORWARD_UNOP(Sort) +_FORWARD_UNOP(Sqrt) +_FORWARD_UNOP(Rsqrt) +_FORWARD_UNOP(Square) +_FORWARD_UNOP(Reciprocal) +_FORWARD_UNOP(Erfc) +_FORWARD_UNOP(Erf) +_FORWARD_UNOP(ErfInv) +_FORWARD_UNOP(Lgamma) +_FORWARD_UNOP(Digamma) +_FORWARD_UNOP(Acos) +_FORWARD_UNOP(Asin) +_FORWARD_UNOP(Atan) +_FORWARD_UNOP(Tan) +_FORWARD_UNOP(Acosh) +_FORWARD_UNOP(Asinh) +_FORWARD_UNOP(Atanh) +_FORWARD_UNOP(Cosh) +_FORWARD_UNOP(Sinh) +_FORWARD_UNOP(Real) +_FORWARD_UNOP(Imag) +_FORWARD_UNOP(Conj) #undef _FORWARD #undef _FORWARD_UNOP diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 0da3964676e9c6729229686f38bb05c8b2427bff..a568c24c6376e1fe17f5e5a4f6626bf0970985a3 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -187,7 +187,7 @@ class LocalComputationBuilder { LocalOp Parameter(int64 parameter_number, const Shape& shape, const string& name); - std::unique_ptr GetShape(const LocalOp& operand); + StatusOr GetShape(const LocalOp& operand); // Returns the shape of the current return value for the computation. StatusOr GetReturnValueShape(); @@ -259,6 +259,9 @@ class LocalComputationBuilder { LocalOp ConvertElementType(const LocalOp& operand, PrimitiveType new_element_type); + LocalOp BitcastConvertType(const LocalOp& operand, + PrimitiveType new_element_type); + LocalOp Call(const LocalComputation& local_computation, tensorflow::gtl::ArraySlice operands); @@ -270,8 +273,7 @@ class LocalComputationBuilder { LocalOp Map(tensorflow::gtl::ArraySlice operands, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + tensorflow::gtl::ArraySlice dimensions); LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, @@ -333,6 +335,13 @@ class LocalComputationBuilder { _FORWARD_BINOP(Min) _FORWARD_BINOP(And) _FORWARD_BINOP(Or) + _FORWARD_BINOP(Xor) + _FORWARD_BINOP(ShiftLeft) + _FORWARD_BINOP(ShiftRightArithmetic) + _FORWARD_BINOP(ShiftRightLogical) + _FORWARD_BINOP(Atan2) + _FORWARD_BINOP(Pow) + _FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) @@ -346,13 +355,30 @@ class LocalComputationBuilder { _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) _FORWARD_UNOP(Tanh) - _FORWARD_UNOP(SqrtF32) - _FORWARD_UNOP(SquareF32) - _FORWARD_BINOP(Pow) _FORWARD_UNOP(IsFinite) - _FORWARD_UNOP(ReciprocalF32) _FORWARD_UNOP(Neg) _FORWARD_UNOP(Sort) + _FORWARD_UNOP(Sqrt) + _FORWARD_UNOP(Rsqrt) + _FORWARD_UNOP(Square) + _FORWARD_UNOP(Reciprocal) + _FORWARD_UNOP(Erfc) + _FORWARD_UNOP(Erf) + _FORWARD_UNOP(ErfInv) + _FORWARD_UNOP(Lgamma) + _FORWARD_UNOP(Digamma) + _FORWARD_UNOP(Acos) + _FORWARD_UNOP(Asin) + _FORWARD_UNOP(Atan) + _FORWARD_UNOP(Tan) + _FORWARD_UNOP(Acosh) + _FORWARD_UNOP(Asinh) + _FORWARD_UNOP(Atanh) + _FORWARD_UNOP(Cosh) + _FORWARD_UNOP(Sinh) + _FORWARD_UNOP(Real) + _FORWARD_UNOP(Imag) + _FORWARD_UNOP(Conj) #undef _FORWARD #undef _FORWARD_UNOP diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 477df6fde25d0db760e08df9d335bd12e31ccb55..5d5a955bfee35b38a61b9a9f792c1b31259ce044 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -109,7 +109,7 @@ limitations under the License. // Must be included first #include "tensorflow/python/lib/core/numpy.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -957,6 +957,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Tuple; %unignore xla::swig::LocalComputationBuilder::GetTupleElement; %unignore xla::swig::LocalComputationBuilder::ConvertElementType; +%unignore xla::swig::LocalComputationBuilder::BitcastConvertType; %unignore xla::swig::LocalComputationBuilder::Call; %unignore xla::swig::LocalComputationBuilder::Transpose; %unignore xla::swig::LocalComputationBuilder::Rev; @@ -988,6 +989,10 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Min; %unignore xla::swig::LocalComputationBuilder::And; %unignore xla::swig::LocalComputationBuilder::Or; +%unignore xla::swig::LocalComputationBuilder::Xor; +%unignore xla::swig::LocalComputationBuilder::ShiftLeft; +%unignore xla::swig::LocalComputationBuilder::ShiftRightArithmetic; +%unignore xla::swig::LocalComputationBuilder::ShiftRightLogical; %unignore xla::swig::LocalComputationBuilder::Not; %unignore xla::swig::LocalComputationBuilder::Abs; %unignore xla::swig::LocalComputationBuilder::Exp; @@ -1001,13 +1006,33 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Cos; %unignore xla::swig::LocalComputationBuilder::Sin; %unignore xla::swig::LocalComputationBuilder::Tanh; -%unignore xla::swig::LocalComputationBuilder::SqrtF32; -%unignore xla::swig::LocalComputationBuilder::SquareF32; -%unignore xla::swig::LocalComputationBuilder::Pow; +%unignore xla::swig::LocalComputationBuilder::Atan2; %unignore xla::swig::LocalComputationBuilder::IsFinite; -%unignore xla::swig::LocalComputationBuilder::ReciprocalF32; +%unignore xla::swig::LocalComputationBuilder::Pow; %unignore xla::swig::LocalComputationBuilder::Neg; %unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::LocalComputationBuilder::Sqrt; +%unignore xla::swig::LocalComputationBuilder::Rsqrt; +%unignore xla::swig::LocalComputationBuilder::Square; +%unignore xla::swig::LocalComputationBuilder::Reciprocal; +%unignore xla::swig::LocalComputationBuilder::Erfc; +%unignore xla::swig::LocalComputationBuilder::Erf; +%unignore xla::swig::LocalComputationBuilder::ErfInv; +%unignore xla::swig::LocalComputationBuilder::Lgamma; +%unignore xla::swig::LocalComputationBuilder::Digamma; +%unignore xla::swig::LocalComputationBuilder::Acos; +%unignore xla::swig::LocalComputationBuilder::Asin; +%unignore xla::swig::LocalComputationBuilder::Atan; +%unignore xla::swig::LocalComputationBuilder::Tan; +%unignore xla::swig::LocalComputationBuilder::Acosh; +%unignore xla::swig::LocalComputationBuilder::Asinh; +%unignore xla::swig::LocalComputationBuilder::Atanh; +%unignore xla::swig::LocalComputationBuilder::Cosh; +%unignore xla::swig::LocalComputationBuilder::Sinh; +%unignore xla::swig::LocalComputationBuilder::Real; +%unignore xla::swig::LocalComputationBuilder::Imag; +%unignore xla::swig::LocalComputationBuilder::Conj; +%unignore xla::swig::LocalComputationBuilder::Complex; %unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DeleteLocalShapedBuffer; %unignore xla::swig::DeleteLocalComputation; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 68648a3a176363de69a56ecb8070f82862874e94..6f665faf61b25b23a32ce4d0a012543ba18d7e64 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" @@ -49,6 +50,8 @@ int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { return NPY_FLOAT32; case F64: return NPY_FLOAT64; + case C64: + return NPY_COMPLEX64; case TUPLE: return NPY_OBJECT; default: @@ -82,6 +85,8 @@ PrimitiveType NumpyTypeToPrimitiveType(int np_type) { return F32; case NPY_FLOAT64: return F64; + case NPY_COMPLEX64: + return C64; case NPY_OBJECT: return TUPLE; default: @@ -103,6 +108,7 @@ bool NumpyTypeIsValid(int np_type) { case NPY_FLOAT16: case NPY_FLOAT32: case NPY_FLOAT64: + case NPY_COMPLEX64: case NPY_OBJECT: return true; default: @@ -374,7 +380,7 @@ StatusOr> XlaLiteralFromPyObject(PyObject* o) { TF_ASSIGN_OR_RETURN(auto literal, XlaLiteralFromPyObject(element)); elements.push_back(std::move(literal)); } - return Literal::MakeTupleOwned(std::move(elements)); + return LiteralUtil::MakeTupleOwned(std::move(elements)); } else if (PyArray_Check(o)) { PyArrayObject* py_array = reinterpret_cast(o); int rank = PyArray_NDIM(py_array); @@ -383,7 +389,7 @@ StatusOr> XlaLiteralFromPyObject(PyObject* o) { dimensions[i] = PyArray_DIM(py_array, i); } int np_type = PyArray_TYPE(py_array); - auto literal = Literal::CreateFromDimensions( + auto literal = LiteralUtil::CreateFromDimensions( NumpyTypeToPrimitiveType(np_type), dimensions); TF_RETURN_IF_ERROR( CopyNumpyArrayToLiteral(np_type, py_array, literal.get())); @@ -424,6 +430,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_FLOAT64: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_COMPLEX64: + CopyNumpyArrayToLiteral(py_array, literal); + break; default: return InvalidArgument( "No XLA literal container for Numpy type number: %d", np_type); @@ -461,6 +470,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, case NPY_FLOAT64: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_COMPLEX64: + CopyLiteralToNumpyArray(literal, py_array); + break; default: LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; } diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 64f0aae0f9790f0199ac6cb931a5c9f6dc356f4c..a67c93a4fb7413f9bbcb9afd92c36fd118836e1f 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -25,7 +25,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/python/lib/core/numpy.h" diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c025127c3cf1871d4def1297ed36c046cae61d4b..a2c6fc344d192265d536ef7e23ad5c6d7c847014 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -99,12 +99,30 @@ _UNARY_OPS = [ 'Cos', 'Sin', 'Tanh', - 'SqrtF32', - 'SquareF32', 'IsFinite', - 'ReciprocalF32', + 'Sqrt', + 'Rsqrt', + 'Square', + 'Reciprocal', 'Neg', 'Sort', + 'Erf', + 'Erfc', + 'ErfInv', + 'Lgamma', + 'Digamma', + 'Acos', + 'Asin', + 'Atan', + 'Tan', + 'Acosh', + 'Asinh', + 'Atanh', + 'Cosh', + 'Sinh', + 'Real', + 'Imag', + 'Conj', ] _BINARY_OPS = [ @@ -123,7 +141,13 @@ _BINARY_OPS = [ 'Min', 'And', 'Or', + 'Xor', 'Pow', + 'ShiftLeft', + 'ShiftRightArithmetic', + 'ShiftRightLogical', + 'Atan2', + 'Complex', ] @@ -257,9 +281,12 @@ class Shape(object): self._dimensions == other._dimensions and self._minor_to_major == other._minor_to_major) + def __ne__(self, other): + return not self == other + def __repr__(self): return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, ' - '_is_tuple={!r}), _minor_to_major={!r}').format( + '_is_tuple={!r}, _minor_to_major={!r})').format( self._dtype, self._dimensions, self._is_tuple, self._minor_to_major) @@ -457,14 +484,16 @@ class LocalComputation(object): if self.is_compiled: raise ValueError('Attempt to compile a compiled local XLA computation.') + result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) + if layout_fn: argument_shapes = [ shape.map_leaves(layout_fn) for shape in argument_shapes ] - result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) result_shape = result_shape.map_leaves(layout_fn) - compile_options = compile_options or CompileOptions() - compile_options.result_shape = result_shape + + compile_options = compile_options or CompileOptions() + compile_options.result_shape = result_shape return LocalComputation( self.c_local_computation.Compile(argument_shapes, compile_options), is_compiled=True) @@ -696,6 +725,18 @@ class ComputationBuilder(object): """ return self._client.ConvertElementType(operand, new_element_type) + def BitcastConvertType(self, operand, new_element_type): + """Enqueues a bitcast type conversion operation onto the computation. + + Args: + operand: the operand to convert. + new_element_type: the target primitive type. + + Returns: + A LocalOp representing the added conversion op. + """ + return self._client.BitcastConvertType(operand, new_element_type) + def GetShape(self, operand): return _wrap_shape(self._client.GetShape(operand)) @@ -905,20 +946,19 @@ class ComputationBuilder(object): """ return self._client.Call(computation_to_apply.c_local_computation, operands) - def Map(self, operands, computation_to_apply, dimensions, static_operands=()): + def Map(self, operands, computation_to_apply, dimensions): """Enqueues a map operation onto the computation. Args: operands: an iterable of LocalOp. computation_to_apply: a Computation object. dimensions: dimensions over which to apply map the function. - static_operands: auxiliary arguments passed to the applied computation. Returns: A LocalOp representing the added Map op. """ return self._client.Map(operands, computation_to_apply.c_local_computation, - dimensions, static_operands) + dimensions) def Reduce(self, operand, init_value, computation_to_apply, dimensions): """Enqueues a reduction operation onto the computation. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 71e1d60a4e23dbfef333223c396e109533da9365..fd98e19457f61aade947aa354d2e415148d127f6 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -157,6 +157,13 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayBool([True, True, False, False]))) self._ExecuteAndCompareExact(c, expected=[True, True, True, False]) + def testBooleanXor(self): + c = self._NewComputation() + c.Xor( + c.Constant(NumpyArrayBool([True, False, True, False])), + c.Constant(NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) + def testSum2DF32(self): c = self._NewComputation() c.Add( @@ -164,6 +171,24 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + def testShiftLeft(self): + c = self._NewComputation() + c.ShiftLeft(c.Constant(NumpyArrayS32([3])), + c.Constant(NumpyArrayS32([2]))) + self._ExecuteAndCompareClose(c, expected=[12]) + + def testShiftRightArithmetic(self): + c = self._NewComputation() + c.ShiftRightArithmetic(c.Constant(NumpyArrayS32([-2])), + c.Constant(NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[-1]) + + def testShiftRightLogical(self): + c = self._NewComputation() + c.ShiftRightLogical(c.Constant(NumpyArrayS32([-1])), + c.Constant(NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[2**31 - 1]) + def testGetProto(self): c = self._NewComputation() c.Add( @@ -464,6 +489,34 @@ class SingleOpTest(LocalComputationTest): for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): _ConvertAndTest(x, src_dtype, dst_dtype) + def testBitcastConvertType(self): + xla_x32_types = { + np.int32: xla_client.xla_data_pb2.S32, + np.float32: xla_client.xla_data_pb2.F32, + } + + xla_x64_types = { + np.int64: xla_client.xla_data_pb2.S64, + np.float64: xla_client.xla_data_pb2.F64, + } + + def _ConvertAndTest(template, src_dtype, dst_dtype, dst_etype): + c = self._NewComputation() + x = c.Constant(np.array(template, dtype=src_dtype)) + c.BitcastConvertType(x, dst_etype) + + result = c.Build().Compile().Execute() + expected = np.array(template, src_dtype).view(dst_dtype) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(result.dtype, expected.dtype) + np.testing.assert_equal(result, expected) + + x = [0, 1, 0, 0, 1] + for xla_types in [xla_x32_types, xla_x64_types]: + for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): + _ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype]) + def testCrossReplicaSumOneReplica(self): samples = [ NumpyArrayF32(42.0), @@ -1168,14 +1221,6 @@ class EmbeddedComputationsTest(LocalComputationTest): self._CreateBinaryDivF64Computation(), [0]) self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) - def DISABLED_testMapWithStaticOperands(self): - c = self._NewComputation() - factor = c.ConstantF32Scalar(3.0) - c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], - self._CreateMulF32ByParamComputation(), [0], - static_operands=[factor]) - self._ExecuteAndCompareClose(c, expected=[3.0, 6.0, 9.0, 12.0]) - def testSelectAndScatterF32(self): c = self._NewComputation() c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), diff --git a/tensorflow/compiler/xla/python_api/BUILD b/tensorflow/compiler/xla/python_api/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..d790c4db6c466a2bf4d2cf30365749fb901f74a0 --- /dev/null +++ b/tensorflow/compiler/xla/python_api/BUILD @@ -0,0 +1,38 @@ +# Description: +# Python API for XLA. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +py_library( + name = "types", + srcs = ["types.py"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/python:dtypes", + "//tensorflow/python:platform", + "//third_party/py/numpy", + ], +) + +py_library( + name = "xla_shape", + srcs = ["xla_shape.py"], + visibility = ["//visibility:public"], + deps = [ + ":types", + "//tensorflow/compiler/xla:xla_data_proto_py", + ], +) + +py_library( + name = "xla_literal", + srcs = ["xla_literal.py"], + visibility = ["//visibility:public"], + deps = [ + ":types", + ":xla_shape", + "//tensorflow/compiler/xla:xla_data_proto_py", + ], +) diff --git a/tensorflow/compiler/xla/python_api/types.py b/tensorflow/compiler/xla/python_api/types.py new file mode 100644 index 0000000000000000000000000000000000000000..57dfce3971b829d2a3052d347e5d2d322db0c841 --- /dev/null +++ b/tensorflow/compiler/xla/python_api/types.py @@ -0,0 +1,131 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""Utilities for XLA-specific Python types.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import numpy as _np # Avoids becoming a part of public Tensorflow API. + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.python.framework import dtypes + +# Records corresponsence between a XLA primitive type and Python/Numpy types. +# +# primitive_type: value of type xla_data_pb2.PrimitiveType +# numpy_dtype: corresponsing Numpy "dtype" (like np.float32) +# literal_field_name: name of the field in the LiteralProto message elements +# of this type go into. +# literal_field_type: type of the field named 'literal_field_name'. +# +# TODO(eliben): figure out how to avoid knowing the extra Python type and the +# astype cast when writing into Literals. +TypeConversionRecord = collections.namedtuple('TypeConversionRecord', [ + 'primitive_type', 'numpy_dtype', 'literal_field_name', 'literal_field_type' +]) + +# Maps from XLA primitive types to TypeConversionRecord. +MAP_XLA_TYPE_TO_RECORD = { + xla_data_pb2.BF16: + TypeConversionRecord( + primitive_type=xla_data_pb2.BF16, + numpy_dtype=dtypes.bfloat16.as_numpy_dtype, + literal_field_name='bf16s', + literal_field_type=float), + xla_data_pb2.F16: + TypeConversionRecord( + primitive_type=xla_data_pb2.F16, + numpy_dtype=_np.float16, + literal_field_name='f16s', + literal_field_type=float), + xla_data_pb2.F32: + TypeConversionRecord( + primitive_type=xla_data_pb2.F32, + numpy_dtype=_np.float32, + literal_field_name='f32s', + literal_field_type=float), + xla_data_pb2.F64: + TypeConversionRecord( + primitive_type=xla_data_pb2.F64, + numpy_dtype=_np.float64, + literal_field_name='f64s', + literal_field_type=float), + xla_data_pb2.S8: + TypeConversionRecord( + primitive_type=xla_data_pb2.S8, + numpy_dtype=_np.int8, + literal_field_name='s8s', + literal_field_type=int), + xla_data_pb2.S16: + TypeConversionRecord( + primitive_type=xla_data_pb2.S16, + numpy_dtype=_np.int16, + literal_field_name='s16s', + literal_field_type=int), + xla_data_pb2.S32: + TypeConversionRecord( + primitive_type=xla_data_pb2.S32, + numpy_dtype=_np.int32, + literal_field_name='s32s', + literal_field_type=int), + xla_data_pb2.S64: + TypeConversionRecord( + primitive_type=xla_data_pb2.S64, + numpy_dtype=_np.int64, + literal_field_name='s64s', + literal_field_type=int), + xla_data_pb2.U8: + TypeConversionRecord( + primitive_type=xla_data_pb2.U8, + numpy_dtype=_np.uint8, + literal_field_name='s8s', + literal_field_type=int), + xla_data_pb2.U16: + TypeConversionRecord( + primitive_type=xla_data_pb2.U16, + numpy_dtype=_np.uint16, + literal_field_name='s16s', + literal_field_type=int), + xla_data_pb2.U32: + TypeConversionRecord( + primitive_type=xla_data_pb2.U32, + numpy_dtype=_np.uint32, + literal_field_name='s32s', + literal_field_type=int), + xla_data_pb2.U64: + TypeConversionRecord( + primitive_type=xla_data_pb2.U64, + numpy_dtype=_np.uint64, + literal_field_name='s64s', + literal_field_type=int), + xla_data_pb2.PRED: + TypeConversionRecord( + primitive_type=xla_data_pb2.PRED, + numpy_dtype=_np.bool, + literal_field_name='preds', + literal_field_type=bool) +} + +# Maps from Numpy dtypes to TypeConversionRecord. +# Note the conversion on the key. Numpy has a known issue wherein dtype hashing +# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, +# when keying by dtype in this dict, we use the string form of dtypes. +MAP_DTYPE_TO_RECORD = { + str(_np.dtype(record.numpy_dtype)): record + for record in MAP_XLA_TYPE_TO_RECORD.values() +} diff --git a/tensorflow/compiler/xla/python_api/xla_literal.py b/tensorflow/compiler/xla/python_api/xla_literal.py new file mode 100644 index 0000000000000000000000000000000000000000..757e41a78ad2b57d2ef6e1f3055160be22c7b3ed --- /dev/null +++ b/tensorflow/compiler/xla/python_api/xla_literal.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""XLA LiteralProto utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as _np # Avoids becoming a part of public Tensorflow API. + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python_api import types +from tensorflow.compiler.xla.python_api import xla_shape + + +def ConvertLiteralToNumpyArray(literal): + """Converts a XLA literal to a Numpy array.""" + element_type = literal.shape.element_type + if element_type == xla_data_pb2.TUPLE: + return tuple( + ConvertLiteralToNumpyArray(subliteral) + for subliteral in literal.tuple_literals) + + type_record = types.MAP_XLA_TYPE_TO_RECORD[element_type] + if not literal.shape.dimensions: + return _np.array( + getattr(literal, type_record.literal_field_name)[0], + type_record.numpy_dtype) + else: + # Infer the proper Numpy order from the LiteralProto's layout. The repeated + # field representing the array's content in the Literal is linearized. + # Reading is done in two steps: + # + # 1. Read the array as 1D from the LiteralProto repeated field. + # 2. Reshape the array to its proper shape, using the right order depending + # on the LiteralProto's layout. + layout_order = literal.shape.layout.minor_to_major + numpy_shape = tuple(literal.shape.dimensions) + if layout_order == range(len(literal.shape.dimensions)): + numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='F') + elif layout_order == range(len(literal.shape.dimensions) - 1, -1, -1): + numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C') + else: + raise NotImplementedError('Unsupported layout: {0}'.format(layout_order)) + ndarray = _np.array( + getattr(literal, type_record.literal_field_name), + copy=False, + dtype=type_record.numpy_dtype) + return numpy_reshaper(ndarray) + + +def _ConvertNumpyArrayToLiteral(ndarray): + """Converts a Numpy array to a XLA literal.""" + type_record = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)] + literal = xla_data_pb2.LiteralProto() + literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(ndarray).message) + + if ndarray.ndim == 0: + getattr(literal, type_record.literal_field_name).append( + _np.asscalar(ndarray.astype(type_record.literal_field_type))) + else: + # Ndarrays with boolean dtypes need special type conversion with protobufs + if ndarray.dtype in {_np.bool_, _np.dtype('bool')}: + for element in _np.nditer(ndarray): + getattr(literal, type_record.literal_field_name).append( + type_record.literal_field_type(element)) + else: + ndarray_flat = ndarray.ravel(order='A') + getattr(literal, type_record.literal_field_name).extend(ndarray_flat) + return literal + + +def ConvertNumpyArrayToLiteral(value): + """Converts a Numpy array or a nested tuple thereof to an XLA literal.""" + if isinstance(value, tuple): + literal = xla_data_pb2.LiteralProto() + literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(value).message) + for component in value: + component_literal = literal.tuple_literals.add() + component_literal.CopyFrom(ConvertNumpyArrayToLiteral(component)) + return literal + else: + return _ConvertNumpyArrayToLiteral(value) diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..f158f6b2410352432445f669155aff0af5526abf --- /dev/null +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -0,0 +1,155 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""XLA Shape utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as _np # Avoids becoming a part of public Tensorflow API. + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python_api import types + + +class Shape(object): + """Wraps a xla_data_pb2.Shape message with a convenient Python type. + + Provides direct access to the underlying xla_data_pb2.Shape message in the + message attribute, along with accessor wrappers to the message's fields. + Avoid direct access to .message unless interacting directly with protobuf APIs + like CopyFrom. In other words, prefer hauling the shape around in a Shape, and + only access .message when strictly required by the protobuf API. + """ + + def __init__(self, element_type, dimensions, layout=None): + """Creates a new XLA Shape. + + Args: + element_type: element type from xla_data_pb2. + dimensions: sequence of dimensions sizes (integers), or sequence + of Shapes in the case of a tuple, i.e. when element_type is + TUPLE. + layout: optional minor_to_major sequence for layout. If not given, the + default major-to-minor layout is used. + + Raises: + ValueError: if element_type is TUPLE but dimensions are not Shape objects. + """ + self.message = xla_data_pb2.Shape() + self.message.element_type = element_type + if element_type == xla_data_pb2.TUPLE: + if not all(isinstance(subshape, Shape) for subshape in dimensions): + raise ValueError( + 'XLA tuple requires sequence of Shape objects as dimensions') + self._tuple_shapes = tuple(dimensions) + for component_shape in self._tuple_shapes: + component_message = self.message.tuple_shapes.add() + component_message.CopyFrom(component_shape.message) + else: + self.message.dimensions.extend(dimensions) + if layout is None: + layout = list(reversed(range(len(dimensions)))) + self.message.layout.format = xla_data_pb2.DENSE + self.message.layout.minor_to_major.extend(layout) + + def element_type(self): + return self.message.element_type + + def is_tuple(self): + return self.element_type() == xla_data_pb2.TUPLE + + def dimensions(self): + if self.is_tuple(): + raise ValueError('Tuple shape has no dimensions. Try tuple_shapes()?') + return self.message.dimensions + + def tuple_shapes(self): + """If this is a tuple, returns its sequence of constituent Shape objects. + + Returns: + Tuple sub-shapes. + + Raises: + ValueError: if this is not a tuple. + """ + if not self.is_tuple(): + raise ValueError('tuple_shapes() called on a non-tuple shape') + return self._tuple_shapes + + def layout(self): + return self.message.layout + + @staticmethod + def from_pyval(pyval): + return CreateShapeFromNumpy(pyval) + + +def _CreateShapeFromNumpy(ndarray): # pylint: disable=invalid-name + """Create a Shape from a given Numpy array. + + Args: + ndarray: Numpy array. + + Returns: + A Shape object. + """ + element_type = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)].primitive_type + dimensions = ndarray.shape + + # Set the shape's layout based on the ordering of ndarray. + # Numpy arrays come in two orders: Fortran (column-major) and C (row-major). + if _np.isfortran(ndarray): + # Column-major layout. This corresponds to a "dimension order is + # minor-to-major" layout in XLA. + layout = range(ndarray.ndim) + else: + # Row-major layout. This corresponds to a "dimension order is + # major-to-minor" layout int XLA. + layout = list(reversed(xrange(ndarray.ndim))) + + return Shape(element_type, dimensions, layout) + + +def CreateShapeFromNumpy(value): # pylint: disable=invalid-name + """Create a Shape from a Numpy array or a nested tuple structure thereof. + + Args: + value: Numpy array or (possibly nested) tuple structure that bottoms out in + Numpy arrays. + + Returns: + A Shape object. + """ + if isinstance(value, tuple): + return Shape( + xla_data_pb2.TUPLE, + [CreateShapeFromNumpy(component) for component in value]) + else: + return _CreateShapeFromNumpy(value) + + +def CreateShapeFromDtypeAndTuple(dtype, shape_tuple): # pylint: disable=invalid-name + """Create a shape from a Numpy dtype and a sequence of nonnegative integers. + + Args: + dtype: a numpy dtype, e.g. np.dtype('int32'). + shape_tuple: a sequence of nonnegative integers. + + Returns: + A Shape object. + """ + element_type = types.MAP_DTYPE_TO_RECORD[str(dtype)].primitive_type + return Shape(element_type, shape_tuple) diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index c289c84cff743871a7126cb932d6cda823ceb696..a803520876952a0ab67ecb827b1f256c915335f9 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,7 +18,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -510,8 +511,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( std::pair lhs_dilation, std::pair rhs_dilation, ConvolutionDimensionNumbers dnums) { HloComputation::Builder b("ConvArray4DGeneralDimensionDilated"); - auto lhs_literal = Literal::CreateR4FromArray4D(lhs); - auto rhs_literal = Literal::CreateR4FromArray4D(rhs); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs); std::array ordered_kernel_strides; std::array ordered_input_dimensions; diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 9da9bc60a2025e63b57a3be9ed360d150f88d73c..8091bed4996a753649a5ecedda69a1ae48fb5897 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -53,7 +53,7 @@ class ReferenceUtilTest : public ::testing::Test { TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -65,7 +65,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { {11.f, 12.f}, }); auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -73,7 +73,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); - auto actual_literal = Literal::CreateR1(*result); + auto actual_literal = LiteralUtil::CreateR1(*result); LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, ErrorSpec(0.0001)); } @@ -81,13 +81,13 @@ TEST_F(ReferenceUtilTest, ReduceToColArray2D) { TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); - auto actual_literal = Literal::CreateR1(*result); + auto actual_literal = LiteralUtil::CreateR1(*result); LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { - auto result = Literal::CreateR1(ReferenceUtil::Reduce4DTo1D( + auto result = LiteralUtil::CreateR1(ReferenceUtil::Reduce4DTo1D( Array4D(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2}, [](float a, float b) { return a + b; })); LiteralTestUtil::ExpectR1Equal({0}, *result); @@ -96,7 +96,7 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, ErrorSpec(0.0001)); } @@ -106,7 +106,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { return value + row + col; }; auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -117,7 +117,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) { input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); - auto actual_literal = Literal::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); @@ -134,7 +134,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width); }; auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index); - auto actual_literal = Literal::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); @@ -144,7 +144,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { TEST_F(ReferenceUtilTest, SliceArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}}); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, *actual_literal, ErrorSpec(0.0001)); @@ -152,7 +152,7 @@ TEST_F(ReferenceUtilTest, SliceArray2D) { TEST_F(ReferenceUtilTest, SliceStridedArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}}); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, *actual_literal, ErrorSpec(0.0001)); @@ -164,7 +164,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) { auto result = ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 2, 2}}, {{1, 1, 1}}); - auto actual_literal = Literal::CreateR3FromArray3D(*result); + auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal, @@ -177,7 +177,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) { auto result = ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 3, 4}}, {{1, 2, 2}}); - auto actual_literal = Literal::CreateR3FromArray3D(*result); + auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, @@ -190,7 +190,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) { auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 2, 2, 2}}, {{1, 1, 1, 1}}); - auto actual_literal = Literal::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); LiteralTestUtil::ExpectR4Near( {{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}}, @@ -203,7 +203,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) { auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 3, 4, 5}}, {{1, 2, 2, 2}}); - auto actual_literal = Literal::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); LiteralTestUtil::ExpectR4Near( {{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}}, @@ -218,7 +218,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kSame); Array3D expected = {{{17, 28, 39, 20}}}; - auto actual_literal = Literal::CreateR3FromArray3D(*actual); + auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -231,7 +231,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) { ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kValid); Array3D expected = {{{17, 28, 39}}}; - auto actual_literal = Literal::CreateR3FromArray3D(*actual); + auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -266,7 +266,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) { })); // clang-format on - auto actual_literal = Literal::CreateR4FromArray4D(*actual); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -300,7 +300,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) { })); // clang-format on - auto actual_literal = Literal::CreateR4FromArray4D(*actual); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -356,7 +356,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { }}); // clang-format on - auto actual_literal = Literal::CreateR4FromArray4D(*actual); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -409,7 +409,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { Array4D expected({{{{2514, 2685}}}}); // clang-format on - auto actual_literal = Literal::CreateR4FromArray4D(*actual); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -422,7 +422,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) { auto actual = ReferenceUtil::ApplyElementwise2D( [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); - auto actual_literal = Literal::CreateR2FromArray2D(*actual); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, *actual_literal, ErrorSpec(0.0001)); } diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 1775666652f303ee095a11405537106a3eb9b056..44b22a5586dee3f7dd8ea0edbf9deb2090986ac8 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -39,10 +39,10 @@ tf_cc_binary( srcs = ["grpc_service_main.cc"], deps = [ ":grpc_service", + "//tensorflow:grpc++", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "@grpc//:grpc++", ], ) @@ -54,14 +54,14 @@ tf_cc_test( ], deps = [ ":grpc_stub", + "//tensorflow:grpc++", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@grpc//:grpc++", ], ) @@ -71,9 +71,9 @@ cc_library( hdrs = ["grpc_service.h"], deps = [ ":xla_service_proto", + "//tensorflow:grpc++", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "@grpc//:grpc++", ], ) diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index d7dd9786a2bbde2d18ae81a9a9d4cc4b2cc38411..67886761813f0bb45a600661b017be91ffeade73 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "grpcpp/security/credentials.h" #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_stub.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/lib/io/path.h" @@ -85,19 +85,19 @@ TEST_F(GRPCClientTestBase, ItsAlive) { TEST_F(GRPCClientTestBase, AxpyTenValues) { XlaBuilder builder("axpy_10"); - auto alpha = builder.ConstantR0(3.1415926535); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto y = builder.ConstantR1( - {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); - auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + auto alpha = ConstantR0(&builder, 3.1415926535); + auto x = ConstantR1( + &builder, {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto y = ConstantR1( + &builder, {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); + auto ax = Mul(alpha, x); + Add(ax, y); std::vector expected = { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; std::unique_ptr expected_literal = - Literal::CreateR1(expected); + LiteralUtil::CreateR1(expected); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( computation, {}, nullptr)); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 2942edbf71f29304ebb240f0547808ae0af1ac87..7331d2b54cfdb853ff6fcb2e02c5bdd9cf716779 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -32,6 +32,7 @@ tf_proto_library_py( name = "hlo_proto", # bzl adds a _py suffix only to the OSS target. srcs = ["hlo.proto"], visibility = ["//visibility:public"], + deps = ["//tensorflow/compiler/xla:xla_data_proto_py"], ) xla_proto_library( @@ -135,7 +136,7 @@ cc_library( ":hlo_dce", ":hlo_pass", ":tuple_simplifier", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -181,6 +182,7 @@ tf_cc_test( name = "shape_inference_test", srcs = ["shape_inference_test.cc"], deps = [ + ":hlo", ":shape_inference", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -226,6 +228,7 @@ cc_library( ":hlo", ":hlo_query", ":shape_inference", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -243,7 +246,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_evaluator", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -253,7 +256,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -293,6 +296,7 @@ cc_library( ":hlo_reachability", ":name_uniquer", "//tensorflow/compiler/xla:array", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_tree", @@ -395,6 +399,7 @@ tf_cc_test( deps = [ ":hlo_matchers", ":hlo_parser", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -406,7 +411,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_parser", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -423,7 +428,7 @@ tf_cc_test( srcs = ["hlo_sharding_test.cc"], deps = [ ":hlo", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -452,7 +457,7 @@ tf_cc_test( srcs = ["call_graph_test.cc"], deps = [ ":call_graph", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -486,6 +491,7 @@ cc_library( hdrs = ["call_inliner.h"], deps = [ ":call_graph", + ":hlo_dce", ":hlo_pass", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", @@ -501,7 +507,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -520,7 +526,7 @@ tf_cc_test( deps = [ ":call_graph", ":flatten_call_graph", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -558,13 +564,13 @@ cc_library( ":computation_placer", ":device_memory_allocator", ":platform_util", - ":pool", + ":stream_pool", ":transfer_manager", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", @@ -592,6 +598,7 @@ cc_library( ":hlo_proto_util", ":platform_util", ":source_map_util", + ":stream_pool", ":transfer_manager", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", @@ -606,6 +613,7 @@ cc_library( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", ], alwayslink = 1, @@ -636,7 +644,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], @@ -745,8 +753,8 @@ cc_library( ":hlo_execution_profile", ":hlo_graph_dumper", ":hlo_proto", - ":pool", ":shaped_buffer", + ":stream_pool", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -796,7 +804,7 @@ cc_library( hdrs = ["transfer_manager.h"], deps = [ ":shaped_buffer", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -832,7 +840,7 @@ cc_library( hdrs = ["execution_tracker.h"], deps = [ ":backend", - ":pool", + ":stream_pool", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -940,7 +948,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], @@ -959,7 +966,7 @@ tf_cc_test( ":hlo", ":hlo_ordering", ":hlo_scheduling", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1037,7 +1044,7 @@ tf_cc_test( ":hlo_ordering", ":hlo_value", ":tuple_points_to_analysis", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1051,6 +1058,7 @@ cc_library( hdrs = ["hlo_module_group_metadata.h"], deps = [ ":hlo", + ":hlo_casting_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -1101,6 +1109,7 @@ tf_cc_test( srcs = ["hlo_scheduling_test.cc"], deps = [ ":buffer_value", + ":heap_simulator", ":hlo", ":hlo_ordering", ":hlo_scheduling", @@ -1119,7 +1128,7 @@ cc_library( hdrs = ["hlo_query.h"], deps = [ ":hlo", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", ], ) @@ -1168,6 +1177,7 @@ cc_library( deps = [ ":hlo", ":shape_inference", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1198,6 +1208,7 @@ cc_library( deps = [ ":hlo", ":hlo_pass", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1217,11 +1228,26 @@ cc_library( ":hlo_creation_utils", ":hlo_pass", ":while_util", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", ], ) +cc_library( + name = "scatter_expander", + srcs = ["scatter_expander.cc"], + hdrs = ["scatter_expander.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + ":while_util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + ], +) + tf_cc_test( name = "batchnorm_expander_test", size = "small", @@ -1230,8 +1256,9 @@ tf_cc_test( ":batchnorm_expander", ":hlo", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -1253,6 +1280,7 @@ cc_library( ":hlo_pass", ":hlo_query", ":pattern_matcher", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1272,7 +1300,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -1308,7 +1336,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -1343,7 +1371,7 @@ cc_library( ":call_inliner", ":hlo", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1359,6 +1387,7 @@ tf_cc_test( ":conditional_simplifier", ":hlo", ":hlo_matchers", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1370,6 +1399,18 @@ tf_cc_test( ], ) +cc_library( + name = "while_loop_analysis", + srcs = ["while_loop_analysis.cc"], + hdrs = ["while_loop_analysis.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + "//tensorflow/compiler/xla:literal", + "//tensorflow/core:lib", + ], +) + cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], @@ -1377,8 +1418,8 @@ cc_library( deps = [ ":call_inliner", ":hlo", - ":hlo_evaluator", ":hlo_pass", + ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", ], @@ -1418,7 +1459,7 @@ tf_cc_test( deps = [ ":defuser", ":hlo_matchers", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", ], @@ -1446,7 +1487,7 @@ tf_cc_test( deps = [ ":hlo_matchers", ":implicit_broadcast_remover", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", ], @@ -1488,7 +1529,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":tuple_simplifier", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -1503,7 +1544,7 @@ cc_library( hdrs = ["reshape_mover.h"], deps = [ ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -1518,7 +1559,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":reshape_mover", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1553,7 +1594,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":inliner", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", @@ -1570,7 +1611,7 @@ cc_library( hdrs = ["computation_placer.h"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -1602,7 +1643,7 @@ cc_library( hdrs = ["generic_transfer_manager.h"], deps = [ ":transfer_manager", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1649,8 +1690,8 @@ tf_cc_test( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1693,7 +1734,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_matchers", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1708,6 +1749,7 @@ tf_cc_binary( deps = [ ":hlo", ":hlo_graph_dumper", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1722,7 +1764,7 @@ tf_cc_test( srcs = ["hlo_module_test.cc"], deps = [ ":hlo", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", @@ -1820,7 +1862,7 @@ tf_cc_test( ":hlo_matchers", ":hlo_ordering", ":instruction_fusion", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -1857,7 +1899,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_liveness_analysis", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -1918,7 +1960,7 @@ tf_cc_test( ":hlo_matchers", ":hlo_ordering", ":instruction_fusion", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1950,8 +1992,10 @@ cc_library( hdrs = ["tuple_points_to_analysis.h"], deps = [ ":hlo", + ":hlo_dataflow_analysis", ":logical_buffer", ":logical_buffer_analysis", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1970,6 +2014,7 @@ tf_cc_test( ":hlo_matchers", ":instruction_fusion", ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1993,6 +2038,7 @@ cc_library( deps = [ ":computation_layout", ":hlo", + ":hlo_casting_utils", ":hlo_dce", ":hlo_graph_dumper", ":hlo_pass", @@ -2041,7 +2087,7 @@ tf_cc_test( ":hlo_graph_dumper", ":hlo_matchers", ":hlo_runner", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -2092,6 +2138,7 @@ cc_library( hdrs = ["hlo_verifier.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:status_macros", @@ -2104,6 +2151,7 @@ tf_cc_test( srcs = ["hlo_verifier_test.cc"], deps = [ ":hlo", + ":hlo_parser", ":hlo_verifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -2123,6 +2171,7 @@ cc_library( ":buffer_liveness", ":buffer_value", ":call_graph", + ":copy_insertion", ":flatten_call_graph", ":hlo", ":hlo_dce", @@ -2130,6 +2179,7 @@ cc_library( ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", + ":tuple_simplifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -2143,6 +2193,7 @@ tf_cc_test( name = "hlo_rematerialization_test", srcs = ["hlo_rematerialization_test.cc"], deps = [ + ":flatten_call_graph", ":hlo", ":hlo_matchers", ":hlo_ordering", @@ -2162,6 +2213,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_dce", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -2182,7 +2234,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_module_dce", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -2206,7 +2258,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":layout_assignment", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -2265,7 +2317,7 @@ cc_library( ":hlo", ":hlo_domain_map", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -2281,7 +2333,7 @@ tf_cc_test( ":hlo", ":hlo_cse", ":hlo_matchers", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -2303,7 +2355,7 @@ cc_library( ":hlo_evaluator", ":hlo_pass", ":hlo_query", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", @@ -2318,7 +2370,7 @@ tf_cc_test( ":hlo_constant_folding", ":hlo_matchers", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -2355,6 +2407,20 @@ cc_library( ], ) +cc_library( + name = "hlo_domain_verifier", + srcs = ["hlo_domain_verifier.cc"], + hdrs = ["hlo_domain_verifier.h"], + deps = [ + ":hlo", + ":hlo_domain_map", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_domain_isolator", srcs = ["hlo_domain_isolator.cc"], @@ -2374,8 +2440,8 @@ cc_library( hdrs = ["hlo_domain_remover.h"], deps = [ ":hlo", - ":hlo_domain_isolator", ":hlo_domain_map", + ":hlo_domain_verifier", ":hlo_graph_dumper", ":hlo_pass", "//tensorflow/compiler/xla:types", @@ -2395,6 +2461,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -2409,7 +2476,7 @@ cc_library( ":hlo_evaluator", ":hlo_pass", ":hlo_query", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", @@ -2544,7 +2611,7 @@ cc_library( hdrs = ["hlo_tfgraph_builder.h"], deps = [ ":hlo", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework", @@ -2575,7 +2642,7 @@ cc_library( ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:window_util", @@ -2593,6 +2660,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_graph_dumper", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/tests:test_utils", @@ -2624,12 +2692,12 @@ tf_cc_test( ":hlo_matchers", ":shape_inference", ":transpose_folding", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -2645,7 +2713,7 @@ cc_library( deps = [ ":hlo", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -2660,13 +2728,13 @@ tf_cc_test( ":hlo", ":shape_inference", ":zero_sized_hlo_elimination", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2674,21 +2742,25 @@ tf_cc_test( ) cc_library( - name = "pool", - hdrs = ["pool.h"], + name = "stream_pool", + srcs = ["stream_pool.cc"], + hdrs = ["stream_pool.h"], deps = [ + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", ], ) tf_cc_test( - name = "pool_test", - srcs = ["pool_test.cc"], + name = "stream_pool_test", + srcs = ["stream_pool_test.cc"], deps = [ - ":pool", + ":stream_pool", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -2820,6 +2892,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":tuple_util", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/core:lib", ], ) @@ -2955,6 +3028,7 @@ cc_library( ":hlo", ":hlo_lexer", ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 3b36939b8a6900f047bbec225aa232e0e805b5d1..37834e1cc2657ff56f65a4f94eb973b9022eb8e1 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -50,20 +51,15 @@ namespace { namespace m = match; -// Returns whether operand is a literal with the given value. -bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { - return operand->opcode() == HloOpcode::kConstant && - operand->literal().IsAll(value); -} - bool IsAll(const HloInstruction* op, int8 value) { - if (IsLiteralWithValue(op, value)) { - return true; - } - if (op->opcode() == HloOpcode::kBroadcast && IsAll(op->operand(0), value)) { - return true; + switch (op->opcode()) { + case HloOpcode::kBroadcast: + return IsAll(op->operand(0), value); + case HloOpcode::kConstant: + return op->literal().IsAll(value); + default: + return false; } - return false; } // Returns whether the given transpose produces a result which is bit-wise @@ -75,21 +71,22 @@ bool TransposeIsBitcast(const HloInstruction* transpose) { transpose->dimensions()); } -// Returns true if the given reshape produces a result which is bit-wise +// Returns true if the given reshape/copy produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. // // This function is conservative -- even if this function returns false, the // reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. -bool ReshapeIsBitcast( - const HloInstruction* reshape, +bool ReshapeOrCopyIsBitcast( + const HloInstruction* instr, const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { - CHECK_EQ(HloOpcode::kReshape, reshape->opcode()); + CHECK(HloOpcode::kReshape == instr->opcode() || + HloOpcode::kCopy == instr->opcode()); - const HloInstruction* operand = reshape->operand(0); + const HloInstruction* operand = instr->operand(0); // Can't insert bitcasts if the compiler used a memory layout which isn't // compatible. - return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) && - valid_bitcast_callback(operand->shape(), reshape->shape()); + return ShapeUtil::ReshapeIsBitcast(operand->shape(), instr->shape()) && + valid_bitcast_callback(operand->shape(), instr->shape()); } // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain @@ -153,15 +150,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) override; + Status HandleSort(HloInstruction* sort) override; + Status HandleTranspose(HloInstruction* transpose) override; Status HandleSubtract(HloInstruction* sub) override; Status HandleMap(HloInstruction* map) override; - Status HandleMaximum(HloInstruction* maximum) override; - Status HandleMinimum(HloInstruction* minimum) override; - // Returns whether algebraic simplification has occurred. const bool changed() const { return changed_; } @@ -200,8 +196,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Helper method to perform and add reduction in a single dimension. HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { - HloInstruction* zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction* zero = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( @@ -433,7 +430,15 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op)); } // All copies can be eliminated (assuming layout constraints are satisified). - ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0)); + if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) { + return Status::OK(); + } + + if (is_layout_sensitive_ && + ReshapeOrCopyIsBitcast(copy, valid_bitcast_callback_)) { + ReplaceWithBitcast(copy); + } + return Status::OK(); } @@ -449,7 +454,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( // Filter out and remove empty operands. std::vector nonempty_operands; for (HloInstruction* operand : operands) { - if (!ShapeUtil::HasZeroElements(operand->shape())) { + if (!ShapeUtil::IsZeroElementArray(operand->shape())) { nonempty_operands.push_back(operand); } } @@ -528,11 +533,15 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { constant, BuildTupleConstant(computation_, constant->literal())); } + if (constant->shape().element_type() == TOKEN) { + return Status::OK(); + } + // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { - std::unique_ptr unique_scalar = - MakeUnique(constant->literal().GetFirstScalarLiteral()); + std::unique_ptr unique_scalar = MakeUnique( + LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( HloInstruction::CreateConstant(std::move(unique_scalar))); return ReplaceWithNewInstruction( @@ -563,6 +572,14 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { return Status::OK(); } +namespace { +template +Status InvertConstant(const HloInstruction& constant, Literal* result) { + return result->Populate([&](tensorflow::gtl::ArraySlice indices) { + return T{1.0} / constant.literal().Get(indices); + }); +} +} // namespace Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { Shape* shape; @@ -624,14 +641,31 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // (Backends can do this transformation, but generally only if the constant is // a scalar.) if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { - HloInstruction* one = - computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::One(a->shape().element_type()).CloneToUnique())); - HloInstruction* inverse = computation_->AddInstruction( - HloInstruction::CreateBinary(b->shape(), HloOpcode::kDivide, one, b)); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary(divide->shape(), - HloOpcode::kMultiply, a, inverse)); + Literal new_literal(b->shape()); + switch (b->shape().element_type()) { + case F16: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + case F32: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + case BF16: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + case F64: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + case C64: + TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + break; + default: + return Status::OK(); + } + auto inverse = computation_->AddInstruction( + HloInstruction::CreateConstant((new_literal.CloneToUnique()))); + TF_ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); + return ReplaceInstruction(divide, new_divide); } // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) @@ -651,18 +685,18 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) { TF_ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply, b, c)); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary(divide->shape(), - HloOpcode::kDivide, a, b_times_c)); + TF_ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c)); + return ReplaceInstruction(divide, new_divide); } // A / (B / C) => (A*C) / B if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) { TF_ASSIGN_OR_RETURN(auto a_times_c, MakeBinaryHlo(HloOpcode::kMultiply, a, c)); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary(divide->shape(), - HloOpcode::kDivide, a_times_c, b)); + TF_ASSIGN_OR_RETURN(auto new_divide, + MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b)); + return ReplaceInstruction(divide, new_divide); } return Status::OK(); @@ -1058,11 +1092,11 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } // Replace a zero element dot with a broadcast of the constant 0. - if (ShapeUtil::HasZeroElements(dot->shape()) || - ShapeUtil::HasZeroElements(lhs->shape()) || - ShapeUtil::HasZeroElements(rhs->shape())) { + if (ShapeUtil::IsZeroElementArray(dot->shape()) || + ShapeUtil::IsZeroElementArray(lhs->shape()) || + ShapeUtil::IsZeroElementArray(rhs->shape())) { auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } @@ -1124,6 +1158,19 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } + // 0*A => 0. Only applies for integral types for correct NaN-handling. + if (IsAll(lhs, 0) && + primitive_util::IsIntegralType(multiply->shape().element_type()) && + ReplaceInstructionIfSameShape(multiply, lhs)) { + return Status::OK(); + } + // A*0 => 0 + if (IsAll(rhs, 0) && + primitive_util::IsIntegralType(multiply->shape().element_type()) && + ReplaceInstructionIfSameShape(multiply, rhs)) { + return Status::OK(); + } + // exp(A) * exp(B) => exp(A+B) if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) { auto add = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -1221,9 +1268,10 @@ bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, switch (instruction->opcode()) { case HloOpcode::kReshape: case HloOpcode::kReverse: - case HloOpcode::kSort: case HloOpcode::kTranspose: return true; + case HloOpcode::kSort: + return (!ShapeUtil::IsTuple(instruction->shape())); default: return false; } @@ -1392,7 +1440,7 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { } Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { - if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) { + if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) { return ReplaceWithNewInstruction( pad, HloInstruction::CreateBroadcast(pad->shape(), pad->mutable_operand(1), {})); @@ -1487,7 +1535,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( - Literal::One(power->shape().element_type()).CloneToUnique()); + LiteralUtil::One(power->shape().element_type()).CloneToUnique()); std::unique_ptr ones; if (ShapeUtil::IsScalar(power->shape())) { ones = std::move(one); @@ -1522,7 +1570,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::One(rhs->shape().element_type()).CloneToUnique())); + LiteralUtil::One(rhs->shape().element_type()).CloneToUnique())); // Explicitly broadcast scalar 1 to the output shape, to avoid implicit // broadcast in divide HLO as we are trying to eliminate implicit @@ -1638,7 +1686,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // Reshape directly to empty constant if the shape contains zero-element // dimension. - if (ShapeUtil::HasZeroElements(reshape->shape())) { + if (ShapeUtil::IsZeroElementArray(reshape->shape())) { auto empty_constant = HloInstruction::CreateConstant( Literal::CreateFromShape(reshape->shape())); @@ -1672,7 +1720,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // Make this a bitcast if possible. if (is_layout_sensitive_ && - ReshapeIsBitcast(reshape, valid_bitcast_callback_)) { + ReshapeOrCopyIsBitcast(reshape, valid_bitcast_callback_)) { ReplaceWithBitcast(reshape); return Status::OK(); } @@ -1698,19 +1746,37 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) { return Status::OK(); } + + auto is_unstrided_slice = [](const HloInstruction* hlo) { + return c_all_of(hlo->slice_strides(), + [](int64 stride) { return stride == 1; }); + }; + if (slice->operand(0)->opcode() == HloOpcode::kSlice && + is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) { + HloInstruction* operand_slice = slice->mutable_operand(0); + std::vector new_slice_starts = slice->slice_starts(); + std::vector new_slice_limits = slice->slice_limits(); + for (int64 i = 0; i < new_slice_starts.size(); ++i) { + new_slice_starts[i] += operand_slice->slice_starts(i); + new_slice_limits[i] += operand_slice->slice_starts(i); + } + return ReplaceWithNewInstruction( + slice, HloInstruction::CreateSlice( + slice->shape(), operand_slice->mutable_operand(0), + new_slice_starts, new_slice_limits, slice->slice_strides())); + } return Status::OK(); } Status AlgebraicSimplifierVisitor::HandleDynamicSlice( HloInstruction* dynamic_slice) { auto operand = dynamic_slice->mutable_operand(0); - auto start_indices = dynamic_slice->operand(1); if (ShapeUtil::IsScalar(dynamic_slice->shape())) { return ReplaceInstruction(dynamic_slice, operand); } - // DynamicSlice where operand has the same size as the output and - // start_indices are all zero is simply equal to operand. - if (IsAll(start_indices, 0) && SameShape(operand, dynamic_slice)) { + // DynamicSlice where operand has the same size as the output is simply equal + // to operand. + if (SameShape(operand, dynamic_slice)) { return ReplaceInstruction(dynamic_slice, operand); } return Status::OK(); @@ -1719,27 +1785,17 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { auto update = dynamic_update_slice->mutable_operand(1); - auto start_indices = dynamic_update_slice->operand(2); - // DynamicUpdateSlice on a scalar just passes through the update argument. - if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { - return ReplaceInstruction(dynamic_update_slice, update); - } - // DynamicUpdateSlice where operand and update have the same size and - // start_indices are all zero is simply equal to update. - // - // (We require start_indices to be all zero because we want this optimization - // not to affect the visible behavior of this op even when the indices are out - // of range. Currently dynamic-update-slice wraps out-of-range indices, so - // we can only remove the op if its indices never wrap.) - if (IsAll(start_indices, 0) && SameShape(dynamic_update_slice, update)) { + // DynamicUpdateSlice where operand and update have the same size is simply + // equal to update. + if (SameShape(dynamic_update_slice, update)) { return ReplaceInstruction(dynamic_update_slice, update); } // If any dimension of update is 0, elide the DynamicUpdateSlice. This // optimization becomes invalid should we later prefer to warn about out of // bound indices. - if (ShapeUtil::HasZeroElements(update->shape())) { + if (ShapeUtil::IsZeroElementArray(update->shape())) { return ReplaceInstruction(dynamic_update_slice, dynamic_update_slice->mutable_operand(0)); } @@ -1747,12 +1803,18 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( } Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { + // TODO(b/112040122): Most of those optimizations can be done for multi-output + // reduces. + if (ShapeUtil::IsTuple(reduce->shape())) { + return Status::OK(); + } + auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); - if (ShapeUtil::HasZeroElements(arg->shape()) || - ShapeUtil::HasZeroElements(reduce->shape())) { + if (ShapeUtil::IsZeroElementArray(arg->shape()) || + ShapeUtil::IsZeroElementArray(reduce->shape())) { return ReplaceWithNewInstruction( reduce, HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); @@ -1858,12 +1920,32 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } } + // Convert Reduce(concat({a,b,...})) to + // map(reduce(a),map(reduce(b),...,)) + // + // This should make fusion easier or use less memory bandwidth in the unfused + // case. + if (arg->opcode() == HloOpcode::kConcatenate && + c_linear_search(reduce->dimensions(), arg->concatenate_dimension())) { + HloInstruction* old_reduce = nullptr; + for (HloInstruction* operand : arg->operands()) { + HloInstruction* new_reduce = computation_->AddInstruction( + HloInstruction::CreateReduce(reduce->shape(), operand, init_value, + reduce->dimensions(), function)); + if (old_reduce != nullptr) { + new_reduce = computation_->AddInstruction(HloInstruction::CreateMap( + reduce->shape(), {old_reduce, new_reduce}, function)); + } + old_reduce = new_reduce; + } + return ReplaceInstruction(reduce, old_reduce); + } return Status::OK(); } Status AlgebraicSimplifierVisitor::HandleReduceWindow( HloInstruction* reduce_window) { - if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) { + if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) { return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateBroadcast(reduce_window->shape(), @@ -2031,6 +2113,21 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( /*reduce_computation=*/function)); } +Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { + auto operand = sort->mutable_operand(0); + int64 dimension_to_sort = sort->dimensions(0); + if (ShapeUtil::IsZeroElementArray(operand->shape()) || + operand->shape().dimensions(dimension_to_sort) <= 1) { + if (sort->operand_count() == 1) { + return ReplaceInstruction(sort, operand); + } + // If it is key/value sort, the output of sort is a tuple. + return ReplaceWithNewInstruction( + sort, HloInstruction::CreateTuple({operand, sort->mutable_operand(1)})); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); if (std::is_sorted(transpose->dimensions().begin(), @@ -2059,16 +2156,15 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction* convolution) { auto lhs = convolution->mutable_operand(0); auto rhs = convolution->mutable_operand(1); - if (ShapeUtil::HasZeroElements(lhs->shape()) || - ShapeUtil::HasZeroElements(rhs->shape())) { + if (ShapeUtil::IsZeroElementArray(lhs->shape()) || + ShapeUtil::IsZeroElementArray(rhs->shape())) { return ReplaceWithNewInstruction( convolution, HloInstruction::CreateBroadcast( convolution->shape(), - computation_->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::MakeShape(convolution->shape().element_type(), {}), - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))), + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(convolution->shape().element_type()) + .CloneToUnique())), {})); } const auto& window = convolution->window(); @@ -2240,68 +2336,6 @@ Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) { return ReplaceWithNewInstruction(map, std::move(clone)); } -Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { - // Match the following tree: - // min_operand operand - // \ / - // max_operand min - // \ / - // max - // where max_operand and min_operand are scalar constants. - { - HloInstruction* min; - HloInstruction* max_operand; - HloInstruction* min_operand; - HloInstruction* operand; - - if (hlo_query::MatchBinaryInstructionOperandOpcode( - HloOpcode::kMinimum, maximum, - /*matching_operand=*/&min, - /*other_operand=*/&max_operand) && - hlo_query::MatchBinaryInstructionOperand( - hlo_query::IsScalarConstant, min, - /*matching_operand=*/&min_operand, - /*other_operand=*/&operand) && - TransformToClampIfSameShape(maximum, min, min_operand, operand, maximum, - max_operand)) { - return Status::OK(); - } - } - - return Status::OK(); -} - -Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) { - // Match the following tree: - // max_operand operand - // \ / - // min_operand max - // \ / - // min - // where max_operand and min_operand are scalar constants. - { - HloInstruction* max; - HloInstruction* max_operand; - HloInstruction* min_operand; - HloInstruction* operand; - - if (hlo_query::MatchBinaryInstructionOperandOpcode( - HloOpcode::kMaximum, minimum, - /*matching_operand=*/&max, - /*other_operand=*/&min_operand) && - hlo_query::MatchBinaryInstructionOperand( - hlo_query::IsScalarConstant, max, - /*matching_operand=*/&max_operand, - /*other_operand=*/&operand) && - TransformToClampIfSameShape(minimum, minimum, min_operand, operand, max, - max_operand)) { - return Status::OK(); - } - } - - return Status::OK(); -} - StatusOr AlgebraicSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 2605b0488cb7c6850746df94c4ab05d6b5d35de5..862cbeeba6b82e1f24a6616b3237dc47d022e9af 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -60,7 +60,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); @@ -74,12 +74,32 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +// Test that A * 0 is simplified to 0 +TEST_F(AlgebraicSimplifierTest, MulZero) { + Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0s32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero)); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kMultiply); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_EQ(computation->root_instruction(), zero); +} + // Test that Reduce(Reduce(A)) -> Reduce(A) TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { HloComputation::Builder builder(TestName()); // Create add computation. HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); HloComputation* add_computation = nullptr; { HloComputation::Builder builder(TestName() + ".add"); @@ -119,7 +139,7 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0)); @@ -140,9 +160,9 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.14159f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.14159f))); HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1)); @@ -165,7 +185,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); HloInstruction* bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(r2f32, zero, {0, 1})); builder.AddInstruction( @@ -200,9 +220,12 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - builder.AddInstruction( - HloInstruction::CreateMap(r2f32, {param0, zero}, add_computation)); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + builder.AddInstruction(HloInstruction::CreateMap( + r2f32, + {param0, builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, zero, {}))}, + add_computation)); auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -211,7 +234,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, zero)); + EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { @@ -220,7 +243,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0, 0}))); HloInstruction* bcast = builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1})); builder.AddInstruction( @@ -239,7 +262,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({3.14f, 3.14f, 3.14f}))); + LiteralUtil::CreateR1({3.14f, 3.14f, 3.14f}))); auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -255,7 +278,7 @@ TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({3.14, 3.14, 4}))); + LiteralUtil::CreateR1({3.14, 3.14, 4}))); auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -274,7 +297,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); @@ -295,7 +318,7 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kSubtract, param0, constant)); @@ -367,17 +390,16 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r2f32, "param1")); HloInstruction* param2 = builder.AddInstruction( HloInstruction::CreateParameter(2, r2f32, "param2")); HloInstruction* param3 = builder.AddInstruction( - HloInstruction::CreateParameter(3, r0f32, "param3")); + HloInstruction::CreateParameter(3, r2f32, "param3")); HloInstruction* div0 = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1)); HloInstruction* div1 = builder.AddInstruction( @@ -398,8 +420,6 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { EXPECT_THAT( computation->root_instruction(), op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2))); - EXPECT_TRUE( - ShapeUtil::Compatible(computation->root_instruction()->shape(), r2f32)); } // Test that A/exp(B) is simplified to A*exp(-B). @@ -459,7 +479,6 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { // Test that broadcasting is done on the right step when simplifying A/pow(B,C) // to A*pow(B,-C). TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -467,7 +486,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* param2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction::CreateParameter(2, r1f32, "param2")); HloInstruction* power = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2)); builder.AddInstruction( @@ -484,14 +503,9 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { ASSERT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); - - const HloInstruction* negate = - computation->root_instruction()->operand(1)->operand(1); - const Shape& negate_shape = negate->shape(); - EXPECT_EQ(0, negate_shape.dimensions_size()); } -// A / Const => A * (1 / Const) +// A / Const => A * InvertedConst TEST_F(AlgebraicSimplifierTest, DivideByConstant) { Shape r1f32 = ShapeUtil::MakeShape(F32, {3}); HloComputation::Builder builder(TestName()); @@ -499,7 +513,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0.f, 1.f, 2.f}))); + LiteralUtil::CreateR1({0.f, 1.f, 2.f}))); builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, constant)); @@ -510,20 +524,19 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Divide(op::Constant(), constant))); + op::Multiply(param0, op::Constant())); } // pow(pow(A, X), Y) => pow(A, X*Y) TEST_F(AlgebraicSimplifierTest, PowerOfPower) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* exp1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* exp2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction::CreateParameter(2, r1f32, "param2")); HloInstruction* inner_power = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1)); builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, @@ -540,15 +553,14 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex // numbers. TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { - Shape r0c64 = ShapeUtil::MakeShape(C64, {}); Shape r1c64 = ShapeUtil::MakeShape(C64, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( HloInstruction::CreateParameter(0, r1c64, "param0")); HloInstruction* exp1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0c64, "param1")); + HloInstruction::CreateParameter(1, r1c64, "param1")); HloInstruction* exp2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0c64, "param2")); + HloInstruction::CreateParameter(2, r1c64, "param2")); HloInstruction* inner_power = builder.AddInstruction( HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1)); builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, @@ -567,7 +579,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); @@ -588,7 +600,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 1.0}, {1.0, 1.0}}))); + LiteralUtil::CreateR2({{1.0, 1.0}, {1.0, 1.0}}))); HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); @@ -868,7 +880,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); @@ -892,7 +904,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); @@ -920,7 +932,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); @@ -942,7 +954,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); @@ -964,7 +976,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* negative_one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(-1))); builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); @@ -1055,7 +1067,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { builder.AddInstruction(HloInstruction::CreateReduceWindow( ShapeUtil::MakeShape(F32, {5, 2}), param, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), window, add_computation)); module().AddEntryComputation(builder.Build()); HloPassFix simplifier(/*is_layout_sensitive=*/false, @@ -1082,7 +1094,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {5, 2}), param, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), padding)); module().AddEntryComputation(builder.Build()); EXPECT_THAT(module().entry_computation()->root_instruction(), @@ -1124,7 +1136,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); @@ -1159,6 +1171,33 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { EXPECT_THAT(computation->root_instruction(), param0); } +TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), "param")); + *param->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({0, 1, 2, 3}); + HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), HloOpcode::kCopy, param)); + *copy->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({1, 2, 0, 3}); + auto computation = module().AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + + AlgebraicSimplifier simplifier1(/*is_layout_sensitive=*/true, + non_bitcasting_callback()); + ASSERT_FALSE(simplifier1.Run(&module()).ValueOrDie()); + // Verify that the copy is not replaced. + EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + + AlgebraicSimplifier simplifier2(/*is_layout_sensitive=*/true, + bitcasting_callback()); + ASSERT_TRUE(simplifier2.Run(&module()).ValueOrDie()); + // Verify that the copy is replaced. + EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); +} + // Test that unary concatenates are removed. TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); @@ -1189,7 +1228,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* empty_literal = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); HloInstruction* empty_slice = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1})); @@ -1211,6 +1250,55 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { op::Concatenate(param0, param0, param1)); } +// Test that reduce of concat is simplified. +TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { + const int kParamLength = 100; + Shape r3f32 = + ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r3f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r3f32, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r3f32, "param2")); + Shape concat_shape = + ShapeUtil::MakeShape(F32, {kParamLength, 3 * kParamLength, kParamLength}); + HloInstruction* Concatenate = + builder.AddInstruction(HloInstruction::CreateConcatenate( + concat_shape, {param0, param1, param2}, 1)); + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); + Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength}); + + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, Concatenate, zero, {1, 2}, add_computation)); + + auto computation = module().AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT( + computation->root_instruction(), + op::Map(op::Map(op::Reduce(param0, zero), op::Reduce(param1, zero)), + op::Reduce(param2, zero))); +} + // Test a concatenate with only empty operands is removed. TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { const int kParamLength = 100; @@ -1219,7 +1307,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* empty_literal = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); HloInstruction* empty_slice = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1})); @@ -1389,33 +1477,6 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); } -// Regression test for a bug in the reshape sinking transformation, where -// moving a reshape to a scalar led to a crash. -TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { - HloComputation::Builder builder(TestName()); - HloInstruction* param = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1}), "param")); - HloInstruction* reshape = builder.AddInstruction( - HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param)); - HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1., 2., 3.}))); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero)); - auto computation = module().AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - - simplifier.Run(&module()).ValueOrDie(); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); -} - // Regression test for a bug where if we failed to sink a reshape, we'd set the // 'changed' bit in AlgebraicSimplifier to false. TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { @@ -1428,7 +1489,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")), builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{0, 0}, {0, 0}}))))); + LiteralUtil::CreateR2({{0, 0}, {0, 0}}))))); builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); @@ -1451,7 +1512,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")), builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{0, 0}, {0, 0}}))))); + LiteralUtil::CreateR2({{0, 0}, {0, 0}}))))); builder.AddInstruction( HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, @@ -1734,7 +1795,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 2}), "param")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); PaddingConfig no_padding; for (int i = 0; i < 2; ++i) { auto dimension = no_padding.add_dimensions(); @@ -1765,7 +1826,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {10, 10}), "param")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); PaddingConfig padding; int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {2, -3}; @@ -1847,6 +1908,73 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { EXPECT_THAT(computation->root_instruction(), param); } +TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { + HloComputation::Builder builder(TestName()); + const int64 dim0 = 11; + const int64 dim1 = 12; + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param")); + HloInstruction* original_slice = + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {dim0 - 2, dim1 - 4}), param, + /*start_indices=*/{1, 2}, + /*limit_indices=*/{dim0 - 1, dim1 - 2}, /*strides=*/{1, 1})); + + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice, + /*start_indices=*/{2, 3}, + /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1})); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Slice(op::Slice(param))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3); + EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5); + EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2); + EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4); +} + +TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {1}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), keys); +} + +TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0}); + Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto values = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values")); + builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values)); +} + TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { struct ConvTestOptions { int in_batch = 10; @@ -1878,7 +2006,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { // Builds a convolution from and runs algebraic simplification on // the computation. Returns a string description of the result of // simplification. - auto build_and_simplify = [&options, this]() -> string { + auto build_and_simplify = [&options]() -> string { HloComputation::Builder b(TestName()); Window window; @@ -2076,160 +2204,6 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { EXPECT_EQ("NO_CHANGE", build_and_simplify()); } -// Test that max(min(A, x), y) is transformed to clamp(y, A, x) -TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); - HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMinimum, param0, min_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Minimum(param0, min_value), max_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Clamp(max_value, param0, min_value)); -} - -// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar -// values. -TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMaximum, param0, max_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Clamp(max_value, param0, min_value)); -} - -// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for -// broadcasted scalar values. -TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r1f32, HloOpcode::kMaximum, param0, max_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Clamp(max_value, param0, min_value)); -} - -// Test that min(max(A, non-constant1), non-constant2) is not canonicalized to -// clamp(non-constant1, A, non-constant2) -TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32, "param1")); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0f32, "param2")); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMaximum, param0, max_value)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Maximum(param0, max_value), min_value)); -} - -// Test that min(f(max(A, constant1)), constant2) is not transformed to -// clamp(constant1, A, constant2) -TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); - HloComputation::Builder builder(TestName()); - HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); - HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMaximum, param0, max_value)); - HloInstruction* fmax = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value)); - builder.AddInstruction(HloInstruction::CreateBinary( - r0f32, HloOpcode::kMinimum, fmax, min_value)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), - min_value)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), - min_value)); -} - // Test that slice(broadcast(/*scalar value*/)) simplifies to a single // broadcast. TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { @@ -2271,7 +2245,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloComputation::Builder builder(TestName()); HloInstruction* forty_two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); HloInstruction* broadcast = builder.AddInstruction( @@ -2318,7 +2292,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { padding.mutable_dimensions(3)->set_edge_padding_high(2); HloInstruction* pad_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding)); @@ -2349,7 +2323,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { const Shape reduce_window_shape = ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); HloInstruction* reduce_init_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* reduce_window = builder.AddInstruction(HloInstruction::CreateReduceWindow( reduce_window_shape, pad, reduce_init_value, window, @@ -2400,7 +2374,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { padding.mutable_dimensions(3)->set_edge_padding_high(2); HloInstruction* pad_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding)); @@ -2435,7 +2409,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { const Shape reduce_window_shape = ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); HloInstruction* reduce_init_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* reduce_window = builder.AddInstruction(HloInstruction::CreateReduceWindow( reduce_window_shape, convert, reduce_init_value, window, @@ -2506,9 +2480,9 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { HloComputation::Builder call_builder(TestName() + ".Call"); HloInstruction* zero = call_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0.0f}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0.0f}))); HloInstruction* one = call_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1.0f}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0f}))); call_builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); @@ -2524,9 +2498,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { HloComputation::Builder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; - std::unique_ptr value = - Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), - Literal::CreateR1(constant_vector).get()}); + std::unique_ptr value = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(constant_scalar).get(), + LiteralUtil::CreateR1(constant_vector).get()}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto computation = module().AddEntryComputation(builder.Build()); @@ -2549,8 +2523,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { shape, builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "slice_from")), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))), + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), /*slice_sizes=*/{10, 100, 1000})); auto computation = module().AddEntryComputation(builder.Build()); @@ -2583,8 +2557,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { builder.AddInstruction( HloInstruction::CreateParameter(2, slice_shape, "to_update")), slice, - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))))); + builder.AddInstruction(HloInstruction::CreateParameter( + 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -2599,7 +2573,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* input_array = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({3, 4}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({3, 4}))); HloInstruction* inner_bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(r2f32, input_array, {1})); Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); @@ -2708,7 +2682,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( pad_shape, input, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), padding)); HloComputation* add_computation = nullptr; @@ -2727,7 +2701,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { Window window = window_util::MakeWindow( decorate_spatials(param.reduce_window_spatials, 1, 1)); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ShapeInference::InferReduceWindowShape( pad->shape(), zero->shape(), window, @@ -2866,7 +2840,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); auto* lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k))); Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n}); @@ -2945,7 +2919,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n}); auto* rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n))); DotDimensionNumbers dot_dnums; @@ -2992,7 +2966,7 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { HloInstruction* const update = builder.AddInstruction( HloInstruction::CreateParameter(1, update_shape, "update")); HloInstruction* const start_indices = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0}))); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( dslice_shape, operand, update, start_indices)); const HloComputation* const computation = @@ -3041,7 +3015,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase); Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); auto* lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, /*cols=*/lhs_cols))); @@ -3049,7 +3023,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int32 start_col = (spec.lcd == 0) ? spec.s : 0; const auto start_indices = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR1({start_row, start_col}))); int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k; Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); @@ -3060,7 +3034,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k; Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); auto* rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, /*cols=*/rhs_cols))); @@ -3108,7 +3082,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k; Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); auto* lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, /*cols=*/lhs_cols))); @@ -3119,7 +3093,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase); Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); auto* rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, /*cols=*/rhs_cols))); @@ -3127,7 +3101,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int32 start_col = (spec.rcd == 0) ? spec.s : 0; const auto start_indices = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR1({start_row, start_col}))); int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k; Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 95b4cb6d2e694063b648b264bd2454ae0a5469ff..51ebc4763b612884a4453edec5711f78c4006fc3 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -109,11 +109,11 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) { ResolveInternal(data)); for (const auto& shaped_buffer : replicated_buffers) { std::vector shape_indices; - ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), - [this, &shape_indices](const Shape& /*subshape*/, - const ShapeIndex& index) { - shape_indices.push_back(index); - }); + ShapeUtil::ForEachSubshape( + shaped_buffer->on_device_shape(), + [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) { + shape_indices.push_back(index); + }); for (const ShapeIndex& index : shape_indices) { TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), shaped_buffer->device_ordinal())); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 349b32451a697dbd6804b44cd1a36419c753bb14..d12be3e007fe0b16ac850d64521f0025d481b5d2 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -96,24 +96,19 @@ Backend::CreateDefaultBackend() { return CreateBackend(backend_options); } -StatusOr Backend::BorrowStream(int device_ordinal) { - TF_ASSIGN_OR_RETURN(auto exec, stream_executor(device_ordinal)); - return BorrowStream(exec); +StatusOr Backend::BorrowStream(int device_ordinal) { + TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal)); + return BorrowStream(executor); } -StatusOr Backend::BorrowStream( - se::StreamExecutor* executor) { +StatusOr Backend::BorrowStream(se::StreamExecutor* executor) { tensorflow::mutex_lock l(mu_); if (0 == stream_pools_.count(executor)) { stream_pools_.emplace(std::piecewise_construct, std::forward_as_tuple(executor), - std::forward_as_tuple([executor]() { - auto stream = MakeUnique(executor); - stream->Init(); - return stream; - })); + std::forward_as_tuple()); } - return stream_pools_.at(executor).Allocate(); + return stream_pools_.at(executor).BorrowStream(executor); } Backend::Backend( diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 6546602473e3381cf13879ddebd05d34d1f7a055..1bc3796fa48c1627538474d04ef5358ba64dfce9 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -63,11 +63,9 @@ class BackendOptions { // // It also offers a pooling API for creation/use of initialized streams: // -// StreamPtr stream = backend->BorrowStream().ConsumeValueOrDie(); +// StreamPool::Ptr stream = backend->BorrowStream().ConsumeValueOrDie(); class Backend { public: - using StreamPtr = Pool::SmartPtr; - // Creates a new backend. static StatusOr> CreateBackend( const BackendOptions& options); @@ -114,13 +112,13 @@ class Backend { // Borrows a stream for use by the caller, either by grabbing it from an // internal pool, or by constructing/initializating it, and returns the result // to the caller. - StatusOr BorrowStream(int device_ordinal); - StatusOr BorrowStream(se::StreamExecutor* executor); + StatusOr BorrowStream(int device_ordinal); + StatusOr BorrowStream(se::StreamExecutor* executor); // Returns a function to borrow a stream, as `BorrowStream` above does. // Purely for convenience, the caller could rather make this anonymous // function itself. - std::function(int)> StreamBorrower() { + std::function(int)> StreamBorrower() { return [this](int device_ordinal) { return BorrowStream(device_ordinal); }; } @@ -169,7 +167,7 @@ class Backend { tensorflow::mutex mu_; // Mapping from stream executor to stream pools, used by `BorrowStream` above. - std::map> stream_pools_ GUARDED_BY(mu_); + std::map stream_pools_ GUARDED_BY(mu_); // The default memory allocator to use. std::unique_ptr memory_allocator_; diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index ec13fadbc75e2315d1d6ef72e24a0faca0c7de40..c4cd60c1201f7ddbf0aba4b6d587952531b74bfa 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -41,6 +43,8 @@ namespace xla { namespace { +using tensorflow::gtl::optional; + // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { @@ -97,7 +101,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { add_instruction(HloInstruction::CreateConvert( ShapeUtil::MakeShape(operand->shape().element_type(), {}), add_instruction(HloInstruction::CreateConstant( - Literal::CreateR0(-0.5f))))), + LiteralUtil::CreateR0(-0.5f))))), {})); return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower, operand, exponent); @@ -113,7 +117,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { add_instruction(HloInstruction::CreateConvert( ShapeUtil::MakeShape(operand->shape().element_type(), {}), add_instruction(HloInstruction::CreateConstant( - Literal::CreateR0(1.0 / element_count))))), + LiteralUtil::CreateR0(1.0 / element_count))))), {})); return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, operand, elem_count_recip); @@ -200,11 +204,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( HloInstruction* offset = batch_norm->mutable_operand(2); const Shape feature_shape = scale->shape(); - auto zero_literal = Literal::CreateR0(0.0f); + auto zero_literal = LiteralUtil::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon = add(HloInstruction::CreateBroadcast( operand_shape, @@ -288,16 +292,22 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( int64 instruction_count_after = computation_->instruction_count(); CHECK_EQ(instruction_count_after, instruction_count_before + added_instructions.size()); + const HloSharding& sharding = batch_norm->sharding(); HloSharding operand_sharding = - batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0}); + sharding.GetAsShapeTree(batch_norm->shape()).element({0}); + optional unique_device = batch_norm->sharding_unique_device(); + HloSharding default_sharding = + unique_device.has_value() + ? HloSharding::AssignDevice(unique_device.value()) + : HloSharding::Replicate(); for (HloInstruction* inst : added_instructions) { if (ShapeUtil::Equal(inst->shape(), operand_shape)) { inst->set_sharding(operand_sharding); } else { - inst->set_sharding(HloSharding::Replicate()); + inst->set_sharding(default_sharding); } } - tuple->set_sharding(batch_norm->sharding()); + tuple->set_sharding(sharding); } TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); @@ -320,7 +330,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( HloInstruction* var = batch_norm->mutable_operand(4); const Shape feature_shape = scale->shape(); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( operand_shape, @@ -388,14 +398,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( CHECK_EQ(instruction_count_after, instruction_count_before + added_instructions.size()); if (batch_norm->has_sharding()) { + const HloSharding& sharding = batch_norm->sharding(); + optional unique_device = batch_norm->sharding_unique_device(); + HloSharding default_sharding = + unique_device.has_value() + ? HloSharding::AssignDevice(unique_device.value()) + : HloSharding::Replicate(); for (HloInstruction* inst : added_instructions) { if (ShapeUtil::Equal(inst->shape(), operand_shape)) { - inst->set_sharding(batch_norm->sharding()); + inst->set_sharding(sharding); } else { - inst->set_sharding(HloSharding::Replicate()); + inst->set_sharding(default_sharding); } } - shifted_normalized->set_sharding(batch_norm->sharding()); + shifted_normalized->set_sharding(sharding); } TF_CHECK_OK( ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized))); @@ -447,11 +463,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 feature_count = activation_shape.dimensions(feature_index); const int64 elements_per_feature_int64 = size_in_elements / feature_count; - auto zero_literal = Literal::CreateR0(0.0f); + auto zero_literal = LiteralUtil::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); @@ -542,7 +558,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); auto elements_per_feature_literal = - Literal::CreateR0(elements_per_feature_int64); + LiteralUtil::CreateR0(elements_per_feature_int64); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, elements_per_feature_literal->Convert(ptype)); auto elements_per_feature = add( @@ -562,19 +578,25 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( auto tuple = HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); if (batch_norm->has_sharding()) { + const HloSharding& sharding = batch_norm->sharding(); int64 instruction_count_after = computation_->instruction_count(); CHECK_EQ(instruction_count_after, instruction_count_before + added_instructions.size()); HloSharding activation_sharding = - batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0}); + sharding.GetAsShapeTree(batch_norm->shape()).element({0}); + auto unique_device = batch_norm->sharding_unique_device(); + HloSharding default_sharding = + unique_device.has_value() + ? HloSharding::AssignDevice(unique_device.value()) + : HloSharding::Replicate(); for (HloInstruction* inst : added_instructions) { if (ShapeUtil::Equal(inst->shape(), activation_shape)) { inst->set_sharding(activation_sharding); } else { - inst->set_sharding(HloSharding::Replicate()); + inst->set_sharding(default_sharding); } } - tuple->set_sharding(batch_norm->sharding()); + tuple->set_sharding(sharding); } TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index aa36e64b07099a372dab67babc7a18a2d39596bc..a725351462809e5b670bbf1d79d2dded87e54f07 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -19,12 +19,13 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -114,5 +115,33 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { EXPECT_EQ(root->opcode(), HloOpcode::kTuple); } +TEST_F(BatchNormExpanderTest, BatchNormTrainingSharding) { + const char* module_str = R"( +HloModule module +ENTRY entry { + %param.0 = f32[8,4] parameter(0) + %param.1 = f32[4] parameter(1) + %param.2 = f32[4] parameter(2) + ROOT %batch-norm-training = (f32[8,4], f32[4], f32[4]) + batch-norm-training(f32[8,4] %param.0, f32[4] %param.1, f32[4] %param.2), + epsilon=0.001, feature_index=1, sharding={maximal device=1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str)); + BatchNormExpander rewriter(/*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, + /*rewrite_grad_op=*/true); + ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + + for (auto* instruction : module->entry_computation()->instructions()) { + if (instruction->opcode() == HloOpcode::kParameter) { + continue; + } + auto device = instruction->sharding_unique_device(); + ASSERT_TRUE(device); + EXPECT_EQ(*device, 1); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index f7b4c1405dbc8719d8fba5476e6e41d2921ea877..7cf05ca443c00c3b40eeb7d756cf216b45c45c39 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -235,7 +235,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, - sum, /*replica_group_ids=*/{}, /*barrier=*/"")); + sum, /*replica_group_ids=*/{}, /*barrier=*/"", + /*all_reduce_id=*/tensorflow::gtl::nullopt)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 14c54ddd135af024327f63418b410da1ed3c4fd4..16e99b57220cc185fbfaa75d30a0de709cf61ee7 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -34,8 +34,10 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; - // Special handling for cross-replica-sum which can have a tuple output. + // Special handling for cross-replica-sum and sort which can have a tuple + // output. Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleSort(HloInstruction* sort) override; static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { @@ -49,6 +51,10 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { // conversions between F32 and BF16 to make it supported. Status HandleInstruction(HloInstruction* hlo); + // Handle instructions with tuple outputs by examining each output + // independently. + Status HandleMultipleOutputs(HloInstruction* hlo); + // Inserts a conversion HLO that changes the given HLO's output type. Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to, HloComputation* computation); @@ -148,22 +154,35 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( HloInstruction* crs) { if (!ShapeUtil::IsTuple(crs->shape())) { return HandleInstruction(crs); + } else { + return HandleMultipleOutputs(crs); } +} + +Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) { + if (!ShapeUtil::IsTuple(sort->shape())) { + return HandleInstruction(sort); + } else { + return HandleMultipleOutputs(sort); + } +} - std::vector operand_types(crs->operand_count()); - std::vector output_types(crs->operand_count()); +Status BFloat16NormalizationVisitor::HandleMultipleOutputs( + HloInstruction* hlo) { + std::vector operand_types(hlo->operand_count()); + std::vector output_types(hlo->operand_count()); int64 f32_count = 0; int64 bf16_count = 0; bool has_unsupported_bf16_operand = false; bool has_unsupported_bf16_output = false; - for (int64 i = 0; i < crs->operand_count(); ++i) { - operand_types[i] = crs->operand(i)->shape().element_type(); - output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + operand_types[i] = hlo->operand(i)->shape().element_type(); + output_types[i] = ShapeUtil::GetSubshape(hlo->shape(), {i}).element_type(); if (operand_types[i] == F32) { f32_count += 1; } else if (operand_types[i] == BF16) { bf16_count += 1; - if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { has_unsupported_bf16_operand = true; } } @@ -171,7 +190,7 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( f32_count += 1; } else if (output_types[i] == BF16) { bf16_count += 1; - if (!bfloat16_support_->SupportsBF16Output(*crs)) { + if (!bfloat16_support_->SupportsBF16Output(*hlo)) { has_unsupported_bf16_output = true; } } @@ -185,43 +204,43 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( if (operand_types[i] != BF16) { return false; } - if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { return true; } - if (bfloat16_support_->SupportsMixedPrecisions(*crs)) { + if (bfloat16_support_->SupportsMixedPrecisions(*hlo)) { return false; } return has_unsupported_bf16_operand || has_unsupported_bf16_output || f32_count > 0; }; - for (int64 i = 0; i < crs->operand_count(); ++i) { + for (int64 i = 0; i < hlo->operand_count(); ++i) { if (should_convert_operand(i)) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); f32_count += 1; bf16_count -= 1; } } if (!has_unsupported_bf16_output && - (bfloat16_support_->SupportsMixedPrecisions(*crs) || f32_count == 0 || + (bfloat16_support_->SupportsMixedPrecisions(*hlo) || f32_count == 0 || bf16_count == 0)) { return Status::OK(); } - std::vector materialized_users = crs->users(); - std::vector output_elements(crs->operand_count()); - auto original_shape = crs->shape(); - for (int64 i = 0; i < crs->operand_count(); ++i) { - auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}); + std::vector materialized_users = hlo->users(); + std::vector output_elements(hlo->operand_count()); + auto original_shape = hlo->shape(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + auto subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), {i}); if (output_types[i] != BF16) { output_elements[i] = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + HloInstruction::CreateGetTupleElement(*subshape, hlo, i)); continue; } subshape->set_element_type(F32); auto gte = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + HloInstruction::CreateGetTupleElement(*subshape, hlo, i)); output_elements[i] = computation_->AddInstruction(HloInstruction::CreateConvert( ShapeUtil::ChangeElementType(*subshape, BF16), gte)); @@ -229,11 +248,11 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( auto tuple = computation_->AddInstruction( HloInstruction::CreateTuple(output_elements)); - // Use the crs' shape temporarily, in order to pass checks in + // Use the hlo' shape temporarily, in order to pass checks in // ReplaceUseWith. - *tuple->mutable_shape() = crs->shape(); + *tuple->mutable_shape() = hlo->shape(); for (auto* user : materialized_users) { - TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple)); + TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, tuple)); } *tuple->mutable_shape() = original_shape; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 830f26422bdc2b3bd789e7d5926bcebac815d34a..f9f1f64998f5b925102dc238941897ff6d441b3f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -251,7 +251,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, - /*replica_group_ids=*/{}, /*barrier=*/"")); + /*replica_group_ids=*/{}, /*barrier=*/"", + /*all_reduce_id=*/tensorflow::gtl::nullopt)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); @@ -265,6 +266,33 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32); } +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); + Shape s32_shape = ShapeUtil::MakeShape(BF16, {1024}); + + HloInstruction* key = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "key")); + HloInstruction* value = builder.AddInstruction( + HloInstruction::CreateParameter(1, s32_shape, "value")); + + HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, value)); + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), gte); + EXPECT_EQ(gte->shape().element_type(), BF16); + EXPECT_EQ(sort->operand(0)->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); +} + // Tests that the normalization should not cause unsupported mixed precision due // to resolving unsupported BF16 operand. TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index ed0746980f87ac2bea79c308644dc63769f9e309..2fb401c4289728f3f59538464c5b8ad49957985b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_propagation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -85,9 +85,9 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes( auto root_changes_it = changes_to_bf16_.find(root); if (root_changes_it != changes_to_bf16_.end()) { - for (const auto& index : root_changes_it->second) { + for (const auto& entry : root_changes_it->second) { for (const HloValue* value : - dataflow_->GetValueSet(root, index).values()) { + dataflow_->GetValueSet(root, entry.second).values()) { changed_root_buffers.insert(value); } } @@ -204,12 +204,23 @@ void BFloat16Propagation::DetermineWhileComputationsPrecision( bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, const ShapeIndex& index) const { + // If the subshape isn't floating point then none of the users will be BF16. + const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index); + if (subshape.element_type() != BF16 && subshape.element_type() != F32) { + return false; + } + auto& value_set = dataflow_->GetValueSet(&hlo, index); for (const HloValue* value : value_set.values()) { if (ContainsKey(values_that_must_be_kept_as_f32_, value)) { return false; } - if (ValueTypeAfterChange(value) == BF16) { + // We use the original type for the value because we are going to examine + // the uses of it, instead of the value itself. If ValueTypeAfterChange() + // were used, it would cause problems when there are aliasing buffers, i.e., + // ResolveInconsistencyOfAliasingBuffers() would fail to revert the + // tentative change to BF16 even if the uses require F32. + if (value->shape().element_type() == BF16) { continue; } for (const HloUse& use : value->uses()) { @@ -257,23 +268,34 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, // If the op propagates precision and it outputs a BF16, then it's OK to // supply BF16 also as the input. In the backward pass, the users shapes // should have already been processed. - PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID; - if (use.instruction->opcode() == HloOpcode::kTuple || - (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && - ShapeUtil::IsTuple(use.instruction->shape()))) { - ShapeIndex use_output_index{use.operand_number}; - for (int64 i : use.operand_index) { - use_output_index.push_back(i); - } - user_output_type = - OutputTypeAfterChange(use.instruction, use_output_index); - } else { - user_output_type = OutputTypeAfterChange(use.instruction, {}); - } if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision( - *use.instruction, use.operand_number) && - user_output_type == BF16) { - continue; + *use.instruction, use.operand_number)) { + if (use.instruction->opcode() == HloOpcode::kTuple || + (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && + ShapeUtil::IsTuple(use.instruction->shape()))) { + ShapeIndex use_output_index{use.operand_number}; + for (int64 i : use.operand_index) { + use_output_index.push_back(i); + } + if (OutputTypeAfterChange(use.instruction, use_output_index) == + BF16) { + continue; + } + } else if (use.instruction->opcode() == HloOpcode::kGetTupleElement) { + ShapeIndex use_output_index; + for (int64 i = 1; i < use.operand_index.size(); ++i) { + use_output_index.push_back(use.operand_index[i]); + } + if (OutputTypeAfterChange(use.instruction, use_output_index) == + BF16) { + continue; + } + } else { + if (OutputTypeAfterChange(use.instruction, use.operand_index) == + BF16) { + continue; + } + } } return false; } @@ -368,6 +390,7 @@ bool BFloat16Propagation::InstructionIsCandidateForBF16Output( if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && hlo->opcode() != HloOpcode::kTuple && hlo->opcode() != HloOpcode::kGetTupleElement && + hlo->opcode() != HloOpcode::kDomain && hlo->shape().element_type() != BF16) { for (int64 i = 0; i < hlo->operand_count(); ++i) { if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, @@ -548,6 +571,9 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( } visited_computations->insert(visited_in_while.begin(), visited_in_while.end()); + } else if (hlo->opcode() == HloOpcode::kFusion) { + ResolveInconsistencyOfAliasingBuffersHelper( + hlo->fused_instructions_computation(), visited_computations); } } // Now adjust parameters of called computations. @@ -559,7 +585,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( HloModule* module) { - std::list computations_topological_order = + const auto& computations_topological_order = module->MakeComputationPostOrder(); tensorflow::gtl::FlatSet resolved; for (auto comp_it = computations_topological_order.rbegin(); @@ -597,7 +623,6 @@ Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) { // (1) a is F32 but tuple is BF16 // (2) after adding conversion // (3) after tuple simplifier and DCE. - bool needs_tuple_simplifier = false; for (auto computation : module->MakeComputationPostOrder()) { auto insts = computation->MakeInstructionPostOrder(); for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { @@ -611,67 +636,25 @@ Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) { continue; } ShapeTree converted_outputs(hlo->shape()); - // Iterate through nodes in the shape tree in pre-order and initialize - // each non-root node with a corresponding get-tuple-element. For a leaf - // node, if its shape does not match the fusion output, create a - // conversion node to overwrite the node value. - for (auto it = converted_outputs.begin(); it != converted_outputs.end(); - ++it) { - ShapeIndex output_index = it->first; - HloInstruction*& output = it->second; - const Shape subshape = - ShapeUtil::GetSubshape(hlo->shape(), output_index); - if (output_index.empty()) { - output = fusion_root; - } else { - ShapeIndex parent_index = output_index; - parent_index.pop_back(); - output = fusion_computation->AddInstruction( - HloInstruction::CreateGetTupleElement( - subshape, converted_outputs.element(parent_index), - output_index.back())); - } - if (ShapeUtil::IsTuple(subshape)) { - continue; - } - if (!ShapeUtil::Compatible( - subshape, - ShapeUtil::GetSubshape(fusion_root->shape(), output_index))) { - output = fusion_computation->AddInstruction( - HloInstruction::CreateConvert(subshape, output)); - } - } - // Iterate through nodes in the shape tree in reverse pre-order and create - // a tuple instruction for each non-leaf node where the elements are the - // values of its child nodes. - for (auto it = converted_outputs.rbegin(); it != converted_outputs.rend(); - ++it) { - ShapeIndex output_index = it->first; - HloInstruction*& output = it->second; - const Shape& subshape = - ShapeUtil::GetSubshape(hlo->shape(), output_index); - if (!ShapeUtil::IsTuple(subshape)) { - continue; - } - std::vector elements( - ShapeUtil::TupleElementCount(subshape)); - ShapeIndex child_index = output_index; - for (int64 i = 0; i < elements.size(); ++i) { - child_index.push_back(i); - elements[i] = converted_outputs.element(child_index); - child_index.pop_back(); - } - output = fusion_computation->AddInstruction( - HloInstruction::CreateTuple(elements)); - } - fusion_computation->set_root_instruction(converted_outputs.element({})); - needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape()); + // Deep copy the fusion root, and convert a leaf node only if its shape + // does not match the fusion output. + TF_ASSIGN_OR_RETURN( + HloInstruction * copy, + fusion_computation->DeepCopyInstructionWithCustomCopier( + fusion_root, + [hlo](HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* comp) { + const Shape& hlo_subshape = + ShapeUtil::GetSubshape(hlo->shape(), leaf_index); + if (ShapeUtil::Compatible(leaf->shape(), hlo_subshape)) { + return leaf; + } + return comp->AddInstruction( + HloInstruction::CreateConvert(hlo_subshape, leaf)); + })); + fusion_computation->set_root_instruction(copy); } } - if (needs_tuple_simplifier) { - TupleSimplifier tuple_simplifier; - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); - } return Status::OK(); } @@ -740,10 +723,38 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { changes_to_bf16_.clear(); changed_ = false; + auto computations_topological_order = module->MakeComputationPostOrder(); + + // Before running the propagation pass, we insert copies (kConvert to the same + // type) of F32 inputs to while loops. This prevents other uses of the same + // input from aliasing the while loop input/output, so that there's greater + // chance to use BF16 inside the loop. If some of these added copies do not + // help, they will remain F32 after BF16 propagation and will be removed since + // they are no-ops. + for (auto computation : computations_topological_order) { + for (auto inst : computation->MakeInstructionPostOrder()) { + if (inst->opcode() != HloOpcode::kWhile) { + continue; + } + + auto operand = inst->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * copy, + computation->DeepCopyInstructionWithCustomCopier( + operand, [](HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* comp) { + if (leaf->shape().element_type() != F32) { + return leaf; + } + return comp->AddInstruction( + HloInstruction::CreateConvert(leaf->shape(), leaf)); + })); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(inst, copy)); + } + } + TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module)); - std::list computations_topological_order = - module->MakeComputationPostOrder(); // The first step is a forward pass (parameters to root), where we determine // the potential candidate instructions to use bfloat16 in the outputs that // are not likely to cause overhead from extra explicit conversions. This is @@ -766,8 +777,7 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { // propagation in reverse topological order. for (auto comp_it = computations_topological_order.rbegin(); comp_it != computations_topological_order.rend(); ++comp_it) { - if ((*comp_it)->IsFusionComputation()) { - // Fusion computations are handled when visiting the fusion instruction. + if (ContainsKey(computations_visited_in_backward_pass_, *comp_it)) { continue; } auto insts = (*comp_it)->MakeInstructionPostOrder(); @@ -775,6 +785,7 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/true); } + computations_visited_in_backward_pass_.insert(*comp_it); } // It's possible that an instruction does not define a buffer, but the @@ -784,39 +795,42 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { // Apply the changes in changes_to_bf16_. for (auto& change : changes_to_bf16_) { - auto shape = change.first->mutable_shape(); - for (const auto& index : change.second) { - auto subshape = ShapeUtil::GetMutableSubshape(shape, index); + for (const auto& entry : change.second) { + auto subshape = entry.first; CHECK_EQ(subshape->element_type(), F32); subshape->set_element_type(BF16); changed_ = true; } } + // Removes redundant HLOs added by this pass, either when inserting + // de-aliasing copies to while loop inputs, or later when converting output + // types. + auto clean_up = [this, module]() { + TF_RETURN_IF_ERROR(SkipNoopConversions(module)); + TupleSimplifier tuple_simplifier; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + return Status::OK(); + }; + if (!changed_) { + TF_RETURN_IF_ERROR(clean_up()); return false; } TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module)); TF_RETURN_IF_ERROR(ResolveConvertedConstants(module)); - // This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 -> - // BF16), so we skip them now. - TF_RETURN_IF_ERROR(SkipNoopConversions(module)); - - { - // We may have dead HLOs after ResolveInconsistentFusions, - // ResolveConvertedConstants and SkipNoopConversions. - HloDCE dce; - TF_RETURN_IF_ERROR(dce.Run(module).status()); - } + TF_RETURN_IF_ERROR(clean_up()); return true; } PrimitiveType BFloat16Propagation::OutputTypeAfterChange( HloInstruction* hlo, const ShapeIndex& index) const { - PrimitiveType type_on_hlo = - ShapeUtil::GetSubshape(hlo->shape(), index).element_type(); + Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index); + const PrimitiveType type_on_hlo = subshape->element_type(); if (type_on_hlo != F32) { return type_on_hlo; } @@ -824,7 +838,7 @@ PrimitiveType BFloat16Propagation::OutputTypeAfterChange( if (it == changes_to_bf16_.end()) { return type_on_hlo; } - return ContainsKey(it->second, index) ? BF16 : F32; + return ContainsKey(it->second, subshape) ? BF16 : F32; } PrimitiveType BFloat16Propagation::ValueTypeAfterChange( @@ -838,14 +852,16 @@ void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet( HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) { if (target_type == BF16) { auto& entry = changes_to_bf16_[hlo]; - entry.insert(index); + entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index), + index); } else { CHECK_EQ(target_type, F32); auto it = changes_to_bf16_.find(hlo); if (it == changes_to_bf16_.end()) { return; } - it->second.erase(index); + it->second.erase( + ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index)); } } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index de0355ddfca127753f90d1899b424a8e77c9b291..02b8cad089dd8465b7af5c1014e37b77ded6949d 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -194,17 +194,11 @@ class BFloat16Propagation : public HloPassInterface { // are subject to further adjustment, then finally applied to the HLOs. This // avoids setting changed_ to true but all changes are reverted during // adjustment. - struct IndexHasher { - int64 operator()(const ShapeIndex& index) const { - int64 hash = 0; - for (int64 i : index) { - hash = tensorflow::Hash64Combine(hash, std::hash()(i)); - } - return hash; - } - }; + // + // For each HloInstruction, changes_to_bf16_ stores the affected buffers in + // the output as a map from in-place pointers to subshapes to shape indices. tensorflow::gtl::FlatMap> + tensorflow::gtl::FlatMap> changes_to_bf16_; // Whether the last processed HLO module has been changed by this pass. diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 5e1499ee6b6ef397f95f7ed29e808d530777bd07..69b654d30e42b1ed69304206f09120e86831d468 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -133,9 +133,9 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { array_b.FillUnique(10.0f); HloInstruction* a = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateFromArray(array_a))); + HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a))); HloInstruction* b = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateFromArray(array_b))); + HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b))); HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b)); @@ -150,11 +150,11 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); EXPECT_TRUE(LiteralTestUtil::Equal( - dot->operand(0)->literal(), - *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)))); + *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)), + dot->operand(0)->literal())); EXPECT_TRUE(LiteralTestUtil::Equal( - dot->operand(1)->literal(), - *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)))); + *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)), + dot->operand(1)->literal())); } // Tests that BF16 can be propagated through nested tuples. @@ -240,12 +240,10 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); - EXPECT_TRUE(OutputsBF16(add0)); EXPECT_TRUE(OutputsBF16(add1)); EXPECT_TRUE(OutputsBF16(lhs)); - // rhs is a get-tuple-element, which does not define a buffer, but its shape - // should also be adjusted accordingly. - EXPECT_TRUE(OutputsBF16(rhs)); + + // add0 and rhs have been eliminated by simplification and DCE. } // Tests that a non-fusion computation's root should not be changed. @@ -434,7 +432,7 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { HloInstruction* tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({param, add1})); HloInstruction* sel = builder.AddInstruction(HloInstruction::CreateTernary( - tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); + tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); HloInstruction* gte0 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, sel, 0)); HloInstruction* gte1 = builder.AddInstruction( @@ -510,6 +508,63 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { EXPECT_FALSE(OutputsBF16(dot)); } +// Tests that if the while condition prevents using BF16, no changes should be +// made to the while body and thus the fusion node inside it. +TEST_F(BFloat16PropagationTest, + ConditionPreventsPropagationForFusionInsideWhile) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + + auto builder_cond = HloComputation::Builder("cond"); + auto cond_param = builder_cond.AddInstruction( + HloInstruction::CreateParameter(0, shape, "cond_param")); + builder_cond.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1})))); + auto cond = module->AddEmbeddedComputation(builder_cond.Build()); + + auto builder_body = HloComputation::Builder("body"); + auto body_param = builder_body.AddInstruction( + HloInstruction::CreateParameter(0, shape, "body_param")); + auto body_transpose = builder_body.AddInstruction( + HloInstruction::CreateTranspose(shape, body_param, {0, 1})); + + auto builder_f = HloComputation::Builder("fusion"); + HloInstruction* a_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + builder_f.AddInstruction(HloInstruction::CreateTranspose(shape, a_f, {0, 1})); + auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); + auto body_fusion = builder_body.AddInstruction(HloInstruction::CreateFusion( + shape, HloInstruction::FusionKind::kCustom, {body_transpose}, comp_f)); + auto body = module->AddEmbeddedComputation(builder_body.Build()); + + auto while_hlo = builder.AddInstruction( + HloInstruction::CreateWhile(shape, cond, body, add)); + + auto dot = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_FALSE(OutputsBF16(add)); + EXPECT_FALSE(OutputsBF16(body_fusion)); + EXPECT_FALSE(OutputsBF16(body_param)); + EXPECT_FALSE(OutputsBF16(body_transpose)); + EXPECT_FALSE(OutputsBF16(a_f)); +} + // Tests that BF16 is propagated properly through while computations with // tuple-shaped input/output. TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { @@ -555,10 +610,14 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot = builder_body.AddInstruction( + auto body_dot1 = builder_body.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + auto body_dot2 = builder_body.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs)); + auto body_transpose = builder_body.AddInstruction( + HloInstruction::CreateTranspose(shape, body_dot2, {0, 1})); builder_body.AddInstruction( - HloInstruction::CreateTuple({body_dot, body_rhs})); + HloInstruction::CreateTuple({body_dot1, body_transpose})); auto body = module->AddEmbeddedComputation(builder_body.Build()); auto while_hlo = builder.AddInstruction( @@ -577,9 +636,11 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(lhs)); EXPECT_FALSE(OutputsBF16(rhs)); - EXPECT_TRUE(OutputsBF16(body_dot)); + EXPECT_TRUE(OutputsBF16(body_dot1)); EXPECT_TRUE(OutputsBF16(body_lhs)); EXPECT_FALSE(OutputsBF16(body_rhs)); + EXPECT_FALSE(OutputsBF16(body_dot2)); + EXPECT_FALSE(OutputsBF16(body_transpose)); EXPECT_TRUE(OutputsBF16(cond_lhs)); EXPECT_FALSE(OutputsBF16(cond_rhs)); EXPECT_TRUE(OutputsBF16(add0)); @@ -734,12 +795,95 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), add2); - EXPECT_EQ(add2->operand(0), gte0); - EXPECT_EQ(add2->operand(1), gte1); - EXPECT_EQ(gte0->shape().element_type(), BF16); - EXPECT_EQ(gte1->shape().element_type(), BF16); + EXPECT_EQ(add2->operand(0), add0); + EXPECT_EQ(add2->operand(1), add1); EXPECT_EQ(add0->shape().element_type(), BF16); EXPECT_EQ(add1->shape().element_type(), BF16); } +TEST_F(BFloat16PropagationTest, TupleDomain) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* a_trans = + builder.AddInstruction(HloInstruction::CreateTranspose(shape, a, {0, 1})); + HloInstruction* b_trans = + builder.AddInstruction(HloInstruction::CreateTranspose(shape, b, {0, 1})); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({a_trans, b_trans})); + HloInstruction* domain = builder.AddInstruction( + HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); + HloInstruction* a_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 0)); + HloInstruction* b_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 1)); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte)); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_EQ(computation->root_instruction(), root); + + // test BF16 propagated through domain + EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 0).element_type(), + BF16); + EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 1).element_type(), + BF16); + + EXPECT_TRUE(OutputsBF16(a_trans)); + EXPECT_TRUE(OutputsBF16(b_trans)); + EXPECT_TRUE(OutputsBF16(a_gte)); + EXPECT_TRUE(OutputsBF16(b_gte)); + EXPECT_FALSE(OutputsBF16(a)); + EXPECT_FALSE(OutputsBF16(b)); +} + +// Tests that bf16 is not propagated through a domain in case its input cannot +// be propagated. In the case below the input of the domain is the parameter +// tuple which cannot be propagated, so the domain instruction is not propagated +// either. +TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + HloInstruction* domain = builder.AddInstruction( + HloInstruction::CreateDomain(param->shape(), param, nullptr, nullptr)); + HloInstruction* a_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 0)); + HloInstruction* b_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 1)); + HloInstruction* a_trans = builder.AddInstruction( + HloInstruction::CreateTranspose(shape, a_gte, {0, 1})); + HloInstruction* b_trans = builder.AddInstruction( + HloInstruction::CreateTranspose(shape, b_gte, {0, 1})); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans)); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), root); + EXPECT_TRUE(OutputsBF16(a_trans)); + EXPECT_TRUE(OutputsBF16(b_trans)); + EXPECT_FALSE(OutputsBF16(a_gte)); + EXPECT_FALSE(OutputsBF16(b_gte)); + EXPECT_FALSE(OutputsBF16(domain)); + EXPECT_FALSE(OutputsBF16(param)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 07b4b14b5ec1bdbc01345091105df69368b0b2fb..23645346e6f491beb5171cc839c013ce5f83d789 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -25,6 +25,7 @@ bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo, case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kGetTupleElement: case HloOpcode::kTuple: case HloOpcode::kWhile: @@ -43,6 +44,7 @@ bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kGetTupleElement: case HloOpcode::kTuple: case HloOpcode::kWhile: @@ -81,6 +83,7 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kConcatenate: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kDomain: case HloOpcode::kGetTupleElement: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -92,11 +95,15 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kTranspose: case HloOpcode::kTuple: return true; + case HloOpcode::kBitcast: + return hlo.shape().element_type() == + hlo.operand(0)->shape().element_type(); case HloOpcode::kDynamicSlice: return operand_index == 0; case HloOpcode::kDynamicUpdateSlice: return operand_index == 0 || operand_index == 1; case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: return operand_index == 1 || operand_index == 2; default: break; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 5d3b0cb333928d3b7b042ef0a6f4969f87655d7f..cfd26fc778cbf9b031fb73259eb76538327b2a6c 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -139,6 +139,7 @@ Status GatherComputationsByAllocationType( case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kFusion: // Map/reduce etc computations are always thread-local. @@ -270,7 +271,7 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto.set_index(index_); proto.set_size(size_); proto.set_is_thread_local(is_thread_local_); - proto.set_is_reusable(is_reusable_); + proto.set_is_tuple(is_tuple_); proto.set_color(color_.value()); if (is_entry_computation_parameter_) { proto.set_is_entry_computation_parameter(true); @@ -279,6 +280,7 @@ BufferAllocationProto BufferAllocation::ToProto() const { } proto.set_parameter_number(parameter_number_); } + proto.set_is_constant(is_constant_); proto.set_maybe_live_out(maybe_live_out_); for (const auto& buffer_offset_size : assigned_buffers_) { BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned(); @@ -304,6 +306,9 @@ string BufferAllocation::ToString() const { StrAppend(&output, ", parameter ", parameter_number(), " at ShapeIndex ", param_shape_index().ToString()); } + if (is_constant()) { + StrAppend(&output, ", constant"); + } if (is_thread_local()) { StrAppend(&output, ", thread-local"); } @@ -491,20 +496,16 @@ BufferAssignment::GetUniqueTopLevelOutputSlice() const { } BufferAllocation* BufferAssignment::NewEmptyAllocation( - int64 size, bool is_thread_local, bool is_reusable, - LogicalBuffer::Color color) { + int64 size, LogicalBuffer::Color color) { BufferAllocation::Index index = allocations_.size(); - allocations_.emplace_back(index, size, is_thread_local, is_reusable, color); + allocations_.emplace_back(index, size, color); BufferAllocation* allocation = &allocations_.back(); return allocation; } BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer, - int64 size, - bool is_thread_local, - bool is_reusable) { - BufferAllocation* allocation = - NewEmptyAllocation(size, is_thread_local, is_reusable, buffer.color()); + int64 size) { + BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color()); AddAssignment(allocation, buffer, /*offset=*/0, size); allocation->peak_buffers_.push_back(&buffer); return allocation; @@ -517,7 +518,8 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation, CHECK_EQ(0, allocation_index_for_buffer_.count(&buffer)) << "LogicalBuffer " << buffer << " already has an allocation."; CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty()) - << "Non-reusable allocation already assigned a buffer"; + << "Non-reusable allocation already assigned a buffer: " + << allocation->ToString(); TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer)); @@ -609,6 +611,10 @@ Status BufferAssignment::ComputeSummaryStats() { stats_.parameter_allocation_count++; stats_.parameter_allocation_bytes += allocation.size(); } + if (allocation.is_constant()) { + stats_.constant_allocation_count++; + stats_.constant_allocation_bytes += allocation.size(); + } if (allocation.maybe_live_out()) { stats_.maybe_live_out_allocation_count++; stats_.maybe_live_out_allocation_bytes += allocation.size(); @@ -631,8 +637,9 @@ Status BufferAssignment::ComputeSummaryStats() { } } if (module_sequence.size() == module_->computation_count()) { - TF_ASSIGN_OR_RETURN(const int64 min_size, - MinimumMemoryForModule(module_sequence, buffer_size_)); + TF_ASSIGN_OR_RETURN( + const int64 min_size, + HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } @@ -644,6 +651,8 @@ string BufferAssignment::Stats::ToString() const { Appendf(&s, "BufferAssignment stats:\n"); Appendf(&s, " parameter allocation: %10s\n", HumanReadableNumBytes(parameter_allocation_bytes).c_str()); + Appendf(&s, " constant allocation: %10s\n", + HumanReadableNumBytes(constant_allocation_bytes).c_str()); Appendf(&s, " maybe_live_out allocation: %10s\n", HumanReadableNumBytes(maybe_live_out_allocation_bytes).c_str()); Appendf(&s, " preallocated temp allocation: %10s\n", @@ -721,8 +730,10 @@ StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, LogicalBuffer::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, - bool allow_input_output_aliasing, BufferLiveness::Colorer colorer) { - BufferAssigner assigner(allow_input_output_aliasing, std::move(colorer)); + bool allow_input_output_aliasing, bool allocate_buffers_for_constants, + BufferLiveness::Colorer colorer) { + BufferAssigner assigner(allow_input_output_aliasing, + allocate_buffers_for_constants, std::move(colorer)); return assigner.CreateAssignment(module, std::move(hlo_ordering), std::move(buffer_size), std::move(color_alignment)); @@ -750,8 +761,8 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, return false; } - if (allocation->is_entry_computation_parameter()) { - VLOG(4) << "Can't assign: allocation holds parameter"; + if (allocation->is_readonly()) { + VLOG(4) << "Can't assign: allocation is readonly"; return false; } @@ -807,8 +818,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } Status BufferAssigner::AssignBuffersForComputation( - const HloComputation* computation, const DebugOptions& debug_options, - bool is_thread_local, + const HloComputation* computation, bool is_thread_local, const FlatSet& colocated_buffers, const FlatSet& colocated_allocations, FlatMap>* @@ -868,8 +878,8 @@ Status BufferAssigner::AssignBuffersForComputation( // important reuse case where an elementwise instruction reuses one of its // operand's buffer. This improves locality. std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [this, has_sequential_order, &liveness, &post_order_position, - assignment](const LogicalBuffer* a, const LogicalBuffer* b) { + [has_sequential_order, &liveness, &post_order_position, assignment]( + const LogicalBuffer* a, const LogicalBuffer* b) { // Primary sort is by decreasing buffer size. const int64 a_size = assignment->buffer_size_(*a); const int64 b_size = assignment->buffer_size_(*b); @@ -904,15 +914,19 @@ Status BufferAssigner::AssignBuffersForComputation( TF_RET_CHECK(!assignment->HasAllocation(*buffer)); const HloInstruction* instruction = buffer->instruction(); + const int64 buffer_size = assignment->buffer_size_(*buffer); + if (instruction->opcode() == HloOpcode::kConstant) { - // No BufferAllocations for constants. - // TODO(b/32248867): For consistency, constants should get allocations. - VLOG(3) << "Skipping constant: " << *buffer; + if (allocate_buffers_for_constants_) { + BufferAllocation* allocation = + assignment->NewAllocation(*buffer, buffer_size); + allocation->set_constant(true); + VLOG(3) << "New allocation #" << allocation->index() << " for constant " + << *buffer; + } continue; } - const int64 buffer_size = assignment->buffer_size_(*buffer); - const bool is_entry_parameter = instruction->opcode() == HloOpcode::kParameter && computation == computation->parent()->entry_computation(); @@ -922,9 +936,7 @@ Status BufferAssigner::AssignBuffersForComputation( // computations do not need special allocations because they live inside // callers. BufferAllocation* allocation = - assignment->NewAllocation(*buffer, buffer_size, - /*is_thread_local=*/false, - /*is_reusable=*/false); + assignment->NewAllocation(*buffer, buffer_size); allocation->set_entry_computation_parameter( instruction->parameter_number(), buffer->index()); VLOG(3) << "New allocation #" << allocation->index() @@ -933,20 +945,18 @@ Status BufferAssigner::AssignBuffersForComputation( } if (is_thread_local) { - // We do not reuse thread-local buffers for now, because they are - // dynamically allocated and their lifetimes are hard to compute. - BufferAllocation* allocation = assignment->NewAllocation( - *buffer, buffer_size, is_thread_local, /*is_reusable=*/false); + BufferAllocation* allocation = + assignment->NewAllocation(*buffer, buffer_size); + allocation->set_is_thread_local(true); VLOG(3) << "New allocation #" << allocation->index() << " for thread-local: " << *buffer; continue; } if (ShapeUtil::IsTuple(buffer->shape())) { - // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend - // assumes longer buffer liveness than indicated by the analysis. - BufferAllocation* allocation = assignment->NewAllocation( - *buffer, buffer_size, is_thread_local, /*is_reusable=*/false); + BufferAllocation* allocation = + assignment->NewAllocation(*buffer, buffer_size); + allocation->set_is_tuple(true); VLOG(3) << "New allocation #" << allocation->index() << " for tuple-shaped buffer: " << *buffer; continue; @@ -1029,8 +1039,8 @@ Status BufferAssigner::AssignBuffersForComputation( } if (!assignment->HasAllocation(*buffer)) { - BufferAllocation* allocation = assignment->NewAllocation( - *buffer, buffer_size, is_thread_local, /*is_reusable=*/true); + BufferAllocation* allocation = + assignment->NewAllocation(*buffer, buffer_size); allocation_indices.push_back(allocation->index()); VLOG(3) << "New allocation #" << allocation->index() << " for: " << *buffer; @@ -1084,6 +1094,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; + options.alloc_constants = allocate_buffers_for_constants_; BufferValueFlatSet buffer_value_set = ToBufferValueFlatSet(single_colored_set.second); options.buffers_to_assign = &buffer_value_set; @@ -1226,8 +1237,8 @@ void BufferAssigner::AssignBuffersFromHeapSimulator( result.fragmentation_size; } - BufferAllocation* allocation = assignment->NewEmptyAllocation( - result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true, color); + BufferAllocation* allocation = + assignment->NewEmptyAllocation(result.heap_size, color); for (const auto& buffer_chunk : result.chunk_map) { // TODO(lauj) Remove this down_cast after downstream users of // BufferAllocation::assigned_buffers() are updated to use BufferValue. @@ -1331,11 +1342,25 @@ BufferAssigner::MergeColocatedBufferSets( auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, &buffer_size, &is_entry_parameter](int64 i, int64 j) { - // Do not merge if one of the sets includes live outs or entry parameters. + // Do not merge if one of the sets includes live outs, entry parameters or + // constants. + // + // Buffer liveness does not report the correct live range for entry + // parameter and live out buffers so we have to special case them here. On + // backends that support constant buffer allocations, constant buffers are + // assigned globals in readonly storage so we can't merge colocated buffer + // sets containing constants with colocated buffer sets containing writing + // instructions or other constants. + // + // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to + // the caller of the executable so we can't write to entry parameters + // either, and the argument for not merging constants also applies to entry + // parameters. for (int64 key : {i, j}) { for (auto& buffer : colocated_buffer_sets[key]) { if (buffer_liveness.MaybeLiveOut(*buffer) || - is_entry_parameter(*buffer)) { + is_entry_parameter(*buffer) || + buffer->instruction()->opcode() == HloOpcode::kConstant) { return true; } } @@ -1417,9 +1442,9 @@ void BufferAssigner::BuildColocatedBufferSets( const HloInstruction* while_hlo = instruction; ShapeUtil::ForEachSubshape( while_hlo->shape(), - [this, while_hlo, &points_to_analysis, &buffer_liveness, - buffer_size, computation, colocated_buffer_sets]( - const Shape& /*subshape*/, const ShapeIndex& index) { + [this, while_hlo, &points_to_analysis, buffer_size, + colocated_buffer_sets](const Shape& /*subshape*/, + const ShapeIndex& index) { std::vector colocated_set; // Add while.init. AddBufferToColocatedSet(while_hlo->operand(0), index, @@ -1443,8 +1468,23 @@ void BufferAssigner::BuildColocatedBufferSets( }); } else if (opcode == HloOpcode::kCall) { const HloInstruction* call_hlo = instruction; - const HloInstruction* root_hlo = - call_hlo->to_apply()->root_instruction(); + const HloComputation* callee = call_hlo->to_apply(); + const HloInstruction* root_hlo = callee->root_instruction(); + for (int64 i = 0; i < call_hlo->operand_count(); i++) { + const HloInstruction* call_param = callee->parameter_instruction(i); + const HloInstruction* call_operand = call_hlo->operand(i); + ShapeUtil::ForEachSubshape( + call_operand->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector colocated_set; + AddBufferToColocatedSet(call_param, index, points_to_analysis, + &colocated_set); + AddBufferToColocatedSet(call_operand, index, points_to_analysis, + &colocated_set); + AddSetToColocatedBufferSets(colocated_set, + colocated_buffer_sets); + }); + } ShapeUtil::ForEachSubshape( call_hlo->shape(), [this, call_hlo, root_hlo, &points_to_analysis, @@ -1550,6 +1590,7 @@ void BufferAssigner::AssignColocatedBufferSets( // param in 'colocated_buffer_set'. int64 entry_parameter_number = -1; const ShapeIndex* entry_parameter_shape_idx = nullptr; + bool is_constant = false; for (const LogicalBuffer* buffer : colocated_buffer_set) { const HloInstruction* instruction = buffer->instruction(); const HloComputation* computation = instruction->parent(); @@ -1557,10 +1598,14 @@ void BufferAssigner::AssignColocatedBufferSets( computation == computation->parent()->entry_computation()) { entry_parameter_number = instruction->parameter_number(); entry_parameter_shape_idx = &buffer->index(); - break; + } else if (instruction->opcode() == HloOpcode::kConstant) { + is_constant = true; } } + CHECK(!is_constant || entry_parameter_number == -1) + << "Copy insertion should have inserted copies to prevent this."; + for (const LogicalBuffer* buffer : colocated_buffer_set) { const int64 buffer_size = assignment->buffer_size_(*buffer); if (allocation == nullptr) { @@ -1568,18 +1613,14 @@ void BufferAssigner::AssignColocatedBufferSets( // allocations for each colocated buffer set. When liveness has // module-level scope, we can allow buffers to be shared across // computations (in some cases). - allocation = assignment->NewAllocation(*buffer, buffer_size, - /*is_thread_local=*/false, - /*is_reusable=*/true); + allocation = assignment->NewAllocation(*buffer, buffer_size); if (entry_parameter_number >= 0) { - // This colocated buffer set contains an entry parameter and other - // logical buffers which use the parameter as read-only in a while - // body computation (which updates in place). - // Set 'entry_computation_parameter' to indicate that it contains - // an entry parameter, and to prevent reuse in MaybeAssignBuffer. allocation->set_entry_computation_parameter( entry_parameter_number, *entry_parameter_shape_idx); } + if (is_constant) { + allocation->set_constant(true); + } colocated_allocations->insert(allocation->index()); } else { CHECK_EQ(buffer_size, allocation->size()) @@ -1637,7 +1678,7 @@ StatusOr> BufferAssigner::CreateAssignment( buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, module->config().debug_options(), + computation, /*is_thread_local=*/false, colocated_buffers, colocated_allocations, &buffers_to_assign_sequentially, assignment.get())); } @@ -1658,7 +1699,7 @@ StatusOr> BufferAssigner::CreateAssignment( continue; } TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, module->config().debug_options(), + computation, /*is_thread_local=*/true, colocated_buffers, colocated_allocations, /*buffers_to_assign_sequentially=*/nullptr, assignment.get())); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index ad0b0bf7c25d7194a06801e4ef1c9ee961f6b915..94495290c131e22392079dc2d0237d990b646d3e 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -58,13 +57,8 @@ class BufferAllocation { // contiguously and can be used as array indexes. using Index = int64; - BufferAllocation(Index index, int64 size, bool is_thread_local, - bool is_reusable, LogicalBuffer::Color color) - : index_(index), - size_(size), - is_thread_local_(is_thread_local), - is_reusable_(is_reusable), - color_(color) {} + BufferAllocation(Index index, int64 size, LogicalBuffer::Color color) + : index_(index), size_(size), color_(color) {} ~BufferAllocation() {} // Returns the index of this allocation. @@ -74,9 +68,28 @@ class BufferAllocation { // inside of a map or reduce computation. Such allocations need to be thread // local. bool is_thread_local() const { return is_thread_local_; } + void set_is_thread_local(bool is_thread_local) { + is_thread_local_ = is_thread_local; + } // Whether this allocation can be used by more than one logical buffer. - bool is_reusable() const { return is_reusable_; } + bool is_reusable() const { + // We do not reuse thread-local buffers for now, because they are + // dynamically allocated and their lifetimes are hard to compute. + // + // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend + // assumes longer buffer liveness than indicated by the analysis. + return !is_thread_local() && !is_tuple(); + } + + // Whether this allocation is readonly i.e. backed by memory we cannot write + // to. + bool is_readonly() const { + return is_entry_computation_parameter() || is_constant(); + } + + bool is_tuple() const { return is_tuple_; } + void set_is_tuple(bool is_tuple) { is_tuple_ = is_tuple; } // Whether this allocation holds a LogicalBuffer from a parameter of the entry // computation. These buffers have lifetimes which may be longer than the @@ -84,6 +97,13 @@ class BufferAllocation { bool is_entry_computation_parameter() const { return is_entry_computation_parameter_; } + + // Whether this allocation holds a constant. On the CPU and GPU backends + // constant allocations are not allocated dynamically, instead we resolve + // references to these buffer allocations to a global in the readonly section + // of the binary. + bool is_constant() const { return is_constant_; } + // If this allocation holds a Buffer from a parameter of the entry // computation, this methods returns the parameter number. CHECKs otherwise. int64 parameter_number() const { @@ -189,7 +209,9 @@ class BufferAllocation { // of the computation. !maybe_live_out() && // Thread-local buffers are allocated using `alloca`s. - !is_thread_local(); + !is_thread_local() && + // Constant buffers are allocated as global values. + !is_constant(); } // Add a heap trace which was used to assign slices to logical buffers in this @@ -245,6 +267,8 @@ class BufferAllocation { parameter_number_ = parameter_number; param_shape_index_ = std::move(param_shape_index); } + + void set_constant(bool is_constant) { is_constant_ = is_constant; } void set_maybe_live_out(bool value) { maybe_live_out_ = value; } void set_index(Index index) { index_ = index; } void set_size(int64 size) { size_ = size; } @@ -256,10 +280,10 @@ class BufferAllocation { int64 size_; // Whether this buffer needs to be thread-local. - bool is_thread_local_; + bool is_thread_local_ = false; - // Whether this buffer is usable by more than one logical buffer. - bool is_reusable_; + // Whether this buffer holds a tuple. + bool is_tuple_ = false; // Color of the allocation. LogicalBuffer::Color color_; @@ -283,6 +307,9 @@ class BufferAllocation { // might not actually escape. bool maybe_live_out_ = false; + // See comment on the is_constant() accessor. + bool is_constant_ = false; + // Mapping from the set of buffers assigned to this allocation to their // logical offsets and sizes. tensorflow::gtl::FlatMap assigned_buffers_; @@ -398,6 +425,8 @@ class BufferAssignment { struct Stats { int64 parameter_allocation_count = 0; int64 parameter_allocation_bytes = 0; + int64 constant_allocation_count = 0; + int64 constant_allocation_bytes = 0; int64 maybe_live_out_allocation_count = 0; int64 maybe_live_out_allocation_bytes = 0; int64 preallocated_temp_allocation_count = 0; @@ -426,14 +455,11 @@ class BufferAssignment { // Creates and returns a new BufferAllocation, with no assigned // LogicalBuffers. Ownership is maintained internally. - BufferAllocation* NewEmptyAllocation(int64 size, bool is_thread_local, - bool is_reusable, - LogicalBuffer::Color color); + BufferAllocation* NewEmptyAllocation(int64 size, LogicalBuffer::Color color); // Helper that calls NewEmptyAllocation and AddAssignment in one call, // creating an allocation containing a single LogicalBuffer. - BufferAllocation* NewAllocation(const LogicalBuffer& buffer, int64 size, - bool is_thread_local, bool is_reusable); + BufferAllocation* NewAllocation(const LogicalBuffer& buffer, int64 size); // Adds a LogicalBuffer to the set assigned to the given allocation. void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer, @@ -493,12 +519,15 @@ class BufferAssigner { LogicalBuffer::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, bool allow_input_output_aliasing = false, + bool allocate_buffers_for_constants = false, BufferLiveness::Colorer colorer = BufferLiveness::DefaultColorer()); private: BufferAssigner(bool allow_input_output_aliasing, + bool allocate_buffers_for_constants, BufferLiveness::Colorer colorer) : allow_input_output_aliasing_(allow_input_output_aliasing), + allocate_buffers_for_constants_(allocate_buffers_for_constants), colorer_(colorer) {} virtual ~BufferAssigner() = default; @@ -513,8 +542,7 @@ class BufferAssigner { // true, then all assigned buffers have the is_thread_local flag set to // true. Status AssignBuffersForComputation( - const HloComputation* computation, const DebugOptions& debug_options, - bool is_thread_local, + const HloComputation* computation, bool is_thread_local, const tensorflow::gtl::FlatSet& colocated_buffers, const tensorflow::gtl::FlatSet& colocated_allocations, @@ -595,6 +623,9 @@ class BufferAssigner { // buffers can be shared if their sizes match. bool allow_input_output_aliasing_; + // If true, allocate buffers for constant instructions. + bool allocate_buffers_for_constants_; + // Functor used to assign colors to newly allocated logical buffers. BufferLiveness::Colorer colorer_; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index efa4696130ffeff669b0d674438a45c5a9d48ef2..eccb146a0d7d628870be179a540d9750df3fe41c 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" @@ -89,7 +89,20 @@ class BufferAssignmentTest : public HloTestBase { return BufferAssigner::Run( module, xla::MakeUnique(module), backend().compiler()->BufferSizeBytesFunction(), - [alignment](LogicalBuffer::Color) { return alignment; }) + [alignment](LogicalBuffer::Color) { return alignment; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) + .ConsumeValueOrDie(); + } + + std::unique_ptr RunBufferAssignmentNoBuffersForConstants( + HloModule* module, int64 alignment = 1) { + return BufferAssigner::Run( + module, xla::MakeUnique(module), + backend().compiler()->BufferSizeBytesFunction(), + [alignment](LogicalBuffer::Color) { return alignment; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/false) .ConsumeValueOrDie(); } @@ -98,8 +111,9 @@ class BufferAssignmentTest : public HloTestBase { return BufferAssigner::Run( module, xla::MakeUnique(module), backend().compiler()->BufferSizeBytesFunction(), - [alignment](LogicalBuffer::Color) { return alignment; }, false, - std::move(colorer)) + [alignment](LogicalBuffer::Color) { return alignment; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true, std::move(colorer)) .ConsumeValueOrDie(); } @@ -115,7 +129,9 @@ class BufferAssignmentTest : public HloTestBase { module, xla::MakeUnique(module, module_sequence), backend().compiler()->BufferSizeBytesFunction(), - [alignment](LogicalBuffer::Color) { return alignment; }) + [alignment](LogicalBuffer::Color) { return alignment; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); } @@ -125,7 +141,7 @@ class BufferAssignmentTest : public HloTestBase { auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); auto value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value)); return builder.Build(); @@ -142,7 +158,7 @@ class BufferAssignmentTest : public HloTestBase { const string& name) { auto builder = HloComputation::Builder(name); auto const4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto index = builder.AddInstruction( @@ -167,9 +183,9 @@ class BufferAssignmentTest : public HloTestBase { const string& name) { auto builder = HloComputation::Builder(name); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto constv = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto indexc = builder.AddInstruction( @@ -290,13 +306,19 @@ static bool BuffersDistinct(const std::vector& a, TEST_F(BufferAssignmentTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); - // Check that the constant does not have a buffer assigned. - EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); + { + auto buffers = RunBufferAssignment(module.get()); + EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); + } + + { + auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); + EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); + } } TEST_F(BufferAssignmentTest, BufferForConst) { @@ -304,20 +326,26 @@ TEST_F(BufferAssignmentTest, BufferForConst) { // no buffers assigned, and their consumer has a buffer. auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); + LiteralUtil::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); - // The two constant nodes have no buffers assigned. - EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); - EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); - // The add node has an output buffer. - GetAssignedOutputAllocation(*buffers, add); + { + auto buffers = RunBufferAssignment(module.get()); + EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); + EXPECT_TRUE(buffers->HasTopLevelAllocation(const1)); + GetAssignedOutputAllocation(*buffers, add); + } + { + auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); + EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); + EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); + GetAssignedOutputAllocation(*buffers, add); + } } TEST_F(BufferAssignmentTest, HasAllocationAt) { @@ -327,7 +355,7 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) { auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, f32vec100_, "param0")); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); auto tuple = builder.AddInstruction( @@ -352,7 +380,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { // This computation copies a constant to output. auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); auto module = CreateNewModule(); @@ -660,7 +688,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { auto exp2 = builder.AddInstruction( HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1)); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( /*shape=*/f32vec10_, /*operand=*/exp2, @@ -708,9 +736,9 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { // Creates the main kernel and verifies instruction counts. auto builder = HloComputation::Builder(TestName()); auto const3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({const3, const4})); auto while_op = builder.AddInstruction(HloInstruction::CreateWhile( @@ -773,11 +801,11 @@ TEST_F(BufferAssignmentTest, ExampleConditional) { auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.4f))); auto const2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.4f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.4f))); auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( r0f32_, pred, const1, true_computation, const2, false_computation)); module->AddEntryComputation(builder.Build()); @@ -1094,7 +1122,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { // Allocations for the call computation should not be thread-local. auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param); - EXPECT_FALSE(call_param_alloc.is_entry_computation_parameter()); + EXPECT_TRUE(call_param_alloc.is_entry_computation_parameter()); EXPECT_FALSE(call_param_alloc.maybe_live_out()); EXPECT_FALSE(call_param_alloc.is_thread_local()); @@ -1196,12 +1224,13 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { // TODO(b/32248867): Enable when buffer assignment gives allocations to // constants. -TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) { +TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { // Test that a tuple constant which is forwarded to the computation output // is properly handled. auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateConstant(Literal::MakeTuple( - {Literal::CreateR0(0).get(), Literal::CreateR0(1).get()}))); + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), + LiteralUtil::CreateR0(1).get()}))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1252,16 +1281,18 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) { auto assignment = RunBufferAssignment(module.get()); - EXPECT_EQ(3, assignment->Allocations().size()); + EXPECT_EQ(2, assignment->Allocations().size()); // Buffers for call are colocated with the sub-computation. EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}), GetAllocation(*assignment, sub_tuple, /*index=*/{})); EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}), GetAllocation(*assignment, sub_param, /*index=*/{})); - // The parameter isn't aliased with anything. + + // The parameter isn't aliased with the result tuple, but it is aliased with + // the call operand. EXPECT_NE(GetTopLevelAllocation(*assignment, param), GetTopLevelAllocation(*assignment, sub_tuple)); - EXPECT_NE(GetTopLevelAllocation(*assignment, param), + EXPECT_EQ(GetTopLevelAllocation(*assignment, param), GetTopLevelAllocation(*assignment, sub_param)); } @@ -1325,13 +1356,15 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { GetAllocation(*assignment, c_call, /*index=*/{0})); EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{0}), GetAllocation(*assignment, d_param, /*index=*/{0})); - // The parameters aren't aliased with anything. + EXPECT_TRUE(BuffersDistinct({a_param}, {b_param}, *assignment)); EXPECT_TRUE(BuffersDistinct({a_param}, {c_param}, *assignment)); EXPECT_TRUE(BuffersDistinct({a_param}, {d_param}, *assignment)); - EXPECT_TRUE(BuffersDistinct({b_param}, {c_param}, *assignment)); - EXPECT_TRUE(BuffersDistinct({b_param}, {d_param}, *assignment)); - EXPECT_TRUE(BuffersDistinct({c_param}, {d_param}, *assignment)); + + EXPECT_EQ(GetAllocation(*assignment, b_param, /*index=*/{0}), + GetAllocation(*assignment, c_param, /*index=*/{0})); + EXPECT_EQ(GetAllocation(*assignment, c_param, /*index=*/{0}), + GetAllocation(*assignment, d_param, /*index=*/{0})); } TEST_F(BufferAssignmentTest, BitcastAsOutput) { @@ -1365,8 +1398,9 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { HloInstruction::CreateParameter(1, tuple_shape, "param1")); auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter( 2, ShapeUtil::MakeShape(PRED, {}), "param1")); - auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred_param, tuple_param0, tuple_param1)); + auto select = builder.AddInstruction( + HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect, + pred_param, tuple_param0, tuple_param1)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1583,7 +1617,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) { auto b = HloComputation::Builder(TestName() + ".cond"); b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); condition = module->AddEmbeddedComputation(b.Build()); } HloComputation* body; @@ -1638,6 +1672,66 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) { nonbcast_buffer->instruction() == condition->parameter_instruction(0)); } +TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) { + const char* hlo_text = R"( +HloModule Module + +True { + ROOT x.0.1 = f32[] parameter(0) +} + +False { + x.0.0 = f32[] parameter(0) + ROOT copy.1 = f32[] copy(x.0.0) +} + +ENTRY main { + pred.1.0 = pred[] parameter(0) + constant.1.1 = f32[] constant(56) + copy.2 = f32[] copy(constant.1.1) + constant.1.2 = f32[] constant(12) + ROOT conditional.1.3 = f32[] conditional(pred.1.0, copy.2, constant.1.2), + true_computation=True, false_computation=False +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + HloInstruction* constant_1 = + module->entry_computation()->GetInstructionWithName("constant.1.1"); + HloInstruction* constant_2 = + module->entry_computation()->GetInstructionWithName("constant.1.2"); + + auto buffers = RunBufferAssignment(module.get()); + + { + const BufferAllocation& allocation_for_const_1 = + GetTopLevelAllocation(*buffers, constant_1); + EXPECT_TRUE(allocation_for_const_1.is_constant()); + for (const auto& buffer_offset_pair : + allocation_for_const_1.assigned_buffers()) { + EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(), + HloOpcode::kCopy); + EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(), + HloOpcode::kConditional); + } + } + + { + const BufferAllocation& allocation_for_const_2 = + GetTopLevelAllocation(*buffers, constant_2); + EXPECT_TRUE(allocation_for_const_2.is_constant()); + for (const auto& buffer_offset_pair : + allocation_for_const_2.assigned_buffers()) { + EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(), + HloOpcode::kCopy); + EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(), + HloOpcode::kConditional); + } + } +} + class WhileBufferAssignmentTest : public HloTestBase { protected: std::unique_ptr BuildWhileConditionComputation( @@ -1646,9 +1740,9 @@ class WhileBufferAssignmentTest : public HloTestBase { builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto ten = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(10))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten)); return builder.Build(); @@ -1677,7 +1771,9 @@ class WhileBufferAssignmentTest : public HloTestBase { return BufferAssigner::Run( module, xla::MakeUnique(module, sequence), ByteSizeOf, - [alignment](LogicalBuffer::Color) { return alignment; }) + [alignment](LogicalBuffer::Color) { return alignment; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); } @@ -1707,7 +1803,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { HloInstruction::CreateParameter(2, data_shape_, "weights1")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto output1 = builder.AddInstruction( @@ -1827,6 +1923,74 @@ ENTRY %test_module { EXPECT_NE(slice_param, slice_while1); } +TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithConstant) { + const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + + const char* module_str = R"( +HloModule test_module + +%cond.v0 { + %param = s32[] parameter(0) + ROOT %constant = pred[] constant(true) +} + +%cond.v1 { + %param.0 = s32[] parameter(0) + ROOT %constant.0 = pred[] constant(true) +} + +%body.v0 { + ROOT %param.1 = s32[] parameter(0) +} + +%body.v1 { + %param.2 = s32[] parameter(0) + ROOT add = s32[] add(%param.2, %param.2) +} + +ENTRY %test_module { + %constant.42 = s32[] constant(42) + %while.0 = s32[] while(%constant.42), condition=%cond.v0, body=%body.v0 + %mul = s32[] multiply(%while.0, %while.0) + %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1 + ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + // Run CopyInsertion and check if the graph constructed above doesn't need + // any copies inserted for BufferAssignment to run. + int64 instruction_count = module->instruction_count(); + CopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + ASSERT_EQ(instruction_count, module->instruction_count()); + + // Get the instructions in the module. + const HloInstruction* bcast = module->entry_computation()->root_instruction(); + const HloInstruction* constant = + module->entry_computation()->GetInstructionWithName("constant.42"); + ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); + const HloInstruction* while1 = bcast->operand(0); + ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); + const HloInstruction* while0 = while1->operand(0)->operand(0); + ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); + + // Run buffer assignment. + auto assignment = RunBufferAssignment(module.get()); + TF_ASSERT_OK_AND_ASSIGN(auto slice_constant, + assignment->GetUniqueSlice(constant, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, + assignment->GetUniqueSlice(while0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while1, + assignment->GetUniqueSlice(while1, {})); + + // The constant slice is part of the while0's colocation set (init value), but + // not merged into the while1's colocation set. + EXPECT_EQ(slice_constant, slice_while0); + EXPECT_NE(slice_constant, slice_while1); +} + // Tests that the colocated buffers for while instructions are properly assigned // during buffer assignment such that the result tuple elements are not assigned // to the same buffer. @@ -1850,7 +2014,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto build_cond = [&]() { auto builder = HloComputation::Builder("cond"); auto const4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); builder.AddInstruction(HloInstruction::CreateBinary( @@ -1862,7 +2026,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto build_body = [&]() { auto builder = HloComputation::Builder("body"); auto const9 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(9))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(9))); auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); builder.AddInstruction( @@ -1874,11 +2038,15 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); - auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, "")); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + auto infeed = + builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, "")); + auto infeed_data = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r0s32, infeed, 0)); auto cond0 = module->AddEmbeddedComputation(build_cond()); auto body0 = module->AddEmbeddedComputation(build_body()); auto while0 = builder.AddInstruction( - HloInstruction::CreateWhile(r0s32, cond0, body0, infeed)); + HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data)); auto cond1 = module->AddEmbeddedComputation(build_cond()); auto body1 = module->AddEmbeddedComputation(build_body()); @@ -1886,7 +2054,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { HloInstruction::CreateWhile(r0s32, cond1, body1, while0)); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero)); auto cond2 = module->AddEmbeddedComputation(build_cond()); @@ -1909,15 +2077,17 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // computation, since the issue this test stresses depends on the order the // nodes are traversed during BufferAssignment. SequentialHloOrdering::HloModuleSequence sequence; - sequence[module->entry_computation()] = {infeed, while0, while1, zero, - add, while2, tuple}; + sequence[module->entry_computation()] = { + token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}; TF_ASSERT_OK_AND_ASSIGN( auto assignment, BufferAssigner::Run( module.get(), xla::MakeUnique(module.get(), sequence), backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return 1; })); + [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); @@ -1948,7 +2118,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); @@ -1992,16 +2162,16 @@ TEST_F(BufferAssignmentTest, TwoCalls) { auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param")); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1)); sub_computation = module->AddEmbeddedComputation(builder.Build(add)); } auto builder = HloComputation::Builder(TestName()); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto call1 = builder.AddInstruction( HloInstruction::CreateCall(r0f32, {constant2}, sub_computation)); auto call2 = builder.AddInstruction( @@ -2025,6 +2195,56 @@ TEST_F(BufferAssignmentTest, TwoCalls) { EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); } +TEST_F(BufferAssignmentTest, CallParamCoAllocation) { + const char* hlo_text = R"( +HloModule CallParamCoAllocation + +Callee { + param0 = (f32[100],(f32[200],f32[300])) parameter(0) + param1 = s32[20] parameter(1) + ROOT constant = f32[] constant(1) +} + +ENTRY Main { + entry_param0 = f32[100] parameter(0) + entry_param1 = s32[20] parameter(1) + custom_call = (f32[200],f32[300]) custom-call(), custom_call_target="call-target" + call_op0 = (f32[100],(f32[200],f32[300])) tuple(entry_param0, custom_call) + ROOT call_result = f32[] call(call_op0, entry_param1), to_apply=Callee +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString( + hlo_text, legacy_flags::GetDebugOptionsFromFlags())); + + auto buffers = RunBufferAssignment(module.get()); + + HloComputation* main = module->entry_computation(); + HloComputation* callee = module->GetComputationWithName("Callee"); + EXPECT_NE(callee, nullptr); + + HloInstruction* param0 = callee->parameter_instruction(0); + HloInstruction* param1 = callee->parameter_instruction(1); + + HloInstruction* entry_param0 = main->parameter_instruction(0); + HloInstruction* entry_param1 = main->parameter_instruction(1); + HloInstruction* custom_call = main->GetInstructionWithName("custom_call"); + + EXPECT_EQ(GetAllocation(*buffers, entry_param0, {}), + GetAllocation(*buffers, param0, {0})); + EXPECT_EQ(GetAllocation(*buffers, entry_param1, {}), + GetAllocation(*buffers, param1, {})); + + EXPECT_EQ(GetAllocation(*buffers, custom_call, {}), + GetAllocation(*buffers, param0, {1})); + EXPECT_EQ(GetAllocation(*buffers, custom_call, {0}), + GetAllocation(*buffers, param0, {1, 0})); + EXPECT_EQ(GetAllocation(*buffers, custom_call, {1}), + GetAllocation(*buffers, param0, {1, 1})); +} + static bool IsPostOrderTraversal( const std::vector& sequence) { tensorflow::gtl::FlatSet seen_so_far; @@ -2053,9 +2273,9 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto input0 = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape_, "input0")); @@ -2121,7 +2341,9 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { BufferAssigner::Run( module.get(), xla::MakeUnique(module.get(), sequence), - ByteSizeOf, [](LogicalBuffer::Color) { return 1; }) + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); @@ -2137,7 +2359,7 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto output1 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index f623aef67a4f98b447a9a15634a78deb60cfe6f1..4a927b57674345f8b3493c098778182a299c5902 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -327,11 +327,12 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param)); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto recv = builder.AddInstruction( - HloInstruction::CreateRecv(vec_, /*channel_id=*/0)); + HloInstruction::CreateRecv(vec_, token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); auto send = builder.AddInstruction( - HloInstruction::CreateSend(recv_done, /*channel_id=*/1)); + HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); auto module = CreateNewModule(); @@ -438,11 +439,13 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { // computation. The buffer containing {0, 1} is copied by GetTupleElement, and // the buffers containing {3} and 3 are dead. auto builder = HloComputation::Builder(TestName()); - auto inner_tuple0 = Literal::MakeTuple( - {Literal::CreateR0(0).get(), Literal::CreateR0(1).get()}); - auto inner_tuple1 = Literal::MakeTuple({Literal::CreateR0(3).get()}); + auto inner_tuple0 = + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), + LiteralUtil::CreateR0(1).get()}); + auto inner_tuple1 = + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(3).get()}); auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); + LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); builder.AddInstruction(HloInstruction::CreateGetTupleElement( inner_tuple0->shape(), tuple_constant, 0)); @@ -490,7 +493,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element0_shape, tuple_param0, 0)); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); @@ -502,7 +505,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element1_shape, tuple_param0, 1)); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1)); @@ -554,7 +557,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element0_shape, tuple_param0, 0)); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); @@ -626,7 +629,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); HloInstruction* slice = nullptr; if (update_uses_tuple_element1) { // Create a slice instruction as an additional user of 'gte1'. @@ -637,7 +640,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -756,7 +759,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); if (tuple_element1_has_two_uses) { // Add 'gte0' and 'gte1' to create another user of 'gte1'. @@ -765,7 +768,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index a23427f00ccd88bb0fe1d973a667f80ca54b14cd..985ff30e80ac9c41c024e4c4d2d0ebb3cff75167 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -61,6 +61,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kFusion: return CallContext::kParallel; diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 1ea7d538cd515c3098b6a1f03c6146d288330406..cc80b7484313329104eec1ce71a150b47d8330c9 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -82,7 +82,7 @@ class CallGraphTest : public HloTestBase { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); return builder.Build(); @@ -247,11 +247,11 @@ TEST_F(CallGraphTest, ComputationWithConditional) { HloComputation::Builder builder(TestName()); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloInstruction* const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.4f))); HloInstruction* const2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.6f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.6f))); HloInstruction* conditional = builder.AddInstruction(HloInstruction::CreateConditional( kScalarShape, pred, const1, true_computation, const2, diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 482ccc5b67109258f544e5657ecfa0e8f62192c0..256d05a73e0bf61d959d21795c106286b52d0b19 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -151,6 +152,14 @@ StatusOr CallInliner::Run(HloModule* module) { } return Status::OK(); })); + if (did_mutate) { + // Run DCE to remove called computations which are now becoming unused. + // This can result then in problems if within the called computation, there + // were send/recv instructions, which the module group verifier will flag as + // error findingthe same channel ID used for multiple send/recv + // instructions. + TF_RETURN_IF_ERROR(HloDCE().Run(module).status()); + } return did_mutate; } diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 738d00881dd057fc13c115006c15e8f5b6d14a1d..ff968bca297077c7cf869ff8d2becb8bf739dce3 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -48,9 +48,9 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // the "one" value. HloComputation::Builder inner(TestName() + ".inner"); HloInstruction* zero = inner.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(24.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(24.0f))); HloInstruction* one = inner.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); TF_ASSERT_OK(zero->AddControlDependencyTo(one)); auto module = CreateNewModule(); HloComputation* inner_computation = @@ -87,7 +87,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { // little trickier. HloComputation::Builder just_false(TestName() + ".false"); just_false.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* false_computation = module->AddEmbeddedComputation(just_false.Build()); @@ -99,7 +99,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { HloComputation::Builder outer(TestName() + ".outer"); HloInstruction* init_value = outer.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); outer.AddInstruction( HloInstruction::CreateWhile(pred, call_false, call_false, init_value)); @@ -123,9 +123,9 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) { HloComputation::Builder just_false(TestName() + ".false"); auto* true_constant = just_false.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({true}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({true}))); auto* false_constant = just_false.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); TF_ASSERT_OK(false_constant->AddControlDependencyTo(true_constant)); HloComputation* false_computation = module->AddEmbeddedComputation(just_false.Build()); @@ -147,15 +147,17 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { HloComputation::Builder outfeeder(TestName() + ".outfeeder"); auto value = outfeeder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + auto token = outfeeder.AddInstruction(HloInstruction::CreateToken()); outfeeder.AddInstruction( - HloInstruction::CreateOutfeed(f32, value, /*outfeed_config=*/"")); + HloInstruction::CreateOutfeed(f32, value, token, /*outfeed_config=*/"")); auto outfeed_computation = module->AddEmbeddedComputation(outfeeder.Build()); HloComputation::Builder outer(TestName() + ".outer"); outer.AddInstruction(HloInstruction::CreateCall( - ShapeUtil::MakeNil(), /*operands=*/{}, outfeed_computation)); + outfeed_computation->root_instruction()->shape(), /*operands=*/{}, + outfeed_computation)); module->AddEntryComputation(outer.Build()); diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index a5b392cbc33c12c3255f3c06e9842fc116e672e5..13008efed1494402eaff47904c2e4797334381a1 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -31,16 +31,23 @@ namespace xla { ChannelTracker::ChannelTracker() : next_channel_(1) {} -ChannelHandle ChannelTracker::NewChannel() { +StatusOr ChannelTracker::NewChannel( + ChannelHandle::ChannelType type) { + if (type != ChannelHandle::DEVICE_TO_DEVICE && + type != ChannelHandle::HOST_TO_DEVICE && + type != ChannelHandle::DEVICE_TO_HOST) { + return InvalidArgument("Invalid channel type: %d", type); + } tensorflow::mutex_lock lock(channel_mutex_); // Create a new channel handle with a unique value. - const ChannelHandle new_handle = AllocateHandle(); + ChannelHandle new_handle = AllocateHandle(type); // Register a channel object associated with the handle. Channel channel; channel.has_sender = false; channel.receiver_count = 0; + channel.type = type; opaque_to_channel_[new_handle.handle()] = channel; return new_handle; @@ -56,10 +63,11 @@ Status ChannelTracker::RegisterRecv(const ChannelHandle& handle) { return RegisterRecvInternal(handle); } -ChannelHandle ChannelTracker::AllocateHandle() { +ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) { int64 handle_value = next_channel_++; ChannelHandle result; result.set_handle(handle_value); + result.set_type(type); return result; } @@ -68,6 +76,13 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { return NotFound("channel handle not found: %lld", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; + if (channel.type == ChannelHandle::HOST_TO_DEVICE) { + return FailedPrecondition( + "host-to-device channels cannot be used with a Send operation; " + "channel handle: %lld", + handle.handle()); + } + if (channel.has_sender) { return FailedPrecondition( "when registering send, passed a channel handle that is already used " @@ -83,6 +98,13 @@ Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { return NotFound("channel handle not found: %lld", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; + if (channel.type == ChannelHandle::DEVICE_TO_HOST) { + return FailedPrecondition( + "device-to-host channels cannot be used with a Recv operation; " + "channel handle: %lld", + handle.handle()); + } + // TODO(b/33942691): Allow more than 1 receivers for broadcast. if (channel.receiver_count >= 1) { return FailedPrecondition( diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index fac0afd672ff3ed083aacf778dd9c4f90a2ee870..d773558c284a7d645f2766bb88c50f7da3777e5d 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -48,11 +48,12 @@ class ChannelTracker { struct Channel { bool has_sender; int64 receiver_count; + ChannelHandle::ChannelType type; }; // Creates a new Channel object and returns the corresponding // ChannelHandle for it. - ChannelHandle NewChannel(); + StatusOr NewChannel(ChannelHandle::ChannelType type); // Informs that the given channel handle is used for a Send operation. // Returns an error status if the handle is already used by another Send. @@ -65,7 +66,8 @@ class ChannelTracker { private: // Bumps the next_channel_ number and returns the allocated number // wrapped in a ChannelHandle. - ChannelHandle AllocateHandle() EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); + ChannelHandle AllocateHandle(ChannelHandle::ChannelType type) + EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); Status RegisterSendInternal(const ChannelHandle& handle) EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 0dceed853dcbae211657f00433866cfe10c51fc7..6b3b9820f09803c8a04504e6c35c22de51abf04b 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -35,6 +35,13 @@ Compiler::ComputeBackendConfigs(const HloInstruction& hlo, return {}; } +std::unique_ptr +Compiler::ComputeDefaultBackendConfig(const HloInstruction& hlo, + se::StreamExecutor* executor) const { + CHECK(executor != nullptr); + return nullptr; +} + // Define a default version where metadata is not used. StatusOr>> Compiler::CompileAheadOfTime( diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index d1144f97bb2ab29d3d18f3b3f65a38af46e68dd1..34f7fe12cac5a4dcd3822865bee903d6eabc25c0 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -48,11 +48,6 @@ namespace xla { // compuation. using ObjectFileData = std::vector; -// Contains the buffer sizes information needed to allocate buffers to execute -// an ahead-of-time computation. Entries which contain -1 designate a parameter -// which should be skipped over during allocation. -using BufferSizes = std::vector; - // Abstract superclass describing the result of an ahead-of-time compilation. class AotCompilationResult { public: @@ -179,6 +174,16 @@ class Compiler { ComputeBackendConfigs(const HloInstruction& hlo, se::StreamExecutor* executor) const; + // Returns the backend configuration that the backend chooses by default for + // the given HLO. Returns no configuration if the backend does not support + // configurations for the given HLO. + // + // The stream executor is passed in to provide information about the hardware + // that the backend configurations would be targeting. + virtual std::unique_ptr + ComputeDefaultBackendConfig(const HloInstruction& hlo, + se::StreamExecutor* executor) const; + // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. virtual StatusOr>> diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index 7c1bacff92b231661477b9931a3066fd91110445..187ce568cbb6c6666e978b8c8114262313c70ba5 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" @@ -29,9 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + namespace xla { Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { @@ -71,6 +75,19 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { return std::move(assignment); } +string DeviceAssignment::ToString() const { + string output = StrCat("Computations: ", computation_count(), + " Replicas: ", replica_count(), "\n"); + for (int computation = 0; computation < computation_count(); ++computation) { + StrAppend(&output, "Computation ", computation, ": "); + for (int replica = 0; replica < replica_count(); ++replica) { + StrAppend(&output, operator()(replica, computation), " "); + } + StrAppend(&output, "\n"); + } + return output; +} + StatusOr ComputationPlacer::DeviceId(int replica, int computation, int replica_count, int computation_count) { diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index 737d00e93ecb51a9bd544bbcbe99d93374d108fb..c899ffb9dc562426ef14c0d414469c04debeec70 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -55,6 +55,8 @@ class DeviceAssignment : public Array2D { // due to a StatusOr of an incomplete type (DeviceAssignment). static StatusOr> Deserialize( const DeviceAssignmentProto& proto); + + string ToString() const; }; // A generic implementation of the XLA computation placer, which assigns device diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index e9ec796121fff223474c3e81a5e973cc37f8caec..b7be3ba605a89a736b032eaab5a5085ac64fc549 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 868348547d9f5cbdc7576c7fc0697d72c3a3e557..c43a31b167d47af3c92ed35fa52594fa5da1e4af 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -55,7 +55,7 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { true_computation_builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {}), "param")); auto one = true_computation_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); true_computation_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one)); @@ -73,7 +73,7 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "param")); auto forty_two = false_computation_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); false_computation_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two)); @@ -82,11 +82,11 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { } auto false_instrn = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto false_param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {}), "false_param")); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); builder.AddInstruction(HloInstruction::CreateConditional( ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation, @@ -106,7 +106,7 @@ TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) { HloComputation* computation = MakeConditional(&module()); auto* true_op = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK( true_op->AddControlDependencyTo(computation->root_instruction())); @@ -119,10 +119,11 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* true_computation = conditional->true_computation(); + auto* token = true_computation->AddInstruction(HloInstruction::CreateToken()); auto* send = true_computation->AddInstruction(HloInstruction::CreateSend( true_computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))), - /*channel_id=*/0)); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))), + token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateSendDone(send)); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); } @@ -133,8 +134,9 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* true_computation = conditional->true_computation(); + auto* token = true_computation->AddInstruction(HloInstruction::CreateToken()); auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv( - ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0)); + ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); } @@ -144,8 +146,9 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* false_computation = conditional->false_computation(); - false_computation->AddInstruction( - HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + auto token = false_computation->AddInstruction(HloInstruction::CreateToken()); + false_computation->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::MakeShape(F32, {1}), token, "config")); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 33d8338809d4e8c7c4774f062c3dda5494543ca6..3e39c1bab1e07d192a8c145be5103085fd3c189b 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -76,15 +76,6 @@ SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node, policy.copy_parameters_and_constants = true; policy.copy_root_replicated_buffers = true; } - for (const CallSite& site : node.caller_callsites()) { - // The AddCopiesForConditional() already adds copies, but the copy remover - // removes them, so we re-add them by returning the policy here. But really - // the copy remover should not be removing them. - if (site.instruction()->opcode() == HloOpcode::kConditional) { - policy.copy_parameters_and_constants = true; - policy.copy_root_replicated_buffers = true; - } - } return policy; } @@ -321,7 +312,7 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, return Status::OK(); } -// We add copies for all the indices of the true and false computaiton roots, +// We add copies for all the indices of the true and false computation roots, // in order to resolve interference. We later rely on the CopyRemover to drop // the unnecessary ones. Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, @@ -360,26 +351,6 @@ Status StripControlDependenciesFrom(HloInstruction* instruction) { return Status::OK(); } -// Add kCopy instructions to the given module to guarantee there is no -// live-range interference. Generally interference can only occur around kWhile -// instructions which have update-in-place semantics. -Status AddCopiesToResolveInterference(HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module)); - - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kWhile) { - TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); - } else if (instruction->opcode() == HloOpcode::kConditional) { - TF_RETURN_IF_ERROR( - AddCopiesForConditional(*alias_analysis, instruction)); - } - } - } - return Status::OK(); -} - // Class for removing unnecessary copies from the module. // // kCopy instructions are added conservatively to guarantee no live range @@ -472,6 +443,10 @@ class CopyRemover { // between copies added around aliased operations (kWhile) guarantees // this strict order. for (const HloValue* value_a : buffer.values()) { + if (ShapeUtil::IsToken(value_a->shape())) { + // Token values have no representation and cannot interfere. + continue; + } for (const HloValue* value_b : buffer.values()) { if (value_a != value_b) { DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, @@ -613,7 +588,10 @@ class CopyRemover { VLOG(2) << copy->name() << " is not removable"; return false; } - + if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) { + VLOG(2) << copy->name() << " is not removable (shape mismatch)"; + return false; + } const CopyNodes& copy_node = copy_map_.at(copy); ValueNode* src = copy_node.src; ValueNode* dest = copy_node.dest; @@ -670,7 +648,12 @@ class CopyRemover { // We can only perform copy elision if the resulting merged values have // totally ordered live ranges; otherwise the merged buffer would have // live range interference. - if (IsHead(*dest)) { + if (src->next == dest) { + // In the process of eliding copies, its possible for a copy to have the + // same source and destination buffer. In this case, the copy can be + // safely removed. + VLOG(2) << copy->name() << " source and destination buffers are same."; + } else if (IsHead(*dest)) { // The copy copies an arbitrary value in the source buffer (call it s_x) // and defines d_0, the first value in the destination buffer. After // merging, the values in the combined buffer must be strictly ordered @@ -947,22 +930,30 @@ class CopyRemover { BufferValueTracker buffer_value_tracker_; }; -// Try to remove as many copies from the module as possible without introducing -// live range interference. Copy instructions (identified by their unique id) in -// the set copies_to_exclude are not considered for removal. -Status RemoveUnnecessaryCopies( - const HloOrdering& ordering, - const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module) { +void MaybeDumpModule(const string& message, const HloModule& module) { + if (VLOG_IS_ON(3)) { + VLOG(3) << message; + XLA_VLOG_LINES(3, module.ToString()); + hlo_graph_dumper::MaybeDumpHloModule(module, message); + } +} + +} // namespace + +// Add kCopy instructions to the given module to guarantee there is no +// live-range interference. Generally interference can only occur around kWhile +// instructions which have update-in-place semantics. +Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module)); - CopyRemover copy_remover(*alias_analysis, ordering, module); - XLA_VLOG_LINES(3, copy_remover.ToString()); + HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy && - !ContainsKey(copies_to_exclude, instruction->unique_id())) { - TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); + if (instruction->opcode() == HloOpcode::kWhile) { + TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); + } else if (instruction->opcode() == HloOpcode::kConditional) { + TF_RETURN_IF_ERROR( + AddCopiesForConditional(*alias_analysis, instruction)); } } } @@ -979,9 +970,10 @@ Status RemoveUnnecessaryCopies( // // (3) Constants and parameters cannot be live out of the entry computation // -Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { +Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, + HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module)); + HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); // Identify which shape indices of which instructions need to be copied. Store // these results in 'instructions_to_copy'. @@ -1065,13 +1057,23 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { HloInstruction* instruction = pair.first; const ShapeTree& indices_to_copy = pair.second; + ShapeTree copies_added(indices_to_copy.shape()); std::vector users = instruction->users(); TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, instruction->parent()->DeepCopyInstruction( - instruction, &indices_to_copy)); + instruction, &indices_to_copy, &copies_added)); for (HloInstruction* user : users) { TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); } + // Special case copies are not eligible for later copy elision passes. + indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) { + if (has_copy) { + HloInstruction* copy = *copies_added.mutable_element(index); + if (copy != nullptr) { + copy->SetCopyElisionAllowed(false); + } + } + }); if (instruction == instruction->parent()->root_instruction()) { instruction->parent()->set_root_instruction(deep_copy); } @@ -1079,23 +1081,36 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { return Status::OK(); } -Status VerifyNoLiveRangeInterference(HloModule* module) { +Status CopyInsertion::VerifyNoLiveRangeInterference(HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module)); + HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); DependencyHloOrdering ordering(module); TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); return Status::OK(); } -void MaybeDumpModule(const string& message, const HloModule& module) { - if (VLOG_IS_ON(3)) { - VLOG(3) << message; - XLA_VLOG_LINES(3, module.ToString()); - hlo_graph_dumper::MaybeDumpHloModule(module, message); +Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, + HloModule* module) { + MaybeDumpModule("after adding copies to resolve interference", *module); + + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); + CopyRemover copy_remover(*alias_analysis, ordering, module); + XLA_VLOG_LINES(3, copy_remover.ToString()); + + std::unique_ptr call_graph = CallGraph::Build(module); + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy && + instruction->CopyElisionAllowed()) { + TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); + } + } } -} + MaybeDumpModule("after removing unnecessary copies", *module); -} // namespace + return Status::OK(); +} StatusOr CopyInsertion::Run(HloModule* module) { // Copy insertion is performed in three steps: @@ -1130,16 +1145,13 @@ StatusOr CopyInsertion::Run(HloModule* module) { "Call graph must be flattened before copy insertion."); } - // Gather Ids of existing kCopy instructions in the module. We avoid removing - // these copies (except via DCE in TupleSimplifier) because they may have been - // added for reasons not considered by copy insertion (eg, layout assignment). - // Instruction id is used instead of HloInstruction* because the pointer - // values may be recycled. - tensorflow::gtl::FlatSet existing_copies; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - existing_copies.insert(instruction->unique_id()); + int64 num_existing_copies = 0; + if (VLOG_IS_ON(1)) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + ++num_existing_copies; + } } } } @@ -1158,13 +1170,8 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); - MaybeDumpModule("after adding copies to resolve interference", *module); - DependencyHloOrdering ordering(module); - TF_RETURN_IF_ERROR( - RemoveUnnecessaryCopies(ordering, existing_copies, module)); - - MaybeDumpModule("after removing unnecessary copies", *module); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module)); TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); @@ -1185,7 +1192,7 @@ StatusOr CopyInsertion::Run(HloModule* module) { } } } - VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size(); + VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies; VLOG(1) << "Num copies after copy-insertion: " << num_total_copies; } diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 65e3d31e347e2cb249a072e7d06ca10c55401748..5ba64b78a3c9aff5f323691df2ece9b5e6bf3232 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -48,6 +47,15 @@ class CopyInsertion : public HloPassInterface { public: tensorflow::StringPiece name() const override { return "copy-insertion"; } + // fusion_can_share_buffer: backend specific function that decides whether a + // fusion can share buffer with its operand. + // + // TODO(b/80315712): Find a better way to tell whether a fusion can share + // buffer. + CopyInsertion(const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer = nullptr) + : fusion_can_share_buffer_(fusion_can_share_buffer) {} + // Run the pass on the given module. Returns whether the module was changed // (copies were inserted). StatusOr Run(HloModule* module) override; @@ -62,6 +70,25 @@ class CopyInsertion : public HloPassInterface { // // TODO(b/62548313): Remove this when buffer assignment is module-scoped. static StatusOr AddCopiesForBufferAssignment(HloModule* module); + + // Try to remove as many copies from the module as possible without + // introducing live range interference. Only copy instructions that are + // eligible for copy elision are considered for removal. + Status RemoveUnnecessaryCopies(const HloOrdering& ordering, + HloModule* module); + + private: + // Verifies that no HLO values have interfering live ranged assuming the + // ordering used by copy insertion. + Status VerifyNoLiveRangeInterference(HloModule* module); + + Status AddCopiesToResolveInterference(HloModule* module); + + Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module); + + // Backend specific function that decides whether a fusion can share buffer + // with its operand. + HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 684fff8a6fd4fd045a0de9df9660b887f32bdf40..892d0d7b547aaf1e7f1c55e4163d1e1fd9518def 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -108,7 +108,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { // be copied before entering the tuple. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); @@ -125,21 +125,27 @@ TEST_F(CopyInsertionTest, SingleConstant) { } TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { - // Verify that an kCopy instructions which exist in the pass before + // Verify that kCopy instructions which change layout and exist before // copy-insertion remain in the graph after copy-insertion. auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kCopy, constant)); - HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kCopy, constant)); + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}))); + auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape()); + Layout reversed_layout = + LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major); + Shape copy_shape = constant->shape(); + *copy_shape.mutable_layout() = reversed_layout; + HloInstruction* copy_1 = builder.AddInstruction( + HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); + HloInstruction* copy_2 = builder.AddInstruction( + HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( constant->shape(), HloOpcode::kAdd, copy_1, copy_2)); - HloInstruction* add_copy = builder.AddInstruction( - HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add)); + builder.AddInstruction( + HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add)); module->AddEntryComputation(builder.Build()); @@ -147,12 +153,11 @@ TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 3); + EXPECT_EQ(CountCopies(*module), 2); - EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())))); + EXPECT_EQ(module->entry_computation()->root_instruction(), add); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Add(op::Copy(op::Constant()), op::Copy(op::Constant()))); } TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { @@ -162,9 +167,9 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); HloInstruction* x = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -192,11 +197,11 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { // the computation result. Verify that copies are added properly. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); HloInstruction* constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); HloInstruction* tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -204,9 +209,9 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { HloInstruction::CreateTuple({constant3, constant2})); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1)); EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); @@ -250,8 +255,9 @@ TEST_F(CopyInsertionTest, BitcastConstant) { // The output of a bitcast is its operand (same buffer), so a bitcast // constant feeding the result must have a copy added. auto builder = HloComputation::Builder(TestName()); - HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1.0, 42.0}))); + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.0, 42.0}))); HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); @@ -365,9 +371,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { // copy is added. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); HloInstruction* tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -375,9 +381,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { HloInstruction::CreateTuple({constant2, constant1})); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); HloInstruction* gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); @@ -408,7 +414,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { const Shape& loop_state_shape) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(10))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); auto loop_state = builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); auto induction_variable = @@ -437,7 +443,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(1). @@ -475,7 +481,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); // add0 = Add(in0, 1) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -544,7 +550,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); // add0 = Add(in0, 1) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); @@ -559,8 +565,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { data = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); } - auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto update = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); @@ -593,7 +600,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( gte0->shape(), HloOpcode::kAdd, gte0, inc)); @@ -603,8 +610,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // GTE(GTE(loop_state, 1), 0) -> Add auto gte10 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0)); - auto update10 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto update10 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add10 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, gte10, update10)); @@ -628,10 +636,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".While"); auto induction_var_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); - auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto data_init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); if (nested) { auto inner_init = builder.AddInstruction( @@ -654,8 +663,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstruction_InitPointsToConstant() { auto builder = HloComputation::Builder(TestName() + ".While"); - auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto data_init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, &builder); } @@ -672,11 +682,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto v1 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto v2 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); @@ -684,9 +694,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto data_init = builder.AddInstruction(HloInstruction::CreateTernary( - nested_tuple_shape_, HloOpcode::kSelect, pred, tuple1, tuple2)); + nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_, data_init, &builder); @@ -696,7 +706,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto one_vec = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); auto data_init = @@ -709,11 +719,12 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstruction_InitPointsToInterfering() { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto data_init = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); - auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto one_vec = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // Take a reference to 'data_init' to make it interfere with while result. auto add = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data_init, one_vec)); @@ -745,7 +756,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { const bool nested = ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); auto induction_var_init = builder->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto condition = module_->AddEmbeddedComputation( BuildConditionComputation(loop_state_shape)); auto body = module_->AddEmbeddedComputation( @@ -1247,7 +1258,6 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { auto loop_init = builder.AddInstruction( HloInstruction::CreateTuple({iter_param, data_param, data_param})); - // Two while loops shares the same loop init tuple. auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape, condition1, body1, loop_init)); @@ -1305,7 +1315,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); cond_builder.AddInstruction(HloInstruction::CreateUnary( cond_constant->shape(), HloOpcode::kNot, cond_constant)); HloComputation* condition = @@ -1313,9 +1323,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -1370,7 +1380,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); cond_builder.AddInstruction(HloInstruction::CreateUnary( cond_constant->shape(), HloOpcode::kNot, cond_constant)); HloComputation* condition = @@ -1378,9 +1388,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -1430,7 +1440,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); cond_builder.AddInstruction(HloInstruction::CreateUnary( cond_constant->shape(), HloOpcode::kNot, cond_constant)); HloComputation* condition = @@ -1438,7 +1448,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); builder.AddInstruction( @@ -1515,7 +1525,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); cond_builder.AddInstruction(HloInstruction::CreateUnary( cond_constant->shape(), HloOpcode::kNot, cond_constant)); HloComputation* condition = @@ -1570,14 +1580,14 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { body_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0))); HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module->AddEmbeddedComputation(cond_builder.Build()); @@ -1595,12 +1605,51 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { EXPECT_THAT(condition->root_instruction(), op::Constant()); } +TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) { + string module_string = R"( +HloModule TokensShouldNotBeCopied + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %TokensShouldNotBeCopied () -> s32[] { + %one = s32[] constant(1) + %negative_one = s32[] negate(%one) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + HloRunner::CreateModuleFromString( + module_string, GetDebugOptionsForTest())); + InsertCopies(module.get()); + + // There should be no copies added because tokens should not be copied. + EXPECT_EQ(CountCopies(*module), 0); +} + std::unique_ptr MakeTrivialCondition(const Shape& shape) { auto builder = HloComputation::Builder("trivial_condition"); builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "loop_state")); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kNot, constant)); return builder.Build(); @@ -1958,5 +2007,46 @@ ENTRY TestComputation { InsertCopies(module.get()); } +TEST_F(CopyInsertionTest, NestedWhiles) { + // Verify that only no unnecessary copies remain after copy insertion for + // trivial nested whiles (b/112472605). + const string& hlo_string = R"( +HloModule TestModule + +cond.inner { + ROOT param.cond.inner = pred[] parameter(0) +} + +body.inner { + param.body.inner = pred[] parameter(0) + ROOT neg = pred[] negate(param.body.inner) +} + +cond.outer { + ROOT param.cond.outer = pred[] parameter(0) +} + +body.outer { + param.cond.outer = pred[] parameter(0) + ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner +} + +ENTRY TestComputation { + entry_param = pred[] parameter(0) + ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + InsertCopies(module.get()); + + // There should only be a single copy inserted, and it's in the entry + // computation. + EXPECT_EQ(CountCopies(*module), 1); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::While(op::Copy(op::Parameter()))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index b703be0f39e2032bc58479f0b957f9d8b01a77c3..84779c60b0c790cd0be98f8b42996acf7e3cfd8b 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -20,7 +20,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") load( "//third_party/mkl:build_defs.bzl", - "if_mkl", + "mkl_deps", ) # Filegroup used to collect source files for dependency checking. @@ -37,6 +37,7 @@ cc_library( srcs = ["cpu_transfer_manager.cc"], hdrs = ["cpu_transfer_manager.h"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -54,34 +55,23 @@ cc_library( ) cc_library( - name = "external_constant_pool", - srcs = ["external_constant_pool.cc"], - hdrs = ["external_constant_pool.h"], + name = "buffer_info_util", + srcs = ["buffer_info_util.cc"], + hdrs = ["buffer_info_util.h"], deps = [ - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/core:lib", ], ) -tf_cc_test( - name = "external_constant_pool_test", - srcs = ["external_constant_pool_test.cc"], - deps = [ - ":external_constant_pool", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", - ], -) - cc_library( name = "cpu_compiler", srcs = ["cpu_compiler.cc"], hdrs = ["cpu_compiler.h"], deps = [ ":compiler_functor", + ":buffer_info_util", ":conv_canonicalization", ":cpu_copy_insertion", ":cpu_executable", @@ -95,7 +85,9 @@ cc_library( ":ir_emitter", ":parallel_task_assignment", ":simple_orc_jit", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:scatter_expander", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -112,7 +104,6 @@ cc_library( "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", - "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -152,7 +143,7 @@ cc_library( "@llvm//:x86_code_gen", # fixdeps: keep "@llvm//:x86_disassembler", # fixdeps: keep ] + select({ - "@org_tensorflow//tensorflow:linux_ppc64le": [ + "//tensorflow:linux_ppc64le": [ "@llvm//:powerpc_disassembler", "@llvm//:powerpc_code_gen", ], @@ -175,7 +166,6 @@ cc_library( ":cpu_runtime", ":custom_call_target_registry", ":disassembler", - ":external_constant_pool", ":orc_jit_memory_mapper", ":runtime_fp16", ":runtime_conv2d", @@ -256,7 +246,6 @@ cc_library( ":cpu_options", ":cpu_runtime", ":dot_op_emitter", - ":external_constant_pool", ":ir_emission_utils", ":ir_function", ":parallel_loop_emitter", @@ -273,15 +262,17 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", + "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", - "//tensorflow/compiler/xla/service/llvm_ir:ops", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "@llvm//:code_gen", @@ -379,7 +370,7 @@ tf_cc_binary( srcs = ["sample_harness.cc"], deps = [ "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -387,8 +378,8 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -468,6 +459,7 @@ cc_library( deps = [ ":vector_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:math_ops", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:transform_utils", @@ -506,10 +498,7 @@ cc_library( "//tensorflow/core:framework_lite", "//tensorflow/core/kernels:eigen_helpers", "//third_party/eigen3", - ] + if_mkl([ - "@mkl_dnn", - "//third_party/mkl:intel_binary_blob", - ]), + ] + mkl_deps(), ) cc_library( @@ -563,10 +552,7 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", "//third_party/eigen3", - ] + if_mkl([ - "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ]), + ] + mkl_deps(), ) cc_library( @@ -741,7 +727,7 @@ tf_cc_test( deps = [ ":cpu_layout_assignment", ":target_machine_features_fake", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -833,7 +819,7 @@ tf_cc_test( ":cpu_executable", ":parallel_task_assignment", ":target_machine_features_fake", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -916,7 +902,7 @@ tf_cc_test( srcs = ["cpu_copy_insertion_test.cc"], deps = [ ":cpu_copy_insertion", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..408fe0f5bf5d729165eadd532d4740211620645d --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" + +namespace xla { +namespace cpu { + +using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo; + +std::vector CreateBufferInfosFromBufferAssignment( + const BufferAssignment& buffer_assignment) { + std::vector buffer_infos; + for (const BufferAllocation& allocation : buffer_assignment.Allocations()) { + if (allocation.is_thread_local()) { + buffer_infos.push_back(BufferInfo::MakeOnStackBuffer(allocation.size())); + } else if (allocation.is_constant()) { + buffer_infos.push_back(BufferInfo::MakeConstant(allocation.size())); + } else if (allocation.is_entry_computation_parameter()) { + buffer_infos.push_back(BufferInfo::MakeEntryParameter( + /*size=*/allocation.size(), + /*param_number=*/allocation.parameter_number())); + } else { + buffer_infos.push_back(BufferInfo::MakeTempBuffer(allocation.size())); + } + } + return buffer_infos; +} + +std::vector CreateArgIndexTableFromBufferInfos( + tensorflow::gtl::ArraySlice buffer_infos) { + std::vector result; + for (int64 i = 0; i < buffer_infos.size(); i++) { + if (buffer_infos[i].is_entry_parameter()) { + if (buffer_infos[i].entry_parameter_number() >= result.size()) { + result.resize(buffer_infos[i].entry_parameter_number() + 1); + } + result[buffer_infos[i].entry_parameter_number()] = i; + } + } + return result; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h new file mode 100644 index 0000000000000000000000000000000000000000..05de70c72686dcbdaf0b47c46cde23ed45abdb42 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ + +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { +namespace cpu { +// Creates and returns a list of BufferInfo instances containing relevant +// information from `buffer_assignment`. +std::vector<::tensorflow::cpu_function_runtime::BufferInfo> +CreateBufferInfosFromBufferAssignment( + const BufferAssignment& buffer_assignment); + +// Creates and returns a table containing the mapping from entry computation +// parameters to buffer allocation indices. +// +// If this function returns V then entry parameter i has buffer allocation index +// V[i]. +std::vector CreateArgIndexTableFromBufferInfos( + tensorflow::gtl::ArraySlice<::tensorflow::cpu_function_runtime::BufferInfo> + buffer_infos); +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 6a7eb85e3baec3517b8f3ddef6a8dcfae9c9e614..128eea4828b5e514b2ba6b398898e4a5d228e746 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -156,9 +156,26 @@ std::unique_ptr CompilerFunctor::operator()( target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream); codegen_passes.run(module); - // Construct ObjectFile from machine code buffer. - return std::unique_ptr( + std::unique_ptr memory_buffer( new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer))); + + if (VLOG_IS_ON(2)) { + llvm::Expected> obj_file = + llvm::object::ObjectFile::createObjectFile(*memory_buffer); + if (obj_file) { + StatusOr disasm_result = + disassembler_->DisassembleObjectFile(*obj_file.get()); + if (disasm_result.ok()) { + XLA_VLOG_LINES(2, disasm_result.ValueOrDie().text); + } else { + LOG(WARNING) << "Could not disassemble object file!"; + } + } else { + LOG(WARNING) << "Could convert memory buffer to object file!"; + } + } + + return memory_buffer; } static std::vector VectorFunctionsForTargetLibraryInfoImpl() { diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 375b017b09263c20c1b1ef8329f7e2f6a573dda4..547d4c696da5cfdde3dece03250ae5fa51c92f25 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -60,11 +60,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in CNHW order. auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D( + LiteralUtil::CreateR4FromArray4D(Array4D( kInputFeatureCount, kBatchSize, kInputSize, kInputSize)))); // The kernel dimensions are in OIHW order. auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D( + LiteralUtil::CreateR4FromArray4D(Array4D( kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize)))); ConvolutionDimensionNumbers dnums; @@ -122,11 +122,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in NHWC order. auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D( + LiteralUtil::CreateR4FromArray4D(Array4D( kBatchSize, kInputSize, kInputSize, kInputFeatureCount)))); // The kernel dimensions are in HWIO order. auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D( + LiteralUtil::CreateR4FromArray4D(Array4D( kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount)))); ConvolutionDimensionNumbers dnums; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index d039132535071661d047579587385210719fede3..35154af0482a2ea85f17add48b16de4a8bb6affb 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -30,6 +30,7 @@ limitations under the License. #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Mangler.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/Object/ObjectFile.h" @@ -38,7 +39,7 @@ limitations under the License. #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -49,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" +#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" @@ -66,7 +68,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" @@ -87,6 +88,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -103,6 +105,7 @@ limitations under the License. namespace xla { namespace cpu { +using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo; CpuAotCompilationOptions::CpuAotCompilationOptions( string triple, string cpu_name, string features, string entry_point_name, @@ -120,11 +123,11 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const { } CpuAotCompilationResult::CpuAotCompilationResult( - ObjectFileData object_file_data, BufferSizes buffer_sizes, + ObjectFileData object_file_data, std::vector buffer_infos, int64 result_buffer_index, std::unique_ptr hlo_profile_printer_data) : object_file_data_(std::move(object_file_data)), - buffer_sizes_(std::move(buffer_sizes)), + buffer_infos_(std::move(buffer_infos)), result_buffer_index_(result_buffer_index), hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {} @@ -269,6 +272,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, /*enable_dot_strength_reduction=*/false); + pass.AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. @@ -296,22 +300,26 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_device_entry_computation_layout(), - &target_machine_features); + module->mutable_entry_computation_layout(), &target_machine_features); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return true; }, - /*enable_dot_strength_reduction=*/false); - pipeline.AddPass(/*is_layout_sensitive=*/true); + { + auto& pass = pipeline.AddPass>( + "after layout assignement"); + pass.AddPass>( + /*is_layout_sensitive=*/true, + [](const Shape&, const Shape&) { return true; }, + /*enable_dot_strength_reduction=*/false); + pass.AddPass(); + pass.AddPass(/*is_layout_sensitive=*/true); + } pipeline.AddPass(BF16, F32); // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = @@ -351,7 +359,7 @@ llvm::TargetOptions CompilerTargetOptions( llvm::TargetOptions target_options; llvm_ir::SetTargetOptions( /*fast_math_enabled=*/module_config.debug_options() - .xla_enable_fast_math(), + .xla_cpu_enable_fast_math(), &target_options); return target_options; } @@ -518,7 +526,7 @@ StatusOr> CpuCompiler::RunBackend( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_enable_fast_math(), + module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), pre_optimization_ir_hook, post_optimization_ir_hook); llvm_module->setDataLayout(jit->data_layout()); @@ -559,7 +567,9 @@ StatusOr> CpuCompiler::RunBackend( BufferAssigner::Run( module.get(), xla::MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -579,7 +589,9 @@ StatusOr> CpuCompiler::RunBackend( IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features, jit->external_constant_pool()); + &target_machine_features); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); for (auto embedded_computation : entry_computation->MakeEmbeddedComputationsList()) { @@ -602,7 +614,13 @@ StatusOr> CpuCompiler::RunBackend( /*is_top_level_computation=*/true, &module_sequence.at(entry_computation))); - string function_name = llvm_ir::AsString(entry_function->getName()); + string function_name = [&]() { + llvm::SmallVector function_name_vector; + llvm::Mangler::getNameWithPrefix( + function_name_vector, entry_function->getName(), jit->data_layout()); + return string(function_name_vector.begin(), function_name_vector.end()); + }(); + string ir_module_string; if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); @@ -638,9 +656,9 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, // so we bail if the configs have conflicting flags. At the moment, the only // flag that needs to be consistent is fast-math. const bool fast_math_enabled = - modules[0]->config().debug_options().xla_enable_fast_math(); + modules[0]->config().debug_options().xla_cpu_enable_fast_math(); for (const auto& module : modules) { - if (module->config().debug_options().xla_enable_fast_math() != + if (module->config().debug_options().xla_cpu_enable_fast_math() != fast_math_enabled) { return InvalidArgument( "All HLO module configs must have the same value for " @@ -738,7 +756,9 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, BufferAssigner::Run( module, xla::MakeUnique(module, module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -766,8 +786,10 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, IrEmitter ir_emitter(*module, *assignment, &llvm_module, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - &target_machine_features, - /*external_constant_pool=*/nullptr); + &target_machine_features); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -813,7 +835,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, CompilerFunctor compiler_functor( target_machine.get(), &disassembler, opt_level, options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_enable_fast_math(), + module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); std::unique_ptr object_file = @@ -821,27 +843,14 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, ObjectFileData object_file_data(object_file->getBufferStart(), object_file->getBufferEnd()); - BufferSizes buffer_sizes; - for (const BufferAllocation& allocation : assignment->Allocations()) { - // Callers don't need to allocate temporary buffers for parameters. - if (allocation.is_entry_computation_parameter()) { - buffer_sizes.push_back(-1); - continue; - } - // Callers don't need to allocate anything for thread-local temporary - // buffers. They are lowered to allocas. - if (allocation.is_thread_local()) { - buffer_sizes.push_back(-1); - continue; - } - buffer_sizes.push_back(allocation.size()); - } + std::vector buffer_infos = + CreateBufferInfosFromBufferAssignment(*assignment); TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment->GetUniqueTopLevelOutputSlice()); results.emplace_back(MakeUnique( - std::move(object_file_data), std::move(buffer_sizes), + std::move(object_file_data), std::move(buffer_infos), result_slice.index(), std::move(hlo_profile_printer_data))); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index e56f9f01134f84b4698c078b750b0c1fdca7748e..04e1c48872ed55ca7f2aa3bec08c44a1666b90f1 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "llvm/Target/TargetMachine.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" @@ -78,7 +79,8 @@ class CpuAotCompilationOptions : public AotCompilationOptions { class CpuAotCompilationResult : public AotCompilationResult { public: CpuAotCompilationResult( - ObjectFileData object_file_data, BufferSizes buffer_sizes, + ObjectFileData object_file_data, + std::vector<::tensorflow::cpu_function_runtime::BufferInfo> buffer_infos, int64 result_buffer_index, std::unique_ptr hlo_profile_printer_data); ~CpuAotCompilationResult(); @@ -88,17 +90,20 @@ class CpuAotCompilationResult : public AotCompilationResult { } const ObjectFileData& object_file_data() const { return object_file_data_; } - const BufferSizes& buffer_sizes() const { return buffer_sizes_; } + const std::vector<::tensorflow::cpu_function_runtime::BufferInfo>& + buffer_infos() const { + return buffer_infos_; + } int64 result_buffer_index() const { return result_buffer_index_; } private: // Contains the compiled computation: an object file. const ObjectFileData object_file_data_; - // The list of buffer sizes which should be allocated in order to execute the - // compiled computation. These buffers are used for temporary buffers used - // ephemerally during computation as well as the output result. - const BufferSizes buffer_sizes_; + // A list of BufferInfo objects describing the buffers used by the XLA + // computation. + const std::vector<::tensorflow::cpu_function_runtime::BufferInfo> + buffer_infos_; // Contains which buffer index into |buffer_sizes| was designated to the // result of the computation. This buffer should be passed into the output diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index a05a26941786cbf404c4685abb098c9ac8caaa09..4db7fa446ea9188940f930bcadf753bd3e6b79e3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -74,14 +74,14 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { body_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0))); HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module->AddEmbeddedComputation(cond_builder.Build()); @@ -114,7 +114,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { auto sub_param = sub_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); auto constant = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0))); auto add = sub_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, sub_param, constant)); sub_builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index cf43b74c699ca8cbbef11a0abbaf4d69476f5d77..c376864c3e1f882e11bc05f8cf93f2fb1c88e4ec 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -69,12 +69,19 @@ CpuExecutable::CpuExecutable( // guarded by the mutex. compute_function_ = reinterpret_cast(cantFail(sym.getAddress())); + VLOG(1) << "compute_function_ at address " + << reinterpret_cast(compute_function_); } -Status CpuExecutable::AllocateBuffers( +StatusOr, + std::vector>> +CpuExecutable::CreateTempArray( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers) { - CHECK_EQ(buffers->size(), assignment_->Allocations().size()); + tensorflow::gtl::ArraySlice arguments) { + std::vector unowning_buffers( + assignment_->Allocations().size()); + std::vector owning_buffers( + assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() << " allocations for module " << module().name(); for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); @@ -84,44 +91,51 @@ Status CpuExecutable::AllocateBuffers( VLOG(3) << allocation.ToString(); if (allocation.is_entry_computation_parameter()) { + unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer( + allocation.param_shape_index()); VLOG(3) << "allocation #" << i << " is a parameter"; continue; } + if (allocation.is_constant()) { + VLOG(3) << "allocation #" << i << " is a constant"; + continue; + } + if (allocation.is_thread_local()) { VLOG(3) << "buffer #" << i << " is thread-local"; continue; } int64 buffer_size = allocation.size(); - if (!(*buffers)[i].is_null()) { + if (!owning_buffers[i].is_null()) { VLOG(3) << "buffer #" << i << " is in the preallocated result ShapedBuffer"; } else { - TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate( - device_ordinal, buffer_size)); + TF_ASSIGN_OR_RETURN(owning_buffers[i], memory_allocator->Allocate( + device_ordinal, buffer_size)); + unowning_buffers[i] = owning_buffers[i].AsDeviceMemoryBase(); VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes [" - << (*buffers)[i].opaque() << "]"; + << owning_buffers[i].opaque() << "]"; } // Since the output buffer and all the temporary buffers were written into // by the JITed code, msan has no way of knowing their memory was // initialized. Mark them initialized so that msan doesn't flag loads from // these buffers. - TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i].opaque(), buffer_size); } TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment_->GetUniqueTopLevelOutputSlice()); VLOG(3) << "result index: " << result_slice.index(); - return Status::OK(); + return {{std::move(unowning_buffers), std::move(owning_buffers)}}; } Status CpuExecutable::ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: @@ -131,17 +145,11 @@ Status CpuExecutable::ExecuteComputeFunction( // // result: Points at the result. // run_options: the ExecutableRunOptions object. - // args_array: An array of pointers, each of which points to a parameter. - // The size of this array is determined by the function's arity - // (ProgramShape). - // temps_array: An array of pointers, each of which points to a temporary - // buffer the computation needs. The size of this array is - // determined by buffer analysis. + // args_array: null + // temps_array: An array of pointers, containing pointers to temporary buffers + // required by the executable adn pointers to entry computation + // parameters. // - std::vector args_array; - for (const ShapedBuffer* argument : arguments) { - args_array.push_back(argument->root_buffer().opaque()); - } uint64 start_micros = tensorflow::Env::Default()->NowMicros(); @@ -164,16 +172,14 @@ Status CpuExecutable::ExecuteComputeFunction( if (VLOG_IS_ON(3)) { VLOG(3) << "Executing compute function:"; VLOG(3) << tensorflow::strings::Printf( - " func(void* result, void* params[%zu], void* temps[%zu], " + " func(void* result, void* params[null], void* temps[%zu], " "uint64 profile_counters[%zu])", - args_array.size(), buffer_pointers.size(), profile_counters_size); + buffer_pointers.size(), profile_counters_size); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); }; - VLOG(3) << tensorflow::strings::Printf( - " params = [%s]", - tensorflow::str_util::Join(args_array, ", ", ptr_printer).c_str()); + VLOG(3) << " params = nullptr"; VLOG(3) << tensorflow::strings::Printf( " temps = [%s]", tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); @@ -181,8 +187,8 @@ Status CpuExecutable::ExecuteComputeFunction( profile_counters); } - compute_function_(result_buffer, run_options, args_array.data(), - buffer_pointers.data(), profile_counters); + compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(), + profile_counters); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -206,8 +212,8 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( tensorflow::gtl::MutableArraySlice buffers) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( - /*on_host_shape=*/host_result_shape(), - /*on_device_shape=*/host_result_shape(), run_options->allocator(), + /*on_host_shape=*/result_shape(), + /*on_device_shape=*/result_shape(), run_options->allocator(), stream->parent()->device_ordinal()); // Move OwningDeviceMemory values which contain the array(s) of the result @@ -243,27 +249,11 @@ StatusOr CpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { - if (GetRootPointsToSet().IsAmbiguous()) { - return Unimplemented("Points-to set of root instruction is ambiguous"); - } - - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - - std::vector unowning_buffers; - unowning_buffers.reserve(buffers.size()); - for (auto& buffer : buffers) { - unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); - } - TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(), - arguments, unowning_buffers, - hlo_execution_profile)); - - return CreateResultShapedBuffer(run_options, &buffers); + TF_ASSIGN_OR_RETURN( + auto result, + ExecuteAsyncOnStreamImpl(run_options, arguments, hlo_execution_profile)); + TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone()); + return std::move(result); } StatusOr CpuExecutable::ExecuteAsyncOnStream( @@ -274,22 +264,30 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( "Asynchronous execution on stream with hlo profiling is not yet " "supported on CPU."); } + return ExecuteAsyncOnStreamImpl(run_options, arguments, nullptr); +} + +StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile) { + if (GetRootPointsToSet().IsAmbiguous()) { + return Unimplemented("Points-to set of root instruction is ambiguous"); + } auto* host_stream = dynamic_cast( run_options->stream()->implementation()); se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - + std::vector owning_buffers; std::vector unowning_buffers; - unowning_buffers.reserve(buffers.size()); - for (auto& buffer : buffers) { - unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); - } + TF_ASSIGN_OR_RETURN( + std::tie(unowning_buffers, owning_buffers), + CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), + arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, - CreateResultShapedBuffer(run_options, &buffers)); + CreateResultShapedBuffer(run_options, &owning_buffers)); // At this point, `unowning_buffers` contains unowning pointers to all of our // buffers, and `buffers` contains owning pointers to the non-live-out @@ -307,23 +305,22 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( struct AsyncRunTask { CpuExecutable* executable; ServiceExecutableRunOptions run_options; - std::vector arguments; std::vector unowning_buffers; std::shared_ptr> buffers; + HloExecutionProfile* hlo_execution_profile; void operator()() { // Failing a CHECK here is not great, but I don't see an obvious way to // return a failed Status asynchronously. TF_CHECK_OK(executable->ExecuteComputeFunction( - &run_options.run_options(), arguments, unowning_buffers, - /*hlo_execution_profile=*/nullptr)); + &run_options.run_options(), unowning_buffers, hlo_execution_profile)); } }; - host_stream->EnqueueTask(AsyncRunTask{ - this, *run_options, - std::vector(arguments.begin(), arguments.end()), - unowning_buffers, - std::make_shared>(std::move(buffers))}); + host_stream->EnqueueTask( + AsyncRunTask{this, *run_options, std::move(unowning_buffers), + std::make_shared>( + std::move(owning_buffers)), + hlo_execution_profile}); return std::move(result); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 8dd47bfb865e8a0552542f510d3365cff0d111e0..96e53de57eee013fe6f847c10e23a38f5beb9adc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -85,20 +85,39 @@ class CpuExecutable : public Executable { const BufferAssignment& buffer_assignment() const { return *assignment_; } private: - // Allocate buffers required for execution and assign them to the elements of - // "buffers". "buffers" should be sized to the number of buffers in buffer - // assignment. Each vector element corresponds to a particular Index. If - // a vector element already contains a non-null DeviceMemoryBase, then no - // buffer is assigned for this element. - Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator, - int device_ordinal, - std::vector* buffers); + // This is for sharing the code between ExecuteOnStream and + // ExecuteAsyncOnStream. + // + // Notice that it's tricky to use correctly, as the profile object (when it + // exists) must out-live the task. + StatusOr ExecuteAsyncOnStreamImpl( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice arguments, + HloExecutionProfile* hlo_execution_profile); + + // Creates an array suitable for passing as the "temps" argument to the JIT + // compiled function pointer. + // + // Returns (unowning_buffers, owning_buffers) where: + // + // - unowning_buffers.data() can be passed as the temps argument as-is and + // includes pointers to the scratch storage required by the computation, + // the live-out buffer into which the result will be written and entry + // computation parameters. + // + // - owning_buffers contains owning pointers to the buffers that were + // allocated by this routine. This routine allocates buffers for temporary + // storage and the live-out buffer into which the computation writes it + // result. + StatusOr, + std::vector>> + CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal, + tensorflow::gtl::ArraySlice arguments); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. Status ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 97e10a89a209c057685709e7a5034052ff4376ed..991b14f17dbc8cd061af98e032824d3f7075e78b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -282,7 +282,7 @@ class OpcodeFusionTest : public InstructionFusionTest { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "arg0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one)); return module->AddEmbeddedComputation(builder.Build()); @@ -501,8 +501,8 @@ TEST_F(OpcodeFusionTest, UnaryMapOfExp) { HloInstruction* exp = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0)); - builder.AddInstruction(HloInstruction::CreateMap( - shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{})); + builder.AddInstruction( + HloInstruction::CreateMap(shape, {exp}, CreateAdderToOne(module.get()))); module->AddEntryComputation(builder.Build()); @@ -525,8 +525,8 @@ TEST_F(OpcodeFusionTest, BinaryMapOfExps) { HloInstruction* exp1 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1)); - builder.AddInstruction(HloInstruction::CreateMap( - shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{})); + builder.AddInstruction( + HloInstruction::CreateMap(shape, {exp0, exp1}, CreateMax(module.get()))); module->AddEntryComputation(builder.Build()); @@ -595,7 +595,7 @@ TEST_F(OpcodeFusionTest, MessOfFusileNodes) { auto pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(S32, {5}), idx_choice, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), padding_config)); auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 429fc7b78608da0e9cd794ac294851b326f5be24..3681d12d8da818d06d2f690024008c9ccb896286 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 54c52bc08f9c53b8c6898689b18c4cb7f4bdcfd0..639064040f521a9e84bd87c5d05f674204e4d6e2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -92,9 +92,10 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) { } // namespace -void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, - const void* shape, - xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* +__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, + const void* shape, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "AcquireInfeedBufferForDequeue: " << ShapeString(shape, shape_length); @@ -111,9 +112,11 @@ void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, return buffer->data(); } -void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, + void* buffer_ptr, + const void* shape_ptr, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "ReleaseInfeedBufferAfterDeque: " << ShapeString(shape_ptr, shape_length); @@ -125,8 +128,10 @@ void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( std::move(shape)); } -void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( - xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* +__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length, + const void* shape_ptr, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "AcquireOutfeedBufferForPopulation: " << ShapeString(shape_ptr, shape_length); @@ -143,9 +148,11 @@ void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( return buffer->data(); } -void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length, + void* buffer_ptr, + const void* shape_ptr, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: " << ShapeString(shape_ptr, shape_length); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index d97802ee45d6add3c466577d7624d9ca74e2f380..59bc7e0e16fcc66a010408259a1ccfb2b6bb35fd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -160,9 +161,8 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int32 size_32 = static_cast(size); CpuInfeedBuffer* queued_buffer = new CpuInfeedBuffer(size_32); - Status s = - TransferBufferToDevice(executor, /*size=*/size, - /*source=*/source, queued_buffer->device_memory()); + Status s = executor->SynchronousMemcpyH2D( + /*host_src=*/source, /*size=*/size, queued_buffer->device_memory()); if (!s.ok()) { queued_buffer->Done(s); @@ -173,7 +173,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, Status CpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { if (!ShapeUtil::IsTuple(literal_shape)) { int64 size = GetByteSizeRequirement(literal_shape); // Note: OSS build didn't like implicit conversion from @@ -181,18 +181,16 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( tensorflow::gtl::ArraySlice dimensions( tensorflow::bit_cast(literal_shape.dimensions().data()), literal_shape.dimensions().size()); - *literal = std::move(*Literal::CreateFromDimensions( - literal_shape.element_type(), dimensions)); - TF_ASSIGN_OR_RETURN(Shape received_shape, - TransferArrayBufferFromOutfeed( - executor, literal->untyped_data(), size)); - TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape())) + TF_ASSIGN_OR_RETURN( + Shape received_shape, + TransferArrayBufferFromOutfeed(executor, literal.untyped_data(), size)); + TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape())) << "Shape received from outfeed " << ShapeUtil::HumanString(received_shape) << " did not match the shape that was requested for outfeed: " << ShapeUtil::HumanString(literal_shape); TF_RET_CHECK(size == GetByteSizeRequirement(received_shape)); - *literal->mutable_shape_do_not_use() = received_shape; + *literal.mutable_shape_do_not_use() = received_shape; return Status::OK(); } @@ -201,22 +199,12 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( "Nested tuple outfeeds are not yet implemented on CPU."); } - std::vector> elements; std::vector> buffer_data; for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { const Shape& tuple_element_shape = ShapeUtil::GetTupleElementShape(literal_shape, i); - // Note: OSS build didn't like implicit conversion from - // literal_shape.dimensions() to the array slice on 2017-07-10. - tensorflow::gtl::ArraySlice dimensions( - tensorflow::bit_cast( - tuple_element_shape.dimensions().data()), - tuple_element_shape.dimensions().size()); - auto empty = Literal::CreateFromDimensions( - tuple_element_shape.element_type(), dimensions); int64 size = GetByteSizeRequirement(tuple_element_shape); - buffer_data.push_back({empty->untyped_data(), size}); - elements.push_back(std::move(empty)); + buffer_data.push_back({literal.untyped_data({i}), size}); } TF_ASSIGN_OR_RETURN(Shape received_shape, @@ -230,11 +218,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( TF_RET_CHECK(GetByteSizeRequirement(literal_shape) == GetByteSizeRequirement(received_shape)); - for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { - *elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i); - } - *literal = std::move(*Literal::MakeTupleOwned(std::move(elements))); - TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape)); + TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal_shape)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 6dfc666f09dfa6df740cd54bea0957e3144181bc..80ef953d532798281c10b7a212b9c4d84a790c27 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -39,13 +40,14 @@ class CpuTransferManager : public GenericTransferManager { Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; - Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, - const void* source) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; private: + Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, + const void* source); + // Transfers infeed data to device. InfeedBuffer->Done() must be // called to clean up the memory allocated for InfeedBuffer. StatusOr TransferBufferToInfeedInternal( diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 8eb39d615fd482cdcea716ba7b105c643a2d8b87..f2ac742b6e6fc12076e7a2a242155c005f4b05b8 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -49,15 +49,15 @@ class MemoryTile { // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at // `major_dim_offset` in the major dimension. The tile size along the minor // dimension is the vector size, and that is implicitly determined by `vsl`. - MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, + MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b, llvm::Value* matrix, int64 matrix_size_along_minor_dim, llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) - : vsl_(vsl), ir_builder_(ir_builder) { + : vsl_(vsl), b_(b) { pointers_.reserve(tile_size_along_major_dim); for (int64 i = 0; i < tile_size_along_major_dim; i++) { - llvm::Value* total_offset = ir_builder->CreateMul( - ir_builder->getInt64(matrix_size_along_minor_dim), - ir_builder->CreateAdd(ir_builder->getInt64(i), major_dim_offset)); + llvm::Value* total_offset = + b->CreateMul(b->getInt64(matrix_size_along_minor_dim), + b->CreateAdd(b->getInt64(i), major_dim_offset)); pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); } } @@ -101,8 +101,7 @@ class MemoryTile { for (int64 i = 0; i < pointers_.size(); i++) { for (int64 j = 0; j < tile_size_along_middle_dim; j++) { result[i].push_back(vsl_->LoadBroadcast( - pointers_[i], ir_builder_->CreateAdd(minor_dim_offset, - ir_builder_->getInt64(j)))); + pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j)))); } } return result; @@ -110,7 +109,7 @@ class MemoryTile { private: VectorSupportLibrary* vsl_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; std::vector pointers_; }; @@ -249,16 +248,15 @@ class ColumnMajorMatrixVectorProductEmitter ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, - llvm::IRBuilder<>* ir_builder) + llvm::IRBuilder<>* b) : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), - ir_builder_(ir_builder), - ksl_(ir_builder_), - vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), - ir_builder_, "") { + b_(b), + ksl_(b_), + vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") { CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); CHECK(!has_addend() || addend != nullptr); } @@ -272,7 +270,7 @@ class ColumnMajorMatrixVectorProductEmitter bool is_first_column); MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { - return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, /*matrix_size_along_minor_dim=*/m(), /*major_dim_offset=*/column_start, /*tile_size_along_major_dim=*/column_count); @@ -302,7 +300,7 @@ class ColumnMajorMatrixVectorProductEmitter llvm::Value* rhs_; llvm::Value* addend_; llvm::Value* result_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; KernelSupportLibrary ksl_; VectorSupportLibrary vsl_; }; @@ -331,7 +329,7 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() { }); if (column_remainder != 0) { - EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder, + EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, column_limit == 0); } } @@ -364,7 +362,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( return; } - llvm::Value* columns_llvm = ir_builder_->getInt64(columns); + llvm::Value* columns_llvm = b_->getInt64(columns); // for (col = current_tile_col; col < (columns + current_tile_col); col++) // for (row = row_start, row < m_; row++) { @@ -375,12 +373,11 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( ksl_.ForReturnVoid( "dot.inner.epilg.outer", /*start=*/current_tile_col, - /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col), + /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), /*step=*/1, /*peel_first_iteration=*/false, [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); - llvm::Value* total_offset = - ir_builder_->CreateMul(col, ir_builder_->getInt64(m())); + llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); ksl_.ForReturnVoid( @@ -388,9 +385,8 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( /*step=*/1, [&](llvm::Value* scalar_row) { llvm::Value* product = vsl_.Mul( vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); - llvm::Value* setting_result_first_time = ir_builder_->CreateAnd( - is_first_scalar_col, - ir_builder_->getInt1(is_first_tiled_column)); + llvm::Value* setting_result_first_time = b_->CreateAnd( + is_first_scalar_col, b_->getInt1(is_first_tiled_column)); ksl_.IfReturnVoid( setting_result_first_time, /*true_block_generator=*/ @@ -478,16 +474,15 @@ class RowMajorMatrixVectorProductEmitter RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result, - llvm::IRBuilder<>* ir_builder) + llvm::Value* result, llvm::IRBuilder<>* b) : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), - ir_builder_(ir_builder), - ksl_(ir_builder_), - vsl_(scalar_type(), /*vector_size=*/tile_cols(), ir_builder_, "") { + b_(b), + ksl_(b_), + vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") { CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); CHECK(!has_addend() || addend != nullptr); } @@ -498,7 +493,7 @@ class RowMajorMatrixVectorProductEmitter private: MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { - return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, /*matrix_size_along_minor_dim=*/k(), /*major_dim_offset=*/row_start, /*tile_size_along_major_dim=*/row_count); @@ -517,7 +512,7 @@ class RowMajorMatrixVectorProductEmitter llvm::Value* rhs_; llvm::Value* addend_; llvm::Value* result_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; KernelSupportLibrary ksl_; VectorSupportLibrary vsl_; }; @@ -559,7 +554,7 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, for (int i = 0; i < row_count; i++) { llvm::Value* result_value = vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); - llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row); + llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row); if (addend_ && row_count != vsl_.vector_size()) { result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); } @@ -578,7 +573,7 @@ void RowMajorMatrixVectorProductEmitter::Emit() { [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); if (row_remainder != 0) { - EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); + EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); } } @@ -609,9 +604,8 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( } for (int r = 0; r < rows; r++) { - llvm::Value* total_offset = ir_builder_->CreateMul( - ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row), - ir_builder_->getInt64(k())); + llvm::Value* total_offset = b_->CreateMul( + b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); ksl_.ForReturnVoid( @@ -722,13 +716,13 @@ class MatrixMatrixBlockPanelEmitter { // `lhs` with `rhs` and stores the result in `result`. explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, - llvm::IRBuilder<>* ir_builder) + llvm::IRBuilder<>* b) : lhs_(lhs), rhs_(rhs), result_(result), config_(config), - ir_builder_(ir_builder), - ksl_(ir_builder_) { + b_(b), + ksl_(b_) { CHECK(max_vectorization_width() > 0 && IsPowerOfTwo(static_cast(max_vectorization_width()))); CHECK_GT(max_vector_count(), 0); @@ -761,7 +755,7 @@ class MatrixMatrixBlockPanelEmitter { int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end); - llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); } + llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); } Config config() const { return config_; } Dimensions dims() const { return config().dims(); } @@ -782,7 +776,7 @@ class MatrixMatrixBlockPanelEmitter { llvm::Value* result_; Config config_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; KernelSupportLibrary ksl_; }; @@ -804,8 +798,8 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { current_vectorization_width >= min_vectorization_width()) { int64 n_end = dims().n() - (dims().n() % current_vectorization_width); if (n_start != n_end) { - VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, - ir_builder_, "gebp"); + VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, + "gebp"); HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); n_start = n_end; } @@ -819,10 +813,9 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { } if (n_start != dims().n()) { - VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp"); + VectorSupportLibrary vsl(scalar_type(), 1, b_, "gebp"); ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { - llvm::Value* n_i_next = - ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1)); + llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); HandleResiduesOnK(&vsl, n_i, n_i_next); }); } @@ -935,11 +928,11 @@ void MatrixMatrixBlockPanelEmitter::EmitTiledGemm( ksl_.ForReturnVoid( "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { MemoryTile result_memory_tile( - vsl, ir_builder_, /*matrix=*/result_, + vsl, b_, /*matrix=*/result_, /*matrix_size_along_minor_dim=*/dims().n(), /*major_dim_offset=*/m_i, /*tile_size_along_major_dim=*/tile_size_m); - MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_, + MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_, /*matrix_size_along_minor_dim=*/dims().k(), /*major_dim_offset=*/m_i, /*tile_size_along_major_dim=*/tile_size_m); @@ -949,8 +942,8 @@ void MatrixMatrixBlockPanelEmitter::EmitTiledGemm( result_memory_tile.LoadTile(n_i)); ksl_.ForReturnVoid( "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { - MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, - dims().n(), k_i, tile_size_k); + MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, + tile_size_k); std::vector> lhs_tile = lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); std::vector rhs_tile = @@ -980,7 +973,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* ir_builder, + llvm::IRBuilder<>* b, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) : dot_(dot), @@ -989,7 +982,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, rhs_array_(rhs_array), addend_array_(addend_array), executable_run_options_value_(executable_run_options_value), - ir_builder_(ir_builder), + b_(b), hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} @@ -997,15 +990,14 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); DotOpEmitter dot_emitter(dot, target_array, lhs_array, rhs_array, - addend_array, executable_run_options_value, - ir_builder, hlo_module_config, - target_machine_features); + addend_array, executable_run_options_value, b, + hlo_module_config, target_machine_features); return dot_emitter.Emit(); } @@ -1050,13 +1042,13 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( } int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - ir_builder_->CreateMemSet( - target, ir_builder_->getInt8(0), size_bytes, + b_->CreateMemSet( + target, b_->getInt8(0), size_bytes, target_machine_features_.minimum_alignment_for_allocation(size_bytes)); int64 max_target_vector_width = target_machine_features_.vector_register_num_elements( - *ir_builder_->GetInsertBlock()->getParent(), primitive_type); + *b_->GetInsertBlock()->getParent(), primitive_type); int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width; std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = @@ -1074,18 +1066,18 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( << config.GetCacheKey(); const bool enable_fast_math = - hlo_module_config_.debug_options().xla_enable_fast_math(); + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, - config.GetCacheKey(), lhs, rhs, target, + /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), lhs, + rhs, target, [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { - MatrixMatrixBlockPanelEmitter gebp_emitter( - config, /*lhs=*/lhs, /*rhs=*/rhs, - /*result=*/target, ir_builder_); + MatrixMatrixBlockPanelEmitter gebp_emitter(config, /*lhs=*/lhs, + /*rhs=*/rhs, + /*result=*/target, b_); gebp_emitter.Emit(); }); @@ -1157,13 +1149,13 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); const bool enable_fast_math = - hlo_module_config_.debug_options().xla_enable_fast_math(); + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); const int target_vector_register_element_size = target_machine_features_.vector_register_num_elements( - *ir_builder_->GetInsertBlock()->getParent(), primitive_type); + *b_->GetInsertBlock()->getParent(), primitive_type); // We may not always know the vector register size for the target we're // compiling against, in which case target_vector_register_element_size is 0. @@ -1184,13 +1176,13 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, - config.GetCacheKey(), lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), + lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, llvm::Value* result_op) { ColumnMajorMatrixVectorProductEmitter emitter( - config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); + config, lhs_op, rhs_op, addend_op, result_op, b_); emitter.Emit(); }); } else { @@ -1203,13 +1195,13 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, - config.GetCacheKey(), lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), + lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, llvm::Value* result_op) { - RowMajorMatrixVectorProductEmitter emitter( - config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); + RowMajorMatrixVectorProductEmitter emitter(config, lhs_op, rhs_op, + addend_op, result_op, b_); emitter.Emit(); }); } @@ -1285,11 +1277,11 @@ Status DotOpEmitter::Emit() { // Create loop nests which loop through the LHS operand dimensions and the RHS // operand dimensions. The reduction dimension of the LHS and RHS are handled // in a separate innermost loop which performs the sum of products. - llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), ir_builder_); - llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest( - &loop_nest, lhs_array_, lhs_reduction_dimension, "lhs"); - llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest( - &loop_nest, rhs_array_, rhs_reduction_dimension, "rhs"); + llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), b_); + llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest( + lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); + llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( + rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); // Create the loop which does the sum of products reduction. // @@ -1319,68 +1311,61 @@ Status DotOpEmitter::Emit() { // Function entry basic block. // - Emit alloca for accumulator llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent(); - SetToFirstInsertPoint(&func->getEntryBlock(), ir_builder_); + SetToFirstInsertPoint(&func->getEntryBlock(), b_); llvm::Type* accum_type = target_array_.GetElementLlvmType(); - llvm::Value* accum_address = ir_builder_->CreateAlloca( - accum_type, /*ArraySize=*/nullptr, "accum_address"); + llvm::Value* accum_address = + b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address"); // Preheader basic block of reduction loop: // - Initialize accumulator to zero. llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock(); - ir_builder_->SetInsertPoint(preheader_bb->getTerminator()); + b_->SetInsertPoint(preheader_bb->getTerminator()); - ir_builder_->CreateStore(llvm::Constant::getNullValue(accum_type), - accum_address); + b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address); // Body basic block of reduction loop: // - Load elements from lhs and rhs array. // - Multiply lhs-element and rhs-element. // - Load accumulator and add to product. // - Store sum back into accumulator. - SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), ir_builder_); + SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_); - llvm::Value* lhs_element = - lhs_array_.EmitReadArrayElement(lhs_index, ir_builder_); - llvm::Value* rhs_element = - rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_); + llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_); + llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_); - llvm::Value* accum = ir_builder_->CreateLoad(accum_address); + llvm::Value* accum = b_->CreateLoad(accum_address); llvm::Value* updated_accum; if (ShapeUtil::ElementIsComplex(lhs_shape)) { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; - llvm::Value* product_real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(real(lhs_element), real(rhs_element)), - ir_builder_->CreateFMul(imag(lhs_element), imag(rhs_element))); - llvm::Value* product_imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_element), imag(rhs_element)), - ir_builder_->CreateFMul(imag(lhs_element), real(rhs_element))); - updated_accum = ir_builder_->CreateInsertValue( - accum, ir_builder_->CreateFAdd(real(accum), product_real), {0}); - updated_accum = ir_builder_->CreateInsertValue( - updated_accum, ir_builder_->CreateFAdd(imag(accum), product_imag), {1}); + auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); }; + auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); }; + llvm::Value* product_real = + b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)), + b_->CreateFMul(imag(lhs_element), imag(rhs_element))); + llvm::Value* product_imag = + b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)), + b_->CreateFMul(imag(lhs_element), real(rhs_element))); + updated_accum = b_->CreateInsertValue( + accum, b_->CreateFAdd(real(accum), product_real), {0}); + updated_accum = b_->CreateInsertValue( + updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1}); } else { - llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element); - updated_accum = ir_builder_->CreateFAdd(accum, product); + llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element); + updated_accum = b_->CreateFAdd(accum, product); } - ir_builder_->CreateStore(updated_accum, accum_address); + b_->CreateStore(updated_accum, accum_address); // Exit basic block of reduction loop. // - Load accumulator value (the result). // - Store into output array. - SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), ir_builder_); + SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_); - llvm::Value* result = ir_builder_->CreateLoad(accum_address); + llvm::Value* result = b_->CreateLoad(accum_address); // Create index into target address. The target index is the concatenation of // the rhs and lhs indexes with the reduction dimensions removed. The terms // from the rhs index are the lower dimensions in the index so we add them // first. - llvm_ir::IrArray::Index target_index; + llvm_ir::IrArray::Index target_index(lhs_index.GetType()); for (int dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { target_index.push_back(lhs_index[dimension]); @@ -1392,11 +1377,11 @@ Status DotOpEmitter::Emit() { } } - target_array_.EmitWriteArrayElement(target_index, result, ir_builder_); + target_array_.EmitWriteArrayElement(target_index, result, b_); // Set the IR builder insert point to the exit basic block of the outer most // loop. - ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); + b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); return Status::OK(); } @@ -1404,28 +1389,31 @@ Status DotOpEmitter::Emit() { Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. llvm::Value* result; + // Use the same index_type for all tensor accesses in the same kernel. + llvm::Type* index_type = b_->getInt64Ty(); + llvm_ir::IrArray::Index element_index(index_type); llvm::Value* lhs_value = - lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); + lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_); llvm::Value* rhs_value = - rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); + rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_); if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) { -#define REAL(x) ir_builder_->CreateExtractValue(x, {0}) -#define IMAG(x) ir_builder_->CreateExtractValue(x, {1}) - llvm::Value* real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(REAL(lhs_value), REAL(rhs_value)), - ir_builder_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value))); - llvm::Value* imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)), - ir_builder_->CreateFMul(IMAG(lhs_value), REAL(rhs_value))); +#define REAL(x) b_->CreateExtractValue(x, {0}) +#define IMAG(x) b_->CreateExtractValue(x, {1}) + llvm::Value* real = + b_->CreateFSub(b_->CreateFMul(REAL(lhs_value), REAL(rhs_value)), + b_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value))); + llvm::Value* imag = + b_->CreateFAdd(b_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)), + b_->CreateFMul(IMAG(lhs_value), REAL(rhs_value))); #undef IMAG #undef REAL result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType()); - result = ir_builder_->CreateInsertValue(result, real, {0}); - result = ir_builder_->CreateInsertValue(result, imag, {1}); + result = b_->CreateInsertValue(result, real, {0}); + result = b_->CreateInsertValue(result, imag, {1}); } else { - result = ir_builder_->CreateFMul(lhs_value, rhs_value); + result = b_->CreateFMul(lhs_value, rhs_value); } - target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); + target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_); return Status::OK(); } @@ -1448,7 +1436,7 @@ Status DotOpEmitter::EmitCallToRuntime() { fn_name = multi_threaded ? runtime::kEigenMatMulF16SymbolName : runtime::kEigenSingleThreadedMatMulF16SymbolName; - float_type = ir_builder_->getHalfTy(); + float_type = b_->getHalfTy(); break; case F32: fn_name = multi_threaded @@ -1457,7 +1445,7 @@ Status DotOpEmitter::EmitCallToRuntime() { : (use_mkl_dnn ? runtime::kMKLSingleThreadedMatMulF32SymbolName : runtime::kEigenSingleThreadedMatMulF32SymbolName); - float_type = ir_builder_->getFloatTy(); + float_type = b_->getFloatTy(); break; case F64: fn_name = multi_threaded @@ -1466,7 +1454,7 @@ Status DotOpEmitter::EmitCallToRuntime() { : (use_mkl_dnn ? runtime::kMKLSingleThreadedMatMulF64SymbolName : runtime::kEigenSingleThreadedMatMulF64SymbolName); - float_type = ir_builder_->getDoubleTy(); + float_type = b_->getDoubleTy(); break; default: return Unimplemented("Invalid type %s for dot operation", @@ -1474,16 +1462,16 @@ Status DotOpEmitter::EmitCallToRuntime() { } llvm::Type* float_ptr_type = float_type->getPointerTo(); - llvm::Type* int64_type = ir_builder_->getInt64Ty(); - llvm::Type* int32_type = ir_builder_->getInt32Ty(); - llvm::Type* int8_ptr_type = ir_builder_->getInt8Ty()->getPointerTo(); + llvm::Type* int64_type = b_->getInt64Ty(); + llvm::Type* int32_type = b_->getInt32Ty(); + llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo(); llvm::FunctionType* matmul_type = llvm::FunctionType::get( - ir_builder_->getVoidTy(), + b_->getVoidTy(), {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type, int64_type, int64_type, int64_type, int32_type, int32_type}, /*isVarArg=*/false); - llvm::Function* function = ir_builder_->GetInsertBlock()->getParent(); + llvm::Function* function = b_->GetInsertBlock()->getParent(); llvm::Module* module = function->getParent(); llvm::Function* matmul_func = llvm::cast( @@ -1518,18 +1506,15 @@ Status DotOpEmitter::EmitCallToRuntime() { std::swap(transpose_lhs, transpose_rhs); } - ir_builder_->CreateCall( + b_->CreateCall( matmul_func, - {ir_builder_->CreateBitCast(executable_run_options_value_, int8_ptr_type), - ir_builder_->CreateBitCast(target_array_.GetBasePointer(), - float_ptr_type), - ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type), - ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type), - ir_builder_->getInt64(mat_mult_dims.m), - ir_builder_->getInt64(mat_mult_dims.n), - ir_builder_->getInt64(mat_mult_dims.k), - ir_builder_->getInt32(transpose_lhs), - ir_builder_->getInt32(transpose_rhs)}); + {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type), + b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type), + b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type), + b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type), + b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n), + b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs), + b_->getInt32(transpose_rhs)}); return Status::OK(); } @@ -1552,36 +1537,6 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } -llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( - llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, - int64 reduction_dimension, tensorflow::StringPiece name_suffix) { - // Prepares the dimension list we will use to emit the loop nest. Outermost - // loops are added first. Add loops in major-to-minor order, and skip the - // reduction dimension. - std::vector dimensions; - const Shape& shape = operand_array.GetShape(); - for (int i = LayoutUtil::MinorToMajor(shape).size() - 1; i >= 0; --i) { - int64 dimension = LayoutUtil::Minor(shape.layout(), i); - if (dimension != reduction_dimension) { - dimensions.push_back(dimension); - } - } - - // Create loop nest with one for-loop for each dimension of the - // output. - llvm_ir::IrArray::Index index = - loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); - // Verify every dimension except the reduction dimension was set in the index. - for (int dimension = 0; dimension < index.size(); ++dimension) { - if (dimension == reduction_dimension) { - DCHECK_EQ(nullptr, index[dimension]); - } else { - DCHECK_NE(nullptr, index[dimension]); - } - } - return index; -} - // Return whether the given shape is a matrix with no padding. static bool IsRank2WithNoPadding(const Shape& shape) { return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); @@ -1627,8 +1582,8 @@ bool PotentiallyImplementedAsEigenDot( const Shape& lhs_shape = hlo.operand(0)->shape(); const Shape& rhs_shape = hlo.operand(1)->shape(); - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { + if (ShapeUtil::IsZeroElementArray(lhs_shape) || + ShapeUtil::IsZeroElementArray(rhs_shape)) { return false; } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index ed2a18976a0f1a88e7bb4632d3a63167d5c146ad..590032fbe907d7ca90bf69b7ccc3170b8efec72e 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -61,7 +61,7 @@ class DotOpEmitter { const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); @@ -70,8 +70,7 @@ class DotOpEmitter { const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* ir_builder, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); @@ -89,17 +88,6 @@ class DotOpEmitter { // Emits a call to the CPU runtime to perform the matrix multiply. Status EmitCallToRuntime(); - // Emits a series of nested loops for iterating over an operand array in the - // dot operation. Loops are constructed in major to minor dimension layout - // order. No loop is emitted for the given reduction_dimension. The function - // returns an IrArray index for the given operand_array containing the indvars - // of the loops. All dimensions of the index are filled except for the - // reduction dimension. name_suffix is the string to append to the names of - // LLVM constructs (eg, basic blocks) constructed by this method. - llvm_ir::IrArray::Index EmitOperandArrayLoopNest( - llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, - int64 reduction_dimension, tensorflow::StringPiece name_suffix); - // Represents the dimensions of a matrix-matrix multiply operation. struct MatMultDims { // The number of rows in the LHS. @@ -171,7 +159,7 @@ class DotOpEmitter { const llvm_ir::IrArray& rhs_array_; const llvm_ir::IrArray* addend_array_; llvm::Value* executable_run_options_value_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; const HloModuleConfig& hlo_module_config_; const TargetMachineFeatures& target_machine_features_; }; diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index e97113dfa0f59e791d614c0093d0781e49c48ee4..db54454707983ade31594119b2e868fa168d4cc2 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -19,6 +19,8 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -28,48 +30,6 @@ limitations under the License. namespace xla { namespace cpu { -StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { - switch (op->opcode()) { - case HloOpcode::kTanh: { - PrimitiveType element_type = op->shape().element_type(); - bool cast_result_to_fp16 = false; - string function_name; - switch (element_type) { - case F16: - cast_result_to_fp16 = true; - operand_value = ir_builder_->CreateFPCast(operand_value, - ir_builder_->getFloatTy()); - TF_FALLTHROUGH_INTENDED; - case F32: - function_name = "tanhf"; - break; - case F64: - function_name = "tanh"; - break; - default: - return Unimplemented("tanh"); - } - // Create a function declaration. - llvm::Function* function = - llvm::cast(module_->getOrInsertFunction( - llvm_ir::AsStringRef(function_name), operand_value->getType(), - operand_value->getType())); - function->setCallingConv(llvm::CallingConv::C); - function->setDoesNotThrow(); - function->setDoesNotAccessMemory(); - // Create an instruction to call the function. - llvm::Value* result = ir_builder_->CreateCall(function, operand_value); - if (cast_result_to_fp16) { - result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy()); - } - return result; - } - default: - return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value); - } -} - StatusOr CpuElementalIrEmitter::EmitAtan2( PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { string function_name; @@ -77,8 +37,8 @@ StatusOr CpuElementalIrEmitter::EmitAtan2( switch (prim_type) { case F16: cast_result_to_fp16 = true; - lhs = ir_builder_->CreateFPCast(lhs, ir_builder_->getFloatTy()); - rhs = ir_builder_->CreateFPCast(rhs, ir_builder_->getFloatTy()); + lhs = b_->CreateFPCast(lhs, b_->getFloatTy()); + rhs = b_->CreateFPCast(rhs, b_->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "atan2f"; @@ -98,9 +58,42 @@ StatusOr CpuElementalIrEmitter::EmitAtan2( function->setDoesNotThrow(); function->setDoesNotAccessMemory(); // Create an instruction to call the function. - llvm::Value* result = ir_builder_->CreateCall(function, {lhs, rhs}); + llvm::Value* result = b_->CreateCall(function, {lhs, rhs}); + if (cast_result_to_fp16) { + result = b_->CreateFPCast(result, b_->getHalfTy()); + } + return result; +} + +StatusOr CpuElementalIrEmitter::EmitTanh( + PrimitiveType prim_type, llvm::Value* value) const { + bool cast_result_to_fp16 = false; + string function_name; + switch (prim_type) { + case F16: + cast_result_to_fp16 = true; + value = b_->CreateFPCast(value, b_->getFloatTy()); + TF_FALLTHROUGH_INTENDED; + case F32: + function_name = "tanhf"; + break; + case F64: + function_name = "tanh"; + break; + default: + return Unimplemented("tanh"); + } + // Create a function declaration. + llvm::Function* function = llvm::cast( + module_->getOrInsertFunction(llvm_ir::AsStringRef(function_name), + value->getType(), value->getType())); + function->setCallingConv(llvm::CallingConv::C); + function->setDoesNotThrow(); + function->setDoesNotAccessMemory(); + // Create an instruction to call the function. + llvm::Value* result = b_->CreateCall(function, value); if (cast_result_to_fp16) { - result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy()); + result = b_->CreateFPCast(result, b_->getHalfTy()); } return result; } @@ -118,9 +111,8 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( ElementwiseSourceIndex(index, *hlo, i))); operands.push_back(operand_value); } - return ir_emitter_->EmitScalarCall(hlo->shape().element_type(), - hlo->to_apply(), operands, - llvm_ir::IrName(hlo)); + return ir_emitter_->EmitElementalMap(*Cast(hlo), + operands, llvm_ir::IrName(hlo)); }; } return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator); diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 4446dfd2821fb4b6e75f33694367392ecbcdd8bf..76833e765d05f2477961cd06cead66797c5be623 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -31,7 +31,7 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { public: CpuElementalIrEmitter(const HloModuleConfig& module_config, IrEmitter* ir_emitter, llvm::Module* module) - : ElementalIrEmitter(module_config, module, ir_emitter->ir_builder()), + : ElementalIrEmitter(module_config, module, ir_emitter->b()), ir_emitter_(ir_emitter) {} llvm_ir::ElementGenerator MakeElementGenerator( @@ -39,10 +39,10 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { const HloToElementGeneratorMap& operand_to_generator) const override; protected: - StatusOr EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const override; StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const override; + StatusOr EmitTanh(PrimitiveType prim_type, + llvm::Value* value) const override; IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc deleted file mode 100644 index c56286559158758ca6db5ae097729286bde346f0..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" - -#include -#include -#include - -#include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" - -namespace xla { -namespace cpu { -void ExternalConstantPool::Insert(string name, const LiteralSlice& literal, - int64 alignment) { - CHECK(!ShapeUtil::IsTuple(literal.shape())); - CHECK(alignment > 0 && IsPowerOfTwo(static_cast(alignment))); - CHECK(entries_.find(name) == entries_.end()); - - const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); - void* raw_pointer = tensorflow::port::AlignedMalloc( - literal_size, std::max(alignment, sizeof(void*))); - CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size - << " bytes with alignment of " << alignment; - - std::memcpy(raw_pointer, literal.untyped_data(), literal_size); - entries_.emplace(std::move(name), static_cast(raw_pointer)); -} - -const uint8* ExternalConstantPool::Find(const string& name) { - auto it = entries_.find(name); - return it == entries_.end() ? nullptr : it->second.get(); -} -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h deleted file mode 100644 index 0677f5f0b58005079890052a426e5f48c5d09ed1..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ - -#include - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/platform/mem.h" - -namespace xla { -namespace cpu { -// An ExternalConstantPool maintains a set of constants kept external to -// generated LLVM IR. These constants are accessed from the IR via globals with -// extern linkage. This current incarnation of ExternalConstantPool only -// supports the JIT CPU backend; the AOT backend is not supported. -// -// Implementation-wise, this is a simple wrapper around a map of strings to byte -// buffers. This simply implementation works in a JIT scenario. This class -// will have to become smarter if we decide to support external constant pools -// on AOT compiles in the future. -class ExternalConstantPool { - public: - // Inserts a buffer with the contents of `literal` into the constant pool with - // the name `name`. It is an error to try to insert two constants with the - // same `name` into the same constant pool. The buffer for literal is aligned - // to `aligment` bytes, and `alignment` must be a power of 2. - // - // The constant pool copies out the contents of `literal` into a buffer it - // owns -- it does not keep pointers to `literal`, or to memory owned by - // `literal`. - void Insert(string name, const LiteralSlice& literal, int64 alignment); - - // Find the constant with name `name` in this constant pool. If there isn't - // such constant, return nullptr. - const uint8* Find(const string& name); - - private: - // We need to `AlignedFree` pointers allocated into `entries_` since we - // allocate them with `AlignedMalloc`. - struct FreeDeleter { - void operator()(void* ptr) { tensorflow::port::AlignedFree(ptr); } - }; - - tensorflow::gtl::FlatMap> - entries_; -}; -} // namespace cpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc deleted file mode 100644 index 9290a4e5dfc03ddb86e9d82f1f0f4f9a8ceebb88..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace cpu { -namespace { -class ExternalConstantPoolTest : public ::testing::Test {}; - -template -T GetFromBuffer(const uint8* buffer, int64 index) { - T result; - std::memcpy(&result, buffer + index * sizeof(T), sizeof(T)); - return result; -} - -TEST(ExternalConstantPoolTest, Basic) { - ExternalConstantPool constant_pool; - EXPECT_EQ(constant_pool.Find("name-0"), nullptr); - const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); - constant_pool.Insert("name-0", *literal, 4); - const uint8* constant = constant_pool.Find("name-0"); - ASSERT_NE(constant, nullptr); - - EXPECT_EQ(GetFromBuffer(constant, 0), 1); - EXPECT_EQ(GetFromBuffer(constant, 1), 2); - EXPECT_EQ(GetFromBuffer(constant, 2), 3); - EXPECT_EQ(GetFromBuffer(constant, 3), 4); - - EXPECT_EQ(constant_pool.Find("name-1"), nullptr); -} - -TEST(ExternalConstantPoolTest, RowMinorLayout) { - ExternalConstantPool constant_pool; - EXPECT_EQ(constant_pool.Find("name-0"), nullptr); - const auto literal = Literal::CreateR2WithLayout( - {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); - constant_pool.Insert("name-0", *literal, 4); - const uint8* constant = constant_pool.Find("name-0"); - ASSERT_NE(constant, nullptr); - - EXPECT_EQ(GetFromBuffer(constant, 0), 1); - EXPECT_EQ(GetFromBuffer(constant, 1), 3); - EXPECT_EQ(GetFromBuffer(constant, 2), 2); - EXPECT_EQ(GetFromBuffer(constant, 3), 4); -} - -TEST(ExternalConstantPoolTest, Alignment) { - ExternalConstantPool constant_pool; - EXPECT_EQ(constant_pool.Find("name-0"), nullptr); - - for (int i = 0; i < 8; i++) { - int64 alignment = 1 << i; - string name = tensorflow::strings::StrCat("name-", i); - - const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); - constant_pool.Insert(name, *literal, alignment); - - const uint8* constant = constant_pool.Find(name); - ASSERT_NE(constant, nullptr); - EXPECT_EQ(reinterpret_cast(constant) % alignment, 0); - } -} - -} // namespace -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index b560b7531c0d24e6f670e61a15dce295d9fa2a49..1a8bedfe6afb4f096ddd4703c312b84d521a7ba5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -64,8 +64,8 @@ bool PotentiallyImplementedAsEigenConvolution( return false; } - if (ShapeUtil::HasZeroElements(input_shape) || - ShapeUtil::HasZeroElements(kernel_shape)) { + if (ShapeUtil::IsZeroElementArray(input_shape) || + ShapeUtil::IsZeroElementArray(kernel_shape)) { return false; } // Make sure input and kernel has the same data type. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 59223fddac2f5f7e2e85de4d37e4b6c5760ae697..6f433b4f30372da9cf4503396dbb60172cfc0cb0 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -48,11 +48,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -83,22 +86,20 @@ IrEmitter::IrEmitter( llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - const TargetMachineFeatures* target_machine_features, - ExternalConstantPool* external_constant_pool) + const TargetMachineFeatures* target_machine_features) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), - ir_builder_(llvm_module->getContext()), + b_(llvm_module->getContext()), instruction_to_profile_idx_(std::move(instruction_to_profile_idx)), computation_to_profile_idx_(std::move(computation_to_profile_idx)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), is_top_level_computation_(false), - target_machine_features_(*target_machine_features), - external_constant_pool_(external_constant_pool) { - ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( + target_machine_features_(*target_machine_features) { + b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() - .xla_enable_fast_math())); + .xla_cpu_enable_fast_math())); } StatusOr IrEmitter::EmitComputation( @@ -115,6 +116,19 @@ StatusOr IrEmitter::EmitComputation( computation->root_instruction()->outer_dimension_partitions().size(); } + if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) { + TF_ASSIGN_OR_RETURN( + computation_root_allocation_, + assignment_.GetUniqueTopLevelSlice(computation->root_instruction())); + } + + for (const HloInstruction* param : computation->parameter_instructions()) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice, + assignment_.GetUniqueTopLevelSlice(param)); + computation_parameter_allocations_[param_slice.allocation()->index()] = + param->parameter_number(); + } + InitializeIrFunction(function_name); // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic // readcyclecounter if it is unavailable. @@ -131,6 +145,8 @@ StatusOr IrEmitter::EmitComputation( // Delete 'compute_function', finalizing 'ir_function' and restoring caller // IR insert point. compute_function_.reset(); + computation_root_allocation_ = BufferAllocation::Slice(); + computation_parameter_allocations_.clear(); return ir_function; } @@ -142,11 +158,11 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; // Create and initialize new IrFunction. - compute_function_.reset( - new IrFunction(function_name, linkage, - options::OptimizeForSizeRequested(hlo_module_config_), - hlo_module_config_.debug_options().xla_enable_fast_math(), - module_, &ir_builder_, num_dynamic_loop_bounds_)); + compute_function_.reset(new IrFunction( + function_name, linkage, + options::OptimizeForSizeRequested(hlo_module_config_), + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(), module_, + &b_, num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -154,71 +170,55 @@ IrEmitter::~IrEmitter() {} Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { VLOG(2) << "HandleBitcast: " << bitcast->ToString(); emitted_value_[bitcast] = - ir_builder_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)), - IrShapeType(bitcast->shape())->getPointerTo(), - AsStringRef(IrName(bitcast))); + b_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)), + IrShapeType(bitcast->shape())->getPointerTo(), + AsStringRef(IrName(bitcast))); return Status::OK(); } llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { - llvm::Constant* result; + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, module_); + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/initializer->getType(), + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/initializer, + /*Name=*/""); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); + return llvm::ConstantExpr::getBitCast( + result_global, IrShapeType(literal.shape())->getPointerTo()); +} + +Status IrEmitter::EmitConstantGlobals() { + for (const BufferAllocation& allocation : assignment_.Allocations()) { + if (!allocation.is_constant()) { + continue; + } - // We avoid creating large constants in the LLVM IR since LLVM is not - // efficient for large constant arrays. We still emit "small enough" constant - // arrays into the Ir, in the off chance the LLVM optimizer can do something - // interesting with it. - // - // TODO(b/29904935): Remove the large constant pool. - const int kMaxInternalConstantSizeInBytes = 128; - if (external_constant_pool_ && - ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { - string global_name = tensorflow::strings::StrCat( - "constant_global_", external_global_constant_counter_++); - llvm::GlobalVariable* result_global = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/IrShapeType(literal.shape()), - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::ExternalLinkage, - /*Initializer=*/nullptr, - /*Name=*/AsStringRef(global_name)); - result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); - external_constant_pool_->Insert(global_name, literal, - MinimumAlignmentForShape(literal.shape())); - result = result_global; - } else { - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); - llvm::GlobalVariable* result_global = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/initializer->getType(), - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/initializer, - /*Name=*/""); - result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); - result = llvm::ConstantExpr::getBitCast( - result_global, IrShapeType(literal.shape())->getPointerTo()); - } - return result; + const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); + llvm::Constant* global_for_const; + auto it = emitted_literals_.find(&literal); + if (it != emitted_literals_.end()) { + global_for_const = it->second; + } else { + global_for_const = EmitGlobalForLiteral(literal); + InsertOrDie(&emitted_literals_, &literal, global_for_const); + } + + InsertOrDie(&constant_buffer_to_global_, allocation.index(), + global_for_const); + } + + return Status::OK(); } Status IrEmitter::HandleConstant(HloInstruction* constant) { VLOG(2) << "HandleConstant: " << constant->ToString(); - const Literal& literal = constant->literal(); - llvm::Constant* global_for_const; - - auto it = emitted_literals_.find(&literal); - if (it != emitted_literals_.end()) { - global_for_const = it->second; - } else { - global_for_const = EmitGlobalForLiteral(literal); - emitted_literals_[&literal] = global_for_const; - } - emitted_value_[constant] = global_for_const; - VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const); - VLOG(2) << " its type: " - << llvm_ir::DumpToString(*global_for_const->getType()); - return Status::OK(); + // IrEmitter::EmitConstantGlobals has already taken care of emitting the body + // of the constant. + return EmitTargetAddressForOp(constant); } Status IrEmitter::HandleCopy(HloInstruction* copy) { @@ -226,10 +226,13 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy)); return EmitMemcpy(*(copy->operand(0)), *copy); - } else { - // Use the elemental emitter for non-tuple shapes. + } else if (ShapeUtil::IsArray(copy->shape())) { + // Use the elemental emitter for array shapes. return DefaultAction(copy); } + return Unimplemented( + "unsupported operand type %s for copy instruction", + PrimitiveType_Name(copy->shape().element_type()).c_str()); } // Calculate the alignment of a buffer allocated for a given primitive type. @@ -297,51 +300,66 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Shape& shape = get_tuple_element->shape(); emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement( shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape), - GetEmittedValueFor(operand), &ir_builder_, module_); + GetEmittedValueFor(operand), &b_, module_); return Status::OK(); } Status IrEmitter::HandleSelect(HloInstruction* select) { auto pred = select->operand(0); - auto on_true = select->operand(1); - auto on_false = select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); - - if (ShapeUtil::IsTuple(select->shape())) { - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select)); - llvm_ir::EmitTupleSelect( - GetIrArrayFor(select), GetIrArrayFor(pred), GetEmittedValueFor(on_true), - GetEmittedValueFor(on_false), &ir_builder_, module_); - return Status::OK(); - } - return DefaultAction(select); } -Status IrEmitter::HandleInfeed(HloInstruction* infeed) { - VLOG(2) << "HandleInfeed: " << infeed->ToString(); +Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { + auto pred = tuple_select->operand(0); + auto on_true = tuple_select->operand(1); + auto on_false = tuple_select->operand(2); + TF_RET_CHECK(pred->shape().element_type() == PRED); + TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); + TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape())); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select)); + llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred), + GetEmittedValueFor(on_true), + GetEmittedValueFor(on_false), &b_, module_); + return Status::OK(); +} - const Shape& shape = infeed->shape(); +Status IrEmitter::HandleInfeed(HloInstruction* instruction) { + HloInfeedInstruction* infeed = Cast(instruction); + VLOG(2) << "HandleInfeed: " << infeed->ToString(); - // The infeed operation produces data (dequeued from the infeed queue) at this - // address, which has been provided by buffer assignment. + // The infeed operation produces a two-element tuple containing data and a + // token value. HloInfeedInstruction::infeed_shape gives us the data shape. + const Shape& data_shape = infeed->infeed_shape(); + DCHECK(ShapeUtil::Equal(data_shape, + ShapeUtil::GetTupleElementShape(infeed->shape(), 0))); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed)); - llvm_ir::IrArray infeed_array = GetIrArrayFor(infeed); - if (ShapeUtil::IsTuple(shape)) { - TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape)); + // Write the tuple index table. + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, + assignment_.GetUniqueSlice(infeed, {0})); + llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice, + assignment_.GetUniqueSlice(infeed, {1})); + llvm::Value* token_address = EmitTempBufferPointer( + token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1)); + llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_, + module_); + + if (ShapeUtil::IsTuple(data_shape)) { + TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape)); // For a tuple, we first copy each of the internal elements to // their corresponding target locations. We then construct the // tuple outer buffer containing pointers to the internal // elements. std::vector tuple_element_addresses; - for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + for (int64 i = 0; i < data_shape.tuple_shapes_size(); ++i) { TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, - assignment_.GetUniqueSlice(infeed, {i})); + assignment_.GetUniqueSlice(infeed, {0, i})); const Shape& tuple_element_shape = - ShapeUtil::GetTupleElementShape(shape, i); + ShapeUtil::GetTupleElementShape(data_shape, i); // Only the outer tuple buffer's target address is obtained from // GetEmittedValueFor, to handle the case when Infeed is the root @@ -356,11 +374,11 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { tuple_element_addresses.push_back(tuple_element_address); } - llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_, - module_); + llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape), + tuple_element_addresses, &b_, module_); } else { - TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape, - GetEmittedValueFor(infeed))); + TF_RETURN_IF_ERROR( + EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address)); } return Status::OK(); @@ -378,14 +396,14 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, int32 length_32 = static_cast(length); int32 shape_length; - TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr, - llvm_ir::EncodeSelfDescribingShapeConstant( - shape, &shape_length, &ir_builder_)); + TF_ASSIGN_OR_RETURN( + llvm::Value * shape_ptr, + llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_)); // The signature of the acquire infeed buffer function is: // // (void*)(int32 length); - llvm::Type* int32_type = ir_builder_.getInt32Ty(); + llvm::Type* int32_type = b_.getInt32Ty(); llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); llvm::FunctionType* acquire_type = llvm::FunctionType::get( i8_ptr_type, {int32_type, i8_ptr_type, int32_type}, @@ -405,8 +423,7 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, // // (void)(int32 length, void* buffer); llvm::FunctionType* release_type = llvm::FunctionType::get( - ir_builder_.getVoidTy(), - {int32_type, i8_ptr_type, i8_ptr_type, int32_type}, + b_.getVoidTy(), {int32_type, i8_ptr_type, i8_ptr_type, int32_type}, /*isVarArg=*/false); llvm::Function* release_func; @@ -423,25 +440,22 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, // of size exactly 'length_32', and the runtime is responsible for // check-failing the process if there is a mismatch, versus passing us back a // buffer that we might overrun. - llvm::Value* acquired_pointer = ir_builder_.CreateCall( - acquire_func, {ir_builder_.getInt32(length_32), shape_ptr, - ir_builder_.getInt32(shape_length)}); + llvm::Value* acquired_pointer = b_.CreateCall( + acquire_func, + {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)}); if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. - ir_builder_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, - acquired_pointer, - /*SrcAlign=*/1, length_32); + b_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer, + /*SrcAlign=*/1, length_32); } else { // Outfeed -- copy from the in-program address to the acquired buffer. - ir_builder_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, - program_buffer_address, - /*SrcAlign=*/1, length_32); + b_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address, + /*SrcAlign=*/1, length_32); } - ir_builder_.CreateCall(release_func, - {ir_builder_.getInt32(length_32), acquired_pointer, - shape_ptr, ir_builder_.getInt32(shape_length)}); + b_.CreateCall(release_func, {b_.getInt32(length_32), acquired_pointer, + shape_ptr, b_.getInt32(shape_length)}); return Status::OK(); } @@ -462,7 +476,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { ShapeUtil::GetTupleElementShape(operand_shape, i); llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement( tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape), - value, &ir_builder_, module_); + value, &b_, module_); TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed, tuple_element_shape, tuple_element)); } @@ -481,46 +495,96 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { for (auto operand : tuple->operands()) { base_ptrs.push_back(GetEmittedValueFor(operand)); } - llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_, module_); + llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_, module_); return Status::OK(); } -Status IrEmitter::HandleMap(HloInstruction* map) { - gtl::ArraySlice operands(map->operands()); - HloComputation* function = map->to_apply(); - // The called computation should have been emitted previously. - llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function); +llvm::Value* IrEmitter::EmitElementalMap( + const HloMapInstruction& map_instr, + tensorflow::gtl::ArraySlice elemental_operands, + tensorflow::StringPiece name) { + return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); +} + +StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( + HloReduceWindowInstruction* reduce_window, + const llvm_ir::IrArray::Index& index) { + const HloInstruction* operand = reduce_window->operand(0); + const Window& window = reduce_window->window(); + + // We fold inputs into the accumulator and initialize it to + // the initial value on the reduce_window. + PrimitiveType operand_element_type = operand->shape().element_type(); + llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), + "reduce_window_accumulator_address", &b_, + MinimumAlignmentForPrimitiveType(operand_element_type)); + b_.CreateStore(b_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))), + accumulator_address); - return EmitTargetElementLoop(map, [this, map, operands, mapped_ir_function]( - const llvm_ir::IrArray::Index& index) { - std::vector parameter_addresses; - for (const HloInstruction* operand : operands) { - const llvm_ir::IrArray& array = GetIrArrayFor(operand); - parameter_addresses.push_back( - array.EmitArrayElementAddress(index, &ir_builder_)); + llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_); + std::vector window_size; + for (const auto& dim : window.dimensions()) { + window_size.push_back(dim.size()); + } + const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape( + ShapeUtil::MakeShape(operand_element_type, window_size), "window"); + CHECK_EQ(window_index.size(), index.size()); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); + + llvm_ir::IrArray::Index input_index(b_.getInt64Ty(), index.size()); + llvm::Value* in_bounds_condition = nullptr; + for (size_t i = 0; i < index.size(); ++i) { + llvm::Value* strided_index = + b_.CreateNSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); + input_index[i] = + b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); + + // We need to check if 0 <= input_index[i] < bound, as otherwise we are in + // the padding so that we can skip the computation. That is equivalent to + // input_index[i] < bound as an *unsigned* comparison, since a negative + // value will wrap to a large positive value. + llvm::Value* index_condition = b_.CreateICmpULT( + input_index[i], + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + if (in_bounds_condition == nullptr) { + in_bounds_condition = index_condition; + } else { + in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); } - return EmitElementFunctionCall(mapped_ir_function, map->shape(), - parameter_addresses, "map_function"); - }); + } + CHECK(in_bounds_condition != nullptr); + + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); + SetToFirstInsertPoint(if_data.true_block, &b_); + + // We are not in the padding, so carry out the computation. + llvm_ir::IrArray input_array(GetIrArrayFor(operand)); + llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_); + llvm::Value* result = EmitThreadLocalCall( + *reduce_window->to_apply(), + {b_.CreateLoad(accumulator_address), input_value}, "reducer_function"); + b_.CreateStore(result, accumulator_address); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); + return b_.CreateLoad(accumulator_address); } Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { - auto operand = reduce_window->operand(0); - const Window& window = reduce_window->window(); - HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - /*instruction=*/*reduce_window, /*operands=*/{operand}, - /*supported_types=*/{F32, BF16, S32})); + /*instruction=*/*reduce_window, + /*operands=*/{reduce_window->operand(0)}, + /*supported_types=*/{F32, BF16, S32, F16})); // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(window)) { + if (window_util::HasDilation(reduce_window->window())) { return Unimplemented( "Dilation for ReduceWindow is not implemented on CPU."); } - // The called computation should have been emitted previously. - llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); - // Pseudo code for reduce window: // // for (coordinates O in the output) @@ -535,72 +599,9 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { // This is completely un-optimized and just here to have something // that works. return EmitTargetElementLoop( - reduce_window, [this, reduce_window, operand, window, - reducer_function](const llvm_ir::IrArray::Index& index) { - // We fold inputs into the accumulator and initialize it to - // the initial value on the reduce_window. - PrimitiveType operand_element_type = operand->shape().element_type(); - llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), - "reduce_window_accumulator_address", &ir_builder_, - MinimumAlignmentForPrimitiveType(operand_element_type)); - ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor( - reduce_window->operand(1))), - accumulator_address); - - llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), - &ir_builder_); - std::vector window_size; - for (const auto& dim : window.dimensions()) { - window_size.push_back(dim.size()); - } - const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape( - ShapeUtil::MakeShape(operand_element_type, window_size), "window"); - CHECK_EQ(window_index.size(), index.size()); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); - - llvm_ir::IrArray::Index input_index(index.size()); - llvm::Value* in_bounds_condition = nullptr; - for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* strided_index = ir_builder_.CreateNSWMul( - index[i], ir_builder_.getInt64(window.dimensions(i).stride())); - input_index[i] = ir_builder_.CreateNSWSub( - ir_builder_.CreateNSWAdd(strided_index, window_index[i]), - ir_builder_.getInt64(window.dimensions(i).padding_low())); - - // We need to check if 0 <= input_index[i] < bound, as - // otherwise we are in the padding so that we can skip the - // computation. That is equivalent to input_index[i] < bound - // as an *unsigned* comparison, since a negative value will - // wrap to a large positive value. - llvm::Value* index_condition = ir_builder_.CreateICmpULT( - input_index[i], ir_builder_.getInt64(ShapeUtil::GetDimension( - operand->shape(), i))); - if (in_bounds_condition == nullptr) { - in_bounds_condition = index_condition; - } else { - in_bounds_condition = - ir_builder_.CreateAnd(in_bounds_condition, index_condition); - } - } - CHECK(in_bounds_condition != nullptr); - - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - in_bounds_condition, "in-bounds", &ir_builder_); - SetToFirstInsertPoint(if_data.true_block, &ir_builder_); - - // We are not in the padding, so carry out the computation. - llvm_ir::IrArray input_array(GetIrArrayFor(operand)); - llvm::Value* input_value_address = - input_array.EmitArrayElementAddress(input_index, &ir_builder_); - llvm::Value* result = EmitElementFunctionCall( - reducer_function, reduce_window->shape(), - {accumulator_address, input_value_address}, "reducer_function"); - ir_builder_.CreateStore(result, accumulator_address); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - return ir_builder_.CreateLoad(accumulator_address); + reduce_window, [&](const llvm_ir::IrArray::Index& index) { + return EmitTargetElementLoopBodyForReduceWindow( + Cast(reduce_window), index); }); } @@ -621,12 +622,6 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { "Dilation for SelectAndScatter is not implemented on CPU. "); } - // The select and scatter computations should have been emitted previously. - llvm::Function* select_function = - FindOrDie(emitted_functions_, select_and_scatter->select()); - llvm::Function* scatter_function = - FindOrDie(emitted_functions_, select_and_scatter->scatter()); - // Pseudo code for select-and-scatter: // // initialized_flag is initially off for every window, and is turned on after @@ -652,140 +647,128 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"), [this, init_value](const llvm_ir::IrArray::Index& target_index) { llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - return ir_builder_.CreateLoad(init_value_addr); + return b_.CreateLoad(init_value_addr); })); // Create a loop to iterate over the source array to scatter to the output. - llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &ir_builder_); + llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &b_); const llvm_ir::IrArray::Index source_index = source_loops.AddLoopsForShape(source->shape(), "source"); - SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), - &ir_builder_); + SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), &b_); // Allocate space to keep the currently selected value, its index, and // the boolean initialized_flag, which is initially set to false. llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), - "selected_value_address", &ir_builder_, + "selected_value_address", &b_, MinimumAlignmentForPrimitiveType(operand_element_type)); llvm::Value* selected_index_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank), - "selected_index_address", &ir_builder_); + b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( - ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_); - ir_builder_.CreateStore(ir_builder_.getInt1(false), initialized_flag_address); + b_.getInt1Ty(), "initialized_flag_address", &b_); + b_.CreateStore(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. - llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), - &ir_builder_); + llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_); std::vector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); } const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape( ShapeUtil::MakeShape(operand_element_type, window_size), "window"); - SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), - &ir_builder_); + SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), &b_); // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - llvm_ir::IrArray::Index operand_index(source_index.size()); - llvm::Value* in_bounds_condition = ir_builder_.getTrue(); + llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size()); + llvm::Value* in_bounds_condition = b_.getTrue(); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = ir_builder_.CreateNSWMul( - source_index[i], ir_builder_.getInt64(window.dimensions(i).stride())); - operand_index[i] = ir_builder_.CreateNSWSub( - ir_builder_.CreateNSWAdd(strided_index, window_index[i]), - ir_builder_.getInt64(window.dimensions(i).padding_low())); - llvm::Value* index_condition = ir_builder_.CreateICmpULT( + llvm::Value* strided_index = b_.CreateNSWMul( + source_index[i], b_.getInt64(window.dimensions(i).stride())); + operand_index[i] = + b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); + llvm::Value* index_condition = b_.CreateICmpULT( operand_index[i], - ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = - ir_builder_.CreateAnd(in_bounds_condition, index_condition); + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); // Only need to do something if the operand index is within the bounds. First // check if the initialized_flag is set. llvm_ir::LlvmIfData if_in_bounds = - llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_); - SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_); - llvm_ir::LlvmIfData if_initialized = - llvm_ir::EmitIfThenElse(ir_builder_.CreateLoad(initialized_flag_address), - "initialized", &ir_builder_); + llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); + SetToFirstInsertPoint(if_in_bounds.true_block, &b_); + llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( + b_.CreateLoad(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. - SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_); + SetToFirstInsertPoint(if_initialized.false_block, &b_); const auto save_operand_index = [&](const llvm_ir::IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - ir_builder_.CreateInBoundsGEP(selected_index_address, - {ir_builder_.getInt32(i)}); - ir_builder_.CreateStore(operand_index[i], - selected_index_address_slot); + b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); + b_.CreateStore(operand_index[i], selected_index_address_slot); } }; llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = - operand_array.EmitReadArrayElement(operand_index, &ir_builder_); - ir_builder_.CreateStore(operand_data, selected_value_address); + operand_array.EmitReadArrayElement(operand_index, &b_); + b_.CreateStore(operand_data, selected_value_address); save_operand_index(operand_index); - ir_builder_.CreateStore(ir_builder_.getInt1(true), initialized_flag_address); + b_.CreateStore(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to potentially // update the selected value and index with the currently visiting operand. - SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_); - const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); + SetToFirstInsertPoint(if_initialized.true_block, &b_); llvm::Value* operand_address = - operand_array.EmitArrayElementAddress(operand_index, &ir_builder_); - llvm::Value* result = EmitElementFunctionCall( - select_function, output_shape, {selected_value_address, operand_address}, + operand_array.EmitArrayElementAddress(operand_index, &b_); + llvm::Value* operand_element = b_.CreateLoad(operand_address); + llvm::Value* result = EmitThreadLocalCall( + *select_and_scatter->select(), + {b_.CreateLoad(selected_value_address), operand_element}, "select_function"); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = ir_builder_.CreateICmpNE( + llvm::Value* cond = b_.CreateICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = - llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_); - SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_); - ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address), - selected_value_address); + llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); + SetToFirstInsertPoint(if_select_lhs.false_block, &b_); + b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to // the selected index of the output. The value we store at the output // location is computed by calling the `scatter` function with the source // value and the current output value. - SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), - &ir_builder_); - llvm_ir::IrArray::Index selected_index; + SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_); + llvm_ir::IrArray::Index selected_index(source_index.GetType()); for (int64 i = 0; i < rank; ++i) { - llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( - selected_index_address, {ir_builder_.getInt32(i)}); - selected_index.push_back( - ir_builder_.CreateLoad(selected_index_address_slot)); + llvm::Value* selected_index_address_slot = + b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); - llvm::Value* source_value_address = - source_array.EmitArrayElementAddress(source_index, &ir_builder_); + llvm::Value* source_value = + source_array.EmitReadArrayElement(source_index, &b_); llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); - llvm::Value* output_value_address = - output_array.EmitArrayElementAddress(selected_index, &ir_builder_); - llvm::Value* scatter_value = EmitElementFunctionCall( - scatter_function, source->shape(), - {output_value_address, source_value_address}, "scatter_function"); - output_array.EmitWriteArrayElement(selected_index, scatter_value, - &ir_builder_); - - SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), - &ir_builder_); + llvm::Value* output_value = + output_array.EmitReadArrayElement(selected_index, &b_); + llvm::Value* scatter_value = + EmitThreadLocalCall(*select_and_scatter->scatter(), + {output_value, source_value}, "scatter_function"); + output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_); + + SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_); return Status::OK(); } @@ -824,21 +807,155 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return DotOpEmitter::EmitDotOperation( *dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, - GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, + GetExecutableRunOptionsArgument(), &b_, hlo_module_config_, target_machine_features_); } +StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( + HloConvolutionInstruction* convolution, + const llvm_ir::IrArray::Index& index) { + const HloInstruction* lhs = convolution->operand(0); + const HloInstruction* rhs = convolution->operand(1); + const Window& window = convolution->window(); + + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + int num_spatial_dims = dnums.output_spatial_dimensions_size(); + std::vector output_spatial(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + output_spatial[i] = index[dnums.output_spatial_dimensions(i)]; + } + llvm::Value* output_feature = index[dnums.output_feature_dimension()]; + llvm::Value* batch = index[dnums.output_batch_dimension()]; + + // We will accumulate the products into this sum to calculate the output entry + // at the given index. + PrimitiveType lhs_element_type = lhs->shape().element_type(); + llvm::Type* lhs_llvm_type = + llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); + llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( + lhs_llvm_type, "convolution_sum_address", &b_, + MinimumAlignmentForPrimitiveType(lhs_element_type)); + llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type); + b_.CreateStore(constant_zero, sum_address); + + llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_); + std::vector kernel_spatial(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + kernel_spatial[i] = + loops + .AddLoop( + 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)), + tensorflow::strings::StrCat("k", i)) + ->GetIndVarValue(); + } + llvm::Value* input_feature = + loops + .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()), + "iz") + ->GetIndVarValue(); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); + + // Calculate the spatial index in the input array, taking striding, dilation + // and padding into account. An index in the padding will be out of the bounds + // of the array. + const auto calculate_input_index = [this](llvm::Value* output_index, + llvm::Value* kernel_index, + const WindowDimension& window_dim) { + llvm::Value* strided_index = + b_.CreateNSWMul(output_index, b_.getInt64(window_dim.stride())); + llvm::Value* dilated_kernel_index = b_.CreateNSWMul( + kernel_index, b_.getInt64(window_dim.window_dilation())); + return b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, dilated_kernel_index), + b_.getInt64(window_dim.padding_low())); + }; + std::vector input_spatial(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + input_spatial[i] = calculate_input_index( + output_spatial[i], kernel_spatial[i], window.dimensions(i)); + } + + // We need to check if 0 <= input dim < bound, as otherwise we are in the + // padding so that we can skip the computation. That is equivalent to input + // dim < bound as an *unsigned* comparison, since a negative value will wrap + // to a large positive value. The input dim is dilated, so we need to dilate + // the bound as well to match. + + // Also need to check that the input coordinates are not in one of the + // holes created by base dilation. + const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) { + llvm::Value* remainder = + b_.CreateSRem(input_index, b_.getInt64(base_dilation)); + return b_.CreateICmpEQ(remainder, b_.getInt64(0)); + }; + + llvm::Value* in_bounds_condition = b_.getInt1(true); + for (int i = 0; i < num_spatial_dims; ++i) { + llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound( + lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), + window.dimensions(i).base_dilation())); + llvm::Value* dim_in_bound = b_.CreateICmpULT(input_spatial[i], input_bound); + llvm::Value* dim_not_in_hole = + not_in_hole(input_spatial[i], window.dimensions(i).base_dilation()); + llvm::Value* dim_ok = b_.CreateAnd(dim_in_bound, dim_not_in_hole); + in_bounds_condition = b_.CreateAnd(in_bounds_condition, dim_ok); + } + + // Now we need to map the dilated base coordinates back to the actual + // data indices on the lhs. + const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) { + return b_.CreateSDiv(input_index, b_.getInt64(base_dilation)); + }; + for (int i = 0; i < num_spatial_dims; ++i) { + input_spatial[i] = + undilate(input_spatial[i], window.dimensions(i).base_dilation()); + } + + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); + SetToFirstInsertPoint(if_data.true_block, &b_); + + // We are not in the padding, so carry out the computation. + int num_dims = num_spatial_dims + 2; + llvm_ir::IrArray::Index input_index(b_.getInt64Ty(), num_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; + } + input_index[dnums.input_feature_dimension()] = input_feature; + input_index[dnums.input_batch_dimension()] = batch; + + llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); + llvm_ir::IrArray::Index kernel_index(b_.getInt64Ty(), num_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + kernel_index[dnums.kernel_spatial_dimensions(i)] = + window.dimensions(i).window_reversal() + ? b_.CreateNSWSub(b_.getInt64(window.dimensions(i).size() - 1), + kernel_spatial[i]) + : kernel_spatial[i]; + } + + kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; + kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; + + llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); + llvm::Value* product = + b_.CreateFMul(input_array.EmitReadArrayElement(input_index, &b_), + kernel_array.EmitReadArrayElement(kernel_index, &b_)); + llvm::Value* sum = b_.CreateFAdd(b_.CreateLoad(sum_address), product); + b_.CreateStore(sum, sum_address); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); + return b_.CreateLoad(sum_address); +} + Status IrEmitter::HandleConvolution(HloInstruction* convolution) { auto lhs = convolution->operand(0); auto rhs = convolution->operand(1); - const auto& window = convolution->window(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, /*supported_types=*/{F16, F32, C64})); - const ConvolutionDimensionNumbers& dnums = - convolution->convolution_dimension_numbers(); - // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support // different data layouts. if (PotentiallyImplementedAsEigenConvolution(*convolution, @@ -918,12 +1035,12 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { PrimitiveType primitive_type = lhs->shape().element_type(); llvm::Type* ir_ptr_type = primitive_type == F16 - ? ir_builder_.getHalfTy()->getPointerTo() - : ir_builder_.getFloatTy()->getPointerTo(); - llvm::Type* int64_type = ir_builder_.getInt64Ty(); - llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo(); + ? b_.getHalfTy()->getPointerTo() + : b_.getFloatTy()->getPointerTo(); + llvm::Type* int64_type = b_.getInt64Ty(); + llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo(); llvm::FunctionType* conv_type = llvm::FunctionType::get( - ir_builder_.getVoidTy(), + b_.getVoidTy(), {int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, @@ -955,34 +1072,34 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { conv_func->setCallingConv(llvm::CallingConv::C); conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); - ir_builder_.CreateCall( - conv_func, { - GetExecutableRunOptionsArgument(), - ir_builder_.CreateBitCast( - GetEmittedValueFor(convolution), ir_ptr_type), - ir_builder_.CreateBitCast(lhs_address, ir_ptr_type), - ir_builder_.CreateBitCast(rhs_address, ir_ptr_type), - ir_builder_.getInt64(input_batch), - ir_builder_.getInt64(input_rows), - ir_builder_.getInt64(input_cols), - ir_builder_.getInt64(input_channels), - ir_builder_.getInt64(kernel_rows), - ir_builder_.getInt64(kernel_cols), - ir_builder_.getInt64(kernel_channels), - ir_builder_.getInt64(kernel_filters), - ir_builder_.getInt64(output_rows), - ir_builder_.getInt64(output_cols), - ir_builder_.getInt64(row_stride), - ir_builder_.getInt64(col_stride), - ir_builder_.getInt64(padding_top), - ir_builder_.getInt64(padding_bottom), - ir_builder_.getInt64(padding_left), - ir_builder_.getInt64(padding_right), - ir_builder_.getInt64(lhs_row_dilation), - ir_builder_.getInt64(lhs_col_dilation), - ir_builder_.getInt64(rhs_row_dilation), - ir_builder_.getInt64(rhs_col_dilation), - }); + b_.CreateCall( + conv_func, + { + GetExecutableRunOptionsArgument(), + b_.CreateBitCast(GetEmittedValueFor(convolution), ir_ptr_type), + b_.CreateBitCast(lhs_address, ir_ptr_type), + b_.CreateBitCast(rhs_address, ir_ptr_type), + b_.getInt64(input_batch), + b_.getInt64(input_rows), + b_.getInt64(input_cols), + b_.getInt64(input_channels), + b_.getInt64(kernel_rows), + b_.getInt64(kernel_cols), + b_.getInt64(kernel_channels), + b_.getInt64(kernel_filters), + b_.getInt64(output_rows), + b_.getInt64(output_cols), + b_.getInt64(row_stride), + b_.getInt64(col_stride), + b_.getInt64(padding_top), + b_.getInt64(padding_bottom), + b_.getInt64(padding_left), + b_.getInt64(padding_right), + b_.getInt64(lhs_row_dilation), + b_.getInt64(lhs_col_dilation), + b_.getInt64(rhs_row_dilation), + b_.getInt64(rhs_col_dilation), + }); return Status::OK(); } @@ -995,149 +1112,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // See the description of convolution in the XLA documentation for the pseudo // code for convolution. return EmitTargetElementLoop( - convolution, [this, convolution, lhs, rhs, window, - dnums](const llvm_ir::IrArray::Index& index) { - int num_spatial_dims = dnums.output_spatial_dimensions_size(); - std::vector output_spatial(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - output_spatial[i] = index[dnums.output_spatial_dimensions(i)]; - } - llvm::Value* output_feature = index[dnums.output_feature_dimension()]; - llvm::Value* batch = index[dnums.output_batch_dimension()]; - - // We will accumulate the products into this sum to calculate - // the output entry at the given index. - PrimitiveType lhs_element_type = lhs->shape().element_type(); - llvm::Type* lhs_llvm_type = - llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); - llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - lhs_llvm_type, "convolution_sum_address", &ir_builder_, - MinimumAlignmentForPrimitiveType(lhs_element_type)); - llvm::Value* constant_zero = - llvm::Constant::getNullValue(lhs_llvm_type); - ir_builder_.CreateStore(constant_zero, sum_address); - - llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_); - std::vector kernel_spatial(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - kernel_spatial[i] = - loops - .AddLoop(0, - rhs->shape().dimensions( - dnums.kernel_spatial_dimensions(i)), - tensorflow::strings::StrCat("k", i)) - ->GetIndVarValue(); - } - llvm::Value* input_feature = - loops - .AddLoop( - 0, lhs->shape().dimensions(dnums.input_feature_dimension()), - "iz") - ->GetIndVarValue(); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); - - // Calculate the spatial index in the input array, taking striding, - // dilation and padding into account. An index in the padding will be - // out of the bounds of the array. - const auto calculate_input_index = - [this](llvm::Value* output_index, llvm::Value* kernel_index, - const WindowDimension& window_dim) { - llvm::Value* strided_index = ir_builder_.CreateNSWMul( - output_index, ir_builder_.getInt64(window_dim.stride())); - llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul( - kernel_index, - ir_builder_.getInt64(window_dim.window_dilation())); - return ir_builder_.CreateNSWSub( - ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index), - ir_builder_.getInt64(window_dim.padding_low())); - }; - std::vector input_spatial(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - input_spatial[i] = calculate_input_index( - output_spatial[i], kernel_spatial[i], window.dimensions(i)); - } - - // We need to check if 0 <= input dim < bound, as otherwise we are in - // the padding so that we can skip the computation. That is equivalent - // to input dim < bound as an *unsigned* comparison, since a negative - // value will wrap to a large positive value. The input dim is dilated, - // so we need to dilate the bound as well to match. - - // Also need to check that the input coordinates are not in one of the - // holes created by base dilation. - const auto not_in_hole = [&](llvm::Value* input_index, - int64 base_dilation) { - llvm::Value* remainder = ir_builder_.CreateSRem( - input_index, ir_builder_.getInt64(base_dilation)); - return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0)); - }; - - llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); - for (int i = 0; i < num_spatial_dims; ++i) { - llvm::ConstantInt* input_bound = - ir_builder_.getInt64(window_util::DilatedBound( - lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), - window.dimensions(i).base_dilation())); - llvm::Value* dim_in_bound = - ir_builder_.CreateICmpULT(input_spatial[i], input_bound); - llvm::Value* dim_not_in_hole = not_in_hole( - input_spatial[i], window.dimensions(i).base_dilation()); - llvm::Value* dim_ok = - ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole); - in_bounds_condition = - ir_builder_.CreateAnd(in_bounds_condition, dim_ok); - } - - // Now we need to map the dilated base coordinates back to the actual - // data indices on the lhs. - const auto undilate = [&](llvm::Value* input_index, - int64 base_dilation) { - return ir_builder_.CreateSDiv(input_index, - ir_builder_.getInt64(base_dilation)); - }; - for (int i = 0; i < num_spatial_dims; ++i) { - input_spatial[i] = - undilate(input_spatial[i], window.dimensions(i).base_dilation()); - } - - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - in_bounds_condition, "in-bounds", &ir_builder_); - SetToFirstInsertPoint(if_data.true_block, &ir_builder_); - - // We are not in the padding, so carry out the computation. - int num_dims = num_spatial_dims + 2; - llvm_ir::IrArray::Index input_index(num_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; - } - input_index[dnums.input_feature_dimension()] = input_feature; - input_index[dnums.input_batch_dimension()] = batch; - - llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); - llvm_ir::IrArray::Index kernel_index(num_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - kernel_index[dnums.kernel_spatial_dimensions(i)] = - window.dimensions(i).window_reversal() - ? ir_builder_.CreateNSWSub( - ir_builder_.getInt64(window.dimensions(i).size() - 1), - kernel_spatial[i]) - : kernel_spatial[i]; - } - - kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; - kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; - - llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); - llvm::Value* product = ir_builder_.CreateFMul( - input_array.EmitReadArrayElement(input_index, &ir_builder_), - kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_)); - llvm::Value* sum = ir_builder_.CreateFAdd( - ir_builder_.CreateLoad(sum_address), product); - ir_builder_.CreateStore(sum, sum_address); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - return ir_builder_.CreateLoad(sum_address); + convolution, [&](const llvm_ir::IrArray::Index& index) { + return EmitTargetElementLoopBodyForConvolution( + Cast(convolution), index); }); } @@ -1161,11 +1138,11 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { } // Args have been computed, make the call. - llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo(); - llvm::Type* int32_type = ir_builder_.getInt32Ty(); - llvm::Type* int64_type = ir_builder_.getInt64Ty(); + llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo(); + llvm::Type* int32_type = b_.getInt32Ty(); + llvm::Type* int64_type = b_.getInt64Ty(); llvm::FunctionType* fft_type = llvm::FunctionType::get( - ir_builder_.getVoidTy(), + b_.getVoidTy(), {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); @@ -1182,16 +1159,15 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { fft_func->setDoesNotThrow(); fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); const int fft_rank = fft_length.size(); - ir_builder_.CreateCall( + b_.CreateCall( fft_func, {GetExecutableRunOptionsArgument(), - ir_builder_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type), - ir_builder_.CreateBitCast(operand_address, int8_ptr_type), - ir_builder_.getInt32(fft->fft_type()), ir_builder_.getInt32(fft_rank), - ir_builder_.getInt64(input_batch), - ir_builder_.getInt64(fft_rank > 0 ? fft_length[0] : 0), - ir_builder_.getInt64(fft_rank > 1 ? fft_length[1] : 0), - ir_builder_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); + b_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type), + b_.CreateBitCast(operand_address, int8_ptr_type), + b_.getInt32(fft->fft_type()), b_.getInt32(fft_rank), + b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), + b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), + b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); return Status::OK(); } @@ -1230,11 +1206,10 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); // TODO(b/63762267): Be more aggressive about specifying alignment. - ir_builder_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, - /*SrcAlign=*/1, - ShapeUtil::ByteSizeOf(operand_shape)); + b_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, + /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); } - llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &ir_builder_, module_); + llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_); return Status::OK(); } @@ -1267,47 +1242,7 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex( Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); - auto param_number = parameter->parameter_number(); - auto param_shape = parameter->shape(); - - // We have to access the parameter at offset param_number in the params - // array. The code generated here is equivalent to this C code: - // - // i8* param_address_untyped = params[param_number]; - // Param* param_address_typed = (Param*)param_address_untyped; - // - // Where Param is the actual element type of the underlying buffer (for - // example, float for an XLA F32 element type). - llvm::Value* params = compute_function_->parameters_arg(); - llvm::Value* param_address_offset = - llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); - llvm::LoadInst* param_address_untyped = - ir_builder_.CreateLoad(param_address_offset); - param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); - if (is_top_level_computation_ && - hlo_module_config_.debug_options() - .xla_llvm_enable_invariant_load_metadata()) { - // In the entry computation the parameter slots in the %params argument are - // invariant through program execution. In computations that are called - // from the entry computation (via kWhile, kCall and kConditional) the - // parameter slots are *not* invariant since they're written to by their - // callers. - param_address_untyped->setMetadata( - llvm::LLVMContext::MD_invariant_load, - llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); - } - - llvm::Value* param_address_typed = ir_builder_.CreateBitCast( - param_address_untyped, IrShapeType(param_shape)->getPointerTo()); - emitted_value_[parameter] = param_address_typed; - - if (!ShapeUtil::IsOpaque(param_shape)) { - AttachAlignmentMetadataForLoad(param_address_untyped, param_shape); - AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape); - } - - VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed); - return Status::OK(); + return EmitTargetAddressForOp(parameter); } // Returns true if the relative order of the unreduced dimensions stays the same @@ -1405,58 +1340,61 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( return nullptr; case HloOpcode::kAdd: - return [root_is_integral](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, + return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { - return root_is_integral ? ir_builder->CreateAdd(lhs, rhs) - : ir_builder->CreateFAdd(lhs, rhs); + return root_is_integral ? b->CreateAdd(lhs, rhs) + : b->CreateFAdd(lhs, rhs); }; case HloOpcode::kMultiply: - return [root_is_integral](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, + return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { - return root_is_integral ? ir_builder->CreateMul(lhs, rhs) - : ir_builder->CreateFMul(lhs, rhs); + return root_is_integral ? b->CreateMul(lhs, rhs) + : b->CreateFMul(lhs, rhs); }; case HloOpcode::kAnd: - return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, - llvm::Value* rhs) { return ir_builder->CreateAnd(lhs, rhs); }; + return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + return b->CreateAnd(lhs, rhs); + }; case HloOpcode::kOr: - return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, - llvm::Value* rhs) { return ir_builder->CreateOr(lhs, rhs); }; + return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + return b->CreateOr(lhs, rhs); + }; + + case HloOpcode::kXor: + return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + return b->CreateXor(lhs, rhs); + }; case HloOpcode::kMaximum: return [root_is_floating_point, root_is_signed]( - llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, - llvm::Value* rhs) { + llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { if (root_is_floating_point) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum, - {lhs, rhs}, {lhs->getType()}, - ir_builder); + {lhs, rhs}, {lhs->getType()}, b); } - return ir_builder->CreateSelect( - ir_builder->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE - : llvm::ICmpInst::ICMP_UGE, - lhs, rhs), + return b->CreateSelect( + b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE + : llvm::ICmpInst::ICMP_UGE, + lhs, rhs), lhs, rhs); }; case HloOpcode::kMinimum: return [root_is_floating_point, root_is_signed]( - llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, - llvm::Value* rhs) { + llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { if (root_is_floating_point) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum, - {lhs, rhs}, {lhs->getType()}, - ir_builder); + {lhs, rhs}, {lhs->getType()}, b); } - return ir_builder->CreateSelect( - ir_builder->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE - : llvm::ICmpInst::ICMP_ULE, - lhs, rhs), + return b->CreateSelect( + b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE + : llvm::ICmpInst::ICMP_ULE, + lhs, rhs), lhs, rhs); }; } @@ -1525,34 +1463,31 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( accumulator.reserve(accumulator_type.size()); for (auto accumulator_shard_type : accumulator_type) { accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry( - accumulator_shard_type, "accumulator", &ir_builder_, 0)); + accumulator_shard_type, "accumulator", &b_, 0)); } - llvm::Value* init_value_ssa = - ir_builder_.CreateLoad(GetEmittedValueFor(init_value)); + llvm::Value* init_value_ssa = b_.CreateLoad(GetEmittedValueFor(init_value)); for (llvm::Value* accumulator_shard : accumulator) { llvm::Value* initial_value; auto shard_type = accumulator_shard->getType()->getPointerElementType(); if (auto vector_type = llvm::dyn_cast(shard_type)) { - initial_value = ir_builder_.CreateVectorSplat( - vector_type->getNumElements(), init_value_ssa); + initial_value = + b_.CreateVectorSplat(vector_type->getNumElements(), init_value_ssa); } else { initial_value = init_value_ssa; } - ir_builder_.CreateAlignedStore(initial_value, accumulator_shard, - element_alignment); + b_.CreateAlignedStore(initial_value, accumulator_shard, element_alignment); } llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"), - &ir_builder_); + &b_); llvm_ir::IrArray::Index reduced_dims_index = reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, "reduction_dim"); - SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), - &ir_builder_); + SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_); llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); llvm_ir::IrArray::Index input_index = reduced_dims_index; @@ -1565,38 +1500,34 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( } CHECK(output_index.end() == it); - llvm::Value* input_address = ir_builder_.CreateBitCast( - arg_array.EmitArrayElementAddress(input_index, &ir_builder_), - ir_builder_.getInt8PtrTy()); + llvm::Value* input_address = b_.CreateBitCast( + arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy()); for (int i = 0; i < accumulator.size(); i++) { auto input_address_typed = - ir_builder_.CreateBitCast(input_address, accumulator[i]->getType()); + b_.CreateBitCast(input_address, accumulator[i]->getType()); auto current_accumulator_value = - ir_builder_.CreateAlignedLoad(accumulator[i], element_alignment); - auto addend = - ir_builder_.CreateAlignedLoad(input_address_typed, element_alignment); + b_.CreateAlignedLoad(accumulator[i], element_alignment); + auto addend = b_.CreateAlignedLoad(input_address_typed, element_alignment); arg_array.AnnotateLoadStoreInstructionWithMetadata(addend); auto reduced_result = - reduction_generator(&ir_builder_, current_accumulator_value, addend); - ir_builder_.CreateAlignedStore(reduced_result, accumulator[i], - element_alignment); + reduction_generator(&b_, current_accumulator_value, addend); + b_.CreateAlignedStore(reduced_result, accumulator[i], element_alignment); if (i != (accumulator.size() - 1)) { - input_address = ir_builder_.CreateConstInBoundsGEP1_32( - reduced_result->getType(), input_address_typed, 1); + input_address = b_.CreateConstInBoundsGEP1_32(reduced_result->getType(), + input_address_typed, 1); } } - SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), - &ir_builder_); + SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), &b_); ShardedVector result_ssa; result_ssa.reserve(accumulator.size()); for (auto accumulator_shard : accumulator) { result_ssa.push_back( - ir_builder_.CreateAlignedLoad(accumulator_shard, element_alignment)); + b_.CreateAlignedLoad(accumulator_shard, element_alignment)); } return result_ssa; } @@ -1605,17 +1536,17 @@ void IrEmitter::EmitShardedVectorStore( llvm::Value* store_address, const std::vector& value_to_store, const int alignment, const llvm_ir::IrArray& containing_array) { for (int i = 0; i < value_to_store.size(); i++) { - auto store_address_typed = ir_builder_.CreateBitCast( + auto store_address_typed = b_.CreateBitCast( store_address, llvm::PointerType::getUnqual(value_to_store[i]->getType())); - auto store_instruction = ir_builder_.CreateAlignedStore( + auto store_instruction = b_.CreateAlignedStore( value_to_store[i], store_address_typed, alignment); containing_array.AnnotateLoadStoreInstructionWithMetadata( store_instruction); if (i != (value_to_store.size() - 1)) { - store_address = ir_builder_.CreateConstInBoundsGEP1_32( + store_address = b_.CreateConstInBoundsGEP1_32( value_to_store[i]->getType(), store_address_typed, 1); } } @@ -1681,8 +1612,9 @@ StatusOr IrEmitter::EmitVectorizedReduce( // } // } - llvm_ir::ForLoopNest loop_nest(IrName(reduce), &ir_builder_); - llvm_ir::IrArray::Index array_index(reduce->shape().dimensions_size()); + llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_); + llvm_ir::IrArray::Index array_index(b_.getInt64Ty(), + reduce->shape().dimensions_size()); for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0; --i) { int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); @@ -1700,7 +1632,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( if (llvm::BasicBlock* innermost_body_bb = loop_nest.GetInnerLoopBodyBasicBlock()) { - SetToFirstInsertPoint(innermost_body_bb, &ir_builder_); + SetToFirstInsertPoint(innermost_body_bb, &b_); } auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock(); @@ -1714,7 +1646,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( tensorflow::strings::Printf("dim.%lld", innermost_dimension)); array_index[innermost_dimension] = loop->GetIndVarValue(); - SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_); ShardedVectorType vector_type = CreateShardedVectorType( reduce->shape().element_type(), vectorization_factor); @@ -1725,16 +1657,16 @@ StatusOr IrEmitter::EmitVectorizedReduce( llvm_ir::IrArray target_array = GetIrArrayFor(reduce); llvm::Value* output_address = - target_array.EmitArrayElementAddress(array_index, &ir_builder_); + target_array.EmitArrayElementAddress(array_index, &b_); EmitShardedVectorStore(output_address, accumulator, element_alignment, target_array); if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) { CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); - ir_builder_.SetInsertPoint(exit_terminator); + b_.SetInsertPoint(exit_terminator); } else { CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); - ir_builder_.SetInsertPoint(loop->GetExitBasicBlock()); + b_.SetInsertPoint(loop->GetExitBasicBlock()); } } @@ -1744,8 +1676,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( if (innermost_dimension_size % vectorization_factor) { // TODO(b/63775531): Consider using a scalar loop here to save on code size. array_index[innermost_dimension] = - ir_builder_.getInt64(innermost_dimension_size - - (innermost_dimension_size % vectorization_factor)); + b_.getInt64(innermost_dimension_size - + (innermost_dimension_size % vectorization_factor)); ShardedVectorType vector_type = CreateShardedVectorType( reduce->shape().element_type(), @@ -1757,19 +1689,77 @@ StatusOr IrEmitter::EmitVectorizedReduce( llvm_ir::IrArray target_array = GetIrArrayFor(reduce); llvm::Value* output_address = - target_array.EmitArrayElementAddress(array_index, &ir_builder_); + target_array.EmitArrayElementAddress(array_index, &b_); EmitShardedVectorStore(output_address, accumulator, element_alignment, target_array); } if (outermost_loop_exit_block) { - ir_builder_.SetInsertPoint(outermost_loop_exit_block); + b_.SetInsertPoint(outermost_loop_exit_block); } return true; } +StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( + HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) { + const HloInstruction* arg = reduce->mutable_operand(0); + const HloInstruction* init_value = reduce->mutable_operand(1); + gtl::ArraySlice dimensions(reduce->dimensions()); + + // Initialize an accumulator with init_value. + PrimitiveType accumulator_type = reduce->shape().element_type(); + llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator", + &b_, MinimumAlignmentForPrimitiveType(accumulator_type)); + llvm::Value* init_value_addr = GetEmittedValueFor(init_value); + llvm::Value* load_init_value = b_.CreateLoad(init_value_addr); + b_.CreateStore(load_init_value, accumulator_addr); + + // The enclosing loops go over all the target elements. Now we have to compute + // the actual target element. For this, we build a new loop nest to iterate + // over all the reduction dimensions in the argument. + // AddLoopsForShapeOnDimensions will return an Index where induction Value*s + // are placed for each dimension in dimensions, and all the rest are nullptrs. + llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); + const llvm_ir::IrArray::Index reduced_dims_index = + loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, + "reduction_dim"); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); + + // Build a full index for the input argument, using reduced_dims_index as the + // base. In reduced_dims_index only the reduction dimensions are filled in. We + // fill in the rest of the dimensions with induction Value*s taken from + // 'index' which iterates over the target array. See the high-level + // description in the XLA documentation for details. + llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); + llvm_ir::IrArray::Index input_index = reduced_dims_index; + llvm_ir::IrArray::Index::const_iterator it = index.begin(); + + for (size_t i = 0; i < input_index.size(); ++i) { + if (input_index[i] == nullptr) { + input_index[i] = *it++; + } + } + CHECK(index.end() == it); + + // Apply the reduction function to the loaded value. + llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_); + llvm::Value* result = EmitThreadLocalCall( + *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element}, + "reduce_function"); + b_.CreateStore(result, accumulator_addr); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); + return b_.CreateLoad(accumulator_addr); +} + Status IrEmitter::HandleReduce(HloInstruction* reduce) { + // TODO(b/112040122): Support variadic reduce. + if (!ShapeUtil::IsArray(reduce->shape())) { + return Unimplemented("Variadic reduce is not supported on CPU"); + } auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); gtl::ArraySlice dimensions(reduce->dimensions()); @@ -1790,61 +1780,11 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { } } - // The called computation should have been emitted previously. - llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); - return EmitTargetElementLoop( - reduce, [this, reduce, arg, init_value, dimensions, - reducer_function](const llvm_ir::IrArray::Index& index) { - // Initialize an accumulator with init_value. - PrimitiveType accumulator_type = reduce->shape().element_type(); - llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), - "accumulator", &ir_builder_, - MinimumAlignmentForPrimitiveType(accumulator_type)); - llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - llvm::Value* load_init_value = ir_builder_.CreateLoad(init_value_addr); - ir_builder_.CreateStore(load_init_value, accumulator_addr); - - // The enclosing loops go over all the target elements. Now we have to - // compute the actual target element. For this, we build a new loop nest - // to iterate over all the reduction dimensions in the argument. - // AddLoopsForShapeOnDimensions will return an Index where induction - // Value*s are placed for each dimension in dimensions, and all the rest - // are nullptrs. - llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &ir_builder_); - const llvm_ir::IrArray::Index reduced_dims_index = - loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, - "reduction_dim"); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); - - // Build a full index for the input argument, using reduced_dims_index - // as the base. In reduced_dims_index only the reduction dimensions are - // filled in. We fill in the rest of the dimensions with induction - // Value*s taken from 'index' which iterates over the target array. - // See the high-level description in the XLA documentation for details. - llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); - llvm_ir::IrArray::Index input_index = reduced_dims_index; - llvm_ir::IrArray::Index::const_iterator it = index.begin(); - - for (size_t i = 0; i < input_index.size(); ++i) { - if (input_index[i] == nullptr) { - input_index[i] = *it++; - } - } - CHECK(index.end() == it); - - // Apply the reduction function to the loaded value. - llvm::Value* input_address = - arg_array.EmitArrayElementAddress(input_index, &ir_builder_); - llvm::Value* result = EmitElementFunctionCall( - reducer_function, reduce->shape(), - {accumulator_addr, input_address}, "reduce_function"); - ir_builder_.CreateStore(result, accumulator_addr); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - return ir_builder_.CreateLoad(accumulator_addr); - }); + return EmitTargetElementLoop(reduce, + [&](const llvm_ir::IrArray::Index& index) { + return EmitTargetElementLoopBodyForReduce( + Cast(reduce), index); + }); } Status IrEmitter::HandleSend(HloInstruction* send) { @@ -1857,6 +1797,10 @@ Status IrEmitter::HandleSendDone(HloInstruction* send_done) { return Unimplemented("Send-done is not implemented on CPU."); } +Status IrEmitter::HandleScatter(HloInstruction*) { + return Unimplemented("Scatter is not implemented on CPUs."); +} + Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); auto operand = slice->operand(0); @@ -1873,7 +1817,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice)); - if (ShapeUtil::HasZeroElements(slice->shape())) { + if (ShapeUtil::IsZeroElementArray(slice->shape())) { return Status::OK(); } @@ -1946,7 +1890,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { llvm_ir::IrArray target_array = GetIrArrayFor(slice); const int64 num_outer_loops = outer_dims.size(); - llvm_ir::ForLoopNest loops(IrName(slice), &ir_builder_); + llvm_ir::ForLoopNest loops(IrName(slice), &b_); llvm_ir::IrArray::Index target_index = loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice"); @@ -1955,21 +1899,21 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // for the rest of the dimensions the copy writes to the full dimension. std::replace(target_index.begin(), target_index.end(), static_cast(nullptr), - static_cast(ir_builder_.getInt64(0))); + static_cast(b_.getInt64(0))); if (num_outer_loops > 0) { - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); } llvm_ir::IrArray source_array = GetIrArrayFor(operand); const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice( /*shape=*/slice->shape(), /*starts=*/slice->slice_starts(), - /*strides=*/slice->slice_strides(), /*builder=*/&ir_builder_); + /*strides=*/slice->slice_strides(), /*builder=*/&b_); - llvm::Value* memcpy_dest = target_array.EmitArrayElementAddress( - target_index, &ir_builder_, "slice.dest"); - llvm::Value* memcpy_source = source_array.EmitArrayElementAddress( - source_index, &ir_builder_, "slice.source"); + llvm::Value* memcpy_dest = + target_array.EmitArrayElementAddress(target_index, &b_, "slice.dest"); + llvm::Value* memcpy_source = + source_array.EmitArrayElementAddress(source_index, &b_, "slice.source"); const int64 memcpy_elements = primitive_elements_per_logical_element * memcpy_logical_elements; @@ -1986,7 +1930,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { } if (num_outer_loops > 0) { - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); } return Status::OK(); @@ -2012,7 +1956,7 @@ Status IrEmitter::HandleDynamicUpdateSlice( auto operands = GetIrArraysForOperandsOf(dynamic_update_slice); return llvm_ir::EmitDynamicUpdateSliceInPlace( operands, GetIrArrayFor(dynamic_update_slice), - IrName(dynamic_update_slice, "in_place"), &ir_builder_); + IrName(dynamic_update_slice, "in_place"), &b_); } return DefaultAction(dynamic_update_slice); } @@ -2046,43 +1990,41 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { [this, pad](const llvm_ir::IrArray::Index& target_index) { const HloInstruction* padding_value = pad->operand(1); llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value); - return ir_builder_.CreateLoad(padding_value_addr); + return b_.CreateLoad(padding_value_addr); })); // Create a loop to iterate over the operand elements and update the output // locations where the operand elements should be stored. - llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &ir_builder_); + llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &b_); const HloInstruction* operand = pad->operand(0); const llvm_ir::IrArray::Index operand_index = loops.AddLoopsForShape(operand->shape(), "operand"); - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); // Load an element from the operand. llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = - operand_array.EmitReadArrayElement(operand_index, &ir_builder_); + operand_array.EmitReadArrayElement(operand_index, &b_); // Compute the output index the operand element should be assigned to. // output_index := edge_padding_low + operand_index * (interior_padding + 1) const PaddingConfig& padding_config = pad->padding_config(); - llvm_ir::IrArray::Index output_index; + llvm_ir::IrArray::Index output_index(operand_index.GetType()); for (size_t i = 0; i < operand_index.size(); ++i) { - llvm::Value* offset = ir_builder_.CreateMul( + llvm::Value* offset = b_.CreateMul( operand_index[i], - ir_builder_.getInt64(padding_config.dimensions(i).interior_padding() + - 1)); - llvm::Value* index = ir_builder_.CreateAdd( - offset, - ir_builder_.getInt64(padding_config.dimensions(i).edge_padding_low())); + b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); + llvm::Value* index = b_.CreateAdd( + offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low())); output_index.push_back(index); } // Store the operand element to the computed output location. llvm_ir::IrArray output_array(GetIrArrayFor(pad)); - output_array.EmitWriteArrayElement(output_index, operand_data, &ir_builder_); + output_array.EmitWriteArrayElement(output_index, operand_data, &b_); - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); return Status::OK(); } @@ -2104,8 +2046,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { // Delegate to common implementation of fused in-place dynamic-update-slice. auto operands = GetIrArraysForOperandsOf(fusion); return llvm_ir::EmitFusedDynamicUpdateSliceInPlace( - fusion, operands, GetIrArrayFor(fusion), &elemental_emitter, - &ir_builder_); + fusion, operands, GetIrArrayFor(fusion), &elemental_emitter, &b_); } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) { VLOG(3) << "HandleFusion kLoop"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); @@ -2140,7 +2081,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( *dot, target_array, lhs_array, rhs_array, &addend_array, - GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, + GetExecutableRunOptionsArgument(), &b_, hlo_module_config_, target_machine_features_)); return Status::OK(); } else { @@ -2152,18 +2093,13 @@ Status IrEmitter::HandleCall(HloInstruction* call) { HloComputation* computation = call->to_apply(); llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation); - std::vector parameter_addresses; - for (const HloInstruction* operand : call->operands()) { - parameter_addresses.push_back(GetEmittedValueFor(operand)); - } - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); if (!computation->root_instruction()->outer_dimension_partitions().empty()) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. std::vector call_args = GetArrayFunctionCallArguments( - parameter_addresses, &ir_builder_, computation->name(), + {}, &b_, computation->name(), /*return_value_buffer=*/emitted_value_[call], /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), /*temp_buffers_arg=*/GetTempBuffersArgument(), @@ -2171,11 +2107,10 @@ Status IrEmitter::HandleCall(HloInstruction* call) { HloInstruction* root = computation->root_instruction(); TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin( - call_args, root->shape(), root->outer_dimension_partitions(), - &ir_builder_, call_ir_function, computation->name())); + call_args, root->shape(), root->outer_dimension_partitions(), &b_, + call_ir_function, computation->name())); } else { - EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, - emitted_value_[call], computation->name()); + EmitGlobalCall(*computation, computation->name()); } return Status::OK(); @@ -2184,33 +2119,31 @@ Status IrEmitter::HandleCall(HloInstruction* call) { Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { gtl::ArraySlice operands(custom_call->operands()); tensorflow::StringPiece custom_call_target(custom_call->custom_call_target()); - llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); + llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - i8_ptr_type, ir_builder_.getInt32(operands.size()), - "cc_operands_alloca", &ir_builder_); + i8_ptr_type, b_.getInt32(operands.size()), "cc_operands_alloca", &b_); for (size_t i = 0; i < operands.size(); ++i) { const HloInstruction* operand = operands[i]; llvm::Value* operand_as_i8ptr = - ir_builder_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); - llvm::Value* slot_in_operands_alloca = ir_builder_.CreateInBoundsGEP( - operands_alloca, {ir_builder_.getInt64(i)}); - ir_builder_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca); + b_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); + llvm::Value* slot_in_operands_alloca = + b_.CreateInBoundsGEP(operands_alloca, {b_.getInt64(i)}); + b_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca); } auto* custom_call_ir_function = llvm::cast(module_->getOrInsertFunction( AsStringRef(custom_call_target), llvm::FunctionType::get( - /*Result=*/ir_builder_.getVoidTy(), + /*Result=*/b_.getVoidTy(), /*Params=*/{i8_ptr_type, operands_alloca->getType()}, /*isVarArg=*/false))); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); - auto* output_address_arg = ir_builder_.CreatePointerCast( - GetEmittedValueFor(custom_call), i8_ptr_type); + auto* output_address_arg = + b_.CreatePointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); - ir_builder_.CreateCall(custom_call_ir_function, - {output_address_arg, operands_alloca}); + b_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca}); return Status::OK(); } @@ -2258,12 +2191,6 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { const HloInstruction* init = xla_while->operand(0); emitted_value_[xla_while] = GetEmittedValueFor(init); - // The called computation should have been emitted previously. - llvm::Function* condition_ir_function = - FindOrDie(emitted_functions_, condition); - llvm::Function* body_ir_function = - FindOrDie(emitted_functions_, xla_while->while_body()); - // Generating: // while (Condition(while_result)) { // // CopyInsertion pass inserts copies which enable 'while_result' to @@ -2275,17 +2202,15 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "header")), compute_function_->function()); - ir_builder_.CreateBr(header_bb); - ir_builder_.SetInsertPoint(header_bb); + b_.CreateBr(header_bb); + b_.SetInsertPoint(header_bb); // Calls the condition function to determine whether to proceed with the // body. It must return a bool, so use the scalar call form. - llvm::Value* while_result = GetEmittedValueFor(xla_while); - llvm::Value* while_condition = EmitElementFunctionCall( - condition_ir_function, condition->root_instruction()->shape(), - {while_result}, IrName(xla_while, "cond")); - llvm::Value* while_predicate = ir_builder_.CreateICmpNE( - while_condition, + EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond")); + llvm::Value* while_predicate = b_.CreateICmpNE( + b_.CreateLoad( + GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. @@ -2294,20 +2219,20 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); - ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb); + b_.CreateCondBr(while_predicate, body_bb, exit_bb); // Calls the body function from the body block. - ir_builder_.SetInsertPoint(body_bb); + b_.SetInsertPoint(body_bb); // Calls the body function. - EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result, - IrName(xla_while, "body")); + EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body")); + // Finishes with a branch back to the header. - ir_builder_.CreateBr(header_bb); + b_.CreateBr(header_bb); // Adds the exit block to the function and sets the insert point there. compute_function_->function()->getBasicBlockList().push_back(exit_bb); - ir_builder_.SetInsertPoint(exit_bb); + b_.SetInsertPoint(exit_bb); return Status::OK(); } @@ -2349,21 +2274,21 @@ StatusOr IrEmitter::EmitFastConcatenate( std::vector outer_dims(std::next(concat_dim_layout_itr), output_min2maj.end()); - llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); - llvm::Type* i8_type = ir_builder_.getInt8Ty(); + llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); + llvm::Type* i8_type = b_.getInt8Ty(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate)); llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); - llvm_ir::ForLoopNest loops(IrName(concatenate), &ir_builder_); + llvm_ir::ForLoopNest loops(IrName(concatenate), &b_); llvm_ir::IrArray::Index outer_dims_index = loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat"); std::replace(outer_dims_index.begin(), outer_dims_index.end(), static_cast(nullptr), - static_cast(ir_builder_.getInt64(0))); + static_cast(b_.getInt64(0))); if (!outer_dims.empty()) { - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); } PrimitiveType primitive_type = output_shape.element_type(); @@ -2372,10 +2297,10 @@ StatusOr IrEmitter::EmitFastConcatenate( // Contiguous subregions from each operand to the concatenate contribute to a // contiguous subregion in the target buffer starting at target_region_begin. - llvm::Value* target_region_begin = ir_builder_.CreateBitCast( - target_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_, - "target_region"), - i8_ptr_type); + llvm::Value* target_region_begin = + b_.CreateBitCast(target_array.EmitArrayElementAddress( + outer_dims_index, &b_, "target_region"), + i8_ptr_type); int64 byte_offset_into_target_region = 0; int64 inner_dims_product = @@ -2389,14 +2314,13 @@ StatusOr IrEmitter::EmitFastConcatenate( for (HloInstruction* operand : operands) { const Shape& input_shape = operand->shape(); llvm_ir::IrArray source_array = GetIrArrayFor(operand); - llvm::Value* copy_source_address = ir_builder_.CreateBitCast( - source_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_, - "src_addr"), + llvm::Value* copy_source_address = b_.CreateBitCast( + source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"), i8_ptr_type); - llvm::Value* copy_target_address = ir_builder_.CreateGEP( - i8_type, target_region_begin, - ir_builder_.getInt64(byte_offset_into_target_region)); + llvm::Value* copy_target_address = + b_.CreateGEP(i8_type, target_region_begin, + b_.getInt64(byte_offset_into_target_region)); EmitTransferElements( copy_target_address, copy_source_address, @@ -2409,7 +2333,7 @@ StatusOr IrEmitter::EmitFastConcatenate( } if (!outer_dims.empty()) { - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); } return true; @@ -2428,16 +2352,15 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); if (element_count == 1) { - auto* load_instruction = ir_builder_.CreateAlignedLoad( - ir_builder_.CreateBitCast(source, primitive_ptr_type), - element_alignment); + auto* load_instruction = b_.CreateAlignedLoad( + b_.CreateBitCast(source, primitive_ptr_type), element_alignment); source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction); - auto* store_instruction = ir_builder_.CreateAlignedStore( - load_instruction, ir_builder_.CreateBitCast(target, primitive_ptr_type), + auto* store_instruction = b_.CreateAlignedStore( + load_instruction, b_.CreateBitCast(target, primitive_ptr_type), element_alignment); target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { - auto* memcpy_instruction = ir_builder_.CreateMemCpy( + auto* memcpy_instruction = b_.CreateMemCpy( target, /*DstAlign=*/element_alignment, source, /*SrcAlign=*/element_alignment, element_count * primitive_type_size); @@ -2471,8 +2394,6 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { Status IrEmitter::HandleConditional(HloInstruction* conditional) { auto pred = conditional->operand(0); - auto true_arg = conditional->operand(1); - auto false_arg = conditional->operand(2); TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && pred->shape().element_type() == PRED) << "Predicate on a Conditional must be bool; got: " @@ -2494,37 +2415,60 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { << " and " << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); - llvm::Function* true_function = - FindOrDie(emitted_functions_, true_computation); - llvm::Function* false_function = - FindOrDie(emitted_functions_, false_computation); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); - llvm::Value* conditional_result = GetEmittedValueFor(conditional); // Generating: // if (pred) // cond_result = true_computation(true_operand) // else // cond_result = false_computation(false_operand) - llvm::LoadInst* pred_value = ir_builder_.CreateLoad( + llvm::LoadInst* pred_value = b_.CreateLoad( GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = ir_builder_.CreateICmpNE( + llvm::Value* pred_cond = b_.CreateICmpNE( pred_value, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(pred_cond, "conditional", &ir_builder_); + llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_); + + SetToFirstInsertPoint(if_data.true_block, &b_); + EmitGlobalCall(*conditional->true_computation(), + IrName(conditional, "_true")); - SetToFirstInsertPoint(if_data.true_block, &ir_builder_); - EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)}, - conditional_result, IrName(conditional, "_true")); + SetToFirstInsertPoint(if_data.false_block, &b_); + EmitGlobalCall(*conditional->false_computation(), + IrName(conditional, "_false")); - SetToFirstInsertPoint(if_data.false_block, &ir_builder_); - EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)}, - conditional_result, IrName(conditional, "_false")); + SetToFirstInsertPoint(if_data.after_block, &b_); + return Status::OK(); +} + +Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { + TF_RET_CHECK(ByteSizeOf(gen_token->shape()) == 0); + // No code to generate, but we need to emit an address for book-keeping. + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(gen_token)); + return Status::OK(); +} + +Status IrEmitter::HandleIota(HloInstruction* iota) { + // TODO(b/64798317): implement iota on CPU. + return Unimplemented("Iota is not implemented on CPU."); +} + +Status IrEmitter::HandleRng(HloInstruction* rng) { + ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; + for (const HloInstruction* operand : rng->operands()) { + operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { + return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_); + }; + } + + CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); + TF_RETURN_IF_ERROR(EmitTargetElementLoop( + rng, elemental_emitter.MakeElementGenerator(rng, operand_to_generator))); + + llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_); - SetToFirstInsertPoint(if_data.after_block, &ir_builder_); return Status::OK(); } @@ -2545,7 +2489,7 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { auto record_complete_computation = [&](llvm::Value* prof_counter) { if (prof_counter) { - profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter); + profiling_state_.RecordCompleteComputation(&b_, prof_counter); } }; @@ -2567,54 +2511,51 @@ llvm::Value* IrEmitter::GetProfileCounterCommon( int64 prof_counter_idx = it->second; string counter_name = IrName("prof_counter", hlo.name()); - return ir_builder_.CreateGEP(GetProfileCountersArgument(), - ir_builder_.getInt64(prof_counter_idx), - AsStringRef(counter_name)); + return b_.CreateGEP(GetProfileCountersArgument(), + b_.getInt64(prof_counter_idx), AsStringRef(counter_name)); } -void IrEmitter::ProfilingState::UpdateProfileCounter( - llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter, - llvm::Value* cycle_end, llvm::Value* cycle_start) { - auto* cycle_diff = ir_builder->CreateSub(cycle_end, cycle_start); +void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b, + llvm::Value* prof_counter, + llvm::Value* cycle_end, + llvm::Value* cycle_start) { + auto* cycle_diff = b->CreateSub(cycle_end, cycle_start); llvm::LoadInst* old_cycle_count = - ir_builder->CreateLoad(prof_counter, "old_cycle_count"); + b->CreateLoad(prof_counter, "old_cycle_count"); auto* new_cycle_count = - ir_builder->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count"); - ir_builder->CreateStore(new_cycle_count, prof_counter); + b->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count"); + b->CreateStore(new_cycle_count, prof_counter); } -llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter( - llvm::IRBuilder<>* ir_builder) { - llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); +llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) { + llvm::Module* module = b->GetInsertBlock()->getModule(); if (use_rdtscp_) { llvm::Function* func_llvm_readcyclecounter = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::readcyclecounter); - return ir_builder->CreateCall(func_llvm_readcyclecounter); + return b->CreateCall(func_llvm_readcyclecounter); } llvm::Function* func_llvm_x86_rdtscp = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp); if (!aux_i8ptr_) { - llvm::AllocaInst* rdtscp_aux = llvm_ir::EmitAllocaAtFunctionEntry( - ir_builder->getInt32Ty(), "rdtscp_aux", ir_builder); - aux_i8ptr_ = - ir_builder->CreateBitCast(rdtscp_aux, ir_builder->getInt8PtrTy()); + llvm::AllocaInst* rdtscp_aux = + llvm_ir::EmitAllocaAtFunctionEntry(b->getInt32Ty(), "rdtscp_aux", b); + aux_i8ptr_ = b->CreateBitCast(rdtscp_aux, b->getInt8PtrTy()); } - llvm::ConstantInt* alloca_size = ir_builder->getInt64(4); + llvm::ConstantInt* alloca_size = b->getInt64(4); llvm::Function* func_llvm_lifetime_start = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_start); - ir_builder->CreateCall(func_llvm_lifetime_start, {alloca_size, aux_i8ptr_}); - llvm::Value* rdtscp_call = - ir_builder->CreateCall(func_llvm_x86_rdtscp, aux_i8ptr_); + b->CreateCall(func_llvm_lifetime_start, {alloca_size, aux_i8ptr_}); + llvm::Value* rdtscp_call = b->CreateCall(func_llvm_x86_rdtscp, aux_i8ptr_); llvm::Function* func_llvm_lifetime_end = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_end); - ir_builder->CreateCall(func_llvm_lifetime_end, {alloca_size, aux_i8ptr_}); + b->CreateCall(func_llvm_lifetime_end, {alloca_size, aux_i8ptr_}); return rdtscp_call; } -void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* ir_builder, +void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo) { - auto* cycle_start = ReadCycleCounter(ir_builder); + auto* cycle_start = ReadCycleCounter(b); cycle_start->setName(AsStringRef(IrName(hlo, "cycle_start"))); cycle_starts_[hlo] = cycle_start; if (first_read_cycle_start_ == nullptr) { @@ -2622,20 +2563,20 @@ void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* ir_builder, } } -void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* ir_builder, +void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo, llvm::Value* prof_counter) { - auto* cycle_end = ReadCycleCounter(ir_builder); + auto* cycle_end = ReadCycleCounter(b); cycle_end->setName(AsStringRef(IrName(hlo, "cycle_end"))); auto* cycle_start = cycle_starts_[hlo]; - UpdateProfileCounter(ir_builder, prof_counter, cycle_end, cycle_start); + UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start); last_read_cycle_end_ = cycle_end; } void IrEmitter::ProfilingState::RecordCompleteComputation( - llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter) { + llvm::IRBuilder<>* b, llvm::Value* prof_counter) { if (last_read_cycle_end_ && first_read_cycle_start_) { - UpdateProfileCounter(ir_builder, prof_counter, last_read_cycle_end_, + UpdateProfileCounter(b, prof_counter, last_read_cycle_end_, first_read_cycle_start_); } } @@ -2643,14 +2584,14 @@ void IrEmitter::ProfilingState::RecordCompleteComputation( Status IrEmitter::Preprocess(HloInstruction* hlo) { VLOG(3) << "Visiting: " << hlo->ToString(); if (instruction_to_profile_idx_.count(hlo)) { - profiling_state_.RecordCycleStart(&ir_builder_, hlo); + profiling_state_.RecordCycleStart(&b_, hlo); } return Status::OK(); } Status IrEmitter::Postprocess(HloInstruction* hlo) { if (auto* prof_counter = GetProfileCounterFor(*hlo)) { - profiling_state_.RecordCycleDelta(&ir_builder_, hlo, prof_counter); + profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter); } return Status::OK(); } @@ -2697,42 +2638,76 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return compute_function_->exec_run_options_arg(); } -llvm::Value* IrEmitter::EmitTempBufferPointer( +llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { - llvm::Type* element_type = IrShapeType(target_shape); - // The alignment and number of bytes within the temporary buffer is determined - // by the maximal shape as determined by buffer assignment. - const BufferAllocation& allocation = assignment_.GetAllocation(slice.index()); - if (allocation.is_thread_local()) { + const BufferAllocation& allocation = *slice.allocation(); + llvm::Value* tempbuf_address = [&]() -> llvm::Value* { + if (slice == computation_root_allocation_) { + llvm::Argument* retval = compute_function_->result_arg(); + llvm::AttrBuilder attr_builder; + attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); + attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); + retval->addAttrs(attr_builder); + return retval; + } + + auto param_it = + computation_parameter_allocations_.find(slice.allocation()->index()); + if (param_it != computation_parameter_allocations_.end()) { + int64 param_number = param_it->second; + // We have to access the parameter at offset param_number in the params + // array. The code generated here is equivalent to this C code: + // + // i8* param_address_untyped = params[param_number]; + // Param* param_address_typed = (Param*)param_address_untyped; + // + // Where Param is the actual element type of the underlying buffer (for + // example, float for an XLA F32 element type). + llvm::Value* params = compute_function_->parameters_arg(); + llvm::Value* param_address_offset = + llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); + llvm::LoadInst* param_address_untyped = + b_.CreateLoad(param_address_offset); + + if (!ShapeUtil::IsOpaque(target_shape)) { + AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); + AttachDereferenceableMetadataForLoad(param_address_untyped, + target_shape); + } + return param_address_untyped; + } + // Thread-local allocations should only be assigned a single buffer. const auto& assigned_buffers = allocation.assigned_buffers(); CHECK_EQ(1, assigned_buffers.size()); const Shape& shape = assigned_buffers.begin()->first->shape(); - llvm::AllocaInst*& tempbuf_address = thread_local_buffers_[{ - ir_builder_.GetInsertBlock()->getParent(), slice}]; - if (tempbuf_address == nullptr) { - tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry( + std::pair key = { + compute_function_->function(), slice}; + auto buf_it = thread_local_buffers_.find(key); + if (buf_it == thread_local_buffers_.end()) { + llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( IrShapeType(shape), - tensorflow::strings::StrCat("thread_local", slice.ToString()), - &ir_builder_, MinimumAlignmentForShape(target_shape)); + tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_, + MinimumAlignmentForShape(target_shape)); + auto it_inserted_pair = thread_local_buffers_.insert({key, buffer}); + CHECK(it_inserted_pair.second); + buf_it = it_inserted_pair.first; } - return ir_builder_.CreateBitCast(tempbuf_address, - element_type->getPointerTo()); - } + return buf_it->second; + }(); + return b_.CreateBitCast(tempbuf_address, + IrShapeType(target_shape)->getPointerTo()); +} +llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape) { + const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( - GetTempBuffersArgument(), slice.index(), &ir_builder_); - llvm::LoadInst* tempbuf_address_base = - ir_builder_.CreateLoad(tempbuf_address_ptr); - if (is_top_level_computation_ && - hlo_module_config_.debug_options() + GetTempBuffersArgument(), slice.index(), &b_); + llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr); + if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // In the entry computation the parameter slots in the %params argument are - // invariant through program execution. In computations that are called - // from the entry computation (via kWhile, kCall and kConditional) the - // parameter slots are *not* invariant since they're written to by their - // callers. tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); @@ -2743,87 +2718,29 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( llvm::Value* tempbuf_address_untyped = tempbuf_address_base; if (slice.offset() > 0) { // Adjust the address to account for the slice offset. - tempbuf_address_untyped = ir_builder_.CreateInBoundsGEP( - tempbuf_address_base, ir_builder_.getInt64(slice.offset())); + tempbuf_address_untyped = + b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); } - return ir_builder_.CreateBitCast(tempbuf_address_untyped, - element_type->getPointerTo()); + return b_.CreateBitCast(tempbuf_address_untyped, + IrShapeType(target_shape)->getPointerTo()); } -// Emits a function call returning a single array element. Allocates space -// for a single element_type value, and loads it after call. -llvm::Value* IrEmitter::EmitElementFunctionCall( - llvm::Function* function, const Shape& return_shape, - gtl::ArraySlice parameter_addresses, - tensorflow::StringPiece name) { - llvm::Value* return_value_buffer = EmitArrayFunctionCall( - function, return_shape, 1, parameter_addresses, name); - return ir_builder_.CreateLoad( - return_value_buffer, - AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); -} - -// Emits a core function call based on the following pseudo-code. -// -// char** parameter_addresses_buffer = -// allocate buffer with a pointer for each parameter to the function -// for each parameter index, i.e. for i = 0, ..., #parameters: -// parameter_addresses_buffer[i] = parameter_addresses[i] -// call function(return_value_buffer, -// parameter_addresses_buffer, -// temps) -// return return_value_buffer -- address of the return value. -void IrEmitter::EmitArrayFunctionCallInto( - llvm::Function* function, gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name) { - ir_builder_.CreateCall( - function, GetArrayFunctionCallArguments( - parameter_addresses, &ir_builder_, name, - /*return_value_buffer=*/return_value_buffer, - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), - /*profile_counters_arg=*/GetProfileCountersArgument())); -} - -llvm::Value* IrEmitter::EmitArrayFunctionCall( - llvm::Function* function, const Shape& return_shape, int64 element_count, - gtl::ArraySlice parameter_addresses, - tensorflow::StringPiece name) { - llvm::Value* elements = - llvm::ConstantInt::get(ir_builder_.getInt64Ty(), element_count); - PrimitiveType return_type = return_shape.element_type(); - llvm::Value* return_value_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements, - tensorflow::strings::StrCat(name, "_return_value_address"), - &ir_builder_, MinimumAlignmentForPrimitiveType(return_type)); - EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer, - name); - return return_value_buffer; +llvm::Value* IrEmitter::EmitTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape) { + if (slice.allocation()->is_thread_local()) { + return EmitThreadLocalTempBufferPointer(slice, target_shape); + } else if (slice.allocation()->is_constant()) { + return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); + } else { + return EmitGlobalTempBufferPointer(slice, target_shape); + } } Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { - llvm::Value* addr; const Shape& target_shape = op->shape(); - if (op == op->parent()->root_instruction()) { - // For the root node, we write directly to the output buffer of the - // function. - llvm::Argument* retval = compute_function_->result_arg(); - if (!ShapeUtil::IsNil(target_shape)) { - llvm::AttrBuilder attr_builder; - attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); - attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); - retval->addAttrs(attr_builder); - } - addr = ir_builder_.CreateBitCast(retval, - IrShapeType(target_shape)->getPointerTo()); - } else { - // For other nodes, we need the temporary buffer allocated for this node to - // write the result into. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(op)); - addr = EmitTempBufferPointer(slice, target_shape); - } + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(op)); + llvm::Value* addr = EmitTempBufferPointer(slice, target_shape); addr->setName(AsStringRef(IrName(op))); emitted_value_[op] = addr; return Status::OK(); @@ -2858,14 +2775,14 @@ Status IrEmitter::EmitTargetElementLoop( llvm_ir::IrArray(op_target_address, element_shape)); } TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter(element_generator, output_arrays, &ir_builder_) + llvm_ir::LoopEmitter(element_generator, output_arrays, &b_) .EmitLoop(IrName(target_op))); std::vector tuple_operand_ptrs; for (int64 i = 0; i < output_arrays.size(); ++i) { tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } - llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_, module_); + llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_, module_); } else { if (ShouldEmitParallelLoopFor(*target_op)) { @@ -2874,11 +2791,11 @@ Status IrEmitter::EmitTargetElementLoop( compute_function_->GetDynamicLoopBounds(); // Emit parallel loop with dynamic loop bounds for most-major dimensions. TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array, - &dynamic_loop_bounds, &ir_builder_) + &dynamic_loop_bounds, &b_) .EmitLoop(IrName(target_op))); } else { TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) + llvm_ir::LoopEmitter(element_generator, target_array, &b_) .EmitLoop(IrName(target_op))); } } @@ -2891,8 +2808,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* destination_value = GetEmittedValueFor(&destination); int64 source_size = ByteSizeOf(source.shape()); // TODO(b/63762267): Be more aggressive about specifying alignment. - ir_builder_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, - /*SrcAlign=*/1, source_size); + b_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, + /*SrcAlign=*/1, source_size); return Status::OK(); } @@ -2920,7 +2837,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArrayFor(operand).EmitReadArrayElement(index, &ir_builder_); + return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_); }; } CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); @@ -2928,20 +2845,69 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); } -StatusOr IrEmitter::EmitScalarCall( - PrimitiveType return_type, HloComputation* computation, - const std::vector& arguments, tensorflow::StringPiece name) { - llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation); - std::vector argument_addrs; - for (auto argument : arguments) { - llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry( - argument->getType(), "arg_addr", &ir_builder_); - ir_builder_.CreateStore(argument, argument_addr); - argument_addrs.push_back(argument_addr); +llvm::Value* IrEmitter::EmitThreadLocalCall( + const HloComputation& callee, + tensorflow::gtl::ArraySlice parameters, + tensorflow::StringPiece name) { + const Shape& return_shape = callee.root_instruction()->shape(); + + // Lifting this restriction to allow "small" arrays should be easy. Allowing + // larger arrays is difficult because we allocate the buffer for this return + // value on the stack. + CHECK(ShapeUtil::IsScalar(return_shape)); + + PrimitiveType return_type = return_shape.element_type(); + + std::vector parameter_addrs; + for (llvm::Value* parameter : parameters) { + CHECK(!parameter->getType()->isPointerTy()); + llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry( + parameter->getType(), "arg_addr", &b_); + b_.CreateStore(parameter, parameter_addr); + parameter_addrs.push_back(parameter_addr); + } + + llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(return_type, module_), + tensorflow::strings::StrCat(name, "_retval_addr"), &b_, + MinimumAlignmentForPrimitiveType(return_type)); + + b_.CreateCall( + FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + parameter_addrs, &b_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), + /*profile_counters_arg=*/GetProfileCountersArgument())); + + return b_.CreateLoad(return_value_buffer); +} + +void IrEmitter::EmitGlobalCall(const HloComputation& callee, + tensorflow::StringPiece name) { + b_.CreateCall(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + /*parameter_addresses=*/{}, &b_, name, + /*return_value_buffer=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()), + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); +} + +llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( + const HloComputation& callee) { + const HloInstruction* root_inst = callee.root_instruction(); + if (root_inst->opcode() == HloOpcode::kOutfeed) { + return llvm::Constant::getNullValue(b_.getInt8PtrTy()); } - return EmitElementFunctionCall(llvm_function, - ShapeUtil::MakeShape(return_type, {}), - argument_addrs, name); + + const BufferAllocation::Slice root_buffer = + assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie(); + return EmitTempBufferPointer(root_buffer, root_inst->shape()); } + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 32c536e18fee86cc60067ba3b25ab1eb0e4233df..c9a1dab62dcbcd926baa82737d24efa03fd326e9 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -30,12 +30,12 @@ limitations under the License. #include "llvm/IR/Value.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/service/cpu/ir_function.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -67,17 +67,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { // index in the profiling array. // computation_to_profile_idx: the mapping from HLO computations to their // index in the profiling array. - // external_constant_pool: if non-null, points to an ExternalConstantPool - // instance into which the Ir emitter can spill - // constants. IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - const TargetMachineFeatures* target_machine, - ExternalConstantPool* external_constant_pool); + const TargetMachineFeatures* target_machine); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -102,12 +98,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { bool is_top_level_computation, std::vector* instruction_order); - llvm::IRBuilder<>* ir_builder() { return &ir_builder_; } + llvm::IRBuilder<>* b() { return &b_; } - // Emits a call to `computation` with scalar arguments `arguments`. - StatusOr EmitScalarCall( - PrimitiveType return_type, HloComputation* computation, - const std::vector& arguments, tensorflow::StringPiece name); + // Emit an LLVM global variable for every constant buffer allocation. + Status EmitConstantGlobals(); + + // Emit code to map one element according to `map_instr`. + llvm::Value* EmitElementalMap( + const HloMapInstruction& map_instr, + tensorflow::gtl::ArraySlice elemental_operands, + tensorflow::StringPiece name); protected: // @@ -122,6 +122,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCopy(HloInstruction* copy) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; @@ -143,13 +144,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleRecvDone(HloInstruction* recv_done) override; Status HandlePad(HloInstruction* pad) override; Status HandleTuple(HloInstruction* tuple) override; - Status HandleMap(HloInstruction* map) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; + Status HandleScatter(HloInstruction* scatter) override; + Status HandleAfterAll(HloInstruction* gen_token) override; + Status HandleIota(HloInstruction* iota) override; + Status HandleRng(HloInstruction* rng) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -215,9 +219,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); - // Emits code that computes the address of the given temporary buffer to the - // function. target_shape is the shape of this temporary buffer. - // The returned Value's type is a pointer to element_type. + // Helper for EmitTempBufferPointer. + llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape); + + // Helper for EmitTempBufferPointer. + llvm::Value* EmitThreadLocalTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape); + + // Emits code that computes the address of the given buffer allocation slice. + // + // TODO(sanjoy): This should be renamed to reflect that it no longer provides + // access to just temporaries. llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice, const Shape& target_shape); @@ -229,44 +242,27 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::StringPiece function_name_suffix); // Used for LLVM IR register names. - // Methods that emit a function call. - // Parameters: - // function - The LLVM function to call. - // return_shape - The return shape of the HLO computation that was used to - // make the function. Not the same as the return type of the function - // in LLVM, since we use output parameters for the return type. - // element_count - number of elements to return (array form only). - // parameter_addresses - pointers to be passed to the function as - // parameters. - // name - used for LLVM IR register names. - - // Emits a function call, returning a scalar, often an element of a larger - // array. Returns a Value for the scalar element returned by the function. - llvm::Value* EmitElementFunctionCall( - llvm::Function* function, const Shape& return_shape, - tensorflow::gtl::ArraySlice parameter_addresses, + // Emits a call to a thread local function (e.g. to the computation nested + // within a reduce or a map). Thread local callees (by definition) only write + // to and read from thread local allocations. + // + // `parameters` holds the *scalar values* that need to be passed to the + // callee. The return value is the scalar returned by the callee. + llvm::Value* EmitThreadLocalCall( + const HloComputation& callee, + tensorflow::gtl::ArraySlice parameters, tensorflow::StringPiece name); - // Array function call emitter. Stores the function's result into a supplied - // buffer. - // Parameters: - // function - The LLVM function to call. - // parameter_addresses - pointers to be passed to the function as - // parameters. - // return_value - pointer to a buffer where the call result is stored. - - void EmitArrayFunctionCallInto( - llvm::Function* function, - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name); - - // Array function call emitter. Returns a Value for the function's return - // value buffer address. The return value buffer is alloca'ed by this - // function. - llvm::Value* EmitArrayFunctionCall( - llvm::Function* function, const Shape& return_shape, int64 element_count, - tensorflow::gtl::ArraySlice parameter_addresses, - tensorflow::StringPiece name); + // Emits a call to a "global" function (e.g. to the computation nested within + // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to + // the parameters and return values for these computations so there is no need + // to explicitly pass parameters or return results. + void EmitGlobalCall(const HloComputation& callee, + tensorflow::StringPiece name); + + // Returns the buffer to which a global call to `callee` would have written + // its result. + llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee); // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. @@ -405,11 +401,10 @@ class IrEmitter : public DfsHloVisitorWithDefault { NameUniquer name_uniquer_; // Map containing all previously emitted computations. - std::map emitted_functions_; + std::map emitted_functions_; // Map containing all previously emitted thread-local temporary buffers. - std::map, - llvm::AllocaInst*> + std::map, llvm::Value*> thread_local_buffers_; // The following fields track the IR emission state. According to LLVM memory @@ -417,7 +412,17 @@ class IrEmitter : public DfsHloVisitorWithDefault { // creates the encapsulated llvm::Function s.t. it is added to the llvm // module's function list). std::unique_ptr compute_function_; - llvm::IRBuilder<> ir_builder_; + llvm::IRBuilder<> b_; + + // The buffer allocation slice for the root of the computation being compiled. + // Only relevant for thread local computations. + BufferAllocation::Slice computation_root_allocation_; + + // Maps the buffer allocation slices for the parameters to the computation + // being compiled to their parameter numbers. Only relevant for thread local + // computations. + tensorflow::gtl::FlatMap + computation_parameter_allocations_; // Maps HLO instructions to their index into the profile counter array. const std::unordered_map @@ -453,23 +458,22 @@ class IrEmitter : public DfsHloVisitorWithDefault { : use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {} // Record the cycle counter before an HLO executes. - void RecordCycleStart(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo); + void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo); // Record the number of cycles it took for an HLO to execute. - void RecordCycleDelta(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo, + void RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo, llvm::Value* prof_counter); // Record the number of cycles it took for the entire computation to // execute. - void RecordCompleteComputation(llvm::IRBuilder<>* ir_builder, + void RecordCompleteComputation(llvm::IRBuilder<>* b, llvm::Value* prof_counter); // Convenience function to generate a call to an intrinsic which reads the // CPU cycle counter. - llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* ir_builder); + llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* b); // Store the cycle counter delta to the per-HLO profile counter. - void UpdateProfileCounter(llvm::IRBuilder<>* ir_builder, - llvm::Value* prof_counter, llvm::Value* cycle_end, - llvm::Value* cycle_start); + void UpdateProfileCounter(llvm::IRBuilder<>* b, llvm::Value* prof_counter, + llvm::Value* cycle_end, llvm::Value* cycle_start); private: // Should we use the x86-specific rdtscp or the generic readcyclecounter @@ -517,6 +521,17 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Returns the number of bytes within the shape. int64 ByteSizeOf(const Shape& shape) const; + StatusOr EmitTargetElementLoopBodyForMap( + HloMapInstruction* map, const llvm_ir::IrArray::Index& index); + StatusOr EmitTargetElementLoopBodyForReduceWindow( + HloReduceWindowInstruction* reduce_window, + const llvm_ir::IrArray::Index& index); + StatusOr EmitTargetElementLoopBodyForConvolution( + HloConvolutionInstruction* convolution, + const llvm_ir::IrArray::Index& index); + StatusOr EmitTargetElementLoopBodyForReduce( + HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index); + enum class XfeedKind { kInfeed, kOutfeed, @@ -536,9 +551,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { const TargetMachineFeatures& target_machine_features_; - int64 external_global_constant_counter_ = 0; - ExternalConstantPool* external_constant_pool_; - struct LiteralPtrHashFunctor { size_t operator()(const Literal* literal) const { return literal->Hash(); } }; @@ -553,6 +565,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { LiteralPtrHashFunctor, LiteralPtrEqualityFunctor> emitted_literals_; + tensorflow::gtl::FlatMap + constant_buffer_to_global_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 2d6f2f3818a7bd4424aaa7d918ca86abef15c0e9..2db4d000f5b149969c88fb4325ca28aa11dc3708 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -49,11 +49,10 @@ IrFunction::IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, const bool optimize_for_size_requested, const bool enable_fast_math, llvm::Module* llvm_module, - llvm::IRBuilder<>* ir_builder, - int64 num_dynamic_loop_bounds) - : ir_builder_(ir_builder), + llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds) + : b_(b), llvm_module_(llvm_module), - caller_insert_point_guard_(*ir_builder), + caller_insert_point_guard_(*b), num_dynamic_loop_bounds_(num_dynamic_loop_bounds) { Initialize(function_name, linkage, optimize_for_size_requested, enable_fast_math); @@ -61,7 +60,7 @@ IrFunction::IrFunction(const string& function_name, IrFunction::~IrFunction() { // Emit function return value. - ir_builder_->CreateRetVoid(); + b_->CreateRetVoid(); } DynamicLoopBounds IrFunction::GetDynamicLoopBounds() { @@ -81,9 +80,16 @@ void IrFunction::Initialize(const string& function_name, // void function(i8* retval, i8* run_options, i8** params, i8** temps, // i64* dynamic_loop_bounds, i64* prof_counters) // - // retval: points to the returned value. - // params: address of an array with pointers to parameters. - // temps: address of an array with pointers to temporary buffers. + // For thread local functions: + // retval: points to the returned value. + // params: address of an array with pointers to parameters. + // temps: is null + // + // For global functions: + // retval: is null + // params: is null + // temps: address of an array with pointers to temporary buffers and entry + // computation parameters. // // Therefore, the generated function's signature (FunctionType) is statically // determined - parameter unpacking is done in code generated into the @@ -174,7 +180,7 @@ void IrFunction::Initialize(const string& function_name, function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias); } - ir_builder_->SetInsertPoint(llvm::BasicBlock::Create( + b_->SetInsertPoint(llvm::BasicBlock::Create( /*Context=*/llvm_module_->getContext(), /*Name=*/"entry", /*Parent=*/function_)); @@ -184,9 +190,8 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { CHECK_GT(num_dynamic_loop_bounds_, 0); CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); - return ir_builder_->CreateLoad( - ir_builder_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), - ir_builder_->getInt64(offset), AsStringRef(name))); + return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), + b_->getInt64(offset), AsStringRef(name))); } // Emits code to allocate an array of parameter address pointers, and store @@ -195,27 +200,32 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { // address buffer). std::vector GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::IRBuilder<>* b, tensorflow::StringPiece name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { - llvm::Value* parameter_addresses_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder->getInt8PtrTy(), - ir_builder->getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), - ir_builder); - for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = ir_builder->CreateBitCast( - parameter_addresses[i], ir_builder->getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, - "_address_as_i8ptr"))); - llvm::Value* slot_in_param_addresses = ir_builder->CreateInBoundsGEP( - parameter_addresses_buffer, {ir_builder->getInt64(i)}); - ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); + llvm::Value* parameter_addresses_buffer; + + if (parameter_addresses.empty()) { + parameter_addresses_buffer = + llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo()); + } else { + parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), b); + + for (size_t i = 0; i < parameter_addresses.size(); ++i) { + llvm::Value* parameter_as_i8ptr = + b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), + AsStringRef(tensorflow::strings::StrCat( + name, "_parameter_", i, "_address_as_i8ptr"))); + llvm::Value* slot_in_param_addresses = + b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); + b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); + } } const auto to_int8_ptr = [=](llvm::Value* ptr) { - return ir_builder->CreatePointerCast(ptr, ir_builder->getInt8PtrTy()); + return b->CreatePointerCast(ptr, b->getInt8PtrTy()); }; std::vector arguments{ to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), @@ -230,22 +240,21 @@ std::vector GetArrayFunctionCallArguments( // calls to 'parallel_function' (and joins threads before returning). Status EmitCallToParallelForkJoin( const std::vector& arguments, const Shape& shape, - const std::vector& dimension_partition_counts, - llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, - const string& name) { - llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + const std::vector& dimension_partition_counts, llvm::IRBuilder<>* b, + llvm::Function* parallel_function, const string& name) { + llvm::Module* module = b->GetInsertBlock()->getModule(); // Build ParallelForkJoin function type. std::vector compute_function_params = GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0); // Number of parallel compute functions. - compute_function_params.push_back(ir_builder->getInt32Ty()); + compute_function_params.push_back(b->getInt32Ty()); // Array of partitions. There is an array element for each // partition x partition_dim x 2 (for dimension start and limit). compute_function_params.push_back( llvm::Type::getInt64PtrTy(module->getContext())); // Number of partitioned most-major dimensions in 'shape'. - compute_function_params.push_back(ir_builder->getInt32Ty()); + compute_function_params.push_back(b->getInt32Ty()); // Function pointer for compute function to be dispatched in parallel. compute_function_params.push_back( llvm::Type::getInt8PtrTy(module->getContext())); @@ -268,7 +277,7 @@ Status EmitCallToParallelForkJoin( ShapePartitionIterator partition_iterator(shape, dimension_partition_counts); const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); // Add argument specifying the number of parallel partitions. - fork_join_arguments.push_back(ir_builder->getInt32(num_partitions)); + fork_join_arguments.push_back(b->getInt32(num_partitions)); // The number of partitioned most-major dimensions in 'shape'. const int32 num_partitioned_dims = dimension_partition_counts.size(); @@ -293,15 +302,15 @@ Status EmitCallToParallelForkJoin( const std::pair& dim_partition = dim_partitions[j]; const int32 index = partition_index + j * dim_partition_size; // Store partition [dim_start, dim_limit) intervals for each dimension. - partitions[index] = ir_builder->getInt64(dim_partition.first); + partitions[index] = b->getInt64(dim_partition.first); partitions[index + 1] = - ir_builder->getInt64(dim_partition.first + dim_partition.second); + b->getInt64(dim_partition.first + dim_partition.second); } } // Create global variable out of dimension partitions in 'partitions'. llvm::ArrayType* partitions_array_type = - llvm::ArrayType::get(ir_builder->getInt64Ty(), partition_array_size); + llvm::ArrayType::get(b->getInt64Ty(), partition_array_size); llvm::Constant* partitions_array = llvm::ConstantArray::get(partitions_array_type, partitions); llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( @@ -315,16 +324,16 @@ Status EmitCallToParallelForkJoin( tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); // Add argument specifying parallel dimension partitions. - fork_join_arguments.push_back(ir_builder->CreateBitCast( - global_partitions_array, - llvm::Type::getInt64PtrTy(module->getContext()))); + fork_join_arguments.push_back( + b->CreateBitCast(global_partitions_array, + llvm::Type::getInt64PtrTy(module->getContext()))); // Add argument specifying the number of partitioned most-major dimensions. - fork_join_arguments.push_back(ir_builder->getInt32(num_partitioned_dims)); + fork_join_arguments.push_back(b->getInt32(num_partitioned_dims)); // Add argument for parallel compute function pointer. fork_join_arguments.push_back( - ir_builder->CreateBitCast(parallel_function, ir_builder->getInt8PtrTy())); + b->CreateBitCast(parallel_function, b->getInt8PtrTy())); // Emit call to parallel fork/join. - ir_builder->CreateCall(fork_join_func, fork_join_arguments); + b->CreateCall(fork_join_func, fork_join_arguments); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index 2e55181eed867aca762f2b9b8310624ea12c7487..a41cbb64cdd9f5b6de5d1eadfbf7e63e1e984801 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -54,7 +54,7 @@ class IrFunction { IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, const bool optimize_for_size_requested, const bool enable_fast_math, llvm::Module* llvm_module, - llvm::IRBuilder<>* ir_builder, int64 num_dynamic_loop_bounds); + llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds); ~IrFunction(); // Emit ir to read and return the set of ir values representing the dynamic @@ -97,7 +97,7 @@ class IrFunction { // 'offset' from the "dynamic_loop_bounds" argument of this function. llvm::Value* GetDynamicLoopBound(int64 offset); - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; llvm::Module* llvm_module_; llvm::IRBuilder<>::InsertPointGuard caller_insert_point_guard_; @@ -116,7 +116,7 @@ class IrFunction { // Returns an array of compute function call argument ir values. std::vector GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::IRBuilder<>* b, tensorflow::StringPiece name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); @@ -124,9 +124,8 @@ std::vector GetArrayFunctionCallArguments( // calls to 'parallel_function' (and joins threads before returning). Status EmitCallToParallelForkJoin( const std::vector& arguments, const Shape& shape, - const std::vector& dimension_partition_counts, - llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, - const string& name); + const std::vector& dimension_partition_counts, llvm::IRBuilder<>* b, + llvm::Function* parallel_function, const string& name); } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index 2e5cc96098241415b82f225afc81981f3e1069e0..cef5e57b0b12b7ae93af0d2508b2b9d6a592d390 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/IR/Verifier.h" #include "llvm/Transforms/Utils/Cloning.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/logging.h" @@ -52,46 +53,14 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, llvm::BasicBlock* vector_tanh_body = llvm::BasicBlock::Create(*context, "body", vector_tanh_function); - llvm::IRBuilder<> ir_builder(vector_tanh_body); + llvm::IRBuilder<> b(vector_tanh_body); llvm::FastMathFlags fast_math_flags; - fast_math_flags.setFast(); - ir_builder.setFastMathFlags(fast_math_flags); - - VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "tanh_f32"); + fast_math_flags.setFast(enable_fast_math); + b.setFastMathFlags(fast_math_flags); llvm::Value* input = &*vector_tanh_function->arg_begin(); - CHECK_EQ(input->getType(), vsl.vector_type()); - - // This implements the same rational interpolant as implemented in Eigen3. - llvm::Value* input_clamped = - vsl.Clamp(input, /*low=*/GetIeeeF32(-9.0), /*high=*/GetIeeeF32(9.0)); - - std::array numerator_coeffs{ - -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, - 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, - 4.89352455891786e-03f}; - - std::array denominator_coeffs{ - 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, - 4.89352518554385e-03f}; - - llvm::Value* input_squared = vsl.Mul(input_clamped, input_clamped); - llvm::Value* numerator = vsl.SplatFloat(GetIeeeF32(numerator_coeffs[0])); - for (int i = 1; i < numerator_coeffs.size(); i++) { - numerator = - vsl.MulAdd(input_squared, numerator, GetIeeeF32(numerator_coeffs[i])); - } - - numerator = vsl.Mul(input_clamped, numerator); - - llvm::Value* denominator = vsl.SplatFloat(GetIeeeF32(denominator_coeffs[0])); - for (int i = 1; i < denominator_coeffs.size(); i++) { - denominator = vsl.MulAdd(input_squared, denominator, - GetIeeeF32(denominator_coeffs[i])); - } - - llvm::Value* result = vsl.Div(numerator, denominator); - ir_builder.CreateRet(result); + CHECK_EQ(vector_width, input->getType()->getVectorNumElements()); + b.CreateRet(llvm_ir::EmitFastTanh(&b, input)); DCHECK(!llvm::verifyFunction(*vector_tanh_function)); return vector_tanh_function; @@ -113,12 +82,12 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, llvm::BasicBlock* vector_exp_body = llvm::BasicBlock::Create(*context, "body", vector_exp_function); - llvm::IRBuilder<> ir_builder(vector_exp_body); + llvm::IRBuilder<> b(vector_exp_body); llvm::FastMathFlags fast_math_flags; fast_math_flags.setFast(); - ir_builder.setFastMathFlags(fast_math_flags); + b.setFastMathFlags(fast_math_flags); - VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "exp_f32"); + VectorSupportLibrary vsl(F32, vector_width, &b, "exp_f32"); // This implements the same polynomial approximation as implemented in Eigen3. @@ -160,21 +129,21 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, // VectorSupportLibrary (intentionally) can't juggle more than one type at a // time so drop down to IRBuilder for this bit. llvm::Value* vector_constant_0x7f = - ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f)); + b.CreateVectorSplat(vector_width, b.getInt32(0x7f)); llvm::Value* vector_constant_23 = - ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23)); + b.CreateVectorSplat(vector_width, b.getInt32(23)); llvm::Type* i32_vector_type = - llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width); + llvm::VectorType::get(b.getInt32Ty(), vector_width); // fx is clamped so we don't have to worry about it being out of range for // i32. - llvm::Value* emm0 = ir_builder.CreateFPToSI(fx, i32_vector_type); - emm0 = ir_builder.CreateAdd(emm0, vector_constant_0x7f); - emm0 = ir_builder.CreateShl(emm0, vector_constant_23); - llvm::Value* emm0_f32 = ir_builder.CreateBitCast(emm0, vsl.vector_type()); + llvm::Value* emm0 = b.CreateFPToSI(fx, i32_vector_type); + emm0 = b.CreateAdd(emm0, vector_constant_0x7f); + emm0 = b.CreateShl(emm0, vector_constant_23); + llvm::Value* emm0_f32 = b.CreateBitCast(emm0, vsl.vector_type()); llvm::Value* result = vsl.Max(vsl.Mul(y, emm0_f32), input); - ir_builder.CreateRet(result); + b.CreateRet(result); DCHECK(!llvm::verifyFunction(*vector_exp_function)); return vector_exp_function; @@ -196,13 +165,13 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, llvm::BasicBlock* vector_log_body = llvm::BasicBlock::Create(*context, "body", vector_log_function); - llvm::IRBuilder<> ir_builder(vector_log_body); + llvm::IRBuilder<> b(vector_log_body); llvm::FastMathFlags fast_math_flags; fast_math_flags.setFast(); - ir_builder.setFastMathFlags(fast_math_flags); + b.setFastMathFlags(fast_math_flags); llvm::Value* input = &*vector_log_function->arg_begin(); - VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "log_f32"); + VectorSupportLibrary vsl(F32, vector_width, &b, "log_f32"); const llvm::APFloat half = GetIeeeF32(0.5); const llvm::APFloat one = GetIeeeF32(1.0); @@ -238,22 +207,21 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, // VectorSupportLibrary (intentionally) can't juggle more than one type at a // time so drop down to IRBuilder for this bit. llvm::Value* vector_constant_0x7f = - ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f)); + b.CreateVectorSplat(vector_width, b.getInt32(0x7f)); llvm::Value* vector_constant_23 = - ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23)); + b.CreateVectorSplat(vector_width, b.getInt32(23)); llvm::Type* i32_vector_type = - llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width); + llvm::VectorType::get(b.getInt32Ty(), vector_width); - llvm::Value* emm0 = ir_builder.CreateLShr( - ir_builder.CreateBitCast(input, i32_vector_type), vector_constant_23); + llvm::Value* emm0 = + b.CreateLShr(b.CreateBitCast(input, i32_vector_type), vector_constant_23); // Keep only the fractional part. input = vsl.FloatAnd(input, inv_mant_mask); input = vsl.FloatOr(input, half); - emm0 = ir_builder.CreateSub(emm0, vector_constant_0x7f); - llvm::Value* e = - vsl.Add(one, ir_builder.CreateSIToFP(emm0, vsl.vector_type())); + emm0 = b.CreateSub(emm0, vector_constant_0x7f); + llvm::Value* e = vsl.Add(one, b.CreateSIToFP(emm0, vsl.vector_type())); // part2: // if( x < SQRTHF ) { @@ -294,7 +262,7 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf); llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs); - ir_builder.CreateRet(result); + b.CreateRet(result); DCHECK(!llvm::verifyFunction(*vector_log_function)); return vector_log_function; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index 54af40506dab48b3c2a3a44eb0b5f5fb213a32ec..8560e4296aa95fe791446abb1b4363b9145f343e 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -25,19 +25,21 @@ namespace cpu { ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, - const DynamicLoopBounds* dynamic_loop_bounds, llvm::IRBuilder<>* ir_builder) - : LoopEmitter(target_element_generator, target_array, ir_builder), + const DynamicLoopBounds* dynamic_loop_bounds, llvm::IRBuilder<>* b) + : LoopEmitter(target_element_generator, target_array, b), dynamic_loop_bounds_(dynamic_loop_bounds) {} std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) { + tensorflow::StringPiece loop_name, llvm::Type* index_type) { + CHECK_NE(index_type, nullptr); + CHECK(!ShapeUtil::IsTuple(shape_)); CHECK(!ShapeUtil::IsScalar(shape_)); - llvm_ir::ForLoopNest loop_nest(loop_name, ir_builder_); + llvm_ir::ForLoopNest loop_nest(loop_name, b_); const int64 num_dims = shape_.dimensions_size(); - llvm_ir::IrArray::Index array_index(num_dims); + llvm_ir::IrArray::Index array_index(index_type, num_dims); // Add loops from outer-most to inner-most dimensions. for (int i = LayoutUtil::MinorToMajor(shape_).size() - 1; i >= 0; --i) { @@ -63,8 +65,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( } } // Point IR builder at inner loop BB. - llvm_ir::SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), - ir_builder_); + llvm_ir::SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), b_); // Set exit_bb_ to the exit block of the loop nest. exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index 755715634aa70a822b21d25dcae20a8fe053477a..076c683ca566f2c53992c358903d2aadead290f9 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -54,14 +54,14 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, const DynamicLoopBounds* dynamic_loop_bounds, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) override; + tensorflow::StringPiece loop_name, llvm::Type* index_type) override; private: const DynamicLoopBounds* dynamic_loop_bounds_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index fc2efbaf9a22b02cd729da2f367d53bc15506836..36c9f743859ae2da6c4fb3fd753bd7862fe2d3ab 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -110,8 +110,9 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { const string hlo_string = R"( HloModule TestTaskParallel_infeed_outfeed ENTRY InfeedOutfeed { - infeed0 = u32[12345678,2]{1,0} infeed() - ROOT outfeed0 = u32[12345678,2]{1,0} outfeed(infeed0) + infeed0 = (u32[12345678,2]{1,0}, token[]) infeed() + infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0 + ROOT outfeed0 = token[] outfeed(infeed0.data) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc index d03da46575b331de113cc5f33c2b4267504e8308..a5f34908d70dd18ec017bdf9833c7df40f80db07 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -20,6 +20,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -58,13 +59,14 @@ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, // [partition1_dim2_start] // [partition1_dim2_limit] // -void __xla_cpu_runtime_ParallelForkJoin( +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( void* result_ptr, const void* run_options_ptr, const void** params, void** temps, uint64* prof_counters, int32 num_partitions, int64* partitions, int32 num_partitioned_dims, void* function_ptr) { VLOG(2) << "ParallelForkJoin ENTRY" << " num_partitions: " << num_partitions << " num_partitioned_dims: " << num_partitioned_dims; + CHECK_EQ(params, nullptr); CHECK_GT(num_partitions, 1); CHECK_GT(num_partitioned_dims, 0); const xla::ExecutableRunOptions* run_options = @@ -79,9 +81,9 @@ void __xla_cpu_runtime_ParallelForkJoin( for (int32 i = 1; i < num_partitions; ++i) { const int64 offset = i * stride; run_options->intra_op_thread_pool()->enqueueNoNotification( - [i, function, result_ptr, run_options_ptr, params, temps, prof_counters, + [i, function, result_ptr, run_options_ptr, temps, prof_counters, partitions, offset, &bc]() { - function(result_ptr, run_options_ptr, params, temps, + function(result_ptr, run_options_ptr, nullptr, temps, &partitions[offset], prof_counters); bc.DecrementCount(); VLOG(3) << "ParallelForkJoin partition " << i << " done."; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index 39b13183ff093611a42b3931d45f64eadb420622..a71a85913cfef271bc2a226cb0cf2dd4204499a4 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -20,6 +20,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" using tensorflow::int32; @@ -77,27 +78,24 @@ void MatMulImpl(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, } // namespace -void __xla_cpu_runtime_EigenMatMulF16(const void* run_options_ptr, - Eigen::half* out, Eigen::half* lhs, - Eigen::half* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out, - float* lhs, float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out, - double* lhs, double* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64( + const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc index f8c8dd5e93d53db8d87be0208b5cf4daac3464f1..8dc5f3c93b6ba1a722ea7b23b4b5190ac0600cd6 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) +#if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "third_party/intel_mkl_ml/include/mkl_cblas.h" #include "third_party/intel_mkl_ml/include/mkl_service.h" @@ -23,6 +23,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool" +#include "tensorflow/core/platform/dynamic_annotations.h" using tensorflow::int32; using tensorflow::int64; @@ -74,10 +75,9 @@ void MatMulF64(const void* run_options_ptr, double* out, double* lhs, } // namespace -void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out, - float* lhs, float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread @@ -88,11 +88,11 @@ void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out, // Set thread number back to the previous number. mkl_set_num_threads_local(prev_num_threads); } + // BLAS GEMM API for 64-bit Matrix Multiplication -void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out, - double* lhs, double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF64( + const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread @@ -103,22 +103,26 @@ void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out, // Set thread number back to the previous number. mkl_set_num_threads_local(prev_num_threads); } -void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr, - float* out, float* lhs, - float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr, + float* out, float* lhs, float* rhs, + int64 m, int64 n, int64 k, + int32 transpose_lhs, + int32 transpose_rhs) { // Set the thread number to 1 for single threaded excution. int prev_num_threads = mkl_set_num_threads_local(1); MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); // Set thread number back to the previous number. mkl_set_num_threads_local(prev_num_threads); } -void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr, - double* out, double* lhs, - double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr, + double* out, double* lhs, + double* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { // Set the thread number to 1 for single threaded excution. int prev_num_threads = mkl_set_num_threads_local(1); MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 17303e2f0d34e531a3a56aa147608b949e0f43ae..16692e7f2e6145b2649b67987eef47916e958be2 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -17,6 +17,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" using tensorflow::int32; @@ -71,7 +72,8 @@ void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, } // namespace -void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { @@ -79,16 +81,22 @@ void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenSingleThreadedMatMulF32( - const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr, + float* out, float* lhs, + float* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenSingleThreadedMatMulF64( - const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, + double* out, double* lhs, + double* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index 167aa4adda995a259190a932a76a34ca5883444c..f227e4ae139b92e56786e38ef8eef72c9e2cd424 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -21,9 +21,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -38,20 +38,21 @@ int main(int argc, char** argv) { // Transfer parameters. std::unique_ptr param0_literal = - xla::Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + xla::LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = xla::Literal::CreateR2( - {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); + std::unique_ptr param1_literal = + xla::LiteralUtil::CreateR2( + {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); std::unique_ptr param1_data = client->TransferToServer(*param1_literal).ConsumeValueOrDie(); // Build computation. xla::XlaBuilder builder(""); - auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto add = builder.Add(p1, p0, {0}); + auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Add(p1, p0, {0}); xla::StatusOr computation_status = builder.Build(); xla::XlaComputation computation = computation_status.ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index c4c90515ac7ec2721cb9ea48d42e3c5080e249af..be772cfb7e564cebc5725854dbf5678e5c507556 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -127,13 +127,6 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, } llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { - if (const uint8* from_constant_pool = - external_constant_pool_.Find(string(name))) { - return llvm::JITEvaluatedSymbol( - reinterpret_cast(from_constant_pool), - llvm::JITSymbolFlags::None); - } - void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); if (func_addr == nullptr) { return nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 1851a3ee0bb97b4860605d7211a6ae70ac88686b..d74b63fcf45bd70cd18ee41f1e9714ba6a222abd 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -29,7 +29,6 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" -#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -91,10 +90,6 @@ class SimpleOrcJIT { llvm::TargetMachine* target_machine() const { return target_machine_.get(); } - ExternalConstantPool* external_constant_pool() { - return &external_constant_pool_; - } - // Creates an llvm::TargetMachine suitable for JITting code that will run on // the current machine. static std::unique_ptr InferTargetMachineForJIT( @@ -112,7 +107,6 @@ class SimpleOrcJIT { std::shared_ptr symbol_resolver_; ObjLayerT object_layer_; CompileLayerT compile_layer_; - ExternalConstantPool external_constant_pool_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 66ae5ef0f66e90982102d73e474f5d0582f5415c..181cec3cdddeb40daf5276d9d1d6a139417a6072 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -40,7 +40,7 @@ tf_cc_test( name = "cpu_fusion_test", srcs = ["cpu_fusion_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -82,7 +82,7 @@ tf_cc_test( name = "cpu_noalias_test", srcs = ["cpu_noalias_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -128,16 +128,16 @@ tf_cc_test( name = "cpu_infeed_test", srcs = ["cpu_infeed_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h index 7c8d07a10baf55dba8cbd347ebe1459b78e268e0..77b3a0301f2f90b577b7eaad86064dc30e2d9456 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h @@ -22,7 +22,7 @@ namespace xla { namespace cpu { // Tests that verify IR emitted by the CPU backend is as expected. -class CpuCodegenTest : public LLVMIRGenTestBase {}; +class CpuCodegenTest : public LlvmIrGenTestBase {}; } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index faac927027c48e44eb8ff1fcc4109fbc177fc579..00a7aa2ad2f6bac4877302296ccb76222557535c 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -40,7 +40,7 @@ class CpuExternalConstantsTest : public CpuCodegenTest { HloInstruction* constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2FromArray2D(backing_array))); + LiteralUtil::CreateR2FromArray2D(backing_array))); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); builder.AddInstruction( @@ -56,7 +56,8 @@ class CpuExternalConstantsTest : public CpuCodegenTest { TEST_F(CpuExternalConstantsTest, Basic) { TestWithArray(/*rows=*/1024, /*cols=*/1024, R"( -CHECK: @constant_global_0 = external constant [1024 x [1024 x float]], align 16 +CHECK-NOT: @constant_global_0 = external constant [1024 x [1024 x float]], align 16 +CHECK: @0 = private constant [4194304 x i8] {{.*}}, align 16 )"); } @@ -65,7 +66,7 @@ TEST_F(CpuExternalConstantsTest, BasicNegative) { // to externalize it. TestWithArray(/*rows=*/4, /*cols=*/4, R"( CHECK-NOT: @constant_global_0 = external constant [16 x float], align 8 -CHECK: @0 = private constant [16 x float] {{.*}}, align 8 +CHECK: @0 = private constant [64 x i8] {{.*}}, align 8 )"); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 23e7a3de4d8188a3add259582e11030539e154c1..d98856fdbf4165a5909f193ebe8512e21af83dfc 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -43,8 +43,8 @@ class CpuFusionTest : public HloTestBase { TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { auto builder = HloComputation::Builder(TestName()); - auto input_literal1 = Literal::CreateR1({1.0, 2.0, 3.0}); - auto input_literal2 = Literal::CreateR1({-2.0, -42.0, 2.0}); + auto input_literal1 = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); + auto input_literal2 = LiteralUtil::CreateR1({-2.0, -42.0, 2.0}); Shape vshape = input_literal1->shape(); auto input1 = builder.AddInstruction( @@ -83,7 +83,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { TEST_F(CpuFusionTest, FuseElementwiseOpChain) { auto builder = HloComputation::Builder(TestName()); - auto input_literal = Literal::CreateR1({-1.5, -2.5, -3.0}); + auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); Shape vshape = input_literal->shape(); auto input = builder.AddInstruction( @@ -96,8 +96,11 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { HloInstruction::CreateUnary(vshape, HloOpcode::kExp, ceil)); auto floor = builder.AddInstruction( HloInstruction::CreateUnary(vshape, HloOpcode::kFloor, exp)); - auto two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto two = builder.AddInstruction(HloInstruction::CreateBroadcast( + vshape, + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))), + {})); builder.AddInstruction( HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor)); @@ -114,9 +117,9 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); EXPECT_EQ(HloOpcode::kMultiply, fusion_instruction->fused_expression_root()->opcode()); - // There should be 7 fused instructions: 2 parameters and the fused + // There should be 8 fused instructions: 2 parameters and the fused // operations. - EXPECT_EQ(7, fusion_instruction->fused_instruction_count()); + EXPECT_EQ(8, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. auto result = ExecuteAndTransfer(std::move(module), {}); @@ -131,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { // middle. auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto input_literal = Literal::CreateR1({-1.5, -2.5, -3.0}); + auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); Shape vshape = input_literal->shape(); auto input = builder.AddInstruction( @@ -163,15 +166,18 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { ShapeUtil::MakeShape(F32, {6, 1}), concatenate)), /*init_value=*/ builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), /*dimensions_to_reduce=*/{1}, add_f32)); auto exp = builder.AddInstruction( HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce)); auto floor = builder.AddInstruction( HloInstruction::CreateUnary(cshape, HloOpcode::kFloor, exp)); - auto two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto two = builder.AddInstruction(HloInstruction::CreateBroadcast( + cshape, + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))), + {})); builder.AddInstruction( HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor)); @@ -188,9 +194,9 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode()); EXPECT_EQ(HloOpcode::kMultiply, fusion_instruction1->fused_expression_root()->opcode()); - // There should be 5 fused instructions in the root fusion instruction: 2 + // There should be 6 fused instructions in the root fusion instruction: 2 // parameters, multiply, floor, and exp. - EXPECT_EQ(5, fusion_instruction1->fused_instruction_count()) + EXPECT_EQ(6, fusion_instruction1->fused_instruction_count()) << fusion_instruction1->fused_instructions_computation()->ToString(); auto fusion_instruction2 = reduce->operand(0); @@ -225,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // operand vectors. Test for this problem by counting the number of nodes in // each fusion instruction to ensure that negate is not duplicated. auto builder = HloComputation::Builder(TestName()); - auto input_literal = Literal::CreateR1({1.0, 2.0, 3.0}); + auto input_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); Shape vshape = input_literal->shape(); auto constant = builder.AddInstruction( @@ -286,10 +292,10 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { // computation. The duplication is caused by the other use of exp2 in the // tuple. auto builder = HloComputation::Builder(TestName()); - auto input_literal1 = Literal::CreateR1({1.0, 2.0, 3.0}); - auto input_literal2 = Literal::CreateR1({-2.0, -42.0, 2.0}); + auto input_literal1 = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); + auto input_literal2 = LiteralUtil::CreateR1({-2.0, -42.0, 2.0}); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); Shape shape = constant->shape(); auto exp1 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index dd63b998e9b6d04981ec6f7300c883c9b23b154f..c35569c6619ba5b534c5d8bb7ad683d84b6ecf4b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -47,7 +47,7 @@ class InfeedTest : public ClientLibraryTestBase { // don't use ResetDevice since it is not implemented on CPU. ASSERT_IS_OK(client_->TransferToInfeed(literal)); XlaBuilder builder(TestName()); - builder.Infeed(literal.shape()); + Infeed(&builder, literal.shape()); if (ShapeUtil::IsTuple(literal.shape())) { // TODO(b/30609564): Use ComputeAndCompareLiteral instead. ComputeAndCompareTuple(&builder, literal, {}); @@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*Literal::CreateR0(true)); + TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*Literal::CreateR1({1, 2, 3})); + TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*Literal::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip( - *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, - r3_dim0minor)); + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0minor)); - TestInfeedRoundTrip( - *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, - r3_dim0major)); + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*Literal::CreateR4( + TestInfeedRoundTrip(*LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } TEST_F(InfeedTest, SingleInfeedTuple) { TestInfeedRoundTrip( - *Literal::MakeTuple({Literal::CreateR1({1, 2, 3}).get(), - Literal::CreateR0(false).get()})); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), + LiteralUtil::CreateR0(false).get()})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*Literal::MakeTuple({})); + TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); } // Tests Infeed operation used in a while loop, as in the code below. The @@ -125,8 +125,8 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Gt(builder.ConstantR0(40.0f), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Gt(ConstantR0(&builder, 40.0f), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add the reduced value of the Infeed @@ -134,17 +134,16 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto infeed = builder.Infeed(infeed_shape); - auto addend = - builder.Reduce(infeed, builder.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder), {0}); - builder.Add(prev, addend); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto infeed = Infeed(&builder, infeed_shape); + auto addend = Reduce(infeed, ConstantR0(&builder, 0.0f), + CreateScalarAddComputation(F32, &builder), {0}); + Add(prev, addend); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - auto init = builder.ConstantR0(0.0f); - builder.While(condition, body, init); + auto init = ConstantR0(&builder, 0.0f); + While(condition, body, init); // Build and asynchronously launch the computation. auto computation = builder.Build().ConsumeValueOrDie(); @@ -157,13 +156,16 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { }); // Send 5 Infeed data of shape F32[3]. - ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({1, 2, 3}))); - ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({4, 5, 6}))); - ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({7, 8, 9}))); ASSERT_IS_OK( - client_->TransferToInfeed(*Literal::CreateR1({10, 11, 12}))); + client_->TransferToInfeed(*LiteralUtil::CreateR1({1, 2, 3}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*LiteralUtil::CreateR1({4, 5, 6}))); ASSERT_IS_OK( - client_->TransferToInfeed(*Literal::CreateR1({13, 14, 15}))); + client_->TransferToInfeed(*LiteralUtil::CreateR1({7, 8, 9}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*LiteralUtil::CreateR1({10, 11, 12}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*LiteralUtil::CreateR1({13, 14, 15}))); delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); @@ -207,8 +209,8 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.GetTupleElement(prev, 1); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + GetTupleElement(prev, 1); condition = builder.Build().ConsumeValueOrDie(); } @@ -218,47 +220,47 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // The body adds the reduced value of the Infeed data (first tuple element) // to the previous accumulator, and returns the accumulator and the continue // flag (second tuple element) as a tuple. - const auto build_body = [this, &result_shape](const Shape& infeed_shape) { + const auto build_body = [&result_shape](const Shape& infeed_shape) { XlaComputation body; XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto infeed = builder.Infeed(infeed_shape); - auto addend = builder.Reduce( - builder.GetTupleElement(infeed, 0), builder.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder), {0}); - auto result = builder.Add(builder.GetTupleElement(prev, 0), addend); - builder.Tuple({result, builder.GetTupleElement(infeed, 1)}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto infeed = Infeed(&builder, infeed_shape); + auto addend = + Reduce(GetTupleElement(infeed, 0), ConstantR0(&builder, 0.0f), + CreateScalarAddComputation(F32, &builder), {0}); + auto result = Add(GetTupleElement(prev, 0), addend); + Tuple(&builder, {result, GetTupleElement(infeed, 1)}); return builder.Build().ConsumeValueOrDie(); }; // Create the first while loop with infeed1_shape. - auto init = builder.Tuple( - {builder.ConstantR0(0.0f), builder.ConstantR0(true)}); - auto while1 = builder.While(condition, build_body(infeed1_shape), init); - auto result1 = builder.Tuple( - {builder.GetTupleElement(while1, 0), builder.ConstantR0(true)}); + auto init = Tuple(&builder, {ConstantR0(&builder, 0.0f), + ConstantR0(&builder, true)}); + auto while1 = While(condition, build_body(infeed1_shape), init); + auto result1 = Tuple( + &builder, {GetTupleElement(while1, 0), ConstantR0(&builder, true)}); // Create the second while loop with infeed2_shape. Note that the result from // the first while loop is used as the initial value. - auto while2 = builder.While(condition, build_body(infeed2_shape), result1); - builder.GetTupleElement(while2, 0); + auto while2 = While(condition, build_body(infeed2_shape), result1); + GetTupleElement(while2, 0); // Build the computation. auto computation = builder.Build().ConsumeValueOrDie(); // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({1, 2}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), + LiteralUtil::CreateR0(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({3, 4}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({3, 4}).get(), + LiteralUtil::CreateR0(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({5, 6}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({5, 6}).get(), + LiteralUtil::CreateR0(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({7, 8}).get(), - Literal::CreateR0(false).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8}).get(), + LiteralUtil::CreateR0(false).get()}))); // Asynchronously launch the execution on the device. std::unique_ptr result; @@ -273,14 +275,14 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). sleep(1); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({1, 2, 3}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), + LiteralUtil::CreateR0(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({7, 8, 9}).get(), - Literal::CreateR0(false).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8, 9}).get(), + LiteralUtil::CreateR0(false).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({4, 5, 6}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({4, 5, 6}).get(), + LiteralUtil::CreateR0(true).get()}))); // Wait for the execution to be done, and transfer the result. delete computation_thread; // Joins the thread. diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index 27044b1d62027e3b83744c486cb790269e505aff..90b99c828e2fcfd77579026a39d3a6711599feee 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -38,7 +38,8 @@ while_body { while_cond { arg_cond = f32[2,3,2] parameter(0) - ROOT unknown = pred[] infeed() + infeed = (pred[], token[]) infeed() + ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { @@ -49,14 +50,14 @@ ENTRY main { {{2, 1}, {2001, 3002}, {2001, 2002}}}) const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body - out0 = () outfeed(f32[2,3,2] const_a) - ROOT out1 = () outfeed(f32[2,3,2] const_b) + out0 = token[] outfeed(f32[2,3,2] const_a) + ROOT out1 = token[] outfeed(f32[2,3,2] const_b) } )"; string filecheck_pattern = R"( -CHECK: private constant [12 x float] -CHECK-NOT: private constant [12 x float] +CHECK: private constant [48 x i8] +CHECK-NOT: private constant [48 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -84,7 +85,8 @@ while_body { while_cond { arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - ROOT unknown = pred[] infeed() + infeed = (pred[], token[]) infeed() + ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } ENTRY main { @@ -98,10 +100,10 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [1 x float] -CHECK: private constant [2 x float] -CHECK-NOT: private constant [1 x float] -CHECK-NOT: private constant [2 x float] +CHECK: private constant [4 x i8] +CHECK: private constant [8 x i8] +CHECK-NOT: private constant [4 x i8] +CHECK-NOT: private constant [8 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index 3b6b0ed74065615fb9e47a0ec3c6c4ab078e45c4..01daed4bcd38323bfe33e798a78c2b00b150a1bc 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "llvm/IR/Module.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" @@ -42,7 +42,7 @@ TEST_F(CpuNoAliasTest, Concat) { HloComputation::Builder builder(TestName()); std::unique_ptr literal = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* param_x = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "x")); @@ -78,7 +78,7 @@ TEST_F(CpuNoAliasTest, Concat) { llvm::Function* func = llvm::cast( ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context))); llvm::BasicBlock* bb = llvm::BasicBlock::Create(context, "body", func); - llvm::IRBuilder<> ir_builder(bb); + llvm::IRBuilder<> b(bb); auto* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0); llvm_ir::IrArray::Index zero2D({zero, zero}); @@ -90,7 +90,7 @@ TEST_F(CpuNoAliasTest, Concat) { ir_module.getOrInsertGlobal("param_x", array2d_type); llvm_ir::IrArray param_x_array(param_x_val, param_shape); aa.AddAliasingInformationToIrArray(*param_x, ¶m_x_array); - param_x_array.EmitReadArrayElement(zero2D, &ir_builder) + param_x_array.EmitReadArrayElement(zero2D, &b) ->setName("read_param_x_array"); } @@ -100,7 +100,7 @@ TEST_F(CpuNoAliasTest, Concat) { auto shape = ShapeUtil::MakeShape(F32, {2, 4}); llvm_ir::IrArray concat1_array(concat1_val, shape); aa.AddAliasingInformationToIrArray(*concat1, &concat1_array); - concat1_array.EmitReadArrayElement(zero2D, &ir_builder) + concat1_array.EmitReadArrayElement(zero2D, &b) ->setName("read_concat1_array"); } @@ -110,7 +110,7 @@ TEST_F(CpuNoAliasTest, Concat) { auto shape = ShapeUtil::MakeShape(F32, {2, 6}); llvm_ir::IrArray concat2_array(concat2_val, shape); aa.AddAliasingInformationToIrArray(*concat2, &concat2_array); - concat2_array.EmitReadArrayElement(zero2D, &ir_builder) + concat2_array.EmitReadArrayElement(zero2D, &b) ->setName("read_concat2_array"); } diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index 1ee279290b6fcfe775ce9867d424b1c031f5d2bd..dac416e1c78c2f60d458480c5062f48b77d4878d 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -32,12 +32,13 @@ ENTRY main { {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) - ROOT out = () outfeed(f32[2,3,2] const_a) + outfeed = token[] outfeed(f32[2,3,2] const_a) + ROOT root = () tuple() } )"; string filecheck_pattern = R"( -CHECK: private constant [12 x float] +CHECK: private constant [48 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index c444d151858d3a152a01b99657ffae89ebc6b487..3274be8d9dbfaa55e250748a389ad34fdeb81922 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -23,14 +23,14 @@ namespace xla { namespace cpu { VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size, - llvm::IRBuilder<>* ir_builder, + llvm::IRBuilder<>* b, std::string name) : vector_size_(vector_size), primitive_type_(primitive_type), - ir_builder_(ir_builder), + b_(b), name_(std::move(name)) { scalar_type_ = llvm_ir::PrimitiveTypeToIrType( - primitive_type, ir_builder_->GetInsertBlock()->getModule()); + primitive_type, b_->GetInsertBlock()->getModule()); scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_); vector_type_ = llvm::VectorType::get(scalar_type_, vector_size); vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); @@ -63,9 +63,9 @@ llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs, llvm::Value* rhs) { if (scalar_type_->isFloatingPointTy()) { - return ir_builder()->CreateFMul(lhs, rhs, name()); + return b()->CreateFMul(lhs, rhs, name()); } else { - return ir_builder()->CreateMul(lhs, rhs, name()); + return b()->CreateMul(lhs, rhs, name()); } } @@ -76,13 +76,13 @@ llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) { llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); - return ir_builder()->CreateFSub(lhs, rhs); + return b()->CreateFSub(lhs, rhs); } llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); if (scalar_type_->isFloatingPointTy()) { - return llvm_ir::EmitFloatMax(lhs, rhs, ir_builder_); + return llvm_ir::EmitFloatMax(lhs, rhs, b_); } else { LOG(FATAL) << "Max for integers is unimplemented"; } @@ -91,13 +91,13 @@ llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) { llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) { AssertCorrectTypes({a}); return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a}, - {a->getType()}, ir_builder()); + {a->getType()}, b()); } llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); if (scalar_type_->isFloatingPointTy()) { - return ir_builder()->CreateFDiv(lhs, rhs, name()); + return b()->CreateFDiv(lhs, rhs, name()); } else { LOG(FATAL) << "Division for integers is unimplemented"; } @@ -111,42 +111,41 @@ llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, CHECK(low.compare(high) == llvm::APFloat::cmpLessThan); CHECK(scalar_type_->isFloatingPointTy()); return llvm_ir::EmitFloatMin( - llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), ir_builder_), - GetConstantFloat(type, high), ir_builder_); + llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), b_), + GetConstantFloat(type, high), b_); } llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); - return I1ToFloat(ir_builder()->CreateFCmpOEQ(lhs, rhs, name())); + return I1ToFloat(b()->CreateFCmpOEQ(lhs, rhs, name())); } llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); - return I1ToFloat(ir_builder()->CreateFCmpOLT(lhs, rhs, name())); + return I1ToFloat(b()->CreateFCmpOLT(lhs, rhs, name())); } llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); - return I1ToFloat(ir_builder()->CreateFCmpULE(lhs, rhs, name())); + return I1ToFloat(b()->CreateFCmpULE(lhs, rhs, name())); } llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) { bool is_vector = llvm::isa(i1->getType()); llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector); - return ir_builder()->CreateBitCast( - ir_builder()->CreateSExt(i1, integer_type, name()), - is_vector ? vector_type() : scalar_type(), name()); + return b()->CreateBitCast(b()->CreateSExt(i1, integer_type, name()), + is_vector ? vector_type() : scalar_type(), name()); } llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) { CHECK(scalar_type()->isFloatingPointTy()); const llvm::DataLayout& data_layout = - ir_builder()->GetInsertBlock()->getModule()->getDataLayout(); + b()->GetInsertBlock()->getModule()->getDataLayout(); int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type()); - llvm::Type* scalar_int_type = ir_builder()->getIntNTy(float_size_bits); + llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits); if (vector) { return llvm::VectorType::get(scalar_int_type, vector_size()); } else { @@ -156,7 +155,7 @@ llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) { llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) { CHECK_EQ(x->getType(), scalar_type()); - return ir_builder()->CreateVectorSplat(vector_size(), x, name()); + return b()->CreateVectorSplat(vector_size(), x, name()); } llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs, @@ -164,10 +163,9 @@ llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs, AssertCorrectTypes({lhs, rhs}); llvm::Type* int_type = IntegerTypeForFloatSize(lhs->getType() == vector_type()); - return ir_builder()->CreateBitCast( - ir_builder()->CreateAnd( - ir_builder()->CreateBitCast(lhs, int_type, name()), - ir_builder()->CreateBitCast(rhs, int_type, name()), name()), + return b()->CreateBitCast( + b()->CreateAnd(b()->CreateBitCast(lhs, int_type, name()), + b()->CreateBitCast(rhs, int_type, name()), name()), vector_type()); } @@ -175,9 +173,8 @@ llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) { AssertCorrectTypes({lhs}); llvm::Type* int_type = IntegerTypeForFloatSize(lhs->getType() == vector_type()); - return ir_builder()->CreateBitCast( - ir_builder()->CreateNot( - ir_builder()->CreateBitCast(lhs, int_type, name()), name()), + return b()->CreateBitCast( + b()->CreateNot(b()->CreateBitCast(lhs, int_type, name()), name()), vector_type()); } @@ -185,47 +182,43 @@ llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); llvm::Type* int_type = IntegerTypeForFloatSize(lhs->getType() == vector_type()); - return ir_builder()->CreateBitCast( - ir_builder()->CreateOr(ir_builder()->CreateBitCast(lhs, int_type, name()), - ir_builder()->CreateBitCast(rhs, int_type, name()), - name()), + return b()->CreateBitCast( + b()->CreateOr(b()->CreateBitCast(lhs, int_type, name()), + b()->CreateBitCast(rhs, int_type, name()), name()), vector_type(), name()); } llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs, llvm::Value* rhs) { if (scalar_type_->isFloatingPointTy()) { - return ir_builder()->CreateFAdd(lhs, rhs, name()); + return b()->CreateFAdd(lhs, rhs, name()); } else { - return ir_builder()->CreateAdd(lhs, rhs, name()); + return b()->CreateAdd(lhs, rhs, name()); } } llvm::Value* VectorSupportLibrary::ComputeOffsetPointer( llvm::Value* base_pointer, llvm::Value* offset_elements) { if (base_pointer->getType() != scalar_pointer_type()) { - base_pointer = ir_builder()->CreateBitCast(base_pointer, - scalar_pointer_type(), name()); + base_pointer = + b()->CreateBitCast(base_pointer, scalar_pointer_type(), name()); } - return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements}, - name()); + return b()->CreateInBoundsGEP(base_pointer, {offset_elements}, name()); } llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) { if (pointer->getType() != vector_pointer_type()) { - pointer = - ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name()); + pointer = b()->CreateBitCast(pointer, vector_pointer_type(), name()); } - return ir_builder()->CreateAlignedLoad( + return b()->CreateAlignedLoad( pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); } llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) { if (pointer->getType() != scalar_pointer_type()) { - pointer = - ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name()); } - return ir_builder()->CreateAlignedLoad( + return b()->CreateAlignedLoad( pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); } @@ -233,30 +226,28 @@ void VectorSupportLibrary::StoreVector(llvm::Value* value, llvm::Value* pointer) { AssertCorrectTypes({value}); if (pointer->getType() != vector_pointer_type()) { - pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type()); + pointer = b()->CreateBitCast(pointer, vector_pointer_type()); } - ir_builder()->CreateAlignedStore( - value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); + b()->CreateAlignedStore(value, pointer, + ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); } void VectorSupportLibrary::StoreScalar(llvm::Value* value, llvm::Value* pointer) { AssertCorrectTypes({value}); if (pointer->getType() != scalar_pointer_type()) { - pointer = - ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name()); } - ir_builder()->CreateAlignedStore( - value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); + b()->CreateAlignedStore(value, pointer, + ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); } llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) { if (pointer->getType() != scalar_pointer_type()) { - pointer = - ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name()); } - return ir_builder()->CreateVectorSplat( - vector_size(), ir_builder()->CreateLoad(pointer), name()); + return b()->CreateVectorSplat(vector_size(), b()->CreateLoad(pointer), + name()); } llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) { @@ -267,20 +258,19 @@ llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) { for (unsigned j = 0; j < vector_size(); ++j) { if (j < (i / 2)) { - mask[j] = ir_builder()->getInt32(i / 2 + j); + mask[j] = b()->getInt32(i / 2 + j); } else { - mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty()); + mask[j] = llvm::UndefValue::get(b()->getInt32Ty()); } } - llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector( - vector, llvm::UndefValue::get(vector_type()), - llvm::ConstantVector::get(mask), ""); + llvm::Value* half_remaining_lanes = + b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask), ""); vector = Add(vector, half_remaining_lanes); } - return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0), - name()); + return b()->CreateExtractElement(vector, b()->getInt32(0), name()); } llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs, @@ -307,19 +297,19 @@ llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs, // vector, which are the lanes 2 and 3 in the rhs vector. for (int i = 0; i < vector_size(); i += 2) { int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2); - mask_a.push_back(ir_builder()->getInt32(increment + i)); - mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + mask_a.push_back(b()->getInt32(increment + i)); + mask_b.push_back(b()->getInt32(increment + i + 1)); } for (int i = 0; i < vector_size(); i += 2) { int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size(); - mask_a.push_back(ir_builder()->getInt32(increment + i)); - mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + mask_a.push_back(b()->getInt32(increment + i)); + mask_b.push_back(b()->getInt32(increment + i + 1)); } - llvm::Value* shuffle_0 = ir_builder()->CreateShuffleVector( - lhs, rhs, llvm::ConstantVector::get(mask_a)); - llvm::Value* shuffle_1 = ir_builder()->CreateShuffleVector( - lhs, rhs, llvm::ConstantVector::get(mask_b)); + llvm::Value* shuffle_0 = + b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_a)); + llvm::Value* shuffle_1 = + b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_b)); return Add(shuffle_0, shuffle_1); } @@ -327,23 +317,21 @@ llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) { llvm::SmallVector mask; for (int i = 0; i < vector_size() / 2; i++) { - mask.push_back(ir_builder()->getInt32(i)); + mask.push_back(b()->getInt32(i)); } - return ir_builder()->CreateShuffleVector(vector, - llvm::UndefValue::get(vector_type()), - llvm::ConstantVector::get(mask)); + return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); } llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) { llvm::SmallVector mask; for (int i = 0; i < vector_size() / 2; i++) { - mask.push_back(ir_builder()->getInt32(i + vector_size() / 2)); + mask.push_back(b()->getInt32(i + vector_size() / 2)); } - return ir_builder()->CreateShuffleVector(vector, - llvm::UndefValue::get(vector_type()), - llvm::ConstantVector::get(mask)); + return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); } std::vector VectorSupportLibrary::ComputeHorizontalSums( @@ -360,8 +348,8 @@ std::vector VectorSupportLibrary::ComputeHorizontalSums( [this](llvm::Value* vector) { return AddReduce(vector); }); if (init_values) { for (int64 i = 0, e = result.size(); i < e; i++) { - result[i] = Add(result[i], ir_builder()->CreateExtractElement( - init_values, ir_builder()->getInt32(i))); + result[i] = Add(result[i], + b()->CreateExtractElement(init_values, b()->getInt32(i))); } } return result; @@ -398,9 +386,9 @@ VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( std::vector results; for (int i = 0; i < lane_width; i++) { - llvm::Value* scalar_result = ir_builder()->CreateExtractElement( - i < (lane_width / 2) ? low : high, - ir_builder()->getInt32(i % (lane_width / 2)), name()); + llvm::Value* scalar_result = + b()->CreateExtractElement(i < (lane_width / 2) ? low : high, + b()->getInt32(i % (lane_width / 2)), name()); results.push_back(scalar_result); } @@ -415,17 +403,14 @@ llvm::Value* VectorSupportLibrary::GetZeroScalar() { return llvm::Constant::getNullValue(scalar_type()); } -LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder) - : ir_builder_(ir_builder) { - alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_); +LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* b) : b_(b) { + alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", b_); } -llvm::Value* LlvmVariable::Get() const { - return ir_builder_->CreateLoad(alloca_); -} +llvm::Value* LlvmVariable::Get() const { return b_->CreateLoad(alloca_); } void LlvmVariable::Set(llvm::Value* new_value) { - ir_builder_->CreateStore(new_value, alloca_); + b_->CreateStore(new_value, alloca_); } TileVariable::TileVariable(VectorSupportLibrary* vector_support, diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 49c2a4e2f4bae9e1672b7d2fe891301bce08bd4b..c728f6df0aef83e6ddc6c932a347f14da06d9d0d 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -46,11 +46,11 @@ class VectorSupportLibrary { // instance (i.e. LoadVector will load a vector of type <`vector_size` x // `primitive_type`>). VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size, - llvm::IRBuilder<>* ir_builder, std::string name); + llvm::IRBuilder<>* b, std::string name); llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* Mul(int64 lhs, llvm::Value* rhs) { - return Mul(ir_builder()->getInt64(lhs), rhs); + return Mul(b()->getInt64(lhs), rhs); } llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) { return Mul(GetConstantFloat(rhs->getType(), lhs), rhs); @@ -63,7 +63,7 @@ class VectorSupportLibrary { llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* Add(int64 lhs, llvm::Value* rhs) { - return Add(ir_builder()->getInt64(lhs), rhs); + return Add(b()->getInt64(lhs), rhs); } llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) { return Add(GetConstantFloat(rhs->getType(), lhs), rhs); @@ -147,13 +147,11 @@ class VectorSupportLibrary { llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, llvm::Value* offset_elements, int64 scale) { return ComputeOffsetPointer( - base_pointer, - ir_builder_->CreateMul(ir_builder_->getInt64(scale), offset_elements)); + base_pointer, b_->CreateMul(b_->getInt64(scale), offset_elements)); } llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, int64 offset_elements) { - return ComputeOffsetPointer(base_pointer, - ir_builder()->getInt64(offset_elements)); + return ComputeOffsetPointer(base_pointer, b()->getInt64(offset_elements)); } llvm::Value* LoadVector(llvm::Value* pointer); @@ -164,7 +162,7 @@ class VectorSupportLibrary { } llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) { - return LoadVector(base_pointer, ir_builder()->getInt64(offset_elements)); + return LoadVector(base_pointer, b()->getInt64(offset_elements)); } llvm::Value* LoadScalar(llvm::Value* pointer); @@ -175,7 +173,7 @@ class VectorSupportLibrary { } llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) { - return LoadScalar(base_pointer, ir_builder()->getInt64(offset_elements)); + return LoadScalar(base_pointer, b()->getInt64(offset_elements)); } void StoreVector(llvm::Value* value, llvm::Value* pointer); @@ -187,7 +185,7 @@ class VectorSupportLibrary { void StoreVector(llvm::Value* value, llvm::Value* base_pointer, int64 offset_elements) { - StoreVector(value, base_pointer, ir_builder()->getInt64(offset_elements)); + StoreVector(value, base_pointer, b()->getInt64(offset_elements)); } void StoreScalar(llvm::Value* value, llvm::Value* pointer); @@ -198,7 +196,7 @@ class VectorSupportLibrary { void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, int64 offset_elements) { - StoreScalar(base_pointer, ir_builder()->getInt64(offset_elements)); + StoreScalar(base_pointer, b()->getInt64(offset_elements)); } llvm::Value* LoadBroadcast(llvm::Value* pointer); @@ -207,7 +205,7 @@ class VectorSupportLibrary { return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements)); } llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) { - return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements)); + return LoadBroadcast(base_pointer, b()->getInt64(offset_elements)); } // Compute the horizontal sum of each vector in `vectors`. The i'th element @@ -220,7 +218,7 @@ class VectorSupportLibrary { llvm::Value* GetZeroVector(); llvm::Value* GetZeroScalar(); - llvm::IRBuilder<>* ir_builder() const { return ir_builder_; } + llvm::IRBuilder<>* b() const { return b_; } int64 vector_size() const { return vector_size_; } llvm::Type* vector_type() const { return vector_type_; } llvm::Type* vector_pointer_type() const { return vector_pointer_type_; } @@ -277,7 +275,7 @@ class VectorSupportLibrary { int64 vector_size_; PrimitiveType primitive_type_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; llvm::Type* vector_type_; llvm::Type* vector_pointer_type_; llvm::Type* scalar_type_; @@ -289,22 +287,21 @@ class VectorSupportLibrary { // can later convert to a SSA value. class LlvmVariable { public: - LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder); + LlvmVariable(llvm::Type*, llvm::IRBuilder<>* b); llvm::Value* Get() const; void Set(llvm::Value* new_value); private: llvm::AllocaInst* alloca_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; }; class VectorVariable : public LlvmVariable { public: VectorVariable(VectorSupportLibrary* vector_support, llvm::Value* initial_value) - : LlvmVariable(vector_support->vector_type(), - vector_support->ir_builder()) { + : LlvmVariable(vector_support->vector_type(), vector_support->b()) { Set(initial_value); } }; @@ -313,8 +310,7 @@ class ScalarVariable : public LlvmVariable { public: ScalarVariable(VectorSupportLibrary* vector_support, llvm::Value* initial_value) - : LlvmVariable(vector_support->scalar_type(), - vector_support->ir_builder()) { + : LlvmVariable(vector_support->scalar_type(), vector_support->b()) { Set(initial_value); } }; diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc index 32b5c5d35fae61ae6cb17fafcada1abd6c3c088c..e727ba49cb6321e499b5d50d5f45e7f7f6bb6fef 100644 --- a/tensorflow/compiler/xla/service/defuser_test.cc +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/defuser.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" @@ -124,7 +124,7 @@ TEST_F(DefuserTest, NonTrivialFusionInstruction) { auto div = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3)); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); @@ -162,7 +162,7 @@ TEST_F(DefuserTest, MultipleFusionInstructions) { auto div = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3)); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index ee2b455730f8f520db6652f0352f8a96291cac73..86d57581f84920e8005e8f3c420e7488fc095434 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" @@ -76,6 +76,7 @@ class DfsHloVisitorBase { virtual Status HandleClamp(HloInstructionPtr hlo) = 0; virtual Status HandleSelect(HloInstructionPtr hlo) = 0; + virtual Status HandleTupleSelect(HloInstructionPtr hlo) = 0; virtual Status HandleMaximum(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -105,6 +106,7 @@ class DfsHloVisitorBase { virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; + virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -183,6 +185,9 @@ class DfsHloVisitorBase { virtual Status HandleOr(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } + virtual Status HandleXor(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } virtual Status HandleShiftLeft(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -208,6 +213,7 @@ class DfsHloVisitorBase { virtual Status HandleReverse(HloInstructionPtr hlo) = 0; virtual Status HandleSort(HloInstructionPtr hlo) = 0; virtual Status HandleConstant(HloInstructionPtr hlo) = 0; + virtual Status HandleIota(HloInstructionPtr hlo) = 0; virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0; virtual Status HandleReduce(HloInstructionPtr hlo) = 0; virtual Status HandleBitcast(HloInstructionPtr hlo) = 0; @@ -228,6 +234,7 @@ class DfsHloVisitorBase { virtual Status HandleWhile(HloInstructionPtr hlo) = 0; virtual Status HandleConditional(HloInstructionPtr hlo) = 0; virtual Status HandleGather(HloInstructionPtr hlo) = 0; + virtual Status HandleScatter(HloInstructionPtr hlo) = 0; virtual Status HandlePad(HloInstructionPtr hlo) = 0; @@ -243,7 +250,7 @@ class DfsHloVisitorBase { virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; - virtual Status HandleGenerateToken(HloInstructionPtr token) = 0; + virtual Status HandleAfterAll(HloInstructionPtr token) = 0; // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 6934e00a4b665e9e6a4302e0c0a8ce1d5bb94373..617a5a2eb4796d8003099e39e3d26389e532e954 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" @@ -79,6 +79,9 @@ class DfsHloVisitorWithDefaultBase Status HandleSelect(HloInstructionPtr select) override { return DefaultAction(select); } + Status HandleTupleSelect(HloInstructionPtr tuple_select) override { + return DefaultAction(tuple_select); + } Status HandleDot(HloInstructionPtr dot) override { return DefaultAction(dot); } @@ -91,6 +94,9 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } + Status HandleAllToAll(HloInstructionPtr crs) override { + return DefaultAction(crs); + } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } @@ -112,6 +118,9 @@ class DfsHloVisitorWithDefaultBase Status HandleConstant(HloInstructionPtr constant) override { return DefaultAction(constant); } + Status HandleIota(HloInstructionPtr iota) override { + return DefaultAction(iota); + } Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override { return DefaultAction(get_tuple_element); } @@ -188,7 +197,10 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } - Status HandleGenerateToken(HloInstructionPtr token) override { + Status HandleScatter(HloInstructionPtr scatter) override { + return DefaultAction(scatter); + } + Status HandleAfterAll(HloInstructionPtr token) override { return DefaultAction(token); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 93fea7ead7a86bb34c449668fd88a58145681eb1..2e9d6be2de4a2ab918d9a5ea4881ad3fd036792e 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -61,13 +61,13 @@ int64 GlobalRandomValue() { llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, int64 mantissa_bits, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { // Integer and float types for casting and constant generation. llvm::Type* float_type = x->getType(); - llvm::IntegerType* int_type = ir_builder->getInt32Ty(); + llvm::IntegerType* int_type = b->getInt32Ty(); // Cast the input value to an integer for bitwise manipulation. - llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type); + llvm::Value* x_as_int = b->CreateBitCast(x, int_type); if (mantissa_bits < 23) { // Last remaining mantissa bit. @@ -77,22 +77,22 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, // equal to a base value of 0111... plus one bit if the last remaining // mantissa bit is 1. const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; - llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr( - ir_builder->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), + llvm::Value* x_last_mantissa_bit = b->CreateLShr( + b->CreateAnd(x_as_int, + llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), (23 - mantissa_bits)); - llvm::Value* x_rounding_bias = ir_builder->CreateAdd( - x_last_mantissa_bit, - llvm::ConstantInt::get(int_type, base_rounding_bias)); + llvm::Value* x_rounding_bias = + b->CreateAdd(x_last_mantissa_bit, + llvm::ConstantInt::get(int_type, base_rounding_bias)); // Add rounding bias, and mask out truncated bits. Note that the case // where adding the rounding bias overflows into the exponent bits is // correct; the non-masked mantissa bits will all be zero, and the // exponent will be incremented by one. const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); - x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias); - x_as_int = ir_builder->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); + x_as_int = b->CreateAdd(x_as_int, x_rounding_bias); + x_as_int = b->CreateAnd(x_as_int, + llvm::ConstantInt::get(int_type, truncation_mask)); } if (exponent_bits < 8) { @@ -120,29 +120,29 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, f32_exponent_bias - reduced_exponent_bias; // Do we overflow or underflow? - llvm::Value* x_exponent = ir_builder->CreateAnd( + llvm::Value* x_exponent = b->CreateAnd( x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); - llvm::Value* x_overflows = ir_builder->CreateICmpUGT( + llvm::Value* x_overflows = b->CreateICmpUGT( x_exponent, llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); - llvm::Value* x_underflows = ir_builder->CreateICmpULE( + llvm::Value* x_underflows = b->CreateICmpULE( x_exponent, llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); // Compute appropriately-signed values of zero and infinity. - llvm::Value* x_signed_zero = ir_builder->CreateAnd( + llvm::Value* x_signed_zero = b->CreateAnd( x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); - llvm::Value* x_signed_inf = ir_builder->CreateOr( + llvm::Value* x_signed_inf = b->CreateOr( x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); // Force to zero or infinity if overflow or underflow. (Note that this // truncates all denormal values to zero, rather than rounding them.) - x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int); - x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int); + x_as_int = b->CreateSelect(x_overflows, x_signed_inf, x_as_int); + x_as_int = b->CreateSelect(x_underflows, x_signed_zero, x_as_int); } // Cast the result back to a floating-point type. - llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type); + llvm::Value* result = b->CreateBitCast(x_as_int, float_type); // Correct result for NaN inputs. // @@ -154,53 +154,49 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, // // If the fast-math flags are set to assume no NaNs, the comparison is likely // to be optimized away, so there's no point in even emitting it. - if (!ir_builder->getFastMathFlags().noNaNs()) { - llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x); + if (!b->getFastMathFlags().noNaNs()) { + llvm::Value* x_is_nan = b->CreateFCmpUNO(x, x); if (mantissa_bits > 0) { - result = ir_builder->CreateSelect(x_is_nan, x, result); + result = b->CreateSelect(x_is_nan, x, result); } else { - result = ir_builder->CreateSelect( + result = b->CreateSelect( x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); } } return result; } -llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, - llvm::IRBuilder<>* ir_builder) { +llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, llvm::IRBuilder<>* b) { auto reduced_precision = EmitReducePrecisionFloat( f32_value, /*exponent_bits=*/primitive_util::kBFloat16ExponentBits, - /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder); - auto as_int32 = - ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty()); - auto shifted = ir_builder->CreateLShr(as_int32, 16); - auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty()); - return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty()); + /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b); + auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty()); + auto shifted = b->CreateLShr(as_int32, 16); + auto truncated = b->CreateTrunc(shifted, b->getInt16Ty()); + return b->CreateBitCast(truncated, b->getInt16Ty()); } -llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, - llvm::IRBuilder<>* ir_builder) { - auto as_int16 = - ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty()); - auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty()); - auto shifted = ir_builder->CreateShl(as_int32, 16); - return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy()); +llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, llvm::IRBuilder<>* b) { + auto as_int16 = b->CreateBitCast(bf16_value, b->getInt16Ty()); + auto as_int32 = b->CreateZExt(as_int16, b->getInt32Ty()); + auto shifted = b->CreateShl(as_int32, 16); + return b->CreateBitCast(shifted, b->getFloatTy()); } llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, PrimitiveType from_type, PrimitiveType to_type, llvm::Module* module, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { if (primitive_util::IsSignedIntegralType(from_type)) { - return ir_builder->CreateSIToFP( - integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + return b->CreateSIToFP(integer_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module)); } else { CHECK(primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED); - return ir_builder->CreateUIToFP( - integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + return b->CreateUIToFP(integer_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module)); } } @@ -226,39 +222,43 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); PrimitiveType to_type = op->shape().element_type(); - CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED); + CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED) + << from_type; if (from_type == to_type) { return operand_value; } + if (to_type == PRED) { + return b_->CreateZExt( + b_->CreateICmpNE(operand_value, llvm::ConstantInt::get( + operand_value->getType(), 0)), + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + } if (primitive_util::IsIntegralType(to_type)) { - return ir_builder_->CreateIntCast( + return b_->CreateIntCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), primitive_util::IsSignedIntegralType(from_type)); } if (primitive_util::IsFloatingPointType(to_type)) { if (to_type == BF16) { - return EmitF32ToBF16( - EmitIntegralToFloating(operand_value, from_type, F32, module_, - ir_builder_), - ir_builder_); + return EmitF32ToBF16(EmitIntegralToFloating(operand_value, from_type, + F32, module_, b_), + b_); } return EmitIntegralToFloating(operand_value, from_type, to_type, - module_, ir_builder_); + module_, b_); } if (primitive_util::IsComplexType(to_type)) { auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( primitive_util::ComplexComponentType(to_type), module_); if (primitive_util::IsSignedIntegralType(from_type)) { return EmitComposeComplex( - op, - ir_builder_->CreateSIToFP(operand_value, to_ir_component_type), + op, b_->CreateSIToFP(operand_value, to_ir_component_type), nullptr); } if (primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED) { return EmitComposeComplex( - op, - ir_builder_->CreateUIToFP(operand_value, to_ir_component_type), + op, b_->CreateUIToFP(operand_value, to_ir_component_type), nullptr); } } @@ -275,7 +275,7 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return ir_builder_->CreateBitCast( + return b_->CreateBitCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( @@ -293,18 +293,18 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); auto zero = llvm::ConstantInt::get(type, 0); - auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero); - return ir_builder_->CreateSelect(cmp, operand_value, - ir_builder_->CreateNeg(operand_value)); + auto cmp = b_->CreateICmpSGE(operand_value, zero); + return b_->CreateSelect(cmp, operand_value, + b_->CreateNeg(operand_value)); } else { return operand_value; } } case HloOpcode::kClz: { - auto is_zero_undef = ir_builder_->getFalse(); - return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::ctlz, {operand_value, is_zero_undef}, - {operand_value->getType()}, ir_builder_); + auto is_zero_undef = b_->getFalse(); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctlz, + {operand_value, is_zero_undef}, + {operand_value->getType()}, b_); } case HloOpcode::kSign: { bool is_signed = @@ -312,31 +312,28 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); auto zero = llvm::ConstantInt::get(type, 0); - auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero); + auto cmp = b_->CreateICmpEQ(operand_value, zero); if (is_signed) { - auto ashr = ir_builder_->CreateAShr(operand_value, - type->getIntegerBitWidth() - 1); - return ir_builder_->CreateSelect(cmp, zero, - ir_builder_->CreateOr(ashr, 1)); + auto ashr = + b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1); + return b_->CreateSelect(cmp, zero, b_->CreateOr(ashr, 1)); } else { - return ir_builder_->CreateSelect(cmp, zero, - llvm::ConstantInt::get(type, 1)); + return b_->CreateSelect(cmp, zero, llvm::ConstantInt::get(type, 1)); } } case HloOpcode::kNegate: - return ir_builder_->CreateNeg(operand_value); + return b_->CreateNeg(operand_value); case HloOpcode::kNot: { auto type = op->shape().element_type(); if (type == PRED) { // It is not sufficient to just call CreateNot() here because a PRED // is represented as an i8 and the truth value is stored only in the // bottom bit. - return ir_builder_->CreateZExt( - ir_builder_->CreateNot(ir_builder_->CreateTrunc( - operand_value, ir_builder_->getInt1Ty())), + return b_->CreateZExt( + b_->CreateNot(b_->CreateTrunc(operand_value, b_->getInt1Ty())), llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } else if (primitive_util::IsIntegralType(type)) { - return ir_builder_->CreateNot(operand_value); + return b_->CreateNot(operand_value); } return Unimplemented("unary op Not is not defined for type '%d'", type); } @@ -352,7 +349,7 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); PrimitiveType to_type = op->shape().element_type(); - CHECK(primitive_util::IsFloatingPointType(from_type)); + CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type; if (from_type == to_type) { return operand_value; } @@ -364,32 +361,38 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitComposeComplex( op, - ir_builder_->CreateFPCast( - operand_value, - llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), + b_->CreateFPCast(operand_value, llvm_ir::PrimitiveTypeToIrType( + to_component_type, module_)), nullptr); } if (from_type == BF16) { TF_RET_CHECK(to_type != BF16); - operand_value = EmitBF16ToF32(operand_value, ir_builder_); + operand_value = EmitBF16ToF32(operand_value, b_); from_type = F32; if (from_type == to_type) { return operand_value; } } if (from_type == F32 && to_type == BF16) { - return EmitF32ToBF16(operand_value, ir_builder_); + return EmitF32ToBF16(operand_value, b_); + } + if (to_type == PRED) { + return b_->CreateZExt( + b_->CreateFCmpUNE( + operand_value, + llvm::ConstantFP::get(operand_value->getType(), 0.0)), + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } if (primitive_util::IsFloatingPointType(to_type)) { - return ir_builder_->CreateFPCast( + return b_->CreateFPCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsSignedIntegralType(to_type)) { - return ir_builder_->CreateFPToSI( + return b_->CreateFPToSI( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsUnsignedIntegralType(to_type)) { - return ir_builder_->CreateFPToUI( + return b_->CreateFPToUI( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return Unimplemented("unhandled conversion operation: %s => %s", @@ -405,7 +408,7 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return ir_builder_->CreateBitCast( + return b_->CreateBitCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( @@ -428,46 +431,52 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return EmitCos(op->shape().element_type(), operand_value); case HloOpcode::kSin: return EmitSin(op->shape().element_type(), operand_value); + case HloOpcode::kTanh: + return EmitTanh(op->shape().element_type(), operand_value); case HloOpcode::kFloor: - return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()}, - ir_builder_); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, + {operand_value}, + {operand_value->getType()}, b_); case HloOpcode::kCeil: - return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::ceil, {operand_value}, {operand_value->getType()}, - ir_builder_); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ceil, + {operand_value}, + {operand_value->getType()}, b_); case HloOpcode::kAbs: - return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::fabs, {operand_value}, {operand_value->getType()}, - ir_builder_); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, + {operand_value}, + {operand_value->getType()}, b_); case HloOpcode::kRoundNearestAfz: - return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::round, {operand_value}, {operand_value->getType()}, - ir_builder_); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round, + {operand_value}, + {operand_value->getType()}, b_); case HloOpcode::kSign: { // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero); - auto olt = ir_builder_->CreateFCmpOLT(operand_value, zero); - return ir_builder_->CreateSelect( + auto oeq = b_->CreateFCmpOEQ(operand_value, zero); + auto olt = b_->CreateFCmpOLT(operand_value, zero); + return b_->CreateSelect( oeq, zero, - ir_builder_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0), - llvm::ConstantFP::get(type, 1.0))); + b_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0), + llvm::ConstantFP::get(type, 1.0))); } case HloOpcode::kIsFinite: { // abs(x) o!= inf, this works because the comparison returns false if // either operand is NaN. auto type = operand_value->getType(); auto abs_value = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_); + llvm::Intrinsic::fabs, {operand_value}, {type}, b_); auto infinity = llvm::ConstantFP::getInfinity(type); - auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity); - return ir_builder_->CreateZExt( - not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + auto not_infinite = b_->CreateFCmpONE(abs_value, infinity); + return b_->CreateZExt(not_infinite, + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: - return ir_builder_->CreateFNeg(operand_value); + return b_->CreateFNeg(operand_value); + case HloOpcode::kReal: + return operand_value; + case HloOpcode::kImag: + return llvm::ConstantFP::get(operand_value->getType(), 0.0); default: return Unimplemented("unary floating-point op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -487,13 +496,12 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); - auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), - ir_builder_->CreateFMul(b, b)); + auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex( - op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); + return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), + angle); } case HloOpcode::kLog1p: { // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) @@ -501,15 +509,14 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); auto one = llvm::ConstantFP::get(llvm_ty, 1.0); - auto a_plus_one = ir_builder_->CreateFAdd(a, one); - auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(a_plus_one, a_plus_one), - ir_builder_->CreateFMul(b, b)); + auto a_plus_one = b_->CreateFAdd(a, one); + auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a_plus_one, a_plus_one), + b_->CreateFMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex( - op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); + return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), + angle); } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -523,12 +530,11 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( primitive_util::ComplexComponentType(to_type); auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); - return EmitComposeComplex( - op, - ir_builder_->CreateFPCast(EmitExtractReal(operand_value), - to_ir_component_type), - ir_builder_->CreateFPCast(EmitExtractImag(operand_value), - to_ir_component_type)); + return EmitComposeComplex(op, + b_->CreateFPCast(EmitExtractReal(operand_value), + to_ir_component_type), + b_->CreateFPCast(EmitExtractImag(operand_value), + to_ir_component_type)); } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) @@ -538,8 +544,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); - return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); + return EmitComposeComplex(op, b_->CreateFMul(exp_a, cos_b), + b_->CreateFMul(exp_a, sin_b)); } case HloOpcode::kExpm1: { // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i @@ -550,9 +556,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0); - auto real_result = - ir_builder_->CreateFSub(ir_builder_->CreateFMul(exp_a, cos_b), one); - auto imag_result = ir_builder_->CreateFMul(exp_a, sin_b); + auto real_result = b_->CreateFSub(b_->CreateFMul(exp_a, cos_b), one); + auto imag_result = b_->CreateFMul(exp_a, sin_b); return EmitComposeComplex(op, real_result, imag_result); } case HloOpcode::kCos: { @@ -567,18 +572,14 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = - ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); auto half_exp_neg_b = - ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); return EmitComposeComplex( - op, - ir_builder_->CreateFMul( - cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), - ir_builder_->CreateFMul( - sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b))); + op, b_->CreateFMul(cos_a, b_->CreateFAdd(half_exp_neg_b, half_exp_b)), + b_->CreateFMul(sin_a, b_->CreateFSub(half_exp_neg_b, half_exp_b))); } case HloOpcode::kSin: { // sin(z) = .5i(e^(-iz) - e^(iz)) @@ -594,18 +595,14 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = - ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); auto half_exp_neg_b = - ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); return EmitComposeComplex( - op, - ir_builder_->CreateFMul( - sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), - ir_builder_->CreateFMul( - cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); + op, b_->CreateFMul(sin_a, b_->CreateFAdd(half_exp_b, half_exp_neg_b)), + b_->CreateFMul(cos_a, b_->CreateFSub(half_exp_b, half_exp_neg_b))); } case HloOpcode::kTanh: { /* @@ -633,64 +630,61 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a)); TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); - auto exp_neg_a = ir_builder_->CreateFDiv( - llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); - auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(exp_a, exp_a), - ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); - auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); - auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); - auto real_num = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), - ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); - auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); - auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); + auto exp_neg_a = + b_->CreateFDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = b_->CreateFSub( + b_->CreateFMul(exp_a, exp_a), b_->CreateFMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = b_->CreateFMul(cos_b, cos_b); + auto sin_b_sq = b_->CreateFMul(sin_b, sin_b); + auto real_num = + b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + b_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = b_->CreateFMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = b_->CreateFAdd(exp_a, exp_neg_a); auto exp_a_plus_exp_neg_a_sq = - ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); - auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); + b_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = b_->CreateFSub(exp_a, exp_neg_a); auto exp_a_minus_exp_neg_a_sq = - ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); - auto imag_num = ir_builder_->CreateFMul( - cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, - exp_a_minus_exp_neg_a_sq)); - auto denom = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), - ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); - return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), - ir_builder_->CreateFDiv(imag_num, denom)); + b_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = b_->CreateFMul( + cos_b_sin_b, + b_->CreateFSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq)); + auto denom = + b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + b_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, b_->CreateFDiv(real_num, denom), + b_->CreateFDiv(imag_num, denom)); } case HloOpcode::kAbs: { - auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - ir_builder_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = + b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), + EmitExtractReal(operand_value)), + b_->CreateFMul(EmitExtractImag(operand_value), + EmitExtractImag(operand_value))); return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, - {sum_sq->getType()}, ir_builder_); + {sum_sq->getType()}, b_); } case HloOpcode::kSign: { // Sign(c) = c / |c| - auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - ir_builder_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = + b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), + EmitExtractReal(operand_value)), + b_->CreateFMul(EmitExtractImag(operand_value), + EmitExtractImag(operand_value))); auto cplx_abs = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_); + llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_); auto type = cplx_abs->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero); - return ir_builder_->CreateSelect( + auto oeq = b_->CreateFCmpOEQ(cplx_abs, zero); + return b_->CreateSelect( oeq, EmitComposeComplex(op, zero, zero), EmitComposeComplex( - op, - ir_builder_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), - ir_builder_->CreateFDiv(EmitExtractImag(operand_value), - cplx_abs))); + op, b_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), + b_->CreateFDiv(EmitExtractImag(operand_value), cplx_abs))); } case HloOpcode::kNegate: - return EmitComposeComplex( - op, ir_builder_->CreateFNeg(EmitExtractReal(operand_value)), - ir_builder_->CreateFNeg(EmitExtractImag(operand_value))); + return EmitComposeComplex(op, + b_->CreateFNeg(EmitExtractReal(operand_value)), + b_->CreateFNeg(EmitExtractImag(operand_value))); case HloOpcode::kReal: return EmitExtractReal(operand_value); case HloOpcode::kImag: @@ -724,15 +718,15 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( case HloOpcode::kComplex: return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: - return ir_builder_->CreateFAdd(lhs_value, rhs_value); + return b_->CreateFAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: - return ir_builder_->CreateFSub(lhs_value, rhs_value); + return b_->CreateFSub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return ir_builder_->CreateFMul(lhs_value, rhs_value); + return b_->CreateFMul(lhs_value, rhs_value); case HloOpcode::kDivide: - return ir_builder_->CreateFDiv(lhs_value, rhs_value); + return b_->CreateFDiv(lhs_value, rhs_value); case HloOpcode::kRemainder: - return ir_builder_->CreateFRem(lhs_value, rhs_value); + return b_->CreateFRem(lhs_value, rhs_value); // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered // comparisons always return false when one of the operands is NaN, whereas // unordered comparisons return true. @@ -742,22 +736,22 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( // matches C++'s semantics. case HloOpcode::kEq: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kNe: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kLt: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kGt: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kLe: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kGe: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kMaximum: return EmitFloatMax(lhs_value, rhs_value); @@ -778,64 +772,56 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( llvm::Value* rhs_value) const { switch (op->opcode()) { case HloOpcode::kAdd: - return EmitComposeComplex( - op, - ir_builder_->CreateFAdd(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFAdd(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex(op, + b_->CreateFAdd(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + b_->CreateFAdd(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); case HloOpcode::kSubtract: - return EmitComposeComplex( - op, - ir_builder_->CreateFSub(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFSub(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex(op, + b_->CreateFSub(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + b_->CreateFSub(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); case HloOpcode::kMultiply: return EmitComposeComplex( op, - ir_builder_->CreateFSub( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), - ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)))); + b_->CreateFSub(b_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + b_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), + b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + b_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)))); case HloOpcode::kDivide: { // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) - auto rhs_sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(rhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(rhs_value), - EmitExtractImag(rhs_value))); + auto rhs_sum_sq = + b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(rhs_value), + EmitExtractReal(rhs_value)), + b_->CreateFMul(EmitExtractImag(rhs_value), + EmitExtractImag(rhs_value))); auto type = rhs_sum_sq->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero); - auto real_inf_or_nan = - ir_builder_->CreateFDiv(EmitExtractReal(lhs_value), zero); - auto imag_inf_or_nan = - ir_builder_->CreateFDiv(EmitExtractImag(lhs_value), zero); - return ir_builder_->CreateSelect( + auto oeq = b_->CreateFCmpOEQ(rhs_sum_sq, zero); + auto real_inf_or_nan = b_->CreateFDiv(EmitExtractReal(lhs_value), zero); + auto imag_inf_or_nan = b_->CreateFDiv(EmitExtractImag(lhs_value), zero); + return b_->CreateSelect( oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan), EmitComposeComplex( op, - ir_builder_->CreateFDiv( - ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), + b_->CreateFDiv( + b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + b_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), rhs_sum_sq), - ir_builder_->CreateFDiv( - ir_builder_->CreateFSub( - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value))), + b_->CreateFDiv( + b_->CreateFSub(b_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)), + b_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value))), rhs_sum_sq))); } // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered @@ -846,21 +832,21 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( // unordered comparison. This makes x != y equivalent to !(x == y), and // matches C++'s semantics. case HloOpcode::kEq: - return ir_builder_->CreateAnd( + return b_->CreateAnd( llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), ir_builder_), + EmitExtractReal(rhs_value), b_), llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), ir_builder_)); + EmitExtractImag(rhs_value), b_)); case HloOpcode::kNe: - return ir_builder_->CreateOr( + return b_->CreateOr( llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), ir_builder_), + EmitExtractReal(rhs_value), b_), llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), ir_builder_)); + EmitExtractImag(rhs_value), b_)); case HloOpcode::kPower: { // (a+bi)^(c+di) = @@ -872,29 +858,26 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( auto b = EmitExtractImag(lhs_value); auto c = EmitExtractReal(rhs_value); auto d = EmitExtractImag(rhs_value); - auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), - ir_builder_->CreateFMul(b, b)); + auto aa_p_bb = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto half_c = ir_builder_->CreateFMul(one_half, c); + auto half_c = b_->CreateFMul(one_half, c); TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, EmitPow(component_type, aa_p_bb, half_c)); - auto neg_d = ir_builder_->CreateFNeg(d); + auto neg_d = b_->CreateFNeg(d); TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); - auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); + auto neg_d_arg_lhs = b_->CreateFMul(neg_d, arg_lhs); TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, EmitExp(component_type, neg_d_arg_lhs)); - auto coeff = - ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + auto coeff = b_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); - auto half_d = ir_builder_->CreateFMul(one_half, d); - auto q = - ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), - ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); + auto half_d = b_->CreateFMul(one_half, d); + auto q = b_->CreateFAdd(b_->CreateFMul(c, arg_lhs), + b_->CreateFMul(half_d, ln_aa_p_bb)); TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); - return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), - ir_builder_->CreateFMul(coeff, sin_q)); + return EmitComposeComplex(op, b_->CreateFMul(coeff, cos_q), + b_->CreateFMul(coeff, sin_q)); } default: return Unimplemented("binary complex op '%s'", @@ -904,12 +887,12 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value) const { - return llvm_ir::EmitFloatMax(lhs_value, rhs_value, ir_builder_); + return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_); } llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value) const { - return llvm_ir::EmitFloatMin(lhs_value, rhs_value, ir_builder_); + return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_); } StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, @@ -921,15 +904,14 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, "type F32."); } auto getFloat = [&](const float f) { - return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f); + return llvm::ConstantFP::get(b_->getFloatTy(), f); }; auto multiply_add = [&](tensorflow::gtl::ArraySlice coefficients, llvm::Value* w) { llvm::Value* p = getFloat(coefficients.front()); coefficients.pop_front(); for (float coefficient : coefficients) { - p = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(p, w), - getFloat(coefficient)); + p = b_->CreateFAdd(b_->CreateFMul(p, w), getFloat(coefficient)); } return p; }; @@ -947,50 +929,48 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, // } // return p*x llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::log, {ir_builder_->getFloatTy()}); + module_, llvm::Intrinsic::log, {b_->getFloatTy()}); - llvm::Value* w = ir_builder_->CreateFNeg(ir_builder_->CreateCall( - logf_fn, - {ir_builder_->CreateFMul(ir_builder_->CreateFSub(getFloat(1.0f), x), - ir_builder_->CreateFAdd(getFloat(1.0f), x))})); + llvm::Value* w = b_->CreateFNeg(b_->CreateCall( + logf_fn, {b_->CreateFMul(b_->CreateFSub(getFloat(1.0f), x), + b_->CreateFAdd(getFloat(1.0f), x))})); - llvm::Value* p_addr = llvm_ir::EmitAllocaAtFunctionEntry( - ir_builder_->getFloatTy(), "p.addr", ir_builder_); + llvm::Value* p_addr = + llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_); - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(ir_builder_->CreateFCmpOLT(w, getFloat(5.0f)), - "w_less_than_five", ir_builder_); + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + b_->CreateFCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_); // Handle true BB. - SetToFirstInsertPoint(if_data.true_block, ir_builder_); + SetToFirstInsertPoint(if_data.true_block, b_); { - llvm::Value* lw = ir_builder_->CreateFSub(w, getFloat(2.5f)); + llvm::Value* lw = b_->CreateFSub(w, getFloat(2.5f)); tensorflow::gtl::ArraySlice lq{ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, -4.39150654e-06f, 0.00021858087f, -0.00125372503f, -0.00417768164f, 0.246640727f, 1.50140941f}; llvm::Value* p = multiply_add(lq, lw); - ir_builder_->CreateStore(p, p_addr); + b_->CreateStore(p, p_addr); } // Handle false BB. - SetToFirstInsertPoint(if_data.false_block, ir_builder_); + SetToFirstInsertPoint(if_data.false_block, b_); { llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::sqrt, {ir_builder_->getFloatTy()}); + module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); - llvm::Value* gw = ir_builder_->CreateFSub( - ir_builder_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f)); + llvm::Value* gw = + b_->CreateFSub(b_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f)); tensorflow::gtl::ArraySlice gq{ -0.000200214257f, 0.000100950558f, 0.00134934322f, -0.00367342844f, 0.00573950773f, -0.0076224613f, 0.00943887047f, 1.00167406f, 2.83297682f}; llvm::Value* p = multiply_add(gq, gw); - ir_builder_->CreateStore(p, p_addr); + b_->CreateStore(p, p_addr); } - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - llvm::Value* p = ir_builder_->CreateLoad(p_addr); - return ir_builder_->CreateFMul(p, x); + SetToFirstInsertPoint(if_data.after_block, b_); + llvm::Value* p = b_->CreateLoad(p_addr); + return b_->CreateFMul(p, x); } StatusOr ElementalIrEmitter::EmitErfcInv( @@ -998,13 +978,13 @@ StatusOr ElementalIrEmitter::EmitErfcInv( // Compute erfcinv(value) by calculating erfinv(1.0 - value). auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); - return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value)); + return EmitErfInv(prim_type, b_->CreateFSub(one, value)); } StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, llvm::Value* value) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value}, - {value->getType()}, ir_builder_); + {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, @@ -1016,35 +996,34 @@ StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, // When x is large, the naive evaluation of ln(x + 1) is more // accurate than the Taylor series. TF_ASSIGN_OR_RETURN(auto for_large_x, - EmitLog(prim_type, ir_builder_->CreateFAdd(x, one))); + EmitLog(prim_type, b_->CreateFAdd(x, one))); // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + …. - auto for_small_x = ir_builder_->CreateFMul( - ir_builder_->CreateFAdd(ir_builder_->CreateFMul(negative_half, x), one), - x); + auto for_small_x = + b_->CreateFMul(b_->CreateFAdd(b_->CreateFMul(negative_half, x), one), x); const auto kAntilogarithmIsSmallThreshold = 1e-4; - auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, - {type}, ir_builder_); - auto x_is_small = ir_builder_->CreateFCmpOLT( + auto abs_x = + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); + auto x_is_small = b_->CreateFCmpOLT( abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold)); - return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); + return b_->CreateSelect(x_is_small, for_small_x, for_large_x); } StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, llvm::Value* value) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, - {value->getType()}, ir_builder_); + {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitCos(PrimitiveType prim_type, llvm::Value* value) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, - {value->getType()}, ir_builder_); + {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, llvm::Value* value) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, - {value->getType()}, ir_builder_); + {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, @@ -1056,25 +1035,25 @@ StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, // When the exponent is large, the naive evaluation of e^(x) - 1 is more // accurate than the Taylor series. TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value)); - auto for_large_x = ir_builder_->CreateFSub(exp_x, one); + auto for_large_x = b_->CreateFSub(exp_x, one); // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + …. // We want exp(x)-1 which is x + x^2/2 + x^3/6 + …. - auto x_squared = ir_builder_->CreateFAdd(x, x); - auto x_squared_over_two = ir_builder_->CreateFMul(x_squared, half); - auto for_small_x = ir_builder_->CreateFAdd(x, x_squared_over_two); + auto x_squared = b_->CreateFAdd(x, x); + auto x_squared_over_two = b_->CreateFMul(x_squared, half); + auto for_small_x = b_->CreateFAdd(x, x_squared_over_two); const auto kExponentIsSmallThreshold = 1e-5; - auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, - {type}, ir_builder_); - auto x_is_small = ir_builder_->CreateFCmpOLT( + auto abs_x = + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); + auto x_is_small = b_->CreateFCmpOLT( abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); - return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); + return b_->CreateSelect(x_is_small, for_small_x, for_large_x); } StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, - {lhs->getType()}, ir_builder_); + {lhs->getType()}, b_); } StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, @@ -1083,17 +1062,21 @@ StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, return Unimplemented("atan2"); } +StatusOr ElementalIrEmitter::EmitTanh(PrimitiveType prim_type, + llvm::Value* value) const { + return Unimplemented("tanh"); +} + StatusOr ElementalIrEmitter::EmitReducePrecision( const HloInstruction* hlo, llvm::Value* x) const { if (hlo->operand(0)->shape().element_type() != F32) { return Unimplemented("reduce-precision only implemented for F32"); } return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(), - /*mantissa_bits=*/hlo->mantissa_bits(), - ir_builder_); + /*mantissa_bits=*/hlo->mantissa_bits(), b_); } -static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* ir_builder, +static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* shift_result, bool saturate_to_sign_bit) { @@ -1106,15 +1089,14 @@ static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* ir_builder, llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1); llvm::Value* saturated_value; if (saturate_to_sign_bit) { - saturated_value = ir_builder->CreateSelect( - ir_builder->CreateICmpSLT(lhs, zero), minus_one, zero); + saturated_value = + b->CreateSelect(b->CreateICmpSLT(lhs, zero), minus_one, zero); } else { saturated_value = zero; } llvm::Value* shift_amt_in_range = - ir_builder->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk"); - return ir_builder->CreateSelect(shift_amt_in_range, shift_result, - saturated_value); + b->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk"); + return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value); } StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( @@ -1123,47 +1105,49 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( switch (op->opcode()) { // TODO(jingyue): add the "nsw" attribute for signed types. case HloOpcode::kAdd: - return ir_builder_->CreateAdd(lhs_value, rhs_value); + return b_->CreateAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: - return ir_builder_->CreateSub(lhs_value, rhs_value); + return b_->CreateSub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return ir_builder_->CreateMul(lhs_value, rhs_value); + return b_->CreateMul(lhs_value, rhs_value); case HloOpcode::kDivide: - return is_signed ? ir_builder_->CreateSDiv(lhs_value, rhs_value) - : ir_builder_->CreateUDiv(lhs_value, rhs_value); + return is_signed ? b_->CreateSDiv(lhs_value, rhs_value) + : b_->CreateUDiv(lhs_value, rhs_value); case HloOpcode::kRemainder: - return is_signed ? ir_builder_->CreateSRem(lhs_value, rhs_value) - : ir_builder_->CreateURem(lhs_value, rhs_value); + return is_signed ? b_->CreateSRem(lhs_value, rhs_value) + : b_->CreateURem(lhs_value, rhs_value); case HloOpcode::kEq: return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kNe: return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kLt: return llvm_ir::EmitComparison( is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT, - lhs_value, rhs_value, ir_builder_); + lhs_value, rhs_value, b_); case HloOpcode::kGt: return llvm_ir::EmitComparison( is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT, - lhs_value, rhs_value, ir_builder_); + lhs_value, rhs_value, b_); case HloOpcode::kLe: return llvm_ir::EmitComparison( is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE, - lhs_value, rhs_value, ir_builder_); + lhs_value, rhs_value, b_); case HloOpcode::kGe: return llvm_ir::EmitComparison( is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, - lhs_value, rhs_value, ir_builder_); + lhs_value, rhs_value, b_); case HloOpcode::kMinimum: return EmitIntegralMin(lhs_value, rhs_value, is_signed); case HloOpcode::kMaximum: return EmitIntegralMax(lhs_value, rhs_value, is_signed); case HloOpcode::kAnd: - return ir_builder_->CreateAnd(lhs_value, rhs_value); + return b_->CreateAnd(lhs_value, rhs_value); case HloOpcode::kOr: - return ir_builder_->CreateOr(lhs_value, rhs_value); + return b_->CreateOr(lhs_value, rhs_value); + case HloOpcode::kXor: + return b_->CreateXor(lhs_value, rhs_value); // Shifting out bits >= the number of bits in the type being shifted // produces a poison value in LLVM which is basically "deferred undefined @@ -1171,20 +1155,17 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( // UB. We replace the poison value with a constant to avoid this deferred // UB. case HloOpcode::kShiftRightArithmetic: - return SaturateShiftIfNecessary( - ir_builder_, lhs_value, rhs_value, - ir_builder_->CreateAShr(lhs_value, rhs_value), - /*saturate_to_sign_bit=*/true); + return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, + b_->CreateAShr(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/true); case HloOpcode::kShiftLeft: - return SaturateShiftIfNecessary( - ir_builder_, lhs_value, rhs_value, - ir_builder_->CreateShl(lhs_value, rhs_value), - /*saturate_to_sign_bit=*/false); + return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, + b_->CreateShl(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/false); case HloOpcode::kShiftRightLogical: - return SaturateShiftIfNecessary( - ir_builder_, lhs_value, rhs_value, - ir_builder_->CreateLShr(lhs_value, rhs_value), - /*saturate_to_sign_bit=*/false); + return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, + b_->CreateLShr(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/false); default: return Unimplemented("binary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -1194,21 +1175,19 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed) const { - return ir_builder_->CreateSelect( - ir_builder_->CreateICmp( - is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, - lhs_value, rhs_value), - lhs_value, rhs_value); + return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE + : llvm::ICmpInst::ICMP_UGE, + lhs_value, rhs_value), + lhs_value, rhs_value); } llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed) const { - return ir_builder_->CreateSelect( - ir_builder_->CreateICmp( - is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE, - lhs_value, rhs_value), - lhs_value, rhs_value); + return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE + : llvm::ICmpInst::ICMP_ULE, + lhs_value, rhs_value), + lhs_value, rhs_value); } llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( @@ -1220,204 +1199,296 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( const Shape& operand_shape = hlo.operand(operand_no)->shape(); // If the operand is scalar, the source index is always {}. if (ShapeUtil::IsScalar(operand_shape)) { - return llvm_ir::IrArray::Index(); + return llvm_ir::IrArray::Index(target_index.GetType()); } // If no implicit broadcast is needed for this operand, returns the target // index as the source index. - if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) { + // + // `IrArray::Index` may contain a physical linear which we can propagate to + // our operand only if our layouts match. "only if" is a bit strong since + // e.g. we can still forward the linear index if the operand shape is + // [5,1,1,5]{3,2,1,0} and the HLO shape is[5,1,1,5]{3,1,2,0}, but those cases + // are probably not worth handling here for now. + if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape()) && + LayoutUtil::Equal(operand_shape.layout(), hlo.shape().layout())) { return target_index; } // If implicit broadcast is needed, the source dimensions that are broadcast // have index 0. CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape())); - llvm_ir::IrArray::Index source_index; + llvm_ir::IrArray::Index source_index(target_index.GetType()); for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) { if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) { source_index.push_back(target_index[i]); } else { CHECK_EQ(1, operand_shape.dimensions(i)); - source_index.push_back(ir_builder_->getInt64(0)); + source_index.push_back(target_index.GetConstantWithIndexType(0)); } } return source_index; } -llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( +StatusOr ElementalIrEmitter::ConvertValueForDistribution( const HloInstruction* hlo, - const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) - const { - PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type(); - llvm::Type* param_ir_type = - llvm_ir::PrimitiveTypeToIrType(param_prim_type, module_); - - // Same values as PCG library - // https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h - llvm::Value* multiplier = ir_builder_->getInt( - llvm::APInt(128, {0x4385DF649FCCF645, 0x2360ED051FC65DA4})); - llvm::Value* increment = ir_builder_->getInt( - llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D})); - - auto random_value_from_hlo = [hlo]() { - const HloModule* module = - hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent() - : hlo->parent()->parent(); - return module->RandomNew64(); - }; + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const { + TF_ASSIGN_OR_RETURN(llvm::Value * a_or_mean, + operand_to_generator.at(hlo->operand(0))(index)); + TF_ASSIGN_OR_RETURN(llvm::Value * b_or_sigma, + operand_to_generator.at(hlo->operand(1))(index)); + PrimitiveType elem_prim_ty = hlo->shape().element_type(); + llvm::Type* elem_ir_ty = + llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_); + llvm::Type* raw_value_ty = raw_value->getType(); + + // Convert raw integer to float in range [0, 1) if the element is a float. + llvm::Value* elem_value = raw_value; + if (elem_ir_ty->isFloatingPointTy()) { + unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits(); + CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64); + // Perform the division using the float type with the same number of bits + // as the raw value to avoid overflow. + if (raw_value_size_in_bits == 32) { + elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy()); + elem_value = b_->CreateFDiv( + elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); + } else { + elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy()); + elem_value = b_->CreateFDiv( + elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); + } + + if (elem_ir_ty != elem_value->getType()) { + elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty); + } + } - // Seed each RNG emitter with a new 64-bit seed from the HloModule. If the - // compilation order is deterministic (i.e., RandomNew64 invocation order is - // deterministic), then the order of RNG is deterministic for a given seed and - // hence tests will be deterministic. - // If the user provides a global seed instruction then we only use 64-bits of - // the host's random number generator to seed the 128 bit value with the other - // 64-bits is due to a user specified global seed instruction. - // Create a GlobalVariable to maintain state between invocations. There is a - // bug in NVPTX with GlobalVariable and 128 bit values, so using 2 64-bit + // Convert the value for the requested distribution. + switch (hlo->random_distribution()) { + case RNG_UNIFORM: { + if (elem_ir_ty->isFloatingPointTy()) { + return b_->CreateFAdd( + b_->CreateFMul(b_->CreateFSub(b_or_sigma, a_or_mean), elem_value), + a_or_mean); + } else { + // To generate a uniform random value in [a, b) from a raw random sample + // in range [0, 2^N), we let range = b - a and return + // (a + raw_value % range). If range is not a power of 2, raw values + // larger than (2^N - 2^N % range) are biased toward results in + // [a, a + (limit % range)). An unbiased algorithm would need to drop + // raw values and re-sample, but we don't do this because re-sampling in + // an efficient way is complex, and it's not clear that users need it. + // In particular, if one thread in a GPU warp needs to re-sample, we pay + // the same cost as if the whole warp were to re-sample. So an + // efficient re-sampling implementation on GPU would need to do + // nontrivial work to share entropy between threads in the warp. + auto range = b_->CreateSub(b_or_sigma, a_or_mean); + return b_->CreateAdd(a_or_mean, b_->CreateURem(elem_value, range)); + } + } + case RNG_NORMAL: { + TF_ASSIGN_OR_RETURN( + llvm::Value * r, + EmitErfcInv(elem_prim_ty, + b_->CreateFMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), + elem_value))); + return b_->CreateFAdd(b_->CreateFMul(r, b_or_sigma), a_or_mean); + } + default: + return InvalidArgument( + "unhandled distribution %s", + RandomDistribution_Name(hlo->random_distribution()).c_str()); + } +} + +namespace { + +// Checks that the primitive type is supported by the elemental IR emitter for +// Philox RNG and returns the number of elements in each 128 bit sample of the +// Philox RNG algorithm. +int32 GetNumberOfElementsPerPhiloxRngSample(PrimitiveType elem_prim_ty) { + // Calculate the number of elements, that is the number of random numbers, in + // a 128 bit sample. + switch (elem_prim_ty) { + case U32: + case S32: + case F32: + // The algorithm uses 32 bits to generate values for F16. + case F16: + return 4; + case U64: + case S64: + case F64: + return 2; + default: + // BF16 is converted to F16 by the hlo pass HloElementTypeConverter. + // Other data types are not supported by XLA random operation. + LOG(FATAL) << "Unrecognized primitive type for RNG " << elem_prim_ty; + } + return 0; +} + +// Calculates the four uint32 values for the 128-bit Philox sample. +std::array CalculateSampleValues( + llvm::Value* sample_idx, llvm::Value* hlo_random_value, + llvm::Value* global_random_number, llvm::Value* rng_state, + llvm::IRBuilder<>* b) { + llvm::Type* index_ty = sample_idx->getType(); + + std::array counter_values; + + // Use the sample index to initialize counter[0] and counter[1]. + unsigned index_ty_size_in_bits = index_ty->getPrimitiveSizeInBits(); + CHECK(index_ty_size_in_bits == 32 || index_ty_size_in_bits == 64); + if (index_ty_size_in_bits == 32) { + counter_values[0] = sample_idx; + counter_values[1] = b->getInt32(0); + } else { + std::tie(counter_values[0], counter_values[1]) = + llvm_ir::SplitInt64ToInt32s(b, sample_idx); + } + + // Xor the global state variable with the global random number seed and use + // the result to initialize counter[2] and counter[3]. + std::tie(counter_values[2], counter_values[3]) = llvm_ir::SplitInt64ToInt32s( + b, b->CreateXor(rng_state, global_random_number)); + + // The algorithm uses a 64 bit key, which is also interpreted as two uint32 // values. - llvm::GlobalVariable* state_ptr0 = new llvm::GlobalVariable( - /*M=*/*module_, - /*Ty=*/ir_builder_->getInt64Ty(), - /*isConstant=*/false, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/ir_builder_->getInt64(random_value_from_hlo()), - /*Name=*/"state_ptr0"); - - // When the module config seed is 0, the expected result of a prng is a random - // value. Instead of using the random_value_from_hlo, we need a global random - // value as the graph seed. This is because if we use random_value_from_hlo - // here, then for a newly built hlo graph, it always gives the same number. - uint64 graph_seed = hlo_module_config_.seed() != 0 ? hlo_module_config_.seed() - : GlobalRandomValue(); - llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable( - /*M=*/*module_, - /*Ty=*/ir_builder_->getInt64Ty(), - /*isConstant=*/false, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/ir_builder_->getInt64(graph_seed), - /*Name=*/"state_ptr1"); - - // We want each thread to use its own stream, so we modify the increment per - // thread. We want the increment to remain odd, so we shift the thread id left - // 1 and add it to the increment. - increment = ir_builder_->CreateAdd(increment, - ir_builder_->CreateShl(EmitThreadId(), 1)); - - // PCG-XSL-RR algorithm - // http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf - // state = multiplier * state + increment - // return uint64_t(state ^ (state >> 64))) >>> (state >> 122) - // where ">>>" is bitwise rotation - auto get_next_i64 = [=]() { - llvm::Value* state0 = ir_builder_->CreateZExtOrTrunc( - ir_builder_->CreateLoad(state_ptr0, "state0"), - ir_builder_->getInt128Ty()); - llvm::Value* state1 = ir_builder_->CreateShl( - ir_builder_->CreateZExtOrTrunc( - ir_builder_->CreateLoad(state_ptr1, "state1"), - ir_builder_->getInt128Ty()), - 64); - llvm::Value* state = ir_builder_->CreateOr(state0, state1); - llvm::Value* updated = ir_builder_->CreateAdd( - ir_builder_->CreateMul(state, multiplier), increment); - ir_builder_->CreateStore( - ir_builder_->CreateTrunc(updated, ir_builder_->getInt64Ty()), - state_ptr0); - ir_builder_->CreateStore( - ir_builder_->CreateTrunc(ir_builder_->CreateLShr(updated, 64), - ir_builder_->getInt64Ty()), - state_ptr1); - - return llvm_ir::CreateRor( - ir_builder_->CreateTrunc( - ir_builder_->CreateXor(state, ir_builder_->CreateLShr(state, 64)), - ir_builder_->getInt64Ty()), - ir_builder_->CreateTrunc(ir_builder_->CreateLShr(state, 122), - ir_builder_->getInt64Ty()), - ir_builder_); - }; + llvm::Value* key_values[2]; + + // Use a module random number to initialize the key. + std::tie(key_values[0], key_values[1]) = + llvm_ir::SplitInt64ToInt32s(b, hlo_random_value); + + // Prepare the constants used in the Philox RNG Algorithm. + llvm::Value* philoxW32A = b->getInt32(0x9E3779B9); + llvm::Value* philoxW32B = b->getInt32(0xBB67AE85); + llvm::Value* philoxM4xW32A = b->getInt32(0xD2511F53); + llvm::Value* philoxM4xW32B = b->getInt32(0xCD9E8D57); + + // Compute the 128 bit value for the current sample by repeating the + // single round computation and key raising computation for ten times. + for (int round = 0; round < 10; ++round) { + // A single round of computation of the counter values is as follows: + // MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0); + // MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1); + // counter[0] = hi1 ^ counter[1] ^ key[0]; + // counter[1] = lo1; + // counter[2] = hi0 ^ counter[3] ^ key[1]; + // counter[3] = lo0; + llvm::Value* lo0; + llvm::Value* hi0; + std::tie(lo0, hi0) = + llvm_ir::UMulLowHigh32(b, philoxM4xW32A, counter_values[0]); + llvm::Value* lo1; + llvm::Value* hi1; + std::tie(lo1, hi1) = + llvm_ir::UMulLowHigh32(b, philoxM4xW32B, counter_values[2]); + counter_values[0] = + b->CreateXor(hi1, b->CreateXor(counter_values[1], key_values[0])); + counter_values[1] = lo1; + counter_values[2] = + b->CreateXor(hi0, b->CreateXor(counter_values[3], key_values[1])); + counter_values[3] = lo0; + key_values[0] = b->CreateAdd(key_values[0], philoxW32A); + key_values[1] = b->CreateAdd(key_values[1], philoxW32B); + } - auto get_next_uniform_float = [=]() { - return ir_builder_->CreateFDiv( - ir_builder_->CreateUIToFP(get_next_i64(), param_ir_type), - llvm::ConstantFP::get(param_ir_type, 0x1p64)); - }; + return counter_values; +} +} // namespace + +// Implements the Philox algorithm to generate random numbers in parallel. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +// +// The paper presents a few variants of the Philox algorithm, we picked the +// 4x32_10 version of the algorithm for the following reasons: +// . 4x32 uses 32-bit multiplication which is fast on GPUs. +// . The authors recommend the 10-round variant, and TensorFlow also uses it. +// +// Precondition: the RNG instruction is not fused. +llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) + const { + VLOG(3) << "Using philox RNG algorithm"; + CHECK(!hlo->IsFused()); + // A random number generated by the per module random number generator. + // This ensures that each RNG HLO generates a different random sequence. + llvm::Value* hlo_random_value = b_->getInt64(hlo->GetModule()->RandomNew64()); + // A value specified by the configuration or generated by a global random + // number generator. + llvm::Value* global_random_number = + b_->getInt64(hlo_module_config_.seed() != 0 ? hlo_module_config_.seed() + : GlobalRandomValue()); + + int elems_per_sample = + GetNumberOfElementsPerPhiloxRngSample(hlo->shape().element_type()); + + // Allocate stack storage for the 128 bit sample as four int32. + llvm::Type* int32_ty = b_->getInt32Ty(); + llvm::Value* sample_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + int32_ty, /*element_count=*/b_->getInt32(4), "sample", b_); + + // Load the global state variable for the Philox RNG algorithm. + llvm::GlobalVariable* rng_state_ptr = + llvm_ir::GetOrCreateVariableForPhiloxRngState(module_, b_); + llvm::Value* rng_state = b_->CreateLoad(rng_state_ptr, "rng_state_value"); + + // Build and return the elemental IR generator to generate a random value for + // the element corresponding to the current thread. + // + // This elemental IR generator computes one sample with multiple random + // numbers but only returns one random number. As a result, neighboring + // threads may calculate the same sample unnecessarily. However, if the + // kernel containing the RNG hlo is unrolled, LLVM is able to optimize away + // the duplicated computation of the same sample. In particular, if the unroll + // factor is a multiplier of elems_per_sample, LLVM is able to completely + // remove such duplicated computation. If the unroll factor is a non-trivial + // factor of elems_per_sample, LLVM can only partially remove such duplicated + // computation. return [=](const llvm_ir::IrArray::Index& index) -> StatusOr { - switch (hlo->random_distribution()) { - case RNG_UNIFORM: { - TF_ASSIGN_OR_RETURN(llvm::Value * p, - operand_to_generator.at(hlo->operand(0))(index)); - TF_ASSIGN_OR_RETURN(llvm::Value * q, - operand_to_generator.at(hlo->operand(1))(index)); - if (primitive_util::IsFloatingPointType(param_prim_type)) { - return ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(ir_builder_->CreateFSub(q, p), - get_next_uniform_float()), - p); - } else { - auto r = ir_builder_->CreateSub(q, p); - auto leading_zeros = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(true)}, - {param_ir_type}, ir_builder_); - auto in_block = ir_builder_->GetInsertBlock(); - - // A terminator should be present iff we're emitting code - // into the middle (as opposed to the end) of a basic block. - CHECK_EQ(ir_builder_->GetInsertPoint() == in_block->end(), - in_block->getTerminator() == nullptr); - - llvm::BasicBlock* body_block; - llvm::BasicBlock* out_block; - - if (ir_builder_->GetInsertPoint() == in_block->end()) { - body_block = llvm_ir::CreateBasicBlock( - nullptr, IrName(hlo, "rng_body"), ir_builder_); - out_block = llvm_ir::CreateBasicBlock( - nullptr, IrName(hlo, "rng_out"), ir_builder_); - llvm::BranchInst::Create(body_block, in_block); - } else { - body_block = in_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), "rng_body"); - out_block = body_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), "rng_out"); - body_block->getTerminator()->eraseFromParent(); - } - - SetToFirstInsertPoint(body_block, ir_builder_); - auto random = ir_builder_->CreateAnd( - ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type), - ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0), - leading_zeros)); - llvm::BranchInst::Create(out_block, body_block, - ir_builder_->CreateICmpULT(random, r), - body_block); - SetToFirstInsertPoint(out_block, ir_builder_); - return ir_builder_->CreateAdd( - p, ir_builder_->CreateSelect( - ir_builder_->CreateICmpEQ(p, q), - llvm::ConstantInt::get(param_ir_type, 0), random)); - } - } - case RNG_NORMAL: { - TF_ASSIGN_OR_RETURN(llvm::Value * m, - operand_to_generator.at(hlo->operand(0))(index)); - TF_ASSIGN_OR_RETURN(llvm::Value * s, - operand_to_generator.at(hlo->operand(1))(index)); - TF_ASSIGN_OR_RETURN( - llvm::Value * r, - EmitErfcInv(param_prim_type, - ir_builder_->CreateFMul( - llvm::ConstantFP::get(param_ir_type, 2.0), - get_next_uniform_float()))); - return ir_builder_->CreateFAdd(ir_builder_->CreateFMul(r, s), m); - } - default: - return InvalidArgument( - "unhandled distribution %s", - RandomDistribution_Name(hlo->random_distribution()).c_str()); + llvm::Type* index_ty = index.GetType(); + // Calculate the linear element index. + llvm::Value* elem_idx = index.linear(); + if (elem_idx == nullptr) { + elem_idx = index.Linearize(AsInt64Slice(hlo->shape().dimensions()), b_); } + + // Calculate the index for the 128 bit sample and the offset of the current + // element within the sample. + llvm::Value* elems_per_sample_value = + llvm::ConstantInt::get(index_ty, elems_per_sample); + llvm::Value* sample_idx = b_->CreateUDiv(elem_idx, elems_per_sample_value); + llvm::Value* elem_offset = b_->CreateURem(elem_idx, elems_per_sample_value); + + std::array counter_values = CalculateSampleValues( + sample_idx, hlo_random_value, global_random_number, rng_state, b_); + + // Store the four counter_values into the sample_address alloca so we can + // load the elem_offset'th one below. + for (int idx = 0; idx < 4; ++idx) { + b_->CreateStore(counter_values[idx], + b_->CreateInBoundsGEP(sample_address, b_->getInt32(idx))); + } + + llvm::Type* int64_ty = b_->getInt64Ty(); + CHECK(elems_per_sample == 2 || elems_per_sample == 4); + llvm::Type* raw_value_ty = elems_per_sample == 2 ? int64_ty : int32_ty; + // Retrieve the raw value for the current element from the current sample. + llvm::Value* raw_elem_value = b_->CreateLoad( + b_->CreateInBoundsGEP( + b_->CreatePointerCast(sample_address, raw_value_ty->getPointerTo()), + elem_offset), + "raw_elem_value"); + + return ConvertValueForDistribution(hlo, operand_to_generator, index, + raw_elem_value); }; } @@ -1434,9 +1505,8 @@ StatusOr ElementalIrEmitter::EmitElementalSelect( TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, operand_to_generator.at(hlo->operand(2))( ElementwiseSourceIndex(index, *hlo, 2))); - return ir_builder_->CreateSelect( - ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()), - on_true_value, on_false_value); + return b_->CreateSelect(b_->CreateTrunc(pred_value, b_->getInt1Ty()), + on_true_value, on_false_value); } StatusOr ElementalIrEmitter::EmitElementalClamp( @@ -1472,64 +1542,62 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( const int64 concat_dim = hlo->dimensions(0); auto source_index = target_index; - llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); + llvm::BasicBlock* init_block = b_->GetInsertBlock(); // A terminator should be present iff we're emitting code // into the middle (as opposed to the end) of a basic block. - CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), + CHECK_EQ(b_->GetInsertPoint() == init_block->end(), init_block->getTerminator() == nullptr); llvm::BasicBlock* exit_block; - if (ir_builder_->GetInsertPoint() == init_block->end()) { + if (b_->GetInsertPoint() == init_block->end()) { exit_block = llvm_ir::CreateBasicBlock( - /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); + /*insert_before=*/nullptr, IrName(hlo, "merge"), b_); } else { - exit_block = init_block->splitBasicBlock(ir_builder_->GetInsertPoint(), + exit_block = init_block->splitBasicBlock(b_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge"))); init_block->getTerminator()->eraseFromParent(); } - llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); - llvm::PHINode* output = ir_builder_->CreatePHI( + llvm_ir::SetToFirstInsertPoint(exit_block, b_); + llvm::PHINode* output = b_->CreatePHI( llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), hlo->operands().size()); - auto prior_insert_point = ir_builder_->GetInsertPoint(); + auto prior_insert_point = b_->GetInsertPoint(); - ir_builder_->SetInsertPoint(init_block); + b_->SetInsertPoint(init_block); for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); ++operand_idx) { const HloInstruction* operand = hlo->operand(operand_idx); auto true_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_from_operand", operand_idx), - ir_builder_); + exit_block, StrCat("concat_index_from_operand", operand_idx), b_); auto false_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_not_from_operand", operand_idx), - ir_builder_); + exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_); auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), operand->shape().dimensions(concat_dim)); - ir_builder_->CreateCondBr( - ir_builder_->CreateICmpULT(source_index[concat_dim], concat_dim_size), + b_->CreateCondBr( + b_->CreateICmpULT(source_index[concat_dim], concat_dim_size), true_block, false_block); // Create the terminator of the true block before calling operand // generators, because they require non-degenerate basic blocks. - ir_builder_->SetInsertPoint( + b_->SetInsertPoint( llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); TF_ASSIGN_OR_RETURN(llvm::Value * value, operand_to_generator.at(operand)(source_index)); - output->addIncoming(value, ir_builder_->GetInsertBlock()); + output->addIncoming(value, b_->GetInsertBlock()); // Subtract the size of the concat dimension of the current operand // from the source index. - ir_builder_->SetInsertPoint(false_block); + b_->SetInsertPoint(false_block); source_index[concat_dim] = - ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size); + b_->CreateSub(source_index[concat_dim], concat_dim_size); } - ir_builder_->CreateUnreachable(); - ir_builder_->SetInsertPoint(exit_block, prior_insert_point); + b_->CreateUnreachable(); + b_->SetInsertPoint(exit_block, prior_insert_point); return output; } @@ -1540,41 +1608,40 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( // Emit IR to read dynamic start indices from hlo->operand(1). const HloInstruction* input_hlo = hlo->operand(0); const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); + // Use the same index type for all tensor accesses in the same kernel. + llvm::Type* index_type = index.GetType(); + llvm_ir::IrArray::Index slice_start_index(index_type, rank); for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, operand_to_generator.at(hlo->operand(1))(dim_index)); // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) + start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); + int64 largest_valid_start_index = + input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i); + CHECK_GE(largest_valid_start_index, 0); - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. - start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, - index[i]->getType()); - llvm::Value* operand_dim_size = llvm::ConstantInt::get( - start_index_value->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* output_dim_size = llvm::ConstantInt::get( - start_index_value->getType(), hlo->shape().dimensions(i)); - + bool is_signed = ShapeUtil::ElementIsSigned(hlo->operand(1)->shape()); start_index_value = EmitIntegralMin( - ir_builder_->CreateSub(operand_dim_size, output_dim_size), - EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), - start_index_value, /*is_signed=*/true), - /*is_signed=*/true); + index_typed_const(largest_valid_start_index), + EmitIntegralMax(index_typed_const(0), start_index_value, is_signed), + is_signed); start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = start_index_value; } - llvm_ir::IrArray::Index input_index(rank); + llvm_ir::IrArray::Index input_index(index_type, rank); for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: // input_index = start_index + offset_index - input_index[i] = ir_builder_->CreateAdd(slice_start_index[i], index[i]); + input_index[i] = b_->CreateAdd(slice_start_index[i], index[i]); } return operand_to_generator.at(input_hlo)(input_index); } @@ -1594,25 +1661,29 @@ StatusOr ElementalIrEmitter::EmitElementalGather( const llvm_ir::ElementGenerator& indices_generator = operand_to_generator.at(hlo->operand(1)); + llvm::Type* index_type = index.GetType(); // This is the index into `operand` that holds the element we want to - // generate. This index "unsafe" as in the components in here may be - // out of bounds. - IrArray::Index unsafe_operand_index; - - // First copy in the window indices to unsafe_operand_index. - for (int64 i = 0, e = operand_shape.dimensions_size(), - unsafe_operand_index_dim = 0; + // generate. + IrArray::Index operand_index(index_type); + + // First copy in the window indices to operand_index. Also collect a mapping + // from operand dimension to output window dimension. Elided window dimensions + // map to -1. + std::vector operand_to_output_dim(operand_shape.dimensions_size(), -1); + for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0; i < e; i++) { if (c_binary_search(dim_numbers.elided_window_dims(), i)) { - unsafe_operand_index.push_back(ir_builder_->getInt64(0)); + operand_index.push_back(index.GetConstantWithIndexType(0)); } else { - unsafe_operand_index.push_back( - index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]); + int64 output_window_dim = + dim_numbers.output_window_dims(operand_index_dim++); + operand_to_output_dim[i] = output_window_dim; + operand_index.push_back(index[output_window_dim]); } } // This is the index of the index vector in the gather_indices tensor. - IrArray::Index gather_index_index; + IrArray::Index gather_index_index(index_type); { std::vector gather_index_index_components; for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { @@ -1626,40 +1697,52 @@ StatusOr ElementalIrEmitter::EmitElementalGather( } } - auto add_to_unsafe_operand_index = [&](llvm::Value* index_component, - int64 dim) { - llvm::Value* gather_dim_component_extended = ir_builder_->CreateSExtOrTrunc( - index_component, ir_builder_->getInt64Ty()); - unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] = - ir_builder_->CreateAdd( - unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)], - gather_dim_component_extended); + auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) { + llvm::Value* gather_dim_component_extended = + b_->CreateSExtOrTrunc(index_component, index_type); + int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim); + int64 output_dim = operand_to_output_dim[operand_dim]; + // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim. + // This means we set the iteration index to 0, so for the purpose of the + // following calculations we can consider the output dimension size to be 1. + int64 output_dim_size = + output_dim == -1 ? 1 : output_shape.dimensions(output_dim); + int64 largest_valid_start_index = + operand_shape.dimensions(operand_dim) - output_dim_size; + CHECK_GE(largest_valid_start_index, 0); + + // Clamp the gather index so that the gather region fits in the operand. + // gather_dim_component_extended_inbound = + // clamp(gather_dim_component_extended, 0, largest_valid_start_index); + + // TODO(b/111078873): This is implementation defined behavior. + bool is_signed = ShapeUtil::ElementIsSigned(indices_shape); + auto gather_dim_component_extended_inbound = EmitIntegralMin( + index.GetConstantWithIndexType(largest_valid_start_index), + EmitIntegralMax(index.GetConstantWithIndexType(0), + gather_dim_component_extended, is_signed), + is_signed); + + operand_index[operand_dim] = b_->CreateAdd( + operand_index[operand_dim], gather_dim_component_extended_inbound); }; if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, indices_generator(gather_index_index)); - add_to_unsafe_operand_index(gather_dim_component, 0); + add_to_operand_index(gather_dim_component, 0); } else { int64 index_vector_size = indices_shape.dimensions(dim_numbers.index_vector_dim()); for (int64 i = 0; i < index_vector_size; i++) { gather_index_index[dim_numbers.index_vector_dim()] = - ir_builder_->getInt64(i); + index.GetConstantWithIndexType(i); TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, indices_generator(gather_index_index)); - add_to_unsafe_operand_index(gather_dim_component, i); + add_to_operand_index(gather_dim_component, i); } } - - IrArray::Index safe_operand_index; - for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) { - safe_operand_index.push_back(ir_builder_->CreateURem( - unsafe_operand_index[i], - ir_builder_->getInt64(operand_shape.dimensions(i)))); - } - - return operand_generator(safe_operand_index); + return operand_generator(operand_index); } StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( @@ -1671,49 +1754,46 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* start_hlo = hlo->operand(2); // Calculate slice start/end indices. const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); - llvm_ir::IrArray::Index slice_limit_index(rank); + llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank); + llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank); // Slice intersection gathers (ANDs) conditions on all ranks for which // 'input' is set to 'update' - llvm::Value* slice_intersection = ir_builder_->getTrue(); + llvm::Value* slice_intersection = b_->getTrue(); for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + llvm::Type* index_type = index[0]->getType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, operand_to_generator.at(start_hlo)(dim_index)); // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. - start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, - index[i]->getType()); - llvm::Value* input_dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - + start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); + llvm::Value* update_dim_size = + index_typed_const(update_hlo->shape().dimensions(i)); + int64 largest_valid_start_index = + input_hlo->shape().dimensions(i) - update_hlo->shape().dimensions(i); + CHECK_GE(largest_valid_start_index, 0); + + bool is_signed = ShapeUtil::ElementIsSigned(start_hlo->shape()); start_index_value = EmitIntegralMin( - ir_builder_->CreateSub(input_dim_size, update_dim_size), - EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), - start_index_value, /*is_signed=*/true), - /*is_signed=*/true); + index_typed_const(largest_valid_start_index), + EmitIntegralMax(index_typed_const(0), start_index_value, is_signed), + is_signed); start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = start_index_value; - slice_limit_index[i] = - ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); + slice_limit_index[i] = b_->CreateAdd(slice_start_index[i], update_dim_size); - slice_intersection = ir_builder_->CreateAnd( - slice_intersection, - ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), + slice_intersection = b_->CreateAnd( + slice_intersection, b_->CreateICmpSGE(index[i], slice_start_index[i]), "slice_intersection"); - slice_intersection = ir_builder_->CreateAnd( - slice_intersection, - ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), + slice_intersection = b_->CreateAnd( + slice_intersection, b_->CreateICmpSLT(index[i], slice_limit_index[i]), "slice_intersection"); } @@ -1722,29 +1802,29 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // else -> return data from 'input'. llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), - "ret_value_addr", ir_builder_); - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - slice_intersection, "slice_intersection", ir_builder_); + "ret_value_addr", b_); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(slice_intersection, "slice_intersection", b_); // Handle true BB (return data from 'update') - SetToFirstInsertPoint(if_data.true_block, ir_builder_); + SetToFirstInsertPoint(if_data.true_block, b_); // Compute update index for intersection case. - llvm_ir::IrArray::Index update_index(rank); + llvm_ir::IrArray::Index update_index(index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { - update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]); + update_index[i] = b_->CreateSub(index[i], slice_start_index[i]); } TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); - ir_builder_->CreateStore(true_value, ret_value_addr); + b_->CreateStore(true_value, ret_value_addr); // Handle false BB (return data from 'input') - SetToFirstInsertPoint(if_data.false_block, ir_builder_); + SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * false_value, operand_to_generator.at(input_hlo)(index)); - ir_builder_->CreateStore(false_value, ret_value_addr); + b_->CreateStore(false_value, ret_value_addr); - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - return ir_builder_->CreateLoad(ret_value_addr); + SetToFirstInsertPoint(if_data.after_block, b_); + return b_->CreateLoad(ret_value_addr); } StatusOr ElementalIrEmitter::EmitElementalPad( @@ -1752,29 +1832,29 @@ StatusOr ElementalIrEmitter::EmitElementalPad( const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& padded_index) const { auto index = padded_index; - llvm::Value* in_bounds = ir_builder_->getTrue(); + llvm::Value* in_bounds = b_->getTrue(); for (size_t i = 0; i < index.size(); ++i) { auto index_typed_const = [=](int64 n) { return llvm::ConstantInt::get(index[i]->getType(), n); }; const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = ir_builder_->CreateSub( - index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = ir_builder_->CreateAnd( - in_bounds, ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), - "in_bounds"); - in_bounds = ir_builder_->CreateAnd( + index[i] = + b_->CreateSub(index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = b_->CreateAnd(in_bounds, + b_->CreateICmpSGE(index[i], index_typed_const(0)), + "in_bounds"); + in_bounds = b_->CreateAnd( in_bounds, - ir_builder_->CreateICmpEQ( + b_->CreateICmpEQ( index_typed_const(0), - ir_builder_->CreateURem( - index[i], index_typed_const(pad_dim.interior_padding() + 1))), + b_->CreateURem(index[i], + index_typed_const(pad_dim.interior_padding() + 1))), "in_bounds"); - index[i] = ir_builder_->CreateSDiv( + index[i] = b_->CreateSDiv( index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = ir_builder_->CreateAnd( + in_bounds = b_->CreateAnd( in_bounds, - ir_builder_->CreateICmpSLT( + b_->CreateICmpSLT( index[i], index_typed_const(hlo->operand(0)->shape().dimensions(i))), "in_bounds"); @@ -1787,25 +1867,26 @@ StatusOr ElementalIrEmitter::EmitElementalPad( // } llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), - "pad_result_addr", ir_builder_); + "pad_result_addr", b_); llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - SetToFirstInsertPoint(if_data.true_block, ir_builder_); + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); + SetToFirstInsertPoint(if_data.true_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(0))(index)); - ir_builder_->CreateStore(operand_value, ret_value_addr); + b_->CreateStore(operand_value, ret_value_addr); - SetToFirstInsertPoint(if_data.false_block, ir_builder_); + SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, - operand_to_generator.at(hlo->operand(1))({})); - ir_builder_->CreateStore(padding_value, ret_value_addr); + operand_to_generator.at(hlo->operand(1))( + IrArray::Index(index.GetType()))); + b_->CreateStore(padding_value, ret_value_addr); - SetToFirstInsertPoint(if_data.after_block, ir_builder_); + SetToFirstInsertPoint(if_data.after_block, b_); // Don't create phi(operand_value, padding_value) here, because invoking // operand_to_generator may create new basic blocks, making the parent // of operand_value or padding_value no longer a predecessor of // if_data.after_block. - return ir_builder_->CreateLoad(ret_value_addr); + return b_->CreateLoad(ret_value_addr); } StatusOr ElementalIrEmitter::EmitElementalDot( @@ -1824,21 +1905,25 @@ StatusOr ElementalIrEmitter::EmitElementalDot( int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); + llvm::Type* index_type = dot_result_index[0]->getType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + std::unique_ptr inner_loop = llvm_ir::ForLoop::EmitForLoop( - IrName(hlo, "inner"), ir_builder_->getInt64(0), - ir_builder_->getInt64(contracted_dim_size), ir_builder_->getInt64(1), - ir_builder_); + IrName(hlo, "inner"), index_typed_const(0), + index_typed_const(contracted_dim_size), index_typed_const(1), b_); - SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_); + SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_); PrimitiveType primitive_type = hlo->shape().element_type(); llvm::Type* primitive_type_llvm = llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); - llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry( - primitive_type_llvm, "dot_acc", ir_builder_); - ir_builder_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm), - accumulator_alloca); + llvm::Value* accumulator_alloca = + llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_); + b_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm), + accumulator_alloca); - SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_); + SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_); // This is the inner reduction loop for a dot operation that produces // one element in the output. If the operands to the dot operation have @@ -1846,7 +1931,7 @@ StatusOr ElementalIrEmitter::EmitElementalDot( // Given an output index [a,b,c,d,e] in the result, we compute: // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) - IrArray::Index lhs_index, rhs_index; + IrArray::Index lhs_index(index_type), rhs_index(index_type); for (int64 i = 0; i < lhs_dims - 1; i++) { lhs_index.push_back(dot_result_index[i]); @@ -1858,43 +1943,36 @@ StatusOr ElementalIrEmitter::EmitElementalDot( } rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); - llvm::Value* current_accumulator = - ir_builder_->CreateLoad(accumulator_alloca); + llvm::Value* current_accumulator = b_->CreateLoad(accumulator_alloca); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); llvm::Value* next_accumulator; if (primitive_util::IsComplexType(primitive_type)) { - llvm::Value* product_real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); - llvm::Value* product_imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value))); - next_accumulator = ir_builder_->CreateInsertValue( + llvm::Value* product_real = b_->CreateFSub( + b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); + llvm::Value* product_imag = b_->CreateFAdd( + b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), + b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))); + next_accumulator = b_->CreateInsertValue( current_accumulator, - ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), - product_real), + b_->CreateFAdd(EmitExtractReal(current_accumulator), product_real), {0}); - next_accumulator = ir_builder_->CreateInsertValue( + next_accumulator = b_->CreateInsertValue( next_accumulator, - ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), - product_imag), + b_->CreateFAdd(EmitExtractImag(current_accumulator), product_imag), {1}); } else if (primitive_util::IsFloatingPointType(primitive_type)) { - next_accumulator = ir_builder_->CreateFAdd( - current_accumulator, ir_builder_->CreateFMul(lhs_value, rhs_value)); + next_accumulator = b_->CreateFAdd(current_accumulator, + b_->CreateFMul(lhs_value, rhs_value)); } else { - next_accumulator = ir_builder_->CreateAdd( - current_accumulator, ir_builder_->CreateMul(lhs_value, rhs_value)); + next_accumulator = + b_->CreateAdd(current_accumulator, b_->CreateMul(lhs_value, rhs_value)); } - ir_builder_->CreateStore(next_accumulator, accumulator_alloca); + b_->CreateStore(next_accumulator, accumulator_alloca); - SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_); - return ir_builder_->CreateLoad(accumulator_alloca); + SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_); + return b_->CreateLoad(accumulator_alloca); } llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( @@ -1945,6 +2023,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kShiftLeft: @@ -1993,7 +2072,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); auto source_index = target_index; for (int64 dim : hlo->dimensions()) { - source_index[dim] = ir_builder_->CreateSub( + source_index[dim] = b_->CreateSub( llvm::ConstantInt::get(target_index[dim]->getType(), hlo->shape().dimensions(dim) - 1), target_index[dim]); @@ -2006,16 +2085,16 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); // The `dimensions` member of the broadcast instruction maps from // input dimensions to output dimensions. - return operand_to_generator.at( - operand)(target_index.SourceIndexOfBroadcast( - hlo->shape(), operand->shape(), hlo->dimensions(), ir_builder_)); + return operand_to_generator.at(operand)( + target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(), + hlo->dimensions(), b_)); }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { IrArray::Index sliced_index = index.SourceIndexOfSlice( /*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(), - /*strides=*/hlo->slice_strides(), /*builder=*/ir_builder_); + /*strides=*/hlo->slice_strides(), /*builder=*/b_); return operand_to_generator.at(hlo->operand(0))(sliced_index); }; case HloOpcode::kDynamicSlice: @@ -2040,27 +2119,26 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( ShapeUtil::ElementsIn(hlo->operand(0)->shape())); return [this, hlo, &operand_to_generator](const IrArray::Index& index) { const HloInstruction* operand = hlo->operand(0); - return operand_to_generator.at(operand)(index.SourceIndexOfBitcast( - hlo->shape(), operand->shape(), ir_builder_)); + return operand_to_generator.at(operand)( + index.SourceIndexOfBitcast(hlo->shape(), operand->shape(), b_)); }; case HloOpcode::kReshape: CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()), ShapeUtil::ElementsIn(hlo->operand(0)->shape())); return [this, hlo, &operand_to_generator](const IrArray::Index& index) { const HloInstruction* operand = hlo->operand(0); - return operand_to_generator.at(operand)(index.SourceIndexOfReshape( - hlo->shape(), operand->shape(), ir_builder_)); + return operand_to_generator.at(operand)( + index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_)); }; case HloOpcode::kTranspose: return [this, hlo, &operand_to_generator](const IrArray::Index& target_index) { return operand_to_generator.at(hlo->operand(0))( target_index.SourceIndexOfTranspose( - hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(), - ir_builder_)); + hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(), b_)); }; case HloOpcode::kRng: - return MakeRngElementGenerator(hlo, operand_to_generator); + return MakePhiloxRngElementGenerator(hlo, operand_to_generator); case HloOpcode::kPad: return [this, hlo, &operand_to_generator]( const IrArray::Index& padded_index) -> StatusOr { @@ -2074,7 +2152,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; default: - return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", HloOpcodeString(hlo->opcode()).c_str()); }; @@ -2082,11 +2160,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( } llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const { - return ir_builder_->CreateExtractValue(value, {0}); + return b_->CreateExtractValue(value, {0}); } llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const { - return ir_builder_->CreateExtractValue(value, {1}); + return b_->CreateExtractValue(value, {1}); } llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, @@ -2094,10 +2172,10 @@ llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, llvm::Value* imag) const { auto cplx_type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto complex = ir_builder_->CreateInsertValue( + auto complex = b_->CreateInsertValue( llvm::ConstantAggregateZero::get(cplx_type), real, {0}); if (imag != nullptr) { - complex = ir_builder_->CreateInsertValue(complex, imag, {1}); + complex = b_->CreateInsertValue(complex, imag, {1}); } return complex; } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index d199473374ad394913413a7d3fe805f8782936f7..1598a4dd85632cfa9835a81a21eddff3e57bfa1f 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -34,10 +34,8 @@ class ElementalIrEmitter { std::unordered_map; ElementalIrEmitter(const HloModuleConfig& hlo_module_config, - llvm::Module* module, llvm::IRBuilder<>* ir_builder) - : ir_builder_(ir_builder), - module_(module), - hlo_module_config_(hlo_module_config) {} + llvm::Module* module, llvm::IRBuilder<>* b) + : b_(b), module_(module), hlo_module_config_(hlo_module_config) {} virtual ~ElementalIrEmitter() = default; @@ -54,7 +52,7 @@ class ElementalIrEmitter { const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) const; - llvm::IRBuilder<>* ir_builder() const { return ir_builder_; } + llvm::IRBuilder<>* b() const { return b_; } llvm::Module* module() const { return module_; } protected: @@ -124,6 +122,9 @@ class ElementalIrEmitter { llvm::Value* lhs, llvm::Value* rhs) const; + virtual StatusOr EmitTanh(PrimitiveType prim_type, + llvm::Value* value) const; + virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x) const; @@ -144,9 +145,7 @@ class ElementalIrEmitter { int64 operand_no) const; // Identifier of the thread unique among all threads on the device - virtual llvm::Value* EmitThreadId() const { - return ir_builder_->getIntN(128, 0); - } + virtual llvm::Value* EmitThreadId() const { return b_->getIntN(128, 0); } StatusOr EmitElementalSelect( const HloInstruction* hlo, @@ -188,7 +187,7 @@ class ElementalIrEmitter { const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& dot_result_index) const; - llvm::IRBuilder<>* const ir_builder_; + llvm::IRBuilder<>* const b_; llvm::Module* module_; @@ -197,10 +196,17 @@ class ElementalIrEmitter { const HloModuleConfig& hlo_module_config_; private: - // Returns a ElementGenerator for a RNG HloInstruction. - llvm_ir::ElementGenerator MakeRngElementGenerator( + // Returns a ElementGenerator for an RNG HloInstruction using the Philox + // random number generation algorithm. + llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) const; + // Converts the raw value generated by a random number generation algorithm + // to the distribution requested by the RNG HloInstruction. + StatusOr ConvertValueForDistribution( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index 8980d4303353a132ada2b3c685b4f2856c33c6a1..addb016b0481b744ff42ba827104099b6cdc3bb9 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -57,8 +57,8 @@ ENTRY main { } )"; - std::unique_ptr lhs = Literal::CreateR3({{{1}, {2}}}); - std::unique_ptr rhs = Literal::CreateR3({{{3}, {4}}}); + std::unique_ptr lhs = LiteralUtil::CreateR3({{{1}, {2}}}); + std::unique_ptr rhs = LiteralUtil::CreateR3({{{3}, {4}}}); RunTest(hlo_text, {lhs.get(), rhs.get()}); } } // namespace diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 6df172db8e541c5cef7aab04f3d8611fc735e8b0..fd75847d0c0e737957401b8efc420d504a3c0706 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -82,7 +82,18 @@ StatusOr Executable::ExecuteOnStreamWrapper( StatusOr return_value = ExecuteOnStream(run_options, arguments, profile_ptr.get()); - TF_RETURN_IF_ERROR(return_value.status()); + if (!return_value.status().ok()) { + if (profile != nullptr) { + // Ensure the ThenStartTimer call has completed before we destroy timer. + // We already have a failure status to return, so just log this if it + // fails. + Status status = stream->BlockHostUntilDone(); + if (!status.ok()) { + LOG(ERROR) << "Failed to BlockHostUntilDone: " << status; + } + } + return return_value.status(); + } if (profile != nullptr) { VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; @@ -116,6 +127,11 @@ StatusOr Executable::ExecuteOnStreamWrapper( if (profile->compute_time_ns() == 0) { profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); } + + const int64 executable_size_in_bytes = SizeInBytes(); + if (executable_size_in_bytes != 0) { + profile->set_executable_size_in_bytes(executable_size_in_bytes); + } } if (profile_ptr != nullptr) { @@ -129,6 +145,8 @@ StatusOr Executable::ExecuteOnStreamWrapper( return return_value; } +int64 Executable::SizeInBytes() { return -1; } + Status Executable::DumpHloSnapshot() { TF_RET_CHECK(dumping_snapshot()); TF_RET_CHECK(hlo_snapshot_->has_hlo() && diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index dc1f26ea65cc707d4f0522af2aa3ec40621632f1..98eaeee30a693211ae564a5ef3c373f0364bef11 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -88,8 +88,7 @@ class Executable { // called explicitly for other (async, for example) variants after the stream // has completed. virtual Status PopulateExecutionProfile( - HloExecutionProfile* hlo_execution_profile, - se::StreamExecutor* executor) { + HloExecutionProfile* hlo_execution_profile, se::Stream* stream) { return Status::OK(); } @@ -132,10 +131,14 @@ class Executable { // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. - const Shape& host_result_shape() const { - return hlo_module_->config().host_entry_computation_layout().result_shape(); + const Shape& result_shape() const { + return hlo_module_->config().entry_computation_layout().result_shape(); } + // Returns the size of the executable in bytes. Returns -1 by default if the + // method is not overridden to support this kind of query. + virtual int64 SizeInBytes(); + // Dumping helpers. void set_hlo_snapshot(std::unique_ptr hlo_snapshot) { hlo_snapshot_ = std::move(hlo_snapshot); diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 6794cfe297b0fb9a15eb9b7e6906d225f9597d07..228c3fac95c3114484637bd93ec51c60b44403cc 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -25,7 +25,7 @@ limitations under the License. namespace xla { AsyncExecution::AsyncExecution(Backend* backend, - std::vector streams, + std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result) : backend_(CHECK_NOTNULL(backend)), @@ -46,9 +46,10 @@ Status AsyncExecution::BlockUntilDone() const { ExecutionTracker::ExecutionTracker() : next_handle_(1) {} -ExecutionHandle ExecutionTracker::Register( - Backend* backend, std::vector streams, - const ExecutionProfile& profile, GlobalDataHandle result) { +ExecutionHandle ExecutionTracker::Register(Backend* backend, + std::vector streams, + const ExecutionProfile& profile, + GlobalDataHandle result) { tensorflow::mutex_lock lock(execution_mutex_); int64 handle = next_handle_++; auto inserted = handle_to_execution_.emplace( diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h index 4458152dd9a98890fc3a3e7f324245ec68821467..4e9b9f883e26f5564a9c63a40d2b4b9348908214 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.h +++ b/tensorflow/compiler/xla/service/execution_tracker.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,7 +40,7 @@ namespace xla { // the stream when destructed. class AsyncExecution { public: - AsyncExecution(Backend* backend, std::vector streams, + AsyncExecution(Backend* backend, std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result); Status BlockUntilDone() const; @@ -54,7 +54,7 @@ class AsyncExecution { Backend* backend_; // Stream on which the execution is launched. - std::vector streams_; + std::vector streams_; // Profile object of the execution to be returned to the user. ExecutionProfile profile_; @@ -72,7 +72,7 @@ class ExecutionTracker { // Registers an execution with its backend, streams, and data handle to the // execution result. Returns a handle for the registered execution. ExecutionHandle Register(Backend* backend, - std::vector stream, + std::vector stream, const ExecutionProfile& profile, GlobalDataHandle data); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index d3854b40de3572a60df1ad99d8a4589f59ad7194..8f6608241ed02bbb7e9fde9b6d767c002435e777 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -80,7 +80,7 @@ class FlattenCallGraphTest : public HloTestBase { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); return builder.Build(); @@ -157,7 +157,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(PRED, {}), "param0")); HloInstruction* false_constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); builder.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, param0, false_constant)); @@ -168,7 +168,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { { HloComputation::Builder builder(TestName() + ".entry"); HloInstruction* false_constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateWhile( ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation, false_constant)); @@ -232,11 +232,11 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { // computation in the true and false branch. HloComputation::Builder builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.0f))); builder.AddInstruction(HloInstruction::CreateConditional( kScalarShape, pred, constant1, sub_computation, constant2, sub_computation)); diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 2d3e4b1fcdf6675955714cab262a8b2ca8ff4297..e3a42d0d06be9e4c9ef96ed2e6ff5daa8eebaf3e 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -113,7 +114,7 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( const Shape& index_shape = index_vector->shape(); HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateFromDimensions(index_shape.element_type(), {1}))); + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); // We extract out individual components from the smaller index and concatenate // them (interspersing zeros as needed) into the larger index. @@ -300,7 +301,7 @@ static StatusOr PermuteGatherAndWindowDims( StatusOr GatherExpander::ExpandGather( HloInstruction* gather_instr) { - CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape())); + CHECK(!ShapeUtil::IsZeroElementArray(gather_instr->shape())); HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); @@ -369,7 +370,7 @@ StatusOr GatherExpander::Run(HloModule* module) { return inst->opcode() == HloOpcode::kGather && // Avoid expanding gather ops that produce zero sized tensors, // instead punt these to ZeroSizedHloElimination. - !ShapeUtil::HasZeroElements(inst->shape()); + !ShapeUtil::IsZeroElementArray(inst->shape()); }; std::vector gather_instrs; diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 5ee67ccb4ae147683c7b41941670c6fc413a0d09..0ce2db907b643f3beabd127388370dbe601179e1 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -20,11 +20,10 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -43,7 +42,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const { } Status GenericTransferManager::WriteSingleTupleIndexTable( - se::StreamExecutor* executor, + se::Stream* stream, tensorflow::gtl::ArraySlice elements, const Shape& shape, se::DeviceMemoryBase* region) { TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape)); @@ -52,13 +51,27 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( for (const se::DeviceMemoryBase& element : elements) { element_pointers.push_back(element.opaque()); } - return TransferBufferToDevice(executor, GetByteSizeRequirement(shape), - element_pointers.data(), region); + TF_RETURN_IF_ERROR(TransferBufferToDevice( + stream, GetByteSizeRequirement(shape), element_pointers.data(), region)); + // Ensure the buffer is transferred before we destroy element_pointers. + return stream->BlockHostUntilDone(); } -StatusOr> -GenericTransferManager::TransferLiteralFromDevice( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { +void GenericTransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, std::function done) { + Status status = stream->BlockHostUntilDone(); + if (!status.ok()) { + return done(status); + } + + done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer, + literal)); +} + +Status GenericTransferManager::TransferLiteralFromDeviceInternal( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal) { VLOG(2) << "transferring literal from device ordinal " << executor->device_ordinal() << "; device buffer: " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); @@ -68,28 +81,24 @@ GenericTransferManager::TransferLiteralFromDevice( TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), device_buffer.on_host_shape())); - std::unique_ptr literal = - Literal::CreateFromShape(device_buffer.on_host_shape()); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { - if (!ShapeUtil::IsTuple(subshape)) { - TF_RETURN_IF_ERROR(TransferBufferFromDevice( - executor, + if (ShapeUtil::IsArray(subshape)) { + TF_RETURN_IF_ERROR(executor->SynchronousMemcpyD2H( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), /*destination=*/ - literal->untyped_data(index))); + literal.untyped_data(index))); } return Status::OK(); })); - return std::move(literal); + return Status::OK(); } -Status GenericTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const LiteralSlice& literal, +Status GenericTransferManager::TransferLiteralToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, const ShapedBuffer& device_buffer) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " @@ -103,9 +112,10 @@ Status GenericTransferManager::TransferLiteralToDevice( TF_RET_CHECK( ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape())); - TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + TF_RET_CHECK(stream->parent()->device_ordinal() == + device_buffer.device_ordinal()); - TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer)); + TF_RETURN_IF_ERROR(WriteTupleIndexTables(stream, device_buffer)); return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), @@ -121,16 +131,21 @@ Status GenericTransferManager::TransferLiteralToDevice( if (LayoutUtil::Equal(device_subshape.layout(), subliteral.shape().layout())) { source = subliteral.untyped_data(); + return TransferBufferToDevice( + stream, + /*size=*/GetByteSizeRequirement(device_subshape), source, + &device_memory); } else { // Relayout data before transferring. relayed_out_literal = subliteral.Relayout(device_subshape.layout(), /*shape_index=*/{}); source = relayed_out_literal->untyped_data(); + TF_RETURN_IF_ERROR(TransferBufferToDevice( + stream, + /*size=*/GetByteSizeRequirement(device_subshape), source, + &device_memory)); + return stream->BlockHostUntilDone(); } - return TransferBufferToDevice( - executor, - /*size=*/GetByteSizeRequirement(device_subshape), source, - &device_memory); } return Status::OK(); }); @@ -141,16 +156,10 @@ Status GenericTransferManager::TransferLiteralToInfeed( return Unimplemented("Generic transfer to Infeed"); } -Status GenericTransferManager::TransferBufferToInfeed( - se::StreamExecutor* executor, int64 size, const void* source) { - return Unimplemented("Generic transfer to Infeed"); -} - Status GenericTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) { - return Unimplemented( - "Outfeed is not supported on this platform (b/30467474)"); + MutableBorrowingLiteral literal) { + return Unimplemented("Generic transfer from Outfeed"); } Status GenericTransferManager::ResetDevices( diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 3da9570ef7eebcdf618439f628fb4d5589993e4f..6c1a21587a7ef5199afb93715dc57be5139fbc22 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/transfer_manager.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -41,18 +40,20 @@ class GenericTransferManager : public TransferManager { se::Platform::Id PlatformId() const override; - StatusOr> TransferLiteralFromDevice( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) override; + void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function done) override; - Status TransferLiteralToDevice(se::StreamExecutor* executor, - const LiteralSlice& literal, - const ShapedBuffer& device_buffer) override; + Status TransferLiteralToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer) override; Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; Status ResetDevices( tensorflow::gtl::ArraySlice executors) override; @@ -60,15 +61,16 @@ class GenericTransferManager : public TransferManager { int64 GetByteSizeRequirement(const Shape& shape) const override; protected: - Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, - const void* source) override; - Status WriteSingleTupleIndexTable( - se::StreamExecutor* executor, + se::Stream* stream, tensorflow::gtl::ArraySlice elements, const Shape& shape, se::DeviceMemoryBase* region) override; private: + Status TransferLiteralFromDeviceInternal(se::StreamExecutor* executor, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal); + // The platform this transfer manager targets. const se::Platform::Id platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 5e02631a5856a45762027d246144d6f6dd913c26..bacd2c1f146385dac0ec5978c5f8ec9f463cf550 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -36,6 +36,7 @@ cc_library( hdrs = ["gpu_constants.h"], deps = [ "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework", ], ) @@ -113,11 +114,13 @@ cc_library( srcs = ["hlo_to_ir_bindings.cc"], hdrs = ["hlo_to_ir_bindings.h"], deps = [ + ":buffer_allocations", ":ir_emission_utils", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", @@ -141,6 +144,7 @@ cc_library( ], deps = [ ":backend_configs", + ":buffer_allocations", ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", @@ -149,8 +153,7 @@ cc_library( ":ir_emission_utils", ":parallel_loop_emitter", ":partition_assignment", - ":while_transformer", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -162,13 +165,17 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:name_uniquer", + "//tensorflow/compiler/xla/service:while_loop_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", + "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_tiling", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", - "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/compiler/xla/service/llvm_ir:sort_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -199,7 +206,7 @@ cc_library( srcs = ["elemental_ir_emitter.cc"], hdrs = ["elemental_ir_emitter.h"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -214,6 +221,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/compiler/xla/service/llvm_ir:math_ops", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", @@ -237,6 +245,20 @@ cc_library( ], ) +cc_library( + name = "hlo_execution_profiler", + srcs = ["hlo_execution_profiler.cc"], + hdrs = ["hlo_execution_profiler.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_execution_profile", + "//tensorflow/compiler/xla/service:stream_pool", + "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + cc_library( name = "gpu_executable", srcs = [ @@ -251,7 +273,9 @@ cc_library( "infeed_thunk.cc", "kernel_thunk.cc", "memset_thunk.cc", + "outfeed_thunk.cc", "sequential_thunk.cc", + "thunk.cc", "thunk_schedule.cc", "tuple_thunk.cc", "while_thunk.cc", @@ -268,6 +292,7 @@ cc_library( "infeed_thunk.h", "kernel_thunk.h", "memset_thunk.h", + "outfeed_thunk.h", "sequential_thunk.h", "thunk.h", "thunk_schedule.h", @@ -275,14 +300,16 @@ cc_library( "while_thunk.h", ], deps = [ - ":backend_configs", ":buffer_allocations", ":cudnn_convolution_runner", + ":hlo_execution_profiler", ":infeed_manager", ":ir_emission_utils", + ":outfeed_manager", ":partition_assignment", ":stream_assignment", "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -300,6 +327,7 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -336,6 +364,7 @@ cc_library( ":cudnn_convolution_runner", ":gpu_executable", ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", @@ -367,7 +396,7 @@ cc_library( hdrs = ["cudnn_convolution_rewriter.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -428,6 +457,8 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":instruction_fusion", + ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:multi_output_fusion", @@ -439,6 +470,7 @@ tf_cc_test( name = "multi_output_fusion_test", srcs = ["multi_output_fusion_test.cc"], deps = [ + ":instruction_fusion", ":multi_output_fusion", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -501,6 +533,7 @@ cc_library( hdrs = ["pad_insertion.h"], deps = [ ":ir_emission_utils", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", @@ -511,13 +544,48 @@ cc_library( ], ) +cc_library( + name = "pad_for_tensor_cores", + srcs = ["pad_for_tensor_cores.cc"], + hdrs = ["pad_for_tensor_cores.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:shape_inference", + ], +) + +tf_cc_test( + name = "pad_for_tensor_cores_test", + srcs = ["pad_for_tensor_cores_test.cc"], + deps = [ + ":ir_emission_utils", + ":pad_for_tensor_cores", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep + ], +) + cc_library( name = "gpu_transfer_manager", srcs = ["gpu_transfer_manager.cc"], hdrs = ["gpu_transfer_manager.h"], deps = [ ":gpu_compiler", + ":outfeed_manager", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -536,8 +604,8 @@ cc_library( cc_library( name = "gpu_compiler", - srcs = ["gpu_compiler.cc"], - hdrs = ["gpu_compiler.h"], + srcs = ["nvptx_compiler.cc"], + hdrs = ["nvptx_compiler.h"], deps = [ ":cudnn_convolution_algorithm_picker", ":cudnn_convolution_rewriter", @@ -552,9 +620,11 @@ cc_library( ":ir_emission_utils", ":ir_emitter", ":multi_output_fusion", + ":pad_for_tensor_cores", ":pad_insertion", ":partition_assignment", ":stream_assignment", + ":stream_executor_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -566,10 +636,8 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", - "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", - "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -584,10 +652,10 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", - "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", @@ -599,7 +667,6 @@ cc_library( "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:core", - "@llvm//:support", ], alwayslink = True, # Contains compiler registration ) @@ -610,24 +677,46 @@ cc_library( hdrs = ["cudnn_batchnorm_rewriter.h"], deps = [ ":ir_emission_utils", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", ], ) +cc_library( + name = "xfeed_queue", + hdrs = ["xfeed_queue.h"], + deps = ["//tensorflow/core:lib"], +) + cc_library( name = "infeed_manager", srcs = ["infeed_manager.cc"], hdrs = ["infeed_manager.h"], deps = [ + ":xfeed_queue", + "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], ) +cc_library( + name = "outfeed_manager", + srcs = ["outfeed_manager.cc"], + hdrs = ["outfeed_manager.h"], + deps = [ + ":xfeed_queue", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "gpu_layout_assignment", srcs = ["gpu_layout_assignment.cc"], @@ -659,6 +748,8 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep ], @@ -697,32 +788,17 @@ tf_cc_test( ], ) -cc_library( - name = "while_transformer", - srcs = ["while_transformer.cc"], - hdrs = ["while_transformer.h"], - deps = [ - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", - ], -) - tf_cc_test( name = "while_transformer_test", srcs = ["while_transformer_test.cc"], deps = [ ":instruction_fusion", - ":while_transformer", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -757,6 +833,8 @@ cc_library( hdrs = ["stream_executor_util.h"], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:stream_executor_no_cuda", ], diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index ab5149dcdb09290cd0c0b2233029d0988a95f036..537295292b6ced72c4b2c456557b3c06e0aa5254 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -44,17 +44,27 @@ StatusOr> BufferAllocations::Builder::Build( num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { + const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); + const int64 expected_alignment = [&] { + if (allocation.is_entry_computation_parameter()) { + return kEntryParameterAlignBytes; + } else if (allocation.is_constant()) { + return kConstantBufferAlignBytes; + } else { + return kXlaAllocatedBufferAlignBytes; + } + }(); + // If buffer #i's address is already registered (e.g. external arguments or // result buffers), use that registered buffer. if (registered_buffers_.count(i)) { se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i); - if (reinterpret_cast(address.opaque()) % - kCudaMallocAlignBytes != + if (reinterpret_cast(address.opaque()) % expected_alignment != 0) { return InternalError( "Address of registered buffer %lld must be a multiple of %llx, but " "was %p", - i, kCudaMallocAlignBytes, address.opaque()); + i, kEntryParameterAlignBytes, address.opaque()); } buffer_allocations->SetBuffer(i, FindOrDie(registered_buffers_, i)); continue; @@ -62,7 +72,6 @@ StatusOr> BufferAllocations::Builder::Build( // Allocate each allocation that might escape, or is the temp buffer. bool seen_temp_buffer = false; - const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) { const int64 buffer_size = allocation.size(); se::DeviceMemoryBase buffer_address; @@ -70,13 +79,12 @@ StatusOr> BufferAllocations::Builder::Build( OwningDeviceMemory buffer; TF_ASSIGN_OR_RETURN( buffer, memory_allocator->Allocate(device_ordinal, buffer_size)); - if (reinterpret_cast(buffer.opaque()) % - kCudaMallocAlignBytes != + if (reinterpret_cast(buffer.opaque()) % expected_alignment != 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " "multiple of %llx, but was %p", - kCudaMallocAlignBytes, buffer.opaque()); + kXlaAllocatedBufferAlignBytes, buffer.opaque()); } // We do manual memory management within BufferAllocations. Be sure not // to do a TF_RETURN_IF_ERROR between this line and the @@ -165,5 +173,10 @@ void BufferAllocations::SetBuffer(BufferAllocation::Index buffer_index, buffers_[buffer_index] = buffer; } +bool ShouldEmitLiteralInLlvmIr(const Literal& literal) { + // LLVM can sometimes do interesting optimizations using scalar constants. + return ShapeUtil::IsScalar(literal.shape()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index 636623502597b3a66523938ba430e9d5a82f796c..f13eab0dd787a2bfa687c991f9d808568360fd24 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -107,6 +107,12 @@ class BufferAllocations { bool torn_down_ = false; }; +// LLVM and PTXAS don't deal well with large constants, so we only emit very +// small constants directly in LLVM IR. Larger constants are emitted with zero +// initializers in LLVM IR and are later overwritten when the PTX/CUBIN is +// loaded. +bool ShouldEmitLiteralInLlvmIr(const Literal& literal); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 77a48965e031349b045a956fd3f28c58607328e5..5780e0af40699bb6ac2c190c09cd02023fb44db7 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -32,8 +33,11 @@ ConditionalThunk::ConditionalThunk( predicate_buffer_index_(predicate_buffer_index), true_operand_buffer_index_(true_operand_buffer_index), false_operand_buffer_index_(false_operand_buffer_index), - true_thunk_(std::move(true_thunk_sequence), hlo), - false_thunk_(std::move(false_thunk_sequence), hlo) {} + // Pass nullptr as the HloInstruction* to the true_thunk_ and false_thunk_ + // constructors because these SequentialThunks are logically "part of" + // this ConditionalThunk, and shouldn't be profiled separately from it. + true_thunk_(std::move(true_thunk_sequence), nullptr), + false_thunk_(std::move(false_thunk_sequence), nullptr) {} Status ConditionalThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { @@ -43,7 +47,9 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable, } Status ConditionalThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); // Copy the predicate value from device. bool predicate; se::DeviceMemoryBase predicate_address = @@ -59,10 +65,15 @@ Status ConditionalThunk::ExecuteOnStream( // Execute the true or the false computation depending on the value of the // predicate. if (predicate) { - TF_RETURN_IF_ERROR(true_thunk_.ExecuteOnStream(buffer_allocations, stream)); + profiler->StartHloComputation(); + TF_RETURN_IF_ERROR( + true_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->true_computation()); } else { + profiler->StartHloComputation(); TF_RETURN_IF_ERROR( - false_thunk_.ExecuteOnStream(buffer_allocations, stream)); + false_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->false_computation()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index ee03865d174469285a9e98b8a30fea90d997df37..aef24342c9fe182eb54b1c2beff840a76e7b8115 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -50,7 +51,8 @@ class ConditionalThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: BufferAllocation::Slice predicate_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index f0881124128c9b043392ffc4fa3aee2cd5b754c7..7833a4077e6c6ee4960665f37fb01a35530fd302 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -55,7 +56,8 @@ ConvolutionThunk::ConvolutionThunk( tensor_ops_enabled_(tensor_ops_enabled) {} Status ConvolutionThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase input_data = buffer_allocations.GetDeviceAddress(input_buffer_); se::DeviceMemoryBase filter_data = @@ -68,6 +70,7 @@ Status ConvolutionThunk::ExecuteOnStream( se::dnn::AlgorithmConfig algorithm_config( se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_)); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); TF_RETURN_IF_ERROR(RunCudnnConvolution( convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, filter_data, output_data, scratch, window_, dim_nums_, algorithm_config, diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 6d845025b1aef2b0a5f147401b6db0598ba94d6d..d76ca6698dcf462c3c4961ce6a9784822af3a81f 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" @@ -66,7 +67,8 @@ class ConvolutionThunk : public Thunk { // Does the convolution for the thunk on "stream". Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: class ScratchAllocator; diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index ee38c0318a878c7bcdc02afdcd146bfb4498d9a2..92e03f94c11f68082f0a8caa64f82e8533557194 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -30,9 +31,11 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk( mem_size_(mem_size) {} Status HostToDeviceCopyThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemcpy(&destination_data, source_address_, mem_size_); return Status::OK(); } @@ -47,11 +50,13 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( mem_size_(mem_size) {} Status DeviceToDeviceCopyThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); se::DeviceMemoryBase source_data = buffer_allocations.GetDeviceAddress(source_buffer_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemcpy(&destination_data, source_data, mem_size_); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index 8b128386f61636de9ac41e856a2b00c578e05735..91564b520acae1839e0a466cf580db00bdf57e46 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -40,7 +41,8 @@ class HostToDeviceCopyThunk : public Thunk { HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const void* source_address_; @@ -63,7 +65,8 @@ class DeviceToDeviceCopyThunk : public Thunk { DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const BufferAllocation::Slice source_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc index db6924c742e4a949a3e939b6d6659e92c2d1e312..60289506524759580dbb9b82147c78c4ce1cb25e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -66,11 +67,12 @@ Status Visitor::HandleBatchNormInference(HloInstruction* batch_norm) { return Status::OK(); } - HloInstruction* epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* epsilon = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(batch_norm->epsilon()))); HloInstruction* feature_index = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(batch_norm->feature_index()))); + LiteralUtil::CreateR0(batch_norm->feature_index()))); std::vector operands(batch_norm->operands().begin(), batch_norm->operands().end()); @@ -101,11 +103,12 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { return Status::OK(); } - HloInstruction* epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* epsilon = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(batch_norm->epsilon()))); HloInstruction* feature_index = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(batch_norm->feature_index()))); + LiteralUtil::CreateR0(batch_norm->feature_index()))); std::vector operands(batch_norm->operands().begin(), batch_norm->operands().end()); @@ -126,12 +129,17 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { HloInstruction* variance_plus_epsilon = computation_->AddInstruction(HloInstruction::CreateBinary( inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev, - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-2))))); + computation_->AddInstruction(HloInstruction::CreateBroadcast( + inverse_stddev->shape(), + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(-2))), + {})))); HloInstruction* variance = computation_->AddInstruction(HloInstruction::CreateBinary( variance_plus_epsilon->shape(), HloOpcode::kSubtract, - variance_plus_epsilon, epsilon)); + variance_plus_epsilon, + computation_->AddInstruction(HloInstruction::CreateBroadcast( + variance_plus_epsilon->shape(), epsilon, {})))); // Repackage the results. std::unique_ptr new_tuple = HloInstruction::CreateTuple({ @@ -164,23 +172,29 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { return Status::OK(); } - HloInstruction* epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* epsilon = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(batch_norm->epsilon()))); HloInstruction* feature_index = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(batch_norm->feature_index()))); + LiteralUtil::CreateR0(batch_norm->feature_index()))); // The cudnn libcall expects its input to be rsqrt(variance + epsilon), but // the batchnorm HLO takes plain variance as input. Fix it up. HloInstruction* var_plus_epsilon = computation_->AddInstruction(HloInstruction::CreateBinary( batch_norm->operand(3)->shape(), HloOpcode::kAdd, - batch_norm->mutable_operand(3), epsilon)); + batch_norm->mutable_operand(3), + computation_->AddInstruction(HloInstruction::CreateBroadcast( + batch_norm->operand(3)->shape(), epsilon, {})))); HloInstruction* inverse_stddev = computation_->AddInstruction(HloInstruction::CreateBinary( var_plus_epsilon->shape(), HloOpcode::kPower, var_plus_epsilon, - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-.5))))); + computation_->AddInstruction(HloInstruction::CreateBroadcast( + var_plus_epsilon->shape(), + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(-.5))), + {})))); std::vector operands(batch_norm->operands().begin(), batch_norm->operands().end()); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index 68099fd63847ef9993f9bc7ac0e28b2939631b35..7b172812c36bb141787ef3a9285d6f7ce13e343b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -99,13 +100,15 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( } Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; std::tie(operand_desc, scale_offset_desc) = MakeDescriptors(hlo_instruction()->shape(), feature_index_); se::DeviceMemory output(buffer_allocations.GetDeviceAddress(output_)); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenBatchNormalizationForward( se::DeviceMemory(buffer_allocations.GetDeviceAddress(operand_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), @@ -123,6 +126,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( /*is_training=*/false, // /*var_to_inv_var=*/nullptr, // /*inv_var_to_var=*/nullptr); + if (!stream->ok()) { return InternalError("BatchNormalizationForward call failed."); } @@ -158,7 +162,8 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( } Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; // The BatchNormTraining HLO outputs a tuple of three elements: output data, @@ -175,6 +180,7 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( buffer_allocations.GetDeviceAddress(output_inv_stddev_)); se::DeviceMemory null_device_ptr(nullptr); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenBatchNormalizationForward( se::DeviceMemory(buffer_allocations.GetDeviceAddress(operand_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), @@ -240,7 +246,8 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( } Status CudnnBatchNormBackwardThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; @@ -257,6 +264,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( se::DeviceMemory output_grad_offset( buffer_allocations.GetDeviceAddress(output_grad_offset_)); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenBatchNormalizationBackward( se::DeviceMemory( buffer_allocations.GetDeviceAddress(grad_output_)), diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h index 874f85a863092ee05ae5df1f92d732318c5a0554..d2143b3952984722d136757255aa0aa60e9cab7e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" @@ -60,7 +61,8 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk { const CudnnBatchNormForwardInferenceThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: BufferAllocation::Slice operand_; @@ -90,7 +92,8 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { const CudnnBatchNormForwardTrainingThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: BufferAllocation::Slice operand_; @@ -123,7 +126,8 @@ class CudnnBatchNormBackwardThunk : public Thunk { delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: BufferAllocation::Slice operand_; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 3dc98c4c93ea2b9b68dd3ee27794a39847f8756c..7348307ec8a7286dfb733d6b9685862b20f11ac9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -14,12 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" namespace xla { namespace gpu { @@ -80,8 +82,7 @@ bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape, const ConvolutionDimensionNumbers& dnums, se::StreamExecutor* stream_exec) { // Skip this check for cudnn7 and newer. - auto version = - stream_exec->AsDnn()->GetVersion(); + auto version = stream_exec->AsDnn()->GetVersion(); if (version.ok() && version.ValueOrDie().major_version() >= 7) { return true; } @@ -137,6 +138,28 @@ string NumBytesToString(int64 bytes) { tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); } +// Acquires a process-global lock on the device pointed to by the given +// StreamExecutor. +// +// This is used to prevent other XLA instances from trying to autotune on this +// device while we're using it. +tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + // se::Platform*s are global singletons guaranteed to live forever. + static auto* mutexes = + new std::map, + tensorflow::mutex>(); + + tensorflow::mutex_lock global_lock(mu); + auto it = mutexes + ->emplace(std::piecewise_construct, + std::make_tuple(stream_exec->platform(), + stream_exec->device_ordinal()), + std::make_tuple()) + .first; + return tensorflow::mutex_lock{it->second}; +} + } // anonymous namespace // We could have caching here so that we don't redo this work for two identical @@ -155,6 +178,13 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) { + // Don't run this function concurrently on the same GPU. + // + // This is a bit of a hack and doesn't protect us against arbitrary concurrent + // use of a GPU, but it's sufficient to let us compile two HLO modules + // concurrently and then run them sequentially. + tensorflow::mutex_lock lock = LockGpu(stream_exec_); + // Create a stream for us to do our work on. se::Stream stream{stream_exec_}; stream.Init(); @@ -338,8 +368,8 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( computation->AddInstruction(HloInstruction::CreateTuple( {computation->AddInstruction(HloInstruction::CreateGetTupleElement( new_call_shape.tuple_shapes(0), new_call, 0)), - computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({})))})); + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({})))})); TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple)); return true; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index e0c73aa73acb7f3313eb54fb07390cb76590433e..905b5ee8767d0fa0514c7f1abf83bc089cd08045 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -42,8 +42,8 @@ bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { } // CuDNN does not accept zero-element arguments - if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) || - ShapeUtil::HasZeroElements(conv->operand(1)->shape())) { + if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) || + ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) { return false; } diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index b812dd7d3fbb25f279e87f79b647e299f29073ea..9b6de115ad7e7f87e431f839c1690858f4bce3fd 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -29,12 +29,13 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -67,8 +68,8 @@ bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { GpuElementalIrEmitter::GpuElementalIrEmitter( const HloModuleConfig& hlo_module_config, llvm::Module* module, - llvm::IRBuilder<>* ir_builder, NestedComputer compute_nested) - : ElementalIrEmitter(hlo_module_config, module, ir_builder), + llvm::IRBuilder<>* b, NestedComputer compute_nested) + : ElementalIrEmitter(hlo_module_config, module, b), hlo_module_config_(hlo_module_config), compute_nested_(std::move(compute_nested)) {} @@ -92,8 +93,8 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( cast_result_to_fp16 = true; for (int64 i = 0; i < operands.size(); ++i) { if (input_types[i] == F16) { - converted_operands[i] = ir_builder_->CreateFPCast( - converted_operands[i], ir_builder_->getFloatTy()); + converted_operands[i] = + b_->CreateFPCast(converted_operands[i], b_->getFloatTy()); converted_input_types[i] = F32; } } @@ -112,7 +113,7 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( converted_input_types, output_type) .ValueOrDie(); if (cast_result_to_fp16) { - result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy()); + result = b_->CreateFPCast(result, b_->getHalfTy()); } return result; } @@ -209,13 +210,15 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( return make_sqrt(); } - if (hlo_module_config_.debug_options().xla_enable_fast_math() && - IsFPLiteralWithValue(rhs, -.5)) { + if (IsFPLiteralWithValue(rhs, -.5)) { VLOG(10) << "emitting pow(A, -.5) as 1/sqrt(A): " << op->ToString(); // LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX // rsqrt.approx instruction. + // + // TODO(jlebar): Does this happen with fastmath disabled? If not, should + // we force-enable it? TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt()); - return ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); + return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); } VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString(); @@ -271,17 +274,20 @@ StatusOr GpuElementalIrEmitter::EmitAtan2( prim_type); } -StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { - PrimitiveType input_type = op->operand(0)->shape().element_type(); - PrimitiveType output_type = op->shape().element_type(); - switch (op->opcode()) { - case HloOpcode::kTanh: - return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type}, - output_type); - default: - return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value); - } +StatusOr GpuElementalIrEmitter::EmitTanh( + PrimitiveType prim_type, llvm::Value* value) const { + // Emit a fast approximation of tanh instead of calling __nv_tanh. + // __nv_tanh is particularly bad because it contains branches, thus + // preventing LLVM's load-store vectorizer from working its magic across a + // function which contains tanh calls. + // + // This routine isn't numerically precise, but it's good enough for ML. + + // Upcast F16 to F32 if necessary. + llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); + llvm::Value* input = b_->CreateFPCast(value, type); + llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); + return b_->CreateFPCast(fast_tanh, value->getType()); } llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( @@ -302,32 +308,31 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( // Declares the callee if it is not declared already. llvm::Function* callee = llvm::cast( - ir_builder_->GetInsertBlock()->getModule()->getOrInsertFunction( + b_->GetInsertBlock()->getModule()->getOrInsertFunction( llvm_ir::AsStringRef(callee_name), callee_type)); for (auto attribute : attributes) { callee->addFnAttr(attribute); } - return ir_builder_->CreateCall(callee, llvm_ir::AsArrayRef(operands)); + return b_->CreateCall(callee, llvm_ir::AsArrayRef(operands)); } llvm::Value* GpuElementalIrEmitter::EmitThreadId() const { - llvm::Value* block_id = ir_builder_->CreateIntCast( + llvm::Value* block_id = b_->CreateIntCast( llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, - {}, {}, ir_builder_), - ir_builder_->getIntNTy(128), /*isSigned=*/true, "block.id"); - llvm::Value* thread_id_in_block = ir_builder_->CreateIntCast( + {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "block.id"); + llvm::Value* thread_id_in_block = b_->CreateIntCast( llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, - {}, {}, ir_builder_), - ir_builder_->getIntNTy(128), /*isSigned=*/true, "thread.id"); - llvm::Value* threads_per_block = ir_builder_->CreateIntCast( + {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); + llvm::Value* threads_per_block = b_->CreateIntCast( llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, - {}, {}, ir_builder_), - ir_builder_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); - return ir_builder_->CreateNSWAdd( - ir_builder_->CreateNSWMul(block_id, threads_per_block), - thread_id_in_block); + {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); + return b_->CreateNSWAdd(b_->CreateNSWMul(block_id, threads_per_block), + thread_id_in_block); } llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( @@ -373,14 +378,20 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), - "reduce_window_accum_ptr", ir_builder_); + "reduce_window_accum_ptr", b_); { TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))({})); - ir_builder_->CreateStore(init_value, accum_ptr); + operand_to_generator.at(hlo->operand(1))( + IrArray::Index(index.GetType()))); + b_->CreateStore(init_value, accum_ptr); } - llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_); + llvm::Type* index_type = index.GetType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return index.GetConstantWithIndexType(c); + }; + + llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); std::vector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); @@ -389,32 +400,32 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( ShapeUtil::MakeShape(operand_element_type, window_size), "window"); CHECK_EQ(window_index.size(), index.size()); - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_); - IrArray::Index input_index(index.size()); - llvm::Value* in_bounds = ir_builder_->getInt1(true); + IrArray::Index input_index(index_type, index.size()); + llvm::Value* in_bounds = b_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* stridden_index = ir_builder_->CreateNSWMul( - index[i], ir_builder_->getInt64(window.dimensions(i).stride())); - input_index[i] = ir_builder_->CreateNSWSub( - ir_builder_->CreateNSWAdd(stridden_index, window_index[i]), - ir_builder_->getInt64(window.dimensions(i).padding_low())); + llvm::Value* stridden_index = b_->CreateNSWMul( + index[i], index_typed_const(window.dimensions(i).stride())); + input_index[i] = b_->CreateNSWSub( + b_->CreateNSWAdd(stridden_index, window_index[i]), + index_typed_const(window.dimensions(i).padding_low())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This // comparison is equivalent to the unsigned comparison // input_index[i] < bound, as a negative value wraps to a large // positive value. - in_bounds = ir_builder_->CreateAnd( + in_bounds = b_->CreateAnd( in_bounds, - ir_builder_->CreateICmpULT( + b_->CreateICmpULT( input_index[i], - ir_builder_->getInt64(operand->shape().dimensions(i)))); + index_typed_const(operand->shape().dimensions(i)))); } llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - SetToFirstInsertPoint(if_data.true_block, ir_builder_); + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); + SetToFirstInsertPoint(if_data.true_block, b_); // We are not in pad, so do the computation. TF_ASSIGN_OR_RETURN(llvm::Value * input_value, @@ -422,24 +433,28 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN( llvm::Value * accum_value, compute_nested_(*hlo->to_apply(), - {ir_builder_->CreateLoad(accum_ptr), input_value})); - ir_builder_->CreateStore(accum_value, accum_ptr); + {b_->CreateLoad(accum_ptr), input_value})); + b_->CreateStore(accum_value, accum_ptr); - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), ir_builder_); - return ir_builder_->CreateLoad(accum_ptr); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); + return b_->CreateLoad(accum_ptr); }; case HloOpcode::kReduce: + // TODO(b/112040122): This should be supported. + CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce"; return [=, &operand_to_generator]( const IrArray::Index& output_index) -> StatusOr { const HloInstruction* operand = hlo->operand(0); llvm::Value* accum_ptr = - ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( + b()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( hlo->shape().element_type(), module_)); + llvm::Type* index_type = output_index.GetType(); TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))({})); - ir_builder()->CreateStore(init_value, accum_ptr); + operand_to_generator.at(hlo->operand(1))( + IrArray::Index(index_type))); + b()->CreateStore(init_value, accum_ptr); - llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_); + llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); IrArray::Index input_index = loops.AddLoopsForShapeOnDimensions( operand->shape(), hlo->dimensions(), "reduction_dim"); if (!ShapeUtil::IsScalar(hlo->shape())) { @@ -454,18 +469,17 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( CHECK_EQ(output_index.size(), j); } - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder()); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b()); TF_ASSIGN_OR_RETURN( llvm::Value * input_value, operand_to_generator.at(hlo->operand(0))(input_index)); TF_ASSIGN_OR_RETURN( llvm::Value * accum_value, - compute_nested_( - *hlo->to_apply(), - {ir_builder()->CreateLoad(accum_ptr), input_value})); - ir_builder()->CreateStore(accum_value, accum_ptr); - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), ir_builder()); - return ir_builder()->CreateLoad(accum_ptr); + compute_nested_(*hlo->to_apply(), + {b()->CreateLoad(accum_ptr), input_value})); + b()->CreateStore(accum_value, accum_ptr); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b()); + return b()->CreateLoad(accum_ptr); }; default: return ElementalIrEmitter::MakeElementGenerator(hlo, diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 91f4d960aa62fff3e0699ece37a8c74d7dcf2f59..84454d31bb820a3de6ef3364bd205b8115bd95c0 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -43,7 +43,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const HloComputation&, tensorflow::gtl::ArraySlice)>; GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config, - llvm::Module* module, llvm::IRBuilder<>* ir_builder, + llvm::Module* module, llvm::IRBuilder<>* b, NestedComputer compute_nested); llvm_ir::ElementGenerator MakeElementGenerator( @@ -51,9 +51,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const HloToElementGeneratorMap& operand_to_generator) const override; protected: - StatusOr EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const override; - StatusOr EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const override; @@ -85,6 +82,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const override; + StatusOr EmitTanh(PrimitiveType prim_type, + llvm::Value* value) const override; + llvm::Value* EmitThreadId() const override; private: diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index e14ee6918bf148861ecccac99355fccf7ae93103..0cdddf8bcfd4e849b311bf810eda471d79dbf106 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -107,7 +108,8 @@ FftThunk::FftThunk(FftType fft_type, output_shape_(output_shape) {} Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); VLOG(3) << "Output shape: " @@ -116,6 +118,7 @@ Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(), buffer_allocations.memory_allocator()); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); if (fft_plan_ == nullptr) { const int64 fft_rank = fft_length_.size(); CHECK_LE(fft_rank, 3); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index b0a22564f3a09bb67a3c01723f6e37c604656d45..8c53be5077b0c5a88d303c729457139c6cb800f1 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" @@ -72,7 +73,8 @@ class FftThunk : public Thunk { // Does the FFT for the thunk on "stream". Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const se::fft::Type fft_type_; diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index b36539e0cb8d0a2f4758dd90acbdd8fc7181b8ca..2fd2206324e5f763490780a54880825a772b7ea2 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -27,8 +28,11 @@ ForThunk::ForThunk(const int64 loop_limit, const HloInstruction* hlo) : Thunk(Kind::kWhile, hlo), loop_limit_(loop_limit), - body_thunk_sequence_( - MakeUnique(std::move(*body_thunk_sequence), hlo)) {} + body_thunk_sequence_(MakeUnique( + // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ + // constructor because this SequentialThunk is logically "part of" + // this ForThunk, and shouldn't be profiled separately from it. + std::move(*body_thunk_sequence), nullptr)) {} Status ForThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { @@ -37,11 +41,17 @@ Status ForThunk::Initialize(const GpuExecutable& executable, } Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { + VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " + << (hlo_instruction() ? hlo_instruction()->ToString() : ""); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); for (int64 i = 0; i < loop_limit_; ++i) { + profiler->StartHloComputation(); // Invoke loop body thunk sequence. - TF_RETURN_IF_ERROR( - body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, + stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->while_body()); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index 41ddfe0ceb1d0516c1c64feca53212a925632209..c2d39071b292c6704e9b5857a68bd8b3f3b9a914 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -39,7 +40,8 @@ class ForThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const int64 loop_limit_; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 79fca43d022816645b8a07b9e806fe9cc3745e7c..74282c568c09921dbeec2e9cce79b6c73b6ea592 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -31,16 +32,19 @@ namespace { // dimensions. struct MatrixDescriptor { MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose, - int64 matrix_num_rows, int64 matrix_num_cols) + int64 matrix_num_rows, int64 matrix_num_cols, + int64 matrix_batch_size) : data(matrix_data), transpose(needs_transpose), num_rows(matrix_num_rows), - num_cols(matrix_num_cols) {} + num_cols(matrix_num_cols), + batch_size(matrix_batch_size) {} se::DeviceMemoryBase data; bool transpose; // Whether this matrix needs to be transposed. int64 num_rows; int64 num_cols; + int64 batch_size; }; // Performs a gemm call without an explicit algorithm on lhs_matrix and @@ -50,6 +54,9 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, MatrixDescriptor output_matrix, double alpha, se::Stream* stream) { DCHECK(!output_matrix.transpose); + const int64 batch_size = lhs_matrix.batch_size; + CHECK_EQ(batch_size, rhs_matrix.batch_size); + CHECK_EQ(batch_size, output_matrix.batch_size); se::DeviceMemory lhs_data(lhs_matrix.data); se::DeviceMemory rhs_data(rhs_matrix.data); se::DeviceMemory output_data(output_matrix.data); @@ -60,13 +67,30 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, : se::blas::Transpose::kNoTranspose; auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols; + if (batch_size == 1) { + return stream + ->ThenBlasGemm( + lhs_transpose, rhs_transpose, output_matrix.num_rows, + output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, + lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, + &output_data, /*leading dim of output=*/output_matrix.num_rows) + .ok(); + } + + int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols; + int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols; + int64 output_stride = output_matrix.num_rows * output_matrix.num_cols; return stream - ->ThenBlasGemm( + ->ThenBlasGemmStridedBatched( lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, - lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, - /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, - &output_data, /*leading dim of output=*/output_matrix.num_rows) + output_matrix.num_cols, /*size of reduce dim=*/k, + /*alpha=*/alpha, lhs_data, + /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride, + /*beta=*/0.0, &output_data, + /*leading dim of output=*/output_matrix.num_rows, output_stride, + batch_size) .ok(); } @@ -93,6 +117,10 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, se::blas::ProfileResult* output_profile_result) { DCHECK(!output_matrix.transpose); + CHECK_EQ(1, lhs_matrix.batch_size); + CHECK_EQ(1, rhs_matrix.batch_size); + CHECK_EQ(1, output_matrix.batch_size); + se::DeviceMemory lhs_data(lhs_matrix.data); se::DeviceMemory rhs_data(rhs_matrix.data); se::DeviceMemory output_data(output_matrix.data); @@ -141,9 +169,15 @@ StatusOr DoGemmAutotune( alpha, computation_type, algorithm, stream, &profile_result)); - if (profile_result.is_valid() && profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; + if (profile_result.is_valid()) { + VLOG(3) << "cublas gemm algorithm " << algorithm << " took " + << profile_result.elapsed_time_in_ms() << "ms"; + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + } else { + VLOG(4) << "cublas gemm algorithm " << algorithm << " failed."; } } @@ -167,6 +201,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { return &DoGemm; case F64: return &DoGemm; + case C64: + return &DoGemm>; default: LOG(FATAL) << "Unsupported type."; } @@ -180,6 +216,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type) return &DoGemmWithAlgorithm; case F64: return &DoGemmWithAlgorithm; + case C64: + return &DoGemmWithAlgorithm>; default: LOG(FATAL) << "Unsupported type."; } @@ -192,6 +230,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { return &DoGemmAutotune; case F64: return &DoGemmAutotune; + case C64: + return &DoGemmAutotune>; default: LOG(FATAL) << "Unsupported type."; } @@ -210,6 +250,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { return se::blas::ComputationType::kF32; case F64: return se::blas::ComputationType::kF64; + case C64: + return se::blas::ComputationType::kComplexF32; default: LOG(FATAL) << "Unsupported type."; } @@ -252,7 +294,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, alpha_(alpha) {} Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { VLOG(2) << "Executing a GemmThunk"; se::DeviceMemoryBase lhs_data = @@ -262,12 +305,37 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::DeviceMemoryBase output_data = buffer_allocations.GetDeviceAddress(output_buffer_); + DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), + dim_nums.rhs_batch_dimensions_size()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, + ShapeUtil::Rank(output_shape_)); + + int64 row_dim = dim_nums.lhs_batch_dimensions_size(); + int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1; + int64 batch_size = std::accumulate(output_shape_.dimensions().begin(), + output_shape_.dimensions().end() - 2, 1, + std::multiplies()); + + // Check that the batch dims don't cover the last two dims. + for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { + CHECK_NE(row_dim, batch_dim); + CHECK_NE(col_dim, batch_dim); + } + + // Verify that the non-batch dimensions are minor-most. This is required for + // efficient access. + for (const auto* shape : {&lhs_shape_, &rhs_shape_, &output_shape_}) { + CHECK_LT(shape->layout().minor_to_major(row_dim), 2); + CHECK_LT(shape->layout().minor_to_major(col_dim), 2); + } + // BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between // matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of // their layout. Therefore, we should treat dimension 0 as row and dimension 1 // as column when mapping a matrix Dot to BLAS gemm. - int64 output_num_rows = output_shape_.dimensions(0); - int64 output_num_cols = output_shape_.dimensions(1); + int64 output_num_rows = output_shape_.dimensions(row_dim); + int64 output_num_cols = output_shape_.dimensions(col_dim); // BLAS gemm expects the inputs and the output are in column-major order. // Therefore, we need to convert dot between row-major matrices to that @@ -290,34 +358,46 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, // the leading dimension of the LHS matrix of gemm is the number of rows in // B^T and thus the number of columns in B. - auto make_descriptor = [this](se::DeviceMemoryBase data, const Shape& shape, - bool transpose) -> MatrixDescriptor { - bool is_row_major = LayoutUtil::Minor(shape.layout(), 0) != 0; - bool layout_mismatch = LayoutUtil::Minor(shape.layout(), 0) != - LayoutUtil::Minor(output_shape_.layout(), 0); - return MatrixDescriptor(data, transpose ^ layout_mismatch, - shape.dimensions(is_row_major), - shape.dimensions(!is_row_major)); + auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape, + bool transpose) -> MatrixDescriptor { + bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0; + bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) != + LayoutUtil::Minor(output_shape_.layout(), row_dim); + return MatrixDescriptor( + data, transpose ^ layout_mismatch, + shape.dimensions(row_dim + static_cast(is_row_major)), + shape.dimensions(row_dim + static_cast(!is_row_major)), + batch_size); }; - DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); - const MatrixDescriptor lhs_descriptor = make_descriptor( - lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0); + lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim); const MatrixDescriptor rhs_descriptor = make_descriptor( - rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1); + rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim); // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to // autotune this gemm to figure out the best algorithm. - auto launch = [this](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::Stream* stream) { + auto launch = [&](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::Stream* stream) { PrimitiveType element_type = output_shape_.element_type(); se::blas::ComputationType computation_type = GetBlasComputationType(element_type); + // TODO(b/112111608): Implement auto tune for batched gemm. + if (batch_size != 1) { + return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, + alpha_, stream); + } + + auto thunk_name = [&] { + return hlo_instruction() != nullptr ? hlo_instruction()->ToString() + : ""; + }; + const string& device_name = stream->parent()->GetDeviceDescription().name(); auto autotune_it = autotune_results_.find(device_name); if (autotune_it == autotune_results_.end()) { + VLOG(3) << "Starting autotune of GemmThunk " << thunk_name(); StatusOr best_algorithm = GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type, stream); @@ -325,11 +405,11 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, autotune_results_.insert({device_name, best_algorithm}).first; if (autotune_it->second.ok()) { - VLOG(2) << "Autotune on GemmThunk " << this + VLOG(2) << "Autotune on GemmThunk " << thunk_name() << " successful; best algorithm is " << best_algorithm.ValueOrDie(); } else { - VLOG(2) << "Autotune on GemmThunk " << this + VLOG(2) << "Autotune on GemmThunk " << thunk_name() << " unsuccessful. Will use generic gemm."; } } @@ -339,7 +419,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, if (best_algorithm.ok()) { auto algorithm = best_algorithm.ValueOrDie(); VLOG(2) << "Using algorithm " << algorithm - << " chosen by autotuning on GemmThunk " << this; + << " chosen by autotuning on GemmThunk " << thunk_name(); return GetGemmWithAlgorithmFn(element_type)( lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type, algorithm, stream, @@ -352,17 +432,18 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, alpha_, stream); }; + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); bool launch_ok; - if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) { - launch_ok = launch( - lhs_descriptor, rhs_descriptor, - MatrixDescriptor(output_data, false, output_num_rows, output_num_cols), - stream); + if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) { + launch_ok = launch(lhs_descriptor, rhs_descriptor, + MatrixDescriptor(output_data, false, output_num_rows, + output_num_cols, batch_size), + stream); } else { - launch_ok = launch( - rhs_descriptor, lhs_descriptor, - MatrixDescriptor(output_data, false, output_num_cols, output_num_rows), - stream); + launch_ok = launch(rhs_descriptor, lhs_descriptor, + MatrixDescriptor(output_data, false, output_num_cols, + output_num_rows, batch_size), + stream); } if (!launch_ok) { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 7a4830d64e7caef5a1170cbdbf8ab373fdaf16e2..12c81f9bfc6bfdac63edf9c826b835057107fa41 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -48,14 +49,15 @@ class GemmThunk : public Thunk { // Does the gemm operation for the thunk on "stream", which must be non-null. Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; - - // Returns true if we'll perform autotuning if run on the given stream. If - // so, we want the GPU to be quiescent during autotuning, so as not to - // introduce noise in our results. - bool ShouldHaltAllActivityBeforeRunning(se::Stream* stream) override { - return autotune_results_.count( - stream->parent()->GetDeviceDescription().name()) != 0; + se::Stream* stream, + HloExecutionProfiler* profiler) override; + + bool WillAutotuneKernel(se::Stream* stream) override { + // We will autotune this kernel if we don't already have a autotune result + // for the stream device. + return autotune_results_.find( + stream->parent()->GetDeviceDescription().name()) == + autotune_results_.end(); } private: @@ -73,6 +75,8 @@ class GemmThunk : public Thunk { // results. The map's value is the best algorithm we've found for this thunk // on this device, or an error if none of the algorithms worked and we should // use the regular gemm without an algorithm. + // + // TODO(b/112415150): Make this thread safe. std::unordered_map> autotune_results_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc index aa360c7f73de2f0f9cf59c22b552b8e60ddb3a87..7f0b030fece8f25578bd90a538279d455350278a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc @@ -14,12 +14,23 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" +#include "tensorflow/core/framework/allocator.h" namespace xla { namespace gpu { -// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses -const int64 kCudaMallocAlignBytes = 256; +// kEntryParameterAlignBytes is equal to EIGEN_MAX_ALIGN_BYTES, though including +// Eigen headers here to get that symbol may not be a good idea. +// EIGEN_MAX_ALIGN_BYTES may differ between CUDA-enabled builds vs CUDA-disabled +// builds and we don't want the IR generated by XLA:GPU to depend on that. +// +// TODO(b/111767313): Consider raising EIGEN_MAX_ALIGN_BYTES if it helps. +const int64 kEntryParameterAlignBytes = 16; + +const int64 kXlaAllocatedBufferAlignBytes = + tensorflow::Allocator::kAllocatorAlignment; + +const int64 kConstantBufferAlignBytes = kXlaAllocatedBufferAlignBytes; } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.h b/tensorflow/compiler/xla/service/gpu/gpu_constants.h index eb1ca4c6c95a23d2a08f5f9c3cbc85e7d47d4f89..6f5f1fa09c57dfd246d702c0adc92c7e2e76805a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_constants.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.h @@ -21,9 +21,15 @@ limitations under the License. namespace xla { namespace gpu { -// Minimum alignment of cudaMalloc. We require that buffers created by our -// DeviceMemoryAllocator, and all input/output buffers, have this alignment. -extern const int64 kCudaMallocAlignBytes; +// Minimum alignment for buffers passed as incoming arguments by TensorFlow. +extern const int64 kEntryParameterAlignBytes; + +// Minimum alignment for buffers allocated by XLA: the temp buffers and the live +// out (result) buffers. +extern const int64 kXlaAllocatedBufferAlignBytes; + +// Minimum alignment for constant buffers. +extern const int64 kConstantBufferAlignBytes; } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index c5ccdd4a7dcec02ddab8a1f748659de41f6202d2..75f414e47fe3edcc1b10b392ed5cc5038be6c190 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -48,79 +48,17 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module)); - TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, - HloDataflowAnalysis::Run(*module)); - - // Make sure all operands of a library call are in memory instead of constants - // in IR. - for (HloInstruction* hlo : - module->entry_computation()->MakeInstructionPostOrder()) { - // Inserts a copy of hlo->operand(n) if it's a constant. - auto copy_operand_if_constant = [&](int64 n) -> Status { - HloInstruction* operand = hlo->mutable_operand(n); - TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); - const auto& values = dataflow->GetValueSet(operand).values(); - if (std::any_of(values.begin(), values.end(), [](const HloValue* value) { - return value->defining_instruction()->opcode() == - HloOpcode::kConstant; - })) { - TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand)); - TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(n, copy)); - changed = true; - } - return Status::OK(); - }; - - if (IsCustomCallToDnnBatchNorm(*hlo)) { - // The epsilon and feature_index operands to a CUDNN batchnorm op don't - // need to be materialized in memory -- in fact, they must be constants. - // These are the last two operands of all three batchnorm ops. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } else if (ImplementedAsLibraryCall(*hlo) || - hlo->opcode() == HloOpcode::kCrossReplicaSum) { - // For all other library calls and cross-replica-sum, materialize all the - // operands into memory. (Cross-replica-sum gets its constant args - // materialized even if it's not implemented as a libcall to simplify the - // implementation. It's slower, but we can constant fold away constant - // args *anyway*, so we just need to make it work.) - for (int64 i = 0; i < hlo->operand_count(); ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } - } - - // Init values of while and conditional nodes cannot be constants. Insert - // copies for any constants found at the operands of these nodes. - tensorflow::gtl::FlatSet inserted_copies; + // Check the assumption that the epsilon and feature_index constants of the + // CUDNN batchnorm op are not shared with other ops where we would replace + // them with a copy. These custom op calls are generated with the + // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them. for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kWhile && - instruction->opcode() != HloOpcode::kConditional) { + for (HloInstruction* hlo : computation->instructions()) { + if (!IsCustomCallToDnnBatchNorm(*hlo)) { continue; } - for (auto operand : instruction->operands()) { - // Skip the operands that have already been replaced with a copy in a - // previous iteration (which is possible when a constant is used as an - // operand in multiple places). - if (ContainsKey(inserted_copies, operand)) { - continue; - } - for (auto& pair : dataflow->GetInstructionValueSet(operand)) { - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (value->defining_instruction()->IsConstant() && - !ContainsKey(hlo_to_copy_map_, value->defining_instruction())) { - HloInstruction* constant = value->defining_instruction(); - TF_ASSIGN_OR_RETURN(HloInstruction * copy, - FindOrInsertCopy(constant)); - TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy)); - inserted_copies.insert(copy); - changed = true; - } - } - } + for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count(); ++i) { + CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant); } } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 25d8f720ea4791a4c94efcad6909cd0c113fbe70..70608379048871cf6ee72145fa9afff71a3eabe6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -22,8 +22,9 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -41,77 +42,6 @@ namespace { using tensorflow::tracing::ScopedAnnotation; -// A helper class for profiling HLO in the course of GPU program execution. -// All of the profiling is guarded internally, to avoid the caller needing to -// have lots of conditionals sprinkled around. -class HloExecutionProfiler { - public: - // If profiling is enabled, start an execution timer running. - explicit HloExecutionProfiler( - bool do_profile, HloExecutionProfile* profile, se::Stream* stream, - const std::vector::SmartPtr>& sub_streams, - const HloComputation* computation) - : do_profile_(do_profile), - profile_(profile), - stream_(stream), - sub_streams_(sub_streams), - computation_(computation) { - if (do_profile_) { - clock_rate_ghz_ = - stream->parent()->GetDeviceDescription().clock_rate_ghz(); - execution_timer_.reset(new se::Timer(stream->parent())); - per_op_timer_.reset(new se::Timer(stream->parent())); - stream->InitTimer(execution_timer_.get()) - .ThenStartTimer(execution_timer_.get()); - stream->InitTimer(per_op_timer_.get()); - } - } - - // If profiling is enabled, sets the total cycle count on the profile from the - // execution timer. - void FinishExecution() { - CHECK(!finished_execution_) << "Call FinishExecution only once!"; - finished_execution_ = true; - if (do_profile_) { - stream_->ThenWaitFor(&sub_streams_); - stream_->ThenStopTimer(execution_timer_.get()); - stream_->BlockHostUntilDone().IgnoreError(); - profile_->set_total_cycles_executed( - *computation_, execution_timer_->Nanoseconds() * clock_rate_ghz_); - } - } - - // If profiling is enabled, starts the per-operation timer. - void StartOperation() { - if (do_profile_) { - stream_->ThenStartTimer(per_op_timer_.get()); - } - } - - // If profiling is enabled, stops the per-operation timer and records the time - // that the hlo_instruction took to execute in the profile. - void FinishOperation(const HloInstruction* hlo_instruction) { - if (do_profile_) { - stream_->ThenWaitFor(&sub_streams_); - stream_->ThenStopTimer(per_op_timer_.get()); - stream_->BlockHostUntilDone().IgnoreError(); - profile_->SetCyclesTakenBy( - hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_); - } - } - - private: - const bool do_profile_; - double clock_rate_ghz_; - HloExecutionProfile* profile_; - se::Stream* stream_; - const std::vector::SmartPtr>& sub_streams_; - const HloComputation* computation_; - std::unique_ptr execution_timer_; - std::unique_ptr per_op_timer_; - bool finished_execution_ = false; -}; - } // namespace // Implementation note: HLO profiling is always enabled for GPU executables, @@ -155,7 +85,7 @@ Status GpuExecutable::ExecuteThunks( } // Stream 0 indicates `main_stream` and substreams start from stream 1. - std::vector::SmartPtr> sub_streams; + std::vector sub_streams; sub_streams.reserve(thunk_schedule_->StreamCount() - 1); while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { sub_streams.emplace_back(); @@ -201,24 +131,24 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } - // If this thunk requests it, wait for all currently-executing thunks to - // finish. This is useful e.g. if the thunk is about to perform autotuning. - if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { + // If this thunk is about to autotune then wait for all currently executing + // thunks to finish. This reduces noise and thus the probability of + // choosing a suboptimal algorithm. + if (thunk->WillAutotuneKernel(stream)) { TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); } - profiler.StartOperation(); VLOG(2) << "Executing the thunk for " << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; - TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); + TF_RETURN_IF_ERROR( + thunk->ExecuteOnStream(buffer_allocations, stream, &profiler)); if (thunk_schedule_->Depended(thunk)) { auto finish_event = MakeUnique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); } - profiler.FinishOperation(thunk->hlo_instruction()); } main_stream->ThenWaitFor(&sub_streams); @@ -253,6 +183,55 @@ Status GpuExecutable::ExecuteThunks( return Status::OK(); } +StatusOr +GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { + tensorflow::mutex_lock lock(module_handle_mutex_); + auto it = module_globals_.find(executor); + if (it != module_globals_.end()) { + return &it->second; + } + + se::MultiModuleLoaderSpec module_spec; + if (!cubin().empty()) { + module_spec.AddCudaCubinInMemory(cubin()); + } + module_spec.AddCudaPtxInMemory(ptx().c_str()); + + tensorflow::gtl::FlatMap globals; + se::ModuleHandle module_handle; + executor->LoadModule(module_spec, &module_handle); + + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); + ++i) { + const BufferAllocation& allocation = assignment_->GetAllocation(i); + if (allocation.is_constant()) { + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase global, + executor->GetUntypedSymbol( + llvm_ir::ConstantBufferAllocationToGlobalName(allocation), + module_handle)); + VLOG(3) << "Resolved global " + << llvm_ir::ConstantBufferAllocationToGlobalName(allocation) + << " to " << global.opaque(); + InsertOrDie(&globals, i, global); + + const Literal& literal = + llvm_ir::LiteralForConstantAllocation(allocation); + CHECK(ShapeUtil::IsArray(literal.shape())); + if (!ShouldEmitLiteralInLlvmIr(literal)) { + VLOG(3) << "H2D memcpy for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + TF_RETURN_IF_ERROR(executor->SynchronousMemcpyH2D( + literal.untyped_data(), allocation.size(), &global)); + } + } + } + + module_handles_.emplace(executor, + se::ScopedModuleHandle(executor, module_handle)); + return &module_globals_.emplace(executor, std::move(globals)).first->second; +} + StatusOr GpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -264,6 +243,10 @@ StatusOr GpuExecutable::ExecuteOnStream( } BufferAllocations::Builder buffer_allocations_builder; + se::StreamExecutor* executor = run_options->stream()->parent(); + + TF_ASSIGN_OR_RETURN(auto* const globals, ResolveConstantGlobals(executor)); + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); ++i) { const BufferAllocation& allocation = assignment_->GetAllocation(i); @@ -285,8 +268,12 @@ StatusOr GpuExecutable::ExecuteOnStream( buffer_allocations_builder.RegisterBuffer(i, buffer); } + + if (allocation.is_constant()) { + buffer_allocations_builder.RegisterBuffer(i, FindOrDie(*globals, i)); + } } - se::StreamExecutor* executor = run_options->stream()->parent(); + TF_ASSIGN_OR_RETURN( auto buffer_allocations, buffer_allocations_builder.Build( @@ -307,7 +294,7 @@ StatusOr GpuExecutable::ExecuteOnStream( // the respective location in ShapedBuffer. std::set buffers_in_result; TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachMutableElementWithStatus( - [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( + [&buffer_allocations, &buffers_in_result, this]( const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { const auto& sources = this->GetRootPointsToSet().element(index); // The points-to set is unambiguous so the set should be a diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 80ec38c3ac114fe4ad9d56784330c1144d913db1..c7ce6d0acbbbe594040271c0d45c71c016e36514 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -34,6 +34,8 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -66,7 +68,7 @@ class GpuExecutable : public Executable { } // Returns the compiled PTX for the computation. - tensorflow::StringPiece ptx() const { return ptx_; } + const string& ptx() const { return ptx_; } // Returns the cubin (compiled PTX) stored in this GpuExecutable. May be // empty, in which case compilation is left up to the GPU driver. @@ -98,6 +100,15 @@ class GpuExecutable : public Executable { // computation. Uses points-to analysis from buffer assignment. const PointsToSet& GetRootPointsToSet() const; + using BufferAllocToDeviceMemoryMap = + tensorflow::gtl::FlatMap; + + // Loads the PTX or CUBIN for this executable into `executor` and resolves the + // globals corresponding to constant buffers. Returns a map mapping buffer + // allocation indices to GPU pointers. + StatusOr ResolveConstantGlobals( + stream_executor::StreamExecutor* executor); + // The LLVM IR, in string format, of the unoptimized module generated for this // GpuExecutable. We save a string instead of an llvm::Module* because leaving // llvm::Module* in a singleton can cause the heap checker to emit false @@ -126,6 +137,14 @@ class GpuExecutable : public Executable { // memory for every output/temp buffers. const std::unique_ptr assignment_; + // Cache of module handles and constant buffer allocation maps used by + // `ResolveConstantGlobals`. + tensorflow::mutex module_handle_mutex_; + std::map + module_handles_ GUARDED_BY(module_handle_mutex_); + std::map + module_globals_ GUARDED_BY(module_handle_mutex_); + TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 8bf62dde8b9948375fc493fd1a524cfa7b062502..d033faee8d25ed81a1483f8314652ef999ab36c5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -31,52 +31,58 @@ limitations under the License. namespace xla { namespace gpu { -using stream_executor::dnn::DataLayout; -using stream_executor::dnn::FilterLayout; - -static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { - int major, minor; - CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, - &minor)); - return major >= 7; -} +using se::dnn::DataLayout; +using se::dnn::FilterLayout; // Returns (input, filter, output) layouts. static std::tuple HeuristicLayoutAssignment(const HloInstruction* instr, - stream_executor::StreamExecutor* stream_executor) { + se::StreamExecutor* stream_executor) { // DataLayout and FilterLayout uses weird enum names. Translations: // N <=> Batch or Output // C <=> Depth or Input // H <=> Y // W <=> X // - // Therefore kOutputInputYX means NHWC; kBatchDepthYX means NCHW. + // Therefore kOutputInputYX and kBatchDepthYX mean NCHW. + // + // If you have trouble keeping these straight, consider that all that matters + // is the location of the channel dim: Is it major (NCHW), or minor (NHWC)? + + constexpr auto kAllNCHW = + std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + constexpr auto kAllNHWC = + std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth); - // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x - // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version - // changes, as well as the hardware updates. + // If we're not Volta or not fp16, the decision is easy: Use NCHW. if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 && IsVoltaOrLater(*stream_executor))) { - return std::make_tuple(DataLayout::kBatchDepthYX, - FilterLayout::kOutputInputYX, - DataLayout::kBatchDepthYX); + return kAllNCHW; } + VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString(); - // For BackwardInput that has stride, full NHWC layouts run significantly - // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC). + + // Empirically we've found with Volta and cudnn 7 that backward-input convs + // with stride are significantly faster with NCHW layouts. // - // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW, - // NHWC). + // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW), + // which on paper gives good performance. However, there are two observations: + // * a mixed layout combination is more cuDNN-bug prone, based on empirical + // envidence. + // * we've also observed that for mixed layouts, cuDNN transposes data back + // and forth from a different layout combination. If we end up with + // transposes anyway, we prefer to have them in XLA, as they can be fused. + // TODO(timshen): Figure out the exact condition. This may be achieved by + // auto-tuning layouts offline. if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && window_util::HasStride(instr->window())) { - return std::make_tuple(DataLayout::kBatchYXDepth, - FilterLayout::kOutputInputYX, - DataLayout::kBatchDepthYX); + return kAllNCHW; } - return std::make_tuple(DataLayout::kBatchYXDepth, - FilterLayout::kOutputYXInput, - DataLayout::kBatchYXDepth); + + // For other Volta f16 convolutions, use NHWC. + return kAllNHWC; } // Adds layout constraints on the cudnn custom-call instruction. The layout @@ -170,6 +176,38 @@ Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); } + + // For batched dot we require the default layout. + // TODO(b/112111608): This is overly conservative, the only real restriction + // is that batch dimensions must be major. + if (instruction->opcode() == HloOpcode::kDot && + ImplementedAsGemm(*instruction) && + instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) { + // Verify that the batch dims come before the row and col dims. + const DotDimensionNumbers& dim_nums = + instruction->dot_dimension_numbers(); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), + dim_nums.rhs_batch_dimensions_size()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, + ShapeUtil::Rank(instruction->shape())); + for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { + CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2); + } + + // Set both inputs and the output to default layout. + Shape op0_shape = instruction->operand(0)->shape(); + LayoutUtil::SetToDefaultLayout(&op0_shape); + Shape op1_shape = instruction->operand(1)->shape(); + LayoutUtil::SetToDefaultLayout(&op1_shape); + Shape output_shape = instruction->shape(); + LayoutUtil::SetToDefaultLayout(&output_shape); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op0_shape, instruction, 0)); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op1_shape, instruction, 1)); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(output_shape, instruction)); + } } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index e48165c1426ea04839c245bc20b851a0f1710246..286547ebae2f1a4b8d783a06d13b4dd96052b952 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -31,6 +33,8 @@ namespace xla { namespace gpu { namespace { +namespace op = xla::testing::opcode_matchers; + using LayoutAssignmentTest = HloTestBase; TEST_F(LayoutAssignmentTest, Elementwise) { @@ -132,10 +136,10 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { HloInstruction::CreateParameter(4, aux_shape, "variance")); auto* epsilon = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto* feature_index = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(kFeatureIndex))); + LiteralUtil::CreateR0(kFeatureIndex))); auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall( shape, @@ -201,10 +205,10 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { HloInstruction::CreateParameter(2, offset_scale_shape, "offset")); auto* epsilon = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto* feature_index = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(kFeatureIndex))); + LiteralUtil::CreateR0(kFeatureIndex))); auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall( batchnorm_shape, {operand, scale, offset, epsilon, feature_index}, @@ -278,10 +282,10 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { HloInstruction::CreateParameter(4, shape, "var")); auto* epsilon = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto* feature_index = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(kFeatureIndex))); + LiteralUtil::CreateR0(kFeatureIndex))); auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall( @@ -327,6 +331,33 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { } } +TEST_F(LayoutAssignmentTest, DotLayout) { + const char* hlo_text = R"( + HloModule DotLayout + ENTRY dot { + p0 = f32[8,8,256,64]{3,1,2,0} parameter(0) + p1 = f32[8,8,256,64]{3,1,2,0} parameter(1) + ROOT dot.1330.10585 = f32[8,8,256,256]{3,2,1,0} dot(p0, p1), + lhs_batch_dims={0,1}, lhs_contracting_dims={3}, + rhs_batch_dims={0,1}, rhs_contracting_dims={3} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + GpuLayoutAssignment layout_assignment(&computation_layout, + backend().default_stream_executor()); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + Shape expected_shape = + ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0}); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::ShapeWithLayout(expected_shape), + op::ShapeWithLayout(expected_shape))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 7bb8df6581b49b1bf8c84a972f715e8dc119d8de..a2f53f844613da9fe8166489dc9959e8d30c6332 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include "llvm/IR/DataLayout.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,15 +36,14 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { +namespace gpu { // TODO(b/30467474) Once GPU infeed implementation settles, consider // folding back the cpu and gpu infeed implementations into a generic // one if possible. -GpuTransferManager::GpuTransferManager() - : GenericTransferManager( - se::cuda::kCudaPlatformId, - /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) - .getPointerSize(0 /* default address space */)) {} +GpuTransferManager::GpuTransferManager(se::Platform::Id id, + unsigned pointer_size) + : GenericTransferManager(id, pointer_size) {} Status GpuTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const LiteralSlice& literal) { @@ -50,53 +51,28 @@ Status GpuTransferManager::TransferLiteralToInfeed( VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); - if (!ShapeUtil::IsTuple(shape)) { - int64 size = GetByteSizeRequirement(shape); - return TransferBufferToInfeed(executor, size, literal.untyped_data()); - } - - if (ShapeUtil::IsNestedTuple(shape)) { - return Unimplemented( - "Infeed with a nested tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); - } - // For a tuple, we transfer each of its elements to the device and // enqueue the resulting destination device addresses with the // infeed manager. - std::vector buffers; - buffers.reserve(ShapeUtil::TupleElementCount(shape)); - auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { - for (gpu::InfeedBuffer* b : buffers) { - b->Done(); - } - }); - - for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const Shape& tuple_element_shape = - ShapeUtil::GetTupleElementShape(shape, i); - int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); - TF_ASSIGN_OR_RETURN( - gpu::InfeedBuffer * buffer, - TransferBufferToInfeedInternal(executor, tuple_element_size, - literal.untyped_data({i}))); - buffers.push_back(buffer); - } - - cleanup.release(); - return EnqueueBuffersToInfeed(executor, buffers); -} - -Status GpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, - int64 size, - const void* source) { - TF_ASSIGN_OR_RETURN(gpu::InfeedBuffer * buffer, - TransferBufferToInfeedInternal(executor, size, source)); - return EnqueueBuffersToInfeed(executor, {buffer}); + ShapeTree buffer_tree(shape); + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + shape, [&](const Shape& literal_subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(literal_subshape)) { + int64 tuple_element_size = GetByteSizeRequirement(literal_subshape); + TF_ASSIGN_OR_RETURN( + *buffer_tree.mutable_element(index), + TransferBufferToInfeedInternal(executor, tuple_element_size, + literal.untyped_data(index))); + } + return Status::OK(); + })); + + return EnqueueBuffersToInfeed(executor, std::move(buffer_tree)); } Status GpuTransferManager::EnqueueBuffersToInfeed( - se::StreamExecutor* executor, std::vector buffers) { + se::StreamExecutor* executor, ShapeTree buffers) { gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(); se::Stream* stream = infeed_manager->GetStream(executor); @@ -106,21 +82,18 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( // possible. Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { - for (gpu::InfeedBuffer* b : buffers) { - b->Done(); - } return InternalError("Failed to complete data transfer on stream %p: %s", stream, block_status.error_message().c_str()); } - infeed_manager->EnqueueBuffers(buffers); + infeed_manager->EnqueueDestination(std::move(buffers)); VLOG(2) << "Infeed data transferred"; return Status::OK(); } -StatusOr GpuTransferManager::TransferBufferToInfeedInternal( +StatusOr GpuTransferManager::TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); @@ -136,23 +109,85 @@ StatusOr GpuTransferManager::TransferBufferToInfeedInternal( return InternalError("Failed to obtain a stream"); } - gpu::InfeedBuffer* buffer = new gpu::InfeedBuffer(executor, size); - stream->ThenMemcpy(buffer->device_memory(), source, size); + InfeedBuffer buffer(executor, size); + stream->ThenMemcpy(buffer.device_memory(), source, size); VLOG(2) << "Queued infeed data on stream " << stream; - return buffer; + return std::move(buffer); +} + +static void ShapeTreeToLiteral( + ShapeTree>* shape_tree) { + // This is a struct instead of a lambda for std::function-free recursion. + struct Helper { + static void helper( + ShapeTree>* shape_tree, + ShapeIndex* index) { + const Shape& shape = ShapeUtil::GetSubshape(shape_tree->shape(), *index); + if (ShapeUtil::IsArray(shape)) { + (*shape_tree->mutable_element(*index))->WaitUntilAvailable(); + return; + } + + CHECK(ShapeUtil::IsTuple(shape)) + << ShapeUtil::HumanStringWithLayout(shape); + const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape); + index->push_back(0); + for (int64 i = 0; i < tuple_element_count; ++i) { + index->back() = i; + helper(shape_tree, index); + } + index->pop_back(); + } + }; + ShapeIndex index; + Helper::helper(shape_tree, &index); +} + +Status GpuTransferManager::TransferLiteralFromOutfeed( + se::StreamExecutor* /*executor*/, const Shape& literal_shape, + MutableBorrowingLiteral literal) { + ShapeTree> outfeed_buffers( + &literal_shape); + + // First create a tree of literal buffers that the device can write to. + outfeed_buffers.ForEachMutableElement( + [&](const ShapeIndex& index, + std::unique_ptr* buffer) { + const Shape& shape = ShapeUtil::GetSubshape(literal_shape, index); + // Do not transfer tuple index buffers. + if (ShapeUtil::IsTuple(shape)) { + return; + } + *buffer = MakeUnique(GetByteSizeRequirement(shape)); + (*buffer)->set_destination( + MakeUnique(literal, index)); + }); + + // Give the tree of buffers to the outfeed mananger. The device will fill it + // while we're waiting for it below. + gpu::OutfeedManager* outfeed_manager = gpu::GetOrCreateOutfeedManager(); + outfeed_manager->EnqueueDestination(&outfeed_buffers); + + // Now wait for the tree of buffers are written. + ShapeTreeToLiteral(&outfeed_buffers); + return Status::OK(); } +} // namespace gpu } // namespace xla -static std::unique_ptr CreateGpuTransferManager() { - return xla::MakeUnique(); +static std::unique_ptr CreateNVPTXTransferManager() { + return xla::MakeUnique( + /*id=*/stream_executor::cuda::kCudaPlatformId, + /*pointer_size=*/llvm::DataLayout(xla::gpu::NVPTXCompiler::kDataLayout) + .getPointerSize(0 /* default address space */)); } static bool InitModule() { xla::TransferManager::RegisterTransferManager( - stream_executor::cuda::kCudaPlatformId, &CreateGpuTransferManager); + stream_executor::cuda::kCudaPlatformId, &CreateNVPTXTransferManager); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index 09f8227f508a3159f3def285898e15bfad544552..7929042869763dfeab2fe8f87093b7ea758337d0 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" @@ -28,33 +29,36 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { +namespace gpu { // An implementation of the XLA GenericTransferManager that // handles GPU-specific infeed. class GpuTransferManager : public GenericTransferManager { public: - GpuTransferManager(); + GpuTransferManager(se::Platform::Id id, unsigned pointer_size); ~GpuTransferManager() override {} Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; - Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, - const void* source) override; + Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, + const Shape& literal_shape, + MutableBorrowingLiteral literal) override; private: // Initiates the infeed data transfers. InfeedBuffer->Done() must be // called to clean up the memory allocated for InfeedBuffer. - StatusOr TransferBufferToInfeedInternal( + StatusOr TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source); // Enqueues infeed data buffers with the infeed manager after their // transfer completes. Status EnqueueBuffersToInfeed(se::StreamExecutor* executor, - std::vector buffers); + ShapeTree buffers); TF_DISALLOW_COPY_AND_ASSIGN(GpuTransferManager); }; +} // namespace gpu } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..17226769302eef0dd01550b0bc5404e889ad78f8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -0,0 +1,122 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace xla { +namespace gpu { +namespace { +void InitAndStartTimer(std::stack>* timers, + se::Stream* stream) { + timers->push(MakeUnique(stream->parent())); + stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get()); +} + +uint64 GetCyclesTaken(std::stack>* timers, + const std::vector& sub_streams, + se::Stream* stream, double clock_rate_ghz) { + CHECK_GT(timers->size(), 0); + stream->ThenWaitFor(&sub_streams); + stream->ThenStopTimer(timers->top().get()); + stream->BlockHostUntilDone().IgnoreError(); + double nanoseconds = timers->top()->Nanoseconds(); + timers->pop(); + return static_cast(nanoseconds * clock_rate_ghz); +} +} // namespace + +HloExecutionProfiler::HloExecutionProfiler( + bool do_profile, HloExecutionProfile* profile, se::Stream* stream, + const std::vector& sub_streams, + const HloComputation* computation) + : do_profile_(do_profile), + profile_(profile), + stream_(stream), + sub_streams_(sub_streams), + computation_(computation) { + if (do_profile_) { + clock_rate_ghz_ = stream->parent()->GetDeviceDescription().clock_rate_ghz(); + InitAndStartTimer(&timers_, stream); + } +} + +void HloExecutionProfiler::FinishExecution() { + CHECK(!finished_execution_) << "Call FinishExecution only once!"; + finished_execution_ = true; + if (do_profile_) { + profile_->set_total_cycles_executed( + *computation_, + GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_)); + } +} + +void HloExecutionProfiler::StartHloComputation() { + if (do_profile_) { + InitAndStartTimer(&timers_, stream_); + } +} + +void HloExecutionProfiler::FinishHloComputation( + const HloComputation* computation) { + if (do_profile_) { + profile_->set_total_cycles_executed( + *computation, + GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_)); + } +} + +void HloExecutionProfiler::StartHloInstruction() { + if (do_profile_) { + InitAndStartTimer(&timers_, stream_); + } +} + +void HloExecutionProfiler::FinishHloInstruction( + const HloInstruction* hlo_instruction) { + if (do_profile_) { + hlo_instructions_.erase(hlo_instruction); + profile_->SetCyclesTakenBy( + hlo_instruction, + GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_)); + } +} + +std::unique_ptr +HloExecutionProfiler::MakeScopedInstructionProfiler( + const HloInstruction* hlo_instruction) { + if (do_profile_ && hlo_instruction != nullptr) { + // Make sure that we are not already measuring the time for the same + // 'hlo_instruction'. + CHECK(hlo_instructions_.insert(hlo_instruction).second) + << hlo_instruction->name(); + } + return MakeUnique(this, hlo_instruction); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..80cde75f2bbb555f514fffea58ad92edf92fd0d1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +class ScopedInstructionProfiler; + +// A helper class for profiling HLO in the course of GPU program execution. +// All of the profiling is guarded internally, to avoid the caller needing to +// have lots of conditionals sprinkled around. +class HloExecutionProfiler { + public: + // If profiling is enabled, start an execution timer running. + explicit HloExecutionProfiler(bool do_profile, HloExecutionProfile* profile, + se::Stream* stream, + const std::vector& sub_streams, + const HloComputation* computation); + + // If profiling is enabled, sets the total cycle count on the profile from the + // execution timer. + void FinishExecution(); + + // If profiling is enabled, starts a timer for a (sub)computation. + void StartHloComputation(); + + // If profiling is enabled stops the timer for a (sub)computation and records + // the time that the computation took to execute in the profile. + void FinishHloComputation(const HloComputation* computation); + + // If profiling is enabled, starts a per-operation timer. + void StartHloInstruction(); + + // If profiling is enabled, stops the per-operation timer and records the time + // that the hlo_instruction took to execute in the profile. + void FinishHloInstruction(const HloInstruction* hlo_instruction); + + // Returns a ScopedInstructionProfiler and triggers a call to + // StartHloInstruction(). Once the returned ScopedInstructionProfiler goes + // out of scope, it triggers a call to FinishHloInstruction(). + std::unique_ptr MakeScopedInstructionProfiler( + const HloInstruction* hlo_instruction); + + private: + const bool do_profile_; + double clock_rate_ghz_; + HloExecutionProfile* profile_; + se::Stream* stream_; + const std::vector& sub_streams_; + const HloComputation* computation_; + std::stack> timers_; + // Contains the HLO instructions for which we are currently measuring the + // time. + std::unordered_set hlo_instructions_; + bool finished_execution_ = false; +}; + +// This class can be used within the ExecuteOnStream() implementations of +// Thunks. It ensures that we always have a pair of matching +// StartHloInstruction() and FinishHloInstruction() calls to the profiler. +class ScopedInstructionProfiler { + public: + ScopedInstructionProfiler(HloExecutionProfiler* profiler, + const HloInstruction* hlo_instruction) + : profiler_(profiler), hlo_instruction_(hlo_instruction) { + if (hlo_instruction != nullptr) { + profiler->StartHloInstruction(); + } + } + ~ScopedInstructionProfiler() { + if (hlo_instruction_ != nullptr) { + profiler_->FinishHloInstruction(hlo_instruction_); + } + } + + private: + HloExecutionProfiler* profiler_; + const HloInstruction* hlo_instruction_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index 375709150e08996ea6a40f5e9e66a8f8d9287008..19de37b0fbed15455e8c6a9bfe427ba3d9f0a9dc 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -100,7 +100,7 @@ GpuHloOrdering::GpuHloOrdering( if (last_instruction_per_stream[stream_no] != nullptr) { immediate_preds.push_back(last_instruction_per_stream[stream_no]); } - predecessor_map->SetReachabilityToUnion(immediate_preds, hlo); + predecessor_map->FastSetReachabilityToUnion(immediate_preds, hlo); last_instruction_per_stream[stream_no] = hlo; } else { // Only parameters and constants don't have an assigned stream, since they diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 061210352cf12e6802d066d311fd2cb481673f15..8c11cd05419289d82b033c936bb60884f45cb636 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -18,8 +18,10 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -39,7 +41,7 @@ void HloToIrBindings::EmitBasePointersForHlos( // I/O HLOs are bound to the arguments of the current IR function. I.e., // // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) { - llvm::Function* function = ir_builder_->GetInsertBlock()->getParent(); + llvm::Function* function = b_->GetInsertBlock()->getParent(); CHECK_EQ(io_hlos.size() + 1, function->arg_size()); // An HLO can have duplicated operands. This data structure remembers which @@ -79,8 +81,8 @@ void HloToIrBindings::EmitBasePointersForHlos( const int64 offset = slice.offset(); CHECK_NE(nullptr, temp_buffer_base_); // Emit IR for GetTupleElement instruction and bind to emitted value. - llvm::Value* base_ptr = ir_builder_->CreateInBoundsGEP( - temp_buffer_base_, ir_builder_->getInt64(offset)); + llvm::Value* base_ptr = + b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)); BindHloToIrValue(*non_io_hlo, EmitGetTupleElement(non_io_hlo, base_ptr)); } @@ -108,15 +110,20 @@ void HloToIrBindings::EmitBasePointersForHlos( if (slice.allocation()->is_thread_local()) { llvm::Type* pointee_type = llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_); - BindHloToIrValue(*non_io_hlo, - ir_builder_->CreateAlloca(pointee_type), index); + BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type), + index); + } else if (slice.allocation()->is_constant()) { + llvm::Value* global_for_constant = + module_->getGlobalVariable(llvm_ir::AsStringRef( + llvm_ir::ConstantBufferAllocationToGlobalName( + *slice.allocation()))); + BindHloToIrValue(*non_io_hlo, global_for_constant); } else { const int64 offset = slice.offset(); CHECK_NE(nullptr, temp_buffer_base_); BindHloToIrValue( *non_io_hlo, - ir_builder_->CreateInBoundsGEP(temp_buffer_base_, - ir_builder_->getInt64(offset)), + b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)), index); } }); @@ -129,15 +136,23 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_, module_); + GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_, module_); } return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_, module_); + EmitGetTupleElement(gte->operand(0), base_ptr), b_, module_); +} + +// Returns true if `value` has a name that should not be changed. +static bool HasMeaningfulName(llvm::Value* value) { + if (auto* global = llvm::dyn_cast(value)) { + return global->getLinkage() != llvm::GlobalValue::PrivateLinkage; + } + return false; } llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, - const ShapeIndex& shape_index, + ShapeIndexView shape_index, llvm::Value* ir_value) { llvm::Type* pointee_type = llvm_ir::ShapeToIrType( ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_); @@ -145,20 +160,24 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, llvm::Value* typed_ir_value; if (llvm::isa(ir_value)) { - typed_ir_value = llvm::ConstantExpr::getBitCast( + typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast( llvm::cast(ir_value), dest_type); } else { - typed_ir_value = - ir_builder_->CreateBitCast(ir_value, pointee_type->getPointerTo()); + typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo()); + } + if (!HasMeaningfulName(ir_value)) { + ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); + } + if (!HasMeaningfulName(typed_ir_value)) { + typed_ir_value->setName( + llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); } - ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); - typed_ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); return typed_ir_value; } void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, - const ShapeIndex& shape_index) { + ShapeIndexView shape_index) { VLOG(2) << "Binding " << hlo.ToString(); const Shape& hlo_shape = hlo.shape(); @@ -202,7 +221,7 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, << " of " << hlo.ToString(); llvm_ir::IrArray ir_array(base_ptr, ShapeUtil::GetSubshape(hlo.shape(), shape_index)); - alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); + alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array, shape_index); // The GPU backend emits one kernel per top-level HLO, and LLVM views // execution of one kernel as the "whole program" executed on the GPU. diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 3d34311b4368d17cb074aaf33c71fc865e96387e..eee40b0e91fc03013a6978ae3cfe42b87633eed7 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -36,14 +36,13 @@ class HloToIrBindings { public: HloToIrBindings(const HloModule& module, const BufferAssignment* buffer_assignment, - llvm::IRBuilder<>* ir_builder, llvm::Module* llvm_module, + llvm::IRBuilder<>* b, llvm::Module* llvm_module, bool is_nested) : buffer_assignment_(buffer_assignment), is_nested_(is_nested), - ir_builder_(ir_builder), + b_(b), module_(llvm_module), - alias_analysis_(module, *buffer_assignment_, - &ir_builder_->getContext()) {} + alias_analysis_(module, *buffer_assignment_, &b_->getContext()) {} void EmitBasePointersForHlos( tensorflow::gtl::ArraySlice io_hlos, @@ -51,7 +50,7 @@ class HloToIrBindings { // Rebinds the given HLO to the LLVM IR value that represent its address. void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, - const ShapeIndex& shape_index = {}); + ShapeIndexView shape_index = {}); // Unbinds all IR values that's defined in an LLVM function, e.g., function // arguments and stack variables. Global variables will be kept in bindings_. @@ -71,7 +70,7 @@ class HloToIrBindings { // A helper method that returns the base pointer of the IrArray containing the // output of "inst".at the given ShapeIndex. llvm::Value* GetBasePointer(const HloInstruction& hlo, - const ShapeIndex& shape_index = {}) const { + ShapeIndexView shape_index = {}) const { auto it = base_ptrs_.find(&hlo); CHECK(it != base_ptrs_.end()) << hlo.ToString(); return it->second.element(shape_index); @@ -97,14 +96,14 @@ class HloToIrBindings { // Returns an llvm typed ir representation of 'ir_value' based on 'hlo' shape. llvm::Value* GetTypedIrValue(const HloInstruction& hlo, - const ShapeIndex& shape_index, + ShapeIndexView shape_index, llvm::Value* ir_value); const BufferAssignment* buffer_assignment_; const bool is_nested_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; llvm::Module* module_; // Stores the underlying llvm::IrArray for each HloInstruction. diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index ae310beefad0c81c17fd4140b441b3a19a002e2c..c5f0cdf6cd5d3e076bffa875fbba991bf0681ee8 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -15,76 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" -#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/core/platform/logging.h" namespace xla { namespace gpu { -InfeedManager::InfeedManager() : host_to_device_executor_(nullptr) {} - -void InfeedManager::Reset() { - tensorflow::mutex_lock l(mu_); - CHECK(dequeued_buffer_.empty()); - for (auto buffer : enqueued_buffer_) { - buffer->Done(); - } - enqueued_buffer_.clear(); -} - -void InfeedManager::EnqueueBuffers(const std::vector& buffers) { - tensorflow::mutex_lock l(mu_); - bool was_empty = enqueued_buffer_.empty(); - for (gpu::InfeedBuffer* b : buffers) { - enqueued_buffer_.push_back(b); - } - if (was_empty) { - // This has the potential to suffer from the notified thread - // immediately trying and failing to acquire mu_, but seems - // preferable to the alternative of notifying outside the lock - // on every enqueue. - cv_.notify_one(); - } -} - -InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { - bool became_empty = false; - InfeedBuffer* current_buffer; - { - tensorflow::mutex_lock l(mu_); - while (enqueued_buffer_.empty()) { - cv_.wait(l); - } - current_buffer = enqueued_buffer_.front(); - enqueued_buffer_.pop_front(); - dequeued_buffer_.insert(current_buffer); - if (enqueued_buffer_.empty()) { - became_empty = true; - } - } - if (became_empty) { - for (const auto& callback : on_empty_callbacks_) { - callback(); - } - } - return current_buffer; -} - -void InfeedManager::ReleaseBuffers(const std::vector& buffers) { - { - tensorflow::mutex_lock l(mu_); - for (gpu::InfeedBuffer* b : buffers) { - CHECK(ContainsKey(dequeued_buffer_, b)); - dequeued_buffer_.erase(b); - } - } - for (gpu::InfeedBuffer* b : buffers) { - b->Done(); - } -} - se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { + tensorflow::mutex_lock l(host_to_device_stream_mu_); if (host_to_device_executor_ == nullptr) { host_to_device_executor_ = executor; host_to_device_stream_ = MakeUnique(executor); @@ -100,10 +37,6 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { return host_to_device_stream_.get(); } -void InfeedManager::RegisterOnEmptyCallback(std::function callback) { - on_empty_callbacks_.push_back(std::move(callback)); -} - InfeedManager* GetOrCreateInfeedManager() { static InfeedManager* manager = new InfeedManager; return manager; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h index a3fc15cfe36a490f38daabca9ff36fbb1012aead..7e418882e051a77e10bd12000bbc9769980f5f14 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h @@ -20,12 +20,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ -#include -#include - +#include "tensorflow/compiler/xla/service/gpu/xfeed_queue.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -47,90 +44,41 @@ namespace gpu { // the client. The client manages the memory of the buffer. class InfeedBuffer { public: + InfeedBuffer() = default; InfeedBuffer(se::StreamExecutor* executor, int64 length) - : executor_(executor), length_(length) { - device_memory_ = executor_->AllocateArray(length); - CHECK(!device_memory_.is_null()); + : device_memory_(executor, executor->AllocateArray(length)), + length_(length) { + CHECK(!device_memory_->is_null()); } - ~InfeedBuffer() { executor_->Deallocate(&device_memory_); } - int64 length() const { return length_; } - // Callback to signal that this buffer is consumed. This helps the - // client to manage memory for the infeed buffers. - void Done() { delete this; } - - se::DeviceMemoryBase* device_memory() { return &device_memory_; } + se::DeviceMemoryBase* device_memory() { return device_memory_.ptr(); } private: - se::StreamExecutor* executor_; // Not owned. - const int64 length_; - se::DeviceMemoryBase device_memory_; + se::ScopedDeviceMemory device_memory_; + int64 length_; }; // Client-side class used to enqueue infeed buffers. -class InfeedManager { +class InfeedManager : public XfeedQueue> { public: - InfeedManager(); - - // Calls the completion callback for any enqueued buffers that have - // not been dequeued by the runtime, and empties the infeed - // queue. Reset may not be called while a runtime computation is - // processing a dequeued buffer. The only safe way to ensure this - // condition is to call Reset when no computation is taking place. - void Reset(); - - // Adds a set of buffers to the infeed queue atomically. buffer->Done - // will be called when the buffer will no longer be accessed by the - // InfeedManager, either as a result of a call to Reset or because the - // runtime has dequeued and used the buffer. - void EnqueueBuffers(const std::vector& buffers); - - // Blocks until the infeed queue is non-empty, then returns the - // buffer at the head of the queue. Adds the current buffer to the - // to-be released set. - InfeedBuffer* BlockingDequeueBuffer(); - - // Releases a set of buffers from the to-be released set. - void ReleaseBuffers(const std::vector& buffers); - // Returns a cached stream associated with an executor. Allocates a // new stream on the first invocation. On subsequent invocations, if // the cached executor is not the same as the requested executor, // returns null. se::Stream* GetStream(se::StreamExecutor* executor); - // Registers a callback that will be called when 'enqueued_buffer_' becomes - // empty. - void RegisterOnEmptyCallback(std::function callback); - private: - // TODO(b/30467474): Revisit if this mutex becomes a point of - // contention. - tensorflow::mutex mu_; - - // Condition variable that is signaled every time a buffer is - // enqueued to an empty queue. - tensorflow::condition_variable cv_; - - // InfeedBuffer* queue contents are not owned, but buffer->Done must - // be called when the buffer is no longer needed by the runtime. - std::deque enqueued_buffer_; - - // Buffers that are dequeued and currently being processed by the - // runtime. Not owned. - tensorflow::gtl::FlatSet dequeued_buffer_; + // Mutex for serializing the creation of host_to_device_stream_. + tensorflow::mutex host_to_device_stream_mu_; // Cached host to device stream for queuing infeed data. - std::unique_ptr host_to_device_stream_; + std::unique_ptr host_to_device_stream_ + GUARDED_BY(host_to_device_stream_mu_); // Executor that the host_to_device_stream belongs to. Not owned. - se::StreamExecutor* host_to_device_executor_; - - // List of callbacks which will be called when 'enqueued_buffer_' becomes - // empty. - std::vector> on_empty_callbacks_; + se::StreamExecutor* host_to_device_executor_ = nullptr; }; // Singleton creator-or-accessor: Returns the GPU infeed manager. diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index ea34d5b30c91e8b809e3e17a904e27e589fd6b5f..fee6d2af3bfd4976f5845edf592e8310b55a3feb 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -22,57 +23,82 @@ namespace xla { namespace gpu { InfeedThunk::InfeedThunk( - tensorflow::gtl::ArraySlice tuple_element_buffers, - const BufferAllocation::Slice& destination_buffer, + const ShapeTree& infeed_slices, const HloInstruction* hlo_instruction) - : Thunk(Kind::kInfeed, hlo_instruction), - tuple_element_buffers_(tuple_element_buffers.begin(), - tuple_element_buffers.end()), - destination_buffer_(destination_buffer) {} + : Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {} Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { - VLOG(2) << "Infeeding to GPU "; + se::Stream* stream, + HloExecutionProfiler* profiler) { + VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString(); + + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); + ShapeTree infeed_buffers = + GetOrCreateInfeedManager()->BlockingGetNextDestination(); + + { + // The infeed buffer has an extra outer tuple with a token. Adjust the index + // accordingly. + ShapeIndex index = {0}; + std::function*)> copy_tuple_contents = + [&](std::vector* tuple_element_addresses) { + const Shape& shape = ShapeUtil::GetSubshape(infeed_buffers.shape(), + ShapeIndexView(index, 1)); + // For the leaf buffers of the tuple copy the elements directly. + if (ShapeUtil::IsArray(shape)) { + const BufferAllocation::Slice& tuple_element_buffer = + infeed_slices_.element(index); + se::DeviceMemoryBase tuple_element_address = + buffer_allocations.GetDeviceAddress(tuple_element_buffer); - se::DeviceMemoryBase destination_address = - buffer_allocations.GetDeviceAddress(destination_buffer_); + InfeedBuffer* buffer = + infeed_buffers.mutable_element(ShapeIndexView(index, 1)); + stream->ThenMemcpy(&tuple_element_address, + *(buffer->device_memory()), buffer->length()); + tuple_element_addresses->push_back(tuple_element_address.opaque()); + return; + } + + const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape); + index.push_back(0); + std::vector inner_tuple_element_addresses; + for (int64 i = 0; i < tuple_element_count; ++i) { + index.back() = i; + copy_tuple_contents(&inner_tuple_element_addresses); + } + index.pop_back(); + + // Create a buffer of pointers for non-leaf buffers. + CHECK_EQ(tuple_element_count, inner_tuple_element_addresses.size()); + auto host_size = inner_tuple_element_addresses.size() * sizeof(void*); + se::DeviceMemoryBase tuple_address = + buffer_allocations.GetDeviceAddress( + infeed_slices_.element(index)); + stream->ThenMemcpy(&tuple_address, + inner_tuple_element_addresses.data(), host_size); + tuple_element_addresses->push_back(tuple_address.opaque()); + }; - InfeedManager* infeed_manager = GetOrCreateInfeedManager(); - std::vector infeed_buffers; - if (ShapeUtil::IsTuple(hlo_instruction()->shape())) { - CHECK(!ShapeUtil::IsNestedTuple(hlo_instruction()->shape())); - // Transfer the tuple elements first. std::vector tuple_element_addresses; - for (BufferAllocation::Slice tuple_element_buffer : - tuple_element_buffers_) { - se::DeviceMemoryBase tuple_element_address = - buffer_allocations.GetDeviceAddress(tuple_element_buffer); - - InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); - infeed_buffers.push_back(buffer); - stream->ThenMemcpy(&tuple_element_address, *(buffer->device_memory()), - buffer->length()); - tuple_element_addresses.push_back(tuple_element_address.opaque()); - } - // Transfer the tuple outer buffer. - auto host_size = tuple_element_addresses.size() * sizeof(void*); - stream->ThenMemcpy(&destination_address, tuple_element_addresses.data(), - host_size); - } else { - InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); - infeed_buffers.push_back(buffer); - stream->ThenMemcpy(&destination_address, *(buffer->device_memory()), - buffer->length()); + copy_tuple_contents(&tuple_element_addresses); + CHECK_EQ(1, tuple_element_addresses.size()); } + // Construct top-level tuple of infeed containing the data and the token. Use + // a nullptr for the token, it should never be dereferenced. + se::DeviceMemoryBase data_address = + buffer_allocations.GetDeviceAddress(infeed_slices_.element({0})); + void* infeed_addresses[] = {data_address.opaque(), nullptr}; + se::DeviceMemoryBase top_level_address = + buffer_allocations.GetDeviceAddress(infeed_slices_.element({})); + stream->ThenMemcpy(&top_level_address, infeed_addresses, 2 * sizeof(void*)); + Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", stream, block_status.error_message().c_str()); } - infeed_manager->ReleaseBuffers(infeed_buffers); - VLOG(2) << "Infeeding to GPU complete"; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 93713cb12defd95bdd69cb0aa7ad7b4e37fc8fae..59487e245b78e66c45409fe712e86d3392e50580 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -32,23 +33,19 @@ namespace gpu { class InfeedThunk : public Thunk { public: // Constructs a InfeedThunk that copies data from the on-device - // infeed queue to the device buffer - // `destination_buffer`. `mem_size` is the size of the data in - // bytes. - InfeedThunk(tensorflow::gtl::ArraySlice - tuple_element_buffers, - const BufferAllocation::Slice& destination_buffer, + // infeed queue into the buffers in the given shape tree. + InfeedThunk(const ShapeTree& infeed_slices, const HloInstruction* hlo_instruction); InfeedThunk(const InfeedThunk&) = delete; InfeedThunk& operator=(const InfeedThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: - const std::vector tuple_element_buffers_; - const BufferAllocation::Slice destination_buffer_; + const ShapeTree infeed_slices_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 6c4519185b34989eb53c884ba214d69b824b113c..0f2c83aeb2633a007559d8caac78ea2d233539ed 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -40,6 +40,7 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kFusion || + hlo.opcode() == HloOpcode::kGather || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || @@ -72,6 +73,67 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { } } +// This function limits the maximum number of operands to a fusion. +// +// There's a cap on how many parameters we can pass to a CUDA kernel, but +// exactly what that limit is is hazy, as it depends on (among other things) how +// much GPU constant memory is in use for other purposes. +// +// Moreover, we don't even know at the point that we're running fusion how many +// arguments the CUDA kernel for a fusion node will have: It depends on buffer +// assignment, where we will decide which of the fusion's operands live in XLA's +// big temp buffer versus in other allocations. +// +// As a heuristic, we simply cap the number of fusion operands plus outputs at +// kMaxOperandsAndOutputsPerFusion. This puts an upper bound on the number of +// parameters to the kernel, working around the correctness problem. +// +// This limit is also often good for performance. In a fusion with many +// operands, each GPU thread likely has to do a lot of work, and so possibly +// uses a lot of registers, thus limiting occupancy. +/*static*/ bool GpuInstructionFusion::FusionWouldBeTooLarge( + const HloInstruction* a, const HloInstruction* b) { + // Compute the number of outputs of the (possibly multi-output) fusion node + // we're considering creating. + // + // This isn't precise; we may be off by one if + // - We're creating a multi-output fusion out of two non-MOFs. Creating a + // MOF adds a new buffer, namely, the tuple buffer. + // - We're merging two MOFs. In this case, we should count the tuple buffer + // only once. + // - WLOG there's an edge from `a` to `b` and `b` is the only consumer of + // `a`. In this case the result of `a` is not part of the output of the + // fusion. + // + // But because this is a heuristic and our limit + // kMaxOperandsAndOutputsPerFusion is a large value (so +/- 1 doesn't make a + // big difference), we ignore this small inaccuracy in favor of simplicity. + int64 num_output_buffers = ShapeUtil::SubshapeCount(a->shape()) + + ShapeUtil::SubshapeCount(b->shape()); + + // The new fusion will have no more operands and outputs than + // producer_operands + consumer_operands - 1 + num_output_buffers + // (minus one because we may be fusing a producer->consumer edge between `a` + // and `b`). + // + // This fact may be enough to let us avoid having to compute the true total + // number of operands, which can be expensive. + if (a->operand_count() + b->operand_count() - 1 + num_output_buffers <= + kMaxOperandsAndOutputsPerFusion) { + return false; + } + + // Compute the precise number of operands to the new fusion. + tensorflow::gtl::FlatSet operands( + a->operands().begin(), a->operands().end()); + operands.insert(b->operands().begin(), b->operands().end()); + // If there's an edge between `a` and `b`, don't count it: We're fusing that + // producer -> consumer relationship. + operands.erase(a); + operands.erase(b); + return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion; +} + bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); @@ -140,6 +202,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, IsIEEEFloatingPointScalarConstant(producer->operand(0)) && fused_parameter_users[0]->opcode() == HloOpcode::kMultiply; } + return false; } // Other output fusions are not currently supported on GPUs. @@ -182,8 +245,13 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return true; } - return IsFusile(*producer) && IsFusile(*consumer) && - InstructionFusion::ShouldFuse(consumer, operand_index); + if (!IsFusile(*producer) || !IsFusile(*consumer) || + !InstructionFusion::ShouldFuse(consumer, operand_index)) { + return false; + } + + // We put this check last because it's potentially expensive. + return !FusionWouldBeTooLarge(consumer, producer); } bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index f629d9ff2c7165b652369612c30979150f93bd24..c91f6343a69268ca687004dbe0ffbb863271a95c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -27,6 +27,19 @@ class GpuInstructionFusion : public InstructionFusion { explicit GpuInstructionFusion(bool may_duplicate) : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {} + // Maximum number of operands plus outputs allowed on a single fusion node. + // Exposed publicly mainly for tests. + static constexpr int64 kMaxOperandsAndOutputsPerFusion = 64; + + // Determines whether the combination of `a` and `b` into a (possibly + // multi-output) fusion would be "too large" -- i.e., have more operands and + // outputs than is allowed. + // + // `ShouldFuse` and `ShouldFuseIntoMultiOutput` call this; it's public so that + // other fusion passes (e.g. GPU multi-output fusion) can also call this. + static bool FusionWouldBeTooLarge(const HloInstruction* a, + const HloInstruction* b); + static bool IsExpensive(const HloInstruction& instruction); bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 1963d9eef72d41fa0a275bea98f959671fa7e737..8d0522bd8fd6659e64d18c52807df8dc7fc2f3b8 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -33,7 +33,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndOperandElementReusingConsumerNotFused) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* broadcast2 = @@ -53,7 +53,7 @@ TEST_F(InstructionFusionTest, NonCostlyProducerAndOperandElementReusingConsumerFused) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0)); HloInstruction* broadcast2 = @@ -73,7 +73,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* reshape2 = builder.AddInstruction( @@ -92,7 +92,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* transpose2 = builder.AddInstruction( @@ -606,5 +606,35 @@ TEST_F(InstructionFusionTest, FuseScalarConstant) { op::Parameter())); } +// Check that we limit the number of operands to fusions we create. +TEST_F(InstructionFusionTest, AvoidsLargeFusion) { + constexpr int64 kNumParams = 200; + ASSERT_GT(kNumParams, GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion); + + // Compute p0 + p1 + ... + pN. + HloComputation::Builder b(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {10, 100}); + auto param0 = + b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); + auto sum = param0; + for (int64 i = 1; i < kNumParams; ++i) { + auto param = + b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p")); + sum = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param)); + } + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(b.Build()); + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + for (const HloInstruction* instr : computation->instructions()) { + EXPECT_LE(instr->operand_count(), + GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion) + << instr->ToString(); + } +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 67890bfed1136796c83c7ef6912ffc1ab1b7e332..c349063c71f000435a05306101ad724505f2d197 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -38,40 +38,44 @@ namespace gpu { namespace { // Return whether the given shape is a matrix with no padding. -bool IsRank2WithNoPadding(const Shape& shape) { - return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); +bool IsRank2WithNoPadding(const Shape& shape, int64 batch_dimensions_size) { + return ShapeUtil::Rank(shape) == batch_dimensions_size + 2 && + !LayoutUtil::IsPadded(shape); } // In a gemm operation where output = lhs * rhs, check whether the given shapes // are valid for the operation. bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { + const Shape& output_shape, + int64 batch_dimensions_size) { // The inputs and the output must // 1) be matrices with no padding and a non-zero number of elements, // 2) have an allowed element type. PrimitiveType output_primitive_type = output_shape.element_type(); bool type_is_allowed = (output_primitive_type == F16 || output_primitive_type == F32 || - output_primitive_type == F64); - return type_is_allowed && IsRank2WithNoPadding(lhs_shape) && - IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape) && - !ShapeUtil::HasZeroElements(lhs_shape) && - !ShapeUtil::HasZeroElements(rhs_shape); + output_primitive_type == F64 || output_primitive_type == C64); + return type_is_allowed && + IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) && + IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) && + IsRank2WithNoPadding(output_shape, batch_dimensions_size) && + !ShapeUtil::IsZeroElementArray(lhs_shape) && + !ShapeUtil::IsZeroElementArray(rhs_shape); } bool DotImplementedAsGemm(const HloInstruction& dot) { CHECK_EQ(dot.opcode(), HloOpcode::kDot); const Shape& lhs_shape = dot.operand(0)->shape(); const Shape& rhs_shape = dot.operand(1)->shape(); + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); // If gemm can accept the operand shapes, use it rather than a custom // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) { + if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(), + dim_numbers.lhs_batch_dimensions_size())) { // The size of the reduction dimension should match. The shape inference // guarantees this invariant, so the check here is for programming // errors. - const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); return true; @@ -81,11 +85,6 @@ bool DotImplementedAsGemm(const HloInstruction& dot) { } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { - // We can only do this if the HLO is unnested. - if (hlo.parent() != hlo.GetModule()->entry_computation()) { - return false; - } - // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { return DotImplementedAsGemm(hlo); @@ -242,15 +241,17 @@ llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, arguments_ptr}); } -llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, - llvm::IRBuilder<>* builder) { +llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, + llvm::IRBuilder<>* builder) { int bit_width = value->getType()->getPrimitiveSizeInBits(); + llvm::Value* all_warps_mask = builder->getInt32(-1); // Special case for efficiency if (value->getType()->isFloatTy() && bit_width == 32) { return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_shfl_down_f32, - {value, offset, builder->getInt32(kWarpSize - 1)}, {}, builder); + llvm::Intrinsic::nvvm_shfl_sync_down_f32, + {all_warps_mask, value, offset, builder->getInt32(kWarpSize - 1)}, {}, + builder); } // We must split values wider than 32 bits as the "shfl" instruction operates @@ -264,10 +265,11 @@ llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, for (int i = 0; i < num_segments; ++i) { x = builder->CreateInsertElement( x, - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_shfl_down_i32, - {builder->CreateExtractElement(x, i), - offset, builder->getInt32(kWarpSize - 1)}, - {}, builder), + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_shfl_sync_down_i32, + {all_warps_mask, builder->CreateExtractElement(x, i), offset, + builder->getInt32(kWarpSize - 1)}, + {}, builder), i); } return builder->CreateBitCast( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 59455f389e733fee2d6cace7486f919a0c5e834e..5d23a3d01842c7b4ff405171cd49c96a19f7e5b0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -31,6 +31,12 @@ namespace gpu { constexpr int64 kWarpSize = 32; // Returns true if `hlo` will be implemented as a call to BLAS gemm. +// +// Precondition: `hlo` is in an "unnested context", meaning, it lives within the +// entry computation, within the either of a while loop's subcomputations, +// within any of a conditional's subcomputations, etc., but *does not* live +// within a reduce subcomputation, a map subcomputation, a fusion +// subcomputation, etc. It's OK if `hlo` *is* a fusion. bool ImplementedAsGemm(const HloInstruction& hlo); // A call to cuDNN for batch normalization is represented as CustomCall HLO with @@ -125,13 +131,17 @@ llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, llvm::IRBuilder<>* builder); // Emits code to shuffle data between threads of a warp. This has the same -// semantics as the PTX "shfl.down" instruction [0] but works for values of any -// size. The last operand of the emitted "shfl" is `kWarpSize - 1`. +// semantics as the PTX "shfl.sync.down" instruction but works for values that +// aren't 32 bits in size. The last operand of the emitted "shfl" is +// `kWarpSize - 1`. +// +// This function emits a "full-warp" shuffle, which all threads of a warp +// participate in. *Do not use this function from a divergent context:* You +// can't correctly do so on both Volta and earlier GPUs. // -// [0] -// http://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl -llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, - llvm::IRBuilder<>* builder); +// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync +llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, + llvm::IRBuilder<>* builder); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 547af33e9a98c03e1429366172f9a401e385a9d1..6675dbd3f9eef8d13c9dec200e5bf47faa5b514d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -57,47 +57,30 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, IrEmitterContext* ir_emitter_context, bool is_nested) : ir_emitter_context_(ir_emitter_context), module_(ir_emitter_context->llvm_module()), - ir_builder_(module_->getContext()), + b_(module_->getContext()), bindings_(ir_emitter_context->hlo_module(), - &ir_emitter_context->buffer_assignment(), &ir_builder_, module_, + &ir_emitter_context->buffer_assignment(), &b_, module_, is_nested), hlo_module_config_(hlo_module_config) { - ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( + b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_enable_fast_math())); + .xla_gpu_enable_fast_math())); } Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand, *hlo) - .EmitReadArrayElement(index, &ir_builder_); + return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_); }; } return EmitTargetElementLoop( - *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_, + *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_, GetNestedComputer()) .MakeElementGenerator(hlo, operand_to_generator)); } Status IrEmitter::HandleConstant(HloInstruction* constant) { - const Literal& literal = constant->literal(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - *module_, initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, - /*Name=*/""); - VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl - << " emitted_value: " << llvm_ir::DumpToString(*global_for_const) - << std::endl - << " its type: " - << llvm_ir::DumpToString(*global_for_const->getType()); - llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( - global_for_const, - llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); - bindings_.BindHloToIrValue(*constant, shape_constant); return Status::OK(); } @@ -122,15 +105,10 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { get_tuple_element->shape(), get_tuple_element->tuple_index(), // TODO(b/26344050): tighten the alignment here // based on the real element type. - /*alignment=*/1, GetBasePointer(*operand), &ir_builder_, module_)); + /*alignment=*/1, GetBasePointer(*operand), &b_, module_)); return Status::OK(); } -Status IrEmitter::HandleSort(HloInstruction*) { - // TODO(b/26783907): Implement sort on GPU. - return Unimplemented("sort"); -} - Status IrEmitter::HandleSend(HloInstruction*) { return Unimplemented("Send is not implemented on GPU"); } @@ -147,13 +125,16 @@ Status IrEmitter::HandleRecvDone(HloInstruction*) { return Unimplemented("Recv-done is not implemented on GPU"); } +Status IrEmitter::HandleScatter(HloInstruction*) { + return Unimplemented("Scatter is not implemented on GPUs."); +} + Status IrEmitter::HandleTuple(HloInstruction* tuple) { std::vector base_ptrs; for (const HloInstruction* operand : tuple->operands()) { base_ptrs.push_back(GetBasePointer(*operand)); } - llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &ir_builder_, - module_); + llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_, module_); return Status::OK(); } @@ -174,7 +155,7 @@ Status IrEmitter::EmitCallToNestedComputation( std::vector arguments(operands.begin(), operands.end()); arguments.push_back(output); arguments.push_back(bindings_.GetTempBufferBase()); - ir_builder_.CreateCall(emitted_function, arguments); + b_.CreateCall(emitted_function, arguments); return Status::OK(); } @@ -194,45 +175,44 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( HloOpcode root_opcode = computation.root_instruction()->opcode(); PrimitiveType element_type = computation.root_instruction()->shape().element_type(); - llvm::Value* source = ir_builder_.CreateLoad(source_address, "source"); + bool is_atomic_integral = element_type == S32 || element_type == U32 || + element_type == S64 || element_type == U64; + llvm::Value* source = b_.CreateLoad(source_address, "source"); if (root_opcode == HloOpcode::kAdd) { // NVPTX supports atomicAdd on F32 and integer types. if (element_type == F32) { // F32 + F32 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_atomic_load_add_f32, {output_address, source}, - {output_address->getType()}, &ir_builder_); + {output_address->getType()}, &b_); return true; } - if (primitive_util::IsIntegralType(element_type)) { + if (is_atomic_integral) { // integral + integral - ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, - source, - llvm::AtomicOrdering::SequentiallyConsistent); + b_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } } - // NVPTX supports atomicMax and atomicMin on only integer types. - if (root_opcode == HloOpcode::kMaximum && - primitive_util::IsIntegralType(element_type)) { + // NVPTX supports atomicMax and atomicMin only on integer types. + if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) { // max(integral, integral) auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Max : llvm::AtomicRMWInst::UMax; - ir_builder_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + b_.CreateAtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } - if (root_opcode == HloOpcode::kMinimum && - primitive_util::IsIntegralType(element_type)) { + if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) { // min(integral, integral) auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Min : llvm::AtomicRMWInst::UMin; - ir_builder_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + b_.CreateAtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } @@ -304,20 +284,20 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, llvm::Type* element_address_type = element_type->getPointerTo(); int atomic_size = (element_size < 32) ? 32 : element_size; - llvm::Type* atomic_type = ir_builder_.getIntNTy(atomic_size); + llvm::Type* atomic_type = b_.getIntNTy(atomic_size); llvm::Type* atomic_address_type = atomic_type->getPointerTo(output_address_type->getPointerAddressSpace()); // cas_old_output_address and cas_new_output_address point to the scratch // memory where we store the old and new values for the repeated atomicCAS // operations. - llvm::Value* cas_old_output_address = ir_builder_.CreateAlloca( + llvm::Value* cas_old_output_address = b_.CreateAlloca( atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); - llvm::Value* cas_new_output_address = ir_builder_.CreateAlloca( + llvm::Value* cas_new_output_address = b_.CreateAlloca( atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); // Emit preparation code to the preheader. - llvm::BasicBlock* loop_preheader_bb = ir_builder_.GetInsertBlock(); + llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock(); llvm::Value* atomic_memory_address; // binop_output_address points to the scratch memory that stores the @@ -328,77 +308,71 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, CHECK_EQ((element_size % sizeof(char)), 0); llvm::Type* address_int_type = module_->getDataLayout().getIntPtrType(output_address_type); - atomic_memory_address = - ir_builder_.CreatePtrToInt(output_address, address_int_type); + atomic_memory_address = b_.CreatePtrToInt(output_address, address_int_type); llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3); - llvm::Value* offset = ir_builder_.CreateAnd(atomic_memory_address, mask); + llvm::Value* offset = b_.CreateAnd(atomic_memory_address, mask); mask = llvm::ConstantInt::get(address_int_type, -4); - atomic_memory_address = ir_builder_.CreateAnd(atomic_memory_address, mask); + atomic_memory_address = b_.CreateAnd(atomic_memory_address, mask); atomic_memory_address = - ir_builder_.CreateIntToPtr(atomic_memory_address, atomic_address_type); - binop_output_address = ir_builder_.CreateAdd( - ir_builder_.CreatePtrToInt(cas_new_output_address, address_int_type), - offset); + b_.CreateIntToPtr(atomic_memory_address, atomic_address_type); + binop_output_address = b_.CreateAdd( + b_.CreatePtrToInt(cas_new_output_address, address_int_type), offset); binop_output_address = - ir_builder_.CreateIntToPtr(binop_output_address, element_address_type); + b_.CreateIntToPtr(binop_output_address, element_address_type); } else { atomic_memory_address = - ir_builder_.CreateBitCast(output_address, atomic_address_type); + b_.CreateBitCast(output_address, atomic_address_type); binop_output_address = - ir_builder_.CreateBitCast(cas_new_output_address, element_address_type); + b_.CreateBitCast(cas_new_output_address, element_address_type); } // Use the value from the memory that atomicCAS operates on to initialize // cas_old_output. llvm::Value* cas_old_output = - ir_builder_.CreateLoad(atomic_memory_address, "cas_old_output"); - ir_builder_.CreateStore(cas_old_output, cas_old_output_address); + b_.CreateLoad(atomic_memory_address, "cas_old_output"); + b_.CreateStore(cas_old_output, cas_old_output_address); llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( - ir_builder_.GetInsertPoint(), "atomic_op_loop_exit"); - llvm::BasicBlock* loop_body_bb = - llvm::BasicBlock::Create(ir_builder_.getContext(), "atomic_op_loop_body", - ir_builder_.GetInsertBlock()->getParent()); - ir_builder_.SetInsertPoint(loop_body_bb); + b_.GetInsertPoint(), "atomic_op_loop_exit"); + llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create( + b_.getContext(), "atomic_op_loop_body", b_.GetInsertBlock()->getParent()); + b_.SetInsertPoint(loop_body_bb); // Change preheader's successor from loop_exit_bb to loop_body_bb. loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb); // Emit the body of the loop that repeatedly invokes atomicCAS. // // Use cas_old_output to initialize cas_new_output. - cas_old_output = - ir_builder_.CreateLoad(cas_old_output_address, "cas_old_output"); - ir_builder_.CreateStore(cas_old_output, cas_new_output_address); + cas_old_output = b_.CreateLoad(cas_old_output_address, "cas_old_output"); + b_.CreateStore(cas_old_output, cas_new_output_address); // Emits code to calculate new_output = operation(old_output, source); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( computation, {binop_output_address, source_address}, binop_output_address)); llvm::Value* cas_new_output = - ir_builder_.CreateLoad(cas_new_output_address, "cas_new_output"); + b_.CreateLoad(cas_new_output_address, "cas_new_output"); // Emit code to perform the atomicCAS operation // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, // cas_new_output); - llvm::Value* ret_value = ir_builder_.CreateAtomicCmpXchg( + llvm::Value* ret_value = b_.CreateAtomicCmpXchg( atomic_memory_address, cas_old_output, cas_new_output, llvm::AtomicOrdering::SequentiallyConsistent, llvm::AtomicOrdering::SequentiallyConsistent); // Extract the memory value returned from atomicCAS and store it as // cas_old_output. - ir_builder_.CreateStore( - ir_builder_.CreateExtractValue(ret_value, 0, "cas_old_output"), - cas_old_output_address); + b_.CreateStore(b_.CreateExtractValue(ret_value, 0, "cas_old_output"), + cas_old_output_address); // Extract the success bit returned from atomicCAS and generate a // conditional branch on the success bit. - ir_builder_.CreateCondBr( - ir_builder_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, - loop_body_bb); + b_.CreateCondBr(b_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, + loop_body_bb); // Set the insertion point to the exit basic block so that the caller of // this method can continue emitting code to the right place. - SetToFirstInsertPoint(loop_exit_bb, &ir_builder_); + SetToFirstInsertPoint(loop_exit_bb, &b_); return Status::OK(); } @@ -424,46 +398,49 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation( Status IrEmitter::HandleSelect(HloInstruction* select) { auto pred = select->operand(0); - auto on_true = select->operand(1); - auto on_false = select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); - - if (ShapeUtil::IsTuple(select->shape())) { - llvm_ir::EmitTupleSelect(GetIrArray(*select, *select), - GetIrArray(*pred, *select), - GetBasePointer(*on_true), - GetBasePointer(*on_false), &ir_builder_, module_); - return Status::OK(); - } - // We must not call the subclass `DefaultAction` method, lest its // `HandleSelect` call `IrEmitter::HandleSelect` and its `DefaultAction` // assume no handler has already been called. return IrEmitter::DefaultAction(select); } +Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { + auto pred = tuple_select->operand(0); + auto on_true = tuple_select->operand(1); + auto on_false = tuple_select->operand(2); + TF_RET_CHECK(pred->shape().element_type() == PRED); + TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); + TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape())); + llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select), + GetIrArray(*pred, *tuple_select), + GetBasePointer(*on_true), GetBasePointer(*on_false), + &b_, module_); + return Status::OK(); +} + namespace { -llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* ir_builder) { - return ir_builder->CreateExtractValue(x, {0}); -} - -llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* ir_builder) { - return ir_builder->CreateExtractValue(x, {1}); -} - -std::pair MultiplyComplex( - llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* ir_builder) { - llvm::Value* lhs_real = Real(lhs_value, ir_builder); - llvm::Value* lhs_imag = Imag(lhs_value, ir_builder); - llvm::Value* rhs_real = Real(rhs_value, ir_builder); - llvm::Value* rhs_imag = Imag(rhs_value, ir_builder); - llvm::Value* real_result1 = ir_builder->CreateFMul(lhs_real, rhs_real); - llvm::Value* real_result2 = ir_builder->CreateFMul(lhs_imag, rhs_imag); - llvm::Value* real_result = ir_builder->CreateFSub(real_result1, real_result2); - llvm::Value* imag_result1 = ir_builder->CreateFMul(lhs_real, rhs_imag); - llvm::Value* imag_result2 = ir_builder->CreateFMul(lhs_imag, rhs_real); - llvm::Value* imag_result = ir_builder->CreateFAdd(imag_result1, imag_result2); +llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* b) { + return b->CreateExtractValue(x, {0}); +} + +llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* b) { + return b->CreateExtractValue(x, {1}); +} + +std::pair MultiplyComplex(llvm::Value* lhs_value, + llvm::Value* rhs_value, + llvm::IRBuilder<>* b) { + llvm::Value* lhs_real = Real(lhs_value, b); + llvm::Value* lhs_imag = Imag(lhs_value, b); + llvm::Value* rhs_real = Real(rhs_value, b); + llvm::Value* rhs_imag = Imag(rhs_value, b); + llvm::Value* real_result1 = b->CreateFMul(lhs_real, rhs_real); + llvm::Value* real_result2 = b->CreateFMul(lhs_imag, rhs_imag); + llvm::Value* real_result = b->CreateFSub(real_result1, real_result2); + llvm::Value* imag_result1 = b->CreateFMul(lhs_real, rhs_imag); + llvm::Value* imag_result2 = b->CreateFMul(lhs_imag, rhs_real); + llvm::Value* imag_result = b->CreateFAdd(imag_result1, imag_result2); return {real_result, imag_result}; } } // namespace @@ -477,23 +454,29 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { const Shape& lhs_shape = lhs_instruction->shape(); const Shape& rhs_shape = rhs_instruction->shape(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); + // TODO(b/110211620): Convert to use i32 index_type when it is possible. + llvm::Type* index_type = b_.getInt64Ty(); + llvm_ir::IrArray::Index element_index(index_type); if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) { // If the operands are scalar, don't emit any loops. llvm::Value* lhs_value = - lhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); + lhs_array.EmitReadArrayElement(/*index=*/element_index, &b_); llvm::Value* rhs_value = - rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); + rhs_array.EmitReadArrayElement(/*index=*/element_index, &b_); llvm::Value* result; if (ShapeUtil::ElementIsComplex(lhs_shape)) { - auto value = MultiplyComplex(lhs_value, rhs_value, &ir_builder_); + auto value = MultiplyComplex(lhs_value, rhs_value, &b_); result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); - result = ir_builder_.CreateInsertValue(result, value.first, {0}); - result = ir_builder_.CreateInsertValue(result, value.second, {1}); + result = b_.CreateInsertValue(result, value.first, {0}); + result = b_.CreateInsertValue(result, value.second, {1}); } else { - result = ir_builder_.CreateFMul(lhs_value, rhs_value); + result = b_.CreateFMul(lhs_value, rhs_value); } - target_array.EmitWriteArrayElement(/*index=*/{}, result, &ir_builder_); + target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_); return Status::OK(); } @@ -509,9 +492,15 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { const int64 lhs_reduction_dimension = ShapeUtil::GetDimensionNumber(lhs_shape, -1); const int64 rhs_reduction_dimension = - ShapeUtil::Rank(rhs_shape) >= 2 + ShapeUtil::Rank(rhs_shape) >= 2 + dnums.lhs_batch_dimensions_size() ? ShapeUtil::GetDimensionNumber(rhs_shape, -2) - : 0; + : dnums.lhs_batch_dimensions_size(); + + // Check that the batch dims don't cover the last two dims. + for (int64 batch_dim : dnums.lhs_batch_dimensions()) { + CHECK_NE(lhs_reduction_dimension, batch_dim); + CHECK_NE(rhs_reduction_dimension, batch_dim); + } // Verify the reduction dimension in the two operands are the same size. TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == @@ -520,11 +509,18 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Create loop nests which loop through the LHS operand dimensions and the RHS // operand dimensions. The reduction dimension of the LHS and RHS are handled // in a separate innermost loop which performs the sum of products. - llvm_ir::ForLoopNest loop_nest(IrName(dot), &ir_builder_); - llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest( - lhs_array, lhs_reduction_dimension, "lhs", &loop_nest); - llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest( - rhs_array, rhs_reduction_dimension, "rhs", &loop_nest); + llvm_ir::ForLoopNest loop_nest(IrName(dot), &b_); + llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest( + lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); + llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( + rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); + + // We don't have to iterate over the batch dimensions in both arrays, simplify + // the loop nest of the rhs. + for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) { + DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i)); + rhs_index[i] = lhs_index[i]; + } // Create the reduction loop which does the sum of products reduction. std::unique_ptr reduction_loop = loop_nest.AddLoop( @@ -544,7 +540,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry( accum_type, // The pointee type of the alloca instruction. "accum_address", // The name of the alloca instruction. - &ir_builder_); + &b_); // Initialize the accumulator in the preheader to zero. new llvm::StoreInst( @@ -558,59 +554,58 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // updated_accum = accum + lhs_element * rhs_element // *accum_address = updated_accum TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty()); - ir_builder_.SetInsertPoint( + b_.SetInsertPoint( &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt()); - llvm::Value* lhs_element = - lhs_array.EmitReadArrayElement(lhs_index, &ir_builder_); - llvm::Value* rhs_element = - rhs_array.EmitReadArrayElement(rhs_index, &ir_builder_); - llvm::Value* accum = ir_builder_.CreateLoad(accum_address); + llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_); + llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_); + llvm::Value* accum = b_.CreateLoad(accum_address); llvm::Value* updated_accum; if (ShapeUtil::ElementIsComplex(lhs_shape)) { - auto value = MultiplyComplex(lhs_element, rhs_element, &ir_builder_); - llvm::Value* accum_real = Real(accum, &ir_builder_); - llvm::Value* real_sum = ir_builder_.CreateFAdd(accum_real, value.first); - updated_accum = ir_builder_.CreateInsertValue(accum, real_sum, {0}); - llvm::Value* accum_imag = Imag(accum, &ir_builder_); - llvm::Value* imag_sum = ir_builder_.CreateFAdd(accum_imag, value.second); - updated_accum = ir_builder_.CreateInsertValue(updated_accum, imag_sum, {1}); + auto value = MultiplyComplex(lhs_element, rhs_element, &b_); + llvm::Value* accum_real = Real(accum, &b_); + llvm::Value* real_sum = b_.CreateFAdd(accum_real, value.first); + updated_accum = b_.CreateInsertValue(accum, real_sum, {0}); + llvm::Value* accum_imag = Imag(accum, &b_); + llvm::Value* imag_sum = b_.CreateFAdd(accum_imag, value.second); + updated_accum = b_.CreateInsertValue(updated_accum, imag_sum, {1}); } else { - llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element); - updated_accum = ir_builder_.CreateFAdd(accum, product); + llvm::Value* product = b_.CreateFMul(lhs_element, rhs_element); + updated_accum = b_.CreateFAdd(accum, product); } - ir_builder_.CreateStore(updated_accum, accum_address); + b_.CreateStore(updated_accum, accum_address); // After the reduction loop exits, store the accumulator into the target // address. The index into the target address is the concatenation of the rhs // and lhs indexes with the reduction dimensions removed. The terms from the // rhs index are the lower dimensions in the index so we add them first. - llvm_ir::IrArray::Index target_index; + llvm_ir::IrArray::Index target_index(index_type); for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { target_index.push_back(lhs_index[dimension]); } } - for (size_t dimension = 0; dimension < rhs_index.size(); ++dimension) { + // Skip over the batch dimensions to not have them in the index twice. + for (size_t dimension = dnums.lhs_batch_dimensions_size(); + dimension < rhs_index.size(); ++dimension) { if (dimension != rhs_reduction_dimension) { target_index.push_back(rhs_index[dimension]); } } - SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_); target_array.EmitWriteArrayElement( target_index, - ir_builder_.CreateLoad( - accum_address), // The value written to the target array. - &ir_builder_); + b_.CreateLoad(accum_address), // The value written to the target array. + &b_); // Set the IR builder insert point to the exit basic block of the outer most // loop. This ensures later instructions are inserted after this loop nest. - ir_builder_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); + b_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); return Status::OK(); } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { - if (ShapeUtil::HasZeroElements(convolution->shape())) { + if (ShapeUtil::IsZeroElementArray(convolution->shape())) { // Emit no code for an empty output. return Status::OK(); } @@ -620,7 +615,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { } Status IrEmitter::HandleFft(HloInstruction* fft) { - if (ShapeUtil::HasZeroElements(fft->shape())) { + if (ShapeUtil::IsZeroElementArray(fft->shape())) { // Emit no code for an empty output. return Status::OK(); } @@ -637,6 +632,10 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { } Status IrEmitter::HandleReduce(HloInstruction* reduce) { + // TODO(b/112040122): Support variadic reduce. + if (!ShapeUtil::IsArray(reduce->shape())) { + return Unimplemented("Variadic reduce is not supported on GPU"); + } auto arg = reduce->operand(0); auto init_value = reduce->operand(1); tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); @@ -646,11 +645,10 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { [=](const llvm_ir::IrArray::Index& index) -> StatusOr { // Initialize an accumulator with init_value. llvm::AllocaInst* accumulator_addr = - ir_builder_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType( + b_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType( reduce->shape().element_type(), module_)); - ir_builder_.CreateStore( - ir_builder_.CreateLoad(GetBasePointer(*init_value)), - accumulator_addr); + b_.CreateStore(b_.CreateLoad(GetBasePointer(*init_value)), + accumulator_addr); // The enclosing loops go over all the target elements. Now we have to // compute the actual target element. For this, we build a new loop nest @@ -658,12 +656,12 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { // AddLoopsForShapeOnDimensions will return an Index where induction // Value*s are placed for each dimension in dimensions, and all the rest // are nullptrs. - llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &ir_builder_); + llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); const llvm_ir::IrArray::Index reduced_dims_index = loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, "reduction_dim"); - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); // Build a full index for the input argument, using reduced_dims_index // as the base. In reduced_dims_index only the reduction dimensions are @@ -682,13 +680,12 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { // Apply the reduction function to the loaded value. llvm::Value* input_address = - GetIrArray(*arg, *reduce) - .EmitArrayElementAddress(input_index, &ir_builder_); + GetIrArray(*arg, *reduce).EmitArrayElementAddress(input_index, &b_); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *function, {accumulator_addr, input_address}, accumulator_addr)); - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - return ir_builder_.CreateLoad(accumulator_addr); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); + return b_.CreateLoad(accumulator_addr); }); } @@ -701,8 +698,8 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArray(*operand, *fusion)); } - GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, - &ir_builder_, GetNestedComputer()); + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); @@ -732,24 +729,6 @@ Status IrEmitter::HandleOutfeed(HloInstruction*) { return Unimplemented("Outfeed is not supported on GPU."); } -Status IrEmitter::HandleRng(HloInstruction* random) { - ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; - for (const HloInstruction* operand : random->operands()) { - operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand, *random) - .EmitReadArrayElement(index, &ir_builder_); - }; - } - // Emits a single-threaded loop because the loop body generated by the element - // generator for Rng can't be parallelized (b/32333178). - return llvm_ir::LoopEmitter( - GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_, - GetNestedComputer()) - .MakeElementGenerator(random, operand_to_generator), - GetIrArray(*random, *random), &ir_builder_) - .EmitLoop(IrName(random)); -} - Status IrEmitter::HandleBatchNormInference(HloInstruction*) { return Unimplemented( "The GPU backend does not implement BatchNormInference directly. It " @@ -773,34 +752,9 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { "to a cudnn CustomCall using CudnnBatchNormRewriter."); } -llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( - const llvm_ir::IrArray& operand_array, int64 reduction_dimension, - tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) { - // Prepares the dimension list we will use to emit the loop nest. Outermost - // loops are added first. Add loops in major-to-minor order, and skip the - // reduction dimension. - std::vector dimensions; - const Shape& shape = operand_array.GetShape(); - for (int i = 0; i < LayoutUtil::MinorToMajor(shape).size(); ++i) { - int64 dimension = LayoutUtil::Major(shape.layout(), i); - if (dimension != reduction_dimension) { - dimensions.push_back(dimension); - } - } - - // Create loop nest with one for-loop for each dimension of the - // output. - llvm_ir::IrArray::Index index = - loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); - // Verify every dimension except the reduction dimension was set in the index. - for (size_t dimension = 0; dimension < index.size(); ++dimension) { - if (dimension == reduction_dimension) { - DCHECK_EQ(nullptr, index[dimension]); - } else { - DCHECK_NE(nullptr, index[dimension]); - } - } - return index; +Status IrEmitter::HandleIota(HloInstruction*) { + // TODO(b/64798317): implement iota on GPU. + return Unimplemented("Iota is not implemented on GPU."); } StatusOr IrEmitter::ComputeNestedElement( @@ -809,16 +763,16 @@ StatusOr IrEmitter::ComputeNestedElement( llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType( computation.root_instruction()->shape().element_type(), module_), - "return_buffer", &ir_builder_); + "return_buffer", &b_); std::vector parameter_buffers; for (llvm::Value* parameter_element : parameter_elements) { parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry( - parameter_element->getType(), "parameter_buffer", &ir_builder_)); - ir_builder_.CreateStore(parameter_element, parameter_buffers.back()); + parameter_element->getType(), "parameter_buffer", &b_)); + b_.CreateStore(parameter_element, parameter_buffers.back()); } TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers, return_buffer)); - return ir_builder_.CreateLoad(return_buffer); + return b_.CreateLoad(return_buffer); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index e55dfc6dae844ceb1d28ad389d133c80823bad9a..561c6838798aa92ce2c96b3c45d5ba42fe6edef3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -79,7 +79,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; - Status HandleSort(HloInstruction* sort) override; Status HandleSend(HloInstruction* send) override; Status HandleSendDone(HloInstruction* send_done) override; Status HandleRecv(HloInstruction* recv) override; @@ -87,14 +86,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleParameter(HloInstruction* parameter) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleTuple(HloInstruction* tuple) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; - Status HandleRng(HloInstruction* random) override; Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormGrad(HloInstruction* batch_norm) override; + Status HandleIota(HloInstruction* iota) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } @@ -161,7 +162,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { // The following fields track the IR emission state. According to LLVM memory // management rules, their memory is owned by the module. - llvm::IRBuilder<> ir_builder_; + llvm::IRBuilder<> b_; // Mapping from HLO to its underlying LLVM value. HloToIrBindings bindings_; @@ -170,17 +171,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { const HloModuleConfig& hlo_module_config_; private: - // Emits a series of nested loops for iterating over an operand array in the - // dot operation. Loops are constructed in major to minor dimension layout - // order. No loop is emitted for the given reduction_dimension. The function - // returns an IrArray index for the given operand_array containing the indvars - // of the loops. All dimensions of the index are filled except for the - // reduction dimension. name_suffix is the string to append to the names of - // LLVM constructs (eg, basic blocks) constructed by this method. - llvm_ir::IrArray::Index EmitOperandArrayLoopNest( - const llvm_ir::IrArray& operand_array, int64 reduction_dimension, - tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest); - // A helper method for EmitAtomicOperationForNestedComputation. Certain // computations, such as floating-point addition and integer maximization, can // be simply implemented using an LLVM atomic instruction. If "computation" is @@ -197,6 +187,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* output_address, llvm::Value* source_address); + // A helper method for HandleSort(). It adds the inner comparison loop where + // we compare elements pointed to by 'keys_index' and 'compare_keys_index'. + void EmitCompareLoop(int64 dimension_to_sort, + const llvm_ir::IrArray::Index& keys_index, + const llvm_ir::IrArray::Index& compare_keys_index, + const llvm_ir::IrArray& keys_array); + StatusOr ComputeNestedElement( const HloComputation& computation, tensorflow::gtl::ArraySlice parameter_elements); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index bb47a4280541ce2806472aa9365bb0ef38c0c3b3..5c827e5f9cf3e1c04af444dae338a2ec411ce372 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -70,10 +70,10 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( argument_dereferenceable_bytes.push_back(root_size); } // The base pointer of the memory block for all pre-allocated temp buffers. - argument_types.push_back(ir_builder_.getInt8PtrTy()); + argument_types.push_back(b_.getInt8PtrTy()); llvm::FunctionType* function_type = - llvm::FunctionType::get(ir_builder_.getVoidTy(), argument_types, false); + llvm::FunctionType::get(b_.getVoidTy(), argument_types, false); llvm::Function* function = llvm::Function::Create( function_type, // The function type. llvm::GlobalValue::InternalLinkage, // The linkage type. @@ -96,8 +96,7 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( llvm::BasicBlock::Create(function->getContext(), "entry", function); // Emit a "return void" at entry_bb's end, and sets the insert point before // that return instruction. - ir_builder_.SetInsertPoint( - llvm::ReturnInst::Create(function->getContext(), entry_bb)); + b_.SetInsertPoint(llvm::ReturnInst::Create(function->getContext(), entry_bb)); std::vector non_io_hlos; for (const auto* hlo : nested_computation.instructions()) { @@ -120,25 +119,24 @@ Status IrEmitterNested::EmitTargetElementLoop( // For MOF we give the loop emitter an array for every output it should // generate. if (hlo.IsMultiOutputFusion()) { + const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape()); std::vector target_arrays; - for (int64 i = 0, e = ShapeUtil::TupleElementCount(hlo.shape()); i != e; - ++i) { + target_arrays.reserve(num_elems); + for (int64 i = 0; i != num_elems; ++i) { target_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter(element_generator, target_arrays, &ir_builder_) - .EmitLoop()); + llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop()); std::vector tuple_operand_ptrs; + tuple_operand_ptrs.reserve(num_elems); for (const llvm_ir::IrArray& array : target_arrays) { tuple_operand_ptrs.push_back(array.GetBasePointer()); } - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, - module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); return Status::OK(); } - return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), - &ir_builder_) + return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_) .EmitLoop(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 726434c3dfd4f1ef866d2eb9b6d7eb8b659e0984..1e81cbde35372d9f7d6ee234d2408038d6f99dc7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -28,11 +28,12 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" @@ -48,30 +49,35 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -79,6 +85,7 @@ namespace gpu { namespace { +using llvm_ir::IrArray; using llvm_ir::IrName; using tensorflow::gtl::ArraySlice; using tensorflow::gtl::InlinedVector; @@ -164,40 +171,6 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { return DfsHloVisitor::Postprocess(hlo); } -namespace { -bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment, - const HloInstruction& hlo) { - // `hlo` needs to satisfy the following conditions to be implemented as a - // host-to-device cuMemcpy. - // - // 1. `hlo` is a kCopy instruction. - // 2. `hlo`'s only operand is a kConstant instruction. - // 3. `hlo` and its operand have the same shape (thus the same layout too). - // 4. The address of `hlo`'s buffer is known at runtime (without dereferencing - // pointers in a tuple). - return hlo.opcode() == HloOpcode::kCopy && - hlo.operand(0)->opcode() == HloOpcode::kConstant && - ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok(); -} - -bool ImplementedAsDeviceToDeviceMemcpy( - const BufferAssignment& buffer_assignment, const HloInstruction& hlo) { - // `hlo` needs to satisfy three conditions to be implemented as a - // device-to-device cuMemcpy. - // - // 1. `hlo` is a kCopy instruction. - // 2. `hlo` and its operand have the same shape (thus the same layout too). - // 3. `hlo` and its operand have a statically-known buffer assignment - // (constants do not, for instance), which means the source buffer also - // resides on the device. - return hlo.opcode() == HloOpcode::kCopy && - ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok() && - buffer_assignment.GetUniqueTopLevelSlice(hlo.operand(0)).ok(); -} -} // namespace - llvm::Function* IrEmitterUnnested::BuildKernelPrototype( const HloInstruction& inst, tensorflow::gtl::ArraySlice args) { @@ -211,7 +184,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( llvm::LLVMContext& context = module->getContext(); llvm::FunctionType* kernel_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(context), - std::vector(args.size(), ir_builder_.getInt8PtrTy()), + std::vector(args.size(), b_.getInt8PtrTy()), /*isVarArg=*/false); llvm::Function* kernel = llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage, @@ -226,9 +199,20 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( ++arg_it; kernel->addDereferenceableAttr(arg_no + 1, alloc->size()); + + const int64 alignment = [&] { + if (alloc->is_entry_computation_parameter()) { + return kEntryParameterAlignBytes; + } else if (alloc->is_constant()) { + return kConstantBufferAlignBytes; + } else { + return kXlaAllocatedBufferAlignBytes; + } + }(); + kernel->addParamAttr( - arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment, - kCudaMallocAlignBytes)); + arg_no, + llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment)); if (alloc->IsPreallocatedTempBuffer()) { fn_arg->setName("temp_buf"); @@ -247,7 +231,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( nvvm_annotations_node->addOperand(llvm::MDNode::get( context, {llvm::ConstantAsMetadata::get(kernel), llvm::MDString::get(context, "kernel"), - llvm::ConstantAsMetadata::get(ir_builder_.getInt32(1))})); + llvm::ConstantAsMetadata::get(b_.getInt32(1))})); // Update the insert point to the entry basic block. llvm::BasicBlock* entry_bb = @@ -255,7 +239,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( // Emit a "return void" at entry_bb's end, and set the insert point before // that return instruction. - ir_builder_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb)); + b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb)); return kernel; } @@ -283,6 +267,69 @@ int ComputeMaxUnrollFactor(const HloInstruction* hlo) { // Cannot unroll. return 1; } + +// Returns the llvm type for the indices used in the kernel that contains the +// hlo instruction. Such indices include the index for the parallel loop and +// the indices for the tensors accessed by the kernel. The return type is i32 +// iff the following conditions are met: +// . The launch_size of the kernel is within the range of i32. +// . The sizes of all the tensors accessed within the kernel are within the +// range of i32. +// Otherwise, the return type is i64. +llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, + llvm::IRBuilder<>* b) { + // Find the unnested hlo instructon for which the kernel is generated for. + const HloInstruction* unnested_hlo = hlo; + const HloComputation* computation = hlo->parent(); + if (computation->IsFusionComputation()) { + unnested_hlo = computation->FusionInstruction(); + } + + auto shape_in_range = [&](const Shape& s) { + bool in_range = true; + ShapeUtil::ForEachSubshape( + s, [&](const Shape& sub_shape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(sub_shape) && + !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { + in_range = false; + } + }); + + return in_range; + }; + + llvm::Type* i64_ty = b->getInt64Ty(); + // Check launch dimension + if (!IsInt32(launch_size)) { + return i64_ty; + } + + // Check the size of result tensors + if (!shape_in_range(unnested_hlo->shape())) { + return i64_ty; + } + + auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool { + return shape_in_range(operand->shape()); + }; + + // Check the size of input tensors + if (!c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { + return i64_ty; + } + + // Check the size of the internal result tensors + if (unnested_hlo->opcode() == HloOpcode::kFusion) { + if (!c_all_of( + unnested_hlo->fused_instructions_computation()->instructions(), + hlo_shape_in_range)) { + return i64_ty; + } + } + + return b->getInt32Ty(); +} + } // namespace Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { @@ -292,21 +339,18 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { unroll_factor = ComputeMaxUnrollFactor(hlo); } - thunk_sequence_->emplace_back(BuildKernelThunk(hlo, unroll_factor)); + thunk_sequence_->emplace_back(BuildKernelThunk( + hlo, /*implements_whole_instruction=*/true, unroll_factor)); return IrEmitter::DefaultAction(hlo); } Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { - const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); - if (dnums.lhs_batch_dimensions_size() > 0 || - dnums.rhs_batch_dimensions_size() > 0) { - return Unimplemented("Dot with batch dimensions not implemented."); - } if (ImplementedAsGemm(*dot)) { thunk_sequence_->emplace_back(BuildGemmThunk(dot)); return Status::OK(); } - thunk_sequence_->emplace_back(BuildKernelThunk(dot)); + thunk_sequence_->emplace_back( + BuildKernelThunk(dot, /*implements_whole_instruction=*/true)); return IrEmitter::HandleDot(dot); } @@ -316,7 +360,8 @@ Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { } Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { - thunk_sequence_->emplace_back(BuildKernelThunk(convolution)); + thunk_sequence_->emplace_back( + BuildKernelThunk(convolution, /*implements_whole_instruction=*/true)); return IrEmitter::HandleConvolution(convolution); } @@ -500,6 +545,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { switch (root->opcode()) { case HloOpcode::kTuple: case HloOpcode::kReduce: { + if (root->opcode() == HloOpcode::kReduce && + ShapeUtil::IsTuple(root->shape())) { + // TODO(b/112040122): Support variadic reduce. + return Unimplemented("Variadic reduce is not supported on GPU"); + } VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); std::vector> thunks; ArraySlice output_instructions = @@ -523,16 +573,17 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { } } CHECK(first_reduce != nullptr); - thunks.push_back(BuildKernelThunk(fusion)); + thunks.push_back( + BuildKernelThunk(fusion, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), fusion)); - std::vector parameter_arrays; + std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter( - hlo_module_config_, ir_emitter_context_->llvm_module(), - &ir_builder_, GetNestedComputer()); + hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, + GetNestedComputer()); FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); @@ -551,17 +602,16 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { if (root->opcode() == HloOpcode::kTuple) { output_shape_index = {i}; } - // TODO(kramerb): CHECK that layouts are equal. Currently this - // breaks multioutputfusion_test. The test has pre-fused - // instructions, but layout_assignment will not assign any layouts - // for instructions inside of a fused computation. It just removes - // the layouts instead. if (inst->opcode() == HloOpcode::kReduce) { - CHECK(ShapeUtil::Compatible(first_reduce->shape(), inst->shape())); - CHECK(ShapeUtil::Compatible(first_reduce->operand(0)->shape(), - inst->operand(0)->shape())); - CHECK(ShapeUtil::Compatible(first_reduce->operand(1)->shape(), - inst->operand(1)->shape())); + CHECK(IsReductionToVector(*inst)) + << "Only reductions to vector are supported"; + // Shapes, layouts and dimensions must be the same for all reduces + // inside of this fusion. + CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); + CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), + inst->operand(0)->shape())); + CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), + inst->operand(1)->shape())); CHECK(first_reduce->dimensions() == inst->dimensions()); input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); init_value_gens.push_back( @@ -569,8 +619,13 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { reducers.push_back(inst->to_apply()); reduce_output_shapes.push_back(std::move(output_shape_index)); } else { - CHECK(ShapeUtil::Compatible(first_reduce->operand(0)->shape(), - inst->shape())); + // For extra outputs we can relax shape equality to allow different + // types (with the same number of elements). Layouts still have to + // match. + CHECK(ShapeUtil::CompatibleIgnoringElementType( + first_reduce->operand(0)->shape(), inst->shape())); + CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), + inst->shape().layout())); extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), std::move(output_shape_index)); } @@ -593,21 +648,22 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // touching the un-updated elements. // Set up kernel thunk and fused ir emitter. - thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); - std::vector operand_arrays; + thunk_sequence_->emplace_back( + BuildKernelThunk(fusion, /*implements_whole_instruction=*/true)); + std::vector operand_arrays; for (HloInstruction* operand : fusion->operands()) { operand_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), - &ir_builder_, GetNestedComputer()); + &b_, GetNestedComputer()); // Shape of the dynamic-update-slice's "update" operand. Shape update_shape = root->operand(1)->shape(); // Array to write into. Because this is an in-place operation, this is the // same as operand 0's array. - llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion); + IrArray output_array = GetIrArray(*fusion, *fusion); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( update_shape, ir_emitter_context_->device_description()); @@ -618,346 +674,38 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( fusion, operand_arrays, output_array, &elemental_emitter, - launch_dimensions, &ir_builder_); + launch_dimensions, &b_); } + if (ImplementedAsGemm(*fusion)) { thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); return Status::OK(); } - CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); - int unroll_factor = ComputeMaxUnrollFactor(fusion); + CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop); - thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor)); - return IrEmitter::HandleFusion(fusion); -} - -namespace { - -// Returns the indices of the first elements of all consecutive subarrays of the -// given array. For example: -// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} -std::vector ConsecutiveSegments(tensorflow::gtl::ArraySlice xs) { - std::vector is = {0}; - for (size_t i = 1; i < xs.size(); ++i) { - if (1 != xs[i] - xs[i - 1]) { - is.push_back(i); - } - } - return is; -} - -// Merges the sequences of dimensions of the given shape which start at the -// given indices `segs`. -Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, - const Shape& shape) { - std::vector dimensions; - for (size_t i = 1; i <= segs.size(); ++i) { - dimensions.push_back(std::accumulate( - shape.dimensions().begin() + segs[i - 1], - shape.dimensions().begin() + - (segs.size() == i ? shape.dimensions().size() : segs[i]), - 1, std::multiplies())); - } - return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), - dimensions); -} - -// Returns whether the given shapes and permutation are a 0-2-1 transpose, and -// if so, the normalized and rank-reduced shapes. The shapes must have the same -// dimensions, so this considers layout only. -// -// This function recognizes higher-rank transposes which are elementwise -// equivalent to a 0-2-1 transpose. -std::tuple IsTranspose021(const Shape& a, const Shape& b) { - CHECK(ShapeUtil::Compatible(a, b)); - std::vector perm(a.dimensions().size()); - { - auto layout_a_orig = LayoutUtil::MinorToMajor(a); - std::vector layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); - auto layout_b_orig = LayoutUtil::MinorToMajor(b); - std::vector layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); - for (size_t i = 0; i < perm.size(); ++i) { - perm[i] = PositionInContainer(layout_b, layout_a[i]); - } + if (CheckAndEmitHloWithTile021(fusion)) { + return Status::OK(); } - auto segs = ConsecutiveSegments(perm); - Shape norm_a = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); - Shape norm_b = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b); - if (3 == segs.size() && 0 == perm[0]) { - Shape reduced_a = MergeDimensions(segs, norm_a); - Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout( - b.element_type(), - Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions()))); - return std::make_tuple(true, reduced_a, reduced_b); - } - return std::make_tuple(false, ShapeUtil::MakeNil(), ShapeUtil::MakeNil()); -} - -// Returns whether the given shapes are potentially of a 0-2-1 transpose. -// As 0-2-1 is a self-inverse permutation, which shape is input or output is -// arbitrary. -bool AreShapesForTranspose021(const Shape& a, const Shape& b) { - return 3 == b.dimensions().size() && - ShapeUtil::Compatible( - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a), - ShapeUtil::PermuteDimensions( - {0, 2, 1}, - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - b))); -} - -// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from -// major to minor. The x- and y- dimensions are tiled in square tiles of edge -// length `tile_size`. Each thread block of `tile_size` x `num_rows` threads -// transposes one tile: each thread copies a row from the input to a shared -// memory tile, then copies a column from the shared memory tile to the output. -// -// `tile_size` should usually be same as warp size. -// -// Returns (number of tiles = number of thread blocks needed). -// -// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient -// to launch fewer blocks so each transposes many tiles, and -// in any case, the number of blocks we can launch is limited. -// -// This is the same algorithm in CUDA: -// https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189 -int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, - const int64 tile_size, const int64 num_rows, - llvm::IRBuilder<>* builder) { - // Adds `addend` to the given `dim` of `index`. - auto offset_dim = [builder](llvm_ir::IrArray::Index index, - llvm::Value* addend, int64 dim) { - index[dim] = builder->CreateAdd(index[dim], addend); - return index; - }; - - CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape())); - - Shape input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input.GetShape()); - Shape output_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - output.GetShape()); - input = input.CastToShape(input_shape, builder); - output = output.CastToShape(output_shape, builder); - - llvm::Type* tile_type = llvm::ArrayType::get( - llvm::ArrayType::get(input.GetElementLlvmType(), tile_size), - // One extra here to avoid share memory bank conflict - tile_size + 1); - auto* tile = new llvm::GlobalVariable( - *builder->GetInsertBlock()->getParent()->getParent(), tile_type, - /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, - llvm::UndefValue::get(tile_type), "tile", nullptr, - llvm::GlobalValue::NotThreadLocal, - /*AddressSpace=*/3 /* GPU shared memory */); - - // let x = threadIdx.x - llvm::Value* x = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, num_rows * tile_size, - static_cast(x)); - x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true, - "thread.id.x"); - - // computing logical thread ids - // logical_x = x % tile_size - auto logical_x = builder->CreateURem(x, builder->getInt64(tile_size)); - - // logical_y = x / tile_size - auto logical_y = builder->CreateUDiv(x, builder->getInt64(tile_size)); - - // `emit_cp` emits equivalent to following pseudocode: - // if (tile_size == tile_width && tile_size == tile_height) { - // unroll for (i in range(0, tile_size, num_rows)) { - // emit_cp_element(index + {0, i, 0}, y + logical_y); - // } - // } else if (x < tile_width) { - // tile_height_upperbound = ceil(tile_height / num_rows) * num_rows; - // for (i in range(0, tile_height_upperbound, num_rows)) { - // y_loc = i + logical_y; - // if (y_loc < tile_height) - // emit_cp_element(index + {0, i, 0}, y_loc); - // } - // } - // - // We use this to emit both the copy from input to tile and the copy from tile - // to output. - // - // `index` is the origin of the row or column in the input or output array. - // - // `emit_cp_element(index, y)` emits code to copy a single element between the - // tile and the input or output array, where `y` is the `y`-position in the - // tile, whether which is row or column is a function of whether we're copying - // from input or to output, and `index` is the index into the input or output - // array. - auto emit_cp_tile = [builder, tile_size, &offset_dim, num_rows, logical_x, - logical_y]( - std::function - emit_cp_element, - llvm::Value* tile_width, llvm::Value* tile_height, - const llvm_ir::IrArray::Index& index, - const string& loop_name) { - llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse( - builder->CreateAnd( - builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width), - builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)), - "not_last_row", builder); - builder->SetInsertPoint(if_not_last_row.true_block->getTerminator()); - for (int64 i = 0; i < tile_size; i += num_rows) { - auto source_idx = offset_dim(index, builder->getInt64(i), /*dim=*/1); - auto y_loc = builder->CreateAdd(builder->getInt64(i), logical_y); - emit_cp_element(source_idx, y_loc); - } - builder->SetInsertPoint(if_not_last_row.false_block->getTerminator()); - llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse( - builder->CreateICmpULT(logical_x, tile_width), "x_in_tile", builder); - builder->SetInsertPoint(if_in_tile.true_block->getTerminator()); - - // tile_height_upper_bound = ceil(tile_height / num_rows) * num_rows - auto tile_height_upper_bound = builder->CreateMul( - builder->CreateUDiv( - builder->CreateAdd(tile_height, builder->getInt64(num_rows - 1)), - builder->getInt64(num_rows)), - builder->getInt64(num_rows)); - - auto loop = llvm_ir::ForLoop::EmitForLoop( - loop_name, builder->getInt64(0), tile_height_upper_bound, - builder->getInt64(num_rows), builder); - llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder); - builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator()); - - auto y_loc = builder->CreateAdd(loop->GetIndVarValue(), logical_y); - auto if_y_in_tile = llvm_ir::EmitIfThenElse( - builder->CreateICmpULT(y_loc, tile_height), "y_in_tile", builder); - builder->SetInsertPoint(if_y_in_tile.true_block->getTerminator()); - - emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1), - y_loc); - builder->SetInsertPoint(if_not_last_row.after_block->getTerminator()); - }; - auto input_dims_in_tiles = input_shape.dimensions(); - // Unpermuted dimensions are untiled. - for (int i = 1; i < 3; ++i) { - input_dims_in_tiles[i] = - CeilOfRatio(input_dims_in_tiles[i], tile_size); - } - int64 num_tiles = - std::accumulate(input_dims_in_tiles.begin(), input_dims_in_tiles.end(), 1, - std::multiplies()); - const llvm_ir::IrArray::Index input_tile_index( - /*linear=*/builder->CreateIntCast( - llvm_ir::AddRangeMetadata( - 0, num_tiles, - static_cast(llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, - builder))), - builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"), - ShapeUtil::MakeShapeWithDescendingLayout( - PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)), - builder); - const llvm_ir::IrArray::Index input_tile_origin = ({ - llvm_ir::IrArray::Index index = input_tile_index; - for (int i = 1; i < 3; ++i) { - index[i] = builder->CreateMul(index[i], builder->getInt64(tile_size), - "tile_origin." + std::to_string(i)); - } - index; - }); - const llvm_ir::IrArray::Index input_index = - offset_dim(offset_dim(input_tile_origin, logical_x, /*dim=*/2), logical_y, - /*dim=*/1); - std::vector tile_dims(input_shape.dimensions().size()); - // Only last row or column may not have full size. - for (int i = 1; i < 3; ++i) { - tile_dims[i] = builder->CreateSelect( - builder->CreateICmpEQ(input_tile_index[i], - builder->getInt64(input_dims_in_tiles[i] - 1)), - builder->getInt64(input_shape.dimensions(i) - - (input_dims_in_tiles[i] - 1) * tile_size), - builder->getInt64(tile_size), "tile_size"); - } - - // Load data from input memory to shared memory tile. - emit_cp_tile( - // tile[y, x] = input_array[index] - [builder, tile, &input, logical_x](const llvm_ir::IrArray::Index& index, - llvm::Value* y) { - builder->CreateStore( - input.EmitReadArrayElement(index, builder, "input_element"), - builder->CreateGEP(tile, {builder->getInt64(0), y, logical_x})); - }, - tile_dims[2], tile_dims[1], input_index, "input"); + int unroll_factor = ComputeMaxUnrollFactor(fusion); - // Wait for all threads to reach this point, lest we copy a value from tile to - // output before the other thread copies it from input to tile. - // This is `__syncthreads` in CUDA. - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, builder); - - const llvm_ir::IrArray::Index output_tile_index( - Permute({0, 2, 1}, input_tile_index.multidim())); - const llvm_ir::IrArray::Index output_tile_origin( - Permute({0, 2, 1}, input_tile_origin.multidim())); - const llvm_ir::IrArray::Index output_index = - offset_dim(offset_dim(output_tile_origin, logical_x, /*dim=*/2), - logical_y, /*dim=*/1); - - // Store data from shared memory tile to output memory. - emit_cp_tile( - // output_array[index] = tile[x, y] - [builder, tile, &output, logical_x](const llvm_ir::IrArray::Index& index, - llvm::Value* y) { - output.EmitWriteArrayElement( - index, - builder->CreateLoad( - builder->CreateGEP(tile, {builder->getInt64(0), logical_x, y}), - "output_element"), - builder); - }, - tile_dims[1], tile_dims[2], output_index, "output"); - - return num_tiles; + thunk_sequence_->emplace_back(BuildKernelThunk( + fusion, /*implements_whole_instruction=*/true, unroll_factor)); + return IrEmitter::HandleFusion(fusion); } -} // namespace - Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { - if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(), - *copy)) { - thunk_sequence_->emplace_back(BuildHostToDeviceCopyThunk(copy)); - return Status::OK(); - } - if (ImplementedAsDeviceToDeviceMemcpy( - ir_emitter_context_->buffer_assignment(), *copy)) { + CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape())); + const BufferAssignment& buffer_assignment = + ir_emitter_context_->buffer_assignment(); + if (LayoutUtil::Equal(copy->operand(0)->shape().layout(), + copy->shape().layout()) && + buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) { thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy)); return Status::OK(); } - bool is_transpose_021; - Shape reduced_input_shape, reduced_output_shape; - std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) = - IsTranspose021(copy->operand(0)->shape(), copy->shape()); - if (is_transpose_021 && - reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled && - reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) { - thunk_sequence_->emplace_back(BuildKernelThunk(copy)); - VLOG(3) << "Emitting tiled 0-2-1 transposition"; - constexpr int64 tile_size = 32; - constexpr int64 num_rows = 8; - int64 num_tiles = EmitTranspose021Tiled( - GetIrArray(*copy->operand(0), *copy) - .CastToShape(reduced_input_shape, &ir_builder_), - GetIrArray(*copy, *copy) - .CastToShape(reduced_output_shape, &ir_builder_), - tile_size, num_rows, &ir_builder_); - UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size), - LastThunk(), ir_emitter_context_->llvm_module()); + if (CheckAndEmitHloWithTile021(copy)) { return Status::OK(); } @@ -965,7 +713,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { } Status IrEmitterUnnested::EmitExtraOutputsForReduce( - const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + const HloInstruction* reduce, const IrArray::Index& index, tensorflow::gtl::ArraySlice< std::pair> extra_output_gens) { @@ -973,11 +721,11 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( const HloInstruction* output = reduce->parent()->FusionInstruction(); llvm::Value* extra_output_address = GetIrArray(*output, *output, extra_output_gens[i].second) - .EmitArrayElementAddress(index, &ir_builder_, + .EmitArrayElementAddress(index, &b_, "extra_output_element_address"); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, extra_output_gens[i].first(index)); - ir_builder_.CreateStore(extra_output_ir_value, extra_output_address); + b_.CreateStore(extra_output_ir_value, extra_output_address); } return Status::OK(); } @@ -1002,6 +750,18 @@ Status IrEmitterUnnested::EmitReductionToScalar( int64 num_tiles = RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize); + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), {num_tiles}, {0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + + llvm::Type* index_ty = + GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); + + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + // Check whether every thread will process a full tile's worth of elements // without reading outside the bounds of the input. If this is true, we can // skip some bounds checks in the final algorithm. @@ -1040,57 +800,57 @@ Status IrEmitterUnnested::EmitReductionToScalar( // // and threads_per_block is a multiple of warpSize. // reduce_kernel<<>>(); // - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); + llvm::Value* partial_reduction_result_address = + b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index({}))); - ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + init_value_gens[i](IrArray::Index(index_ty))); + b_.CreateStore(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; + x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); // Emit an inner for-loop that reduces the elements in the tile. auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); + llvm_ir::ForLoop::EmitForLoop( + "element_id_in_tile", index_typed_constant(0), + index_typed_constant(kTileSize), index_typed_constant(1), &b_); // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &ir_builder_); - llvm::Value* x = ir_builder_.CreateNSWAdd( - ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize)), + &b_); + llvm::Value* x = b_.CreateNSWAdd( + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)), tile_element_loop->GetIndVarValue()); // Unless we know the tile is entirely in bounds, we have to emit a // x-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(num_elems)), - "x_in_bounds", &ir_builder_); + b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", + &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } - llvm_ir::IrArray::Index input_index( - /*linear=*/x, input_shape, &ir_builder_); - llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); + + IrArray::Index input_index( + /*linear=*/x, input_shape, &b_); + llvm::Value* input_address = b_.CreateAlloca(element_ir_type); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + b_.CreateStore(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -1101,49 +861,48 @@ Status IrEmitterUnnested::EmitReductionToScalar( // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's // immediately beyond the tile. - llvm::Value* x_end = ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize), - ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize))); + llvm::Value* x_end = b_.CreateNSWAdd( + index_typed_constant(kTileSize), + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize))); // The tile is entirely in bound if all_threads_in_bounds or // x_end <= num_elems. - llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.CreateICmpULE(x_end, ir_builder_.getInt64(num_elems)), - ir_builder_.getInt1(all_threads_in_bounds)); + llvm::Value* tile_in_bounds = + b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)), + b_.getInt1(all_threads_in_bounds)); llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, - &ir_builder_); + llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_); TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, - &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, &b_); TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); // After the if-then-else statement on tile_in_bounds, emit calls to // shfl_down that accumulate the partial reduction results of all threads // from the warp. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, - &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &b_); int bit_width = llvm_ir::GetSizeInBits(element_ir_type); // bitcast cannot be applied to aggregate types (even packed ones), so we // instead bitcast addresses of load/store to intN* of the same bit-width. llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? ir_builder_.getIntNTy(bit_width) + ? b_.getIntNTy(bit_width) : element_ir_type; for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( - element_ir_type, nullptr, "result_from_other_lane"); + llvm::Value* result_from_other_lane = + b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), + llvm::Value* partial_reduction_result = b_.CreateLoad( + b_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), "partial_reduction_result"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), - &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) + << "Requires block size a multiple of the warp size, otherwise we " + "will read undefined elements."; + b_.CreateStore( + EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + b_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -1157,24 +916,23 @@ Status IrEmitterUnnested::EmitReductionToScalar( // Emit an atomic operation that accumulates the partial reduction result of // lane 0 (which holds the partially accumulated result for its warp) to the // output element. - llvm::Value* lane_id = ir_builder_.CreateURem( - x_in_tiles, ir_builder_.getInt64(kWarpSize), "lane_id"); + llvm::Value* lane_id = + b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)), - "lane_id_is_zero", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, - &ir_builder_); + b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", + &b_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { llvm::Value* output_address = GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( - llvm_ir::IrArray::Index( - /*linear=*/ir_builder_.getInt64(0), + IrArray::Index( + /*linear=*/b_.getInt64(0), ShapeUtil::GetSubshape(output->shape(), reduce_output_shapes[i]), - &ir_builder_), - &ir_builder_, "output_element_address"); + &b_), + &b_, "output_element_address"); TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( *reducers[i], output_address, partial_reduction_result_addresses[i])); } @@ -1182,18 +940,14 @@ Status IrEmitterUnnested::EmitReductionToScalar( }; // Emit a parallel loop that iterates through all input tiles, one per thread. - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {num_tiles}, {0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); UpdateLaunchDimensions( launch_dimensions, static_cast(LastThunk())->thunks().back().get(), ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &ir_builder_) - .EmitLoop(IrName(reduce)); + launch_dimensions, &b_) + .EmitLoop(IrName(reduce), index_ty); } Status IrEmitterUnnested::EmitColumnReduction( @@ -1205,8 +959,8 @@ Status IrEmitterUnnested::EmitColumnReduction( tensorflow::gtl::ArraySlice< std::pair> extra_output_gens) { - // Divide the input matrix into tiles of size Kx1. For example, when the - // input matrix is 4x4 and K=2, the tiled matrix looks like + // Divide the input matrix into tiles of size KxL. For example, when the + // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like // // 0123 // 0123 @@ -1218,85 +972,131 @@ Status IrEmitterUnnested::EmitColumnReduction( // // We choose 128 as the tile size based on empirical evidence. It's big enough // to reduce the amount of atomic adds in the end, maximizing the memory - // bandwidth. - constexpr int64 kTileSize = 128; + // bandwidth. A tile width of 2 allows for high memory bandwidth utilization + // on 16b input data. + constexpr int64 kTileHeight = 128; + constexpr int64 kTileWidth = 2; - // If the height is not a multiple of the tile size, we pad the bottom of the + // If the height is not a multiple of kTileHeight, we pad the bottom of the // input matrix. - const int64 height_in_tiles = CeilOfRatio(height, kTileSize); + const int64 height_in_tiles = CeilOfRatio(height, kTileHeight); + // If width is not a multiple of kTileWidth the rightmost thread will process + // fewer input elements. + const int64 width_in_tiles = CeilOfRatio(width, kTileWidth); + Shape tiled_input_shape = + ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(), + {height_in_tiles, width_in_tiles}, {1, 0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + + // TODO(b/110211620): Convert to use i32 index_type when it is possible. + llvm::Type* index_ty = b_.getInt64Ty(); + + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; - // linear_index < height_in_tiles * width; + // linear_index < height_in_tiles * width_in_tiles; // linear_index += blockDim.x * gridDim.x) { - // y_in_tiles = linear_index / width; - // x = linear_index % width; + // y_in_tiles = linear_index / width_in_tiles; + // x_in_tiles = linear_index % width_in_tiles; // - // partial_result = init_value; - // if (height % kTileSize == 0 || - // y_in_tiles * kTileSize + kTileSize <= height) { - // for (element_id_in_tile : range(kTileSize)) { - // y = y_in_tiles * kTileSize + element_id_in_tile; - // partial_result = Reducer(partial_result, input[y][x]); + // partial_results[kTileWidth] = init_values; + // tile_in_y_bounds = height % kTileHeight == 0 || + // y_in_tiles * kTileHeight + kTileHeight <= height; + // tile_in_x_bounds = width % kTileWidth == 0 || + // x_in_tiles * kTileWidth + kTileWidth <= width; + // // The implementation handles y and x bound checks separately. + // if (tile_in_y_bounds && tile_in_x_bounds) { + // for (y_offset : range(kTileHeight)) { + // y = y_in_tiles * kTileHeight + y_offset; + // for (x_offset : range(kTileWidth)) { + // x = x_in_tiles * kTileWidth + x_offset; + // partial_result = Reducer(partial_result[x_offset], input[y][x]); + // } // } // } else { - // for (element_id_in_tile : range(kTileSize)) { - // y = y_in_tiles * kTileSize + element_id_in_tile; - // if (y < height) { - // partial_result = Reducer(partial_result, input[y][x]); + // for (y_offset : range(kTileHeight)) { + // y = y_in_tiles * kTileHeight + y_offset; + // for (y_offset : range(kTileHeight)) { + // x = x_in_tiles * kTileWidth + x_offset; + // if (y < height && x < width) { + // partial_result = Reducer(partial_result, input[y][x]); + // } // } // } // } - // AtomicReducer(&output[x], partial_result); + // for (x_offset : range(kTileWidth)) { + // AtomicReducer(&output[x + x_offset], partial_result[x_offset]); + // } // } - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { const int num_reduces = reducers.size(); // Emit the loop body that reduces one tile. llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index({}))); - ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* partial_reduction_result_address = + b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + + llvm::Twine(i * kTileWidth + x_offset)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](IrArray::Index(index_ty))); + b_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); + } } // Emit an inner for-loop that partially reduces the elements in the given // tile. llvm::Value* y_in_tiles = tile_index[0]; - llvm::Value* x = tile_index[1]; + llvm::Value* x_in_tiles = tile_index[1]; - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { + y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty); + x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); + + auto emit_tile_element_loop = [=](bool tile_in_y_bounds, + bool tile_in_x_bounds) -> Status { std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); + llvm_ir::ForLoop::EmitForLoop( + "element_id_in_tile", index_typed_constant(0), + index_typed_constant(kTileHeight), index_typed_constant(1), &b_); // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &ir_builder_); - llvm::Value* y = ir_builder_.CreateNSWAdd( - ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize)), + &b_); + llvm::Value* y = b_.CreateNSWAdd( + b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)), tile_element_loop->GetIndVarValue()); - // Unless we know the tile is entirely in bounds, we have to emit a - // y-in-bounds check before reading from the input. - if (!tile_in_bounds) { + + // Unless we know that y is in bounds, we have to emit a check before + // reading from the input. + if (!tile_in_y_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(y, ir_builder_.getInt64(height)), - "y_in_bounds", &ir_builder_); + b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds", + &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } - llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - { + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* x = b_.CreateNSWAdd( + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); + // Unless we know that x is in bounds, we have to emit a check before + // reading from the input. + if (!tile_in_x_bounds) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds", + &b_); + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); + } + llvm::Value* input_address = b_.CreateAlloca(element_ir_type); // {y,x} is an index to input_matrix_shape [height,width]. We need to // convert that to an index to input_shape (the shape of the operand of // "reduce"). This conversion is composed of a transposition from @@ -1312,84 +1112,108 @@ Status IrEmitterUnnested::EmitColumnReduction( const Shape input_matrix_shape = ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), {height, width}); - const llvm_ir::IrArray::Index input_matrix_index( - {y, x}, input_matrix_shape, &ir_builder_); - const llvm_ir::IrArray::Index input_index = + const IrArray::Index input_matrix_index({y, x}, input_matrix_shape, + &b_); + const IrArray::Index input_index = input_matrix_index .SourceIndexOfReshape(input_matrix_shape, - normalized_input_shape, &ir_builder_) + normalized_input_shape, &b_) .SourceIndexOfTranspose(normalized_input_shape, input_shape, - transpose_dimension_mapping, - &ir_builder_); + transpose_dimension_mapping, &b_); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + b_.CreateStore(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); + {partial_reduction_result_addresses[i * kTileWidth + x_offset], + input_address}, + partial_reduction_result_addresses[i * kTileWidth + x_offset])); + TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens)); } - return EmitExtraOutputsForReduce(reduce, input_index, - extra_output_gens); } + return Status::OK(); }; - // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's - // immediately beyond the tile. - llvm::Value* y_end = ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize), - ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize))); - llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.CreateICmpULE(y_end, ir_builder_.getInt64(height)), - ir_builder_.getInt1(height % kTileSize == 0)); - // The tile is entirely in bound if "height" is a multiple of kTileSize or + // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location + // that's immediately beyond the tile. + llvm::Value* y_end = b_.CreateNSWAdd( + index_typed_constant(kTileHeight), + b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight))); + // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location + // that's immediately beyond the tile. + llvm::Value* x_end = b_.CreateNSWAdd( + index_typed_constant(kTileWidth), + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth))); + llvm::Value* tile_in_y_bounds = + b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)), + b_.getInt1(height % kTileHeight == 0)); + llvm::Value* tile_in_x_bounds = + b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)), + b_.getInt1(width % kTileWidth == 0)); + // The tile is in y bounds if "height" is a multiple of kTileHeight or // y_end <= height. - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); - - // After the if-then-else statement on tile_in_bounds, emit atomic - // operations to accumulate the partial reduction result to the output - // element. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, - &ir_builder_); + llvm_ir::LlvmIfData if_tile_in_y_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_y_bounds, "tile_in_y_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block, &b_); + // The tile is in x bounds if "width" is a multiple of kTileWidth or + // x_end <= width. + llvm_ir::LlvmIfData if_tile_in_x_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, + /*tile_in_x_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, + /*tile_in_x_bounds=*/false)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block, &b_); + if_tile_in_x_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, + /*tile_in_x_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, + /*tile_in_x_bounds=*/false)); + + // After the nested if-then-else statement on tile_in_y_bounds and + // tile_in_x_bounds, emit atomic operations to accumulate the partial + // reduction result to the output element. + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block, &b_); const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index( - x, - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &ir_builder_), - &ir_builder_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_reduction_result_addresses[i])); + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* x = b_.CreateNSWAdd( + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + IrArray::Index( + x, + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &b_), + &b_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i * kTileWidth + x_offset])); + } } return Status::OK(); }; // Emit a parallel loop that iterate through all input tiles. - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {height_in_tiles, width}, {1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); UpdateLaunchDimensions( launch_dimensions, static_cast(LastThunk())->thunks().back().get(), ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &ir_builder_) - .EmitLoop(IrName(reduce)); + launch_dimensions, &b_) + .EmitLoop(IrName(reduce), index_ty); } static std::pair ComputeTilingSchemeForReduction( @@ -1441,7 +1265,7 @@ Status IrEmitterUnnested::EmitRowReduction( // for (element_id_in_tile : range(x_tile_size)) { // int x = x_in_tiles * x_tile_size + element_id_in_tile; // if (x < width) - // partial_result = reducer(partial_result, input[z][y][z]); + // partial_result = reducer(partial_result, input[z][y][x]); // } // AtomicReducer(&output[y], partial_result); // } @@ -1493,23 +1317,25 @@ Status IrEmitterUnnested::EmitRowReduction( // x + (x_tile_size - 1) * warpSize < width) { // // The entire x_tile is in bounds. // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; - // ++element_id_in_z_tile) { + // ++element_id_in_z_tile) { // z = z_in_tiles * z_tile_size + element_id_in_z_tile; - // for (int element_id_in_x_tile = 0;element_id_in_x_tile < x_tile_size; - // ++element_id_in_x_tile, x += warpSize) { - // partial_result = Reducer(partial_result, input[z][y][x]); + // int tx = x; + // for (int element_id_in_x_tile = 0; + // element_id_in_x_tile < x_tile_size; + // ++element_id_in_x_tile, tx += warpSize) { + // partial_result = Reducer(partial_result, input[z][y][tx]); // } // } // } else { // // The tile is partially in bounds. // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; - // ++element_id_in_z_tile) { + // ++element_id_in_z_tile) { // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // int tx = x; // for (int element_id_in_x_tile = 0; element_id_in_x_tile < - // x_tile_size; - // ++element_id_in_tile, x += warpSize) { - // if (x < width) - // partial_result = Reducer(partial_result, input[z][y][x]); + // x_tile_size; ++element_id_in_tile, tx += warpSize) { + // if (tx < width) + // partial_result = Reducer(partial_result, input[z][y][tx]); // } // } // } @@ -1531,20 +1357,30 @@ Status IrEmitterUnnested::EmitRowReduction( // the use of shfl_down is valid. const int64 width_in_tiles = RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize); + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), + {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + llvm::Type* index_ty = + GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); - auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) { - // Emit the loop body that reduces one z-x-tile. + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + auto loop_body_emitter = [=](const IrArray::Index& tile_index) { const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( input_shape.element_type(), ir_emitter_context_->llvm_module()); std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); + llvm::Value* partial_reduction_result_address = + b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index({}))); - ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + init_value_gens[i](IrArray::Index(index_ty))); + b_.CreateStore(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } @@ -1552,24 +1388,26 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* z_tile = tile_index[0]; llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; - llvm::Value* warp_id = ir_builder_.CreateUDiv( - x_tile, ir_builder_.getInt64(kWarpSize), "warp_id"); - llvm::Value* lane_id = ir_builder_.CreateURem( - x_tile, ir_builder_.getInt64(kWarpSize), "lane_id"); + + x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty); + + llvm::Value* warp_id = + b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); + llvm::Value* lane_id = + b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id"); // The x-location of the last element in this z-x-tile. - // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * - // x_tile_size); - llvm::Value* last_x = ir_builder_.CreateNSWAdd( - lane_id, ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(x_tile_size - 1), - ir_builder_.CreateNSWMul( - warp_id, ir_builder_.getInt64(x_tile_size))))); + // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); + llvm::Value* last_x = b_.CreateNSWAdd( + lane_id, + b_.CreateNSWMul( + index_typed_constant(kWarpSize), + b_.CreateNSWAdd( + index_typed_constant(x_tile_size - 1), + b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size))))); KernelSupportLibrary ksl( - &ir_builder_, + &b_, /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, /*prevent_vectorization=*/false); @@ -1578,41 +1416,40 @@ Status IrEmitterUnnested::EmitRowReduction( auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, int64 x_tile_loop_bound) -> Status { auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { - llvm::Value* z = ir_builder_.CreateNSWAdd( - z_indvar, ir_builder_.CreateNSWMul( - ir_builder_.getInt64(z_tile_size), z_tile)); - + llvm::Value* z = b_.CreateNSWAdd( + z_indvar, + b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile)); TF_RETURN_IF_ERROR(ksl.For( "x_tile", - /*start=*/0, /*end=*/x_tile_loop_bound, /*step=*/1, - [&](llvm::Value* x_indvar) -> Status { - // x = lane_id + warpSize * (element_id_in_x_tile + warp_id * - // x_tile_size); - llvm::Value* x = ir_builder_.CreateNSWAdd( + /*start=*/index_typed_constant(0), + /*end=*/index_typed_constant(x_tile_loop_bound), + /*step=*/1, [&](llvm::Value* x_indvar) -> Status { + // x = lane_id + + // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); + llvm::Value* x = b_.CreateNSWAdd( lane_id, - ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - x_indvar, - ir_builder_.CreateNSWMul( - warp_id, ir_builder_.getInt64(x_tile_size))))); + b_.CreateNSWMul( + index_typed_constant(kWarpSize), + b_.CreateNSWAdd( + x_indvar, b_.CreateNSWMul( + warp_id, llvm::ConstantInt::get( + index_ty, x_tile_size))))); // Unless we know the x-tile is entirely in bounds, we have to // emit a x-in-bounds check before reading from the input. if (!x_tile_in_bounds) { llvm_ir::LlvmIfData if_x_in_bounds_data = - llvm_ir::EmitIfThenElse(ir_builder_.CreateICmpULT( - x, ir_builder_.getInt64(width)), - "x_in_bounds", &ir_builder_); - // Points ir_builder_ to the then-block. + llvm_ir::EmitIfThenElse( + b_.CreateICmpULT(x, index_typed_constant(width)), + "x_in_bounds", &b_); + // Points b_ to the then-block. llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, - &ir_builder_); + &b_); } // Emit code that reads the input element and accumulates it // to the partial reduction result. - llvm::Value* input_address = - ir_builder_.CreateAlloca(element_ir_type); + llvm::Value* input_address = b_.CreateAlloca(element_ir_type); { // {z,y,x} is an index to input_3d_tensor_shape // [depth,height,width]. We need to convert that to an index @@ -1630,21 +1467,20 @@ Status IrEmitterUnnested::EmitRowReduction( const Shape input_3d_tensor_shape = ShapeUtil::MakeShapeWithDescendingLayout( input_shape.element_type(), {depth, height, width}); - const llvm_ir::IrArray::Index input_3d_tensor_index( - {z, y, x}, input_3d_tensor_shape, &ir_builder_); - const llvm_ir::IrArray::Index input_index = + const IrArray::Index input_3d_tensor_index( + {z, y, x}, input_3d_tensor_shape, &b_); + const IrArray::Index input_index = input_3d_tensor_index .SourceIndexOfReshape(input_3d_tensor_shape, - normalized_input_shape, - &ir_builder_) + normalized_input_shape, &b_) .SourceIndexOfTranspose( normalized_input_shape, input_shape, - transpose_dimension_mapping, &ir_builder_); + transpose_dimension_mapping, &b_); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + b_.CreateStore(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -1658,13 +1494,14 @@ Status IrEmitterUnnested::EmitRowReduction( }; return ksl.For("z_tile", - /*start=*/0, /*end=*/z_tile_size, /*step=*/1, - emit_z_tile_element_loop); + /*start=*/index_typed_constant(0), + /*end=*/index_typed_constant(z_tile_size), + /*step=*/1, emit_z_tile_element_loop); }; - llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0), - ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width))); + llvm::Value* tile_in_bounds = + b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), + b_.CreateICmpULT(last_x, index_typed_constant(width))); TF_RETURN_IF_ERROR( ksl.If(tile_in_bounds, @@ -1687,23 +1524,25 @@ Status IrEmitterUnnested::EmitRowReduction( // bitcast cannot be applied to aggregate types (even packed ones), so we // instead bitcast addresses of load/store to intN* of the same bit-width. llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? ir_builder_.getIntNTy(bit_width) + ? b_.getIntNTy(bit_width) : element_ir_type; for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( - element_ir_type, nullptr, "result_from_other_lane"); + llvm::Value* result_from_other_lane = + b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), + llvm::Value* partial_reduction_result = b_.CreateLoad( + b_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), "partial_reduction_result"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), - &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) + << "Requires block size a multiple of the warp size, otherwise we " + "will read undefined elements."; + b_.CreateStore( + EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + b_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -1718,48 +1557,43 @@ Status IrEmitterUnnested::EmitRowReduction( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)), - "lane_id_is_zero", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, - &ir_builder_); + b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", + &b_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { llvm::Value* output_address = GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( - llvm_ir::IrArray::Index( - y, - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &ir_builder_), - &ir_builder_, "output_element_address"); - if (x_tile_size * z_tile_size < depth * width) { - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, - partial_reduction_result_addresses[i])); - } else { + IrArray::Index(y, + ShapeUtil::GetSubshape( + output->shape(), reduce_output_shapes[i]), + &b_), + &b_, "output_element_address"); + // We don't need to emit atomic operations if there is only one tile of + // results. 'depth' is the z dimension, 'width' is the x dimension. + if (z_tile_size >= depth && x_tile_size >= width) { TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {output_address, partial_reduction_result_addresses[i]}, output_address)); + } else { + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i])); } } return Status::OK(); }; // Emit a parallel loop that iterates through every input tiles. - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), - {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); UpdateLaunchDimensions( launch_dimensions, static_cast(LastThunk())->thunks().back().get(), ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &ir_builder_) - .EmitLoop(IrName(reduce)); + launch_dimensions, &b_) + .EmitLoop(IrName(reduce), index_ty); } // Figures out whether `reduce` is a row or column reduction, and which @@ -1865,6 +1699,10 @@ Status IrEmitterUnnested::EmitReductionToVector( } Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { + // TODO(b/112040122): Support multi-output reduce. + if (!ShapeUtil::IsArray(reduce->shape())) { + return Unimplemented("Multi-output reduce is not supported on GPU"); + } auto input = reduce->operand(0); auto init_value = reduce->operand(1); tensorflow::gtl::ArraySlice dimensions_to_reduce(reduce->dimensions()); @@ -1872,39 +1710,41 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { // HandleReduce specializes reduction from a multi-dimensional array to a 1D // array. The specialized version requires an initializer thunk that // initializes the output array to the initial value of the reduce. - if (IsReductionToVector(*reduce) && - // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits - 32 <= primitive_util::BitWidth(reduce->shape().element_type())) { + if (IsReductionToVector(*reduce)) { TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, BuildInitializerThunk(reduce)); std::vector> thunks; thunks.push_back(std::move(initializer_thunk)); - thunks.push_back(BuildKernelThunk(reduce)); + thunks.push_back( + BuildKernelThunk(reduce, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), reduce)); return EmitReductionToVector( - reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*input, *reduce) - .EmitReadArrayElement(index, &ir_builder_); + reduce, input->shape(), {[&](const IrArray::Index& index) { + return GetIrArray(*input, *reduce).EmitReadArrayElement(index, &b_); }}, - {[&](const llvm_ir::IrArray::Index& index) { + {[&](const IrArray::Index& index) { return GetIrArray(*init_value, *reduce) - .EmitReadArrayElement(index, &ir_builder_); + .EmitReadArrayElement(index, &b_); }}, dimensions_to_reduce, {reducer}, {{}}, {}); } - thunk_sequence_->emplace_back(BuildKernelThunk(reduce)); + thunk_sequence_->emplace_back( + BuildKernelThunk(reduce, /*implements_whole_instruction=*/true)); return IrEmitter::HandleReduce(reduce); } Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { bool all_tuple_elements_have_buffer = c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { - return ir_emitter_context_->buffer_assignment().HasTopLevelAllocation( - tuple_element); + return ir_emitter_context_->buffer_assignment() + .GetUniqueTopLevelSlice(tuple_element) + .ok(); }); + // TODO(b/111689850): This logic isn't quite correct. + // // Tuples (especially tuples that are the final result of a computation) can // be so huge that if we were to emit a kernel that took each tuple element as // a parameter, we would exceed the max allowable number of parameters to a @@ -1912,9 +1752,9 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { // buffer, we collect their buffer addresses in a host array, and then copy // that array to the tuple's buffer. // - // Some tuple elements (e.g. const or bitcast of const) might not have a - // buffer -- their contents are stored in code. In that case, we fall back to - // emitting kernels which have access to their buffer addresses in code. + // Some tuple elements might not have an unambiguous buffer (like the result + // of a select-tuple). In that case, we fall back to emitting kernels which + // have access to their buffer addresses in code. if (all_tuple_elements_have_buffer) { std::vector tuple_element_buffers; for (const HloInstruction* tuple_element : tuple->operands()) { @@ -1924,7 +1764,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); return Status::OK(); } - thunk_sequence_->emplace_back(BuildKernelThunk(tuple)); + thunk_sequence_->emplace_back( + BuildKernelThunk(tuple, /*implements_whole_instruction=*/true)); return IrEmitter::HandleTuple(tuple); } @@ -1949,7 +1790,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( BuildInitializerThunk(select_and_scatter)); std::vector> thunks; thunks.push_back(std::move(initializer_thunk)); - thunks.push_back(BuildKernelThunk(select_and_scatter)); + thunks.push_back(BuildKernelThunk(select_and_scatter, + /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), select_and_scatter)); @@ -1959,6 +1801,14 @@ Status IrEmitterUnnested::HandleSelectAndScatter( "Dilation for SelectAndScatter not implemented on GPU."); } + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + source->shape(), ir_emitter_context_->device_description()); + llvm::Type* index_type = GetIndexTypeForKernel( + select_and_scatter, launch_dimensions.launch_bound(), &b_); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + // kSelectAndScatter is implemented as two kernel launches: the first launch // initializes the output array to the given initial value, // and the second accumulates the "source" matrix to the @@ -1978,114 +1828,106 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // selected_index = I // initialized_flag = true // output(selected_index) = scatter(output(selected_index), source(S)) - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& source_index) -> Status { + auto loop_body_emitter = [=](const IrArray::Index& source_index) -> Status { // Allocate space to keep the currently selected value, its index, and a // boolean flag if the value is initialized. The initialized_flag is set // false. llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, ir_emitter_context_->llvm_module()), - "selected_value_address", &ir_builder_); + "selected_value_address", &b_); llvm::Value* selected_index_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank), - "selected_index_address", &ir_builder_); + index_type, index_typed_constant(rank), "selected_index_address", + &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( - ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_); - ir_builder_.CreateStore(ir_builder_.getInt1(false), - initialized_flag_address); + b_.getInt1Ty(), "initialized_flag_address", &b_); + b_.CreateStore(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. - llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), - &ir_builder_); + llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_, + index_type); std::vector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); CHECK_GT(dim.size(), 0); } - const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape( + const IrArray::Index window_index = window_loops.AddLoopsForShape( ShapeUtil::MakeShape(operand_element_type, window_size), "window"); llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), - &ir_builder_); + &b_); // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - llvm_ir::IrArray::Index operand_index(source_index.size()); - llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); + IrArray::Index operand_index(index_type, source_index.size()); + llvm::Value* in_bounds_condition = b_.getInt1(true); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = ir_builder_.CreateNSWMul( - source_index[i], ir_builder_.getInt64(window.dimensions(i).stride())); - operand_index[i] = ir_builder_.CreateNSWSub( - ir_builder_.CreateNSWAdd(strided_index, window_index[i]), - ir_builder_.getInt64(window.dimensions(i).padding_low())); - llvm::Value* index_condition = ir_builder_.CreateICmpULT( + llvm::Value* strided_index = b_.CreateNSWMul( + source_index[i], index_typed_constant(window.dimensions(i).stride())); + operand_index[i] = b_.CreateNSWSub( + b_.CreateNSWAdd(strided_index, window_index[i]), + index_typed_constant(window.dimensions(i).padding_low())); + llvm::Value* index_condition = b_.CreateICmpULT( operand_index[i], - ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = - ir_builder_.CreateAnd(in_bounds_condition, index_condition); + index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i))); + in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); // Only need to do something if the operand index is within the bounds. // First check if the initialized_flag is set. llvm_ir::LlvmIfData if_in_bounds = - llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_); + llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_); llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - ir_builder_.CreateLoad(initialized_flag_address), "initialized", - &ir_builder_); + b_.CreateLoad(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. - llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_); - const auto save_operand_index = [&]( - const llvm_ir::IrArray::Index& operand_index) { + llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_); + const auto save_operand_index = [&](const IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - ir_builder_.CreateInBoundsGEP(selected_index_address, - {ir_builder_.getInt32(i)}); - ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); + b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); + b_.CreateStore(operand_index[i], selected_index_address_slot); } }; - llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter); + IrArray operand_array = GetIrArray(*operand, *select_and_scatter); llvm::Value* operand_data = - operand_array.EmitReadArrayElement(operand_index, &ir_builder_); - ir_builder_.CreateStore(operand_data, selected_value_address); + operand_array.EmitReadArrayElement(operand_index, &b_); + b_.CreateStore(operand_data, selected_value_address); save_operand_index(operand_index); - ir_builder_.CreateStore(ir_builder_.getInt1(true), - initialized_flag_address); + b_.CreateStore(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to // potentially update the selected value and index with the currently // visiting operand. - llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_); const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); llvm::Value* operand_address = - operand_array.EmitArrayElementAddress(operand_index, &ir_builder_); + operand_array.EmitArrayElementAddress(operand_index, &b_); llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(PRED, ir_emitter_context_->llvm_module()), - "select_return_buffer", &ir_builder_); + "select_return_buffer", &b_); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *select_and_scatter->select(), {selected_value_address, operand_address}, select_return_buffer)); - llvm::Value* result = ir_builder_.CreateLoad(select_return_buffer); + llvm::Value* result = b_.CreateLoad(select_return_buffer); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = ir_builder_.CreateICmpNE( + llvm::Value* cond = b_.CreateICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType( PRED, ir_emitter_context_->llvm_module()), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = - llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_); - ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address), - selected_value_address); + llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); + llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_); + b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to @@ -2093,27 +1935,24 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // location is computed by calling the `scatter` function with the source // value and the current output value. llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), - &ir_builder_); - llvm_ir::IrArray::Index selected_index; + &b_); + IrArray::Index selected_index(operand_index.GetType()); for (int64 i = 0; i < rank; ++i) { - llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( - selected_index_address, {ir_builder_.getInt32(i)}); - selected_index.push_back( - ir_builder_.CreateLoad(selected_index_address_slot)); + llvm::Value* selected_index_address_slot = + b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); } llvm::Value* source_value_address = GetIrArray(*source, *select_and_scatter) - .EmitArrayElementAddress(source_index, &ir_builder_); + .EmitArrayElementAddress(source_index, &b_); llvm::Value* output_value_address = GetIrArray(*select_and_scatter, *select_and_scatter) - .EmitArrayElementAddress(selected_index, &ir_builder_); + .EmitArrayElementAddress(selected_index, &b_); return EmitAtomicOperationForNestedComputation( *select_and_scatter->scatter(), output_value_address, source_value_address); }; - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - source->shape(), ir_emitter_context_->device_description()); UpdateLaunchDimensions( launch_dimensions, // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk @@ -2123,8 +1962,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( static_cast(LastThunk())->thunks().back().get(), ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, source->shape(), - launch_dimensions, &ir_builder_) - .EmitLoop(IrName(select_and_scatter)); + launch_dimensions, &b_) + .EmitLoop(IrName(select_and_scatter), index_type); } Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { @@ -2133,33 +1972,147 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { condition->root_instruction()->shape().element_type() == PRED) << "While condition computation must return bool"; // Build ForThunk for conformant while loops, otherwise build WhileThunk. - auto result = CanTransformWhileToFor(xla_while); - if (result.ok()) { - auto tuple = result.ConsumeValueOrDie(); - // loop_trip_count = (limit - start + increment - 1) / increment - const int64 loop_trip_count = - (std::get<1>(tuple) - std::get<0>(tuple) + std::get<2>(tuple) - 1) / - std::get<2>(tuple); - thunk_sequence_->emplace_back(BuildForThunk(xla_while, loop_trip_count)); + // TODO(b/112163966): Move trip count computation earlier in the pipeline. + if (auto loop_trip_count = ComputeWhileLoopTripCount(xla_while)) { + thunk_sequence_->emplace_back(BuildForThunk(xla_while, *loop_trip_count)); VLOG(3) << "Built ForThunk for while: " << xla_while->name(); } else { thunk_sequence_->emplace_back(BuildWhileThunk(xla_while)); - VLOG(3) << "Built WhileThunk for while: " << xla_while->name() - << " while-to-for transform status: " << result.status(); + VLOG(3) << "Built WhileThunk for while: " << xla_while->name(); } return Status::OK(); } -Status IrEmitterUnnested::HandleRng(HloInstruction* random) { - thunk_sequence_->push_back(BuildKernelThunk(random)); - return IrEmitter::HandleRng(random); +Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { + // Build the kernel to generate the random numbers. + // + // Unroll the kernel so that the duplicated computation that calculates the + // 128 bit sample can be optimized away by LLVM. + thunk_sequence_->emplace_back( + BuildKernelThunk(rng, /*implements_whole_instruction=*/false, + ComputeMaxUnrollFactor(rng))); + ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; + for (const HloInstruction* operand : rng->operands()) { + operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*operand, *rng).EmitReadArrayElement(index, &b_); + }; + } + TF_RETURN_IF_ERROR(EmitTargetElementLoop( + *rng, GpuElementalIrEmitter(hlo_module_config_, module_, &b_, + GetNestedComputer()) + .MakeElementGenerator(rng, operand_to_generator))); + std::unique_ptr rng_thunk = std::move(thunk_sequence_->back()); + thunk_sequence_->pop_back(); + + // Emit a kernel to increment the global state for Philox RNG algorithm. + thunk_sequence_->emplace_back( + BuildKernelThunk(rng, /*implements_whole_instruction=*/false)); + llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_); + std::unique_ptr increment_seed_thunk = + std::move(thunk_sequence_->back()); + thunk_sequence_->pop_back(); + + // Build the SequentialThunk for the RNG hlo. + std::vector> thunks; + thunks.reserve(2); + thunks.push_back(std::move(rng_thunk)); + thunks.push_back(std::move(increment_seed_thunk)); + thunk_sequence_->emplace_back( + MakeUnique(std::move(thunks), rng)); + + return Status::OK(); } Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { - thunk_sequence_->push_back(BuildKernelThunk(select)); + thunk_sequence_->push_back( + BuildKernelThunk(select, /*implements_whole_instruction=*/true)); return IrEmitter::HandleSelect(select); } +Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { + std::vector> thunks; + auto keys = sort->operand(0); + auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + ShapeIndex keys_shape_index({}); + ShapeIndex values_shape_index({}); + if (values != nullptr) { + keys_shape_index = ShapeIndex({0}); + values_shape_index = ShapeIndex({1}); + } + auto keys_destination = GetAllocationSlice(*sort, keys_shape_index); + auto values_destination = GetAllocationSlice(*sort, values_shape_index); + + if (keys_destination != GetAllocationSlice(*keys)) { + thunks.push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*keys), + /*destination_buffer=*/keys_destination, + /*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr)); + } + if (values != nullptr && values_destination != GetAllocationSlice(*values)) { + // TODO(b/26783907): Figure out why we never seem to share buffers for + // key/value sort. + thunks.push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*values), + /*destination_buffer=*/values_destination, + /*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr)); + } + + int64 dimension_to_sort = sort->dimensions(0); + int64 dimension_to_sort_bound = keys->shape().dimensions(dimension_to_sort); + int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); + auto index_type = b_.getInt64Ty(); + + // Naive C++ code for the outer loops: + // + // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound); + // ++stage) { + // int64 first_xor_mask = (1LL << (stage + 1)) - 1; + // SortInPlace(first_xor_mask); + // for (int64 mask = stage - 1; mask >= 0; --mask) { + // int64 later_xor_mask = 1LL << mask; + // SortInPlace(later_xor_mask); + // } + // } + // + // This follows the algorithm described on Wikipedia: + // https://en.wikipedia.org/wiki/Bitonic_sorter + + for (int64 stage = 0; stage < num_stages; ++stage) { + for (int64 mask = stage; mask >= 0; --mask) { + thunks.push_back( + BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + keys->shape(), ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), + ir_emitter_context_->llvm_module()); + + llvm::Value* xor_mask; + if (mask == stage) { + xor_mask = llvm::ConstantInt::get(index_type, (1LL << (stage + 1)) - 1); + } else { + xor_mask = llvm::ConstantInt::get(index_type, 1LL << mask); + } + + TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( + dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index), + values != nullptr ? tensorflow::gtl::make_optional( + GetIrArray(*sort, *sort, values_shape_index)) + : tensorflow::gtl::nullopt, + IrName(sort), xor_mask, &b_, &launch_dimensions)); + } + } + + thunk_sequence_->emplace_back( + MakeUnique(std::move(thunks), sort)); + return Status::OK(); +} + +Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { + thunk_sequence_->push_back( + BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true)); + return IrEmitter::HandleTupleSelect(tuple_select); +} + Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { if (hlo_module_config_.replica_count() != 1) { // TODO(b/33011107): Support nontrivial cross replica sum on GPU. @@ -2195,22 +2148,31 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { thunks.push_back(MakeUnique( /*source_address=*/GetAllocationSlice(*crs->operand(i)), /*destination_buffer=*/tuple_element_buffers.back(), - /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs)); + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr)); } // Output a tuple of the buffers above. thunks.push_back(MakeUnique(tuple_element_buffers, - GetAllocationSlice(*crs), crs)); + GetAllocationSlice(*crs), nullptr)); thunk_sequence_->push_back( MakeUnique(std::move(thunks), crs)); return Status::OK(); } +Status IrEmitterUnnested::HandleAfterAll(HloInstruction* gen_token) { + return Status::OK(); +} + Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { thunk_sequence_->emplace_back(BuildInfeedThunk(infeed)); return Status::OK(); } +Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) { + thunk_sequence_->emplace_back(BuildOutfeedThunk(outfeed)); + return Status::OK(); +} + // Figures out how to access the buffers for all subshapes of hlo's operands and // for hlo itself (i.e. all the buffers produced by HLO). // @@ -2298,11 +2260,6 @@ GetHloBufferSlices(const HloInstruction* hlo, // Adds entries for all subshapes of instr to `slices`. auto add_slices_for = [&](const HloInstruction* instr) { - // GPU constants don't have buffers; don't bother looking for one. - if (instr->IsConstant()) { - return; - } - ShapeUtil::ForEachSubshape( instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) { if (slices.count({instr, index})) { @@ -2329,13 +2286,9 @@ GetHloBufferSlices(const HloInstruction* hlo, return slices; } -Status IrEmitterUnnested::HandleGather(HloInstruction* gather) { - // TODO(b/72710576): Gather is not implemented on GPUs - return Unimplemented("Gather is not implemented on GPUs."); -} - std::unique_ptr IrEmitterUnnested::BuildKernelThunk( - const HloInstruction* inst, int unroll_factor) { + const HloInstruction* inst, bool implements_whole_instruction, + int unroll_factor) { const BufferAssignment& buffer_assn = ir_emitter_context_->buffer_assignment(); @@ -2368,21 +2321,25 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( // We'll pass a pointer to each of the elements of `buffers` to our kernel, in // this order. - std::vector buffers(buffers_needed.begin(), - buffers_needed.end()); - std::sort(buffers.begin(), buffers.end(), + std::vector non_constant_buffers; + c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), + [](const BufferAllocation* allocation) { + return !allocation->is_constant(); + }); + + std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), [](const BufferAllocation* a, const BufferAllocation* b) { return a->index() < b->index(); }); - llvm::Function* kernel = BuildKernelPrototype(*inst, buffers); + llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers); // Build a map from a BufferAllocation to the corresponding argument in our // kernel. std::unordered_map kernel_args; { auto arg_it = kernel->arg_begin(); - auto buffers_it = buffers.begin(); + auto buffers_it = non_constant_buffers.begin(); for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) { kernel_args[*buffers_it] = arg_it; } @@ -2400,18 +2357,24 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( << " is found in slice " << slice.ToString() << " at GTE index " << gte_index.ToString(); - llvm::Value* loc = - ir_builder_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), - {ir_builder_.getInt64(slice.offset())}); + llvm::Value* loc; + if (slice.allocation()->is_constant()) { + loc = ir_emitter_context_->llvm_module()->getGlobalVariable( + llvm_ir::AsStringRef(llvm_ir::ConstantBufferAllocationToGlobalName( + *slice.allocation()))); + CHECK_NE(loc, nullptr); + } else { + loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), + {b_.getInt64(slice.offset())}); + } // If gte_index is nonempty, we have to dereference `loc` to get to the // value we're ultimately interested in. llvm::Type* int8_double_pointer = - llvm::PointerType::get(ir_builder_.getInt8PtrTy(), /*AddressSpace=*/0); + llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0); for (int64 idx : gte_index) { - loc = ir_builder_.CreateBitCast(loc, int8_double_pointer); - loc = ir_builder_.CreateLoad( - ir_builder_.CreateInBoundsGEP(loc, {ir_builder_.getInt64(idx)})); + loc = b_.CreateBitCast(loc, int8_double_pointer); + loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)})); } bindings_.BindHloToIrValue(*instr, loc, index); @@ -2423,11 +2386,12 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer)); } else { bindings_.SetTempBufferBase( - llvm::ConstantPointerNull::get(ir_builder_.getInt8PtrTy())); + llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } - return MakeUnique(buffers, llvm_ir::AsString(kernel->getName()), - inst, unroll_factor); + return MakeUnique( + non_constant_buffers, llvm_ir::AsString(kernel->getName()), + implements_whole_instruction ? inst : nullptr, unroll_factor); } std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( @@ -2459,17 +2423,31 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( const HloInstruction* inst) { CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); - std::vector tuple_element_buffers; - for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) { - BufferAllocation::Slice buffer = ir_emitter_context_->buffer_assignment() - .GetUniqueSlice(inst, {i}) - .ConsumeValueOrDie(); - tuple_element_buffers.push_back(buffer); - } + ShapeTree slices(inst->shape()); + slices.ForEachMutableElement( + [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { + *slice = ir_emitter_context_->buffer_assignment() + .GetUniqueSlice(inst, index) + .ConsumeValueOrDie(); + }); + return MakeUnique(slices, inst); +} - return MakeUnique( - tuple_element_buffers, - /*destination_buffer=*/GetAllocationSlice(*inst), inst); +std::unique_ptr IrEmitterUnnested::BuildOutfeedThunk( + const HloInstruction* inst) { + CHECK_EQ(HloOpcode::kOutfeed, inst->opcode()); + + ShapeTree slices(inst->operand(0)->shape()); + slices.ForEachMutableElement( + [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { + auto status_or_slice = + ir_emitter_context_->buffer_assignment().GetUniqueSlice( + inst->operand(0), index); + if (status_or_slice.ok()) { + *slice = status_or_slice.ConsumeValueOrDie(); + } + }); + return MakeUnique(std::move(slices), inst); } namespace { @@ -2587,6 +2565,11 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( init_value = hlo->operand(init_value->parameter_number()); } + // Initializer thunks don't implement a whole instruction, and we want to + // profile the whole instruction instead of the individual thunks it consists + // of. Therefore we pass nullptr as the HloInstruction* to the thunks we + // generate below. + // // In the common case, the initializer is a constant. In this case, emit a // device-memset call if we can. Currently StreamExecutor only supports // zeroing and 32-bit memsets. @@ -2600,24 +2583,26 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( ArraySlice literal_bytes( reinterpret_cast(literal.untyped_data()), num_bytes); if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return {MakeUnique(GetAllocationSlice(*hlo, index), hlo)}; + return { + MakeUnique(GetAllocationSlice(*hlo, index), nullptr)}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by // repeating the literal 4 or 2 times, so long as the destination buffer is // an even multiple of 32 bits long. + const Shape& output_shape = ShapeUtil::GetSubshape(hlo->shape(), index); if ((num_bytes == 1 || num_bytes == 2) && - ShapeUtil::ByteSizeOf(hlo->shape()) % 4 == 0) { + ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { uint16 pattern16; if (num_bytes == 1) { uint8 b = literal_bytes.front(); pattern16 = uint16{b} | (uint16{b} << 8); } else { - pattern16 = literal_bytes.front(); + memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16)); } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); return {MakeUnique( - pattern32, GetAllocationSlice(*hlo, index), hlo)}; + pattern32, GetAllocationSlice(*hlo, index), nullptr)}; } // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit @@ -2628,12 +2613,13 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); return {MakeUnique( - word, GetAllocationSlice(*hlo, index), hlo)}; + word, GetAllocationSlice(*hlo, index), nullptr)}; } } // Otherwise fall back to our slow initializer code. - std::unique_ptr kernel_thunk = BuildKernelThunk(hlo); + std::unique_ptr kernel_thunk = + BuildKernelThunk(hlo, /*implements_whole_instruction=*/false); LaunchDimensions launch_dimensions = CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index), ir_emitter_context_->device_description()); @@ -2642,15 +2628,24 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( // If the init_value was fused into this reduce we have to generate it first. if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); - TF_RETURN_IF_ERROR(HandleConstant(const_cast(init_value))); + + const Literal& literal = init_value_operand->literal(); + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, module_); + + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + *module_, initializer->getType(), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, + /*Name=*/""); + global_for_const->setAlignment(kConstantBufferAlignBytes); + bindings_.BindHloToIrValue(*init_value_operand, global_for_const); } TF_RETURN_IF_ERROR(ParallelLoopEmitter( - [=](const llvm_ir::IrArray::Index& index) { + [=](const IrArray::Index& index) { return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &ir_builder_); + .EmitReadArrayElement(index, &b_); }, - GetIrArray(*hlo, *hlo, index), launch_dimensions, - &ir_builder_) + GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_) .EmitLoop(IrName(hlo))); // Clean up state left behind by emitting the loop above. (This is normally @@ -2761,13 +2756,13 @@ std::unique_ptr IrEmitterUnnested::BuildWhileThunk( HloComputation* condition = hlo->while_condition(); IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition, ir_emitter_context_); - TF_CHECK_OK(condition->root_instruction()->Accept(&ir_emitter_condition)); + TF_CHECK_OK(condition->Accept(&ir_emitter_condition)); // Generate thunk sequence for while 'body'. HloComputation* body = hlo->while_body(); IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, ir_emitter_context_); - TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); + TF_CHECK_OK(body->Accept(&ir_emitter_body)); return MakeUnique( GetAllocationSlice(*condition->root_instruction()), // cond result @@ -2785,7 +2780,7 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( HloComputation* body = hlo->while_body(); IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, ir_emitter_context_); - TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); + TF_CHECK_OK(body->Accept(&ir_emitter_body)); return MakeUnique(loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo); @@ -2801,12 +2796,12 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( HloComputation* true_computation = hlo->true_computation(); IrEmitterUnnested ir_emitter_true(hlo_module_config_, true_computation, ir_emitter_context_); - TF_CHECK_OK(true_computation->root_instruction()->Accept(&ir_emitter_true)); + TF_CHECK_OK(true_computation->Accept(&ir_emitter_true)); HloComputation* false_computation = hlo->false_computation(); IrEmitterUnnested ir_emitter_false(hlo_module_config_, false_computation, ir_emitter_context_); - TF_CHECK_OK(false_computation->root_instruction()->Accept(&ir_emitter_false)); + TF_CHECK_OK(false_computation->Accept(&ir_emitter_false)); return MakeUnique( GetAllocationSlice(*hlo->operand(0)), @@ -2834,37 +2829,588 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( ir_emitter_context_->llvm_module()); if (!hlo.IsMultiOutputFusion()) { return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), - launch_dimensions, &ir_builder_, unroll_factor) - .EmitLoop(IrName(&hlo)); + launch_dimensions, &b_, unroll_factor) + .EmitLoop( + IrName(&hlo), + GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_)); } - // For multiple outputs fusion, we need to emit each operand and the root. - std::vector output_arrays; + // For multioutput fusion, we need to emit each operand and the root. + std::vector output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } - TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, - launch_dimensions, &ir_builder_, - unroll_factor) - .EmitLoop(IrName(&hlo))); + TF_RETURN_IF_ERROR( + ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, + &b_, unroll_factor) + .EmitLoop(IrName(&hlo), + GetIndexTypeForKernel( + &hlo, launch_dimensions.launch_bound(), &b_))); std::vector tuple_operand_ptrs; for (int64 i = 0; i < output_arrays.size(); ++i) { tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } - ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, - module_); + b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); return Status::OK(); } Status IrEmitterUnnested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { - CHECK(Thunk::Kind::kKernel == LastThunk()->kind()); + CHECK_EQ(Thunk::Kind::kKernel, LastThunk()->kind()); return EmitTargetElementLoopInThunk(hlo, element_generator, static_cast(LastThunk())); } +int IrEmitterUnnested::ConstructIrArrayForOutputs( + const HloInstruction& hlo, std::vector* output_arrays) { + int64 num_outputs = 1; + if (hlo.IsMultiOutputFusion()) { + num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); + output_arrays->reserve(num_outputs); + for (int64 i = 0; i < num_outputs; ++i) { + output_arrays->push_back(GetIrArray(hlo, hlo, {i})); + } + } else { + output_arrays->push_back(GetIrArray(hlo, hlo)); + } + return num_outputs; +} + +int IrEmitterUnnested::ConstructIrArrayForInputs( + const HloInstruction& hlo, std::vector* param_arrays) { + int64 num_params = hlo.operands().size(); + param_arrays->reserve(num_params); + for (const HloInstruction* param : hlo.operands()) { + param_arrays->push_back(GetIrArray(*param, hlo)); + } + return num_params; +} + +int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( + const HloInstruction& hlo, const std::vector& output_arrays, + tensorflow::gtl::ArraySlice reduced_output_dims, + std::vector* output_reduced_shapes, + std::vector* output_in_reduced_shape_arrays) { + int64 num_outputs = 1; + if (hlo.IsMultiOutputFusion()) { + num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); + output_in_reduced_shape_arrays->reserve(num_outputs); + output_reduced_shapes->reserve(num_outputs); + for (int64 i = 0; i < num_outputs; ++i) { + output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( + ShapeUtil::GetSubshape(hlo.shape(), {i}).element_type(), + reduced_output_dims)); + output_in_reduced_shape_arrays->push_back( + output_arrays[i].CastToShape((*output_reduced_shapes)[i], &b_)); + } + } else { + output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( + hlo.shape().element_type(), reduced_output_dims)); + output_in_reduced_shape_arrays->push_back( + output_arrays[0].CastToShape((*output_reduced_shapes)[0], &b_)); + } + return num_outputs; +} + +int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( + const HloInstruction& hlo, const std::vector& param_arrays, + const std::vector& param_buffers, + tensorflow::gtl::ArraySlice reduced_output_dims, + std::vector* param_reduced_shapes, + std::vector* param_in_reduced_shape_arrays) { + int64 num_params = hlo.operands().size(); + param_in_reduced_shape_arrays->reserve(num_params); + param_reduced_shapes->reserve(num_params); + for (int64 id = 0; id < num_params; ++id) { + if (param_buffers[id] == nullptr) { + param_reduced_shapes->push_back(Shape()); + param_in_reduced_shape_arrays->push_back(IrArray()); + continue; + } + const HloInstruction* param = hlo.operand(id); + param_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( + param->shape().element_type(), + Permute({0, 2, 1}, reduced_output_dims))); + param_in_reduced_shape_arrays->push_back( + param_arrays[id].CastToShape((*param_reduced_shapes)[id], &b_)); + } + return num_params; +} + +namespace { + +// Reads thread_idx.x and converts it to a (y,x) coordinate, assuming that the +// thread lives within a square tile of size tile_size (so thread blocks are of +// size tile_size * tile_size). +std::tuple CalculateYXCoordinateWithinTile( + llvm::IRBuilder<>* builder, llvm::Value* tile_size, + int64 threads_per_tile) { + // Calculate the starting element coordinate within a tile for the current + // thread, (y, x) from thread_id. + llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); + llvm_ir::AddRangeMetadata(0, threads_per_tile, + llvm::cast(thread_id)); + thread_id = builder->CreateIntCast(thread_id, tile_size->getType(), + /*isSigned=*/true, "thread.id.x"); + auto x = builder->CreateURem(thread_id, tile_size); + auto y = builder->CreateUDiv(thread_id, tile_size); + return std::make_tuple(y, x); +} + +// Reads block_idx.x, casts it to type index_ty, and adds the assumption that +// it's in the range [0, num_blocks]. +llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty, + int64 num_blocks) { + llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder); + llvm_ir::AddRangeMetadata(0, num_blocks, + llvm::cast(block_id)); + return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true, + "block.id.x"); +} + +// Emits code to process up to (tile_size/num_rows) elements in a tile, given +// `emit_elem_function` is the function to emit code to process one element, `y` +// and `x` are the coordinates for the first element to process, and `index` is +// the index for the origin of the tile. Emits bounds check to ensure that each +// processed element is within the boundary defined by `tile_width` and +// `tile_height`. +void EmitTiledElementalCodeWithBoundsCheck( + int64 tile_size, int64 num_rows, const IrArray::Index& index, + const string& loop_name, KernelSupportLibrary* ksl, + llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, + llvm::Value* tile_width, llvm::Value* tile_height, + const std::function& + emit_elem_function) { + llvm::Type* index_ty = tile_width->getType(); + // Emits a constant value with index type. + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + // Adds `addend` to the given `dim` of `index`. + auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { + index[dim] = builder->CreateAdd(index[dim], addend); + return index; + }; + + auto emit_full_tile = [&] { + for (int64 i = 0; i < tile_size; i += num_rows) { + auto source_idx = offset_dim(index, index_typed_constant(i), /*dim=*/1); + auto y_loc = builder->CreateAdd(index_typed_constant(i), y); + emit_elem_function(source_idx, y_loc); + } + }; + + auto emit_last_row = [&] { + ksl->IfReturnVoid("x_in_tile", builder->CreateICmpULT(x, tile_width), [&] { + // tile_height_upper_bound = + // ceil(tile_height / num_rows) * num_rows + auto tile_height_upper_bound = builder->CreateMul( + builder->CreateUDiv( + builder->CreateAdd(tile_height, + index_typed_constant(num_rows - 1)), + index_typed_constant(num_rows)), + index_typed_constant(num_rows)); + ksl->ForReturnVoid( + loop_name, /*start=*/index_typed_constant(0), + /*end=*/tile_height_upper_bound, + /*step=*/index_typed_constant(num_rows), [&](llvm::Value* y_indvar) { + auto y_loc = builder->CreateAdd(y_indvar, y); + ksl->IfReturnVoid( + "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), [&] { + emit_elem_function(offset_dim(index, y_indvar, /*dim=*/1), + y_loc); + }); + }); + }); + }; + ksl->IfReturnVoid( + "full_tile", + builder->CreateAnd( + builder->CreateICmpEQ(index_typed_constant(tile_size), tile_width), + builder->CreateICmpEQ(index_typed_constant(tile_size), tile_height)), + emit_full_tile, emit_last_row); +} +} // namespace + +// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose +// algorithm to improve the memory access patterns for the input parameters +// which have a shape that is a 0-2-1 transpose of the output tensors. +// +// For the purpose of tiling, the output tensors have a logical shape of three +// components 0-2-1 while the relevant input parameters have a logical shape of +// three components 0-1-2 in the order major to minor. The x- and y- dimensions +// of the tensors are tiled in square tiles of edge length `kTileSize`. Each +// thread block of `kTileSize` x `kNumRows` threads transposes one tile: each +// thread copies kTileSize/kNumRows elements from the input to a shared memory +// tile, then the otherwise "regular hlo kernel" reads from the shared memory +// instead of the original input. +// +// This is similar to the following CUDA algorithm in TensorFlow: +// https://goo.gl/MStRV6. +// +// `kTileSize` should usually be same as warp size. We currently choose 32 for +// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. +// +// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient +// to launch fewer blocks so each transposes many tiles. +LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( + HloInstruction* hlo, tensorflow::gtl::ArraySlice reduced_output_dims, + tensorflow::gtl::ArraySlice tiled_param_ids) { + // Parameters for the tiling algorithm. + constexpr int64 kTileSize = 32; + constexpr int64 kNumRows = 4; + constexpr int64 kThreadsPerTile = kTileSize * kNumRows; + + // Construct IrArrays for the inputs and outputs. + std::vector output_arrays; + int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays); + std::vector param_arrays; + int64 num_params = ConstructIrArrayForInputs(*hlo, ¶m_arrays); + + // Allocate shared memory buffers to store the tiled inputs. + std::vector param_shmem_buffers(num_params, nullptr); + for (int64 id : tiled_param_ids) { + const HloInstruction* param = hlo->operand(id); + // Add 1 to the minor dimension to reduce shared memory bank conflicts. + llvm::Type* tile_type = llvm::ArrayType::get( + llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( + param->shape().element_type(), module_), + kTileSize + 1), + kTileSize); + const int kNVPTXSharedMemoryAddrSpace = 3; + auto* tile_base_ptr = new llvm::GlobalVariable( + *b_.GetInsertBlock()->getParent()->getParent(), tile_type, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::UndefValue::get(tile_type), + llvm_ir::AsStringRef(IrName(hlo, StrCat("tile", id))), nullptr, + llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace); + param_shmem_buffers[id] = tile_base_ptr; + VLOG(3) << "Added shmem buffer for parameter " << id << ": " + << llvm_ir::DumpToString(*tile_base_ptr); + } + + // The 0-2-1 shape of the tiling scheme is the reduced shape of the HLO result + // for the purpose of tiling. Calculate the logical output dimensions in the + // tile from the reduced output dimensions. + std::vector output_dims_in_tiles = std::vector( + reduced_output_dims.begin(), reduced_output_dims.end()); + CHECK_EQ(output_dims_in_tiles.size(), 3); + for (int i = 1; i < 3; ++i) { + output_dims_in_tiles[i] = + CeilOfRatio(output_dims_in_tiles[i], kTileSize); + } + const int64 num_tiles = + c_accumulate(output_dims_in_tiles, 1, std::multiplies()); + LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile); + + llvm::Type* index_ty = + GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + // Cast each output IrArray to its corresponding reduced shape and keep the + // reduced shape live during IR emission. + std::vector output_in_reduced_shape_arrays; + std::vector output_reduced_shapes; + CHECK_EQ(ConstructOutputReducedShapeAndCastOutputIrArrayToShape( + *hlo, output_arrays, reduced_output_dims, &output_reduced_shapes, + &output_in_reduced_shape_arrays), + num_outputs); + + // For each tiled parameter, cast its input IrArray to the corresponding + // reduced shape and keep the reduced shape live during IR emission. + std::vector param_in_reduced_shape_arrays; + std::vector param_reduced_shapes; + CHECK_EQ(ConstructInputReducedShapeAndCastInputIrArrayToShape( + *hlo, param_arrays, param_shmem_buffers, reduced_output_dims, + ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays), + num_params); + + // Calculate the starting element coordinate within a tile for the current + // thread, (y, x) from thread_id. + llvm::Value* x; + llvm::Value* y; + std::tie(y, x) = CalculateYXCoordinateWithinTile( + &b_, index_typed_constant(kTileSize), kThreadsPerTile); + + // Calculate the index for the current output tile from block_id. + const IrArray::Index output_tile_index( + GetBlockIdx(&b_, index_ty, num_tiles), + ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/, + output_dims_in_tiles), + &b_); + + // Output tile origin is the index for the first element of the current output + // tile. + const IrArray::Index output_tile_origin = [&] { + IrArray::Index index = output_tile_index; + for (int i = 1; i < 3; ++i) { + index[i] = + b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize), + "tile_origin." + std::to_string(i)); + } + return index; + }(); + + // Calculate the input tile origin from the output tile origin. + const IrArray::Index input_tile_origin( + Permute({0, 2, 1}, output_tile_origin.multidim())); + + // Calculate the current output tile bounds in each of the logical dimensions. + std::vector output_tile_bounds(3); + for (int i = 1; i < 3; ++i) { + // Only last row or column may not have full size. + output_tile_bounds[i] = b_.CreateSelect( + b_.CreateICmpEQ(output_tile_index[i], + index_typed_constant(output_dims_in_tiles[i] - 1)), + index_typed_constant(reduced_output_dims[i] - + (output_dims_in_tiles[i] - 1) * kTileSize), + index_typed_constant(kTileSize), "kTileSize"); + } + + KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); + + // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck. + auto emit_tiled_elemental_code_with_bounds_check = + [&](const IrArray::Index& index, const string& loop_name, + llvm::Value* tile_width, llvm::Value* tile_height, + const std::function& + emit_elem_function) { + EmitTiledElementalCodeWithBoundsCheck( + kTileSize, kNumRows, index, loop_name, &ksl, &b_, y, x, tile_width, + tile_height, emit_elem_function); + }; + + // Adds `addend` to the given `dim` of `index`. + auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { + index[dim] = b_.CreateAdd(index[dim], addend); + return index; + }; + const IrArray::Index input_index = + offset_dim(offset_dim(input_tile_origin, x, /*dim=*/2), y, /*dim=*/1); + + // Copy input parameter values to shared memory buffers: + // tile[y, x] = input[index] + emit_tiled_elemental_code_with_bounds_check( + input_index, "input", output_tile_bounds[1], output_tile_bounds[2], + [&](const IrArray::Index& index, llvm::Value* y_loc) { + for (int64 id : tiled_param_ids) { + IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; + llvm::Value* shmem_buffer = param_shmem_buffers[id]; + // TODO(jlebar): Add AA metadata to this store. Tile buffers are + // global variables, so LLVM can't infer much about it. + b_.CreateStore( + input_in_logical_shape.EmitReadArrayElement(index, &b_, + "input_element"), + b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); + } + }); + + // Wait for all threads to reach this point, lest we copy a value from tile to + // output before the other thread copies it from input to tile. + // This is `__syncthreads` in CUDA. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + + llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); + + const IrArray::Index output_index = + offset_dim(offset_dim(output_tile_origin, x, /*dim=*/2), y, /*dim=*/1); + + // Write to output[index] by emitting code like normal, except that values for + // the tiled parameters are read from the shmem buffers. + if (hlo->opcode() == HloOpcode::kCopy) { + emit_tiled_elemental_code_with_bounds_check( + output_index, "output", output_tile_bounds[2], output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc) { + // TODO(jlebar): Add AA metadata to this load. + llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad( + b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), + "output_element"); + output_in_reduced_shape_arrays[0].EmitWriteArrayElement( + index, load_from_shmem_buffer, &b_); + }); + } else { + CHECK_EQ(hlo->opcode(), HloOpcode::kFusion); + emit_tiled_elemental_code_with_bounds_check( + output_index, "output", output_tile_bounds[2], output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc) { + GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); + FusedIrEmitter fused_emitter(param_arrays, &elem_emitter); + tiled_param_info.set_y(y_loc); + fused_emitter.SetTiledParameterInfo(&tiled_param_info); + TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); + IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex( + index, output_reduced_shapes[0], output_arrays[0].GetShape(), + &b_); + const llvm_ir::ElementGenerator& output_generator = + fused_emitter.GetRootGenerator(); + llvm::Value* output_value = + output_generator(untiled_index).ValueOrDie(); + if (hlo->IsMultiOutputFusion()) { + CHECK(output_value->getType()->isStructTy()); + CHECK_EQ(output_value->getType()->getStructNumElements(), + output_in_reduced_shape_arrays.size()); + for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) { + output_in_reduced_shape_arrays[i].EmitWriteArrayElement( + index, b_.CreateExtractValue(output_value, i), &b_); + } + } else { + output_in_reduced_shape_arrays[0].EmitWriteArrayElement( + index, output_value, &b_); + } + }); + } + + // For multioutput fusion, emit a tuple with all the individual outputs. + if (hlo->IsMultiOutputFusion()) { + std::vector tuple_operand_ptrs; + for (int64 i = 0; i < output_arrays.size(); ++i) { + tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); + } + llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_, + module_); + } + + return launch_dimensions; +} + +bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { + HloOpcode opcode = hlo->opcode(); + CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy); + CHECK(opcode != HloOpcode::kFusion || + hlo->fusion_kind() == HloInstruction::FusionKind::kLoop) + << "Only loop fusions are supported."; + + const Shape& output_shape = hlo->IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo->shape(), {0}) + : hlo->shape(); + + // If the output_shape is reduced to 021 shape, find all the parameters of the + // hlo that are in the corresponding 012 shape. + std::vector params_012; + optional> reduced_dims_021; + for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); + ++operand_idx) { + HloInstruction* operand = hlo->mutable_operand(operand_idx); + auto find_transpose_result = + llvm_ir::FindTranspose021(operand->shape(), output_shape); + if (!find_transpose_result.has_value()) { + continue; + } + const std::vector& curr_reduced_dims_021 = *find_transpose_result; + if (!reduced_dims_021.has_value()) { + reduced_dims_021 = curr_reduced_dims_021; + } + if (!ContainersEqual(*reduced_dims_021, curr_reduced_dims_021)) { + // There is more than one possible transpose. Instead of picking one + // transpose, we simply give up here. + return false; + } + params_012.push_back(operand_idx); + } + + if (!reduced_dims_021.has_value()) { + return false; + } + + if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled || + (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) { + return false; + } + + // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the + // elements are of size 4 bytes), and CUDA has an architectural limit of 48kb + // shared memory per SM. (This is increased to 96kb in Volta, but we don't + // use this, in part because it eats into our L1 cache space.) + // + // For correctness we need to ensure that we don't make more than 48kb worth + // of shmem tiles per block. And for performance, we'd probably like to use + // significantly less, so that we can fit more than one block at a time on a + // gpu core. + // + // We say without benchmarks that we want at least 3 threads/block, + // corresponding to 3 shmem tiles if the elements are 32 bits wide. We choose + // which params get the shmem transpose treatment arbitrarily; it's not clear + // if there's a Right Choice. + // + // This is only sound if tiled transposes are the only place where we use + // shared memory in fusions. If in the future other fusile ops use shared + // memory, we'll have to adjust this heuristic. + constexpr int kMinBlocksPerCore = 3; + constexpr int64 kShmemPerCore = 48 * 1024; + int64 shmem_used = 0; + for (int64 i = 0; i < params_012.size(); ++i) { + const HloInstruction* operand = hlo->operand(params_012[i]); + shmem_used += + 32 * 33 * + ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type()); + + if (kMinBlocksPerCore * shmem_used > kShmemPerCore) { + // Erase this element and everything after it from params_012. + params_012.resize(i); + break; + } + } + + VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString(); + thunk_sequence_->emplace_back( + BuildKernelThunk(hlo, /*implements_whole_instruction=*/true)); + const LaunchDimensions launch_dimensions = + EmitHlo021Tile(hlo, *reduced_dims_021, params_012); + UpdateLaunchDimensions(launch_dimensions, LastThunk(), + ir_emitter_context_->llvm_module()); + + return true; +} + +Status IrEmitterUnnested::EmitConstantGlobals() { + for (const BufferAllocation& allocation : + ir_emitter_context_->buffer_assignment().Allocations()) { + if (!allocation.is_constant()) { + continue; + } + + const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); + const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal); + llvm::ArrayType* global_type = + llvm::ArrayType::get(b_.getInt8Ty(), allocation.size()); + llvm::Constant* initializer = + should_emit_initializer + ? llvm_ir::ConvertLiteralToIrConstant(literal, module_) + : llvm::ConstantAggregateZero::get(global_type); + if (should_emit_initializer) { + VLOG(3) << "Emitted initializer for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + } + + // These globals will be looked up by name by GpuExecutable so we need to + // give them an external linkage. Not all of their uses are visible in the + // LLVM IR (e.g. TupleThunk) so we can't give then a linkage that merely + // preserves their names (like available_externally), we also need to ensure + // that they stick around even if they're "unused". + // + // We may have to be more more clever here in the future if we notice that + // we're keeping around too many globals because of their linkage. + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + global_type, /*isConstant=*/should_emit_initializer, + llvm::GlobalValue::ExternalLinkage, + /*Initializer=*/initializer, + llvm_ir::AsStringRef( + llvm_ir::ConstantBufferAllocationToGlobalName(allocation))); + global_for_const->setAlignment(kConstantBufferAlignBytes); + ir_emitter_context_->llvm_module()->getGlobalList().push_back( + global_for_const); + } + + return Status::OK(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 202231b82f3877c11cf932bd00a8aac350fd0afa..525441990795e160ba0e8facb910d5cc9796c4bb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" namespace xla { namespace gpu { @@ -67,16 +68,19 @@ class IrEmitterUnnested : public IrEmitter { Status HandleDot(HloInstruction* dot) override; Status HandleFft(HloInstruction* fft) override; Status HandleFusion(HloInstruction* fusion) override; - Status HandleGather(HloInstruction* gather) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleSelectAndScatter(HloInstruction* instruction) override; Status HandleTuple(HloInstruction* tuple) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleInfeed(HloInstruction* xla_infeed) override; + Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleRng(HloInstruction* random) override; Status HandleSelect(HloInstruction* select) override; + Status HandleSort(HloInstruction* sort) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAfterAll(HloInstruction* gen_token) override; Status EmitTargetElementLoop( const HloInstruction& hlo, @@ -88,6 +92,9 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, KernelThunk* thunk); + // Emits LLVM global variables corresponding to constant instructions. + Status EmitConstantGlobals(); + private: // Builds the appropriate thunk for the instruction hlo and returns the owning // pointer to it. The caller needs to make sure `inst` outlives the lifetime @@ -115,7 +122,7 @@ class IrEmitterUnnested : public IrEmitter { // Emits code that reduces a matrix of shape [height x width] to a vector of // [width]. Other parameters have the same meaning as those of // `EmitReductionToVector`. Note that input shape might not be - // [height x width], but can be bitcast to [height x weight] with "height" + // [height x width], but can be bitcast to [height x width] with "height" // being the major dimension. Status EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, @@ -131,7 +138,7 @@ class IrEmitterUnnested : public IrEmitter { // Emits code that reduces a 3D tensor of shape [depth x height x width] to a // vector of shape [height]. Other parameters have the same meaning as those // of `EmitReductionToVector`. Note that input shape might not be - // [depth x height x width], but can be bitcast to [depth x height x weight] + // [depth x height x width], but can be bitcast to [depth x height x width] // with "depth" being the most major dimension. Status EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, @@ -182,12 +189,56 @@ class IrEmitterUnnested : public IrEmitter { std::pair> extra_output_gens); + // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel + // for the hlo instruction. + bool CheckAndEmitHloWithTile021(HloInstruction* hlo); + // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and + // returns the launch dimensions for the kernel. This is a helper to support + // the implementation of CheckAndEmitHloWithTile021. + LaunchDimensions EmitHlo021Tile( + HloInstruction* hlo, + tensorflow::gtl::ArraySlice reduced_output_dims, + tensorflow::gtl::ArraySlice tiled_param_ids); + // Generates the IrArray for each output of hlo and returns the number of + // outputs. + int ConstructIrArrayForOutputs(const HloInstruction& hlo, + std::vector* output_arrays); + // Generates the IrArray for each input of hlo and returns the number of + // inputs. + int ConstructIrArrayForInputs(const HloInstruction& hlo, + std::vector* param_arrays); + // For each output of the `hlo` instruction, constructs the reduced shape for + // the output with the given `reduced_output_dims` and cast the original + // output IrArray element in `output_arrays` to the reduced shape. Returns + // the number of outputs. + int ConstructOutputReducedShapeAndCastOutputIrArrayToShape( + const HloInstruction& hlo, + const std::vector& output_arrays, + tensorflow::gtl::ArraySlice reduced_output_dims, + std::vector* output_reduced_shapes, + std::vector* output_in_reduced_shape_arrays); + // For each input of the `hlo` instruction, checks its value in + // `param_buffers` to find out whether the input has a reduced shape. If the + // input has a reduced shape, constructs the reduced shape for the input and + // casts the original input IrArray in `param_arrays` to the reduced shape. + // Return the total number of inputs. + int ConstructInputReducedShapeAndCastInputIrArrayToShape( + const HloInstruction& hlo, + const std::vector& param_arrays, + const std::vector& param_buffers, + tensorflow::gtl::ArraySlice reduced_output_dims, + std::vector* param_reduced_shapes, + std::vector* param_in_reduced_shape_arrays); + // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned // Thunk object. The kernel implementation will be unrolled if unroll_factor - // is greater than one. - std::unique_ptr BuildKernelThunk(const HloInstruction* inst, - int unroll_factor = 1); + // is greater than one. 'implements_whole_instruction' specifies whether this + // KernelThunk implements the whole 'inst' HloInstruction. In some cases + // 'inst' will be implemented by a sequence of Thunks. + std::unique_ptr BuildKernelThunk( + const HloInstruction* inst, bool implements_whole_instruction, + int unroll_factor = 1); // Returns a FftThunk that calls cuFFT to implement `inst`. std::unique_ptr BuildFftThunk(const HloInstruction* inst); @@ -208,10 +259,14 @@ class IrEmitterUnnested : public IrEmitter { std::unique_ptr BuildDeviceToDeviceCopyThunk( const HloInstruction* inst); - // Returns an InfeedThunk that performs device-to-device memcpy to implement + // Returns an InfeedThunk that performs a host-to-device memcpy to implement // `inst`. std::unique_ptr BuildInfeedThunk(const HloInstruction* inst); + // Returns an OutfeedThunk that performs a device-to-host memcpy to implement + // `inst`. + std::unique_ptr BuildOutfeedThunk(const HloInstruction* inst); + // Returns a WhileThunk that invokes thunk sequences for 'condition' and // 'body' sub-computations of while instruction 'hlo'. std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index f56c1ce69f11ed79c8be76834269f29de93a9645..e76823ad103dfa5ba61a0d3ba81b2c028dfeb33e 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -75,7 +76,8 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { } Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { // Load the kernel. se::StreamExecutor* executor = stream->parent(); LaunchDimensions launch_dimensions; @@ -100,6 +102,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, VLOG(3) << " Arg: alloc #" << arg->index() << ": " << buf.opaque() << " (" << buf.size() << "B)"; } + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); if (!stream->parent()->Launch( stream, se::ThreadDim(launch_dimensions.threads_per_block()), se::BlockDim(launch_dimensions.block_count()), *kernel, diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 7def27e189b66747569344a3dbe5c0c446f903be..d751de50ad6671b3bf88cd4de49a8feb448e13ba 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -62,7 +63,8 @@ class KernelThunk : public Thunk { // Executes the kernel for the thunk on "stream", which must be non-null. Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: // Buffers passed to the kernel as arguments. diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 7de8f9e1ee922bdbf65fd1299702482e1843f17e..eb93efc560efbb4c14065ec98b980a1ca78605c6 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -17,12 +17,12 @@ cc_library( name = "llvm_gpu_backend", srcs = [ "dump_ir_pass.cc", - "gpu_backend_lib.cc", + "nvptx_backend_lib.cc", "utils.cc", ], hdrs = [ "dump_ir_pass.h", - "gpu_backend_lib.h", + "nvptx_backend_lib.h", "utils.h", ], deps = [ @@ -34,6 +34,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@llvm//:amdgpu_code_gen", "@llvm//:analysis", "@llvm//:bit_reader", "@llvm//:bit_writer", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc similarity index 93% rename from tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc rename to tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index a4e4e85bf3d2c197cfc691b7fca0920aa6571729..ff4ae1f9ef2ad2fda4bb9100de93019c0b88fbd1 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h" #include #include @@ -114,20 +114,20 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Gets the GPU name as it's known to LLVM for a given compute capability. If // we see an unrecognized compute capability, we return "sm_30". static string GetSmName(std::pair compute_capability) { - static auto* m = new std::map, int>({{{2, 0}, 20}, - {{2, 1}, 21}, - {{3, 0}, 30}, - {{3, 2}, 32}, - {{3, 5}, 35}, - {{3, 7}, 37}, - {{5, 0}, 50}, - {{5, 2}, 52}, - {{5, 3}, 53}, - {{6, 0}, 60}, - {{6, 1}, 61}, - {{6, 2}, 62}, - // TODO: Change this to 70 once LLVM NVPTX supports it - {{7, 0}, 60}}); + static auto* m = new std::map, int>({ + {{3, 0}, 30}, + {{3, 2}, 32}, + {{3, 5}, 35}, + {{3, 7}, 37}, + {{5, 0}, 50}, + {{5, 2}, 52}, + {{5, 3}, 53}, + {{6, 0}, 60}, + {{6, 1}, 61}, + {{6, 2}, 62}, + {{7, 0}, 70}, + {{7, 2}, 72}, + }); int sm_version = 30; auto it = m->find(compute_capability); if (it != m->end()) { @@ -180,7 +180,7 @@ std::unique_ptr GetTargetMachine( TargetOptions target_options = InitTargetOptionsFromCodeGenFlags(); llvm_ir::SetTargetOptions( /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_enable_fast_math(), + .xla_gpu_enable_fast_math(), &target_options); // Enable FMA synthesis. @@ -206,7 +206,7 @@ std::unique_ptr GetTargetMachine( codegen_opt_level = CodeGenOpt::None; } return WrapUnique(target->createTargetMachine( - triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx42", target_options, + triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, Optional(RelocModel), Optional(CMModel), codegen_opt_level)); } @@ -319,8 +319,8 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, llvm::Linker linker(*module); string libdevice_path = tensorflow::io::JoinPath( - libdevice_dir_path, GetLibdeviceFilename(libdevice_dir_path, - compute_capability)); + libdevice_dir_path, + GetLibdeviceFilename(libdevice_dir_path, compute_capability)); TF_RETURN_IF_ERROR(tensorflow::Env::Default()->FileExists(libdevice_path)); VLOG(1) << "Linking with libdevice from: " << libdevice_path; std::unique_ptr libdevice_module = @@ -328,7 +328,7 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, if (linker.linkInModule( std::move(libdevice_module), llvm::Linker::Flags::LinkOnlyNeeded, [](Module& M, const StringSet<>& GVS) { - internalizeModule(M, [&M, &GVS](const GlobalValue& GV) { + internalizeModule(M, [&GVS](const GlobalValue& GV) { return !GV.hasName() || (GVS.count(GV.getName()) == 0); }); })) { diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h similarity index 90% rename from tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h rename to tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h index 0a345191d34e6f40db043c559a67a44a6748321c..54e0e140dea1c3a8b21ffde2950c4bc9b703b71c 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // LLVM-based compiler backend. -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_LIB_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_LIB_H_ #include #include @@ -44,4 +44,4 @@ StatusOr CompileToPtx(llvm::Module* module, } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_LIB_H_ diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc index d4100a898b5bb9eec382c34932c2db104c9e985b..9fd6cf7157ecd659e7eb1d2c5228eca931ff6a01 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc @@ -14,21 +14,27 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" + +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/stream_executor/stream_executor.h" namespace xla { namespace gpu { Status MemzeroThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemZero(&dest_data, dest_data.size()); return Status::OK(); } Status Memset32BitValueThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemset32(&dest_data, value_, dest_data.size()); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h index 51c332d287d139335b356fc66411b5ffaa448b5a..d1fec0bd76b8a80f4a1e1c2e818f248997da7a75 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status.h" @@ -36,7 +37,8 @@ class MemzeroThunk : public Thunk { : Thunk(Kind::kMemzero, hlo), dest_(dest) {} Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const BufferAllocation::Slice dest_; @@ -52,7 +54,8 @@ class Memset32BitValueThunk : public Thunk { : Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {} Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: uint32 value_; diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index e3f444a1268cf8e3e551a3c5b986fee0339fa327..c62bae0628f7b2fbfe822104fbe5f3528e0e09c3 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -23,9 +23,13 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -77,46 +81,69 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, return false; } // The elementwise output shapes must be the same (including layout). - return ShapeUtil::Equal(get_element_shape(element_instr_1), - get_element_shape(element_instr_2)); -} - -bool GpuMultiOutputFusion::IsProfitableOperand(HloInstruction* instr) { - // kConstant instruction will not have memory reads, so it won't be a profit - // source. Skip them. - if (instr->opcode() == HloOpcode::kConstant && - ShapeUtil::IsEffectiveScalar(instr->shape())) { - return false; - } - // We don't target to fuse producer/consumer instructions -- this should - // be taken care of by the instruction_fusion pass. If instr has only - // one user, it will not have sibling instructions. We won't consider it. - if (instr->user_count() < 2) { - return false; - } - return true; + return ShapeUtil::EqualIgnoringFpPrecision( + get_element_shape(element_instr_1), get_element_shape(element_instr_2)); } namespace { -bool IsReduction(HloInstruction* instr) { +bool IsInputFusibleReduction(HloInstruction* instr) { if (instr->IsMultiOutputFusion()) { for (const HloInstruction* operand : instr->fused_expression_root()->operands()) { if (operand->opcode() == HloOpcode::kReduce) { + CHECK(instr->fusion_kind() == HloInstruction::FusionKind::kInput) + << " Reduce multi-output fusion " << instr->ToString() + << " must be an input fusion."; return true; } } return false; } else if (instr->opcode() == HloOpcode::kFusion) { - return instr->fused_expression_root()->opcode() == HloOpcode::kReduce; + // The loop emitter can handle to-vector reduce fusions. Such reduce + // fusions have the fusion kind kLoop rather than kInput. We do not fuse + // to-vector reduce fusions, because the resulting fusions may no longer be + // supported by loop emitter. + return IsReductionToVector(*instr->fused_expression_root()); } else { - return instr->opcode() == HloOpcode::kReduce; + return IsReductionToVector(*instr); + } +} + +// The code emitted for reduction suffers from poor data locality if the layouts +// of input parameters differ. In such situtations it is beneficial not to fuse. +// We consider input params with maximum rank only. Params with smaller ranks +// will be broadcasted and have not been observed to cause data locality issues. +// TODO(b/111977086): Improve reduce emitters to remove this limitation. +bool ReduceFriendlyInputLayouts(HloInstruction* instr) { + std::vector params; + if (instr->opcode() == HloOpcode::kFusion) { + params = instr->fused_parameters(); + } else { + for (HloInstruction* operand : instr->operands()) { + params.push_back(operand); + } + } + int64 max_rank = 0; + const Layout* max_rank_layout; + for (HloInstruction* param : params) { + if (ShapeUtil::Rank(param->shape()) > max_rank) { + max_rank = ShapeUtil::Rank(param->shape()); + max_rank_layout = ¶m->shape().layout(); + } } + return c_all_of(params, [&](HloInstruction* param) { + return (ShapeUtil::Rank(param->shape()) < max_rank) || + (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); + }); } + } // namespace bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { - return IsReduction(instr); + // We can fuse reduces and loop fusions. + return IsInputFusibleReduction(instr) || + (instr->opcode() == HloOpcode::kFusion && + instr->fusion_kind() == HloInstruction::FusionKind::kLoop); } int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, @@ -140,5 +167,146 @@ int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, return profit; } +bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, + HloInstruction* instr2) { + if (!MultiOutputFusion::LegalToFuse(instr1, instr2)) { + return false; + } + + // If we're fusing fusions only do it if the fusion kind matches. Loop fusions + // merge into bigger loop fusions and input (reduce) fusions become fusions + // with multiple reduce outputs. We could fuse reduce and loop fusions + // together too (the result being an input fusion) if we find cases where this + // improves things. + CHECK(instr1->opcode() == HloOpcode::kFusion); + if ((instr2->opcode() == HloOpcode::kFusion && + instr1->fusion_kind() != instr2->fusion_kind()) || + (instr2->opcode() != HloOpcode::kFusion && + instr1->fusion_kind() == HloInstruction::FusionKind::kLoop)) { + return false; + } + + // Do this check last, as it may be expensive. + return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2); +} + +bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { + bool changed = false; + RecomputeReachability(); + + tensorflow::gtl::FlatSet to_fuse; + // Keep a list of the instructions to fuse after making all the fusion + // decisions. We first aggressively add instructions to potential_fusion_list, + // then filter out instructions that will be no longer fusable because of + // reachability change. This avoids recalculating reachability on a large set + // of instructions. + std::vector> + potential_fusion_list; + std::vector> fusion_list; + std::vector instrs_to_update_reachability; + + // For each reduce or reduce multi-output fusion, try to fuse it with loop + // fusions operands. + for (HloInstruction* consumer : computation()->MakeInstructionPostOrder()) { + if (consumer->user_count() == 0) { + VLOG(3) << consumer->name() << " has no users."; + continue; + } + if (!IsInputFusibleReduction(consumer)) { + VLOG(3) << consumer->name() << " is not an input-fusable reduction."; + continue; + } + VLOG(3) << consumer->name() + << " is a fusion candidate. Looking for fuseable operands."; + + auto consumer_operands = consumer->operands(); + for (size_t i = 0; i < consumer_operands.size(); ++i) { + HloInstruction* producer = consumer_operands[i]; + if (!producer->IsFusable()) { + VLOG(3) << producer->name() << " is not fusable."; + continue; + } + const bool is_loop_fusion = + producer->opcode() == HloOpcode::kFusion && + producer->fusion_kind() == HloInstruction::FusionKind::kLoop; + if (!producer->IsElementwise() && !is_loop_fusion) { + VLOG(3) << producer->name() << " is not a loop fusion."; + continue; + } + if (!ShapesCompatibleForFusion(producer, consumer)) { + VLOG(3) << producer->name() << " has an incompatible shape."; + continue; + } + if (!ReduceFriendlyInputLayouts(producer)) { + VLOG(3) << producer->name() << " has inputs with mixed layouts."; + continue; + } + // If we have already decided to fuse this producer, skip it. + if (ContainsKey(to_fuse, producer)) { + VLOG(3) << producer->name() << " will be fused with another consumer."; + continue; + } + // Do not fuse a producer if the other operands of the fusion are + // reachable from the producer, this would create a cycle. + if (c_any_of(consumer_operands, [&](HloInstruction* operand) { + return producer != operand && + reachability()->IsReachable(producer, operand); + })) { + VLOG(3) << producer->name() << " would introduce a cycle when fused."; + break; + } + to_fuse.insert(producer); + potential_fusion_list.emplace_back(producer, consumer); + instrs_to_update_reachability.push_back(producer); + instrs_to_update_reachability.push_back(consumer); + break; + } + } + + // Filter out pairs that will be no longer fusable because of reachability + // change. + for (auto& fusion_pair : potential_fusion_list) { + HloInstruction* producer = fusion_pair.first; + HloInstruction* consumer = fusion_pair.second; + if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) { + return producer != operand && + reachability()->IsReachable(producer, operand); + })) { + UpdateReachability(producer, consumer, instrs_to_update_reachability); + fusion_list.push_back(fusion_pair); + } + } + + for (auto fusions_to_create : fusion_list) { + HloInstruction* producer = fusions_to_create.first; + HloInstruction* consumer = fusions_to_create.second; + if (consumer->opcode() != HloOpcode::kFusion) { + // Fusing with a reduce (fusion) always results in an input fusion. + HloInstruction* input_fusion = + computation()->AddInstruction(HloInstruction::CreateFusion( + consumer->shape(), HloInstruction::FusionKind::kInput, consumer)); + VLOG(2) << "Fuse producer " << producer->name() << " and its consumer " + << consumer->name() << " into " << input_fusion->name(); + TF_CHECK_OK(computation()->ReplaceInstruction(consumer, input_fusion)); + if (producer->opcode() == HloOpcode::kFusion) { + input_fusion->MergeFusionInstructionIntoMultiOutput(producer); + } else { + input_fusion->FuseInstructionIntoMultiOutput(producer); + } + } else { + VLOG(2) << "Fuse producer " << producer->name() << " into its consumer " + << consumer->name(); + + if (producer->opcode() == HloOpcode::kFusion) { + consumer->MergeFusionInstructionIntoMultiOutput(producer); + } else { + consumer->FuseInstructionIntoMultiOutput(producer); + } + } + changed = true; + } + return changed; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h index 5451a93cec5e4aeca05717c181edb9dad0305c83..67ca5d49eee8508e93284b134f8410eb3a89f9ce 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -43,10 +43,11 @@ class GpuMultiOutputFusion : public MultiOutputFusion { // estimated as the size of the common operands b/w instr1 and instr2. int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) override; - // Whether fusing the instruction can reduce memory reads. - // - // TODO(tjoerg): Move this method up into the MultiOutputFusion base class. - bool IsProfitableOperand(HloInstruction* instr) override; + // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. + bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) override; + + // Fuse loop fusions into reduce fusions. + bool DoProducerConsumerMultiOutputFusion() override; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 924cfb11f3da76c475458ea14f201e809be61be8..14f157a5e518a0ec82c664c123629d04bd385bbf 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -27,7 +28,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace gpu { -using InstructionFusionTest = HloTestBase; +using MultiOutputFusionTest = HloTestBase; const char kModulePrefix[] = R"( HloModule test_module @@ -40,10 +41,10 @@ const char kModulePrefix[] = R"( scalar_mul_computation { scalar_lhs.1 = f32[] parameter(0) scalar_rhs.1 = f32[] parameter(1) - ROOT mul.1 = f32[] add(scalar_lhs.1, scalar_rhs.1) + ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1) })"; -TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { // Fusion with reduce instruction root and a sibling reduce instruction // sharing the same input param. auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( @@ -72,7 +73,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { op::Tuple(op::Reduce(), op::Reduce())); } -TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) { +TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) { auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[6400]{0} parameter(1) @@ -99,7 +100,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) { ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); } -TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { +TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[10,10]{1,0} parameter(1) @@ -126,7 +127,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); } -TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) { +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) { // Two sibling fusions with reduce instruction roots sharing the same input // param. auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( @@ -160,7 +161,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) { op::Tuple(op::Reduce(), op::Reduce())); } -TEST_F(InstructionFusionTest, +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) { // Multi-output fusion with two reduce instructions root and a sibling reduce // instruction sharing the same input param. @@ -193,7 +194,7 @@ TEST_F(InstructionFusionTest, op::Tuple(op::Reduce(), op::Reduce(), op::Reduce())); } -TEST_F(InstructionFusionTest, +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) { // Verify that if we already have a multi-output fusion that we prefer to pick // a reduce op from its operands for checking shape compatibility. @@ -226,5 +227,271 @@ TEST_F(InstructionFusionTest, ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); } +TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + fused_computation_2 { + p0.2 = f32[6400]{0} parameter(0) + const.2 = f32[] constant(1) + ROOT div = f32[6400]{0} divide(p0.2, const.2) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Multiply(), op::Divide())); +} + +TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add_computation + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Exp())); +} + +TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_add { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + c0 = f32[] constant(0) + add = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add + reduce = f32[2,2]{1,0} reduce(add, c0), dimensions={2}, to_apply=scalar_add_computation + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, add) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Add())); +} + +TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_select { + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + c0 = f32[] constant(0) + broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} + greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + c1 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add_computation + mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce + gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0 + gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1 + ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(), + op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce(), op::Select())); +} + +TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_element_wise { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + ROOT root = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + c1 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={1}, to_apply=scalar_add_computation + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(element_wise), kind=kLoop, calls=fused_reduce + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(MultiOutputFusionTest, + ProducerConsumerFusionFp16LoopFusionAndReduceFusion) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_select { + p1.1 = f16[2,2,2]{2,1,0} parameter(1) + c0 = f16[] constant(0) + broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={} + greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + p0.1 = f16[2,2,2]{2,1,0} parameter(0) + ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast) + } + fused_reduce { + p0.2 = f16[2,2,2]{2,1,0} parameter(0) + convert = f32[2,2,2]{2,1,0} convert(p0.2) + c1 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(convert, c1), dimensions={2}, to_apply=scalar_add_computation + mul = f32[2,2,2]{2,1,0} multiply(convert, convert) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + ENTRY reduce { + p0 = f16[2,2,2]{2,1,0} parameter(0) + p1 = f16[2,2,2]{2,1,0} parameter(1) + select = f16[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce + gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0 + gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1 + ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) tuple(gte1, gte1, select) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(), + op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce(), op::Select())); +} + +TEST_F(MultiOutputFusionTest, + ProducerConsumerFusionReduceUnfriendlyLoopFusion) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + mixed_input_layouts_computation { + p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1) + c0 = f16[] constant(0) + broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={} + greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast) + ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast) + } + fused_reduce { + p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2) + c0.2 = f32[] constant(0) + ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add_computation + } + ENTRY reduce { + p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0) + p1 = f16[128,1024,32,32]{1,3,2,0} parameter(1) + loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation + reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce + ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +// Check that we limit the number of operands to fusions we create. +TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) { + constexpr int64 kNumParams = 200; + ASSERT_GT(kNumParams, GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion); + + // Compute + // p0 * p1, + // p0 * p1 + p1 * p2 + // p0 * p1 + p1 * p2 + p2 * p3 + // ... + // where each of the (pi * pj)'s is represented as a fusion node so that + // multi-output fusion will pay attention to it. + auto module = CreateNewModule(); + HloComputation::Builder b(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {10, 100}); + + std::vector params; + for (int64 i = 0; i < kNumParams; ++i) { + params.push_back( + b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p"))); + } + + // Creates a fusion node that calculates x*y. + auto make_fusion = [&](HloInstruction* x, HloInstruction* y) { + HloComputation::Builder sub_builder("subcomp"); + auto* p0 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "p")); + auto* p1 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "p")); + sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1)); + HloComputation* subcomp = + module->AddEmbeddedComputation(sub_builder.Build()); + return HloInstruction::CreateFusion( + shape, HloInstruction::FusionKind::kLoop, {x, y}, subcomp); + }; + + auto* sum = b.AddInstruction(make_fusion(params[0], params[1])); + for (int64 i = 2; i < kNumParams; ++i) { + sum = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, sum, + b.AddInstruction(make_fusion(params[i - 1], params[i])))); + } + auto computation = module->AddEntryComputation(b.Build()); + EXPECT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + for (const HloInstruction* instr : computation->instructions()) { + EXPECT_LE(instr->operand_count() + ShapeUtil::SubshapeCount(instr->shape()), + GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion) + << instr->ToString(); + } +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc similarity index 87% rename from tensorflow/compiler/xla/service/gpu/gpu_compiler.cc rename to tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index afefc740d707cd6fd01420e950eafd37abe80119..1ab663fb2ae424a508084b0165c6812f724f28f3 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" #include #include @@ -34,9 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" -#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" @@ -51,11 +49,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" -#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -72,10 +72,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" -#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -97,8 +97,8 @@ limitations under the License. namespace xla { namespace gpu { -/* static */ const char* GpuCompiler::kTargetTriple = "nvptx64-nvidia-cuda"; -/* static */ const char* GpuCompiler::kDataLayout = +/* static */ const char* NVPTXCompiler::kTargetTriple = "nvptx64-nvidia-cuda"; +/* static */ const char* NVPTXCompiler::kDataLayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"; namespace { @@ -147,7 +147,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // support BF16 operations without directly implementing a BF16 lowering for // most ops. pipeline.AddPass(BF16, F32); - pipeline.AddPass(); { auto& pass = @@ -165,13 +164,12 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - // Rewrite gather ops into smaller ones. - pass.AddPass(); - // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. pipeline.AddPass(); + pipeline.AddPass(); + pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); @@ -203,13 +201,19 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker(); pipeline.AddPass(); pipeline.AddPass(); + if (IsVoltaOrLater(*stream_exec)) { + pipeline.AddPass(); + // PadForTensorCores leaves behind unnecessary tuple/get-tuple-element + // pairs that TupleSimplifier fixes. + pipeline.AddPass(); + } TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } { HloPassPipeline pipeline("layout_assignment"); pipeline.AddPass( - hlo_module->mutable_device_entry_computation_layout(), stream_exec); + hlo_module->mutable_entry_computation_layout(), stream_exec); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -260,6 +264,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); fusion.AddPass(); + fusion.AddPass(/*is_layout_sensitive=*/true, + /*only_fusion_computations=*/true); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); @@ -277,14 +283,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, } } - { - // Do an aggressive LICM pass over while loops. In particular, this hoists - // constants that were sunk by WhileLoopConstantSinking. Leaving them in - // the while loop may result in unnecessary copies. - HloPassPipeline pipeline("while-loop-licm"); - pipeline.AddPass(true); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } return Status::OK(); } @@ -356,16 +354,30 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { return; } + // We need ptxas >= 9.0 as a hard requirement, because we compile targeting + // PTX 6.0. An older ptxas will just fail to compile any of our code. + // // ptxas 9.0 before 9.0.276 and ptxas 9.1 before 9.1.121 miscompile some // address calculations with large offsets (e.g. "load ptr + large_constant"), // b/70245379. - if ((vmaj == 9 && vmin == 0 && vdot < 276) || - (vmaj == 9 && vmin == 1 && vdot < 121)) { - LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." - << vmin << "." << vdot - << ", which is in range [9.0.0, 9.0.276) + [9.1.0, 9.1.121). " - "These versions are known to miscompile XLA code, leading " - "to incorrect results or invalid-address errors."; + // + // ptxas 9.1.121 miscompiles some large multioutput fusions, again in a way + // that appears related to address calculations, b/111107644. ptxas 9.2.88 + // appears to work, as far as we can tell. + if (vmaj < 9) { + LOG(ERROR) + << "You are using ptxas 8.x, but XLA requires ptxas 9.x (and strongly " + "prefers >= 9.2.88). Compilation of XLA kernels below will likely " + "fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas " + "binary is sufficient."; + } else if ((vmaj < 9 || vmin < 2 || vdot < 88)) { + LOG(WARNING) + << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." + << vdot + << ", which older than 9.2.88. ptxas 9.x before 9.2.88 is known to " + "miscompile XLA code, leading to incorrect results or " + "invalid-address errors.\n\nYou do not need to update to CUDA " + "9.2.88; cherry-picking the ptxas binary is sufficient."; } } @@ -393,17 +405,18 @@ void WarnIfBadDriverJITVersion() { // - 384.x before 384.108 // - 387.x before 387.40 // - 390.x before 390.10. - auto vmaj = std::get<0>(version); - auto vmin = std::get<1>(version); - if ((vmaj == 384 && vmin < 108) || // - (vmaj == 387 && vmin < 40) || // - (vmaj == 390 && vmin < 10)) { + // + // In addition, only >= 396.20 contains ptxas >= 9.2.88, which contains the + // fix for the "large multioutput fusions" miscompile, b/111107644. + if (version < std::make_tuple(396, 20, 0)) { LOG(WARNING) << "*** WARNING *** Invoking the PTX->SASS JIT from driver version " << se::cuda::DriverVersionToString(version) - << ", which is in range [384.0.0, 384.108.0) + [387.0.0, 387.40.0) + " - "[390.0.0, 390.10.0). These versions are known to miscompile XLA " - "code, leading to incorrect results or invalid-address errors."; + << ", which is older than 396.20.0. These versions are known to " + "miscompile XLA code, leading to incorrect results or " + "invalid-address errors.\nXLA only uses the driver JIT if it " + "cannot find ptxas; you don't need to update your driver if " + "you can point XLA to ptxas 9.2.88 or newer."; } }); } @@ -475,14 +488,14 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, } // namespace -GpuCompiler::GpuCompiler() +NVPTXCompiler::NVPTXCompiler() : pointer_size_(llvm::DataLayout(kDataLayout) .getPointerSize(0 /* default address space */)) {} -StatusOr> GpuCompiler::RunHloPasses( +StatusOr> NVPTXCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses"); tracing::ScopedActivity activity("HLO Transforms", module->name(), /*is_expensive=*/true); TF_RETURN_IF_ERROR( @@ -490,10 +503,10 @@ StatusOr> GpuCompiler::RunHloPasses( return std::move(module); } -StatusOr> GpuCompiler::RunBackend( +StatusOr> NVPTXCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend"); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend"); TF_RET_CHECK(stream_exec != nullptr); @@ -527,11 +540,13 @@ StatusOr> GpuCompiler::RunBackend( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr buffer_assignment, - BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), - BufferSizeBytesFunction(), - /*color_alignment=*/[](LogicalBuffer::Color) { - return kCudaMallocAlignBytes; - })); + BufferAssigner::Run( + module.get(), hlo_schedule->ConsumeHloOrdering(), + BufferSizeBytesFunction(), + /*color_alignment=*/ + [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::Stats::ToString() and BufferAssignment::ToString() // include headers, so no need for us to print them ourselves. XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); @@ -552,10 +567,12 @@ StatusOr> GpuCompiler::RunBackend( HloComputation* entry_computation = module->entry_computation(); IrEmitterUnnested ir_emitter(module->config(), entry_computation, &ir_emitter_context); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); - TF_RETURN_IF_ERROR( - entry_computation->root_instruction()->Accept(&ir_emitter)); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - IR emission"); + TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter)); } if (user_pre_optimization_hook_) { @@ -581,7 +598,8 @@ StatusOr> GpuCompiler::RunBackend( } { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier"); + XLA_SCOPED_LOGGING_TIMER( + "NVPTXCompiler::RunBackend - Running LLVM verifier"); std::string err; llvm::raw_string_ostream err_stream(err); @@ -621,7 +639,7 @@ StatusOr> GpuCompiler::RunBackend( string ptx; { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - CompileToPtx"); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - CompileToPtx"); TF_ASSIGN_OR_RETURN(ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, module->config(), libdevice_dir)); } @@ -690,10 +708,10 @@ StatusOr> GpuCompiler::RunBackend( return std::unique_ptr(gpu_executable); } -std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, - int cc_minor) { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::CompilePtxOrGetCachedResult"); +std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, + int cc_major, + int cc_minor) { + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompilePtxOrGetCachedResult"); tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); bool inserted; decltype(compilation_cache_.begin()) iter; @@ -766,12 +784,14 @@ std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, } StatusOr>> -GpuCompiler::CompileAheadOfTime(std::vector> module, - const AotCompilationOptions& options) { - return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime"); +NVPTXCompiler::CompileAheadOfTime( + std::vector> module, + const AotCompilationOptions& options) { + return Unimplemented( + "not yet implemented: NVPTXCompiler::CompileAheadOfTime"); } -se::Platform::Id GpuCompiler::PlatformId() const { +se::Platform::Id NVPTXCompiler::PlatformId() const { return se::cuda::kCudaPlatformId; } @@ -781,7 +801,7 @@ se::Platform::Id GpuCompiler::PlatformId() const { static bool InitModule() { xla::Compiler::RegisterCompilerFactory( stream_executor::cuda::kCudaPlatformId, - []() { return xla::MakeUnique(); }); + []() { return xla::MakeUnique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h similarity index 93% rename from tensorflow/compiler/xla/service/gpu/gpu_compiler.h rename to tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index f3b02ae5d8867bdf1d970e809bff95a15d9f54d2..d4d2909f1b2dc57c3ae0f9d67067e533574369dd 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NVPTX_COMPILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NVPTX_COMPILER_H_ #include #include @@ -37,10 +37,10 @@ namespace xla { namespace gpu { // The GPU compiler generates efficient GPU executables. -class GpuCompiler : public LLVMCompiler { +class NVPTXCompiler : public LLVMCompiler { public: - GpuCompiler(); - ~GpuCompiler() override {} + NVPTXCompiler(); + ~NVPTXCompiler() override {} // Bring in // StatusOr>> Compile( @@ -64,7 +64,7 @@ class GpuCompiler : public LLVMCompiler { se::Platform::Id PlatformId() const override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { - // Capture just the pointer size, not the entire GpuCompiler object. + // Capture just the pointer size, not the entire NVPTXCompiler object. int64 pointer_size = pointer_size_; return [pointer_size](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape, pointer_size); @@ -146,10 +146,10 @@ class GpuCompiler : public LLVMCompiler { CompilationCacheHash, CompilationCacheEq> compilation_cache_ GUARDED_BY(mutex_); - TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler); + TF_DISALLOW_COPY_AND_ASSIGN(NVPTXCompiler); }; } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NVPTX_COMPILER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..4aaf0c9e142106a0e74f319d71dad4c4c96d3f08 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace gpu { + +OutfeedManager* GetOrCreateOutfeedManager() { + static auto* manager = new OutfeedManager; + return manager; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..160ba4b691f818ff01b41b8603c11853ea12c253 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_ + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/gpu/xfeed_queue.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" + +namespace xla { +namespace gpu { + +// TODO(b/30467474) Once GPU outfeed implementation settles, consider +// folding back the cpu and gpu outfeed implementations into a generic +// one if possible. + +// Defines a buffer holding the destination for an outfeed in host memory and a +// notification when that triggers when the transfer is done. +class OutfeedBuffer { + public: + OutfeedBuffer(int64 length) : length_(length) {} + + // Waits for the device transfer to be finished. + void WaitUntilAvailable() { done_.WaitForNotification(); } + + int64 length() const { return length_; } + void set_destination(std::unique_ptr destination) { + destination_ = std::move(destination); + } + MutableBorrowingLiteral* destination() { return destination_.get(); } + + // Callback to signal that this buffer is consumed. + void Done() { done_.Notify(); } + + private: + std::unique_ptr destination_; + const int64 length_; + tensorflow::Notification done_; +}; + +// Manages a thread-safe queue of buffers. The buffers are supposed to be +// produced by the transfer manager and consumed by the device. +using OutfeedManager = XfeedQueue>*>; + +// Singleton creator-or-accessor: Returns the GPU outfeed manager. +OutfeedManager* GetOrCreateOutfeedManager(); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..b99d998c4d7df514c024b1f8d643d08c72059d0e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -0,0 +1,107 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +OutfeedThunk::OutfeedThunk(ShapeTree outfeed_slices, + const HloInstruction* hlo_instruction) + : Thunk(Kind::kOutfeed, hlo_instruction), + outfeed_slices_(std::move(outfeed_slices)) {} + +Status OutfeedThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString(); + + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); + OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(); + ShapeTree>* outfeed_buffers = + outfeed_manager->BlockingGetNextDestination(); + + // Nothing to be done for empty tuples. + if (ShapeUtil::IsEmptyTuple(hlo_instruction()->operand(0)->shape())) { + return Status::OK(); + } + CHECK(ShapeUtil::Compatible(hlo_instruction()->operand(0)->shape(), + outfeed_buffers->shape())); + + TF_RETURN_IF_ERROR(outfeed_buffers->ForEachMutableElementWithStatus( + [&](const ShapeIndex& index, std::unique_ptr* buffer) { + if (!*buffer) { // Tuple pointers. + return Status::OK(); + } + + BufferAllocation::Slice slice = outfeed_slices_.element(index); + se::DeviceMemoryBase data_address; + if (slice.allocation()) { + // If we have a static allocation, read it from there. This avoids + // synchronizing the host and device just to read a pointer. + data_address = buffer_allocations.GetDeviceAddress(slice); + } else { + // Otherwise we have to read the tuple pointer first. + CHECK(!index.empty()); + // Copy the parent buffer to the host. + BufferAllocation::Slice tuple_slice = + outfeed_slices_.element(ShapeIndexView(index).ConsumeFront()); + if (!tuple_slice.allocation()) { + return Unimplemented( + "Nested dynamic tuples are not supported on GPU"); + } + se::DeviceMemoryBase tuple_address = + buffer_allocations.GetDeviceAddress(tuple_slice); + CHECK(tuple_slice.size() % sizeof(void*) == 0) + << "Tuple size must be a multiple of pointer size"; + std::vector tuple_element_buffer_addresses(tuple_slice.size() / + sizeof(void*)); + stream->ThenMemcpy(tuple_element_buffer_addresses.data(), + tuple_address, tuple_slice.size()); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + // The data address is specified by the element of the tuple pointer + // buffer. + data_address = + se::DeviceMemoryBase(tuple_element_buffer_addresses[index.back()], + (*buffer)->length()); + } + + // TODO(b/111309141): Run this on a separate stream so it doesn't block + // the GPU from doing work during the transfer. This could be handled by + // making StreamAssignment do something intelligent with outfeed thunks. + stream + ->ThenMemcpy((*buffer)->destination()->untyped_data(), data_address, + (*buffer)->length()) + .ThenDoHostCallback([buffer]() { (*buffer)->Done(); }); + return Status::OK(); + })); + + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError("Failed to complete data transfer on stream %p: %s", + stream, block_status.error_message().c_str()); + } + + VLOG(2) << "Outfeeding from GPU complete"; + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..8ed89f05f0c5bb2e3893e695d413bac3b231112d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_THUNK_H_ + +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A thunk that outfeeds data. Data must be already resident on the host. This +// thunk performs a host to device copy from the buffer allocated for the +// outfeed op to the host location. +class OutfeedThunk : public Thunk { + public: + // Constructs a OutfeedThunk that copies data to the host-side + // outfeed queue from the buffers in the given shape tree. + OutfeedThunk(ShapeTree outfeed_slices, + const HloInstruction* hlo_instruction); + + OutfeedThunk(const OutfeedThunk&) = delete; + OutfeedThunk& operator=(const OutfeedThunk&) = delete; + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream, + HloExecutionProfiler* profiler) override; + + private: + const ShapeTree outfeed_slices_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc new file mode 100644 index 0000000000000000000000000000000000000000..79f7d31816baf0b95b967771b956a9c06ac81e91 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -0,0 +1,233 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" + +namespace xla { +namespace gpu { + +using tensorflow::gtl::ArraySlice; + +// We want the input/output feature counts of an f16 conv to be factors of 8, +// because without this cudnn can't use tensor cores on the conv. +static constexpr int64 kDesiredNumFeaturesFactor = 8; + +// We won't pad a conv if doing so increases the total number of bytes in the +// lhs, rhs, or result by more than this amount. +// +// TODO(jlebar): This number was tuned experimentally. It represents a +// compromise on our current benchmarks; it speeds some up significantly, and +// doesn't slow any down. But we can observe by changing this value that +// there's additional room for speedups. Achieving those speedups without also +// slowing other things down will likely require a more sophisticated heuristic, +// possibly some form of auto-tuning. +static constexpr double kMaxBytesTouchedIncrease = 1.2; + +// Pads the given dimensions in the given shape up to a multiple of +// kDesiredNumFeaturesFactor. +static Shape PadShape(Shape s, ArraySlice dims) { + for (int64 dim : dims) { + int64 dim_to_pad_size = s.dimensions(dim); + int64 new_dim_to_pad_size = + RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); + s.set_dimensions(dim, new_dim_to_pad_size); + } + return s; +} + +// Creates and returns an HLO that zero-pads one or more dimensions in the given +// instruction so that its shape is equal to the given shape. +// +// Padding is added to the end of each relevant dimension. +// +// If the instruction already has the given shape, simply returns it without an +// intervening pad. +static HloInstruction* PadInstruction(HloInstruction* instr, + const Shape& new_shape) { + HloComputation* comp = instr->parent(); + + const Shape& shape = instr->shape(); + auto* zero = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(shape.element_type()).CloneToUnique())); + + PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); + + bool added_padding = false; + for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) { + if (shape.dimensions(dim) == new_shape.dimensions(dim)) { + continue; + } + CHECK_GT(new_shape.dimensions(dim), shape.dimensions(dim)); + pad_config.mutable_dimensions(dim)->set_edge_padding_high( + new_shape.dimensions(dim) - shape.dimensions(dim)); + added_padding = true; + } + + if (!added_padding) { + return instr; + } + return comp->AddInstruction( + HloInstruction::CreatePad(new_shape, instr, zero, pad_config)); +} + +// Pads the input/output feature dimensions of the given cudnn convolution +// custom-call to be multiples of kDesiredNumFeaturesFactor. +static StatusOr PadFeaturesDims(HloInstruction* conv) { + CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) + << "conv must use 0 scratch bytes, i.e. this pass must be run " + "before CudnnConvolutionAlgorithmPicker."; + + const auto& target = conv->custom_call_target(); + const auto& dnums = conv->convolution_dimension_numbers(); + auto* lhs = conv->mutable_operand(0); + auto* rhs = conv->mutable_operand(1); + const Shape& result_shape = conv->shape().tuple_shapes(0); + + Shape new_lhs_shape = [&] { + if (target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardFilterCallTarget) { + // LHS is "input". + return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardInputCallTarget); + // LHS is "output". + return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); + }(); + + Shape new_rhs_shape = [&] { + if (target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardInputCallTarget) { + // RHS is "filter". + return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); + // RHS is "output". + return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); + }(); + + if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && + ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) { + VLOG(3) << "No need to pad features of " << conv->ToString(); + return false; + } + + Shape new_result_shape = [&] { + if (target == kCudnnConvForwardCallTarget) { + // Result is "output". + return PadShape(result_shape, {dnums.output_feature_dimension()}); + } + if (target == kCudnnConvBackwardInputCallTarget) { + // Result is "input". + return PadShape(result_shape, {dnums.input_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); + // Result is "filter". + return PadShape(result_shape, {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + }(); + + // Check that padding wouldn't increase the total bytes read/written by this + // operation too much. + auto check_size_increase = [&](const Shape& old_shape, + const Shape& new_shape) { + int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape); + int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape); + if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) { + return true; + } + VLOG(3) << "Not padding convolution; doing so would change input / result " + "shape from " + << ShapeUtil::HumanString(old_shape) << " to " + << ShapeUtil::HumanString(new_shape) << ", a size increase of " + << new_bytes / static_cast(old_bytes) << "x > " + << kMaxBytesTouchedIncrease << "x: " << conv->ToString(); + return false; + }; + if (!check_size_increase(lhs->shape(), new_lhs_shape) || + !check_size_increase(rhs->shape(), new_rhs_shape) || + !check_size_increase(result_shape, new_result_shape)) { + return false; + } + + // OK, let's do the transformation! + + auto* new_lhs = PadInstruction(lhs, new_lhs_shape); + auto* new_rhs = PadInstruction(rhs, new_rhs_shape); + CHECK(new_lhs != lhs || new_rhs != rhs) + << "We should have had to pad either LHS or RHS."; + + auto add = [&](std::unique_ptr new_instr) { + return conv->parent()->AddInstruction(std::move(new_instr)); + }; + + Shape new_conv_shape = ShapeUtil::MakeTupleShape( + {new_result_shape, ShapeUtil::MakeShape(U8, {0})}); + auto* new_conv = + add(conv->CloneWithNewOperands(new_conv_shape, {new_lhs, new_rhs})); + + // Slice the new conv result if necessary, keeping in mind that new_conv has + // tuple shape (new_result_shape, u8[0]). + if (!ShapeUtil::Equal(result_shape, new_result_shape)) { + std::vector start_indices(result_shape.dimensions_size(), 0); + std::vector end_indices(result_shape.dimensions().begin(), + result_shape.dimensions().end()); + std::vector strides(result_shape.dimensions_size(), 1); + + auto* new_conv_result = add( + HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0)); + auto* empty_temp_buffer = + add(HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + auto* sliced_result = add(HloInstruction::CreateSlice( + result_shape, new_conv_result, start_indices, end_indices, strides)); + new_conv = + add(HloInstruction::CreateTuple({sliced_result, empty_temp_buffer})); + } + + VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with " + << new_conv->ToString(); + TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(conv, new_conv)); + return true; +} + +static std::vector GetRelevantConvs(HloComputation* comp) { + std::vector convs; + for (HloInstruction* instr : comp->instructions()) { + if (IsCustomCallToDnnConvolution(*instr) && + instr->operand(0)->shape().element_type() == F16) { + convs.push_back(instr); + } + } + return convs; +} + +StatusOr PadForTensorCores::Run(HloModule* module) { + bool changed = false; + for (HloComputation* comp : module->MakeNonfusionComputations()) { + for (HloInstruction* conv : GetRelevantConvs(comp)) { + TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv)); + changed |= result; + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h new file mode 100644 index 0000000000000000000000000000000000000000..192359f026bfb2f1d5436713e4a30725fa0ad6ba --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Ensures that f16 cudnn convolutions have input/output channel dimensions that +// are multiples of 8, inserting pads/slices as necessary. +// +// This is useful primarily for Volta and newer GPUs, where tensor cores can +// only be used if the channel dims are multiples of 8. It's probably the +// opposite of useful on other GPUs, so you should check what GPU you're +// targeting before running this pass. +// +// TODO(jlebar): Also pad dots. +class PadForTensorCores : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { + return "pad for tensor cores"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..99e7580b826fc5cd6d98a037a5eb064552952e18 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc @@ -0,0 +1,164 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; + +using PadForTensorCoresTest = HloVerifiedTestBase; + +TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,41] parameter(0) + filter = f16[2,2,41,40] parameter(1) + ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + + SCOPED_TRACE(module().ToString()); + EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 48, 40}))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + output = f16[10,20,30,41] parameter(0) + filter = f16[2,2,40,41] parameter(1) + ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardInput" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::CustomCall(kCudnnConvBackwardInputCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 40, 48}))); +} + +TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,40] parameter(0) + filter = f16[2,2,40,41] parameter(1) + ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvForwardCallTarget, op::Parameter(0), + op::Pad(op::Parameter(1), _)))), + _)); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + output = f16[10,20,30,40] parameter(0) + filter = f16[2,2,41,40] parameter(1) + result = (f16[10,20,30,41], u8[0]) custom-call(output, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardInput" + ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardInputCallTarget, op::Parameter(0), + op::Pad(op::Parameter(1), _)))), + _))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,41] parameter(0) + output = f16[10,20,30,40] parameter(1) + result = (f16[2,2,41,40], u8[0]) custom-call(input, output), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardFilter" + ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardFilterCallTarget, + op::Pad(op::Parameter(0), _), op::Parameter(1)))), + _))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,40] parameter(0) + output = f16[10,20,30,41] parameter(1) + result = (f16[2,2,40,41], u8[0]) custom-call(input, output), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardFilter" + ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardFilterCallTarget, + op::Parameter(0), op::Pad(op::Parameter(1), _)))), + _))); +} + +} // anonymous namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index c8f0d4185c63c5bafca6f30acab31cbe8e987277..b22040eee167e784bed58dbc0d0ad2ae042037f3 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -68,7 +69,7 @@ HloInstruction* MaybePaddedAndSlicedInput( PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(element_type)))); + MakeUnique(LiteralUtil::Zero(element_type)))); input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } @@ -125,7 +126,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, PrimitiveType element_type = kernel->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(element_type)))); + MakeUnique(LiteralUtil::Zero(element_type)))); return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -234,9 +235,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(input->shape().element_type())))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(MakeUnique( + LiteralUtil::Zero(input->shape().element_type())))); HloInstruction* padded_input = MakePadHlo(input, padding, input_padding_config).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index d8c07dc3119fb81a3ef22822acb11b7c4d5bbca5..3838fee674566196e10ddd98462c1a1aa7835e1a 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -32,33 +32,33 @@ namespace gpu { ParallelLoopEmitter::ParallelLoopEmitter( BodyEmitter body_emitter, const Shape& shape, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor) - : LoopEmitter(body_emitter, shape, ir_builder), + : LoopEmitter(body_emitter, shape, b), launch_dimensions_(launch_dimensions), unroll_factor_(unroll_factor) {} ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor) - : LoopEmitter(target_element_generator, target_arrays, ir_builder), + : LoopEmitter(target_element_generator, target_arrays, b), launch_dimensions_(launch_dimensions), unroll_factor_(unroll_factor) {} ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor) - : LoopEmitter(target_element_generator, target_array, ir_builder), + : LoopEmitter(target_element_generator, target_array, b), launch_dimensions_(launch_dimensions), unroll_factor_(unroll_factor) {} std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) { + tensorflow::StringPiece loop_name, llvm::Type* index_type) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; // if (linear_index < num_elements) { @@ -71,31 +71,30 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( // // %nctaid.x is currently specified as 2147483647. VLOG(3) << "EmitIndexAndSetExitBasicBlock unroll_factor " << unroll_factor_; + CHECK_NE(index_type, nullptr); std::vector array_indices; - llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_); + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), static_cast(block_id)); - block_id = - ir_builder_->CreateZExt(block_id, ir_builder_->getInt64Ty(), "block_id"); + block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id"); // Per the PTX documentation: // "It is guaranteed that [...] 0 <= %tid.x < %ntid.x" // // %ntid.x is currently specified as 1024. llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, ir_builder_); + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(), static_cast(thread_id)); - thread_id = ir_builder_->CreateZExt(thread_id, ir_builder_->getInt64Ty(), - "thread_id"); - - llvm::Value* linear_index_base = ir_builder_->CreateAdd( - ir_builder_->CreateMul( - block_id, - ir_builder_->getInt64(launch_dimensions_.threads_per_block()), "", - /*HasNUW=*/true, /*HasNSW=*/true), + thread_id = b_->CreateZExtOrTrunc(thread_id, index_type, "thread_id"); + + llvm::Value* linear_index_base = b_->CreateAdd( + b_->CreateMul(block_id, + llvm::ConstantInt::get( + index_type, launch_dimensions_.threads_per_block()), + "", + /*HasNUW=*/true, /*HasNSW=*/true), thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); // Add an @llvm.assume(linear_index < threads_per_block * num_blocks). @@ -108,39 +107,41 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( // conditions in the same basic block as their operands. llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::assume, - {ir_builder_->CreateICmpULT( + {b_->CreateICmpULT( linear_index_base, - ir_builder_->getInt64(launch_dimensions_.threads_per_block() * - launch_dimensions_.block_count()), + llvm::ConstantInt::get(index_type, + launch_dimensions_.threads_per_block() * + launch_dimensions_.block_count()), "linear_index_in_range")}, - {}, ir_builder_); + {}, b_); if (unroll_factor_ > 1) { - linear_index_base = ir_builder_->CreateMul( - linear_index_base, ir_builder_->getInt64(unroll_factor_), + linear_index_base = b_->CreateMul( + linear_index_base, llvm::ConstantInt::get(index_type, unroll_factor_), "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true); } - array_indices.emplace_back(linear_index_base, shape_, ir_builder_); + array_indices.emplace_back(linear_index_base, shape_, b_); for (int i = 1; i < unroll_factor_; ++i) { - llvm::Value* linear_index = ir_builder_->CreateAdd( - linear_index_base, ir_builder_->getInt64(i), "linear_index", - /*HasNUW=*/true, /*HasNSW=*/true); - array_indices.emplace_back(linear_index, shape_, ir_builder_); + llvm::Value* linear_index = + b_->CreateAdd(linear_index_base, llvm::ConstantInt::get(index_type, i), + "linear_index", + /*HasNUW=*/true, /*HasNSW=*/true); + array_indices.emplace_back(linear_index, shape_, b_); } auto if_in_bounds = llvm_ir::EmitIfThenElse( - ir_builder_->CreateICmpULT( + b_->CreateICmpULT( linear_index_base, - ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))), - llvm_ir::IrName(loop_name, "in_bounds"), ir_builder_, false); + llvm::ConstantInt::get(index_type, ShapeUtil::ElementsIn(shape_))), + llvm_ir::IrName(loop_name, "in_bounds"), b_, false); // Set exit_bb_ to the exit block of the if structure. exit_bb_ = if_in_bounds.after_block; CHECK_NE(nullptr, exit_bb_); // Set IR builder insertion point to the body of the if structure. - llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, b_); return array_indices; } diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 25318b3bed8bf4a2dfe3a4a974269d0405c3bfec..b82a23419df08cafdc69b6d2f14528484b95dc73 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -34,13 +34,13 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { // The meanings of other parameters are the same as LoopEmitter. ParallelLoopEmitter(BodyEmitter body_emitter, const Shape& shape, const LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* ir_builder, int unroll_factor = 1); + llvm::IRBuilder<>* b, int unroll_factor = 1); // Constructs a ParallelLoopEmitter from an element generator that generates // each element of the given target array. ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, const LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* ir_builder, int unroll_factor = 1); + llvm::IRBuilder<>* b, int unroll_factor = 1); // Constructs a loop emitter for a loop that generates on element of each of N // arrays on each iteration. @@ -50,7 +50,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor = 1); ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; @@ -58,7 +58,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) override; + tensorflow::StringPiece loop_name, llvm::Type* index_type) override; private: // The thread and block dimension to parallelize the loop on. diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index c125474edb1036090a926020f2b1e7fcf64c751a..02471129e004b4876ce20a62cade34060c65b478 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -47,6 +47,7 @@ class LaunchDimensions { int64 block_count() const { return block_count_; } int64 threads_per_block() const { return threads_per_block_; } + int64 launch_bound() const { return block_count() * threads_per_block(); } private: int64 block_count_; diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index 88cb10883e97ae663dc492ad088e6daf9133d7f5..84285be70a4ba94101040a639c39b3eaecbb5bb3 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -33,9 +34,12 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable, } Status SequentialThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); for (const auto& thunk : thunks_) { - TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); + TF_RETURN_IF_ERROR( + thunk->ExecuteOnStream(buffer_allocations, stream, profiler)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index 135f79e413dfaa27f2f2264e0daa3beb3c305e0f..3c4de1d1a6c912ba31f56c29b10ca004d1e56da6 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -41,7 +42,8 @@ class SequentialThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: // The list of sub-thunks. diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index e4cfc6999f2da04dd7e7a34d854fdb3d75b8bfc6..0806dd51614f4d2da12f3fbbc9fb98df5273d5c8 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -33,13 +33,13 @@ int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const { } void StreamAssignment::AssignStreamToHlo(const HloInstruction* hlo, - int stream_no) { - CHECK_GE(stream_no, 0); - if (stream_no >= stream_count_) { - stream_count_ = stream_no + 1; + int stream_num) { + CHECK_GE(stream_num, 0); + if (stream_num >= stream_count_) { + stream_count_ = stream_num + 1; } - InsertOrDie(&hlo_to_stream_number_, hlo, stream_no); - VLOG(2) << "Assign stream #" << stream_no << " to " << hlo->ToString(); + InsertOrDie(&hlo_to_stream_number_, hlo, stream_num); + VLOG(2) << "Assign stream #" << stream_num << " to " << hlo->ToString(); } namespace { @@ -51,6 +51,12 @@ bool CanRunConcurrently(const HloInstruction& a, const HloInstruction& b, return !reachability.IsConnected(&a, &b); } +constexpr int kInvalidStreamNum = -1; +// Returns true iff `stream_num` is an invalid stream number. +inline bool IsStreamNumValid(int stream_num) { + return stream_num != kInvalidStreamNum; +} + // Returns which existing stream to assign to `hlo`, or -1 if a stream is not // needed. `stream_assignment` is the existing stream assignment for all // instructions topologically before `hlo`. `seen_gemms` contains all GEMMs that @@ -62,7 +68,7 @@ int ComputeStreamToAssign( if (hlo.opcode() == HloOpcode::kParameter || hlo.opcode() == HloOpcode::kConstant) { // kParameter and kConstant do not need a thunk. - return -1; + return kInvalidStreamNum; } if (hlo.GetModule() @@ -75,17 +81,17 @@ int ComputeStreamToAssign( if (!ImplementedAsGemm(hlo)) { // If `hlo` is not implemented as a GEMM, keep it close to its operands to // avoid excessive synchronization. - int stream_no = -1; + int stream_num = -1; for (const auto* operand : hlo.operands()) { if (stream_assignment.HasStreamAssigned(*operand)) { - stream_no = - std::max(stream_no, stream_assignment.StreamNumberForHlo(*operand)); + stream_num = std::max(stream_num, + stream_assignment.StreamNumberForHlo(*operand)); } } - if (stream_no == -1) { - stream_no = 0; + if (!IsStreamNumValid(stream_num)) { + stream_num = 0; } - return stream_no; + return stream_num; } // Assign different streams to concurrent GEMMs. The code below uses a @@ -94,17 +100,17 @@ int ComputeStreamToAssign( // `hlo` a different stream. std::set forbidden_stream_numbers; for (const auto* seen_gemm : seen_gemms) { - int stream_no = stream_assignment.StreamNumberForHlo(*seen_gemm); - if (!forbidden_stream_numbers.count(stream_no) && + int stream_num = stream_assignment.StreamNumberForHlo(*seen_gemm); + if (!forbidden_stream_numbers.count(stream_num) && CanRunConcurrently(*seen_gemm, hlo, reachability)) { - forbidden_stream_numbers.insert(stream_no); + forbidden_stream_numbers.insert(stream_num); } } - for (int stream_no = 0; stream_no < stream_assignment.StreamCount(); - ++stream_no) { - if (!forbidden_stream_numbers.count(stream_no)) { - return stream_no; + for (int stream_num = 0; stream_num < stream_assignment.StreamCount(); + ++stream_num) { + if (!forbidden_stream_numbers.count(stream_num)) { + return stream_num; } } return stream_assignment.StreamCount(); @@ -118,11 +124,27 @@ std::unique_ptr AssignStreams(const HloModule& module) { std::unique_ptr reachability = computation.ComputeReachability(); std::vector seen_gemms; + // The execution of different RNG Hlo instructions in the same module updates + // a common global variable. To avoid a race condition, we simply assign all + // RNG kernels to the same stream to make them run sequentially. + // + // TODO(b/111791052): If we remove such a common variable, we will need to + // clean up the code here. + int stream_num_for_rng = kInvalidStreamNum; for (const auto* hlo : computation.MakeInstructionPostOrder()) { - int stream_no = ComputeStreamToAssign(*hlo, *stream_assignment, - *reachability, seen_gemms); - if (stream_no != -1) { - stream_assignment->AssignStreamToHlo(hlo, stream_no); + // If we ever enable fusion of RNG instructions, we will need to extend this + // code to look inside a fused instruction. + int stream_num = (hlo->opcode() == HloOpcode::kRng && + IsStreamNumValid(stream_num_for_rng)) + ? stream_num_for_rng + : ComputeStreamToAssign(*hlo, *stream_assignment, + *reachability, seen_gemms); + if (IsStreamNumValid(stream_num)) { + stream_assignment->AssignStreamToHlo(hlo, stream_num); + if (hlo->opcode() == HloOpcode::kRng && + !IsStreamNumValid(stream_num_for_rng)) { + stream_num_for_rng = stream_num; + } } if (ImplementedAsGemm(*hlo)) { seen_gemms.push_back(hlo); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index a50ddf6ac63c7fa7ccace94bc7f40f438aedccf8..05b305ea4cdfdbaeb42544b626a6b9990bb42f57 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -20,10 +20,17 @@ limitations under the License. namespace xla { namespace gpu { -using stream_executor::dnn::DataLayout; -using stream_executor::dnn::DataLayoutString; -using stream_executor::dnn::FilterLayout; -using stream_executor::dnn::FilterLayoutString; +using se::dnn::DataLayout; +using se::dnn::DataLayoutString; +using se::dnn::FilterLayout; +using se::dnn::FilterLayoutString; + +bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { + int major, minor; + CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, + &minor)); + return major >= 7; +} StatusOr> StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 8218f4fd11d3978d0ecc53fc15e287aea4b69ec3..1fc46bafa10e7ba6c896f081d5c836bd400886c9 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -24,18 +26,20 @@ limitations under the License. namespace xla { namespace gpu { +// Returns true if the given StreamExecutor is for a Volta or newer nvidia GPU. +bool IsVoltaOrLater(const se::StreamExecutor& stream_exec); + // Returns (input, filter, output) XLA Layout protos given the StreamExecutor // layouts. StatusOr> StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, - stream_executor::dnn::DataLayout input, - stream_executor::dnn::FilterLayout filter, - stream_executor::dnn::DataLayout output); + se::dnn::DataLayout input, + se::dnn::FilterLayout filter, + se::dnn::DataLayout output); // Returns (input, filter, output) StreamExecutor layouts given the XLA layouts. -StatusOr> +StatusOr< + std::tuple> XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, const Layout& input, const Layout& filter, const Layout& output); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..4fad3f46cf953945e4f395e751e5ba76db97ecc4 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -0,0 +1,223 @@ +# Description: GPU-specific XLA tests. For example, codegen tests that +# verify the IR emitted. +# +# TODO(jlebar): None of these tests actually use the GPU, so they should not +# need to run on machines with GPUs present. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = [":friends"]) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "gpu_codegen_test", + testonly = True, + srcs = ["gpu_codegen_test.cc"], + hdrs = ["gpu_codegen_test.h"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "gpu_copy_test", + srcs = ["gpu_copy_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_ftz_test", + srcs = ["gpu_ftz_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_index_test", + srcs = ["gpu_index_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_infeed_test", + srcs = ["infeed_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_kernel_tiling_test", + srcs = ["gpu_kernel_tiling_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_ldg_test", + srcs = ["gpu_ldg_test.cc"], + tags = ["requires-gpu-sm35"], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_noalias_test", + srcs = ["gpu_noalias_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_fusion_test", + srcs = ["gpu_fusion_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_unrolling_test", + srcs = ["gpu_unrolling_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_alignment_test", + testonly = True, + srcs = ["gpu_alignment_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", + "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..672c68e59b59dff19f0c5575db26dea455c45053 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuAlignmentTest : public GpuCodegenTest {}; + +TEST_F(GpuAlignmentTest, Test) { + const char* hlo_string = R"( +HloModule GpuAlignmentTest + +ENTRY main { + zero = f32[] constant(0) + tok = token[] after-all() + a = f32[100] parameter(0) + b_tup = (f32[200], token[]) infeed(tok) + b = f32[200] get-tuple-element(b_tup), index=0 + a_padded = f32[150] pad(a, zero), padding=0_50 + b_sliced = f32[150] slice(b), slice={[0:150]} + ROOT c = f32[150] add(a_padded, b_sliced) +} +)"; + + CompileAndVerifyIr(hlo_string, R"( +CHECK: @fusion(i8* align 64 dereferenceable(600) %alloc0, i8* align 16 dereferenceable(400) %alloc1, i8* align 64 dereferenceable(864) %temp_buf) +)"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b8415fe9106137e588f345a3492f93e46aeb5b6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace gpu { + +std::unique_ptr GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) { + HloModuleConfig config; + auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + debug_options.set_xla_gpu_ftz(ftz); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + // TODO(b/38354253): Change tests to use Parameters instead of Constants. + debug_options.add_xla_disable_hlo_passes("constant_folding"); + config.set_debug_options(debug_options); + + return MakeUnique(TestName(), config); +} + +void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr hlo_module, + const string& pattern) { + std::unique_ptr executable = + std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie()); + string ptx_str = + std::string(static_cast(executable.get())->ptx()); + StatusOr filecheck_result = RunFileCheck(ptx_str, pattern); + ASSERT_TRUE(filecheck_result.ok()); + EXPECT_TRUE(filecheck_result.ValueOrDie()); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h new file mode 100644 index 0000000000000000000000000000000000000000..e4a3573babb7ed746504c1466f85b582aa4d044f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_GPU_CODEGEN_TEST_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_GPU_CODEGEN_TEST_H_ + +#include + +#include "tensorflow/compiler/xla/tests/llvm_irgen_test_base.h" + +namespace xla { +namespace gpu { + +// Tests that verify IR or PTX emitted by the GPU backend is as expected. +class GpuCodegenTest : public LlvmIrGenTestBase { + protected: + // Like HloTestBase::CreateNewModule(), with a flag for configuring the ftz + // option. + std::unique_ptr CreateNewModuleWithFTZ(bool ftz); + + // Compiles the given HLO module to PTX and verifies the PTX matches the given + // FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html). + void CompileAndVerifyPtx(std::unique_ptr hlo_module, + const string& pattern); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_GPU_CODEGEN_TEST_H_ diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ce69e058e64aab1f3c292b2ad7c7b529d4666b35 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +class GpuCopyTest : public GpuCodegenTest {}; + +// The GPU backend should not emit a copy kernel for the kCopy instruction in +// this test. Instead, it should generate a CopyThunk which invokes cuMemcpy at +// runtime. +TEST_F(GpuCopyTest, UseMemcpy) { + HloComputation::Builder builder(TestName()); + + std::unique_ptr literal = + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + + std::unique_ptr computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + // There should not be any kernel prefixed "copy". + CompileAndVerifyIr(std::move(hlo_module), "; CHECK-NOT: define void @_copy", + /*match_optimized_ir=*/false); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..177b94934c7f519172508b5cc6e088f908401193 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" + +// Check that the ftz (flush denormals to zero) flag is reflected in PTX as +// expected. + +namespace xla { +namespace gpu { +namespace { + +class GpuFtzTest : public GpuCodegenTest { + public: + explicit GpuFtzTest(bool ftz) : ftz_(ftz) {} + + // Creates an HLO module that performs the given binary operation on some + // data. + std::unique_ptr CreateBinaryOpModule(HloOpcode op) { + HloComputation::Builder builder(TestName()); + + Shape param_shape = ShapeUtil::MakeShapeWithLayout( + F32, /*dimensions=*/{100, 100}, /*minor_to_major=*/{1, 0}); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /* parameter_number=*/0, param_shape, "x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /* parameter_number=*/1, param_shape, "y")); + builder.AddInstruction(HloInstruction::CreateBinary(param_shape, op, x, y)); + + auto hlo_module = CreateNewModuleWithFTZ(ftz_); + hlo_module->AddEntryComputation(builder.Build()); + return hlo_module; + } + + // Creates an HLO module that performs the given unary operation on some data. + std::unique_ptr CreateUnaryOpModule(HloOpcode op) { + HloComputation::Builder builder(TestName()); + + Shape param_shape = ShapeUtil::MakeShapeWithLayout( + F32, /*dimensions=*/{100, 100}, /*minor_to_major=*/{1, 0}); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /* parameter_number=*/0, param_shape, "x")); + builder.AddInstruction(HloInstruction::CreateUnary(param_shape, op, x)); + + auto hlo_module = CreateNewModuleWithFTZ(ftz_); + hlo_module->AddEntryComputation(builder.Build()); + return hlo_module; + } + + bool ftz_; +}; + +class GpuFtzEnabledTest : public GpuFtzTest { + public: + GpuFtzEnabledTest() : GpuFtzTest(/*ftz=*/true) {} +}; + +class GpuFtzDisabledTest : public GpuFtzTest { + public: + GpuFtzDisabledTest() : GpuFtzTest(/*ftz=*/false) {} +}; + +// Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise. +TEST_F(GpuFtzEnabledTest, MultiplyFtz) { + CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( + CHECK-NOT: mul.f32 + CHECK: mul.ftz.f32 + CHECK-NOT: mul.f32 + )"); +} +TEST_F(GpuFtzDisabledTest, MultiplyFtz) { + CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( + CHECK-NOT: mul.ftz.f32 + CHECK: mul.f32 + CHECK-NOT: mul.ftz.f32 + )"); +} + +// In NVPTX, exp(float) is implemented in libdevice, and consults __nvvm_reflect +// to determine whether or not ftz is enabled. The implementation uses two +// calls to ex2.approx. When ftz is on, we get two calls to the ftz version; +// when ftz is off, we get one call to the ftz version and one call to the +// regular version. +TEST_F(GpuFtzEnabledTest, ExpFtz) { + CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"( + CHECK-NOT: ex2.approx.f32 + CHECK: ex2.approx.ftz.f32 + CHECK-NOT: ex2.approx.f32 + CHECK: ex2.approx.ftz.f32 + CHECK-NOT: ex2.approx.f32 + CHECK-NOT: ex2.approx.ftz.f32 + )"); +} + +TEST_F(GpuFtzDisabledTest, ExpFtz) { + CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"( + CHECK-NOT: ex2.approx.f32 + CHECK-DAG: ex2.approx.ftz.f32 + CHECK-DAG: ex2.approx.f32 + CHECK-NOT: ex2.approx.f32 + CHECK-NOT: ex2.approx.ftz.f32 + )"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..674b436a8e3135a5dfe3731647897696bf1321cd --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuFusionTest : public GpuCodegenTest {}; + +TEST_F(GpuFusionTest, FusedReshape) { + const char* hlo_text = R"( + HloModule test_module + + fused_computation { + p0.param_0 = f32[4,1,1]{2,1,0} parameter(0) + p1.param_1 = f32[4,1]{1,0} parameter(1) + reshape = f32[4,1]{1,0} reshape(p0.param_0) + ROOT add = f32[4,1] add(reshape, p1.param_1) + } + + ENTRY BroadcastIntoAdd { + p0 = f32[4,1,1]{2,1,0} parameter(0) + p1 = f32[4,1]{1,0} parameter(1) + ROOT fusion = f32[4,1]{1,0} fusion(p0, p1), kind=kLoop, + calls=fused_computation + } +)"; + + CompileAndVerifyIr(hlo_text, + R"( +; CHECK-LABEL: @fusion +; CHECK: fadd +; CHECK: } + )"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e5958165eff21d82faf821213e50fe30a11059a4 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +// This file tests the index expressions used to reference source tensors. When +// the destination tensor and source tensor have compatible shapes, the linear +// index is used to access the source tensor. Otherwise, dimensional indices +// computed from the linear index are used to access the source tensor. + +class GpuIndexTest : public GpuCodegenTest {}; + +TEST_F(GpuIndexTest, CompatibleUseLinearIndex) { + HloComputation::Builder builder(TestName()); + + auto param_shape = ShapeUtil::MakeShape(F32, {5, 7, 2}); + HloInstruction* param_x = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "x")); + HloInstruction* param_y = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "y")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y)); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + // Check the optimized IR as the unoptimized IR contains dead udiv and urem. + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-NOT: udiv +; CHECK-NOT: urem + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshape) { + HloModuleConfig config; + config.set_debug_options(HloTestBase::GetDebugOptionsForTest()); + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY CompatibleUseLinearIndexWithReshape { + x = f32[5,7,2]{2,1,0} parameter(0) + y = f32[5,14]{1,0} parameter(1) + reshape = f32[5,7,2]{2,1,0} reshape(y) + ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, reshape) + })", + config) + .ValueOrDie(); + + // Check the optimized IR as the unoptimized IR contains dead udiv and urem. + CompileAndVerifyIr(std::move(module), + R"( +; CHECK-NOT: udiv +; CHECK-NOT: urem + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) { + HloModuleConfig config; + config.set_debug_options(HloTestBase::GetDebugOptionsForTest()); + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY CompatibleUseLinearIndexWithReshape { + x = f32[5,7,2]{2,1,0} parameter(0) + y = f32[14]{0} parameter(1) + reshape = f32[7,2]{1,0} reshape(y) + broadcast = f32[5,7,2]{2,1,0} broadcast(reshape), dimensions={1,2} + ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, broadcast) + })", + config) + .ValueOrDie(); + + // Check the optimized IR reuses the linear index by calculating modulo 14. + CompileAndVerifyIr(std::move(module), + R"( +; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14 +; CHECK: %[[bitcast:.*]] = bitcast i8 addrspace(1)* %[[alloc:.*]] to float addrspace(1)* +; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64 +; CHECK: getelementptr inbounds float, float addrspace(1)* %[[bitcast]], i64 %[[idx1]] + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithSizeOneDimensions) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + config.set_debug_options(debug_options); + + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY CompatibleUseLinearIndexWithSizeOneDimensions { + x = f32[1,1024,1,256]{3,2,1,0} parameter(0) + ROOT y = f16[1,1024,1,256]{2,3,1,0} convert(x) + })", + config) + .ValueOrDie(); + + // Check that the unoptimized IR reuses the linear index. + CompileAndVerifyIr(std::move(module), + R"( +; CHECK-LABEL: @fusion +; CHECK: udiv i32 %[[linear_index:.*]], 262144 +; CHECK: %[[ld_addr:.*]] = getelementptr inbounds float, float* {{.*}}, i32 %[[linear_index]] +; CHECK: load float, float* %[[ld_addr]] +; CHECK: %[[st_addr:.*]] = getelementptr inbounds half, half* {{.*}}, i32 %[[linear_index]] +; CHECK: store half {{.*}}, half* %[[st_addr]] + )", + /*match_optimized_ir=*/false); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cca35316f0c472d2a17c466f8cd1af7f22575a8b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuKernelTilingTest : public GpuCodegenTest { + protected: + GpuKernelTilingTest() { + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + config_.set_debug_options(debug_options); + // Disable layout_assignment to use the preassigned layouts. + debug_options.add_xla_disable_hlo_passes("layout_assignment"); + } + HloModuleConfig config_; +}; + +TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { + const char *const kHloString = R"( + HloModule unnested_transpose_1 + + ENTRY unnested_transpose_1 { + para0 = f16[32,3,64]{2,1,0} parameter(0) + ROOT copy1 = f16[32,3,64]{1,0,2} copy(para0) + })"; + + // Check that a call to llvm.nvvm.barrier0 is generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @copy +; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { + const char *const kHloString = R"( + HloModule unnested_transpose_2 + + ENTRY unnested_transpose_2 { + para0 = f16[2,3,64]{2,1,0} parameter(0) + ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0) + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @copy +; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { + const char *const kHloString = R"( + HloModule multiple_output_fusion_1 + fused_computation.1 { + param0 = f32[4,5,6,7,8]{4,3,2,1,0} parameter(0) + copy = f32[4,5,6,7,8]{2,1,4,3,0} copy(param0) + ROOT convert = f16[4,5,6,7,8]{2,1,4,3,0} convert(copy) + } + + ENTRY copy_in_fusion_run_without_hlo_passes { + para0 = f32[4,5,6,7,8]{4,3,2,1,0} parameter(0) + ROOT fusion.1 = f16[4,5,6,7,8]{2,1,4,3,0} fusion(para0), kind=kLoop, + calls=fused_computation.1 + })"; + + // Check that a call to llvm.nvvm.barrier0 is generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { + const char *const kHloString = R"( + HloModule multiple_output_fusion_1 + fused_computation.1 { + param0 = f16[8,31,31,65]{3,2,1,0} parameter(0) + param1 = f16[8,31,31,65]{3,2,1,0} parameter(1) + copy0 = f16[8,31,31,65]{2,1,3,0} copy(param0) + copy1 = f16[8,31,31,65]{2,1,3,0} copy(param1) + ROOT tuple1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0}) + tuple(copy0, copy1) + } + + ENTRY multiple_output_fusion_1 { + para0 = f16[8,31,31,65]{3,2,1,0} parameter(0) + para1 = f16[8,31,31,65]{3,2,1,0} parameter(1) + ROOT fusion.1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0}) + fusion(para0,para1), kind=kLoop, calls=fused_computation.1 + })"; + + // Check that a call to llvm.nvvm.barrier0 is generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, + MultipleOutputFusionWithTwoPossibleTransposesNotTiled) { + const char *const kHloString = R"( + HloModule multiple_output_fusion_2 + fused_computation.1 { + param0 = f16[8,31,31,65]{3,2,1,0} parameter(0) + param1 = f16[8,31,31,65]{1,3,2,0} parameter(1) + copy2 = f16[8,31,31,65]{2,1,3,0} copy(param0) + copy3 = f16[8,31,31,65]{2,1,3,0} copy(param1) + ROOT tuple1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0}) + tuple(copy2, copy3) + } + + ENTRY multiple_output_fusion_2 { + para0 = f16[8,31,31,65]{3,2,1,0} parameter(0) + para1 = f16[8,31,31,65]{1,3,2,0} parameter(1) + ROOT fusion1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0}) + fusion(para0,para1), kind=kLoop, calls=fused_computation.1 + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c9ae7bada5e7545b558b6fcb872ece60850cbe9 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -0,0 +1,141 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests that we emit ld.global.nc (the PTX instruction corresponding to CUDA's +// __ldg builtin) for reads of buffers that don't change during a kernel's +// execution. + +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +class GpuLdgTest : public GpuCodegenTest {}; + +// Parameters are never overwritten, so parameter reads should get ld.global.nc +// reads. +TEST_F(GpuLdgTest, LdgForParamRead) { + HloComputation::Builder builder(TestName()); + + auto shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + std::unique_ptr computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyPtx(std::move(hlo_module), R"( + CHECK-NOT: ld.global.f32 + CHECK: ld.global.nc.f32 + )"); +} + +// Check that reading a buffer produced by a non-parameter HLO also results in +// ld.global.nc, if that buffer isn't modified within the instruction that reads +// it. +TEST_F(GpuLdgTest, LdgForNonParamRead) { + HloComputation::Builder builder(TestName()); + + auto shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + HloInstruction* square = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, add)); + builder.AddInstruction(HloInstruction::CreateTuple({add, square})); + std::unique_ptr computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyPtx(std::move(hlo_module), R"( + CHECK: { + CHECK-NOT: ld.global.f32 + CHECK: ld.global.nc.f32 + CHECK: } + )"); +} + +// Check that reading a buffer that's modified in-place does not produce +// ld.global.nc. +// +// We do this by creating a reduce that feeds into a sin. We don't currently +// fuse sin into reduce, and the sin is elementwise, so it reuses its input +// buffer as its output. +// +// It seems like a fair bet that we won't start fusing sin into the output of +// reduce in the foreseeable future. But if that turns out to be wrong, I give +// you, future reader, permission to delete this test. +TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) { + auto hlo_module = CreateNewModule(); + HloComputation::Builder builder(TestName()); + + HloComputation* reduce_computation; + { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + reduce_computation = + hlo_module->AddEmbeddedComputation(embedded_builder.Build()); + } + + auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {2}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "x")); + HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, + builder.AddInstruction(HloInstruction::CreateBinary( + param_shape, HloOpcode::kAdd, param, param)), + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), + {0}, reduce_computation)); + builder.AddInstruction( + HloInstruction::CreateUnary(reduce_shape, HloOpcode::kSin, reduce)); + + std::unique_ptr computation = builder.Build(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyPtx(std::move(hlo_module), R"( + CHECK-LABEL: .entry sin + CHECK: { + CHECK-NOT: ld.global.nc.f32 + CHECK: ld.global.f32 + CHECK: } + )"); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c42e5704a4d2e611a203293e60a86ba4104bca46 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +class GpuNoAliasTest : public GpuCodegenTest {}; + +TEST_F(GpuNoAliasTest, Concat) { + HloComputation::Builder builder(TestName()); + + auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* param_x = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "x")); + HloInstruction* param_y = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "y")); + HloInstruction* concat = + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 4}), {param_x, param_y}, 1)); + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 6}), {concat, param_x}, 1)); + + std::unique_ptr computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK: %[[x_gep:.*]] = getelementptr inbounds [2 x [2 x float]], [2 x [2 x float]]* %x{{.*}}, i32 0 +; CHECK: load float, float* %[[x_gep]], {{.*}}, !noalias ![[param_noalias:.*]] +; CHECK: %[[y_gep:.*]] = getelementptr inbounds [2 x [2 x float]], [2 x [2 x float]]* %y{{.*}}, i32 0 +; CHECK: load float, float* %[[y_gep]], {{.*}}, !noalias ![[param_noalias]] +; CHECK: %[[result_ptr:.*]] = bitcast [2 x [6 x float]]* %fusion{{.*}} to float* +; CHECK: %[[result_gep:.*]] = getelementptr inbounds float, float* %[[result_ptr]] +; CHECK: store float {{.*}}, float* %[[result_gep]], !alias.scope ![[param_noalias]] +; CHECK: ![[param_noalias]] = !{![[retval_buffer:.*]]} + )", + /*match_optimized_ir=*/false); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..962293630683fcbbce3941f622061a2ff0f02dda --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuUnrollingTest : public GpuCodegenTest {}; + +const char *const kAddModule = R"( + HloModule test_module + + fused_computation { + p0.param_0 = f32[2,2]{1,0} parameter(0) + p1.param_1 = f32[2,2]{1,0} parameter(1) + ROOT add = f32[2,2] add(p0.param_0, p1.param_1) + } + + ENTRY BroadcastIntoAdd { + p0 = f32[2,2]{1,0} parameter(0) + p1 = f32[2,2]{1,0} parameter(1) + ROOT fusion = f32[2,2]{1,0} fusion(p0, p1), kind=kLoop, + calls=fused_computation + })"; + +TEST_F(GpuUnrollingTest, DoNotUnroll) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + config.set_debug_options(debug_options); + auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @fusion +; CHECK: fadd +; CHECK-NOT: fadd +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuUnrollingTest, UnrollFourTimes) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + // We request a factor of 8, but the computation works on 4 elements, limiting + // the maximum unroll factor. + debug_options.set_xla_gpu_max_kernel_unroll_factor(8); + config.set_debug_options(debug_options); + auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @fusion +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK-NOT: fadd +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuUnrollingTest, UnrollDefaultTimes) { + // The default unrolling factor is 4. + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @fusion +; CHECK: load <4 x float> +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK-NOT: fadd +; CHECK: store <4 x float> +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuUnrollingTest, UnrollUnfusedAdd) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(4); + config.set_debug_options(debug_options); + + const char *const kUnfusedAddModule = R"( + HloModule test_module + + ENTRY AddFunc { + p0 = f32[2,2]{1,0} parameter(0) + p1 = f32[2,2]{1,0} parameter(1) + ROOT add = f32[2,2]{1,0} add(p0, p1) + })"; + auto hlo_module = ParseHloString(kUnfusedAddModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @add +; CHECK: load <4 x float> +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK-NOT: fadd +; CHECK: store <4 x float> +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(2); + config.set_debug_options(debug_options); + + const char *const kMultiOutputFusionModule = R"( + HloModule test_module + + fused_computation { + p0.param_0 = f32[2,2]{1,0} parameter(0) + p1.param_1 = f32[2,2]{1,0} parameter(1) + add = f32[2,2]{1,0} add(p0.param_0, p1.param_1) + mul = f32[2,2]{1,0} multiply(p0.param_0, p1.param_1) + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(add, mul) + } + + ENTRY BroadcastIntoAdd { + p0 = f32[2,2]{1,0} parameter(0) + p1 = f32[2,2]{1,0} parameter(1) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p0, p1), kind=kLoop, + calls=fused_computation + })"; + auto hlo_module = + ParseHloString(kMultiOutputFusionModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @fusion +; CHECK: load <2 x float> +; CHECK: load <2 x float> +; CHECK-NOT: load <2 x float> +; CHECK: fadd +; CHECK: fmul +; CHECK: fadd +; CHECK: fmul +; CHECK: store <2 x float> +; CHECK: store <2 x float> +; CHECK-NOT: store <2 x float> +; CHECK-NOT: fadd +; CHECK-NOT: fmul +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9072b30317d253fd6d50e9d98949cad4eaebfe7b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class InfeedTest : public ClientLibraryTestBase { + protected: + // Transfers the given literal to the infeed interface of the device, and + // check if the returned data from Infeed HLO is same as the literal. + void TestInfeedRoundTrip(const Literal& literal) { + // TODO(b/30481585) Explicitly reset the Infeed state so that the + // test is not affected by the state from the previous tests. + ASSERT_IS_OK(client_->TransferToInfeed(literal)); + XlaBuilder builder(TestName()); + Infeed(&builder, literal.shape()); + if (ShapeUtil::IsTuple(literal.shape())) { + // TODO(b/30609564): Use ComputeAndCompareLiteral instead. + ComputeAndCompareTuple(&builder, literal, {}); + } else { + ComputeAndCompareLiteral(&builder, literal, {}); + } + } +}; + +TEST_F(InfeedTest, SingleInfeedR0Bool) { + TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); +} + +TEST_F(InfeedTest, SingleInfeedR1U32) { + TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); +} + +TEST_F(InfeedTest, SingleInfeedR2F32) { + TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); +} + +TEST_F(InfeedTest, SingleInfeedR3F32) { + TestInfeedRoundTrip( + *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); +} + +TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { + const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); + const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); + + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0minor)); + + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0major)); +} + +TEST_F(InfeedTest, SingleInfeedR4S32) { + TestInfeedRoundTrip(*LiteralUtil::CreateR4( + {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, + {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); +} + +// Tests that a large infeed can be handled. +TEST_F(InfeedTest, LargeInfeed) { + Array4D array(80, 100, 8, 128); + array.FillIota(1.0f); + TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D(array)); +} + +TEST_F(InfeedTest, SingleInfeedTuple) { + TestInfeedRoundTrip( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), + LiteralUtil::CreateR0(false).get()})); +} + +TEST_F(InfeedTest, SingleInfeedEmptyTuple) { + TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); +} + +// Tests that a large tuple infeed can be handled. +TEST_F(InfeedTest, SingleInfeedLargeTuple) { + Array4D array(40, 100, 8, 128); + array.FillIota(1.0f); + TestInfeedRoundTrip(*LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4FromArray4D(array).get(), + LiteralUtil::CreateR0(5).get()})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk.cc b/tensorflow/compiler/xla/service/gpu/thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..c78605cebbc671272b8df9faf0e0cc54be2f5b1c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/thunk.cc @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/thunk.h" + +namespace xla { +namespace gpu { + +std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { + switch (kind) { + case Thunk::kConditional: + return os << "kConditional"; + case Thunk::kConvolution: + return os << "kConvolution"; + case Thunk::kCopy: + return os << "kCopy"; + case Thunk::kCudnnBatchNormBackward: + return os << "kCudnnBatchNormBackward"; + case Thunk::kCudnnBatchNormForwardInference: + return os << "kCudnnBatchNormForwardInference"; + case Thunk::kCudnnBatchNormForwardTraining: + return os << "kCudnnBatchNormForwardTraining"; + case Thunk::kFft: + return os << "kFft"; + case Thunk::kGemm: + return os << "kGemm"; + case Thunk::kInfeed: + return os << "kInfeed"; + case Thunk::kKernel: + return os << "kKernel"; + case Thunk::kMemset32BitValue: + return os << "kMemset32BitValue"; + case Thunk::kMemzero: + return os << "kMemzero"; + case Thunk::kOutfeed: + return os << "kOutfeed"; + case Thunk::kSequential: + return os << "kSequential"; + case Thunk::kTuple: + return os << "kTuple"; + case Thunk::kWhile: + return os << "kWhile"; + } +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 931c0bffab850362dbd2df975657dd47d9cbd3ae..e68bee035a029178844282995429eaa960cc4817 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -40,7 +41,7 @@ class GpuExecutable; // This is thread-compatible. class Thunk { public: - enum class Kind { + enum Kind { kConditional, kConvolution, kCopy, @@ -53,6 +54,7 @@ class Thunk { kKernel, kMemset32BitValue, kMemzero, + kOutfeed, kSequential, kTuple, kWhile, @@ -80,25 +82,18 @@ class Thunk { return Status::OK(); } - // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) - // before calling ExecuteOnStream(stream). If it returns true, it's the - // user's responsibility to wait for all activity on the GPU to finish before - // calling ExecuteOnStream. - // - // This value is not required to be constant for a given Thunk. For example, - // a Thunk that performs autotuning may return true for its first run and - // false thereafter. - virtual bool ShouldHaltAllActivityBeforeRunning(se::Stream* /*stream*/) { - return false; - } + // Returns true if this kernel will autotune for the stream device the next + // time it is run. + virtual bool WillAutotuneKernel(se::Stream* /*stream*/) { return false; } // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's - // lifetime. Stream argument must be non-null. + // lifetime. 'stream' and 'profiler' must be non-null. // // Precondition: Initialize(stream->parent()) has been called. virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) = 0; + se::Stream* stream, + HloExecutionProfiler* profiler) = 0; private: Kind kind_; @@ -108,6 +103,8 @@ class Thunk { // A sequence of thunks. using ThunkSequence = std::vector>; +std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index 97cb04c38fbf18e516857f5269c984696ca204c3..8579b1545fd24f80621ac0f53b997e33586cbabe 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -15,30 +15,41 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace gpu { Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { - std::vector tuple_element_buffer_addresses; - for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { - tuple_element_buffer_addresses.push_back( - buffer_allocations.GetDeviceAddress(tuple_element_buffer).opaque()); + se::Stream* stream, + HloExecutionProfiler* profiler) { + auto size = tuple_element_buffers_.size(); + auto tuple_element_buffer_addresses = MakeUnique(size); + for (int i = 0; i != size; ++i) { + tuple_element_buffer_addresses[i] = + buffer_allocations.GetDeviceAddress(tuple_element_buffers_[i]).opaque(); } se::DeviceMemory dest_buffer_address( buffer_allocations.GetDeviceAddress(dest_buffer_)); - auto host_size = tuple_element_buffer_addresses.size() * sizeof(void*); + auto host_size = size * sizeof(void*); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); if (!stream ->ThenMemcpy(&dest_buffer_address, - tuple_element_buffer_addresses.data(), host_size) + tuple_element_buffer_addresses.get(), host_size) .ok()) { return InternalError( "Unable to launch MemcpyH2D from %p to %p with size %lu", - tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(), - sizeof(void*) * tuple_element_buffer_addresses.size()); + tuple_element_buffer_addresses.get(), dest_buffer_address.opaque(), + host_size); + } + // Free the tuple address buffer when memcpy is done. + auto* buffers_raw = tuple_element_buffer_addresses.release(); + if (!stream->ThenDoHostCallback([buffers_raw] { delete[] buffers_raw; }) + .ok()) { + delete[] buffers_raw; + return InternalError("Unable to enqueue host callback!"); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index 951f809b51937c97a6e7de0345ec58a8b66a4242..2d5735d6c40ccd26f0e527f1a02403910db4c812 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -46,7 +47,8 @@ class TupleThunk : public Thunk { TupleThunk& operator=(const TupleThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 30b9640c4c75dae61e9a90da5fb10e9d4a90cd26..d81d87e7dc54cd752000b85f3ec173d66d7195e4 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -29,10 +30,14 @@ WhileThunk::WhileThunk( const HloInstruction* hlo) : Thunk(Kind::kWhile, hlo), condition_result_buffer_index_(condition_result_buffer_index), + // Pass nullptr as the HloInstruction* to the condition_thunk_sequence_ + // and body_thunk_sequence_ constructors because these SequentialThunks + // are logically "part of" this WhileThunk, and shouldn't be profiled + // separately from it. condition_thunk_sequence_(MakeUnique( - std::move(*condition_thunk_sequence), hlo)), - body_thunk_sequence_( - MakeUnique(std::move(*body_thunk_sequence), hlo)) {} + std::move(*condition_thunk_sequence), nullptr)), + body_thunk_sequence_(MakeUnique( + std::move(*body_thunk_sequence), nullptr)) {} Status WhileThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { @@ -43,18 +48,24 @@ Status WhileThunk::Initialize(const GpuExecutable& executable, } Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase condition_result_data = buffer_allocations.GetDeviceAddress(condition_result_buffer_index_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); while (true) { // Invoke thunk sequence for while 'condition' computation. - TF_RETURN_IF_ERROR( - condition_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); + profiler->StartHloComputation(); + VLOG(3) << "Executing condition computation"; + TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream( + buffer_allocations, stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->while_condition()); // Copy the result of condition computation and break the loop if 'false'. bool condition_result; stream->ThenMemcpy(&condition_result, condition_result_data, sizeof(bool)); + VLOG(3) << "condition_result = " << condition_result; Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError( @@ -66,9 +77,15 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, break; } - // Invoke thunk sequence for while 'body' computation. - TF_RETURN_IF_ERROR( - body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); + // We measure the time of one execution of the while body computation. The + // while body may be executed more than once, the last measurement "wins". + profiler->StartHloComputation(); + VLOG(3) << "Executing body computation"; + // Invoke thunk sequence for while 'body' computation, and pass on + // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'. + TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, + stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->while_body()); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 22176685a92df9c95b10f755b209309843c0fa3a..9270f95ee67cf0bd3ab8082452a9d8703cb4304e 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -48,7 +49,8 @@ class WhileThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const BufferAllocation::Slice condition_result_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc deleted file mode 100644 index 7749201cbceece216a2db2569936949eb7de5125..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ /dev/null @@ -1,521 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" - -#include -#include - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { -namespace gpu { - -namespace { - -// TODO(b/33483676) Use an expression tree to specify computations to pattern -// match for while transformations. - -// ExprTree is a simple recursive data structure used to express computation -// patterns to match. -// -// Each ExprTree node is comprised of an HloOpcode, and a set of operands (each -// of type ExprTree). Operands can be added by specifying the index and -// HloOpcode of the operand. -// -// For example, the following computation: -// -// Parameter -// | -// Const GetTupleElement -// \ / -// Add (root) -// -// Can be matched with the following expression tree: -// -// ExprTree add(HloOpcode::kAdd, -// ExprTree(HloOpcode::kConstant), -// ExprTree(HloOpcode::kGetTupleElement, -// tuple_index, ExprTree(HloOpcode::kParameter))); -// -// Match the ExprTree root against an Hlo graph: -// -// ExprTree::TaggedInstructionMap tagged_instructions; -// TF_RETURN_IF_ERROR(add.Match(computation_->root_instruction(), -// &tagged_instructions)); -// -// Instructions that are "tagged" with a context-specific string will -// be returned in 'tagged_instructions' for further processing (i.e. parsing -// constants or recording the tuple_index). -// -class ExprTree { - public: - explicit ExprTree(HloOpcode opcode) : opcode_(opcode) {} - ExprTree(HloOpcode opcode, const string& tag) : opcode_(opcode), tag_(tag) {} - ExprTree(HloOpcode opcode, const ExprTree& operand0) : opcode_(opcode) { - SetOperand(0, operand0); - } - ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0) - : opcode_(opcode) { - SetOperand(index0, operand0); - } - ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0, - int64 index1, const ExprTree& operand1) - : opcode_(opcode) { - SetOperand(index0, operand0); - SetOperand(index1, operand1); - } - ExprTree(HloOpcode opcode, const string& tag, const ExprTree& operand0) - : opcode_(opcode), tag_(tag) { - SetOperand(0, operand0); - } - ExprTree(HloOpcode opcode, const ExprTree& operand0, const ExprTree& operand1) - : opcode_(opcode) { - SetOperand(0, operand0); - SetOperand(1, operand1); - } - - ExprTree(const ExprTree& to_copy) { - opcode_ = to_copy.opcode_; - tag_ = to_copy.tag_; - if (to_copy.fused_root_tree_ != nullptr) { - fused_root_tree_.reset(new ExprTree(*to_copy.fused_root_tree_)); - } - for (auto& pair : to_copy.operands_) { - CHECK(operands_.find(pair.first) == operands_.end()); - operands_.insert(std::make_pair( - pair.first, std::unique_ptr(new ExprTree(*pair.second)))); - } - } - - void SetFusedRoot(const ExprTree& fused_root) { - fused_root_tree_.reset(new ExprTree(fused_root)); - } - - typedef std::unordered_map - TaggedInstructionMap; - - // Matches 'instruction' HloOpcode against 'opcode_'. - // Recursively matches each operand in 'operands_'. - // Recursively matches fused instructions starting at 'fused_root_tree_' - // if 'opcode_ == kFusion'. - // Returns OK status, and instructions in 'tagged_instructions' for each - // matched ExprTree node with a non-empty 'tag_'. - // Returns error message on failure. - Status Match(const HloInstruction* instruction, - TaggedInstructionMap* tagged_instructions) const { - if (opcode_ != instruction->opcode()) { - return InvalidArgument("got opcode %s, want %s", - HloOpcodeString(instruction->opcode()).c_str(), - HloOpcodeString(opcode_).c_str()); - } - - VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_; - if (!tag_.empty()) { - tagged_instructions->insert({tag_, instruction}); - } - - if (instruction->opcode() == HloOpcode::kFusion) { - CHECK(fused_root_tree_ != nullptr); - // Match fused instructions for this node starting a 'fused_root_tree'. - TF_RETURN_IF_ERROR(fused_root_tree_->Match( - instruction->fused_expression_root(), tagged_instructions)); - } - - // Match each operand in 'operands_'. - for (auto& pair : operands_) { - TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), - tagged_instructions)); - } - return Status::OK(); - } - - private: - void SetOperand(int64 index, const ExprTree& operand) { - CHECK_EQ(0, operands_.count(index)); - operands_.insert(std::make_pair(index, MakeUnique(operand))); - } - - HloOpcode opcode_; - std::unordered_map> operands_; - std::unique_ptr fused_root_tree_; - string tag_; -}; - -// MatcherBase is a base class that provides common functionality for -// sub-classes which match specific target sub-computations (i.e. loop -// induction variable initialization, comparison and update). -class MatcherBase { - public: - MatcherBase() {} - virtual ~MatcherBase() {} - - // Attempts to match each ExprTree in 'expr_trees_'. - // Returns OK on the first successful match, error status otherwise. - virtual Status Run() { - Status status; - for (const ExprTree& expr_tree : expr_trees_) { - status = MatchExprTree(expr_tree); - if (status.ok()) { - return status; - } - } - return status; - } - - virtual Status MatchExprTree(const ExprTree& expr_tree) = 0; - - // Returns the constant value parsed form kConstant 'instruction'. - // Returns error status otherwise. - Status ParseConstInteger(const HloInstruction* instruction, - int64* const_value) const { - CHECK_EQ(HloOpcode::kConstant, instruction->opcode()); - PrimitiveType element_type = instruction->shape().element_type(); - if (element_type != S32 && element_type != S64) { - return InvalidArgument("Expected constant of integral type."); - } - const Literal& literal = instruction->literal(); - PrimitiveType type = literal.shape().element_type(); - if (type != S32 && type != S64) { - return InvalidArgument("Must use S32 or S64 integral types."); - } - if (type == S32) { - *const_value = static_cast(literal.GetFirstElement()); - } else if (type == S64) { - *const_value = literal.GetFirstElement(); - } - return Status::OK(); - } - - StatusOr GetTaggedInstruction( - const string& tag, - const ExprTree::TaggedInstructionMap& tagged_instructions) { - auto it = tagged_instructions.find(tag); - if (it == tagged_instructions.end()) { - return InvalidArgument("Cound not find instruction for tag: %s", - tag.c_str()); - } - return it->second; - } - - protected: - std::vector expr_trees_; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase); -}; - -// WhileConditionComputationMatcher attempts to match a target computation -// pattern in the while condition sub-computation. -// If the target pattern is matched, two pieces of information are extracted -// from 'tagged' instructions returned by the matcher: -// -// *) 'tuple_index': -// *) The loop induction variable tuple_index from the GetTupleElement -// instruction of the matched computation. -// *) Used in subsequent matching passes of while init operand and body -// computations to select loop induction variable tuple element. -// -// *) 'loop_limit': -// *) The integral value from Constant root operand in matched computation. -// *) Used as the constant for the loop limit. -// -class WhileConditionComputationMatcher : public MatcherBase { - public: - explicit WhileConditionComputationMatcher(const HloComputation* computation) - : computation_(computation) { - expr_trees_.emplace_back(BuildCondExprTree()); - } - - int64 loop_limit() const { return loop_limit_; } - int64 tuple_index() const { return tuple_index_; } - - private: - // Builds expression tree for the following condition computation: - // - // Const Parameter - // \ / - // Fusion ------------> FusionParam FusionParam - // \ / - // GTE / - // \ / - // LessThan (fused root) - // - ExprTree BuildCondExprTree() { - // Build ExprTree for fused instructions. - ExprTree fused_root( - HloOpcode::kLt, - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")), - ExprTree(HloOpcode::kParameter)); - - // Build top-level computation. - ExprTree root(HloOpcode::kFusion, - ExprTree(HloOpcode::kConstant, "loop_limit"), - ExprTree(HloOpcode::kParameter, "param0")); - - root.SetFusedRoot(fused_root); - return root; - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while condition"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), - &tagged_instructions)); - - // Get tagged GTE instruction and set 'tuple_index_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* gte, - GetTaggedInstruction("gte", tagged_instructions)); - tuple_index_ = gte->tuple_index(); - - // Get tagged Constant instruction and parse 'loop_limit_'. - TF_ASSIGN_OR_RETURN( - const HloInstruction* const_hlo, - GetTaggedInstruction("loop_limit", tagged_instructions)); - TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_limit_)); - - // Get tagged "param0" instruction, and check that it matches - // 'computation_' parameter 0. - TF_ASSIGN_OR_RETURN(const HloInstruction* param0, - GetTaggedInstruction("param0", tagged_instructions)); - if (param0 != computation_->parameter_instruction(0)) { - return InvalidArgument("Unexpected Parameter0 instruction : %s", - param0->name().c_str()); - } - - // Get tagged 'gte.fusion_param.param0', find its associated fusion operand, - // and compare it to 'computation_' parameter0. - TF_ASSIGN_OR_RETURN( - const HloInstruction* gte_fusion_param0, - GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions)); - CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode()); - CHECK(gte_fusion_param0->IsFused()); - if (gte_fusion_param0->parent()->FusionInstruction()->operand( - gte_fusion_param0->parameter_number()) != - computation_->parameter_instruction(0)) { - return InvalidArgument("Could not match fusion param: %s", - gte_fusion_param0->name().c_str()); - } - - return Status::OK(); - } - - const HloComputation* computation_; - - int64 loop_limit_ = -1; - int64 tuple_index_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher); -}; - -// WhileInitOperandMatcher matches a target computation pattern of the -// while instructions 'init' operand, indexing the tuple at 'tuple_index'. -// On success, parses constant 'loop_start' which represents the loop induction -// variable start values, then returns OK. -// Returns error status otherwise. -class WhileInitOperandMatcher : public MatcherBase { - public: - WhileInitOperandMatcher(const HloInstruction* while_hlo, - const int64 tuple_index) - : while_hlo_(while_hlo), tuple_index_(tuple_index) { - expr_trees_.emplace_back(BuildInitExprTree()); - } - - int64 loop_start() const { return loop_start_; } - - private: - // Builds expression tree for the following while init operand subcomputation: - // - // Const - // | - // Copy - // | - // Tuple0 - // | - // While - // - ExprTree BuildInitExprTree() { - return ExprTree( - HloOpcode::kWhile, "while", - ExprTree(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, - ExprTree(HloOpcode::kConstant, "loop_start")))); - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while init"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions)); - - // Get tagged while instruction check against 'while_hlo_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* while_hlo, - GetTaggedInstruction("while", tagged_instructions)); - if (while_hlo != while_hlo_) { - return InvalidArgument("Expected While for instruction : %s", - while_hlo->name().c_str()); - } - - // Get tagged Constant instruction and parse 'loop_start_'. - TF_ASSIGN_OR_RETURN( - const HloInstruction* const_hlo, - GetTaggedInstruction("loop_start", tagged_instructions)); - TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); - - return Status::OK(); - } - - const HloInstruction* while_hlo_; - const int64 tuple_index_; - - int64 loop_start_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher); -}; - -// WhileBodyComputationMatcher matches a target computation pattern for -// the loop induction variable update. Matching proceeds from the while body -// computation root[tuple_index] to param[tuple_index], where 'tuple_index' -// If the target pattern is matched, parses a constant which represents the -// loop induction variable increment value, then returns status OK. -// Returns error status otherwise. -class WhileBodyComputationMatcher : public MatcherBase { - public: - WhileBodyComputationMatcher(const HloComputation* computation, - const int64 tuple_index) - : computation_(computation), tuple_index_(tuple_index) { - expr_trees_.emplace_back(BuildBodyExprTree(0, 1)); - expr_trees_.emplace_back(BuildBodyExprTree(1, 0)); - } - - int64 loop_increment() const { return loop_increment_; } - - private: - // Builds expression tree for the following while body computation: - // - // - // FusionParam FusionParam - // \ / - // Const Param \ GTE1 - // \ / \ / - // Fusion -----------> Add - // | - // Copy - // | - // Tuple0 - // - ExprTree BuildBodyExprTree(const int64 const_index, const int64 gte_index) { - // Build ExprTree for fused instructions. - ExprTree gte1 = - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")); - ExprTree fused_root(HloOpcode::kAdd, const_index, - ExprTree(HloOpcode::kParameter), gte_index, gte1); - - // Build fusion instruction (and set fused root). - ExprTree fusion(HloOpcode::kFusion, 0, - ExprTree(HloOpcode::kConstant, "loop_increment"), 1, - ExprTree(HloOpcode::kParameter, "param0")); - fusion.SetFusedRoot(fused_root); - - // Build top-level computation. - ExprTree tuple0(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, fusion)); - return tuple0; - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while body"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), - &tagged_instructions)); - - for (const auto& pair : tagged_instructions) { - const auto& tag = pair.first; - const auto& inst = pair.second; - - if (tag == "gte" && inst->tuple_index() != tuple_index_) { - // Check that the matched GTE instruction is at the 'tuple_index' we - // matched in the while condition computation. - return InvalidArgument("Unexpected tuple index instruction : %s", - inst->name().c_str()); - } else if (tag == "loop_increment") { - // ParseHloString the constant which represents the loop induction - // variable increment value. - TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); - } else if (tag == "param0" && - inst != computation_->parameter_instruction(0)) { - // Check that the matched parameter == parameter 0 from 'computation_'. - return InvalidArgument("Unexpected Parameter0 instruction : %s", - inst->name().c_str()); - } else if (tag == "gte.fusion_param.param0") { - // Fusion parameter: lookup and compare with associated fusion operand. - CHECK_EQ(HloOpcode::kParameter, inst->opcode()); - CHECK(inst->IsFused()); - if (inst->parent()->FusionInstruction()->operand( - inst->parameter_number()) != - computation_->parameter_instruction(0)) { - return InvalidArgument("Could not match fusion param: %s", - inst->name().c_str()); - } - } - } - return Status::OK(); - } - - const HloComputation* computation_; - const int64 tuple_index_; - - int64 loop_increment_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher); -}; - -} // namespace - -StatusOr> CanTransformWhileToFor( - const HloInstruction* while_hlo) { - if (while_hlo->opcode() != HloOpcode::kWhile) { - return InvalidArgument("Expected While instruction."); - } - - WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition()); - TF_RETURN_IF_ERROR(cond_matcher.Run()); - - WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index()); - TF_RETURN_IF_ERROR(init_matcher.Run()); - - WhileBodyComputationMatcher body_matcher(while_hlo->while_body(), - cond_matcher.tuple_index()); - TF_RETURN_IF_ERROR(body_matcher.Run()); - - // Check for valid For loop parameters. - if (init_matcher.loop_start() >= cond_matcher.loop_limit()) { - return InvalidArgument("Loop start must be less than loop limit."); - } - if (body_matcher.loop_increment() <= 0) { - return InvalidArgument("Loop increment must greater than zero."); - } - return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(), - body_matcher.loop_increment()); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.h b/tensorflow/compiler/xla/service/gpu/while_transformer.h deleted file mode 100644 index fe3a954e1828ee4a323872eea81f64c7e780ad24..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/statusor.h" - -namespace xla { -namespace gpu { - -// Runs an analysis of the while loop instruction 'while_hlo' (and its -// associated sub-computations) to determine if it can be transformed into an -// equivalent "for" loop with the following "for" loop parameters: -// -// *) 'loop_start': loop induction variable starting value. -// *) 'loop_limit': loop induction variable limit value. -// *) 'loop_increment': loop induction variable per-iteration increment value. -// -// Returns an std::tuple = (loop_start, loop_limit, loop_increment) on success. -// The values in the returned tuple are values extracted from the 'while_hlo' -// operand (and its sub-computations) during analysis. -// Returns an error status on failure. -StatusOr> CanTransformWhileToFor( - const HloInstruction* while_hlo); - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 2f290f61bd527e9827472a78256f015e066e44be..c5f3906356d821e059d2b1213c9083c4408a4d1c 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -13,11 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" - #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -42,7 +41,7 @@ class WhileTransformerTest : public HloTestBase { const int64 tuple_index, const int64 limit) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(limit))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(limit))); auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( 0, GetLoopStateShape(tuple_index), "loop_state")); auto induction_variable = @@ -65,8 +64,8 @@ class WhileTransformerTest : public HloTestBase { auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, ind_var_tuple_index)); - auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(increment))); + auto inc = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(increment))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(data_tuple_index). @@ -89,10 +88,12 @@ class WhileTransformerTest : public HloTestBase { const int64 ind_var_tuple_index, const int64 ind_var_init) { auto builder = HloComputation::Builder(TestName() + ".While"); - auto induction_var_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(ind_var_init))); - auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto induction_var_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(ind_var_init))); + auto data_init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); auto loop_state_init = ind_var_tuple_index == 0 ? builder.AddInstruction( @@ -108,12 +109,12 @@ class WhileTransformerTest : public HloTestBase { void RunFusionPasses() { // Run standard fusion passes. - EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/false) + .Run(module_.get()) + .status()); + TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module_.get()) + .status()); } void RunCopyInsertionPass() { @@ -139,10 +140,7 @@ class WhileTransformerTest : public HloTestBase { Shape condition_result_shape_; }; -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { // Build computation with induction variable at tuple element 0. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -151,18 +149,13 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); - // Check results. - EXPECT_THAT(result.ConsumeValueOrDie(), - Eq(std::tuple(0, 10, 1))); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(10, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { // Build computation with induction variable at tuple element 1. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(1, 10)); @@ -171,19 +164,14 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); - // Check results. - EXPECT_THAT(result.ConsumeValueOrDie(), - Eq(std::tuple(0, 10, 1))); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(10, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { - // Build computation with invalid loop limit. +TEST_F(WhileTransformerTest, ImpossibleLoopLimit) { + // Build computation with an impossible loop limit. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 5)); auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); @@ -191,17 +179,13 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_FALSE(result.ok()); - EXPECT_THAT(result.status().error_message(), - HasSubstr("Loop start must be less than loop limit.")); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(0, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { +TEST_F(WhileTransformerTest, InvalidLoopIncrement) { // Build computation with invalid loop increment. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -210,11 +194,9 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_FALSE(result.ok()); - EXPECT_THAT(result.status().error_message(), - HasSubstr("Loop increment must greater than zero.")); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_FALSE(result); } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/xfeed_queue.h b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..dd46ff433ba0ad6bfa3999b96845fdaebe148aca --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h @@ -0,0 +1,90 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { +namespace gpu { + +// TODO(b/30467474) Once GPU outfeed implementation settles, consider +// folding back the cpu and gpu outfeed implementations into a generic +// one if possible. + +// Manages a thread-safe queue of buffers. +template +class XfeedQueue { + public: + // Adds a tree of buffers to the queue. The individual buffers correspond to + // the elements of a tuple and may be nullptr if the buffer is a tuple index + // buffer. + void EnqueueDestination(BufferType buffers) { + tensorflow::mutex_lock l(mu_); + enqueued_buffers_.push_back(std::move(buffers)); + cv_.notify_one(); + } + + // Blocks until the queue is non-empty, then returns the buffer at the head of + // the queue. + BufferType BlockingGetNextDestination() { + bool became_empty; + BufferType current_buffer; + { + tensorflow::mutex_lock l(mu_); + while (enqueued_buffers_.empty()) { + cv_.wait(l); + } + current_buffer = std::move(enqueued_buffers_.front()); + enqueued_buffers_.pop_front(); + became_empty = enqueued_buffers_.empty(); + } + if (became_empty) { + for (const auto& callback : on_empty_callbacks_) { + callback(); + } + } + return current_buffer; + } + + void RegisterOnEmptyCallback(std::function callback) { + on_empty_callbacks_.push_back(std::move(callback)); + } + + private: + tensorflow::mutex mu_; + + // Condition variable that is signaled every time a buffer is enqueued. + tensorflow::condition_variable cv_; + + // The queue of trees of buffers. Buffer* queue contents are not owned. + std::deque enqueued_buffers_ GUARDED_BY(mu_); + + // List of callbacks which will be called when 'enqueued_buffers_' becomes + // empty. + std::vector> on_empty_callbacks_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_ diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index acf661148699dab18916e3065ee647d37fda6208..aa89567ee86e59e197045c0b51eed3b9aa59fef7 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -47,7 +48,7 @@ HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "x_value")); auto half = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.5))); builder.AddInstruction(HloInstruction::CreateBinary( half->shape(), HloOpcode::kAdd, x_value, half)); return module->AddEmbeddedComputation(builder.Build()); @@ -122,7 +123,7 @@ std::unique_ptr MakeBigGraph() { auto rng = builder.AddInstruction( HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m})); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_computation = ScalarSumComputation(module.get()); builder.AddInstruction( HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation)); diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 5dba50a63b1d77bda0835e0333cc7dd9ddbe2dcf..4005fc0d114a3ec7a38dfb5edecdaeb1e8497ade 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -26,7 +26,8 @@ namespace xla { using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; -StatusOr MinimumMemoryForModule( +/*static*/ +StatusOr HeapSimulator::MinimumMemoryForModule( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function) { if (module_sequence.empty()) { @@ -49,15 +50,19 @@ StatusOr MinimumMemoryForModule( return result.heap_size; } -StatusOr MinimumMemoryForComputation( +/*static*/ +StatusOr HeapSimulator::MinimumMemoryForComputation( const HloComputation& computation, const std::vector& sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap* + memory_by_computation) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function)); + sequence, points_to_analysis, size_function, + HeapSimulator::Options(), memory_by_computation)); return result.heap_size; } @@ -81,9 +86,11 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options, + const tensorflow::gtl::FlatMap* + memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*module_sequence=*/nullptr); + /*module_sequence=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -223,6 +230,9 @@ Status HeapSimulator::RunComputation( // // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer // that we should assign. + + // Make sure each buffer get reused at most once. + FlatSet reused_buffers; for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; @@ -235,6 +245,9 @@ Status HeapSimulator::RunComputation( bool shared = false; if (options_.may_reuse_operand_buffers) { for (const BufferValue* operand_buffer : operand_buffers_to_free) { + if (reused_buffers.count(operand_buffer) != 0) { + continue; + } if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && buffer->instruction()->opcode() != HloOpcode::kCopy && points_to_analysis.CanShareOperandBufferWithUser( @@ -244,6 +257,7 @@ Status HeapSimulator::RunComputation( << operand_buffer->ToString(); ShareBuffer(buffer, operand_buffer, instruction); shared = true; + reused_buffers.insert(operand_buffer); break; } } @@ -254,6 +268,12 @@ Status HeapSimulator::RunComputation( Alloc(buffer, instruction); } } + // Account for the memory used by subcomputations when estimating the + // current heap size. + if (memory_by_computation_ != nullptr) { + algorithm_->AccountForSubcomputationMemory(instruction, + *memory_by_computation_); + } // If the whole module is sequential, we can save memory by running the // heap-simulation for sub-computations inline. E.g. the buffers for the @@ -321,12 +341,15 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence) + const SequentialHloOrdering::HloModuleSequence* module_sequence, + const tensorflow::gtl::FlatMap* + memory_by_computation) : no_fragmentation_stats_(MakeUnique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - module_sequence_(module_sequence) { + module_sequence_(module_sequence), + memory_by_computation_(memory_by_computation) { debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); } @@ -495,6 +518,26 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } +void NoFragmentationStatsHeap::AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + // We only count the memory usage of the largest subcomputation, instead of + // adding them all, because subcomputations won't execute in parallel. + int64 max_subcomputation_bytes = 0; + for (const auto* c : instruction->called_computations()) { + auto it = memory_by_computation.find(c); + if (it != memory_by_computation.end()) { + int64 subcomputation_bytes = it->second; + if (subcomputation_bytes > max_subcomputation_bytes) { + max_subcomputation_bytes = subcomputation_bytes; + } + } + } + max_heap_size_ = + std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes); +} + void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) { current_heap_size_ -= size; } diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 3be3bb8e7fae9d22c9be00f3f81d2b93638c22ef..811a6042df9434ac3f4bed71b9c093433e25c1bb 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -34,21 +34,6 @@ limitations under the License. namespace xla { -// Returns the minimum memory required to compute an HLO module where all -// computations have been scheduled (represented by the given module_sequence), -// assuming no fragmentation. -StatusOr MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function); - -// Returns the minimum memory required to compute the given computation, -// assuming no fragmentation. -StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); - // Forward declare classes defined below. class HeapAlgorithm; @@ -100,6 +85,23 @@ class HeapSimulator { const BufferValueFlatSet* buffers_to_assign; }; + // Returns the minimum memory required to compute an HLO module where all + // computations have been scheduled (represented by the given + // module_sequence), assuming no fragmentation. + static StatusOr MinimumMemoryForModule( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function); + + // Returns the minimum memory required to compute the given computation, + // assuming no fragmentation. + static StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); + // Run the heap simulation with the given algorithm, assuming the given // module_sequence, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid @@ -126,7 +128,9 @@ class HeapSimulator { const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + const Options& options = Options(), + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); private: // If 'module_sequence' is non-null, it is used to find kCall and kWhile @@ -135,7 +139,9 @@ class HeapSimulator { HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence); + const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr, + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); ~HeapSimulator(); Status RunComputation( @@ -159,7 +165,13 @@ class HeapSimulator { const std::unique_ptr algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; + // module_sequence_ is set by buffer assignment, and memory_by_computation_ is + // set by hlo scheduling. Then, in RunComputation, we check both in order to + // handle subcomputations. It would be good to unify the handling of + // subcomputations, but it's not clear how. const SequentialHloOrdering::HloModuleSequence* module_sequence_; + const tensorflow::gtl::FlatMap* + memory_by_computation_; // In addition to Alloc and Free, the heap simulator exposes a concept of // buffer sharing. When ShareBuffer is called, instead of allocating new @@ -204,6 +216,11 @@ class HeapAlgorithm { // Alloc allocates a buffer of 'size' bytes. virtual void Alloc(const BufferValue* buffer, int64 size) = 0; + virtual void AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap& + memory_by_computation) {} + // Free de-allocates a previously allocated buffer. virtual void Free(const BufferValue* buffer, int64 size) = 0; @@ -222,7 +239,14 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { ~NoFragmentationStatsHeap() override = default; void Alloc(const BufferValue* buffer, int64 size) override; + + void AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap& + memory_by_computation) override; + void Free(const BufferValue* buffer, int64 size) override; + Result Finish() override; private: diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 309ab85f784274835904015472f3f0d601885763..b41dc66fe9f5e869a114be96b7cc01fc1a3d59da 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -89,7 +89,8 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { cond_lt}; module_sequence[body_computation] = {body_param}; module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, MinimumMemoryForModule(module_sequence, size_fn).ValueOrDie()); + EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn) + .ValueOrDie()); } const char kAlloc[] = "Alloc"; @@ -197,6 +198,11 @@ class HeapSimulatorTracker { .ConsumeValueOrDie(); } + int64 OffsetAt(const HloInstruction* instruction, const ShapeIndex& index) { + const BufferValue* buffer = BufferAt(instruction, index); + return result_.chunk_map.at(buffer).offset; + } + // Ensures the expected sequence of Alloc/Free/Finish calls was performed. void ExpectCallSequence(const CallSequence& expected) const { EXPECT_EQ(expected, actual_calls_); @@ -208,10 +214,9 @@ class HeapSimulatorTracker { const ShapeIndex& index_a, const HloInstruction* instruction_b, const ShapeIndex& index_b) { - const BufferValue* a = BufferAt(instruction_a, index_a); - const BufferValue* b = BufferAt(instruction_b, index_b); - EXPECT_EQ(result_.chunk_map[a].offset, result_.chunk_map[b].offset) - << *a << ", " << *b; + int64 offset_a = OffsetAt(instruction_a, index_a); + int64 offset_b = OffsetAt(instruction_b, index_b); + EXPECT_EQ(offset_a, offset_b); } private: @@ -234,7 +239,7 @@ class HeapSimulatorTest : public HloTestBase { TEST_F(HeapSimulatorTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); // Constants aren't assigned. See b/32248867 HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0}); @@ -310,6 +315,43 @@ TEST_F(HeapSimulatorTest, MultiplyAdd) { tracker.ExpectSharedBuffers(add, {}, mul, {}); } +TEST_F(HeapSimulatorTest, BufferReusedOnce) { + HeapSimulatorTracker tracker(TestName()); + auto builder = HloComputation::Builder(TestName()); + + HloComputation::Builder fusion_builder("fusion"); + { + HloComputation::Builder& builder = fusion_builder; + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, f32vec4_, "A")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec4_, HloOpcode::kExp, a_param)); + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param)); + + builder.AddInstruction(HloInstruction::CreateTuple({exp, neg})); + } + auto fusion_computation = + tracker.module()->AddEmbeddedComputation(fusion_builder.Build()); + auto a_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec4_, "paramA")); + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param)); + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}), + HloInstruction::FusionKind::kLoop, {neg}, fusion_computation)); + tracker.module()->AddEntryComputation(builder.Build()); + + tracker.RunWholeModule({a_param, neg, fusion}); + + auto neg_buffer = tracker.OffsetAt(neg, {}); + int64 output_buffer_0 = tracker.OffsetAt(fusion, {0}); + int64 output_buffer_1 = tracker.OffsetAt(fusion, {1}); + // Only one buffer should be shared. + EXPECT_TRUE((neg_buffer == output_buffer_0) ^ + (neg_buffer == output_buffer_1)); +} + TEST_F(HeapSimulatorTest, MultiplyDot) { auto builder = HloComputation::Builder(TestName()); auto paramA = builder.AddInstruction( @@ -632,7 +674,7 @@ class HeapAlgorithmTestBase : public ::testing::Test { const BufferValue* DummyBufferValue() { const BufferValue::Id id = buffers_.size(); auto const0 = builder_.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); buffers_.emplace_back(MakeUnique(id, const0, ShapeIndex{})); return buffers_.back().get(); } diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index e201359d3d25b7d2dda852762c6de1fcb75685d7..be9098f555e78f3cabfe55481356f8b6841a3a2b 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -145,12 +145,26 @@ message HloInstructionProto { repeated int64 operand_ids = 36; repeated int64 control_predecessor_ids = 37; repeated int64 called_computation_ids = 38; - repeated int64 replica_group_ids = 44; xla.OpSharding sharding = 40; // Backend configuration for the instruction. Has backend-specific meaning. string backend_config = 43; + + // Cross replica op fields. + // TODO(b/112107579): remove replica_group_ids field and always use + // replica_groups. + repeated int64 replica_group_ids = 44; + repeated ReplicaGroup replica_groups = 49; + int64 all_reduce_id = 45; + string cross_replica_sum_barrier = 46; + + // Whether this Send/Recv instruction transfers data to/from the host. Only + // present for Send and Recv instructions and their SendDone and RecvDone + // partners. + bool is_host_transfer = 47; + + xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; } // Serialization of HloComputation. @@ -235,8 +249,9 @@ message BufferAllocationProto { int64 index = 1; int64 size = 2; bool is_thread_local = 3; - bool is_reusable = 4; + bool is_tuple = 11; bool is_entry_computation_parameter = 5; + bool is_constant = 12; int64 parameter_number = 6; repeated int64 parameter_shape_index = 10; bool maybe_live_out = 7; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index a88283ed9a6459b4fa9310e160b59c77d51f1027..e8a4b034b4396860bd5873f43003844ce92dea6c 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -452,15 +452,16 @@ string HloAliasAnalysis::ToString() const { /* static */ StatusOr> HloAliasAnalysis::Run( - HloModule* module) { + HloModule* module, const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer) { VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); - TF_ASSIGN_OR_RETURN( - alias_analysis->dataflow_analysis_, - HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, - /*bitcast_defines_value=*/false)); + TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, + /*bitcast_defines_value=*/false, + fusion_can_share_buffer)); BufferValueMap buffer_map(alias_analysis->dataflow_analysis()); buffer_map.MergeAliasedBuffers(); @@ -493,6 +494,16 @@ StatusOr> HloAliasAnalysis::Run( bool HloAliasAnalysis::HasLiveRangeInterference( const HloOrdering& ordering) const { for (const HloBuffer& buffer : buffers()) { + CHECK(!buffer.values().empty()); + if (ShapeUtil::IsToken(buffer.values().front()->shape())) { + // Tokens have no on-device representation and cannot interfere. + for (const HloValue* value : buffer.values()) { + // If one of the values is a token, all values must be a token. + DCHECK(ShapeUtil::IsToken(value->shape())); + } + continue; + } + // Check that the values in the buffer are totally ordered with respect to // 'ordering'. Begin by sorting the values with respect to 'ordering' with a // tie-break using value ID. The tie-break is necessary because we need a @@ -517,7 +528,6 @@ bool HloAliasAnalysis::HasLiveRangeInterference( // a buffer and A interferes with C, then necessarily A also interferes // with B. So to check interference you only need to check interference // between A and B, and between B and C. - CHECK(!values.empty()); for (int i = 1; i < values.size(); ++i) { if (!ordering.IsDefinedBefore(*values[i - 1], *values[i])) { VLOG(1) << values[i - 1]->ToShortString() << " and " diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index 67dfd4301b3a027a496911ecf6f06841dfd6423a..1fea544730c27efdaa260f55ea81c163165f7ed5 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -39,7 +39,10 @@ class HloAliasAnalysis { public: // The callgraph of the given HloModule must be flattened // (xla::FlattenCallGraph) prior to running the analysis. - static StatusOr> Run(HloModule* module); + static StatusOr> Run( + HloModule* module, + const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer); string ToString() const; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 8f18d50f6e033fab1c01f42017b951c224c22799..da94ab5346e5628b4a603b3ac2d84071904d1e65 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -47,7 +47,9 @@ class HloAliasAnalysisTest : public HloTestBase { // reference to the generated analysis stored in analysis_. HloAliasAnalysis& RunAnalysis() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); - analysis_ = HloAliasAnalysis::Run(module_.get()).ConsumeValueOrDie(); + analysis_ = HloAliasAnalysis::Run(module_.get(), + /*fusion_can_share_buffer=*/nullptr) + .ConsumeValueOrDie(); return *analysis_; } @@ -116,9 +118,9 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) { // Test the analysis on a single binary operation (Add). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); module_->AddEntryComputation(builder.Build()); @@ -228,9 +230,9 @@ TEST_F(HloAliasAnalysisTest, SingleCall) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); @@ -267,9 +269,9 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( @@ -346,15 +348,15 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -439,15 +441,15 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while0 = builder.AddInstruction( @@ -498,7 +500,7 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); return cond_builder.Build(); }; // Build separate condition computations so the call graph is flat. The @@ -543,9 +545,9 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto entry_while = builder.AddInstruction( @@ -608,17 +610,17 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2, constant3})); auto xla_while = builder.AddInstruction( @@ -654,19 +656,18 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { } TEST_F(HloAliasAnalysisTest, TupleSelect) { - // Test a kSelect of a tuple value. Non-top-level element flow through the - // instruction. + // Test a kTupleSelect. Non-top-level element flow through the instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -677,13 +678,13 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { builder.AddInstruction(HloInstruction::CreateTuple({constant4})); const Shape tuple_shape = tuple1->shape(); auto select11 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1)); auto select12 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); auto select34 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4)); auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, select12, select34)); + tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34)); module_->AddEntryComputation(builder.Build()); @@ -718,7 +719,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { } TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { - // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO: + // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO: // // body((F32[], F32[]) %tuple_param): // %negate = Negate(%tuple_param{0}) @@ -754,22 +755,22 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({constant2})); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, select)); @@ -806,7 +807,7 @@ TEST_F(HloAliasAnalysisTest, Bitcast) { // Bitcasting a value should not produce a new buffer. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); @@ -825,7 +826,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) { // interference. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); builder.AddInstruction(HloInstruction::CreateTuple({constant, bitcast})); @@ -844,13 +845,13 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { // the other use of the init. auto builder = HloComputation::Builder(TestName()); auto init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, init->shape(), "param")); auto cond_root = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index b158f44923982642615543cebbd54f00596e2641..441288da1a6859a3f393a298ee02eb4b435e42e0 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -120,6 +120,30 @@ HloInstruction* HloComputation::AddParameter( return instructions_.back().get(); } +namespace { + +// Returns the new name for a fusion parameter when we change its number. +// +// Fusion parameters are named foo.param_1, bar.param_2, etc. We are +// renumbering the parameters, so replace the final number in the name with +// the updated value. +string RenameFusionParameter(const string& original_name, int64 new_param_no) { + const string param_underscore = ".param_"; + size_t index = original_name.rfind(param_underscore); + if (index == string::npos) { + return original_name; + } + string after_param = original_name.substr(index + param_underscore.size()); + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + return StrCat(original_name.substr(0, index + param_underscore.size()), + new_param_no); + } + return original_name; +} + +} // namespace + Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); @@ -132,21 +156,8 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - string param_name = param_instruction->name(); - // Fusion parameters are named foo.param_1, bar.param_2, etc. We are - // renumbering the parameters, so replace the final number in the name with - // the updated value. - const string param_underscore = ".param_"; - size_t index = param_name.rfind(param_underscore); - if (index == string::npos) { - string after_param = name().substr(index + param_underscore.size()); - int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { - param_name = - StrCat(param_name.substr(0, index), param_underscore, param_no); - } - } - + string param_name = + RenameFusionParameter(param_instruction->name(), param_no); HloInstruction* new_instr = AddInstructionInternal(HloInstruction::CreateParameter( param_no, param_instruction->shape(), param_name)); @@ -159,6 +170,34 @@ Status HloComputation::RemoveParameter(int64 param_no) { return Status::OK(); } +Status HloComputation::RemoveUnusedParameters() { + CHECK(IsFusionComputation()); + int64 removed = 0; + for (int64 i = 0; i < param_instructions_.size(); ++i) { + HloInstruction* param_instruction = param_instructions_[i]; + if (param_instruction->user_count() == 0 && + param_instruction != root_instruction()) { + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + ++removed; + continue; + } + + if (removed > 0) { + const int64 param_no = i - removed; + string param_name = + RenameFusionParameter(param_instruction->name(), param_no); + HloInstruction* new_instr = + AddInstructionInternal(HloInstruction::CreateParameter( + param_no, param_instruction->shape(), param_name)); + TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); + param_instructions_[param_no] = new_instr; + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + } + } + param_instructions_.resize(param_instructions_.size() - removed); + return Status::OK(); +} + bool HloComputation::IsRemovable(const HloInstruction* instruction) { // If the instruction has control predecessors or successors then we cannot // remove the instruction without violating ordering constraints (added, for @@ -234,7 +273,6 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) { TF_RET_CHECK(instruction_iterators_.count(instruction) != 0); auto inst_it = instruction_iterators_.at(instruction); (*inst_it)->set_parent(nullptr); - instruction->DetachFromOperands(); instructions_.erase(inst_it); return Status::OK(); } @@ -246,9 +284,8 @@ void HloComputation::set_root_instruction( if (!IsFusionComputation()) { CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), root_instruction_->shape())) - << new_root_instruction->shape().ShortDebugString() - << " is incompatible with " - << root_instruction_->shape().ShortDebugString(); + << new_root_instruction->shape() << " is incompatible with " + << root_instruction_->shape(); } bool root_found = false; for (auto& instruction : instructions_) { @@ -264,46 +301,11 @@ void HloComputation::set_root_instruction( namespace { -// Helper class which computes the post order of an expression rooted at a -// particular instruction. -class InstructionPostOrderer : public DfsHloVisitorWithDefault { - public: - // added_instructions is the set of instructions which have already been - // accounted for in the post order in previous invocations of - // GetOrder. Without this mechanism, instructions which are predecessors of - // multiple root instructions of the computation can be added to the post - // order more than once. - static std::list GetOrder( - HloInstruction* root, - tensorflow::gtl::FlatSet* added_instructions) { - InstructionPostOrderer orderer(added_instructions); - TF_CHECK_OK(root->Accept(&orderer)); - return std::move(orderer.post_order_); - } - - private: - explicit InstructionPostOrderer( - tensorflow::gtl::FlatSet* added_instructions) - : added_instructions_(added_instructions) {} - ~InstructionPostOrderer() override {} - - Status DefaultAction(HloInstruction* hlo_instruction) override { - if (added_instructions_->count(hlo_instruction) == 0) { - post_order_.push_back(hlo_instruction); - added_instructions_->insert(hlo_instruction); - } - return Status::OK(); - } - - std::list post_order_; - tensorflow::gtl::FlatSet* added_instructions_; -}; - // Helper which builds a post order of the HLO call graph. void ComputeComputationPostOrder( HloComputation* computation, tensorflow::gtl::FlatSet* visited, - std::list* post_order) { + std::vector* post_order) { if (visited->insert(computation).second) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -315,48 +317,53 @@ void ComputeComputationPostOrder( } } -std::list ComputeInstructionPostOrder( - HloInstruction* root, tensorflow::gtl::FlatSet* visited) { - std::list post_order; - std::vector> dfs_stack; - dfs_stack.emplace_back(root, false); +enum State { kVisiting, kVisited }; + +void ComputeInstructionPostOrder( + std::vector* post_order, HloInstruction* root, + tensorflow::gtl::FlatMap* visited) { + std::vector dfs_stack; + dfs_stack.push_back(root); while (!dfs_stack.empty()) { const auto current = dfs_stack.back(); - if (current.second) { - dfs_stack.pop_back(); - if (!visited->insert(current.first).second) { - continue; - } - post_order.push_back(current.first); - } else { - if (visited->count(current.first)) { + auto it = visited->find(current); + if (it != visited->end()) { + if (it->second == kVisited) { + // Already visited. dfs_stack.pop_back(); continue; } - dfs_stack.back().second = true; - - // Add the operands to the stack in reverse order so the first operand is - // processed first. This will produce a more natural ordering and a nicer - // result for thigns like HLO stringification. - const auto& operands = current.first->operands(); - for (int64 i = operands.size() - 1; i >= 0; --i) { - dfs_stack.emplace_back(operands[i], false); - } + // Visit this node. + CHECK_EQ(kVisiting, it->second); + dfs_stack.pop_back(); + post_order->push_back(current); + it->second = kVisited; + continue; + } - for (HloInstruction* op : current.first->control_predecessors()) { - dfs_stack.emplace_back(op, false); - } + visited->insert({current, kVisiting}); + + // Add the operands to the stack in reverse order so the first operand is + // processed first. This will produce a more natural ordering and a nicer + // result for thigns like HLO stringification. + const auto& operands = current->operands(); + for (int64 i = operands.size() - 1; i >= 0; --i) { + dfs_stack.emplace_back(operands[i]); + } + + for (HloInstruction* op : current->control_predecessors()) { + dfs_stack.emplace_back(op); } } - return post_order; } } // namespace -std::list HloComputation::MakeInstructionPostOrder() const { - std::list post_order; - std::list trace_instructions; - tensorflow::gtl::FlatSet added_instructions; +std::vector HloComputation::MakeInstructionPostOrder() const { + std::vector post_order; + post_order.reserve(instruction_count()); + std::vector trace_instructions; + tensorflow::gtl::FlatMap visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -364,21 +371,20 @@ std::list HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - post_order.splice( - post_order.end(), - ComputeInstructionPostOrder(instruction.get(), &added_instructions)); + ComputeInstructionPostOrder(&post_order, instruction.get(), &visited); } } - post_order.splice(post_order.end(), trace_instructions); + post_order.insert(post_order.end(), trace_instructions.begin(), + trace_instructions.end()); CHECK_EQ(instructions_.size(), post_order.size()) << "number of instructions does not match post order size"; return post_order; } -std::list HloComputation::MakeEmbeddedComputationsList() +std::vector HloComputation::MakeEmbeddedComputationsList() const { tensorflow::gtl::FlatSet visited; - std::list post_order; + std::vector post_order; // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after @@ -522,23 +528,11 @@ HloInstruction* HloComputation::CreateFusionInstruction( } StatusOr HloComputation::DeepCopyHelper( - HloInstruction* instruction, const ShapeTree* indices_to_copy, - ShapeTree* copies_added, ShapeIndex* index) { - if (ShapeUtil::IsArray(instruction->shape())) { - if (indices_to_copy == nullptr || indices_to_copy->element(*index)) { - // Use kCopy to copy array elements - HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary( - instruction->shape(), HloOpcode::kCopy, instruction)); - if (copies_added != nullptr) { - *copies_added->mutable_element(*index) = copy; - } - return copy; - } else { - // Array elements which are not to be copied are passed through - // transparently. - return instruction; - } - } else if (ShapeUtil::IsTuple(instruction->shape())) { + HloInstruction* instruction, ShapeIndex* index, + const std::function< + HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* computation)>& copy_leaf) { + if (ShapeUtil::IsTuple(instruction->shape())) { std::vector elements; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); i++) { @@ -548,17 +542,22 @@ StatusOr HloComputation::DeepCopyHelper( instruction, i)); index->push_back(i); - TF_ASSIGN_OR_RETURN( - HloInstruction * element, - DeepCopyHelper(gte, indices_to_copy, copies_added, index)); + TF_ASSIGN_OR_RETURN(HloInstruction * element, + DeepCopyHelper(gte, index, copy_leaf)); elements.push_back(element); index->pop_back(); } return AddInstruction(HloInstruction::CreateTuple(elements)); - } else { - return FailedPrecondition( - "Can only copy array and tuple shaped instructions"); } + if (ShapeUtil::IsToken(instruction->shape())) { + // Tokens have no on-device representation and cannot be copied. Pass + // through transparently. + return instruction; + } + + // Array shape. + TF_RET_CHECK(ShapeUtil::IsArray(instruction->shape())); + return copy_leaf(instruction, *index, this); } StatusOr HloComputation::DeepCopyInstruction( @@ -580,7 +579,36 @@ StatusOr HloComputation::DeepCopyInstruction( } ShapeIndex index; - return DeepCopyHelper(instruction, indices_to_copy, copies_added, &index); + auto copy_leaf = [indices_to_copy, copies_added]( + HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* computation) { + if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) { + HloInstruction* copy = computation->AddInstruction( + HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf)); + if (copies_added != nullptr) { + *copies_added->mutable_element(leaf_index) = copy; + } + return copy; + } + // Elements which are not to be copied are passed through + // transparently. + return leaf; + }; + return DeepCopyHelper(instruction, &index, copy_leaf); +} + +StatusOr HloComputation::DeepCopyInstructionWithCustomCopier( + HloInstruction* instruction, + const std::function< + HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* computation)>& copy_leaf) { + if (instruction->parent() != this) { + return FailedPrecondition( + "Can't deep copy instruction %s: instruction is not in computation %s", + instruction->name().c_str(), name().c_str()); + } + ShapeIndex index; + return DeepCopyHelper(instruction, &index, copy_leaf); } ProgramShape HloComputation::ComputeProgramShape() const { @@ -645,7 +673,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, std::unique_ptr HloComputation::ComputeReachability() const { - const std::list all = MakeInstructionPostOrder(); + const auto& all = MakeInstructionPostOrder(); auto result = MakeUnique(all); std::vector inputs; @@ -653,7 +681,7 @@ std::unique_ptr HloComputation::ComputeReachability() inputs.assign(hlo->operands().begin(), hlo->operands().end()); inputs.insert(inputs.end(), hlo->control_predecessors().begin(), hlo->control_predecessors().end()); - result->SetReachabilityToUnion(inputs, hlo); + result->FastSetReachabilityToUnion(inputs, hlo); } return result; } @@ -863,15 +891,6 @@ std::unique_ptr HloComputation::CloneWithReplacements( } } context->MapComputation(this, result.get()); - // We cloned the elements of 'replacements', so they're all going to be - // destroyed. HloInstructions need to be detached from their operands before - // they're destroyed, otherwise they stick around in the operands' users lists - // and cause use-after-frees. - for (auto& kv : replacements) { - if (std::unique_ptr& new_instr = kv.second) { - new_instr->DetachFromOperands(); - } - } return result; } @@ -879,4 +898,13 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) { name_ = name_uniquer->GetUniqueName(name_); } +HloInstruction* HloComputation::GetInstructionWithName( + tensorflow::StringPiece name) { + auto instructions_in_computation = instructions(); + auto it = c_find_if(instructions_in_computation, [&](HloInstruction* instr) { + return instr->name() == name; + }); + return it == instructions_in_computation.end() ? nullptr : *it; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 0da4a305f3d5d694a1918fed294337100b0a27fd..49ed65910f519810740b89760ad815f287e59a91 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ +#include #include #include #include @@ -113,6 +114,11 @@ class HloComputation { // instruction. Status RemoveParameter(int64 param_no); + // Remove unused parameters from the computation. + // Note this is only applicatable to the computation for the fusion + // instruction. + Status RemoveUnusedParameters(); + // Add new parameter instruction to the computation. // This should be a new parameter. Instruction will be appended to parameters // and inserted to the instruction list. @@ -199,7 +205,7 @@ class HloComputation { // Compute and return a post-order of the instructions in the computation. In // this order, definitions of values always appear before their uses. - std::list MakeInstructionPostOrder() const; + std::vector MakeInstructionPostOrder() const; // Computes and returns the reachability between HLO instructions in the // computation. The returned HloReachabilityMap is constructed such that @@ -221,7 +227,7 @@ class HloComputation { // transitively. The embedded computations are sorted such that if computation // A calls computation B (eg, via a map instruction) then A will appear after // B in the list. - std::list MakeEmbeddedComputationsList() const; + std::vector MakeEmbeddedComputationsList() const; // Creates a fusion instruction containing the given instructions. // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion @@ -249,6 +255,14 @@ class HloComputation { const ShapeTree* indices_to_copy = nullptr, ShapeTree* copies_added = nullptr); + // As above, but uses a custom function to copy the leaf nodes, which could + // create alternative HLOs other than kCopy, or even pass-throughs. + StatusOr DeepCopyInstructionWithCustomCopier( + HloInstruction* instruction, + const std::function< + HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* computation)>& copy_leaf); + // Computes and returns the ProgramShape of this computation (shape of // parameters and result with layout). ProgramShape ComputeProgramShape() const; @@ -351,6 +365,10 @@ class HloComputation { unique_id_ = id; } + // Returns the instruction in this computation that has name `name`. Returns + // null if there is no such computation. + HloInstruction* GetInstructionWithName(tensorflow::StringPiece name); + int64 unique_id() const { return unique_id_; } private: @@ -373,8 +391,10 @@ class HloComputation { // Internal helper for recursive copying of an instruction. Creates and // returns a deep copy of the given instruction. StatusOr DeepCopyHelper( - HloInstruction* instruction, const ShapeTree* indices_to_copy, - ShapeTree* copies_added, ShapeIndex* index); + HloInstruction* instruction, ShapeIndex* index, + const std::function< + HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* computation)>& copy_leaf); // Internal helper to collect unreachable roots. std::vector CollectUnreachableRoots() const; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 25469a54c48f4f5cab478aba929f1cc18de8b81f..e4c547033139185d5dd4ef37db2d22a6431c1102 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -118,7 +118,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { // Test GetInstructionPostOrder for a computation with one instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); @@ -129,7 +129,7 @@ TEST_F(HloComputationTest, PostOrderSimple) { // instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( @@ -144,7 +144,7 @@ TEST_F(HloComputationTest, PostOrderTrace) { // Test GetInstructionPostOrder for a computation with a trace instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto trace = @@ -163,13 +163,13 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), @@ -181,11 +181,11 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -205,11 +205,11 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { // computation has multiple roots (dead code). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); // Add three disconnected add expressions. builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant2)); @@ -256,7 +256,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { // Test that DeepCopyInstruction properly copies an array. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); @@ -268,9 +268,9 @@ TEST_F(HloComputationTest, DeepCopyTuple) { // Test that DeepCopyInstruction properly copies a tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -289,7 +289,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) { // copy are specified. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto computation = builder.Build(); { @@ -314,9 +314,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { // specified by the given indices. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto computation = builder.Build(); @@ -371,11 +371,43 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { } } +TEST_F(HloComputationTest, DeepCopyToken) { + // Test that DeepCopyInstruction properly handles tokens which should not be + // copied. + auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); + + // No copy should be added. + EXPECT_THAT(copy, op::AfterAll()); +} + +TEST_F(HloComputationTest, DeepCopyTokenTuple) { + // Test that DeepCopyInstruction properly handles tokens which should not be + // copied. + auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({token, constant})); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); + + // Only the array (second tuple element) should be copied. The token is passed + // through transparently. + EXPECT_THAT(copy, op::Tuple(op::GetTupleElement(tuple), + op::Copy(op::GetTupleElement(tuple)))); +} + TEST_F(HloComputationTest, CycleDetection) { // Test whether the visitor can detect cycles in the graph. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto add = builder.AddInstruction( @@ -385,6 +417,9 @@ TEST_F(HloComputationTest, CycleDetection) { // Add a control dependency to create a cycle. ASSERT_IS_OK(add->AddControlDependencyTo(negate)); + auto instructions = computation->MakeInstructionPostOrder(); + EXPECT_EQ(3, instructions.size()); + const auto visitor = [](HloInstruction* instruction) { return Status::OK(); }; auto visit_status = computation->Accept(visitor); ASSERT_FALSE(visit_status.ok()); @@ -398,7 +433,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { // twice. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto dead_negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary( @@ -421,9 +456,9 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { TEST_F(HloComputationTest, CloneWithControlDependency) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); @@ -467,9 +502,9 @@ TEST_F(HloComputationTest, Reachability) { // There is a control dependency from 'add' to 'exp'. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto negate = builder.AddInstruction( @@ -572,13 +607,14 @@ TEST_F(HloComputationTest, Stringification) { auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); - EXPECT_EQ(computation->ToString(options), - R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + const string expected_computation = + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { %x = f32[5,10]{1,0} parameter(0) %y = f32[20,10]{1,0} parameter(1) %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"); +})"; + EXPECT_EQ(computation->ToString(options), expected_computation); } TEST_F(HloComputationTest, StringificationIndent) { @@ -604,13 +640,14 @@ TEST_F(HloComputationTest, StringificationIndent) { auto options = HloPrintOptions().set_print_metadata(false).set_indent_amount(2); - EXPECT_EQ(computation->ToString(options), - R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + const string expected_computation = + R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { %x = f32[5,10]{1,0} parameter(0) %y = f32[20,10]{1,0} parameter(1) %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} - })"); + })"; + EXPECT_EQ(computation->ToString(options), expected_computation); } TEST_F(HloComputationTest, StringificationCanonical) { @@ -635,21 +672,23 @@ TEST_F(HloComputationTest, StringificationCanonical) { auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); - EXPECT_EQ(computation->ToString(options), - R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + const string expected_computation1 = + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { %x = f32[5,10]{1,0} parameter(0) %y = f32[20,10]{1,0} parameter(1) %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"); +})"; + EXPECT_EQ(computation->ToString(options), expected_computation1); options = HloPrintOptions().Canonical(); - EXPECT_EQ(computation->ToString(options), R"(TransposeDot { + const string expected_computation2 = R"(TransposeDot { tmp_0 = f32[5,10]{1,0} parameter(0) tmp_1 = f32[20,10]{1,0} parameter(1) tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"); +})"; + EXPECT_EQ(computation->ToString(options), expected_computation2); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 35ecd4428d0dfde2de445ea34472d2c78148c6c9..7229031c0c7f8bd374cfb495c7d8c11e9ca8b95e 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -51,14 +51,18 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, Reduce operation. + // Skip Constant, Parameter, Reduce, and AfterAll operation. // TODO(b/35975797): Enable Reduce operation once arbitrary computation // are supported by the evaluator. // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. + // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one + // operand in which case constant folding will be impossible and this + // special case is not necessary. if (instruction->opcode() == HloOpcode::kParameter || instruction->opcode() == HloOpcode::kConstant || instruction->opcode() == HloOpcode::kTuple || - instruction->opcode() == HloOpcode::kReduce) { + instruction->opcode() == HloOpcode::kReduce || + instruction->opcode() == HloOpcode::kAfterAll) { continue; } // Skip instructions with non-constant operands. diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 5d05ccfc0b223d8749a2577ba1bf96b1ab3e761b..64a42c1efc0c788ae8e66fb72b2d9aecec179082 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -41,7 +41,7 @@ using HloConstantFoldingTest = HloTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); @@ -62,7 +62,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); @@ -82,8 +82,8 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({42.0f, 19.0f}))); + HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({42.0f, 19.0f}))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); @@ -120,7 +120,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { for (auto csize : test_config.concat_sizes) { dimensions[test_config.concat_dimension] = csize; concat_size += csize; - auto literal = Literal::CreateFromDimensions(F32, dimensions); + auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); HloInstruction* insn = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); operands.push_back(insn); @@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) { const int64 slice_limits[] = {10, 8, 6, 5, 9}; const int64 slice_strides[] = {1, 1, 1, 1, 1}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - Literal::CreateRandomLiteral( + LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - Literal::CreateRandomLiteral( + LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->Literal::CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 762e1afc71b108b2e32b5a7f7f1bbeb783fc6fbd..1bbb0ff08e26f626f4c3992a5f20ec4990f7db2d 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -49,9 +49,9 @@ Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) { // The default number of bytes accessed for an instruction is the sum of the // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not // handle opaque types. - float bytes_accessed = shape_size_(hlo->shape()); + float bytes_accessed = GetShapeSize(hlo->shape()); for (const HloInstruction* operand : hlo->operands()) { - bytes_accessed += shape_size_(operand->shape()); + bytes_accessed += GetShapeSize(operand->shape()); } current_properties_[kBytesAccessedKey] = bytes_accessed; @@ -121,6 +121,13 @@ Status HloCostAnalysis::HandleElementwiseOp( } } +int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const { + if (!LayoutUtil::HasLayout(shape)) { + return 0; + } + return shape_size_(shape); +} + Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } @@ -155,6 +162,10 @@ Status HloCostAnalysis::HandleConstant(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleIota(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { // GetTupleElement forwards a pointer and does not touch each element in the // output. @@ -164,7 +175,11 @@ Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleSelect(const HloInstruction*) { +Status HloCostAnalysis::HandleSelect(const HloInstruction* hlo) { + return HandleElementwiseOp(hlo); +} + +Status HloCostAnalysis::HandleTupleSelect(const HloInstruction*) { return Status::OK(); } @@ -173,21 +188,21 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) { } Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) { - current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2; + current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2; return Status::OK(); } Status HloCostAnalysis::HandleDynamicSlice( const HloInstruction* dynamic_slice) { current_properties_[kBytesAccessedKey] = - shape_size_(dynamic_slice->shape()) * 2; + GetShapeSize(dynamic_slice->shape()) * 2; return Status::OK(); } Status HloCostAnalysis::HandleDynamicUpdateSlice( const HloInstruction* dynamic_update_slice) { current_properties_[kBytesAccessedKey] = - shape_size_(dynamic_update_slice->operand(1)->shape()) * 2; + GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2; return Status::OK(); } @@ -196,7 +211,7 @@ Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) { // through them). The memory touched is then only the size of the output // index table of the tuple. - current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape()); + current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape()); return Status::OK(); } @@ -393,7 +408,7 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleGenerateToken(const HloInstruction*) { +Status HloCostAnalysis::HandleAfterAll(const HloInstruction*) { return Status::OK(); } @@ -518,12 +533,25 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { // TODO(b/33004697): Compute correct cost here, taking the actual number of // replicas into account. double flops = 0.0; - ShapeUtil::ForEachSubshape( - crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsArray(subshape)) { - flops += ShapeUtil::ElementsIn(subshape); - } - }); + ShapeUtil::ForEachSubshape(crs->shape(), + [&](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); + current_properties_[kFlopsKey] = flops; + return Status::OK(); +} + +Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { + // TODO(b/110096724): Compute correct cost here. + double flops = 0.0; + ShapeUtil::ForEachSubshape(hlo->shape(), + [&](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); current_properties_[kFlopsKey] = flops; return Status::OK(); } @@ -538,15 +566,9 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) { } Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { - // Compute the properties of the fused expression and attribute them to the - // fusion node. Use a dummy shape_size to avoid any errors from trying to - // calculate the size of a shape that does not have a layout, since nodes - // inside fusion nodes do not necessarily have a layout assigned. - ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; }; TF_ASSIGN_OR_RETURN( current_properties_, - ProcessSubcomputation(fusion->fused_instructions_computation(), - &shape_size)); + ProcessSubcomputation(fusion->fused_instructions_computation())); // Fusion nodes that produce a tuple also produce the entries in the tuple. // Ignore the memory accessed inside fused ops, since fusion is supposed to @@ -555,11 +577,11 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { ShapeUtil::ForEachSubshape( fusion->shape(), [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) { - current_properties_[kBytesAccessedKey] += shape_size_(subshape); + current_properties_[kBytesAccessedKey] += GetShapeSize(subshape); }); for (const HloInstruction* operand : fusion->operands()) { - current_properties_[kBytesAccessedKey] += shape_size_(operand->shape()); + current_properties_[kBytesAccessedKey] += GetShapeSize(operand->shape()); } return Status::OK(); @@ -640,6 +662,11 @@ Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { return Status::OK(); } +Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { + // TODO(b/32945756): Compute the properties of the sub-computation. + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } @@ -677,11 +704,8 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { } StatusOr HloCostAnalysis::ProcessSubcomputation( - HloComputation* computation, const ShapeSizeFunction* shape_size) { - if (shape_size == nullptr) { - shape_size = &shape_size_; - } - HloCostAnalysis visitor(*shape_size, per_second_rates_); + HloComputation* computation) { + HloCostAnalysis visitor(shape_size_, per_second_rates_); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); return visitor.properties(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 0d66736fe1d0677d13a63ede7a203d6ac20c76f5..193a04bea0831de2b3aca19b17a445ad73e02e49 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -52,9 +52,11 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleElementwiseUnary(const HloInstruction* hlo) override; Status HandleElementwiseBinary(const HloInstruction* hlo) override; Status HandleConstant(const HloInstruction* constant) override; + Status HandleIota(const HloInstruction* iota) override; Status HandleGetTupleElement( const HloInstruction* get_tuple_element) override; - Status HandleSelect(const HloInstruction* select) override; + Status HandleSelect(const HloInstruction* hlo) override; + Status HandleTupleSelect(const HloInstruction* hlo) override; Status HandleCompare(const HloInstruction* compare) override; Status HandleClamp(const HloInstruction* clamp) override; Status HandleReducePrecision(const HloInstruction* hlo) override; @@ -69,6 +71,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; + Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; Status HandleHostCompute(const HloInstruction* host_compute) override; @@ -97,11 +100,12 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; - Status HandleGenerateToken(const HloInstruction* token) override; + Status HandleAfterAll(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; Status HandleGather(const HloInstruction* gather) override; + Status HandleScatter(const HloInstruction* scatter) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; @@ -147,11 +151,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor { const Properties& per_second_rates); // Returns the properties computed from visiting the computation rooted at the - // given hlo. Uses shape_size_ to calculate shape sizes if shape_size is null, - // otherwise uses shape_size_. - StatusOr ProcessSubcomputation( - HloComputation* computation, - const ShapeSizeFunction* shape_size = nullptr); + // given hlo. + StatusOr ProcessSubcomputation(HloComputation* computation); // Utility function to handle all element-wise operations. Status HandleElementwiseOp(const HloInstruction* hlo_instruction); @@ -168,6 +169,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor { static float GetPropertyForHlo(const HloInstruction& hlo, const string& key, const HloToProperties& hlo_to_properties); + // Decorates shape_size_ by returning 0 immediately if the shape does not have + // a layout. + int64 GetShapeSize(const Shape& shape) const; + // Function which computes the size of the top-level of a given shape (not // including nested elements, if any). If null then bytes_accessed methods // return an error. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index d22bef56730da194816b4ee89dc3196439b350f9..2c854eea18642eb7cb081b4fdfe3bc83627e41ae 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/service.h" @@ -59,9 +59,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a unary user function: x => exp(x + 0.5) { XlaBuilder builder("add_and_exp"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto half = builder.ConstantR0(0.5); - builder.Exp(builder.Add(x, half)); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto half = ConstantR0(&builder, 0.5); + Exp(Add(x, half)); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); add_and_exp_ = computation_status.ConsumeValueOrDie(); @@ -70,9 +70,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary user function: (x, y) => x + y { XlaBuilder builder("add"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); add_ = computation_status.ConsumeValueOrDie(); @@ -81,9 +81,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x)) { XlaBuilder builder("sigmoid"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto one = builder.ConstantR0(1.0); - builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x)))); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto one = ConstantR0(&builder, 1.0); + Div(one, Add(one, Exp(Neg(x)))); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); sigmoid_ = computation_status.ConsumeValueOrDie(); @@ -92,9 +92,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary max function: (x, y) => max (x, y) { XlaBuilder builder("max"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Max(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Max(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); max_ = computation_status.ConsumeValueOrDie(); @@ -103,9 +103,9 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary GT function: (x, y) => x > y { XlaBuilder builder("gt"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Gt(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Gt(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); gt_ = computation_status.ConsumeValueOrDie(); @@ -137,9 +137,9 @@ class HloCostAnalysisTest : public ::testing::Test { TEST_F(HloCostAnalysisTest, MatrixMultiply) { XlaBuilder builder("matrix_multiply"); - auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); - auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); - auto result = builder.Dot(lhs, rhs); + auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); + auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); + Dot(lhs, rhs); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -159,8 +159,8 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { TEST_F(HloCostAnalysisTest, Map) { XlaBuilder builder("map"); - auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); - auto result = builder.Map({input}, add_and_exp_, {0}); + auto input = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10}), "in"); + Map(&builder, {input}, add_and_exp_, {0}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -176,17 +176,17 @@ TEST_F(HloCostAnalysisTest, Map) { TEST_F(HloCostAnalysisTest, Convolution) { XlaBuilder builder("convolution"); - auto input = builder.Parameter( - 0, + auto input = Parameter( + &builder, 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, /*x_dim=*/20}), "input"); - auto kernel = builder.Parameter( - 1, + auto kernel = Parameter( + &builder, 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, /*x_dim=*/3}), "kernel"); - auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid); + Conv(input, kernel, {1, 1}, Padding::kValid); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -206,9 +206,8 @@ TEST_F(HloCostAnalysisTest, Convolution) { TEST_F(HloCostAnalysisTest, Reduce) { XlaBuilder builder("reduce"); auto input = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); - auto result = - builder.Reduce(input, builder.ConstantR0(0.0f), add_, {1}); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + Reduce(input, ConstantR0(&builder, 0.0f), add_, {1}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -224,9 +223,9 @@ TEST_F(HloCostAnalysisTest, Reduce) { TEST_F(HloCostAnalysisTest, ReduceWindow) { XlaBuilder builder("reduce_window"); auto input = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); - auto result = builder.ReduceWindow(input, builder.ConstantR0(0), add_, - {4, 5}, {4, 5}, Padding::kValid); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + ReduceWindow(input, ConstantR0(&builder, 0), add_, {4, 5}, {4, 5}, + Padding::kValid); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -241,12 +240,11 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) { TEST_F(HloCostAnalysisTest, SelectAndScatter) { XlaBuilder builder("select_and_scatter"); auto operand = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto source = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 4}), "source"); - auto result = - builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, - source, builder.ConstantR0(0), add_); + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 4}), "source"); + SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, source, + ConstantR0(&builder, 0), add_); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -261,7 +259,7 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) { TEST_F(HloCostAnalysisTest, Broadcast) { XlaBuilder b("broadcast"); - b.Broadcast(b.ConstantR0(42), {10, 7}); + Broadcast(ConstantR0(&b, 42), {10, 7}); auto hlo_module = BuildHloGraph(&b); HloCostAnalysis analysis(ShapeSize); ASSERT_IS_OK( @@ -273,13 +271,12 @@ TEST_F(HloCostAnalysisTest, Broadcast) { TEST_F(HloCostAnalysisTest, FullyConnectedForward) { XlaBuilder builder("fully_connected_forward"); auto input = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); auto weight = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 20}), "weight"); - auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias"); + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 20}), "weight"); + auto bias = Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {20}), "bias"); // sigmoid(input * weight + bias) - auto result = builder.Map( - {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_, {0, 1}); + Map(&builder, {Add(Dot(input, weight), bias, {1})}, sigmoid_, {0, 1}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -297,11 +294,11 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { HloCostAnalysis conv_analysis(ShapeSize); { XlaBuilder builder("conv_looking_matmul"); - auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), - "input"); - auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), - "weights"); - builder.Conv(lhs, rhs, {1, 1}, Padding::kSame); + auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), + "input"); + auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), + "weights"); + Conv(lhs, rhs, {1, 1}, Padding::kSame); auto hlo_module = BuildHloGraph(&builder); ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( &conv_analysis)); @@ -311,10 +308,10 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { { XlaBuilder builder("matmul"); auto lhs = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); auto rhs = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64}), "weights"); - builder.Dot(lhs, rhs); + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64}), "weights"); + Dot(lhs, rhs); auto hlo_module = BuildHloGraph(&builder); ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept( &matmul_analysis)); @@ -341,13 +338,13 @@ TEST_F(FusionCostAnalysis, LoopFusion) { // tuple = Tuple({sub, sub, mul, C1}) HloComputation::Builder builder(TestName()); auto c1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2))); auto c2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2))); auto c3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2)); @@ -394,9 +391,9 @@ TEST_F(FusionCostAnalysis, NoLayout) { HloComputation::Builder builder(TestName()); auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D(2, 3, 4, 5)))); + LiteralUtil::CreateR4FromArray4D(Array4D(2, 3, 4, 5)))); auto c2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(shape_without_layout, c2, {1})); @@ -419,9 +416,9 @@ TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { XlaBuilder builder("matmul"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y"); - auto tuple = builder.Tuple({x, y}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y"); + Tuple(&builder, {x, y}); auto hlo_module = BuildHloGraph(&builder); ASSERT_IS_OK( @@ -435,21 +432,21 @@ TEST_F(HloCostAnalysisTest, TupleCost) { TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { XlaBuilder builder("BaseDilatedConvolution"); - auto input = builder.Parameter( - 0, + auto input = Parameter( + &builder, 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, /*x_dim=*/20}), "input"); - auto kernel = builder.Parameter( - 1, + auto kernel = Parameter( + &builder, 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, /*x_dim=*/3}), "kernel"); - auto result = builder.ConvGeneralDilated( - input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}}, - /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, - XlaBuilder::CreateDefaultConvDimensionNumbers(2)); + ConvGeneralDilated(input, kernel, /*window_strides=*/{1, 1}, + /*padding=*/{{1, 1}, {1, 1}}, + /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, + XlaBuilder::CreateDefaultConvDimensionNumbers(2)); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -463,8 +460,8 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { TEST_F(HloCostAnalysisTest, Slice) { // Test the analysis on a slice. XlaBuilder builder("slice"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); - auto slice = builder.Slice(x, {0}, {1}, {1}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); + Slice(x, {0}, {1}, {1}); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. @@ -478,8 +475,8 @@ TEST_F(HloCostAnalysisTest, Slice) { TEST_F(HloCostAnalysisTest, DynamicSlice) { // Test the analysis on a slice. XlaBuilder builder("dynamic-slice"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); - auto slice = builder.DynamicSlice(x, builder.ConstantR1({1}), {1}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); + DynamicSlice(x, ConstantR1(&builder, {1}), {1}); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. @@ -493,9 +490,9 @@ TEST_F(HloCostAnalysisTest, DynamicSlice) { TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { // Test the analysis on a slice. XlaBuilder builder("dynamic-update-slice"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); - auto slice = builder.DynamicUpdateSlice(x, builder.ConstantR1({1.0}), - builder.ConstantR1({1})); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x"); + DynamicUpdateSlice(x, ConstantR1(&builder, {1.0}), + ConstantR1(&builder, {1})); auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 0fb65c845a6d4407c81171f6c1569fee98b1d16d..858992a3264a7fe6217374dc1e53f35ef77763c1 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -173,6 +174,29 @@ StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); } +StatusOr MakeMapHlo( + tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation) { + CHECK(!operands.empty()) << "Map Hlo requires at least one operand."; + HloComputation* computation = operands.front()->parent(); + std::vector operand_shapes; + int64 max_operand_rank = 0; + for (const HloInstruction* operand : operands) { + CHECK_EQ(computation, operand->parent()); + operand_shapes.push_back(&operand->shape()); + max_operand_rank = + std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + } + std::vector map_dims(max_operand_rank); + std::iota(map_dims.begin(), map_dims.end(), 0); + TF_ASSIGN_OR_RETURN( + Shape map_shape, + ShapeInference::InferMapShape( + operand_shapes, map_computation->ComputeProgramShape(), map_dims)); + return computation->AddInstruction( + HloInstruction::CreateMap(map_shape, operands, map_computation)); +} + StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); @@ -250,6 +274,38 @@ StatusOr ElideDegenerateDims(HloInstruction* operand, return MakeReshapeHlo(output_shape, operand); } +StatusOr InsertDegenerateDims( + HloInstruction* operand, ArraySlice dims_to_insert) { + CHECK(c_is_sorted(dims_to_insert)); + + const Shape& operand_shape = operand->shape(); + int64 output_shape_rank = + operand_shape.dimensions_size() + dims_to_insert.size(); + for (auto dim_to_insert : dims_to_insert) { + CHECK_LT(dim_to_insert, output_shape_rank); + } + + std::vector output_shape_dim_bounds; + output_shape_dim_bounds.reserve(output_shape_rank); + int64 operand_dims_idx = 0; + int64 dims_to_insert_idx = 0; + for (int64 i = 0; i < output_shape_rank; ++i) { + if (dims_to_insert_idx < dims_to_insert.size() && + i == dims_to_insert[dims_to_insert_idx]) { + output_shape_dim_bounds.push_back(1); + ++dims_to_insert_idx; + } else { + output_shape_dim_bounds.push_back( + operand_shape.dimensions(operand_dims_idx)); + ++operand_dims_idx; + } + } + + Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(), + output_shape_dim_bounds); + return MakeReshapeHlo(output_shape, operand); +} + StatusOr PadVectorWithZeros(HloInstruction* operand, int64 zeros_to_prepend, int64 zeros_to_append) { @@ -261,9 +317,9 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, padding_config_dim.set_edge_padding_high(zeros_to_append); *padding_config.add_dimensions() = padding_config_dim; - HloInstruction* zero = - computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(operand->shape().element_type())))); + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(MakeUnique( + LiteralUtil::Zero(operand->shape().element_type())))); return MakePadHlo(operand, zero, padding_config); } @@ -272,7 +328,7 @@ StatusOr BroadcastZeros( ArraySlice broadcast_dimensions) { HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(element_type)))); + MakeUnique(LiteralUtil::Zero(element_type)))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/broadcast_dimensions); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 49b1402d689a74874e34423a1832a0b6aa15f469..5ff8946fb098b57ae563a8ade47e8323f807a369 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -102,6 +102,12 @@ StatusOr MakeConcatHlo( StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dim_numbers); +// Creates a Map HLO instruction and adds it to the computation containing the +// operands. All operands must be in the same computation. +StatusOr MakeMapHlo( + tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation); + // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of // these add all the instructions they generate into the computation containing @@ -144,6 +150,16 @@ StatusOr ExpandFirstDimIntoNDims( StatusOr ElideDegenerateDims( HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_elide); +// Inserts (via reshape) a set of degenerate dimensions (dimensions containing +// exactly one element), `dims_to_insert` into `operand`. The dimensions in +// `dims_to_insert` refer to the dimensions in the result, and hence should be +// less than the rank of the result. Also, `dims_to_insert` must be sorted. +// +// For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is +// {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34]. +StatusOr InsertDegenerateDims( + HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_insert); + // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the // front and `zeros_to_append` zeros in the back. StatusOr PadVectorWithZeros(HloInstruction* operand, diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index 7e7c4f95fed737f40064224717f409b934e4ff27..60d3e71757d5ce31e025c744e089ff56091d9a43 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -60,8 +60,8 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *Literal::CreateR1({3, 4})); + *module, {LiteralUtil::CreateR1({3, 4})})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({3, 4})); } TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { @@ -82,10 +82,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { std::unique_ptr result_literal, evaluator.Evaluate>( *module, - {Literal::CreateR3( + {LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); CHECK_EQ(*result_literal, - *Literal::CreateR2( + *LiteralUtil::CreateR2( {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); } @@ -103,10 +103,11 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { entry_computation->set_root_instruction(with_1_degenerate_dim_prepended); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {Literal::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *Literal::CreateR2({{9, 10}})); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9, 10}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { @@ -123,10 +124,11 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { entry_computation->set_root_instruction(with_2_degenerate_dims_prepended); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {Literal::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *Literal::CreateR3({{{9, 10}}})); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR3({{{9, 10}}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { @@ -145,8 +147,8 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR0(9)})); - CHECK_EQ(*result_literal, *Literal::CreateR2({{9}})); + *module, {LiteralUtil::CreateR0(9)})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9}})); } TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { @@ -166,9 +168,9 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR1({1, 2, 3, 4, 5, 6})})); + *module, {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); CHECK_EQ(*result_literal, - *Literal::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); + *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { @@ -188,8 +190,8 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *Literal::CreateR1({0, 0, 0, 3, 4, 0})); + *module, {LiteralUtil::CreateR1({3, 4})})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { @@ -209,8 +211,8 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR0(0)})); - CHECK_EQ(*result_literal, *Literal::CreateR2({{0, 0}, {0, 0}})); + *module, {LiteralUtil::CreateR0(0)})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{0, 0}, {0, 0}})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { @@ -230,9 +232,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR0(0.0f)})); + *module, {LiteralUtil::CreateR0(0.0f)})); CHECK_EQ(*result_literal, - *Literal::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); + *LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index a0ee8896230d6dcacb5a8eb607fc00ae5226cfa5..06484f4012fc091f70df7bc8ec231ce3fcf89669 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -24,7 +24,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_domain_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -143,10 +143,8 @@ StatusOr HloCSE::Run(HloModule* module) { if (instruction->operand_count() == 0) { continue; } - // Skip instructions which have side effects or are a domain (which must - // not be CSE-ed). - if (instruction->HasSideEffect() || - instruction->opcode() == HloOpcode::kDomain) { + // Skip instructions which have side effects. + if (instruction->HasSideEffect()) { continue; } diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 16db374566c727f1f3efe2a6d419f1f3caf0aaf1..90fbaa37c5a70a78a9a818b4a8968f3406c671b1 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -53,9 +53,9 @@ TEST_F(HloCseTest, CombineTwoConstants) { // Test that two identical constants are commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -72,7 +72,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_EQ(42.0f, constant->literal().Get({})); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = Literal::CreateR0(84.0); + auto expected = LiteralUtil::CreateR0(84.0); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -81,10 +81,10 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { // the pass is not layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -104,7 +104,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { EXPECT_THAT(add, op::Add(first_operand, first_operand)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -113,10 +113,10 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { // if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -134,7 +134,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -144,20 +144,20 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { auto builder = HloComputation::Builder(TestName()); std::vector constants; constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)))); // Duplicate the float constant to verify something happens. constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)))); const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); for (int64 i = 0; i < constants.size(); ++i) { @@ -188,13 +188,13 @@ TEST_F(HloCseTest, NonscalarConstants) { // Test that identical nonscalar constants are merged. auto builder = HloComputation::Builder(TestName()); auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); // Create a constant which has the same shape but a different value. auto uncommon_constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}))); + LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}))); // Tie the constants together with a tuple. This makes it easier to refer to // the constant instructions via their use. @@ -223,7 +223,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test that three identical instructions are commoned. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -239,7 +239,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { EXPECT_EQ(5, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); - HloCSE cse(/*is_layout_sensitive=*/false); + HloCSE cse(/*is_layout_sensitive=*/true); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); @@ -248,12 +248,189 @@ TEST_F(HloCseTest, IdenticalInstructions) { EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand, first_operand)); } +// Test two identical while loops with same inputs +TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) { + auto module = ParseHloString(R"( + HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput + + %body (param: (f32[], f32[])) -> (f32[], f32[]) { + %param = (f32[], f32[]) parameter(0) + %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param), +index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param), +index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) + ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add) + } + + %condition (param.1: (f32[], f32[])) -> pred[] { + %param.1 = (f32[], f32[]) parameter(0) + ROOT %constant = pred[] constant(false) + } + + %condition.1 (param.2: (f32[], f32[])) -> pred[] { + %param.2 = (f32[], f32[]) parameter(0) + ROOT %constant.1 = pred[] constant(false) + } + + ENTRY %WhileLoopsIdenticalConditionsAndBodiesSameInput () -> (f32[], f32[]) +{ %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %tuple.1 = +(f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) %while = (f32[], +f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT +%while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), +condition=%condition.1, body=%body + } + )") + .ValueOrDie(); + + auto computation = module->entry_computation(); + + EXPECT_EQ(5, computation->instruction_count()); + HloCSE cse(true); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_EQ(4, computation->instruction_count()); +} + +// Test two while loops with same conditions, same inputs, but different +// bodies +TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) { + auto module = ParseHloString(R"( + HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies + + %body (param: (f32[], f32[])) -> (f32[], f32[]) { + %param = (f32[], f32[]) parameter(0) + %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param), +index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param), +index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) + ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add) + } + + %body2 (param.1: (f32[], f32[])) -> (f32[], f32[]) { + %param.1 = (f32[], f32[]) parameter(0) + %get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %param.1), +index=0 %get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %param.1), +index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[] +%get-tuple-element.3) ROOT %tuple.2 = (f32[], f32[]) tuple(f32[] +%get-tuple-element.2, f32[] %sub) + } + + %condition (param.2: (f32[], f32[])) -> pred[] { + %param.2 = (f32[], f32[]) parameter(0) + ROOT %constant = pred[] constant(false) + } + + %condition.1 (param.3: (f32[], f32[])) -> pred[] { + %param.3 = (f32[], f32[]) parameter(0) + ROOT %constant.1 = pred[] constant(false) + } + + ENTRY %WhileLoopsIdenticalConditionsSameInputAndDifferentBodies () -> +(f32[], f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) + %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) + %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), +condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], +f32[]) %tuple.1), condition=%condition.1, body=%body2 + } + )") + .ValueOrDie(); + + auto computation = module->entry_computation(); + + EXPECT_EQ(5, computation->instruction_count()); + HloCSE cse(true); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_EQ(5, computation->instruction_count()); +} + +// Test two identical while loops with different inputs +TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) { + auto module = ParseHloString(R"( + HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput + + %body (param: (f32[], f32[])) -> (f32[], f32[]) { + %param = (f32[], f32[]) parameter(0) + %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param), +index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param), +index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) + ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add) + } + + %condition (param.1: (f32[], f32[])) -> pred[] { + %param.1 = (f32[], f32[]) parameter(0) + ROOT %constant = pred[] constant(false) + } + + %condition.1 (param.2: (f32[], f32[])) -> pred[] { + %param.2 = (f32[], f32[]) parameter(0) + ROOT %constant.1 = pred[] constant(false) + } + + ENTRY %WhileLoopsIdenticalConditionsAndBodiesDifferentInput () -> (f32[], +f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) + %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) + %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), +condition=%condition, body=%body %constant.4 = f32[] constant(1) %constant.5 = +f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[] +%constant.5) ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.2), +condition=%condition.1, body=%body + } + + )") + .ValueOrDie(); + + auto computation = module->entry_computation(); + + EXPECT_EQ(8, computation->instruction_count()); + HloCSE cse(true); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_EQ(8, computation->instruction_count()); +} + +// Test two while loops with identical bodies and same inputs, but different +// conditions +TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) { + auto module = ParseHloString(R"( + HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions + + %body (param: (f32[], f32[])) -> (f32[], f32[]) { + %param = (f32[], f32[]) parameter(0) + %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param), +index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param), +index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) + ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add) + } + + %condition (param.1: (f32[], f32[])) -> pred[] { + %param.1 = (f32[], f32[]) parameter(0) + ROOT %constant = pred[] constant(false) + } + + %condition.1 (param.2: (f32[], f32[])) -> pred[] { + %param.2 = (f32[], f32[]) parameter(0) + ROOT %constant.1 = pred[] constant(true) + } + + ENTRY %WhileLoopsIdenticalBodiesAndInputDifferntConditions () -> (f32[], +f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) + %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) + %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), +condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], +f32[]) %tuple.1), condition=%condition.1, body=%body + })") + .ValueOrDie(); + + auto computation = module->entry_computation(); + + EXPECT_EQ(5, computation->instruction_count()); + HloCSE cse(true); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_EQ(5, computation->instruction_count()); +} + TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { // Test that two identical instructions with different layouts are *not* // commoned if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -284,7 +461,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { // the pass is layout insensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -362,7 +539,7 @@ TEST_F(HloCseTest, IdenticalExpressions) { // The *1 instructions should be merged with the *2 instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kNegate, constant)); @@ -400,9 +577,9 @@ TEST_F(HloCseTest, DoNotCombineRng) { // Test that two RNG ops are not commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto rng1 = builder.AddInstruction(HloInstruction::CreateRng( ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, {constant1, constant2})); @@ -442,9 +619,9 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); auto builder = HloComputation::Builder(TestName() + "_rng_fun"); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto rng = builder.AddInstruction(HloInstruction::CreateRng( scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2})); auto param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -459,7 +636,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({5.0f}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({5.0f}))); auto rng1 = builder.AddInstruction( HloInstruction::CreateMap(constant->shape(), {constant}, rng_function)); auto rng2 = builder.AddInstruction( @@ -521,9 +698,9 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { // in this case) are not collapsed. auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -536,5 +713,40 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { EXPECT_EQ(2, computation->instruction_count()); } +TEST_F(HloCseTest, Domain) { + auto module = ParseHloString(R"( +HloModule module +ENTRY %entry { + %param = f32[] parameter(0), sharding={maximal device=0} + %domain.0 = f32[] domain(%param), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %domain.1 = f32[] domain(%param), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %domain.2 = f32[] domain(%param), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=2}} + %negate.0 = f32[] negate(%domain.0) + %negate.1 = f32[] negate(%domain.1) + %negate.2 = f32[] negate(%domain.2) + %domain.3 = f32[] domain(%negate.0), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %domain.4 = f32[] domain(%negate.1), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %domain.5 = f32[] domain(%negate.2), + domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} + %add = f32[] add(%domain.3, %domain.4) + ROOT %sub = f32[] subtract(%add, %domain.5) +})") + .ValueOrDie(); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + LOG(INFO) << "AAAAA " << module->ToString(); + const HloInstruction* sub = module->entry_computation()->root_instruction(); + const HloInstruction* add = sub->operand(0); + EXPECT_EQ(add->operand(0), add->operand(1)); + EXPECT_NE(add->operand(0), sub->operand(1)); + EXPECT_NE(add->operand(1), sub->operand(1)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index d0200058683b2db8f5f0469d6c643014881f741e..bbfb0c253f583b633c4b2c34b2f068b563d3d9e0 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -34,16 +34,86 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { + +// We have this pattern in dynamaic update slice fusion, which should be +// supported: +// +// Parameters: p0, p1 +// Fusion +// ds = DynamicSlice(p0, p1) +// ROOT DynamicUpdateslice(p0, ds, p1) +// +// In this case, we should be able to reuse p0 and output, although p0 has +// multiple uses. +bool MultiDynamicSliceUseShareSameIndices( + tensorflow::gtl::ArraySlice uses) { + if (uses.empty()) { + return false; + } + const HloInstruction* indices = nullptr; + for (HloUse use : uses) { + auto user = use.instruction; + if (user->opcode() == HloOpcode::kDynamicUpdateSlice) { + if (indices == nullptr) { + indices = user->operand(2); + } else if (indices != user->operand(2)) { + return false; + } + if (use.operand_number != 0) { + return false; + } + } else if (user->opcode() == HloOpcode::kDynamicSlice) { + if (indices == nullptr) { + indices = user->operand(1); + } else if (indices != user->operand(1)) { + return false; + } + } else { + return false; + } + } + return true; +} + +} // namespace using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form, - bool bitcast_defines_value) +HloDataflowAnalysis::HloDataflowAnalysis( + const HloModule& module, bool ssa_form, bool bitcast_defines_value, + const FusionCanShareBufferFunction& fusion_can_share_buffer) : module_(module), ssa_form_(ssa_form), bitcast_defines_value_(bitcast_defines_value), - call_graph_(CallGraph::Build(&module)) {} + call_graph_(CallGraph::Build(&module)), + fusion_can_share_buffer_(fusion_can_share_buffer) {} + +bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( + const HloInstruction* inst) { + tensorflow::gtl::FlatSet visited; + tensorflow::gtl::InlinedVector stack; + stack.push_back(inst); + while (!stack.empty()) { + const HloInstruction* current = stack.back(); + stack.pop_back(); + visited.insert(current); + for (const HloInstruction* user : current->users()) { + // Found a user that is non-elementwise on current instruction. + for (const int64 use_index : user->OperandIndices(current)) { + if (!user->IsElementwiseOnOperand(use_index) && + user->opcode() != HloOpcode::kTuple) { + return false; + } + } + if (!visited.count(user)) { + stack.push_back(user); + } + } + } + return true; +} bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const ShapeIndex& index) const { @@ -328,18 +398,17 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) { CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone); bool changed = false; - // RecvDone forwards the operand value at {0} to the output. + // RecvDone forwards the operand value at {0} to element {0} of its output. for (auto& pair : GetInstructionValueSet(recv_done)) { ShapeIndex& index = pair.first; HloValueSet& value_set = pair.second; - ShapeIndex operand_index = {0}; - for (int64 i : index) { - operand_index.push_back(i); + if (index.empty() || index[0] != 0) { + continue; } const HloValueSet& operand_value_set = - GetValueSet(recv_done->operand(0), operand_index); + GetValueSet(recv_done->operand(0), index); if (value_set != operand_value_set) { value_set = operand_value_set; changed = true; @@ -396,6 +465,24 @@ bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { return changed; } +bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) { + // Domain instructions just forward their operand. Given that domains can have + // a tuple operand, we iterate through its indexes, like for copies. + // Unlike copies though we also propagate the top-level value. + CHECK_EQ(domain->opcode(), HloOpcode::kDomain); + bool changed = false; + for (auto& pair : GetInstructionValueSet(domain)) { + const ShapeIndex& index = pair.first; + HloValueSet& value_set = pair.second; + HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) { CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement); bool changed = false; @@ -490,17 +577,17 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { } } -bool HloDataflowAnalysis::UpdateSelectValueSet(HloInstruction* select) { - CHECK_EQ(select->opcode(), HloOpcode::kSelect); - // A phi value is not defined at a kSelect instruction because kSelect does - // not create a new value. Rather it forwards a value from its operands. This - // contrasts with kWhile instruction (which does define a phi value) which has - // in-place update semantics. +bool HloDataflowAnalysis::UpdateTupleSelectValueSet(HloInstruction* select) { + CHECK_EQ(select->opcode(), HloOpcode::kTupleSelect); + // A phi value is not defined at a kTupleSelect instruction because + // kTupleSelect does not create a new value. Rather it forwards a value from + // its operands. This contrasts with kWhile instruction (which does define a + // phi value) which has in-place update semantics. bool changed = false; for (auto& pair : GetInstructionValueSet(select)) { const ShapeIndex& index = pair.first; if (index.empty()) { - // kSelect copies (not forwards) the top-level value. + // kTupleSelect copies (not forwards) the top-level value. continue; } HloValueSet& value_set = pair.second; @@ -556,12 +643,14 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateBitcastValueSet(instruction); case HloOpcode::kSlice: return UpdateSliceValueSet(instruction); + case HloOpcode::kDomain: + return UpdateDomainValueSet(instruction); case HloOpcode::kCopy: return UpdateCopyValueSet(instruction); case HloOpcode::kGetTupleElement: return UpdateGetTupleElementValueSet(instruction); - case HloOpcode::kSelect: - return UpdateSelectValueSet(instruction); + case HloOpcode::kTupleSelect: + return UpdateTupleSelectValueSet(instruction); case HloOpcode::kTuple: return UpdateTupleValueSet(instruction); case HloOpcode::kParameter: @@ -734,6 +823,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kGetTupleElement: + case HloOpcode::kDomain: // These instructions define no values. The values in their output // flow from their operands or from cross computation dataflow. break; @@ -759,21 +849,25 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { } break; case HloOpcode::kCopy: - case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: case HloOpcode::kTuple: // These instructions only define their top-level values. Any other // values flow from their operands. define_top_level_only(); break; case HloOpcode::kRecvDone: - // RecvDone aliases its input tuple element {0}, therefore does not - // define any values. + // RecvDone produces a two-element tuple. Element zero aliases its + // input tuple element {0}; element one is a token. + define_value_at(/*index=*/{}); + define_value_at(/*index=*/{1}); break; case HloOpcode::kSend: - // Send produces a tuple of {aliased operand, U32 context}, therefore - // only defines the top-level tuple and the tuple element at {1}. + // Send produces a tuple of {aliased operand, U32 context, token}, + // therefore only defines the top-level tuple and the tuple elements + // at {1} and {2}. define_value_at(/*index=*/{}); define_value_at(/*index=*/{1}); + define_value_at(/*index=*/{2}); break; default: define_all_values(); @@ -787,12 +881,13 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { /* static */ StatusOr> HloDataflowAnalysis::Run( - const HloModule& module, bool ssa_form, bool bitcast_defines_value) { + const HloModule& module, bool ssa_form, bool bitcast_defines_value, + const FusionCanShareBufferFunction& fusion_can_share_buffer) { VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto dataflow_analysis = WrapUnique( - new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); + auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis( + module, ssa_form, bitcast_defines_value, fusion_can_share_buffer)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); dataflow_analysis->Propagate(); @@ -915,23 +1010,26 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( ShapeUtil::GetSubshape(operand->shape(), operand_index); const Shape& user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { return false; } if (user->opcode() == HloOpcode::kFusion) { + if (fusion_can_share_buffer_ != nullptr) { + return fusion_can_share_buffer_(user, operand); + } // Get the parameter associated with 'operand'; HloInstruction* fusion_param = user->fused_parameter(user->operand_index(operand)); const HloValue& value = GetValueDefinedAt(fusion_param, operand_index); - if (value.uses().size() != 1) { - return false; + if (MultiDynamicSliceUseShareSameIndices(value.uses())) { + return true; } - const HloUse& use = value.uses()[0]; - - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || + user->fusion_kind() == HloInstruction::FusionKind::kInput) { if (user->fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice) { // Loop fusion with kDynamicUpdateSlice fused root. @@ -939,11 +1037,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Returns true iff there is exactly one use of 'operand' at shape index // 'operand_index', and this singleton use is the fused root at operand // index 0. - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; + if (value.uses().size() == 1) { + const HloUse& use = value.uses()[0]; + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } + return false; } - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + return AreTransitiveUsesElementwiseOrTuple(fusion_param); + } + if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. // Check if one operand of kAdd fused root is kDot or kConvolution. @@ -964,8 +1068,12 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Returns true iff there is exactly one use of 'operand' at shape index // 'operand_index', and this singleton use is the fused root (at operand // index 'other_add_operand_index'). - return use.instruction == user->fused_expression_root() && - use.operand_number == other_add_operand_index; + if (value.uses().size() == 1) { + const HloUse& use = value.uses()[0]; + return use.instruction == user->fused_expression_root() && + use.operand_number == other_add_operand_index; + } + return false; } } @@ -976,6 +1084,21 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kSort) { + // Only valid if there are no other users. + if (operand->users().size() != 1) { + return false; + } + // If we only sort keys, the output of sort is not a tuple, so we can always + // share the buffer. + if (user->operand_count() == 1) { + return true; + } + CHECK(!user_index.empty()); + // Only share with the right tuple element buffer. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && user_index[0] == operand_indices[0]; + } if (user->opcode() == HloOpcode::kCall) { // Get all uses of value defined by 'operand' at 'operand_index'. const auto& uses = GetValueDefinedAt(operand, operand_index).uses(); @@ -1003,9 +1126,6 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Loop fusions that contain transposing copies won't reach here as they have // different layouts, which fails the check in the beginning of this function. - // - // Multi-output fusion will fail the check here as tuples are not considered - // an elementwise operation. return user->IsElementwiseOnOperand(user->operand_index(operand)); } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 9868746b6113881949e388cd2a4aa9f610b1fdb7..f4abc7a7c7dcfb223067fe946bec0c5ef32f206b 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -42,6 +42,20 @@ namespace xla { // Analysis which identifies all HLO values and their uses in an HLO module. class HloDataflowAnalysis { public: + // Different backends can have very different ways to do fusion, so we give + // backends the flexibility to decide whether an fusion instruction can share + // buffer with it's operands. If this is not specified, a default strategy + // will be used; if this is specified, it will be applied *in addition* to the + // default strategy. + // + // The first parameter of the function should be the fusion instruction, the + // second parameter should be an operand of the fusion instruction. + // + // TODO(b/80315712): Find a better way to tell whether a fusion can share + // buffer. + using FusionCanShareBufferFunction = std::function; + // Run dataflow analysis on the given module. Parameters: // // ssa_form : If true then new values are defined at the merge points of @@ -61,7 +75,10 @@ class HloDataflowAnalysis { // value of its operand. static StatusOr> Run( const HloModule& module, bool ssa_form = false, - bool bitcast_defines_value = false); + bool bitcast_defines_value = false, + const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr); + + static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst); // Returns true if 'instruction' defines an HLO value at the given shape index // of its output. @@ -136,8 +153,10 @@ class HloDataflowAnalysis { const ShapeIndex& user_index) const; protected: - HloDataflowAnalysis(const HloModule& module, bool ssa_form, - bool bitcast_defines_value = false); + HloDataflowAnalysis( + const HloModule& module, bool ssa_form, + bool bitcast_defines_value = false, + const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr); // Returns a new HloValue defined at the given instruction and shape index. HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, @@ -166,10 +185,11 @@ class HloDataflowAnalysis { bool UpdateCallValueSet(HloInstruction* call); bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); + bool UpdateDomainValueSet(HloInstruction* domain); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); bool UpdateRecvDoneValueSet(HloInstruction* recv_done); - bool UpdateSelectValueSet(HloInstruction* select); + bool UpdateTupleSelectValueSet(HloInstruction* select); bool UpdateSendValueSet(HloInstruction* send); bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateWhileValueSet(HloInstruction* xla_while); @@ -221,6 +241,10 @@ class HloDataflowAnalysis { // The Id to use for the next HloValue. HloValue::Id next_value_id_ = 0; + + // Backend specific function that decides whether a fusion can share buffer + // with its operand. + FusionCanShareBufferFunction fusion_can_share_buffer_ = nullptr; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index db1822ec47a7f52e2c3ef8dcbf433cd787ef75ab..4755c4a0cf8d268b1c47e596a14605eb2c60b36c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -101,9 +101,9 @@ TEST_P(HloDataflowAnalysisTest, BinaryOperation) { // Test the dataflow for a simple binary operation (Add). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); module_->AddEntryComputation(builder.Build()); @@ -198,9 +198,9 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { // Verify the dataflow through a nested tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto nested_tuple = builder.AddInstruction( @@ -259,9 +259,9 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); @@ -308,9 +308,9 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( @@ -362,9 +362,9 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( @@ -426,9 +426,9 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, outer_computation)); module_->AddEntryComputation(builder.Build()); @@ -493,15 +493,15 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -594,15 +594,15 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while0 = builder.AddInstruction( @@ -653,7 +653,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); @@ -691,9 +691,9 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto entry_while = builder.AddInstruction( @@ -780,15 +780,15 @@ TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -840,11 +840,11 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) { // Test a kSelect of an array value. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2)); @@ -860,19 +860,18 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) { } TEST_P(HloDataflowAnalysisTest, TupleSelect) { - // Test a kSelect of a tuple value. Non-top-level element flow through the - // instruction. + // Test a kTupleSelect. Non-top-level element flow through the instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -883,20 +882,20 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { builder.AddInstruction(HloInstruction::CreateTuple({constant4})); const Shape tuple_shape = tuple1->shape(); auto select11 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1)); auto select12 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); auto select34 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4)); auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, select12, select34)); + tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34)); module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - // Top-level value is always defined by a kSelect. + // Top-level value is always defined by a kTupleSelect. EXPECT_TRUE(analysis.ValueIsDefinedAt(select11)); EXPECT_TRUE(analysis.ValueIsDefinedAt(select12)); EXPECT_TRUE(analysis.ValueIsDefinedAt(select34)); @@ -937,20 +936,20 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { } TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { - // Test kSelect of a nested tuple. + // Test kTupleSelect of a nested tuple. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); auto constant5 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0))); auto inner_tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant2, constant3})); auto tuple1 = builder.AddInstruction( @@ -960,7 +959,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant4, inner_tuple2})); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); module_->AddEntryComputation(builder.Build()); @@ -983,7 +982,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { } TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { - // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO: + // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO: // // body((F32[], F32[]) %tuple_param): // %add = Add(%tuple_param{0}, %tuple_param{1}) @@ -1026,24 +1025,24 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({constant2})); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); auto gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, select, 0)); auto tuple = @@ -1089,7 +1088,7 @@ TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) { // Test the bitcast_defines_value flag to the dataflow analysis. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); @@ -1158,44 +1157,50 @@ TEST_P(HloDataflowAnalysisTest, SendAndSendDone) { auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto send = builder.AddInstruction( - HloInstruction::CreateSend(param, /*channel_id=*/0)); + HloInstruction::CreateSend(param, token, /*channel_id=*/0)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_EQ(analysis.values().size(), 4); + EXPECT_EQ(analysis.values().size(), 6); EXPECT_TRUE(analysis.ValueIsDefinedAt(param)); EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{})); EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0})); EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{2})); EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done)); EXPECT_THAT(HloValuesAt(send, /*index=*/{0}), UnorderedElementsAre(analysis.GetValueDefinedAt(param))); } TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) { - // Test that a RecvDone forwards its operand tuple element at {0} to the - // output. + // Test that a RecvDone forwards its operand tuple element at {0} to element + // {0} of the output. auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto recv = builder.AddInstruction( - HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0)); + HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_EQ(analysis.values().size(), 3); + EXPECT_EQ(analysis.values().size(), 7); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{})); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0})); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1})); - EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done)); - EXPECT_THAT(HloValuesAt(recv_done), + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{2})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{0})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{1})); + EXPECT_THAT(HloValuesAt(recv_done, /*index=*/{0}), UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0}))); EXPECT_TRUE( analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module()); @@ -1304,13 +1309,13 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { auto body_param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "body_param")); auto constant = body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto exp = body_builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant)); auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, exp, body_param)); auto dead_constant = body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kNegate, dead_constant)); HloComputation* body = module_->AddEmbeddedComputation( @@ -1320,7 +1325,7 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "cond_param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); @@ -1571,11 +1576,11 @@ TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.0f))); auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( scalar_shape_, pred, constant1, true_computation, constant2, false_computation)); @@ -1662,11 +1667,11 @@ TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.0f))); auto tuple_operand = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( @@ -1792,15 +1797,15 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) { // Build entry computation. auto builder = HloComputation::Builder(TestName()); auto pred1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto pred2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.2f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.2f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.3f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.3f))); auto tuple_operand = builder.AddInstruction( HloInstruction::CreateTuple({pred2, constant1, constant2})); auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( @@ -1880,9 +1885,14 @@ class HloDataflowAnalysisTestBase : public HloTestBase { computation_ = module_->AddEntryComputation(std::move(computation)); } - void RunAnalysis() { + void RunAnalysis(const HloDataflowAnalysis::FusionCanShareBufferFunction& + fusion_can_share_buffer = nullptr) { CHECK_NOTNULL(module_.get()); - dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); + dataflow_analysis_ = + HloDataflowAnalysis::Run(*module_, /*ssa_form=*/false, + /*bitcast_defines_value=*/false, + fusion_can_share_buffer) + .ConsumeValueOrDie(); } void BuildModuleAndRunAnalysis(std::unique_ptr computation) { @@ -1933,9 +1943,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -1998,7 +2008,7 @@ TEST_F(CanShareOperandBufferWithUserTest, } TEST_F(CanShareOperandBufferWithUserTest, - MultiOutputFusionCantAliasOperandBuffer) { + MultiOutputFusionCanAliasOperandBuffer) { auto builder = HloComputation::Builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); @@ -2022,14 +2032,14 @@ TEST_F(CanShareOperandBufferWithUserTest, {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop); RunAnalysis(); - EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, - fusion, {0})); - EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, - fusion, {1})); - EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, - fusion, {0})); - EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, - fusion, {1})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {1})); } TEST_F(CanShareOperandBufferWithUserTest, @@ -2038,7 +2048,7 @@ TEST_F(CanShareOperandBufferWithUserTest, Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -2057,6 +2067,31 @@ TEST_F(CanShareOperandBufferWithUserTest, fusion, {})); } +TEST_F(CanShareOperandBufferWithUserTest, + CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + Shape slice_shape = ShapeUtil::MakeShape(F32, {1, 2}); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "param0")); + auto index = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0}))); + auto ds = builder.AddInstruction( + HloInstruction::CreateDynamicSlice(slice_shape, param, index, {1, 2, 2})); + + auto dus = builder.AddInstruction( + HloInstruction::CreateDynamicUpdateSlice(data_shape, param, ds, index)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dus, ds, index}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { auto builder = HloComputation::Builder(TestName()); @@ -2109,9 +2144,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -2132,7 +2167,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { } TEST_F(CanShareOperandBufferWithUserTest, - FusedDynamicUpdateSliceWithConvertCantShare) { + FusedDynamicUpdateSliceWithConvertCanShare) { auto builder = HloComputation::Builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {8}); @@ -2149,9 +2184,9 @@ TEST_F(CanShareOperandBufferWithUserTest, // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape_bf16, convert1, update, starts)); @@ -2166,8 +2201,7 @@ TEST_F(CanShareOperandBufferWithUserTest, HloInstruction::FusionKind::kLoop); RunAnalysis(); - // The fusion instruction can't share with tuple element 1. - EXPECT_FALSE( + EXPECT_TRUE( dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {})); } @@ -2198,14 +2232,56 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto sort = + builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + Shape values_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto values = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values")); + auto sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The buffer for the keys can be shared with the first tuple entry. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0})); + // The buffer for the values can be shared with the second tuple entry. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {1})); + // Verify that the buffers are not shared with the "wrong" tuple entry. + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {0})); +} + TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto builder = HloComputation::Builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -2214,7 +2290,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -2236,7 +2312,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -2244,7 +2320,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { HloInstruction::CreateReverse(data_shape, operand, {0, 1})); auto two = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); @@ -2259,10 +2335,37 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { fusion, {})); } +TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kMultiply, operand, operand)); + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, mul}, HloInstruction::FusionKind::kInput); + RunAnalysis(/*fusion_can_share_buffer=*/[](const HloInstruction* fusion, + const HloInstruction*) { + return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop; + }); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto make_cond = [this, &data_shape]() { + auto make_cond = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); @@ -2271,7 +2374,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - auto make_body = [this, &data_shape]() { + auto make_body = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Body"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); @@ -2309,7 +2412,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto sub_param = sub_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "sub_param")); auto one = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto ones = sub_builder.AddInstruction( HloInstruction::CreateBroadcast(shape, one, {1})); auto add = sub_builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index fcd723af146e2227b8661b1a4993f1338f7de389..7d35e251ca21951036336ff1a1eb4aabc87bc5ca 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -41,20 +41,13 @@ StatusOr HloDCE::Run(HloModule* module) { XLA_VLOG_LINES(2, module->ToString()); for (auto* computation : module->MakeComputationPostOrder()) { - std::unordered_set live_instructions; - TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( - [&live_instructions](HloInstruction* instruction) { - live_instructions.insert(instruction); - return Status::OK(); - })); - // Remove any dead roots and their dead transitive operands. Collect them // into a separate list first to avoid problems with iterating through the // computation's instruction while simultaneously removing instructions. std::vector dead_roots; for (auto* instruction : computation->instructions()) { - if (instruction->user_count() == 0 && - live_instructions.count(instruction) == 0 && + if (instruction != computation->root_instruction() && + instruction->user_count() == 0 && computation->IsRemovable(instruction) && !instruction->HasSideEffect()) { dead_roots.push_back(instruction); @@ -85,8 +78,7 @@ StatusOr HloDCE::Run(HloModule* module) { } // Remove dead computations. - std::list computations = module->MakeComputationPostOrder(); - for (auto* computation : computations) { + for (auto* computation : module->MakeComputationPostOrder()) { if (live_computations.count(computation) == 0) { TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); changed = true; diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 5a56607a665c4cbeb7b2572f182b88e890602968..26e3736e01270dbc6ca67647e814843aba2d1e3d 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -53,9 +53,9 @@ TEST_F(HloDceTest, NoDeadCode) { // Verify that no dead code is removed from a computation with no dead code. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -74,20 +74,21 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) { // Verify that side-effect instructions (Send in this test) are not removed. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); builder.AddInstruction( - HloInstruction::CreateSend(constant, /*channel_id=*/0)); + HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); builder.AddInstruction(HloInstruction::CreateTuple({})); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(3, computation->instruction_count()); + EXPECT_EQ(4, computation->instruction_count()); HloDCE dce; EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); - EXPECT_EQ(3, computation->instruction_count()); + EXPECT_EQ(4, computation->instruction_count()); } TEST_F(HloDceTest, DeadParameters) { @@ -126,9 +127,9 @@ TEST_F(HloDceTest, ControlDependencies) { // Verify that instructions with control dependencies are not removed. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); // Create two dead instructions: a negate and an add. auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary( @@ -223,7 +224,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { auto param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "cond_param")); auto constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); cond_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, constant)); } @@ -234,9 +235,9 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { { auto param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); - - auto infeed = - body_builder.AddInstruction(HloInstruction::CreateInfeed(shape, "")); + auto token = body_builder.AddInstruction(HloInstruction::CreateToken()); + auto infeed = body_builder.AddInstruction( + HloInstruction::CreateInfeed(shape, token, "")); body_builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, infeed)); } @@ -278,8 +279,10 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { { auto param = nested_callee_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); + auto token = + nested_callee_builder.AddInstruction(HloInstruction::CreateToken()); nested_callee_builder.AddInstruction( - HloInstruction::CreateOutfeed(shape, param, "")); + HloInstruction::CreateOutfeed(shape, param, token, "")); } auto nested_called_computation = module->AddEmbeddedComputation(nested_callee_builder.Build()); @@ -342,12 +345,12 @@ TEST_F(HloDceTest, RemoveDeadSubcomputation) { builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")), builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), /*dimensions_to_reduce=*/{0}, reduce_subcomp)); // Add another instruction as the root of the computation. builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); module->AddEntryComputation(builder.Build()); EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); @@ -383,7 +386,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) { builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")), builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), /*dimensions_to_reduce=*/{0}, reduce_subcomp)); // Add another instruction as the root of the computation that also uses @@ -393,7 +396,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) { builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")), builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), /*dimensions_to_reduce=*/{0}, reduce_subcomp)); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index e0c5718509dabebb7b9307bf764b0ea1ce7369a0..eded3e78eead76c4564daee119034c5031eba409 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -26,10 +26,10 @@ limitations under the License. namespace xla { // Domain isolation is the task of placing kDomain instructions between HLO -// instructions having different shrading. A kDomain instruction is essentially +// instructions having different sharding. A kDomain instruction is essentially // used to break an HLO graph edge connecting two instructions with different // sharding. If a set of connected instructions have all the same sharding, no -// kDomain instruciton will be placed. +// kDomain instruction will be placed. class HloDomainIsolator : public HloPassInterface { public: // Creates a new kDomain instruction for the edge between the use instruction diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index ebd5adb5d573ce4b556046f85eb26a6ad59efcb9..9e096320db5048457435199627a1ef1fe1572177 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -41,11 +41,15 @@ namespace xla { bool HloDomainMap::InSameDomain(HloInstruction* instruction1, HloInstruction* instruction2) const { - int64 domain_id1 = FindOrDefault(instruction_to_domain_, instruction1, -1); - int64 domain_id2 = FindOrDefault(instruction_to_domain_, instruction2, -1); + int64 domain_id1 = GetDomainId(instruction1); + int64 domain_id2 = GetDomainId(instruction2); return domain_id1 >= 0 && domain_id1 == domain_id2; } +int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const { + return FindOrDefault(instruction_to_domain_, instruction, -1); +} + Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); // We only check operands, so we are sure to not process the empty domain from @@ -58,6 +62,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } } + if (instruction == instruction->parent()->root_instruction()) { + auto domain = MakeUnique(); + domain->enter_domains.insert(instruction); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index e62ef763fb3881ab6030b1f6a66266ac80a3d84d..1ca71597253eecfb45ae8f384240033a57045277 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -65,6 +65,10 @@ class HloDomainMap { // currently processing. bool IsDomainInstruction(HloInstruction* instruction) const; + // Retrieves the domain identifier of the instruction, or -1 in case + // instruction is not found within any domain. + int64 GetDomainId(HloInstruction* instruction) const; + private: HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index aa0308100a21f109579de75788fce7d242d6a6b0..f855f2a1fc944fcc11c9afed278bef4af87813da 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -71,12 +71,6 @@ class DomainMetadata { // Returns a string representation of the metadata. virtual string ToString() const = 0; - - // Given a reachable set (the set of instructions which are reachable from - // each other via user/operand pathways, without crossing a kDomain - // instruciton), makes sure that all of them have metadata attributes which - // are coherent with this metadata object. - virtual Status NormalizeInstructions(const Domain& domain) const = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc index 1d06040b0e7c92b03f4cb5481bdee73a0f74f939..67fad0769f5eb5ceca64ebd2aa78c6469f2c813d 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_domain_remover.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" #include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_domain_verifier.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -43,54 +43,16 @@ class HloDomainRemover::RunContext { Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain( const DomainMetadata::Domain& domain) { - // Verify that the whole kDomain frontier bounding the instruction reach set, - // has matching metadata. - // A kDomain instruction has two sides of metadata, a user facing and an - // operand facing. - // A reachable instruction set can make contact with a kDomain instruction on - // a user facing side (the kDomain is operand of the instruction), or on a - // operand facing side (the kDomain is user of the instruction). - // And depending on the contact side, the proper metadata object - // (user_side_metadata() vs. operand_side_metadata()) needs to be used for - // consistency checks. - const DomainMetadata* ref_metadata = nullptr; - VLOG(4) << "Reach set:"; - for (HloInstruction* instruction : domain.instructions) { - VLOG(4) << " " << instruction->name(); - } - VLOG(4) << " Domains:"; - for (HloInstruction* instruction : domain.enter_domains) { - const DomainMetadata& meta = instruction->user_side_metadata(); - VLOG(4) << " User side: " << instruction->name(); - VLOG(4) << " " << meta.ToString(); - if (ref_metadata == nullptr) { - ref_metadata = &meta; - } else { - TF_RET_CHECK(meta.Matches(*ref_metadata)) - << "Metadata mismatch at instruction " << instruction->name() << " : " - << meta.ToString() << " vs " << ref_metadata->ToString(); - } - } - for (HloInstruction* instruction : domain.exit_domains) { - const DomainMetadata& meta = instruction->operand_side_metadata(); - VLOG(4) << " Operand side: " << instruction->name(); - VLOG(4) << " " << meta.ToString(); - if (ref_metadata == nullptr) { - ref_metadata = &meta; - } else { - TF_RET_CHECK(meta.Matches(*ref_metadata)) - << "Metadata mismatch at instruction " << instruction->name() << " : " - << meta.ToString() << " vs " << ref_metadata->ToString(); - } - } + TF_ASSIGN_OR_RETURN(const DomainMetadata* ref_metadata, + HloDomainVerifier::VerifyDomain(domain)); if (ref_metadata != nullptr) { VLOG(4) << "Applying domain normalization: " << ref_metadata->ToString(); - TF_RETURN_IF_ERROR(ref_metadata->NormalizeInstructions(domain)); + TF_RETURN_IF_ERROR(remover_->normalizer_(domain, ref_metadata)); } else { // No kDomain instruction was present within this domain, so call the // generic normalization functions and have them apply their heuristic. VLOG(2) << "Applying domain-less normalization"; - TF_RETURN_IF_ERROR(remover_->normalizer_(domain)); + TF_RETURN_IF_ERROR(remover_->normalizer_(domain, nullptr)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h index 0c71dd34fd4d2944037dc965a2c9ad2c592d6e3e..c859e05f02e54d601804b641094ecdd11bbe1aed 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.h +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -35,9 +35,10 @@ class HloDomainRemover : public HloPassInterface { // instructions in it with the same attributes (ie, sharding), a normalizer // function is tasked at applying attribute normalization on the instructions // within such domain. - HloDomainRemover( - tensorflow::StringPiece kind, - std::function normalizer) + HloDomainRemover(tensorflow::StringPiece kind, + std::function + normalizer) : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {} tensorflow::StringPiece name() const override { return "domain_remover"; } @@ -48,7 +49,9 @@ class HloDomainRemover : public HloPassInterface { class RunContext; string kind_; - std::function normalizer_; + std::function + normalizer_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 5553ddb153f7f1f2e6a790890c11f35e192488c4..70271be304336767bd3fd01297217e9309a941b6 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -21,12 +21,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class HloDomainTest : public HloTestBase { +class HloDomainTest : public HloVerifiedTestBase { protected: bool FindUserViaDomainPath(HloInstruction* instruction, HloInstruction* operand) const { @@ -64,11 +65,11 @@ class HloDomainTest : public HloTestBase { return false; } - StatusOr> ParseModule( - tensorflow::StringPiece hlo_string) { + StatusOr ParseModule(tensorflow::StringPiece hlo_string) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return ParseHloString(hlo_string, config); + ParseAndVerifyModule(hlo_string, config); + return &module(); } }; @@ -96,12 +97,6 @@ class OpNameMetadata : public DomainMetadata { string ToString() const override { return opname_; } - Status NormalizeInstructions( - const DomainMetadata::Domain& domain) const override { - // For the purposes of this test, nothing to do. - return Status::OK(); - } - static tensorflow::StringPiece KindName() { return "opname"; } private: @@ -123,7 +118,8 @@ std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, std::move(user_side_metadata)); } -Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain) { +Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain, + const DomainMetadata* metadata) { // Nothing to do for the particular use this test make of the OpName domains. return Status::OK(); } @@ -143,32 +139,31 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + EXPECT_TRUE(HasDomainEdge(module, "c", "a")); + EXPECT_TRUE(HasDomainEdge(module, "c", "b")); + EXPECT_TRUE(HasDomainEdge(module, "d", "a")); + EXPECT_TRUE(HasDomainEdge(module, "d", "b")); + EXPECT_FALSE(HasDomainEdge(module, "e", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), - NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + EXPECT_FALSE(HasDomainEdge(module, "c", "a")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "d", "a")); + EXPECT_FALSE(HasDomainEdge(module, "d", "b")); + EXPECT_FALSE(HasDomainEdge(module, "e", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); } TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) { @@ -186,12 +181,11 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(!isolator_changed); } @@ -202,37 +196,38 @@ HloModule Module ENTRY entry { p0 = (f32[4]) parameter(0) a = f32[4] get-tuple-element(p0), index=0 - b = (f32[4], u32[]) send(a), channel_id=1, sharding={maximal device=0} - c = () send-done(b), channel_id=1, sharding={maximal device=0} - d = (f32[4], u32[]) recv(), channel_id=2, sharding={maximal device=0} - e = f32[4] recv-done(d), channel_id=2, sharding={maximal device=0} - f = f32[4] add(a, e) - g = f32[4] subtract(a, e) + token = token[] after-all() + b = (f32[4], u32[], token[]) send(a, token), channel_id=1, sharding={maximal device=0} + c = token[] send-done(b), channel_id=1, sharding={maximal device=0} + d = (f32[4], u32[], token[]) recv(token), channel_id=2, sharding={maximal device=0} + e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0} + e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0} + f = f32[4] add(a, e_element) + g = f32[4] subtract(a, e_element) ROOT h = (f32[4], f32[4]) tuple(f, g) } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e")); - EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0")); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + EXPECT_TRUE(HasDomainEdge(module, "b", "a")); + EXPECT_TRUE(HasDomainEdge(module, "f", "e_element")); + EXPECT_FALSE(HasDomainEdge(module, "a", "p0")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), - NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e")); + EXPECT_FALSE(HasDomainEdge(module, "b", "a")); + EXPECT_FALSE(HasDomainEdge(module, "f", "e_element")); } TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { @@ -240,20 +235,21 @@ TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { HloModule Module ENTRY entry { - a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=-1} - b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=-1} - c = f32[4] add(b, b), sharding={maximal device=-1} - d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=-1} - ROOT e = () send-done(d), channel_id=2, sharding={maximal device=-1} + token = token[] after-all(), sharding={maximal device=-1} + a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=-1} + b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1} + b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1} + c = f32[4] add(b_element, b_element), sharding={maximal device=-1} + d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=-1} + ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1} } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_FALSE(isolator_changed); } @@ -262,24 +258,25 @@ TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { HloModule Module ENTRY entry { - a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=0} - b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=0} - c = f32[4] add(b, b) - d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=0} - ROOT e = () send-done(d), channel_id=2, sharding={maximal device=0} + token = token[] after-all(), sharding={maximal device=0} + a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=0} + b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0} + b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0} + c = f32[4] add(b_element, b_element) + d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=0} + ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0} } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainRemover remover(ShardingMetadata::KindName(), - NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_FALSE(remover_changed); - HloInstruction* add = FindInstruction(module.get(), "c"); + HloInstruction* add = FindInstruction(module, "c"); ASSERT_NE(add, nullptr); auto device = add->sharding_unique_device(); EXPECT_TRUE(device.has_value()); @@ -302,42 +299,41 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator sharding_isolator(CreateShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, - sharding_isolator.Run(module.get())); + sharding_isolator.Run(module)); EXPECT_TRUE(sharding_isolator_changed); HloDomainIsolator opname_isolator(OpNameDomainCreator); TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, - opname_isolator.Run(module.get())); + opname_isolator.Run(module)); EXPECT_TRUE(opname_isolator_changed); - EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + EXPECT_TRUE(HasDomainEdge(module, "c", "a")); + EXPECT_TRUE(HasDomainEdge(module, "c", "b")); + EXPECT_TRUE(HasDomainEdge(module, "d", "a")); + EXPECT_TRUE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); HloDomainRemover sharding_remover(ShardingMetadata::KindName(), - NormalizeShardingDomain); + ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, - sharding_remover.Run(module.get())); + sharding_remover.Run(module)); EXPECT_TRUE(sharding_remover_changed); HloDomainRemover opname_remover(OpNameMetadata::KindName(), OpNameDomainNormalizer); TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, - opname_remover.Run(module.get())); + opname_remover.Run(module)); EXPECT_TRUE(opname_remover_changed); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module, "c", "a")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "d", "a")); + EXPECT_FALSE(HasDomainEdge(module, "d", "c")); } TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { @@ -345,33 +341,35 @@ TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { HloModule Module ENTRY entry { - infeed = (f32[4], f32[4]) infeed(), - sharding={{maximal device=1}, {maximal device=0}} - gte0 = f32[4] get-tuple-element(infeed), index=0 - gte1 = f32[4] get-tuple-element(infeed), index=1 + token = token[] after-all() + infeed = ((f32[4], f32[4]), token[]) infeed(token), + sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} + infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0 + gte0 = f32[4] get-tuple-element(infeed.data), index=0 + gte1 = f32[4] get-tuple-element(infeed.data), index=1 copy0 = f32[4] copy(gte0) copy1 = f32[4] copy(gte1) ROOT add = f32[4] add(copy0, copy1) } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module.get(), "gte0", "infeed")); - EXPECT_TRUE(HasDomainEdge(module.get(), "gte1", "infeed")); - EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0")); - EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1")); + EXPECT_TRUE(HasDomainEdge(module, "infeed.data", "infeed")); + EXPECT_FALSE(HasDomainEdge(module, "copy0", "gte0")); + EXPECT_FALSE(HasDomainEdge(module, "copy1", "gte1")); // Inject unassigned tuple/gte within the infeed domain, to simulate the // HLO passes adding unexpected instructions. // // infeed + // | + // infeed.data (tuple element 0 of infeed) // / \ // GTE0 GTE1 // / \ @@ -380,31 +378,36 @@ ENTRY entry { // \ / // TUPLE // | - // DOMAIN - HloInstruction* infeed = FindInstruction(module.get(), "infeed"); + HloInstruction* infeed = FindInstruction(module, "infeed"); ASSERT_NE(infeed, nullptr); - auto infeed_users = infeed->users(); - HloInstruction* new_gte0 = + HloInstruction* infeed_data = infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + + auto infeed_data_users = infeed_data->users(); + HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed_data->shape(), 0), infeed_data, + 0)); HloInstruction* new_copy0 = - infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary( new_gte0->shape(), HloOpcode::kCopy, new_gte0)); - HloInstruction* new_gte1 = - infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetTupleElementShape(infeed->shape(), 1), infeed, 1)); + HloInstruction* new_gte1 = infeed_data->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed_data->shape(), 1), infeed_data, + 1)); HloInstruction* new_copy1 = - infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary( new_gte1->shape(), HloOpcode::kCopy, new_gte1)); - HloInstruction* new_tuple = infeed->parent()->AddInstruction( + HloInstruction* new_tuple = infeed_data->parent()->AddInstruction( HloInstruction::CreateTuple({new_copy0, new_copy1})); - for (HloInstruction* user : infeed_users) { - TF_EXPECT_OK(infeed->ReplaceUseWith(user, new_tuple)); + for (HloInstruction* user : infeed_data_users) { + TF_EXPECT_OK(infeed_data->ReplaceUseWith(user, new_tuple)); } HloDomainRemover remover(ShardingMetadata::KindName(), - NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_TRUE(remover_changed); struct Assignment { @@ -418,7 +421,7 @@ ENTRY entry { }; for (auto& assignment : assignments) { auto device = assignment.instruction->sharding_unique_device(); - EXPECT_TRUE(device.has_value()); + ASSERT_TRUE(device.has_value()); EXPECT_EQ(*device, assignment.device); } EXPECT_TRUE(new_tuple->has_sharding()); @@ -428,5 +431,97 @@ ENTRY entry { HloSharding::AssignDevice(0)})); } +TEST_F(HloDomainTest, EmptyRootDomain) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + %param = f32[1] parameter(0), sharding={maximal device=0} + %tuple = (f32[1]) tuple(%param), + sharding={maximal device=1} + ROOT %gte = f32[1] get-tuple-element(%tuple), index=0, + sharding={maximal device=1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module, "tuple", "param")); + EXPECT_FALSE(HasDomainEdge(module, "gte", "tuple")); + + // Remove %tuple and %gte (tuple simplification) + HloInstruction* gte = FindInstruction(module, "gte"); + HloInstruction* tuple = FindInstruction(module, "tuple"); + module->entry_computation()->set_root_instruction(tuple->mutable_operand(0)); + TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(gte)); + TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(tuple)); + + HloDomainRemover remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_TRUE(root->has_sharding()); + EXPECT_EQ(root->sharding(), HloSharding::AssignDevice(1)); +} + +// Tests that text dumps of domain instructions can be parsed back, in the +// specific case of null shardings. +TEST_F(HloDomainTest, DumpParseNullSharding) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {}); + auto sharding_md_0 = MakeUnique(nullptr); + auto sharding_md_1 = MakeUnique(nullptr); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); + HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain( + shape, param, std::move(sharding_md_0), std::move(sharding_md_1))); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, domain, domain)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto hlo_string = module->ToString(); + ASSERT_TRUE(ParseModule(hlo_string).status().ok()); +} + +TEST_F(HloDomainTest, DomainTuple) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = f32[4] parameter(0), sharding={maximal device=0} + cst = u32[] constant(0), sharding={maximal device=1} + tpl = (u32[], f32[4]) tuple(cst, p0), sharding={{maximal device=1}, {maximal device=0}} + ROOT gte = f32[4] get-tuple-element(tpl), index=1, sharding={maximal device=0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + // Clear sharding of tpl instruction, in order to test domain sharding + // application. + auto tpl = FindInstruction(module, "tpl"); + tpl->clear_sharding(); + + HloDomainRemover remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + EXPECT_EQ(HloSharding::Tuple(tpl->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + tpl->sharding()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc new file mode 100644 index 0000000000000000000000000000000000000000..751fc677e2d955fd3d9f8970f7c0370a22c054bf --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_verifier.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainVerifier::RunContext { + public: + RunContext(HloModule* module, HloDomainVerifier* verifier) + : module_(module), verifier_(verifier) {} + + Status Run(); + + private: + // If the verifier caller passed an empty vector for kinds, we collect all the + // avalable domain types. + Status PopulateDomainKinds(); + + HloModule* module_; + HloDomainVerifier* verifier_; +}; + +Status HloDomainVerifier::RunContext::PopulateDomainKinds() { + if (verifier_->kinds_.empty()) { + // The caller specified no domain kinds, collect all the ones available. + std::set kinds; + for (HloComputation* computation : module_->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kDomain) { + TF_RET_CHECK(instruction->user_side_metadata().Kind() == + instruction->operand_side_metadata().Kind()) + << instruction->ToString(); + kinds.insert(instruction->user_side_metadata().Kind().ToString()); + } + } + } + verifier_->kinds_.insert(verifier_->kinds_.end(), kinds.begin(), + kinds.end()); + } + return Status::OK(); +} + +Status HloDomainVerifier::RunContext::Run() { + VLOG(4) << "Running HLO Domain Verifier"; + TF_RETURN_IF_ERROR(PopulateDomainKinds()); + for (HloComputation* computation : module_->computations()) { + for (auto& kind : verifier_->kinds_) { + // First create the domain instruciton sets. A domain instruction set is + // the set of instructions whose edges never cross a kDomain instruction. + TF_ASSIGN_OR_RETURN(std::unique_ptr domain_map, + HloDomainMap::Create(computation, kind)); + // Verify every domain populated within the map. + for (auto& domain : domain_map->GetDomains()) { + TF_RETURN_IF_ERROR(VerifyDomain(*domain).status()); + } + } + } + return Status::OK(); +} + +StatusOr HloDomainVerifier::Run(HloModule* module) { + RunContext run_context(module, this); + TF_RETURN_IF_ERROR(run_context.Run()); + return false; +} + +StatusOr HloDomainVerifier::VerifyDomain( + const DomainMetadata::Domain& domain) { + const DomainMetadata* ref_metadata = nullptr; + VLOG(4) << "Reach set:"; + for (HloInstruction* instruction : domain.instructions) { + VLOG(4) << " " << instruction->name(); + } + VLOG(4) << " Domains:"; + for (HloInstruction* instruction : domain.enter_domains) { + const DomainMetadata& meta = instruction->user_side_metadata(); + VLOG(4) << " User side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + for (HloInstruction* instruction : domain.exit_domains) { + const DomainMetadata& meta = instruction->operand_side_metadata(); + VLOG(4) << " Operand side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + return ref_metadata; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h new file mode 100644 index 0000000000000000000000000000000000000000..8e53cf97f8ba9a88140a909ad20c1a938aec8c1f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_VERIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_VERIFIER_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { + +// Verifies that the domain instructions are consistent, and the each domain is +// surrounded by the same metadata. +class HloDomainVerifier : public HloPassInterface { + public: + HloDomainVerifier(std::vector kinds) : kinds_(std::move(kinds)) {} + + tensorflow::StringPiece name() const override { return "domain_verifier"; } + + StatusOr Run(HloModule* module) override; + + // Verify that the whole kDomain frontier bounding the instruction reach set, + // has matching metadata. + // A kDomain instruction has two sides of metadata, a user facing and an + // operand facing. + // A reachable instruction set can make contact with a kDomain instruction on + // a user facing side (the kDomain is operand of the instruction), or on a + // operand facing side (the kDomain is user of the instruction). + // And depending on the contact side, the proper metadata object + // (user_side_metadata() vs. operand_side_metadata()) needs to be used for + // consistency checks. + // Returns the DomainMetadata pointer which surrounds the domain, and + // represents the common metadata within such domain. If the returned + // DomainMetadata pointer is nullptr, the input domain had no kDomain + // boundary. + static StatusOr VerifyDomain( + const DomainMetadata::Domain& domain); + + private: + class RunContext; + + std::vector kinds_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_VERIFIER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index 4ed1508d7067684a15d0fb7d86e69b055bc1333b..b9244b8e9e5f34e7ac5113c8eacb6f8243eea314 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -144,6 +144,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { opcode == HloOpcode::kCrossReplicaSum || opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || + opcode == HloOpcode::kScatter || opcode == HloOpcode::kSelectAndScatter || opcode == HloOpcode::kConditional) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index 5c5a059e0fd895f03bc26a975609b57333237faf..c170e36c73ad2bef830e528de3ec72d38683d888 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -57,8 +57,10 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) { const string& hlo_string = R"( HloModule InfeedOutfeed ENTRY RoundTrip16MiBR1.v2 { - ROOT infeed = bf16[4]{0} infeed() - outfeed = () outfeed(infeed) + token = token[] after-all() + infeed = (bf16[4]{0}, token[]) infeed(token) + ROOT infeed.data = bf16[4]{0} get-tuple-element(infeed), index=0 + outfeed = token[] outfeed(infeed.data, token) } )"; auto module = CreateModuleFromHloString(hlo_string); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index e0648e14672c45e9a691fd6a674c9a2cd7605a12..51353eea6e72d5a131897f3c3ae312046051103e 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -135,7 +136,6 @@ StatusOr> Compare( } // namespace - HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { typed_visitors_[PRED] = MakeUnique>(this); @@ -300,12 +300,6 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( instruction->CloneWithNewOperands(instruction->shape(), operands); auto result = Evaluate(cloned_instruction.get()); - // Clean up our cloned instructions before returning. - cloned_instruction->DetachFromOperands(); - for (auto& operand : owned_operands) { - operand->DetachFromOperands(); - } - return result; } @@ -321,7 +315,6 @@ StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( rhs_instr.get()); auto result = Evaluate(cloned_instruction.get()); - cloned_instruction->DetachFromOperands(); return result; } @@ -334,10 +327,27 @@ StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); auto result = Evaluate(cloned_instruction.get()); - cloned_instruction->DetachFromOperands(); return result; } +StatusOr> HloEvaluator::EvaluateDotOp( + const DotDimensionNumbers& dim_numbers, const Literal& lhs, + const Literal& rhs) { + std::unique_ptr lhs_instr = + HloInstruction::CreateConstant(lhs.CloneToUnique()); + std::unique_ptr rhs_instr = + HloInstruction::CreateConstant(rhs.CloneToUnique()); + + TF_ASSIGN_OR_RETURN( + Shape dot_shape, + ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers)); + + std::unique_ptr cloned_instruction = + HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(), + dim_numbers); + return Evaluate(cloned_instruction.get()); +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; @@ -372,7 +382,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); - CHECK(!ShapeUtil::IsTuple(reference_shape)); + CHECK(ShapeUtil::IsArray(reference_shape)); const int64 rank = ShapeUtil::Rank(reference_shape); const int64 concat_dim = concatenate->dimensions()[0]; CHECK_GE(concat_dim, 0); @@ -383,14 +393,14 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (int64 i = 1; i < operands.size(); ++i) { const Shape& operand_shape = operands[i]->shape(); - CHECK(!ShapeUtil::IsTuple(operand_shape)); + CHECK(ShapeUtil::IsArray(operand_shape)); // Accumulate the concat dimension from all tensors taking part to the // operation. concat_dimensions[concat_dim] += ShapeUtil::GetDimension(operand_shape, concat_dim); } - auto result_literal = Literal::CreateFromDimensions( + auto result_literal = LiteralUtil::CreateFromDimensions( reference_shape.element_type(), concat_dimensions); DimensionVector source_indices(rank, 0); DimensionVector dest_indices(concat_dimensions.size(), 0); @@ -541,7 +551,7 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) { operand_literals.push_back(&GetEvaluatedLiteralFor(operand)); } - evaluated_[tuple] = Literal::MakeTuple(operand_literals); + evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals); return Status::OK(); } @@ -765,6 +775,12 @@ class OutputWindowIndexToInputIndex { return ArraySlice(input_index_); } + // Returns for a given 'input_dim' the corresponding output dimension index, + // or -1 if 'input_dim' is an elided window dimension. + int64 input_dim_value_to_output_index(int64 input_dim) { + return input_dim_value_to_output_index_[input_dim]; + } + private: // Propagates window dimensions from the output index to input_index_ by // mutating input_index_ in place. @@ -782,7 +798,7 @@ class OutputWindowIndexToInputIndex { // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of // the input index from the output index. See - // PropagateOutputIndexToInputIndex. + // PropagateOutputIndexWindowDimsToInputIndex. std::vector input_dim_value_to_output_index_; // The result computed by this functor. operator() returns an ArraySlice into @@ -835,6 +851,8 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { // corresponding index in the input shape. std::vector input_index(operand.shape().dimensions_size()); std::vector output_index(gather->shape().dimensions_size()); + std::vector input_gather_index_clamped( + operand.shape().dimensions_size()); OutputGatherIndexToInputIndex output_gather_index_to_input_index( &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), @@ -856,14 +874,26 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { output_index[i] = output_gather_index[i] + output_window_index[i]; DCHECK_LT(output_index[i], shape.dimensions(i)); } + for (int i = 0, e = input_gather_index.size(); i < e; i++) { + int64 output_dim = + output_window_index_to_input_index.input_dim_value_to_output_index(i); + // If 'output_dim' is -1, it means 'i' is an elided window dim. This means + // we set the iteration index to 0, so for the purpose of the following + // calculations we can consider the output dimension size to be 1. + int64 output_dim_size = + output_dim == -1 ? 1 : shape.dimensions(output_dim); + // Clamp the gather index so that the gather region fits in the operand. + // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0, + // operand_shape.dimensions(i) - + // output_dim_size); + input_gather_index_clamped[i] = + std::min(operand_shape.dimensions(i) - output_dim_size, + std::max(0LL, input_gather_index[i])); + } for (int i = 0, e = input_index.size(); i < e; i++) { - // TODO(b/74360564): We should implement whatever out of bounds behavior - // we decide for dynamic-slice here as well. - input_index[i] = (input_gather_index[i] + input_window_index[i]) % - operand_shape.dimensions(i); - if (input_index[i] < 0) { - input_index[i] += operand_shape.dimensions(i); - } + input_index[i] = input_gather_index_clamped[i] + input_window_index[i]; + DCHECK_GE(input_index[i], 0); + DCHECK_LT(input_index[i], operand_shape.dimensions(i)); } TF_RETURN_IF_ERROR( result->CopyElementFrom(operand, input_index, output_index)); @@ -910,11 +940,8 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } -Status HloEvaluator::HandleGenerateToken(HloInstruction* token) { - // Literals cannot represent a TOKEN shape so just create an empty tuple as - // the "result" of the kGenerateToken operation. - // TODO(b/109929053): Add support for TOKENs in Literals. - evaluated_[token] = Literal::MakeTuple({}); +Status HloEvaluator::HandleAfterAll(HloInstruction* token) { + evaluated_[token] = LiteralUtil::CreateToken(); return Status::OK(); } @@ -1035,8 +1062,6 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { const auto& on_false = GetEvaluatedLiteralFor(select->operand(2)); // If predicate is of scalar type, no element-wise selection would be needed. - // This would also handle output array of tuple types as the DefaultAction - // would go through the HloEvaluatorTypedVisitor which doesn't handle tuples. if (ShapeUtil::IsScalar(pred.shape())) { if (pred.Get({})) { evaluated_[select] = on_true.CloneToUnique(); @@ -1049,6 +1074,19 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { return DefaultAction(select); } +Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) { + const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0)); + const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1)); + const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2)); + + if (pred.Get({})) { + evaluated_[tuple_select] = on_true.CloneToUnique(); + } else { + evaluated_[tuple_select] = on_false.CloneToUnique(); + } + return Status::OK(); +} + Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { HloComputation* cond_comp = while_hlo->while_condition(); HloComputation* body_comp = while_hlo->while_body(); @@ -1079,6 +1117,161 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { return Status::OK(); } +// Key-value sort is a special snowflake: it's templated on two different +// element types, one for the keys, and one for the values. Jump through some +// hoops to make this work. +namespace { +template +StatusOr> EvaluateSortInternal( + HloInstruction* sort, const Literal& keys_literal, + const Literal& values_literal) { + auto rank = ShapeUtil::Rank(keys_literal.shape()); + TF_RET_CHECK( + ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) + << "Sort keys and values must have the same dimensions"; + TF_RET_CHECK(rank > 0 && rank <= 2) + << "Sort is only supported for rank-1 and rank-2 shapes, rank is: " + << rank; + TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort"; + // We need to sort and array of keys and an array of values, where the + // sorted order of the values is determined by the keys. The simplest(?) + // way to do this is to go to an array-of-pairs representation, sort the + // array using the keys, and then go back to pair-of-arrays. + VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); + VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); + + auto sort_r1 = [](const Literal& keys_literal, + const Literal& values_literal) { + const auto& keys_data = keys_literal.data(); + const auto& values_data = values_literal.data(); + + using kv_pair = std::pair; + std::vector key_value_vector; + CHECK_EQ(keys_data.size(), values_data.size()); + key_value_vector.reserve(keys_data.size()); + for (int i = 0; i < keys_data.size(); ++i) { + key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i])); + } + std::sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess(a.first, b.first); + }); + std::vector result_keys; + std::vector result_values; + for (const auto& key_value : key_value_vector) { + result_keys.push_back(key_value.first); + result_values.push_back(key_value.second); + } + auto result_keys_literal = MakeUnique(keys_literal.shape()); + result_keys_literal->PopulateR1( + tensorflow::gtl::ArraySlice(result_keys)); + auto result_values_literal = MakeUnique(values_literal.shape()); + result_values_literal->PopulateR1( + tensorflow::gtl::ArraySlice(result_values)); + return std::make_pair(std::move(result_keys_literal), + std::move(result_values_literal)); + }; + + std::unique_ptr result_tuple; + if (rank == 1) { + auto result_pair = sort_r1(keys_literal, values_literal); + result_tuple = LiteralUtil::MakeTuple( + {result_pair.first.get(), result_pair.second.get()}); + } else { + // For R2 sort, the desired semantics are to sort each matrix row + // independently. + auto keys_result_literal = MakeUnique(keys_literal.shape()); + auto values_result_literal = MakeUnique(values_literal.shape()); + int64 r1_length = keys_literal.shape().dimensions(1); + for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { + TF_ASSIGN_OR_RETURN(auto keys_r1_slice, + keys_literal.Slice({row, 0}, {row + 1, r1_length}) + ->Reshape({r1_length})); + TF_ASSIGN_OR_RETURN(auto values_r1_slice, + values_literal.Slice({row, 0}, {row + 1, r1_length}) + ->Reshape({r1_length})); + auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice); + TF_ASSIGN_OR_RETURN(auto sorted_keys, + r1_result_pair.first->Reshape({1, r1_length})); + TF_ASSIGN_OR_RETURN(auto sorted_values, + r1_result_pair.second->Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom( + *sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); + TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom( + *sorted_values, {0, 0}, {row, 0}, {1, r1_length})); + } + result_tuple = LiteralUtil::MakeTuple( + {keys_result_literal.get(), values_result_literal.get()}); + } + + VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString(); + return std::move(result_tuple); +} + +template +StatusOr> EvaluateSortCurried( + HloInstruction* sort, const Literal& keys_literal, + const Literal& values_literal) { + switch (sort->operand(1)->shape().element_type()) { + case F32: + return EvaluateSortInternal(sort, keys_literal, + values_literal); + case U32: + return EvaluateSortInternal(sort, keys_literal, + values_literal); + case S32: + return EvaluateSortInternal(sort, keys_literal, + values_literal); + case BF16: + return EvaluateSortInternal(sort, keys_literal, + values_literal); + default: + return InvalidArgument("Unsupported type for Sort"); + } +} + +StatusOr> EvaluateSort(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { + switch (sort->operand(0)->shape().element_type()) { + case F32: + return EvaluateSortCurried(sort, keys_literal, values_literal); + case U32: + return EvaluateSortCurried(sort, keys_literal, values_literal); + case S32: + return EvaluateSortCurried(sort, keys_literal, values_literal); + case BF16: + return EvaluateSortCurried(sort, keys_literal, values_literal); + default: + return InvalidArgument("Unsupported type for Sort"); + } +} +} // namespace + +Status HloEvaluator::HandleSort(HloInstruction* sort) { + const int64 sort_dim = sort->dimensions(0); + const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape()); + if (sort_dim != rank - 1) { + return Unimplemented( + "Trying to support along dimension %lld, which is not the last " + "dimension", + sort_dim); + } + + if (!ShapeUtil::IsTuple(sort->shape())) { + return DefaultAction(sort); + } else { + auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)), + GetEvaluatedLiteralFor(sort->operand(1))); + if (result.ok()) { + evaluated_[sort] = std::move(result.ValueOrDie()); + return Status::OK(); + } else { + return result.status(); + } + } +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index fc2fc9437b238a2e519401b2b121dfbef070e2dc..a4c37ef32827892194da070ee05ec6dc4f4c306f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -115,6 +116,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault { StatusOr> EvaluateElementwiseUnaryOp( HloOpcode opcode, const Literal& operand); + StatusOr> EvaluateDotOp( + const DotDimensionNumbers& dim_numbers, const Literal& lhs, + const Literal& rhs); + protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this // class. @@ -172,9 +177,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; + Status HandleBroadcast(HloInstruction* broadcast) override; - Status HandleGenerateToken(HloInstruction* token) override; + Status HandleAfterAll(HloInstruction* token) override; + + Status HandleSort(HloInstruction* sort) override; // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 72eb9930e92c340ab9f42cd563c27507623b2ba7..3ac6d68df30955d2e5e06e1e76d2182772151b47 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" @@ -112,9 +112,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. TEST_P(HloEvaluatorTest, DoesClamp) { - auto low = Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); - auto value = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); - auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto low = LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); + auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); Shape shape = low->shape(); HloComputation::Builder b(TestName()); @@ -127,15 +127,15 @@ TEST_P(HloEvaluatorTest, DoesClamp) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); + auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { - auto low = Literal::CreateR0(0.f); - auto value = Literal::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); - auto high = Literal::CreateR0(1.f); + auto low = LiteralUtil::CreateR0(0.f); + auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); + auto high = LiteralUtil::CreateR0(1.f); Shape shape = value->shape(); HloComputation::Builder b(TestName()); @@ -148,7 +148,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({{0, 0}, {1, 1}}); + auto expected = LiteralUtil::CreateR2({{0, 0}, {1, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -156,9 +156,9 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. TEST_P(HloEvaluatorTest, DoesSelect) { - auto pred = Literal::CreateR2({{true, false}, {false, true}}); - auto on_true = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - auto on_false = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto pred = LiteralUtil::CreateR2({{true, false}, {false, true}}); + auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); Shape shape = on_true->shape(); HloComputation::Builder b(TestName()); @@ -173,7 +173,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { std::unique_ptr result = Evaluate({}); - auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); + auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -181,37 +181,46 @@ TEST_P(HloEvaluatorTest, DoesSelect) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. TEST_P(HloEvaluatorTest, DoesAdd) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{3, 4}, {-96, 8}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{3, 4}, {-96, 8}}); TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs), std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise and with 2 operands. TEST_P(HloEvaluatorTest, DoesAnd) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{0, 0}, {4, 4}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{0, 0}, {4, 4}}); TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs), std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. TEST_P(HloEvaluatorTest, DoesOr) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{3, 4}, {-100, 4}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{3, 4}, {-100, 4}}); TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs), std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise or with 2 operands. +TEST_P(HloEvaluatorTest, DoesXor) { + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{3, 4}, {-104, 0}}); + TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs), + std::move(rhs)); +} +// Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise multiply with 2 operands. TEST_P(HloEvaluatorTest, DoesMultiply) { - auto lhs = Literal::CreateR2({{-1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2( + auto lhs = LiteralUtil::CreateR2({{-1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2( {{std::numeric_limits::min(), 4}, {4, 4}}); - auto expected = Literal::CreateR2( + auto expected = LiteralUtil::CreateR2( {{std::numeric_limits::min(), 0}, {-400, 16}}); TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs), std::move(rhs)); @@ -219,17 +228,17 @@ TEST_P(HloEvaluatorTest, DoesMultiply) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. TEST_P(HloEvaluatorTest, DoesDivideInt64) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{0, 0}, {-25, 1}}); TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } TEST_P(HloEvaluatorTest, DoesDivideDouble) { - auto lhs = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); - auto rhs = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); + auto lhs = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); + auto rhs = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); auto expected = - Literal::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); + LiteralUtil::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } @@ -237,54 +246,54 @@ TEST_P(HloEvaluatorTest, DoesDivideDouble) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. TEST_P(HloEvaluatorTest, DoesAbsR2) { - auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); - auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); + auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); + auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } TEST_P(HloEvaluatorTest, DoesAbsR0) { - auto operand = Literal::CreateR0(-1.0f); - auto expected = Literal::CreateR0(1.0f); + auto operand = LiteralUtil::CreateR0(-1.0f); + auto expected = LiteralUtil::CreateR0(1.0f); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } TEST_P(HloEvaluatorTest, DoesAbsR1WithZeroSize) { - auto operand = Literal::CreateR1({}); - auto expected = Literal::CreateR1({}); + auto operand = LiteralUtil::CreateR1({}); + auto expected = LiteralUtil::CreateR1({}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } TEST_P(HloEvaluatorTest, DoesNegateR2) { - auto operand = Literal::CreateR2( + auto operand = LiteralUtil::CreateR2( {{0, std::numeric_limits::min()}, {-1, 4}}); - auto expected = - Literal::CreateR2({{0, std::numeric_limits::min()}, {1, -4}}); + auto expected = LiteralUtil::CreateR2( + {{0, std::numeric_limits::min()}, {1, -4}}); TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand)); } TEST_P(HloEvaluatorTest, DoesCosR2) { - auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); - auto expected = Literal::CreateR2({{1, -1}, {-1, 1}}); + auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); + auto expected = LiteralUtil::CreateR2({{1, -1}, {-1, 1}}); TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesSinR2) { - auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); - auto expected = Literal::CreateR2({{0, 0}, {0, 0}}); + auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); + auto expected = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesNotR2) { auto operand = - Literal::CreateR2({{0, std::numeric_limits::min()}, - {-1, std::numeric_limits::max()}}); + LiteralUtil::CreateR2({{0, std::numeric_limits::min()}, + {-1, std::numeric_limits::max()}}); auto expected = - Literal::CreateR2({{-1, std::numeric_limits::max()}, - {0, std::numeric_limits::min()}}); + LiteralUtil::CreateR2({{-1, std::numeric_limits::max()}, + {0, std::numeric_limits::min()}}); TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand)); } // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto rhs2 = Literal::CreateR2({{1, -20}, {-100, 4}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); @@ -305,7 +314,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { std::unique_ptr result = Evaluate(args); - auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); + auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -315,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - Literal::CreateRandomLiteral( + LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->CloneToUnique(); HloInstruction* literal_instruction = @@ -340,8 +349,8 @@ TEST_P(HloEvaluatorTest, DoesReshape) { // Verifies Broadcast operation is correctly evaluated. TEST_P(HloEvaluatorTest, DoesBroadcast) { HloComputation::Builder b(TestName()); - auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); - auto output_literal = Literal::CreateR3( + auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + auto output_literal = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}}); HloInstruction* literal_instruction = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -356,8 +365,8 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { HloComputation::Builder b(TestName()); - auto input_literal = Literal::CreateR0(111); - auto output_literal = Literal::CreateR2( + auto input_literal = LiteralUtil::CreateR0(111); + auto output_literal = LiteralUtil::CreateR2( {{111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}}); HloInstruction* literal_instruction = b.AddInstruction( @@ -377,9 +386,9 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-1, -2}, {100, 200}}))); + LiteralUtil::CreateR2({{-1, -2}, {100, 200}}))); HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-2, -3}, {-100, -200}}))); + LiteralUtil::CreateR2({{-2, -3}, {-100, -200}}))); std::vector operands = {operand1, operand2}; @@ -390,8 +399,8 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { std::unique_ptr result = Evaluate(); - auto expected = - Literal::CreateR2({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); + auto expected = LiteralUtil::CreateR2( + {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -399,9 +408,9 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({100, 200}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({100, 200}))); HloInstruction* operand2 = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); std::vector operands = {operand1, operand2}; @@ -412,16 +421,16 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR1({100, 200}); + auto expected = LiteralUtil::CreateR1({100, 200}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { HloComputation::Builder b(TestName()); - auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto expected = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), expected->shape())); @@ -438,9 +447,9 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { HloComputation::Builder b(TestName()); - auto input_literal = Literal::CreateR2WithLayout( + auto input_literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1})); - auto expected = Literal::CreateR2WithLayout( + auto expected = LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0})); ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), expected->shape())); @@ -469,13 +478,13 @@ PaddingConfig CreatePaddingConfig( } TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { - auto operand = Literal::CreateR2({{}, {}}); + auto operand = LiteralUtil::CreateR2({{}, {}}); HloComputation::Builder b(TestName()); auto operand_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); constexpr int32 kPadValue = 10; - auto pad_value = Literal::CreateR0(kPadValue); + auto pad_value = LiteralUtil::CreateR0(kPadValue); auto padding_value_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value))); @@ -487,7 +496,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2( + auto expected = LiteralUtil::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); @@ -497,11 +506,11 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { HloComputation::Builder b(TestName()); Array4D input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - auto input = Literal::CreateR4FromArray4D(input_array); + auto input = LiteralUtil::CreateR4FromArray4D(input_array); HloInstruction* input_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); constexpr float kPadValue = 1.5; - auto pad_value = Literal::CreateR0(kPadValue); + auto pad_value = LiteralUtil::CreateR0(kPadValue); HloInstruction* pad_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value))); @@ -523,7 +532,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { (*expected_array)(7, 0, 0, 0) = 5.0f; (*expected_array)(7, 2, 0, 0) = 6.0f; - auto expected = Literal::CreateR4FromArray4D(*expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(*expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -540,12 +549,12 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { // } auto input_array = MakeUnique>(4, 3); input_array->FillUnique(1.0f); - auto input = Literal::CreateR2FromArray2D(*input_array); + auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); auto pad_value_instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.718f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.718f))); auto r2_padding_on_dim0_dim1 = CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}}); @@ -565,7 +574,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 2) = 2.718f; (*expected_array)(0, 3) = 2.718f; (*expected_array)(0, 4) = 2.718f; - auto expected = Literal::CreateR2FromArray2D(*expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); } @@ -581,12 +590,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { // } auto input_array = MakeUnique>(4, 3); input_array->FillUnique(1.0f); - auto input = Literal::CreateR2FromArray2D(*input_array); + auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); auto pad_value_instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.718f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.718f))); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -604,7 +613,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { std::unique_ptr result = Evaluate(); auto expected_array = MakeUnique>(0, 9); - auto expected = Literal::CreateR2FromArray2D(*expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -621,13 +630,13 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // } auto lhs_array = MakeUnique>(4, 1); lhs_array->FillUnique(1.0f); - auto lhs_literal = Literal::CreateR2FromArray2D(*lhs_array); + auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); // rhs: // f32[2] { 1, 2 }, - auto rhs_literal = Literal::CreateR2({{1, 2}}); + auto rhs_literal = LiteralUtil::CreateR2({{1, 2}}); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -649,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { {4.f, 8.f}, }); // clang-format on - auto expected = Literal::CreateR2FromArray2D(expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -660,7 +669,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { // lhs: // f32[3] // { 1, 2, 3 }, - auto lhs_literal = Literal::CreateR1({1, 2, 3}); + auto lhs_literal = LiteralUtil::CreateR1({1, 2, 3}); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -672,7 +681,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { // } auto rhs_array = MakeUnique>(3, 2); rhs_array->FillUnique(1.0f); - auto rhs_literal = Literal::CreateR2FromArray2D(*rhs_array); + auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -686,7 +695,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR1({22.f, 28.f}); + auto expected = LiteralUtil::CreateR1({22.f, 28.f}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -703,7 +712,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // } auto lhs_array = MakeUnique>(4, 3); lhs_array->FillUnique(1.0f); - auto lhs_literal = Literal::CreateR2FromArray2D(*lhs_array); + auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -715,7 +724,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // } auto rhs_array = MakeUnique>(3, 2); rhs_array->FillUnique(1.0f); - auto rhs_literal = Literal::CreateR2FromArray2D(*rhs_array); + auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -735,7 +744,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { {94.f, 124.f}, {130.f, 172.f}, }); - auto expected = Literal::CreateR2FromArray2D(expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -744,12 +753,12 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { HloComputation::Builder b(TestName()); Array3D lhs_array = {{{1, 2, 3}}}; - auto lhs_literal = Literal::CreateR3FromArray3D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR3FromArray3D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); Array3D rhs_array = {{{3.f, 4.f}}}; - auto rhs_literal = Literal::CreateR3FromArray3D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR3FromArray3D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -783,7 +792,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { std::unique_ptr result = Evaluate(); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; - auto expected = Literal::CreateR3FromArray3D(expected_array); + auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -800,7 +809,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { {13, 14, 15, 16}, })); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -811,7 +820,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { {7, 8}, })); // clang-format on - auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -845,7 +854,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { {149, 160, 171, 80}, })); // clang-format on - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -875,11 +884,11 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { }}); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(input); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(input); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); - auto rhs_literal = Literal::CreateR4FromArray4D(weight); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(weight); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse( @@ -924,7 +933,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { Array4D expected_array({{{{2514, 2685}}}}); Array4D expected_array_bf16({{{{2512, 2672}}}}); // clang-format on - auto expected = Literal::CreateR4FromArray4D( + auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); @@ -955,11 +964,11 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { }}); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(input); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(input); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); - auto rhs_literal = Literal::CreateR4FromArray4D(weight); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(weight); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -1001,7 +1010,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { Array4D expected_array({{{{2514, 2685}}}}); Array4D expected_array_bf16({{{{2512, 2672}}}}); // clang-format on - auto expected = Literal::CreateR4FromArray4D( + auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); @@ -1019,7 +1028,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { {13, 14, 15, 16}, })); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -1030,7 +1039,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { {7, 8}, })); // clang-format on - auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -1065,7 +1074,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { {91, 112, 98, 120, 105, 128, 112}, {65, 84, 70, 90, 75, 96, 80}, })); - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1082,7 +1091,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { {13, 14, 15, 16}, })); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -1093,7 +1102,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { {7, 8}, })); // clang-format on - auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -1129,7 +1138,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { {104, 91, 112, 98, 120, 105, 128, 112}, {78, 65, 84, 70, 90, 75, 96, 80}, })); - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1147,7 +1156,7 @@ TEST_P(HloEvaluatorTest, {13, 14, 15, 16}, })); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -1158,7 +1167,7 @@ TEST_P(HloEvaluatorTest, {8, 9, 10}, })); // clang-format on - auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -1201,7 +1210,7 @@ TEST_P(HloEvaluatorTest, {0, 0, 0}, {91, 98, 105}, })); - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1216,9 +1225,9 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 std::vector v(kNumElements, 1.0f); HloInstruction* arg_instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(v))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(v))); HloInstruction* init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1253,9 +1262,9 @@ void BM_ReducePrecisely(int num_iters) { constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 std::vector v(kNumElements, 1.0f); HloInstruction* arg_instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(v))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(v))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1290,13 +1299,13 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { // } auto arg_array = MakeUnique>(2, 3); arg_array->FillUnique(1.0f); - auto arg_literal = Literal::CreateR2FromArray2D(*arg_array); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1317,7 +1326,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR1({6, 18}); + auto expected = LiteralUtil::CreateR1({6, 18}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1332,13 +1341,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { // } auto arg_array = MakeUnique>(2, 3); arg_array->FillUnique(1.0f); - auto arg_literal = Literal::CreateR2FromArray2D(*arg_array); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder max_computation("max"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1369,7 +1378,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({{6, 7}}); + auto expected = LiteralUtil::CreateR2({{6, 7}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1383,13 +1392,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { // } auto arg_array = MakeUnique>(2, 3); arg_array->FillUnique(1.0f); - auto arg_literal = Literal::CreateR2FromArray2D(*arg_array); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1426,7 +1435,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({{1, 3, 5}, {5, 11, 13}}); + auto expected = LiteralUtil::CreateR2({{1, 3, 5}, {5, 11, 13}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1436,13 +1445,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. std::vector input_dims(6, 4); std::unique_ptr arg_literal = - Literal::CreateFullWithDescendingLayout(input_dims, 1.0f); + LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1489,7 +1498,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { std::vector output_dims = {4, 3, 3, 3, 4, 4}; std::unique_ptr result_literal = - Literal::CreateFullWithDescendingLayout(output_dims, 8.0f); + LiteralUtil::CreateFullWithDescendingLayout(output_dims, 8.0f); EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); } @@ -1504,7 +1513,8 @@ TEST_P(HloEvaluatorTest, StridedSlice) { // } auto operand_array = MakeUnique>(3, 5); operand_array->FillUnique(1.0f); - auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); @@ -1518,7 +1528,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {3}, {19}, }); @@ -1536,13 +1546,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { // } auto operand_array = MakeUnique>(2, 4); operand_array->FillUnique(1.0f); - auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, @@ -1551,7 +1562,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); @@ -1571,13 +1582,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { // } auto operand_array = MakeUnique>(2, 4); operand_array->FillUnique(1.0f); - auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2, 1}))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, @@ -1586,7 +1598,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); @@ -1604,16 +1616,17 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { // } auto operand_array = MakeUnique>(2, 3); operand_array->FillUnique(1.0); - auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); auto update = b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-2.0, -3.0}, {-6.0, -7.0}}))); + LiteralUtil::CreateR2({{-2.0, -3.0}, {-6.0, -7.0}}))); Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( @@ -1622,7 +1635,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {1, -2, -3}, {5, -6, -7}, }); @@ -1640,12 +1653,13 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { // } auto operand_array = MakeUnique>(2, 3); operand_array->FillUnique(1.0); - auto operand_literal2 = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal2 = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand2 = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal2))); HloInstruction* operand1 = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); auto tuple = b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2})); @@ -1657,7 +1671,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {1, 2, 3}, {5, 6, 7}, }); @@ -1677,9 +1691,9 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { operand_array->FillUnique(1.0); HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2FromArray2D(*operand_array))); + LiteralUtil::CreateR2FromArray2D(*operand_array))); HloInstruction* operand1 = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); auto tuple1 = b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2})); @@ -1697,8 +1711,8 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { std::unique_ptr result = Evaluate(); auto result_inner_literal = - Literal::CreateR2FromArray2D(*operand_array); - auto expected = Literal::MakeTuple({ + LiteralUtil::CreateR2FromArray2D(*operand_array); + auto expected = LiteralUtil::MakeTuple({ result_inner_literal.get(), result_inner_literal.get(), }); @@ -1726,7 +1740,7 @@ TEST_P(HloEvaluatorTest, Reverse) { {{23.0f}, {24.0f}}}, }); // clang-format on - auto operand_literal = Literal::CreateR4FromArray4D(input); + auto operand_literal = LiteralUtil::CreateR4FromArray4D(input); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); @@ -1737,7 +1751,7 @@ TEST_P(HloEvaluatorTest, Reverse) { std::unique_ptr result = Evaluate(); // clang-format off - auto expected = Literal::CreateR4FromArray4D({ + auto expected = LiteralUtil::CreateR4FromArray4D({ {{{23.0f}, {24.0f}}, {{21.0f}, {22.0f}}, {{19.0f}, {20.0f}}}, @@ -1773,11 +1787,11 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}. HloEvaluator evaluator; auto result = evaluator.EvaluateWithSubstitutions( - add, {{param0, Literal::CreateR1({1, 2, 3, 4}).get()}, - {square, Literal::CreateR1({10, 20, 30, 40}).get()}}); + add, {{param0, LiteralUtil::CreateR1({1, 2, 3, 4}).get()}, + {square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1790,18 +1804,18 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0")); HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kMultiply, param0, param0)); - HloInstruction* constant = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + HloInstruction* constant = b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); HloInstruction* add = b.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, constant, square)); // Evaluate add with square = {10, 20, 30, 40}. HloEvaluator evaluator; auto result = evaluator.EvaluateWithSubstitutions( - add, {{square, Literal::CreateR1({10, 20, 30, 40}).get()}}); + add, {{square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -1821,11 +1835,12 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), gather_indices.get()}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1845,10 +1860,11 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 3}, {4, 6}, {7, 9}}), + *LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1869,11 +1885,11 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 2}, {2, 1}}); + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR3( + *LiteralUtil::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1895,13 +1911,13 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{-1, 1}, {-4, 4}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1923,13 +1939,13 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{-2, 2}, {-1, 1}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1950,10 +1966,11 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({1, 1}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{5}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{5}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1974,11 +1991,11 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{2, 1}, {1, 1}}); + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR3({{{8}}, {{5}}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{8}}, {{5}}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1998,10 +2015,11 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{}, {}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{}, {}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -2022,26 +2040,479 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = Literal::CreateR1({0, 1, 2}); + std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); std::unique_ptr gather_indices = - Literal::CreateR3({{{0}, {1}}, {{2}, {1}}}); + LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{0, 1}, {2, 1}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{0, 1}, {2, 1}}), *Evaluate({operand.get(), gather_indices.get()}))); } +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=mul_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(f32[] lhs, f32[] rhs) +} + +ENTRY main { + operand = f32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = f32[2,3] parameter(2) + ROOT scatter = f32[3,3] scatter(operand, indices, updates), + to_apply=add_f32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({2, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); + EXPECT_TRUE(LiteralTestUtil::Near( + *LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}), + ErrorSpec{0.1, 0.01})); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { + const char* hlo_text = R"( +HloModule TensorFlowScatterMultipleBatchDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=2 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + std::unique_ptr expected = + LiteralUtil::CreateR3({{{-10, 10}, {-2, 2}, {-3, 3}}, // + {{-40, 40}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, + EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNdNonDefaultIndexVectorDim + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + std::unique_ptr expected = + LiteralUtil::CreateR3({{{-20, 20}, {-10, 10}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule DynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0,1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); + std::unique_ptr expected = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + std::unique_ptr updates = + LiteralUtil::CreateR3({{{10}}, {{20}}}); + std::unique_ptr expected = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_ZeroDimBounds + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,0] parameter(2) + ROOT scatter = s32[3,0] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *operand, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { + const string hlo_text = R"( +HloModule Scatter_NoUpdateWindowDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2 +} +)"; + ParseAndVerifyModule(hlo_text); + + std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + std::unique_ptr expected = + LiteralUtil::CreateR1({10, 61, 32}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise comparison with 2 bfloat16 operands. TEST_P(HloEvaluatorTest, DoesCompareBF16) { // lhs >= rhs - auto lhs = Literal::CreateR2( + auto lhs = LiteralUtil::CreateR2( {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)}, {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}}); - auto rhs = Literal::CreateR2( + auto rhs = LiteralUtil::CreateR2( {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)}, {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}}); auto expected = - Literal::CreateR2({{false, true, true}, {false, true, true}}); + LiteralUtil::CreateR2({{false, true, true}, {false, true, true}}); TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs), std::move(rhs)); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 13f46407e33e36bdbef4c9032630101d6c18268f..084b49b4783fe15e91917317d8b3746e2c7569d0 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/core/lib/core/casts.h" @@ -34,6 +35,37 @@ using is_complex_t = std::is_same; template using is_complex64_t = std::is_same; +// It's UB to use std::sort with std::less, because of NaNs. Define +// "safe" less functions which are actually strict weak orders. +template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> +bool SafeLess(const NativeT& a, const NativeT& b) { + return a < b; +} + +template ::value || + std::is_same::value>::type* = nullptr> +bool SafeLess(const NativeT& a, const NativeT& b) { + if (std::isnan(b)) { + return !std::isnan(a); + } else { + return a < b; + } +} + +template ::value>::type* = nullptr> +bool SafeLess(const NativeT& a, const NativeT& b) { + if (Eigen::half_impl::isnan(b)) { + return !Eigen::half_impl::isnan(a); + } else { + return a < b; + } +} + // Templated DfsHloVisitor for use by HloEvaluator. // // Typically ReturnT here indicates the resulting literal type of each evaluated @@ -269,6 +301,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleFloor(floor); } + Status HandleImag(HloInstruction* imag) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag], + ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) { + return std::imag(elem_operand); + })); + return Status::OK(); + } + Status HandleLog(HloInstruction* log) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { @@ -572,6 +612,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleReal(HloInstruction* real) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[real], + ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) { + return std::real(elem_operand); + })); + return Status::OK(); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -610,12 +658,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el && rhs_el; - })); - return Status::OK(); + return InvalidArgument("Unsupported type for And"); } template < @@ -644,12 +687,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> Status HandleOr(HloInstruction* or_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el || rhs_el; - })); - return Status::OK(); + return InvalidArgument("Unsupported type for Or"); } template < @@ -663,6 +701,35 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleOr(or_); } + template ::value>::type* = + nullptr> + Status HandleXor(HloInstruction* xor_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[xor_], + ElementWiseBinaryOp(xor_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el ^ rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleXor(HloInstruction* xor_) { + return InvalidArgument("Unsupported type for Xor"); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleXor(HloInstruction* xor_) { + return InvalidArgument("Unsupported type for Xor"); + } + + Status HandleXor(HloInstruction* xor_) override { + return HandleXor(xor_); + } + template ::value && @@ -778,7 +845,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override { CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); - CHECK(!ShapeUtil::IsTuple(select->shape())); + CHECK(ShapeUtil::IsArray(select->shape())); std::function select_op = [](bool pred, ReturnT on_true, ReturnT on_false) { if (pred) { @@ -1006,83 +1073,47 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_EQ(dnums.lhs_batch_dimensions_size(), dnums.rhs_batch_dimensions_size()); - std::vector lhs_non_contracting_dims; + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); + + // result_index_locations[i] contains one or two pointers to the locations + // in lhs_index or rhs_index where the i'th result index should go. + tensorflow::gtl::InlinedVector, kInlineRank> + result_index_locations; + result_index_locations.reserve(lhs_rank + rhs_rank - 2); + + // The first components in the output shape are the LHS and RHS batch + // dimensions: + for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); i++) { + result_index_locations.push_back( + {&lhs_index[dnums.lhs_batch_dimensions(i)], + &rhs_index[dnums.rhs_batch_dimensions(i)]}); + } + + // Then we have the LHS and RHS non-contracting dimensions, if any: for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension) { - lhs_non_contracting_dims.push_back(i); + if (i != lhs_contracting_dimension && + !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) { + result_index_locations.push_back({&lhs_index[i], nullptr}); } } - - std::vector rhs_non_batch_non_contracting_dims; - tensorflow::gtl::FlatSet batch_dims_set( - dnums.rhs_batch_dimensions().begin(), - dnums.rhs_batch_dimensions().end()); for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { - rhs_non_batch_non_contracting_dims.push_back(i); + if (i != rhs_contracting_dimension && + !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) { + result_index_locations.push_back({&rhs_index[i], nullptr}); } } - const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); - const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); - - DimensionVector lhs_index(lhs_rank); - DimensionVector rhs_index(rhs_rank); auto result = MakeUnique(dot->shape()); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice result_index) { ElementwiseT result_val = static_cast(0); - // Find the corresponding non-contracting indices for lhs and rhs. - // - // For `result_index`, its batch dimension, if exists, will be at the - // same dimension as the batch dimension of lhs and rhs. More - // specifically: - // - For lhs, the non-contracting dimensions, including the batch - // dimension have the same index as the `result_index`. - // - For rhs, the batch dimension is set seperately from other - // non-contracting dimensions, since these other non-contracting - // dimensions in rhs follow the non-contracting dimensions of lhs in - // the resulting index. - // - // As an example, for a resulting index: - // result_index [result_batch, result_x, result_y] - // the effecting lhs and rhs indices are: - // lhs [result_batch, lhs_non_contracting_dim, contracting_dim - // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] - // `result_x` is only affected by the lhs_non_contracting_dim and - // likewise `result_y` only depends on rhs_non_contracting_dim. - // - // so we can look up the lhs and rhs indices by: - // - // lhs: - // batch index is the same as `result_batch`. - // non-contracting dimension is the same as - // result_index[lhs_non_contracting_dim] - // rhs: - // batch index: the same as `result_batch`. - // non-contracting dimension index: *not* the same as - // result_index[rhs_non_contractng_dim], since the - // non-contracting dimensions of lhs are included in the - // result_index first. Instead, the non_contracting_dim of rhs must - // be calculated as following: - // lhs_non_contracting_dimensions_size + - // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 - // - // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is - // the index offset to the result_index that only depends on - // the non_batch and non-contracting dimensions of rhs. -1 at the - // end translates size to index. - for (auto i : lhs_non_contracting_dims) { - lhs_index[i] = result_index[i]; - } - for (auto i : dnums.rhs_batch_dimensions()) { - rhs_index[i] = result_index[i]; - } - for (auto i : rhs_non_batch_non_contracting_dims) { - const int64 rhs_non_batch_non_contracting_dim = - lhs_non_contracting_size + (i - batch_dim_size) - 1; - rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; + for (int64 i = 0; i < result_index.size(); i++) { + *result_index_locations[i].first = result_index[i]; + if (result_index_locations[i].second) { + *result_index_locations[i].second = result_index[i]; + } } // Accumulates resulting product along the contracted dimension. @@ -1103,7 +1134,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandlePad(HloInstruction* pad) override { - CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); + CHECK(ShapeUtil::IsArray(pad->operand(0)->shape())); // Padding value must be scalar. CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), @@ -1116,7 +1147,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { /*padding_config=*/pad->padding_config())); CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); // Create new HLO of padded shape with padding value. @@ -1182,7 +1213,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { dynamic_slice->dynamic_slice_sizes())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); TF_RET_CHECK( primitive_util::IsIntegralType(start_indices->shape().element_type())); @@ -1237,7 +1268,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { operand->shape(), update->shape(), start_indices->shape())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); TF_RET_CHECK( primitive_util::IsIntegralType(start_indices->shape().element_type())); @@ -1302,7 +1333,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { parent_->GetEvaluatedLiteralFor(operand); auto curr_val = arg_literal.Get(multi_index); - auto curr_val_literal = Literal::CreateR0(curr_val); + auto curr_val_literal = LiteralUtil::CreateR0(curr_val); arg_literals.push_back(std::move(curr_val_literal)); } @@ -1378,7 +1409,74 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleSort(HloInstruction* sort) { + auto keys = sort->operand(0); + auto rank = ShapeUtil::Rank(keys->shape()); + TF_RET_CHECK(rank > 0 && rank <= 2) + << "Sort is only supported for R1 and R2 shapes"; + TF_RET_CHECK(sort->operand_count() == 1) + << "Typed visitor does not support key-value sort"; + + const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); + + auto sort_r1 = [this](const Literal& keys_literal) { + VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); + const auto& keys_data = keys_literal.data(); + + std::vector result_data(keys_data.begin(), keys_data.end()); + std::sort(result_data.begin(), result_data.end(), + [](const ReturnT& a, const ReturnT& b) { + return SafeLess(a, b); + }); + auto result_literal = MakeUnique(keys_literal.shape()); + result_literal->PopulateR1( + tensorflow::gtl::ArraySlice(result_data)); + VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); + return result_literal; + }; + + if (rank == 1) { + parent_->evaluated_[sort] = std::move(sort_r1(keys_literal)); + } else { + // For R2 sort, the desired semantics are to sort each matrix row + // independently. + auto result_literal = MakeUnique(keys_literal.shape()); + int64 r1_length = keys->shape().dimensions(1); + for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { + TF_ASSIGN_OR_RETURN(auto r1_slice, + keys_literal.Slice({row, 0}, {row + 1, r1_length}) + ->Reshape({r1_length})); + auto r1_result = sort_r1(*r1_slice); + TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( + *r1_result, {0, 0}, {row, 0}, {1, r1_length})); + } + parent_->evaluated_[sort] = std::move(result_literal); + } + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleSort(HloInstruction* sort) { + return InvalidArgument("Unsupported type for Sort"); + } + + Status HandleSort(HloInstruction* sort) override { + return HandleSort(sort); + } + Status HandleReduce(HloInstruction* reduce) override { + // TODO(b/112040122): Support variadic reduce. + if (!ShapeUtil::IsArray(reduce->shape())) { + return Unimplemented("Variadic reduce is not supported in the Evaluator"); + } auto arg = reduce->operand(0); auto init_value = reduce->operand(1); tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); @@ -1387,13 +1485,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { ShapeUtil::Rank(arg->shape()) - dimensions.size()); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, ShapeInference::InferReduceShape( - /*arg=*/arg->shape(), - /*init_value=*/init_value->shape(), + {&arg->shape(), &init_value->shape()}, /*dimensions_to_reduce=*/dimensions, /*to_apply=*/function->ComputeProgramShape())); TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); @@ -1450,8 +1547,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto curr_val = arg_literal.Get(input_index); // Evaluate computation with specified literal operands. - auto curr_val_literal = Literal::CreateR0(curr_val); - auto result_val_literal = Literal::CreateR0(result_val); + auto curr_val_literal = LiteralUtil::CreateR0(curr_val); + auto result_val_literal = + LiteralUtil::CreateR0(result_val); std::unique_ptr computed_result = embedded_evaluator @@ -1529,10 +1627,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid // dynamic memory allocations. - auto curr_val_literal = Literal::CreateR0(ReturnT()); - auto selected_val_literal = Literal::CreateR0(ReturnT()); - auto source_literal_scatter = Literal::CreateR0(ReturnT()); - auto scattered_literal = Literal::CreateR0(ReturnT()); + auto curr_val_literal = LiteralUtil::CreateR0(ReturnT()); + auto selected_val_literal = LiteralUtil::CreateR0(ReturnT()); + auto source_literal_scatter = LiteralUtil::CreateR0(ReturnT()); + auto scattered_literal = LiteralUtil::CreateR0(ReturnT()); do { // For each element in `source`, we place a window in `operand`. For each // window placement, we iterate inside the window twice: @@ -1613,7 +1711,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanStringWithLayout(inferred_return_shape); const Literal& operand_literal = @@ -1653,9 +1751,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Evaluate computation with specified literal operands. const auto curr_val_literal = - Literal::CreateR0(curr_val); + LiteralUtil::CreateR0(curr_val); const auto result_val_literal = - Literal::CreateR0(result_val); + LiteralUtil::CreateR0(result_val); std::unique_ptr computed_result = embedded_evaluator .Evaluate( @@ -1677,6 +1775,388 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + // Reshapes the scatter indices input to have a trailing degenerate `1` + // dimension if necessary. Hands over the ownership of the newly created + // literal (if there is one) to `reshaped_indices`. + StatusOr> ReshapedScatterIndices( + int64 index_vector_dim, const Literal& indices, + std::unique_ptr* reshaped_indices) { + if (indices.shape().dimensions_size() != index_vector_dim) { + return std::cref(indices); + } + + std::vector new_shape(indices.shape().dimensions().begin(), + indices.shape().dimensions().end()); + new_shape.push_back(1); + TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape)); + return std::cref(**reshaped_indices); + } + + // Returns an ShapeUtil::IndexIterationSpace that iterates over the update + // scatter dimensions while keeping the rest of the update dimensions clamped + // to 0. + ShapeUtil::IndexIterationSpace IterationSpaceForUpdateScatterIndices( + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + int64 updates_rank = updates_shape.dimensions_size(); + std::vector index_base(updates_rank, 0); + std::vector index_count(updates_rank, 1); + for (int64 i = 0; i < updates_rank; i++) { + bool is_update_scatter_dim = + !c_binary_search(dim_numbers.update_window_dims(), i); + if (is_update_scatter_dim) { + index_count[i] = updates_shape.dimensions(i); + } + } + return {std::move(index_base), std::move(index_count), + std::vector(updates_rank, 1)}; + } + + // Return an ShapeUtil::IndexIterationSpace that iterates over the update + // window dimensions while keeping the rest of the update dimensions clamped + // to 0. + ShapeUtil::IndexIterationSpace IterationSpaceForUpdateWindowIndices( + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + int64 updates_rank = updates_shape.dimensions_size(); + std::vector index_base(updates_rank, 0); + std::vector index_count(updates_rank, 1); + for (int64 i = 0; i < updates_rank; i++) { + bool is_update_window_dim = + c_binary_search(dim_numbers.update_window_dims(), i); + if (is_update_window_dim) { + index_count[i] = updates_shape.dimensions(i); + } + } + return {std::move(index_base), std::move(index_count), + std::vector(updates_rank, 1)}; + } + + // This functor computes the contribution of scatter_indices to an input index + // corresponding to an update index. That is, given an update index I, it + // picks out the scatter indices in I and uses them to look up a scatter + // index, S, from the scatter indices tensor, and expands S into the input + // space according to scatter_dims_to_operand_dims. + // + // This is similar to the class HloEvaluator::OutputGatherIndexToInputIndex + // that does the corresponding function for Gather. + class UpdateScatterIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit UpdateScatterIndexToInputIndex( + const ScatterDimensionNumbers* dim_numbers, const Shape& input_shape, + const Shape& updates_shape, const Literal* scatter_indices) + : dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) { + for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { + update_dim_is_scatter_dims_.push_back( + !c_binary_search(dim_numbers_.update_window_dims(), i)); + } + + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + int64 index_of_input_dim_in_index_vector = + FindIndex(dim_numbers_.scatter_dims_to_operand_dims(), i); + if (index_of_input_dim_in_index_vector == + dim_numbers_.scatter_dims_to_operand_dims_size()) { + input_dim_value_to_index_vector_.push_back(-1); + } else { + input_dim_value_to_index_vector_.push_back( + index_of_input_dim_in_index_vector); + } + } + + index_vector_index_.resize(scatter_indices_.shape().dimensions_size()); + input_index_.resize(input_shape.dimensions_size()); + int64 index_vector_size = + scatter_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); + index_vector_.resize(index_vector_size); + } + + // Returns the contribution of scatter_indices to the input index + // corresponding to update_index. See scatter_inner_loop_body. + // + // This is conceptually a stateless transformation from update_index to the + // scatter input index, but: + // + // - Instead of allocating memory to represent the scatter input index on + // every invocation we reuse the same storage for the result + // (input_index_), mutating it in place. + // - Instead of allocating buffers for temporary values like + // index_vector_index_ and index_vector on every invocation, we reuse the + // same storage for all invocations. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()( + tensorflow::gtl::ArraySlice update_index) { + PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index); + TF_RETURN_IF_ERROR(FetchIndexVector()); + PropagateIndexVectorToInputIndex(); + return tensorflow::gtl::ArraySlice(input_index_); + } + + private: + // Propagates the scatter index dimensions from the update index into + // index_vector_index_ by mutating index_vector_index_ in place. Does not + // update the dim_numbers.index_vector_dim() dimension -- that's the + // dimension we iterate over in FetchIndexVector. + void PropagateUpdateIndexScatterDimsToIndexVectorIndex( + tensorflow::gtl::ArraySlice update_index) { + int64 index_vector_index_i = 0; + for (int64 i = 0, e = update_index.size(); i < e; i++) { + if (!update_dim_is_scatter_dims_[i]) { + continue; + } + + if (index_vector_index_i == dim_numbers_.index_vector_dim()) { + index_vector_index_i++; + } + + index_vector_index_[index_vector_index_i++] = update_index[i]; + } + } + + // Populates index_vector_ by iterating over scatter_indices_ according to + // index_vector_index_. + Status FetchIndexVector() { + int64 index_vector_dim = dim_numbers_.index_vector_dim(); + for (int64 i = 0, e = index_vector_.size(); i < e; i++) { + index_vector_index_[index_vector_dim] = i; + TF_ASSIGN_OR_RETURN(index_vector_[i], scatter_indices_.GetIntegralAsS64( + index_vector_index_)); + } + return Status::OK(); + } + + // Populates input_index_. + void PropagateIndexVectorToInputIndex() { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_index_vector_[i] != -1) { + input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i + // of the input index from the index vector. See + // PropagateIndexVectorToInputIndex. + std::vector input_dim_value_to_index_vector_; + + // update_dim_is_scatter_dims_[i] is true iff the update index i is a + // scatter dimension. + std::vector update_dim_is_scatter_dims_; + + // The buffer into which we construct an index into scatter_indices_ to + // fetch the index vector. + std::vector index_vector_index_; + + // The index vector fetched from scatter_indices_. + std::vector index_vector_; + + // The result computed by this functor. operator() returns an ArraySlice + // into this vector. + std::vector input_index_; + + const ScatterDimensionNumbers& dim_numbers_; + const Literal& scatter_indices_; + }; + + // This functor computes the contribution of the window indices in an update + // index to an input index. That is, given an update index I it picks out the + // update window indices in I and expands it into a window index into the + // input shape. + // + // This is similar to the class HloEvaluator::OutputWindowIndexToInputIndex + // that does the corresponding function for Gather. + class UpdateWindowIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit UpdateWindowIndexToInputIndex( + const ScatterDimensionNumbers& dim_numbers, const Shape& input_shape, + const Shape& updates_shape) { + std::vector window_index_to_update_index; + int64 update_index_count = 0; + for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { + if (c_binary_search(dim_numbers.update_window_dims(), i)) { + window_index_to_update_index.push_back(update_index_count++); + } else { + update_index_count++; + } + } + + int64 window_dim_count = 0; + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + if (c_binary_search(dim_numbers.inserted_window_dims(), i)) { + input_dim_value_to_update_index_.push_back(-1); + } else { + input_dim_value_to_update_index_.push_back( + window_index_to_update_index[window_dim_count++]); + } + } + + input_index_.resize(input_shape.dimensions_size()); + } + + // Returns the contribution of the window indices to the input index + // corresponding to update_index. See scatter_inner_loop_body. + // + // This is conceptually a stateless transformation from update_index to the + // window input index, but instead of allocating memory to represent the + // scatter input index on every invocation we reuse the same storage for the + // result (input_index_), mutating it in place. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()( + tensorflow::gtl::ArraySlice update_index) { + PropagateUpdateIndexWindowDimsToInputIndex(update_index); + return tensorflow::gtl::ArraySlice(input_index_); + } + + // Returns for a given 'input_dim' the corresponding update dimension index, + // or -1 if 'input_dim' is an elided window dimension. + int64 input_dim_value_to_update_index(int64 input_dim) { + return input_dim_value_to_update_index_[input_dim]; + } + + private: + // Propagates window dimensions from the update index to input_index_ by + // mutating input_index_ in place. + void PropagateUpdateIndexWindowDimsToInputIndex( + tensorflow::gtl::ArraySlice update_index) { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_update_index_[i] != -1) { + input_index_[i] = update_index[input_dim_value_to_update_index_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i + // of the input index from the update index. See + // PropagateUpdateIndexWindowDimsToInputIndex. + std::vector input_dim_value_to_update_index_; + + // The result computed by this functor. operator() returns an ArraySlice + // into this vector. + std::vector input_index_; + }; + + Status HandleScatter(HloInstruction* scatter) override { + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + const Literal& operand = + parent_->GetEvaluatedLiteralFor(scatter->operand(0)); + std::unique_ptr reshaped_scatter_indices; + TF_ASSIGN_OR_RETURN(const Literal& scatter_indices, + ReshapedScatterIndices(dim_numbers.index_vector_dim(), + parent_->GetEvaluatedLiteralFor( + scatter->operand(1)), + &reshaped_scatter_indices)); + const Literal& updates = + parent_->GetEvaluatedLiteralFor(scatter->operand(2)); + const Shape& updates_shape = updates.shape(); + const Shape& operand_shape = operand.shape(); + + ShapeUtil::IndexIterationSpace scatter_indices_iteration_space = + IterationSpaceForUpdateScatterIndices(updates_shape, dim_numbers); + ShapeUtil::IndexIterationSpace window_indices_iteration_space = + IterationSpaceForUpdateWindowIndices(updates_shape, dim_numbers); + + std::vector input_index(operand_shape.dimensions_size()); + std::vector update_index(updates_shape.dimensions_size()); + std::vector input_scatter_index_clamped( + operand_shape.dimensions_size()); + + UpdateScatterIndexToInputIndex update_scatter_index_to_input_index( + &scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, + updates_shape, &scatter_indices); + UpdateWindowIndexToInputIndex update_window_index_to_input_index( + scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, + updates_shape); + + // Initialize the result with the operand. This makes it easier to handle + // the updates even when the indices are repeated. + std::unique_ptr result = operand.CloneToUnique(); + HloEvaluator embedded_evaluator; + auto scatter_inner_loop_body = + [&](tensorflow::gtl::ArraySlice update_window_index, + tensorflow::gtl::ArraySlice input_scatter_index, + tensorflow::gtl::ArraySlice update_scatter_index) + -> StatusOr { + TF_ASSIGN_OR_RETURN( + tensorflow::gtl::ArraySlice input_window_index, + update_window_index_to_input_index(update_window_index)); + for (int i = 0, e = update_index.size(); i < e; i++) { + update_index[i] = update_scatter_index[i] + update_window_index[i]; + DCHECK_LT(update_index[i], updates_shape.dimensions(i)); + } + for (int i = 0, e = input_scatter_index.size(); i < e; i++) { + int64 update_dim = + update_window_index_to_input_index.input_dim_value_to_update_index( + i); + // If 'update_dim' is -1, it means 'i' is an elided window dim. This + // means we set the iteration index to 0, so for the purpose of the + // following calculations we can consider the update dimension size to + // be 1. + int64 update_dim_size = + update_dim == -1 ? 1 : updates_shape.dimensions(update_dim); + // Clamp the scatter index so that the scatter region fits in the + // operand. input_scatter_index_clamped[i] = + // clamp(input_scatter_index[i], 0, + // operand_shape.dimensions(i) - + // update_dim_size); + input_scatter_index_clamped[i] = + std::min(operand_shape.dimensions(i) - update_dim_size, + std::max(0LL, input_scatter_index[i])); + } + for (int i = 0, e = input_index.size(); i < e; i++) { + input_index[i] = input_scatter_index_clamped[i] + input_window_index[i]; + DCHECK_GE(input_index[i], 0); + DCHECK_LT(input_index[i], operand_shape.dimensions(i)); + } + + auto result_value_literal = + LiteralUtil::CreateR0(result->Get(input_index)); + auto update_value_literal = + LiteralUtil::CreateR0(updates.Get(update_index)); + std::unique_ptr updated_result = + embedded_evaluator + .Evaluate( + *scatter->to_apply(), + {result_value_literal.get(), update_value_literal.get()}) + .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on the + // same computation. + embedded_evaluator.ResetVisitStates(); + result->Set(input_index, updated_result->Get({})); + return true; + }; + + auto scatter_outer_loop_body = + [&](tensorflow::gtl::ArraySlice update_scatter_index) + -> StatusOr { + TF_ASSIGN_OR_RETURN( + tensorflow::gtl::ArraySlice input_scatter_index, + update_scatter_index_to_input_index(update_scatter_index)); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + updates_shape, window_indices_iteration_space, + [&](tensorflow::gtl::ArraySlice update_window_index) { + return scatter_inner_loop_body( + update_window_index, input_scatter_index, update_scatter_index); + })); + return true; + }; + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + updates_shape, scatter_indices_iteration_space, + scatter_outer_loop_body)); + parent_->evaluated_[scatter] = std::move(result); + return Status::OK(); + } + Status HandleSlice(HloInstruction* slice) override { auto operand = slice->operand(0); const Shape& shape = slice->shape(); @@ -1700,7 +2180,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return operand_literal.Get(operand_index); }; - auto result = Literal::CreateFromDimensions( + auto result = LiteralUtil::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); TF_RETURN_IF_ERROR(result->Populate(func)); parent_->evaluated_[slice] = std::move(result); @@ -1902,6 +2382,30 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleReducePrecision(reduce_precision); } + template ::value || + std::is_same::value || + std::is_same::value>::type* = nullptr> + Status HandleIota(HloInstruction* iota) { + auto result = MakeUnique(iota->shape()); + auto data = result->data(); + std::iota(data.begin(), data.end(), 0); + parent_->evaluated_[iota] = std::move(result); + return Status::OK(); + } + template ::value || + std::is_same::value || + std::is_same::value)>::type* = nullptr> + Status HandleIota(HloInstruction* iota) { + return InvalidArgument("Unsupported type for iota"); + } + Status HandleIota(HloInstruction* iota) override { + return HandleIota(iota); + } + private: // Creates a vector of multipliers which can be used to create a linear index // into shape. @@ -1959,10 +2463,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { start_indices_typed.end()); // Clamp the start indices so the slice is in-bounds w.r.t the operand. - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to officially document different behavior. for (int64 i = 0; i < start.size(); ++i) { start[i] = std::min( std::max(int64{0}, start[i]), @@ -1996,10 +2496,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { start_indices_typed.end()); // Clamp the update start indices so the slice is in-bounds w.r.t the // operand. - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. for (int64 i = 0; i < rank; ++i) { start[i] = std::min( std::max(0, start[i]), diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 28fc6c4209bcc14d890f28d6a9935c55b76586c0..1efa6eb5bda7e1cb90874e0466aafd2c788a3fbf 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -27,7 +27,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -832,19 +832,22 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(), // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which // is just noise. - if (!ShapeUtil::IsTuple(shape) && ShapeUtil::HasZeroElements(shape)) { + if (ShapeUtil::IsZeroElementArray(shape)) { return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); } // Print the literal value of constants with <= K elements. optional elem_count; - if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) { + if (ShapeUtil::IsArray(shape)) { elem_count = 1; for (int64 dim : shape.dimensions()) { *elem_count *= dim; } } - if (elem_count.has_value() && *elem_count <= 8) { + // Allow HloDotDumper to print HloInstruction reconstructed from HloProto + // collected from profiling tools. Those constants may not have a valid + // literal. + if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } @@ -948,6 +951,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kGe: case HloOpcode::kGt: case HloOpcode::kImag: + case HloOpcode::kIota: case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: @@ -960,11 +964,13 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: case HloOpcode::kRng: case HloOpcode::kRoundNearestAfz: + case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -983,7 +989,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kBitcast: case HloOpcode::kGetTupleElement: case HloOpcode::kTrace: - case HloOpcode::kGenerateToken: + case HloOpcode::kAfterAll: case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: @@ -1000,7 +1006,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kPad: case HloOpcode::kReshape: case HloOpcode::kReverse: - case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: case HloOpcode::kTranspose: // De-emphasize scalar-shaped data movement ops and all data movement ops // inside fusion nodes, both of which are essentially free. @@ -1016,6 +1022,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kWhite; } return kGreen; + case HloOpcode::kScatter: + // Do not de-emphasize Scatter, since it involves significant work. case HloOpcode::kCopy: // Emphasize copy nodes, which are either physical transposes (and thus // significant), or copies of read-only buffers (and thus dead weight). @@ -1040,6 +1048,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kMap: return kGray; case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kRecv: diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 68f41a1cbb4db228f5dcf8b4a6130f05e81262a8..1d7a062c55696de9db4b187efd86bce191279083 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -120,7 +121,7 @@ TEST(HloGraphDumperTest, NestedFusion) { TEST(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(-42))); instruction->SetAndSanitizeName("i_am_a_constant_root_instruction"); HloModuleConfig config; HloModule m(TestName(), config); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 39662d1735e4b411ef36cbc8421eabe52be6165a..8690f2cdaa9b45d126e91b123c6992cbe2f27e1d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -65,31 +64,47 @@ StatusOr> HloInstruction::CreateFromProto( const auto operands = [&instruction_map, &proto](int index) { return instruction_map.at(proto.operand_ids(index)); }; + const auto all_operands = [&instruction_map, &proto]() { + std::vector result(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + result.begin(), [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + return result; + }; const auto computations = [&computation_map, &proto](int index) { return computation_map.at(proto.called_computation_ids(index)); }; switch (opcode) { // Ops migrated to subclasses. case HloOpcode::kBatchNormTraining: - CHECK_EQ(proto.operand_ids_size(), 3); + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "BatchNormTraining instruction should have 3 operands but sees " + << proto.operand_ids_size(); instruction = CreateBatchNormTraining( proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormInference: - CHECK_EQ(proto.operand_ids_size(), 5); + TF_RET_CHECK(proto.operand_ids_size() == 5) + << "BatchNormInference instruction should have 5 operands but sees " + << proto.operand_ids_size(); instruction = CreateBatchNormInference( proto.shape(), operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormGrad: - CHECK_EQ(proto.operand_ids_size(), 5); + TF_RET_CHECK(proto.operand_ids_size() == 5) + << "BatchNormGrad instruction should have 5 operands but sees " + << proto.operand_ids_size(); instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kFft: { - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Fft instruction should have 1 operand but sees " + << proto.operand_ids_size(); std::vector fft_length(proto.fft_length().begin(), proto.fft_length().end()); instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), @@ -97,75 +112,100 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kSend: - CHECK_EQ(proto.operand_ids_size(), 1); - instruction = CreateSend(operands(0), proto.channel_id()); + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Send instruction should have 2 operand but sees " + << proto.operand_ids_size(); + instruction = CreateSend(operands(0), operands(1), proto.channel_id(), + proto.is_host_transfer()); break; case HloOpcode::kSendDone: - CHECK_EQ(proto.operand_ids_size(), 1); - instruction = CreateSendDone(operands(0)); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "SendDone instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateSendDone(operands(0), proto.is_host_transfer()); break; case HloOpcode::kRecv: - CHECK_EQ(proto.operand_ids_size(), 0); - instruction = - CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id()); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Recv instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0), + proto.channel_id(), proto.is_host_transfer()); break; case HloOpcode::kRecvDone: - CHECK_EQ(proto.operand_ids_size(), 1); - instruction = CreateRecvDone(operands(0)); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "RecvDone instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateRecvDone(operands(0), proto.is_host_transfer()); break; case HloOpcode::kReverse: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Reverse instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateReverse(proto.shape(), operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; - case HloOpcode::kConcatenate: { - CHECK_EQ(proto.dimensions_size(), 1); - std::vector concat_operands(proto.operand_ids_size()); - std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), - concat_operands.begin(), - [&instruction_map](int64 operand_id) { - return instruction_map.at(operand_id); - }); - instruction = CreateConcatenate(proto.shape(), concat_operands, - proto.dimensions(0)); + case HloOpcode::kConcatenate: + TF_RET_CHECK(proto.dimensions_size() == 1) + << "Concatenate instruction should have 1 dimension but sees " + << proto.dimensions_size(); + instruction = + CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0)); break; - } case HloOpcode::kReduce: - CHECK_EQ(proto.operand_ids_size(), 2); - CHECK_EQ(proto.called_computation_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Reduce instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Reduce instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); instruction = CreateReduce(proto.shape(), operands(0), operands(1), std::vector(proto.dimensions().begin(), proto.dimensions().end()), computations(0)); break; + case HloOpcode::kSort: { + TF_RET_CHECK(proto.operand_ids_size() == 1 || + proto.operand_ids_size() == 2) + << "Sort instruction should have 1 or 2 operands but has " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.dimensions().size() == 1) + << "Sort instruction should have 1 dimension"; + HloInstruction* keys = operands(0); + HloInstruction* values = + proto.operand_ids_size() == 2 ? operands(1) : nullptr; + instruction = + CreateSort(proto.shape(), proto.dimensions(0), keys, values); + break; + } case HloOpcode::kTranspose: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Transpose instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateTranspose(proto.shape(), operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; case HloOpcode::kBroadcast: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Broadcast instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateBroadcast(proto.shape(), operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; - case HloOpcode::kMap: { - CHECK_EQ(proto.called_computation_ids_size(), 1); - std::vector map_operands(proto.operand_ids_size()); - std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), - map_operands.begin(), - [&instruction_map](int64 operand_id) { - return instruction_map.at(operand_id); - }); - instruction = CreateMap(proto.shape(), map_operands, computations(0)); + case HloOpcode::kMap: + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Map instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); + instruction = CreateMap(proto.shape(), all_operands(), computations(0)); break; - } case HloOpcode::kSlice: { - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Slice instruction should have 1 operand but sees " + << proto.operand_ids_size(); std::vector slice_starts, slice_limits, slice_strides; for (const HloInstructionProto::SliceDimensions& slice_dimensions : proto.slice_dimensions()) { @@ -178,17 +218,21 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kConstant: { - CHECK(proto.has_literal()); - TF_ASSIGN_OR_RETURN(auto literal, - Literal::CreateFromProto(proto.literal())); - instruction = CreateConstant(std::move(literal)); + // TODO(b/110214922): Revert this to CHECK(proto.has_literal()). + if (proto.has_literal()) { + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(proto.literal())); + instruction = CreateConstant(std::move(literal)); + } else { + instruction = MakeUnique(proto.shape()); + } break; } case HloOpcode::kTrace: { TF_RET_CHECK(proto.operand_ids_size() == 1) << "Trace instruction should have 1 operand but sees " << proto.operand_ids_size(); - CHECK(proto.has_literal()); + TF_RET_CHECK(proto.has_literal()); TF_ASSIGN_OR_RETURN(auto literal, Literal::CreateFromProto(proto.literal())); instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); @@ -204,37 +248,28 @@ StatusOr> HloInstruction::CreateFromProto( // Find the fused computation and set its fusion instruction. TF_RET_CHECK(proto.called_computation_ids_size() == 1) - << "Expect 1 called computation for fusion instruction, but sees " + << "Expect 1 called computation for fusion instruction but sees " << proto.called_computation_ids_size(); const int64 fusion_id = proto.called_computation_ids(0); auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); TF_RET_CHECK(fused_computation != nullptr) << "No fusion computation with id " << fusion_id; - std::vector fusion_operands(proto.operand_ids_size()); - std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), - fusion_operands.begin(), - [&instruction_map](int64 operand_id) { - return instruction_map.at(operand_id); - }); - instruction = CreateFusion(proto.shape(), fusion_kind, fusion_operands, + instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(), fused_computation); break; } - case HloOpcode::kRng: { - std::vector rng_parms(proto.operand_ids_size()); - std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), - rng_parms.begin(), [&instruction_map](int64 operand_id) { - return instruction_map.at(operand_id); - }); - instruction = CreateRng(proto.shape(), proto.distribution(), rng_parms); + case HloOpcode::kRng: + instruction = + CreateRng(proto.shape(), proto.distribution(), all_operands()); break; - } case HloOpcode::kParameter: instruction = CreateParameter(proto.parameter_number(), proto.shape(), proto.name()); break; case HloOpcode::kGetTupleElement: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "GetTupleElement instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateGetTupleElement(proto.shape(), operands(0), proto.tuple_index()); break; @@ -243,6 +278,157 @@ StatusOr> HloInstruction::CreateFromProto( CreateReducePrecision(proto.shape(), operands(0), proto.exponent_bits(), proto.mantissa_bits()); break; + case HloOpcode::kInfeed: { + const Shape& data_shape = + ShapeUtil::GetTupleElementShape(proto.shape(), 0); + if (proto.operand_ids_size() == 0) { + // TODO(b/80000000): Remove this when all uses of infeed are + // converted to take tokens. + instruction = CreateInfeed(data_shape, proto.infeed_config()); + } else { + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = + CreateInfeed(data_shape, operands(0), proto.infeed_config()); + } + } break; + case HloOpcode::kOutfeed: + if (proto.operand_ids_size() == 1) { + // TODO(b/80000000): Remove this when all uses of outfeed are + // converted to take tokens. + instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), + proto.outfeed_config()); + } else { + CHECK_EQ(proto.operand_ids_size(), 2); + instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), + operands(1), proto.outfeed_config()); + } + break; + case HloOpcode::kCrossReplicaSum: { + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "CrossReplicaSum should have 1 called computation but sees " + << proto.called_computation_ids_size(); + tensorflow::gtl::optional all_reduce_id; + if (proto.all_reduce_id() > 0) { + all_reduce_id = proto.all_reduce_id(); + } + instruction = CreateCrossReplicaSum( + proto.shape(), all_operands(), computations(0), + /*replica_group_ids=*/ + std::vector(proto.replica_group_ids().begin(), + proto.replica_group_ids().end()), + /*barrier=*/proto.cross_replica_sum_barrier(), + /*all_reduce_id=*/all_reduce_id); + break; + } + case HloOpcode::kAllToAll: { + instruction = CreateAllToAll( + proto.shape(), all_operands(), + /*replica_groups=*/ + std::vector(proto.replica_groups().begin(), + proto.replica_groups().end()), + /*barrier=*/proto.cross_replica_sum_barrier()); + break; + } + case HloOpcode::kConvolution: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Convolution instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_window()); + TF_RET_CHECK(proto.has_convolution_dimension_numbers()); + instruction = + CreateConvolve(proto.shape(), operands(0), operands(1), + proto.window(), proto.convolution_dimension_numbers()); + break; + case HloOpcode::kReduceWindow: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "ReduceWindow instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "ReduceWindow should have 1 called computation but sees " + << proto.called_computation_ids_size(); + instruction = CreateReduceWindow(proto.shape(), operands(0), operands(1), + proto.window(), computations(0)); + break; + case HloOpcode::kSelectAndScatter: + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "SelectAndScatter instruction should have 3 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 2) + << "SelectAndScatter should have 2 called computations but sees " + << proto.called_computation_ids_size(); + instruction = CreateSelectAndScatter( + proto.shape(), operands(0), computations(0), proto.window(), + operands(1), operands(2), computations(1)); + break; + case HloOpcode::kCustomCall: + instruction = CreateCustomCall(proto.shape(), all_operands(), + proto.custom_call_target()); + if (proto.has_window()) { + static_cast(instruction.get()) + ->set_window(proto.window()); + } + if (proto.has_convolution_dimension_numbers()) { + static_cast(instruction.get()) + ->set_convolution_dimension_numbers( + proto.convolution_dimension_numbers()); + } + break; + case HloOpcode::kHostCompute: + instruction = + CreateHostCompute(proto.shape(), all_operands(), proto.channel_name(), + proto.cost_estimate_ns()); + break; + case HloOpcode::kPad: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Pad instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_padding_config()); + instruction = CreatePad(proto.shape(), operands(0), operands(1), + proto.padding_config()); + break; + case HloOpcode::kDynamicSlice: { + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "DynamicSlice instruction should have 2 operands but sees " + << proto.operand_ids_size(); + std::vector slice_sizes(proto.dynamic_slice_sizes_size()); + c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); + instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1), + slice_sizes); + break; + } + case HloOpcode::kGather: { + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Gather instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_gather_dimension_numbers()) + << "Gather instruction should have GatherDimensionNumbers set."; + std::unique_ptr gather_dimension_numbers = + MakeUnique(proto.gather_dimension_numbers()); + std::vector gather_window_bounds; + for (int64 bound : proto.gather_window_bounds()) { + gather_window_bounds.push_back(bound); + } + instruction = + CreateGather(proto.shape(), operands(0), operands(1), + *gather_dimension_numbers, gather_window_bounds); + break; + } + case HloOpcode::kScatter: { + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "Scatter instruction should have 3 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_scatter_dimension_numbers()) + << "Scatter instruction should have ScatterDimensionNumbers set."; + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Scatter instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); + auto scatter_dimension_numbers = MakeUnique( + proto.scatter_dimension_numbers()); + instruction = + CreateScatter(proto.shape(), operands(0), operands(1), operands(2), + computations(0), *scatter_dimension_numbers); + break; + } default: { instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -273,52 +459,17 @@ StatusOr> HloInstruction::CreateFromProto( instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); - if (proto.has_window()) { - instruction->window_ = MakeUnique(proto.window()); - } - if (proto.has_convolution_dimension_numbers()) { - instruction->convolution_dimension_numbers_ = - MakeUnique( - proto.convolution_dimension_numbers()); - } if (proto.has_dot_dimension_numbers()) { instruction->dot_dimension_numbers_ = MakeUnique(proto.dot_dimension_numbers()); } - for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) { - instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size); - } - if (proto.has_padding_config()) { - instruction->padding_config_ = - MakeUnique(proto.padding_config()); - } - instruction->outfeed_config_ = proto.outfeed_config(); - instruction->infeed_config_ = proto.infeed_config(); - instruction->custom_call_target_ = proto.custom_call_target(); - instruction->outfeed_shape_ = proto.outfeed_shape(); - if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, HloSharding::FromProto(proto.sharding())); instruction->set_sharding(sharding); } - if (proto.has_gather_dimension_numbers()) { - instruction->gather_dimension_numbers_ = - MakeUnique(proto.gather_dimension_numbers()); - } - for (int64 bound : proto.gather_window_bounds()) { - instruction->gather_window_bounds_.push_back(bound); - } - - instruction->channel_name_ = proto.channel_name(); - instruction->cost_estimate_ns_ = proto.cost_estimate_ns(); - - for (int64 replica_group_id : proto.replica_group_ids()) { - instruction->replica_group_ids_.push_back(replica_group_id); - } - return std::move(instruction); } @@ -337,6 +488,11 @@ StatusOr> HloInstruction::CreateFromProto( return MakeUnique(std::move(literal)); } +/* static */ std::unique_ptr HloInstruction::CreateIota( + const Shape& shape) { + return WrapUnique(new HloInstruction(HloOpcode::kIota, shape)); +} + /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { @@ -388,7 +544,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: break; default: @@ -423,6 +578,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kSubtract: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -440,8 +596,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, // Only certain opcodes are supported with CreateTernary: opcodes of ternary // instructions with no auxiliary fields. switch (opcode) { - case (HloOpcode::kClamp): - case (HloOpcode::kSelect): + case HloOpcode::kClamp: + case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: break; default: LOG(FATAL) << "Invalid ternary instruction opcode " @@ -459,30 +616,16 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateMap( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice static_operands) { - return MakeUnique(shape, operands, map_computation, - static_operands); + HloComputation* map_computation) { + return MakeUnique(shape, operands, map_computation); } /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConvolution, shape)); - if (window_util::HasBaseDilation(window)) { - instruction->name_ = instruction->name() + "-base-dilated"; - } - if (window_util::HasWindowDilation(window)) { - instruction->name_ = instruction->name() + "-window-dilated"; - } - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->window_ = MakeUnique(window); - instruction->convolution_dimension_numbers_ = - MakeUnique(dimension_numbers); - return instruction; + return MakeUnique(shape, lhs, rhs, window, + dimension_numbers); } /* static */ std::unique_ptr HloInstruction::CreateFft( @@ -531,66 +674,73 @@ HloInstruction::CreateCrossReplicaSum( HloComputation* reduce_computation, tensorflow::gtl::ArraySlice replica_group_ids, tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& channel_id) { - // TODO(b/79737069): Remove the CHECK when supported. - CHECK(!channel_id.has_value()); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->called_computations_.push_back(reduce_computation); - instruction->replica_group_ids_.assign(replica_group_ids.begin(), - replica_group_ids.end()); - instruction->cross_replica_sum_barrier_ = std::string(barrier); - return instruction; + const tensorflow::gtl::optional& all_reduce_id) { + return MakeUnique( + shape, operands, reduce_computation, replica_group_ids, barrier, + all_reduce_id); +} + +/* static */ std::unique_ptr HloInstruction::CreateAllToAll( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups, + tensorflow::StringPiece barrier) { + return MakeUnique(shape, operands, replica_groups, + barrier); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( - const Shape& shape, const string& config) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape)); - instruction->set_infeed_config(config); - return instruction; + const Shape& infeed_shape, HloInstruction* token_operand, + const string& config) { + return MakeUnique(infeed_shape, token_operand, config); +} + +/* static */ std::unique_ptr HloInstruction::CreateInfeed( + const Shape& infeed_shape, const string& config) { + return MakeUnique(infeed_shape, config); } /* static */ std::unique_ptr HloInstruction::CreateOutfeed( - const Shape& shape, HloInstruction* operand, + const Shape& outfeed_shape, HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) { + return MakeUnique(outfeed_shape, operand, + token_operand, outfeed_config); +} + +/* static */ std::unique_ptr HloInstruction::CreateOutfeed( + const Shape& outfeed_shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil())); - CHECK(ShapeUtil::Compatible(operand->shape(), shape)) - << "Outfeed shape " << shape << " must be compatible with operand shape " - << operand->shape(); - instruction->AppendOperand(operand); - instruction->outfeed_config_ = std::string(outfeed_config); - instruction->outfeed_shape_ = shape; - return instruction; + return MakeUnique(outfeed_shape, operand, + outfeed_config); } /* static */ std::unique_ptr HloInstruction::CreateSend( - HloInstruction* operand, int64 channel_id) { - return MakeUnique(operand, channel_id); + HloInstruction* operand, HloInstruction* token, int64 channel_id, + bool is_host_transfer) { + return MakeUnique(operand, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateSendDone( - HloInstruction* operand) { + HloInstruction* operand, bool is_host_transfer) { auto send_operand = DynCast(operand); CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - return MakeUnique(send_operand); + return MakeUnique(send_operand, is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecv( - const Shape& shape, int64 channel_id) { - return MakeUnique(shape, channel_id); + const Shape& shape, HloInstruction* token, int64 channel_id, + bool is_host_transfer) { + return MakeUnique(shape, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecvDone( - HloInstruction* operand) { + HloInstruction* operand, bool is_host_transfer) { auto recv_operand = DynCast(operand); CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - return MakeUnique(recv_operand); + return MakeUnique(recv_operand, is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateReverse( @@ -599,17 +749,22 @@ HloInstruction::CreateCrossReplicaSum( return MakeUnique(shape, operand, dimensions); } -/* static */ std::unique_ptr -HloInstruction::CreateGenerateToken( +/* static */ std::unique_ptr HloInstruction::CreateAfterAll( tensorflow::gtl::ArraySlice operands) { - auto instruction = WrapUnique(new HloInstruction( - HloOpcode::kGenerateToken, ShapeUtil::MakeTokenShape())); + CHECK(!operands.empty()); + auto instruction = WrapUnique( + new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); for (auto operand : operands) { instruction->AppendOperand(operand); } return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateToken() { + return WrapUnique( + new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); +} + /* static */ std::unique_ptr HloInstruction::CreateWhile( const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init) { @@ -650,13 +805,8 @@ HloInstruction::CreateGenerateToken( /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kDynamicSlice, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(start_indices); - instruction->dynamic_slice_sizes_.assign(slice_sizes.begin(), - slice_sizes.end()); - return instruction; + return MakeUnique(shape, operand, start_indices, + slice_sizes); } /* static */ std::unique_ptr @@ -695,23 +845,32 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, } /* static */ std::unique_ptr HloInstruction::CreateReduce( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, HloInstruction* operand, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation) { + auto instruction = WrapUnique(new HloReduceInstruction( + shape, {operand, init_value}, dimensions_to_reduce, reduce_computation)); + return std::move(instruction); +} + +/* static */ std::unique_ptr HloInstruction::CreateReduce( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice init_values, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) { - return MakeUnique( - shape, arg, init_value, dimensions_to_reduce, reduce_computation); + std::vector all_args; + all_args.reserve(operands.size() * 2); + all_args.insert(all_args.end(), operands.begin(), operands.end()); + all_args.insert(all_args.end(), init_values.begin(), init_values.end()); + return MakeUnique(shape, all_args, dimensions_to_reduce, + reduce_computation); } /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(init_value); - instruction->called_computations_.push_back(reduce_computation); - instruction->window_ = MakeUnique(window); - return instruction; + return MakeUnique(shape, operand, init_value, + window, reduce_computation); } /* static */ std::unique_ptr @@ -749,16 +908,8 @@ HloInstruction::CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, const Window& window, HloInstruction* source, HloInstruction* init_value, HloComputation* scatter) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSelectAndScatter, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(source); - instruction->AppendOperand(init_value); - // Select comes before scatter in the vector. - instruction->called_computations_.push_back(select); - instruction->called_computations_.push_back(scatter); - instruction->window_ = MakeUnique(window); - return instruction; + return MakeUnique( + shape, operand, select, window, source, init_value, scatter); } /* static */ std::unique_ptr HloInstruction::CreateBroadcast( @@ -823,11 +974,8 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreatePad( const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kPad, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(padding_value); - instruction->padding_config_ = MakeUnique(padding_config); - return instruction; + return MakeUnique(shape, operand, padding_value, + padding_config); } /* static */ std::unique_ptr HloInstruction::CreateReshape( @@ -847,6 +995,12 @@ HloInstruction::CreateBroadcastSequence( return MakeUnique(shape, operand, dimensions); } +/* static */ std::unique_ptr HloInstruction::CreateSort( + const Shape& shape, int64 dimension, HloInstruction* keys, + HloInstruction* values) { + return MakeUnique(shape, dimension, keys, values); +} + /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { return MakeUnique(shape, fusion_kind, fused_root); @@ -891,6 +1045,8 @@ bool HloInstruction::HasSideEffectNoRecurse() const { case HloOpcode::kTrace: case HloOpcode::kHostCompute: return true; + case HloOpcode::kCrossReplicaSum: + return all_reduce_id().has_value(); default: return false; } @@ -924,26 +1080,15 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr HloInstruction::CreateCustomCall( const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kCustomCall, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->custom_call_target_ = std::string(custom_call_target); - return instruction; + return MakeUnique(shape, operands, + custom_call_target); } /* static */ std::unique_ptr HloInstruction::CreateHostCompute( const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kHostCompute, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->channel_name_ = std::string(channel_name); - instruction->cost_estimate_ns_ = cost_estimate_ns; - return instruction; + return MakeUnique(shape, operands, channel_name, + cost_estimate_ns); } /* static */ std::unique_ptr HloInstruction::CreateTuple( @@ -960,34 +1105,18 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kGather, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(gather_indices); - instruction->gather_dimension_numbers_ = - MakeUnique(gather_dim_numbers); - c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_)); - return instruction; + return MakeUnique(shape, operand, gather_indices, + gather_dim_numbers, window_bounds); } -/* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice output_window_dims, - tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, - int64 index_vector_dim) { - GatherDimensionNumbers gather_dim_numbers; - for (int64 output_window_dim : output_window_dims) { - gather_dim_numbers.add_output_window_dims(output_window_dim); - } - for (int64 elided_window_dim : elided_window_dims) { - gather_dim_numbers.add_elided_window_dims(elided_window_dim); - } - for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { - gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); - } - - gather_dim_numbers.set_index_vector_dim(index_vector_dim); - return gather_dim_numbers; +/* static */ std::unique_ptr HloInstruction::CreateScatter( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers) { + return MakeUnique(shape, operand, scatter_indices, + updates, update_computation, + scatter_dim_numbers); } /* static */ std::unique_ptr HloInstruction::CreateDomain( @@ -1040,6 +1169,21 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kParameter: case HloOpcode::kGetTupleElement: case HloOpcode::kReducePrecision: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kConvolution: + case HloOpcode::kCustomCall: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kHostCompute: + case HloOpcode::kPad: + case HloOpcode::kDynamicSlice: + case HloOpcode::kSort: + case HloOpcode::kGather: + case HloOpcode::kScatter: + case HloOpcode::kIota: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1062,7 +1206,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); clone = CreateUnary(shape, opcode_, new_operands[0]); @@ -1086,6 +1229,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kRemainder: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -1095,6 +1239,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( // Ternary ops. case HloOpcode::kClamp: case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: CHECK_EQ(new_operands.size(), 3); clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1], new_operands[2]); @@ -1103,21 +1248,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kCall: clone = CreateCall(shape, new_operands, to_apply()); break; - case HloOpcode::kCustomCall: - clone = CreateCustomCall(shape, new_operands, custom_call_target_); - if (window_ != nullptr) { - clone->window_ = MakeUnique(*window_); - } - if (convolution_dimension_numbers_ != nullptr) { - clone->convolution_dimension_numbers_ = - MakeUnique( - *convolution_dimension_numbers_); - } - break; - case HloOpcode::kHostCompute: - clone = CreateHostCompute(shape, new_operands, channel_name_, - cost_estimate_ns_); - break; case HloOpcode::kConvert: CHECK_EQ(new_operands.size(), 1); clone = CreateConvert(shape, new_operands[0]); @@ -1126,45 +1256,15 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kConvolution: - CHECK_EQ(new_operands.size(), 2); - clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, - *convolution_dimension_numbers_); - break; case HloOpcode::kDot: CHECK_EQ(new_operands.size(), 2); clone = CreateDot(shape, new_operands[0], new_operands[1], *dot_dimension_numbers_); break; - case HloOpcode::kCrossReplicaSum: - clone = - CreateCrossReplicaSum(shape, new_operands, to_apply(), - replica_group_ids_, cross_replica_sum_barrier_); - break; - case HloOpcode::kPad: - CHECK_EQ(new_operands.size(), 2); - clone = - CreatePad(shape, new_operands[0], new_operands[1], *padding_config_); - break; - case HloOpcode::kReduceWindow: - CHECK_EQ(new_operands.size(), 2); - clone = CreateReduceWindow(shape, new_operands[0], new_operands[1], - *window_, to_apply()); - break; - case HloOpcode::kSelectAndScatter: - CHECK_EQ(new_operands.size(), 3); - clone = - CreateSelectAndScatter(shape, new_operands[0], select(), *window_, - new_operands[1], new_operands[2], scatter()); - break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); break; - case HloOpcode::kDynamicSlice: - clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1], - dynamic_slice_sizes_); - break; case HloOpcode::kDynamicUpdateSlice: CHECK_EQ(new_operands.size(), 3); clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], @@ -1179,33 +1279,24 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateWhile(shape, while_condition(), while_body(), new_operands[0]); break; - case HloOpcode::kInfeed: - CHECK_EQ(new_operands.size(), 0); - clone = CreateInfeed(shape, infeed_config()); - break; - case HloOpcode::kOutfeed: - CHECK_EQ(new_operands.size(), 1); - clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config()); - break; case HloOpcode::kConditional: CHECK_EQ(new_operands.size(), 3); clone = CreateConditional(shape, new_operands[0], new_operands[1], true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kGather: - CHECK_EQ(new_operands.size(), 2); - clone = CreateGather(shape, new_operands[0], new_operands[1], - *gather_dimension_numbers_, gather_window_bounds_); - break; case HloOpcode::kDomain: CHECK_EQ(new_operands.size(), 1); clone = CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), user_side_metadata_->Clone()); break; - case HloOpcode::kGenerateToken: - clone = CreateGenerateToken(new_operands); + case HloOpcode::kAfterAll: + if (new_operands.empty()) { + clone = CreateToken(); + } else { + clone = CreateAfterAll(new_operands); + } break; } SetupDerivedInstruction(clone.get()); @@ -1222,7 +1313,29 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( return clone; } -HloInstruction::~HloInstruction() {} +HloInstruction::~HloInstruction() { + // Detach from operands. An instruction may be repeated as an operand. To + // avoid calling RemoveUser twice on the same operand, check before remove. + for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { + HloInstruction* operand = operands_[operand_num]; + if (operand == nullptr) { + continue; + } + if (operand->user_set_.find(this) != operand->user_set_.end()) { + operand->RemoveUser(this); + } + operands_[operand_num] = nullptr; + } + + // Update users. Set `nullptr` to the correpsonding operand slot for users. + for (auto& user : this->users()) { + for (int i = 0; i < user->operand_count(); ++i) { + if (user->operands_[i] == this) { + user->operands_[i] = nullptr; + } + } + } +} std::unique_ptr HloInstruction::Clone( const string& suffix, HloCloneContext* context) const { @@ -1368,6 +1481,30 @@ void HloInstruction::AppendOperand(HloInstruction* operand) { operand->AddUser(this); } +void HloInstruction::RemoveOperandsAtAscendingIndices( + tensorflow::gtl::ArraySlice ascending_indices) { + if (ascending_indices.empty()) { + return; + } + int next_index = 0; + int removed_count = 0; + for (int to_remove : ascending_indices) { + while (next_index < to_remove) { + operands_[next_index - removed_count] = operands_[next_index]; + ++next_index; + } + CHECK_LT(to_remove, operands_.size()); + ++removed_count; + ++next_index; + } + while (next_index < operands_.size()) { + operands_[next_index - removed_count] = operands_[next_index]; + ++next_index; + } + CHECK_EQ(removed_count, ascending_indices.size()); + operands_.resize(operands_.size() - removed_count); +} + void HloInstruction::AddUser(HloInstruction* user) { if (!ContainsKey(user_set_, user)) { user_set_.insert(user); @@ -1405,7 +1542,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kDivide: - case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: case HloOpcode::kExp: @@ -1421,6 +1557,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kAnd: case HloOpcode::kNot: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -1441,75 +1578,36 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: return true; - // These opcodes have complex or special behavior so just return false. - case HloOpcode::kDomain: - case HloOpcode::kWhile: - case HloOpcode::kGenerateToken: + // This opcode has complex or special behavior so just return false. + case HloOpcode::kAfterAll: return false; - // Convolution has a window and dimensions. - case HloOpcode::kConvolution: - return protobuf_util::ProtobufEquals(window(), other.window()) && - protobuf_util::ProtobufEquals( - convolution_dimension_numbers(), - other.convolution_dimension_numbers()); // Check dot dimension numbers. case HloOpcode::kDot: return protobuf_util::ProtobufEquals(dot_dimension_numbers(), other.dot_dimension_numbers()); - case HloOpcode::kGather: - return protobuf_util::ProtobufEquals(gather_dimension_numbers(), - other.gather_dimension_numbers()) && - gather_window_bounds() == other.gather_window_bounds(); - - case HloOpcode::kReduceWindow: - return eq_computations(to_apply(), other.to_apply()) && - protobuf_util::ProtobufEquals(window(), other.window()); - - // SelectAndScatter is determined by both select and scatter - // computation as well as the window configuration. - case HloOpcode::kSelectAndScatter: - return eq_computations(select(), other.select()) && - eq_computations(scatter(), other.scatter()) && - protobuf_util::ProtobufEquals(window(), other.window()); - // Remaining instructions with special values. - case HloOpcode::kPad: - return protobuf_util::ProtobufEquals(padding_config(), - other.padding_config()); case HloOpcode::kCall: - case HloOpcode::kCrossReplicaSum: - return replica_group_ids() == other.replica_group_ids() && - cross_replica_sum_barrier() == other.cross_replica_sum_barrier() && - eq_computations(to_apply(), other.to_apply()); - case HloOpcode::kCustomCall: - if ((window_ == nullptr) != (other.window_ == nullptr) || - (window_ != nullptr && - !protobuf_util::ProtobufEquals(window(), other.window()))) { - return false; - } - if ((convolution_dimension_numbers_ == nullptr) != - (other.convolution_dimension_numbers_ == nullptr) || - (convolution_dimension_numbers_ != nullptr && - !protobuf_util::ProtobufEquals( - convolution_dimension_numbers(), - other.convolution_dimension_numbers()))) { - return false; - } - return custom_call_target_ == other.custom_call_target_; + return eq_computations(to_apply(), other.to_apply()); case HloOpcode::kConditional: return eq_computations(true_computation(), other.true_computation()) && eq_computations(false_computation(), other.false_computation()); - // These opcodes are not yet supported. - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kSort: - case HloOpcode::kHostCompute: + case HloOpcode::kWhile: { + if (eq_computations(while_body(), other.while_body()) && + eq_computations(while_condition(), other.while_condition())) { + return true; + } return false; + } + + case HloOpcode::kDomain: + return operand_side_metadata().Matches(other.operand_side_metadata()) && + user_side_metadata().Matches(other.user_side_metadata()); // Ops migrated to subclasses should never come to this line. // TODO(b/80131774): Remove this switch when migration is complete. @@ -1524,20 +1622,36 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReverse: case HloOpcode::kConcatenate: case HloOpcode::kReduce: + case HloOpcode::kSort: case HloOpcode::kTranspose: case HloOpcode::kBroadcast: case HloOpcode::kMap: case HloOpcode::kSlice: case HloOpcode::kConstant: + case HloOpcode::kIota: case HloOpcode::kTrace: case HloOpcode::kFusion: case HloOpcode::kRng: case HloOpcode::kParameter: case HloOpcode::kGetTupleElement: case HloOpcode::kReducePrecision: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: + case HloOpcode::kConvolution: + case HloOpcode::kCustomCall: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kHostCompute: + case HloOpcode::kPad: + case HloOpcode::kDynamicSlice: + case HloOpcode::kGather: + case HloOpcode::kScatter: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } + return false; } void HloInstruction::RemoveUser(HloInstruction* user) { @@ -1569,6 +1683,10 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); + if (user->opcode() == HloOpcode::kFusion) { + TF_RETURN_IF_ERROR( + Cast(user)->DeduplicateFusionOperands()); + } return Status::OK(); } @@ -1577,10 +1695,14 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, TF_RET_CHECK(operand_num >= 0); TF_RET_CHECK(operand_num < operand_count()); HloInstruction* old_operand = mutable_operand(operand_num); + if (old_operand == new_operand) { + return Status::OK(); + } + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), new_operand->shape())) - << old_operand->shape().ShortDebugString() << " is not compatible with " - << new_operand->shape().ShortDebugString(); + << old_operand->shape() << " is not compatible with " + << new_operand->shape(); operands_[operand_num] = new_operand; VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with " @@ -1607,6 +1729,10 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); + if (user->opcode() == HloOpcode::kFusion) { + TF_RETURN_IF_ERROR( + Cast(user)->DeduplicateFusionOperands()); + } } } users_.clear(); @@ -1621,22 +1747,6 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { return Status::OK(); } -void HloInstruction::DetachFromOperands() { - VLOG(3) << "DetachFromOperands:\n " << ToString(); - CHECK_EQ(0, user_count()); - // An instruction may be repeated as an operand. To avoid calling RemoveUser - // twice on the same operand, keep a set of already detached operands. - std::set detached_operands; - for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { - HloInstruction* operand = operands_[operand_num]; - if (!ContainsKey(detached_operands, operand)) { - operand->RemoveUser(this); - detached_operands.insert(operand); - } - operands_[operand_num] = nullptr; - } -} - HloComputation* HloInstruction::to_apply() const { switch (opcode_) { case HloOpcode::kCall: @@ -1644,6 +1754,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kReduceWindow: case HloOpcode::kReduce: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -1661,6 +1772,8 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; break; @@ -1670,16 +1783,6 @@ void HloInstruction::set_to_apply(HloComputation* computation) { } } -const string& HloInstruction::custom_call_target() const { - CHECK_EQ(opcode_, HloOpcode::kCustomCall); - return custom_call_target_; -} - -const string& HloInstruction::outfeed_config() const { - CHECK_EQ(opcode_, HloOpcode::kOutfeed); - return outfeed_config_; -} - HloComputation* HloInstruction::while_condition() const { CHECK_EQ(HloOpcode::kWhile, opcode_); return called_computations_[kConditionComputationIndex]; @@ -1706,32 +1809,6 @@ void HloInstruction::set_while_body(HloComputation* computation) { called_computations_[kBodyComputationIndex] = computation; } -HloComputation* HloInstruction::select() const { - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return called_computations_[kSelectComputationIndex]; -} - -HloComputation* HloInstruction::scatter() const { - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return called_computations_[kScatterComputationIndex]; -} - -void HloInstruction::set_select(HloComputation* computation) { - // Don't allow changing the computation for fused instructions so we don't - // have to recompute called_instructions for the entire fusion instruction. - CHECK(!IsFused()); - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - called_computations_[kSelectComputationIndex] = computation; -} - -void HloInstruction::set_scatter(HloComputation* computation) { - // Don't allow changing the computation for fused instructions so we don't - // have to recompute called_instructions for the entire fusion instruction. - CHECK(!IsFused()); - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - called_computations_[kScatterComputationIndex] = computation; -} - HloComputation* HloInstruction::true_computation() const { CHECK_EQ(HloOpcode::kConditional, opcode_); return called_computations_[kTrueComputationIndex]; @@ -1827,6 +1904,7 @@ bool HloInstruction::IsElementwiseImpl( case HloOpcode::kSubtract: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -1835,15 +1913,21 @@ bool HloInstruction::IsElementwiseImpl( // Ternary elementwise operations. case HloOpcode::kSelect: - return !ShapeUtil::IsTuple(shape_); case HloOpcode::kClamp: return true; + case HloOpcode::kDynamicUpdateSlice: + return operand_idx.has_value() && operand_idx.value() == 0; + default: return false; } } +bool HloInstruction::IsCrossModuleAllReduce() const { + return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id(); +} + string HloInstruction::ToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { @@ -1901,6 +1985,11 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + // If operand is already been deleted, put `null` to the string output. + if (operand == nullptr) { + StrAppend(out, "null "); + return; + } std::vector str; if (options.print_operand_shape()) { str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); @@ -1927,32 +2016,10 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( std::vector HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { std::vector extra = ExtraAttributesToStringImpl(options); - if (window_ != nullptr && window_->dimensions_size() != 0) { - extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); - } - if (padding_config_ != nullptr) { - extra.push_back( - StrCat("padding=", xla::PaddingConfigToString(*padding_config_))); - } - if (opcode() == HloOpcode::kDynamicSlice) { - extra.push_back( - StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")); - } - - if (convolution_dimension_numbers_ != nullptr) { - extra.push_back(StrCat( - "dim_labels=", - ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); - } if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); } - if (gather_dimension_numbers_ != nullptr) { - extra.push_back(GatherDimensionNumbersToString()); - extra.push_back( - StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); - } if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { @@ -1973,7 +2040,8 @@ std::vector HloInstruction::ExtraAttributesToString( } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || opcode() == HloOpcode::kReduce || - opcode() == HloOpcode::kCrossReplicaSum) { + opcode() == HloOpcode::kCrossReplicaSum || + opcode() == HloOpcode::kScatter) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -2008,6 +2076,8 @@ std::vector HloInstruction::ExtraAttributesToString( case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: extra.push_back( StrCat("to_apply=\n", to_apply()->ToString(new_options))); break; @@ -2036,32 +2106,10 @@ std::vector HloInstruction::ExtraAttributesToString( }), "}")); } - if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) { - extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")); - } - if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) { - extra.push_back( - StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); - } if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), - "\", entry=", operand_side_metadata_->ToString(), - ", exit=", user_side_metadata_->ToString(), "}")); - } - if (!replica_group_ids().empty()) { - extra.push_back( - StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")); - } - if (!cross_replica_sum_barrier().empty()) { - extra.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); - } - - // By contract, we print the custom call target even if - // options.print_subcomputation_mode() == kOff, because the call target is not - // an HloComputation. - if (opcode() == HloOpcode::kCustomCall) { - extra.push_back( - StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + "\", entry=", user_side_metadata_->ToString(), + ", exit=", operand_side_metadata_->ToString(), "}")); } return extra; @@ -2100,46 +2148,14 @@ HloInstructionProto HloInstruction::ToProto() const { } } - if (window_ != nullptr) { - *proto.mutable_window() = *window_; - } - if (convolution_dimension_numbers_ != nullptr) { - *proto.mutable_convolution_dimension_numbers() = - *convolution_dimension_numbers_; - } if (dot_dimension_numbers_ != nullptr) { *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; } - if (gather_dimension_numbers_ != nullptr) { - *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_; - } - if (opcode() == HloOpcode::kGather) { - for (int64 bound : gather_window_bounds()) { - proto.add_gather_window_bounds(bound); - } - } - - for (int64 slice_size : dynamic_slice_sizes_) { - proto.add_dynamic_slice_sizes(slice_size); - } - if (padding_config_ != nullptr) { - *proto.mutable_padding_config() = *padding_config_; - } - proto.set_outfeed_config(outfeed_config_); - proto.set_infeed_config(infeed_config_); - proto.set_custom_call_target(custom_call_target_); - *proto.mutable_outfeed_shape() = outfeed_shape_; if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); } - proto.set_channel_name(channel_name_); - proto.set_cost_estimate_ns(cost_estimate_ns_); - for (int64 replica_group_id : replica_group_ids_) { - proto.add_replica_group_ids(replica_group_id); - } - return proto; } @@ -2149,35 +2165,6 @@ string HloInstruction::ToCategory() const { return "data formatting"; } - if (opcode() == HloOpcode::kConvolution) { - string category = "convolution"; - if (window_util::HasBaseDilation(window())) { - category += " base-dilated"; - } - if (window_util::HasWindowDilation(window())) { - category += " window-dilated"; - } - return category; - } - - // Give transpose-dot and backwards-conv fusions the categories "dot" and - // "convolution" so they match the categories of proper kDot and kConvolution - // ops. These fusion categories are really just a way of expressing a - // particular kind of dot or conv, so they should have the same category as a - // vanilla dot/conv. - if (opcode() == HloOpcode::kFusion) { - switch (fusion_kind()) { - case FusionKind::kLoop: - return "loop fusion"; - case FusionKind::kInput: - return "input fusion"; - case FusionKind::kOutput: - return "output fusion"; - case FusionKind::kCustom: - return "custom fusion"; - } - } - if (IsElementwise()) { return "non-fusion elementwise"; } @@ -2263,6 +2250,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleAnd(this); case HloOpcode::kOr: return visitor->HandleOr(this); + case HloOpcode::kXor: + return visitor->HandleXor(this); case HloOpcode::kShiftLeft: return visitor->HandleShiftLeft(this); case HloOpcode::kShiftRightArithmetic: @@ -2287,12 +2276,16 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleRemainder(this); case HloOpcode::kSelect: return visitor->HandleSelect(this); + case HloOpcode::kTupleSelect: + return visitor->HandleTupleSelect(this); case HloOpcode::kConvolution: return visitor->HandleConvolution(this); case HloOpcode::kFft: return visitor->HandleFft(this); case HloOpcode::kCrossReplicaSum: return visitor->HandleCrossReplicaSum(this); + case HloOpcode::kAllToAll: + return visitor->HandleAllToAll(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -2385,10 +2378,14 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSendDone(this); case HloOpcode::kGather: return visitor->HandleGather(this); + case HloOpcode::kScatter: + return visitor->HandleScatter(this); case HloOpcode::kDomain: return visitor->HandleDomain(this); - case HloOpcode::kGenerateToken: - return visitor->HandleGenerateToken(this); + case HloOpcode::kAfterAll: + return visitor->HandleAfterAll(this); + case HloOpcode::kIota: + return visitor->HandleIota(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2629,12 +2626,6 @@ Status HloInstruction::AcceptOrdered( return visitor->FinishVisit(this); } -const Shape& HloInstruction::outfeed_shape() const { - DCHECK_EQ(opcode_, HloOpcode::kOutfeed); - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); - return outfeed_shape_; -} - const Shape& HloInstruction::shape() const { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); return shape_; @@ -2926,26 +2917,6 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } -string HloInstruction::GatherDimensionNumbersToString() const { - CHECK_NE(gather_dimension_numbers_.get(), nullptr); - string output_window_dims = - StrCat("output_window_dims={", - Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); - string elided_window_dims = - StrCat("elided_window_dims={", - Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); - string gather_dims_to_operand_dims = StrCat( - "gather_dims_to_operand_dims={", - Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); - string index_vector_dim = StrCat( - "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); - - return Join>( - {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, - index_vector_dim}, - ", "); -} - bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: @@ -3168,4 +3139,121 @@ int32 HloInstruction::exponent_bits() const { int32 HloInstruction::mantissa_bits() const { return Cast(this)->mantissa_bits(); } + +string HloInstruction::infeed_config() const { + return Cast(this)->infeed_config(); +} + +void HloInstruction::set_infeed_config(const string& config) { + return Cast(this)->set_infeed_config(config); +} + +const Shape& HloInstruction::outfeed_shape() const { + return Cast(this)->outfeed_shape(); +} + +const string& HloInstruction::outfeed_config() const { + return Cast(this)->outfeed_config(); +} + +const std::vector& HloInstruction::replica_group_ids() const { + return Cast(this)->replica_group_ids(); +} + +const std::vector& HloInstruction::replica_groups() const { + return Cast(this)->replica_groups(); +} + +string HloInstruction::cross_replica_sum_barrier() const { + if (opcode() == HloOpcode::kCrossReplicaSum) { + return Cast(this)->cross_replica_sum_barrier(); + } + return Cast(this)->cross_replica_sum_barrier(); +} + +void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { + if (opcode() == HloOpcode::kCrossReplicaSum) { + return Cast(this)->set_cross_replica_sum_barrier( + barrier); + } + return Cast(this)->set_cross_replica_sum_barrier( + barrier); +} + +tensorflow::gtl::optional HloInstruction::all_reduce_id() const { + return Cast(this)->all_reduce_id(); +} + +const ConvolutionDimensionNumbers& +HloInstruction::convolution_dimension_numbers() const { + if (auto convolution = DynCast(this)) { + return convolution->convolution_dimension_numbers(); + } + if (auto custom_call = DynCast(this)) { + return custom_call->convolution_dimension_numbers(); + } + LOG(FATAL) << "Unimplemented method."; +} + +void HloInstruction::set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + if (auto convolution = DynCast(this)) { + convolution->set_convolution_dimension_numbers(dnums); + } else if (auto custom_call = DynCast(this)) { + custom_call->set_convolution_dimension_numbers(dnums); + } else { + LOG(FATAL) << "Unimplemented method."; + } +} + +HloComputation* HloInstruction::select() const { + return Cast(this)->select(); +} + +HloComputation* HloInstruction::scatter() const { + return Cast(this)->scatter(); +} + +void HloInstruction::set_select(HloComputation* computation) { + return Cast(this)->set_select(computation); +} + +void HloInstruction::set_scatter(HloComputation* computation) { + return Cast(this)->set_scatter(computation); +} + +const string& HloInstruction::custom_call_target() const { + return Cast(this)->custom_call_target(); +} + +const string& HloInstruction::channel_name() const { + return Cast(this)->channel_name(); +} + +const PaddingConfig& HloInstruction::padding_config() const { + return Cast(this)->padding_config(); +} + +int64 HloInstruction::slice_sizes(int64 dimension) const { + return Cast(this)->slice_sizes(dimension); +} + +const std::vector& HloInstruction::dynamic_slice_sizes() const { + return Cast(this)->dynamic_slice_sizes(); +} + +const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const { + return Cast(this)->gather_dimension_numbers(); +} + +tensorflow::gtl::ArraySlice HloInstruction::gather_window_bounds() + const { + return Cast(this)->gather_window_bounds(); +} + +const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() + const { + return Cast(this)->scatter_dimension_numbers(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a206cdab2739cd5046295179ead5d8bf19d521c4..3c575ae6ea8e60f48def4debcd9cfbea63e396b2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -33,7 +33,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/iterator_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -346,6 +346,9 @@ class HloInstruction { static std::unique_ptr CreateConstant( std::unique_ptr literal); + // Creates an Iota instruction. + static std::unique_ptr CreateIota(const Shape& shape); + // Creates a get tuple element instruction. static std::unique_ptr CreateGetTupleElement( const Shape& shape, HloInstruction* operand, int64 index); @@ -389,11 +392,10 @@ class HloInstruction { // Creates a map instruction, where the computation (given by the handle) is // applied element-wise to every element in operands (across the operands, - // at a given index) with the same `static_operands`. + // at a given index) static std::unique_ptr CreateMap( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice static_operands = {}); + HloComputation* map_computation); // Creates a convolution op, where rhs is the convolutional filter // and window describes how the filter is applied to lhs. @@ -435,9 +437,9 @@ class HloInstruction { // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // - // `channel_id`: for Allreduce nodes from different models, if they have the - // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be - // applied cross models. + // `all_reduce_id`: for Allreduce nodes from different modules, if they have + // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will + // not be applied cross modules. // // TODO(b/79737069): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( @@ -445,8 +447,27 @@ class HloInstruction { HloComputation* reduce_computation, tensorflow::gtl::ArraySlice replica_group_ids, tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& channel_id = - tensorflow::gtl::nullopt); + const tensorflow::gtl::optional& all_reduce_id); + + // This op handles the communication of an Alltoall operation. On each core, + // the operands are N ops in the same shape, where N is the number of cores + // participating the Alltoall. Then the N operands are scattered to N cores, + // e.g., the ith operand is sent to the ith core. Then each core gathers the + // received data into a tuple. + // + // - `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group in the order of 0 - (n-1). Alltoall + // will be applied within subgroups in the specified order. For example, + // replica groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied + // within replica 1, 2, 3, and in the gather phase, the received blocks will + // be concatenated in the order of 1, 2, 3; another Alltoall will be applied + // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. + // + // TODO(b/110096724): This is NOT YET ready to use. + static std::unique_ptr CreateAllToAll( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups, + tensorflow::StringPiece barrier); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -459,36 +480,57 @@ class HloInstruction { const Shape& shape, HloInstruction* operand); // Creates an infeed instruction, which reads data of the given shape from the - // Infeed interface of the device. - static std::unique_ptr CreateInfeed(const Shape& shape, + // Infeed interface of the device. infeed_shape is the shape of the data + // received from the infeed *not* the shape of the infeed instruction which + // is a tuple containing the infeed_shape and the TOKEN. + static std::unique_ptr CreateInfeed( + const Shape& infeed_shape, HloInstruction* token_operand, + const string& config); + // Overload which does not require a token. + // TODO(b/80000000): Remove this overload when all uses of infeed are + // converted to take tokens. + static std::unique_ptr CreateInfeed(const Shape& infeed_shape, const string& config); - // Creates an outfeed instruction, which outputs data. + // Creates an outfeed instruction, which outputs data. outfeed_shape is the + // shape of the data being outfed *not* the shape of the outfeed instruction + // which is a TOKEN. static std::unique_ptr CreateOutfeed( - const Shape& shape, HloInstruction* operand, + const Shape& outfeed_shape, HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); + // Overload which does not require a token. + // TODO(b/80000000): Remove this overload when all uses of outfeed are + // converted to take tokens. + static std::unique_ptr CreateOutfeed( + const Shape& outfeed_shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config); // Creates an asynchronous send instruction with the given channel id, which // initiates sending the operand data to a unique receive instruction in - // another computation that has the same channel id. - static std::unique_ptr CreateSend(HloInstruction* operand, - int64 channel_id); + // another computation that has the same channel id. If is_host_transfer is + // true, then this Send operation transfers data to the host. + static std::unique_ptr CreateSend( + HloInstruction* operand, HloInstruction* token, int64 channel_id, + bool is_host_transfer = false); // Blocks until data transfer for the Send instruction (operand) is complete. // The operand must be kSend. static std::unique_ptr CreateSendDone( - HloInstruction* operand); + HloInstruction* operand, bool is_host_transfer = false); // Creates an asynchronous receive instruction with the given channel id, // which allocates resources to receive data of the given shape from a unique - // send instruction in another computation that has the same channel id. - static std::unique_ptr CreateRecv(const Shape& shape, - int64 channel_id); + // send instruction in another computation that has the same channel id. If + // is_host_transfer is true, then this Send operation transfers data from the + // host. + static std::unique_ptr CreateRecv( + const Shape& shape, HloInstruction* token, int64 channel_id, + bool is_host_transfer = false); // Blocks until data transfer for the Recv instruction (operand) is complete // and returns the receive buffer. The operand must be kRecv. static std::unique_ptr CreateRecvDone( - HloInstruction* operand); + HloInstruction* operand, bool is_host_transfer = false); // Creates a slice instruction, where the operand is sliced by the given // start/limit indices. @@ -519,17 +561,34 @@ class HloInstruction { int64 dimension); // Creates a reduce instruction, where the computation (given by the handle) - // is applied successively to every element in operand. That is, if f is the - // function to apply (which either takes 2 [accumulator, value] or 3 - // [accumulator, index, value] arguments) and init is a reduction operator - // specified initial value (for example, 0 for addition), then this operation - // will compute: - // f(f(init, [index0], value0), [index1], value1), ...) + // is applied successively to every element in operand. For example, let f be + // the function to apply, which takes 2 arguments, an accumulator and the + // current value. Let init be an initial value (which is normally chosen to be + // the identity element for f, e.g. 0 if f is addition). + // Then the reduce HLO will compute: + // f(f(init, value0), value1), ...) static std::unique_ptr CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation); + // A more general, multiple-argument version of the above. + // The function to apply, f, now takes N arguments: + // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., + // init_valueN], and returns an N-tuple. The performed computation is (for + // commutative and associative f operators) equivalent to: + // + // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0) + // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, + // ..., inputN.value1) + // ... + // TODO(b/112040122): Add support to this in HLO passes and in backends. + static std::unique_ptr CreateReduce( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice init_values, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation); + // Creates a reduce-window instruction, where the computation (given // by the handle) is applied window-wise at each valid window // position in the operand. @@ -596,6 +655,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + // Creates a sort op, with a keys operand, and an optional values operand. + static std::unique_ptr CreateSort( + const Shape& shape, int64 dimension, HloInstruction* keys, + HloInstruction* values = nullptr); + // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 @@ -617,6 +681,12 @@ class HloInstruction { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds); + static std::unique_ptr CreateScatter( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers); + // Creates a kDomain instruction which delimits an HLO domain which have // the provided user and operand side metadata. static std::unique_ptr CreateDomain( @@ -665,17 +735,18 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); - // Creates a token instruction used for joining or creating token types which - // thread through side-effecting operations. - static std::unique_ptr CreateGenerateToken( + // Creates a Afterall instruction used for joining or creating new values of + // token type which thread through side-effecting operations. Operands must + // all be tokens, and there must be at least one operand. + static std::unique_ptr CreateAfterAll( tensorflow::gtl::ArraySlice operands); - // Creates an instance of GatherDimensionNumbers. - static GatherDimensionNumbers MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice output_window_dims, - tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, - int64 index_vector_dim); + // Creates an AfterAll instruction which creates a token type out of thin air + // (no operands). This is a separate method from CreateAfterAll to facility + // the removal of operand-less AfterAll instructions. + // TODO(b/110532604): Remove this capability of creating a token from nothing + // when we plumb a primordial token from the entry computation. + static std::unique_ptr CreateToken(); // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -811,9 +882,15 @@ class HloInstruction { // Replaces the use of this instruction in "user" with "new_producer". Note // that there might be multiple uses of this instruction in "user"; all will // be replaced. + // + // If user is a fusion instruction, this function will remove any duplicated + // operands of it which could be created due to this replacement. Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); // Replaces the specified operand with new_operand. + // + // This function does NOT remove duplicated operands even if this instruction + // is a fusion, so that the existing operand numbers do not change. Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); // Replaces all uses of this instruction with the new producer. If @@ -822,14 +899,10 @@ class HloInstruction { // // If this instruction is the root of its computation, sets the computation's // root to new_producer. - Status ReplaceAllUsesWith(HloInstruction* new_producer); - - // Detaches an instruction from its operands. That is, remove the instruction - // from each operand's user set. This should only be called prior to - // deallocating the instruction. // - // TODO(b/78305363): Make this automatic when deleting an instruction. - void DetachFromOperands(); + // If a user is a fusion instruction, this function will remove any duplicated + // operands of it which could be created due to this replacement. + Status ReplaceAllUsesWith(HloInstruction* new_producer); // Performs a postorder DFS visit using this node as the root. If // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when @@ -903,18 +976,6 @@ class HloInstruction { HloComputation* to_apply() const; void set_to_apply(HloComputation* to_apply); - // Returns the custom_call_target for CustomCall. - // Precondition: opcode() == HloOpcode::kCustomCall - const string& custom_call_target() const; - - // Returns the config for the Outfeed instruction. - // Precondition: opcode() == HloOpcode::kOutfeed - const string& outfeed_config() const; - - // Returns the shape for the Outfeed instruction. - // Precondition: opcode() == HloOpcode::kOutfeed - const Shape& outfeed_shape() const; - // Gets/sets the while_condition or while_body HloComputation for While. The // setters should only be called by HloModule or HloComputation methods. // @@ -924,15 +985,6 @@ class HloInstruction { void set_while_condition(HloComputation* while_condition); void set_while_body(HloComputation* while_body); - // Gets/sets the select or scatter HloComputation for SelectAndScatter. The - // setters should only be called by HloModule or HloComputation methods. - // - // Precondition: opcode() == HloOpcode::kSelectAndScatter. - HloComputation* select() const; - HloComputation* scatter() const; - void set_select(HloComputation* select); - void set_scatter(HloComputation* scatter); - // Gets/sets the true and false HloComputation for Conditional. The setters // should only be called by HloModule or HloComputation methods. // @@ -974,7 +1026,7 @@ class HloInstruction { // Returns a category for the HLO. This could be something like "convolution" // or "elementwise". - string ToCategory() const; + virtual string ToCategory() const; // Returns a logging instruction, if the output of this instruction is logged. // @@ -982,18 +1034,6 @@ class HloInstruction { HloInstruction* tracing() const; void set_tracing(HloInstruction* trace_instruction); - // Returns the channel name associated with the instruction. The name is - // used to identify host Send/Recv operations. - // - // Precondition: opcode() == HloOpcode::kHostCompute - string channel_name() const { return channel_name_; } - - // Returns the infeed configuration string. The infeed configuration includes - // any metadata needed for the backend compiler (e.g., infeed buffer address) - // and is target-dependent. - string infeed_config() const { return infeed_config_; } - void set_infeed_config(const string& config) { infeed_config_ = config; } - // Returns true if this instruction is fused, ie contained within a fusion // instruction. bool IsFused() const; @@ -1017,9 +1057,7 @@ class HloInstruction { if (sharding_ == nullptr) { return tensorflow::gtl::optional(); } - auto device = sharding_->UniqueDevice(); - return device.ok() ? device.ValueOrDie() - : tensorflow::gtl::optional(); + return sharding_->UniqueDevice(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. @@ -1060,54 +1098,17 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // Returns the size of the slice in the given dimension for a dynamic - // slice node. - // - // Precondition: opcode() == HloOpcode::kDynamicSlice - int64 slice_sizes(int64 dimension) const { - CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); - return dynamic_slice_sizes_[dimension]; - } - const std::vector& dynamic_slice_sizes() const { - CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); - return dynamic_slice_sizes_; + // TODO(b/80249101): Remove these methods once HLO scheduling and copy + // insertion are integrated, and we don't need to run a separate pass + // of copy elision anymore. + bool CopyElisionAllowed() const { + CHECK_EQ(HloOpcode::kCopy, opcode_); + return copy_elision_allowed_; } - // Returns data on the window in a windowed operation such as - // convolution. - const Window& window() const { - CHECK(window_ != nullptr); - return *window_; - } - - // Sets the window data in a windowed operation such as convolution. - void set_window(const Window& window) { - window_ = MakeUnique(window); - } - - // Returns the padding configuration for a pad node. - // - // Precondition: opcode() == HloOpcode::kPad - const PaddingConfig& padding_config() const { - CHECK(padding_config_ != nullptr); - return *padding_config_; - } - - // Returns data on the dimension numbers used for a convolution operation, - // which may be a kConvolution instruction or a kCustomCall that implements a - // convolution. - const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { - CHECK(convolution_dimension_numbers_ != nullptr); - return *convolution_dimension_numbers_; - } - - // Sets the convolution dimension numbers on this instruction. In general you - // shouldn't need to call this; instead, specify the convolution dimension - // numbers when you create the instruction. - void set_convolution_dimension_numbers( - const ConvolutionDimensionNumbers& dnums) { - convolution_dimension_numbers_ = - MakeUnique(dnums); + void SetCopyElisionAllowed(bool value) { + CHECK_EQ(HloOpcode::kCopy, opcode_); + copy_elision_allowed_ = value; } // Returns data on the dimension numbers used for a dot operation. @@ -1119,19 +1120,6 @@ class HloInstruction { // Returns the dump string of the dot dimension numbers. string DotDimensionNumbersToString() const; - const GatherDimensionNumbers& gather_dimension_numbers() const { - CHECK(gather_dimension_numbers_ != nullptr); - return *gather_dimension_numbers_; - } - - tensorflow::gtl::ArraySlice gather_window_bounds() const { - CHECK_EQ(opcode(), HloOpcode::kGather); - return gather_window_bounds_; - } - - // Returns the dump string of the gather dimension numbers. - string GatherDimensionNumbersToString() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1186,6 +1174,9 @@ class HloInstruction { // Returns true if this instruction is elementwise on all its operands. bool IsElementwise() const; + // Returns true if this is an cross module all-reduce instrucion. + bool IsCrossModuleAllReduce() const; + // Returns true if this elementwise instruction implicitly broadcasts operand // `operand_idx`. // @@ -1422,27 +1413,96 @@ class HloInstruction { // Delegates to HloGetTupleElementInstruction::tuple_index. int64 tuple_index() const; - // Returns the number of exponent bits for a reduce-precision node. + // Delegates to HloReducePrecisionInstruction::exponent_bits. int32 exponent_bits() const; - // Returns the number of mantissa bits for a reduce-precision node. + // Delegates to HloReducePrecisionInstruction::mantissa_bits. int32 mantissa_bits() const; - // Old methods kept for smooth subclassing transition END. - // Returns the group ids of each replica for CrossReplicaSum op. - const std::vector& replica_group_ids() const { - return replica_group_ids_; - } + // Delegates to HloInfeedInstruction::infeed_config. + string infeed_config() const; + + // Delegates to HloInfeedInstruction::set_infeed_config. + void set_infeed_config(const string& config); + + // Returns the config for the Outfeed instruction. + const string& outfeed_config() const; + + // Returns the shape for the Outfeed instruction. + const Shape& outfeed_shape() const; + + // Delegates to HloAllReduceInstruction::replica_group_ids. + const std::vector& replica_group_ids() const; + + // Delegates to HloAllToAllInstruction::replica_groups. + const std::vector& replica_groups() const; + + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. + string cross_replica_sum_barrier() const; + void set_cross_replica_sum_barrier(const string& barrier); + + // Delegates to HloAllReduceInstruction::all_reduce_id. + tensorflow::gtl::optional all_reduce_id() const; - // Returns the barrier config used for the CrossReplicaSum implementation of - // each backend. - string cross_replica_sum_barrier() const { - return cross_replica_sum_barrier_; + // Returns data on the window in a windowed operation such as + // convolution. + virtual const Window& window() const { + LOG(FATAL) << "Unimplemented method."; } - void set_cross_replica_sum_barrier(string barrier) { - cross_replica_sum_barrier_ = barrier; + + // Sets the window data in a windowed operation such as convolution. + virtual void set_window(const Window& window) { + LOG(FATAL) << "Unimplemented method."; } + // Returns data on the dimension numbers used for a convolution operation, + // which may be a kConvolution instruction or a kCustomCall that implements a + // convolution. + const ConvolutionDimensionNumbers& convolution_dimension_numbers() const; + + // Sets the convolution dimension numbers on this instruction. In general you + // shouldn't need to call this; instead, specify the convolution dimension + // numbers when you create the instruction. + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums); + + // Delegates to HloSelectAndScatterInstruction::select. + HloComputation* select() const; + + // Delegates to HloSelectAndScatterInstruction::scatter. + HloComputation* scatter() const; + + // Delegates to HloSelectAndScatterInstruction::set_select. + void set_select(HloComputation* computation); + + // Delegates to HloSelectAndScatterInstruction::set_scatter. + void set_scatter(HloComputation* computation); + + // Delegates to HloCustomCallInstruction::custom_call_target. + const string& custom_call_target() const; + + // Delegates to HloHostComputeInstruction::channel_name. + const string& channel_name() const; + + // Delegates to HloPadInstruction::padding_config. + const PaddingConfig& padding_config() const; + + // Delegates to HloDynamicSliceInstruction::slice_sizes. + int64 slice_sizes(int64 dimension) const; + + // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. + const std::vector& dynamic_slice_sizes() const; + + // Delegates to HloGatherInstruction::gather_dimension_numbers. + const GatherDimensionNumbers& gather_dimension_numbers() const; + // Delegates to HloGatherInstruction::gather_window_bounds. + tensorflow::gtl::ArraySlice gather_window_bounds() const; + + // Delegates to HloScatterInstruction::scatter_dimension_numbers(). + const ScatterDimensionNumbers& scatter_dimension_numbers() const; + + // Old methods kept for smooth subclassing transition END. + protected: enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; // Helper class for computing OperandElementUse for kFusion. @@ -1460,12 +1520,35 @@ class HloInstruction { operands_.erase(operands_.begin() + index); } + // Removes a list of operands with the given indices in ascending order. + void RemoveOperandsAtAscendingIndices( + tensorflow::gtl::ArraySlice ascending_indices); + void AppendComputation(HloComputation* computation) { called_computations_.push_back(computation); } void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } + void set_called_computation(int index, HloComputation* computation) { + called_computations_[index] = computation; + } + // Indices of computations in called_computations_ for instructions which call + // multiple computations. + enum { + // kWhile computations. + kBodyComputationIndex = 0, + kConditionComputationIndex = 1, + + // kSelectAndScatter computations. + kSelectComputationIndex = 0, + kScatterComputationIndex = 1, + + // kConditional computations. + kTrueComputationIndex = 0, + kFalseComputationIndex = 1, + }; + private: // Implementation for non-common logic of CloneWithNewOperands. virtual std::unique_ptr CloneWithNewOperandsImpl( @@ -1555,31 +1638,14 @@ class HloInstruction { // The computation in which this instruction is contained. HloComputation* parent_ = nullptr; - // Shape of outfeed request. - Shape outfeed_shape_; - // Result shape of this instruction. Shape shape_; - // Describes the window in a windowed operation such as convolution. - std::unique_ptr window_; - - // Describes the dimension numbers used for a convolution. - std::unique_ptr convolution_dimension_numbers_; - // Describes the dimension numbers used for a dot. std::unique_ptr dot_dimension_numbers_; - std::unique_ptr gather_dimension_numbers_; - std::vector gather_window_bounds_; - - // Describes the [start, start + size) range size for a dynamic slice - // ('start' is specified dynamically in the second operand of the operation). - std::vector dynamic_slice_sizes_; - - // The padding configuration that describes the edge padding and interior - // padding of this pad instruction. Only set for pad instructions. - std::unique_ptr padding_config_; + // Used to tag kCopy instructions that are eligible for copy elision. + bool copy_elision_allowed_ = true; // The sharding, if one exists. std::unique_ptr sharding_; @@ -1588,56 +1654,19 @@ class HloInstruction { std::unique_ptr operand_side_metadata_; std::unique_ptr user_side_metadata_; - // Name of a global symbol to call, only present for kCustomCall. - string custom_call_target_; - - // Name to use for host send/recv channels, only present for kHostCompute. - string channel_name_; - - // Estimate of the duration of a host computation in nanoseconds. - int64 cost_estimate_ns_ = 0; - // Computations called by this instruction. std::vector called_computations_; - // Indices of computations in called_computations_ for instructions which call - // multiple computations. - enum { - // kWhile computations. - kBodyComputationIndex = 0, - kConditionComputationIndex = 1, - - // kSelectAndScatter computations. - kSelectComputationIndex = 0, - kScatterComputationIndex = 1, - - // kConditional computations. - kTrueComputationIndex = 0, - kFalseComputationIndex = 1, - }; - - // Outfeed configuration information, only present for kOutfeed. - string outfeed_config_; - // A trace instruction that consumes this instruction. // // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as // an operand. HloInstruction* trace_instruction_ = nullptr; - // The string representation of the infeed configuration. - string infeed_config_; - // The backend-specific configuration for how a backend should compile this // HLO. See the documentation on backend_config(). string backend_config_; - // The group id of each replica for CrossReplicaSum. - std::vector replica_group_ids_; - - // The string representation of the barrier config used for CrossReplicaSum. - string cross_replica_sum_barrier_; - // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 5d6f8b931f0c665fba03e1c845214fa83aabf12e..8a694dde8066ab9a1138b9f7981153d451ddb89e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -20,10 +20,11 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -249,7 +250,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r0f32_, "param1")); auto c0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto addleft = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, c0)); auto addright = builder.AddInstruction( @@ -294,7 +295,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r0f32_, "param1")); auto c0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto neg1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0)); auto addleft = builder.AddInstruction( @@ -334,7 +335,7 @@ TEST_F(HloInstructionTest, TrivialMap) { auto param = embedded_builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "x")); auto value = embedded_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value)); auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); @@ -383,9 +384,9 @@ TEST_F(HloInstructionTest, TrivialReduce) { auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, f32a100x10, "p")); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto reduce = builder.AddInstruction( HloInstruction::CreateReduce(f32v100, param0, const0, /*dimensions_to_reduce=*/{1}, add_f32)); @@ -626,7 +627,7 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { HloComputation::Builder builder(TestName()); // Create a fusion instruction containing a single unary operation. auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); auto module = CreateNewModule(); @@ -642,9 +643,9 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { HloComputation::Builder builder(TestName()); // Create a fusion instruction containing a single binary operation. auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto module = CreateNewModule(); @@ -661,7 +662,7 @@ TEST_F(HloInstructionTest, ChainFusionOp) { HloComputation::Builder builder(TestName()); // Create a chain of fused unary ops. auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); auto exp2 = builder.AddInstruction( @@ -682,7 +683,7 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { HloComputation::Builder builder(TestName()); // Create a chain of fused unary ops. auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); auto exp2 = builder.AddInstruction( @@ -710,16 +711,17 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { HloComputation::Builder builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({ + HloInstruction::CreateConstant(LiteralUtil::CreateR2({ {1, 2}, {3, 4}, }))); auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto outfeed10 = builder.AddInstruction( - HloInstruction::CreateOutfeed(shape10, constant, "")); + HloInstruction::CreateOutfeed(shape10, constant, token, "")); auto outfeed01 = builder.AddInstruction( - HloInstruction::CreateOutfeed(shape01, constant, "")); + HloInstruction::CreateOutfeed(shape01, constant, token, "")); auto clone01 = builder.AddInstruction(outfeed01->Clone()); auto clone10 = builder.AddInstruction(outfeed10->Clone()); @@ -731,7 +733,7 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { HloComputation::Builder builder(TestName()); auto* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({ + HloInstruction::CreateConstant(LiteralUtil::CreateR2({ {1, 2}, {3, 4}, }))); @@ -762,13 +764,13 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { HloComputation::Builder builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); - auto map_1_x = builder.AddInstruction(HloInstruction::CreateMap( - scalar_shape, {constant}, computation_x, /*static_operands=*/{})); - auto map_2_x = builder.AddInstruction(HloInstruction::CreateMap( - scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{})); - auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap( - scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{})); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); + auto map_1_x = builder.AddInstruction( + HloInstruction::CreateMap(scalar_shape, {constant}, computation_x)); + auto map_2_x = builder.AddInstruction( + HloInstruction::CreateMap(scalar_shape, {map_1_x}, computation_x)); + auto map_3_y = builder.AddInstruction( + HloInstruction::CreateMap(scalar_shape, {map_2_x}, computation_y)); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( @@ -797,11 +799,11 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { // Notable complexities are repeated operands in the same instruction, // different shapes, use of value in different expressions. auto c1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto c2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.1f))); auto c3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(9.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(9.0f))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2)); @@ -872,11 +874,11 @@ TEST_F(HloInstructionTest, IdenticalInstructions) { // Create a set of random constant operands to use below. Make them matrices // so dimensions are interesting. auto operand1 = HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); auto operand2 = HloInstruction::CreateConstant( - Literal::CreateR2({{10.0, 20.0}, {30.0, 40.0}})); - auto vector_operand = - HloInstruction::CreateConstant(Literal::CreateR1({42.0, 123.0})); + LiteralUtil::CreateR2({{10.0, 20.0}, {30.0, 40.0}})); + auto vector_operand = HloInstruction::CreateConstant( + LiteralUtil::CreateR1({42.0, 123.0})); Shape shape = operand1->shape(); // Convenient short names for the operands. @@ -923,6 +925,40 @@ TEST_F(HloInstructionTest, IdenticalInstructions) { *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2))); } +TEST_F(HloInstructionTest, IdenticalCallInstructions) { + const char* const hlo_string = R"( +HloModule Module + +subcomp1 (x: f32[]) -> f32[] { + x = f32[] parameter(0) + ROOT n = f32[] sine(x) +} + +subcomp2 (x: f32[]) -> f32[] { + x = f32[] parameter(0) + ROOT n = f32[] cosine(x) +} + +ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) { + p = f32[] parameter(0) + t1 = f32[] call(p), to_apply=subcomp1 + t2 = f32[] call(p), to_apply=subcomp1 + t3 = f32[] call(p), to_apply=subcomp2 + ROOT t = (f32[], f32[], f32[]) tuple(t1, t2, t3) + } +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + auto* root = module->entry_computation()->root_instruction(); + auto* t1 = root->operand(0); + auto* t2 = root->operand(1); + auto* t3 = root->operand(2); + + EXPECT_TRUE(StructuralEqual(*t1, *t2)); + EXPECT_FALSE(StructuralEqual(*t1, *t3)); +} + TEST_F(HloInstructionTest, FunctionVisitor) { // Verify the function visitor HloInstruction::Accept visits all instructions // from a root properly given the following graph: @@ -1136,6 +1172,40 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { EXPECT_TRUE(StructuralEqual(*fusion, *fusion2)); } +TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { + // Fused expression: + // + // x y + // | | + // | transpose + // \ / + // dot + const Shape s = ShapeUtil::MakeShape(F32, {10, 10}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(s, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + HloInstruction* fusion = computation->CreateFusionInstruction( + {dot, reshape}, HloInstruction::FusionKind::kLoop); + + EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok()); + + EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y)); + EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1); +} + TEST_F(HloInstructionTest, FusionEquality) { auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); @@ -1165,9 +1235,9 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { // Build a nested fusion computation. Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto b_t = builder.AddInstruction( HloInstruction::CreateTranspose(data_shape, b, {1, 0})); DotDimensionNumbers dot_dnums; @@ -1176,7 +1246,7 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { auto dot = builder.AddInstruction( HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); auto add = builder.AddInstruction(HloInstruction::CreateBinary( @@ -1273,7 +1343,7 @@ TEST_F(HloInstructionTest, Stringification) { "condition=%TransposeDot, body=%TransposeDot"); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloInstruction* conditional = builder.AddInstruction(HloInstruction::CreateConditional( sout, pred, x, computation, x, computation)); @@ -1300,7 +1370,7 @@ TEST_F(HloInstructionTest, StringifyGather_0) { HloInstruction* gather_instruction = builder.AddInstruction(HloInstruction::CreateGather( gather_result_shape, input, gather_indices, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1336,7 +1406,7 @@ TEST_F(HloInstructionTest, StringifyGather_1) { HloInstruction* gather_instruction = builder.AddInstruction(HloInstruction::CreateGather( gather_result_shape, input, gather_indices, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1355,6 +1425,55 @@ TEST_F(HloInstructionTest, StringifyGather_1) { "index_vector_dim=2, window_bounds={30,29,28,27,26}"); } +TEST_F(HloInstructionTest, StringifyScatter) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape scatter_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); + Shape scatter_updates_shape = + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Scatter"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* scatter_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, scatter_indices_tensor_shape, "scatter_indices")); + HloInstruction* scatter_updates = + builder.AddInstruction(HloInstruction::CreateParameter( + 2, scatter_updates_shape, "scatter_updates")); + + HloComputation::Builder update_builder("Scatter.update"); + update_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p1")); + update_builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2")); + + auto module = CreateNewModule(); + auto* update_computation = + module->AddEmbeddedComputation(update_builder.Build()); + + HloInstruction* scatter_instruction = + builder.AddInstruction(HloInstruction::CreateScatter( + input_tensor_shape, input, scatter_indices, scatter_updates, + update_computation, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2))); + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ( + scatter_instruction->ToString(), + "%scatter = f32[50,49,48,47,46]{4,3,2,1,0} " + "scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, " + "f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), " + "update_window_dims={4,5,6,7,8}, inserted_window_dims={}, " + "scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, " + "to_apply=%Scatter.update"); +} + TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { // Tests stringification of a simple op, fusion, while, and conditional. const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); @@ -1386,15 +1505,15 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); - EXPECT_EQ( - fusion->ToString(options), + const string expected_fusion = R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls= { tmp_0 = f32[5,10]{1,0} parameter(0) tmp_1 = f32[20,10]{1,0} parameter(1) tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"); +})"; + EXPECT_EQ(fusion->ToString(options), expected_fusion); } TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { @@ -1426,8 +1545,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { HloInstruction::CreateWhile(sout, computation, computation, x)); auto options = HloPrintOptions().Canonical(); - EXPECT_EQ(loop->ToString(options), - R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= + const string expected_loop = + R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= { tmp_0 = f32[5,10]{1,0} parameter(0) tmp_1 = f32[20,10]{1,0} parameter(1) @@ -1449,7 +1568,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} } -})"); +})"; + EXPECT_EQ(loop->ToString(options), expected_loop); } TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { @@ -1481,13 +1601,12 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { HloInstruction::CreateWhile(sout, computation, computation, x)); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloInstruction* conditional = builder.AddInstruction(HloInstruction::CreateConditional( sout, pred, x, computation, x, computation)); auto options = HloPrintOptions().Canonical(); - EXPECT_EQ( - conditional->ToString(options), + const string expected_conditional = R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation= { tmp_0 = f32[5,10]{1,0} parameter(0) @@ -1510,7 +1629,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} } -})"); +})"; + EXPECT_EQ(conditional->ToString(options), expected_conditional); } TEST_F(HloInstructionTest, CheckDeepClone) { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index d326d5d009fa3a378cfe92e1b081445f09d982e1..1de5032670ff47cda5599cf736bbd3529cfcaba9 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -17,13 +17,17 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { +using ::tensorflow::str_util::CEscape; using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; @@ -177,8 +181,11 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, - int64 channel_id) - : HloInstruction(opcode, shape), channel_id_(channel_id) {} + int64 channel_id, + bool is_host_transfer) + : HloInstruction(opcode, shape), + channel_id_(channel_id), + is_host_transfer_(is_host_transfer) {} HloInstructionProto HloSendRecvInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); @@ -188,7 +195,12 @@ HloInstructionProto HloSendRecvInstruction::ToProto() const { std::vector HloSendRecvInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("channel_id=", channel_id_)}; + std::vector attrs; + attrs.push_back(StrCat("channel_id=", channel_id_)); + if (is_host_transfer()) { + attrs.push_back("is_host_transfer=true"); + } + return attrs; } bool HloSendRecvInstruction::IdenticalSlowPath( @@ -201,26 +213,32 @@ bool HloSendRecvInstruction::IdenticalSlowPath( // Send instruction produces a tuple of {aliased operand, U32 context}. HloSendInstruction::HloSendInstruction(HloInstruction* operand, - int64 channel_id) + HloInstruction* token, int64 channel_id, + bool is_host_transfer) : HloSendRecvInstruction( HloOpcode::kSend, - ShapeUtil::MakeTupleShape( - {CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}), - channel_id) { + ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(), + ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}), + channel_id, is_host_transfer) { AppendOperand(operand); + AppendOperand(token); } std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 1); - return MakeUnique(new_operands[0], channel_id()); + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(new_operands[0], new_operands[1], + channel_id(), is_host_transfer()); } -HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand) - : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil(), - CHECK_NOTNULL(operand)->channel_id()) { +HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, + bool is_host_transfer) + : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(), + CHECK_NOTNULL(operand)->channel_id(), + is_host_transfer) { AppendOperand(operand); } @@ -231,30 +249,39 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return MakeUnique( - Cast(new_operands[0])); + Cast(new_operands[0]), is_host_transfer()); } // Recv instruction produces a tuple of {receive buffer, U32 context}. -HloRecvInstruction::HloRecvInstruction(const Shape& shape, int64 channel_id) +HloRecvInstruction::HloRecvInstruction(const Shape& shape, + HloInstruction* token, int64 channel_id, + bool is_host_transfer) : HloSendRecvInstruction( HloOpcode::kRecv, - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), - channel_id) {} + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}), + channel_id, is_host_transfer) { + AppendOperand(token); +} std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 0); + CHECK_EQ(new_operands.size(), 1); return MakeUnique( - ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); + ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(), + is_host_transfer()); } -HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand) +HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand, + bool is_host_transfer) : HloSendRecvInstruction( HloOpcode::kRecvDone, - ShapeUtil::GetTupleElementShape(operand->shape(), 0), - CHECK_NOTNULL(operand)->channel_id()) { + ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(operand->shape(), 0), + ShapeUtil::MakeTokenShape()}), + CHECK_NOTNULL(operand)->channel_id(), is_host_transfer) { AppendOperand(operand); } @@ -265,7 +292,132 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return MakeUnique( - Cast(new_operands[0])); + Cast(new_operands[0]), is_host_transfer()); +} + +HloAllReduceInstruction::HloAllReduceInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id) + : HloInstruction(HloOpcode::kCrossReplicaSum, shape), + replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()), + cross_replica_sum_barrier_(barrier.begin(), barrier.end()), + all_reduce_id_(all_reduce_id) { + for (auto operand : operands) { + AppendOperand(operand); + } + AppendComputation(reduce_computation); +} + +HloInstructionProto HloAllReduceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 i : replica_group_ids_) { + proto.add_replica_group_ids(i); + } + // Proto3 is so sad. + if (all_reduce_id_) { + proto.set_all_reduce_id(*all_reduce_id_); + } + proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + return proto; +} + +std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result = { + StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")}; + if (!cross_replica_sum_barrier().empty()) { + result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); + } + if (all_reduce_id_) { + result.push_back(StrCat("all_reduce_id=", *all_reduce_id_)); + } + return result; +} + +bool HloAllReduceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return replica_group_ids() == casted_other.replica_group_ids() && + eq_computations(to_apply(), casted_other.to_apply()) && + cross_replica_sum_barrier() == + casted_other.cross_replica_sum_barrier() && + all_reduce_id() == casted_other.all_reduce_id(); +} + +std::unique_ptr +HloAllReduceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* /*context*/) const { + return MakeUnique( + shape, new_operands, to_apply(), replica_group_ids(), + cross_replica_sum_barrier(), all_reduce_id()); +} + +HloAllToAllInstruction::HloAllToAllInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups, + tensorflow::StringPiece barrier) + : HloInstruction(HloOpcode::kAllToAll, shape), + replica_groups_(replica_groups), + cross_replica_sum_barrier_(barrier.begin(), barrier.end()) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +bool HloAllToAllInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return ContainersEqual(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return ContainersEqual(a.replica_ids(), + b.replica_ids()); + }) && + cross_replica_sum_barrier() == + casted_other.cross_replica_sum_barrier(); +} + +std::unique_ptr +HloAllToAllInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* /*context*/) const { + return MakeUnique( + shape, new_operands, replica_groups(), cross_replica_sum_barrier()); +} + +std::vector HloAllToAllInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector result; + std::vector replica_group_str; + for (const ReplicaGroup& group : replica_groups()) { + replica_group_str.push_back( + StrCat("{", Join(group.replica_ids(), ","), "}")); + } + result.push_back( + StrCat("replica_groups={", Join(replica_group_str, ","), "}")); + + if (!cross_replica_sum_barrier().empty()) { + result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); + } + + return result; +} + +HloInstructionProto HloAllToAllInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_replica_groups() = {replica_groups_.begin(), + replica_groups_.end()}; + proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + return proto; } HloReverseInstruction::HloReverseInstruction( @@ -347,13 +499,14 @@ HloConcatenateInstruction::CloneWithNewOperandsImpl( } HloReduceInstruction::HloReduceInstruction( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, tensorflow::gtl::ArraySlice args, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) : HloInstruction(HloOpcode::kReduce, shape), dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { - AppendOperand(arg); - AppendOperand(init_value); + for (HloInstruction* arg : args) { + AppendOperand(arg); + } AppendComputation(reduce_computation); } @@ -386,8 +539,48 @@ std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( - shape, new_operands[0], new_operands[1], dimensions(), to_apply()); + return MakeUnique(shape, new_operands, dimensions(), + to_apply()); +} + +HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, + HloInstruction* keys, + HloInstruction* values) + : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { + AppendOperand(keys); + if (values) { + AppendOperand(values); + } +} + +HloInstructionProto HloSortInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloSortInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloSortInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + HloInstruction* keys = new_operands[0]; + HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; + return MakeUnique(shape, dimensions(0), keys, values); } HloTransposeInstruction::HloTransposeInstruction( @@ -485,10 +678,8 @@ HloBroadcastInstruction::CloneWithNewOperandsImpl( HloMapInstruction::HloMapInstruction( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice static_operands) + HloComputation* map_computation) : HloInstruction(HloOpcode::kMap, shape) { - CHECK(static_operands.empty()) << "static_operands not yet supported"; for (auto operand : operands) { AppendOperand(operand); } @@ -609,9 +800,14 @@ HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()), literal_(std::move(literal)) {} +HloConstantInstruction::HloConstantInstruction(const Shape& shape) + : HloInstruction(HloOpcode::kConstant, shape) {} + HloInstructionProto HloConstantInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - *proto.mutable_literal() = literal_->ToProto(); + if (literal_ != nullptr) { + *proto.mutable_literal() = literal_->ToProto(); + } return proto; } @@ -657,8 +853,9 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( CanonicalNameMap* canonical_name_map) const { string operands; // For constants, show the actual value in place of an empty operand list. - if ((!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || - options.print_large_constants()) { + if (literal_ != nullptr && + ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || + options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); @@ -684,7 +881,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( HloTraceInstruction::HloTraceInstruction(const string& tag, HloInstruction* operand) : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()), - literal_(Literal::CreateR1U8(tag)) { + literal_(LiteralUtil::CreateR1U8(tag)) { AppendOperand(operand); operand->set_tracing(this); } @@ -733,6 +930,19 @@ HloFusionInstruction::HloFusionInstruction( fusion_computation->SetFusionInstruction(this); } +string HloFusionInstruction::ToCategory() const { + switch (fusion_kind()) { + case FusionKind::kLoop: + return "loop fusion"; + case FusionKind::kInput: + return "input fusion"; + case FusionKind::kOutput: + return "output fusion"; + case FusionKind::kCustom: + return "custom fusion"; + } +} + HloInstructionProto HloFusionInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_fusion_kind(xla::ToString(fusion_kind())); @@ -743,10 +953,6 @@ HloInstructionProto HloFusionInstruction::ToProto() const { bool HloFusionInstruction::IsElementwiseImpl( const tensorflow::gtl::optional& operand_idx) const { - if (fusion_kind() != FusionKind::kLoop) { - return false; - } - if (!operand_idx.has_value()) { for (auto* fused : fused_instructions()) { if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) { @@ -830,10 +1036,8 @@ void HloFusionInstruction::MergeFusionInstruction( // Fuse 'unfused_instructions' into 'this'. for (auto& instruction : unfused_instructions) { FuseInstruction(instruction); - instruction->DetachFromOperands(); } CHECK_EQ(0, cloned_fusion->user_count()); - cloned_fusion->DetachFromOperands(); TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation( cloned_fusion->fused_instructions_computation())); } @@ -963,8 +1167,6 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); clone = fused_expression_root(); } else { - clone = fused_instructions_computation()->AddInstruction( - instruction_to_fuse->Clone(/*suffix=*/"")); // When add_output is false, instruction_to_fuse is necessarily an operand // of the fusion instruction. After fusion this will no longer be the // case. Remove the operand from the operand list and remove its @@ -974,6 +1176,16 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( bool in_operand_list = std::find(operands().begin(), operands().end(), instruction_to_fuse) != operands().end(); CHECK(add_output || in_operand_list); + if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { + // We assume all uses of a kTuple operation are GTE ops, not another + // fusion node. In this case, we don't need to clone + // 'instruction_to_fuse'. + CHECK(!in_operand_list); + clone = instruction_to_fuse; + } else { + clone = fused_instructions_computation()->AddInstruction( + instruction_to_fuse->Clone(/*suffix=*/"")); + } const std::vector& fused_parameters = fused_instructions_computation()->parameter_instructions(); for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { @@ -1070,9 +1282,10 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( } int64 index = tuple_elements.size(); if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { - index -= instruction_to_fuse->operand_count(); + CHECK_EQ(clone, instruction_to_fuse); + index -= clone->operand_count(); std::vector to_be_removed; - for (auto old_gte : instruction_to_fuse->users()) { + for (auto old_gte : clone->users()) { CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement); int64 old_tuple_index = old_gte->tuple_index(); HloInstruction* new_gte = @@ -1084,7 +1297,6 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( for (auto old_gte : to_be_removed) { TF_CHECK_OK(parent()->RemoveInstruction(old_gte)); } - TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone)); } else { HloInstruction* new_gte = parent()->AddInstruction(HloInstruction::CreateGetTupleElement( @@ -1093,7 +1305,9 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( } } - VLOG(2) << "New clone:\n" << clone->ToString(); + if (clone != instruction_to_fuse) { + VLOG(2) << "New clone:\n" << clone->ToString(); + } return clone; } @@ -1129,6 +1343,26 @@ std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( new_fused_computation); } +Status HloFusionInstruction::DeduplicateFusionOperands() { + tensorflow::gtl::FlatMap operand_indices; + std::vector operands_to_remove; + for (int i = 0; i < operand_count(); ++i) { + auto emplace_result = operand_indices.emplace(operand(i), i); + if (!emplace_result.second) { + TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith( + fused_parameter(emplace_result.first->second))); + operands_to_remove.push_back(i); + } + } + if (operands_to_remove.empty()) { + return Status::OK(); + } + TF_RETURN_IF_ERROR( + fused_instructions_computation()->RemoveUnusedParameters()); + RemoveOperandsAtAscendingIndices(operands_to_remove); + return Status::OK(); +} + HloRngInstruction::HloRngInstruction( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters) @@ -1284,4 +1518,650 @@ HloReducePrecisionInstruction::CloneWithNewOperandsImpl( shape, new_operands[0], exponent_bits(), mantissa_bits()); } +HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, + HloInstruction* token_operand, + const string& config) + : HloInstruction(HloOpcode::kInfeed, + ShapeUtil::MakeTupleShape( + {infeed_shape, ShapeUtil::MakeTokenShape()})), + infeed_config_(config) { + AppendOperand(token_operand); +} + +HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, + const string& config) + : HloInstruction(HloOpcode::kInfeed, + ShapeUtil::MakeTupleShape( + {infeed_shape, ShapeUtil::MakeTokenShape()})), + infeed_config_(config) {} + +HloInstructionProto HloInfeedInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_infeed_config(infeed_config_); + return proto; +} + +std::vector HloInfeedInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (infeed_config_.empty()) { + return {}; + } + return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")}; +} + +bool HloInfeedInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +std::unique_ptr HloInfeedInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + if (new_operands.empty()) { + return MakeUnique(infeed_shape(), infeed_config()); + } else { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(infeed_shape(), new_operands[0], + infeed_config()); + } +} + +HloOutfeedInstruction::HloOutfeedInstruction( + const Shape& outfeed_shape, HloInstruction* operand, + HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) + : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), + outfeed_shape_(outfeed_shape), + outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { + CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) + << "Outfeed shape " << outfeed_shape + << " must be compatible with operand shape " << operand->shape(); + AppendOperand(operand); + AppendOperand(token_operand); +} + +HloOutfeedInstruction::HloOutfeedInstruction( + const Shape& outfeed_shape, HloInstruction* operand, + tensorflow::StringPiece outfeed_config) + : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), + outfeed_shape_(outfeed_shape), + outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { + CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) + << "Outfeed shape " << outfeed_shape + << " must be compatible with operand shape " << operand->shape(); + AppendOperand(operand); +} + +HloInstructionProto HloOutfeedInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_outfeed_config(outfeed_config()); + *proto.mutable_outfeed_shape() = outfeed_shape(); + return proto; +} + +std::vector HloOutfeedInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (outfeed_config_.empty()) { + return {}; + } + return {StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")}; +} + +bool HloOutfeedInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + if (new_operands.size() == 1) { + return MakeUnique(outfeed_shape(), new_operands[0], + outfeed_config()); + } else { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(outfeed_shape(), new_operands[0], + new_operands[1], outfeed_config()); + } +} + +HloConvolutionInstruction::HloConvolutionInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) + : HloInstruction(HloOpcode::kConvolution, shape), + window_(window), + convolution_dimension_numbers_(dimension_numbers) { + if (window_util::HasBaseDilation(window)) { + SetAndSanitizeName(StrCat(name(), "-base-dilated")); + } + if (window_util::HasWindowDilation(window)) { + SetAndSanitizeName(StrCat(name(), "-window-dilated")); + } + AppendOperand(lhs); + AppendOperand(rhs); +} + +string HloConvolutionInstruction::ToCategory() const { + string category = "convolution"; + if (window_util::HasBaseDilation(window())) { + category += " base-dilated"; + } + if (window_util::HasWindowDilation(window())) { + category += " window-dilated"; + } + return category; +} + +HloInstructionProto HloConvolutionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_window() = window_; + *proto.mutable_convolution_dimension_numbers() = + convolution_dimension_numbers_; + return proto; +} + +std::vector HloConvolutionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_.dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(window()), "}")); + } + extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString( + convolution_dimension_numbers_))); + return extra; +} + +bool HloConvolutionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return protobuf_util::ProtobufEquals(window(), casted_other.window()) && + protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + casted_other.convolution_dimension_numbers()); +} + +std::unique_ptr +HloConvolutionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(shape, new_operands[0], + new_operands[1], window(), + convolution_dimension_numbers_); +} + +HloReduceWindowInstruction::HloReduceWindowInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* init_value, + const Window& window, HloComputation* reduce_computation) + : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) { + AppendOperand(operand); + AppendOperand(init_value); + AppendComputation(reduce_computation); +} + +HloInstructionProto HloReduceWindowInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_window() = window_; + return proto; +} + +std::vector HloReduceWindowInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_.dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(window()), "}")); + } + return extra; +} + +bool HloReduceWindowInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return eq_computations(to_apply(), casted_other.to_apply()) && + protobuf_util::ProtobufEquals(window(), casted_other.window()); +} + +std::unique_ptr +HloReduceWindowInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], window(), to_apply()); +} + +HloSelectAndScatterInstruction::HloSelectAndScatterInstruction( + const Shape& shape, HloInstruction* operand, HloComputation* select, + const Window& window, HloInstruction* source, HloInstruction* init_value, + HloComputation* scatter) + : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) { + AppendOperand(operand); + AppendOperand(source); + AppendOperand(init_value); + // Select comes before scatter in the vector. + AppendComputation(select); + AppendComputation(scatter); +} + +HloInstructionProto HloSelectAndScatterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_window() = window_; + return proto; +} + +std::vector HloSelectAndScatterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_.dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(window()), "}")); + } + return extra; +} + +bool HloSelectAndScatterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return eq_computations(select(), casted_other.select()) && + eq_computations(scatter(), casted_other.scatter()) && + protobuf_util::ProtobufEquals(window(), casted_other.window()); +} + +std::unique_ptr +HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique( + shape, new_operands[0], select(), window(), new_operands[1], + new_operands[2], scatter()); +} + +HloCustomCallInstruction::HloCustomCallInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) + : HloInstruction(HloOpcode::kCustomCall, shape), + custom_call_target_(custom_call_target.begin(), + custom_call_target.end()) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloCustomCallInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + if (window_ != nullptr) { + *proto.mutable_window() = *window_; + } + if (convolution_dimension_numbers_ != nullptr) { + *proto.mutable_convolution_dimension_numbers() = + *convolution_dimension_numbers_; + } + proto.set_custom_call_target(custom_call_target_); + return proto; +} + +std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_ != nullptr && window_->dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); + } + if (convolution_dimension_numbers_ != nullptr) { + extra.push_back(StrCat( + "dim_labels=", + ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); + } + // By contract, we print the custom call target even if + // options.print_subcomputation_mode() == kOff, because the call target is not + // an HloComputation. + extra.push_back( + StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + return extra; +} + +bool HloCustomCallInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + if ((window_ == nullptr) != (casted_other.window_ == nullptr) || + (window_ != nullptr && + !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) { + return false; + } + if ((convolution_dimension_numbers_ == nullptr) != + (casted_other.convolution_dimension_numbers_ == nullptr) || + (convolution_dimension_numbers_ != nullptr && + !protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + casted_other.convolution_dimension_numbers()))) { + return false; + } + return custom_call_target_ == casted_other.custom_call_target_; +} + +std::unique_ptr +HloCustomCallInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + auto cloned = MakeUnique(shape, new_operands, + custom_call_target()); + if (window_ != nullptr) { + cloned->set_window(*window_); + } + if (convolution_dimension_numbers_ != nullptr) { + cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_); + } + return std::move(cloned); +} + +HloHostComputeInstruction::HloHostComputeInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) + : HloInstruction(HloOpcode::kHostCompute, shape), + channel_name_(channel_name.begin(), channel_name.end()), + cost_estimate_ns_(cost_estimate_ns) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloHostComputeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_channel_name(channel_name_); + proto.set_cost_estimate_ns(cost_estimate_ns_); + return proto; +} + +bool HloHostComputeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +std::unique_ptr +HloHostComputeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique( + shape, new_operands, channel_name_, cost_estimate_ns_); +} + +HloPadInstruction::HloPadInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config) + : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) { + AppendOperand(operand); + AppendOperand(padding_value); +} + +HloInstructionProto HloPadInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_padding_config() = padding_config_; + return proto; +} + +std::vector HloPadInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))}; +} + +bool HloPadInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals(padding_config(), + casted_other.padding_config()); +} + +std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(shape, new_operands[0], new_operands[1], + padding_config_); +} + +HloDynamicSliceInstruction::HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, + tensorflow::gtl::ArraySlice slice_sizes) + : HloInstruction(HloOpcode::kDynamicSlice, shape), + dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { + AppendOperand(operand); + AppendOperand(start_indices); +} + +HloInstructionProto HloDynamicSliceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 slice_size : dynamic_slice_sizes_) { + proto.add_dynamic_slice_sizes(slice_size); + } + return proto; +} + +std::vector HloDynamicSliceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return { + StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")}; +} + +bool HloDynamicSliceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return true; +} + +std::unique_ptr +HloDynamicSliceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); +} + +HloGatherInstruction::HloGatherInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds) + : HloInstruction(HloOpcode::kGather, shape) { + AppendOperand(operand); + AppendOperand(gather_indices); + gather_dimension_numbers_ = + MakeUnique(gather_dim_numbers); + c_copy(window_bounds, std::back_inserter(gather_window_bounds_)); +} + +string HloGatherInstruction::GatherDimensionNumbersToString() const { + CHECK(gather_dimension_numbers_ != nullptr); + string output_window_dims = + StrCat("output_window_dims={", + Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); + string elided_window_dims = + StrCat("elided_window_dims={", + Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); + string gather_dims_to_operand_dims = StrCat( + "gather_dims_to_operand_dims={", + Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); + + return Join>( + {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + +/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice output_window_dims, + tensorflow::gtl::ArraySlice elided_window_dims, + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim) { + GatherDimensionNumbers gather_dim_numbers; + for (int64 output_window_dim : output_window_dims) { + gather_dim_numbers.add_output_window_dims(output_window_dim); + } + for (int64 elided_window_dim : elided_window_dims) { + gather_dim_numbers.add_elided_window_dims(elided_window_dim); + } + for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { + gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + } + + gather_dim_numbers.set_index_vector_dim(index_vector_dim); + return gather_dim_numbers; +} + +HloInstructionProto HloGatherInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers(); + for (int64 bound : gather_window_bounds()) { + proto.add_gather_window_bounds(bound); + } + return proto; +} + +std::vector HloGatherInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {GatherDimensionNumbersToString(), + StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")}; +} + +bool HloGatherInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals( + gather_dimension_numbers(), + casted_other.gather_dimension_numbers()) && + gather_window_bounds() == casted_other.gather_window_bounds(); +} + +std::unique_ptr HloGatherInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], gather_dimension_numbers(), + gather_window_bounds()); +} + +HloScatterInstruction::HloScatterInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers) + : HloInstruction(HloOpcode::kScatter, shape) { + AppendOperand(operand); + AppendOperand(scatter_indices); + AppendOperand(updates); + AppendComputation(update_computation); + scatter_dimension_numbers_ = + MakeUnique(scatter_dim_numbers); +} + +string HloScatterInstruction::ScatterDimensionNumbersToString() const { + string update_window_dims = + StrCat("update_window_dims={", + Join(scatter_dimension_numbers().update_window_dims(), ","), "}"); + string inserted_window_dims = StrCat( + "inserted_window_dims={", + Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); + string scatter_dims_to_operand_dims = StrCat( + "scatter_dims_to_operand_dims={", + Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), + "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); + + return Join>( + {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + +/* static */ ScatterDimensionNumbers +HloScatterInstruction::MakeScatterDimNumbers( + tensorflow::gtl::ArraySlice update_window_dims, + tensorflow::gtl::ArraySlice inserted_window_dims, + tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + int64 index_vector_dim) { + ScatterDimensionNumbers scatter_dim_numbers; + for (int64 update_window_dim : update_window_dims) { + scatter_dim_numbers.add_update_window_dims(update_window_dim); + } + for (int64 inserted_window_dim : inserted_window_dims) { + scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim); + } + for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) { + scatter_dim_numbers.add_scatter_dims_to_operand_dims( + scatter_dim_to_operand_dim); + } + scatter_dim_numbers.set_index_vector_dim(index_vector_dim); + return scatter_dim_numbers; +} + +HloInstructionProto HloScatterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers(); + return proto; +} + +std::vector HloScatterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {ScatterDimensionNumbersToString()}; +} + +bool HloScatterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals( + scatter_dimension_numbers(), + casted_other.scatter_dimension_numbers()) && + eq_computations(to_apply(), casted_other.to_apply()); +} + +std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), + scatter_dimension_numbers()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 6749d875559008b3a2bd479ff075c83d85d87509..9586ad667345111d05015e035c93fe6578e3b665 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -141,12 +141,15 @@ class HloSendRecvInstruction : public HloInstruction { // channel. int64 channel_id() const { return channel_id_; } + // Returns whether this send/recv instruction sends data to/from the host. + bool is_host_transfer() const { return is_host_transfer_; } + // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; protected: explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, - int64 channel_id); + int64 channel_id, bool is_host_transfer); private: std::vector ExtraAttributesToStringImpl( @@ -157,11 +160,15 @@ class HloSendRecvInstruction : public HloInstruction { eq_computations) const override; // Represents a unique identifier for each Send/Recv instruction pair. int64 channel_id_; + + // Whether this send/recv instruction sends data to/from the host. + bool is_host_transfer_; }; class HloSendInstruction : public HloSendRecvInstruction { public: - explicit HloSendInstruction(HloInstruction* operand, int64 channel_id); + explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token, + int64 channel_id, bool is_host_transfer); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -173,7 +180,8 @@ class HloSendInstruction : public HloSendRecvInstruction { class HloSendDoneInstruction : public HloSendRecvInstruction { public: - explicit HloSendDoneInstruction(HloSendInstruction* operand); + explicit HloSendDoneInstruction(HloSendInstruction* operand, + bool is_host_transfer); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -185,7 +193,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction { class HloRecvInstruction : public HloSendRecvInstruction { public: - explicit HloRecvInstruction(const Shape& shape, int64 channel_id); + explicit HloRecvInstruction(const Shape& shape, HloInstruction* token, + int64 channel_id, bool is_host_transfer); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -197,14 +206,112 @@ class HloRecvInstruction : public HloSendRecvInstruction { class HloRecvDoneInstruction : public HloSendRecvInstruction { public: - explicit HloRecvDoneInstruction(HloRecvInstruction* operand); + explicit HloRecvDoneInstruction(HloRecvInstruction* operand, + bool is_host_transfer); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloAllReduceInstruction : public HloInstruction { + public: + explicit HloAllReduceInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id); + + // Returns the group ids of each replica for CrossReplicaSum op. + const std::vector& replica_group_ids() const { + return replica_group_ids_; + } + + // Returns the barrier config used for the CrossReplicaSum implementation of + // each backend. + string cross_replica_sum_barrier() const { + return cross_replica_sum_barrier_; + } + void set_cross_replica_sum_barrier(string barrier) { + cross_replica_sum_barrier_ = barrier; + } + + tensorflow::gtl::optional all_reduce_id() const { + return all_reduce_id_; + } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The group id of each replica for CrossReplicaSum. + std::vector replica_group_ids_; + + // The string representation of the barrier config used for CrossReplicaSum. + string cross_replica_sum_barrier_; + + // For Allreduce nodes from different modules, if they have the same + // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross modules. + tensorflow::gtl::optional all_reduce_id_; +}; + +class HloAllToAllInstruction : public HloInstruction { + public: + explicit HloAllToAllInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operand, + const std::vector& replica_groups, + tensorflow::StringPiece barrier); + + const std::vector& replica_groups() const { + return replica_groups_; + } + + // TODO(b/110096724): rename this. + void set_cross_replica_sum_barrier(string barrier) { + cross_replica_sum_barrier_ = barrier; + } + string cross_replica_sum_barrier() const { + return cross_replica_sum_barrier_; + } + + HloInstructionProto ToProto() const override; private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const override; + + std::vector replica_groups_; + + // The string representation of the barrier config. + string cross_replica_sum_barrier_; }; class HloReverseInstruction : public HloInstruction { @@ -265,7 +372,7 @@ class HloConcatenateInstruction : public HloInstruction { class HloReduceInstruction : public HloInstruction { public: explicit HloReduceInstruction( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, tensorflow::gtl::ArraySlice args, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation); // Returns the dimension sizes or numbers associated with this instruction. @@ -274,6 +381,47 @@ class HloReduceInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns the input tensors to be reduced. + tensorflow::gtl::ArraySlice inputs() const { + return tensorflow::gtl::ArraySlice(operands(), 0, + operand_count() / 2); + } + + // Returns the init values of the reduction. + tensorflow::gtl::ArraySlice init_values() const { + return tensorflow::gtl::ArraySlice( + operands(), operand_count() / 2, operand_count()); + } + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloSortInstruction : public HloInstruction { + public: + explicit HloSortInstruction(const Shape& shape, int64 dimension, + HloInstruction* keys, + HloInstruction* values = nullptr); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns the sort dimension for this instruction + int64 sort_dimension() { return dimensions(0); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -350,8 +498,7 @@ class HloMapInstruction : public HloInstruction { public: explicit HloMapInstruction( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation, - tensorflow::gtl::ArraySlice static_operands = {}); + HloComputation* map_computation); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -436,8 +583,12 @@ class HloSliceInstruction : public HloInstruction { class HloConstantInstruction : public HloInstruction { public: explicit HloConstantInstruction(std::unique_ptr literal); + // Used when the literal is too large and dropped. + explicit HloConstantInstruction(const Shape& shape); // Returns the literal associated with this instruction. const Literal& literal() const { return *literal_; } + // Returns whether there is literal associated with this instruction. + bool HasLiteral() const { return literal_ != nullptr; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -498,6 +649,7 @@ class HloFusionInstruction : public HloInstruction { tensorflow::gtl::ArraySlice operands, HloComputation* fusion_computation); + string ToCategory() const override; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -576,6 +728,9 @@ class HloFusionInstruction : public HloInstruction { void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } + // If multiple operands are the same instruction, keeps only one of them. + Status DeduplicateFusionOperands(); + private: // Fuses the given instruction into this fusion instruction. When add_output // is false (which is the default), instruction_to_fuse is cloned and the @@ -722,6 +877,421 @@ class HloReducePrecisionInstruction : public HloInstruction { int32 exponent_bits_ = 0; int32 mantissa_bits_ = 0; }; + +class HloInfeedInstruction : public HloInstruction { + public: + explicit HloInfeedInstruction(const Shape& infeed_shape, + HloInstruction* token_operand, + const string& config); + // TODO(b/80000000): Remove this constructor when all uses of infeed are + // converted to take tokens. + explicit HloInfeedInstruction(const Shape& infeed_shape, + const string& config); + // Returns the infeed configuration string. The infeed configuration includes + // any metadata needed for the backend compiler (e.g., infeed buffer address) + // and is target-dependent. + string infeed_config() const { return infeed_config_; } + void set_infeed_config(const string& config) { infeed_config_ = config; } + // Returns the shape of the data received by the infeed. This is not the same + // as the shape of the infeed instruction which produces a tuple containing + // the infeed data shape and a TOKEN. + const Shape& infeed_shape() const { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); + return ShapeUtil::GetSubshape(shape(), {0}); + } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The string representation of the infeed configuration. + string infeed_config_; +}; + +class HloOutfeedInstruction : public HloInstruction { + public: + explicit HloOutfeedInstruction(const Shape& outfeed_shape, + HloInstruction* operand, + HloInstruction* token_operand, + tensorflow::StringPiece outfeed_config); + // TODO(b/80000000): Remove this constructor when all uses of outfeed are + // converted to take tokens. + explicit HloOutfeedInstruction(const Shape& outfeed_shape, + HloInstruction* operand, + tensorflow::StringPiece outfeed_config); + + // Returns the shape for the Outfeed instruction. + const Shape& outfeed_shape() const { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); + return outfeed_shape_; + } + // Returns the config for the Outfeed instruction. + const string& outfeed_config() const { return outfeed_config_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Shape of outfeed request. + Shape outfeed_shape_; + // Outfeed configuration information, only present for kOutfeed. + string outfeed_config_; +}; + +class HloConvolutionInstruction : public HloInstruction { + public: + explicit HloConvolutionInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers); + const Window& window() const override { return window_; } + void set_window(const Window& window) override { window_ = window; } + const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { + return convolution_dimension_numbers_; + } + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + convolution_dimension_numbers_ = dnums; + } + string ToCategory() const override; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + Window window_; + // Describes the dimension numbers used for a convolution. + ConvolutionDimensionNumbers convolution_dimension_numbers_; +}; + +class HloReduceWindowInstruction : public HloInstruction { + public: + explicit HloReduceWindowInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* init_value, + const Window& window, + HloComputation* reduce_computation); + const Window& window() const override { return window_; } + void set_window(const Window& window) override { window_ = window; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + Window window_; +}; + +class HloSelectAndScatterInstruction : public HloInstruction { + public: + explicit HloSelectAndScatterInstruction( + const Shape& shape, HloInstruction* operand, HloComputation* select, + const Window& window, HloInstruction* source, HloInstruction* init_value, + HloComputation* scatter); + const Window& window() const override { return window_; } + void set_window(const Window& window) override { window_ = window; } + // Gets/sets the select or scatter HloComputation for SelectAndScatter. The + // setters should only be called by HloModule or HloComputation methods. + HloComputation* select() const { + return called_computations()[kSelectComputationIndex]; + } + + HloComputation* scatter() const { + return called_computations()[kScatterComputationIndex]; + } + + void set_select(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + set_called_computation(kSelectComputationIndex, computation); + } + + void set_scatter(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + set_called_computation(kScatterComputationIndex, computation); + } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + Window window_; +}; + +class HloCustomCallInstruction : public HloInstruction { + public: + explicit HloCustomCallInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target); + const Window& window() const override { + CHECK(window_ != nullptr); + return *window_; + } + + void set_window(const Window& window) override { + window_ = MakeUnique(window); + } + + const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { + CHECK(convolution_dimension_numbers_ != nullptr); + return *convolution_dimension_numbers_; + } + + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + convolution_dimension_numbers_ = + MakeUnique(dnums); + } + const string& custom_call_target() const { return custom_call_target_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // Name of a global symbol to call, only present for kCustomCall. + string custom_call_target_; + // Describes the window in a windowed operation such as convolution. + std::unique_ptr window_; + // Describes the dimension numbers used for a convolution. + std::unique_ptr convolution_dimension_numbers_; +}; + +class HloHostComputeInstruction : public HloInstruction { + public: + explicit HloHostComputeInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); + // Returns the channel name associated with the instruction. The name is + // used to identify host Send/Recv operations. + const string& channel_name() const { return channel_name_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // Name to use for host send/recv channels. + string channel_name_; + // Estimate of the duration of a host computation in nanoseconds. + int64 cost_estimate_ns_ = 0; +}; + +class HloPadInstruction : public HloInstruction { + public: + explicit HloPadInstruction(const Shape& shape, HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config); + // Returns the padding configuration for a pad node. + const PaddingConfig& padding_config() const { return padding_config_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The padding configuration that describes the edge padding and interior + // padding of this pad instruction. + PaddingConfig padding_config_; +}; + +class HloDynamicSliceInstruction : public HloInstruction { + public: + explicit HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + // Old methods kept for smooth subclassing transition END. + // Returns the size of the slice in the given dimension for a dynamic + // slice node. + int64 slice_sizes(int64 dimension) const { + return dynamic_slice_sizes_[dimension]; + } + const std::vector& dynamic_slice_sizes() const { + return dynamic_slice_sizes_; + } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes the [start, start + size) range size for a dynamic slice + // ('start' is specified dynamically in the second operand of the operation). + std::vector dynamic_slice_sizes_; +}; + +class HloGatherInstruction : public HloInstruction { + public: + explicit HloGatherInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds); + const GatherDimensionNumbers& gather_dimension_numbers() const { + CHECK(gather_dimension_numbers_ != nullptr); + return *gather_dimension_numbers_; + } + tensorflow::gtl::ArraySlice gather_window_bounds() const { + return gather_window_bounds_; + } + // Returns the dump string of the gather dimension numbers. + string GatherDimensionNumbersToString() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Creates an instance of GatherDimensionNumbers. + static GatherDimensionNumbers MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice output_window_dims, + tensorflow::gtl::ArraySlice elided_window_dims, + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::unique_ptr gather_dimension_numbers_; + std::vector gather_window_bounds_; +}; + +class HloScatterInstruction : public HloInstruction { + public: + explicit HloScatterInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers); + const ScatterDimensionNumbers& scatter_dimension_numbers() const { + CHECK(scatter_dimension_numbers_ != nullptr); + return *scatter_dimension_numbers_; + } + // Returns the dump string of the scatter dimension numbers. + string ScatterDimensionNumbersToString() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Creates an instance of ScatterDimensionNumbers. + static ScatterDimensionNumbers MakeScatterDimNumbers( + tensorflow::gtl::ArraySlice update_window_dims, + tensorflow::gtl::ArraySlice inserted_window_dims, + tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + int64 index_vector_dim); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::unique_ptr scatter_dimension_numbers_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index f0d9fdbc8f86da0bb9d7f9235239df677c9506bc..8e0d38b6a63917582b8bfa10f205e1ed511efef3 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -143,8 +143,47 @@ TokKind HloLexer::LexToken() { return TokKind::kLparen; case ')': return TokKind::kRparen; - case '/': - return LexComment(); + case '/': { + if (PeekCurrentChar() == '*') { + // This is the start of a /*...*/ delimited comment. Save the current + // location in case the comment is unterminated so the error message + // will point to the beginning of the comment. + const char* comment_start = current_ptr_; + current_ptr_++; + // Advance until '*/' is found. + while (true) { + int current = GetNextChar(); + if (current == '*' && PeekCurrentChar() == '/') { + // End of comment. + current_ptr_++; + break; + } + if (current == kEOF) { + // Unterminated comment. + current_ptr_ = comment_start; + return TokKind::kError; + } + } + // Return no token for the comment. Keep lexing. + continue; + } else if (PeekCurrentChar() == '/') { + // This is the start of a '//' delimited comment. Throw away + // everything until end of line or file. The end-of-line character(s) + // are left unlexed in the buffer which is harmless because these are + // skipped later by the lexer. This approach enables support for + // different end-of-line encodings. + while (true) { + int current = PeekCurrentChar(); + if (current == kEOF || current == '\n' || current == '\r') { + break; + } + current_ptr_++; + } + continue; + } + // A lone '/' is an error. + return TokKind::kError; + } case '"': return LexString(); } @@ -299,9 +338,12 @@ TokKind HloLexer::LexNumberOrPattern() { static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::strings::safe_strto64( - StringPieceFromPointers(token_start_, current_ptr_), &int64_val_); - return TokKind::kInt; + auto slice = StringPieceFromPointers(token_start_, current_ptr_); + if (tensorflow::strings::safe_strto64(slice, &int64_val_)) { + return TokKind::kInt; + } + LOG(ERROR) << "Failed to parse int literal: " << slice; + return TokKind::kError; } static LazyRE2 neg_inf = {"-inf"}; @@ -354,16 +396,6 @@ tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { return StringPieceFromPointers(start, end); } -TokKind HloLexer::LexComment() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); - static LazyRE2 comment_pattern = {R"(\/\*.*?\*\/)"}; - if (RE2::Consume(&consumable, *comment_pattern)) { - current_ptr_ = consumable.begin(); - return TokKind::kComment; - } - return TokKind::kError; -} - // Lexes quoted string with escaping characters. If matched, the quoted string // will be unescaped and stored to str_val_. TokKind HloLexer::LexString() { @@ -409,8 +441,6 @@ string TokKindToString(TokKind kind) { return "kRparen"; case TokKind::kArrow: return "kArrow"; - case TokKind::kComment: - return "kComment"; case TokKind::kw_HloModule: return "kw_HloModule"; case TokKind::kw_ENTRY: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index ceb674f25e94ac3ac2e6a4a0687a93ffdcd065e0..003ac34ace5713446afa74eb3af96ae33087223e 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -105,7 +105,6 @@ class HloLexer { TokKind LexShape(); TokKind LexConstant(); TokKind LexNumberOrPattern(); - TokKind LexComment(); TokKind LexString(); const tensorflow::StringPiece buf_; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index 0275294a1a86cef13e5b267ad578f30cc18858dc..01b625c29ca2823b2a2490b30a9d4d5128b4c22e 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index c570b420c21fed4d7828feb24ee5c7859db94a79..c577b4359aae6c66f29860a0e56c3487b07afc02 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -187,6 +187,7 @@ HLO_MATCHER(Exp); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); HLO_MATCHER(Ge); +HLO_MATCHER(AfterAll); HLO_MATCHER(Gt); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); @@ -195,6 +196,7 @@ HLO_MATCHER(Log); HLO_MATCHER(And); HLO_MATCHER(Not); HLO_MATCHER(Or); +HLO_MATCHER(Xor); HLO_MATCHER(Lt); HLO_MATCHER(Map); HLO_MATCHER(Maximum); @@ -229,6 +231,7 @@ HLO_MATCHER(Tanh); HLO_MATCHER(Trace); HLO_MATCHER(Transpose); HLO_MATCHER(Tuple); +HLO_MATCHER(TupleSelect); HLO_MATCHER(While); // The special cases below let you check additional information about the diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 9a3010cf1ff75e840130d8442bbe26d6041cef25..7961aece541faeb66875885b380158756c503250 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -75,8 +76,10 @@ TEST(HloMatchersTest, Test) { } TEST(HloMatchersTest, CustomCallMatcher) { - auto c1 = HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3})); - auto c2 = HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3})); + auto c1 = + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3})); + auto c2 = + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3})); auto call = HloInstruction::CreateCustomCall( ShapeUtil::MakeShape(F32, {1}), {c1.get(), c2.get()}, "foo_target"); @@ -154,9 +157,8 @@ TEST(HloMatchersTest, ShardingMatcher) { Array assignment({2}); assignment.SetValues({0, 1}); auto sharding = HloSharding::Tuple( - tuple_shape, - {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment), - HloSharding::AssignDevice(1), HloSharding::Replicate()}); + tuple_shape, {HloSharding::Tile(assignment), HloSharding::AssignDevice(1), + HloSharding::Replicate()}); p2->set_sharding(sharding); EXPECT_THAT(p0.get(), op::NoSharding()); @@ -169,8 +171,7 @@ TEST(HloMatchersTest, ShardingMatcher) { EXPECT_THAT( p2.get(), - op::Sharding( - "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}")); + op::Sharding("{{devices=[2]0,1}, {maximal device=1}, {replicated}}")); EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: " diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 9c59374b4a9d7e3dbfb99d8a6b30d4230e553658..55ff073d3faf34aa0f1b8f0886946837e7a49bcc 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -58,7 +58,7 @@ HloComputation* HloModule::AddComputationInternal( // If the module configuration has no entry layout computation set, create a // default one based on the program shape. - if (!config_.has_host_entry_computation_layout()) { + if (!config_.has_entry_computation_layout()) { config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } @@ -231,14 +231,11 @@ StatusOr> HloModule::CreateFromProto( TF_RET_CHECK(proto.has_program_shape()) << "No program shape found in the proto"; const auto& expected_program_shape = proto.program_shape(); - TF_RET_CHECK( - expected_program_shape.parameters_size() == - module_config.device_entry_computation_layout().parameter_count()); + TF_RET_CHECK(expected_program_shape.parameters_size() == + module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = - module_config.device_entry_computation_layout() - .parameter_layout(i) - .shape(); + module_config.entry_computation_layout().parameter_layout(i).shape(); TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), parameter_shape)) << "HloModuleConfig has different shape for parameter " << i @@ -248,7 +245,7 @@ StatusOr> HloModule::CreateFromProto( << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); } const Shape& result_shape = - module_config.device_entry_computation_layout().result_layout().shape(); + module_config.entry_computation_layout().result_layout().shape(); TF_RET_CHECK( ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " @@ -327,7 +324,7 @@ StatusOr HloModule::CreateModuleConfigFromProto( // The module config is constructed with default layouts regardless of what is // passed in via the ProgramShape. Set the layouts to the appropriate values. ComputationLayout* entry_layout = - module_config.mutable_host_entry_computation_layout(); + module_config.mutable_entry_computation_layout(); for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { TF_RETURN_IF_ERROR( entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -335,9 +332,6 @@ StatusOr HloModule::CreateModuleConfigFromProto( } TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( program_shape.result())); - *module_config.mutable_device_entry_computation_layout() = - module_config.host_entry_computation_layout(); - return module_config; } @@ -451,7 +445,7 @@ int64 HloModule::instruction_count() const { return n; } -std::list HloModule::MakeComputationPostOrder() const { +std::vector HloModule::MakeComputationPostOrder() const { // First determine all root computations by building a set of nonroot // computations (computations which are called by an instruction in the // module). @@ -469,7 +463,7 @@ std::list HloModule::MakeComputationPostOrder() const { // order. This prevents duplication as an embedded computation may be called // from two different root computations. std::set added_computations; - std::list post_order; + std::vector post_order; for (auto& computation : computations_) { if (nonroot_computations.count(computation.get()) == 0) { for (HloComputation* embedded_computation : @@ -543,10 +537,11 @@ uint64 HloModule::RandomNew64() const { HloComputation* HloModule::GetComputationWithName( tensorflow::StringPiece name) { - auto it = c_find_if(computations(), [&](HloComputation* computation) { + auto computations_in_module = computations(); + auto it = c_find_if(computations_in_module, [&](HloComputation* computation) { return computation->name() == name; }); - return it == computations().end() ? nullptr : *it; + return it == computations_in_module.end() ? nullptr : *it; } /* static */ std::atomic HloModule::next_unique_module_id_(0); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 757e65bda286d983d05e5a791aa7dffe97bac945..d2e726a0db63f622cd5092d56b4f746232d04aad 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -105,20 +105,19 @@ class HloModule { return entry_computation_; } - ComputationLayout* mutable_host_entry_computation_layout() { - return config_.mutable_host_entry_computation_layout(); + // Creates the ComputationLayout which describes the current status of the HLO + // module entry computation. + ComputationLayout compute_computation_layout() const { + return ComputationLayout(entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); } - const ComputationLayout& host_entry_computation_layout() const { - return config_.host_entry_computation_layout(); + ComputationLayout* mutable_entry_computation_layout() { + return config_.mutable_entry_computation_layout(); } - ComputationLayout* mutable_device_entry_computation_layout() { - return config_.mutable_device_entry_computation_layout(); - } - - const ComputationLayout& device_entry_computation_layout() const { - return config_.device_entry_computation_layout(); + const ComputationLayout& entry_computation_layout() const { + return config_.entry_computation_layout(); } // Gets the computations in this module. @@ -154,7 +153,7 @@ class HloModule { // Compute and return a post order of all computations in the module. The sort // is defined like so: if computation A has an instruction which calls // computation B, then A will appear after B in the sort. - std::list MakeComputationPostOrder() const; + std::vector MakeComputationPostOrder() const; // Gets the computations in this module which aren't for fusion nodes. // diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index dae5578a3158fecb8219e518841dec1020b2ca98..07a8c798dbee072db3b75d5e99ca0dcabb5fdf6b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -28,16 +28,14 @@ namespace xla { using tensorflow::strings::StrAppend; -HloModuleConfig::HloModuleConfig() {} - -HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape) - : host_entry_computation_layout_(program_shape), - device_entry_computation_layout_(program_shape) {} +HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape, + bool ignore_layouts) + : entry_computation_layout_( + ComputationLayout(program_shape, ignore_layouts)) {} void HloModuleConfig::SetDefaultComputationLayout( const ProgramShape& program_shape) { - host_entry_computation_layout_ = ComputationLayout(program_shape); - device_entry_computation_layout_ = ComputationLayout(program_shape); + entry_computation_layout_ = ComputationLayout(program_shape); } string HloModuleConfig::compilation_cache_key() const { @@ -46,18 +44,11 @@ string HloModuleConfig::compilation_cache_key() const { StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : - host_entry_computation_layout_->parameter_layouts()) { + entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", - host_entry_computation_layout_->result_shape().SerializeAsString()); - for (const ShapeLayout& param_layout : - device_entry_computation_layout_->parameter_layouts()) { - params.push_back(param_layout.shape().DebugString()); - } - StrAppend( - &key, tensorflow::str_util::Join(params, ", "), ") => ", - device_entry_computation_layout_->result_shape().SerializeAsString()); + entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. static std::atomic counter{0}; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index cdb0b29a2399b387bc617262032e9083ba079625..074e9c90705d432b8344aebaf3c15aeb41a59fa3 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -37,48 +37,34 @@ class HloModuleConfig { // ComputationLayout. The default ctor creates it without -- in this case // accessing entry_computation_layout will CHECK-fail. The ctor accepting a // ProgramShape creates a computation layout using this shape. - HloModuleConfig(); - explicit HloModuleConfig(const ProgramShape& program_shape); + // The layouts in the ProgramShape will be reset to default unless + // ignore_layouts is set to false. + HloModuleConfig() = default; - // Checks if this config has an entry computation layout already. - bool has_host_entry_computation_layout() const { - return host_entry_computation_layout_.has_value(); - } + explicit HloModuleConfig(const ProgramShape& program_shape, + bool ignore_layouts = true); - bool has_device_entry_computation_layout() const { - return device_entry_computation_layout_.has_value(); + // Checks if this config has an entry computation layout already. + bool has_entry_computation_layout() const { + return entry_computation_layout_.has_value(); } // Sets the entry computation layout for this config. If the entry computation // layout already exists, it is silently replaced. void SetDefaultComputationLayout(const ProgramShape& program_shape); - // Returns a constant reference to the on-host layout of the entry - // computation. Assumes the layout was set. - const ComputationLayout& host_entry_computation_layout() const { - CHECK(host_entry_computation_layout_.has_value()); - return *host_entry_computation_layout_; - } - - // Returns a mutable pointer to the layout of the on-host entry computation. + // Returns a constant reference to the layout of the entry computation. // Assumes the layout was set. - ComputationLayout* mutable_host_entry_computation_layout() { - CHECK(host_entry_computation_layout_.has_value()); - return &(*host_entry_computation_layout_); - } - - // Returns a constant reference to the on-device layout of the entry - // computation. Assumes the layout was set. - const ComputationLayout& device_entry_computation_layout() const { - CHECK(device_entry_computation_layout_.has_value()); - return *device_entry_computation_layout_; + const ComputationLayout& entry_computation_layout() const { + CHECK(entry_computation_layout_.has_value()); + return *entry_computation_layout_; } - // Returns a mutable pointer to the layout of the on-device entry computation. + // Returns a mutable pointer to the layout of the entry computation. // Assumes the layout was set. - ComputationLayout* mutable_device_entry_computation_layout() { - CHECK(device_entry_computation_layout_.has_value()); - return &(*device_entry_computation_layout_); + ComputationLayout* mutable_entry_computation_layout() { + CHECK(entry_computation_layout_.has_value()); + return &(*entry_computation_layout_); } // Returns whether to enable HLO-level profiling. @@ -127,8 +113,7 @@ class HloModuleConfig { private: // If you add new members, be sure to update compilation_cache_key. - tensorflow::gtl::optional host_entry_computation_layout_; - tensorflow::gtl::optional device_entry_computation_layout_; + tensorflow::gtl::optional entry_computation_layout_; // Whether this is a 'host module'. bool is_host_module_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index bf33640db16638803f4f8e6c66f35d6bb6e2c9fe..10bf9ffd6c1960df5ca2a3555d120b0874407f15 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -75,10 +77,23 @@ Status HloModuleGroupMetadata::Build() { if (tracked == nullptr) { return Status::OK(); } - // Add the parent computation of this channel instruction and its peer - // computation (both must be while computations) as companions. + + std::vector peers; if (IsChannelInstruction(hlo)) { - HloComputation* peer_computation = PeerComputation(hlo); + peers.push_back(PeerComputation(hlo)); + } else if (hlo->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) { + if (instr == hlo) { + continue; + } + peers.push_back(instr->parent()); + } + } + + // Add the parent computation of this channel (or all-reduce) instruction + // and its peer computation(s) (both must be while computations) as + // companions. + for (HloComputation* peer_computation : peers) { const TrackedInstruction* peer_tracked = GetTrackedInstruction(peer_computation); TF_RET_CHECK(peer_tracked != nullptr) @@ -162,8 +177,12 @@ bool HloModuleGroupMetadata::IsChannelInstruction( case HloOpcode::kSend: case HloOpcode::kRecv: case HloOpcode::kSendDone: - case HloOpcode::kRecvDone: - return true; + case HloOpcode::kRecvDone: { + const HloSendRecvInstruction* send_recv_instr = + DynCast(instruction); + CHECK(send_recv_instr != nullptr); + return !send_recv_instr->is_host_transfer(); + } default: return false; } @@ -175,7 +194,8 @@ bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const { bool HloModuleGroupMetadata::InstructionCommunicates( HloInstruction* hlo) const { - return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo); + return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo) || + hlo->IsCrossModuleAllReduce(); } const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel( @@ -200,6 +220,13 @@ HloComputation* HloModuleGroupMetadata::PeerComputation( } } +const std::vector& HloModuleGroupMetadata::GetAllReduceGroup( + int64 all_reduce_id) const { + auto it = all_reduce_map_.find(all_reduce_id); + CHECK(it != all_reduce_map_.end()); + return it->second; +} + std::vector HloModuleGroupMetadata::GetCompanionsPath(const HloInstruction* hlo) const { std::vector path; @@ -278,10 +305,27 @@ Status HloModuleGroupMetadata::RecordInstructions() { tracked_instructions_[hlo->to_apply()] = TrackedInstruction(hlo, ComputationKind::kCallFunction); } + + // Group cross module all-reduce instructions by the all_reduce id. + if (hlo->IsCrossModuleAllReduce()) { + TF_RET_CHECK(channel_id_map_.find(*hlo->all_reduce_id()) == + channel_id_map_.end()) + << "all_reduce_id " << *hlo->all_reduce_id() + << " is already used by a send/recv instruction"; + all_reduce_map_[*hlo->all_reduce_id()].push_back(hlo); + max_channel_id_ = std::max(max_channel_id_, *hlo->all_reduce_id()); + return Status::OK(); + } + if (!IsChannelInstruction(hlo)) { return Status::OK(); } + TF_RET_CHECK(all_reduce_map_.find(hlo->channel_id()) == + all_reduce_map_.end()) + << "channel id " << hlo->channel_id() + << " is already used by an all-reduce instruction"; + // Add a new channel if needed. if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) { channels_.emplace_back(); @@ -324,6 +368,7 @@ Status HloModuleGroupMetadata::RecordInstructions() { } } VLOG(2) << "Created " << channels_.size() << " channels"; + VLOG(2) << "Created " << all_reduce_map_.size() << " all-reduce groups"; return Status::OK(); } @@ -382,7 +427,8 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { // Check if the shapes match for each channel. for (const Channel& channel : channels_) { const Shape& send_shape = channel.send->operand(0)->shape(); - const Shape& recv_shape = channel.recv_done->shape(); + const Shape& recv_shape = + ShapeUtil::GetTupleElementShape(channel.recv_done->shape(), 0); if (!ShapeUtil::Compatible(send_shape, recv_shape)) { return FailedPrecondition("send/recv shapes do not match"); } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index ffde3a332dfc141ca928a44cfdf4686900e9f47b..84f2d3f5fbc1a6ff1df8ba3c0babd122e5701148 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -92,7 +92,7 @@ class HloModuleGroupMetadata { ComputationKind kind_ = ComputationKind::kInvalid; }; - // Represents a channel and the 4 instructions that form the channel. + // Represents a channel and the instructions that form the channel. struct Channel { int64 id = -1; HloInstruction* send = nullptr; @@ -118,13 +118,17 @@ class HloModuleGroupMetadata { // comment above on companion instructions. bool IsCompanionInstruction(HloInstruction* hlo) const; - // Returns true if the instruction is either a channel instruction or a - // companion instruction. + // Returns true if the instruction is either a channel instruction, a + // cross-module all-reduce instruction, or a companion instruction. bool InstructionCommunicates(HloInstruction* hlo) const; // Returns the Channel instance for the given channel id. const Channel& GetChannel(int64 channel_id) const; + // Returns the all-reduce instructions with the same all_reduce_id. + const std::vector& GetAllReduceGroup( + int64 all_reduce_id) const; + // Returns the computation that contains the peer channel instructions for // the given instruction. // @@ -187,13 +191,14 @@ class HloModuleGroupMetadata { // Returns all channels in the module group. const std::vector& channels() const { return channels_; } - // Returns the maximum channel id used in the module group. + // Returns the maximum channel id or all_reduce_id used in the module group. int64 max_channel_id() const { return max_channel_id_; } private: Status Build(); - // Record all channel instructions and While instructions. + // Record all channel instructions, cross-module AllReduce instructions, and + // While/Conditional/Call instructions. Status RecordInstructions(); // Verifies the given HloModules are well-formed and follow the specification, @@ -255,6 +260,9 @@ class HloModuleGroupMetadata { // Map from channel ids to the index in channels_. tensorflow::gtl::FlatMap channel_id_map_; + // Map from all-reduce ids to the all reduce instructions. + tensorflow::gtl::FlatMap> all_reduce_map_; + // The maximum channel id used in the module group. int64 max_channel_id_ = -1; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 5a0d1e264eb5095ff53721416ebcf4842a063f97..9fd0ade153109c6c809c37aa08257f83a82c44d5 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -56,12 +56,17 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( }; // If the given instruction is a companion instruction, we need to find the - // predecessors of all of its companion instructions. + // predecessors of all of its companion instructions. If the instruction is an + // all-reduce, we need to find the predecessors of all the peer all-reduce + // instructions. std::vector instruction_group; if (metadata_.IsCompanionInstruction(instruction)) { for (HloInstruction* companion : metadata_.Companions(instruction)) { instruction_group.push_back(companion); } + } else if (instruction->IsCrossModuleAllReduce()) { + instruction_group = + metadata_.GetAllReduceGroup(*instruction->all_reduce_id()); } else { instruction_group.push_back(instruction); } @@ -112,12 +117,17 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( }; // If the given instruction is a companion instruction, we need to find the - // successors of all of its companion instructions. + // successors of all of its companion instructions. If the instruction is an + // all-reduce, we need to find the successors of all its peer all-reduce + // instructions. std::vector instruction_group; if (metadata_.IsCompanionInstruction(instruction)) { for (HloInstruction* companion : metadata_.Companions(instruction)) { instruction_group.push_back(companion); } + } else if (instruction->IsCrossModuleAllReduce()) { + instruction_group = + metadata_.GetAllReduceGroup(*instruction->all_reduce_id()); } else { instruction_group.push_back(instruction); } @@ -170,15 +180,17 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( HloInstruction* hlo = stack.top(); // Find the instruction group of the currently visited instruction. The - // instruction group represents all companion instructions of the - // current instruction, and are considered to be a single entity for the - // purpose of the traversal (i.e., they must always be in the same visit - // state). + // instruction group represents all companion instructions of the current + // instruction, or all the all-reduce instructions that belong to the same + // group, or are considered to be a single entity for the purpose of the + // traversal (i.e., they must always be in the same visit state). std::vector instruction_group; if (metadata_.IsCompanionInstruction(hlo)) { for (HloInstruction* companion : metadata_.Companions(hlo)) { instruction_group.push_back(companion); } + } else if (hlo->IsCrossModuleAllReduce()) { + instruction_group = metadata_.GetAllReduceGroup(*hlo->all_reduce_id()); } else { instruction_group.push_back(hlo); } @@ -277,7 +289,7 @@ Status HloModuleGroupUtil::VerifyComputations( StatusOr> HloModuleGroupUtil::ComputeReachability( tensorflow::gtl::ArraySlice computations) { - std::list post_order; + std::vector post_order; auto visit_function = [&](HloInstruction* instruction, const std::vector& instruction_group) { @@ -292,7 +304,7 @@ HloModuleGroupUtil::ComputeReachability( } auto reachability = MakeUnique(post_order); for (HloInstruction* hlo : post_order) { - reachability->SetReachabilityToUnion(GlobalPredecessors(hlo), hlo); + reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo); } return std::move(reachability); } diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 7f28a804bfec9c2f1bbb5fa08f7dd4e68be14d35..236f4500860a8673e61cbd2f861a8fc40c7861f7 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -38,7 +38,7 @@ class HloModuleTest : public HloTestBase { std::unique_ptr CreateConstantComputation() { auto builder = HloComputation::Builder("Constant"); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); return builder.Build(); } @@ -122,7 +122,7 @@ TEST_F(HloModuleTest, CloneHasFusion) { { auto b = HloComputation::Builder("Entry"); auto input = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); b.AddInstruction( HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput, /*operands=*/{input}, fused_computation)); @@ -173,7 +173,7 @@ TEST_F(HloModuleTest, LargeConstantToString) { auto builder = HloComputation::Builder("Constant"); std::vector values(16, 42.0); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(values))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(values))); module->AddEntryComputation(builder.Build()); EXPECT_EQ( diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index a35546f5f41b149d119ee141fd734da8bfd055b2..ec279867e595b66a22882703cc06046e3e916c96 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -47,6 +47,7 @@ namespace xla { #define HLO_OPCODE_LIST(V) \ V(kAbs, "abs") \ V(kAdd, "add") \ + V(kAllToAll, "all-to-all") \ V(kAtan2, "atan2") \ V(kBatchNormGrad, "batch-norm-grad") \ V(kBatchNormInference, "batch-norm-inference") \ @@ -81,12 +82,13 @@ namespace xla { V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ - V(kGenerateToken, "generate-token", kHloOpcodeIsVariadic) \ + V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ V(kHostCompute, "host-compute") \ V(kImag, "imag") \ V(kInfeed, "infeed") \ + V(kIota, "iota") \ V(kIsFinite, "is-finite") \ V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \ V(kLog, "log") \ @@ -94,6 +96,7 @@ namespace xla { V(kAnd, "and") \ V(kNot, "not") \ V(kOr, "or") \ + V(kXor, "xor") \ V(kLt, "less-than", kHloOpcodeIsComparison) \ V(kMap, "map", kHloOpcodeIsVariadic) \ V(kMaximum, "maximum") \ @@ -116,6 +119,7 @@ namespace xla { V(kReverse, "reverse") \ V(kRng, "rng") \ V(kRoundNearestAfz, "round-nearest-afz") \ + V(kScatter, "scatter") \ V(kSelect, "select") \ V(kSelectAndScatter, "select-and-scatter") \ V(kSend, "send") \ @@ -132,6 +136,7 @@ namespace xla { V(kTrace, "trace") \ V(kTranspose, "transpose") \ V(kTuple, "tuple", kHloOpcodeIsVariadic) \ + V(kTupleSelect, "tuple-select") \ V(kWhile, "while") enum class HloOpcode { diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 774345124b4ad62e35d9423a23f1dbaa28e44d80..6f3f83f63a05fafaa3f3ddcff8a7cac7cb7b06d5 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -58,7 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kConcatenate: case HloOpcode::kFusion: case HloOpcode::kMap: - case HloOpcode::kGenerateToken: + case HloOpcode::kAfterAll: case HloOpcode::kTuple: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index dcd4725fe78e8b9b5d14437e964cb5aaf1664117..6c1e015f77a62c3e3ff7ffa5ce9dea735f46e10a 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -232,6 +232,11 @@ bool HloOrdering::UseIsBeforeValueDefinition( << " and def is in FALSE computation"; return true; } + if (value.defining_instruction() == use.instruction) { + VLOG(4) << " use is conditional " << use << " and def is " + << value.ToShortString(); + return true; + } } VLOG(4) << " use is not before value"; diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index cfe5dace05ac03f1573f90b2ce664c94837837b4..126d3a2d9c70bff1d2a022e395652049768d6d21 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -57,7 +57,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { auto builder_c = HloComputation::Builder("C"); HloInstruction* c = builder_c.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); HloComputation* computation_c = module->AddEmbeddedComputation(builder_c.Build()); @@ -145,7 +145,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(scalar_shape, condition, body, constant)); module->AddEntryComputation(builder.Build()); @@ -208,7 +208,7 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(scalar_shape, condition, body, constant)); auto add = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index fef475380c5c810e1c4712406dde6b1135be3d97..4b3cd99dc06520bfeb60430d9d4316db66ea04b3 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -15,8 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -117,11 +119,13 @@ class HloParser { // Types of attributes. enum class AttrTy { + kBool, kInt64, kInt32, kFloat, kString, kBracedInt64List, + kBracedInt64ListList, kHloComputation, kFftType, kWindow, @@ -202,6 +206,10 @@ class HloParser { bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); + // 'parse_and_add_item' is an lambda to parse an element in the list and add + // the parsed element to the result. It's supposed to capture the result. + bool ParseList(const TokKind start, const TokKind end, const TokKind delim, + const std::function& parse_and_add_item); bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); @@ -327,22 +335,15 @@ bool HloParser::ParseComputations() { // set the layouts to what the hlo text says. for (int p = 0; p < computation->num_parameters(); p++) { const Shape& param_shape = computation->parameter_instruction(p)->shape(); - TF_CHECK_OK(module_->mutable_host_entry_computation_layout() - ->mutable_parameter_layout(p) - ->CopyLayoutFromShape(param_shape)); - TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + TF_CHECK_OK(module_->mutable_entry_computation_layout() ->mutable_parameter_layout(p) ->CopyLayoutFromShape(param_shape)); } const Shape& result_shape = computation->root_instruction()->shape(); - TF_CHECK_OK(module_->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->CopyLayoutFromShape(result_shape)); - TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + TF_CHECK_OK(module_->mutable_entry_computation_layout() ->mutable_result_layout() ->CopyLayoutFromShape(result_shape)); } - return true; } @@ -496,6 +497,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateConstant(std::move(literal))); break; } + case HloOpcode::kIota: { + if (!ParseOperands(&operands, /*expected_size=*/0) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateIota(shape)); + break; + } // Unary ops. case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: @@ -516,7 +525,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -545,6 +553,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kRemainder: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: { @@ -558,7 +567,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } // Ternary ops. case HloOpcode::kClamp: - case HloOpcode::kSelect: { + case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: { if (!ParseOperands(&operands, /*expected_size=*/3) || !ParseAttributes(attrs)) { return false; @@ -590,27 +600,52 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional to_apply; optional> replica_group_ids; optional barrier; + optional all_reduce_id; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; attrs["replica_group_ids"] = { /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; + attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64, + &all_reduce_id}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - if (replica_group_ids) { instruction = builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( shape, operands, *to_apply, *replica_group_ids, - barrier ? *barrier : "")); + barrier ? *barrier : "", all_reduce_id)); } else { instruction = builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, {}, barrier ? *barrier : "")); + shape, operands, *to_apply, {}, barrier ? *barrier : "", + all_reduce_id)); } break; } + case HloOpcode::kAllToAll: { + optional>> tmp_groups; + optional barrier; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; + attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + std::vector replica_groups; + if (tmp_groups) { + c_transform(*tmp_groups, std::back_inserter(replica_groups), + [](const std::vector& ids) { + ReplicaGroup group; + *group.mutable_replica_ids() = {ids.begin(), ids.end()}; + return group; + }); + } + instruction = builder->AddInstruction(HloInstruction::CreateAllToAll( + shape, operands, replica_groups, barrier ? *barrier : "")); + break; + } case HloOpcode::kReshape: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -620,12 +655,42 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateReshape(shape, operands[0])); break; } - case HloOpcode::kGenerateToken: { + case HloOpcode::kAfterAll: { if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateGenerateToken(operands)); + if (operands.empty()) { + instruction = builder->AddInstruction(HloInstruction::CreateToken()); + } else { + instruction = + builder->AddInstruction(HloInstruction::CreateAfterAll(operands)); + } + break; + } + case HloOpcode::kSort: { + auto loc = lexer_.GetLoc(); + + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs) || + dimensions->size() != 1) { + return false; + } + switch (operands.size()) { + case 1: + instruction = builder->AddInstruction(HloInstruction::CreateSort( + shape, dimensions->at(0), /*keys=*/operands[0])); + break; + case 2: + instruction = builder->AddInstruction(HloInstruction::CreateSort( + shape, dimensions->at(0), + /*keys=*/operands[0], /*values=*/operands[1])); + break; + default: + return Error(loc, StrCat("expects either 1 or 2 operands, but has ", + operands.size(), " operands")); + } break; } case HloOpcode::kTuple: { @@ -652,18 +717,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kRecv: { optional channel_id; + // If the is_host_transfer attribute is not present then default to false. + optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; - if (!ParseOperands(&operands, /*expected_size=*/0) || + attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, + &is_host_transfer}; + if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id)); + // If the is_host_transfer attribute is not present then default to false. + instruction = builder->AddInstruction(HloInstruction::CreateRecv( + shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer)); break; } case HloOpcode::kRecvDone: { optional channel_id; + // If the is_host_transfer attribute is not present then default to false. + optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, + &is_host_transfer}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { return false; @@ -671,24 +745,32 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (channel_id != operands[0]->channel_id()) { return false; } - instruction = - builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0])); + instruction = builder->AddInstruction( + HloInstruction::CreateRecvDone(operands[0], *is_host_transfer)); break; } case HloOpcode::kSend: { optional channel_id; + // If the is_host_transfer attribute is not present then default to false. + optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; - if (!ParseOperands(&operands, /*expected_size=*/1) || + attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, + &is_host_transfer}; + if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateSend(operands[0], *channel_id)); + instruction = builder->AddInstruction(HloInstruction::CreateSend( + operands[0], operands[1], *channel_id, *is_host_transfer)); break; } case HloOpcode::kSendDone: { optional channel_id; + // If the is_host_transfer attribute is not present then default to false. + optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, + &is_host_transfer}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { return false; @@ -696,8 +778,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (channel_id != operands[0]->channel_id()) { return false; } - instruction = - builder->AddInstruction(HloInstruction::CreateSendDone(operands[0])); + instruction = builder->AddInstruction( + HloInstruction::CreateSendDone(operands[0], *is_host_transfer)); break; } case HloOpcode::kGetTupleElement: { @@ -810,18 +892,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReduce: { + auto loc = lexer_.GetLoc(); + optional reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; optional> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } + if (operands.size() % 2) { + return Error(loc, StrCat("expects an even number of operands, but has ", + operands.size(), " operands")); + } instruction = builder->AddInstruction(HloInstruction::CreateReduce( - shape, /*operand=*/operands[0], /*init_value=*/operands[1], + shape, /*operands=*/ + tensorflow::gtl::ArraySlice(operands, 0, + operands.size() / 2), + /*init_values=*/ + tensorflow::gtl::ArraySlice( + operands, operands.size() / 2, operands.size()), *dimensions_to_reduce, *reduce_computation)); break; } @@ -981,23 +1073,53 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kInfeed: { optional config; attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands, /*expected_size=*/0) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateInfeed(shape, config ? *config : "")); + // We need to know the infeed data shape to construct the infeed + // instruction. This is the zero-th element of the tuple-shaped output of + // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail + // if the shape is not a non-empty tuple, so add guard so an error message + // can be emitted instead of a check fail + if (!ShapeUtil::IsTuple(shape) && !ShapeUtil::IsEmptyTuple(shape)) { + return Error(lexer_.GetLoc(), + "infeed must have a non-empty tuple shape"); + } + + if (operands.empty()) { + // TODO(b/80000000): Remove this when all uses of infeed are + // converted to take tokens. + instruction = builder->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::GetTupleElementShape(shape, 0), config ? *config : "")); + } else if (operands.size() == 1) { + instruction = builder->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::GetTupleElementShape(shape, 0), operands[0], + config ? *config : "")); + } else { + return Error(lexer_.GetLoc(), + "infeed must have exactly zero or one operands"); + } break; } case HloOpcode::kOutfeed: { optional config; attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( - operands[0]->shape(), operands[0], config ? *config : "")); + if (operands.size() == 1) { + // TODO(b/80000000): Remove this when all uses of outfeed are + // converted to take tokens. + instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( + operands[0]->shape(), operands[0], config ? *config : "")); + } else if (operands.size() == 2) { + instruction = builder->AddInstruction( + HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0], + operands[1], config ? *config : "")); + } else { + return Error(lexer_.GetLoc(), + "outfeed must have exactly one or two operands"); + } break; } case HloOpcode::kRng: { @@ -1047,13 +1169,24 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kCustomCall: { optional custom_call_target; + optional window; + optional dnums; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; + attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; + attrs["dim_labels"] = {/*required=*/false, + AttrTy::kConvolutionDimensionNumbers, &dnums}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( shape, operands, *custom_call_target)); + if (window.has_value()) { + instruction->set_window(*window); + } + if (dnums.has_value()) { + instruction->set_convolution_dimension_numbers(*dnums); + } break; } case HloOpcode::kHostCompute: { @@ -1134,17 +1267,54 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return false; } - GatherDimensionNumbers dim_numbers = HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/*output_window_dims, - /*elided_window_dims=*/*elided_window_dims, - /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, - /*index_vector_dim=*/*index_vector_dim); + GatherDimensionNumbers dim_numbers = + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/*output_window_dims, + /*elided_window_dims=*/*elided_window_dims, + /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, + /*index_vector_dim=*/*index_vector_dim); instruction = builder->AddInstruction(HloInstruction::CreateGather( shape, /*operand=*/operands[0], /*gather_indices=*/operands[1], dim_numbers, *window_bounds)); break; } + case HloOpcode::kScatter: { + optional> update_window_dims; + attrs["update_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims}; + optional> inserted_window_dims; + attrs["inserted_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims}; + optional> scatter_dims_to_operand_dims; + attrs["scatter_dims_to_operand_dims"] = {/*required=*/true, + AttrTy::kBracedInt64List, + &scatter_dims_to_operand_dims}; + optional index_vector_dim; + attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, + &index_vector_dim}; + + optional update_computation; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &update_computation}; + + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + + ScatterDimensionNumbers dim_numbers = + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/*update_window_dims, + /*inserted_window_dims=*/*inserted_window_dims, + /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims, + /*index_vector_dim=*/*index_vector_dim); + + instruction = builder->AddInstruction(HloInstruction::CreateScatter( + shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1], + /*updates=*/operands[2], *update_computation, dim_numbers)); + break; + } case HloOpcode::kDomain: { DomainData domain; attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; @@ -1153,8 +1323,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return false; } instruction = builder->AddInstruction(HloInstruction::CreateDomain( - shape, operands[0], std::move(domain.entry_metadata), - std::move(domain.exit_metadata))); + shape, operands[0], std::move(domain.exit_metadata), + std::move(domain.entry_metadata))); break; } case HloOpcode::kTrace: @@ -1240,7 +1410,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, bool replicated = false; std::vector devices; std::vector tile_assignment_dimensions; - Shape tile_shape; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { case TokKind::kw_maximal: @@ -1291,7 +1460,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, break; } case TokKind::kShape: - tile_shape = lexer_.GetShapeVal(); + // TODO(b/112302613): Left here for backward compatibility to ignore the + // removed tile shape data. lexer_.Lex(); break; case TokKind::kRbrace: @@ -1306,19 +1476,12 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return Error(loc, "replicated shardings should not have any devices assigned"); } - if (!ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, - "replicated shardings should not have any tile shape set"); - } sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { return Error(loc, "maximal shardings should have exactly one device assigned"); } - if (!ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, "maximal shardings should not have any tile shape set"); - } sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); sharding->add_tile_assignment_devices(devices[0]); } else { @@ -1326,9 +1489,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return Error( loc, "non-maximal shardings must have more than one device assigned"); } - if (ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, "non-maximal shardings should have a tile shape set"); - } if (tile_assignment_dimensions.empty()) { return Error( loc, @@ -1336,7 +1496,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, "dimensions"); } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); - *sharding->mutable_tile_shape() = tile_shape; for (tensorflow::int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } @@ -1493,6 +1652,24 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, "value ", value, " is out of range for literal's primitive type ", PrimitiveType_Name(literal->shape().element_type()))); } + } else if (std::is_unsigned::value) { + CHECK((std::is_same::value || + std::is_same::value)) + << "Unimplemented checking for ParsedElemT"; + + ParsedElemT upper_bound; + if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { + upper_bound = std::numeric_limits::max(); + } else { + upper_bound = + static_cast(std::numeric_limits::max()); + } + if (value > upper_bound || value < 0) { + // Value is out of range for LiteralNativeT. + return TokenError(StrCat( + "value ", value, " is out of range for literal's primitive type ", + PrimitiveType_Name(literal->shape().element_type()))); + } } else if (value > static_cast( std::numeric_limits::max()) || value < static_cast( @@ -1561,7 +1738,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, } } } - *literal = Literal::MakeTupleOwned(std::move(elements)); + *literal = LiteralUtil::MakeTupleOwned(std::move(elements)); return ParseToken(TokKind::kRparen, StrCat("expects ')' at the end of the tuple with ", ShapeUtil::TupleElementCount(shape), "elements")); @@ -1589,8 +1766,8 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, } // Create a literal with the given shape in default layout. - *literal = Literal::CreateFromDimensions(shape.element_type(), - AsInt64Slice(shape.dimensions())); + *literal = LiteralUtil::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); tensorflow::int64 nest_level = 0; tensorflow::int64 linear_index = 0; // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for @@ -1647,7 +1824,6 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, break; } case TokKind::kComma: - case TokKind::kComment: // Skip. lexer_.Lex(); break; @@ -1983,6 +2159,14 @@ bool HloParser::ParseAttributeHelper( bool success = [&] { LocTy attr_loc = lexer_.GetLoc(); switch (attr_type) { + case AttrTy::kBool: { + bool result; + if (!ParseBool(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } case AttrTy::kInt64: { tensorflow::int64 result; if (!ParseInt64(&result)) { @@ -2086,6 +2270,26 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kBracedInt64ListList: { + std::vector> result; + auto parse_and_add_item = [&]() { + std::vector item; + if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, + TokKind::kComma, &item)) { + return false; + } + result.push_back(item); + return true; + }; + if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item)) { + return false; + } + static_cast>>*>( + attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kSliceRanges: { SliceRanges result; if (!ParseSliceRanges(&result)) { @@ -2428,6 +2632,26 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, end, StrCat("expects an int64 list to end with ", TokKindToString(end))); } +bool HloParser::ParseList(const TokKind start, const TokKind end, + const TokKind delim, + const std::function& parse_and_add_item) { + if (!ParseToken(start, StrCat("expects a list starting with ", + TokKindToString(start)))) { + return false; + } + if (lexer_.GetKind() == end) { + // empty + } else { + do { + if (!parse_and_add_item()) { + return false; + } + } while (EatIfPresent(delim)); + } + return ParseToken( + end, StrCat("expects a list to end with ", TokKindToString(end))); +} + // param_list_to_shape ::= param_list '->' shape bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index f834d34d57106b11cf398f966d8c0224f00d1b8d..5990a3d4784750feef2e375492851974214db779 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -277,12 +277,28 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { "SendRecv", R"(HloModule TwoSendRecvBothWayRecvFist_module -ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1} - ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1} +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, sharding={maximal device=1} + ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1} + %constant = f32[] constant(2.1), sharding={maximal device=0} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0} +} + +)" +}, +{ +"SendRecvWithHostTransfer", +R"(HloModule HostTransferSendRecv_module + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, is_host_transfer=true + ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} - %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, is_host_transfer=true + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true } )" @@ -742,6 +758,46 @@ ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5 ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} } +)" +}, +{ +"scatter", +R"(HloModule StringifyScatter + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] { + %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2) + ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3 +} + +)" +}, +{ + "ConstantUnsignedNoUnderflow", + R"(HloModule ConstantUnsignedNoUnderflow_module + +ENTRY %ConstantUnsignedNoUnderflow () -> u64[] { + ROOT %constant = u64[] constant(1) +} + +)" +}, + +{ + "ConstantUnsignedNoOverflow", + R"(HloModule ConstantUnsignedNoOverflow_module + +ENTRY %ConstantUnsignedNoOverflow () -> u64[] { + ROOT %constant = u64[] constant(9223372036854775807) +} + )" }, }); @@ -787,6 +843,32 @@ ENTRY ReduceR3ToR2.v3 { ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 } +)" +}, +// tuple reduce +{ +"TupleReduce", +R"(HloModule TupleReduce + +max_argmax { + value = f32[] parameter(2) + prev_max = f32[] parameter(0) + is_next_larger = pred[] greater-than-or-equal-to(value, prev_max) + max = f32[] select(is_next_larger, value, prev_max) + index = s32[] parameter(3) + prev_argmax = s32[] parameter(1) + argmax = s32[] select(is_next_larger, index, prev_argmax) + ROOT pair = (f32[], s32[]) tuple(max, argmax) +} + +ENTRY reduce_entry { + values = f32[1024]{0} parameter(0) + indices = f32[1024]{0} parameter(1) + init_value = f32[] constant(-inf) + init_index = s32[] constant(-1) + ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax +} + )" }, // infeed/outfeed @@ -795,10 +877,14 @@ ENTRY ReduceR3ToR2.v3 { R"(HloModule outfeed_module ENTRY InfeedToOutfeed { - infeed = (u32[3]{0}, pred[]) infeed() - outfeed = () outfeed(infeed) - ROOT infeed.1 = (u32[3]{0}, pred[]) infeed() - outfeed.1 = () outfeed(infeed.1) + token = token[] after-all() + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 + outfeed = token[] outfeed(infeed.data, token) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 + infeed.1.token = token[] get-tuple-element(infeed.1), index=1 + outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) } )" @@ -826,6 +912,56 @@ ENTRY ReducePrecision { ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10 } +)" +}, +// Sort (Key) +{ +"SortKey", +R"(HloModule sort + +ENTRY Sort { + x = f32[1024]{0} parameter(0) + ROOT sorted = f32[1024]{0} sort(x), dimensions={0} +} + +)" +}, +// Sort (Key, Value) +{ +"SortKeyValue", +R"(HloModule sort + +ENTRY Sort { + keys = f32[1024]{0} parameter(0) + values = s32[1024]{0} parameter(1) + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} +} + +)" +}, +// R2 Sort (Key) +{ +"SortKeyR2", +R"(HloModule sort + +ENTRY Sort { + x = f32[1024,16]{0,1} parameter(0) + ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0} +} + +)" +}, +// R2 Sort (Key, Value) +{ +"SortKeyValueR2", +R"(HloModule sort + +ENTRY Sort { + keys = f32[1024,16]{0,1} parameter(0) + values = s32[1024,16]{0,1} parameter(1) + ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0} +} + )" }, // Conditional @@ -913,7 +1049,7 @@ add { ENTRY CRS { input = f32[8]{0} parameter(0) - ROOT crs = f32[8]{0} cross-replica-sum(input), to_apply=add + ROOT crs = f32[8]{0} cross-replica-sum(input), replica_group_ids={}, to_apply=add } )" @@ -931,7 +1067,53 @@ add { ENTRY CrossReplicaSumWithSubgroups { input = f32[128,32]{0,1} parameter(0) - ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), to_apply=add, replica_group_ids={0,0,1,1}, barrier="abc" + ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add +} + +)" +}, +// all-to-all +{ +"AllToAll", +R"(HloModule AllToAll + +ENTRY AllToAll { + input = f32[128,32]{0,1} parameter(0) + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={} +} + +)" +}, +// all-to-all with subgroups +{ +"AllToAllWithSubgroups", +R"(HloModule AllToAllWithSubgroups + +ENTRY AllToAllWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}, barrier="abc" +} + +)" +}, +// Iota +{ +"Iota", +R"(HloModule iota + +ENTRY Iota { + ROOT iota = f32[100]{0} iota() +} + +)" +}, +// custom-call with window and dim_labels +{ +"CustomCallWithWindowAndDimLabels", +R"(HloModule CustomCallWithWindowAndDimLabels + +ENTRY Computation { + ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target" } )" @@ -1132,6 +1314,40 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { "is out of range for literal's primitive type F16"); } +TEST_F(HloParserTest, ConstantUnsignedUnderflow) { + const string original = R"( + HloModule ConstantUnsignedUnderflow_module + ENTRY %ConstantUnsignedUnderflow () -> u64[] { + ROOT %constant = u64[] constant(-1) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "is out of range for literal's primitive type U64"); +} + +TEST_F(HloParserTest, ConstantUnsignedOverflow) { + const string original = R"( + HloModule ConstantUnsignedOverflow_module + ENTRY %ConstantUnsignedOverflow () -> u32[] { + ROOT %constant = u32[] constant(4294967296) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "is out of range for literal's primitive type U32"); +} + +TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) { + const string original = R"( + HloModule ConstantUnsignedOverflow_module + ENTRY %ConstantUnsignedOverflow () -> u64[] { + ROOT %constant = u64[] constant(9223372036854775808) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + TEST_F(HloParserTest, ConstantWithExp) { const string original = R"(HloModule ConstantWithExp_module @@ -1192,11 +1408,12 @@ TEST_F(HloParserTest, UnexpectedAttribute) { const string original = R"(HloModule unexpected_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15 - %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv - %send-done = () send-done((f32[], u32[]) %send), channel_id=16 + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } )"; @@ -1208,11 +1425,12 @@ TEST_F(HloParserTest, MissingAttribute) { const string original = R"(HloModule missing_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15 - %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(-2.1) - %send = (f32[], u32[]) send(f32[] %constant) - %send-done = () send-done((f32[], u32[]) %send), channel_id=16 + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token) + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } )"; @@ -1224,11 +1442,12 @@ TEST_F(HloParserTest, PredecessorUndefined) { const string original = R"(HloModule pre_not_found_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15 - %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done} - %send-done = () send-done((f32[], u32[]) %send), channel_id=16 + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done} + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } )"; @@ -1302,7 +1521,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); - auto program_layout = module.ValueOrDie()->host_entry_computation_layout(); + auto program_layout = module.ValueOrDie()->entry_computation_layout(); ASSERT_EQ(program_layout.parameter_count(), 1); auto param_layout = program_layout.parameter_layout(0).layout(); auto result_layout = program_layout.result_layout().layout(); @@ -1341,6 +1560,81 @@ ENTRY consts { "last"); } +TEST_F(HloParserTest, Comments) { + const string original = R"(/* module description. */ +HloModule comments: + +ENTRY /*comment*/ c1 { + /* blah */ + ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/}) + /* comment */ +} + +/* something else */ + +)"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, MultilineComments) { + const string original = R"(HloModule multiline_comment: +ENTRY c1 { + /* + ROOT foo = f32[1]{0} constant({12345}) + */ + ROOT const1 = f32[1]{0} constant({12345}) +/* +a +b +c +d + +*/ +})"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, UnterminatedComment) { + const string original = R"(HloModule unterminated_comment: +ENTRY c1 { +/* unterminated + ROOT const1 = f32[1]{0} constant({12345}) +})"; + // Verify that the error message points to the beginning of the unterminated + // comment. + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "/* unterminated\n^"); +} + +TEST_F(HloParserTest, SlashSlashComments) { + const string original = R"(HloModule slash_slash_comment: +// Garbage +ENTRY c1 { + // Foo bar + ROOT const1 = f32[1]{0} constant({12345}) // Something else +})"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) { + const string original = + "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo " + "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) { + const string original = + "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo " + "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + TEST_F(HloParserTest, MultipleEntries) { const string original = R"(HloModule multiple_entries: ENTRY c1 { @@ -1418,5 +1712,15 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); } +TEST_F(HloParserTest, NontupleInfeed) { + const string original = R"(HloModule nontuple_infeed: +ENTRY nontuple_infeed { + token = token[] after-all() + ROOT infeed = pred[] infeed(token) +})"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "infeed must have a non-empty tuple shape"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index b3d0a07add39968c6310392ea01daeab8a7dd9af..28194deb0e32252b372a328b006dabaf250fa2c7 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_ +#include + #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,9 +36,19 @@ class HloPassFix : public Pass { StatusOr Run(HloModule* module) override { bool changed = false; bool changed_this_iteration = true; + int64 iteration_count = 0; + int64 limit = + std::max(static_cast(1000), module->instruction_count()); while (changed_this_iteration) { TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module)); changed |= changed_this_iteration; + ++iteration_count; + if (iteration_count == limit) { + LOG(ERROR) + << "Unexpectedly number of iterations in HLO passes (" + << iteration_count + << ")\nIf compilation hangs here, please file a bug with XLA."; + } } return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index d45038f1f4a2e4aa19234eec93fdc9a068a902e1..2a07b6fcbc243d955e136ccdf097c8155a115845 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_query.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -61,7 +61,7 @@ bool AllOperandsAreConstants(const HloInstruction& instruction) { } HloInstruction* GetMatchingOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction) { for (HloInstruction* op : instruction->operands()) { if (matcher(op)) { @@ -72,7 +72,7 @@ HloInstruction* GetMatchingOperand( } bool MatchBinaryInstructionOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction, HloInstruction** matching_operand, HloInstruction** other_operand) { CHECK_EQ(instruction->operand_count(), 2); diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index c79347bbf9d6146943b7b787f713369cb37fadee..c0826a6aee1f693484207a86ec258c6604d92318 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -45,7 +45,7 @@ bool IsScalarConstant(const HloInstruction* instruction); // multiple matching operands, then the first matching operand is returned. If // there are no matching operands then nullptr is returned. HloInstruction* GetMatchingOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction); // Returns whether a binary instruction has a matching operand. Sets @@ -53,7 +53,7 @@ HloInstruction* GetMatchingOperand( // other_operand. Note: in the case where both operands match, the first operand // of the instruction is returned. bool MatchBinaryInstructionOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction, HloInstruction** matching_operand, HloInstruction** other_operand); diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 4738e46f8aeb96a4c25d04b3246bd21f644fe3ea..01b088a957554821e65db7bf9cedf334db49728f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { HloReachabilityMap::HloReachabilityMap( - const std::list& instructions) + tensorflow::gtl::ArraySlice instructions) : size_(instructions.size()) { bit_vectors_.reserve(size_); for (const HloInstruction* hlo : instructions) { diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 69bb2b3cee6dafe058c45b4e74e93401bea2cfc9..48215d32a8284919cce6beb1663e6a723eefc1c4 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -41,7 +41,8 @@ class HloReachabilityMap { public: // Sets up a graph with no edges and where the nodes correspond to the given // instructions. - explicit HloReachabilityMap(const std::list& instructions); + explicit HloReachabilityMap( + tensorflow::gtl::ArraySlice instructions); // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index 657a9ee83d29e72b95660325f9139f44159d6508..585c95972b0e01abc14543205af71b4b0c0bdf3c 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -39,15 +39,15 @@ TEST_F(HloReachabilityTest, Reachability) { */ auto builder = HloComputation::Builder(TestName()); auto a = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto b = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto c = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto d = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto e = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.Build(); HloReachabilityMap reachability({a, b, c, d, e}); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 9c7bc7a5ea7c77dadb8772f08b823c3579cf2154..cf0be30c7ad5cbeb7fd3d71c7c649b6b448360b8 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -1201,7 +1202,8 @@ StatusOr HloRematerialization::RematerializeComputation( StatusOr HloRematerialization::Run( HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit_bytes, RematerializationSizes* sizes) { + int64 memory_limit_bytes, RematerializationSizes* sizes, + CopyInsertion* copy_insertion) { // The sequence is constructed entirely by this method. TF_RET_CHECK(sequence->empty()); @@ -1236,6 +1238,16 @@ StatusOr HloRematerialization::Run( return size_function_(buffer.shape()); }, scheduler_algorithm_)); + if (copy_insertion) { + // We run a separate pass of copy elision here because the sequential + // ordering from the HLO schedule allows for more copies to be eliminated. + // TODO(b/80249101): Instead of a separate copy elision pass, use the + // ordering from the HLO schedule directly for copy insertion. + SequentialHloOrdering ordering(module, *sequence); + TF_RETURN_IF_ERROR( + copy_insertion->RemoveUnnecessaryCopies(ordering, module)); + } + // Compute peak memory usage of all computations in the module called in a // sequential context. call_graph_ = CallGraph::Build(module); @@ -1338,9 +1350,10 @@ StatusOr HloRematerialization::Run( int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes) { + RematerializationSizes* sizes, CopyInsertion* copy_insertion) { HloRematerialization remat(scheduler_algorithm, size_function); - return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes); + return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes, + copy_insertion); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 2ee2dd0571ae8c6604e4ca722351fd48a913bda5..2ec004350ad88ff31ece90ec419d90a55b965166 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -17,6 +17,7 @@ #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -57,6 +58,13 @@ class HloRematerialization { // sizes: Optional outparam that indicates the peak memory usage of the HLO // module before/after rematerialization. // + // copy_insertion: If non-null, run copy elision after scheduling. This + // pass is used to eliminate copies that were inserted by copy insertion + // before HLO scheduling. + // + // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy + // insertion is integrated with HLO scheduling. + // // Returns whether any instructions were rematerialized. If memory use is // already below the given limit then no instructions are rematerialized and // false is returned. @@ -68,7 +76,7 @@ class HloRematerialization { const ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes = nullptr); + RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr); protected: HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, @@ -83,7 +91,8 @@ class HloRematerialization { // contains the memory-minimizing order in which to emit the HLO instructions. StatusOr Run(HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit, RematerializationSizes* sizes); + int64 memory_limit, RematerializationSizes* sizes, + CopyInsertion* copy_insertion); // Rematerializes instructions within the given computation. 'order' is the // order in which the computation's instructions will be emitted in the diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index e81334d5a84268a129cd4e90091e97dc23243226..ac8c97d380953764b66135ad1c5fcee0d481c004 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -132,7 +132,7 @@ class HloRematerializationTest : public HloTestBase { builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); return builder.Build(); } @@ -147,7 +147,7 @@ class HloRematerializationTest : public HloTestBase { TF_EXPECT_OK(verifier().Run(module).status()); return HloRematerialization::RematerializeAndSchedule( ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, - sequence); + sequence, /*sizes=*/nullptr); } // Various shapes used in the canned computations. @@ -226,7 +226,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -263,7 +263,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -296,7 +296,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index e1f9d8efd4974055947438c8a2e15cb77d1b5c75..b2725e2918ce76248d9f2cdbb2a6e5a63226bf9a 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -98,8 +98,10 @@ StatusOr HloRunner::TransferLiteralToDevice( backend().transfer_manager()->AllocateScopedShapedBuffer( literal.shape(), backend().memory_allocator(), backend().default_device_ordinal())); + TF_ASSIGN_OR_RETURN( + auto stream, backend().BorrowStream(backend().default_stream_executor())); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - backend().default_stream_executor(), literal, buffer)); + stream.get(), literal, buffer)); return std::move(buffer); } @@ -127,8 +129,10 @@ StatusOr> HloRunner::TransferLiteralsToDevice( StatusOr> HloRunner::TransferLiteralFromDevice( const ShapedBuffer& buffer) { - return backend().transfer_manager()->TransferLiteralFromDevice( - backend().default_stream_executor(), buffer); + TF_ASSIGN_OR_RETURN( + auto stream, backend().BorrowStream(backend().default_stream_executor())); + return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(), + buffer); } StatusOr> HloRunner::Execute( @@ -176,8 +180,12 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( TF_ASSIGN_OR_RETURN(std::unique_ptr executable, CreateExecutable(std::move(module), run_hlo_passes)); - return executable->ExecuteOnStreamWrapper(&service_run_options, - /*profile=*/profile, arguments); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer retval, + executable->ExecuteOnStreamWrapper(&service_run_options, + /*profile=*/profile, arguments)); + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + return std::move(retval); } StatusOr HloRunner::ExecuteWithDeviceBuffers( @@ -237,7 +245,7 @@ StatusOr>> HloRunner::ExecuteReplicated( backend().transfer_manager()->AllocateScopedShapedBuffer( argument->shape(), backend().memory_allocator(), device)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - executor, *argument, argument_buffer)); + streams.back().get(), *argument, argument_buffer)); argument_buffers.push_back(std::move(argument_buffer)); argument_buffer_ptrs[index++] = &argument_buffers.back(); } @@ -305,9 +313,10 @@ StatusOr>> HloRunner::ExecuteReplicated( std::vector> exec_results; for (int64 i = 0; i < options.num_replicas; ++i) { + TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); TF_ASSIGN_OR_RETURN(std::unique_ptr literal, backend().transfer_manager()->TransferLiteralFromDevice( - streams[i]->parent(), results[i])); + streams[i].get(), results[i])); exec_results.push_back(std::move(literal)); } return std::move(exec_results); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index b14ade3549d093acdb5cdc7ae99dd025a42d5621..27cc5361cde2fa021b9489f98217ae5648afc2ad 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -375,7 +375,7 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> ScheduleComputationsInModule( +StatusOr> ScheduleComputationHelper( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -399,12 +399,9 @@ StatusOr> DFSMemoryScheduler( const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation) { - // This ordering is based on DFS post-order, with a heuristic to decide which - // operand to visit first. The heuristic is based on 'extra_users', which is - // simply users-1 for each instruction. By subtracting 1, we're saying that - // instructions with no users or a single user don't count; instructions with - // lots of fan-out will be visited earlier. + // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; + int64 total_hlos = computation.parent()->NumUniqueInstructionIds(); tensorflow::gtl::FlatMap extra_users; tensorflow::gtl::FlatMap total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { @@ -413,6 +410,11 @@ StatusOr> DFSMemoryScheduler( total_sizes[hlo] = 0; continue; } + // This ordering is based on DFS post-order, with a heuristic to decide + // which operand to visit first. The heuristic is based on 'extra_users', + // which is simply users-1 for each instruction. By subtracting 1, we're + // saying that instructions with no users or a single user don't count; + // instructions with lots of fan-out will be visited earlier. extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; int64 logical_buffer_size = SumLogicalBufferSizes( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); @@ -428,10 +430,13 @@ StatusOr> DFSMemoryScheduler( // lead to it. But computation is a DAG, so we are double-counting nodes, // which can lead to overflows for large programs. // cumulative_total_size caps the size to prevent overflows. + // Same for total_hlos: it prevents overflows on very large and branchy + // models, where the number of paths is exponential to the number of nodes. // NOTE(dimvar): this is quite ugly and should be changed. It's unclear // why we care about transitive sizes; when scheduling a node, its input // and output buffers should be all that matters, not its "history". total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); + extra_users[hlo] = std::min(extra_users[hlo], total_hlos); } CHECK_EQ(extra_users.size(), computation.instruction_count()); CHECK_EQ(total_sizes.size(), computation.instruction_count()); @@ -498,29 +503,29 @@ StatusOr> DefaultMemoryScheduler( std::vector list_sequence, ListMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 list_memory, - MinimumMemoryForComputation(computation, list_sequence, - points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(const int64 list_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, list_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 dfs_memory, - MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, - size_function)); + TF_ASSIGN_OR_RETURN(const int64 dfs_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, dfs_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); TF_ASSIGN_OR_RETURN( std::vector post_order_sequence, PostOrderMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 post_order_memory, - MinimumMemoryForComputation(computation, post_order_sequence, - points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(const int64 post_order_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, post_order_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory post order sequence: " << HumanReadableNumBytes(post_order_memory); @@ -551,16 +556,18 @@ StatusOr ScheduleComputationsInModule( for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { TF_ASSIGN_OR_RETURN(auto one_computation_sequence, - ScheduleComputationsInModule( + ScheduleComputationHelper( *computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = - MinimumMemoryForComputation(*computation, one_computation_sequence, - *points_to_analysis, size_function) + HeapSimulator::MinimumMemoryForComputation( + *computation, one_computation_sequence, *points_to_analysis, + size_function, &memory_by_computation) .ValueOrDie(); sequence[computation] = std::move(one_computation_sequence); } } + VLOG(1) << "Module schedule:\n" << sequence; return sequence; } @@ -571,8 +578,8 @@ StatusOr> ScheduleOneComputation( TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); tensorflow::gtl::FlatMap empty_map; - return ScheduleComputationsInModule(computation, *points_to_analysis, - size_function, nullptr, empty_map); + return ScheduleComputationHelper(computation, *points_to_analysis, + size_function, nullptr, empty_map); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 6f1b1215d39dfbaeff768de70fa0a0859cd97381..9ec983c2bc353955cb23d441d200ac8aa36951b1 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -144,7 +145,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { // ROOT %subtract = f32[4]{0} subtract( // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) // } - // %SubcomputationsNotAccounted () -> f32[2,4] { + // %ListAccountsForSubcomputations () -> f32[2,4] { // %constant.3 = f32[2,4]{1,0} constant( // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) // %transpose = f32[2,4]{1,0} transpose( @@ -167,8 +168,9 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { auto cond_builder = HloComputation::Builder("WhileCond"); HloInstruction* cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "cond_param")); - HloInstruction* zero_vector = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{0, 0, 0, 0}}))); + HloInstruction* zero_vector = + cond_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0, 0, 0, 0}}))); cond_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); @@ -178,16 +180,18 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { auto body_builder = HloComputation::Builder("WhileBody"); HloInstruction* body_param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "body_param")); - HloInstruction* one_vector = body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + HloInstruction* one_vector = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); body_builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, body_param, one_vector)); auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); // transpose(matrix) + bcast(while) auto builder = HloComputation::Builder(TestName()); - HloInstruction* while_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + HloInstruction* while_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); // Creates 16 bytes, ignoring subcomputations HloInstruction* while_loop = builder.AddInstruction(HloInstruction::CreateWhile( @@ -198,7 +202,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); HloInstruction* matrix = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); // Creates 32 bytes HloInstruction* transpose = builder.AddInstruction( @@ -210,16 +214,16 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule( - *module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - }, - ListMemoryScheduler)); + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + auto entry_computation = module->entry_computation(); + EXPECT_EQ(entry_computation->instruction_count(), + sequence.at(entry_computation).size()); SequentialHloOrdering ordering(module.get(), sequence); // This schedule is an example of List's greedy heuristics being suboptimal. // The while_loop is more expensive than transpose, so it would have been @@ -228,6 +232,24 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); + + tensorflow::gtl::FlatMap memory_by_computation; + memory_by_computation[cond_computation] = 17; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator doesn't account for subcomputations + EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn) + .ValueOrDie()); + // HeapSimulator accounts for subcomputations. The max mem doesn't change + // because the while body isn't live during the peak. + EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); } TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { @@ -238,7 +260,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { // Wrap lit in abs because constants are considered free by // IgnoreInstruction, and it skews the accounting. auto lit = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1, 1, 1, 1, 1, 1}))); + LiteralUtil::CreateR1({1, 1, 1, 1, 1, 1}))); auto abs_const = builder.AddInstruction( HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit)); @@ -260,7 +282,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, ScheduleComputationsInModule(*module, - [&TUPLE_SIZE](const BufferValue& buffer) { + [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf( buffer.shape(), TUPLE_SIZE); }, @@ -281,11 +303,11 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { HloComputation::Builder builder(TestName()); auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1, 1, 1, 1, 1}))); + LiteralUtil::CreateR1({1, 1, 1, 1, 1}))); auto c2 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1, 2, 3, 4, 5}))); + LiteralUtil::CreateR1({1, 2, 3, 4, 5}))); auto c3 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0, 2, 4, 6, 8}))); + LiteralUtil::CreateR1({0, 2, 4, 6, 8}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2)); @@ -325,5 +347,73 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); } +TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { + auto module = CreateNewModule(); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // param != 0 + // Needs 17 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* zero_vector = + cond_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0, 0, 0, 0}}))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); + // Creates 16 bytes, ignoring subcomputations + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + auto entry_computation = module->entry_computation(); + EXPECT_EQ(entry_computation->instruction_count(), + sequence.at(entry_computation).size()); + + tensorflow::gtl::FlatMap memory_by_computation; + memory_by_computation[cond_computation] = 17; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator doesn't account for subcomputations + EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn) + .ValueOrDie()); + // HeapSimulator accounts for subcomputations + EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 9fb15df7c26951fb7f0d62b0d6533d6312e7a4d5..879fb3bbab2ada0f924282f16b3d9ccb4c2cb203 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -31,12 +31,9 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { CHECK_EQ(1, ShapeUtil::Rank(input_shape)); CHECK_GT(num_tiles, 1); std::vector dimensions(1, num_tiles); - Shape tile_shape = input_shape; - auto& tile_dimension = (*tile_shape.mutable_dimensions())[0]; - tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); Array assignment(dimensions); std::iota(assignment.begin(), assignment.end(), 0); - return HloSharding(tile_shape, assignment); + return HloSharding(assignment); } HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { @@ -60,6 +57,9 @@ HloSharding HloSharding::Tuple( const Shape& tuple_shape, tensorflow::gtl::ArraySlice shardings) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + for (auto& sharding : shardings) { + CHECK(!sharding.IsTuple()) << sharding.ToString(); + } std::vector flattened_list(shardings.begin(), shardings.end()); CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape)) << "Flat list has " << flattened_list.size() << ", required " @@ -67,6 +67,24 @@ HloSharding HloSharding::Tuple( return HloSharding(flattened_list); } +HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, + const HloSharding& sharding) { + CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + CHECK(!sharding.IsTuple()) << sharding.ToString(); + int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape); + std::vector flattened_list; + flattened_list.reserve(leaf_count); + for (int64 i = 0; i < leaf_count; ++i) { + flattened_list.push_back(sharding); + } + return HloSharding(flattened_list); +} + +HloSharding HloSharding::Single(const Shape& shape, + const HloSharding& sharding) { + return ShapeUtil::IsTuple(shape) ? SingleTuple(shape, sharding) : sharding; +} + string HloSharding::ToString() const { if (IsTuple()) { std::vector parts; @@ -83,8 +101,7 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); } else { - return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[", - Join(tile_assignment_.dimensions(), ","), "]", + return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]", Join(tile_assignment_, ","), "}"); } } @@ -100,8 +117,30 @@ bool HloSharding::UsesDevice(int64 device) const { std::find(devices.begin(), devices.end(), device) != devices.end(); } +std::map HloSharding::UsedDevices(int64* count) const { + int64 element_count = 1; + std::map device_map; + if (IsTuple()) { + for (auto& tuple_element_sharding : tuple_elements()) { + auto unique_device = tuple_element_sharding.UniqueDevice(); + if (unique_device) { + device_map[*unique_device] += 1; + } + } + element_count = tuple_elements().size(); + } else { + auto unique_device = UniqueDevice(); + if (unique_device) { + device_map[*unique_device] += 1; + } + } + if (count != nullptr) { + *count = element_count; + } + return device_map; +} + std::vector HloSharding::TileIndexForDevice(int64 device) const { - CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); CHECK(!IsTuple()); std::vector ret_index; @@ -121,32 +160,43 @@ int64 HloSharding::DeviceForTileIndex( if (maximal_) { return *tile_assignment_.begin(); } - CHECK_EQ(ShapeUtil::Rank(tile_shape_), tile_assignment_.dimensions().size()); return tile_assignment_(index); } -std::vector HloSharding::TileOffsetForDevice(int64 device) const { +std::vector HloSharding::TileOffsetForDevice(const Shape& shape, + int64 device) const { CHECK(!IsTuple()); - std::vector index = TileIndexForDevice(device); if (maximal_) { - // Index will always be all zeroes if we're maximal, and tile_shape_ is not - // valid. - return index; + return std::vector(shape.dimensions_size(), 0); } + + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); + std::vector index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { - index[i] *= tile_shape_.dimensions(i); + const int64 shape_dim = shape.dimensions(i); + index[i] = std::min( + index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim); } return index; } -std::vector HloSharding::TileLimitForDevice(int64 device) const { +std::vector HloSharding::TileLimitForDevice(const Shape& shape, + int64 device) const { CHECK(!IsTuple()); - CHECK(!maximal_); // Maximal shardings do not have a valid tile shape. + if (maximal_) { + return std::vector(shape.dimensions().begin(), + shape.dimensions().end()); + } + + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); std::vector index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { - index[i] = (index[i] + 1) * tile_shape_.dimensions(i); + const int64 shape_dim = shape.dimensions(i); + index[i] = std::min( + (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), + shape_dim); } return index; } @@ -194,40 +244,31 @@ StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { return Tuple(ShapeTree(shape, *this)); } -StatusOr HloSharding::UniqueDevice() const { +tensorflow::gtl::optional HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { - return tensorflow::errors::InvalidArgument( - "UniqueDevice() called on empty tuple"); + return tensorflow::gtl::nullopt; } - std::vector> results; - std::transform(tuple_elements_.begin(), tuple_elements_.end(), - std::back_inserter(results), - [](const HloSharding& s) { return s.UniqueDevice(); }); - if (std::all_of(results.begin(), results.end(), - [&](const StatusOr& s) { - return s.ok() && results[0].ok() && - s.ValueOrDie() == results[0].ValueOrDie(); - })) { - return results[0]; - } else { - return tensorflow::errors::InvalidArgument( - "Tuple did not contain a unique device"); + tensorflow::gtl::optional unique_device; + for (auto& tuple_sharding : tuple_elements_) { + auto device = tuple_sharding.UniqueDevice(); + if (!device || (unique_device && *device != *unique_device)) { + return tensorflow::gtl::nullopt; + } + unique_device = device; } + return unique_device; } - if (!replicated_ && maximal_ && !IsTuple()) { + if (!replicated_ && maximal_) { return static_cast(*tile_assignment_.begin()); } - return tensorflow::errors::InvalidArgument( - "UniqueDevice() called on sharding that executes on multiple devices"); + return tensorflow::gtl::nullopt; } -bool HloSharding::HasUniqueDevice() const { - if (IsTuple()) { - return UniqueDevice().status().ok(); - } else { - return !IsReplicated() && IsTileMaximal(); - } +int64 HloSharding::GetUniqueDevice() const { + auto device = UniqueDevice(); + CHECK(device) << "Sharding does not have a unique device: " << *this; + return *device; } Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { @@ -301,11 +342,12 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return Status::OK(); } - // The tile rank must be the same as the input rank. - if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { + // The tile assignment tensor must have the same rank as the input. + if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) { return tensorflow::errors::InvalidArgument( - "Tile rank is different to the input rank. sharding=", ToString(), - ", input_shape=", ShapeUtil::HumanString(shape)); + "Number of tile assignment dimensions is different to the input rank. " + "sharding=", + ToString(), ", input_shape=", ShapeUtil::HumanString(shape)); } // The correct constructor have to be used to create tile maximal shardings. @@ -315,20 +357,6 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, "sharding was intended, use HloSharding::Replicated(). If a device " "placement was intended, use HloSharding::AssignDevice()"); } - - // The tile assignment tensor must contain enough element to cover the full - // shape with tiles of the specified size. - for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) { - int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i); - if (shape.dimensions(i) > total_tile_size) { - return tensorflow::errors::InvalidArgument( - StrCat("Tile assignment tensor has too few element to cover the full " - "shape. Dimension ", - i, ", shape ", shape.dimensions(i), ", total size ", - total_tile_size)); - } - } - return Status::OK(); } @@ -358,7 +386,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, proto.tile_assignment_dimensions().end())); std::copy(proto.tile_assignment_devices().begin(), proto.tile_assignment_devices().end(), tile_assignment.begin()); - return HloSharding(proto.tile_shape(), tile_assignment); + return HloSharding(tile_assignment); } OpSharding HloSharding::ToProto() const { @@ -372,7 +400,6 @@ OpSharding HloSharding::ToProto() const { return result; } - *result.mutable_tile_shape() = tile_shape_; for (int64 dim : tile_assignment_.dimensions()) { result.add_tile_assignment_dimensions(dim); } @@ -389,30 +416,16 @@ OpSharding HloSharding::ToProto() const { return result; } -HloSharding HloSharding::TransformShardedTileShape( - const Shape& new_shape, - const std::function& transform) const { - CHECK(!IsTuple()); +Shape HloSharding::TileShape(const Shape& shape) const { if (IsTileMaximal()) { - return *this; + return shape; } - CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape())); - Shape new_tile_shape; - new_tile_shape.set_element_type(tile_shape().element_type()); - for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) { - int64 dim; - if (tile_assignment().dim(i) == 1) { - dim = new_shape.dimensions(i); - } else if (transform) { - dim = transform(i, tile_shape().dimensions(i)); - } else { - dim = tile_shape().dimensions(i); - } - new_tile_shape.add_dimensions(dim); + Shape result_shape = shape; + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + (*result_shape.mutable_dimensions())[i] = + CeilOfRatio(shape.dimensions(i), tile_assignment_.dim(i)); } - TF_CHECK_OK( - LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape)); - return HloSharding::Tile(new_tile_shape, tile_assignment()); + return result_shape; } HloSharding HloSharding::GetSubSharding(const Shape& shape, @@ -439,6 +452,24 @@ tensorflow::gtl::optional HloSharding::ExtractSingleSharding() return tuple_elements_.front(); } +size_t HloSharding::Hash() const { + if (!tuple_) { + size_t h = 0; + for (const auto& element : tuple_elements_) { + h = tensorflow::Hash64Combine(h, element.Hash()); + } + return h; + } + if (replicated_) { + return 0; + } + size_t h = 0; + for (uint32 v : tile_assignment_) { + h = tensorflow::Hash64Combine(h, std::hash{}(v)); + } + return h; +} + std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { out << sharding.ToString(); return out; diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 6a744e0247273e25c5de3143b7bbba2b79ee816a..894783e5d1538fa4e8e91b65827121f32040af83 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -19,10 +19,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ +#include #include +#include #include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -46,22 +48,10 @@ class HloSharding { // the input shape (one tile) assigned to a single device. static HloSharding AssignDevice(int64 device_id); - // Creates a new sharding which splits a shape into tiles each with shape - // `tile_shape`. Each tile is assigned to one device, which is specified by - // `tile_assignment`. Any tensor not a multiple of the tile size in any - // dimension is implicitly padded to the tile size. - // - // e.g. Tile({2, 2}, {0, 1}) on a tensor of shape {3, 2} would look like: - // 2 1 padding - // <------><-> - // +----+----+ - // | 0 | 1 | - // +----+----+ - // - // Split into two tiles, one of which is implicitly padded by one. - static HloSharding Tile(const Shape& tile_shape, - const Array& tile_assignment) { - return HloSharding(tile_shape, tile_assignment); + // Creates a new sharding which splits a shape into tiles amongst the devices + // specified by `tile_assignment`. + static HloSharding Tile(const Array& tile_assignment) { + return HloSharding(tile_assignment); } // Creates a new sharding which splits a one-dimensional input shape into @@ -78,6 +68,15 @@ class HloSharding { static HloSharding Tuple(const Shape& tuple_shape, tensorflow::gtl::ArraySlice shardings); + // Creates a new sharding for a tuple type, with a single input sharding + // repeated on each leaf. + static HloSharding SingleTuple(const Shape& tuple_shape, + const HloSharding& sharding); + + // If shape is an array, returns sharding, otherwise returns the tuple shaped + // sharding with all the leaf nodes having the same input sharding. + static HloSharding Single(const Shape& shape, const HloSharding& sharding); + // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); @@ -118,6 +117,14 @@ class HloSharding { // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; + // Retrieves an histogram of the devices used by the sharding. The returned + // map has the device number as key, and the occurrence count as value. + // If a sharding does not have a device, it will not be incuded in the + // histogram. The count argument, if not nullptr, will receive the total + // number of elements this sharding is made of (one for array, N leaves for + // tuples). + std::map UsedDevices(int64* count) const; + // Returns the tile that should be executed on the given device. // REQUIRES: !IsTuple() std::vector TileIndexForDevice(int64 device) const; @@ -127,24 +134,30 @@ class HloSharding { // REQUIRES: !IsTuple() int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice index) const; - // Given a device ID, returns the offset within the input space of the + // Given a device ID, returns the offset within the specified shape of the // tile that should be executed on the given core. This returns the lower // extent of the tile in the input space. // REQUIRES: !IsTuple() - std::vector TileOffsetForDevice(int64 device) const; + std::vector TileOffsetForDevice(const Shape& shape, + int64 device) const; - // Given a device ID, returns the limit within the input space of the + // Given a device ID, returns the limit within the specified shape of the // tile that should be executed on the given core. This returns the upper // extent of the tile in the input space. // REQUIRES: !IsTuple() - std::vector TileLimitForDevice(int64 device) const; + std::vector TileLimitForDevice(const Shape& shape, int64 device) const; - // Returns the single device this op operates on. - // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal() - StatusOr UniqueDevice() const; + // Returns the single device this op operates on. If the sharding does not + // span a single device, the return value will be empty. + // In order for a sharding to span a single device, every leaf sharding must + // be maximal and not replicated, and the used device must match. + tensorflow::gtl::optional UniqueDevice() const; + + // Retrieves the unique device or fails with a CHECK. + int64 GetUniqueDevice() const; // Returns true if this op only uses a single device. - bool HasUniqueDevice() const; + bool HasUniqueDevice() const { return UniqueDevice().has_value(); } // Returns the ShapeTree containing the shardings for each element of this // tuple, if IsTuple, or a ShapeTree with a single element containing this @@ -173,32 +186,12 @@ class HloSharding { bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && - ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && tile_assignment_ == other.tile_assignment_ && tuple_elements_ == other.tuple_elements_; } bool operator!=(const HloSharding& other) const { return !(*this == other); } - size_t Hash() const { - if (!tuple_) { - size_t h = 0; - for (const auto& element : tuple_elements_) { - h = tensorflow::Hash64Combine(h, element.Hash()); - } - return h; - } - if (replicated_) { - return 0; - } - size_t h = 0; - for (uint32 v : tile_assignment_) { - h = tensorflow::Hash64Combine(h, std::hash{}(v)); - } - for (uint32 v : tile_shape_.dimensions()) { - h = tensorflow::Hash64Combine(h, std::hash{}(v)); - } - return h; - } + size_t Hash() const; struct Hasher { size_t operator()(const HloSharding& sharding) const { @@ -206,9 +199,6 @@ class HloSharding { } }; - // Gets the tile shape. - // REQUIRES: !IsTileMaximal() && !IsTuple() - const Shape& tile_shape() const { return tile_shape_; } // Gets the tile assignment tensor. // REQUIRES: !IsReplicated() && !IsTuple() const Array& tile_assignment() const { return tile_assignment_; } @@ -220,39 +210,33 @@ class HloSharding { return tuple_elements_; } - // Return a new sharding that can apply to the given new shape. - // If this sharding is tile-maximal, the returned sharding will be the same as - // this sharding. If this sharding is not tile-maximal, the returned - // sharding's tile size will differ: - // - Non-sharded dimensions will be adapted to be the same as `new_shape`; - // tile_dimension(i) = new_shape.dimensions(i); - // - Sharded dimensions will be kept the same unless `transform` is supplied - // in which case tile_dimension(i) = transform(i, tile_dimension(i)); - // REQUIRES: !IsTuple(). - HloSharding TransformShardedTileShape( - const Shape& new_shape, - const std::function& transform = nullptr) const; + // Gets the tile shape. + // REQUIRES: !IsTuple() + Shape TileShape(const Shape& shape) const; private: HloSharding() : replicated_(true), maximal_(true), tuple_(false), - tile_shape_(), tile_assignment_({0}) {} + // device_id values: + // -2: magic number to mean unassigned device, used by spatial partitioning + // -1: the id of the host + // 0 or positive: the id of a device + // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once + // we have fully switched to the side-effect tokens. explicit HloSharding(int64 device_id) : replicated_(false), maximal_(true), tuple_(false), - tile_shape_(), tile_assignment_({1}, device_id) {} - HloSharding(const Shape& tile_shape, const Array& tile_assignment) + explicit HloSharding(const Array& tile_assignment) : replicated_(false), maximal_(false), tuple_(false), - tile_shape_(tile_shape), tile_assignment_(tile_assignment) {} - HloSharding(const std::vector& tuple_shardings) + explicit HloSharding(const std::vector& tuple_shardings) : replicated_(false), maximal_(false), tuple_(true), @@ -275,7 +259,6 @@ class HloSharding { bool replicated_; bool maximal_; bool tuple_; - Shape tile_shape_; Array tile_assignment_; // Only non-empty when tuple_ is true, but because empty tuples are allowed // may also be empty even then. This is a flattened list of all the leaf diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 7b4b071af46df19520f9ba1f1f632692d489de59..a2c1d39d0d4893333b3c2ed0e3418b01dac8cefd 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -88,6 +88,12 @@ std::vector LocatePassThroughDomainLinks( VLOG(2) << " " << instruction->ToString(); } } + if (instruction == instruction->parent()->root_instruction()) { + pass_through.emplace_back(nullptr, instruction); + VLOG(2) << "Found passthrough domain link:"; + VLOG(2) << " "; + VLOG(2) << " " << instruction->ToString(); + } } return pass_through; } @@ -101,8 +107,12 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, HloInstruction::CreateGetTupleElement(pass_through.operand->shape(), tuple, 0)); gte->set_sharding(sharding); - TF_RETURN_IF_ERROR( - pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + if (pass_through.user != nullptr) { + TF_RETURN_IF_ERROR( + pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + } else { + pass_through.operand->parent()->set_root_instruction(gte); + } } return Status::OK(); } @@ -148,7 +158,6 @@ ShapeTree GetTupleSharding(HloInstruction* tuple) { const HloSharding* GetOperandSharding(const HloInstruction* operand, const DomainMetadata::Domain& domain, const HloSharding& sharding) { - DCHECK_EQ(domain.reach_set.count(const_cast(operand)), 1); // Here the user of operand is within the domain instruction set, and since it // is user of operand, we need to look into the enter_domains set. If this is // not a kDomain within the user domains set, then return the operand @@ -193,10 +202,17 @@ StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, for (int64 i = 0; i < instruction->operand_count(); ++i) { const HloSharding* operand_sharding = GetOperandSharding(instruction->operand(i), domain, sharding); - if (operand_sharding != nullptr && - shape_tree.element({i}) != *operand_sharding) { - *shape_tree.mutable_element({i}) = *operand_sharding; - ++tuple_assigned; + if (operand_sharding != nullptr) { + HloSharding operand_subsharding = HloSharding::Replicate(); + if (operand_sharding == &sharding) { + operand_subsharding = + sharding.GetSubSharding(instruction->shape(), {i}); + operand_sharding = &operand_subsharding; + } + if (shape_tree.element({i}) != *operand_sharding) { + *shape_tree.mutable_element({i}) = *operand_sharding; + ++tuple_assigned; + } } } if (tuple_assigned > 0) { @@ -235,6 +251,8 @@ StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, Status ApplyDomainSharding(const DomainMetadata::Domain& domain, const HloSharding& sharding) { + // None of the external normalizers handled the domain sharding, try to see + // whether this is a single sharding first. auto single_sharding = sharding.ExtractSingleSharding(); if (single_sharding) { // Shortcut the simple case. We have a unique sharding, so we call @@ -360,28 +378,39 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const { } string ShardingMetadata::ToString() const { - return sharding_ != nullptr ? sharding_->ToString() : "None"; + return sharding_ != nullptr ? sharding_->ToString() : "{}"; } -Status ShardingMetadata::NormalizeInstructions( - const DomainMetadata::Domain& domain) const { - if (sharding_ != nullptr) { - VLOG(4) << "Normalizing sharding to " << sharding_->ToString() << ":"; - TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding_)); - TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding_)); +/*static*/ StatusOr +ShardingMetadata::ToShardingMetadata(const DomainMetadata* metadata) { + if (metadata->Kind() != ShardingMetadata::KindName()) { + return Status( + tensorflow::error::INVALID_ARGUMENT, + "ShardingMetadata normalizer called with incorrect domain metadata"); } - return Status::OK(); + return static_cast(metadata); } -Status NormalizeShardingDomain(const DomainMetadata::Domain& domain) { - TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, - ExtractOriginalCommonSharding(domain.instructions)); - if (sharding != nullptr) { - VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString() - << ":"; - TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding)); +Status ShardingMetadata::NormalizeShardingDomain( + const DomainMetadata::Domain& domain, const DomainMetadata* metadata) { + if (metadata != nullptr) { + TF_ASSIGN_OR_RETURN(const auto& sharding_metadata, + ToShardingMetadata(metadata)); + const HloSharding* sharding = sharding_metadata->sharding(); + if (sharding != nullptr) { + VLOG(4) << "Normalizing sharding to " << sharding->ToString() << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding)); + TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding)); + } } else { - VLOG(1) << "Unable to find common sharding"; + TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, + ExtractOriginalCommonSharding(domain.instructions)); + if (sharding != nullptr) { + VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString(); + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding)); + } else { + VLOG(1) << "Unable to find common sharding"; + } } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index ec162c34904ee2dfac3daeeee37133282a9c9698..5e01fc0e22ae8f3421c2cb5790adf44b1200a804 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -38,23 +38,26 @@ class ShardingMetadata : public DomainMetadata { string ToString() const override; - Status NormalizeInstructions( - const DomainMetadata::Domain& domain) const override; + const HloSharding* sharding() const { return sharding_.get(); } static tensorflow::StringPiece KindName() { return "sharding"; } + static StatusOr ToShardingMetadata( + const DomainMetadata* metadata); + + // Apply the specified domain metadata onto the specified domain. If no + // metadata is specified then apply sharding heuristics and normalize the + // instructions whose sharding deviates from the one which is inferred as to + // be the original one. Policy wise, HLO passes are allowed to create new + // unassigned instructions, but if they do create assigned ones, they have to + // conform to the ones around. + static Status NormalizeShardingDomain(const DomainMetadata::Domain& domain, + const DomainMetadata* metadata); + private: std::unique_ptr sharding_; }; -// Within a set of instructions which had common sharding attributes before -// entring the HLO passes pipeline, apply sharding heuristics and normalize the -// instructions whose sharding deviates from the one which is inferred as to be -// the original one. -// Policy wise, HLO passes are allowed to create new unassigned instructions, -// but if they do create assigned ones, they have to conform to the ones around. -Status NormalizeShardingDomain(const DomainMetadata::Domain& domain); - // Given an HLO graph edge between instruction and one of its operands, creates // a ShardingMetadata based kDomain instruction if the sharding between // instruction and operand changes. Returns nullptr if there is no need for a diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 54b7402b866361748d9eb35182b0bf486c4c9bdc..45fc300fcaf5a301fe11768da77a7c0907919c39 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -39,7 +39,6 @@ Array MakeArray(tensorflow::gtl::ArraySlice dimensions, class HloShardingTest : public HloTestBase {}; TEST_F(HloShardingTest, Replicate) { - Shape tile_shape = ShapeUtil::MakeShape(U32, {4}); HloSharding sharding = HloSharding::Replicate(); EXPECT_TRUE(sharding.IsReplicated()); EXPECT_TRUE(sharding.IsTileMaximal()); @@ -51,7 +50,7 @@ TEST_F(HloShardingTest, Replicate) { EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/2)); - EXPECT_IS_NOT_OK(sharding.UniqueDevice()); + EXPECT_FALSE(sharding.HasUniqueDevice()); } TEST_F(HloShardingTest, DevicePlacement) { @@ -60,7 +59,7 @@ TEST_F(HloShardingTest, DevicePlacement) { EXPECT_TRUE(sharding.IsTileMaximal()); EXPECT_FALSE(sharding.UsesDevice(0)); EXPECT_TRUE(sharding.UsesDevice(5)); - EXPECT_EQ(5, sharding.UniqueDevice().ValueOrDie()); + EXPECT_EQ(5, sharding.GetUniqueDevice()); HloSharding other = HloSharding::Replicate(); EXPECT_NE(other, sharding); @@ -79,37 +78,22 @@ TEST_F(HloShardingTest, DevicePlacement) { TEST_F(HloShardingTest, Tile) { { // Test should fail because of a duplicate tile assignment. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 0, 2, 3})); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}), /*num_devices=*/4)); } { // Test should fail because of more devices used then `num_device`. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}), /*num_devices=*/2)); } - { - // Test should fail because the total tiled size in dimension 0 is 4 but we - // have 6 elements along that dimensions. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); - EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {6, 3}), - /*num_devices=*/4)); - } - { // Test should pass. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); + Shape shape = ShapeUtil::MakeShape(U32, {4, 5}); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}), /*num_devices=*/5)); @@ -118,12 +102,16 @@ TEST_F(HloShardingTest, Tile) { EXPECT_EQ(2, sharding.DeviceForTileIndex({1, 0})); EXPECT_EQ(1, sharding.DeviceForTileIndex({1, 1})); - EXPECT_EQ(sharding.TileOffsetForDevice(0), (std::vector{0, 0})); - EXPECT_EQ(sharding.TileOffsetForDevice(3), (std::vector{0, 3})); - EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector{2, 0})); - EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector{2, 3})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 0), + (std::vector{0, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 3), + (std::vector{0, 3})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 2), + (std::vector{2, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 1), + (std::vector{2, 3})); - EXPECT_IS_NOT_OK(sharding.UniqueDevice()); + EXPECT_FALSE(sharding.HasUniqueDevice()); } } @@ -135,8 +123,7 @@ TEST_F(HloShardingTest, NestedTuple) { ShapeUtil::MakeShape(F32, {4, 6}), }); - HloSharding tiled_sharding = HloSharding::Tile( - ShapeUtil::MakeShape(F32, {4, 3}), Array({{0, 1}})); + HloSharding tiled_sharding = HloSharding::Tile(Array({{0, 1}})); OpSharding proto; proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE); *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto(); @@ -187,32 +174,11 @@ TEST_F(HloShardingTest, Hash) { } { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 2, 1})); - EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); - } - - { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding1 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding2 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); } - { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 1, 2})); - EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); - } - HloSharding default_sharding = HloSharding::Replicate(); { ShapeTree shape_tree(ShapeUtil::MakeTupleShape({}), @@ -259,19 +225,6 @@ TEST_F(HloShardingTest, Hash) { } } -TEST_F(HloShardingTest, TransformShardedTileShapeTest) { - HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), - Array4D({{{{0, 1}, {2, 3}}}})); - HloSharding result = sharding.TransformShardedTileShape( - ShapeUtil::MakeShape(F32, {13, 15, 17, 19}), - [](int dim, int value) { return dim * 111; }); - HloSharding expected = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}), - Array4D({{{{0, 1}, {2, 3}}}})); - EXPECT_EQ(result, expected); -} - TEST_F(HloShardingTest, ToStringReplicatedTest) { HloSharding sharding = HloSharding::Replicate(); EXPECT_EQ(sharding.ToString(), "{replicated}"); @@ -284,9 +237,8 @@ TEST_F(HloShardingTest, ToStringAssignDeviceTest) { TEST_F(HloShardingTest, ToStringTiledTest) { HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}), - Array3D({{{2, 3}}, {{5, 7}}})); - EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}"); + HloSharding::Tile(Array3D({{{2, 3}}, {{5, 7}}})); + EXPECT_EQ(sharding.ToString(), "{devices=[2,1,2]2,3,5,7}"); } TEST_F(HloShardingTest, ToStringTupleTest) { @@ -294,21 +246,18 @@ TEST_F(HloShardingTest, ToStringTupleTest) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}), ShapeUtil::MakeShape(U32, {7, 25}), ShapeUtil::MakeShape(S32, {9, 11})}), - {HloSharding::Replicate(), - HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}), - Array2D({{3, 5}})), + {HloSharding::Replicate(), HloSharding::Tile(Array2D({{3, 5}})), HloSharding::AssignDevice(3)}); EXPECT_EQ(sharding.ToString(), - "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}"); + "{{replicated}, {devices=[1,2]3,5}, {maximal device=3}}"); } TEST_F(HloShardingTest, OstreamTest) { HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), - Array4D({{{{0, 1}, {2, 3}}}})); + HloSharding::Tile(Array4D({{{{0, 1}, {2, 3}}}})); std::ostringstream oss; oss << sharding; - EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); + EXPECT_EQ(oss.str(), "{devices=[1,1,2,2]0,1,2,3}"); } TEST_F(HloShardingTest, ParseHloString) { @@ -319,8 +268,7 @@ TEST_F(HloShardingTest, ParseHloString) { }; check(HloSharding::Replicate()); check(HloSharding::AssignDevice(2)); - check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}}))); + check(HloSharding::Tile(Array4D({{{{0}, {1}}}}))); // Empty tuple. One sharding is required for empty tuples, as we need to be // able to assign sharding to them, even though they have no leaves. check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), @@ -332,8 +280,7 @@ TEST_F(HloShardingTest, ParseHloString) { ShapeUtil::MakeShape(F32, {3, 5, 7}), ShapeUtil::MakeShape(F32, {3, 7})}); check(HloSharding::Tuple( - tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}})), + tuple_shape, {HloSharding::Tile(Array4D({{{{0}, {1}}}})), HloSharding::Replicate(), HloSharding::AssignDevice(1)})); } { @@ -343,8 +290,7 @@ TEST_F(HloShardingTest, ParseHloString) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}), ShapeUtil::MakeShape(F32, {3, 7})})}); std::vector leaf_shardings = { - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}})), + HloSharding::Tile(Array4D({{{{0}, {1}}}})), HloSharding::Replicate(), HloSharding::AssignDevice(1)}; ShapeTree sharding_tree(tuple_shape, HloSharding::Replicate()); // Assign leaf_shardings to sharding_tree leaves. diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 7b601f9a9578cfa6b293cf7f002255f7db8b1257..45c684d66752862eec301b8943d350804f070309 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -75,7 +75,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); auto x = builder.AddInstruction( HloInstruction::CreateCall(r0s32_, {constant}, callee1)); auto y = builder.AddInstruction( @@ -112,9 +112,9 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3))); auto x = builder.AddInstruction( HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1)); auto y = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 3dc733940fc89952bd5e75a9b28d9cbf356f8000..b78bfa0cdf4db605576fa11e18ce6c654c6a0b6d 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -101,11 +101,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction( } }; string node_name; - if (debug_options_.xla_hlo_tfgraph_device_scopes() && - instruction->has_sharding() && - instruction->sharding().HasUniqueDevice()) { - node_name = StrCat( - "dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie()); + if (debug_options_.xla_hlo_tfgraph_device_scopes()) { + auto device = instruction->sharding_unique_device(); + if (device) { + node_name = StrCat("dev", *device); + } } // If an instruction is fused, put it in the subgraph of the fusion; // otherwise, put it in the computation subgraph. @@ -215,10 +215,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { NodeDef* node_def = graph_def_.add_node(); node_def->set_name(GetNodeNameForInstruction(instruction)); node_def->set_op(GetOpDefName(instruction)); - if (instruction->has_sharding() && - instruction->sharding().HasUniqueDevice()) { - TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice()); - node_def->set_device(GetDeviceName(device)); + + auto device = instruction->sharding_unique_device(); + if (device) { + node_def->set_device(GetDeviceName(*device)); } SetNodeAttrs(instruction, node_def); if (instruction->opcode() == HloOpcode::kFusion) { diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index be156d765dc10d54eaf301e90883babbc5693e28..1e2b31a1f2bb4865faafc3d14e2b194e3aa171a1 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -90,7 +90,7 @@ TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { auto builder = HloComputation::Builder("Const"); HloInstruction *instruction = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); OpMetadata metadata; metadata.set_op_name("x"); metadata.set_op_type("y"); diff --git a/tensorflow/compiler/xla/service/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h index 533429608bc2e13626a3e746fbe465398e1f4bb4..4458c251dee4af365e39027dd4289925c8890efd 100644 --- a/tensorflow/compiler/xla/service/hlo_token.h +++ b/tensorflow/compiler/xla/service/hlo_token.h @@ -44,7 +44,6 @@ enum class TokKind { kRparen, // ( ) kArrow, // -> - kComment, // /*xxx*/ // Keywords kw_HloModule, diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 7b27dbfec376b8ba16d00285f10e2cc291e07a61..7fd99fc93050b386c5ad24e6dcd2fea1bf652c3f 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -125,7 +125,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, // transparently. CHECK_EQ(operand_number, 0); return index.empty(); - case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: // Select does not use any nested elements of its selected-from operands // (operand 1 and 2) CHECK_GE(operand_number, 0); @@ -283,8 +283,7 @@ std::ostream& operator<<(std::ostream& out, string InstructionValueSet::ToString() const { string out = StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n"); - ForEachElement([this, &out](const ShapeIndex& index, - const HloValueSet& value_set) { + ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) { StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); }); return out; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 9034073cc8a82311297ccd087741e6713110a5a7..e7674f3ddd5baa87c872d1c0b40bff340f3cd911 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,6 +15,8 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -39,6 +41,10 @@ Status ShapeVerifier::HandleSelect(HloInstruction* select) { return CheckTernaryShape(select); } +Status ShapeVerifier::HandleTupleSelect(HloInstruction* tuple_select) { + return CheckTernaryShape(tuple_select); +} + Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { std::vector operand_shapes; for (const HloInstruction* operand : concatenate->operands()) { @@ -99,6 +105,15 @@ Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) { ShapeInference::InferCrossReplicaSumShape(operand_shapes)); } +Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { + std::vector operand_shapes; + for (const HloInstruction* operand : hlo->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(hlo, + ShapeInference::InferAllToAllTupleShape(operand_shapes)); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), @@ -106,29 +121,140 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -Status ShapeVerifier::HandleInfeed(HloInstruction*) { return Status::OK(); } +namespace { + +Status CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no) { + const HloInstruction* token = instruction->operand(operand_no); + if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) { + return InternalError( + "Expected operand %lld to be token-shaped, actual shape is " + "%s:\n%s", + operand_no, ShapeUtil::HumanString(token->shape()).c_str(), + instruction->ToString().c_str()); + } + return Status::OK(); +} + +Status CheckOperandAndParameter(const HloInstruction* instruction, + int64 operand_number, + const HloComputation* computation, + int64 parameter_number) { + const HloInstruction* operand = instruction->operand(operand_number); + const HloInstruction* parameter = + computation->parameter_instruction(parameter_number); + if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) { + return InternalError("Operand %s shape does not match parameter's %s in %s", + operand->ToString().c_str(), + parameter->ToString().c_str(), + instruction->ToString().c_str()); + } + return Status::OK(); +} + +} // namespace + +Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { + HloInfeedInstruction* infeed = Cast(instruction); + // Infeed has an optional single token operand. + // TODO(b/80000000): Update when token is not optional. + if (infeed->operand_count() == 1) { + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); + } + + // The output of infeed is a tuple containing the data value and a token. + return CheckShape(infeed, + ShapeUtil::MakeTupleShape( + {infeed->infeed_shape(), ShapeUtil::MakeTokenShape()})); +} + +Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { + HloOutfeedInstruction* outfeed = Cast(instruction); + // Outfeed has an optional token operand (operand 1). + // TODO(b/80000000): Update when token is not optional. + if (outfeed->operand_count() == 2) { + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); + } -Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // Outfeed has a separate shape field for the value which is outfed to the - // host. The shape of the instruction itself is always nil because the outfeed - // produces no HLO value in the graph. + // host. The shape of the instruction itself is always a token. if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { return InternalError( - "Expected outfeed to have shape compatible with operand's shape %s, " + "Expected outfeed shape to be compatible with operand's shape %s, " "actual shape is %s:\n%s", ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(), outfeed->ToString().c_str()); } - return CheckShape(outfeed, ShapeUtil::MakeNil()); + return CheckShape(outfeed, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleHostCompute(HloInstruction*) { return Status::OK(); } -Status ShapeVerifier::HandleRng(HloInstruction*) { return Status::OK(); } +bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, + const Shape& shape_1, + const Shape& result_shape) { + return ShapeUtil::SameElementType(shape_0, shape_1) && + (ShapeUtil::SameElementType(shape_0, result_shape) || + (allow_mixed_precision_ && + ShapeUtil::SameElementTypeIgnoringFpPrecision(shape_0, + result_shape))); +} + +Status ShapeVerifier::HandleRng(HloInstruction* instruction) { + if (instruction->operand_count() != 2) { + return InternalError("Expected two operands for Rng instruction: %s", + instruction->ToString().c_str()); + } + + const Shape& shape_0 = instruction->operand(0)->shape(); + const Shape& shape_1 = instruction->operand(1)->shape(); + if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) { + return InternalError( + "Expected scalar types for the two operands of Rng instruction: %s", + instruction->ToString().c_str()); + } + + if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) { + return InternalError( + "Expected compatible element types for the result and the two operands" + " of Rng instruction: %s", + instruction->ToString().c_str()); + } + + PrimitiveType element_type = shape_0.element_type(); + switch (instruction->random_distribution()) { + case RNG_UNIFORM: + if (!primitive_util::IsFloatingPointType(element_type) && + !primitive_util::IsIntegralType(element_type) && + element_type != PRED) { + return InternalError( + "Element type not supported." + " Expected element to be of floating point type, integral type or" + " predicate type for RngUniform: %s", + instruction->ToString().c_str()); + } + break; + + case RNG_NORMAL: + if (!primitive_util::IsFloatingPointType(element_type)) { + return InternalError( + "Element type not supported." + " Expected element to be FloatingPointType for RngNormal: %s", + instruction->ToString().c_str()); + } + break; + default: + return InternalError( + "Invalid Rng distribution %s", + RandomDistribution_Name(instruction->random_distribution()).c_str()); + } + + return Status::OK(); +} Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { return CheckShape( @@ -137,13 +263,28 @@ Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { } Status ShapeVerifier::HandleSort(HloInstruction* sort) { - return CheckUnaryShape(sort); + if (sort->operand_count() == 2 && + !ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(1)->shape())) { + return InternalError( + "Expected sort to have to have the same dimensions for the keys and " + "the values. Keys shape is: %s\n, Values shape is: %s", + ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(), + ShapeUtil::HumanString(sort->operand(1)->shape()).c_str()); + } + return CheckVariadicShape(sort); } Status ShapeVerifier::HandleConstant(HloInstruction* constant) { return CheckShape(constant, constant->literal().shape()); } +Status ShapeVerifier::HandleIota(HloInstruction* iota) { + return ShapeUtil::Rank(iota->shape()) == 1 + ? Status::OK() + : InternalError("Iota only supports arrays of rank 1."); +} + Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { return CheckShape(get_tuple_element, ShapeInference::InferGetTupleElementShape( @@ -152,10 +293,13 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { } Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { + if (!ShapeUtil::IsArray(reduce->shape())) { + return InvalidArgument("Variadic reduce is not supported."); + } return CheckShape( reduce, ShapeInference::InferReduceShape( - reduce->operand(0)->shape(), reduce->operand(1)->shape(), + {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()}, reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); } @@ -203,8 +347,11 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleCall(HloInstruction* call) { + for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) { + TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i)); + } // The shape of kCall should match the shape of the computation it calls. - return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); + return CheckShape(call, call->to_apply()->root_instruction()->shape()); } Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); } @@ -273,19 +420,37 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { } Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { + TF_RETURN_IF_ERROR( + CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0)); + TF_RETURN_IF_ERROR( + CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0)); + const Shape& conditional_shape = + xla_while->while_condition()->root_instruction()->shape(); + if (!ShapeUtil::Compatible(conditional_shape, + ShapeUtil::MakeShape(PRED, {}))) { + return InternalError( + "Conditional computation shape does not lead to a scalar predicate " + "shape: %s", + ShapeUtil::HumanString(conditional_shape).c_str()); + } // The shape of kWhile should match the shape of the body computation it // calls. return CheckShape(xla_while, - xla_while->while_body()->ComputeProgramShape().result()); + xla_while->while_body()->root_instruction()->shape()); } Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { + TF_RETURN_IF_ERROR(CheckOperandAndParameter( + conditional, 1, conditional->true_computation(), 0)); + TF_RETURN_IF_ERROR(CheckOperandAndParameter( + conditional, 2, conditional->false_computation(), 0)); + TF_RETURN_IF_ERROR( + CheckShape(conditional, + conditional->true_computation()->root_instruction()->shape())); TF_RETURN_IF_ERROR(CheckShape( conditional, - conditional->true_computation()->ComputeProgramShape().result())); - return CheckShape( - conditional, - conditional->false_computation()->ComputeProgramShape().result()); + conditional->false_computation()->root_instruction()->shape())); + return Status::OK(); } Status ShapeVerifier::HandlePad(HloInstruction* pad) { @@ -295,39 +460,29 @@ Status ShapeVerifier::HandlePad(HloInstruction* pad) { } Status ShapeVerifier::HandleSend(HloInstruction* send) { - TF_RET_CHECK(send->users().size() == 1); - const HloInstruction* send_done = send->users().front(); - TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); - TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); - return CheckShape( - send, ShapeUtil::MakeTupleShape( - {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); + return CheckShape(send, + ShapeUtil::MakeTupleShape({send->operand(0)->shape(), + ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()})); } Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { - TF_RET_CHECK(send_done->operands().size() == 1); - const HloInstruction* send = send_done->operand(0); - TF_RET_CHECK(send->opcode() == HloOpcode::kSend); - TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); - return CheckShape(send_done, ShapeUtil::MakeNil()); + return CheckShape(send_done, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleRecv(HloInstruction* recv) { - TF_RET_CHECK(recv->users().size() == 1); - const HloInstruction* recv_done = recv->users().front(); - TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); - TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); - return CheckShape(recv, - ShapeUtil::MakeTupleShape( - {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); + return CheckShape( + recv, ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(recv->shape(), 0), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})); } Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { - TF_RET_CHECK(recv_done->operands().size() == 1); - const HloInstruction* recv = recv_done->operand(0); - TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); - TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); - return CheckShape(recv_done, recv->shape().tuple_shapes(0)); + return CheckShape( + recv_done, + ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(recv_done->operand(0)->shape(), 0), + ShapeUtil::MakeTokenShape()})); } Status ShapeVerifier::HandleBatchNormTraining( @@ -368,9 +523,9 @@ namespace { // inputs. Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { switch (instruction->opcode()) { - // White list the following opcodes for mixed-precision check, because they - // involve data pass through or grouping via tuples, where the precisions - // of buffers can be different. + // White list the following opcodes for mixed-precision check, because + // they involve data pass through or grouping via tuples, where the + // precisions of buffers can be different. case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kConstant: @@ -386,6 +541,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kRecvDone: case HloOpcode::kReducePrecision: case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kTuple: @@ -426,12 +582,21 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), gather->gather_window_bounds())); } -Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) { +Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { + return CheckShape( + scatter, ShapeInference::InferScatterShape( + scatter->operand(0)->shape(), scatter->operand(1)->shape(), + scatter->operand(2)->shape(), + scatter->to_apply()->ComputeProgramShape(), + scatter->scatter_dimension_numbers())); +} + +Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { std::vector operand_shapes; for (const HloInstruction* operand : token->operands()) { operand_shapes.push_back(&operand->shape()); } - return CheckShape(token, ShapeInference::InferTokenShape(operand_shapes)); + return CheckShape(token, ShapeInference::InferAfterAllShape(operand_shapes)); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, @@ -448,16 +613,10 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, // We treat BF16 and F32 as compatible types if mixed precision is allowed, // but only when the instruction defines the BF16/F32 buffer. switch (instruction->opcode()) { - case HloOpcode::kSelect: - if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) { - // Select only defines the top-level buffer, which in this case is the - // tuple, so we cannot allow mixed precision. - compatible = - ShapeUtil::Compatible(instruction->shape(), inferred_shape); - } else { - compatible = ShapeUtil::CompatibleIgnoringFpPrecision( - instruction->shape(), inferred_shape); - } + case HloOpcode::kTupleSelect: + // TupleSelect only defines the top-level buffer, which in this case is + // the tuple, so we cannot allow mixed precision. + compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); break; case HloOpcode::kGetTupleElement: case HloOpcode::kTuple: @@ -538,19 +697,6 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { instruction->opcode(), instruction->operands())); } -// Checks if the given two instructions shares the same channel id. -Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, - const HloInstruction* instr2) { - if (instr1->channel_id() != instr2->channel_id()) { - return InternalError( - "Expected to have the same channel id, actual channel ids are: %s " - "(%lld), %s (%lld)", - instr1->ToString().c_str(), instr1->channel_id(), - instr2->ToString().c_str(), instr2->channel_id()); - } - return Status::OK(); -} - string ComputationsToString( tensorflow::gtl::ArraySlice computations) { return tensorflow::str_util::Join( @@ -561,7 +707,8 @@ string ComputationsToString( // Verifies various invariants about the structure of the HLO: // -// (1) each instruction has a non-null parent() set to the HloComputation which +// (1) each instruction has a non-null parent() set to the HloComputation +// which // contains it. // // (2) each computation has a non-null parent() set to the HloModule which @@ -595,9 +742,9 @@ Status VerifyHloStructure(HloModule* module) { } // Check that operands are in the same computation separately from verifying - // parent() correctness so conditions like a null HloInstruction::parent() are - // identified and reported explicitly above rather than reporting a mismatched - // operand. + // parent() correctness so conditions like a null HloInstruction::parent() + // are identified and reported explicitly above rather than reporting a + // mismatched operand. for (const HloComputation* computation : module->computations()) { for (const HloInstruction* instruction : computation->instructions()) { for (int i = 0; i < instruction->operand_count(); ++i) { @@ -621,13 +768,14 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { return InternalError( - "Instruction of fused computation does not match expected instruction " + "Instruction of fused computation does not match expected " + "instruction " "%s.", fusion->ToString().c_str()); } - // Fused root instruction and fused parameters must all be owned by the fusion - // computation. + // Fused root instruction and fused parameters must all be owned by the + // fusion computation. bool root_owned = false; const std::vector& fused_parameters = fusion->fused_parameters(); @@ -669,8 +817,8 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { fusion->ToString().c_str()); } - // All uses of fused instructions must be in the fusion computation, and every - // non-root instruction must have at least one use. + // All uses of fused instructions must be in the fusion computation, and + // every non-root instruction must have at least one use. for (auto* instruction : fusion->fused_instructions_computation()->instructions()) { if (instruction != fused_root) { @@ -714,7 +862,8 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (!ShapeUtil::Compatible(fused_param->shape(), fusion->operand(param_no)->shape())) { return InternalError( - "Shape mismatch between parameter number %lld and its operand in %s.", + "Shape mismatch between parameter number %lld and its operand in " + "%s.", param_no, fusion->ToString().c_str()); } } @@ -750,33 +899,23 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { "While loop must have exactly one operand; had %lld : %s", instruction->operand_count(), instruction->ToString().c_str()); } - auto* init = instruction->operand(0); - auto* cond_param = while_cond->parameter_instruction(0); - if (!ShapeUtil::Compatible(init->shape(), cond_param->shape())) { - return FailedPrecondition( - "While condition's parameter must have the same shape as the " - "loop's 'init'. init: %s, param: %s", - init->ToString().c_str(), cond_param->ToString().c_str()); - } - auto* cond_root = while_cond->root_instruction(); - if (!ShapeUtil::Compatible(cond_root->shape(), - ShapeUtil::MakeShape(PRED, {}))) { - return FailedPrecondition("While condition should have shape PRED: %s", - cond_root->ToString().c_str()); - } - auto* body_param = while_body->parameter_instruction(0); - if (!ShapeUtil::Compatible(init->shape(), body_param->shape())) { + return Status::OK(); +} + +Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) { + if (instruction->true_computation()->num_parameters() != 1) { return FailedPrecondition( - "While body's parameter must have the same shape as the loop's" - " 'init'. init: %s, param: %s", - init->ToString().c_str(), body_param->ToString().c_str()); + "True computation %s of %s must have 1 parameter insted of %lld", + instruction->true_computation()->name().c_str(), + instruction->ToString().c_str(), + instruction->true_computation()->num_parameters()); } - auto* body_root = while_body->root_instruction(); - if (!ShapeUtil::Compatible(init->shape(), body_root->shape())) { + if (instruction->false_computation()->num_parameters() != 1) { return FailedPrecondition( - "While body should have same shape as the loop's 'init'." - "init: %s, body: %s", - init->ToString().c_str(), body_root->ToString().c_str()); + "False computation %s of %s must have 1 parameter insted of %lld", + instruction->false_computation()->name().c_str(), + instruction->ToString().c_str(), + instruction->false_computation()->num_parameters()); } return Status::OK(); } @@ -785,8 +924,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { const Shape& out_shape = instruction->shape(); for (HloInstruction* operand : instruction->operands()) { const Shape& operand_shape = operand->shape(); - if (!ShapeUtil::IsScalar(operand_shape) && - !ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { + if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { return FailedPrecondition( "Implicit broadcast is not allowed in HLO." "Found non-compatible shapes for instruction %s.\n" @@ -814,9 +952,10 @@ bool ShapeContainsToken(const Shape& shape) { } // Verifies that all types entering and exiting the entry computation are -// legal. For example, TOKEN types have no Literal representation and cannot be -// on the interface of the entry computation (parameters and root instruction). +// legal. Status VerifyEntryAndExitShapes(const HloModule& module) { + // Tokens cannot be passed as entry parameters. + // TODO(b/80000000): Remove this constraint. for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { HloInstruction* param = module.entry_computation()->parameter_instruction(i); @@ -826,13 +965,102 @@ Status VerifyEntryAndExitShapes(const HloModule& module) { ShapeUtil::HumanString(param->shape()).c_str()); } } - if (ShapeContainsToken( - module.entry_computation()->root_instruction()->shape())) { + return Status::OK(); +} + +// Checks if the given two instructions share the same channel id. +Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2) { + if (instr1->channel_id() != instr2->channel_id()) { + return InternalError( + "Expected to have the same channel id, actual channel ids are: %s " + "(%lld), %s (%lld)", + instr1->ToString().c_str(), instr1->channel_id(), + instr2->ToString().c_str(), instr2->channel_id()); + } + return Status::OK(); +} + +// Checks if the given two instructions have the same is_host_transfer +// attribute value. Intsructions must be send/recv instructions or their +// 'done' variant. +Status CheckSameIsHostTransfer(const HloInstruction* instr1, + const HloInstruction* instr2) { + const HloSendRecvInstruction* send_recv1 = + DynCast(instr1); + const HloSendRecvInstruction* send_recv2 = + DynCast(instr2); + TF_RET_CHECK(send_recv1 != nullptr); + TF_RET_CHECK(send_recv2 != nullptr); + if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { return InternalError( - "Entry root is or contains a token shape: %s", - ShapeUtil::HumanString( - module.entry_computation()->root_instruction()->shape()) - .c_str()); + "Expected instructions to have the same is-host-transfer property: " + "%s, " + "%s ", + instr1->ToString().c_str(), instr2->ToString().c_str()); + } + return Status::OK(); +} + +// Checks various invariants of send and recv instructions. +Status VerifySendsAndRecvs(const HloModule& module) { + tensorflow::gtl::FlatMap host_channels; + // Host send/recv instructions must have their own unique channel. + auto check_unique_host_channel = [&](const HloInstruction* instruction) { + const HloSendRecvInstruction* sendrecv = + DynCast(instruction); + if (sendrecv->is_host_transfer()) { + auto it_inserted = + host_channels.insert({sendrecv->channel_id(), sendrecv}); + if (!it_inserted.second) { + return FailedPrecondition( + "Channel %lld is used for multiple host send/recv instructions: " + "%s " + "and " + "%s", + sendrecv->channel_id(), sendrecv->ToString().c_str(), + it_inserted.first->second->ToString().c_str()); + } + } + + return Status::OK(); + }; + + // Send/Recv instruction must have a single user: the corresponding + // SendDone/RecvDone. with matching channel. + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* send_done = instruction->users().front(); + TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); + break; + } + case HloOpcode::kRecv: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* recv_done = instruction->users().front(); + TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); + break; + } + case HloOpcode::kSendDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); + break; + case HloOpcode::kRecvDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); + break; + default: + break; + } + } } return Status::OK(); } @@ -841,6 +1069,7 @@ Status VerifyEntryAndExitShapes(const HloModule& module) { StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); + TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module)); tensorflow::gtl::FlatMap instructions; @@ -880,7 +1109,11 @@ StatusOr HloVerifier::Run(HloModule* module) { << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); } else if (instruction->opcode() == HloOpcode::kWhile) { TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction)); - } else if (instruction->IsElementwise()) { + } else if (instruction->opcode() == HloOpcode::kConditional) { + TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction)); + } else if (instruction->opcode() != + HloOpcode::kRng /* Rng operands are always scalar. */ + && instruction->IsElementwise()) { TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction)); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 7283b3e7dcdbed5be18a1da1571287cf0c089288..c942fab08e1ace75bccb8762954787a4366922a9 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -35,7 +35,9 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleElementwiseBinary(HloInstruction* hlo) override; Status HandleClamp(HloInstruction* clamp) override; Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleConcatenate(HloInstruction* concatenate) override; + Status HandleIota(HloInstruction* iota) override; Status HandleConvert(HloInstruction* convert) override; Status HandleBitcastConvert(HloInstruction* convert) override; Status HandleCopy(HloInstruction* copy) override; @@ -43,6 +45,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllToAll(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; @@ -81,7 +84,8 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; - Status HandleGenerateToken(HloInstruction* token) override; + Status HandleScatter(HloInstruction* scatter) override; + Status HandleAfterAll(HloInstruction* token) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } @@ -101,11 +105,14 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckTernaryShape(const HloInstruction* instruction); Status CheckVariadicShape(const HloInstruction* instruction); - // Checks if the given two instructions share the same channel id. - Status CheckSameChannel(const HloInstruction* instr1, - const HloInstruction* instr2); - private: + // Return true if the shapes of the two operands have the same element type, + // and the result shape either has the same element type as the operand + // shapes or mixed precision is allowed and the result shape and the operand + // shapes have floating point element types. + bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, + const Shape& result_shape); + // Whether the inputs and output of an instruction can contain both F32s and // BF16s. Tuples that include both F32s and BF16s are allowed regardless of // this flag. @@ -145,6 +152,8 @@ class HloVerifier : public HloPassInterface { Status CheckWhileInstruction(HloInstruction* instruction); + Status CheckConditionalInstruction(HloInstruction* instruction); + // Checks that the non-scalar operand shapes are compatible to the output // shape, i.e., that there are no implicit broadcasts of size-one dimensions. Status CheckElementwiseInstruction(HloInstruction* instruction); diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index c92db0be14dceb32ea86521dcc99b8f63738e4a5..d764964f3c3dc58a54bd0307f8b625076c14f3e5 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -33,7 +34,17 @@ namespace { using ::testing::HasSubstr; -using HloVerifierTest = HloTestBase; +class HloVerifierTest : public HloTestBase { + public: + HloVerifierTest() + : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {} +}; + +class HloVerifierTestAllowMixedPrecision : public HloTestBase { + public: + HloVerifierTestAllowMixedPrecision() + : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {} +}; TEST_F(HloVerifierTest, NullInstructionParent) { HloComputation::Builder builder(TestName()); @@ -123,5 +134,146 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) { EXPECT_FALSE(verifier().Run(module.get()).status().ok()); } +TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) { + const char* const hlo_string = R"( +HloModule Module + +callme { + ROOT param = (s32[], f32[4]) parameter(0) +} + +ENTRY entry { + p0 = (f32[4], s32[]) parameter(0) + ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("shape does not match parameter")); +} + +TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) { + const char* const hlo_string = R"( +HloModule Module + +true_branch { + tparam = (s32[], f32[4]) parameter(0) + ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1 +} + +false_branch { + fparam = (s32[], f32[4]) parameter(0) + ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1 +} + +ENTRY entry { + p0 = (f32[4], s32[]) parameter(0) + constant = pred[] constant(true) + ROOT conditional = f32[4] conditional(constant, p0, p0), + true_computation=true_branch, false_computation=false_branch +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("shape does not match parameter")); +} + +TEST_F(HloVerifierTest, RngOpnd0NotScalar) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOpnd0NotScalar { + constant.0 = f32[] constant(0) + constant.1 = f16[2] constant({1, 3}) + ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1), + distribution=rng_uniform + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type")); +} + +TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngOperandElementTypesNotMatch { + constant.0 = f32[] constant(0) + constant.1 = f16[] constant(1) + ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected compatible element types")); +} + +TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngResultElementTypeNotMatch { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected compatible element types")); +} + +TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngResultElementTypeNotMatch { + constant.0 = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, RngElementTypeNotSupported) { + const char* const hlo_string = R"( + HloModule Module + + ENTRY RngElementTypeNotSupported { + constant.0 = s32[] constant(0) + constant.1 = s32[] constant(1) + ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1), + distribution=rng_normal + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index d7458c338e9f1df9fac90270845aae0b8f779ee2..bb5b40a8a87c5eab5a5b1599581a81bbd064511b 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -36,7 +36,8 @@ string HumanReadableProfileBuilder::ToString() const { computation_name_.c_str(), HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); - auto print_op = [&](const OpInfo& op) { + int64 cumulative_cycles = 0; + auto print_op = [&](const OpInfo& op, bool is_total = false) { // Skip ops with 0 optimal seconds and 0 actual cycles. These are ops that // were expected to be free and are actually free -- things like (on most // backends) kParameter or kConstant HLOs. There's no need to clutter the @@ -59,27 +60,44 @@ string HumanReadableProfileBuilder::ToString() const { } } + double cumulative_cycles_percent = 0; double cycles_percent = 0; + if (!is_total) { + cumulative_cycles += op.cycles; + } if (total_cycles_ > 0) { cycles_percent = op.cycles / static_cast(total_cycles_) * 100; + cumulative_cycles_percent = + cumulative_cycles / static_cast(total_cycles_) * 100; + } + + string cycles_percent_str; + if (is_total) { + // Leaving off the two trailing decimal points of "100.%" lets us save two + // columns in the output. + cycles_percent_str = "100.% 100Σ"; + } else { + cycles_percent_str = + Printf("%5.2f%% %2.0fΣ", cycles_percent, cumulative_cycles_percent); } double nsecs = op.cycles / clock_rate_ghz_; - Appendf(&s, - "%15lld cycles (%6.2f%%) :: %12.1f usec %22s :: %18s " - ":: %18s :: %14s :: %16s :: %s\n", - op.cycles, cycles_percent, CyclesToMicroseconds(op.cycles), - op.optimal_seconds < 0 - ? "" - : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), - op.flop_count <= 0 - ? "" - : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), - op.transcendental_count <= 0 ? "" - : HumanReadableNumTranscendentalOps( - op.transcendental_count, nsecs) - .c_str(), - bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); + Appendf( + &s, + "%15lld cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " + "%16s :: %s\n", + op.cycles, cycles_percent_str.c_str(), CyclesToMicroseconds(op.cycles), + op.optimal_seconds < 0 + ? "" + : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), + op.flop_count <= 0 + ? "" + : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), + op.transcendental_count <= 0 + ? "" + : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs) + .c_str(), + bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); }; float optimal_seconds_sum = 0.0; @@ -98,7 +116,8 @@ string HumanReadableProfileBuilder::ToString() const { VLOG(1) << "Total floating point ops: " << total_flops; print_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, - total_transcendentals, total_bytes, optimal_seconds_sum}); + total_transcendentals, total_bytes, optimal_seconds_sum}, + /*is_total=*/true); // Sort ops in decreasing order of cycles, and print them. std::vector sorted_ops(op_infos_); diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc index 8c7b38dd1bf73e0be7b669d7215812aaef1cee17..f85d31d5225b8012b68f851b2bfec219d736ba0d 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 8b3fa6c1572cf0ed91fc427722edcb23d8b8529d..3531b7223fb11df212fa8d30e3adba6aac6c5679 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -28,6 +29,7 @@ namespace { using Analysis = IndexedArrayAnalysis; using UnknownArray = Analysis::UnknownArray; using ConstantArray = Analysis::ConstantArray; +using ReshapedArray = Analysis::ReshapedArray; using ScalarIndexedArray = Analysis::ScalarIndexedArray; using tensorflow::gtl::ArraySlice; using tensorflow::str_util::Join; @@ -52,6 +54,13 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { "(constant ", ShapeUtil::HumanString(root->shape()), ")"); } + case Array::kReshaped: { + ReshapedArray* reshaped_array = root->as(); + return tensorflow::strings::StrCat( + "(reshape ", ToString(reshaped_array->operand(), print_constants), + " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")"); + } + case Array::kScalarIndexedConstant: case Array::kScalarIndexed: { auto* indexed_array = root->as(); @@ -152,6 +161,12 @@ StatusOr IndexedArrayAnalysis::ComputeArrayFor( computed_array, ComputeArrayForReshape(instr->shape(), FindOrDie(cache_, instr->operand(0)))); + } else if (instr->opcode() == HloOpcode::kDot) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(), + FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); } else { computed_array = nullptr; } @@ -239,15 +254,40 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForGather( tensorflow::gtl::ArraySlice window_bounds, Array* source, Array* indices) { if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { + VLOG(3) << "ComputeArrayForGather: indices are not scalar"; return nullptr; } CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1); - if (!c_binary_search(dim_numbers.elided_window_dims(), - dim_numbers.gather_dims_to_operand_dims(0))) { + + // We can also handle dim_numbers.elided_window_dims_size() == 0 here, should + // it become relevant. + + if (dim_numbers.elided_window_dims_size() != 1 || + dim_numbers.elided_window_dims(0) != + dim_numbers.gather_dims_to_operand_dims(0)) { + VLOG(3) << "ComputeArrayForGather: gather operations must elide " + "gather_dims_to_operand_dims[0] and " + "gather_dims_to_operand_dims[0] only"; return nullptr; } + // ScalarIndexedArray cannot represent gathers that "slice" along some + // dimensions -- for instance it cannot represent a gather that picks 5 [2,3] + // arrays from an array of size [7,4,6]. We check that condition down below: + + for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) { + if (i != dim_numbers.elided_window_dims(0) && + source->shape().dimensions(i) != window_bounds[i]) { + VLOG(3) << "ComputeArrayForGather: window_bounds[" << i + << "] != source->shape().dimensions(" << i << ") -- " + << source->shape().dimensions(i) << " vs. " << window_bounds[i] + << " with dim_numbers.elided_window_dims(0) = " + << dim_numbers.elided_window_dims(0); + return nullptr; + } + } + int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0); std::vector output_dims; for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { @@ -257,8 +297,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForGather( } if (auto* indexed = dynamic_cast(source)) { - auto it = c_find(indexed->output_dims(), source_dim); - if (it != indexed->output_dims().end()) { + if (c_linear_search(indexed->output_dims(), source_dim)) { return FoldGatherOfGather(indexed, indices, source_dim, output_dims, shape); } @@ -336,7 +375,11 @@ std::vector ComputeReshapePassthroughDimPairs( // result_subarray_size does not include the elements in the current // `result_dim` dimension (we multiply in result_shape[result_dim] at the // end of loop body) so candidate_operand_dim can never be zero. - CHECK_NE(candidate_operand_dim, 0); + CHECK_NE(candidate_operand_dim, 0) + << "result_dim = " << result_dim + << ", result_subarray_size = " << result_subarray_size + << ", result_shape = [" << Join(result_shape, ",") << "]" + << ", operand_shape = [" << Join(operand_shape, ",") << "]"; if (candidate_operand_dim != -1 && result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { @@ -357,7 +400,7 @@ std::vector ComputeReshapePassthroughDimPairs( }); VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" << Join(result_shape, ",") << "] passthrough indices are [" - << Join(result_strings, ",") << "]"; + << Join(result_strings, ",") << "] (legend: `result`->`operand`)"; } DCHECK(c_is_sorted( @@ -398,22 +441,202 @@ int64 MapPassthroughOperandDimToResultDim( int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, ArraySlice result_shape, int64 source_passthrough_dim) { + VLOG(3) << "FindSourcePositionForPassthroughResultDim([" + << Join(operand_shape, ",") << "], [" << Join(result_shape, ",") + << "], " << source_passthrough_dim << ")"; + int64 indexed_source_subarray_size = std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, - operand_shape.end(), 1, std::multiplies()); + operand_shape.end(), 1LL, std::multiplies()); return FindSuffixWithProduct(result_shape, indexed_source_subarray_size); } +Shape StripDegenerateDimensions(const Shape& shape) { + DimensionVector new_dims; + c_copy_if(shape.dimensions(), std::back_inserter(new_dims), + [](int64 dim) { return dim != 1; }); + return ShapeUtil::MakeShape(shape.element_type(), new_dims); +} }; // namespace -StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( - const Shape& shape, Array* operand) { - auto* scalar_indexed = dynamic_cast(operand); - if (!scalar_indexed) { +StatusOr +IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( + ScalarIndexedArray* operand) { + const Shape& shape = operand->shape(); + if (!ShapeUtil::HasDegenerateDimensions(shape)) { + return operand; + } + + // We only need to reshape out the degenerate dims from the indices and the + // source (except the source dim). + + const Shape& source_shape = operand->source()->shape(); + DimensionVector new_source_shape_dims; + for (int64 i = 0, e = source_shape.dimensions_size(); i < e; i++) { + if (i == operand->source_dim() || source_shape.dimensions(i) != 1) { + new_source_shape_dims.push_back(source_shape.dimensions(i)); + } + } + + Shape new_source_shape = + ShapeUtil::MakeShape(shape.element_type(), new_source_shape_dims); + Shape new_indices_shape = + StripDegenerateDimensions(operand->indices()->shape()); + + TF_ASSIGN_OR_RETURN( + Array* const new_source, + ComputeArrayForReshape(new_source_shape, operand->source())); + TF_ASSIGN_OR_RETURN( + Array* const new_indices, + ComputeArrayForReshape(new_indices_shape, operand->indices())); + + // Build the new output dims while keeping track of the degenerate dims that + // will no longer be present. + DimensionVector new_output_dims; + int64 degenerate_dims_seen = 0; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (shape.dimensions(i) == 1) { + degenerate_dims_seen++; + } else if (ArrayContains(operand->output_dims(), i)) { + new_output_dims.push_back(i - degenerate_dims_seen); + } + } + + // Similarly, build the new source dim while keeping track of the degenerate + // dims that will no longer be present. + int64 degenerate_dims_before_source_dim = + std::count(source_shape.dimensions().begin(), + source_shape.dimensions().begin() + operand->source_dim(), 1); + int64 new_source_dim = + operand->source_dim() - degenerate_dims_before_source_dim; + + return ConstructScalarIndexedArray( + new_source, new_indices, new_source_dim, + InlinedVectorToVector(new_output_dims), + StripDegenerateDimensions(operand->shape())); +} + +StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( + ScalarIndexedArray* operand, + tensorflow::gtl::ArraySlice degenerate_dims) { + if (degenerate_dims.empty()) { + return operand; + } + + CHECK(!ShapeUtil::HasDegenerateDimensions(operand->shape())); + + DimensionVector new_output_dims = [&]() { + // To make things easy we use a "scratch" buffer of bools where the i'th + // element is true iff the i'th component of the result index is an output + // index. + + gtl::InlinedVector output_dims_bitvector( + operand->shape().dimensions_size()); + for (int64 output_dim : operand->output_dims()) { + output_dims_bitvector[output_dim] = true; + } + + for (int64 degenerate_dim : degenerate_dims) { + InsertAt(&output_dims_bitvector, degenerate_dim, false); + } + + DimensionVector result; + result.reserve(operand->output_dims().size()); + for (int64 i = 0, e = output_dims_bitvector.size(); i < e; i++) { + if (output_dims_bitvector[i]) { + result.push_back(i); + } + } + + return result; + }(); + + DimensionVector new_result_shape_dims; + c_copy(operand->shape().dimensions(), + std::back_inserter(new_result_shape_dims)); + for (int64 degenerate_dim : degenerate_dims) { + InsertAt(&new_result_shape_dims, degenerate_dim, 1); + } + + DimensionVector new_source_shape_dims = new_result_shape_dims; + for (int64 output_dim : new_output_dims) { + EraseAt(&new_source_shape_dims, output_dim); + } + + int64 new_source_dim = [&]() { + for (int i = 0, e = new_source_shape_dims.size(); i < e; i++) { + int64 non_degenerate_dims_seen = 0; + if (non_degenerate_dims_seen == operand->source_dim()) { + return i; + } + if (new_source_shape_dims[new_source_dim] != 1) { + non_degenerate_dims_seen++; + } + } + LOG(FATAL) << "Did not find source dim in " << ToString(operand); + }(); + + int64 source_dim_size = + operand->source()->shape().dimensions(operand->source_dim()); + InsertAt(&new_source_shape_dims, /*index=*/new_source_dim, + /*value=*/source_dim_size); + + Shape new_source_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + new_source_shape_dims); + Shape new_result_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + new_result_shape_dims); + + TF_ASSIGN_OR_RETURN( + Array* const new_source, + ComputeArrayForReshape(new_source_shape, operand->source())); + return ConstructScalarIndexedArray( + new_source, operand->indices(), new_source_dim, + InlinedVectorToVector(new_output_dims), new_result_shape); +} + +StatusOr IndexedArrayAnalysis::FoldReshapeOfGather( + const Shape& shape, ScalarIndexedConstantArray* operand) { + VLOG(3) << "FoldReshapeOfGather(" << ToString(operand) << ")"; + + // To make things easier on ourselves, instead of directly trying to fold the + // reshape of `operand` to `shape`, we call + // `FoldReshapeOfGatherNoDegenerateDims` on shapes without degenerate dims and + // handle the degenerate dimensions here by inserting reshapes. + + TF_ASSIGN_OR_RETURN(ScalarIndexedArray* const operand_without_degenerate_dims, + ReshapeToRemoveDegenerateDims(operand)); + + Shape output_shape_without_degenerate_dims = StripDegenerateDimensions(shape); + TF_ASSIGN_OR_RETURN( + ScalarIndexedArray* const folded_reshape_without_degenerate_dims, + FoldReshapeOfGatherNoDegenerateDims( + output_shape_without_degenerate_dims, + operand_without_degenerate_dims->as())); + + if (folded_reshape_without_degenerate_dims == nullptr) { return nullptr; } + DimensionVector degenerate_result_dims; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (shape.dimensions(i) == 1) { + degenerate_result_dims.push_back(i); + } + } + + return ReshapeToAddDegenerateDims(folded_reshape_without_degenerate_dims, + degenerate_result_dims); +} + +StatusOr +IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( + const Shape& shape, ScalarIndexedConstantArray* scalar_indexed) { + VLOG(3) << "FoldReshapeOfGatherNoDegenerateDims(" << ToString(scalar_indexed) + << ")"; + CHECK(!ShapeUtil::HasDegenerateDimensions(shape)); + CHECK(!ShapeUtil::HasDegenerateDimensions(scalar_indexed->shape())); + // Try to fold Reshape(ScalarIndexed(Const, Indices)) // => ScalarIndexed(Const', Indices) // @@ -464,7 +687,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( std::vector reshape_passthrough_dims = ComputeReshapePassthroughDimPairs( - /*operand_shape=*/AsInt64Slice(operand->shape().dimensions()), + /*operand_shape=*/AsInt64Slice(scalar_indexed->shape().dimensions()), /*result_shape=*/AsInt64Slice(shape.dimensions())); auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) { @@ -474,6 +697,8 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( if (!c_all_of(scalar_indexed->output_dims(), is_reshape_passthrough_operand_dim)) { + VLOG(3) << "Not all output dims are passthrough dims " + << ToString(scalar_indexed); return nullptr; } @@ -527,6 +752,11 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( // (a.k.a. isn't pass-through) than the [3,5,2] array. if (source_dim_for_new_scalar_indexed_node == -1) { + VLOG(3) << "Could not compute the source dim for the new scalar indexed " + "node: scalar_indexed_source_shape = [" + << Join(scalar_indexed_source_shape.dimensions(), ",") + << "] and new_scalar_indexed_source_shape = [" + << Join(new_scalar_indexed_source_shape, ",") << "]"; return nullptr; } @@ -534,6 +764,10 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); + CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1LL, + std::multiplies()), + ShapeUtil::ElementsIn(scalar_indexed_source_shape)); + CHECK(IsReshapePassthroughOperandDim( ComputeReshapePassthroughDimPairs( /*operand_shape=*/AsInt64Slice( @@ -564,6 +798,31 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( output_dims_for_new_scalar_indexed_node, shape); } +StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( + const Shape& shape, Array* operand) { + if (ShapeUtil::Compatible(operand->shape(), shape)) { + return operand; + } + + if (auto* scalar_indexed = + dynamic_cast(operand)) { + TF_ASSIGN_OR_RETURN(Analysis::Array * reshape_folded_into_gather, + FoldReshapeOfGather(shape, scalar_indexed)); + if (reshape_folded_into_gather) { + return reshape_folded_into_gather; + } + } + + if (auto* constant_array = dynamic_cast(operand)) { + TF_ASSIGN_OR_RETURN(Literal* const new_literal, + TakeOwnership(constant_array->literal()->Reshape( + AsInt64Slice(shape.dimensions())))); + return Construct(new_literal); + } + + return Construct(operand, shape); +} + StatusOr IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, Array* lhs, @@ -703,11 +962,177 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, return Construct( new_source, scalar_indexed_const->indices(), scalar_indexed_const->source_dim(), - std::vector(scalar_indexed_const->output_dims().begin(), - scalar_indexed_const->output_dims().end()), + ArraySliceToVector(scalar_indexed_const->output_dims()), scalar_indexed_const->shape()); } +namespace { + +// Returns the non-contracting non-batch dimension (as per `contracting_dims` +// and `batch_dims`) if there is exactly one, otherwise returns nullopt. +gtl::optional GetOnlyNonContractingNonBatchDim( + int64 rank, ArraySlice contracting_dims, + ArraySlice batch_dims) { + gtl::optional result; + for (int64 dim = 0; dim < rank; dim++) { + if (!ArrayContains(contracting_dims, dim) && + !ArrayContains(batch_dims, dim)) { + if (result.has_value()) { + return gtl::nullopt; + } + result = dim; + } + } + return result; +} + +// Returns true if `indexed_array`, which is either the LHS or the RHS of a Dot +// HLO, can be folded into the dot operation. For now these conditions are both +// necessary and sufficient. +// +// `tag` describes the caller. Used only for logging. +// +// `contracting_dims` and `batch_dims` are the contracting and batch dimensions +// of whatever operand `indexed_array` is to the dot (LHS or RHS). +bool CanFoldDotIntoIndexedArray( + tensorflow::StringPiece tag, + Analysis::ScalarIndexedConstantArray* indexed_array, + ArraySlice contracting_dims, ArraySlice batch_dims) { + gtl::optional non_contracting_non_batch_dim = + GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()), + contracting_dims, batch_dims); + if (!non_contracting_non_batch_dim.has_value()) { + VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions"; + return false; + } + + if (indexed_array->output_dims().size() != 1 || + indexed_array->output_dims()[0] != *non_contracting_non_batch_dim) { + VLOG(3) << tag << ": output dims != the lhs non-contracting non-batch dim"; + return false; + } + + int64 indexed_array_rank = ShapeUtil::Rank(indexed_array->shape()); + if (indexed_array->source_dim() < (indexed_array_rank - 2)) { + // This restriction can be lifted by inserting reshape nodes. + VLOG(3) << tag + << ": source dim is not in the low two dims, won't be able to form " + "a matmul"; + return false; + } + + return true; +} + +} // namespace + +StatusOr +IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + ScalarIndexedConstantArray* lhs, ConstantArray* rhs) { + VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " " + << ToString(rhs); + if (!CanFoldDotIntoIndexedArray( + "ComputeArrayForDotWithIndexedLhs", lhs, /*contracting_dims=*/ + AsInt64Slice(dim_numbers.lhs_contracting_dimensions()), + /*batch_dims=*/AsInt64Slice(dim_numbers.lhs_batch_dimensions()))) { + return nullptr; + } + + int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + DotDimensionNumbers new_dim_numbers = dim_numbers; + new_dim_numbers.set_lhs_contracting_dimensions( + 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); + + TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, lhs->literal(), *rhs->literal()))); + + // The new source dimension is wherever the non-batch non-contracting LHS + // dimension "went". + int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() + + dim_numbers.rhs_batch_dimensions_size(); + + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, lhs->indices(), new_source_dim, + ArraySliceToVector(lhs->output_dims()), shape); +} + +StatusOr +IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + ConstantArray* lhs, ScalarIndexedConstantArray* rhs) { + VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " " + << ToString(rhs); + if (!CanFoldDotIntoIndexedArray( + "ComputeArrayForDotWithIndexedRhs", rhs, /*contracting_dims=*/ + AsInt64Slice(dim_numbers.rhs_contracting_dimensions()), + /*batch_dims=*/AsInt64Slice(dim_numbers.rhs_batch_dimensions()))) { + return nullptr; + } + + int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + + DotDimensionNumbers new_dim_numbers = dim_numbers; + new_dim_numbers.set_rhs_contracting_dimensions( + 0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1)); + + TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, *lhs->literal(), rhs->literal()))); + + // The new source dimension is wherever the non-batch non-contracting RHS + // dimension "went". + int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() + + dim_numbers.rhs_batch_dimensions_size() + 1; + + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, rhs->indices(), new_source_dim, + ArraySliceToVector(rhs->output_dims()), shape); +} + +StatusOr IndexedArrayAnalysis::ComputeArrayForDot( + const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs, + Array* rhs) { + // Intuitively, if + // + // - The LHS of a dot product is a gathered sequence of rows from a constant + // array (i.e. LHS[I,J] = Const[Indices[I],J]) and the RHS is a constant + // + // OR + // + // - If the RHS of a dot product is a gathered sequence of columns from a + // constant array (i.e. RHS[I,J] = Const[I, Indices[J]]) and the LHS is a + // constant + // + // then the result of the dot product itself is a gather from a constant + // array. E.g. Dot(LHS, ConstRhs) where LHS[I,J] = Const[Indices[I],J] can be + // rewritten as Result where Result[I,J] = Dot(Const, ConstRhs)[Indices[I], + // J]. + // + // We do a general version of this rewrite here. + VLOG(3) << "ComputeArrayForDot(" << ToString(lhs) << " " << ToString(rhs); + if (auto* lhs_indexed_array = + dynamic_cast(lhs)) { + if (auto* rhs_constant = dynamic_cast(rhs)) { + return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers, + lhs_indexed_array, rhs_constant); + } + } + + if (auto* rhs_indexed_array = + dynamic_cast(rhs)) { + if (auto* lhs_constant = dynamic_cast(lhs)) { + return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant, + rhs_indexed_array); + } + } + + return nullptr; +} + tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { return "indexed-array-analysis-printer-pass"; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index ce92fd2919c90fa8a2fb7b796ed6f0fdaf48fe62..e923dc39f7f464a8d3c400294499a6f5efda3991 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -39,7 +39,13 @@ class IndexedArrayAnalysis { // Array instances are immutable once created. class Array { public: - enum Kind { kUnknown, kConstant, kScalarIndexedConstant, kScalarIndexed }; + enum Kind { + kUnknown, + kConstant, + kReshaped, + kScalarIndexedConstant, + kScalarIndexed + }; virtual Kind kind() const = 0; virtual const Shape& shape() const = 0; @@ -96,6 +102,27 @@ class IndexedArrayAnalysis { friend class IndexedArrayAnalysis; }; + // Represents an Array that is a reshape of another Array. + class ReshapedArray : public Array { + public: + Kind kind() const override { return kReshaped; } + + // The array to reshape. + Array* operand() const { return operand_; } + + // The output shape. + const Shape& shape() const override { return shape_; } + + private: + explicit ReshapedArray(Array* operand, Shape shape) + : operand_(operand), shape_(shape) {} + + Array* operand_; + const Shape shape_; + + friend class IndexedArrayAnalysis; + }; + // --------------------------------------------------------------------------- // Indexed Array Overview // --------------------------------------------------------------------------- @@ -241,6 +268,18 @@ class IndexedArrayAnalysis { tensorflow::gtl::ArraySlice window_bounds, Array* source, Array* indices); + StatusOr ComputeArrayForDotWithIndexedLhs( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + ScalarIndexedConstantArray* lhs, ConstantArray* rhs); + + StatusOr ComputeArrayForDotWithIndexedRhs( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + ConstantArray* lhs, ScalarIndexedConstantArray* rhs); + + StatusOr ComputeArrayForDot(const Shape& shape, + const DotDimensionNumbers& dim_numbers, + Array* lhs, Array* rhs); + // This tries to fold a ScalarIndexedArray which has another // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a // ScalarIndexedArray as indices. If `source` happened to be a @@ -266,6 +305,21 @@ class IndexedArrayAnalysis { ScalarIndexedArray* source, Array* indices, int64 source_dim, tensorflow::gtl::ArraySlice output_dims, Shape shape); + // Reshapes a scalar-indexed node to remove the degenerate dimensions in its + // output. The result is always a scalar-indexed node. + StatusOr ReshapeToRemoveDegenerateDims( + ScalarIndexedArray* operand); + + // Reshapes a scalar-indexed node such that the result has the degenerate + // dimensions `degenerate_dims`. The result is always a scalar-indexed node. + StatusOr ReshapeToAddDegenerateDims( + ScalarIndexedArray* operand, + tensorflow::gtl::ArraySlice degenerate_dims); + + StatusOr FoldReshapeOfGather( + const Shape& shape, ScalarIndexedConstantArray* operand); + StatusOr FoldReshapeOfGatherNoDegenerateDims( + const Shape& shape, ScalarIndexedConstantArray* scalar_indexed); StatusOr ComputeArrayForReshape(const Shape& shape, Array* operand); StatusOr ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 373556ebeba883f7dc2116bdf0ffc3274182f775..5f4b42799b1c26ea544f9d4447cc45b5ae9d5a48 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -34,6 +36,27 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase { } private: + // Replaces seqences of whitespace with a single space. This makes the + // strings being matched against "whitespace insensitive" which lets us indent + // them for readability. + string CanonicalizeWhitespace(const string& text) { + string result; + + for (char c : text) { + if (!isspace(c)) { + result.push_back(c); + } else if (!result.empty() && result.back() != ' ') { + result.push_back(' '); + } + } + + while (!result.empty() && result.back() == ' ') { + result.pop_back(); + } + + return result; + } + void AssertArrayForRootExpressionIsImpl(const string& hlo_text, const string& root_expression, bool print_constants) { @@ -44,10 +67,10 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase { IndexedArrayAnalysis::Array* const array_result, indexed_tensor_analysis.GetArrayFor( module().entry_computation()->root_instruction())); - string string_result = - indexed_tensor_analysis.ToString(array_result, print_constants); + string string_result = CanonicalizeWhitespace( + indexed_tensor_analysis.ToString(array_result, print_constants)); LOG(INFO) << string_result; - ASSERT_EQ(string_result, root_expression); + ASSERT_EQ(string_result, CanonicalizeWhitespace(root_expression)); } }; @@ -91,6 +114,82 @@ ENTRY main { hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])"); } +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices = s32[5,2] parameter(0) + ROOT gather = s32[5] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed1) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3,1] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,2}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed2) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3,1] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,2,3] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={2}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={2,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed3) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,2} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { string hlo_text = R"( HloModule SimpleGather @@ -273,7 +372,157 @@ ENTRY main { "(scalar-indexed-const (constant s32[3,3,4]) %indices 0->[0,3])"); } -TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative0) { +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[2,6] constant(s32[2,6]{ + {1,2,3,4,5,6},{1,2,3,4,5,6}}) + indices = s32[1] parameter(0) + gather = s32[1,6] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT reshape = s32[1,1,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,1,6]) + (reshape %indices to s32[]) + 0->[]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } }) + + i.0 = s64[1,3]{1,0} parameter(0) + g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2}, + elided_window_dims={0}, gather_dims_to_operand_dims={0}, + index_vector_dim=2, window_bounds={1,3} + + i.1 = s64[1] parameter(1) + g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2}, + elided_window_dims={1}, gather_dims_to_operand_dims={1}, + index_vector_dim=1, window_bounds={1,1,3} + + ROOT reshape = s32[1,3]{1,0} reshape(g.1) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,3]) + (reshape + (scalar-indexed %i.0 %i.1 1->[1]) + to s64[]) + 0->[]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}}) + indices = s32[1] parameter(0) + gather = s32[1,6] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT reshape = s32[1,1,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[1,1,1,6]) + (reshape %indices to s32[]) + 0->[]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[1,2,6] constant(s32[1,2,6]{{ + {1,2,3,4,5,6},{1,2,3,4,5,6}}}) + indices = s32[1] parameter(0) + gather = s32[1,1,6] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={1,1,6} + ROOT reshape = s32[1,1,1,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,1,1,6] s32[2,1,1,1,6] { + { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } }, + { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } } }) + (reshape %indices to s32[]) + 0->[]) +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, + expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[2,6] constant(s32[2,6]{ + {1,2,3,4,5,6},{1,2,3,4,5,6}}) + indices = s32[1,5] parameter(0) + gather = s32[1,5,6] gather(operand, indices), + output_window_dims={2}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,6} + ROOT reshape = s32[1,1,5,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,1,6] s32[2,1,1,6] { + { /*i0=0*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } }, + { /*i0=1*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } } }) + (reshape %indices to s32[5]) + 0->[2]) +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, + expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) { string hlo_text = R"( HloModule ReshapeOfGather @@ -290,10 +539,19 @@ ENTRY main { } )"; - AssertArrayForRootExpressionIs(hlo_text, "%reshape"); + const char* expected_root_expression = R"( +(reshape + (scalar-indexed-const + (constant s32[3,4]) + %indices + 0->[0,2]) + to s32[5,2,2,2,3]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); } -TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative1) { +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) { string hlo_text = R"( HloModule ReshapeOfGather @@ -313,7 +571,48 @@ ENTRY main { } )"; - AssertArrayForRootExpressionIs(hlo_text, "%reshape"); + const char* expected_root_expression = R"( +(reshape + (scalar-indexed-const + (constant s32[3,5,2]) + %indices + 1->[2]) + to s32[6,7]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4,1] constant(s32[3,4,1]{ + {{1},{2},{3},{4}}, + {{1},{2},{3},{4}}, + {{1},{2},{3},{4}}}) + indices = s32[5,6] parameter(0) + gather = s32[5,4,6,1] gather(operand, indices), + output_window_dims={1,3}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4,1} + ROOT reshape = s32[5,2,2,2,3,1] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(reshape + (scalar-indexed-const + (constant s32[3,4,1]) + %indices + 0->[0,2]) + to s32[5,2,2,2,3,1]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); } TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) { @@ -500,5 +799,170 @@ ENTRY main { AssertArrayForRootExpressionIs(hlo_text, "%add"); } +TEST_F(IndexedArrayAnalysisTest, DotOpBasic_0) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + indices = s32[5] parameter(0) + dot_lhs = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[3,3] s32[3,3] { + { 70, 80, 90 }, + { 158, 184, 210 }, + { 246, 288, 330 } }) + %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpBasic_1) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}}) + indices = s32[5] parameter(0) + dot_lhs = s32[3,5] gather(gather_operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1} + ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[4,3] s32[4,3] { + { 84, 99, 114 }, + { 96, 114, 132 }, + { 108, 129, 150 }, + { 120, 144, 168 } }) + %indices 0->[1]))"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpBasic_2) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + indices = s32[5] parameter(0) + dot_rhs = s32[3,5] gather(gather_operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1} + ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[4,4] s32[4,4] { + { 38, 44, 50, 56 }, + { 83, 98, 113, 128 }, + { 128, 152, 176, 200 }, + { 173, 206, 239, 272 } }) + %indices 1->[1]) +)"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpBasic_3) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + indices = s32[5] parameter(0) + dot_rhs = s32[5,3] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} + ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[4,4] s32[4,4] { + { 14, 32, 50, 68 }, + { 32, 77, 122, 167 }, + { 50, 122, 194, 266 }, + { 68, 167, 266, 365 } }) + %indices 1->[0]) +)"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpWithBatch) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[2,3,2] constant(s32[2,3,2]{{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}}) + dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) + indices = s32[4] parameter(0) + dot_rhs = s32[2,3,4] gather(gather_operand, indices), + output_window_dims={0,1}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=1, + window_bounds={2,3,1} + ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs), + lhs_contracting_dims={2}, rhs_contracting_dims={1}, + lhs_batch_dims={0}, rhs_batch_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[2,2,2] s32[2,2,2] { + { { 22, 28 }, + { 49, 64 } }, + { { 220, 244 }, + { 301, 334 } } }) + %indices 3->[2]) +)"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpNegative) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}}) + indices = s32[2] parameter(0) + dot_lhs = s32[3,2] gather(gather_operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1} + ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, "%dot"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index d2af261008f40ee83e0676cfc7e67c45f8be1844..32937b33b3737482f07d4c7607f7f1c5c183a56b 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -51,10 +51,10 @@ TEST_F(InlinerTest, MapMax) { auto max_f32 = max_builder.Build(); auto builder = HloComputation::Builder("MapMaxFunction"); - auto lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); - auto rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({4, 3, 2, 1}))); + auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({4, 3, 2, 1}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); @@ -70,7 +70,7 @@ TEST_F(InlinerTest, MapMax) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = Literal::CreateR1({4, 3, 3, 4}); + auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } @@ -83,12 +83,12 @@ TEST_F(InlinerTest, MapConstant) { HloInstruction::CreateParameter(0, r0f32, "x")); (void)param1; const2_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); auto const2_f32 = const2_builder.Build(); auto builder = HloComputation::Builder("MapConstFunction"); auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); + LiteralUtil::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); @@ -104,7 +104,7 @@ TEST_F(InlinerTest, MapConstant) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = Literal::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); + auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } @@ -123,10 +123,10 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { auto max_f32 = max_builder.Build(); auto builder = HloComputation::Builder("MapSubFunction"); - auto lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); - auto rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({4, 3, 2, 1}))); + auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({4, 3, 2, 1}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); @@ -142,7 +142,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = Literal::CreateR1({3, 1, -1, -3}); + auto expected = LiteralUtil::CreateR1({3, 1, -1, -3}); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index abedb4063d3763516e66cff36633dbd90c8cafde..f33942d67907d8f40811bde5041350a2e1e1f1fc 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -73,6 +73,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kInfeed: + case HloOpcode::kIota: case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLt: @@ -83,6 +84,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kOr: + case HloOpcode::kXor: case HloOpcode::kOutfeed: case HloOpcode::kPad: case HloOpcode::kReal: @@ -96,9 +98,10 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: - case HloOpcode::kGenerateToken: + case HloOpcode::kAfterAll: case HloOpcode::kTranspose: case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: return false; // Cheap instructions for reals, but expensive for complex. @@ -117,6 +120,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDomain: @@ -138,6 +142,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kReduceWindow: case HloOpcode::kRemainder: case HloOpcode::kRng: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kSend: case HloOpcode::kSendDone: @@ -237,6 +242,30 @@ InstructionFusion::ComputeGloballyUnfusable( if (EffectivelyAtMostUnary(producer)) { continue; } + + // If the total size of the inputs is less than or equal to the total size + // of the outputs for the producer then duplicating it won't increase the + // memory traffic. In that case, we do not forbid fusion of the operation + // here. + auto total_size = [](const Shape& shape) { + int64 size = 0; + ShapeUtil::ForEachSubshape( + shape, + [&size](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsArray(subshape)) { + size += ShapeUtil::ElementsIn(subshape); + } + }); + return size; + }; + int64 operands_size = 0; + for (const HloInstruction* op : producer->operands()) { + operands_size += total_size(op->shape()); + } + if (operands_size <= total_size(producer->shape())) { + continue; + } + // Otherwise we will forbid fusing the op unless we can fuse it into // all of its consumers on all paths. // @@ -281,10 +310,8 @@ StatusOr InstructionFusion::Run(HloModule* module) { // map from HloInstruction* to the instruction's index in the vector. An // instruction is "removed" from the vector by setting it's element to // nullptr. - std::list post_order_list = + std::vector post_order = computation_->MakeInstructionPostOrder(); - std::vector post_order(post_order_list.begin(), - post_order_list.end()); tensorflow::gtl::FlatMap post_order_index; for (size_t i = 0; i < post_order.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 21db2338995960bde00ec9c4b325e5562fc3a592..9e7a15f0330d3f06779c850a4b575f84fe0b9505 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -167,7 +167,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); HloInstruction* binary1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0)); HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); @@ -258,7 +259,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { add = f32[4,3]{1,0} add(p0, p0) abs1 = f32[4,3]{1,0} abs(add) log = f32[4,3]{1,0} log(abs1) - send = f32[4,3]{1,0} send(log), channel_id=0 + token = token[] after-all() + send = f32[4,3]{1,0} send(log, token), channel_id=0 abs2 = f32[4,3]{1,0} abs(log) ROOT root = f32[4,3]{1,0} subtract(abs2, add) })") @@ -288,7 +290,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { p0 = f32[4,3]{1,0} parameter(0) add1 = f32[4,3]{1,0} add(p0, p0) log = f32[4,3]{1,0} log(p0) - send = f32[4,3]{1,0} send(log), channel_id=0 + token = token[] after-all() + send = f32[4,3]{1,0} send(log, token), channel_id=0 add2 = f32[4,3]{1,0} add(log, add1) ROOT root = f32[4,3]{1,0} subtract(add1, add2) })") @@ -321,7 +324,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { add1 = f32[4,3]{1,0} add(p0, p0) add2 = f32[4,3]{1,0} add(add1, add1) log = f32[4,3]{1,0} log(add2) - send = f32[4,3]{1,0} send(log), channel_id=0 + token = token[] after-all() + send = f32[4,3]{1,0} send(log, token), channel_id=0 sub1 = f32[4,3]{1,0} subtract(log, add2) sub2 = f32[4,3]{1,0} subtract(add2, add1) ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) @@ -352,7 +356,8 @@ TEST_F(InstructionFusionTest, AllowUnaryDuplication) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0")); HloInstruction* unary1 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0)); - builder.AddInstruction(HloInstruction::CreateSend(unary1, 0)); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + builder.AddInstruction(HloInstruction::CreateSend(unary1, token, 0)); HloInstruction* unary2 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1)); @@ -375,7 +380,8 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); HloInstruction* binary1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0)); HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 524d3234eb4eff9c7d000eca1a0d9f5c4fae90af..8652599dc6d48ff8c2aaa703fead161f891a57d1 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -74,7 +74,7 @@ cc_library( hdrs = ["executable.h"], deps = [ ":executor", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index c1666530687f2f8407a9dcb4e271c9d95552a689..9f8f4bda875cdff5e20fa8ca8eeecaa1140e2b9c 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -44,7 +44,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->mutable_device_entry_computation_layout()); + hlo_module->mutable_entry_computation_layout()); return pipeline.Run(hlo_module).status(); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 029e71058a7373b9310c6d9ffdb65f72ca28e5af..8d40c08d555a232b7cf3b81cc0f9970804c2f896 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -75,9 +75,9 @@ StatusOr InterpreterExecutable::ExecuteOnStream( // consumes. std::vector> arg_literals; for (int64 p = 0; p < computation->num_parameters(); ++p) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr arg_literal, - transfer_manager->TransferLiteralFromDevice(executor, *arguments[p])); + TF_ASSIGN_OR_RETURN(std::unique_ptr arg_literal, + transfer_manager->TransferLiteralFromDevice( + run_options->stream(), *arguments[p])); arg_literals.push_back(std::move(arg_literal)); } @@ -96,7 +96,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( result_literal->shape(), run_options->allocator(), executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - executor, *result_literal, result)); + run_options->stream(), *result_literal, result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 97e9fa2c8e8ecd918ffe3df2fd4e731f3b91e6db..4fb67bd0b72fc591c1ffa76ebb0513bf14ed3737 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -53,6 +53,7 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst, AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); }); + AsExecutorStream(stream)->BlockUntilDone(); return true; } @@ -61,6 +62,7 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); }); + AsExecutorStream(stream)->BlockUntilDone(); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 9b109022fbfc698f7dadc678ef837da270a5e74a..db6b910b32f8ec234c4cf1c331a1aa3bb2f9389f 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -104,7 +104,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { } // No "synchronize all activity" implemented for this platform at the moment. - bool SynchronizeAllActivity() override { return false; } + bool SynchronizeAllActivity() override { return true; } bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override { return false; } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 7067b6f86a0fb24fb946ad236bca9bbd48d53722..805fdb2d5bd8a08490b354d60f281c8f99bc20d8 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -30,10 +30,12 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" @@ -59,7 +61,6 @@ namespace xla { // anonymous namespace, instead of three or four spread all over this file. namespace { - } // namespace std::ostream& operator<<(std::ostream& out, @@ -113,14 +114,18 @@ LayoutConstraints::LayoutConstraints( HloComputation* computation) : points_to_analysis_(points_to_analysis), computation_(computation) { // Gather all array-shaped logical buffers into unconstrained_buffer_ids. - for (LogicalBuffer::Id id = 0; id < points_to_analysis_.num_logical_buffers(); - id++) { - auto& buffer = points_to_analysis_.logical_buffer(id); - // The points to analysis is computed per module, restrict constraints to - // array buffers in this computation. - if (buffer.IsArray() && buffer.instruction()->parent() == computation) { - unconstrained_buffer_ids_.insert(buffer.id()); - } + for (HloInstruction* inst : computation_->instructions()) { + points_to_analysis_.GetPointsToSet(inst).ForEachElement( + [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) { + for (const LogicalBuffer* buffer : buffers) { + // The points to analysis is computed per module, restrict + // constraints to array buffers in this computation. + if (buffer->IsArray() && + buffer->instruction()->parent() == computation) { + unconstrained_buffer_ids_.insert(buffer->id()); + } + } + }); } } @@ -175,41 +180,32 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, TF_RETURN_IF_ERROR( LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); - const BufferLayoutConstraint* curr_constraint = - GetBufferLayoutConstraint(buffer); - if (curr_constraint != nullptr) { - if (LayoutUtil::Equal(curr_constraint->layout(), layout)) { + auto iter = buffer_constraints_.find(&buffer); + if (iter != buffer_constraints_.end()) { + const BufferLayoutConstraint& curr_constraint = iter->second; + if (LayoutUtil::Equal(curr_constraint.layout(), layout)) { // New constraint matches existing constraint. Nothing to do. return Status::OK(); } - if (curr_constraint->mandatory()) { + if (curr_constraint.mandatory()) { return FailedPrecondition( "Buffer %s already has the layout constraint %s, cannot add " "incompatible constraint %s", buffer.ToString().c_str(), - LayoutUtil::HumanString(curr_constraint->layout()).c_str(), + LayoutUtil::HumanString(curr_constraint.layout()).c_str(), LayoutUtil::HumanString(layout).c_str()); } - } - - auto iter = buffer_constraints_.find(&buffer); - bool overwrite = iter != buffer_constraints_.end(); - if (!overwrite) { + iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); + } else { + TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1) + << buffer.ToString(); iter = buffer_constraints_ .insert(std::make_pair( &buffer, BufferLayoutConstraint(layout, buffer, mandatory, dfs))) .first; - } else { - iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); } added_constraints_.push_back(&iter->second); - - // Remove buffer from the set of unconstrained buffers. - TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) == - static_cast(!overwrite)); - unconstrained_buffer_ids_.erase(buffer.id()); - return Status::OK(); } @@ -401,6 +397,43 @@ string LayoutConstraints::ToString() const { return output; } +namespace { + +bool IsHostSendRecv(const HloInstruction* instruction) { + const HloSendRecvInstruction* send_recv_instr = + DynCast(instruction); + return send_recv_instr != nullptr && send_recv_instr->is_host_transfer(); +} + +} // namespace + +Status LayoutAssignment::BuildHostChannelConstraints( + HloComputation* computation) { + for (auto* instruction : computation->instructions()) { + const HloSendRecvInstruction* send_recv_instr = + DynCast(instruction); + if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) { + continue; + } + + // For host transfers the Send and Recv instruction carry the layout. + if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kRecv) { + const Shape& data_shape = + ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0); + TF_RET_CHECK(ShapeUtil::IsArray(data_shape)); + TF_RET_CHECK(LayoutUtil::HasLayout(data_shape)); + const Layout* prev_layout = host_channel_constraints_.ConstrainChannel( + send_recv_instr->channel_id(), data_shape.layout()); + TF_RET_CHECK(prev_layout == nullptr) + << "Cannot constrain host transfer layout as it was set to " + << LayoutUtil::HumanString(*prev_layout) << ": " + << send_recv_instr->ToString(); + } + } + return Status::OK(); +} + Status LayoutAssignment::AddMandatoryConstraints( const ComputationLayout* computation_layout, ChannelLayoutConstraints* channel_constraints, HloComputation* computation, @@ -408,6 +441,11 @@ Status LayoutAssignment::AddMandatoryConstraints( VLOG(3) << "Adding mandatory layout constraints to computation " << computation->name(); + auto get_channel_constraints = [&](const HloInstruction* instruction) { + return IsHostSendRecv(instruction) ? &host_channel_constraints_ + : channel_constraints; + }; + // Constrain layouts of instructions which define values with pre-existing // layouts. for (auto* instruction : computation->instructions()) { @@ -444,18 +482,21 @@ Status LayoutAssignment::AddMandatoryConstraints( if (instruction->opcode() == HloOpcode::kSend || instruction->opcode() == HloOpcode::kRecv) { - CHECK(channel_constraints) + CHECK(get_channel_constraints(instruction)) << "Multi-module layout assignment requires ChannelLayoutConstraints"; int64 channel_id = instruction->channel_id(); - if (!channel_constraints->IsChannelConstrained(channel_id)) { + if (!get_channel_constraints(instruction) + ->IsChannelConstrained(channel_id)) { continue; } if (instruction->opcode() == HloOpcode::kSend) { // TODO(b/68493863): Change to use SetOperandLayout(). const Shape send_buffer_shape = instruction->operand(0)->shape(); TF_RET_CHECK(ShapeUtil::IsArray(send_buffer_shape)); - Shape new_buffer_shape = channel_constraints->LayoutShapeForChannel( - send_buffer_shape, instruction->channel_id()); + Shape new_buffer_shape = + get_channel_constraints(instruction) + ->LayoutShapeForChannel(send_buffer_shape, + instruction->channel_id()); TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( new_buffer_shape, instruction->operand(0))); } else { @@ -466,8 +507,9 @@ Status LayoutAssignment::AddMandatoryConstraints( const LogicalBuffer* buffer, constraints->points_to_analysis().GetBufferDefinedAt(instruction, {0})); - Shape new_shape = channel_constraints->LayoutShapeForChannel( - recv_buffer_shape, instruction->channel_id()); + Shape new_shape = get_channel_constraints(instruction) + ->LayoutShapeForChannel( + recv_buffer_shape, instruction->channel_id()); TF_RETURN_IF_ERROR( constraints->SetBufferLayout(new_shape.layout(), *buffer)); } @@ -716,7 +758,8 @@ Status CheckParameterLayout(HloInstruction* parameter, const ComputationLayout& computation_layout) { const ShapeLayout& parameter_layout = computation_layout.parameter_layout(parameter->parameter_number()); - if (!parameter_layout.MatchesLayoutInShape(parameter->shape())) { + if (parameter_layout.LayoutIsSet() && + !parameter_layout.MatchesLayoutInShape(parameter->shape())) { return InternalError( "parameter instruction %s does not match layout of computation " "shape: %s", @@ -831,8 +874,8 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction, // HostCompute module. // Otherwise it is preferable to leave the new instruction without device, // and let the automatic device placer to choose the best location. - if (!sharding.HasUniqueDevice() || - HloSharding::IsReservedDevice(sharding.UniqueDevice().ValueOrDie())) { + auto device = sharding.UniqueDevice(); + if (!device || HloSharding::IsReservedDevice(*device)) { copy->set_sharding(sharding); } } @@ -936,14 +979,15 @@ LayoutAssignment::LayoutAssignment( ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), + saved_entry_computation_layout_(*entry_computation_layout), channel_layout_constraints_(channel_constraints) { + if (channel_layout_constraints_ != nullptr) { + // Save a copy of the input ChannelLayoutConstraints so that we can reset it + // if we have to undo previous operations (ClearPreviousPassSideEffects()). + channel_constraints_ = *channel_layout_constraints_; + } VLOG(1) << "Entry computation layout given to layout assignment: " << entry_computation_layout_->ToString(); - // Layouts of all parameter instructions must be set. - for (const ShapeLayout& parameter_layout : - entry_computation_layout_->parameter_layouts()) { - CHECK(parameter_layout.LayoutIsSet()); - } } std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( @@ -1184,7 +1228,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( const PointsToSet& points_to_set = constraints->points_to_analysis().GetPointsToSet(instruction); return points_to_set.ForEachElementWithStatus( - [this, &shape_layout, constraints]( + [&shape_layout, constraints]( const ShapeIndex& index, const PointsToSet::BufferList& buffers) -> Status { if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) { @@ -1519,7 +1563,7 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // and the computation result. The latter two are specified in // computation_layout, so we only need to keep the existing layouts for // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidently use the existing layout. + // layout assignment pass that may accidentally use the existing layout. for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kBitcast) { // bitcasts are inherently layout sensitive and so a bitcast instruction @@ -1542,6 +1586,10 @@ Status LayoutAssignment::RunOnComputation( ChannelLayoutConstraints* channel_constraints) { VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name() << ")"; + + // Must be run before clearing layouts. + TF_RETURN_IF_ERROR(BuildHostChannelConstraints(computation)); + TF_RETURN_IF_ERROR(ClearComputationLayouts(computation)); if (computation_layout != nullptr) { auto it = computation_layouts_.find(computation); @@ -1572,6 +1620,13 @@ Status LayoutAssignment::RunOnComputation( // Propagates layouts from mandatory and backend constraints. TF_RETURN_IF_ERROR(PropagateConstraints(&constraints)); + // Prior to applying default layouts, we take note of all HLO instructions + // which lack a layout constraint. + for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) { + unconstrained_layout_instructions_.insert( + points_to_analysis.GetBuffer(buffer_id).instruction()); + } + // While any unconstrained buffers remain, pick an arbitrary buffer, give it a // layout and propagate the change. while (!constraints.unconstrained_buffer_ids().empty()) { @@ -1614,13 +1669,65 @@ Status LayoutAssignment::RunOnComputation( // Record the layouts assigned for any communication ops in // channel_constraints so that they are constrained for future modules. + if (channel_constraints != nullptr) { + TF_RETURN_IF_ERROR( + ConstrainChannelLayouts(computation, channel_constraints)); + } + return Status::OK(); +} + +Status LayoutAssignment::ConstrainChannelLayouts( + HloComputation* computation, + ChannelLayoutConstraints* channel_constraints) { + auto get_channel_constraints = [&](const HloInstruction* instruction) { + return IsHostSendRecv(instruction) ? &host_channel_constraints_ + : channel_constraints; + }; + // We go through the kRecvDone before. These must either impose their layout, + // or find a matching one already existing (ConstrainChannel() returns + // nullptr). for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kRecvDone) { + const Layout* layout = + get_channel_constraints(instruction) + ->ConstrainChannel( + instruction->channel_id(), + ShapeUtil::GetSubshape(instruction->shape(), {0}).layout()); + TF_RET_CHECK(layout == nullptr) + << instruction->ToString() + << " cannot constrain layout as it was set to " + << LayoutUtil::HumanString(*layout); + } + } + // After that we go through the kSend. These are likely going to have a kCopy + // as operand (otherwise we add it), so in case the constrained layout does + // not match, we can change the kCopy layout (and the kSend one as well). + for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kSend) { - channel_constraints->ConstrainChannel( - instruction->channel_id(), instruction->operand(0)->shape().layout()); - } else if (instruction->opcode() == HloOpcode::kRecvDone) { - channel_constraints->ConstrainChannel(instruction->channel_id(), - instruction->shape().layout()); + HloInstruction* operand = instruction->mutable_operand(0); + const Layout* layout = get_channel_constraints(instruction) + ->ConstrainChannel(instruction->channel_id(), + operand->shape().layout()); + if (layout != nullptr) { + // We found an already constrained layout which does not match the one + // the kSend wants to impose. Either add a new kCopy, or use the + // existing one to marshal the correct shape. + Shape shape = operand->shape(); + *shape.mutable_layout() = *layout; + if (operand->opcode() != HloOpcode::kCopy) { + HloInstruction* copy = operand->parent()->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand)); + RegisterAddedCopy(copy); + SetupCopiedInstruction(*operand, copy, {}); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy)); + operand = copy; + } else { + *operand->mutable_shape() = shape; + } + Shape* send_shape = + ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0}); + *send_shape = shape; + } } } return Status::OK(); @@ -1672,13 +1779,14 @@ StatusOr LayoutAssignment::Run(HloModule* module) { // when seen from an outer instruction, which has across-computation // constraints to impose. // For example, the kWhile instruction needs to enforce the same layouts for - // the parameters and root of the bosy, as well as the condition parameters. + // the parameters and root of the body, as well as the condition parameters. // Similarly, the kConditional instruction needs to enforce the same layouts // for the root of the true and false computations. // So in the first pass, while allowing the layouts to flow to parameters and // root, we also fix up the eventually inconsistent ComputationLayout, which // will be then made mandatory by the second pass. for (int64 i = 0; i < 2; ++i) { + VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass"; TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module)); TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); @@ -1716,10 +1824,12 @@ StatusOr LayoutAssignment::Run(HloModule* module) { Status LayoutAssignment::Init() { computation_layouts_.clear(); + *entry_computation_layout_ = saved_entry_computation_layout_; return Status::OK(); } Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { + VLOG(5) << "Clearing previous side effects"; // Clear all the copies which have been added, and all the related // instructions (like GTE and tuples). int64 removed_copies = 0; @@ -1737,12 +1847,14 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { } } added_copies_.clear(); + unconstrained_layout_instructions_.clear(); if (removed_copies > 0) { TupleSimplifier tuple_simplifier; HloDCE dce; TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); } + ResetChannelConstraints(); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index c287cca0c54ba1bb514bd8d243c137eca99b258f..f9e8dbea2f8aa224318adf3cf4b5e493792d3093 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -249,25 +249,30 @@ class ChannelLayoutConstraints { // Given `shape`, apply the layout for `channel_id`. `channel_id` must already // be constrained. Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const { - CHECK(IsChannelConstrained(channel_id)); - *shape.mutable_layout() = constraints_.at(channel_id); + auto it = constraints_.find(channel_id); + CHECK(it != constraints_.end()) << "Channel " << channel_id; + *shape.mutable_layout() = it->second; return shape; } // Returns the layout constraint for `channel_id`, which must already be // constrained. - Layout LayoutForChannel(int64 channel_id) const { - CHECK(IsChannelConstrained(channel_id)); - return constraints_.at(channel_id); + const Layout& LayoutForChannel(int64 channel_id) const { + auto it = constraints_.find(channel_id); + CHECK(it != constraints_.end()) << "Channel " << channel_id; + return it->second; } // Adds a new layout constraint for `channel_id`. If a constraint for - // `channel_id` already exists, this operation requires that the new layout is - // the same as the previously constrained layout. - void ConstrainChannel(int64 channel_id, const Layout& layout) { - CHECK(!IsChannelConstrained(channel_id) || - LayoutUtil::Equal(layout, constraints_[channel_id])); - constraints_[channel_id] = layout; + // `channel_id` has been added, this API returns nullptr, otherwise returns + // the layout which has already been set for the channel. + const Layout* ConstrainChannel(int64 channel_id, const Layout& layout) { + auto it = constraints_.emplace(std::make_pair(channel_id, layout)); + if (it.second) { + return nullptr; + } + return LayoutUtil::Equal(layout, it.first->second) ? nullptr + : &it.first->second; } private: @@ -427,8 +432,13 @@ class LayoutAssignment : public HloPassInterface { Status PropagateComputationLayouts(HloComputation* computation, ComputationLayout* computation_layout); + // The pointer to the ComputationLayout passed as constructor parameter. ComputationLayout* entry_computation_layout_; + // A copy of entry_computation_layout_ used to reset it to the initial values + // during the multiple passes done by the layout assignment operation. + ComputationLayout saved_entry_computation_layout_; + protected: // Sets up the copy instruction according to the characteristic (sharding, // metadata, ...) of the reference instruction. The index argument is used @@ -464,6 +474,23 @@ class LayoutAssignment : public HloPassInterface { // itself). Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number); + // Apply the channel layout constraints by populating the channel_constraints + // data structure passed in at constructor time. Eventually adds copies in + // case two ends of a channel ended up with a different leyout. + Status ConstrainChannelLayouts(HloComputation* computation, + ChannelLayoutConstraints* channel_constraints); + + // Resets the input ChannelLayoutConstraints to the original copy received + // from the constructor input. + void ResetChannelConstraints() { + if (channel_layout_constraints_ != nullptr) { + *channel_layout_constraints_ = channel_constraints_; + } + } + + // Adds constraints related to host Send/Recv instructions. + Status BuildHostChannelConstraints(HloComputation* computation); + // Map containing the layouts of all computations assigned so // far. Computations are handled in a topological sort where computations are // handled before their caller instructions so the layouts of caller @@ -474,7 +501,23 @@ class LayoutAssignment : public HloPassInterface { // here. tensorflow::gtl::FlatSet added_copies_; - ChannelLayoutConstraints* channel_layout_constraints_; + // The pointer to the channel layout constraints passed in with the + // constructor. If not nullptr, this is an input/output argument. + ChannelLayoutConstraints* channel_layout_constraints_ = nullptr; + + // A copy of the input layout constraints used to reset the above pointer in + // case we have to undo operations due to the multiple passes over the + // computations/instructions. + ChannelLayoutConstraints channel_constraints_; + + // Layout constraints for send/recv instructions which communicate with the + // host. + ChannelLayoutConstraints host_channel_constraints_; + + // The set of HLO instructions which lacked any layout constraint, thus + // receiving propagated default layouts. + tensorflow::gtl::FlatSet + unconstrained_layout_instructions_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index bf0448a67674f24591d866b646b98aea09ebb12c..a16fa75e3032cfa4257d9b5608dd176fdb4ddbdb 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -52,10 +52,18 @@ using ::testing::ElementsAre; class LayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, - ComputationLayout* entry_computation_layout) { - LayoutAssignment layout_assignment(entry_computation_layout); + ComputationLayout* entry_computation_layout, + ChannelLayoutConstraints* channel_constraints = nullptr) { + LayoutAssignment layout_assignment( + entry_computation_layout, /*channel_constraints=*/channel_constraints); EXPECT_IS_OK(layout_assignment.Run(module).status()); } + + std::vector LayoutOf(HloModule* module, tensorflow::StringPiece name) { + auto minor_to_major = + FindInstruction(module, name)->shape().layout().minor_to_major(); + return std::vector(minor_to_major.begin(), minor_to_major.end()); + } }; TEST_F(LayoutAssignmentTest, ComputationLayout) { @@ -133,9 +141,9 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); - auto constant_literal1 = Literal::CreateR2WithLayout( + auto constant_literal1 = LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); - auto constant_literal2 = Literal::CreateR2WithLayout( + auto constant_literal2 = LiteralUtil::CreateR2WithLayout( {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); Shape ashape = constant_literal1->shape(); @@ -184,10 +192,10 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { // match their source). auto builder = HloComputation::Builder(TestName()); auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); @@ -221,10 +229,10 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { // Verify layouts of a select with tuple operands is assigned properly. auto builder = HloComputation::Builder(TestName()); auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple0 = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); @@ -232,7 +240,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { HloInstruction::CreateTuple({constant0, constant1})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); @@ -266,7 +274,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // tuple and assigning the layouts of the copied arrays as needed. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto inner_tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); auto nested_tuple = builder.AddInstruction( @@ -576,7 +584,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { auto builder = HloComputation::Builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7}); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(input_shape, constant, {})); auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -707,17 +715,10 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { LayoutUtil::MakeLayout({2, 1, 0})); AssignLayouts(module.get(), &computation_layout); - auto layout_of = [&](tensorflow::StringPiece name) { - return FindInstruction(module.get(), name) - ->shape() - .layout() - .minor_to_major(); - }; - - EXPECT_THAT(layout_of("gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(layout_of("gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(layout_of("gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(layout_of("fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0)); EXPECT_THAT(FindInstruction(module.get(), "gte1") ->shape() .tuple_shapes(0) @@ -769,9 +770,12 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { false_builder.AddInstruction( HloInstruction::CreateParameter(0, tshape, "param")); // Using infeed as layout assignment does not mess up with it. - auto infeed = - false_builder.AddInstruction(HloInstruction::CreateInfeed(xshape, "")); - false_builder.AddInstruction(HloInstruction::CreateTuple({infeed})); + auto token = false_builder.AddInstruction(HloInstruction::CreateToken()); + auto infeed = false_builder.AddInstruction( + HloInstruction::CreateInfeed(xshape, token, "")); + auto infeed_data = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(xshape, infeed, 0)); + false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data})); } HloComputation* false_computation = module->AddEmbeddedComputation(false_builder.Build()); @@ -798,7 +802,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { auto builder = HloComputation::Builder(TestName()); auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); builder.AddInstruction(HloInstruction::CreateUnary( constant0->shape(), HloOpcode::kBitcast, constant0)); @@ -816,5 +820,46 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { "Unexpected bitcast operation seen during layout assignment")); } +TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { + // Pin non matching layouts to parameter and root. + const char* module_str = R"( + HloModule test_module + + ENTRY entry_computation { + param = (f32[2,2]) parameter(0) + gte = f32[2,2] get-tuple-element(param), index=0 + token = token[] after-all() + recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1} + recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1, + sharding={maximal device=1} + ROOT root = f32[2,2] get-tuple-element(recv-done), index=0 + send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1, + sharding={maximal device=0} + send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0} + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + Shape param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); + TF_ASSERT_OK( + computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape( + param_shape)); + computation_layout.mutable_result_layout()->ResetLayout( + LayoutUtil::MakeLayout({1, 0})); + + ChannelLayoutConstraints channel_constraints; + AssignLayouts(module.get(), &computation_layout, &channel_constraints); + + EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0)); + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::GetSubshape( + FindInstruction(module.get(), "send")->shape(), {0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index f1e7fc29532ce7e6841010a5258f4000a7c70383..cdd3daf73b8ac1a4d1ec3c81224c2c0bfe8e5811 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -21,6 +21,11 @@ filegroup( ]), ) +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + cc_library( name = "alias_analysis", srcs = ["alias_analysis.cc"], @@ -37,12 +42,25 @@ cc_library( ], ) +tf_cc_test( + name = "alias_analysis_test", + srcs = ["alias_analysis_test.cc"], + deps = [ + ":alias_analysis", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/core:test", + ], +) + cc_library( name = "llvm_util", srcs = ["llvm_util.cc"], hdrs = ["llvm_util.h"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -106,12 +124,31 @@ cc_library( ], ) +cc_library( + name = "kernel_tiling", + srcs = ["kernel_tiling.cc"], + hdrs = ["kernel_tiling.h"], + deps = [ + ":ir_array", + ":llvm_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + cc_library( name = "fused_ir_emitter", srcs = ["fused_ir_emitter.cc"], hdrs = ["fused_ir_emitter.h"], deps = [ ":ir_array", + ":kernel_tiling", ":llvm_util", ":loop_emitter", ":tuple_ops", @@ -127,9 +164,9 @@ cc_library( ) cc_library( - name = "ops", - srcs = ["ops.cc"], - hdrs = ["ops.h"], + name = "dynamic_update_slice_util", + srcs = ["dynamic_update_slice_util.cc"], + hdrs = ["dynamic_update_slice_util.h"], deps = [ ":fused_ir_emitter", ":ir_array", @@ -143,6 +180,23 @@ cc_library( ], ) +cc_library( + name = "sort_util", + srcs = ["sort_util.cc"], + hdrs = ["sort_util.h"], + deps = [ + ":ir_array", + ":llvm_loop", + ":llvm_util", + ":loop_emitter", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", + "//tensorflow/compiler/xla/service/gpu:partition_assignment", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + cc_library( name = "tuple_ops", srcs = ["tuple_ops.cc"], @@ -169,3 +223,22 @@ cc_library( "@llvm//:core", ], ) + +cc_library( + name = "buffer_assignment_util", + srcs = ["buffer_assignment_util.cc"], + hdrs = ["buffer_assignment_util.h"], + deps = [ + "//tensorflow/compiler/xla/service:buffer_assignment", + ], +) + +cc_library( + name = "math_ops", + srcs = ["math_ops.cc"], + hdrs = ["math_ops.h"], + deps = [ + ":llvm_util", + "@llvm//:core", + ], +) diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 21bca1d6beff5b2804531724b94b123d4523c173..e5370eca56f2e3a891523ba2b72961d66ec809aa 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -28,19 +28,20 @@ namespace llvm_ir { // Sentry allocation used to represent parameters of the entry computation in // alias_scope_metadata_ and noalias_metadata_. static const BufferAllocation* kParameterAllocation = new BufferAllocation( - /*index=*/-1, /*size=*/0, /*is_thread_local=*/false, /*is_reusable=*/false, - LogicalBuffer::Color(0)); + /*index=*/-1, /*size=*/0, LogicalBuffer::Color(0)); void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, - llvm_ir::IrArray* array) { + llvm_ir::IrArray* array, + const ShapeIndex& index) { BufferAllocation::Slice buffer_slice; - if (hlo.opcode() == HloOpcode::kParameter) { - // Parameters may alias with each other but may not alias with our temporary - // buffers. + if (hlo.opcode() == HloOpcode::kParameter && + hlo.parent() == hlo.parent()->parent()->entry_computation()) { + // Entry computation parameters may alias with each other but may not alias + // with our temporary buffers. buffer_slice = BufferAllocation::Slice(kParameterAllocation, 0, 0); } else { const std::set slices = - assignment_.GetAllSlices(&hlo, /*index=*/{}); + assignment_.GetAllSlices(&hlo, index); if (slices.empty() || slices.size() > 1) { // Skip HLOs which don't have a buffer assigned or for which the // buffer can't be determined statically. We cannot determine their @@ -137,16 +138,18 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( // 2. Operands of users of the given hlo. // 3. Operands of the given hlo. // - // This set can be increased as we need. For now only consider top-level - // buffers (index = {}) not buffers nested within the instruction's - // operands/output which are not typically touched. + // This set can be increased as we need. std::vector worklist; auto add_buffers_to_worklist = [&worklist, &assignment](const HloInstruction* instruction) { - for (const LogicalBuffer* buffer : - assignment.GetSourceBuffers(instruction, /*index=*/{})) { - worklist.push_back(buffer); - } + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&](const Shape& /*shape*/, const ShapeIndex& index) { + for (const LogicalBuffer* buffer : + assignment.GetSourceBuffers(instruction, index)) { + worklist.push_back(buffer); + } + }); }; for (HloInstruction* user : hlo.users()) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 5244ac61e56307857aca659854647bd6c3e991d7..fe9eab93aae95557e3ee27a64c09b78f37ac2348 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -38,7 +38,8 @@ class AliasAnalysis { // Augments IrArray with aliasing information. void AddAliasingInformationToIrArray(const HloInstruction& hlo, - llvm_ir::IrArray* array); + llvm_ir::IrArray* array, + const ShapeIndex& index = {}); private: // Returns a unique alias domain for this emitter. diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fe5ec1cc66d06e85ce70625ef7cf764a37b29166 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { +class AliasAnalysisTest : public CpuCodegenTest {}; + +void FakeCustomCallTarget(float* out, float** in) {} + +REGISTER_CUSTOM_CALL_TARGET(FakeCustomCallTarget); + +TEST_F(AliasAnalysisTest, EmbeddedComputationParamsMayAliasTemps) { + const char* hlo_string = R"( +HloModule while + +body { + const.0.125 = f32[] constant(0.125) + body.state = f32[] parameter(0) + ROOT add.2.2 = f32[] add(const.0.125, body.state) +} + +condition { + const.100 = f32[] constant(100) + condition.state = f32[] parameter(0) + addend = f32[] custom-call(condition.state), custom_call_target="FakeCustomCallTarget" + add = f32[] add(addend, condition.state) + ROOT greater-than = pred[] greater-than(const.100, add) +} + +ENTRY while3 { + const.0 = f32[] constant(0) + ROOT while = f32[] while(const.0), condition=condition, body=body +} +)"; + + CompileAndVerifyIr(hlo_string, R"( +; CHECK-LABEL: @body(i8* %retval +; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] +; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]] +; +; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params +; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0 +; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]] +; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float* +; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]] +; +; CHECK-LABEL: @while3( + +![[alias_scope_md_for_store]] = !{![[buffer_idx_0:.*]]} +![[buffer_idx_0]] = !{!"buffer: {index:0, offset:0, size:4}", ![[aa_md_root:.*]]} +![[aa_md_root]] = !{!"XLA global AA domain"} +![[buffer_idx_1:.*]] = !{!"buffer: {index:1, offset:0, size:4}", !3} +![[buffer_idx_1_offset_16:.*]] = !{!"buffer: {index:1, offset:16, size:1}", !3} +![[noalias_md_for_load]] = !{![[buffer_idx_1_offset_16]], ![[buffer_idx_1]]} +} +)"); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..4eb5d9fb4750927ca189e02f312b2d6be7fdd418 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" + +namespace xla { +namespace llvm_ir { +static const HloInstruction& InstrForConstantBufferAllocation( + const BufferAllocation& allocation) { + CHECK(allocation.is_constant()); + HloInstruction* const_instr = nullptr; + for (const auto& buffer_offset_pair : allocation.assigned_buffers()) { + const LogicalBuffer* buffer = buffer_offset_pair.first; + // BufferAssignment may have assigned non-constant instructions to this + // allocation too so we can't CHECK this condition. E.g. for + // + // while(init = constant, body = identity, cond = ...) + // + // the LogicalBuffer for the kWhile instruction will have the same + // BufferAllocation as the LogicalBuffer for the (init) constant. + if (buffer->instruction()->opcode() == HloOpcode::kConstant) { + CHECK_EQ(const_instr, nullptr) + << const_instr->ToString() << " " << buffer->ToString(); + const_instr = buffer->instruction(); + } + } + CHECK_NE(const_instr, nullptr); + return *const_instr; +} + +string ConstantBufferAllocationToGlobalName( + const BufferAllocation& allocation) { + string instr_name = InstrForConstantBufferAllocation(allocation).name(); + for (char& c : instr_name) { + if (c == '.') { + c = '_'; + } + } + return tensorflow::strings::StrCat("buffer_for_", instr_name); +} + +const Literal& LiteralForConstantAllocation( + const BufferAllocation& allocation) { + return InstrForConstantBufferAllocation(allocation).literal(); +} +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h new file mode 100644 index 0000000000000000000000000000000000000000..bfb6eecb87f6a1b756b3a8da3377f608dd7f0be7 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" + +namespace xla { +namespace llvm_ir { +// In XLA:GPU we map constant buffer allocations to globals in the generated +// LLVM IR. This function gives us the name of the global variable a constant +// buffer is mapped to. Not used on XLA:CPU. +string ConstantBufferAllocationToGlobalName(const BufferAllocation& allocation); + +// Returns the Literal corresponding to `allocation`, which must be a constant +// allocation. +const Literal& LiteralForConstantAllocation(const BufferAllocation& allocation); +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc similarity index 75% rename from tensorflow/compiler/xla/service/llvm_ir/ops.cc rename to tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index dacc54742c0897bbd92315f1e33a484aae56bb7f..27fbb11e2ede66a1268e7e949634b2c7d29cbc1c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" @@ -38,16 +38,16 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, // Emits a sequential loop if launch_dimensions is null. static Status EmitDynamicUpdateSliceInPlaceImpl( const Shape& update_shape, const ElementGenerator& start_indices_generator, - ElementGenerator update_array_generator, const IrArray& output_array, - const gpu::LaunchDimensions* launch_dimensions, - tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) { + bool is_signed, ElementGenerator update_array_generator, + const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions, + tensorflow::StringPiece name, llvm::IRBuilder<>* b) { const Shape& output_shape = output_array.GetShape(); // Read start indices from start_indices_generator. const int64 rank = ShapeUtil::Rank(output_shape); - IrArray::Index start_index(rank); + IrArray::Index start_index(b->getInt64Ty(), rank); for (int64 i = 0; i < rank; ++i) { - IrArray::Index dim_index({ir_builder->getInt64(i)}); + IrArray::Index dim_index({b->getInt64(i)}); TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); llvm::Value* output_dim_size = llvm::ConstantInt::get( start_index[i]->getType(), output_shape.dimensions(i)); @@ -56,21 +56,19 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, output_dim_size - update_dim_size) - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. - llvm::Value* max_bound = - ir_builder->CreateSub(output_dim_size, update_dim_size); + llvm::Value* max_bound = b->CreateSub(output_dim_size, update_dim_size); llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0); - start_index[i] = ir_builder->CreateSelect( - ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SGE, zero, start_index[i]), - zero, start_index[i]); - - start_index[i] = ir_builder->CreateSelect( - ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SLE, max_bound, - start_index[i]), - max_bound, start_index[i]); + start_index[i] = + b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE + : llvm::ICmpInst::ICMP_UGE, + zero, start_index[i]), + zero, start_index[i]); + + start_index[i] = + b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE + : llvm::ICmpInst::ICMP_ULE, + max_bound, start_index[i]), + max_bound, start_index[i]); } auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { @@ -79,33 +77,32 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // // output_index[dim] = start_index[dim] + update_index[dim] // - IrArray::Index output_index(rank); + IrArray::Index output_index(start_index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { - llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast( - start_index[i], update_index[i]->getType()); - output_index[i] = ir_builder->CreateAdd(start_index0, update_index[i]); + llvm::Value* start_index0 = + b->CreateSExtOrBitCast(start_index[i], update_index[i]->getType()); + output_index[i] = b->CreateAdd(start_index0, update_index[i]); } // Do output[output_index] = update[update_index]. TF_ASSIGN_OR_RETURN(llvm::Value * update_data, update_array_generator(update_index)); - output_array.EmitWriteArrayElement(output_index, update_data, ir_builder); + output_array.EmitWriteArrayElement(output_index, update_data, b); return Status::OK(); }; if (launch_dimensions != nullptr) { return gpu::ParallelLoopEmitter(loop_body_emitter, update_shape, - *launch_dimensions, ir_builder) + *launch_dimensions, b) .EmitLoop(name); } - return LoopEmitter(loop_body_emitter, update_shape, ir_builder) - .EmitLoop(name); + return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name); } Status EmitDynamicUpdateSliceInPlace( tensorflow::gtl::ArraySlice operand_arrays, const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name; // No need to use operand_arrays[0], the input array of the @@ -116,15 +113,16 @@ Status EmitDynamicUpdateSliceInPlace( Shape update_shape = update_array.GetShape(); ElementGenerator start_indices_generator = [&](const IrArray::Index& index) { - return start_indices_array.EmitReadArrayElement(index, ir_builder); + return start_indices_array.EmitReadArrayElement(index, b); }; ElementGenerator update_array_generator = [&](const IrArray::Index& index) { - return update_array.EmitReadArrayElement(index, ir_builder); + return update_array.EmitReadArrayElement(index, b); }; + bool is_signed = ShapeUtil::ElementIsSigned(start_indices_array.GetShape()); return EmitDynamicUpdateSliceInPlaceImpl( - update_shape, start_indices_generator, update_array_generator, - output_array, /*launch_dimensions=*/nullptr, name, ir_builder); + update_shape, start_indices_generator, is_signed, update_array_generator, + output_array, /*launch_dimensions=*/nullptr, name, b); } // Shared implementation for EmitFusedDynamicUpdateSliceInPlace and @@ -135,8 +133,7 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( HloInstruction* fusion, tensorflow::gtl::ArraySlice fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, - const gpu::LaunchDimensions* launch_dimensions, - llvm::IRBuilder<>* ir_builder) { + const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for " << fusion->ToShortString(); @@ -170,30 +167,30 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( ElementGenerator start_indices_generator = fused_emitter.GetGenerator(start_indices); + bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape()); return EmitDynamicUpdateSliceInPlaceImpl( - update_shape, start_indices_generator, update_array_generator, - fusion_output_array, launch_dimensions, IrName(fusion), ir_builder); + update_shape, start_indices_generator, is_signed, update_array_generator, + fusion_output_array, launch_dimensions, IrName(fusion), b); } Status EmitFusedDynamicUpdateSliceInPlace( HloInstruction* fusion, tensorflow::gtl::ArraySlice fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, - /*launch_dimensions=*/nullptr, ir_builder); + /*launch_dimensions=*/nullptr, b); } Status EmitParallelFusedDynamicUpdateSliceInPlace( HloInstruction* fusion, tensorflow::gtl::ArraySlice fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, - const gpu::LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* ir_builder) { + const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, - &launch_dimensions, ir_builder); + &launch_dimensions, b); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h similarity index 91% rename from tensorflow/compiler/xla/service/llvm_ir/ops.h rename to tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index 175b081e84d31779b15560cb0998011fe046ca01..3502577d236a099e0b721b98217b758696966821 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_ #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" @@ -66,7 +66,7 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace( Status EmitDynamicUpdateSliceInPlace( tensorflow::gtl::ArraySlice operand_arrays, const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Given a loop-fusion node whose root is a dynamic-update-slice op whose // array-to-be-updated and output share the same buffer slice, emits @@ -76,7 +76,7 @@ Status EmitFusedDynamicUpdateSliceInPlace( HloInstruction* fusion, tensorflow::gtl::ArraySlice fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with // the given launch dimensions. @@ -84,10 +84,9 @@ Status EmitParallelFusedDynamicUpdateSliceInPlace( HloInstruction* fusion, tensorflow::gtl::ArraySlice fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, - const gpu::LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* ir_builder); + const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b); } // namespace llvm_ir } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index d909845a3a21fc55e44b0037371fca30e577980f..72ede377e1a505d5e4916915e18827e1a0f3fdf9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -52,7 +52,7 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { // that would be regenerated without caching. But this might increase the // JIT compilation time. if (generated_value_bb == nullptr || - generated_value_bb == ir_builder_->GetInsertBlock()) { + generated_value_bb == b_->GetInsertBlock()) { VLOG(3) << "The cached generated value is reused."; return generated_value; } @@ -60,8 +60,7 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { "a different BB (" << llvm_ir::AsString(generated_value_bb->getName()) << ") from the current insertion block (" - << llvm_ir::AsString(ir_builder_->GetInsertBlock()->getName()) - << ")."; + << llvm_ir::AsString(b_->GetInsertBlock()->getName()) << ")."; } TF_ASSIGN_OR_RETURN( @@ -77,14 +76,14 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { llvm::Constant* initializer = llvm_ir::ConvertLiteralToIrConstant(literal, module_); llvm::GlobalVariable* global = new llvm::GlobalVariable( - *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), + *b_->GetInsertBlock()->getModule(), initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, /*Name=*/""); llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); generators_[constant] = [=](const IrArray::Index& index) { return IrArray(shape_constant, constant->shape()) - .EmitReadArrayElement(index, ir_builder_); + .EmitReadArrayElement(index, b_); }; return Status::OK(); @@ -104,7 +103,7 @@ Status FusedIrEmitter::HandleGetTupleElement( // Emit code to lookup tuple element pointer, and store it in 'gte_values_'. llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement( get_tuple_element->shape(), get_tuple_element->tuple_index(), - /*alignment=*/1, it->second, ir_builder_, module_); + /*alignment=*/1, it->second, b_, module_); gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr)); // Emit code to read base tuple element array (if non-tuple shaped). if (!ShapeUtil::IsTuple(get_tuple_element->shape())) { @@ -112,16 +111,32 @@ Status FusedIrEmitter::HandleGetTupleElement( [=](const IrArray::Index& index) -> StatusOr { // TODO(b/34080002) Add aliasing information to tuple element IrArray. return IrArray(tuple_element_ptr, get_tuple_element->shape()) - .EmitReadArrayElement(index, ir_builder_); + .EmitReadArrayElement(index, b_); }; } return Status::OK(); } Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { - generators_[parameter] = [=](const IrArray::Index& index) { + generators_[parameter] = [=](const IrArray::Index& index) -> llvm::Value* { + if (tiled_parameter_info_) { + if (llvm::Value* param_tile_buffer = + tiled_parameter_info_->GetBufferForParameter( + parameter->parameter_number())) { + // TODO(jlebar): Add AA metadata to this load. Tile buffers are global + // variables, so LLVM's points-to analysis doesn't help us much. And we + // want the AA info to be present before address spaces are inferred + // (which is pretty late in the pipeline), so even if we had + // address-space-based AA in LLVM, it wouldn't help us much here. + return b_->CreateLoad( + b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0), + tiled_parameter_info_->x(), + tiled_parameter_info_->y()}), + "tiled_buffer"); + } + } return parameter_arrays_[parameter->parameter_number()] - .EmitReadArrayElement(index, ir_builder_); + .EmitReadArrayElement(index, b_); }; // Store ir value for fusion operand associated with fusion parameter to be // accessed by subsequent fused GetTupleElement instructions. @@ -140,11 +155,11 @@ Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { } generators_[tuple] = [=](const IrArray::Index& index) -> StatusOr { - llvm::Value* ret = llvm::UndefValue::get(llvm::StructType::get( - ir_builder_->getContext(), operand_elemental_ir_types)); + llvm::Value* ret = llvm::UndefValue::get( + llvm::StructType::get(b_->getContext(), operand_elemental_ir_types)); for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) { TF_ASSIGN_OR_RETURN(llvm::Value * val_i, generators_[operands[i]](index)); - ret = ir_builder_->CreateInsertValue(ret, val_i, i); + ret = b_->CreateInsertValue(ret, val_i, i); } return ret; }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index b3b6026ef17daa184c0a015fdea618597ef068b3..30471480c4fb3ce3bf3226a28e9d2ffa79ae5f29 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -56,8 +57,9 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { FusedIrEmitter(tensorflow::gtl::ArraySlice parameter_arrays, ElementalIrEmitter* elemental_emitter) : parameter_arrays_(parameter_arrays), + tiled_parameter_info_(nullptr), elemental_emitter_(elemental_emitter), - ir_builder_(elemental_emitter->ir_builder()), + b_(elemental_emitter->b()), module_(elemental_emitter->module()) {} Status DefaultAction(HloInstruction* hlo) override; @@ -86,9 +88,14 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { return it->second; } + void SetTiledParameterInfo(const llvm_ir::TiledParameterInfo* info) { + tiled_parameter_info_ = info; + } + private: // Arrays of parameters of fusion instruction tensorflow::gtl::ArraySlice parameter_arrays_; + const llvm_ir::TiledParameterInfo* tiled_parameter_info_; ElementalIrEmitter* elemental_emitter_; @@ -96,7 +103,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { const HloInstruction* fused_root_ = nullptr; // Borrowed - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; llvm::Module* module_; // Map from instruction pointers to functions to generate elements of their diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 7323abeb2077154f82828bcda3e90eb45a67138a..2b6caee6aa72f426cf85c8c56c3ef500ff8c5d3d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -29,9 +29,9 @@ limitations under the License. namespace xla { namespace llvm_ir { -static void Delinearize(std::vector* multidim, - llvm::Value* linear, const Shape& shape, - llvm::IRBuilder<>* ir_builder) { +void IrArray::Index::Delinearize(std::vector* multidim, + llvm::Value* linear, const Shape& shape, + llvm::IRBuilder<>* b) const { int64 divisor = 1; const Layout& layout = shape.layout(); for (int64 i = 0; i < layout.minor_to_major_size(); ++i) { @@ -48,10 +48,10 @@ static void Delinearize(std::vector* multidim, // useful because cuda-memcheck can't help us much in XLA: Most of our // memory lives in one big allocation, so cuda-memcheck can't detect // out-of-bounds accesses. - auto* quot = ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)); + auto* quot = b->CreateUDiv(linear, GetConstantWithIndexType(divisor)); if (i < layout.minor_to_major_size() - 1) { - (*multidim)[dimension] = ir_builder->CreateURem( - quot, ir_builder->getInt64(size_of_current_dimension)); + (*multidim)[dimension] = b->CreateURem( + quot, GetConstantWithIndexType(size_of_current_dimension)); } else { (*multidim)[dimension] = quot; } @@ -60,15 +60,17 @@ static void Delinearize(std::vector* multidim, } IrArray::Index::Index(llvm::Value* linear, const Shape& shape, - llvm::IRBuilder<>* ir_builder) + llvm::IRBuilder<>* b) : multidim_(ShapeUtil::Rank(shape)), linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK_NE(linear, nullptr); + index_type_ = linear->getType(); CHECK(LayoutUtil::HasLayout(shape)) << "Shape " << ShapeUtil::HumanStringWithLayout(shape) << " should have a layout."; - Delinearize(&multidim_, linear, shape, ir_builder); + Delinearize(&multidim_, linear, shape, b); } IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, @@ -77,6 +79,13 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { + if (size()) { + index_type_ = multidim_[0]->getType(); + } else { + CHECK_NE(linear_, nullptr); + index_type_ = linear_->getType(); + } + CHECK_NE(index_type_, nullptr); CHECK_EQ(shape.dimensions_size(), multidim.size()); CHECK(LayoutUtil::HasLayout(shape)) << "Shape " << ShapeUtil::HumanStringWithLayout(shape) @@ -84,10 +93,13 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, } IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, - const Shape& shape, llvm::IRBuilder<>* ir_builder) + const Shape& shape, llvm::IRBuilder<>* b) : multidim_(multidim.begin(), multidim.end()), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK_GT(multidim_.size(), 0); + index_type_ = multidim[0]->getType(); + CHECK_NE(index_type_, nullptr); CHECK_EQ(shape.dimensions_size(), multidim.size()); CHECK(LayoutUtil::HasLayout(shape)); } @@ -130,15 +142,15 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); std::vector source_multidim_index( - ShapeUtil::Rank(input_shape), - llvm::UndefValue::get(builder->getInt64Ty())); + ShapeUtil::Rank(input_shape), llvm::UndefValue::get(index_type_)); // We compute the source indices in each common factor from only the target // indices in the same common factor. for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { llvm::Value* logical_linear_index = Index(tensorflow::gtl::ArraySlice( multidim_, common_factors[k].second, - common_factors[k + 1].second - common_factors[k].second)) + common_factors[k + 1].second - common_factors[k].second), + index_type_) .Linearize( tensorflow::gtl::ArraySlice( AsInt64Slice(output_shape.dimensions()), @@ -150,9 +162,10 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( // linear index by each dimension size. for (int64 i = common_factors[k + 1].first - 1; i >= common_factors[k].first; --i) { - llvm::Value* divisor = builder->getInt64(input_shape.dimensions(i)); + llvm::Value* divisor = + GetConstantWithIndexType(input_shape.dimensions(i)); if (input_shape.dimensions(i) == 1) { - source_multidim_index[i] = builder->getInt64(0); + source_multidim_index[i] = GetConstantWithIndexType(0); } else if (i == common_factors[k].first) { source_multidim_index[i] = logical_linear_index; } else { @@ -168,14 +181,14 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) { return Index(source_multidim_index, linear(), input_shape); } - return Index(source_multidim_index); + return Index(source_multidim_index, index_type_); } IrArray::Index IrArray::Index::SourceIndexOfSlice( const Shape& shape, tensorflow::gtl::ArraySlice starts, tensorflow::gtl::ArraySlice strides, llvm::IRBuilder<>* builder) const { - Index source_index(multidim_.size()); + Index source_index(index_type_, multidim_.size()); for (int i = 0; i < multidim_.size(); ++i) { int64 stride = strides[i]; auto type = multidim_[i]->getType(); @@ -224,11 +237,12 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( // the physical index of the element in the buffer. This is like Linearize, // but takes the layout into account. int64 scale = 1; - llvm::Value* linear_index = builder->getInt64(0); + llvm::Value* linear_index = GetConstantWithIndexType(0); for (auto dimension : LayoutUtil::MinorToMajor(shape)) { linear_index = builder->CreateAdd( linear_index, - builder->CreateMul(multidim_[dimension], builder->getInt64(scale), "", + builder->CreateMul(multidim_[dimension], + GetConstantWithIndexType(scale), "", /*HasNUW=*/true, /*HasNSW=*/true), "", /*HasNUW=*/true, /*HasNSW=*/true); scale *= shape.dimensions(dimension); @@ -252,7 +266,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( } if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) || !LayoutUtil::HasLayout(shape)) { - return Index(source_index); + return Index(source_index, index_type_); } // High-level idea: we can reuse the linear index if the broadcasted // dimensions are contiguous, and this part of the operation is a bitcast. @@ -274,7 +288,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( bool contiguous_broadcast_dimensions = max_broadcasted_dimension - min_broadcasted_dimension == rank - 1; if (!contiguous_broadcast_dimensions) { - return Index(source_index); + return Index(source_index, index_type_); } // Check if the mapped dimensions are a bitcast. std::vector operand_logical_to_physical = @@ -282,7 +296,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( for (int64 i = 0; i < rank; ++i) { if (operand_logical_to_physical[i] != logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) { - return Index(source_index); + return Index(source_index, index_type_); } } llvm::Value* linear = linear_; @@ -291,7 +305,9 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } if (divisor > 1) { - linear = builder->CreateUDiv(linear, builder->getInt64(divisor)); + linear = builder->CreateUDiv( + linear, + IrArray::Index(linear->getType()).GetConstantWithIndexType(divisor)); } if (min_broadcasted_dimension > 0) { int64 mod = 1; @@ -299,7 +315,9 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( ++i) { mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } - linear = builder->CreateURem(linear, builder->getInt64(mod)); + linear = builder->CreateURem( + linear, + IrArray::Index(linear->getType()).GetConstantWithIndexType(mod)); } return Index(source_index, linear, operand_shape); } @@ -309,12 +327,14 @@ llvm::Value* IrArray::Index::Linearize( llvm::IRBuilder<>* builder) const { // Each dimension is multiplied by the product of the sizes of all // earlier dimensions and added to the accumulator logical_linear_index. - llvm::Value* logical_linear_index = builder->getInt64(0); + CHECK_EQ(size(), dimensions.size()); + llvm::Value* logical_linear_index = GetConstantWithIndexType(0); int64 multiplier = 1; for (ssize_t i = size() - 1; i >= 0; --i) { llvm::Value* addend = - builder->CreateMul((*this)[i], builder->getInt64(multiplier), "", + builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "", /*HasNUW=*/true, /*HasNSW=*/true); + addend = builder->CreateZExtOrTrunc(addend, index_type_); logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "", /*HasNUW=*/true, /*HasNSW=*/true); multiplier *= dimensions[i]; @@ -323,7 +343,7 @@ llvm::Value* IrArray::Index::Linearize( } llvm::Value* IrArray::EmitArrayElementAddress( - const IrArray::Index& index, llvm::IRBuilder<>* ir_builder, + const IrArray::Index& index, llvm::IRBuilder<>* b, tensorflow::StringPiece name) const { if (ShapeUtil::IsScalar(*shape_)) { // Special handling of scalars: a scalar pretends to have the same value for @@ -334,12 +354,11 @@ llvm::Value* IrArray::EmitArrayElementAddress( CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_)); if (index.LinearValidOnShape(*shape_)) { - llvm::Module* module = - ir_builder->GetInsertBlock()->getParent()->getParent(); - return ir_builder->CreateInBoundsGEP( - ir_builder->CreateBitCast( - base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), module) - ->getPointerTo()), + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + return b->CreateInBoundsGEP( + b->CreateBitCast(base_ptr_, + PrimitiveTypeToIrType(shape_->element_type(), module) + ->getPointerTo()), {index.linear()}, llvm_ir::AsStringRef(name)); } @@ -349,7 +368,8 @@ llvm::Value* IrArray::EmitArrayElementAddress( // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to // produce better code in some cases. auto dim = shape_->dimensions(i); - actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); + actual_index.push_back( + dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]); } // "base_ptr_" has the type of "*" @@ -357,13 +377,15 @@ llvm::Value* IrArray::EmitArrayElementAddress( // should be computed by // // getelementptr base_ptr_, 0, most major index, ..., most minor index - std::vector gep_indices(1, ir_builder->getInt64(0)); + CHECK_GT(index.size(), 0); + std::vector gep_indices( + 1, llvm::ConstantInt::get(index[0]->getType(), 0)); for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) { int64 dimension = LayoutUtil::Major(shape_->layout(), i); gep_indices.push_back(actual_index[dimension]); } - return ir_builder->CreateInBoundsGEP(base_ptr_, gep_indices, - llvm_ir::AsStringRef(name)); + return b->CreateInBoundsGEP(base_ptr_, gep_indices, + llvm_ir::AsStringRef(name)); } void IrArray::AnnotateLoadStoreInstructionWithMetadata( @@ -379,38 +401,40 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata( } llvm::Value* IrArray::EmitReadArrayElement(const Index& index, - llvm::IRBuilder<>* ir_builder, + llvm::IRBuilder<>* b, tensorflow::StringPiece name) const { - llvm::Value* element_address = - EmitArrayElementAddress(index, ir_builder, name); - llvm::LoadInst* load = ir_builder->CreateLoad(element_address); + llvm::Value* element_address = EmitArrayElementAddress(index, b, name); + llvm::LoadInst* load = b->CreateLoad(element_address); AnnotateLoadStoreInstructionWithMetadata(load); return load; } void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value, - llvm::IRBuilder<>* ir_builder) const { - llvm::Value* element_address = EmitArrayElementAddress(index, ir_builder); - llvm::StoreInst* store = ir_builder->CreateStore(value, element_address); + llvm::IRBuilder<>* b) const { + llvm::Value* element_address = EmitArrayElementAddress(index, b); + llvm::StoreInst* store = b->CreateStore(value, element_address); AnnotateLoadStoreInstructionWithMetadata(store); } IrArray IrArray::CastToShape(const Shape& new_shape, - llvm::IRBuilder<>* ir_builder) const { - llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent(); + llvm::IRBuilder<>* b) const { + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module); - return IrArray( - ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), - new_shape); + IrArray new_irarray( + b->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_shape); + new_irarray.metadata_ = metadata_; + return new_irarray; } /* static */ IrArray::Index IrArray::BumpIndex(const Index& index, int64 which_dimension, int64 addend, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { Index new_index = index; - new_index[which_dimension] = ir_builder->CreateAdd( - index[which_dimension], ir_builder->getInt64(addend), "", /*HasNUW=*/true, + new_index[which_dimension] = b->CreateAdd( + index[which_dimension], + llvm::ConstantInt::get(index[which_dimension]->getType(), addend), "", + /*HasNUW=*/true, /*HasNSW=*/true); return new_index; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 4c3195c29c859c9eef08e3f6531b059edbebfc47..28ca793e3eeaed86664bfa6aa859a38f2c4dc6f3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -53,34 +53,53 @@ class IrArray { // multidimensional index, which LLVM DCE can delete. class Index { public: - // Constructs an empty zero-dimensional index. - Index() {} - // Constructs an index of rank "size". Each dimension of the index is // initialized to "value". - explicit Index(size_t size, llvm::Value* value = nullptr) - : multidim_(size, value) {} + explicit Index(size_t size, llvm::Value* value) + : multidim_(size, value), index_type_(value->getType()) { + CHECK_NE(index_type_, nullptr); + } + + // Constructs an index of rank "size". Each dimension of the index is + // initialized to nullptr. + explicit Index(llvm::Type* index_ty, size_t size = 0) + : multidim_(size, nullptr), index_type_(index_ty) { + CHECK(index_ty->isIntegerTy()); + } // Constructs an index from multi-dimensional index "multidim". The linear // index is set to nullptr. - explicit Index(tensorflow::gtl::ArraySlice multidim) - : multidim_(multidim.begin(), multidim.end()) {} + explicit Index(tensorflow::gtl::ArraySlice multidim, + llvm::Type* index_ty = nullptr) + : multidim_(multidim.begin(), multidim.end()) { + if (size() == 0) { + index_type_ = index_ty; + } else { + index_type_ = (*this)[0]->getType(); + if (index_ty != nullptr) { + CHECK_EQ(index_type_, index_ty); + } + } + CHECK_NE(index_type_, nullptr); + CHECK(c_all_of(multidim, [&](llvm::Value* v) { + return index_type_ == v->getType(); + })); + } // Constructs an index from linear index "linear" and computes the - // multi-dimensional index from "linear" and "shape". "ir_builder" is the IR + // multi-dimensional index from "linear" and "shape". "b" is the IR // builder to emit the index of each dimension in the multi-dimensional // index. // // Precondition: "shape" has a layout. - Index(llvm::Value* linear, const Shape& shape, - llvm::IRBuilder<>* ir_builder); + Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b); // Constructs an index from the given multi-dimensional index and the shape // that it indexes into. // // Precondition: "shape" has a layout. Index(tensorflow::gtl::ArraySlice multidim, - const Shape& shape, llvm::IRBuilder<>* ir_builder); + const Shape& shape, llvm::IRBuilder<>* b); // Constructs an index from both a multi-dimensional index and a linear // index. "shape" has the same meaning as that in the constructor that takes @@ -94,19 +113,19 @@ class IrArray { size_t size() const { return multidim().size(); } llvm::Value* operator[](size_t i) const { return multidim()[i]; } - llvm::Value*& operator[](size_t i) { return multidim()[i]; } + llvm::Value*& operator[](size_t i) { return mutable_multidim()[i]; } - void push_back(llvm::Value* value) { multidim().push_back(value); } + void push_back(llvm::Value* value) { mutable_multidim().push_back(value); } void InsertAt(int64 index, llvm::Value* value) { CHECK_LE(index, size()); - multidim().insert(multidim().begin() + index, value); + mutable_multidim().insert(mutable_multidim().begin() + index, value); } using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; - iterator begin() { return multidim().begin(); } - iterator end() { return multidim().end(); } + iterator begin() { return mutable_multidim().begin(); } + iterator end() { return mutable_multidim().end(); } const_iterator begin() const { return multidim().begin(); } const_iterator end() const { return multidim().end(); } @@ -154,13 +173,25 @@ class IrArray { llvm::Value* Linearize(tensorflow::gtl::ArraySlice dimensions, llvm::IRBuilder<>* builder) const; + llvm::Type* GetType() const { return index_type_; } + + llvm::Constant* GetConstantWithIndexType(int64 c) const { + // The LLVM function makes sure that the value can be represented by the + // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const + // APInt &V). + return llvm::ConstantInt::get(index_type_, c); + } + private: // Changing the multi-dimensional index invalidates the linear index. - std::vector& multidim() { + std::vector& mutable_multidim() { linear_ = nullptr; return multidim_; } + void Delinearize(std::vector* multidim, llvm::Value* linear, + const Shape& shape, llvm::IRBuilder<>* b) const; + std::vector multidim_; // These values are purely for efficiency; `multidim_` is enough to find the @@ -177,6 +208,8 @@ class IrArray { llvm::Value* linear_ = nullptr; Layout layout_; std::vector dims_; + + llvm::Type* index_type_; }; // Default constructor. Constructs an IrArray in a null status. @@ -206,8 +239,7 @@ class IrArray { // // The optional name is useful for debugging when looking at // the emitted LLVM IR. - llvm::Value* EmitArrayElementAddress(const Index& index, - llvm::IRBuilder<>* ir_builder, + llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, tensorflow::StringPiece name = "") const; // Attach metadata this IrArray instance knows about to "instruction". @@ -221,18 +253,16 @@ class IrArray { // // The optional name is useful for debugging when looking at // the emitted LLVM IR. - llvm::Value* EmitReadArrayElement(const Index& index, - llvm::IRBuilder<>* ir_builder, + llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, tensorflow::StringPiece name = "") const; // Emit IR to write the given value to the array element at the given index. void EmitWriteArrayElement(const Index& index, llvm::Value* value, - llvm::IRBuilder<>* ir_builder) const; + llvm::IRBuilder<>* b) const; // Returns a new IrArray whose shape is "new_shape" and base pointer is a // bitcast of the base pointer of "this" IrArray. - IrArray CastToShape(const Shape& new_shape, - llvm::IRBuilder<>* ir_builder) const; + IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const; void AddAliasScopeMetadata(llvm::MDNode* alias_scope) { CHECK_NE(alias_scope, nullptr); @@ -278,7 +308,7 @@ class IrArray { // Bumps the "which_dimension" value within the provided index by the provided // addend. static Index BumpIndex(const Index& index, int64 which_dimension, - int64 addend, llvm::IRBuilder<>* ir_builder); + int64 addend, llvm::IRBuilder<>* b); private: // Add the specified LLVM IR metadata to loads/stores associated with this diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 1f6e3c829f890d68aa251b101f0402c120a19d61..b79567369aa532c4963e3941f6cb9844cd1476dd 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -22,9 +22,9 @@ Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - return If(ir_builder_->CreateICmpSLT(start, end), [&]() -> Status { + return If(b_->CreateICmpSLT(start, end), [&]() -> Status { TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); - return For(name, ir_builder_->CreateAdd(start, step), end, step, + return For(name, b_->CreateAdd(start, step), end, step, [&](llvm::Value* iv) { return for_body_generator(iv, false); }); }); } @@ -37,44 +37,44 @@ Status KernelSupportLibrary::For( if (peel_first_iteration) { return For(name, start, end, step, true, [&](llvm::Value* indvar, bool is_first_iteration) -> Status { - return for_body_generator( - indvar, ir_builder_->getInt1(is_first_iteration)); + return for_body_generator(indvar, + b_->getInt1(is_first_iteration)); }); } else { std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( - name, start, end, step, ir_builder_, + name, start, end, step, b_, /*unroll_mode=*/unroll_mode_, /*prevent_vectorization=*/prevent_vectorization_); - ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); + b_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); TF_RETURN_IF_ERROR( for_body_generator(loop->GetIndVarValue(), - /*is_first_iteration=*/ir_builder_->CreateICmpEQ( + /*is_first_iteration=*/b_->CreateICmpEQ( loop->GetIndVarValue(), start))); - llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_); + llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), b_); return Status::OK(); } } Status KernelSupportLibrary::If( - llvm::Value* condition, const std::function& true_block_generator, + tensorflow::StringPiece name, llvm::Value* condition, + const std::function& true_block_generator, const std::function& false_block_generator) { - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(condition, "", ir_builder_); - ir_builder_->SetInsertPoint(&if_data.true_block->back()); + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_); + b_->SetInsertPoint(&if_data.true_block->back()); TF_RETURN_IF_ERROR(true_block_generator()); - ir_builder_->SetInsertPoint(&if_data.false_block->back()); + b_->SetInsertPoint(&if_data.false_block->back()); TF_RETURN_IF_ERROR(false_block_generator()); - llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); + llvm_ir::SetToLastInsertPoint(if_data.after_block, b_); return Status::OK(); } void KernelSupportLibrary::EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, - llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + tensorflow::StringPiece kernel_name, KernelSupportLibrary::ArgumentVector arguments, const std::function& kernel_body_generator) { - llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + llvm::Module* module = b->GetInsertBlock()->getModule(); llvm::Function* function = module->getFunction(llvm_ir::AsStringRef(kernel_name)); @@ -97,22 +97,22 @@ void KernelSupportLibrary::EmitAndCallOutlinedKernel( std::back_inserter(arg_types), [](llvm::Value* arg) { return arg->getType(); }); - auto* function_type = llvm::FunctionType::get( - ir_builder->getVoidTy(), arg_types, /*isVarArg=*/false); + auto* function_type = + llvm::FunctionType::get(b->getVoidTy(), arg_types, /*isVarArg=*/false); function = llvm_ir::CreateFunction( function_type, llvm::GlobalValue::InternalLinkage, /*enable_fast_math=*/enable_fast_math, /*optimize_for_size=*/optimize_for_size, kernel_name, module); - llvm::IRBuilder<>::InsertPointGuard guard(*ir_builder); + llvm::IRBuilder<>::InsertPointGuard guard(*b); auto* entry_bb = - llvm::BasicBlock::Create(ir_builder->getContext(), "entry", function); - auto* return_inst = llvm::ReturnInst::Create(ir_builder->getContext(), + llvm::BasicBlock::Create(b->getContext(), "entry", function); + auto* return_inst = llvm::ReturnInst::Create(b->getContext(), /*retVal=*/nullptr, entry_bb); // Set the insert point to before return_inst. - ir_builder->SetInsertPoint(return_inst); + b->SetInsertPoint(return_inst); std::vector arg_values; /* @@ -132,7 +132,7 @@ void KernelSupportLibrary::EmitAndCallOutlinedKernel( VLOG(3) << "Re-using kernel for " << kernel_name; } - ir_builder->CreateCall(function, llvm_ir::AsArrayRef(sanitized_args)); + b->CreateCall(function, llvm_ir::AsArrayRef(sanitized_args)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index e17c649e5272a9e7c0d5126083ab76542abfdf48..b00f903d56a83c5b76188007702470c44c55c213 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -30,14 +30,14 @@ namespace xla { // flow more readable. class KernelSupportLibrary { public: - // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR. + // `b` is the llvm::IRBuilder instance used to generate LLVM IR. // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop // generated by this instance of KernelSupportLibrary. explicit KernelSupportLibrary( - llvm::IRBuilder<>* ir_builder, + llvm::IRBuilder<>* b, llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll, bool prevent_vectorization = true) - : ir_builder_(ir_builder), + : b_(b), unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} @@ -71,18 +71,18 @@ class KernelSupportLibrary { const std::function& for_body_generator) { - return For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + return For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure if `peel_first_iteration` is @@ -125,8 +125,8 @@ class KernelSupportLibrary { llvm::Value* is_first_iteration)>& for_body_generator) { return For(name, /*start=*/start, /*end=*/end, - /*step=*/ir_builder_->getInt64(step), peel_first_iteration, - for_body_generator); + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, @@ -135,8 +135,8 @@ class KernelSupportLibrary { llvm::Value* is_first_iteration)>& for_body_generator) { ForReturnVoid(name, /*start=*/start, /*end=*/end, - /*step=*/ir_builder_->getInt64(step), peel_first_iteration, - for_body_generator); + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } Status For( @@ -165,7 +165,7 @@ class KernelSupportLibrary { tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - return For(name, start, end, ir_builder_->getInt64(step), + return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), /*peel_first_iteration=*/false, [&](llvm::Value* indvar, llvm::Value*) -> Status { return for_body_generator(indvar); @@ -176,24 +176,25 @@ class KernelSupportLibrary { tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, start, end, ir_builder_->getInt64(step), + ForReturnVoid(name, start, end, + llvm::ConstantInt::get(start->getType(), step), for_body_generator); } Status For( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - return For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + return For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure: @@ -202,16 +203,30 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - Status If(llvm::Value* condition, + Status If(tensorflow::StringPiece name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator = []() -> Status { return Status::OK(); }); + Status If(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = + []() -> Status { return Status::OK(); }) { + return If("", condition, true_block_generator, false_block_generator); + } + void IfReturnVoid(llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator = []() { }) { - TF_CHECK_OK(If(condition, + IfReturnVoid("", condition, true_block_generator, false_block_generator); + } + + void IfReturnVoid(tensorflow::StringPiece name, llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() { + }) { + TF_CHECK_OK(If(name, condition, [&]() { true_block_generator(); return Status::OK(); @@ -243,41 +258,39 @@ class KernelSupportLibrary { // in a nullptr llvm::Value* in its position to `kernel_body_generator`. // Currently we only support at most one nullptr value in `arguments`. static void EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, - llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, - ArgumentVector arguments, + bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + tensorflow::StringPiece kernel_name, ArgumentVector arguments, const std::function& kernel_body_generator); // Thin wrappers around the more general EmitAndCallOutlinedKernel above. static void EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, - llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, - llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, + bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + llvm::Value* arg2, const std::function& kernel_body_generator) { EmitAndCallOutlinedKernel( - enable_fast_math, optimize_for_size, ir_builder, kernel_name, - {arg0, arg1, arg2}, [&](ArgumentVector args) { + enable_fast_math, optimize_for_size, b, kernel_name, {arg0, arg1, arg2}, + [&](ArgumentVector args) { kernel_body_generator(args[0], args[1], args[2]); }); } static void EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, - llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, - llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, - llvm::Value* arg3, + bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + llvm::Value* arg2, llvm::Value* arg3, const std::function& kernel_body_generator) { EmitAndCallOutlinedKernel( - enable_fast_math, optimize_for_size, ir_builder, kernel_name, + enable_fast_math, optimize_for_size, b, kernel_name, {arg0, arg1, arg2, arg3}, [&](ArgumentVector args) { kernel_body_generator(args[0], args[1], args[2], args[3]); }); } private: - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; llvm_ir::UnrollMode unroll_mode_; bool prevent_vectorization_; }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc new file mode 100644 index 0000000000000000000000000000000000000000..35b394127288d816952b48c84b193257bab0bcda --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace llvm_ir { + +namespace { +// Returns the indices of the first elements of all consecutive subarrays of the +// given array. For example: +// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} +std::vector ConsecutiveSegments(tensorflow::gtl::ArraySlice xs) { + std::vector is = {0}; + for (size_t i = 1; i < xs.size(); ++i) { + if (1 != xs[i] - xs[i - 1]) { + is.push_back(i); + } + } + return is; +} + +// Merges the sequences of dimensions of the given shape which start at the +// given indices `segs`. +Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, + const Shape& shape) { + std::vector dimensions; + for (size_t i = 1; i <= segs.size(); ++i) { + dimensions.push_back(std::accumulate( + shape.dimensions().begin() + segs[i - 1], + shape.dimensions().begin() + + (segs.size() == i ? shape.dimensions().size() : segs[i]), + 1, std::multiplies())); + } + return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), + dimensions); +} +} // namespace + +tensorflow::gtl::optional > FindTranspose021( + const Shape& a, const Shape& b) { + if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) { + return tensorflow::gtl::nullopt; + } + + std::vector perm(a.dimensions().size()); + { + auto layout_a_orig = LayoutUtil::MinorToMajor(a); + std::vector layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); + auto layout_b_orig = LayoutUtil::MinorToMajor(b); + std::vector layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); + for (size_t i = 0; i < perm.size(); ++i) { + perm[i] = PositionInContainer(layout_b, layout_a[i]); + } + } + auto segs = ConsecutiveSegments(perm); + if ((3 == segs.size() && 0 == perm[0]) || 2 == segs.size()) { + Shape norm_a = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); + Shape reduced_a = MergeDimensions(segs, norm_a); + auto reduced_a_dims = reduced_a.dimensions(); + std::vector dims_021; + if (2 == segs.size()) { + // The logical component-0 is of size one. + dims_021 = {1, reduced_a_dims[1], reduced_a_dims[0]}; + } else { + dims_021 = {reduced_a_dims[0], reduced_a_dims[2], reduced_a_dims[1]}; + } + + return dims_021; + } + + return tensorflow::gtl::nullopt; +} + +IrArray::Index GetUnreducedOutputIndex( + const IrArray::Index& reduced_output_index, + const Shape& reduced_output_shape, const Shape& unreduced_output_shape, + llvm::IRBuilder<>* b) { + auto bounds = reduced_output_shape.dimensions(); + auto minor_to_major = reduced_output_shape.layout().minor_to_major(); + llvm::Value* linear_index = reduced_output_index.GetConstantWithIndexType(0); + int64 multiplier = 1; + for (int i = 0; i < reduced_output_index.size(); ++i) { + int64 dim = minor_to_major[i]; + llvm::Value* addend = + b->CreateMul(reduced_output_index[dim], + reduced_output_index.GetConstantWithIndexType(multiplier), + "linearizing", + /*HasNUW=*/true, /*HasNSW=*/true); + linear_index = b->CreateAdd(linear_index, addend, "", + /*HasNUW=*/true, /*HasNSW=*/true); + multiplier *= bounds[dim]; + } + + return IrArray::Index(linear_index, unreduced_output_shape, b); +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..ccb9b8ba3e6b0079664f2da92ce67224e176fa1d --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_ + +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" + +namespace xla { +namespace llvm_ir { + +// About 0-2-1 transpose: +// +// If a shape can be viewed as three logical components 0-1-2 in the order of +// major to minor, a 0-2-1-transpose changes the order of such logical +// components to 0-2-1. We call the shape being transposed the input shape and +// the transposed shape the output shape. The logical view of the input and +// output shapes for the transpose are called the 0-1-2 shape or reduced input +// shape and the 0-2-1 shape or the reduced output shape respectively. The +// original input and output shapes are called the unreduced input and output +// shapes. + +// If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the +// reduced shape of `b` or the 0-2-1 shape. +tensorflow::gtl::optional > FindTranspose021(const Shape& a, + const Shape& b); + +// Return the unreduced output index corresponding to the given reduced output +// index. +IrArray::Index GetUnreducedOutputIndex( + const IrArray::Index& reduced_output_index, + const Shape& reduced_output_shape, const Shape& unreduced_output_shape, + llvm::IRBuilder<>* b); + +// A class to represent information for tiled parameters to support IR emission +// for 021 transpose. +class TiledParameterInfo { + public: + TiledParameterInfo(tensorflow::gtl::ArraySlice param_buffers, + llvm::Value* y, llvm::Value* x) + : param_buffers_(param_buffers), y_(y), x_(x) {} + + llvm::Value* x() const { return x_; } + llvm::Value* y() const { return y_; } + + void set_x(llvm::Value* x) { x_ = x; } + void set_y(llvm::Value* y) { y_ = y; } + + llvm::Value* GetBufferForParameter(int64 index) const { + return param_buffers_[index]; + } + + private: + // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr + // if the parameter is not tiled. + tensorflow::gtl::ArraySlice param_buffers_; + // The y coordinate within a tile. + llvm::Value* y_; + // The x coordinate within a tile. + llvm::Value* x_; +}; + +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 9f867014fb015845448c4fcf9c165750f8a61935..ba7f94834c7fd04d97cec012537244323308b8ce 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -47,27 +47,27 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, /* static */ std::unique_ptr ForLoop::EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, + llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode, bool prevent_vectorization) { std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, end_index, step, unroll_mode, prevent_vectorization)); - loop->Emit(ir_builder); + loop->Emit(b); return loop; } -void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { +void ForLoop::Emit(llvm::IRBuilder<>* b) { // The preheader block is the block the builder is currently emitting // code into. - preheader_bb_ = ir_builder->GetInsertBlock(); + preheader_bb_ = b->GetInsertBlock(); - llvm::BasicBlock::iterator insert_point = ir_builder->GetInsertPoint(); + llvm::BasicBlock::iterator insert_point = b->GetInsertPoint(); if (insert_point == preheader_bb_->end()) { // We're emitting the loop at the end of a basic block. Verify there is no // terminator (eg, branch) in the basic block. CHECK_EQ(nullptr, preheader_bb_->getTerminator()); - exit_bb_ = CreateLoopBB("loop_exit", ir_builder); + exit_bb_ = CreateLoopBB("loop_exit", b); } else { // We're emitting the loop into the middle of a basic block. splitBasicBlock // requires that this basic block be well-formed (have a terminator). @@ -86,51 +86,50 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { insert_before_bb_ = exit_bb_; // Create remaining basic block which form the inside of the loop. - header_bb_ = CreateLoopBB("loop_header", ir_builder); - body_bb_ = CreateLoopBB("loop_body", ir_builder); + header_bb_ = CreateLoopBB("loop_header", b); + body_bb_ = CreateLoopBB("loop_body", b); // Function entry basic block. // Emit alloca for the induction variable. We do this at the entry to the // basic block to ensure the alloc only executes once per function (we could // be emitting a nested loop). llvm::Function* func = preheader_bb_->getParent(); - ir_builder->SetInsertPoint(&func->getEntryBlock(), - func->getEntryBlock().getFirstInsertionPt()); + b->SetInsertPoint(&func->getEntryBlock(), + func->getEntryBlock().getFirstInsertionPt()); llvm::Value* indvar_address = - ir_builder->CreateAlloca(ir_builder->getInt64Ty(), nullptr, - AsStringRef(GetQualifiedName("invar_address"))); + b->CreateAlloca(start_index_->getType(), nullptr, + AsStringRef(GetQualifiedName("invar_address"))); // Preheader basic block. // Initialize induction variable starting index. Create branch to the header. - ir_builder->SetInsertPoint(preheader_bb_); - ir_builder->CreateStore(start_index_, indvar_address); + b->SetInsertPoint(preheader_bb_); + b->CreateStore(start_index_, indvar_address); // The preheader should not have a branch yet. CHECK_EQ(preheader_bb_->getTerminator(), nullptr); - ir_builder->CreateBr(header_bb_); + b->CreateBr(header_bb_); // Header basic block. // Emit the loop conditional branch. Load and compare indvar with ending // index and jump to loop exit if equal. Jump to body otherwise. - ir_builder->SetInsertPoint(header_bb_); - indvar_ = ir_builder->CreateLoad(indvar_address, - AsStringRef(GetQualifiedName("indvar"))); - llvm::Value* exit_cond = ir_builder->CreateICmpUGE(indvar_, end_index_); - ir_builder->CreateCondBr(/*Cond=*/exit_cond, - /*True=*/exit_bb_, /*False=*/body_bb_); + b->SetInsertPoint(header_bb_); + indvar_ = + b->CreateLoad(indvar_address, AsStringRef(GetQualifiedName("indvar"))); + llvm::Value* exit_cond = b->CreateICmpUGE(indvar_, end_index_); + b->CreateCondBr(/*Cond=*/exit_cond, + /*True=*/exit_bb_, /*False=*/body_bb_); // Body basic block. // Increment indvar, store indvar, and jump to header. - ir_builder->SetInsertPoint(body_bb_); + b->SetInsertPoint(body_bb_); llvm::Value* step = step_; llvm::Value* indvar = indvar_; - llvm::Value* indvar_inc = - ir_builder->CreateAdd(indvar, step, "invar.inc", - /*HasNUW=*/true, /*HasNSW=*/true); - ir_builder->CreateStore(indvar_inc, indvar_address); - llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_); + llvm::Value* indvar_inc = b->CreateAdd(indvar, step, "invar.inc", + /*HasNUW=*/true, /*HasNSW=*/true); + b->CreateStore(indvar_inc, indvar_address); + llvm::BranchInst* back_branch = b->CreateBr(header_bb_); - std::vector loop_metadata = GetLoopMetadata(ir_builder); + std::vector loop_metadata = GetLoopMetadata(b); if (!loop_metadata.empty()) { llvm::LLVMContext* ctx = &start_index_->getContext(); auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None); @@ -141,11 +140,10 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { } // Re-point the IR builder to the loop exit block. - ir_builder->SetInsertPoint(exit_bb_); + b->SetInsertPoint(exit_bb_); } -std::vector ForLoop::GetLoopMetadata( - llvm::IRBuilder<>* ir_builder) { +std::vector ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) { const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full"; const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable"; @@ -160,7 +158,7 @@ std::vector ForLoop::GetLoopMetadata( if (prevent_vectorization_) { result.push_back(llvm::MDNode::get( *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName), - llvm::ConstantAsMetadata::get(ir_builder->getFalse())})); + llvm::ConstantAsMetadata::get(b->getFalse())})); } if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) { @@ -175,9 +173,8 @@ string ForLoop::GetQualifiedName(tensorflow::StringPiece name) { } llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder) { - return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), - ir_builder); + llvm::IRBuilder<>* b) { + return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b); } std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, @@ -185,7 +182,7 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* end_index, UnrollMode unroll_mode, bool prevent_vectorization) { - return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), + return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1), unroll_mode, prevent_vectorization); } @@ -197,12 +194,12 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. - ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); + b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } std::unique_ptr loop(new ForLoop( /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode, prevent_vectorization)); - loop->Emit(ir_builder_); + loop->Emit(b_); if (outer_loop_preheader_bb_ == nullptr) { outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock(); @@ -223,8 +220,8 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); - return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), unroll_mode, + return AddLoop(suffix, GetConstantWithIndexType(start_index), + GetConstantWithIndexType(end_index), unroll_mode, prevent_vectorization); } @@ -234,9 +231,9 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); - return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), - ir_builder_->getInt64(stride), unroll_mode, + return AddLoop(suffix, GetConstantWithIndexType(start_index), + GetConstantWithIndexType(end_index), + GetConstantWithIndexType(stride), unroll_mode, prevent_vectorization); } @@ -250,7 +247,7 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( const Shape& shape, tensorflow::gtl::ArraySlice dimensions, tensorflow::StringPiece suffix) { - llvm_ir::IrArray::Index index(shape.dimensions_size(), nullptr); + llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size()); for (int64 dimension : dimensions) { std::unique_ptr loop = AddLoop( /*start_index=*/0, @@ -262,5 +259,35 @@ IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( return index; } +IrArray::Index ForLoopNest::EmitOperandArrayLoopNest( + const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, + tensorflow::StringPiece name_suffix) { + // Prepares the dimension list we will use to emit the loop nest. Outermost + // loops are added first. Add loops in major-to-minor order, and skip the + // 'dimension_to_skip' dimension. + std::vector dimensions; + const Shape& shape = operand_array.GetShape(); + for (int64 dimension : LayoutUtil::MinorToMajor(shape)) { + if (dimension != dimension_to_skip) { + dimensions.push_back(dimension); + } + } + + // Create loop nest with one for-loop for each dimension of the + // output. + llvm_ir::IrArray::Index index = + AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); + // Verify every dimension except the 'dimension_to_skip' dimension was set in + // the index. + for (size_t dimension = 0; dimension < index.size(); ++dimension) { + if (dimension == dimension_to_skip) { + DCHECK_EQ(nullptr, index[dimension]); + } else { + DCHECK_NE(nullptr, index[dimension]); + } + } + return index; +} + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 4e403cd994874c27453574283c5c573c876628db..a4fed5c8dc55d38d25031252e3960404a5bf84e6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -79,7 +79,7 @@ class ForLoop { // loop. static std::unique_ptr EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, + llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -138,10 +138,10 @@ class ForLoop { UnrollMode unroll_mode, bool prevent_vectorization); // Emit the loop at the insert point of the builder. - void Emit(llvm::IRBuilder<>* ir_builder); + void Emit(llvm::IRBuilder<>* b); llvm::BasicBlock* CreateLoopBB(tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Creates a name for an LLVM construct, appending prefix_ and suffix_, if // they are set. @@ -149,7 +149,7 @@ class ForLoop { // Return a list of metadata nodes that should be associated with the // llvm::Loop for this `ForLoop`. - std::vector GetLoopMetadata(llvm::IRBuilder<>* ir_builder); + std::vector GetLoopMetadata(llvm::IRBuilder<>* b); string prefix_; string suffix_; @@ -177,15 +177,20 @@ class ForLoop { // A simple class for constructing nested for-loops. class ForLoopNest { public: - explicit ForLoopNest(llvm::IRBuilder<>* ir_builder) - : ForLoopNest(/*name=*/"", ir_builder) {} + explicit ForLoopNest(llvm::IRBuilder<>* b, llvm::Type* index_ty = nullptr) + : ForLoopNest(/*name=*/"", b) { + SetIndexType(index_ty); + } - ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) + ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* b, + llvm::Type* index_ty = nullptr) : name_(std::string(name)), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), - ir_builder_(ir_builder) {} + b_(b) { + SetIndexType(index_ty); + } // Adds a loop to the nest. If no loop has been added yet then emit a loop at // the current insert point of the given builder. If one or more loops have @@ -242,6 +247,17 @@ class ForLoopNest { const Shape& shape, tensorflow::gtl::ArraySlice dimensions, tensorflow::StringPiece suffix); + // Emits a series of nested loops for iterating over an operand array. Loops + // are constructed in major to minor dimension layout order. No loop is + // emitted for the given 'dimension_to_skip'. The function returns an IrArray + // index for the given operand_array containing the indvars of the loops. All + // dimensions of the index are filled except for 'dimension_to_skip'. + // name_suffix is the string to append to the names of LLVM constructs (eg, + // basic blocks) constructed by this method. + IrArray::Index EmitOperandArrayLoopNest(const llvm_ir::IrArray& operand_array, + int64 dimension_to_skip, + tensorflow::StringPiece name_suffix); + // Convenience methods which return particular basic blocks of the outermost // or innermost loops. These methods return nullptr if no loops have been // added yet. @@ -252,6 +268,14 @@ class ForLoopNest { llvm::BasicBlock* GetInnerLoopBodyBasicBlock() { return inner_loop_body_bb_; } private: + void SetIndexType(llvm::Type* index_ty) { + index_type_ = index_ty == nullptr ? b_->getInt64Ty() : index_ty; + } + + llvm::Constant* GetConstantWithIndexType(int64 c) const { + return llvm::ConstantInt::get(index_type_, c); + } + // Human-friendly name of the loop nest. string name_; @@ -264,7 +288,9 @@ class ForLoopNest { // has been added yet. llvm::BasicBlock* inner_loop_body_bb_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; + + llvm::Type* index_type_; TF_DISALLOW_COPY_AND_ASSIGN(ForLoopNest); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ff64da87e9c9acf8a9d7ff87d3b1be7a9e9106bb..e6126881af8b8123e08a4eaa934b52a7fd378ce6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -26,7 +26,7 @@ limitations under the License. #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Utils/Cloning.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -47,8 +48,8 @@ namespace { // Note, this function is only useful in an insertion context; in a global // (e.g. constants) context it will CHECK fail. -llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* ir_builder) { - auto block = CHECK_NOTNULL(ir_builder->GetInsertBlock()); +llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) { + auto block = CHECK_NOTNULL(b->GetInsertBlock()); auto fn = CHECK_NOTNULL(block->getParent()); auto module = CHECK_NOTNULL(fn->getParent()); return module; @@ -86,41 +87,41 @@ llvm::Value* EmitCallToIntrinsic( llvm::Intrinsic::ID intrinsic_id, tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice overloaded_types, - llvm::IRBuilder<>* ir_builder) { - llvm::Module* module = ModuleFromIRBuilder(ir_builder); + llvm::IRBuilder<>* b) { + llvm::Module* module = ModuleFromIRBuilder(b); llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( module, intrinsic_id, AsArrayRef(overloaded_types)); - return ir_builder->CreateCall(intrinsic, AsArrayRef(operands)); + return b->CreateCall(intrinsic, AsArrayRef(operands)); } llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* ir_builder) { - if (ir_builder->getFastMathFlags().noNaNs()) { - auto cmp = ir_builder->CreateFCmpUGE(lhs_value, rhs_value); - return ir_builder->CreateSelect(cmp, lhs_value, rhs_value); + llvm::IRBuilder<>* b) { + if (b->getFastMathFlags().noNaNs()) { + auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value); + return b->CreateSelect(cmp, lhs_value, rhs_value); } else { - auto cmp_ge = ir_builder->CreateFCmpOGE(lhs_value, rhs_value); - auto lhs_is_nan = ir_builder->CreateFCmpUNE(lhs_value, lhs_value); - auto sel_lhs = ir_builder->CreateOr(cmp_ge, lhs_is_nan); - return ir_builder->CreateSelect(sel_lhs, lhs_value, rhs_value); + auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value); + auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value); + auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan); + return b->CreateSelect(sel_lhs, lhs_value, rhs_value); } } llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* ir_builder) { - if (ir_builder->getFastMathFlags().noNaNs()) { - auto cmp = ir_builder->CreateFCmpULE(lhs_value, rhs_value); - return ir_builder->CreateSelect(cmp, lhs_value, rhs_value); + llvm::IRBuilder<>* b) { + if (b->getFastMathFlags().noNaNs()) { + auto cmp = b->CreateFCmpULE(lhs_value, rhs_value); + return b->CreateSelect(cmp, lhs_value, rhs_value); } else { - auto cmp_le = ir_builder->CreateFCmpOLE(lhs_value, rhs_value); - auto lhs_is_nan = ir_builder->CreateFCmpUNE(lhs_value, lhs_value); - auto sel_lhs = ir_builder->CreateOr(cmp_le, lhs_is_nan); - return ir_builder->CreateSelect(sel_lhs, lhs_value, rhs_value); + auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value); + auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value); + auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan); + return b->CreateSelect(sel_lhs, lhs_value, rhs_value); } } llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { llvm::Type* array_type = array->getType(); CHECK(array_type->isPointerTy()); llvm::PointerType* array_type_as_pointer = @@ -130,16 +131,16 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index, << " array=" << llvm_ir::DumpToString(*array) << " index=" << llvm_ir::DumpToString(*index); - return ir_builder->CreateInBoundsGEP( + return b->CreateInBoundsGEP( array_type_as_pointer->getElementType(), array, llvm::isa(array) - ? llvm::ArrayRef({ir_builder->getInt64(0), index}) + ? llvm::ArrayRef({b->getInt64(0), index}) : index); } llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, - llvm::IRBuilder<>* ir_builder) { - return EmitBufferIndexingGEP(array, ir_builder->getInt64(index), ir_builder); + llvm::IRBuilder<>* b) { + return EmitBufferIndexingGEP(array, b->getInt64(index), b); } llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, @@ -193,6 +194,10 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, // An Opaque is like a void*, use i8*. case OPAQUE: return llvm::Type::getInt8PtrTy(module->getContext()); + case TOKEN: + // Tokens do not have a physical representation, but the compiler needs + // some placeholder type, so use int8*. + return llvm::Type::getInt8PtrTy(module->getContext()); default: LOG(FATAL) << "unsupported type " << element_type; } @@ -227,14 +232,15 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { return result_type; } -StatusOr EncodeSelfDescribingShapeConstant( - const Shape& shape, int32* shape_size, llvm::IRBuilder<>* ir_builder) { +StatusOr EncodeSelfDescribingShapeConstant(const Shape& shape, + int32* shape_size, + llvm::IRBuilder<>* b) { string encoded_shape = shape.SerializeAsString(); if (encoded_shape.size() > std::numeric_limits::max()) { return InternalError("Encoded shape size exceeded int32 size limit."); } *shape_size = static_cast(encoded_shape.size()); - return ir_builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(encoded_shape)); + return b->CreateGlobalStringPtr(llvm_ir::AsStringRef(encoded_shape)); } StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, @@ -245,224 +251,69 @@ StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, return shape; } -namespace { - -// Recursively construct a multidimensional LLVM constant which represents the -// given literal. The minor-to-major dimension ordering in the constant matches -// that of the literal. For example, given a [2 x 3 x 4] Literal (dimension 0 -// has size 4, dimension 1 has size 3, etc) of primitive type F32 with a -// minor_to_major value of [2, 1, 0] (column major), a LLVM constant of type -// [4 x [3 x [2 x float]] will be returned. -// -// multi_index is a multidimensional index into the array. dimension_index is an -// index into the minor_to_major field in the literal shape. This determines -// which dimension is iterated over in this level of the recursion. Dimensions -// are iterated from most major down to most minor (highest dimension_index -// value down to zero). -llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, - std::vector* multi_index, - llvm::Module* module) { - const Shape& shape = literal.shape(); - llvm::Type* ir_element_type = - llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module); - if (dimension_index == -1) { - // Base case of the recursion. Index into the data field of the protobuf - // with the multi index. - llvm::Constant* value; - switch (shape.element_type()) { - case PRED: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case U8: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case S32: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case U32: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case S64: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case U64: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case F32: - value = llvm::ConstantFP::get(ir_element_type, - literal.Get(*multi_index)); - break; - case BF16: - value = llvm::ConstantInt::get( - ir_element_type, - tensorflow::bit_cast(literal.Get(*multi_index))); - break; - case F16: - value = llvm::ConstantFP::get( - ir_element_type, - static_cast(literal.Get(*multi_index))); - break; - case F64: - value = llvm::ConstantFP::get(ir_element_type, - literal.Get(*multi_index)); - break; - case C64: { - complex64 x = literal.Get(*multi_index); - value = llvm::ConstantStruct::get( - static_cast(ir_element_type), - llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), - x.real()), - llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), - x.imag())); - break; - } - default: - LOG(FATAL) << "unsupported type " << shape.element_type(); - } - return value; - } - - // The dimension index starts at the one less than the rank of the array and - // decrements with each recursive call. We want to iterate through the - // dimensions in major-to-minor order as we recurse so just index into - // minor_to_major to get the dimension number for this level of the recursion. - int64 dimension = LayoutUtil::Minor(shape.layout(), dimension_index); - - // Recursively call LiteralToConstant to construct subarrays for the - // more-minor dimensions. Gather the subarrays into a vector for bundling into - // a new (higher-dimensional) ConstantArray. - std::vector elements; - for (int64 i = 0; i < shape.dimensions(dimension); ++i) { - (*multi_index)[dimension] = i; - elements.push_back( - LiteralToConstant(literal, dimension_index - 1, multi_index, module)); - } - - llvm::Type* element_type; - if (elements.empty()) { - element_type = ir_element_type; - for (int i = 0; i < dimension_index; ++i) { - int64 index = LayoutUtil::Minor(shape.layout(), i); - element_type = - llvm::ArrayType::get(element_type, shape.dimensions(index)); - } - } else { - element_type = elements[0]->getType(); - } - llvm::ArrayType* aggregate_type = - llvm::ArrayType::get(element_type, shape.dimensions(dimension)); - return llvm::ConstantArray::get(aggregate_type, elements); -} - -template -llvm::Constant* GetConstantDataArray(const Literal& literal, - llvm::Module* module) { - const T* data = static_cast(literal.untyped_data()); - int64 num_elements = literal.size_bytes() / sizeof(T); - return llvm::ConstantDataArray::get(module->getContext(), - llvm::makeArrayRef(data, num_elements)); -} - -} // namespace - llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module) { - const Shape& shape = literal.shape(); - // TODO(b/29904935): We can get rid of this switch by exposing a - // ConstantDataArray factory method that takes a llvm::Type and a StringRef. - switch (shape.element_type()) { - case U64: - return GetConstantDataArray(literal, module); - case U32: - return GetConstantDataArray(literal, module); - case U8: - return GetConstantDataArray(literal, module); - case S64: - return GetConstantDataArray(literal, module); - case S32: - return GetConstantDataArray(literal, module); - case F64: - return GetConstantDataArray(literal, module); - case F32: - return GetConstantDataArray(literal, module); - case BF16: - case F16: - return GetConstantDataArray(literal, module); - case PRED: - return GetConstantDataArray(literal, module); - // TODO(b/29904935): Also use ConstantDataArray for complex numbers. - case C64: { - int64 dimensions = ShapeUtil::Rank(shape); - std::vector multi_index(dimensions, 0); - return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1, - &multi_index, module); - } - default: - LOG(FATAL) << "unsupported type " << shape.element_type(); - } + const char* data = static_cast(literal.untyped_data()); + CHECK_EQ(module->getDataLayout().isLittleEndian(), + tensorflow::port::kLittleEndian); + return llvm::ConstantDataArray::getString( + module->getContext(), llvm::StringRef(data, literal.size_bytes()), + /*AddNull=*/false); } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder, + llvm::IRBuilder<>* b, int alignment) { - return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, ir_builder, - alignment); + return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment); } llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder, int alignment) { - llvm::IRBuilder<>::InsertPoint insert_point = ir_builder->saveIP(); - llvm::Function* function = ir_builder->GetInsertBlock()->getParent(); - ir_builder->SetInsertPoint(&function->getEntryBlock(), - function->getEntryBlock().getFirstInsertionPt()); + llvm::IRBuilder<>* b, int alignment) { + llvm::IRBuilder<>::InsertPoint insert_point = b->saveIP(); + llvm::Function* function = b->GetInsertBlock()->getParent(); + b->SetInsertPoint(&function->getEntryBlock(), + function->getEntryBlock().getFirstInsertionPt()); llvm::AllocaInst* alloca = - ir_builder->CreateAlloca(type, element_count, AsStringRef(name)); + b->CreateAlloca(type, element_count, AsStringRef(name)); if (alignment != 0) { alloca->setAlignment(alignment); } - ir_builder->restoreIP(insert_point); + b->restoreIP(insert_point); return alloca; } llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { return llvm::BasicBlock::Create( - /*Context=*/ir_builder->getContext(), + /*Context=*/b->getContext(), /*Name=*/AsStringRef(name), - /*Parent=*/ir_builder->GetInsertBlock()->getParent(), + /*Parent=*/b->GetInsertBlock()->getParent(), /*InsertBefore*/ insert_before); } LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder, bool emit_else) { + llvm::IRBuilder<>* b, bool emit_else) { llvm_ir::LlvmIfData if_data; - if_data.if_block = ir_builder->GetInsertBlock(); - if_data.true_block = CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-true"), ir_builder); + if_data.if_block = b->GetInsertBlock(); + if_data.true_block = + CreateBasicBlock(nullptr, tensorflow::strings::StrCat(name, "-true"), b); if_data.false_block = - emit_else ? CreateBasicBlock(nullptr, - tensorflow::strings::StrCat(name, "-false"), - ir_builder) + emit_else ? CreateBasicBlock( + nullptr, tensorflow::strings::StrCat(name, "-false"), b) : nullptr; // Add a terminator to the if block, if necessary. if (if_data.if_block->getTerminator() == nullptr) { - ir_builder->SetInsertPoint(if_data.if_block); + b->SetInsertPoint(if_data.if_block); if_data.after_block = CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-after"), ir_builder); - ir_builder->CreateBr(if_data.after_block); + nullptr, tensorflow::strings::StrCat(name, "-after"), b); + b->CreateBr(if_data.after_block); } else { if_data.after_block = if_data.if_block->splitBasicBlock( - ir_builder->GetInsertPoint(), + b->GetInsertPoint(), AsStringRef(tensorflow::strings::StrCat(name, "-after"))); } @@ -470,39 +321,37 @@ LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, // we're going to replace it with a conditional branch. if_data.if_block->getTerminator()->eraseFromParent(); - ir_builder->SetInsertPoint(if_data.if_block); - ir_builder->CreateCondBr( - condition, if_data.true_block, - emit_else ? if_data.false_block : if_data.after_block); + b->SetInsertPoint(if_data.if_block); + b->CreateCondBr(condition, if_data.true_block, + emit_else ? if_data.false_block : if_data.after_block); - ir_builder->SetInsertPoint(if_data.true_block); - ir_builder->CreateBr(if_data.after_block); + b->SetInsertPoint(if_data.true_block); + b->CreateBr(if_data.after_block); if (emit_else) { - ir_builder->SetInsertPoint(if_data.false_block); - ir_builder->CreateBr(if_data.after_block); + b->SetInsertPoint(if_data.false_block); + b->CreateBr(if_data.after_block); } - ir_builder->SetInsertPoint(if_data.after_block, - if_data.after_block->getFirstInsertionPt()); + b->SetInsertPoint(if_data.after_block, + if_data.after_block->getFirstInsertionPt()); return if_data; } llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { llvm::Value* comparison_result; if (lhs_value->getType()->isIntegerTy()) { - comparison_result = ir_builder->CreateICmp(predicate, lhs_value, rhs_value); + comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value); } else { - comparison_result = ir_builder->CreateFCmp(predicate, lhs_value, rhs_value); + comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value); } // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1 // arrays. So we extend it to i8 so that it's addressable. - return ir_builder->CreateZExt( - comparison_result, - llvm_ir::PrimitiveTypeToIrType(PRED, ModuleFromIRBuilder(ir_builder))); + return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType( + PRED, ModuleFromIRBuilder(b))); } // Internal helper that is called from emitted code to log an int64 value with a @@ -511,17 +360,14 @@ static void LogS64(const char* tag, int64 value) { LOG(INFO) << tag << " (int64): " << value; } -void EmitLogging(const char* tag, llvm::Value* value, - llvm::IRBuilder<>* ir_builder) { +void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) { llvm::FunctionType* log_function_type = llvm::FunctionType::get( - ir_builder->getVoidTy(), - {ir_builder->getInt64Ty(), ir_builder->getInt64Ty()}, /*isVarArg=*/false); - ir_builder->CreateCall( + b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false); + b->CreateCall( log_function_type, - ir_builder->CreateIntToPtr( - ir_builder->getInt64(tensorflow::bit_cast(&LogS64)), - log_function_type->getPointerTo()), - {ir_builder->getInt64(tensorflow::bit_cast(tag)), value}); + b->CreateIntToPtr(b->getInt64(tensorflow::bit_cast(&LogS64)), + log_function_type->getPointerTo()), + {b->getInt64(tensorflow::bit_cast(tag)), value}); } void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) { @@ -811,5 +657,56 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { } } +std::pair UMulLowHigh32(llvm::IRBuilder<>* b, + llvm::Value* src0, + llvm::Value* src1) { + CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32); + CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32); + llvm::Type* int64_ty = b->getInt64Ty(); + src0 = b->CreateZExt(src0, int64_ty); + src1 = b->CreateZExt(src1, int64_ty); + return SplitInt64ToInt32s(b, b->CreateMul(src0, src1)); +} + +std::pair SplitInt64ToInt32s( + llvm::IRBuilder<>* b, llvm::Value* value_64bits) { + CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64); + llvm::Type* int32_ty = b->getInt32Ty(); + llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty); + llvm::Value* high_32bits = + b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty); + return std::make_pair(low_32bits, high_32bits); +} + +llvm::GlobalVariable* GetOrCreateVariableForPhiloxRngState( + llvm::Module* module, llvm::IRBuilder<>* b) { + static const char* kPhiloxRngStateVariableName = "philox_rng_state"; + llvm::GlobalVariable* state_ptr = + module->getNamedGlobal(kPhiloxRngStateVariableName); + if (!state_ptr) { + state_ptr = new llvm::GlobalVariable( + /*M=*/*module, + /*Ty=*/b->getInt64Ty(), + /*isConstant=*/false, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/b->getInt64(0), + /*Name=*/kPhiloxRngStateVariableName); + } + return state_ptr; +} + +void IncrementVariableForPhiloxRngState(int64 value, llvm::Module* module, + llvm::IRBuilder<>* builder) { + llvm::GlobalVariable* state_ptr = + GetOrCreateVariableForPhiloxRngState(module, builder); + llvm::Value* state_value_old = builder->CreateLoad(state_ptr, "load_state"); + // If the 64-bit value overflows, we use the wraparound value. This should + // be fine in practice as we only add one to the value each time when a RNG is + // executed. + llvm::Value* state_value_new = builder->CreateAdd( + state_value_old, builder->getInt64(value), "inc_state"); + builder->CreateStore(state_value_new, state_ptr); +} + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 4a10ec466dae6fdb56546fb8d8b353dcff6a5b8d..09583985342033d486d50910b6f5ca732a9a3756 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -27,7 +27,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "llvm/Support/raw_ostream.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/types.h" @@ -105,26 +105,26 @@ llvm::Value* EmitCallToIntrinsic( llvm::Intrinsic::ID intrinsic_id, tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice overloaded_types, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Emit float max. Emit maxnum intrinsic is fast math is disabled, or // fcmp+select otherwise llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Emit float min. Emit minnum intrinsic is fast math is disabled, or // fcmp+select otherwise llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Convenience methods for emitting a GEP instruction that indexes into a buffer // (1-dimensional array), equivalent to array[index]. The type is automatically // determined from the element type of the array. The int64 index overload // wraps the index in a i64 llvm::Value. llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Returns the LLVM type which represents the given XLA primitive type. llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, @@ -139,8 +139,9 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module); // Returns a value that represents a pointer to a global string constant that // encodes the shape as a serialized protobuf. -StatusOr EncodeSelfDescribingShapeConstant( - const Shape& shape, int32* shape_size, llvm::IRBuilder<>* ir_builder); +StatusOr EncodeSelfDescribingShapeConstant(const Shape& shape, + int32* shape_size, + llvm::IRBuilder<>* b); // Inverses the encoding of a Shape protobuf into an LLVM global variable. // @@ -164,21 +165,21 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, // through a loop. llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder, + llvm::IRBuilder<>* b, int alignment = 0); // As EmitAllocaAtFunctionEntry, but allocates element_count entries // instead of a single element. llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder, int alignment = 0); + llvm::IRBuilder<>* b, int alignment = 0); // Creates a basic block with the same context and function as for the // builder. Inserts at the end of the function if insert_before is // null. llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Struct with data on a conditional branch in a diamond shape created // via EmitIfThenElse. @@ -210,13 +211,13 @@ struct LlvmIfData { // block with a terminator. If you need to use this for a // non-terminated block, just make the function able to do that too. LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder, bool emit_else = true); + llvm::IRBuilder<>* b, bool emit_else = true); // Emits a compare operation between "lhs" and "rhs" with the given predicate, // and then converts the result to i8 so that it is addressable. llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, llvm::Value* lhs, llvm::Value* rhs, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Emits a call that logs the given value with the given tag as a prefix. // The provided tag and value are passed to a runtime logging call that is @@ -228,8 +229,7 @@ llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, // Precondition: value must be an int64. // Precondition: tag must be a stable pointer for the lifetime of the generated // program (the constant pointer is burned in to the program). -void EmitLogging(const char* tag, llvm::Value* value, - llvm::IRBuilder<>* ir_builder); +void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b); // Adds alignment metadata to a load instruction using the given alignment. // The alignment refers to the result of the load, not the load itself. @@ -292,6 +292,27 @@ llvm::Function* CreateFunction(llvm::FunctionType* function_type, // don't start with xla_ to LLVM. void InitializeLLVMCommandLineOptions(const HloModuleConfig& config); +// Zero-extends two 32-bit values to 64 bits, multiplies them, and returns the +// result as a pair of (low 32 bits, high 32 bits). +std::pair UMulLowHigh32(llvm::IRBuilder<>* b, + llvm::Value* src0, + llvm::Value* src1); +// Splits the 64-bit integer value into its high and low 32 bits. +std::pair SplitInt64ToInt32s( + llvm::IRBuilder<>* b, llvm::Value* value_64bits); + +// Checks whether a global variable is already created to represent a +// state passed between RNG calls implemented with Philox algorithm. If not, +// creates such a variable. Returns the global variable. +llvm::GlobalVariable* GetOrCreateVariableForPhiloxRngState( + llvm::Module* module, llvm::IRBuilder<>* b); + +// Adds a value to the global state variable each time when a RNG hlo is +// executed. The value of this global state variable is added to the seed +// of the Philox RNG algorithm so that calling the same RNG Hlo multiple times +// should rarely produce the same result. +void IncrementVariableForPhiloxRngState(int64 value, llvm::Module* module, + llvm::IRBuilder<>* b); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index dc2934a34c23f8229947210cacc9863d47c2ea55..36f5fa195224c20e30a14f72b32eb42a681bb5e9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -33,26 +33,24 @@ namespace xla { namespace llvm_ir { LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, - llvm::IRBuilder<>* ir_builder) - : body_emitter_(body_emitter), shape_(shape), ir_builder_(ir_builder) {} + llvm::IRBuilder<>* b) + : body_emitter_(body_emitter), shape_(shape), b_(b) {} LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, - const IrArray& target_array, - llvm::IRBuilder<>* ir_builder) + const IrArray& target_array, llvm::IRBuilder<>* b) : body_emitter_([=](const llvm_ir::IrArray::Index array_index) -> Status { // Convert target_element_generator to a BodyEmitter. TF_ASSIGN_OR_RETURN(llvm::Value * target_element, target_element_generator(array_index)); - target_array.EmitWriteArrayElement(array_index, target_element, - ir_builder); + target_array.EmitWriteArrayElement(array_index, target_element, b); return Status::OK(); }), shape_(target_array.GetShape()), - ir_builder_(ir_builder) {} + b_(b) {} static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( const ElementGenerator& target_element_generator, - const std::vector& target_arrays, llvm::IRBuilder<>* ir_builder) { + const std::vector& target_arrays, llvm::IRBuilder<>* b) { return [=](const llvm_ir::IrArray::Index array_index) { TF_ASSIGN_OR_RETURN(llvm::Value * target_element, target_element_generator(array_index)); @@ -64,8 +62,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( for (int64 i = 0; i < target_arrays.size(); ++i) { target_arrays[i].EmitWriteArrayElement( - array_index, ir_builder->CreateExtractValue(target_element, i), - ir_builder); + array_index, b->CreateExtractValue(target_element, i), b); } return Status::OK(); }; @@ -73,13 +70,12 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, - llvm::IRBuilder<>* ir_builder) + llvm::IRBuilder<>* b) : body_emitter_(MakeBodyEmitterForMultiOutputFusion( target_element_generator, - std::vector(target_arrays.begin(), target_arrays.end()), - ir_builder)), + std::vector(target_arrays.begin(), target_arrays.end()), b)), shape_(target_arrays[0].GetShape()), - ir_builder_(ir_builder) { + b_(b) { // Sanity check: In multi-output fusion, all shapes produced must have the // same dimensions. for (const IrArray& array : target_arrays) { @@ -90,19 +86,20 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, } std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) { + tensorflow::StringPiece loop_name, llvm::Type* index_type) { + CHECK_NE(index_type, nullptr); if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. exit_bb_ = nullptr; - return {IrArray::Index()}; + return {IrArray::Index(index_type)}; } // Create loop nest with one for-loop for each dimension of the target shape. // Loops are added from outermost to innermost order with the ForLoopNest // class so emit loops in order from most-major dimension down to most-minor // dimension (of the target shape). - ForLoopNest loop_nest(loop_name, ir_builder_); - IrArray::Index array_index(shape_.dimensions_size()); + ForLoopNest loop_nest(loop_name, b_); + IrArray::Index array_index(index_type, shape_.dimensions_size()); for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { int64 dimension = LayoutUtil::Major(shape_.layout(), i); std::unique_ptr loop = loop_nest.AddLoop( @@ -115,8 +112,8 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( // Set IR builder insertion point to the loop body basic block of the // innermost loop. llvm::BasicBlock* innermost_body_bb = loop_nest.GetInnerLoopBodyBasicBlock(); - ir_builder_->SetInsertPoint(innermost_body_bb, - innermost_body_bb->getFirstInsertionPt()); + b_->SetInsertPoint(innermost_body_bb, + innermost_body_bb->getFirstInsertionPt()); // Set exit_bb_ to the exit block of the loop nest. exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); @@ -125,16 +122,21 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { +Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name, + llvm::Type* index_type) { + if (index_type == nullptr) { + index_type = b_->getInt64Ty(); + } + for (const IrArray::Index& array_index : - EmitIndexAndSetExitBasicBlock(loop_name)) { + EmitIndexAndSetExitBasicBlock(loop_name, index_type)) { TF_RETURN_IF_ERROR(body_emitter_(array_index)); } - // Set the insertion point of ir_builder_ to the loop exit, so that + // Set the insertion point of b_ to the loop exit, so that // code emitted for later instructions will be correctly placed. if (exit_bb_ != nullptr) { - ir_builder_->SetInsertPoint(exit_bb_); + b_->SetInsertPoint(exit_bb_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index b70d28ecd3033eb26629718e50ce48f39b162273..c4f5c82086ccfa233e0be118b1de10cce55a51b1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -41,11 +41,11 @@ class LoopEmitter { using BodyEmitter = std::function; LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Constructs a LoopEmitter from an element generator that generates each // element of the given target array. LoopEmitter(const ElementGenerator& target_element_generator, - const IrArray& target_array, llvm::IRBuilder<>* ir_builder); + const IrArray& target_array, llvm::IRBuilder<>* b); // Constructs a LoopEmitter that emits one element into each of N separate // arrays on each iteration of the loop. @@ -54,7 +54,7 @@ class LoopEmitter { // produce an LLVM struct with N elements. LoopEmitter(const ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); LoopEmitter(const LoopEmitter&) = delete; LoopEmitter& operator=(const LoopEmitter&) = delete; @@ -65,13 +65,15 @@ class LoopEmitter { // specifies the element, will return multiple indices if the loop is // unrolled. std::vector EmitIndexAndSetExitBasicBlock() { - return EmitIndexAndSetExitBasicBlock(/*loop_name=*/""); + return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", b_->getInt64Ty()); } + virtual std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name); + tensorflow::StringPiece loop_name, llvm::Type* index_type); // Emits a complete loop nest for every element in the given shape. - Status EmitLoop(tensorflow::StringPiece loop_name = ""); + Status EmitLoop(tensorflow::StringPiece loop_name = "", + llvm::Type* index_type = nullptr); protected: // An IR emitter that generates the loop body. @@ -84,7 +86,7 @@ class LoopEmitter { // scalar, no loops are emitted and exit_bb_ is nullptr in that case. llvm::BasicBlock* exit_bb_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e115cdabf4b290617700276dba8f2e5648a7c07 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +namespace llvm_ir { + +llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input) { + llvm::Type* type = input->getType(); + + // Clamp the input to [-9, 9]. + llvm::Value* input_clamped = llvm_ir::EmitFloatMin( + llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, -9.0), b), + llvm::ConstantFP::get(type, 9.0), b); + + static constexpr std::array numerator_coeffs{ + -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, + 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, + 4.89352455891786e-03f}; + + static constexpr std::array denominator_coeffs{ + 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, + 4.89352518554385e-03f}; + + llvm::Value* input_squared = b->CreateFMul(input_clamped, input_clamped); + llvm::Value* numerator = llvm::ConstantFP::get(type, numerator_coeffs[0]); + for (int i = 1; i < numerator_coeffs.size(); i++) { + numerator = b->CreateFAdd(b->CreateFMul(input_squared, numerator), + llvm::ConstantFP::get(type, numerator_coeffs[i])); + } + + numerator = b->CreateFMul(input_clamped, numerator); + + llvm::Value* denominator = llvm::ConstantFP::get(type, denominator_coeffs[0]); + for (int i = 1; i < denominator_coeffs.size(); i++) { + denominator = + b->CreateFAdd(b->CreateFMul(input_squared, denominator), + llvm::ConstantFP::get(type, denominator_coeffs[i])); + } + + return b->CreateFDiv(numerator, denominator); +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/math_ops.h b/tensorflow/compiler/xla/service/llvm_ir/math_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..6c8bc3a076367eae2f1829966be2872e5f258178 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/math_ops.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_MATH_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_MATH_OPS_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" + +namespace xla { +namespace llvm_ir { + +// Emits an approximation of tanh. The implementation uses the same rational +// interpolant as implemented in Eigen3. +llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input); + +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_MATH_OPS_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..e546f5cc4ae305b40c1bdbcae090daadee11241b --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -0,0 +1,161 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" + +// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace llvm_ir { + +namespace { +// Adds the inner comparison loop where we compare elements pointed to by +// 'keys_index' and 'compare_keys_index'. +void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, + const IrArray::Index& compare_keys_index, + const IrArray& keys_array, + const tensorflow::gtl::optional& values_array, + llvm::IRBuilder<>* b) { + // if (is_smaller_index && + // compare_keys[dimension_to_sort] < dimension_to_sort_bound) + llvm::Value* is_smaller_index = b->CreateICmpSLT( + keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]); + int64 dimension_to_sort_bound = + keys_array.GetShape().dimensions(dimension_to_sort); + auto if_data = EmitIfThenElse( + b->CreateAnd(is_smaller_index, + b->CreateICmpSLT(compare_keys_index[dimension_to_sort], + keys_index.GetConstantWithIndexType( + dimension_to_sort_bound))), + "smaller_comparison_index", b, /*emit_else=*/false); + SetToFirstInsertPoint(if_data.true_block, b); + auto key1 = keys_array.EmitReadArrayElement(keys_index, b); + auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, b); + auto key_type = keys_array.GetShape().element_type(); + auto comparison = + primitive_util::IsFloatingPointType(key_type) + // TODO(b/26783907): Figure out how to handle NaNs. + ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key2, key1) + : b->CreateICmp(primitive_util::IsSignedIntegralType(key_type) + ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + key2, key1); + // If key2 < key1 + auto if_smaller_data = + EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false); + SetToFirstInsertPoint(if_smaller_data.true_block, b); + // Swap key1 with key2. + keys_array.EmitWriteArrayElement(keys_index, key2, b); + keys_array.EmitWriteArrayElement(compare_keys_index, key1, b); + if (values_array.has_value()) { + // Also swap the values. + auto value1 = values_array.value().EmitReadArrayElement(keys_index, b); + auto value2 = + values_array.value().EmitReadArrayElement(compare_keys_index, b); + values_array.value().EmitWriteArrayElement(keys_index, value2, b); + values_array.value().EmitWriteArrayElement(compare_keys_index, value1, b); + } +} +} // namespace + +Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, + const tensorflow::gtl::optional& values_array, + tensorflow::StringPiece name, llvm::Value* xor_mask, + llvm::IRBuilder<>* b, + const gpu::LaunchDimensions* launch_dimensions) { + const Shape& keys_shape = keys_array.GetShape(); + + // Create loop nests which loop through the operand dimensions. The sort + // dimension is handled in the innermost loop which performs the sorting. + ForLoopNest loop_nest(name, b); + IrArray::Index keys_index = + loop_nest.EmitOperandArrayLoopNest(keys_array, dimension_to_sort, "keys"); + if (loop_nest.GetInnerLoopBodyBasicBlock() != nullptr) { + SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), b); + } + + // 'compare_keys_index' is the index of the element that 'keys_index' should + // be compared to. + IrArray::Index compare_keys_index(keys_index.GetType()); + for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) { + if (dimension != dimension_to_sort) { + compare_keys_index.push_back(keys_index[dimension]); + } else { + compare_keys_index.push_back(nullptr); + } + } + + // Naive C++ code for the inner compare loop: + // + // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { + // int64 j = i ^ xor_mask; + // if (i < j && j < dimension_to_sort_bound) { + // int64 min_key = std::min(keys[i], keys[j]); + // keys[j] = std::max(keys[i], keys[j]); + // keys[i] = min_key; + // } + // } + // + // This follows the algorithm described on Wikipedia: + // https://en.wikipedia.org/wiki/Bitonic_sorter + + int64 dimension_to_sort_bound = + keys_array.GetShape().dimensions(dimension_to_sort); + Shape compare_shape = ShapeUtil::MakeShape(keys_shape.element_type(), + {dimension_to_sort_bound}); + auto compare_loop_body_emitter = + [&](const IrArray::Index& compare_index) -> Status { + keys_index[dimension_to_sort] = compare_index[0]; + compare_keys_index[dimension_to_sort] = + b->CreateXor(compare_index[0], xor_mask); + EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, + keys_array, values_array, b); + return Status::OK(); + }; + if (launch_dimensions != nullptr) { + TF_RETURN_IF_ERROR(gpu::ParallelLoopEmitter(compare_loop_body_emitter, + compare_shape, + *launch_dimensions, b) + .EmitLoop(name)); + } else { + TF_RETURN_IF_ERROR(LoopEmitter(compare_loop_body_emitter, compare_shape, b) + .EmitLoop(name)); + } + + // Set the IR builder insert point to the exit basic block of the outer most + // loop. This ensures later instructions are inserted after this loop nest. + b->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); + + return Status::OK(); +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h new file mode 100644 index 0000000000000000000000000000000000000000..8458744c6bc0e50a1c1cc8d3e66e29c7d4f74d73 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ + +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace llvm_ir { +// Emits llvm IR to do pairwise comparisons/swaps in the 'dimension_to_sort' +// dimension of 'keys_array'. All other dimensions are kept as-is. This +// implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr, +// the inner compare loop will not be parallelized. +Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, + const tensorflow::gtl::optional& values_array, + tensorflow::StringPiece name, llvm::Value* xor_mask, + llvm::IRBuilder<>* b, + const gpu::LaunchDimensions* launch_dimensions); +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 5fc08aab916e377b245b6221108956c06da70767..11ed6ee59f1bf8e7004b8bef7319b37ef41a304c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -31,12 +31,12 @@ namespace llvm_ir { void EmitTupleSelect(const IrArray& select, const IrArray& pred, llvm::Value* on_true, llvm::Value* on_false, - llvm::IRBuilder<>* ir_builder, llvm::Module* module) { + llvm::IRBuilder<>* b, llvm::Module* module) { CHECK(ShapeUtil::IsScalar(pred.GetShape())); llvm::LoadInst* pred_value = - ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = ir_builder->CreateICmpNE( + b->CreateLoad(pred.GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = b->CreateICmpNE( pred_value, llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, module), 0), "boolean_predicate"); @@ -46,47 +46,42 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { - llvm::Value* const element_index[] = {ir_builder->getInt64(0), - ir_builder->getInt64(i)}; + llvm::Value* const element_index[] = {b->getInt64(0), b->getInt64(i)}; llvm::Value* on_true_element_address = - ir_builder->CreateInBoundsGEP(on_true, element_index); - llvm::Value* on_true_element = ir_builder->CreateLoad( + b->CreateInBoundsGEP(on_true, element_index); + llvm::Value* on_true_element = b->CreateLoad( on_true_element_address, "on_true_element_" + llvm::Twine(i)); llvm::Value* on_false_element_address = - ir_builder->CreateInBoundsGEP(on_false, element_index); - llvm::Value* on_false_element = ir_builder->CreateLoad( + b->CreateInBoundsGEP(on_false, element_index); + llvm::Value* on_false_element = b->CreateLoad( on_false_element_address, "on_false_element_" + llvm::Twine(i)); llvm::Value* output_element_address = - ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); - ir_builder->CreateStore( - ir_builder->CreateSelect(pred_cond, on_true_element, on_false_element, - "select_output_element_" + llvm::Twine(i)), - output_element_address); + b->CreateInBoundsGEP(select.GetBasePointer(), element_index); + b->CreateStore(b->CreateSelect(pred_cond, on_true_element, on_false_element, + "select_output_element_" + llvm::Twine(i)), + output_element_address); } } void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, - llvm::IRBuilder<>* ir_builder, llvm::Module* module) { + llvm::IRBuilder<>* b, llvm::Module* module) { for (size_t i = 0; i < operands.size(); ++i) { - auto* store = ir_builder->CreateStore( - ir_builder->CreatePointerCast(operands[i], - PrimitiveTypeToIrType(TUPLE, module)), - ir_builder->CreateInBoundsGEP( - tuple.GetBasePointer(), - {ir_builder->getInt64(0), ir_builder->getInt64(i)})); + auto* store = b->CreateStore( + b->CreatePointerCast(operands[i], PrimitiveTypeToIrType(TUPLE, module)), + b->CreateInBoundsGEP(tuple.GetBasePointer(), + {b->getInt64(0), b->getInt64(i)})); tuple.AnnotateLoadStoreInstructionWithMetadata(store); } } llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, - llvm::IRBuilder<>* ir_builder, - llvm::Module* module) { - llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP( - operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)}); - llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr); + llvm::IRBuilder<>* b, llvm::Module* module) { + llvm::Value* element_ptr = + b->CreateInBoundsGEP(operand, {b->getInt64(0), b->getInt64(index)}); + llvm::LoadInst* src_buffer = b->CreateLoad(element_ptr); // Mark the loaded pointer as dereferenceable if we know its shape. if (!ShapeUtil::IsOpaque(target_shape)) { @@ -98,7 +93,7 @@ llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, llvm::Type* element_type = ShapeToIrType(target_shape, module); llvm::Value* ret_val = - ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo()); + b->CreateBitCast(src_buffer, element_type->getPointerTo()); return ret_val; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index 352d34ebf839c6c2465abade7c3d3eb3b7a34506..cf6bf5d0b14ba71cbed67f9a1dc728c0eef5e393 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -61,13 +61,13 @@ namespace llvm_ir { // output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] void EmitTupleSelect(const IrArray& select, const IrArray& pred, llvm::Value* on_true, llvm::Value* on_false, - llvm::IRBuilder<>* ir_builder, llvm::Module* module); + llvm::IRBuilder<>* b, llvm::Module* module); // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, - llvm::IRBuilder<>* ir_builder, llvm::Module* module); + llvm::IRBuilder<>* b, llvm::Module* module); // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. A GetTupleElement instruction @@ -75,8 +75,7 @@ void EmitTuple(const IrArray& tuple, // Returns an llvm value representing a pointer to the tuple element buffer. llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, - llvm::IRBuilder<>* ir_builder, - llvm::Module* module); + llvm::IRBuilder<>* b, llvm::Module* module); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 296d04d4362b12fdc39798a016ca9e8795e02586..5e02096ee501b23a7976a50f13bb7e7f3c5e2d34 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -154,7 +155,8 @@ StatusOr> LocalService::CompileExecutable( for (int i = 0; i < argument_layouts.size(); ++i) { const Shape& argument_shape = *argument_layouts[i]; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape)); if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { tensorflow::gtl::optional metadata = ParameterMetadata(computation, /*parameter_number=*/i); @@ -178,8 +180,8 @@ StatusOr> LocalService::CompileExecutable( } } if (build_options.result_layout() != nullptr) { - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( - *build_options.result_layout(), program_shape.result())); + TF_RETURN_IF_ERROR(ValidateResultShape(*build_options.result_layout(), + program_shape.result())); } ExecutionOptions execution_options = @@ -189,6 +191,9 @@ StatusOr> LocalService::CompileExecutable( std::unique_ptr module_config, CreateModuleConfig(program_shape, argument_layouts, &execution_options)); + VLOG(3) << "Computation Layout: " + << module_config->entry_computation_layout().ToString(); + TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, execute_backend_->stream_executor(build_options.device_ordinal())); diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 39d6734c3fc06df6832cf67edddbc7c14c815cd1..8f707ea9046a00a15cac469672a7a992f20bf483 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index f410921b4b5337192bdeae5924631d9c06b7d5a5..d631fb5ee42df6525681a5cd1fe1a8241824121d 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -131,18 +131,23 @@ Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) { return Status::OK(); } -Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) { - // RecvDone doesn't create a new buffer but rather aliases its input (Recv) - // tuple element at {0} to its output. +Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction* recv_done) { + // RecvDone produces a two-element tuple containing the data value (which + // aliases part of its operand) and a token. Only the tuple index table and + // the token are defined by the RecvDone. + NewLogicalBuffer(recv_done, /*index=*/{}); + NewLogicalBuffer(recv_done, /*index=*/{1}); return Status::OK(); } Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) { - // Send creates new buffers for the top-level tuple and the context (tuple - // element at {1}). Tuple element at {0} is an alias of the Send operand, so - // we don't need to create a new Logical Buffer for that. + // Send creates new buffers for the top-level tuple, the context (tuple + // element at {1}), and the token (tuple element at {2}). Tuple element at {0} + // is an alias of the Send operand, so we don't need to create a new Logical + // Buffer for that. NewLogicalBuffer(send, /*index=*/{}); NewLogicalBuffer(send, /*index=*/{1}); + NewLogicalBuffer(send, /*index=*/{2}); return Status::OK(); } @@ -152,10 +157,10 @@ Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -Status LogicalBufferAnalysis::HandleSelect(HloInstruction* select) { +Status LogicalBufferAnalysis::HandleTupleSelect(HloInstruction* tuple_select) { // Select allocates a new buffer and then shallow copies the on_true or // on_false buffer into this new buffer. - NewLogicalBuffer(select, /*index=*/{}); + NewLogicalBuffer(tuple_select, /*index=*/{}); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index b5ef3967875a58b35631d5f69c210f5cbcd91250..81f524d84a8091e1fff13dc7c55b401143a02753 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -63,7 +63,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; - Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; // A map from the buffer ID to the logical buffer std::vector> logical_buffers_; diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 29f787b86b9cbb6f80d048b46b78bdad8074f488..4166ef5baf9c891968b584a0c498005e9ae87784 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -28,7 +28,7 @@ StatusOr MultiOutputFusion::Run(HloModule* module) { for (auto* computation : module->MakeNonfusionComputations()) { computation_ = computation; - reachability_ = computation_->ComputeReachability(); + RecomputeReachability(); candidates_.clear(); candidates_index_.clear(); all_fusion_candidates_.clear(); @@ -115,42 +115,37 @@ HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, HloInstruction* fused = instr2; // Make sure that if only one of the instructions is a fusion, or if only one // of the instructions is a multi-output fusion, it's what will be fused into. - // - // An invariant is that no bitcast nodes will show up in the middle of a - // fusion node. This invariant must hold in order for us to lower it. Given - // that, we require that during multi-output fusion, a fusion node ending with - // bitcast to preserve its structure as a nested fusion instead being - // merged and flattened. - if (fused->opcode() == HloOpcode::kFusion && - fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + if (fused->opcode() == HloOpcode::kFusion) { std::swap(remaining, fused); } if (fused->IsMultiOutputFusion()) { std::swap(remaining, fused); } - if (fused->opcode() == HloOpcode::kFusion && - fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + if (fused->opcode() == HloOpcode::kFusion) { remaining->MergeFusionInstructionIntoMultiOutput(fused); } else { - if (remaining->opcode() == HloOpcode::kFusion && - remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) { - auto parent_computation = remaining->parent(); - // Create a nested fusion node. - auto remaining_nested_fused = - parent_computation->AddInstruction(HloInstruction::CreateFusion( - remaining->shape(), HloInstruction::FusionKind::kLoop, - remaining)); - TF_CHECK_OK(parent_computation->ReplaceInstruction( - remaining, remaining_nested_fused)); - remaining = remaining_nested_fused; - } remaining->FuseInstructionIntoMultiOutput(fused); } - return remaining; } +bool MultiOutputFusion::IsProfitableOperand(HloInstruction* instr) { + // kConstant instruction will not have memory reads, so it won't be a profit + // source. Skip them. + if (instr->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(instr->shape())) { + return false; + } + // We don't target to fuse producer/consumer instructions -- this should + // be taken care of by the instruction_fusion pass. If instr has only + // one user, it will not have sibling instructions. We won't consider it. + if (instr->user_count() < 2) { + return false; + } + return true; +} + void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { HloInstruction* fusion = instr1; HloInstruction* fused = instr2; @@ -261,6 +256,10 @@ bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, return true; } +void MultiOutputFusion::RecomputeReachability() { + reachability_ = computation_->ComputeReachability(); +} + void MultiOutputFusion::UpdateReachability( HloInstruction* instr1, HloInstruction* instr2, tensorflow::gtl::ArraySlice instrs_to_update, @@ -329,14 +328,11 @@ bool MultiOutputFusion::Perform() { --fuel_; } } - if (DoProducerConsumerMultiOutputFusion(computation_)) { + if (DoProducerConsumerMultiOutputFusion()) { changed = true; } return changed; } -bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion( - HloComputation* /*computation*/) { - return false; -} +bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion() { return false; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index cfdf83cfe856a7c3b05f51129446cd4e1055a8d6..0019cd725417d81900974b462c3b05075ce3e893 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -72,12 +72,25 @@ class MultiOutputFusion : public HloPassInterface { // multi-output fusion instruction. virtual int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) = 0; - // Whether fusing the instruction can reduce cost. - virtual bool IsProfitableOperand(HloInstruction* instr) = 0; + // Whether fusing the instruction can reduce memory reads. + virtual bool IsProfitableOperand(HloInstruction* instr); // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2); + // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. + // The other instruction is removed from its parent computation. + virtual HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); + + // Recompute reachability for the current computation. + void RecomputeReachability(); + + // Returns the reachability map for the current computation. + HloReachabilityMap* reachability() const { return reachability_.get(); } + + // Returns the computation for the pass. + HloComputation* computation() const { return computation_; } + // Update the reachability map after fusing instr1 and instr2. void UpdateReachability( HloInstruction* instr1, HloInstruction* instr2, @@ -89,13 +102,9 @@ class MultiOutputFusion : public HloPassInterface { // // TODO(b/80420762): Perform producer-consumer multi-output fusion in // InstructionFusion instead. - virtual bool DoProducerConsumerMultiOutputFusion(HloComputation* computation); + virtual bool DoProducerConsumerMultiOutputFusion(); private: - // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. - // The other instruction is removed from its parent computation. - HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); - // Update the internal data structures after instr1 and instr2 are fused into // one fusion instruction. void Update(HloInstruction* instr1, HloInstruction* instr2); diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index 3a6a7c25f4b727c7112dbcbcb4f3d892679a0011..f6e7578a89551ec2f23d4d8c8b488c3c10e0bf1c 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -67,22 +67,17 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { has_numeric_suffix = true; // Remove numeric suffix from root. root = root.substr(0, separator_index); - // Update count to at least the numeric suffix value to avoid future - // colisions with this name. - generated_names_[root] = std::max(generated_names_[root], numeric_suffix); } } - int64* count = &(generated_names_[root]); - if (*count == 0) { - *count = 1; + + SequentialIdGenerator& id_generator = generated_names_[root]; + numeric_suffix = id_generator.RegisterId(numeric_suffix); + if (numeric_suffix == 0) { return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0) : root; - } else { - tensorflow::strings::StrAppend(&root, separator_, *count); - // Increment lookup under old 'root' name. - (*count)++; - return root; } + tensorflow::strings::StrAppend(&root, separator_, numeric_suffix); + return root; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 4139c2700b25e8600182a034a8ac6f4f041c12e6..4423d6106920eaeab830bd9dc08529ff409a5161 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -17,10 +17,11 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_NAME_UNIQUER_H_ #include -#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -44,13 +45,40 @@ class NameUniquer { static string GetSanitizedName(const string& name); private: + // Used to track and generate new identifiers for the same instruction name + // root. + class SequentialIdGenerator { + public: + SequentialIdGenerator() = default; + + // Tries to register id as used identifier. If id is not already used, the + // id itself will be returned. Otherwise a new one will be generated, and + // returned. + int64 RegisterId(int64 id) { + if (used_.insert(id).second) { + return id; + } + while (!used_.insert(next_).second) { + ++next_; + } + return next_++; + } + + private: + // The next identifier to be tried. + int64 next_ = 0; + + // Set of all the identifiers which has been used. + tensorflow::gtl::FlatSet used_; + }; + // The string to use to separate the prefix of the name from the uniquing // integer value. string separator_; - // Map from name prefix to the number of names generated using that prefix - // so far. - std::unordered_map generated_names_; + // Map from name prefix to the generator data structure which tracks used + // identifiers and generates new ones. + tensorflow::gtl::FlatMap generated_names_; TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer); }; diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 2ec255558c4ed3695ec6c824458cbedac44dc297..3e2592c6ac626143f1421e545a31d9be91e376bc 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -54,12 +54,13 @@ TEST_F(NameUniquerTest, NumericSuffixes) { EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); - EXPECT_EQ("foo.55", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo")); EXPECT_EQ("foo.55.1", uniquer.GetUniqueName("foo.55.1")); - EXPECT_EQ("foo.55.2", uniquer.GetUniqueName("foo.55.1")); - EXPECT_EQ("bar.0", uniquer.GetUniqueName("bar.-1000")); - EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); - EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); + EXPECT_EQ("foo.55.0", uniquer.GetUniqueName("foo.55.1")); + EXPECT_EQ("bar.1000", uniquer.GetUniqueName("bar.1000")); + EXPECT_EQ("bar.2000", uniquer.GetUniqueName("bar.2000")); + EXPECT_EQ("bar.-2000", uniquer.GetUniqueName("bar.-2000")); + EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.1")); } TEST_F(NameUniquerTest, PrefixHasSuffix) { @@ -77,12 +78,12 @@ TEST_F(NameUniquerTest, Sanitize) { EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); EXPECT_EQ("foo_54", uniquer.GetUniqueName("foo_54")); EXPECT_EQ("foo_54.1", uniquer.GetUniqueName("foo_54.1")); - EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo_2", uniquer.GetUniqueName("foo")); // Invalid characters will be replaced with '_'. - EXPECT_EQ("bar_0", uniquer.GetUniqueName("bar<-1000")); - EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000")); - EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1")); + EXPECT_EQ("bar_1000", uniquer.GetUniqueName("bar<1000")); + EXPECT_EQ("bar_2000", uniquer.GetUniqueName("bar<2000")); + EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar_1")); // Separator is only recognized in the middle of the prefix. EXPECT_EQ("_10", uniquer.GetUniqueName( @@ -93,5 +94,15 @@ TEST_F(NameUniquerTest, Sanitize) { EXPECT_EQ("foobar__1", uniquer.GetUniqueName("foobar_")); } +TEST_F(NameUniquerTest, KeepNamesInRandomOrder) { + NameUniquer uniquer("."); + + EXPECT_EQ("foo.11", uniquer.GetUniqueName("foo.11")); + EXPECT_EQ("foo.10", uniquer.GetUniqueName("foo.10")); + EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo.1")); + EXPECT_EQ("foo.12", uniquer.GetUniqueName("foo.12")); + EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo.3")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 2515222cf2db3d9699c85c13f4fe72b3488fa217..ac6ea4c72f61a47726b3ae7dd000837d3fba1b93 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -86,8 +86,8 @@ namespace xla { // are provided below. // // Example nullary instruction: -// Recv() == Op().WithOpcode(HloOpcode::kRecv) -// Recv(&a) == Op(&a).WithOpcode(HloOpcode::kRecv) +// Param() == Op().WithOpcode(HloOpcode::kParam) +// Param(&a) == Op(&a).WithOpcode(HloOpcode::kParam) // // Example unary instruction: // Abs() == Op().WithOpcode(HloOpcode::kAbs) @@ -726,6 +726,32 @@ class HloInstructionPatternFusionKindImpl { ::xla::HloInstruction::FusionKind kind_; }; +// An HloInstructionPattern implementation that matches only if the instruction +// is a kGetTupleElement with a particular tuple index. +template +class HloInstructionPatternTupleIndexImpl { + public: + explicit constexpr HloInstructionPatternTupleIndexImpl( + const Previous& previous, int64 tuple_index) + : previous_(previous), tuple_index_(tuple_index) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && + inst->opcode() == HloOpcode::kGetTupleElement && + inst->tuple_index() == tuple_index_; + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && + inst->opcode() == HloOpcode::kGetTupleElement && + inst->tuple_index() == tuple_index_; + } + + private: + Previous previous_; + int64 tuple_index_; +}; + // A pattern that matches HloInstructions. template class HloInstructionPattern { @@ -841,6 +867,17 @@ class HloInstructionPattern { HloInstructionPatternFusionKindImpl(impl_, kind), matched_inst_); } + // Modifies the pattern to match only if the instruction is a + // get-tuple-element with the given tuple index. + constexpr HloInstructionPattern> + WithTupleIndex(int64 tuple_index) const { + return HloInstructionPattern>( + HloInstructionPatternTupleIndexImpl(impl_, tuple_index), + matched_inst_); + } + private: Impl impl_; HloInstructionType** matched_inst_; @@ -880,9 +917,7 @@ Op(::xla::HloInstruction** matched_inst) { return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \ } XLA_NULLOP_PATTERN(Constant) -XLA_NULLOP_PATTERN(Infeed) XLA_NULLOP_PATTERN(Parameter) -XLA_NULLOP_PATTERN(Recv) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. @@ -919,18 +954,21 @@ XLA_UNOP_PATTERN(Cos) XLA_UNOP_PATTERN(Exp) XLA_UNOP_PATTERN(Fft) XLA_UNOP_PATTERN(Floor) +XLA_UNOP_PATTERN(GetTupleElement) XLA_UNOP_PATTERN(Imag) +XLA_UNOP_PATTERN(Infeed) XLA_UNOP_PATTERN(IsFinite) XLA_UNOP_PATTERN(Log) XLA_UNOP_PATTERN(Not) XLA_UNOP_PATTERN(Negate) -XLA_UNOP_PATTERN(Outfeed) XLA_UNOP_PATTERN(Real) +XLA_UNOP_PATTERN(Recv) +XLA_UNOP_PATTERN(RecvDone) XLA_UNOP_PATTERN(Reduce) XLA_UNOP_PATTERN(ReducePrecision) XLA_UNOP_PATTERN(Reshape) XLA_UNOP_PATTERN(Reverse) -XLA_UNOP_PATTERN(Send) +XLA_UNOP_PATTERN(SendDone) XLA_UNOP_PATTERN(Sign) XLA_UNOP_PATTERN(Sin) XLA_UNOP_PATTERN(Sort) @@ -981,8 +1019,10 @@ XLA_BINOP_PATTERN(Maximum) XLA_BINOP_PATTERN(Minimum) XLA_BINOP_PATTERN(Multiply) XLA_BINOP_PATTERN(Ne) +XLA_BINOP_PATTERN(Outfeed) XLA_BINOP_PATTERN(Power) XLA_BINOP_PATTERN(Remainder) +XLA_BINOP_PATTERN(Send) XLA_BINOP_PATTERN(Subtract) XLA_BINOP_PATTERN(And) XLA_BINOP_PATTERN(Or) @@ -1040,6 +1080,32 @@ inline auto NonConstant(HloInstructionType** matched_inst) return Op(matched_inst).IsNonConstant(); } +// Add overloads for GetTupleElement which take a int64 specifying which tuple +// element is selected. +template +inline auto GetTupleElement(Arg&& arg, int64 tuple_index) + -> decltype(Op().WithOpcode(HloOpcode::kGetTupleElement) + .WithOperand(0, std::forward(arg)) + .WithTupleIndex(tuple_index)) { + return Op() + .WithOpcode(HloOpcode::kGetTupleElement) + .WithOperand(0, std::forward(arg)) + .WithTupleIndex(tuple_index); +} + +template +inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, + int64 tuple_index) + -> decltype(Op(matched_inst) + .WithOpcode(HloOpcode::kGetTupleElement) + .WithOperand(0, std::forward(arg)) + .WithTupleIndex(tuple_index)) { + return Op(matched_inst) + .WithOpcode(HloOpcode::kGetTupleElement) + .WithOperand(0, std::forward(arg)) + .WithTupleIndex(tuple_index); +} + } // namespace match } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index fef3c132b0f3467a01b02f2be88b419459179277..a530581c34bf1d699eae3c53203c197f7943cc53 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -193,5 +193,23 @@ TEST(PatternMatcherTest, FusionKind) { HloInstruction::FusionKind::kLoop))); } +TEST(PatternMatcherTest, GetTupleElement) { + constexpr char kModuleStr[] = R"( + HloModule test_module + + ENTRY while.v11 { + p0 = (f32[], f32[], f32[]) parameter(0) + ROOT gte = f32[] get-tuple-element(p0), index=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + + auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0))); + EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1))); + EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2))); + EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0))); + EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 7c63c0acc7764d558b2151190f0fa79fac355cbf..39fe3c7835d1c74c0f1e5bc0ebf5916ec734c24a 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -75,19 +75,6 @@ PlatformUtil::GetSupportedPlatforms() { auto* platform = platform_pair.second; auto compiler_status = Compiler::GetForPlatform(platform); if (compiler_status.ok()) { - if (platform->VisibleDeviceCount() > 0) { - LOG(INFO) << "platform " << platform->Name() << " present with " - << platform->VisibleDeviceCount() << " visible devices"; - } else { - LOG(WARNING) << "platform " << platform->Name() << " present but no " - << "visible devices found"; - } - // Note: currently we call zero device platforms "supported" on the basis - // that, if the platform support was linked in, it was probably intended - // to be used for execution, and this way we can flag an error. - // - // TODO(b/33730287) If we want an alternative version of this behavior we - // could add an --xla_fallback_to_host flag. platforms.push_back(platform); } else { LOG(INFO) << "platform " << platform->Name() << " present but no " diff --git a/tensorflow/compiler/xla/service/pool.h b/tensorflow/compiler/xla/service/pool.h deleted file mode 100644 index 8e710ebb6dc17e0e204ba6ab3c6c159627cd9d3b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/pool.h +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_POOL_H_ -#define TENSORFLOW_COMPILER_XLA_POOL_H_ - -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/core/platform/mutex.h" - -namespace xla { - -// Pool of values, which are created as needed and destroyed when the `Pool` is -// destroyed -template -class Pool { - public: - struct Deleter { - void operator()(T* ptr) { pool->Deallocate(ptr); } - - Pool* pool; - }; - - // A pointer to a taken element of a `Pool` which returns it to the pool on - // destruction - using SmartPtr = std::unique_ptr; - - // Constructs a `Pool` with given factory function, which need not be - // thread-safe. - explicit Pool(std::function()> factory) - : factory_(factory) {} - - explicit Pool() : Pool([]() { return MakeUnique(); }) {} - - // Returns a pointer to a value in the pool, creating a new value if none is - // free. The returned smart pointer returns the element to the pool on - // destruction. - // - // This method is thread-safe. - SmartPtr Allocate() { - tensorflow::mutex_lock lock(mu_); - T* ptr; - if (!xs_.empty()) { - ptr = std::move(xs_.back()).release(); - xs_.pop_back(); - } else { - ptr = factory_().release(); - } - Deleter del = {this}; - return std::unique_ptr(ptr, del); - } - - private: - // Puts a pointer to a value back into the pool, leaving it free for future - // use. - // - // This method is thread-safe. - void Deallocate(T* ptr) { - tensorflow::mutex_lock lock(mu_); - xs_.push_back(std::unique_ptr(ptr)); - } - - const std::function()> factory_ GUARDED_BY(mu_); - std::vector> xs_ GUARDED_BY(mu_); - tensorflow::mutex mu_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_POOL_H_ diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 49ec38eb62c7b51c7a2d301d882cef032b288036..ca86c5d13e98a98c62d0c9e8e32e28fe99e0fa1f 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 13e2d3258e3b92f52320201c382594962c0e3b2b..ccb9fb3e3af5e308accc924d3501213841d7d6c7 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -76,9 +76,13 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); - auto rng0 = builder.AddInstruction( - HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}), - RandomDistribution::RNG_UNIFORM, {})); + auto rng0 = builder.AddInstruction(HloInstruction::CreateRng( + ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}), + RandomDistribution::RNG_UNIFORM, + {builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(1.0f)))})); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0)); @@ -175,8 +179,9 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{true, true, false}, {false, false, true}}))); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( + {{true, true, false}, {false, false, true}}))); auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param1")); @@ -255,12 +260,12 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {3, 2}); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const0)); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); @@ -309,7 +314,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param0")); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); builder.AddInstruction(HloInstruction::CreateBinary( @@ -348,7 +353,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3}), "param0")); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({9, 8, 7}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({9, 8, 7}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); auto reshape1 = diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc new file mode 100644 index 0000000000000000000000000000000000000000..45ca731153bf4312b1c78b3c74224b1ec7ed8436 --- /dev/null +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -0,0 +1,350 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/scatter_expander.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +using tensorflow::gtl::ArraySlice; + +// Transposes the given scatter_indices such that the index_vector_dim becomes +// the most-minor dimension. +static StatusOr TransposeIndexVectorDimToLast( + HloInstruction* scatter_indices, int64 index_vector_dim) { + const Shape& scatter_indices_shape = scatter_indices->shape(); + + if (scatter_indices_shape.dimensions_size() == index_vector_dim) { + return scatter_indices; + } + + if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) { + return scatter_indices; + } + + std::vector permutation; + permutation.reserve(scatter_indices_shape.dimensions_size()); + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != index_vector_dim) { + permutation.push_back(i); + } + } + permutation.push_back(index_vector_dim); + return MakeTransposeHlo(scatter_indices, permutation); +} + +// Canonicalizes the scatter_indices tensor in order to keep them uniform while +// performing the scatter operation. +static StatusOr CanonicalizeScatterIndices( + HloInstruction* scatter_indices, int64 index_vector_dim) { + // Transpose the non-index-vector dimensions to the front. + TF_ASSIGN_OR_RETURN( + HloInstruction * transposed_scatter_indices, + TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); + bool indices_are_scalar = + index_vector_dim == scatter_indices->shape().dimensions_size(); + + // The number of dimensions in scatter_indices that are index dimensions. + const int64 index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1; + + // If there is only one index (i.e. scatter_indices has rank 1 and this + // scatter is really just a dynamic update slice) add a leading degenerate + // dimension for uniformity. Otherwise create a "collapsed" leading dimension + // that subsumes all of the non-index-vector dimensions. + const Shape& shape = transposed_scatter_indices->shape(); + if (shape.dimensions_size() == index_dims_in_scatter_indices) { + return PrependDegenerateDims(transposed_scatter_indices, 1); + } else { + // Collapse all but the dimensions (0 or 1) in scatter_indices containing + // the index vectors. + return CollapseFirstNDims( + transposed_scatter_indices, + shape.dimensions_size() - index_dims_in_scatter_indices); + } +} + +// Permutes the `updates` tensor such that all the scatter dims appear in the +// major dimensions and all the window dimensions appear in the minor +// dimensions. +static StatusOr PermuteScatterAndWindowDims( + HloInstruction* updates, ArraySlice update_window_dims) { + std::vector permutation; + const int64 updates_rank = ShapeUtil::Rank(updates->shape()); + permutation.reserve(updates_rank); + + for (int64 i = 0; i < updates_rank; ++i) { + bool is_scatter_dim = !c_binary_search(update_window_dims, i); + if (is_scatter_dim) { + permutation.push_back(i); + } + } + for (auto window_dim : update_window_dims) { + permutation.push_back(window_dim); + } + + return MakeTransposeHlo(updates, permutation); +} + +// Expands or contracts the scatter indices in the updates tensor. +static StatusOr AdjustScatterDims( + const Shape& scatter_indices_shape, HloInstruction* updates, + int64 index_vector_dim) { + int64 num_scatter_dims = scatter_indices_shape.dimensions_size(); + if (index_vector_dim < scatter_indices_shape.dimensions_size()) { + --num_scatter_dims; + } + if (num_scatter_dims == 0) { + // If there are no scatter dims, this must be a dynamic-update-slice kind of + // scatter. In this case, we prepend a degenerate dimension to work + // uniformly in the while loop. + return PrependDegenerateDims(updates, 1); + } + return CollapseFirstNDims(updates, num_scatter_dims); +} + +// Expands an index vector from the scatter_indices tensor into a vector that +// can be used to dynamic-update-slice to perform the scatter update. +static StatusOr ExpandIndexVectorIntoOperandSpace( + HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers, + int64 operand_rank) { + HloComputation* computation = index_vector->parent(); + const Shape& index_shape = index_vector->shape(); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); + + // We extract out individual components from the smaller index and concatenate + // them (interspersing zeros as needed) into the larger index. + std::vector expanded_index_components; + + for (int i = 0; i < operand_rank; i++) { + int64 index_vector_dim_index = + FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i); + if (index_vector_dim_index != + dim_numbers.scatter_dims_to_operand_dims_size()) { + TF_ASSIGN_OR_RETURN( + HloInstruction * component_to_concat, + MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, + /*limit_indices=*/{index_vector_dim_index + 1}, + /*strides=*/{1})); + expanded_index_components.push_back(component_to_concat); + } else { + expanded_index_components.push_back(zero); + } + } + + return MakeConcatHlo(expanded_index_components, /*dimension=*/0); +} + +// Body of the while loop that performs the scatter operation using other HLOs. +static StatusOr> ScatterLoopBody( + HloInstruction* scatter, HloInstruction* induction_var, + const std::vector& loop_state) { + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + CHECK_EQ(loop_state.size(), 3); + HloInstruction* operand = loop_state[0]; + HloInstruction* scatter_indices = loop_state[1]; + HloInstruction* updates = loop_state[2]; + + bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1; + CHECK_EQ(has_scalar_indices, + dim_numbers.index_vector_dim() == + scatter->operand(1)->shape().dimensions_size()); + + // Build a vector form of the induction variable of the while loop. + TF_ASSIGN_OR_RETURN( + HloInstruction * induction_var_as_vector, + MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/{1})); + + // Pick the index to scatter from scatter_indices based on the induction_var + // and transform that to an index into the `operand` space. + HloInstruction* index_vector; + if (has_scalar_indices) { + TF_ASSIGN_OR_RETURN( + index_vector, + MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1})); + } else { + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_scatter_indices, + PadVectorWithZeros(induction_var_as_vector, + /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); + int index_vector_size = scatter_indices->shape().dimensions(1); + TF_ASSIGN_OR_RETURN( + HloInstruction * index_vector_2d, + MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices, + {1, index_vector_size})); + TF_ASSIGN_OR_RETURN(index_vector, + ElideDegenerateDims(index_vector_2d, {0})); + } + TF_ASSIGN_OR_RETURN( + HloInstruction * scatter_slice_start, + ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers, + operand->shape().dimensions_size())); + + // Extract the slice to be used to update from `updates` tensor for the + // induction_var corresponding to this iteration of the while loop. + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_updates, + PadVectorWithZeros( + induction_var_as_vector, /*zeros_to_prepend=*/0, + /*zeros_to_append=*/updates->shape().dimensions_size() - 1)); + std::vector update_slice_bounds(updates->shape().dimensions().begin(), + updates->shape().dimensions().end()); + update_slice_bounds[0] = 1; + TF_ASSIGN_OR_RETURN( + HloInstruction * update_slice, + MakeDynamicSliceHlo(updates, index_into_updates, update_slice_bounds)); + TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter, + ElideDegenerateDims(update_slice, {0})); + TF_ASSIGN_OR_RETURN( + HloInstruction * update_slice_with_dims_inserted, + InsertDegenerateDims(update_slice_for_scatter, + AsInt64Slice(dim_numbers.inserted_window_dims()))); + + // Extact the slice to update from `operand` tensor. + const Shape& update_slice_shape = update_slice_with_dims_inserted->shape(); + TF_ASSIGN_OR_RETURN( + HloInstruction * operand_slice_to_update, + MakeDynamicSliceHlo(operand, scatter_slice_start, + AsInt64Slice(update_slice_shape.dimensions()))); + + // Compute the new value for the slice to be updated in `operand` tensor by + // combining the existing value and the update value using the update + // computation. + TF_ASSIGN_OR_RETURN( + HloInstruction * updated_operand_slice, + MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted}, + scatter->to_apply())); + + // Write the updated value of the slice into `operand` tensor. + TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand, + MakeDynamicUpdateSliceHlo(operand, updated_operand_slice, + scatter_slice_start)); + + return StatusOr>{ + {updated_operand, scatter_indices, updates}}; +} + +// High Level Algorithm. +// +// 1. Canonicalize the scatter_indices tensor such that it has rank 2, where +// each row is an index into the operand. +// 2. Canonicalize the updates tensor such that is has rank `num_window_dims+1` +// and the scatter dim is the most-major dimension. +// 3. Iterate over the set of indices in the canonicalized scatter_indices +// tensor using a while loop, updating the operand for each such index. Each +// iteration of this while loop performs the following: +// a. Pick the index from scatter_indices for this iteration. +// b. Transfrom this index into an index into the operand space. +// c. Extract the slice to be used to update from the updates tensor. +// d. Extract the slice to update from the operand tensor. +// e. Compute the new value for the slice to update by combining the slices +// from c. and d. using the update_computation of scatter. +// f. Write the updated value of the slice into the operand tensor. + +StatusOr ScatterExpander::ExpandScatter( + HloInstruction* scatter) { + HloInstruction* operand = scatter->mutable_operand(0); + HloInstruction* scatter_indices = scatter->mutable_operand(1); + HloInstruction* updates = scatter->mutable_operand(2); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + + // If the updates tensor is empty, there is no need to update the operand. We + // can return the operand as is. + if (ShapeUtil::IsZeroElementArray(updates->shape())) { + return operand; + } + + // Compute the trip count for the while loop to be used for scatter. This + // should be the number of indices we should scatter into the operand. + const Shape& scatter_indices_shape = scatter_indices->shape(); + int64 scatter_loop_trip_count = 1; + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); + } + } + if (!IsInt32(scatter_loop_trip_count)) { + return Unimplemented( + "Scatter operations with more than 2147483647 scatter indices are not " + "supported. This error occurred for %s.", + scatter->ToString().c_str()); + } + + // Canonicalize the scatter_indices, after which the size of its most-major + // dimension must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices, + CanonicalizeScatterIndices( + scatter_indices, dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, + canonical_scatter_indices->shape().dimensions(0)); + + // Canonicalize the updates, after which the size of its most-major dimension + // must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN( + HloInstruction * canonical_updates, + PermuteScatterAndWindowDims( + updates, AsInt64Slice(dim_numbers.update_window_dims()))); + TF_ASSIGN_OR_RETURN( + HloInstruction * adjusted_canonical_updates, + AdjustScatterDims(scatter_indices->shape(), canonical_updates, + dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, + adjusted_canonical_updates->shape().dimensions(0)); + + // The while loop that implements the scatter operation. + StatusOr> scatter_loop_result_status = + WhileUtil::MakeCountedLoop( + scatter->parent(), scatter_loop_trip_count, + {operand, canonical_scatter_indices, adjusted_canonical_updates}, + [&](HloInstruction* induction_var, + const std::vector& loop_state) { + return ScatterLoopBody(scatter, induction_var, loop_state); + }); + TF_ASSIGN_OR_RETURN(std::vector scatter_loop_result, + scatter_loop_result_status); + return scatter_loop_result.front(); +} + +StatusOr ScatterExpander::Run(HloModule* module) { + std::vector scatter_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kScatter) { + scatter_instrs.push_back(instr); + } + } + } + + for (auto instr : scatter_instrs) { + TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr)); + TF_RETURN_IF_ERROR( + instr->parent()->ReplaceInstruction(instr, expanded_root)); + } + + return !scatter_instrs.empty(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h new file mode 100644 index 0000000000000000000000000000000000000000..8f735e877d270c10b494e1cd974904c4e2d960c9 --- /dev/null +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +class ScatterExpander : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { return "scatter_expander"; } + StatusOr Run(HloModule* module) override; + + private: + StatusOr ExpandScatter(HloInstruction* scatter); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index d01c35b99231310692f85d0f9fbf4f2c3709d44c..1dbf540d13d1fb6f6a4052caeff922cc0290f1b8 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/source_map_util.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -52,10 +53,10 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/ptr_util.h" using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrCat; -using ::xla::source_map_util::InvalidParameterArgument; namespace xla { @@ -64,25 +65,25 @@ namespace { // Records the arguments used to invoke a computation in an HloSnapshot proto. Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, - se::StreamExecutor* executor, TransferManager* transfer_manager, + se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, *argument)); + transfer_manager->TransferLiteralFromDevice(stream, *argument)); *module->add_arguments() = literal->ToProto(); } return Status::OK(); } // Records the result of a computation in a HloSnapshot proto. -Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, +Status RecordResult(const ShapedBuffer& result, se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, result)); + transfer_manager->TransferLiteralFromDevice(stream, result)); *module->mutable_result() = literal->ToProto(); return Status::OK(); } @@ -169,7 +170,8 @@ Service::Service(const ServiceOptions& options, Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, CreateChannelHandleResponse* result) { - *result->mutable_channel() = channel_tracker_.NewChannel(); + TF_ASSIGN_OR_RETURN(*result->mutable_channel(), + channel_tracker_.NewChannel(arg->channel_type())); return Status::OK(); } @@ -191,21 +193,17 @@ Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, return Status::OK(); } -Status Service::ValidateResultShapeWithLayout(const Shape& shape_with_layout, - const Shape& result_shape) const { - if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) { +Status Service::ValidateResultShape(const Shape& client_shape, + const Shape& result_shape) const { + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape)); + if (!ShapeUtil::Compatible(client_shape, result_shape)) { return InvalidArgument( "Shape used to set computation result layout %s is not compatible " "with result shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(client_shape).c_str(), ShapeUtil::HumanString(result_shape).c_str()); } - if (!LayoutUtil::HasLayout(shape_with_layout)) { - return InvalidArgument( - "Shape used to set computation result layout %s does not have layout", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); - } - return ShapeUtil::ValidateShape(shape_with_layout); + return Status::OK(); } StatusOr>> @@ -248,10 +246,8 @@ StatusOr> Service::CreateModuleConfig( tensorflow::gtl::ArraySlice argument_shapes, const ExecutionOptions* execution_options) { auto config = MakeUnique(program_shape); - ComputationLayout* host_computation_layout = - config->mutable_host_entry_computation_layout(); - ComputationLayout* device_computation_layout = - config->mutable_device_entry_computation_layout(); + ComputationLayout* computation_layout = + config->mutable_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { return InvalidArgument("computation takes %d parameters, but %zu given", program_shape.parameters_size(), @@ -268,32 +264,22 @@ StatusOr> Service::CreateModuleConfig( i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } - TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i) - ->CopyLayoutFromShape(*argument_shapes[i])); - TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i) - ->CopyLayoutFromShape(*argument_shapes[i])); + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + *argument_shapes[i])); } if (execution_options != nullptr && execution_options->has_shape_with_output_layout()) { const auto& shape_with_output_layout = execution_options->shape_with_output_layout(); - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout, - program_shape.result())); TF_RETURN_IF_ERROR( - host_computation_layout->mutable_result_layout()->CopyLayoutFromShape( - shape_with_output_layout)); + ValidateResultShape(shape_with_output_layout, program_shape.result())); TF_RETURN_IF_ERROR( - device_computation_layout->mutable_result_layout()->CopyLayoutFromShape( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( shape_with_output_layout)); } else { // If the result layout is not set, then choose the default. - // TODO(b/29118294): Allow the compiler to choose a better layout in this - // case. - // TODO(b/78356948): We are forcing the default layout here. We should fix - // clients which expect a default layout, to be explicit about it, by - // passing the proper ExecutionOptions with shape_with_output_layout set. - host_computation_layout->mutable_result_layout()->SetToDefaultLayout(); - device_computation_layout->mutable_result_layout()->SetToDefaultLayout(); + computation_layout->mutable_result_layout()->SetToDefaultLayout(); } config->set_replica_count(options_.number_of_replicas()); @@ -348,8 +334,8 @@ StatusOr>> Service::BuildExecutables( module_protos[i]->entry_computation_name().c_str()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); - hlo_snapshots.push_back(std::move(hlo_snapshot)); } + hlo_snapshots.push_back(std::move(hlo_snapshot)); } VLOG(1) << "Computations:"; @@ -381,22 +367,6 @@ StatusOr>> Service::BuildExecutables( return std::move(executables); } -Status Service::ValidateEntryComputationLayout(HloModule* module) { - const ComputationLayout& on_device = - module->device_entry_computation_layout(); - for (int64 i = 0; i < on_device.parameter_count(); ++i) { - TF_RET_CHECK(ShapeUtil::Equal( - on_device.parameter_shape(i), - execute_backend_->transfer_manager()->HostShapeToDeviceShape( - module->host_entry_computation_layout().parameter_shape(i)))); - } - TF_RET_CHECK(ShapeUtil::Equal( - module->device_entry_computation_layout().result_shape(), - execute_backend_->transfer_manager()->HostShapeToDeviceShape( - module->host_entry_computation_layout().result_shape()))); - return Status::OK(); -} - StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, @@ -407,7 +377,7 @@ Service::ExecuteParallelAndRegisterResult( ExecutionProfile* profile) { // Streams where the computation are launched, so we can wait on the streams // to complete. - std::vector::SmartPtr> streams; + std::vector streams; std::vector> timers; // Global data handles for the computation results, one for each computation. @@ -434,12 +404,12 @@ Service::ExecuteParallelAndRegisterResult( CHECK_EQ(replicas.size(), arguments[i].size()); std::vector result_buffers; for (int64 replica = 0; replica < replicas.size(); ++replica) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream, backend->BorrowStream(replicas[replica])); streams.push_back(std::move(stream)); if (replica == 0 && profile != nullptr) { - timers.emplace_back(new se::Timer(streams.back()->parent())); + timers.push_back(MakeUnique(streams.back()->parent())); streams.back() ->InitTimer(timers.back().get()) .ThenStartTimer(timers.back().get()); @@ -471,7 +441,7 @@ Service::ExecuteParallelAndRegisterResult( streams.back()->ThenStopTimer(timers.back().get()); } - result_buffers.emplace_back(std::move(result)); + result_buffers.push_back(std::move(result)); } TF_ASSIGN_OR_RETURN(GlobalDataHandle handle, allocation_tracker_.RegisterReplicatedBuffers( @@ -498,7 +468,7 @@ Service::ExecuteParallelAndRegisterResult( HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(), &executable->hlo_profile_index_map()); TF_RETURN_IF_ERROR( - executable->PopulateExecutionProfile(&hlo_profile, stream->parent())); + executable->PopulateExecutionProfile(&hlo_profile, stream)); XLA_LOG_LINES( tensorflow::INFO, hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription())); @@ -546,13 +516,13 @@ StatusOr Service::ExecuteAndRegisterResult( arguments, Backend* backend, const string& result_tag, ExecutionProfile* profile) { // Set up streams. - std::vector::SmartPtr> streams; + std::vector streams; TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, SingleComputationDeviceHandle())); TF_RET_CHECK(!replicas.empty()); for (se::StreamExecutor* executor : replicas) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream, backend->BorrowStream(executor)); streams.push_back(std::move(stream)); } @@ -564,7 +534,7 @@ StatusOr Service::ExecuteAndRegisterResult( // Set up run options. std::vector run_options; - for (const Pool::SmartPtr& stream : streams) { + for (const StreamPool::Ptr& stream : streams) { ExecutableRunOptions options; options.set_stream(stream.get()); options.set_device_ordinal(stream->parent()->device_ordinal()); @@ -589,7 +559,7 @@ StatusOr Service::ExecuteAndRegisterResult( std::vector> replicated_arguments; for (const auto& arg : arguments) { - replicated_arguments.emplace_back(arg); + replicated_arguments.push_back(arg); } TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( @@ -692,7 +662,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, request.execution_options())); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); + << module_config->entry_computation_layout().ToString(); // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(replicated_arguments); @@ -721,6 +691,17 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, executable_ptrs.push_back(executable.get()); } + for (int i = 0; i < executable_ptrs.size(); i++) { + if (executable_ptrs[i]->dumping_snapshot()) { + TF_ASSIGN_OR_RETURN(auto stream, + execute_backend_->BorrowStream( + all_executors[i][0]->device_ordinal())); + TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(), + execute_backend_->transfer_manager(), + executable_ptrs[i]->hlo_snapshot())); + } + } + // Execute the generated executables in parallel and return the device // handles for each computation's output. ExecutionProfile profile; @@ -736,6 +717,20 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, *result->add_responses() = response; } + for (int i = 0; i < executable_ptrs.size(); i++) { + if (executable_ptrs[i]->dumping_snapshot()) { + TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(outputs[i], 0)); + TF_ASSIGN_OR_RETURN(auto stream, + execute_backend_->BorrowStream(all_executors[i][0])); + TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), + execute_backend_->transfer_manager(), + executable_ptrs[i]->hlo_snapshot())); + // Dump out the ith snapshot. + TF_RETURN_IF_ERROR(executable_ptrs[i]->DumpHloSnapshot()); + } + } + VLOG(1) << "successfully completed 'execute-graph-parallel' request"; return Status::OK(); } @@ -828,13 +823,15 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, device_allocator)); - // Check that on-host and on-device shapes are consistent. - TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, backend->compiler()->RunBackend( std::move(module), executor, device_allocator)); + if (!execution_directory_path.empty()) { + executable->set_hlo_snapshot(std::move(hlo_snapshot)); + } + return std::move(executable); } @@ -872,12 +869,14 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, execute_backend_->default_stream_executor(), /*device_allocator=*/nullptr)); + TF_ASSIGN_OR_RETURN(auto stream, + execute_backend_->BorrowStream( + execute_backend_->default_stream_executor())); if (executable->dumping_snapshot()) { executable->hlo_snapshot()->set_execution_platform( execute_backend_->platform()->Name()); TF_RETURN_IF_ERROR(RecordArguments( - replicated_arguments.front(), - execute_backend_->default_stream_executor(), + replicated_arguments.front(), stream.get(), execute_backend_->transfer_manager(), executable->hlo_snapshot())); } @@ -891,9 +890,9 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, TF_ASSIGN_OR_RETURN( const ShapedBuffer* result_buffer, allocation_tracker_.ResolveForReplica(result->output(), 0)); - TF_RETURN_IF_ERROR(RecordResult( - *result_buffer, execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->hlo_snapshot())); + TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), + execute_backend_->transfer_manager(), + executable->hlo_snapshot())); TF_RETURN_IF_ERROR(executable->DumpHloSnapshot()); } @@ -931,14 +930,13 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, return_shape = &shaped_buffer->on_host_shape(); } - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - execute_backend_->stream_executor(shaped_buffer->device_ordinal())); + TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( + shaped_buffer->device_ordinal())); TF_ASSIGN_OR_RETURN( std::unique_ptr result_literal, execute_backend_->transfer_manager()->TransferLiteralFromDevice( - executor, *shaped_buffer)); + stream.get(), *shaped_buffer)); if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal->shape())) { @@ -988,9 +986,10 @@ Status Service::TransferToServer(const TransferToServerRequest* arg, execute_backend_->transfer_manager()->AllocateScopedShapedBuffer( shape, execute_backend_->memory_allocator(), executor->device_ordinal())); + TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, *literal, shaped_buffer)); + stream.get(), *literal, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } TF_ASSIGN_OR_RETURN(*result->mutable_data(), @@ -1054,11 +1053,12 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, executor = replicas[arg->replica_id()]; } - Literal literal; + auto literal = Literal::CreateFromShape(arg->shape_with_layout()); + TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), &literal)); - *result->mutable_literal() = literal.ToProto(); + executor, arg->shape_with_layout(), *literal)); + *result->mutable_literal() = literal->ToProto(); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 8748a4c1447eca691abc0f7ca48feda48ceb86e1..47d196fb2aaee897ce1fd3745129af10bf5b2d2d 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -193,9 +193,6 @@ class Service : public ServiceInterface { const ExecutionOptions& execution_options, tensorflow::gtl::ArraySlice arguments); - // Assert that host- and device-shapes are in a consistent state. - Status ValidateEntryComputationLayout(HloModule* module); - protected: friend class LocalExecutable; @@ -266,11 +263,11 @@ class Service : public ServiceInterface { // will be the result of this computation. Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); - // Convenience function which checks whether the given shape_with_layout + // Convenience function which checks whether the given client_shape // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. - Status ValidateResultShapeWithLayout(const Shape& shape_with_layout, - const Shape& result_shape) const; + Status ValidateResultShape(const Shape& client_shape, + const Shape& result_shape) const; // Returns the stream executors assigned to the replicas represented by the // given device handle. Each device_handle is a virtual replicated device that diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 7f3910cdb0366078b97fb5f6a2dc498b37570926..dbfed628bfcabffe66bef41a82e0e2430897d80d 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_EXECUTABLE_RUN_OPTIONS_H_ #include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -27,8 +27,7 @@ namespace xla { // data, now only a stream cache for GPU backend. class ServiceExecutableRunOptions { public: - using StreamBorrower = - std::function::SmartPtr>(int)>; + using StreamBorrower = std::function(int)>; ServiceExecutableRunOptions() : ServiceExecutableRunOptions(ExecutableRunOptions()) {} @@ -51,7 +50,7 @@ class ServiceExecutableRunOptions { // Borrows a stream and returns a smart pointer which returns the stream on // destruction. - StatusOr::SmartPtr> BorrowStream(int device_ordinal) const { + StatusOr BorrowStream(int device_ordinal) const { return borrow_stream_ ? borrow_stream_(device_ordinal) : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index bd98e86b08b7507b4a7a0d1a7ebac4b654ff2171..a4ea2b28f4dbf41d61702f1af2d65c4d2c86d578 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -49,81 +49,110 @@ bool AllUnique(tensorflow::gtl::ArraySlice slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } -Status ExpectNotTupleOrOpaque(const Shape& shape, - tensorflow::StringPiece op_type) { - if (ShapeUtil::IsTuple(shape)) { - return InvalidArgument("Expected non-tuple argument for %s, but got %s.", +Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) { + if (!ShapeUtil::IsArray(shape)) { + return InvalidArgument("Expected array argument for %s, but got %s.", std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); - } else if (ShapeUtil::IsOpaque(shape)) { - return InvalidArgument("Expected non-opaque argument for %s, but got %s.", - std::string(op_type).c_str(), - ShapeUtil::HumanString(shape).c_str()); - } else { - return Status::OK(); } + return Status::OK(); } -Status VerifyReducerShape(const ProgramShape& reducer_shape, - const Shape& init_value_shape, - const PrimitiveType& input_element_type) { - if (reducer_shape.parameters_size() != 2) { +Status VerifyReducerShape( + const ProgramShape& reducer_shape, + tensorflow::gtl::ArraySlice init_value_shapes, + tensorflow::gtl::ArraySlice input_element_types, + int64 inputs) { + if (reducer_shape.parameters_size() != inputs * 2) { return InvalidArgument( - "Reduction function must take 2 parameters, but " + "Reduction function must take %lld parameters, but " "takes %d parameter(s).", - reducer_shape.parameters_size()); + inputs * 2, reducer_shape.parameters_size()); } const Shape& accumulator_shape = reducer_shape.result(); - if (ShapeUtil::Rank(accumulator_shape) != 0) { - return InvalidArgument( - "Reduction function must have rank 0 (rank %lld reduction function " - "given).", - ShapeUtil::Rank(accumulator_shape)); - } - - // Check that the accumulator can be passed in as the first argument. - // Note: comparing here and below with Compatible since we don't care about - // layout in scalars - see b/26668201 for a longer-term vision. - if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) { + std::vector accumulator_subshapes; + if (ShapeUtil::IsArray(accumulator_shape)) { + if (inputs != 1) { + return InvalidArgument( + "Reduction function must produce a tuple with %lld elements, but " + "produces a scalar", + inputs); + } + accumulator_subshapes.push_back(&accumulator_shape); + } else if (ShapeUtil::IsTuple(accumulator_shape)) { + if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) { + return InvalidArgument( + "Reduction function must produce a tuple with %lld elements, but has " + "%lld elements", + inputs, ShapeUtil::TupleElementCount(accumulator_shape)); + } + for (const Shape& element_shape : accumulator_shape.tuple_shapes()) { + accumulator_subshapes.push_back(&element_shape); + } + } else { return InvalidArgument( - "Reduction function's first parameter shape differs from the " - "result shape: %s vs %s", - ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(), + "Reduction function must produce a scalar or tuple of scalars, but has " + "shape: %s", ShapeUtil::HumanString(accumulator_shape).c_str()); } - // Check that init_value's shape is suitable for reducer_shape. - if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, - init_value_shape)) { - return InvalidArgument( - "Reduction function's accumulator shape differs from the " - "init_value shape: %s vs %s", - ShapeUtil::HumanString(accumulator_shape).c_str(), - ShapeUtil::HumanString(init_value_shape).c_str()); - } - - // Check that the inputs can be passed in as the second argument. - const Shape& input_element_shape = - ShapeUtil::MakeShape(input_element_type, {}); - if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape, - reducer_shape.parameters(1))) { - return InvalidArgument( - "Reduction function's second parameter shape differs from the " - "input type element type: %s vs %s", - ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(input_element_shape).c_str()); + for (const Shape* element_shape : accumulator_subshapes) { + if (ShapeUtil::Rank(*element_shape) != 0) { + return InvalidArgument( + "Reduction function must return a scalar or tuple of scalars but " + "returns shape: %s", + ShapeUtil::HumanString(accumulator_shape).c_str()); + } } - // Currently the accumulator and inputs must be the same type, - // though that restriction could be relaxed. - if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, - reducer_shape.parameters(1))) { - return InvalidArgument( - "Reduction function's second parameter shape must " - "match the result shape, but got %s vs %s.", - ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(accumulator_shape).c_str()); + for (int64 i = 0; i < inputs; ++i) { + // Check that the accumulator can be passed in as the first argument. + // Note: comparing here and below with Compatible since we don't care about + // layout in scalars - see b/26668201 for a longer-term vision. + if (!ShapeUtil::Compatible(*accumulator_subshapes[i], + reducer_shape.parameters(i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape differs from the " + "result shape: %s vs %s", + i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + } + // Check that init_value's shapes are suitable for reducer_shape. + if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i], + *init_value_shapes[i])) { + return InvalidArgument( + "Reduction function's accumulator shape at index %lld differs from " + "the init_value shape: %s vs %s", + i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(), + ShapeUtil::HumanString(*init_value_shapes[i]).c_str()); + } + // Check that the inputs can be passed in as the non-accumulator arguments. + const Shape input_element_shape = + ShapeUtil::MakeShape(input_element_types[i], {}); + if (!ShapeUtil::CompatibleIgnoringFpPrecision( + input_element_shape, reducer_shape.parameters(inputs + i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape differs from the " + "input type element type: %s vs %s", + inputs + i, + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), + ShapeUtil::HumanString(input_element_shape).c_str()); + } + // Check that the accumulator and inputs to the reducer function match. + // If the accumulator is scalar, it must have the same type as the inputs + // (up to fp precision). If it is a tuple, then the k-th element of the + // tuple must have the same type as the K-th input (again, up to fp + // precision.) + if (!ShapeUtil::CompatibleIgnoringFpPrecision( + *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape must " + "match the result shape, but got %s vs %s.", + inputs + i, + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), + ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + } } return Status::OK(); @@ -198,8 +227,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return shape; } - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(shape, "operand of unary operation")); + TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation")); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); switch (opcode) { @@ -229,13 +257,16 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return shape; case HloOpcode::kReal: case HloOpcode::kImag: - if (!ShapeUtil::ElementIsComplex(shape)) { + if (ShapeUtil::ElementIsComplex(shape)) { + return ShapeUtil::ComplexComponentShape(shape); + } else if (ShapeUtil::ElementIsFloating(shape)) { + return shape; + } else { return InvalidArgument( - "Expected element type in shape to be complex for real/imag " - "operation; got %s.", + "Expected element type in shape to be floating or complex for " + "real/imag operation; got %s.", PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(shape, F32); case HloOpcode::kAbs: if (ShapeUtil::ElementIsComplex(shape)) { return ShapeUtil::ChangeElementType( @@ -246,7 +277,6 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case HloOpcode::kNegate: case HloOpcode::kRoundNearestAfz: case HloOpcode::kSign: - case HloOpcode::kSort: return shape; case HloOpcode::kNot: @@ -289,8 +319,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const Shape* arg_shape = nullptr; PrimitiveType element_type = PRIMITIVE_TYPE_INVALID; for (const Shape* shape : arg_shapes) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(*shape, "operand of concatenation")); + TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation")); if (!arg_shape) { arg_shape = shape; element_type = arg_shape->element_type(); @@ -337,7 +366,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } -/* static */ StatusOr ShapeInference::InferTokenShape( +/* static */ StatusOr ShapeInference::InferAfterAllShape( tensorflow::gtl::ArraySlice arg_shapes) { for (const Shape* arg_shape : arg_shapes) { if (arg_shape->element_type() != TOKEN) { @@ -358,12 +387,13 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } - if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { + if (!ShapeUtil::IsArray(operand_shape) || + !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. return InvalidArgument( - "Convert does not allow tuples, so cannot convert from %s to %s.", + "Convert does not allow non-arrays, so cannot convert from %s to %s.", ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } @@ -380,7 +410,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } - if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { + if (!ShapeUtil::IsArray(operand_shape) || + !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. @@ -427,7 +458,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferPadShape( const Shape& operand_shape, const Shape& padding_value_shape, const PaddingConfig& padding_config) { - if (ShapeUtil::IsTuple(operand_shape)) { + if (!ShapeUtil::IsArray(operand_shape)) { return InvalidArgument( "Pad operation does not support tuple-shape operands."); } @@ -566,8 +597,8 @@ Status ValidateDotDimensionNumbers( /* static */ StatusOr ShapeInference::InferDotOpShape( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot")); + TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot")); + TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot")); auto fail = [lhs, rhs](const string& addendum) -> Status { string message = tensorflow::strings::Printf( @@ -786,10 +817,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation")); + TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation")); + TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( @@ -853,12 +882,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - lhs, tensorflow::strings::StrCat("lhs of binary operation ", - HloOpcodeString(opcode)))); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - rhs, tensorflow::strings::StrCat("rhs of binary operation ", - HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR( + ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ", + HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR( + ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ", + HloOpcodeString(opcode)))); switch (opcode) { case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -893,6 +922,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kXor: if (lhs.element_type() != PRED && !primitive_util::IsIntegralType(lhs.element_type())) { return InvalidArgument( @@ -937,6 +967,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InferClampShape(lhs, rhs, ehs); case HloOpcode::kSelect: return InferSelectShape(lhs, rhs, ehs); + case HloOpcode::kTupleSelect: + return InferTupleSelectShape(lhs, rhs, ehs); default: return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode).c_str()); @@ -947,6 +979,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, HloOpcode opcode, tensorflow::gtl::ArraySlice operands) { std::vector operand_shapes; + operand_shapes.reserve(operands.size()); for (const HloInstruction* operand : operands) { operand_shapes.push_back(&operand->shape()); } @@ -962,11 +995,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, switch (opcode) { case HloOpcode::kTuple: { Shape result = ShapeUtil::MakeTupleShape({}); + result.mutable_tuple_shapes()->Reserve(operand_shapes.size()); for (const Shape* shape : operand_shapes) { ShapeUtil::AppendShapeToTuple(*shape, &result); } return result; } + case HloOpcode::kSort: { + if (operand_shapes.size() == 1) { + return *operand_shapes[0]; + } else if (operand_shapes.size() == 2) { + if (!ShapeUtil::SameDimensions(*operand_shapes[0], + *operand_shapes[1])) { + return InvalidArgument( + "Sort keys and values dimensions must match. " + "Keys shape is: %s\n, Values shape is: %s", + ShapeUtil::HumanString(*operand_shapes[0]).c_str(), + ShapeUtil::HumanString(*operand_shapes[1]).c_str()); + } + return ShapeUtil::MakeTupleShape( + {*operand_shapes[0], *operand_shapes[1]}); + } + return InvalidArgument("Unexpected number of operands for sort"); + } default: return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode).c_str()); @@ -984,15 +1035,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // All arguments must have the same shape. const Shape* arg_shape = arg_shapes[0]; for (size_t i = 1; i < arg_shapes.size(); ++i) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map")); + TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map")); if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) { continue; } - if (!ShapeUtil::IsTuple(*arg_shapes[i]) && - !ShapeUtil::IsTuple(*arg_shape) && - ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i], + if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) { if (ShapeUtil::IsScalar(*arg_shapes[i])) { continue; @@ -1075,11 +1123,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& operand_shape, const Shape& scale_shape, const Shape& offset_shape, int64 feature_index) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - offset_shape, "offset input of batch norm training")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - scale_shape, "scale input of batch norm training")); + ExpectArray(operand_shape, "operand of batch norm training")); + TF_RETURN_IF_ERROR( + ExpectArray(offset_shape, "offset input of batch norm training")); + TF_RETURN_IF_ERROR( + ExpectArray(scale_shape, "scale input of batch norm training")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == Status::OK()); @@ -1181,11 +1229,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& offset_shape, const Shape& mean_shape, const Shape& variance_shape, int64 feature_index) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - offset_shape, "offset input of batch norm inference")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - scale_shape, "scale input of batch norm inference")); + ExpectArray(operand_shape, "operand of batch norm inference")); + TF_RETURN_IF_ERROR( + ExpectArray(offset_shape, "offset input of batch norm inference")); + TF_RETURN_IF_ERROR( + ExpectArray(scale_shape, "scale input of batch norm inference")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == Status::OK()); @@ -1328,16 +1376,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& operand_shape, const Shape& scale_shape, const Shape& mean_shape, const Shape& var_shape, const Shape& output_grad_shape, int64 feature_index) { + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm grad")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(scale_shape, "scale input of batch norm grad")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(mean_shape, "mean input of batch norm grad")); + ExpectArray(scale_shape, "scale input of batch norm grad")); + TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad")); + TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(var_shape, "var input of batch norm grad")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - output_grad_shape, "output_grad input of batch norm grad")); + ExpectArray(output_grad_shape, "output_grad input of batch norm grad")); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape)); @@ -1486,8 +1531,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, const ConvolutionDimensionNumbers& dnums) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution")); + TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); + TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( @@ -1722,7 +1767,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, tensorflow::gtl::ArraySlice operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum")); + ExpectArray(*operand_shape, "operand of cross replica sum")); } if (operand_shapes.size() == 1) { return *operand_shapes[0]; @@ -1734,11 +1779,83 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShape(operand_shape_values); } +/* static */ StatusOr ShapeInference::InferAllToAllShape( + const Shape& shape, int64 split_dimension, int64 concat_dimension, + int64 split_count) { + TF_RET_CHECK(split_count > 0); + if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { + return InvalidArgument( + "AllToAll split_dimension %lld is out-of-bounds in shape %s.", + split_dimension, ShapeUtil::HumanString(shape).c_str()); + } + if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { + return InvalidArgument( + "AllToAll concat_dimension %lld is out-of-bounds in shape %s.", + concat_dimension, ShapeUtil::HumanString(shape).c_str()); + } + if (shape.dimensions(split_dimension) % split_count != 0) { + return InvalidArgument( + "AllToAll split dimension size %lld must be dividable by split_count " + "%lld.", + shape.dimensions(split_dimension), split_count); + } + std::vector new_dimensions(shape.dimensions().begin(), + shape.dimensions().end()); + new_dimensions[split_dimension] /= split_count; + new_dimensions[concat_dimension] *= split_count; + return ShapeUtil::MakeShape(shape.element_type(), new_dimensions); +} + +/* static */ StatusOr ShapeInference::InferAllToAllTupleShape( + tensorflow::gtl::ArraySlice operand_shapes) { + // An Alltoall HLO instruction receives N operands (with the same shape) and + // returns a tuple that contains N array shapes. + TF_RET_CHECK(!operand_shapes.empty()); + for (int i = 0; i < operand_shapes.size(); i++) { + if (!ShapeUtil::Equal(*operand_shapes[0], *operand_shapes[i])) { + return InvalidArgument( + "HLO all-to-all has operands with different shapes: the 0th " + "operand shape %s, but the %dth operand has shape %s.", + ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i, + ShapeUtil::HumanString(*operand_shapes[i]).c_str()); + } + } + + return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); +} + /* static */ StatusOr ShapeInference::InferReduceShape( - const Shape& arg, const Shape& init_value, + tensorflow::gtl::ArraySlice arg_shapes, tensorflow::gtl::ArraySlice dimensions_to_reduce, const ProgramShape& to_apply) { - // Check that the dimension to reduce are in-bounds for the given shape. + if (arg_shapes.empty()) { + return InvalidArgument("Reduce must have at least 2 arguments, has 0"); + } + if (arg_shapes.size() % 2) { + return InvalidArgument( + "Reduce must have an even number of arguments, has %lu", + arg_shapes.size()); + } + int64 num_reduced_args = arg_shapes.size() / 2; + + tensorflow::gtl::ArraySlice reduced_args(arg_shapes, 0, + num_reduced_args); + // Check that all of the reduced tensors have the same dimensions. The element + // types may be different. + for (int64 i = 1; i < num_reduced_args; ++i) { + if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { + return InvalidArgument( + "All reduced tensors must have the sime dimension. Tensor 0 has " + "shape %s, Tensor %lld has shape %s", + ShapeUtil::HumanString(*reduced_args[0]).c_str(), i, + ShapeUtil::HumanString(*reduced_args[i]).c_str()); + } + } + + // Check that the dimensions to reduce are in-bounds for the given shape. + // We've already verified all reduced tensors have the same dimensions, so it + // doesn't matter which one we choose. + const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { return InvalidArgument( @@ -1746,8 +1863,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(arg).c_str()); } } - TF_RETURN_IF_ERROR( - VerifyReducerShape(to_apply, init_value, arg.element_type())); + + tensorflow::gtl::ArraySlice init_values( + arg_shapes, num_reduced_args, arg_shapes.size()); + std::vector element_types; + for (const Shape* arg : reduced_args) { + element_types.push_back(arg->element_type()); + } + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types, + num_reduced_args)); std::set dimensions_to_reduce_set(dimensions_to_reduce.begin(), dimensions_to_reduce.end()); @@ -1758,16 +1882,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } - return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions); + if (ShapeUtil::IsScalar(to_apply.result())) { + return ShapeUtil::MakeShape(to_apply.result().element_type(), + new_dimensions); + } else { + std::vector result_subshapes; + for (const Shape& subshape : to_apply.result().tuple_shapes()) { + result_subshapes.push_back( + ShapeUtil::MakeShape(subshape.element_type(), new_dimensions)); + } + return ShapeUtil::MakeTupleShape(result_subshapes); + } } /* static */ StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window, const ProgramShape& to_apply_shape) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window")); - TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape, - operand_shape.element_type())); + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, + {operand_shape.element_type()}, + /*inputs=*/1)); return InferWindowOutputShape(operand_shape, window, init_value_shape.element_type(), /*allow_negative_padding=*/false); @@ -1778,7 +1912,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Window& window, const Shape& source_shape, const Shape& init_value_shape, const ProgramShape& scatter_shape) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter")); + ExpectArray(operand_shape, "operand of select-and-scatter")); // Check if the select function has a proper shape of (T,T) -> PRED. if (select_shape.parameters_size() != 2) { @@ -1812,8 +1946,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } // Check if the scatter function has a proper shape as a reduction. - TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape, - source_shape.element_type())); + TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape}, + {source_shape.element_type()}, + /*inputs=*/1)); // Check if the result shape of window operation matches the source shape. TF_ASSIGN_OR_RETURN(const Shape& window_result_shape, @@ -1843,7 +1978,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, Join(starts, ",").c_str(), Join(limits, ",").c_str(), Join(strides, ",").c_str()); }; - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice")); + TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice")); VLOG(2) << tensorflow::strings::Printf( "slicing shape %s starts={%s} limits={%s}", ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(), @@ -1902,10 +2037,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, tensorflow::gtl::ArraySlice slice_sizes) { + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape, - "start indices of dynamic slice")); + ExpectArray(start_indices_shape, "start indices of dynamic slice")); VLOG(2) << tensorflow::strings::Printf( "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", @@ -1963,11 +2097,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& operand_shape, const Shape& update_shape, const Shape& start_indices_shape) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice")); + ExpectArray(operand_shape, "operand of dynamic update slice")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - start_indices_shape, "start indices of dynamic update slice")); + ExpectArray(update_shape, "update of dynamic update slice")); + TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, + "start indices of dynamic update slice")); VLOG(2) << tensorflow::strings::Printf( "updating slice of shape %s at dynamic start_indices %s with update " @@ -2035,8 +2169,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /*static */ StatusOr ShapeInference::InferReverseShape( const Shape& operand_shape, tensorflow::gtl::ArraySlice dimensions) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of reverse")); + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse")); if (!AllUnique(dimensions)) { return InvalidArgument("a dimension number is duplicated in reverse"); } @@ -2166,7 +2299,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); for (int64 size : broadcast_sizes) { if (size < 0) { return InvalidArgument("Broadcast with negative dimension size %lld.", @@ -2185,7 +2318,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferReshapeShape( const Shape& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); Shape inferred_shape = ShapeUtil::MakeShape(operand.element_type(), new_sizes); @@ -2217,7 +2350,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferTransposeShape( const Shape& operand, tensorflow::gtl::ArraySlice dimensions) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); std::vector indices(ShapeUtil::Rank(operand)); std::iota(indices.begin(), indices.end(), 0); @@ -2238,9 +2371,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // "degenerate" cases, as with binary elementwise ops. /* static */ StatusOr ShapeInference::InferClampShape( const Shape& min, const Shape& operand, const Shape& max) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); + TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand")); + TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { return InvalidArgument("Clamp with different operand types: %s, %s, %s.", @@ -2273,15 +2406,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // broadcast from all operands, not just the predicate. /* static */ StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { - bool compatible; - if (ShapeUtil::IsTuple(on_true)) { - // Select only defines the top-level buffer, so if it's a tuple, the two - // input must match exactly. - compatible = ShapeUtil::Compatible(on_true, on_false); - } else { - compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false); - } - if (!compatible) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) { return InvalidArgument( "Operands to select must be the same shape; got %s and %s.", ShapeUtil::HumanString(on_true).c_str(), @@ -2293,7 +2418,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(pred).c_str()); } if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || - ShapeUtil::Rank(pred) == 0) { + ShapeUtil::IsScalar(pred)) { // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. @@ -2307,6 +2432,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } +/* static */ StatusOr ShapeInference::InferTupleSelectShape( + const Shape& pred, const Shape& on_true, const Shape& on_false) { + // Select only defines the top-level buffer, so if it's a tuple, the two + // input must match exactly. + if (!ShapeUtil::Compatible(on_true, on_false)) { + return InvalidArgument( + "Operands to tuple-select must be the same shape; got %s and %s.", + ShapeUtil::HumanString(on_true).c_str(), + ShapeUtil::HumanString(on_false).c_str()); + } + if (pred.element_type() != PRED) { + return InvalidArgument( + "TupleSelect's pred operand must have PRED element type; got %s.", + ShapeUtil::HumanString(pred).c_str()); + } + if (!ShapeUtil::IsScalar(pred)) { + return InvalidArgument( + "TupleSelect operation with non-scalar predicate: %s.", + ShapeUtil::HumanString(pred).c_str()); + } + return on_true; +} + /* static */ StatusOr ShapeInference::InferCallShape( tensorflow::gtl::ArraySlice arg_shapes, const ProgramShape& to_apply) { @@ -2439,9 +2587,9 @@ static Status ValidateGatherDimensionNumbers( const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - gather_indices_shape, "gather indices operand of gather op")); + ExpectArray(input_shape, "input tensor operand gather op")); + TF_RETURN_IF_ERROR( + ExpectArray(gather_indices_shape, "gather indices operand of gather op")); if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { return InvalidArgument( @@ -2546,4 +2694,194 @@ static Status ValidateGatherDimensionNumbers( return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds); } +namespace { + +Status ValidateScatterDimensionNumbers( + const Shape& operand_shape, + tensorflow::gtl::ArraySlice scatter_indices_shape, + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + // Validate update_window_dims in ScatterDimensionNumbers. + if (!c_is_sorted(dim_numbers.update_window_dims())) { + return InvalidArgument( + "update_window_dims in scatter op must be sorted; got: %s.", + Join(dim_numbers.update_window_dims(), ", ").c_str()); + } + if (c_adjacent_find(dim_numbers.update_window_dims()) != + dim_numbers.update_window_dims().end()) { + return InvalidArgument( + "update_window_dims in scatter op must not repeat; got: %s.", + Join(dim_numbers.update_window_dims(), ", ").c_str()); + } + const int64 updates_rank = ShapeUtil::Rank(updates_shape); + for (int64 window_dim : dim_numbers.update_window_dims()) { + if (window_dim < 0 || window_dim >= updates_rank) { + return InvalidArgument( + "Invalid update_window_dims set in scatter op; valid range is [0, " + "%lld). got: %lld.", + updates_rank, window_dim); + } + } + + // Validate inserted_window_dims in ScatterDimensionNumbers. + if (!c_is_sorted(dim_numbers.inserted_window_dims())) { + return InvalidArgument( + "inserted_window_dims in scatter op must be sorted; got: %s.", + Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + } + if (c_adjacent_find(dim_numbers.inserted_window_dims()) != + dim_numbers.inserted_window_dims().end()) { + return InvalidArgument( + "inserted_window_dims in scatter op must not repeat; got: %s.", + Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + } + for (int64 inserted_dim : dim_numbers.inserted_window_dims()) { + if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) { + return InvalidArgument( + "Invalid inserted_window_dims set in scatter op; valid range is [0, " + "%d), got: %lld.", + operand_shape.dimensions_size(), inserted_dim); + } + } + + // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. + if (dim_numbers.scatter_dims_to_operand_dims_size() != + scatter_indices_shape[dim_numbers.index_vector_dim()]) { + return InvalidArgument( + "Scatter op has %d elements in scatter_dims_to_operand_dims and the " + "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. " + "These two numbers must be equal.", + dim_numbers.scatter_dims_to_operand_dims_size(), + dim_numbers.index_vector_dim(), + scatter_indices_shape[dim_numbers.index_vector_dim()]); + } + for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) { + int64 scatter_dim_to_operand_dim = + dim_numbers.scatter_dims_to_operand_dims(i); + if (scatter_dim_to_operand_dim < 0 || + scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) { + return InvalidArgument( + "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), " + "got: %d->%lld.", + operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim); + } + } + std::vector sorted_scatter_dims_to_operand_dims( + dim_numbers.scatter_dims_to_operand_dims().begin(), + dim_numbers.scatter_dims_to_operand_dims().end()); + c_sort(sorted_scatter_dims_to_operand_dims); + if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) != + sorted_scatter_dims_to_operand_dims.end()) { + return InvalidArgument( + "Repeated dimensions not allowed in scatter_dims_to_operand_dims; " + "got: %s.", + Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str()); + } + + return Status::OK(); +} + +} // namespace + +/*static*/ StatusOr ShapeInference::InferScatterShape( + const Shape& operand_shape, const Shape& scatter_indices_shape, + const Shape& updates_shape, const ProgramShape& to_apply_shape, + const ScatterDimensionNumbers& scatter_dim_numbers) { + TF_RETURN_IF_ERROR( + ExpectArray(operand_shape, "operand tensor of scatter op")); + TF_RETURN_IF_ERROR( + ExpectArray(scatter_indices_shape, "scatter indices of scatter op")); + TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op")); + + if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { + return InvalidArgument( + "Scatter indices parameter must be an integral tensor; got %s.", + ShapeUtil::HumanString(scatter_indices_shape).c_str()); + } + + if (scatter_indices_shape.dimensions_size() < + scatter_dim_numbers.index_vector_dim() || + scatter_dim_numbers.index_vector_dim() < 0) { + return InvalidArgument( + "Scatter index leaf dimension must be within [0, rank(scatter_indices)" + " + 1). rank(scatter_indices) is %d and scatter index leaf dimension " + "is %lld.", + scatter_indices_shape.dimensions_size(), + scatter_dim_numbers.index_vector_dim()); + } + + // Check if the update computation has a proper shape as a reduction. + const Shape init_value_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), {}); + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, + {updates_shape.element_type()}, + /*inputs=*/1)); + + std::vector expanded_scatter_indices_shape = + ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions())); + if (expanded_scatter_indices_shape.size() == + scatter_dim_numbers.index_vector_dim()) { + expanded_scatter_indices_shape.push_back(1); + } + + int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + + scatter_dim_numbers.update_window_dims_size(); + if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { + return InvalidArgument("Updates tensor must be of rank %lld; got %lld.", + expected_updates_rank, + ShapeUtil::Rank(updates_shape)); + } + + TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers( + operand_shape, expanded_scatter_indices_shape, updates_shape, + scatter_dim_numbers)); + + int64 inserted_dims_seen = 0; + std::vector max_update_window_bounds; + for (int i = 0; i < operand_shape.dimensions_size(); ++i) { + if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() && + scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) { + ++inserted_dims_seen; + } else { + max_update_window_bounds.push_back(operand_shape.dimensions(i)); + } + } + for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) { + auto update_window_dim = scatter_dim_numbers.update_window_dims(i); + if (updates_shape.dimensions(update_window_dim) > + max_update_window_bounds[i]) { + return InvalidArgument( + "Bounds of the window dimensions of updates must not exceed the " + "bounds of the corresponding dimensions of operand. For dimension " + "%lld, updates bound is %lld, operand bound is %lld.", + update_window_dim, updates_shape.dimensions(update_window_dim), + max_update_window_bounds[i]); + } + } + + int64 scatter_dims_seen = 0; + for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { + bool is_update_window_dim = + c_binary_search(scatter_dim_numbers.update_window_dims(), i); + if (is_update_window_dim) { + continue; + } + if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) { + ++scatter_dims_seen; + } + if (updates_shape.dimensions(i) != + expanded_scatter_indices_shape[scatter_dims_seen]) { + return InvalidArgument( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices. For " + "scatter dimension %lld, updates bound is %lld, scatter_indices " + "bound is %lld.", + i, updates_shape.dimensions(i), + expanded_scatter_indices_shape[scatter_dims_seen]); + } + ++scatter_dims_seen; + } + + return operand_shape; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index f1f7b50902d899c0c629c3098d80fc400fb1388d..c185b0a1bd79e23e0d76daad50fb4a9708a743dd 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -119,11 +119,22 @@ class ShapeInference { const Shape& in, FftType fft_type, tensorflow::gtl::ArraySlice fft_length); - // Infers the shape produced a cross replica sum with the given operand + // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferCrossReplicaSumShape( tensorflow::gtl::ArraySlice operand_shapes); + // Infers final shape of an Alltoall operation that is created by the xla + // builder. + static StatusOr InferAllToAllShape(const Shape& shape, + int64 split_dimension, + int64 concat_dimension, + int64 split_count); + + // Infers the shape of an HLO all-to-all instruction. + static StatusOr InferAllToAllTupleShape( + tensorflow::gtl::ArraySlice operand_shapes); + // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. // @@ -131,7 +142,7 @@ class ShapeInference { // index as the leading parameter, and the program shape should match // accordingly (or an error will result). static StatusOr InferReduceShape( - const Shape& arg, const Shape& init_value, + tensorflow::gtl::ArraySlice arg_shapes, tensorflow::gtl::ArraySlice dimensions_to_reduce, const ProgramShape& to_apply); @@ -216,11 +227,11 @@ class ShapeInference { static StatusOr InferConcatOpShape( tensorflow::gtl::ArraySlice arg_shapes, int64 dimension); - // Infers the shape produced by a kGenerateToken operation. Trivially this - // shape is always a TOKEN shape. However, ShapeInference serves two purposes: - // inferring shapes and checking operand shapes. This method verifies that the - // operand shapes are all TOKENs. - static StatusOr InferTokenShape( + // Infers the shape produced by a kAfterAll. Trivially this shape is always a + // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes + // and checking operand shapes. This method verifies that the operand shapes + // are all TOKENs. + static StatusOr InferAfterAllShape( tensorflow::gtl::ArraySlice arg_shapes); // Helper that validates the given operand shape can be converted to the @@ -268,6 +279,14 @@ class ShapeInference { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds); + // Helper that validates the given input shape, scatter indices shape, updates + // shape, and scatter dimension numbers that constitute a scatter operation, + // and returns the result shape of the scatter operation. + static StatusOr InferScatterShape( + const Shape& operand_shape, const Shape& scatter_indices_shape, + const Shape& updates_shape, const ProgramShape& to_apply_shape, + const ScatterDimensionNumbers& scatter_dim_numbers); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. @@ -286,6 +305,10 @@ class ShapeInference { static StatusOr InferSelectShape(const Shape& pred, const Shape& on_true, const Shape& on_false); + // Helper for inferring the shape of TupleSelect ops. + static StatusOr InferTupleSelectShape(const Shape& pred, + const Shape& on_true, + const Shape& on_false); // Helper for inferring shapes of binary operations which use degenerate // dimension broadcasting (a dimension of size 1 in one operand is broadcast diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 6d017dffe2d8f927abad4a62bff7fe41bc871975..a73fa181cdd13dc7fabcdc367ae117e19bdc3e5f 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -62,7 +63,7 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest { tensorflow::gtl::ArraySlice dimensions_to_reduce) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); auto inferred_status = ShapeInference::InferReduceShape( - arg, f32_, dimensions_to_reduce, to_apply); + {&arg, &f32_}, dimensions_to_reduce, to_apply); EXPECT_IS_OK(inferred_status.status()); EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape, inferred_status.ValueOrDie())); @@ -702,11 +703,99 @@ TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) { /*dimensions_to_reduce=*/{0, 1, 2}); } +TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_IS_OK(inferred_status.status()); + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}), + inferred_status.ValueOrDie())); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = + ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_}, + ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("must take 4 parameters, but takes 6 parameter(s)")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr( + "parameter shape differs from the result shape: s32[] vs f32[]")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) { + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("must have at least 2 arguments, has 0")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = + ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr("must produce a tuple with 2 elements, but produces a scalar")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr("must produce a tuple with 2 elements, but has 3 elements")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("accumulator shape at index 0 differs from the " + "init_value shape: s32[] vs f32[]")); +} + TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = ShapeInference::InferReduceShape( - ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4}, - to_apply); + {&arg_shape, &f32_}, + /*dimensions_to_reduce=*/{3, 4}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("out-of-bounds dimension")); @@ -714,8 +803,9 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = - ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), @@ -724,12 +814,13 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = - ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), - HasSubstr("first parameter shape differs")); + HasSubstr("0-th parameter shape differs")); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { @@ -1311,7 +1402,7 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT( inferred_status_error4.status().error_message(), - HasSubstr("Expected non-tuple argument for operand of concatenation")); + HasSubstr("Expected array argument for operand of concatenation")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); auto inferred_status_error5 = ShapeInference::InferConcatOpShape( @@ -1387,7 +1478,7 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) { ShapeInference::InferReverseShape(tuple_shape, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("Expected non-tuple argument")); + HasSubstr("Expected array argument")); } TEST_F(ShapeInferenceTest, Call) { @@ -1523,7 +1614,19 @@ TEST_F(ShapeInferenceTest, BadSlice) { << statusor.status(); } -class GatherShapeInferenceTest : public ShapeInferenceTest { +TEST_F(ShapeInferenceTest, BadSort) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values = ShapeUtil::MakeShape(F32, {5}); + StatusOr statusor = + ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values}); + ASSERT_FALSE(statusor.ok()); + + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("dimensions must match")) + << statusor.status(); +} + +class ScatterGatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5}); @@ -1540,59 +1643,63 @@ class GatherShapeInferenceTest : public ShapeInferenceTest { ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_}); + const ProgramShape to_apply_ = + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); }; -TEST_F(GatherShapeInferenceTest, TensorFlowGather) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1})); +// Shape inference tests for Gather. + +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) { + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_vector_32_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), + /*window_bounds=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{1}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, - /*index_vector_dim=*/1), - /*window_bounds=*/{1, 48})); +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) { + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_vector_32_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{1}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1), + /*window_bounds=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, - /*index_vector_dim=*/4), - /*window_bounds=*/{1, 48})); +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) { + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4), + /*window_bounds=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1604,12 +1711,12 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1622,12 +1729,12 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1640,13 +1747,13 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { +TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) { // This is equivalent to a dynamic slice. TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{0, 1, 2, 3, 4}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1658,13 +1765,13 @@ TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { +TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) { // The gather indices "tensor" is a scalar S here that's used to slice out // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result. TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_scalar_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{0, 1, 2, 3}, /*elided_window_dims=*/{0}, /*gather_dims_to_operand_dims=*/{0}, @@ -1676,41 +1783,44 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { +TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/1), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Expected non-tuple argument for input")) + HasSubstr("Expected array argument for input")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { +TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/0), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Expected non-tuple argument for gather indices")) + HasSubstr("Expected array argument for gather indices")) << statusor.status(); } -TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { +TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/0), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1718,11 +1828,11 @@ TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingWindowIndices) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 8, 7}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1735,11 +1845,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowIndices) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 7}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1752,11 +1862,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexOutOfBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 99, 100, 101}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1768,11 +1878,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 9}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1784,11 +1894,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{4}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1802,11 +1912,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 19}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1819,11 +1929,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 3}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1837,11 +1947,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, @@ -1856,11 +1966,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, @@ -1874,11 +1984,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, @@ -1892,11 +2002,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{2, 1}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1908,10 +2018,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { +TEST_F(ScatterGatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowBoundsTooLarge) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{2}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1924,11 +2035,11 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1942,11 +2053,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{1}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1959,10 +2070,10 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { +TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1976,5 +2087,575 @@ TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { << statusor.status(); } +// Shape inference tests for Scatter. + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {64, 32}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {32, 48}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {10, 32}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {32, 8}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterWithUpdatesNotMatchingIndices) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterWithUpdatesNotMatchingIndicesV2) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterNdWithUpdatesNotMatchingIndices) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) { + // This is equivalent to a dynamic update slice. + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3, 4}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) { + // The scalar indices "tensor" is a scalar S here that's used to update a + // [30,29,28,27] shaped tensor within the operand at position S. + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for operand")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + ScatterWithTupleShapedScatterIndicesInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for scatter indices")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for updates")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + s64_vector_32_, vector_32_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Scatter indices parameter must be an integral tensor")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/10)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Scatter index leaf dimension must be within [0, " + "rank(scatter_indices) + 1)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Updates tensor must be of rank 7; got 8.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) { + const ProgramShape invalid_update_computation = + ShapeUtil::MakeProgramShape({f32_}, f32_); + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), + invalid_update_computation, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Reduction function must take 2 parameters, but takes 1")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 8, 7}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("update_window_dims in scatter op must be sorted")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedUpdateWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 7}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("update_window_dims in scatter op must not repeat")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 9}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid update_window_dims set in scatter op; valid " + "range is [0, 9)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{2, 1}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("inserted_window_dims in scatter op must be sorted")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedInsertedWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 1}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("inserted_window_dims in scatter op must not repeat")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 5}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid inserted_window_dims set in scatter op; valid " + "range is [0, 5)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and " + "the bound of dimension index_vector_dim=4 of scatter_indices " + "is 5. These two numbers must be equal")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain " + "is [0, 5), got: 4->10")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Repeated dimensions not allowed in scatter_dims_to_operand_dims")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..c0582c6a2d3a05e2ed5aead5faac54e536d350cd --- /dev/null +++ b/tensorflow/compiler/xla/service/stream_pool.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stream_pool.h" + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) { + std::unique_ptr stream; + { + tensorflow::mutex_lock lock(mu_); + if (!streams_.empty()) { + // Re-use an existing stream from the pool. + stream = std::move(streams_.back()); + streams_.pop_back(); + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool reusing existing stream"; + } + } + + if (!stream) { + // Create a new stream. + stream = MakeUnique(executor); + stream->Init(); + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool created new stream"; + } + + // Return the stream wrapped in Ptr, which has our special deleter semantics. + PtrDeleter deleter = {this}; + return Ptr(stream.release(), deleter); +} + +void StreamPool::ReturnStream(se::Stream* stream) { + if (stream->ok()) { + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool returning ok stream"; + tensorflow::mutex_lock lock(mu_); + streams_.emplace_back(stream); + } else { + // If the stream has encountered any errors, all subsequent operations on it + // will fail. So just delete the stream, and rely on new streams to be + // created in the future. + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool deleting !ok stream"; + delete stream; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/stream_pool.h b/tensorflow/compiler/xla/service/stream_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..7221d323a61593ac4b203a81b6046d81a5beaaf0 --- /dev/null +++ b/tensorflow/compiler/xla/service/stream_pool.h @@ -0,0 +1,64 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_ + +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Pool of stream_executor::Streams, which are created as needed and +// destroyed when the pool is destroyed. +class StreamPool { + public: + struct PtrDeleter { + void operator()(se::Stream* stream) { pool->ReturnStream(stream); } + StreamPool* pool; + }; + + // Stream pointer type returned by BorrowStream, which returns the + // stream to the pool on destruction. + using Ptr = std::unique_ptr; + + StreamPool() {} + + // Returns a pointer to a stream in the pool, creating a new stream + // if none are available in the pool. The returned smart pointer + // returns the stream to the pool on destruction. + // + // This method is thread-safe. + Ptr BorrowStream(se::StreamExecutor* executor); + + private: + // Puts a pointer to a stream back into the pool, leaving it free + // for future use. Streams that have previously encountered errors + // are deleted, and not returned to the pool. + // + // This method is thread-safe. + void ReturnStream(se::Stream* stream); + + tensorflow::mutex mu_; + std::vector> streams_ GUARDED_BY(mu_); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_ diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..aaf5c37b0d250f78cb57639255ac9b59e1b462f7 --- /dev/null +++ b/tensorflow/compiler/xla/service/stream_pool_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stream_pool.h" + +#include + +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace { + +class StreamPoolTest : public ::testing::Test { + protected: + std::unique_ptr NewStreamExecutor() { + se::Platform* platform = + se::MultiPlatformManager::PlatformWithName("Host").ConsumeValueOrDie(); + se::StreamExecutorConfig config(/*ordinal=*/0); + return platform->GetUncachedExecutor(config).ConsumeValueOrDie(); + } +}; + +TEST_F(StreamPoolTest, EmptyPool) { StreamPool pool; } + +TEST_F(StreamPoolTest, OneStreamPool) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow and return a stream. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + se::Stream* stream1_ptr = stream1.get(); + EXPECT_TRUE(stream1->ok()); + stream1 = nullptr; + + // Borrow and return another stream. + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + se::Stream* stream2_ptr = stream2.get(); + EXPECT_TRUE(stream2->ok()); + stream2 = nullptr; + + // The underlying streams should be the same, since stream1 was the + // only stream available in the pool when stream2 was borrowed. + EXPECT_EQ(stream1_ptr, stream2_ptr); +} + +TEST_F(StreamPoolTest, TwoStreamPool) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow two streams. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + se::Stream* stream1_ptr = stream1.get(); + EXPECT_TRUE(stream1->ok()); + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + se::Stream* stream2_ptr = stream2.get(); + EXPECT_TRUE(stream2->ok()); + + // The underlying streams should be different, since we haven't + // returned either of them yet. + EXPECT_NE(stream1_ptr, stream2_ptr); + + // Return stream1 and borrow stream3. + stream1 = nullptr; + StreamPool::Ptr stream3 = pool.BorrowStream(executor.get()); + se::Stream* stream3_ptr = stream3.get(); + EXPECT_TRUE(stream3->ok()); + + // stream1 and stream3 should be the same. + EXPECT_EQ(stream1_ptr, stream3_ptr); + EXPECT_NE(stream2_ptr, stream3_ptr); + + // Return stream2, and borrow stream4. + stream2 = nullptr; + StreamPool::Ptr stream4 = pool.BorrowStream(executor.get()); + se::Stream* stream4_ptr = stream4.get(); + EXPECT_TRUE(stream4->ok()); + + // Stream2 and stream4 should be the same. + EXPECT_EQ(stream2_ptr, stream4_ptr); + EXPECT_NE(stream3_ptr, stream4_ptr); +} + +TEST_F(StreamPoolTest, BadStreamDiscarded) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow a stream. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream1->ok()); + + // Force an error on the stream; here we call a method that requires + // DNN support, which we know the Host platform doesn't support. + stream1->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(stream1->ok()); + + // Return stream1 and borrow stream2. + stream1 = nullptr; + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + se::Stream* stream2_ptr = stream2.get(); + EXPECT_TRUE(stream2->ok()); + + // The underlying streams should be different. They would have been + // the same, but since we forced an error on stream1, it cannot be + // put back into the pool. Sadly we can't just check: + // EXPECT_NE(stream1_ptr, stream2_ptr); + // + // The above should hold logically, but it may fail if the new + // stream instance allocated for stream2 happens to reside in the + // same memory address as stream1, which has been deleted. + // + // The check that stream2->ok() serves as a good-enough check. + + // Return stream2 and borrow stream3. The previous error on stream1 + // has no effect on these streams, and they are the same. + stream2 = nullptr; + StreamPool::Ptr stream3 = pool.BorrowStream(executor.get()); + se::Stream* stream3_ptr = stream3.get(); + EXPECT_TRUE(stream3->ok()); + EXPECT_EQ(stream2_ptr, stream3_ptr); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index c4d01562c4e32225ebb984d8fcd93ec3fa86e403..32d368a90429ec026120bdf033957617eeaba23e 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -22,8 +22,12 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/notification.h" + +using ::tensorflow::strings::StrCat; namespace xla { /* static */ tensorflow::mutex @@ -36,8 +40,104 @@ TransferManager::GetPlatformTransferManagers() { return r; } +StatusOr> TransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer) { + StatusOr> ret; + + se::Stream* substream = stream->GetOrCreateSubStream(); + substream->ThenWaitFor(stream); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + tensorflow::Notification n; + Status s; + Literal literal(device_buffer.on_host_shape()); + TransferLiteralFromDevice(substream, device_buffer, literal, + [&](Status status) { + s = status; + n.Notify(); + }); + n.WaitForNotification(); + if (!s.ok()) { + return s; + } + return MakeUnique(std::move(literal)); +} + +Status TransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal) { + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + Status ret; + tensorflow::Notification n; + TransferLiteralFromDevice(substream, device_buffer, literal, + [&](Status status) { + ret = status; + n.Notify(); + }); + n.WaitForNotification(); + return ret; +} + +Status TransferManager::TransferLiteralToDevice( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer) { + // Implement the synchronous version by waiting on the asynchronous version. + // Use a substream so that if we are called from a HostCallback we don't + // deadlock. + se::Stream* substream = stream->GetOrCreateSubStream(); + substream->ThenWaitFor(stream); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + TF_RETURN_IF_ERROR( + TransferLiteralToDeviceAsync(substream, literal, device_buffer)); + return substream->BlockHostUntilDone(); +} + +StatusOr> TransferManager::TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source) { + StatusOr> ret; + // Implement the synchronous version by waiting on the asynchronous version. + // Use a substream so that if we are called from a HostCallback we don't + // deadlock. + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + tensorflow::Notification n; + Literal literal(shape); + Status s; + TransferArrayFromDevice(substream, shape, source, literal, + [&](Status status) { + s = status; + n.Notify(); + }); + n.WaitForNotification(); + if (!s.ok()) { + return s; + } + return MakeUnique(std::move(literal)); +} + Status TransferManager::TransferArrayToDevice( - se::StreamExecutor* executor, const LiteralSlice& literal, + se::Stream* stream, const LiteralSlice& literal, + const se::DeviceMemoryBase& dest) { + // Implement the synchronous version by waiting on the asynchronous version. + // Use a substream so that if we are called from a HostCallback we don't + // deadlock. + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + TF_RETURN_IF_ERROR(TransferArrayToDeviceAsync(substream, literal, dest)); + return substream->BlockHostUntilDone(); +} + +Status TransferManager::TransferArrayToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) @@ -51,28 +151,33 @@ Status TransferManager::TransferArrayToDevice( dest.size(), GetByteSizeRequirement(on_device_shape)); } ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, - executor->platform(), executor->device_ordinal()); + stream->parent()->platform(), + stream->parent()->device_ordinal()); shaped_buffer.set_buffer(dest, /*index=*/{}); - return TransferLiteralToDevice(executor, literal, shaped_buffer); + return TransferLiteralToDevice(stream, literal, shaped_buffer); } -StatusOr> TransferManager::TransferArrayFromDevice( - se::StreamExecutor* executor, const Shape& shape, - const se::DeviceMemoryBase& source) { - TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) - << "Shape " << ShapeUtil::HumanString(shape) - << " has a differently shaped representation on-device: " - << ShapeUtil::HumanString(HostShapeToDeviceShape(shape)); +void TransferManager::TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, + const MutableBorrowingLiteral& literal, std::function done) { + if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) { + auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), + " has a differently shaped representation on-device: ", + ShapeUtil::HumanString(HostShapeToDeviceShape(shape))); + return done(FailedPrecondition("%s", error.c_str())); + } if (source.size() < GetByteSizeRequirement(shape)) { - return FailedPrecondition( - "Allocation on device not large enough for array: " - "%lld < %lld", - source.size(), GetByteSizeRequirement(shape)); + return done( + FailedPrecondition("Allocation on device not large enough for array: " + "%lld < %lld", + source.size(), GetByteSizeRequirement(shape))); } ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, - executor->platform(), executor->device_ordinal()); + stream->parent()->platform(), + stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); - return TransferLiteralFromDevice(executor, shaped_buffer); + return TransferLiteralFromDevice(stream, shaped_buffer, literal, + std::move(done)); } /* static */ void TransferManager::RegisterTransferManager( @@ -108,10 +213,14 @@ StatusOr> TransferManager::TransferArrayFromDevice( } Status TransferManager::WriteTupleIndexTables( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { - VLOG(2) << "Writing tuple index tables for " << device_buffer; + se::Stream* stream, const ShapedBuffer& device_buffer) { + TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer)); + return stream->BlockHostUntilDone(); +} - TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); +Status TransferManager::WriteTupleIndexTablesAsync( + se::Stream* stream, const ShapedBuffer& device_buffer) { + VLOG(2) << "Writing tuple index tables for " << device_buffer; return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_device_shape(), @@ -129,7 +238,7 @@ Status TransferManager::WriteTupleIndexTables( elements.push_back(device_buffer.buffer(element_index)); element_index.pop_back(); } - return WriteSingleTupleIndexTable(executor, elements, device_subshape, + return WriteSingleTupleIndexTable(stream, elements, device_subshape, &device_memory); } @@ -138,26 +247,20 @@ Status TransferManager::WriteTupleIndexTables( } Status TransferManager::TransferBufferFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - int64 size, void* destination) { + se::Stream* stream, const se::DeviceMemoryBase& source, int64 size, + void* destination) { if (source.size() < size) { return FailedPrecondition( "Source allocation on device not large enough for data tranfer: " "%lld < %lld", source.size(), size); } - auto copy_status = executor->SynchronousMemcpyD2H(source, size, destination); - if (!copy_status.ok()) { - return AddStatus( - Status(static_cast(copy_status.code()), - copy_status.error_message()), - "failed transfer from device to buffer"); - } + stream->ThenMemcpy(destination, source, size); return Status::OK(); } Status TransferManager::TransferBufferToDevice( - se::StreamExecutor* executor, int64 size, const void* source, + se::Stream* stream, int64 size, const void* source, se::DeviceMemoryBase* destination) { if (destination->size() < size) { return FailedPrecondition( @@ -165,13 +268,7 @@ Status TransferManager::TransferBufferToDevice( "%lld < %lld", destination->size(), size); } - auto copy_status = executor->SynchronousMemcpyH2D(source, size, destination); - if (!copy_status.ok()) { - return AddStatus( - Status(static_cast(copy_status.code()), - copy_status.error_message()), - "failed transfer of buffer to device"); - } + stream->ThenMemcpy(destination, source, size); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 43a8092b06fba0e2495bce0ee1a309c85a908273..475a2e5c141d66fa689fb402da1ee81fb4ab80f7 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -52,30 +52,69 @@ class TransferManager { return host_shape; } - // Returns a literal containing the data held in the given ShapedBuffer. - // using the provided executor. The optional literal_shape will be the shape - // for the literal. The shape of the ShapedBuffer and - // DeviceShape(literal_shape) must be compatible, but need not have the same - // layout. + // Returns a literal containing the data held in the given ShapedBuffer + // using the provided executor. This operation is performed synchronously + // without waiting for any other operation on a stream to complete. + // + // This function should be avoided in favor of the asynchronous version below. virtual StatusOr> TransferLiteralFromDevice( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) = 0; + se::Stream* stream, const ShapedBuffer& device_buffer); + virtual Status TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal); + + // Begins transferring a literal containing the data held in the given + // ShapedBuffer using the provided executor. + // + // This operation is performed asynchronously on the given stream. It returns + // once the transfer is enqueued. 'done' is invoked with the result when + // complete. + // + // device_buffer is copied by reference and must live at least until done() is + // invoked. + virtual void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function done) = 0; // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, - // but need not have the same layout - virtual Status TransferLiteralToDevice(se::StreamExecutor* executor, + // but need not have the same layout. + // + // This operation is performed synchronously without waiting for any other + // operation on a stream to complete. This function should be avoided in favor + // of the asynchronous version below. + virtual Status TransferLiteralToDevice(se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) = 0; + const ShapedBuffer& device_buffer); + + // Transfers the given literal into the previously allocated device memory + // represented by the given ShapedBuffer using the given executor. The shape + // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, + // but need not have the same layout. + // + // This operation is performed asynchronously on the given stream. It returns + // once the transfer is enqueued. + virtual Status TransferLiteralToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer) = 0; // Convenience methods for transferring an array to or from the device at a // known address. This avoids having to construct a ShapedBuffer just to // transfer an array at a known address. - Status TransferArrayToDevice(se::StreamExecutor* executor, - const LiteralSlice& literal, + Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); + void TransferArrayFromDevice(se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + const MutableBorrowingLiteral& literal, + std::function done); + + Status TransferArrayToDeviceAsync(se::Stream* stream, + const LiteralSlice& literal, + const se::DeviceMemoryBase& dest); StatusOr> TransferArrayFromDevice( - se::StreamExecutor* executor, const Shape& shape, + se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source); // Transfers the given literal into the Infeed interface of the device, @@ -85,9 +124,9 @@ class TransferManager { // Transfers the given literal from the Outfeed interface of the device, // using the given executor. - virtual Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - const Shape& literal_shape, - Literal* literal) = 0; + virtual Status TransferLiteralFromOutfeed( + se::StreamExecutor* executor, const Shape& literal_shape, + MutableBorrowingLiteral literal) = 0; // Resets the devices associated with this transfer manager. virtual Status ResetDevices( @@ -96,8 +135,10 @@ class TransferManager { // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the // ShapedBuffer is array-shaped this method does nothing. - Status WriteTupleIndexTables(se::StreamExecutor* executor, + Status WriteTupleIndexTables(se::Stream* stream, const ShapedBuffer& device_buffer); + Status WriteTupleIndexTablesAsync(se::Stream* stream, + const ShapedBuffer& device_buffer); // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory @@ -130,21 +171,11 @@ class TransferManager { const se::Platform* platform); protected: - // Transfer a memory block of the given size from 'source' buffer to the - // Infeed interface of the device using the given executor. - // - // size is the size to transfer from source in bytes. - // - // source is the source data that must be in the target-dependent layout that - // the Infeed HLO used in the computation expects. - virtual Status TransferBufferToInfeed(se::StreamExecutor* executor, - int64 size, const void* source) = 0; - // Transfer a memory block of the given size from the device source into the // 'destination' buffer. // // size is the size to transfer to destination in bytes. - virtual Status TransferBufferFromDevice(se::StreamExecutor* executor, + virtual Status TransferBufferFromDevice(se::Stream* stream, const se::DeviceMemoryBase& source, int64 size, void* destination); @@ -152,15 +183,15 @@ class TransferManager { // destination of the device. // // size is the size to transfer from source in bytes. - virtual Status TransferBufferToDevice(se::StreamExecutor* executor, - int64 size, const void* source, + virtual Status TransferBufferToDevice(se::Stream* stream, int64 size, + const void* source, se::DeviceMemoryBase* destination); // Writes the given device-memory pointers in 'elements' to the given region // to construct a tuple index table in the platform-specific tuple // representation. virtual Status WriteSingleTupleIndexTable( - se::StreamExecutor* executor, + se::Stream* stream, tensorflow::gtl::ArraySlice elements, const Shape& shape, se::DeviceMemoryBase* region) = 0; diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index cccb8f2fbb0266bbf1f40b09170938a1e5d3e78d..58f767e913fbc0023e0c45a4f0e82ecefeeef2d6 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -160,11 +160,11 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { auto builder = HloComputation::Builder("entry"); // (1.0 + 2.0) * (2.0 - 3.0) HloInstruction* const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* const2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); HloInstruction* const3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( const1->shape(), HloOpcode::kAdd, const1, const2)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index eb6d1ada6b553f998fe06917dfdf0b5092cd79cd..0447807a41b8b32ee297e1ca94393da8c687c5e6 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" @@ -121,7 +122,6 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index, } namespace { - // Gather fusion instructions from 'instruction' into 'fusion_instructions'. void GatherFusionInstructions( HloInstruction* instruction, @@ -232,8 +232,7 @@ Status TuplePointsToAnalysis::HandleGetTupleElement( // Copy the points-to set (and tuple sources) at index {element_index} of the // operand to the points-to set for this GetTupleElement instruction. points_to_set.ForEachMutableElement( - [&, this](const ShapeIndex& target_index, - PointsToSet::BufferList* points_to) { + [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) { // Construct an index into the operand by prepending element_index to // the index for the GetTupleElement instruction's points-to set. ShapeIndex src_index; @@ -292,22 +291,29 @@ Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { } Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { - // RecvDone aliases its input (Recv) tuple element {0} to its output. + // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its + // output. The other indices ({} and {1}) define their own buffers. PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done); + points_to_set.AddPointedToBuffer( + logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}), + /*index=*/{}); + points_to_set.AddPointedToBuffer( + logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}), + /*index=*/{1}); + const PointsToSet& operand_points_to_set = GetPointsToSet(recv_done->operand(0)); - // Recursively copy the points to set of the operand tuple {0}. + // Recursively copy the points to set of the operand tuple {0} to the output + // element {0}. points_to_set.ForEachMutableElement( - [this, &points_to_set, &operand_points_to_set]( + [&points_to_set, &operand_points_to_set]( const ShapeIndex& index, PointsToSet::BufferList* buffers) { - ShapeIndex src_index({0}); - for (auto element : index) { - src_index.push_back(element); + if (index.empty() || index[0] != 0) { + return; } - *buffers = operand_points_to_set.element(src_index); - for (auto& tuple_source : - operand_points_to_set.tuple_sources(src_index)) { + *buffers = operand_points_to_set.element(index); + for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) { points_to_set.add_tuple_source(index, tuple_source); } }); @@ -315,7 +321,7 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { } Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { - // Send creates a tuple of {aliased operand, U32 context}. + // Send creates a tuple of {aliased operand, U32 context, token}. PointsToSet& points_to_set = CreateEmptyPointsToSet(send); // Creates the points to set for the tuple and its element at {1}. @@ -328,6 +334,10 @@ Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { context_buffer->push_back( &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1}))); + auto token_buffer = points_to_set.mutable_element(ShapeIndex({2})); + token_buffer->push_back( + &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2}))); + // Recursively copy the points to set of the operand to output tuple {0}. const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0)); operand_points_to_set.ForEachElement( @@ -388,7 +398,7 @@ Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select) { +Status TuplePointsToAnalysis::HandleTupleSelect(HloInstruction* tuple_select) { // Select allocates a new buffer and then shallow copies the on_true or // on_false buffer into this new buffer. Which side is chosen cannot be // determined statically so conservatively set the points-to set to the union @@ -396,9 +406,9 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select) { // // First create a copy of the on_true points-to set (and tuple sources), then // add in elements of the on_false points-to set (tuple sources). - auto on_true = select->operand(1); - auto on_false = select->operand(2); - PointsToSet& points_to_set = CreateCopiedPointsToSet(select, on_true); + auto on_true = tuple_select->operand(1); + auto on_false = tuple_select->operand(2); + PointsToSet& points_to_set = CreateCopiedPointsToSet(tuple_select, on_true); const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set; points_to_set.ForEachMutableElement( [&](const ShapeIndex& index, PointsToSet::BufferList* buffers) { @@ -416,7 +426,7 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select) { // respective element in the points-to set should contain only itself. points_to_set.mutable_element({})->clear(); points_to_set.AddPointedToBuffer( - logical_buffer_analysis_->GetBuffer(select, /*index=*/{}), + logical_buffer_analysis_->GetBuffer(tuple_select, /*index=*/{}), /*index=*/{}); return Status::OK(); } @@ -506,7 +516,7 @@ Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction( const HloInstruction* instruction, TuplePointsToAnalysis::BufferDefinitionVector* buffers) { GetPointsToSet(instruction) - .ForEachElement([this, buffers, instruction]( + .ForEachElement([buffers, instruction]( const ShapeIndex& index, const PointsToSet::BufferList& source_buffers) { // Add buffers which 'instruction' is the source of. @@ -536,7 +546,7 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction); const PointsToSet& src_points_to_set = GetPointsToSet(src); dst_points_to_set.ForEachMutableElement( - [this, &dst_points_to_set, &src_points_to_set]( + [&dst_points_to_set, &src_points_to_set]( const ShapeIndex& index, PointsToSet::BufferList* buffers) { *buffers = src_points_to_set.element(index); for (auto& tuple_source : src_points_to_set.tuple_sources(index)) { @@ -707,6 +717,7 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( // root at operand 0 or 1. Or... // (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index // 0. +// (5) The 'user' of 'operand' is Sort, and it is the only user. // // (2) and (3) can only be determined if points-to analysis is available. bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( @@ -723,7 +734,8 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( return false; } if (user->opcode() == HloOpcode::kFusion) { - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || + user->fusion_kind() == HloInstruction::FusionKind::kInput) { if (user->fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice) { // Loop fusion with kDynamicUpdateSlice fused root. @@ -732,6 +744,11 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( // 'operand_index', and this singleton use is the fused root at operand // index 0. return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + } else { + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + return HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( + fusion_param); } } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { @@ -766,6 +783,21 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kSort) { + // Only valid if there are no other users. + if (operand->users().size() != 1) { + return false; + } + // If we only sort keys, the output of sort is not a tuple, so we can always + // share the buffer. + if (user->operand_count() == 1) { + return true; + } + CHECK(!user_index.empty()); + // Only share with the right tuple element buffer. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && user_index[0] == operand_indices[0]; + } if (user->opcode() == HloOpcode::kCall) { // TODO(b/62548313): Remove when buffer assignment is module scoped and // does not assign buffers to calls. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index c0d82414806d9a6ff57aec59d077f444137fec9a..686bb053288fbd6a46ca50a2c65c739354fd2678 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -253,7 +253,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; - Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; string ToString() const; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 5734f284071944bc22011405898cf86f33dc48d7..10d382e8abc92145c1804cbf18bbed714fa34571 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -124,9 +124,9 @@ class TuplePointsToAnalysisTest : public HloTestBase { TEST_F(TuplePointsToAnalysisTest, SimpleTuple) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -177,14 +177,14 @@ TEST_F(TuplePointsToAnalysisTest, NestedTuple) { // tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, constant3})); @@ -238,14 +238,14 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) { // tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, constant3})); @@ -270,7 +270,7 @@ TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) { // Create a tuple which contains duplicate elements. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant, constant, constant})); @@ -291,9 +291,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) { // the same. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto copy = builder.AddInstruction( @@ -317,9 +317,10 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) { // Send forwards its operand to the output tuple at {0}. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto send = builder.AddInstruction( - HloInstruction::CreateSend(constant, /*channel_id=*/0)); + HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); BuildModuleAndRunAnalysis(builder.Build()); @@ -342,8 +343,9 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) { TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) { // RecvDone forwards its operand tuple element at {0} to the output. auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto recv = builder.AddInstruction(HloInstruction::CreateRecv( - ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0)); + ShapeUtil::MakeShape(F32, {1, 2, 3}), token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); BuildModuleAndRunAnalysis(builder.Build()); @@ -355,7 +357,7 @@ TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) { ExpectHasTopLevelBuffers( points_to_analysis_->GetPointsToSet(recv).element({}), {recv}); - ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}}); + ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {0}}}); } TEST_F(TuplePointsToAnalysisTest, TupleSelect) { @@ -363,18 +365,18 @@ TEST_F(TuplePointsToAnalysisTest, TupleSelect) { // set containing the union of both sides. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant2, constant2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); BuildModuleAndRunAnalysis(builder.Build()); @@ -401,9 +403,9 @@ TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) { auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, tuple_shape, "param1")); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, param0, param1)); + tuple_shape, HloOpcode::kTupleSelect, pred, param0, param1)); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(tuple_shape, HloOpcode::kCopy, select)); @@ -441,18 +443,18 @@ TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) { // Select from two identical tuples. The result should not be ambiguous. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); BuildModuleAndRunAnalysis(builder.Build()); @@ -472,9 +474,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { // the right values. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto inner_tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto inner_tuple2 = builder.AddInstruction( @@ -486,9 +488,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); BuildModuleAndRunAnalysis(builder.Build()); @@ -519,9 +521,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) { // have the operand of the bitcast in its points-to set. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( constant2->shape(), HloOpcode::kBitcast, constant2)); auto tuple = @@ -555,9 +557,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { // Construct a tuple constant and kCopy it. Verify the points-to set of the // copy correctly correctly points into the nested elements of the constant. auto builder = HloComputation::Builder(TestName()); - auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), - Literal::CreateR1({2.0, 42}).get()}))); + auto tuple_constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), + LiteralUtil::CreateR1({2.0, 42}).get()}))); auto copy = builder.AddInstruction(HloInstruction::CreateUnary( tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); @@ -577,9 +580,9 @@ TEST_F(TuplePointsToAnalysisTest, BufferAliases) { // times. Verify buffer alias sets. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple = builder.AddInstruction( @@ -618,7 +621,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { auto tuple_element1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1)); auto ones = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); + LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); // Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones) auto update = builder.AddInstruction(HloInstruction::CreateBinary( update_shape, HloOpcode::kAdd, tuple_element1, ones)); @@ -866,9 +869,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -960,9 +963,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -1009,14 +1012,56 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto sort = + builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + Shape values_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto values = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values")); + auto sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The buffer for the keys can be shared with the first tuple entry. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0})); + // The buffer for the values can be shared with the second tuple entry. + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(values, {}, + sort, {1})); + // Verify that the buffers are not shared with the "wrong" tuple entry. + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(values, {}, + sort, {0})); +} + TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto builder = HloComputation::Builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -1025,7 +1070,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -1047,7 +1092,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -1055,7 +1100,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { HloInstruction::CreateReverse(data_shape, operand, {0, 1})); auto two = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); @@ -1073,7 +1118,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto make_cond = [this, &data_shape]() { + auto make_cond = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); @@ -1082,7 +1127,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - auto make_body = [this, &data_shape]() { + auto make_body = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Body"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); @@ -1120,7 +1165,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto sub_param = sub_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "sub_param")); auto one = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto ones = sub_builder.AddInstruction( HloInstruction::CreateBroadcast(shape, one, {1})); auto add = sub_builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index d3635eae81ec7017f9bf6a69250d10716309c9ec..39b693872da6bd985d95c2abc9519662c838a3f5 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..af2cb6dc2a3f4a004351acc62796e0daf46719c2 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -0,0 +1,238 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { + +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; + +// Finds and returns the non-constant operand in instr. +// +// CHECK-fails if instr doesn't have exactly one unique non-constant operand. +static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { + const HloInstruction* result = nullptr; + for (const HloInstruction* operand : instr->operands()) { + if (!operand->IsConstant()) { + if (result != nullptr) { + CHECK_EQ(result, operand); + } + result = operand; + } + } + CHECK_NE(result, nullptr); + return result; +} + +// If all of instr's operands are either constants or have the form +// get-tuple-element(gte_operand, N) +// for the same value N, returns N. Otherwise, returns nullopt. +static optional GetGTEOperandIndex(const HloInstruction* instr, + const HloInstruction* gte_operand) { + VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " + << gte_operand->ToString() << ")"; + optional tuple_idx; + for (const HloInstruction* operand : instr->operands()) { + if (operand->IsConstant()) { + continue; + } + // Look through copies. + // TODO(b/68830972): We wouldn't need this if for loop matching on the GPU + // would run before copy insertion. + if (operand->opcode() == HloOpcode::kCopy) { + operand = operand->operand(0); + } + if (operand->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "instr uses something other than gte(gte_operand): " + << operand->ToString(); + return nullopt; + } + if (operand->operand(0) != gte_operand) { + VLOG(2) << "instr has gte whose operand is not gte_operand: " + << operand->ToString(); + return nullopt; + } + if (tuple_idx && tuple_idx != operand->tuple_index()) { + VLOG(2) << "instr has operands with conflicting gte indices, " + << *tuple_idx << " vs " << operand->tuple_index(); + return nullopt; + } + + tuple_idx = operand->tuple_index(); + } + return tuple_idx; +} + +// Tries to get the tuple index of the induction variable of a while loop. +// +// Checks that the loop condition and root both plumb the induction variable +// through the same tuple index, and that they both apply exactly one op to the +// induction variable before deciding whether to do another loop iteration (in +// the loop condition's case) or packing the induction variable into the result +// tuple (in the loop body's case). +// +// Specifically, checks that the loop condition has structure +// +// root = op(constants, get-tuple-elem(param0, N), constants) +// +// and the loop body has the structure +// +// inc = op(constants, get-tuple-elem(param0, N), constants) +// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). +// +// If so, returns N. Otherwise, returns nullopt. +static optional GetLoopInductionVarTupleIdx( + const HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + VLOG(2) << "Finding induction variable for loop " + << while_op->ToShortString(); + + // The while_cond computation should have the form + // + // while_cond_root = + // op(constants, get-tuple-elem(while_cond_param, N), constants). + // + // If it does, set indvar_tuple_idx to N. + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_param = while_cond->parameter_instruction(0); + optional indvar_tuple_idx = + GetGTEOperandIndex(while_cond_root, while_cond_param); + if (!indvar_tuple_idx) { + VLOG(2) << "Induction variable not found in loop condition: " + << while_cond->root_instruction()->ToString(); + return nullopt; + } + + // The while_body computation should have the form + // + // while_body_inc = + // op(constants, get-tuple-elem(while_body_param, N), constants) + // while_body_root = tuple(..., while_body_inc, ...) + // + // where while_body_inc is operand N of while_body_root. + auto* while_body = while_op->while_body(); + auto* while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple instruction: " + << while_body_root->ToString(); + return nullopt; + } + + auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); + auto* while_body_param = while_body->parameter_instruction(0); + optional while_body_indvar_tuple_idx = + GetGTEOperandIndex(while_body_inc, while_body_param); + if (!while_body_indvar_tuple_idx) { + VLOG(2) + << "Induction variable not found in while body increment instruction: " + << while_body_inc->ToString(); + return nullopt; + } + if (while_body_indvar_tuple_idx != indvar_tuple_idx) { + VLOG(2) << "Tuple index of induction variable does not match between loop " + "condition (" + << *indvar_tuple_idx << ") and while body (" + << *while_body_indvar_tuple_idx << ")"; + return nullopt; + } + + // Finally, check that the while loop's initial value is a tuple with enough + // elements. + auto* while_init = while_op->operand(0); + if (while_init->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); + return nullopt; + } + + VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; + return indvar_tuple_idx; +} + +optional ComputeWhileLoopTripCount(HloInstruction* while_op, + int64 max_value_returned) { + VLOG(2) << "Getting trip count for loop " << while_op->ToString(); + + // The loop's induction variable is found at + // + // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), + // + // where comp is while_op->while_body() or while_op->while_condition(). + optional indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); + if (!indvar_tuple_idx) { + return nullopt; + } + + // Now that we know the index of the induction variable, we can we can try to + // compute how many times the loop executes. Start by computing the induction + // variable's initial value. + HloEvaluator evaluator(/*max_loop_iterations=*/0); + auto* while_init = while_op->mutable_operand(0); + auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); + StatusOr> indvar_init_result = + evaluator.Evaluate(indvar_init); + if (!indvar_init_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable init: " + << indvar_init_result.status(); + return nullopt; + } + + auto* while_body = while_op->while_body(); + auto* while_body_indvar_update = + while_body->root_instruction()->operand(*indvar_tuple_idx); + auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); + + // The initial value of the induction variable. + std::unique_ptr indvar_iter_val = + std::move(indvar_init_result).ValueOrDie(); + for (int64 trip_count = 0; trip_count != max_value_returned + 1; + ++trip_count) { + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_indvar = NonConstantOperand(while_cond_root); + StatusOr> result = + evaluator.EvaluateWithSubstitutions( + while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}}); + if (!result.ok()) { + VLOG(2) << "Couldn't evaluate while cond: " << result.status(); + return nullopt; + } + if (result.ValueOrDie()->data() == + tensorflow::gtl::ArraySlice{false}) { + VLOG(2) << "Loop has static trip count of " << trip_count; + return trip_count; + } + + // Calculate the value of the induction variable after one iteration of the + // loop, and check whether the while condition is true with this new value. + StatusOr> indvar_next_result = + evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, + {{while_body_indvar, indvar_iter_val.get()}}); + if (!indvar_next_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable update: " + << indvar_next_result.status(); + return nullopt; + } + indvar_iter_val = std::move(indvar_next_result).ValueOrDie(); + } + + VLOG(2) << "Loop has unknown trip count."; + return nullopt; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..bf59813e8c405a8709446bf8457729348ceae4ec --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace xla { + +// Returns the precise trip count of the loop if it's statically known, +// nullopt otherwise. max_value_returned limits the number of steps that are +// evaluated while trying to brute force a loop trip count, trip counts larger +// than max_value_returned result in nullopt. +tensorflow::gtl::optional ComputeWhileLoopTripCount( + HloInstruction *while_op, int64 max_value_returned = 128); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 10fc4958fae06414dbe7a3a0a798cb5c6e0f35c2..62af45128ad2fb7bf886bef78ec3ab42529a181e 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -61,6 +61,12 @@ StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { int64 index = invariant_gte->tuple_index(); const HloInstruction& invariant_value = *init_value.operand(index); + + // Should have at least one user that's not while_body_root. + if (invariant_gte->user_count() <= 1) { + continue; + } + if (invariant_value.opcode() == HloOpcode::kConstant) { auto* constant_instr = while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk")); diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 393e75803888d8a642881c4d525b170d1e1180ba..266039d2ff8ef4befba0d1023ac1914737207d4f 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -196,5 +196,50 @@ ENTRY entry { op::GetTupleElement(op::Parameter(0)), op::GetTupleElement(op::Parameter(0)))); } + +TEST_F(WhileLoopConstantSinkingTest, DontCreateDeadConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 + + outfeed = token[] outfeed(p_body.0) + ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1) +} + +condition { + p_cond = (f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + while_init = (f32[2],f32[2]) tuple(const_0, const_1) + ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, + body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::GetTupleElement(), op::GetTupleElement(), + op::GetTupleElement())); + for (const HloInstruction* inst : while_body->instructions()) { + if (inst->opcode() == HloOpcode::kConstant) { + EXPECT_GT(inst->user_count(), 0); + } + } +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 8831c513eee66e36163135b732f833d46cb7eb03..32e69c335b713c438bd7fcb2053709b0624f58ed 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -53,7 +53,7 @@ HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation( builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "param")); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); return module->AddEmbeddedComputation(builder.Build()); } @@ -125,7 +125,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { builder.AddInstruction(HloInstruction::CreateUnary( scalar_s32, HloOpcode::kNegate, mul_result)); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); HloInstruction* sub_result = builder.AddInstruction(HloInstruction::CreateBinary( scalar_s32, HloOpcode::kSubtract, negate_result, constant)); @@ -248,7 +248,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + auto token_shape = ShapeUtil::MakeTokenShape(); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -258,25 +260,32 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* in_token = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(token_shape, param, 2)); + HloInstruction* out_token = builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_s32, gte_0, in_token, "")); builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_s32, gte_0, "")); - builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + HloInstruction::CreateTuple({gte_0, gte_1, out_token})); return module().AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); + auto* scalar_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_s32, "param")); + auto* token = builder.AddInstruction(HloInstruction::CreateToken()); auto* init_value = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "init_value")); + HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( while_shape, MakeAlwaysTrueComputation(while_shape, &module()), while_body, init_value)); - + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); module().AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, WhileLoopInvariantCodeMotion{}.Run(&module())); - EXPECT_FALSE(simplified_loop); + ASSERT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Outfeed())); @@ -287,7 +296,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { // bitcast either. auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); - Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + auto token_shape = ShapeUtil::MakeTokenShape(); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape}); HloComputation* while_body = [&]() { HloComputation::Builder builder(TestName() + ".while_body"); @@ -297,21 +308,29 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); HloInstruction* gte_1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* in_token = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(token_shape, param, 2)); HloInstruction* bitcast_inst = builder.AddInstruction( HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + HloInstruction* out_token = builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, in_token, "")); builder.AddInstruction( - HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, "")); - builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + HloInstruction::CreateTuple({gte_0, gte_1, out_token})); return module().AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); + auto* scalar_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_s32, "param")); + auto* token = builder.AddInstruction(HloInstruction::CreateToken()); auto* init_value = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "init_value")); + HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( while_shape, MakeAlwaysTrueComputation(while_shape, &module()), while_body, init_value)); + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); module().AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index ec05a74e286c89dd8db5ae07580e461938d7c087..dd8697e680c56165f87c365a721eda2de1ebc085 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/call_inliner.h" -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -26,23 +26,6 @@ namespace xla { using tensorflow::gtl::nullopt; using tensorflow::gtl::optional; -// Finds and returns the non-constant operand in instr. -// -// CHECK-fails if instr doesn't have exactly one unique non-constant operand. -static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { - const HloInstruction* result = nullptr; - for (const HloInstruction* operand : instr->operands()) { - if (!operand->IsConstant()) { - if (result != nullptr) { - CHECK_EQ(result, operand); - } - result = operand; - } - } - CHECK_NE(result, nullptr); - return result; -} - // Determines whether the given instruction is a send/recv node, or has a // subcomputation which contains a send/recv node. static bool IsOrContainsSendOrRecv(const HloInstruction* instr); @@ -72,211 +55,6 @@ static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { return false; } -// If all of instr's operands are either constants or have the form -// get-tuple-element(gte_operand, N) -// for the same value N, returns N. Otherwise, returns nullopt. -static optional GetGTEOperandIndex(const HloInstruction* instr, - const HloInstruction* gte_operand) { - VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " - << gte_operand->ToString() << ")"; - optional tuple_idx; - for (const HloInstruction* operand : instr->operands()) { - if (operand->IsConstant()) { - continue; - } - if (operand->opcode() != HloOpcode::kGetTupleElement) { - VLOG(2) << "instr uses something other than gte(gte_operand): " - << operand->ToString(); - return nullopt; - } - if (operand->operand(0) != gte_operand) { - VLOG(2) << "instr has gte whose operand is not gte_operand: " - << operand->ToString(); - return nullopt; - } - if (tuple_idx && tuple_idx != operand->tuple_index()) { - VLOG(2) << "instr has operands with conflicting gte indices, " - << *tuple_idx << " vs " << operand->tuple_index(); - return nullopt; - } - - tuple_idx = operand->tuple_index(); - } - return tuple_idx; -} - -// Tries to get the tuple index of the induction variable of a while loop. -// -// Checks that the loop condition and root both plumb the induction variable -// through the same tuple index, and that they both apply exactly one op to the -// induction variable before deciding whether to do another loop iteration (in -// the loop condition's case) or packing the induction variable into the result -// tuple (in the loop body's case). -// -// Specifically, checks that the loop condition has structure -// -// root = op(constants, get-tuple-elem(param0, N), constants) -// -// and the loop body has the structure -// -// inc = op(constants, get-tuple-elem(param0, N), constants) -// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). -// -// If so, returns N. Otherwise, returns nullopt. -static optional GetLoopInductionVarTupleIdx( - const HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Finding induction variable for loop " - << while_op->ToShortString(); - - // The while_cond computation should have the form - // - // while_cond_root = - // op(constants, get-tuple-elem(while_cond_param, N), constants). - // - // If it does, set indvar_tuple_idx to N. - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_param = while_cond->parameter_instruction(0); - optional indvar_tuple_idx = - GetGTEOperandIndex(while_cond_root, while_cond_param); - if (!indvar_tuple_idx) { - VLOG(2) << "Induction variable not found in loop condition: " - << while_cond->root_instruction()->ToString(); - return nullopt; - } - - // The while_body computation should have the form - // - // while_body_inc = - // op(constants, get-tuple-elem(while_body_param, N), constants) - // while_body_root = tuple(..., while_body_inc, ...) - // - // where while_body_inc is operand N of while_body_root. - auto* while_body = while_op->while_body(); - auto* while_body_root = while_body->root_instruction(); - if (while_body_root->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While body's root is not a tuple instruction: " - << while_body_root->ToString(); - return nullopt; - } - - auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); - auto* while_body_param = while_body->parameter_instruction(0); - optional while_body_indvar_tuple_idx = - GetGTEOperandIndex(while_body_inc, while_body_param); - if (!while_body_indvar_tuple_idx) { - VLOG(2) - << "Induction variable not found in while body increment instruction: " - << while_body_inc->ToString(); - return nullopt; - } - if (while_body_indvar_tuple_idx != indvar_tuple_idx) { - VLOG(2) << "Tuple index of induction variable does not match between loop " - "condition (" - << *indvar_tuple_idx << ") and while body (" - << *while_body_indvar_tuple_idx << ")"; - return nullopt; - } - - // Finally, check that the while loop's initial value is a tuple with enough - // elements. - auto* while_init = while_op->operand(0); - if (while_init->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); - return nullopt; - } - - VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; - return indvar_tuple_idx; -} - -// Tries to determine the number of times the given loop executes. Currently -// simply returns 0, 1, or "can't tell" (nullopt). -static optional GetLoopTripCount(HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Getting trip count for loop " << while_op->ToString(); - - // The loop's induction variable is found at - // - // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), - // - // where comp is while_op->while_body() or while_op->while_condition(). - optional indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); - if (!indvar_tuple_idx) { - return nullopt; - } - - VLOG(2) << "Induction variable is at index " << *indvar_tuple_idx - << " in input tuple."; - - // Now that we know the index of the induction variable, we can we can try to - // compute how many times the loop executes. Start by computing the induction - // variable's initial value. - HloEvaluator evaluator(/*max_loop_iterations=*/0); - auto* while_init = while_op->mutable_operand(0); - auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); - StatusOr> indvar_init_result = - evaluator.Evaluate(indvar_init); - if (!indvar_init_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable init: " - << indvar_init_result.status(); - return nullopt; - } - - // Evaluates the while loop's condition, returning either "true" (continue - // looping), "false" (stop looping), or nullopt (can't evaluate). - auto evaluate_while_cond = [&](const Literal& indvar) -> optional { - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_indvar = NonConstantOperand(while_cond_root); - StatusOr> result = - evaluator.EvaluateWithSubstitutions(while_cond_root, - {{while_cond_indvar, &indvar}}); - if (!result.ok()) { - VLOG(2) << "Couldn't evaluate while cond: " << result.status(); - return nullopt; - } - return result.ValueOrDie()->data() == - tensorflow::gtl::ArraySlice{true}; - }; - - // The initial value of the induction variable. - const Literal& indvar_iter0_val = *indvar_init_result.ValueOrDie(); - - // Evaluate whether the while condition is true when seeded with - // indvar_iter0_val. - optional while_cond_iter0_val = evaluate_while_cond(indvar_iter0_val); - if (while_cond_iter0_val == false) { - VLOG(2) << "Loop has static trip count of 0."; - return 0; - } - - // Calculate the value of the induction variable after one iteration of the - // loop, and check whether the while condition is true with this new value. - auto* while_body = while_op->while_body(); - auto* while_body_indvar_update = - while_body->root_instruction()->operand(*indvar_tuple_idx); - auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); - StatusOr> indvar_iter1_result = - evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, {{while_body_indvar, &indvar_iter0_val}}); - if (!indvar_iter1_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable update: " - << indvar_iter1_result.status(); - return nullopt; - } - const Literal& indvar_iter1_val = *indvar_iter1_result.ValueOrDie(); - optional while_cond_iter1_val = evaluate_while_cond(indvar_iter1_val); - if (while_cond_iter1_val == false) { - VLOG(2) << "Determined that loop has static trip count of 1."; - return 1; - } - - VLOG(2) << "Loop has unknown trip count >= 1."; - return nullopt; -} - // Tries to remove elements in a while loop's tuple that aren't used within the // loop. // @@ -577,7 +355,9 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { } // Remove while loops with static trip count of 0. - optional trip_count = GetLoopTripCount(while_op); + optional trip_count = + ComputeWhileLoopTripCount(while_op, + /*max_value_returned=*/1); if (trip_count && *trip_count == 0) { // The loop never executes, so the value of the loop is the value of its // "init" operand. diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 619e87caa5b6d0f6ec3c3b1489b0d4f50ef29963..2e1571943e537f772ee7dcd95c80ba540445b76e 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -157,7 +157,7 @@ TEST_F(WhileLoopSimplifierTest, auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* true_op = while_op->while_body()->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK(true_op->AddControlDependencyTo( while_op->while_body()->root_instruction())); ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); @@ -175,9 +175,11 @@ TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); + auto* token = while_body->AddInstruction(HloInstruction::CreateToken()); auto* send = while_body->AddInstruction(HloInstruction::CreateSend( while_body->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))), + token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateSendDone(send)); EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); @@ -190,8 +192,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); + auto* token = while_body->AddInstruction(HloInstruction::CreateToken()); auto* recv = while_body->AddInstruction( - HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), + HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); @@ -208,8 +211,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); - while_body->AddInstruction( - HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + auto token = while_body->AddInstruction(HloInstruction::CreateToken()); + while_body->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::MakeShape(F32, {1}), token, "config")); EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 473eab2ea84eb8faf745cbe299bc80bcc1b62a35..1ef17b9d7d2e769aadf39f8a70f78200b88e9d2c 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/tuple_util.h" @@ -38,7 +39,7 @@ static StatusOr WidenWhileCondition( // the root instruction later. We later change the root instruction to // something more appropriate. builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); return narrow_condition->parent()->AddEmbeddedComputation(builder.Build()); }(); @@ -154,7 +155,7 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape, {&loop_state_shape}, scalar_pred, "while_cond")); HloInstruction* trip_count_constant = cond_computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(trip_count))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(trip_count))); HloInstruction* param = cond_computation->parameter_instruction(0); TF_ASSIGN_OR_RETURN(HloInstruction * indvar, @@ -175,7 +176,7 @@ static StatusOr> MakeCountedLoopBodyComputation( CreateComputationWithSignature( {&loop_state_shape}, loop_state_shape, "while_body")); HloInstruction* one = body_computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); HloInstruction* param = body_computation->parameter_instruction(0); TF_ASSIGN_OR_RETURN(HloInstruction * indvar, MakeGetTupleElementHlo(param, 0)); @@ -203,7 +204,7 @@ static StatusOr MakeInitTupleFromInitValues( std::vector init_values_with_indvar; init_values_with_indvar.reserve(init_values.size() + 1); HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); init_values_with_indvar.push_back(zero); c_copy(init_values, std::back_inserter(init_values_with_indvar)); return computation->AddInstruction( diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index d79d3297213e832306ea4726483b0f215df0f5d3..2ccb919acf9c4e7c59a1ebaf36f42a6781068b5e 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -179,7 +179,9 @@ body { cond { param.c = (s32[], s32[]) parameter(0) - ROOT condition = pred[] infeed() + token = token[] after-all() + infeed = (pred[], token[]) infeed(token) + ROOT condition = pred[] get-tuple-element(infeed), index=0 } ENTRY main { diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc index aa40b5cb264803097f52966d6f61f1f41b6b3017..83d696fe0915086c3c98b6d7cbdaeaeb4d9d0bdb 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -32,11 +32,12 @@ StatusOr ZeroSizedHloElimination::Run(HloModule* module) { for (HloComputation* comp : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { if (instruction->HasSideEffect() || - ShapeUtil::IsTuple(instruction->shape())) { + !ShapeUtil::IsArray(instruction->shape()) || + instruction->opcode() == HloOpcode::kConstant) { continue; } if (comp->IsRemovable(instruction) && - ShapeUtil::HasZeroElements(instruction->shape())) { + ShapeUtil::IsZeroElementArray(instruction->shape())) { TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( instruction, HloInstruction::CreateConstant( Literal::CreateFromShape(instruction->shape())))); diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index f5331280ee9f252aa5717baab88f2c203be5c372..b9ef18892d7aa859f6b0b505db4c004e4f5c5066 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -67,7 +67,16 @@ TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateParameter) { } TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateSideEffects) { - builder_.AddInstruction(HloInstruction::CreateSend(zero_sized_param_, 0)); + auto token = builder_.AddInstruction(HloInstruction::CreateToken()); + builder_.AddInstruction( + HloInstruction::CreateSend(zero_sized_param_, token, 0)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); + EXPECT_FALSE(changed); +} + +TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateConstant) { + builder_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 7ee366b27a82bdbcb7a63a57ea80194db8ca7df4..caad31d6ce7ce35fa362ec364b0d7f1d95973715 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -67,6 +67,14 @@ void ShapeLayout::ResetLayout(const Layout& layout) { TF_CHECK_OK(ShapeUtil::ValidateShape(shape_)); } +void ShapeLayout::ResetLayout(const Layout& layout, + ShapeIndexView shape_index) { + CHECK(ShapeUtil::IsTuple(shape_)); + *ShapeUtil::GetMutableSubshape(&shape_, shape_index)->mutable_layout() = + layout; + TF_CHECK_OK(ShapeUtil::ValidateShape(shape_)); +} + bool ShapeLayout::operator==(const ShapeLayout& other) const { return ShapeUtil::Equal(shape_, other.shape_); } diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index 36806da599cc9b27286e67c128bb7f496f29c105..214cf98854938414c23c5031f4114016140ae9a7 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -72,6 +72,10 @@ class ShapeLayout { // tuple. void ResetLayout(const Layout& layout); + // Resets the layout on the shape at the provided ShapeIndex to the provided + // layout. Shape must be a tuple. + void ResetLayout(const Layout& layout, ShapeIndexView shape_index); + // Returns a string representation of this object. string ToString() const { return ShapeUtil::HumanStringWithLayout(shape_); } diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 18e54d23c241ae0d4c61d8be79ff021dfb02a3e6..c74dd648addd70633edc2ec10a60879a00942716 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -44,10 +44,6 @@ struct ShapeTreeNode { // Data corresponding to this node. std::pair data; - // Children of this node, as indices into the container's nodes_ array. - std::vector children; - - // Tells whether this is a leaf node. bool is_leaf = true; explicit ShapeTreeNode(ShapeIndex index) @@ -56,6 +52,20 @@ struct ShapeTreeNode { : data(std::move(index), std::move(data)) {} }; +// Internal representation of an index table entry. +struct IndexTableEntry { + // Index of the node in the ShapeTreeNode vector. + uint32 index; + // Index of the first child in a IndexTableEntry vector. In the index + // table all children entries for a given node will be placed next to each + // other. This allows us to use a single field to index them. + uint32 children_start; +#ifndef NDEBUG + // Number of children, used for bounds checking. + uint32 children_count; +#endif +}; + } // namespace internal template @@ -84,6 +94,7 @@ template class ShapeTree { public: using Node = internal::ShapeTreeNode; + using Index = internal::IndexTableEntry; // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} @@ -105,8 +116,8 @@ class ShapeTree { // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). - const T& element(const ShapeIndex& index) const; - T* mutable_element(const ShapeIndex& index); + const T& element(ShapeIndexView index) const; + T* mutable_element(ShapeIndexView index); // Return the shape represented with this ShapeTree. const Shape& shape() const { return *shape_; } @@ -125,7 +136,7 @@ class ShapeTree { // Returns true if the node at the given index is a leaf node (an array // shape). - bool IsLeaf(const ShapeIndex& index) const { return Lookup(index)->is_leaf; } + bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; } ShapeTree(const ShapeTree&) = default; ShapeTree& operator=(const ShapeTree&) = default; @@ -211,12 +222,12 @@ class ShapeTree { // Returns an iterator pointing to the given ShapeIndex. // REQUIRES: index must exist in the ShapeTree. - iterator find(const ShapeIndex& index) { + iterator find(ShapeIndexView index) { Node* element = Lookup(index); return iterator(&nodes_, typename std::vector::iterator(element), /*iterate_leaves_only=*/false); } - const_iterator find(const ShapeIndex& index) const { + const_iterator find(ShapeIndexView index) const { Node* element = Lookup(index); return iterator(&nodes_, typename std::vector::const_iterator(element), @@ -267,11 +278,12 @@ class ShapeTree { private: // Initialize node->children based on 'shape'. All children are assigned the // the given 'init_value'. - void InitChildren(const Shape& shape, const T& init_value, Node* node); + void InitChildren(const Shape& shape, const T& init_value, Node* node, + Index* index); // Initialize node->children based on 'shape'. All children have // default-constructed data values. - void InitChildren(const Shape& shape, Node* node); + void InitChildren(const Shape& shape, Node* node, Index* index); // Returns the number of subshapes, including interior nodes, in shape. int64 CountSubshapes(const Shape& shape); @@ -285,12 +297,15 @@ class ShapeTree { static Status ForEachMutableHelper(const Fn& func, std::vector* nodes); // Return the tree node at the given index. - Node* Lookup(const ShapeIndex& index); - const Node* Lookup(const ShapeIndex& index) const; + Node* Lookup(ShapeIndexView index); + const Node* Lookup(ShapeIndexView index) const; // The nodes in this shape tree. std::vector nodes_; + // Index table for node lookups. + std::vector index_table_; + // If we own our Shape, this field contains it, and shape_ is a pointer into // here. Otherwise if we don't own our shape, this is nullptr. std::shared_ptr shape_storage_; @@ -373,36 +388,74 @@ int64 ShapeTree::CountSubshapes(const Shape& shape) { template void ShapeTree::InitChildren(const Shape& shape, const T& init_value, - Node* node) { + Node* node, Index* index) { if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); - node->children.reserve(size); +#ifndef NDEBUG + index->children_count = size; +#endif node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); + + // At the end of the index_table, reserve a continuous space to hold the + // children of current node. In order to enforce the invariant that all + // children of a given node are placed together, we need to do the + // reservation before we recurse into any of its children. + int64 children_start_position = index_table_.size(); + index_table_.resize(index_table_.size() + size); + for (int i = 0; i < size; ++i) { shape_index[shape_index.size() - 1] = i; - node->children.push_back(nodes_.size()); + index_table_[children_start_position + i].index = nodes_.size(); + // The first child of the node in the index table is placed at the end of + // the table. + index_table_[children_start_position + i].children_start = + index_table_.size(); nodes_.emplace_back(shape_index, init_value); - InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back()); + InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(), + &index_table_[children_start_position + i]); } + } else { +#ifndef NDEBUG + index->children_count = 0; +#endif } } template -void ShapeTree::InitChildren(const Shape& shape, Node* node) { +void ShapeTree::InitChildren(const Shape& shape, Node* node, Index* index) { if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); - node->children.reserve(size); +#ifndef NDEBUG + index->children_count = size; +#endif node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); + + // At the end of the index_table, reserve a continuous space to hold the + // children of current node. In order to enforce the invariant that all + // children of a given node are placed together, we need to do the + // reservation before we recurse into any of its children. + int64 children_start_position = index_table_.size(); + index_table_.resize(index_table_.size() + size); + for (int i = 0; i < size; ++i) { shape_index[shape_index.size() - 1] = i; - node->children.push_back(nodes_.size()); + index_table_[children_start_position + i].index = nodes_.size(); + // The first child of the node in the index table is placed at the end of + // the table. + index_table_[children_start_position + i].children_start = + index_table_.size(); nodes_.emplace_back(shape_index); - InitChildren(shape.tuple_shapes(i), &nodes_.back()); + InitChildren(shape.tuple_shapes(i), &nodes_.back(), + &index_table_[children_start_position + i]); } + } else { +#ifndef NDEBUG + index->children_count = 0; +#endif } } @@ -413,24 +466,36 @@ ShapeTree::ShapeTree(Shape shape) // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); - InitChildren(*shape_, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, &nodes_[0], &index_table_[0]); } template ShapeTree::ShapeTree(const Shape* shape) : shape_(shape) { - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); - InitChildren(*shape_, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, &nodes_[0], &index_table_[0]); } template ShapeTree::ShapeTree(const std::shared_ptr& shape) : shape_storage_(shape), shape_(shape_storage_.get()) { - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); - InitChildren(*shape_, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, &nodes_[0], &index_table_[0]); } template @@ -440,52 +505,67 @@ ShapeTree::ShapeTree(Shape shape, const T& init_value) // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); - InitChildren(*shape_, init_value, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); } template ShapeTree::ShapeTree(const Shape* shape, const T& init_value) : shape_(shape) { - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); - InitChildren(*shape_, init_value, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); } template ShapeTree::ShapeTree(const std::shared_ptr& shape, const T& init_value) : shape_storage_(shape), shape_(shape_storage_.get()) { - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); - InitChildren(*shape_, init_value, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); } template -const T& ShapeTree::element(const ShapeIndex& index) const { +const T& ShapeTree::element(ShapeIndexView index) const { return Lookup(index)->data.second; } template -T* ShapeTree::mutable_element(const ShapeIndex& index) { +T* ShapeTree::mutable_element(ShapeIndexView index) { return &Lookup(index)->data.second; } template -internal::ShapeTreeNode* ShapeTree::Lookup(const ShapeIndex& index) { - Node* node = &nodes_[0]; +internal::ShapeTreeNode* ShapeTree::Lookup(ShapeIndexView index) { + Index* iter = &index_table_[0]; for (const int64 i : index) { CHECK_GE(i, 0); - CHECK_LT(i, node->children.size()); - node = &nodes_[node->children[i]]; +#ifndef NDEBUG + CHECK_LT(i, iter->children_count); +#endif + iter = &index_table_[iter->children_start + i]; } - return node; + + return &nodes_[iter->index]; } template const internal::ShapeTreeNode* ShapeTree::Lookup( - const ShapeIndex& index) const { + ShapeIndexView index) const { return const_cast(this)->Lookup(index); } diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 51de82e95746281ed6e587b545dc933b48ce1ad4..c4c958be4a18f23b8e34f9e619e447c6bf4334b5 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -172,7 +172,7 @@ TEST_F(ShapeTreeTest, TupleShape) { // Write zero to all data elements. shape_tree.ForEachMutableElement( - [&sum](const ShapeIndex& /*index*/, int* data) { *data = 0; }); + [](const ShapeIndex& /*index*/, int* data) { *data = 0; }); EXPECT_EQ(0, shape_tree.element({})); EXPECT_EQ(0, shape_tree.element({0})); EXPECT_EQ(0, shape_tree.element({1})); @@ -227,14 +227,16 @@ TEST_F(ShapeTreeTest, NestedTupleShape) { TEST_F(ShapeTreeTest, InvalidIndexingTuple) { ShapeTree shape_tree{tuple_shape_}; - +#ifndef NDEBUG EXPECT_DEATH(shape_tree.element({4}), ""); +#endif } TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) { ShapeTree shape_tree{nested_tuple_shape_}; - +#ifndef NDEBUG EXPECT_DEATH(shape_tree.element({0, 0}), ""); +#endif } TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) { @@ -602,12 +604,15 @@ void BM_Iterate(int iters, int depth, int fan_out) { } } -BENCHMARK(BM_Construct)->ArgPair(2, 8); -BENCHMARK(BM_ConstructUnowned)->ArgPair(2, 8); -BENCHMARK(BM_Copy)->ArgPair(2, 8); -BENCHMARK(BM_Move)->ArgPair(2, 8); -BENCHMARK(BM_ForEach)->ArgPair(2, 8); -BENCHMARK(BM_Iterate)->ArgPair(2, 8); +#define BENCHMARK_WITH_ARGS(name) \ + BENCHMARK(name)->ArgPair(2, 8)->ArgPair(1, 1000) + +BENCHMARK_WITH_ARGS(BM_Construct); +BENCHMARK_WITH_ARGS(BM_ConstructUnowned); +BENCHMARK_WITH_ARGS(BM_Copy); +BENCHMARK_WITH_ARGS(BM_Move); +BENCHMARK_WITH_ARGS(BM_ForEach); +BENCHMARK_WITH_ARGS(BM_Iterate); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 5db66599324913b9214d7623597060950246fb03..34869cc5078699603c006387161fddd4fee4a9f8 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -24,9 +24,11 @@ limitations under the License. #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/overflow_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/iterator_range.h" @@ -44,28 +46,14 @@ namespace xla { using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -string ShapeIndex::ToString() const { - return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); -} +string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); } string ShapeIndexView::ToString() const { - return StrCat("{", - tensorflow::str_util::Join( - tensorflow::gtl::make_range(begin_, end_), ","), - "}"); + return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); } bool ShapeIndexView::operator==(const ShapeIndexView& other) const { - if (size() != other.size()) { - return false; - } - for (auto it = begin(), other_it = other.begin(); it != end(); - ++it, ++other_it) { - if (*it != *other_it) { - return false; - } - } - return true; + return indices_ == other.indices_; } bool ShapeIndexView::operator!=(const ShapeIndexView& other) const { @@ -93,8 +81,11 @@ bool IsArrayPrimitiveType(PrimitiveType primitive_type) { // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. -bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { - if (!ShapeUtil::SameElementType(lhs, rhs)) { +bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, + bool ignore_fp_precision) { + if ((ignore_fp_precision && + !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) || + (!ignore_fp_precision && !ShapeUtil::SameElementType(lhs, rhs))) { VLOG(3) << "CompareShapes: lhs element type != rhs element type"; return false; } @@ -102,7 +93,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { if (ShapeUtil::IsTuple(lhs)) { return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), [=](const Shape& l, const Shape& r) { - return CompareShapes(l, r, compare_layouts); + return CompareShapes(l, r, compare_layouts, + ignore_fp_precision); }); } else if (!ShapeUtil::IsArray(lhs)) { // Non-tuple, non-array tupes such as opaque and token types are trivially @@ -169,7 +161,8 @@ StatusOr MakeShapeWithLayoutInternal( } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { - bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true); + bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, + /*ignore_fp_precision=*/false); if (!equal && VLOG_IS_ON(3)) { VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); @@ -178,6 +171,18 @@ StatusOr MakeShapeWithLayoutInternal( return equal; } +/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs, + const Shape& rhs) { + bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true, + /*ignore_fp_precision=*/true); + if (!equal && VLOG_IS_ON(3)) { + VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = " + << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString(); + } + + return equal; +} + /* static */ int64 ShapeUtil::Rank(const Shape& shape) { CHECK(ShapeUtil::IsArray(shape)) << "Non-arrays do not have a rank, shape: " << shape; @@ -263,6 +268,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( tensorflow::gtl::ArraySlice shapes) { Shape result; result.set_element_type(TUPLE); + result.mutable_tuple_shapes()->Reserve(shapes.size()); for (const auto& shape : shapes) { AppendShapeToTuple(shape, &result); } @@ -363,7 +369,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::IsNil(const Shape& shape) { - return IsTuple(shape) ? IsEmptyTuple(shape) : HasZeroElements(shape); + return IsEmptyTuple(shape); } /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { @@ -379,6 +385,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return shape.tuple_shapes(index); } +/* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) { + int64 n = 0; + ForEachSubshape(shape, [&](const Shape& literal_subshape, + const ShapeIndex& index) { ++n; }); + return n; +} + /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start, int64 limit) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple)); @@ -413,15 +426,26 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( std::multiplies()); } -/* static */ bool ShapeUtil::HasZeroElements(const Shape& shape) { - return ElementsIn(shape) == 0; +/* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) { + CHECK(IsArray(shape) || IsTuple(shape)); + if (IsArray(shape)) { + return ElementsIn(shape); + } + int64 count = 0; + for (const Shape& element_shape : shape.tuple_shapes()) { + count += ElementsInRecursive(element_shape); + } + return count; +} + +/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { + return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } /* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { return shape.element_type() == F32 && Rank(shape) == 0; } - namespace { // Class to memoize the computation of @@ -554,12 +578,11 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so // we convert in to the RE2-consumable type and then consume the corresponding // amount from our StringPiece type. + static LazyRE2 shape_pattern = { + "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?"}; tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); - if (RE2::Consume( - &s_consumable, - "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?", - &element_type_string, &dimensions_string, &format_string, - &layout_string)) { + if (RE2::Consume(&s_consumable, *shape_pattern, &element_type_string, + &dimensions_string, &format_string, &layout_string)) { size_t consumed = s->size() - s_consumable.size(); s->remove_prefix(consumed); auto string_to_int64 = [&s](const string& input) -> StatusOr { @@ -573,8 +596,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { }; auto comma_list_to_int64s = - [&s, - string_to_int64](const string& input) -> StatusOr> { + [string_to_int64](const string& input) -> StatusOr> { std::vector results; for (const string& piece : tensorflow::str_util::Split(input, ',')) { TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); @@ -645,15 +667,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - if (IsArray(lhs)) { - return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); - } else if (lhs.element_type() == TUPLE) { - return rhs.element_type() == TUPLE && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); - } else { - // Opaque, token, etc types are vacuously compatible. - return true; - } + return CompareShapes(lhs, rhs, /*compare_layouts=*/false, + /*ignore_fp_precision=*/false); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, @@ -666,7 +681,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { CompatibleIgnoringElementType); } else { // Opaque, token, etc types are vacuously compatible. - return true; + return lhs.element_type() == rhs.element_type(); } } @@ -681,7 +696,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { CompatibleIgnoringFpPrecision); } else { // Opaque, token, etc types are vacuously compatible. - return true; + return lhs.element_type() == rhs.element_type(); } } @@ -776,7 +791,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (LayoutUtil::IsSparseArray(shape)) { allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { - CHECK(LayoutUtil::IsDenseArray(shape)); + CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString(); tensorflow::gtl::ArraySlice padded_dimensions = LayoutUtil::PaddedDimensions(shape); if (!padded_dimensions.empty()) { @@ -855,6 +870,71 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } } + TF_RETURN_IF_ERROR(ValidateShapeSize(shape)); + return Status::OK(); +} + +/* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) { + VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape); + + if (!IsArray(shape)) { + return Status::OK(); + } + + int64 shape_size = [&shape]() { + if (LayoutUtil::IsSparseArray(shape)) { + int64 max_sparse_elements = LayoutUtil::MaxSparseElements(shape.layout()); + if (max_sparse_elements < 0) { + return max_sparse_elements; + } + int64 sparse_elements_size = MultiplyWithoutOverflow( + max_sparse_elements, ByteSizeOfPrimitiveType(shape.element_type())); + if (sparse_elements_size < 0) { + return sparse_elements_size; + } + int64 sparse_indices_size = + MultiplyWithoutOverflow(max_sparse_elements, ShapeUtil::Rank(shape)); + if (sparse_indices_size < 0) { + return sparse_indices_size; + } + sparse_indices_size = + MultiplyWithoutOverflow(sparse_indices_size, sizeof(int64)); + if (sparse_indices_size < 0) { + return sparse_indices_size; + } + // At this point, both sparse_indices_size and sparse_elements_size are + // non-negative, so we can easily check if adding them wraps. + if (static_cast(sparse_elements_size) + + static_cast(sparse_indices_size) > + INT64_MAX) { + return static_cast(-1); + } + } + + // This is intentionally unconditional: even if the shape is sparse, we want + // to verify the densified version has a reasonable size. + int64 dense_shape_size = 1; + if (shape.dimensions().empty()) { + return dense_shape_size; + } + + for (int64 dim : shape.dimensions()) { + dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim); + if (dense_shape_size < 0) { + return dense_shape_size; + } + } + dense_shape_size = MultiplyWithoutOverflow( + dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type())); + return dense_shape_size; + }(); + + if (shape_size < 0) { + return InvalidArgument("Shape %s size may overflow int64.", + ShapeUtil::HumanString(shape).c_str()); + } + + VLOG(3) << "Shape size is valid: " << shape_size; return Status::OK(); } @@ -903,6 +983,21 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return *return_shape; } +/* static */ StatusOr ShapeUtil::TryGetSubshape( + const Shape& shape, ShapeIndexView index) { + const Shape* return_shape = &shape; + for (auto i : index) { + if (!IsTuple(*return_shape) || i < 0 || + i >= return_shape->tuple_shapes_size()) { + return InvalidArgument( + "Shape index %s not a valid subshape index for tuple with shape %s", + index.ToString().c_str(), shape.DebugString().c_str()); + } + return_shape = &return_shape->tuple_shapes(i); + } + return return_shape; +} + /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape, ShapeIndexView index) { Shape* return_shape = shape; @@ -939,6 +1034,11 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return leaves; } +/* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) { + CHECK(ShapeUtil::IsArray(shape)); + return ArrayContains(AsInt64Slice(shape.dimensions()), 1); +} + namespace { // Helper for ForEachSubshape which visits the subshapes of the given shape in @@ -1022,12 +1122,41 @@ Status ForEachMutableSubshapeHelper( for (auto dim : Permute(permutation, shape.dimensions())) { new_shape.add_dimensions(dim); } + + // If `shape` has a layout, by contract we choose a new layout such that the + // transpose defined by this permutation is a bitcast. + // + // Some formalism helps to understand the correct way to do this. We're going + // to do algebra in the group of permutations of the dimensions of `shape`. + // + // Since the order of `shape`'s dimensions is not permuted relative to itself, + // `shape`'s list of dimensions is isomorphic to the identity I. + // + // Let `shape`'s layout be L. A layout is a permutation which maps a + // minor-to-major physical layout to the order of a shape's logical dims. + // Therefore inverse of a layout maps from logical to physical dims, and so + // the physical layout of I is simply L'.I = L', where L' is the inverse of L. + // + // Let the argument `permutation` be P. This is a permutation over `shape`'s + // dimensions, so our return value will be a shape with dims P.I = P. Our + // goal is to construct a layout permutation L* that we can apply to P such + // that that the physical dimension ordering of the returned shape is the same + // as that of the original shape, namely L'. + // + // Our returned shape has dims P and layout L*, so its in-memory layout is + // L*'.P. Setting this equal to L' and solving for L*, we get: + // + // L*'.P = L' => + // L*' = L'P' => + // L* = P.L + // if (shape.has_layout()) { CHECK(LayoutUtil::IsDenseArray(shape)); Layout* new_layout = new_shape.mutable_layout(); new_layout->set_format(DENSE); new_layout->clear_minor_to_major(); - for (auto index : Permute(permutation, shape.layout().minor_to_major())) { + for (auto index : ComposePermutations( + permutation, AsInt64Slice(shape.layout().minor_to_major()))) { new_layout->add_minor_to_major(index); } if (shape.layout().padded_dimensions_size() > 0) { @@ -1037,6 +1166,13 @@ Status ForEachMutableSubshapeHelper( new_layout->add_padded_dimensions(dim); } } + // The permutation accepted by TransposeIsBitcast is the inverse of the + // permutation here. + CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation))) + << "shape=" << HumanStringWithLayout(shape) + << ", new_shape=" << HumanStringWithLayout(new_shape) + << ", permutation={" << tensorflow::str_util::Join(permutation, ",") + << "}"; } return new_shape; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index ae2d17d6bbbfed96e1da192253838ae5e9a67e17..d6f17fc965d24bbbbd083b8dd0ec11a59e49ed4e 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" @@ -73,10 +74,12 @@ class ShapeIndex { // push_front is O(n^2), but shapes don't usually have a ton of dimensions. void push_front(int64 value) { indices_.insert(indices_.begin(), value); } - std::vector::const_iterator begin() const { return indices_.begin(); } - std::vector::const_iterator end() const { return indices_.end(); } - std::vector::iterator begin() { return indices_.begin(); } - std::vector::iterator end() { return indices_.end(); } + using container_type = tensorflow::gtl::InlinedVector; + + container_type::const_iterator begin() const { return indices_.begin(); } + container_type::const_iterator end() const { return indices_.end(); } + container_type::iterator begin() { return indices_.begin(); } + container_type::iterator end() { return indices_.end(); } const int64* data() const { return indices_.data(); } @@ -97,7 +100,7 @@ class ShapeIndex { string ToString() const; private: - std::vector indices_; + container_type indices_; }; // A view into a ShapeIndex as above, with the cheap/easy ability to consume the @@ -110,31 +113,33 @@ class ShapeIndex { class ShapeIndexView { public: ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0) - : ShapeIndexView(shape_index.data() + offset, - shape_index.data() + shape_index.size()) { + : indices_(shape_index.data() + offset, shape_index.size() - offset) { CHECK_LE(offset, shape_index.size()); } - ShapeIndexView(std::initializer_list indices) - : ShapeIndexView(indices.begin(), indices.end()) {} + ShapeIndexView(std::initializer_list indices) : indices_(indices) {} ShapeIndexView(const ShapeIndexView& other) = default; using iterator = const int64*; - iterator begin() const { return begin_; } - iterator end() const { return end_; } - int64 size() const { return std::distance(begin_, end_); } - bool empty() const { return begin_ == end_; } + iterator begin() const { return indices_.begin(); } + iterator end() const { return indices_.end(); } + int64 size() const { return indices_.size(); } + bool empty() const { return indices_.empty(); } int64 front() const { CHECK(!empty()); - return *begin_; + return indices_.front(); } ShapeIndexView ConsumeFront() const { - CHECK(!empty()); - auto new_begin = begin_; - ++new_begin; - return ShapeIndexView(new_begin, end_); + ShapeIndexView result = *this; + result.indices_.pop_front(); + return result; } - ShapeIndex ToShapeIndex() const { return ShapeIndex(begin_, end_); } + ShapeIndexView ConsumeBack() const { + ShapeIndexView result = *this; + result.indices_.pop_back(); + return result; + } + ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); } bool operator==(const ShapeIndexView& other) const; bool operator!=(const ShapeIndexView& other) const; @@ -142,10 +147,7 @@ class ShapeIndexView { string ToString() const; private: - ShapeIndexView(iterator begin, iterator end) : begin_(begin), end_(end) {} - - iterator begin_; - iterator end_; + tensorflow::gtl::ArraySlice indices_; }; std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); @@ -175,8 +177,11 @@ class ShapeUtil { // Precondition: IsArray(shape) static int64 ElementsIn(const Shape& shape); - // Returns true if 'shape' has zero elements. - static bool HasZeroElements(const Shape& shape); + // As ElementsIn(), but recurses through tuples. + static int64 ElementsInRecursive(const Shape& shape); + + // Returns true if 'shape' is an array with zero elements. + static bool IsZeroElementArray(const Shape& shape); // Returns the number of bytes required for an allocation of shape. The // |pointer_size| parameter is used for calculating the size of tuple @@ -277,6 +282,9 @@ class ShapeUtil { // Returns whether the lhs and rhs shapes are identical protobufs. static bool Equal(const Shape& lhs, const Shape& rhs); + // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. + static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); + // Returns the rank (number of dimensions) of the given shape. // Precondition: !IsTuple(shape) static int64 Rank(const Shape& shape); @@ -336,7 +344,7 @@ class ShapeUtil { // Appends a major dimension to the shape with the given bound. static void AppendMajorDimension(int bound, Shape* shape); - // Returns an empty tuple shape. Can be used to indicate side-effects. + // Returns an empty tuple shape. Can be used as a sentinel Shape value. static Shape MakeNil() { return MakeTupleShape({}); } // Checks whether the shape is initialized. @@ -446,7 +454,7 @@ class ShapeUtil { // Returns true if shape is an empty tuple. static bool IsEmptyTuple(const Shape& shape); - // Returns true if shape is an empty tuple, or is an array with no elements. + // Returns true if shape is the nil shape (an empty tuple). static bool IsNil(const Shape& shape); // Returns the number of elements in the given tuple shape. @@ -457,6 +465,9 @@ class ShapeUtil { // Precondition: IsTuple(shape) && TupleElementCount(shape) > index static const Shape& GetTupleElementShape(const Shape& shape, int64 index); + // Returns the number of elements, recursively, in the given shape. + static int64 SubshapeCount(const Shape& shape); + // Slices tuple elements in the range [start, limit) and returns a new tuple // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32). static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit); @@ -476,8 +487,11 @@ class ShapeUtil { static bool IndexIsValid(const Shape& shape, ShapeIndexView index); // GetSubshape and GetMutableSubshape return a particular nested Shape within - // the given Shape argument. + // the given Shape argument. The non-Try variants check fail if index is + // invalid. static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index); + static StatusOr TryGetSubshape(const Shape& shape, + ShapeIndexView index); static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index); // Returns whether the given index in the given shape is a leaf element of the @@ -513,8 +527,18 @@ class ShapeUtil { static Status ForEachMutableSubshapeWithStatus( Shape* shape, const MutatingStatusVisitorFunction& func); + // Returns true if `shape` (which must be an array) with degenerate dimensions + // (dimensions with bound 1). + static bool HasDegenerateDimensions(const Shape& shape); + // Permutes the dimensions by the given permutation, so - // return_value.dimensions[permutation[i]] = argument.dimensions[i] + // return_value.dimensions[permutation[i]] = argument.dimensions[i]. + // + // Postcondition: For any valid permutation, + // + // !HasLayout(shape) || + // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape), + // InversePermutation(permutation)). static Shape PermuteDimensions(tensorflow::gtl::ArraySlice permutation, const Shape& shape); @@ -686,6 +710,10 @@ class ShapeUtil { static size_t Hash(const Shape& shape); private: + // Validates the shape size is sane. This makes sure it's safe to do + // calculations in int64 without overflowing. + static Status ValidateShapeSize(const Shape& shape); + // Validates all of the non-layout properties of the shape -- this is a helper // used by both the layout-optional and layout-required public method. static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape); @@ -697,7 +725,7 @@ class ShapeUtil { tensorflow::gtl::ArraySlice incr, const FnType& visitor_function, bool parallel = false) { - if (ShapeUtil::HasZeroElements(shape)) { + if (ShapeUtil::IsZeroElementArray(shape)) { return Status::OK(); } CHECK_EQ(Rank(shape), base.size()); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 0ff514564bdb27b7afa4cf99b0d727f2c029a5ae..e5dd62ae9a3dd9b961a7ae03a99c19220dbd43e7 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" +#include #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" @@ -22,12 +23,23 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { using ::testing::ElementsAre; +TEST(ShapeUtilTest, ShapeIndexViewTest) { + ShapeIndex index = {1, 2, 3, 4}; + ShapeIndexView index_view(index, 1); + EXPECT_EQ(3, index_view.size()); + EXPECT_EQ(ShapeIndexView({2, 3, 4}), index_view); + EXPECT_EQ(ShapeIndexView({3, 4}), index_view.ConsumeFront()); + EXPECT_EQ(ShapeIndexView({2, 3}), index_view.ConsumeBack()); +} + TEST(ShapeUtilTest, GetDimensionHelperCanNegativeIndex) { Shape matrix = ShapeUtil::MakeShape(F32, {2, 3}); EXPECT_EQ(3, ShapeUtil::GetDimension(matrix, -1)); @@ -172,6 +184,41 @@ TEST(ShapeUtilTest, CompatibleIdenticalShapes) { ASSERT_TRUE(ShapeUtil::Compatible(shape1, shape2)); } +TEST(ShapeUtilTest, TokenCompatibility) { + EXPECT_TRUE(ShapeUtil::Compatible(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTokenShape())); + EXPECT_FALSE(ShapeUtil::Compatible(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShape(F32, {}))); + EXPECT_FALSE(ShapeUtil::Compatible(ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTokenShape())); + EXPECT_TRUE(ShapeUtil::Compatible( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}))); +} + +TEST(ShapeUtilTest, TokensEqualShapes) { + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTokenShape())); + EXPECT_FALSE(ShapeUtil::Equal(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShape(F32, {}))); + EXPECT_FALSE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTokenShape())); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}))); + EXPECT_FALSE(ShapeUtil::Equal( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {1, 0})}))); +} + TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) { Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2}); auto layout_1 = shape_1.mutable_layout(); @@ -207,6 +254,24 @@ TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) { EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2)); } +TEST(ShapeUtilTest, EqualIgnoringFpPrecision) { + EXPECT_TRUE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1}))); +} + +TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) { + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {0, 1}))); + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {3, 4}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {1, 0}))); + EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision( + ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}), + ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1}))); +} + TEST(ShapeUtilTest, CompatibleTuples) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); @@ -269,6 +334,17 @@ TEST(ShapeUtilTest, IncompatibleScalarVsTuple) { EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape2, shape1)); } +TEST(ShapeUtilTest, OpaqueVsArray) { + Shape shape1 = ShapeUtil::MakeShape(F32, {5, 7}); + Shape shape2 = ShapeUtil::MakeOpaqueShape(); + EXPECT_FALSE(ShapeUtil::Compatible(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::Compatible(shape2, shape1)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape2, shape1)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape2, shape1)); +} + TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) { Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); shape1.mutable_layout()->add_padded_dimensions(10); @@ -329,6 +405,16 @@ TEST(ShapeUtilTest, ByteSizeOfWithPadding) { EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape)); } +TEST(ShapeUtilTest, NilShape) { + EXPECT_TRUE(ShapeUtil::IsNil(ShapeUtil::MakeNil())); + EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {1, 2, 3}))); + EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {0, 1}))); + EXPECT_FALSE(ShapeUtil::IsNil( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})}))); + EXPECT_FALSE(ShapeUtil::IsNil( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {0})}))); +} + TEST(ShapeUtilTest, NestedTuple) { EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape({}))); EXPECT_FALSE(ShapeUtil::IsNestedTuple( @@ -359,25 +445,30 @@ TEST(ShapeUtilTest, ElementsIn) { EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17}))); } -TEST(ShapeUtilTest, HasZeroElements) { - EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {}))); - EXPECT_EQ(true, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0}))); - EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 1}))); - EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2, 1}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 5}))); - EXPECT_EQ(true, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 0, 5}))); - EXPECT_EQ(true, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0, 3, 0}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 3, 5}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {13, 17}))); +TEST(ShapeUtilTest, IsZeroElementArray) { + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {}))); + EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0}))); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 1}))); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2, 1}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 5}))); + EXPECT_TRUE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 0, 5}))); + EXPECT_TRUE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0, 3, 0}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 3, 5}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {13, 17}))); + + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeNil())); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeTupleShape({}))); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {0, 3, 0})}))); } TEST(ShapeUtilTest, SameDimensions) { @@ -742,6 +833,39 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } +TEST(ShapeUtilTest, HasDegenerateDimensions) { + EXPECT_TRUE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 1, 2}))); + EXPECT_TRUE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 1, 1}))); + EXPECT_FALSE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 3, 5}))); + EXPECT_FALSE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 0, 5}))); +} + +TEST(ShapeUtilTest, PermuteDimensionsLayout) { + std::vector layout(3); + std::iota(layout.begin(), layout.end(), 0); + do { + Shape s = ShapeUtil::MakeShapeWithLayout(F32, {10, 100, 1000}, layout); + SCOPED_TRACE(tensorflow::strings::StrCat("s=", ShapeUtil::HumanString(s))); + + std::vector permutation(3); + std::iota(permutation.begin(), permutation.end(), 0); + do { + SCOPED_TRACE(tensorflow::strings::StrCat( + "permutation=", tensorflow::str_util::Join(permutation, ","))); + + // TransposeIsBitcast takes the inverse of the permutation that + // PermuteDimensions takes. + EXPECT_TRUE(ShapeUtil::TransposeIsBitcast( + s, ShapeUtil::PermuteDimensions(permutation, s), + InversePermutation(permutation))); + } while (std::next_permutation(permutation.begin(), permutation.end())); + } while (std::next_permutation(layout.begin(), layout.end())); +} + TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h index 0e1387c93938fa520562fcd63ac107a82b089a51..a32e2ad9851b0b5644f7e6f0f9ead6c438934c07 100644 --- a/tensorflow/compiler/xla/statusor.h +++ b/tensorflow/compiler/xla/statusor.h @@ -12,297 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -// StatusOr is the union of a Status object and a T object. StatusOr models -// the concept of an object that is either a value, or an error Status -// explaining why such a value is not present. To this end, StatusOr does not -// allow its Status value to be Status::OK. -// -// The primary use-case for StatusOr is as the return value of a -// function which may fail. -// -// Example client usage for a StatusOr, where T is not a pointer: -// -// StatusOr result = DoBigCalculationThatCouldFail(); -// if (result.ok()) { -// float answer = result.ValueOrDie(); -// printf("Big calculation yielded: %f", answer); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example client usage for a StatusOr: -// -// StatusOr result = FooFactory::MakeNewFoo(arg); -// if (result.ok()) { -// std::unique_ptr foo(result.ValueOrDie()); -// foo->DoSomethingCool(); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example client usage for a StatusOr>: -// -// StatusOr> result = FooFactory::MakeNewFoo(arg); -// if (result.ok()) { -// std::unique_ptr foo = std::move(result.ValueOrDie()); -// foo->DoSomethingCool(); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example factory implementation returning StatusOr: -// -// StatusOr FooFactory::MakeNewFoo(int arg) { -// if (arg <= 0) { -// return tensorflow::InvalidArgument("Arg must be positive"); -// } else { -// return new Foo(arg); -// } -// } -// -// Note that the assignment operators require that destroying the currently -// stored value cannot invalidate the argument; in other words, the argument -// cannot be an alias for the current value, or anything owned by the current -// value. #ifndef TENSORFLOW_COMPILER_XLA_STATUSOR_H_ #define TENSORFLOW_COMPILER_XLA_STATUSOR_H_ #include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/statusor_internals.h" -#include "tensorflow/core/platform/macros.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace xla { -#if defined(__clang__) -// Only clang supports warn_unused_result as a type annotation. -template -class TF_MUST_USE_RESULT StatusOr; -#endif - -template -class StatusOr : private internal_statusor::StatusOrData, - private internal_statusor::TraitsBase< - std::is_copy_constructible::value, - std::is_move_constructible::value> { - template - friend class StatusOr; - - typedef internal_statusor::StatusOrData Base; - - public: - typedef T element_type; - - // Constructs a new StatusOr with Status::UNKNOWN status. This is marked - // 'explicit' to try to catch cases like 'return {};', where people think - // StatusOr> will be initialized with an empty vector, - // instead of a Status::UNKNOWN status. - explicit StatusOr(); - - // StatusOr will be copy constructible/assignable if T is copy - // constructible. - StatusOr(const StatusOr&) = default; - StatusOr& operator=(const StatusOr&) = default; - - // StatusOr will be move constructible/assignable if T is move - // constructible. - StatusOr(StatusOr&&) = default; - StatusOr& operator=(StatusOr&&) = default; - - // Conversion copy/move constructor, T must be convertible from U. - template ::value>::type* = nullptr> - StatusOr(const StatusOr& other); - template ::value>::type* = nullptr> - StatusOr(StatusOr&& other); - - // Conversion copy/move assignment operator, T must be convertible from U. - template ::value>::type* = nullptr> - StatusOr& operator=(const StatusOr& other); - template ::value>::type* = nullptr> - StatusOr& operator=(StatusOr&& other); - - // Constructs a new StatusOr with the given value. After calling this - // constructor, calls to ValueOrDie() will succeed, and calls to status() will - // return OK. - // - // NOTE: Not explicit - we want to use StatusOr as a return type - // so it is convenient and sensible to be able to do 'return T()' - // when the return type is StatusOr. - // - // REQUIRES: T is copy constructible. - StatusOr(const T& value); - - // Constructs a new StatusOr with the given non-ok status. After calling - // this constructor, calls to ValueOrDie() will CHECK-fail. - // - // NOTE: Not explicit - we want to use StatusOr as a return - // value, so it is convenient and sensible to be able to do 'return - // Status()' when the return type is StatusOr. - // - // REQUIRES: !status.ok(). This requirement is DCHECKed. - // In optimized builds, passing Status::OK() here will have the effect - // of passing tensorflow::error::INTERNAL as a fallback. - StatusOr(const Status& status); - StatusOr& operator=(const Status& status); - - // TODO(b/62186997): Add operator=(T) overloads. - - // Similar to the `const T&` overload. - // - // REQUIRES: T is move constructible. - StatusOr(T&& value); - - // RValue versions of the operations declared above. - StatusOr(Status&& status); - StatusOr& operator=(Status&& status); - - // Returns this->status().ok() - bool ok() const { return this->status_.ok(); } - - // Returns a reference to our status. If this contains a T, then - // returns Status::OK(). - const Status& status() const &; - Status status() &&; - - // Returns a reference to our current value, or CHECK-fails if !this->ok(). - // - // Note: for value types that are cheap to copy, prefer simple code: - // - // T value = statusor.ValueOrDie(); - // - // Otherwise, if the value type is expensive to copy, but can be left - // in the StatusOr, simply assign to a reference: - // - // T& value = statusor.ValueOrDie(); // or `const T&` - // - // Otherwise, if the value type supports an efficient move, it can be - // used as follows: - // - // T value = std::move(statusor).ValueOrDie(); - // - // The std::move on statusor instead of on the whole expression enables - // warnings about possible uses of the statusor object after the move. - // C++ style guide waiver for ref-qualified overloads granted in cl/143176389 - // See go/ref-qualifiers for more details on such overloads. - const T& ValueOrDie() const &; - T& ValueOrDie() &; - const T&& ValueOrDie() const &&; - T&& ValueOrDie() &&; - - T ConsumeValueOrDie() { return std::move(ValueOrDie()); } - - // Ignores any errors. This method does nothing except potentially suppress - // complaints from any tools that are checking that errors are not dropped on - // the floor. - void IgnoreError() const; -}; - -//////////////////////////////////////////////////////////////////////////////// -// Implementation details for StatusOr - -template -StatusOr::StatusOr() : Base(Status(tensorflow::error::UNKNOWN, "")) {} - -template -StatusOr::StatusOr(const T& value) : Base(value) {} - -template -StatusOr::StatusOr(const Status& status) : Base(status) {} - -template -StatusOr& StatusOr::operator=(const Status& status) { - this->Assign(status); - return *this; -} - -template -StatusOr::StatusOr(T&& value) : Base(std::move(value)) {} - -template -StatusOr::StatusOr(Status&& status) : Base(std::move(status)) {} - -template -StatusOr& StatusOr::operator=(Status&& status) { - this->Assign(std::move(status)); - return *this; -} - -template -template ::value>::type*> -inline StatusOr::StatusOr(const StatusOr& other) - : Base(static_cast::Base&>(other)) {} - -template -template ::value>::type*> -inline StatusOr& StatusOr::operator=(const StatusOr& other) { - if (other.ok()) - this->Assign(other.ValueOrDie()); - else - this->Assign(other.status()); - return *this; -} - -template -template ::value>::type*> -inline StatusOr::StatusOr(StatusOr&& other) - : Base(static_cast::Base&&>(other)) {} - -template -template ::value>::type*> -inline StatusOr& StatusOr::operator=(StatusOr&& other) { - if (other.ok()) { - this->Assign(std::move(other).ValueOrDie()); - } else { - this->Assign(std::move(other).status()); - } - return *this; -} - -template -const Status& StatusOr::status() const & { - return this->status_; -} -template -Status StatusOr::status() && { - return ok() ? Status::OK() : std::move(this->status_); -} - -template -const T& StatusOr::ValueOrDie() const & { - this->EnsureOk(); - return this->data_; -} - -template -T& StatusOr::ValueOrDie() & { - this->EnsureOk(); - return this->data_; -} - -template -const T&& StatusOr::ValueOrDie() const && { - this->EnsureOk(); - return std::move(this->data_); -} - -template -T&& StatusOr::ValueOrDie() && { - this->EnsureOk(); - return std::move(this->data_); -} - +// Use steam_executor's StatusOr so we don't duplicate code. template -void StatusOr::IgnoreError() const { - // no-op -} +using StatusOr = ::stream_executor::port::StatusOr; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e7e0a19db0516e4210f6bb78d6b5e6968bf78b2a..b8e700ae9705d596061628700036e223aca3a0f1 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -65,6 +65,7 @@ cc_library( srcs = ["test_utils.cc"], hdrs = ["test_utils.h"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -88,6 +89,7 @@ cc_library( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:error_spec", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_comparison", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test", @@ -152,8 +154,8 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", @@ -179,6 +181,7 @@ cc_library( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -189,8 +192,8 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -209,6 +212,7 @@ cc_library( deps = [ ":codegen_test_base", ":filecheck", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:test", @@ -258,7 +262,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", @@ -286,8 +290,8 @@ xla_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -302,7 +306,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -310,8 +314,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -330,8 +334,8 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -345,16 +349,16 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -372,9 +376,10 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -391,8 +396,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -406,7 +411,7 @@ xla_test( tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -415,9 +420,9 @@ xla_test( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -435,14 +440,14 @@ xla_test( tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -460,9 +465,9 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -479,8 +484,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -497,8 +502,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -515,9 +520,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -531,6 +536,7 @@ xla_test( srcs = ["scalar_computations_test.cc"], shard_count = 32, deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -538,8 +544,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -557,8 +563,8 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -573,7 +579,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -581,8 +587,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -599,7 +605,7 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -607,8 +613,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -633,7 +639,7 @@ xla_test( deps = [ ":client_library_test_base", ":literal_test_util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -645,7 +651,7 @@ xla_test( tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -653,7 +659,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -676,8 +682,8 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -697,6 +703,22 @@ xla_test( "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +xla_test( + name = "scatter_test", + srcs = ["scatter_test.cc"], + deps = [ + ":client_library_test_base", + ":hlo_test_base", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -719,8 +741,8 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -743,8 +765,8 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -763,11 +785,12 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -779,7 +802,7 @@ xla_test( CONVOLUTION_TEST_DEPS = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -788,7 +811,7 @@ CONVOLUTION_TEST_DEPS = [ "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -826,13 +849,13 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -855,8 +878,8 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -873,7 +896,7 @@ xla_test( ":test_utils", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -884,9 +907,10 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -905,7 +929,7 @@ xla_test( ":test_utils", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -916,9 +940,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -938,12 +962,12 @@ xla_test( ], deps = [ ":test_utils", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -964,7 +988,7 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -983,8 +1007,8 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1005,7 +1029,7 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", @@ -1029,14 +1053,15 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1056,9 +1081,9 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1077,6 +1102,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", @@ -1086,9 +1112,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1113,9 +1139,9 @@ xla_test_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1147,16 +1173,16 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1174,10 +1200,10 @@ xla_test( deps = [ ":client_library_test_base", "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1226,12 +1252,13 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1244,10 +1271,12 @@ xla_test( name = "custom_call_test", srcs = ["custom_call_test.cc"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1270,8 +1299,8 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1288,12 +1317,13 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1313,8 +1343,8 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1331,8 +1361,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1348,8 +1378,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1365,7 +1395,7 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1373,8 +1403,8 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1388,14 +1418,14 @@ xla_test( name = "prng_test", srcs = ["prng_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1413,6 +1443,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", @@ -1423,8 +1454,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1443,7 +1474,7 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1466,9 +1497,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1492,8 +1523,8 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1512,8 +1543,8 @@ xla_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1527,15 +1558,15 @@ xla_test( name = "cross_replica_sum_test", srcs = ["cross_replica_sum_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1557,7 +1588,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1571,15 +1602,15 @@ xla_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1597,8 +1628,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1611,7 +1642,7 @@ xla_test( name = "compute_constant_test", srcs = ["compute_constant_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1620,8 +1651,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1641,8 +1672,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1658,8 +1689,8 @@ xla_test( deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1672,8 +1703,8 @@ xla_test( deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1686,15 +1717,15 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1711,7 +1742,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1728,6 +1759,7 @@ tf_cc_test( srcs = ["llvm_compiler_test.cc"], tags = ["requires-gpu-sm35"], deps = [ + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:cpu_plugin", @@ -1748,7 +1780,7 @@ xla_test( name = "round_trip_packed_literal_test", srcs = ["round_trip_packed_literal_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:packed_literal_reader", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1771,15 +1803,16 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1798,14 +1831,14 @@ xla_test( srcs = ["multioutput_fusion_test.cc"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", @@ -1838,11 +1871,11 @@ xla_test( name = "local_client_allocation_test", srcs = ["local_client_allocation_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1861,7 +1894,7 @@ xla_test( shard_count = 30, tags = ["optonly"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -1869,8 +1902,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", @@ -1886,6 +1919,16 @@ xla_test( ], ) +xla_test( + name = "outfeed_in_nested_computation_test", + srcs = ["outfeed_in_nested_computation_test.cc"], + deps = [ + "//tensorflow/compiler/xla/tests:local_client_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "hlo_metadata_test", srcs = [ @@ -1895,7 +1938,7 @@ tf_cc_test( ":local_client_test_base", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/core:test_main", @@ -1907,7 +1950,7 @@ xla_test( srcs = ["round_trip_transfer_test.cc"], deps = [ "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", @@ -1928,7 +1971,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1937,8 +1980,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1951,7 +1994,7 @@ xla_test( name = "deep_graph_test", srcs = ["deep_graph_test.cc"], deps = [ - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -1976,7 +2019,7 @@ xla_test( ":literal_test_util", ":local_client_test_base", ":xla_internal_test_main", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1984,8 +2027,10 @@ xla_test( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", ], ) @@ -2035,9 +2080,30 @@ xla_test( ":local_client_test_base", ":test_utils", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) + +xla_test( + name = "iota_test", + srcs = ["iota_test.cc"], + blacklisted_backends = [ + "cpu", + "gpu", + ], + tags = [ + "enable_for_xla_interpreter", + ], + deps = [ + ":client_library_test_base", + ":literal_test_util", + ":xla_internal_test_main", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 36a706496918ac8c15780473019e2a8d098ffa22..74f2e36f826cd82ce4015df857f3de67950beaeb 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -51,16 +51,16 @@ class ArrayElementwiseOpTestParamCount XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Neg(a); + auto a = ConstantR1(&builder, {}); + Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); - builder.Neg(a); + auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + Neg(a); ComputeAndCompareR1(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {}, error_spec_); @@ -68,10 +68,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) { XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-1, 0, 1, 324, - std::numeric_limits::min(), - std::numeric_limits::max()}); - builder.Neg(a); + auto a = ConstantR1(&builder, + {-1, 0, 1, 324, std::numeric_limits::min(), + std::numeric_limits::max()}); + Neg(a); // -min == min for int32 due to an overflow. In C++ it is undefined behavior // to do this calculation. For XLA we have not specified that, so it @@ -84,17 +84,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Neg(a); + auto a = ConstantR1(&builder, {}); + Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); - builder.Neg(a); + auto a = ConstantR1( + &builder, {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); + Neg(a); ComputeAndCompareR1( &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, @@ -103,16 +103,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({ - -1, - 1, - 0, - 0x12345678, - static_cast(0xffffffff12345678l), - static_cast(0x8000000000000000LL), - static_cast(0x8000000000000001LL), - }); - builder.Neg(a); + auto a = + ConstantR1(&builder, { + -1, + 1, + 0, + 0x12345678, + static_cast(0xffffffff12345678l), + static_cast(0x8000000000000000LL), + static_cast(0x8000000000000001LL), + }); + Neg(a); LOG(INFO) << -static_cast(0x7FFFFFFFFFFFFFFFLL); ComputeAndCompareR1(&builder, @@ -130,8 +131,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.IsFinite(a); + auto a = ConstantR1(&builder, {}); + IsFinite(a); ComputeAndCompareR1(&builder, {}, {}); } @@ -141,21 +142,21 @@ static const float kNonCanonicalNaN = tensorflow::bit_cast(0x7FD01234); XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) { XlaBuilder builder(TestName()); - builder.IsFinite(builder.ConstantR0(NAN)); + IsFinite(ConstantR0(&builder, NAN)); ComputeAndCompareR0(&builder, false, {}); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); - builder.IsFinite(builder.ConstantR0(kNonCanonicalNaN)); + IsFinite(ConstantR0(&builder, kNonCanonicalNaN)); ComputeAndCompareR0(&builder, false, {}); const float inf = std::numeric_limits::infinity(); - builder.IsFinite(builder.ConstantR0(inf)); + IsFinite(ConstantR0(&builder, inf)); ComputeAndCompareR0(&builder, false, {}); - builder.IsFinite(builder.ConstantR0(-inf)); + IsFinite(ConstantR0(&builder, -inf)); ComputeAndCompareR0(&builder, false, {}); - builder.IsFinite(builder.ConstantR0(0.0f)); + IsFinite(ConstantR0(&builder, 0.0f)); ComputeAndCompareR0(&builder, true, {}); } @@ -163,9 +164,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { XlaBuilder builder(TestName()); const float inf = std::numeric_limits::infinity(); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); - auto a = builder.ConstantR1( - {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); - builder.IsFinite(a); + auto a = ConstantR1(&builder, + {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); + IsFinite(a); ComputeAndCompareR1(&builder, {false, true, false, true, false, false}, {}); @@ -173,9 +174,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); - auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - builder.Add(a, b); + auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + auto b = ConstantR1(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); + Add(a, b); ComputeAndCompareR1(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {}, error_spec_); @@ -183,20 +184,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Add(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); - auto b = builder.ConstantR1( - {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); - builder.Add(a, b); + auto a = ConstantR1( + &builder, {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); + auto b = ConstantR1( + &builder, {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); + Add(a, b); ComputeAndCompareR1( &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, @@ -205,9 +206,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Add(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -224,8 +225,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0x8000000000000000LL, 0x8000000000000000LL, 1}; - std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); - auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param"); + std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); std::unique_ptr lhs_data = client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); @@ -238,12 +239,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0, 1, 0x8000000000000000LL}; - std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); - auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param"); + std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); std::unique_ptr rhs_data = client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); - b.Add(lhs_param, rhs_param); + Add(lhs_param, rhs_param); std::vector expected(lhs.size()); for (int64 i = 0; i < lhs.size(); ++i) { @@ -264,8 +265,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 1, 0, -1}; - std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); - auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param"); + std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); std::unique_ptr lhs_data = client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); @@ -277,12 +278,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 0x7FFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL}; - std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); - auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param"); + std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); std::unique_ptr rhs_data = client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); - auto sub = b.Sub(lhs_param, rhs_param); + Sub(lhs_param, rhs_param); std::vector expected(lhs.size()); for (int64 i = 0; i < lhs.size(); ++i) { @@ -302,26 +303,26 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { b_values.push_back(2 * i / static_cast(count + 2)); } - std::unique_ptr a_literal = Literal::CreateR1({a_values}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({a_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a_constant = builder.ConstantR1(a_values); - auto a_param = builder.Parameter(0, a_literal->shape(), "a_param"); + auto a_constant = ConstantR1(&builder, a_values); + auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param"); - std::unique_ptr b_literal = Literal::CreateR1({b_values}); + std::unique_ptr b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param"); - auto b_param = builder.ConstantR1(b_values); + auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param"); + auto b_param = ConstantR1(&builder, b_values); - auto sum1 = builder.Add(a_constant, b_constant); - auto sum2 = builder.Add(a_constant, b_param); - auto sum3 = builder.Add(a_param, b_constant); - auto sum4 = builder.Add(a_param, b_param); + auto sum1 = Add(a_constant, b_constant); + auto sum2 = Add(a_constant, b_param); + auto sum3 = Add(a_param, b_constant); + auto sum4 = Add(a_param, b_param); - auto sum = builder.Add(sum1, sum2); - sum = builder.Add(sum, sum3); - sum = builder.Add(sum, sum4); + auto sum = Add(sum1, sum2); + sum = Add(sum, sum3); + sum = Add(sum, sum4); std::vector expected; for (int64 i = 0; i < count; ++i) { @@ -334,9 +335,9 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); - auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + auto b = ConstantR1(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); + Sub(a, b); ComputeAndCompareR1(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f}, {}, error_spec_); @@ -344,38 +345,38 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-1, 0, 2, 1000000000}); - auto b = builder.ConstantR1({-1, 2, 1, -1}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {-1, 0, 2, 1000000000}); + auto b = ConstantR1(&builder, {-1, 2, 1, -1}); + Sub(a, b); ComputeAndCompareR1(&builder, {0, -2, 1, 1000000001}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Sub(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); - auto b = builder.ConstantR1( - {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, + {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); + auto b = ConstantR1( + &builder, {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); + Sub(a, b); ComputeAndCompareR1( &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, @@ -384,18 +385,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Sub(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto b = builder.ConstantR1({10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); - builder.Div(a, b); + auto a = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto b = ConstantR1(&builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); + Div(a, b); ComputeAndCompareR1(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {}, error_spec_); @@ -403,9 +404,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Div(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -442,7 +443,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - builder.Div(dividend, divisor); + Div(dividend, divisor); ComputeAndCompareR1(&builder, quotients, {dividend_data.get(), divisor_data.get()}); @@ -454,7 +455,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - builder.Div(dividend, builder.ConstantR1(divisors)); + Div(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); } @@ -467,7 +468,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - builder.Rem(dividend, divisor); + Rem(dividend, divisor); ComputeAndCompareR1(&builder, remainders, {dividend_data.get(), divisor_data.get()}); @@ -479,7 +480,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - builder.Rem(dividend, builder.ConstantR1(divisors)); + Rem(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); } @@ -513,7 +514,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - builder.Div(dividend, divisor); + Div(dividend, divisor); ComputeAndCompareR1(&builder, quotients, {dividend_data.get(), divisor_data.get()}); @@ -524,7 +525,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - builder.Div(dividend, builder.ConstantR1(divisors)); + Div(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); } @@ -537,7 +538,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - builder.Rem(dividend, divisor); + Rem(dividend, divisor); ComputeAndCompareR1(&builder, remainders, {dividend_data.get(), divisor_data.get()}); @@ -548,7 +549,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - builder.Rem(dividend, builder.ConstantR1(divisors)); + Rem(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); } @@ -556,11 +557,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); - auto b = builder.ConstantR1( - {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); - builder.Div(a, b); + auto a = ConstantR1( + &builder, {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); + auto b = ConstantR1(&builder, + {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); + Div(a, b); ComputeAndCompareR1( &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); @@ -568,20 +569,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Div(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); - auto b = builder.ConstantR1( - {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); - builder.Rem(a, b); + auto a = ConstantR1( + &builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); + auto b = ConstantR1( + &builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); + Rem(a, b); ComputeAndCompareR1( &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {}, @@ -590,20 +591,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Rem(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Rem(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); - auto b = builder.ConstantR1( - {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); - builder.Rem(a, b); + auto a = ConstantR1( + &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); + auto b = ConstantR1( + &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); + Rem(a, b); ComputeAndCompareR1( &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {}, @@ -612,9 +613,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto b = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - builder.Mul(a, b); + auto a = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto b = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + Mul(a, b); ComputeAndCompareR1(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f}, {}, error_spec_); @@ -622,9 +623,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Mul(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -648,18 +649,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR1(a_data); - auto b = builder.ConstantR1(b_data); - builder.Mul(a, b); + auto a = ConstantR1(&builder, a_data); + auto b = ConstantR1(&builder, b_data); + Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Mul(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Mul(a, b); ComputeAndCompareR1(&builder, {}, {}); } @@ -679,20 +680,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR1(a_data); - auto b = builder.ConstantR1(b_data); - builder.Mul(a, b); + auto a = ConstantR1(&builder, a_data); + auto b = ConstantR1(&builder, b_data); + Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); - auto b = builder.ConstantR1( - {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); - builder.Mul(a, b); + auto a = ConstantR1( + &builder, {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); + auto b = ConstantR1(&builder, + {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); + Mul(a, b); ComputeAndCompareR1( &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, @@ -701,27 +702,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Mul(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({false, false, true, true}); - auto b = builder.ConstantR1({false, true, false, true}); - builder.And(a, b); + auto a = ConstantR1(&builder, {false, false, true, true}); + auto b = ConstantR1(&builder, {false, true, false, true}); + And(a, b); ComputeAndCompareR1(&builder, {false, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{false, false}, {true, true}}); - auto b = builder.ConstantR2({{false, true}, {false, true}}); - builder.And(a, b); + auto a = ConstantR2(&builder, {{false, false}, {true, true}}); + auto b = ConstantR2(&builder, {{false, true}, {false, true}}); + And(a, b); Array2D expected_array({{false, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -729,27 +730,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.And(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, -1, -8}); - auto b = builder.ConstantR1({5, -7, 12}); - builder.And(a, b); + auto a = ConstantR1(&builder, {0, -1, -8}); + auto b = ConstantR1(&builder, {5, -7, 12}); + And(a, b); ComputeAndCompareR1(&builder, {0, -7, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, -5}, {-1, 5}}); - auto b = builder.ConstantR2({{1, -6}, {4, 5}}); - builder.And(a, b); + auto a = ConstantR2(&builder, {{0, -5}, {-1, 5}}); + auto b = ConstantR2(&builder, {{1, -6}, {4, 5}}); + And(a, b); Array2D expected_array({{0, -6}, {4, 5}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -757,27 +758,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.And(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, 1, 8}); - auto b = builder.ConstantR1({5, 7, 12}); - builder.And(a, b); + auto a = ConstantR1(&builder, {0, 1, 8}); + auto b = ConstantR1(&builder, {5, 7, 12}); + And(a, b); ComputeAndCompareR1(&builder, {0, 1, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, 1}, {3, 8}}); - auto b = builder.ConstantR2({{1, 0}, {7, 6}}); - builder.And(a, b); + auto a = ConstantR2(&builder, {{0, 1}, {3, 8}}); + auto b = ConstantR2(&builder, {{1, 0}, {7, 6}}); + And(a, b); Array2D expected_array({{0, 0}, {3, 0}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -785,27 +786,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.And(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({false, false, true, true}); - auto b = builder.ConstantR1({false, true, false, true}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {false, false, true, true}); + auto b = ConstantR1(&builder, {false, true, false, true}); + Or(a, b); ComputeAndCompareR1(&builder, {false, true, true, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{false, false}, {true, true}}); - auto b = builder.ConstantR2({{false, true}, {false, true}}); - builder.Or(a, b); + auto a = ConstantR2(&builder, {{false, false}, {true, true}}); + auto b = ConstantR2(&builder, {{false, true}, {false, true}}); + Or(a, b); Array2D expected_array({{false, true}, {true, true}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -813,27 +814,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, -1, 8}); - auto b = builder.ConstantR1({5, -7, 4}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {0, -1, 8}); + auto b = ConstantR1(&builder, {5, -7, 4}); + Or(a, b); ComputeAndCompareR1(&builder, {5, -1, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, -1}, {8, 8}}); - auto b = builder.ConstantR2({{5, -7}, {4, 1}}); - builder.Or(a, b); + auto a = ConstantR2(&builder, {{0, -1}, {8, 8}}); + auto b = ConstantR2(&builder, {{5, -7}, {4, 1}}); + Or(a, b); Array2D expected_array({{5, -1}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -841,27 +842,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, 1, 8}); - auto b = builder.ConstantR1({5, 7, 4}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {0, 1, 8}); + auto b = ConstantR1(&builder, {5, 7, 4}); + Or(a, b); ComputeAndCompareR1(&builder, {5, 7, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, 1}, {8, 8}}); - auto b = builder.ConstantR2({{5, 7}, {4, 1}}); - builder.Or(a, b); + auto a = ConstantR2(&builder, {{0, 1}, {8, 8}}); + auto b = ConstantR2(&builder, {{5, 7}, {4, 1}}); + Or(a, b); Array2D expected_array({{5, 7}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -869,25 +870,108 @@ XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.Or(a, b); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, XorPredR1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {false, false, true, true}); + auto b = ConstantR1(&builder, {false, true, false, true}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {false, true, true, false}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorPredR2) { + XlaBuilder builder(TestName()); + auto a = ConstantR2(&builder, {{false, false}, {true, true}}); + auto b = ConstantR2(&builder, {{false, true}, {false, true}}); + Xor(a, b); + + Array2D expected_array({{false, true}, {true, false}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementPredR1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorS32R1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {0, -1, 8}); + auto b = ConstantR1(&builder, {5, -7, 4}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {5, 6, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorS32R2) { + XlaBuilder builder(TestName()); + auto a = ConstantR2(&builder, {{0, -1}, {8, 8}}); + auto b = ConstantR2(&builder, {{5, -7}, {4, 1}}); + Xor(a, b); + + Array2D expected_array({{5, 6}, {12, 9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementS32R1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorU32R1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {0, 1, 8}); + auto b = ConstantR1(&builder, {5, 7, 4}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {5, 6, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorU32R2) { + XlaBuilder builder(TestName()); + auto a = ConstantR2(&builder, {{0, 1}, {8, 8}}); + auto b = ConstantR2(&builder, {{5, 7}, {4, 1}}); + Xor(a, b); + + Array2D expected_array({{5, 6}, {12, 9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementU32R1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + Xor(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({false, true, true, false}); - builder.Not(a); + auto a = ConstantR1(&builder, {false, true, true, false}); + Not(a); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{false, true}, {true, false}}); - builder.Not(a); + auto a = ConstantR2(&builder, {{false, true}, {true, false}}); + Not(a); Array2D expected_array({{true, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -895,24 +979,24 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Not(a); + auto a = ConstantR1(&builder, {}); + Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-1, 0, 1}); - builder.Not(a); + auto a = ConstantR1(&builder, {-1, 0, 1}); + Not(a); ComputeAndCompareR1(&builder, {0, -1, -2}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{-1, 0}, {1, 8}}); - builder.Not(a); + auto a = ConstantR2(&builder, {{-1, 0}, {1, 8}}); + Not(a); Array2D expected_array({{0, -1}, {-2, -9}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -920,24 +1004,24 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Not(a); + auto a = ConstantR1(&builder, {}); + Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0, 4294967295}); - builder.Not(a); + auto a = ConstantR1(&builder, {0, 4294967295}); + Not(a); ComputeAndCompareR1(&builder, {4294967295, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{0, 4294967295}, {1, 4294967294}}); - builder.Not(a); + auto a = ConstantR2(&builder, {{0, 4294967295}, {1, 4294967294}}); + Not(a); Array2D expected_array({{4294967295, 0}, {4294967294, 1}}); ComputeAndCompareR2(&builder, expected_array, {}); @@ -945,19 +1029,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.Not(a); + auto a = ConstantR1(&builder, {}); + Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({static_cast(0x12345678), - static_cast(0xF0001000), 1, 3, 77, - 1, -3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, -1}); - builder.ShiftLeft(a, b); + auto a = ConstantR1( + &builder, {static_cast(0x12345678), static_cast(0xF0001000), + 1, 3, 77, 1, -3, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 15, 32, 100, -1}); + ShiftLeft(a, b); ComputeAndCompareR1(&builder, {static_cast(0x23456780), 0x00100000, 0x4, @@ -967,11 +1051,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({static_cast(0x92345678), - static_cast(0x10001000), 1, 3, 77, - 1, -3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, -1}); - builder.ShiftRightArithmetic(a, b); + auto a = ConstantR1( + &builder, {static_cast(0x92345678), static_cast(0x10001000), + 1, 3, 77, 1, -3, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 2, 32, 100, -1}); + ShiftRightArithmetic(a, b); ComputeAndCompareR1( &builder, @@ -982,11 +1066,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({static_cast(0x92345678), - static_cast(0x10001000), 1, 3, 77, - 1, -3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, -1}); - builder.ShiftRightLogical(a, b); + auto a = ConstantR1( + &builder, {static_cast(0x92345678), static_cast(0x10001000), + 1, 3, 77, 1, -3, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 5, 32, 100, -1}); + ShiftRightLogical(a, b); ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); @@ -994,10 +1078,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, ~0u}); - builder.ShiftLeft(a, b); + auto a = ConstantR1(&builder, + {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 15, 32, 100, ~0u}); + ShiftLeft(a, b); ComputeAndCompareR1( &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {}); @@ -1005,10 +1089,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, ~0u}); - builder.ShiftRightArithmetic(a, b); + auto a = ConstantR1(&builder, + {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 2, 32, 100, ~0u}); + ShiftRightArithmetic(a, b); ComputeAndCompareR1( &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {}); @@ -1016,10 +1100,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, ~0u}); - builder.ShiftRightLogical(a, b); + auto a = ConstantR1(&builder, + {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); + auto b = ConstantR1(&builder, {4, 8, 2, 7, 5, 32, 100, ~0u}); + ShiftRightLogical(a, b); ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); @@ -1028,18 +1112,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 2.25f, 10.0f, NAN}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 2.25f, 10.0f, NAN}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } @@ -1047,9 +1131,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - builder.Ge(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); + Ge(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } @@ -1057,9 +1141,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - builder.Gt(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); + Gt(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } @@ -1067,9 +1151,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 5.0f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - builder.Le(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 5.0f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); + Le(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, false, false}, {}); } @@ -1077,9 +1161,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - builder.Lt(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); + Lt(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, false, false, false}, {}); } @@ -1088,9 +1172,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Eq(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -1099,9 +1184,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } @@ -1109,26 +1194,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, - {1.0f, 25.5f}, - {2.25f, -3.0f}, - {NAN, 0.0f}, - {1.0f, 6.0f}}); - auto rhs = builder.ConstantR1({{0.0f, 10.0f}, - {1.0f, 5.0f}, - {2.25f, -3.0f}, - {10.0f, 0.0f}, - {1.0f, NAN}}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = ConstantR1(&builder, {{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } @@ -1138,17 +1223,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, - {1.0f, 25.5f}, - {2.25f, -3.0f}, - {NAN, 0.0f}, - {1.0f, 6.0f}}); - auto rhs = builder.ConstantR1({{0.0f, 10.0f}, - {1.0f, 5.0f}, - {2.25f, -3.0f}, - {10.0f, 0.0f}, - {1.0f, NAN}}); - builder.Ne(lhs, rhs); + auto lhs = ConstantR1(&builder, {{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = ConstantR1(&builder, {{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, true, true}, {}); } @@ -1158,9 +1243,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({10.0f, 25.5f, 1.0f, 10.0f, NAN}); - builder.Ne(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 25.5f, 1.0f, 10.0f, NAN}); + Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, true, true, true}, {}); } @@ -1169,9 +1254,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Ne(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1181,9 +1267,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Ge(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1193,9 +1280,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Gt(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1206,9 +1294,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Le(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1218,9 +1307,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); - auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - builder.Lt(lhs, rhs); + auto lhs = + ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); + auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); + Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1230,9 +1320,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Eq(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -1242,9 +1332,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Ne(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1253,9 +1343,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Ge(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1264,9 +1354,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Gt(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1276,9 +1366,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Le(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1287,9 +1377,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); - auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - builder.Lt(lhs, rhs); + auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); + Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1300,10 +1390,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto lhs = - builder.ConstantR1({4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); + ConstantR1(&builder, {4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); auto rhs = - builder.ConstantR1({2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); - builder.Pow(lhs, rhs); + ConstantR1(&builder, {2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); + Pow(lhs, rhs); ComputeAndCompareR1( &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_); @@ -1312,9 +1402,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({-2.0f, -0.6f, -0.6f, 0.0f}); - auto rhs = builder.ConstantR1({0.5f, 0.6f, -0.6f, -0.6f}); - builder.Pow(lhs, rhs); + auto lhs = ConstantR1(&builder, {-2.0f, -0.6f, -0.6f, 0.0f}); + auto rhs = ConstantR1(&builder, {0.5f, 0.6f, -0.6f, -0.6f}); + Pow(lhs, rhs); ComputeAndCompareR1(&builder, {NAN, NAN, NAN, INFINITY}, {}, error_spec_); @@ -1322,9 +1412,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Pow(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Pow(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -1336,14 +1426,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr param_literal = Literal::CreateR1(values); + std::unique_ptr param_literal = LiteralUtil::CreateR1(values); std::unique_ptr param_data = client_->TransferToServer(*param_literal).ConsumeValueOrDie(); - auto sum = b.ConstantR0(0.0f); - auto param = b.Parameter(0, param_literal->shape(), "param"); + auto sum = ConstantR0(&b, 0.0f); + auto param = Parameter(&b, 0, param_literal->shape(), "param"); for (float exponent : exponents) { - sum = b.Add(sum, b.Pow(param, b.ConstantR0(exponent))); + sum = Add(sum, Pow(param, ConstantR0(&b, exponent))); } std::vector expected; @@ -1364,15 +1454,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - b.Pow(b.Exp(param0), param1); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + Pow(Exp(param0), param1); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1389,15 +1479,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - b.Log(b.Pow(param0, param1)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + Log(Pow(param0, param1)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1414,15 +1504,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - b.Mul(b.Exp(param0), b.Exp(param1)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + Mul(Exp(param0), Exp(param1)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1439,15 +1529,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - b.Div(param0, b.Exp(param1)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + Div(param0, Exp(param1)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1465,21 +1555,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - auto param2 = b.Parameter(2, literal2->shape(), "param2"); - b.Div(b.Div(param0, param1), param2); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + Div(Div(param0, param1), param2); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1497,22 +1587,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - auto param2 = b.Parameter(2, literal2->shape(), "param2"); - b.Div(param0, b.Div(param1, param2)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + Div(param0, Div(param1, param2)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1530,22 +1620,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - auto param2 = b.Parameter(2, literal2->shape(), "param2"); - b.Div(param0, b.Pow(param1, param2)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + Div(param0, Pow(param1, param2)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1564,27 +1654,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; std::vector values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); - std::unique_ptr literal3 = Literal::CreateR1(values3); + std::unique_ptr literal3 = LiteralUtil::CreateR1(values3); std::unique_ptr data3 = client_->TransferToServer(*literal3).ConsumeValueOrDie(); - auto param0 = b.Parameter(0, literal0->shape(), "param0"); - auto param1 = b.Parameter(1, literal1->shape(), "param1"); - auto param2 = b.Parameter(2, literal2->shape(), "param2"); - auto param3 = b.Parameter(3, literal3->shape(), "param2"); - b.Div(b.Div(param0, param1), b.Div(param2, param3)); + auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + auto param3 = Parameter(&b, 3, literal3->shape(), "param2"); + Div(Div(param0, param1), Div(param2, param3)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { @@ -1604,8 +1694,8 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { for (int i = 0; i < count; ++i) { values.push_back(i / static_cast(count)); } - auto x = builder.ConstantR1(values); - builder.Pow(x, builder.ConstantR0(2.0f)); + auto x = ConstantR1(&builder, values); + Pow(x, ConstantR0(&builder, 2.0f)); std::vector expected; expected.reserve(values.size()); @@ -1630,8 +1720,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { Array4D expected(2, 2, 2, 2, expected_vector); - auto x = builder.ConstantR4FromArray4D(values); - builder.Pow(x, builder.ConstantR0(2.0f)); + auto x = ConstantR4FromArray4D(&builder, values); + Pow(x, ConstantR0(&builder, 2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } @@ -1641,8 +1731,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { Array4D values(2, 2, 0, 2); Array4D expected(2, 2, 0, 2); - auto x = builder.ConstantR4FromArray4D(values); - builder.Pow(x, builder.ConstantR0(2.0f)); + auto x = ConstantR4FromArray4D(&builder, values); + Pow(x, ConstantR0(&builder, 2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } @@ -1650,9 +1740,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); - auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); - builder.Min(lhs, rhs); + auto lhs = ConstantR1(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN}); + Min(lhs, rhs); ComputeAndCompareR1(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {}, error_spec_); @@ -1660,18 +1750,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Min(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Min(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); - auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); - auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); - builder.Min(lhs, rhs); + auto lhs = ConstantR1(&builder, {1.0, 1.0, 2.25, NAN, 6.0}); + auto rhs = ConstantR1(&builder, {2.0, -5.0, 1.0, 10.0, NAN}); + Min(lhs, rhs); ComputeAndCompareR1(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {}, error_spec_); @@ -1680,9 +1770,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); - auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); - builder.Max(lhs, rhs); + auto lhs = ConstantR1(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN}); + Max(lhs, rhs); ComputeAndCompareR1(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {}, error_spec_); @@ -1690,18 +1780,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - builder.Max(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Max(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); - auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); - auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); - builder.Max(lhs, rhs); + auto lhs = ConstantR1(&builder, {1.0, 1.0, 2.25, NAN, 6.0}); + auto rhs = ConstantR1(&builder, {2.0, -5.0, 1.0, 10.0, NAN}); + Max(lhs, rhs); ComputeAndCompareR1(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {}, error_spec_); @@ -1711,11 +1801,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); - auto y = builder.ConstantR1( - {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); - builder.Max(x, y); + auto x = ConstantR1( + &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); + auto y = ConstantR1( + &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); + Max(x, y); std::vector expected = {min, max, 0, -1, 0, 0, 0, 1, 1, 10, max, max, max}; @@ -1726,11 +1816,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); - auto y = builder.ConstantR1( - {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); - builder.Min(x, y); + auto x = ConstantR1( + &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); + auto y = ConstantR1( + &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); + Min(x, y); std::vector expected = {min, min, min, -10, -1, -1, 0, 0, 0, 1, 0, max, min}; @@ -1740,9 +1830,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); - auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); - builder.Max(x, y); + auto x = ConstantR1(&builder, {0, 0, 1, 1, 1, max, max, max}); + auto y = ConstantR1(&builder, {0, 1, 0, 1, 10, 0, 234234, max}); + Max(x, y); std::vector expected = {0, 1, 1, 1, 10, max, max, max}; ComputeAndCompareR1(&builder, expected, {}); @@ -1751,9 +1841,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); - auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); - builder.Min(x, y); + auto x = ConstantR1(&builder, {0, 0, 1, 1, 1, max, max, max}); + auto y = ConstantR1(&builder, {0, 1, 0, 1, 10, 0, 234234, max}); + Min(x, y); std::vector expected = {0, 0, 0, 1, 1, 0, 234234, max}; ComputeAndCompareR1(&builder, expected, {}); @@ -1761,11 +1851,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); - auto y = builder.ConstantR1( - {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); - builder.Max(x, y); + auto x = ConstantR1( + &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); + auto y = ConstantR1( + &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); + Max(x, y); std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; @@ -1774,9 +1864,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { XlaBuilder builder(TestName()); - auto u = builder.ConstantR1({3.5}); - auto v = builder.ConstantR1({}); - builder.Max(u, v); + auto u = ConstantR1(&builder, {3.5}); + auto v = ConstantR1(&builder, {}); + Max(u, v); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -1784,9 +1874,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { for (int broadcast_dim : {0, 1}) { XlaBuilder builder(TestName()); - auto u = builder.ConstantR1({3.5}); - auto v = builder.ConstantR2FromArray2D(Array2D(0, 2)); - builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); + auto u = ConstantR1(&builder, {3.5}); + auto v = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); ComputeAndCompareR2(&builder, Array2D(0, 2), {}, error_spec_); } @@ -1794,10 +1884,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - auto m = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - builder.Max(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + auto m = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + Max(v, m, /*broadcast_dimensions=*/{1}); Array2D expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -1805,9 +1895,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({}); - auto m = builder.ConstantR2({{}, {}}); - builder.Max(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {}); + auto m = ConstantR2(&builder, {{}, {}}); + Max(v, m, /*broadcast_dimensions=*/{1}); Array2D expected({{}, {}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -1815,10 +1905,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { XlaBuilder builder(TestName()); - auto scalar = builder.ConstantR0(2); + auto scalar = ConstantR0(&builder, 2); Array3D a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}}); - auto array = builder.ConstantR3FromArray3D(a_3d); - builder.Max(array, scalar, /*broadcast_dimensions=*/{}); + auto array = ConstantR3FromArray3D(&builder, a_3d); + Max(array, scalar, /*broadcast_dimensions=*/{}); Array3D expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}}); ComputeAndCompareR3(&builder, expected, {}); @@ -1826,10 +1916,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { XlaBuilder builder(TestName()); - auto scalar = builder.ConstantR0(2); + auto scalar = ConstantR0(&builder, 2); Array3D a_3d(2, 0, 3); - auto array = builder.ConstantR3FromArray3D(a_3d); - builder.Max(array, scalar, /*broadcast_dimensions=*/{}); + auto array = ConstantR3FromArray3D(&builder, a_3d); + Max(array, scalar, /*broadcast_dimensions=*/{}); Array3D expected(2, 0, 3); ComputeAndCompareR3(&builder, expected, {}); @@ -1837,10 +1927,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { XlaBuilder builder(TestName()); - auto m = - builder.ConstantR2({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); - auto v = builder.ConstantR1({-10.2f, 16.4f}); - builder.Min(m, v, /*broadcast_dimensions=*/{0}); + auto m = ConstantR2(&builder, + {{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); + auto v = ConstantR1(&builder, {-10.2f, 16.4f}); + Min(m, v, /*broadcast_dimensions=*/{0}); Array2D expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -1848,9 +1938,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { XlaBuilder builder(TestName()); - auto m = builder.ConstantR2({{}, {}}); - auto v = builder.ConstantR1({-10.2f, 16.4f}); - builder.Min(m, v, /*broadcast_dimensions=*/{0}); + auto m = ConstantR2(&builder, {{}, {}}); + auto v = ConstantR1(&builder, {-10.2f, 16.4f}); + Min(m, v, /*broadcast_dimensions=*/{0}); Array2D expected({{}, {}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -1859,11 +1949,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { XlaBuilder builder(TestName()); auto array2d = - builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); - auto array4d = builder.ConstantR4FromArray4D( - {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}}, - {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}}); - builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); + ConstantR2(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); + auto array4d = ConstantR4FromArray4D( + &builder, {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}}, + {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}}); + Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); Array4D expected( {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}}, @@ -1874,10 +1964,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { XlaBuilder builder(TestName()); auto array2d = - builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); + ConstantR2(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); Array4D arg(2, 2, 0, 3); - auto array4d = builder.ConstantR4FromArray4D(arg); - builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); + auto array4d = ConstantR4FromArray4D(&builder, arg); + Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); Array4D expected(2, 2, 0, 3); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -1885,9 +1975,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); - builder.Min(x, y); + auto x = ConstantR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = ConstantR1(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + Min(x, y); std::vector expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0}; ComputeAndCompareR1(&builder, expected, {}); @@ -1895,9 +1985,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); - builder.Max(x, y); + auto x = ConstantR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = ConstantR1(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + Max(x, y); std::vector expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9}; ComputeAndCompareR1(&builder, expected, {}); @@ -1905,19 +1995,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-3, 26, 2, -1, 1}); - auto b = builder.ConstantR1({10, 5, 1, 10, -10}); - builder.Rem(a, b); + auto a = ConstantR1(&builder, {-3, 26, 2, -1, 1}); + auto b = ConstantR1(&builder, {10, 5, 1, 10, -10}); + Rem(a, b); ComputeAndCompareR1(&builder, {-3, 1, 0, -1, 1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { XlaBuilder builder(TestName()); - auto minimum = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); - auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); - auto maximum = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); - builder.Clamp(minimum, argument, maximum); + auto minimum = ConstantR1(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); + auto argument = + ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); + auto maximum = ConstantR1(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0}); + Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {}, error_spec_); @@ -1925,10 +2016,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { XlaBuilder builder(TestName()); - auto minimum = builder.ConstantR0(0.0f); - auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); - auto maximum = builder.ConstantR0(5.0f); - builder.Clamp(minimum, argument, maximum); + auto minimum = ConstantR0(&builder, 0.0f); + auto argument = ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); + auto maximum = ConstantR0(&builder, 5.0f); + Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {}, error_spec_); @@ -1936,16 +2027,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { XlaBuilder builder(TestName()); - auto min_scalar = builder.ConstantR0(0.0f); - auto min_vector = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); - auto arg_vector = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); - auto max_scalar = builder.ConstantR0(3.0f); - auto max_vector = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); + auto min_scalar = ConstantR0(&builder, 0.0f); + auto min_vector = + ConstantR1(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); + auto arg_vector = + ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); + auto max_scalar = ConstantR0(&builder, 3.0f); + auto max_vector = + ConstantR1(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0}); // Perform clamp with broadcasted scalar and vector. - builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + Add(Add(Clamp(min_vector, arg_vector, max_scalar), + Clamp(min_scalar, arg_vector, max_vector)), + Add(Clamp(min_vector, arg_vector, max_vector), + Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {}, error_spec_); @@ -1953,52 +2047,52 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) { XlaBuilder builder(TestName()); - auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0, -5}); - auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4, 10}); - auto max_vector = builder.ConstantR1({3, 0, 25, 5, 123, -1}); - builder.Clamp(min_vector, arg_vector, max_vector); + auto min_vector = ConstantR1(&builder, {1, -6, 1, 2, 0, -5}); + auto arg_vector = ConstantR1(&builder, {2, 10, -5, 1, 4, 10}); + auto max_vector = ConstantR1(&builder, {3, 0, 25, 5, 123, -1}); + Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 0, 1, 2, 4, -1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) { XlaBuilder builder(TestName()); - auto min_scalar = builder.ConstantR0(0); - auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0}); - auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4}); - auto max_scalar = builder.ConstantR0(3); - auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); + auto min_scalar = ConstantR0(&builder, 0); + auto min_vector = ConstantR1(&builder, {1, -6, 1, 2, 0}); + auto arg_vector = ConstantR1(&builder, {2, 10, -5, 1, 4}); + auto max_scalar = ConstantR0(&builder, 3); + auto max_vector = ConstantR1(&builder, {3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + Add(Add(Clamp(min_vector, arg_vector, max_scalar), + Clamp(min_scalar, arg_vector, max_vector)), + Add(Clamp(min_vector, arg_vector, max_vector), + Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) { XlaBuilder builder(TestName()); - auto min_vector = builder.ConstantR1({1, 2, 1, 2, 0, ~0u - 4}); - auto arg_vector = builder.ConstantR1({2, 10, 5, 1, 4, 10}); - auto max_vector = builder.ConstantR1({3, 5, 25, 5, 123, ~0u}); - builder.Clamp(min_vector, arg_vector, max_vector); + auto min_vector = ConstantR1(&builder, {1, 2, 1, 2, 0, ~0u - 4}); + auto arg_vector = ConstantR1(&builder, {2, 10, 5, 1, 4, 10}); + auto max_vector = ConstantR1(&builder, {3, 5, 25, 5, 123, ~0u}); + Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { XlaBuilder builder(TestName()); - auto min_scalar = builder.ConstantR0(0); - auto min_vector = builder.ConstantR1({1, 0, 1, 2, 0}); - auto arg_vector = builder.ConstantR1({2, 10, 0, 1, 4}); - auto max_scalar = builder.ConstantR0(3); - auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); + auto min_scalar = ConstantR0(&builder, 0); + auto min_vector = ConstantR1(&builder, {1, 0, 1, 2, 0}); + auto arg_vector = ConstantR1(&builder, {2, 10, 0, 1, 4}); + auto max_scalar = ConstantR0(&builder, 3); + auto max_vector = ConstantR1(&builder, {3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + Add(Add(Clamp(min_vector, arg_vector, max_scalar), + Clamp(min_scalar, arg_vector, max_vector)), + Add(Clamp(min_vector, arg_vector, max_vector), + Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } @@ -2007,18 +2101,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); + LiteralUtil::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Add(p0, p1); + auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Add(p0, p1); ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, {param0_data.get(), param1_data.get()}, @@ -2029,18 +2123,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); + LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); + LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Add(p0, p1); + auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Add(p0, p1); Array3D expected(0, 7, 0); ComputeAndCompareR3( @@ -2051,13 +2145,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); - auto p = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Add(a, p); + auto a = ConstantR1(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); + auto p = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Add(a, p); ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, {param0_data.get()}, error_spec_); @@ -2065,8 +2159,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - builder.Cos(a); + auto a = ConstantR1(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f}); + Cos(a); ComputeAndCompareR1(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, error_spec_); @@ -2074,8 +2168,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - builder.Sin(a); + auto a = ConstantR1(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f}); + Sin(a); ComputeAndCompareR1(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {}, error_spec_); @@ -2083,9 +2177,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); - auto b = builder.ConstantR1({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); - builder.Atan2(a, b); + auto a = ConstantR1(&builder, {0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); + auto b = ConstantR1(&builder, {6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); + Atan2(a, b); ComputeAndCompareR1( &builder, @@ -2095,8 +2189,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f}); - builder.Tanh(a); + auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f}); + Tanh(a); ComputeAndCompareR1(&builder, {-0.986614f, 0.996260f, 0.978026}, {}, error_spec_); @@ -2107,7 +2201,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { // the input tensor is large enough to exercise the vectorized tanh // implementation on XLA CPU. XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR1( + auto input_literal = LiteralUtil::CreateR1( {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16, -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32, -1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85, @@ -2118,8 +2212,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*input_literal)); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Tanh(input); + auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + Tanh(input); ComputeAndCompareR1( &builder, @@ -2149,7 +2243,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { // Just to help make sense of the scales here -- exp(89) saturates float32 and // exp(-10) is smaller than our error spec. - std::unique_ptr input_literal = Literal::CreateR1( + std::unique_ptr input_literal = LiteralUtil::CreateR1( {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, @@ -2164,8 +2258,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, client_->TransferToServer(*input_literal)); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Exp(input); + auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + Exp(input); std::vector expected_result; int64 input_size = input_literal->shape().dimensions(0); @@ -2183,7 +2277,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { // implementation on XLA CPU. XlaBuilder builder(TestName()); - std::unique_ptr input_literal = Literal::CreateR1( + std::unique_ptr input_literal = LiteralUtil::CreateR1( {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, @@ -2202,8 +2296,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, client_->TransferToServer(*input_literal)); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Log(input); + auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + Log(input); std::vector expected_result; int64 input_size = input_literal->shape().dimensions(0); @@ -2218,9 +2312,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678}); - builder.Clz(a); + auto a = ConstantR1( + &builder, {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678}); + Clz(a); ComputeAndCompareR1(&builder, {32, 31, 27, 15, 9, 3, 0}, {}); } @@ -2228,8 +2322,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) { XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) { XlaBuilder builder(TestName()); auto a = - builder.ConstantR1({0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1}); - builder.Clz(a); + ConstantR1(&builder, {0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1}); + Clz(a); ComputeAndCompareR1(&builder, {64, 63, 32, 1, 0}, {}); } @@ -2241,12 +2335,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // c---------------------/ XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); - auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); - auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); + auto a = ConstantR1(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); + auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); + auto c = ConstantR1(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); - auto add = builder.Add(a, b); - builder.Add(add, c); + auto add = Add(a, b); + Add(add, c); ComputeAndCompareR1(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2259,12 +2353,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) { // a---------------------/ XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); - auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); - auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); + auto a = ConstantR1(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); + auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); + auto c = ConstantR1(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); - auto add = builder.Add(b, c); - builder.Add(a, add); + auto add = Add(b, c); + Add(a, add); ComputeAndCompareR1(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2276,12 +2370,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) { // b ----- (neg) ----/ XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); - auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); + auto a = ConstantR1(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); + auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); - auto neg_a = builder.Neg(a); - auto neg_b = builder.Neg(b); - builder.Add(neg_a, neg_b); + auto neg_a = Neg(a); + auto neg_b = Neg(b); + Add(neg_a, neg_b); ComputeAndCompareR1(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {}, error_spec_); @@ -2297,14 +2391,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { // d -----/ XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); - auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); - auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); - auto d = builder.ConstantR1({-19.0f, 10.0f, -40.0f, 20.2f}); + auto a = ConstantR1(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); + auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); + auto c = ConstantR1(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); + auto d = ConstantR1(&builder, {-19.0f, 10.0f, -40.0f, 20.2f}); - auto add_ab = builder.Add(a, b); - auto add_cd = builder.Add(c, d); - builder.Add(add_ab, add_cd); + auto add_ab = Add(a, b); + auto add_cd = Add(c, d); + Add(add_ab, add_cd); ComputeAndCompareR1(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {}, error_spec_); @@ -2312,11 +2406,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto b = - builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - builder.Add(a, b); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto b = ConstantR2(&builder, + {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); + Add(a, b); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2326,10 +2420,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { // Add a scalar + matrix. XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto scalar = builder.ConstantR0(3.0f); - builder.Add(scalar, a); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto scalar = ConstantR0(&builder, 3.0f); + Add(scalar, a); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2338,10 +2432,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { // Add a matrix + scalar. XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto scalar = builder.ConstantR0(3.0f); - builder.Add(a, scalar); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto scalar = ConstantR0(&builder, 3.0f); + Add(a, scalar); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2351,13 +2445,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches // only dim 0 of the matrix. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({20.0f, 40.0f, 60.0f}); + auto v = ConstantR1(&builder, {20.0f, 40.0f, 60.0f}); // clang-format off - auto m = builder.ConstantR2({ + auto m = ConstantR2(&builder, { {-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); // clang-format on - builder.Add(v, m, /*broadcast_dimensions=*/{1}); + Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array( {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2366,27 +2460,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { // Test broadcasting in Eq comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({42, 73}); - auto m = builder.ConstantR2({{42, 73}, {42, 52}}); + auto v = ConstantR1(&builder, {42, 73}); + auto m = ConstantR2(&builder, {{42, 73}, {42, 52}}); // This test exercises both possible broadcast dimensions for a vector/matrix // comparison. - auto cmp_dim_0 = builder.Eq(v, m, /*broadcast_dimensions=*/{1}); - auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0}); - auto result = builder.Tuple({cmp_dim_0, cmp_dim_1}); + auto cmp_dim_0 = Eq(v, m, /*broadcast_dimensions=*/{1}); + auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0}); + Tuple(&builder, {cmp_dim_0, cmp_dim_1}); - auto expected = Literal::MakeTuple( - {Literal::CreateR2({{true, true}, {true, false}}).get(), - Literal::CreateR2({{true, false}, {false, false}}).get()}); + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{true, true}, {true, false}}).get(), + LiteralUtil::CreateR2({{true, false}, {false, false}}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { // Test broadcasting in Ne comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({42, 73}); - auto m = builder.ConstantR2({{42, 73}, {42, 52}}); - builder.Ne(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {42, 73}); + auto m = ConstantR2(&builder, {{42, 73}, {42, 52}}); + Ne(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,2] { { 00 }, @@ -2398,9 +2492,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { // Test broadcasting in Ge comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, 2, 3, 4}); - auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - builder.Ge(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {1, 2, 3, 4}); + auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); + Ge(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1100 }, @@ -2412,9 +2506,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { // Test broadcasting in Gt comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, 2, 3, 4}); - auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - builder.Gt(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {1, 2, 3, 4}); + auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); + Gt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0100 }, @@ -2426,9 +2520,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { // Test broadcasting in Le comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, 2, 3, 4}); - auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - builder.Le(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {1, 2, 3, 4}); + auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); + Le(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1011 }, @@ -2440,9 +2534,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { // Test broadcasting in Lt comparison. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, 2, 3, 4}); - auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - builder.Lt(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {1, 2, 3, 4}); + auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); + Lt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0011 }, @@ -2455,9 +2549,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) { // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op // arguments is reversed. XlaBuilder builder(TestName()); - auto m = builder.ConstantR2({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); - auto v = builder.ConstantR1({2.0f, 4.0f, 6.0f}); - builder.Mul(m, v, /*broadcast_dimensions=*/{1}); + auto m = + ConstantR2(&builder, {{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); + auto v = ConstantR1(&builder, {2.0f, 4.0f, 6.0f}); + Mul(m, v, /*broadcast_dimensions=*/{1}); Array2D expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } @@ -2468,10 +2563,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {3, 1} // The result has shape {3, 2}, where md is broadcast over m - auto m = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto md = builder.ConstantR2({{10.0f, 20.0f, 30.0f}}); - builder.Add(m, md); + auto m = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto md = ConstantR2(&builder, {{10.0f, 20.0f, 30.0f}}); + Add(m, md); Array2D expected_array( {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2483,10 +2578,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) { // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {1, 2} // The result has shape {3, 2}, where md is broadcast over m - auto m = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto md = builder.ConstantR2({{10.0f}, {20.0f}}); - builder.Add(m, md); + auto m = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto md = ConstantR2(&builder, {{10.0f}, {20.0f}}); + Add(m, md); Array2D expected_array( {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2501,9 +2596,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { // a's shape in XLA notation is {1, 4} // b's shape in XLA notation is {3, 1} // The result has shape {3, 4}. - auto a = builder.ConstantR2({{0.0f}, {10.0f}, {20.0f}, {30.0f}}); - auto b = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); - builder.Add(a, b); + auto a = ConstantR2(&builder, {{0.0f}, {10.0f}, {20.0f}, {30.0f}}); + auto b = ConstantR2(&builder, {{1.0f, 2.0f, 3.0f}}); + Add(a, b); Array2D expected_array({{1.0f, 2.0f, 3.0f}, {11.0f, 12.0f, 13.0f}, {21.0f, 22.0f, 23.0f}, @@ -2515,9 +2610,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { // Add together a (2,2) array and a (2) array, using dimension 0 for // broadcasting (though there are two ways to broadcast these shapes). XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({20.0f, 40.0f}); - auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - builder.Add(v, m, /*broadcast_dimensions=*/{1}); + auto v = ConstantR1(&builder, {20.0f, 40.0f}); + auto m = ConstantR2(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}}); + Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } @@ -2526,9 +2621,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) { // Add together a (2,2) array and a (2) array, using dimension 1 for // broadcasting (though there are two ways to broadcast these shapes). XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({20.0f, 40.0f}); - auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - builder.Add(v, m, /*broadcast_dimensions=*/{0}); + auto v = ConstantR1(&builder, {20.0f, 40.0f}); + auto m = ConstantR2(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}}); + Add(v, m, /*broadcast_dimensions=*/{0}); Array2D expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } @@ -2538,12 +2633,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); - auto a = builder.ConstantR3FromArray3D(a_3d); + auto a = ConstantR3FromArray3D(&builder, a_3d); Array3D b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}}, {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}}); - auto b = builder.ConstantR3FromArray3D(b_3d); - builder.Add(a, b); + auto b = ConstantR3FromArray3D(&builder, b_3d); + Add(a, b); Array3D expected_3d( {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}}, @@ -2565,9 +2660,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { {11.0f, 12.0f}}, }); // clang-format on - auto a = builder.ConstantR3FromArray3D(a_3d); - auto v = builder.ConstantR1({10.0f, 20.0f}); - builder.Add(a, v, /*broadcast_dimensions=*/{2}); + auto a = ConstantR3FromArray3D(&builder, a_3d); + auto v = ConstantR1(&builder, {10.0f, 20.0f}); + Add(a, v, /*broadcast_dimensions=*/{2}); Array3D expected_3d( {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}}, @@ -2589,9 +2684,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { {11.0f, 12.0f}}, }); // clang-format on - auto a = builder.ConstantR3FromArray3D(a_3d); - auto v = builder.ConstantR1({10.0f, 20.0f}); - builder.Add(a, v, /*broadcast_dimensions=*/{0}); + auto a = ConstantR3FromArray3D(&builder, a_3d); + auto v = ConstantR1(&builder, {10.0f, 20.0f}); + Add(a, v, /*broadcast_dimensions=*/{0}); // clang-format off Array3D expected_3d({ @@ -2619,12 +2714,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { {9.0f, 10.0f}, {11.0f, 12.0f}}, }); - auto a = builder.ConstantR3FromArray3D(a_3d); - auto m = builder.ConstantR2({ + auto a = ConstantR3FromArray3D(&builder, a_3d); + auto m = ConstantR2(&builder, { {10.0f, 20.0f, 30.0f}, {40.0f, 50.0f, 60.0f}, }); - builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); + Add(a, m, /*broadcast_dimensions=*/{0, 1}); Array3D expected_3d({ {{11.0f, 12.0f}, @@ -2644,12 +2739,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); - auto a = builder.ConstantR3FromArray3D(a_3d); + auto a = ConstantR3FromArray3D(&builder, a_3d); Array3D b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}}); - auto b = builder.ConstantR3FromArray3D(b_3d); + auto b = ConstantR3FromArray3D(&builder, b_3d); - builder.Gt(a, b); + Gt(a, b); Array3D expected_3d( {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); @@ -2684,9 +2779,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { } } - auto a = builder.ConstantR4FromArray4D(*operand_a_4d); - auto b = builder.ConstantR4FromArray4D(*operand_b_4d); - builder.Add(a, b); + auto a = ConstantR4FromArray4D(&builder, *operand_a_4d); + auto b = ConstantR4FromArray4D(&builder, *operand_b_4d); + Add(a, b); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } @@ -2712,9 +2807,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { } } - auto a = builder.ConstantR4FromArray4D(*operand_a_4d); - auto b = builder.ConstantR1(operand_b_1d); - builder.Add(a, b, {1}); + auto a = ConstantR4FromArray4D(&builder, *operand_a_4d); + auto b = ConstantR1(&builder, operand_b_1d); + Add(a, b, {1}); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } @@ -2730,11 +2825,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::iota(r1.begin(), r1.end(), 1.0); XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR4FromArray4DWithLayout( - r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); - auto a = builder.ConstantLiteral(*a_literal); - auto b = builder.ConstantR1(r1); - builder.Add(a, b, {1}); + std::unique_ptr a_literal = + LiteralUtil::CreateR4FromArray4DWithLayout( + r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); + auto a = ConstantLiteral(&builder, *a_literal); + auto b = ConstantR1(&builder, r1); + Add(a, b, {1}); for (int i0 = 0; i0 < d0; ++i0) { for (int i1 = 0; i1 < d1; ++i1) { @@ -2752,22 +2848,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { XlaBuilder builder(TestName()); auto shape = ShapeUtil::MakeOpaqueShape(); - auto x = builder.Parameter(0, shape, "x"); - builder.Add(x, x); + auto x = Parameter(&builder, 0, shape, "x"); + Add(x, x); auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), ::testing::ContainsRegex( - "Expected non-opaque argument for lhs of binary operation")); + "Expected array argument for lhs of binary operation")); } XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto b = - builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - builder.Add(a, b, /*broadcast_dimensions=*/{0, 1}); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto b = ConstantR2(&builder, + {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); + Add(a, b, /*broadcast_dimensions=*/{0, 1}); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2776,11 +2872,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); - auto b = - builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - builder.Add(a, b, /*broadcast_dimensions=*/{1, 0}); + auto a = ConstantR2(&builder, + {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto b = ConstantR2(&builder, + {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); + Add(a, b, /*broadcast_dimensions=*/{1, 0}); auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); @@ -2792,15 +2888,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { // broadcast. XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { XlaBuilder builder(TestName()); - auto x_literal = Literal::CreateR1({1, 2, 3}); - auto y_literal = Literal::CreateR1({4, 5}); + auto x_literal = LiteralUtil::CreateR1({1, 2, 3}); + auto y_literal = LiteralUtil::CreateR1({4, 5}); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - auto x = builder.Parameter(0, x_literal->shape(), "x"); - auto y = builder.Parameter(1, y_literal->shape(), "y"); - auto slice = builder.Slice(x, {1}, {2}, {1}); - builder.Sub(slice, y); + auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto y = Parameter(&builder, 1, y_literal->shape(), "y"); + auto slice = Slice(x, {1}, {2}, {1}); + Sub(slice, y); ComputeAndCompareR1(&builder, {-2, -3}, {x_data.get(), y_data.get()}, error_spec_); diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index fcd9ff55e393f64476ddd4754e0fa74427f1cb51..caeb0bf49a0dde9eeac02037b2ea04fd024d100c 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -29,10 +29,10 @@ class AxpySimpleTest : public ClientLibraryTestBase {}; TEST_F(AxpySimpleTest, AxTenValues) { XlaBuilder builder("ax_10"); - auto alpha = builder.ConstantR0(3.1415926535); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - builder.Mul(alpha, x); + auto alpha = ConstantR0(&builder, 3.1415926535); + auto x = ConstantR1( + &builder, {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + Mul(alpha, x); std::vector expected = { -3.14159265, 3.14159265, 6.28318531, -6.28318531, -9.42477796, @@ -42,11 +42,11 @@ TEST_F(AxpySimpleTest, AxTenValues) { XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { XlaBuilder builder("axpy_10"); - auto alpha = builder.ConstantR0(3.1415926535); - auto x = builder.ConstantR1({}); - auto y = builder.ConstantR1({}); - auto ax = builder.Mul(alpha, x); - builder.Add(ax, y); + auto alpha = ConstantR0(&builder, 3.1415926535); + auto x = ConstantR1(&builder, {}); + auto y = ConstantR1(&builder, {}); + auto ax = Mul(alpha, x); + Add(ax, y); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -54,13 +54,13 @@ XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { TEST_F(AxpySimpleTest, AxpyTenValues) { XlaBuilder builder("axpy_10"); - auto alpha = builder.ConstantR0(3.1415926535); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto y = builder.ConstantR1( - {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); - auto ax = builder.Mul(alpha, x); - builder.Add(ax, y); + auto alpha = ConstantR0(&builder, 3.1415926535); + auto x = ConstantR1( + &builder, {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto y = ConstantR1( + &builder, {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); + auto ax = Mul(alpha, x); + Add(ax, y); TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape()); diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc index 22c3394e6f34bd018ffaaaa4d9d68339673c3764..af0b8522394a0c591e6c42ad12db8853ef66243c 100644 --- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -35,10 +35,10 @@ class BadRngShapeValidationTest : public ClientLibraryTestBase {}; TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0.0); - auto one = builder.ConstantR0(1.0); + auto zero = ConstantR0(&builder, 0.0); + auto one = ConstantR0(&builder, 1.0); Shape default_constructed; - builder.RngUniform(zero, one, default_constructed); + RngUniform(zero, one, default_constructed); StatusOr computation = builder.Build(); EXPECT_FALSE(computation.ok()); @@ -49,13 +49,13 @@ TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0.0); - auto one = builder.ConstantR0(1.0); + auto zero = ConstantR0(&builder, 0.0); + auto one = ConstantR0(&builder, 1.0); Shape sans_layout; sans_layout.set_element_type(F32); sans_layout.add_dimensions(1); - builder.RngUniform(zero, one, sans_layout); + RngUniform(zero, one, sans_layout); StatusOr computation = builder.Build(); ASSERT_TRUE(computation.ok()); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index f3dac75a44b948c4b45b80b93e7462073010979e..24b17b71007a1872462bed1f6b86ae1a5bb9922c 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -20,10 +20,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -62,7 +63,7 @@ class BatchNormalizationTest {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_)); + input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_)); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -101,9 +102,9 @@ INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest, XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { XlaBuilder builder("subtract_in_z_one_sample"); - auto x = builder.ConstantLiteral(input_literal_); - auto y = builder.ConstantR1({3.14, 4.25}); - builder.Sub(x, y, /*broadcast_dimensions=*/{1}); + auto x = ConstantLiteral(&builder, input_literal_); + auto y = ConstantR1(&builder, {3.14, 4.25}); + Sub(x, y, /*broadcast_dimensions=*/{1}); Array4D expected(kSamples, kZ, kY, kX); Array2D pz({ @@ -117,8 +118,8 @@ XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { XlaBuilder builder("square_tesseract_elementwise"); - auto x = builder.ConstantLiteral(input_literal_); - builder.SquareF32(x); + auto x = ConstantLiteral(&builder, input_literal_); + Square(x); using tensorflow::MathUtil; @@ -134,11 +135,10 @@ XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { XLA_TEST_P(BatchNormalizationTest, SumToZ) { XlaBuilder builder("sum_to_z"); - auto input_activations = builder.ConstantLiteral(input_literal_); + auto input_activations = ConstantLiteral(&builder, input_literal_); XlaComputation add = CreateScalarAddComputation(F32, &builder); // Reduce all but the Z dimension. - builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, - {0, 2, 3}); + Reduce(input_activations, ConstantR0(&builder, 0.0f), add, {0, 2, 3}); std::vector expected = {6, 12.6}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); @@ -146,13 +146,13 @@ XLA_TEST_P(BatchNormalizationTest, SumToZ) { XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { XlaBuilder builder("square_and_reduce"); - auto input_activations = builder.ConstantLiteral(input_literal_); - auto set_means = builder.ConstantR1({2.f, 4.2f}); - auto activation_deviations = builder.Sub(input_activations, set_means, - /*broadcast_dimensions=*/{1}); + auto input_activations = ConstantLiteral(&builder, input_literal_); + auto set_means = ConstantR1(&builder, {2.f, 4.2f}); + auto activation_deviations = Sub(input_activations, set_means, + /*broadcast_dimensions=*/{1}); XlaComputation add = CreateScalarAddComputation(F32, &builder); - auto dev_squares = builder.SquareF32(activation_deviations); - builder.Reduce(dev_squares, builder.ConstantR0(0.0f), add, {0, 2, 3}); + auto dev_squares = Square(activation_deviations); + Reduce(dev_squares, ConstantR0(&builder, 0.0f), add, {0, 2, 3}); std::vector expected = {18, 0.06}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); @@ -160,8 +160,8 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { XlaBuilder builder("variance_to_stddev"); - auto variance = builder.ConstantR1({6.f, .02f}); - builder.SqrtF32(variance); + auto variance = ConstantR1(&builder, {6.f, .02f}); + Sqrt(variance); std::vector expected = {2.44948974f, 0.14142136f}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); @@ -172,50 +172,50 @@ XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { XlaBuilder builder("batch_normalize_per_spec"); auto input_activations = - CheckShape(&builder, builder.ConstantLiteral(input_literal_), + CheckShape(&builder, ConstantLiteral(&builder, input_literal_), ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); - auto gamma = builder.ConstantR1({1.0, 1.0}); - auto beta = builder.ConstantR1({0.0, 0.0}); + auto gamma = ConstantR1(&builder, {1.0, 1.0}); + auto beta = ConstantR1(&builder, {0.0, 0.0}); XlaComputation add = CreateScalarAddComputation(F32, &builder); // Reduce all dimensions except dimension 1. Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2}); auto sum = CheckShape( &builder, - builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0, 2, 3}), + Reduce(input_activations, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0, 2, 3}), TwoElementVectorF32); auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie(); auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie(); - auto count = builder.ConstantR0(ShapeUtil::ElementsIn(input_shape) / - ShapeUtil::ElementsIn(sum_shape)); - auto set_means = builder.Div(sum, count); + auto count = + ConstantR0(&builder, ShapeUtil::ElementsIn(input_shape) / + ShapeUtil::ElementsIn(sum_shape)); + auto set_means = Div(sum, count); const float kEpsilon = 1e-9f; - auto epsilon = builder.ConstantR0(kEpsilon); - auto epsilon2 = builder.ConstantR1({kEpsilon, kEpsilon}); - auto activation_deviations = builder.Sub(input_activations, set_means, - /*broadcast_dimensions=*/{1}); - auto dev_squares = builder.SquareF32(activation_deviations); - auto sum_of_squares = CheckShape( - &builder, - builder.Reduce(dev_squares, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0, 2, 3}), - TwoElementVectorF32); - auto variance = builder.Div(sum_of_squares, count); - auto standard_deviation = builder.SqrtF32(variance); + auto epsilon = ConstantR0(&builder, kEpsilon); + auto epsilon2 = ConstantR1(&builder, {kEpsilon, kEpsilon}); + auto activation_deviations = Sub(input_activations, set_means, + /*broadcast_dimensions=*/{1}); + auto dev_squares = Square(activation_deviations); + auto sum_of_squares = + CheckShape(&builder, + Reduce(dev_squares, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0, 2, 3}), + TwoElementVectorF32); + auto variance = Div(sum_of_squares, count); + auto standard_deviation = Sqrt(variance); auto standard_deviation_above_epsilon = - CheckShape(&builder, builder.Gt(standard_deviation, epsilon), + CheckShape(&builder, Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); - auto gt_eps = builder.Select(standard_deviation_above_epsilon, - standard_deviation, epsilon2); - auto normalization_factors = builder.ReciprocalF32(gt_eps); + auto gt_eps = + Select(standard_deviation_above_epsilon, standard_deviation, epsilon2); + auto normalization_factors = Reciprocal(gt_eps); auto normalized_input_activations = - builder.Mul(activation_deviations, normalization_factors, - /*broadcast_dimensions=*/{1}); - /* auto output_activations = */ builder.Add( - builder.Mul(normalized_input_activations, gamma, - /*broadcast_dimensions=*/{1}), - beta, /*broadcast_dimensions=*/{1}); + Mul(activation_deviations, normalization_factors, + /*broadcast_dimensions=*/{1}); + /* auto output_activations = */ Add(Mul(normalized_input_activations, gamma, + /*broadcast_dimensions=*/{1}), + beta, /*broadcast_dimensions=*/{1}); Array4D expected(kSamples, kZ, kY, kX); Array2D pz({ @@ -232,46 +232,47 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { const int kFeatureIndex = 3; XlaBuilder builder(TestName()); - auto operand = builder.ConstantR4FromArray4D( - {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); + auto operand = ConstantR4FromArray4D( + &builder, {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); - auto scale = builder.ConstantR1({2.0f, 3.0f}); + auto scale = ConstantR1(&builder, {2.0f, 3.0f}); - auto offset = builder.ConstantR1({1.0f, 2.0f}); + auto offset = ConstantR1(&builder, {1.0f, 2.0f}); - builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, - {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, + {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) .get(), - Literal::CreateR1({4, 5}).get(), - Literal::CreateR1({5, 5}).get()}); + LiteralUtil::CreateR1({4, 5}).get(), + LiteralUtil::CreateR1({5, 5}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); } -XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { +XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); - auto operand = builder.ConstantR4FromArray4D( + auto operand = ConstantR4FromArray4D( + &builder, {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); - auto scale = builder.ConstantR1({2.0f, 3.0f}); + auto scale = ConstantR1(&builder, {2.0f, 3.0f}); - auto offset = builder.ConstantR1({1.0f, 2.0f}); + auto offset = ConstantR1(&builder, {1.0f, 2.0f}); - builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, - {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, + {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) .get(), - Literal::CreateR1({4, 5}).get(), - Literal::CreateR1({5, 5}).get()}); + LiteralUtil::CreateR1({4, 5}).get(), + LiteralUtil::CreateR1({5, 5}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); } @@ -294,14 +295,14 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { CreateR1Parameter(std::vector(260, 1.0f), /*parameter_number=*/2, "offset", &builder, &h2); - builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/1, kFeatureIndex); + BatchNormTraining(h0, h1, h2, + /*epsilon=*/1, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) .get(), - Literal::CreateR1(std::vector(260, 1.0f)).get(), - Literal::CreateR1(std::vector(260, 0.0f)).get()}); + LiteralUtil::CreateR1(std::vector(260, 1.0f)).get(), + LiteralUtil::CreateR1(std::vector(260, 0.0f)).get()}); ComputeAndCompareTuple(&builder, *expected, {operand.get(), scale.get(), offset.get()}, @@ -327,14 +328,15 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { /*parameter_number=*/2, "offset", &builder, &h2); // var = 125, mean = 15, epsilon = -100 - builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/-100, kFeatureIndex); + BatchNormTraining(h0, h1, h2, + /*epsilon=*/-100, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR3FromArray3D({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR3FromArray3D( + {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) .get(), - Literal::CreateR1(std::vector(1, 15.0f)).get(), - Literal::CreateR1(std::vector(1, 125.0f)).get()}); + LiteralUtil::CreateR1(std::vector(1, 15.0f)).get(), + LiteralUtil::CreateR1(std::vector(1, 125.0f)).get()}); ComputeAndCompareTuple(&builder, *expected, {operand.get(), scale.get(), offset.get()}, @@ -346,26 +348,27 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { XlaBuilder builder(TestName()); auto operand = - builder.ConstantR4FromArray4D(Array4D(2, 2, 2, 1, 0.0f)); + ConstantR4FromArray4D(&builder, Array4D(2, 2, 2, 1, 0.0f)); - auto scale = builder.ConstantR1({1.0f, 1.0f}); + auto scale = ConstantR1(&builder, {1.0f, 1.0f}); - auto mean = builder.ConstantR1({0.0f, 0.0f}); + auto mean = ConstantR1(&builder, {0.0f, 0.0f}); - auto var = builder.ConstantR1({1.0f, 1.0f}); + auto var = ConstantR1(&builder, {1.0f, 1.0f}); - auto grad_output = builder.ConstantR4FromArray4D( + auto grad_output = ConstantR4FromArray4D( + &builder, {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); - builder.BatchNormGrad(operand, scale, mean, var, grad_output, - /*epsilon=*/0.0, kFeatureIndex); + BatchNormGrad(operand, scale, mean, var, grad_output, + /*epsilon=*/0.0, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, - {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, + {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) .get(), - Literal::CreateR1({0, 0}).get(), - Literal::CreateR1({16, 20}).get()}); + LiteralUtil::CreateR1({0, 0}).get(), + LiteralUtil::CreateR1({16, 20}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); } @@ -511,22 +514,23 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, scale4D, offset4D, epsilon); - auto expected_normalized = Literal::CreateR4FromArray4D(normalized); + auto expected_normalized = + LiteralUtil::CreateR4FromArray4D(normalized); - auto offset_literal = Literal::CreateR1(offset); - auto scale_literal = Literal::CreateR1(scale); - auto input_literal = Literal::CreateR4FromArray4D(input_array); + auto offset_literal = LiteralUtil::CreateR1(offset); + auto scale_literal = LiteralUtil::CreateR1(scale); + auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = - builder.Parameter(0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal->shape(), "input"); auto scale_activations = - builder.Parameter(1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal->shape(), "offset"); auto offset_activations = - builder.Parameter(2, offset_literal->shape(), "scale"); + Parameter(&builder, 2, offset_literal->shape(), "scale"); - auto expected = Literal::MakeTuple({expected_normalized.get(), - Literal::CreateR1(mean).get(), - Literal::CreateR1(var).get()}); + auto expected = LiteralUtil::MakeTuple( + {expected_normalized.get(), LiteralUtil::CreateR1(mean).get(), + LiteralUtil::CreateR1(var).get()}); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -535,8 +539,8 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { std::unique_ptr offset_data = client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); - builder.BatchNormTraining(input_activations, scale_activations, - offset_activations, epsilon, feature_index); + BatchNormTraining(input_activations, scale_activations, offset_activations, + epsilon, feature_index); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor @@ -611,21 +615,21 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, scale4D, offset4D, epsilon); - auto offset_literal = Literal::CreateR1(offset); - auto scale_literal = Literal::CreateR1(scale); - auto mean_literal = Literal::CreateR1(mean); - auto var_literal = Literal::CreateR1(var); - auto input_literal = Literal::CreateR4FromArray4D(input_array); + auto offset_literal = LiteralUtil::CreateR1(offset); + auto scale_literal = LiteralUtil::CreateR1(scale); + auto mean_literal = LiteralUtil::CreateR1(mean); + auto var_literal = LiteralUtil::CreateR1(var); + auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = - builder.Parameter(0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal->shape(), "input"); auto scale_activations = - builder.Parameter(1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal->shape(), "offset"); auto offset_activations = - builder.Parameter(2, offset_literal->shape(), "scale"); - auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean"); + Parameter(&builder, 2, offset_literal->shape(), "scale"); + auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean"); auto variance_activations = - builder.Parameter(4, var_literal->shape(), "variance"); + Parameter(&builder, 4, var_literal->shape(), "variance"); Array4D expected = normalized; @@ -640,9 +644,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { std::unique_ptr variance_data = client_->TransferToServer(*var_literal).ConsumeValueOrDie(); - builder.BatchNormInference(input_activations, scale_activations, - offset_activations, mean_activations, - variance_activations, epsilon, feature_index); + BatchNormInference(input_activations, scale_activations, offset_activations, + mean_activations, variance_activations, epsilon, + feature_index); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor @@ -729,7 +733,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { var4D, [epsilon](float a) { return a + epsilon; }); auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D( - var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); }); + var_add_epsilon, [](float a) { return 1 / std::sqrt(a); }); auto grad_output_times_var = *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon, @@ -798,21 +802,23 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { }); auto expected_grad_activation = - Literal::CreateR4FromArray4D(grad_activation); + LiteralUtil::CreateR4FromArray4D(grad_activation); - auto input_literal = Literal::CreateR4FromArray4D(input_array); - auto scale_literal = Literal::CreateR1(scale); - auto mean_literal = Literal::CreateR1(mean); - auto var_literal = Literal::CreateR1(var); + auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); + auto scale_literal = LiteralUtil::CreateR1(scale); + auto mean_literal = LiteralUtil::CreateR1(mean); + auto var_literal = LiteralUtil::CreateR1(var); auto grad_output_literal = - Literal::CreateR4FromArray4D(grad_output_array); - - auto input_parameter = builder.Parameter(0, input_literal->shape(), "input"); - auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale"); - auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean"); - auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance"); + LiteralUtil::CreateR4FromArray4D(grad_output_array); + + auto input_parameter = + Parameter(&builder, 0, input_literal->shape(), "input"); + auto scale_parameter = + Parameter(&builder, 1, scale_literal->shape(), "scale"); + auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean"); + auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance"); auto grad_output_parameter = - builder.Parameter(4, grad_output_literal->shape(), "grad_output"); + Parameter(&builder, 4, grad_output_literal->shape(), "grad_output"); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -825,14 +831,13 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { std::unique_ptr grad_output_data = client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); - builder.BatchNormGrad(input_parameter, scale_parameter, mean_parameter, - var_parameter, grad_output_parameter, epsilon, - feature_index); + BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter, + grad_output_parameter, epsilon, feature_index); auto expected = - Literal::MakeTuple({expected_grad_activation.get(), - Literal::CreateR1(grad_scale).get(), - Literal::CreateR1(grad_offset).get()}); + LiteralUtil::MakeTuple({expected_grad_activation.get(), + LiteralUtil::CreateR1(grad_scale).get(), + LiteralUtil::CreateR1(grad_offset).get()}); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index ca337e78840e77377719636cd4cf33af2578210d..6c20f654fe3df6a28e9633cd832c11b487894bad 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -51,9 +51,9 @@ class Bfloat16Test : public ClientLibraryTestBase { XLA_TEST_F(Bfloat16Test, ScalarOperation) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR0(static_cast(2.0f)); - auto y = builder.ConstantR0(static_cast(1.0f)); - builder.Add(x, y); + auto x = ConstantR0(&builder, static_cast(2.0f)); + auto y = ConstantR0(&builder, static_cast(1.0f)); + Add(x, y); ComputeAndCompareR0(&builder, static_cast(3.0f), {}, error_spec_); @@ -61,8 +61,8 @@ XLA_TEST_F(Bfloat16Test, ScalarOperation) { XLA_TEST_F(Bfloat16Test, LogOperation) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR0(static_cast(4.0f)); - builder.Log(x); + auto x = ConstantR0(&builder, static_cast(4.0f)); + Log(x); ComputeAndCompareR0(&builder, static_cast(1.387f), {}, error_spec_); @@ -70,7 +70,7 @@ XLA_TEST_F(Bfloat16Test, LogOperation) { XLA_TEST_F(Bfloat16Test, NegateScalarF16) { XlaBuilder builder(TestName()); - builder.Neg(builder.ConstantR0(static_cast(2.1f))); + Neg(ConstantR0(&builder, static_cast(2.1f))); ComputeAndCompareR0(&builder, static_cast(-2.1f), {}, error_spec_); @@ -80,33 +80,33 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); - auto operand = builder.ConstantR4FromArray4D( + auto operand = ConstantR4FromArray4D( + &builder, {{{{static_cast(1.f)}, {static_cast(2.f)}}, {{static_cast(3.f)}, {static_cast(4.f)}}}, {{{static_cast(5.f)}, {static_cast(6.f)}}, {{static_cast(7.f)}, {static_cast(8.f)}}}}); - auto scale = builder.ConstantR1( - {static_cast(2.0f), static_cast(3.0f)}); + auto scale = ConstantR1( + &builder, {static_cast(2.0f), static_cast(3.0f)}); - auto offset = builder.ConstantR1( - {static_cast(1.0f), static_cast(2.0f)}); + auto offset = ConstantR1( + &builder, {static_cast(1.0f), static_cast(2.0f)}); - auto tuple = builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4( + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4( {{{{static_cast(-1.6875f)}, {static_cast(-2.04f)}}, {{static_cast(0.105f)}, {static_cast(0.66f)}}}, {{{static_cast(1.89f)}, {static_cast(3.35f)}}, {{static_cast(3.7f)}, {static_cast(6.04f)}}}}) .get(), - Literal::CreateR1( + LiteralUtil::CreateR1( {static_cast(4), static_cast(5)}) .get(), - Literal::CreateR1( + LiteralUtil::CreateR1( {static_cast(5), static_cast(5)}) .get()}); @@ -117,38 +117,39 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { const int kFeatureIndex = 2; XlaBuilder builder(TestName()); - auto operand = builder.ConstantR4FromArray4D( - Array4D(2, 2, 2, 1, static_cast(0.0f))); + auto operand = ConstantR4FromArray4D( + &builder, Array4D(2, 2, 2, 1, static_cast(0.0f))); - auto scale = builder.ConstantR1( - {static_cast(1.0f), static_cast(1.0f)}); + auto scale = ConstantR1( + &builder, {static_cast(1.0f), static_cast(1.0f)}); - auto mean = builder.ConstantR1( - {static_cast(0.0f), static_cast(0.0f)}); + auto mean = ConstantR1( + &builder, {static_cast(0.0f), static_cast(0.0f)}); - auto var = builder.ConstantR1( - {static_cast(1.0f), static_cast(1.0f)}); + auto var = ConstantR1( + &builder, {static_cast(1.0f), static_cast(1.0f)}); - auto grad_output = builder.ConstantR4FromArray4D( + auto grad_output = ConstantR4FromArray4D( + &builder, {{{{static_cast(1.f)}, {static_cast(2.f)}}, {{static_cast(3.f)}, {static_cast(4.f)}}}, {{{static_cast(5.f)}, {static_cast(6.f)}}, {{static_cast(7.f)}, {static_cast(8.f)}}}}); - builder.BatchNormGrad(operand, scale, mean, var, grad_output, - /*epsilon=*/0.0, kFeatureIndex); + BatchNormGrad(operand, scale, mean, var, grad_output, + /*epsilon=*/0.0, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4( + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4( {{{{static_cast(-3.f)}, {static_cast(-3.f)}}, {{static_cast(-1.f)}, {static_cast(-1.f)}}}, {{{static_cast(1.f)}, {static_cast(1.f)}}, {{static_cast(3.f)}, {static_cast(3.f)}}}}) .get(), - Literal::CreateR1( + LiteralUtil::CreateR1( {static_cast(0), static_cast(0)}) .get(), - Literal::CreateR1( + LiteralUtil::CreateR1( {static_cast(16), static_cast(20)}) .get()}); diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc index 48203b1d40ea69ff00a57c2c9e42620739b23d59..0d7a3aa46a9c12c19d954c11ae3a2cccbed886ef 100644 --- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc +++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -33,9 +33,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_32x4) { auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 4); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - builder.Add(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Add(lhs, rhs); auto aexpected = ReferenceUtil::MapWithIndexArray2D( *alhs, [&](float lhs_value, int64 row, int64 col) { @@ -49,9 +49,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_129x129) { auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 129); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - builder.Add(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Add(lhs, rhs); auto aexpected = ReferenceUtil::MapWithIndexArray2D( *alhs, [&](float lhs_value, int64 row, int64 col) { @@ -65,9 +65,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_9x5) { auto arhs = MakeLinspaceArray2D(0.0, 1.0, 9, 1); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - builder.Add(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Add(lhs, rhs); auto aexpected = ReferenceUtil::MapWithIndexArray2D( *alhs, [&](float lhs_value, int64 row, int64 col) { @@ -81,9 +81,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) { auto arhs = MakeLinspaceArray2D(0.0, 1.0, 129, 1); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - builder.Add(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Add(lhs, rhs); auto aexpected = ReferenceUtil::MapWithIndexArray2D( *alhs, [&](float lhs_value, int64 row, int64 col) { @@ -94,11 +94,12 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) { TEST_F(BinopScalingTest, R0PlusR2F32) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR0(42.0); - auto rhs = builder.ConstantR2({ - {1.0, 2.0}, {3.0, 4.0}, - }); - builder.Add(lhs, rhs); + auto lhs = ConstantR0(&builder, 42.0); + auto rhs = ConstantR2(&builder, { + {1.0, 2.0}, + {3.0, 4.0}, + }); + Add(lhs, rhs); Array2D expected(2, 2); expected(0, 0) = 42.0 + 1.0; @@ -129,9 +130,9 @@ TEST_F(BinopScalingTest, R4PlusR0S32) { }); // clang-format on - auto lhs = builder.ConstantR4FromArray4D(lhs_array); - auto rhs = builder.ConstantR0(42); - builder.Add(lhs, rhs); + auto lhs = ConstantR4FromArray4D(&builder, lhs_array); + auto rhs = ConstantR0(&builder, 42); + Add(lhs, rhs); ComputeAndCompareR4(&builder, expected, {}); } diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc index bff60f25ec8f15d372d251ac313200301a04f20f..c6b5108fe9e5bcf843982676d822f1942359da71 100644 --- a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc +++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -43,8 +43,8 @@ class BitcastConvertTest : public ClientLibraryTestBase { TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42, 64}); - builder.BitcastConvertType(a, S32); + auto a = ConstantR1(&builder, {42, 64}); + BitcastConvertType(a, S32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -52,8 +52,8 @@ TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) { TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0f, 64.0f}); - builder.BitcastConvertType(a, F32); + auto a = ConstantR1(&builder, {42.0f, 64.0f}); + BitcastConvertType(a, F32); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}); @@ -62,10 +62,10 @@ TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) { TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { XlaBuilder builder(TestName()); auto a = - builder.ConstantR1({0, static_cast(0x80000000), 0x3F800000, - static_cast(0xBF800000), 0x3F000000, - static_cast(0xBF000000)}); - builder.BitcastConvertType(a, F32); + ConstantR1(&builder, {0, static_cast(0x80000000), + 0x3F800000, static_cast(0xBF800000), + 0x3F000000, static_cast(0xBF000000)}); + BitcastConvertType(a, F32); std::vector expected = {0.0f, -0.0f, 1.0f, -1.0f, 0.5f, -0.5f}; ComputeAndCompareR1(&builder, expected, {}); @@ -73,8 +73,8 @@ TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.BitcastConvertType(a, F32); + auto a = ConstantR1(&builder, {}); + BitcastConvertType(a, F32); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}); @@ -82,8 +82,8 @@ XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.6, 64.4}); - builder.BitcastConvertType(a, S32); + auto a = ConstantR1(&builder, {42.6, 64.4}); + BitcastConvertType(a, S32); std::vector expected = {0x422a6666, 0x4280cccd}; ComputeAndCompareR1(&builder, expected, {}); @@ -91,9 +91,9 @@ TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) { TEST_F(BitcastConvertTest, ConvertS32Extremes) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {std::numeric_limits::min(), std::numeric_limits::max()}); - builder.BitcastConvertType(a, F32); + auto a = ConstantR1(&builder, {std::numeric_limits::min(), + std::numeric_limits::max()}); + BitcastConvertType(a, F32); std::vector expected = {-0.0f, NAN}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0, 0)); @@ -102,10 +102,10 @@ TEST_F(BitcastConvertTest, ConvertS32Extremes) { TEST_F(BitcastConvertTest, ConvertMapToS32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); - auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); - b->BitcastConvertType(param, S32); - auto a = builder.ConstantR1({42.0f, 64.0f}); - builder.Map({a}, b->BuildAndNoteError(), {0}); + auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "in"); + BitcastConvertType(param, S32); + auto a = ConstantR1(&builder, {42.0f, 64.0f}); + Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {0x42280000, 0x42800000}; ComputeAndCompareR1(&builder, expected, {}); @@ -114,10 +114,10 @@ TEST_F(BitcastConvertTest, ConvertMapToS32) { TEST_F(BitcastConvertTest, ConvertMapToF32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); - auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); - b->BitcastConvertType(param, F32); - auto a = builder.ConstantR1({0x42280000, 0x42800000}); - builder.Map({a}, b->BuildAndNoteError(), {0}); + auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(S32, {}), "in"); + BitcastConvertType(param, F32); + auto a = ConstantR1(&builder, {0x42280000, 0x42800000}); + Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}); @@ -130,9 +130,9 @@ TEST_F(BitcastConvertTest, ConvertMapToF32) { // the new convert should have the same element type as the old convert. TEST_F(BitcastConvertTest, ConvertReshape) { XlaBuilder builder(TestName()); - auto input = builder.ConstantR1({0x42280000}); - auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); - builder.BitcastConvertType(reshape, F32); + auto input = ConstantR1(&builder, {0x42280000}); + auto reshape = Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); + BitcastConvertType(reshape, F32); ComputeAndCompareR0(&builder, 42.0f, {}); } diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 3a0f51fc66d65c8684bd607b9e8103559cd4d8d4..1d28e85b16596b0ec2717138fb2081878203e8b2 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -20,7 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -37,17 +38,17 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { XlaBuilder* builder) { switch (op) { case HloOpcode::kMinimum: { - return builder->Min(lhs, rhs); + return Min(lhs, rhs); } case HloOpcode::kMaximum: { - return builder->Max(lhs, rhs); + return Max(lhs, rhs); } case HloOpcode::kMultiply: { - return builder->Mul(lhs, rhs); + return Mul(lhs, rhs); } default: { // Default to Add - return builder->Add(lhs, rhs); + return Add(lhs, rhs); } } } @@ -58,7 +59,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array3D* r3_array, float start, float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); - auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout( + auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r3_global_data = client_->TransferToServer(*r3_data).ConsumeValueOrDie(); @@ -71,7 +72,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array2D* r2_array, float start, float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); - auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout( + auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r2_global_data = client_->TransferToServer(*r2_data).ConsumeValueOrDie(); @@ -104,13 +105,13 @@ using ::testing::HasSubstr; XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR0(1.5), {}); + Broadcast(ConstantR0(&b, 1.5), {}); ComputeAndCompareR0(&b, 1.5, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR0(2.25), {2, 3}); + Broadcast(ConstantR0(&b, 2.25), {2, 3}); Array2D expected(2, 3, 2.25); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } @@ -122,7 +123,7 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { CreateR0Parameter(2.25f, /*parameter_number=*/0, /*name=*/"src", /*builder=*/&b, /*data_handle=*/&src); - b.Broadcast(src, {2, 3}); + Broadcast(src, {2, 3}); Array2D expected(2, 3, 2.25); ComputeAndCompareR2(&b, expected, {param_data.get()}, ErrorSpec(0.0001)); @@ -130,21 +131,21 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR0(2.25), {2, 0}); + Broadcast(ConstantR0(&b, 2.25), {2, 0}); Array2D expected(2, 0); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR0(2.25), {0, 2}); + Broadcast(ConstantR0(&b, 2.25), {0, 2}); Array2D expected(0, 2); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR1({1, 2, 3}), {2}); + Broadcast(ConstantR1(&b, {1, 2, 3}), {2}); Array2D expected(2, 3); expected(0, 0) = 1; @@ -156,6 +157,86 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } +XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR1(&b, {1, 2}), + ShapeUtil::MakeShape(F32, {2, 2}), {1}); + + Array2D expected(2, 2); + expected(0, 0) = 1; + expected(0, 1) = 2; + expected(1, 0) = 1; + expected(1, 1) = 2; + + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR1(&b, {1, 2}), + ShapeUtil::MakeShape(F32, {2, 2}), {0}); + + Array2D expected(2, 2); + expected(0, 0) = 1; + expected(0, 1) = 1; + expected(1, 0) = 2; + expected(1, 1) = 2; + + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), + ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 1}); + + Array3D expected(2, 2, 2); + expected(0, 0, 0) = 1.0; + expected(1, 0, 0) = 2.0; + expected(0, 0, 1) = 1.0; + expected(1, 0, 1) = 2.0; + expected(0, 1, 0) = 5.0; + expected(1, 1, 0) = 6.0; + expected(1, 1, 1) = 6.0; + expected(0, 1, 1) = 5.0; + + ComputeAndCompareR3(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), + ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 2}); + + Array3D expected(2, 2, 2); + expected(0, 0, 0) = 1.0; + expected(1, 0, 0) = 2.0; + expected(0, 0, 1) = 5.0; + expected(1, 0, 1) = 6.0; + expected(0, 1, 0) = 1.0; + expected(1, 1, 0) = 2.0; + expected(1, 1, 1) = 6.0; + expected(0, 1, 1) = 5.0; + + ComputeAndCompareR3(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR1(&b, {1, 2}), + ShapeUtil::MakeShape(F32, {3, 2}), {1}); + + Array2D expected(3, 2); + expected(0, 0) = 1; + expected(0, 1) = 2; + expected(1, 0) = 1; + expected(1, 1) = 2; + expected(2, 0) = 1; + expected(2, 1) = 2; + + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + // Tests implicit broadcasting of PREDs. XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { XlaBuilder b(TestName()); @@ -172,7 +253,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { XlaOp x, y; auto x_data = CreateR2Parameter(x_vals, 0, "x", &b, &x); auto y_data = CreateR3Parameter(y_vals, 1, "y", &b, &y); - b.And(x, y, /*broadcast_dimensions=*/{1, 2}); + And(x, y, /*broadcast_dimensions=*/{1, 2}); Array3D expected(2, 2, 1); expected(0, 0, 0) = false; @@ -185,7 +266,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR1({}), {2}); + Broadcast(ConstantR1(&b, {}), {2}); Array2D expected(2, 0); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); @@ -193,7 +274,7 @@ XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) { XlaBuilder b(TestName()); - b.Broadcast(b.ConstantR1({1, 2, 3}), {0}); + Broadcast(ConstantR1(&b, {1, 2, 3}), {0}); Array2D expected(0, 3); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); @@ -209,14 +290,14 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { // dimensions. XlaBuilder b(TestName()); - b.Add(b.ConstantR2({{1.0, 5.0}}), - b.ConstantLiteral(*Literal::CreateR3( - {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), - /*broadcast_dimensions=*/{1, 2}); + Add(ConstantR2(&b, {{1.0, 5.0}}), + ConstantLiteral(&b, *LiteralUtil::CreateR3( + {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), + /*broadcast_dimensions=*/{1, 2}); auto expected = - Literal::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, - {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); + LiteralUtil::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, + {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -260,9 +341,10 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape, &r3_implicit_array, 1.0, 0.2, 56789); - auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input"); - auto r3_parameter = builder.Parameter(1, r3_shape, "input"); - XlaOp op = BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); + auto r3_implicit_parameter = + Parameter(&builder, 0, r3_implicit_shape, "input"); + auto r3_parameter = Parameter(&builder, 1, r3_shape, "input"); + BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); Array3D expected_array(spec.output_bounds[0], spec.output_bounds[1], spec.output_bounds[2]); @@ -284,7 +366,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { } } } - auto expected = Literal::CreateR3FromArray3D(expected_array); + auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r3_implicit_global_data.get(), r3_global_data.get()}, @@ -306,10 +388,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h); auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h); - b.Add(r3h, r1h); + Add(r3h, r1h); auto expected = - Literal::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); + LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, ErrorSpec(0.0001)); @@ -317,79 +399,81 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}})); + auto r3 = ConstantLiteral( + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); + LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}, {2}}})); + auto r3 = ConstantLiteral( + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); + LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}, {3, 4}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = + ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); + auto r3 = ConstantLiteral( + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); + LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}, {{3, 4}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = + ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); + auto r3 = ConstantLiteral( + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); + LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XlaBuilder b(TestName()); - auto r1 = - b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}, {{3}, {4}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = ConstantLiteral( + &b, *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); + auto r3 = ConstantLiteral( + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); + LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}}})); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}}})); + auto r3 = ConstantLiteral( + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); + LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -509,14 +593,14 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789); auto r2_implicit_parameter1 = - builder.Parameter(0, r2_implicit_shape1, "input0"); - auto r2_parameter = builder.Parameter(1, r2_shape, "input1"); + Parameter(&builder, 0, r2_implicit_shape1, "input0"); + auto r2_parameter = Parameter(&builder, 1, r2_shape, "input1"); auto r2_implicit_parameter2 = - builder.Parameter(2, r2_implicit_shape2, "input2"); + Parameter(&builder, 2, r2_implicit_shape2, "input2"); XlaOp op1 = BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder); - XlaOp op2 = BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); + BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); Array2D expected_array(spec.output_bounds[0], spec.output_bounds[1]); @@ -530,7 +614,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { *v = ApplyOpToFloats(spec.op2, tmp, v3); }); - auto expected = Literal::CreateR2FromArray2D(expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r2_implicit_global_data1.get(), r2_global_data.get(), @@ -544,80 +628,82 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}})); - auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); - b.Add(r2, r1); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}})); + auto r2 = + ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + Add(r2, r1); - auto expected = Literal::CreateR2({{2, 4}, {4, 6}}); + auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1}, {2}})); - auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); - b.Add(r2, r1); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1}, {2}})); + auto r2 = + ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + Add(r2, r1); - auto expected = Literal::CreateR2({{2, 3}, {5, 6}}); + auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XlaBuilder b(TestName()); - auto r1 = b.ConstantR1({10, 20}); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r3, r1, {0}); + auto r1 = ConstantR1(&b, {10, 20}); + auto r3 = ConstantLiteral( + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r3, r1, {0}); - auto expected = - Literal::CreateR3({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); + auto expected = LiteralUtil::CreateR3( + {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XlaBuilder b(TestName()); - auto r1 = b.ConstantR1({10, 20}); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r1, r3, {1}); + auto r1 = ConstantR1(&b, {10, 20}); + auto r3 = ConstantLiteral( + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r1, r3, {1}); - auto expected = - Literal::CreateR3({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); + auto expected = LiteralUtil::CreateR3( + {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XlaBuilder b(TestName()); - auto r1 = b.ConstantR1({10, 20}); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); - b.Add(r1, r3, {2}); + auto r1 = ConstantR1(&b, {10, 20}); + auto r3 = ConstantLiteral( + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + Add(r1, r3, {2}); - auto expected = - Literal::CreateR3({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); + auto expected = LiteralUtil::CreateR3( + {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { XlaBuilder b(TestName()); - auto r1_0 = b.ConstantR1({1000, 2000}); - auto r1_1 = b.ConstantR1({100, 200}); - auto r1_2 = b.ConstantR1({10, 20}); - auto r3 = b.ConstantLiteral( - *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + auto r1_0 = ConstantR1(&b, {1000, 2000}); + auto r1_1 = ConstantR1(&b, {100, 200}); + auto r1_2 = ConstantR1(&b, {10, 20}); + auto r3 = ConstantLiteral( + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { - r3 = b.Add(r1_0, r3, {0}); - r3 = b.Add(r3, r1_1, {1}); - r3 = b.Add(r1_2, r3, {2}); + r3 = Add(r1_0, r3, {0}); + r3 = Add(r3, r1_1, {1}); + r3 = Add(r1_2, r3, {2}); } - r3 = b.Mul(r3, b.ConstantR0(-2)); + r3 = Mul(r3, ConstantR0(&b, -2)); - auto expected = Literal::CreateR3( + auto expected = LiteralUtil::CreateR3( {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); @@ -626,19 +712,19 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { XlaBuilder b(TestName()); - auto r1_0 = b.ConstantR1({1000, 2000}); - auto r1_1 = b.ConstantR1({100, 200}); - auto r1_2 = b.ConstantR1({10, 20}); - auto r0 = b.ConstantR0(3); - auto r3 = b.Broadcast(r0, {2, 2, 2}); + auto r1_0 = ConstantR1(&b, {1000, 2000}); + auto r1_1 = ConstantR1(&b, {100, 200}); + auto r1_2 = ConstantR1(&b, {10, 20}); + auto r0 = ConstantR0(&b, 3); + auto r3 = Broadcast(r0, {2, 2, 2}); for (int i = 0; i < 3; ++i) { - r3 = b.Add(r1_0, r3, {0}); - r3 = b.Add(r3, r1_1, {1}); - r3 = b.Add(r1_2, r3, {2}); + r3 = Add(r1_0, r3, {0}); + r3 = Add(r3, r1_1, {1}); + r3 = Add(r1_2, r3, {2}); } - r3 = b.Mul(r3, b.ConstantR0(-1)); + r3 = Mul(r3, ConstantR0(&b, -1)); - auto expected = Literal::CreateR3( + auto expected = LiteralUtil::CreateR3( {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); @@ -650,10 +736,10 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { // results in a shape incompatible with the lhs [2, 3, 1]. XlaBuilder b(TestName()); - b.Add(b.ConstantR2({{1.0, 5.0}, {1.0, 5.0}}), - b.ConstantLiteral(*Literal::CreateR3( - {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), - /*broadcast_dimensions=*/{1, 2}); + Add(ConstantR2(&b, {{1.0, 5.0}, {1.0, 5.0}}), + ConstantLiteral(&b, *LiteralUtil::CreateR3( + {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), + /*broadcast_dimensions=*/{1, 2}); auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); @@ -665,8 +751,8 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { // Test invalid broadcasting with [1, 2] and [2, 3] inputs. XlaBuilder b(TestName()); - b.Add(b.ConstantR2({{1.0, 2.0}}), - b.ConstantR2({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); + Add(ConstantR2(&b, {{1.0, 2.0}}), + ConstantR2(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); @@ -678,8 +764,8 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { // Test invalid broadcasting with [1, 2] and [2, 3] inputs. XlaBuilder b(TestName()); - b.Add(b.ConstantR2({{1.0, 2.0}}), - b.ConstantR2({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); + Add(ConstantR2(&b, {{1.0, 2.0}}), + ConstantR2(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 51b9f0d3e330e73f5d110f0a62f824179d5c7cf7..c7b94b5bbaaa512ad36056f9e68a87cc706c24b1 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -37,7 +37,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { // Test degenerate case of broadcasting a scalar into a scalar. auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {}), input, {})); @@ -46,14 +46,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0(42.0), *result, - error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0(42.0), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {})); @@ -63,14 +63,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, + *LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple // to enable testing of the results. @@ -86,18 +86,18 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), + *LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), LiteralSlice(*result, {0}), error_spec_)); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), + *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), LiteralSlice(*result, {1}), error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); @@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE( - LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, + error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { @@ -116,7 +116,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { // the dimensions, ie transpose. auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); @@ -125,15 +125,15 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE( - LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + *LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, + error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); @@ -143,15 +143,15 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), + *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1.0, 2.0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0, 2.0}))); // Broadcast vector in dimension 1. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -166,8 +166,9 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -176,7 +177,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { int64 r1_size = input_data.size(); std::iota(input_data.begin(), input_data.end(), 0.0f); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(input_data))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(input_data))); // Broadcast vector in dimension 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -196,8 +197,9 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -207,7 +209,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { std::vector r1_array(64, 42.0); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(r1_array))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(r1_array))); // Broadcast vector in dimension 1. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -218,14 +220,14 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array), + EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); @@ -238,15 +240,16 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { auto builder = HloComputation::Builder(TestName()); Array2D to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}}); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2FromArray2D(to_broadcast))); + LiteralUtil::CreateR2FromArray2D(to_broadcast))); // Broadcast vector in dimensions 2 and 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -260,8 +263,9 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -280,7 +284,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { } } auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR3FromArray3D(input_vals))); + LiteralUtil::CreateR3FromArray3D(input_vals))); // Broadcast vector in dimensions 2 and 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -291,8 +295,9 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index 5fd33b50c94356839bbed58acd43b7d0286f4a7e..b1d18210eaafdfec0920c0cccaa0dfdbd6de5609 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -16,8 +16,9 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -34,7 +35,7 @@ class CallOpTest : public ClientLibraryTestBase { protected: XlaComputation CreateR0F32IdentityComputation() { XlaBuilder builder("Identity"); - builder.Parameter(0, r0f32_, "x"); + Parameter(&builder, 0, r0f32_, "x"); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -42,9 +43,9 @@ class CallOpTest : public ClientLibraryTestBase { XlaComputation CreateR1S0F32AdditionComputation() { XlaBuilder builder("Addition"); - auto x = builder.Parameter(0, r1s0f32_, "x"); - auto y = builder.Parameter(1, r1s0f32_, "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, r1s0f32_, "x"); + auto y = Parameter(&builder, 1, r1s0f32_, "y"); + Add(x, y); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -52,9 +53,9 @@ class CallOpTest : public ClientLibraryTestBase { XlaComputation CreateR1S2F32AdditionComputation() { XlaBuilder builder("Addition"); - auto x = builder.Parameter(0, r1s2f32_, "x"); - auto y = builder.Parameter(1, r1s2f32_, "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, r1s2f32_, "x"); + auto y = Parameter(&builder, 1, r1s2f32_, "y"); + Add(x, y); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -62,7 +63,7 @@ class CallOpTest : public ClientLibraryTestBase { XlaComputation CreateR0F32TupleComputation() { XlaBuilder builder("Tuple"); - builder.Tuple({builder.Parameter(0, r0f32_, "x")}); + Tuple(&builder, {Parameter(&builder, 0, r0f32_, "x")}); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -76,8 +77,9 @@ class CallOpTest : public ClientLibraryTestBase { XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32IdentityComputation(); - auto constant = builder.ConstantLiteral(*Literal::CreateR0(42.0)); - builder.Call(callee, {constant}); + auto constant = + ConstantLiteral(&builder, *LiteralUtil::CreateR0(42.0)); + Call(&builder, callee, {constant}); ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); } @@ -85,9 +87,9 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S0F32AdditionComputation(); - auto x = builder.ConstantLiteral(*Literal::CreateR1({})); - auto y = builder.ConstantLiteral(*Literal::CreateR1({})); - builder.Call(callee, {x, y}); + auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); + auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); + Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); } @@ -95,9 +97,11 @@ XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S2F32AdditionComputation(); - auto x = builder.ConstantLiteral(*Literal::CreateR1({1.0f, 2.0f})); - auto y = builder.ConstantLiteral(*Literal::CreateR1({2.0f, 3.0f})); - builder.Call(callee, {x, y}); + auto x = + ConstantLiteral(&builder, *LiteralUtil::CreateR1({1.0f, 2.0f})); + auto y = + ConstantLiteral(&builder, *LiteralUtil::CreateR1({2.0f, 3.0f})); + Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); } @@ -105,40 +109,40 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { XlaBuilder builder("inner"); { - auto x = builder.Parameter(0, r0f32_, "x"); - builder.Add(x, builder.ConstantR0(1.0)); + auto x = Parameter(&builder, 0, r0f32_, "x"); + Add(x, ConstantR0(&builder, 1.0)); } TF_ASSERT_OK_AND_ASSIGN(XlaComputation inner, builder.Build()); XlaBuilder builder2("outer"); { - auto x = builder2.Parameter(0, r0f32_, "x"); - x = builder2.Call(inner, {x}); - x = builder2.Call(inner, {x}); - x = builder2.Call(inner, {x}); + auto x = Parameter(&builder2, 0, r0f32_, "x"); + x = Call(&builder2, inner, {x}); + x = Call(&builder2, inner, {x}); + x = Call(&builder2, inner, {x}); } TF_ASSERT_OK_AND_ASSIGN(XlaComputation outer, builder2.Build()); XlaBuilder builder3("outermost"); { - auto x = builder3.Parameter(0, r0f32_, "x"); - x = builder3.Call(outer, {x}); - x = builder3.Call(outer, {x}); - x = builder3.Call(outer, {x}); + auto x = Parameter(&builder3, 0, r0f32_, "x"); + x = Call(&builder3, outer, {x}); + x = Call(&builder3, outer, {x}); + x = Call(&builder3, outer, {x}); } TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr start, - client_->TransferToServer(*Literal::CreateR0(1.0f))); + client_->TransferToServer(*LiteralUtil::CreateR0(1.0f))); ComputeAndCompareR0(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f)); } XLA_TEST_F(CallOpTest, CallR0F32Tuple) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32TupleComputation(); - auto elem = Literal::CreateR0(42.0); - auto tuple = Literal::MakeTuple({elem.get()}); - builder.Call(callee, {builder.ConstantLiteral(*elem)}); + auto elem = LiteralUtil::CreateR0(42.0); + auto tuple = LiteralUtil::MakeTuple({elem.get()}); + Call(&builder, callee, {ConstantLiteral(&builder, *elem)}); ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f)); } diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 660ff0cad5666219a4a7cb1eedbed03f06e651ba..a4eb57fc7b9abd460a7d158d0dc629eba88018cd 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -36,11 +36,11 @@ class CheckExecutionArityTest : public ClientLibraryTestBase {}; TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XlaBuilder builder("add_two_params"); - auto param_literal = Literal::CreateR1({1.1f, 2.2f}); + auto param_literal = LiteralUtil::CreateR1({1.1f, 2.2f}); - auto p0 = builder.Parameter(0, param_literal->shape(), "param0"); - auto p1 = builder.Parameter(1, param_literal->shape(), "param1"); - auto add = builder.Add(p0, p1); + auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0"); + auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1"); + Add(p0, p1); auto param0_data = client_->TransferToServer(*param_literal).ConsumeValueOrDie(); @@ -77,20 +77,20 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { XlaBuilder builder("add_two_params"); - auto p0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); - auto p1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4}), "param1"); - auto add = builder.Mul(p0, p1); + auto p0 = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0"); + auto p1 = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {4}), "param1"); + Mul(p0, p1); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); auto computation = computation_status.ConsumeValueOrDie(); - auto f32_literal = Literal::CreateR0(1.1f); + auto f32_literal = LiteralUtil::CreateR0(1.1f); auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie(); - auto f32_4_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); + auto f32_4_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); auto f32_4_data = client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie(); - auto u8_4_literal = Literal::CreateR1U8("hola"); + auto u8_4_literal = LiteralUtil::CreateR1U8("hola"); auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie(); // Match diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index bf8ed4d9fb0bc61b86ef0b5872711a122a3d416b..59d917054be2ebe3a25f902f51972a682a5231b6 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -156,7 +157,7 @@ string ClientLibraryTestBase::ExecuteToString( void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = Literal::CreateR1(expected); + std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -272,10 +273,16 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), arguments_passed_in.end()); + + // Transfer and use elements of arguments_, if the AddParam() API was used. + std::vector> owning_arguments; if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - arguments.push_back(argument.get()); + owning_arguments.push_back( + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) + .ValueOrDie()); + arguments.push_back(owning_arguments.back().get()); } } @@ -294,7 +301,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = Literal::ConvertF32ToBF16(expected); + converted_expected = LiteralUtil::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -330,10 +337,16 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ErrorSpec error, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), arguments_passed_in.end()); + + // Transfer and use elements of arguments_, if the AddParam() API was used. + std::vector> owning_arguments; if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - arguments.push_back(argument.get()); + owning_arguments.push_back( + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) + .ValueOrDie()); + arguments.push_back(owning_arguments.back().get()); } } @@ -346,7 +359,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = Literal::ConvertF32ToBF16(expected); + converted_expected = LiteralUtil::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -388,7 +401,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( auto actual = actual_status.ConsumeValueOrDie(); // Turn the expected value into a literal. - std::unique_ptr expected_literal = Literal::CreateR1U8(expected); + std::unique_ptr expected_literal = LiteralUtil::CreateR1U8(expected); VLOG(1) << "expected: " << expected_literal->ToString(); VLOG(1) << "actual: " << actual->ToString(); @@ -453,6 +466,14 @@ ClientLibraryTestBase::ComputeValueAndReference( // function. std::vector> argument_data; std::vector> ref_argument_data; + + // Use `arguments_` if the AddParam() API was used. Otherwise, use + // plain `arguments`. + if (!arguments_.empty()) { + CHECK_EQ(arguments.size(), 0); + arguments = arguments_; + } + for (const auto& arg : arguments) { TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg.Clone())); TF_ASSIGN_OR_RETURN(auto ref_data, ref_client_->TransferToServer(arg)); @@ -486,11 +507,11 @@ ClientLibraryTestBase::ComputeValueAndReference( XlaComputation ClientLibraryTestBase::CreateScalarRelu() { XlaBuilder builder("relu"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); - auto z_value = builder.Parameter(0, shape, "z_value"); + auto z_value = Parameter(&builder, 0, shape, "z_value"); auto zero = use_bfloat16_ - ? builder.ConstantR0(static_cast(0.0f)) - : builder.ConstantR0(0.0f); - builder.Max(z_value, zero); + ? ConstantR0(&builder, static_cast(0.0f)) + : ConstantR0(&builder, 0.0f); + Max(z_value, zero); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -499,9 +520,9 @@ XlaComputation ClientLibraryTestBase::CreateScalarRelu() { XlaComputation ClientLibraryTestBase::CreateScalarMax() { XlaBuilder builder("max"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); - auto x = builder.Parameter(0, shape, "x"); - auto y = builder.Parameter(1, shape, "y"); - builder.Max(x, y); + auto x = Parameter(&builder, 0, shape, "x"); + auto y = Parameter(&builder, 1, shape, "y"); + Max(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -510,13 +531,13 @@ XlaComputation ClientLibraryTestBase::CreateScalarMax() { XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { XlaBuilder builder("relu_sensitivity"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); - auto activation = builder.Parameter(0, shape, "activation"); - auto backprop = builder.Parameter(1, shape, "backprop"); + auto activation = Parameter(&builder, 0, shape, "activation"); + auto backprop = Parameter(&builder, 1, shape, "backprop"); auto zero = use_bfloat16_ - ? builder.ConstantR0(static_cast(0.0f)) - : builder.ConstantR0(0.0f); - auto activation_gtz = builder.Gt(activation, zero); - builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); + ? ConstantR0(&builder, static_cast(0.0f)) + : ConstantR0(&builder, 0.0f); + auto activation_gtz = Gt(activation, zero); + Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -551,16 +572,16 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaBuilder* builder) { - XlaOp data_handle; - arguments_.push_back(CreateParameterAndTransferLiteral( - arguments_.size(), argument, "", builder, &data_handle)); - return data_handle; + arguments_.push_back(argument.Clone()); + return Parameter(builder, /*parameter_number=*/arguments_.size() - 1, + MaybeConvertShapeToBfloat16(argument.shape()), ""); } XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { - return builder->ConstantLiteral( - use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal); + return ConstantLiteral(builder, use_bfloat16_ + ? *LiteralUtil::ConvertF32ToBF16(literal) + : literal); } std::unique_ptr @@ -573,22 +594,39 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, nullptr, builder, data_handle); } +Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { + if (!use_bfloat16_) { + return shape; + } + Shape new_shape = shape; + ShapeUtil::ForEachMutableSubshape(&new_shape, + [](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); + return new_shape; +} + +Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( + const Literal& literal) { + if (use_bfloat16_) { + return std::move(*LiteralUtil::ConvertF32ToBF16(literal)); + } + return literal.Clone(); +} + std::unique_ptr ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, const DeviceHandle* device_handle, XlaBuilder* builder, XlaOp* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr converted_literal; - if (use_bfloat16_) { - converted_literal = Literal::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } + Literal param_literal = MaybeConvertLiteralToBfloat16(literal); std::unique_ptr data = - client_->TransferToServer(*param_literal, device_handle) + client_->TransferToServer(param_literal, device_handle) .ConsumeValueOrDie(); *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); + Parameter(builder, parameter_number, param_literal.shape(), name); return data; } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 0499fec5898a42affa0e0a712dee10187355c13e..b04a3b105ca017b6a91d271e603dcd0cc2068a33 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -26,8 +26,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -73,8 +74,9 @@ class ClientLibraryTestBase : public ::testing::Test { string TestName() const; void SetFastMathDisabled(bool disabled) { - execution_options_.mutable_debug_options()->set_xla_enable_fast_math( - !disabled); + auto* opts = execution_options_.mutable_debug_options(); + opts->set_xla_cpu_enable_fast_math(!disabled); + opts->set_xla_gpu_enable_fast_math(!disabled); } void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } @@ -284,7 +286,7 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp AddParam(const Array& argument, XlaBuilder* builder) { - return AddParam(*Literal::CreateFromArray(argument), builder); + return AddParam(*LiteralUtil::CreateFromArray(argument), builder); } // Creates a constant instruction with the given literal. When the @@ -299,13 +301,14 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp CreateConstantFromArray(const Array& array, XlaBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); + return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array), + builder); } // Same as CreateConstantFromArray, but for scalars. template XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateR0(value), + return CreateConstantFromLiteral(*LiteralUtil::CreateR0(value), builder); } @@ -373,6 +376,13 @@ class ClientLibraryTestBase : public ::testing::Test { // The float type used in this test, BF16 or F32 according to use_bfloat16. PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } + // Executes the computation and calculates the expected reference value using + // the reference client. Returns two literals in the order of (expected, + // actual). + StatusOr, std::unique_ptr>> + ComputeValueAndReference(XlaBuilder* builder, + tensorflow::gtl::ArraySlice arguments); + Client* client_; Client* ref_client_; // To compute reference result. ExecutionOptions execution_options_; @@ -390,19 +400,16 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); - // Executes the computation and calculates the expected reference value using - // the reference client. Returns two literals in the order of (expected, - // actual). - StatusOr, std::unique_ptr>> - ComputeValueAndReference(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + // Converts an f32 shape/literal to bf16 if use_bfloat16_ is true. + Literal MaybeConvertLiteralToBfloat16(const Literal& literal); + Shape MaybeConvertShapeToBfloat16(const Shape& shape); // Whether to run tests with all float-type input/output converted to // bfloat16. bool use_bfloat16_ = false; // Arguments to be passed to the computation when it runs. - std::vector> arguments_; + std::vector arguments_; }; template @@ -410,7 +417,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR0(expected); + LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -426,7 +433,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR0(expected); + LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -436,7 +443,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR1(expected); + LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -452,7 +459,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR1(expected); + LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -462,7 +469,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR2FromArray2D(expected); + LiteralUtil::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -478,7 +485,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR2FromArray2D(expected); + LiteralUtil::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -488,7 +495,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR3FromArray3D(expected); + LiteralUtil::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -504,7 +511,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR3FromArray3D(expected); + LiteralUtil::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -514,7 +521,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR4FromArray4D(expected); + LiteralUtil::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -530,7 +537,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR4FromArray4D(expected); + LiteralUtil::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -539,13 +546,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = Literal::CreateR0(value); + std::unique_ptr literal = LiteralUtil::CreateR0(value); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + *data_handle = Parameter(builder, parameter_number, literal->shape(), name); return data; } @@ -553,13 +560,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = Literal::CreateR1(values); + std::unique_ptr literal = LiteralUtil::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + *data_handle = Parameter(builder, parameter_number, literal->shape(), name); return data; } @@ -567,13 +574,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); + std::unique_ptr literal = LiteralUtil::CreateR2FromArray2D(array_2d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + *data_handle = Parameter(builder, parameter_number, literal->shape(), name); return data; } @@ -581,13 +588,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); + std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(array_3d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + *data_handle = Parameter(builder, parameter_number, literal->shape(), name); return data; } diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 08671cf62445826649b5c97003f998ae98a59d97..c898dacf489db97223e2918414daf5de88bece64 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -43,8 +43,8 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { std::vector> layouts = {{0, 1}, {1, 0}}; for (const std::vector& execute_layout : layouts) { for (const std::vector& transfer_layout : layouts) { - b.Add(b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})); + Add(ConstantR2(&b, {{1, 2}, {3, 4}}), + ConstantR2(&b, {{10, 20}, {30, 40}})); TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); ExecutionOptions execution_options = execution_options_; @@ -56,7 +56,7 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { client_->Execute(computation, {}, &execution_options)); std::unique_ptr expected_literal = - Literal::CreateR2WithLayout( + LiteralUtil::CreateR2WithLayout( {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); TF_ASSERT_OK_AND_ASSIGN( @@ -72,8 +72,8 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { XlaBuilder b(TestName()); - b.Tuple({b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})}); + Tuple(&b, {ConstantR2(&b, {{1, 2}, {3, 4}}), + ConstantR2(&b, {{10, 20}, {30, 40}})}); TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); @@ -112,13 +112,13 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr const_arg, - client_->TransferToServer(*Literal::CreateR2({{5, 6}, {7, 8}}))); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr const_arg, + client_->TransferToServer( + *LiteralUtil::CreateR2({{5, 6}, {7, 8}}))); XlaBuilder b(TestName() + ".add"); - b.Add(b.Parameter(0, shape, "param_0"), - b.ConstantR2({{1, 2}, {3, 4}})); + Add(Parameter(&b, 0, shape, "param_0"), + ConstantR2(&b, {{1, 2}, {3, 4}})); TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build()); // We can't really test parallel execution on CPU since all of the cores in a @@ -136,7 +136,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN(auto results, client_->ExecuteParallel(computation_instances)); - auto expected_result = Literal::CreateR2({{6, 8}, {10, 12}}); + auto expected_result = LiteralUtil::CreateR2({{6, 8}, {10, 12}}); TF_ASSERT_OK_AND_ASSIGN( auto result_literal, diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 50a006964869b3e5dce431d441f7cd81af9df910..7c52c9fbbb57f9291ea9f0966e2efa715819fb67 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -50,7 +50,7 @@ class CompilationCacheTest : public ClientLibraryTestBase { &execution_profile) .ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR0(expected_result), *result, error_spec_)); + *LiteralUtil::CreateR0(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -67,7 +67,7 @@ class CompilationCacheTest : public ClientLibraryTestBase { std::unique_ptr result = client_->Transfer(*data_handle).ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2(expected_result), *result, error_spec_)); + *LiteralUtil::CreateR2(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -77,7 +77,7 @@ class CompilationCacheTest : public ClientLibraryTestBase { // TODO(b/74197823): Disabled because there is no cache in the new design. XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { XlaBuilder builder(TestName()); - builder.Neg(builder.ConstantR0(42.0)); + Neg(ConstantR0(&builder, 42.0)); XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); @@ -89,17 +89,17 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = - client_->TransferToServer(*Literal::CreateR0(42.0f)) + client_->TransferToServer(*LiteralUtil::CreateR0(42.0f)) .ConsumeValueOrDie(); std::unique_ptr data_123 = - client_->TransferToServer(*Literal::CreateR0(123.0f)) + client_->TransferToServer(*LiteralUtil::CreateR0(123.0f)) .ConsumeValueOrDie(); std::unique_ptr data_456 = - client_->TransferToServer(*Literal::CreateR0(456.0f)) + client_->TransferToServer(*LiteralUtil::CreateR0(456.0f)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); - builder.Neg(builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + Neg(Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param")); XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, @@ -115,16 +115,16 @@ XLA_TEST_F(CompilationCacheTest, // TODO(b/74197823): Disabled because there is no cache in the new design. XLA_TEST_F(CompilationCacheTest, DISABLED_MultipleComputations) { XlaBuilder builder_neg(TestName() + "_neg"); - builder_neg.Neg(builder_neg.ConstantR0(42.0)); + Neg(ConstantR0(&builder_neg, 42.0)); XlaComputation computation_neg = builder_neg.Build().ConsumeValueOrDie(); XlaBuilder builder_exp(TestName() + "_exp"); - builder_exp.Exp(builder_exp.ConstantR0(1.0)); + Exp(ConstantR0(&builder_exp, 1.0)); XlaComputation computation_exp = builder_exp.Build().ConsumeValueOrDie(); XlaBuilder builder_add(TestName() + "_add"); - builder_add.Add(builder_add.ConstantR0(2.0), - builder_add.ConstantR0(3.0)); + Add(ConstantR0(&builder_add, 2.0), + ConstantR0(&builder_add, 3.0)); XlaComputation computation_add = builder_add.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation_neg, {}, -42.0, @@ -143,18 +143,18 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { // layouts. Use these arrays as parameters to a simple computation. If the // layout of the array changes then computation should be recompiled (cache // miss). - auto rowmaj_array = Literal::CreateR2WithLayout( + auto rowmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); auto rowmaj_handle = client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); - auto colmaj_array = Literal::CreateR2WithLayout( + auto colmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); auto colmaj_handle = client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR2F32(computation, {colmaj_handle.get()}, diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index ba22530f1cfee56337f862c25122d399dbf0f1e4..5a06d061f0d83fff547502495ff8ab13fb421b70 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -99,7 +99,7 @@ TEST_F(ComputeConstantTest, ScalarInt32Literal) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.ConstantR0(42); + auto computation = ConstantR0(&b, 42); EXPECT_TRUE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -113,7 +113,7 @@ TEST_F(ComputeConstantTest, ScalarFloatAdd) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); + Add(ConstantR0(&b, 42.5f), ConstantR0(&b, 1.5f)); EXPECT_TRUE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -127,8 +127,8 @@ TEST_F(ComputeConstantTest, ScalarRng) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), - ShapeUtil::MakeShape(F32, {})); + RngUniform(ConstantR0(&b, 1.1f), ConstantR0(&b, 2.1f), + ShapeUtil::MakeShape(F32, {})); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -141,7 +141,7 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); + auto computation = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param"); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -156,8 +156,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR0(1.0f), - b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + Add(ConstantR0(&b, 1.0f), + Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param")); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -174,18 +174,18 @@ TEST_F(ComputeConstantTest, UnrelatedParam) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); + auto param_a = Parameter(&b, 10, ShapeUtil::MakeShape(F32, {}), "param0"); auto constant_4 = - b.Add(b.ConstantR0(2.5f), b.ConstantR0(1.5f)); - auto not_constant_a = b.Add(constant_4, param_a); + Add(ConstantR0(&b, 2.5f), ConstantR0(&b, 1.5f)); + auto not_constant_a = Add(constant_4, param_a); - auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1"); + auto param_b = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "param1"); auto constant_9 = - b.Mul(b.ConstantR0(2.0f), b.ConstantR0(4.5f)); - auto not_constant_b = b.Add(param_b, constant_9); + Mul(ConstantR0(&b, 2.0f), ConstantR0(&b, 4.5f)); + auto not_constant_b = Add(param_b, constant_9); - auto constant_13 = b.Add(constant_4, constant_9); - b.Add(not_constant_b, b.Add(constant_13, not_constant_a)); + auto constant_13 = Add(constant_4, constant_9); + Add(not_constant_b, Add(constant_13, not_constant_a)); EXPECT_TRUE(IsConstant(constant_13, &b)); @@ -201,13 +201,13 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); + Add(ConstantR1(&b, {1, 2}), ConstantR1(&b, {3, 4})); EXPECT_TRUE(IsConstant(computation, &b)); TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = - Literal::CreateR1({4, 6}); + LiteralUtil::CreateR1({4, 6}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -216,12 +216,12 @@ TEST_F(ComputeConstantTest, IntegerDivide) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); + auto computation = Div(ConstantR0(&b, 15), ConstantR0(&b, 3)); EXPECT_TRUE(IsConstant(computation, &b)); TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr expected_literal = Literal::CreateR0(5); + std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -237,13 +237,13 @@ XLA_TEST_F(ComputeConstantTest, Layout) { TF_ASSERT_OK_AND_ASSIGN( auto computed, ComputeConstantLiteral( client, - b.Add(b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})), + Add(ConstantR2(&b, {{1, 2}, {3, 4}}), + ConstantR2(&b, {{10, 20}, {30, 40}})), &b, &layout_proto)); std::unique_ptr expected_literal = - Literal::CreateR2WithLayout({{11, 22}, {33, 44}}, - LayoutUtil::MakeLayout(layout)); + LiteralUtil::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( expected_literal->shape(), computed->shape())); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index a4c8a83eb15f7cc279b6c8f1bf1394c0afb9f7cf..be017477d84eb9faf5aa79dcdf54d6b6aaf6fd8e 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -39,7 +39,7 @@ using ::testing::HasSubstr; // Concatenate expects at least one argument. XLA_TEST_F(ConcatTest, Concat_Nothing) { XlaBuilder builder(TestName()); - builder.ConcatInDim({}, 0); + ConcatInDim(&builder, {}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), @@ -49,8 +49,8 @@ XLA_TEST_F(ConcatTest, Concat_Nothing) { // Concatenate with one argument works. XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0, 64.0}); - builder.ConcatInDim({a}, 0); + auto a = ConstantR1(&builder, {42.0, 64.0}); + ConcatInDim(&builder, {a}, 0); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -58,8 +58,8 @@ XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.ConcatInDim({a}, 0); + auto a = ConstantR1(&builder, {}); + ConcatInDim(&builder, {a}, 0); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -69,9 +69,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { // to concatenate on. XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR0(42.0); - auto b = builder.ConstantR0(64.0); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR0(&builder, 42.0); + auto b = ConstantR0(&builder, 64.0); + ConcatInDim(&builder, {a, b}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), @@ -80,9 +80,9 @@ XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({}); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {}); + ConcatInDim(&builder, {a, b}, 0); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -90,9 +90,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - auto b = builder.ConstantR1({256.0}); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, {}); + auto b = ConstantR1(&builder, {256.0}); + ConcatInDim(&builder, {a, b}, 0); std::vector expected = {256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -100,9 +100,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) { XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0, 64.0}); - auto b = builder.ConstantR1({}); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, {42.0, 64.0}); + auto b = ConstantR1(&builder, {}); + ConcatInDim(&builder, {a, b}, 0); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -110,9 +110,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) { XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0, 64.0}); - auto b = builder.ConstantR1({256.0}); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, {42.0, 64.0}); + auto b = ConstantR1(&builder, {256.0}); + ConcatInDim(&builder, {a, b}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -130,9 +130,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR1(lhs); - auto b = builder.ConstantR1(rhs); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR1(&builder, lhs); + auto b = ConstantR1(&builder, rhs); + ConcatInDim(&builder, {a, b}, 0); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } @@ -140,9 +140,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) { XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) { for (int dim : {0, 1}) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 0)); - auto b = builder.ConstantR2FromArray2D(Array2D(0, 0)); - builder.ConcatInDim({a, b}, dim); + auto a = ConstantR2FromArray2D(&builder, Array2D(0, 0)); + auto b = ConstantR2FromArray2D(&builder, Array2D(0, 0)); + ConcatInDim(&builder, {a, b}, dim); ComputeAndCompareR2(&builder, Array2D(0, 0), {}, ErrorSpec(0.0001)); @@ -153,9 +153,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 0); Array2D expected({ {0}, @@ -168,9 +168,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 1); Array2D expected({ {0, 64}, @@ -181,9 +181,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { XLA_TEST_F(ConcatTest, Concat2x0With2x5) { XlaBuilder builder(TestName()); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, Array2D(2, 0)); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 1); ComputeAndCompareR2(&builder, *b_array, {}, ErrorSpec(0.0001)); } @@ -192,9 +192,9 @@ XLA_TEST_F(ConcatTest, Concat2x3With2x5) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(2, 3); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 1); Array2D expected({ {0, 1, 2, 64, 65, 66, 67, 68}, @@ -206,9 +206,9 @@ XLA_TEST_F(ConcatTest, Concat2x3With2x5) { XLA_TEST_F(ConcatTest, Concat3x2With0x2) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(Array2D(0, 2)); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + ConcatInDim(&builder, {a, b}, 0); ComputeAndCompareR2(&builder, *a_array, {}, ErrorSpec(0.0001)); } @@ -217,9 +217,9 @@ XLA_TEST_F(ConcatTest, Concat3x2With5x2) { XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0); - auto a = builder.ConstantR2FromArray2D(*a_array); - auto b = builder.ConstantR2FromArray2D(*b_array); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR2FromArray2D(&builder, *a_array); + auto b = ConstantR2FromArray2D(&builder, *b_array); + ConcatInDim(&builder, {a, b}, 0); Array2D expected({ {0, 1}, @@ -236,9 +236,9 @@ XLA_TEST_F(ConcatTest, Concat3x2With5x2) { XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR3FromArray3D(Array3D(3, 0, 2)); - auto b = builder.ConstantR3FromArray3D(Array3D(3, 0, 1)); - builder.ConcatInDim({a, b}, 2); + auto a = ConstantR3FromArray3D(&builder, Array3D(3, 0, 2)); + auto b = ConstantR3FromArray3D(&builder, Array3D(3, 0, 1)); + ConcatInDim(&builder, {a, b}, 2); ComputeAndCompareR3(&builder, Array3D(3, 0, 3), {}, ErrorSpec(0.0001)); } @@ -257,9 +257,9 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { {{7}}, {{8}}, }); - auto a = builder.ConstantR3FromArray3D(a_array); - auto b = builder.ConstantR3FromArray3D(b_array); - builder.ConcatInDim({a, b}, 2); + auto a = ConstantR3FromArray3D(&builder, a_array); + auto b = ConstantR3FromArray3D(&builder, b_array); + ConcatInDim(&builder, {a, b}, 2); Array3D expected({ {{0, 1, 6}}, @@ -271,10 +271,10 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0}); - auto b = builder.ConstantR1({64.0}); - auto c = builder.ConstantR1({256.0}); - builder.ConcatInDim({a, b, c}, 0); + auto a = ConstantR1(&builder, {42.0}); + auto b = ConstantR1(&builder, {64.0}); + auto c = ConstantR1(&builder, {256.0}); + ConcatInDim(&builder, {a, b, c}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -300,10 +300,10 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { {{7}}, {{11}}, }); - auto a = builder.ConstantR3FromArray3D(a_array); - auto b = builder.ConstantR3FromArray3D(b_array); - auto c = builder.ConstantR3FromArray3D(c_array); - builder.ConcatInDim({a, b, c}, 2); + auto a = ConstantR3FromArray3D(&builder, a_array); + auto b = ConstantR3FromArray3D(&builder, b_array); + auto c = ConstantR3FromArray3D(&builder, c_array); + ConcatInDim(&builder, {a, b, c}, 2); Array3D expected({ {{0, 1, 2, 3}}, @@ -315,11 +315,11 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0}); - auto b = builder.ConstantR1({64.0}); - auto c = builder.ConstantR1({256.0}); + auto a = ConstantR1(&builder, {42.0}); + auto b = ConstantR1(&builder, {64.0}); + auto c = ConstantR1(&builder, {256.0}); // concatenated = (a concat b) concat c - builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0); + ConcatInDim(&builder, {ConcatInDim(&builder, {a, b}, 0), c}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -327,11 +327,11 @@ XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) { XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0}); - auto b = builder.ConstantR1({64.0}); - auto c = builder.ConstantR1({256.0}); + auto a = ConstantR1(&builder, {42.0}); + auto b = ConstantR1(&builder, {64.0}); + auto c = ConstantR1(&builder, {256.0}); // concatenated = a concat (b concat c) - builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0); + ConcatInDim(&builder, {a, ConcatInDim(&builder, {b, c}, 0)}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -346,9 +346,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2D(lhs); - auto b = builder.ConstantR2FromArray2D(rhs); - builder.ConcatInDim({a, b}, 0); + auto a = ConstantR2FromArray2D(&builder, lhs); + auto b = ConstantR2FromArray2D(&builder, rhs); + ConcatInDim(&builder, {a, b}, 0); Array2D expected(2, 1024); for (int i = 0; i < 1024; ++i) { @@ -367,9 +367,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2D(lhs); - auto b = builder.ConstantR2FromArray2D(rhs); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, lhs); + auto b = ConstantR2FromArray2D(&builder, rhs); + ConcatInDim(&builder, {a, b}, 1); Array2D expected(1, 2048); for (int i = 0; i < 1024; ++i) { @@ -392,9 +392,9 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) { } XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2D(lhs); - auto b = builder.ConstantR2FromArray2D(rhs); - builder.ConcatInDim({a, b}, 1); + auto a = ConstantR2FromArray2D(&builder, lhs); + auto b = ConstantR2FromArray2D(&builder, rhs); + ConcatInDim(&builder, {a, b}, 1); Array2D expected(64, 66); for (int i0 = 0; i0 < 64; ++i0) { @@ -410,22 +410,37 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) { XlaBuilder builder(TestName()); auto opaque_shape = ShapeUtil::MakeOpaqueShape(); auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); - auto x = builder.Parameter(0, r1f32, "x"); - auto y = builder.Parameter(1, opaque_shape, "y"); - builder.ConcatInDim({x, y}, 0); + auto x = Parameter(&builder, 0, r1f32, "x"); + auto y = Parameter(&builder, 1, opaque_shape, "y"); + ConcatInDim(&builder, {x, y}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT( computation_status.status().ToString(), - HasSubstr("Expected non-opaque argument for operand of concatenation")); + HasSubstr("Expected array argument for operand of concatenation")); +} + +// Show that we can't concatenate with tokens. +XLA_TEST_F(ConcatTest, CannotConcatTokens) { + XlaBuilder builder(TestName()); + auto token_shape = ShapeUtil::MakeTokenShape(); + auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); + auto x = Parameter(&builder, 0, r1f32, "x"); + auto y = Parameter(&builder, 1, token_shape, "y"); + ConcatInDim(&builder, {x, y}, 0); + StatusOr computation_status = builder.Build(); + ASSERT_FALSE(computation_status.ok()); + EXPECT_THAT( + computation_status.status().ToString(), + HasSubstr("Expected array argument for operand of concatenation")); } XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { XlaBuilder builder(TestName()); - auto p0 = builder.ConstantR1({true}); - auto p1 = builder.ConstantR1({false}); - auto p2 = builder.ConstantR1({true}); - builder.ConcatInDim({p0, p1, p2}, 0); + auto p0 = ConstantR1(&builder, {true}); + auto p1 = ConstantR1(&builder, {false}); + auto p2 = ConstantR1(&builder, {true}); + ConcatInDim(&builder, {p0, p1, p2}, 0); bool expected[] = {true, false, true}; ComputeAndCompareR1(&builder, expected, {}); @@ -433,11 +448,11 @@ XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { XlaBuilder builder(TestName()); - auto a0 = builder.ConstantR1({1}); - auto a1 = builder.ConstantR1({2, 3}); - auto a2 = builder.ConstantR1({4, 5, 6}); - auto a3 = builder.ConstantR1({7, 8, 9, 10}); - builder.ConcatInDim({a0, a1, a2, a3}, 0); + auto a0 = ConstantR1(&builder, {1}); + auto a1 = ConstantR1(&builder, {2, 3}); + auto a2 = ConstantR1(&builder, {4, 5, 6}); + auto a3 = ConstantR1(&builder, {7, 8, 9, 10}); + ConcatInDim(&builder, {a0, a1, a2, a3}, 0); std::vector expected(10); std::iota(expected.begin(), expected.end(), 1); @@ -472,7 +487,7 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { auto p1 = CreateR3Parameter(arr1, /*parameter_number=*/1, "p1", &builder, &h1); - builder.ConcatInDim({h0, h1}, 2); + ConcatInDim(&builder, {h0, h1}, 2); ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); } @@ -499,9 +514,9 @@ TEST_P(ConcatR2BinaryTest, DoIt) { rhs.FillUnique(1000); XlaBuilder builder(TestName()); - auto a0 = builder.ConstantR2FromArray2D(lhs); - auto a1 = builder.ConstantR2FromArray2D(rhs); - builder.ConcatInDim({a0, a1}, spec.concat_dimension); + auto a0 = ConstantR2FromArray2D(&builder, lhs); + auto a1 = ConstantR2FromArray2D(&builder, rhs); + ConcatInDim(&builder, {a0, a1}, spec.concat_dimension); std::unique_ptr> expected = ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension); @@ -519,19 +534,19 @@ TEST_P(ConcatR2BinaryTest, DoIt) { // concat XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); - auto x_literal = Literal::CreateR0(2.f); - auto y_literal = Literal::CreateR0(3.f); + auto x_literal = LiteralUtil::CreateR0(2.f); + auto y_literal = LiteralUtil::CreateR0(3.f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, f32_scalar, "x"); - auto y = builder.Parameter(1, f32_scalar, "y"); - auto mul = builder.Mul(x, y); - auto add1 = builder.Add(mul, builder.ConstantR1({1.f, 2.f})); - auto add2 = builder.Add(mul, builder.ConstantR1({3.f, 4.f})); - auto add3 = builder.Add(mul, builder.ConstantR1({5.f, 6.f})); - builder.ConcatInDim({add1, add2, add3}, /*dimension=*/0); + auto x = Parameter(&builder, 0, f32_scalar, "x"); + auto y = Parameter(&builder, 1, f32_scalar, "y"); + auto mul = Mul(x, y); + auto add1 = Add(mul, ConstantR1(&builder, {1.f, 2.f})); + auto add2 = Add(mul, ConstantR1(&builder, {3.f, 4.f})); + auto add3 = Add(mul, ConstantR1(&builder, {5.f, 6.f})); + ConcatInDim(&builder, {add1, add2, add3}, /*dimension=*/0); ComputeAndCompareR1(&builder, {7., 8., 9., 10., 11., 12.}, {x_data.get(), y_data.get()}, ErrorSpec(1e-4)); @@ -541,21 +556,21 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { // produces the correct result in rank 1. XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); - auto x_literal = Literal::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); - auto y_literal = Literal::CreateR0(1.5f); - auto z_literal = Literal::CreateR0(5.5f); + auto x_literal = LiteralUtil::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); + auto y_literal = LiteralUtil::CreateR0(1.5f); + auto z_literal = LiteralUtil::CreateR0(5.5f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, x_literal->shape(), "x"); - auto y = builder.Parameter(1, f32_scalar, "y"); - auto z = builder.Parameter(2, f32_scalar, "z"); - auto bcast = builder.Broadcast(y, {5}); - auto bcast2 = builder.Broadcast(z, {3}); - auto concat = builder.ConcatInDim({bcast, x}, /*dimension=*/0); - builder.ConcatInDim({concat, bcast2}, /*dimension=*/0); + auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto y = Parameter(&builder, 1, f32_scalar, "y"); + auto z = Parameter(&builder, 2, f32_scalar, "z"); + auto bcast = Broadcast(y, {5}); + auto bcast2 = Broadcast(z, {3}); + auto concat = ConcatInDim(&builder, {bcast, x}, /*dimension=*/0); + ConcatInDim(&builder, {concat, bcast2}, /*dimension=*/0); ComputeAndCompareR1( &builder, @@ -569,21 +584,21 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); Array3D x3d(3, 5, 7, 3.14f); - auto x_literal = Literal::CreateR3FromArray3D(x3d); - auto y_literal = Literal::CreateR0(1.5f); - auto z_literal = Literal::CreateR0(5.5f); + auto x_literal = LiteralUtil::CreateR3FromArray3D(x3d); + auto y_literal = LiteralUtil::CreateR0(1.5f); + auto z_literal = LiteralUtil::CreateR0(5.5f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, x_literal->shape(), "x"); - auto y = builder.Parameter(1, f32_scalar, "y"); - auto z = builder.Parameter(2, f32_scalar, "y"); - auto y_bcast = builder.Broadcast(y, {1, 5, 7}); - auto z_bcast = builder.Broadcast(z, {4, 1, 7}); - auto concat = builder.ConcatInDim({y_bcast, x}, /*dimension=*/0); - builder.ConcatInDim({concat, z_bcast}, /*dimension=*/1); + auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto y = Parameter(&builder, 1, f32_scalar, "y"); + auto z = Parameter(&builder, 2, f32_scalar, "y"); + auto y_bcast = Broadcast(y, {1, 5, 7}); + auto z_bcast = Broadcast(z, {4, 1, 7}); + auto concat = ConcatInDim(&builder, {y_bcast, x}, /*dimension=*/0); + ConcatInDim(&builder, {concat, z_bcast}, /*dimension=*/1); Array3D y_bcast3d(1, 5, 7, 1.5f); Array3D z_bcast3d(4, 1, 7, 5.5f); auto concat0 = ReferenceUtil::Concat3D(y_bcast3d, x3d, 0); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index 7ff6706935740c7d76ee5cd03eae292386760397..b27c1044baf2c0002f166c53a81e4361c60d012a 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -26,8 +26,8 @@ class ConditionalOpTest : public ClientLibraryTestBase { protected: XlaComputation CreateR0ConstantComputation(float value) { XlaBuilder builder("Constant"); - builder.Parameter(0, empty_tuple_, "tuple"); - builder.ConstantR0(value); + Parameter(&builder, 0, empty_tuple_, "tuple"); + ConstantR0(&builder, value); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -35,7 +35,7 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateR0IdentityComputation() { XlaBuilder builder("Identity"); - builder.Parameter(0, r0f32_, "x"); + Parameter(&builder, 0, r0f32_, "x"); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -43,8 +43,8 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateCeilComputation(const Shape& shape) { XlaBuilder builder("Ceil"); - auto param = builder.Parameter(0, shape, "param"); - builder.Ceil(param); + auto param = Parameter(&builder, 0, shape, "param"); + Ceil(param); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -60,8 +60,8 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateFloorComputation(const Shape& shape) { XlaBuilder builder("Floor"); - auto param = builder.Parameter(0, shape, "param"); - builder.Floor(param); + auto param = Parameter(&builder, 0, shape, "param"); + Floor(param); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -78,12 +78,12 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateTupleCeilComputation(const string& computation_name, const Shape& tuple_shape) { XlaBuilder builder(computation_name); - auto tuple = builder.Parameter(0, tuple_shape, "tuple"); - auto x = builder.GetTupleElement(tuple, 0); - auto y = builder.GetTupleElement(tuple, 1); - auto x_ceil = builder.Ceil(x); - auto y_ceil = builder.Ceil(y); - builder.Tuple({x_ceil, y_ceil}); + auto tuple = Parameter(&builder, 0, tuple_shape, "tuple"); + auto x = GetTupleElement(tuple, 0); + auto y = GetTupleElement(tuple, 1); + auto x_ceil = Ceil(x); + auto y_ceil = Ceil(y); + Tuple(&builder, {x_ceil, y_ceil}); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -100,12 +100,12 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateTupleFloorComputation(const string& computation_name, const Shape& tuple_shape) { XlaBuilder builder(computation_name); - auto tuple = builder.Parameter(0, tuple_shape, "tuple"); - auto x = builder.GetTupleElement(tuple, 0); - auto y = builder.GetTupleElement(tuple, 1); - auto x_floor = builder.Floor(x); - auto y_floor = builder.Floor(y); - builder.Tuple({x_floor, y_floor}); + auto tuple = Parameter(&builder, 0, tuple_shape, "tuple"); + auto x = GetTupleElement(tuple, 0); + auto y = GetTupleElement(tuple, 1); + auto x_floor = Floor(x); + auto y_floor = Floor(y); + Tuple(&builder, {x_floor, y_floor}); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -122,10 +122,10 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateTupleAddComputation(const string& computation_name, const Shape& tuple_shape) { XlaBuilder builder(computation_name); - auto tuple = builder.Parameter(0, tuple_shape, "tuple"); - auto x = builder.GetTupleElement(tuple, 0); - auto y = builder.GetTupleElement(tuple, 1); - builder.Add(x, y); + auto tuple = Parameter(&builder, 0, tuple_shape, "tuple"); + auto x = GetTupleElement(tuple, 0); + auto y = GetTupleElement(tuple, 1); + Add(x, y); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -142,10 +142,10 @@ class ConditionalOpTest : public ClientLibraryTestBase { XlaComputation CreateTupleSubComputation(const string& computation_name, const Shape& tuple_shape) { XlaBuilder builder(computation_name); - auto tuple = builder.Parameter(0, tuple_shape, "tuple"); - auto x = builder.GetTupleElement(tuple, 0); - auto y = builder.GetTupleElement(tuple, 1); - builder.Sub(x, y); + auto tuple = Parameter(&builder, 0, tuple_shape, "tuple"); + auto x = GetTupleElement(tuple, 0); + auto y = GetTupleElement(tuple, 1); + Sub(x, y); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); @@ -172,198 +172,215 @@ class ConditionalOpTest : public ClientLibraryTestBase { // Test true and false computations that do not take any parameters. XLA_TEST_F(ConditionalOpTest, Parameters0) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operands = builder.Tuple({}); + XlaOp pred; + auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); + auto operands = Tuple(&builder, {}); auto true_computation = CreateR0ConstantComputation(56.0f); auto false_computation = CreateR0ConstantComputation(12.0f); - builder.Conditional(pred, operands, true_computation, operands, - false_computation); + Conditional(pred, operands, true_computation, operands, false_computation); - ComputeAndCompareR0(&builder, 56.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 56.0f, {pred_arg.get()}, error_spec_); } // Test true and false computations that take in 1 parameter. XLA_TEST_F(ConditionalOpTest, Parameters1) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.0f); - auto operand2 = builder.ConstantR0(12.0f); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operand1 = ConstantR0(&builder, 56.0f); + auto operand2 = ConstantR0(&builder, 12.0f); auto identity = CreateR0IdentityComputation(); - builder.Conditional(pred, operand1, identity, operand2, identity); + Conditional(pred, operand1, identity, operand2, identity); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with two different computations in the true and false cases // that take in different arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.4f); - auto operand2 = builder.ConstantR0(12.6f); - builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2, - CreateR0FloorComputation()); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operand1 = ConstantR0(&builder, 56.4f); + auto operand2 = ConstantR0(&builder, 12.6f); + Conditional(pred, operand1, CreateR0CeilComputation(), operand2, + CreateR0FloorComputation()); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with two different computations in the true and false cases // that take in the same arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand = builder.ConstantR0(12.6f); - builder.Conditional(pred, operand, CreateR0CeilComputation(), operand, - CreateR0FloorComputation()); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operand = ConstantR0(&builder, 12.6f); + Conditional(pred, operand, CreateR0CeilComputation(), operand, + CreateR0FloorComputation()); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with the same computation in the true and false cases but // take in different arguments. XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.4f); - auto operand2 = builder.ConstantR0(12.6f); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operand1 = ConstantR0(&builder, 56.4f); + auto operand2 = ConstantR0(&builder, 12.6f); auto floor = CreateR0FloorComputation(); - builder.Conditional(pred, operand1, floor, operand2, floor); + Conditional(pred, operand1, floor, operand2, floor); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with the same computation in the true and false cases that // take in the same arguments. XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand = builder.ConstantR0(12.6f); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operand = ConstantR0(&builder, 12.6f); auto floor = CreateR0FloorComputation(); - builder.Conditional(pred, operand, floor, operand, floor); + Conditional(pred, operand, floor, operand, floor); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with different instances of the same computation in the true // and false cases. XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.4f); - auto operand2 = builder.ConstantR0(12.6f); - builder.Conditional(pred, operand1, CreateR0FloorComputation(), operand2, - CreateR0FloorComputation()); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operand1 = ConstantR0(&builder, 56.4f); + auto operand2 = ConstantR0(&builder, 12.6f); + Conditional(pred, operand1, CreateR0FloorComputation(), operand2, + CreateR0FloorComputation()); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test the case when a call invokes a computation that contains a conditional. XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); XlaBuilder inner_builder(TestName() + ".inner_conditional"); - auto pred_cond = inner_builder.Parameter(0, r0bool, "param0"); - auto true_operand = inner_builder.Parameter(1, r0f32_, "param1"); - auto false_operand = inner_builder.Parameter(2, r0f32_, "param2"); - inner_builder.Conditional(pred_cond, true_operand, CreateR0CeilComputation(), - false_operand, CreateR0FloorComputation()); + auto pred_cond = Parameter(&inner_builder, 0, r0bool, "param0"); + auto true_operand = Parameter(&inner_builder, 1, r0f32_, "param1"); + auto false_operand = Parameter(&inner_builder, 2, r0f32_, "param2"); + Conditional(pred_cond, true_operand, CreateR0CeilComputation(), false_operand, + CreateR0FloorComputation()); auto inner_builder_result = inner_builder.Build(); XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.4f); - auto operand2 = builder.ConstantR0(12.6f); - builder.Call(inner_builder_result.ConsumeValueOrDie(), - {pred, operand1, operand2}); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operand1 = ConstantR0(&builder, 56.4f); + auto operand2 = ConstantR0(&builder, 12.6f); + Call(&builder, inner_builder_result.ConsumeValueOrDie(), + {pred, operand1, operand2}); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test true and false computations that take in 2 parameters and predicate is // true. XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operand1 = builder.ConstantR0(56.0f); - auto operand2 = builder.ConstantR0(12.0f); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands, - CreateR0TupleSubComputation()); - - ComputeAndCompareR0(&builder, 68.0f, {}, error_spec_); + XlaOp pred; + auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); + auto operand1 = ConstantR0(&builder, 56.0f); + auto operand2 = ConstantR0(&builder, 12.0f); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR0TupleAddComputation(), operands, + CreateR0TupleSubComputation()); + + ComputeAndCompareR0(&builder, 68.0f, {pred_arg.get()}, error_spec_); } // Test true and false computations that take in 2 parameters and predicate is // false. XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(56.0f); - auto operand2 = builder.ConstantR0(12.0f); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands, - CreateR0TupleSubComputation()); - - ComputeAndCompareR0(&builder, 44.0f, {}, error_spec_); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operand1 = ConstantR0(&builder, 56.0f); + auto operand2 = ConstantR0(&builder, 12.0f); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR0TupleAddComputation(), operands, + CreateR0TupleSubComputation()); + + ComputeAndCompareR0(&builder, 44.0f, {pred_arg.get()}, error_spec_); } // Test true and false computations that take in 2 array parameters and // predicate is true. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operand1 = builder.ConstantR1({24.0f, 56.0f}); - auto operand2 = builder.ConstantR1({10.0f, 11.0f}); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, - CreateR1TupleSubComputation()); - - ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {}, error_spec_); + XlaOp pred; + auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); + auto operand1 = ConstantR1(&builder, {24.0f, 56.0f}); + auto operand2 = ConstantR1(&builder, {10.0f, 11.0f}); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR1TupleSubComputation()); + + ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {pred_arg.get()}, + error_spec_); } // Test true and false computations that take in 2 array parameters and // predicate is false. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operand1 = builder.ConstantR1({24.0f, 56.0f}); - auto operand2 = builder.ConstantR1({10.0f, 11.0f}); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, - CreateR1TupleSubComputation()); - - ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {}, error_spec_); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operand1 = ConstantR1(&builder, {24.0f, 56.0f}); + auto operand2 = ConstantR1(&builder, {10.0f, 11.0f}); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR1TupleSubComputation()); + + ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {pred_arg.get()}, + error_spec_); } // Test true and false computations that return a tuple of scalars. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operands = builder.Tuple( - {builder.ConstantR0(12.2f), builder.ConstantR0(25.6f)}); - builder.Conditional(pred, operands, CreateR0TupleCeilComputation(), operands, - CreateR0TupleFloorComputation()); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operands = Tuple(&builder, {ConstantR0(&builder, 12.2f), + ConstantR0(&builder, 25.6f)}); + Conditional(pred, operands, CreateR0TupleCeilComputation(), operands, + CreateR0TupleFloorComputation()); ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR0(12.0f).get(), - Literal::CreateR0(25.0f).get()}), - {}, error_spec_); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12.0f).get(), + LiteralUtil::CreateR0(25.0f).get()}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a tuple of arrays. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operands = builder.Tuple({builder.ConstantR1({12.2f, 15.8f}), - builder.ConstantR1({25.6f, 29.2f})}); - builder.Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, - CreateR1TupleFloorComputation()); + XlaOp pred; + auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); + auto operands = + Tuple(&builder, {ConstantR1(&builder, {12.2f, 15.8f}), + ConstantR1(&builder, {25.6f, 29.2f})}); + Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, + CreateR1TupleFloorComputation()); ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR1({13.0f, 16.0f}).get(), - Literal::CreateR1({26.0f, 30.0f}).get()}), - {}, error_spec_); + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({13.0f, 16.0f}).get(), + LiteralUtil::CreateR1({26.0f, 30.0f}).get()}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a tuple of a predicate, a @@ -371,85 +388,91 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { XlaBuilder true_builder(TestName() + ".true"); { - true_builder.Parameter(0, empty_tuple_, "tuple"); - auto true_pred = true_builder.ConstantR0(true); - auto true_scalar = true_builder.ConstantR0(12.2f); - auto true_array = true_builder.ConstantR1({12.8f, 14.6f}); - true_builder.Tuple({true_pred, true_scalar, true_array}); + Parameter(&true_builder, 0, empty_tuple_, "tuple"); + auto true_pred = ConstantR0(&true_builder, true); + auto true_scalar = ConstantR0(&true_builder, 12.2f); + auto true_array = ConstantR1(&true_builder, {12.8f, 14.6f}); + Tuple(&true_builder, {true_pred, true_scalar, true_array}); } auto true_builder_result = true_builder.Build(); EXPECT_IS_OK(true_builder_result.status()); XlaBuilder false_builder(TestName() + ".false"); { - false_builder.Parameter(0, empty_tuple_, "tuple"); - auto false_pred = false_builder.ConstantR0(false); - auto false_scalar = false_builder.ConstantR0(25.6f); - auto false_array = false_builder.ConstantR1({26.4f, 32.6f}); - false_builder.Tuple({false_pred, false_scalar, false_array}); + Parameter(&false_builder, 0, empty_tuple_, "tuple"); + auto false_pred = ConstantR0(&false_builder, false); + auto false_scalar = ConstantR0(&false_builder, 25.6f); + auto false_array = ConstantR1(&false_builder, {26.4f, 32.6f}); + Tuple(&false_builder, {false_pred, false_scalar, false_array}); } auto false_builder_result = false_builder.Build(); EXPECT_IS_OK(false_builder_result.status()); XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operands = builder.Tuple({}); - builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), - operands, false_builder_result.ConsumeValueOrDie()); + XlaOp pred; + auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); + auto operands = Tuple(&builder, {}); + Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, + false_builder_result.ConsumeValueOrDie()); ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR0(true).get(), - Literal::CreateR0(12.2f).get(), - Literal::CreateR1({12.8f, 14.6f}).get()}), - {}, error_spec_); + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(true).get(), + LiteralUtil::CreateR0(12.2f).get(), + LiteralUtil::CreateR1({12.8f, 14.6f}).get()}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a nested tuple. XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { XlaBuilder true_builder(TestName() + ".true"); { - true_builder.Parameter(0, empty_tuple_, "tuple"); - auto true_constant1 = true_builder.ConstantR0(12.2f); - auto true_constant2 = true_builder.ConstantR1({12.8f, 14.6f}); - auto true_constant3 = true_builder.ConstantR1({25.4f, 29.8f}); - auto true_constant4 = true_builder.ConstantR0(35.6f); - true_builder.Tuple({true_builder.Tuple({true_constant1, true_constant2}), - true_builder.Tuple({true_constant3, true_constant4})}); + Parameter(&true_builder, 0, empty_tuple_, "tuple"); + auto true_constant1 = ConstantR0(&true_builder, 12.2f); + auto true_constant2 = ConstantR1(&true_builder, {12.8f, 14.6f}); + auto true_constant3 = ConstantR1(&true_builder, {25.4f, 29.8f}); + auto true_constant4 = ConstantR0(&true_builder, 35.6f); + Tuple(&true_builder, + {Tuple(&true_builder, {true_constant1, true_constant2}), + Tuple(&true_builder, {true_constant3, true_constant4})}); } auto true_builder_result = true_builder.Build(); EXPECT_IS_OK(true_builder_result.status()); XlaBuilder false_builder(TestName() + ".false"); { - false_builder.Parameter(0, empty_tuple_, "tuple"); - auto false_constant1 = false_builder.ConstantR0(46.6f); - auto false_constant2 = false_builder.ConstantR1({54.4f, 58.4f}); - auto false_constant3 = false_builder.ConstantR1({62.1f, 67.4f}); - auto false_constant4 = false_builder.ConstantR0(9.3f); - false_builder.Tuple( - {false_builder.Tuple({false_constant1, false_constant2}), - false_builder.Tuple({false_constant3, false_constant4})}); + Parameter(&false_builder, 0, empty_tuple_, "tuple"); + auto false_constant1 = ConstantR0(&false_builder, 46.6f); + auto false_constant2 = ConstantR1(&false_builder, {54.4f, 58.4f}); + auto false_constant3 = ConstantR1(&false_builder, {62.1f, 67.4f}); + auto false_constant4 = ConstantR0(&false_builder, 9.3f); + Tuple(&false_builder, + {Tuple(&false_builder, {false_constant1, false_constant2}), + Tuple(&false_builder, {false_constant3, false_constant4})}); } auto false_builder_result = false_builder.Build(); EXPECT_IS_OK(false_builder_result.status()); XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto operands = builder.Tuple({}); - builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), - operands, false_builder_result.ConsumeValueOrDie()); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operands = Tuple(&builder, {}); + Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, + false_builder_result.ConsumeValueOrDie()); ComputeAndCompareTuple( &builder, - *Literal::MakeTuple( - {Literal::MakeTuple({Literal::CreateR0(46.6f).get(), - Literal::CreateR1({54.4f, 58.4f}).get()}) + *LiteralUtil::MakeTuple( + {LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(46.6f).get(), + LiteralUtil::CreateR1({54.4f, 58.4f}).get()}) .get(), - Literal::MakeTuple({Literal::CreateR1({62.1f, 67.4f}).get(), - Literal::CreateR0(9.3f).get()}) + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({62.1f, 67.4f}).get(), + LiteralUtil::CreateR0(9.3f).get()}) .get()}), - {}, error_spec_); + {pred_arg.get()}, error_spec_); } // Test conditional that takes in scalar operands in the form of external @@ -464,8 +487,8 @@ XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { CreateR0Parameter(56.3f, 1, "operand1", &builder, &operand1); auto operand2_param = CreateR0Parameter(12.7f, 2, "operand2", &builder, &operand2); - builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2, - CreateR0FloorComputation()); + Conditional(pred, operand1, CreateR0CeilComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0( &builder, 57.0f, @@ -484,8 +507,8 @@ XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { &builder, &operand1); auto operand2_param = CreateR1Parameter({10.2f, 11.6f}, 2, "operand2", &builder, &operand2); - builder.Conditional(pred, operand1, CreateR1CeilComputation(), operand2, - CreateR1FloorComputation()); + Conditional(pred, operand1, CreateR1CeilComputation(), operand2, + CreateR1FloorComputation()); ComputeAndCompareR1( &builder, {10.0f, 11.0f}, @@ -499,29 +522,29 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); - auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); - auto pred_cond = inner_builder.GetTupleElement(param0, 0); - auto true_operand = inner_builder.GetTupleElement(param0, 1); - auto false_operand = inner_builder.GetTupleElement(param0, 2); - inner_builder.Conditional(pred_cond, true_operand, - CreateR0CeilComputation(), false_operand, - CreateR0FloorComputation()); + auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0"); + auto pred_cond = GetTupleElement(param0, 0); + auto true_operand = GetTupleElement(param0, 1); + auto false_operand = GetTupleElement(param0, 2); + Conditional(pred_cond, true_operand, CreateR0CeilComputation(), + false_operand, CreateR0FloorComputation()); } auto inner_builder_result = inner_builder.Build(); EXPECT_IS_OK(inner_builder_result.status()); XlaBuilder builder(TestName()); - auto pred1 = builder.ConstantR0(true); - auto pred2 = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(1.1f); - auto operand2 = builder.ConstantR0(12.2f); - auto operand3 = builder.ConstantR0(43.3f); - auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); - builder.Conditional(pred1, tuple_operand, - inner_builder_result.ConsumeValueOrDie(), operand3, - CreateR0IdentityComputation()); - - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + XlaOp pred1, pred2; + auto pred1_arg = CreateR0Parameter(true, 0, "pred1", &builder, &pred1); + auto pred2_arg = CreateR0Parameter(false, 1, "pred2", &builder, &pred2); + auto operand1 = ConstantR0(&builder, 1.1f); + auto operand2 = ConstantR0(&builder, 12.2f); + auto operand3 = ConstantR0(&builder, 43.3f); + auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2}); + Conditional(pred1, tuple_operand, inner_builder_result.ConsumeValueOrDie(), + operand3, CreateR0IdentityComputation()); + + ComputeAndCompareR0(&builder, 12.0f, + {pred1_arg.get(), pred2_arg.get()}, error_spec_); } XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { @@ -529,36 +552,36 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); - auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); - auto pred_cond = inner_builder.GetTupleElement(param0, 0); - auto true_operand = inner_builder.GetTupleElement(param0, 1); - auto false_operand = inner_builder.GetTupleElement(param0, 2); - inner_builder.Conditional(pred_cond, true_operand, - CreateR0CeilComputation(), false_operand, - CreateR0FloorComputation()); + auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0"); + auto pred_cond = GetTupleElement(param0, 0); + auto true_operand = GetTupleElement(param0, 1); + auto false_operand = GetTupleElement(param0, 2); + Conditional(pred_cond, true_operand, CreateR0CeilComputation(), + false_operand, CreateR0FloorComputation()); } auto inner_builder_result = inner_builder.Build(); EXPECT_IS_OK(inner_builder_result.status()); XlaBuilder builder(TestName()); - auto pred2 = builder.ConstantR0(false); - auto operand1 = builder.ConstantR0(1.1f); - auto operand2 = builder.ConstantR0(12.2f); - auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); - builder.Call(inner_builder_result.ConsumeValueOrDie(), {tuple_operand}); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operand1 = ConstantR0(&builder, 1.1f); + auto operand2 = ConstantR0(&builder, 12.2f); + auto tuple_operand = Tuple(&builder, {pred, operand1, operand2}); + Call(&builder, inner_builder_result.ConsumeValueOrDie(), {tuple_operand}); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test a mismatch in the shape of the true operand and true computation. XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto operand1 = builder.ConstantR0(56.0f); - auto operand2 = builder.ConstantR0(12.0f); - auto operands = builder.Tuple({operand1, operand2}); - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, - CreateR0TupleSubComputation()); + auto pred = ConstantR0(&builder, true); + auto operand1 = ConstantR0(&builder, 56.0f); + auto operand2 = ConstantR0(&builder, 12.0f); + auto operands = Tuple(&builder, {operand1, operand2}); + Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR0TupleSubComputation()); auto result = builder.Build(); EXPECT_FALSE(result.ok()); @@ -572,46 +595,47 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { XlaComputation swapper; { XlaBuilder builder(TestName() + ".swapper"); - auto param0 = builder.Parameter(0, tuple_shape, "sp0"); - auto x = builder.GetTupleElement(param0, 0); - auto y = builder.GetTupleElement(param0, 1); - builder.Tuple({y, x}); + auto param0 = Parameter(&builder, 0, tuple_shape, "sp0"); + auto x = GetTupleElement(param0, 0); + auto y = GetTupleElement(param0, 1); + Tuple(&builder, {y, x}); swapper = builder.Build().ConsumeValueOrDie(); } XlaComputation forwarder; { XlaBuilder builder(TestName() + ".forwarder"); - auto param0 = builder.Parameter(0, tuple_shape, "fp0"); - auto x = builder.GetTupleElement(param0, 0); - auto y = builder.GetTupleElement(param0, 1); - builder.Tuple({x, y}); + auto param0 = Parameter(&builder, 0, tuple_shape, "fp0"); + auto x = GetTupleElement(param0, 0); + auto y = GetTupleElement(param0, 1); + Tuple(&builder, {x, y}); forwarder = builder.Build().ConsumeValueOrDie(); } XlaComputation main; { XlaBuilder builder(TestName() + ".main"); - auto param0 = builder.Parameter(0, tuple_shape, "mp0"); - auto x = builder.GetTupleElement(param0, 0); - auto y = builder.GetTupleElement(param0, 1); - auto lt_pred = builder.Lt(x, y); - auto res = builder.Conditional(lt_pred, param0, forwarder, param0, swapper); - auto ge_pred = builder.Ge(x, y); - builder.Conditional(ge_pred, res, swapper, res, forwarder); + auto param0 = Parameter(&builder, 0, tuple_shape, "mp0"); + auto x = GetTupleElement(param0, 0); + auto y = GetTupleElement(param0, 1); + auto lt_pred = Lt(x, y); + auto res = Conditional(lt_pred, param0, forwarder, param0, swapper); + auto ge_pred = Ge(x, y); + Conditional(ge_pred, res, swapper, res, forwarder); main = builder.Build().ConsumeValueOrDie(); } auto test_swap = [&](float a, float b) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR0(a); - auto y = builder.ConstantR0(b); - auto tuple_operand = builder.Tuple({x, y}); - builder.Call(main, {tuple_operand}); + XlaOp x, y; + auto x_arg = CreateR0Parameter(a, 0, "x", &builder, &x); + auto y_arg = CreateR0Parameter(b, 1, "y", &builder, &y); + auto tuple_operand = Tuple(&builder, {x, y}); + Call(&builder, main, {tuple_operand}); ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR0(a).get(), - Literal::CreateR0(b).get()}), - {}, error_spec_); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(a).get(), + LiteralUtil::CreateR0(b).get()}), + {x_arg.get(), y_arg.get()}, error_spec_); }; test_swap(3.11f, 9.4f); diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 916ffadbc798ec0dd016f45b0bc4c36233455ee7..49375748319ad5fe40db507a034ec4b07adb7e84 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -22,10 +22,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -39,7 +40,7 @@ class ConstantsTest : public ClientLibraryTestBase { TEST_F(ConstantsTest, ZeroCellF32) { XlaBuilder builder(TestName()); - builder.ConstantR1({}); + ConstantR1(&builder, {}); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -48,7 +49,7 @@ TEST_F(ConstantsTest, OneCellF32) { std::vector constant = {2.0}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); } @@ -57,7 +58,7 @@ TEST_F(ConstantsTest, OneCellS32) { std::vector constant = {2}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}); } @@ -66,7 +67,7 @@ TEST_F(ConstantsTest, OneCellU32) { std::vector constant = {2}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}); } @@ -75,7 +76,7 @@ TEST_F(ConstantsTest, EightCells) { std::vector constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); } @@ -85,14 +86,14 @@ TEST_F(ConstantsTest, SixteenCells) { 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}; XlaBuilder builder(TestName()); - builder.ConstantR1(constant); + ConstantR1(&builder, constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); } TEST_F(ConstantsTest, Empty_0x2) { XlaBuilder builder(TestName()); - builder.ConstantR2FromArray2D(Array2D(0, 2)); + ConstantR2FromArray2D(&builder, Array2D(0, 2)); ComputeAndCompareR2(&builder, Array2D(0, 2), {}, error_spec_); } @@ -102,15 +103,15 @@ TEST_F(ConstantsTest, Small_2x2) { MakeLinspaceArray2D(100.0, 200.0, 2, 2); XlaBuilder builder(TestName()); - builder.ConstantR2FromArray2D(*constant); + ConstantR2FromArray2D(&builder, *constant); ComputeAndCompareR2(&builder, *constant, {}, error_spec_); } TEST_F(ConstantsTest, Empty_3x0x2) { XlaBuilder builder(TestName()); - auto constant = builder.ConstantLiteral( - *Literal::CreateR3FromArray3D(Array3D(3, 0, 2))); + ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D( + Array3D(3, 0, 2))); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); } @@ -125,8 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - auto constant = - builder.ConstantLiteral(*Literal::CreateR3FromArray3D(array3d)); + ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D(array3d)); ComputeAndCompareR3(&builder, array3d, {}); } @@ -141,17 +141,17 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { }); input_array.FillWithPZ(pz); std::unique_ptr input_literal = - Literal::CreateR4FromArray4D(input_array); + LiteralUtil::CreateR4FromArray4D(input_array); { XlaBuilder builder(TestName()); - builder.ConstantLiteral(*input_literal); + ConstantLiteral(&builder, *input_literal); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } { XlaBuilder builder(TestName()); - builder.ConstantR4FromArray4D(input_array); + ConstantR4FromArray4D(&builder, input_array); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } } @@ -159,17 +159,26 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { XlaBuilder builder(TestName()); - builder.ConstantLiteral( - *Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), - Literal::CreateR1({2.0, 42}).get()})); + ConstantLiteral(&builder, + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), + LiteralUtil::CreateR1({2.0, 42}).get()})); std::unique_ptr result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); - LiteralTestUtil::ExpectR2Near( - {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_); - LiteralTestUtil::ExpectR1Near( - {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_); + LiteralTestUtil::ExpectR2Near({{1.0}, {2.0}}, + LiteralSlice(*result, {0}), error_spec_); + LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(*result, {1}), + error_spec_); +} + +TEST_F(ConstantsTest, Token) { + XlaBuilder builder(TestName()); + ConstantLiteral(&builder, *LiteralUtil::CreateToken()); + // TODO(b/80000000): tokens cannot be returned from computations. + Tuple(&builder, {}); + TF_ASSERT_OK(Execute(&builder, {}).status()); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 722d882471a41a75c1e5e60f8c1a151b76c7e004..1adc68cc4839dcd7d89741ec016f27bc9047c9a5 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -45,44 +46,107 @@ class ConvertTest : public ClientLibraryTestBase { TEST_F(ConvertTest, ConvertR1S32ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42, 64}); - builder.ConvertElementType(a, S32); + auto a = ConstantR1(&builder, {42, 64}); + ConvertElementType(a, S32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); } +TEST_F(ConvertTest, ConvertR1S32ToR1U32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {42, 64}); + ConvertElementType(a, U32); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1S32ToR1PRED) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {42, 0, -64}); + ConvertElementType(a, PRED); + + std::array expected = {true, false, true}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1U32ToR1U32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {42, 64}); + ConvertElementType(a, U32); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1U32ToR1S32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {42, 64}); + ConvertElementType(a, S32); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1U32ToR1PRED) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {42, 0, 64}); + ConvertElementType(a, PRED); + + std::array expected = {true, false, true}; + ComputeAndCompareR1(&builder, expected, {}); +} + TEST_F(ConvertTest, ConvertR1F32ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.0f, 64.0f}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {42.0f, 64.0f}); + ConvertElementType(a, F32); std::vector expected = {42.0f, 64.0f}; - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1F32ToR1PRED) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {42.0f, 0.0f, 64.0f}); + ConvertElementType(a, PRED); + + std::array expected = {true, false, true}; + ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1S32ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42, 64}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {42, 64}); + ConvertElementType(a, F32); std::vector expected = {42.0f, 64.0f}; - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1PREDToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({true, false, true}); - builder.ConvertElementType(a, S32); + auto a = ConstantR1(&builder, {true, false, true}); + ConvertElementType(a, S32); std::vector expected = {1, 0, 1}; ComputeAndCompareR1(&builder, expected, {}); } +TEST_F(ConvertTest, ConvertR1PREDToR1U32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {true, false, true}); + ConvertElementType(a, U32); + + std::vector expected = {1, 0, 1}; + ComputeAndCompareR1(&builder, expected, {}); +} + TEST_F(ConvertTest, ConvertR1PREDToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({true, false, true}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {true, false, true}); + ConvertElementType(a, F32); std::vector expected = {1., 0., 1.}; ComputeAndCompareR1(&builder, expected, {}); @@ -90,17 +154,17 @@ TEST_F(ConvertTest, ConvertR1PREDToR1F32) { XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {}); + ConvertElementType(a, F32); std::vector expected = {}; - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1F32ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({42.6, 64.4}); - builder.ConvertElementType(a, S32); + auto a = ConstantR1(&builder, {42.6, 64.4}); + ConvertElementType(a, S32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -145,12 +209,12 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { static_cast(0x8000008000000000LL), static_cast(0x8000010000000000LL), }; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, F32); + ConvertElementType(arg_param, F32); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -164,12 +228,12 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, F32); + ConvertElementType(arg_param, F32); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -182,12 +246,12 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XlaBuilder builder(TestName()); std::vector arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, U32); + ConvertElementType(arg_param, U32); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -199,12 +263,12 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, S64); + ConvertElementType(arg_param, S64); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -216,12 +280,12 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, -1, -0x1000}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, S64); + ConvertElementType(arg_param, S64); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -253,12 +317,12 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { 9223370937343148032.f, -9223371487098961920.f, -9223370937343148032.f}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); - auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); - builder.ConvertElementType(arg_param, S64); + ConvertElementType(arg_param, S64); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { @@ -269,8 +333,8 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32, 64}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {32, 64}); + ConvertElementType(a, F32); std::vector expected = {32.0, 64.0}; ComputeAndCompareR1(&builder, expected, {}); @@ -278,8 +342,8 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32, 64}); - builder.ConvertElementType(a, S32); + auto a = ConstantR1(&builder, {32, 64}); + ConvertElementType(a, S32); std::vector expected = {32, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -287,8 +351,8 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32, 64}); - builder.ConvertElementType(a, U32); + auto a = ConstantR1(&builder, {32, 64}); + ConvertElementType(a, U32); std::vector expected = {32, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -296,8 +360,8 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32.0f, 64.0f}); - builder.ConvertElementType(a, F64); + auto a = ConstantR1(&builder, {32.0f, 64.0f}); + ConvertElementType(a, F64); std::vector expected = {32.0, 64.0}; ComputeAndCompareR1(&builder, expected, {}); @@ -305,8 +369,8 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({32.0, 64.0}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {32.0, 64.0}); + ConvertElementType(a, F32); std::vector expected = {32.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}); @@ -314,9 +378,9 @@ XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { TEST_F(ConvertTest, ConvertS32Extremes) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1( - {std::numeric_limits::min(), std::numeric_limits::max()}); - builder.ConvertElementType(a, F32); + auto a = ConstantR1(&builder, {std::numeric_limits::min(), + std::numeric_limits::max()}); + ConvertElementType(a, F32); std::vector expected = { static_cast(std::numeric_limits::min()), @@ -327,10 +391,10 @@ TEST_F(ConvertTest, ConvertS32Extremes) { TEST_F(ConvertTest, ConvertMapToS32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); - auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); - b->ConvertElementType(param, S32); - auto a = builder.ConstantR1({42.0f, 64.0f}); - builder.Map({a}, b->BuildAndNoteError(), {0}); + auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "in"); + ConvertElementType(param, S32); + auto a = ConstantR1(&builder, {42.0f, 64.0f}); + Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -339,10 +403,10 @@ TEST_F(ConvertTest, ConvertMapToS32) { TEST_F(ConvertTest, ConvertMapToF32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); - auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); - b->ConvertElementType(param, F32); - auto a = builder.ConstantR1({42, 64}); - builder.Map({a}, b->BuildAndNoteError(), {0}); + auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(S32, {}), "in"); + ConvertElementType(param, F32); + auto a = ConstantR1(&builder, {42, 64}); + Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -355,9 +419,9 @@ TEST_F(ConvertTest, ConvertMapToF32) { // the new convert should have the same element type as the old convert. TEST_F(ConvertTest, ConvertReshape) { XlaBuilder builder(TestName()); - auto input = builder.ConstantR1({42}); - auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); - builder.ConvertElementType(reshape, F32); + auto input = ConstantR1(&builder, {42}); + auto reshape = Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); + ConvertElementType(reshape, F32); ComputeAndCompareR0(&builder, 42.0f, {}, ErrorSpec(0.0001)); } @@ -391,13 +455,13 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*Literal::CreateR1(input))); + client_->TransferToServer(*LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); - builder.ConvertElementType( - builder.Parameter( - 0, ShapeUtil::MakeShape(F16, {static_cast(input.size())}), - "param"), + ConvertElementType( + Parameter(&builder, 0, + ShapeUtil::MakeShape(F16, {static_cast(input.size())}), + "param"), F32); ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); @@ -411,13 +475,13 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*Literal::CreateR1(input))); + client_->TransferToServer(*LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); - builder.ConvertElementType( - builder.Parameter( - 0, ShapeUtil::MakeShape(F32, {static_cast(input.size())}), - "param"), + ConvertElementType( + Parameter(&builder, 0, + ShapeUtil::MakeShape(F32, {static_cast(input.size())}), + "param"), F16); ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); @@ -426,28 +490,28 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { XLA_TEST_F(ConvertTest, ConvertC64ToC64) { XlaBuilder builder(TestName()); std::vector x = {{42.0f, 64.0f}}; - builder.ConvertElementType(builder.ConstantR1(x), C64); + ConvertElementType(ConstantR1(&builder, x), C64); ComputeAndCompareR1(&builder, x, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConvertTest, ConvertS64S64) { XlaBuilder builder(TestName()); std::vector x = {{-42, 64}}; - builder.ConvertElementType(builder.ConstantR1(x), S64); + ConvertElementType(ConstantR1(&builder, x), S64); ComputeAndCompareR1(&builder, x, {}); } XLA_TEST_F(ConvertTest, ConvertU64U64) { XlaBuilder builder(TestName()); std::vector x = {{42, 64}}; - builder.ConvertElementType(builder.ConstantR1(x), U64); + ConvertElementType(ConstantR1(&builder, x), U64); ComputeAndCompareR1(&builder, x, {}); } XLA_TEST_F(ConvertTest, ConvertU64S64) { XlaBuilder builder(TestName()); std::vector unsigned_x = {{42, UINT64_MAX}}; - builder.ConvertElementType(builder.ConstantR1(unsigned_x), S64); + ConvertElementType(ConstantR1(&builder, unsigned_x), S64); std::vector signed_x = {{42, -1}}; ComputeAndCompareR1(&builder, signed_x, {}); } @@ -455,11 +519,31 @@ XLA_TEST_F(ConvertTest, ConvertU64S64) { XLA_TEST_F(ConvertTest, ConvertS64U64) { XlaBuilder builder(TestName()); std::vector signed_x = {{42, -1, INT64_MIN}}; - builder.ConvertElementType(builder.ConstantR1(signed_x), U64); + ConvertElementType(ConstantR1(&builder, signed_x), U64); std::vector unsigned_x = { {42, UINT64_MAX, tensorflow::MathUtil::IPow(2, 63)}}; ComputeAndCompareR1(&builder, unsigned_x, {}); } +XLA_TEST_F(ConvertTest, ConvertBF16F32) { + XlaBuilder builder(TestName()); + + std::vector all_bfloats(1 << 16); + for (int i = 0; i < all_bfloats.size(); ++i) { + all_bfloats[i].value = i; + } + + std::vector expected(all_bfloats.size()); + for (int i = 0; i < expected.size(); ++i) { + expected[i] = (1U << 16) * i; + } + + // Exhaustively test all bf16 to f32 conversions. + xla::XlaOp all_bfloats_bf16 = ConstantR1(&builder, all_bfloats); + xla::XlaOp all_bfloats_f32 = ConvertElementType(all_bfloats_bf16, F32); + BitcastConvertType(all_bfloats_f32, U32); + ComputeAndCompareR1(&builder, expected, {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index b5a42e305987df030c15d089f5877f73bb61de1b..7b6bbc4f571af2e11306f95c24e243e78e0f4f4e 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -93,14 +93,15 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, auto weight_array = MakeUnique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = - client_->TransferToServer(*Literal::CreateR4FromArray4D(*weight_array)) + client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(*input_array); + auto input = ConstantR4FromArray4D(&builder, *input_array); auto weight = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight"); - auto conv1 = builder.Conv(input, weight, {1, 1}, Padding::kValid); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight"); + auto conv1 = Conv(input, weight, {1, 1}, Padding::kValid); ConvolutionDimensionNumbers dim_nums = XlaBuilder::CreateDefaultConvDimensionNumbers(); @@ -117,8 +118,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, dim_nums.set_kernel_input_feature_dimension( dim_nums.kernel_output_feature_dimension()); dim_nums.set_kernel_output_feature_dimension(old_kernel_input_feature_dim); - builder.ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid, - dim_nums); + ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid, dim_nums); auto expected_conv1 = ReferenceUtil::ConvArray4D(*input_array, *weight_array, {1, 1}, Padding::kValid); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 346bb3a3996ee5bf662b0f74dd0c2096efbf5295..5ed8122e0073bde77bb2507a0ddd89c4365627c9 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -89,9 +89,9 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { ASSERT_EQ(2, arhs->height()); XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR4FromArray4D(*alhs); - auto rhs = builder.ConstantR4FromArray4D(*arhs); - builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); + auto lhs = ConstantR4FromArray4D(&builder, *alhs); + auto rhs = ConstantR4FromArray4D(&builder, *arhs); + Conv(lhs, rhs, {1, 1}, Padding::kValid); ComputeAndCompare(&builder, {}, error_spec_); } @@ -109,9 +109,9 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 1, 2); input_data.FillWithYX(Array2D({ @@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } }; @@ -140,9 +140,9 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({ @@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { {7.0f, 8.0f}, })); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } }; @@ -174,9 +174,9 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({ @@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } }; @@ -210,9 +210,9 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 3, 3}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({{1.0f, 2.0f, 3.0f, 4.0f}, @@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}})); // clang-format on ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } }; @@ -238,9 +238,9 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1}, Padding::kValid); } Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); @@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { Array3D expected({{{510, 610, 710, 810}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -268,10 +268,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { { Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( + ConvGeneralDilated( input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2}, /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); @@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { Array3D expected({{{570.0f, 670.0f, 770.0f}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -304,10 +304,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( + ConvGeneralDilated( input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, /*lhs_dilation=*/{2}, /*rhs_dilation=*/{1}, /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); @@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { Array3D expected({{{190, 320, 230, 380, 270, 440, 310, 500}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -335,10 +335,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( + ConvGeneralDilated( input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, /*lhs_dilation=*/{2}, /*rhs_dilation=*/{2}, /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); @@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { Array3D expected({{{510, 0, 610, 0, 710, 0, 810}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -369,10 +369,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { { Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( + ConvGeneralDilated( input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}}, /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1}, /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); @@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -408,8 +408,8 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { Shape input_shape = ShapeUtil::MakeShape(F32, input_dims); Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims); { - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Tensorflow dimension numbers for 3D convolution. ConvolutionDimensionNumbers dnums; @@ -429,21 +429,20 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { dnums.set_kernel_input_feature_dimension(3); dnums.set_kernel_output_feature_dimension(4); - builder.ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid, - dnums); + ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid, dnums); } std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota(input_elems.begin(), input_elems.end(), 1.0f); - auto input_r1 = Literal::CreateR1(input_elems); + auto input_r1 = LiteralUtil::CreateR1(input_elems); auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota(filter_elems.begin(), filter_elems.end(), 1.0f); - auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - auto expected_r1 = Literal::CreateR1( + auto expected_r1 = LiteralUtil::CreateR1( {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446, 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470}); auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); @@ -475,8 +474,8 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Tensorflow dimension numbers for 2D convolution. ConvolutionDimensionNumbers dnums; @@ -493,21 +492,20 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { dnums.set_kernel_input_feature_dimension(2); dnums.set_kernel_output_feature_dimension(3); - builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, - dnums); + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums); } std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); - auto input_r1 = Literal::CreateR1(input_elems); + auto input_r1 = LiteralUtil::CreateR1(input_elems); auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); - auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - auto expected_r1 = Literal::CreateR1( + auto expected_r1 = LiteralUtil::CreateR1( {static_cast(92115), static_cast(93150), static_cast(94185)}); auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); @@ -541,8 +539,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29}); Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); ConvolutionDimensionNumbers dnums; dnums.set_input_feature_dimension(0); @@ -551,7 +549,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, dnums.set_kernel_output_feature_dimension(1); dnums.set_output_batch_dimension(0); dnums.set_output_feature_dimension(1); - builder.ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums); + ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums); Array2D param0(4, 29); param0.FillUnique(); @@ -563,8 +561,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, expected_result.Fill(0); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(param0)), - std::move(*Literal::CreateFromArray(param1))}, + {std::move(*LiteralUtil::CreateFromArray(param0)), + std::move(*LiteralUtil::CreateFromArray(param1))}, error_spec_); } @@ -599,8 +597,8 @@ class Convolve1D1WindowTestBase Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); // Tensorflow dimension numbers for 1D convolution. ConvolutionDimensionNumbers dnums; @@ -614,24 +612,23 @@ class Convolve1D1WindowTestBase dnums.set_kernel_input_feature_dimension(1); dnums.set_kernel_output_feature_dimension(2); - builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, - dnums); + ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums); } std::vector input_elems(ShapeUtil::ElementsIn(input_shape), static_cast(1.0f)); - auto input_r1 = Literal::CreateR1(input_elems); + auto input_r1 = LiteralUtil::CreateR1(input_elems); auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), static_cast(1.0f)); - auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); std::vector expect_elems(batch * output_feature * num_windows, static_cast(window_size * input_feature)); - auto expected_r1 = Literal::CreateR1(expect_elems); + auto expected_r1 = LiteralUtil::CreateR1(expect_elems); auto expected_r3 = expected_r1->Reshape({batch, num_windows, output_feature}) .ConsumeValueOrDie(); @@ -726,9 +723,9 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 1, 2); input_data.FillWithYX(Array2D({ @@ -740,8 +737,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { })); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } @@ -754,9 +751,9 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 1, 2); input_data.FillIota(0); @@ -764,8 +761,8 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { filter_data.FillIota(10); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}); + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index fea850dc135e33fe098aa755c6fdd93319cd2837..6784c16715da72d337edf70fa51db42c59404136 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -27,8 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -55,12 +55,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Minimal) { XlaBuilder builder(TestName()); const Array4D input_array(1, 1, 1, 1, {2}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {3}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); const Array4D expected(1, 1, 1, 1, {6}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -70,12 +70,12 @@ XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) { XlaBuilder builder(TestName()); const Array4D input_array(5, 1, 1, 1, {1, 2, 3, 4, 5}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {2}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); const Array4D expected(5, 1, 1, 1, {2, 4, 6, 8, 10}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -86,12 +86,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) { Array4D input_array(2, 1, 3, 4); input_array.FillWithMultiples(1); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {2.3}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(2, 1, 3, 4); expected.FillWithMultiples(2.3); @@ -102,12 +102,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) { XlaBuilder builder(TestName()); Array4D input_array(1, 2, 1, 1, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 3, 1, 1, {12, 34, 56}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -117,12 +117,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 2, {1, 2}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 1, {12}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -132,12 +132,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 2, {12, 23}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -147,12 +147,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 2, 1, {12, 34}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -162,12 +162,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 2, 1, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 2, {13, 24}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -177,12 +177,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 2, 2, {1000, 100, 10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 1, {1234}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -194,13 +194,13 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) { Array4D input_array( 2, 2, 2, 3, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, // plane 0 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 0, 0}); // plane 1 - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array( 2, 2, 1, 2, {1000, 100, 10, 1, 0.1, 0.01, 0.001, 0.0001}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected( 2, 2, 2, 2, @@ -213,12 +213,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {10}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 2}, Padding::kValid); + Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 2, {10, 30}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -228,12 +228,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {10}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 2}, Padding::kValid); + Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 3, {10, 30, 50}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -243,12 +243,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 3, {100, 10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 2}, Padding::kValid); + Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 1, {123}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -258,12 +258,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 3, {100, 10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 2}, Padding::kValid); + Conv(input, filter, {1, 2}, Padding::kValid); Array4D expected(1, 1, 1, 2, {123, 345}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -273,12 +273,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 1, {10}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {2, 2}, Padding::kValid); + Conv(input, filter, {2, 2}, Padding::kValid); Array4D expected(1, 1, 2, 2, {10, 30, 70, 90}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -288,12 +288,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 1, {1}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 3, {10, 20, 30}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 1, 1, {20}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -303,12 +303,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 5, {10000, 1000, 100, 10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 1, 3, {123, 1230, 12300}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -318,15 +318,15 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 3, 3, {10000, 0, 1000, // row 0 0, 100, 0, // row 1 10, 0, 1}); // row 2 - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 2, 2, {104, 230, 2300, 10400}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -336,12 +336,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) { XlaBuilder builder(TestName()); Array4D input_array(1, 2, 1, 2, {1, 2, 3, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 2, 1, 1, {10, 1}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kSame); + Conv(input, filter, {1, 1}, Padding::kSame); Array4D expected(1, 1, 1, 2, {13, 24}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -351,12 +351,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 2, 2, {7, 13, 17, 23}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 2, 2, {216, 276, 396, 456}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -366,12 +366,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) { XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); const Array4D filter_array(1, 1, 1, 2, {7, 13}); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 1, 1, 2, {33, 53}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -383,15 +383,15 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) { std::vector input_data(64); std::iota(input_data.begin(), input_data.end(), 0.0); Array4D input_array(1, 1, 8, 8, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(128); std::fill(filter_data.begin(), filter_data.begin() + 64, 1.0); std::fill(filter_data.begin() + 64, filter_data.begin() + 128, 2.0); const Array4D filter_array(2, 1, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 2, 1, 1, {2016, 4032}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -403,14 +403,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) { std::vector input_data(16 * 1 * 1 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(16, 1, 1, 1, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 1 * 1); std::iota(filter_data.begin(), filter_data.end(), 1.0); const Array4D filter_array(1, 1, 1, 1, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; @@ -432,14 +432,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) { } } } - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * ky * kx); std::iota(filter_data.begin(), filter_data.end(), 1.0); const Array4D filter_array(1, 1, ky, kx, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data(bs); for (int i = 0; i < bs; ++i) { @@ -463,14 +463,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { } } } - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * ky * kx); std::iota(filter_data.begin(), filter_data.end(), 1.0); const Array4D filter_array(1, 1, ky, kx, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = { 23, @@ -492,14 +492,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) { } } } - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 8 * 8); std::iota(filter_data.begin(), filter_data.end(), 1.0); const Array4D filter_array(1, 1, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = { 19664, 21744, 23824, 25904, 27984, 30064, 32144, 34224, @@ -515,7 +515,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { std::vector input_data(2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); Array4D input_array(1, 2, 8, 8, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(2 * 2 * 8 * 8); std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4, @@ -527,9 +527,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(), 4.0); const Array4D filter_array(2, 2, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(1, 2, 1, 1, {14240, 30496}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -541,7 +541,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { std::vector input_data(2 * 2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); Array4D input_array(2, 2, 8, 8, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(2 * 2 * 8 * 8); std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4, @@ -553,9 +553,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(), 4.0); const Array4D filter_array(2, 2, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(2, 2, 1, 1, {14240, 30496, 38816, 87840}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -567,7 +567,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { std::vector input_data(32 * 2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); Array4D input_array(32, 2, 8, 8, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(2 * 2 * 8 * 8); std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4, @@ -579,9 +579,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(), 4.0); const Array4D filter_array(2, 2, 8, 8, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = { 14240, 30496, 38816, 87840, 63392, 145184, 87968, @@ -613,9 +613,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { } } - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); Array4D expected(16, 16, 1, 1); for (int i0 = 0; i0 < 16; ++i0) { @@ -635,9 +635,9 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) { Array4D input_array(1, 1, 4, 6, input_data); Array4D filter_array(1, 1, 2, 3, {1, 10, 100, 2, 20, 200}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{2, 2}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -654,9 +654,9 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -677,9 +677,9 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { 200, 20, 2, // 300, 30, 3, // 400, 40, 4}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{2, 1}, /*padding=*/{{1, 0}, {0, 0}}, /*lhs_dilation=*/{3, 2}, /*rhs_dilation=*/{}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -699,9 +699,9 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneral( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-1, -1}}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -718,9 +718,9 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneral( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-1, 2}}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -737,9 +737,9 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneral( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {2, -1}}, XlaBuilder::CreateDefaultConvDimensionNumbers()); @@ -756,9 +756,9 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {3, 2}}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, @@ -781,9 +781,9 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) { Array4D input_array(1, 1, 1, 5, input_data); Array4D filter_array(1, 1, 1, 2, {10, 1}); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.ConvGeneralDilated( + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-3, -2}}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, @@ -821,9 +821,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) { Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -854,9 +854,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) { Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -887,9 +887,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) { Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -920,9 +920,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) { Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -954,9 +954,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Array4D filter_array(oz, iz, ky, kx, kernel_data); XlaBuilder builder(TestName()); - auto input = builder.ConstantR4FromArray4D(input_array); - auto filter = builder.ConstantR4FromArray4D(filter_array); - builder.Conv(input, filter, {1, 1}, Padding::kValid); + auto input = ConstantR4FromArray4D(&builder, input_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); + Conv(input, filter, {1, 1}, Padding::kValid); std::unique_ptr> expected = ReferenceUtil::ConvArray4D( input_array, filter_array, {1, 1}, Padding::kValid); @@ -970,12 +970,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 2, 3, 1, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 2 * 1 * 1); std::iota(filter_data.begin(), filter_data.end(), 1.0); Array4D filter_array(1, 2, 1, 1, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); ConvolutionDimensionNumbers dnums; // NHWC input format. @@ -995,7 +995,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { dnums.set_kernel_output_feature_dimension(3); // Tests padding sizes that don't correspond either to SAME or VALID padding. - builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); + ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); std::vector expected_data = { 0, 0, 0, 0, 0, 0, 0, // @@ -1014,12 +1014,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 2, 3, 1, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 1 * 1); std::iota(filter_data.begin(), filter_data.end(), 2.0); Array4D filter_array(1, 1, 1, 1, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); ConvolutionDimensionNumbers dnums; // NHWC input format. @@ -1039,7 +1039,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { dnums.set_kernel_output_feature_dimension(3); // Tests padding sizes that don't correspond either to SAME or VALID padding. - builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); + ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums); std::vector expected_data = { 0, 0, 0, 0, 0, 0, 0, 0, // @@ -1058,12 +1058,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 2, 3, 1, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 1 * 1); std::iota(filter_data.begin(), filter_data.end(), 2.0); Array4D filter_array(1, 1, 1, 1, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); ConvolutionDimensionNumbers dnums; // NHWC input format. @@ -1083,7 +1083,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { dnums.set_kernel_output_feature_dimension(3); // Tests zero padding sizes. This can use matmul for computation. - builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); + ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); std::vector expected_data = { 2, 4, 6, // @@ -1099,12 +1099,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { std::vector input_data(1 * 2 * 3 * 2); std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 2, 3, 2, input_data); - auto input = builder.ConstantR4FromArray4D(input_array); + auto input = ConstantR4FromArray4D(&builder, input_array); std::vector filter_data(1 * 1 * 2 * 3); std::iota(filter_data.begin(), filter_data.end(), 2.0); Array4D filter_array(1, 1, 2, 3, filter_data); - auto filter = builder.ConstantR4FromArray4D(filter_array); + auto filter = ConstantR4FromArray4D(&builder, filter_array); ConvolutionDimensionNumbers dnums; // NHWC input format. @@ -1124,7 +1124,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { dnums.set_kernel_output_feature_dimension(3); // Tests zero padding sizes. This can use matmul for computation. - builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); + ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums); std::vector expected_data = { 12, 15, 18, // @@ -1148,14 +1148,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); - auto weights = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 2, /*values=*/{5, 6})); - auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {1, 0}}); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); + auto weights = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 2, /*values=*/{5, 6})); + auto mirrored_weights = Rev(weights, {2, 3}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {1, 0}}); ComputeAndCompareR4(&builder, {{{{5, 16, 27}}}}, {}, error_spec_); } @@ -1167,16 +1167,16 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 1, /*values=*/{1})); - auto weights = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); - auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvGeneralDilated(gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {0, 3}}, - /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, - XlaBuilder::CreateDefaultConvDimensionNumbers()); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 1, /*values=*/{1})); + auto weights = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); + auto mirrored_weights = Rev(weights, {2, 3}); + ConvGeneralDilated(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {0, 3}}, + /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); ComputeAndCompareR4(&builder, {{{{100, 0}}}}, {}, error_spec_); } @@ -1187,14 +1187,14 @@ XLA_TEST_F(ConvolutionVariantsTest, XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 1, /*values=*/{1})); - auto weights = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); - auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {1, 1}}); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 1, /*values=*/{1})); + auto weights = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); + auto mirrored_weights = Rev(weights, {2, 3}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {1, 1}}); ComputeAndCompareR4(&builder, {{{{10}}}}, {}, error_spec_); } @@ -1208,14 +1208,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); - auto weights = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 2, /*values=*/{1, 10})); - auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {0, 2}}); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); + auto weights = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 2, /*values=*/{1, 10})); + auto mirrored_weights = Rev(weights, {2, 3}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {0, 2}}); ComputeAndCompareR4(&builder, {{{{12, 23, 30, 0}}}}, {}, error_spec_); } @@ -1229,17 +1229,17 @@ XLA_TEST_F(ConvolutionVariantsTest, // weight gradients: 24,130,240 // // This pattern will be fused to backward convolution with padding=(1,2). - auto activations = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {1, 2}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - XlaBuilder::CreateDefaultConvDimensionNumbers()); - builder.Transpose(forward_conv, {0, 1, 2, 3}); + auto activations = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); + auto forward_conv = + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {1, 2}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); + Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{24, 130, 240}}}}, {}, error_spec_); } @@ -1255,17 +1255,17 @@ XLA_TEST_F(ConvolutionVariantsTest, // This pattern will be fused to backward convolution with padding=(2,1). // Note: both (2,1) and (2,0) are valid padding for the backward convolution // because the stride is 2. - auto activations = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {2, 0}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - XlaBuilder::CreateDefaultConvDimensionNumbers()); - builder.Transpose(forward_conv, {0, 1, 2, 3}); + auto activations = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); + auto forward_conv = + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {2, 0}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); + Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{13, 24}}}}, {}, error_spec_); } @@ -1282,17 +1282,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { // because the stride is 2. ConvolutionFolding prefers (2,2) because cuDNN // supports even padding only -- using (2,1) would need extra effort of // canonicalization. - auto activations = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); - auto gradients = builder.ConstantR4FromArray4D( - Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {2, 1}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - XlaBuilder::CreateDefaultConvDimensionNumbers()); - builder.Transpose(forward_conv, {0, 1, 2, 3}); + auto activations = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 4, /*values=*/{1, 2, 3, 4})); + auto gradients = ConstantR4FromArray4D( + &builder, Array4D(1, 1, 1, 3, /*values=*/{100, 10, 1})); + auto forward_conv = + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); + Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{13, 24, 130}}}}, {}, error_spec_); } @@ -1300,14 +1300,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) { XlaBuilder builder(TestName()); - auto gradients = builder.ConstantR3FromArray3D( - Array3D(1, 1, 1, /*value=*/1)); + auto gradients = ConstantR3FromArray3D( + &builder, Array3D(1, 1, 1, /*value=*/1)); auto weights = - builder.ConstantR3FromArray3D(Array3D({{{1, 10, 100}}})); - auto mirrored_weights = builder.Rev(weights, {2}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1}, - /*padding=*/{{1, 1}}); + ConstantR3FromArray3D(&builder, Array3D({{{1, 10, 100}}})); + auto mirrored_weights = Rev(weights, {2}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1}, + /*padding=*/{{1, 1}}); ComputeAndCompareR3(&builder, {{{10}}}, {}, error_spec_); } @@ -1315,17 +1315,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { XlaBuilder builder(TestName()); auto activations = - builder.ConstantR3FromArray3D(Array3D({{{1, 2, 3, 4}}})); + ConstantR3FromArray3D(&builder, Array3D({{{1, 2, 3, 4}}})); auto gradients = - builder.ConstantR3FromArray3D(Array3D({{{100, 10, 1}}})); + ConstantR3FromArray3D(&builder, Array3D({{{100, 10, 1}}})); auto forward_conv = - builder.ConvGeneralDilated(activations, gradients, - /*window_strides=*/{1}, - /*padding=*/{{2, 1}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{2}, - XlaBuilder::CreateDefaultConvDimensionNumbers( - /*num_spatial_dims=*/1)); - builder.Transpose(forward_conv, {0, 1, 2}); + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1}, + /*padding=*/{{2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{2}, + XlaBuilder::CreateDefaultConvDimensionNumbers( + /*num_spatial_dims=*/1)); + Transpose(forward_conv, {0, 1, 2}); ComputeAndCompareR3(&builder, {{{13, 24, 130}}}, {}, error_spec_); } @@ -1333,52 +1333,52 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { XlaBuilder builder(TestName()); - auto gradients_flat = Literal::CreateR1({1}); + auto gradients_flat = LiteralUtil::CreateR1({1}); auto gradients_literal = gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); - auto gradients = builder.ConstantLiteral(*gradients_literal); + auto gradients = ConstantLiteral(&builder, *gradients_literal); - auto weights_flat = Literal::CreateR1({1, 10, 100}); + auto weights_flat = LiteralUtil::CreateR1({1, 10, 100}); auto weights_literal = weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto weights = builder.ConstantLiteral(*weights_literal); + auto weights = ConstantLiteral(&builder, *weights_literal); - auto expected_flat = Literal::CreateR1({10}); + auto expected_flat = LiteralUtil::CreateR1({10}); auto expected_literal = expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); - auto mirrored_weights = builder.Rev(weights, {2, 3, 4}); - builder.ConvWithGeneralPadding(gradients, mirrored_weights, - /*window_strides=*/{1, 1, 1}, - /*padding=*/{{0, 0}, {0, 0}, {1, 1}}); + auto mirrored_weights = Rev(weights, {2, 3, 4}); + ConvWithGeneralPadding(gradients, mirrored_weights, + /*window_strides=*/{1, 1, 1}, + /*padding=*/{{0, 0}, {0, 0}, {1, 1}}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { XlaBuilder builder(TestName()); - auto activations_flat = Literal::CreateR1({1, 2, 3, 4}); + auto activations_flat = LiteralUtil::CreateR1({1, 2, 3, 4}); auto activations_literal = activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); - auto activations = builder.ConstantLiteral(*activations_literal); + auto activations = ConstantLiteral(&builder, *activations_literal); - auto gradients_flat = Literal::CreateR1({100, 10, 1}); + auto gradients_flat = LiteralUtil::CreateR1({100, 10, 1}); auto gradients_literal = gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto gradients = builder.ConstantLiteral(*gradients_literal); + auto gradients = ConstantLiteral(&builder, *gradients_literal); - auto expected_flat = Literal::CreateR1({13, 24, 130}); + auto expected_flat = LiteralUtil::CreateR1({13, 24, 130}); auto expected_literal = expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1, 1, 1}, - /*padding=*/{{0, 0}, {0, 0}, {2, 1}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 1, 2}, - XlaBuilder::CreateDefaultConvDimensionNumbers( - /*num_spatial_dims=*/3)); - builder.Transpose(forward_conv, {0, 1, 2, 3, 4}); + auto forward_conv = + ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1, 1, 1}, + /*padding=*/{{0, 0}, {0, 0}, {2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 1, 2}, + XlaBuilder::CreateDefaultConvDimensionNumbers( + /*num_spatial_dims=*/3)); + Transpose(forward_conv, {0, 1, 2, 3, 4}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 2b3390ca98cb2922410d451c06811aa9d4ff8c0b..5ef273e5a26ea8a16db864974c9bfa2c296cbce8 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -58,37 +58,38 @@ class CopyOpTest : public HloTestBase { }; XLA_TEST_F(CopyOpTest, CopyR0Bool) { - TestCopyOp(*Literal::CreateR0(true)); + TestCopyOp(*LiteralUtil::CreateR0(true)); } XLA_TEST_F(CopyOpTest, CopyR1S0U32) { - TestCopyOp(*Literal::CreateR1({})); + TestCopyOp(*LiteralUtil::CreateR1({})); } XLA_TEST_F(CopyOpTest, CopyR1S3U32) { - TestCopyOp(*Literal::CreateR1({1, 2, 3})); + TestCopyOp(*LiteralUtil::CreateR1({1, 2, 3})); } XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) { - TestCopyOp(*Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + TestCopyOp( + *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) { - TestCopyOp(*Literal::CreateR4( + TestCopyOp(*LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) { - TestCopyOp(*Literal::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); + TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); } XLA_TEST_F(CopyOpTest, CopyParameterScalar) { auto builder = HloComputation::Builder(TestName()); // Copy literal to device to use as parameter. - auto literal = Literal::CreateR0(42.0); + auto literal = LiteralUtil::CreateR0(42.0); Shape shape = literal->shape(); auto param0 = builder.AddInstruction( @@ -109,7 +110,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { auto builder = HloComputation::Builder(TestName()); - auto literal = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -131,7 +132,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { HloComputation::Builder builder(TestName()); std::unique_ptr literal = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); // Reverse the minor-to-major order of the literal. Layout* literal_layout = literal->mutable_shape_do_not_use()->mutable_layout(); @@ -168,7 +169,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = Literal::CreateR3FromArray3D(a); + std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -202,7 +203,7 @@ void CopyOpTest::TestCopyConstantLayoutR4( HloComputation::Builder builder(TestName()); - std::unique_ptr literal = Literal::CreateR4FromArray4D(a); + std::unique_ptr literal = LiteralUtil::CreateR4FromArray4D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -248,7 +249,7 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { auto empty = Literal::CreateFromShape(in_shape); XlaBuilder builder(TestName()); - auto param0 = builder.Parameter(0, in_shape, "input"); + Parameter(&builder, 0, in_shape, "input"); auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc index b151187c4b8f01c5b46ccadf27d2e22a7c902e98..d12a4e7fcd7813775a81677bcaa07af60ff9b477 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -45,7 +45,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); - auto literal = Literal::CreateR1({1, 2, 3}); + auto literal = LiteralUtil::CreateR1({1, 2, 3}); EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); } @@ -66,10 +66,10 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); - auto literal0 = Literal::CreateR1({1, 2, 3}); - auto literal1 = Literal::CreateR1({10, 20}); + auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); + auto literal1 = LiteralUtil::CreateR1({10, 20}); EXPECT_EQ( - *Literal::MakeTuple({literal0.get(), literal1.get()}), + *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()})); } @@ -93,9 +93,9 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); - auto literal0 = Literal::CreateR1({1, 2, 3}); - auto literal1 = Literal::CreateR1({10, 20}); - EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}), + auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); + auto literal1 = LiteralUtil::CreateR1({10, 20}); + EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), *ExecuteAndTransfer(std::move(module), {literal0.get()})); } diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index b43d5c9ff5d75ee0e1b3c9ceb2bc295e631ac107..13c777835eb2d2519d39205cdc96f0aac4850c7d 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" @@ -73,7 +74,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2")); @@ -94,7 +95,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { array(1, 1) = 4.0f; auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2FromArray2D(array))); + HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array))); builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum")); @@ -110,7 +111,7 @@ XLA_TEST_F(CustomCallTest, auto b = HloComputation::Builder(TestName()); auto input = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2FromArray2D( + HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D( Array2D{{1.0f, 2.0f}, {3.0f, 4.0f}}))); auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall( ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues")); @@ -135,8 +136,8 @@ class CustomCallClientAPITest : public ClientLibraryTestBase {}; // are reserved for internal use. XLA_TEST_F(CustomCallClientAPITest, IllegalCustomCallTarget) { XlaBuilder builder(TestName()); - builder.CustomCall("$illegal", /*operands=*/{}, - ShapeUtil::MakeShape(F32, {1})); + CustomCall(&builder, "$illegal", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {1})); StatusOr> result = Execute(&builder, /*arguments=*/{}); diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index bfe688e20d182d581c3e3b545ac2289413deef7c..5f234f36a8543ad408fb3430b27844beb16a54b5 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -48,7 +48,7 @@ class DeallocationTest : public ClientLibraryTestBase { TEST_F(DeallocationTest, DeallocateScalar) { XlaBuilder builder(TestName()); - builder.ConstantR0(42.0); + ConstantR0(&builder, 42.0); auto global_data = ExecuteAndCheckTransfer(&builder, {}); // A result can be transferred an arbitrary number of times. Add an extra @@ -66,7 +66,7 @@ TEST_F(DeallocationTest, DeallocateScalar) { TEST_F(DeallocationTest, DeallocateVector) { XlaBuilder builder(TestName()); - builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); @@ -79,7 +79,7 @@ TEST_F(DeallocationTest, DeallocateVector) { TEST_F(DeallocationTest, DeallocateEmptyVector) { XlaBuilder builder(TestName()); - builder.ConstantR1({}); + ConstantR1(&builder, {}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); @@ -92,8 +92,8 @@ TEST_F(DeallocationTest, DeallocateEmptyVector) { XLA_TEST_F(DeallocationTest, DeallocateTuple) { XlaBuilder builder(TestName()); - builder.Tuple({builder.ConstantR0(42.0), - builder.ConstantR1({1.0, 2.0, 3.0})}); + Tuple(&builder, {ConstantR0(&builder, 42.0), + ConstantR1(&builder, {1.0, 2.0, 3.0})}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); @@ -106,9 +106,10 @@ XLA_TEST_F(DeallocationTest, DeallocateTuple) { XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { XlaBuilder builder(TestName()); - auto element = builder.ConstantR0(42.0); - auto inner_tuple = builder.Tuple({builder.ConstantR0(42.0), element}); - builder.Tuple({element, inner_tuple, element}); + auto element = ConstantR0(&builder, 42.0); + auto inner_tuple = + Tuple(&builder, {ConstantR0(&builder, 42.0), element}); + Tuple(&builder, {element, inner_tuple, element}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); @@ -122,9 +123,9 @@ XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { XlaBuilder builder(TestName()); auto inner_tuple = - builder.Tuple({builder.ConstantR0(42.0), - builder.ConstantR1({1.0, 2.0, 3.0})}); - builder.Tuple({inner_tuple, builder.ConstantR1({0.123, 0.456})}); + Tuple(&builder, {ConstantR0(&builder, 42.0), + ConstantR1(&builder, {1.0, 2.0, 3.0})}); + Tuple(&builder, {inner_tuple, ConstantR1(&builder, {0.123, 0.456})}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); ASSERT_IS_OK(client_->Unregister(*global_data)); diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 12789fe66530fe03eb33316eda652336f29971ab..2db6503afab748d7b778e26b2f9350ac64c7778b 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -54,9 +54,9 @@ class DeconstructTupleTest : public ClientLibraryTestBase { TEST_F(DeconstructTupleTest, DeconstructTuple) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({const1, const2}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {const1, const2}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); @@ -73,9 +73,9 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({const1, const2}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {const1, const2}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status1 = client_->DeconstructTuple(*global_data); @@ -103,9 +103,9 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({const1, const2, const2, const1}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {const1, const2, const2, const1}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); @@ -129,9 +129,9 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({const1, const2, const1}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {const1, const2, const1}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); @@ -159,7 +159,7 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XlaBuilder builder(TestName()); - builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); + ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); @@ -171,11 +171,11 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({3.14f, -100.25f}); + LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0"); - builder.Tuple({p}); + auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); + Tuple(&builder, {p}); auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()}); auto result_status = client_->DeconstructTuple(*global_data); @@ -186,9 +186,9 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { XlaBuilder builder(TestName()); - auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); - builder.Tuple({builder.Tuple({const1, const2}), const1}); + auto const1 = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0}); + auto const2 = ConstantR1(&builder, {2.0, 4.0, 6.0, 8.0}); + Tuple(&builder, {Tuple(&builder, {const1, const2}), const1}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); auto result_status = client_->DeconstructTuple(*global_data); diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc index 085a5105aca1c173a7cbc211aebbeb5b254b0753..3f3e8ab712fea14be9e4a7015effdf8ce518309b 100644 --- a/tensorflow/compiler/xla/tests/deep_graph_test.cc +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" namespace xla { @@ -30,7 +30,7 @@ TEST_F(ClientLibraryTestBase, DeepGraph) { auto y_data = CreateR0Parameter(1, 1, "y", &b, &y); XlaOp z = x; for (int i = 0; i < kDepth; ++i) { - z = b.Add(z, y); + z = Add(z, y); } ComputeAndCompareR0(&b, /*expected=*/kDepth + 3, {x_data.get(), y_data.get()}); diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 0fd846cef8095a857dd7b2c12d8afdf409e2bd66..0e9e92ed996fbb34826d19b670c7c4920a1aad13 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -67,15 +67,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XlaOp param; auto param_data = CreateParameterAndTransferLiteral( 0, - *Literal::MakeTuple({Literal::CreateR2({{1, 2}, {3, 4}}).get(), - Literal::CreateR2({{5, 6}, {7, 8}}).get()}), + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), + LiteralUtil::CreateR2({{5, 6}, {7, 8}}).get()}), "arg0", &builder, ¶m); - auto lhs = builder.GetTupleElement(param, 0); - auto rhs = builder.GetTupleElement(param, 1); - builder.Dot(lhs, rhs); + auto lhs = GetTupleElement(param, 0); + auto rhs = GetTupleElement(param, 1); + Dot(lhs, rhs); ComputeAndCompareLiteral(&builder, - *Literal::CreateR2({{19, 22}, {43, 50}}), + *LiteralUtil::CreateR2({{19, 22}, {43, 50}}), {param_data.get()}); } @@ -87,9 +88,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR1(&builder, {}); + auto rhs = ConstantR1(&builder, {}); + Dot(lhs, rhs); this->template ComputeAndCompareR0(&builder, static_cast(0.0), {}, this->error_spec_); @@ -102,20 +103,20 @@ TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64); XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D({{3.0f, 4.0f}}); - auto rhs = builder.ConstantFromArray({3.0f, 4.0f}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, {{3.0f, 4.0f}}); + auto rhs = ConstantFromArray(&builder, {3.0f, 4.0f}); + Dot(lhs, rhs); this->template ComputeAndCompareR1(&builder, {static_cast(25.0f)}, {}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, OneElementVectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR1({static_cast(2.0f)}); - auto rhs = builder.ConstantR1({static_cast(3.0f)}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR1(&builder, {static_cast(2.0f)}); + auto rhs = ConstantR1(&builder, {static_cast(3.0f)}); + Dot(lhs, rhs); this->template ComputeAndCompareR0(&builder, static_cast(6.0f), {}, this->error_spec_); @@ -124,9 +125,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantFromArray({1.0f, 2.5f, 42.0f}); - auto rhs = builder.ConstantFromArray({11.0f, -1.0f, 0.5f}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantFromArray(&builder, {1.0f, 2.5f, 42.0f}); + auto rhs = ConstantFromArray(&builder, {11.0f, -1.0f, 0.5f}); + Dot(lhs, rhs); this->template ComputeAndCompareR0(&builder, static_cast(29.5f), {}, this->error_spec_); @@ -136,69 +137,69 @@ std::vector MinorToMajorForIsRowMajor(bool row_major) { return {row_major ? 1 : 0, row_major ? 0 : 1}; } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + auto rhs = ConstantR2FromArray2D(&builder, Array2D(2, 0)); + Dot(lhs, rhs); this->template ComputeAndCompareR2(&builder, Array2D(0, 0), {}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x3) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto rhs = builder.ConstantR2FromArray2D( - {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}}); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + auto rhs = ConstantR2FromArray2D( + &builder, {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}}); + Dot(lhs, rhs); this->template ComputeAndCompareR2(&builder, Array2D(0, 3), {}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_3x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D( - {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}}); - auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D( + &builder, {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}}); + auto rhs = ConstantR2FromArray2D(&builder, Array2D(2, 0)); + Dot(lhs, rhs); this->template ComputeAndCompareR2(&builder, Array2D(3, 0), {}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_2x0_0x2) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto rhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto result = builder.Dot(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(2, 0)); + auto rhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); + Dot(lhs, rhs); this->template ComputeAndCompareR2( &builder, Array2D(2, 2, static_cast(0.0f)), {}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto param0 = - builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 4}), "arg0"); + Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 4}), "arg0"); auto param1 = - builder.Parameter(1, ShapeUtil::MakeShapeWithType({4, 1}), "arg1"); - auto exp0 = builder.Exp(param0); - auto result = builder.Dot(exp0, param1); + Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({4, 1}), "arg1"); + auto exp0 = Exp(param0); + Dot(exp0, param1); auto lhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2D( + ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}})) .ConsumeValueOrDie(); auto rhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2D( + ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( {{1.0f}, {2.0f}, {3.0f}, {4.0f}})) .ConsumeValueOrDie(); @@ -217,23 +218,22 @@ class SquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f}, {3.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); + Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), + Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); Array2D expected({{15.0f, -2.0f}, {-25.0f, 34.0f}}); ComputeAndCompareR2(&builder, expected, @@ -287,9 +287,10 @@ void ParametricDotTest::TestImpl() { std::unique_ptr> dot_lhs_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); - std::unique_ptr dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout( - *dot_lhs_data, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); + std::unique_ptr dot_lhs_lit = + LiteralUtil::CreateR2FromArray2DWithLayout( + *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor( + param.dot_lhs_row_major))); std::unique_ptr dot_lhs_handle = client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); @@ -298,7 +299,7 @@ void ParametricDotTest::TestImpl() { Layout rhs_layout = LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.dot_rhs_row_major)); std::unique_ptr dot_rhs_lit = - Literal::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); + LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); std::unique_ptr dot_rhs_handle = client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); @@ -308,7 +309,7 @@ void ParametricDotTest::TestImpl() { if (param.has_addend) { addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); - addend_lit = Literal::CreateR2FromArray2DWithLayout( + addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout( *addend_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.addend_row_major))); addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); @@ -316,26 +317,26 @@ void ParametricDotTest::TestImpl() { XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, - ShapeUtil::MakeShapeWithLayout( - prim_type, {param.m, param.k}, - MinorToMajorForIsRowMajor(param.dot_lhs_row_major)), - "dot_lhs"), - builder.Parameter(1, - ShapeUtil::MakeShapeWithLayout( - prim_type, {param.k, param.n}, - MinorToMajorForIsRowMajor(param.dot_rhs_row_major)), - "dot_rhs")); + auto result = + Dot(Parameter(&builder, 0, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.m, param.k}, + MinorToMajorForIsRowMajor(param.dot_lhs_row_major)), + "dot_lhs"), + Parameter(&builder, 1, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.k, param.n}, + MinorToMajorForIsRowMajor(param.dot_rhs_row_major)), + "dot_rhs")); if (param.has_addend) { - result = builder.Add( - result, builder.Parameter( - 2, - ShapeUtil::MakeShapeWithLayout( - prim_type, {param.m, param.n}, - MinorToMajorForIsRowMajor(param.addend_row_major)), - "addend")); + result = + Add(result, + Parameter(&builder, 2, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.m, param.n}, + MinorToMajorForIsRowMajor(param.addend_row_major)), + "addend")); } std::unique_ptr> expected; @@ -477,14 +478,14 @@ class NonsquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -492,9 +493,8 @@ class NonsquareMatrixDot : public DotOperationTest { XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); + Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), + Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); Array2D expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); @@ -512,21 +512,20 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); } XLA_TEST_F(DotOperationTest, MatrixVectorC64) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateR2WithLayout( + ->TransferToServer(*LiteralUtil::CreateR2WithLayout( {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateR2WithLayout( + ->TransferToServer(*LiteralUtil::CreateR2WithLayout( {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); + Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), + Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); Array2D expected({{30.0, -2.0}}); @@ -534,15 +533,17 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ConcurrentMatMult) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto matrix1 = builder.ConstantR2FromArray2D({{1.0f, 2.0f}, {3.0f, 4.0f}}); - auto matrix2 = builder.ConstantR2FromArray2D({{5.0f, 6.0f}, {7.0f, 8.0f}}); - auto matrix12 = builder.Dot(matrix1, matrix2); - auto matrix21 = builder.Dot(matrix2, matrix1); - builder.Add(matrix12, matrix21); + auto matrix1 = + ConstantR2FromArray2D(&builder, {{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto matrix2 = + ConstantR2FromArray2D(&builder, {{5.0f, 6.0f}, {7.0f, 8.0f}}); + auto matrix12 = Dot(matrix1, matrix2); + auto matrix21 = Dot(matrix2, matrix1); + Add(matrix12, matrix21); Array2D expected({{42.0f, 56.0f}, {74.0f, 96.0f}}); this->template ComputeAndCompareR2(&builder, expected, {}, @@ -559,32 +560,32 @@ TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64); XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { using T = TypeParam; XlaBuilder builder(this->TestName()); - auto x = - builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "x"); - auto y = - builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "y"); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "y"); - auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2}); - auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); + auto x_flat = Reshape(x, {0, 1, 2, 3}, {4, 2, 2}); + auto y_flat = Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); // Slice batches into individual matrices and multiply them. std::vector out_slices; for (int i = 0; i < 4; ++i) { // Slice off individual matrices and reshape to 2D tensors. - auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); - x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2}); - auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); - y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2}); + auto x_slice = Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); + x_slice = Reshape(x_slice, {0, 1, 2}, {2, 2}); + auto y_slice = Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); + y_slice = Reshape(y_slice, {0, 1, 2}, {2, 2}); - auto out = builder.Dot(x_slice, y_slice); - out = builder.Reshape(out, {0, 1}, {1, 2, 2}); + auto out = Dot(x_slice, y_slice); + out = Reshape(out, {0, 1}, {1, 2, 2}); out_slices.push_back(out); } - auto out_flat = builder.ConcatInDim(out_slices, 0); - builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); + auto out_flat = ConcatInDim(&builder, out_slices, 0); + Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = this->client_ - ->TransferToServer(*Literal::CreateR4FromArray4D( + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( {{{{1000.0f, 100.0f}, {10.0f, 1.0f}}, {{2000.0f, 200.0f}, {20.0f, 2.0f}}}, {{{3000.0f, 300.0f}, {30.0f, 3.0f}}, @@ -592,7 +593,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*Literal::CreateR4FromArray4D( + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{11.0f, 22.0f}, {33.0f, 44.0f}}, {{55.0f, 66.0f}, {77.0f, 88.0f}}}})) @@ -611,14 +612,14 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { {x_data.get(), y_data.get()}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto x = - builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); + Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); auto y = - builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 2, 2}), "y"); + Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2, 2}), "y"); DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(2); @@ -626,17 +627,17 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { dnums.add_lhs_batch_dimensions(0); dnums.add_rhs_batch_dimensions(0); - auto out = builder.DotGeneral(x, y, dnums); + DotGeneral(x, y, dnums); auto x_data = this->client_ - ->TransferToServer(*Literal::CreateR3FromArray3D( + ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*Literal::CreateR3FromArray3D( + ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}})) .ConsumeValueOrDie(); @@ -647,7 +648,49 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { {x_data.get(), y_data.get()}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { + using T = TypeParam; + + XlaBuilder builder(this->TestName()); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "y"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(3); + dnums.add_rhs_contracting_dimensions(2); + dnums.add_lhs_batch_dimensions(0); + dnums.add_lhs_batch_dimensions(1); + dnums.add_rhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(1); + + DotGeneral(x, y, dnums); + + auto x_data = + this->client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {{{9.0f, 10.0f}, {11.0f, 12.0f}}, + {{13.0f, 14.0f}, {15.0f, 16.0f}}}})) + .ConsumeValueOrDie(); + + auto y_data = + this->client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}, + {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}})) + .ConsumeValueOrDie(); + + this->template ComputeAndCompareR4( + &builder, + /*expected=*/ + {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {{{10.0f, 9.0f}, {12.0f, 11.0f}}, {{14.0f, 13.0f}, {16.0f, 15.0f}}}}, + {x_data.get(), y_data.get()}, this->error_spec_); +} + +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) { using T = TypeParam; for (bool transpose_lhs : {false, true}) { for (bool transpose_rhs : {false, true}) { @@ -665,32 +708,36 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { } auto lhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - *lhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + ->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + *lhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - *rhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + ->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + *rhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); XlaBuilder builder(this->TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); - auto lhs_arg = builder.Parameter( - 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), + auto lhs_arg = Parameter( + &builder, 0, + ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), "lhs"); - auto rhs_arg = builder.Parameter( - 1, ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}), + auto rhs_arg = Parameter( + &builder, 1, + ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}), "rhs"); if (transpose_lhs) { - lhs_arg = builder.Transpose(lhs_arg, {1, 0}); + lhs_arg = Transpose(lhs_arg, {1, 0}); } if (transpose_rhs) { - rhs_arg = builder.Transpose(rhs_arg, {1, 0}); + rhs_arg = Transpose(rhs_arg, {1, 0}); } - auto result = builder.Dot(lhs_arg, rhs_arg); + Dot(lhs_arg, rhs_arg); Array2D expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); VLOG(1) << "TestTransposeFolding " << transpose_lhs << " " @@ -703,7 +750,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { } } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, DotOfConcatOptimizationWithConstLHS) { using T = TypeParam; auto prim_type = primitive_util::NativeToPrimitiveType(); @@ -713,15 +760,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}})); XlaBuilder builder(this->TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), - "rhs_arg_0"); - auto rhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), - "rhs_arg_1"); - auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}), - "rhs_arg_2"); - auto result = builder.Dot( - lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_arg_0 = Parameter( + &builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0"); + auto rhs_arg_1 = Parameter( + &builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs_arg_1"); + auto rhs_arg_2 = Parameter( + &builder, 2, ShapeUtil::MakeShape(prim_type, {1, 2}), "rhs_arg_2"); + Dot(lhs_constant, + ConcatInDim(&builder, {rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); std::unique_ptr> arg_0_value_array( new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}})); @@ -732,15 +779,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_0_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_1_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_2_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{53.0f, 74.0f}, {45.0f, 66.0f}}); this->template ComputeAndCompareR2( @@ -749,7 +796,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, DotOfConcatOptimizationWithConstRHS) { using T = TypeParam; std::unique_ptr> constant_rhs_array( @@ -761,15 +808,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, {2.0f, 1.0f}})); XlaBuilder builder(this->TestName()); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2}), - "lhs_arg_0"); - auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 3}), - "lhs_arg_1"); - auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShapeWithType({2, 1}), - "lhs_arg_2"); - auto result = builder.Dot( - builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto lhs_arg_0 = Parameter( + &builder, 0, ShapeUtil::MakeShapeWithType({2, 2}), "lhs_arg_0"); + auto lhs_arg_1 = Parameter( + &builder, 1, ShapeUtil::MakeShapeWithType({2, 3}), "lhs_arg_1"); + auto lhs_arg_2 = Parameter( + &builder, 2, ShapeUtil::MakeShapeWithType({2, 1}), "lhs_arg_2"); + Dot(ConcatInDim(&builder, {lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), + rhs_constant); std::unique_ptr> arg_0_value_array( new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}})); @@ -781,15 +828,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_0_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_1_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_2_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{38.0f, 36.0f}, {93.0f, 91.0f}}); this->template ComputeAndCompareR2( @@ -811,16 +858,15 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({1, 0}); - auto dynamic_slice = - builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {1, 0}); + auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D expected({{96.0, 105.0, 114.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -839,25 +885,23 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({0, 1}); - auto dynamic_slice = - builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {0, 1}); + auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D expected({{105.0}, {105.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstRHSReverseMM)))) { + + DotOfGatherOptimizationWithConstRHSReverseMM) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -870,25 +914,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({0, 1}); - auto dynamic_slice = - builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {0, 1}); + auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D expected({{105.0, 105.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstLHSReverseMM)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -901,25 +941,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({1, 0}); - auto dynamic_slice = - builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {1, 0}); + auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D expected({{96.0}, {105.0}, {114.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0}, {3.0, 4.0}, @@ -937,25 +973,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({0, 1}); - auto dynamic_slice = - builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {0, 1}); + auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D expected({{126.0, 129.0, 132.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0}, {3.0, 4.0}, @@ -973,25 +1005,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({0, 1}); - auto dynamic_slice = - builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {0, 1}); + auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D expected({{129.0}, {129.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) { std::unique_ptr> constant_lhs_array(new Array2D( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr> constant_rhs_array( @@ -1001,25 +1029,21 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({1, 0}); - auto dynamic_slice = - builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {1, 0}); + auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + DotGeneral(dynamic_slice, rhs_constant, dot_dnums); Array2D expected({{56.0, 168.0, 91.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU( - DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) { std::unique_ptr> constant_lhs_array(new Array2D( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr> constant_rhs_array( @@ -1029,19 +1053,41 @@ XLA_TEST_F(DotOperationTest, // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} XlaBuilder builder(TestName()); - auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); - auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto start_constant = builder.ConstantR1({1, 0}); - auto dynamic_slice = - builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array); + auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array); + auto start_constant = ConstantR1(&builder, {1, 0}); + auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(1); - auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + DotGeneral(lhs_constant, dynamic_slice, dot_dnums); Array2D expected({{168.0}, {168.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } + +XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) { + XlaBuilder builder(TestName()); + + Array2D lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array); + + Array2D rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}}); + auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array); + + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + DotGeneral(lhs_constant, rhs_constant, dot_dnums); + + Array2D expected({ + {26.f, 30.f}, + {38.f, 44.f}, + }); + + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 49f3a10d227f2f9edfe76405ba13498fe822f8d8..7f6f203a1ba48e0053f799c58bbbeae87aef1f7f 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" @@ -124,11 +124,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { // vector is special so that it cannot be an ArraySlice, which // is what the code below wants. So instead we do this. Literal input_values = - std::move(*Literal::CreateR1(input_values_int) + std::move(*LiteralUtil::CreateR1(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR1(expected_values_int) + std::move(*LiteralUtil::CreateR1(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -138,8 +138,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - builder.DynamicSlice(input, starts, slice_sizes); + auto input = ConstantLiteral(&builder, input_values); + DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -150,11 +150,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array2D& expected_values_int) { Literal input_values = - std::move(*Literal::CreateR2FromArray2D(input_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR2FromArray2D(expected_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -164,8 +164,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - builder.DynamicSlice(input, starts, slice_sizes); + auto input = ConstantLiteral(&builder, input_values); + DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -176,11 +176,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array3D& expected_values_int) { Literal input_values = - std::move(*Literal::CreateR3FromArray3D(input_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR3FromArray3D(expected_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -190,8 +190,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - builder.DynamicSlice(input, starts, slice_sizes); + auto input = ConstantLiteral(&builder, input_values); + DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -202,18 +202,28 @@ XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } +XLA_TEST_F(DynamicSliceTest, UInt32R1OOB) { + RunR1({0, 1, 2, 3, 4}, {2147483648u}, {2}, {3, 4}); +} XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } +XLA_TEST_F(DynamicSliceTest, UInt32R2OOB) { + RunR2({{0, 1}, {2, 3}}, {2147483648u, 0}, {1, 1}, {{2}}); +} XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } +XLA_TEST_F(DynamicSliceTest, UInt32R3OOB) { + RunR3({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}, + {2147483648u, 0, 2147483648u}, {1, 1, 1}, {{{5}}}); +} XLA_TEST_F(DynamicSliceTest, Int32R1Pred) { // Slice at dimension start. @@ -349,15 +359,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { void RunR0(int input_value_int, int update_value_int, const std::vector slice_starts, int expected_value_int) { Literal input_value = - std::move(*Literal::CreateR0(input_value_int) + std::move(*LiteralUtil::CreateR0(input_value_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal update_value = - std::move(*Literal::CreateR0(update_value_int) + std::move(*LiteralUtil::CreateR0(update_value_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_value = - std::move(*Literal::CreateR0(expected_value_int) + std::move(*LiteralUtil::CreateR0(expected_value_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -367,9 +377,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_value); - auto update = builder.ConstantLiteral(update_value); - builder.DynamicUpdateSlice(input, update, starts); + auto input = ConstantLiteral(&builder, input_value); + auto update = ConstantLiteral(&builder, update_value); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_value, {start_data.get()}); } @@ -380,15 +390,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, tensorflow::gtl::ArraySlice expected_values_int) { Literal input_values = - std::move(*Literal::CreateR1(input_values_int) + std::move(*LiteralUtil::CreateR1(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal update_values = - std::move(*Literal::CreateR1(update_values_int) + std::move(*LiteralUtil::CreateR1(update_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR1(expected_values_int) + std::move(*LiteralUtil::CreateR1(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -398,9 +408,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - auto update = builder.ConstantLiteral(update_values); - builder.DynamicUpdateSlice(input, update, starts); + auto input = ConstantLiteral(&builder, input_values); + auto update = ConstantLiteral(&builder, update_values); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -411,15 +421,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array2D& expected_values_int) { Literal input_values = - std::move(*Literal::CreateR2FromArray2D(input_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal update_values = - std::move(*Literal::CreateR2FromArray2D(update_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR2FromArray2D(expected_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -429,9 +439,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - auto update = builder.ConstantLiteral(update_values); - builder.DynamicUpdateSlice(input, update, starts); + auto input = ConstantLiteral(&builder, input_values); + auto update = ConstantLiteral(&builder, update_values); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -442,15 +452,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array3D& expected_values_int) { Literal input_values = - std::move(*Literal::CreateR3FromArray3D(input_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal update_values = - std::move(*Literal::CreateR3FromArray3D(update_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR3FromArray3D(expected_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -460,9 +470,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantLiteral(input_values); - auto update = builder.ConstantLiteral(update_values); - builder.DynamicUpdateSlice(input, update, starts); + auto input = ConstantLiteral(&builder, input_values); + auto update = ConstantLiteral(&builder, update_values); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } @@ -508,8 +518,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { XlaOp update; std::unique_ptr update_data = CreateR3Parameter( update_values, 1, "update_values", &builder, &update); - auto starts = builder.ConstantR1({index, 0, 0}); - builder.DynamicUpdateSlice(input, update, starts); + auto starts = ConstantR1(&builder, {index, 0, 0}); + DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. ComputeAndCompareR3(&builder, expected_values, @@ -520,7 +530,7 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { template void DumpArray(const string& name, const Array3D values) { std::unique_ptr literal = - Literal::CreateR3FromArray3D(values); + LiteralUtil::CreateR3FromArray3D(values); LOG(INFO) << name << ":" << literal->ToString(); } }; @@ -530,21 +540,32 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0(); } -// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. XLA_TEST_F(DynamicUpdateSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt32R1OOB) { + RunR1({0, 1, 2, 3, 4}, {5, 6}, {2147483648u}, {0, 1, 2, 5, 6}); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt32R2OOB) { + RunR2({{0, 1}, {2, 3}}, {{4}}, {2147483648u, 0}, + {{0, 1}, {4, 3}}); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt32R3OOB) { + RunR3({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}, {{{8}}}, + {2147483648u, 0, 2147483648u}, + {{{0, 1}, {2, 3}}, {{4, 8}, {6, 7}}}); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB(); } @@ -695,17 +716,17 @@ void BM_DynamicSlice(int num_iters) { XlaBuilder builder("DynamicSlice"); // Create input as a constant: shape [1, 2, 3, 4] - auto input_literal = Literal::CreateR4( + auto input_literal = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - auto input = builder.ConstantLiteral(*input_literal); + auto input = ConstantLiteral(&builder, *input_literal); // Create dynamic slice start indices as a parameter: shape [4] auto start_indices_shape = ShapeUtil::MakeShape(S32, {4}); auto start_indices = - builder.Parameter(0, start_indices_shape, "start_indices"); + Parameter(&builder, 0, start_indices_shape, "start_indices"); // Add DynamicSlice op to the computatation. - builder.DynamicSlice(input, start_indices, {1, 1, 1, 1}); + DynamicSlice(input, start_indices, {1, 1, 1, 1}); auto computation = builder.Build().ConsumeValueOrDie(); // Initialize and transfer parameter buffer. @@ -715,9 +736,11 @@ void BM_DynamicSlice(int num_iters) { start_indices_shape, &allocator, /*device_ordinal=*/0) .ConsumeValueOrDie(); - auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); + auto start_indices_literal = LiteralUtil::CreateR1({0, 1, 2, 3}); + auto stream = + client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *start_indices_literal, buffer)); + stream.get(), *start_indices_literal, buffer)); std::unique_ptr executable = client diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index a6ba6db5d3bf86de91f6fda022c46afee01281c2..5116e60ca63ef5f94b25b15e6616086fb9e44bbb 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/platform/test.h" @@ -31,10 +31,10 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr input, client_->TransferToServer( - *Literal::CreateR2F32Linspace(1e0, 1e5, 256, 256))); + *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256))); XlaBuilder b(TestName() + ".add"); - b.Dot(b.Parameter(0, shape, "param_0"), b.Parameter(1, shape, "param_1")); + Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1")); TF_ASSERT_OK_AND_ASSIGN(XlaComputation dot_product, b.Build()); ExecutionProfile execution_profile; diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index 0a37e4d423620122f2e109343a86a964f46d778f..bf1de02ba9dbd97db9ee31484402fe9b92385219 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -39,7 +39,7 @@ class ExhaustiveF32ElementwiseOpTest XlaBuilder builder(TestName()); std::unique_ptr input_literal = - Literal::CreateFromDimensions(F32, {input_size}); + LiteralUtil::CreateFromDimensions(F32, {input_size}); for (int64 i = begin; i < end; i++) { if (i >= known_incorrect_range.first && i < known_incorrect_range.second) { @@ -54,7 +54,7 @@ class ExhaustiveF32ElementwiseOpTest TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, client_->TransferToServer(*input_literal)); - auto input = builder.Parameter(0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal->shape(), "input"); enqueue_op(&builder, input); std::vector expected_result; @@ -79,8 +79,8 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { #endif ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { builder->Log(input); }, - std::log, known_incorrect_range); + [](XlaBuilder* builder, const XlaOp& input) { Log(input); }, std::log, + known_incorrect_range); } XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { @@ -95,14 +95,14 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { #endif ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { builder->Exp(input); }, - std::exp, known_incorrect_range); + [](XlaBuilder* builder, const XlaOp& input) { Exp(input); }, std::exp, + known_incorrect_range); } XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, TanhF32) { ExhaustivelyTestF32Op( - [](XlaBuilder* builder, const XlaOp& input) { builder->Tanh(input); }, - std::tanh, /*known_incorrect_range=*/{0, 0}); + [](XlaBuilder* builder, const XlaOp& input) { Tanh(input); }, std::tanh, + /*known_incorrect_range=*/{0, 0}); } std::vector> CreateExhaustiveParameters() { diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc index 93d1c921c4a138cda55ed7338b8e3aa82518d114..dcb469087e0064d17ce3b04fdeaf0b6136069a55 100644 --- a/tensorflow/compiler/xla/tests/filecheck.cc +++ b/tensorflow/compiler/xla/tests/filecheck.cc @@ -76,6 +76,11 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { XLA_LOG_LINES(tensorflow::WARNING, input); LOG(WARNING) << "FileCheck pattern was:"; XLA_LOG_LINES(tensorflow::WARNING, pattern); + } else if (!standard_error.empty()) { + LOG(INFO) << "FileCheck stderr:"; + XLA_LOG_LINES(tensorflow::INFO, standard_error); + LOG(INFO) << "FileCheck input was:"; + XLA_LOG_LINES(tensorflow::INFO, input); } return succeeded; } diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 71eb914a8e5eaef2e38b9e6e7d45b8a10ce1bd7a..39cc6c5927f1d416e31f689487efc10c20371abe 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -42,12 +42,12 @@ class FloorCeilTest : public ClientLibraryTestBase { LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ") << "}"; XlaBuilder builder(TestName()); - auto c = builder.ConstantR1(input); + auto c = ConstantR1(&builder, input); if (f == kCeil) { - builder.Ceil(c); + Ceil(c); } else { ASSERT_EQ(kFloor, f); - builder.Floor(c); + Floor(c); } ComputeAndCompareR1(&builder, expected, /*arguments=*/{}); } @@ -55,12 +55,12 @@ class FloorCeilTest : public ClientLibraryTestBase { void TestR0F32(float input, float expected, Function f) { LOG(INFO) << "input: " << expected; XlaBuilder builder(TestName()); - auto c = builder.ConstantR0(input); + auto c = ConstantR0(&builder, input); if (f == kCeil) { - builder.Ceil(c); + Ceil(c); } else { ASSERT_EQ(kFloor, f); - builder.Floor(c); + Floor(c); } ComputeAndCompareR0(&builder, expected, /*arguments=*/{}); } diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc index 73f029b59bc56aa6c3e86200a49fcae0fd177101..c5bbbe778df15d63a2586bd6291a7a33fc82aa52 100644 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/platform/test.h" @@ -28,11 +28,11 @@ class FmaxSimpleTest : public ClientLibraryTestBase {}; TEST_F(FmaxSimpleTest, FmaxTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); - auto y = builder.ConstantR1( - {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); - builder.Max(x, y); + auto x = ConstantR1( + &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); + auto y = ConstantR1( + &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); + Max(x, y); std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index e6f79b5ac55dddfbb213a36cadbee53bc9443d9d..792be0d3fcd55621b9f8cdf0fdc28f7bb49294d1 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -25,14 +25,15 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -89,7 +90,7 @@ class FusionTest : public HloTestBase { HloInstruction* hlos[4]; for (int i = 0; i < Arity; ++i) { hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2FromArray2D(operand_data[i]))); + LiteralUtil::CreateR2FromArray2D(operand_data[i]))); } auto answer_shape = ShapeUtil::MakeShape(prim_type, {test_width, test_height}); @@ -115,7 +116,7 @@ class FusionTest : public HloTestBase { ArraySlice(hlos, 0, Arity + 1), HloInstruction::FusionKind::kLoop); - auto expected = Literal::CreateR2FromArray2D(answer_data); + auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4))); @@ -186,27 +187,28 @@ XLA_TEST_F(FusionTest, Test) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0}, {2.0}, {3.0}}))); + LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-1.0}, {-1.0}, {-1.0}}))); + LiteralUtil::CreateR2({{-1.0}, {-1.0}, {-1.0}}))); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1)); auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0})); auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.62, 2.72, 3.14}}))); + LiteralUtil::CreateR2({{1.62, 2.72, 3.14}}))); auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate( ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0)); auto const6 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); + LiteralUtil::CreateR2({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6)); auto add8 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7)); auto const9 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); - auto const10 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{true, false, true}, {false, true, false}}))); + LiteralUtil::CreateR2({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); + auto const10 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( + {{true, false, true}, {false, true, false}}))); auto select11 = builder.AddInstruction( HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kSelect, const10, add8, const9)); @@ -222,7 +224,7 @@ XLA_TEST_F(FusionTest, Test) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{0.5}, {2.72}}), + *LiteralUtil::CreateR2({{0.5}, {2.72}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } @@ -233,11 +235,11 @@ XLA_TEST_F(FusionTest, Parameter) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0, 3.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0)); auto const2 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-2.0, -2.0, -2.0}}))); + LiteralUtil::CreateR2({{-2.0, -2.0, -2.0}}))); // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1} auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2)); @@ -248,7 +250,7 @@ XLA_TEST_F(FusionTest, Parameter) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{-1.0, 0.0, 1.0}}), + *LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } @@ -269,7 +271,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { auto hlo_module = CreateNewModule(); auto two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto x = builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {})); auto y = builder.AddInstruction( @@ -292,9 +294,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); + LiteralUtil::CreateR2({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1})); // add2 = broadcast(const_vector) + const_array @@ -308,7 +310,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), + *LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } @@ -316,14 +318,14 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto single_element_array = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{5}}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {}), single_element_array)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(5), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(5), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -331,14 +333,14 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); + LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 2, 3}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), + *LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -346,14 +348,14 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); + LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}), + *LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -361,14 +363,14 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR3({{{7}}}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(7), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -376,14 +378,14 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(7))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR3({{{7}}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{7}}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -391,14 +393,14 @@ XLA_TEST_F(FusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(7))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(7), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -406,14 +408,14 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), + *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -421,14 +423,14 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 4}, {2, 5}, {3, 6}}), + *LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -436,14 +438,14 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), + *LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -451,7 +453,7 @@ XLA_TEST_F(FusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( ShapeUtil::MakeShape(S32, {3}), const0, {0})); hlo_module->AddEntryComputation(builder.Build()) @@ -459,7 +461,7 @@ XLA_TEST_F(FusionTest, Reverse) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({3, 2, 1}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({3, 2, 1}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -467,7 +469,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( ShapeUtil::MakeShape(S32, {3}), const0, {0})); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -477,7 +479,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({-3, -2, -1}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-3, -2, -1}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -485,7 +487,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(S32, {2}), const0, {})); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -495,15 +497,15 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({-1, -1}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -1}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, SliceNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(S32, {2}), const0, {0}, {4}, {2})); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -513,17 +515,17 @@ XLA_TEST_F(FusionTest, SliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({-1, -3}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -3}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DynamicSliceNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1}))); auto dynamic_slice2 = builder.AddInstruction(HloInstruction::CreateDynamicSlice( ShapeUtil::MakeShape(S32, {2}), const0, const1, {2})); @@ -535,15 +537,15 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({-2, -3}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-2, -3}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReshapeNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {2, 2}), const0)); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -552,17 +554,16 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -2}, {-3, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{-1, -2}, {-3, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } -// TODO(b/64070202): Investigate failure. -XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) { +XLA_TEST_F(FusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2}, {3, 4}}))); + LiteralUtil::CreateR2({{1, 2}, {3, 4}}))); auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {2, 2}), const0, {1, 0})); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -571,9 +572,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -3}, {-2, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{-1, -3}, {-2, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } std::unique_ptr MakeReduceTestComputation() { @@ -591,10 +592,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 4, 8}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 4, 8}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); @@ -603,7 +604,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(15), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(15), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -611,10 +612,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 4, 8}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 4, 8}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); @@ -625,7 +626,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(-15), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(-15), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -633,9 +634,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); + LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); Window window; ASSERT_TRUE( tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n" @@ -675,7 +676,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{462, 2145}, {24871, 62491}}), + *LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -687,9 +688,9 @@ XLA_TEST_F(FusionTest, SharedConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0)); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -711,7 +712,7 @@ XLA_TEST_F(FusionTest, SharedConstant) { EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({8}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({8}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -765,6 +766,79 @@ XLA_TEST_F(FusionTest, Clamp2D) { TestElementwise2D(HloOpcode::kClamp); } +// TODO(b/73903144): Enable on interpreter once interpreter supports bitcast. +XLA_TEST_F(FusionTest, DISABLED_ON_INTERPRETER(FusionWithLayout)) { + const string hlo_text = R"( +HloModule Cluster + +fusion_c { + fusion.arg = f32[2,2]{1,0} parameter(0) + bitcast.0 = f32[2,2,1]{2,1,0} bitcast(fusion.arg) + tanh.0 = f32[2,2,1]{0,2,1} tanh(bitcast.0) + ROOT bitcast.2 = f32[2,2,1]{1,2,0} bitcast(tanh.0) +} + +ENTRY main { + arg = f32[2,2]{1,0} parameter(0) + ROOT fusion = f32[2,2,1]{1,2,0} fusion(arg), kind=kLoop, calls=fusion_c +} +)"; + + std::unique_ptr operand = + LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + test_runner_.Execute(std::move(module), {operand.get()}, + /*run_hlo_passes=*/false)); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), + *result)); +} + +class FusionClientLibraryTest : public ClientLibraryTestBase {}; + +XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { + // On the GPU backend, it's possible to have too many transposes within one + // fusion, causing the kernel to run out shared memory and thus not compile. + // We want to check that doesn't happen. + // + // To do this, we create a computation that computes + // + // P0 + P0*P1*P1 + P0*P2*P2 ... + // + // where even parameters have layout 1 and odd parameters have layout 2. + // + // Our goal is to tempt the backend into creating one giant multi-output + // fusion for the whole computation, including the transposes. Currently + // multi-output fusion only fuses fusions, so each of the terms in the sum + // needs to be a fusion itself, thus the contortions above. + constexpr int kNumParams = 25; + XlaBuilder b("ManyLayoutTransformations"); + + // This test produces values that overflow int32, which is UB, so use uint32, + // where overflow is OK. + Array2D arr(32, 32); + arr.FillUnique(); + std::unique_ptr l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + LayoutUtil::MakeLayout({0, 1})); + + std::unique_ptr l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + LayoutUtil::MakeLayout({1, 0})); + + XlaOp p0 = AddParam(*l1, &b); + XlaOp sum = p0; + for (int i = 1; i < kNumParams; ++i) { + auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b); + sum = sum + p0 * pN * pN; + } + + ComputeAndCompare(&b, {}); +} + void BM_ParallelFusion(int num_iters) { // Simple element-wise computation to benchmark parallel task partitioning. tensorflow::testing::StopTiming(); @@ -793,31 +867,31 @@ void BM_ParallelFusion(int num_iters) { // Create computation. XlaBuilder builder("ParallelFusion"); Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1}); - auto param0 = builder.Parameter(0, shape0, "param0"); + auto param0 = Parameter(&builder, 0, shape0, "param0"); Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1}); - auto param1 = builder.Parameter(1, shape1, "param1"); + auto param1 = Parameter(&builder, 1, shape1, "param1"); Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1}); - auto param2 = builder.Parameter(2, shape2, "param2"); + auto param2 = Parameter(&builder, 2, shape2, "param2"); - auto x = builder.Mul(param0, param1); - auto y = builder.Add(x, param2); + auto x = Mul(param0, param1); + Add(x, param2); auto computation = builder.Build().ConsumeValueOrDie(); // Transfer literals to device. auto param0_literal = - Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); + LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); ScopedShapedBuffer buffer0 = client->LiteralToShapedBuffer(*param0_literal, device_ordinal) .ConsumeValueOrDie(); auto param1_literal = - Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); + LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); ScopedShapedBuffer buffer1 = client->LiteralToShapedBuffer(*param1_literal, device_ordinal) .ConsumeValueOrDie(); auto param2_literal = - Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); + LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); ScopedShapedBuffer buffer2 = client->LiteralToShapedBuffer(*param2_literal, device_ordinal) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 143ffbdeb409d91ab6d46d386aa5ff98ebc4ae10..b77bece85ad1b2192b04330af9e60d3a424b59f4 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -21,9 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -// NB! TODO(b/74360564): These tests do not test out of bounds behavior since -// that hasn't been specced yet. - namespace xla { namespace { @@ -62,8 +60,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -83,8 +82,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -104,9 +104,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 2}, {2, 1}}); + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -126,9 +126,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -148,9 +148,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -170,11 +170,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -194,11 +194,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -218,8 +218,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({1, 1}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -239,9 +240,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{2, 1}, {1, 1}}); + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -260,18 +261,15 @@ ENTRY main { window_bounds={1, 0} } )"; - std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { // Out of bounds indices must not crash, and the indices in range should // produce the same values across all backends. - // - // TODO(b/74360564): Once we have a well defined semantics for OOB accesses, - // we should get rid of the mask and check that backends produce the same - // value for OOB indices too. const string hlo_text = R"( HloModule BatchDynamicSlice @@ -285,29 +283,45 @@ ENTRY main { gather_dims_to_operand_dims={0,1}, index_vector_dim=1, window_bounds={1,1} - gather_reshaped = s32[6]{0} reshape(gather) - in_bounds_mask = s32[6]{0} parameter(2) - ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) + ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR2( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - std::unique_ptr in_bounds_mask = - Literal::CreateR1({0, 1, 1, 0, 0, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) { + // Out of bounds indices must not crash, and the indices in range should + // produce the same values across all backends. - RunTest(hlo_text, - {operand.get(), gather_indices.get(), in_bounds_mask.get()}); + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = u32[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + ROOT result = s32[6]{0} reshape(gather) +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); } XLA_TEST_F(GatherOperationTest, NegativeIndex) { // Negative indices must not crash, and the indices in range should produce // the same values across all backends. - // - // TODO(b/74360564): Once we have a well defined semantics for negative - // accesses, we should get rid of the mask and check that backends produce the - // same value for negative indices too. const string hlo_text = R"( HloModule BatchDynamicSlice @@ -321,20 +335,40 @@ ENTRY main { gather_dims_to_operand_dims={0,1}, index_vector_dim=1, window_bounds={1,1} - gather_reshaped = s32[6]{0} reshape(gather) - in_bounds_mask = s32[6]{0} parameter(2) - ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) + ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR2( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - std::unique_ptr in_bounds_mask = - Literal::CreateR1({0, 1, 1, 0, 0, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) { + // Negative indices must not crash, and the indices in range should produce + // the same values across all backends. - RunTest(hlo_text, - {operand.get(), gather_indices.get(), in_bounds_mask.get()}); + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = u32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + gather = u32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + ROOT result = u32[6]{0} reshape(gather) +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = LiteralUtil::CreateR2( + {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); } XLA_TEST_F(GatherOperationTest, OneScalarIndex) { @@ -352,9 +386,9 @@ ENTRY main { window_bounds={1,3,2} } )"; - std::unique_ptr operand = Literal::CreateR3( + std::unique_ptr operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr gather_indices = Literal::CreateR0(1); + std::unique_ptr gather_indices = LiteralUtil::CreateR0(1); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -373,8 +407,8 @@ ENTRY main { window_bounds={1} } )"; - std::unique_ptr operand = Literal::CreateR1({1, 2, 3, 4}); - std::unique_ptr gather_indices = Literal::CreateR0(1); + std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + std::unique_ptr gather_indices = LiteralUtil::CreateR0(1); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -394,8 +428,8 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = LiteralUtil::CreateR1({}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -418,8 +452,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -442,9 +477,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 2}, {2, 1}}); + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -467,9 +502,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -492,11 +527,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -520,11 +555,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -547,8 +582,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({1, 1}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -571,9 +607,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{2, 1}, {1, 1}}); + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -598,22 +634,23 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); - auto operand = builder.Parameter(0, operand_shape, "operand"); - auto indices = builder.Parameter(1, indices_shape, "indices"); + auto operand = Parameter(&builder, 0, operand_shape, "operand"); + auto indices = Parameter(&builder, 1, indices_shape, "indices"); GatherDimensionNumbers dim_numbers; dim_numbers.add_output_window_dims(1); dim_numbers.add_elided_window_dims(0); dim_numbers.add_gather_dims_to_operand_dims(0); dim_numbers.set_index_vector_dim(1); - builder.Gather(operand, indices, dim_numbers, {1, 3}); + Gather(operand, indices, dim_numbers, {1, 3}); std::vector expected = {}; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr operand_arg, - client_->TransferToServer(*Literal::CreateR2( - {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr operand_arg, + client_->TransferToServer( + *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr indices_arg, - client_->TransferToServer(*Literal::CreateR1({0, 2}))); + client_->TransferToServer(*LiteralUtil::CreateR1({0, 2}))); TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); @@ -629,8 +666,8 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { client_->ExecuteParallel(computation_instances)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, client_->Transfer(*(result_data[0]))); - EXPECT_TRUE(LiteralTestUtil::Equal( - *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}))); + LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, + *result_literal); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 76bf47845ca045b4eede9a3b47ae5c2ce93ce577..51450314b611b49c643fb6fd5b0c0d2e7205a2d2 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -37,8 +37,7 @@ class HalfTestBase : public ClientLibraryTestBase { static const int kNumElements = 4; }; -using UnaryBuildFuncTy = - std::function; +using UnaryBuildFuncTy = std::function; struct UnaryOpTestParam { std::function compute_func; @@ -49,7 +48,8 @@ class UnaryOpTest : public HalfTestBase, public ::testing::WithParamInterface {}; XLA_TEST_P(UnaryOpTest, Ops) { - std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); + std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1), half(9.0), + half(42.0), half(-9.0), half(-100.0)}); XlaBuilder builder(TestName()); XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", @@ -62,7 +62,7 @@ XLA_TEST_P(UnaryOpTest, Ops) { } UnaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd); + build_func(x_opnd); ComputeAndCompareR1(&builder, expected, {x_data.get()}, error_spec_); } @@ -79,18 +79,17 @@ half round_imp(half value) { INSTANTIATE_TEST_CASE_P( half, UnaryOpTest, ::testing::Values( - UnaryOpTestParam{[](half x) { return abs(x); }, &XlaBuilder::Abs}, - UnaryOpTestParam{[](half x) { return round_imp(x); }, - &XlaBuilder::Round}, - UnaryOpTestParam{[](half x) { return ceil(x); }, &XlaBuilder::Ceil}, - UnaryOpTestParam{[](half x) { return cos(x); }, &XlaBuilder::Cos}, - UnaryOpTestParam{[](half x) { return exp(x); }, &XlaBuilder::Exp}, - UnaryOpTestParam{[](half x) { return floor(x); }, &XlaBuilder::Floor}, - UnaryOpTestParam{[](half x) { return log(x); }, &XlaBuilder::Log}, - UnaryOpTestParam{[](half x) { return -x; }, &XlaBuilder::Neg}, - UnaryOpTestParam{[](half x) { return sign_imp(x); }, &XlaBuilder::Sign}, - UnaryOpTestParam{[](half x) { return sin(x); }, &XlaBuilder::Sin}, - UnaryOpTestParam{[](half x) { return tanh(x); }, &XlaBuilder::Tanh} + UnaryOpTestParam{[](half x) { return abs(x); }, &Abs}, + UnaryOpTestParam{[](half x) { return round_imp(x); }, &Round}, + UnaryOpTestParam{[](half x) { return ceil(x); }, &Ceil}, + UnaryOpTestParam{[](half x) { return cos(x); }, &Cos}, + UnaryOpTestParam{[](half x) { return exp(x); }, &Exp}, + UnaryOpTestParam{[](half x) { return floor(x); }, &Floor}, + UnaryOpTestParam{[](half x) { return log(x); }, &Log}, + UnaryOpTestParam{[](half x) { return -x; }, &Neg}, + UnaryOpTestParam{[](half x) { return sign_imp(x); }, &Sign}, + UnaryOpTestParam{[](half x) { return sin(x); }, &Sin}, + UnaryOpTestParam{[](half x) { return tanh(x); }, &Tanh} )); @@ -118,19 +117,18 @@ XLA_TEST_P(UnaryPredTest, Ops) { } UnaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd); + build_func(x_opnd); ComputeAndCompareR1(&builder, expected, {x_data.get()}); } INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, ::testing::Values(UnaryPredTestParam{ - [](half x) { return isfinite(x); }, - &XlaBuilder::IsFinite})); + [](half x) { return isfinite(x); }, &IsFinite})); -using BinaryBuildFuncTy = std::function)>; +using BinaryBuildFuncTy = + std::function)>; struct BinaryOpTestParam { std::function compute_func; @@ -159,7 +157,7 @@ XLA_TEST_P(BinaryOpTest, Ops) { } BinaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd, y_opnd, {}); + build_func(x_opnd, y_opnd, {}); ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}, error_spec_); @@ -173,22 +171,15 @@ half atan2_imp(half x, half y) { INSTANTIATE_TEST_CASE_P( half, BinaryOpTest, ::testing::Values( - BinaryOpTestParam{[](half x, half y) { return x + y; }, - &XlaBuilder::Add}, + BinaryOpTestParam{[](half x, half y) { return x + y; }, &Add}, BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); }, - &XlaBuilder::Atan2}, - BinaryOpTestParam{[](half x, half y) { return x / y; }, - &XlaBuilder::Div}, - BinaryOpTestParam{[](half x, half y) { return max(x, y); }, - &XlaBuilder::Max}, - BinaryOpTestParam{[](half x, half y) { return min(x, y); }, - &XlaBuilder::Min}, - BinaryOpTestParam{[](half x, half y) { return x * y; }, - &XlaBuilder::Mul}, - BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, - &XlaBuilder::Pow}, - BinaryOpTestParam{[](half x, half y) { return x - y; }, - &XlaBuilder::Sub} + &Atan2}, + BinaryOpTestParam{[](half x, half y) { return x / y; }, &Div}, + BinaryOpTestParam{[](half x, half y) { return max(x, y); }, &Max}, + BinaryOpTestParam{[](half x, half y) { return min(x, y); }, &Min}, + BinaryOpTestParam{[](half x, half y) { return x * y; }, &Mul}, + BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, &Pow}, + BinaryOpTestParam{[](half x, half y) { return x - y; }, &Sub} )); @@ -221,27 +212,22 @@ XLA_TEST_P(BinaryPredTest, Ops) { } BinaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd, y_opnd, {}); + build_func(x_opnd, y_opnd, {}); ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}); } INSTANTIATE_TEST_CASE_P( half, BinaryPredTest, - ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; }, - &XlaBuilder::Eq}, - BinaryPredTestParam{[](half x, half y) { return x != y; }, - &XlaBuilder::Ne}, - BinaryPredTestParam{[](half x, half y) { return x >= y; }, - &XlaBuilder::Ge}, - BinaryPredTestParam{[](half x, half y) { return x > y; }, - &XlaBuilder::Gt}, - BinaryPredTestParam{[](half x, half y) { return x <= y; }, - &XlaBuilder::Le}, - BinaryPredTestParam{[](half x, half y) { return x < y; }, - &XlaBuilder::Lt} - - )); + ::testing::Values( + BinaryPredTestParam{[](half x, half y) { return x == y; }, &Eq}, + BinaryPredTestParam{[](half x, half y) { return x != y; }, &Ne}, + BinaryPredTestParam{[](half x, half y) { return x >= y; }, &Ge}, + BinaryPredTestParam{[](half x, half y) { return x > y; }, &Gt}, + BinaryPredTestParam{[](half x, half y) { return x <= y; }, &Le}, + BinaryPredTestParam{[](half x, half y) { return x < y; }, &Lt} + + )); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index cf971dd61b71ad329b20b0bb7c16166126562681..5511190caf95544e2ac48d91c0a138db06a2544c 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" @@ -30,9 +30,9 @@ class HloMetadataTest : public LocalClientTestBase { } void BuildAddComputation(XlaBuilder* builder) { - auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder->Add(x, y); + auto x = Parameter(builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); } OpMetadata metadata_; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 242cc5db11ff2bdf69209df7537216573d8afbf3..f05d1a8b9d372e720ae1634a9c8d5c0591e39b89 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -83,13 +83,16 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace -HloTestBase::HloTestBase() - : HloTestBase(GetTestPlatform(), GetReferencePlatform()) {} +HloTestBase::HloTestBase(bool allow_mixed_precision_in_hlo_verifier) + : HloTestBase(GetTestPlatform(), GetReferencePlatform(), + allow_mixed_precision_in_hlo_verifier) {} HloTestBase::HloTestBase(se::Platform* test_platform, - se::Platform* reference_platform) + se::Platform* reference_platform, + bool allow_mixed_precision_in_hlo_verifier) : test_runner_(test_platform), reference_runner_(reference_platform) { - hlo_verifier_ = MakeUnique(/*allow_mixed_precision=*/true); + hlo_verifier_ = + MakeUnique(allow_mixed_precision_in_hlo_verifier); } /* static */ @@ -233,6 +236,29 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( reference_preprocessor); } +::testing::AssertionResult HloTestBase::Run(const StringPiece hlo_string) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + const auto& fake_arguments = + MakeFakeArguments(module_or_status.ValueOrDie().get()) + .ConsumeValueOrDie(); + std::vector fake_argument_ptrs; + c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const std::unique_ptr& literal) { return literal.get(); }); + return test_runner_ + .Execute(std::move(module_or_status.ValueOrDie()), + fake_argument_ptrs, /*run_hlo_passes=*/true) + .ok() + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure(); +} + ::testing::AssertionResult HloTestBase::RunAndCompareFromFile( const string& filename, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor) { @@ -276,9 +302,10 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( HloComputation* HloTestBase::FindComputation(HloModule* module, tensorflow::StringPiece name) { - auto it = c_find_if(module->computations(), + auto computations = module->computations(); + auto it = c_find_if(computations, [&](HloComputation* c) { return c->name() == name; }); - if (it == module->computations().end()) { + if (it == computations.end()) { return nullptr; } return *it; @@ -287,9 +314,10 @@ HloComputation* HloTestBase::FindComputation(HloModule* module, HloInstruction* HloTestBase::FindInstruction(HloModule* module, tensorflow::StringPiece name) { for (const HloComputation* c : module->computations()) { - auto it = c_find_if(c->instructions(), + auto instructions = c->instructions(); + auto it = c_find_if(instructions, [&](HloInstruction* i) { return i->name() == name; }); - if (it != c->instructions().end()) { + if (it != instructions.end()) { return *it; } } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 249da87f489324ed9d377cc46a15cef5a9e74192..4232eeceb10b37a209f247ffa70fb9a08be337e6 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -80,12 +80,13 @@ class HloTestBase : public ::testing::Test { // automatically finds another supported backend as the test backend. If the // interpreter is the only supported backend, it will be both the test backend // and the reference backend. - HloTestBase(); + HloTestBase(bool allow_mixed_precision_in_hlo_verifier = true); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. - HloTestBase(se::Platform* test_platform, se::Platform* reference_platform); + HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, + bool allow_mixed_precision_in_hlo_verifier = true); ~HloTestBase() override {} @@ -166,6 +167,8 @@ class HloTestBase : public ::testing::Test { const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; + ::testing::AssertionResult Run(const tensorflow::StringPiece hlo_string) + TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareFromFile( const string& filename, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) @@ -185,13 +188,9 @@ class HloTestBase : public ::testing::Test { // 'layout'. void ForceParameterLayout(HloModule* module, int64 param_no, const Layout& layout) { - ASSERT_LT( - param_no, - module->mutable_host_entry_computation_layout()->parameter_count()); - module->mutable_host_entry_computation_layout() - ->mutable_parameter_layout(param_no) - ->ResetLayout(layout); - module->mutable_device_entry_computation_layout() + ASSERT_LT(param_no, + module->mutable_entry_computation_layout()->parameter_count()); + module->mutable_entry_computation_layout() ->mutable_parameter_layout(param_no) ->ResetLayout(layout); } @@ -199,21 +198,22 @@ class HloTestBase : public ::testing::Test { // Convenience method to force the layout of the computation result in a // module. The result layout of 'module' is set to 'layout'. void ForceResultLayout(HloModule* module, const Layout& layout) { - module->mutable_host_entry_computation_layout() + module->mutable_entry_computation_layout() ->mutable_result_layout() ->ResetLayout(layout); - module->mutable_device_entry_computation_layout() + } + + void ForceResultLayout(HloModule* module, const Layout& layout, + ShapeIndexView shape_index) { + module->mutable_entry_computation_layout() ->mutable_result_layout() - ->ResetLayout(layout); + ->ResetLayout(layout, shape_index); } // Convenience method to clear the layout of the computation result in // 'module'. void ForceClearResultLayout(HloModule* module) { - module->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->Clear(); - module->mutable_device_entry_computation_layout() + module->mutable_entry_computation_layout() ->mutable_result_layout() ->Clear(); } diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index 22c664d1426c598dbb695ff1b66ce009b0a19c00..ad1f5b9eed8b5b140100c1fa35dc7d698e3db48b 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -72,10 +72,10 @@ HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { return modules_.back().get(); } -void HloVerifiedTestBase::ParseAndVerifyModule( - tensorflow::StringPiece hlo_text) { +void HloVerifiedTestBase::ParseAndVerifyModule(tensorflow::StringPiece hlo_text, + const HloModuleConfig& config) { CHECK(!module_) << "Called ParseModule when test already has a module."; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config)); VerifyModule(module_.get()); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index 5b59cc77f61b05092d3afb331e73932c9edc5840..5b28c01c369fa1ae1c7941f5c8139882c4dbed08 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -44,7 +44,8 @@ class HloVerifiedTestBase : public HloTestBase { // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). HloModule& module(); - void ParseAndVerifyModule(tensorflow::StringPiece hlo_text); + void ParseAndVerifyModule(tensorflow::StringPiece hlo_text, + const HloModuleConfig& config = HloModuleConfig()); // Sets the shape-size function used during hlo verification. If this isn't // called, a default ShapeVerifier is used instead. diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..17ac95ae0198d98490b25f7f2edd32d1e0495803 --- /dev/null +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace { + +class IotaTest : public ClientLibraryTestBase { + public: + explicit IotaTest(se::Platform* platform = nullptr) + : ClientLibraryTestBase(platform) {} + template + std::vector GetExpected(const int64 num_elements) { + std::vector result(num_elements); + std::iota(result.begin(), result.end(), 0); + return result; + } +}; + +XLA_TEST_F(IotaTest, SimpleR1) { + for (int num_elements = 1; num_elements < 10000001; num_elements *= 10) { + { + XlaBuilder builder(TestName() + "_f32"); + IotaGen(&builder, F32, num_elements); + ComputeAndCompareR1(&builder, GetExpected(num_elements), {}, + ErrorSpec{0.0001}); + } + { + XlaBuilder builder(TestName() + "_u32"); + IotaGen(&builder, U32, num_elements); + ComputeAndCompareR1(&builder, GetExpected(num_elements), + {}); + } + { + XlaBuilder builder(TestName() + "_s32"); + IotaGen(&builder, S32, num_elements); + ComputeAndCompareR1(&builder, GetExpected(num_elements), + {}); + } + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index d1b8a6cf0b2552f1b7d95a2560d502da14ddc39a..31a099c15f1f20457c90de97054f68a31eb49011 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -154,20 +155,20 @@ class LiteralTestUtil { template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR0(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR0(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR1Equal( tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR1(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR1(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR2(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR2(expected), actual)); } template @@ -175,46 +176,46 @@ template std::initializer_list>> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR3(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR3(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( const Array2D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( const Array3D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( const Array4D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR0(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR0(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR1Near( tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR1(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR1(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR2(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR2(expected), actual, error)); } template @@ -222,7 +223,7 @@ template std::initializer_list>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR3(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR3(expected), actual, error)); } template @@ -231,28 +232,28 @@ template std::initializer_list>>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR4(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR4(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( const Array2D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( const Array3D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( const Array4D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index bbac7285aefbb1f028fad152e4b7fe6af01e9f6d..f297b2b847f570d26e71ddcd8e34bc626f982e1f 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -31,8 +31,9 @@ namespace xla { namespace { TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { - std::unique_ptr literal = Literal::MakeTuple({ - Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), + std::unique_ptr literal = LiteralUtil::MakeTuple({ + LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR0(64).get(), }); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal)); } @@ -42,11 +43,13 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { // un-fail an assertion failure. The CHECK-failure is death, so we can make a // death assertion. auto unequal_things_are_equal = [] { - std::unique_ptr lhs = Literal::MakeTuple({ - Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), + std::unique_ptr lhs = LiteralUtil::MakeTuple({ + LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR0(64).get(), }); - std::unique_ptr rhs = Literal::MakeTuple({ - Literal::CreateR0(64).get(), Literal::CreateR0(42).get(), + std::unique_ptr rhs = LiteralUtil::MakeTuple({ + LiteralUtil::CreateR0(64).get(), + LiteralUtil::CreateR0(42).get(), }); CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal"; }; @@ -55,8 +58,8 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { auto dummy_lambda = [] { - auto two = Literal::CreateR0(2); - auto four = Literal::CreateR0(4); + auto two = LiteralUtil::CreateR0(2); + auto four = LiteralUtil::CreateR0(4); ErrorSpec error(0.001); CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four"; }; @@ -98,8 +101,8 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { } TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { - auto expected = Literal::CreateR1({1, 2, 3}); - auto actual = Literal::CreateR1({4, 5, 6}); + auto expected = LiteralUtil::CreateR1({1, 2, 3}); + auto actual = LiteralUtil::CreateR1({4, 5, 6}); ::testing::AssertionResult result = LiteralTestUtil::Equal(*expected, *actual); EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}")); @@ -107,25 +110,26 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { } TEST(LiteralTestUtilTest, NearComparatorR1) { - auto a = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); - auto b = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto a = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto b = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); } TEST(LiteralTestUtilTest, NearComparatorR1Nan) { - auto a = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); - auto b = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); + auto a = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); + auto b = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); } TEST(LiteralTestUtil, NearComparatorDifferentLengths) { - auto a = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); - auto b = Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); + auto a = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto b = + LiteralUtil::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001})); } diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 082bc34136e004795ce300c66591758f47c665fe..e719da54d45d3e6eb3f3e14d3fa3076db2081e04 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -14,9 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -64,7 +65,7 @@ class LLVMCompilerTest : public ::testing::Test { // Create HLO module, and run the compiler. auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); @@ -86,7 +87,7 @@ class LLVMCompilerTest : public ::testing::Test { void TestMultiModuleCompilation(LLVMCompiler *compiler) { HloComputation::Builder builder(TestName()); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); std::unique_ptr hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); @@ -144,7 +145,7 @@ TEST_F(CpuCompilerTest, HooksTest) { } TEST_F(GpuCompilerTest, HooksTest) { - gpu::GpuCompiler compiler; + gpu::NVPTXCompiler compiler; TestCompilerHooks(&compiler); } @@ -154,7 +155,7 @@ TEST_F(CpuCompilerTest, MultiModuleCompilation) { } TEST_F(GpuCompilerTest, MultModuleCompilation) { - gpu::GpuCompiler compiler; + gpu::NVPTXCompiler compiler; TestMultiModuleCompilation(&compiler); } } // namespace diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index 2c45f19c090d2690878430363bf0d20252b2f3df..6fc11150978931f980349799372872f9fb68f292 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,28 +26,28 @@ limitations under the License. namespace xla { -void LLVMIRGenTestBase::SetIrHook(bool match_optimized_ir) { +void LlvmIrGenTestBase::SetIrHook(bool match_optimized_ir) { auto llvm_compiler = GetLLVMCompiler(); using std::placeholders::_1; // Add the IR inspection hook to the LLVM compiler. if (match_optimized_ir) { llvm_compiler->SetPostOptimizationHook( - std::bind(&LLVMIRGenTestBase::IrHook, this, _1)); + std::bind(&LlvmIrGenTestBase::IrHook, this, _1)); } else { llvm_compiler->SetPreOptimizationHook( - std::bind(&LLVMIRGenTestBase::IrHook, this, _1)); + std::bind(&LlvmIrGenTestBase::IrHook, this, _1)); } } -void LLVMIRGenTestBase::ResetIrHook() { +void LlvmIrGenTestBase::ResetIrHook() { auto llvm_compiler = GetLLVMCompiler(); llvm_compiler->RemovePreOptimizationHook(); llvm_compiler->RemovePostOptimizationHook(); } -void LLVMIRGenTestBase::CompileAndVerifyIr( +void LlvmIrGenTestBase::CompileAndVerifyIr( std::unique_ptr hlo_module, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); @@ -58,7 +59,17 @@ void LLVMIRGenTestBase::CompileAndVerifyIr( EXPECT_TRUE(filecheck_result.ValueOrDie()); } -void LLVMIRGenTestBase::CompileAheadOfTimeAndVerifyIr( +void LlvmIrGenTestBase::CompileAndVerifyIr(const string& hlo_text, + const string& expected_llvm_ir, + bool match_optimized_ir) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + CompileAndVerifyIr(std::move(module), expected_llvm_ir, match_optimized_ir); +} + +void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr( std::unique_ptr hlo_module, const AotCompilationOptions& options, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); @@ -71,11 +82,11 @@ void LLVMIRGenTestBase::CompileAheadOfTimeAndVerifyIr( EXPECT_TRUE(filecheck_result.ValueOrDie()); } -LLVMCompiler* LLVMIRGenTestBase::GetLLVMCompiler() { +LLVMCompiler* LlvmIrGenTestBase::GetLLVMCompiler() { return static_cast(backend().compiler()); } -Status LLVMIRGenTestBase::IrHook(const llvm::Module& module) { +Status LlvmIrGenTestBase::IrHook(const llvm::Module& module) { ir_ = llvm_ir::DumpModuleToString(module); return Status::OK(); } diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h index 74cbb5f5df662992046a5b0f9a31e52879f375ad..018f9546afc3e408686a9ac75a74320a05b27182 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h @@ -24,7 +24,7 @@ limitations under the License. namespace xla { // Tests that verify IR emitted by the CPU/GPU backend is as expected. -class LLVMIRGenTestBase : public CodegenTestBase { +class LlvmIrGenTestBase : public CodegenTestBase { protected: // Compiles the given HLO module to LLVM IR and verifies the IR matches the // given pattern. `pattern` is in the FileCheck pattern matching syntax @@ -38,6 +38,12 @@ class LLVMIRGenTestBase : public CodegenTestBase { void CompileAndVerifyIr(std::unique_ptr hlo_module, const string& pattern, bool match_optimized_ir); + // A thin wrapper around CompileAndVerifyIr that parses `hlo_text` to create + // an HLO module. + void CompileAndVerifyIr(const string& hlo_text, + const string& expected_llvm_ir, + bool match_optimized_ir = false); + // Compiles the given HLO module to LLVM IR and verifies the IR matches the // given pattern. `pattern` is in the FileCheck pattern matching syntax // (http://llvm.org/docs/CommandGuide/FileCheck.html). diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index f21f83992ffb7c07dff31c68a7e9e3f7944bf512..e2cd5bcc5a95f692dcf4a43d717252bfe876aa81 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" @@ -38,14 +38,14 @@ class LocalClientAllocationTest : public LocalClientTestBase { XLA_TEST_F(LocalClientAllocationTest, AddVectors) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0.0f, 1.0f, 2.0f}); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = ConstantR1(&builder, {0.0f, 1.0f, 2.0f}); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); int64 allocation_count_before = allocator_->allocation_count(); @@ -74,9 +74,9 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) { // Run a computation on every device on the system. Verify that allocation // occurs on the proper device. XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0.0f, 1.0f, 2.0f}); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = ConstantR1(&builder, {0.0f, 1.0f, 2.0f}); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); auto computation = builder.Build().ConsumeValueOrDie(); TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc index 47cab796041e9669affaebd7866d0d80100730f1..115448c908ac9e7f0b01772ce348d23bf4d838ed 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc @@ -42,13 +42,12 @@ extern "C" void SumStructElements(float* out, void** parameters) { TEST_F(LocalClientAotTest, Constant) { xla::ExecutableRunOptions run_options; OpaqueData opaque_data{100, 20, 3}; - void* parameters[] = {&opaque_data}; float out = 0; - void* temporary_buffers[] = {nullptr, &out}; - SumAndDouble(&out, &run_options, parameters, temporary_buffers); + void* temporary_buffers[] = {&opaque_data, &out}; + SumAndDouble(&out, &run_options, nullptr, temporary_buffers); EXPECT_EQ(out, 246.0f); opaque_data = {1, 2, 3}; - SumAndDouble(&out, &run_options, parameters, temporary_buffers); + SumAndDouble(&out, &run_options, nullptr, temporary_buffers); EXPECT_EQ(out, 12.0f); } diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index a366afe8262e1f537b225e395bba9cb2fc22683a..60eb21aafd23a8d724d1f08d5c87098b7c3dcd6b 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -21,8 +21,8 @@ limitations under the License. #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -37,8 +37,8 @@ using xla::string; xla::XlaComputation Doubler() { xla::XlaBuilder builder("doubler"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); - auto x = builder.Parameter(0, r0f32, "x"); - builder.Mul(x, builder.ConstantR0(2.0)); + auto x = xla::Parameter(&builder, 0, r0f32, "x"); + xla::Mul(x, xla::ConstantR0(&builder, 2.0)); return std::move(builder.Build().ValueOrDie()); } @@ -51,10 +51,10 @@ int main(int argc, char** argv) { xla::XlaBuilder builder("aot_test_helper"); auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); - auto opaque_param = builder.Parameter(0, opaque_shape, "x"); + auto opaque_param = Parameter(&builder, 0, opaque_shape, "x"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); - auto sum = builder.CustomCall("SumStructElements", {opaque_param}, r0f32); - builder.Call(Doubler(), {sum}); + auto sum = CustomCall(&builder, "SumStructElements", {opaque_param}, r0f32); + Call(&builder, Doubler(), {sum}); if (argc != 2) { LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU"; @@ -92,9 +92,10 @@ int main(int argc, char** argv) { // It's lame to hard-code the buffer assignments, but we need // local_client_aot_test.cc to be able to easily invoke the function. CHECK_EQ(result->result_buffer_index(), 1); - CHECK_EQ(result->buffer_sizes().size(), 2); - CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer - CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer + CHECK_EQ(result->buffer_infos().size(), 3); + CHECK(result->buffer_infos()[0].is_entry_parameter()); // param buffer + CHECK_EQ(result->buffer_infos()[1].size(), sizeof(float)); // result buffer + CHECK(result->buffer_infos()[2].is_constant()); // const buffer if (triple.isOSBinFormatELF()) { // Check the ELF magic. CHECK_EQ(result->object_file_data()[0], 0x7F); diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 96858c00d6bbe59b673a34e7d5ca261756709596..1a823cf189b310c62c735419936544ea99fcfbaf 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -54,7 +54,7 @@ class LocalClientExecuteTest : public LocalClientTestBase { XLA_TEST_F(LocalClientExecuteTest, Constant) { XlaBuilder builder(TestName()); - auto y = builder.ConstantR0(123.0f); + ConstantR0(&builder, 123.0f); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); @@ -64,11 +64,11 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) { XLA_TEST_F(LocalClientExecuteTest, AddScalars) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.ConstantR0(123.0f); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = ConstantR0(&builder, 123.0f); + Add(x, y); - auto x_value = LiteralToShapedBuffer(*Literal::CreateR0(42.0f)); + auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0(42.0f)); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value}); LiteralTestUtil::ExpectR0Near(165.f, *ShapedBufferToLiteral(result), @@ -77,11 +77,11 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) { XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "x"); - auto y = builder.ConstantR1({}); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "x"); + auto y = ConstantR1(&builder, {}); + Add(x, y); - auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({})); + auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1({})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); LiteralTestUtil::ExpectR1Near({}, *ShapedBufferToLiteral(result), @@ -90,12 +90,12 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { XLA_TEST_F(LocalClientExecuteTest, AddVectors) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); LiteralTestUtil::ExpectR1Near( @@ -104,12 +104,12 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) { XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ExecutionProfile profile; ScopedShapedBuffer result = ExecuteLocallyOrDie( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(), @@ -122,19 +122,19 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + Add(x, y); auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. - auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( + auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. - auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( + auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); @@ -155,15 +155,15 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + Add(x, y); auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); // Run with col-major result layout. ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie( @@ -192,15 +192,15 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { XLA_TEST_F(LocalClientExecuteTest, TupleResult) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - builder.Tuple({x, y, x}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + Tuple(&builder, {x, y, x}); auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -209,27 +209,26 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {1})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, + LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - auto inner_tuple = builder.Tuple({x, y, x}); - builder.Tuple({inner_tuple, x}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + auto inner_tuple = Tuple(&builder, {x, y, x}); + Tuple(&builder, {inner_tuple, x}); auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -238,28 +237,25 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 0})); - LiteralTestUtil::ExpectR2Equal( - {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {0, 1})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 2})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0, 0})); + LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, + LiteralSlice(*result_literal, {0, 1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { // Verify setting the result layout of a computation with a tuple output. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); - builder.Tuple({x, y}); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); + Tuple(&builder, {x, y}); auto array = LiteralToShapedBuffer( - *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); ExecutableBuildOptions options = DefaultExecutableBuildOptions(); Shape shape_with_layout = ShapeUtil::MakeTupleShape( @@ -273,10 +269,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { options, DefaultExecutableRunOptions()); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -291,23 +287,23 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { // Computation adds the respective array and vector elements from each tuple // argument and returns the results as a tuple. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, tuple_shape0, "x"); - auto y = builder.Parameter(1, tuple_shape1, "y"); - auto x_0 = builder.GetTupleElement(x, 0); - auto x_1 = builder.GetTupleElement(x, 1); - auto y_0 = builder.GetTupleElement(y, 0); - auto y_1 = builder.GetTupleElement(y, 1); - auto array_sum = builder.Add(x_0, y_1); - auto vector_diff = builder.Sub(x_1, y_0); - builder.Tuple({array_sum, vector_diff}); + auto x = Parameter(&builder, 0, tuple_shape0, "x"); + auto y = Parameter(&builder, 1, tuple_shape1, "y"); + auto x_0 = GetTupleElement(x, 0); + auto x_1 = GetTupleElement(x, 1); + auto y_0 = GetTupleElement(y, 0); + auto y_1 = GetTupleElement(y, 1); + auto array_sum = Add(x_0, y_1); + auto vector_diff = Sub(x_1, y_0); + Tuple(&builder, {array_sum, vector_diff}); auto computation = builder.Build().ConsumeValueOrDie(); - auto x_literal = Literal::MakeTuple( - {Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - Literal::CreateR1({42.0, 75.0, 123.0}).get()}); - auto y_literal = Literal::MakeTuple( - {Literal::CreateR1({2.0, 4.0, 6.0}).get(), - Literal::CreateR2({{55.0, 44.0}, {33.0, 22.0}}).get()}); + auto x_literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), + LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}); + auto y_literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({2.0, 4.0, 6.0}).get(), + LiteralUtil::CreateR2({{55.0, 44.0}, {33.0, 22.0}}).get()}); auto x_buffer = LiteralToShapedBuffer(*x_literal); auto y_buffer = LiteralToShapedBuffer(*y_literal); @@ -319,11 +315,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR1Equal( - {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR1Equal({40.0f, 71.0f, 117.0f}, + LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -338,32 +333,32 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { // Computation negates the array element and sums the two vector elements in // the nested tuple. The resulting array and vector are returned as a tuple. XlaBuilder builder(TestName()); - auto param = builder.Parameter(0, nested_tuple_shape, "param"); - auto inner_tuple = builder.GetTupleElement(param, 0); - auto inner_array = builder.GetTupleElement(inner_tuple, 0); - auto inner_vector = builder.GetTupleElement(inner_tuple, 1); - auto outer_vector = builder.GetTupleElement(param, 1); - - auto negate_array = builder.Neg(inner_array); - auto vector_sum = builder.Add(inner_vector, outer_vector); - builder.Tuple({negate_array, vector_sum}); + auto param = Parameter(&builder, 0, nested_tuple_shape, "param"); + auto inner_tuple = GetTupleElement(param, 0); + auto inner_array = GetTupleElement(inner_tuple, 0); + auto inner_vector = GetTupleElement(inner_tuple, 1); + auto outer_vector = GetTupleElement(param, 1); + + auto negate_array = Neg(inner_array); + auto vector_sum = Add(inner_vector, outer_vector); + Tuple(&builder, {negate_array, vector_sum}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = Literal::MakeTuple( - {Literal::MakeTuple( - {Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - Literal::CreateR1({42.0, 75.0, 123.0}).get()}) + auto arg_literal = LiteralUtil::MakeTuple( + {LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), + LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}) .get(), - Literal::CreateR1({222.0, -2.0, 10.0}).get()}); + LiteralUtil::CreateR1({222.0, -2.0, 10.0}).get()}); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR1Equal( - {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR1Equal({264.0, 73.0, 133.0}, + LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -376,31 +371,30 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { ShapeUtil::MakeTupleShape({array_shape, array_shape}); XlaBuilder builder(TestName()); - auto param = builder.Parameter(0, tuple_shape, "param"); - auto element_0 = builder.GetTupleElement(param, 0); - auto element_1 = builder.GetTupleElement(param, 1); - builder.Tuple({builder.Neg(element_0), builder.Add(element_1, element_1)}); + auto param = Parameter(&builder, 0, tuple_shape, "param"); + auto element_0 = GetTupleElement(param, 0); + auto element_1 = GetTupleElement(param, 1); + Tuple(&builder, {Neg(element_0), Add(element_1, element_1)}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = Literal::MakeTuple( - {Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - Literal::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); + auto arg_literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), + LiteralUtil::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); - LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralSlice(*result_0_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4.0}}, + LiteralSlice(*result_0_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{22.0, 6.0}, {8.0, 10}}, + LiteralSlice(*result_0_literal, {1})); ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); - LiteralTestUtil::ExpectR2Equal( - {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0, 2.0}, {3.0, 4.0}}, + LiteralSlice(*result_1_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{44.0, 12.0}, {16.0, 20}}, + LiteralSlice(*result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -420,26 +414,25 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes); XlaBuilder builder(TestName()); - auto param = builder.Parameter(0, tuple_shape, "param"); + auto param = Parameter(&builder, 0, tuple_shape, "param"); // Add each element's tuple index value to every element. std::vector result_elements; for (int i = 0; i < kElementCount; ++i) { - auto element = builder.GetTupleElement(param, i); - result_elements.push_back( - builder.Add(element, builder.ConstantR0(i))); + auto element = GetTupleElement(param, i); + result_elements.push_back(Add(element, ConstantR0(&builder, i))); } - builder.Tuple(result_elements); + Tuple(&builder, result_elements); auto computation = builder.Build().ConsumeValueOrDie(); // Feed in a tuple where each two-element vector element is {tuple_index, // -tuple_index}. std::vector> arg_elements; for (int i = 0; i < kElementCount; ++i) { - arg_elements.push_back(Literal::CreateR1({1.0f * i, -1.0f * i})); + arg_elements.push_back(LiteralUtil::CreateR1({1.0f * i, -1.0f * i})); } std::unique_ptr arg_literal = - Literal::MakeTupleOwned(std::move(arg_elements)); + LiteralUtil::MakeTupleOwned(std::move(arg_elements)); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); @@ -447,8 +440,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), - error_spec_); + {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); } } @@ -465,22 +457,22 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes); XlaBuilder builder(TestName()); - auto param = builder.Parameter(0, tuple_shape, "param"); + auto param = Parameter(&builder, 0, tuple_shape, "param"); // The computation increments each leaf value by an amount equal to the leaf's // ordinal position in a traversal of the tuple. std::vector result_elements; for (int i = 0; i < kFanout; ++i) { - auto outer_element = builder.GetTupleElement(param, i); + auto outer_element = GetTupleElement(param, i); std::vector inner_result_elements; for (int j = 0; j < kFanout; ++j) { - auto inner_element = builder.GetTupleElement(outer_element, j); - inner_result_elements.push_back(builder.Add( - inner_element, builder.ConstantR0(i * kFanout + j))); + auto inner_element = GetTupleElement(outer_element, j); + inner_result_elements.push_back( + Add(inner_element, ConstantR0(&builder, i * kFanout + j))); } - result_elements.push_back(builder.Tuple(inner_result_elements)); + result_elements.push_back(Tuple(&builder, inner_result_elements)); } - builder.Tuple(result_elements); + Tuple(&builder, result_elements); auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. @@ -488,12 +480,13 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { for (int i = 0; i < kFanout; ++i) { std::vector> inner_tuple_elements; for (int j = 0; j < kFanout; ++j) { - inner_tuple_elements.push_back(Literal::CreateR0(i + j)); + inner_tuple_elements.push_back(LiteralUtil::CreateR0(i + j)); } outer_tuple_elements.push_back( - Literal::MakeTupleOwned(std::move(inner_tuple_elements))); + LiteralUtil::MakeTupleOwned(std::move(inner_tuple_elements))); } - auto arg_literal = Literal::MakeTupleOwned(std::move(outer_tuple_elements)); + auto arg_literal = + LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements)); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); @@ -520,23 +513,23 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { } XlaBuilder builder(TestName()); - auto element = builder.Parameter(0, shape, "param"); + auto element = Parameter(&builder, 0, shape, "param"); for (int i = 0; i < kTupleDepth; ++i) { - element = builder.GetTupleElement(element, 0); + element = GetTupleElement(element, 0); } - auto output = builder.Add(element, builder.ConstantR0(42.0)); + auto output = Add(element, ConstantR0(&builder, 42.0)); for (int i = 0; i < kTupleDepth; ++i) { - output = builder.Tuple({output}); + output = Tuple(&builder, {output}); } auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. - std::unique_ptr arg_literal = Literal::CreateR0(123.0); + std::unique_ptr arg_literal = LiteralUtil::CreateR0(123.0); for (int i = 0; i < kTupleDepth; ++i) { std::vector> arg_vector; arg_vector.push_back(std::move(arg_literal)); - arg_literal = Literal::MakeTupleOwned(std::move(arg_vector)); + arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector)); } auto arg_buffer = LiteralToShapedBuffer(*arg_literal); @@ -547,19 +540,19 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { for (int i = 0; i < kTupleDepth; ++i) { index.push_back(0); } - LiteralTestUtil::ExpectR0Equal( - 165.0, LiteralSlice(*result_literal, index)); + LiteralTestUtil::ExpectR0Equal(165.0, + LiteralSlice(*result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { // Test passing in an invalid number of arguments. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {3}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {3}), "y"); + Add(x, y); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({1.0f, 2.0f, 3.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -571,11 +564,11 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { // Test passing in an argument with the wrong shape. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - builder.Neg(x); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + Neg(x); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -588,11 +581,11 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { // Test passing in an invalid result layout parameter. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); - builder.Neg(x); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); + Neg(x); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions().set_result_layout( @@ -611,7 +604,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { // Try to run a trivial computation on every device on the system. If a // specific device is not supported, check that the right error is returned. XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); auto computation = builder.Build().ConsumeValueOrDie(); for (int d = 0; d < local_client_->device_count(); ++d) { if (!local_client_->device_ordinal_supported(d)) { @@ -638,7 +631,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) { // Try running computations on devices with device ordinal values which do not // exist. XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); auto computation = builder.Build().ConsumeValueOrDie(); auto execute_status = @@ -655,7 +648,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) { XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { // Run a computation on a specific stream on each device on the system. XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); auto computation = builder.Build().ConsumeValueOrDie(); for (int d = 0; d < local_client_->device_count(); ++d) { @@ -691,7 +684,7 @@ XLA_TEST_F(LocalClientExecuteTest, wrong_stream.Init(); XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions().set_stream(&wrong_stream)); @@ -708,7 +701,7 @@ XLA_TEST_F(LocalClientExecuteTest, TestAllocator allocator(wrong_platform); XlaBuilder builder(TestName()); - auto y = builder.ConstantR0(123.0f); + ConstantR0(&builder, 123.0f); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), @@ -721,7 +714,7 @@ XLA_TEST_F(LocalClientExecuteTest, XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) { // Try to run a computation on a stream that has not been initialized. XlaBuilder builder(TestName()); - builder.ConstantR0(42.0f); + ConstantR0(&builder, 42.0f); LOG(INFO) << "default device = " << local_client_->default_device_ordinal(); se::StreamExecutor* executor = @@ -744,26 +737,26 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); + Select(ConstantR0(&builder, false), tuple12, tuple21); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR1Equal( - {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0})); - LiteralTestUtil::ExpectR1Equal( - {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1})); + LiteralTestUtil::ExpectR1Equal({2.0f, 4.0f, 6.0f}, + LiteralSlice(*tuple_literal, {0})); + LiteralTestUtil::ExpectR1Equal({1.0f, 2.0f, 3.0f}, + LiteralSlice(*tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); - auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); Shape argument_layout = ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0}); @@ -775,10 +768,14 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { executable_status.ConsumeValueOrDie(); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = executable->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); + ASSERT_IS_OK(local_client_->mutable_backend() + ->BorrowStream(0) + .ValueOrDie() + ->BlockHostUntilDone()); LiteralTestUtil::ExpectR1Near( {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); @@ -799,29 +796,29 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { }; // Array shapes. - test_to_device_and_back(*Literal::CreateR0(42.0)); - test_to_device_and_back(*Literal::CreateR0(true)); - test_to_device_and_back(*Literal::CreateR1({1.0, 42.0, 744.4})); + test_to_device_and_back(*LiteralUtil::CreateR0(42.0)); + test_to_device_and_back(*LiteralUtil::CreateR0(true)); + test_to_device_and_back(*LiteralUtil::CreateR1({1.0, 42.0, 744.4})); test_to_device_and_back( - *Literal::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*Literal::CreateR2({{2, 1}, {4444, 56}})); + *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); // Null shape (empty tuple). - test_to_device_and_back(*Literal::MakeTuple({})); + test_to_device_and_back(*LiteralUtil::MakeTuple({})); // Non-nested tuples. test_to_device_and_back( - *Literal::MakeTuple({Literal::CreateR0(12223.0).get()})); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12223.0).get()})); test_to_device_and_back( - *Literal::MakeTuple({Literal::CreateR1({1.0, -42.0}).get(), - Literal::CreateR0(123456.0).get()})); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), + LiteralUtil::CreateR0(123456.0).get()})); // Nested tuple. - test_to_device_and_back(*Literal::MakeTuple( - {Literal::MakeTuple({Literal::CreateR1({1.0, -42.0}).get(), - Literal::CreateR0(123456.0).get()}) + test_to_device_and_back(*LiteralUtil::MakeTuple( + {LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), + LiteralUtil::CreateR0(123456.0).get()}) .get(), - Literal::CreateR0(false).get()})); + LiteralUtil::CreateR0(false).get()})); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { @@ -839,24 +836,47 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { }; test_to_device_and_back( - *Literal::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*Literal::CreateR2({{2, 1}, {4444, 56}})); + *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); test_to_device_and_back( - *Literal::CreateR2({{20000000000ULL, 1}, {4444, 56}})); - test_to_device_and_back( - *Literal::MakeTuple({Literal::CreateR1({1.0, -42.0}).get(), - Literal::CreateR0(123456789000LL).get()})); + *LiteralUtil::CreateR2({{20000000000ULL, 1}, {4444, 56}})); + test_to_device_and_back(*LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({1.0, -42.0}).get(), + LiteralUtil::CreateR0(123456789000LL).get()})); } -// TODO(b/34359662): Support infeed/outfeed on GPU and CPU parallel. -// 2017-10-18. -XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) { +XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); - auto in = builder.Infeed(shape); - auto constant = builder.ConstantR1({1.0f, 2.0f, 3.0f}); - auto sum = builder.Add(in, constant); - builder.Outfeed(sum, shape, /*outfeed_config=*/""); + auto in = Infeed(&builder, shape); + auto constant = ConstantR1(&builder, {1.0f, 2.0f, 3.0f}); + Add(in, constant); + + std::unique_ptr result; + std::unique_ptr thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", [&] { + result = ShapedBufferToLiteral(ExecuteLocallyOrDie( + builder.Build().ValueOrDie(), /*arguments=*/{})); + })); + + ASSERT_IS_OK(local_client_->TransferToInfeedLocal( + *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), + local_client_->default_device_ordinal())); + + // Join the thread. + thread.reset(); + + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); +} + +XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { + XlaBuilder builder(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {3}); + auto in = Infeed(&builder, shape); + auto constant = ConstantR1(&builder, {1.0f, 2.0f, 3.0f}); + auto sum = Add(in, constant); + Outfeed(sum, shape, /*outfeed_config=*/""); std::unique_ptr thread( tensorflow::Env::Default()->StartThread( @@ -864,7 +884,7 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) { [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *Literal::CreateR1({-5.0, 123.0, 42.0}), + *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, @@ -891,17 +911,19 @@ void BM_LocalClientOverhead(int num_iters) { // Use a tiny add operation as the computation. XlaBuilder builder("Add"); auto shape = ShapeUtil::MakeShape(F32, {2, 3}); - auto x = builder.Parameter(0, shape, "x"); - builder.Add(x, x); + auto x = Parameter(&builder, 0, shape, "x"); + Add(x, x); auto computation = builder.Build().ConsumeValueOrDie(); auto buffer = transfer_manager ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0) .ConsumeValueOrDie(); - auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *literal, buffer)); + auto literal = LiteralUtil::CreateR2({{0, 0, 0}, {0, 0, 0}}); + auto stream = + client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal, + buffer)); const int kWarmups = 2; @@ -911,11 +933,8 @@ void BM_LocalClientOverhead(int num_iters) { std::unique_ptr executable = executable_status.ConsumeValueOrDie(); - se::Stream stream(executors[client->default_device_ordinal()]); - stream.Init(); - ExecutableRunOptions run_options; - run_options.set_allocator(&allocator).set_stream(&stream); + run_options.set_allocator(&allocator).set_stream(stream.get()); for (int i = 0; i < kWarmups; ++i) { auto result = executable->Run({&buffer}, run_options); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 88797a7d0a7d0567b3a380c5fb1ad0c0ee875587..eaddf756dbc913dd9668cd22228fbd18c2c33309 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -20,6 +20,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -189,7 +190,19 @@ StatusOr LocalClientTestBase::ExecuteLocally( TF_ASSIGN_OR_RETURN( std::unique_ptr executable, local_client_->Compile(computation, argument_layouts, build_options)); - return executable->Run(arguments, run_options); + TF_ASSIGN_OR_RETURN(auto ret, executable->Run(arguments, run_options)); + + auto device_ordinal = + build_options.device_ordinal() == -1 ? 0 : build_options.device_ordinal(); + auto* stream = run_options.stream(); + if (!stream) { + stream = local_client_->mutable_backend() + ->BorrowStream(device_ordinal) + .ValueOrDie() + .get(); + } + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + return std::move(ret); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 258226523d830b40ecaa761df95988dc90f5ca47..b4477e9a6b23363ee3a1380f9f98f4b8226f6920 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index c0c02e584c2348f64a9d7d0800038f5ca67a2171..2d622242e657ce032a17f7b26c94227d343e2a38 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -30,8 +30,8 @@ class LogTest : public ClientLibraryTestBase {}; XLA_TEST_F(LogTest, LogZeroValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR3FromArray3D(Array3D(3, 0, 0)); - builder.Log(x); + auto x = ConstantR3FromArray3D(&builder, Array3D(3, 0, 0)); + Log(x); ComputeAndCompareR3(&builder, Array3D(3, 0, 0), {}, ErrorSpec(0.0001)); @@ -42,8 +42,8 @@ TEST_F(LogTest, LogTenValues) { 5.0, 6.0, -7.0, -8.0, 9.0}; XlaBuilder builder(TestName()); - auto x = builder.ConstantR1(input); - builder.Log(x); + auto x = ConstantR1(&builder, input); + Log(x); std::vector expected; expected.reserve(input.size()); diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 3975e9125703ee081d4e84fa8bd27fcbe483ac34..0732e195d44d738b264361e43d38259c26a4116e 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -52,9 +52,9 @@ class MapTest : public ClientLibraryTestBase { // 1.0f ---------/ XlaComputation CreateAdderToOne() { XlaBuilder mapped_builder(TestName()); - auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto one = mapped_builder.ConstantR0(1.0); - mapped_builder.Add(x, one); + auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto one = ConstantR0(&mapped_builder, 1.0); + Add(x, one); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -62,9 +62,9 @@ class MapTest : public ClientLibraryTestBase { XlaComputation CreateMax() { XlaBuilder b(TestName()); - auto lhs = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto rhs = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - b.Max(lhs, rhs); + auto lhs = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto rhs = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Max(lhs, rhs); auto computation_status = b.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -75,8 +75,8 @@ class MapTest : public ClientLibraryTestBase { template XlaComputation CreateScalarOne() { XlaBuilder mapped_builder("scalar_one"); - (void)mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - mapped_builder.ConstantR0(1); + (void)Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + ConstantR0(&mapped_builder, 1); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -89,9 +89,9 @@ class MapTest : public ClientLibraryTestBase { // 2.0f ---------/ XlaComputation CreateMulByTwo() { XlaBuilder mapped_builder(TestName()); - auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto two = mapped_builder.ConstantR0(2.0); - mapped_builder.Mul(x, two); + auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto two = ConstantR0(&mapped_builder, 2.0); + Mul(x, two); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -107,10 +107,10 @@ class MapTest : public ClientLibraryTestBase { // 1.0f ---------/ XlaComputation CreateAdderToOneTimesItself() { XlaBuilder mapped_builder(TestName()); - auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto one = mapped_builder.ConstantR0(1.0); - auto adder_to_one = mapped_builder.Add(x, one); - mapped_builder.Mul(x, adder_to_one); + auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto one = ConstantR0(&mapped_builder, 1.0); + auto adder_to_one = Add(x, one); + Mul(x, adder_to_one); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -125,10 +125,10 @@ class MapTest : public ClientLibraryTestBase { XlaComputation CreateMapPlusN(const XlaComputation& embedded_computation, float n) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto map = builder.Map({x}, embedded_computation, {}); - auto constant_n = builder.ConstantR0(n); - builder.Add(map, constant_n); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto map = Map(&builder, {x}, embedded_computation, {}); + auto constant_n = ConstantR0(&builder, n); + Add(map, constant_n); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -138,9 +138,9 @@ class MapTest : public ClientLibraryTestBase { // defined by (x, y) -> x > y. XlaComputation CreateGt() { XlaBuilder b("Gt"); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - b.Gt(x, y); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Gt(x, y); auto computation_status = b.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -155,11 +155,11 @@ class MapTest : public ClientLibraryTestBase { // z {R0F32} ---------------/ XlaComputation CreateTernaryAdder() { XlaBuilder mapped_builder("TernaryAdder"); - auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = mapped_builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - auto z = mapped_builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "z"); - auto xy = mapped_builder.Add(x, y); - mapped_builder.Add(xy, z); + auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&mapped_builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + auto z = Parameter(&mapped_builder, 2, ShapeUtil::MakeShape(F32, {}), "z"); + auto xy = Add(x, y); + Add(xy, z); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -169,12 +169,12 @@ class MapTest : public ClientLibraryTestBase { TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR0(42.0); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(42.0); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOne(), {}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOne(), {}); ComputeAndCompareR0(&builder, 43.0, {param0_data.get()}, ErrorSpec(0.01f)); @@ -183,12 +183,12 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR1({}); + std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOne(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -198,12 +198,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOne(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -212,12 +212,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { TEST_F(MapTest, MapEachF32ElementToS32Constant) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateScalarOne(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); } @@ -225,12 +225,12 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { TEST_F(MapTest, MapEachF32ElementToU32Constant) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateScalarOne(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); } @@ -239,12 +239,12 @@ TEST_F(MapTest, MapEachElemLongerChainR1) { // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); + LiteralUtil::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOneTimesItself(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOneTimesItself(), {0}); ComputeAndCompareR1( &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f}, @@ -255,13 +255,13 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR1({}); + std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); - builder.Map({map1}, CreateMulByTwo(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); + Map(&builder, {map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -272,13 +272,13 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); - builder.Map({map1}, CreateMulByTwo(), {0}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); + Map(&builder, {map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1(&builder, {6.4f, 8.6f, 10.8f, 13.0f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -287,13 +287,13 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { TEST_F(MapTest, MapEachElemPlusOneR2) { // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR2( + std::unique_ptr param0_literal = LiteralUtil::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param}, CreateAdderToOne(), {0, 1}); + auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param}, CreateAdderToOne(), {0, 1}); Array2D expected_array( {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}}); @@ -319,10 +319,10 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { auto embed3 = CreateMapPlusN(embed1, 4.0); XlaBuilder embed4_builder("embed4"); - auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x"); - auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2, {}); - auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3, {}); - embed4_builder.Add(embed4_map_lhs, embed4_map_rhs); + auto embed4_param = Parameter(&embed4_builder, 0, scalar_shape, "x"); + auto embed4_map_lhs = Map(&embed4_builder, {embed4_param}, embed2, {}); + auto embed4_map_rhs = Map(&embed4_builder, {embed4_param}, embed3, {}); + Add(embed4_map_lhs, embed4_map_rhs); auto embed4_status = embed4_builder.Build(); ASSERT_IS_OK(embed4_status.status()); auto embed4 = embed4_status.ConsumeValueOrDie(); @@ -330,11 +330,11 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { auto embed5 = CreateMapPlusN(embed2, 6.0); XlaBuilder builder(TestName()); - auto constant_42 = builder.ConstantR0(42.0); - auto constant_7 = builder.ConstantR0(7.0); - auto map_42 = builder.Map({constant_42}, embed5, {}); - auto map_7 = builder.Map({constant_7}, embed4, {}); - builder.Add(map_42, map_7); + auto constant_42 = ConstantR0(&builder, 42.0); + auto constant_7 = ConstantR0(&builder, 7.0); + auto map_42 = Map(&builder, {constant_42}, embed5, {}); + auto map_7 = Map(&builder, {constant_7}, embed4, {}); + Add(map_42, map_7); ComputeAndCompareR0(&builder, 73.0, {}, ErrorSpec(0.01f)); } @@ -343,17 +343,18 @@ TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, CreateScalarAddComputation(F32, &builder), {0}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder), + {0}); ComputeAndCompareR1(&builder, {7.3f, 7.7, 4.3f, 0}, {param0_data.get(), param1_data.get()}, @@ -364,20 +365,20 @@ TEST_F(MapTest, MapBinaryAdder) { // for Map that used to fail in shape inference (b/28989438). XLA_TEST_F(MapTest, AddWithMixedLayouts) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR2WithLayout( + std::unique_ptr param0_literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = Literal::CreateR2WithLayout( + std::unique_ptr param1_literal = LiteralUtil::CreateR2WithLayout( {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1})); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder), - {0, 1}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), + {0, 1}); Array2D expected(2, 2); expected(0, 0) = 11; @@ -391,19 +392,19 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { XLA_TEST_F(MapTest, AddR3_3x0x2) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); + LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); + LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder), - {0, 1, 2}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), + {0, 1, 2}); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {param0_data.get(), param1_data.get()}); @@ -413,22 +414,22 @@ TEST_F(MapTest, MapTernaryAdder) { // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); std::unique_ptr param2_literal = - Literal::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); + LiteralUtil::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); std::unique_ptr param2_data = client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto param2 = builder.Parameter(2, param2_literal->shape(), "param2"); - builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2"); + Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0}); ComputeAndCompareR1( &builder, {-2.7f, -92.3f, -895.7f, -400.0f}, @@ -440,7 +441,8 @@ TEST_F(MapTest, MapGt) { // Maps (x,y) -> x > y onto two R1F32 vectors. XlaBuilder b(TestName()); auto gt = CreateGt(); - b.Map({b.ConstantR1({1, 20}), b.ConstantR1({10, 2})}, gt, {0}); + Map(&b, {ConstantR1(&b, {1, 20}), ConstantR1(&b, {10, 2})}, gt, + {0}); ComputeAndCompareR1(&b, {false, true}, {}); } @@ -449,15 +451,15 @@ TEST_F(MapTest, NestedBinaryMap) { { // max_with_square(x) = do max(x, x^2) via a map. XlaBuilder b("max_with_square"); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - b.Map({x, b.Mul(x, x)}, CreateMax(), {}); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + Map(&b, {x, Mul(x, x)}, CreateMax(), {}); auto computation_status = b.Build(); ASSERT_IS_OK(computation_status.status()); max_with_square = computation_status.ConsumeValueOrDie(); } XlaBuilder b(TestName()); - auto input = b.ConstantR1({0.1f, 0.5f, -0.5f, 1.0f, 2.0f}); - b.Map({input}, max_with_square, {0}); + auto input = ConstantR1(&b, {0.1f, 0.5f, -0.5f, 1.0f, 2.0f}); + Map(&b, {input}, max_with_square, {0}); ComputeAndCompareR1(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {}); } @@ -468,23 +470,23 @@ TEST_F(MapTest, MapOperantionWithBuildError) { XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("ErrorAdd"); - auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(U16, {}), "y"); - sub_builder->Add(x, y); + auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(U16, {}), "y"); + Add(x, y); auto error_add = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, error_add, {0}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, error_add, {0}); StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); @@ -506,21 +508,21 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); - auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - sub_builder->Pow(x, y); + auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y"); + Pow(x, y); auto power = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = Literal::CreateR0(2.0f); - std::unique_ptr param1_literal = Literal::CreateR0(5.0f); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); + std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, power, {}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, power, {}); ComputeAndCompareR0(&builder, 32.0f, {param0_data.get(), param1_data.get()}, @@ -533,21 +535,21 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); - auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - sub_builder->Sub(y, x); // note that this is y - x, not x - y + auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y"); + Sub(y, x); // note that this is y - x, not x - y auto sub_opposite = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = Literal::CreateR0(2.0f); - std::unique_ptr param1_literal = Literal::CreateR0(5.0f); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); + std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, sub_opposite, {}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + Map(&builder, {param0, param1}, sub_opposite, {}); ComputeAndCompareR0( &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f)); @@ -559,16 +561,16 @@ TEST_F(MapTestWithFullOpt, MapSquare) { XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); - auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - sub_builder->Mul(x, x); + auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x"); + Mul(x, x); auto square = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = Literal::CreateR0(10.0f); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(10.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param0}, square, {}); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + Map(&builder, {param0}, square, {}); ComputeAndCompareR0(&builder, 100.0f, {param0_data.get()}, ErrorSpec(0.01f)); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 27fd36e06acdc589f3a84ad561164e4a33b93506..da8c42d465340f2af3d6acd2c3676b69512f193f 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -56,15 +56,15 @@ TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, TypesF16F32); XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { using T = TypeParam; XlaBuilder builder("exp_2x2"); - auto data = builder.ConstantR2FromArray2D({ - {1.0f, 0.0f}, // row 0 - {-1.0f, 0.5f}, // row 1 - }); - builder.Exp(data); + auto data = ConstantR2FromArray2D(&builder, { + {1.0f, 0.0f}, // row 0 + {-1.0f, 0.5f}, // row 1 + }); + Exp(data); std::unique_ptr expected = - Literal::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 - {0.36788f, 1.64872f}}); // row 1 + LiteralUtil::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 + {0.36788f, 1.64872f}}); // row 1 this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } @@ -76,43 +76,43 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { // add_half(x) = x + 0.5 XlaBuilder builder("add_half"); auto x_value = - builder.Parameter(0, ShapeUtil::MakeShapeWithType({}), "x_value"); - auto half = builder.ConstantR0(static_cast(0.5)); - builder.Add(x_value, half); + Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({}), "x_value"); + auto half = ConstantR0(&builder, static_cast(0.5)); + Add(x_value, half); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); add_half = computation_status.ConsumeValueOrDie(); } XlaBuilder builder("map_2x2"); - auto data = builder.ConstantR2FromArray2D({ - {1.0f, 0.0f}, // row 0 - {-1.0f, 0.5f}, // row 1 - }); - auto map = builder.Map({data}, add_half, {0, 1}); + auto data = ConstantR2FromArray2D(&builder, { + {1.0f, 0.0f}, // row 0 + {-1.0f, 0.5f}, // row 1 + }); + Map(&builder, {data}, add_half, {0, 1}); std::unique_ptr expected = - Literal::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 - {-0.5f, 1.0f}}); // row 1 + LiteralUtil::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 + {-0.5f, 1.0f}}); // row 1 this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { using T = TypeParam; XlaBuilder builder("max_2x2"); - auto lhs = builder.ConstantR2FromArray2D({ - {7.0f, 2.0f}, // row 0 - {3.0f, -4.0f}, // row 1 - }); - auto rhs = builder.ConstantR2FromArray2D({ - {5.0f, 6.0f}, // row 0 - {1.0f, -8.0f}, // row 1 - }); - auto max = builder.Max(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, { + {7.0f, 2.0f}, // row 0 + {3.0f, -4.0f}, // row 1 + }); + auto rhs = ConstantR2FromArray2D(&builder, { + {5.0f, 6.0f}, // row 0 + {1.0f, -8.0f}, // row 1 + }); + Max(lhs, rhs); std::unique_ptr expected = - Literal::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 - {3.0f, -4.0f}}); // row 1 + LiteralUtil::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 + {3.0f, -4.0f}}); // row 1 this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); } @@ -137,9 +137,9 @@ class TestLinspaceMaxParametric XlaBuilder builder( tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - auto max = builder.Max(lhs, rhs); + auto lhs = ConstantR2FromArray2D(&builder, *alhs); + auto rhs = ConstantR2FromArray2D(&builder, *arhs); + Max(lhs, rhs); Array2D expected(rows, cols); for (int row = 0; row < rows; ++row) { @@ -200,31 +200,33 @@ class MatOpsDotAddTest TF_ASSERT_OK_AND_ASSIGN( auto lhs_handle, - client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); TF_ASSERT_OK_AND_ASSIGN( auto rhs_handle, - client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); XlaBuilder builder(TestName()); - auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); + auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs"); auto lhs_mat_arg = lhs_arg; if (transpose) { - lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); + lhs_mat_arg = Transpose(lhs_mat_arg, {1, 0}); } - auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); - auto result = builder.Dot(lhs_mat_arg, rhs_arg); + auto rhs_arg = Parameter(&builder, 1, rhs_shape, "rhs"); + auto result = Dot(lhs_mat_arg, rhs_arg); Array2D expected; if (add_lhs) { - result = builder.Add(result, lhs_arg); + result = Add(result, lhs_arg); if (transpose) { expected = Array2D({{47.0f, 52.0f}, {71.0f, 78.0f}}); } else { expected = Array2D({{35.0f, 39.0f}, {81.0f, 89.0f}}); } } else { - result = builder.Add(result, rhs_arg); + result = Add(result, rhs_arg); if (transpose) { expected = Array2D({{56.0f, 61.0f}, {80.0f, 87.0f}}); } else { diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc index 0791a71aacf7614286fe964623a3172a174d4722..955dbef6dcd28421fb351c6ee064ac53eda1fd08 100644 --- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc +++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -33,9 +33,10 @@ class SliceTest : public ClientLibraryTestBase {}; XLA_TEST_F(SliceTest, Slice2D) { XlaBuilder builder("slice_2d"); - auto original = builder.ConstantR2( + auto original = ConstantR2( + &builder, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}}); - builder.Slice(original, {2, 1}, {4, 3}, {1, 1}); + Slice(original, {2, 1}, {4, 3}, {1, 1}); Array2D expected({{8.0f, 9.0f}, {11.0f, 12.0f}}); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); @@ -45,8 +46,8 @@ XLA_TEST_F(SliceTest, Slice3D) { XlaBuilder builder("slice_3d"); Array3D array_3d( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}); - auto original = builder.ConstantR3FromArray3D(array_3d); - builder.Slice(original, {0, 0, 1}, {2, 1, 2}, {1, 1, 1}); + auto original = ConstantR3FromArray3D(&builder, array_3d); + Slice(original, {0, 0, 1}, {2, 1, 2}, {1, 1, 1}); Array3D expected_3d({{{2.0f}}, {{6.0f}}}); ComputeAndCompareR3(&builder, expected_3d, {}, ErrorSpec(0.000001)); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 41f723edf1ff3518686231f31b61b64291b1f6bf..eb06b115daa96bccd73de30bb7fa30733a6fd947 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -60,7 +60,7 @@ class MultiOutputFusionTest : public HloTestBase { const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size}); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(8.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(8.0f))); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, elem_shape0, "0")); @@ -105,8 +105,9 @@ class MultiOutputFusionTest : public HloTestBase { Literal expect(ShapeUtil::MakeShape(F32, {size, size})); expect.PopulateWithValue(size * 1.5f * 3.5f); - auto actual = ExecuteAndTransfer( - std::move(hlo_module), {Literal::CreateR0(-9.0f).get(), &arg1}); + auto actual = + ExecuteAndTransfer(std::move(hlo_module), + {LiteralUtil::CreateR0(-9.0f).get(), &arg1}); EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } @@ -165,7 +166,8 @@ class MultiOutputFusionTest : public HloTestBase { Literal input1(ShapeUtil::MakeShape(F64, {size})); input1.PopulateWithValue(1.); - Literal expect = std::move(*Literal::CreateR1({size * 1.5f * 3.5f})); + Literal expect = + std::move(*LiteralUtil::CreateR1({size * 1.5f * 3.5f})); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } @@ -198,16 +200,16 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::MakeTupleOwned( - Literal::MakeTupleOwned( - Literal::MakeTupleOwned(Literal::CreateR0(42)), - Literal::CreateR0(1.0)), - Literal::MakeTupleOwned(Literal::CreateR0(3.0), - Literal::CreateR0(4))); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + auto param = LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), + LiteralUtil::CreateR0(1.0)), + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(3.0), + LiteralUtil::CreateR0(4))); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned(Literal::CreateR0(42)))); + *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), *result)); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { @@ -232,11 +234,10 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR1({1.0, 2.0, 3.0, -1.0}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); - EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::CreateR1({0.0, 4.0, 9.0, 1.0}))); + auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0, -1.0}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, *result); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { @@ -266,11 +267,10 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR1({1.0, 2.0, 3.0}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); - EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::CreateR1({0.0, 4.0, 9.0}))); + auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, *result); } const char* const kScalarOps = R"( @@ -310,13 +310,15 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, - *Literal::MakeTupleOwned(Literal::CreateR2({{3, 7}, {11, 15}}), - Literal::CreateR2({{5, 16}, {36, 64}})))); + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{3, 7}, {11, 15}}), + LiteralUtil::CreateR2({{5, 16}, {36, 64}})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -340,13 +342,15 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned( - Literal::CreateR2({{6, 8}, {10, 12}}), - Literal::CreateR2({{25, 36}, {49, 64}})))); + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{6, 8}, {10, 12}}), + LiteralUtil::CreateR2({{25, 36}, {49, 64}})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -371,13 +375,15 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned(Literal::CreateR1({14, 22}), - Literal::CreateR1({36, 64}), - Literal::CreateR1({66, 138})))); + *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), + LiteralUtil::CreateR1({36, 64}), + LiteralUtil::CreateR1({66, 138})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -402,15 +408,16 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, - *Literal::MakeTupleOwned( - Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), - Literal::CreateR2({{3, 7}, {11, 15}}), - Literal::CreateR2({{5, 16}, {36, 64}})))); + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), + LiteralUtil::CreateR2({{3, 7}, {11, 15}}), + LiteralUtil::CreateR2({{5, 16}, {36, 64}})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -435,15 +442,17 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, - *Literal::MakeTupleOwned( - Literal::CreateR2({{6, 8}, {10, 12}}), - Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), - Literal::CreateR2({{25, 36}, {49, 64}})))); + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{6, 8}, {10, 12}}), + LiteralUtil::CreateR3( + {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + LiteralUtil::CreateR2({{25, 36}, {49, 64}})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -455,7 +464,8 @@ XLA_TEST_F(MultiOutputFusionTest, r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add mul = f32[2,2,2]{2,1,0} multiply(p0, p0) c1 = f32[] constant(5) - mul2 = f32[2,2,2]{2,1,0} multiply(p0, c1) + b1 = f32[2,2,2]{2,1,0} broadcast(c1), dimensions={} + mul2 = f32[2,2,2]{2,1,0} multiply(p0, b1) ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) tuple(r1, mul, mul2) } @@ -468,16 +478,18 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, - *Literal::MakeTupleOwned( - Literal::CreateR1({14, 22}), - Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), - Literal::CreateR3( - {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})))); + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR1({14, 22}), + LiteralUtil::CreateR3( + {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + LiteralUtil::CreateR3( + {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -502,16 +514,57 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - auto init1 = Literal::CreateR0(5); - auto init2 = Literal::CreateR0(6); - TF_ASSERT_OK_AND_ASSIGN( - auto result, - Execute(std::move(module), {param.get(), init1.get(), init2.get()})); + auto param = + LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto init1 = LiteralUtil::CreateR0(5); + auto init2 = LiteralUtil::CreateR0(6); + std::unique_ptr result = ExecuteNoHloPasses( + std::move(module), {param.get(), init1.get(), init2.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{167, 172}, {176, 180}}), + LiteralUtil::CreateR2({{6, 6}, {6, 8}})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) { + p0 = f16[2,2,2]{2,1,0} parameter(0) + convert = f32[2,2,2]{2,1,0} convert(p0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(convert, convert) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) + tuple(r1, r2, p0) + } + + ENTRY reduce { + p = f16[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = LiteralUtil::CreateR3( + {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, + {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned( - Literal::CreateR2({{167, 172}, {176, 180}}), - Literal::CreateR2({{6, 6}, {6, 8}})))); + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{3, 7}, {11, 15}}), + LiteralUtil::CreateR2({{5, 16}, {36, 64}}), + LiteralUtil::CreateR3( + {{{Eigen::half(1), Eigen::half(2)}, + {Eigen::half(3), Eigen::half(4)}}, + {{Eigen::half(5), Eigen::half(6)}, + {Eigen::half(7), Eigen::half(8)}}})), + *result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0a0426adcbc1b5b89be0841fa2c4204e2b65abf4 --- /dev/null +++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc @@ -0,0 +1,169 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/local_client_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +// Tests that ensure outfeed instructions that are contained in nested +// computations in non-root positions are executed. + +class OutfeedInNestedComputationTest : public LocalClientTestBase {}; + +XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { + XlaBuilder b(TestName()); + + Shape state_tuple_array_shape = ShapeUtil::MakeShape(xla::S32, {10, 5}); + Shape int_shape = ShapeUtil::MakeShape(xla::S32, {}); + Shape state_tuple_shape = + ShapeUtil::MakeTupleShape({int_shape, state_tuple_array_shape}); + Shape xfeed_shape = ShapeUtil::MakeShape(xla::S32, {2}); + + XlaOp some_buffer = Broadcast(ConstantR0(&b, 0), {10, 5}); + XlaOp num_iter = Infeed(&b, int_shape); + XlaOp init_tuple = Tuple(&b, {num_iter, some_buffer}); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation loop_cond, [&] { + // Condition: iteration variable > 0 + XlaBuilder cond_builder("loop_condition"); + XlaOp state_tuple = Parameter(&cond_builder, 0, state_tuple_shape, "state"); + XlaOp loop_counter = GetTupleElement(state_tuple, 0); + Outfeed(loop_counter, int_shape, ""); + Gt(loop_counter, ConstantR0(&cond_builder, 0)); + return cond_builder.Build(); + }()); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation loop_body, [&] { + XlaBuilder body_builder("loop_body"); + XlaOp state_tuple = Parameter(&body_builder, 0, state_tuple_shape, "state"); + XlaOp loop_counter = GetTupleElement(state_tuple, 0); + XlaOp buffer_inside = GetTupleElement(state_tuple, 1); + + // Read some stuff from Infeed. + XlaOp some_input = Infeed(&body_builder, xfeed_shape); + XlaOp sum = Add(some_input, Broadcast(loop_counter, {2})); + Outfeed(sum, xfeed_shape, ""); + + XlaOp iter_left = Sub(loop_counter, ConstantR0(&body_builder, 1)); + + Tuple(&body_builder, {iter_left, buffer_inside}); + return body_builder.Build(); + }()); + + // Build loop. + XlaOp result_tuple = While(loop_cond, loop_body, init_tuple); + GetTupleElement(result_tuple, 0); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); + + std::unique_ptr comp_result; + std::unique_ptr thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", [&] { + comp_result = local_client_->ExecuteAndTransfer(computation, {}) + .ConsumeValueOrDie(); + })); + + VLOG(1) << "Transferring trip count to computation"; + // Transfer number of iterations to Infeed. + TF_ASSERT_OK( + local_client_->TransferToInfeed(*LiteralUtil::CreateR0(1))); + + // Pick up value from outfeed + { + VLOG(1) << "Reading from condition outfeed"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + local_client_->TransferFromOutfeed(&int_shape)); + EXPECT_EQ(r->Get({}), 1); + } + + VLOG(1) << "Writing data to infeed"; + // Transfer some stuff to Infeed for use inside of loop. + TF_ASSERT_OK(local_client_->TransferToInfeed( + *LiteralUtil::CreateR1({10, 20}))); + + // Pick up value from outfeed + { + VLOG(1) << "Reading from body outfeed"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + local_client_->TransferFromOutfeed(&xfeed_shape)); + EXPECT_EQ(r->Get({0}), 11); + EXPECT_EQ(r->Get({1}), 21); + } + + { + VLOG(1) << "Reading from condition outfeed"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + local_client_->TransferFromOutfeed(&int_shape)); + EXPECT_EQ(r->Get({}), 0); + } + + // Joins the thread + thread.reset(); + + EXPECT_EQ(comp_result->Get({}), 0); +} + +XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { + XlaBuilder b(TestName()); + + Shape condition_shape = ShapeUtil::MakeShape(xla::PRED, {}); + Shape result_shape = ShapeUtil::MakeShape(xla::PRED, {}); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation true_computation, [&] { + XlaBuilder inner_builder("true_computation"); + XlaOp param = Parameter(&inner_builder, 0, result_shape, "param"); + Outfeed(param, result_shape, ""); + Or(param, param); + return inner_builder.Build(); + }()); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation false_computation, [&] { + XlaBuilder inner_builder("false_computation"); + Parameter(&inner_builder, 0, result_shape, "param"); + return inner_builder.Build(); + }()); + + XlaOp pred = Infeed(&b, condition_shape); + Conditional(/*predicate=*/pred, /*true_operand=*/pred, + /*true_computation=*/true_computation, /*false_operand=*/pred, + /*false_computation=*/false_computation); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); + + std::unique_ptr comp_result; + std::unique_ptr thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", [&] { + comp_result = local_client_->ExecuteAndTransfer(computation, {}) + .ConsumeValueOrDie(); + })); + + TF_ASSERT_OK( + local_client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + local_client_->TransferFromOutfeed(&result_shape)); + + EXPECT_EQ(r->Get({}), true); + + // Join the thread + thread.reset(); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index ce295b832d79e4f00656f2893c2ba1162693dd73..ca21b0b2ba590a6daadf2c8d3d9ad213514b0f0f 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(0); - b.Pad(AddParam(*Literal::CreateR1({}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*LiteralUtil::CreateR1({}), &b), + AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, {}, {}, DefaultErrorSpec()); } @@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { dimension->set_edge_padding_high(4); dimension->set_interior_padding(7); - b.Pad(AddParam(*Literal::CreateR1({}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*LiteralUtil::CreateR1({}), &b), + AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, DefaultErrorSpec()); } @@ -123,16 +123,17 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(1); - b.Pad(AddParam(*Literal::CreateR1({1, 2, 3}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*LiteralUtil::CreateR1({1, 2, 3}), &b), + AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); ComputeAndCompareR1(&b, expected, {}, DefaultErrorSpec()); } XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { XlaBuilder b(TestName()); - b.Pad(AddParam(Array4D(2, 0, 3, 2), &b), - AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); + Pad(AddParam(Array4D(2, 0, 3, 2), &b), + AddParam(*LiteralUtil::CreateR0(1.5), &b), + r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, DefaultErrorSpec()); } @@ -147,8 +148,8 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { }); input->FillWithYX(input_xy); - b.Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(1.5), &b), - r4_padding_on_dim0_dim1_); + Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0(1.5), &b), + r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); expected->Fill(1.5); @@ -166,8 +167,9 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - b.Pad(AddParam(input, &b), AddParam(*Literal::CreateR0(pad_value), &b), - r4_padding_on_dim0_dim1_); + Pad(AddParam(input, &b), + AddParam(*LiteralUtil::CreateR0(pad_value), &b), + r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(8, 5, 1, 1); expected->Fill(pad_value); @@ -205,11 +207,11 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { const float pad_value = -5.123f; Array4D input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6}); - auto input = Literal::CreateR4FromArray4D(input_array); + auto input = LiteralUtil::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - b.Pad(AddParam(*input, &b), - AddParam(*Literal::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(*input, &b), + AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 1, 5, 8); expected_array.Fill(pad_value); @@ -251,11 +253,11 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { input_array(0, 0, 0, 0) = 1.0f; input_array(0, 24, 6, 6) = 2.0f; input_array(0, 17, 2, 5) = 3.0f; - auto input = Literal::CreateR4FromArray4D(input_array); + auto input = LiteralUtil::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - b.Pad(AddParam(*input, &b), - AddParam(*Literal::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(*input, &b), + AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 25, 17, 11); expected_array.Fill(pad_value); @@ -275,8 +277,8 @@ XLA_TEST_F(PadTest, Pad4DU8Array) { }); input->FillWithYX(input_xy); - b.Pad(AddParam(*input, &b), b.ConstantR0(35), - r4_padding_on_dim0_dim1_); + Pad(AddParam(*input, &b), ConstantR0(&b, 35), + r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); expected->Fill(35); @@ -294,16 +296,16 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { // Since bool is currently not well supported, use Broadcast operation to // create the operand for Pad. - auto input = b.Broadcast(b.ConstantR0(true), {1, 1, 3, 2}); + auto input = Broadcast(ConstantR0(&b, true), {1, 1, 3, 2}); auto padded = - b.Pad(input, b.ConstantR0(false), r4_padding_on_dim0_dim1_); + Pad(input, ConstantR0(&b, false), r4_padding_on_dim0_dim1_); // For the same reason, use Select to convert boolean values to int32. auto zeros = MakeUnique>(2, 3, 3, 2); auto ones = MakeUnique>(2, 3, 3, 2); zeros->Fill(0); ones->Fill(1); - b.Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b)); + Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b)); auto expected = MakeUnique>(2, 3, 3, 2); expected->Fill(0); @@ -329,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - b.Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(0.0f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -351,7 +353,8 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - b.Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(3.14f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -376,7 +379,8 @@ XLA_TEST_P(PadTestFloat, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -403,7 +407,8 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -430,7 +435,8 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -446,12 +452,13 @@ XLA_TEST_P(PadTestFloat, ReducePad) { XlaComputation add = CreateScalarAddComputation(FloatType(), &b); auto reduce = - b.Reduce(input, AddParam(*Literal::CreateR0(0.0), &b), add, {0}); + Reduce(input, AddParam(*LiteralUtil::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - b.Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); + Pad(reduce, AddParam(*LiteralUtil::CreateR0(0.0f), &b), + padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 838f1b4e2f0f0e0871ec717bdeefcbbc653397e3..f6c762e7a4bee91a26c4c2e033c3717fef6d91d0 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -42,11 +42,12 @@ class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR0(3.14159f); + std::unique_ptr param0_literal = + LiteralUtil::CreateR0(3.14159f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0"); ComputeAndCompareR0(&builder, 3.14159f, {param0_data.get()}, ErrorSpec(0.0001f)); @@ -54,11 +55,11 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR1({}); + std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0"); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -67,11 +68,11 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({3.14f, -100.25f}); + LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); ComputeAndCompareR1(&builder, {3.14f, -100.25f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -80,12 +81,13 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XlaBuilder builder(TestName()); string str("hello world"); - std::unique_ptr param0_literal = Literal::CreateR1U8(str); + std::unique_ptr param0_literal = LiteralUtil::CreateR1U8(str); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter( - 0, ShapeUtil::MakeShape(U8, {static_cast(str.size())}), "param0"); + Parameter(&builder, 0, + ShapeUtil::MakeShape(U8, {static_cast(str.size())}), + "param0"); ComputeAndCompareR1U8(&builder, str, {param0_data.get()}); } @@ -93,11 +95,11 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR2FromArray2D(Array2D(3, 0)); + LiteralUtil::CreateR2FromArray2D(Array2D(3, 0)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); ComputeAndCompareR2(&builder, Array2D(3, 0), {param0_data.get()}, ErrorSpec(0.01f)); @@ -105,12 +107,12 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR2( + std::unique_ptr param0_literal = LiteralUtil::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); Array2D expected_array( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); @@ -121,28 +123,28 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XLA_TEST_F(ParamsTest, TwoParameters) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = Literal::CreateR1({1, 2}); + std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, literal0->shape(), "param0"); + auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); - std::unique_ptr literal1 = Literal::CreateR1({10, 20}); + std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param1 = builder.Parameter(1, literal1->shape(), "param1"); + auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); // Use both parameters // // {1, 2} + {10, 20} = {11, 22} - auto sum = builder.Add(param0, param1); - sum = builder.Add(param0, param1); + auto sum = Add(param0, param1); + sum = Add(param0, param1); // Use only the second parameter again, to show that it can be used // twice and to make the computation asymmetric in the two // parameters to test that the parameters are not swapped. // // {11, 22} * {10, 20} = {110, 440} - auto prod = builder.Mul(sum, param1); + Mul(sum, param1); ComputeAndCompareR1(&builder, {110, 440}, {param0_data.get(), param1_data.get()}, @@ -152,12 +154,12 @@ XLA_TEST_F(ParamsTest, TwoParameters) { XLA_TEST_F(ParamsTest, MissingParameter) { // Test that an error is returned when a computation with an incomplete set of // parameters (parameter numbers not contiguous from 0) is executed. - std::unique_ptr literal = Literal::CreateR0(3.14159f); + std::unique_ptr literal = LiteralUtil::CreateR0(3.14159f); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2"); + Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2"); auto computation_status = builder.Build(); ASSERT_NE(computation_status.status(), Status::OK()); @@ -166,15 +168,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) { XLA_TEST_F(ParamsTest, UnusedParameter) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = Literal::CreateR1({1, 2}); + std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, literal0->shape(), "param0"); + Parameter(&builder, 0, literal0->shape(), "param0"); - std::unique_ptr literal1 = Literal::CreateR1({10, 20}); + std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param1 = builder.Parameter(1, literal1->shape(), "param1"); + Parameter(&builder, 1, literal1->shape(), "param1"); ComputeAndCompareR1(&builder, {10, 20}, {param0_data.get(), param1_data.get()}, @@ -186,22 +188,23 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // unused expression. XlaBuilder builder(TestName()); - std::unique_ptr literal0 = Literal::CreateR1({1, 2}); + std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1({10, 20, 30}); + std::unique_ptr literal1 = + LiteralUtil::CreateR1({10, 20, 30}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = builder.Parameter(0, literal0->shape(), "param0"); - auto param1 = builder.Parameter(1, literal1->shape(), "param1"); - auto param2 = builder.Parameter(2, literal1->shape(), "param2"); + auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); + auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); + auto param2 = Parameter(&builder, 2, literal1->shape(), "param2"); // This add is unused. - builder.Add(param1, param2); + Add(param1, param2); - builder.Neg(param0); + Neg(param0); ComputeAndCompareR1( &builder, {-1, -2}, @@ -215,7 +218,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector init_value = {{0, 1}}; init_value.resize(size); - XlaOp sum_handle = builder.ConstantR1(init_value); + XlaOp sum_handle = ConstantR1(&builder, init_value); std::vector sum = {{0, 1}}; sum.resize(size); @@ -230,11 +233,11 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector sum_value = {{entry0, entry1}}; sum_value.resize(size); - std::unique_ptr literal = Literal::CreateR1(sum_value); + std::unique_ptr literal = LiteralUtil::CreateR1(sum_value); param_data_owner.push_back( client_->TransferToServer(*literal).ConsumeValueOrDie()); - XlaOp param = builder.Parameter(i, literal->shape(), "param"); - sum_handle = builder.Add(sum_handle, param); + XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + sum_handle = Add(sum_handle, param); } std::vector param_data; @@ -260,16 +263,16 @@ XLA_TEST_F(ParamsTest, XlaBuilder builder(TestName()); std::vector> param_data_owner; - XlaOp sum_handle = builder.ConstantR0(0.0f); + XlaOp sum_handle = ConstantR0(&builder, 0.0f); float target = 0.0; constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = Literal::CreateR0(i); + std::unique_ptr literal = LiteralUtil::CreateR0(i); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = builder.Parameter(i, literal->shape(), "param"); - sum_handle = builder.Add(sum_handle, param); + XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + sum_handle = Add(sum_handle, param); } std::vector param_data; @@ -291,26 +294,26 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( XlaBuilder builder(TestName()); std::vector> param_data_owner; - XlaOp sum_handle = builder.ConstantR1({0, 0}); + XlaOp sum_handle = ConstantR1(&builder, {0, 0}); int32 target = 0; constexpr int kParamCount = 3000; std::vector params; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = Literal::CreateR1({i, i}); + std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = builder.Parameter(i, literal->shape(), "param"); + XlaOp param = Parameter(&builder, i, literal->shape(), "param"); params.push_back(param); - sum_handle = builder.Add(sum_handle, param); + sum_handle = Add(sum_handle, param); } std::vector outputs; for (int i = 0; i < kParamCount; ++i) { - outputs.push_back(builder.Add(params[i], sum_handle)); + outputs.push_back(Add(params[i], sum_handle)); } - builder.Tuple(outputs); + Tuple(&builder, outputs); std::vector param_data; param_data.reserve(param_data_owner.size()); @@ -321,10 +324,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::vector> elements; std::vector ptrs; for (int i = 0; i < kParamCount; ++i) { - elements.push_back(Literal::CreateR1({target + i, target + i})); + elements.push_back(LiteralUtil::CreateR1({target + i, target + i})); ptrs.push_back(elements.back().get()); } - ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); } // Test large number of parameters flowing into a while-loop. @@ -353,25 +356,25 @@ XLA_TEST_F(ParamsTest, std::vector params; std::vector parameter_shapes; for (int i = 0; i < kParamCount; ++i) { - std::unique_ptr literal = Literal::CreateR1({i, i}); + std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = builder.Parameter(i, literal->shape(), "param"); + XlaOp param = Parameter(&builder, i, literal->shape(), "param"); params.push_back(param); parameter_shapes.push_back(literal->shape()); } // Add bool parameter for the loop condition. Use a parameter HLO instead of a // constant because DCE may eliminate the while-body otherwise. - std::unique_ptr bool_literal = Literal::CreateR0(false); + std::unique_ptr bool_literal = LiteralUtil::CreateR0(false); param_data_owner.push_back( std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); XlaOp bool_param = - builder.Parameter(kParamCount, bool_literal->shape(), "bool_param"); + Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param"); params.push_back(bool_param); parameter_shapes.push_back(bool_literal->shape()); - auto init = builder.Tuple(params); + auto init = Tuple(&builder, params); // Create a computation for the condition: while(bool_param). Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes); @@ -379,8 +382,8 @@ XLA_TEST_F(ParamsTest, { XlaBuilder builder("condition"); auto condition_parameter = - builder.Parameter(0, while_shape, "condition_parameter"); - builder.GetTupleElement(condition_parameter, kParamCount); + Parameter(&builder, 0, while_shape, "condition_parameter"); + GetTupleElement(condition_parameter, kParamCount); condition = builder.Build().ConsumeValueOrDie(); } @@ -389,27 +392,27 @@ XLA_TEST_F(ParamsTest, XlaComputation body; { XlaBuilder builder("body"); - auto body_parameter = builder.Parameter(0, while_shape, "body_parameter"); + auto body_parameter = Parameter(&builder, 0, while_shape, "body_parameter"); std::vector updates; for (int i = 0; i < kParamCount; ++i) { - auto add = builder.Add(builder.GetTupleElement(body_parameter, i), - builder.ConstantR1({1, 1})); + auto add = Add(GetTupleElement(body_parameter, i), + ConstantR1(&builder, {1, 1})); updates.push_back(add); } // Add bool parameter. - updates.push_back(builder.GetTupleElement(body_parameter, kParamCount)); + updates.push_back(GetTupleElement(body_parameter, kParamCount)); - builder.Tuple(updates); + Tuple(&builder, updates); body = builder.Build().ConsumeValueOrDie(); } - auto loop = builder.While(condition, body, init); + auto loop = While(condition, body, init); std::vector outputs; for (int i = 0; i < kParamCount; ++i) { - outputs.push_back(builder.GetTupleElement(loop, i)); + outputs.push_back(GetTupleElement(loop, i)); } - builder.Tuple(outputs); + Tuple(&builder, outputs); std::vector param_data; param_data.reserve(param_data_owner.size()); @@ -420,10 +423,10 @@ XLA_TEST_F(ParamsTest, std::vector> elements; std::vector ptrs; for (int i = 0; i < kParamCount; ++i) { - elements.push_back(Literal::CreateR1({i, i})); + elements.push_back(LiteralUtil::CreateR1({i, i})); ptrs.push_back(elements.back().get()); } - ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); } #endif @@ -433,16 +436,16 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r1f32_3, r1f32_3}); - auto input = builder.Parameter(0, tuple_shape, "input"); - auto lhs = builder.GetTupleElement(input, 0); - auto rhs = builder.GetTupleElement(input, 1); - builder.Add(lhs, rhs); + auto input = Parameter(&builder, 0, tuple_shape, "input"); + auto lhs = GetTupleElement(input, 0); + auto rhs = GetTupleElement(input, 1); + Add(lhs, rhs); std::unique_ptr data = client_ - ->TransferToServer(*Literal::MakeTuple({ - Literal::CreateR1({1, 2, 3}).get(), - Literal::CreateR1({4, 5, 6}).get(), + ->TransferToServer(*LiteralUtil::MakeTuple({ + LiteralUtil::CreateR1({1, 2, 3}).get(), + LiteralUtil::CreateR1({4, 5, 6}).get(), })) .ConsumeValueOrDie(); @@ -454,10 +457,10 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr literal = Literal::CreateR2WithLayout( + std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); XlaBuilder builder(TestName()); - builder.Parameter(0, literal->shape(), "input"); + Parameter(&builder, 0, literal->shape(), "input"); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -466,10 +469,10 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr literal = Literal::CreateR2WithLayout( + std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); XlaBuilder builder(TestName()); - builder.Parameter(0, literal->shape(), "input"); + Parameter(&builder, 0, literal->shape(), "input"); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -477,8 +480,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { } XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { - std::unique_ptr literal = Literal::CreateR2({ - {1, 3}, {2, 4}, + std::unique_ptr literal = LiteralUtil::CreateR2({ + {1, 3}, + {2, 4}, }); const Shape original = literal->shape(); { @@ -494,9 +498,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { } // Use the original shape in building the computation. XlaBuilder builder(TestName()); - auto input = builder.Parameter(0, original, "input"); + auto input = Parameter(&builder, 0, original, "input"); // Use the slice operator to get an off-diagonal element. - builder.Slice(input, {0, 1}, {1, 2}, {1, 1}); + Slice(input, {0, 1}, {1, 2}, {1, 1}); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 77159efb26f3b7dd4918f24305f7269a2d6ff647..2fc7f816b56db6f57ca835d1847476b6d622ce5e 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -29,64 +29,63 @@ namespace { class PredTest : public ClientLibraryTestBase { protected: - void TestCompare( - bool lhs, bool rhs, bool expected, - XlaOp (XlaBuilder::*op)(const xla::XlaOp&, const xla::XlaOp&, - tensorflow::gtl::ArraySlice)) { + void TestCompare(bool lhs, bool rhs, bool expected, + std::function)> + op) { XlaBuilder builder(TestName()); - XlaOp lhs_op = builder.ConstantR0(lhs); - XlaOp rhs_op = builder.ConstantR0(rhs); - XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); + XlaOp lhs_op = ConstantR0(&builder, lhs); + XlaOp rhs_op = ConstantR0(&builder, rhs); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } }; TEST_F(PredTest, ConstantR0PredTrue) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR0(true); + ConstantR0(&builder, true); ComputeAndCompareR0(&builder, true, {}); } TEST_F(PredTest, ConstantR0PredFalse) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR0(false); + ConstantR0(&builder, false); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, ConstantR0PredCompareEq) { - TestCompare(true, false, false, &XlaBuilder::Eq); + TestCompare(true, false, false, &Eq); } TEST_F(PredTest, ConstantR0PredCompareNe) { - TestCompare(true, false, true, &XlaBuilder::Ne); + TestCompare(true, false, true, &Ne); } TEST_F(PredTest, ConstantR0PredCompareLe) { - TestCompare(true, false, false, &XlaBuilder::Le); + TestCompare(true, false, false, &Le); } TEST_F(PredTest, ConstantR0PredCompareLt) { - TestCompare(true, false, false, &XlaBuilder::Lt); + TestCompare(true, false, false, &Lt); } TEST_F(PredTest, ConstantR0PredCompareGe) { - TestCompare(true, false, true, &XlaBuilder::Ge); + TestCompare(true, false, true, &Ge); } TEST_F(PredTest, ConstantR0PredCompareGt) { - TestCompare(true, false, true, &XlaBuilder::Gt); + TestCompare(true, false, true, &Gt); } TEST_F(PredTest, ConstantR1Pred) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({true, false, false, true}); + ConstantR1(&builder, {true, false, false, true}); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } TEST_F(PredTest, ConstantR2Pred) { XlaBuilder builder(TestName()); - auto a = - builder.ConstantR2({{false, true, true}, {true, false, false}}); + ConstantR2(&builder, {{false, true, true}, {true, false, false}}); const string expected = R"(pred[2,3] { { 011 }, { 100 } @@ -96,44 +95,44 @@ TEST_F(PredTest, ConstantR2Pred) { TEST_F(PredTest, AnyR1True) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({true, false}); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR1(&builder, {true, false}); + Any(a); ComputeAndCompareR0(&builder, true, {}); } TEST_F(PredTest, AnyR1False) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({false, false}); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR1(&builder, {false, false}); + Any(a); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, AnyR1VacuouslyFalse) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR1({}); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR1(&builder, {}); + Any(a); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, AnyR2True) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({ - {false, false, false}, - {false, false, false}, - {false, false, true}, - }); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR2(&builder, { + {false, false, false}, + {false, false, false}, + {false, false, true}, + }); + Any(a); ComputeAndCompareR0(&builder, true, {}); } TEST_F(PredTest, AnyR2False) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({ - {false, false, false}, - {false, false, false}, - {false, false, false}, - }); - TF_ASSERT_OK(Any(a, &builder).status()); + auto a = ConstantR2(&builder, { + {false, false, false}, + {false, false, false}, + {false, false, false}, + }); + Any(a); ComputeAndCompareR0(&builder, false, {}); } diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 1a2de6937c3e134852a730f62f7b56417cf49b28..326e13b3867f2f804e882e00e35850d0189ad8d7 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -53,8 +53,8 @@ template std::unique_ptr PrngTest::UniformTest( T a, T b, tensorflow::gtl::ArraySlice dims, int64 seed) { XlaBuilder builder(TestName()); - builder.RngUniform( - builder.ConstantR0(a), builder.ConstantR0(b), + RngUniform( + ConstantR0(&builder, a), ConstantR0(&builder, b), ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dims)); SetSeed(seed); @@ -141,9 +141,9 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, int32 sample_size = range_size * expected_count; XlaBuilder builder(TestName()); - builder.RngUniform(builder.ConstantR0(0), - builder.ConstantR0(range_size), - ShapeUtil::MakeShape(S32, {sample_size})); + RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, range_size), + ShapeUtil::MakeShape(S32, {sample_size})); SetSeed(seed); auto actual = @@ -177,28 +177,29 @@ XLA_TEST_F(PrngTest, Uniformity108) { EXPECT_LT(UniformChiSquared(108, 256), 132.144); } XLA_TEST_F(PrngTest, Uniformity256) { - EXPECT_LT(UniformChiSquared(256, 256), 293.248); + EXPECT_LT(UniformChiSquared(256, 512), 293.248); } XLA_TEST_F(PrngTest, MapUsingRng) { // Build a x -> (x + U[0,1)) computation. - auto build_sum_rng = [this](XlaBuilder& builder) { + auto build_sum_rng = [](XlaBuilder& builder) { auto b = builder.CreateSubBuilder("sum_with_rng"); - auto x = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "input"); - b->Add(x, b->RngUniform(b->ConstantR0(0), b->ConstantR0(1), - ShapeUtil::MakeShape(F32, {}))); + auto x = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "input"); + Add(x, + RngUniform(ConstantR0(b.get(), 0), ConstantR0(b.get(), 1), + ShapeUtil::MakeShape(F32, {}))); return b->BuildAndNoteError(); }; XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr param0_data, client_->TransferToServer(*param0_literal)); - auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); auto fn = build_sum_rng(builder); - builder.Map({param0}, fn, {0}); + Map(&builder, {param0}, fn, {0}); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); @@ -226,9 +227,8 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { // Build a U[0,1) computation. auto build_computation = [this]() { XlaBuilder builder(TestName()); - builder.RngUniform(builder.ConstantR0(0), - builder.ConstantR0(1), - ShapeUtil::MakeShape(F32, {10})); + RngUniform(ConstantR0(&builder, 0), ConstantR0(&builder, 1), + ShapeUtil::MakeShape(F32, {10})); return builder.Build(); }; @@ -282,8 +282,8 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { XLA_TEST_F(PrngTest, TenValuesN01) { XlaBuilder builder(TestName()); - builder.RngNormal(builder.ConstantR0(0), builder.ConstantR0(1), - ShapeUtil::MakeShape(F32, {10})); + RngNormal(ConstantR0(&builder, 0), ConstantR0(&builder, 1), + ShapeUtil::MakeShape(F32, {10})); SetSeed(42); ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); @@ -294,9 +294,9 @@ XLA_TEST_F(PrngTest, RngUniformCrash) { XlaBuilder builder(TestName()); // This used to crash XLA during LLVM IR generation for CPUs. - auto rng_uniform = builder.RngUniform(builder.ConstantR0(0), - builder.ConstantR0(1000 * 1000), - ShapeUtil::MakeShape(S32, {})); + RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 1000 * 1000), + ShapeUtil::MakeShape(S32, {})); SetSeed(0); ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc index f95e75648343aa88bd7c39de4ee9f387f2b60506..fab2a65de109c670a6854c0fc1118162acf3d312 100644 --- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc +++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -31,8 +31,8 @@ class QueryInferredShapeTest : public ClientLibraryTestBase {}; TEST_F(QueryInferredShapeTest, OnePlusOneShape) { XlaBuilder builder("one_plus_one"); - auto one = builder.ConstantR0(1.0); - auto result = builder.Add(one, one); + auto one = ConstantR0(&builder, 1.0); + auto result = Add(one, one); StatusOr shape_status = builder.GetShape(result); ASSERT_IS_OK(shape_status.status()); auto shape = shape_status.ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index 9052b188ed09a715b6ad7c3a40dc853d02cdd70c..a080dd1732bde21712cf47b4b57538cf4040f30e 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -95,21 +95,21 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) { *reduce_input_shape->mutable_layout() = LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major); - std::unique_ptr reduce_input = - Literal::CreateR4({{ /*i0=0*/ - {/*i1=0*/ - {-0.246092796, -0.179497838, -0.161181688}, - {-0.151643038, -0.240213156, -0.198156}}, - {/*i1=1*/ - {-0.14222312, -0.162200093, -0.193907976}, - {-0.239411, -0.198166847, -0.172471642}}}, - { /*i0=1*/ - {/*i1=0*/ - {-0.22965157, -0.218723893, -0.129257083}, - {-0.188762426, -0.16123569, -0.181166649}}, - {/*i1=1*/ - {-0.241772294, -0.245131493, -0.160247207}, - {-0.179881215, -0.23383224, -0.121976733}}}}); + std::unique_ptr reduce_input = LiteralUtil::CreateR4( + {{ /*i0=0*/ + {/*i1=0*/ + {-0.246092796, -0.179497838, -0.161181688}, + {-0.151643038, -0.240213156, -0.198156}}, + {/*i1=1*/ + {-0.14222312, -0.162200093, -0.193907976}, + {-0.239411, -0.198166847, -0.172471642}}}, + { /*i0=1*/ + {/*i1=0*/ + {-0.22965157, -0.218723893, -0.129257083}, + {-0.188762426, -0.16123569, -0.181166649}}, + {/*i1=1*/ + {-0.241772294, -0.245131493, -0.160247207}, + {-0.179881215, -0.23383224, -0.121976733}}}}); EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index b311785449f1774c3bc1e4d7ad35c2866e3b4061..531648fe3eb8e3941c5e3c012847ee68c616590f 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -230,12 +230,13 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({input_values}); + std::unique_ptr a_literal = + LiteralUtil::CreateR1({input_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); - builder.ReducePrecision(a, exponent_bits, mantissa_bits); + ReducePrecision(a, exponent_bits, mantissa_bits); ComputeAndCompareR1(&builder, expected_values, {a_data.get()}); } @@ -253,18 +254,18 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // Abs doesn't affect resolution. - auto abs = builder.Abs(a); + auto abs = Abs(a); // Near 1.0, Log(x) approximates x - 1; this lets us confirm that the // reduce-precision operation showed up in the correct place in the // graph. - builder.Log(abs); + Log(abs); // Insert precision-reduction after the Abs(x) operation, rounding that // result to exactly 1.0f. @@ -282,14 +283,14 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // These two operations should be fused by any reasonable backend. - auto abs = builder.Abs(a); - builder.Neg(abs); + auto abs = Abs(a); + Neg(abs); // Add a pass after operation fusion, suffixing kAbs operations. This // should not see into the fusion nodes and thus should not affect the @@ -308,14 +309,14 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // These two operations should be fused by any reasonable backend. - auto abs = builder.Abs(a); - builder.Neg(abs); + auto abs = Abs(a); + Neg(abs); // Add a pass after operation fusion, suffixing kFusion operations. auto reduce_precision_pass = execution_options_.mutable_debug_options() @@ -332,14 +333,14 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // These two operations should be fused by any reasonable backend. - auto abs = builder.Abs(a); - builder.Neg(abs); + auto abs = Abs(a); + Neg(abs); // Add a pass suffixing fusion nodes containing kCos operations. This // should have no effect. @@ -357,14 +358,14 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = builder.Parameter(0, a_literal->shape(), "a"); + auto a = Parameter(&builder, 0, a_literal->shape(), "a"); // These two operations should be fused by any reasonable backend. - auto abs = builder.Abs(a); - builder.Neg(abs); + auto abs = Abs(a); + Neg(abs); // Add a pass suffixing fusion nodes containing kAbs operations. This // should see the kAbs operation within the above fusion node. diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index d671d40456a276a44b462f390c95aa4af301263a..2065271a7f686c52c88df80b0efe8f2e1542d198 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -37,8 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -67,12 +67,12 @@ class ReduceTest : public ClientLibraryTestBase { ReduceTest() { // Implementation note: laid out z >> y >> x by default. // clang-format off - literal_2d_ = Literal::CreateR2({ + literal_2d_ = LiteralUtil::CreateR2({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 }); - literal_3d_ = Literal::CreateR3Projected({ + literal_3d_ = LiteralUtil::CreateR3Projected({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 @@ -89,9 +89,9 @@ class ReduceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); std::vector input_data(element_count); for (int64 i = 0; i < element_count; ++i) { @@ -101,7 +101,7 @@ class ReduceTest : public ClientLibraryTestBase { } } std::unique_ptr input_literal = - Literal::CreateR1(AsSlice(input_data)); + LiteralUtil::CreateR1(AsSlice(input_data)); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -118,22 +118,22 @@ class ReduceTest : public ClientLibraryTestBase { const int element_count = input_data.size(); XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count}); - auto input_par = builder.Parameter(0, input_shape, "input"); + auto input_par = Parameter(&builder, 0, input_shape, "input"); auto pred_values = - builder.Eq(input_par, builder.ConstantR1(element_count, 1)); + Eq(input_par, ConstantR1(&builder, element_count, 1)); XlaOp init_value; XlaComputation reduce; if (and_reduce) { - init_value = builder.ConstantR0(true); - reduce = CreateScalarAndComputation(&builder); + init_value = ConstantR0(&builder, true); + reduce = CreateScalarAndComputation(PRED, &builder); } else { - init_value = builder.ConstantR0(false); - reduce = CreateScalarOrComputation(&builder); + init_value = ConstantR0(&builder, false); + reduce = CreateScalarOrComputation(PRED, &builder); } - builder.Reduce(pred_values, init_value, reduce, - /*dimensions_to_reduce=*/{0}); + Reduce(pred_values, init_value, reduce, + /*dimensions_to_reduce=*/{0}); - std::unique_ptr input_literal = Literal::CreateR1(input_data); + std::unique_ptr input_literal = LiteralUtil::CreateR1(input_data); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -156,26 +156,26 @@ class ReduceTest : public ClientLibraryTestBase { int64 major = 0) { XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(U8, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto input_pred = builder.Eq(input, builder.ConstantR0(1)); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto input_pred = Eq(input, ConstantR0(&builder, 1)); XlaOp init_value; XlaComputation reduce_op; if (and_reduce) { - init_value = builder.ConstantR0(true); - reduce_op = CreateScalarAndComputation(&builder); + init_value = ConstantR0(&builder, true); + reduce_op = CreateScalarAndComputation(PRED, &builder); } else { - init_value = builder.ConstantR0(false); - reduce_op = CreateScalarOrComputation(&builder); + init_value = ConstantR0(&builder, false); + reduce_op = CreateScalarOrComputation(PRED, &builder); } - builder.Reduce(input_pred, init_value, reduce_op, - /*dimensions_to_reduce=*/{0}); + Reduce(input_pred, init_value, reduce_op, + /*dimensions_to_reduce=*/{0}); Array2D input_data(rows, cols); input_data.FillRandom(0, 1); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = @@ -202,14 +202,14 @@ class ReduceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1}); Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = @@ -230,14 +230,14 @@ class ReduceTest : public ClientLibraryTestBase { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = @@ -287,15 +287,15 @@ class ReduceTest : public ClientLibraryTestBase { XlaComputation reduction_function = reduction_function_generator(&builder); const Shape input_shape = ShapeUtil::MakeShape( xla::primitive_util::NativeToPrimitiveType(), {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(initial_value); - builder.Reduce(input, zero, reduction_function, - /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, initial_value); + Reduce(input, zero, reduction_function, + /*dimensions_to_reduce=*/{0}); Array2D input_data(rows, cols); input_data.FillUnique(initial_value); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = @@ -442,15 +442,15 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - auto log_ = builder.Log(input); - builder.Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + auto log_ = Log(input); + Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0}); Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -473,16 +473,16 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - auto log_ = builder.Log(input); - auto transpose = builder.Transpose(log_, {1, 0}); - builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + auto log_ = Log(input); + auto transpose = Transpose(log_, {1, 0}); + Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1}); Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -505,10 +505,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50}); - XlaOp input = builder.Parameter(0, input_shape, "input"); - XlaOp zero = builder.ConstantR0(0.0); - XlaOp transpose = builder.Transpose(input, /*permutation=*/{1, 0, 2}); - builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); + XlaOp input = Parameter(&builder, 0, input_shape, "input"); + XlaOp zero = ConstantR0(&builder, 0.0); + XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2}); + Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, MakeFakeLiteral(input_shape)); @@ -522,16 +522,16 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { XlaBuilder builder(TestName()); XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto zero = builder.ConstantR0(0.0); - auto log_ = builder.Tanh(input); - auto reshape = builder.Reshape(log_, {rows, cols}); - builder.Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + auto log_ = Tanh(input); + auto reshape = Reshape(log_, {rows, cols}); + Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0}); Array3D input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR3FromArray3D(input_data); + LiteralUtil::CreateR3FromArray3D(input_data); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -568,9 +568,9 @@ void PrintTo(const BoundsLayout& spec, std::ostream* os) { XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) { XlaBuilder builder(TestName()); auto add = CreateScalarAddComputation(F32, &builder); - auto scalar = builder.ConstantR0(42.0); - auto broadcasted = builder.Broadcast(scalar, {500, 500}); - builder.Reduce(broadcasted, builder.ConstantR0(0.0f), add, {0, 1}); + auto scalar = ConstantR0(&builder, 42.0); + auto broadcasted = Broadcast(scalar, {500, 500}); + Reduce(broadcasted, ConstantR0(&builder, 0.0f), add, {0, 1}); float expected = 42.0f * static_cast(500 * 500); ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); @@ -580,9 +580,9 @@ XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) { XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) { XlaBuilder builder(TestName()); auto max = CreateScalarMaxComputation(F32, &builder); - auto scalar = builder.ConstantR0(42.0); - auto broadcasted = builder.Broadcast(scalar, {500, 500}); - builder.Reduce(broadcasted, builder.ConstantR0(0.0f), max, {0, 1}); + auto scalar = ConstantR0(&builder, 42.0); + auto broadcasted = Broadcast(scalar, {500, 500}); + Reduce(broadcasted, ConstantR0(&builder, 0.0f), max, {0, 1}); float expected = 42.0f; ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); @@ -594,9 +594,9 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { auto max = CreateScalarMaxComputation(F32, &builder); Array2D input(300, 250); input.FillRandom(214.0f); - auto input_literal = Literal::CreateR2FromArray2D(input); - builder.Reduce(builder.ConstantLiteral(*input_literal), - builder.ConstantR0(FLT_MIN), max, {0, 1}); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); + Reduce(ConstantLiteral(&builder, *input_literal), + ConstantR0(&builder, FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; input.Each( [&](int64, int64, float* v) { input_max = std::max(input_max, *v); }); @@ -609,9 +609,9 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { auto min = CreateScalarMinComputation(F32, &builder); Array2D input(150, 130); input.FillRandom(214.0f); - auto input_literal = Literal::CreateR2FromArray2D(input); - builder.Reduce(builder.ConstantLiteral(*input_literal), - builder.ConstantR0(FLT_MAX), min, {0, 1}); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); + Reduce(ConstantLiteral(&builder, *input_literal), + ConstantR0(&builder, FLT_MAX), min, {0, 1}); auto input_min = FLT_MAX; input.Each( @@ -623,12 +623,11 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { XlaBuilder builder(TestName()); Array2D input({{1}, {2}}); auto min = CreateScalarMinComputation(U32, &builder); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); auto initial_value = - builder.ConstantR0(std::numeric_limits::max()); + ConstantR0(&builder, std::numeric_limits::max()); - builder.Reduce(builder.ConstantLiteral(*input_literal), initial_value, min, - {0, 1}); + Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1}); ComputeAndCompareR0(&builder, 1, {}); } @@ -636,21 +635,20 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { XlaBuilder builder(TestName()); Array2D input({{1}, {2}}); auto max = CreateScalarMaxComputation(U32, &builder); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); auto initial_value = - builder.ConstantR0(std::numeric_limits::min()); + ConstantR0(&builder, std::numeric_limits::min()); - builder.Reduce(builder.ConstantLiteral(*input_literal), initial_value, max, - {0, 1}); + Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1}); ComputeAndCompareR0(&builder, 2, {}); } // Reduces a matrix among dimension 1. XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_2d_); + auto m = ConstantLiteral(&builder, *literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {1}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); std::vector expected = {6.f, 15.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -659,9 +657,9 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar). XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_2d_); + auto m = ConstantLiteral(&builder, *literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); ComputeAndCompareR0(&builder, 21.0f, {}, ErrorSpec(0.0001, 1e-4)); } @@ -669,9 +667,9 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Tests 2D matrix ReduceToRow operation. XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XlaBuilder builder("reduce_among_y"); - auto m = builder.ConstantLiteral(*literal_2d_); + auto m = ConstantLiteral(&builder, *literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); std::vector expected = {5.f, 7.f, 9.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -679,9 +677,9 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {1, 2}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {1, 2}); std::vector expected = {21.f, 21.f, 21.f, 21.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -689,9 +687,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); std::vector expected = {20.f, 28.f, 36.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -699,9 +697,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1, 2}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1, 2}); float expected = 21.0f * 4.0; ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); @@ -709,9 +707,9 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {0}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); // clang-format off Array2D expected({ @@ -724,9 +722,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {1}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); // clang-format off Array2D expected({ @@ -741,9 +739,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { XlaBuilder builder(TestName()); - auto m = builder.ConstantLiteral(*literal_3d_); + auto m = ConstantLiteral(&builder, *literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); - builder.Reduce(m, builder.ConstantR0(0.0f), add, {2}); + Reduce(m, ConstantR0(&builder, 0.0f), add, {2}); // clang-format off Array2D expected({ @@ -800,13 +798,17 @@ XLA_TEST_F(ReduceTest, VectorizedReduce_Min) { XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) { RunVectorizedReduceTestForType( - static_cast(CreateScalarAndComputation), + static_cast([](XlaBuilder* builder) { + return CreateScalarAndComputation(PRED, builder); + }), [](bool a, bool b) { return a && b; }, true); } XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) { RunVectorizedReduceTestForType( - static_cast(CreateScalarOrComputation), + static_cast([](XlaBuilder* builder) { + return CreateScalarOrComputation(PRED, builder); + }), [](bool a, bool b) { return a || b; }, false); } @@ -820,17 +822,17 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { // input_array.FillRandom(3.14f, 0.05); input_array.Fill(1.0f); - auto input_literal = Literal::CreateR3FromArray3D(input_array); + auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout)); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); auto input_activations = - builder.Parameter(0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal->shape(), "input"); XlaComputation add = CreateScalarAddComputation(F32, &builder); - auto sum = builder.Reduce(input_activations, builder.ConstantR0(0.0f), - add, GetParam().reduce_dims); + Reduce(input_activations, ConstantR0(&builder, 0.0f), add, + GetParam().reduce_dims); auto expected = ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims, @@ -871,14 +873,15 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { XlaBuilder builder(TestName()); XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder); - auto a = builder.ConstantR0(2.0f); - auto a2 = builder.Abs(a); + auto a = ConstantR0(&builder, 2.0f); + auto a2 = Abs(a); - std::unique_ptr b_literal = Literal::CreateR1({1.0f, 4.0f}); + std::unique_ptr b_literal = + LiteralUtil::CreateR1({1.0f, 4.0f}); std::unique_ptr b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b = builder.Parameter(0, b_literal->shape(), "b"); - auto max = builder.Reduce(b, a2, max_f32, {0}); + auto b = Parameter(&builder, 0, b_literal->shape(), "b"); + Reduce(b, a2, max_f32, {0}); ComputeAndCompareR0(&builder, 4.0f, {b_data.get()}); } @@ -900,13 +903,13 @@ class ReduceInitializerTest : public ReduceTest { XlaComputation max_fn = CreateScalarMaxComputation( primitive_util::NativeToPrimitiveType(), &builder); - auto init = builder.ConstantR0(initializer); + auto init = ConstantR0(&builder, initializer); std::vector input_arr(num_elems, std::numeric_limits::lowest()); - auto input_literal = Literal::CreateR1(input_arr); + auto input_literal = LiteralUtil::CreateR1(input_arr); auto input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - builder.Reduce(builder.Parameter(0, input_literal->shape(), "input"), init, - max_fn, {0}); + Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init, + max_fn, {0}); ComputeAndCompareR0(&builder, initializer, {input_data.get()}); } @@ -939,23 +942,24 @@ XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) { XLA_TEST_F(ReduceTest, ReduceIdentity) { XlaBuilder builder(TestName()); Shape single_float = ShapeUtil::MakeShape(F32, {}); - builder.Parameter(0, single_float, "lhs-unused"); - builder.Parameter(1, single_float, "rhs-used"); + Parameter(&builder, 0, single_float, "lhs-unused"); + Parameter(&builder, 1, single_float, "rhs-used"); auto computation_status = builder.Build(); TF_ASSERT_OK(computation_status.status()); Shape operand_shape = ShapeUtil::MakeShape(F32, {1}); - builder.Reduce(builder.Parameter(0, operand_shape, "operand"), - builder.Parameter(1, single_float, "init"), - computation_status.ValueOrDie(), {0}); + Reduce(Parameter(&builder, 0, operand_shape, "operand"), + Parameter(&builder, 1, single_float, "init"), + computation_status.ValueOrDie(), {0}); float operand[] = {42.0f}; float init = 58.5f; float expected = 42.0f; - std::unique_ptr input_literal = Literal::CreateR1(operand); + std::unique_ptr input_literal = + LiteralUtil::CreateR1(operand); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - std::unique_ptr input_literal2 = Literal::CreateR0(init); + std::unique_ptr input_literal2 = LiteralUtil::CreateR0(init); std::unique_ptr input_global_data2 = client_->TransferToServer(*input_literal2).ConsumeValueOrDie(); ComputeAndCompareR0( @@ -963,5 +967,32 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) { ErrorSpec(0.0001)); } +XLA_TEST_F(ReduceTest, AndReduceU64) { + XlaBuilder builder(TestName()); + Array2D initializer = {{0x123456789ABCDEF0LL, 0x3BCDEF12A4567890LL}, + {0XFFFFFFFFFFFFFFD6LL, 101}, + {1, 0XFFFFFFFFFFFFFFFFLL}}; + auto reducer = CreateScalarAndComputation(U64, &builder); + auto m = ConstantR2FromArray2D(&builder, initializer); + Reduce(m, ConstantR0(&builder, 0xFFFFFFFFFFFFFFFFLL), reducer, {1}); + + std::vector expected = {0x1204461080145890LL, 68, 1}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ReduceTest, OrReduceU64) { + XlaBuilder builder(TestName()); + Array2D initializer = {{0x123456789ABCDEF0LL, 0x3BCDEF12A4567890LL}, + {0xFFFFFFFFFFFFFFD6LL, 101}, + {1, 0xCAFEBEEFABABABABLL}}; + auto reducer = CreateScalarOrComputation(U64, &builder); + auto m = ConstantR2FromArray2D(&builder, initializer); + Reduce(m, ConstantR0(&builder, 0), reducer, {1}); + + std::vector expected = {0X3BFDFF7ABEFEFEF0LL, 0XFFFFFFFFFFFFFFF7LL, + 0xCAFEBEEFABABABABLL}; + ComputeAndCompareR1(&builder, expected, {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 266760e8202fddc48792ac66dda334255e428808..92c93f08b2e8e543aeaa58020eddacd109b2e2da 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -24,8 +24,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -70,31 +70,33 @@ class ReduceWindowTest : public ::testing::WithParamInterface, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - auto init = - CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_); - builder_.ReduceWindow(input, init, - CreateScalarAddComputation(FloatType(), &builder_), - window_dimensions, window_strides, padding); + auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), + &builder_); + ReduceWindow(input, init, + CreateScalarAddComputation(FloatType(), &builder_), + window_dimensions, window_strides, padding); } void ReduceWindowMax(const XlaOp& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_); - builder_.ReduceWindow(input, init, - CreateScalarMaxComputation(FloatType(), &builder_), - window_dimensions, window_strides, padding); + auto init = + CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_); + ReduceWindow(input, init, + CreateScalarMaxComputation(FloatType(), &builder_), + window_dimensions, window_strides, padding); } void ReduceWindowMin(const XlaOp& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_); - builder_.ReduceWindow(input, init, - CreateScalarMinComputation(FloatType(), &builder_), - window_dimensions, window_strides, padding); + auto init = + CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_); + ReduceWindow(input, init, + CreateScalarMinComputation(FloatType(), &builder_), + window_dimensions, window_strides, padding); } XlaBuilder builder_; @@ -102,14 +104,14 @@ class ReduceWindowTest : public ::testing::WithParamInterface, TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { const auto input = CreateConstantFromLiteral( - *Literal::CreateR1({1, 1, 1, 1}), &builder_); + *LiteralUtil::CreateR1({1, 1, 1, 1}), &builder_); const auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(0), &builder_); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); - builder_.ReduceWindow(input, init_value, - CreateScalarAddComputation(FloatType(), &builder_), - /*window_dimensions=*/{1, 2}, - /*window_strides=*/{1}, Padding::kValid); + ReduceWindow(input, init_value, + CreateScalarAddComputation(FloatType(), &builder_), + /*window_dimensions=*/{1, 2}, + /*window_strides=*/{1}, Padding::kValid); ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT) << builder_.first_error(); ASSERT_THAT(builder_.first_error().error_message(), @@ -119,33 +121,32 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { // Regression test for b/68964348. TEST_P(ReduceWindowTest, R0ReduceWindow) { const auto input = - CreateConstantFromLiteral(*Literal::CreateR0(42.0), &builder_); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(42.0), &builder_); const auto init = - CreateConstantFromLiteral(*Literal::CreateR0(1.0), &builder_); - builder_.ReduceWindow(input, init, - CreateScalarAddComputation(FloatType(), &builder_), - /*window_dimensions=*/{}, - /*window_strides=*/{}, Padding::kSame); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR0(43.0), {}, + CreateConstantFromLiteral(*LiteralUtil::CreateR0(1.0), &builder_); + ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), + /*window_dimensions=*/{}, + /*window_strides=*/{}, Padding::kSame); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0(43.0), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride2) { const auto input = CreateConstantFromLiteral( - *Literal::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({100, 1}), {}, - ErrorSpec(0.00001)); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({100, 1}), + {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) { const auto input = CreateConstantFromLiteral( - *Literal::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1}, Padding::kSame); ComputeAndCompareLiteral(&builder_, - *Literal::CreateR1({1000, 100, 10, 1, 1}), {}, - ErrorSpec(0.00001)); + *LiteralUtil::CreateR1({1000, 100, 10, 1, 1}), + {}, ErrorSpec(0.00001)); } XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { @@ -157,7 +158,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -172,7 +173,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -186,7 +187,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -203,7 +204,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -225,8 +226,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { @@ -248,8 +249,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } // Tests the super windowing logic w.r.t handling prime number of windows in a @@ -273,8 +274,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { @@ -290,8 +291,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. @@ -306,15 +307,15 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { Padding padding = Padding::kValid; const Shape scalar = ShapeUtil::MakeShape(FloatType(), {}); auto b = builder_.CreateSubBuilder("unusual"); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); - b->Min(b->Add(lhs, rhs), - CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); + auto lhs = Parameter(b.get(), 0, scalar, "lhs"); + auto rhs = Parameter(b.get(), 1, scalar, "rhs"); + Min(Add(lhs, rhs), + CreateConstantFromLiteral(*LiteralUtil::CreateR0(8.0f), b.get())); XlaComputation reduce_fn = b->BuildAndNoteError(); - builder_.ReduceWindow( + ReduceWindow( input, - CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_), + CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), &builder_), reduce_fn, /*window_dimensions=*/{1, 1, 2, 1}, /*window_strides=*/{1, 1, 1, 1}, padding); @@ -328,15 +329,15 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*expected), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D input_array(13, 12, 8, 15); input_array.FillRandom(2.f, 2.f); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -348,7 +349,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -377,7 +378,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) { auto shape = ShapeUtil::MakeShape(F32, input_dims); std::unique_ptr arg_literal = - Literal::CreateFullWithDescendingLayout(input_dims, 1.0f); + LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); @@ -386,7 +387,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector output_dims = {8, 8, 6, 6, 8, 8}; std::unique_ptr expected = - Literal::CreateFullWithDescendingLayout(output_dims, 9.0f); + LiteralUtil::CreateFullWithDescendingLayout(output_dims, 9.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } @@ -395,7 +396,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -409,7 +410,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -417,7 +418,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -431,7 +432,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -439,7 +440,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -453,7 +454,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -474,18 +475,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector input_vector(128 * 9, 1); const auto input = CreateConstantFromLiteral( - *Literal::CreateR1(input_vector), &builder_); + *LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); ComputeAndCompareLiteral( &builder_, - *Literal::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + *LiteralUtil::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, DefaultErrorSpec()); } @@ -500,9 +501,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *Literal::CreateR1(input_vector), &builder_); + *LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -517,9 +518,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *Literal::CreateR1(input_vector), &builder_); + *LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {1}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -536,14 +537,15 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, + *LiteralUtil::CreateFromArray(*res), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D input_array(6, 4, 1.0f); - XlaOp input = builder_.Broadcast( - CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4}); + XlaOp input = Broadcast( + CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4}); Padding padding = Padding::kSame; ReduceWindowAdd(input, {4, 2}, {3, 3}, padding); @@ -551,8 +553,9 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, + *LiteralUtil::CreateFromArray(*res), {}, + DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, @@ -610,7 +613,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, param.base_bounds[2], param.base_bounds[3]); input.FillIota(1); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", @@ -622,12 +625,12 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, } auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto computation = param.reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); - b.ReduceWindowWithGeneralPadding( + ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, /*computation=*/computation, @@ -648,7 +651,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*stride=*/param.strides, /*padding=*/padding); std::unique_ptr expected_literal = - Literal::CreateFromArray(*expected); + LiteralUtil::CreateFromArray(*expected); const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout( input_literal->shape().element_type(), AsInt64Slice(expected_literal->shape().dimensions()), param.layout); @@ -960,25 +963,25 @@ TEST_P(R3ReduceWindowTest, Add) { Array3D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], 1.0f); std::unique_ptr input_literal = - Literal::CreateR3FromArray3DWithLayout( + LiteralUtil::CreateR3FromArray3DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindow(/*operand=*/parameter, - /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + ReduceWindow(/*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/CreateScalarAddComputation(FloatType(), &b), + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/param.padding); auto expected = ReferenceUtil::ReduceWindow3DAdd( /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/param.padding); - ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); } @@ -1094,7 +1097,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, const float kInitValue = 0.0f; Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); std::unique_ptr input_literal = - Literal::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; @@ -1108,8 +1111,8 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindowWithGeneralPadding( + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, /*computation=*/computation, @@ -1124,7 +1127,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); } }; @@ -1293,7 +1296,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::vector input_vector(param.base_bounds[0]); std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr input_literal = - Literal::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); + LiteralUtil::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); @@ -1305,8 +1308,8 @@ TEST_P(R1ReduceWindowTest, DoIt) { ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindowWithGeneralPadding( + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, /*computation=*/computation, @@ -1324,7 +1327,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *Literal::CreateR1(*expected), + ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1(*expected), {input_arg.get()}, DefaultErrorSpec()); } @@ -1338,7 +1341,7 @@ INSTANTIATE_TEST_CASE_P( // results on the interpreter backend. class ReduceWindowTextTest : public HloTestBase {}; -TEST_F(ReduceWindowTextTest, R2General256x384) { +XLA_TEST_F(ReduceWindowTextTest, R2General256x384) { const string hlo_string = R"( HloModule R2Window mul { @@ -1355,7 +1358,7 @@ ENTRY R2Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { +XLA_TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { const string hlo_string = R"( HloModule R2Window mul { @@ -1372,7 +1375,7 @@ ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window= EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(ReduceWindowTextTest, R2General2x5) { +XLA_TEST_F(ReduceWindowTextTest, R2General2x5) { const string hlo_string = R"( HloModule R2Window mul { @@ -1389,7 +1392,7 @@ ENTRY R2Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { +XLA_TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { const string hlo_string = R"( HloModule R2Window mul { @@ -1407,7 +1410,7 @@ ENTRY R2Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { +XLA_TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { const string hlo_string = R"( HloModule R3Window mul { @@ -1425,7 +1428,7 @@ ENTRY R3Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -TEST_F(HloTestBase, ReduceWindowIdentity) { +XLA_TEST_F(HloTestBase, ReduceWindowIdentity) { const string hlo_string = R"( HloModule ReduceWindowIdentity identity.pad_to_reduce_window { @@ -1442,7 +1445,7 @@ ENTRY reduce-window-identity { EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); } -TEST_F(HloTestBase, ReduceWindowS32) { +XLA_TEST_F(HloTestBase, ReduceWindowS32) { const string hlo_string = R"( HloModule reduce-window @@ -1461,5 +1464,24 @@ ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); } +XLA_TEST_F(HloTestBase, ReduceWindowF16) { + const string hlo_string = R"( +HloModule reduce-window + +%identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] { + %param0 = f16[] parameter(0) + ROOT %param1 = f16[] parameter(1) +} + +ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] { + %parameter.0 = f16[81,8]{1,0} parameter(0) + %parameter.1 = f16[] parameter(1) + ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 36d763b0f7f4267ede076c0b25cfaf9654e96e0d..d8914513819415368a628eab1f482f9644dd46b1 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -39,8 +39,8 @@ class ReplayTest : public ClientLibraryTestBase {}; TEST_F(ReplayTest, TwoPlusTwoReplay) { // Make 2+2 computation. XlaBuilder builder(TestName()); - auto two = builder.ConstantR0(2); - builder.Add(two, two); + auto two = ConstantR0(&builder, 2); + Add(two, two); XlaComputation computation = builder.Build().ConsumeValueOrDie(); // Serialize it out. @@ -70,9 +70,9 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Make computation. XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(S32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(S32, {}), "y"); + Add(x, y); XlaComputation computation = builder.Build().ConsumeValueOrDie(); // Serialize it out. @@ -91,10 +91,10 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Run it. std::unique_ptr x_data = - client_->TransferToServer(*Literal::CreateR0(2)) + client_->TransferToServer(*LiteralUtil::CreateR0(2)) .ConsumeValueOrDie(); std::unique_ptr y_data = - client_->TransferToServer(*Literal::CreateR0(3)) + client_->TransferToServer(*LiteralUtil::CreateR0(3)) .ConsumeValueOrDie(); std::unique_ptr literal = client_ @@ -111,13 +111,13 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { // As above, but with map(+2) over some constant array. XlaBuilder plus_two_builder("plus two"); auto input = - plus_two_builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "input"); - plus_two_builder.Add(input, plus_two_builder.ConstantR0(2)); + Parameter(&plus_two_builder, 0, ShapeUtil::MakeShape(S32, {}), "input"); + Add(input, ConstantR0(&plus_two_builder, 2)); XlaComputation plus_two = plus_two_builder.Build().ConsumeValueOrDie(); XlaBuilder mapper_builder(TestName()); - auto original = mapper_builder.ConstantR1({1, 2, 3}); - mapper_builder.Map({original}, plus_two, {0}); + auto original = ConstantR1(&mapper_builder, {1, 2, 3}); + Map(&mapper_builder, {original}, plus_two, {0}); XlaComputation computation = mapper_builder.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index da1b588ec41cef711412367e89b2a9b1029bca71..368f5583c9ce3773e57b858ff7606f679346529a 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -44,11 +44,11 @@ using ReshapeMotionTest = ClientLibraryTestBase; TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) { XlaBuilder builder(TestName()); - auto a = builder.ConstantR2({{2, 3, 5}, {7, 11, 13}}); - auto b = builder.ConstantR2({{17, 19}, {23, 29}, {31, 37}}); - auto c = builder.Reshape(a, {6}); - auto d = builder.Reshape(b, {6}); - auto e = builder.Mul(c, d); + auto a = ConstantR2(&builder, {{2, 3, 5}, {7, 11, 13}}); + auto b = ConstantR2(&builder, {{17, 19}, {23, 29}, {31, 37}}); + auto c = Reshape(a, {6}); + auto d = Reshape(b, {6}); + Mul(c, d); ComputeAndCompareR1(&builder, {34, 57, 115, 203, 341, 481}, {}); } diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index a4580cd71d46ad0a0186eddd51291f9c322b6f49..382d1b1ae741285dcd1f7761edb82a5c333887af 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -55,39 +55,39 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { XlaBuilder builder(TestName()); Array2D input_array(1, 1); input_array.Fill(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(input_array); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({1.0f}); + auto expected_literal = LiteralUtil::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR1({1.0f}); + auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{}); + Collapse(/*operand=*/parameter, /*dimensions=*/{}); - auto expected_literal = Literal::CreateR1({1.0f}); + auto expected_literal = LiteralUtil::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR1({1.0f}); + auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0}); - auto expected_literal = Literal::CreateR1({1.0f}); + auto expected_literal = LiteralUtil::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -97,15 +97,15 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { XlaBuilder builder(TestName()); Array2D input_array(1, 1); input_array.Fill(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(input_array); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); - auto reshape = builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{}); + auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); - auto expected_literal = Literal::CreateR0(1.0f); + auto expected_literal = LiteralUtil::CreateR0(1.0f); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -113,63 +113,54 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR0(1.0f); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); - auto a = builder.Neg(parameter); - builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); + auto a = Neg(parameter); + Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); - auto expected_literal = Literal::CreateR1({-1.0f}); + auto expected_literal = LiteralUtil::CreateR1({-1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { +XLA_TEST_P(ReshapeTest, Trivial0x3) { XlaBuilder builder(TestName()); Array2D input_array(0, 3); - auto input_literal = Literal::CreateR2FromArray2D(input_array); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-05-15 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { +XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR2FromArray2D(Array2D(0, 3)); + LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { +XLA_TEST_P(ReshapeTest, Trivial3x0) { XlaBuilder builder(TestName()); Array2D input_array(3, 0); - auto input_literal = Literal::CreateR2FromArray2D(input_array); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -177,12 +168,12 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { // Collapses a 2-dimensional row vector to 1 dimension. XLA_TEST_P(ReshapeTest, Trivial1x3) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR2({{1.0f, 2.0f, 3.0f}}); + auto input_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -190,30 +181,26 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { // Collapses a 2-dimensional column vector to 1 dimension. XLA_TEST_P(ReshapeTest, Trivial3x1) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR2({{1.0f}, {2.0f}, {3.0f}}); + auto input_literal = LiteralUtil::CreateR2({{1.0f}, {2.0f}, {3.0f}}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Splits an empty vector into an empty matrix. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { +XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR1({}); + auto input_literal = LiteralUtil::CreateR1({}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, - /*new_sizes=*/{2, 0}); - auto expected_literal = Literal::CreateR2({{}, {}}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 0}); + auto expected_literal = LiteralUtil::CreateR2({{}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -222,32 +209,28 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { XlaBuilder builder(TestName()); auto input_literal = - Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, - /*new_sizes=*/{2, 3}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 3}); auto expected_literal = - Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Transposes a 2x0 array to a 0x2 array. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { +XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); + auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 2)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 0}); - auto expected_literal = Literal::CreateR2({{}, {}}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 0}); + auto expected_literal = LiteralUtil::CreateR2({{}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -256,15 +239,15 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { XlaBuilder builder(TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); - auto input_literal = Literal::CreateFromArray(*simple); + auto input_literal = LiteralUtil::CreateFromArray(*simple); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{3, 1}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); - auto expected_literal = Literal::CreateFromArray(*expected); + auto expected_literal = LiteralUtil::CreateFromArray(*expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -273,32 +256,28 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { XLA_TEST_P(ReshapeTest, TransposeAsReshape) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto input_literal = Literal::CreateFromArray(*a4x3); + auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 4}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 4}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); - auto expected_literal = Literal::CreateFromArray(*expected); + auto expected_literal = LiteralUtil::CreateFromArray(*expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Transposes a 0x4 array with XlaBuilder::Transpose. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { +XLA_TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); + auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 4)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Transpose(parameter, {1, 0}); - auto expected_literal = Literal::CreateR2({{}, {}, {}, {}}); + Transpose(parameter, {1, 0}); + auto expected_literal = LiteralUtil::CreateR2({{}, {}, {}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -307,49 +286,43 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { XLA_TEST_P(ReshapeTest, Transpose4x3) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto input_literal = Literal::CreateFromArray(*a4x3); + auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Transpose(parameter, {1, 0}); + Transpose(parameter, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); - auto expected_literal = Literal::CreateFromArray(*expected); + auto expected_literal = LiteralUtil::CreateFromArray(*expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { +XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); + auto input_literal = LiteralUtil::CreateFromArray(Array2D(6, 0)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 3, 0, 0}); - auto expected_literal = Literal::CreateFromArray(Array4D(2, 3, 0, 0)); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 3, 0, 0}); + auto expected_literal = + LiteralUtil::CreateFromArray(Array4D(2, 3, 0, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { +XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array4D(2, 3, 4, 0)); + auto input_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 4, 0)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{24, 0}); - auto expected_literal = Literal::CreateFromArray(Array2D(24, 0)); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{24, 0}); + auto expected_literal = LiteralUtil::CreateFromArray(Array2D(24, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -359,32 +332,28 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto input_literal = Literal::CreateFromArray(*a4x3); + auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 6}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 6}); auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); - auto expected_literal = Literal::CreateFromArray(*expected); + auto expected_literal = LiteralUtil::CreateFromArray(*expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } -// TODO(b/29185393): Make this work with the GPU backend. The GPU backend -// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 -// with an incorrect result rank. -// -XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { +XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array2D(0, 6)); + auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 6)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 0}); - auto expected_literal = Literal::CreateFromArray(Array2D(3, 0)); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 0}); + auto expected_literal = LiteralUtil::CreateFromArray(Array2D(3, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -394,15 +363,15 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto input_literal = Literal::CreateFromArray(*a4x3); + auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, - /*new_sizes=*/{2, 6}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{2, 6}); Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); - auto expected_literal = Literal::CreateFromArray(expected); + auto expected_literal = LiteralUtil::CreateFromArray(expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -420,13 +389,13 @@ static Array3D ArrayForDocR3Tests() { XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{24}); - auto expected_literal = Literal::CreateR1( + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{24}); + auto expected_literal = LiteralUtil::CreateR1( {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -435,33 +404,33 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{8, 3}); - auto expected_literal = Literal::CreateR2({{10, 11, 12}, - {15, 16, 17}, - {20, 21, 22}, - {25, 26, 27}, - {30, 31, 32}, - {35, 36, 37}, - {40, 41, 42}, - {45, 46, 47}}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{8, 3}); + auto expected_literal = LiteralUtil::CreateR2({{10, 11, 12}, + {15, 16, 17}, + {20, 21, 22}, + {25, 26, 27}, + {30, 31, 32}, + {35, 36, 37}, + {40, 41, 42}, + {45, 46, 47}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{24}); - auto expected_literal = Literal::CreateR1( + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{24}); + auto expected_literal = LiteralUtil::CreateR1( {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -470,33 +439,33 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{8, 3}); - auto expected_literal = Literal::CreateR2({{10, 20, 30}, - {40, 11, 21}, - {31, 41, 12}, - {22, 32, 42}, - {15, 25, 35}, - {45, 16, 26}, - {36, 46, 17}, - {27, 37, 47}}); + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{8, 3}); + auto expected_literal = LiteralUtil::CreateR2({{10, 20, 30}, + {40, 11, 21}, + {31, 41, 12}, + {22, 32, 42}, + {15, 25, 35}, + {45, 16, 26}, + {36, 46, 17}, + {27, 37, 47}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{2, 6, 2}); - auto expected_literal = Literal::CreateR3( + Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{2, 6, 2}); + auto expected_literal = LiteralUtil::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -523,12 +492,12 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { Array4D t2x2x2x3(2, 2, 2, 3); auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3); t2x2x2x3.FillWithYX(*filler2x3); - auto input_literal = Literal::CreateFromArray(t2x2x2x3); + auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); - auto expected_literal = Literal::CreateR2( + Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); + auto expected_literal = LiteralUtil::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); @@ -548,15 +517,15 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 0, 1) = 5; t(1, 0, 1, 0) = 6; t(1, 0, 1, 1) = 7; - auto input_literal = Literal::CreateFromArray(t); + auto input_literal = LiteralUtil::CreateFromArray(t); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{2, 4}); + Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{2, 4}); auto expected_literal = - Literal::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); + LiteralUtil::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -575,9 +544,9 @@ XLA_TEST_P(ReshapeTest, ToScalar) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); - b.Reshape(parameter, dimensions, {}); + Reshape(parameter, dimensions, {}); - auto expected_literal = Literal::CreateR0(83.0f); + auto expected_literal = LiteralUtil::CreateR0(83.0f); ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, zero_error_spec_); } @@ -585,11 +554,11 @@ XLA_TEST_P(ReshapeTest, ToScalar) { XLA_TEST_P(ReshapeTest, BadDimensions) { XlaBuilder b(TestName()); - auto input_literal = Literal::CreateR1({1.0f}); + auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); - b.Reshape(parameter, {}, {}); + Reshape(parameter, {}, {}); EXPECT_THAT( ExecuteToString(&b, {}), ::testing::HasSubstr("not a permutation of the operand dimensions")); @@ -597,11 +566,11 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { XLA_TEST_P(ReshapeTest, BadNewSizes) { XlaBuilder b(TestName()); - auto input_literal = Literal::CreateR1({1.0f, 2.0f}); + auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); - b.Reshape(parameter, {1}, {}); + Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), ::testing::HasSubstr("mismatched element counts")); } @@ -609,7 +578,8 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { XlaBuilder builder(TestName()); // clang-format off - auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ + auto input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + Array4D{ { { {0, 1}, @@ -637,7 +607,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); + Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); Array2D expected_array({ {0, 1, 2, 3, 100, 101, 102, 103}, @@ -654,16 +624,16 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); std::unique_ptr expected = - Literal::CreateR2FromArray2D(expected_array); + LiteralUtil::CreateR2FromArray2D(expected_array); if (use_bfloat16()) { - expected = Literal::ConvertF32ToBF16(*expected); + expected = LiteralUtil::ConvertF32ToBF16(*expected); } EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = Literal::CreateR2({ + std::unique_ptr input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, @@ -671,10 +641,10 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); + Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - auto expected_literal = Literal::CreateR4({ + auto expected_literal = LiteralUtil::CreateR4({ {{{0, 1, 2, 3}}, {{4, 5, 6, 7}}}, {{{100, 101, 102, 103}}, @@ -690,7 +660,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { // Tests R2->R4 reshape with the reshape dimensions {1, 0}. XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = Literal::CreateR2({ + std::unique_ptr input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, @@ -698,10 +668,10 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); + Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - auto expected_literal = Literal::CreateR4({ + auto expected_literal = LiteralUtil::CreateR4({ {{{0, 100, 200, 1}}, {{101, 201, 2, 102}}}, {{{202, 3, 103, 203}}, @@ -723,15 +693,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); + Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = - Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal); + LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -745,15 +715,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); + Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = - Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal); + LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -768,20 +738,20 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, - /*new_sizes=*/{5, 60}); + Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, + /*new_sizes=*/{5, 60}); Array2D expected_array(5, 60); input.Each([&](tensorflow::gtl::ArraySlice indices, float* cell) { expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) = *cell; }); - auto expected = Literal::CreateR2FromArray2D(expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -795,13 +765,13 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, - /*new_sizes=*/{7, 2, 3, 5}); + Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, + /*new_sizes=*/{7, 2, 3, 5}); XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; @@ -817,7 +787,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. if (use_bfloat16()) { - auto expected = Literal::ConvertF32ToBF16(*input_literal); + auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal); EXPECT_EQ(expected->data(), output_literal->data()); } else { EXPECT_EQ(input_literal->data(), output_literal->data()); @@ -826,21 +796,21 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { XlaBuilder builder(TestName()); - auto literal_1x2x3x4 = Literal::CreateR4( + auto literal_1x2x3x4 = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{1, 2, 3, 4}); + Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{1, 2, 3, 4}); ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()}); } XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { - auto literal_1x2x3x4 = Literal::CreateR4( + auto literal_1x2x3x4 = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); @@ -848,11 +818,11 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, - /*new_sizes=*/{2, 4, 3, 1}); + Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, + /*new_sizes=*/{2, 4, 3, 1}); // clang-format off - auto expected_2x4x3x1 = Literal::CreateR4( + auto expected_2x4x3x1 = LiteralUtil::CreateR4( {{{{1}, {5}, {9}}, {{2}, {6}, {10}}, {{3}, {7}, {11}}, @@ -876,17 +846,17 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -905,17 +875,17 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -934,17 +904,17 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -964,17 +934,17 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -993,17 +963,17 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); - builder.Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, - /*new_sizes=*/new_bounds); + Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) ->Relayout(input_literal->shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index e7bd142dc9ddefbd8bebfb77d72218d662645c31..41e49b4003236d55d85592315652a0ddefd5c485 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -82,12 +82,12 @@ TEST_P(FloatReverseTest, Reverses) { std::vector input_vector( ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims))); std::iota(input_vector.begin(), input_vector.end(), 0.0); - auto r1_literal = Literal::CreateR1(input_vector); + auto r1_literal = LiteralUtil::CreateR1(input_vector); auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie(); XlaBuilder builder(TestName()); auto a = AddParam(*input_literal, &builder); - builder.Rev(a, spec.reversal); + Rev(a, spec.reversal); std::unique_ptr expected = input_literal->CloneToUnique(); std::vector output_indices(spec.input_dims.size()); @@ -127,7 +127,7 @@ XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) { }}); // clang-format on - b.Rev(b.ConstantR4FromArray4D(input), {0, 3}); + Rev(ConstantR4FromArray4D(&b, input), {0, 3}); // clang-format off Array4D expected({{ @@ -163,7 +163,7 @@ TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) { }); // clang-format on - b.Rev(b.ConstantR4FromArray4D(input), {0, 1}); + Rev(ConstantR4FromArray4D(&b, input), {0, 1}); // clang-format off Array4D expected({ diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index 7cfca781acda15879075f4386c2096e537877aac..a620fe19085d98c8b6642b25b159d6c2308bdae2 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/packed_literal_reader.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index f334a8c1318a59bbfdd27dd1a63ed162600089ce..a8193c2eac05ba4f0df339909f3e82a28ac35253 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -46,61 +46,62 @@ class RoundTripTransferTest : public ClientLibraryTestBase { }; TEST_F(RoundTripTransferTest, R0S32) { - RoundTripTest(*Literal::CreateR0(42)); + RoundTripTest(*LiteralUtil::CreateR0(42)); } TEST_F(RoundTripTransferTest, R0F32) { - RoundTripTest(*Literal::CreateR0(42.0)); + RoundTripTest(*LiteralUtil::CreateR0(42.0)); } TEST_F(RoundTripTransferTest, R1F32_Len0) { - RoundTripTest(*Literal::CreateR1({})); + RoundTripTest(*LiteralUtil::CreateR1({})); } TEST_F(RoundTripTransferTest, R1F32_Len2) { - RoundTripTest(*Literal::CreateR1({42.0, 64.0})); + RoundTripTest(*LiteralUtil::CreateR1({42.0, 64.0})); } TEST_F(RoundTripTransferTest, R1F32_Len256) { std::vector values(256); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*Literal::CreateR1(values)); + RoundTripTest(*LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1024) { std::vector values(1024); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*Literal::CreateR1(values)); + RoundTripTest(*LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1025) { std::vector values(1025); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*Literal::CreateR1(values)); + RoundTripTest(*LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len4096) { std::vector values(4096); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*Literal::CreateR1(values)); + RoundTripTest(*LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R2F32_Len10x0) { - RoundTripTest(*Literal::CreateR2FromArray2D(Array2D(10, 0))); + RoundTripTest( + *LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); } TEST_F(RoundTripTransferTest, R2F32_Len2x2) { - RoundTripTest(*Literal::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); + RoundTripTest(*LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); } TEST_F(RoundTripTransferTest, R3F32) { RoundTripTest( - *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); + *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); } TEST_F(RoundTripTransferTest, R4F32) { - RoundTripTest(*Literal::CreateR4({{ + RoundTripTest(*LiteralUtil::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -108,33 +109,36 @@ TEST_F(RoundTripTransferTest, R4F32) { } TEST_F(RoundTripTransferTest, EmptyTuple) { - RoundTripTest(*Literal::MakeTuple({})); + RoundTripTest(*LiteralUtil::MakeTuple({})); } TEST_F(RoundTripTransferTest, TupleOfR1F32) { - RoundTripTest(*Literal::MakeTuple({Literal::CreateR1({1, 2}).get(), - Literal::CreateR1({3, 4}).get()})); + RoundTripTest( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), + LiteralUtil::CreateR1({3, 4}).get()})); } TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) { - RoundTripTest(*Literal::MakeTuple({Literal::CreateR1({}).get(), - Literal::CreateR1({3, 4}).get()})); + RoundTripTest( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({}).get(), + LiteralUtil::CreateR1({3, 4}).get()})); } TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) { - RoundTripTest(*Literal::MakeTuple({Literal::CreateR0(1.0).get(), - Literal::CreateR1({2, 3}).get()})); + RoundTripTest( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(1.0).get(), + LiteralUtil::CreateR1({2, 3}).get()})); } // Below two tests are added to identify the cost of large data transfers. TEST_F(RoundTripTransferTest, R2F32_Large) { - RoundTripTest(*Literal::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); + RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); } TEST_F(RoundTripTransferTest, R4F32_Large) { Array4D array4d(2, 2, 256, 256); array4d.FillWithMultiples(1.0f); - RoundTripTest(*Literal::CreateR4FromArray4D(array4d)); + RoundTripTest(*LiteralUtil::CreateR4FromArray4D(array4d)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 308d3fc78a51e63c0e3db8c0cda18caf11f665bd..e42c71eb284deb2e50d6ea4b47fa707e4bc14ffc 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -19,8 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -44,74 +45,75 @@ class ScalarComputationsTest : public ClientLibraryTestBase { protected: // A template for building and running a binary comparison test. template - void TestCompare( - NativeT lhs, NativeT rhs, bool expected, - XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, - tensorflow::gtl::ArraySlice)) { + void TestCompare(NativeT lhs, NativeT rhs, bool expected, + std::function)> + op) { XlaBuilder builder(TestName()); - XlaOp lhs_op = builder.ConstantR0(lhs); - XlaOp rhs_op = builder.ConstantR0(rhs); - XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); + XlaOp lhs_op = ConstantR0(&builder, lhs); + XlaOp rhs_op = ConstantR0(&builder, rhs); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } template void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, - XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, - tensorflow::gtl::ArraySlice)) { + std::function)> + op) { XlaBuilder builder(TestName()); - XlaOp lhs_op = builder.ConstantR0(lhs); - XlaOp rhs_op = builder.ConstantR0(rhs); - XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); + XlaOp lhs_op = ConstantR0(&builder, lhs); + XlaOp rhs_op = ConstantR0(&builder, rhs); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } }; XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) { XlaBuilder builder(TestName()); - builder.ConstantR0(2.1f); + ConstantR0(&builder, 2.1f); ComputeAndCompareR0(&builder, 2.1f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) { XlaBuilder builder(TestName()); - builder.Neg(builder.ConstantR0(2.1f)); + Neg(ConstantR0(&builder, 2.1f)); ComputeAndCompareR0(&builder, -2.1f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) { XlaBuilder builder(TestName()); - builder.Neg(builder.ConstantR0(2)); + Neg(ConstantR0(&builder, 2)); ComputeAndCompareR0(&builder, -2, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); + Add(ConstantR0(&builder, 2.1f), ConstantR0(&builder, 5.5f)); ComputeAndCompareR0(&builder, 7.6f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(2), builder.ConstantR0(5)); + Add(ConstantR0(&builder, 2), ConstantR0(&builder, 5)); ComputeAndCompareR0(&builder, 7, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU32) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); + Add(ConstantR0(&builder, 35), ConstantR0(&builder, 57)); ComputeAndCompareR0(&builder, 92, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); + Add(ConstantR0(&builder, 35), ConstantR0(&builder, 57)); ComputeAndCompareR0(&builder, 92, {}); } @@ -120,7 +122,7 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) { XlaBuilder builder(TestName()); const uint64 a = static_cast(1) << 63; const uint64 b = a + 1; - builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); + Add(ConstantR0(&builder, a), ConstantR0(&builder, b)); ComputeAndCompareR0(&builder, a + b, {}); } @@ -129,40 +131,39 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) { XlaBuilder builder(TestName()); const int64 a = static_cast(1) << 62; const int64 b = a - 1; - builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); + Add(ConstantR0(&builder, a), ConstantR0(&builder, b)); ComputeAndCompareR0(&builder, a + b, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) { XlaBuilder builder(TestName()); - builder.Add(builder.ConstantR0(0.25), - builder.ConstantR0(3.5)); + Add(ConstantR0(&builder, 0.25), ConstantR0(&builder, 3.5)); ComputeAndCompareR0(&builder, 3.75, {}); } XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) { XlaBuilder builder(TestName()); - builder.Sub(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); + Sub(ConstantR0(&builder, 2.1f), ConstantR0(&builder, 5.5f)); ComputeAndCompareR0(&builder, -3.4f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) { XlaBuilder builder(TestName()); - builder.Sub(builder.ConstantR0(2), builder.ConstantR0(5)); + Sub(ConstantR0(&builder, 2), ConstantR0(&builder, 5)); ComputeAndCompareR0(&builder, -3, {}); } XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { XlaBuilder builder(TestName()); - auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a"); - builder.ConvertElementType(a, F32); + auto a = Parameter(&builder, 0, ShapeUtil::MakeShape(S64, {}), "a"); + ConvertElementType(a, F32); int64 value = 3LL << 35; - std::unique_ptr a_literal = Literal::CreateR0(value); + std::unique_ptr a_literal = LiteralUtil::CreateR0(value); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); ComputeAndCompareR0(&builder, static_cast(value), @@ -171,9 +172,8 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) { XlaBuilder builder(TestName()); - builder.Mul(builder.Mul(builder.ConstantR0(2.1f), - builder.ConstantR0(5.5f)), - builder.ConstantR0(0.5f)); + Mul(Mul(ConstantR0(&builder, 2.1f), ConstantR0(&builder, 5.5f)), + ConstantR0(&builder, 0.5f)); ComputeAndCompareR0(&builder, 5.775f, {}, error_spec_); } @@ -190,7 +190,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsS32) { for (int32 x : data) { for (int32 y : data) { XlaBuilder builder(TestName()); - builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); + Mul(ConstantR0(&builder, x), ConstantR0(&builder, y)); // Signed integer overflow is undefined behavior in C++. Convert the input // integers to unsigned, perform the multiplication unsigned, and convert @@ -209,7 +209,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { for (uint32 x : data) { for (uint32 y : data) { XlaBuilder builder(TestName()); - builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); + Mul(ConstantR0(&builder, x), ConstantR0(&builder, y)); uint32 expected = x * y; ComputeAndCompareR0(&builder, expected, {}); @@ -219,18 +219,17 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { XlaBuilder builder(TestName()); - builder.Mul( - builder.Mul(builder.ConstantR0(2), builder.ConstantR0(5)), - builder.ConstantR0(1)); + Mul(Mul(ConstantR0(&builder, 2), ConstantR0(&builder, 5)), + ConstantR0(&builder, 1)); ComputeAndCompareR0(&builder, 10, {}); } XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR0(2.1f); - std::unique_ptr b_literal = Literal::CreateR0(5.5f); - std::unique_ptr c_literal = Literal::CreateR0(0.5f); + std::unique_ptr a_literal = LiteralUtil::CreateR0(2.1f); + std::unique_ptr b_literal = LiteralUtil::CreateR0(5.5f); + std::unique_ptr c_literal = LiteralUtil::CreateR0(0.5f); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); @@ -239,10 +238,10 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { std::unique_ptr c_data = client_->TransferToServer(*c_literal).ConsumeValueOrDie(); - XlaOp a = builder.Parameter(0, a_literal->shape(), "a"); - XlaOp b = builder.Parameter(1, b_literal->shape(), "b"); - XlaOp c = builder.Parameter(2, c_literal->shape(), "c"); - builder.Mul(builder.Mul(a, b), c); + XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a"); + XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b"); + XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c"); + Mul(Mul(a, b), c); ComputeAndCompareR0(&builder, 5.775f, {a_data.get(), b_data.get(), c_data.get()}, @@ -251,14 +250,14 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) { XlaBuilder builder(TestName()); - builder.Div(builder.ConstantR0(5.0f), builder.ConstantR0(2.5f)); + Div(ConstantR0(&builder, 5.0f), ConstantR0(&builder, 2.5f)); ComputeAndCompareR0(&builder, 2.0f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) { XlaBuilder builder(TestName()); - builder.Rem(builder.ConstantR0(2.5f), builder.ConstantR0(5.0f)); + Rem(ConstantR0(&builder, 2.5f), ConstantR0(&builder, 5.0f)); ComputeAndCompareR0(&builder, 2.5f, {}, error_spec_); } @@ -281,8 +280,8 @@ class DivS32Test : public ClientLibraryTestBase, XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { DivS32Params p = GetParam(); XlaBuilder builder(TestName()); - builder.Div(builder.ConstantR0(p.dividend), - builder.ConstantR0(p.divisor)); + Div(ConstantR0(&builder, p.dividend), + ConstantR0(&builder, p.divisor)); ComputeAndCompareR0(&builder, p.quotient, {}); } @@ -290,8 +289,8 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) { DivS32Params p = GetParam(); XlaBuilder builder(TestName()); - builder.Rem(builder.ConstantR0(p.dividend), - builder.ConstantR0(p.divisor)); + Rem(ConstantR0(&builder, p.dividend), + ConstantR0(&builder, p.divisor)); ComputeAndCompareR0(&builder, p.remainder, {}); } @@ -305,7 +304,7 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) { CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); auto divisord = CreateR0Parameter(p.divisor, 1, "divisor", &builder, &divisor); - builder.Div(dividend, divisor); + Div(dividend, divisor); ComputeAndCompareR0(&builder, p.quotient, {dividendd.get(), divisord.get()}); @@ -320,7 +319,7 @@ XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) { CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); auto divisord = CreateR0Parameter(p.divisor, 1, "divisor", &builder, &divisor); - builder.Rem(dividend, divisor); + Rem(dividend, divisor); ComputeAndCompareR0(&builder, p.remainder, {dividendd.get(), divisord.get()}); @@ -367,18 +366,18 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { XlaBuilder builder(TestName()); XlaOp dividend = - builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); + Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend"); XlaOp divisor = - builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); - builder.Div(dividend, divisor); + Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor"); + Div(dividend, divisor); TF_ASSERT_OK_AND_ASSIGN(div_computation, builder.Build()); } for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = Literal::CreateR0(dividend); - auto divisor_literal = Literal::CreateR0(divisor); + auto dividend_literal = LiteralUtil::CreateR0(dividend); + auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, client_->TransferToServer(*dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, @@ -389,7 +388,8 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = Literal::CreateR0(dividend / divisor); + auto expected_literal = + LiteralUtil::CreateR0(dividend / divisor); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } @@ -408,18 +408,18 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { XlaBuilder builder(TestName()); XlaOp dividend = - builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); + Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend"); XlaOp divisor = - builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); - builder.Rem(dividend, divisor); + Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor"); + Rem(dividend, divisor); TF_ASSERT_OK_AND_ASSIGN(rem_computation, builder.Build()); } for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = Literal::CreateR0(dividend); - auto divisor_literal = Literal::CreateR0(divisor); + auto dividend_literal = LiteralUtil::CreateR0(dividend); + auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, client_->TransferToServer(*dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, @@ -430,7 +430,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = Literal::CreateR0(dividend % divisor); + auto expected_literal = + LiteralUtil::CreateR0(dividend % divisor); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } @@ -439,10 +440,10 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { XlaBuilder builder(TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); - builder.Rem(x, builder.ConstantR0(80000)); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); + Rem(x, ConstantR0(&builder, 80000)); - std::unique_ptr literal = Literal::CreateR0(87919); + std::unique_ptr literal = LiteralUtil::CreateR0(87919); TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal)); ComputeAndCompareR0(&builder, 7919, {input_data.get()}); } @@ -451,15 +452,15 @@ XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) { XlaBuilder builder(TestName()); // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32 // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF). - builder.Div(builder.ConstantR0(0xFFFFFFFE), - builder.ConstantR0(2)); + Div(ConstantR0(&builder, 0xFFFFFFFE), + ConstantR0(&builder, 2)); ComputeAndCompareR0(&builder, 0x7FFFFFFF, {}); } XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { XlaBuilder builder(TestName()); - builder.Rem(builder.ConstantR0(11), builder.ConstantR0(3)); + Rem(ConstantR0(&builder, 11), ConstantR0(&builder, 3)); ComputeAndCompareR0(&builder, 2, {}); } @@ -468,7 +469,7 @@ XLA_TEST_F(ScalarComputationsTest, AndBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { XlaBuilder builder(TestName()); - builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + And(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x && y, {}); } @@ -479,7 +480,7 @@ XLA_TEST_F(ScalarComputationsTest, AndS32) { for (int32 x : {0, 8}) { for (int32 y : {1, -16}) { XlaBuilder builder(TestName()); - builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + And(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x & y, {}); } @@ -490,7 +491,7 @@ XLA_TEST_F(ScalarComputationsTest, AndU32) { for (uint32 x : {0, 8}) { for (uint32 y : {1, 16}) { XlaBuilder builder(TestName()); - builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + And(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x & y, {}); } @@ -501,7 +502,7 @@ XLA_TEST_F(ScalarComputationsTest, OrBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { XlaBuilder builder(TestName()); - builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + Or(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x || y, {}); } @@ -512,7 +513,7 @@ XLA_TEST_F(ScalarComputationsTest, OrS32) { for (int32 x : {0, 8}) { for (int32 y : {1, -16}) { XlaBuilder builder(TestName()); - builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + Or(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x | y, {}); } @@ -523,7 +524,7 @@ XLA_TEST_F(ScalarComputationsTest, OrU32) { for (uint32 x : {0, 8}) { for (uint32 y : {1, 16}) { XlaBuilder builder(TestName()); - builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + Or(ConstantR0(&builder, x), ConstantR0(&builder, y)); ComputeAndCompareR0(&builder, x | y, {}); } @@ -533,7 +534,7 @@ XLA_TEST_F(ScalarComputationsTest, OrU32) { XLA_TEST_F(ScalarComputationsTest, NotBool) { for (bool x : {false, true}) { XlaBuilder builder(TestName()); - builder.Not(builder.ConstantR0(x)); + Not(ConstantR0(&builder, x)); ComputeAndCompareR0(&builder, !x, {}); } @@ -542,7 +543,7 @@ XLA_TEST_F(ScalarComputationsTest, NotBool) { XLA_TEST_F(ScalarComputationsTest, NotS32) { for (int32 x : {-1, 0, 1}) { XlaBuilder builder(TestName()); - builder.Not(builder.ConstantR0(x)); + Not(ConstantR0(&builder, x)); ComputeAndCompareR0(&builder, ~x, {}); } @@ -551,7 +552,7 @@ XLA_TEST_F(ScalarComputationsTest, NotS32) { XLA_TEST_F(ScalarComputationsTest, NotU32) { for (uint32 x : {0, 1, 2}) { XlaBuilder builder(TestName()); - builder.Not(builder.ConstantR0(x)); + Not(ConstantR0(&builder, x)); ComputeAndCompareR0(&builder, ~x, {}); } @@ -559,18 +560,18 @@ XLA_TEST_F(ScalarComputationsTest, NotU32) { XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { XlaBuilder builder(TestName()); - builder.Select(builder.ConstantR0(true), // The predicate. - builder.ConstantR0(123.0f), // The value on true. - builder.ConstantR0(42.0f)); // The value on false. + Select(ConstantR0(&builder, true), // The predicate. + ConstantR0(&builder, 123.0f), // The value on true. + ConstantR0(&builder, 42.0f)); // The value on false. ComputeAndCompareR0(&builder, 123.0f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { XlaBuilder builder(TestName()); - builder.Select(builder.ConstantR0(false), // The predicate. - builder.ConstantR0(123.0f), // The value on true. - builder.ConstantR0(42.0f)); // The value on false. + Select(ConstantR0(&builder, false), // The predicate. + ConstantR0(&builder, 123.0f), // The value on true. + ConstantR0(&builder, 42.0f)); // The value on false. ComputeAndCompareR0(&builder, 42.0f, {}, error_spec_); } @@ -579,313 +580,311 @@ XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { // templatized comparison tests. XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) { XlaBuilder builder(TestName()); - builder.Gt(builder.ConstantR0(2.0f), builder.ConstantR0(1.0f)); + Gt(ConstantR0(&builder, 2.0f), ConstantR0(&builder, 1.0f)); ComputeAndCompareR0(&builder, true, {}); } // S32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) { - TestCompare(2, 1, false, &XlaBuilder::Eq); + TestCompare(2, 1, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) { - TestCompare(3, 3, true, &XlaBuilder::Eq); + TestCompare(3, 3, true, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeS32) { - TestCompare(2, 1, true, &XlaBuilder::Ne); + TestCompare(2, 1, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeS32) { - TestCompare(2, 1, true, &XlaBuilder::Ge); + TestCompare(2, 1, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtS32) { - TestCompare(1, 5, false, &XlaBuilder::Gt); + TestCompare(1, 5, false, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeS32) { - TestCompare(2, 1, false, &XlaBuilder::Le); + TestCompare(2, 1, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtS32) { - TestCompare(9, 7, false, &XlaBuilder::Lt); + TestCompare(9, 7, false, &Lt); TestCompare(std::numeric_limits::min(), - std::numeric_limits::max(), true, &XlaBuilder::Lt); + std::numeric_limits::max(), true, &Lt); } // U32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) { - TestCompare(2, 1, false, &XlaBuilder::Eq); + TestCompare(2, 1, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeU32) { - TestCompare(2, 1, true, &XlaBuilder::Ne); + TestCompare(2, 1, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) { - TestCompare(2, 1, true, &XlaBuilder::Ge); + TestCompare(2, 1, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) { - TestCompare(3, 3, true, &XlaBuilder::Ge); + TestCompare(3, 3, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtU32) { - TestCompare(1, 5, false, &XlaBuilder::Gt); - TestCompare(5, 5, false, &XlaBuilder::Gt); - TestCompare(5, 1, true, &XlaBuilder::Gt); + TestCompare(1, 5, false, &Gt); + TestCompare(5, 5, false, &Gt); + TestCompare(5, 1, true, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeU32) { - TestCompare(2, 1, false, &XlaBuilder::Le); + TestCompare(2, 1, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtU32) { - TestCompare(9, 7, false, &XlaBuilder::Lt); - TestCompare(0, std::numeric_limits::max(), true, - &XlaBuilder::Lt); + TestCompare(9, 7, false, &Lt); + TestCompare(0, std::numeric_limits::max(), true, &Lt); } // F32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) { - TestCompare(2.0, 1.3, false, &XlaBuilder::Eq); + TestCompare(2.0, 1.3, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeF32) { - TestCompare(2.0, 1.3, true, &XlaBuilder::Ne); + TestCompare(2.0, 1.3, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) { - TestCompare(2.0, 1.9, true, &XlaBuilder::Ge); + TestCompare(2.0, 1.9, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) { - TestCompare(3.5, 3.5, true, &XlaBuilder::Ge); + TestCompare(3.5, 3.5, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtF32) { - TestCompare(1.0, 5.2, false, &XlaBuilder::Gt); + TestCompare(1.0, 5.2, false, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeF32) { - TestCompare(2.0, 1.2, false, &XlaBuilder::Le); + TestCompare(2.0, 1.2, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32) { - TestCompare(9.0, 7.2, false, &XlaBuilder::Lt); + TestCompare(9.0, 7.2, false, &Lt); } // F32 comparisons with exceptional values. The test names encode the // left/right operands at the end, and use Minf and Mzero for -inf and -0.0. XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) { - TestCompare(-INFINITY, -0.0, true, &XlaBuilder::Lt); + TestCompare(-INFINITY, -0.0, true, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, false, &XlaBuilder::Lt); + TestCompare(-0.0, 0.0, false, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) { - TestCompare(0.0, INFINITY, true, &XlaBuilder::Lt); + TestCompare(0.0, INFINITY, true, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) { - TestCompare(-INFINITY, -0.0, false, &XlaBuilder::Ge); + TestCompare(-INFINITY, -0.0, false, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, true, &XlaBuilder::Ge); + TestCompare(-0.0, 0.0, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) { - TestCompare(0.0, INFINITY, false, &XlaBuilder::Ge); + TestCompare(0.0, INFINITY, false, &Ge); } XLA_TEST_F(ScalarComputationsTest, ExpScalar) { XlaBuilder builder(TestName()); - builder.Exp(builder.ConstantR0(2.0f)); + Exp(ConstantR0(&builder, 2.0f)); ComputeAndCompareR0(&builder, 7.3890562, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, LogScalar) { XlaBuilder builder("log"); - builder.Log(builder.ConstantR0(2.0f)); + Log(ConstantR0(&builder, 2.0f)); ComputeAndCompareR0(&builder, 0.6931471, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, TanhScalar) { XlaBuilder builder(TestName()); - builder.Tanh(builder.ConstantR0(2.0f)); + Tanh(ConstantR0(&builder, 2.0f)); ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) { XlaBuilder builder(TestName()); - builder.Tanh(builder.ConstantR0(2.0)); + Tanh(ConstantR0(&builder, 2.0)); ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, PowScalar) { XlaBuilder builder(TestName()); - builder.Pow(builder.ConstantR0(2.0f), builder.ConstantR0(3.0f)); + Pow(ConstantR0(&builder, 2.0f), ConstantR0(&builder, 3.0f)); ComputeAndCompareR0(&builder, 8.0, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(-1), // The lower bound. - builder.ConstantR0(5), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, -1), // The lower bound. + ConstantR0(&builder, 5), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 3, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(-1), // The lower bound. - builder.ConstantR0(2), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, -1), // The lower bound. + ConstantR0(&builder, 2), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 2, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(-1), // The lower bound. - builder.ConstantR0(-5), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, -1), // The lower bound. + ConstantR0(&builder, -5), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, -1, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(1), // The lower bound. - builder.ConstantR0(5), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, 1), // The lower bound. + ConstantR0(&builder, 5), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 3, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(1), // The lower bound. - builder.ConstantR0(2), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, 1), // The lower bound. + ConstantR0(&builder, 2), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 2, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(1), // The lower bound. - builder.ConstantR0(0), // The operand to be clamped. - builder.ConstantR0(3)); // The upper bound. + Clamp(ConstantR0(&builder, 1), // The lower bound. + ConstantR0(&builder, 0), // The operand to be clamped. + ConstantR0(&builder, 3)); // The upper bound. ComputeAndCompareR0(&builder, 1, {}); } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. - builder.ConstantR0(5.0f), // The operand to be clamped. - builder.ConstantR0(3.0f)); // The upper bound. + Clamp(ConstantR0(&builder, 2.0f), // The lower bound. + ConstantR0(&builder, 5.0f), // The operand to be clamped. + ConstantR0(&builder, 3.0f)); // The upper bound. ComputeAndCompareR0(&builder, 3.0, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. - builder.ConstantR0(2.5f), // The operand to be clamped. - builder.ConstantR0(3.0f)); // The upper bound. + Clamp(ConstantR0(&builder, 2.0f), // The lower bound. + ConstantR0(&builder, 2.5f), // The operand to be clamped. + ConstantR0(&builder, 3.0f)); // The upper bound. ComputeAndCompareR0(&builder, 2.5, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { XlaBuilder builder(TestName()); - builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. - builder.ConstantR0(-5.0f), // The operand to be clamped. - builder.ConstantR0(3.0f)); // The upper bound. + Clamp(ConstantR0(&builder, 2.0f), // The lower bound. + ConstantR0(&builder, -5.0f), // The operand to be clamped. + ConstantR0(&builder, 3.0f)); // The upper bound. ComputeAndCompareR0(&builder, 2.0, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, MinS32Above) { - TestMinMax(10, 3, 3, &XlaBuilder::Min); + TestMinMax(10, 3, 3, &Min); } XLA_TEST_F(ScalarComputationsTest, MinS32Below) { - TestMinMax(-100, 3, -100, &XlaBuilder::Min); + TestMinMax(-100, 3, -100, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxS32Above) { - TestMinMax(10, 3, 10, &XlaBuilder::Max); + TestMinMax(10, 3, 10, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxS32Below) { - TestMinMax(-100, 3, 3, &XlaBuilder::Max); + TestMinMax(-100, 3, 3, &Max); } XLA_TEST_F(ScalarComputationsTest, MinU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, 3, &XlaBuilder::Min); + TestMinMax(large, 3, 3, &Min); } XLA_TEST_F(ScalarComputationsTest, MinU32Below) { - TestMinMax(0, 5, 0, &XlaBuilder::Min); + TestMinMax(0, 5, 0, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, large, &XlaBuilder::Max); + TestMinMax(large, 3, large, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxU32Below) { - TestMinMax(0, 5, 5, &XlaBuilder::Max); + TestMinMax(0, 5, 5, &Max); } XLA_TEST_F(ScalarComputationsTest, MinF32Above) { - TestMinMax(10.1f, 3.1f, 3.1f, &XlaBuilder::Min); + TestMinMax(10.1f, 3.1f, 3.1f, &Min); } XLA_TEST_F(ScalarComputationsTest, MinF32Below) { - TestMinMax(-100.1f, 3.1f, -100.1f, &XlaBuilder::Min); + TestMinMax(-100.1f, 3.1f, -100.1f, &Min); } XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) { SetFastMathDisabled(true); - TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Min); - TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Min); + TestMinMax(NAN, 3.1f, NAN, &Min); + TestMinMax(-3.1f, NAN, NAN, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxF32Above) { - TestMinMax(10.1f, 3.1f, 10.1f, &XlaBuilder::Max); + TestMinMax(10.1f, 3.1f, 10.1f, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxF32Below) { - TestMinMax(-100.1f, 3.1f, 3.1f, &XlaBuilder::Max); + TestMinMax(-100.1f, 3.1f, 3.1f, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) { SetFastMathDisabled(true); - TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Max); - TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Max); + TestMinMax(NAN, 3.1f, NAN, &Max); + TestMinMax(-3.1f, NAN, NAN, &Max); } XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20. XlaBuilder b(TestName()); - b.Div( - b.Sub(b.Mul(b.ConstantR0(1), - b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), - b.Add(b.ConstantR0(7), b.ConstantR0(0)))), - b.ConstantR0(4)), - b.ConstantR0(20)); + Div(Sub(Mul(ConstantR0(&b, 1), + Mul(Sub(ConstantR0(&b, 3), ConstantR0(&b, 1)), + Add(ConstantR0(&b, 7), ConstantR0(&b, 0)))), + ConstantR0(&b, 4)), + ConstantR0(&b, 20)); ComputeAndCompareR0(&b, 0.5, {}, error_spec_); } @@ -893,30 +892,18 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { // Compute the expression 1 * (3 - 1) * (7 + 0) - 4. XlaBuilder b(TestName()); - b.Sub(b.Mul(b.ConstantR0(1), - b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), - b.Add(b.ConstantR0(7), b.ConstantR0(0)))), - b.ConstantR0(4)); + Sub(Mul(ConstantR0(&b, 1), + Mul(Sub(ConstantR0(&b, 3), ConstantR0(&b, 1)), + Add(ConstantR0(&b, 7), ConstantR0(&b, 0)))), + ConstantR0(&b, 4)); ComputeAndCompareR0(&b, 10, {}); } -XLA_TEST_F(ScalarComputationsTest, SqrtF320) { - XlaBuilder builder(TestName()); - Literal zero_literal = Literal::Zero(PrimitiveType::F32); - - std::unique_ptr zero_data = - client_->TransferToServer(zero_literal).ConsumeValueOrDie(); - - XlaOp zero = builder.Parameter(0, zero_literal.shape(), "zero"); - builder.SqrtF32(zero); - - ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); -} XLA_TEST_F(ScalarComputationsTest, RoundScalar) { XlaBuilder builder(TestName()); - builder.Round(builder.ConstantR0(1.4f)); + Round(ConstantR0(&builder, 1.4f)); ComputeAndCompareR0(&builder, 1.0f, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..769d07e9d0cbacce656b9327c13417a13976e3d8 --- /dev/null +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -0,0 +1,616 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class ScatterTest : public HloTestBase { + protected: + void RunTest(const string& hlo_text, Literal* operand, + Literal* scatter_indices, Literal* updates) { + RunTest(hlo_text, {operand, scatter_indices, updates}); + } + + void RunTest(const string& hlo_text, + tensorflow::gtl::ArraySlice args) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); + } +}; + +XLA_TEST_F(ScatterTest, TensorFlowScatterV1_Update) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) { + const string hlo_text = R"( +HloModule TensorFlowScatter_Add + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) { + const string hlo_text = R"( +HloModule TensorFlowScatter_Mul + +mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=mul_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) { + const string hlo_text = R"( +HloModule TensorFlowScatter_F32 + +add_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(f32[] lhs, f32[] rhs) +} + +ENTRY main { + operand = f32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = f32[2,3] parameter(2) + ROOT scatter = f32[3,3] scatter(operand, indices, updates), + to_apply=add_f32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({2, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) { + const char* hlo_text = R"( +HloModule TensorFlowScatterMultipleBatchDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=2 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterNd) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNdNonDefaultIndexVectorDim + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, DynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule DynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0,1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + std::unique_ptr updates = + LiteralUtil::CreateR3({{{10}}, {{20}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_ZeroDimBounds + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,0] parameter(2) + ROOT scatter = s32[3,0] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, NoUpdateWindowDims) { + const string hlo_text = R"( +HloModule Scatter_NoUpdateWindowDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, OutOfBoundsIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = u32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, NegativeIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, OneScalarIndex) { + const char* hlo_text = R"( +HloModule OneScalarIndex + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[2,3,2]{2,1,0} parameter(0) + index = s32[] parameter(1) + updates = s32[1,3,2]{2,1,0} parameter(2) + ROOT scatter = s32[2,3,2]{2,1,0} scatter(operand, index, updates), + to_apply=update_s32, + update_window_dims={0,1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); + std::unique_ptr updates = + LiteralUtil::CreateR3({{{10, 20}, {30, 40}, {50, 60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, ScalarUpdate) { + const char* hlo_text = R"( +HloModule ScalarUpdate + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[4]{0} parameter(0) + index = s32[] parameter(1) + updates = s32[] parameter(2) + ROOT scatter = s32[4]{0} scatter(operand, index, updates), + to_apply=update_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 +} +)"; + std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); + std::unique_ptr updates = LiteralUtil::CreateR0(25); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, EmptyIndices) { + const string hlo_text = R"( +HloModule EmptyIndices + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[0] parameter(1) + updates = s32[0,0] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = LiteralUtil::CreateR1({}); + std::unique_ptr updates = LiteralUtil::CreateR2({{}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 7015e5a6a31f506d30c2629d7735482cf354455a..e3d4f98dd7432d1dce7e697586e8b17105dc82e7 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -73,16 +73,16 @@ XLA_TEST_P(SelectAndScatterTest, ParamTest) { auto operand_shape = GetParam().operand_shape; Array o(operand_shape); o.FillRandom(1.5f); - auto operand = builder_.ConstantFromArray(o); + auto operand = ConstantFromArray(&builder_, o); auto source_shape = GetParam().source_shape; Array s(source_shape); s.FillRandom(12.0f); - auto source = builder_.ConstantFromArray(s); + auto source = ConstantFromArray(&builder_, s); - builder_.SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions, - GetParam().window_strides, GetParam().padding_type, - source, builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions, + GetParam().window_strides, GetParam().padding_type, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompare(&builder_, {}, ErrorSpec(1e-5)); } @@ -197,110 +197,110 @@ INSTANTIATE_TEST_CASE_P( // Test for F32 1D array, with a zero-element input. XLA_TEST_F(SelectAndScatterTest, R1S0F32) { - const auto operand = builder_.ConstantR1({}); - const auto source = builder_.ConstantR1({}); - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, - /*window_strides=*/{3}, Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + const auto operand = ConstantR1(&builder_, {}); + const auto source = ConstantR1(&builder_, {}); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompareR1(&builder_, {}, {}, ErrorSpec(1e-7)); } // Test for F32 1D array, when windows do not overlap. XLA_TEST_F(SelectAndScatterTest, R1F32) { const auto operand = - builder_.ConstantR1({1.f, 9.f, 3.f, 7.f, 5.f, 6.f}); - const auto source = builder_.ConstantR1({34.f, 42.f}); + ConstantR1(&builder_, {1.f, 9.f, 3.f, 7.f, 5.f, 6.f}); + const auto source = ConstantR1(&builder_, {34.f, 42.f}); const std::vector expected = {0.f, 34.f, 0.f, 42.f, 0.f, 0.f}; - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, - /*window_strides=*/{3}, Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompareR1(&builder_, expected, {}, ErrorSpec(1e-7)); } // Test for S32 1D array, when windows do not overlap and the init value is 1. XLA_TEST_F(SelectAndScatterTest, R1S32) { - const auto operand = builder_.ConstantR1({-1, 0, 6, 4, -4, 10}); - const auto source = builder_.ConstantR1({-10, 20}); + const auto operand = ConstantR1(&builder_, {-1, 0, 6, 4, -4, 10}); + const auto source = ConstantR1(&builder_, {-10, 20}); const std::vector expected = {1, 1, -9, 1, 1, 21}; - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, - /*window_strides=*/{3}, Padding::kValid, source, - builder_.ConstantR0(1), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, + /*window_strides=*/{3}, Padding::kValid, source, + ConstantR0(&builder_, 1), add_s32_); ComputeAndCompareR1(&builder_, expected, {}); } // Test for S32 1D array, when windows overlap with each other. XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) { - const auto operand = builder_.ConstantR1({1, 9, 3, 7, 5, 6}); - const auto source = builder_.ConstantR1({34, 42, 53, 19}); + const auto operand = ConstantR1(&builder_, {1, 9, 3, 7, 5, 6}); + const auto source = ConstantR1(&builder_, {34, 42, 53, 19}); const std::vector expected = {0, 76, 0, 72, 0, 0}; - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, - /*window_strides=*/{1}, Padding::kValid, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3}, + /*window_strides=*/{1}, Padding::kValid, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR1(&builder_, expected, {}); } // Test for S32 2D array, when windows do not overlap. XLA_TEST_F(SelectAndScatterTest, R2S32) { const auto operand = - builder_.ConstantR2({{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}}); - const auto source = builder_.ConstantR2({{2, 6}}); + ConstantR2(&builder_, {{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}}); + const auto source = ConstantR2(&builder_, {{2, 6}}); Array2D expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}}); - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, - /*window_strides=*/{2, 3}, Padding::kValid, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{2, 3}, Padding::kValid, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } // Test for tie breaking rule in ge_f32_. When a tie is present, the operand // that has the lower lexicographical order (smaller index) should be chosen. XLA_TEST_F(SelectAndScatterTest, R2F32Tie) { - const auto operand = builder_.ConstantR2( - {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}}); - const auto source = builder_.ConstantR2( - {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); + const auto operand = ConstantR2( + &builder_, {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}}); + const auto source = ConstantR2( + &builder_, {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); Array2D expected( {{12.f, 9.f, 0.f}, {15.f, 9.f, 0.f}, {0.f, 0.f, 0.f}}); - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3}, - /*window_strides=*/{1, 1}, Padding::kSame, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3}, + /*window_strides=*/{1, 1}, Padding::kSame, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(1e-7)); } // Similar to SelectAndScatterTest.R2S32 but the input is transposed. XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) { - const auto operand = builder_.ConstantR2( - {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}}); + const auto operand = ConstantR2( + &builder_, {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}}); const auto reshape = - builder_.Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); - const auto source = builder_.ConstantR2({{2, 6}}); + Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); + const auto source = ConstantR2(&builder_, {{2, 6}}); Array2D expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}}); - builder_.SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3}, - /*window_strides=*/{2, 3}, Padding::kValid, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{2, 3}, Padding::kValid, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } // Test for S32 2D array, when windows overlap with each other. XLA_TEST_F(SelectAndScatterTest, R2S32OverlappingWindow) { const auto operand = - builder_.ConstantR2({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); - const auto source = builder_.ConstantR2({{2, 6, 4}}); + ConstantR2(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + const auto source = ConstantR2(&builder_, {{2, 6, 4}}); Array2D expected({{0, 0, 0, 0, 0}, {0, 0, 12, 0, 0}}); - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, - /*window_strides=*/{1, 1}, Padding::kValid, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3}, + /*window_strides=*/{1, 1}, Padding::kValid, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } // Test for S32 2D array, when the padding is Padding::kSAME. XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) { const auto operand = - builder_.ConstantR2({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); - const auto source = builder_.ConstantR2({{2, 6, 4}}); + ConstantR2(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + const auto source = ConstantR2(&builder_, {{2, 6, 4}}); Array2D expected({{0, 0, 0, 0, 4}, {0, 2, 6, 0, 0}}); - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, - /*window_strides=*/{2, 2}, Padding::kSame, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{2, 2}, Padding::kSame, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } @@ -308,25 +308,26 @@ XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) { // with each other. XLA_TEST_F(SelectAndScatterTest, R2S32SamePaddingOverlappingWindow) { const auto operand = - builder_.ConstantR2({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); + ConstantR2(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}}); const auto source = - builder_.ConstantR2({{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}}); + ConstantR2(&builder_, {{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}}); Array2D expected({{0, 0, 0, 0, 8}, {0, 5, 23, 0, 19}}); - builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, - /*window_strides=*/{1, 1}, Padding::kSame, source, - builder_.ConstantR0(0), add_s32_); + SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{1, 1}, Padding::kSame, source, + ConstantR0(&builder_, 0), add_s32_); ComputeAndCompareR2(&builder_, expected, {}); } XLA_TEST_F(SelectAndScatterTest, R2F32OverlappingR2Source) { - const auto operand = builder_.ConstantR2( - {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}}); - const auto source = builder_.ConstantR2({{1.0f, 2.0f}, {3.0f, 4.0f}}); + const auto operand = ConstantR2( + &builder_, {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}}); + const auto source = + ConstantR2(&builder_, {{1.0f, 2.0f}, {3.0f, 4.0f}}); Array2D expected( {{0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 2.0f}, {3.0f, 0.0f, 4.0f}}); - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2}, - /*window_strides=*/{1, 1}, Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2}, + /*window_strides=*/{1, 1}, Padding::kValid, source, + ConstantR0(&builder_, 0.0f), add_f32_); ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(1e-7)); } @@ -342,16 +343,16 @@ TEST_F(SelectAndScatterTest, R4F32Valid) { {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}}; Array4D o(4, 6, 15, 220); o.FillWithPZ(pzo); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = ConstantR4FromArray4D(&builder_, o); Array4D e(4, 6, 15, 220); e.FillWithPZ(pze); Array4D s(2, 2, 15, 220); s.FillWithPZ(pzs); - auto source = builder_.ConstantR4FromArray4D(s); + auto source = ConstantR4FromArray4D(&builder_, s); s.FillWithPZ(pzs); - builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, - Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, + Padding::kValid, source, ConstantR0(&builder_, 0.0f), + add_f32_); ComputeAndCompareR4(&builder_, e, {}, ErrorSpec(1e-7)); } @@ -367,16 +368,16 @@ TEST_F(SelectAndScatterTest, R4F32Overlap) { {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}}; Array4D o(4, 5, 17, 128); o.FillWithPZ(pzo); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = ConstantR4FromArray4D(&builder_, o); Array4D e(4, 5, 17, 128); e.FillWithPZ(pze); Array4D s(2, 2, 17, 128); s.FillWithPZ(pzs); - auto source = builder_.ConstantR4FromArray4D(s); + auto source = ConstantR4FromArray4D(&builder_, s); s.FillWithPZ(pzs); - builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, - Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, + Padding::kValid, source, ConstantR0(&builder_, 0.0f), + add_f32_); ComputeAndCompareR4(&builder_, e, {}, ErrorSpec(1e-7)); } @@ -392,16 +393,16 @@ TEST_F(SelectAndScatterTest, R4F32OverlapSmall) { {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}}; Array4D o(4, 5, 1, 1); o.FillWithPZ(pzo); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = ConstantR4FromArray4D(&builder_, o); Array4D e(4, 5, 1, 1); e.FillWithPZ(pze); Array4D s(2, 2, 1, 1); s.FillWithPZ(pzs); - auto source = builder_.ConstantR4FromArray4D(s); + auto source = ConstantR4FromArray4D(&builder_, s); s.FillWithPZ(pzs); - builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, - Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1}, + Padding::kValid, source, ConstantR0(&builder_, 0.0f), + add_f32_); ComputeAndCompareR4(&builder_, e, {}, ErrorSpec(1e-7)); } @@ -414,39 +415,39 @@ TEST_F(SelectAndScatterTest, R4F32RefValidFixedSmall) { Array2D pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}}; Array4D o(4, 6, 4, 4); o.FillWithPZ(pzo); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = ConstantR4FromArray4D(&builder_, o); Array4D s(2, 2, 4, 4); s.FillWithPZ(pzs); - auto source = builder_.ConstantR4FromArray4D(s); + auto source = ConstantR4FromArray4D(&builder_, s); s.FillWithPZ(pzs); - builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, - Padding::kValid, source, - builder_.ConstantR0(0.0f), add_f32_); + SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1}, + Padding::kValid, source, ConstantR0(&builder_, 0.0f), + add_f32_); auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 3, 1, 1}, {2, 3, 1, 1}, false); ComputeAndCompareR4(&builder_, *e, {}, ErrorSpec(1e-7)); } XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) { - const auto operand = builder_.ConstantR1({1, 2, 3, 100, 3, 2, 1}); - const auto source = builder_.ConstantR1({34, 42, 53, 19}); + const auto operand = ConstantR1(&builder_, {1, 2, 3, 100, 3, 2, 1}); + const auto source = ConstantR1(&builder_, {34, 42, 53, 19}); const std::vector expected = {0, 0, 0, 53, 0, 0, 0}; - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, - /*window_strides=*/{1}, Padding::kValid, source, - builder_.ConstantR0(0), max_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, + /*window_strides=*/{1}, Padding::kValid, source, + ConstantR0(&builder_, 0), max_f32_); ComputeAndCompareR1(&builder_, expected, {}, ErrorSpec(1e-7)); } XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) { - const auto operand = builder_.ConstantR1({1, 2, 3, 100, 3, 2, 1}); - const auto source = builder_.ConstantR1({34, 42, 53, 19}); + const auto operand = ConstantR1(&builder_, {1, 2, 3, 100, 3, 2, 1}); + const auto source = ConstantR1(&builder_, {34, 42, 53, 19}); const float max_float = std::numeric_limits::max(); const std::vector expected = {max_float, max_float, max_float, 19, max_float, max_float, max_float}; - builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, - /*window_strides=*/{1}, Padding::kValid, source, - builder_.ConstantR0(max_float), min_f32_); + SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4}, + /*window_strides=*/{1}, Padding::kValid, source, + ConstantR0(&builder_, max_float), min_f32_); ComputeAndCompareR1(&builder_, expected, {}, ErrorSpec(1e-7)); } diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 72707f224446c7585d1d90ac6681a7b38c41d5f1..1c01402798658877889527a5dd02d5c74787ff99 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -35,50 +35,52 @@ class SelectTest : public ClientLibraryTestBase { TEST_F(SelectTest, SelectScalarF32True) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto on_true = builder.ConstantR0(123.0f); - auto on_false = builder.ConstantR0(42.0f); - auto result = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, true); + auto on_true = ConstantR0(&builder, 123.0f); + auto on_false = ConstantR0(&builder, 42.0f); + Select(pred, on_true, on_false); ComputeAndCompareR0(&builder, 123.0f, {}, error_spec_); } TEST_F(SelectTest, SelectScalarS32True) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto on_true = builder.ConstantR0(-42); - auto on_false = builder.ConstantR0(42); - auto result = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, true); + auto on_true = ConstantR0(&builder, -42); + auto on_false = ConstantR0(&builder, 42); + Select(pred, on_true, on_false); ComputeAndCompareR0(&builder, -42, {}); } TEST_F(SelectTest, SelectScalarF32False) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto on_true = builder.ConstantR0(123.0f); - auto on_false = builder.ConstantR0(42.0f); - auto result = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, false); + auto on_true = ConstantR0(&builder, 123.0f); + auto on_false = ConstantR0(&builder, 42.0f); + Select(pred, on_true, on_false); ComputeAndCompareR0(&builder, 42.0f, {}, error_spec_); } XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR1({}); - auto on_true = builder.ConstantR1({}); - auto on_false = builder.ConstantR1({}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR1(&builder, {}); + auto on_true = ConstantR1(&builder, {}); + auto on_false = ConstantR1(&builder, {}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR1({false, true, false, true, false}); - auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR1(&builder, {false, true, false, true, false}); + auto on_true = + ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto on_false = + ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, error_spec_); @@ -88,12 +90,12 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector // is not a constant, but rather the result of comparing two other vectors. XlaBuilder builder(TestName()); - auto v1 = builder.ConstantR1({}); - auto v2 = builder.ConstantR1({}); - auto cmp = builder.Eq(v1, v2); - auto on_true = builder.ConstantR1({}); - auto on_false = builder.ConstantR1({}); - auto select = builder.Select(cmp, on_true, on_false); + auto v1 = ConstantR1(&builder, {}); + auto v2 = ConstantR1(&builder, {}); + auto cmp = Eq(v1, v2); + auto on_true = ConstantR1(&builder, {}); + auto on_false = ConstantR1(&builder, {}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -102,12 +104,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is // not a constant, but rather the result of comparing two other vectors. XlaBuilder builder(TestName()); - auto v1 = builder.ConstantR1({1, 2, 3, 4, 5}); - auto v2 = builder.ConstantR1({9, 2, 9, 4, 9}); - auto cmp = builder.Eq(v1, v2); - auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto select = builder.Select(cmp, on_true, on_false); + auto v1 = ConstantR1(&builder, {1, 2, 3, 4, 5}); + auto v2 = ConstantR1(&builder, {9, 2, 9, 4, 9}); + auto cmp = Eq(v1, v2); + auto on_true = + ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto on_false = + ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {}, error_spec_); @@ -116,12 +120,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s. XlaBuilder builder(TestName()); - auto v1 = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); - auto v2 = builder.ConstantR1({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f}); - auto cmp = builder.Gt(v1, v2); - auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto select = builder.Select(cmp, on_true, on_false); + auto v1 = ConstantR1(&builder, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + auto v2 = ConstantR1(&builder, {-1.0f, -2.0f, 13.0f, 14.0f, 4.4f}); + auto cmp = Gt(v1, v2); + auto on_true = + ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto on_false = + ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {}, error_spec_); @@ -140,8 +146,8 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto cmp = builder.Gt(v1, v2); - auto select = builder.Select(cmp, v1, v2); + auto cmp = Gt(v1, v2); + Select(cmp, v1, v2); ComputeAndCompareR1(&builder, {41.0f, 22.0f, 23.0f, 84.0f}, {param0_data.get(), param1_data.get()}, error_spec_); @@ -181,8 +187,8 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { CreateR1Parameter(v2vec, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto cmp = builder.Gt(v1, v2); - auto select = builder.Select(cmp, v1, v2); + auto cmp = Gt(v1, v2); + Select(cmp, v1, v2); ComputeAndCompareR1(&builder, expected_vec, {param0_data.get(), param1_data.get()}, error_spec_); @@ -192,14 +198,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to // select between two R1F32s. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1, -1, 2, -2}); - auto s = builder.ConstantR0(0); - auto cmp = builder.Gt(v, s); + auto v = ConstantR1(&builder, {1, -1, 2, -2}); + auto s = ConstantR0(&builder, 0); + auto cmp = Gt(v, s); - auto on_true = builder.ConstantR1({11.0f, 22.0f, 33.0f, 44.0f}); + auto on_true = ConstantR1(&builder, {11.0f, 22.0f, 33.0f, 44.0f}); auto on_false = - builder.ConstantR1({-111.0f, -222.0f, -333.0f, -444.0f}); - auto select = builder.Select(cmp, on_true, on_false); + ConstantR1(&builder, {-111.0f, -222.0f, -333.0f, -444.0f}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {}, error_spec_); @@ -209,14 +215,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to // select between two R1F32s. XlaBuilder builder(TestName()); - auto v = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f}); - auto s = builder.ConstantR0(2.5f); - auto cmp = builder.Gt(v, s); + auto v = ConstantR1(&builder, {1.0f, 2.0f, 3.0f, 4.0f}); + auto s = ConstantR0(&builder, 2.5f); + auto cmp = Gt(v, s); - auto on_true = builder.ConstantR1({11.0f, 22.0f, 33.0f, 44.0f}); + auto on_true = ConstantR1(&builder, {11.0f, 22.0f, 33.0f, 44.0f}); auto on_false = - builder.ConstantR1({-111.0f, -222.0f, -333.0f, -444.0f}); - auto select = builder.Select(cmp, on_true, on_false); + ConstantR1(&builder, {-111.0f, -222.0f, -333.0f, -444.0f}); + Select(cmp, on_true, on_false); ComputeAndCompareR1(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {}, error_spec_); @@ -225,10 +231,10 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { for (bool which : {false, true}) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(which); - auto on_true = builder.ConstantR1({}); - auto on_false = builder.ConstantR1({}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, which); + auto on_true = ConstantR1(&builder, {}); + auto on_false = ConstantR1(&builder, {}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -236,20 +242,20 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(true); - auto on_true = builder.ConstantR1({-2.5f, 25.5f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, true); + auto on_true = ConstantR1(&builder, {-2.5f, 25.5f}); + auto on_false = ConstantR1(&builder, {10.0f, 5.0f}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {-2.5f, 25.5f}, {}, error_spec_); } TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { XlaBuilder builder(TestName()); - auto pred = builder.ConstantR0(false); - auto on_true = builder.ConstantR1({-2.5f, 25.5f}); - auto on_false = builder.ConstantR1({10.0f, 5.0f}); - auto select = builder.Select(pred, on_true, on_false); + auto pred = ConstantR0(&builder, false); + auto on_true = ConstantR1(&builder, {-2.5f, 25.5f}); + auto on_false = ConstantR1(&builder, {10.0f, 5.0f}); + Select(pred, on_true, on_false); ComputeAndCompareR1(&builder, {10.0f, 5.0f}, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 5653bf11a7364bf9ed79bcb6b53f7db31f454803..b8ad6668f80a3002eff3cc458997966ee67c8d4b 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -42,8 +42,8 @@ TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { values.FillIota(0); XlaBuilder builder(TestName()); - auto original = builder.ConstantR3FromArray3D(values); - builder.Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1}); + auto original = ConstantR3FromArray3D(&builder, values); + Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1}); Array3D expected{ {{0.0}, {3.0}, {6.0}}, {{9.0}, {12.0}, {15.0}}, {{18.0}, {21.0}, {24.0}}}; @@ -55,8 +55,8 @@ TEST_F(SliceTest, Slice3x3x3_To_3x1x3_F32) { values.FillIota(0); XlaBuilder builder(TestName()); - auto original = builder.ConstantR3FromArray3D(values); - builder.Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1}); + auto original = ConstantR3FromArray3D(&builder, values); + Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1}); Array3D expected{ {{0.0, 1.0, 2.0}}, {{9.0, 10.0, 11.0}}, {{18.0, 19.0, 20.0}}}; @@ -68,8 +68,8 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { values.FillIota(0); XlaBuilder builder(TestName()); - auto original = builder.ConstantR3FromArray3D(values); - builder.Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1}); + auto original = ConstantR3FromArray3D(&builder, values); + Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1}); Array3D expected{ {{{0.0, 1.0, 2.0}, {3.0, 4.0, 5.0}, {6.0, 7.0, 8.0}}}}; @@ -78,24 +78,24 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(Array2D(0, 0)); - builder.Slice(original, {0, 0}, {0, 0}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, Array2D(0, 0)); + Slice(original, {0, 0}, {0, 0}, {1, 1}); ComputeAndCompareR2(&builder, Array2D(0, 0), {}); } XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(Array2D(0, 20)); - builder.Slice(original, {0, 15}, {0, 20}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, Array2D(0, 20)); + Slice(original, {0, 15}, {0, 20}, {1, 1}); ComputeAndCompareR2(&builder, Array2D(0, 5), {}); } XLA_TEST_F(SliceTest, Slice3x0to2x0F32) { XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(Array2D(3, 0)); - builder.Slice(original, {1, 0}, {3, 0}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, Array2D(3, 0)); + Slice(original, {1, 0}, {3, 0}, {1, 1}); ComputeAndCompareR2(&builder, Array2D(2, 0), {}); } @@ -109,8 +109,8 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) { } XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(values); - builder.Slice(original, {128, 128}, {256, 256}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, values); + Slice(original, {128, 128}, {256, 256}, {1, 1}); Array2D expected(128, 128); for (int row = 0; row < 128; ++row) { @@ -127,8 +127,8 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) { std::iota(values.data(), values.data() + 4096, 0.0); XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(values); - builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, values); + Slice(original, {0, 3072}, {1, 4096}, {1, 1}); Array2D expected(1, 1024); std::iota(expected.data(), expected.data() + 1024, 3072.0); @@ -148,8 +148,8 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) { } } XlaBuilder builder(TestName()); - auto original = builder.ConstantR2FromArray2D(values); - builder.Slice(original, {0, 0}, {16, 2}, {1, 1}); + auto original = ConstantR2FromArray2D(&builder, values); + Slice(original, {0, 0}, {16, 2}, {1, 1}); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); } @@ -160,8 +160,8 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { auto expected = ReferenceUtil::Slice4D( values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}}); XlaBuilder builder(TestName()); - auto original = builder.ConstantR4FromArray4D(values); - builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); + auto original = ConstantR4FromArray4D(&builder, values); + Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); } @@ -170,11 +170,11 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { values.FillRandom(3.14f); auto expected = ReferenceUtil::Slice4D(values, {{0, 0, 0, 0}}, {{2, 4, 6, 8}}, /*strides=*/{{1, 1, 2, 1}}); - auto expected_literal = Literal::CreateR4FromArray4DWithLayout( + auto expected_literal = LiteralUtil::CreateR4FromArray4DWithLayout( *expected, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); - auto original = builder.ConstantR4FromArray4D(values); - builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); + auto original = ConstantR4FromArray4D(&builder, values); + Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001), &expected_literal->shape()); } @@ -197,12 +197,12 @@ class SliceR1Test : public ClientLibraryTestBase, // vector. tensorflow::gtl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); - auto literal = Literal::CreateR1(input); + auto literal = LiteralUtil::CreateR1(input); XlaBuilder builder(TestName()); - auto original = builder.Parameter(0, literal->shape(), "p0"); - builder.Slice(original, {spec.slice_start}, {spec.slice_limit}, - {spec.slice_stride}); + auto original = Parameter(&builder, 0, literal->shape(), "p0"); + Slice(original, {spec.slice_start}, {spec.slice_limit}, + {spec.slice_stride}); // Ditto. tensorflow::gtl::InlinedVector expected; @@ -344,7 +344,11 @@ INSTANTIATE_TEST_CASE_P( R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 2}, R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 8}, R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 7}, - R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 125} + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 125}, + R1Spec{16 * 1024 * 1024, 0, 16 * 1024 * 1024, 4097}, + R1Spec{16 * 1024 * 1024, 0, 16 * 1024 * 1024, 4093}, + R1Spec{16 * 1024 * 1024, 12 * 1024 + 17, 16 * 1024 * 1024 - 231, 4097}, + R1Spec{16 * 1024 * 1024, 12 * 1024 + 17, 16 * 1024 * 1024 - 231, 4093} ), SliceR1TestDataToString ); @@ -368,12 +372,12 @@ XLA_TEST_P(SliceR2Test, DoIt) { const R2Spec& spec = GetParam(); Array2D input(spec.input_dim0, spec.input_dim1); input.FillUnique(); - auto literal = Literal::CreateR2FromArray2DWithLayout( + auto literal = LiteralUtil::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(spec.layout)); XlaBuilder builder(TestName()); - auto a = builder.Parameter(0, literal->shape(), "p0"); - builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); + auto a = Parameter(&builder, 0, literal->shape(), "p0"); + Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, client_->TransferToServer(*literal)); @@ -463,13 +467,12 @@ class SliceR4Test : public ClientLibraryTestBase, auto expected = ReferenceUtil::Slice4D( values, spec.slice_starts, spec.slice_limits, spec.slice_strides); XlaBuilder builder(TestName()); - auto literal = Literal::CreateR4FromArray4DWithLayout( + auto literal = LiteralUtil::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); - auto parameter = builder.Parameter(0, literal->shape(), "p0"); + auto parameter = Parameter(&builder, 0, literal->shape(), "p0"); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, client_->TransferToServer(*literal)); - builder.Slice(parameter, spec.slice_starts, spec.slice_limits, - spec.slice_strides); + Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001)); } }; diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index dd7c541733634213606b5a7983b59bb1f14bf75c..2647937013222ccfdae98b0c1d141f461020b5c9 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" @@ -110,7 +111,7 @@ StatusOr> MakeFakeLiteralInternal( MakeFakeLiteralInternal(element_shape, engine)); elements.push_back(std::move(element)); } - return Literal::MakeTupleOwned(std::move(elements)); + return LiteralUtil::MakeTupleOwned(std::move(elements)); } if (engine == nullptr) { return Literal::CreateFromShape(shape); @@ -161,6 +162,9 @@ StatusOr> MakeFakeLiteralInternal( })); break; } + // Token requires no data. + case TOKEN: + break; default: return Unimplemented("Unsupported type for fake literal generation: %s", ShapeUtil::HumanString(shape).c_str()); @@ -217,7 +221,7 @@ std::unique_ptr MakeRandomNonwrappingSliceIndex( start_indices[i] = generator(*engine); } } - return Literal::CreateR1(start_indices); + return LiteralUtil::CreateR1(start_indices); } // Use dataflow analysis on each parameter to see if there are uses that would @@ -270,14 +274,22 @@ StatusOr> CreateLiteralForConstrainedUses( switch (use->opcode()) { case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: - if (needs_index != nullptr && - !ShapeUtil::Equal(needs_index->shape(), use->shape())) { - return Unimplemented( - "Conflicting operand generation slice index constraints\n"); + if (needs_index != nullptr) { + auto needs_index_shape = needs_index->shape(); + auto use_shape = use->shape(); + if (needs_index->opcode() == HloOpcode::kDynamicSlice) { + needs_index_shape = needs_index->operand(0)->shape(); + } + if (use->opcode() == HloOpcode::kDynamicSlice) { + use_shape = use->operand(0)->shape(); + } + if (!ShapeUtil::Equal(needs_index_shape, use_shape)) { + return Unimplemented( + "Conflicting operand generation slice index constraints\n"); + } } needs_index = use; break; - case HloOpcode::kReduce: case HloOpcode::kReduceWindow: needs_constant = use; @@ -307,9 +319,9 @@ StatusOr> CreateLiteralForConstrainedUses( } else if (needs_constant != nullptr) { switch (constant_type) { case ConstantType::kZero: - return Literal::Zero(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); case ConstantType::kOne: - return Literal::One(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::One(param.shape().element_type()).CloneToUnique(); case ConstantType::kUnknown: // We want the identity element for the computation, but we don't really // know what it is - so any value we generate will be just as wrong. diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index a8689f64981569ceb7c8a712f8ece00c99e8cf2d..e59f215a9a3ace80d7a23e1bbc40970c7a63ea0d 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 59afd28a80c0fbf3df38457cd05961c883769856..a2f0338e25977d7c76dbc48b3afc649b77ba4ee2 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -15,7 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -31,16 +32,16 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) { XlaBuilder builder(TestName()); // Make the reduction lambda. Shape single_float = ShapeUtil::MakeShape(F32, {}); - builder.Parameter(0, single_float, "unused"); - builder.Parameter(1, single_float, "used"); + Parameter(&builder, 0, single_float, "unused"); + Parameter(&builder, 1, single_float, "used"); auto computation_status = builder.Build(); TF_ASSERT_OK(computation_status.status()); // Make the reduction. Shape pair_float = ShapeUtil::MakeShape(F32, {2}); - builder.Reduce(builder.Parameter(0, pair_float, "operand"), - builder.Parameter(1, single_float, "init"), - computation_status.ValueOrDie(), {0}); + Reduce(Parameter(&builder, 0, pair_float, "operand"), + Parameter(&builder, 1, single_float, "init"), + computation_status.ValueOrDie(), {0}); computation_status = builder.Build(); TF_ASSERT_OK(computation_status.status()); @@ -53,5 +54,23 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) { TF_ASSERT_OK(MakeFakeArguments(&module).status()); } +XLA_TEST_F(TestUtilsTest, Token) { + auto module = ParseHloString( + R"(HloModule outfeed_module + + ENTRY InfeedToOutfeed { + token = token[] parameter(0) + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 + outfeed = token[] outfeed(infeed.data, token) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 + infeed.1.token = token[] get-tuple-element(infeed.1), index=1 + outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) + })") + .ValueOrDie(); + TF_ASSERT_OK(MakeFakeArguments(module.get()).status()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 4585244ce81c14ab6d4d629bb7d208d73c82248d..2bdbd08309a81b201fc224110805549f7fb5bb55 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -28,32 +28,32 @@ namespace { class TokenHloTest : public HloTestBase {}; -// TODO(b/79770375): Compile, not just verify the HLO module when the backends -// support kGenerateToken. XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { std::unique_ptr module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + builder.AddInstruction(HloInstruction::CreateToken()); module->AddEntryComputation(builder.Build()); - EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, TokenTree) { std::unique_ptr module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto token0 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - auto token1 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - auto token2 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); + auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); + auto token2 = builder.AddInstruction(HloInstruction::CreateToken()); builder.AddInstruction( - HloInstruction::CreateGenerateToken({token0, token0, token1, token2})); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateAfterAll({token0, token0, token1, token2})); module->AddEntryComputation(builder.Build()); - EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { @@ -64,7 +64,7 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { builder.AddInstruction( HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1")); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); module->AddEntryComputation(builder.Build()); Status status = HloVerifier().Run(module.get()).status(); @@ -91,26 +91,14 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); } -XLA_TEST_F(TokenHloTest, InvalidTokenRoot) { - std::unique_ptr module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateGenerateToken({})); - module->AddEntryComputation(builder.Build()); - - Status status = HloVerifier().Run(module.get()).status(); - ASSERT_IS_NOT_OK(status); - EXPECT_THAT(status.error_message(), - ::testing::HasSubstr("Entry root is or contains a token shape")); -} - XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { std::unique_ptr module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); - builder.AddInstruction(HloInstruction::CreateGenerateToken({param})); + builder.AddInstruction(HloInstruction::CreateAfterAll({param})); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); module->AddEntryComputation(builder.Build()); Status status = HloVerifier().Run(module.get()).status(); @@ -120,5 +108,99 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { "Operands of token instructions must be TOKEN types")); } +XLA_TEST_F(TokenHloTest, TokenInWhileLoop) { + // Thread a token around a while loop. Token is created and consumed by a + // AfterAll instruction in the while body. + string module_string = R"( +HloModule TokenInWhileLoop + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %TokenInWhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + DebugOptions debug_options = GetDebugOptionsForTest(); + // Module DCE pass removes the generate token instructions. + debug_options.add_xla_disable_hlo_passes("hlo-module-dce"); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(module_string, debug_options)); + + EXPECT_TRUE(RunAndCompare(std::move(module), error_spec_)); +} + +XLA_TEST_F(TokenHloTest, TokenInConditional) { + string module_string = R"( +HloModule TokenInConditional + +%True (param.1: token[]) -> (s32[], token[]) { + %param.1 = token[] parameter(0) + %forty_two = s32[] constant(42) + ROOT %tuple = (s32[], token[]) tuple(s32[] %forty_two, token[] %param.1) +} + +%False (param.2: s32[]) -> (s32[], token[]) { + %param.2 = s32[] parameter(0) + %new_token = token[] after-all() + ROOT %tuple = (s32[], token[]) tuple(s32[] %param.2, token[] %new_token) +} + +ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { + %param.3 = pred[] parameter(0) + %init_token = token[] after-all() + %seven = s32[] constant(7) + %cond = (s32[], token[]) conditional(pred[] %param.3, token[] %init_token, s32[] %seven), true_computation=True, false_computation=False + ROOT %root = s32[] get-tuple-element((s32[], token[]) %cond), index=0 +} +)"; + + DebugOptions debug_options = GetDebugOptionsForTest(); + // Module DCE pass removes the generate token instructions. + debug_options.add_xla_disable_hlo_passes("hlo-module-dce"); + + { + // True case. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(module_string, debug_options)); + auto arg = LiteralUtil::CreateR0(true); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {arg.get()})); + EXPECT_EQ(42, result->Get({})); + } + + { + // False case. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(module_string, debug_options)); + auto arg = LiteralUtil::CreateR0(false); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {arg.get()})); + EXPECT_EQ(7, result->Get({})); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 0063e7ad415e9b6718c164f415ced6fb76cbf44a..125513ddfd16cb4e742e7d589e22b721307621ee 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -18,10 +18,11 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -41,7 +43,12 @@ class TransferManagerTest : public LocalClientTestBase { TransferManagerTest() : shape_size_fn_([this](const Shape& shape) { return transfer_manager_->GetByteSizeRequirement(shape); - }) {} + }) { + stream_ptr_ = local_client_->mutable_backend() + ->BorrowStream(stream_executor_) + .ValueOrDie(); + stream_ = stream_ptr_.get(); + } ~TransferManagerTest() override = default; @@ -53,37 +60,41 @@ class TransferManagerTest : public LocalClientTestBase { .ValueOrDie(); } + protected: + StreamPool::Ptr stream_ptr_; + se::Stream* stream_; + private: std::function shape_size_fn_; }; XLA_TEST_F(TransferManagerTest, TransferR0U32) { - std::unique_ptr literal = Literal::CreateR0(42); + std::unique_ptr literal = LiteralUtil::CreateR0(42); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR0Equal(42, *result); } XLA_TEST_F(TransferManagerTest, TransferR1F32) { std::unique_ptr literal = - Literal::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); + LiteralUtil::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, *result); @@ -92,48 +103,48 @@ XLA_TEST_F(TransferManagerTest, TransferR1F32) { XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { std::vector test_vector(1024 * 1024); std::iota(test_vector.begin(), test_vector.end(), 0); - std::unique_ptr literal = Literal::CreateR1(test_vector); + std::unique_ptr literal = LiteralUtil::CreateR1(test_vector); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR1Equal(test_vector, *result); } XLA_TEST_F(TransferManagerTest, TransferR1U8) { const char* test_string = "0123456789abcdef"; - std::unique_ptr literal = Literal::CreateR1U8(test_string); + std::unique_ptr literal = LiteralUtil::CreateR1U8(test_string); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_EQ(result->GetR1U8AsString(), test_string); } XLA_TEST_F(TransferManagerTest, TransferR2F32) { std::unique_ptr literal = - Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); @@ -141,7 +152,7 @@ XLA_TEST_F(TransferManagerTest, TransferR2F32) { XLA_TEST_F(TransferManagerTest, TransferR2F32AndChangeLayoutTransferringToDevice) { - std::unique_ptr literal = Literal::CreateR2WithLayout( + std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1})); const Shape ondevice_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); @@ -149,11 +160,11 @@ XLA_TEST_F(TransferManagerTest, // Round trip literal through device. Set the on-device layout to something // different than the literal layout. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_FALSE( LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); @@ -162,89 +173,237 @@ XLA_TEST_F(TransferManagerTest, } XLA_TEST_F(TransferManagerTest, TransferTuple) { - std::unique_ptr literal = Literal::MakeTuple( - {Literal::CreateR0(123.0f).get(), - Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); + std::unique_ptr literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(123.0f).get(), + LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { - std::unique_ptr literal = Literal::MakeTuple({}); + std::unique_ptr literal = LiteralUtil::MakeTuple({}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { - std::unique_ptr literal = Literal::MakeTuple( - {Literal::CreateR0(123.0f).get(), - Literal::MakeTuple( - {Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) + std::unique_ptr literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(123.0f).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) .get(), - Literal::CreateR1({-10.0f, 123.0f}).get()}); + LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValue) { - std::unique_ptr literal = Literal::CreateR1( + std::unique_ptr literal = LiteralUtil::CreateR1( {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { - std::unique_ptr literal = Literal::MakeTuple( - {Literal::CreateR1( + std::unique_ptr literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1( {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}) .get(), - Literal::CreateR1({1, 2, 3, 4, 5, 6}).get(), - Literal::CreateR0(complex64(0.3f, -0.4f)).get()}); + LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6}).get(), + LiteralUtil::CreateR0(complex64(0.3f, -0.4f)).get()}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } +XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { + // "Copy" a token from the device. The token has no physical representation so + // no copying is actually performed, but it shouldn't fail. + // TODO(b/110532604): Add transferring the token to device when this is + // supported. + auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); + EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result)); +} + +XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { + const int64 kIterationCount = 5000; + std::unique_ptr literal1 = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(123.0f).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) + .get(), + LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); + std::unique_ptr literal2 = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(456.0f).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(), + LiteralUtil::CreateR1({44.0f, -11.0f, 3333333.3f}).get()}) + .get(), + LiteralUtil::CreateR1({-98.0f, 153.0f}).get()}); + + auto device_buffer1 = AllocateDeviceBuffer(literal1->shape()); + auto device_buffer2 = AllocateDeviceBuffer(literal2->shape()); + + auto stream1 = stream_; + auto stream2 = stream_->GetOrCreateSubStream(); + + std::unique_ptr result1, result2; + + // Round trip literals through device in multiple streams asynchronously. + for (int i = 0; i < kIterationCount; ++i) { + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1, + device_buffer1)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2, + device_buffer2)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr this_result1, + transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr this_result2, + transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2)); + result1 = std::move(this_result1); + result2 = std::move(this_result2); + } + + EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1)); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2)); +} + +class TransferDeviceToHostBenchmark : public TransferManagerTest { + public: + using TransferManagerTest::TransferManagerTest; + ~TransferDeviceToHostBenchmark() override {} + + void Run(int iters, int num_tuple_elements, int array_size) { + tensorflow::testing::StopTiming(); + SetUp(); + + std::vector> tuple_elements; + for (int i = 0; i < num_tuple_elements; ++i) { + tuple_elements.push_back( + LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); + } + std::unique_ptr literal = + LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); + } + tensorflow::testing::StopTiming(); + TearDown(); + } + + void TestBody() override {} +}; + +class TransferHostToDeviceBenchmark : public TransferManagerTest { + public: + using TransferManagerTest::TransferManagerTest; + ~TransferHostToDeviceBenchmark() override {} + + void Run(int iters, int num_tuple_elements, int array_size) { + tensorflow::testing::StopTiming(); + SetUp(); + + std::vector> tuple_elements; + for (int i = 0; i < num_tuple_elements; ++i) { + tuple_elements.push_back( + LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); + } + std::unique_ptr literal = + LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + } + tensorflow::testing::StopTiming(); + TearDown(); + } + + void TestBody() override {} +}; + +void BM_TransferDeviceToHost(int iters, int num_tuple_elements, + int array_size) { + TransferDeviceToHostBenchmark bm; + bm.Run(iters, num_tuple_elements, array_size); +} + +void BM_TransferHostToDevice(int iters, int num_tuple_elements, + int array_size) { + TransferHostToDeviceBenchmark bm; + bm.Run(iters, num_tuple_elements, array_size); +} + +BENCHMARK(BM_TransferHostToDevice) + ->ArgPair(1, 256) + ->ArgPair(1, 257) + ->ArgPair(100, 256) + ->ArgPair(100, 257); + +BENCHMARK(BM_TransferDeviceToHost) + ->ArgPair(1, 256) + ->ArgPair(1, 257) + ->ArgPair(100, 256) + ->ArgPair(100, 257); + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + tensorflow::testing::RunBenchmarks(); + return RUN_ALL_TESTS(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index fe1e3da7eca00e128377e6e56af877868aafa836..fbe9d1b64aa0c06d65b547c45cfa981800d40ff3 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -38,34 +38,35 @@ class TransposeTest : public ClientLibraryTestBase { XLA_TEST_F(TransposeTest, Transpose0x0) { XlaBuilder builder("Transpose"); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 0)); - auto result = builder.Transpose(lhs, {1, 0}); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 0)); + Transpose(lhs, {1, 0}); ComputeAndCompareR2(&builder, Array2D(0, 0), {}, error_spec_); } XLA_TEST_F(TransposeTest, Transpose0x42) { XlaBuilder builder("Transpose"); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 42)); - auto result = builder.Transpose(lhs, {1, 0}); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 42)); + Transpose(lhs, {1, 0}); ComputeAndCompareR2(&builder, Array2D(42, 0), {}, error_spec_); } XLA_TEST_F(TransposeTest, Transpose7x0) { XlaBuilder builder("Transpose"); - auto lhs = builder.ConstantR2FromArray2D(Array2D(7, 0)); - auto result = builder.Transpose(lhs, {1, 0}); + auto lhs = ConstantR2FromArray2D(&builder, Array2D(7, 0)); + Transpose(lhs, {1, 0}); ComputeAndCompareR2(&builder, Array2D(0, 7), {}, error_spec_); } TEST_F(TransposeTest, Transpose2x2) { XlaBuilder builder("Transpose"); - auto lhs = builder.ConstantR2({ - {1.0, 2.0}, {3.0, 4.0}, - }); - auto result = builder.Transpose(lhs, {1, 0}); + auto lhs = ConstantR2(&builder, { + {1.0, 2.0}, + {3.0, 4.0}, + }); + Transpose(lhs, {1, 0}); Array2D expected({{1.0f, 3.0f}, {2.0f, 4.0f}}); @@ -74,16 +75,18 @@ TEST_F(TransposeTest, Transpose2x2) { XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { XlaBuilder builder("Transpose"); - auto operand = builder.ConstantR3FromArray3D(Array3D(0, 2, 3)); - auto result = builder.Transpose(operand, {1, 2, 0}); + auto operand = + ConstantR3FromArray3D(&builder, Array3D(0, 2, 3)); + Transpose(operand, {1, 2, 0}); ComputeAndCompareR3(&builder, Array3D(2, 3, 0), {}); } TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { XlaBuilder builder("Transpose"); - auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); - auto result = builder.Transpose(operand, {1, 2, 0}); + auto operand = + ConstantR3FromArray3D(&builder, {{{1, 2, 3}, {4, 5, 6}}}); + Transpose(operand, {1, 2, 0}); Array3D expected({{{1}, {2}, {3}}, {{4}, {5}, {6}}}); @@ -92,8 +95,9 @@ TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { XlaBuilder builder("Transpose"); - auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); - auto result = builder.Transpose(operand, {2, 1, 0}); + auto operand = + ConstantR3FromArray3D(&builder, {{{1, 2, 3}, {4, 5, 6}}}); + Transpose(operand, {2, 1, 0}); Array3D expected({{{1}, {4}}, {{2}, {5}}, {{3}, {6}}}); @@ -102,8 +106,9 @@ TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { TEST_F(TransposeTest, Transpose1x2x3_1x2x3) { XlaBuilder builder("Transpose"); - auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); - auto result = builder.Transpose(operand, {0, 1, 2}); + auto operand = + ConstantR3FromArray3D(&builder, {{{1, 2, 3}, {4, 5, 6}}}); + Transpose(operand, {0, 1, 2}); Array3D expected({{{1, 2, 3}, {4, 5, 6}}}); @@ -116,9 +121,9 @@ TEST_F(TransposeTest, MultiTranspose3x2) { for (int transposes = 0; transposes <= 10; ++transposes) { XlaBuilder builder("Transpose"); - auto computed = builder.ConstantR2FromArray2D(input); + auto computed = ConstantR2FromArray2D(&builder, input); for (int i = 0; i < transposes; ++i) { - computed = builder.Transpose(computed, {1, 0}); + computed = Transpose(computed, {1, 0}); } const Array2D& expected = transposes % 2 == 0 ? input : transposed; ComputeAndCompareR2(&builder, expected, {}, error_spec_); @@ -130,8 +135,8 @@ TEST_F(TransposeTest, Small_1x1) { auto aoperand = MakeLinspaceArray2D(0.0, 1.0, 1, 1); XlaBuilder builder("transpose_1x1"); - auto operand = builder.ConstantR2FromArray2D(*aoperand); - builder.Transpose(operand, {1, 0}); + auto operand = ConstantR2FromArray2D(&builder, *aoperand); + Transpose(operand, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*aoperand); ComputeAndCompareR2(&builder, *expected, {}, ErrorSpec(1e-4)); @@ -142,8 +147,8 @@ TEST_F(TransposeTest, Small_2x2) { auto aoperand = MakeLinspaceArray2D(0.0, 4.0, 2, 2); XlaBuilder builder("transpose_2x2"); - auto operand = builder.ConstantR2FromArray2D(*aoperand); - builder.Transpose(operand, {1, 0}); + auto operand = ConstantR2FromArray2D(&builder, *aoperand); + Transpose(operand, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*aoperand); ComputeAndCompareR2(&builder, *expected, {}, ErrorSpec(1e-4)); @@ -162,8 +167,8 @@ void TransposeTest::TestTransposeConstant021(size_t n1, size_t n2, size_t n3) { } XlaBuilder builder(TestName()); - auto operand = builder.ConstantR3FromArray3D(aoperand); - builder.Transpose(operand, {0, 2, 1}); + auto operand = ConstantR3FromArray3D(&builder, aoperand); + Transpose(operand, {0, 2, 1}); ComputeAndCompareR3(&builder, expected, {}); } diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 41189231b90e842292830a932cf381af60456d4c..97bbf80aff80e995ea5cdd3e5d8807ee4d380067 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -49,12 +50,12 @@ XLA_TEST_F(TupleTest, TupleConstant) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - auto value = - Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), - Literal::CreateR1(constant_vector).get(), - Literal::CreateR2(constant_matrix).get()}); + auto value = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(constant_scalar).get(), + LiteralUtil::CreateR1(constant_vector).get(), + LiteralUtil::CreateR2(constant_matrix).get()}); - builder.ConstantLiteral(*value); + ConstantLiteral(&builder, *value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } @@ -64,11 +65,11 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { const float constant_scalar1 = 7.3f; const float constant_scalar2 = 1.2f; - auto value = - Literal::MakeTuple({Literal::CreateR0(constant_scalar1).get(), - Literal::CreateR0(constant_scalar2).get()}); + auto value = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(constant_scalar1).get(), + LiteralUtil::CreateR0(constant_scalar2).get()}); - builder.ConstantLiteral(*value); + ConstantLiteral(&builder, *value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } @@ -82,14 +83,14 @@ XLA_TEST_F(TupleTest, TupleCreate) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - builder.Tuple({builder.ConstantR0(constant_scalar), - builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); - - auto expected = - Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), - Literal::CreateR1(constant_vector).get(), - Literal::CreateR2(constant_matrix).get()}); + Tuple(&builder, {ConstantR0(&builder, constant_scalar), + ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); + + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(constant_scalar).get(), + LiteralUtil::CreateR1(constant_vector).get(), + LiteralUtil::CreateR2(constant_matrix).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -97,19 +98,20 @@ XLA_TEST_F(TupleTest, TupleCreate) { XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { XlaBuilder builder(TestName()); - builder.Tuple( - {builder.ConstantR0(7.0), builder.ConstantR1({})}); + Tuple(&builder, + {ConstantR0(&builder, 7.0), ConstantR1(&builder, {})}); - auto expected = Literal::MakeTuple({Literal::CreateR0(7.0).get(), - Literal::CreateR1({}).get()}); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(7.0).get(), + LiteralUtil::CreateR1({}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } // Tests the creation of an empty tuple. XLA_TEST_F(TupleTest, EmptyTupleCreate) { XlaBuilder builder(TestName()); - builder.Tuple({}); - auto expected = Literal::MakeTuple({}); + Tuple(&builder, {}); + auto expected = LiteralUtil::MakeTuple({}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -121,9 +123,10 @@ XLA_TEST_F(TupleTest, GetTupleElement) { {1.f, 2.f, 3.f}, // row 0 {4.f, 5.f, 6.f}, // row 1 }; - auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); - builder.GetTupleElement(tuple_data, 1); + auto tuple_data = + Tuple(&builder, {ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); + GetTupleElement(tuple_data, 1); ComputeAndCompareR2(&builder, Array2D(constant_matrix), {}, error_spec_); } @@ -131,17 +134,18 @@ XLA_TEST_F(TupleTest, GetTupleElement) { // Trivial test for extracting a tuple element with GetTupleElement. XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) { XlaBuilder builder(TestName()); - auto tuple_data = builder.Tuple( - {builder.ConstantR1({}), - builder.ConstantR2FromArray2D(Array2D(0, 101))}); - builder.GetTupleElement(tuple_data, 1); + auto tuple_data = + Tuple(&builder, + {ConstantR1(&builder, {}), + ConstantR2FromArray2D(&builder, Array2D(0, 101))}); + GetTupleElement(tuple_data, 1); ComputeAndCompareR2(&builder, Array2D(0, 101), {}, error_spec_); } XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) { XlaBuilder builder(TestName()); - auto value = builder.ConstantR1({4.5f}); - builder.GetTupleElement(value, 1); + auto value = ConstantR1(&builder, {4.5f}); + GetTupleElement(value, 1); auto result_status = builder.Build(); EXPECT_FALSE(result_status.ok()); EXPECT_THAT( @@ -158,14 +162,15 @@ XLA_TEST_F(TupleTest, AddTupleElements) { {1.f, 2.f, 3.f}, // row 0 {4.f, 5.f, 6.f}, // row 1 }; - auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); - auto vector_element = builder.GetTupleElement(tuple_data, 0); - auto matrix_element = builder.GetTupleElement(tuple_data, 1); + auto tuple_data = + Tuple(&builder, {ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); + auto vector_element = GetTupleElement(tuple_data, 0); + auto matrix_element = GetTupleElement(tuple_data, 1); auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie(); auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie(); - builder.Add(matrix_element, vector_element, - /*broadcast_dimensions=*/{1}); + Add(matrix_element, vector_element, + /*broadcast_dimensions=*/{1}); Array2D expected({ {2.f, 4.f, 6.f}, // row 0 @@ -185,13 +190,14 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { {1.f, 2.f, 3.f}, // row 0 {4.f, 5.f, 6.f}, // row 1 }; - auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); - builder.Tuple({builder.GetTupleElement(tuple_data, 1), - builder.GetTupleElement(tuple_data, 0)}); - auto expected = - Literal::MakeTuple({Literal::CreateR2(constant_matrix).get(), - Literal::CreateR1(constant_vector).get()}); + auto tuple_data = + Tuple(&builder, {ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); + Tuple(&builder, + {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)}); + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2(constant_matrix).get(), + LiteralUtil::CreateR1(constant_vector).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -206,14 +212,14 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { std::unique_ptr v2_data = CreateR0Parameter(1.0f, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&b, /*data_handle=*/&v2); - auto v1_gt = b.Gt(v1, v2); // false - auto v2_gt = b.Gt(v2, v1); // true - auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true} - auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false} - b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); + auto v1_gt = Gt(v1, v2); // false + auto v2_gt = Gt(v2, v1); // true + auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true} + auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false} + Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); auto expected = - Literal::MakeTuple({Literal::CreateR0(direction).get(), - Literal::CreateR0(!direction).get()}); + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(direction).get(), + LiteralUtil::CreateR0(!direction).get()}); ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, error_spec_); @@ -243,22 +249,23 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { {1.f, 2.f, 3.f}, // row 0 {4.f, 5.f, 6.f}, // row 1 }; - auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); - auto new_tuple01 = builder.Tuple({builder.GetTupleElement(tuple_data, 0), - builder.GetTupleElement(tuple_data, 1)}); - auto new_tuple10 = builder.Tuple({builder.GetTupleElement(tuple_data, 1), - builder.GetTupleElement(tuple_data, 0)}); - auto vector_from_01 = builder.GetTupleElement(new_tuple01, 0); - auto vector_from_10 = builder.GetTupleElement(new_tuple10, 1); - auto matrix_from_01 = builder.GetTupleElement(new_tuple01, 1); - auto matrix_from_10 = builder.GetTupleElement(new_tuple10, 0); - - auto addvectors = builder.Add(vector_from_01, vector_from_10); - auto addmatrices = builder.Add(matrix_from_01, matrix_from_10); - - builder.Add(addmatrices, addvectors, - /*broadcast_dimensions=*/{1}); + auto tuple_data = + Tuple(&builder, {ConstantR1(&builder, constant_vector), + ConstantR2(&builder, constant_matrix)}); + auto new_tuple01 = Tuple(&builder, {GetTupleElement(tuple_data, 0), + GetTupleElement(tuple_data, 1)}); + auto new_tuple10 = Tuple(&builder, {GetTupleElement(tuple_data, 1), + GetTupleElement(tuple_data, 0)}); + auto vector_from_01 = GetTupleElement(new_tuple01, 0); + auto vector_from_10 = GetTupleElement(new_tuple10, 1); + auto matrix_from_01 = GetTupleElement(new_tuple01, 1); + auto matrix_from_10 = GetTupleElement(new_tuple10, 0); + + auto addvectors = Add(vector_from_01, vector_from_10); + auto addmatrices = Add(matrix_from_01, matrix_from_10); + + Add(addmatrices, addvectors, + /*broadcast_dimensions=*/{1}); Array2D expected({ {4.f, 8.f, 12.f}, // row 0 @@ -273,14 +280,15 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); - - builder.Select(builder.ConstantR0(false), tuple12, tuple21); - auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), - Literal::CreateR1(vec1).get()}); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); + + Select(ConstantR0(&builder, false), tuple12, tuple21); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), + LiteralUtil::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -292,22 +300,22 @@ XLA_TEST_F(TupleTest, TuplesInAMap) { // Need to put a select in there to prevent HLO-level optimizations from // optimizing out the tuples. XlaBuilder b("sort_square"); - auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto x2 = b.Mul(x, x); - auto x_smaller_tuple = b.Tuple({x, x2}); - auto x2_smaller_tuple = b.Tuple({x2, x}); - auto sorted = b.Select(b.Lt(x, x2), x_smaller_tuple, x2_smaller_tuple); - auto smaller = b.GetTupleElement(sorted, 0); - auto greater = b.GetTupleElement(sorted, 1); - b.Add(greater, b.Mul(b.ConstantR0(100.0f), smaller)); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto x2 = Mul(x, x); + auto x_smaller_tuple = Tuple(&b, {x, x2}); + auto x2_smaller_tuple = Tuple(&b, {x2, x}); + auto sorted = Select(Lt(x, x2), x_smaller_tuple, x2_smaller_tuple); + auto smaller = GetTupleElement(sorted, 0); + auto greater = GetTupleElement(sorted, 1); + Add(greater, Mul(ConstantR0(&b, 100.0f), smaller)); auto computation_status = b.Build(); ASSERT_IS_OK(computation_status.status()); tuple_computation = computation_status.ConsumeValueOrDie(); } XlaBuilder b(TestName()); - auto input = b.ConstantR1({-1.0f, 1.0f, 2.1f}); - b.Map({input}, tuple_computation, {0}); + auto input = ConstantR1(&b, {-1.0f, 1.0f, 2.1f}); + Map(&b, {input}, tuple_computation, {0}); ComputeAndCompareR1(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_); } @@ -317,14 +325,15 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); - - builder.Select(builder.ConstantR0(true), tuple12, tuple21); - auto expected = Literal::MakeTuple({Literal::CreateR1(vec1).get(), - Literal::CreateR1(vec2).get()}); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); + + Select(ConstantR0(&builder, true), tuple12, tuple21); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec1).get(), + LiteralUtil::CreateR1(vec2).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -335,14 +344,13 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); - auto select = - builder.Select(builder.ConstantR0(false), tuple12, tuple21); - builder.GetTupleElement(select, 0); + auto select = Select(ConstantR0(&builder, false), tuple12, tuple21); + GetTupleElement(select, 0); ComputeAndCompareR1(&builder, vec2, {}, error_spec_); } @@ -371,19 +379,16 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesCascaded) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto pred_tuple = builder.Tuple( - {builder.ConstantR0(true), builder.ConstantR0(false)}); - auto tuple12 = builder.Tuple( - {builder.ConstantR1(vec1), builder.ConstantR1(vec2)}); - auto tuple21 = builder.Tuple( - {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); + auto pred_tuple = Tuple(&builder, {ConstantR0(&builder, true), + ConstantR0(&builder, false)}); + auto tuple12 = Tuple(&builder, {ConstantR1(&builder, vec1), + ConstantR1(&builder, vec2)}); + auto tuple21 = Tuple(&builder, {ConstantR1(&builder, vec2), + ConstantR1(&builder, vec1)}); - auto select1 = - builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21); - auto select2 = - builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1); - builder.Add(builder.GetTupleElement(select2, 0), - builder.GetTupleElement(select2, 1)); + auto select1 = Select(GetTupleElement(pred_tuple, 0), tuple12, tuple21); + auto select2 = Select(GetTupleElement(pred_tuple, 1), tuple21, select1); + Add(GetTupleElement(select2, 0), GetTupleElement(select2, 1)); ComputeAndCompareR1(&builder, {3.f, 6.f, 9.f}, {}, error_spec_); } @@ -395,31 +400,32 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; - auto c1 = builder.ConstantR1(vec1); - auto c2 = builder.ConstantR1(vec2); - auto tuple12 = builder.Tuple({c1, c2}); - auto tuple21 = builder.Tuple({c2, c1}); + auto c1 = ConstantR1(&builder, vec1); + auto c2 = ConstantR1(&builder, vec2); + auto tuple12 = Tuple(&builder, {c1, c2}); + auto tuple21 = Tuple(&builder, {c2, c1}); - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), - Literal::CreateR1(vec1).get()}); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), + LiteralUtil::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } XLA_TEST_F(TupleTest, NestedTuples) { XlaBuilder builder(TestName()); - auto inner_tuple = builder.Tuple( - {builder.ConstantR1({1.0, 2.0}), builder.ConstantR0(42.0)}); - builder.Tuple({inner_tuple, builder.ConstantR1({22.0, 44.0})}); + auto inner_tuple = Tuple(&builder, {ConstantR1(&builder, {1.0, 2.0}), + ConstantR0(&builder, 42.0)}); + Tuple(&builder, {inner_tuple, ConstantR1(&builder, {22.0, 44.0})}); - auto expected_v1 = Literal::CreateR1({1.0, 2.0}); - auto expected_s = Literal::CreateR0(42.0); + auto expected_v1 = LiteralUtil::CreateR1({1.0, 2.0}); + auto expected_s = LiteralUtil::CreateR0(42.0); auto expected_inner_tuple = - Literal::MakeTuple({expected_v1.get(), expected_s.get()}); - auto expected_v2 = Literal::CreateR1({22.0, 44.0}); + LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()}); + auto expected_v2 = LiteralUtil::CreateR1({22.0, 44.0}); auto expected = - Literal::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); + LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -432,21 +438,21 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { Shape outer_tuple_shape = ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape}); - auto input = builder.Parameter(0, outer_tuple_shape, "input"); - auto gte0 = builder.GetTupleElement(input, 0); - auto gte1 = builder.GetTupleElement(gte0, 1); - builder.Add(gte1, builder.ConstantR1({10.0, 11.0, 12.0})); + auto input = Parameter(&builder, 0, outer_tuple_shape, "input"); + auto gte0 = GetTupleElement(input, 0); + auto gte1 = GetTupleElement(gte0, 1); + Add(gte1, ConstantR1(&builder, {10.0, 11.0, 12.0})); std::unique_ptr data = client_ - ->TransferToServer(*Literal::MakeTuple({ - Literal::MakeTuple( + ->TransferToServer(*LiteralUtil::MakeTuple({ + LiteralUtil::MakeTuple( { - Literal::CreateR1({1.0, 2.0, 3.0}).get(), - Literal::CreateR1({4.0, 5.0, 6.0}).get(), + LiteralUtil::CreateR1({1.0, 2.0, 3.0}).get(), + LiteralUtil::CreateR1({4.0, 5.0, 6.0}).get(), }) .get(), - Literal::CreateR1({7.0, 8.0, 9.0}).get(), + LiteralUtil::CreateR1({7.0, 8.0, 9.0}).get(), })) .ConsumeValueOrDie(); @@ -463,25 +469,26 @@ XLA_TEST_F(TupleTest, ComplexTuples) { Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2}); Shape arg0_shape = ShapeUtil::MakeTupleShape( {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})}); - auto input0 = builder.Parameter(0, arg0_shape, "input0"); - auto t0 = builder.GetTupleElement(input0, 0); - auto t1 = builder.GetTupleElement(input0, 1); - auto t10 = builder.GetTupleElement(t1, 0); - auto t11 = builder.GetTupleElement(t1, 1); - auto sum = builder.Add(builder.Add(t10, t11, {1}), t0); - auto input1 = builder.Parameter(1, c64r1, "input1"); - auto prod = builder.Mul(input1, sum, {1}); - builder.Tuple({builder.Tuple({prod, sum}), - builder.ConstantR0({123, 456})}); + auto input0 = Parameter(&builder, 0, arg0_shape, "input0"); + auto t0 = GetTupleElement(input0, 0); + auto t1 = GetTupleElement(input0, 1); + auto t10 = GetTupleElement(t1, 0); + auto t11 = GetTupleElement(t1, 1); + auto sum = Add(Add(t10, t11, {1}), t0); + auto input1 = Parameter(&builder, 1, c64r1, "input1"); + auto prod = Mul(input1, sum, {1}); + Tuple(&builder, {Tuple(&builder, {prod, sum}), + ConstantR0(&builder, {123, 456})}); } std::unique_ptr arg0 = client_ - ->TransferToServer(*Literal::MakeTuple( - {Literal::CreateR0({1, 2}).get(), - Literal::MakeTuple( - {Literal::CreateR1({{10, 20}, {30, 40}}).get(), - Literal::CreateR2( + ->TransferToServer(*LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0({1, 2}).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({{10, 20}, {30, 40}}) + .get(), + LiteralUtil::CreateR2( {{{100, 200}, {300, 400}}, {{1000, 2000}, {3000, 4000}}, {{10000, 20000}, {30000, 40000}}}) @@ -490,11 +497,13 @@ XLA_TEST_F(TupleTest, ComplexTuples) { .ConsumeValueOrDie(); std::unique_ptr arg1 = client_ - ->TransferToServer(*Literal::CreateR1({{1, 2}, {1, -2}})) + ->TransferToServer( + *LiteralUtil::CreateR1({{1, 2}, {1, -2}})) .ConsumeValueOrDie(); - auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, - {{1011, 2022}, {3031, 4042}}, - {{10011, 20022}, {30031, 40042}}}); + auto sum = + LiteralUtil::CreateR2({{{111, 222}, {331, 442}}, + {{1011, 2022}, {3031, 4042}}, + {{10011, 20022}, {30031, 40042}}}); auto prod = MakeUnique(sum->shape()); ASSERT_TRUE(prod->Populate( [&sum](tensorflow::gtl::ArraySlice indexes) { @@ -504,9 +513,9 @@ XLA_TEST_F(TupleTest, ComplexTuples) { : complex64(1, -2)); }) .ok()); - auto expected = - Literal::MakeTuple({Literal::MakeTuple({prod.get(), sum.get()}).get(), - Literal::CreateR0({123, 456}).get()}); + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(), + LiteralUtil::CreateR0({123, 456}).get()}); ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, error_spec_); } @@ -529,11 +538,58 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::MakeTupleOwned(Literal::CreateR1({1, 2, 3})); + auto param = + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({1, 2, 3})); auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, - *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})))); + *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2({{1, 2, 3}})), + *result)); +} + +// Disabled on interpreter due to lack of outfeed. +XLA_TEST_F(TupleHloTest, + DISABLED_ON_INTERPRETER(NonAmbiguousTopLevelAllocation)) { + const char* testcase = R"( + HloModule tuple + + ENTRY main { + a = f32[2] parameter(0) + b = f32[2] parameter(1) + c = f32[2] parameter(2) + d = f32[2] parameter(3) + cond = pred[] parameter(4) + + tup0 = (f32[2],f32[2]) tuple(a, b) + tup1 = (f32[2],f32[2]) tuple(c, d) + + s = (f32[2],f32[2]) tuple-select(cond, tup0, tup1) + gte = f32[2] get-tuple-element(s), index=0 + tuple = (f32[2]) tuple(gte) + token = token[] after-all() + ROOT outfeed = token[] outfeed(tuple, token) + } + )"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param0 = LiteralUtil::CreateR1({1, 2}); + auto param1 = LiteralUtil::CreateR1({2, 3}); + auto param4 = LiteralUtil::CreateR0(false); + // Put execution on a separate thread so we can block on outfeed. + std::unique_ptr thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", [&] { + TF_EXPECT_OK(Execute(std::move(module), + {param0.get(), param1.get(), param1.get(), + param0.get(), param4.get()}) + .status()); + })); + auto expected = + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({2, 3})); + auto literal = Literal::CreateFromShape(expected->shape()); + TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( + backend().default_stream_executor(), expected->shape(), *literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index c3abe22797f5eaa76ced2ad8534bd68c32983e60..20ae68ab74026936c43e5f525eb796eb402a19cb 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -38,8 +38,8 @@ class UnaryOpTest : public ClientLibraryTestBase { template void AbsSize0TestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1({}); - auto abs = builder.Abs(arg); + auto arg = ConstantR1(&builder, {}); + Abs(arg); if (primitive_util::NativeToPrimitiveType() == C64) { ComputeAndCompareR1(&builder, {}, {}); @@ -51,8 +51,8 @@ class UnaryOpTest : public ClientLibraryTestBase { template void AbsTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1({-2, 25, 0, -123, inf(), -inf()}); - auto abs = builder.Abs(arg); + auto arg = ConstantR1(&builder, {-2, 25, 0, -123, inf(), -inf()}); + Abs(arg); ComputeAndCompareR1(&builder, {2, 25, 0, 123, inf(), inf()}, {}); } @@ -60,9 +60,9 @@ class UnaryOpTest : public ClientLibraryTestBase { template void SignTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1( - {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); - auto sign = builder.Sign(arg); + auto arg = ConstantR1( + &builder, {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); + Sign(arg); ComputeAndCompareR1(&builder, {-1, 1, 0, 0, -1, 1, -1}, {}); } @@ -70,10 +70,10 @@ class UnaryOpTest : public ClientLibraryTestBase { template void SignAbsTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1({-2, 25, 0, -123}); - auto sign = builder.Sign(arg); - auto abs = builder.Abs(arg); - builder.Sub(builder.Mul(sign, abs), arg); + auto arg = ConstantR1(&builder, {-2, 25, 0, -123}); + auto sign = Sign(arg); + auto abs = Abs(arg); + Sub(Mul(sign, abs), arg); ComputeAndCompareR1(&builder, {0, 0, 0, 0}, {}); } @@ -92,27 +92,28 @@ int64 UnaryOpTest::inf() { template <> void UnaryOpTest::AbsTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1({{-2, 0}, - {0, 25}, - {0, 0}, - {-0.3f, 0.4f}, - {0, inf()}, - {-inf(), 0}}); - auto abs = builder.Abs(arg); + auto arg = ConstantR1(&builder, {{-2, 0}, + {0, 25}, + {0, 0}, + {-0.3f, 0.4f}, + {0, inf()}, + {-inf(), 0}}); + Abs(arg); std::unique_ptr expected = - Literal::CreateR1({2, 25, 0, 0.5, inf(), inf()}); + LiteralUtil::CreateR1({2, 25, 0, 0.5, inf(), inf()}); ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } template <> void UnaryOpTest::SignTestHelper() { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1( + auto arg = ConstantR1( + &builder, {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); - auto sign = builder.Sign(arg); + Sign(arg); - std::unique_ptr expected = Literal::CreateR1( + std::unique_ptr expected = LiteralUtil::CreateR1( {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } @@ -121,13 +122,13 @@ template <> void UnaryOpTest::SignAbsTestHelper() { XlaBuilder builder(TestName()); auto arg = - builder.ConstantR1({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}}); - auto sign = builder.Sign(arg); - auto abs = builder.Abs(arg); - builder.Sub(builder.Mul(sign, builder.ConvertElementType(abs, C64)), arg); + ConstantR1(&builder, {{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}}); + auto sign = Sign(arg); + auto abs = Abs(arg); + Sub(Mul(sign, ConvertElementType(abs, C64)), arg); std::unique_ptr expected = - Literal::CreateR1({0, 0, 0, 0}); + LiteralUtil::CreateR1({0, 0, 0, 0}); ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } @@ -145,37 +146,34 @@ XLA_TEST_F(UnaryOpTest, AbsTestR1) { XLA_TEST_F(UnaryOpTest, AbsTestR0) { XlaBuilder builder(TestName()); - auto argi = builder.ConstantR0(-5); - auto absi = builder.Abs(argi); - auto argf = builder.ConstantR0(-3.0f); - auto absf = builder.Abs(argf); - auto argf0 = builder.ConstantR0(-0.0f); - auto absf0 = builder.Abs(argf0); - auto argc = builder.ConstantR0({-0.3f, 0.4f}); - auto absc = builder.Abs(argc); - builder.Add(builder.Add(absc, absf0), - builder.Add(absf, builder.ConvertElementType(absi, F32))); + auto argi = ConstantR0(&builder, -5); + auto absi = Abs(argi); + auto argf = ConstantR0(&builder, -3.0f); + auto absf = Abs(argf); + auto argf0 = ConstantR0(&builder, -0.0f); + auto absf0 = Abs(argf0); + auto argc = ConstantR0(&builder, {-0.3f, 0.4f}); + auto absc = Abs(argc); + Add(Add(absc, absf0), Add(absf, ConvertElementType(absi, F32))); ComputeAndCompareR0(&builder, 8.5f, {}); } XLA_TEST_F(UnaryOpTest, SignTestR0) { XlaBuilder builder(TestName()); - auto argi = builder.ConstantR0(-5); - auto sgni = builder.Sign(argi); // -1 - auto argf = builder.ConstantR0(-4.0f); - auto sgnf = builder.Sign(argf); // -1 - auto argf0 = builder.ConstantR0(-0.0f); - auto sgnf0 = builder.Sign(argf0); // 0 - auto argc = builder.ConstantR0({-.3, .4}); - auto sgnc = builder.Sign(argc); // (-.6, .8) - builder.Add(sgnc, builder.ConvertElementType( - builder.Add(builder.Add(sgnf0, sgnf), - builder.ConvertElementType(sgni, F32)), - C64)); + auto argi = ConstantR0(&builder, -5); + auto sgni = Sign(argi); // -1 + auto argf = ConstantR0(&builder, -4.0f); + auto sgnf = Sign(argf); // -1 + auto argf0 = ConstantR0(&builder, -0.0f); + auto sgnf0 = Sign(argf0); // 0 + auto argc = ConstantR0(&builder, {-.3, .4}); + auto sgnc = Sign(argc); // (-.6, .8) + Add(sgnc, ConvertElementType( + Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64)); std::unique_ptr expected = - Literal::CreateR0({-2.6f, 0.8f}); + LiteralUtil::CreateR0({-2.6f, 0.8f}); ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } @@ -194,9 +192,9 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1( - {2, 25, 0, 123, std::numeric_limits::max()}); - auto abs = builder.Abs(arg); + auto arg = ConstantR1( + &builder, {2, 25, 0, 123, std::numeric_limits::max()}); + Abs(arg); ComputeAndCompareR1( &builder, {2, 25, 0, 123, std::numeric_limits::max()}, {}); @@ -204,37 +202,37 @@ XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR1( - {2, 25, 0, 123, std::numeric_limits::max()}); - auto sign = builder.Sign(arg); + auto arg = ConstantR1( + &builder, {2, 25, 0, 123, std::numeric_limits::max()}); + Sign(arg); ComputeAndCompareR1(&builder, {1, 1, 0, 1, 1}, {}); } XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { XlaBuilder builder(TestName()); - auto arg = builder.ConstantR2({{1.0, -2.0}, {-3.0, 4.0}}); - auto sign = builder.Sign(arg); - auto abs = builder.Abs(arg); - builder.Sub(builder.Mul(sign, abs), arg); + auto arg = ConstantR2(&builder, {{1.0, -2.0}, {-3.0, 4.0}}); + auto sign = Sign(arg); + auto abs = Abs(arg); + Sub(Mul(sign, abs), arg); ComputeAndCompareR2(&builder, {{0, 0}, {0, 0}}, {}); } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 1}); - auto rhs = builder.ConstantR1({1, 1}); - builder.ConvertElementType(builder.Eq(lhs, rhs), S32); + auto lhs = ConstantR1(&builder, {0, 1}); + auto rhs = ConstantR1(&builder, {1, 1}); + ConvertElementType(Eq(lhs, rhs), S32); ComputeAndCompareR1(&builder, {0, 1}, {}); } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) { XlaBuilder builder(TestName()); - auto lhs = builder.ConstantR1({0, 1}); - auto rhs = builder.ConstantR1({1, 1}); - builder.ConvertElementType(builder.Eq(lhs, rhs), F32); + auto lhs = ConstantR1(&builder, {0, 1}); + auto rhs = ConstantR1(&builder, {1, 1}); + ConvertElementType(Eq(lhs, rhs), F32); ComputeAndCompareR1(&builder, {0.0, 1.0}, {}); } diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc index 82d301983fc7885ef5c1c1ed05b74fc017bb7727..ef1b1445bbe555da00db4446d59439b752735a80 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -46,7 +46,7 @@ class VecOpsReduceTest : public ClientLibraryTestBase { {{1.0, 2.0, 3.0}, // } plane 2 in dim 0 {4.0, 5.0, 6.0}}}); // clang-format on - return builder_.ConstantR3FromArray3D(x3d); + return ConstantR3FromArray3D(&builder_, x3d); } XlaBuilder builder_; @@ -56,11 +56,10 @@ class VecOpsReduceTest : public ClientLibraryTestBase { TEST_F(VecOpsReduceTest, AddReduceR1F32) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); - auto x = builder_.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + auto x = ConstantR1( + &builder_, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR0(&builder_, -4.2f, {}, errspec_); } @@ -71,10 +70,9 @@ TEST_F(VecOpsReduceTest, AddReduceBigR1F32) { std::vector input(3000); std::iota(input.begin(), input.end(), 100.0f); - auto x = builder_.ConstantR1(input); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + auto x = ConstantR1(&builder_, input); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); float expected = std::accumulate(input.begin(), input.end(), 0.0f); ComputeAndCompareR0(&builder_, expected, {}, errspec_); @@ -83,11 +81,10 @@ TEST_F(VecOpsReduceTest, AddReduceBigR1F32) { TEST_F(VecOpsReduceTest, MaxReduceR1F32) { auto max_reducer = CreateScalarMax(); - auto x = builder_.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto max_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), max_reducer, - /*dimensions_to_reduce=*/{0}); + auto x = ConstantR1( + &builder_, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Reduce(x, ConstantR0(&builder_, 0.0f), max_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR0(&builder_, 2.6f, {}, errspec_); } @@ -95,11 +92,10 @@ TEST_F(VecOpsReduceTest, MaxReduceR1F32) { TEST_F(VecOpsReduceTest, MaxReduceR1F32WithNontrivialInit) { auto max_reducer = CreateScalarMax(); - auto x = builder_.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto max_reduce = - builder_.Reduce(x, builder_.ConstantR0(4.0f), max_reducer, - /*dimensions_to_reduce=*/{0}); + auto x = ConstantR1( + &builder_, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Reduce(x, ConstantR0(&builder_, 4.0f), max_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR0(&builder_, 4.0f, {}, errspec_); } @@ -108,15 +104,14 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim1) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); // clang-format off - auto x = builder_.ConstantR2({ + auto x = ConstantR2(&builder_, { {1.0, 2.0, 3.0}, // | dim 0 {4.0, 5.0, 6.0}}); // | // ------ dim 1 ---------- // clang-format on - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{1}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1}); ComputeAndCompareR1(&builder_, {6.0, 15.0}, {}, errspec_); } @@ -125,13 +120,12 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); // clang-format off - auto x = builder_.ConstantR2({ + auto x = ConstantR2(&builder_, { {1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}); // clang-format on - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); ComputeAndCompareR1(&builder_, {5.0, 7.0, 9.0}, {}, errspec_); } @@ -139,9 +133,8 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{2}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{2}); Array2D expected_array({{6.0f, 15.0f}, {6.0f, 15.0f}, {6.0f, 15.0f}}); @@ -151,9 +144,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{1}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1}); Array2D expected_array( {{5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}}); @@ -164,9 +156,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0}); Array2D expected_array({{3.0f, 6.0f, 9.0f}, {12.0f, 15.0f, 18.0f}}); @@ -176,9 +167,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{1, 2}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{1, 2}); ComputeAndCompareR1(&builder_, {21.0, 21.0, 21.0}, {}, errspec_); } @@ -186,9 +176,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) { XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0, 2}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 2}); ComputeAndCompareR1(&builder_, {18.0, 45.0}, {}, errspec_); } @@ -196,9 +185,8 @@ XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) { TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0, 1}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 1}); ComputeAndCompareR1(&builder_, {15.0, 21.0, 27.0}, {}, errspec_); } @@ -206,9 +194,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) { TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) { auto sum_reducer = CreateScalarAddComputation(F32, &builder_); auto x = BuildSampleConstantCube(); - auto add_reduce = - builder_.Reduce(x, builder_.ConstantR0(0.0f), sum_reducer, - /*dimensions_to_reduce=*/{0, 1, 2}); + Reduce(x, ConstantR0(&builder_, 0.0f), sum_reducer, + /*dimensions_to_reduce=*/{0, 1, 2}); ComputeAndCompareR0(&builder_, 63.0, {}, errspec_); } diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 5cce7a2bf82c1a8403536a91e67910f949ef185a..3848ec1684cdc9186e14ac0b60315b7520d127f3 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -50,9 +50,9 @@ class VecOpsSimpleTest : public ClientLibraryTestBase { XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto exp = builder.Exp(x); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Exp(x); std::vector expected = {8.1662, 7.4274e-02, 13.4637, 1.8316e-02, 8.1662, 9.9742, 6.7379e-03, 4.0657e-01, @@ -69,8 +69,8 @@ XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) { for (int i = 0; i < count; ++i) { exponents.push_back(i / static_cast(count)); } - auto x = builder.ConstantR1(exponents); - auto exp = builder.Exp(x); + auto x = ConstantR1(&builder, exponents); + Exp(x); std::vector expected; expected.reserve(exponents.size()); @@ -98,8 +98,8 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) { Array4D expected(2, 2, 2, 2, expected_vector); - auto x = builder.ConstantR4FromArray4D(exponents); - auto exp = builder.Exp(x); + auto x = ConstantR4FromArray4D(&builder, exponents); + Exp(x); ComputeAndCompareR4(&builder, expected, {}, ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3)); @@ -107,9 +107,9 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) { XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - builder.Neg(x); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Neg(x); std::vector expected = {-2.1, 2.6, -2.6, 4.0, -2.1, -2.3, 5.0, 0.9, 2.4, -1.6}; @@ -118,8 +118,8 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({2, -2, 12, -4, 5, 20, -15, 0, -2, 1}); - builder.Neg(x); + auto x = ConstantR1(&builder, {2, -2, 12, -4, 5, 20, -15, 0, -2, 1}); + Neg(x); std::vector expected = {-2, 2, -12, 4, -5, -20, 15, 0, 2, -1}; ComputeAndCompareR1(&builder, expected, {}); @@ -127,59 +127,19 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {0, 1, 42, static_cast(-1), static_cast(-12)}); - builder.Neg(x); + auto x = ConstantR1( + &builder, {0, 1, 42, static_cast(-1), static_cast(-12)}); + Neg(x); std::vector expected = {0, static_cast(-1), static_cast(-42), 1, 12}; ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) { - XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - builder.SquareF32(x); - - std::vector expected = {4.41, 6.76, 6.76, 16., 4.41, - 5.29, 25., 0.81, 5.76, 2.56}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); -} - -XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { - XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - builder.ReciprocalF32(x); - - std::vector expected = { - 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048, - 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); -} - -XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) { - XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({0.0, -0.0}); - auto exp = builder.SqrtF32(x); - - ComputeAndCompareR1(&builder, {0, 0}, {}, error_spec_); -} - -XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) { - XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); - auto exp = builder.SqrtF32(x); - - std::vector expected = {4, 1, 32, 0.4, 0.4472, 111.1080}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); -} - XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { XlaBuilder builder(TestName()); - auto x = - builder.ConstantR1({16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345}); - auto exp = builder.Pow(x, builder.ConstantR0(-.5f)); + auto x = ConstantR1(&builder, + {16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345}); + Pow(x, ConstantR0(&builder, -.5f)); std::vector expected = {.25, 1, .03125, 2.5, 2.23607, .009000, .900025}; @@ -191,11 +151,11 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { XlaBuilder builder(TestName()); auto add = CreateScalarAddComputation(F32, &builder); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto y = builder.ConstantR1( - {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto max = builder.Map({x, y}, add, {0}); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = ConstantR1( + &builder, {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); + Map(&builder, {x, y}, add, {0}); std::vector expected = {1.7, -3.2, -0.4, -3.8, 5.9, 0.1, -6.8, 4., -1., 2.2}; @@ -204,11 +164,11 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto y = builder.ConstantR1( - {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto max = builder.Max(x, y); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = ConstantR1( + &builder, {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); + Max(x, y); std::vector expected = {2.1, -0.6, 2.6, 0.2, 3.8, 2.3, -1.8, 4.9, 1.4, 1.6}; @@ -227,7 +187,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) { {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto max = builder.Max(v1, v2); + Max(v1, v2); ComputeAndCompareR1(&builder, {41.0f, 22.0f, 23.0f, 84.0f}, {param0_data.get(), param1_data.get()}, error_spec_); @@ -267,7 +227,7 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { CreateR1Parameter(v2vec, /*parameter_number=*/1, /*name=*/"v2", /*builder=*/&builder, /*data_handle=*/&v2); - auto max = builder.Max(v1, v2); + Max(v1, v2); ComputeAndCompareR1(&builder, expected_vec, {param0_data.get(), param1_data.get()}, error_spec_); @@ -275,10 +235,10 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto y = builder.ConstantR0(0); - auto max = builder.Max(x, y); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = ConstantR0(&builder, 0); + Max(x, y); std::vector expected = {2.1, 0.0, 2.6, 0.0, 2.1, 2.3, 0.0, 0.0, 0.0, 1.6}; @@ -287,11 +247,11 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { XLA_TEST_F(VecOpsSimpleTest, MinTenValues) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto y = builder.ConstantR1( - {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto min = builder.Min(x, y); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + auto y = ConstantR1( + &builder, {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); + Min(x, y); std::vector expected = {-0.4, -2.6, -3.0, -4.0, 2.1, -2.2, -5.0, -0.9, -2.4, 0.6}; @@ -300,11 +260,11 @@ XLA_TEST_F(VecOpsSimpleTest, MinTenValues) { XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0); - auto one = builder.ConstantR0(1); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); - auto clamp = builder.Min(builder.Max(x, zero), one); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); + Min(Max(x, zero), one); std::vector expected = {1.0, 0.0, 1.0, 0.3, 1.0, 0.9, 0.0, 0.1, 0.0, 0.6}; @@ -313,11 +273,11 @@ XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) { XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0); - auto one = builder.ConstantR0(1); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); - auto clamp = builder.Clamp(zero, x, one); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 1); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); + Clamp(zero, x, one); std::vector expected = {1.0, 0.0, 1.0, 0.3, 1.0, 0.9, 0.0, 0.1, 0.0, 0.6}; @@ -326,10 +286,10 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR1({0.0f, 0.0f}); - auto one = builder.ConstantR1({1.0f, 1.0f}); - auto x = builder.ConstantR1({2.1, -2.6}); - auto clamp = builder.Clamp(zero, x, one); + auto zero = ConstantR1(&builder, {0.0f, 0.0f}); + auto one = ConstantR1(&builder, {1.0f, 1.0f}); + auto x = ConstantR1(&builder, {2.1, -2.6}); + Clamp(zero, x, one); std::vector expected = {1.0, 0.0}; ComputeAndCompareR1(&builder, expected, {}); @@ -337,11 +297,11 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { XlaBuilder builder(TestName()); - auto one = builder.ConstantR0(1); - auto two = builder.ConstantR0(2); - auto x = builder.ConstantR1( - {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); - auto clamp = builder.Clamp(one, x, two); + auto one = ConstantR0(&builder, 1); + auto two = ConstantR0(&builder, 2); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6}); + Clamp(one, x, two); std::vector expected = {2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0}; @@ -350,10 +310,10 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { XlaBuilder builder(TestName()); - auto zero = builder.ConstantR0(0); - auto one = builder.ConstantR0(10); - auto x = builder.ConstantR1({-3, 3, 9, 13}); - auto clamp = builder.Clamp(zero, x, one); + auto zero = ConstantR0(&builder, 0); + auto one = ConstantR0(&builder, 10); + auto x = ConstantR1(&builder, {-3, 3, 9, 13}); + Clamp(zero, x, one); std::vector expected = {0, 3, 9, 10}; ComputeAndCompareR1(&builder, expected, {}); @@ -365,9 +325,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { // add_half(x) = x + 0.5 XlaBuilder builder("add_half"); auto x_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); - auto half = builder.ConstantR0(0.5); - builder.Add(x_value, half); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x_value"); + auto half = ConstantR0(&builder, 0.5); + Add(x_value, half); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); add_half = computation_status.ConsumeValueOrDie(); @@ -378,9 +338,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { // clamp(y) = clamp<0,5>(y) XlaBuilder builder("clamp"); auto y_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y_value"); - auto zero = builder.ConstantR0(0.0); - auto clamped = builder.Clamp(zero, y_value, builder.ConstantR0(5)); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "y_value"); + auto zero = ConstantR0(&builder, 0.0); + Clamp(zero, y_value, ConstantR0(&builder, 5)); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); clamp = computation_status.ConsumeValueOrDie(); @@ -391,13 +351,13 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { // mult_relu_add(z) = clamp(add_half(2 * max(z, 0))) XlaBuilder builder("mult_relu_add"); auto z_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value"); - auto zero = builder.ConstantR0(0.0); - auto two = builder.ConstantR0(2.0); - auto max = builder.Max(z_value, zero); - auto mult = builder.Mul(two, max); - auto inner = builder.Map({mult}, add_half, {}); - builder.Map({inner}, clamp, {}); + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "z_value"); + auto zero = ConstantR0(&builder, 0.0); + auto two = ConstantR0(&builder, 2.0); + auto max = Max(z_value, zero); + auto mult = Mul(two, max); + auto inner = Map(&builder, {mult}, add_half, {}); + Map(&builder, {inner}, clamp, {}); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); mult_relu_add = computation_status.ConsumeValueOrDie(); @@ -405,9 +365,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { XlaBuilder builder("map10"); { - auto x = builder.ConstantR1( - {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto activations = builder.Map({x}, mult_relu_add, {0}); + auto x = ConstantR1( + &builder, {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Map(&builder, {x}, mult_relu_add, {0}); } std::vector expected = {4.7, 0.5, 5.0, 0.5, 4.7, @@ -417,9 +377,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({-5, -4, -3, -2, -1, 0, 1, 2, 3, 4}); - auto y = builder.ConstantR0(3); - builder.Rem(x, y); + auto x = ConstantR1(&builder, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto y = ConstantR0(&builder, 3); + Rem(x, y); std::vector expected = {-2, -1, 0, -2, -1, 0, 1, 2, 0, 1}; ComputeAndCompareR1(&builder, expected, {}); @@ -427,9 +387,9 @@ XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({false, true}); - auto y = builder.ConstantR1({true, false}); - builder.Eq(x, y); + auto x = ConstantR1(&builder, {false, true}); + auto y = ConstantR1(&builder, {true, false}); + Eq(x, y); std::array expected = {{false, false}}; ComputeAndCompareR1(&builder, expected, {}); @@ -437,9 +397,9 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { XlaBuilder builder(TestName()); - auto x = builder.ConstantR1({false, true}); - auto y = builder.ConstantR1({true, false}); - builder.Ne(x, y); + auto x = ConstantR1(&builder, {false, true}); + auto y = ConstantR1(&builder, {true, false}); + Ne(x, y); std::array expected = {{true, true}}; ComputeAndCompareR1(&builder, expected, {}); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index c463f3eac55e5b8ab32dc52d5a38e7840241bc58..1bdf1867b9330b715b0ba4aca71d56307883c775 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -55,8 +55,8 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Gt(builder.ConstantR0(5), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Gt(ConstantR0(&builder, 5), prev); condition = builder.Build().ConsumeValueOrDie(); } @@ -64,16 +64,16 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR0(1); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR0(&builder, 1); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.ConstantR0(0); - builder.While(condition, body, init); + auto init = ConstantR0(&builder, 0); + While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -91,8 +91,8 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Gt(builder.ConstantR0(5), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Gt(ConstantR0(&builder, 5), prev); condition = builder.Build().ConsumeValueOrDie(); } @@ -100,16 +100,16 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR0(1); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR0(&builder, 1); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.ConstantR0(0); - builder.While(condition, body, init); + auto init = ConstantR0(&builder, 0); + While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -122,8 +122,8 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Gt(builder.ConstantR0(5), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Gt(ConstantR0(&builder, 5), prev); condition = builder.Build().ConsumeValueOrDie(); } @@ -131,18 +131,18 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR0(1); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR0(&builder, 1); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.Reduce(builder.ConstantR1(2, 1), - builder.ConstantR0(0), - CreateScalarAddComputation(S32, &builder), {0}); - builder.While(condition, body, init); + auto init = + Reduce(ConstantR1(&builder, 2, 1), ConstantR0(&builder, 0), + CreateScalarAddComputation(S32, &builder), {0}); + While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -154,8 +154,8 @@ TEST_F(WhileTest, WhileWithPredicateResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Ne(builder.ConstantR0(true), prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Ne(ConstantR0(&builder, true), prev); condition = builder.Build().ConsumeValueOrDie(); } @@ -163,16 +163,16 @@ TEST_F(WhileTest, WhileWithPredicateResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Or(prev, builder.ConstantR0(true)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Or(prev, ConstantR0(&builder, true)); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.Ne(builder.ConstantR0(false), - builder.ConstantR0(true)); - builder.While(condition, body, init); + auto init = + Ne(ConstantR0(&builder, false), ConstantR0(&builder, true)); + While(condition, body, init); ComputeAndCompareR0(&builder, true, {}); } @@ -184,17 +184,16 @@ TEST_F(WhileTest, WhileWithPredicateResult) { // while (result.sum() < 15.5f) { // result = result + vector(0); // } -// TODO(b/29185393): does not terminate on CPU. -TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) { Shape result_shape = ShapeUtil::MakeShape(F32, {0}); // Create a computation for the reduction. XlaComputation add; { XlaBuilder builder("add"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); add = builder.Build().ConsumeValueOrDie(); } @@ -203,10 +202,10 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0}); - builder.Gt(builder.ConstantR0(15.5f), sum); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto sum = Reduce(prev, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0}); + Gt(ConstantR0(&builder, 15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } @@ -215,16 +214,16 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR1({}); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR1(&builder, {}); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.ConstantR1({}); - auto result = builder.While(condition, body, init); + auto init = ConstantR1(&builder, {}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -246,9 +245,9 @@ TEST_F(WhileTest, WhileWithVectorResult) { XlaComputation add; { XlaBuilder builder("add"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); add = builder.Build().ConsumeValueOrDie(); } @@ -257,10 +256,10 @@ TEST_F(WhileTest, WhileWithVectorResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0}); - builder.Gt(builder.ConstantR0(15.5f), sum); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto sum = Reduce(prev, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0}); + Gt(ConstantR0(&builder, 15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } @@ -269,16 +268,16 @@ TEST_F(WhileTest, WhileWithVectorResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR1(8, 0.125f); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR1(&builder, 8, 0.125f); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.ConstantR1(8, 0.f); - auto result = builder.While(condition, body, init); + auto init = ConstantR1(&builder, 8, 0.f); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -306,9 +305,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { XlaComputation add; { XlaBuilder builder("add"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - builder.Add(x, y); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y"); + Add(x, y); add = builder.Build().ConsumeValueOrDie(); } @@ -317,10 +316,10 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, - /*dimensions_to_reduce=*/{0}); - builder.Gt(builder.ConstantR0(15.5f), sum); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto sum = Reduce(prev, ConstantR0(&builder, 0.0f), add, + /*dimensions_to_reduce=*/{0}); + Gt(ConstantR0(&builder, 15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } @@ -329,27 +328,27 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR1(8, 0.125f); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR1(&builder, 8, 0.125f); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.ConstantR1(8, 0.f); - auto result = builder.While(condition, body, init); + auto init = ConstantR1(&builder, 8, 0.f); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - builder.Tuple({result}); + Tuple(&builder, {result}); // Individual elements with increase by 1/8 each time through the loop, so // the sum will increase by 1.0. It will first be >15.5 when the elements // have all reached 2.0. auto expected_data = - Literal::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); - auto expected = Literal::MakeTuple({expected_data.get()}); + LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); + auto expected = LiteralUtil::MakeTuple({expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -366,9 +365,9 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(N), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, N), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -377,32 +376,34 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto w1 = builder.GetTupleElement(prev, 1); - auto w2 = builder.GetTupleElement(prev, 2); - auto w3 = builder.GetTupleElement(prev, 3); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto w1 = GetTupleElement(prev, 1); + auto w2 = GetTupleElement(prev, 2); + auto w3 = GetTupleElement(prev, 3); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), - builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 3, 1.f), + ConstantR1(&builder, 3, 2.f), + ConstantR1(&builder, 3, 3.f)}); + auto result = While(condition, body, init); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(N); - auto expected_w1 = Literal::CreateR1({1.0f, 1.0f, 1.0f}); - auto expected_w2 = Literal::CreateR1({2.0f, 2.0f, 2.0f}); - auto expected_w3 = Literal::CreateR1({3.0f, 3.0f, 3.0f}); - auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(), - expected_w3.get(), expected_w1.get()}); + auto expected_counter = LiteralUtil::CreateR0(N); + auto expected_w1 = LiteralUtil::CreateR1({1.0f, 1.0f, 1.0f}); + auto expected_w2 = LiteralUtil::CreateR1({2.0f, 2.0f, 2.0f}); + auto expected_w3 = LiteralUtil::CreateR1({3.0f, 3.0f, 3.0f}); + auto expected = + LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(), + expected_w3.get(), expected_w1.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -419,9 +420,9 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(N), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, N), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -430,26 +431,27 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto w1 = builder.GetTupleElement(prev, 1); - auto w2 = builder.GetTupleElement(prev, 2); - auto w3 = builder.GetTupleElement(prev, 3); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto w1 = GetTupleElement(prev, 1); + auto w2 = GetTupleElement(prev, 2); + auto w3 = GetTupleElement(prev, 3); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), - builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); - auto xla_while = builder.While(condition, body, init); - - auto add12 = builder.Add(builder.GetTupleElement(xla_while, 1), - builder.GetTupleElement(xla_while, 2)); - auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3)); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 3, 1.f), + ConstantR1(&builder, 3, 2.f), + ConstantR1(&builder, 3, 3.f)}); + auto xla_while = While(condition, body, init); + + auto add12 = + Add(GetTupleElement(xla_while, 1), GetTupleElement(xla_while, 2)); + auto result = Add(add12, GetTupleElement(xla_while, 3)); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -474,9 +476,9 @@ TEST_F(WhileTest, WhileWithTupleResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, 5), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -486,30 +488,30 @@ TEST_F(WhileTest, WhileWithTupleResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(5); - auto expected_data = Literal::CreateR1( + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_data = LiteralUtil::CreateR1( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -524,9 +526,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, 5), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -535,29 +537,28 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto pred = builder.GetTupleElement(prev, 1); - auto new_pred = builder.Or(pred, builder.ConstantR0(true)); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto pred = GetTupleElement(prev, 1); + auto new_pred = Or(pred, ConstantR0(&builder, true)); + Tuple(&builder, {Add(iteration, ConstantR0(&builder, 1)), new_pred}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple({builder.ConstantR0(0), - builder.Ne(builder.ConstantR0(false), - builder.ConstantR0(true))}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + Ne(ConstantR0(&builder, false), + ConstantR0(&builder, true))}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(5); - auto expected_predicate = Literal::CreateR0(true); - auto expected = - Literal::MakeTuple({expected_counter.get(), expected_predicate.get()}); + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_predicate = LiteralUtil::CreateR0(true); + auto expected = LiteralUtil::MakeTuple( + {expected_counter.get(), expected_predicate.get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); } @@ -571,9 +572,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, 5), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -583,26 +584,26 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Tuple({builder.Add(iteration, builder.ConstantR0(1)), - builder.ConstantR0(7)}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Tuple(&builder, {Add(iteration, ConstantR0(&builder, 1)), + ConstantR0(&builder, 7)}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR0(7)}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR0(&builder, 7)}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(5); - auto expected_data = Literal::CreateR0(7); + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_data = LiteralUtil::CreateR0(7); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -632,9 +633,9 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { const int c1 = 5; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c1)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } @@ -642,9 +643,9 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { const int c2 = 7; { XlaBuilder builder("condition2"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c2)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c2)); TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } @@ -654,43 +655,43 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } XlaComputation body2; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build()); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto while1 = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto while1 = While(condition, body, init); - auto while2 = builder.While(condition2, body2, while1); + auto while2 = While(condition2, body2, while1); - auto while_result1 = builder.GetTupleElement(while1, 1); - auto while_result2 = builder.GetTupleElement(while2, 1); + auto while_result1 = GetTupleElement(while1, 1); + auto while_result2 = GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( builder.GetShape(while_result2).ConsumeValueOrDie()); - auto result = builder.Add(while_result1, while_result2); + auto result = Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -711,9 +712,9 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { const int c1 = 5; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c1)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } @@ -721,9 +722,9 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { const int c2 = 7; { XlaBuilder builder("condition2"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c2)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c2)); TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } @@ -733,30 +734,30 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto while1 = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto while1 = While(condition, body, init); - auto while2 = builder.While(condition2, body, while1); + auto while2 = While(condition2, body, while1); - auto while_result1 = builder.GetTupleElement(while1, 1); - auto while_result2 = builder.GetTupleElement(while2, 1); + auto while_result1 = GetTupleElement(while1, 1); + auto while_result2 = GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( builder.GetShape(while_result2).ConsumeValueOrDie()); - auto result = builder.Add(while_result1, while_result2); + auto result = Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -778,9 +779,9 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { const int c1 = 5; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c1)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } @@ -788,9 +789,9 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { const int c2 = 7; { XlaBuilder builder("condition2"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(c2)); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, c2)); TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build()); } @@ -800,29 +801,29 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - auto weights = builder.GetTupleElement(prev, 1); - auto input = builder.ConstantR1(10, 1.f); - auto new_weights = builder.Add(weights, input); - builder.Tuple( - {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + auto weights = GetTupleElement(prev, 1); + auto input = ConstantR1(&builder, 10, 1.f); + auto new_weights = Add(weights, input); + Tuple(&builder, + {Add(iteration, ConstantR0(&builder, 1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto while1 = builder.While(condition, body, init); - auto while2 = builder.While(condition2, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto while1 = While(condition, body, init); + auto while2 = While(condition2, body, init); - auto while_result1 = builder.GetTupleElement(while1, 1); - auto while_result2 = builder.GetTupleElement(while2, 1); + auto while_result1 = GetTupleElement(while1, 1); + auto while_result2 = GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( builder.GetShape(while_result2).ConsumeValueOrDie()); - auto result = builder.Add(while_result1, while_result2); + auto result = Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); @@ -844,9 +845,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Gt(ConstantR0(&builder, 5), iteration); condition = builder.Build().ConsumeValueOrDie(); } @@ -856,38 +857,37 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); + auto prev = Parameter(&builder, 0, result_shape, "prev"); // TupleElement 0 - auto iteration = builder.GetTupleElement(prev, 0); - auto out0 = builder.Add(iteration, builder.ConstantR0(1)); + auto iteration = GetTupleElement(prev, 0); + auto out0 = Add(iteration, ConstantR0(&builder, 1)); // TupleElement 1 - auto input = builder.GetTupleElement(prev, 1); + auto input = GetTupleElement(prev, 1); // Update. - auto update = builder.ConvertElementType(builder.Broadcast(out0, {2}), F32); + auto update = ConvertElementType(Broadcast(out0, {2}), F32); // Starts = iteration * 2; - auto starts = builder.Reshape( - builder.Mul(iteration, builder.ConstantR0(2)), {1}); + auto starts = Reshape(Mul(iteration, ConstantR0(&builder, 2)), {1}); // UpdateSlice. - auto out1 = builder.DynamicUpdateSlice(input, update, starts); + auto out1 = DynamicUpdateSlice(input, update, starts); - builder.Tuple({out0, out1}); + Tuple(&builder, {out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder("while"); - auto init = builder.Tuple( - {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); - auto result = builder.While(condition, body, init); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), + ConstantR1(&builder, 10, 0.f)}); + auto result = While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(5); - auto expected_data = Literal::CreateR1( + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_data = LiteralUtil::CreateR1( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -913,10 +913,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { // Create a computation for the condition: repeat for count iterations. auto build_condition = [this, v6s32](int count) { XlaBuilder builder(TestName()); - auto prev = builder.Reshape( - builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0}, - {}); - builder.Gt(builder.ConstantR0(count), prev); + auto prev = Reshape( + Slice(Parameter(&builder, 0, v6s32, "prev"), {0}, {1}, {1}), {0}, {}); + Gt(ConstantR0(&builder, count), prev); return builder.Build().ConsumeValueOrDie(); }; @@ -924,22 +923,22 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, v6s32, "prev"); - auto inc = builder.ConcatInDim( - {builder.ConstantR1({1}), - builder.RngUniform(builder.ConstantR0(0), - builder.ConstantR0(100), - ShapeUtil::MakeShape(S32, {5}))}, - 0); - builder.Add(inc, prev); + auto prev = Parameter(&builder, 0, v6s32, "prev"); + auto inc = ConcatInDim(&builder, + {ConstantR1(&builder, {1}), + RngUniform(ConstantR0(&builder, 0), + ConstantR0(&builder, 100), + ShapeUtil::MakeShape(S32, {5}))}, + 0); + Add(inc, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. auto while_loop = [this, &body, build_condition](int count) { XlaBuilder builder(TestName()); - auto init = builder.ConstantR1({0, 0, 0, 0, 0, 0}); - builder.While(build_condition(count), body, init); + auto init = ConstantR1(&builder, {0, 0, 0, 0, 0, 0}); + While(build_condition(count), body, init); return builder.Build(); }; @@ -958,33 +957,30 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); XlaBuilder outer("outer"); - auto p = outer.Parameter(0, element_shape, "param"); - auto t = outer.Tuple({p, outer.ConstantR1({1, 1})}); + auto p = Parameter(&outer, 0, element_shape, "param"); + auto t = Tuple(&outer, {p, ConstantR1(&outer, {1, 1})}); TF_ASSERT_OK_AND_ASSIGN(Shape tuple_shape, outer.GetShape(t)); XlaBuilder cond("cond"); - auto cond_t = cond.Parameter(0, tuple_shape, "t"); - TF_ASSERT_OK(Any(cond.Eq(cond.GetTupleElement(cond_t, 0), - cond.ConstantR1({42, 42})), - &cond) - .status()); + auto cond_t = Parameter(&cond, 0, tuple_shape, "t"); + Any(Eq(GetTupleElement(cond_t, 0), ConstantR1(&cond, {42, 42}))); XlaBuilder body("body"); - auto body_t = body.Parameter(0, tuple_shape, "t"); - auto e = body.GetTupleElement(body_t, 1); - body.Tuple({e, e}); + auto body_t = Parameter(&body, 0, tuple_shape, "t"); + auto e = GetTupleElement(body_t, 1); + Tuple(&body, {e, e}); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); - outer.While(cond_computation, body_computation, t); + While(cond_computation, body_computation, t); - auto expected_element = Literal::CreateR1({1, 1}); + auto expected_element = LiteralUtil::CreateR1({1, 1}); auto expected = - Literal::MakeTuple({expected_element.get(), expected_element.get()}); + LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*Literal::CreateR1({42, 42}))); + client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -993,24 +989,23 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); XlaBuilder outer("outer"); - auto p = outer.Parameter(0, element_shape, "param"); + auto p = Parameter(&outer, 0, element_shape, "param"); XlaBuilder cond("cond"); - auto cond_t = cond.Parameter(0, element_shape, "t"); - TF_ASSERT_OK( - Any(cond.Eq(cond_t, cond.ConstantR1({42, 42})), &cond).status()); + auto cond_t = Parameter(&cond, 0, element_shape, "t"); + Any(Eq(cond_t, ConstantR1(&cond, {42, 42}))); XlaBuilder body("body"); - auto body_t = body.Parameter(0, element_shape, "t"); - auto e = body.Broadcast(body.ConstantR0(1.0), {2}); + Parameter(&body, 0, element_shape, "t"); + Broadcast(ConstantR0(&body, 1.0), {2}); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); - outer.While(cond_computation, body_computation, p); + While(cond_computation, body_computation, p); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*Literal::CreateR1({42, 42}))); + client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); ComputeAndCompareR1(&outer, {1.0f, 1.0f}, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1019,25 +1014,24 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {}); XlaBuilder outer("outer"); - auto p = outer.Parameter(0, element_shape, "param"); + auto p = Parameter(&outer, 0, element_shape, "param"); XlaBuilder cond("cond"); - auto cond_t = cond.Parameter(0, element_shape, "t"); - cond.Eq(cond_t, cond.ConstantR0(42)); + auto cond_t = Parameter(&cond, 0, element_shape, "t"); + Eq(cond_t, ConstantR0(&cond, 42)); XlaBuilder body("body"); - auto body_t = body.Parameter(0, element_shape, "t"); - auto tuple = - body.Tuple({body_t, body.Add(body_t, body.ConstantR0(1))}); - auto e = body.GetTupleElement(tuple, 1); + auto body_t = Parameter(&body, 0, element_shape, "t"); + auto tuple = Tuple(&body, {body_t, Add(body_t, ConstantR0(&body, 1))}); + GetTupleElement(tuple, 1); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); - outer.While(cond_computation, body_computation, p); + While(cond_computation, body_computation, p); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*Literal::CreateR0(42))); + client_->TransferToServer(*LiteralUtil::CreateR0(42))); ComputeAndCompareR0(&outer, 43.0f, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1056,33 +1050,31 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { XlaBuilder outer("outer"); auto p = - outer.Tuple({outer.ConstantR0(0), - outer.Parameter(0, ShapeUtil::MakeShape(S32, {}), "t")}); + Tuple(&outer, {ConstantR0(&outer, 0), + Parameter(&outer, 0, ShapeUtil::MakeShape(S32, {}), "t")}); XlaBuilder cond("cond"); - auto params = cond.Parameter(0, result_shape, "prev"); - auto cond_t = cond.Add(cond.GetTupleElement(params, 1), - cond.GetTupleElement(params, 0)); - cond.Lt(cond_t, cond.ConstantR0(30)); + auto params = Parameter(&cond, 0, result_shape, "prev"); + auto cond_t = Add(GetTupleElement(params, 1), GetTupleElement(params, 0)); + Lt(cond_t, ConstantR0(&cond, 30)); XlaBuilder body("body"); - auto body_t = body.Parameter(0, result_shape, "t"); + auto body_t = Parameter(&body, 0, result_shape, "t"); - auto tuple = body.Tuple( - {body.Add(body.GetTupleElement(body_t, 0), body.ConstantR0(1)), - body.Add(body.GetTupleElement(body_t, 1), body.ConstantR0(1))}); + Tuple(&body, {Add(GetTupleElement(body_t, 0), ConstantR0(&body, 1)), + Add(GetTupleElement(body_t, 1), ConstantR0(&body, 1))}); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); - outer.While(cond_computation, body_computation, p); + While(cond_computation, body_computation, p); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*Literal::CreateR0(1))); + client_->TransferToServer(*LiteralUtil::CreateR0(1))); - auto add1 = Literal::CreateR0(15); - auto add2 = Literal::CreateR0(16); - auto expected = Literal::MakeTuple({add1.get(), add2.get()}); + auto add1 = LiteralUtil::CreateR0(15); + auto add2 = LiteralUtil::CreateR0(16); + auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()}); ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1105,9 +1097,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { XlaComputation inner_condition; { XlaBuilder builder("inner_condition"); - auto params = builder.Parameter(0, inner_result_shape, "prev"); - auto i = builder.GetTupleElement(params, 0); - builder.Lt(i, builder.ConstantR0(7)); + auto params = Parameter(&builder, 0, inner_result_shape, "prev"); + auto i = GetTupleElement(params, 0); + Lt(i, ConstantR0(&builder, 7)); inner_condition = builder.Build().ConsumeValueOrDie(); } @@ -1116,8 +1108,8 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { XlaComputation outer_condition; { XlaBuilder builder("outer_condition"); - auto prev = builder.Parameter(0, outer_result_shape, "prev"); - builder.Lt(prev, builder.ConstantR0(30)); + auto prev = Parameter(&builder, 0, outer_result_shape, "prev"); + Lt(prev, ConstantR0(&builder, 30)); outer_condition = builder.Build().ConsumeValueOrDie(); } @@ -1126,12 +1118,12 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { XlaComputation inner_body; { XlaBuilder builder("inner_body"); - auto params = builder.Parameter(0, inner_result_shape, "prev"); - auto i = builder.GetTupleElement(params, 0); - auto result = builder.GetTupleElement(params, 1); - i = builder.Add(builder.ConstantR0(1), i); - result = builder.Add(builder.ConstantR0(2), result); - builder.Tuple({i, result}); + auto params = Parameter(&builder, 0, inner_result_shape, "prev"); + auto i = GetTupleElement(params, 0); + auto result = GetTupleElement(params, 1); + i = Add(ConstantR0(&builder, 1), i); + result = Add(ConstantR0(&builder, 2), result); + Tuple(&builder, {i, result}); inner_body = builder.Build().ConsumeValueOrDie(); } @@ -1139,17 +1131,17 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { XlaComputation outer_body; { XlaBuilder builder("outer_body"); - auto prev = builder.Parameter(0, outer_result_shape, "prev"); - auto init = builder.Tuple({builder.ConstantR0(0), prev}); - auto result = builder.While(inner_condition, inner_body, init); - builder.GetTupleElement(result, 1); + auto prev = Parameter(&builder, 0, outer_result_shape, "prev"); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), prev}); + auto result = While(inner_condition, inner_body, init); + GetTupleElement(result, 1); outer_body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.ConstantR0(0); - builder.While(outer_condition, outer_body, init); + auto init = ConstantR0(&builder, 0); + While(outer_condition, outer_body, init); ComputeAndCompareR0(&builder, 42, {}); } @@ -1167,8 +1159,8 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { XlaComputation condition_callee; { XlaBuilder builder("condition_callee"); - auto prev = builder.Parameter(0, result_shape, "prev"); - builder.Tuple({builder.Gt(builder.ConstantR0(5), prev)}); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + Tuple(&builder, {Gt(ConstantR0(&builder, 5), prev)}); condition_callee = builder.Build().ConsumeValueOrDie(); } @@ -1176,9 +1168,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto result = builder.Call(condition_callee, {prev}); - builder.GetTupleElement(result, 0); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto result = Call(&builder, condition_callee, {prev}); + GetTupleElement(result, 0); condition = builder.Build().ConsumeValueOrDie(); } @@ -1186,16 +1178,16 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, result_shape, "prev"); - auto input = builder.ConstantR0(1); - builder.Add(input, prev); + auto prev = Parameter(&builder, 0, result_shape, "prev"); + auto input = ConstantR0(&builder, 1); + Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. XlaBuilder builder(TestName()); - auto init = builder.ConstantR0(0); - builder.While(condition, body, init); + auto init = ConstantR0(&builder, 0); + While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -1210,40 +1202,69 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { XlaComputation condition; { XlaBuilder builder("condition"); - auto state = builder.Parameter(0, while_shape, "state"); - builder.Gt(builder.ConstantR0(5), builder.GetTupleElement(state, 0)); + auto state = Parameter(&builder, 0, while_shape, "state"); + Gt(ConstantR0(&builder, 5), GetTupleElement(state, 0)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } XlaComputation body; { XlaBuilder builder("body"); - auto state = builder.Parameter(0, while_shape, "state"); - auto indvar = builder.GetTupleElement(state, 0); - auto input_0 = builder.GetTupleElement(state, 1); - auto input_1 = builder.GetTupleElement(state, 2); - auto output = builder.Tanh(builder.Dot(input_0, input_1)); - auto indvar_next = builder.Add(indvar, builder.ConstantR0(1)); - builder.Tuple({indvar_next, input_0, input_1, output}); + auto state = Parameter(&builder, 0, while_shape, "state"); + auto indvar = GetTupleElement(state, 0); + auto input_0 = GetTupleElement(state, 1); + auto input_1 = GetTupleElement(state, 2); + auto output = Tanh(Dot(input_0, input_1)); + auto indvar_next = Add(indvar, ConstantR0(&builder, 1)); + Tuple(&builder, {indvar_next, input_0, input_1, output}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } XlaBuilder builder(TestName()); - auto matrix_input = builder.Parameter(0, matrix_shape, "matrix"); - auto init = builder.Tuple( - {builder.ConstantR0(0), matrix_input, matrix_input, matrix_input}); - auto while_instruction = builder.While(condition, body, init); - builder.GetTupleElement(while_instruction, 3); + auto matrix_input = Parameter(&builder, 0, matrix_shape, "matrix"); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), matrix_input, + matrix_input, matrix_input}); + auto while_instruction = While(condition, body, init); + GetTupleElement(while_instruction, 3); - TF_ASSERT_OK_AND_ASSIGN(auto param_value, - client_->TransferToServer(*Literal::CreateR2( - {{1.0, 2.0}, {-1.0, -2.0}}))); + TF_ASSERT_OK_AND_ASSIGN( + auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2( + {{1.0, 2.0}, {-1.0, -2.0}}))); ComputeAndCompareR2( &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}}, {param_value.get()}, ErrorSpec(4e-5)); } +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { + auto while_shape = ShapeUtil::MakeShape(S32, {}); + + XlaComputation condition; + { + XlaBuilder builder("condition"); + Parameter(&builder, 0, while_shape, "state"); + Infeed(&builder, ShapeUtil::MakeShape(PRED, {})); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); + } + + XlaComputation body; + { + XlaBuilder builder("body"); + auto indvar = Parameter(&builder, 0, while_shape, "state"); + Add(indvar, ConstantR0(&builder, 1)); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); + } + + XlaBuilder builder(TestName()); + While(condition, body, ConstantR0(&builder, 0)); + + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(false))); + + ComputeAndCompareR0(&builder, 2, {}); +} + void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); @@ -1264,9 +1285,9 @@ void BM_WhileLoop(int num_iters) { XlaComputation condition; { XlaBuilder builder("condition"); - auto prev = builder.Parameter(0, loop_state_shape, "prev"); - auto iteration = builder.GetTupleElement(prev, 0); - builder.Lt(iteration, builder.ConstantR0(loop_limit)); + auto prev = Parameter(&builder, 0, loop_state_shape, "prev"); + auto iteration = GetTupleElement(prev, 0); + Lt(iteration, ConstantR0(&builder, loop_limit)); condition = builder.Build().ConsumeValueOrDie(); } @@ -1274,29 +1295,29 @@ void BM_WhileLoop(int num_iters) { XlaComputation body; { XlaBuilder builder("body"); - auto prev = builder.Parameter(0, loop_state_shape, "prev"); + auto prev = Parameter(&builder, 0, loop_state_shape, "prev"); // TupleElement 0 - auto iteration = builder.GetTupleElement(prev, 0); - auto out0 = builder.Add(iteration, builder.ConstantR0(1)); + auto iteration = GetTupleElement(prev, 0); + auto out0 = Add(iteration, ConstantR0(&builder, 1)); // TupleElement 1 - auto input = builder.GetTupleElement(prev, 1); + auto input = GetTupleElement(prev, 1); // Update. - auto one = builder.ConstantR0(1.0); - auto update = builder.Broadcast(one, {1, 1024, 1024}); + auto one = ConstantR0(&builder, 1.0); + auto update = Broadcast(one, {1, 1024, 1024}); // Starts = iteration * 2; - auto starts = builder.ConstantR1({0, 0, 0}); + auto starts = ConstantR1(&builder, {0, 0, 0}); // UpdateSlice. - auto out1 = builder.DynamicUpdateSlice(input, update, starts); - builder.Tuple({out0, out1}); + auto out1 = DynamicUpdateSlice(input, update, starts); + Tuple(&builder, {out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While instruction. XlaBuilder builder("while"); - auto zero = builder.ConstantR0(0.0); - auto input = builder.Broadcast(zero, {seq_len, 1024, 1024}); - auto init = builder.Tuple({builder.ConstantR0(0), input}); - builder.While(condition, body, init); + auto zero = ConstantR0(&builder, 0.0); + auto input = Broadcast(zero, {seq_len, 1024, 1024}); + auto init = Tuple(&builder, {ConstantR0(&builder, 0), input}); + While(condition, body, init); auto computation = builder.Build().ConsumeValueOrDie(); std::unique_ptr executable = diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 3c9a01653c67203cbc962a3d3d967142f7a2102c..11f3efb1f34ad23ebdcbb65c90aa5fb7a6adeae5 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -18,10 +18,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -79,10 +80,12 @@ struct ParsedProfileOutputLine { Status ParseOneProfileOutputLine( const string& line, bool expect_hlo, - gtl::FlatMap* parsed_results) { + gtl::FlatMap* parsed_results, + tensorflow::gtl::ArraySlice opcodes_to_ignore = + {}) { string separator = "[^:]*:: +"; - string match_percentage = "\\d+\\.\\d\\d%"; - string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; + string match_percentage = R"(\d+\.\d*% +\d+Σ)"; + string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))"; string match_usecs = "([0-9.]+) usec"; string match_flops = "([^ ]*)"; string match_trops = "([^ ]*)"; @@ -113,7 +116,9 @@ Status ParseOneProfileOutputLine( ", Regexp: ", regexp_pattern); } - InsertOrDie(parsed_results, parsed_line.opcode, parsed_line); + if (!c_linear_search(opcodes_to_ignore, parsed_line.opcode)) { + InsertOrDie(parsed_results, parsed_line.opcode, parsed_line); + } return Status::OK(); } @@ -128,20 +133,23 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, se::StreamExecutor* executor = backend->default_stream_executor(); DeviceMemoryAllocator* allocator = backend->memory_allocator(); auto* transfer_manager = backend->transfer_manager(); + TF_ASSERT_OK_AND_ASSIGN( + StreamPool::Ptr stream_ptr, + backend->BorrowStream(backend->default_device_ordinal())); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer lhs_arg, transfer_manager->AllocateScopedShapedBuffer( lhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - executor, *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); + stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer rhs_arg, transfer_manager->AllocateScopedShapedBuffer( rhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - executor, *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); + stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, @@ -153,9 +161,6 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, &executable->hlo_profile_printer_data(), &executable->hlo_profile_index_map()); - TF_ASSERT_OK_AND_ASSIGN( - Backend::StreamPtr stream_ptr, - backend->BorrowStream(backend->default_device_ordinal())); ExecutableRunOptions exec_run_options; exec_run_options.set_stream(stream_ptr.get()); exec_run_options.set_allocator(backend->memory_allocator()); @@ -168,6 +173,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, auto execution_result, executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg}, &hlo_execution_profile)); + TF_ASSERT_OK(stream_ptr->BlockHostUntilDone()); (void)execution_result; *profile_output = @@ -187,9 +193,9 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { ClientLibrary::GetOrCreateLocalClient(platform)); XlaBuilder builder(TestName()); - auto result = builder.Tanh(builder.Add( - builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); + Tanh(Add( + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), + Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); @@ -219,7 +225,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { MaybeFind(parsed_profile_lines, "tanh")); EXPECT_GT(total_profile.cycles, 0); - EXPECT_EQ(total_profile.cycles_percentage, "100.00%"); + EXPECT_EQ(total_profile.cycles_percentage, "100.% 100Σ"); EXPECT_TRUE(HasFlops(total_profile)); EXPECT_TRUE(HasTrops(total_profile)); @@ -239,9 +245,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { EXPECT_TRUE(HasTrops(tanh_profile)); } -// TODO(b/71544591): The GPU backend does not record cycles spent in on Hlo -// instructions "interior" to while nodes. -XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { +XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { const int64 size = 256; Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size}); Shape while_result_shape = @@ -255,30 +259,30 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { XlaComputation condition; { XlaBuilder builder("condition"); - auto state = builder.Parameter(0, while_result_shape, "state"); - auto iteration = builder.GetTupleElement(state, 0); - builder.Gt(builder.ConstantR0(5), iteration); + auto state = Parameter(&builder, 0, while_result_shape, "state"); + auto iteration = GetTupleElement(state, 0); + Gt(ConstantR0(&builder, 5), iteration); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } XlaComputation body; { XlaBuilder builder("body"); - auto state = builder.Parameter(0, while_result_shape, "state"); - auto matrix = builder.GetTupleElement(state, 1); - auto next_iteration = builder.Add(builder.GetTupleElement(state, 0), - builder.ConstantR0(1)); - builder.Tuple({next_iteration, builder.Add(matrix, matrix)}); + auto state = Parameter(&builder, 0, while_result_shape, "state"); + auto matrix = GetTupleElement(state, 1); + auto next_iteration = + Add(GetTupleElement(state, 0), ConstantR0(&builder, 1)); + Tuple(&builder, {next_iteration, Mul(matrix, matrix)}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } XlaBuilder builder(TestName()); auto initial_while_state = - builder.Tuple({builder.ConstantR0(0), - builder.Parameter(0, matrix_shape, "initial_value")}); - auto while_result = builder.While(condition, body, initial_while_state); - builder.Add(builder.GetTupleElement(while_result, 1), - builder.Parameter(1, matrix_shape, "other_value")); + Tuple(&builder, {ConstantR0(&builder, 0), + Parameter(&builder, 0, matrix_shape, "initial_value")}); + auto while_result = While(condition, body, initial_while_state); + Add(GetTupleElement(while_result, 1), + Parameter(&builder, 1, matrix_shape, "other_value")); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); @@ -290,36 +294,50 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { tensorflow::str_util::Split(profile_output, '\n'); auto while_body_profile_start = - std::find_if(profile_output_lines.begin(), profile_output_lines.end(), + c_find_if(profile_output_lines, [](tensorflow::StringPiece s) { + return tensorflow::str_util::StartsWith(s, + "Execution profile for body"); + }); + + ASSERT_NE(while_body_profile_start, profile_output_lines.cend()); + + auto while_body_profile_end = + std::find_if(while_body_profile_start, profile_output_lines.end(), [](tensorflow::StringPiece s) { return tensorflow::str_util::StartsWith( - s, "Execution profile for body"); + s, "********** microseconds report **********"); }); - ASSERT_NE(while_body_profile_start, profile_output_lines.end()); + // We emit a blank line before the "********** microseconds report **********" + // line. + while_body_profile_end--; - gtl::FlatMap parsed_profile_lines; + ASSERT_NE(while_body_profile_end, profile_output_lines.end()); - TF_ASSERT_OK( - ParseOneProfileOutputLine(*std::next(while_body_profile_start, 1), - /*expect_hlo=*/false, &parsed_profile_lines)); + gtl::FlatMap parsed_profile_lines; - TF_ASSERT_OK( - ParseOneProfileOutputLine(*std::next(while_body_profile_start, 2), - /*expect_hlo=*/true, &parsed_profile_lines)); + for (auto while_body_profile_i = while_body_profile_start + 1; + while_body_profile_i != while_body_profile_end; while_body_profile_i++) { + // There are multiple "get-tuple-element" instructions in the while body so + // we ignore them -- we don't want parsed_profile_lines to be a multi-map. + TF_ASSERT_OK(ParseOneProfileOutputLine( + *while_body_profile_i, + /*expect_hlo=*/while_body_profile_i != (while_body_profile_start + 1), + &parsed_profile_lines, {"get-tuple-element"})); + } TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile, MaybeFind(parsed_profile_lines, "[total]")); - TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile, - MaybeFind(parsed_profile_lines, "add")); + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine multiply_profile, + MaybeFind(parsed_profile_lines, "multiply")); EXPECT_GT(total_while_body_profile.cycles, 0); EXPECT_EQ(total_while_body_profile.opcode, "[total]"); - EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%"); + EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.% 100Σ"); - EXPECT_GT(total_while_body_profile.cycles, dot_profile.cycles); - EXPECT_NE(dot_profile.cycles_percentage, "0.00%"); - EXPECT_NE(dot_profile.cycles_percentage, "100.00%"); + EXPECT_GT(total_while_body_profile.cycles, multiply_profile.cycles); + EXPECT_NE(multiply_profile.cycles_percentage, "0.00%"); + EXPECT_NE(multiply_profile.cycles_percentage, "100.00%"); } } // namespace } // namespace xla @@ -336,8 +354,11 @@ static std::pair AddXlaHloProfileFlag(int argc, char** argv) { new_argv[argc] = strdup("--xla_hlo_profile"); // Fusion can change the Hlo instructions that show up in the final Hlo - // executable, so block it here. - new_argv[argc + 1] = strdup("--xla_disable_hlo_passes=fusion"); + // executable, so block it here. Also block the WhileLoopInvariantCodeMotion + // pass, otherwise a while loop is transformed and we could not match the + // original name in the ProfileWhileComputation test. + new_argv[argc + 1] = strdup( + "--xla_disable_hlo_passes=fusion,while-loop-invariant-code-motion"); return {argc + 2, new_argv}; } diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index a9f2915b458b1816926de727b3da21982d06f6c0..a075195618c42aaa11f7b1c17730e67889a2c308 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -49,6 +49,7 @@ GTEST_API_ int main(int argc, char** argv) { } // Unfortunately Google's internal benchmark infrastructure has a // different API than Tensorflow's. + testing::InitGoogleTest(&argc, argv); #if defined(PLATFORM_GOOGLE) base::SetFlag(&FLAGS_benchmarks, pattern); RunSpecifiedBenchmarks(); diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 56702feab9a4e8d00df3a165ab994aef2d42d830..897123d7606db60abc1105b03beb3f23ab249579 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index e45e5291c9b10803f5e5008b72c7dd0116a0dea0..708e8c80d8b5c09454eb64d4e12df51a5b7ea628 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc index 23070b663870a2b78b38663e09a32fcb28d9c2dc..92f9b4f9f0efa2dc08287bdcbefc88f879164308 100644 --- a/tensorflow/compiler/xla/text_literal_reader_test.cc +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 373c0d2d8d8ab05dec11e51f265d41b91e7920bf..24e0784741a4c9779b0adb7a7740c3d6e2fb033a 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 0a1235b5e04675da0f412bafab6c4ecf04367787..159ac1b7e1b6f9c07dac795fb640cd0b2d284bcb 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ #define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 70cf2fb1b8a1b4f2ecfdaeaef3a00ddc974e2652..4ea02faffcd52065b05c0444202bd1a3d9d87ee6 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -30,8 +31,9 @@ namespace xla { namespace { TEST(TextLiteralWriterTest, WritesFloatLiteral) { - auto literal = Literal::CreateR2({ - {3.14, 2.17}, {1.23, 4.56}, + auto literal = LiteralUtil::CreateR2({ + {3.14, 2.17}, + {1.23, 4.56}, }); string path = tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever"); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index e4a052c8f1c0009619c3a94606f6384d04006e4e..40d28a57bfddd3403cad8252df985b746362631f 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -37,6 +37,7 @@ cc_library( "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_proto", @@ -74,7 +75,7 @@ cc_library( srcs = ["replay_computation.cc"], deps = [ "//tensorflow/compiler/xla:execution_options_util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -84,7 +85,9 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:testing", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", @@ -123,7 +126,7 @@ tf_cc_binary( name = "show_literal", srcs = ["show_literal.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", @@ -145,7 +148,7 @@ tf_cc_binary( name = "show_text_literal", srcs = ["show_text_literal.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:text_literal_reader", "//tensorflow/compiler/xla:types", @@ -164,6 +167,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", @@ -181,6 +185,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_proto", @@ -198,6 +203,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index befb55453777dce30af89bcaad2ffe1647097576..f20dcef382b86d27d7c176ae7e4132ad1db7b901 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index cfb8f37487d6499b803438a135be54524fcf17d2..f0af0580c1fbca455c6ed5f87f82971faee50a06 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 5dd5150be339846d0775880931f615b92c5b08d8..f03e1b1f965af761c101555fd0275bc0425b9cf0 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index a5dce20456c6a2402f425ebb3d575d1bb625f839..dc5c106d02cb679f3e6f5b2bea40bbb42f8bd1cc 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index f7574e0b1cc95daee6d6743ba4e2e490ee87e7c6..b4774233e588dc407bfb88defca9bf55e08eea09 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -30,6 +30,9 @@ limitations under the License. // The output format is: // // file_path: computation_name :: type:literal_str +// +// Note: If you pass multiple modules, they will be compiled in parallel but run +// in series. #include #include @@ -42,8 +45,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -74,6 +79,18 @@ struct Options { int num_runs = 1; }; +std::unique_ptr CompileExecutable(const HloSnapshot& module, + LocalClient* client) { + XlaComputation computation(module.hlo().hlo_module()); + std::vector argument_layouts; + for (const auto& param : computation.proto().program_shape().parameters()) { + argument_layouts.push_back(¶m); + } + return client + ->Compile(computation, argument_layouts, ExecutableBuildOptions()) + .ValueOrDie(); +} + // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // @@ -84,6 +101,7 @@ struct Options { // If neither generate_fake_infeed is true nor a fake_infeed_shape is provided, // no infeed is performed. StatusOr ReplayComputation(const HloSnapshot& module, + LocalExecutable* executable, LocalClient* client, const Options& opts) { XlaComputation computation(module.hlo().hlo_module()); @@ -166,34 +184,36 @@ StatusOr ReplayComputation(const HloSnapshot& module, }); } - std::vector argument_layouts; - for (const auto& param : computation.proto().program_shape().parameters()) { - argument_layouts.push_back(¶m); + // Do not attempt to run the executable if num_runs is less than 1. + if (opts.num_runs < 1) { + return Cancelled("Cancelled after compilation since --num_runs < 1."); } - std::unique_ptr executable = - client->Compile(computation, argument_layouts, ExecutableBuildOptions()) - .ValueOrDie(); // Run the computation num_runs times, and return the result from the last // execution. + const bool xla_hlo_profile = + legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); StreamExecutorMemoryAllocator allocator( client->platform(), {client->platform()->ExecutorForDevice(0).ValueOrDie()}); tensorflow::gtl::optional result; for (int i = 0; i < opts.num_runs; ++i) { + // If xla_hlo_profile is enabled, print a noisy message before the last run, + // making it easier to separate this profile from the others in the logspam. + if (xla_hlo_profile && i == opts.num_runs - 1) { + LOG(INFO) << "\n\n***** Final run below ******"; + } ExecutionProfile profile; ExecutableRunOptions run_options; run_options.set_execution_profile(&profile); run_options.set_allocator(&allocator); TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options)); - LOG(INFO) << "Execution took " - << static_cast(profile.compute_time_ns()) / 1e9 << "s"; + LOG(INFO) << "Done executing in " + << static_cast(profile.compute_time_ns()) / 1e9 + << "s: " << module.hlo().hlo_module().name(); } - // Check that --num_runs > 0, otherwise *result below will fail with an - // unhelpful error (because the loop didn't run any iterations). - CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0"; TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, client->ShapedBufferToLiteral(*result)); return std::move(*result_literal); @@ -203,9 +223,13 @@ StatusOr ParseInputFile(const string& filename, const Options& opts) { tensorflow::Env* env = tensorflow::Env::Default(); HloSnapshot snapshot; - if (tensorflow::ReadBinaryProto(env, filename, &snapshot).ok()) { + auto s = tensorflow::ReadBinaryProto(env, filename, &snapshot); + if (s.ok()) { return snapshot; } + if (s.code() == tensorflow::error::NOT_FOUND) { + return s; + } CHECK(opts.use_fake_data) << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " "and textual HLO don't carry real data."; @@ -232,15 +256,42 @@ StatusOr ParseInputFile(const string& filename, int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { LocalClient* client = ClientLibrary::LocalClientOrDie(); int exit_status = EXIT_SUCCESS; + + std::vector snapshots; for (char* arg : args) { StatusOr maybe_snapshot = ParseInputFile(arg, opts); - if (!maybe_snapshot.ok()) { - continue; + if (maybe_snapshot.ok()) { + snapshots.push_back(std::move(maybe_snapshot).ValueOrDie()); + } else { + LOG(ERROR) << "Can't handle file " << arg << ": " + << maybe_snapshot.status(); } - HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie(); - StatusOr result_status = ReplayComputation(snapshot, client, opts); + } + + // Compile all the modules in parallel. + LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel."; + std::vector> executables; + { + // ThreadPool CHECK-fails if we give it 0 threads. + tensorflow::thread::ThreadPool thread_pool( + tensorflow::Env::Default(), tensorflow::ThreadOptions(), + "compile_modules", std::max(size_t{1}, snapshots.size()), + /*low_latency_hint=*/false); + executables.resize(snapshots.size()); + for (int64 i = 0; i < snapshots.size(); ++i) { + thread_pool.Schedule([&snapshots, &executables, client, i] { + executables[i] = CompileExecutable(snapshots[i], client); + }); + } + } + LOG(INFO) << "Done compiling; now running the modules."; + + for (int64 i = 0; i < executables.size(); ++i) { + LocalExecutable* executable = executables[i].get(); + StatusOr result_status = + ReplayComputation(snapshots[i], executable, client, opts); if (!result_status.ok()) { - fprintf(stderr, "%s: error: %s\n", arg, + fprintf(stderr, "%s: error: %s\n", args[i], result_status.status().ToString().c_str()); exit_status = EXIT_FAILURE; continue; @@ -248,10 +299,11 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { if (opts.print_result) { Literal result = std::move(result_status).ValueOrDie(); - fprintf(stdout, "%s: %s :: %s:%s\n", arg, - snapshot.hlo().hlo_module().name().c_str(), + fprintf(stdout, "%s: %s :: %s:%s\n", args[i], + executable->executable()->module().name().c_str(), ShapeUtil::HumanString(result.shape()).c_str(), result.ToString().c_str()); + auto& snapshot = snapshots[i]; if (snapshot.has_result()) { std::unique_ptr literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index fe8e72ba32bb4493b2751cfdfeb977f271092f9c..51909190a3ef20c3df78d08796e88bdbb650609d 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc index 8525873e913185554d18df8c8c3584bfcdcdcabe..48c837481181f6ad8f864569fd62e0e23fa02ecd 100644 --- a/tensorflow/compiler/xla/tools/show_text_literal.cc +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/text_literal_reader.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index b4f45cc972d3d397ddff8e8d9163d1fef387392f..5ae099a4622bb7116c7a17f93060b699ead6e3a6 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -499,17 +500,17 @@ bool c_is_sorted(const C& c, Compare&& comp) { } template -auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) { +auto c_adjacent_find(C& c) -> decltype(std::begin(c)) { return std::adjacent_find(std::begin(c), std::end(c)); } template -auto c_find_if(const C& c, Pred&& pred) -> decltype(std::begin(c)) { +auto c_find_if(C& c, Pred&& pred) -> decltype(std::begin(c)) { return std::find_if(std::begin(c), std::end(c), std::forward(pred)); } template -auto c_find(const C& c, Value&& value) -> decltype(std::begin(c)) { +auto c_find(C& c, Value&& value) -> decltype(std::begin(c)) { return std::find(std::begin(c), std::end(c), std::forward(value)); } @@ -533,12 +534,24 @@ c_count_if(const C& c, Pred&& pred) { return std::count_if(std::begin(c), std::end(c), std::forward(pred)); } +// Determines whether `value` is present in `c`. +template +bool c_linear_search(const C& c, T&& value) { + auto last = std::end(c); + return std::find(std::begin(c), last, std::forward(value)) != last; +} + template int64 FindIndex(const C& c, Value&& value) { auto it = c_find(c, std::forward(value)); return std::distance(c.begin(), it); } +template +bool ArrayContains(tensorflow::gtl::ArraySlice c, const T& value) { + return c_find(c, value) != c.end(); +} + template void InsertAt(C* c, int64 index, Value&& value) { c->insert(c->begin() + index, std::forward(value)); @@ -549,6 +562,17 @@ void EraseAt(C* c, int64 index) { c->erase(c->begin() + index); } +template +std::vector ArraySliceToVector(tensorflow::gtl::ArraySlice slice) { + return std::vector(slice.begin(), slice.end()); +} + +template +std::vector InlinedVectorToVector( + const tensorflow::gtl::InlinedVector& inlined_vector) { + return std::vector(inlined_vector.begin(), inlined_vector.end()); +} + // Returns true if `x` fits in 32-bits. template bool IsInt32(T x) { diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 6f07e4606bef015214f2c564515c8258a906205b..3b72eb17c600abf542caffb66fe150a051b4bb4d 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -104,15 +104,6 @@ message DebugOptions { // interpretation of this value is left to the backends. int32 xla_backend_optimization_level = 31; - // When true, "unsafe" mathematical optimizations are enabled. These - // transformations include but are not limited to: - // - // - Reducing the precision of operations (e.g. using an approximate sin - // function, or transforming x/y into x * (1/y)). - // - Assuming that operations never produce or consume NaN or +/- Inf. - // - Assuming that +0 and -0 are indistinguishable. - bool xla_enable_fast_math = 32; - // Embed the compiler IR as a string in the executable. bool xla_embed_ir_in_executable = 33; @@ -194,6 +185,16 @@ message DebugOptions { // Maximum kernel unroll factor for the GPU backend. int32 xla_gpu_max_kernel_unroll_factor = 98; + // When true, "unsafe" mathematical optimizations are enabled. These + // transformations include but are not limited to: + // + // - Reducing the precision of operations (e.g. using an approximate sin + // function, or transforming x/y into x * (1/y)). + // - Assuming that operations never produce or consume NaN or +/- Inf. + // - Assuming that +0 and -0 are indistinguishable. + bool xla_cpu_enable_fast_math = 99; + bool xla_gpu_enable_fast_math = 100; + // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. map xla_backend_extra_options = 500; @@ -293,6 +294,7 @@ message ComputationStatsResponse { } message CreateChannelHandleRequest { + ChannelHandle.ChannelType channel_type = 1; } message CreateChannelHandleResponse { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 0af73e8a93060f4569ddef9697b89a6fa2b8674b..4c35e93d38450b8263290da8e327d1f2126c1532 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -274,6 +274,9 @@ message ExecutionProfile { // for the input data transfer since the memory is initialized with the proper // values before the execution. int64 compute_and_transfer_time_ns = 5; + + // The size of the binary code in the executable. + int64 executable_size_in_bytes = 6; } // Handle given to a user that represents an execution that the user launched @@ -305,6 +308,22 @@ message DeviceHandle { // Send instructions will be blocked until the data is transferred. message ChannelHandle { int64 handle = 1; + enum ChannelType { + // Invalid primitive type to serve as default. + CHANNEL_TYPE_INVALID = 0; + + // A channel for sending data between devices. + DEVICE_TO_DEVICE = 1; + + // A channel for sending data from the device to the host. Can only be used + // with a Send operation. + DEVICE_TO_HOST = 2; + + // A channel for sending data from the host to the device. Can only be used + // with a Recv operation. + HOST_TO_DEVICE = 3; + } + ChannelType type = 2; } // DeviceAssignmentProto is a serialized form of DeviceAssignment class, which @@ -428,6 +447,20 @@ message GatherDimensionNumbers { int64 index_vector_dim = 4; } +// Describes the dimension numbers for a scatter operation. +// +// All the fields are similar to the corresponding fields in +// GatherDimensionNumbers. Differences are noted below. +message ScatterDimensionNumbers { + // The set of dimensions in the updates shape that are window dimensions. + repeated int64 update_window_dims = 1; + // The set of window dimensions that must be inserted into the updates shape. + repeated int64 inserted_window_dims = 2; + + repeated int64 scatter_dims_to_operand_dims = 3; + int64 index_vector_dim = 4; +} + message ConvolutionDimensionNumbers { // The number of the dimension that represents batch in the input. int64 input_batch_dimension = 7; @@ -528,3 +561,11 @@ message OpSharding { // to. repeated OpSharding tuple_shardings = 5; } + +// Describes the replica groups in a cross replica op (e.g., all-reduce and +// all-to-all). +message ReplicaGroup { + // The ids of the replicas that belongs to the same group. The ordering of the + // ids matters in some op (e.g., all-to-all). + repeated int64 replica_ids = 1; +} diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 50b1ae5cc3cba2d6ac89c4415a3419ffdf7aec93..23bb783e2207da7076833138f4421980ad20bd96 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -7,8 +7,8 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//third_party/mpi:mpi.bzl", "if_mpi") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") load("//tensorflow:tensorflow.bzl", "if_not_windows") +load("//tensorflow:tensorflow.bzl", "if_not_windows_cuda") py_library( name = "contrib_py", @@ -26,14 +26,11 @@ py_library( "//tensorflow/contrib/bayesflow:bayesflow_py", "//tensorflow/contrib/boosted_trees:init_py", "//tensorflow/contrib/checkpoint/python:checkpoint", - "//tensorflow/contrib/cloud:cloud_py", - "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/contrib/autograph", "//tensorflow/contrib/constrained_optimization", - "//tensorflow/contrib/control_flow", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", @@ -46,10 +43,10 @@ py_library( "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/feature_column:feature_column_py", "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/fused_conv:fused_conv_py", "//tensorflow/contrib/gan", "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/contrib/grid_rnn:grid_rnn_py", + "//tensorflow/contrib/hadoop", "//tensorflow/contrib/hooks", "//tensorflow/contrib/image:distort_image_py", "//tensorflow/contrib/image:image_py", @@ -111,20 +108,33 @@ py_library( "//tensorflow/contrib/tfprof", "//tensorflow/contrib/timeseries", "//tensorflow/contrib/tpu", - "//tensorflow/contrib/tpu:tpu_py", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", - ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([ - "//tensorflow/contrib/tensorrt:init_py", - ]) + select({ + "//tensorflow/python/estimator:estimator_py", + ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({ "//tensorflow:with_kafka_support_windows_override": [], "//tensorflow:with_kafka_support": [ "//tensorflow/contrib/kafka", ], "//conditions:default": [], - }) + if_not_windows([ + }) + select({ + "//tensorflow:with_aws_support_windows_override": [], + "//tensorflow:with_aws_support": [ + "//tensorflow/contrib/kinesis", + ], + "//conditions:default": [], + }) + if_not_windows_cuda([ + "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols + ]) + if_not_windows([ + "//tensorflow/contrib/bigtable", # depends on bigtable + "//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", + # TODO(aaroey): tensorrt dependency has to appear before tflite so the + # build can resolve its flatbuffers symbols within the tensorrt library. + # This is an issue with the tensorrt static library and will be fixed by + # the next tensorrt release, so fix the order here after that. + "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code ]), ) @@ -137,6 +147,7 @@ cc_library( "//tensorflow/contrib/coder:all_kernels", "//tensorflow/contrib/data/kernels:dataset_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels", + "//tensorflow/contrib/hadoop:dataset_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_kernels", @@ -154,6 +165,12 @@ cc_library( "//tensorflow/contrib/kafka:dataset_kernels", ], "//conditions:default": [], + }) + select({ + "//tensorflow:with_aws_support_windows_override": [], + "//tensorflow:with_aws_support": [ + "//tensorflow/contrib/kinesis:dataset_kernels", + ], + "//conditions:default": [], }), ) @@ -166,6 +183,7 @@ cc_library( "//tensorflow/contrib/data:dataset_ops_op_lib", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", + "//tensorflow/contrib/hadoop:dataset_ops_op_lib", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", @@ -183,5 +201,11 @@ cc_library( "//tensorflow/contrib/kafka:dataset_ops_op_lib", ], "//conditions:default": [], + }) + select({ + "//tensorflow:with_aws_support_windows_override": [], + "//tensorflow:with_aws_support": [ + "//tensorflow/contrib/kinesis:dataset_ops_op_lib", + ], + "//conditions:default": [], }), ) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index ad8c40395c2cdcc5e4288e04bb2115bd3627cdc9..e18ea8df4df719a7317333cf9038ce7facf8d6ac 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -22,15 +22,16 @@ from __future__ import print_function import os # Add projects here, they will show up under tf.contrib. +from tensorflow.contrib import autograph from tensorflow.contrib import batching from tensorflow.contrib import bayesflow from tensorflow.contrib import checkpoint -from tensorflow.contrib import cloud +if os.name != "nt": + from tensorflow.contrib import cloud from tensorflow.contrib import cluster_resolver from tensorflow.contrib import coder from tensorflow.contrib import compiler from tensorflow.contrib import constrained_optimization -from tensorflow.contrib import control_flow from tensorflow.contrib import copy_graph from tensorflow.contrib import crf from tensorflow.contrib import cudnn_rnn diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 159d985db5c48f8fe1a26350255f8d8f68482473..3b539734a236804026826a8117d9c668c0dd089a 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -32,10 +32,10 @@ def _flatten_tensors(tensors): """Check tensors for isomorphism and flatten. Args: - tensors: list of T @{tf.Tensor} which must all have the same shape. + tensors: list of T `tf.Tensor` which must all have the same shape. Returns: - tensors: a list of T @{tf.Tensor} which are flattened (1D) views of tensors + tensors: a list of T `tf.Tensor` which are flattened (1D) views of tensors shape: the original shape of each element of input tensors Raises: @@ -61,12 +61,12 @@ def _reshape_tensors(tensors, shape): """Reshape tensors flattened by _flatten_tensors. Args: - tensors: list of T @{tf.Tensor} of identical length 1D tensors. + tensors: list of T `tf.Tensor` of identical length 1D tensors. shape: list of integers describing the desired shape. Product of the elements must equal the length of each tensor. Returns: - list of T @{tf.Tensor} which are the reshaped inputs. + list of T `tf.Tensor` which are the reshaped inputs. """ reshaped = [] for t in tensors: @@ -79,12 +79,12 @@ def _padded_split(tensor, pieces): """Like split for 1D tensors but pads-out case where len % pieces != 0. Args: - tensor: T @{tf.Tensor} that must be 1D. + tensor: T `tf.Tensor` that must be 1D. pieces: a positive integer specifying the number of pieces into which tensor should be split. Returns: - list of T @{tf.Tensor} of length pieces, which hold the values of + list of T `tf.Tensor` of length pieces, which hold the values of thin input tensor, in order. The final tensor may be zero-padded on the end to make its size equal to those of all of the other tensors. @@ -132,11 +132,11 @@ def _strip_padding(tensors, pad_len): """Strip the suffix padding added by _padded_split. Args: - tensors: list of T @{tf.Tensor} of identical length 1D tensors. + tensors: list of T `tf.Tensor` of identical length 1D tensors. pad_len: number of elements to be stripped from the end of each tensor. Returns: - list of T @{tf.Tensor} which are the stripped inputs. + list of T `tf.Tensor` which are the stripped inputs. Raises: ValueError: tensors must be a non-empty list of 1D tensors, and @@ -161,12 +161,12 @@ def _ragged_split(tensor, pieces): """Like split for 1D tensors but allows case where len % pieces != 0. Args: - tensor: T @{tf.Tensor} that must be 1D. + tensor: T `tf.Tensor` that must be 1D. pieces: a positive integer specifying the number of pieces into which tensor should be split. Returns: - list of T @{tf.Tensor} of length pieces, which hold the values of + list of T `tf.Tensor` of length pieces, which hold the values of the input tensor, in order. The final tensor may be shorter than the others, which will all be of equal length. @@ -256,7 +256,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks, """Construct a subgraph performing a ring-style all-reduce of input_tensors. Args: - input_tensors: a list of T @{tf.Tensor} objects, which must all + input_tensors: a list of T `tf.Tensor` objects, which must all have the same shape and type. num_workers: number of worker tasks spanned by input_tensors. num_subchunks: number of subchunks each device should process in one tick. @@ -272,7 +272,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks, size. Returns: - a list of T @{tf.Tensor} identical sum-reductions of input_tensors. + a list of T `tf.Tensor` identical sum-reductions of input_tensors. """ if len(input_tensors) < 2: raise ValueError("input_tensors must be length 2 or longer") @@ -299,7 +299,7 @@ def _build_ring_gather(input_tensors, devices, num_subchunks, """Construct a subgraph for the first (reduction) pass of ring all-reduce. Args: - input_tensors: a list of T @{tf.Tensor} 1D input tensors of same + input_tensors: a list of T `tf.Tensor` 1D input tensors of same shape and type. devices: array of device name strings num_subchunks: number of subchunks each device should process in one tick. @@ -311,7 +311,7 @@ def _build_ring_gather(input_tensors, devices, num_subchunks, ValueError: tensors must all be one dimensional. Returns: - list of list of T @{tf.Tensor} of (partially) reduced values where + list of list of T `tf.Tensor` of (partially) reduced values where exactly num_subchunks chunks at each device are fully reduced. """ num_devices = len(input_tensors) @@ -360,11 +360,11 @@ def _apply_unary_to_chunks(f, chunks_by_dev): """Apply a unary op to each tensor in chunks_by_dev, on same device. Args: - f: a unary function over T @{tf.Tensor}. - chunks_by_dev: list of lists of T @{tf.Tensor}. + f: a unary function over T `tf.Tensor`. + chunks_by_dev: list of lists of T `tf.Tensor`. Returns: - new list of lists of T @{tf.Tensor} with the same structure as + new list of lists of T `tf.Tensor` with the same structure as chunks_by_dev containing the derived tensors. """ output = [] @@ -381,14 +381,14 @@ def _build_ring_scatter(pred_by_s_d, rank_by_s_d, Args: pred_by_s_d: as produced by _ring_permutations rank_by_s_d: as produced by _ring_permutations - chunks_by_dev: list of list of T @{tf.Tensor} indexed by ints + chunks_by_dev: list of list of T `tf.Tensor` indexed by ints (device, chunk) Raises: ValueError: chunks_by_dev is not well-formed Returns: - list of T @{tf.Tensor} which are the fully reduced tensors, one + list of T `tf.Tensor` which are the fully reduced tensors, one at each device corresponding to the outer dimension of chunks_by_dev. """ num_devices = len(chunks_by_dev) @@ -448,12 +448,12 @@ def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None): the future with edge-case specific logic. Args: - input_tensors: list of T @{tf.Tensor} to be elementwise reduced. + input_tensors: list of T `tf.Tensor` to be elementwise reduced. red_op: a binary elementwise reduction Op. un_op: an optional unary elementwise Op to apply to reduced values. Returns: - list of T @{tf.Tensor} which are the fully reduced tensors, one + list of T `tf.Tensor` which are the fully reduced tensors, one at each device of input_tensors. Raises: @@ -475,13 +475,13 @@ def _build_recursive_hd_gather(input_tensors, devices, red_op): """Construct the gather phase of recursive halving-doubling all-reduce. Args: - input_tensors: list of T @{tf.Tensor} to be elementwise reduced. + input_tensors: list of T `tf.Tensor` to be elementwise reduced. devices: a list of strings naming the devices hosting input_tensors, which will also be used to host the (partial) reduction values. red_op: a binary elementwise reduction Op. Returns: - list of T @{tf.Tensor} which are the fully reduced tensor shards. + list of T `tf.Tensor` which are the fully reduced tensor shards. Raises: ValueError: num_devices not a power of 2, or tensor len not divisible @@ -516,12 +516,12 @@ def _build_recursive_hd_scatter(input_tensors, devices): """Construct the scatter phase of recursive halving-doublng all-reduce. Args: - input_tensors: list of T @{tf.Tensor} that are fully-reduced shards. + input_tensors: list of T `tf.Tensor` that are fully-reduced shards. devices: a list of strings naming the devices on which the reconstituted full tensors should be placed. Returns: - list of T @{tf.Tensor} which are the fully reduced tensors. + list of T `tf.Tensor` which are the fully reduced tensors. """ num_devices = len(devices) num_hops = int(math.log(num_devices, 2)) @@ -571,7 +571,7 @@ def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None): un_op: optional elementwise unary Op to be applied to fully-reduced values. Returns: - list of T @{tf.Tensor} which are the fully reduced tensors. + list of T `tf.Tensor` which are the fully reduced tensors. """ input_tensors, shape = _flatten_tensors(input_tensors) dst_devices = [t.device for t in input_tensors] @@ -594,7 +594,7 @@ def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None): un_op: optional elementwise unary Op to be applied to fully-reduced values. Returns: - list of T @{tf.Tensor} which are the fully reduced shards. + list of T `tf.Tensor` which are the fully reduced shards. Raises: ValueError: inputs not well-formed. @@ -629,7 +629,7 @@ def _build_shuffle_scatter(reduced_shards, dst_devices): should be reconstituted. Returns: - list of T @{tf.Tensor} scattered tensors. + list of T `tf.Tensor` scattered tensors. """ num_devices = len(dst_devices) out_tensors = [] @@ -644,7 +644,7 @@ def _split_by_task(devices, values): Args: devices: list of device name strings - values: list of T @{tf.tensor} of same length as devices. + values: list of T `tf.tensor` of same length as devices. Returns: (per_task_devices, per_task_values) where both values are @@ -680,14 +680,14 @@ def build_nccl_all_reduce(input_tensors, red_op, un_op=None): """Build a subgraph that does one full all-reduce, using NCCL. Args: - input_tensors: list of T @{tf.Tensor} of same-shape and type values to + input_tensors: list of T `tf.Tensor` of same-shape and type values to be reduced. red_op: binary elementwise reduction operator. Must be one of {tf.add} un_op: optional unary elementwise Op to apply to fully-reduce values. Returns: - list of T @{tf.Tensor} of reduced values. + list of T `tf.Tensor` of reduced values. Raises: ValueError: red_op not supported. @@ -709,14 +709,14 @@ def _build_nccl_hybrid(input_tensors, red_op, upper_level_f): """Construct a subgraph for NCCL hybrid all-reduce. Args: - input_tensors: list of T @{tf.Tensor} of same-shape and type values to + input_tensors: list of T `tf.Tensor` of same-shape and type values to be reduced. red_op: binary elementwise reduction operator. upper_level_f: function for reducing one value per worker, across workers. Returns: - list of T @{tf.Tensor} of reduced values. + list of T `tf.Tensor` of reduced values. Raises: ValueError: inputs not well-formed. @@ -797,7 +797,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): """Construct a subgraph for Shuffle hybrid all-reduce. Args: - input_tensors: list of T @{tf.Tensor} of same-shape and type values to + input_tensors: list of T `tf.Tensor` of same-shape and type values to be reduced. gather_devices: list of device names on which to host gather shards. red_op: binary elementwise reduction operator. @@ -805,7 +805,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): workers. Returns: - list of T @{tf.Tensor} of reduced values. + list of T `tf.Tensor` of reduced values. Raises: ValueError: inputs not well-formed. diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index c10179ba8b290b6209f5567d6323df4bcf711585..f0b1c92cf7e4b760381da38febd9682ce2a4f27c 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -1,6 +1,8 @@ # Description: # JNI-based Java inference interface for TensorFlow. +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml b/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml index bced47e046db889366bf88e563d086a8c367431a..c17110a78be49f70ef108be79a624d87ad9ed28d 100644 --- a/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml +++ b/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml @@ -1,6 +1,10 @@ + + diff --git a/tensorflow/contrib/autograph/BUILD b/tensorflow/contrib/autograph/BUILD index 30dd846893c30b9205972bd5216cc1871ab03d76..ad700ac4a0342e2a7bc07a6ecf6710cea892e296 100644 --- a/tensorflow/contrib/autograph/BUILD +++ b/tensorflow/contrib/autograph/BUILD @@ -23,9 +23,9 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/autograph/impl", + "//tensorflow/contrib/autograph/lang", "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/utils", - "@gast_archive//:gast", - "@six_archive//:six", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/contrib/autograph/CONTRIBUTING.md index a4aec8c74a9ad1418072471a5d3cde8c3b968a38..06fb7b03d5dbbfd2fcb6d6a2ecfe5c817f94a469 100644 --- a/tensorflow/contrib/autograph/CONTRIBUTING.md +++ b/tensorflow/contrib/autograph/CONTRIBUTING.md @@ -1,4 +1,4 @@ -# How to Contribute +# How to contribute We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below. @@ -46,3 +46,50 @@ bazel test --config=opt --copt=-O3 --copt=-march=native \ ``` from the root of the `tensorflow` repository. For more details see the [main TensorFlow Contributing File](../../CONTRIBUTING.md) + +## Developer info + +### Module structure + +The graph below describes the dependencies between AutoGraph modules (not to be mistaken with the directory structure for these modules, which is flat): + +```dot +digraph d_modules { + autograph [style=filled]; + converters; + core; + impl; + lang; + operators; + + autograph -> impl + autograph -> lang + + impl -> converters + impl -> core + impl -> operators + + lang -> operators + + converters -> core + converters -> lang +} +``` + +`autograph` is the sole user-visible module. + +A short description of the modules: + + * `autograph`: the main module imported by the user and by the generated code; only contains declarations + * `impl`: high level code and the implementation of the api frontend + * `core`: base classes for the AutoGraph source code transformation logic; see in particular `converter.py` + * `lang`: special user-visible functions that serve as extensions to the Python language + * `converters`: collection of source code transformation modules specialized for particular AutoGraph features + * `operators`: collection of operators that AutoGraph overloads; these correspond to Python operators as well as Python syntactic structures, like control flow + +There are two additional modules, `pyct` and `utils`. These are independent of AutoGraph: + + * `pyct`: a general purpose Python source code transformation library + * `utils`: the kitchen sync; deprecated + +Note: we have a long term plan to factor out an implementation of `impl` and `converters` that is independent of autograph, into a general purpose Python operator overloading library. diff --git a/tensorflow/contrib/autograph/LIMITATIONS.md b/tensorflow/contrib/autograph/LIMITATIONS.md new file mode 100644 index 0000000000000000000000000000000000000000..d8b1cb7616ac348981bf2b69d6e2fd8d8a6e6b78 --- /dev/null +++ b/tensorflow/contrib/autograph/LIMITATIONS.md @@ -0,0 +1,50 @@ +# Capabilities and Limitations + +TF AutoGraph converts Eager Python code into TensorFlow graph-mode code. For example, users write code with `if` and `while` and AutoGraph automatically converts it into the equivalent `tf.cond`, and `tf.while_loop`. + +Python is a large language, so hoping to convert arbitrary Python code directly to TF graphs is overly ambitious. However, the Python code written to metaprogram TF graphs is in practice a restricted subset. We aim to support as much of this subset as possible. The table below lays out what we currently handle, what we hope to support, and what we have no plans to support. + +# Python Language Support Status + +Note: as more complex features in TensorFlow are made more accessible using AutoGraph, we expect to come across use cases that haven't been tried before, some of which might reveal rare bugs. If we do find any such bugs, we may add additional restrictions for the affected configurations, until those bugs are resolved. + + Construct | Supported now? | Plan to support? | Notes + :--------- | :--------------: | :----------------: | :----- +If statement | Yes | | Converts to `tf.cond`. If variables are created in one branch that don’t exist in another, which is inexpressible in TF, we throw a clear error. +For statement | Yes | | We will specialize `for` loops with unknown and known lengths, as well as for loops over TF datasets. Converts to `tf.while_loop`, with an additional `maximum_iterations` hint, if that is known. Creating variables inside the loop that are used later outside the loop is not supported, as the loop may have no iterations. +While statement | Yes | | Converts to `tf.while_loop`. Creating variables inside the loop is not supported, as the loop may have no iterations. +Continue and break | Yes | | Converts to boolean flags and extra predicates in loop tests. +Composition of control flow | Yes | | Arbitrary composition of `if`, `while`, `for`, `break`, and `continue`, along with other supported language elements, is supported and tested. +Iterators | Some | Yes | Not all iterators supported, but we plan to support everything that can be desugared, such as `enumerate` and `zip`. +Multiple return values | Yes | | We desugar them into variables, boolean flags and conditionals so that the function has a single return value at the end, and provide a clear error if we are unable to do so. +Print expression | Yes | | Wrapped in `PyFunc`, and given proper control dependencies. Optional support for using tf.Log when py_func is undesirable exists. +Static function calls | Yes | | Non-recursive function calls +Nested call trees | Yes | | For example, `f` calls `g` which calls `h`, all of which need conversion. +Recursive function calls | No | Maybe | Based on available support in TF. Currently `function.Defun` is the best candidate, but it is not reentrant. +Python built-ins | Some | Yes | `print`, `len`, `range`, `xrange`, `int`, `float` are supported, and we plan to support or clearly error on all [Python built-ins](https://docs.python.org/3/library/functions.html). +List operations | Yes | | We convert list creation, append, pop and indexing to their TF TensorArray equivalents. However, we do need some extra type hints to fully convert correctly. We hope to remove this limitation. +Function variables | Yes | | e.g. `f_new = f_orig; f_new()` +Lambda functions | No | Yes | Planned feature. +Classes | Yes | | Classes can be converted all at once, or method-by-method. Some limitations exist around static and class methods. +Subclasses | Yes | | Subclassing library objects like tf.keras.Model is also supported. +Dynamic types | Some | | `o = C1() if foo else C2(); o.bar()`. Some scenarios where types are data-dependent may not be supported. We will raise a meaningful error in that case. +Dynamic code / exec | No | | +Reflection | No | | +Try / Except | No | No | No current sane TF equivalent. +Global variables | Restricted | | In general, we only support read-only access to arguments or variables defined outside the converted code. A few exceptions include TensorFlow library code. +Functions with side effects | Some | | Side effects are allowed, under certain circumstances. +Collections | Some | Yes | We currently support lists. There are currently no TF equivalents of dictionaries or tuples. +List Comprehensions | Yes | | We desugar `ListComp` into the appropriate combination of `For` and `If` statements. Other comprehensions are currently very low priority. +Custom context managers | No | Yes | Currently low priority. Left unconverted currently. +Generators | No | Maybe | Could be achievable using queues; very low priority. +Assertions | Yes | | As `tf.Assert` +Deletion | Yes | Maybe | Currently unconverted. If new semanti cs are required for `del`, we are able to add it in. +Inline imports | No | Yes | For example, `import numpy as np; np.eye(3)`. Currently low priority. +Async | No | No | + +## Extra capabilities + + - We liberally add name scopes to generated functions + - Operations get decent default names everywhere (planned) + - Statements that have no output values are given correct control dependencies. For example, `for i in range(n): print(i)` will have control dependencies to ensure the `print` statements are executed serially. + diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md index 674859bed4ec157d5d5b33b6fc015c930e54b392..cc54da4daa9a5bb4e64145963ffec63021d08876 100644 --- a/tensorflow/contrib/autograph/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -1,10 +1,10 @@ # AutoGraph -IMPORTANT: AutoGraph is alpha software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)). +IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)). AutoGraph is a Python to TensorFlow compiler. -With AutoGraph, you can write [Eager style](https://www.tensorflow.org/programmers_guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. +With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. [Please see this file for which parts of the Python language we currently support](LIMITATIONS.md). For example, this Python function: @@ -68,12 +68,21 @@ Then import the `autograph` module from `tf.contrib`: from tensorflow.contrib import autograph as ag ``` -### Interactive demo notebooks +### Related links -For more extensive examples, check out these interactive notebooks: +Articles: - * [RNN trained using Keras and Estimators](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb) + * [TensorFlow blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7) + +Interactive notebooks: + + * [Quick guide](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/guide/autograph.ipynb) + * [RNN trained using Keras and Estimators](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb) * [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb) + * [Basic control flow speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb) + * [MNIST training speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb) + * [Basic algorithm samples](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb) + * [Introductory workshop support notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb) ## Using with annotations @@ -120,3 +129,15 @@ You can use the functional API to inspect the generated code as well: print(ag.to_code(f)) # Output: ``` + +## Filing bugs and feature requests + +### Reporting a bug + + - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message. + - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message. + - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you. + +### Requesting a feature + +If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there. diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/contrib/autograph/STYLE_GUIDE.md index 866e5f583a34570dfddc733f57561ed1d2b7c5bf..7e6b0cc27dd1cf8c0f459a0a34f98092728342a2 100644 --- a/tensorflow/contrib/autograph/STYLE_GUIDE.md +++ b/tensorflow/contrib/autograph/STYLE_GUIDE.md @@ -20,7 +20,17 @@ Naming conventions: Below are AutoGraph-specific conventions. In the event of conflict, it supercedes all previous conventions. -1. __Citations in Docstrings.__ Write a `#### References` subsection at the +1. __Types in docstrings.__ Use [PEP 484][https://www.python.org/dev/peps/pep-0484/] + notation to describe the type for args, return values and attributes. + + Example: + + ``` + Args: + foo: Dict[str, List[int]], a dictionary of sorts + ``` + +2. __Citations in Docstrings.__ Write a `#### References` subsection at the bottom of any docstring with citations. Use ICLR’s bibliography style to write references; for example, order entries by the first author's last name. Add a link to the paper if the publication is open source (ideally, @@ -60,12 +70,12 @@ it supercedes all previous conventions. https://arxiv.org/abs/1803.04386 ``` -2. Avoid LaTeX in docstrings. +3. Avoid LaTeX in docstrings. * It is not rendered in many (if not most) editors and can be hard to read for both LaTeX experts and non-experts. -3. Write docstring and comment math using ASCII friendly notation; python using +4. Write docstring and comment math using ASCII friendly notation; python using operators. E.g., `x**2` better than `x^2`, `x[i, j]` better than `x_{i,j}`, `sum{ f(x[i]) : i=1...n }` better than `\sum_{i=1}^n f(x_i)` `int{sin(x) dx: x in [0, 2 pi]}` better than `\int_0^{2\pi} sin(x) dx`. diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index dbdbad8f4c91c725294baa36acebbaf5b5e8cf5c..26e7a4a4d38e264486c981e6fc4c547bcc53b302 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -22,17 +22,21 @@ from __future__ import division from __future__ import print_function # TODO(mdan): Bring only the relevant symbols to the top level. -from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph import operators +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core.errors import GraphConstructionError +from tensorflow.contrib.autograph.core.errors import TfRuntimeError +from tensorflow.contrib.autograph.core.errors import improved_errors +from tensorflow.contrib.autograph.impl.api import RunMode from tensorflow.contrib.autograph.impl.api import convert from tensorflow.contrib.autograph.impl.api import converted_call from tensorflow.contrib.autograph.impl.api import do_not_convert -from tensorflow.contrib.autograph.impl.api import RunMode from tensorflow.contrib.autograph.impl.api import to_code from tensorflow.contrib.autograph.impl.api import to_graph -from tensorflow.contrib.autograph.impl.directives import set_element_type -from tensorflow.contrib.autograph.impl.directives import set_loop_options -from tensorflow.contrib.autograph.impl.special_functions import stack +from tensorflow.contrib.autograph.lang.directives import set_element_type +from tensorflow.contrib.autograph.lang.directives import set_loop_options +from tensorflow.contrib.autograph.lang.special_functions import stack +from tensorflow.contrib.autograph.lang.special_functions import tensor_list from tensorflow.contrib.autograph.pyct.transformer import AutographParseError from tensorflow.python.util.all_util import remove_undocumented @@ -46,10 +50,15 @@ _allowed_symbols = [ 'to_graph', # Overloaded operators 'operators', - # Special functions and directives + # Errors + 'improved_errors', + 'GraphConstructionError', + 'TfRuntimeError', + # Python language "extensions" 'set_element_type', 'set_loop_options', 'stack', + 'tensor_list', # Exceptions 'AutographParseError', # Utilities: to be removed diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD index 284ad84be566199adaaa1ab641d37528ae4dfd2d..2d2ab7040a8bb76f9538f201f75a2e4dcba0f511 100644 --- a/tensorflow/contrib/autograph/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -21,40 +21,29 @@ py_library( "break_statements.py", "builtin_functions.py", "call_trees.py", + "conditional_expressions.py", "continue_statements.py", "control_flow.py", "decorators.py", - "ifexp.py", - "list_comprehension.py", + "directives.py", + "error_handlers.py", + "list_comprehensions.py", "lists.py", "logical_expressions.py", "name_scopes.py", + "return_statements.py", "side_effect_guards.py", - "single_return.py", "slices.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ - "@gast_archive//:gast", - ], -) - -py_library( - name = "test_lib", - srcs = [ - "converter_test_base.py", - ], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:__subpackages__"], - deps = [ - ":converters", - "//tensorflow/contrib/autograph/operators", + "//tensorflow/contrib/autograph/core", + "//tensorflow/contrib/autograph/lang", "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/pyct/static_analysis", - "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:util", "@gast_archive//:gast", - "@six_archive//:six", ], ) @@ -64,7 +53,8 @@ py_test( srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -74,7 +64,8 @@ py_test( srcs = ["break_statements_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -85,7 +76,8 @@ py_test( srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -97,18 +89,31 @@ py_test( srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/impl", "//tensorflow/python:client_testlib", ], ) +py_test( + name = "conditional_expressions_test", + srcs = ["conditional_expressions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "continue_statements_test", srcs = ["continue_statements_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -118,7 +123,8 @@ py_test( srcs = ["control_flow_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -127,8 +133,25 @@ py_test( name = "decorators_test", srcs = ["decorators_test.py"], srcs_version = "PY2AND3", + tags = [ + "no_pip", + "no_windows", + ], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "directives_test", + srcs = ["directives_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", + "//tensorflow/contrib/autograph/lang", "//tensorflow/python:client_testlib", ], ) @@ -137,18 +160,20 @@ py_test( name = "name_scopes_test", srcs = ["name_scopes_test.py"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) py_test( - name = "list_comprehension_test", - srcs = ["list_comprehension_test.py"], + name = "list_comprehensions_test", + srcs = ["list_comprehensions_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -158,7 +183,8 @@ py_test( srcs = ["lists_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -168,7 +194,8 @@ py_test( srcs = ["logical_expressions_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -177,34 +204,33 @@ py_test( name = "side_effect_guards_test", srcs = ["side_effect_guards_test.py"], srcs_version = "PY2AND3", - tags = [ - # TODO(mdan): Fix. - "flaky", - "notap", - ], + tags = ["notsan"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) py_test( - name = "single_return_test", - srcs = ["single_return_test.py"], + name = "return_statements_test", + srcs = ["return_statements_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) py_test( - name = "ifexp_test", - srcs = ["ifexp_test.py"], + name = "error_handlers_test", + srcs = ["error_handlers_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], @@ -215,7 +241,8 @@ py_test( srcs = ["slices_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], diff --git a/tensorflow/contrib/autograph/converters/__init__.py b/tensorflow/contrib/autograph/converters/__init__.py index e4e8eda42f655e204310eaa9defdd5c90bf06e15..6325ac78dc3a08d14c1abf5e0f1ae60258639162 100644 --- a/tensorflow/contrib/autograph/converters/__init__.py +++ b/tensorflow/contrib/autograph/converters/__init__.py @@ -18,5 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# TODO(mdan): Define a base transformer class that can recognize skip_processing -# TODO(mdan): All converters are incomplete, especially those that change blocks +# Naming conventions: +# * each converter should specialize on a single idiom; be consistent with +# the Python reference for naming +# * all converters inherit core.converter.Base +# * module names describe the idiom that the converter covers, plural +# * the converter class is named consistent with the module, singular and +# includes the word Transformer +# +# Example: +# +# lists.py +# class ListTransformer(converter.Base) diff --git a/tensorflow/contrib/autograph/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py index 3b0db677ce5e417e7afea8d8fe4121a0352bb6d7..af2f20f267d5cc64a6e9507a08c44f7e52245c28 100644 --- a/tensorflow/contrib/autograph/converters/asserts.py +++ b/tensorflow/contrib/autograph/converters/asserts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Converts Assert statements to their corresponding TF calls.""" +"""Converts assert statements to their corresponding TF calls.""" from __future__ import absolute_import from __future__ import division @@ -20,12 +20,12 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class AssertsTransformer(transformer.Base): - """Transforms Print nodes to Call so they can be handled as functions.""" +class AssertTransformer(converter.Base): + """Transforms Assert nodes to Call so they can be handled as functions.""" def visit_Assert(self, node): self.generic_visit(node) @@ -45,5 +45,5 @@ class AssertsTransformer(transformer.Base): raise NotImplementedError('can only convert string messages for now.') -def transform(node, context): - return AssertsTransformer(context).visit(node) +def transform(node, ctx): + return AssertTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/asserts_test.py b/tensorflow/contrib/autograph/converters/asserts_test.py index cc913febe8d0f411588af69b87ec52ce58f4469c..38faba45df6746d56933a1647594af133b671628 100644 --- a/tensorflow/contrib/autograph/converters/asserts_test.py +++ b/tensorflow/contrib/autograph/converters/asserts_test.py @@ -21,21 +21,21 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.converters import asserts -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class AssertsTest(converter_test_base.TestCase): +class AssertsTest(converter_testing.TestCase): def test_transform(self): def test_fn(a): assert a > 0 - node = self.parse_and_analyze(test_fn, {}) - node = asserts.transform(node, self.ctx) + node, ctx = self.prepare(test_fn, {}) + node = asserts.transform(node, ctx) - self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call)) + self.assertTrue(isinstance(node.body[0].value, gast.Call)) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 775d92c1d9f8bc35d1eda62f3f3ef7ee43414779..180779670d91abd7d395bda0b63f592967c5015b 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -12,32 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Canonicalizes break statements by de-sugaring into a control boolean.""" +"""Lowers break statements to conditionals.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno -# Tags for local state. -BREAK_USED = 'break_used' -CONTROL_VAR_NAME = 'control_var_name' +class _Break(object): + def __init__(self): + self.used = False + self.control_var_name = None -class BreakStatementTransformer(transformer.Base): + def __repr__(self): + return 'used: %s, var: %s' % (self.used, self.control_var_name) + + +class BreakTransformer(converter.Base): """Canonicalizes break statements into additional conditionals.""" def visit_Break(self, node): - self.set_local(BREAK_USED, True) - var_name = self.get_local(CONTROL_VAR_NAME) + self.state[_Break].used = True + var_name = self.state[_Break].control_var_name # TODO(mdan): This will fail when expanded inside a top-level else block. template = """ - var_name = True + var_name = tf.constant(True) continue """ return templates.replace(template, var_name=var_name) @@ -57,20 +62,20 @@ class BreakStatementTransformer(transformer.Base): block=block) return node - def _track_body(self, nodes, break_var): - self.enter_local_scope() - self.set_local(CONTROL_VAR_NAME, break_var) + def _process_body(self, nodes, break_var): + self.state[_Break].enter() + self.state[_Break].control_var_name = break_var nodes = self.visit_block(nodes) - break_used = self.get_local(BREAK_USED, False) - self.exit_local_scope() + break_used = self.state[_Break].used + self.state[_Break].exit() return nodes, break_used def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break_', scope.referenced) + break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) - node.body, break_used = self._track_body(node.body, break_var) + node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) @@ -80,7 +85,7 @@ class BreakStatementTransformer(transformer.Base): guarded_orelse = self._guard_if_present(node.orelse, break_var) template = """ - var_name = False + var_name = tf.constant(False) while test and not var_name: body else: @@ -97,11 +102,11 @@ class BreakStatementTransformer(transformer.Base): def visit_For(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break_', scope.referenced) + break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) - node.body, break_used = self._track_body(node.body, break_var) + node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) @@ -117,7 +122,7 @@ class BreakStatementTransformer(transformer.Base): # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ - var_name = False + var_name = tf.constant(False) for target in iter_: (var_name,) body @@ -137,5 +142,5 @@ class BreakStatementTransformer(transformer.Base): return node -def transform(node, context): - return BreakStatementTransformer(context).visit(node) +def transform(node, ctx): + return BreakTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py index 1af59e9b5260fe0d3a3ef72c7a003dc451e230f3..fcae7d68c0f90817e001b45fa86ca6be08456027 100644 --- a/tensorflow/contrib/autograph/converters/break_statements_test.py +++ b/tensorflow/contrib/autograph/converters/break_statements_test.py @@ -19,13 +19,20 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.converters import break_statements -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.eager import context as tfe_ctx +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test -class BreakCanonicalizationTest(converter_test_base.TestCase): +class BreakCanonicalizationTest(converter_testing.TestCase): - def test_basic_while(self): + def assertTransformedEquivalent(self, test_fn, *inputs): + with self.converted(test_fn, break_statements, {}, + constant_op.constant) as result: + self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + + def test_while_loop(self): def test_fn(x): v = [] @@ -36,15 +43,12 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}) - node = break_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual([], result.test_fn(0)) - self.assertEqual([], result.test_fn(1)) - self.assertEqual([3], result.test_fn(4)) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 4) - def test_basic_for(self): + def test_for_loop(self): def test_fn(a): v = [] @@ -55,18 +59,13 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}) - node = break_statements.transform(node, self.ctx) - - with self.compiled(node) as result: + with self.converted(test_fn, break_statements, {}, + constant_op.constant) as result: # The break is incompletely canonicalized. The loop will not interrupt, # but the section following the break will be skipped. - self.assertEqual([], result.test_fn([])) - self.assertEqual([3, 3], result.test_fn([4, 4])) - self.assertEqual([3], result.test_fn([4, 5])) self.assertEqual([3], result.test_fn([5, 4])) - def test_deeply_nested(self): + def test_nested(self): def test_fn(x): v = [] @@ -83,13 +82,10 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v, u, w - node = self.parse_and_analyze(test_fn, {}) - node = break_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual(([], [], []), result.test_fn(0)) - self.assertEqual(([2, 1], [2], [0]), result.test_fn(3)) - self.assertEqual(([10, 9, 8, 7], [10, 8], [6]), result.test_fn(11)) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 11) def test_nested_loops(self): @@ -109,16 +105,13 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v, u - node = self.parse_and_analyze(test_fn, {}) - node = break_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual(([], []), result.test_fn(0)) - self.assertEqual(([1], []), result.test_fn(2)) - self.assertEqual(([2, 1], [1]), result.test_fn(3)) - self.assertEqual(([4, 3, 2, 1], [3, 1]), result.test_fn(5)) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 5) - def test_loop_else(self): + def test_loop_orelse(self): def test_fn(x): v = [] @@ -134,13 +127,10 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v, u - node = self.parse_and_analyze(test_fn, {}) - node = break_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual(([], []), result.test_fn(0)) - self.assertEqual(([], [1]), result.test_fn(2)) - self.assertEqual(([2], [1]), result.test_fn(3)) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, 3) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py index 231e4ee35a72f51845a476d9f605986ac73b4676..b26c52294c2d1c11ce14d8a2903f7f88079a703f 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class BuiltinFunctionTransformer(transformer.Base): +class BuiltinFunctionTransformer(converter.Base): """Handles builtin functions. This transformer only covers functions that are translated into a @@ -68,5 +68,5 @@ class BuiltinFunctionTransformer(transformer.Base): return self.visit(function_call) -def transform(node, context): - return BuiltinFunctionTransformer(context).visit(node) +def transform(node, ctx): + return BuiltinFunctionTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py index 30272409df322560b04ba75b3e1cb6f9ad5ff0af..d5c3e2c250cc1ee0205fd1941040bf70de4a149a 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -18,73 +18,55 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys - import six from tensorflow.contrib.autograph.converters import builtin_functions -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class BuiltinFunctionsTest(converter_test_base.TestCase): +class BuiltinFunctionsTest(converter_testing.TestCase): def test_len(self): def test_fn(a): return len(a) - node = self.parse_and_analyze(test_fn, {'len': len}) - node = builtin_functions.transform(node, self.ctx) - - with self.compiled(node, array_ops.shape) as result: + with self.converted(test_fn, builtin_functions, {'len': len}, + array_ops.shape) as result: with self.test_session() as sess: - self.assertEqual(3, - sess.run( - result.test_fn(constant_op.constant([0, 0, 0])))) - - self.assertEqual(3, result.test_fn([0, 0, 0])) + ops = result.test_fn(constant_op.constant([0, 0, 0])) + self.assertEqual(sess.run(ops), 3) def test_print(self): - def test_fn(a): - print(a) + if six.PY2: + return - node = self.parse_and_analyze(test_fn, {'print': print}) - node = builtin_functions.transform(node, self.ctx) + def test_fn(a): + return print(a) - with self.compiled(node) as result: + with self.converted(test_fn, builtin_functions, {'print': print}) as result: with self.test_session() as sess: - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - result.test_fn(constant_op.constant('a')) - sess.run(sess.graph.get_operations()) - self.assertEqual(out_capturer.getvalue(), 'a\n') - finally: - sys.stdout = sys.__stdout__ + with self.assertPrints('a\n'): + sess.run(result.test_fn('a')) - def test_print_with_op_multiple_values(self): + def test_print_multiple_values(self): - def test_fn(a, b, c): - print(a, b, c) + if six.PY2: + return - node = self.parse_and_analyze(test_fn, {'print': print}) - node = builtin_functions.transform(node, self.ctx) + def test_fn(a, b, c): + return print(a, b, c) - with self.compiled(node) as result: + with self.converted(test_fn, builtin_functions, {'print': print}) as result: with self.test_session() as sess: - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - result.test_fn( - constant_op.constant('a'), constant_op.constant(1), [2, 3]) - sess.run(sess.graph.get_operations()) - self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n') - finally: - sys.stdout = sys.__stdout__ + with self.assertPrints('a 1 [2, 3]\n'): + sess.run( + result.test_fn( + constant_op.constant('a'), constant_op.constant(1), [2, 3])) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py index b6ecdcb7809b1ad7e7461324cb6a110ef4180609..2d1bed3367fa0b283200b775c5953da80c855367 100644 --- a/tensorflow/contrib/autograph/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -26,12 +26,12 @@ from collections import namedtuple import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -45,6 +45,9 @@ KNOWN_NUMPY_FUNCTIONS = { } +# TODO(mdan): Get rid of these interfaces. Can now depend directly on Namer. + + class FunctionNamer(object): """Describes the interface for CallTreeTransformer's namer.""" @@ -76,20 +79,18 @@ class FunctionNamer(object): raise NotImplementedError() -class CallTreeTransformer(transformer.Base): - """Transforms the call tree by renaming transformed symbols.""" +# TODO(mdan): Rename to CallsTransformer. - def __init__(self, context, uncompiled_modules, nocompile_decorators): - super(CallTreeTransformer, self).__init__(context) - self.uncompiled_modules = uncompiled_modules - self.nocompile_decorators = nocompile_decorators + +class CallTreeTransformer(converter.Base): + """Transforms the call tree by renaming transformed symbols.""" def _resolve_name(self, node): """Used to resolve decorator info.""" if isinstance(node, gast.Call): return self._resolve_name(node.func) if isinstance(node, gast.Name): - return self.context.namespace.get(node.id) + return self.ctx.namespace.get(node.id) if isinstance(node, gast.Attribute): parent = self._resolve_name(node.value) if parent is not None: @@ -119,12 +120,12 @@ class CallTreeTransformer(transformer.Base): """Determines whether an entity should be compiled in the context.""" # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. module_name = fqn[0] - for mod in self.uncompiled_modules: + for mod in self.ctx.program.uncompiled_modules: if module_name.startswith(mod[0] + '.'): return False for i in range(1, len(fqn)): - if fqn[:i] in self.uncompiled_modules: + if fqn[:i] in self.ctx.program.uncompiled_modules: return False # Check for local decorations @@ -140,7 +141,7 @@ class CallTreeTransformer(transformer.Base): if hasattr(target_entity, '__pyct_is_compile_decorator'): return False - if target_entity in self.nocompile_decorators: + if target_entity in self.ctx.program.autograph_decorators: return False # Inspect the target function decorators. If any include a @convert @@ -159,7 +160,7 @@ class CallTreeTransformer(transformer.Base): for dec in target_node.decorator_list: decorator_fn = self._resolve_name(dec) if (decorator_fn is not None and - decorator_fn in self.nocompile_decorators): + decorator_fn in self.ctx.program.autograph_decorators): return False return True @@ -174,7 +175,7 @@ class CallTreeTransformer(transformer.Base): return node if anno.hasanno(node, 'is_constructor'): - new_name = self.context.namer.compiled_class_name( + new_name = self.ctx.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: @@ -183,7 +184,7 @@ class CallTreeTransformer(transformer.Base): else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) - new_name, do_rename = self.context.namer.compiled_function_name( + new_name, do_rename = self.ctx.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: @@ -237,7 +238,7 @@ class CallTreeTransformer(transformer.Base): # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ - ag__.converted_call(func, True, False, {}, args) + ag__.converted_call(func, True, False, False, {}, args) """ call_expr = templates.replace(template, func=node.func, args=node.args) new_call = call_expr[0].value @@ -264,15 +265,16 @@ class CallTreeTransformer(transformer.Base): return node def visit_Call(self, node): - # If the function is wrapped by one of the marker decorators, + # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') - if target_entity in self.nocompile_decorators: + if target_entity in self.ctx.program.autograph_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' - 'A decorator needs at least an argument.') + 'A decorator needs at least one positional argument.' % + target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) @@ -309,27 +311,20 @@ class CallTreeTransformer(transformer.Base): # ensure that they return the correct value. return node - if self.context.recursive: + if self.ctx.program.recursive: node = self._insert_dynamic_conversion(node) return node -def transform(node, context, uncompiled_modules, nocompile_decorators): +def transform(node, ctx): """Transform function call to the compiled counterparts. Args: - node: AST to transform. - context: An EntityContext object. - uncompiled_modules: set of string tuples, each tuple represents the fully - qualified name of a package containing functions that will not be - compiled. - nocompile_decorators: A tuple containing decorators to be stripped from - functions during conversion. + node: AST + ctx: EntityContext Returns: A tuple (node, new_names): node: The transformed AST new_names: set(string), containing any newly-generated names """ - t = CallTreeTransformer(context, uncompiled_modules, nocompile_decorators) - node = t.visit(node) - return node + return CallTreeTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/call_trees_test.py b/tensorflow/contrib/autograph/converters/call_trees_test.py index 303dd54a4ee49de27fad0c5cdc2d6274abfe0fa8..8cdba659eee264717204cc6048bbe0b8bbfe245f 100644 --- a/tensorflow/contrib/autograph/converters/call_trees_test.py +++ b/tensorflow/contrib/autograph/converters/call_trees_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.autograph.converters import call_trees -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -29,44 +29,41 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class CallTreesTest(converter_test_base.TestCase): +class CallTreesTest(converter_testing.TestCase): def test_basic(self): def test_fn_1(_): raise ValueError('This should not be called in the compiled version.') - def renamed_test_fn_1(a): + def other_test_fn_1(a): return a + 1 def test_fn_2(a): return test_fn_1(a) + 1 - node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1}) - node = call_trees.transform(node, self.ctx, (), ()) + ns = {'test_fn_1': test_fn_1} + node, ctx = self.prepare(test_fn_2, ns) + node = call_trees.transform(node, ctx) - with self.compiled(node) as result: - # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 - # manually. - result.renamed_test_fn_1 = renamed_test_fn_1 - self.assertEquals(3, result.test_fn_2(1)) + with self.compiled(node, ns) as result: + new_name, _ = ctx.namer.compiled_function_name(('test_fn_1',)) + setattr(result, new_name, other_test_fn_1) + self.assertEquals(result.test_fn_2(1), 3) def test_dynamic_function(self): def test_fn_1(): - raise ValueError('This should be masked by the mock.') + raise ValueError('This should be masked by the mock in self.compiled.') def test_fn_2(f): return f() + 3 - node = self.parse_and_analyze(test_fn_2, {}) - node = call_trees.transform(node, self.ctx, (), ()) - - with self.compiled(node) as result: + with self.converted(test_fn_2, call_trees, {}) as result: # 10 = 7 (from the mock) + 3 (from test_fn_2) self.assertEquals(10, result.test_fn_2(test_fn_1)) - def test_simple_methods(self): + def test_basic_method(self): class TestClass(object): @@ -76,49 +73,43 @@ class CallTreesTest(converter_test_base.TestCase): def test_fn_2(self, a): return self.test_fn_1(a) + 1 - node = self.parse_and_analyze( - TestClass.test_fn_2, {'TestClass': TestClass}, - namer=converter_test_base.FakeNoRenameNamer(), + ns = {'TestClass': TestClass} + node, ctx = self.prepare( + TestClass.test_fn_2, + ns, + namer=converter_testing.FakeNoRenameNamer(), arg_types={'self': (TestClass.__name__, TestClass)}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, ctx) - with self.compiled(node) as result: + with self.compiled(node, ns) as result: tc = TestClass() self.assertEquals(3, result.test_fn_2(tc, 1)) - def test_py_func_wrap_no_retval(self): + def test_py_func_no_retval(self): def test_fn(a): setattr(a, 'foo', 'bar') - node = self.parse_and_analyze(test_fn, {'setattr': setattr}) - node = call_trees.transform(node, self.ctx, (), ()) - - with self.compiled(node) as result: + with self.converted(test_fn, call_trees, {'setattr': setattr}) as result: with self.test_session() as sess: - # The function has no return value, so we do some tricks to grab the - # generated py_func node and ensure its effect only happens at graph - # execution. class Dummy(object): pass a = Dummy() result.test_fn(a) + py_func_op, = sess.graph.get_operations() self.assertFalse(hasattr(a, 'foo')) - sess.run(sess.graph.get_operations()[0]) + sess.run(py_func_op) self.assertEquals('bar', a.foo) - def test_py_func_wrap_known_function(self): + def test_py_func_known_function(self): def test_fn(): return np.random.binomial(2, 0.5) - node = self.parse_and_analyze(test_fn, {'np': np}) - node = call_trees.transform(node, self.ctx, (), ()) - - with self.compiled(node, dtypes.int64) as result: - result.np = np + with self.converted(test_fn, call_trees, {'np': np}, + dtypes.int64) as result: with self.test_session() as sess: self.assertTrue(isinstance(result.test_fn(), ops.Tensor)) self.assertIn(sess.run(result.test_fn()), (0, 1, 2)) @@ -130,22 +121,17 @@ class CallTreesTest(converter_test_base.TestCase): a = math_ops.add(a, constant_op.constant(1)) return a - node = self.parse_and_analyze(test_fn, { - 'math_ops': math_ops, - 'constant_op': constant_op - }) - node = call_trees.transform(node, self.ctx, - set(((math_ops.__name__,), - (constant_op.__name__,))), ()) - - with self.compiled(node) as result: - result.math_ops = math_ops - result.constant_op = constant_op + ns = {'math_ops': math_ops, 'constant_op': constant_op} + node, ctx = self.prepare( + test_fn, + ns, + arg_types=set(((math_ops.__name__,), (constant_op.__name__,)))) + node = call_trees.transform(node, ctx) + + with self.compiled(node, ns) as result: with self.test_session() as sess: - # Not renamed, because the converter doesn't rename the definition - # itself (the caller is responsible for that). result_tensor = result.test_fn(constant_op.constant(1)) - self.assertEquals(3, sess.run(result_tensor)) + self.assertEquals(sess.run(result_tensor), 3) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/conditional_expressions.py b/tensorflow/contrib/autograph/converters/conditional_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..63f649dfdf5f740ba66260a51175a0ec2b716ea3 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/conditional_expressions.py @@ -0,0 +1,129 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converts the ternary conditional operator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno + + +class _FunctionDefs(object): + + def __init__(self): + self.nodes = [] + + +class _Statement(object): + + def __init__(self): + self.scope = None + + +class ConditionalExpressionTransformer(converter.Base): + """Converts contitional expressions to functional form.""" + + def _postprocess_statement(self, node): + """Inserts any separate functions that node may use.""" + replacements = [] + for def_node in self.state[_FunctionDefs].nodes: + replacements.extend(def_node) + replacements.append(node) + node = replacements + # The corresponding enter is called by self.visit_block (see _process_block) + self.state[_FunctionDefs].exit() + return node, None + + def _create_branch(self, expr, name_stem): + scope = self.state[_Statement].scope + name = self.ctx.namer.new_symbol(name_stem, scope.referenced) + template = """ + def name(): + return expr, + """ + node = templates.replace(template, name=name, expr=expr) + self.state[_FunctionDefs].nodes.append(node) + return name + + def visit_IfExp(self, node): + if anno.hasanno(node.test, anno.Basic.QN): + name_root = anno.getanno(node.test, anno.Basic.QN).ssf() + else: + name_root = 'ifexp' + + true_fn_name = self._create_branch(node.body, '%s_true' % name_root) + false_fn_name = self._create_branch(node.orelse, '%s_false' % name_root) + + return templates.replace_as_expression( + 'ag__.utils.run_cond(test, true_fn_name, false_fn_name)', + test=node.test, + true_fn_name=true_fn_name, + false_fn_name=false_fn_name) + + def _process_block(self, scope, block): + self.state[_Statement].enter() + self.state[_Statement].scope = scope + block = self.visit_block( + block, + before_visit=self.state[_FunctionDefs].enter, + after_visit=self._postprocess_statement) + self.state[_Statement].exit() + return block + + def visit_FunctionDef(self, node): + node.args = self.generic_visit(node.args) + node.decorator_list = self.visit_block(node.decorator_list) + node.body = self._process_block( + anno.getanno(node, anno.Static.SCOPE), node.body) + return node + + def visit_For(self, node): + node.target = self.visit(node.target) + node.body = self._process_block( + anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) + node.orelse = self._process_block( + anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse) + return node + + def visit_While(self, node): + node.test = self.visit(node.test) + node.body = self._process_block( + anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) + node.orelse = self._process_block( + anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse) + return node + + def visit_If(self, node): + node.test = self.visit(node.test) + node.body = self._process_block( + anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) + node.orelse = self._process_block( + anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse) + return node + + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body = self._process_block( + anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) + return node + + +def transform(node, ctx): + node = ConditionalExpressionTransformer(ctx).visit(node) + return node diff --git a/tensorflow/contrib/autograph/converters/conditional_expressions_test.py b/tensorflow/contrib/autograph/converters/conditional_expressions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..95a3108741800c5fe504690f92876fa63edd8651 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/conditional_expressions_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for conditional_expressions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import conditional_expressions +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.platform import test + + +class ConditionalExpressionsTest(converter_testing.TestCase): + + def assertTransformedEquivalent(self, test_fn, *inputs): + ns = {} + with self.converted(test_fn, conditional_expressions, ns) as result: + self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + + def test_basic(self): + + def test_fn(x): + return 1 if x else 0 + + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 3) + + def test_nested_orelse(self): + + def test_fn(x): + y = x * x if x > 0 else x if x else 1 + return y + + self.assertTransformedEquivalent(test_fn, -2) + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py index 0417817a77e706fc0ce805f7391bea600f5fbb2d..0476e97c15e33dcfc09b3555cf8dc7ff3fd7ce19 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -31,13 +31,13 @@ GUARD_CREATED = 'guard_created' CREATE_GUARD_NEXT = 'create_guard_next' -class ContinueCanonicalizationTransformer(transformer.Base): +class ContinueCanonicalizationTransformer(converter.Base): """Canonicalizes continue statements into additional conditionals.""" def visit_Continue(self, node): self.set_local(CONTINUE_USED, True) template = """ - var_name = True + var_name = tf.constant(True) """ return templates.replace( template, var_name=self.get_local(CONTROL_VAR_NAME)) @@ -85,14 +85,14 @@ class ContinueCanonicalizationTransformer(transformer.Base): def _visit_loop_body(self, node, nodes): self.enter_local_scope() scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - continue_var = self.context.namer.new_symbol('continue_', scope.referenced) + continue_var = self.ctx.namer.new_symbol('continue_', scope.referenced) self.set_local(CONTROL_VAR_NAME, continue_var) nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) if self.get_local(CONTINUE_USED, False): template = """ - var_name = False + var_name = tf.constant(False) """ control_var_init = templates.replace(template, var_name=continue_var) nodes = control_var_init + nodes @@ -135,5 +135,5 @@ class ContinueCanonicalizationTransformer(transformer.Base): return node -def transform(node, namer): - return ContinueCanonicalizationTransformer(namer).visit(node) +def transform(node, ctx): + return ContinueCanonicalizationTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py index bcbb316d7459aa5a25bb0bd128cd6e359a393288..37c15211b4fe266e57879249fe7e060ded44dc1f 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements_test.py +++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py @@ -19,13 +19,20 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.converters import continue_statements -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.eager import context as tfe_ctx +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test -class ContinueCanonicalizationTest(converter_test_base.TestCase): +class ContinueCanonicalizationTest(converter_testing.TestCase): - def test_basic_continue(self): + def assertTransformedEquivalent(self, test_fn, *inputs): + with self.converted(test_fn, continue_statements, {}, + constant_op.constant) as result: + self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + + def test_basic(self): def test_fn(x): v = [] @@ -36,17 +43,13 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}) - node = continue_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual(test_fn(0), result.test_fn(0)) - self.assertEqual(test_fn(1), result.test_fn(1)) - self.assertEqual(test_fn(2), result.test_fn(2)) - self.assertEqual(test_fn(3), result.test_fn(3)) - self.assertEqual(test_fn(4), result.test_fn(4)) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 4) - def test_basic_continue_for_loop(self): + def test_for_loop(self): def test_fn(a): v = [] @@ -57,16 +60,13 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}) - node = continue_statements.transform(node, self.ctx) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [1]) + self.assertTransformedEquivalent(test_fn, [2]) + self.assertTransformedEquivalent(test_fn, [1, 2, 3]) - with self.compiled(node) as result: - self.assertEqual(test_fn([]), result.test_fn([])) - self.assertEqual(test_fn([1]), result.test_fn([1])) - self.assertEqual(test_fn([2]), result.test_fn([2])) - self.assertEqual(test_fn([1, 2, 3]), result.test_fn([1, 2, 3])) - - def test_continue_deeply_nested(self): + def test_nested(self): def test_fn(x): v = [] @@ -83,15 +83,11 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v, u, w - node = self.parse_and_analyze(test_fn, {}) - node = continue_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual(test_fn(0), result.test_fn(0)) - self.assertEqual(test_fn(1), result.test_fn(1)) - self.assertEqual(test_fn(2), result.test_fn(2)) - self.assertEqual(test_fn(3), result.test_fn(3)) - self.assertEqual(test_fn(4), result.test_fn(4)) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 4) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index d7ddbe8a04f64848d6ec21155d8d85f60e19d276..5a5a6ad63a777f463e80e061d4870f2ee7491c39 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -20,13 +20,12 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.contrib.autograph.pyct.static_analysis import cfg -from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct.static_analysis import annos class SymbolNamer(object): @@ -45,7 +44,7 @@ class SymbolNamer(object): raise NotImplementedError() -class ControlFlowTransformer(transformer.Base): +class ControlFlowTransformer(converter.Base): """Transforms control flow structures like loops an conditionals.""" def _create_cond_branch(self, body_name, aliased_orig_names, @@ -91,60 +90,56 @@ class ControlFlowTransformer(transformer.Base): return templates.replace( template, test=test, body_name=body_name, orelse_name=orelse_name) - def visit_If(self, node): - self.generic_visit(node) + def _fmt_symbol_list(self, symbol_set): + if not symbol_set: + return 'no variables' + return ', '.join(map(str, symbol_set)) - body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) - body_defs = body_scope.created | body_scope.modified - orelse_defs = orelse_scope.created | orelse_scope.modified - live = anno.getanno(node, 'live_out') - - # We'll need to check if we're closing over variables that are defined - # elsewhere in the function - # NOTE: we can only detect syntactic closure in the scope - # of the code passed in. If the AutoGraph'd function itself closes - # over other variables, this analysis won't take that into account. - defined = anno.getanno(node, 'defined_in') - - # We only need to return variables that are - # - modified by one or both branches - # - live (or has a live parent) at the end of the conditional - modified = [] - for def_ in body_defs | orelse_defs: - def_with_parents = set((def_,)) | def_.support_set - if live & def_with_parents: - modified.append(def_) - - # We need to check if live created variables are balanced - # in both branches - created = live & (body_scope.created | orelse_scope.created) - - # The if statement is illegal if there are variables that are created, - # that are also live, but both branches don't create them. - if created: - if created != (body_scope.created & live): - raise ValueError( - 'The main branch does not create all live symbols that the else ' - 'branch does.') - if created != (orelse_scope.created & live): - raise ValueError( - 'The else branch does not create all live symbols that the main ' - 'branch does.') - - # Alias the closure variables inside the conditional functions - # to avoid errors caused by the local variables created in the branch - # functions. + def visit_If(self, node): + node = self.generic_visit(node) + + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) + defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) + live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) + + modified_in_cond = body_scope.modified | orelse_scope.modified + returned_from_cond = set() + for s in modified_in_cond: + if s in live_out: + returned_from_cond.add(s) + elif s.is_composite(): + # Special treatment for compound objects: if any of their owner entities + # are live, then they are outputs as well. + if any(owner in live_out for owner in s.owner_set): + returned_from_cond.add(s) + + need_alias_in_body = body_scope.modified & defined_in + need_alias_in_orelse = orelse_scope.modified & defined_in + + created_in_body = body_scope.modified & returned_from_cond - defined_in + created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in + + if created_in_body != created_in_orelse: + raise ValueError( + 'if statement may not initialize all variables: the true branch' + ' creates %s, while the false branch creates %s. Make sure all' + ' these variables are initialized either in both' + ' branches or before the if statement.' % + (self._fmt_symbol_list(created_in_body), + self._fmt_symbol_list(created_in_orelse))) + + # Alias the closure variables inside the conditional functions, to allow + # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. - aliased_body_orig_names = tuple(body_scope.modified - body_scope.created) - aliased_orelse_orig_names = tuple(orelse_scope.modified - - orelse_scope.created) + aliased_body_orig_names = tuple(need_alias_in_body) + aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( - self.context.namer.new_symbol(s.ssf(), body_scope.referenced) + self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( - self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced) + self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) @@ -154,59 +149,47 @@ class ControlFlowTransformer(transformer.Base): node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) - if not modified: + returned_from_cond = tuple(returned_from_cond) + if returned_from_cond: + if len(returned_from_cond) == 1: + # TODO(mdan): Move this quirk into the operator implementation. + cond_results = returned_from_cond[0] + else: + cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) + + returned_from_body = tuple( + alias_body_map[s] if s in need_alias_in_body else s + for s in returned_from_cond) + returned_from_orelse = tuple( + alias_orelse_map[s] if s in need_alias_in_orelse else s + for s in returned_from_cond) + + else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. - results = None - elif len(modified) == 1: - results = modified[0] - else: - results = gast.Tuple([s.ast() for s in modified], None) - - body_name = self.context.namer.new_symbol('if_true', body_scope.referenced) - orelse_name = self.context.namer.new_symbol('if_false', - orelse_scope.referenced) - if modified: - - def build_returns(aliased_names, alias_map, scope): - """Builds list of return variables for a branch of a conditional.""" - returns = [] - for s in modified: - if s in aliased_names: - returns.append(alias_map[s]) - else: - if s not in scope.created | defined: - raise ValueError( - 'Attempting to return variable "%s" from the true branch of ' - 'a conditional, but it was not closed over, or created in ' - 'this branch.' % str(s)) - else: - returns.append(s) - return tuple(returns) - - body_returns = build_returns(aliased_body_orig_names, alias_body_map, - body_scope) - orelse_returns = build_returns(aliased_orelse_orig_names, - alias_orelse_map, orelse_scope) + cond_results = None + # TODO(mdan): This doesn't belong here; it's specific to the operator. + returned_from_body = templates.replace_as_expression('tf.constant(1)') + returned_from_orelse = templates.replace_as_expression('tf.constant(1)') - else: - body_returns = orelse_returns = templates.replace('tf.ones(())')[0].value + body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) + orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) body_def = self._create_cond_branch( body_name, - aliased_orig_names=tuple(aliased_body_orig_names), - aliased_new_names=tuple(aliased_body_new_names), + aliased_orig_names=aliased_body_orig_names, + aliased_new_names=aliased_body_new_names, body=node_body, - returns=body_returns) + returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, - aliased_orig_names=tuple(aliased_orelse_orig_names), - aliased_new_names=tuple(aliased_orelse_new_names), + aliased_orig_names=aliased_orelse_orig_names, + aliased_new_names=aliased_orelse_new_names, body=node_orelse, - returns=orelse_returns) - cond_expr = self._create_cond_expr(results, node.test, body_name, + returns=returned_from_orelse) + cond_expr = self._create_cond_expr(cond_results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr @@ -214,11 +197,11 @@ class ControlFlowTransformer(transformer.Base): def visit_While(self, node): self.generic_visit(node) - body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced - cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE) + cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) cond_closure = set() for s in cond_scope.referenced: for root in s.support_set: @@ -235,7 +218,7 @@ class ControlFlowTransformer(transformer.Base): raise ValueError('cannot convert while loop: no outputs') state_ssf = [ - self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state + self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf @@ -267,11 +250,9 @@ class ControlFlowTransformer(transformer.Base): state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, - test_name=self.context.namer.new_symbol('loop_test', - body_scope.referenced), + test_name=self.ctx.namer.new_symbol('loop_test', body_scope.referenced), test=test, - body_name=self.context.namer.new_symbol('loop_body', - body_scope.referenced), + body_name=self.ctx.namer.new_symbol('loop_body', body_scope.referenced), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) @@ -281,14 +262,14 @@ class ControlFlowTransformer(transformer.Base): def visit_For(self, node): self.generic_visit(node) - body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced state = list(body_closure) state_ssf = [ - self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state + self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf @@ -326,17 +307,14 @@ class ControlFlowTransformer(transformer.Base): state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, - extra_test_name=self.context.namer.new_symbol('extra_test', - all_referenced), + extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, - body_name=self.context.namer.new_symbol('loop_body', all_referenced), + body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node -def transform(node, context): - cfg.run_analyses(node, cfg.Liveness(context)) - cfg.run_analyses(node, cfg.Defined(context)) - node = ControlFlowTransformer(context).visit(node) +def transform(node, ctx): + node = ControlFlowTransformer(ctx).visit(node) return node diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index 9d23d9b5b7e8e8480e04fccc1c8c81799abf382b..ade35014263c3ae4ec14b40ee0f2507b70627d41 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -19,17 +19,24 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.converters import control_flow -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test -class ControlFlowTest(converter_test_base.TestCase): +class ControlFlowTest(converter_testing.TestCase): - def test_simple_while(self): + def assertTransformedResult(self, test_fn, inputs, expected): + if not isinstance(inputs, tuple): + inputs = (inputs,) + with self.converted(test_fn, control_flow, {}, + constant_op.constant) as result: + with self.test_session() as sess: + self.assertEqual(sess.run(result.test_fn(*inputs)), expected) + + def test_while_basic(self): def test_fn(n): i = 0 @@ -39,29 +46,18 @@ class ControlFlowTest(converter_test_base.TestCase): i += 1 return s, i, n - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) - - with self.compiled(node) as result: - with self.test_session() as sess: - self.assertEqual((10, 5, 5), - sess.run(result.test_fn(constant_op.constant(5)))) + self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5)) - def test_while_single_var(self): + def test_while_single_output(self): def test_fn(n): while n > 0: n -= 1 return n - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) + self.assertTransformedResult(test_fn, constant_op.constant(5), 0) - with self.compiled(node) as result: - with self.test_session() as sess: - self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5)))) - - def test_simple_if(self): + def test_if_basic(self): def test_fn(n): a = 0 @@ -72,114 +68,85 @@ class ControlFlowTest(converter_test_base.TestCase): b = 2 * n return a, b - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) + self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0)) + self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2)) + + def test_if_complex_outputs(self): + + class TestClass(object): - with self.compiled(node) as result: + def __init__(self, a, b): + self.a = a + self.b = b + + def test_fn(n, obj): + obj.a = 0 + obj.b = 0 + if n > 0: + obj.a = -n + else: + obj.b = 2 * n + return obj + + with self.converted(test_fn, control_flow, {}) as result: with self.test_session() as sess: - self.assertEqual((-1, 0), - sess.run(result.test_fn(constant_op.constant(1)))) - self.assertEqual((0, -2), - sess.run(result.test_fn(constant_op.constant(-1)))) + res_obj = result.test_fn(constant_op.constant(1), TestClass(0, 0)) + self.assertEqual(sess.run((res_obj.a, res_obj.b)), (-1, 0)) + res_obj = result.test_fn(constant_op.constant(-1), TestClass(0, 0)) + self.assertEqual(sess.run((res_obj.a, res_obj.b)), (0, -2)) - def test_if_single_var(self): + def test_if_single_output(self): def test_fn(n): if n > 0: n = -n return n - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) + self.assertTransformedResult(test_fn, constant_op.constant(1), -1) - with self.compiled(node) as result: - with self.test_session() as sess: - self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) - - def test_imbalanced_aliasing(self): + def test_if_semi(self): def test_fn(n): if n > 0: n = 3 return n - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) - - with self.compiled(node, control_flow_ops.cond) as result: - with self.test_session() as sess: - self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(2)))) - self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + self.assertTransformedResult(test_fn, constant_op.constant(2), 3) + self.assertTransformedResult(test_fn, constant_op.constant(-3), -3) - def test_ignore_unread_variable(self): + def test_if_local_var(self): def test_fn(n): - b = 3 # pylint: disable=unused-variable if n > 0: b = 4 + n = b + 1 return n - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) + self.assertTransformedResult(test_fn, constant_op.constant(1), 5) + self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) - with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: - with self.test_session() as sess: - self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(3)))) - self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + def test_if_no_outputs(self): - def test_handle_temp_variable(self): + def test_fn(n): + if n > 0: + b = 4 # pylint:disable=unused-variable + return n - def test_fn_using_temp(x, y, w): - if x < y: - z = x + y - else: - w = 2 - tmp = w - z = x - tmp - return z, w + # Without side effect guards, the if statement will stage a cond, + # but that will be pruned at execution. + self.assertTransformedResult(test_fn, constant_op.constant(1), 1) + self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) - node = self.parse_and_analyze(test_fn_using_temp, {}) - node = control_flow.transform(node, self.ctx) + def test_if_imbalanced_outputs(self): - with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: - with self.test_session() as sess: - z, w = sess.run( - result.test_fn_using_temp( - constant_op.constant(-3), constant_op.constant(3), - constant_op.constant(3))) - self.assertEqual(0, z) - self.assertEqual(3, w) - z, w = sess.run( - result.test_fn_using_temp( - constant_op.constant(3), constant_op.constant(-3), - constant_op.constant(3))) - self.assertEqual(1, z) - self.assertEqual(2, w) - - def test_fn_ignoring_temp(x, y, w): - if x < y: - z = x + y - else: - w = 2 - tmp = w - z = x - tmp - return z + def test_fn(n): + if n > 0: + b = 4 + return b - node = self.parse_and_analyze(test_fn_ignoring_temp, {}) - node = control_flow.transform(node, self.ctx) - - with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: - with self.test_session() as sess: - z = sess.run( - result.test_fn_ignoring_temp( - constant_op.constant(-3), constant_op.constant(3), - constant_op.constant(3))) - self.assertEqual(0, z) - z = sess.run( - result.test_fn_ignoring_temp( - constant_op.constant(3), constant_op.constant(-3), - constant_op.constant(3))) - self.assertEqual(1, z) + node, ctx = self.prepare(test_fn, {}) + with self.assertRaises(transformer.AutographParseError): + control_flow.transform(node, ctx) def test_simple_for(self): @@ -191,22 +158,11 @@ class ControlFlowTest(converter_test_base.TestCase): s2 += e * e return s1, s2 - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) + self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), (4, 10)) + empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32) + self.assertTransformedResult(test_fn, empty_vector, (0, 0)) - with self.compiled(node) as result: - with self.test_session() as sess: - l = [1, 2, 3] - self.assertEqual( - test_fn(l), sess.run(result.test_fn(constant_op.constant(l)))) - l = [] - self.assertEqual( - test_fn(l), - sess.run( - result.test_fn( - constant_op.constant(l, shape=(0,), dtype=dtypes.int32)))) - - def test_for_single_var(self): + def test_for_single_output(self): def test_fn(l): s = 0 @@ -214,22 +170,11 @@ class ControlFlowTest(converter_test_base.TestCase): s += e return s - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) + self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), 4) + empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32) + self.assertTransformedResult(test_fn, empty_vector, 0) - with self.compiled(node) as result: - with self.test_session() as sess: - l = [1, 2, 3] - self.assertEqual( - test_fn(l), sess.run(result.test_fn(constant_op.constant(l)))) - l = [] - self.assertEqual( - test_fn(l), - sess.run( - result.test_fn( - constant_op.constant(l, shape=(0,), dtype=dtypes.int32)))) - - def test_for_with_iterated_expression(self): + def test_for_iterated_expression(self): eval_count = [0] @@ -243,14 +188,13 @@ class ControlFlowTest(converter_test_base.TestCase): s += e return s - node = self.parse_and_analyze(test_fn, {'count_evals': count_evals}) - node = control_flow.transform(node, self.ctx) + ns = {'count_evals': count_evals} + node, ctx = self.prepare(test_fn, ns) + node = control_flow.transform(node, ctx) - with self.compiled(node) as result: - result.count_evals = count_evals - self.assertEqual(test_fn(5), result.test_fn(5)) - # count_evals ran twice, once for test_fn and another for result.test_fn - self.assertEqual(eval_count[0], 2) + with self.compiled(node, ns) as result: + self.assertEqual(result.test_fn(5), 10) + self.assertEqual(eval_count[0], 1) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/decorators.py b/tensorflow/contrib/autograph/converters/decorators.py index 92445f31746cf94856ea43893f99a2ba60355fb5..3471bd11d6073f57a2703b438df95a60f19e8e0c 100644 --- a/tensorflow/contrib/autograph/converters/decorators.py +++ b/tensorflow/contrib/autograph/converters/decorators.py @@ -24,19 +24,14 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import pretty_printer +from tensorflow.python.util import tf_inspect -class DecoratorsTransformer(gast.NodeTransformer): +class DecoratorsTransformer(converter.Base): """Converts or removes decorators.""" - def __init__(self, remove_decorators): - self.remove_decorators = remove_decorators - self.additional_dependencies = set() - - # pylint:disable=invalid-name - def visit_FunctionDef(self, node): self.generic_visit(node) kept_decorators = [] @@ -58,31 +53,53 @@ class DecoratorsTransformer(gast.NodeTransformer): # This is currently verified by tests. continue - if not anno.hasanno(dec_func, 'live_val'): - raise ValueError( - 'Could not resolve decorator: %s' % pretty_printer.fmt(dec_func)) - + original_dec = anno.getanno(dec_func, anno.Basic.QN) dec_value = anno.getanno(dec_func, 'live_val') - if dec_value not in self.remove_decorators: - kept_decorators.append((dec, dec_value)) - for _, dec_value in kept_decorators: - if dec_value.__module__ == '__main__': + if dec_value in self.ctx.program.autograph_decorators: + # AutoGraph decorators do not need to be preserved. + continue + + # When using foo.bar.baz, we only really need to grab foo and import + # that. + dec_support_node = dec_func + while isinstance(dec_support_node, gast.Attribute): + dec_support_node = dec_support_node.value + + if not anno.hasanno(dec_support_node, 'live_val'): raise ValueError( - 'decorator "%s" was not allowed because it is declared ' - 'in the module "%s". To fix this, declare it in a separate ' - 'module that we can import it from.' % (dec_value, - dec_value.__module__)) + 'could not resolve symbol "%s" when looking up decorator "%s"' % + (anno.getanno(dec_support_node, anno.Basic.QN), original_dec)) + + dec_support = anno.getanno(dec_support_node, 'live_val') + # The tuple contains: + # * the AST that represents the decorator + # * the entity supporting the decorator (i.e., what we need to import) + # * the name of the module that needs to be imported for this decorator + # to properly resolve. + # Examples: + # for foo.bar, the tuple is (, , 'foo') + # for baz, the tuple is (, , 'baz') + kept_decorators.append((dec, dec_support, + anno.getanno(dec_support_node, anno.Basic.QN))) + + for _, dec_support, name in kept_decorators: + if tf_inspect.ismodule(dec_support): + self.ctx.program.additional_imports.add( + 'import %s as %s' % (dec_support.__name__, name)) else: - self.additional_dependencies.add(dec_value) - - node.decorator_list = [dec for dec, _ in kept_decorators] + if dec_support.__module__ == '__main__': + raise ValueError( + 'decorator "%s" was not allowed because it is declared ' + 'in the module "%s". To fix this, declare it in a separate ' + 'module that we can import it from.' % (dec_support, + dec_support.__module__)) + self.ctx.program.additional_imports.add( + 'from %s import %s' % (dec_support.__module__, name)) + + node.decorator_list = [dec for dec, _, _ in kept_decorators] return node - # pylint:enable=invalid-name - -def transform(node, remove_decorators): - transformer = DecoratorsTransformer(remove_decorators) - node = transformer.visit(node) - return node, transformer.additional_dependencies +def transform(node, ctx): + return DecoratorsTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/decorators_test.py b/tensorflow/contrib/autograph/converters/decorators_test.py index 9c01f689127dbedad7669c65b03e7da071b2d64d..095abc5edc02de55cd0b28d9aa9f9c4e7cec13c3 100644 --- a/tensorflow/contrib/autograph/converters/decorators_test.py +++ b/tensorflow/contrib/autograph/converters/decorators_test.py @@ -20,9 +20,10 @@ from __future__ import print_function from functools import wraps -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import decorators +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.platform import test @@ -39,28 +40,35 @@ def simple_decorator(f): return lambda a: f(a) + 1 -def self_removing_decorator(removing_wrapper): +def self_transform_decorator(transform): + def decorator(f): @wraps(f) def wrapper(*args): # This removing wrapper is defined in the test below. This setup is so - # intricate just to simulate how we use the transformer in practice. - transformed_f = removing_wrapper(f, (self_removing_decorator,)) + # intricate in order to simulate how we use the transformer in practice. + transformed_f = transform(f, (self_transform_decorator,)) return transformed_f(*args) + 1 return wrapper return decorator -class DecoratorsTest(converter_test_base.TestCase): +class DecoratorsTest(converter_testing.TestCase): - def _remover_wrapper(self, f, remove_decorators): + def _transform(self, f, autograph_decorators): namespace = { - 'self_removing_decorator': self_removing_decorator, - 'simple_decorator': simple_decorator + 'self_transform_decorator': self_transform_decorator, + 'simple_decorator': simple_decorator, + 'converter_testing': converter_testing, } - node = self.parse_and_analyze(f, namespace) - node, _ = decorators.transform(node, remove_decorators=remove_decorators) - result, _ = compiler.ast_to_object(node) + node, ctx = self.prepare( + f, + namespace, + recursive=False, + autograph_decorators=autograph_decorators) + node = decorators.transform(node, ctx) + import_line = '\n'.join(ctx.program.additional_imports) + result, _ = compiler.ast_to_object(node, source_prefix=import_line) return getattr(result, f.__name__) def test_noop(self): @@ -68,16 +76,12 @@ class DecoratorsTest(converter_test_base.TestCase): def test_fn(a): return a - node = self.parse_and_analyze(test_fn, {}) - node, deps = decorators.transform(node, remove_decorators=()) - result, _ = compiler.ast_to_object(node) - - self.assertFalse(deps) - self.assertEqual(1, result.test_fn(1)) + with self.converted(test_fn, decorators, {}) as result: + self.assertEqual(1, result.test_fn(1)) def test_function(self): - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(a): return a @@ -88,7 +92,7 @@ class DecoratorsTest(converter_test_base.TestCase): class TestClass(object): - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(self, a): return a @@ -101,38 +105,39 @@ class DecoratorsTest(converter_test_base.TestCase): # Note that reversing the order of this two doesn't work. @classmethod - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(cls, a): return a # 2 = 1 (a) + 1 (decorator applied exactly once) self.assertEqual(2, TestClass.test_fn(1)) - def test_nested_decorators(self): + def test_nested_decorators_local(self): - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(a): @simple_decorator def inner_fn(b): return b + 11 return inner_fn(a) - with self.assertRaises(ValueError): + # Expected to fail because simple_decorator could not be imported. + with self.assertRaises(transformer.AutographParseError): test_fn(1) - # TODO(mdan): Uncomment this test once converter_test_base is updated. - # (can't do it now because it has unrelated pending changes) - # def test_nested_decorators(self): - # - # @self_removing_decorator(self._remover_wrapper) - # def test_fn(a): - # @imported_decorator - # def inner_fn(b): - # return b + 11 - # return inner_fn(a) - # - # # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn) - # self.assertEqual(14, test_fn(1)) + def test_nested_decorators_imported(self): + + @self_transform_decorator(self._transform) + def test_fn(a): + + @converter_testing.imported_decorator + def inner_fn(b): + return b + 11 + + return inner_fn(a) + + # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn) + self.assertEqual(14, test_fn(1)) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/directives.py b/tensorflow/contrib/autograph/converters/directives.py new file mode 100644 index 0000000000000000000000000000000000000000..77f625bac792621c45799d1a220f99eb4b99f7af --- /dev/null +++ b/tensorflow/contrib/autograph/converters/directives.py @@ -0,0 +1,128 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Handles directives. + +This converter removes the directive functions from the code and moves the +information they specify into AST annotations. It is a specialized form of +static analysis, one that is specific to AutoGraph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.lang import directives +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.python.util import tf_inspect + +ENCLOSING_LOOP = 'enclosing_loop' + + +def _map_args(call_node, function): + """Maps AST call nodes to the actual function's arguments. + + Args: + call_node: ast.Call + function: Callable[..., Any], the actual function matching call_node + Returns: + Dict[Text, ast.AST], mapping each of the function's argument names to + the respective AST node. + Raises: + ValueError: if the default arguments are not correctly set + """ + args = call_node.args + kwds = {kwd.arg: kwd.value for kwd in call_node.keywords} + call_args = tf_inspect.getcallargs(function, *args, **kwds) + + # Keyword arguments not specified in kwds will be mapped to their defaults, + # which are Python values. Since we don't currently have a way to transform + # those into AST references, we simply remove them. By convention, directives + # use UNSPECIFIED as default value for for optional arguments. No other + # defaults should be present. + unexpected_defaults = [] + for k in call_args: + if (k not in kwds + and call_args[k] not in args + and call_args[k] is not directives.UNSPECIFIED): + unexpected_defaults.append(k) + if unexpected_defaults: + raise ValueError('Unexpected keyword argument values, %s, for function %s' + % (zip(unexpected_defaults, + [call_args[k] for k in unexpected_defaults]), + function)) + return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED} + + +class DirectivesTransformer(converter.Base): + """Parses compiler directives and converts them into AST annotations.""" + + def _process_symbol_directive(self, call_node, directive): + if len(call_node.args) < 1: + raise ValueError('"%s" requires a positional first argument' + ' as the target' % directive.__name__) + target = call_node.args[0] + defs = anno.getanno(target, anno.Static.ORIG_DEFINITIONS) + for def_ in defs: + def_.directives[directive] = _map_args(call_node, directive) + return call_node + + def _process_statement_directive(self, call_node, directive): + if self.local_scope_level < 1: + raise ValueError( + '"%s" must be used inside a statement' % directive.__name__) + target = self.get_local(ENCLOSING_LOOP) + node_anno = anno.getanno(target, converter.AgAnno.DIRECTIVES, {}) + node_anno[directive] = _map_args(call_node, directive) + anno.setanno(target, converter.AgAnno.DIRECTIVES, node_anno) + return call_node + + def visit_Expr(self, node): + if isinstance(node.value, gast.Call): + call_node = node.value + if anno.hasanno(call_node.func, 'live_val'): + live_val = anno.getanno(call_node.func, 'live_val') + + if live_val is directives.set_element_type: + call_node = self._process_symbol_directive(call_node, live_val) + elif live_val is directives.set_loop_options: + call_node = self._process_statement_directive(call_node, live_val) + else: + return self.generic_visit(node) + + return None # Directive calls are not output in the generated code. + return self.generic_visit(node) + + # TODO(mdan): This will be insufficient for other control flow. + # That means that if we ever have a directive that affects things other than + # loops, we'll need support for parallel scopes, or have multiple converters. + def _track_and_visit_loop(self, node): + self.enter_local_scope() + self.set_local(ENCLOSING_LOOP, node) + node = self.generic_visit(node) + self.exit_local_scope() + return node + + def visit_While(self, node): + return self._track_and_visit_loop(node) + + def visit_For(self, node): + return self._track_and_visit_loop(node) + + +def transform(node, ctx): + return DirectivesTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/directives_test.py b/tensorflow/contrib/autograph/converters/directives_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d083b891314d2f8f3fa61b46edc347ca8e24eb --- /dev/null +++ b/tensorflow/contrib/autograph/converters/directives_test.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for directives module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import directives as directives_converter +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.contrib.autograph.core.converter import AgAnno +from tensorflow.contrib.autograph.lang import directives +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.python.platform import test + + +class DirectivesTest(converter_testing.TestCase): + + def test_local_target(self): + + def test_fn(): + l = [] + string_var = 0 + directives.set_element_type(l, 'a', string_var) + + node, ctx = self.prepare(test_fn, {'directives': directives}) + node = directives_converter.transform(node, ctx) + + def_, = anno.getanno(node.body[0].targets[0], + anno.Static.DEFINITIONS) + d = def_.directives[directives.set_element_type] + self.assertEqual(d['dtype'].s, 'a') + self.assertEqual(d['shape'].id, 'string_var') + + def test_argument_target(self): + + def test_fn(a): + directives.set_element_type(a, 1, shape=2) + + node, ctx = self.prepare(test_fn, {'directives': directives}) + node = directives_converter.transform(node, ctx) + + def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) + d = def_.directives[directives.set_element_type] + self.assertEqual(d['dtype'].n, 1) + self.assertEqual(d['shape'].n, 2) + + def test_loop_target(self): + + def test_fn(): + a = True + while True: + directives.set_loop_options(parallel_iterations=10, back_prop=a) + + node, ctx = self.prepare(test_fn, {'directives': directives}) + node = directives_converter.transform(node, ctx) + + d = anno.getanno(node.body[1], AgAnno.DIRECTIVES) + d = d[directives.set_loop_options] + self.assertEqual(d['parallel_iterations'].n, 10) + self.assertEqual(d['back_prop'].id, 'a') + self.assertNotIn('swap_memory', d) + + def test_invalid_default(self): + + def invalid_directive(valid_arg, invalid_default=object()): + del valid_arg + del invalid_default + return + + def call_invalid_directive(): + invalid_directive(1) + + node, _ = parser.parse_entity(call_invalid_directive) + # Find the call to the invalid directive + node = node.body[0].body[0].value + with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'): + directives_converter._map_args(node, invalid_directive) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/converters/error_handlers.py b/tensorflow/contrib/autograph/converters/error_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..193682139438c1d0133b17165d7f7fb84e2eaaac --- /dev/null +++ b/tensorflow/contrib/autograph/converters/error_handlers.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wraps function bodies with a try/except to rewrite error tracebacks. + +Only adds try/except wrappers to functions that have the anno.Basic.ORIGIN +annotation because these are the functions originally written by the user. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates + + +class ErrorRewritingTransformer(converter.Base): + """Possibly wraps the body of a function in a try/except. + + Only wraps functions that were originally defined by the user, detected by + checking for the anno.Basic.ORIGIN annotation. + """ + + def visit_FunctionDef(self, node): + node = self.generic_visit(node) + + if (anno.hasanno(node, anno.Basic.ORIGIN) and + len(self.enclosing_entities) <= 1): + template = """ + try: + body + except: + ag__.rewrite_graph_construction_error(ag_source_map__) + """ + node.body = templates.replace(template, body=node.body) + return node + + +def transform(node, ctx): + return ErrorRewritingTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/contrib/autograph/converters/error_handlers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5d61b220afa0fcf9a9e619bbd78f83a5076c473a --- /dev/null +++ b/tensorflow/contrib/autograph/converters/error_handlers_test.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for error_handlers module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import error_handlers +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.contrib.autograph.core import errors +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import origin_info +from tensorflow.python.platform import test + + +class ErrorHandlersTest(converter_testing.TestCase): + + def test_basic(self): + + def test_fn(): + raise ValueError() + + node, ctx = self.prepare(test_fn, {}) + anno.setanno( + node, anno.Basic.ORIGIN, + origin_info.OriginInfo(None, 'test_function_name', 'test_code', + 'test_comment')) + node = error_handlers.transform(node, ctx) + with self.compiled(node, {}) as result: + with self.assertRaises(errors.GraphConstructionError): + # Here we just assert that the handler works. Its correctness is + # verified by errors_test.py. + result.test_fn() + + def test_no_origin_annotation(self): + + def test_fn(): + raise ValueError() + + with self.converted(test_fn, error_handlers, {}) as result: + with self.assertRaises(ValueError): + result.test_fn() + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/converters/ifexp.py b/tensorflow/contrib/autograph/converters/ifexp.py deleted file mode 100644 index 616d222762e09feeba1809f119d915dfbe522283..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/converters/ifexp.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Canonicalizes the ternary conditional operator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer - - -class IfExp(transformer.Base): - """Canonicalizes all IfExp nodes into plain conditionals.""" - - def visit_IfExp(self, node): - template = """ - ag__.utils.run_cond(test, lambda: (body,), lambda: (orelse,)) - """ - desugared_ifexp = templates.replace_as_expression( - template, test=node.test, body=node.body, orelse=node.orelse) - return desugared_ifexp - - -def transform(node, context): - """Desugar IfExp nodes into plain conditionals. - - Args: - node: an AST node to transform - context: a context object - - Returns: - new_node: an AST with no IfExp nodes, only conditionals. - """ - - node = IfExp(context).visit(node) - return node diff --git a/tensorflow/contrib/autograph/converters/ifexp_test.py b/tensorflow/contrib/autograph/converters/ifexp_test.py deleted file mode 100644 index ac6849dcb4bd7dacd84bb205f5c65395d8c2f51e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/converters/ifexp_test.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for ifexp module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.autograph import utils -from tensorflow.contrib.autograph.converters import converter_test_base -from tensorflow.contrib.autograph.converters import ifexp -from tensorflow.python.platform import test - - -class IfExpTest(converter_test_base.TestCase): - - def compiled_fn(self, test_fn, *args): - node = self.parse_and_analyze(test_fn, {}) - node = ifexp.transform(node, self.ctx) - module = self.compiled(node, *args) - return module - - def test_simple(self): - - def test_fn(x): - return 1 if x else 0 - - with self.compiled_fn(test_fn) as result: - result.autograph_util = utils - for x in [0, 1]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_fn(self): - - def f(x): - return 3 * x - - def test_fn(x): - y = f(x * x if x > 0 else x) - return y - - with self.compiled_fn(test_fn) as result: - result.autograph_util = utils - result.f = f - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_exp(self): - - def test_fn(x): - return x * x if x > 0 else x - - with self.compiled_fn(test_fn) as result: - result.autograph_util = utils - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_nested(self): - - def test_fn(x): - return x * x if x > 0 else x if x else 1 - - with self.compiled_fn(test_fn) as result: - result.autograph_util = utils - for x in [-2, 0, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_in_cond(self): - - def test_fn(x): - if x > 0: - return x * x if x < 5 else x * x * x - return -x - - with self.compiled_fn(test_fn) as result: - result.autograph_util = utils - for x in [-2, 2, 5]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_assign_in_cond(self): - - def test_fn(x): - if x > 0: - x = -x if x < 5 else x - return x - - with self.compiled_fn(test_fn) as result: - result.autograph_util = utils - for x in [-2, 2, 5]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/autograph/converters/list_comprehension.py b/tensorflow/contrib/autograph/converters/list_comprehension.py deleted file mode 100644 index d7f292015164e047d054c5d1fb0b391e960bb73d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/converters/list_comprehension.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Canonicalizing list comprehensions into for and if statements. - -e.g. -result = [x * x for x in xs] - -becomes - -result = [] -for x in xs: - elt = x * x - result.append(elt) -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import gast - -from tensorflow.contrib.autograph.pyct import parser -from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer - - -class ListCompCanonicalizationTransformer(transformer.Base): - """NodeTransformer to canonicalize list comprehensions.""" - - def __init__(self, context): - super(ListCompCanonicalizationTransformer, self).__init__(context) - - def make_update_list_node(self, list_, elt): - return templates.replace('list_.append(elt)', list_=list_, elt=elt)[0] - - def instantiate_list_node(self): - return parser.parse_str('[]').body[0].value - - def visit_Assign(self, node): - if not isinstance(node.value, gast.ListComp): - return node - if len(node.targets) > 1: - raise ValueError('Only support single assignment.') - return self.canonicalize_listcomp(node.targets[0], node.value) - - def canonicalize_listcomp(self, result_node, list_comp_node): - - make_list = templates.replace( - 'list_ = create_list', - list_=result_node, - create_list=self.instantiate_list_node()) - loop_body = self.make_update_list_node(result_node, list_comp_node.elt) - - for gen in reversed(list_comp_node.generators): - for gen_if in reversed(gen.ifs): - loop_body = templates.replace( - 'if test: loop_body', test=gen_if, loop_body=loop_body) - loop_body = templates.replace( - 'for target in iter_: loop_body', - iter_=gen.iter, - target=gen.target, - loop_body=loop_body) - - return make_list + loop_body - - -def transform(node, context): - return ListCompCanonicalizationTransformer(context).visit(node) diff --git a/tensorflow/contrib/autograph/converters/list_comprehensions.py b/tensorflow/contrib/autograph/converters/list_comprehensions.py new file mode 100644 index 0000000000000000000000000000000000000000..ecf4628816201a0a6ef4ca14b0f351d818d905b3 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/list_comprehensions.py @@ -0,0 +1,82 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Lowers list comprehensions into for and if statements. + +Example: + + result = [x * x for x in xs] + +becomes + + result = [] + for x in xs: + elt = x * x + result.append(elt) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.pyct import templates + + +# TODO(mdan): This should covert directly to operator calls. + + +class ListCompTransformer(converter.Base): + """Lowers list comprehensions into standard control flow.""" + + def visit_Assign(self, node): + if not isinstance(node.value, gast.ListComp): + return self.generic_visit(node) + if len(node.targets) > 1: + raise NotImplementedError('multiple assignments') + + target, = node.targets + list_comp_node = node.value + + template = """ + target = [] + """ + initialization = templates.replace(template, target=target) + + template = """ + target.append(elt) + """ + body = templates.replace(template, target=target, elt=list_comp_node.elt) + + for gen in reversed(list_comp_node.generators): + for gen_if in reversed(gen.ifs): + template = """ + if test: + body + """ + body = templates.replace(template, test=gen_if, body=body) + template = """ + for target in iter_: + body + """ + body = templates.replace( + template, iter_=gen.iter, target=gen.target, body=body) + + return initialization + body + + +def transform(node, ctx): + return ListCompTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/list_comprehension_test.py b/tensorflow/contrib/autograph/converters/list_comprehensions_test.py similarity index 54% rename from tensorflow/contrib/autograph/converters/list_comprehension_test.py rename to tensorflow/contrib/autograph/converters/list_comprehensions_test.py index 4758671f5ec83c26cfa54be0ef68f5f564094f6c..59b5ce9ca052bd1f2201285bef90f398b35e536c 100644 --- a/tensorflow/contrib/autograph/converters/list_comprehension_test.py +++ b/tensorflow/contrib/autograph/converters/list_comprehensions_test.py @@ -12,18 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for list_comprehension module.""" +"""Tests for list_comprehensions module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base -from tensorflow.contrib.autograph.converters import list_comprehension +from tensorflow.contrib.autograph.converters import list_comprehensions +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class ListCompTest(converter_test_base.TestCase): +class ListCompTest(converter_testing.TestCase): + + def assertTransformedEquivalent(self, test_fn, *inputs): + with self.converted(test_fn, list_comprehensions, {}) as result: + self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) def test_basic(self): @@ -31,14 +35,8 @@ class ListCompTest(converter_test_base.TestCase): s = [e * e for e in l] return s - node = self.parse_and_analyze(test_fn, {}) - node = list_comprehension.transform(node, self.ctx) - - with self.compiled(node) as result: - l = [1, 2, 3] - self.assertEqual(test_fn(l), result.test_fn(l)) - l = [] - self.assertEqual(test_fn(l), result.test_fn(l)) + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [1, 2, 3]) def test_multiple_generators(self): @@ -46,29 +44,17 @@ class ListCompTest(converter_test_base.TestCase): s = [e * e for sublist in l for e in sublist] return s - node = self.parse_and_analyze(test_fn, {}) - node = list_comprehension.transform(node, self.ctx) + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [[1], [2], [3]]) - with self.compiled(node) as result: - l = [[1], [2], [3]] - self.assertEqual(test_fn(l), result.test_fn(l)) - l = [] - self.assertEqual(test_fn(l), result.test_fn(l)) - - def test_conds(self): + def test_cond(self): def test_fn(l): s = [e * e for e in l if e > 1] return s - node = self.parse_and_analyze(test_fn, {}) - node = list_comprehension.transform(node, self.ctx) - - with self.compiled(node) as result: - l = [1, 2, 3] - self.assertEqual(test_fn(l), result.test_fn(l)) - l = [] - self.assertEqual(test_fn(l), result.test_fn(l)) + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [1, 2, 3]) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py index c15dfff9e8ebd8b96fd4aff82459a6fd7d0ac8ab..a02fc827b8bd92b36549599b5433118fcd9a28cf 100644 --- a/tensorflow/contrib/autograph/converters/lists.py +++ b/tensorflow/contrib/autograph/converters/lists.py @@ -32,10 +32,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.lang import directives from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -43,7 +44,7 @@ from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno POP_USES = 'pop_uses' -class ListTransformer(transformer.Base): +class ListTransformer(converter.Base): """Converts lists and related operations to their TF counterpart.""" def visit_List(self, node): @@ -88,13 +89,13 @@ class ListTransformer(transformer.Base): scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) target_node = node.func.value - # Attempt to use a related name if can get one. Otherwise use something + # Attempt to use a related name if one exists. Otherwise use something # generic. if anno.hasanno(target_node, anno.Basic.QN): target_name = anno.getanno(target_node, anno.Basic.QN).ssf() else: - target_name = 'list' - pop_var_name = self.context.namer.new_symbol(target_name, scope.referenced) + target_name = 'list_' + pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced) pop_uses = self.get_local(POP_USES, []) pop_uses.append((node, pop_var_name)) @@ -104,9 +105,10 @@ class ListTransformer(transformer.Base): def _replace_stack_call(self, node): assert len(node.args) == 1 - dtype = anno.getanno( + dtype = self.get_definition_directive( node.args[0], - 'element_type', + directives.set_element_type, + 'dtype', default=templates.replace_as_expression('None')) template = """ ag__.list_stack( @@ -134,7 +136,10 @@ class ListTransformer(transformer.Base): node = self._replace_append_call(node) elif func_name == 'pop' and (len(node.args) <= 1): node = self._replace_pop_call(node) - elif func_name == 'stack' and (len(node.args) == 1): + elif (func_name == 'stack' and (len(node.args) == 1) and + (not node.keywords or node.keywords[0].arg == 'strict')): + # This avoids false positives with keyword args. + # TODO(mdan): handle kwargs properly. node = self._replace_stack_call(node) return node @@ -146,15 +151,22 @@ class ListTransformer(transformer.Base): pop_element = original_call_node.args[0] else: pop_element = parser.parse_expression('None') + # The call will be something like "target.pop()", and the dtype is hooked to # target, hence the func.value. - dtype = anno.getanno( + # TODO(mdan): For lists of lists, this won't work. + # The reason why it won't work is because it's unclear how to annotate + # the list as a "list of lists with a certain element type" when using + # operations like `l.pop().pop()`. + dtype = self.get_definition_directive( original_call_node.func.value, - 'element_type', + directives.set_element_type, + 'dtype', default=templates.replace_as_expression('None')) - shape = anno.getanno( + shape = self.get_definition_directive( original_call_node.func.value, - 'element_shape', + directives.set_element_type, + 'shape', default=templates.replace_as_expression('None')) template = """ @@ -223,5 +235,5 @@ class ListTransformer(transformer.Base): return node -def transform(node, context): - return ListTransformer(context).visit(node) +def transform(node, ctx): + return ListTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py index 9f18ab9f44dd8c3f341a02b950f75317c676eff8..996e99ee61b3713a03ff167b892101fca35eaeac 100644 --- a/tensorflow/contrib/autograph/converters/lists_test.py +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -18,9 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph import utils -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import lists +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.contrib.autograph.lang import directives +from tensorflow.contrib.autograph.lang import special_functions +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -28,17 +31,17 @@ from tensorflow.python.ops import list_ops from tensorflow.python.platform import test -class ListTest(converter_test_base.TestCase): +tf = None # Will be replaced by a mock. + + +class ListTest(converter_testing.TestCase): def test_empty_list(self): def test_fn(): return [] - node = self.parse_and_analyze(test_fn, {}) - node = lists.transform(node, self.ctx) - - with self.compiled(node) as result: + with self.converted(test_fn, lists, {}) as result: tl = result.test_fn() # Empty tensor lists cannot be evaluated or stacked. self.assertTrue(isinstance(tl, ops.Tensor)) @@ -49,27 +52,19 @@ class ListTest(converter_test_base.TestCase): def test_fn(): return [1, 2, 3] - node = self.parse_and_analyze(test_fn, {}) - node = lists.transform(node, self.ctx) - - with self.compiled(node) as result: - with self.test_session() as sess: - tl = result.test_fn() - r = list_ops.tensor_list_stack(tl, dtypes.int32) - self.assertAllEqual(sess.run(r), [1, 2, 3]) + with self.converted(test_fn, lists, {}) as result: + self.assertAllEqual(result.test_fn(), [1, 2, 3]) def test_list_append(self): def test_fn(): - l = [1] + l = special_functions.tensor_list([1]) l.append(2) l.append(3) return l - node = self.parse_and_analyze(test_fn, {}) - node = lists.transform(node, self.ctx) - - with self.compiled(node) as result: + ns = {'special_functions': special_functions} + with self.converted(test_fn, lists, ns) as result: with self.test_session() as sess: tl = result.test_fn() r = list_ops.tensor_list_stack(tl, dtypes.int32) @@ -78,24 +73,21 @@ class ListTest(converter_test_base.TestCase): def test_list_pop(self): def test_fn(): - l = [1, 2, 3] - utils.set_element_type(l, dtypes.int32, ()) + l = special_functions.tensor_list([1, 2, 3]) s = l.pop() return s, l - node = self.parse_and_analyze( - test_fn, - { - 'utils': utils, - 'dtypes': dtypes - }, - include_type_analysis=True, - ) - node = lists.transform(node, self.ctx) - - with self.compiled(node) as result: - result.utils = utils - result.dtypes = dtypes + ns = {'special_functions': special_functions} + node, ctx = self.prepare(test_fn, ns) + def_, = anno.getanno(node.body[0].targets[0], + anno.Static.ORIG_DEFINITIONS) + def_.directives[directives.set_element_type] = { + 'dtype': parser.parse_expression('tf.int32'), + 'shape': parser.parse_expression('()'), + } + node = lists.transform(node, ctx) + + with self.compiled(node, ns, dtypes.int32) as result: with self.test_session() as sess: ts, tl = result.test_fn() r = list_ops.tensor_list_stack(tl, dtypes.int32) @@ -108,10 +100,7 @@ class ListTest(converter_test_base.TestCase): s = l.pop().pop() return s - node = self.parse_and_analyze(test_fn, {}) - node = lists.transform(node, self.ctx) - - with self.compiled(node) as result: + with self.converted(test_fn, lists, {}) as result: test_input = [1, 2, [1, 2, 3]] # TODO(mdan): Pass a list of lists of tensor when we fully support that. # For now, we just pass a regular Python list of lists just to verify that @@ -120,29 +109,24 @@ class ListTest(converter_test_base.TestCase): def test_list_stack(self): - tf = None # Will be replaced with a mock. - def test_fn(): l = [1, 2, 3] - utils.set_element_type(l, dtypes.int32) return tf.stack(l) - node = self.parse_and_analyze( - test_fn, - { - 'utils': utils, - 'dtypes': dtypes - }, - include_type_analysis=True, - ) - node = lists.transform(node, self.ctx) - - with self.compiled(node, array_ops.stack, dtypes.int32) as result: - result.utils = utils - result.dtypes = dtypes + node, ctx = self.prepare(test_fn, {}) + def_, = anno.getanno(node.body[0].targets[0], + anno.Static.ORIG_DEFINITIONS) + def_.directives[directives.set_element_type] = { + 'dtype': parser.parse_expression('tf.int32') + } + node = lists.transform(node, ctx) + + with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result: with self.test_session() as sess: self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3]) + # TODO(mdan): Add a test with tf.stack with axis kwarg. + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py index 3a795a315a3c2aa08ac1577a204102755b6e849c..16eb1f0e3f8ad34e615931882ab2896db485f457 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions.py @@ -23,10 +23,10 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer # TODO(mdan): Properly extrack boolean ops according to lazy eval rules. @@ -39,11 +39,11 @@ from tensorflow.contrib.autograph.pyct import transformer SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' -class LogicalExpressionTransformer(transformer.Base): +class LogicalExpressionTransformer(converter.Base): """Converts logical expressions to corresponding TF calls.""" - def __init__(self, context): - super(LogicalExpressionTransformer, self).__init__(context) + def __init__(self, ctx): + super(LogicalExpressionTransformer, self).__init__(ctx) # TODO(mdan): Look into replacing with bitwise operators instead. # TODO(mdan): Skip replacing if the function is trivial. self.op_mapping = { @@ -128,5 +128,5 @@ class LogicalExpressionTransformer(transformer.Base): return right -def transform(node, context): - return LogicalExpressionTransformer(context).visit(node) +def transform(node, ctx): + return LogicalExpressionTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py index 2814060c4d831e4dddacb3dcbcbe1db42160db20..ca07de5e8a1f870391ecbe41bf1341dc52c25347 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py @@ -18,23 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import logical_expressions +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class GradientsFunctionTest(converter_test_base.TestCase): +class GradientsFunctionTest(converter_testing.TestCase): def test_equals(self): def test_fn(a, b): return a == b - node = self.parse_and_analyze(test_fn, {}) - node = logical_expressions.transform(node, self.ctx) - - with self.compiled(node, math_ops.equal) as result: + with self.converted(test_fn, logical_expressions, {}, + math_ops.equal) as result: with self.test_session() as sess: self.assertTrue(sess.run(result.test_fn(1, 1))) self.assertFalse(sess.run(result.test_fn(1, 2))) @@ -44,11 +42,8 @@ class GradientsFunctionTest(converter_test_base.TestCase): def test_fn(a, b, c): return (a or b) and (a or b or c) - node = self.parse_and_analyze(test_fn, {}) - node = logical_expressions.transform(node, self.ctx) - - with self.compiled(node, math_ops.logical_or, - math_ops.logical_and) as result: + with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or, + math_ops.logical_and) as result: with self.test_session() as sess: self.assertTrue(sess.run(result.test_fn(True, False, True))) diff --git a/tensorflow/contrib/autograph/converters/name_scopes.py b/tensorflow/contrib/autograph/converters/name_scopes.py index dfee529abaa8c14d9b408819b32c5199500a2c2f..dd6c6bf960c52d094a16d4cd72fa84f65b9322a1 100644 --- a/tensorflow/contrib/autograph/converters/name_scopes.py +++ b/tensorflow/contrib/autograph/converters/name_scopes.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class FunctionNameScopeTransformer(transformer.Base): +class FunctionNameScopeTransformer(converter.Base): """Wrap a function body with a `name_scope` of the function name.""" def _name_for_current_scope(self): @@ -70,5 +70,5 @@ class FunctionNameScopeTransformer(transformer.Base): return node -def transform(node, context): - return FunctionNameScopeTransformer(context).visit(node) +def transform(node, ctx): + return FunctionNameScopeTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/contrib/autograph/converters/name_scopes_test.py index 17692cbd880dbc1db4bb40ad7345e27907499f9d..a329b0db70e2c6559fa5cf36694cf808fa28a6cb 100644 --- a/tensorflow/contrib/autograph/converters/name_scopes_test.py +++ b/tensorflow/contrib/autograph/converters/name_scopes_test.py @@ -18,30 +18,26 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import name_scopes +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.platform import test -class FunctionNameScopeTransformer(converter_test_base.TestCase): +class FunctionNameScopeTransformer(converter_testing.TestCase): def test_basic(self): def test_fn(l): """This should stay here.""" - a = 5 + a = 1 l += a return l - node = self.parse_and_analyze(test_fn, {}) - node = name_scopes.transform(node, self.ctx) - - with self.compiled(node, ops.name_scope) as result: + with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result: result_op = result.test_fn(constant_op.constant(1)) self.assertIn('test_fn/', result_op.op.name) - self.assertEqual('This should stay here.', result.test_fn.__doc__) def test_long_docstring(self): @@ -54,13 +50,12 @@ class FunctionNameScopeTransformer(converter_test_base.TestCase): Returns: l """ - return l - - node = self.parse_and_analyze(test_fn, {}) - node = name_scopes.transform(node, self.ctx) + return l + 1 - with self.compiled(node, ops.name_scope) as result: - self.assertIn('Multi-line', result.test_fn.__doc__) + with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result: + result_op = result.test_fn(constant_op.constant(1)) + self.assertIn('test_fn/', result_op.op.name) + self.assertIn('Multi-line docstring.', result.test_fn.__doc__) self.assertIn('Returns:', result.test_fn.__doc__) def test_nested_functions(self): @@ -68,21 +63,16 @@ class FunctionNameScopeTransformer(converter_test_base.TestCase): def test_fn(l): def inner_fn(i): - return i ** 2 - - l += 4 - return inner_fn(l) + return i + 1 - node = self.parse_and_analyze(test_fn, {}) - node = name_scopes.transform(node, self.ctx) + l += 1 + return l, inner_fn(l) - with self.compiled(node, ops.name_scope) as result: - result_op = result.test_fn(constant_op.constant(1)) - first_result_input_name = result_op.op.inputs[0].name - second_result_input_name = result_op.op.inputs[1].name - self.assertIn('test_fn/', first_result_input_name) - self.assertNotIn('inner_fn', first_result_input_name) - self.assertIn('test_fn/inner_fn/', second_result_input_name) + with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result: + first, second = result.test_fn(constant_op.constant(1)) + self.assertIn('test_fn/', first.op.name) + self.assertNotIn('inner_fn', first.op.name) + self.assertIn('test_fn/inner_fn/', second.op.name) def test_method(self): @@ -91,48 +81,20 @@ class FunctionNameScopeTransformer(converter_test_base.TestCase): def test_fn(self, l): def inner_fn(i): - return i ** 2 - - l += 4 - return inner_fn(l) + return i + 1 - # Note that 'TestClass' was needed in the namespace here. - node = self.parse_and_analyze( - TestClass, {'TestClass': TestClass}, owner_type=TestClass) - node = name_scopes.transform(node, self.ctx) + l += 1 + return l, inner_fn(l) - with self.compiled(node, ops.name_scope) as result: - result_op = result.TestClass().test_fn(constant_op.constant(1)) - first_result_input_name = result_op.op.inputs[0].name - second_result_input_name = result_op.op.inputs[1].name - self.assertIn('TestClass/test_fn/', first_result_input_name) - self.assertNotIn('inner_fn', first_result_input_name) - self.assertIn('TestClass/test_fn/inner_fn/', second_result_input_name) + ns = {'TestClass': TestClass} + node, ctx = self.prepare(TestClass, ns, owner_type=TestClass) + node = name_scopes.transform(node, ctx) - def test_operator(self): - - class TestClass(object): - - def __call__(self, l): - - def inner_fn(i): - return i ** 2 - - l += 4 - return inner_fn(l) - - # Note that 'TestClass' was needed in the namespace here. - node = self.parse_and_analyze( - TestClass.__call__, {'TestClass': TestClass}, owner_type=TestClass) - node = name_scopes.transform(node, self.ctx) - - with self.compiled(node, ops.name_scope) as result: - result_op = result.__call__(TestClass(), constant_op.constant(1)) - first_result_input_name = result_op.op.inputs[0].name - second_result_input_name = result_op.op.inputs[1].name - self.assertIn('call__/', first_result_input_name) - self.assertNotIn('inner_fn', first_result_input_name) - self.assertIn('call__/inner_fn/', second_result_input_name) + with self.compiled(node, {}, ops.name_scope) as result: + first, second = result.TestClass().test_fn(constant_op.constant(1)) + self.assertIn('TestClass/test_fn/', first.op.name) + self.assertNotIn('inner_fn', first.op.name) + self.assertIn('TestClass/test_fn/inner_fn/', second.op.name) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/single_return.py b/tensorflow/contrib/autograph/converters/return_statements.py similarity index 94% rename from tensorflow/contrib/autograph/converters/single_return.py rename to tensorflow/contrib/autograph/converters/return_statements.py index bcc9ca9dfeb00ef2d2e60edf6a1abfba19a1bad7..a351cd81b82f7fb32f62ac1579355ace0501759d 100644 --- a/tensorflow/contrib/autograph/converters/single_return.py +++ b/tensorflow/contrib/autograph/converters/return_statements.py @@ -20,21 +20,21 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Move this logic into transformer_base. -class BodyVisitor(transformer.Base): +class BodyVisitor(converter.Base): """Walks breadth- or depth-first the list-of-nodes bodies of AST nodes.""" - def __init__(self, context, depth_first=False): + def __init__(self, ctx, depth_first=False): + super(BodyVisitor, self).__init__(ctx) self.depth_first = depth_first self.changes_made = False - super(BodyVisitor, self).__init__(context) def visit_nodelist(self, nodelist): for node in nodelist: @@ -144,13 +144,13 @@ def contains_return(node): return False -class LiftReturn(transformer.Base): +class LiftReturn(converter.Base): """Move return statements out of If and With blocks.""" - def __init__(self, context): + def __init__(self, ctx): + super(LiftReturn, self).__init__(ctx) self.changes_made = False self.common_return_name = None - super(LiftReturn, self).__init__(context) def visit_If(self, node): # Depth-first traversal of if statements @@ -195,8 +195,8 @@ class LiftReturn(transformer.Base): last_return_name = self.common_return_name body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) referenced_names = body_scope.referenced - self.common_return_name = self.context.namer.new_symbol( - 'return_', referenced_names) + self.common_return_name = self.ctx.namer.new_symbol('return_', + referenced_names) node = self.generic_visit(node) self.common_return_name = last_return_name return node @@ -265,7 +265,7 @@ class DetectReturnInFunctionDef(gast.NodeVisitor): 'Each function definition should contain at least one return.') -def transform(node, context): +def transform(node, ctx): """Ensure a function has only a single return. This transforms an AST node with multiple returns successively into containing @@ -280,8 +280,8 @@ def transform(node, context): this is an error. Args: - node: an AST node to transform - context: a context object + node: ast.AST + ctx: converter.EntityContext Returns: new_node: an AST with a single return value @@ -301,10 +301,10 @@ def transform(node, context): while True: # Try to lift all returns out of if statements and with blocks - lr = LiftReturn(context) + lr = LiftReturn(ctx) node = lr.visit(node) changes_made = lr.changes_made - fe = FoldElse(context) + fe = FoldElse(ctx) node = fe.visit(node) changes_made = changes_made or fe.changes_made diff --git a/tensorflow/contrib/autograph/converters/return_statements_test.py b/tensorflow/contrib/autograph/converters/return_statements_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7c8c8a2586c6716e78960ee964ff3b0735fa47 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/return_statements_test.py @@ -0,0 +1,167 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for return_statements module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import return_statements +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +class SingleReturnTest(converter_testing.TestCase): + + def assertTransformedEquivalent(self, test_fn, *inputs): + ns = {'ops': ops} + with self.converted(test_fn, return_statements, ns) as result: + self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + + def test_straightline(self): + + def test_fn(x): + return x * x + + self.assertTransformedEquivalent(test_fn, 2) + + def test_conditional(self): + + def test_fn(x): + if x > 0: + return x + else: + return x * x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def test_missing_orelse(self): + + def test_fn(x): + if x > 0: + return x + + node, ctx = self.prepare(test_fn, {}) + with self.assertRaises(ValueError): + return_statements.transform(node, ctx) + + def test_missing_orelse_recovrable(self): + + def test_fn(x): + if x > 0: + return x + return x * x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def test_missing_branch_return_recoverable(self): + + def test_fn(x): + if x < 0: + x *= x + else: + return x + return x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def test_conditional_nested(self): + + def test_fn(x): + if x > 0: + if x < 5: + return x + else: + return x * x + else: + return x * x * x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + self.assertTransformedEquivalent(test_fn, 5) + + def test_context_manager(self): + + def test_fn(x): + with ops.name_scope(''): + return x * x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def test_context_manager_in_conditional(self): + + def test_fn(x): + if x > 0: + with ops.name_scope(''): + return x * x + else: + return x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def text_conditional_in_context_manager(self): + + def test_fn(x): + with ops.name_scope(''): + if x > 0: + return x * x + else: + return x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def test_no_return(self): + + def test_fn(x): + x *= x + + self.assertTransformedEquivalent(test_fn, 2) + + def test_nested_functions(self): + + def test_fn(x): + + def inner_fn(y): + if y > 0: + return y * y + else: + return y + + return inner_fn(x) + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def test_loop(self): + + def test_fn(x): + for _ in range(10): + return x + return x + + node, ctx = self.prepare(test_fn, {}) + with self.assertRaises(ValueError): + return_statements.transform(node, ctx) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards.py b/tensorflow/contrib/autograph/converters/side_effect_guards.py index 3bcb2d3c42c6e0663c8f78523199a364b6ac231f..b808604f0ab2d42f41a560035ab046ff782a3431 100644 --- a/tensorflow/contrib/autograph/converters/side_effect_guards.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards.py @@ -36,11 +36,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -59,14 +59,9 @@ class SymbolNamer(object): raise NotImplementedError() -class SideEffectGuardTransformer(transformer.Base): +class SideEffectGuardTransformer(converter.Base): """Adds control dependencies to functions with side effects.""" - def __init__(self, context): - super(SideEffectGuardTransformer, self).__init__(context) - - # pylint:disable=invalid-name - def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes @@ -149,7 +144,7 @@ class SideEffectGuardTransformer(transformer.Base): s for s in guarded_args if s not in args_scope.parent.modified) aliased_new_names = tuple( qual_names.QN( - self.context.namer.new_symbol( + self.ctx.namer.new_symbol( s.ssf(), args_scope.parent.referenced)) for s in need_alias) alias_map = dict(zip(need_alias, aliased_new_names)) if len(guarded_args) == 1: @@ -183,8 +178,6 @@ class SideEffectGuardTransformer(transformer.Base): (node.body, alias_map)) return node - # pylint:enable=invalid-name - -def transform(node, context): - return SideEffectGuardTransformer(context).visit(node) +def transform(node, ctx): + return SideEffectGuardTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py index ce0ce33243a1352107eb8121050ee76474869809..bee512abbc2e115d69bc9a5d53b6c54d428cc73a 100644 --- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py @@ -18,147 +18,145 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import side_effect_guards +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variables +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test -class SideEffectGuardsTest(converter_test_base.TestCase): +tf = None # Will be replaced by a mock. - def test_side_effect_on_return_only_variable(self): - tf = None +class SideEffectGuardsTest(converter_testing.TestCase): + + def test_side_effect_on_return_only_variable(self): def test_fn(a): tf.assign(a, a + 1) return a - node = self.parse_and_analyze(test_fn, {}) - node = side_effect_guards.transform(node, self.ctx) + node, ctx = self.prepare(test_fn, {}) + node = side_effect_guards.transform(node, ctx) - with self.compiled(node, state_ops.assign) as result: - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) + + with self.compiled(node, {}, state_ops.assign) as result: with self.test_session() as sess: - v = variables.Variable(2) + v = variable_scope.get_variable('test', initializer=2) sess.run(v.initializer) - # NOTE: We don't expect the assignment to execute in this case, because - # variables cannot be reliably guarded. - self.assertEqual(2, sess.run(result.test_fn(v))) + sess.run(result.test_fn(v)) + # TODO(mdan): Add support for this use case. + # Right now the variable `a` is not conditioned on the `assign` because + # there's no way to add control dependencies to a variable object. + self.assertEqual(2, sess.run(v)) def test_side_effect_on_used_variable(self): - tf = None - def test_fn(a): tf.assign(a, a + 1) return a + 1 - node = self.parse_and_analyze(test_fn, {}) - node = side_effect_guards.transform(node, self.ctx) + node, ctx = self.prepare(test_fn, {}) + node = side_effect_guards.transform(node, ctx) - with self.compiled(node, state_ops.assign) as result: - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) + + with self.compiled(node, {}, state_ops.assign) as result: with self.test_session() as sess: - v = variables.Variable(2) + v = variable_scope.get_variable('test', initializer=2) sess.run(v.initializer) - # NOTE: Unlike test_side_effect_on_return_only_variable, the variable - # was used in the local scope and so we could catch the assign's side - # effect. - self.assertEqual(4, sess.run(result.test_fn(v))) + sess.run(result.test_fn(v)) + # TODO(mdan): Ensure the result of test_fn(v) is also deterministic. + # Right now it's 3 or 4 based on whether the read is synchronized. + self.assertEqual(3, sess.run(v)) def test_side_effect_on_tensor(self): - tf = None - def test_fn(a): tf.Assert(a > 0, ['expected in throw']) return a - node = self.parse_and_analyze(test_fn, {}) - node = side_effect_guards.transform(node, self.ctx) + node, ctx = self.prepare(test_fn, {}) + node = side_effect_guards.transform(node, ctx) - with self.compiled(node, control_flow_ops.Assert) as result: - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) + + with self.compiled(node, {}, control_flow_ops.Assert) as result: with self.test_session() as sess: - # NOTE: In this case we can also capture the side effect because the - # argument is a tensor ans we can wrap it inside an identity. with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'expected in throw'): sess.run(result.test_fn(constant_op.constant(-1))) def test_multiline_block(self): - tf = None - def test_fn(a): - tf.assign(a, a + 1) + tf.assign_add(a, 1) b = a + 1 - tf.assign(a, b + 1) - c = b + 1 - d = c + 1 - return d + tf.assign_add(a, 1) + b += 1 + return b - node = self.parse_and_analyze(test_fn, {}) - node = side_effect_guards.transform(node, self.ctx) + node, ctx = self.prepare(test_fn, {}) + node = side_effect_guards.transform(node, ctx) - with self.compiled(node, state_ops.assign) as result: - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) + + with self.compiled(node, {}, state_ops.assign_add) as result: with self.test_session() as sess: - v = variables.Variable(2) + v = variable_scope.get_variable('test', initializer=2) sess.run(v.initializer) - self.assertEqual(6, sess.run(result.test_fn(v))) + sess.run(result.test_fn(v)) + # TODO(mdan): Ensure the result of test_fn(v) is also deterministic. + self.assertEqual(4, sess.run(v)) def test_multiline_nested_block(self): - tf = None - def test_fn(a): with tf.name_scope('foo'): tf.assign(a, a + 1) b = a + 1 - c = b + 1 - d = c + 1 - return d + return b - node = self.parse_and_analyze(test_fn, {}) - node = side_effect_guards.transform(node, self.ctx) + node, ctx = self.prepare(test_fn, {}) + node = side_effect_guards.transform(node, ctx) - with self.compiled(node, state_ops.assign, ops.name_scope) as result: - self.assertEqual(len(node.body[0].body[0].body), 1) + self.assertEqual(len(node.body[0].body), 1) + + with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result: with self.test_session() as sess: - v = variables.Variable(2) + v = variable_scope.get_variable('test', initializer=2) sess.run(v.initializer) - self.assertEqual(6, sess.run(result.test_fn(v))) + sess.run(result.test_fn(v)) + # TODO(mdan): Ensure the result of test_fn(v) is also deterministic. + self.assertEqual(3, sess.run(v)) def test_multiline_block_unsafe(self): - tf = None - def test_fn(a): tf.assign(a, a + 1) b = a + 1 - tf.assign(a, a + 1) + tf.assign_add(a, 1) c = b + 1 - d = c + 1 - return d + return c + + node, ctx = self.prepare(test_fn, {}) + node = side_effect_guards.transform(node, ctx) - node = self.parse_and_analyze(test_fn, {}) - node = side_effect_guards.transform(node, self.ctx) + self.assertEqual(len(node.body), 1) - with self.compiled(node, state_ops.assign) as result: - self.assertEqual(len(node.body[0].body), 1) + with self.compiled(node, {}, state_ops.assign, + state_ops.assign_add) as result: with self.test_session() as sess: - v = variables.Variable(2) + v = variable_scope.get_variable('test', initializer=2) sess.run(v.initializer) - # NOTE: This intentionally highlights the flakiness. The test should be - # tightened down once that is solved. - self.assertTrue(sess.run(result.test_fn(v)) in (6, 7)) + sess.run(result.test_fn(v)) + # TODO(mdan): Ensure the result of test_fn(v) is also deterministic. + self.assertEqual(4, sess.run(v)) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/single_return_test.py b/tensorflow/contrib/autograph/converters/single_return_test.py deleted file mode 100644 index d483005a09537ea8227814f65aa7e6402c853f60..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/converters/single_return_test.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for single_return module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.autograph.converters import converter_test_base -from tensorflow.contrib.autograph.converters import single_return -from tensorflow.python.framework.ops import name_scope -from tensorflow.python.platform import test - - -class SingleReturnTest(converter_test_base.TestCase): - - def compiled_fn(self, test_fn, *args): - node = self.parse_and_analyze(test_fn, {}) - node = single_return.transform(node, self.ctx) - module = self.compiled(node, *args) - return module - - def test_noop(self): - # Noop - def test_fn(x): - return x - - with self.compiled_fn(test_fn) as result: - self.assertEqual(test_fn(2.0), result.test_fn(2.0)) - - def test_return_expression(self): - # ANF - def test_fn(x): - return x * x - - with self.compiled_fn(test_fn) as result: - x = 2 - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_merge(self): - # Simple merge - def test_fn(x): - if x > 0: - return x - else: - return x * x - - with self.compiled_fn(test_fn) as result: - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_orphan_branch(self): - - def test_fn(x): - if x > 0: - return x - - with self.assertRaises(ValueError): - self.compiled_fn(test_fn) - - def test_lift_body_into_false_branch(self): - - def test_fn(x): - if x > 0: - return x - return x * x - - with self.compiled_fn(test_fn) as result: - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_lift_body_into_true_branch(self): - - def test_fn(x): - if x < 0: - x *= x - else: - # TODO(alexbw): linter bug here that requires us suppress this warning. - return x # pylint: disable=undefined-loop-variable - return x - - with self.compiled_fn(test_fn) as result: - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_nested_if(self): - - def test_fn(x): - if x > 0: - if x < 5: - return x - else: - return x * x - else: - return x * x * x - - with self.compiled_fn(test_fn) as result: - for x in [-2, 2, 5]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_context_manager(self): - - def test_fn(x): - - with name_scope(''): - return x * x - - with self.compiled_fn(test_fn) as result: - result.name_scope = name_scope - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_context_manager_in_conditional(self): - - def test_fn(x): - if x > 0: - with name_scope(''): - return x * x - else: - return x - - with self.compiled_fn(test_fn, name_scope) as result: - result.name_scope = name_scope - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def text_conditional_in_context_manager(self): - - def test_fn(x): - with name_scope(''): - if x > 0: - return x * x - else: - return x - - with self.compiled_fn(test_fn) as result: - result.name_scope = name_scope - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_no_return(self): - - def test_fn(x): - x *= x - - with self.compiled_fn(test_fn) as result: - self.assertEqual(test_fn(2), result.test_fn(2)) - - def test_nested_functiondefs(self): - - def test_fn(x): - - def inner_fn(y): - if y > 0: - return y * y - else: - return y - - return inner_fn(x) - - with self.compiled_fn(test_fn) as result: - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_loop(self): - - def test_fn(x): - for _ in range(10): - return x - return x - - with self.assertRaises(ValueError): - self.compiled_fn(test_fn) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py index 85aeda9c4164eb70329bd50f789eea5441c8fc87..c527f98613a2ffebf35141d4dac85e972a89c93b 100644 --- a/tensorflow/contrib/autograph/converters/slices.py +++ b/tensorflow/contrib/autograph/converters/slices.py @@ -20,12 +20,12 @@ from __future__ import print_function import gast -from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.lang import directives from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class SliceTransformer(transformer.Base): +class SliceTransformer(converter.Base): """Converts slicing operations to their TF counterpart. Currently, relying on the default slice operator that Tensor uses is @@ -36,12 +36,14 @@ class SliceTransformer(transformer.Base): def _process_single_assignment(self, target, value): if not isinstance(target, gast.Subscript): return None + if not isinstance(target.slice, gast.Index): + return None template = """ target = ag__.set_item(target, key, item) """ return templates.replace( - template, target=target.value, key=target.slice, item=value) + template, target=target.value, key=target.slice.value, item=value) def visit_Assign(self, node): node = self.generic_visit(node) @@ -56,17 +58,17 @@ class SliceTransformer(transformer.Base): def visit_Subscript(self, node): node = self.generic_visit(node) if not isinstance(node.slice, gast.Index): - # TODO(mdan): It might make more sense to wave them through. - raise NotImplementedError('non-index slice') + return node if not isinstance(node.ctx, gast.Load): # Index writes are handled at a higher level, one at which the rvalue is # also available. return node - dtype = anno.getanno( + dtype = self.get_definition_directive( node.value, - 'element_type', + directives.set_element_type, + 'dtype', default=templates.replace_as_expression('None')) template = """ @@ -76,8 +78,8 @@ class SliceTransformer(transformer.Base): opts=ag__.GetItemOpts(element_dtype=dtype)) """ return templates.replace_as_expression( - template, target=node.value, key=node.slice, dtype=dtype) + template, target=node.value, key=node.slice.value, dtype=dtype) -def transform(node, context): - return SliceTransformer(context).visit(node) +def transform(node, ctx): + return SliceTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py index 6c2d7e1ea1a6c46fcc3a2c6972a24507646ef858..c822d53a4a2810755fd6841af85544dd8fc76a5e 100644 --- a/tensorflow/contrib/autograph/converters/slices_test.py +++ b/tensorflow/contrib/autograph/converters/slices_test.py @@ -18,42 +18,59 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph import utils -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import slices +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.contrib.autograph.lang import directives +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import list_ops from tensorflow.python.platform import test -class SliceTest(converter_test_base.TestCase): +class SliceTest(converter_testing.TestCase): def test_index_access(self): def test_fn(l): - utils.set_element_type(l, dtypes.int32) return l[1] - node = self.parse_and_analyze( - test_fn, - { - 'utils': utils, - 'dtypes': dtypes - }, - include_type_analysis=True, - ) - node = slices.transform(node, self.ctx) - - with self.compiled(node, dtypes.int32) as result: - result.utils = utils - result.dtypes = dtypes + node, ctx = self.prepare(test_fn, {}) + def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) + def_.directives[directives.set_element_type] = { + 'dtype': parser.parse_expression('tf.int32') + } + node = slices.transform(node, ctx) + + with self.compiled(node, {}, dtypes.int32) as result: with self.test_session() as sess: tl = list_ops.tensor_list_from_tensor( [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32)) y = result.test_fn(tl) self.assertEqual(2, sess.run(y)) + def test_index_access_multiple_definitions(self): + + def test_fn(l): + if l: + l = [] + return l[1] + + node, ctx = self.prepare(test_fn, {}) + def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) + def_.directives[directives.set_element_type] = { + 'dtype': parser.parse_expression('tf.int32') + } + def_, = anno.getanno(node.body[0].body[0].targets[0], + anno.Static.DEFINITIONS) + def_.directives[directives.set_element_type] = { + 'dtype': parser.parse_expression('tf.float32') + } + with self.assertRaises(transformer.AutographParseError): + slices.transform(node, ctx) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/core/BUILD b/tensorflow/contrib/autograph/core/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..1873045a921f8af6068d8fccca6a5625b2aedcf8 --- /dev/null +++ b/tensorflow/contrib/autograph/core/BUILD @@ -0,0 +1,75 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "core", + srcs = [ + "config.py", + "converter.py", + "errors.py", + "naming.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", + ], +) + +py_test( + name = "errors_test", + srcs = ["errors_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + ], +) + +py_test( + name = "naming_test", + srcs = ["naming_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + "//tensorflow/python:client_testlib", + ], +) + +py_library( + name = "test_lib", + srcs = [ + "converter_testing.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":core", + "//tensorflow/contrib/autograph/operators", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", + "@gast_archive//:gast", + "@six_archive//:six", + ], +) diff --git a/tensorflow/contrib/autograph/impl/config.py b/tensorflow/contrib/autograph/core/config.py similarity index 100% rename from tensorflow/contrib/autograph/impl/config.py rename to tensorflow/contrib/autograph/core/config.py diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/contrib/autograph/core/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..83a80c1f52123c325782a67c651e892163af83b3 --- /dev/null +++ b/tensorflow/contrib/autograph/core/converter.py @@ -0,0 +1,330 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter construction support. + +This module contains a base class for all converters, as well as supporting +structures. These structures are referred to as contexts. + +The class hierarchy is as follows: + + + [extends] converter.Base + [extends] transformer.Base + [extends] gast.nodeTransformer + [uses] transfomer.SourceInfo + [uses] converter.EntityContext + [uses] converter.ProgramContext + [uses] transfomer.SourceInfo + +converter.Base is a specialization of transformer.Base for AutoGraph. It's a +very lightweight subclass that adds a `ctx` attribute holding the corresponding +EntityContext object (see below). Note that converters are not reusable, and +`visit` will raise an error if called more than once. + +converter.EntityContext contains mutable state associated with an entity that +the converter processes. + +converter.ProgramContext contains mutable state across related entities. For +example, when converting several functions that call one another, the +ProgramContext should be shared across these entities. + +Below is the overal flow at conversion: + + program_ctx = ProgramContext(, , ...) + while : + entity, source_info = + entity_ctx = EntityContext(program_ctx, source_info) + for : + converter = ConverterClass(entity_ctx) + + # May update entity_ctx and program_ctx + entity = converter.visit(entity) + + + +Note that pyct contains a small number of transformers used for static analysis. +These implement transformer.Base, rather than converter.Base, to avoid a +dependency on AutoGraph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from enum import Enum + + +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import naming +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import liveness +from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions +from tensorflow.contrib.autograph.pyct.static_analysis import type_info + +# TODO(mdan): These contexts can be refactored into first class objects. +# For example, we could define Program and Entity abstractions that hold on +# to the actual entity and have conversion methods. + +# TODO(mdan): Add a test specific to this converter. + + +class ProgramContext(object): + """ProgramContext keeps track of converting function hierarchies. + + This object is mutable, and is updated during conversion. Not thread safe. + + Attributes: + recursive: bool, whether to recursively convert any functions that the + decorator function may call. + autograph_decorators: Tuple[Callable, ...], decorator functions that belong + to AutoGraph. These require special treatment. + dependency_cache: Dict[Any, ast.AST], the original entities mapped to their + converted AST + additional_imports: Set[Any], additional entities which for any reason + cannot be attached after loading and need to be explicitly imported + in the generated code + name_map: Dict[str, str], map of original entity name to the name of + their converted counterparts + autograph_module: Module, a reference to the autograph module. This + needs to be specified by the caller to avoid circular dependencies. + uncompiled_modules: Set[Tuple[str, ...]], with each tuple representing the + fully qualified name of a package containing functions that will not be + compiled. + required_imports: str, containing an import statement on each line. These + are all the imports necessary for the compiled code to run, in addition + to the closures of each entity, which are attached dynamically. + """ + + def __init__( + self, + recursive, + autograph_decorators, + partial_types, + autograph_module, + uncompiled_modules, + ): + self.recursive = recursive + self.autograph_decorators = autograph_decorators + self.partial_types = partial_types if partial_types else () + self.autograph_module = autograph_module + self.uncompiled_modules = uncompiled_modules + + # Required to output dependencies in discovery order, which should match + # the reverse dependency order. + self.dependency_cache = collections.OrderedDict() + self.additional_imports = set() + self.name_map = {} + + @property + def required_imports(self): + """Returns a block containing all imports required by the converted code.""" + # TODO(mdan): Check that these don't clobber one another. + return '\n'.join(config.COMPILED_IMPORT_STATEMENTS + + tuple(self.additional_imports)) + + def new_namer(self, namespace): + return naming.Namer(namespace, self.recursive, self.name_map, + self.partial_types) + + def update_name_map(self, namer): + """Updates renamed_calls based on the recent activity from the namer. + + Whenever we convert a new entity, any references to other entities are being + renamed to match their soon-to-be-converted counterparts. The namer keeps + track of these renames. When conversion is complete, we copy those renames + so that when those referenced entities are being converted, their new name + matches. + + Args: + namer: naming.Namer + + Raises: + ValueError: when an entity was renamed twice and to different names. + """ + # TODO(mdan): Have call_trees do this directly. + # This is done so indirectly, via the namer, for historic reasons. But + # now we can have the converter that does the rename record the new name + # as well and skip this step altogether. + for o, name in namer.renamed_calls.items(): + if o in self.name_map: + if self.name_map[o] != name: + raise ValueError( + 'Calls to %s were converted using multiple names (%s). This is ' + 'possible when an entity with one of these names already ' + 'existed. To fix, avoid using any of these names.' % + (o, (name, self.name_map[o]))) + else: + self.name_map[o] = name + + def add_to_cache(self, original_entity, converted_ast): + self.dependency_cache[original_entity] = converted_ast + + +class EntityContext(object): + """Tracks the conversion of a single entity. + + This object is mutable, and is updated during conversion. Not thread safe. + + Attributes: + namer: Namer + info: transformer.EntityInfo + program: ProgramContext + """ + + def __init__(self, namer, entity_info, program_ctx): + self.namer = namer + self.info = entity_info + self.program = program_ctx + + +class Base(transformer.Base): + """All converters should inherit from this class. + + Attributes: + ctx: EntityContext + """ + + def __init__(self, ctx): + super(Base, self).__init__(ctx.info) + self.ctx = ctx # Keeping this short because it's used frequently. + + self._used = False + self._ast_depth = 0 + + def get_definition_directive(self, node, directive, arg, default): + """Returns the unique directive for a symbol, or a default if none exist. + + See lang/directives.py for details on directives. + + Args: + node: ast.AST + directive: Callable[..., Any] + arg: str + default: Any + + Raises: + ValueError: if conflicting annotations have been found + """ + defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) + if not defs: + return default + + # TODO(mdan): Simplify this. + arg_values = [] + for def_ in defs: + if (directive not in def_.directives or + arg not in def_.directives[directive]): + continue + arg_value = def_.directives[directive][arg] + for prev_value in arg_values: + if not ast_util.matches(arg_value, prev_value): + qn = anno.getanno(node, anno.Basic.QN) + raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' % + (qn, directive.__name__, arg, + compiler.ast_to_source(arg_value).strip(), + compiler.ast_to_source(prev_value).strip())) + arg_values.append(arg_value) + + if not arg_values: + return default + + arg_value, = arg_values + return arg_value + + def visit(self, node): + if not self._ast_depth: + if self._used: + raise ValueError('converter objects cannot be reused') + self._used = True + + self._ast_depth += 1 + try: + return super(Base, self).visit(node) + finally: + self._ast_depth -= 1 + + +class AnnotatedDef(reaching_definitions.Definition): + + def __init__(self): + super(AnnotatedDef, self).__init__() + self.directives = {} + + +class AgAnno(Enum): + """Annotation labels specific to AutoGraph. See anno.py.""" + + DIRECTIVES = 'User directives associated with the annotated statement.' + + def __repr__(self): + return self.name + + +def standard_analysis(node, context, is_initial=False): + """Performs a complete static analysis of the given code. + + Args: + node: ast.AST + context: converter.EntityContext + is_initial: bool, whether this is the initial analysis done on the input + source code + + Returns: + ast.AST, same as node, with the static analysis annotations added + """ + # TODO(mdan): Clear static analysis here. + # TODO(mdan): Consider not running all analyses every time. + # TODO(mdan): Don't return a node because it's modified by reference. + graphs = cfg.build(node) + node = qual_names.resolve(node) + node = activity.resolve(node, context.info, None) + node = reaching_definitions.resolve(node, context.info, graphs, AnnotatedDef) + node = liveness.resolve(node, context.info, graphs) + node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) + node = type_info.resolve(node, context.info) + # This second call allows resolving first-order class attributes. + node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) + if is_initial: + anno.dup( + node, + { + anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, + }, + ) + return node + + +def apply_(node, context, converter_module): + """Applies a converter to an AST. + + Args: + node: ast.AST + context: converter.EntityContext + converter_module: converter.Base + + Returns: + ast.AST, the result of applying converter to node + """ + node = standard_analysis(node, context) + node = converter_module.transform(node, context) + return node diff --git a/tensorflow/contrib/autograph/converters/converter_test_base.py b/tensorflow/contrib/autograph/core/converter_testing.py similarity index 63% rename from tensorflow/contrib/autograph/converters/converter_test_base.py rename to tensorflow/contrib/autograph/core/converter_testing.py index 41c2e71702e7e3ee3811a2cbee27c8c988eb3a5c..5ee2c3fffd7474cb8ca28349385a9d543e92a72d 100644 --- a/tensorflow/contrib/autograph/converters/converter_test_base.py +++ b/tensorflow/contrib/autograph/core/converter_testing.py @@ -20,20 +20,27 @@ from __future__ import print_function import contextlib import imp +import sys + +import six from tensorflow.contrib.autograph import operators from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.core import errors from tensorflow.contrib.autograph.pyct import compiler -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import pretty_printer -from tensorflow.contrib.autograph.pyct import qual_names -from tensorflow.contrib.autograph.pyct.static_analysis import activity -from tensorflow.contrib.autograph.pyct.static_analysis import live_values -from tensorflow.contrib.autograph.pyct.static_analysis import type_info +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.platform import test +def imported_decorator(f): + return lambda a: f(a) + 1 + + +# TODO(mdan): We might be able to use the real namer here. class FakeNamer(object): """A fake namer that uses a global counter to generate unique names.""" @@ -67,7 +74,17 @@ class TestCase(test.TestCase): """Base class for unit tests in this module. Contains relevant utilities.""" @contextlib.contextmanager - def compiled(self, node, *symbols): + def assertPrints(self, expected_result): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + yield + self.assertEqual(out_capturer.getvalue(), expected_result) + finally: + sys.stdout = sys.__stdout__ + + @contextlib.contextmanager + def compiled(self, node, namespace, *symbols): source = None self.dynamic_calls = [] @@ -77,12 +94,17 @@ class TestCase(test.TestCase): return 7 try: - result, source = compiler.ast_to_object(node) + result, source = compiler.ast_to_object(node, include_source_map=True) + result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__['utils'] = utils + fake_ag.__dict__['rewrite_graph_construction_error'] = ( + errors.rewrite_graph_construction_error) result.__dict__['ag__'] = fake_ag + for k, v in namespace.items(): + result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: @@ -91,6 +113,13 @@ class TestCase(test.TestCase): print('Offending compiled code:\n%s' % source) raise + @contextlib.contextmanager + def converted(self, entity, converter_module, namespace, *tf_symbols): + node, ctx = self.prepare(entity, namespace) + node = converter_module.transform(node, ctx) + with self.compiled(node, namespace, *tf_symbols) as result: + yield result + def make_fake_mod(self, name, *symbols): fake_mod = imp.new_module(name) for s in symbols: @@ -107,30 +136,31 @@ class TestCase(test.TestCase): for k, v in ns.items(): setattr(module, k, v) - def parse_and_analyze(self, - test_fn, - namespace, - namer=None, - arg_types=None, - include_type_analysis=True, - owner_type=None, - recursive=True): + def prepare(self, + test_fn, + namespace, + namer=None, + arg_types=None, + owner_type=None, + recursive=True, + autograph_decorators=()): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=namer or FakeNamer(), + node = node.body[0] + if namer is None: + namer = FakeNamer() + program_ctx = converter.ProgramContext( + recursive=recursive, + autograph_decorators=autograph_decorators, + partial_types=None, + autograph_module=None, + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) + entity_info = transformer.EntityInfo( source_code=source, - source_file=None, + source_file='', namespace=namespace, arg_values=None, arg_types=arg_types, - owner_type=owner_type, - recursive=recursive, - type_annotation_func=utils.set_element_type) - node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) - if include_type_analysis: - node = type_info.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) - self.ctx = ctx - return node + owner_type=owner_type) + ctx = converter.EntityContext(namer, entity_info, program_ctx) + node = converter.standard_analysis(node, ctx, is_initial=True) + return node, ctx diff --git a/tensorflow/contrib/autograph/core/errors.py b/tensorflow/contrib/autograph/core/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..5a57d57e7d4c6461f05030b72cc9bfe1b33210db --- /dev/null +++ b/tensorflow/contrib/autograph/core/errors.py @@ -0,0 +1,258 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Error rewriting logic. + +Contains the functions responsible for rewriting tracebacks of errors raised +in AutoGraph (AG) code to refer to user written code, so that errors only refer +to the original user code. + +When 'user code' is used in comments it refers to the original source code that +the user wrote and is converting using AutoGraph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import logging +import sys +import traceback + +from tensorflow.contrib.autograph.pyct import origin_info +from tensorflow.python.framework import errors_impl + +# TODO(mdan): Add a superclass common to all errors. + + +class GraphConstructionError(Exception): + """Error for graph construction errors from AutoGraph generated code.""" + + def __init__(self, original_error, custom_traceback): + self.original_error = original_error + self.custom_traceback = custom_traceback + super(GraphConstructionError, self).__init__() + + def __str__(self): + traceback_str = ''.join(traceback.format_list(self.custom_traceback)) + return ('Traceback (most recent call last):\n' + traceback_str + '\n' + str( + self.original_error) + '\n') + + +class TfRuntimeError(Exception): + """Error wrapper for runtime errors raised by AutoGraph generated code.""" + + def __init__(self, op_name, op_message, custom_traceback): + self.op_name = op_name + self.op_message = op_message + self.custom_traceback = custom_traceback + super(TfRuntimeError, self).__init__() + + def __str__(self): + message = '%s\n\nCaused by op %r, defined at:\n' % (self.op_message, + self.op_name) + return message + ''.join(traceback.format_list(self.custom_traceback)) + + +def _rewrite_tb(source_map, tb): + """Rewrites code references in a traceback. + + Args: + source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping + locations to their origin + tb: List[Tuple[Text, Text, Text, Text]], consistent with + traceback.extract_tb. + Returns: + List[Tuple[Text, Text, Text, Text]], the rewritten traceback + """ + new_tb = [] + for frame in tb: + filename, lineno, _, _ = frame + loc = origin_info.LineLocation(filename, lineno) + origin = source_map.get(loc) + if origin is not None: + new_tb.append(origin.as_frame()) + else: + new_tb.append(frame) + return new_tb + + +# TODO(mdan): rename to raise_* +def rewrite_graph_construction_error(source_map): + """Rewrites errors raised by non-AG APIs inside AG generated code. + + This is called from the except handler inside an AutoGraph generated function + (that is, during exception handling). Only rewrites the frames corresponding + to the function that this is called from, so each function is responsible + to call this to have its own frames rewritten. + + This function always raises an error. + + Args: + source_map: Dict[origin_info.Location, origin_info.OriginInfo], the source + map belonging to the calling function + + Raises: + GraphConstructionError: The rewritten underlying error. + Exception: The underlying error, if it could not be rewritten. + """ + error_info = sys.exc_info() + _, original_error, e_traceback = error_info + assert original_error is not None + try: + current_traceback = _cut_traceback_loops(source_map, + traceback.extract_tb(e_traceback)) + if isinstance(original_error, GraphConstructionError): + # TODO(mdan): This is incomplete. + # The error might have bubbled through a non-converted function. + previous_traceback = original_error.custom_traceback + cleaned_traceback = [current_traceback[0]] + previous_traceback + else: + cleaned_traceback = current_traceback + + cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback) + + if isinstance(original_error, GraphConstructionError): + original_error.custom_traceback = cleaned_traceback + new_error = original_error + else: + new_error = GraphConstructionError(original_error, cleaned_traceback) + except Exception: + logging.exception('Error while rewriting AutoGraph error:') + # TODO(mdan): Should reraise here, removing the top frame as well. + raise original_error + else: + raise new_error + finally: + # Addresses warning https://docs.python.org/2/library/sys.html#sys.exc_info. + del e_traceback + + +def _cut_traceback_loops(source_map, original_traceback): + """Check for cases where we leave a user method and re-enter it. + + This is done by looking at the function names when the filenames are from any + files the user code is in. If we find a case where we return to a user method + after leaving it then we cut out the frames in between because we assume this + means these in between frames are from internal AutoGraph code that shouldn't + be included. + + An example of this is: + + File "file1.py", line 57, in my_func + ... + File "control_flow_ops.py", line 231, in cond + ... + File "control_flow_ops.py", line 1039, in inner_cond + ... + File "file1.py", line 68, in my_func + ... + + Where we would remove the control_flow_ops.py frames because we re-enter + my_func in file1.py. + + The source map keys are (file_path, line_number) so get the set of all user + file_paths. + + Args: + source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping + locations to their origin + original_traceback: List[Tuple[Text, Text, Text, Text]], consistent with + traceback.extract_tb. + + Returns: + List[Tuple[Text, Text, Text, Text]], the traceback with any loops removed. + """ + all_user_files = set(loc.filename for loc in source_map) + cleaned_traceback = [] + last_user_frame_index = None + last_user_user_file_path = None + # TODO(mdan): Simplify this logic. + for fi, frame in enumerate(original_traceback): + frame_file_path, lineno, _, _ = frame + src_map_key = origin_info.LineLocation(frame_file_path, lineno) + if frame_file_path in all_user_files: + if src_map_key in source_map: + if (last_user_frame_index is not None and + last_user_user_file_path == frame_file_path): + cleaned_traceback = cleaned_traceback[:last_user_frame_index] + last_user_frame_index = fi + last_user_user_file_path = frame_file_path + cleaned_traceback.append(frame) + return cleaned_traceback + + +# TODO(mdan): This should be consistent with rewrite_graph_construction_error +# Both should either raise or return. +def rewrite_tf_runtime_error(error, source_map): + """Rewrites TensorFlow runtime errors raised by ops created in AG code. + + Args: + error: tf.OpError + source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo] + + Returns: + TfRuntimeError, the rewritten underlying error. + """ + try: + cleaned_traceback = _cut_traceback_loops(source_map, error.op.traceback) + # cleaned_traceback = error.op.traceback + cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback) + + op_name = error.op.name + op_message = error.message + rewritten_error = TfRuntimeError(op_name, op_message, cleaned_traceback) + return rewritten_error + except Exception: # pylint: disable=broad-except + logging.exception('Error while rewriting AutoGraph error:') + return error + + +# TODO(znado): Add arg to enable different levels of error rewriting. +@contextlib.contextmanager +def improved_errors(converted_function): + """Context manager that rewrites runtime errors. + + This context manager will rewrite runtime errors so that their traceback + is relative to the original code before conversion. + + Use with the output of to_graph, and wrap the execution of respective ops. + Example: + + converted_my_func = ag.to_graph(my_func) + ops = converted_my_func(...) + + with ag.improved_errors(converted_my_func): + sess.run(ops) + + Args: + converted_function: Callable[..., Any], the output of a to_graph call + + Yields: + None + + Raises: + TfRuntimeError: if any OpError originates in the converted code, it will + be wrapped into a TfRuntimeError + ValueError: If converted_function is not generated by AutoGraph + """ + if (getattr(converted_function, 'ag_source_map', None) is None or + not isinstance(converted_function.ag_source_map, dict)): + raise ValueError( + 'converted_function must be the result of an autograph.to_graph call') + try: + yield + except errors_impl.OpError as e: + raise rewrite_tf_runtime_error(e, converted_function.ag_source_map) diff --git a/tensorflow/contrib/autograph/core/errors_test.py b/tensorflow/contrib/autograph/core/errors_test.py new file mode 100644 index 0000000000000000000000000000000000000000..404c1f5456f9654724d068e3007fe9ced15cbf07 --- /dev/null +++ b/tensorflow/contrib/autograph/core/errors_test.py @@ -0,0 +1,105 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for errors module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.core import errors +from tensorflow.contrib.autograph.pyct import origin_info +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors as tf_errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect + + +def zero_div(): + x = array_ops.constant(10, dtype=dtypes.int32) + return x // 0 + + +def zero_div_caller(): + return zero_div() + + +class RuntimeErrorsTest(test.TestCase): + + def fake_origin(self, function, line_offset): + _, lineno = tf_inspect.getsourcelines(function) + filename = tf_inspect.getsourcefile(function) + lineno += line_offset + loc = origin_info.LineLocation(filename, lineno) + origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code', + 'test_comment') + return loc, origin + + def test_improved_errors_basic(self): + loc, origin = self.fake_origin(zero_div, 2) + zero_div_caller.ag_source_map = {loc: origin} + + ops = zero_div_caller() + with self.assertRaises(errors.TfRuntimeError) as cm: + with errors.improved_errors(zero_div_caller): + with self.test_session() as sess: + sess.run(ops) + + for frame in cm.exception.custom_traceback: + _, _, function_name, _ = frame + self.assertNotEqual('zero_div', function_name) + self.assertIn(origin.as_frame(), set(cm.exception.custom_traceback)) + + def test_improved_errors_no_matching_lineno(self): + loc, origin = self.fake_origin(zero_div, -1) + zero_div_caller.ag_source_map = {loc: origin} + + ops = zero_div_caller() + with self.assertRaises(errors.TfRuntimeError) as cm: + with errors.improved_errors(zero_div_caller): + with self.test_session() as sess: + sess.run(ops) + + all_function_names = set() + for frame in cm.exception.custom_traceback: + _, _, function_name, _ = frame + all_function_names.add(function_name) + self.assertNotEqual('test_function_name', function_name) + self.assertIn('zero_div', all_function_names) + + def test_improved_errors_failures(self): + loc, _ = self.fake_origin(zero_div, 2) + zero_div_caller.ag_source_map = {loc: 'bogus object'} + + ops = zero_div_caller() + with self.assertRaises(tf_errors.InvalidArgumentError): + with errors.improved_errors(zero_div_caller): + with self.test_session() as sess: + sess.run(ops) + + def test_improved_errors_validation(self): + with self.assertRaisesRegexp( + ValueError, + 'converted_function must be the result of an autograph.to_graph call'): + errors.improved_errors(zero_div).__enter__() + with self.assertRaisesRegexp( + ValueError, + 'converted_function must be the result of an autograph.to_graph call'): + zero_div_caller.ag_source_map = 'not a dict' + errors.improved_errors(zero_div_caller).__enter__() + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/impl/naming.py b/tensorflow/contrib/autograph/core/naming.py similarity index 100% rename from tensorflow/contrib/autograph/impl/naming.py rename to tensorflow/contrib/autograph/core/naming.py diff --git a/tensorflow/contrib/autograph/impl/naming_test.py b/tensorflow/contrib/autograph/core/naming_test.py similarity index 98% rename from tensorflow/contrib/autograph/impl/naming_test.py rename to tensorflow/contrib/autograph/core/naming_test.py index 73fc0894655cb49e4f61bf8ca51995b06feb3072..d2bebd0478b1074e421b5da1427a0dbaf91b6c9f 100644 --- a/tensorflow/contrib/autograph/impl/naming_test.py +++ b/tensorflow/contrib/autograph/core/naming_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.impl import naming +from tensorflow.contrib.autograph.core import naming from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md b/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md new file mode 100644 index 0000000000000000000000000000000000000000..bcbb920cc53de4b89dc67128c9c2c2312f030f0a --- /dev/null +++ b/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md @@ -0,0 +1,33 @@ +# Specifying return data type for `py_func` calls + +The `py_func` op requires specifying a +[data type](https://www.tensorflow.org/guide/tensors#data_types). + +When wrapping a function with `py_func`, for instance using +`@autograph.do_not_convert(run_mode=autograph.RunMode.PY_FUNC)`, you have two +options to specify the returned data type: + + * explicitly, with a specified `tf.DType` value + * by matching the data type of an input argument, which is then assumed to be + a `Tensor` + +Examples: + +Specify an explicit data type: + +``` + def foo(a): + return a + 1 + + autograph.util.wrap_py_func(f, return_dtypes=[tf.float32]) +``` + +Match the data type of the first argument: + +``` + def foo(a): + return a + 1 + + autograph.util.wrap_py_func( + f, return_dtypes=[autograph.utils.py_func.MatchDType(0)]) +``` diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/contrib/autograph/examples/integration_tests/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..6c281485b4a3c4d09292a4d7af16330cdc44edd4 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/integration_tests/BUILD @@ -0,0 +1,54 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_test( + name = "errors_test", + srcs = [ + "errors_test.py", + ], + srcs_version = "PY2AND3", + tags = ["no_windows"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "keras_test", + srcs = [ + "keras_test.py", + ], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "list_literals_test", + srcs = [ + "list_literals_test.py", + ], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b9159942bcf8837b97dfac000d8fb34d15a314 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py @@ -0,0 +1,162 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Error traceback rewriting integration tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib import autograph as ag +from tensorflow.python.util import tf_inspect + + +class ErrorsTest(tf.test.TestCase): + + def test_graph_construction_error_rewriting_call_tree(self): + + def innermost(x): + if x > 0: + return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) + return tf.zeros((2, 3)) + + def inner_caller(): + return innermost(1.0) + + def caller(): + return inner_caller() + + with self.assertRaises(ag.GraphConstructionError) as error: + graph = ag.to_graph(caller) + graph() + expected = error.exception + custom_traceback = expected.custom_traceback + found_correct_filename = False + num_innermost_names = 0 + num_inner_caller_names = 0 + num_caller_names = 0 + ag_output_filename = tf_inspect.getsourcefile(graph) + for frame in custom_traceback: + filename, _, fn_name, _ = frame + self.assertFalse('control_flow_ops.py' in filename) + self.assertFalse(ag_output_filename in filename) + found_correct_filename |= __file__ in filename + self.assertNotEqual('tf__test_fn', fn_name) + num_innermost_names += int('innermost' == fn_name) + self.assertNotEqual('tf__inner_caller', fn_name) + num_inner_caller_names += int('inner_caller' == fn_name) + self.assertNotEqual('tf__caller', fn_name) + num_caller_names += int('caller' == fn_name) + self.assertTrue(found_correct_filename) + self.assertEqual(num_innermost_names, 1) + self.assertEqual(num_inner_caller_names, 1) + self.assertEqual(num_caller_names, 1) + + def test_graph_construction_error_rewriting_class(self): + + class TestClass(object): + + def test_fn(self): + return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) + + def inner_caller(self): + return self.test_fn() + + def caller(self): + return self.inner_caller() + + # Note we expect a TypeError here because the traceback will not be + # rewritten for classes. + with self.assertRaises(TypeError): + graph = ag.to_graph(TestClass) + graph().caller() + + def test_runtime_error_rewriting(self): + + def g(x, s): + while tf.reduce_sum(x) > s: + x //= 0 + return x + + def test_fn(x): + return g(x, 10) + + compiled_fn = ag.to_graph(test_fn) + + with self.assertRaises(ag.TfRuntimeError) as error: + with self.test_session() as sess: + x = compiled_fn(tf.constant([4, 8])) + with ag.improved_errors(compiled_fn): + sess.run(x) + expected = error.exception + custom_traceback = expected.custom_traceback + found_correct_filename = False + num_test_fn_frames = 0 + num_g_frames = 0 + ag_output_filename = tf_inspect.getsourcefile(compiled_fn) + for frame in custom_traceback: + filename, _, fn_name, source_code = frame + self.assertFalse(ag_output_filename in filename) + self.assertFalse('control_flow_ops.py' in filename) + self.assertFalse('ag__.' in fn_name) + self.assertFalse('tf__g' in fn_name) + self.assertFalse('tf__test_fn' in fn_name) + found_correct_filename |= __file__ in filename + num_test_fn_frames += int('test_fn' == fn_name and + 'return g(x, 10)' in source_code) + # This makes sure that the code is correctly rewritten from "x_1 //= 0" to + # "x //= 0". + num_g_frames += int('g' == fn_name and 'x //= 0' in source_code) + self.assertTrue(found_correct_filename) + self.assertEqual(num_test_fn_frames, 1) + self.assertEqual(num_g_frames, 1) + + def test_runtime_error_rewriting_nested(self): + + def test_fn(x): + + def g(y): + return y**2 // 0 + + s = 0 + for xi in x: + s += g(xi) + return s + + compiled_fn = ag.to_graph(test_fn) + + # TODO(b/111408261): Nested functions currently do not rewrite correctly, + # when they do we should change this test to check for the same traceback + # properties as the other tests. This should throw a runtime error with a + # frame with "g" as the function name but because we don't yet add + # try/except blocks to inner functions the name is "tf__g". + with self.assertRaises(ag.TfRuntimeError) as error: + with self.test_session() as sess: + x = compiled_fn(tf.constant([4, 8])) + with ag.improved_errors(compiled_fn): + sess.run(x) + expected = error.exception + custom_traceback = expected.custom_traceback + num_tf_g_frames = 0 + for frame in custom_traceback: + _, _, fn_name, _ = frame + self.assertNotEqual('g', fn_name) + num_tf_g_frames += int('tf__g' == fn_name) + self.assertEqual(num_tf_g_frames, 1) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7ef5a3e2bbf6a15936eb181c9c4112f8b820e6 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py @@ -0,0 +1,103 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras integration tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib import autograph + + +class MinimalKeras(tf.keras.Model): + + def call(self, x): + return x * 3 + + +class ModelWithStaticConditional(object): + + def __init__(self, initial): + self.initial = initial + if self.initial: + self.h = 15 + + @autograph.convert() + def call(self): + x = 10 + if self.initial: + x += self.h + return x + + +class BasicBlock(tf.keras.Model): + + def __init__(self): + super(BasicBlock, self).__init__() + self.conv1 = tf.keras.layers.Conv2D(8, 3) + self.pool = tf.keras.layers.GlobalAveragePooling2D() + self.dense = tf.keras.layers.Dense(3) + + def call(self, x): + x = self.conv1(x) + x = self.pool(x) + x = self.dense(x) + return x + + +class CompoundModel(tf.keras.Model): + + def __init__(self): + super(CompoundModel, self).__init__() + self.block = BasicBlock() + + @autograph.convert(recursive=True) + def call(self, x): + x = self.block(x) # pylint: disable=not-callable + return x + + +class KerasTest(tf.test.TestCase): + + def test_basic(self): + MinimalKeras() + + def test_conditional_attributes_False(self): + model = ModelWithStaticConditional(False) + self.assertEqual(model.call(), 10) + + def test_conditional_attributes_True(self): + model = ModelWithStaticConditional(True) + self.assertEqual(model.call(), 25) + + def test_recursive_true(self): + with self.assertRaisesRegexp(NotImplementedError, + 'Object conversion is not yet supported.'): + with tf.Graph().as_default(): + model = CompoundModel() + model.build(tf.TensorShape((None, 10, 10, 1))) + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + sample_input = tf.random_uniform((1, 10, 10, 1)) + output = model(sample_input) # pylint: disable=not-callable + self.assertEqual(sess.run(output).shape, (1, 3)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/autograph/utils/type_hints.py b/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py similarity index 54% rename from tensorflow/contrib/autograph/utils/type_hints.py rename to tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py index aeb9e545610460afbe364dfcfc7a54b9aede29fe..680b6dbaf07fc10e11dfa1e9d3a075624024c103 100644 --- a/tensorflow/contrib/autograph/utils/type_hints.py +++ b/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py @@ -12,30 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""No-op utilities that provide static type hints. - -These are used when the data type is not known at creation, for instance in the -case of empty lists. -""" +"""Tests of functions that use list literals.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import tensorflow as tf + +from tensorflow.contrib import autograph as ag + + +def list_used_as_tuple(): + return tf.constant([1, 2, 3]) + + +class ListLiteralsTest(tf.test.TestCase): + + def test_basic(self): + converted = ag.to_graph(list_used_as_tuple) + result = converted() -def set_element_type(entity, dtype, shape=None): - """Indicates that the entity is expected hold items of specified type. + with self.test_session() as sess: + self.assertAllEqual(sess.run(result), [1, 2, 3]) - This function is a no-op. Its presence merely marks the data type of its - argument. The staged TensorFlow ops will reflect and assert this data type. - Args: - entity: A Tensor or TensorArray. - dtype: TensorFlow dtype value to assert for entity. - shape: Optional shape to assert for entity. - Returns: - The value of entity, unchanged. - """ - del dtype - del shape - return entity +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb b/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..c10a5741f640be5ab7d2604dd32f2f4d6ddf1a22 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb @@ -0,0 +1,299 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "aQkTGc-d8I1k" + }, + "source": [ + "This notebook runs a basic speed test for a simple algorithm that implements the process described in Collatz Conjecture.\n", + "\n", + "https://en.wikipedia.org/wiki/Collatz_conjecture" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "x5ChBlH09jk_" + }, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "X-QAUpWdPxUh" + }, + "outputs": [], + "source": [ + "!pip install -U -q tf-nightly" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "wiKQu3w05eCa" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph as ag\n", + "from tensorflow.python.eager import context" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "_cRFTcwT9mnn" + }, + "source": [ + "### Plotting helpers" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "ww7rc0GQ9pMu" + }, + "outputs": [], + "source": [ + "def plot_results(counts, times, title):\n", + " plt.plot(counts, np.array(times) * 1000., 'o')\n", + " plt.ylabel('Time (milliseconds)')\n", + " plt.xlabel('Collatz counter')\n", + " plt.title(title)\n", + " plt.ylim(0, 30)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ESZGw9s9-Y5_" + }, + "source": [ + "### Collatz function definition" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "qeunWm9m-dT7" + }, + "outputs": [], + "source": [ + "def collatz(a):\n", + " count = 0\n", + " while a \u003e 1.1:\n", + " if a % 2 \u003c 0.1:\n", + " a //= 2\n", + " else:\n", + " a = 3 * a + 1\n", + " count += 1\n", + " return count\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "nnFmPDvScsDo" + }, + "source": [ + "# AutoGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 301 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 9153, + "status": "ok", + "timestamp": 1531757473651, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "6fU4vlxYcsDe", + "outputId": "11b50f28-aced-4506-a743-4b749e9645c3" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAEcCAYAAAAydkhNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XtcVGXCB/DfGRBUQA0ZURQvyIspm1reQkxNSPICgoqW\npWZu1vbmjZJV3Jc+axappVLu7guV25rU5g3wlq3iBd1wXHSN3hXy9ZaCgoOIIKAzMOf9g5dZkTkz\nB5i7v+9fzJlzzjzPHD2/Oc/znOcIoiiKICIiMkBh6wIQEZH9YkgQEZEkhgQREUliSBARkSSGBBER\nSWJIEBGRJIYEkQNYsWIFkpOTbV0MegQxJMipzJ49G8OHD4dWq5W9zeOPP45r164163O2bt2KqKgo\nDB48GKNGjcKcOXOwf//+5haXyO4xJMhpFBUV4fTp0xAEAYcPH5a9nSAIzfqc9957D1999RVWrFiB\nU6dO4fjx41iyZAmOHz8uuQ3vWSVHxZAgp5GRkYHBgwdj6tSpSE9P1y+fPXs2duzYoX+dnp6OWbNm\nAQBefvlliKKIqKgoPPXUU/juu+8AANu2bcP48eMxYsQIvPnmm7h58yYA4PLly/jmm2+wYcMGhISE\nwM3NDYIg4KmnnkJSUlKjz9ywYQNefPFFDB48GIWFhdi1axcmTpyIp556Cs899xy+/fZb/fqnTp3C\nmDFjkJKSgqeffhphYWHYs2dPo/rduXMHr7/+Op566inMnDmz2Vc/RC3BkCCnkZmZiaioKEyePBkn\nTpxAWVmZ5LoNVw9bt24FAOzevRtnzpzBhAkTkJOTg/Xr1+OTTz7BiRMn4Ofnh7i4OACASqVCt27d\nMGDAAJPl2bNnD1avXo0zZ86gW7du6Ny5M1JTU3HmzBkkJSUhKSkJ+fn5+vVLS0tRXl6O48eP48MP\nP0RiYiKuXLmif3/fvn1YuHAhcnNz4e/vj40bN7bkayJqFoYEOYXc3Fxcv34dEyZMQHBwMHr27Nnk\nl7hce/fuxfTp0/H444+jTZs2iIuLw9mzZ3H9+nXcvn0bSqWy0fpjxozBsGHDMHDgQNy4cUO/PCYm\nBn379oVCoYCrqyvGjBmDHj16AACGDh2K0NBQ5Obm6tcXBAFLlixBmzZtMGzYMIwZM0Z/ZQMA48eP\nx69+9SsoFApERkY2ChgiS2FIkFPIzMzEqFGj0LFjRwDApEmTkJGR0aJ93bx5E35+fvrX7du3R6dO\nnVBSUoJOnTrpm54aHDt2DCdPnoRWq23U99C1a9cm682cORMjRozAsGHDkJ2djdu3b+vf79ChA9zd\n3fWv/fz8Gn2Wj4+P/u927dqhurq6RfUjag5XWxeAqLXu37+P7777DjqdDqNGjQIAaDQaVFZWoqCg\nAO3bt8e9e/f066vVaqP769KlC65fv65/XV1djfLycvj6+qJTp05YvXo1/vWvfyE4OLjRdg93Tj/Y\nIa7RaLB48WKsW7cOYWFhUCgU+M///M9G21RUVODevXto27YtAODGjRsICgpq5rdBZF68kiCHd/Dg\nQbi4uOC7775DZmYmMjMz8d1332Ho0KHIzMxE//798be//Q337t3DL7/8gp07dzba3sfHp1En8OTJ\nk7Fr1y4UFBRAo9Fg/fr1GDRoEPz8/NCnTx/MnDkTcXFx+OGHH3D//n3odDqcOXPG6CgprVYLrVaL\nxx57DAqFAseOHcPf//73RuuIoohPPvkEWq0Wubm5OHr0KCZMmGDeL4uomXglQQ4vIyMD06ZNg6+v\nb6Pls2bNwvvvv4+9e/fip59+QmhoKPr164fIyEjk5OTo11u4cCHi4+Oh0WiwatUqPP/881i8eDEW\nLlyIiooKPPnkk1i/fr1+/cTERGzduhVJSUm4du0avLy80Lt3b2zcuFHfTPVwYHh4eGDlypVYvHgx\ntFotnn32WYSFhTVaR6lUomPHjnjmmWfQvn17rFq1Cr179zbzt0XUPIIlHzqk0Wjw0ksvQavVoq6u\nDhEREXjrrbdQWFiIuLg43LlzB8HBwVi7di1cXZlX9Og6deoU4uPjcfToUVsXhagRizY3ubm5YcuW\nLcjIyEBGRgays7Px448/4qOPPsK8efPw/fffw8vLq9EYdiIish8W75No164dgPqritraWgiCAJVK\nhYiICAD1wwQPHjxo6WIQEVELWDwkdDodoqOjERoaitDQUPj7+6NDhw5QKOo/umvXrk2GFBI9aoYP\nH86mJrJLFg8JhUKhb2rKy8vDxYsXm6zT3LlziIjIOqw2BNbT0xPDhg3Djz/+iIqKCuh0OgBAcXEx\nunTpYnJ7TpBGRGR9Fh1SVFZWhjZt2sDLywv37t1DTk4OFixYgBEjRuDAgQOYOHEi0tPTmwwFNEQQ\nBKjVlZYsrk0plV6sn4Ny5roBrJ+jUyq9WrW9RUNCrVZj+fLl0Ol00Ol0mDhxIsaMGYOAgADExcUh\nOTkZ/fv3x/Tp0y1ZDCIiaiGL3idhbs6e9qyfY3LmugGsn6Nr7ZUEp+UgIiJJDAkiIpLEkCAiIkkM\nCSIiksSQICIiSQwJIiKSxJAgIiJJDAkiIpLEkCAiIkkMCSIiksSQICIiSQwJIiKSxJAgIiJJDAki\nIpLEkCAiIkkMCSIiksSQICIiSQwJIiKSxJAgIiJJDAkiIpLEkCAiIkkMCSIiksSQICIiSQwJIiKS\nxJAgIiJJDAkiIpLEkCAiIkkMCSIikuRqyZ0XFxcjPj4epaWlcHFxwYwZMzB79mxs2rQJ27ZtQ+fO\nnQEAS5cuxejRoy1ZFCIiagGLhoSLiwtWrFiB/v37o6qqClOnTsXIkSMBAPPmzcO8efMs+fFERNRK\nFg0JpVIJpVIJAPDw8EDfvn1x8+ZNAIAoipb8aCIiMgOr9UkUFhaioKAAAwcOBACkpaVhypQpWLly\nJSorK61VDCIiagarhERVVRUWLVqEhIQEeHh4YNasWTh06BAyMzPh4+ODpKQkaxSDiIiaSRAt3O5T\nW1uL119/HaNHj8bcuXObvF9UVIQ33ngDe/bssWQxiIioBSzaJwEACQkJCAwMbBQQarVa31dx8OBB\nBAUFydqXWu28zVJKpRfr56CcuW4A6+folEqvVm1v0ZA4ffo09uzZg6CgIERHR0MQBCxduhR79+5F\nfn4+FAoFunfvjlWrVlmyGERE1EIWDYkhQ4YgPz+/yXLeE0FE5Bh4xzUREUliSBARkSSGBBERSWJI\nEBGRJIYEERFJYkgQEZEkhgQREUliSBARkSSGBBERSWJIEBGRJIYEERFJYkgQEZEkhgQREUliSBAR\nkSSGBBERSWJIEBGRJIYEERFJYkgQEZEk2Y8vvXfvHtRqNdzd3dGlSxdLlomIiOyE0ZDQ6XTIyMjA\n9u3bUVBQAE9PT2g0Gri6uiI8PByvvPIK+vTpY62yEhGRlRkNiRdffBGDBw/GihUrEBwcDBcXFwDA\nrVu3cPz4cSQmJuKFF17ApEmTrFJYIiKyLkEURVHqzbKyMnh7exvdgZx1zEWtrrTK59iCUunF+jko\nZ64bwPo5OqXSq1XbG+24NnTyv3XrFs6ePWt0HSIicg6yRjfNmjULlZWVqKioQHR0NFauXIk1a9ZY\numxERGRjskKiuroaXl5eOHLkCCIjI7Fnzx6cOHHC0mUjIiIbkxUSGo0GAKBSqTBy5EgoFAp9JzYR\nETkvWSExfPhwREREIDc3F8OHD0dFRQUUCt6HR0Tk7GTdTPfuu++ioKAA/v7+cHNzw927d7F69WpL\nl42IiGzMaEhcuHBB/3ebNm1QXFysf+3m5ma5UhERkV0wGhILFiyAIAgQRRE3btyAp6cnAODu3bvo\n1q0bDh8+bHTnxcXFiI+PR2lpKVxcXBAbG4s5c+bgzp07WLp0KYqKitCjRw9s3LgRXl6tG8tLRETm\nZzQkGkJg9erVGDJkCCZMmAAAOHDgAM6dO2dy5y4uLlixYgX69++PqqoqTJ06FaGhodi1axdCQkLw\n2muvITU1FSkpKXjnnXfMUB0iIjInWb3PeXl5+oAAgOeffx4nT540uZ1SqUT//v0BAB4eHujbty9K\nSkqQlZWFmJgYAEBMTAwOHTrUkrITEZGFyQqJmpoa5Obm6l/n5uaipqamWR9UWFiIgoICDBo0CLdu\n3YKPjw+A+iC5fft2s/ZFRETWIXt0U1xcHNq1awcAuH//Pj7++GPZH1JVVYVFixYhISEBHh4eEASh\nRYVt7Rwk9o71c1zOXDeA9XuUyQqJoUOH4tChQ7h8+TJEUURAQIDs0U21tbVYtGgRpkyZgvDwcABA\n586dUVpaCh8fH6jVatnzPzn7JFysn2Ny5roBrJ+js+gEfw+qq6uDm5sbXF1dcfXq1UbDY41JSEhA\nYGAg5s6dq182btw47Nq1CwCQnp6OsLCwZhabiIisQdaVRFpaGj766CN06tRJ31QkCAKysrKMbnf6\n9Gns2bMHQUFBiI6OhiAIWLp0KV577TUsWbIEO3fuhJ+fH5KTk1tfEyIiMjtZIbF582bs3bsX3bt3\nb9bOhwwZgvz8fIPvffnll83aFxERWZ+s5ialUtnsgCAiIscn60pi5MiRWLt2LSZNmgR3d3f98sDA\nQIsVjIiIbE9WSGRkZACov9O6gZw+CSIicmyyQsLUHE1EROScZIUEUD8jrEqlAgA8/fTT6Nu3r8UK\nRURE9kFWx3VGRgZeeeUV5OfnIz8/H/PmzcPu3bstXTYiIrIx2UNg09PToVQqAQBqtRrz589HVFSU\nRQtHRES2JfuO64aAePhvIiJyXrJComfPnvjkk09QUlKCmzdvYtOmTfD397d02YiIyMZkhcTvf/97\nXL58GVFRUYiKisKlS5ewatUqS5eNiIhsTFafROfOnbFhwwZLl4WIiOyMrCuJ1NRUlJeX61/fvn0b\nn3/+ucUKRURE9kFWSOzbtw+dOnXSv37sscewd+9eixWKiIjsg6yQEEWxybK6ujqzF4aIiOyLrJDo\n3bs3/vznP0MUReh0OmzevBk9e/a0dNmIiMjGZIXEypUrceTIEQwcOBCDBw/GsWPHkJiYaOmyERGR\njcka3eTr64stW7aguroaANC+fXuLFoqIiOyD7D6J7du3449//CPat2+PwsJCnDlzxtJlIyIiG5MV\nEklJSTh58iQOHToEAPDw8MAHH3xg0YIREZHtyQoJlUqFjz76CG3btgVQPwT2/v37Fi0YERHZnqyQ\ncHd3hyAI+tc6nc5iBSIiIvshq+M6KCgIu3fvhiiKKCwsRGpqKoYMGWLpshERkY3JupJYvnw5Tp06\nBbVajdjYWNTV1WHZsmWWLhsREdmYrCsJT09PrF692tJlISIiOyPrSmL//v24e/cuACA5ORnz58/H\n//zP/1i0YEREZHuyQuJPf/oTPD09kZeXhxMnTiA6OppXFkREjwBZIeHqWt8q9fe//x2xsbGIjIzk\nEFgiokeArJAQBAG7d+/Gvn37EBISAgDQarUWLRgREdmerJD43e9+hwMHDiA2Nhb+/v64cuUKRowY\nYXK7hIQEjBw5EpGRkfplmzZtwujRoxETE4OYmBhkZ2e3vPRERGRRgmjoYRFmkpubCw8PD8THx2PP\nnj0A6kPCw8MD8+bNa/b+1OpKcxfRbiiVXqyfg3LmugGsn6NTKr1atb3RIbB/+ctfMHfuXKxdu9bg\n+/Hx8UZ3PnToUBQVFTVZbsFcIiIiMzIaEu7u7gDMPzV4WloaMjMz8atf/QrLly+Hl1frko6IiCzD\nos1NAFBUVIQ33nhD39xUVlaGxx57DIIgYMOGDVCr1ZxRlojIThm9kkhLSzO68UsvvdTsD/T29tb/\nPWPGDLzxxhuyt3X2dkPWzzE5c90A1s/RWbRPwhx3VT98oaJWq6FUKgEABw8eRFBQUKs/g4iILMNo\nSCQlJbVq52+//TZUKhXKy8sxduxYLFy4ECqVCvn5+VAoFOjevTtWrVrVqs8gIiLLMRoSx44dM7rx\nmDFjjL7/8ccfN1k2bdo0GcUiIiJ7YDQkPv/8c8n3BEEwGRJEROTYjIbEV199Za1yEBGRHTIaEteu\nXYO/vz8uXLhg8P3AwECLFIqIiOyD0ZBYvXo1UlJSsGDBgibvCYKArKwsixWMiIhsz2hIpKSkAAAO\nHz5slcIQEZF9kfX4UgCoqalBcXEx6urq9MvY3ERE5NxkhcSWLVuwYcMGdOzYEQpF/ezibG4iInJ+\nskLiL3/5Cw4cOABfX19Ll4eIiOyIrIcOde3alQFBRPQIknUlsXDhQqxcuRJjxozRTx8OmL7jmoiI\nHJuskDhy5AiOHDmCK1euNOqTYEgQETk3WSFx8OBBHD58GG3btrV0eYiIyI7I6pPw9/eHq6vs0bJE\nROQkZJ35e/Xqhblz5yI8PBxubm765S156BARETkOWSGh1WrRs2dPnD9/3tLlISIiOyIrJFr78CEi\nInJMRvskTD2+VKPR4OLFi2YtEBER2Q+TE/zV1NRg8uTJGDRoEHx8fHD//n1cvnwZx48fx7Fjx7B8\n+XL07dvXWuUlIiIrMhoSn376KfLy8vDtt9/iD3/4A4qLi9GuXTsEBQUhPDwcaWlp8PT0tFZZiYjI\nykz2SQwcOBADBw60RlmIiMjOyLpPgoiIHk0MCSIiksSQICIiSQwJIiKSJCskbt26hXfeeUc/DUdB\nQQG++eYbixaMiIhsT1ZI/O53v8OQIUNQUVEBAAgICMDXX39t0YIREZHtyQqJkpISvPjii3BxcQEA\nuLm56Z8rQUREzkvWmf7hacIrKiogiqJFCkRERPZDVkiMHz8eiYmJqKqqwq5du/Dqq69i2rRpJrdL\nSEjAyJEjERkZqV92584dvPrqq4iIiMD8+fNRWVnZ8tITEZFFyQqJX//61xg6dCiCg4Nx7NgxzJ49\nG3PnzjW53dSpU/HFF180WpaamoqQkBB8//33GDFiBFJSUlpWciIisjjZj5uLiopCVFRUs3Y+dOhQ\nFBUVNVqWlZWFrVu3AgBiYmIwe/ZsvPPOO83aLxERWYeskLh16xa2bt2Kq1evora2Vr88OTm52R9Y\nVlYGHx8fAIBSqcTt27ebvQ8iIrIOWSHx5ptvYsCAAQgJCdGPcLIFpdLLZp9tDayf43LmugGs36NM\nVkjU1NTg3XffNcsHdu7cGaWlpfDx8YFarYa3t7fsbdVq5+3kViq9WD8H5cx1A1g/R9faAJTVcT1o\n0CD8/PPPLfqAh4fKjhs3Drt27QIApKenIywsrEX7JSIiy5N1JfHCCy/g5ZdfRteuXeHu7q5fvmPH\nDqPbvf3221CpVCgvL8fYsWOxcOFCLFiwAIsXL8bOnTvh5+fXon4NIiKyDlkhsWzZMrzxxhsYMGBA\ns/okPv74Y4PLv/zyS9n7ICIi25EVEu7u7pg/f76ly0JERHZGVp/EM888g+zsbEuXhYiI7IysK4lt\n27YhNTUVHh4ecHNzgyiKEAQBOTk5li4fERHZkKyQ2Llzp6XLQUREdkhWSHTv3t3S5SAiIjtkNCSW\nLVuGdevWYdq0aRAEocn7pobAEhGRYzMaEg0zvf72t7+1SmGIiMi+GA2Jr7/+Gh988AGGDx9urfIQ\nEZEdMToENj8/31rlICIiO8QHVRMRkSSjzU3nz59HSEhIk+W8T4KI6NFgNCR69+6N1NRUa5WFiIjs\njNGQcHNz4z0SRESPMKN9Em3atLFWOYiIyA4ZDYlt27ZZqxxERGSHOLqJiIgkMSSIiEgSQ4KIiCQx\nJIiISBJDgoiIJDEkiIhIEkOCiIgkMSSIiEgSQ4KIiCTJesY1EdGjTnWuBPtyruB6aTX8fNpjUkhv\njBjga+tiWRxDgojIBNW5EqTs/pf+daG6Sv/a2YOCIUFEVuHIv8T35VyRWP6Lw9ShpRgSRGRxjv5L\n/HpptcHlN25VWbkk1mezkBg3bhw8PT2hUCjg6uqKHTt22KooRGRhjv5L3M+nPQrVTQOhW2cPG5TG\numwWEoIg4KuvvkLHjh1tVQQishJH/yU+KaR3oyuhfy/v1ei1IzepSbFZSIiiCJ1OZ6uPJ3I69nyC\nsuYvcUt8Dw3b78v5BTduVaFbZw9MCunVaL+O3qQmxaZXEvPnz4cgCJg5cyZmzJhhq6IQOTx7P0HJ\n/SXeWpb8HkYM8DW6D0dvUpNis5D461//CqVSibKyMsybNw8BAQEYOnSorYpDZHXm/MVr7ycoOb/E\nzcGW34OjN6lJsVlIKJVKAIC3tzeee+45/PTTTyZDQqn0skbRbIb1c1zNrVv2PwsN/uLt0KEtRj/Z\no9F627P+F1dLKtHT1wuxYf/R6P0G129Jn6DM8b2bYx+Tx3hh8pjAZm0jt/4NWvo9mKN+Pbt64cqN\niibL/X29HPrfvk1CoqamBjqdDh4eHqiursaJEyfw1ltvmdxOra60QulsQ6n0Yv0cVEvq9s33BRLL\nf0b/HvWDOR5uOrlyowLrtp5GRcW9Jr+K/TpLt/m39nu31bFrTv0btOR7MFf9Iob5G2xSixjmb9N/\n+60NKJuERGlpKd566y0IgoC6ujpERkZi1KhRtigKkU3IaZpoTtOJsTZ/e+7QNqYlTUfW6vswxFpN\natZmk5Dw9/dHZmamLT6ayCKaeyI2NdpHda7E4PuA4TZuqRMUALvu0DbGVJAa+85tdaI21bntiHjH\nNVErSfUvANInYlO//A2916Cjp5vB5YZOUIlfqAyuay8d2sYYC1JTo5jsvW6OhFOF0yNLda4EiV+o\n8Os1R5D4hQqqcyUt2s/2rP81uHxfzi+S24wY4IvXo4LRQ+kJF4WAHkpPvB4VjBEDfCWbWRqUVdyX\nXVZHHnEzKaS3xPJeRpuiyLx4JUGPJHOOp79aYrhT0tSJWOoXr9SJ/UFyrwQceToJY01Hn+05Z3Ab\nRwg/R8OQoEeSOcfT9/Q1PPSxpSdiqRP7gwrVd/HrNUdM9n/YsiPXHKSC1JHDz9GwuYkeSVK/1otK\n7za7CSo27D8MLm/piViqmeVhOlHUXwFJldNYs5YjM9YURebFKwl6JEn9EhVF6JfLbYIa/WQPVFTc\nM9uImobtth+5gLLK+7K2MXYF5IwdubYexfQoYUhQqznSOPyGshaVym+7bugMNVZHS5yI5QYE0LQt\n3pGOSUs5Y/jZI4YEtZjqXEmTX7v2NA7/4RNlv56PIet0YZP1BAF4zNNd8qRcVHrX6vcaSPWZtHFR\nQFvXdPbkB9vi7X2yP3IsDAlqEVNj+c05Dr8lv4oNnSilOoMf83JHWYX0r3ZXheETsyXvNZDqM6mV\nmF7/wbZ4e5/sjxwLQ4JaxNRYfnMNRZTzq9hQiJgq34OMBQQgfWK25HBLqT6T7j6e/3+fgHRbvCPf\nG0H2hyHhpCzdJm1qLL+5hiJuP3LB8PKjFzBigK9kiAiCWT4er0cFY1/OFbMPtzR1fIwNXTXVFs/h\noWRODAknZI02aVNj+Zs7FFHqpCnVT9Dw61/qikGqicgQby/D/RE9lJ7678uc9xrIOT6tGb3j6PdG\nkH1hSDgha7RJS52IvDu4I3ZsYLM+x9hJ09R2zW27DxvSAz9fLTc6CV6DhvfMPdxS7vFp6egdDg8l\nc2JIOCFrtEmb80Rk7KQp9SsfqD+xu0g0K8lpuzf0eVLrmnO4pbWOD0OBzIEh4YRMzZ7Zmr4KS/R1\nSJ00C9V3YaproU40vFxO2/2DrHlSZZ8BORKGhBOSagrq17NTq/oqpJqFUnf/C92VHvrASDt4Htln\ni6CtE+EiAO3atkH1/Vr4dTYcKsb6NyQyoIk2LgroRNEhmlbYZ0COhCHhhKSaglrTV6E6V4LN+wzP\nvCni34Fx4qcb+NflMv17dSJwt0YLQDqUpE6azaETRXwW/2yr9mEt7DMgR8KQcFKGmk9aOr2yqRvn\nHvRgQEgx1EELGO48lsvRmmrYZ0COgiHhhKT6DVraFt6cG9PkkHr8ptT9CA28vdwBwfDNb2yqIbIM\nThXuZBp+9Reqq5pMJd3S6ZXlPASnOaRCydQU2bHPBuKjN0OdcuprInvFKwknY6zfYdX84fq/m9MW\nLuchOA2C+3ibbHKSCqUH2+qLSu/CVaFAnU4Hv/8fzvrgjWYMBSLrYEg4GVNj8FtygpXqWDZ0Y9q/\nRzddh7ZOBxdBQLu2rqi5XysrlBgARPaFIWFhDz6/QCEIqNPVD+r09nJH7LPNuzNZDkuMwW/uaJyX\nngvCS88FNVqmVHpBrTb8LGgisl8MiYcYegbBz1dvm7x5zFBnMdB4xE6d+O9R/2WV9y0yx7+lxuDz\nFz7Ro4kh8QBTzyCQGucvdZOZt5e7yc809xz/HINPRObEkHiA3KGeD5/YpbaT8/hJS8zxz1/9RGQu\nHAL7ALlDPR8+sbdmiKij3QRGRI8WhsQD/Hzay1rv4RO71HZympt4ExgR2TObhUR2djaef/55RERE\nIDU1tUX7UJ0rQeIXKvx6zREkfqGC6lxJq8pk6mauf6/X66HXhreLfTZQf+OXQgBcFP+e09Tby503\ngRGR3bNJn4ROp8N7772HL7/8El26dMH06dMRFhaGvn37Sm4zZdnuRrOIWuLpa4Y6ffv17GTwXgBT\n2z188xcRkSOySUjk5eWhV69e6N69OwBg0qRJyMrKMhoSOp3YKAgs9fS11jwNjGFARM7GJs1NJSUl\n6Natm/61r68vbt68KXv7fTm/WOXpXkREjzqbhIQoyn2UjGE3blVJdhZztBARkfnYpLmpa9euuH79\nuv51SUkJunTpInt7f18vxIb9B9ZtPd3kvRcj+kGp9DJLOa3NUcstlzPXz5nrBrB+jzJBbO3P+hao\nq6vD888/jy+//BJKpRKxsbFYv3690T4JIiKyPptcSbi4uOC//uu/8Oqrr0IURUyfPp0BQURkh2xy\nJUFERI6Bd1wTEZEkhgQREUliSBARkSS7DwlzzPFkb8aNG4eoqChER0dj+vTpAIA7d+7g1VdfRURE\nBObPn4/KSsd5iltCQgJGjhyJyMhI/TJj9Vm9ejXGjx+PKVOmID8/3xZFbhZD9du0aRNGjx6NmJgY\nxMTEIDs7W/9eSkoKxo8fjwkTJuDEiRO2KLJsxcXFmDNnDiZOnIjIyEhs2bIFgPMcv4fr99VXXwFw\nnuOn0WiMkXxQAAAKiUlEQVQQGxuL6OhoREZGYtOmTQCAwsJCzJgxAxEREYiLi0Ntba1+/aVLl2L8\n+PGYOXNmo1sRJIl2rK6uTgwPDxcLCwtFjUYjRkVFiRcuXLB1sVpt3LhxYnl5eaNla9euFVNTU0VR\nFMWUlBRx3bp1tihai/zjH/8Qz507J06ePFm/TKo+R48eFV977TVRFEXx7NmzYmxsrPUL3EyG6vfp\np5+KmzdvbrLuhQsXxClTpoharVa8du2aGB4eLup0OmsWt1lu3rwpnjt3ThRFUbx79644fvx48cKF\nC05z/KTq5yzHTxRFsbq6WhRFUaytrRVjY2PFs2fPiosXLxb3798viqIoJiYmit98840oiqKYlpYm\nvvvuu6IoiuK+ffvEJUuWmNy/XV9JPDjHU5s2bfRzPDk6URSh0+kaLcvKykJMTAwAICYmBocOHbJF\n0Vpk6NCh6NChQ6NlD9en4bhlZWUhOjoaADBo0CBUVlaitLTUugVuJkP1AwzPHJCVlYWJEyfC1dUV\nPXr0QK9evZCXl2eNYraIUqlE//79AQAeHh7o27cvSkpKnOb4GapfwxRAznD8AKBdu3YA6q8Samtr\nIQgCVCoVIiIiADQ+nzx4XCMiIpCTk2Ny/3YdEq2d48leCYKA+fPnY9q0adi+fTsA4NatW/Dx8QFQ\n/w/79u3btixiq5WVlTWqT1lZGQDg5s2b6Nq1q349X19flJS0bop3W0lLS8OUKVOwcuVKfXOMoX+z\njlK/wsJCFBQUYNCgQU3+PTrD8Wuo38CBAwE4z/HT6XSIjo5GaGgoQkND4e/vjw4dOkChqD+9d+3a\nVV+HB4+fi4sLOnTogPLycqP7t+uQMJT0zuCvf/0rdu3ahc8++wxpaWnIzc2FIAimN3QCho6pI9Z9\n1qxZOHToEDIzM+Hj44MPP/wQgOPWr6qqCosWLUJCQgI8PDwky+ws9XOm46dQKJCRkYHs7Gzk5eXh\n4sWLTdZpqMPD9RNF0WT97DokWjvHk71SKpUAAG9vb4SHhyMvLw+dO3fWX7ar1Wp4e3vbsoitJlUf\nX19fFBcX69crLi52yGPq7e2t/881Y8YMfZNE165dcePGDf16jlC/2tpaLFq0CFOmTEF4eDgA5zp+\nhurnTMevgaenJ4YNG4Yff/wRFRUV+ibtB+vw4PGrq6vD3bt30bFjR6P7teuQeOKJJ3D16lUUFRVB\no9Fg3759CAsLs3WxWqWmpgZVVfXTmVdXV+PEiRMICgrCuHHjsGvXLgBAenq6w9Xz4V8oUvUJCwtD\nRkYGAODs2bPo0KGDvlnDnj1cP7Varf/74MGDCAoKAlBf7/3790Oj0eDatWu4evWqvnnDXiUkJCAw\nMBBz587VL3Om42eofs5y/MrKyvRNZffu3UNOTg4CAwMxYsQIHDhwAEDj4zdu3Dikp6cDAA4cOICn\nn37a5GfY/bQc2dnZeP/99/VzPC1YsMDWRWqVa9eu4a233oIgCKirq0NkZCQWLFiA8vJyLFmyBDdu\n3ICfnx+Sk5MNdpbao7fffhsqlQrl5eXw8fHBwoULER4ejsWLFxusz6pVq3D8+HG0a9cOSUlJCA4O\ntnENjDNUP5VKhfz8fCgUCnTv3h2rVq3SnyxTUlKwY8cOuLq6YuXKlRg1apSNayDt9OnTePnllxEU\nFARBECAIApYuXYqBAwdK/nt0pOMnVb+9e/c6xfH7+eefsXz5cuh0Ouh0OkycOBG/+c1vcO3aNcTF\nxaGiogL9+/fHunXr0KZNG2g0Gixbtgz5+fno1KkT1q9fjx49ehj9DLsPCSIish27bm4iIiLbYkgQ\nEZEkhgQREUliSBARkSSGBBERSWJIEBGRJIYE2b3a2lokJycjIiICkZGRmDRpEtasWYO6ujqj261Y\nsQJpaWkA6qeGXrt2rcnPOnToEH766SezlNsSioqKsG3bNlsXgx4hDAmye8uXL8fFixeRkZGBPXv2\nYPfu3QgICIBGozH7Z2VlZdn1rJ+FhYX49ttvW7StqVAlMsTV1gUgMuaXX35BVlaW/g5foH72ytjY\nWAD1M2CuW7dO/3CYUaNGIT4+3uikZefPn8fvf/971NTUQKPRYMaMGZgzZw5OnDiBw4cPIycnBzt2\n7MArr7yCwsJCHDx4EIIgQKPR4NKlS/jHP/4BT0/PRvv85z//iXXr1qGqqgqCICA+Ph4jR45EXl4e\nPvjgA9TU1KBdu3ZYuXIlnnjiCZw6dQpr1qzBzp07AaDR61OnTuGDDz7AwIEDcfbsWSgUCqxfvx4B\nAQF47733UFRUhJiYGPTs2RPJycm4dOkSkpKSUF5eDq1Wizlz5mDq1KkAgMcffxzLli3D0aNHMWzY\nMCxatMjsx4icnFmeekFkIfv37xejo6Ml3//666/FefPmibW1taJWqxXnzp2rf8DK8uXLxa1bt4qi\nWP+QoDVr1oiiKIpVVVWiRqPR/z1x4kTx4sWLTbZ52LJly8QPP/ywyfLy8nIxNDRUPHv2rCiKoqjT\n6cSKigpRo9GIY8eOFXNyckRRFMUffvhBHDt2rKjVakWVSiVOmzZNv48HX6tUKjE4OFjMz88XRVEU\n//SnP4nvvPNOk/VEsf5BMzExMeKlS5dEUax/sE5ERIT+db9+/cTPP/9c8vsjMoVXEmTXRBOzxuTk\n5CAmJgYuLi4AgKlTp+LQoUN44YUXJLepqanBu+++i4KCAigUCqjVahQUFCAgIEBym40bN6Kmpga/\n/e1vm7x39uxZBAYGYtCgQQDqp2X28vLC+fPn4ebmpp9ELSQkBG5ubrh8+bLJevfp0wePP/44gPqH\n+xw9etTgeleuXMGlS5cQFxen/660Wi0uXryIPn36AID+IUFELcGQILsWHByMK1euoLKyEl5eXk3e\nFw3Mh29qfvz169dDqVRi7dq1+gdAGevf2LlzJ06ePKl//rOhMshd3lBeFxeXRk8nvH//fqP13N3d\n9X+7uLjon1FsaH/e3t76mT0fJggC2rdvb/A9IjnYcU12rVevXhg3bhwSExP1U6zX1dVhy5YtqKmp\nwciRI5Geno7a2lpotVpkZGQgNDTU6D4rKyvRrVs3CIKA8+fPIzc3V/+eh4cH7t69q3/9ww8/4LPP\nPsMf//hHuLm5Gdzfk08+iQsXLuDHH38EUN9PUlFRgYCAAGi1Wpw6dQoAcPLkSdTW1qJ3797o0aMH\nCgsLUVlZCVEUsW/fPlnfh6enp35qaKD+iqNt27bIzMzUL7t06ZL+uzJ1JUZkCq8kyO6tWbMGn376\nKaZOnQo3NzeIoojRo0fDzc0NM2fOxNWrV/XP7X3mmWf0ndpSfvOb3yA+Ph67d+9Gz549MWzYMP17\nU6ZMwYoVK3DgwAG88sor2LlzJ2pqajB//nz9VUBaWlqjX+cdO3bEpk2bkJSUhOrqari4uCA+Ph4h\nISH45JNPsHr1an3H9aeffgpXV1f4+vpi3rx5iImJgb+/P5544glcuHDB5HfRr18/9OnTB5GRkQgI\nCEBycjL++7//G++//z42b96Muro6+Pj4YOPGjQDs/6lqZP84VTgREUlicxMREUliSBARkSSGBBER\nSWJIEBGRJIYEERFJYkgQEZEkhgQREUliSBARkaT/AzLfG+oMx+5pAAAAAElFTkSuQmCC\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7fc3b259add0\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "counts = []\n", + "times = []\n", + "for n in np.logspace(0, 7, 50):\n", + "\n", + " with tf.Graph().as_default():\n", + " tf_collatz = ag.to_graph(collatz)\n", + " count = tf_collatz(tf.constant(n, dtype=tf.float32))\n", + " with tf.Session() as sess:\n", + " count_value = sess.run(count)\n", + "\n", + " res = %timeit -n10 -r1 -o -q sess.run(count)\n", + " counts.append(count_value)\n", + " times.append(res.best)\n", + " \n", + "plot_results(counts, times, 'AutoGraph')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "RRENYzLRF_f3" + }, + "source": [ + "# Eager" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 301 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 5003, + "status": "ok", + "timestamp": 1531757478713, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "dhDf8LLdF_f-", + "outputId": "3de0a5a5-7a11-4b41-8ab0-e4e21ce8d59b" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAEcCAYAAAAydkhNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XtYVWW+B/Dv2hshBdSQHaighhwas7Qeb6GFDjIyI3LZ\nGphdJLLMzqSlKaPQsTPm5KhZkZ7moKOnManGK17wsUfIS87QNj2jnEnIg5cQEtyAyDWBvdf5g4d9\nBPbaLGCvfeP7+QvW2mvt38tGvq71vut9BVEURRAREZmhsncBRETkuBgSREQkiSFBRESSGBJERCSJ\nIUFERJIYEkREJIkhQUREktzsXQCRvYWHh6OiogJqtRqiKEIQBMyePRtvv/22vUsjsjuGBBGA9PR0\nPPHEE3Z5b4PBALVabZf3JuoMbzcRATA38cCNGzeQmJiISZMmITQ0FMuXL0dtba1p//fffw+tVotx\n48bhjTfewNKlS5GWlmbaf+LECcTFxWHChAmYN28efvjhB9O+8PBwbNu2DTExMXj88cdhNBqVbSBR\nNzEkiCSIoohFixbhb3/7G44ePYqysjJs3rwZANDU1ITFixdjzpw5OHv2LGbNmoXjx4+bjv3++++R\nmpqKd999F2fPnsXcuXPx2muvoampyfSao0ePYtu2bTh37hxUKv5TJMfE30wiAL/97W8xceJETJgw\nARMnTsSePXswbNgwhIaGws3NDffffz8SExPx3XffAQAuXLgAg8GA559/Hmq1Gr/61a8wZswY0/n2\n7NmDZ555Bo8++igEQUBcXBzc3d1x8eJF02vmz58PPz8/uLu727y9RHKxT4IIwCeffNKhT6KyshJr\n167FuXPnUF9fD4PBgIEDBwIA9Ho9/Pz82rx+8ODBpq9/+uknHDx4ELt27QLQclXS3NyMW7dumV7j\n7++vVHOIrIYhQQTzfRKbNm2CIAg4cuQI+vfvj+zsbKxduxYAoNFoUFZW1ub1N2/exLBhwwC0BMCi\nRYvw6quvKl88kYJ4u4lIQl1dHTw9PeHl5YWysjJs377dtO+xxx6DWq1GRkYGDAYDsrOzkZeXZ9qf\nkJCAL7/80rStvr4ep06dQn19vc3bQdQTvJIgAvDaa69BpVKZnpOYMmUK3nzzTaxYsQLjx4/H8OHD\nERsbi08//RQA0KdPH2zevBmpqanYtGkTwsLCEB4ebupfeOSRR/Duu+9izZo1KCoqgoeHB8aNG4cJ\nEyYAAARBsFdTibpEUHLRocbGRjz33HNoamqCwWBAZGQkXn/9dRQXF2PZsmW4c+cORo8ejQ0bNsDN\njXlFzi0hIQHz5s2DVqu1dylEVqPo7SZ3d3fs3LkTmZmZyMzMxOnTp3Hx4kW8//77SEpKwldffQVv\nb2/s3btXyTKIFPHdd9+hvLwcBoMBBw4cwOXLl/HUU0/Zuywiq1K8T6Jv374AWq4qmpubIQgCdDod\nIiMjAQBarbbN+HIiZ3Ht2jXExsZi/Pjx+PTTT/Hxxx/D19fX3mURWZXi93iMRiNmz56NoqIiPPfc\ncwgMDET//v1NDw/5+/u3GRZI5CwSEhKQkJBg7zKIFKX4lYRKpTLdasrLy8OVK1c6vIadeEREjslm\nQ2C9vLwwYcIEXLx4EdXV1aa5akpLS/HAAw90eryC/etERCRB0dtNlZWV6NOnD7y9vfHzzz8jNzcX\nCxcuxKRJk3Ds2DHMnDkTBw4cwPTp0zs9lyAI0OtrlCzXrjQab7bPSbly2wC2z9lpNN49Ol7RkNDr\n9Vi5ciWMRiOMRiNmzpyJqVOnIigoCMuWLUNaWhpGjRqFp59+WskyiIiomxR9TsLaXD3t2T7n5Mpt\nA9g+Z9fTKwlOy0FERJIYEkREJIkhQUREkhgSREQkiSFBRESSGBJERCSJIUFERJIYEkREJIkhQURE\nkhgSREQkiSFBRESSGBJERCSJIUFERJIYEkREJIkhQUREkhgSREQkiSFBRESSGBJERCSJIUFERJIY\nEkREJIkhQUREkhgSREQkiSFBRESSGBJERCSJIUFERJIYEkREJIkhQUREkhgSREQkyU3Jk5eWliI5\nORnl5eVQq9VISEjACy+8gC1btmD37t0YNGgQAGDp0qUICwtTshQiIuoGRUNCrVZj1apVGDVqFOrq\n6jB79mxMnjwZAJCUlISkpCQl356IiHpI0ZDQaDTQaDQAAE9PT4wcORK3bt0CAIiiqORbExGRFdis\nT6K4uBgFBQUYM2YMACAjIwOxsbFITU1FTU2NrcogIqIusElI1NXVYcmSJUhJSYGnpyeeffZZZGdn\n4+DBg/D19cW6detsUQYREXWRICp836e5uRmvvvoqwsLCkJiY2GF/SUkJFi1ahMOHDytZBhERdYOi\nfRIAkJKSguDg4DYBodfrTX0Vx48fR0hIiKxz6fWue1tKo/Fm+5yUK7cNYPucnUbj3aPjFQ2J8+fP\n4/DhwwgJCUFcXBwEQcDSpUtx5MgR5OfnQ6VSYejQoVizZo2SZRARUTcpGhLjxo1Dfn5+h+18JoKI\nyDnwiWsiIpLEkCAiIkkMCSIiksSQICIiSQwJIiKSxJAgIiJJDAkiIpLEkCAiIkkMCSIiksSQICIi\nSQwJIiKSxJAgIiJJDAkiIpLEkCAiIkkMCSIiksSQICIiSQwJIiKSxJAgIiJJspcv/fnnn6HX6+Hh\n4YEHHnhAyZqIiMhBWAwJo9GIzMxM7NmzBwUFBfDy8kJjYyPc3NwQERGBF198EQ8++KCtaiUiIhuz\nGBLz5s3DY489hlWrVmH06NFQq9UAgIqKCnzzzTdYvXo1nnnmGURFRdmkWCIisi1BFEVRamdlZSV8\nfHwsnkDOa6xFr6+xyfvYg0bjzfY5KVduG8D2OTuNxrtHx1vsuDb3x7+iogIXLlyw+BoiInINskY3\nPfvss6ipqUF1dTXi4uKQmpqK9evXK10bERHZmayQqK+vh7e3N06cOIHo6GgcPnwYZ86cUbo2IiKy\nM1kh0djYCADQ6XSYPHkyVCqVqRObiIhcl6yQmDhxIiIjI3Hu3DlMnDgR1dXVUKn4HB4RkauT9TDd\nO++8g4KCAgQGBsLd3R21tbVYu3at0rUREZGdWQyJwsJC09d9+vRBaWmp6Xt3d3flqiIiIodgMSQW\nLlwIQRAgiiJu3rwJLy8vAEBtbS0GDx6Mr7/+2uLJS0tLkZycjPLycqjVasTHx2P+/Pm4c+cOli5d\nipKSEgQEBOCjjz6Ct3fPxvISEZH1WQyJ1hBYu3Ytxo0bh9/85jcAgGPHjuHSpUudnlytVmPVqlUY\nNWoU6urqMHv2bEyZMgX79+9HaGgoXnnlFWzduhXp6elYvny5FZpDRETWJKv3OS8vzxQQAPDrX/8a\n3377bafHaTQajBo1CgDg6emJkSNHoqysDDk5OdBqtQAArVaL7Ozs7tROREQKkxUSDQ0NOHfunOn7\nc+fOoaGhoUtvVFxcjIKCAowdOxYVFRXw9fUF0BIkt2/f7tK5iIjINmSPblq2bBn69u0LALh79y42\nbdok+03q6uqwZMkSpKSkwNPTE4IgdKvYns5B4ujYPuflym0D2D65Tv+jGHty/hdFZTUY5ueN+On/\ngrDHA6xybnuRFRLjx49HdnY2rl27BlEUERQUJHt0U3NzM5YsWYLY2FhEREQAAAYNGoTy8nL4+vpC\nr9fLnv/J1SfhYvuckyu3DWD75NJdKkP6oe9N31+/WY2Nu86juvpnTHrYr8fn7y5FJ/i7l8FggLu7\nO9zc3FBUVNRmeKwlKSkpCA4ORmJiomlbeHg49u/fDwA4cOAApk+f3sWyiYgcS1budYntP9q0DmuT\ndSWRkZGB999/HwMHDjTdKhIEATk5ORaPO3/+PA4fPoyQkBDExcVBEAQsXboUr7zyCt58803s27cP\nQ4YMQVpaWs9bQkRkRz+V15vdfrOizsaVWJeskNixYweOHDmCoUOHdunk48aNQ35+vtl9n376aZfO\nRUTkyIb49kOxvmMgDB7kaYdqrEfW7SaNRtPlgCAi6k2iQkdIbB9u20KsTNaVxOTJk7FhwwZERUXB\nw8PDtD04OFixwoiIHInuUhmycq/jp/J6DPHth6jQEW06pFu/zsr9ETcr6jB4kCeiQofbtdPaGmSF\nRGZmJoCWJ61byemTICJyBe1HLhXr60zftw8KZw+F9mSFRGdzNBERuTJLI5dcLRTakxUSQMuMsDqd\nDgDwxBNPYOTIkYoVRUTUXZ3dFuoOVx25JIesjuvMzEy8+OKLyM/PR35+PpKSknDo0CGlayMi6pLW\n20LF+joYRdF0W0h3qaxH5x3i28/sdmcfuSSH7CGwBw4cgEajAQDo9XosWLAAMTExihZHRNQVSt0W\nigod0aZP4v+3O/fIJTlk325qDYj2XxMROQqlbgu56sglOWSFxLBhw/Dxxx9j7ty5EAQBu3fvRmBg\noNK1ERF1iZIPtLniyCU5ZPVJ/P73v8e1a9cQExODmJgYXL16FWvWrFG6NiKiLnHVB9rsSdaVxKBB\ng/Dhhx8qXQsRUY/05ttCSpEVElu3bkVCQgIGDhwIALh9+zb27duHl19+WdHiiIi6qrfeFlKKrNtN\nWVlZpoAAgPvvvx9HjhxRrCgiInIMskJCFMUO2wwGg9WLISIixyIrJEaMGIH/+q//giiKMBqN2LFj\nB4YNG6Z0bUREZGeyQiI1NRUnTpzAmDFj8Nhjj+HUqVNYvXq10rUREZGdyeq49vPzw86dO1Ff3/Kg\nSr9+5h9RJyIi1yK7T2LPnj345JNP0K9fPxQXF+O///u/la6NiIjsTFZIrFu3Dt9++y2ys7MBAJ6e\nnnjvvfcULYyIiOxP1u0mnU6HzMxMaLVaAC1DYO/evatoYUREligxJTh1JCskPDw8IAiC6Xuj0ahY\nQUREnZG7Uhz1nKzbTSEhITh06BBEUURxcTH+/d//HePGjVO6NiIisyxNCU7WJSskVq5cibNnz0Kv\n1yM+Ph4GgwErVqxQujYiIrN680pxtibrdpOXlxfWrl2rdC1ERLIoOSU4tSXrSuLo0aOora0FAKSl\npWHBggX45z//qWhhRES6S2VYvV2Hl9efwOrtOtMypJwS3HZkhcSf/vQneHl5IS8vD2fOnEFcXByv\nLIhIUZbWq570sB9ejRmNAI0X1CoBARovvBozmp3WCpB1u8nNreVlf/vb3xAfH4/o6Gjs2LFD0cKI\nqHfrbL1qTgluG7KuJARBwKFDh5CVlYXQ0FAAQFNTk6KFEVHvxs5pxyArJN5++20cO3YM8fHxCAwM\nxPXr1zFp0qROj0tJScHkyZMRHR1t2rZlyxaEhYVBq9VCq9Xi9OnT3a+eiFzWEF/zc8Sxc9q2BNHc\nYhFWcu7cOXh6eiI5ORmHDx8G0BISnp6eSEpK6vL59Poaa5foMDQab7bPSbly2wD7ta/9A3OtrN33\n0Bs+v56w2Cfxl7/8BYmJidiwYYPZ/cnJyRZPPn78eJSUlHTYrmAuEZGD6uo0Glyv2jFYDAkPDw8A\n1p8aPCMjAwcPHsQjjzyClStXwtu7Z0lHRI6tu9NosHPa/hS93QQAJSUlWLRokel2U2VlJe6//34I\ngoAPP/wQer2eM8oSubjF75/A9ZvVHbaPGNwfm5f/0g4VkVwWryQyMjIsHvzcc891+Q19fHxMXyck\nJGDRokWyj3X1+4Zsn3Ny5bYB1mlfUan542+U1dj9Z9cbPr+esBgS1niquv2Fil6vh0ajAQAcP34c\nISEhPX4PInJsnEbDeVkMiXXr1vXo5G+99RZ0Oh2qqqowbdo0LF68GDqdDvn5+VCpVBg6dCjWrFnT\no/cgIscXFTrC7EglTqPh+CyGxKlTpywePHXqVIv7N23a1GHbnDlzZJRFRK6EI5Wcl8WQ+POf/yy5\nTxCETkOCiKgVRyo5J4sh8dlnn9mqDiIickAWQ+LGjRsIDAxEYWGh2f3BwcGKFEVERI7BYkisXbsW\n6enpWLhwYYd9giAgJydHscKIiMj+LIZEeno6AODrr7+2STFERORYZK0nAQANDQ0oLS2FwWAwbePt\nJiLH1NV5koikyAqJnTt34sMPP8SAAQOgUrXMLs7bTUSOqbvzJBGZIysk/vKXv+DYsWPw8+MvGJGj\n62xFN6KukLXokL+/PwOCyElIrehWrK/F6u066C6V2bgicmayriQWL16M1NRUTJ061TR9OND5E9dE\nJM3a/Qat5zNamNiZt56oq2SFxIkTJ3DixAlcv369TZ8EQ4Koe6zdbyC1ipsU3noiuWSFxPHjx/H1\n11/jvvvuU7oeol6hJ/0G5q5ApM4n5WZFxxlZicyRFRKBgYFwc5M9WpaIOiHVb9DZH2+pKxBB6Nr7\nc4pukkvWX/7hw4cjMTERERERcHd3N23vzqJDRNT5+gpS/RVSVwxuKhWaDMYO2328PVBZc7fDdk7R\nTXLJCommpiYMGzYMly9fVroeol7B0voKlvorpK5Amo0dAwIA4n/Z8sArp+im7pIVEj1dfIiI2rK0\nvsLq7Tqzx6Qf+h591AKMho77hvp6ISp0uGQYMBSouzpdvvSRRx6R3N/Y2IgbN25g5MiRVi+MyNVJ\nra8gdbUAAE0G88NbWwOBYUDW1ukEfw0NDZg1axbGjh0LX19f3L17F9euXcM333yDU6dOYeXKlQwJ\nIiuS6q+4Vx+1CkZR5O0jUpzFkNi8eTPy8vLw17/+Ff/xH/+B0tJS9O3bFyEhIYiIiEBGRga8vLxs\nVStRryDVX3EvoyhiW/IvbVQR9Wad9kmMGTMGY8aMsUUtRISW21B7ThSaHZXUikNYyVZkzd1ERLbV\nOipJCoewkq3wCTkiO+hs3qZJD/uhsOQOcs4Xdzh2+rgA9kGQzTAkiGxM7rxNz/0qBMFDB/AZB7Ir\nhgSRjXVl3iYOayV7k9UnUVFRgeXLl5um4SgoKMAXX3yhaGFErqq78zYR2YOskHj77bcxbtw4VFdX\nAwCCgoLw+eefK1oYkasa4tvP7HaOWCJHJCskysrKMG/ePKjVagCAu7u7aV0JIuqaqNAREts5Yokc\nj6w+ifbThFdXV0O0sPoVUW/TlVXmLM3bRORoZIXEjBkzsHr1atTV1WH//v34/PPPMWfOnE6PS0lJ\nwcmTJzFo0CAcPnwYAHDnzh0sXboUJSUlCAgIwEcffQRvb++etYLIjk7/o7jLq8yxQ5qchax7Ri+/\n/DLGjx+P0aNH49SpU3jhhReQmJjY6XGzZ8/G9u3b22zbunUrQkND8dVXX2HSpElIT0/vXuVEDmJP\nzv+a3Z6V+6ONKyGyPtlDYGNiYhATE9Olk48fPx4lJSVttuXk5GDXrl0AAK1WixdeeAHLly/v0nmJ\nHElRWY3Z7RytRK5AVkhUVFRg165dKCoqQnNzs2l7Wlpal9+wsrISvr6+AACNRoPbt293+RxE1tCV\nfgRLrx3m543rN6s7HMPRSuQKZIXEv/7rv+Lhhx9GaGioaYSTPWg0rt13wfbZjlQ/wtbD32O4f3/E\nT/8XhD0eYPG1/fvfh7DHAxA//V+wcdf5Du8xL/Ihh2pzT7hKO6S4evt6QlZINDQ04J133rHKGw4a\nNAjl5eXw9fWFXq+Hj4+P7GP1evOX9a5Ao/Fm+2zoi68KzG4XReD6zWps3HUe2w/+E/G/DJZ8QvqL\nr37AqIABCHs8ANXVP3cYrTQqYIBDtbm7HO2zs7be0L6ekBUSY8eOxQ8//ICHHnqoy2/QfqhseHg4\n9u/fj4ULF+LAgQOYPn16l89J1FOWVn9rVVlzF+mHvocgmN9/b58DRyuRq5IVEs888wyef/55+Pv7\nw8PDw7R97969Fo976623oNPpUFVVhWnTpmHx4sVYuHAh3njjDezbtw9DhgzpVr8GUU/JWf2tlZtK\nhSaDscN29jlQbyArJFasWIFFixbh4Ycf7lKfxKZNm8xu//TTT2Wfg8iaWjugS8rljzxqNnYMCIBP\nSFPvICskPDw8sGDBAqVrIeoWuaOU2k/RLddQXy9EhQ7nE9LUK8kKiaeeegqnT59GWFiY0vUQdYnc\ntRkA6Sm6AzQtIbDnZCEqqzsuGdoaCAwF6o1khcTu3buxdetWeHp6wt3dHaIoQhAE5ObmKl0fkUWW\n1mZo3d96hSF1i+lmRZ0pBFquSnjFQNRKVkjs27dP6TqIukVqlFJJeW2HKwwp93ZA84qBqC1ZITF0\n6FCl6yDqFqlRSlIjksxhBzSRNIshsWLFCmzcuBFz5syBYGaweGdDYImUcG9H9UAvd7OvkRqRJAgt\nHdG8nUQkj8WQaJ3p9Xe/+51NiiEyp30oVNb8f+dy69c+3h64U9do+sOflXvd7BXGUF8vrFkw0UaV\nEzk/iyHx+eef47333sPEifxHRfbRfvTSvQFxr3739cH7v53SZpu54a68tUTUNRZDIj8/31Z1EJm1\n50ShrNe1n5abq78RWYfs9SSIlNT+gbiHht2PH4puS145tGduigyOVCLqOYshcfnyZYSGhnbYzuck\nyJrMPRAnd16lVryNRKQMiyExYsQIbN261Va1UC8l9UCcHH3UKrwUNYpXDEQKsRgS7u7ufEaCFCdn\n2m4pDAgiZaks7ezTp4+t6qBebIhvP9mv7aNWQSW0zLf0asxoBgSRwixeSezevdtWdVAvFhU6Qtbs\nrAwFItvj6Cayu9Y//FKzsPr090D8tGAGBJEdMCTI6syt7zBrquV1djkLK5FjEsT2i1A7MFdfrNzZ\n2mcuDADzTzq3zJnkKbkgkDNzxs+uK9g+56bRWP4PWmd4JUHdIrXYj4+3h9nXi6LlBYGIyDFZHN1E\nJEXq2QY5T0i3LghERI6PIUHd0pNnG9rPs0REjou3m0iW9v0P/e5zQ21DU7fOZW6eJSJyTAwJ6pS5\n/oee4DxLRM6DIUGd6sncSq1UAjDE14tDWomcDEOCOtWd/of2K8XNmhrs0sMMiVwVQ4I6NcS3n+xb\nTHw6msi1MCSoU3LmVgrQ8FYSkStiSFCnTHMrnSg0+xwEJ94jcl12C4nw8HB4eXlBpVLBzc0Ne/fu\ntVcpJAPnViLqnewWEoIg4LPPPsOAAQPsVUKvZG6+pa78kee60US9i91CQhRFGI1Ge719ryQ13xLA\nuZSIyDy7XkksWLAAgiBg7ty5SEhIsFcpLkfqakHqeYes3B8ZEkRklt1C4ssvv4RGo0FlZSWSkpIQ\nFBSE8ePH26sclyF1tVBYckfyeQfOpUREUhxiPYktW7bA09MTSUlJ9i7F6S1+/wSu36w2u893YF+U\nVzV02D5icH9sXv5LpUsjIidklyuJhoYGGI1GeHp6or6+HmfOnMHrr7/e6XGu/MSutRY+KSqVPodU\nH1DkhEDFf7auvLCLK7cNYPucnVMuOlReXo7XX38dgiDAYDAgOjoaTz75pD1KcTmWno6+U9uIV2NG\ncwgrEclml5AIDAzEwYMH7fHWLs/S09GDB3lyCCsRdQkXHXIxkx72w/RxAWb3cYpuIuoqTsvhgp77\nVQiChw7gbSUi6jGGhIvibSUisgbebiIiIkkMCSIiksSQICIiSeyTsKGezsBKRGRrDAkbyTh+GTnn\ni03fcwZWInIGDAmF6S6VSa7oBnAGViJybAwJBbWfkdUczsBKRI6MIWEl5vobpNZvuNfgQZ5Kl0ZE\n1G0MCSuQWsNBEDo/llNlEJEj4xBYK5C6YnBTWf7xTh8XwP4IInJovJKwAqkV35ol1m/w8fZA/C+D\nGRBE5PAYElYgtYbDUF8vRIUO50R7ROS0GBKdkPMAnNQaDq2BwFAgImfFkLBAqkMaaPsAXOvXvGIg\nIlfDkGjn3isHtUS/s7kH4HjFQESuiCFxj/ZXDkaD+dfxATgi6i04BPYech5+A/gAHBH1HgyJe0gN\nZW2PD8ARUW/B2033kBrK2ketglEU2SFNRL0OQ+IeUkNZX4oaxWAgol6JIXEPDmUlImqLIdEOh7IS\nEf0/dlwTEZEkp76S4JrRRETKctqQkDtlBhERdZ/dbjedPn0av/71rxEZGYmtW7d2+XipB9+ycn/s\nWWFERGRil5AwGo149913sX37dhw5cgRZWVm4cuVKl84h9eAbp8wgIrIeu4REXl4ehg8fjqFDh6JP\nnz6IiopCTk6OxWNiVxzC6u066C6VAWh58M0cTplBRGQ9dgmJsrIyDB482PS9n58fbt26ZfEYo1E0\n9TvoLpUhKnSE2ddxygwiIuuxS0iIotij41un6n41ZjQCNF5QqwQEaLzwasxodloTEVmRXUY3+fv7\n46effjJ9X1ZWhgceeED28Tcr6qDReGPWVG/MmhqsRIl2odF427sERbly+1y5bQDb15vZJSQeffRR\nFBUVoaSkBBqNBllZWfjggw8sHnN4U6yNqiMiolZ2CQm1Wo1/+7d/w0svvQRRFPH0009j5MiR9iiF\niIgsEMSedhAQEZHL4txNREQkiSFBRESSGBJERCTJ4UOip3M8OaLw8HDExMQgLi4OTz/9NADgzp07\neOmllxAZGYkFCxagpqbGzlXKl5KSgsmTJyM6Otq0zVJ71q5dixkzZiA2Nhb5+fn2KLlLzLVvy5Yt\nCAsLg1arhVarxenTp0370tPTMWPGDPzmN7/BmTNn7FGybKWlpZg/fz5mzpyJ6Oho7Ny5E4DrfH7t\n2/fZZ58BcJ3Pr7GxEfHx8YiLi0N0dDS2bNkCACguLkZCQgIiIyOxbNkyNDc3m16/dOlSzJgxA3Pn\nzm3zKIIk0YEZDAYxIiJCLC4uFhsbG8WYmBixsLDQ3mX1WHh4uFhVVdVm24YNG8StW7eKoiiK6enp\n4saNG+1RWrd899134qVLl8RZs2aZtkm15+TJk+Irr7wiiqIoXrhwQYyPj7d9wV1krn2bN28Wd+zY\n0eG1hYWFYmxsrNjU1CTeuHFDjIiIEI1Goy3L7ZJbt26Jly5dEkVRFGtra8UZM2aIhYWFLvP5SbXP\nVT4/URTF+vp6URRFsbm5WYyPjxcvXLggvvHGG+LRo0dFURTF1atXi1988YUoiqKYkZEhvvPOO6Io\nimJWVpb45ptvdnp+h76S6M4cT85AFEUYjcY223JycqDVagEAWq0W2dnZ9iitW8aPH4/+/fu32da+\nPa2fW04mDIcbAAAJLUlEQVRODuLi4gAAY8eORU1NDcrLy21bcBeZax9gfuaAnJwczJw5E25ubggI\nCMDw4cORl5dnizK7RaPRYNSoUQAAT09PjBw5EmVlZS7z+ZlrX+sUQK7w+QFA3759AbRcJTQ3N0MQ\nBOh0OkRGRgJo+/fk3s81MjISubm5nZ7foUOiO3M8OQNBELBgwQLMmTMHe/bsAQBUVFTA19cXQMsv\n9u3bt+1ZYo9VVla2aU9lZSUA4NatW/D39ze9zs/PD2VlZXapsacyMjIQGxuL1NRU0+0Yc7+zztK+\n4uJiFBQUYOzYsR1+H13h82tt35gxYwC4zudnNBoRFxeHKVOmYMqUKQgMDET//v2hUrX8eff39ze1\n4d7PT61Wo3///qiqqrJ4focOCXNJ7wq+/PJL7N+/H9u2bUNGRgbOnTsHQRDsXZZNmPtMnbHtzz77\nLLKzs3Hw4EH4+vrij3/8IwDnbV9dXR2WLFmClJQUeHp6StbsKu1zpc9PpVIhMzMTp0+fRl5entll\nF1rb0L59oih22j6HDomezvHkqDQaDQDAx8cHERERyMvLw6BBg0yX7Xq9Hj4+PvYsscek2uPn54fS\n0lLT60pLS53yM/Xx8TH940pISDDdkvD398fNmzdNr3OG9jU3N2PJkiWIjY1FREQEANf6/My1z5U+\nv1ZeXl6YMGECLl68iOrqatMt7XvbcO/nZzAYUFtbiwEDBlg8r0OHxL1zPDU2NiIrKwvTp0+3d1k9\n0tDQgLq6loWR6uvrcebMGYSEhCA8PBz79+8HABw4cMDp2tn+fyhS7Zk+fToyMzMBABcuXED//v1N\ntzUcWfv26fV609fHjx9HSEgIgJZ2Hz16FI2Njbhx4waKiopMtzccVUpKCoKDg5GYmGja5kqfn7n2\nucrnV1lZabpV9vPPPyM3NxfBwcGYNGkSjh07BqDt5xceHo4DBw4AAI4dO4Ynnnii0/dw+Gk5Tp8+\njT/84Q+mOZ4WLlxo75J65MaNG3j99dchCAIMBgOio6OxcOFCVFVV4c0338TNmzcxZMgQpKWlme0s\ndURvvfUWdDodqqqq4Ovri8WLFyMiIgJvvPGG2fasWbMG33zzDfr27Yt169Zh9OjRdm6BZebap9Pp\nkJ+fD5VKhaFDh2LNmjWmP5bp6enYu3cv3NzckJqaiieffNLOLZB2/vx5PP/88wgJCYEgCBAEAUuX\nLsWYMWMkfx+d6fOTat+RI0dc4vP74YcfsHLlShiNRhiNRsycOROvvfYabty4gWXLlqG6uhqjRo3C\nxo0b0adPHzQ2NmLFihXIz8/HwIED8cEHHyAgIMDiezh8SBARkf049O0mIiKyL4YEERFJYkgQEZEk\nhgQREUliSBARkSSGBBERSWJIkMNrbm5GWloaIiMjER0djaioKKxfvx4Gg8HicatWrUJGRgaAlqmh\nN2zY0Ol7ZWdn43/+53+sUrcSSkpKsHv3bnuXQb0IQ4Ic3sqVK3HlyhVkZmbi8OHDOHToEIKCgtDY\n2Gj198rJyXHoWT+Li4vx17/+tVvHdhaqROa42bsAIkt+/PFH5OTkmJ7wBVpmr4yPjwfQMgPmxo0b\nTYvDPPnkk0hOTrY4adnly5fx+9//Hg0NDWhsbERCQgLmz5+PM2fO4Ouvv0Zubi727t2LF198EcXF\nxTh+/DgEQUBjYyOuXr2K7777Dl5eXm3O+Y9//AMbN25EXV0dBEFAcnIyJk+ejLy8PLz33ntoaGhA\n3759kZqaikcffRRnz57F+vXrsW/fPgBo8/3Zs2fx3nvvYcyYMbhw4QJUKhU++OADBAUF4d1330VJ\nSQm0Wi2GDRuGtLQ0XL16FevWrUNVVRWampowf/58zJ49GwDwi1/8AitWrMDJkycxYcIELFmyxOqf\nEbk4q6x6QaSQo0ePinFxcZL7P//8czEpKUlsbm4Wm5qaxMTERNMCKytXrhR37dolimLLIkHr168X\nRVEU6+rqxMbGRtPXM2fOFK9cudLhmPZWrFgh/vGPf+ywvaqqSpwyZYp44cIFURRF0Wg0itXV1WJj\nY6M4bdo0MTc3VxRFUfz73/8uTps2TWxqahJ1Op04Z84c0znu/V6n04mjR48W8/PzRVEUxT/96U/i\n8uXLO7xOFFsWmtFqteLVq1dFUWxZWCcyMtL0/UMPPST++c9/lvz5EXWGVxLk0MROZo3Jzc2FVquF\nWq0GAMyePRvZ2dl45plnJI9paGjAO++8g4KCAqhUKuj1ehQUFCAoKEjymI8++ggNDQ343e9+12Hf\nhQsXEBwcjLFjxwJomZbZ29sbly9fhru7u2kStdDQULi7u+PatWudtvvBBx/EL37xCwAti/ucPHnS\n7OuuX7+Oq1evYtmyZaafVVNTE65cuYIHH3wQAEyLBBF1B0OCHNro0aNx/fp11NTUwNvbu8N+0cx8\n+J3Nj//BBx9Ao9Fgw4YNpgWgLPVv7Nu3D99++61p/WdzNcjd3lqvWq1uszrh3bt327zOw8PD9LVa\nrTatUWzufD4+PqaZPdsTBAH9+vUzu49IDnZck0MbPnw4wsPDsXr1atMU6waDATt37kRDQwMmT56M\nAwcOoLm5GU1NTcjMzMSUKVMsnrOmpgaDBw+GIAi4fPkyzp07Z9rn6emJ2tpa0/d///vfsW3bNnzy\nySdwd3c3e77HH38chYWFuHjxIoCWfpLq6moEBQWhqakJZ8+eBQB8++23aG5uxogRIxAQEIDi4mLU\n1NRAFEVkZWXJ+nl4eXmZpoYGWq447rvvPhw8eNC07erVq6afVWdXYkSd4ZUEObz169dj8+bNmD17\nNtzd3SGKIsLCwuDu7o65c+eiqKjItG7vU089ZerUlvLaa68hOTkZhw4dwrBhwzBhwgTTvtjYWKxa\ntQrHjh3Diy++iH379qGhoQELFiwwXQVkZGS0+d/5gAEDsGXLFqxbtw719fVQq9VITk5GaGgoPv74\nY6xdu9bUcb1582a4ubnBz88PSUlJ0Gq1CAwMxKOPPorCwsJOfxYPPfQQHnzwQURHRyMoKAhpaWn4\nz//8T/zhD3/Ajh07YDAY4Ovri48++giA46+qRo6PU4UTEZEk3m4iIiJJDAkiIpLEkCAiIkkMCSIi\nksSQICIiSQwJIiKSxJAgIiJJDAkiIpL0f3zF2/hGE4QYAAAAAElFTkSuQmCC\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7fc3af690a50\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "with context.eager_mode():\n", + "\n", + " counts = []\n", + " times = [] \n", + " for n in np.logspace(0, 7, 50):\n", + "\n", + " n_tensor = tf.constant(n, dtype=tf.float32)\n", + " count = collatz(n_tensor)\n", + "\n", + " res = %timeit -n10 -r1 -o -q collatz(n_tensor)\n", + " times.append(res.best)\n", + " counts.append(count)\n", + " \n", + "plot_results(counts, times, 'Eager')\n" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "x5ChBlH09jk_", + "_cRFTcwT9mnn" + ], + "default_view": {}, + "last_runtime": { + "build_target": "", + "kind": "local" + }, + "name": "Autograph vs. Eager Collatz speed test", + "provenance": [ + { + "file_id": "0B8bm7KvwJklpMUQtbnVpYkdJUjRtOTRyWVVfSEhpRl9HYm5n", + "timestamp": 1531512047714 + } + ], + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb b/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..952ec091fb1883e4f17314efa8c458bfe7f01eda --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb @@ -0,0 +1,652 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "etTmZVFN8fYO" + }, + "source": [ + "This notebook runs a basic speed test for a short training loop of a neural network training on the MNIST dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eqOvRhOz8SWs" + }, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "nHY0tntRizGb" + }, + "outputs": [], + "source": [ + "!pip install -U -q tf-nightly" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Pa2qpEmoVOGe" + }, + "outputs": [], + "source": [ + "import gzip\n", + "import os\n", + "import shutil\n", + "import time\n", + "\n", + "import numpy as np\n", + "import six\n", + "from six.moves import urllib\n", + "import tensorflow as tf\n", + "\n", + "from tensorflow.contrib import autograph as ag\n", + "from tensorflow.contrib.eager.python import tfe\n", + "from tensorflow.python.eager import context\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PZWxEJFM9A7b" + }, + "source": [ + "### Testing boilerplate" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "kfZk9EFZ5TeQ" + }, + "outputs": [], + "source": [ + "# Test-only parameters. Test checks successful completion not correctness. \n", + "burn_ins = 1\n", + "trials = 1\n", + "max_steps = 2\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "k0GKbZBJ9Gt9" + }, + "source": [ + "### Speed test configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "gWXV8WHn43iZ" + }, + "outputs": [], + "source": [ + "#@test {\"skip\": true} \n", + "burn_ins = 3\n", + "trials = 10\n", + "max_steps = 500\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kZV_3pGy8033" + }, + "source": [ + "### Data source setup" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "YfnHJbBOBKae" + }, + "outputs": [], + "source": [ + "def download(directory, filename):\n", + " filepath = os.path.join(directory, filename)\n", + " if tf.gfile.Exists(filepath):\n", + " return filepath\n", + " if not tf.gfile.Exists(directory):\n", + " tf.gfile.MakeDirs(directory)\n", + " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n", + " zipped_filepath = filepath + '.gz'\n", + " print('Downloading %s to %s' % (url, zipped_filepath))\n", + " urllib.request.urlretrieve(url, zipped_filepath)\n", + " with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " os.remove(zipped_filepath)\n", + " return filepath\n", + "\n", + "\n", + "def dataset(directory, images_file, labels_file):\n", + " images_file = download(directory, images_file)\n", + " labels_file = download(directory, labels_file)\n", + "\n", + " def decode_image(image):\n", + " # Normalize from [0, 255] to [0.0, 1.0]\n", + " image = tf.decode_raw(image, tf.uint8)\n", + " image = tf.cast(image, tf.float32)\n", + " image = tf.reshape(image, [784])\n", + " return image / 255.0\n", + "\n", + " def decode_label(label):\n", + " label = tf.decode_raw(label, tf.uint8)\n", + " label = tf.reshape(label, [])\n", + " return tf.to_int32(label)\n", + "\n", + " images = tf.data.FixedLengthRecordDataset(\n", + " images_file, 28 * 28, header_bytes=16).map(decode_image)\n", + " labels = tf.data.FixedLengthRecordDataset(\n", + " labels_file, 1, header_bytes=8).map(decode_label)\n", + " return tf.data.Dataset.zip((images, labels))\n", + "\n", + "\n", + "def mnist_train(directory):\n", + " return dataset(directory, 'train-images-idx3-ubyte',\n", + " 'train-labels-idx1-ubyte')\n", + "\n", + "def mnist_test(directory):\n", + " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')\n", + "\n", + "def setup_mnist_data(is_training, hp, batch_size):\n", + " if is_training:\n", + " ds = mnist_train('/tmp/autograph_mnist_data')\n", + " ds = ds.cache()\n", + " ds = ds.shuffle(batch_size * 10)\n", + " else:\n", + " ds = mnist_test('/tmp/autograph_mnist_data')\n", + " ds = ds.cache()\n", + " ds = ds.repeat()\n", + " ds = ds.batch(batch_size)\n", + " return ds\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qzkZyZcS9THu" + }, + "source": [ + "### Keras model definition" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "x_MU13boiok2" + }, + "outputs": [], + "source": [ + "def mlp_model(input_shape):\n", + " model = tf.keras.Sequential((\n", + " tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n", + " tf.keras.layers.Dense(100, activation='relu'),\n", + " tf.keras.layers.Dense(10, activation='softmax')))\n", + " model.build()\n", + " return model\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DXt4GoTxtvn2" + }, + "source": [ + "# AutoGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "W51sfbONiz_5" + }, + "outputs": [], + "source": [ + "def predict(m, x, y):\n", + " y_p = m(x)\n", + " losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n", + " l = tf.reduce_mean(losses)\n", + " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n", + " accuracy = tf.reduce_mean(accuracies)\n", + " return l, accuracy\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "CsAD0ajbi9iZ" + }, + "outputs": [], + "source": [ + "def fit(m, x, y, opt):\n", + " l, accuracy = predict(m, x, y)\n", + " opt.minimize(l)\n", + " return l, accuracy\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "RVw57HdTjPzi" + }, + "outputs": [], + "source": [ + "def get_next_batch(ds):\n", + " itr = ds.make_one_shot_iterator()\n", + " image, label = itr.get_next()\n", + " x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n", + " y = tf.one_hot(tf.squeeze(label), 10)\n", + " return x, y\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "UUI0566FjZPx" + }, + "outputs": [], + "source": [ + "def train(train_ds, test_ds, hp):\n", + " m = mlp_model((28 * 28,))\n", + " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n", + "\n", + " train_losses = []\n", + " test_losses = []\n", + " train_accuracies = []\n", + " test_accuracies = []\n", + " ag.set_element_type(train_losses, tf.float32)\n", + " ag.set_element_type(test_losses, tf.float32)\n", + " ag.set_element_type(train_accuracies, tf.float32)\n", + " ag.set_element_type(test_accuracies, tf.float32)\n", + "\n", + " i = tf.constant(0)\n", + " while i \u003c hp.max_steps:\n", + " train_x, train_y = get_next_batch(train_ds)\n", + " test_x, test_y = get_next_batch(test_ds)\n", + " step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n", + " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n", + "\n", + " train_losses.append(step_train_loss)\n", + " test_losses.append(step_test_loss)\n", + " train_accuracies.append(step_train_accuracy)\n", + " test_accuracies.append(step_test_accuracy)\n", + "\n", + " i += 1\n", + " return (ag.stack(train_losses), ag.stack(test_losses),\n", + " ag.stack(train_accuracies), ag.stack(test_accuracies))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 215 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 12156, + "status": "ok", + "timestamp": 1531752050611, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "K1m8TwOKjdNd", + "outputId": "bd5746f2-bf91-44aa-9eff-38eb11ced33f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('Duration:', 0.6226680278778076)\n", + "('Duration:', 0.6082069873809814)\n", + "('Duration:', 0.6223258972167969)\n", + "('Duration:', 0.6176440715789795)\n", + "('Duration:', 0.6309840679168701)\n", + "('Duration:', 0.6180410385131836)\n", + "('Duration:', 0.6219630241394043)\n", + "('Duration:', 0.6183009147644043)\n", + "('Duration:', 0.6176400184631348)\n", + "('Duration:', 0.6476900577545166)\n", + "('Mean duration:', 0.62254641056060789, '+/-', 0.0099792188690656976)\n" + ] + } + ], + "source": [ + "#@test {\"timeout\": 90}\n", + "with tf.Graph().as_default():\n", + " hp = tf.contrib.training.HParams(\n", + " learning_rate=0.05,\n", + " max_steps=max_steps,\n", + " )\n", + " train_ds = setup_mnist_data(True, hp, 500)\n", + " test_ds = setup_mnist_data(False, hp, 100)\n", + " tf_train = ag.to_graph(train)\n", + " losses = tf_train(train_ds, test_ds, hp)\n", + "\n", + " with tf.Session() as sess:\n", + " durations = []\n", + " for t in range(burn_ins + trials):\n", + " sess.run(tf.global_variables_initializer())\n", + "\n", + " start = time.time()\n", + " (train_losses, test_losses, train_accuracies,\n", + " test_accuracies) = sess.run(losses)\n", + "\n", + " if t \u003c burn_ins:\n", + " continue\n", + "\n", + " duration = time.time() - start\n", + " durations.append(duration)\n", + " print('Duration:', duration)\n", + "\n", + " print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "A06kdgtZtlce" + }, + "source": [ + "# Eager" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "hBKOKGrWty4e" + }, + "outputs": [], + "source": [ + "def predict(m, x, y):\n", + " y_p = m(x)\n", + " losses = tf.keras.losses.categorical_crossentropy(tf.cast(y, tf.float32), y_p)\n", + " l = tf.reduce_mean(losses)\n", + " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n", + " accuracy = tf.reduce_mean(accuracies)\n", + " return l, accuracy\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "HCgTZ0MTt6vt" + }, + "outputs": [], + "source": [ + "def train(ds, hp):\n", + " m = mlp_model((28 * 28,))\n", + " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n", + "\n", + " train_losses = []\n", + " test_losses = []\n", + " train_accuracies = []\n", + " test_accuracies = []\n", + "\n", + " i = 0\n", + " train_test_itr = tfe.Iterator(ds)\n", + " for (train_x, train_y), (test_x, test_y) in train_test_itr:\n", + " train_x = tf.to_float(tf.reshape(train_x, (-1, 28 * 28)))\n", + " train_y = tf.one_hot(tf.squeeze(train_y), 10)\n", + " test_x = tf.to_float(tf.reshape(test_x, (-1, 28 * 28)))\n", + " test_y = tf.one_hot(tf.squeeze(test_y), 10)\n", + "\n", + " if i \u003e hp.max_steps:\n", + " break\n", + "\n", + " with tf.GradientTape() as tape:\n", + " step_train_loss, step_train_accuracy = predict(m, train_x, train_y)\n", + " grad = tape.gradient(step_train_loss, m.variables)\n", + " opt.apply_gradients(zip(grad, m.variables))\n", + " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n", + "\n", + " train_losses.append(step_train_loss)\n", + " test_losses.append(step_test_loss)\n", + " train_accuracies.append(step_train_accuracy)\n", + " test_accuracies.append(step_test_accuracy)\n", + "\n", + " i += 1\n", + " return train_losses, test_losses, train_accuracies, test_accuracies\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 215 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 52499, + "status": "ok", + "timestamp": 1531752103279, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "plv_yrn_t8Dy", + "outputId": "55d5ab3d-252d-48ba-8fb4-20ec3c3e6d00" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('Duration:', 3.9973549842834473)\n", + "('Duration:', 4.018772125244141)\n", + "('Duration:', 3.9740989208221436)\n", + "('Duration:', 3.9922947883605957)\n", + "('Duration:', 3.9795801639556885)\n", + "('Duration:', 3.966722011566162)\n", + "('Duration:', 3.986541986465454)\n", + "('Duration:', 3.992305040359497)\n", + "('Duration:', 4.012261867523193)\n", + "('Duration:', 4.004716157913208)\n", + "('Mean duration:', 3.9924648046493529, '+/-', 0.015681688635624851)\n" + ] + } + ], + "source": [ + "#@test {\"timeout\": 90}\n", + "with context.eager_mode():\n", + " durations = []\n", + " for t in range(burn_ins + trials):\n", + " hp = tf.contrib.training.HParams(\n", + " learning_rate=0.05,\n", + " max_steps=max_steps,\n", + " )\n", + " train_ds = setup_mnist_data(True, hp, 500)\n", + " test_ds = setup_mnist_data(False, hp, 100)\n", + " ds = tf.data.Dataset.zip((train_ds, test_ds))\n", + " start = time.time()\n", + " (train_losses, test_losses, train_accuracies,\n", + " test_accuracies) = train(ds, hp)\n", + " \n", + " train_losses[-1].numpy()\n", + " test_losses[-1].numpy()\n", + " train_accuracies[-1].numpy()\n", + " test_accuracies[-1].numpy()\n", + "\n", + " if t \u003c burn_ins:\n", + " continue\n", + "\n", + " duration = time.time() - start\n", + " durations.append(duration)\n", + " print('Duration:', duration)\n", + "\n", + " print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "eqOvRhOz8SWs", + "PZWxEJFM9A7b", + "kZV_3pGy8033" + ], + "default_view": {}, + "name": "Autograph vs. Eager MNIST speed test", + "provenance": [ + { + "file_id": "1tAQW5tHUgAc8M4-iwwJm6Xs6dV9nEqtD", + "timestamp": 1530297010607 + }, + { + "file_id": "18dCjshrmHiPTIe1CNsL8tnpdGkuXgpM9", + "timestamp": 1530289467317 + }, + { + "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG", + "timestamp": 1522272821237 + }, + { + "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", + "timestamp": 1522238054357 + }, + { + "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", + "timestamp": 1521743157199 + }, + { + "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", + "timestamp": 1520522344607 + } + ], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb b/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..bf824e2760e694ae3c00c9f08d9aa5d5522a9b84 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb @@ -0,0 +1,1512 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "b9R-4ezU3NH0" + }, + "source": [ + "## AutoGraph: examples of simple algorithms\n", + "\n", + "This notebook shows how you can use AutoGraph to compile simple algorithms and run them in TensorFlow.\n", + "\n", + "It requires the nightly build of TensorFlow, which is installed below." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "TuWj26KWz1fZ" + }, + "outputs": [], + "source": [ + "!pip install -U -q tf-nightly" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "3kudk1elq0Gh" + }, + "source": [ + "### Fibonacci numbers\n", + "\n", + "https://en.wikipedia.org/wiki/Fibonacci_number" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 197 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 7512, + "status": "ok", + "timestamp": 1532101577266, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "H7olFlMXqrHe", + "outputId": "472dbfe0-9449-4f93-e908-1a0785188a92" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 : 1\n", + "1 : 2\n", + "2 : 3\n", + "3 : 5\n", + "4 : 8\n", + "5 : 13\n", + "6 : 21\n", + "7 : 34\n", + "8 : 55\n", + "9 : 89\n" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph as ag\n", + "\n", + "\n", + "def fib(n):\n", + " f1 = 0\n", + " f2 = 1\n", + " for i in range(n):\n", + " tmp = f2\n", + " f2 = f2 + f1\n", + " f1 = tmp\n", + " print(i, ': ', f2)\n", + " return f2\n", + "\n", + "\n", + "with tf.Graph().as_default():\n", + " final_fib = ag.to_graph(fib)(tf.constant(10))\n", + " with tf.Session() as sess:\n", + " sess.run(final_fib)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "p8zZyj-tq4K3" + }, + "source": [ + "#### Generated code" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 541 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 103, + "status": "ok", + "timestamp": 1532101577412, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "UeWjK8rHq6Cj", + "outputId": "73ece895-12fb-489a-e52c-032945d7ed7a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "from __future__ import print_function\n", + "import tensorflow as tf\n", + "\n", + "def tf__fib(n):\n", + " try:\n", + " with tf.name_scope('fib'):\n", + " f1 = 0\n", + " f2 = 1\n", + "\n", + " def extra_test(f1_1, f2_1):\n", + " with tf.name_scope('extra_test'):\n", + " return True\n", + "\n", + " def loop_body(i, f1_1, f2_1):\n", + " with tf.name_scope('loop_body'):\n", + " tmp = f2_1\n", + " f2_1 = f2_1 + f1_1\n", + " f1_1 = tmp\n", + " with ag__.utils.control_dependency_on_returns(ag__.utils.\n", + " dynamic_print(i, ': ', f2_1)):\n", + " f2, i_1 = ag__.utils.alias_tensors(f2_1, i)\n", + " return f1_1, f2\n", + " f1, f2 = ag__.for_stmt(ag__.utils.dynamic_builtin(range, n),\n", + " extra_test, loop_body, (f1, f2))\n", + " return f2\n", + " except:\n", + " ag__.rewrite_graph_construction_error(ag_source_map__)\n", + "\n" + ] + } + ], + "source": [ + "print(ag.to_code(fib))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eIfVy6ZTrFEH" + }, + "source": [ + "### Fizz Buzz\n", + "\n", + "https://en.wikipedia.org/wiki/Fizz_buzz" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 125 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 233, + "status": "ok", + "timestamp": 1532101577681, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "33CAheYsrEQ7", + "outputId": "82a493ee-15b5-419d-8c9c-5f4159090a05" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Buzz\n", + "11\n", + "Fizz\n", + "13\n", + "14\n", + "FizzBuzz\n" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph as ag\n", + "\n", + "def fizzbuzz(i, n):\n", + " while i \u003c n:\n", + " msg = ''\n", + " if i % 3 == 0:\n", + " msg += 'Fizz'\n", + " if i % 5 == 0:\n", + " msg += 'Buzz'\n", + " if msg == '':\n", + " msg = tf.as_string(i)\n", + " print(msg)\n", + " i += 1\n", + " return i\n", + "\n", + "with tf.Graph().as_default():\n", + " final_i = ag.to_graph(fizzbuzz)(tf.constant(10), tf.constant(16))\n", + " with tf.Session() as sess:\n", + " sess.run(final_i)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Lkq3DBGOv3fA" + }, + "source": [ + "#### Generated code" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 1081 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 289, + "status": "ok", + "timestamp": 1532101578003, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "bBhFIIaZrxvx", + "outputId": "d076a7ea-e643-4689-f90a-57f5d086dedc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "from __future__ import print_function\n", + "import tensorflow as tf\n", + "\n", + "def tf__fizzbuzz(i, n):\n", + " try:\n", + " with tf.name_scope('fizzbuzz'):\n", + "\n", + " def loop_test(i_1):\n", + " with tf.name_scope('loop_test'):\n", + " return tf.less(i_1, n)\n", + "\n", + " def loop_body(i_1):\n", + " with tf.name_scope('loop_body'):\n", + " msg = ''\n", + "\n", + " def if_true():\n", + " with tf.name_scope('if_true'):\n", + " msg_1, = msg,\n", + " msg_1 += 'Fizz'\n", + " return msg_1,\n", + "\n", + " def if_false():\n", + " with tf.name_scope('if_false'):\n", + " return msg,\n", + " msg = ag__.utils.run_cond(tf.equal(i_1 % 3, 0), if_true, if_false)\n", + "\n", + " def if_true_1():\n", + " with tf.name_scope('if_true_1'):\n", + " msg_2, = msg,\n", + " msg_2 += 'Buzz'\n", + " return msg_2,\n", + "\n", + " def if_false_1():\n", + " with tf.name_scope('if_false_1'):\n", + " return msg,\n", + " msg = ag__.utils.run_cond(tf.equal(i_1 % 5, 0), if_true_1, if_false_1\n", + " )\n", + "\n", + " def if_true_2():\n", + " with tf.name_scope('if_true_2'):\n", + " msg_3, = msg,\n", + " msg_3 = tf.as_string(i_1)\n", + " return msg_3,\n", + "\n", + " def if_false_2():\n", + " with tf.name_scope('if_false_2'):\n", + " return msg,\n", + " msg = ag__.utils.run_cond(tf.equal(msg, ''), if_true_2, if_false_2)\n", + " with ag__.utils.control_dependency_on_returns(ag__.utils.\n", + " dynamic_print(msg)):\n", + " msg_4 = ag__.utils.alias_tensors(msg)\n", + " i_1 += 1\n", + " return i_1,\n", + " i = ag__.while_stmt(loop_test, loop_body, (i,), (tf, n, ag__, i))\n", + " return i\n", + " except:\n", + " ag__.rewrite_graph_construction_error(ag_source_map__)\n", + "\n" + ] + } + ], + "source": [ + "print(ag.to_code(fizzbuzz))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BNRtprSvwJgk" + }, + "source": [ + "### Conway's Game of Life\n", + "\n", + "https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "r8_0ioEuAI-a" + }, + "source": [ + "#### Testing boilerplate" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "7moIlf8VABkl" + }, + "outputs": [], + "source": [ + "NUM_STEPS = 1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QlEvfIQPAYF5" + }, + "source": [ + "#### Game of Life for AutoGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "5pCK2qQSAAK4" + }, + "outputs": [], + "source": [ + "#@test {\"skip\": true} \n", + "NUM_STEPS = 100" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 308 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 14892, + "status": "ok", + "timestamp": 1532101593030, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "hC3qMqryPDHS", + "outputId": "8405c0e9-e518-41d6-f5bc-e78df6474169" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\u003cvideo width=\"432.0\" height=\"288.0\" controls autoplay loop\u003e\n", + " \u003csource type=\"video/mp4\" src=\"data:video/mp4;base64,AAAAHGZ0eXBNNFYgAAACAGlzb21pc28yYXZjMQAAAAhmcmVlAACZUm1kYXQAAAKuBgX//6rcRem9\n", + "5tlIt5Ys2CDZI+7veDI2NCAtIGNvcmUgMTQ4IHIyNzk1IGFhYTlhYTggLSBILjI2NC9NUEVHLTQg\n", + "QVZDIGNvZGVjIC0gQ29weWxlZnQgMjAwMy0yMDE3IC0gaHR0cDovL3d3dy52aWRlb2xhbi5vcmcv\n", + "eDI2NC5odG1sIC0gb3B0aW9uczogY2FiYWM9MSByZWY9MyBkZWJsb2NrPTE6MDowIGFuYWx5c2U9\n", + "MHgzOjB4MTEzIG1lPWhleCBzdWJtZT03IHBzeT0xIHBzeV9yZD0xLjAwOjAuMDAgbWl4ZWRfcmVm\n", + "PTEgbWVfcmFuZ2U9MTYgY2hyb21hX21lPTEgdHJlbGxpcz0xIDh4OGRjdD0xIGNxbT0wIGRlYWR6\n", + "b25lPTIxLDExIGZhc3RfcHNraXA9MSBjaHJvbWFfcXBfb2Zmc2V0PS0yIHRocmVhZHM9OSBsb29r\n", + "YWhlYWRfdGhyZWFkcz0xIHNsaWNlZF90aHJlYWRzPTAgbnI9MCBkZWNpbWF0ZT0xIGludGVybGFj\n", + "ZWQ9MCBibHVyYXlfY29tcGF0PTAgY29uc3RyYWluZWRfaW50cmE9MCBiZnJhbWVzPTMgYl9weXJh\n", + "bWlkPTIgYl9hZGFwdD0xIGJfYmlhcz0wIGRpcmVjdD0xIHdlaWdodGI9MSBvcGVuX2dvcD0wIHdl\n", + "aWdodHA9MiBrZXlpbnQ9MjUwIGtleWludF9taW49MTAgc2NlbmVjdXQ9NDAgaW50cmFfcmVmcmVz\n", + "aD0wIHJjX2xvb2thaGVhZD00MCByYz1jcmYgbWJ0cmVlPTEgY3JmPTIzLjAgcWNvbXA9MC42MCBx\n", + "cG1pbj0wIHFwbWF4PTY5IHFwc3RlcD00IGlwX3JhdGlvPTEuNDAgYXE9MToxLjAwAIAAAAPQZYiE\n", + "ABH//veIHzLLafk613IR560urR9Q7kZxXqS9/iAAAAMAFpyZZ6/h5MpYA5/oqv4s2qPbYpW3jfK6\n", + "zQ6q7WMrNj7Hy8jZzmBpfHCwAAO1W4riBNsrapcCk+5V1W0XkkFULR4Qe+H3uGA2HgNW0zFAAUgt\n", + "W4tdpXv2OEg0Vuy5W5l/xGRmEGKDyeXyrM0S6q/1EKbad0x2mcHseUqNmeOGLy1N3b376XZKZcPY\n", + "IXC5F2332tNMj8CwOQiXM9PiCLyCVfZ3rQSkKBTZErkpS5kXUyoJG3FdIqLjRFKEapbUjcW64HIo\n", + "BeIbtRyWV9FyZfcTakx2KW3eB4ZI//MDykSe8CRgN76uBEqZFXwO63wmUREhHOb5AdaLV3xyGl/I\n", + "RV70rU/3t9t1aq5mFD3hy1aLTAV2U7nG072dyX87F7NgCxZHT2kFxu44fxf6gqVzE3PEbGr5fx9x\n", + "7TKXtmY53VP8UaeCd2HJiZ/sd165SutTnfiWvaLuCnmmXGF0AGqbj9S19kgOhTubZIJBydTTqQOV\n", + "YRlxbgKn2nzvunv9+NDG0/2ikyyp73W15QClmjyt8dUeynoN8CwtEQ59DdrAPZe4ARZTwWAfsRXw\n", + "1vcZ6Gr1nCNWllQw5IyZyxQtXrfc5p4wjPvGaltciG7d3FG1SGk6HDsZy5i/PsnkjRXLUvGbzYp2\n", + "2gs7ZSGfSJbEifctcMGeSqhOOYORKy6f/9omoieCVEEkniBXwWZ/eImb3nxF7SFIaBjgG2j9w5ut\n", + "BY6zSuQ5zRCdajzJ1loNO0havI8mp5yViAeAlLKYCxeK0Lha1FskL67W1YsARZVZ5EkhqAYEeTNI\n", + "M38Og48OXmj6QBN7c1b9uDUTacYEXO88ZQ1gCIREIMnm2Fgkir8pN4gtSeQ12sfOVz5x5KX7sa95\n", + "L4LyFQPDrFZcDBr4PWLeEEv8yzk0cYHE97GmAlA6WQ0HlWsS42cnXefvTPXnx4vcq8pbEo/slAuH\n", + "IBsrJEN1+aMCc9FNxwUPVbZVaWVjwLY0qh+mNWEaiNGRmacDXrYWw0NjqMPiLiFHacY5oGELRgym\n", + "S2mSo6zhsD1wKQ3EUQtwrjKPiDYc/HCqhkVwoWKUdI8xTS60kn4f5UqB0L77Yevh/wt7AnvQKQAq\n", + "QAEEevggRl1uigbOBTtscnYRnAj0edW4QExAzdo+RwLWXTzW/l3cBWTrh3ORzZQlxJ8jQTvPLB+f\n", + "bLazJZWFQQDcWhuhQ3gYcP1ruNwIroINRIr8px0UOgAhnk6CllxMN6gA5S0YPhFVFKd3n0AAAC9f\n", + "vYgISQAAAltBmiRsQR/+tSqC8p1IAOZemTPutEfx0mzK8zG8tdIxonBsDpoLZ+NnIOp4qK6idP1s\n", + "vbGvZz/zHM86Bg3q0yx2atmtgoo/Trt3YRy3se4HTjou+tCi7oJt2d7A8vEhVDu33JNJx+WCOgP0\n", + "03nVdg9lBs15v/0w7qMc3zqqJXCOy/Whl9aRhcaeOEWcD7uK6mCV8a6MpDJ959xBRfv2i/qFOFbL\n", + "Grs58WiGJcq4MQJI+rVWuFN50oiqBgiunfUrRmdviPYpNN11V9pwcOJwssWfIE3agnor/RC7vfLY\n", + "YoXzaJjtWLEL92OOaHLZT0j555xfb4FZcoJee+RXovB9IaoDdYRusngtBXPMUvnO+g2Z5Qdo9P8q\n", + "Zb8ItBAeHT8IBZAD/Z2nEA6qbxqOBSBtQNW6ZFYLtCTIoP/bLjCDHgtZk3cf+N1CpXs15pUIYWDW\n", + "elZtlTkM4w4EJlLdjLZyQPAeaBx/qoLmKyTKAEhm0hU8EcTq00f6fwkWgz2J6GTGtL/vJXgC8u4o\n", + "nTnf+Ou7sVJGVaouXxrzx+yGVHEcp/eV4gaFA95rInngQAOZWbA3558nK61JBPZl3NjEv5B9r9pg\n", + "2+SYY3wBAUeu2fgAB2+yYGw82pkoJJKpzYWORs6i1vn3GEgUTcwlYsdJcraYC5SnGvqSZhX7KM72\n", + "uE1e9bkpvpVyG/mkACn5R4jwX3xc2utCjjZgM101rirIF/7VfDtmJsSTDes+UVhbSr3SeMSI9ixJ\n", + "+fVuFZ5bnQPoRIfPc+Erw+K99JiGN+HE98/eq4pPlMY9oCfVPSdNyOAAAAFfQZ5CeId/AUuqOi5D\n", + "jlKfxuJGZZ1+rVyomjOIykvxtsjsuCiGtElbraCSFWcn3aIYWLrF3fPovVLcOnroBkiRMsdf5yJA\n", + "F87MQuoKeTaGOrxojCCCS64RiHrqNsE+7mfRRUDuB4sAEHFQHxBorgTukPSvrdFr5QDq+BhZj/6H\n", + "KN+IutwFWKX3ZX9pO3sI8My78TgRY5AA6FEcT91WcvnMypB/OWXzK6M8fYuhVVWipAZigjVOYhcF\n", + "9i6GweQFX9AV9EUQOp2qFbkrT5jceBRFLX6j4JUQ781/UGTekv1fcpCmzlpNpp8GdSeWxRL4gasp\n", + "F5uO5KW63rlhYccBo1cFwIN8txHNnwyQNiP00XC0PWDRZfaWSxsACRWrISow71IyUfcL7JNhjTII\n", + "rwDYATS0xZ9ep8siFC3JTxg1eNaroYfeI4tbkRHok47Vk+CUOQPuagVBtFMOOcy2OUbw8AWlAAAA\n", + "ugGeYXRDfwHM79ghzBo9nMnzfQPPIuvorxBb6AC8F4fYGD/t93kNSKNSEuhUXq9FKGtxnCkxN880\n", + "BPb/uTbjLTQVyPNuYlGl/gTlyLcVA/cDoLrl5TvaR/AcSLFE7C/t3kLx0STNibmdAf4TsHWKSblH\n", + "VWB4X7oQHrrDdhwIivRgUZf7f63j2XaGB+cbp5aHCCwJoovY51YTqsZZTz70FlSnypPHQBNzif7h\n", + "uvZkXhtEzpu9rYMo3YECkgAAAXIBnmNqQ38BDchAitLfY16mYQAQlVmv7062W8KLpIS1/zhS50Ib\n", + "b3ERigmkZKZMPaCsAi+zsLcku/gHGHnVZpuCZMFs72gmyuL4JFo6VjWcr5FtBvzIgD26rBNvP73P\n", + "nJjl3JImmFHiKjNez/gG3zTuYyCACuJCEYXyuEmzCM13hdCPHKg5GZtso0Z1qk6T1k2oiqF/3RIn\n", + "kyjRWuxBlHHmJ46TXULiUY14G+RAGoXI+u/G6muNclld2bq+6Zztuy+5ynaDWNNjuN1Ag9KUIx2F\n", + "XwNdepmp52/rOvISNPbMJ0U26OvqplXi+qHTbg8MLpUSIGCY8w9FZ5woLAENgvgu9M79yGlL20e7\n", + "ypJ4RMBqHYDpEz6Z+SSjXD8LsJ7VKlwo22A5Yukp1vTp6HHA35nV+PXK09DuRWKKdQUzmXVihF51\n", + "/+bB0PEFdoNxGdbbM7WveaCJN8XI7JgQWvw2nPlHX8M5QyPGSJ2HEexumoFrABvRAAAB70GaaEmo\n", + "QWiZTAgj//61KoCPNGHq/MxnjqmxxQAEHvTwibmyMZGX3ES9Abh1tMR+/DjR+6dnqRr/VxCl6gEP\n", + "wJ/5EYCYfGaGmQYsLOeM3v2SZjdvqQBwrwKk5A/63kFm8fc3QCLe93Mldv3KWXHdFT7/mudSntDc\n", + "vJwStG4jgi5LKlWdSrVaAxOmElsF+zWNzaCIQ1dOiZqi3JKj64hOeq1XIWyGvRvh6OLKBpB4rL6W\n", + "ugf7H/IPbSQuF5jWV7zL5LhxWiTiI+kAZTUMfO2YOLzmhCUSN9GAmNzgY4D2awYB4V4QTDjI7kdQ\n", + "tL+3Pmfl1HVilu7nC9CzQSvWIosiwv4btyHTL7IPT2gusybyNfW8QO133L6KbDhhXSDWUtcIFCgn\n", + "QUm36C9hvgGjorpKYr5VnErpJX6fRJm76fFYs8/nt763alyqdcSrqaTOLaf/72Wkkmlwbq3nLOIw\n", + "ADFDkkAPwzaM811K11iK/3HaYRT3nEhjJQFk5v4WBXwIVLAZeKdtC8YoGN9K6isN142fOG3s6fm4\n", + "J1nMtOEZHIwep8In4slLmHh39qBzhGZO3igiVpgz7u+JMBeFkVHe72vduBjIy+1dqvxL/TPics3s\n", + "+alwfTMNQKave1qW+5Uj8jZQTjcLAtKvzoako9VMIOfQUQAAAQpBnoZFESw7/wC9ZU4P+UeGsidW\n", + "4n5tFkXmtxppYvKQ+WGj/x3AAdl6+9c9x7N2b/yJykTvVggfpMnFUWtxla4sr1ouwANom+Uf4IBJ\n", + "/zXPovndpGdy98nJbZxFU4rrWpr8aI4YmRX65+IGTn756CZWwXKY5DyMgKnDcCtk0HEuoHgdGhh7\n", + "1PG8+nue+pE9pBHqiBNWAjPd90qfMtABmMShLoXtUObqYbqXhJvVjjFhKdPS03IF24fu9Z0ax15V\n", + "DnkiLmgyOCvJmcdIX70L2ZEECd/hxrSq9JUVjC41OX0F/ayI6GtkPMUuZ2xWkMFo5rqOAo7v0Zlk\n", + "ke/79TjeY13FNiowqcbhMwfDuwAAATIBnqV0Q38BDXNpg2t4nJdhAA5ru/5Co2KbB/AnQt7fa959\n", + "0crOQgtTxL36jtVyKPmfuQMYuWbJ/7bYTEV8sEjceHvN6B0CSEZzVCjaPLzOQJZMQpQ4K4WKPlGc\n", + "lnEwYAC9Dsejj7Fbk2RyCFiJinyU2HOscjUR6fW2jRsAFpVq/PtZDVPvesPG3AqooVaKHp9Ex+Da\n", + "AH0OvccSugyDKsRBAEiYR8645aXxbFSzraQsELDsIIr6HRN8F3lUNVBvzNO3mxBhq4th/kgZSjjJ\n", + "JZrYmg3UfIUO/jn4xs2XQ9Pa7Uy5K3JhuIQwAOUKDmAMC0p6fgz2on4ceyEcfiCGDPZpPyL3391F\n", + "dXID0ctPQ1a+Hk7UcAc9gSDL8CZKz59YyO0ACPjfAKV3Y2dbTAKdWBsUU0EAAAFEAZ6nakN/AItk\n", + "aaqbMCcBE0iEIDnEBfRZN0neHQxaz5DPSzK0ZSL640q0AA5jkP0YAYAumNCN0MxJYpWFoQ9r43H0\n", + "i9SZLdv1UbgpG3aX6KESZW7AgdlevaBngH/w8xYsqWx5t90zzi7x9VyRYpIAD+XTrxvgBoFILNCs\n", + "gd+zDA9uvbAPlLMwG/qFltlwvLokMt344erv3a/C/ySOwZHFzpakInpJ7MQHkmKi1KHZB5KrfqwF\n", + "FnglZJwWbe7LtVojTdwQnAksziDNlEWCkMQQJwziY1KYtlXMNX8mZ3MtYR1KNf/CNin7/ys9ZQyx\n", + "4Zlk//H5KDc/8O2+JaxH20CAaAABxgSxo+yJal1LnRHYfOQ1TygNueW/rPAA37g/6fLS7mbYKz7k\n", + "dsiSiy1mAV7n/qq81UHJPShQSXK+E4Y5XKuXEWG4AAAB8UGarEmoQWyZTAgj//61KoAW7kO9JCjl\n", + "XSE6nAngAJVxWWFl/YDS0gZ32xjwUFed4hmI6rj18z16nS3Mz1iMmFblrtaE4zGXS046COODiIwH\n", + "QG5lRmcBExMKlnynQruQtA8n/NitzdP/ysLrucGyp5nKV+XyJURULfxk4kwNp0a5TFlJ1fusOOJm\n", + "y0hvsvEg+d4Jz3anvWT6M9n5A84CGucNifV+WlN9gI9gs3qSoCZdU/gglcFYM5u8YchzhQFyMKxn\n", + "kpfWK2LU7aaZHt6xLbqjuv74523K9/dtrrsFq/LySiv1P9Wk6/6d5RC72z4cyaUq6hMMn4IWWRo0\n", + "zJIM1/lSYsWxt5/M1Mkv00Rt8OZvmLxuFfd1BIVlANlpgZ39RYhqqzU6v1HwaW0EudelFBGhr5mf\n", + "GaDE05Z8ywp5rN4Qq4D4GNAGD/qgEjtaDDf4ZBAD/TAHBwxfNjm2nPAdbbbIuWSkkv8NK6EMlKqH\n", + "mOktd+CB3P6Szd1+HPnUsyQ3659r3XLnoi0cvM4usfW+BgxqT0mgHSgn/F6ajdTNM+a8xJQnT036\n", + "7195r0uF5vwi7PIviCQ2E4Vs4Wx80/8tBDEJS4qOY1YJ5aNV1OV82fB3HOimLHd2vU/d4Cv7OBh8\n", + "k3gNFcjeBGh+3lQcDCLZrG1mAAAA3kGeykUVLDv/AGVBMHxAlJYGEpFnv2bb0ADrwvVKxe7+SIJI\n", + "g0dPJdL0s9Hd2mGX7rpdIiUH9ZgtnBO+m3uPNae/YtN3u2p0kkCez2KiPNqgSoEcHM+ePgq7afkq\n", + "0HHTSZl/+QbjsyfbI/0lv1mLAJUd3u7VZPPHSdXK3vwLfAwOe3Nid72slU892DijWVvanzM1IzDQ\n", + "XfN6x6GH2qfaLrHePrJTJxXC/RSxcAol7x2JJ5OA8VjN8jXu0yKirBiYqgcdFf9odG8j4bRmE2wD\n", + "MG0SKuGrJfd91b6B7hbRUwAAAPYBnul0Q38Ahz7YAbwPIqnkAA5sEIcKo2/sVUP0LEeFOLjKjaet\n", + "5YFAjDbL5BIdGqWouG/H8ozoec2ZpUbIZu0ELtG5yXc/5opSZlnqbOpqdTQkLs6gr9dv5GbFvVjS\n", + "Os1j9FIMQsdc8pttosNtygWB8gLxr65El6umAZE5CVU9Mc8Xxg/tenmTduGK9Cd7qRDiu1sLYR2f\n", + "or3KBMo8ebz5q5EmWucvREbYSziQIIycIwJg9OG+aH+ZUEQbjbfHfaiX7yoxGJGP78aNOHP7GvC+\n", + "JwM6DxnSyowUBAqkW8ckgrhet8gYYrt8MIe1MPJQB6sv8hHuAXkAAAFWAZ7rakN/AI9XvmYGr0rf\n", + "QEvrPPTQWEAA5ru3wBCXPJiC8OaE25OBvVl2wRXqp61wQU4HxGJCAxkSOz+G3Yzvg36uCK8bPZTq\n", + "avaOG/H9WxjsuwAl/bIYJdnyD151CiUZ34aErVIixKJ53oKrLeHr3xLgxuH+y3w5uH5lQRsL0Pmp\n", + "0jQItTBkKwlPywxFk55pROuYZWi/h/N19QaFlF7WPobUElLlr+nCH+pVt1nW9/YwVGz/cO8zwmWe\n", + "Fb0OnFji7CYSsi9ScC3a50GjUP7IpaY5NAHv33V57bkO/BD6dnreymTbSmQdcj7PAJkvz610fMqn\n", + "mDGTMB31oxAIE5eWeH7mBZouSgmtxEamul7sYaTPe7mP6FqNCz0h6wLot/zAFwx9/D2+XB0x8mmS\n", + "b086o+gqkoYoHQeQm2Sb3MU1Bz0KHDGo9jCmsBmecxs3oNHV4KaIoLKAAAABrEGa8EmoQWyZTAgj\n", + "//61KoAcdmk2P6doyaR4wEHxsIcmssCD5f+3/v8PGtlbWZ+A0oGGFPTAdgmU2TFbrRxlmwUCouNe\n", + "8freV7blHDodFImzwP3saA3AZT6NUl7vDGH/tw5n9y8rP4XGnhEXBHK+6jIhoAYc6G1CDX0mqczJ\n", + "7tbei5I0YSkDjza4rJSbAF6cRoJQH3s2Q+ggBQR0BfH6N3QlPVwd9YFvP6++J+XrbNU56Pxu6Wey\n", + "51asar4AaARXHregTXL4xn/VNt8Ppk2xD3/1jXAVXdqMlS0tYGM/TtrcuTC63Lx21RQtklG6k0xA\n", + "eWm6W0oL0KTvxuyegpC2ySp5v6zpSEYvzWR4IYirfT0RYU+jLtX0t4M/L/0k8xOLTHbouoUPD6DN\n", + "dYYLYlVX5noJzjCAVCiS21OCcIKqWD/YiU/+dTZpdFFNdHEa/MPvUEq7cJD7ANJ0YUweepq2Eqdh\n", + "57SC4Tpg6jyEnFgMaHQLSz1nJNh4lxM1TPouGZ9bmQdDr9WY+nwzRBa+ZLnaqBSYKWSKEs/TNtNZ\n", + "ev7d+EnJUf9G9CAmmiSDlRAvAAAAz0GfDkUVLDv/AGU2nAwHHyQlvUxuENDSO8vXFIAPilnMlQWb\n", + "nTHwb8wkIo6JKOaIP9blrrNXcWeeQDVprB1Bn//+nbSDHls1apJcUyMHUmojA58P91gutTiF40zp\n", + "fDaF096G01gcvpH5Za4+DfUvxQpt/wH5PntJzggww1tLhP1NyH5U2TTgrnA/BevK2aCa9xCuCVgA\n", + "JJZF4uqHE//COeWbJ6LIFJPoadxAxbrAcxPQQHMzEG5G5S3Yfd+YJBLrdO35JvVrsUTYO4AfvJeC\n", + "zwAAAe8Bny10Q38Aj03WPPyvISnWAC7KM5WfLH925SBeAKcvJaYOa5WZCzX9H5nU/7qAFTCgAnl3\n", + "rAoSnKk1337XDAnLfPYAAOSIcqQwF++e4HouwNVAWCEsVyl7Y6DnBaBT2mD1H8560KoMvm3kKNNC\n", + "oxFCc4BdAIXk45JUbGFNGYAjCbBbJInMjwa41HA404yKnJG7rNXdBctnsSL/36UoXvVx3J2tGX84\n", + "+FHk7e72CsAyB49ajd62idmFQji9Jj1GaiqtCIjWs5o6Mz8s5QfrvipNYYD0YZ7gBBGm4AEz17d8\n", + "isscgsp4QI2odbuEJDq1nfJbW6+1HGcN1XfDC1Xfa5IptM5UYHm5zIT4rSPBIDE6l8/NhVxlFP21\n", + "JPQ0DZxnZFvxIBznQbqkhaGZjMafgFoRzC9Nl17x+K6e75RlplRZtXaUIbjAUFBJIQPkoIrT6/O9\n", + "NtkAmnl8qqUC1RktW/RjiJqOyRTTITHqNKvKy/0gb88xEvvGPgzcSs2KpkbHJWmCGIlSWEkuqcCE\n", + "jBn3Y8XOQxMUxEYeLPJ/9s/F2fT5NAnko+RFlv75fWLekZZP2s17yJ5ccFGhZyrkGX6u7xXK7N8G\n", + "Qlz8qfOHvgMQrlB8p4j7qtnPgBPf8mcsM295CuAZxkK+sut074W+0hM24VMAAADaAZ8vakN/AI9G\n", + "UrhSy/Rrhc/LGXguupji5cAHC2DVoxU1gWUkKeMT366GcmuxH5O8lBZJeHl8r2KNT0EaVARyW7pN\n", + "L4uNsKKl/WAzLJ1OZWTQf4NaAfodQGO9KzZS0j6oGvr/urKiQwbP44Tv//glYQyyCFeq+8nnrHBj\n", + "aACu2w1otySh0DYMX412uY6EYcx3GtQaRpNPiKQniWdVV2KH48fVxDy0uLS0SmCZEAWLVNvtWqO+\n", + "q2OwCBr1m50s0i8eRTlSP9xoKtxWC4ZqL77eAW3kYEBJOAywYUAAAAH6QZs0SahBbJlMCCP//rUq\n", + "gBY3NzYDjVIwwAKbp/vtZn3NtK6t0V/4sA0MV4ijJVoTZ+e36T0E9eQ0LOyzsqR0ULZJUDRy41oM\n", + "RdsBwM4wyEJC67daWmuDEXKhZo862uqAH8A0QJ5u5RKBPFpngChYYJdWzP3onEWImG8Yryy/SXt0\n", + "jQ5te76AagLius72bzwZ4AZfLm/04ID6oXhPwqkf1cNsu4/kIt7oCOETiL+lzwHLEnEsdPSz3DxD\n", + "uLGkH8o6jHofDxEXcB6cOS43aUxGKPYPtHCj2gw6RzcRoX5lD5mwqtoCTxk6N8TxyipSUyNnbA2b\n", + "G5NuBUVLHTce3QKY3SdkbyH/wzdOpT3YHUE+FYQwMKCF6SMyMBxp2gI9k4yUZYljUiekF2XIFkfv\n", + "TFy1RUmikOycLKkTYTreTarsMD5JfjZ2FJWrroj/YX+uNeGtKNZl9Zyt+k8u4Htq1bPYEjCrLHds\n", + "qeIuFWmvxTYEQblStjDXmWfITtxy8KvOgn9iV+KlidrnVhlE7Dz30fuHXxxFZvIzhgU9uv6sSC7T\n", + "vZuGMsKGBGTYmSe0P9hLI2VyM/8GUWwG/AITiU4a7OVDjUNRPaiIEt8jt2oImPIY8qcrJ82CVd+P\n", + "mSjoppoeHUTHmeo+koGqjhwT7ueVHNT5VZ4yuGKEDdFfEIkAAAEMQZ9SRRUsO/8AYrbCELHs5dcg\n", + "AyOPuRHZUWtdXLx9XaNQixO/8Cc4Q2MgEa/wKETsHiR8C1XOv7rI3JB0rg46JfjEArbHaTHmANKo\n", + "+czcI/sIduYNFOE3TvObMh/KtGpZSdF+qnDDtY8zD+7RQUdzmkG5zeDj3u4Vq+f3qnKCwgbU+U0R\n", + "dQR9Q60wXqL03p/iYVxkI8jJqvkECuxT7efJI+5rmzyP1yn+WKY2EsjjB7bwwVfe6RxBmzR9Ed/9\n", + "CA95ILUJxNg4HsmCO2Ko+MqZAH3wMlG18kUm2ogL3cKIkVXogjofyKhbsSpKLpFFk71DzB6NrY/3\n", + "HfknWM2yn9yeQB/joufGEf/bvMAS8QAAAN4Bn3F0Q38Ado97WJWiqN4XS53kTA5YWsnJBdebpf+9\n", + "lcN5zPySAC6fH/XzBsBKbxdm4pTiPFVrmGXyhaRiB6dxtlwj8MyI40Do8AXHq41BAunk4K4PTgzR\n", + "rFycWqaL549wB2C5jNCLXlq6Tuytik3ijlMSkx9noeIG2Lc83eWkRkQieksQSO4xI1tzzkdqaNhG\n", + "ExZARu3MauZwrBopslb/ZLdR5ZS0G6p8o9DD5cphJjxJoSV/70/0Gr+woS8Zj0JpVvvpygE5bXQp\n", + "/YBCqjmq4uOCyt9SvCzPelUEwXEAAAGyAZ9zakN/AHZ6+HiwE6fxvgA5rqP9zmI+FShvhJS43N4N\n", + "sc5a7qq0DK7DHadXkQxf+APmeqLrIGM9X5aCQgeyxdoAlcQoyNsm6ol85w5z6JV8A3YntmCae+s8\n", + "+8/Yheg1ctJWrSharoeypUyemQeq9Rm5cIkSOS9Ej0hbIHyFhPQW6K3SawgMNVKQ0s1BpJvXDQSY\n", + "x3jIEdIgEtwe7zce/DjcO3RNN3g+SlPoM7cl0qJbM44NIDG9JGXcwVrY/YKNrpChX0yegP2ZHDI1\n", + "MzOs5eWP/2l5loJrLid2mK4Qhw6EGFrIadsV8rSjzgHRNuzJ4U3JdubidEobU0ehkU0P6MYRK/XM\n", + "58mVywGbsw6LPu56h1S4w3zHGYMd1zPKOsnCUhaRfrSZTxvjerNQ22prVPqBstk4JgHdnSScrwGw\n", + "eQcqvIw7gKhonPDKM4fJtO4n2EsI5Cd0iGMjmgPw/PU3FL8ZP3QbYLMwZ81Wd7BLLBDf+ngKiFIe\n", + "it4neyhhaE/a71b8TxeM/ZrgH9+D76dlgPI1ZJW6CCVyIs6Y5gK2plkcgRYa0MwWF+1A6zPtBEgA\n", + "LOAAAAIIQZt4SahBbJlMCCP//rUqgBY9we30eRuAA2kMf/9/gX2SHKs8Uq31+W7Vx4LugxILnhMT\n", + "6icG5WQzdpL8yjIXjBq99nVaYweUdJE3LrdOpsVxNJ3kODVBkposYOoRuOMi/SNhcjrJwShp6ljG\n", + "Qs7tSeRJSYDkvm+SI2ckjbManbEesw6wo2ZffuryaLuWkU9SNALC+2QbPJD4bFy7sTmB9+6VOdMm\n", + "rnLvYN4ZyAJz7OhQG85P+JnxdgXgvSv66sWBs05p3vOE+53H+HQCMTLVgvoYmHNTIYtZ5CIln4hA\n", + "GrjLg53unVVQTiYlSzZrRE2vmtsqac+v6CrcbtgC4HktflvPTsvgqWNHri9NWa+EuXgx/AgGkZVJ\n", + "r1n6gAd3jtjLtv6YvbPiBBo2AhBUxCbYyroAjcvjwUBtRjXTdDEvdYfItmTKA7W3+KvVi/PCtod6\n", + "/3gOoaA7zRdO+8+MHlGl/c2xzQhj2O1n8eJkOu+NcsBkpmxyosDi11EOEaiQ6vfnOvH9MSM+7D/v\n", + "k91SLlwv/nF+5eDPHSLZQIoFUjHjwVoSGCdOLqmIe6tsfTERCeAhC+1bhRhe0612KIL6izjolsR2\n", + "nUgrl1o39HqnKAVqQ/HguEezLTgmGW27Df2kp4E1wRl/EQgEcsMfBPga1ndY4uHPYq84ArNCWk+c\n", + "YwxlHAPVC3PK3Zp2kQAAAWFBn5ZFFSw7/wBXFVHDEfqz5TAg6AmqzzGCl9B1ICKhB+tKz4Y9Km1L\n", + "/vZyZ1OR5rO815FlrTgGoncUDKVNjpKrVerCm+HleHb1b4FhYQG8B61zGq10uLuoQHIyL4Cv2/mm\n", + "s5Mi7ZftErBt64oWYphUyh0Hmn9dYYheGFzLdE9gvqcAEGJDyLZq+nfiK0Px8pHIgaIfsEdSUYcC\n", + "8Otyxta0EKY+Dm2m8AtQ8jjuDmkSHm/uLhgf1uCnztOKFhkR+ydRCeR9tnIlTfiv3gJbsPT8swjP\n", + "0OUm6yT8LhwwCJU0AGI9hN0/kTkz+NeSHjSPaBx26MAfS2Y5NEtva844h4B/RttjqxMsNDiDrfB4\n", + "5xn/Cl/3XrcF40eivyUSC+FHzx3M4BoLQLOKf7iz8hKiUrqRGVkGToUMxkr5192x9xCjbuvLRMd8\n", + "9Pel4WIOhSi52xuSf1eEhC5VVAp4lHpZmHCbgAAAAaABn7V0Q38AdnTaV3jxqK844c19uepGJJSA\n", + "C7DQuTz6pWfCzxcMbX5JwHItpyM9y3YT46z61a7h5Lyukp+nSKoO0zQhT0EB/u6ILUCNvVbb/89X\n", + "7TVI5UN6EFwYYfi4uoFmqb+5Cd0J/+d2405yTsK/f6WH/T+vNB1DYWrW67ctgHOgMHAWDLG9mitl\n", + "16bXmPVSi2sWzpWYg3147nlnaD00aZHqQlrMPzYTLLFwWHOLNqCoWpNLMMEevc8AnQWeykk9VNTU\n", + "NXzAXhrKDXl1tLQTxZG7GX3K9cQyeUnjfH3rMBGDD2zCLGXrMfPVl9EJ/F5M49Rjn38sXUf2JvF8\n", + "D9r9tV1APCHN27+egfFIMDg9OhrQMtjAe3WEfpYS7pl5yHh7ZZ2CedEo/Wf/ygYTAQFI72AaUTrV\n", + "n47d9OSqAdYs7lkgV0864auRyPQeTKK1Sp3ADeIFS134VGBNG1VnrfyZuznYkI2r0FVkGFrAXpUu\n", + "ZJmyKqqILhJ1OTBM8C0VBV2QXBYa2aSn2jj9t40/wJJWc9IGAVR0vj/u+wFocjwf4QAAAZYBn7dq\n", + "Q38AeUc/pR5QUuADgu7/kKjYlIf8yn+MfKKvFMJ4eRJz/DRqteBIBJsZW3T3phi3NzuSw0zOvEhr\n", + "CHz7xEUteyaR+fa6YCBeiCtangbUerW/UGoCobzV/74XB/lXH53NcEw+6x9o3/ZgwG/7l4psK3P0\n", + "EqSwtCrcKAAv8Wi0Z88mFp3Sp19shMF41mqYa8pNsyefrruQONS60LHg/1GySbrTeTWW74lCDwnt\n", + "BGXpwghp/QF087PP7hxkE8lvu8APh5F1FTiOCBSvJFm6yFC/tz24gmveLoV4Rq/qtYWRE09VDCDH\n", + "yjftToPMsyi4DoCtXsPRk5Jxr9Mn6xDxGjfz8uMmOKJ15ejPi/Sx9cR1QrBsU9dhcYifdB+c0AMF\n", + "PolB3N4pBZAASP6m7EzaTer6yZ2sIKcQdlGt9xsZ0SHtS2313gpdJkLEVrHpO5/BTcfUTTcK1+bC\n", + "PwRYX+iIyInP1m6htprdy84ySZ5IaGCpRKFxMCf5w22wXyyon+dlMPKACguyEPTCCZQ2MqEuC+sa\n", + "uB/hAAABxUGbvEmoQWyZTAgj//61KoAXgR9s4tVmwJ9HTza3s57iAAoQf/wjqzjlXnP+29f12EfR\n", + "S7B+4I2epG2qM/uoQ7VlrfXFlhjyX/aTq0n55QXAKa2xUKolKsuMfmZFFc6+GP96b13JiSidvPgt\n", + "2SSGnq9Yw4MfceFmgOaZRcwoMnpdb0UpI73YdP+DfypKyrkDqKWcBc/BGhrH8+XdnpCNDXfg5rMl\n", + "b0uFlQ11yUxnDYOfRwLbdjJA6FYddawSEVorFtY7jkSQx+OUBUgWkKC9rhKB+uV/yqQsvbuFiyYV\n", + "MviBpsZgSSN0TOC5JedQ5H38ENVBLjXnWZD9PQyueLoT4qwtI+7lodFSnBG3zboWdj6P7XDbgKT/\n", + "zKkFObUjwhstiQtohzxd5AXhBH3DQqNv6mRzuMxFDcTEo5ut/0/1HrPGOF4R3sJ/eQT+YnYseqvc\n", + "0m5njpgI3qkLmn8efBB4q3zWGpHCxBwC84HKjuugMICuXfcJHKn0aWkn65aEjT8AdxDWE09InGyo\n", + "EM1wsU0JgJ/qq/6MdHWfQW6+bt5xWlpYJ4axi9wZc3Aoz+Rixn8UVM2e/bd31+W37ucz9udquxnL\n", + "2JdNUAAAARlBn9pFFSw7/wBZVXkLa/7xg9HEtDOpc+GkSv0gCD3x6eQNkROUaCyL6QH8m/0USPLW\n", + "nllgC+uXg2X8kUpaUiErsLvwKd9y+trtKwV7xlvkAn0JqEnToCvptE1Sb8eF86DTi2ywy7WE/imn\n", + "jNBYQny1cV38ScnZp/V3phWQAYBG3kUdNNuj/FyVB7DgbQbTLK48AO5nLYv8B3LvBNBfBJ+ym1yg\n", + "YJXKwjm8kt8xUjO2UGKeggZOs7YHWr5Fj8OX4jV/B3/cMzP+f6YyrayA/80F6f9vgrbTlhWdlFQ8\n", + "QtrHKjmrl874OSSPJYH5wfQfF/1NrQd6soxjmSWYI9/FqOPoy6ujUPxQvg1fUda+wK31Cv8gD96H\n", + "LPqpgQAAAXkBn/l0Q38AeBaU9hYCjxV6lA176iBcJKIHTfhwkqkAB+a0LmdvcgdK3vyEsSkCI+8U\n", + "up3OQ4OQId/B45+Mf5P4Fc2VsfnQAACxyzNkvgEEYwZk+TyOR6/VZmeFNYMrBdqc2NNBlh56ISK/\n", + "h5V9lagvsX7yv0p9Hk6RXo3uoMgKhKOv/QgBAqhUvAKDw4DS7G31tehd/myRMmCPxIJ79bZsQe2/\n", + "iq7Nquzc/VDpPXFZHPvOmiyfyrt6Fxc2jLHZJGpvacPTIeLJiSaBxgRTEKBr/xXaKQjc5nLhlwgc\n", + "HSz1WRlyOsXOkob3rY8KoGVETaaIvHEl7sVHsV3QN7iR2rIGzf6YHv+c3l8OW1b7tAMShtcCLifl\n", + "8k1OtS8Z5o7MNTObuLXIONSPGo1fC97qRzqHFEfMZntEMqsFjjWPM6JduvRiAv8p/h0kRdcTeRox\n", + "t4PEdFJikYgCJgtFa00LDpNvd6Vv6MImiivCAgL9L7zEaNCr8p/p5ZiDugAAAO8Bn/tqQ38AfAnX\n", + "r+Rl0wYAC9kEZglKr0YEZPxbFiynbDVLyUoB5/4mwbggJCKqWcWLXkOc702XkfuMANGy7OD7QUCV\n", + "nopFHkp77AuzGvvM2JQndhYVkdbX30/kmHQDID1DcpthKQBbzUjm7wgAOqbulxKDc1OUw1plN1OA\n", + "iXs8Ju+zQDtZelKPfekDEF5iPA8IQMn3LLocZ168PVHW73hdmgfMFTsqduJxZ1oiezDuUBPUKdNQ\n", + "1lGg5KUsS5A9iNuo+n1shJKCmk20FfXGeNEywAjYeaq4bao/dd8nZn//htlIayY083IymAgdHbKW\n", + "UQAAAW1Bm/5JqEFsmUwUTBH//rUqgBbB5O6qXkABRezeefAxp9PjwxeDBuTTFSUNk2voPSz0T3Lj\n", + "1K/LmQtEI6YkskJKgxvIXHGf8LHTV/h2Mg/qV3IQ4zvBygOQs98iZyR5jgV+hQ58R6xIcus/6y5a\n", + "HrkViRrv8Sk7So3LYWmfkLzyR6vcCKhF/sCJsY8RS8BK5OOGU2Ll4Qs1n4jPQwTLDELf8SF2+07z\n", + "zB5hexERnOHmWZ9THKXS8j6NXPrj2p32k0gvmlI4b/Of9evEX9mDBp5GtQHOvTswQ/VYUajAUXz4\n", + "5w6EHuB/k+FBz9pe+B69syJ2X5MYn7Qi9rKpCl2kZv4uAWXuNo7oIaU7hr6elcFz53tdL9AEjCAb\n", + "BlT3p448134hjvo9lj95CHF5teK1w+R310Gc3NQ0eeJcsiYD2EoVrHHjVDF/m8I8JtTUFdJ3xm+G\n", + "muADOcIpcqYbeqyKWwHmgvRze+DMQbkLo4AlgQAAAR4Bnh1qQ38AfBSmnoPKZzTuFWeZOcrkeWeU\n", + "yVIALsozlefbqRZf6f7w7fkPoFSkdlxkJJsnO6qzfbc/Kotbm2yeFrIQw5yspszQL8gAAvMHKSnw\n", + "f4CTQ2vfLY55MADj1baDD7LZtn0UK1Eh1HnwXobc+mdHd/JEl/a2Tszf/EZ9+J7oMl+BYsjWKwNY\n", + "vOv5flnnPLcex/hWFIF4n+hpBybvasl5hI9mV0CeAAyAclftj8N9n7hadcpM/TOVmHbSkJ3cr/k+\n", + "StSwI8gY9k3tmbMSZc42caMpFr6YdNCCIj52zmNBccPNFxW+UT/4qCqtX1gc2j7obKDaWzC1yj1A\n", + "td8/VAjqVn+FzuuEokhhvubRT3RCdxeWnBTCG0CxwC7gAAACMkGaAknhClJlMCCP//61KoAXgkIw\n", + "VJpvAgAqN7f+5rJJcY8tkjj7p4LozjswOy2dTydK33mOBGS+NojRzBOlwt3ro+/vdQIUTIVrXKwh\n", + "2SrHPCPJXQoCjJUPkRODCmqbZeBHsv1r7iIOZPpX66HYYhWgPLvPzAb/Nqu9nQqKoyphhNy32+S5\n", + "qAFvjRKLSjPAx7GoKGUNMbYduhsBsrvVTwhrV8uWAls2mxYggJzVuRUZSL9cSt+tjl44BXjlbo1a\n", + "I7ybNHG97GCzcbSNcg0RA+iqwDsdnrZCO0zsNdWK1qVmER0PsSf0dicSrZwIcxZWy6JbkwQn5TnO\n", + "kAah3wAs6pJvW+a5ZiJHl6sVlU3yCOlrECAESqWu0YR75WfiMXgesBOuXGGNsC3icmPYNzM93us1\n", + "7GQTI6RmmFHGo+B2yAB2YJiK1YN/T0ltUuXfFAvL4UdHgEXOVIqVj+S+YpITMKy740IvYQ5zuZPD\n", + "ahdXF7HIU7xE0W12w+6qkuyZwxUMXLXdgx6svudMor1GNfDCdymcKIidhuuXh7vdQrgbivH7usVC\n", + "zjMqgjGahkW1YlmytCooEIoULx5ux9DK360iAi4u/nAomESdiosanRfQ9jQdJSpo4rurLfeCLF1Z\n", + "XsQAQRTcezHlxp1tz3A3WsYMA9urPBB8pUlDdB63MfZDCBphVx/Ddv1AMvPXFEPu18oREsV3BdKx\n", + "e3lxLWWpytzF3zXttYGgBb90j9DgRGE1uaAWyEAAAAEiQZ4gRTRMO/8AWVV6uU/hFqUNYqrP23yu\n", + "FpB+ECoAQNVnJ92i7ZF1i7u1D6K4L4gxm2RaiGsRDmf2iYWEjO8yGHAqwpcDep1/+H221WMh98AE\n", + "VV9Ferf+hy0D7Zu5rX4Hp3s1TpcNcEBIKPHVSHIzaZKKfPXkqE/ga/eepp8Bzdc39OW6g91hVVvf\n", + "WJxrnf77rapWbmivuJFfeO9u+RRykk/agdEi5E/5a475KGQprA2yl390PNrCvoamPyXbETwtbYAQ\n", + "pF9uDZkHdN/NQ1P4rz+zQLJx21eQsP9WBLswpDFYg9BjPw+3VrVEzeid2j5wJBlq+56Hw+Ex6fI6\n", + "1O0GbWSAC5/5Zg+kGX0Yx7/We9PseMWGwXWIVwqI7oHPEnK6wUkAAADgAZ5fdEN/AHk02mburIzA\n", + "1V5U+8CauxZABexQ9zxvy3GIkNn2+19EyZqnRm0DMMsXP4ZwiY8vW/qdBTlATfbmIFDxCTzt76+L\n", + "X3WaNfG+rqTfzj6gLFFHl5IJDtQmIC9KAmTgQM0Lp8TEDdYJnPYGFybq0Xdyl74+130DteV0SYTD\n", + "hgB6230zJvCx8ZW04pZHmYvtJ1LZAxF3BAWKPXcstkh7/Er8zYdPblR7K6t0r3b/sIHpME53VRBk\n", + "ggj1uN/p+iN4KwToxjP8kZ1opB7xpkyOQpicygiGnwjU7EpZpywAAAF2AZ5BakN/AIdka2Wer/IA\n", + "EJVZr+9KNmiS7zXHA/5uJU6D0CbJOrsLPWcfwAUCZZjhlCsnAlgzrrGOONmuxU3En1TfTKb/7Pu5\n", + "1R8PfIYkV/dZFitvMyRPMvzwXX1OcxtjbhM+M0LCh6zNEWJFi2Pi95t8cspIknD4iXNUblA3oEFp\n", + "VGuXt+8S3Upf64YqAxWADhb5zxXL+O/gnWiyawM9fyRrYcExecMkEiv5MHRsJs8Euzdps1vwxzNA\n", + "Zu4bu6ic2K2ueNja78qXGaHz7xLoPIVJv/T4KAuseyOhznfFtKf0Ey0eSBVK9qutGGF83lfe5Wtv\n", + "xb73lHTKLAyiyJassoDHBSQLAcUPb4nB6xWNr9G9gWtqEIp4Or9tKJzZIZ1tnIKZFZGb0ELAlV2+\n", + "pKKDz5nW+syHi871Soc3HtgomT3Y1cp83yQG1GdKkcJPkU1uJVzsVPzbXbSU7/z2Q7cikc4seN2D\n", + "ryQ1l58HjUs0ikCXV/V/CDkAAAH6QZpGSahBaJlMCCP//rUqgBbmS0XBN5gNQAaCJTjyhVwVkMwl\n", + "GF6KXnd0XUyzqjFCJEv0D2xQiJu8if6sKo6qHl+BP/MZw8ss5OKq407INzCjWOsjf2HTKyC5fNLK\n", + "wiJv+PzieOozn64ZK7RRud2QUaDe0kuhk4uCClSYQBImrxmWeEf/X9zH3+ilYhfoZigVm0IoMiuu\n", + "YX1ERVdg0Ld9E6wxbYMiQAGJU1qeeTwc8vb3w3kiJheTA2PNXtrJ98RwtpnhN6QxMe1dw+aQWI7S\n", + "j0oQ9iNx73N93RuNVRxXj/57S9VltjA0RTZBjLvYS81QDA3fBgaNHNzOBZ7dztz/rTxxOpumjTTw\n", + "x9FgnvlMsjx7FYPKUcXD5quVKd8lwTlOiGVI7X1HEv3Hh4EvpYVt6azhUBI1qGunVb3X1lyMhWJ9\n", + "p3muqcicwInEt+BuHY92HoNXaaJJbbQmNX5s3QJbI28Pg4gc2gaUF4SQRcBgM8uwcYUzxEkBS06L\n", + "0moZm8bwMsLYCLj3fgXOyFudpfg6jkYPDeVK811WbzEz8Hcd42XVL0EwE3bwDc+i2I4+NERo6J6l\n", + "d4d7nOIvqUuorZnDPtlYcfSWgBqdP0tQHvFb4Sv9QUCBvXlH2IEiNzo/daaHVtbFRNZ3cag2HOiP\n", + "lMxyt8xYJMnG7di2JiwAAAD7QZ5kRREsO/8AVwwP3fRRACC0tQoY45xe6yfL8KMHlR1wbd4HcPUC\n", + "+4PcnqOzdoNv80ufRyOopFYryJahX+qWFUVKK+nDtdvegTv/PqvENcT8ykEwwQ7z2oNUdaMITYi5\n", + "4tC5YA9FaLSBorMGx3aocAbiF8065MBqyaTkiW7FtGRHVSPubGixAl7hiQRoBoEipfCxkE/EBoII\n", + "omSCNrFRyjd8oY66cDfZt+iBI44uLDeP6eHMEpBALsV0FY7iWjBLaYO1t2PsklOb93SAExoyIX1I\n", + "TiPXiUgrCYe7dgepAF31BCnOuxiIAPWKLDHZLhGOJBLqdemk1EZoKCEAAAE5AZ6DdEN/AIteG4cJ\n", + "hGXgWAAHNd3/IaNiUh/zKhTXYgf+UKkbUvWJoLo7whMXByWkvy3MotNcPaSHeaKS5vKy/hBJIgk5\n", + "CWcdsbd5QzFHyjOIZiaEAA1AziqRPTDRRVYKhcrm181rAlAdaYmvKZAOu92pmI39/PSQjhiMouSe\n", + "XVT3pg0s+/zN7WMQCHqTmey2TTctwD0YnAH9CK4EMAw1jPCCTXgop9epuL/iXjup2S+LS3pGE3iO\n", + "oIHon+1ERGRC2Vp3b2QAstSXzK/2zI+bVnxf0PhgKqa/NeuEaF2SBGZ/TyqGPDnQfJRorCp1s+mw\n", + "tm/3aVbjKRTXeSwl+OCfF6rMqjf/Zw8/4yrjLNmiyOgD8OWqATkM50NFqOShrrTCaHdcxgVW70ss\n", + "cCXKxvzAUCe+4nK4C3zP8QAAAWMBnoVqQ38Ai2Rc7ISR6q0L0pberS7nbElvP1eAuajd6ehFPCEk\n", + "va4007gA4DkP0YAYAumNCN0kma3A2DvFPa+NTDmrilkXNhiNVTFRLzynsy8rdgQPBH6k5DFr/4eZ\n", + "jmJjfYPWB5+2eEYYc9uJ5Ni70hsVFfV+T8zp+ZkLZnd2wv7AZ7A8baF9R5O9oQlCkoVPxkDHTrmt\n", + "rElQhX8Fi0yj2+BVP5O9UNPGQU0+M3KYUTg9yTBG2cCw6Drt49/5M/86NN03F5R9JS9KGOfJjIlA\n", + "koCavGpTFqq7OYU0RM3ilfXBmxvL5QoIK28Uvs71J3h/IvKmg4v/14n3/eoSpqNUCC77ty2SgAAi\n", + "rxQNIHz2GF/lpTynlwsORrYNT1lJMVud8AAQb+/SaHWQXmhJ+8cZTt8XuMgG/t/hdF6GqyG0A/Pn\n", + "hWRq+asN+zBaeyQUWZrjl8ry0h3WPkAZksFb/gV7ABWxAAAB/0GaikmoQWyZTAgj//61KoAWw9mB\n", + "34Nmlq4DQoTYIkneVdOFHxDDrFwsv7yxZXXwNkGuLMduj7QGT/7lr2bNfzApMJfo9/ffM5g789Cz\n", + "1Mn0zxePHMHBL6IHHRVXWyqDMhVLYnQ9xFtc1jml18If/8STBCOf+AZjMnARcFmX1IwLt/ziVSoN\n", + "e4GPKKZqfZWytoW7461OuaeZ9dvtxrCL+W45zobgR5vOrVM+Opl+w/eFlupHlgpQBWgJcPy8sZC4\n", + "/O9laiYA63xx6M701UUvGFsRI+RM6anXyjKc7TVrmZ/YQKRjqB6Mejs2G1mTDkBn7T2ZURI2vZ3u\n", + "VXRNsQnGYDxRUokS3YRHs9LEF/gxKSdLEEiHDqcoIHyS2FPM+cIJRSvB7sxIA3hgfN/O4qDK6VO+\n", + "t71oi1H0Bkz1ugONnVTpQr+WeMS5AtXXNBMXU+ycO0+R9eRe9BwSk0V6tHm/HJ45oIYvyWTj3yZa\n", + "JQ6q+o4isbf26PsTbuSAcvQoMnzEXJkqElGJ8Z3rZtdkIzQW0DDnXeNRbj2wQmuUNBknMsWOw2/t\n", + "fD8BErzYLXI65PwTY+6R5c6RWYzF9HNMLBaO1c6cI4yEu1DMKtZW5FrmVuc6hg7VnWxgAgOdFKFA\n", + "QvmmcrbHsqCH4rkez1y5GoMlxeOuW5WKa/JdcefAflYgakEAAAEQQZ6oRRUsO/8AZUEtmg0dqwLy\n", + "ubLYtABfXw0ri+bvSnwBqWW9hB3/jYP94x5LyZNY560IvuBe5T4EX3/71Gbqj7BS5SJLQ7X1JK0z\n", + "I9iR6McwRU2BDEhu+2JQm1RA2fBVxnzCyNr1JVnfyyuumlkNzE8n1UgnkIbS/FMxc8DghB7zqZzK\n", + "rkagW0hHwSjNf+LJf3DnbXyvnzmB1lcv8Z9QlsnPKDef2giSgbZeTNWRMfeu91kckRy0SSKkaYVK\n", + "KUUpf450Vl2TzPLRaNhk7Du1IJzIJRf9supxssXD9v31LAVibgyznyLU/cS57Vr8KEXG+WpKysV+\n", + "6iQmQ/hCoRg82drzuniAPltxm8MMUZwVMGAAAAEzAZ7HdEN/AHUKF3WsfCAA7NAZyuGlRySXJzA8\n", + "WtPYIqCp+udF6BaVoG3w794kSqeP3syNbVlr+uFhruNMOOzTsNGrbATFZMl9DU6mhIXZ1HEAskmI\n", + "VVSgXlz4sVX35JqYrDPP8r9Bsg/O9tAp7LnTMjWlqOdgOPhHpyqf/hmokPsCwqtKfsDhxP/tmX60\n", + "fhM4KsfvpygzK8jmUmY/GDBCISRQeW6U8uaq8guf+cvy+sP09JLJ4HsULhIsm6kyYO04HBdOFUDr\n", + "/8IzlOKX3w/FCxhimlJIduY8iySAFQmALOuag1Ry1Z3p7NpGIGhZp/q5hzsMAsH2jpHXQPdtFNFH\n", + "4VkqDlRDeGqieCr6gwu3hPQQfF9yauq4qf5R+bfPha9tZ3XjpRO4eqNaj2xEQrcb5cIJOAAAAUsB\n", + "nslqQ38Aj1e+ZhXsJE07lvgA5ryx/X3Tt1hQ2T/wP93u+Km2fQtCsS47kHT/v+BMMbdxEWzwYvcd\n", + "d3NYalS7o/aUthPBRfYGmx2hUIQijLOXN4leC3SONeoCputIRor3Lgsy985K8UL4nvf1+pFmRQg0\n", + "eJgJ9ubt7jVqU4S6enDDZ82+hYwxDWOROomkxsOv8nlizRgAHHE1n42Dq5sLIu8oVYp/4M1h4rCy\n", + "m7AmDrR9dbHlpV6pqPLshIJSKr7R6XCF5H/mgt+78ttEoS2XxbrmVQj6DQtTzcYF1gqzE9DaiXTc\n", + "rKcf1aBAFclenBiNHhbAMEE20Br4FIkr51a0ynzJocMgaUhstOH+7gKJGCsTPkykOiVzQeIGOfi6\n", + "AmLkbzIds0NOnV21ExFbxIFAMu1BymG8Kjwvo1cLb7372R2f+Qt5Z8LjmGrBAAABxUGazkmoQWyZ\n", + "TAgj//61KoAWP/AeMmkxh4qDG8hcZFMZjYIY//v8PGtlbWZ+A0oGGFPTAdgmU2TFbrR0QmwUCouN\n", + "e8fq+V7LhZ4IhSGjAEZXRALCc6lvXQaVk4Hy29vGup69bTfpCSIWWGXFW7WfQjL50GRbZZRZHQ2m\n", + "pjAJ2N9/bloCCNQEfrVxCeDkKfJqKlRpIdnOUaiQpsnEysqkLqMfxaCLAtiv1vFXcLPLizzlMPs7\n", + "NIiiAuhD4+CMokPsODEut5yq6fM1zRym2P9iids6rfyvN0EtWlvUXkAIdmS8HfE5DlX5rtipWZ2i\n", + "d9rb+tQcwCfWN6erokI6tARQJu2c+ZSF/sI7qofDkfNVCHii2Msza0cnJEbLkEfdF+gBET2KrdRv\n", + "E5mgO+6ICEAI6O/h7r7DxvTQ9Wxzo3mHNo6898yojVZYUAEyiEUBn5+alz6XfA0d5GcOXFRjv906\n", + "SVSt5h/ZyjXd+HmcrubYPlDuxhjCrkqyrKcbhfJHp/Mq+DI065H9OXdNO/+uDSHvPcKkibqiAVhI\n", + "DqTA+NZM5+PbtXMsqU6iKpSzqr3AN5mBITP84n9JoTkmCR2U/+5h8eajZc3UcAAAAOdBnuxFFSw7\n", + "/wBlSP3uCsGGoV8bqfG+TF6JTvUuRSAD4pZzJUFnxrFOJYnshFJtjPOw7rAcguf7FPJIlPqbN5qs\n", + "fqCPl7TU74m2w4/OJHMnDpS1+crxo620hZORUqqaN/UeMSuSm/KKx2/MSsIgkvOy0fYS1MAD67Fk\n", + "Z5FUhBYQOPZatG+Xc3Icj+kvLjp5v9fX+nJsaNN4CCl0quEK1R//8eZO87p6DKKxlnRfV62uCNE9\n", + "o2MWYwf9qwHYbtyqG6I4xWPTngQnrsOmiw1Sy0bIvHiKKw6nsCsKdLVPqCFU/q5rppy8Ah4AAAIT\n", + "AZ8LdEN/AI9CIO0JMMhrV/0AB0HLuqwUdobO4BdVbPV1Ioua5WZC0IWTaPE/7qAFTCgAnl3rAoSn\n", + "Kk1336t4zGyyPYAAOSIcqQwF8zee7dn7XFk1tvgy6W/qOMTmkEiEdwceoRsnhNmrNp/TK9OoMIUg\n", + "ShyIuwXG8nP6tDCpAEYSuvpzo5kchXf9jICMUEGqQZjLulIdzbNUEecLTDRk1r3gpdToPPcXdXTM\n", + "AElxf3acmkXSo1kx4tBmKJrXm4kNQ2oDIaqLOc1dGZ+ccoProxsI+jQiCldj17rGF1/E4alcIa3L\n", + "dIofRLGOPkev2msNj9eN+tELiQktxoUq9fKnDsRx9Nbc5IkysRYA/KsIu02gpfPyisLPQwjLSjpr\n", + "jTxnZViCfPC6UCMSLVKUvso8AB0eV8Q+lldoHmqd+EeBeeJOkPU3vuU/GQacMWsLnKmVt/65Nw0r\n", + "y1AnL9+YKkDmvNgpqgQANfZvj5NhddHche/p4la1cXWhY3W/jmtWxMTkOC4tX16bao5sNwcVWRvt\n", + "UHjkDIOIXB+3akBV5Lzaef6YjjT1MeUeFh/FB0tOMV3Bhvdw35krP/ItZ1RF5hRCk1oYqz0ykGZW\n", + "YkciBlvCsweWM2wXwX55h7SZHtxiKM3rO4Aff+TOWGbe8hXaapPE+4wKof+j5KoQ530gP62KsQIG\n", + "BV49pf0LYkAEd7yVzO9dhYYFAAAA+QGfDWpDfwCPWoxxjdaiaFtca/OwfG9dSAC6jYuqYuZmzKSC\n", + "kzbTtnf9idy9v7frgKuFjQymibohZCHRXBQdujo9Laqcw233I4Za+//Mdf06kxHe/IBTsCsxcSfV\n", + "ksVUEdqCe9dEwWwg//4Ee8Le2gLXqz21e4jiFyBOjP5GsM1hpupcfwZtr5Mo/ou28BY4QZExXJ0H\n", + "FzCqK0jKq6c//ut1tsd+kiOyZUVGRAFVkS8bi0vvjrj3zga9Zaa6Mt7yQii43DdcrobbVIWdc0QI\n", + "3+rsc8fgmOnJ+GJGdWYzpFLd5zMjS5ofw5IMBt0GmHVcG82Z6YQkqKJHzQAAAe9BmxJJqEFsmUwI\n", + "I//+tSqAFjc3NgONUfiwAKbp/vtZn3NtK6t0V/4sA0MV4unWIJlE1N72EjQeUPmvxOpceaVXIrAK\n", + "21oMRdsBwM4wyEJDPiji6fXmMlmmsCvOtr78Aj8gA+xKnVDFjoVlH7PPNvnMo0iZJruZeFy1B4T9\n", + "/2iVnlLy1r3LZhoykeyNXqaKEANWeqYl2HjpH92g+fHSONko5D2m4SRKJwFWFllUBg2RTQ3etVYS\n", + "PdQGNCLeaZwhH8zjnIe5Vuu46VBC79Le/PF0x5A18FileZQS8Adcvcamp8leUQ9dML537b7ARaSt\n", + "9Lyu3Sdke9BouNe3+hTyxzxAi1Setn//aNMjVtdKZIT0wLvPIMCsfe3gvhpNMtez9cWJYRUO4qU0\n", + "Dlg6h/pUIog+BzidDDvn6SZ9WUgEXhGZOFeOBYowQfwTGI3ac1V8O93aTpJwa/om7scQbOrwAjjK\n", + "gaYt9yqViBt3FWYRIoJJGYqmGJkf0tLvcymA+Hyayho8kg3J33tLzi7Gkd8xVzsn0AbjvoJ9u5le\n", + "OKsB4L1kcStddnytXouu9GStBCQSRLPeb+iGeZTwQ5uYY8D5fTAcb3C6Ob+B7IWRbbytzq93Kz0y\n", + "yYvbeUq1qJCNW3/zJeXeH+8yV69x5FRyM+55j6UAAAEdQZ8wRRUsO/8AYsUcQvOGOSSADI46r94B\n", + "/W+PEO3biH5wUahFid/4E5wZcJb1S+5KPsyD0qQEL2HibG5BPsDLysut2eDJfU6ijjP6zrYmNEWR\n", + "huQfgh9NsMVuoggiphkYt9ccXxVhYHn++9K8YAnkm28Kzp0jUWHgD2VeIoDjCfJPNnBqH+CERm3s\n", + "nubUQ9LmttVf/+MNJAJgtOFW5A6IBAcBpJtd5kPS+zJ8VxzguhOiD6Pf/zfgjMDUsehmT57QUanw\n", + "gbdNgBf1mSXZw3Czfs4swXmaj+42V39PQblTRJ5hVxxBfyBMHdtD+eP+pUlQP8pBAAnf3v75+Q0T\n", + "L19oeS5dx79IIwiodA3vtFf2KOiU2gODZqY3kJGizWNAAAAA3AGfT3RDfwB2j3tYlaKo3hdLneRM\n", + "Dlhayh8NourV4B4kYRi+kgAOdUf8hAGAI5XCPTeroAwXn8G2yGEphnv3FPeZqmLNmvgLgUkPciaQ\n", + "A3x0WVLvMk+lZn6cJdklOXHEnjNKsClw6wU0RbMDBk1zQUzYb/75rZ2h0N0KqL096XGATDutyhUZ\n", + "RVkyTgfbEgHdPAmzdroStgpcOUEN4xVVZX2E+XrryGs2/tIi+iUaglsBszkGSHUeEuoEpHc8PRHH\n", + "tDc+6s5rO2oABm+Gux/PUd+4yoXEBbF4DtdMIooAAAHGAZ9RakN/AHaNgkMVTymoPnXABzXUf7nM\n", + "R8KlDfCSlxubwbY5y13VVoGV2GO0t+vExf+APmeqLrIGM9X5aCQgGSaQJX4OQoECqyNRzFZQDLhW\n", + "KA4dfYJp7oYRPF8AMOzGYqm7AO7w7FtM2J0yD1XqM3LrKYS1dGZTAzMM0YXyhFuS7+8HWwRTCnl1\n", + "B1MtLMYaA8qvJY/AATH13D2takXBcx78I1sCsI+P57X6Q2Nh62/bggQuV3uhAAN0tyrIgbNQYVBH\n", + "gFwoUmXrxaEApAv0P2E40tM9SJDDcZe8DyE7ljCyxGjQA+gKJHzTkZCCQsmlxDg5It6wsdQ6cusN\n", + "DyWnlyoq3MMo7ugMYcm1YMEY73l36Y/R5wo4wUzuNvV2tJ3rSYBCfXsVjc5o1oA8OllKUpgpBG5u\n", + "9AavXOqCqjA07sUF9WlQ9JPrhiXa9bThYRp0lNBazKKlKwsBPK9zJ1/OayuptCCUOtFLyDYWpp2k\n", + "qNXWH8r0IpnJjxnQFcNmI3LKk+rH0vqX+48vd2BUqTcJ4rwX4e+V6oU1+lJyU8fmS4Kj/iQFUx5A\n", + "ntiGKLVWwqfkoYN2YexrEPVBTpKi81wf61aU8NAxYQAAAjdBm1ZJqEFsmUwII//+tSqAFj3B7fR5\n", + "G4ADaQx//3+BfZIcqzxSrotcVc8CLm7cBBc8JifUTg3KyGbsl0UtvUGR3t77PRffuzjjVfcKeiAp\n", + "EmDpLoqmMXTQU5wmHksjapt36fasfEiGyN1dOKyOI9nT0TFFL0pzQSss7Ux5GajOaQUF29zSIoeo\n", + "7hOusjWiFyZylISVuEBU8nCgDYn9P601XpFko2u3FAuYp/svCLJOzc9W7b14FY05eVZdhfmiv0Wm\n", + "d+i5ZPIv9mhB+8Cb50V0LQeFfsyfPeAABtfp/HIPaN+amWONE9vQ2YbC1JsqKljPbi6Vrd258gHB\n", + "PNyXvESqATfkK1Gnk0AWxo7XFr5y0Ce95pJr1n6gAd91M5RV5lL/XAgE7sYG4524aA+cXAa2XPdd\n", + "1BugfbN6YGWbktwAoVIXoUq7TnrmhBrw2FHa1aE9uMJerl9x/Rs847iKP+iuBUD2VIUOVa/G9Po0\n", + "ksPo1bHVIsITIKnrhXV1NabDgHAc5kIv+PJk6IroGA19oMw2I1d4rGiaYQZE9dmK1VRARJ9VXDBJ\n", + "Vlz3aoQhCyQZvwzvxWhVA1iU1RO1TWnJsppajNeO4Vg4/b+BSviIvrSwwqmjaRr8iuCpVTgz+ZJ6\n", + "95zLiSdnoIFqQJA1Hz4YR/KIOmAfhTTnHcdDelso1m8Bx2oHlzAOiYwR4NhSSRD6EhhCU2kXf5vn\n", + "vYdShk1Y3/pp+Wd9yZwIwTneJB0AoI0bbmfrtbbWj1oAAAFQQZ90RRUsO/8AVxVRwqizyog1fzvw\n", + "w3oFk0s5kH60rPhj0qbUv+9nJnU5H1hbksC+yivmpdt3FAylOp/Re8NoooEKQr4q7MX/kjNCB5zj\n", + "aCmG5E3TxVGWGCYMCsdEF1I+HuXX2a3wLCwf1iqCfznNMRG46GE6nIgxc91oY/zfMduLLCzyb8AQ\n", + "b20W2eRODsXd4+7XC1RndLreJ7Km543AdL1iUo99hYdoASXjyWRNv6wvJrmyFngIDlQOrLluZf/9\n", + "T8Y21pcggXpfTtvdj+B+3lZv29AFHkL2xGPZvyL4UyVUgb3U1DWd/iySeGzlK1IbRNu7obP1czi4\n", + "Rchm1nI/pS+cSuamJbhlQHIreF0u2/zcrSGkuOpbObSfAY//5j6RVfcQovw5wL1RQN0tcA1GtFxu\n", + "ZpovaLthGUkeOPh8iV5bEpupJR1R79Ew1sEkTDugAAABwQGfk3RDfwB2dNpntdq7wHtHkfExb8Mi\n", + "4AOIW+6weDVD4WeLhja/JOA5FtORnuW7CfHWfWrXcPJWyNJJfpx2maEKeggtR3RVEAdA1a1truYO\n", + "N3PBvt2C5hri51AyWveiUQtRNh8OhcT8b+NVPo5dLHlfN2wr8ZipKDuUP3k1md+EiPqVCrK5TuMQ\n", + "knvfHHEV8fXqrrFiHhWYrAGbSJdOrXgrQTN4JDv0LMwXs1Nl1nmEdfSgT5BF3DohYi4r2xGfiJcJ\n", + "KMZ1oPHaRBjgxhu40ZP5HqUG5rQWHD92UCH/Terh0cf4e0554mxHgDF9CBXD2Ey6LaV8LB9Jb9nA\n", + "f7tFFMQRIVaLiP+uig+B5OoeaCY5+GdEeHuY+ZE9jNToZ4yOUwNfysZaXJBrtfqEkQosI3EYRZQA\n", + "COu9BHjZjXsKjEmWe9Jj9yWusbXq4WMANyEJEPNSeDcqy2nLsc2OqSE4CgyCqy8blbRZqycUiZt/\n", + "3NpFflI5dk/7eeQ8Uo727U5FhceNm/3Tv/0N3CZNlPGV4f+3/HHJknpIjibzMw4AkTq3Lkxy1XZ+\n", + "FA9yAR3cZ0/eN1EscyudULe5dTvs1EvlYMWBAAABtgGflWpDfwB5Rz+lHWcxYALocP/IVGxKQ/5l\n", + "P8Y+UVeKYTw8iTn+GjVV8vbhgCZ5cI/70wvHdrfJYaZZyRIawh8+61+/vwo8HAkEyAQL0QVrU8Db\n", + "Z7+ORIRATWUQyS/LIyP8q4/O5rf7OuybqgrrJ5JQm3dvb5EYgnYLHCULt4xtpfvTsT5gEynxu9HL\n", + "Km20sO4q1oqcF4MPx2dj7xETa3veUfVJqfvwop/9NWsmPrdhY/wz7rinYt2HcWm7+ulSBZtWIRv3\n", + "yMRoNM+lyCvZDr0PaN2HfwYWOYr/NgyLM3qvI6TujkJkGWBIPuiFK/SHsSPx7iAMcrZ3CQvQC1rq\n", + "psLEx1Lx0vtWsdQAcjEYe6l7VHqUFbgcjcHAYPQIIgi8NauIxLhxUOQnkJo1mXO/e5w2N9AAHA22\n", + "RlXXsFU92TGe3GmYdLlI4OC3IklyabPhxs95veQzY6n0a2BnyANXxWrQG1vVVVAYgtb88NEdo6By\n", + "gCh1aEE1VpUTP0of4shaZpNk/2gd6T34r4uIClLqdADAAdaA4/epPc357p2Ro8OkrT9okATGaQDM\n", + "AYBiPC2kAQBkyn5ImAAAAdBBm5pJqEFsmUwII//+tSqAF4In0o7iUdIU6DQAMu59v/f4eNbK2my3\n", + "LFfU4bVvmOXvurgANJp+yhdNshfKZWyf1yiq02eNo25TtXkBg+c9UZquU5KtxkSr2wTyRJb5fWbg\n", + "+NL8Fosje7XYkSxYEiB3sVwPhHSvNWh2d4v6fN1lP9qvuUnfb1Bn+TdruqmJdM2vx9efbO5Th2CP\n", + "KiH3jeuRzoCzSIUG7cY38FVzT4nUIJdz+2KjjjJ0E7ZNKQ6lROaPqjFN4utrXaZfqGFX2nWmlL+h\n", + "PxS7plcEcSC1oWpbRWphWgodqD5c2VmFV0yO9NkxWYeDoEeaPVORAB/gqWAbIHdoZVHMBBV6fLyv\n", + "D3u5FppjGB4tzB+WC5jnXJKg0Sk3SkInESay6cwWUVJt/G4Tfg6wbMdEkCvCKlRosg/RTpp5P6wR\n", + "Z2iZfctuN2EQi36vtriULh4PVI/bw9ZXWlyhMpAYPlW3C1NvZrlJMNaSqGSSnh5cJMfrxHquXcAN\n", + "CTgojRhZ3tMe14Ny/HV3UfnpEJgrqxN8KZxlRpYS28Q96uqEu6NBBsBIIz0ei/Mg1x57c0aguL4j\n", + "dVBDXATm12Zi0uXfiRBRiIror0O2CDrlUQAAAPNBn7hFFSw7/wBgSQL3wIE2Tv5B6OJXPcoXMcSb\n", + "cE8qv/1v/uy5HaAJNUQCTSWlcVovOwe/GLZOdN2BNEgb1OlzNEinzyASzg3GuZ9zFeyJHe/zvxXW\n", + "qHgQlhmuH8QdE1M1s5tXy5mwAyoAiCrzupaN60ez6jWL/yRvGdGiPt3qJJLeMG60zAMKa7QhUJFJ\n", + "FMWUFrcLW6iQXx7VTZR7Qo0gz/aCe+BxT2h34J4bdpQTH59SHjOd2X4DMr2kpW5buE3EQBEKSUD8\n", + "yEiNy7MVRtsZHXt1V4Pb6TljTGXtC9pzGwEXtgadiRP8dhtDjxgpVN3IyoEAAAFOAZ/XdEN/AHkx\n", + "u7J3fsEfo6cXtbkNOd4swcOB3voAJyKHu0c0/MGiiYXv+2wca3XUwSOEG+s8df2rHPxj/J/Armyt\n", + "j86AAAWOWZsl8AgjGF9fWv1mQf9jrWNuA4APvfeLBFbZJZm7otp6Fc0DFqB0XCbEvLTkRU5ySc7e\n", + "Y4CD3ziWyxgWkLgxNxAV0V3rzOqUGhFxcTbBCJI75knYyulzgB9+SazwgLVSR2N8nND844Y7GLCN\n", + "0aeRWZgNIAWJkPPhP1VnSRo1jOpV+axgAXL8ExpNwIvLk+O8lekZ0/1o7sI+uJ46XyI2SuA6uJHd\n", + "bwUKNMI2qDKAM6f4kKlJLSQWqzXAi8hAQzI017i25Vpi5npQJ4TsJeyOHRvmO1wY5ZnIEZHyhgB4\n", + "IoLWrdA5opbAou9XxH6m1F6osqepeJLd97Dr7+5BqWzoHoOLhOxNwAAAAQ4Bn9lqQ38Ah1fDGltb\n", + "SoFNBABy4LNe514R+dnaDTYn5E46OmsRrJgYyAm1lSXdflAXI1+CFQXE0A4eKb0poyZSLaaXfRBJ\n", + "r/tA3jW8xYt/UxFDszVrqnPHP/Ny6pw3mJ+pwWr+YYAHxNaLyZj85nxRNPFMUkOr96iCB+MslYrg\n", + "cr/vUoZCrrFka9nw08yFJlyN4Ky9KHUYJOXDrBIiz8KQQaHFalCe3rENKk9raHLB9E2PdI37xydW\n", + "9R3Ktqa3KW5rMJCOoArO2/3trkkCh+/FDlbsei4VdbDQ32DjCaAkDFjCyuqOJNsi8nSI2KDSRFCB\n", + "83l81kCObhPemVMTlMBQzSDvOtDFUtuVwHtirD8AAAFqQZvcSahBbJlMFEwR//61KoAWweTusUEY\n", + "AFR7WLigAceU/KgvW9LBBRTRioW652v1Xpv5tYMFhkRmmlUca4/8lM9NJwOZFgbdLq3dhRjr1SQ+\n", + "iitgTnIKVe77qt/yWy3INzcVxffYfGucVy2ypyvLSUZVvVzu37Ufe4d1uKQAC1EE3Wwzkx7sEK4N\n", + "QwJyCdTZZnLiyrlEXcLAMbB36CvMtmCiaP8XPpa1U2RaJxnBB9qYeP0+JCORflaC8m/hyWfMppd0\n", + "XeCFuAYTEakC9vO4HVF02QH4GZZigg7j7bXnvstEtP5QgYZViZcOoAaQGKtWm3PCHoS8mKWfCUk8\n", + "ZLC6z2a10V0U2DavVH2m02W1Lc4/2WzrwUTHr66DOaP+urnPdabeHdXruv1HJ087InGSipJtxGko\n", + "4rppNbdlP4z6g2o/ksCKcSZ76uS1diKM/39wzVYDu1tkCD1lomve9NoQwUToKqCn30PDqMAAAAEr\n", + "AZ/7akN/AIdka2XuDkeawxOj/BZhZtP+kNbRABb4RmWT8vSOMSH2HVKuz5/n3pn38gQM6YQqY5bV\n", + "v8KsLMWKt//3BpX7BUiSjA/GsXEpiGachc2o+KqjjRfujy3SLc+TvzNfgePwT9w0Jj9Y8j6ORxA7\n", + "13x9/iM5Lx1s2OQQyRluiOYKxXDE9QjNulPCcMLJFKpvAfnZmzl0pzzHw/ANcBEDhABHQ9ftCkUs\n", + "Q4pQOQF20mJ1++bXoRcUz/lR79ACwohpzpGuaQCknCVhUL3lnnyQzloB0PAIRq1VnOd+y8D18t8/\n", + "IEva3L9FTrRi90eT/2pNxjMaqrOmFzrhjd2kmSd3YBlll+A3KrjDn/HtXx8SDjztM7Km7BEd2LVO\n", + "U1pVGn0+C8gCov9gxoEAAAIMQZvgSeEKUmUwII///rUqgBet471BV4xl2QAFRvb+6Uilj9hVaCt9\n", + "oXOXB19FM5G4bNDJAOl9w7HrxMOF2dPOUf977Rp9NoBObCR9cN42Ht77Y+l36qfp5SrWPFz3DG9k\n", + "Uks1s5yfRvMME5RxPYk9+qohbe5TR7z2WNWBJjaTvhnu4485WU3BaTyIbA4BRRdj0/JwsbCXRVZy\n", + "OMmFdXnFdxhNGZ5JMCQy+ip435WTv8KevLzG3OUTxX5d8x0gaiQZdaPwNC9GVrgmtqTc0z7He5Hx\n", + "p/UnXiE+WgHU095CwXga4AbeOtQbj0tjxKUoS9sAoJ5fyTlHv9FnU0ujgUuoA3Kj0ma5qF69zgnv\n", + "MTXEIqf8zuYuInk435YB6s5Aa1W77q49/ZLR70JdKU9F42nWnuaGIFvaX8JNp0NTGvA0s1VSOWIl\n", + "YVdpY6hSPbDqLYXO/LE7X1D3sWpexh+/kcA2B6pYDzx14bD7OD1f9pMDWxIrW6BpNH75M54gOMY1\n", + "SxoTsfh6KVoyFK4Yqd6lPKCLY4O17tm0vzqLEva8zNeuM7b2yHKwMHpqK8FV5yaEer9Zd+uSgIqd\n", + "eftECExc0GDPrda1mDLPyRR8iDjZRvRS/EElnceTaWiUEonB934ThxItQqnJINdKSyNdNwx44Jgq\n", + "H9/Zh55FLA3sdVDr+1aesKMfNmYnbwaje7GN0y0AAAENQZ4eRTRMO/8AYEUc98FD5/CYkGD6VZTK\n", + "7qaMD8JeD5Yvz1s+LaCSFWcn3aLtkXWLu76WBTjEp2boTz2lISGgYIiIhTqGBdSAvn4GaApcqQ2+\n", + "sy0LjwIg9aZXDdjP9AWFTV1H8wY3dWCf+Rn8X8p7dsAFRxXZ4015PG0t6STtIq5DOqARSPJ32oCq\n", + "OenP2L2rQhT0bU7kBXZqDOvuedMFko4K8dbR3EOKtstAjt1gHGNubjQIVeNhJsdrdMtXEY7juX3P\n", + "NuPteAILXrR8S3R5mIOtuZ+vWEUdS+Inr7FnZsbQiIv9i7KDzU2m3LJLNdjmArFBBLgFXYHDvQmL\n", + "9VT51Mb8gx1TyNar/CPWDggAAADyAZ49dEN/AInJdfYNr4ilmYSAMFB4GADpypoeWWXE3q20mGL8\n", + "wfGmH6ZgcbtTXJWZn5/uB2IPeQFG/rqNYZ/bmIUcKhccFRuPa9wOgu4Qnm9oi81y+ChWQK1KoKDK\n", + "TWWDeg/SDhV8w/q9dFY0rcekgnjPKbKFgzK+IO7hoMF7vhpMoVCqvwMtBaesBfF4bzxIufyftMba\n", + "VRaJWuZpM22/FtH8FxujQ6EjGNr9PHZg3rsxXbkYHRqZvH6RGypNdfKRL4serPMKtCeuCWEKaj1Z\n", + "h+pr+ULdNvwpLLHfA3OCu3Ql8v/sLDD/O1LVB9ug+l/wHpAAAAGVAZ4/akN/AInJdjcgUcZACEqh\n", + "GvWiTtr19IbQdv8WE1dBOa+lNipi00vM+C9W8F7IDH0aaS+KKFaekfOwUNG520lVemVKNYbjnPl7\n", + "LimE+s4N2NJ5SYT5+XRMb+vTvKCkG/By5wQO/WbZo9HorEm10+Tu4CVIj+2Ky5hDZl+kA6mkBK7E\n", + "3LwAW+4rGYiO9JH1BLFQj0ZOJq0ybrdVynOYOw8TudsCI+I3fiT5nmYCkIO1N7h++s67fASBLfgP\n", + "CYo7yLNwfifRM3ay+JhoRmwX5tGJ8l9w676Zo1wDaqZ0Q5guAYSxSJk2jHShR6LxlZmIVJnq7S00\n", + "iBOM0mxomzMhjpxeX6zqy/aA2SEREi4ulxZsEvlIWhLQ5YFv6LMkVEh9RITRQOsKGEls7Y4eSRWc\n", + "f23FGWOVxL2MZUmPGVh++Xygx19XCiXwoatt/s2T7zGfLkQ2IBiMKXoeDb7yiR4q+0v6UjACWT2H\n", + "kOIRMpG/B4KQPsfMRT0Rk3cAwV9dNnKm4XTlo9P9TmyT71B/Greq+KvhEBDxAAACJkGaJEmoQWiZ\n", + "TAgj//61KoAW5ktFwTkgtAAhBassVgP2a7WSOTniW7GlpUC5YARIimzpboyDKn/53KIxVBS+A0NS\n", + "3NuuWMzq53zfHvhoSdYO4dYooBUDN2VkLpVK3v3kQo1FoE02X3cyV2j6ziOTJORgWGzqU5k0XKJO\n", + "1VCPDS1gJclQYem5NlGAENmSiR9I8XvNQLGvpLGF/2+aU31xCZzIPp4tUxyLu/gVqq+6L5DezfDz\n", + "gPP3+vv4JFttE5Nyc7LysmCaQfUhi6zPymHmdLjs3bZdma4hV61UMMsGBNZfYf2GUkV1dVZ9kkfz\n", + "RyUYJPFdwjA5S++T8sc03o81MYXnXYkO9hGiG6RRLRRV2fPSgGhghnaqxRhYVQiuVS0ENIpjxqqc\n", + "KBEaAMs1VoaLKEOrNhZ8yB1VLLV9KSiM7/prkkNKRuNLp0WeTv2eHtXhIdAfhKb+ic7Pb48CqpOl\n", + "FnnbgphlxDaS1dplrA4VxMNzEL/27xNMQzhuRvnSDNb60j/kSJHw5x2JG6G/VwCoVAfFrZll45AB\n", + "Puajv4y9+7flMd/pR8Rg9UAn+cey+vNCcCbbn7FNSWq2hl9cymk4fwW6iqBgiFEQ7YZtyDoNCyYz\n", + "KAnW0gvHCg+5n6+qxC+xDS291Y4JfSW927ZZudU0tXxvupwcKf6fDXxz/bqsOMvxj6Y81+e6Dezh\n", + "B2/8nCpk1Qc7N5s0JoStEQ8+K2ir0vIXayhFQIgAAAEeQZ5CRREsO/8AZTZTJbuKD3PiQhYpzA/Q\n", + "3Iqsld8XUz3sHppFsAHZevvXPBLN2cIUd+YCbEEH6MplVFEcbuDDV0dnlBcrCNrbp3+CAOdBsr6h\n", + "0YfLGDPxHlFlUCi4qTS1o0TT2Jzkq8/O+TU7SSImG1EjEmOGpKvxjn7KxERq2Pbd/0y1sNHk5hiQ\n", + "eJwHwc7Z19aIrWes4h3UYQqHeU6kfCpUHVgnGubU2A0Xjg0UrouNSumFogz0StLk4fuhL5slF3Bb\n", + "3NpP7YhgiVLV0FNM21/pfbXvRQFzmliOaZuScgePqa02nvOdEHEpGVRPLCGL/tvzSkZqhXResmQg\n", + "1qZ/TxlvqjWYqPRThBIk2nP66jbd6NLagdWz1BtbrwB3TQAAAVkBnmF0Q38Ajz7dDL7wKLyRAA5r\n", + "u/5Co2KbB/AnQg3XvWeaImUuto8KuobiZ5Rpi0jf/+r5lFprj/mYxpQ5OwqjQqFG0eXwqi1D6M23\n", + "HLH/3LvgYXkbAAGr9uWkQaEU+TeJ38WNXodDC29t8Y0uYEpwNzyC6FqtgkCyDYDpd/nESpdVRRJh\n", + "15SV0TP88AKwZsT7yWH2r5gpJv8AhXnnWmKJ/WMwiS/2+Kf3ikj614P+BDohXhMYGO4GSZ19EkRI\n", + "RjwO1zoy3Umd4iOMuBBPzevAs74sU7IUdkUF24rNAstoyqnAUgY510L3SgPXbZmJYMv+tRpT7ZuM\n", + "oLxE5ACIQ+eHStmGZgh2P1nvrIaZRiBxoWZ1B+DDOtu5OZpc7LbajGP/oy8HbEFyJIcGXHGB5VXY\n", + "HnskMmabuu5xyFIJcVaqbGg3TlqrbBE29OX6xO7K38oavU/okVlIM+AAAAGEAZ5jakN/AIdXv9ZL\n", + "/wCpeCQF0zyG8897iu+TVNq8xXl3pE8eXm424VBKoADmOQ/RgBgC6Y0IzpqUKPVKwCZafdEIuhUv\n", + "zhgtxewRpr3F4VdMy9NUqqvPfGroLPxDW64Af18RtCEv8t7amX9ezvEWK8AgZjHjHXeVi2k8dp4r\n", + "TuMjdngEOGe6y0V0qXE0vJudyGSblaiStnW6rV0e34JxbdN3Qbajy6ozlLfOkq7Wqx1iLXxa4foY\n", + "IPBIjzxdye8gOjZW7bP0axd+wppVHkXrrvuxUf9dp18AanJIIFv6MCm6ujRO2wyu4ZfSbZp/KVFm\n", + "xvxpBAJyjKSdCoPxWylEDyms9NAmwAADmUiy6WUOIsiAC130X9MRKfeLHi3miJh/YDGeINuX+P+e\n", + "NWBXxp3RqAzo1eISPcPztmgXUHCSN2VRpnCOFQoF4yyryK4v7s2U4a7V5e2sVJBhb7kguiVFACK3\n", + "rbLSCnWI4OCs6u017nghnGW3Juq0rF80iqmo5QCt19S62wAAAkZBmmhJqEFsmUwII//+tSqAFu/w\n", + "HjJpMYeKfGxaFh4NwH9VzFzipiNnWLhZf3lim8qQP0NcWviT9hCfSjxxrnYEE59yPQn7u6+tCr/u\n", + "vn8/iyWB73TxWIDTyqwOWzo0R8Wj7McP4QWP8yE0svd//Wkug5+3cHmcpP/ONbeBn+TAQ0VzErlc\n", + "2hXFLnmGW7EB004qvGi/S7JfG21T+V5Sx9Nre0PuomioWltV0uJSYiMg18UwZktQhoyeO+qpPgky\n", + "U9/xX6NUrUyAfCz03v4wSV58lpzV7BxftApX8ZGWBx2zWQV/YeOCEWbmbHqvN18Jd5FxK1iHRqe+\n", + "nBGg6SyBQEQQfCMxCo37AXM212ulRN9X2fE3P9HkhvkaOxQZ5AElyFJ4BlaM9J8bcUgOX6NS6Cqb\n", + "n7IHMcCIPjAIJ36atWVr0EheDYyrwatT/sRxqfSoF0RgoVqtGqstMXZF7XACu2N9LDV5Ss0B+mSl\n", + "kJJqGxc50wazbtpofP341QOLrRCoQigLO2IFkJyqTpln4FgoWIMbx8x6cKkFmIESXv7mZEx6LOrL\n", + "ggZa/EdzllkBPCO/+zBjmey1Y55MrbMpoidNDpdQ6yZ4UDU0ai3HtghNjtrUaVDC+dCrSCASLB02\n", + "bO819PX27qwUTWW1MCrVhUzQkUkht4Xa4bdnUW7zTudPa++EPxUMVY36vPDJoCGilCgIXzTOV6S9\n", + "OVTh4+OA6S/XkcoA6ZjbQLERX5kZSQMoFJs4bPot93titzpDSKAhc1QMx6eKK6Ol2IEAAAEkQZ6G\n", + "RRUsO/8AZUEFdKFRxHYcrgnLV1IJewAc5dAL6/Pr5YWcZb4ejev9b/lpY1ea5Xk1AlTe44c3rPkF\n", + "DXI6yAdEC7kxPh5StAse03AARSF2nro+Dr5bfPJyYF/ERJ9NScPmUIVihvTCsyh5qmuoAH9P7eCu\n", + "Y8rdH1hF/pTSa+Z1tzZc8gwGtgV/YsMtlWLs3VbLWxt2KTDW5Y2b0HA6zgNn25rXu72r6iiN5aw7\n", + "sjFipq/8rjgHE9K0EK2Opn+0SPK2Rbo28aoNdC9V8VxW1CpMNxKjFOs8YmQmJE6Qtkw+Uo5mh3ic\n", + "7Ng6Xje5wAF7a8Iyr8DMIwvMZnnVp6ilQ1B/LSGEPncviRIHH8w83Grtt0CsL1L2isuyMboY11N9\n", + "lxQPpwAAAUABnqV0Q38Aiz6zZgMl5b2XXQAXQ9yHCqNv7FVD9CxHdTnw5pqRTLAoFiba5ss3lqXG\n", + "QCf4/o32jzmzNKjZDN2ghdo3OS7n/NFKTMs4yX0NTqaEhdnVRvrbcGvcKo0NYMgzE8UNwneueU22\n", + "1vpuKbOkae4P82iS9XSi8TlOPcF8mmD+n9qfVTXzL4r0M/s5xxZempvnxqhz38EgmSM/Zw7kEyiv\n", + "giyuP/YjNhFl3FVcOSLiQTCj+F0nLUE7lia+UkuO/YNBXwUKZKD8Add8BG6ZTC4bD/RSktc7uv8w\n", + "NB82AXgnpuELTB2xZFOLAYJncjo03/3uAK678Cl8cw8fzlbnSpp5eUkHacCUtAY9LPrz/OMf2bA9\n", + "vBE2eUwrxz/W0Sg0tjzkUrpnJSF+xYsA2fgRolT6A0NA++mVN8PJVhaGzQAAAX4BnqdqQ38Aj1eg\n", + "HO2BrhbSJp3bjAA7Lyx/X3Tt1hQ2T/wP93u+Km2fQtCsS47kHT/v6cxSu0EEWzwOVr17m7uMIt8s\n", + "rOS2NL0s+wNbNsQiUhFGWcubxLdtukca9QFTdaQjRXuW15l7gz2QnuVPe/r9SLMinrQ8TAT7c4JB\n", + "GrUpwbYY2wvPKUw4NOIKdjGz2TGxM02Yhqm+YQD7nu+MPeXg/5dBf+XeKfPK+RchTbfnRfx28pUm\n", + "+MUq+ynmpWVmmfO3TbD8gZCbZRUeK4LOH5lP3nvVvkbZlQVhN5vPlxxNouZsDfsmprxmWrHzH3vb\n", + "E+c7VsDA88L9wCH+ZmQGzxFjyOQ8cz4P9rsZSuU8vQS1h6fmk4XXUosrmweEGKJT/Sv5qb0OG8e9\n", + "voRxFaPrroiqkALWSnA5n4zcQMwfY/xXX1aR5rslt9ItB406qJIsbsrkl8pXUe2CwOVm9B72bhd1\n", + "lqsCRNktqyPMF/Ek4JsxscPvDjbSqbQZL+uT8zjgAAAB5EGarEmoQWyZTAgj//61KoAZQB+OVG5p\n", + "SZHABUb2//v8PGtlbWZ+A0oGGFPTAdgmU2TFbsuJ6mwUCouNe8f1I2ythN04JSJ5lx+ik6KpnC91\n", + "1FD3eD5Jit+kJIg5holbnldcijL50GRMV+Tt0L65TPBxqSAUdrQu+eLUTHPpJCL4CV5RJau8pEIv\n", + "uK3a7QA/UMQ/nrDjeZ6jqf1BF3JjbyaeIc5drvnYbR6lQ0gBIzp/QRU9xrHm8FESnIe42aooWDJ9\n", + "bVMccs59QBQd45WisW0MXV7NFtyepgfK7biPJN57MDsWL2A4LYHAXH6f6In3GVsSrYQ2HUKGlxpv\n", + "Yf/Xvk0pBnHsuIEsslXTjxwTTzuRb2YT7QCJp6yHiUVL67n8RfvHMNoHfUzP4rVgPSXcPL8FOP2d\n", + "F8GxovHNOmsOSUyc+t9OZXQFF+4FJNSN23FsgARohBEJ3c1u0ax3ACLYlwfCd3/U1mT29ftZkWMR\n", + "uj01t9v2AGHvgKM29X2Vs/ALzLNDd2OM9z+AC4TlcpgcRujIhnjHf17Je/8RMBqJCZtdfrFmz6AW\n", + "Z/aNIv/p/WX6adpvStFWxoDAnf+Tai9COS20TO4GHDviQkpMo6tbNTk4tiYWsmvBNq5u/aO08r2y\n", + "Bs1eH2kAAAD6QZ7KRRUsO/8AZUj9pUTz7rNMoHjJ4gSsLw2wABNFEVCVBZ8at73oa3C8UmeDMVba\n", + "M3uHP8p2EFDXTkl9EiChbxZZgpuvefKfc50lYhoTJ/7H62X0Z9NX2I7S32WT1XJeJtD32zfVBu3K\n", + "VmE+30x6+W2pKnyMM0ZejDKLq8WyIyi+9rC0QVVyU0N739nDCyt6aqRfMfSdljqTnwOmgDB5pHyK\n", + "U8Nf/BZxnIET5uBVX/VcS4bjmT9sCYYwmAz5vBy8cv5J53FYPh0/wF7kP2myhm8SfTnmNtpTej0y\n", + "JjLbrdGSBUAu+lwbCsr/YdOCYrxvvrklZP4j4s5VlQAAAgYBnul0Q38Aiz6zZf6skuDOogA4jl3V\n", + "YKO0NncAuqtob34dJ/eVmQtCFk2jxP+6gBUwoAJ5d6wKEpypNd+AlIf83kNIAAC8trXyGAv3zzzV\n", + "tAa7kzCHOXS39Rxic+qZEHcHH0Hx0iIZnH1UNeoS6dQYQqolDkQpOXG8nP6tDCpAEYSQsJzo5kch\n", + "Xf9jICMUCBjMQXeVS1i3FdA07mrKCBowVzEdee9WvqvXV7KuMTufiL0hA8BHvtD6VFvEZ6eiqgvN\n", + "8RNM5cYXQ2i+4Lx4R2QlAIN1NNxqM8GvSjSh/rgipqY8DwHJh8p9Jbu0Zs+w86pgxJN8m/cvWxRZ\n", + "yFAtI7sBhDbJnNXx83ll0o93YVJhxi0TxWXPf6PlHZeEyvr6QOF2VVafQjsZUg34P/p6tj3lkAer\n", + "aZouLIrbfbTrpoGdtXuXR2qC418s780GZsUBVTlvppC7dgGYqQzB5daoV61BoiIg6tQyG20Yk/Ib\n", + "TtwSJmeU5Eiu/zRo0bpbU2jgV79WVCB/SVzxsmoD1jJEhzN1FHxsbajOijl9Vp76GofsezNr+37n\n", + "UWWhPPzCk1rCLQgaI34ekcMUWq/vBK2WDe7wKACe/5M5UglN5Ct9Orsd3SfYPc0336usW56marFA\n", + "xW2XgVLc1GludnoFyQrT+oASHSl68jJc1j3I4WTIeU/p+eW8RtUF4AAAAR4BnutqQ38Ai1egJmdK\n", + "YqnGBlYUAF9obzNVJ+s4Wyt0Rq0YuZmzKSClvCu/741bUzMW9+2RqBxHf8xROd9WCD2DFO6m3iiG\n", + "ZOgLMC6WQsGlrWDKBATBQkW8M70y/ztO1ZzNQj1ow5FREW75+T8qWeYnaEkP0sDPfhS/8A++EHpT\n", + "ONUZpoNHugOpCj8EFvE/MnQhkWbqDB+V4zYJeD+V1h9PGTTPeM5Ykyq4ZMi+8E5Gka9dd2CFXMaQ\n", + "M99mRo+FOH0+y87A4U4JusoMgrnGwBHn7tNdR1Jgk+wKYqmIwBj2jGPnQFJXhHhE3ZkpIjaeakM2\n", + "8MH5c8xC359KRjK1nfiZHGSkxS98YPps7lGGiAJ2WdM/l0XaVpItX1VPHy/wAAACGUGa8EmoQWyZ\n", + "TAgj//61KoAWNzc2A41R+LAApun++OIZUz7EikV/szjfxvYPLx+f9K2/F/he8DHawkBMdV2wRLxA\n", + "t50GIuRUSWE/39Xo4nAQqkjDTJdufKMgNIx0erMAcY2QA5ejjVo1tlzncJOxCqGpuGwA+5/4IKyu\n", + "bmTzdPecTw0ZdpVPq5j/sb/uUTmyS5oriK2QJUn4uMhurpWU0pM90BFHxmx/55iJQnC/E4AiRjGv\n", + "TSfvy9eol7L6q3/AmWDGKQmta5h6TQecJSS7keMMTmFMkcgh+dQEUTFbphGIZpTz6vxfkWPPyqpQ\n", + "VmS0gectGBeLssajkGiu1ivhXeMUvGnpqjpc6XSD8FJ8sVdfwdsse9JozsVq/t5YFq5+AnEYcopl\n", + "mlIiLVwif6/glDa/FvPVZyUrYuYY9L3TA7eEHe1IcHWSOPxpnafEFBrVGoeZPrbfymiVcHOQ/3CX\n", + "aGrpVwdWrmOHr8jLuajUxWOW37ajHobcyT1hYWMxRTx80fZmsfvsrNw/Nztdx7LidHGE8jPZ4gQZ\n", + "DABlByR/bof6mTmjqkfbsR1PCXy4RDNnn9nCnaSnb8pCApsF6YsDTv0+UmVzx2ZPSdm2LhZIqOim\n", + "mhiXHWt+ZE1dnYkLwTdsgNYEeAUTjY5XG25CAykSMfKGwGWeeOwqKmLAqTmb7mCXXxxpy4+bbELo\n", + "RAxOLFOR7z+Rlt4VIVMH4QAAASRBnw5FFSw7/wBiyP2mEJvZyVx6ACpM7CM8ZBKHKR5j7ndOem+L\n", + "X5lQTliSlHrc19blDxI+BarmPxVVRFr/CorqLGvI+vHNUfF9L5rOth1seL+LchCRD6bYXJMlctoQ\n", + "KBnrSfN8OsFA3rCX0rxhgXIKgdEDuCNRYd4XCiw0AyO8VPwgQ3UKQOwN4T9AdwOVZht3xWSjlGSY\n", + "LTfR+DOcni9vpFUI/V99yTFNeriW/Ezi0Mmb4Xp+UrrTAn+/oqePQryHATZ97i1I4TzdZJ6ol421\n", + "ZZiGDIa6I2z+mz36WJISXYfn5PcaqZon5evy7wkHdXdLSXQuyy6RoW3UMK1kv4eYGMx6MEUBV881\n", + "1DxJ4Az2tfQhJ60iq3lK6xGARpoGTWiGA3pBAAABAwGfLXRDfwCHPtdry+v+2nyY2Sk+gF5YW5HN\n", + "XoAL6QRR4alJgXnPRJGLu1H/XzBsCOVwj2OHZ7/Befz18ioG7PdTUWTo/DFmzXwFwKSHq5MESJ/K\n", + "+czoaBaMU0SilMUvvgF9NaNkzEcYOJjCpUUkl+lvc9iWY7aNcNT0YkO2YuPLl1ZJa6XpXyzgvJfC\n", + "YABMMMlHP4hWdgac8C4JyYJle4OEiXwhanMhhDIkpZpmZqqPP6iXGzuSTb+0ZDMJHqoDGqJmkb8S\n", + "IJuvyZGNE4panvJTPVd9f7g4/aXxMPm3Cn3wfT3mTthI056NzanOEWKjM1qGy4olpTOi0cV3zUKu\n", + "VGl1k7sAAAHXAZ8vakN/AInJcXImIY9AsY+/nZAB2XUf7nMR8KlDfCSlxubwbY5yyAvaK6FdhjtI\n", + "iTEMX/gD5nqi6yBjPV+WgerMVdQiwmsTWCh4ZDRMTEvRNiTK06p6H4BM93iWfwAaKh8Gz9Gaukwy\n", + "InHLEZ0yD1XqM2twrrM9K/zMIWUOeN0Z6Qpdges4mCaPjYBUMA0KTxEuHmES85gUYlt0s0Ks9Nu+\n", + "2hfyb2t0rmyvRs70WgBBgYrdeTZMCwmoCbRHPK4oxsSlCang/p1gu/DmbjnwYRln/v7ufz7R3gdP\n", + "Fr7XrHKEZc+f98DBxQMF82PBbmDGtLAQXHwptz6g5mqHfaJhvvgj78jkqTGrQ4WXMBaKzHGNvGYe\n", + "XIR0bHtcMMQd0uz0UHs+NS8bhlZ93PGBn0DI4S7X4qFOiND2PCIg5ogjbfFqU4Kuh5oLH4L3vi2E\n", + "bzWP7DaofhwjMqjCqAvZAgznNJDsvnJzQxJ6Pqjj2ny04t1drdQRUisSLN+PcLenLQZbe401Xg2H\n", + "yhW845ouHrITGSqb9EOEeoN97gj42PjsdYRMVLRDVvCV2BOAqdLbEmICPHZnyy75qPsejK7duPuc\n", + "fJ9rEnjynB/HxYz7zf/RM6xyYbzIoc3AAAACEkGbNEmoQWyZTAgj//61KoAbj1lLPyvb6PAZgAh9\n", + "7f/9/gX2SHKs8Uq31kdycpXc3bf6XPCYn1E4Nyshm7SbxYTXwR3t77AgzFtBuE6fBgZeY48yXmAW\n", + "rqOr3iMlgArjVOjemrjz47grY/T9rKmhvhaqPi8pvZTzkzZCl+tV6nzXVbBFw15yZW9xk2z611V7\n", + "GITjv5GH4Oi/06B5IbjEMVKEcRpvt893HwIyUBXniM9I90uh0TBxOedvsxxE2iLZsr/m/GNXryb+\n", + "9as6btju6GU5FfXHAHKy97PxI2Rac5Rx/FoPiuKEecRx7EQrDfRmlggPPP63oMY4jkBeTzC7Drwp\n", + "8ik2Z4rhoAMWlcRPfXCI56oe4Jt09oRInuaD3ww9/jGDjhHIXGbNYM/s5UG1XuYLCqaLxESIyPG/\n", + "eNnETthXX/QZDvDCFX3YINANkqDvHlUQ+vcUvksaWF/g1aVcMu45c8BoP1coWBAVWVE6iyDMwfYl\n", + "RYTcnNfp26mpOfqiSJnYH+AFj0qGJttgeZBuJCzdV4F5EDreo0WWAiq/0jdXljJ+ZxDij/UazQOM\n", + "0ct15Q7rTOqLKy+lpOVa/koSWj06e8eyy0wY1FBSVaROGYbDgXze1QzYiVyP6+WTk1fjz+Do+J+/\n", + "TxVlHJsfUOz0tbPJ3R4cSjRVigTxPg9VAYynpzzMlIr0/pCOGd4XYyl3SGTwAAABOUGfUkUVLDv/\n", + "AGU2ltMhgssRVFnYDYHdfwUIOpARUIP1pWfDHpU2pf97OTOpyP7SrW+j72yMHgCy10/KQJvVenOE\n", + "eMrSHUfyq6lVIsdEDgl0M+/NXx5VMpg+IZB+I7xozsY2f0ARjiAjA8ZSqG32YEqaGwpGp+vfKL3P\n", + "hav1CfnyaUmopPCa0Y5ww/PZN4YINPOwE+Gg36kaKP/ME/B0d8v00CzvLXmI8pIa3TqrGIa7PF4X\n", + "8miGO6oXkRH45ag0gFdgkGj+BD1PvtIptIkuqTa5jzG/NewDN9cCfws/hjc474K6NoCTyr++7Tth\n", + "LSIM60DcVje0csuhEMwOmCNob99l/AJp/9hMVsVsEaxUNsWBZFMKnZoLJU/ljkNlTtF1zcUwJoZD\n", + "oLTT6FmWVzlFnyfjiJdVIqMAAYsAAAIPAZ9xdEN/AI8+s1VkrBucudR5tN1L4cUDsugAOgW+6weD\n", + "VD4WeLhja/JOA5FtORnuW7CfHWfWrXcPJlwit0rQdaNL8wYmpMOBxVMKErdopYTnWfb0EZST9ZFP\n", + "kGeAI5wBNyE7pmk7U/hz6/Uncd5yONsvInzdtLdlFGIUuwPsZsiC4nxcPKJ4ER73zqMcPC62dMwB\n", + "YeP2JTSzcWxmsY8AuUeSUMff3wugzCWo2dZWIqj8MEevc9dnI6e4RX4rfqOmeKfJ7QFxuPllAOzz\n", + "FkyERujhdmr2mdRExctZgI01tg+iF/NwBCqP+hQ0BZaq12BgDPwBcWyuj8PXGo/75aroqbic3atK\n", + "78lcQoP6TccBH3q4TpJbdFKZCXZFrS7Hh71ZQxzuADlZ8DDRzGHyvFJs8+7LX0Z3SVEeli/7hzNR\n", + "3en2BovQV52x/rwTox00ojUHS89/I6QK5rr9xZ5z1Evdog7ewBETCofR8FQPxE+2X576ofb9SYpa\n", + "RU+FFWJ4WPQBj/u1ljXdmoINHOgs90YcpGG37DHSgRaxKh3h9samVWdsr/7ZPH7Krx9nfE8zJoXc\n", + "5Frf0sUOO22BhUTf6MatKarbA54SuNAmIi3ejRZKQJ4XCjhpsLBrmw33yy9Nk6OT0LCi0ELysL29\n", + "OvbOK/J+/iRz4bP6v+/3ppYXG9MzSEeggmS96wm6yOsevJy9wrAAAAHWAZ9zakN/AIdXwVSZADwX\n", + "ZeAC6HD/yFRsSkP+ZT/GPlFXimE8PIk5/ho1VfL2NNL2pqViOd6YYnwc7ksNMs5IkNYQ+fdC2XMm\n", + "GpZcBQdS+anJcAkZpOHFxqdIo1pLhI3h3bcsWXXBd+BTXZhbA2JSmhm8EWBGqSBNaO0U3Qcdcea5\n", + "428f3xthr08dSK0oFN+HNErgBuKfL3JZNShDHaW66u0MaG1B/cF2Go8z1F6LGKUAmsy0D/C2CM25\n", + "q38c827dgYTnZjZnTFxlPuxm+JuWvYpOeWyy3J/wjV/USVL+4BKz61/Ccy+EH/JkQUqRmUOtvYei\n", + "XxTdexyug9nI6kyTGc2H3hy0C3uFxKKFKo9PfiwDCQWhQ1+vZIsII4FYexn+pQbkz5kmdlWKB5Lx\n", + "ONpNVggWvIuTYEFI34NTLTOf285YYkebB68ywIJ5f1uX/OXMZ5RxH3gjNZ8mKLNX9suvs06qOt/Q\n", + "e2ZfZ7Orgt/l3O7GLxwWvzugIsO88I1KhpZhgYDdYZ//1lVBcwG/tKVYjF1obqjtyFctY9LPGIag\n", + "318ehZmIvkhW9djj90e+pnWknudbQDv3Os17s3l7qFADdqSGqYyGaSU47a6O12HCRSwmepV1bewA\n", + "AAIrQZt4SahBbJlMCCH//qpVAC8LE+AX+ndLRI9AAL65x3/f4eNbK2tvWi3seP5qm31GHdf4edmk\n", + "0/ZKv9BuxjUGH/qoYxXDUlaWZFHb65x0lomfbckqRBtklU+1LGTmYtvnPAbKnUSAh/jTBATZpFND\n", + "l6V6ofQ5PTBcFjOWwgI6YqalXUkmqnN6g77O4xvodhM7XQWhsA44ADmvatn61wvReF9d9MqoCN9N\n", + "Twpkx2kbbrSoHJrSyqidCsv+e2gnLoWDEdLGn/42++dseweQBj40iKRQ7paDrpDRwTZVjGQJ+52c\n", + "gaUSUp5A/cAn4FgESmp/sZ0NpfD9/7ZAmCbSUfPUar6ndxZ3XG2DXWcNFu473rzFQZNpJnXg/Pfh\n", + "QCQDuu/iX2Vi2NjGs1QVI3BReUxvD8Z/YeLy6w0jDh9dcJGJdKoNjb9Epdy5r0lFeFb9L8AWhdEd\n", + "sGreMPdTiMRlq+JOqjdogseyQTcuDo5iesxIsb0dhY+P9VqSJtTxyPO42dn6TXPZDgt1vROlp+Ic\n", + "VTutbib7FY5U+jSckVQsLzLRwDuIoa+HpEcHjzuwHMaHrKVljgiPeRI3Afdpqx3nHgy0MFCOhGEr\n", + "Jkw+Dadh5qrWjCGOX2K5HPLV0E5qw7krTDhpWX8sTsYsIqvxr/V2EjIFiKwnheBvunmhlbHNUKTl\n", + "ykWRC9Afa8QE+vO8sLJHYNqVh5kOrsn0+NP1Mm4JPbYiahSDJa4o8TJzkXFBAAABAkGflkUVLDv/\n", + "AGBJAvfAgTZO/kHo4lc9yaSVZkgaxkXEQAgySaAqoJy8U1XmJXFaLzsHv4KqZnckX0gP1AYFUr5X\n", + "3Zof5zltHp7OQG87KhkyMuJLOz4diYjf3ctsH2KA3/S29L1hP4qjZ9kfgNEsjrH/nSlX3ikiiFcQ\n", + "/2mu5vwlzQMTIUj5/0pAslvbULpI2rwxcgfjtpeW3qe/Q0sCZXyJ3L7VhEaeyKZo/ALUAi114xdn\n", + "Gao6fyKpZhWohGCsI53i8XO3Y7Dq+aD4ONx4A265BL770fTZiNNw+oM7dwTK1vcPMdOTVjz4fi6j\n", + "bCMBPzMCGM7CsAz7OQTIKiUTlOi8YAAAAakBn7V0Q38AeTG7snd+wR+ioRwfka+slSBm7w4HiigA\n", + "mYoe7RzT8waKJhe/5/xyHdk2lI4Qb6yur2vWdYx/k/gVzZWx+dAAALHLM2W5kE06MD+/WY8W9vMg\n", + "jgsWx+NCob+sUo3r0m3kC7Z6vE5pa/kp8NVK1XizBU/gSaY6/S/NP+nzZeAUHhvnb6LPnQnTmhI7\n", + "+CLAa1UiK6P+lwPbKP0S0Q5RWiopmhls/AKTmwxXB+WRWyrrFglLMCCi/H7yBlZCPn3f1nUi1WXW\n", + "txmtCNftDVTPLfu3fbw+YSszpG0LQoe/d+Hn14JtNEXcVveVKgdRtrJ2SZSzkDZoD5uTokEopKbG\n", + "geSmsxJSe6mDenK/tstnSjFiozTKWgyJb1mTK9iBWStV+uPeceDypkgatRgkwgz17Zgn457UL8xo\n", + "RIb3Rzvhn1PaM6KKHv4wQMqvpqRXKRm+SScKgBhgUzc706tHx+sk3QXrFbfmTj3VwEqpASdMV8SQ\n", + "Rc7Pl7VdiwexHM38nPcgZguGyvH4NF1CZay1mT9d+wee9MfU3VHZJgMp057sUGFJIJZNmQAAASYB\n", + "n7dqQ38Ah1fDGltbSoFNBABy4LNfpqaOuQiA03rsvInHR01iNZMDGQE2sq9jRvjWYcCsjv8TgHDx\n", + "TelM9UgK8aIkbW5xZBO7YH31DMzHB/HcoCKmBUni45/7i/CIo8gF1pGPr0DAA7wV6D09MIgWLTIz\n", + "u2RlgzWHXLOhQSqpesq6gEgghz4eO+szzJWiaji2cgnbFYV7gS1iXMpBIisJc8i3U9gywhFgtGxt\n", + "IPW/7TiYEwGOLwxyjZX1HkROuSI8lAAdZBpungwbYVpPKSngzu3PnOIcBqes7c29MHD8jRPn7Zrt\n", + "720E/jZ4jB2yT62h5AEs+TCYeJmiY6lwGwXm58hIVqeMFafCwAYhd3vDCtfE6mymrvYwtLYQ0YeE\n", + "Ebj2MbA5+zEAAAFwQZu6SahBbJlMFEwR//61KoAWx89GABUe1i4OfaowcQHQyqHCv9PnwkHOB5jh\n", + "ZaY1nqaJvfgMHLxnx0HRU319XsFiIgZ3fycxZ7MoTbod+V6rFy2y2Qtld8RvCt0Ug4PVQuLFLU9x\n", + "N6gbeWntqj92UVkXYHO8rtnoyHbc5vkyDRwK85+1rEknOmV2fCPAJQWJQHZKzqn/akJ6R91HlWya\n", + "u/8GgP8q7KTtX0XyZMALsB3jT/UhmW5AlGIwNHeW1rtDiMG/Xy+69i+m2kTOjww4y5o0/8WfwLLR\n", + "RKlhEE1LYjJQjoy3+hNy7YguxzdtR0GOg0UsPQLFZIBnnCwGmFharg9MSkzKoZck80tBnNzVcu5F\n", + "Ot8W+bdDLv2E/9UTXci1RXlM26z5jearPa/9d/CciU6kElsImbzJ5J2YpzVs+pvW89XbvAJMExZq\n", + "wXD26iUkefzti1p2cc2CbM5qN5CGCTCmR13du1Y9J/JQwXkxhEAAAAFiAZ/ZakN/AHwUpp6Dymc0\n", + "2L536BR5shJlFypABdlGcrzfdaw/6f5GB/atQKmEnLjISTsAvG6zfbdBMs7bm2yeFrIQxXuK81kC\n", + "9pAAAXcBlvswH72knWeKBsU0Ht1g5h3YcKtQv4e82ah693wXobc+mdHgPA3TBKIFWUv/iM+/E90G\n", + "S/NmTeZC+lgt/zT/+HMt/QSFK9C1+AMdH9l6Wmy5eJzA8pumBNuqAArwclv8LW1AC9Ryj7J7dIqZ\n", + "2nhKIYQ08cavMFAGExrDHt7RiTs4Auer+jpijDT1MWhCFcQjNZn9nbOp1MdYUZ3batlHR94YKH39\n", + "SB9iaEe1H+vDrSDRsP3b0PfVLevCUtQQ7tTMju5YxLigI0SkXHby6oMGwH35DOmYdZ/QEHihEbbH\n", + "ljlaWypqm6TR7b/zNBCPoaZiHS0IlbTr/gzMbXxGasP7GssB89XtUV2jZihKJYcij8456L2VAAAC\n", + "WkGb3knhClJlMCCH//6qVQAvW48vGhnpxPcAFRvWsRQfCH0ZQNKlkI/Fmy/VFBZqjdqwlFWyRDRU\n", + "ATa/x8nSCThm/LYIboN0iejGj3Uchm8nyLv3P3+HOOnCw7+XGsyycSpaT/SKI8hu4RwjrdDxqaYn\n", + "k6pZ6qjZtX+IZ04XS8X44piBkZKHHklQnddyez3eJG0JjT0fN5b/c72jAD+sOeXlR6iPKkSUzu0o\n", + "3ha2oHN6UEDmISbP1cbB3piI/SHrisHlFNjIuHiEdkqSzG95tlcEE5RmJMFHyIZtmV+VUnHUg//H\n", + "WOVjyT0+oFlaS4c8th8dtoQJgchjo9u+OPpSDxEJgWI6zeeh28ogNTGzlwRqjfRSsrTItvjA1MD/\n", + "oBFhKLk5Gm5LLSkMpDHu9T5I2IaoH3PKDFRJp5FswrHAqK+C6EMiKJRw3UfQ++e71IzTL0xpDNJL\n", + "z6AeitOHT7WHH1q0lcaxtRKIXyzlri2FOeAU+zEh7DbcM3wvbzCPYrbD4ePmP1flYALif0DM+F20\n", + "woqO1ciEp6KvfcdLwkVhOi6HukmunTXGsruYaqjkaLT2QlUIMJVPTAaXGvEAsJSG/0vfsDXKkk6Z\n", + "sB3ElNrSO3yHej1aIEgW5xnCNisEQsWn6TKnOYGilPN4ZN8EB64V0F8PWNB9Aq0baX+T8kKesmFw\n", + "2y/668NRP8ypn4s+0TEew3V5nLH+An+XxWolypflMoVnWhEhG2W+IIgxfWfPuSgDmqBKtSemnfnO\n", + "mj2z1HJ4yEmqNoBjJwYnWfK8e0PHHb381Mk1zGGJOgWAAAABUEGf/EU0TDv/AFlVerlP4Rak+BQA\n", + "rfH1MAekqKZtO9rI3YpPu0XbIusXd4D2mikBBjNWCs5ZCx1/nIkAW78LpHSyCScRX686DgqeELvg\n", + "+6gjEvz9oPv/Q5SyPMBeMNrb/QJ3ato+Qw19nLJWjl0bduh+HilMsrklIYKHCWBaC/dNC4s7Xl/r\n", + "RCzM7ZJuRKmUY/D5sEAdr/H6TIVmiD0u2jiehC8y8Gw6flB5fdlWyz5ArpMes88RS9cHH1n4Dp5A\n", + "9YiKoxa6XsjMVtwy/Q1CE1CcjEE8nX1x2wi3FF+AiuFwqQsSRlHtfUsVksDBdXLvE8zjbyOIuIMV\n", + "pnJU22cEHHqRAVAAAQz/a8I3JUwtCYefKDlHQuITIdlhxtkj1S9/MOKY0At1R1tnioLMWN7HUVCo\n", + "b6XS9uoGwS6oOJgKcTFbR1vNa4wchWq0XCPds0DBwQAAAPYBnht0Q38AeTSjvudgsbkOLNHOwJSE\n", + "7MIAOT4Tae/DlzyAOhFcKHSt+XmND2K3krM1WAe1ksxoXOx8R5ib25iI4yoXHAvjcPvcDoLvQIYy\n", + "rfzkEj8FCsgVqTty2M7mcrrsvBMmGI/tSEAq1Wpq/wSUg2I4oZj0GjiChzewD+uw3YnWAi/Ntf5Y\n", + "Cv2dU9qEo9e3jPCavhxnj6HVQyqcvxekJ6cEcAGQvRh8PwiQyys4LYMz+Th6jmnZO6zDQlY1h459\n", + "aXiX/1NPDVjhvbOibPxdXy1nW8ZFN/ZpmMtUtTAz4mvuGfLCJYTZv8r0n1cztBPRieehovEAAAGy\n", + "AZ4dakN/AHwTrqiSAEDVZr7cfUIfCi6SEtf6z4BBmn/qEvCbGFYoG0hJzipIIEfgPxGLOPb5hgYo\n", + "3EqlxYfhyi3ADlPB0rSvUe/2K1c1bOHHkBdbN7v2fRCe6cTgBUViIyBzKbW8+YVzs1NjLsftvDLF\n", + "Jws+AVbFUOsz2XZO6+tJqS4okplORVfI8Zh8pjE7ly6+HI7Omo301kEp6VZks8VHiVKJOuTRsuFe\n", + "1lak9cDIgZS7IV3MkEjdmu8V6wPVTOui5KhgRegdKpe7dvKwiZROacSHUyEpgoiQ49NAkgd9ICSC\n", + "nOG96XtcVUK5qLGXI1ECEXtJcuaFVMtCmmOBBiFL8jC1MpHbxQ+4k2qRSUjP3JvFi0NfrsxeXbrH\n", + "Ebg5vBmNpJE6T+wdC73c70xC+Mtp+wYFzu5kfTKcL8d+Nzu4GlIr338e6SWwNSpXRGjfdLp9o3Ic\n", + "2PzMtQmrlpbEeUDp1vnkaZoqSF5M9xanIk/zohgoPX5++NN/ebYvr56WROjUeIUdsOf6nrJlmboT\n", + "DZEat6r4aY15lVCgiz4Mpb/mqSazxzrszmdRYRxGsW8DnzAAAAHfQZoCSahBaJlMCHf//qmWALFy\n", + "5oM61QiAB+cxK4+jNCOHXw6RALujtnWF0llKsvjvaSIz+44BdTBn8Dqmduydu0Ab2yYLL8rBa9BR\n", + "bM/WBrO6FCt4pfpaT57HiAbORTevnWHgnUCdwsiqbddvhjkiuJYbgCMD0kEP1SURu/b2Z5hWsq5s\n", + "eIdJwlVUmffx/GFsHH2OVg2kldaudIzyWEsMXsnZccvZ4+1TTMECSDKdUtlhUW9AAgPUraaePKP1\n", + "hatMAsKbsEP5g1nzjTlmyHjs7FjRbwjKng4/qsqVQ+s9Z8Le9mq44VPerxrlkKxdRgf8PQXTEpxP\n", + "gMR8UP9I/vRSJBbzTafYsMhPytfC8ESUe9ySga0pNZKSvC+bN1h7zO9OEjqF3rsnXJU2SZN7NAbS\n", + "01WCPkWQIdWN39TZ8BwhuM2E1/XfXA9OxCI/7PAG40Z8M1rKVJPTY+iwZnIQA6cEF3rnJVasn/JZ\n", + "rircnzzi1JQr5NiwthCEkD02k7GAoyHtF8lIKArvw+GqH7Ox1Tpd6DhPPJm2hmyijeFH6E+9UCJk\n", + "Iiolc9K3UW1rmUlHlF/p9jHAvsiiJUpuG/KCfna2LEYj9yn6P2oNlWfqq5P2HNtctaJeVRZv9Qb/\n", + "mNVjyjAAAAErQZ4gRREsO/8AZUEtk8LzOoS4AAhIFC88oI10PfUAs3UxxCOOtSzHREgn4/jgVfHt\n", + "0r483Tf2Y8D+zGlycQw2lUV6Nidlo0k0sASUCm4dEwF8Hb0+IzseFE0dYexJdLqvhcI7IIUIH6RG\n", + "uv8cjTXFD8CTksvYGpGc+uBYXhlwc3/jHhNGtm8G24uHniey+Zy/NtEpSl5dub3bE324kx+/N1gF\n", + "sU/CxkQF6UQWvd6Br4nL+i2L6udCLqM/JAVJhScc01UR/bE+NX2i3upx0qofgxfWL8unNZ/BP9Vc\n", + "CvVXAtxPw+0JopAnWMlwtBFG9wd+oP4zOIJ88u/VEvyZQd0JJP1Y3qhYk13Deyiv0C1r6ci1z7CQ\n", + "UwYqgUT64pT/hlIvHeCzEZxqH+WbUbEAAAGYAZ5fdEN/AIteE+hbrZmAAHNd3/IVGxTYP4E6C+Wr\n", + "63le3xAHjzqOqEil1tIAAUY3LvF62/277H30QskV8sEjceHvPe7bE0mfZ44avBY2gS0AAAMByRDk\n", + "EKOyh31Y2H0mdsy+zcGsPrGm3pHtO2riBcgILxHO0F5398HG90hK8UgtDUfp9CQyPOvDSyEU4WTb\n", + "6/WT9Z3aca6tb4C53W6p8Geyjq/mwbvNpnCVbbqIcx1ZT2+dencovmeYmPlI7jrhk6KwLYEd+5gO\n", + "J2YeKk4iWai6BsaO9+Tb5P52jBVHcSZ+Vws5QhTxkBSpdHlWJRcbh50V4ViVltwUN//XNx+jx2bk\n", + "KsfglI41FGmS2xAJtr8ZhKDk1VRRL2tGsNB5nztuRXCFd8q4MIuVVWGjim0ntcxZ/R18mzJZN+sI\n", + "qKUvfsxoaeZp+oIaU1hLeXzgcHEe+3/6emdZeJWoDNhUqhkfWzWzVZbEzUKpDBS9AbVIA5KR27LD\n", + "3HEfRMw9yt8eYILg7m/Rm2ubtU8u6V2QuxVXq1OHry5oY2TAAAABvQGeQWpDfwCPV5unds/RGF4o\n", + "aWlq+XwTSVpG+igacFOApaqyNJIXSXT4q7gA4DkP0YAYAumNCN0MwD7HSEeIsv3Q3L9kZ2RagxvU\n", + "jle4yQq6Zl5W7AgdlZnaBngH/w8xYsqWx5t90zzi7s9VyRY9jaNshfxuJAZcRgFILNTmQNCPoCtl\n", + "wyo5Ht91VCy2qSby6JDLeTD096PzM4KOK7/I+amuefuT0S/QnDNs952oi11JV2mbadqtKDqJE9x4\n", + "nX/OjU9PBP1uhsFLNkjsz6ZHlTOcsZvWUxabbw0HBNFuLXWIYqtAYdWN7c/QUoqY2IlVBR//v+NN\n", + "Bxf/rxPv+9QlTTeUOAVhzyU/kQACorW+VEL2KFNUPF85LUxlbSGEYQv/98/fAQAu6hKRw3yoJoPy\n", + "tyr7S7Za9gGurMYseuvuasNoB+fPCmp37VWgm4yNZQ0LM+8CPtaQgShVMs2/RIG2cXksHuYVqEB7\n", + "PJtzP2tl8EYDen8RohIb2UO5d/Xdc8aoi/Nu4IzGq8ApuZIxjC5J9bUYtMDEDA6eChGKPjb20vqg\n", + "2PRBI2fSXJrcSROGTC4m+VsF+VagO1LnjrakndEAAAHtQZpDSahBbJlMCG///qeEAVH55ayIAL6z\n", + "9D9Go2JR/VsPgULYIy+HM1JNQWUio64eqKV59gHDbxQ77xKGvVi/RlMeepNHF+Cplpp4rKqgivaK\n", + "14o0jVVjKwdzXmYfm8QJck76NrSj9rXzMi3Th9DbQ5HQHvlFr1+Ft6fGVXaubVoF+Bx3J4nvsWO+\n", + "FhXDphKaWh9geM/3PqX1TK4zqhRL2wKgDCWdLvIi2s2e48RSWR1zksj0SjkMINJfgjA7wVj0dW8Z\n", + "NZGlcRPjgkoSgpomI+x9/l7dJ5fHEj4WOkMQMTJnj+KOqaXfgtXbhBachZ0Av1Z6rh+qw/iObJOy\n", + "7q2gUdlftEWI7In7KZjqqg18Bg+z35wI2FmknOyXdEiDAPaFiRrhqkKOLfgLssw1BdohiuTGWlKn\n", + "NvPL4EzIbAUeS+0qv5cFdXvRjnn1zOMYTMpyN1CZYg4pqjj8mGtGdm1F7w0Xo4Mnm3hRmvZyyOaW\n", + "yf38s1SCwyOkhQcwJhrAAebvkxMWrAUWrTq9K9PdCUqFbMVB9+93aovoux8zBfM/WLangtLLXd/D\n", + "T9TcgY0eosWGZeAhQk2sxNC3bgvMT328AT2T2XCg2nG4jsOakPWfscwbc0zKfItj/1eXvyR2tk+K\n", + "fpgdg9dJ/OdcXINTUAAAB95tb292AAAAbG12aGQAAAAAAAAAAAAAAAAAAAPoAAAnEAABAAABAAAA\n", + "AAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAAAAA\n", + "AAAAAAAAAAAAAAAAAAACAAAHCHRyYWsAAABcdGtoZAAAAAMAAAAAAAAAAAAAAAEAAAAAAAAnEAAA\n", + "AAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAEAAAAABsAAAASAA\n", + "AAAAACRlZHRzAAAAHGVsc3QAAAAAAAAAAQAAJxAAAAgAAAEAAAAABoBtZGlhAAAAIG1kaGQAAAAA\n", + "AAAAAAAAAAAAACgAAAGQAFXEAAAAAAAtaGRscgAAAAAAAAAAdmlkZQAAAAAAAAAAAAAAAFZpZGVv\n", + "SGFuZGxlcgAAAAYrbWluZgAAABR2bWhkAAAAAQAAAAAAAAAAAAAAJGRpbmYAAAAcZHJlZgAAAAAA\n", + "AAABAAAADHVybCAAAAABAAAF63N0YmwAAACzc3RzZAAAAAAAAAABAAAAo2F2YzEAAAAAAAAAAQAA\n", + "AAAAAAAAAAAAAAAAAAABsAEgAEgAAABIAAAAAAAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\n", + "AAAAAAAAAAAY//8AAAAxYXZjQwFkABX/4QAYZ2QAFazZQbCWhAAAAwAEAAADAFA8WLZYAQAGaOvj\n", + "yyLAAAAAHHV1aWRraEDyXyRPxbo5pRvPAyPzAAAAAAAAABhzdHRzAAAAAAAAAAEAAABkAAAEAAAA\n", + "ABRzdHNzAAAAAAAAAAEAAAABAAADMGN0dHMAAAAAAAAAZAAAAAEAAAgAAAAAAQAAFAAAAAABAAAI\n", + "AAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQA\n", + "AAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAA\n", + "AAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAAAAAA\n", + "AAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAMAAAAAAEAAAQAAAAA\n", + "AQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAAB\n", + "AAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEA\n", + "AAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAA\n", + "CAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAM\n", + "AAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgA\n", + "AAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAA\n", + "AAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAA\n", + "AAEAABQAAAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAA\n", + "AQAABAAAAAABAAAMAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAAB\n", + "AAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAAAgAAAAAHHN0c2MAAAAAAAAAAQAAAAEA\n", + "AABkAAAAAQAAAaRzdHN6AAAAAAAAAAAAAABkAAAGhgAAAl8AAAFjAAAAvgAAAXYAAAHzAAABDgAA\n", + "ATYAAAFIAAAB9QAAAOIAAAD6AAABWgAAAbAAAADTAAAB8wAAAN4AAAH+AAABEAAAAOIAAAG2AAAC\n", + "DAAAAWUAAAGkAAABmgAAAckAAAEdAAABfQAAAPMAAAFxAAABIgAAAjYAAAEmAAAA5AAAAXoAAAH+\n", + "AAAA/wAAAT0AAAFnAAACAwAAARQAAAE3AAABTwAAAckAAADrAAACFwAAAP0AAAHzAAABIQAAAOAA\n", + "AAHKAAACOwAAAVQAAAHFAAABugAAAdQAAAD3AAABUgAAARIAAAFuAAABLwAAAhAAAAERAAAA9gAA\n", + "AZkAAAIqAAABIgAAAV0AAAGIAAACSgAAASgAAAFEAAABggAAAegAAAD+AAACCgAAASIAAAIdAAAB\n", + "KAAAAQcAAAHbAAACFgAAAT0AAAITAAAB2gAAAi8AAAEGAAABrQAAASoAAAF0AAABZgAAAl4AAAFU\n", + "AAAA+gAAAbYAAAHjAAABLwAAAZwAAAHBAAAB8QAAABRzdGNvAAAAAAAAAAEAAAAsAAAAYnVkdGEA\n", + "AABabWV0YQAAAAAAAAAhaGRscgAAAAAAAAAAbWRpcmFwcGwAAAAAAAAAAAAAAAAtaWxzdAAAACWp\n", + "dG9vAAAAHWRhdGEAAAABAAAAAExhdmY1Ny44My4xMDA=\n", + "\"\u003e\n", + " Your browser does not support the video tag.\n", + "\u003c/video\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f84b2253b50\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "import time\n", + "import traceback\n", + "\n", + "from matplotlib import pyplot as plt\n", + "from matplotlib import animation as anim\n", + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph as ag\n", + "from IPython import display\n", + "\n", + "\n", + "@ag.do_not_convert(ag.RunMode.PY_FUNC)\n", + "def render(boards):\n", + " fig = plt.figure()\n", + "\n", + " ims = []\n", + " for b in boards:\n", + " im = plt.imshow(b, interpolation='none')\n", + " im.axes.get_xaxis().set_visible(False)\n", + " im.axes.get_yaxis().set_visible(False)\n", + " ims.append([im])\n", + "\n", + " try:\n", + " ani = anim.ArtistAnimation(\n", + " fig, ims, interval=100, blit=True, repeat_delay=5000)\n", + " plt.close()\n", + "\n", + " display.display(display.HTML(ani.to_html5_video()))\n", + " except RuntimeError:\n", + " print('Coult not render animation:')\n", + " traceback.print_exc()\n", + "\n", + "\n", + "def gol_episode(board):\n", + " directions = tf.constant(\n", + " ((-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)))\n", + "\n", + " new_board = []\n", + " ag.set_element_type(new_board, tf.int32)\n", + "\n", + " for i in range(len(board)):\n", + " for j in range(len(board[i])):\n", + " num_neighbors = 0\n", + " for d in directions:\n", + " ni = i + d[0]\n", + " nj = j + d[1]\n", + " if ni \u003e= 0 and nj \u003e= 0 and ni \u003c len(board) and nj \u003c len(board[i]):\n", + " num_neighbors += board[ni][nj]\n", + " \n", + " new_cell = 0\n", + " if num_neighbors == 2:\n", + " new_cell = board[i][j]\n", + " elif num_neighbors == 3:\n", + " new_cell = 1\n", + " \n", + " new_board.append(new_cell)\n", + " final_board = ag.stack(new_board)\n", + " final_board = tf.reshape(final_board, board.shape)\n", + " return final_board\n", + " \n", + "\n", + "def gol(initial_board):\n", + " board = initial_board\n", + " boards = []\n", + " ag.set_element_type(boards, tf.int32)\n", + " # We are being explicit about tensor constants to ensure the loop\n", + " # is not unrolled in the graph. This may change in the future.\n", + " for i in range(tf.constant(NUM_STEPS)):\n", + " board = gol_episode(board)\n", + " boards.append(board)\n", + " boards = ag.stack(boards)\n", + " render(boards)\n", + " return tf.no_op()\n", + " \n", + "\n", + "with tf.Graph().as_default():\n", + " # Gosper glider gun\n", + " # Adapted from http://www.cplusplus.com/forum/lounge/75168/\n", + " _ = 0\n", + " initial_board = tf.constant((\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,1,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_,_,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,1,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_ ),\n", + " ( _,1,1,_,_,_,_,_,_,_,_,1,_,_,_,_,_,1,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,1,1,_,_,_,_,_,_,_,_,1,_,_,_,1,_,1,1,_,_,_,_,1,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,1,_,_,_,_,_,1,_,_,_,_,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ))\n", + " initial_board = tf.pad(initial_board, ((0, 20), (0, 10)))\n", + " \n", + " tf_gol = ag.to_graph(gol)\n", + " game_ops = tf_gol(initial_board)\n", + " with tf.Session() as sess:\n", + " sess.run(game_ops)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7NgrSPCZxs3h" + }, + "source": [ + "#### Generated code" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 2323 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 753, + "status": "ok", + "timestamp": 1532101593840, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "hIGYeX0Cxs3i", + "outputId": "e0b62eb1-3e12-4e53-dc54-8a3fa56d823d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "from __future__ import print_function\n", + "import tensorflow as tf\n", + "\n", + "def tf__gol_episode(board):\n", + " try:\n", + " with tf.name_scope('gol_episode'):\n", + " directions = tf.constant(((-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1),\n", + " (1, -1), (1, 0), (1, 1)))\n", + " new_board = ag__.new_list([])\n", + "\n", + " def extra_test_2(new_board_2):\n", + " with tf.name_scope('extra_test_2'):\n", + " return True\n", + "\n", + " def loop_body_2(i, new_board_2):\n", + " with tf.name_scope('loop_body_2'):\n", + "\n", + " def extra_test_1(new_board_1):\n", + " with tf.name_scope('extra_test_1'):\n", + " return True\n", + "\n", + " def loop_body_1(j, new_board_1):\n", + " with tf.name_scope('loop_body_1'):\n", + " num_neighbors = 0\n", + "\n", + " def extra_test(num_neighbors_2):\n", + " with tf.name_scope('extra_test'):\n", + " return True\n", + "\n", + " def loop_body(d, num_neighbors_2):\n", + " with tf.name_scope('loop_body'):\n", + " ni = i + ag__.get_item(d, (0), opts=ag__.GetItemOpts(\n", + " element_dtype=None))\n", + " nj = j + ag__.get_item(d, (1), opts=ag__.GetItemOpts(\n", + " element_dtype=None))\n", + "\n", + " def if_true():\n", + " with tf.name_scope('if_true'):\n", + " num_neighbors_1, = num_neighbors_2,\n", + " num_neighbors_1 += ag__.get_item(ag__.get_item(board,\n", + " (ni), opts=ag__.GetItemOpts(element_dtype=None)),\n", + " (nj), opts=ag__.GetItemOpts(element_dtype=None))\n", + " return num_neighbors_1,\n", + "\n", + " def if_false():\n", + " with tf.name_scope('if_false'):\n", + " return num_neighbors_2,\n", + " num_neighbors_2 = ag__.utils.run_cond(tf.logical_and(tf.\n", + " greater_equal(ni, 0), tf.logical_and(tf.greater_equal\n", + " (nj, 0), tf.logical_and(tf.less(ni, ag__.utils.\n", + " dynamic_builtin(len, board)), tf.less(nj, ag__.utils.\n", + " dynamic_builtin(len, ag__.get_item(board, (i), opts=\n", + " ag__.GetItemOpts(element_dtype=None))))))), if_true,\n", + " if_false)\n", + " return num_neighbors_2,\n", + " num_neighbors = ag__.for_stmt(directions, extra_test,\n", + " loop_body, (num_neighbors,))\n", + " new_cell = 0\n", + "\n", + " def if_true_2():\n", + " with tf.name_scope('if_true_2'):\n", + " new_cell_2, = new_cell,\n", + " new_cell_2 = ag__.get_item(ag__.get_item(board, (i), opts\n", + " =ag__.GetItemOpts(element_dtype=None)), (j), opts=\n", + " ag__.GetItemOpts(element_dtype=None))\n", + " return new_cell_2,\n", + "\n", + " def if_false_2():\n", + " with tf.name_scope('if_false_2'):\n", + " new_cell_3, = new_cell,\n", + "\n", + " def if_true_1():\n", + " with tf.name_scope('if_true_1'):\n", + " new_cell_1, = new_cell_3,\n", + " new_cell_1 = 1\n", + " return new_cell_1,\n", + "\n", + " def if_false_1():\n", + " with tf.name_scope('if_false_1'):\n", + " return new_cell_3,\n", + " new_cell_3 = ag__.utils.run_cond(tf.equal(num_neighbors, \n", + " 3), if_true_1, if_false_1)\n", + " return new_cell_3,\n", + " new_cell = ag__.utils.run_cond(tf.equal(num_neighbors, 2),\n", + " if_true_2, if_false_2)\n", + " new_board_1 = ag__.list_append(new_board_1, new_cell)\n", + " return new_board_1,\n", + " new_board_2 = ag__.for_stmt(ag__.utils.dynamic_builtin(range,\n", + " ag__.utils.dynamic_builtin(len, ag__.get_item(board, (i),\n", + " opts=ag__.GetItemOpts(element_dtype=None)))), extra_test_1,\n", + " loop_body_1, (new_board_2,))\n", + " return new_board_2,\n", + " new_board = ag__.for_stmt(ag__.utils.dynamic_builtin(range, ag__.\n", + " utils.dynamic_builtin(len, board)), extra_test_2, loop_body_2, (\n", + " new_board,))\n", + " final_board = ag__.list_stack(new_board, opts=ag__.ListStackOpts(\n", + " element_dtype=tf.int32, original_call=ag.stack))\n", + " final_board = tf.reshape(final_board, board.shape)\n", + " return final_board\n", + " except:\n", + " ag__.rewrite_graph_construction_error(ag_source_map__)\n", + "\n", + "def tf__gol(initial_board):\n", + " try:\n", + " with tf.name_scope('gol'):\n", + " board = initial_board\n", + " boards = ag__.new_list([])\n", + "\n", + " def extra_test(board_1, boards_1):\n", + " with tf.name_scope('extra_test'):\n", + " return True\n", + "\n", + " def loop_body(i, board_1, boards_1):\n", + " with tf.name_scope('loop_body'):\n", + " board_1 = tf__gol_episode(board_1)\n", + " boards_1 = ag__.list_append(boards_1, board_1)\n", + " return board_1, boards_1\n", + " board, boards = ag__.for_stmt(ag__.utils.dynamic_builtin(range, tf.\n", + " constant(NUM_STEPS)), extra_test, loop_body, (board, boards))\n", + " boards = ag__.list_stack(boards, opts=ag__.ListStackOpts(\n", + " element_dtype=tf.int32, original_call=ag.stack))\n", + " with ag__.utils.control_dependency_on_returns(render(boards)):\n", + " boards_2 = ag__.utils.alias_tensors(boards)\n", + " return tf.no_op()\n", + " except:\n", + " ag__.rewrite_graph_construction_error(ag_source_map__)\n", + "\n" + ] + } + ], + "source": [ + "print(ag.to_code(gol))" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "p8zZyj-tq4K3", + "Lkq3DBGOv3fA", + "r8_0ioEuAI-a", + "7NgrSPCZxs3h" + ], + "default_view": {}, + "last_runtime": { + "build_target": "", + "kind": "local" + }, + "name": "Simple algorithms using AutoGraph", + "provenance": [ + { + "file_id": "19q8KdVF8Cb_fDd13i-WDOG_6n_QGNW5-", + "timestamp": 1528465909719 + } + ], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb index d62390494b78c415212ba91ac914cdfee324f971..7e9cc54d4cafa64e4cd3b48f9376b1b2b4d3575e 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb @@ -1,49 +1,20 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Dev Summit 2018 - Autograph", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [ - { - "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", - "timestamp": 1522238054357 - }, - { - "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", - "timestamp": 1521743157199 - }, - { - "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", - "timestamp": 1520522344607 - } - ], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python2", - "display_name": "Python 2" - } - }, "cells": [ { + "cell_type": "markdown", "metadata": { - "id": "g7nGs4mzVUHP", - "colab_type": "text" + "colab_type": "text", + "id": "g7nGs4mzVUHP" }, - "cell_type": "markdown", "source": [ - "# Experimental: TF Autograph\n", + "# Experimental: TF AutoGraph\n", "**TensorFlow Dev Summit, 2018.**\n", "\n", - "This interactive notebook demonstrates **autograph**, an experimental source-code transformation library to automatically convert TF.Eager and Python code to TensorFlow graphs.\n", + "This interactive notebook demonstrates **AutoGraph**, an experimental source-code transformation library to automatically convert Python, TensorFlow and NumPy code to TensorFlow graphs.\n", "\n", "**Note: this is pre-alpha software!** The notebook works best with Python 2, for now.\n", "\n", - "> ![alt text](https://lh3.googleusercontent.com/QOvy0clmg7siaVKzwmSPAjicWWNQ0OeyaB16plDjSJMf35WD3vLjF6mz4CGrhSHw60HnlZPJjkyDCBzw5XOI0oBGSewyYw=s688)\n", + "\u003e ![alt text](https://lh3.googleusercontent.com/QOvy0clmg7siaVKzwmSPAjicWWNQ0OeyaB16plDjSJMf35WD3vLjF6mz4CGrhSHw60HnlZPJjkyDCBzw5XOI0oBGSewyYw=s688)\n", "\n", "### Table of Contents\n", "1. _Write Eager code that is fast and scalable._\n", @@ -53,37 +24,39 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "uFcgBENZqkB2", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "uFcgBENZqkB2" }, - "cell_type": "code", + "outputs": [], "source": [ "# Install TensorFlow; note that Colab notebooks run remotely, on virtual\n", "# instances provided by Google.\n", "!pip install -U -q tf-nightly" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "Pa2qpEmoVOGe", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "Pa2qpEmoVOGe" }, - "cell_type": "code", + "outputs": [], "source": [ "import os\n", "import time\n", @@ -96,170 +69,172 @@ "import six\n", "\n", "from google.colab import widgets" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "ZVKfj5ttVkqz", - "colab_type": "text" + "colab_type": "text", + "id": "ZVKfj5ttVkqz" }, - "cell_type": "markdown", "source": [ "# 1. Write Eager code that is fast and scalable\n", "\n", "TF.Eager gives you more flexibility while coding, but at the cost of losing the benefits of TensorFlow graphs. For example, Eager does not currently support distributed training, exporting models, and a variety of memory and computation optimizations.\n", "\n", - "Autograph gives you the best of both worlds: write your code in an Eager style, and we will automatically transform it into the equivalent TF graph code. The graph code can be executed eagerly (as a single op), included as part of a larger graph, or exported." + "AutoGraph gives you the best of both worlds: you can write your code in an Eager style, and we will automatically transform it into the equivalent TF graph code. The graph code can be executed eagerly (as a single op), included as part of a larger graph, or exported." ] }, { + "cell_type": "markdown", "metadata": { - "id": "snaZRFdWd9ym", - "colab_type": "text" + "colab_type": "text", + "id": "snaZRFdWd9ym" }, - "cell_type": "markdown", "source": [ - "For example, autograph can convert a function like this:" + "For example, AutoGraph can convert a function like this:" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "9__n8cSIeDnD", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "9__n8cSIeDnD" }, - "cell_type": "code", + "outputs": [], "source": [ "def g(x):\n", - " if x > 0:\n", + " if x \u003e 0:\n", " x = x * x\n", " else:\n", " x = 0\n", " return x" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "gq0eQcuReHET", - "colab_type": "text" + "colab_type": "text", + "id": "gq0eQcuReHET" }, - "cell_type": "markdown", "source": [ "... into a TF graph-building function:" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "sELSn599ePUF", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", - "height": 413 + "height": 431 }, - "outputId": "bb0c7216-1ca3-4da1-d1fb-589902cdcd1a", + "colab_type": "code", "executionInfo": { + "elapsed": 69, "status": "ok", - "timestamp": 1522345737505, - "user_tz": 240, - "elapsed": 243, + "timestamp": 1531750911837, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "sELSn599ePUF", + "outputId": "2858bde5-ae05-4c32-be01-7770ac914f02" }, - "cell_type": "code", - "source": [ - "print(autograph.to_code(g))" - ], - "execution_count": 0, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "from __future__ import print_function\n", "import tensorflow as tf\n", - "from tensorflow.contrib.autograph.impl import api as autograph_api\n", - "from tensorflow.contrib.autograph import utils as autograph_utils\n", "\n", "def tf__g(x):\n", - " with tf.name_scope('g'):\n", + " try:\n", + " with tf.name_scope('g'):\n", "\n", - " def if_true():\n", - " with tf.name_scope('if_true'):\n", - " x_1, = x,\n", - " x_1 = x_1 * x_1\n", - " return x_1,\n", + " def if_true():\n", + " with tf.name_scope('if_true'):\n", + " x_1, = x,\n", + " x_1 = x_1 * x_1\n", + " return x_1,\n", "\n", - " def if_false():\n", - " with tf.name_scope('if_false'):\n", - " x_1, = x,\n", - " x_1 = 0\n", - " return x_1,\n", - " x = autograph_utils.run_cond(tf.greater(x, 0), if_true, if_false)\n", - " return x\n", + " def if_false():\n", + " with tf.name_scope('if_false'):\n", + " x_2, = x,\n", + " x_2 = 0\n", + " return x_2,\n", + " x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false)\n", + " return x\n", + " except:\n", + " ag__.rewrite_graph_construction_error(ag_source_map__)\n", "\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "print(autograph.to_code(g))" ] }, { + "cell_type": "markdown", "metadata": { - "id": "j74n-8hEe6dk", - "colab_type": "text" + "colab_type": "text", + "id": "j74n-8hEe6dk" }, - "cell_type": "markdown", "source": [ "You can then use the converted function as you would any regular TF op -- you can pass `Tensor` arguments and it will return `Tensor`s:" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "AkVaY0-dfEbH", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", "height": 53 }, - "outputId": "4ffe3757-c44d-424c-c2a8-7ddc973bfcce", + "colab_type": "code", "executionInfo": { + "elapsed": 83, "status": "ok", - "timestamp": 1522345737841, - "user_tz": 240, - "elapsed": 257, + "timestamp": 1531750911965, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "AkVaY0-dfEbH", + "outputId": "f04541ad-b1d3-4663-bf27-4d902648283d" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "g(9) = 81\n", + "tf_g(9) = 81\n" + ] + } + ], "source": [ "tf_g = autograph.to_graph(g)\n", "\n", @@ -272,77 +247,72 @@ "\n", " print('g(9) = %s' % g(9))\n", " print('tf_g(9) = %s' % tf_g_result)" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "g(9) = 81\n", - "tf_g(9) = 81\n" - ], - "name": "stdout" - } ] }, { + "cell_type": "markdown", "metadata": { - "id": "trrHQBM1VnD0", - "colab_type": "text" + "colab_type": "text", + "id": "trrHQBM1VnD0" }, - "cell_type": "markdown", "source": [ "# 2. Case study: complex control flow\n", "\n", - "Autograph can convert a large chunk of the Python language into graph-equivalent code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in autograph.\n", - "Autograph will automatically convert most Python control flow statements into their correct graph equivalent.\n", + "Autograph can convert a large subset of the Python language into graph-equivalent code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in AutoGraph.\n", + "AutoGraph will automatically convert most Python control flow statements into their graph equivalent.\n", " " ] }, { + "cell_type": "markdown", "metadata": { - "id": "u0YG3DPgZxoW", - "colab_type": "text" + "colab_type": "text", + "id": "u0YG3DPgZxoW" }, - "cell_type": "markdown", "source": [ "We support common statements like `while`, `for`, `if`, `break`, `return` and more. You can even nest them as much as you like. Imagine trying to write the graph version of this code by hand:" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "xJYDzOcrZ8pI", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", "height": 35 }, - "outputId": "6c244ee4-b141-4ad6-eefa-cfffa71f33c6", + "colab_type": "code", "executionInfo": { + "elapsed": 169, "status": "ok", - "timestamp": 1522345738402, - "user_tz": 240, - "elapsed": 483, + "timestamp": 1531750912183, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "xJYDzOcrZ8pI", + "outputId": "f392b475-bf87-4d90-919d-44f895ee9fc7" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sum of even numbers: 42\n" + ] + } + ], "source": [ "def sum_even(numbers):\n", " s = 0\n", " for n in numbers:\n", - " if n % 2 > 0:\n", + " if n % 2 \u003e 0:\n", " continue\n", " s += n\n", " return s\n", @@ -358,77 +328,74 @@ " \n", "# Uncomment the line below to print the generated graph code\n", "# print(autograph.to_code(sum_even))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Sum of even numbers: 42\n" - ], - "name": "stdout" - } ] }, { + "cell_type": "markdown", "metadata": { - "id": "_YXo4KOcbKrn", - "colab_type": "text" + "colab_type": "text", + "id": "_YXo4KOcbKrn" }, - "cell_type": "markdown", "source": [ "Try replacing the `continue` in the above code with `break` -- Autograph supports that as well!" ] }, { + "cell_type": "markdown", "metadata": { - "id": "xHmC0rBIavW_", - "colab_type": "text" + "colab_type": "text", + "id": "xHmC0rBIavW_" }, - "cell_type": "markdown", "source": [ "The Python code above is much more readable than the matching graph code. Autograph takes care of tediously converting every piece of Python code into the matching TensorFlow graph version for you, so that you can quickly write maintainable code, but still benefit from the optimizations and deployment benefits of graphs." ] }, { + "cell_type": "markdown", "metadata": { - "id": "UEHWGpBXbS7g", - "colab_type": "text" + "colab_type": "text", + "id": "UEHWGpBXbS7g" }, - "cell_type": "markdown", "source": [ "Let's try some other useful Python constructs, like `print` and `assert`. We automatically convert Python `assert` statements into the equivalent `tf.Assert` code. " ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "qUU57xlEbauI", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", "height": 53 }, - "outputId": "add3db4a-2077-4dd5-f7a7-a5b5a4529c26", + "colab_type": "code", "executionInfo": { + "elapsed": 56, "status": "ok", - "timestamp": 1522345738697, - "user_tz": 240, - "elapsed": 253, + "timestamp": 1531750912292, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "qUU57xlEbauI", + "outputId": "c9cd536a-4a95-4eb0-98c0-aafce5d79580" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Got error message: assertion failed: [Do not pass zero!]\n", + "\t [[{{node f/Assert/Assert}} = Assert[T=[DT_STRING], summarize=3, _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](f/NotEqual, f/Assert/Assert/data_0)]]\n" + ] + } + ], "source": [ "def f(x):\n", " assert x != 0, 'Do not pass zero!'\n", @@ -444,61 +411,35 @@ " \n", "# Uncomment the line below to print the generated graph code\n", "# print(autograph.to_code(f))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Got error message: assertion failed: [Do not pass zero!]\n", - "\t [[Node: f/Assert/Assert = Assert[T=[DT_STRING], summarize=3, _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](f/NotEqual, f/Assert/Assert/data_0)]]\n" - ], - "name": "stdout" - } ] }, { + "cell_type": "markdown", "metadata": { - "id": "w5hBZaVJbck4", - "colab_type": "text" + "colab_type": "text", + "id": "w5hBZaVJbck4" }, - "cell_type": "markdown", "source": [ "You can also use `print` functions in-graph:" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "6NdzRKLEboRv", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", - "height": 35 - }, - "outputId": "fb82dfc3-790f-4127-87f6-361805be9e9b", - "executionInfo": { - "status": "ok", - "timestamp": 1522345739013, - "user_tz": 240, - "elapsed": 247, - "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" } - } + }, + "colab_type": "code", + "id": "6NdzRKLEboRv" }, - "cell_type": "code", + "outputs": [], "source": [ "def print_sign(n):\n", - " if n >= 0:\n", + " if n \u003e= 0:\n", " print(n, 'is positive!')\n", " else:\n", " print(n, 'is negative!')\n", @@ -512,65 +453,61 @@ " \n", "# Uncomment the line below to print the generated graph code\n", "# print(autograph.to_code(print_sign))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "1 is positive!\n" - ], - "name": "stdout" - } ] }, { + "cell_type": "markdown", "metadata": { - "id": "9u_Z3i3AivLA", - "colab_type": "text" + "colab_type": "text", + "id": "9u_Z3i3AivLA" }, - "cell_type": "markdown", "source": [ - "We can convert lists to TensorArray, so appending to lists also works, with a few modifications:" + "Appending to lists also works, with a few modifications:" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "MjhCQJVuiTNR", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", "height": 35 }, - "outputId": "dc320b87-595b-4392-d29c-994486fd8a0a", + "colab_type": "code", "executionInfo": { + "elapsed": 148, "status": "ok", - "timestamp": 1522345744470, - "user_tz": 240, - "elapsed": 5391, + "timestamp": 1531750912595, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "MjhCQJVuiTNR", + "outputId": "96bf9131-c7c1-4359-ee82-9c38575e7ab4" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0 1 2 3 4]\n" + ] + } + ], "source": [ "def f(n):\n", " numbers = []\n", " # We ask you to tell us about the element dtype.\n", - " autograph.utils.set_element_type(numbers, tf.int32)\n", + " autograph.set_element_type(numbers, tf.int32)\n", " for i in range(n):\n", " numbers.append(i)\n", - " return numbers.stack() # Stack the list so that it can be used as a Tensor\n", + " return autograph.stack(numbers) # Stack the list so that it can be used as a Tensor\n", "\n", "\n", "tf_f = autograph.to_graph(f)\n", @@ -580,65 +517,62 @@ " \n", "# Uncomment the line below to print the generated graph code\n", "# print(autograph.to_code(f))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "[0 1 2 3 4]\n" - ], - "name": "stdout" - } ] }, { + "cell_type": "markdown", "metadata": { - "id": "UdG8ZFrkTAF2", - "colab_type": "text" + "colab_type": "text", + "id": "UdG8ZFrkTAF2" }, - "cell_type": "markdown", "source": [ "And all of these functionalities, and more, can be composed into more complicated code:\n" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "DVs6wt8NKaGQ", - "colab_type": "code", + "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", "height": 53 }, - "cellView": "code", - "outputId": "0a4b8d08-8f65-4bbc-85ba-dc4c60563519", + "colab_type": "code", "executionInfo": { + "elapsed": 555, "status": "ok", - "timestamp": 1522345745186, - "user_tz": 240, - "elapsed": 658, + "timestamp": 1531750913176, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "DVs6wt8NKaGQ", + "outputId": "8729229c-4f08-4640-d3a1-0d3f9c697a87" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The prime numbers less than 50 are:\n", + "[ 2 3 5 7 11 13 17 19 23 29 31 37 41 43 47]\n" + ] + } + ], "source": [ "def print_primes(n):\n", " \"\"\"Returns all the prime numbers less than n.\"\"\"\n", - " assert n > 0\n", + " assert n \u003e 0\n", " \n", " primes = []\n", - " autograph.utils.set_element_type(primes, tf.int32)\n", + " autograph.set_element_type(primes, tf.int32)\n", " for i in range(2, n):\n", " is_prime = True\n", " for k in range(2, i):\n", @@ -648,7 +582,7 @@ " if not is_prime:\n", " continue\n", " primes.append(i)\n", - " all_primes = primes.stack()\n", + " all_primes = autograph.stack(primes)\n", "\n", " print('The prime numbers less than', n, 'are:')\n", " print(all_primes)\n", @@ -663,45 +597,36 @@ " \n", "# Uncomment the line below to print the generated graph code\n", "# print(autograph.to_code(print_primes))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "The prime numbers less than 50 are:\n", - "[ 2 3 5 7 11 13 17 19 23 29 31 37 41 43 47]\n" - ], - "name": "stdout" - } ] }, { + "cell_type": "markdown", "metadata": { - "id": "JQ8kQT99VqDk", - "colab_type": "text" + "colab_type": "text", + "id": "JQ8kQT99VqDk" }, - "cell_type": "markdown", "source": [ "# 3. Case study: training MNIST with Keras\n", "\n", - "As we've seen, writing control flow in Autograph is easy. So running a training loop in graph should be easy as well!\n", + "As we've seen, writing control flow in AutoGraph is easy. So running a training loop in graph should be easy as well!\n", "\n", "Here, we show an example of such a training loop for a simple Keras model that trains on MNIST." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "0CrtGWgwuLJr", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "0CrtGWgwuLJr" }, - "cell_type": "code", + "outputs": [], "source": [ "import gzip\n", "import shutil\n", @@ -754,66 +679,67 @@ "\n", "def mnist_test(directory):\n", " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "2zu1U9Nqir6L", - "colab_type": "text" + "colab_type": "text", + "id": "2zu1U9Nqir6L" }, - "cell_type": "markdown", "source": [ "First, we'll define a small three-layer neural network using the Keras API" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "x_MU13boiok2", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "x_MU13boiok2" }, - "cell_type": "code", + "outputs": [], "source": [ "def mlp_model(input_shape):\n", - " model = tf.keras.Sequential([\n", + " model = tf.keras.Sequential((\n", " tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n", " tf.keras.layers.Dense(100, activation='relu'),\n", - " tf.keras.layers.Dense(10, activation='softmax')])\n", + " tf.keras.layers.Dense(10, activation='softmax'),\n", + " ))\n", " model.build()\n", " return model" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "Wuqg3H8mi0Xj", - "colab_type": "text" + "colab_type": "text", + "id": "Wuqg3H8mi0Xj" }, - "cell_type": "markdown", "source": [ "Let's connect the model definition (here abbreviated as `m`) to a loss function, so that we can train our model." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "W51sfbONiz_5", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "W51sfbONiz_5" }, - "cell_type": "code", + "outputs": [], "source": [ "def predict(m, x, y):\n", " y_p = m(x)\n", @@ -822,63 +748,63 @@ " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n", " accuracy = tf.reduce_mean(accuracies)\n", " return l, accuracy" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "035tNWQki9tr", - "colab_type": "text" + "colab_type": "text", + "id": "035tNWQki9tr" }, - "cell_type": "markdown", "source": [ "Now the final piece of the problem specification (before loading data, and clicking everything together) is backpropagating the loss through the model, and optimizing the weights using the gradient." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "CsAD0ajbi9iZ", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "CsAD0ajbi9iZ" }, - "cell_type": "code", + "outputs": [], "source": [ "def fit(m, x, y, opt):\n", " l, accuracy = predict(m, x, y)\n", " opt.minimize(l)\n", " return l, accuracy" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "PcVRIacKjSwb", - "colab_type": "text" + "colab_type": "text", + "id": "PcVRIacKjSwb" }, - "cell_type": "markdown", "source": [ "These are some utility functions to download data and generate batches for training" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "RVw57HdTjPzi", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "RVw57HdTjPzi" }, - "cell_type": "code", + "outputs": [], "source": [ "def setup_mnist_data(is_training, hp, batch_size):\n", " if is_training:\n", @@ -896,16 +822,14 @@ " x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n", " y = tf.one_hot(tf.squeeze(label), 10)\n", " return x, y" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "2zEJH5XNjgFz", - "colab_type": "text" + "colab_type": "text", + "id": "2zEJH5XNjgFz" }, - "cell_type": "markdown", "source": [ "This function specifies the main training loop. We instantiate the model (using the code above), instantiate an optimizer (here we'll use SGD with momentum, nothing too fancy), and we'll instantiate some lists to keep track of training and test loss and accuracy over time.\n", "\n", @@ -913,33 +837,35 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "UUI0566FjZPx", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "UUI0566FjZPx" }, - "cell_type": "code", + "outputs": [], "source": [ "def train(train_ds, test_ds, hp):\n", " m = mlp_model((28 * 28,))\n", " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n", + "\n", " train_losses = []\n", - " train_losses = autograph.utils.set_element_type(train_losses, tf.float32)\n", + " autograph.set_element_type(train_losses, tf.float32)\n", " test_losses = []\n", - " test_losses = autograph.utils.set_element_type(test_losses, tf.float32)\n", + " autograph.set_element_type(test_losses, tf.float32)\n", " train_accuracies = []\n", - " train_accuracies = autograph.utils.set_element_type(train_accuracies,\n", - " tf.float32)\n", + " autograph.set_element_type(train_accuracies, tf.float32)\n", " test_accuracies = []\n", - " test_accuracies = autograph.utils.set_element_type(test_accuracies,\n", - " tf.float32)\n", - " i = tf.constant(0)\n", - " while i < hp.max_steps:\n", + " autograph.set_element_type(test_accuracies, tf.float32)\n", + "\n", + " i = 0\n", + " while i \u003c hp.max_steps:\n", " train_x, train_y = get_next_batch(train_ds)\n", " test_x, test_y = get_next_batch(test_ds)\n", " step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n", @@ -953,175 +879,147 @@ " train_accuracies.append(step_train_accuracy)\n", " test_accuracies.append(step_test_accuracy)\n", " i += 1\n", - " return (train_losses.stack(), test_losses.stack(), train_accuracies.stack(),\n", - " test_accuracies.stack())" - ], - "execution_count": 0, - "outputs": [] + " return (autograph.stack(train_losses), autograph.stack(test_losses),\n", + " autograph.stack(train_accuracies),\n", + " autograph.stack(test_accuracies))" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "cYiUQ1ppkHzk", - "colab_type": "text" + "colab_type": "text", + "id": "cYiUQ1ppkHzk" }, - "cell_type": "markdown", "source": [ "Everything is ready to go, let's train the model and plot its performance!" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "K1m8TwOKjdNd", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {}, - {}, - {} - ], - "base_uri": "https://localhost:8080/", - "height": 988 + "height": 585 }, - "outputId": "f9d3eef3-5bea-45c1-ddf9-4edee73e4436", + "colab_type": "code", "executionInfo": { + "elapsed": 17094, "status": "ok", - "timestamp": 1522345800262, - "user_tz": 240, - "elapsed": 52391, + "timestamp": 1531750930585, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "K1m8TwOKjdNd", + "outputId": "9f63da19-c3bf-498b-cf00-29090bf3b4f0" }, - "cell_type": "code", - "source": [ - "with tf.Graph().as_default():\n", - " hp = tf.contrib.training.HParams(\n", - " learning_rate=0.05,\n", - " max_steps=500,\n", - " )\n", - " train_ds = setup_mnist_data(True, hp, 50)\n", - " test_ds = setup_mnist_data(False, hp, 1000)\n", - " tf_train = autograph.to_graph(train)\n", - " (train_losses, test_losses, train_accuracies,\n", - " test_accuracies) = tf_train(train_ds, test_ds, hp)\n", - "\n", - " with tf.Session() as sess:\n", - " sess.run(tf.global_variables_initializer())\n", - " (train_losses, test_losses, train_accuracies,\n", - " test_accuracies) = sess.run([train_losses, test_losses, train_accuracies,\n", - " test_accuracies])\n", - " plt.title('MNIST train/test losses')\n", - " plt.plot(train_losses, label='train loss')\n", - " plt.plot(test_losses, label='test loss')\n", - " plt.legend()\n", - " plt.xlabel('Training step')\n", - " plt.ylabel('Loss')\n", - " plt.show()\n", - " plt.title('MNIST train/test accuracies')\n", - " plt.plot(train_accuracies, label='train accuracy')\n", - " plt.plot(test_accuracies, label='test accuracy')\n", - " plt.legend(loc='lower right')\n", - " plt.xlabel('Training step')\n", - " plt.ylabel('Accuracy')\n", - " plt.show()" - ], - "execution_count": 0, "outputs": [ { - "output_type": "stream", - "text": [ - "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/autograph_mnist_data/train-images-idx3-ubyte.gz\n", - "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/autograph_mnist_data/train-labels-idx1-ubyte.gz\n", - "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/autograph_mnist_data/t10k-images-idx3-ubyte.gz\n", - "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/autograph_mnist_data/t10k-labels-idx1-ubyte.gz\n", - "Step 0 train loss: 2.244329 test loss: 2.2499208 train accuracy: 0.12 test accuracy: 0.161\n", - "Step 50 train loss: 0.64771986 test loss: 0.56013924 train accuracy: 0.82 test accuracy: 0.836\n", - "Step 100 train loss: 0.49011207 test loss: 0.42143965 train accuracy: 0.84 test accuracy: 0.879\n", - "Step 150 train loss: 0.3768609 test loss: 0.39319593 train accuracy: 0.88 test accuracy: 0.883\n", - "Step 200 train loss: 0.36007702 test loss: 0.37089333 train accuracy: 0.9 test accuracy: 0.881\n", - "Step 250 train loss: 0.182115 test loss: 0.28543878 train accuracy: 0.94 test accuracy: 0.915\n", - "Step 300 train loss: 0.2119576 test loss: 0.22305593 train accuracy: 0.92 test accuracy: 0.93\n", - "Step 350 train loss: 0.12932214 test loss: 0.29057172 train accuracy: 0.96 test accuracy: 0.906\n", - "Step 400 train loss: 0.22937602 test loss: 0.2200287 train accuracy: 0.92 test accuracy: 0.925\n", - "Step 450 train loss: 0.23444137 test loss: 0.19857481 train accuracy: 0.94 test accuracy: 0.94\n" - ], - "name": "stdout" - }, - { - "output_type": "display_data", "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAe8AAAFnCAYAAACPasF4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3XmAFNW9Pvynlt5mYdhmQMHggnGN\nS9zCD0ElKug1edUY9ZoQTYze3GuiRk1uYjRqRHNj4n5NrhKjiUYlbihGQFRUFDSoKIvgICAO6+xL\n711V5/2jlq7qZaZnpnumZ3g+/zjTXV1dXSP91PecU+dIQggBIiIiGjLkwT4AIiIi6h2GNxER0RDD\n8CYiIhpiGN5ERERDDMObiIhoiGF4ExERDTEMb6JeOOigg3DllVdmPf6rX/0KBx10kGe766+/3rPN\ne++9h9mzZwMAtm3bhkMPPdR57osvvsCPfvQjzJw5EzNnzsTZZ5+NV199FQBw0003YdasWZg1axYO\nO+wwnHLKKc7v4XDY8x7JZBLz58/v9edavXo1Lr300oK2XbBgAebMmdPn97J19/rZs2fjhRde6PO+\niYY7hjdRL3366aee0Ewmk1izZk3WditXrsQnn3xS0D6vu+46TJs2DYsXL8bixYtxyy234LrrrsPO\nnTtxyy23YNGiRVi0aBHGjRuH3//+987vVVVVnv188sknfQrUI444Ag8//HBB2y5fvhxTpkzp83vZ\n+vt6oj0Zw5uol0444QQsWbLE+f3tt9/GV77ylaztrrnmGtx+++0F7bO+vh5HHnmk8/uRRx6JxYsX\nY/z48QUfV3NzM3784x/jo48+wkUXXQTAbAF48MEHMXPmTOi6jlWrVuHcc8/FrFmzcOaZZ2L58uUA\nzFaB0047DQBw//334ze/+Q2uuOIKfP3rX8d5552HxsZG533ee+89HHzwwVnv9cEHH+Bb3/oWTjvt\nNJx//vloaGgAAOzevRsXX3wxzjzzTJx66qm4++67cx5rPu+99x7OOecczJo1C9/+9redC6Vc++3u\ncSEE/vd//xczZ87EKaecgjlz5kDXdQDAwoULcdZZZ+GMM87AN77xDbz33nsFn3eiwcDwJuqlM844\nAy+99JLz+z//+U/MmjUr53ZCCCxatKjHfU6fPh1XXnkl/va3v2HTpk0AgHHjxkGSpIKPa+zYsbjm\nmmtw1FFH4YknnnAeF0Jg8eLFUBQFv/71r3HppZdi0aJFuPzyy3HTTTfl3NeiRYtw/fXX49VXX8WY\nMWPw7LPPAgA2bdqE2tpaTJgwwfNe4XAY//mf/4lrrrkGS5Yswfe+9z1cddVVAIBHH30Uxx13HF5+\n+WUsWLAADQ0NMAwj57FmikQiuOqqq3DDDTdg0aJF+OEPf4jrrrsOhmHk3G9jY2Pex1944QUsWrQI\nzzzzDJYsWYKGhgY8+eSTAIBbbrkFDz74IBYuXIibbroJr7/+esHnnWgwMLyJeun444/Hxo0b0dLS\nglgshlWrVmHKlCk5t73++uvxhz/8AYlEott9/v73v8d3vvMdLFiwAGeddRZmzJjhBEt/nXzyyc7P\n8+fPxxlnnAEAOOaYY5zqONOxxx6LCRMmQJIkHHLIIdi5cycAYMWKFTk/6wcffIBx48Zh6tSpAICz\nzjoLX3zxBXbs2IExY8bg7bffxvvvvw+/34+77roLdXV1BR376tWrMX78eBxzzDEAgJkzZ6KtrQ3b\nt2/Pu998jy9duhTf+ta3UF1dDVVV8e1vfxuvvPIKAGDMmDF46qmnsH37dhx77LH45S9/WdjJJRok\n6mAfANFQoygKTj/9dCxcuBCjR4/GiSeeCFXN/U/psMMOw3HHHYdHHnkERx99dN59BgIBXHrppbj0\n0kvR2dmJRYsW4fbbb8fEiRMxbdq0fh3vyJEjnZ8XLFiAv/3tb4hEIjAMA/mWNqiurnZ+VhTFaV5+\n5513cMkll2Rt39nZiYaGBk8LhN/vR2trKy655BIYhoFbbrkFjY2N+M53voOf/OQnBR17a2srRowY\nkXVsLS0tefeb7/Guri48/PDDmDdvHgBA13WMHj0aAPCnP/0Jf/rTn3Duuedir732wvXXX4/jjz++\noGMkGgwMb6I+OPPMM3H33Xdj1KhRPfbZ/vSnP8W5556LiRMn5ny+tbUV69evd6rWESNG4Pzzz8ey\nZctQX1/f7/C27d69GzfccAOefvppHHLIIfj8888xc+bMgl+vaRrWrFmT8yKkrq4O+++/P5577rmc\nr7388stx+eWXY8uWLbjsssucSronY8aMQXt7u/O7EAIdHR0YM2YMVFXNud+pU6fmfLyurg4zZszA\nd7/73az3+dKXvoTf/va3MAwD8+fPx7XXXotly5YVeGaIBh6bzYn64Oijj0ZjYyM2btzYY4VWV1eH\n73znO7j//vtzPh+Px3HllVd6wmLr1q34+OOPceyxx/bquFRVRTgczllRt7a2oqKiAvvvvz80TXMq\n0EgkUtC+V69ejYMOOgh+vz/rvY488kg0NTXh448/BgA0NDTgZz/7GYQQ+PWvf4133nkHgBmSY8eO\nhSRJ3R6r7YgjjkBzczNWrVoFwBxfMH78eEycODHvfvM9/vWvfx0vvPACYrEYAOCpp57C888/j9bW\nVnz/+99HOByGLMs48sgjezXWgGgwsPIm6gNJknDaaachFotBlnu+Bv7BD36Ap59+Oudze++9N/70\npz/hvvvuw5w5cyCEQFVVFX75y196RqAX4phjjsEf/vAHTJs2DW+++abnuYMPPhjTp0/HzJkzMWbM\nGPziF7/Ahx9+iNmzZ+O///u/e9y3fYtYvve67777cOuttyISicDn8+Gqq66CJEm48MIL8etf/xq3\n3norhBCYMWMGpkyZgh07dnheryhK1ntWVFTgnnvuwa233opoNIrRo0fjrrvu6na/I0eOzPk4AGzc\nuBHnnHMOADPYb7vtNowePRrTpk3Dt771LSiKAp/Ph9tuu61X551ooElcz5uIiGhoYbM5ERHREMPw\nJiIiGmIY3kREREMMw5uIiGiIYXgTERENMUPmVrGmpq6i7m/UqAq0tUWLus89Ec9j//Ec9h/PYXHw\nPPZfsc9hbW11zsf32MpbVbPvKaXe43nsP57D/uM5LA6ex/4bqHO4x4Y3ERHRUMXwJiIiGmIY3kRE\nREMMw5uIiGiIYXgTERENMQxvIiKiIYbhTURENMQwvImIaNh6443XCt723nvvxI4d23vc7sMP38cN\nN/y8P4fVbwxvIiIalnbu3IFXX11c8PZXXXUt9t57QgmPqHiGzPSoREREvXHXXb/D+vXr8Mgjc2EY\nBnbs2I6dO3fgnnv+iN/+9jdoampELBbDD35wOaZOnYYf//hyXHPNz7F06WuIRML44out2L59G668\n8lpMmTI153u89toSzJv3dyiKgoMOOgS33XYL6us34M47fwefzwe/349bbvktdu7cnvVYdXXuqU8L\nsceGd0c4gfc3NOLYg+sG+1CIiIa9f7z+GVZuaCzqPo87uA7nz5ic9/l///fZeO65f+D7378MDz/8\nIDQthT/+8c9oa2vF8cd/DWeccRa2b9+GG2/8BaZOneZ5bWPjbvzhD/fh3XeX44UXns0Z3tFoFA89\n9AAeeeQJVFRU4Oc//yneffddvPzyyzjnnPMwa9a/4YMPVqK1tQUvv7wg6zGGdx9ceecbaO2M46ZL\njsOk8X0/gURENDQccshhAIDq6hFYv34dXnzxOUiSjM7OjqxtjzjiKABAXV0dwuFwzv01NHyBiRO/\nhIqKCgDA0Ucfg/Xr1+PEE0/CH/7wP2ho+AJf//ppmDRp35yP9cceGd5b23YiPOFNSMnD0dwRZ3gT\nEZXY+TMmd1slDwSfzwcAWLJkETo7O/HAA39GZ2cnfvjD2VnbKkp6gREhRM79SZL3OU1LQZJCOPbY\n4/HnP/8Ny5cvw5w5N+PHP74652Nf/eqxff4se2R4f7ztCyjVbTBG70RLZ3ywD4eIiEpAlmXoup71\neHt7O/baa2/Isow333wdqVSqT/vfZ59J2LbtC0SjEVRUVGLVqg9x1VU/xrPPzsOUKSfi9NPPgBAC\n9fUbsGXLpqzHGN69dPykA7G4CZArO9DSwfAmIhqOJk3aD59+ugH33XcnKiurnMdPPnkGfvGLa/DJ\nJ2vxb//2TdTV1eGRR+b2ev+hUAhXXHEVrr32J5AkGUcccRSOPfZY7NzZghtv/AWqqqrg8/lw/fU3\nob7+06zH+kMS+doDykxTU1dR93fjit+ipTOCQyLn4yfnHlHUfe9Jamuri/632dPwHPYfz2Fx8Dz2\nX7HPYW1t7m7dPfY+7y+P2Q+SL4mmcOtgHwoREVGv7LHhPbFmPACgLdk2yEdCRETUO3tseI8JjQIA\nxBFGPKkN8tEQEREVbs8N74rRAADJH+egNSIiGlL22PAeW2FW3pI/xtvFiIhoSNljw3uME96svImI\naGjZY8M75AvCLwcg+eNoZuVNRDQs9WZJUNtHH32ItjbvnUjlsAyo2x4b3gAwMlDDypuIaJjq7ZKg\ntn/+88Ws8C43e+QMa7a6ijFojDWiqSt7UnoiIhra3EuCXnDBRbj99lvQ1dUFXddx9dU/w+TJB+Lx\nxx/Fm28uhSzLmDp1Gg455FAsW/YGtmzZjDlz7sD48eOz9pu5DOjVV1/nLANaWRkCIJdkGVC3PTy8\nxwItQKfWPtiHQkQ0rD332UtY1bimqPs8uu4rOHfyWXmfdy8J+uijf8YJJ/w/fOMbZ2PLls24994/\n4J57/oinnnoc8+cvgqIomD//WRx33NcwefKXcc01P88Z3LmWAf3ww/fx1ltLcc4552H27AuxaNHr\nJVkG1G2PDu/a0FgAQAysvImIhrM1a1ajvb0Nixe/DABIJMzu0pNP/jquvvq/cNpps3D66bN63E+u\nZUDr6zc4S362tOzClCknlWQZULc9OrzrKszwTildMISALEmDfERERMPTuZPP6rZKLjWfT8VPf/oz\nHH64dy2L6677JbZu/Ryvv74EP/nJf+Chh/7a7X5yLQMaCAScJT/XrFlZsmVA3fboAWt25Y1gFNE4\nZ1kjIhpO3EuCHnro4XjrrTcAAFu2bMZTTz2OcDiMRx6Zi0mT9sX3v38ZqqtrEI1G8i4lCniXAQWA\nVas+xEEHHYpnn52Hzs4OfPOb38QFF1yE+voNzmOnn36G81ix7NGV96hgDSQhQw5EEYmnUBXyDfYh\nERFRkbiXBP3hD3+E2267Gf/1Xz+EYRi4+urrUFVVhfb2Nlx22fcQClXg8MOPwIgRNTjqqK/ihhv+\nG7/97Z3Yf/8DPPvMtQzokUcehVgsihtv/AVGjaoBIJdkGVC3PXZJUHvZtp++fgvicQM/O+pa7L/3\niKK+x56ASwj2H89h//EcFgfPY/9xSdABEpCCkNQUIvHUYB8KERFRQfb48A4qIUiqhs4oJ2ohIqKh\nYY8P7wo1BABoj4YH+UiIiIgKs8eHd6XPvFevIxEZ5CMhIiIqzB4f3iMClQCAjjjDm4iIhoY9PrxH\nVZgj+Xa1tw3ykRARERVmjw/v0RXm7WE7OjrQHk4M8tEQERH1bI8P70qfOWBNUpNYvallkI+GiIio\nZwxvn9nnDSWFpvbY4B4MERFRAUo6Peodd9yBDz74AJqm4T/+4z9w+umnO88tX74cd911FxRFwfTp\n03HFFVeU8lDysm8Vk9QUWjvZbE5EROWvZOH97rvvYuPGjZg3bx7a2tpwzjnneMJ7zpw5ePjhhzFu\n3Dh897vfxcyZMzF58uRSHU5eITVo/qBoaOviRC1ERFT+Shbexx13HI44wlx6bcSIEYjFYtB1HYqi\noKGhATU1Ndhrr70AACeddBJWrFgxKOHtV/wAAJ9foK2NlTcREZW/koW3oijOYuXPPPMMpk+fDkVR\nAABNTU0YPXq0s+3o0aPR0NDQ7f5GjaqAqipFPcba2mqM1M3K2+8XaI8kMXZsFSSu690r+SbOp8Lx\nHPYfz2Fx8Dz230Ccw5IvCfrqq6/imWeewV/+8pd+7aetLVqkIzLZK78IISBLMiTFQCKpY+u2NlQG\nuTRoobgKUf/xHPYfz2Fx8Dz237BYVWzZsmX4v//7P8ydOxfV1ekDqKurQ3Nzs/P77t27UVdXV8pD\nyUuSJPhlP2TVXHi9jYPWiIiozJUsvLu6unDHHXfgwQcfxMiRIz3PTZw4EeFwGNu2bYOmaVi6dCmm\nTp1aqkPpkV/xAbIZ3h2R5KAdBxERUSFK1mz+8ssvo62tDVdffbXz2AknnICDDjoIp512Gm6++WZc\ne+21AIAzzzwT++23X6kOpUd+xY9kyhxpHo5xXW8iIipvJQvvCy64ABdccEHe54877jjMmzevVG/f\nKwHFjw6YS4IyvImIqNzt8TOsAYBf9kMXZmhHGN5ERFTmGN4w+7wNGIBksPImIqKyx/BGeqIWyDrC\ncYY3ERGVN4Y3zD5vAGZ4s/ImIqIyx/CG2ecNAKrPYJ83ERGVPYY3rPu8AYRCEitvIiIqewxvpPu8\nQyEgHNMG+WiIiIi6x/BGus87GABiCQ26YQzyEREREeXH8Ea68g5YS3tH46y+iYiofDG8Afhls89b\nVc2KO57UB/NwiIiIusXwRrryllUBwGw6JyIiKlcMbwABJQAAzrKgrLyJiKicMbwBhFQzvCXVrLjj\nSVbeRERUvhjeAIKKNVJNNu/xjiVYeRMRUflieAMIWpW3IZkVd4yVNxERlTGGN4CQGgIAGJJZecdZ\neRMRURljeAMIWgPWdCQBsM+biIjKG8MbgCqrUCQFmhXe7PMmIqJyxvAGIEkSgmoAKWGFNytvIiIq\nYwxvS1AJImkkAABxTtJCRERljOFtCaoBJHQrvDlJCxERlTGGtyWkBpHQk1BkNpsTEVF5Y3hbgkoQ\nAgKBoOCtYkREVNYY3hZ7opZgCIiyz5uIiMoYw9sSVM0pUitCQCSWGuSjISIiyo/hbQlZ85sHQwJJ\nzUAixaZzIiIqTwxvi115B4IGAFbfRERUvhjeFrvP2+c3wzvM8CYiojLF8LbYzeYqw5uIiMocw9ti\nV96yz+zrZngTEVG5YnhbglblLavmbWIMbyIiKlcMb4tdeUNheBMRUXljeFtC1mhzQzJDm+FNRETl\niuFtCTK8iYhoiGB4W+w+b3tNb85vTkRE5YrhbfHJKmRJdtb0TunGIB8RERFRbgxviyRJCClBZ01v\nneFNRERliuHtElQDiGlxKLLEypuIiMoWw9slqAYR1xJQFRmaJgb7cIiIiHJieLsErWZzRQE0Vt5E\nRFSmGN4uITUAAQHVLxjeRERUthjeLj7Fb/5XNRjeRERUthjeLn7ZBwCQVYGUzj5vIiIqTwxvF5+s\nAgAU1YCm9b/ybutK4MEX16G5I9bvfREREdkY3i4+xay8FaU4fd5PvFqP9z7Zjb8u3NDvfREREdkY\n3i4+u9ncZ0ArQrN5PKl7/ktERFQMDG8Xu89bUQwYQsAw2O9NRETlh+HtYjebS4rZZM5Z1oiIqBwx\nvF2c0eayGdq8XYyIiMoRw9vF7vO2K+9i9HsTEREVG8PbxWk2tyvvItwuRkREVGwlDe/6+nqceuqp\nePzxx7OemzFjBi666CLMnj0bs2fPxu7du0t5KAWxK2/I5ujwfjebC1buRERUfGqpdhyNRnHrrbdi\nypQpebeZO3cuKisrS3UIvebPCG8OWCMionJUssrb7/dj7ty5qKurK9VbFF1Ws3mxwlsqzm6IiIiA\nElbeqqpCVbvf/U033YTt27fjmGOOwbXXXgtJGtyUs6dHFZLdbM5mbyIiKj8lC++eXHnllZg2bRpq\nampwxRVXYPHixZg1a1be7UeNqoCqKkU9htraas/vcf9IAIBqLi6Gqqpg1ja94fObp9enKv3aT7kb\nzp9toPAc9h/PYXHwPPbfQJzDQQvvs88+2/l5+vTpqK+v7za829qiRX3/2tpqNDV1eR4Lx1IAgJSW\nBAA0t4TRVBPo83ukkpq1Pz3rvYaLXOeReofnsP94DouD57H/in0O810IDMqtYl1dXbj00kuRTJoh\nuXLlShx44IGDcSge9mhzQ+KANSIiKl8lq7zXrl2L3/3ud9i+fTtUVcXixYsxY8YMTJw4Eaeddhqm\nT5+OCy64AIFAAIceemi3VfdA8St2n7dZMevs8yYiojJUsvA+/PDD8dhjj+V9/uKLL8bFF19cqrfv\nE6fyBitvIiIqX5xhzUWRFEiQYMCsvDnDGhERlSOGt4skSfApPqfy7uk+7x3hXXjsk38grsUH4vCI\niIgADOJo83Lll33QhTVKvIc+7/s+eghdyTDGVdTi9H1PGYjDIyIiYuWdKagEkDQSAAC9m8p7W2MY\nXckwACBpJAfk2IiIiACGd5bairGIGREEv/oqtic3593ulfcbnJ8lzn9KREQDiOGdYXyFORe7pGpY\nrb2af0N3i/ogT+tKRER7FoZ3hnGV6YVUVPjzbifAe8CJiGhwMLwzjK+oTf8iCquoZTabExHRAGJ4\nZxhfOc75OYEINEPLvaGn8GZ4ExHRwGF4Z6j2V+EHX/4h9I4xgCTQGm/r8TXs8iYiooHE8M5h/5pJ\nMLpGAQCaYq05t/GMV8tTebNXnIiISoHhnYOqSBApc7BaLJV7KVLhSmbeKkZERAOJ4Z2DqsiAYU4+\nl8g7AYsnvYmIiAYMwzsHVZEhDAUAkNBzh3chzeZERESlwPDOQVUkQDfDO5knvHuD4U5ERMXE8M5B\nkiQo1pot21o6cm8kvNsTERENFIZ3Hgp8AICV9TuwsyWS9TxHkhMR0WBheOdhhzdkHZ2R7pvO2SxO\nREQDieGdhyqlwzsX4bpXzBD5lw4lIiIqNoZ3HnZ4S0qe6VFd3EFORERUagzvPHyKz5yIRdaR1Lqv\nrA2w8iYiooHD8M5DlRXAUCApOpKp7KZzd7HNZnMiIhpIDO88fKp1r7esI5nqofJmszkREQ0ghnce\n5ixrKiRFQ0LLUXm7f2blTUREA4jhnYeqyAVX3nqe8GZBTkREpcDwzkOWJXN+c1lHIpljxLn7VjEO\nWCMiogHE8M7DMIQ5YE0WSGiprOe9zeY9lNicw4WIiIqI4Z2HYQhAN+c3j2mJ7rdlnzcREQ0ghnce\nuiGcZUFjqXj2Bp5bxdi5TUREA6eg8F67di2WLl0KALj77rtx8cUX4/333y/pgQ023RAQyQAAIKKH\nu92Wfd5ERDSQCgrvOXPmYL/99sP777+PNWvW4MYbb8R9991X6mMbVIYhIBIhAEBMdGU9z1vFiIho\nsBQU3oFAAPvuuy9ee+01nH/++Zg8eTJkeXi3uJuVdzfh7VmYhM3mREQ0cApK4FgshoULF+LVV1/F\niSeeiPb2dnR2dpb62AaVIQREMggASCLXet7pwK7f1pZzxDkXLCEiolIoKLyvueYaLFiwAD/96U9R\nVVWFxx57DJdcckmJD21w6a5m86Scq8873VS+uy2CpvZY9hZ2djPDiYioiNRCNvra176Gww8/HFVV\nVWhubsaUKVPw1a9+tdTHNqgMwwAMFUJToSvRrOfdlTckkQ5q9zZW5c0KnIiIiqmgyvvWW2/FwoUL\n0d7ejgsvvBCPP/44br755hIf2uD60rhqAIBIhKCrkawAzry3O3ezub1taY6RiIj2TAWF9yeffIJv\nf/vbWLhwIc455xzcc8892Lp1a6mPbVBdcsbB+N7Mg+DTqwFZR0cyo49fSieyxMqbiIgGUEHhbYfP\nG2+8gRkzZgAAkslk6Y6qDFQGfTj56AkIiBEAgMZok+d54b63WxI5VyGxA53ZTURExVRQeO+33344\n88wzEYlEcMghh2D+/Pmoqakp9bGVhZAwP+fOsDe8vROziJwBzcqbiIhKoaABa3PmzEF9fT0OOOAA\nAMDkyZNxxx13lPTAykW1MgotALZ37fY8nll557rXO+lrARSJfd5ERFRUBYV3PB7H66+/jnvvvReS\nJOGoo47C5MmTS31sZWGkbzSAXM3mmaPNvQm9qf1ztO31OvwVYyBaTy71YRIR0R6koGbzG2+8EeFw\nGBdeeCHOP/98NDc344Ybbij1sZWFmmAVhACi1uIkH3zaiBfe3gJkNJvruje869s2AQCUmhb2eRMR\nUVEVVHk3Nzfjrrvucn4/5ZRTMHv27JIdVDmpCKpAVIJm6ACAB55fCwA4cLLrukcS6EqGcdt7D+Oc\nyf+GQ8cchNZ4KwBApHzs8yYioqIqeHrUWCw9g1g0GkUi0f0a18NFZVAFhAxN1z2Pp9y/S8DqjlXY\nEdmFBz5+GADQEm8DAIhkiH3eRERUVAVV3hdccAHOOOMMHH744QCAdevW4aqrrirpgZWLiqAPEBI0\n4Q3vpK65fhMQGQndaoe3prLyJiKioioovM877zxMnToV69atgyRJuPHGG/HYY4+V+tjKgll5S9AN\n74xqKS0d3lLGgDUhhFN5QzYY3kREVFQFhTcA7LXXXthrr72c31evXl2SAyo3duWti+6azYUnoBN6\n0pk+VZJ1DlgjIqKi6vOi3HtKNVkZVCGEbC5U4pIy3GEunAFtABDX4+mnFH2POVdERDQw+hzekiQV\n8zjKVkVQBSBBhze8tYzKO2GkB/DFtHR4S7LOAWtERFRU3Tabn3TSSTlDWgiBtra2kh1UOamw+ryF\n8PZdp3QNAdd2yTzhzT5vIiIqtm7D+4knnhio4yhbiixDEjIMpKC5J2KR3TOsGXkrb7DPm4iIiqzb\n8J4wYcJAHUdZkyUJAgZSWrqpXJJdt4pJQMod3qmoazsDBpjeRERUPH3u8y5EfX09Tj31VDz++ONZ\nzy1fvhznnXceLrjgAjzwwAOlPIx+kyADEEhprn5v1R3eAkmRDu+2RIfn9ULSQEREVCwlC+9oNIpb\nb70VU6ZMyfn8nDlzcP/99+PJJ5/EO++8g88++6xUh9JviiRDSAaSrvD2VN4QSBnp9c2d8DbM0ysk\n721mRERE/VGy8Pb7/Zg7dy7q6uqynmtoaEBNTQ322msvyLKMk046CStWrCjVofSbLCkABOJJVwgr\n3klaUsIV3vF28wctCAAQYOVNRETFU7LwVlUVwWAw53NNTU0YPXq08/vo0aPR1NSUc9tyoMgyJFmg\nI5JuGpdUb+Wtwd1sboV3yhwf+2w3AAAgAElEQVSPzsqbiIiKqeAZ1gbbqFEVUFWlqPusra0uaDtV\nMU+TUFzXOlblLTQVkj/puQu8PWk1m1vhDVkv+L2GouH82QYKz2H/8RwWB89j/w3EORyU8K6rq0Nz\nc7Pz++7du3M2r7u1tUW7fb63amur0dTUVdC2spABCdi2M31vu2SHt+5zqvB9qvZGQ3gHuhJh87lU\nABIAQ9IKfq+hpjfnkXLjOew/nsPi4Hnsv2Kfw3wXAiUdbZ7PxIkTEQ6HsW3bNmiahqVLl2Lq1KmD\ncSgFUWTzNHVE0/3akDXz/m3NvP6pwlhMm+AdnCecZnP2eRMRUfGUrPJeu3Ytfve732H79u1QVRWL\nFy/GjBkzMHHiRJx22mm4+eabce211wIAzjzzTOy3336lOpR+UxUF0IHOqGvaU1UDdBWQzHu4fQhA\nldOn0y/7ENet5nb2eRMRURGVLLwPP/zwbpcNPe644zBv3rxSvX1RqbIV3rH0oDQoGoSuOn3fighA\nkdN98j7Fh6iumE0bDG8iIiqiQWk2H2pUK5S7oq7R5opZeUtOePuhSunwViU1fZ+3zGZzIiIqHoZ3\nAXyKGcrhuN3nLZzK2x6spghvs7kqqxCaFeYyK28iIioehncB/NatYl0xK7xlA5IkzD5v2A/5PM3m\nqqxA6NbvisaVxYiIqGgY3gXwWfeX68KqoJ3bxNzh7Tebyi2qpMIwrN9lnUuTEBFR0TC8C1Dh91k/\nmRHszGvuCm9J+Jy+cQBmFW5V3pJVeXdGk7j/2dXY1hgekOMmIqLhieFdAL/PCm/JmkdNsSpwwzXj\nm65mNJur6cpcMdf0/ufyrVi1sRn3Pbt6AI6aiIiGK4Z3ARTJOk2SgCSlK293szkMNavZ3A53STYr\nb3s98GSKA9iIiKjvGN4FUOxbwCSByqAPvoDVg+2qvCXd22yuuprNoegw2OlNRERFMmQWJhlMslV5\nS5JAZciHsGrAACB0BYmNR0EZ2QRZqvbcKqZIKgAZQpedypuIiKgYWHkXIN1sbuDEr4yHL2D1fRsq\njLbxSG35CoRhB7b9GsXZxu7zdkjSwBw4ERENSwzvAshWEP/7qZNx5tcmweczk9i5jxuArouMZnPV\n2UbKvM+bVTgREfUDw7sAduU9fmwIkiRB8dmVtyu8hchoNndV3jL7vImIqHgY3gWQrSVBDWGGtqxa\no8Vdo811XXjmNreb0IWuAIoGwzDSO2SzORER9QPDuwB2Fa1b4S1Z93m7m80NQzgD2wCkg1xXIUlA\nyuDiJEREVBwcbV4AO5S/6NyGz9o3A0rKfMJwVd6GAclVUctOs7n537iWXguceuetj3dgQm0lDti7\nZrAPhYioLDC8C2D3eS/e+joAQLZWGfMMWDPSk7CYr/Fuc+fqe3AELhqQ4x1OYgkNjy7cAAD4yy9m\nDPLREBGVBzabF0B29WUDgJCyJ2nRDYHHXql3frfDW1LNKj2hJyBghntnJIkHnlsDg6POe6TpRs8b\nERHtYRjeBVAk72kSMMy7vQxvn/fGhnbXa8zntN2T0tsgXZl/UN+Enc2REh0xERENZwzvAmSGNwAr\nuNN93PGk7unztkebG51jobWMN3+Gd05znfeP9YhniIgoG8O7ALKsZD+oe4cLRGIpz+8KXK+xKnQD\n3hHnDO+esWeBiCgbw7sAco7KWxgZ/eAwB1c5r5HdK45Z94lL3srbYHj3iOdoePvX+t247I6l2N0a\nHexDIRpSGN4FyN9s7hV2Vd+K69TaQS/YbN5rDO/h7c8vfQLdEFi2eudgHwrRkMLwLkDmaHMATjXt\n5g7j5vZk+glhbtssbYLkT1cYKY6k7hFH5A9v/PMS9Q3DuwC5Ku9RVaFuX7P4vW3pX6yg362uQ/Co\nt5yHUymGd08Y3kRE2RjeBcjV5z1+VBWqQj4AQCjQw1w3OZrYAVbehWCzORFRNoZ3AZQczeaqrDrL\nfFZX+Lp9vcjRxA4AyZSe83FKY3jvGbhWD1HvMLwLIOf4ZnEv/1kdyg7vQyeNxrUXHoWvHTouu/KW\nNQACb3e8jHe2vwcAWPDOFsxd8ElRj3s4YHYTEWVjeBcgksq+jcW9/Gd1hT/r+ZHVfhy272jzOeE9\nzZI/DskfxxfJT/HEp88CAJ5ftgUr1u0q8pEPnLWbW7BibfGPn5U3EVE2hncBJo3YBwDw5VGTncfM\nZnPz5xGV6fAWmlmFj6kYCQCQ5exmcykQg+RPrzJmrxMOwGmKz2f+ss34+LPmPnyK0rrrHx9j7kvF\nbznggDUiomwM7wJU+6vwwIw7cOa+pzqPqa5Z1/yq7Axei6+ZisTGo3DUhP0BABUBNavZXArEIAVi\nzu+NkXQYdxdWndEkXnznc9z7zOr+faAS6unio7fKObxffGeLs+IZEdFAYnj3guIKbFVWPfNu71NX\nBQCoCYzAdbPOwKTx1QCAiqAv655wSUlB8qfD+4GP/+KsEa7p3YR3JJn3uXKR1Io7gr6cm83nL9uC\ntz7eMdiHMaSV8bUZUVljePeCu59blVQ4y2ZIwMGTRgEAamtCzs8AUBlSs/q8IQlP5d2aaIW692YA\ngN7N7WPhaCrvc+Wi2CPoyzm8iYgGC8O7F7Iqbye7JZxxwpfwzan74rJvHOp5TWXQlzUPOmTDCe+v\n7zPd3F9tAyBrWZV3Uk9i4ZZX0Z7oQGe0/CvvRLHD23U6/ufvH6KxPZZ/40HCC4y+4y1iRH3D8O4F\n9/3e7j5vSQJURcbZ0/ZH7UjvzGuVOZrNIRmQ/HEoIoBzDzwLB4a+AknVIPnj0DIq73/Uv4CXtryC\nFzctQke4/MM7WeRZ49x93vUN7Zj32sai7r8YONlO37HZnKhvGN69oHbT551PZUjN7vOWDEi+JFQj\nCAAwdOt5SUBzVXHtiQ6s2LnS+b0jo897xY6VWLBpUS8/RWkVu/IWGVVtOS7mknnBRURUagzvXvBU\n3pKCQtK7MuiDENnN5lBSkI0AAEDT0o+7+7zvWzU3/RJJdgashQLm/h7f8DQWbX0dulE+M7UVu887\nM6zLsYk6VeRBekREPWF490L2aHMzSLrrtzNvFcucpCUBSQIk3QzvlDUOTZIM6Fafd2ckieZYi/Oa\nqBZzKu+qjBndolr59AMnSthsnuv3cqAxvPuNfd9EvcPw7gXPaHO5h8VILLIsZd/n7TMnaJE0c3IX\n3S5WJQOaYQbBtX98C7rQceDIAwAAsVQMHZEEAMCvKp77qbuS4d5/mCJyH0vxR5tn/l4e4a27Dox9\n3n0nCup8IqJMDO9eUFyBrcqq606xHsoG4X1e8pshbM/GJgzreUkgkdTxxqrt0GWzyq5UK+BX/Ihq\nMcQTZjDqhkBcT8/QFklF+vyZisHdtF30Pu/MyrtMwlvT0sfR3b35ROXglZUNWL+1bbAPg4qI4d0L\n7nW9PQPWemzyywhvnxnMwqq8Dd16Xjbw7Fub8bfFn0JSzI5wvxxEhRpCTIs5FZ4hBLqS6cAO55h7\nfSC5w9tdeacMDXd/+Ces2LEy18sKkt1s3uddFZW72uaANSpniZSOp17biN8/uWqwD4WKiOHdC1kD\n1iw9ZXfdqFDOx42kWXk74S0Z2N5kNoFLqtkRHrDCO6rFnYFRhiEQTrnDe5Arb1d4ufu8t3Y24LP2\nLXh8w9N933eZjjZ3BzYHrFE5K5fWKiouhncvSK5RNYprkpae3Pz94zC+60QkPj3G83gqYTbD233e\nkiTgXApY06X65QBCaghxLY6UZm5oCOFpKh/sZnMtT+WtGVquzXsl84unXAasuQepsc+7H6w/Z+bY\nBiLqHsO7j3yyAvf0qN0J+lVMCnwZRkcthKv/OxaVYRgCuqvytkmqGXw+KYAKXxACAklh9pUbRmaz\n+WBX3q4+by0d3rmWUu2tzLDOvO97sLgvWDjavP/K5aJsOOK5HZ4Y3n2UOT1qTxTFOtVGeluR8iMS\nT0HX0n3e6RdYlbcUQIVaYT1vPmYIIJxKjzAPJ7NDMvMfbCKl4911u5zqvZjczebJZPrnLtcxLnrv\niz7tO/N7Rx/gLyJNN7BkZQOice+88u7AZp93/7Fpt3TKpauJiovh3UfmwiSmQu5R9dnh7VqkROgq\nWjsTcPLUGm0OpPu8VRFASA1ab2pW45l93pnN5l/s7sIPf7cUb3603XnsuTc346EFn2D+si0FfT63\ndVta8doH2/I+7xlt7ro4CLtuYVv4Xu/fF8jRbD7AX0TPv7UZT762EX9fUu953N1Uzmbz/mN4l065\ntFZRcTG8+6jQ+7xtimIlvHuFMV3BLY+uRDhid3obTsVsh7cCPypUc8CbZFXjhiE8TdI7Irs8s6wt\nX7sLAPDU6585jzU0dgEANu3o7NVxA8Cd8z7C35fU521+y9fn7b7/3JD7tiJa1gxrA/w9tHFbBwCg\ntTPheZwD1orD/nOyabd0mN3DE8O7j1RZ6dWiCnblLQz7vxKc028FuiS5m83NKlsRfgTUgPVYesBa\nXDPv8z669itoT3Rgbct656V2FSP7EqhvMwPcp5qj4/vTbJ6vOvKMNk+6wtvVIiAUb/gV/J5Z93kP\nbFC2dZnHPbI64Hnc22wuEI2nsO7z1gE9tuGEAVM6bNUYnhjefaRIhU2PalNVO6itjQ1X5S7Sk7TY\n7CpbNgLpJnopfatYzArv0yedAgB4a9sK57VOv/A+a3DvqofwcdNa+K33T/ajSszXt5tvkhZ35d3X\n8M5s8hvoUcntYfO4R1T4PY+nXIP0NM3AnfM+xp1PfYT6hvYBPb6hzv6nM9AXZW5CCCz9cBt2tQ7u\nfAmlwlaN4Ynh3UeqrOLkoycAAA760qgCtvc2m8vCHd7Wn8E1YE3yJyAMCbLwwWc10UtyepKWmBaD\nLBQ89sIuHDhyf2xo24hdkd3m7uzAC5lBMn/Ty/D5zPdI9WPu8XwDX9yjzd39v+5BdYbct+VMM9/S\nEAKabmTNvFYq9mfOfD8tY5KWLTvN7oimMlxvvJw5zeaD2POweWcnHnulHr+a++7gHUQJsfIenhje\nfeSTFXzntC/jjh9NwWH7ju5xe6fytprNZeSqvN3hHYNIhqDrrv512Wo2N4CYFofQfdi8vRMnjDfv\nH/+0bZP5vN1vnjJHqTdGmyGpZgWZ7EezuZ5jGlAhhGeeb/e0oRHXKHhD6Wt4Zw9Yu/z3b2DO3z7o\n0/56w90FkDkoTcszYI0LbPRNb6vD9zc04sEX1xWlqozEzC6q4VqgsvIenhjefaTKKmRJwtiRuWdP\ny9o+Y7S5e7S63Q9uN5srqg7Jn4RIhKDpBnyKtYqY5K6844BuTtEaUioBAM+/vRHN7bF0FaOkB4nF\nVXOFMvfgqriW7hMvRGbl3RZvx/ee+ylWtb0PqGY4u0MtYbgCW+1bs3n2gDXzd7vSLaXWrvT88cmM\nFgv3eXT/LGekdyyhIZbo/2Q1A03TDWzd1TVg79fb6vCP89fivU92Y3cBTd2vf7itV5/FMATunPeR\n526NoYyV9/DE8O4j91SphVCt0eb2JC2q5FrW0x6wZjWLT5xo7lskQkhpRlazOWCGt6GZj8swt4+m\n4nhx+efpK20lHRqblXcANenp835iwzO4d9VD+KhpbUGfQc+oPjd1fI6ElsCyliUIHvkG4Is74W0I\nA5qhQYHVV9zHyjuzz3sgFwGx108Huq+8NU/l7Q3vK+5+C1fc/VaJjrB0HlrwCW55dOWA9eH3tTp0\n5k/Io7E9hsdfqcctj+afXz+ztaSxPYZ1W1rx10Wf9umYyg2ze3gqaXjffvvtuOCCC3DhhRdi9erV\nnudmzJiBiy66CLNnz8bs2bOxe/fuUh5K0Vz+le/hrP1metb2LoRdeUtWda1KKn71PWu61IxmcyVo\n9puKRAU03Ug3m9vN6rIBXegwUnZ4p5/fvKPTudIWcgp1FWMBAElEoY773FMl2iPUN1rN7T3pbrIH\nSTGgjGx0giypm8EXRJV1Avp2q1jSSHpaEIq95Gh3uqLp982cRU3zDFhzDTTsY7N5S6wVN7xzOza0\nbuzbDors/Q2NAFBQZVsMfa0OMy8oMxXy/0vmn2y4dX2w8h6eShbe//rXv7B161bMmzcPt912G267\n7basbebOnYvHHnsMjz32GMaNG1eqQymqI2sPxxn7fb3Xr3Oaza3wliDjgL1roMiS0w/ujDb3m1+Y\nduWdsjPErrytMBO6VZFb64VLio4dzRGzipEMQNYxOjAKFx30LfP5jACt9JnN7YVOr6plfAnYI95t\nysgmZxR2QrdmiDMqrffuW+W9UnseoWNegz20qbezRdW3fYZH1j2BVB/mWe+KuirvjLECqTxzm/e1\ngnz1izfRlmjH3DWPFbS9EAJPLKnHui2lvT2tKuTreaMi6Gu+9HSPfUE5LHX765DHPu/hqWThvWLF\nCpx66qkAgAMOOAAdHR0Ih8M9vGr4UuzR5s7tZVbftyJD2KPN7T5t1aq8k0GkdAML3ramFrUGrNnL\nhcIKbwjF83wkrjkBH1KDOGj0gZ7nbVU+c0BbOJk7vNvi7Xh03ZOQrIuJzCrHvtf8lJFnw4hXQK5u\ncyrUlNXfrYgghC47y6D2VhhmOMk1zX16/b2rHsL7uz/CxwV2Dbh1uirvzJDwDNJzN6FrfWz+tVpy\nNFHYRcb2pghe/WAb7pz3UZ/erzvukfWZF2yl4q4OOyNJrN3ckndbz/H11I3ShzJ6uGXdnlh5G0Lg\nd3//EP9c8flgH0rJ9G6asF5obm7GYYcd5vw+evRoNDU1oaqqynnspptuwvbt23HMMcfg2muvzeov\ndBs1qgKq2rum6p7U1lYXdX/dGdlsNT/azeaKitraavhUGYmUt887FJKBlFlZ+/wqWtpTwCjXJC5W\neAvdrIpGjrA+hxXOmiGchU1GVY/AXnXWrWzW/u3PXREIAl1AzIjmPBfLPnkbK3evQuAIGfH3T0f1\niJBnu+QXZrhVh6ogkgHIwSg0w0BtbTVi7eaAMkX2QST9gJrM+R7vbVuFUcEafHns/p7Ho/EUKoLp\nqk8Zux1GR61nm978/Xyh3v+9U64vPUOSPK/3B9LHJrv6XYMVfmc7dytBT+8dCJj/FDVDK+g4O+Lp\nC7Fi/3/c1pluUQm5Pk8pqT7FeZ9fPLQEja1R3Hftydhv75qsbSOx9EVVVXXQeV2u44y7rrnyfY6a\n1phnm5he+N9tKGgKpy+cC/k8w+Ezh6NJfNrQjk8b2nHJN78y4O8/IP9mSv4Olsz7ZK+88kpMmzYN\nNTU1uOKKK7B48WLMmjUr7+vb2orb91ZbW42mpoEbTdvVZX1BWAEqdKCpqQuyLOWYpMX6YhYyOrvi\nUCUFCddrneZva8BaW6s5ktsO/65Iup9Y0hR0tVnPWzO0NTZ2QpIkdMbMintXuAlf7GyCX/Z5+vI7\nwzFnv1IwjJaWCJpC5nsmkjpefOdTqOOARFQ4rQApI4mmpi7s6jAHOukpCdD8kIKRrPNtCAN3vvMQ\nAOCBGXc4j2/Y2oY7nlyF807e32yokAClphkpyfBML9ubv19LR5dn+22NYUgSMKG2yrPdZ9s7cO/T\nH+Pq849EY0u6RSIWT3le3+EKuIireb2tPepsF0+mq+jujvWtxmVY9Nkbzu/23ydTMqXj3U9247iD\n69DWnv73UKz/j1es24VdLVEctl/61sfWtuiA/DuJu85vo9XPvnFLC6p82Y2Dja576Ztawmiq9uf9\n99zSkm7ty/c5OjLOZVNzz68ZSlpb0/8f9/R5Bvp7sVQiroWEBvrzFPsc5rsQKFmzeV1dHZqb002d\njY2NqK1NV05nn302xowZA1VVMX36dNTX1+fazbDh3ELk6vMGkNHnbQ1YU+1FjmVougG/Yo3Ylg0E\nfIrTbG73ecPwNpvHk5qzTYUagk/2eZ5/7q3NeGfNTqfPOqEncd1bv8b1Cx/CZ9s7nGN2z58uBSOe\npuKuWNJpAZAMn3MsQtagG4bTbA5dgdB8kBQdCc3bdJ7Qc98+ttIaLLVw5RanA1JSNchVfR/5HE15\nJ0/59V/+hRsf/lfWds8s/QyRuIZnlm5yms1HVPg8zea6YeCL3el/nO4+b/e98O6R/d01Xc5bu8Dz\ne0TLfaG6YPnneHThBjz56saSNO3OXfAJFiz/HM0d6XM1UPO25+qXzddk7668e1qOtZAxEpnvM9xW\n4ervx+mIJPHBp03FOZgBMtz+hrmULLynTp2KxYsXAwDWrVuHuro6p8m8q6sLl156KZJJ88t85cqV\nOPDAA0t1KGUhff+vNe847D5vKV1NWuEtK1Z1bshIaQYCavo+74BPTt8CZjWb6xrM6t0K51hCd6rz\nkBqCIiuQhAzJev6fK7bi4X+uzxpwFg5twd3/SPehulcFU0Y1oiWRHhxlGMK5QJCFz6m8JUWDpgkk\nrNHmwlAgUubFR1vcezUaTaXf372wiv1FLqvmY0I3L07kEd5+UPMiQcsK5lw6k7mvhA1hYHc0/cVk\nT6aj6Qa6oklUhXwI+BVPiK1YuxtrXQPFtDyD19yz2el5phAzRPbjLbHcg9B2tZih/vmu0t7j3tKR\n/ruUMrzdrXG5Lm7yjST3hHcPo80LGayV+d7D7YvffQ76MjPh//z9Qzzw/JohNfVvrgmlhpuShfdX\nv/pVHHbYYbjwwgsxZ84c3HTTTXjuueewZMkSVFdXY/r06c5tZKNHj+62yXw4kK0Ba3a/tV15T6yt\nAmA1nctWVW6FNwwFKV0gqKbv8w74FSek7cldNN2AJBSn2Tye0JyAt5cTlaB41wuHQEJPoNJeKxyA\n0FTPVbq78lZrt+PvDQ8imdJR39BuTlpih7er8oaiIaUbSOr2iHgZ0Mzwrm/x3pIW1dKh61772/4y\ntS8OjC6zGVcOeQc8aprAw2sfw8+W3YRIKuoJccMQePHtLc5kOJ3J3IG3cMur+M27v3fudbfvCtB0\nga5oCtUVPvhUb3hv3tHh2UfKM2DNtba5PUJd0pHQcg9Ei2vZrQ/5Rv/b/w/phijpl1OLq0uglMud\nunMkZ3jnCdGwK7x7Or5CgjgrvHvY519eXo///r/lPe63XLg/X19Gntu3Cw6lqX/zXSwPJyXt877u\nuus8vx988MHOzxdffDEuvvjiUr59WXEqbzugrdHml5xxMPbfewdeiSsw7AFp9n+FDE0zYOj23Oe6\n2Wwup8MdsKoj4Qp1pG/NspcTlYXqHW2uaBAQ2K9mknO/t4hXpudgR+4Q+cvL6/Gv9Y046/9NgqRo\nELoC3YCn8tZ1A0nDWr5UV2AkzGOYt+kZHD/hSAStVdJirvDuSHRiZMAcnGR/v8jWoDsRr4DQFUhB\n7/GkdANrms1j//mymwEAPzjsIhwz7ij8a/1uzH97C0LHCEABOhPp4HdXa+/sMCfvWLV7DY6qPdwJ\n70RKRziWwoSxlYgndU94B/2utdxlAy0j/gUlPgJSMIIW3Q/AHHyXTBmAZCB45FuYv6kT3z3sW1nn\nM7P1AzAHreVi37FgGKLHirM/8lXeiZSOrmgSY2sKm1WwJ+4gyZWx+YI3Ek+fn54uYgoZaZ35Pj0F\n/turd1rbGVDk8p/nyhPehkAP89rkNZRaJIbSsfZV+f+fN0yMqQlaP3mbzasr/Pi3KftClc1wkkc2\nIpwKQ7Kq8ZRuIJGy/keUzD5vJ9ytyjulGea93q5wloLm1fLY0Bjzd6E4zeZAetWyoBLElV/5sfmg\nrHtmrAqnIhgd9C668q/1Zn/0xoYOc1CcrkI3BISRWXlbzea6DL3xSxBJM7Cjrv5cd3gv3rrUqdad\nudntixFdhYhXWp/JfZtQdoDtipjHZ37BC2cZ1Q5X5e2e6tReS317s9msbs+E12atJmZW3rInxGLW\nQLTbLjsBoTFtSFR/Dv8Bq+GbsAmrxItOU3hS0yH545D8CWzsyD0NbVw3gzKoBHDKPicCQM570qOp\nmPP31Q0BrciVhbs5tdm1drm7JeGOJ1bh539a4Zl5rj/0HirCfBcoUdd0sz1V3oWFd+ZtgIV98Wd2\nKQgh8PFnzWU3Ha773PYn1IbSLWdsNqeiGVUdwG2XneBUywq8k18okgJJ1RD48ofYEdllzaomIaUZ\nSKYEhCFDkg34fa6QFq5lPo2McA5EASFhTMhscpZyVN4AsGZjJ3738GdQjRAgG051J4RAOBVBlTWR\nS6akpluVt2p++en2RDEaNF044W1oMiBk6O3mYEU7oAHvILKPm9bipS3mGIms6V11FUas0hz17k+/\nRtMMyJL3f2FP8Lmmh+1IdDrv51621J4AJ2m9zl533V6UpLrSD58qw3AtwBK3ngv61Zwz7dmfMakZ\nkHxmOLfEWz2f3WZX3idNnIqJVXtnfwaYs9XNee9ObAq+CkBYK6sV98vJHUSeytsVjvZ88u5m6/5w\nh0GuUMn3GfU83RQ5t+1L5a27jyv//jOPb/naXbj3mdX466INPb7nQHJ/hP5c8w2lanYoXWj0FcN7\nAO01phIHpk6D3joOU+r+n+e5zBHGftkHVTFHmyeSulllywZURXaazYWr2VwYCiRfylkARA5GIeuh\n9LzoIqMyt5qk7XlzzIsD3ak8E3oSmqGhUq1AcrN5n6Q9hzoAJDTdDEddhaYJT5/3X15ej664GZS6\nZq+mZr7WDnXAW3kDwEeNa8xNnT5vb+UNAFIo3XSe0DSnYjyq1jzGlGEHp56ezAaAgMD7uz92nks/\nYU9ba430z2hTHFFhhjeQDji7sgr6FShq9rdh0khCNwxsbwxD8iec998VzZ4C2J7oJqQGnb9VKiPk\nP2hcjY5kJ8LyLsjVbVafd3Er77hrBTXPimnWZ3avsFasL3F3tZ85h735Prk/o2dq2j40my98dytW\nb2rJu02+VfIyZVbe9iDGTdtLv2hOb3gGBvbjNoVi/z9XSkPpQqOvGN4D7IpZU3H1cT/AtMO+1O12\nqqzCp5qVdyJlhndNtYq6kaH0wDO72Vw3zIFhAEJfXQqoCUj+BJRUFYQQ+OPzaxCNCUiygNPsbM8X\n7p5iVU734dn93RVqJTsCVIYAACAASURBVPTmCdA7R8OADsBuEk5Bks3Qjic1T5/3Z9s6sGqTGVS6\nZlXyVmVu94UDQDSjv7cl3obGaHO6/9OpvBWIlNns7p5mNZyMQkDgyNrD8c39Z5rnwqpaw7GU83q9\nrQ5CAG82LEdjW8QTRPY99kKyBt9l3F49wmo2B9Jf1vGEBglAwK84I+LdknoKTy/dhKde/wzwpZug\nd4R3ZW1rV95BNQjVuqVPM7zhvXLXh87PytjtVp934V9OWzq+wJ/XPo5wMpJ3tHE8zxzg9mduaDSv\n8tQJG/HytpcKfu/u9NRsnm+kuztce2w2z9hvNK7h6Tc24Z6nP05v002fd+b+3ecvc8rcZmtA11in\ni6w86D20cBS8nyE09Vyxu5XKEcN7gPl9Cg6eNKrb2eQAwCer8Cmy1WyuQ5FVCDkJQ04BkvWl4fR5\n604VDgDKSPP+eilZiVhCw/ufNmXdCy75zdCwQ1FYfeZ25b0zYgbNCL81QYDzeiu8EXFeH4lrnsob\nSFfYuqZ4Xp9wVd6fN5mVynXH/Bhn7GtOpdueaE9/Qcqu+9l17/EDQJc1rWuVr8IJPrvyjsRS6dHq\nsSroLXthV2wXfvXss3jh7S3pE23tLwnzizfzy7raVXl/vH0zVu/YjFhSRzCgQJYkyEqu8E5imTWo\nya68ge7DO6QE0pV3RrN5S7wN1b4qyEKFXNmRNWBtV2Q33tuZf33zf9TPx6rG1bjx5b/i9sdzbxfP\n009rn4+GRnNMgG/CJqxu/xDhWApPLKnvVxO6O0dyZUrmMqw276IwvWs2D8ey++u7azbPvIBwH1Pm\nc01Wd8OoEYFuj2mgGT3cklfwfoZQNbsn9HkP2Axr1DuqYt5fHEtqSGoGAlAQTnXhXflvgGwu4iKE\nq9lcl5wFFeRqs0lQSoUQtkfmWkHv+9IGpD4/DFLAXrnMHDls6GZzvH070spdq8ztIxMAtDmVM2Qd\nMFTEpU4oMEeoR42Uq/K2mrqtCwwtZVW2OZrN12zdBXUsMMJf5dzSFtPi6S8J2T52NT0JnSss7TnZ\nK32V8Cne4AvHNE+fubbjAKhjd0Ie0YpVG9OTB9n3w8eMMO54/37ElNEA9nKe19ROrK94BsrY/fBk\nwyLz/RJnOyPOJVflbUSrIFeEkdCTqAgoiCU0p88bAHZEssM77qq884V3OBnBmNAoSMlKdIR2Q0fK\nEzi3vncnAKAmMAL71UxCwJ7Uxz4uawBdomIbNm34ctYxAN5mczc7HKMJLf33ADD3pbVYs6kNmiHw\nvZkH5XxtT/L1ecuSBEOIvCuCefq8ezlgzT1ffa73BrxVW+b+3bPmuS/0hBDOQL5yC7nM0eZ9NZQC\nkc3mNGh8soqAT0VXxPyySfc3Cydw7C/7ZMqAcC2bKVeYVZIwZIStLys7PNW6bZBrmiFb4W3Ezfu8\nDatvWlENGMLA6uZPEMIIvPCKNWGIYc+/bo149pnNqCJe4am81boGyKN2Oc3Qeko2mxFzhLd7xLs7\nvJ2Kxj2TnD0Lnavytu9Dr/RVOLPI2f3F4VjK2b/QVAgtPdGNh7WNhhS2djag0f+x5+k2YxeSUgT+\n/dMLm8STOkKBdDcBACTWHwe9dbzzGYP281blXaFU5q68dbvPO2Teiw/vrWIpQ0Ncj6PaV4VqqRaS\nBBiBzpxNyvd/NBd3rLwPQgi8/uE2ayY2gdZ4m3ksqubchZApka/Z3AooTRee127aaf5/YfSjedId\nJO4+b7v1J6nluaDoplk7U+aXeFeOkfLdNptrmeGt53zOfftavhYDIQSefXOTZxbDgeAZbd6Ppu+B\nWqSmGPaE+7wZ3mXi9Emn4OBRBzr3OvtkFUG/4vzDE3L6S8eerGREyAy8aELzNM/KlWZ4G7qcbtbU\nXaOiJQEpEDOraWsCFfteclk2oBk6UkYKWjQEZ35Su9ncqnxl64vciFeiPZxIr3AGwDfhMwgrZFMp\nCaOqA1AlMzyT1rSpumFO8iKEOUNb0BXedsUl5HSfd2azPwBENKvyViucCxk7+CKxFKCmK3e4lk1N\nnwcjPSFOPkp2c3IslUDQnx5dDwBC8zvvsaO1AxV2ePviEJoPtYFx6Eh2eia+AbwD1pZ+YIb79tb0\ngCd7lrsqfyWqYN72JwKdeQcP7Yo2YmP7Jjz+Sj2WvN+AjmSXZzIcuTJ7MFVcS2Bly/KsVeeAdEBp\nugHZdZ99NGVdlAT7vmSonmcglWKHd54Q9FbehQ9YE0KgM9q7ZvPsyts1sM89IY/r4iffRcfnu7rw\nzxVbcftj+bs4SsGd17kGBvbEPb/AUFGs1oZyxvAuE//fAWfgJ0dfZt0iBqiy2WxuS0npL2A7qGsq\nrfCOa5B82TN1CUN2+vjcfeKSrEMKRK0mc8nZFgBk1XACUNcl1768fd72hCkiXoH2cDIdrjCb0u3K\nW+gKVEXGiKDZPG/fLhWOpszPofmh6cKpvONaHAnNACCQ8rUDwgxGu9nefTucHUqVvgrzVjtIzoC4\ncCzlDG4TKX+Oyl3Af9BKz2fPJQVr/vdPj4HeYYYnanYh4LcvqlyD6qxz8MTrG8zKXElBCkZhRKsw\n2mfeKpdZfcdc4b16o1khN7anJ5SxZ56r9lVBMsygFJLebRW0rvlT5+ftHebAwUlV5gBJKZQ9Teyz\nG1/EB13L4Nvn06zn7PBKaYZ5+6HFvmjpzz3NIk+zuT1oMpmnP7uvfd66Yc6alykz4LuvvHM3m7tb\nLqJGFz5sXJ31Pok8XROl1t8Ba3Z4r9ncUvKpeYuluwuw7vx10QZnEp5yx/AuM4p137JPVhH0eatl\nN2HIGGmFdyyhQW8dl7Uvs/K2vmxEOoilQBSSqkEkXTNluSpbu8/VHinuft4OTykQMydesSZnCfpV\nnDzCmkFMNiAk3ZqaVIJPlVFTYb5XJGmGVWs4BikQgxGvQDJleJrNkykdck0z9EA7KhP7mHO4Z1T+\nABC1Ku8qfyUkSYJPVhFJxLHgnS1WeNvN5n4AMoQhpcPfl4AywgxLvWW8s09ZT48U3m+vaucWPpEM\nQsTMufn9B6xG20irerIH1bmqe8jmjGxyVbvZzN01GmN85t9nc8fnnr+RHd6bGiLOHPbuPm+7X7/K\nX+lZtz39hZT9ZRxOpPvZd0fN/v0vjzBnN5QrO7NGnO+KmhPb5Ap2O7xSuuFtclfSa8c/99YmvPbB\ntqzX9iTfaPN05e0Nu2ff3ISXln+OsNTsdH/0ps9b13NX3lpGuOVbqx0AYnmazd2tBLvHvYSH1z6O\n19au97z2b1/8H/wHZy+GUyhDCM/FQ8Gv62cVav89GhrD+M2j72PTjg7c98xqRON9v3ArplxjI/py\nwZJM6Xjzox34y8vre964DDC8y4w9baoqKZ7KO4sho6bKbPKOJjSkPj8MX459A0YsPamK0CWn8naW\nEQUgWc3u9qxn9v4As9nZvlXJXXlnjVZXXCPMAVQEVewd2MfaRoMOzemHVhUZFX4zFCNWsOzobIYk\nCYh4BZKajpCSEd6VZr9gZXQ/81hzNJt3psyFEsYEzYrYp/iwszWM55dtMf/B+lyVN2BeaNjH75rn\nXa4I43jpAnNbpB8/cOJIp5lbaH7rIsDU5WswH5fTg+Ls1gF17834rKUBcrV5cWCER2Kczzw3G9q8\nM63t6mqF0GX88dkNCPnM/bv7vO1b9qp9Va7WAyNdfarZYbR5d5vzc7u1GEyNOhpGIgg51JX1ZaZI\n9oWZAXWfDVDqvjB/l1zN5prhad2RrM8djafw0vKt+PuSeuyM7Ma8T5/POV97Lkae+7ztqYTXbmnF\nR5+lBxf+c8VWvLBuGT6vfhnq3uY8+T32eXtmFzN6rLx1XXQ72txTeWu5K2+7p+mJ1z51LpQMYaAj\n1QZlRO5FZwpx3zOr8V93vdVta4emG57lMM337t993plTwP7+yVX46LNmvPnR9l7vq9iefXMTfnTn\nm9jZ4p06ubtBh24bt7Wjrcv8/zVfyLd0xEs6HXFfMbzLjF15GxCe8E6PJbcYCmoqrfCOpwChoEau\n9TRf64bkDFjzfPFat4l5mrqtn9uqV6fvv3Y1J2eFp6x7Xl8RUOFTFXMOckWH4QlvCRX/P3vfGW9H\nVa/9TN/19H5OzknvIR0SEjpEulIFiShYLyI2BEQR9PpD5aJX5d5XQbHAtYAIypULWABpIXRIg5De\nc0pO3XXKej+sMmv2npOQkJAE5vlAOHvKXrNm9jzr356/wfTMWax0xyDt5EUKSdiOhxjTYM+5eRSY\nJjgdgxEYq+w2H3D7YGkmKkxqERuqESB33eR9z/k51OD4GZztI2F5FehIjxDu/mRMxylzRvgxascA\nsf34rgGLndIG8VhnODZGNZaFPvlpkZvgZSphKnG0pVqwrm+DiGl7xENPoQskT5vT8LwDuXXqoBTz\nlhdQPO6rsAXKmMqRmBc/EwCwTYqZ9xeY7CuJg2QroJhF9GT7AxKnQo42MQijeQN0Rt5xU4fteMg5\nebylPwa10idSYXnnfCL58Su348mtS/D0tucwHAghIg8j2DDD30f+/Cf3Bd3PWgO18DU2Fsfx8Pra\nnmFL1uRzOR4JlXYtzXrfXZ33cAlroXFuhYhEtqCG/b4RAReWkRvHlOJbv34Bn//RUwGyGS488Xah\nlogfcC/Du9Uudnd4aMlGAMCK9cFFUVAlL/yaO3uz+O7/vIyb734RQDjJ9/Tn8dWfPov/vPe1sm0H\nGxF5H2Lgcp8e8QJu85OSl+CyKR8RWeeEqEjFDWiqItxXhq4G4rfE8RPWvLxvkYsabznWy+uwY9vx\n5zUPBT4D4Mufqi4AAqjB2vKERevS4eo0EU3xyds0NCTMIHl35ujLl1reXiDmXXRcEVsnpCRhTvXd\nxUNuH+pitejpz+P7v30Zjh20qDXTptYwczcTT2rqwv51drbD3dXC+qYbgOohldBx2xePRW1lDBk7\nA0MxAaIGLG9DYeSt+AI1gfkC/GQ3x4DjemhPt8EhDr551xMAgK5cD4jiwcumEDM1DGR4jbwjXLfC\n8jZTUtzft7x5XH989RjUaC1iO8cga4X6qwfXw8vSmv2/L1+BL972NJ5bSePv/YX+wHGinaylwXE9\nPLVlCfr1jZClCbjl3ZcpAIoL64h/iYXGa10rUIqubA8Gi0P4zSNv4KofP4Wt3ZlhNbdLX7YD2SJW\n71oLrWETVOba91gI47W1PfjRH1/Dd38d7o4W51IdvLTzFQxk/UUst4qD3+1hlf009DYa/y+zvAsS\nebthbnP/M0X1hFUnt9flHqF9xe6M561d9HmRPQHDLYzeLrRS5aJDEKW6GYFF2zBW86ad9J70MC3/\nsORH3tt+1cbesm0HGxF5H2KQyduSyLs53YA5jTNgKiwm66mIWzoqkqZI7DE0FfUVPkm7riLI29ky\nDsUNkwFIwiEy2Uj//0bvWwDoAsHfTv/fHLNMxHJlyzwRM6DrKrW8VQdEcaEpPB6uIcXc5isHX8fD\n6/+B3iJzKRcSKNgutnXmoCoqNnX3omh7Qq5UfAdhMWtOiEYBLhxk+k1c87MleHNzHwaGXMhtTxW9\n6LvMAboA4W5zVodOHJal7hIYGvMU6P6POGNnYaksN0Amb7Yw8VAU4QNSQt6KXqTKdVDgekR0U4Pq\nghCC7Sx5jeTSKBRdZLJ+jTx/6QvL20j6fd+lmDe3vFNGCiopDy2IVquOKch7xY4NAIA7HlwJQgj6\nCiWlSxqXf6WWt0tCrErNRUXSxMBQEUosCzVGX3KqomJd/4aApekRD//x4m345Yrf4cnXaDLQ+m0D\nw8a8S8l73dYB/PjV22GOXClkfVHiiXpdcq/L4C9xo/0N3Lf+TxhM+fFM/rIujXFvJstgtKwXf3O8\n2rUcAwU/L6DUba5YGcSP/Jv/5ap/H4ekKgO5MmRf8HZ6cpcuSDj25DbfuGMA/3X/soDrfTjyPpRy\nuEuH+HZi3p0lLU7DKjhMQyv77FBBRN6HGHi3MZd4Abd5OkGJQybvmKkFpBh1XUVrbYX4m3gaeofY\nKp9ocDtHiAYn/Bz+viGPQpjbHIA1eSn9n5KYNz1GBzQXRHVE85WYqSNp+eP86/q/+brmtonnV+7E\nt3/zIlxbw2CBveTYGF23xDvAPuelal2d0o+LaDSBTKUdxVy1ECBcriInn58vWlzPg6XSfXVDJu8M\nYiodu6gVB+ApLC9AsUXSXqAcDzSWbjHCdlxPiKcomgvXI9gyRInMy6XYi1AR96efuXeHbE7eKSGB\nC9XzrT5meadNSu6EoKScboiGDYgKkqXPhpyYtq2/F45EzqpnsnI6D6ahwnY9IfIiQzc8NFbHA+1n\nTxt5MmY3TAfgl8ABtAFNxsliTd86keCn6wrk05JhyAYA1m4rr4tWNV8NcHcQOvkshGEnt4ltnHxl\nK01uHSvvs7p3LX6+7C68YD8ItXon9Kb1cFwPHgsDFG0XetPGkkF6tIwS/n0EEBDuKRsv8fZIzm8n\nbC1n4e+N5X3lfzyOl1d34bkV5Tr85eMoP9ef1/zfbtX+9hV7Gnep5e0GLO/wY3mf8mSMl5mW73co\nl5lF5H2IgVvepJS848wFzcmbqIgZQfI2NEVYhAAATxUPKIXix39RalmHrDBD3OoyFATd5lTpTYei\nuVAU1gwFQNzUUBEP9oC2XZ6lreI1Fssjju5b1oxcbVt+80iWM3P9J9W0fz2uCkUliM/5BxQzD6J4\nAVc37bxGaDy9pDOb6xKYjFw1g373uv4NsD0HMS3BxufPnY0C8k4BLmxhvYfNkaHQc7oegcoFDVUX\nRdtFX44nDkpa2GyBMsAWXZt6ekA8BU7RD4koiivkTLmLO2kkqTEqhwYA5NwMVI/OPfdCFIlv+e0c\npLHCZHYU8svno4KwzHvdEfK8A5LL18vTcyUTCpK8xpuPQU+IOZRlcLn17xFPJPFpqor/emCZf97d\nSHgOZotlOR+q7ore67sDf4nzygoS8/MBuFXtegSKlYU17Um83hOMsfMXf3+BHpdVemGNewVG+5so\n2g4efHo9rvrxU1i5cRfUip7AsUrA8vYTqoazvPNOAf/+3K349crf7/aa9qTbrVZ24U3mPQPefsxb\nnvdk3F+YD7eYKP246Nr4+6YncNeqe3Y7vr2F63n42h1LcO9j4W11AZQ6YgJW9HCaCNt76LuxtoL+\n/sLc6283Uc0jBKs29r6jxi97i4i8DzHwhDW3JObNLW9D5a5XD5apo7bSJ0VdV8vIuxQyAQXIhoQ8\nCnsgd/m7EjFK3rL1yWO0MVNDMhaU7LSJLb6DJxEpngHD8jCiIeWXAknhQSK3PWX/5vNAR2MaJ8xs\nDYxXYdnqPMnMMjV/PjSnrDOb43rCbc47hf34lTsAUPUzOugY7E1UCtT2CljeQ12w3mBN4FwyTGbN\nP7NsOx54YpMYe9HxROa9fJ9UaIDioj9TxKadg+jNDQKOibVbB0RCG1RPlCxx8t64Nc+6z0neBcVD\nkRRgEDZ+fq3En1SejY5CEqZdA5PF8hXNgc403Tlx0fmk280YEd4WbnnHtLjwLshKev1539LnBNfd\nnwsmj9VtwD82/QuATzAXnzyOnstxxaLW3jyOnch5W+QtkvGYkp6iEtF5T7a89Za1UONZPLL1kcDx\nolQupMd63inikefpPX3hra1Q48GMZyhyzFsi72Es78c2P4nOXDde3PkqVvS8OSxp7l6m1IM14SX8\nz9q7sXGAVkS83WzzjTv9+/R2Er5K8wGyUmfEMG/NviKTc9DVl8fGnYPIOTlsGiwvS1R3Y3kPN/4u\nFs/mW12XAEYe8SMfwWObnsTTr2/H7//xVuixpXjxjU6ahf9WePjmQCAi70MMgZh3wG1OicVQfOvN\nKnWba6qwfADfsm6o9gl+WPIOURIjw7jNOUbUVfqHqwomtlcFysc42cRMXciJcnDxEz6GuKVjYmsD\nHGLjcxeNRnUFHWdBTiJ2Zbc3V3BTkU4YaKyOB65HJOUxy7syYYpriM96TKjQnbuQkoHjEphsMdLX\n9Dh6cr2iZGt2zZHivM6OUXD7a2ATGy/upPrvXBa11G0OACZbbK3fPuhnzGsOirYrVMrkudVVHVA9\nPL9qJ2761QtQDBq394h0P1TXLxdi5L1lexH5okv34d4JVmGQGeT3UQFxNbjwJ7UvT4nZKRiIW5oY\nLzQbpk7H1S/FefkCSTdcn7wNej5L8cm74PrWZU/Od3trjLxL1dOMjlV4YM1DcDwXhAAT26tw1GRa\nG19winCJC7evDs72MaKigTeMCUPfUAGPv7JVkJDcjc4cuRJqRXfA8pZDQNLFis5hgYQzBqphH1zA\naH3t0gLDldzmEnlb4eQtJ/r9v9fuxLLulaH7DVceV3RtaHV+WODxzc8AKPdqEELws9d/jf9dG1yo\n9PTnWRc8EiDm4cgvb7t4+LmNWL2ZlmxmbT+GvCvfF3rMvoCX5xUdF//96p34/gs/KRM7KvXWBGr3\nh1ns8NACfw4c1xNVDH9a81f86onnsKnTv++7s8K3sERBfr/fDUTkfYhBVcOzzbmVoTGZUUV1ETM0\n1ErkPba1UsiE0pPQ40c2paXPJHeYbJmXan5Lx7O9yzaPb6kVC4yBrI3KlIXjjmgX23lTkpipiZcc\nh6uxlxnLJm+pTeDoFkqSd6+6V4xHdpvLMWthgbsa0gkDDdWJwPWUlsNVJM3A9ei1NN5cnaIuccfz\nhCeBqDbuXf0AvcaqMWhPjQheuEv3W9e3EaYSA8nR+f33y+eXzRG3vAH4Cxtmeedt+sKvTib8/TUD\niurhjU19gOJC0VwQx4TrefA8iJg4J29uUaowaRmT7DYXjVmkBZur0wx5Bp4QZ+cNxC1dkLeiOeLe\n9hcly7tAnzdV84TbnBOXqcZD3ea9WXo8IUy6Vy+KEkYK/9nryXK3ugKTkXPe4wsxQ1wDNCcQ81ZT\nvbjnzT8LBb8f3PMq7n70TSxdxWK3hpSAVbMT1sQXBQm6HvGrGmRIuQWDdjl5F11byMNyD4ilxv3K\nDtUT3gWZvLmGQSl6MoMwSAIN8ToAwxPgcKpyj21+MqDBz/NKZPL9xV9XYe2urVjWvRKPbHxMfH7/\nk2tx9zPPIT7zceitawLqdjIxquke4bnY3p3BH59Yi+/9lraslWV4d2a7Qse4L8ixDP9C0cP6Aerp\nKG3y44sJ2VjbtwHLi0/AGE1DII7r4anXt+FP/1obOIaHRGQJYPkdEZu6BErCf/YzuxGl6WFW/HCS\nvgcCEXkfYjihbSEA4JSO42GZ5daAcFUzy7u1jr4oJnVUY1RzhbAeAQh3LHe5A74rm26XasI7R8Dp\nbIPT2eZvl9zQJJeCN1SJRNHvuGVqJs4/kVoZU0ZS17HIqAZQLPjkHbf0gIAMjAIjW7pPU20Ccxpn\nYFrdJKzr34AhbTs7h3TxcsyaK615GtIJk1qBsopcSUb9rPH1qEun/HMxxboYUzVzXSL01wFfxtXQ\nDJhG8GfCSSTjZGEp/uKptT6FUpiaLITj66sXbQ95FhNorvGTDGO6IVnOvshMNu/QlzBbwMiWNyGA\n6rG+6l65d0K27ImniVp2wE+kKuboPbKYWA50h1U7EAzZQ1DsONA5GvYW1pVMc0SiD0+aMxCDxa5X\nJm/umvcGqJiOVrsNg3JrTql0atsQJVtVVYVlXfAYKTAvCl3EObAMSU9/1HI8ufVZPLLhn3A9V5RM\n9bA2nUqImI3sNlfCyrcUV7yMB/dgeXMPSEz1PUCK6olqD368O1gFNZ4RLm0ZWTuHQk7DB0efAcDv\nA5ArOPjt31eL/Rw3PKntjV1BFy8PXcge7KGcjd89/1TZsX99diNyBiVEo3VtoFc5J38lNgRr0guw\nJtM6/lK1uqyUUf+/ax8WvyEZL+54Bc9L/elLUXSL5fr/kuXN8asVv8OPX75dhMe4Vfzwhn/ihy//\nP2zxVkKv2waoDhzPw6/+7w08tGRjwAshpH+55e2RMg+kKpP3btrfdrPnbLgGPwcCEXkfYphcOwE/\nOf67mNVwBCyj/PYYnGBUDzFTQ1XKwq1XHI0vf5hm+fK4LQBBvgH3ouwelC1vosHeMBVeVs5WD24v\nrJyPKUnfhWxqBj588nh89zPzMGMctRZiElnlmfEbs6jbvLBsIYobJ4rtinR+njRycvvx9OsUjyXE\nlMfdrSnPCsubeNTytgzNT3aDH1fk1xC3NMwd7y88eDcxUzOggCa1aFKHXMI8Dbyvugw5NGCqscC2\ns0afiqq+2eLvQHvOgHyqi6JXBCEKWup8z0jcNH0vCCccx0Qmb9MsbE+lMe+Cr3QH1wAhCLjNdU3y\nTngl91+aJy7/6hQMxE0NMS3OzmvT5iu6DZe4UPKViO+a5hOo6kiWN53LJ1/y+8bLMW+e8ObsGAli\nGzDa30Rfwbcqq6sly3Dlb2BNfRquPgRNVaEqCoqk3PImqitCSfLcPrrxMVz1xNegN9FSL9cjwoPh\nZYOLq0DCWgi5Q/WEBSqTN/GYfKtr+6ED9jwljARkHf3BnA3Xc7FhYBNMLyUWMLe8eBs2D/ou7oJt\nU8liR0c2xyxCRn6PLN0kyc8SbM9tx5cfvwlPrH8hMNwEy81wulqhQRM6/4E4t+KhR6WJX2rp619a\n5PXYdBHVP1TwNQXYgpiXBZaSmWx5bx7ahqU7yrPOf7Xy9/jNyj+Iv4u2G9B8v/Wl/8Y1T92EXfle\n/HrFH7Az0yme9VIZ1NV9a6HX00UQJ+A1fesD+6iJwYDbnH+XJ2nYc0+G63riPnLIev6lynUyIvKO\nAADQVPYjUspdedzyVlRPuNJrKmJCwjBgeTOrNpDMEaKqxlGZNANxW0teCDA0VfrkbmoGFEVBY7Xv\n9pXJm7+EYqbGXKBKILNazlavTtPjWpK+znhpkhx3hauJId8t7lLL2zTUACnpsWLgHKauBRc2jNhM\nzYSmqXC8oOXNyVtX9fLYqpQ3YKlWYNOpI09ERW6cv13zr5d7PbTa7XijawOKjg14KtobZcvbpN4F\nkDLL2/OISNrjpMFnvQAAIABJREFULwlFt0EcAx4jb3gaFIVlC3P3eWkSIRfaAc1GB2huQNzS/fun\nOYiZmp87UIhTS5yoILaBIjJSzLsI4mpYsqwbdz9MXZM524/r8nixl6mE09kORSEYcCh5X/Ghqaiu\nCU6vmhhCwaRuV8NQy8ibXoODdFJ+PoOWqNEuNVlhiwsu7AJQFz63vF3J8k5qKcQd2kRGUT3YbJ4H\n7SHEtBgaN58HZ/toAIBDbL8Gmn1Hykj4vyvFQ6Ho4q3e9cg5eaSdVnj9dWIMXUO+KtiWHuZKdw0M\nZei4RJMdRhp6y1rE5vwdD+/6LYrI4c8rngxc85Cdpde1fip01RALKNntrbe+BcegnhAufyyseOn3\ns9T5E17dvAFf+q9nxCLHigWJqbQ3Oq9l93rpb3ht34bAdtlb4HgObNfG13/+HP7th/8Sn29l5ZM3\nPPtdvLDzZfxr6xLkmOVdCHNJMw8cH2NNrDqwWUkMBkrF1vZuxuceuwbLu94Qn8ltb6EHr4mrJAI0\ncY4QgnvefACvdvnhCdvx0McSE4frQX8gEJH3IYzm2gROmNWKL15whPgskE0eAqOEcDVVCcgbkuEs\nbwC1lbEAoafiQWICgNYaP0lNjudyWJLb3M821/06TEk0hYu4AEBVih4XsFRLM+Cl97OaYGVWsuWt\ny+QddJsbuio0vGWYqgFNU6h2t7SY4Mk3pmqUkbec9CeTsxibNN+xEMtbjWXxt/7f04Q1TxPudtNQ\nYeq+Z0V0RXNMZPIOtSCY5e1fqA04BmzH893mAGJxlLnNP3bqBEHufFvOy9JnytMRs3RhvampPjqn\njLy9giU8EKQYR8YblFzGRX9O2Hf15XwrLONkaEzZMcR+3EqLWRo8zd83zix/T6X3z9BUOGD3UnwH\n/d5kXAqTlHTVcwerpG1sIWeb8DIVYpz8he95HqAX4RXi+FjHlTAddqzqsg531PJOm0nYDoSHpugV\nxQKAex8qrKT4XXHPx7Iu6vKOF5vhDVXD3joGAPDcm742+NZeupghjo6BQWZpMsubX6Wa3hUoAyxk\ng7+/jJ1hc6RAgyFCF778bT/05vVAMQEvkxZiQaVqfRxLNvtlc+NHVOH8U/zcDzW1C/IP8ub/eQld\ng9TFbO8YgYQex/qSJjy25xPj7ct+gy8/eQN2ubQpztbuTKjVammmkKQNI0au9Oc4dCxyxjtA3d6y\nbsCDq/8BAPjjW38Wn/FjHdcLvEMAQElI5J230Vvow5Nbl+Dny+4Sn+8azIuZiCzvCACoxfzRRRNw\nxBh/tT4hdQTcwSoU3pwdekwpucdMLaiQJJM3CZKZkDdlkBOpONpqq6T9yxcSlaZvRfKXWMyULT//\n/LpE3tzy1lTNJ9mSxUVx3RHwWMIUb0nKY96moQlXOOCX9nDi0jU1IBwiX4OuKtjUOYTfPbpOfM7L\no/RQ8vYXKLy0SoYWIG95MVOSw6C6UIkuLNiKhCnlNDjCCiCOQd3mhJM3LwVzoai0tj5XcFDgbnMA\nlim7zTUoACxDCyTNAUCB5JDQaC5CwtIRN+j86rU70KmsgWKypKd8TMwDKcThEgdEp5nJil70xXDY\ngi3LuscRQtDv7PLbz7Lvz7t0e6ezCTuTVNr0kxMvw1ktF9Ahs/71pqHCUYKVA2JRKB5PAhhFjEx3\n4OjmuXRqpC588iKosHIezGKNyDsAaLKiYhQB2/QXSACgUMvbIx6G7AzSZgq27Sc2Op4jkTf9jsp4\nWlLCo9v6cixbv0jnmeTpwAekBc6OPhZGcA08vIS6yGWyo99B5X7zrx1L50hx8M+XttA+5ZkiuocG\nxBzpii5i5tzw1iq7oShAYeN4EMeEogDZYkHEkkvj/tttP8FLU5WA0Iw1+XmYE18AJ/A1W/rx6rrt\nbJ4NjKrsQHd+FwaKfqWCHMte2fMmrftnv+MbfrEUP/uLb81yFN2i0DRwPSJCFv7AWNWJ68sJK0RF\n7oVF9JqsbMBtvnEbnfNdhV6YE16A3voWc6F71I3O3iFpvQLENpnbnC0M8g5eWeu3C/WIhy2dQ3h+\nVac/3ihhLcJwiGtxFFfNg9dfH7q9lLwtUyuxvOWENXr7501uxPc/O5+2/pMs7/GtJf5MMMuCn1sr\nt7wbEv5Cg1tIPGv5e5+Zh7PmjQ0dK7e8AYiM5dIMYJJP4SOTzg1+oWR5u92tZePxX8R+9q0MQzWk\nemH/+3gs2ND0snri0fX+NRoh5C3PtxJI6C8JA6gudMUQZXQVSRNtKRqX16q6AuSbZZa3r89OAuSe\nLTjCbQ4Ahkl8kvc0aBpLAJOS5gACm+TAeSNmakjr/uIrhz5R1uQWLDEPXoH1ZlcGac9yzRPhEL5Y\n4q1fu3O7YKMgLF5ueXsKJYo3B/3yqE1bbdz10Dq2ncdXM7Br3gKgCNLjC4BYjL2U9SIUBUjoSVw0\n4TwoTgyKbqOphu4vXP/FGEBUmAodP7f+O6uepIsgT4Preb4YDot5d2W74REP9fE62K4nFp02sSWl\nO2Z5mwkpt4Fu4yEEl1Vf8DnK2QUUbRe/fGgVVm1hjXpcXWyX8wb4dRLHFGI7ikoT2VZv7sMdf10B\nWymI+VWhS25ztsBgdegkmxZz2DkwhKLtQW9eC60mqKrW7/o1y6qqBKRhAdAOaZK1nnPZ78s10Jyg\nZX49OT80kA35/cmu+tfWlau6FdyicJtD8VhIyQfPc+GLqEwxA89mioKuCkVzgqV10m9Qq+wRTXgc\nhzDLm97Hj7Z/Bl6mIuClsl0Pv3/CL9+7c/n/4DtLbsMDT/niMZHlHWFYqHu4Y2aJNWwZWlD3N1Aq\nRh/kptoE6qvi0FQ1QO6TO2rF/08fU4svXzjdj8ejNL5OURvzCZ94Kl08MJd5Q3UCR030CVa4iAGk\nErIrmrfwLL/YMXV+TJx386pImNA1Bc7WsdQqIeUxfuIRVFoVpaeDqRmi5EgJqXU3VCMgvfjZD07B\n5R+YIf7Ww8hb2t+TpEcntNUFd9RtqIqOuGR5z2+eCxAFesNmP6Pe1ZHJ2zR2yaw6a+ozfptXx0Au\nHyTvbGIjI2h6nzVNoeTLXtqJlEvdpooHz6afxS0dST2Jwlv0+lzFFpa3V4gL0RauVpbxBkUSk8hl\nYN+/dscu/Owvy/HEm5ScSYaFW3jZGrdwTN+789yrg4J8XOY2R7IXUF2MVuaCFBOB73hsgJYUcosx\nriawoycLt6hDt1xRiREgb0BkxOftIjziIR9jFmM+IfIKAFoOt2pjL/64lNbzt6aaqQyqwlu32uVu\n81hSkD/XyOdVBbativtJPy/gsZe34ull27F5Fy2Rq02m/aQ/j7vNFYgcCFsqeWT3dzBrY8122mKX\ne0Dyeep2J4SAe43VWAbEU0AKcXGN37l7KdZvH4AxgvUzcHTkXliEpNMIBzZ4GZ+mKgErmkP+zdhM\nuY84hig5zEreroxdImID3+1N57A8abDgFkTCGkqSyVJGUhCrIyzvrK+q6OqA7gT7juvhSWe268F1\nCfNuqMgXJE8Zu8ai7QbG8GrXcmjpXvEbEfu8S4jI+z2GcLe5VPIVEvMOWJYSudek/Bfr/KlNmDra\nJ3MAAUEYDpnc4WkBlzk9p2+5VyV88RiZ8MQChBGVSGarS6IuLnkDPA26ptDEKkUBoIAUEtAhuarZ\nS8ojBIs6TkBV/wwa72OQyZk37pARqJsHVZKrkhYBCb085i27zT2pfj5VojKnKABcFQ3VCUwbXYu5\nExtQHauC5VZBiQ8FMuppqZhfh6omhkQmLHENDOVtFGzfbd6XXI7p09l99VToqsK6zrH5GPMMVNZb\nmj8TBduFoijwhmhoxEYOipmHQhTAlmLezPLut/tgxIPkzc/Vn83h+VWd+PsKSt6lljePLTpMMCb/\n+kLs7CmIuHavthHfff5HUHU6B5br51qIlynJQzFzIt5tKQlk8g6Ia8BVCjBYtUYpefNcjbybD1iD\nzrYxrByPC9FQ8n19K81gbk01U8ubu80JJe+EpQuXdtKK+d4PI0jeXPdAdPDzCsiKen3672lzx4rf\nnS2XWqksROKYoGI7qng+sgUnQJwAkMl6ICBwPMePeceyIIUEAFXyDri49wnfclR0ByAqFM8MzLWm\nKqHlcmD3UYkPQq3sogtqT4NC6Dhkb1fGLre8lVjWJz+jnFjzbkGUivE5cnc14gOpy2CqlvjMcT04\nnoO8mxeqisTVoaiOmGN6fSXfwd3ujkcXAJoNuAZts+zySgJ/n7LjpTkChkmqO0CIyPsww5566IZa\n3oGYd3mdNycbz/MCCWvyQqBUfrB0eygUglhJrbpsrVt6+PG8QQh/iR49tQkfP20irr5oBkzNpCtu\nNv50wixrSiCTN38RVqUs6KqO6sKkQHtUUzVEOdCYmjZ888jrMK5q9LDXaGhqoJZ9XHMdPnDkCNx0\n2Vz/slUFhVVzoearsKDZF27hgh4yKhMJ6JqKL104HfOnUq+CiQR9YRh+0h0tFSOB8h7enAWOgf6h\nYHY9APSxzm10kaMGLG8A0OtYwhT7bOG0ZurZYQRAyTEPnSQAKNA1BQumNWFEFQ3Z7Mr3wUqweHKJ\n5a0YBVimImWrJ6Brip/YptlQFF+qlQghGVVoxW8Z2ibmwHN8qV23zw8ZKUZRWN4WErSch32HbrJY\nrsVkMJnHIGXRf9/Y0oVfPkL7NDudbSDFOAtNcPJmI0pQi7Ml2URj3sxtTsnbRdzSka70UBmjrV1L\nyZ+7r4u8RxD7DRb1XRhwugHFhcZEgyqsBGrScRAiuc0VOW4vJe0x0ti4Y1C4r0lJ7kHRsyl560W6\nwGChB+Fh01ykkiE04AYXWSqzvHUvjuKGSbC3jaLbmSWqN9MFDsnR3vSKS8chk3e2pH4bAPSGLYjN\noNnmhuWTYFKpggIFBafot2FlY/EKcShOnC7CJLc5j6kHLG/NoUTMUGrdK6oHKJS4HZewcj0ahvIX\nOPQ7MnlbkLfT3QJnRzubA7o9bmmR5R1heBT3QN56iaVoGcGENeLJ2xXpv1wmskSqkyGsLWCY5Q0A\nY6voD5vYFsa0BF3VMtEamoFvf+JI3HrF0SXnpeTI5V1jpoZjp7eIuHiVVcmuRRMNW2QYEnlfcfZ0\nfP7caRjTSo/RNZW9YPh1aSJO1VSTQGOqBhWmbJkH57M0/p00E/jwiePQ3ugfo6kKvMFaJDedgOqY\n/7nc7IGjpabc2rcU+oJVOem4GnIFF7miCy3mZ1VziyWQxyBZ+l051vCFuc0NPRgWEYlVnobPnD0F\nNRUxen+IBuJqKHg5KLotXsS6ruITZ0zGNefT+9WT3wU9zsnbEucCAK1yF5qmv+W77l0dLbVJsVDQ\n67Yj1rLJF3MJkZYFAM9gCnBF+twkYjq83iZUDbIKDL0o9NKTajWyeceP+zJLTjGZNeZpuOXf5iPN\nyHv11h68vtFPsgJo8ppX4vZWrBwsNYZ7/rYJBP4C1CU0YU2zCsg4Q+iobGGeJhWEKFA1Rt6eDVM1\nYLOsZrHAqejBC7gPWsNmaJXsGow4aiuo0EvOLhey4fFuIkkFr9vZi9gUKpzCFy5y3NzziEgMEwtX\nISTjoChJ2Y6qYhnlJeENVSXoLw4grqThdnb4izUtaBUX3jgyMA5ZMjUjZYIHcmMAQPEEeRc3TMIM\n71xYmomiWxAxb1GD7eooFF0YqinKHh2X+Cp2IrFRh6J5yBSkeWTkrQ42wB1gZWWqC9vxqKdDs0Fc\ng3lwuOVNv38wa4v5cLvaUB1jybvsGY9behTzjjA89mR5lwovWKaGie30ITthVmvoS5KngHgEQQlR\nibiUEPIu7fTEccX0T6Bi2wkgmUqcefTIYcdqqDra6lOoqYiVfS6j1Hr33dYkKNTBj5dUz2pSScwc\n71tqhqYG6n0BX7iBZ30nDD9cUGZ5l2Sex8JKxbgbnhDELR1nzO/AvMmNWDituWxfSyuPmXPyVrhl\nzd2sRRfElPpCM3KXQx2Vdf6LWGQrexp0VaVubzlhx2KuVtfXntfZfSaOgSFniMqzshc5d5vH9TgS\nehy7cr2iJI+/zD+4YIw4f6eyRri947qJz35oqug0BwBoXYmCW4ACJZDMJ5frOBq1evM5Rt5snKpD\nCXj2tCTMup3wCnGkvUYqYcleuq/iL4DqUPJm46urjPueE83P6OcvfM8jovWqxvu6aw5yWQVLWJtM\ngy1aHTh0MZ2gNdod6XamSgfAU6FoLJud2DA1U2RNBxfQ/iKNz21N2gI8Ddtz23HHsrsgMvqlccLT\nxMKoM+tnO8ulcIBP3rwlqli4Sm5z3oa3Wm3CN46/KuCh4ffC1bLwiIdxDS04YVYrxrXUse22P5eA\nOG57Fx1vgLyZZXzVjE/jqhmfDswBvRd8gRLDUMaFpZk0YU3EvNn8cfJWTJFQ5rgeduWpp8kX86H/\nDhX8MSi6DS+TRmbVLH8Bwo7f5qyHogDeUGXAbS5yC3K2mA/iGJg7voWek2kimHpkeUfYDfbU67fU\nhZyMGWitT+G2Lx6DxaeMLy9XAkRrPyrmIFnGEonK33v6yJNRH68NxH5lWJqJq886ATd8bI7I+JXB\nm6+UeglKt/NyH8sILjgqmeWt6DbSyZBac0XOXA9+h6YpActbBpf7TOp+LH5PlndY9yQeo+eqcecd\nNwafPnsKmmuTqDQqA/uGldvFWekWfzEeN82vr1WkF7/O483SgqxdnVZ2PuL6lrdcIy7I3/W158e3\nV2H+lEY0pqtEaRBPaNOlhUtNrBo9+V7U1zOyZy/CU+YEdeCJ4oB4KmaOa0RTTQLfuuyowPa8W4Cl\nmaiuKF8EAUBRoyV7OUbeosENK9dzk53wFAfurkbs7M0hm7eF29RGHlrtdroAkcSBYixPQdHcMnf0\n8nW70DfAXMUa70jmBBZIwuOkuFTVLk5Jo6OizV9oen5M2iU2DCk8M5yXAQDqE3WoqYgJ1/1rXcuR\nJf2wJlBJUd/y1oU7l3sv7O0j4Q0wi5aR8zPbn4dLiBAb4QtX/syYY19FhtDx1+ktSFspGLoKj7e5\nZZamrdHjm5J1+OiiCWirpgaBYmVhjnsZWsUudk56n555lWaqb+rpFdfG3ea5IR1JI/heUHQbnsEs\nZ9tAf6YIS7NQCIl5wzGQLzq+qJLqoug6uH/NX+k1MhU7fo0i1q7QOm6xuJcaBdmOh60OFW5xu1tD\nLe+hrB2o8EjH6Hvig8eOwHc/PQ+WoUUx7wjDY97kJswaX4+vLZ4Vup1nVNdZ9Zg2uhZnL6Qu7GSM\nJWaFkTezvUvbBcqiJvKmM0Yvwk3zrw0mp5WgtjKGUc3h5M7JebiYOd8u9MdLkt7SJn0BKbqDdLyc\nvGXLu1RIxtBUv+SoBJwYZMtbLyFXTmAfnXQhxlaNwsiKkqYlAM6Y34HT5rXjU2dNLtv2hWlXBWr0\nwzL2E5rk1lc0jGzyCb+m+xhfI54n+kj3tEFvxw+O/ffgCT1VinlL99RgbnfPz3jXVBWfOmsK6lL+\nvXOKLAFLWrjUxqphezZ67R7qvuS92y3//GkjhYq0CsXTce6xNI9A04KLy135PliaiY+dOhFnzO8o\nmwvCwgCdPVRHnN8jTmI7MszqtC1s685Qy1tWA+WhBdvCSbOobn/CYG1Nx7wu+otzwn9pdZcIJ2ga\nK8nTnMACSSgPcnI26QJjRLoVlsmS+jwNHmhfe9uz0dMnJToNoxxYvXURLM2kSoeyVCk2+vMhW95C\n598RcyD2Y8f/c9OT6NPXQIlTD0ap5a1oHvQxNJuee5Fk8halWCo9vi5OibE6Sc9jtKyHVs3ugfQc\ncm8N11bo7suhM0sJ/Sf3vFn221fTPSC1G0CKFrxsBdZvH0BPn4OcUxAiLVzbgdgx5G0XGgwxxgIZ\nQme2G9PrpooWvdzyzhTZ74Qt1OIaj/v7mgeO6yGDXpCiBZJL0wx1fs/ZImkoZ0uuewOVTGggFgcq\nUxZMQ0XRdvdoYO0vhJs+EQ5ZWKaGK88tt644UkYS35p/HSrMVHhMOqS1J3/Zlbb+k6340pZ77wS+\n5R1O/pogb/riLiXvuJThHeY214gpvqd0gUHJV8XE+Cx0NFQFtnELf3duc+5Wntc8B/Oa54SOP27p\nuOD4saHbUrFYwPIPu0dJPQk4/vbqtH+9llcJe/0UWJOfB1GZFSBZhZahIaZbSOoJP8bo6dBUBTFL\ng9vTgqJuw+zw5SHh6mVd32TLyCkyYpeItyZO44V9hX5U6JXg7RsURUFh9UxY41+B7dlIWAZqrKQI\njWglnouck0M6UYdpo2sxbXQtHlqyEYWVR6FuyhoMEhazJ4Bjq+hoTQjPByfvXqaRbqoxbO/JImbq\nouc44MtbLpzSjounUNnadExanNWzpL1Aq1z6Ha45CKj1rCpAmmON11mzeD57oSeNhL/wJCo8uEhY\nOlzVLbG2gwsYTt5xRp7phBH4neaJn+XtDXBi8suYeLlVIJ9B+v+COgjVGoKXjwsPguyB4z9z/rsy\nNBW9/QRWo+/9Kap0DPUsVl2TKM/VCHw/I/8iyeP1tT348f8+g9gR6+EO1ACuIQiZQ6vugqIAxS3j\nAU9HvujCLChQzaLoC6ym+kEIdWsXii7i4HF5FzZT4RvoKxeEGsxnAZjQ0rS6okKvQhcQ0DywHQ8O\n8ZUCs3lb/K54eCJrboNVxWrfPRWVcbqIzjssVBUbgDbiDeSKp5XNzYFAZHm/B1EXrxk2mexblx9V\n9hmnZbIbgi61yt8JRFx+mFOqSnncXkZCcmvL7U55rJlnm4dZtfzlf3TNSTh7zKmh35PYjdv8ncLU\ntYALN8z7kNb9a7JUU7jhAZoMN29C0NqX3eomW4BUxZi1ThSAKNBUBcmYgfqqONydI6ES2UrSAhYz\nAD+jHxAvYpng5Xr+ilgSs8bXi0XloglzUUmakXcLyDm5gJiPripCHpQjVhL394aqMdM8xf/A1QEo\naK5Jipp8txict7pUBTp7cxjMFuF2t+KktuMB+PKWDekKUXWRNMt/G7LkbYJtH4qvg9FOFzky2QkJ\nYJ6Mp9iIaVbwuXV05L0cYpVDVMSmpLJDBifvGHvuUnED8o8jD2r1FlbPFORbEaf7ajXbpQ575RoO\nAPWsKUYRhCWr3fjxuThnQfnikqvrmYYq7jl3mxdUujyrZ5Z3Q0U5eZeqNxJPRc7JY9naHiEA43bS\nZ3XZup7gHHDPgOwV83QoCi+1I1CT/SC5FFRCyZ9b3takpbAVapWv3iDVkrt8AcF6rTdSQZZRFvOI\nSZZ10XHhoCg8BnLuBPds8GeBjRhxg3fQo+SdSa6F0bwBm3qD7UoPFCLyfp9hREMKN867BjcvuEF8\nxt08Lvt3XP+5uOGoqwPH7U/y5pa1S8KTO9QSy7s05t2UpOpN9bHaQO05fzlzyzuMGPnLP6xjG/+e\nZIjlffHJ4zB+RFVACW5fQL9fETHN/kJ5b2ceFgAASw+St6oquPC4oDu+QrIkeQ9s/pKloQefMK5f\nPBvzpjQibUrk7Oplcyxn3E8f2YxPnDEJx83wBXZ4xj9A5+vKc6dhFksMvPCEsRjZQL8/5+QDSXma\npsLZOg7FdVP9awxJ2otp/vg4cTbX+Za36ypI6v51N1VVwiME67cPQFVUnNi+AIBfTiff0zE1I/zs\neAauugbQOefQG1g3L4mYDFWjiyJmeXuqLeLoHPa2MSDwUGxaRj/gmvNmubdJMQsgRKEd5cDIW/N/\nG0VlqGwMHfXU82GOXAU11ReYJ/n7AKDIiJfY1LXb0ZTG+NagZgPgaxYQ4ru9eYJWERkYqiEWdfXp\n8pBYaSIeHAMFL49ETIfCeoDzkM+ytUHyVpmSn+w14Za8Yuap7oHmwstUwjI1arm7vuVcTG1mx0jN\nhQIxawIt3Y8RqTbUWLWB8Sqai6ydp78VtmjpzxQDx1tSCaCzg4Z3+D1/bPNTeHzz037mPQn3KO5v\nROT9PkRDog6VVvnKmbvGLSWFpmRDYFtIXtY+QxXkHX5SlcfaWcy71FoZXdmBzx7xcXxp9hXB47ik\nNCPv0pp3gFs1VIq0FMJtHmJ5nzJnBK67ZFawZn4foCgKrr5oBmbWURWz0sQdAEiYMVHrbGkmkjFd\nkLKmBUkLAJKmFONn19DMFjgcfOlVmbLw6bOmoDImddjytLJEx3qplOeoiW1YMK1ZzB2AwPOTCLsG\naQ5ly5vfS/klHUbeimv6nhNO3jVJcbzreUhLY2itoZ6ATN5BIqajwkoHqiHkMVZYaehvLkL+9YXi\nM/k+xIzdh5scj8BQLKiJQSjxAXhKOXl7fY1I6Wm4Zl/g+NLKCQFXF2JKybgRVC5TWaxXImdu9QF+\nXH+4RLiiTscwsq4ON1xKQz1hiZIJk96zwWxRsjqZ2xw5VJgp8ZzIz5x8DYCfsEk8DXllANvct0SN\nNo9Db9hRrtYG0JJD/qyLOTviaRH+IPkkYqaGfNHBUM6fI0/3NQ9Kx2N0rGJzRFATr0KMe5mE5e0K\nsR5O/rmCFPPWHDom1QOxTdibJtFxSc/tfW89KMJYils+twcCEXm/j/GlC6ejrT6JY46gJQ+cvGWC\nGsvqoxtq4uUn2EfwOHRYpjbgW+Ya++2kQmq5p9VNLluAcLe5RuiPKszyPml2G7568UyMaPDJ6+On\nTUR7QwpjWqk1EbS8939ayOSRNbjsiPNxycTzsajjxLLtluRaNzUqQiMatygKNFULvDiC5E3nroy8\nSxwnSSNoeZeiPu6Tt0zEHBVSA5rQBUiAvP2xcs+HHDoI08h3XCI8LPwl3taQFDK6iZiBasn676j3\n3fiJmA5VUZHQ/WssXfAYugqST6Hw5mwU101FXaU/3ngIecvEWSi6mKgfDUVzoTdtgIci4iELkLRR\nCaIEO7uVhif880slmpoaIG9P9ZOkOJRA1QDvXS/FsaUua45OiW/e+A7RwS6szDNlMNlbqVaeyt8S\n2MghLXljShd79BqY5rzJa8jpta8k//TVANl5t3aXS6USxwCIhoZqdi8kmeOWDno9Zx45HjGTajP0\nZf2ySRKKQnq/AAAbb0lEQVQbCJxfHo+i29CbN7BrTIjx8Xtijl6OdblV9CBXFzkAckKbodN7Egif\nlNxzy6I/suaaYEXJgUJE3u9jTBtdi29/4ihhhXLXuKym9qULp+Nri2dhTMv+eyC55T0cefPt6YSO\nb19+JCrfpqv6iDHUHdZWzVyKIdZFzNQxqaM68PI5dnoLbrr8SBh6iOUdco79AV3VcXTLkaFjNAzV\n713NSJeTN19Y8aoCUzUQtyTVOmF5S33R6ZkCf6UMyfIOJW/frZowysm7UnqRlxJj6THyNXLrkkus\nAggo1vE6esfxfCEPRkS1FTGcMa8DC6Y14XPnTEW15Sccjm32a/m5cE9ausbSaxC1+P31cLvbxLMD\nhJO3vMAp2C7aY+MB0HI7opAyyxsAKqQmL9yKa2sIL1NEqbWmlv825PvUW/Sbhvglf/52N6RxkRyO\naU+34YNjTgsQZMqS480avEwaWsUuaPVbQBQvcDwAjKroQEqpFp2+xjbV4rR57SK0YW8eL/bV0n1C\nOnU48FBGQzUTKUr6IaW8QRu3jKqvQ8LSkc07yG0ZAQyy6xS96/05mDLC9x7yTPWkkfS9H9K+y7LP\n0jE4BkawBQ4kt7mha9QL4egY1ZzGly+cXrbodPUcTM1EOvHOQmtvFxF5RxDglrfspo5bOsa1VQ13\nyD5B3UPMW1jm8IZ/2YXg46dNxFXnHYGFU6hs4R7lW4eBoRnCZbuv53gnMHVNvMy5vCTPOOfZ2jUx\npg6lKIhJ8WruNi9VsCq1vNvSkmBMyAtVJtQwy1te1ISRe3wYy1tkrEtWolySyMvRbMcT18jjoYqi\nIBEz8IkzJqO5NonqmL+gbK3zn9EjxtJrr5LIPVGywCgNxUwf689XmEuYuJqwVfNFF0nTAnE1EVMP\nI++URN7cyjv32NE4//gxZfsSV99ziZFENos6ThL/z/UQZOudZCqRe/HkQGy/QiJfRVGwqOMEJGzf\nQ5OWyRsKiutoAqLesLnseAC4es7nME+7UCwARjdX4YLjx4pyPrenFflXj5PGT3uNDwfujeGWt73N\nn6ch0Bh52kxhVEsFXI+gp9dD88CxwXOwRe8lp4zHF844DpW5CQAANUkt85SRRJznHYQtJFzd98rx\nksGqbrg1a2jioWvgklMmYOro2jLvQ09uV6gH5kAhIu8IAsfPpAlJs8aHtxvdXxhTORJAmHVIMb1u\nCh1P24K9Om/M1DFjXB3SZhKWZgaSrvYWPEZ6INzme4IpWd5claqmgvc7py8MTmxFtwhLiqNazHug\nqzqumvFptPfRspVSWhhd6ddUf2NxeQWCjHgIecsItbyHiXnLXp20Qq1dXv8L+Ja37bpoSNDnkBBg\nzsRgDgYAVPMFDILiOVzJri3V4o+x1PKWyPu4GS1oqfWvIWUl8K3510IfkhY4ro6PnEItyUVzR9De\n6LYpuqrFJaW9tnrqrm9IBpvoANSDcvq8Dnz9yC/jxBHHBM4vo3YoqONAXE2QCQBMrh+Lr5TkfJSF\nPzxdlNQBCP09aNLzLSc+AsBXPngs4GqC+NIhxxMoYlw8h+Wy0yYKWWRSjAnLXHZpA0DlllNwztgz\n/HOxPIiKBPME9jYF8hIA6k2Z2O7f95baCmiuf2/5d+SLDlRFRYfH+ruzkreUZHmHClY5BkYIqWMF\nbj99RovJbfQjVy+rfvHnItwDc6AQ1XlHEDh9XgcWTmt+227qfcWF4z+ICdVjMatxeuj2SbXj8b2F\n3wyWK+0FNFXDNXM+H3AN7y0Sehx9hf6DYnkbuubXm7JabRHzZqRTKxGX/DIxpSz6CTVjYbkZAD1l\n7D0i5WeOj24O96x8ceZnsHFwS5m7tBRhZYmyNR6WkAYA7eZErCg8g+aUb/0J8nY8zG2cgXV9GxDL\njMTZx00qO77GCo77psvmYiBbFHM1qmoEsCV8DHweZ42vx8dOnRgcu6WjLl6LOSNH47lupn3u6Zg/\npQknzaZCL6+s7mJSpdQzInsqvrZ4Nrr6cuhVN4nvT5oWvvrJo0TYoCXVhIUtR+GxzU/R87s6IHHC\nCGUatrxYg/icf9APSohGVYKJi6RE2lh87vj3Jox8PdZ61elpQoKFX266bC5Wb+7D5JE1UFdUwovv\nYseXPwfHTm/GP5dyOWBK3o01CXz90jlYuWEX/vPe10DsGHXtl1xDtVGHk9tnYkzlKNz94t+xsZMu\ndqaMqsH0TbUY0ZjCX5esByGK8C6kzBTGj/AXXifPbsPmt6rQyYVYuDgMlzw2LZCiJRZZKTOJmDK8\n5U1cHUdPbcIf/klbpBbfnIPYnL/BNQbE9pgxvOt/uGf9QCCyvCMIKIpywIkboC7Go5pn79aqTUuZ\nrfuCpmQjUua+kT9A3c5pI1VWc/5uwNJV1roRqGSJYaUxb9EUAcGyt9KSLz6DpIS9Dc3AhOqxGF9V\n7sLlGFc9Bie3Hzfsdo6w8IdcBx6WkAYAU5Nz8NkjPo6zR/v19kdNpkQ+sb0auqrjkknn47w5c0Q+\nggzZbQ4A7Y1pTB3lx65HVUqysiXPkio66ZW7qrkVP76+zf/Q1f0sZQCmqQUy5mXhoLilo70xHcgb\nOG56G1rqgs9jXPYGuDpOnOV/XzpBVdZ4YlmYlRhIOvR8adLW+iTOOYY1B5LqpsPCHzWDM+H2NqCt\nOE/McXtjGiczmVut6Lv+wyz3uso4KhL02r2S52DyyBpceMJYv1lKyTV091PCHVXZjqnG8aLne1NN\nAl+4YDqV2iWqyI8wVB2WZiIVN/CJMybhyxdOR3tjGpMa/C6AHz91Mlrrk2KR1TdUCCRHUstbCx0P\nAMAxkIobuP6js3Hy7DYACohtwVN5A52g5X3JxPOFpxAID58cKESWd4QIIfjY5ItQcIvvaAGxr9B1\nFW5nG2yjgCvPOx8AUJP21a+AYDKWXH5klpK3SJ0t/56rZn66/MO9wHnjzsIDax7CxJpxZdsaE37o\nZTjytkwd0+qCNevnHDMacyc2BKoBhgOPaZdm1nPwpD7e31lGW30SmzuHUFc1/MtWvi7iagGXv6Vr\nAZd02Eu7OdmI5mQjtmd2oqmy3LshW84nzhyB8SP8fToa0wAUmEocBZIVVutpR7ULyeOEEYcChS7M\nJCI6YWYrTpzVhhNnt+Hz/1WEmuqDSozQZ/mCo+bi0aWNWPyhCaFzoBerwIVd08N4ssZVjcZLna/5\n1QESqtIWyFZekkUt81HNVP50guT+lhdn3PuSZImH3kAN1FgWtudn4C+QmvzMbZyJf215BgBNPj12\nuh8uaapJYHl3HGqKJr8ljSRivN+BY+L85o/h3rcegJryLWuAVtmMba3EkZMace+W5diapS4U4hgB\nsaKjW47ElNqJeK2b9q1/N2PeEXlHiBACUzOHVak70DCYhKuzdRzqE9R6a29M4ZxjRokOaTweXB+v\nDVjbZoj4DDCsmN07wokjjgnGbSXIRCFaNZYgbKyqqgTaq+4Opmbg5gXfGHZxAAC5l06iL9sPBD//\n6AcmoK0hxayr0rHTfyulxjtfuXB2YB/TUAPkHQ/pLqcoCq6Z83m81rUC0+unlm3XVA03zbsWf1n3\nMOY1B88/aSQlNiPXgEJsA5QEJRdNU8X9VhUVcT2GrJMrkTv1O7DpJIbCiqNRlQ4nlTEtlbjinOHl\nluOZDgwZW2FU9pZpP3AsnnQBptVNxuyQMFgypkvtR6llfvyMFpxz7KhABYvrlmfX88WS21/ni+WE\nYGTFCDQmGtCSKs+h+dAxo5B5eRJeGKKqZykjAVPKjxhd3Q4vlxbkXZr1P7atEvW91YK8DcUs03pI\nmynoqg7HcyLLO0KE9zOSMQOXfmBCwPpUFAVnLRgl/q6NV+PauVehxqoWtdMAy1QPwbvUKyGARR0n\n4G8bH0d7upwggXIvwb6gcpjOdhy3f+kUhDlPYqaO044qb4RSio9Nvggv7HwFExpaA5/HTC0QTx7u\npW1qJuY2zRz2/PWJWnxy6uKyz6tSFlrrkujcUgN97Aa/R3XJjeTiIh111Th/8Ww8+vwmHD2Fkpii\nKHBcD4CCUU27n6fhoCkGiqtnY+KY6mFzH3Z3jRPaq1D7Vgp96BYKZJapBcIbANA7SGPSFSG9Cnij\nkXFVo8u2AfQ6bzjqK6GeBUPXcMmcU/DCE4/T79aswH6WqQXaKB83pTyMJHdPTJrhXRJrrCp05roD\nuQ8HGhF5R4hwCIJn/u8OYaSol3Tt8t9T7z57nz36VBzVNCvUnQr4mfEHEqX913eHay6eiQeeWheY\n+yObZuHIpvIOfg3VCZw4bQyeHqB61/EDYHGdd/wY/PTPWRQ3TII3SC3x0kXYpJrxWLVrNU4ffRLG\n1ldibFu4FT1hxL6Ve/L5U/YxPUpTVZw//Rj8YvkGuF10XsMWmJUpujiZPKqmbBscE1+c/CWMqKsu\n38awu/CWoer4ztHXI+fky/bTVQWktxVesg+jnWNw6YfKEyPlBWKlVU7eAK3+6Mx1v6sJaxF5R4jw\nHkLpy4n/fRAMbyiKMixxA8O7+A8WJnZU42sds/e8I8MHpx8JrO2GSzyMZuWP+xMzxtZhbGsVVm30\nPQSliYeXT/kICm4xkMAYhpHDtOfdEy47bRLufvRNXHRSeV7D28XMhmn45lFfxdeefx1AeF+Bs44e\nicqkhWOOaA58fvnpk/Dy6i6MaWh6R9LE1bEqhFG/rqtQMrUoLF+A6qnhYYFKSU2wKhHufeClm2Hh\nkwOFiLwjRHg/4GCw9x5Qmhl/uCFhxHHxxPMO6Hc01iSwamOv+LvU8k4YiVBteY5vXDoHb27uxbi2\nfVNIbKlL4tpLyj0Pe4vGZD14NnxYuMTQNZEhLmPhEc1YWELo+xO6poqSjPgwuvNT6yahVZuADVuK\n+MAJ4eWtnLwjt3mECBH2Cj/43AK4XnnSz26SzQ869kfM+72OxupgedecCeHW4XAY3VKB0S37ZnUf\nKBxKizZFAbhBP5znPWkkcN2xl9Me4lY4ZY6pot6RltSBW2iUIiLvCBHeA6geJpt4yqgavPRmF2aN\nrwvdfjBxqLnND0U0VvtW9e1XH79XMfxDFeYwCmXvJj555iS8sbGPlX3tObSkKsqwxA0A46vH4gfH\n/ntkeUeIEGH/4NjpLRjVVPG26qbfbZjvASI60JgyqhqTOqpx9NSm9wRxA76lezBx9NRmHD2VWsn7\nK6fz3SRuICLvCBHe01AVBR1N+67xfiDwqbMmY/32gVDVtAhBGLqGr148fKnZ4YTT53Xg2eXbUfUu\nqDjuDfzQ0qEYXBoeB3Qpd/PNN+PDH/4wLrroIrz++uuBbc8++yzOP/98fPjDH8Z///d/H8hhRIgQ\n4RDC/ClN+MjJ4/e8Y4T3FM4/fgx+eOXCQBOZQwEXn0wz6WVltsMBB8zyfv7557Fx40bcc889WLt2\nLa6//nrcc889Yvt3vvMd3HnnnWhsbMTixYvxgQ98AGPHjj1Qw4kQIUKECBHKILvQDyccsCXQkiVL\ncPLJJwMAxowZg/7+fgwNDQEANm/ejMrKSjQ3N0NVVRx33HFYsmTJgRpKhAgRIkSI8J7CAbO8u7u7\nMWWK322lpqYGXV1dSKVS6OrqQk1NTWDb5s2bd3u+6uoE9P0cI6uvP7RigYcronl854jm8J0jmsP9\ng2ge3znejTl81xLWSjV59xa9vdn9NBKK+vo0uroG9+s534+I5vGdI5rDd45oDvcPonl859jfczjc\nQuCAuc0bGhrQ3d0t/u7s7ER9fX3otp07d6KhYe/EByJEiBAhQoT3Kw4YeS9YsACPPvooAGDFihVo\naGhAKkVrTdva2jA0NIQtW7bAcRw8/vjjWLBgwYEaSoQIESJEiPCewgFzm8+aNQtTpkzBRRddBEVR\ncOONN+L+++9HOp3GKaecgptuuglf+cpXAACnn346Ro0atYczRogQIUKECBEAQCHvNBj9LmF/x2Gi\n2M7+QTSP7xzRHL5zRHO4fxDN4zvHYR/zjhAhQoQIESIcGETkHSFChAgRIhxmiMg7QoQIESJEOMwQ\nkXeECBEiRIhwmCEi7wgRIkSIEOEww2GTbR4hQoQIESJEoIgs7wgRIkSIEOEwQ0TeESJEiBAhwmGG\niLwjRIgQIUKEwwwReUeIECFChAiHGSLyjhAhQoQIEQ4zROQdIUKECBEiHGY4YF3FDmXcfPPNeO21\n16AoCq6//nocccQRB3tIhzRWr16NK664Ah//+MexePFibN++Hddccw1c10V9fT3+4z/+A6Zp4sEH\nH8RvfvMbqKqKCy+8EBdccMHBHvohg1tuuQUvvfQSHMfBZz7zGUybNi2aw71ALpfDddddh56eHhQK\nBVxxxRWYOHFiNIf7iHw+jzPPPBNXXHEF5s+fH83jXmDp0qX4whe+gHHjxgEAxo8fj09+8pPv/hyS\n9xmWLl1KPv3pTxNCCFmzZg258MILD/KIDm1kMhmyePFi8o1vfIPcfffdhBBCrrvuOvJ///d/hBBC\nfvCDH5Df/va3JJPJkEWLFpGBgQGSy+XIGWecQXp7ew/m0A8ZLFmyhHzyk58khBCya9cuctxxx0Vz\nuJd46KGHyB133EEIIWTLli1k0aJF0Ry+A/zwhz8k5557LvnTn/4UzeNe4rnnniOf//znA58djDl8\n37nNlyxZgpNPPhkAMGbMGPT392NoaOggj+rQhWma+PnPf46Ghgbx2dKlS3HSSScBAE444QQsWbIE\nr732GqZNm4Z0Oo1YLIZZs2bh5ZdfPljDPqQwd+5c/PjHPwYAVFRUIJfLRXO4lzj99NPxqU99CgCw\nfft2NDY2RnO4j1i7di3WrFmD448/HkD0e94fOBhz+L4j7+7ublRXV4u/a2pq0NXVdRBHdGhD13XE\nYrHAZ7lcDqZpAgBqa2vR1dWF7u5u1NTUiH2iefWhaRoSiQQA4L777sOxxx4bzeE+4qKLLsLVV1+N\n66+/PprDfcT3v/99XHfddeLvaB73HmvWrMFnP/tZXHzxxXjmmWcOyhy+L2PeMkikDvuOMNz8RfNa\njn/84x+477778Mtf/hKLFi0Sn0dz+Pbxhz/8AatWrcJXv/rVwPxEc/j28Oc//xkzZszAiBEjQrdH\n87hnjBw5EldeeSVOO+00bN68GZdeeilc1xXb3605fN+Rd0NDA7q7u8XfnZ2dqK+vP4gjOvyQSCSQ\nz+cRi8Wwc+dONDQ0hM7rjBkzDuIoDy089dRT+NnPfoZf/OIXSKfT0RzuJZYvX47a2lo0Nzdj0qRJ\ncF0XyWQymsO9xBNPPIHNmzfjiSeewI4dO2CaZvQs7iUaGxtx+umnAwDa29tRV1eHZcuWvetz+L5z\nmy9YsACPPvooAGDFihVoaGhAKpU6yKM6vHD00UeLOfzb3/6GY445BtOnT8eyZcswMDCATCaDl19+\nGXPmzDnIIz00MDg4iFtuuQW33347qqqqAERzuLd48cUX8ctf/hIADX1ls9loDvcBP/rRj/CnP/0J\n9957Ly644AJcccUV0TzuJR588EHceeedAICuri709PTg3HPPfdfn8H3ZVezWW2/Fiy++CEVRcOON\nN2LixIkHe0iHLJYvX47vf//72Lp1K3RdR2NjI2699VZcd911KBQKaGlpwXe/+10YhoFHHnkEd955\nJxRFweLFi3H22Wcf7OEfErjnnntw2223YdSoUeKz733ve/jGN74RzeHbRD6fx9e//nVs374d+Xwe\nV155JaZOnYprr702msN9xG233YbW1lYsXLgwmse9wNDQEK6++moMDAzAtm1ceeWVmDRp0rs+h+9L\n8o4QIUKECBEOZ7zv3OYRIkSIECHC4Y6IvCNEiBAhQoTDDBF5R4gQIUKECIcZIvKOECFChAgRDjNE\n5B0hQoQIESIcZnjfibREiHC44ZZbbsGyZctQKBSwcuVKzJw5EwBw3nnn4UMf+tDbOscdd9yB8ePH\nCz3rMHz0ox/Fr3/9a2iatj+GHcDOnTuxbt06zJ8/f7+fO0KE9yOiUrEIEQ4TbNmyBR/5yEfw5JNP\nHuyh7DUefPBBrF27Fl/60pcO9lAiRHhPILK8I0Q4jHHbbbdhy5Yt2LZtG6699lrk83nceuutME0T\n+XweN954I6ZMmYLrrrsOs2fPxvz58/Fv//ZvWLhwIV5//XVkMhncfvvtaGxsxIQJE7BixQr89Kc/\nRV9fH3bs2IGNGzfiqKOOwg033IBCoYBrr70WW7duRVNTEzRNw4IFCwI9ijOZDL7yla9gYGAAjuPg\nhBNOwJlnnokf/ehHIISgqqoKl1xyCb797W9j48aNyGQyOPPMM3H55Zfj/vvvx9///ncoioKdO3di\n9OjRuPnmm2EYxkGc4QgRDk1EMe8IEQ5zbNmyBXfddRemTp2Kvr4+3HTTTbjrrrtw6aWX4vbbby/b\nf+3atTj33HPx29/+FpMmTcLDDz9cts/KlSvxk5/8BPfddx/uv/9+9Pf348EHH4TjOPjjH/+Ib37z\nm3jmmWfKjnv22WfhOA5+97vf4Q9/+AMSiQRaW1txzjnn4Oyzz8Zll12Gu+66Cw0NDbj77rvxxz/+\nEQ899BDeeOMNAMCyZctw66234r777sO2bdsOSy9DhAjvBiLLO0KEwxzTp0+HoigAgLq6Otxyyy0o\nFAoYHBxEZWVl2f7V1dUYN24cAKClpQV9fX1l+8yePRuapkHTNFRXV6O/vx+rVq3CkUceCQCor6/H\n7Nmzy46bNWsWfvKTn+ALX/gCjjvuOFxwwQVQ1aCNsHTpUuzYsQMvvPACAKBYLGLTpk3ieN4+debM\nmVi7dq3okxwhQgQfEXlHiHCYQ3YrX3PNNfjWt76F+fPn4/HHHxfNPGSUJqSFpb2E7eN5XoCIS0kZ\noL2M//KXv+CVV17BP//5T5x33nl44IEHAvuYponPfe5zOPXUUwOf33///fA8b7fjihAhAkXkNo8Q\n4T2E7u5ujBs3Dq7r4pFHHkGxWNxv5x49ejReeeUVAEBPTw9eeun/t3eHOAoDYRTHHyGYJlwAMAjg\nAFROSC0STCWCIJCYBhwOwxEqegIkuqLBbRN0LQaBxkBZsdkaDJutmeb/05PJ517eZCbz9bYmSRLF\ncazhcKggCOQ4jm63m2q1mh6Ph6SfVv97VJ/nuXa7XdH+z+ez7ve7Xq+X0jTVYDAobX6gSmjeQIUs\nFgvNZjO1Wi3N53MFQaAoikrZezqdKo5j+b6vTqcj13XfGnq329V6vVYYhqrX6zLGqN1uy3VdrVYr\nNRoNLZdLZVkm3/f1fD7leV7xVWq/39dms9HlclGv15MxppTZgarhqRiAj1yvV6VpqvF4rDzPNZlM\ntN1ui3fn/3U4HHQ6nbTf70vZD6gymjeAjzSbTR2Px+J/4tFoVFpwA/gbmjcAAJbhwhoAAJYhvAEA\nsAzhDQCAZQhvAAAsQ3gDAGAZwhsAAMt8AxJ5C+54P8QOAAAAAElFTkSuQmCC\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEcCAYAAADUX4MJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvXeAVNXd//++ZdrONsqyNBUECxZQRBHUoKLoE+lP0F+i\nxMT4tRDFWBKVxG7UJPaK8mBBE40lQBAVFAQE6bAU6WWBZXvf6bec3x+3zu7M7iw7w+4Onxd/MDO3\nnXtn9rzPp5zP4RhjDARBEATRAnx7N4AgCILoHJBgEARBEAlBgkEQBEEkBAkGQRAEkRAkGARBEERC\nkGAQBEEQCUGCQRDtzLp16zBq1KiE9n399dfxxz/+sc3nIYhjgQSDaHeuvPJKnHvuuaitrY36fMKE\nCTjzzDNRXFwMAHjooYdw5plnYtu2beY+hw8fxplnnmm+nzp1Kj7//HPz/cyZMzF69GgMHToUl19+\nOe677z4AwNixYzF06FAMHToUZ511FgYPHozzzz8fQ4cOxTvvvJPK240Jx3FJ2bc15yGI1iK2dwMI\nAgD69u2LhQsX4sYbbwQA7NmzB+FwOKoD5DgOubm5ePnllzF79uyoz2Mxd+5cLFiwAB988AH69u2L\nqqoqLF26FADw5ZdfmvtNnToVEydOxP/+7/+m4tYIIm0gC4PoEEyYMAFz584138+dOxeTJk1qst+k\nSZOwe/dubNiwocVzbt++HZdeein69u0LAOjWrRumTJkSc9+WCh68/vrruOeee/DHP/4RQ4cOxfjx\n41FYWIh33nkHI0eOxBVXXIEff/zR3L+8vBx33nknhg8fjmuuuQafffaZuS0cDuOhhx7CRRddhLFj\nx0ZZTMax06dPx4gRI3DVVVfhww8/bPFeY7F//35MnToVF154IcaNG2eKJQAsX74c1113HYYOHYpR\no0bhvffeAwDU1NTgjjvuwIUXXojhw4fjpptuOqZrE+kJCQbRIRgyZAj8fj8OHDgAVVXxzTffYPz4\n8U06crfbjTvuuAMvvvhiQuecN28eZs+eje3bt0NV1Ta1cdmyZZg0aRI2bNiAQYMG4Xe/+x0YY/jh\nhx8wbdo0PPLII+a+9913H3r37o2VK1filVdewYsvvog1a9YAAF577TUUFRVhyZIlmD17NubNm2ce\nxxjDHXfcgUGDBmHlypV4//33MWfOHKxatapVbZVlGXfeeScuu+wyrF69Gn/+85/xwAMPoLCwEADw\n5z//GU899RQ2bdqEL7/8EhdffDEA4L333kPPnj2xdu1a/Pjjj7j33nvb9MyI9IIEg+gwTJgwAfPm\nzcOqVatw6qmnokePHjH3u/7661FSUoIffvih2fONHz8ejzzyCFatWoWpU6di5MiRbYpPDBs2DCNH\njgTP87j22mtRU1OD2267DYIg4Oc//zmKi4vh8/lQUlKCzZs344EHHoDD4cCZZ56JKVOmYP78+QCA\nb775BnfeeSeysrKQn5+PqVOnmtfYunUramtrceedd0IQBPTt2xdTpkzBwoULW9XWgoICBAIB3Hbb\nbRBFERdffDGuuOIK0xXndDqxb98++Hw+ZGVlYdCgQQAAURRRUVGBoqIiCIKACy644JifF5F+kGAQ\nHYbx48fjyy+/xNy5czFhwoS4+zmdTkybNg2vvPJKi66ksWPH4t1338WGDRvwxBNP4NVXX231aN2g\nW7du5mu3240uXbqY8RO32w3GGPx+PyoqKpCTkwOPx2Pu37t3b5SXlwPQXE49e/aM2mZQXFyMsrIy\nXHTRRbjoootw4YUX4u2330Z1dXWr2lpeXo5evXpFfWZvw6uvvoply5bhyiuvxNSpU1FQUAAAuPXW\nW3HyySfjlltuwdVXX90uCQBEx4UEg+gw9O7dG3369MGKFSswZsyYZvedPHkyGhoa8O233yZ0bkEQ\ncM011+CMM87A3r17k9HcuPTo0QN1dXUIBALmZyUlJabFlJeXh5KSEnObkQUGAL169ULfvn2xbt06\nrFu3DuvXr8fGjRsxc+bMVrfBfg3jOkYbzjnnHLz55ptYvXo1Ro8ejT/84Q8AgIyMDDz44IP47rvv\nMHPmTLz//vumK40gSDCIDsUzzzyDDz74AG63u9n9BEHAXXfdhVmzZsXdZ+7cuVi+fDn8fj8YY1i+\nfDn279+PwYMHJ7vZUfTs2RPnn38+XnzxRUQiEezatQuff/45xo8fDwD4n//5H7z99tuor69HaWkp\nPvroI/PYwYMHIzMzE7NmzUI4HIaiKNi7d2+TwHhLDBkyBBkZGZg1axZkWcbatWuxbNkyjB07FpIk\nYcGCBfD5fBAEAV6vF4IgANDiNIcPHwagiYcgCOY2gqC0WqLdsafFnnTSSXG3NWbs2LF455130NDQ\nEHP/zMxMzJw5EwcOHICiKOjduzcef/xxDB06NOFrtAb7eV544QU89thjuOyyy5CTk4N77rkHI0aM\nAADcddddeOyxxzB69Gjk5+dj8uTJmDNnDgCA53nMnDkTzz33HEaPHg1JktC/f3/cc889rWqLw+HA\nW2+9hccffxxvv/02evbsib///e/o168fJEnC/Pnz8fTTT0NRFPTv3x/PP/88AKCwsBBPPvkkampq\nkJOTgxtvvBEXXnhhUp4P0fnhUrmAUmlpKf70pz+hsrISgiBgypQp+PWvfx21z7p16zBt2jSzo7j6\n6qsxbdq0VDWJIAiCOEZSamEIgoCHH34YgwYNgt/vx+TJk3HJJZdgwIABUfsNGzas1T5agiAI4viS\n0hhGXl6ema7n9XoxYMAAM0uDIAiC6Fwct6B3UVERdu3aFTPgWFBQgIkTJ+K2227Dvn37jleTCIIg\niFaQ0hiGgd/vx9SpUzFt2jRcddVVTbbxPA+Px4Ply5fjmWeewaJFi1LdJIIgCKKVpNzCkGUZ06dP\nx4QJE5qIBaC5qowJTqNGjYIkSU2qljbmOGgcQRAE0YiUp9XOmDEDAwcOxM033xxze2VlJbp37w5A\nK4sAALm5uc2ek+M4VFQ0NLvPiUJeXhY9Cx16Fhb0LCzoWVjk5WW16fiUCsbGjRuxYMECnH766Zg4\ncSI4jsO9996L4uJicByHG264AYsWLcLHH38MURThdrvx0ksvpbJJBEEQxDFyXGIYqYBGDBo0erKg\nZ2FBz8KCnoVFWy0MKg1CEARBJAQJBkEQBJEQJBgEQRBEQpBgEARBEAlBgkEQBEEkBAkGQRCEDZ/P\nh7lzPz+mY//0pz/A7/clvP+7776DTz75qOUdOwgkGARBEDYaGuoxd+5nMbepqtrssX//+8vwejNT\n0awOQadcQOmR715AqDIHd1w8CdleZ3s3hyCINGLmzNdRXHwUt9xyI4YNG44RIy7Be+/NQrdu3bFv\n3x58+OGnePjhB1BRUY5IJIwpU36JceMmAgCmTBmP2bM/RCAQwAMPTMe5556H7du3IC8vH8899wKc\nzvj91d69u/H8888hHA6jT58+ePjhx5CZmYnPPvsE8+f/B6Iool+//nj88b9i8+aNePXVF/RFuzi8\n8casqDXkU0WnFIzdVfsADvhg0SDcPTm1y20SBNF+fLp0H9bvatuSCILAQVGs+ckXntkD1185MO7+\nd955NwoLD+Ddd/8JANi8eSN27tyBDz/8FD179gQAzJjxGLKyshAOh/H//t+vMWrUlcjOzgZgrbpY\nVHQETzzxLB588M949NGHsWzZUowZc23c6z799OO4774HMWTIeZg9+2289947uPvu+/DPf36Azz9f\nAFEUTXfXJ598hPvvfwjnnDMYoVCoWSFKJp3aJVUfab5IIUEQRDI466yzTbEAgE8//Rd+85tf4fbb\nf4vy8nIUFR3Wt1jC1KtXbwwYoAnTGWecidLS4rjn9/t98Pt9GDLkPADAtddeh4KCzQCAgQNPw+OP\n/xmLF38NntfWVz/33CF49dUX8fnnn6ChoR48f3y68k5pYUQOnQnnKbvAvDXt3RSCIFLI9VcObNYa\nSIRklAZxu93m682bN2LTpg1455334XQ6cffdtyMSiTQ5xj7q53kh5j524lVp+sc/XkFBwSasXLkc\n77//f/joo89w002/wciRl2H16pW4/fbf4uWX38TJJ59yjHeXOJ3SwmD+HACA5CDBIAgiuWRkZCAQ\nCMTd7vf7kJWVBafTiUOHCvHTT9tj7teaMn1ebyays7OxdWsBAGDRoq9w3nlDAQBlZaU4//wLcOed\n0+H3+xAMBnD0aBFOPXUAbrzxZpxxxiAcPlyY+A22gU5pYagBrYCWH1Xt3BKCINKN7OwcnHvuENx8\n8/+H4cNHYsSIS6K2Dx8+EvPmfYHf/OZXOPnkU3DOOefatloxDC0gnTgzZjyO559/FuFwGL1798GM\nGY9BlmU8+eQj8Pv9ABhuuOFGeL2ZmDXrLWzatAGCIKBfv1Nx8cWXtHj+ZNApq9Xe/MQi+PsthuCQ\n8dpVT7X6i0knqBKnBT0LC3oWFvQsLE7IarXvPzoGQiQHTIigLlLf3s0hCII4IeiUgsFxHFyqppSV\nwep2bg1BEMSJQacUDADwIBsAUBGgOAZBEMTxoNMKRqagZUrtrYif20wQBEEkj04rGJeeoeVmbz96\npJ1bQhAEcWLQaQVjxOn9AAaEWOKVIQmCIIhjp9MKhsAL4FQnFK752ZMEQRCtoS3lzQHg008/Rjgc\njrnt7rtvx+7du4753O1NpxUMABCYE4yPQFaaLzlMEASRKM2VN0+Ezz77GOFwKIkt6jh0ypneBg7O\nBUkMoCEgoUuWq72bQxBEGtC4vPm0adPxr399iO+//xaSJONnP7sct9xyG0KhEB599CFUVJRDVVXc\nfPOtqK6uRGVlBe6++w7k5ubilVfeinudb7/9Bh999D4A4OKLL8Gdd94NVVXx3HNPYffunQA4XHfd\neFx//S9jljhvDzq1YLh4D4JcFap9fhIMgkhD/rPvS2wu39amcwg8B0W1Clqc3+NcTB44Nu7+jcub\nr1+/BkVFhzFr1hwwxvDgg/dhy5YC1NZWo3v3PPz97y8DAAIBPzIyvPj3vz/Ga6+9rZc7j01lZSVm\nznwd7733T2RmZuHee3+PlSuXIy8vHxUV5fjgg08AwCxnHqvEeXvQqV1SHkFbMKTCR7O9CYJIDevW\nrcX69etwyy034pZbbsThw4dQVHQYp546EBs2rMPMma9jy5YCZGR49SMY7GXOY7Fr108YOnQYsrNz\nwPM8rr76WhQUbEbv3n1QUlKMl19+HmvXrjbPGavEeXvQqS2MDIcHCAM17ai4BEGkjskDxzZrDSRC\nW2tJMcYwdepvMH78pCbbZs/+CKtXr8Lbb7+Oiy66GL/5za0JnzNWGb+srCy8//7HWLt2Nf7zn0+x\ndOm3ePjhR2OWOD9ea2DY6dQWRpZTU9/qABUWIwgiOTQubz58+MVYuPC/CAaDAIDKygrU1NSgsrIS\nLpcLY8Zci1/+8ibs2bNbP96rV5eNz1lnnYMtWzajvr4OiqLgu+8W4bzzhqKurhaqqmDUqCtw6613\nYu9e7ZyxSpy3B53awsh2e4EGoC7U/JdDEASRKI3Lm0+bNh2FhYW4447fAtAE5ZFHnkJR0RG88cYr\n4HkOoujAAw88DAAYP34iHnhgOrp3z2sS9DYqa3fr1h233/573H337QCAESMuxaWX/gz79u3FM888\nAcZUcByHO+64O26J8/agU5Y3B4CKigYs3rcK8w/PxymRS/Cnaye0d5PaBSrdbEHPwoKehQU9C4sT\nsry5QY+sLgCABpl+DARBEKmmUwtG36x8AECA0VKtBEEQqaZTC0ZXTxdAFRAWKK2WIAgi1XRqweA5\nHqKcBdXZAEVV2rs5BEEQaU2nFgwAcLMccLyKcl9tezeFIAgiren0gpHt1KL+e0vL2rklBEEQ6U2n\nF4zeOV0BAHtKy9u5JQRBEOlNSgWjtLQUv/71r/Hzn/8c48aNw5w5c2Lu9/TTT2PMmDGYMGECdu7c\n2apr9M/LAwAcqa5sc3sJgiCI+KR0prcgCHj44YcxaNAg+P1+TJ48GZdccgkGDBhg7rN8+XIcPnwY\nixcvxpYtW/DYY4/h008/TfgaeZm5AID6CNWTIgiCSCUptTDy8vIwaNAgAIDX68WAAQNQXh7tOlqy\nZAkmTpwIABgyZAgaGhpQWZm4tZDt1KbIB1U/1M45aZ0gCKJTcNxiGEVFRdi1axcGDx4c9Xl5eTl6\n9uxpvs/Pz0dZWeIBbCPozYQwGgJSchpLEARBNOG4FB/0+/2YPn06ZsyYAa/XG7UtVikro0BXcxg1\nUbqqGdoxjjCYwLe5Vkpn5ES853jQs7CgZ2FBzyI5pFwwZFnG9OnTMWHCBFx11VVNtufn56O0tNR8\nX1paih49erR4XnsxMSc8CDnC2H+oBrnuTl2At9VQYTULehYW9Cws6FlYdPjigzNmzMDAgQNx8803\nx9w+evRozJs3DwBQUFCA7OxsdO/evVXX8IpecI4I9hyhyXsEQRCpIqXD8Y0bN2LBggU4/fTTMXHi\nRHAch3vvvRfFxcXgOA433HADRo0aheXLl+Pqq6+Gx+PBs88+2+rr5GXmokauxNKNhzDh0n7IcDtS\ncDcEQRAnNikVjAsuuCCheRWPPvpom66T49LMLFUIo6o+TIJBEASRAjr9TG/AypSCI4x6f6R9G0MQ\nBJGmpIVgZOlzMThHBHX+cDu3hiAIIj1JC8EwLAzOEUIdWRgEQRApIS0EIy+jGwCA9/hR5yPBIAiC\nSAVpIRi9vb3AgQOfW47qAOVbEwRBpIK0EAy36AIDA+8O4oC4or2bQxAEkZakhWAAwPCeFwAAgq7i\ndm4JQRBEepI2gvGL08YBAHjJ28KeBEEQxLGQNoKR4cgAwplgnNzeTSEIgkhL0kYwAIBXBTCeBIMg\nCCIVpJVgcMwB8ApUprZ3UwiCINKOtBIMXi+NFVFoISWCIIhkk16CwbSig2GFJu8RBEEkm7QSDEG3\nMMIK1ZMiCIJINmkmGJqFEZRC7dwSgiCI9COtBEPkNMHwR0gwCIIgkk2aCYYTABCUyCVFEASRbNJM\nMDQLI0AuKYIgiKSTVoLh5HXBkMnCIAiCSDZpJRgO3SUVIguDIAgi6aSVYLh4FwAgKJNgEARBJJv0\nEgzBAwAISMF2bglBEET6kVaC4RbcAICgTIJBEASRbNJLMETNwggq5JIiCIJINmklGB7dwggpZGEQ\nBEEkm7QSDJdDBFMEqiVFEASRAtJKMESBA5MdCKvkkiIIgkg2aSUYDpEHFAcijASDIAgi2aSVYIgC\nDyaLkFiEVt0jCIJIMmklGE6HACh6iXOavEcQBJFU0kowsjwOMFkvcS4F2rk1BEEQ6UV6CUaGA0zS\n6kn5JH87t4YgCCK9SDPBcAKyLhgRXzu3hiAIIr1IK8FwOwVwClkYBEEQqSCtBIPjOHjEDACAL0KC\nQRAEkUxSKhgzZszAyJEjMW7cuJjb161bh2HDhmHSpEmYNGkS3nzzzTZf0yt4AQANErmkCIIgkomY\nypNPnjwZU6dOxZ/+9Ke4+wwbNgwzZ85M2jWznF7UAqgPk2AQBEEkk5RaGMOGDUN2dnYqL9GEHHcW\nAKAu1HBcr0sQBJHutHsMo6CgABMnTsRtt92Gffv2tfl8mW43mCKgLkKCQRAEkUxS6pJqibPPPhvf\nf/89PB4Pli9fjt///vdYtGhRm87pdTvA6jNQI1aDMQaO45LUWoIgiBObdhUMr9drvh41ahSeeOIJ\n1NbWIjc3t8Vj8/KyYn/eNROsPANSRgMcWQxdPMfXJdYexHsWJyL0LCzoWVjQs0gOKRcMxljcbZWV\nlejevTsAYOvWrQCQkFgAQEVFbJcTUxSooQwIAHYeKcRpXU5tXYM7GXl5WXGfxYkGPQsLehYW9Cws\n2iqcKRWM+++/H2vXrkVtbS0uv/xy3H333ZAkCRzH4YYbbsCiRYvw8ccfQxRFuN1uvPTSS22+ZoZL\nBAtplktFsDLtBYMgCOJ4kVLBeOGFF5rdfuONN+LGG29M6jW9btGsJ0UFCAmCIJJHu2dJJZsMtwNg\n2m3JqtLOrSEIgkgf0k4wPG7RFAyFyQkdwxhrNtZCEARBpKFgeN0imNo6C+Optc/jlc1vp7JZBEEQ\nnZ52TatNBR6nCDBt7oWcoIVRFqhAWaAilc0iCILo9KSdhcHzHDwOLeitUAyDIAgiaaSdYABAhtsF\nAJDVxCwMgiAIomXSUjCy3JqFISVgYahMTXVzCIIg0oI0FQw3ACAsSy3uS24rgiCIxEhLwcj2aC6p\nkNSyYMiMBIMgCCIR0lIwMj2ahREhC4MgCCJpJCQYX331FXw+bQW7V155Bb/73e+wffv2lDasLeRk\naBZGRG456J1o6i1BEMSJTkKC8dZbbyEzMxNbt27FypUrMXHiRDz99NOpbtsxk5OhWxhK6ywMCoAT\nBEHEJyHBEEVtft+qVaswZcoUjBs3DuFwOKUNawtZXieYykFKIK3WHsMg9xRBEER8EhIMjuPw3//+\nFwsXLsSIESMAAFICAeX2IlMvQJhIaRC7SCgUACcIgohLQoLxl7/8Bd988w2mTJmCk046CYWFhRg+\nfHiq23bMeN0ioPIJCYA9hqGQS4ogCCIuCdWSGjp0KN58803zfb9+/fDII4+krFFtxevRLIxEBIMs\nDIIgiMRIyMJ47rnn0NDQAFmW8atf/QrnnXce5s+fn+q2HTNupwAwHmoiFoZKMQyCIIhESEgwfvzx\nR2RlZWHlypXIz8/HokWL8O6776a6bccMx3HgwENFAhYGIwuDIAgiEVo1cW/9+vW4+uqrkZ+fD47j\nUtWmpMBDAEPLMQl7gUKyMAiCIOKTkGB069YNf/nLX/DVV1/hkksugSzLUJSO3bkKnADGqVBbWEkv\n2sKgoDdBEEQ8EhKMF154AQMHDsRLL72EnJwclJaW4re//W2q29YmBE4AOBWhcPPCJlPQmyAIIiES\nEoyuXbvipptugtfrxb59+9CzZ09Mnjw51W1rEyIvgOMZGgLNTzC0i0SiS7oSBEGciCSUVrtt2zZM\nnz4dTqcTjDHIsozXXnsNZ599dqrbd8w4BAcAoMYfQn5Xb9z9yMIgCIJIjIQE469//SueeeYZc5b3\nmjVr8NRTT+GTTz5JaePaglMQAQWo9Yea3U+xTdxLJA2XIAjiRCUhl1QwGDTFAgAuvvhiBIPBlDUq\nGTj1+ld1gebbGT0Pg4LeBEEQ8UhIMDweD9asWWO+X7duHTweT8oalQzcorZMa12gJQvDFsOgUucE\nQRBxScglNWPGDNxzzz1wOvW1siUJr776akob1lY8DkMwAs3uFx3DIAuDIAgiHgkJxuDBg7F48WIc\nPHgQjDH0798fY8aMwbJly1LcvGOniycbqAXqwr5m91No4h5BEERCJCQYAOBwOHD66aeb71kLE+La\nm64ZWQCA+hYEQ6bSIARBEAlxzGt6d/TSINkuTTBCavMuKYWKDxIEQSREsxbGvn374m6TE1gvuz3J\ncmYCACJoIYZBFgZBEERCNCsYt912W9xtLpcr6Y1JJlkOTTBkNJ8lJavWyoEU9CYIgohPs4KxdOnS\n49WOpJPp1GZ3K0LzpUHCSsR8TRYGQRBEfI45htHRyXRogsGEMBRVhSSrqGloKh4h2fqMYhgEQRDx\nSVvBEHkRHBPBCRJCEQUvf7YF97+xCrW+aNE4UFpjviYLgyAIIj5pKxgAIDAR4BWEIwp2HtKEoSEg\nRe3jC1ulQ+yLKREEQRDRpFQwZsyYgZEjR2LcuHFx93n66acxZswYTJgwATt37kzq9QXOAfAqghHL\nchD4RunAvLUtKDcfICcIgjiRSalgTJ48GbNnz467ffny5Th8+DAWL16MJ598Eo899lhSry9wIjhB\nwaqtJRC6lsB5xjqE5UjUPpxoWRUkGARBEPFJqWAMGzYM2dnZcbcvWbIEEydOBAAMGTIEDQ0NqKys\nTNr1Rc4B8Aq+WXcYzoFbIORUY1/DXgDAtxuO4I7nlwG8AjXsBgAE5I5dgZcgCKI9adcYRnl5OXr2\n7Gm+z8/PR1lZWdLO7+Qd4HgVgFXGpCpUDQD4+Lu9iMgqwMuA5AIYcNRXgtpwXdKuTxAEkU60q2DE\nqkeVzJIjDl5bdc8epygLldgupoLjGZgigmMOVIdq8OdVf03a9QmCINKJhIsPpoL8/HyUlpaa70tL\nS9GjR4+Ejs3Ly2pxnwyXG5AAzmW5mioj5fDk8HCduR5SST/tQ0UAVN6Uz0TO3ZHobO1NJfQsLOhZ\nWNCzSA4pF4zmqtqOHj0a//znP/Hzn/8cBQUFyM7ORvfu3RM6b0VFQ4v78EzQ/s+w9q0JV2PJzjXg\ns6vgyq7S2qiKYLyVbpvIuTsKeXlZnaq9qYSehQU9Cwt6FhZtFc6UCsb999+PtWvXora2Fpdffjnu\nvvtuSJIEjuNwww03YNSoUVi+fDmuvvpqeDwePPvss0m9vkvQFlHiXFYBQgaGf+76PHpHRQD42HWk\n/FIAHtENnkvrKSsEQRAtklLBeOGFF1rc59FHH03Z9Z2GYDi1dFk1kBVlbRgwVYh5fGWwGk+u+Qem\nnD4el/UZEXMfgiCIE4W0Hja7BC3obcQwlOr82DvGEYxSfxkUpqDEX56S9hEEQXQm0lswRK0EO+fU\nBaO2B/p5Tmu6oyqALxze5OMGyQ8ACMvNV7wlCII4EUhrwXCLRgxDn8EtO9DV0TQLi6k8uIYe6J99\nSlSswhfRlncNKyQYbUVVGRSV1hshiM5MmguGbmHoAW0mO83MqShUHqrKIPA8VKaamV0+3cIIkWC0\nmfvfWIXpr6xs72YQBNEG2nUeRqrxOJzWG1XQXE9cDMFgAhTGIOjbVKZC4ARTMMjCaDt1/kjLOxEE\n0aFJawvD47CWkeUUTTyatTB0wTDWxfBFdAsjRgxj/9E6vDF3G8IRWkODIIgTg7QWjAyn23zNq5p4\n8FxTo4rZXFKATTBMC6Pp6PivH27Ext0VWLmtpMk2giCIdCStBcMreszXAtMD4LEsDMZDsVkYsmpY\nGC0HvRWFArkEQZwYpLVguEXLwjAEg48VtlG1x8A3dklJ2gzx5oLe8QufEARBpBdpLRgem2A4OD2e\nwZresjHDQ/F7AAAgAElEQVTTW9HDEXuO1EBSZYQULR1XVmUoKsUqCII4sTlhBEPUBSNm0FsXkcMl\nWsziixV74dfjFwaUKUUQxIlOWguGyFvuJ5cuGDFjGLpLqs6nVazNyXKgIRItGDQXgyCIE520Fgw7\nTkGzNjjEF4yIpEUkvB6hiYURK7WWIAjiROKEEQw3r1sYMQoNMt3qYLpwhCXZzJDyihkArBTbpscm\nvakEQRAdkhNGMFy6haEqMZaA1YUCTNsWliWz8ODJ2X0BANWhmqhDhG7FcA9bBL+anmuAl9UEUOtL\nvlXV3IJaBEF0bE4YwRD1dNqv1xxtutEUDO3/kCybFsXJWbEFw3HqVnA8Q6G0LUUtbl8efnsN7nt9\nVdLPS3pBEJ2XE0YwOF63LNQYt8yiLYyILCMoaym1fTJ7AQCqQ7XRh4S1SYFBRks/tgaVFIMgOi1p\nLxhy2ckAgGxeXys8xjwMQBcTm2AoqgwA6JGhHdfYwjAEI8Dqk9zi9IZcUgTReUl7wegdvgjBddeg\nb9cu+icxYhgGuphEFBmyPtvbI7qR5chsIhjGefxqHVRG5UGawy4SKukFQXRa0rq8OQA89KsLUFUf\nQqhRVVnGAK6RdjDdwpAUGbKqvRZ5EV3dXXDUVwyVqeYCS5ygWSAyIjjqK8VJWb1TfCfHj2RbAXY3\nlEqKQRCdlrS3MFxOAb27e8Hb7jS48UqENl7VdGfdwuA4hoisTeITOAFO5oXMFDToqbbaBtl8uat6\nT0ra3l4kO85gPx25pAii85L2gmHA280JxQmoMYwr3cIApyKiaIIg8gJ27NUC4Ha3FCfIYLIIDjzW\nlW5Kq44w2Sup2q2KjmxgbNxdgcNllMRAEPE4cQSDbyZ2YWAExDkGSRcMgRPBIlqA2xAMSVYAQQYL\nZyCf749ifymK/aUpaXd7kGy3kV1LO2qWlKyoeGPuNjz+3vr2bgpBdFjSPoZhIMQQjOCmK6PTbG0W\nRlh3SfEcb2ZEVYdqMe+HA/jvqoPwXKRAVQR4oQXTG5cS6cwku1O3n491UBND6aDtIoiOxIljYTSO\ncAOAHO2aMkqDgGOIyDI4cJAkBiZpa2k0RHz476pCK36hiDAeoWwrf76zeg8KKran4jaOC8mPYXR8\nl1Q6uRQJIlWcOIIRw8JwORvVlTItDIaIIkHkRS27Sq8/JTNNKIwMKaaI4Fj0sq4A8HrB/2HWtjnJ\nvoXjht0llYyOVO0EQe9kx20IIh05cQQjhoXhaSIYepYUr0JSZIi8gFBENoVE1ifzRVkYrKmFYWDM\nFu9s2AUjGa6aqLTajioYHbRdBNGROHEEI4aF4XE1CuHYLAxZVSBwAgJh2YxzGIFwu4VxqFSLXZhi\nYqM23DkLE9o1QlHa3pGyTpAlRfNDCKJlTmjBcDtjC4bgrYfMFPgDCv46ZyMYixYMiFpAHLIT/oBm\nWRgzw+2zvjutYERZGG331US5pFrRMW/YVY7f/W0pSqpSn1BAFgZBtMyJIxgxYt4ZrmiXlOrPQaYj\nE3yXMiiiD4oSXbAwohoWhiYYTHaYLimj9lRYiZjn++7QcqhMBWMMSzcVobQ6kNR7ShX2zlNOhoVx\njC6pd7/aCcaAZZuL29yGlkimhaGoKpZtPor6QKTlnQmiE3HCCEastFp3Y5eUKmJkrwut92YV29gW\nBlMcZmaVYWGEbHGLXTV7sbN6Dw6U1OOjxXvwl1lrk3ErKSdeDONQaQO+WL6/1aNxtY1ZUrES3Foi\nLDWNKTVHMgXjhy0lmLNoN96c23kz5QgiFieMYHAxg95Np6FkOb3WG0MwDFFQNaHgTJeUPehtWBjR\niw4V+0oRCGnbVCh4o2A2NpVvPfYbOQ7YO3hFsVxSywuOYuHqQyipap2ldKwuqWPtwn/YWow7X1iO\nzXsrEj4mmS6pitogAOBAMVUyJtKLE0YwYloYjbOkAHgcGdYb1TiGB2OApGdCGYLBZKdtlT5NFEK6\nYAicdu4Sf5lZH5f31mFH9W7M3v5RW28npcSzMCKyJh6hSNMAf3Mcq0vqWBXj2/VFAICVW0sSPiap\nMW/zp0ZxESK9OGEEI1bQWxCafpYheszXzL52hsrb0mptMQzd+vAFNVdUSNYEY8wpV0DkRZT4y6zl\nNljT69U0hBFppfsk1dhFQra/1q2NcOTY3T0dNbaczJnenP6F2+91b1EtZi/cYT7D9mDVthJ8vmx/\nu12f6Pyc0ILBxVgbwy4YUYstMR6SEfQ2XFK2eRiBiPaZYWFkODzIEbvgaH05Xvz3Fu043upoGWOo\nrg/h/jdW4c15HcvXbe/o7C4pST42wYhXSyoQkrBxd0XcyXxMH6G3NobBmQLdijYmUzBiXP/ZjzZh\n1bZSbNlXlbTrtJbZC3fiqzWH2u36RGIEQhLmfLML5bprsyNx4ghGgr1OhsMuGBxGnN0TYy48CVB5\nax6GKIGpHKAKtpRbXTD0oLdbcKG8SoLMJJiuCcHqaINyEAdLtMqoW/e3XycSi6gYRpSFob1ubUA5\nnkvqrXnb8cbcbVi/qxwAsPtwDZ58fz3qfNFxoFjC3hzHECNPWgxj/9E6LFwdv1OWlPa3JimFuGPz\n5epDWFZQjDf/s629m9KElAvGihUrcO211+Kaa67BO++802T73LlzMWLECEyaNAmTJk3C559/nuom\nIcMl4vSTcmP2LN6oGAaPrAwHHCIPxmwuKVECZAcAzoxzGGJiWBhu0Q2oAjieAZw+UrZZGIX1R9Bg\nS7sMyiGsOroWSowZ45v2lmHBmn1tueVWERXDUJq6pEKtzUCKCnpbr38q1Kr/Fldq8yx2FNagsLTB\nFNLGIYBV20pQVpOa1ORkdaJ//XCj+Zp10BgGTVLs2BhJMnX+jpeWndJqtaqq4qmnnsL777+PHj16\n4Be/+AVGjx6NAQMGRO133XXX4S9/+UsqmxLFq3+4DDzH4bPvm3bCjV1STgevWSeqzSXFK2BG0ULT\nwtA60bAew3AJLr04IbRSIrIzatGlN7bMxmnCRQC6AgDm7luIVcVrURWqwfgB10a16Z0d70LIqsGo\nwFPIznA1afPhsgYcrfBjxDk9W/8wYqDEmbgnHWMMI56FwXGa28b4LCJr5/WHpOgTcEBRhQ+zF+4E\nALz70JXNX9B0CSXeMaakWm2sU3aAvlpRGMSm+R5EB8HwnneAn0oTUmphbN26Faeccgr69OkDh8OB\n6667DkuWLGmy3/EuSGe6p2JYGA7eYb5mjIfLIUAUeIDxCMu64vNWQUJTMHQx8ellzr0OD5i+j2lZ\n6P93dWsl0asjWtqnxyWgMqi5pbZX7WzSJiFLG4nvKyuLeT+Pv7ces77c0bSjPUZYnIl78jHGMKLK\nm9teG9+DoUkRSXvhD+pJBbZzGKOu+NdQ8fyqt7GudJMVdNa3lQcq8fKmmagKNl6X3SIVy7J3UL1I\nyuz9YyEQknCwhFKNW4QzkiY6wq8lmpQKRllZGXr16mW+z8/PR3l5eZP9Fi9ejAkTJuCee+5BaWnq\nFiK6dvjJ+OVVp8XdfuGZPaLna6gcnKIuGCoPcCq6ZLnA8aqZHWVO3NNdSUY5kFxXDqDooqILhSEc\nkwZeBwAI6gKU5XGaIlIeqIzbvgOVzT8bo8NtK/FKg5hZUq2OYcQ+t5GIYFoY+nl9IQm+oCV+HGIn\nLdgpD1RgXVEBPtjxSZNtH+38DHtrD+CLvf+Ne/xx8+u34jLBsIxguPUpzC11NO219seTH2zAUx9s\nMOepELExfukdUC9S65JKRCGvvPJKjB07Fg6HA5988gkefPBBfPDBBy0el5eX1er2/P7686Pee3X3\njihw+Ntdl6Ffr2w4HQLcQgZCiuYr79Y1AxFJBSvlwfEq8rt5cIhXwQwxMFJleRV5eVnwq35wHIeT\n8/NNC8N0RelB7755eQAAf1hzXzkcAjiH9qwkVcL+0F5c1Oc88Hy0nleEapq97+wcD/K6eWNuKyyp\nxyNv/4iHfn0hzj61W7PPqaja+oP2ZrrNaxrfJifwLT7/3YeqwXEcTj+5C6oCVuefneMxjxV4DhIA\nt9uhfSZo9/vlj4fw5Y9W4Dgjw4luXa37inXtBsGyHhwO7TxOp4i8vCwonL4YliP+76akzpqh35rf\nViAkweMSY04MjXWuzCztee46VI1Fqw9h2i+GwCHGHreNu38+AGDBCxMSaosvEMEvH/kaYy/tj9sn\nDY66tj2dN7eLF12z3QmdM5mU12i/KyYIx/T32xaO9/XagsejeTk4jutw7U6pYPTs2RPFxVYdoLKy\nMvTo0SNqn5ycHPP19ddfj+effz6hc1dUtH3t5WBQG+EzBnTxiKir1UQii89BSAmAc4YQDkmaC0bl\nwXGAIOqdfyOXVEiSUFHRgEpfNbIdWSgr85kxDI5XwABwvHZsxMfAgYfC6YHysIQan2Wqv/jjLNx1\n3q0Y1PV0RGy1qUrrK5rct/0HVVJWDyGOu+H9/25HbUMYr/17M566dXizz6XaFliuqQmY1wyFdSuq\nPhTVjj1HarFs81H89ueDzM7vgVd/AKDFG2qqY5/P6GN9/jAqKhrQ0Cg7yiAQjKC21jrHrP9swbkD\numFAb+u3c7TassxM11lYRkVFAyKS9pyliBL3d2O/55Z+W2t+KkVOpgsOkcczH27EL0efhqsvPCnm\nvo3PtftgFQ4eqcG8lQcBAKf3ycawM7W/CZWxmNl8if7W9xdr1u2XKw/i9kmDo46zWyrl5Q1Qwslx\nXx4LtbUBVFQ4j9v18vKyktJfHC9CumtZVdWkt7utApRSl9S5556Lw4cP4+jRo4hEIli4cCFGjx4d\ntU9FhVW+YcmSJRg4cGAqmxTFWf20gPMVQ/tEfZ7j1DoizhWEUxS0CX66MDhdWmdkWg+qVXyQMYa6\ncD1yXTnarGjTwlCi/j941A+m8KaLKhRRcKAsOrW2Pqz9UHy2pV+DsERlR2E1bnluKXYVVpuf2V1S\nQTmIz/bMR0PEpzWTGXMaWk46jSo+aA96c34ArIlL6rl/bsKaHWUo2NfUnSYratxaUlYMI9ol1aQ9\nKovK1vrvqkL8dc7GqH0CsmUVMS76PEYFYaWZQEVr5mG8s2AH/vHxZmzao/12P16yF8s2H03o2K/W\nHDLFArC+j6/XHsKtf/selbq7Jl7cpzliVTMwsD/b9ophdAa2H6jCii2pL3bZHJwZw2jXZsQkpRaG\nIAh45JFHcMstt4Axhl/84hcYMGAAXn31VZx77rm44oor8OGHH2Lp0qUQRRE5OTl49tlnU9mkKAad\n0gUv/P4S5GZGj3Z6ZuZhn38XOF6By8EjIvOmMBRlaCPnxhaGwhT4JD9kpsDDe3GkrMF0WxmWhRHL\neG/hPrjPFgBBQddsF6rrw3ApQXDMiQGu83CArTNX9/NL1sjX7zqChogPWc5MfLxkLwDgw+WrAUcI\nkNxRncJXB7/DsqJVOOorwR+G3mF2ynwCQwSmMkCMACpvdtQH6gohnfEtxOL+CEeiXVqc2w+hSxlU\n9awm56ppCMfNkrJiGNp7o/RIYxSFtdjJ2Z+TKuiuD92JVhfSRLesPn7QO9EYhr0dWR4rQWLOot24\n/Pw+sQ5pFpdT+0I++16bgb1lfxVGX9A3KsgvySqcjpbTmuwDhsaps2Hbs23v9cuPpZjk8eLFT7VJ\ntj8b0rvd2mDGMIz/GcPhMh/65Hm1eGo7klLBAICf/exn+NnPfhb12fTp083X9913H+67775UNyMu\nXbKapqlOPP1qLN+5B1JJfzjO4iEKvDlBz8fpFpEhGODAmCYYxmh++x4/Nh/aAqGbkVarB70NS0MR\nwFTNwuia5UZ1fRicIIHJTuzcE4HrNGteh3FOJjnAOSSsKdmAq0+5XAsK8zL2ur+B53wguO7aqJG/\nsWTsoYYi/b0uGLYMjHjWhqKq8AxdCjXshqKeAwDYXa11aI7eBxEqvihqf9fZP4ITFBRHDgKITu39\nfN9cfW6J1pnaR/JxLQxBgpBbAaWqFwAOispa7OQCNsFo6FIAFA0BGKCoCoKKH+CA2kj8DB1FjR7R\nx3s2stx0XkpbaDwp0RBXvy3oH4ooCQmGZBOFmobo1R6jLYz2FYyOOHJujKqyFhMtUkajWmTrd5Vj\n5vyfcOXQPrhpzBnt0yadE2amd2vwONyI7DsfzJ8LABB5DpwjehKNkR2lveGhMsWsVMv02AVrlCUF\n3hb/UAWAV3WfPwNEOao2VUTVrlcf0gRDLusHMA4bywoAaB1K4zbZR+heUZuAaMRAmGlhcFh86HvM\nWPV0lLvLjk/Wrsm7QmZpEJG3OqzGabWGEMpoOtFob8Mu7Pb9BHDaeaJcUk2ypLR9nP23wTlgK4S8\nIwA0AWtRMGwuqbCnBJxX8+f7pID5B6jy4ZgrIwLRa3o3Z21INpEoScL6Jo1Fx7i0L2QXjMQypYx5\nLABQ2qiisN36SMYqiolSWh1oMgGtvQUrEaQ41u7xoHEtst1HagEA63Y2zTA93pBgxOGZ2y7GpMv6\n47STciEIPDhno1RA1TbiU3koULG7qCp6m/4/7w4AnArOFdRFhNMFQ4HLof3PcQxQHOYxRqmRyqDW\n8alBL5i/C474iiGpsjY/QrT9IYqRqFFkWLW2qUw1O2qe4zB//9eojzRgdfH6mPdeI1nxFKP4oGHp\nAEBIit2BSUwTTCvXniHMQlChgvPolpIhDkoEkZNWg8+pMMXM6PB4fd4J79XOoygsZicXDMtm525Y\nGMPyz9MeR14RGLRYjp36SHQQsSZUi/21hdFus2Y6NPszLqlsWTBueW4pNu6O/4feeIEqNY6FkQj2\nTs5+vLbNOsfxSiFmjGHGO2tw72sroz5XErTMUj0PISwp2FFYbT4P+/Xs4nu8MWuR6e+NZnUEVx4J\nRhx6ds3AuEv6g+c4OAQOnKsZwWA8GBSs31MStc2wNMT8w3CctFsbsVf1wvCztJRbjlchijAtBfsK\nfuV1flTUBrFgk7Z2BgtmQY1o5zOsBrMIIgBnv59QHCwy39sXclpyeAVUPSbCc0A3txbs31a5I+a9\n19kEw+ioa8OWOyckWokKhrABgE9uwM5DNXjqgw3aB7b28Rna8cYf5+6afVAzy+E6Y6M54jRGwcw+\nQx7QXVJNO5nfv7QCr3+h1dvx68Iw/lRtljznCAGMNRGMunC0W+qx1X/Di5vehF+2xUD0Symq0qRU\ni6SoAKdC7HkQxbVWTCTbGz/rZ86i3XG3Nb4v08I4BsGwWxGSrKLWlnUWbgcLI968oHirODYEIman\nXV4bxO/+9j1WbYsuUb+vqM7MBmsrc77Zjec/KcCan0qbtPd4WBjltUE88OYq7C2qjfrcKl4ZLWSJ\nJKykGhKMBBAEHkp1tG+eRQkGh7As4XBFbfQ22z5Cdy2LRqnqhT7dvRh0UncAgCgycE6tc2cRt+nG\nWrPzKB6cuRp8Rj2YIoCFMsxyJIYY2F1SQtcyfF8zD/uOan9MQZtgzNv/FeoytE6L5zm4BK1zO1h/\nOGo/QAtu7w/9ZL43Rlp1YWtkLuVr1XUDIQlPrPmH+Xm1VInVRwrM95zNAuIztOMNwbB3xEfFjTjS\nUGyN6owZ8oarq5kYhpGZZVgYOa5s/YLasQH9/pik3XNdIwvDiPUEFMuCUvXJbw+seBSvFcyK2l+S\nVYh99sJx8m6gj/accjKdZipvLLo1M+ehqUvKsDAsKy5Rl5TdXbZ43SHc9/oqFJZqAtkeWVLxKg/E\niv3sOVKLe15dif+sOAAAWLNd68SNcjAGz3y0sUmGXCzqfGHUNMRO0zbYul/77Rws1n4TgXB0okGq\nWbDqIKrrw5g5/6eoz824lv6TtwSj9dc4WuHD58v2J+07J8FIgByvE9LBcxAptGUBKZYYaAFs1YpV\nGBZGyAsW0YLqnD5/g6kC3E4BXpfWiQgOm2CE3QAz4h4qwCngPH6ogSyYbiwADSFdMMRGMQNexfwf\ntD84Y10OA5nXOlSO40x/v8pU7KiyRr+MMby48S1Uy5YLxReQsLNqD/bVaUFvpa4buIw6bDlUhLte\n/gE1YWt0dCS0HxsjX4PTrQm7BcRn1gK8bJbgsLuGKt3b8fGuL6yRp25lGfenKGqLo+KAHESGwwOR\n10rOc4IClVkuKTWoTfybtW0OygNNV+Lzyz4Yf6EqYwjIQURUCXtrD5hpuYDWkQi52vGcU3vGWR5H\nVPpxY5oXjNguqZYsjFjuGskmCkfKtOdbU6+10e5ikY9TDMEXjC0YscR/2wHNqv16zeGkXPve11fh\n/jdWNbuPEUMzEkKOt2Bk6ll2TZ6TEXMzBUP7P9GK23aemrMBX605hI27E199sjlIMBIgv2sGHvv1\nCPx53HXWh3pw+qW7LwWLeMA5w+Yo+srzTtb2YTxCWy+LPpkqwO0U4eK1Ea8gKmZ8hEU85nnBK4Ao\ngeMYWETrcAzrwxcOAWBaJwwgvGeotl1yIcOt/QgbWw6GEPE8h5AcAqf/+2zPfEiKBEVVsXD9XjMV\nlelVeOsCIawqXqt/xkOp1WapL9tfgHhwDv3adgsjsw6us9aYHWLjWIKTt7l09OM4dwDgFCgqMztk\nvksphPzCJtf0SwFkOrVAP8e0uJCiqKZgsGCmue+yoqYdyfd1/4F76BKAU6GqDA229hX7rJIskqxq\n7QIA2QEOgNspQpaZfm/MDPAbZGVo34nj1K1wnBztBmw82pbMVQ2tDn7TngoU7LXmuDw6ex2mv/JD\nk3uwJz3U+bRnaIjDdxssd+Xxqlbrj1P/K5aFITRKgEgWzcVBGmfpBW3ttT/LzXsrzPkxycSr/602\nFifr+7EGMEDiFgZjDDsLqxGKyKabraVabIlCgpEgp/TMQj+9pAegFSn87f+ciRyvE0q5NstX6KGN\njjJdthGlKkIN299rFoZT0Kf/Cyo4l80lZRYsVK2ihYZPX9+29WAZ+NwKCF21YoQs5IUayAQnSvC6\ntX2DSgheMQMPXHCXdgpO74Q5hpAShlzfBVJlTzRIPtSG6/HV6kOYu0brzOSyk6HWafdaFwiA5/TM\nrf2DwfzapMafiovMjlGp64bIviHmLRqussZZXHyGDypjqAnV4utCrQilGtKqAwucNafBsEw4UYLn\nwm8R5GrNUanrtAI4T9llrnpoEJACyNTXY+f0DDRZZQhKIf06lmB0cWnZb/b4i3Y9GRAjUBlQbwvy\n76zeY7tOSLMmAUCMgEErLaMyBklS4Th5FzwXLgbfpRTiSbugTXRUAV6G2L0YYs/oEbSsMESUCBwD\nCsB5a82OKmxLLFi3sxyvfmGtA19U4YM/pE0U3by3An94bSWq6kJRnZzhglIUFTUNYRSWWgJ4vGIY\ngTguqVgWRpMU1iS565uL/5hZemosC0M7rrIuiNe+2IYH3159zG2org/h9f9sQ3mj0vweV+xZDcbz\nYY0sjETXhdm6vwr/+KQgytWVLCEmwThGpk0cjMv0yT0PTbgaALS1LwBkuT3RO0vWXA+m8nA7BTh0\nwdiozrO5pBpZGKaLS/9MtzCWFhwC77UCf0xyaIFiQTLrKIXkENyiG909WoBb5YxKu/ofsSzqa3oA\nYSWMo5V+0zLQhEs7T00giCMVesC6vqs1GVGQo1xwSnUvXOwZq20zBEO3FOQqK/7DGPCZrQhgeMcI\nvb36WuiCas1X0QnxNXonZ/3oxV4HIHTT4kKSIiGiSqZggIngeAWKwkz3GwtYgmEUioyVVszxShML\nY8XR1WY6bnXIcsFxDq3NxmSqyrogxJ5aDSzXaQVw9CoE5/Fh9U+lZsaXdpNWR6ooKjaUbYHYrRTu\ns9eYHX04gUKSEUnFa19sQ70/ghVbiqMyoczzq6xJxlQy/NmL1x/B9gPNL/zVGgujsbulpa4x0Tkw\nDXHcYgDA8wycK2B20AFbuRQjHmTEkuz9bWFpfdRaNi3xr+/2YtOeCnzwTXTyQ7y4XGMLsLUxjCPl\n2mDHvjBbsgw3EoxjxO7yObVHHpyC5VLJz8mE02E9WjPrBzBdUkbgWoEEzhnSOmhVNK0IIbcSYm+t\nhARTBW3tCHswWD+nQ83U0nEVBzgOCEohHK30IygH4RHd5voeKq9nYhnrkSsOK4iuhBGKKFHBd6Md\nlfV+lNQYCxoJWhsBbU6JLhiGiDiYdi2tI2Xgs7UfrFw8AEqdJlyyKkfFBKDPPTHmsLgztG1yVU8z\n/lNcXY/NeysAm8Xi6H0QzgHbAFiikOnMwIotxZAlDuAVyKrNJSW5ENl9AQCY14o5D0VQwBhDQ0Tb\nluPMRnWoBosOfQ8AqI1YmVG8JwDnoDWQnFp5lkdmr2tyOk7P9LILPOeyRpolVQFsO2C5mzbzX6Cg\nfJs114VTwGXoqdVqdCVae1B5wY+F2HbAKBOjwtF/G/icCpTXBPHOAm2kaYxo2zoPIiwp+GTJXrz4\n6ZZmR66+QMQaVNiIlSUVq6wJn10F8DIW/FioXVeWwLm076W5GIO9w20sllHt6LYP7iErUMNrIh8M\nW22V4gh2MCzjr3M24tMYa+nEw6jj1bjN9vRi+/ca18JIUDBifSfJckOSYLSS+4ZOw+DuZ2NI3tnm\nZxzHobueqgoAORkZeHX6ZfjddYO0D6LmbAhwOQUInPUZ5wrYgujWr0LsZqXpCjxvncdmfXRvGA6A\n01JyAazacQRPfPkJwkoEHtENgRfgFJxgumWh8LovVhHNa+4trsK2I0Vw9tdcUizianQt/Yet8qb4\ncYJiWQL6viLTXW9iBJzHByGnGkpdNy1+oAuNpESQ49QzmcKZ2v2qIsJ6qrA3Wxe0iBuRQ/rzE2Ts\nKKwB74ox74FXTMEoLAri/a93aRYZr0BWmCnsTHFA1WNBxrViWxgyFJuF8asz/xcCJ2BbxU6tVlgk\nuryIkFWLqowt+sEx3B+6YHA2weAzGsx9V24rwcb9tnRovhabyreas/Yd/XbAfY42XyUsKVEj68bB\nUmPlQj6rBmLeUbjO2IgFPxaiqEJfo8V9bILhC0pRKx3a/eH7iuKnuG4J/gDPsG8BRzjKqoll4TRO\nGdUBXG0AACAASURBVC1R9sN15no4T92GuXrm1OJD38M95AfwXUpRVF/S5BwG9ooHDXql5IgiYfb2\nj7CzynIvylnac68QdkNVWVSBRimOBdMQiEB116ImmHhRQEMMGmui/bu0W2OqKRjG/CRtv+bSarcf\nrDLvO5Y4JGtOCwlGKxmQ2w+3D745yqIAgG4eSzCcghNOh2CO6KIsDKa5pK7rP8b8iBOU6DTdxqgC\nnCJvm59gCUbEGHTLxjbJHNn3d2mi5hE8UFy14Lx1OJS1WGuGKpjn++KHPVEdGot4rJnsvKKvMMgB\n4CAw2xyJRllhX63UgsOcI2LGIVRfLgDOtEIiTMKuI1r7QruG6seL5qi/78n6XIxAllXt1yitos+F\nUX1WlVoICvy6NXDoqCEO2tK4ZTU+lNbVWc9Hb0NIDmPzngrUhzXTfdzJ43CaMNw8X3lNED/sKAQA\ndPd0g4Nz4VBFDZYXFKNe1s7X3WG52VzQ3F1mMNyGZWFYLinnqdvhOtcKvBvZVgYO3mH63oXuWiE8\nMf8wnl/xEd7+cru5X60vjltEjD2qNrJyYsUwGGPYWbUHSw+vaLLtyffX4+G315hCYff1Hy5rwJqd\nxZAVFXW+6LphRdDmyAi5ZVGxhFgWRuNRcbWi/Zb4HMv6MuYNuU4rwMvbXjMXHWuMPYXYsDCe+Wou\nNpVvxZtb3zW3CbL2vQWcxXhj0cqoCgZGsLix66vSXwf3OatxOGdhzGvHwhDoxnEa+3Owx3tMC6PR\n/cQTjK37K/Hiv7fgrXnabyOWNiQrz4EEI0nYLQxXIzGxp+ACgNspINPpxbhTr7HtE7+sF1MFnHda\nd2t+Aq+YgdeI3tf07qIFcjlRMjvy+Qsi+GFrMer8WkfrPtsWuJOtWeWcIJsBdqUhFyycYbuWNlHN\neJ/tdWvioQuJ0T7thQAmi5pLqpGYGBZGfTCA0toG/Tjt56fKvFkKJcRrI3g1mNVkAp8xW1wqOg1y\neV+z7XVhv3VP9mtyKg5X1gJMqzZsuOB2HKnAa//ZhrV7tNIjny85gp8O6B06r+CNudvMLK5sZxYU\nSQAnyPh+81EE9LIpl+Zei1OyTjKvA6Dp5E6j7YIE3h2AGrTWi+dt4mK4Ag3CShgRSdFcEJJ2T0Ju\nBUqFHdhSaQW/lx9dYQbW7fEd3tYOzuU3S9JogsFQK0V3tC99ugW/+9v3eH3L/+GLfV/CF4m2vCrr\nAnCduwKf7JoHILpU+vyCDfiw5GX88V+f497XV2LpJqtqL6enR/NZtVHHxJrpbc1jYfj20DL4mR4r\nsv1d5Dpzo46xx5PsRFkYumCUBLV2uQUrAYVj1t/lT8HVUccZFkYTwQhoAwZFaJSFqBOQAiitDuDl\nz7Y0qTzcuMO3p2LbRbixBRavirNBlb6WixGzUE2LxroeBb07GD0yojOoAFvVyUZi4HZqP9QshxWI\nRTMWxjmn5GHcyH4xXVL6GkxwGX8IeufPFB4Ah/e+2hXl+zeQy0+K7pANwSg7Obo9hktK/+PPyXAB\nqhgd9LYJourPBp/hM+cqGGJiWBhLtxy2soxsM+IlFobY8wB8rBo8eM2NZVoY+ig9s0ZbA9yfY2uf\njJqAXnZEFwxr4SpN1DgmAuDQv4fW4dQFtM66rF7vcGSneQ+cIOvxnDB4JsIjujWxE2R4XCIkpnVA\nua4c3HXerdrt81qasxEEz/QPtGacCzJ4XeiU2vwm3wOXUQehizbvxUiPNmJKOV5nk98ObDXMdsmr\n4ehVCPf5S+GyDQbswuUe8gOcp23W2uVxQMg/hK/r5mBd6Sb4ghIqa4PmHAiDssbzVBwR8J4ANlZr\nMZoo102u5i6K9NkA99Al+HanlW4tqppA8hn15loqQPTI+l+7Psc3hUvMDlroWoJ5+79CNbRsMuO7\nNObG2JHi1AWzz9iuD4Qwc/1HELtrLqwcwVZpmbfNE8poiBKMQCQIX8TfpIJybSh+ActN5Vvxxx8e\nxxNz/4ut+6uwYqt2TaP/53kOBeXb8O/d88BYdLmbYEiGrKhY81NplHvqQHE99uhuP9mW1LC35gD+\ntesLKKrSRIhCakBzC9s+TkahTOA4VKs9UejltToDI2X27P5dMbBvDnqc0RubGw6a242smkynJRjN\nuaT69chFZobDCi732W8GkcMh7Ufn5LSAM++t1zp42/mYypkZXAAwKOdsbLIF2DleAThj/oXuRtM7\nJtcZG8Fk69ouhwCmCOAzfHD029Gk7XLJAAg51RC66nMXDDExr2V3ZfFR2xwn70GY5SLT6YWf8dGC\nxqngM+t0V5XDFkuR8e/lO+E8xSbM9vsSFPD6SHLMsH74oASmxcIEW0kW0Qjm6+a/IwQnMsBxnGY1\neRR4XDwa5AjAAR6HC27RBQ4c6vijcJ3lg1KjLw5Wn49Tc06x2m6Ufom4oFTnm+nQAIOYd9Rsg1qb\nBzAtxhKWFGRlOBBq7F4y53hYHQDnkMA5bEvaNrJ0hBxNELxuh/m9fLjzU8gNXyOy7zwA0Vl9ZYEK\nDMjthyUbi7CjsNoUQkCb7GkXDHsRTk6UEcksQllNAHm5HiicdhznDkRlIBmdl0/yY1WxJkKXsf+n\nP/9GI2H9O919uBZ1IT+YIkAqOg3OU3aZKdONsXf8FeEybD9guf/s7WC8BMY4qHXdIORWwu+zLKvl\ngU/x1cpaBNddA3tc0V5aRlEVCLainKuOavOVkLcfqOhhDgzNET8PzNr+IQDg2n5XRnXitb4IPl6y\nF9/bLDQAeHrOBvO1PWj+8uaZAICzup0Bvy1lPBiWsZZ9As9QCcr+oWAKoNb2SNpERLIwkkSvTEsw\njHkLToeAGTddgFN7Wu6qHrkec0SQZROMU/O7xD232+HS4iH2UiM5WkZMSP9bzuP6AbILYs9CcKIU\n1YkbqasGP+3X/fpGp+uIxO3EAd3Npb93OQXzOMOtkuG0pQ0bM9sNq8ZMCbbiL11zdNcR49Et22Va\nLwAQVANwCa6oY3i3D5zbD45X9Vnv0efjmri/LOuI4xWAieA5Dh6XA1BEy5VmCIbkjC5HwqmAIwIH\n00bITI8PFVXVotrvB1M5uByi+T0D2sREo2MNBUS49ew0TpDBGWm0iojIgcHWvBxeMcUrvPsCaAkA\nDoSVMMKSApcTUUJg3C/QNO6hfWi4xmIXRfR6RG1yKLSOn8+sg5CnB9xtAfsVR3/E90dW4p/f7sHm\nvZVR82nW7i1sNAksenTr8zE8/PYarN5eAlUfwXO8ioKqzeY+isowf//X+GjnZ+ZnZufJR1sNxiDg\nHx9vRqWvQRN3Q0SKLMuoMliF8oAW77ALRkiJFpWGUNDswFVeAhQRql9LwqhnVrzEcInxuRXgnNbz\ntE84/XL38qhzu0Xtd8tn1kHIPwSfXpGBqQyctxZlmZZwheSQaWnxORX4v6VrmoiF7SmAzy2HxJp6\nCurD9VHJD5V1ISjQ3gsDNsF1+iYA8et6tRayMJJEpiP2WtoA4BasDvW5O6zO236M8WMDADWUEeXj\ndotO8BzX1D0BmCN4t+CCw5cPKfuwlhoatM7NAtlQ6rtAyNbjA4rW0ZnFEXseMjv6Jp2u2SjtmHNP\n7YbdpdGjwAyHC8afUeM2mi4po3JvVg04XgWTNZdZbpYLDbZ5CRKLWM/COMbbAOdA3dWhOKKuY3eN\n9evRBQfqbBaP7rrjVDdEgYPLoWea6Z1uhOkjcVs8B7wCzhEGxwG8rHWuiqRtq/b74OQVPQlBvy9b\n7MCwJP535CB4jHsQ5P+/vTMPr6LK8/631rvl3pt9D1khJEAgAcIWdmQTJGkWhRe1WxRFWxRwQXrU\nntHWmcYHp/vpx8exfbrtxWec0R573ufFcXoGX0VfEW1axBZwWFQSIAkhZM9dquq8f5yqU1X3XiAo\niCT1+QdS66lTt36/81vO77AEAKLfR+tJBu9qpsrOsHZCPvbM/dEQrVnl1kujdKRDSG4znxdm3EPr\nCYJP6mT3giLFVVY2FF6SR2Lns316jS2rUmjsPonG7pMAP4+6Hy0Wxnst/w/DBLNEjnWftX0ff3EK\nsHhc32l/E5BmAVE3FFXBn/Q0ZQNjBMydJ2BPrx21pXwrxDz28T3/AABQT1SiZkQGDKupX6XfEYnI\nAE+gIIrX/u8x1FXlQOOiIKoIEnXrx/YDoHEeA9eIv4BEJYQ+ngPALP0PAH86/SamFVYj2RXEZ1+0\nQ+bNb1guPISmaAaAcmiEwD3qA1jzqvpVY+kADa5yWhur/8OF9gcWopCKPgNUEWJmE5TmQrx3oBJ1\nVTnskLZQO3r7zcFmbyix1ZVojs7XwbEwLiO3jVqDtRWr4rbHBcF1rBaGWzSPCf91KrIj49jfXsk+\n4rahj84FgYdILItBxQp8y7lMoFpiD8aIlQniGEFg3Ke2IhMen/3H55XNQGJlfoZtH2LuJeV8gQ6t\nlSkgr0sCHxNjMRWsOXrlPb0xbTfdVUYW1aIpZbhr2Sh2bcMlBU2EKPCQJUGvEkyPjyIEmXMB4G0T\nEg1hzOmCRInw7F6Gu08U4z8do8+mVRSy4CqnB71pmyVb2zneuk93BSoCzhquEd1S0HqDCB80srgM\nhUGVgtKWixJ3JbuX26uAEzSo58z3YMR2fB7pvAI+UZzLmHhpVSZfRj/F7tC/xD0zQxf41gWtzOvR\nfbGjfoAuGMbJ5sRHdg6bx0EAga4Zw+IaMCbVmQJeGHYQn4TMkf8XZ6gVEvmqki5Cxit488MT+JsX\n99K5SYrI+j6il+ePrSLASVHW7/2qPSHgz6cP4L/+3IifvXYAB0/YYz/tfZ147vVPcbYr3hrsV0K0\nbIulb901u2wWlpjRCDGtGWImtQI5bxd+9QYtxmik5Z/oakJbxKz91hEyFZoJGdBE0IHgKIzLyPis\ncZiSMyFuO88JCY6mglHk6Y/VI1mFvYhUyfzgPYbLRxMR+nQa1G5rtggVqqLAsTgGEB8TIYawAphA\nNeIVNoyPsTdo20w0ARyoHzys2T94aymUZdOG29P6mIURG7wVWLtjS4hYra04Yhan4t294EX6MYwp\nzqJ+Y0MR6rW4NJWn/SPy9HyjbDofhgS97WxCogroQlDTrS6j76Sig3pCgZBwohnn6odP8kLgBfMZ\nBMUsPBmbxaUrO7ZGivFcutDo1Vd31HoDccv9GnEKEvGYmT+CAneSns4c8qH/41n6OVTo+twSIEbg\nhh/jhRvYOQDAiQlcXEaJllilYH1mfaARbRyuX4ee09qjB2rbzNGwoZxCoEJ3Zv5UjEmnc23CSgTy\nyA/Z+ZEv9Tk4hsIQFBrEVU3XrFEdIOEETMM9Z9QzUyQ9ecFUQLTvRWaBRbWw7ZlslxMj0DQN3Rq1\n9JRmGqP6n45jaGmn76IzRli3dnfhz5+fscV8DPqVEBRVs/32OTFqm+BpxBUNSJi6SOn6NvT5jnQc\nx1fBnWwA0R1O3BeOhXENcb4qkxzHsUwpm8IAkOo2zUyfvu+Bm8ZhTd14WyE9A1HgIXNm2mZsKq/N\nOlHjLQwDluranYolefXmDo2H1y0mXLbS6zKVkUcWY6wZe1mT2O0CzyGpc7Rtn0s4v8Jgqxhqpjst\nJZd+ZC5RhiQKTFm6hlM3lhrlIQg8C9hTIURAhAibbMgEsqufCS0lIlF/uP48gr+DChNNQDDB+he8\nKwS/TGMsIi9C4kXwvk5zBr0S605T6WjW1l8iOEEDQNBJ6MhR6w3aLCoA4APt+r4AzeQCVSYun24p\nhj1A1A2tJ8DOccs8OCkMGR5oUaOWma4wdMGlfFUJF9HnlRgKw514ZUYIUXBiFEJvJpTTpbRqs369\n3sAh1nZDmRjtCGnU+gjIfrgFOsiJkAh4t+lKU9vyqJvUSEKwVihg82ki6O6L4Me/j587wgpYWtyB\nRBV0a5SwthhVEgAgrIXAwVSQysnhiJ4u0q8XxeftxxHiu6C05SB6YiQIAXojvTQOZ7lXibfc9nei\n2e4hJQRVJXHK2FYRQE5sDfZG+0BAQIjlW9Sv0xOhfcsSMAA6V0nrwtuNF67eOxAchfEtYA2OxmLU\nP/LGKIwMrxko9+gun8qiVMwdn4/a4QVx1xF5Dh7eku0SY2EE3JaYBps3kaBdFuFV4LPcR+Ph0yd+\nzR1mX6PdWgbFJQuILYUCIP6jMUqgCDwy1QpEjlWxXdaYT13WjITnWe/Rqa/V7RJkyBIPrSPTdoqq\n6BaGJACKTOMT/nZwHEE0bM+sEoJn2WS5aJhHR3fY/mECyEsNskmZPxi1xrYvYEmVLg0WgxMVljbL\n+tbmTlPs/WVZ1pfz9FDfe9TNLLTcLBe23TIWQuAcDdZG3eZgQ1DBu43yLh7WT5yg4YHVVYiQMDie\nQIIHasS0cgBTKah9SVCbS+g2MQLe3w4huQ2C6mGVio1Kxka1ZD90a1iRqFXCaeCDbbS6cVuePd4E\noF9feyQgm8ouooXNwYD+PoxFxgBAyGw076G/q7AaxpHTbQgVvQ0AiDaOgHKWWjRs5C4amXAyJE42\n+1ZXQD7Zg1HDMlkbkv0uJqi1sAuwxHi+7KJtUM9lAeAAVaJl8I0MJDEKLezBkgK9phqz0CxlhCTq\nau6N9OsWhm6hnSzV+9SqMOh5N5Wstr0rowqB2poPtbnYdq8e3RWodadAOZPH+r1d/hyvHvl3fFMc\nhfEtkJdEf8RjM0bH7TMsDLfooi4TnaJMM188yWVXJiWZMXECUMHrEUylIPEinrl7KsaU0OvwmtUl\nZZYhCR+eyASAfR/g4k2LhRAeyUm0HfWli/HszCcROlCHaONw5LoKzXMkIa4UCkAtFhvEtDACXpmu\nBWJcw+KSWla6ALnuYeZpTOjGW0cuQYZLFEAiHuYyMNpAYxg8lDN0wp9r2BEAQJeRJWkR2oK+RGxP\nN/DBwRZz3ghrn2ldTMgahxJ+PPvbGpdaVW5aaETjLMrO4l4SFHbv5CSZKYZ5tdmQ3SqIItvbxysQ\nfL0Ap0HrplaoVzLjJWYBSSPuRd/7sFwP2iI0pdZDUqBEeXYOADZXROtPQl+v/nsQo+D0kv1ZkXGI\nHKmG1pcEcATJfpkpjGx3nv5cIjghimBArwLbka6P4O3W0alOah0F5CTmumtKfSMmRsfRd6wPNIz0\nY63Pb1EYUXzRaVb/JRE3KzLJSWHw/rMQ0/T0bkWCZJTQFxQ2CVRWk1i9NQgKAl6ZlaAhETfrf06M\nsNG7UUyUKCJCagjhqAo+0AbeFQJUAZnBII2b68rKmOXv66zAoQP0fby57zhaO/qZwtB6kvX0bdou\nztcBIaUVhAATc8aA0wSmFA6dpXWsSNQFrd/IxqP7enWFYc0mg6BAwYUXkxoojsL4Fkh1p+Dv6x7D\n7aPXxu0zBIwsyCxv+4l1tSjLM2MIse4qq1Ay6Asr8ImmgJ9YnoPUgNucvGONIVhy57WuNFbKnGIq\nD06zWwrpQSqYeI6n9alCSVBOl7J2A/a0W8AapJYQ+mSGZTttgyhwtFS3Ygphq4XhcYnwy5YMNEPo\nRjzQQubzcuAgCRJy032oHp6OCaWmdURUEQLPQxYFaJ3pVGD79NURozK7QviwPf5EVBH//t4XUFqG\n2awMq8IAAAHm3wHdJQUAGZ40UxmrElyyiCdvn2S+CykCjifI8Adw45wyOjlTVyZzJmYjrIbN2JOx\ndK/SiLcaqQtGC1NhkSSbAk8TdJeUIdQs5VBO9dMRcpKWgUiEp7Emyyx6EpUBRbYIySgTRElCgM6W\nD3vAccC9KyvY+i9FQdrXdD6LgtxMe+wn1sIw4i9BVwAe68zrGBcM0QRbDEMLeaC25bPf1LmeXvzH\nXjNIThTJJuDFLEspeSLArSefcLzKhLhXS4NsJCiICqKqBjFJXx2yz2/ri96IRRgDgCIhrPUjElUh\nDTvM3lPQ6wKnSbqA1yCX6bXGoh7WFz3hPpwJtUIqOKK/Lxkk5NNTogmkAlrZluMAt0tEwO2DKNP+\n+/cj/wWoEtS2XNbHctkBCKmn0aN0mX1h6XcFA6+ueyEchfEt4ZeTErqmDJcUVRj0BWemUEG4bvRa\nLCmeb5scBMTMENcJeCX4JGvWlV3J8Kop1OIC4uGYcuw6GiFmmwmf0G8PmAvBALDXvALsEwiVeCtH\n4HlMr8qlAslou+DC/SvH4o4lleA4DgGX5XmNaxMe4QPTWbaIJEjgOA48z+He5VUYV2h1pwkQBU6P\nv3BmCitAZ3kbh/WZwt7WXk1E1OYys/cDT8y/rcqc53hz8StFgqpqyE33MSEuF1IhU5CejAW1w+Bx\ni6zv2kPnQECQ4be3CaAzigGwa/tc9P2JGY1QjOKSMZbJrz57Gaf08hj97X5EIhqdk6KnJfPufmj9\nSRg5LBmV+XROESdGzPiHYMR6aJ8E/BxdD4RwyPYnY9a4XL1iMkEwzZ75Zfw7rzYbnLsHQkYTBM2N\nTG9GwgQHI1gPjdYE41x0Do4R9F08kbphQmrETEuOuOhgQFeUvByBqCtytZNa2Sw2JqjwpOjVAfqC\nOHUmRAcEQhSn2nrpipFRN6C4zHIzYoSt+24qQgkqVOw+0MQC1NHGEeB5DjxxUYUhmgFvviPPzJQT\nFPsSBRE3tLCXPq8cYuO26CnqHvRJXmhSH4TMr6DyIag9AZCI1/atyWWf4AuiL1+rSKYVLijQuPOn\nK18KjsK4yhiL+filJDZSFwX6a6nJrMKi4nlx5/gkc2RdNzYX96+swrQxObYJdMYo2Ai4c+r54xux\nghIApo3JRllekNXI4sRInMJ49NYJmFOTh+oR6Wwbx3F214LFmkm03e+VML48A//4w5mWtrtQVZqG\nKaNpgb9kt2VGfIzbIsVFLTGZtygjAMkuUwEZLin2p8UymVJRgBEFetaZItvjFVZLyaLs3DEWn6BZ\nFYa9L41ArtaTjPJhKXHXBQCvaFhunF6sEfiwmU64Ks1Kwy/un2GLE7E26QojoK+/wid1oVdoocrC\niE/pM6e/6mpEe+QsSFTG4S/68HljB+1L0eLGCnuQnuzBD+brylGMMjcIc9vo/dCn9FOFokjweiTc\nsnAkND1773/wLr0ey2jTkwYkDby3GxwH5ChVcAkyunrtgkwLeYGoG5nJHowtobEF99h3bf2W7NXf\nn2XiY7SxHADHLEZXcjdLj40co4t7ifpvhE86B9HTD6JIaDsLLJtabCpPMQIihaD20PdoXE/K/QLH\nO7+09QH7Tej127SwG1oX/RYkuMHJYXNNmDN5UBXebm3pbZd68gDFBaL/LqXSA+AFDUTlkROh5WKM\n9VjkokO2e9sGYRZIxLRmeF8HVC5qq5D9dXEUxlVmWm4tvl+5GuWpZSjM9iMvw3fBMsaAXWE8fMtE\nVJWmQxR4uCXLaD6m2JgxpwBAvMJIkHW17vpK8DzHsrU4Vz8CMQqjOCeAtfPLael1C7bgpW0mMGdZ\n4J6ek5mi+2CtzxxTHSLda5kFH9P2FDcVUrFzXZJdZuoxUQWmhAEzPREAJpTms7LfAMeCnDzHo7LQ\ndNVZLTSPZL9Xrs+sXBuIcRdGvhoJEpUw0lVL54ggPmXZrSuMmhEZmFkyDi7ehY9a6Mxoj+iB1y3i\nb2+rxTBSYzvPUBheyW4hcqpl5r3lXXRGOi0uOIDjCHhXP3PdkKgMTSPwy0m0/Iv/HBN4xj0MIdQX\npdlkRJFoZhyA22uXQiAy+jU9vVQVMboklQn6kBKCKOuz7BUZR5o68L8/OGJruzEq9rpFdITtpdON\nZ/G4JJqRxZsTH3m9ijLpC0DrS4LqPwWS1MbaAQBpoHEtKfc4iEDb3tMfxeiSNKT5/OBFBYvqqJIy\nrG4jeQAAQoSWJmHKWBfW1HWnmNYDAJ8eT5RLP2FtiCrUqiME4LzdZu2xvjL9nvR3KfjPgfN2IuD2\n4cHV1QCA1n5zFjq9d0xsy0L0VIm+3DPdJ+UdB/GcYwkG3wRHYVxlZEHGxOxq8ByP7y8aiR//YOJF\nz0lxJ2Pl8GXYMv5u+7VkgZWdOBemPvrKIipsq4vy2XFG/MDjElGSG7C7aGIo8NOApkeSUD0iPtie\nkPNU3r11YTmCxshfz5PPTDaFtzECCql2X3a2pU6XITSeuXsq/v7OyUyhSTEKI9Vtr2wqWCwM6yz4\nJMlnsz4Ml4ZHcCPoNa+Z7DUtB3dM2u+SCebaKF6LMgcAtaUIoY/nYFp5iem600SED9WyY9J0K04U\neNx8XSUKg6Y7zbA+slK8eHjuTfh+5Wrz4npbjbXMDXjVFAzKqVJInEWBxKz+CIBNliNRF8JRFSIv\nguvIB+8KQUg5oz+XBwGfzEa0Lx9+FZwUBYnKdAY9gNqKbGQlmckNRJFQlO3HjrtnAQDeP/0R5k2l\n7qHjjf14+vd/gdKab6tHZQT9PS4RTT2nbM/F0psFDhyvlzZJp242v5EFSMzEBgiKTcCnCFkYnlwG\nTg4jTPoARURuuk/vQw8kt4KRZW57PxEeoU+m256J/V+1KAxRQW5yEE+so++1NjgbRBHB+2g8ROZd\ntMAi4aGeyQfv7oOoZ34ZvycjC83AL/tYSfqCpFzbPkPx2+ZX6Rjl/+OKnl4gXX2gOArjOwTHcXGj\n9fMxq2AaSoJFtm1uSWCZM8aodU5NPratHY/66aXmgfoo/b4VVchJ8wKqBKGtDBVCTAorgMXF12F2\nQR22zbyDZkCdh4fXVONHN9NsIbUzsWLRNIKyZOp/NjJsMlLMEdyDE36IMemVmJpbazsv22dJk9U/\nghS/C5kpXqYwYl1SAi+gIf8maCEPtO5UiJb5I2qXmYHmO4/CUImKJI+pMFycKYRjlZMk8phTMB0i\nLyLHZ0/ppXA2C+p7M0qw1KJkioL2NGlrIUt3zKhwQtY4TM6egNrsGty3YizuvGEU0j1p4L+YxALs\ngmY5R5Uw1jeV/TkqP4cJoegXMVl7UReum0DbInXnsc1EFeCWJTyxrhar6uis8rMhfSEpRbYp46DF\nFUhUEa3n+pnSA4C3mvRZ2Cy+ISP8V7N9rNSNLGBa7iRb8wxlIvI8KyjJe6k147aU5bAtiWwRxpkc\nvgAAFtxJREFU8KLAI91jWqsF6SnYciNNcy0JFiGqKdjXoseHrNcIe01LzSqg9Wu7KmgBxaxAEHkZ\n1MIszciB0mJm6vkkD1vRT23T0131+AYb+SsyIkfNWJnVcrxr7A+QKZjXY0kiSoLBGYvL2U11NdGx\nl4ijMAYRLklA9MtRiDaVob50EQBaUrksP2gTikb8QBJ55iKSzoxCTdr4uGvKgoQVw29ga4Ofj/Jh\nKSjVM7u0zgxET5aiyluHm+ePYMeoGsGkbF2p6EI74DU/wAJ/Hu6q+n6c6Wyr06WPFg0BnHoelxQA\njEorR/jATJB+P7xuy8dicc/5ZR8k0bJuQBd9zpAaZusVGGuwGyNhI25i5XtlS/DszCfhERMnEFgd\nc0umFmFpbTn7Oy9m9GhVOt6Y63Ech5srV+HWypswtiwdkyqpciFdGQgfnIwcbTRSQxVI8ZsCLyCb\nQjw3OYWlWmvdabbFqB5YMYnFc1xRM8OLKBJ8bhF+r4yZJTU2F5zEuVj2HGBXGOlJSbhuYgGrZmDF\nNlK3JB4Ygxm3LGLliGUYlzHG3GfUPhN4/GjyJtv1rKnYVrebVcBLAm+LbWX6A6yfjBU0P2r5S/w1\nwLE4H5s5DqA4RtFb331mssfWt92WQlKGK9Fg0cQy9n/N4i61Zj0mu4IYhrHmSUb/kfhBnNG3WmeG\nrSpEW/s3D3w7CmMQ4ZIFQBOhnCpLKLhMU1VflIfnMFUPLC+ZWmRzD31TlJPDUe4ej9k1+Vh/QyUC\nPhkTK7JQmVaOB8bfg++PXoX1SysvGq8xmJozEZWp5Vh3fSVumFbEtpsuqXjT3Mg6AwC/165QNo3d\niDtG3wyP6LEpU8OdUZlWztYrCPpkyIKA0P6ZSGqahYnZ1XH34jguYRbcTXOHQxJ5jCy0VyPmOR4z\n86diTsF0SDECNccikD1SYgUUiyjyIH1BjOCn4oeLpuMnG6axfUGX6U4LyH4MyzLjLJolQ84q7GVB\nBukz/67Q2y8LMlaOWMa2zxtXauu/ZIsy3bC0GqW5QXAch3UxKeW2YK3NzWO4pARIvIiiQEHcPlHg\nkO3LtGULWgcZomYpkWMZVQsCh6ClfdZvpDhQaMbXABDF7r5RWwtsbQAAn5qNGcEl7G+rRZAacGFq\neTH7Oxw2r11XUQTrEGJ4bjqeXj+ZXt+iMGLffapo/i58kg8TR9KBhW0FSgAjcvQkFMIjcthirV9g\nkbaB4lSrHURcyGUEAOHDtXopCl1hCFSQPb9lJmRJQGfv5cnVZugW8eTKbEyuNH/sxcFCFMcP0i/I\n/6pYmXA7C3rz8RaG12X+vA03zMYVVTh2shNlafkAqHKwWV+qBOHQfKy7czr6iglOn+3FzQvK8S+7\njgCKC66o74Iz92OZP7EA8yfGz8wHgFUj6hNutwpJcYCZLRuXV+EP7xzDwknDEPDJyMgwlYTVOgzI\nfvgyLLP+LQLKOodEFHmo57LAJ3XSkicWhZtvsYh8MTGboMWasQrxmswq/IvkM+s+2Xzv1oQH+n9D\n2VvbZMYwaP+nulPQHaUuKTp5sRcuWcDT62bhRx/sjrsPz3M2C8Mq4CVBQoY3jZVIJxEXinMCKM0N\n4PjpLvzo5tlY/+vfQOtJxuiSVBz+6hyWTivCud4e7NZj81ZrkOM4rJo+Gvve/QN9HsENoxwjITy8\nfBL6tG4QjYfEi5CM9FuL8pRjftMLa4vxJz1hbNPy8Th4UMNHh1sRPjgFj24owzP7fgEAaD5jqV1F\neD1BQLvwMtADxLEwBhEu+SI/CE2wuWOIvtCvrCsawz1kdWd8HQwBWT4s+SJHfnMyPelYWDgHM/On\nxu2zpqL69WcbV5aO5TNLbcfZFAYAF+eFW3QhNeDGI2vHIz8jCafO0s89mPTNA4cXQ+RFXDdsFoB4\nd9X5KMkN4MHV1XGZbACQ6gli1Yh6ZHrSURwchorCFEwbnY07llSymFehv8Am4JfPLIHSRu9dmVJp\nu55X8uDGEfVIknwoCgyz7Svw03MkXrQLe1jWvdd4WMvSrJk3nC13y3vpxDPjt2i9htXCAOyJDR7Z\nFLRBj0UhWq0XYs+ei7XCc7xm7GjNrFF4aHU11lw3An9zywRwHAe1dRhIXwCFWX688OBsFOcEkOIz\nrZzYZA2rS2ndItO11tsfZen0RsxGZoM9jrmR+hR7xV+XJGBTzQbMyp+GAn8eJlZkQuA53LGkEmmW\n2MzyGcNt5xluViNu+E1wLIxBxMUsjFnjcvH2fjPzJLZEN8dx+NnGujgBeqncNHc4GqaXXFyBXQY4\njsPS0oXn3WcQ65KyYk25BZBwzkN3H7W+aisSBbUvP8tKF2FR8bzzlsa/FFySgJlZU21Kdd0SqgQq\ni+ohSEvhkz22/hpdnIZfbVmMzvB0c20PCzPyp2J63pQ4l2JxsBBPTfsbcBwXF7BPd6fiq65GWiLe\nQpJHwuphy7Fj33PoPkWVeWYqFbb5fovCVI15SnpKttda0ZkqBlUltjZpFrcaIQRZlnO8Me0zrpfm\nTsG88YWIZerobLz/12akBszz3LKA6MlSSHnH4jKZrO3IS0nB8PwuHGnqRG9/FHmBTJzsb2QLZFnL\nAkWOjkNwxP9gYdHcuDaUJRezxJH0oAe/fGg2ezaDuqocVgZ9XFk6XMHJOID/w9xq3wRHYQwiLqYw\nblk4Emvnl6OxtQefN3YgOzU+ZnEhwXpJbfkWlMWlYLikEiHFKM7Z1Xlxx2xcUYX9R9owZVR23L4r\nAcdxl0VZANbRazwXs5is8Y9Yzhd/ssZCrFRnVmFf6ydx271uEQX+dPy07m9x5wc0iypLz57z25Yx\npuLKr7/L2QV12NW4GwVJuZB66TtU9USFm8ob8Lu3Dphr1INaGLIl1hUbjJ+ZPxUCL2B2QV3C9n9/\n0UhUD09H9XBT6QR8MpSTZfBHCzB+9riE5wF0lvniyYX42WsHMG9CATwZSfiwZR/bb/0NPnvnXEji\nfHuixkXgOA4NZdfHpc5WFKVAVZOx973ZQPT838BAcRTGIOJiCgOgftzCbD8Ks88vCAYjfu/5Pxbr\n8pX/eG9dQrfO6OI0jC5Oi9t+LTCQ38W3wdiMUZiQNQ4dZ2R8atlupJJLotnOjGTTXfS3Ux7Gm4c/\nxFu99L2kJ5vK5KfTfwwOHP5jD11kyBhoT8+bgvTp5XjtnWP44nS3vo/uTPekoa3/LOy5azQetrRk\nwXnbLwo8xpfbLUyfW8LT66cgySslVKC3j74ZRzqOISAnYWyZn8ULVS0NxYFClKdQi4rjOFw/pRA5\nad6v7facN2xm3Dae4yDKvC3V+JvgKIxBhEsWsKF+NBudOZhcyMIw1qj2uMSEyuJaxyV/N0KVPMez\ncvDaTILb/4Eu0+pOYI1a3aLpnjRMzpiKt7APxTl268WwwhItaFVRlIpHi1Jx29+/BcBUJptrNuC9\nU3sxKbsm7pyvQ1YCS92gOnMMqjPN+IVh7Qm8gAcm3GM7Nja2djkQeA6CMLBMxIFwxRXG7t278dRT\nT4EQguXLl2P9+vW2/ZFIBA8//DA+++wzpKSk4Nlnn0Vu7sACfQ7xGKl2DpQlU4vw2RftbC2PRPSF\n9bURLsEFcC0x0Mmg3yY8x+HJ2yfh4JftKMk1lcC9y8fYStwYlOUHsfnGsRienziRYiBC0fDyB10B\nXF983ddq97UGz3PQNHLxAwfIFf1CNE3DE088gZdeegmZmZlYsWIF5s6di9JSU5O+9tprCAaD+NOf\n/oQ33ngD27dvx7PPPnslm+UwhPjejBJ8b0bJBY8xXDaJYjrXMo/eOoEF67+L5Kb7WGkOA2t8IJYL\nuQQvpBR9Hgm9/dG45IahgFsW0NN/eSrVAlc4rfbAgQMoLCxEXl4eJEnC9ddfj127dtmO2bVrFxoa\nGgAACxYswJ49e65kkxwc4lg5uwzXTSjA+htGXfzga4jinACqStMvfuAgIJFLyuAnd03FxJGZmFOT\nf95jBhuP3joBs6vzUDMigxX4HJ5/iZOfEnBFLYyWlhbk5JiLwGdlZeHTTz+1HdPa2orsbJp5IggC\nAoEAOjo6kJx85XP4HRwAGrtYPW/4xQ90+M7i8+iT+hIojtL8ZGyoj1/tcjBTnBNg8Z7RxWnYfONY\nlOR8xxVGbIntgRxDCBlwuQgHBwcHAJhQnonjE7pQNybn4gcPQS5Xht8VVRjZ2dk4dcqcKNbS0oLM\nzMy4Y5qbm5GVlQVVVdHT04Ng8OKa0Fr6YKjj9IWJ0xcmQ60v7lsdXzzTYKj1xZXiisYwxowZgxMn\nTuDkyZOIRCLYuXMn5s61z16cPXs2Xn/9dQDAm2++icmTJ1/JJjk4ODg4fE04MhC/0Tdg9+7d+MlP\nfgJCCFasWIH169fj5z//OcaMGYPZs2cjEongwQcfxKFDh5CcnIwdO3YgP3/oBKccHBwcrhWuuMJw\ncHBwcBgcfPdm9Dg4ODg4fCdxFIaDg4ODw4BwFIaDg4ODw4C45hTG7t27sXDhQixYsAAvvPDC1W7O\nFWfbtm2YOnUqli5dyrZ1dnbitttuw4IFC7Bu3Tp0WxYMfvLJJzF//nwsW7YMhw4duhpNviI0Nzfj\nlltuweLFi7F06VL89re/BTA0+yISiWDlypWor6/H0qVL8Ytf0JXWmpqasGrVKixYsACbN2+Goijs\n+E2bNmH+/Pm48cYbbanugwVN09DQ0IC77roLwNDtizlz5uCGG25AfX09VqxYAeAyfyPkGkJVVTJv\n3jzS1NREIpEIueGGG8jRo0evdrOuKB999BE5ePAgWbJkCdv205/+lLzwwguEEEL+6Z/+iWzfvp0Q\nQsjbb79N7rjjDkIIIfv37ycrV6789ht8hWhtbSUHDx4khBDS09ND5s+fT44ePTok+4IQQvr6+ggh\nhCiKQlauXEn2799P7rvvPvLGG28QQgh57LHHyD//8z8TQgh5+eWXyeOPP04IIWTnzp3k/vvvvypt\nvpL8+te/Jlu2bCF33nknIYQM2b6YM2cO6ejosG27nN/INWVhDKQ21WBjwoQJCATsJZ2t9bcaGhpY\nH+zatQv19XSd6LFjx6K7uxttbW3fboOvEBkZGaioqAAA+Hw+lJaWoqWlZUj2BQB4PLQ+UCQSgaIo\n4DgOe/fuxYIFdD2HhoYG/Pd//zeAwV+vrbm5Ge+88w5WrjTXff/ggw+GZF8QQqBp9hUNL+c3ck0p\njES1qVpbW69ii64O7e3tSE+nReUyMjLQ3t4OwF6XC6D909LSclXaeCVpamrC4cOHMXbsWJw9e3ZI\n9oWmaaivr8e0adMwbdo0FBQUIBAIgNertmZnZ7PnPV+9tsHCU089hYceeoiVFDp37hyCweCQ7AuO\n47Bu3TosX74cr776KgBc1m/kmloAgDhTRi5Iov4ZbHW5ent7sXHjRmzbtg0+n++8zzfY+4Lnefzx\nj39ET08P7rnnHhw7dizuGON5Y/uCDKJ6bW+//TbS09NRUVGBvXv3AqDPF/vMQ6EvAOCVV15hSuG2\n225DcXHxZf1GrimFMZDaVEOBtLQ0tLW1IT09HWfOnEFqaioAOkJobm5mxzU3Nw+q/lEUBRs3bsSy\nZcswb948AEO3LwySkpIwceJEfPLJJ+jq6oKmaeB53va8Rl9car22a4G//OUveOutt/DOO+8gHA6j\nt7cXTz31FLq7u4dcXwDUggCA1NRUzJs3DwcOHLis38g15ZIaSG2qwUjsSGDOnDn4t3/7NwDA66+/\nzvpg7ty5+OMf/wgA2L9/PwKBADNFBwPbtm1DWVkZbr31VrZtKPZFe3s7y3QJhULYs2cPysrKMGnS\nJLz55psA7H0xZ86cQVuvbfPmzXj77bexa9cu7NixA5MmTcIzzzwzJPuiv78fvb29AIC+vj689957\nGDFixGX9Rq650iCJalMNZrZs2YK9e/eio6MD6enpuPfeezFv3jzcd999OH36NHJzc/Gzn/2MBcb/\n7u/+Du+++y48Hg+efvppjBo1OBYF2rdvH9auXYsRI0aA4zhwHIdNmzahqqoK999//5Dqi88//xxb\nt26FpmnQNA2LFy/Ghg0b0NjYiM2bN6OrqwsVFRXYvn07JEkaMvXaPvzwQ/zqV7/C888/PyT7orGx\nET/84Q/BcRxUVcXSpUuxfv16dHR0XLZv5JpTGA4ODg4OV4dryiXl4ODg4HD1cBSGg4ODg8OAcBSG\ng4ODg8OAcBSGg4ODg8OAcBSGg4ODg8OAcBSGg4ODg8OAcBSGwzXNqlWr0NDQgOuvvx6jRo1CQ0MD\nGhoasG3btku+1u233z6gctePPPII9u/f/3Wae0kcPHgQ//mf/3nF7+PgMFCceRgOg4KTJ09ixYoV\nF6w+apSKuFZ49dVXsWfPHuzYseNqN8XBAcA1VkvKweFS2LNnD7Zv345x48bh4MGDuOeee9De3o6X\nX36ZLaizdetW1NbWAgBmzpyJl156CcXFxVizZg2qq6vx8ccfo7W1FUuWLMH9998PAFizZg3uvvtu\n1NXV4cEHH0RSUhKOHTuGlpYW1NTU4OmnnwZAa/M89NBDOHfuHAoKCqCqKubMmYMbb7zR1s62tjZs\n2bIF586dAwDU1dXh9ttvx3PPPYe+vj40NDRg0qRJ2Lp1Kz7++GPs2LED/f39AICNGzdixowZOHHi\nBNasWYMlS5Zg3759iEQiePzxx1FTU/Ot9LXDEOGbLNbh4PBdoampiUyePNm27f333yeVlZXk008/\nZdusi8scPXqUzJo1i/09Y8YMcvz4cUIIIatXryZbtmwhhBDS1dVFamtrSVNTE9v37rvvEkIIeeCB\nB8jatWtJNBol4XCYLFy4kOzdu5cQQsiGDRvIL3/5S0IIIY2NjaS6upq88sorcW1/8cUXyWOPPcb+\n7urqIoQQ8q//+q9k8+bNtrbX19eTs2fPEkIIaW5uJjNmzCA9PT3kq6++IuXl5WTnzp3s2WfNmkUU\nRRl4Jzo4XATHwnAY1JSUlGD06NHs7y+//BI///nP0draCkEQ0Nraio6ODiQnJ8edu2jRIgCA3+9H\ncXExTpw4gby8vLjjrrvuOogi/ZQqKytx4sQJ1NbWYu/evXjyyScBAPn5+cySiWXcuHH4/e9/j2ee\neQYTJ05EXV1dwuP27duHpqYmrFu3jhWkFAQBjY2N8Hq98Hg8WLx4MQBgypQpEAQBX375JUpLSwfa\nXQ4OF8RRGA6DGp/PZ/t706ZNePzxxzFz5kxomoaqqiqEw+GE57pcLvZ/nuehquolHTfQdRbGjx+P\n119/He+//z7+8Ic/4MUXX8Tvfve7uOMIIRg1ahReeumluH0nTpyI26Zp2qBa68Hh6nPtRAAdHC4C\nGUD+Rk9PD6tO+sorr5xXCVwOamtrWVnpkydP4sMPP0x4XFNTE5KSkrB48WJs3boVf/3rXwHQtS6M\nMuYAUFNTg6NHj+LPf/4z23bgwAH2//7+frzxxhsA6BKlAFBYWHh5H8phSONYGA6DhoGMprdt24b1\n69cjJycHkyZNgt/vT3h+7LXOt+9Cxz366KN4+OGHsXPnTpSUlKCmpsZ2P4M9e/bgt7/9LQRBACEE\nTzzxBABg2rRp+M1vfoP6+npMnjwZW7duxXPPPYft27eju7sb0WgUBQUFeP755wEA6enpOHLkCFau\nXIlIJIIdO3ZAEISL9omDw0Bx0modHK4Q4XAYkiSB53m0tLRg5cqVePnll1FQUHDZ72VkSb333nuX\n/doODgaOheHgcIU4fvw4HnnkERBCoGkaNm3adEWUhYPDt4VjYTg4ODg4DAgn6O3g4ODgMCAcheHg\n4ODgMCAcheHg4ODgMCAcheHg4ODgMCAcheHg4ODgMCAcheHg4ODgMCD+P4xSKOOE0RxSAAAAAElF\nTkSuQmCC\n", "text/plain": [ - "" + "\u003cmatplotlib.figure.Figure at 0x7f97f1e98d90\u003e" ] }, "metadata": { "tags": [] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAe8AAAFnCAYAAACPasF4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzsvXe8XVWZ///e5dTba3pCQiAJCSWE\nIJGmoSSgjsg4gmCb4Tf+dCwURUdEQXGs41gYFQvDiIyIiKIIJIAgEBJCgJBKertpt59z76m7fv9Y\nu55zboiQBCL783rllXt2WXvttfden6et55Fs27aJECFChAgRIhw1kF/vDkSIECFChAgR/jZE5B0h\nQoQIESIcZYjIO0KECBEiRDjKEJF3hAgRIkSIcJQhIu8IESJEiBDhKENE3hEiRIgQIcJRhoi8I7yp\nMW3aND796U9Xbf/iF7/ItGnTQsfdcMMNoWOWL1/OBz/4QQB2797NCSec4O3btWsXH/vYx1iwYAEL\nFizgkksu4bHHHgPgpptuYuHChSxcuJCZM2fy9re/3fudy+VC19A0jfvvv/9vvq/Vq1dz1VVXHdSx\nDzzwAF/72tde9bVcvNbz3wi46667+P73v/96dyNChFeE+np3IEKE1xsbN24kl8tRX18PCBJas2ZN\n1XErVqxg/fr1IZIeCZ/97Gd597vfzW233QbAqlWr+PCHP8zDDz/MV77yFe+4+fPn8+1vf5vTTjut\nZjvr16/n/vvv55JLLvmb7umkk07i9ttvP6hjly5dyvnnn/+qr+XitZ7/RsAHPvCB17sLESIcFCLN\nO8KbHm95y1t49NFHvd9LlizhxBNPrDruuuuu4+tf//pBtblp0yZOPvlk7/fJJ5/M4sWLGT169EH3\nq6+vj09+8pO89NJLXHHFFYCwAPz0pz9lwYIFmKbJypUrufTSS1m4cCEXX3wxS5cuBYRV4IILLgDg\n1ltv5atf/Sqf+MQnOO+883jve99LT0+Pd53ly5czffr0qmu98MIL/OM//iMXXHAB73vf++jq6gKg\nu7ubD3/4w1x88cWcf/75fO9736vZ18p7ueqqq1i4cCHz58/njjvu8PatXbuWSy+9lAULFvCBD3zA\nu85I26dNm8b+/fu9893fy5cv5/LLL+fqq6/mM5/5DAD33nsvF110ERdeeCFXXnkle/bsAcC2bb7x\njW8wf/58FixYwC9+8QtvrL74xS8CsH///pD15MknnwTAMAy++MUvsmDBAi644AI++clPVllMIkQ4\n3IjIO8KbHhdddBF//vOfvd8PPvggCxcurHmcbdssWrToFds855xz+PSnP82dd97J1q1bARg1ahSS\nJB10v9rb27nuuus45ZRT+PWvf+1tt22bxYsXoygKX/7yl7nqqqtYtGgRH/3oR7nppptqtrVo0SJu\nuOEGHnvsMdra2rjvvvsA2Lp1Kx0dHYwbNy50rVwux8c//nGuu+46Hn30UT70oQ9x9dVXA/C///u/\nzJ07l4ceeogHHniArq4uLMuq2VcXP/nJTxg/fjyLFi3il7/8Jd/97nfZt28fIISiq6++msWLF3P+\n+edzyy23HHD7gbB+/Xouv/xyvvvd79Lf389Xv/pV7rjjDh555BEmTpzIj3/8YwD+9Kc/sXr1ahYv\nXsx9993HXXfdxerVq0Ntff7zn2f69OksXryYn/3sZ3zuc59jcHCQJUuWsHv3bhYtWsQjjzzC1KlT\nWbly5Sv2LUKEQ4mIvCO86XH66aezefNm+vv7KRaLrFy5knnz5tU89oYbbuA///M/KZfLB2zzO9/5\nDldeeSUPPPAA73znO5k/fz533333Ienv2972Nu/v+++/n4suugiAOXPmeNppJU477TTGjRuHJEnM\nmDHDI85ly5bVvNcXXniBUaNGceaZZwLwzne+k127drF3717a2tpYsmQJzz//PPF4nP/6r/+is7Pz\ngH2+8cYb+dKXvgTAhAkT6OjoYPfu3Wzfvp3BwUHOPfdcQJitb7311hG3vxKSyaR3P21tbbzwwgue\nteO0007zxuepp55iwYIFxGIx6uvreeihh0LWlkKhwPLly/nIRz4CwKRJk5gzZw5PPvkkra2tbN26\nlUcffZRiscg111zD2Wef/Yp9ixDhUCLyeUd400NRFC688EIefvhhWltbOeuss1DV2p/GzJkzmTt3\nLnfccQezZ88esc1EIsFVV13FVVddxdDQEIsWLeLrX/8648ePf80TfXNzs/f3Aw88wJ133kk+n8ey\nLEYqVdDQ0OD9rSgKpmkC8Mwzz3gEFcTQ0BBdXV0hC0Q8HmdgYICPfOQjWJbFV77yFXp6erjyyiv5\n1Kc+dcA+r1mzxtO2ZVmmt7cXy7IYHBwM9U1VVVRVHXH7K6Gpqcn72zRNfvjDH/L4449jmib5fJ7J\nkycDMDg4SGNjo3dsOp0OtTM8PIxt21x++eXetkKhwBlnnMFJJ53EjTfeyK9+9Ss+//nPM3/+fG66\n6aZQexEiHG5E5B0hAnDxxRfzve99j5aWlpo+2yCuvfZaLr30UsaPH19z/8DAAC+//LKntTY2NvK+\n972Pp59+mk2bNh0yLa27u5sbb7yRe++9lxkzZrBjxw4WLFhw0OcbhsGaNWtqCiGdnZ1MmTKF3//+\n9zXP/ehHP8pHP/pRtm/fzr/+678yZ86cA17r+uuv58Mf/jDvf//7kSTJG4OWlhYymQyWZSHLMrqu\n093dPeL28ePHI8uyJ3xks9kRr/nQQw/x+OOPc9ddd9Ha2spvf/tbHnjgAe+6g4OD3rF9fX0kk0nv\nd1tbG4qicN9991FXV1fVtrs6IJPJcMMNN3D77bdz7bXXHnAMIkQ4lIjM5hEiALNnz6anp4fNmzdz\n+umnH/DYzs5OrrzyyhHNuKVSiU9/+tM8/fTT3radO3eyatWqEaPKR4KqquRyuZoa9cDAAOl0milT\npmAYBvfccw8A+Xz+oNpevXo106ZNIx6PV13r5JNPpre3l1WrVgHQ1dXF9ddfj23bfPnLX+aZZ54B\nYOLEibS3tyNJ0gH72t/fz6xZs5AkiT/84Q8Ui0UKhQLHHHMMo0eP5pFHHgHgd7/7HV/+8pdH3A7Q\n0dHBhg0bALjvvvuQ5drTWH9/P+PGjaO1tZXBwUEefvhhb2zmz5/Pgw8+iKZpFAoFrrjiCjZt2hQa\n93PPPZff/OY3ABSLRb7whS+wb98+7rvvPn70ox8BwgoyZcqUgxrvCBEOJSLyjhABkCSJCy64gLe+\n9a0jkkEQ//Iv/4Ku6zX3jR07lp/85CdeVPiFF17Itddeyxe+8IVQBPrBYM6cOfT09HD22Wd72qaL\n6dOnc84557BgwQIuu+wy5s+fzymnnOKtPX8lLF26NOTvDl4rFovxwx/+kFtuuYWLLrqIT3ziEyxc\nuBBJkrj88sv53ve+50W4z549m3nz5h2wr1dffTWf+MQneNe73kWhUOCyyy7jS1/6El1dXfzgBz/g\ntttu48ILL+TPf/4zN998M5Ik1dwOwvJx88038+53v5tUKuUt8avEO9/5TjKZDBdccAGf+cxnuOaa\na9i/fz/f/OY3ufjiiznrrLO48MILec973sN73/teTj311ND5N998MytWrGDhwoW85z3vYcKECYwZ\nM4bzzjuPdevWceGFF3LRRRexZcsW/vmf//mgxjxChEMFKarnHSFChAgRIhxdiDTvCBEiRIgQ4ShD\nRN4RIkSIECHCUYaIvCNEiBAhQoSjDBF5R4gQIUKECEcZIvKOECFChAgRjjIcNUlaenuHD2l7LS1p\nBgcLh7TNNyOicXztiMbwtSMaw0ODaBxfOw71GHZ0NNTc/qbVvFVVeb278HeBaBxfO6IxfO2IxvDQ\nIBrH144jNYZvWvKOECFChAgRjlZE5B0hQoQIESIcZYjIO0KECBEiRDjKEJF3hAgRIkSIcJQhIu8I\nESJEiBDhKENE3hEiRIgQIcJRhoi8I0SIECFChKMMEXlHiBAhQoQIRxkOK3lv2rSJ888/n7vuuqtq\n39KlS3nve9/LZZddxo9+9KPD2Y0IESJEiBDh7wqHjbwLhQK33HIL8+bNq7n/a1/7Grfeeit33303\nzzzzDFu2bDlcXYkQIUKECBH+rnDYyDsej/Pzn/+czs7Oqn1dXV00NTUxZswYZFnm3HPPZdmyZYer\nKxEivGmhGxZL1+6jWDZe76542NuXZ822/te7G0cNXtjYy879wyxduw/Lsl/v7rxq9GWKrN8x8Hp3\nA4D9AwVWbekDoKyZPPdyN7Y98tjmSzovbOw54DFHGoetMImqqqhq7eZ7e3tpbW31fre2ttLV1XXA\n9lpa0oc8Z+xICd8j/G2IxvG143CN4d2PbOTXizdw3twc11x+6mG5xt+Kf/nm4wDc/+13oSiHTn/4\ne3wP9/Tm+NEf1ni/48k4F8075rBe83CNo/vcf3XzQpobEoflGn9rX+79+jv4+d0vsmzNPmRV4aK3\nTq55/I9/8SzPv9zNdVecytvnTHjF9o/Eu3jUVBU71JVuOjoaDnmlsjcjonF87TicY7hhu9BwN+wY\neMM9p737syTjh2YK+nt9D7dWaKobt/dz2tS2w3a9IzGOXXsz6K3pw3qNg0V3zzArN/YAsGnnAKcd\n117zuA3Oc3hh/X5mTWw+YJuHegzfUFXFOjs76evr8353d3fXNK9HiBDhtcE180lIr3NPqqEZ1uvd\nhTc8SroZ+m2aR/+YvZFcOJZtY5jiG1EPYAVqrheWgsHh8hHp18HgdSHv8ePHk8vl2L17N4Zh8MQT\nT3DmmWe+Hl2JEOHvGq6LTnrjcTdGRN6viHIFeRtHsc/bRb6kv95d8GBaticQqcrIH0mLY+bP5N44\n5H3YzOZr167lW9/6Fnv27EFVVRYvXsz8+fMZP348F1xwATfffDOf+cxnALj44ouZPLm2ryFChAiv\nHW9E8tYj8n5FaHp4jEzz6CfvQun11byDQWeWZeP+UuSRddn6VAyAzAE072x5iKZE4yHp48HgsJH3\nrFmz+NWvfjXi/rlz53LPPfccrstHiPCGwf6BAo3pOOmk+Nx6MkXSCdWbEGqhe6BAQzpGOukf0z1Y\noLk+QSJWHbiZzZUxLZvWxmRou+Wazd+A7H0kzOYDQyUUWaKp/rUHSFm2TVd3jgmj6pEliZ7BAk11\nCRLx8PMoayZ9QyXGtde9pusVSjq7e3OhbYPDJbJ5jaa6uLetN1MkGVdoSMcrm6BYNtiyJ8u49rqq\ndwOEANWXLTKmrbqvA0Ml4jGF/mzJu+dK2LZNV0+Ose11ntnZtm329OUZ116H5IxTXeBdz5cM9vbl\n6WxJeedYts2W3VniMZljRjfSkynSmI6FYiJ2dQ8zpq2OmFqbZGudUwslzbdmmJZV9bdl2WzenSGV\nUEknVOJxxfuOhgo6fZkidakYqYR/naV7V/B/G+7lIye8n4s7zjng9Q8VjpqAtQgRjkaUNZMbfvYs\njXVxvv+pswD499uWIQG3//v8mufohsXNd6xg9nHtfPQfZgLQny1x48+X8455k7jk7ClV51z7388A\n8D8VbbpKhvw6cLdpmWzL7mBq85SawsOR0Lw/++OlQPW4vBosfm4X9z6xlcvPO45Tj2vn33/6LBOm\nZ7A7N3LdnH+jOdEEwDf/70V2dg/z7Y/No7059aqvd/MdK+jLlkLbNuzKcO2tS0L3c8Mf7sYup/nF\nxy6vauOexzezZNdKmuoVvnvlZVX7n1i5h3v+spmvXHU64zvqve2mZXljB3DlBcdz3pzxVeev2TbA\n9+9dxZknjuaqd5wAwOMv7uH/Ht3E+88/jtOmdfLvP32W9iZfcFi5uZdfLd7I208dxwcvnAbAqi19\n3HqfiKq//v2z+c7dK5k+sZnPXSFWSGzcNci3fr2SudM7+fgls6r6kc2V+ffbljF1XBM3fHBOjdH0\nEdT8g0vvXFJ/YVMvP7l/rbddSuZomLUSuXkaVqaTz922jHHtddzy/73FO+avu5cAsKJ7JRefeGTI\nO0qPGiHCYYRmiAlhKK8BYDj+tQMZP4tlg7Juhvxre/pymJb9N/vcfBPhkWfvezf/ie+v/CkrulfW\n3K8bZs3tbyT8cevDfGHJLWimxsrNIsh21ZY+9vYXIFair/FZ+kuDdA3v8c7Z2S0ijQcOMrhJt2qb\nkSuJuxZ2De0hPmkDieNfrLm/qzdH4riXKI15ofY1MkVsoKsnrOFXmrbXjrAuf9veLADPrNnvbXtx\nUy8AK17uYSivIaWHyE+7H7mpx2lLRG4/8aI/Zv2Be3VzAGzYlfHb3LsRKV5gxYaemv3oHiwCsGVP\ntuZ+0zIxLfG+BX3uZoC8NSe+oC9bDJ0rN/Wjy3kxxoo4d09fPnRM2RDPOqkcuSVwEXlHiHAYURlf\ndDDapghSssnGtzGsiUm1NyMmt6DPc1XvWnoLB0524pL366F5P71HJF7al+/2tlkBf+PR4PN+ZOcT\nDGnD9BbD45zJlVEa/WVcRaOaaJUDBEC52DW8m2v+egNL9jxb+wDJIjZlNXJzd2iz+1yf6FpywPZ7\nC/6qHsuuHm83IK43EyasSvIeyVRdy6Li3rdhWpiWTWzsVtHGpA0j9jN4vf4KoaW/OMCSwu+Jz3hu\nxPMrz6nEf734E7723HeBcLR7Tisi1WWIH/8Cw8ZwVV8AJFXz/pbragsHZVMck1CqXReHCxF5R4hw\nGFG5tEc/iKU+Zd1Ead9LpvU5/mfdrwF/cnU1hf35Hn625k7+47n/8jQGoCoDl6d315hkNw5sYU3f\n+oO+l1eLxri/TjVI2G/0pWLuhAygW+EI6d5MESnuE0ZBD5MfgHIQEtOjO/8KwIPbH625X2ndj9q+\nl8TxYeuFu7xpW3YnALZR7QEtaQZFxSfvWn0cibzzAQKT6rIMpNbXJH+5xj2qTuCXYdqifcVpq0Yf\na12vp6Ivq3qFCVtOjEzQwf6XtDD5dg3vYcfQLnoKfWim7l9L0fn+y98mOfNZlOZetiUfreoLgBQr\nB/7WqIWyeeSj0CPyjvC64I2UZvBwwqwgU10/OPKWG4RWtze3D/AnJ3epUG9RTMq6pYcmm0rh4EBW\n8x++9DNuW/2/r9ifkfBiz2q+8dz3KRrVpBCc6IPEFyTv16p5W7bFrSt/7hFg9X4bpW0vcsv+mvtf\nCbuGdnt/l4zw5NybKSLFAuRtVCeRMi2bIW2Y32/+MzktX7U/eI2JDeOq+g4gN4nnbNvhB1jWTWzb\nJlN2TMuKUUWufZkScr2vKeZr9NGNZu/LhImx4JiWpbosyZnL2Bd/gd3De6vOryWfuEuuTMuirJtI\nqmjLNqsDNF3BsxAwZe9zTNKphAgEfMkhb9saWRjakFvtPefKe1m+z3cZ5PScZzavJGJNyZLT86G+\nACEhDTV8zu83/5lfrL0LzXnHa1lgDhci8n6TwPUvWrZdMYGaNY97pW2vBat61/LJJz7PtuyOmvst\n2+Kl3rUU9dIBr23VEAAOVV9/uf433LT0m6/6fLcfQfI2TCtEriMJMJpmIsUFIbYmW4Cg2VycP1jy\n/YHByaaSED2zeeC3bdvkdJ9MKv3ouiGIwbKtmtqWi9vX3sXu3F5W9673zgHxXILm/JAGG+hfppzh\nzvX30F3oxbQsBofL2LZd9QxFX6rHqrfYz4bBzdy/9aGq4yzbxjBMYsesIzaxtrm2ss18SQ+ZTHcN\n++RdOSn3ZkpIcX/cCjUEGNO0+c7z/81fup7imb3Lvevphskftz7MZ578Mn0lIaTJkh+xXiwb7O4f\nRG7sQ2l0yLscDnzTdJMhLYdhi/5KEhR1v49lzWT7/iGkhE/Y+RoCxLDdR+yYtfQMD4W254o6iZlL\nSc70a04MlAZDxzyzdzm7zZeRm3tInb6IPbl92LbtRZAXk7tZMfCMr3lXQB2zleuXfImCXgwJoK5F\nJp2IYds2e/Ou8CXh2pJ0w2JgqESuqDNYyrIz/gyJ414C/JgDwzK4e+PveaFntX9fWt5/xnJ1v7Kl\noWqzeby25m1ZNn/peoqVgfaPJHlH0eZvAnQPFvjCT5/l4jMm8fLOQbbvG+J//n0+Dy7bwX1PbuPm\nf57LxFEN/PWlPdy5aCPXv382MyYJ0vjNXzbzyIouvvWxeXS8hsjZIO7fIibbv3Y9w5SmY6r2L937\nHHdv/D2N+jF0r5zOj649J7QsA2BTV4Zv/t+LfPySWcydLrLzPfp8F3c/tpkbPjgHPdVNS6KZ0XWv\nLnPfc/tFAJBhGajy3/aZvLxzkO/cvZIPLZzGceP9VIr5khEycRumTUyt1iZKuomUEGSwfVeZNW39\nXhCNaxbvCfgyQ5p3gBwf2LaY4eQQ0IYkSZR1k49/90nOOGEU557lB9Zc96On+MAFM5h/6niG8hrX\n3LqEc04eiz3xRV7u38R/nHUjsQOMgaZb/P//+SRnnTiGf3nHDD73k6VkpN0kRCBxyKToE7PFnwZv\nB6A50cSG5Z1s2JUhEVcoayafuewUZk5u5cWe1WzdrPDw091V0dvd+XDw0n/d8xLb9g3xb+85ke/+\n5iWuuGgikmKCbGFaJorsE+SqLX384Herue59JzNrShvb9w1xyy+fB+CbH5tHZ3OKVV27vON/8dBq\nxqnT/WsnXkJp9f3QtUzSmfKgR3j3P72de34Dn79iNt/69UpSpz8ROvalbfvZ2TlMc32cz922DOnY\n5SSm+wKQVKHxLVu3n98/v5LkTH9btpynLp6mpBlc/+Ol5EsGiVk+mQxr1Zp3X/ol1NR+8oqBbvhR\n0kPlAnJdmNAHy2F/76833AdAfLLQqH+y5EFSPafQ0SKeUXncc7yUAzlZ+x5iEzZj2LBpcAu5cnXf\nDNPi9kWrKDrmckm2QBZC4pduX06PE6TWcEwXeJ+5ze0PvsykUQ3stzZXxRIM6znyJdFfqYZQ8Z3f\nPUeb7FtB5MZ+5PQw2BJIdugeilr1+bWEuMOFSPN+E2CjE7X50LM72b5PfJCWbXPfk9sAPzr0waXC\nf7Z07T7v3EdWiIIxm7p8Te+1wtXmRlp77EbuZhFmuv6hamn2ryvFMfc+4ZeSve+vIjBm+cbd/PdL\nv+CW5f/5mvsa1BoPFktWi34/tGxnyOddKOkhzbsye5aLkmYguf492eLhZ3d6y1hcTb7HMZvHlXhI\nU3DJ0bAMFu34CwPNKwDxvIediPdn13fTlfMjfVEML3p2l6O1PLVqN893v0TeKJAp1Q7ScbEvI/Yv\nWbMPy7YZGCqHTI1lI0jezrNP+pO1bulelHDZuc+la/ezYWAzt6+9i8cHBUm8vCus+e2vIO91OwYp\nlk1+/dgG1NHbWbRGmFslSZivg1i0XBDzn5buAMS6eq/d/gLL973Attxmb5tml0NLBOzOTeJ/UwgE\ntczmPaVe729TEtaRxc+NUIBJ0dm0O0NXTw7dsFCawgFykmqA5L87f1ixKqQVA2RLeedehCY7arRN\nLOW/vzm9uo+u8UFp2093xo84HypWa+lBzTtoNZJi4t4yWYOd3TnvGVYimbb454unM2/mKN4+2yfI\nn6/9FXvG/L7KvVHWTZZt3hHaJsU0hgs6PYNFL9+BlvYtJA2NYoy27xtClqvzIeS0PANDzvtYg7yF\nZu5bshLTxfcjmXHn+v67nCtWZ4orRWbzCIcSlZGiUrzAQCHrmbfcCdUNsKn000LtwJRXCzenkSzV\nfv08U63j56t15ZST8CS47MM1t9nqa5N+g6biVxOIogdyJQfHslAyQj5vbQTyHtZySJK7QNsMCS8e\neRd8Yqg1BtlymKzcyF8Xe3P+RCkpJprmulWcyzb4wlqmXE3ewTEKEpebgSpE3gEBSDMsUDXkQKT2\nYDHnBWC5SMQVT4iT04JUKpPT7A1EsQ+X/D70JlcTm7iR4lh/nfJAKSx8uglz+lIr+dnqX4aEqoHC\nEHe+fA92zH+PJMW3mtiYge2Oz7aW5q0NBI4TRDE40lI/xaA3U/RiG1yhIISA1hcb5QsBVkEEBA47\nhNubKSI39DM0cTGmFCDvcjUhm5Lfn64B35ozXK6+n6DmXWt5m6aJ96w0AnnbisbZJ43lX981k7NO\nGlO1X2kJC2NlKUd8imOSdueCWJk9TuKaGcc0ITf1hPz6l54nhILebBHd9L8LN2ZgWM95Yyyp1fcg\nqToZR8gNCktSsQXbkpDrhjyXVqYQHs/GeEOkeUc4tIhVJNxPnvIUNy3/ukfq7oTvEnStmsG1siu9\nWnjBOCO8fqZDDF6QTsW1C3qRVervUNp3UyxXTxSGUjs46GAR/AC1V6F5uzm7K8k7XzIOSvMeChCv\npBj0Z/0J1jQtbNv2JlLN1MgXq33Kg+UwWelGONZhMKhNK4bXF89H3uATT7YGebtL2AAKZoA43Ykx\noKGEzeYWyROXED/Gj3LPFKsrMCVisqfpuYFKlQUt9hd88t7W7U/87uQaxEDRHw/TMulveB65foBy\n0xZW9a1DM/yJfqhUYwJWDIYdTasQ8C3bloRqJ8g774zhPV+bQS2gPTvrg0dapy8pOn2ZkhfbYBuB\n4C6XuFS/j5Ll77cKIrmKaxbvzRRDhKZaooJXvkLz3jS4FSvhH7c34z/zXMDEbmbbkGyZTDDOooal\nwRVkhgvV38yY5Hh0S/e+p2DSFu/8eFhrjU3Y4AluFB33k6qxs1tsax6TJTFNuLdkW4yHkhTj25sp\nUTQDz0lLOPeV9yPTlWrNWVI133IQ0MylvTPBjCHFyyROehqAgYL/DZzVeiGtyRaKRumIBeNG5P1m\nQ0CajKsVmrdyZDXvkczmXiCRM2lVLrfaNbybItmQ9haEJudqbg9ix9CuUDRxELkAMb0as7k7gcdU\nudpsHiDQkTSUIT3Qf8UMBVaZlk1eL2AENJ+hgJbktl+pLZtWOFguUwoICLJBWQ8njwlOpBkt7PuE\nsHBQNH1hyU0sEgzyqQxYq4zyHa6hESZiCrudSHsMYbKsDCQKmnF39PmWCKxqrTWoea8f2EivuoHE\nCc+BLO47GImdr0HekmJ4VoVcYLy1DXNRiFN0iNHVzmPHrmJDYZV/vqPlDeVqvE+2BIpBT7ZAr5sg\nJBCZHdNF/Elw3EzE3+ZQC9aQKBE6rPmad5D8k5YgviB59xT6+MHKn3r3D9CdC0SmOwKKOdSKtuUU\nVCsVGsOagVmOcDGU10LzjLZsY2voAAAgAElEQVRtFi0J0Qc3ULJWamC5gryDEfbGUKM3Bjv2i/dR\nSfnHj5VEPIIuFVBkib5MMWzCdsZzqJwj4zyDWj5vggKSs9/oHYdeSHrjLzlj1p0V42V0T6TdmEZK\nTWLaZkjjP5yIyPvvEAOlQTQzaEoNkETg5Yx55C32u2ZzzSpXSemHMsmHa3IdSZu3bLe/Yn/l8ifX\nZOx+XJU+tqLtE9NI0dI/X/Mrfvly7dz6w4Go3FdjNvfIW5FCVox8yXCehY0yagd7ctVLbwAKhk/e\nUkVErGnZVcQ8FCC/2uQttO6gmX446ANWzCrNO6g51zKbBzX3vOFf39e8S9imgm0qlMxqn3cQtZaa\nxWOyPz6KDtgVS+L00Du6dzAgyFnV01qwv7XKoxYD1oOc5vfX1WpRDE+wcTXvuvxUrFwrshX3rDWu\nEKS2+W4J28bT8mwIERtA0m5Ckm36snl6B4uOUO2/N0nTIe9gwJfzHWtbT8Z2hJu8JvrQmy2FiClh\nC7N60KKU06sF3IGCP0Yll7wHRoMZQzHTDGnDXpayWm4C1zIwVNA9rdUcGIXZN576mMidviWzXRxb\n49sXAl8wsMA/xsoJ8pcSBU/zjiXEOOp7jmVS8ngAslqWtqYkvZli2IK2XUT2DRQDgmhgjLRts5x7\nEGOcSqj+flMV30dIKLTodSL0bSNGb6ZIWhWBevkaY3M4EEWbH4XY1T3ML/68nk9eeiKdLeGi9nm9\nwJeWfoPx9WP5wunX8H+PbOIvL/oaZnACiFVq3g5Db2m+l889bdO89VLv2B/9YS3nzxnPFRccf1B9\nfOCZ7by8c5Dr3z/b+1DveXwz2ZxGLqWBAiDxg3tXsdkpSPCpfzyJyWMasQhr3pWlI7tdf6/zcXUP\nFkKBa3nTn4SeXtPFI8v3ceOHTvMi1otGkUw5S4MzET350h4ef3EPLQ0Jpk9soXNyteb98PKdrHi5\nhxs/dNorWiFcYUOpMJs//uJu9vUXkFI54pM2cHfXBo4ZfQ0dHdNC5xfMvC9WKz7hqopMMbGXb6y4\nN3R8vkLzfu7lbh7ZvAncVNWySV+2xH/+5iVQyySmr6Bo+WSlxk1PAHIzuAXJ+4k1W2nMdHHh3An8\n/qmtDA6VGTfL13SDxPenZ3aI8+MlbC2JpOrsGxxC003iMSUsSOL4CZ3+j++oZ3dvDqV1L08Ul1G2\nXQ1JRBkPFzRu+eXzDAyVuPyiseLWJBnLtugeGgScSHTJH3NbjyPFNJ7esI1ZiX5mTWmrGVQUFCCW\nb9hLYgZItoK2+VSSJz8VIkPN0kgACScVpmzF0S2DsqlVuUJsSwJLDWt5FRpfPpNAaRVBcXv6ZEa1\npukPHJM0WxlmK/Gpqyi91Iytpfz2jJj4B+zoHeCWX66gN1MiMcrC/WpiOLWoy4M8u24/v35sM23j\nstDqdlJEUvfkMlzzvb/S2ZRkz2CWeDNgim9GNtPY2GTKQ7SlWnjmZT8S37+voNbqru0W53em2wG4\nc/09nNwxq3a8i2yKNtzgMMdaUVpzJraewDZU1NE76Fk3FmjwrmFmOmlKNEFJBLt2NI9l3fYBVm7d\nBzEorT4Lu1SHjEJ/UVhrmuvj5J0xHNf/DrYMlGHKWi/4rqkuTlmvuIf++fR0OMl0VIP+vAZpMUZ9\nmRId44UrIK8XSODniT9ciDTvoxD//fs17O7Nc/+S7VX7smUhDe52tJYgcYP/UUE1eXuk5Ex++/rD\n2vdjL4i2Hti6iJuWfjNkuq3EH57ezoZdmdBktvi5Lp5d3+1V7zEsg1Vb+ymUDTI5jXXbhfZkOaTq\nkXeFGd9dJuVOYNv2DbFuh29CzZm+VnnnY2vZ11/w2ga8NchlS5DDLxdtpKsnx+qt/fz2iS0hs7mr\nzdz7xFZ27B9mYFhM/I/tepIvLf1GTbO6YYj+xlQ51Hd3PIPEuLXGWveiJTRZ24gJE51kkUooJOMK\nQx3LveMkXZBVIUBGmmFy2x/XMRQ0dct+pLrascf3IzqIxy3vOXmacUwTpk5bwlSKvLBR+JT/vHQn\nz6zd7yWPUWWVklXh/5QNpJiOrSWxTRXd0ti0W5hcS3p4vNpTbRhogM2oVnE/8amrKdhhbV+KldnX\nX2D7viGyeY2N+4VmO8FJbtJfEM9/zrQOkulAycfhZmxLwlIK/OB3Ivgp6Au1ikIjdCO1g+M1Sj8Z\nu5wS5tsa5OvmsVY1oRWu69/gCUFuxrPyunnYhhr2V1eQt2vilhQD07LpaEqScuSQs8fOo8mY6B3b\n0C76Kak6tiWDrXiad1e+i+37hsgVdU8rbUk0M8Y6EXO4ha58F89u3k6uqLN70DeB1xtCECrbBbbu\nzrJsXbfXx6ljBOm675rrLtm63/fne/0P3CMO8br7zh53BhMbxmNjM1QerhKgZEMoIak5j3uBeZKq\nYVsSdrEejDj67uORZBu5LksqoVK2xHc0a2In582ayrFNk9kwuJlp08XzH8yL91wkh5EYk5jAgN6L\nlMgzujXtPceGRJpPvWc2tiV5yk1bU9IXnB3yTlvtTFBO8PqWLfrfaaGkM7vzJE5sn0FnXTtHAhF5\nH4UYcgJC6pPVfiP7gCUvCJnN3UxIru9VqTRlSbVNzot2Pk5faaAqorkWatbudYSDsiHuo9NZu+ul\nAK2INh9Z8xb34i6BclG2AxOxc7/BW3Ozk2mmVm1Wl02yAZPycCk8ybhc/IctDzJQGqyZdcowLZAN\nCsmuqr7Hj3+exPTnvd+1AuLKtiBDu5T2+pSIKSiyhGwEAn1KQrovB8hINyykZA65MbBGOKC9h9Jo\n6qItJRYgb9MCbKRYmeZEI7aeQIqXqtJn7sntJ6HEmdQwAc0uhd6Vjsni2tZwC5gKyCb5ongPKrN8\n1cXS4n1QjFCZSxcJSbwbUkwjmw8kRTHFxDyrbTqqpNDPDuIxiX+7ZBbHT/K1HqtUJywA8ZJnBXH9\ntbGhiRyfOA2A/ny1sCNZCiCBWaE5O/uTagJFlogPC3J9dt/zvrAqmyT0Nuxio/C31iB/M9sqtErX\nv+1sb29OolllpjQdw+XT30NCrqO8+RQAzjhFVC5D1T2N2y6lMTPtKI0DyM3i25Ad8rzm1I+RUJKY\nvULI2Ws6Firn2zH6RzPJnOe0GXgXnb5ceubxpBMqaOJdcZMDlZx3Ttt6EqUX345VrAuR97hRTh4B\nh/iS8RjHNYtqeHkjXxWVPaV9lD+8DQNCaFUMMGOeddF2+iCpGsm44rXxrxefQjKhcsGkc0UDDb1M\n7KzHkp3+OO/8xLiwcClt+xjVmvb6m5QTzD6uQ5i9nW1j2tJV1gNVkZkxfpTTB92L1bDNGLppM731\nOD520j8TV0Yu9XsoEZH3UQg3pWFDuka6wVcIsAp+YO5yD9eXqCgyECAb+cDZygyrOjBDMzVRLEEO\ntx3KmuWQd8kh77FO3WOXICqXihkBn7duGV6gkrf8JlS9yaZsB5f4iD4G/es9gexfwSUvUqJA6rRH\nWbTjL962XLmCvKtyh9fI8mZaxI9byZ66p9haeDm0T2nuC/2u5VPXKGDbYDlZtSTFEOStSEiaX3fZ\ndLRGzfKfeX+5j8TMZUiq4S83CvrNAwFKdllMikrM94drugmqjiTbNMQbsEsppHiJTD6Q7U6y6Cn2\nMrZuNC1JQSZBa4LWuB3bkjB6JmBbigjGGhSknXeimG1TYczwOYK8Ee9lMmVWvXMtsrOkKFYmGxDS\nipYg77H1Yzix/QSM2BAtHWUkSfKIxb1HW0tCrOwJGG4msobisXTUif5nQwF8jvbs+DhtUw2Rr/ve\nJeQE8ZiCVaynM93OjuwuQd6ShSTbWIZ/vhCgrND5Vq4Fu9jgkYvSIN7rlqYYNjZJ1dHsFckjLtcq\nIyl6IChNwth/DBAonOFcI6UmUWQJa0jYyPP0e+MNYPaOp0FtAFsKPUM1bnrnx2Iylkve5QyWbVGS\nRTu2HgdkQdJObAKAGjP8sUMQn/usl+xZ7uVkd3FSu59tJnHcS6RmPYuk6khmjNaGROBaQEwnHlMo\nODEPKUX0rTMlNN5sOYuqytiyLtwWtqC5MbFjkWwFdew28vWbkRQD25KIqWIcG2PNSIkCclOvqG8e\n8Hm7z8H1a8tNfdhjnRUThloVVHskEJH3UYxaUeFBM26tJQth8hbHFsoOwcmSZ+6CEaIxAyjVIJ4l\ne5fzu81/Iu6UKHQTHlQm+we8oLr6VIzm+rgXqeznwq4OWOst9PmEqRiA7ZVelOoyJE56GjsogDj3\nIAX81K7mDWHylOurE9G4ZnMXlZp0Lf+pYfpJNoaMAye3qWV216WiiLB2J+eA5m1LgQxteYe8bf8e\n9pZ3ICkmetfxGN2TgLDmHXz+linGRFZNz/pSMjSSJz0FQL1aj1VOIUli+ZUXSZ7MY9kWY+pGe+lb\ng9HphprHLtWBkRBaqwQ9Q4Jsi6azpGr/MSQK47wJXU6UeEL/XxLTnwtZB1plx7edyntCK0DRsa40\nJxqZ2ijiMFIteeceAuRddDRvyRcw3ICidCzFqAZB3iUr8Jwd8rZ0550xYk6kcfC9EwlyknGFsm7S\nnmwjbxTIlYre+YYhuwMi/ne/rQpSsDVBCLGJGyFWoqlRXNclJUWWQBcEVrByoh+q7hEj+OlTvWVy\nAdO+KsvYWgoZGSvmm91BmHx1A5JyGjk9jNK2l8TMZ5AdQSKlJokpMmbJ1byz/HHrwxjNwuftWg2E\ni8dGaXfW5scCPnkH7rNetm8Fd738W4KYN2Yu7516iX8/ySFQdFRJCEiiLdcXrpGIyRSMIkkl4WXO\na3LqqWfKQyKHhaO5e5kiTJVYYTSSbLFOe1ospzNVYoo4/8KxFyFJoI7ewdi2tDf/ueMcU2XqnMC7\n2Litfl/N2EEVHDrUiMj7KEN/IUPylCeQW/bXXCccJCOj1gtVg7xdYrVtO+QTr6V5u/5qqC7WAHjL\nJJTGAZANr22fvG1PA3LN5om4Qkdziv6hEoZp+ZHyjoYeLIPZEyBeSRZtuZp3/NhVyE7mLjcgxtMw\nAm0EyTtoqQhOhi4KevgeNcMK+fprJWUIErxsB9s8sLAFIg7AUHJYpTS25ZyrGCTiCrIsYztadHnT\nqZ42plt+H12t08o3CpO1c76LEHk7yT0kxcS0bAzTIqsNeoFCti152rmUKLK/wmffnGyixZkw3XSu\nSBaWpHt+WDdCtzebc8bL0byNGLph0eFoS7HxIpuZXJ8N9bfDnoptg9wUWAoGDNvid0uiBVsT10qk\nxHlFowSWjLZtltBuXXOrI2C4qTjr4knGNAvhQx29E6VjlzceAGXNyXtQSiPJlne+uz+pCGIp6xYt\nSeH3HigP+uTtnO8SnEsGleZYs38Mck6YY+Vkgfp6cZ6vecvYegJsWJ9ZizJqJ5IEiuW7GWwthW0H\nnoOsE1fiKLLiLAGVSMuNSIkCx09o8oPLzBjDeY3ZTWcgqQbxY1cj1w1jJ7NOH5IidqMk+rJjaCeP\n7XrSfxCGK4CIMY5PWRsao+A35Uac10JKTXJK58zQNkm2ScpJLzmPa2lQO/YwNGoJBb1ISvXT5SbV\nBEklSaacJaZIQrM2VRJxcb5hWpT2jQ9dwzZjnvvw2JaJIj4hVqalIeG9h+51Fdm3HoTa0OO159rD\njIi8jzI8vnOZSBRw3Es10xAGyeBHf1hbtT84eWu2+LusmTy8fKfIchUMOlGq28+X/PaD5kkXwQpS\nUsqv4OOlHJRsQbrgVeJJxAR52zbc9sd1dGcdE6ZD8oWSwc/+tI7t+4ZYvlVIvHaAmAbcDGSBicI1\nobkfYFk3uOuRDdz57OPszPjpX3/1WKAkZg0ff7DYA8AdD7/MMxt2+PsDWt7zG3q4/+ltXoY1gBUv\nB7JG1bBkDBULWJbN/zz0Mt/9zUrW7u0CyRZBOs49SorQvFVZwpZ1UnIdVqbTm1SCfmQ3iMc2Yz75\nB4Uwl0D2noDpaeZim6abFAMa6JT6qb5Glyiw28ls5b5DxbzM48sdM6yreTt+U9fE6fZxy0AXS9fu\n8zVcI0ZXT47lS2JYxXqkej/gUJIgZTfz2TmfADOOlWsRVhGnbbmhnyFpPzNaj6cp0YBWEtey4jl+\nt/lPDJYzKHoDZt94QAqR992PbWbzfiG8NSTqGN/a4l03Ptl5FxzXgusxsYvChy4lnVgKR4BKqkkS\nMZmybnoWiKyWDWhsYdLxfMoBUhjVmgYkFIe8pXiJtMNHSVX0W5Yc06/zPONOXeykEqg1YMvCwuCQ\ntyXrYc0dyA6oSDGd3WN+i5zyg62GCjpnj52HbVbTQVJJEFcV8jlQibNreE9ovyuY6Hum+hslv5JY\ncL16LeKb0Xo8lxx7MZIk0ZRoqNqfjqX9zHqm6iXsKSf3M1jOkI6F6y00JxrZm9/PQMOLYi4zVS+W\n4p7Ht1AeaGFy9p3e8dZwi5dpsi6pYhtxJFVHVWTf8uhp3lLoHqxcE+WNp4KexDBtNnVl+PkD6w95\nIaeREJH3UQYvwtGSR9C8fXJdvbWvan9wOYcZINpHV3RhmFY4KrZm1R0/GKyW5h3URCXVr9Lk19AN\nrNX1yFtm1mThk3txU6+fwcohnadW7eXZ9d3c8svneXHnDsDRLBHLSTzLgeFrIp3pDq8PAC9s7OXJ\nrmdZXlgUShm5bmcgM1egb66/2fXLu9jTm+dXj/vJN4JLjH58/1r+9MyOUCYwSxJ/nzdnfDga10Hf\ncJ6+TJElq/exbscgK7YJ4cQq1vtai2z4ZnPZICY596m7+Zb9PnqWF1MNkH+15l3YO9Zr31svr1ue\nyVnfNY2J6SkhzXt3r/PsHRLq2qvRtdsJ7nK1UvcenWdh9ApNRx23lb+8sMeL5G9K1lPWTdZvz2L2\nj64al3pzFJObJqEbFtZwC5IEckoIdW5ynrdPOFvcS05MY7uNjSLeApDxScMn7yKPPt/FcKmIbcP0\n8e00pfzJ2DYVLjhtgvfeubGK9YrQqpV0nqb6uDdeKVVohZpmehaIgVLGF5Zcn7kTeCincsSOXYXa\nIVZttNbVMXl0g9NH8b6NHS15edBd8vVQIfx1NjaGfttaCjlRIjZlFSYaKZf8HfK2HdO7jS0sHDbU\nxVNccf5xjG2vJ2ZWL29SZMVZlSJR6vOjqMubT0HvOs57zu3pZibEnWWkqk5RFXPP5I4OLxVqXQ3N\n+x+mLOSCSW8T/ayxfGx0U6OnOYPvv3aRVivJ2zGdpzYiyRa2odJQkRBmzqQp3t9mpsPLQJlMqCTk\nJHJcFwKPa4Gq4bcHMHomYGU7aUzHMEyLp1ftZdm6/fRnj0x+84i8jzJ4ZGHEvIQQQYQCoCo0ye99\n8kzGjvJfZMM2mDymkUmjGiiUDQzDqliPWi0cZEv+MqNaPu9g6khJMTzy9uoDy7XIW+GMmaOZM80h\nXOe6biajYPUeOT0szLmOyTc4oU0d3eH9PaqCvLsHi6GsX36DZu2/HXPgLv1lnt+/MnxOILCndi7j\ngHnc6d+0Cc186rJpVUdqZjkkhA3oTgnIYr0n8UuqCNBRFAkUHQXXzxj0AYoJzrUE2IbqTToN9YHP\nXNHF0idLEe3bEpYsyLikGb7mbsQo66YnxMjJgle8xBUWTC0WIka3L+75AP/90X9gUuMElPpBhu0e\n9sbEWH78nbO5/v2zAbDyTVXjIluOS8AwPdLxVg441291TNXZrF1V79pSfWuEbz1wBQwDhRjzZo5B\nkiTU3aeKZUKKiT12LRPGiOuVyqKm9LX/cBYA8+c1M3/2uEAwWIJETMEGGmMueQ/6AW+u5u1o7kpr\nN2rbPuQ6IYTc8E9v83IPWI5PefrxKe+7cjVvL6eMHo7GnzVugvf3tz8+jwmdghzV9n0YUtkjb1fz\ndp+Vi3Qsxa1Xn8Ox45qIqTIzx4r2bEOl05jBW8ecDvhLSs0+EX9QrzZgDY5G6fdzPnz742+lNS2+\nydiY7QzYuzm+ZSpffN+5/MvFM4BqzbshVs/ExrAZ28qFBZJxLS2hnPZSxZyUrqHNB2FrqVBFwpnH\ntPD2U8czPjlZXG+oDdW5P1mSOG5MBzYW31/zQz/4L0DeDTFfwDH7x5KMKzTVJzBMS9R4l6Ct6dBU\nX3wlROR9lMElC3dyrUTIh1rxosdUORSJbdg6MVUmnVTRdKeggHJgzXsoQN7lGpp3KFtWgLx9zdvv\nk0fejmTtfaTudR3ydgPzpGQOuW4IK9vmE1egv2ogAdKYOmfpiUMm/dlSyKzuJVEIEHZQcLED2ZTu\nWH936B6DwVmVZnWlfTdKu798zG1TUYTJOwjbktCscMrUYVMEuNmlOo+0pFiZZFxBloVA4+ZxxlKw\nLRkppnnpJstWwIXg3INhB56pE8ErGEEiIdWhSeKZarqFZrmFMWLifdATyCjIyYDP2yFRvaSCGcM2\nFc9c6xKrazaPqTJtyRaQID9KVGhS+o9lcvNEjxRcK0pojB0TsW5YHmkpDYNI8aInILgTaX8mXPEL\nwFQC5K255F1EqsuKwCzbJ8J0cRJG9zEAPLN/GT3SJm98VUVmVLodWZJZ2buGPnmrFxOQiiW9dzet\nOMVB9Kz/jjvjbzlL+jwyQBBZc6LJS0lslcWzzpQy3mqKSh+xvP2tnNJxovd7Zsdx3t8xVWFWy6zQ\n8UmPvMU4G/smc0LsLG99eqXw3ZRwnoNkc4w1jytnvFcc5xatGWrj7JaFvHvM+wHoqCCphrgjPIze\nSVxOcOnUd4b2B8n7golv4wunX0MlyhtOp7R2nve7M91BIjYyTbkCigvTDs95drE+RN5u8NtFoy6l\n+OLbwYyhBoJZ61RxDz3FXuRkwUmyI85RFZn6eB0fPuFyJmbeAbZMQzqGqsjohk1vtkRrQ7KqENTh\nQkTebzDolsEL3atGXPJVMv3JtbbPO1A4voJ8Y6rirSEGQLZEBKVTYSlb0MKm3Rqad1+gwENNzTto\nNlcM8k4ke9Eh76CJ17CdJTfOB5WIK0h1Wc8n7loO3OVZSpvwVZt943yTslJtogY/o5N7P2U9LJgk\nqHP6GPQHB9ZDl/2JqXISDd5DZWrP+JS1xKesCbTpkLcsoVNhTrNUdFsLZR3TLFdzjnvFFKR4mURM\n8QOeLH+JkK3HQfXJW7c1Z3mM4pnNTcIJQkLEJTWiSQWQRIpUL3LdUB3BSyItNSKlh4jPWYSUzHkC\nUbkozKlWoQEplUNp24OccM+Pe/fdFHdIIZHHzHQwpnwasiT7BXMMv7b4jBZhnVBNJ5LesDxBQB29\nk+QpTwrLhy15/s7eTMl7Z+aOOhWAMaXT/HE2Y9iGitLc65fRDFilEjEFK+dr/xa+5qwqMnElzsJj\nzmNYy/F88REUZy11OuYHU6WoR5UUitKQJxDalkIqoYAR9zK9uVAlX5sDMHSFhBJnsJxlq5NCdErT\nJIKQyg28a8qF3u+JjX5lrrgqc+boed56cPDJ0hUQsFSmpWbzDqeNyY1+8hcARXIEViksCfnr6yVa\n9anETfE8O5rDxOmSN8C5o89hQsPY0H41UBP+H45d6AsLQVgqdqGJy6dcwTWzP8bcUbOrqskFMXfU\n7NDv9x1/ifftg1jn71aQA0g6wlZSjXvvnRog2/p4WJMXFj4xfm5g2+mjT0XVxfuSTsaIKRKGKQJn\nK8fkcCIi7zcYFu94nP9Z93/cv/Xhmvu9IDFbqql5B3OaV5KvLNuUrTC5xlWZdDIG2BhNO/xoVaiK\nNpcb+1jc/cfqvgQQIrMKn7c6ertXHxfAQiz18j5OtRSuUSyHydsNGDKHW3yTskNoqiJj4pN3a7KV\nhJIIpYN1NSZzqIVOyzFhBwQc19xp5ZrQd87wtlea+4Jthgs01Fiap7h542VPsDGHm/nQ8R/ANhVM\nWw/lHNesssiFbckhzTsekz1BwF0/DIARR1I16p01/yaaFyTkBqwFyRtVJyb5ZFkvi0lIbhhk7eAa\nj7xtM0beqaLVoDoR5bKN2tnlkVCx4JiF841IEsSPXYM6QQRTeZYRSQpN0ma2jQ4nKU88oFG17lnI\nl97yWT4y/QOUN84hVRJJRXTDCsUyiPHQUOwEsiRjWlaoZOqxzZP477d/i3GcGDonKIyBsxzPQTyu\nYA2OQt80h45Um3+Q5QsYFx9zPh+Y8T5/V66JZEz1a0obNu2pdqz4MI0Nzn2ZKumEeBZuJjcXHzxB\ntOUSgmHatCRb6C8OsGlwK82JJi8ILvhadaTamdQwgYXHnIcs++MXU2WSCVUkxnFwaufJ4hoBzTKm\nysyfcDafnfMJPjDjn0J9mtYqgs7M3rApOzPsv++9mZIXhNpWURmsMemblMc3d3IgjFQO2MVJ7TM4\nrmUKkiQRj/vvu7Z9JlY5yWzlHXz5LZ9leutxofPG1o/m+jmf8n7bxbqQ5u0+r2CKYzVQddHVvF2Y\nw63e30GN2nUDphNqiPzbm4+MyRwi8j6ieOz5Lrp6wqkphwoaDyz1g5y2O8kLdgyJZSuPrOjinsc3\ne8uhvCQNstCUlq3bz12PbOS3T2xhqKCFfd4V5Js3CuGkIgHNW27qJT55HWpnIA96KEDGJhYo4wjV\nAWuWbYfK5EmKEYo2V8eJ7E5mthUz60ySkuV9nHvl1aH2RE1rvw61p7kYcZ/YHD92XVL1lr5JG9/G\njq4yacXPmAR45KdvO9ExHVdq3k7U9daTwIxTeukcZDPBkBZ+ZkENasPgZm596o/c+9ctxBM1stsF\nNG+3kIax5zjmjj0RLAXN0vnLCr82s47uCCaSuE9bglhZWCVc4UMPrO/V40iK5UUoIxu+VcJdR+ya\n6yUTSbZIyP6k26AKv3HsmHX8pe8BhmNOX4yY9+wanWPASTiiashWjGLJoqUhUdNnbet+bEWQvO1S\n2iPvYKnaJI2MruskEVOxsh1YTlSxHtC8XcjJAoolnv/9T2/HtGxithCwOlLtSJJUpa25goxVrMPM\ntnGccoa3TxwrYWY7mBYkA0vxtFZJkpg35jRa4+K9NbonElMV7zovbupF1uqRFJN0Y9k737VqeX57\nQNtyMjNahb/YNWnbNqT0jnwAACAASURBVExrOZaSWaZgFJnaPLlm8Q5FVvjc3E/xrikLKrZLwrxs\nJLC1BAoqJ7YLAVRRwiQPMLlpkhfU6eLE9hOIbTsXfdf00PZgVbvebJGCM1e11CdCxzUHyLsj3Uot\n/MeZX+Rrb72h5r4g4oHnlwz8bfZOoLzqbYxPTmZUXW0BIRiBbpdTNc3mSlCgCZJ3haBuF/0I+CDJ\nu0pJXSoW2t4RkfffH/b05vj1Y5u56X+eC23/5cMb+MNT2/ijk6fc9dmokkJfpshv/rKZxc+JZTa6\nqfumV4e871y0kcdf3MOi5bt4fkNPyOddGdzh1om2yk6QkWx6Pu9ahelD/uBE0VtD7aLSbL5++wAl\no+RPtorOcN5J0lLWQRITsbb5VH8Nsmx6H+cwvdiWRGnVOZhZ5+OXA+StaiKBhy37EbxOn+rTMXRL\nx9bjFLJJfvC71aTUdCi5hr+EJ4ahy1X36Js7nWIMRh2y1iisCZIFWMSnP4fS0uMtWQHYYDzDw8/u\n8pa+ubBtKeTzdjNC2UYMWZbEGnDZ4PHnffK2EMk3Jo6qByTQ48LnHTCbG5r/2bpaaSJtihSwiu6T\ntvMcOjqcNe+uoBPQLprjzc44OlYBxe+jm9K0PuaTr6Tqjt88TqFk0NaYrE3eAW3ZM5sjfPluLedY\nYFJWHRLz/LPOulndsJDNMEkAyGaSkmbw4DIh7F7UcQVXTv8nprUI7TFoKhX37rgjSmm0jXOZnvTN\n6u77Z9swJu2n6cSWQxM7wPunXIm2YwZm/1hiqkzKuc79T29n5y4np3Z6nTjdUkgnVc49ZayXZAXC\na59PmSpMvP9w5jGcMdrv07njzwx0vur2PbQ42cckSfJIpLT2TK4c93FPuw1mF4yrI5ugAS47ay7Y\nMmec4I/DhXP9wLjeTNEjrvGdgqynTxTvUFOAvNuStcm7OdHkrYmvhfEd4t0MCl+1zOYH8oMDTJFP\nQ983GZBFeteKtoKat+dWIOxDNzPtmAP+OAQ17wWnC5fD204ZGyJvNxvckUBUVewIoViuvfZv/4CY\nLF3Tn0veiqzSE8gnPVzQ6Q/UL0a2KGtmyHReLBuUpaDmHSbkYUeDtMtpSJRANompCnXJcO7lifHj\n2aVtCpO/G6S07xhGG7Pon/DnqoC1fFlDUkyxbjemISkG/UMlLMumZJSRZJvpbZP5+GfO56tPbKef\nHpAt74PSEZnF7HLaXx8qW77ZPKb564fLKSRkSAhLREdTim5LCwWaqXZSRKzLplgj6yWmUCmXbVER\nyCHsKWMbSU2qZ1sOT7CoS6pCY0oBqoYkWSL5DAifbkX0uhtjEJfjDL94BvETlgc0b9kb/1s+IqKX\nZTuGpYhc4u4MLSkGthHnuHHNfPby2Vz/2FKkZE5MHE77Wjkwm7sJJGI6LQ0xioqF5Wb0slQScoJU\nyuCmj8zllj8+AAjTopsfqjXRChUp6t1o9JyjeadUn4ileAlUDb2QwrJt0kmVn3ziHfz48TgbSi+i\nNGRoUBspBsgqpHlrqZqatxfxK0tIkh+kqBsWsRqEI5kJT7iYPKaBK+efSl+fbyFprwim0ndNI3Hc\nS+h7BbknAqbYoJY3OqTNSSGTKMC4pk7MHuGLjqkyHQHTsV1hGscU39aHFkyjdetuFu0Sgkaw1vak\n0Q386NpzhGUFeNv4M2lPtYX93QcoV/Ctj83zAh49Td2Ih4g0SE6V91OJd501hVMmt4a01ffNn8q7\nz5rMt3+9kr39ec+d0tqQ4NZrziYVF8cGY0Nqrek+GHz5I3PRdCtErkGzubftAH5wgOPUuazrEgpR\nkLxdn/dImncwT4W2KRA3QVggPPeUsZw+YxTppMpTq/wA1WT8yFFqpHkfIVg1UpUCVaYxw3KLhMih\nYhCFkkFf0c/JLSt+SktX8ivr1oE1b6fghuf/U0zH5+1XParPT+WczvnO/mCqVD/pQn3MCc6p8Hkv\nyQo/va0lhd9WFVWSBoZLXtnIpkQ9qiLTXu8kvohp3sdZtovexGa7NZklV/O2QdWwveAmmTq5Ednx\ng7c3J0WQn+l/1JYernYkMi4JE2nZ4V3h47dJxhUM1zfsCACphIrlraUuh0zwthFnsiYKIXjJLZzx\nPnPs6SiWWKftEroiS2TKQ0hIjKpvce6gVhIVE0yVeFymPhUThUEUC1s2sJVgoJjTDyeoTZcLtLW4\n5nKfHBpiDWS1Idqakshp8fyntvqaVGstDckQZnt3kp7deiqzW+eIcUgNI8m2l9WsLqkSjym0SZPR\nd8yEgQl8aMpVBNXFUGCSLdf0eYeIXJFFwiBElbRa0buSmfDM+lPGNFV9R5WBQ9bgaN7ffg22YyUI\nam5BIvdWKQT6EkRdYAKPq3LIxxn0N4MTsJZUkSSJ9nTAOmGE1x2nEiqyJCFJEv90/Lt5+4Szqu53\nJKiKHCLaWvcUJKr4K5C3JElV7cnOto7mJLphsddZMphOxqhLxjyiDa7jrmXyPxioilxlNamteR+Y\nvINCyiuZzYPHnth+AnVqmium/2NVm3WBQlCSJHn9DL67ifiRo9SIvI8QauUZr7Xd07wlhd6MT475\nUrXm7aIuJV4iTTdDRSrCmrfN5sw28Zdjcg6bzZ3lN6UpXuIKKWg293Ihq6QSKkk1GdK8dctgW2GD\nc7AFZswj/L5Myat85Urnk5pEQJJclyURU9AtQ0RKuyZ3JxmDu9YbVRdm4YD/s0FpEfV3FZ2OppQg\n74DmrZXcQLhAZivHZFly5CK1bT9Kx26x3MPSPHJXFYlEXMEoOwJATAv5usHGGhjDhORkp9604Res\nUBNiQgsUtFBkiaw2RH28zsvFbLqme1dIkiyRWMJUfFOuQ85lCuiKIF+9GCAMxyeXNftobnL8pwGz\nbGO8kbxeYG+xy8vHPGOUr9U1JeqIyeIeZdstpOFkbnPMo+lEgg/NfC9WOemZ122nIlk66QpbNnax\nAXXvKbSkwlHESSXBuPix6LunoioyTfV+JLoLNaAdKrLkFXoQmncN8jYSnvm2crKH2r7HukCyjrBZ\n1m+/MR7O8hXsl/gd9h8Hr2NraYrPBXzRAZ93Q9zXhG0zTN6viFfBg0HNVKkIbHu1cO91pxO3U1dJ\nskqck9pnVvnjXytqEXWlUHWg/alEtQl+pIC1hng93z7nZs4c+5aqNmu9ZxAm/1cSKg4lIvI+QhiB\nu6vglsNUZYW+rK9590gb+O2m+/0DA8TqlgYtaWaIUIOat9zcy7J9ItLbKjnLpGJlYorsmM2dNddy\niua0Y/JSqoO9MEVO6sZ4AwOlQTQnA1le9zOvGb0TsE0VJSau35sp8v/Yu/P4qMqzf/yfs81MJpls\nkAAJ+yabICgo4i5Qt69WWxUXcKlaRVu1daFUpbUPuFT9Wbva1trqQ12hllddeLpp1YLWlcUVtAjI\nkkD2zHaW3x9nmXMmM5mQZCYZ5vP+h8xkZnLmJMx1rvu+7uuOWevL7eA9uXqMeVyl+/CrDx7Gqzv+\nbZ6npJ7Y/knrzKBmNUZxFy/ZhVRCoB0DyvxQDc0zbG533rKDrjkkbZ6rcLsrWJTvRWNwM3a173Iy\nd0kU4VckaBFX8PZUrsdR3xSBz96yUo45vxO/5IMkCOb6Z8mcKxdFc7cjuwMUAGiqfYFiPq+4OLGk\nx+nnbL3fiNaGmGiNnESCieYeVrOaFr0egZB1fK7gXR4wA+nKj58xHx8NYGBxYs4x4JfNoXMAJdoQ\nCLGg01TEzmz9PslsEqO5A5V5UWF/gNt75IiC0GGeWBAEnFJ9DtQvx6KqPODMwbqzMzkp8/YOm4uY\nV3YhYlumIbLpaMS3j4PSNNK5uEgOIgBQXtJx7tGdOaWbUxUEAQtGXYzohzM7HFcyRRZR2mFnP8HZ\nIcuI+Z2LG89FQYoe+r3NfUHiHjbvjeAdjWnmErqkQCUIAr459RKcMvLkbv+MVFLOb2e4oHFfdLmH\nsv0phs2TL9DSCaYY4QAS9RoAg/dByT1s/t6n9dANsxfuftd2llu/bEJMtTJcw0BdY9jJAPeXJ/aA\n1ttLrAIq8zXNDy8De/WtqI/sT/xQV/C1h5dlQYbeWG0uMSpuhqJ4h819QgClwSLo4SDEUKPzGu7M\ne2d9G0q1oYjpcTyx/jWs/2C3s7etumeY9fqJIri6pjBiVqFdsbWOcnRlrbmOdsBufNGyHau2/MU8\nUCt428PmghL3NOZwF0KVyFaLVCWC0lJ7eU7iP09Ts13oZm1VKCUqsdvaBGCH+SErVdShrug980nW\n/2NZMiuW7QsdqWq7N/OW4tjXHIEMa3jWmuMHzEzTzLytD3YljrgRQVxXUe4aQraXfCnDzKYgpSF7\nIwvZmUqwq5TbjVZExCYYutnD2mn5GPfDiPuwH19ik/oP8xQ0JqqI7S077SmX2JbEOmDAXPs/sMgM\n3pE2H9o3HO08xh42d9bhC4lhUfu47OBk/32bc9YdPwztAJuuGtcdJCVR8BSs+WQRgwODoe0fAqO9\nFOquMdBVn7MbXjDFvvbuzMo5Blfm7Q48/qQ51YmV46C3mFXlyRciboospXyvV0y5GPH35gGaL2Xm\nndziMxv8nmJAd/DufnAZ6JqKSHXBlC2ZsuxMz3FPz9gXAuky784Up/g7AwBZTrwWg/dByJ15P7Rq\nA155dyfuXvmOp9HK8sfeRn2zOTcc0+PY1xxFZWkAJQEFQsRd9GP9J7IztiIFUtUObCsyd/s5acDp\nAOBtB2oF4UvGLwIMEXpbGUR/GIZoNfiQzbaZPsmHoF+GVl8LQdQhVe72PB+agoaWKN57y/xD/vN7\nr+PXaz7Axm3m4+zgamgydMEMmvubo4gLVvC2Mm9REJ0Mz3OeUvTrdg9ZuzPvErnEeZ8lQSvwuTJv\nLe7q2Caaeyy7s57wl0M9VePun2suvZGgt1ZgQukkSKFGSFWJZXSxz6bCMIC4nZl7Mm+/uYey9f7E\nYDPaNHsLy0TWO7bYWspTuQcQVZSUJC5A7A+BgUHz8a/sfx5hcb815SGgusIOgmaTFA0xRIw2xHeM\nxfTBiTXqdvAGzPXtRpv5evaccNAvozpoBqq2Fsks7LOCS0u7+Tu3i3wG+BJroO3gbVfXjq01f85h\nYwc6owLuoFhZGoAAYGhVx9854B16lCUhkXlrZuadnDFqmp5YrpNuODPpQzlV4RLQ8QPXPUfaWYGX\nnbFVliay/AGl5haVMsy/02CK4J343XXN6CHm//3Dxg3M8MhERul+T6LYu5k3kH4IORvsn+WTRQyz\nKtyryjpvhuK+6PLMSSuJkbVU3+/KcSSTPXPeuQverDbPESczKWmAPPQTfLjTu7ymTdgLsbzOmeON\najFEYqq5jlY30KaJEACE6meiWbaWFllV1MGA7GzacP74ryKyx9zooXaIjMtPnAXdMPBKXQPW7/3M\nyXy11lKIZXUIi/sQ9I81s1PV3NtWFAV8e958/PKjTzF1qoh3/55ocFLiC6IZifWPdr/o3U1WW0+7\nGMfOOiUVMVWHjihEeCtSJ9YOxsdNiZaR5vPNDz17pACAOd/tt+daXQ1GrOB95IwifN76mfVzXWug\nnUYuGgTr/BiajOKAbA25CmZBmL9jsxnJNSw4o2w2Pmr+wNmJKbLhWHO/agDtLQJQYgV915y3JEWc\nJVRicTNaVfPnuzPvb596PJa/vBX75E8hKDFUlgWxwzpGe8574UmH4uebXI1rrO5XQwYU47yTxiIU\n9OGdLwdgXcM/UOoL4ZRDvo5h1SFcEDYb5OwT/us89fCRI3HWCWbryTsunYn9zebWh9VNZlCwh8JP\nnjEUf39nBzTdQHFAdoYd506agj98bG7KcsnJhyEkDMCU0WbWfszUIRhUUYTRNWaf7GWXzkSFK6hV\nlRfh9kuPwODK1FXInjlvSUQsrsEwDKfaPDljVDU9MSef5kP1vmuPxusbduGZl825fk+Rmiu4JQc0\nd5CXU2Tw9187B63huJN1L7t0JprazIs+ewcrO4ja2Zq7u9hti7xVzJnMnjIYA8sCGF2ToiNZkvsW\nH43G1ljSnHfXC9Y6M6DU3BfdMHIbvAM+GcsunYnykB+KJODLfe2oTXMRaHNfdCkpRlnSFay53X/t\nHLz18V488Tdzu9p0GXqqi4NcYPDOEbswTSzdB6m0Ac2tewAkrh63la6Fuyg3psUQi5vLqEQB2Cuo\nKJaLIDQMhVi1y1xcJOowAAQUGYJsZn2TB0zA3z7ZZ3bv8oedtZhavVWQ5jOvnu250jZhPwRBgCDH\nYaiJhgMTBtdC/FhE3JpntTPvimAJmvfHzbXWmuQUpe1vbzH/muxqcSdwxs0GNIpVze5aQlJRFAK8\nsdvJrGPbJiIweb35GnIMorVlpN6ayFxDivke3t3/Nt7d/7Z5pyvzdobQ5Rj8483vG9EihII+TB0z\nEOs273aCfak+BOMGV2P9f8zzJEuCk50FjUrobaUQi5s9xwgATc0wg7ccd4oI/ZIPoiA4PbvF4ia0\nqOZzy1xz3n6fhOpQGfaFzfqD8jLRXLalJ4bNq0PeavD4DrOJSHFAdrLYE8dNw4mY5nmcX5FQWQpU\nxkc79w0vH4RqK3sqtiqFAWDW4MPx7KsfQ9s/BJNHVngyQ3e2NW5AotBt+qihngsxURBwyPBEtfWI\nwR23dxw5OH3w6ThsbjhD58mZt98nQdONRJerNMOZpUGf50PeHdB8nmHljnP0smQeQ6oP9oqQ31lf\nDQChoA+hoLeRjP1+3BcCK+bcDkkUUaIcWMFa8rntTFmJH2VJ8/2pmrR0hyyJqAz5sa85mnYIOVvc\nf0/2KE9n5DRLwVIWrKW4QAPM3/PIFH/Hydw1BRw2Pwg5w+ZWT+WInmKHKxd7yZdfkcwPJ1GDIvoQ\njWuQkpYYybIASUlsU1jXGIERC6BdS6x7tVtzhvzmB669WUMM7TAMwwrePkSsHbxkUUaFvxx14X0Q\ny/dCHmAPi7u2WlQVZ/mUvduYMyft7GGsojUchVhsZuYhV+FOyrWg9rB7WzmiH1vLk5QoxFAj9EgR\nEHd1B/OluPp27fhlX0BIFXsgKHFIrYOh7jSDn7OUyPp9+EQ/Lp9yEcT95m5DdsEaYFZdq9aOSoYu\neLL7BmsBgFi6z5mXtzd+QDwAI+aHEGzGjlZzyL0maSlSib1LkRxDaYld7Z0YNncXOk0UToTeYI6q\ndDXzce+6lG7tbUD2I/blKECXMLC8KG27R3exXbHcvXW86aQqWItZ65cVyRu8Az4JqmZkHDYHktY4\np8mQUs2P2wVv7u1dD4Q9kuD+PZX5Qx365OdCb2XeQOJiLpeZd3d4Mu8U1eBdybzNx2U+X+6Lg1R/\nS9nC4J0jTsGatYGCs/uTLWlLQ7tHud9ndmkSJA2KYG4Dam9q4ARvSXSGtQNyAPWNYQhqAO1qO1Td\nvD+shiEKIoKKGbTsIdKI0WbuYiQYgKZ4CuiqigagOdYC//h3nPvawq41yZriFGm1WNXmRofMW0Wj\n8hnE4hZUxMd4AkiqYOLeYcoZQg81QJDjHdbRFrn28lWsYUnPPLp1DGKRtca8bZIzn+tklFa2LMIq\nHrP+I8qS4BS6tEfi0PbVmIFb9cFd6mq0l0JvLYNUXg+p2mxp65f8zsWa3h6C6I9g475NKJKLMCxU\n63kPIcVV+e/XneO2P2R8kmvNtpgYdTiQzOeiCeeiSC7C5AET0j4mZm0vW1bs82Qi7vXSgiDg4gnn\n4uyxp3d7HW86SoqlYnbzEZ8ieoJOQJGg6ZmHzYH0WZV7Pa6U4jF2Zt3Y0vlFdjr2h36uM9RUpG4U\nZ6VjX8wV+/v+fXXGezHoyox9nS8VS9aF2J2x8U22MHhnQTiqoqXduyuY0yXMyvTiRlJ3LtVbgOEE\nb8Xa9UtUAV1GLK5Bttbl2vPjih28NRkCBNQ1heEXzMBoN2Zpj4dRJAcgSaK5VCfuh2EIaFWb8ceP\nVgEAtP2DPB9WA4OuTRosEVenOMOqKPcrIiL2Bh3OnLcVOBUVMdnMumvh3bIwOXifNOxYs2DKZjVZ\nEcvMPa6T23C650EXTVqAG2dcA3XXqMTx6e75bxHlYiLrtT+c7WAfMMzXtocYJUl0/qO3RVRA9SH+\n3ymIb0/sya3IImCIiG01h6ztna38kh/2SgB7HXZYi2B8+egOGzKU+q3aASXmbCBiaHLK5TElUuLi\npegAMp+ja2bivuN+2GlbSltxkeL5MEquDp9dMxNzhx/f5Z/dZUnLxlQtfebt90nQNHPY3C4sTCdd\n5uS+v7Pg3dDazeCdIvPuK+5h855edOVL5q2kec+pMu/OCtbc2/Wm09MLou5i8M6C6x96Ddc/9Jrn\nPrt6Fk7w9gZ3Q9CgR4LmGlZRcTbZ8CsSSopkCJKOPfti5iYMgt061GroIgnOhhRtERXhqIZia3/h\npqg51xpWwwhamar5QWj2zd7eth0fNXyKQfIIaPW1GDwgEVA9OyxZhg8yX3dAqd8pShs62Oc0QnEy\nb6dtpwpVtIbsFe/8kXsI8VuHXdlh/99xQ8xWlfb/PfcmAYD3P+DQkhqMLR8FGCnmvGF1RLOqdodV\nlzgZROyzqYhvH4ehhrk8ys4AZVFwisbs4VmtvhbaPnP4fNSQUqdHtxEtcvrFA4Bf9jkdLY32xEhA\nqsy3LJAI3s7fhC55AtL/G30Kjhx8OAJS9pbq2PPcQyqDnkrsQTnaaMHdrEiWBOiG4azEUBTJO2yu\nmHPebRHVHJXqJCBJXVjDW2o1jXFn92NqzIu5IQO6N8ztVyQU+eU++2B3S3Vx0l2DrL+T0mJfhkf2\nrXS/d/v3ka63eTK7F7zYyd9YV9eJ97b+ffmUp+xCG8MwnA+WRPA2/9XQMXhDC0DdNQbV46LYGfkC\nsLbLnDV5IF54C04wcipXreCtSCIgxWFEfE5L1XJ/KRoANESbMArmnLddLFUe8mPP/nZz/tgXRZEc\nwHdnX463yxow3bUcZXz5mMTxqTK+MuJknHTUZLz7aT3iqo6nt5hrzysrJWyPxc1kU1NwzNQhEMsF\n/CeyCZKswfBFYBhCh0zbfXt8xRgIgoAbzp0Gnyxi9/52zJo4CDe/9ifnfert3jluWRKwaOL52NL4\neYcLjbISH5paDRiGGfyrS8px+lEjURr0YfaUwSgOKLj27EPx8z9thLprDIQae7g8kXnbRU32nuTj\nhpbhzGNGYV9TBDPGV+FnqzbA3GFcgN40EKK1I5tn2Nx1wTFj0FQkq7CaqMiDvsCGfebPOe+YyZ6i\no1NGmu1qX3rjC+e+AaW9u2/wLRdMx0dfNGDK6AGIxTV8/YQxB1Qo1VPupZR2sLH/litCfs8oix3I\nW9oT+5ink/yhe+uF0xGNe7OpMTVluPTUCc4GGwBw8hFD4fdJmDHeu/NWV10wd5wzrN/XejN4H35I\nFRbOH4+jJg/utdfMhuRs+vuLDkdTa+Iz1/130dn5GTE4hEtPnYBDhqcfteqrCzQG7yxStcSmCppm\nz3mbHxya4A3eEPREYxIkdtzy+yRnq0l7DbOdeTt7RUsCdCEOQwtiZ521UUdxJT4PA/sjDeZuZLrq\nZN5V5QEzeFsXEjXFQ1CsFOG4ad4sa3jpUAwOVmN3+15EP5yFY4+ag1DQh+Om1WD9B4lK7WDQgICw\ntbm9gLOPHY09cRn/ec8cNheUKBD3IVDi/aAt9lQrm+996hgzCE8YYfX/1v3QxXarUKxjRe+RVYfj\nyCGHdzj35SV+8z+rIQCCgepQGfyKhLlHJPp6H35IFYJ+Ge1R1cmU7Yst93CsnXkfOnoAJo9MVH+7\nq5zj28cDhoihlRVQRNmpcTDCJdAjRThhzHTPHL1znEWJC5KdrealwNETRnV4HODNEFJ1EOuJytIA\njp4yBIBZiX3aUSMyPKN3uTNve5jX3rSnqjzgyYrt77dFVFRXdF44l5xVpbsYOW5ajee2KAgd7jsQ\n44ZmnqLIld4M3pIo4sQZQzM/sI8lz0PbIympZJpKyPR30JWitmzo+zGdg5j7Cl/Tra+tYXNnj2Xz\nljlfavfztpc7WTtuOXt0W8HSZ+/HbDdOkTSn4GxHnVn1XVtqZtANkUa0W/PRRdY+t1X2jkuy+bo1\nJemvom8+4jpEP5wJI1zq3bQ+oCSGyJUwRH8EmpUZK7LobK0nyCoEXxRGzO/Zlxfo2s5DJa3mHLPe\n2DED6uxDKRS06wLMoFDiSz386QzJG4bntuyZ844797l55v00H+LbJmGYPsN6Qet+Q0R0w3E4b/xZ\nKX9+kc/nFA8CZuFdukpud/FVLqtac8Gdedvnede+xI5x7mFz9+890/RBbwaufCX1g6H7XMvlUHa6\nTaeyrfB+qzlkL7sCEsPmguBu2WmxN9+wMm87wxZE1QreiblQAAiIAc9rGFYWb2gydlrBe0SlOV+8\nP9Jo7kcNIGgFVLvNYeyzQzEiNAynjpyb9j0E5IDTKtL9HyIYkJ3g/Z+ItZtYuGPwhq/dXI8eD3To\nhdyV4F0ePgSRjXMQ+++UDt/r7EMped/idEt07Kvu5P9/dntUIJF5JweCVEU79lW49+VStwwFzGVP\n0Q+OcgJ4mb/jDlm2rhTP5Bv7nRquM2af50TmXeQ59+7fe6bCqT76XO1XCvECJpdD2U5ilmMcNu9F\numF45lI8mbfmLVhzb7cJwargtoKz0SHztrqLWXPefsneDMMM3ppgXQioMnZY2/UNG1AJn6hgQ/1m\nJxjYa4ZDRebws948ELfMPK/L7y8580bS7kh24xdFFlFkWM1gfFZns5i/Q1WwLMo4YegcDAqmn1eU\nZbFDoZqts0KT5Cvv9MHb/Dd52FwUEsHbnitLfs1Uy4DsD8p0u8h1PE4RRqzIXG5WuRdiJzsu2Mv4\netJoo78RBAGGYSRl3lbw3tcOvyIhFFSSNjFxZ96dz3l39fdwMGPwzi7nsz3HDp5PgT62e387rrjn\nn/j7267+1/HEsqrkgjVB6ph5G9awuW7vNuWLwO9zZ97mtZZTdWy9hu7KvJvaYigt9iHgk+GTzCD9\nft0mAMCospEA9OeBHQAAIABJREFUgPJQ9ypFk5tcOHtuW/RwCLIkmPv/Wpm37rOat8T9Kfv+njv+\nLBw39Oi0P7OzZRzJnbHcyoq9c8IBOXWBlz13XGQdWyITTPS/brcadSRn+ikzbyl1Jp+JfcEW19MX\nOdmvOXxQ560h84m9JMtd4S675rQHlgc6jES4f++ZMu+DbXqhOwrxHOTygqWvLqaZefeStz7aCwBY\n+ddPnPvcm444QyuiO/M2AAiJPautYFgSrwX8m6AM+wQ++WRnztvOvINyEIh3zLxrKspR7h/gVMi2\nurbpBIDRZWYR0qSRlTh99ghMH9e1Stprzz4UX9a3ej4EKkJ+nHroNLypbcWYkrF4+8P9MNpKofgT\nFfHmkjd7NzJft1oHdjZ3lep7t1wwHe9vrcfXjh8DUQT+ZT9WTP2zrz17Cl5Y/wVOn20VaLlesqqi\nCKOGhPD5LnP0IPkDIdV8qzNsbkVaWRJxwdxxad8DAFx51hSsa9iBrZFdnuHjZGcdMwqqpuOsY1IX\ntOWj75w/Df/3n+04+fBEEdSx02rQ0h6Hbhg4ekqiHuPCueMQ8MnY+mWip26m4D24MohTjhyOKaMq\nO33cwcyvSDhzzshO29MebIr8svmeh6R/zxfPH9/lTUk6M2N8FU6cXotjpw3p8WsdCAbvA7S18b+o\nC9fjqCHezQXswCYEWiGW1UPbMwLRlJm3GagF0XA2FnGG0q3g7YsOQoV/KBqKdwBSvMOcd5EcgBAX\nnNakLaq5DehXpo/F7JpEj+sLJ3wNL3z+NzRGm5znAeaQ8NeOTywDy+TwQ6pw+CEdA/3Xjp6Mq6uO\nwsdb67B+rbkft/sqNCgH0BSzh/SVbgZvMem22XMaSD1sPmFEhVOpfv5J4/Avc5dMKGLq4dXqiiAu\nPTWx/lpAYthbFARcd85UfPfnr6c8lmCKLlPJAf74aTU4cXpth8e5nXncGEzYfiZ+uWEfFhxydtrH\nBQMyFn7lkLTfz0dDBhTjklO869/H1pbh21/vuKzOXimwbXeLc1+mYXNBEHDeiWN74Ujz21ePHZ35\nQQeZTO/5pF6qmpclsU/+XzJ4H6AH3vkFAGDW4Bmebln2XHdgqtmcJdJa4Q3emrdgDYCZfetyIhu3\nhs1jcQ2KHgREQBOirszbqjZXJBTFi9BqBe9P2z+CKIiYPND7ITin5kjMqTkSb+95DwNTNFzpLe4P\nUPeSnqASRFMssZtXd7bLSw6YPlmCqplDy501TrAdWzsbr+5ch9HWlEEmiepz89+yksQUQ3Kmn7pg\nzTts3tWGVhWBciyddWPXHlzg3Bdt7o0/iAoJ//K7SdVVZ04ZAJKnWARRSxo2TypYgznsbcQDEKwm\nJPYcciSmQdB9ZvAWI4iq1lIxe523LCIoF6FNaoXgb8OeyJeYWDnes4mF2+GDDuvRe83Ep4hmP2rd\n8GTe7mpyo7uZd9J8kt8nOXPQXWn1eN74s/DVMachIHdvXbS3mYP3WFIOm9tz3vbwd+FNN2ad5ClY\n40cYFSYWrHWTmlRY1CGQGELSsLm9zjuRedubeiSGzc3gFo1rEKyGJHEjirC1TttemuWTRQSVIkCO\nQRpgNvaYOWh6z99UNwmC4HyIuueQPOuVXZttHIjkbPdAd0USBfGAAnfyum+3mKp5bqfaijIx523/\nfEbv3uYtWOvfG2QQZQuDdzfFda3zBwg6/vi3T7Hps30AXJm36FoTaFecJw2bb9vdgi92mvPcUSOM\ntri53tVu0qLIEkqUIATRgFS1A7IgY2rV5J6/qR6wP0QVxTtsbjNUxbOTU1clD5tne79cZ847xfda\n2uOe26kL1rpXbU5dx8ybiMG725Izb7ufucMKyA88/T6AdMPmKgZVFGFghbWHtWvplb0dZtQIO3tx\n25m3IotOYBT9EQwtHppoitJHjpo0CANKAzh8fLVzX1BJtAOVoXSrjaA7eI8cHMKF88b37EAzSZEo\nf+/iGZgwvByzJ3v34lZkEbMmVnsKopKHzZl4976JIyowqDKISSMrUFHau21iifIFL1u7SdW9WVgs\nufuVNY9tf3inK1i7Y9FMPP3uK3izHYAuosgvIRzVnK0129VE8LaboiiyiGI90XQk5Ov7db9nHjMK\nZyYtYXIPm/vl7q0td+/zfPslR2R9swdnnbfr1zRuaDluuXBGx8cKAq4+y+z89vQ/twBwVZsbicdQ\n7xo3tBx3XXVUXx8GUZ9i5t1NquEdNjdbV7rms63MuzJkZsTJvc0Bs1GLTxFRVGT9GgzRqdy2s+zW\nWBva4+3wiT5nWN0niyj3JdYvhtIUqvU1d8FadyrNgUTBmiSaLUaz3Xwh0XGte+Pe9pJBnfVqRJRF\nWc28V6xYgffffx+CIGDp0qWYOjWxdnPlypVYs2YNRFHElClT8P3vfz+bh9LrkofN46ruZNsAnK8H\nWMN6qea8BVmFJIoIBqzgrZutIOubIs6weVu8De1qGAEpALs1hSKLKJMSwbs0zaYbfc09593duWq7\nOMkO2tnecEBI7pd6gOSkJi3MvIkoG7KWeb/55pvYtm0bnnrqKSxfvhzLly93vtfa2opHHnkEK1eu\nxBNPPIGtW7fivffey9ahZEVyG8u4qnn7lVvBu9Rqv6m72qMaqnnNJPniePHzv6MdjQDMOe9ia39i\nSfdBgIDWeBva42FnO0/ALFgr8yeCd1mgf3ZOch9z8qYkXWVn3nZGm+3t9xLD5t2M3slLBhm7iSgL\nsvZJuG7dOsyda+5WNWbMGDQ1NaG11exzrSgKFEVBe3s7VFVFOBxGWVn6/Vb7Ql1jGI+t/djZDjJZ\nqszbvVOY0/LUCgLujUnsrFoYsAN/+XwtXtn5uvVY0Wk6IUkiipUgGqNNiGgRTxaryKI3ePeDOe9U\nZDExsNPtzFtK7K8N5K5Pc7eLxa0n9tU2gURUGLI2bF5fX4/JkxPLlyorK1FXV4eSkhL4/X5ce+21\nmDt3Lvx+P04//XSMGtV5v+aKiiBkuXeXCVVVpZ8rXrHyHWzZ3oiyUABXnNVxO8rikOJ5viCJiXXb\ngJN5y4qEqqoQJFkCYEAQAD3uAwLtHY+nrBhl1hy5IokYXl6DD+o+BQBUliSC9eDqEAJFiYA9fNAg\nVA3su3nvdOfRCNYC7wB6OIjSEn+n5zudygpzSkCWRef5AZ+E8cMruvV6mVx46kT86JE3cN68Q7r1\n+iWhAKqqQrj6nKn45aoNOHXO6C69TjbeS6HhOewdPI89l4tzmLNqc/cwZGtrKx5++GG89NJLKCkp\nwSWXXIKPPvoIEyZMSPv8hoaOwa4nqqpCqKtrSfv9fY1h69/2lI/b19CCOiVxf2tbzDNsPn54CB/s\nBMLhOOrqWtAeiSWK1TQZhi4msnPLVacfin+tM3+uKAoY5B+ED2AGb9lINKNoaQ5DiyZ+dVq72Ol7\nyabOzqMAH04oPh8vvl0HjDO6dYzhtqj1WnCe/7MbjoMgICvveVRVMX57y4kQRaFbr9/cHEZdXQtm\njhuIw7v4Opn+FikznsPewfPYc719DtNdCGRt2Ly6uhr19fXO7b1796KqytzcYuvWrRg2bBgqKyvh\n8/lwxBFHYNOmTdk6lG6xLzbSjdJ2GDbX9KRtPs3MW7NeJ2q0QvBHrBcXALXjdZMsys7wuiyJqA3V\nON9zL7tSJNFTCFWi9M9hcwCo8g0GNB/8Svf+1Ox13u4qc9GqPM+W3hqaL8StGIkoN7IWvOfMmYO1\na9cCADZv3ozq6mqUlJhBpra2Flu3bkUkYgazTZs2YeTIkdk6lG4xUqzTdY8edChYi2uA7B42N7Nq\nOxjvqFqDwNRXrRcSYcQ7NpdQRBmqtaRMEgUMK0kEb/dabrt/+IiQuctSd/t254IdfANK9wZ5ZDm3\nc91ERPkga8PmM2bMwOTJk7FgwQIIgoBly5Zh9erVCIVCmDdvHr7xjW9g0aJFkCQJ06dPxxFHHJH5\nRXPIvdRH1VX88aNVmO3aBlQ1OmbeYlFiqMQQzO87Veae1xZgtJVBLDYff+GEr+HThs9RVTQQmlYH\nwCxYqykZjONqZ0MWZcypmYUnsN45JgD47uGLu70eOVfsgjNfN1qjAole6WKWq8x7C+vUiCgXsjrn\nfdNNN3luu+e0FyxYgAULFmTzx/eIu8nGxvoP8cbut/HG7red76tJvc1jqg6xsinxfGgQBQGaYeCL\nvc3eFzdE6K3lQPUOAImtO4FEm1VZFCAKIs7vZH9nScxun+/eYGfe3a02l6zny3mSeff3iykiOjjk\nRzrTBxKZd+pGG8lz3jEtBiHYAr3NrArXoEIUBei6gR/8YZ33yboIvS310rhZE83+2ccfVpPy+/mm\nImQO6Q8o7V7v9UTm3b+D9xGHmPUcIwf3zzX3RHRwYW/zDARBgF/s2Jc7ntzbXG6EIBjQWiogBJuh\nG6qzx7Wn8xoAASLu/8ZXsOaLCMaVj/Z878hJgzC2tgyVKTZc+NkNxyY6teWJMbVluPvq2RhY1r3g\nbQ+79/fg/c2zJuO8ligGlhVlfjARUQ8xeKfhDJsLqbt6JQ+bq4K5lE2PBiHpkpN5R2Kad/03zD2m\ny0sCWDTp/JQ/e0CaQJevexdXl3c/oLl7m/dnkigycBNRznDYPI3EUjEBmqF3+H6HLUFFa9vOmB/Q\nRWiGBkkUsLehvUPmDZ2nvavsXuH9PfMmIsolRpE03FXDWlKWDXirzQ3DgC5Za7jjPhi6BNWIQxIF\nGAY6ZN727mCUmZ1550vBGhFRLjCKpJEp845riYC8e387oJidwIy4HzBEqIaayBalpODP4N1lSp7M\neRMR5RKjSBruOW/N6Dzz/vmfNkFQYgCAUn8I0BKZNwAIHQrWGIi6yqdIkCUBRT6WZxAR2Ri803A3\nadFTDZu75rwjMRWCEoUiKvjhJbNRO6AUcT0OwT67ycPmev9fn91fyJKI755/GM4/aWxfHwoRUb/B\ndCYDM/NOVbCWCOiabkDyx1DmC6G02I/yYDF2RXRIkpW+JxesaflZNd5XDhle0deHQETUrzDzTkN3\nNWlJOeftWuet6ToMKWoOmQPwS+YabdGa6xaS57w1XjMREVH3MXin47RHFVLPebuGzXUhBggGQtbu\nXgEreAv2RiVi0rC51rHpCxERUVcxeKdhrxRLV7AW01QnOzdEs9K8WDG37fRbu3wJaTNvDpsTEVH3\nMXhnIKYpWPt8dyN+9Zy5B7kmmpXmQTt4S1ZmbQftDpk3h82JiKj7GLy7INWcN0QNb31sbt9pWMG7\nWDaDtzNszsybiIiygME7A90wUg6bu7umGZJZvNZh2FxUARgQ/OGkF2XmTURE3cfgnYFupMm8reBt\nGAYMKXnY3NoRTFQhDdoGsbjZ05iFTVqIiKgnGLwzMAwj5Zy3ORRuQDcMCLKdeZu7StnD5pBUSKX7\nAQCXTlqQk+MlIqKDH4N3BuaweYrMGwAkFf/e+SaU2q0AgGDSnLchqBCKWmDEFVQFB+bkeImI6ODH\n4J2BYaRYKmavAZdUPPnpaufu5DlvTYpADIRhREKQRc5zExFR72DwzkDXOxasybDntL33FyctFYvI\n9eY3wqEO+38TERF1F4N3BobRcT9vUU/MaeuRIud+RTSXgNnD5mFxHwBAiIYwpHgQJMOH+M4xEFiv\nRkREPcDgnYFhGNCT5rwF3cysBUmFEU0Eb8GKyvawuV1ULmh++CQfZukLoe4cl/2DJiKigxqDdwYp\nC9bsJiuSCgjmBPjo8Hzn285SMYtgPV4wmHITEVHPMXin8NH+TyH4zMYqqQrWjLgVjK3gbRgCSvUa\n5/uKKENxFagJujeYExER9QSDd5KWWCt++t5v4J/2CgAz8/7vnibPY1S79kxSIQgGYAiQRG9WHfKF\nnK9FnbuIERFR72HwTtIWbwcAp6hM1XTsaWjzPMauX7MzbxgCxKQzGfKVOF+LOnuZExFR72HwThLT\nY57bcVV35rVtumadNsnsXW4Gb++pLHUFbwHmELoB7+sQERF1B4N3koga8dyOxXVAMAvWoh8dgaJw\nLbR6c37bnXlLSeu/Qkpi2Dz5e0RERD3B4J2kPSl4x1XNybz15gEI7T0aajQAABCUqBO8haQzWepP\nBG8hKXgn3yYiIjoQ7NmZJBz3bt8Zs4bNDQMABLRH44Dqg6EqEAJt1lruFAVrimvOW2SwJiKi3sPM\nO0lY6zhsLgg6YJinqj1ilprr4WIIgTAEUYNhCB0CtGepGDNtIiLqRQzeSbyZt4G4pjtD4wAQjpql\n5kakGIJgQPBFAUPskHlLouR8bX+L5WpERNQbGLyTeDJvwUAsrgGCDlmUMLqmFLo5fg4jXJx4nCFA\nTMquJw+YABgCYtsmdPgeERFRTzB4JwnHXcFbVJ05b1EQ4VcS2bQeDSYel2LYPOQrwYTGi6DtGclh\ncyIi6lUM3knCqmvYXNSdanMR3uAN3fV1ig5rAKwit8SwORERUW9g8E4Sdi0VEyTVWectQoLf5w7Y\n7lPXMfMG4AyxC4zeRETUixi8k3gzbw2abkBwhs1dp8u9Q1iKOW8gEbyd2M2KNSIi6gUM3kncTVoE\n0W5ibgZvn2vY3NATpy7VUjEAmDKyEgBw2NiBnvs5BU5ERD3BJi1JYpqrt7lkB28doiBBkdJn3qnm\nvOfOHIZDhldgWHVJh+8RERF1F4N3kqh7Y5KkzFv2BG9vIE+VeYuCgBGDQx3uJyIi6gkGbxfDMBDX\n4s7txLC5DkmQkoK3O1innvMmIiLKBs55u8R11bttp5TIvCVBhCy5h8q9WXiqYfNkrFcjIqLewODt\nYs93S4JZmCaIGiCqEATAJ/o9mbe7YC3dsHk6zNGJiKgnGLxdolbwDohW9zRRg+Azq89LlNABF6wl\nG1xpvu7omrLeOWAiIipInPN2iVvFakVSEG1aCyCp5sYjAEqVEGQxdcGakWadd7KTZtSiOCBj+riB\nGR9LRESUDoO3i5N5C2aGLEgaBMXMvEt9pZCTsm33110ZNpclEXMOHdJ7B0xERAWJw+YuMavS3O8M\nm6vOsHmZrzT9UrE07VGJiIiyIWPw3rp1ay6Oo1+IWcPmPqMIgJV5W8Pm5f4yyHLP5ryJiIh6Q8bg\n/e1vfxsXXHABVq1ahXA4nOnhec0eNldgBm9zztvMvCuKyrwFa8jc25yIiCgbMs55P//88/jkk0/w\n4osvYuHChZg4cSLOPfdcTJ06NRfHl1N2gxZBV2BoIgRJBXwRGLqIkFKMqBRJ/URD5LA5ERHlTJfm\nvMePH4/rr78eS5YswdatW7F48WJcdNFF+O9//5vlw8stO/OGIQG6DLG4GWKgHdq+wVCUpA5rbhw2\nJyKiHMqYee/cuRN/+tOf8Je//AVjx47F1VdfjWOPPRYbN27EzTffjGeeeSYXx5kT9py3oUkwNAmC\nYt6v7hwHRfL2Nvd0WwOYeRMRUc5kDN4LFy7E17/+dfzhD3/AoEGDnPunTp2aceh8xYoVeP/99yEI\nApYuXep5/K5du/Cd73wH8XgckyZNwp133tmDt9E77A5rhi4CmnlqDEOAEQtAlrztURXZvc5b5Jw3\nERHlTMZh8zVr1mDkyJFO4H7iiSfQ1tYGALj99tvTPu/NN9/Etm3b8NRTT2H58uVYvny55/t33303\nLr/8cjz77LOQJAlffvllT95Hr7CXihmqBEO39u6OKwAESJLgqTZP7rbGYXMiIsqVjMH7e9/7Hurr\n653bkUgEt9xyS8YXXrduHebOnQsAGDNmDJqamtDa2goA0HUdb7/9Nk466SQAwLJly1BTU9OtN9Cb\n7DlvXXNl3roMSTSryd0BO3nZGIfNiYgoVzIG78bGRixatMi5fdlll6G5uTnjC9fX16OiosK5XVlZ\nibq6OgDA/v37UVxcjLvuugsXXHAB7r///u4ce6+z57x1VYSzFEyTnEDtnvNOzrwZvImIKFcyznnH\n43Fs3boVY8aMAQBs2rQJ8Xg8w7M6MgzD8/WePXuwaNEi1NbW4qqrrsLLL7+ME044Ie3zKyqCkGXp\ngH9uZ6qqQp7bwqfmMUqiD7D28jZ0CQFFQlVVCEUlifddFFDgXMIYAgYOKO7weoWiUN93b+I57Dme\nw97B89hzuTiHGYP39773PSxevBgtLS3QNA2VlZW49957M75wdXW1Z7h97969qKqqAgBUVFSgpqYG\nw4cPBwDMnj0bn376aafBu6GhPePPPBBVVSHU1bV47mtpN39GuN2A4Lf28tYlSKKAuroWxOJa4sGu\nixEYApoa2+EvwOQ71XmkA8Nz2HM8h72D57HnevscprsQyDhsPm3aNKxduxbPP/881q5dixdffLFL\nmfecOXOwdu1aAMDmzZtRXV2NkpISAIAsyxg2bJizTnzz5s0YNWpUV99L1tjV5mocTuYNXXIqy93z\n3JJnqRg7rBERUe5kzLxbW1vx5z//GQ0NDQDMYfRVq1bhtdde6/R5M2bMwOTJk7FgwQIIgoBly5Zh\n9erVCIVCmDdvHpYuXYolS5bAMAyMHz/eKV7rS1E9BlmUoaqAPedtqDJ8VtB2B2jJ9bVhCDBARESU\nGxmD9w033ICamhq89tpr+MpXvoLXX38dP/jBD7r04jfddJPn9oQJE5yvR4wYgSeeeOLAjjbL4loc\nPlFBXNUhfjEdyvCPEd5+CJSqjgMU7gI1QTAYvImIKGcyDptHo1HceeedqK2txa233orHHnsML774\nYi6OLeeiWgw+yYeYqkNRy1C291hA9UNJUSgnJbdKNRi+iYgoNzIG73g8jvb2dui6joaGBpSXl2P7\n9u25OLaci2kx+CQz81ZkCbpuBmR3NzWbuymLJAuoCAVydpxERFTYMg6bn3XWWXj66adx7rnn4rTT\nTkNlZSVGjBiRi2PLuZgeQ7lYigZVQzCgQNV0AHDmvN3cwfvsY0alDPBERETZkDF42wVngLmka9++\nfZg4cWLWDyzXDMNATIvDJ/kQ13Qosoj2iAogc+bNGW8iIsqljOmiu7vaoEGDMGnSJCeYH0xUXYUB\nwwzeqg6fLELVzcxbSbEVqOgJ3kRERLmTMfOeOHEifvKTn2D69OlQFMW5f/bs2Vk9sFyLWq1RFVGB\nqhlQZBGaZs15KykK1kT3rmIM30RElDsZg/eHH34IAHjrrbec+wRBOOiCt92gRbY28VZkyZnzTpV5\ne4fN9RwcIRERkSlj8H788cdzcRx9zt4ONBG8RSd4y3IiUFeVB1DXGPEOmzPzJiKiHMoYvC+88MKU\nc9wrV67MygH1leTM2yeLUK1hc9k1RL78yqMQi2t4+p9bnftYsEZERLnUpQ5rtng8jvXr1yMYDGb1\noPqCvZe3CHN+293HXHb1MZcl0dkaVI8UQQyEUSQX5fBIiYio0GUM3rNmzfLcnjNnDq688sqsHVBf\nienmsLmExLC5rUM3Nfs5Hx+B4NCdOPb4g2v+n4iI+reMwTu5m9quXbvw+eefZ+2A+krMybzNU+Ju\nzCKLqZbGGTCixZD3TIFPUlJ8n4iIKDsyBu9LLrnE+VoQBJSUlOC6667L6kH1BSd4GzIA3ZN5y510\nTzv4VrwTEVF/lzF4/+Mf/4Cu6xCtoq14PO5Z732wiOnu4B3zbEYipxk2JyIi6gsZo9LatWuxePFi\n5/ZFF12El156KasH1RfsgjUY5ilxr+2WUgybc3UYERH1lYzB+9FHH8WPf/xj5/bvfvc7PProo1k9\nqL4Qt9Z5Q7fmvBV3wVr6wfGDsVUsERH1bxmDt2EYCIVCzu2SkpKDMmBFtCgAQLCCtzvzdq/ztjHx\nJiKivpJxznvKlCm44YYbMGvWLBiGgVdffRVTpkzJxbHllB287cxbUdzrvDnnTURE/UfG4H3bbbdh\nzZo12LBhAwRBwJlnnolTTjklF8eWU1HVCt6anXm7C9YOvpEGIiLKXxmDdzgchqIouP322wEATzzx\nBMLhMIqLi7N+cLlkZ96GZgZt91KxUNDX4fFVZQEAQG3VwXUeiIio/8s4Hnzrrbeivr7euR2JRHDL\nLbdk9aD6gp1566oZvH2yiOVXHolLTjkEIwaHOjz+lCOH44KTx+HKMybl9DiJiIgyBu/GxkYsWrTI\nuX3ZZZehubk5qwfVFyJaBIqoQNPM24osYsiAYhx/WG3KxyuyhHkzh6XMyomIiLIpY/COx+PYujWx\ng9bGjRsRj8ezelB9IaJFEZD8iKvWHt6ddFUjIiLqSxnnvL/3ve9h8eLFaGlpga7rqKiowL333puL\nY8upqBpFQPYjxuBNRET9XMYINW3aNKxduxarVq3CkiVLUF1djWuuuSYXx5ZTyZm3z9UelYiIqD/J\nmHm/9957WL16NV544QXouo4f/ehHmD9/fi6OLWd0Q0dUi8Ev+xHXmHkTEVH/ljZC/eY3v8Fpp52G\nG2+8EZWVlVi1ahWGDx+O008//aDbmMTuax6Q/IjHzYo1Bm8iIuqv0mbeDz74IMaOHYs77rgDRx11\nFICDt4931FrjHZADaGPmTURE/Vza4P3yyy/jT3/6E5YtWwZd13H22WcflFXmABCx1nj7JbNgTRBS\n7yRGRETUH6RNL6uqqnDVVVdh7dq1WLFiBb744gvs3LkTV199NV555ZVcHmPWOZm3VbDmk6WDdpSB\niIjyX5fGhmfOnIm7774br776Kk444QT8/Oc/z/Zx5VRYjQCAWbCm6hwyJyKifu2AolRJSQkWLFiA\np59+OlvH0ye8mbfG4E1ERP0aoxSA9ngYAFAkB9DcHkdx4OCqpiciooMLgzeAdtUM3qLuRzSmoao8\n0MdHRERElB6DN4D2eDsAIBoxT0dVeVFfHg4REVGnGLwBtFmZd6SNwZuIiPo/Bm8kMu+WVvM2gzcR\nEfVnDN5IzHk3NZnd1TjnTURE/RmDN4C2eDsUUUFLmxm8K0L+Pj4iIiKi9Bi8YQ6bFytBRGPmpiQ+\nhduBEhFR/8XgDXPYPCgXIRLX4FNEiGyNSkRE/VjBB2/d0BFWIwgqRYjFNfiZdRMRUT9X8ME7rEZg\nwECxHEQ6NoGoAAAYmElEQVSUwZuIiPIAg7ddad6so6E5Cr+PwZuIiPq3gg/eMc3co3zL9jYYADNv\nIiLq9wo+eMd1M3gbunkqGLyJiKi/Y/DWVfMLBm8iIsoTDN5W5g3dDNo+peBPCRER9XMFH6ni1pw3\nDPNUBFiwRkRE/RyDtzPnbWfeDN5ERNS/MXhzzpuIiPIMg7cz583gTURE+YHBW/MOmzN4ExFRf5fV\n4L1ixQqcf/75WLBgATZs2JDyMffffz8WLlyYzcPolDNsbhWsiSI3JSEiov4ta8H7zTffxLZt2/DU\nU09h+fLlWL58eYfHbNmyBf/5z3+ydQhdkrxUTNP0PjwaIiKizLIWvNetW4e5c+cCAMaMGYOmpia0\ntrZ6HnP33XfjxhtvzNYhdEksqcOapht9eThEREQZZS1419fXo6KiwrldWVmJuro65/bq1asxa9Ys\n1NbWZusQukR1qs3NzLu4SOnDoyEiIspMztUPMoxERtvY2IjVq1fj0UcfxZ49e7r0/IqKIGS5d4vJ\nqqpCkD63bugi5h85Al89aTwkznsfkKqqUF8fQt7jOew5nsPewfPYc7k4h1kL3tXV1aivr3du7927\nF1VVVQCA9evXY//+/bjooosQi8XwxRdfYMWKFVi6dGna12toaO/V46uqCqGurgXN7ebrGrqEuTNq\nsH9fa4Znkpt9Hqn7eA57juewd/A89lxvn8N0FwJZGzafM2cO1q5dCwDYvHkzqqurUVJSAgA45ZRT\n8MILL+Dpp5/Gz372M0yePLnTwJ1NqqvaXBILfuUcERHlgaxl3jNmzMDkyZOxYMECCIKAZcuWYfXq\n1QiFQpg3b162fuwBi7matMgSh8uJiKj/y+qc90033eS5PWHChA6PGTp0KB5//PFsHkannI1JdAmy\nxMybiIj6v4KPVqquWg1aBBaqERFRXij44B3T4xAMs4qdmTcREeWDgo9WcSt4CwJboxIRUX5g8NZU\nCKw0JyKiPFLwESuuxwFDYqU5ERHlDQZvPW4tEyv4U0FERHmioCOWYRiIaXFAl1lpTkREeaOgg7dq\naDBgwGCDFiIiyiMFHbxjWsz8QpcgcdiciIjyREFHLDt4G5rEYXMiIsobhR28rb7mhsaCNSIiyh8F\nHbHszFtn5k1ERHmkwIM3M28iIso/BR2xYnoi82a1ORER5YvCDt4sWCMiojxU4MHb3stb5FIxIiLK\nGwUdsRLrvGXOeRMRUd4o6IjlLBXTRQ6bExFR3ijs4O3qsMaCNSIiyhcM3gCgSdzPm4iI8kZBR6zE\nsLkERSnoU0FERHmkoCOWe9hcYcEaERHliYKOWFFnqZgEHzNvIiLKEwUdseJWhzWDmTcREeWRgo5Y\nMVfmrchS3x4MERFRFxV08I46c94iFLmgTwUREeWRgo5Yqq5CggRAgI/Bm4iI8kRBRyzVUCEKMgAw\n8yYiorxR0BErrschwpzrZvAmIqJ8UdARK66pruDNgjUiIsoPBR28VUOFYJingJk3ERHli4KOWKqu\nQrAybxasERFRvijoiBXXVQgG57yJiCi/FGzEMgzDzLw5bE5ERHmmYCOWqqvmF8y8iYgozxRsxIpr\nVvDW7cyb1eZERJQfCjd423t5W8PmLFgjIqJ8UbARy868DZ1z3kRElF8KNmLFrMwbmghBACRR6NsD\nIiIi6qKCDd6qlXnrugBFFiEIDN5ERJQfCjZ423t5G5oIRSrY00BERHmoYKOWXbCmaQJ8CivNiYgo\nfxRu8LaHzTWBmTcREeWVgo1acd0O3iIUpWBPAxER5aGCjVpxa85bUwXIYsGeBiIiykMFG7Xcw+ay\nzEpzIiLKH4UbvK2CNV0TITHzJiKiPFKwUcvpbW6IkCVm3kRElD8KN3jbvc11ETKrzYmIKI8UbNSy\nm7TAENkalYiI8krBBm9nP29dhMTMm4iI8kjBRq2Ys6uYBJmZNxER5ZGCDd5x97A5C9aIiCiPyNl8\n8RUrVuD999+HIAhYunQppk6d6nxv/fr1eOCBByCKIkaNGoXly5dDzOGSrYgaNb/QJBasERFRXsla\n1HrzzTexbds2PPXUU1i+fDmWL1/u+f4dd9yBhx56CE8++STa2trw6quvZutQUgqrEQCAocssWCMi\norySteC9bt06zJ07FwAwZswYNDU1obW11fn+6tWrMXjwYABAZWUlGhoasnUoKUXiZvCGJjPzJiKi\nvJK1qFVfX4+KigrndmVlJerq6pzbJSUlAIC9e/fi9ddfx/HHH5+tQ0kprEYhQLCqzZl5ExFR/sjq\nnLebYRgd7tu3bx+uvvpqLFu2zBPoU6moCEKWe2/f7XA8Ap/kRzsElJYEUFUV6rXXLjQ8dz3Hc9hz\nPIe9g+ex53JxDrMWvKurq1FfX+/c3rt3L6qqqpzbra2tuPLKK3HDDTfgmGOOyfh6DQ3tvXp8YTUC\nBQoAIBqNo66upVdfv1BUVYV47nqI57DneA57B89jz/X2OUx3IZC1YfM5c+Zg7dq1AIDNmzejurra\nGSoHgLvvvhuXXHIJjjvuuGwdQqci8QgU0QcALFgjIqK8krXMe8aMGZg8eTIWLFgAQRCwbNkyrF69\nGqFQCMcccwyee+45bNu2Dc8++ywA4IwzzsD555+frcPpIKxGUSGXAgAL1oiIKK9kdc77pptu8tye\nMGGC8/WmTZuy+aM7FddVqLoKQTffPoM3EVHfevnlv+OEE07u0mN/8pP7ce65C1BTU5vlo+q/CjJq\nRa0GLbv2xgBw2JyIqC/t2vUl/va3tV1+/PXXf7egAzeQw2rz/iSimcHb0MzqdS4VIyLqOw88cA8+\n/HAzHn30N9B1HV9+uRO7dn2JBx/8Be66607U1e1FOBzG5ZdfhTlzjsV1112F73znFvzzn39HW1sr\nvvhiG3bu3IFvf/u7mD17jvO6qqpi+fIfdHj+J598hPvvvweiKGDKlGm49trrU95n/5zRo8di1aqn\n0NjYiOnTD8eTT/4v2tvbcd11N+Ldd9/Gyy//HbquY/bsObj11u+ipaUFd955G9ra2lBSUoI77vgf\nXH75Rfj9759AMBjEhg3v4cknV2LFih93+5wVZPCOWsEb9rB5DtuyEhH1Z0//Ywv+89HeXn3NmROq\ncd5JY9N+/4ILFmL16qdx2WVX4pFHHoaqxvGLX/wWDQ37MWvWUTj11DOwc+cO3H77EsyZc6znuXv3\n7sF99z2E9ev/jT//eZUneLe0NKd8/oMP3oebb16KsWPH4Uc/ugO7d+9KeV86W7duwRNPrIbP58O7\n776NX/zitxBFEeeddxauvfabeOKJxzFr1myce+4CPPXUSrzzzls47rgT8dpr/8L8+afgtddewbx5\nX+nROS3I4G33NTc08+0z8yYi6j8mTpwMAAiFSvHhh5uxZs1qCIKI5uamDo+dOvUwAObyZHcXz86e\n/8UX2zB27DgAwO2335n2vnTGjh0Hn89crRQIBHDddVdBkiQ0NjaisbERn3zyEa644hoAwPnnXwQA\nqKmpxW9/+0vMn38K3n33bXzjG1cf+IlxKczgrSU2JQFYsEZEZDvvpLGdZsm5oChmD46//vUlNDc3\n4+c//y2am5txxRULOzxWkhLNu5KbgaV7fqpNsFLdJwiJxE5V1Q7Ht3v3Ljz11Er87ncrEQwGsXDh\nedZrSTAM3fNaY8eOw759+/Dhh5sxatQY+P3+zk9CBgUZtSL2piR25s2CNSKiPiOKIjRN63B/Y2Mj\nhgypgSiKeOWVfyAejx/Q66Z7/siRo7B5s7ni6a677sR///t5yvuKi4uxb5/ZbGzjxvdTvn5FRQWC\nwSA+/vgj7N69G/F4HBMnTsLbb/8HAPDcc6vw4ot/AQCcdNI8PPDAPZg375QDeh+pFGTwthlx88qH\nmTcRUd8ZMWIUPv74Izz00P2e+0844ST8+9+v4vrrr0FRURGqq6vx6KO/6fLrpnv+9dffhJ/97P/D\nNdd8A6FQKUaOHJXyvjPPPAf3338vbr75egwcWNXh9ceNG4+ioiCuueZy/P3v/4ezzjoHP/zhD3Hu\nuRdg06YNuO66q/Dvf7+G448/EQBw8snzsHfvXhx++MyenTAAgpGq6Xg/1Jvt5uJaHNf89hnojdWA\nIeKWC6ZjwojOe6tTamyn2HM8hz3Hc9g7eB57rrNz+Pzza7B79y584xvfPKDXS6Ug57wVSYHeMNi5\nzcybiIiy6Z57/gdffrkTd911X6+8XkEG72SsNiciomy69dbbevX1CjLl1HXvTAEL1oiIKJ8UZPCO\nxr1VjRw2JyKifFKQUSvWIXgz8yYiovxRkME7OfOW2B6ViIjySEFGrWjc2/mGmTcRUd96+eW/H/Bz\n3nvvHTQ07M/C0fR/hRm8Y0mZN+e8iYj6zIFuCWp7/vk1BRu8C3KpWMdhc2beRER9xb0l6PnnX4gV\nK36IlpYWaJqGG264GWPHjsP//u/v8cor/4Qoipgz51hMnDgJr776Mj7//DP8z//ci8GDzd4dfbEN\n6OWXX+VsAxqLReD3F2VlG1A3Bm+w2pyIyLZ6y1/w7t6Nvfqa06sPxTljz0j7ffeWoL///W9x5JFH\n4//9v6/i888/w09+ch8efPAXePLJ/8Vzz70ESZLw3HOrMHPmURg7djy+851bnMAN9M02oOeff6Gz\nDejixVfiZz/7VVa2AXVj8AabtBAR9RcbN25AY2MD1q59AQAQjZobSZ1wwsm44YbFmDfvFMyfn35j\nj77YBrS5uTkn24C6FWTwrgz54ZNF6IYBVTMgCgzeREQAcM7YMzrNkrNNUWTceOPNmDJlquf+m276\nHrZt+y/+8Y+/4lvf+iZ+/es/pHz+wbwNqOfYe+2V8sghwyvw1IrT8fBNJ+DXN5/Q14dDRFTQ3FuC\nTpo0Bf/618sAgM8//wxPPvm/aG1txaOP/gYjRozEZZddiVCoDO3tbSm3Ej2YtwH1nLNefbU8Iksi\nBEHgfDcRUR9zbwn69a+fj507t2Px4itwzz3/g8MOm4GSkhI0NjbgyisX4dvfvhqTJ09BaWkZDjts\nBm677VZ89tlW57X6YhvQ+++/x9kGdOHChVnbBtStILcEBbj1XW/heew5nsOe4znsHTyPPZd8Druz\nDWjy66VSkHPeRERE2dbb24C6MXgTERFlQW9vA+rGCV8iIqI8w+BNRESUZxi8iYiI8gyDNxERUZ5h\n8CYiIsozDN5ERER5hsGbiIgozzB4ExER5Zm8aY9KREREJmbeREREeYbBm4iIKM8weBMREeUZBm8i\nIqI8w+BNRESUZxi8iYiI8kxB7ue9YsUKvP/++xAEAUuXLsXUqVP7+pD6tU8++QSLFy/GpZdeiosv\nvhi7du3CLbfcAk3TUFVVhR//+Mfw+XxYs2YN/vCHP0AURZx33nk499xz+/rQ+417770Xb7/9NlRV\nxTe/+U0ceuihPIcHIBwOY8mSJdi3bx+i0SgWL16MCRMm8Bx2UyQSwRlnnIHFixdj9uzZPI8H4I03\n3sD111+PcePGAQDGjx+PK664Ivfn0Cgwb7zxhnHVVVcZhmEYW7ZsMc4777w+PqL+ra2tzbj44ouN\n2267zXj88ccNwzCMJUuWGC+88IJhGIZx//33GytXrjTa2tqM+fPnG83NzUY4HDZOP/10o6GhoS8P\nvd9Yt26dccUVVxiGYRj79+83jj/+eJ7DA/T8888bv/71rw3DMIwdO3YY8+fP5znsgQceeMA455xz\njFWrVvE8HqD169cb3/rWtzz39cU5LLhh83Xr1mHu3LkAgDFjxqCpqQmtra19fFT9l8/nw29+8xtU\nV1c7973xxhs4+eSTAQAnnngi1q1bh/fffx+HHnooQqEQAoEAZsyYgXfeeaevDrtfmTlzJn7yk58A\nAEpLSxEOh3kOD9Bpp52GK6+8EgCwa9cuDBo0iOewm7Zu3YotW7bghBNOAMD/z72hL85hwQXv+vp6\nVFRUOLcrKytRV1fXh0fUv8myjEAg4LkvHA7D5/MBAAYMGIC6ujrU19ejsrLSeQzPa4IkSQgGgwCA\nZ599FscddxzPYTctWLAAN910E5YuXcpz2E333HMPlixZ4tzmeTxwW7ZswdVXX40LLrgAr7/+ep+c\nw4Kc83Yz2B22R9KdP57Xjv72t7/h2Wefxe9+9zvMnz/fuZ/nsOuefPJJfPjhh7j55ps954fnsGue\ne+45HHbYYRg2bFjK7/M8ZjZy5Ehcd911OPXUU7F9+3YsWrQImqY538/VOSy44F1dXY36+nrn9t69\ne1FVVdWHR5R/gsEgIpEIAoEA9uzZg+rq6pTn9bDDDuvDo+xfXn31VfzqV7/Cb3/7W4RCIZ7DA7Rp\n0yYMGDAAQ4YMwcSJE6FpGoqLi3kOD9DLL7+M7du34+WXX8bu3bvh8/n4t3iABg0ahNNOOw0AMHz4\ncAwcOBAbN27M+TksuGHzOXPmYO3atQCAzZs3o7q6GiUlJX18VPnl6KOPds7h//3f/+HYY4/FtGnT\nsHHjRjQ3N6OtrQ3vvPMOjjjiiD4+0v6hpaUF9957Lx5++GGUl5cD4Dk8UG+99RZ+97vfATCnvtrb\n23kOu+HBBx/EqlWr8PTTT+Pcc8/F4sWLeR4P0Jo1a/DII48AAOrq6rBv3z6cc845OT+HBbmr2H33\n3Ye33noLgiBg2bJlmDBhQl8fUr+1adMm3HPPPdi5cydkWcagQYNw3333YcmSJYhGo6ipqcFdd90F\nRVHw0ksv4ZFHHoEgCLj44otx5pln9vXh9wtPPfUUfvrTn2LUqFHOfXfffTduu+02nsMuikQi+P73\nv49du3YhEonguuuuw5QpU3DrrbfyHHbTT3/6U9TW1uKYY47heTwAra2tuOmmm9Dc3Ix4PI7rrrsO\nEydOzPk5LMjgTURElM8KbticiIgo3zF4ExER5RkGbyIiojzD4E1ERJRnGLyJiIjyTME1aSHKN/fe\ney82btyIaDSKDz74ANOnTwcAfO1rX8NXv/rVLr3Gr3/9a4wfP97pZ53KwoUL8fvf/x6SJPXGYXvs\n2bMHn332GWbPnt3rr01UiLhUjChP7NixAxdeeCH+9a9/9fWhHLA1a9Zg69atuPHGG/v6UIgOCsy8\nifLYT3/6U+zYsQNffvklbr31VkQiEdx3333w+XyIRCJYtmwZJk+ejCVLluDwww/H7Nmzcc011+CY\nY47Bhg0b0NbWhocffhiDBg3CIYccgs2bN+OXv/wlGhsbsXv3bmzbtg1HHnkkbr/9dkSjUdx6663Y\nuXMnBg8eDEmSMGfOHM8exW1tbfjud7+L5uZmqKqKE088EWeccQYefPBBGIaB8vJyXHTRRbjzzjux\nbds2tLW14YwzzsDll1+O1atX469//SsEQcCePXswevRorFixAoqi9OEZJuqfOOdNlOd27NiBxx57\nDFOmTEFjYyN+8IMf4LHHHsOiRYvw8MMPd3j81q1bcc4552DlypWYOHEiXnzxxQ6P+eCDD/DQQw/h\n2WefxerVq9HU1IQ1a9ZAVVU888wzuOOOO/D66693eN6///1vqKqKP/7xj3jyyScRDAZRW1uLs88+\nG2eeeSYuu+wyPPbYY6iursbjjz+OZ555Bs8//zw++ugjAMDGjRv///bu2CW1MIzj+NcONQQRQi3W\nYnBsjDoSBFKNOVaEo0M4REO4HGyrKQin5ob+gDBaoiVyECEipakhWkKkQKFoiERPd5DOzYxLlysX\njvw+4+F5X97tx/PyHh7S6TSHh4eUy2VP3jKI/A/qvEU8bmJiAp/PB8DQ0BC7u7u8vb3x8vLC4OBg\nW73f78c0TQACgQBPT09tNZZlYRgGhmHg9/t5fn7m5uaG6elpAIaHh7Esq23d1NQUe3t7bGxsMDc3\nx8rKCj09rT3CxcUFDw8PXF5eAlCr1bi/v3fXf4xPnZyc5O7uzp2TLCK/KbxFPO7ztbJt22xvbzMz\nM8P5+bk7zOOzrw/Svnv28l2N4zgtQfw1lKE5y/j4+JhiscjZ2RnLy8scHR211PT19bG+vs7CwkLL\n90wmg+M4fzyXiDTp2lyki1QqFUzTpNFocHp6Sq1W69jeY2NjFItFAKrVKldXV201uVyObDaLZVnY\ntk1/fz/VahWfz0e9XgeaXf3HVb3jOOzs7Ljd//X1Na+vr7y/v1MoFBgfH+/Y+UW6iTpvkS6SSCSI\nx+MEAgFWV1exbZuDg4OO7L20tEQ2myUWizE6Oko4HG7r0IPBIKlUiv39fQzDIBKJMDIyQjgcJplM\n0tvby9raGre3t8RiMRqNBvPz8+6o1FAoxObmJqVSCdM0iUQiHTm7SLfRr2Ii8iOPj48UCgWi0SiO\n47C4uMjW1pb73/m/ymQy5PN50ul0R/YT6WbqvEXkRwYGBjg5OXHnE8/OznYsuEXk76jzFhER8Rg9\nWBMREfEYhbeIiIjHKLxFREQ8RuEtIiLiMQpvERERj1F4i4iIeMwvRph4T/csGFUAAAAASUVORK5C\nYII=\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEcCAYAAADUX4MJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvXeAHMWZ/v/pNGlnc5S0ymmFUE6WEAgQ2UJkGxtjsMEG\nbDD+YnNwZ3PnH+fD2GcwnDFHMBmcwETLIiMJ5YByzqvd1eY0eTr9/uie7p7dlRACHQ7z/LM73dVV\n1dXd71NvqLcE0zRNcsghhxxyyOFjIH7eHcghhxxyyOHvAznCyCGHHHLI4ZiQI4wccsghhxyOCTnC\nyCGHHHLI4ZiQI4wccsghhxyOCTnCyCGHHHLI4ZiQI4wccviUWL16NXPmzDmmsg899BC33377Ce5R\nDjmcGOQII4dPjTPPPJNx48bR2dmZdfyiiy6ipqaGhoYGAO68805qamrYvHmzU6a2tpaamhrn99VX\nX81LL73k/H7kkUeYO3cukydP5vTTT+e2224DYN68eUyePJnJkydz0kknMX78eCZNmsTkyZN57LHH\nTuTt9glBEE5I2Rxy+FuC/Hl3IId/DFRXV7NgwQKuuuoqAHbt2kUqlcoSjoIgUFRUxAMPPMATTzyR\ndbwvvPLKK7zxxhs888wzVFdX09bWxvvvvw/AX/7yF6fc1VdfzcUXX8xll112Im7tnwa6riNJ0ufd\njRz+hpHTMHL4THDRRRfxyiuvOL9feeUVLrnkkl7lLrnkEnbu3MnatWs/ts4tW7Ywe/ZsqqurASgt\nLeWKK67os+zHJSx46KGHuPXWW7n99tuZPHky8+fP58CBAzz22GPMmjWLM844g+XLlzvlm5ubuemm\nm5gxYwbnnnsuL774onMulUpx5513Mn36dObNm5elMWWu/d73vsfMmTM566yzeO655z72XgG6u7u5\n8cYbmTlzJjNmzODGG2+kqanJOd/V1cW//uu/cuqppzJjxgxuvvlm59y7777LxRdfzJQpUzjnnHNY\nunQpYGl/K1asyBqHjEmsvr6empoaXnrpJc444wyuvfZaAG699VZmz57NtGnTuPrqq9mzZ0/Wvd97\n772ceeaZTJ06lauuuopUKsUNN9zACy+8kHU/8+fP57333jume8/h7wM5wsjhM8GECROIxWLs27cP\nwzB48803mT9/fi9BHggEuPHGG7n//vuPqc5XX32VJ554gi1btmAYxqfq46JFi7jkkktYu3YtY8aM\n4brrrsM0TT788EO+853vcNdddzllb7vtNvr378/SpUt58MEHuf/++1m5ciUAv/71r6mrq+O9997j\niSee4NVXX3WuM02TG2+8kTFjxrB06VKefvppnn32WZYtW/ax/TMMg8suu4zFixfzwQcfEAgEuPvu\nu53zt99+O6lUioULF7J8+XJHwG/atIk777yTO+64g3Xr1vH8888zYMCAI7bTU6Nbu3YtCxcudLS+\nOXPm8M4777B8+XJOOukkfvjDHzpl7733XrZt28Yf//hH1qxZw+23344oilx88cW89tprTrkdO3bQ\n3Nx8zL6dHP4+kCOMHD4zXHTRRbz66qssW7aMYcOGUVFR0We5L33pSxw+fJgPP/zwqPXNnz+fu+66\ni2XLlnH11Vcza9asT+WfmDp1KrNmzUIURc477zw6Ojr49re/jSRJXHDBBTQ0NBCNRjl8+DDr16/n\nhz/8IYqiUFNTwxVXXOEIxDfffJObbrqJ/Px8Kisrufrqq502Nm3aRGdnJzfddBOSJFFdXc0VV1zB\nggULPrZ/RUVFnH322fh8PkKhEDfccIOjiTU3N7N06VLuvvtuwuEwkiQxdepUAF566SUuv/xyZs6c\nCUBFRQVDhw49pjERBIFbbrmFQCCAz+cD4NJLLyUYDKIoCt/97nfZsWMH0WgU0zR5+eWX+fGPf0x5\neTmCIDBx4kQURWHu3LkcPHiQ2tpaAF577TUuuOACZDln9f5HQu5p5vCZYf78+Xzta1+jrq6Oiy66\n6IjlfD4f3/nOd3jwwQe57777jlrnvHnzmDdvHrqu8+677/KDH/yAsWPHcsopp3zi/pWWljr/BwIB\niouLndl2IBDANE1isRgtLS0UFhYSDAad8v3792fr1q2AJbyrqqqyzmXQ0NBAU1MT06dPByyNwzAM\npk2b9rH9SyaT3HPPPSxdupTu7m5M0yQej2OaJo2NjRQWFhIOh3td19jY+Klm8t57MQyD+++/n7fe\neouOjg4EQUAQBDo6Okin06TTaQYOHNirDp/Px/nnn8/rr7/Od7/7XRYsWMCvf/3r4+5TDn+byGkY\nOXxm6N+/PwMGDGDJkiWcc845Ry176aWXEolEeOedd46pbkmSOPfccxk9ejS7d+/+LLp7RFRUVNDV\n1UU8HneOHT582NGYysvLOXz4sHMuEwUG0K9fP6qrq1m9ejWrV69mzZo1rFu3jkceeeRj233yySc5\ncOAAL730EmvXrnV8AqZp0q9fP7q6uohGo72uq6qq4tChQ33WGQqFSCaTzu+WlpZeZbwmqjfeeIMP\nPviAZ555hrVr1/L+++87ZsXi4mL8fr+jRfTExRdfzOuvv86KFSsIBoNMmDDhY+85h78v5Agjh88U\n99xzD8888wyBQOCo5SRJ4uabb+bxxx8/YplXXnmFxYsXE4vFME2TxYsXs3fvXsaPH/9ZdzsLVVVV\nTJo0ifvvv590Os2OHTt46aWXmD9/PgDnn38+jz76KN3d3TQ2NvL88887144fP55wOMzjjz9OKpVC\n13V2797dyzHeF2KxGIFAgHA4TGdnZ9YMvby8nNNOO42f/OQndHd3o2maY666/PLLefnll1m5ciWm\nadLU1MS+ffsAqKmpYcGCBWiaxubNm3nrrbey2uzpY4rFYvh8PgoKCojH49x3330OoQiCwKWXXsq9\n995Lc3MzhmGwYcMGVFUFYOLEiQiCwL333ntUDTOHv1/kCCOHTw3vDHXgwIGMHTu2z3M9MW/ePCoq\nKnqF3mYQDod55JFHOPPMM5k2bRr33XcfP/nJT5g8efIR2/808NZz3333UVdXx6mnnsr3vvc9br31\nVsdHcPPNN9O/f3/mzp3L9ddfz8UXX+xcJ4oijzzyCDt27GDu3LnMmjWLu+66q0/NoCeuueYaEokE\nM2bM4Morr+xlZvrFL36BLMucf/75nHLKKTz77LOARVL33HMP99xzD1OmTOHrX/+6owHdeuut1NbW\nMn36dH7zm99w4YUXHvGewdIS+vXrx2mnnca8efOYNGlS1vk77riDUaNGcfnllzNjxgzuu+++LNK5\n+OKL2b17t0OuOfxjQTiRGyj927/9G4sWLaK0tJQ33nijzzI//elPWbJkCcFgkHvvvZcxY8acqO7k\nkEMOJxivvvoqL774Yq8Q2xz+MXBCNYxLL700a4FWTyxevJja2lrefvtt7r77bv7jP/7jRHYnhxxy\nOIFIJBL8/ve/58tf/vLn3ZUcThBOKGFMnTqVgoKCI55/7733HHV+woQJRCIRWltbT2SXcsghhxOA\npUuXMmvWLMrLy5k3b97n3Z0cThA+17DanuGJlZWVNDU1UVZW9jn2KocccvikmD17NuvXr/+8u5HD\nCcbn6vTuy32SS8yWQw455PC3ic9Vw6isrKSxsdH53djYeMTVwV6Yppkjlj6w+1AHtz2whLOmDeLW\nKycdsdyFP7BWLL/+y/m5cQSu+veFdMfSnDdzCN+9vO+1A5kxe+JHZ1NREvq/7N5x48ofLSAYUHjq\nrqOvifm/RmYsM3juJ+dRlO//VHXe98I6Fn1UR2VJiN/+6Ow+2/vpjbOYMLL8U7Xz94LMPT/6r3N5\n8vWtrNrayOCqfB66/cxPVe8JJ4yjBWHNnTuXF154gQsuuIANGzZQUFBwTOYoQRBoaYl8lt38u0V5\neb4zFoebrL+xeOqYxudwYzeK/I8TWe0di08Cv2KNQWdX4mOv37anBWFoyXH17/8S5eX5pFQD01T/\n5r+Vg3UdqGV5n6qOZMpaC2IYxhHvt+sYnu8/Glpbo6TTGgC6/ulyscEJJowf/OAHrFq1is7OTk4/\n/XRuueUWVFVFEAS+/OUvM2fOHBYvXszZZ59NMBjkZz/72Ynszj88kpkXwzi2SGlVM/6hCON4EfBZ\nn0EipX1s2cb2OGP/DgjDNE003cAwzL95jTyaUP9vGjpxKwj+ZqHrn+09n1DC+Lg8QQD//u//fiK7\n8HeFV5bsQ5FF5s0aclzXJ9M6AFqPl0TTDR59fSunnNyPiSNdDU7tMeMwTJPH39jG5FHlTKv5eNPg\nkbB6exMbdrdy/YUnIX4Ggmr7wQ4+3NjANy4Y84kJ7u3Vtai6wRdnDgFg4cqDxFMal80Z7pQJ+qw9\nIHYe6uRXf9rI9fPGkB/y9VlfY1u8z+PHCtM0eXrhDkZWFzF7fL9jvm7jnlbW7mzmG+ePQRSPPqYv\nLdpLSZGVB8swTVTNwKdIJFIaD7+yme64ypTR5Zw7bRCPvLaFc6cPomZwcVYde+q7+MvyA5w7fRDv\nravj+nljHGI9XpimydKGlQihbsy4Gz0ZswnDNE1eeGcXowYWMX1MZda1ndEUzyzcwVfPHkV5UZBn\n39zBkH4FnDahP0fC+l0trNja2Ov42wc/YEndCn404zaCspWRIJnWeGLBdubNHMLgqnyn7CtL9iHL\nIudOG8hv/7KNc6YPQhDg7dWH+OYXxyAKAk/+dTtnTBrAqIFFR+zLB+vr2X6gHd0wue6LJxHwSTy5\ncAszavpRVhTkpUV7+cYFYwgHlaOOoWGaPPfWTsYMLu41Ru3dSZ57aydfO2e0c+wXv19Pfsiq87Og\njlzywb8hvLH8AMBxE0bKJoyequfO2k7W7Wxh3c4WnrzTtWFqWna5uuYoq7Y1sWpbE9PuPH5b5yOv\nWUn65s8eStVnYO//799b0TcnDyth1snHLmQB/vC+tZdDhjBeXLQXIIswAn7rM0imdTbva2PBioNc\nOXdkVj2CYE1Qu2Kp47qHDJJpnQ83HebDTYc/EWE8+NImAM6ZNoiBFb0TEHrx15UHs9tUdXyKxKHm\nKFsPdADQHUtTkh9g4942Nu5ty3ovAH7+wkfohsmmvW0ArNhSzBmTq4+5v31hS9t2/rDzFQKjAyTW\nn+4cz2gY0YTK+x/V8/5H9VnCcEf7bl5eu5Y9eytpj2zmrmumsmhDA2xoOCJh7O7Yy0MLt2AmXOGv\n2xrGa3sXAlAfPcyIIiur73vr6li3s4Wt+9t5+DZ3hX3mmxxVXcjanS0E/DLLNh3GBEYMKCQUkFm9\nby8fRRfxmytvQBazReob+94iko7y7lsuIb+5+iCD+iusDz7HmlWDKY9PpaE1RuHivXz9vBqOhuaO\nBIs3NLB4Q0MvwvjtX7axo7YT5QN3/5JoQiWaUBH8cUw+/beYs0f8A8HRMHqYpOJHMLX01DAy139W\n+KyTCLR19y2sNUPrs62jte89F/Rnf+R9jYMsiUc890lgfMoxMY7R3OhFMq2zs30Pf2n4MwhW/6MJ\nFekomkrGrCnkdYKkEgpkz3y70xHePvgBunHs47Hy8DrrHymddTyaUDFMgx1te+hrHvzrDY9TL68H\nOUVje5y0mmnT4L9W3c/Le/6SVd4wDR5Y/yiBccuy6lM1g7Tumr/SejrrnFjYgjFkFRtatvTqQyRu\nXdfUHndqTKQ1UqqO/+RlyFUH2dDcO1/YmwfeY1nDqqx+mCbsi1i5vuSqg6iabtenO/3f3bG3z/e3\nuaNvDVc3dGqL30Cu3oUiZYt1sbCFwIQlJPx1fV77SZAjjL8RaJ+BQyp5BOdW0kMY3pdQ7aFhpNTP\nmjA+m3qCfstk1BntTRgN0UZuXfRvvFu7uNe5tOreX8/xTXvuvafZrOe4eMt8WsI4VptyR7KTrW07\nevUn84w/CZIpjf/Z8Bh7YzsR861913XDxMRECHXjG7GepfUre10nFjcSGLsSZeBOeloWf7n2N7y2\ndyEfNW/qdd2KhjV9Po+D3XZG3VS2hhRNqLy463We3v00UqmV+TcSt4R5d9p1Uguyagl9ezzEgg4a\nYo28V7skS1uOqwn3HvLbnf9VzeBQpN75HVOzha/cfy9SUSsv7sqO4gKIJlUEX5zD7VGHaFXNQNMM\nBNF+pj0GyUtOSO5zkyWRjnRH1m9w39E/7XqNB9Y/yrrmjb364TWJxpJu/S2JNkx/DKX/Pgrzss2p\nUqG1GDoldvWq75MiRxh/IzheQaRqOnsbupw6hGCEpsKlJDUrpfXuuk7W7Gy2CitJNh5yzRVb9reh\n6Qa76zqtvReSfQuj7liaw22xY+pPbVME5DRCsBvdMDnUHKUj0lvQ17VE+WhXizO7OhqKwn7kfnvZ\nl9rE4bYYLZ2uQFh52MrY+vq+N3td53Wm9iTDlGe89R47+Xn71BlN0dQedz7mhJo9O/44dMfTNLS6\nY9fXxKC2KdKLDH+8/B4e3vgknamurPtNqTrN8RY6U9kf//aWvexv7OxTgzkcdVOaC7Lb/2hCQ648\niFTSxO93vuwcb+uy3h2l2kojLxa0O+S7pXU7d3z4/9GWtASxZljvzKHmKA2tMZYcWM/zO17klT0L\n2LinlbSqs35XC3/e9i4dKYusDDF7DOtaoyypt7bHFQKWQFy5rQnTNNnZ7ppXBMW6LqNhSCWufyLz\n/bR0Jlm5y33HMwQJ1thn+g0QVd3nEjHakeyykXSUVdsbWbuj2Tm/uHYVgYlLSBbudrSvnYc62RV3\ntZGE5qaRB9hS56ac9457Q2uMpqSVHNLURYcwDhyO0NAa48N6a0vdPa3ZGoGmG/z10EL8J60ATJra\nrfdibeP6LHI50Nht/aOk8I34CKnEakslu3/Hg5wP428ExzNzBPjv59exYvNhbr9yIsm0hlxeRzJY\nx57O/QwMDONnz3/klA1OWsTjexYB5wHw4gd7WbDccgJ/4/yarFl3Bg3RRn76wROk9kzgkVvOR5Gl\nI/ZFNwx+8tQafKM2IxW1cChaw2//ZM0Ye9rI//2J1QB8/bzRnD7xyNuJAsgSyAP20KqH+NHjpVn1\ntdtCqMTf2+HoJYxkSifPY1ZJpjUK7JmYN6pMHrCLneHFrG5UmV41mcde30pDWxzdMJEqD9A+eAeH\nIoMZmJ9tO09oSVJ6iiJ/YdbxH/5mGZpu8ugPT0eRxSzCeLd2MZXBCn711OGse1IN913oSnXT2elu\n5BRPqfxy3aP0z6vi+5NvBGB72y4e2vxbzHSA6092d//LoDHu7gsu+JKOP+ZQpA653J1xx9UEISXI\n7f+7HDAQApZANVNB0jaJvrxnQZagTWgJuqIp7n56Dbphogzajmwnb3jwlbV8oWYgK3cdJDhpkdsH\nOTsqavPBRoKZva1MS3j+/t3dHPQvYX2LZ5ZtC92uWBowkIotwhBMKWvC9YclW/DXZNpyiVjVDKIe\noo157mOT9h7YCoJu6jz6149AzaToN2kOWxqYWNgKh4cBsC+6C7/hrm6PezSWw20xHv7ravwZ/7Os\ngt2VNTua8U9oQrSXnkiS1XBbd5IfP7GcoL3X1vvrGri8xnAIZdX2w6jFe61ZvpKmuSNOVbnCU9t+\nnzWeO2o7AAGl/16kEpf0eo778SCnYXzO2Nd1gOe2/4l46thnrpqh8dy2P7Gv6wArNlvCpqUraWkY\n9kee0JJEssIVvTNPj/qe0kDUeLnlCbZE1/Zq65FNTyHkdSH320c0cXRSy3y0UpE1o13ZvKJXmc5U\nF49vfg4hYKX7jnn6uKN9Ny9sf7GXXTxhRhBEE1Po3X5H0iKM4kBvwuiMuTPzbEI2Wd+yyTGneM1E\nctVBNCnBxpatmKbJwaYI3bE0CAa+wZaJ6EB3tlO5NdHGv3z4E/5r1f29+q77u5BKGhytRbXbEnwJ\nXtmzgEc2P9Wr33s79zv/d6W6s3xQjYkmYmqcxpgrCFoSrXadSZ7Y9Thi2DV3AFnaiOBLUhS2JNVH\nvJJVzjv7FgJxx8IiyKqjYfQPV2VdE1XjtEdSrs/D7wpNQVGpa445moFzXNJBcN9B7+wbyX0fssgC\nMIwonQdWEImriMXNCIodXSXoJNLudVnt2f/Xr36SSDRKV6o7q+8ZJOjEiBWgHh5i1eFzZ+OZb6on\npEIrIEA9NArINnF1x9IIfvf9y+qToCP4rHOCZCDJ7jsj5nVnXdNqa3sHuw+xtMPdbEzwWd93T7Oa\n1THrfTH17MndsEFH36PmWJAjjM8Z9617mJWH17KtfYdz7OOcxR81b2Jl41ruW/ewcywvoPQgjESW\nDd/7ASBlCzUx3ElajLBLX96rrbakJXxM1Z8l3PtCMmU77aJWyGRnuqNXmUc2PsWGls3I/Syh6J3d\n/3rD4yw/vIZaj50ZIClYH5Ep9m6/w+6fX8peKdyZ6uKJgw+gDLYitrwzUKmsntfrXubJLVYKbs0x\nSZkgWuW60910x9Ik7HvyCkJnKmpjY8tWDNMgriWIqO6+F2lVJ3DycnwjNtEcs4SLrhuI4Q4CE10b\nf0Z4uPW5Zo7udMTtu6RSn7DGLaJGHRu5LGY7pMViV6MQi5pZ3f2+p62ko1n1RGui3Xaqm/hGeyYP\nctrRMCQhW2TE1FiWJpcxKYFFBLGkmkUIercdLeSx6XuFqSBpDKrsOwrM0KN0HlxBJJ7GN3gbpiFg\nJKwFfwnNHXdBSXn+t+oeMP2byEowizAyGoZmaJiCjqkpmGlLqGa+F8GXIDB+qVufl0jsdvRua11O\nTHPvfWf3dnxDtmeNhXeMvO4OFY/PJeya0JDTNLZbdf732oeoVbdl9SMaV4lrvQnDuX+jhzVA+vQa\nRs4k9TeCtMdubpgm0lHWL/S0lYJlDoqnUwglCadMWnDr9M6SBEnD1F0hk0UmR4IufewCq2TGTyBb\nwiClZ9vldUPnUNQyUZmqJeD7cgKrRnY7qhSxZjaiYc1MbbOFZmh02U5Rb8QLQF2kAR0NufIQWtNg\nkmndIWKxwBLeJiYJLen2QUk5H3JXqtv5WCF7jBJq9nh5TRGdqS7HLNXs8T3s7NjDkNIqVN1ALGrO\nul4saEdvtcxyHzVvYkm9q5l1pSOItnYUnPIe2zy32ZnqpCJUTlLLJhzRQ27+UR9lnUNJUZTn46BH\ny0zvOxnfsC00x1qJhVWEvG5Ev0cwejSMuN3WndO+z71rHiCmxok6zl0zezYuq0SjKkKBdX6wOpO9\nyUNQ0IEgq5iaDyGvC/+Y1VltDanKp7ap94ZTzWtXoMbbePBn38McahKuGknzijUopSbJg1sYdOoP\nqV/zDIZQj2molM0cSPFJFqHse+9nbJ/4bbbVr2bfcxvJG1TE/ob1bB24ih/ffbfVgC5jpizzX6xp\nOy2rX8IUUygb0wy+fCxSyI9pxmnc+CeSnXWIqxNUnjEIfypE9+42/vT4U7znf5WioiLU+WEaP9iP\n5JMoP2UQgpzmwOL7GTD9GyhVm9jx65WEhxQTr+vGd9ZYmjavJdl5COQIReNLqDpjKIKSYv2GjTx6\n3wvsbtmHKIsMu2Yi+5/fSMXkcqLJycRUld2/XUf1haMJ2kQrKGnMZLYJSjDFLBPc8SJHGJ8Cj72+\nleqKMBd8YfDHll27o5n31tXx/740AZ/S2w+wfHMTYAmaX/xuPd+/YgKvr9iNIMCX5libSv158V40\n3aB0RG/BrWoGcbPLEXhr9zTQb6ibF0kIel4WMVvD8M4KAXx2qoy31rqmEUSjF2F8uLGBrQfa+fb8\nsTzxl+2I9uRTsGePqulKt39/YhVCqAsyoeO2SSKjYUTSroCIqjF+984u8kMKF8wcjOGLuqqwpIHm\n44/v70aVXVNLysgmjPakO1MT/AmSac1pSwy5kTc/enoRna3WjNtLCm2JLnYeyq7D6V86gW4YPPDi\nJmaMqaRRd/vRmewCe01aoydQYE/3Ps5ltkVOerZGkJkRNnXEeXbVe5AHp/afyYcNK+hOdRNIZZtw\nMuhIdlERKidhE7PW2h9faTOGp6+mJiPYBB4S84kF4hT4fM4MX2+vwIhZ791bm7YxavoUZzast1eC\npCEVtpHUUqRUnX2NbQiKyLLVUZIb57DCFFlt2DNpSSO1cY6nbQXDMKCznOSGOew2/OjmaLSG4Zj2\n7NdMB0husK6RShoRC1sZXFUAG9090zOonD2BZFOCC274Aet5lc51PhLNbQz68jTMhjkYUeh/6lxC\n47dgqDq7H11HwUjL1yQIAjvTFjGl2xMMu2ISgeIyzPc6efbFV2EwmLqMmbSEbv6IPIK+m5FKGoh0\nvUnj2+30O30cTWtWIPoGMmTObfgnLEZPpkltTlP3wQ5GXDeFSmEW3z3lbH665efW/dn3KfhS9t8E\nUkE7qbYEA848jeoL06jpA5TVnIevsgNl2Eb2PrmZxKgkgao2nv3lvdzz01/wh8hLpFIJ9JbhlExp\noWvPFnYWnsaqvRswdYNgZRhTlxAkHd+odSTXneU8Y/XQSIoGtmWZ4I4XOcI4TpimycptTbCt6ZgI\n4+FXLTPDtgMdWautM2iLxsgQxu66LhZtqGeJ8RSmIfIl7sU0TRassGznlw/pHVmk6gYJwRVcta0d\nxPpZAr60wE+XR033mgPAQLRnhULm5bZZ509LNxPIbJ8t6r0I46mFlhlt3qwh2StqbdVX8xBGXUsM\nqbKOjDFEsEkrE6FUH3UFRDQd4911ll3+zMnV2dqRqGHi463Vh5BKDuMbYR1P6T0JwzWHCZJGMq1b\nhCEYCEGXnLrT3UAZYKIM3OVWIBi8unwXYAn3LA0jneJwa5yt+9vZur+dorFNYKdC6vD4Cxq6XZ9A\nl+2c13TDGZ9Qqpq4v85x5m7Z105STyIBs8pO48OGFXSlIwiq3uOZ2fdo15mJiNOaBpNXHEf3R7F8\nVgJmOoAgW/dbKJYTV/bhl9IOqZu6jJkIY2oKCV8zb66qdWamelcpYoF1D3E1wZZ97cTVJAISopKZ\nHRhoug6IzkTE1BS7jky4qfXXMHCseYKoW85eUUQ3PSHO4U5KigVkScA0rUhVpWkcauVmRJsIu9Jd\n4ANT9REor8BXFCDVbOeSalzHzqUbQVNIR1Ko3d3IFbVWqpRoAUIgSaikADlcjq5ECBcNYMee/UiD\nsTSMZAgVe2coAAAgAElEQVQjGUIz66lb9Ti61g5SAtlfweDScvbu66BqwjzARPSlkM0COjtqCQ8p\nwlcUoN3YwMrdE62bMUFrGoSpg1RxyHr37O/QVxBEjp2BEVuBmNdN5PA6utd/CIKK2mmS6ugCOQ+l\nSCQiFqGhQrICvWUoRSfvo+n9j6gfGKH90DpKJvWjguHUbqsiMG4ZgqRbwQ32M9ZaqvEPitOuu+/j\n8SLnwzhO9Ey/cawwj7RAX8wWCM2dthAXrY/Jigyx0BrrnUBN1QxU0bXPIqnUtViC4sJThhIMecxT\nGR+GqBOY/D5SiW3zNi3C0DQDwzSz7bWSSxh7Ovez6NAypLI6lOEbaenwmEQE3YlL1/EIDUCyTUHW\nj8yqdOt8c9wN/Yx6/AD3b3jIcS5a11njJPgS+Ea4TtGeJikvYZAhDN1AGbTD0sJMwa7Hnvn540i2\ncDRVW+OwP24hrwtlwF6nuriazHqKpmchmtdG3p5w+xDXrdmdphvOhyx1DrTqtwV0bVPENhdKiIYP\nRZTpSnWRTGsIHvtzRdCacOzptBZ/ZQgDTYZ0yNIoMuYIu63/N/kmCgQr3Uta6XDMhugKgyrzMbqL\nEf0JyyeUaUtXQLPGIqHHiSbSlilJV/jSmSMITFxMYOJifKOsSKHAycvwjV6LGIoQmLgYuf9+EHQC\n45ZaPht/FLGgzbnu9usHMe2sJue3MmgngqTzSsPvqShREATQO8vpPjgAvbMcMa8bQdSJaDYpaz5E\nyeqff9RHjCnvJNVcx8hvT2XIWTcQrCjC0HR8Q2zbv01oAbEAIxFGEE1SZoqUaj3nYZWlgIDRVUb9\nwu0U10xi+BWXUX3haEzN4LSRYyw3V7iLqgofCAZjq/sxwJM4URAN3k09ab0XyUICYhCtfgSCaODz\nPBdBCFptRYtIdyTorHuHEdePY+T1c8grr0HrtFaoC5LGa13/a9Wn+UD1kxcIkz+8hGjzRiJ12yke\nV4nZWo2ZyEdtsFauC/6EO8nQFQrkEnxS376rT4IcYRwn+lrcdSS0JToITPwAsbAl67jXuS30cEQf\naM+2cze1xwETqayO2m7LKWwa7uPTdANN8URYSBqHbDuwTxadWSzgkJPgSzjmCsg4lU10w6S1K5nt\n2/BoGL/66H95cfdr+IZtQS49TG17b4FuNYBr/pJTiIWtGCnbqZjRMOwxaEm4dURSrkbRmMg2TWSE\nrVdLKPIX9iKMTMy/dY1KMq2hGabjP0jvH2t1K2Bf57H3aq22GcMmE9leTJZBQku6z05Oo/pcYsi0\na5om27QPneMZ56Smm05biYjlx8kQxsHGiDV+ukxKMygJFNOe7LSc3vZzUvQwd07/PmXBUtY1bSCu\nJhyflqnL6HEr/YMYtLQMQUkR0ssYUTSUAsEimqjQ7BCQqcvkh3yOWapN2YNUbI2RqfkcX1Pc7Ka2\nvc0aE122khp2WvZFQUkC1n2Zmg9Ts7UyOY3cf58zjqbmw+ioRK21Yk3rY41Zi+wyaEk24Q+a9jWW\nEURrGYDokzH0JHHDevam6sfUPUYSqRMpKCPKIsm2CPGG7KALQdIsTSqlOSlDkkYcVbfGNuy3828l\n8zBSOr5CGUFWad/QCKbA2LIa8keU0LlvFSm/pVH7NR9SuJrI7iRp22el2d+JEixGSDRhJMLEGyJE\nOzsQ82yys18fMxVET+mIPgnRL5FuDBJr2YnePJBgXn/UaJp4g/Vd63ER0zQZVjCUkmkVtGx/hdCA\nAqSgQn2T6tQHOBqGaYhgikzNP527Z97Za6w/KXImqeNEz7QaALXddeyvTzBr5HAUWcQwTdbvamWX\nsRTBl8I3YgO6foZTPsuM4o0aCUY4rLaR8XRs3NtMR0RFzG/HN2wLhzPWJU8UxKY9bZgFUTAES1BL\nGrV1liaiyFLWLNgRut64bDUAStIS8IZMXXO0F2HEEmqfC+0aOj3CWe5hOpE0MGTE/A4E0URrHYA4\nYK9DJJpmsGxLA+vbXHNQdyoKFPRpt3c0DJsAvzr6Mt7Zv4yo1sq2/W3UHe6iJXWY/V3uoilkjb31\n3fQvy0OQVYxYAUbMcjTkF+qk6tyxUA+NtGZyeKJNMuTWVerY8zNrVvwnrcAUVcsMI2mOZlPb1kZM\nsEjQSAVI+ZNEEylWbWtyxj/SKRMwcQiktjlKoFrDTPtZta2J4tJimuIttGh1jqCRuwdxoD7G2KKx\nLD68hHV1u0lmggt0mURHPkopSOWHEIuaEUSTZERm8YZ65GQZpilRl96NMtg1H+UXKZjddhK+4u3O\ne2dqCoYdAXQgtpe60IcgWOGaq3c0kd49Cd+kd0FWEfwJS3NTfZjpgKUlFbZhJj35izQFENC7SlGA\nDc2b2W1rST1hZhb3ZQhB9SGHFEJVFax78vcU1BTiZwiioZA+WINv8A4aBx4kujRG08OrkcVWAkWD\n7f6bCIL1jI2khKabTnRVVypCLKVRSMhJRGimglSeMYT6hS+jbA4QGlhAyhQp8hcy8NRx7PvrWjb+\n+SEQBUZdWU1Mq6by5MvZ/7s/ASZyno/h10wkXF1Dqm45+xc8R94IEX9pyF33YgduGKkgwUFhgv3C\n7HxoNUqgjGDxUECkIDKRwZcdpv4v2zE0A8HMY8DkGkYVj2Br/21IAZGSSVY+MkO1EwymLcKQ++2z\nfJP2+IV8PvJ9R89BdizIEcZxomfiPt3Q+fna/wFg+66vc+NFJ7N002GeXriD/hNbwQcYEgnPegBv\nDLXgIQwrB46LB1/5iEElJb3CYREM8kM+IvE0O+vbCFRGMOMFVtiepNFt57/xKyKm5PVhZMI0M06x\nUYihbqTSRkfAb2/Zh1LtWWVraxjrdmVrSQDN3RHIiJketnZB0jBVV8CbSVt9t4Xwxr1trGhfhNLf\nDQWNZKI5PGYYPVKElN/pjpMtZNvadRpbU0gFGnc89CEgIA/cgdLPQK0bgVK9B0HS2LCnlQ17mwlO\n0zE0BdNelCVlTFKOTV9xZtWCkrQCADIhlgemwIS3SeopVDsiTAzY8fSyiqiFaEtYZq3/+v1ylLGg\nNVdbZORP8r2H3wfVj2+MhmkIFuHrisfcZIKkYep5vLu2Dt+QBFIFNBa/j8+ORu3qMvn579ZTPcKA\nEvjj7j9jyvYs3RTRuotQALnM1czSCZln3twJgDK8nEhpI2JGjusy4aAbTpoFTcFI52OqCnKFu+pY\nDEb57RvbMUxLQImhCIEJS6wuaAoYMnpbP+sa2+8gdw90TJ5mMoxihhyyEBD4cuVNPL/5dacdof9u\n0HA0iAyJV593CmJBO4KskvxoNGd8fQ5rDlkmp6SQoPqrJxEQQ3SsPM1qt3wHcIARX/4acngHkhxm\nyJzbMFWLgMMnjySUCgHbCSlBQMNMhigcX05hjbvZ0ujpVwKQ5ytm0KUnOcdPH3cm+Wsklqo6pefN\n6hGZlMcPf/Rz/ufljwhOfTdraIecfROyJqLb79qgS6w6UzunYHRZ7QbS+YxWvkno238EQG0Yilbn\n46Sykfxpg/Xe5o8occcd8JsWKYghWwuzAz0CviMvuP0kyJmkjhM9NYxme/EUwOrtlkq/z07Z0RW3\nPhrTELPWA8Q0T+RSH07NDARRo7Y548j0HJd0CsLWiyLmdSGIJka0yI6McV9cRRYxPITRU8MwNcX5\nMDPn6hPZi9MyUVKxPhbvtXt8KkLPWO8+2hJN2THBdcfSjo9CShTjl3xE7Ygp12nXH1oH91lfdwQw\n7HmPTUKZ64yovZgv0yfJ7QOaYglYKZFdRpO5cJplMhF8KQZV5CMoKUxdol9xIaYukVCTpDSjlwYk\nqnl0pSOkdRVDsAlS87k+EZs0BVm1Z36C7SBOO/0XBNOZFRqpIL1gR1jV19rCV/aadATQ/Jjp7DUp\nGQIEMDqzd5zLmKRMtTdhmLZGYESzU5+bhuSkIDHV7Igv2bTqccYea11Ov9gpngpERhinuj8xmTVm\nEN+ddhWTS6xlzo3avqz7dbQ+XxLBH8eI53PV2aO5ft4Yrj9nalYf8hU3Q22GCH1DtyJIumvCsutD\ndgMA8ny2hpHuPe43zLOiP4JiftbxylAFXzt7FP9x7TR8PVwERrSIQRVh9/304AeXTWf2uH4Y8QKM\nuFun939ZEqgqDZHaPg0z7XdCrzd+uJbdj6+l6qxhDM0fglQ3BQyZorCP//z66VntZL6THGF8zujp\nw2iI9g4DzJCDKGUMltmE0Z3yOK+dUNc+nOJStiD0oiDfeoQZk4UeKbY+fiXtCEFJBiTVTQbo+BWs\n8yWhsPsh2W2kNDvssqvUmg3bGkYqbc+OPfDaoTMLpRw7dg+NwNR8Vqih2Nu0JR6YSZG/iC579ud1\nzJaGM05AeywywlfzuStae2hOVeHSrD44JKpbgtBn5qEKcfucq2GcNtZyHKKkGFQVRvAnMdMBqsvD\nln9BT6NqRtaCO61xMELamrZ/cOhDN4eRprjCTk4jD9yBGIw54ZboEoIvbZmQQvYCxcysug/CyNj0\nzXQAI9l3umovQQCOoAHQu0t7FBYI+qS+NQy7jxnzHcDs/jMYFHfNqpnV1hmU5RXafXClp5EMUxTO\nlqaFRjUPnH4PY0tr+GrNZUiiyPjhpVQXZffPebaagmlaa1YEAcx4mJrBxUiiyPQRgykLlCCLMpMq\nxnN6v9Pd6+PZAj4gWfdZErLfJ1l1vpOwL5R131747b1A8noQRnmwFJ8iMbgqH9MT7ZXeOx40n5OO\nRu8qRcKtt6a6wtJeDYnUllMY4B9ivZeeZyeJApUlIYxIKckNZzghv+ef/0V+/L8/Y9bsU7l50nVU\nYKXqLykIUBLOo+DwaaS2zbBuJWb199PuZZJBziR1nPDmBDJNk4Zo741aMoudRMkua4hZSe/aeoR+\nWoV7C1JB0qwgyZ4mKSAUykT72I7PVBCjuwQpvwMxvx2js5LX619yzgmBBMqAvWjNAx1tYES/MtbU\ntme1ldRS4AO1biT+0WudKKmkpiKIJnpnGXpbf3zDN3kcxiZyv/2YJugd5cjlDS7ZZcpoCromZt2n\noKQQ1TxSSYHSQDFN8WYQNQTbOW9qClXFhXR67tNLQJkP3Aq59Ttj6RdDCAh9ajkAQTGPLpoAN9QV\nXabAlw+m1a+64GKEtIqp+qgsCUKHhGqmSau6syq3Mj2eA7X9CA6tBaxEiMpAPP2zSF0qbUSusHwr\n7sI46/n5hm51H6pNGHpXH9sVO2s4BFJbZiGGO/HXrHVs8pBNGKmdUzBTHmJRA4wqqGFXtxUSbSbz\nrDDqPoRkpm+GZ0+Jr9RcxsPbNwMtWWOZQUAMuved6U86QFFpNokZpokiynxnwjezjs+pnoUoiAzM\nH8Db+5ew0dGIBNAUh6CMeIEVzAGIgsh/zPwXTNNEEiUi8TTPYq3OHhQeTO0u3VnAWJIXphsYUFrI\nLl1CUNKIpjVpKA4UAJZ14Cv9b+Cpv+5AGbKNEZXucwhIQSezzpSKCUiiO26Z8GC1YRh6mxU4kVnT\nlN45lSvPG41QVks0HUUUxKy8bFcN/xq/fnkzCVxLgCSJR9xPZu6g05g76DT7/oWsv/5UBUY0yuDO\neezYZ1kxPqudNXMaxnHCq2E89PJmlh305sI3eWt1LRv2WGaqhO7amJdsbOAHv1nGDx9exs5GTwqM\noxBG5pwv4LaZSbFQG7TTPkiuQDZsQZMJE21OWmTmnWlKth0YYOzAKlcjsDWE1qgdhaTLljARdeJJ\njZitTZi67NEi3BBOMRTB6C7FiJTY5/oQ1rpkCUw5TSaSRzaDpFSdPMme+fkTjoYRkP2UBSqsuPxQ\nt3WNrWHEYgJkZqGiq2GYhoBf8hOQ/S5ZeUgLICSFARPfiI1O/itTU5BECUH3I/iTHEpbfhwjUoxf\nkZBQMFDpSHXgG24984BZCAjEunrvlmZqiiOwM2QBONpeJmIo65oMKegKqe3TUBuGec65c7zivDyM\n7jLmyNeS2jrTLWObpExDdOzhXnxp6JepqL2MxPrTMZNhR/CqDUPRGgdz57Tvc8fU7znlja5S9O5i\nrj3pq9b9emar6r5xaI2DLU3UhALJ3r5W9RKGv1fK7SPt6xGQA5w9+HRqSkZy3divY6ZDVBRnk5Bp\nWuG2smffB1EQHeHtTTJZXZ6P0VnhaGahgFWmqiRkmQP9CcTCVoRUPpVhV7vpX1COmQ6S3jWF2cXn\nO8eDinsf3zz5qqy+OyHzpquBC4Jgj69AMmUwp3oWXxx2DuDuJW/1J8+a3HggScIxbUCW2aUv8zez\njsqnF4Lmd+r6LJDTMI4TXsLYcOgAgWKPM1jS+OP7vdMyI2pZi98+OlBL5vtyBGsfZqfMMcVnkAZS\n26chV1o+hiitWAI0EyapYCZsf4Q/AZjEtThiogitYTi+/vsxRZ3+VQqqItEFDC4rwUzYaQWCEaev\nAKP6l9AgKhiiRgroiEUhz1LfC0tLaAIKC0UKggaGqNEMmKmQa4bx234a2TaJ6bITIRKc/L6V7E0A\nn2l9GPFuxel7hoiCcoCg7MfszkMq6CAw5V3XBxLVPFFNacyEPV6GjF+WCEgBYpLttLZJZkBxMeWF\n5ZRVRjncsMddh4IrkAv9BXTq1jM1Yvmoh0YhjxbxC3kkpC7WR62QWb2rlCADgE6MVO+Pe9zgKjZu\nVjENwVmfotYPR++01kQU0p90ohgj6AkB9chSI1JKuTIQv5yiRavP0h5GDypi5dYmGpu1LDt55h6O\ntD+SLIkEfYqTjfULY6vYUduBblQwqDLfycT7pTMMJ6tvLDmKaVVDgGx7uJkOotZamQgQNQomlAP1\nWRqGqAcJBrJFzbFsJBXyy5wzbSBD+xVwsDHCYqx3yegu4dSThvQyczntiQLzZg2hrDDA5FHlmILA\nTsqI0UhC7GDK6GmcMq4fi1f5EPyWGfCU6slZRBgOevrr6eqXJp3G3sV7OXf4ab3a/ebYq3h++4uc\nVDwJIRxkUKU1+fnxNVN5fel+5kzMznA8fngZm/a2MbgqH38fPgZZFKksDvGFkyqtRcJHwNfOGYUg\nwFVnj7Lv3zpumCY3XXwym/a0UlHUhz/sOJAjjOOE1+mdlTAMPLmaTKSyekQ79YaXDARfAqmkCcGU\n0ZN+O1bd7B0JBc7MOUMY86aP5C9rTDd1sZy2BbLghiEaohUWKqvopo6oBQCBkeo57PIvJL9fC/u7\nLRNKcSjsONsy2TIzff32vAk8vHk1TVFLW+mIRyAPJg/vx1nDJnL3ync4eVQeqxvfdrprpv0Y3SWY\nhkigooVo/UhL+NtOVG+4rtLvAAB+wRK2sYgCPpswbNJSBL/luI8VIgZjDlmYhkB7JIUpZaKabHVe\n0jB1GZ8iEZQDCJLtRLcjdvoVFfLtOeN4v7Z3csSM9jG8tD/rmi3C0BqHgO5DlkTyKSNBA43GXkxV\nIb1zCsnBlmTuy+cwdXg1GzfWYkRKHOe+1jDcIc1pNRXUh8vYF3P74kSSYZkZ7vnWDDRjCk+/s5kV\nmvuuDanMZ+XWJg4191jImVmUeAT7gSQKWcJRkUW+deHYXuXOmzGoz+sD/iM4UL3OXc//kh7o5XQ1\njmEZkyAIzla5M06qpHHdaHZ27USrH8k3bh5z1GsvPc3Vyu78+jSW7/LxwPpHOGfw6cz4wjgg25x2\n/qjZ+D0+hnDQoyF5CCPPH+Cn59zUZ5tTKicwpXJCr+PV5WG+c8m4XscHVoT5169NOeI9SJKAKAp8\ne/5Y5k6p5r+eW9dnuZKCALdcNt75ndEwTNN6v6bVVByxjU+KnEnqOOENq80IIjFtxzlnVvKWHsY3\nzLPdo3fXrSorT5NsBDET1voAb8RGFjL1KVabhUHLFKE1WpFD/lEfWZu/2AIZBGQziOCP4x+zyuqj\nZs0mFXsmnyELsGbw6ApGKmjnWXKJK98fxC/50AXLfNSVtMgvpAQJyZaAbI67EWJg29AN2TJlKBGU\nYZusML+MIOuR7hogKFpj195qvZK+wdtRBlob+CiCD0US0Vtck1p6/1jSuyfRGU31zjBqL37zySJB\nOWiNn5J0MuTmS5Y5r6iPlOgZQd4vz90vORMlZJgmJZnNHsBe7CaSyKQf13rPeDOOVL3Ds/+y6X52\niizik+2oqGSIacp8x/4N1kxXEAQUScHXY0/monzL1JPZutbnmDhs34PZt1SWJfFTRc18UgeqpId6\nXXM8W9VeUHUuh1/194raOhaMLB7G7JbxTCw52TmWeV/KxGqK/IX4PTnegh5S/LTb6h4vvCY3fx/5\n546ETD6549nO9+PwT0UYa3c0H/POcQBbD7Q75ZfVr+LB9Y+x6NAyHlz/GCnVE29tE0Y6Ypt1MkK/\n15oETyimLTTDTV/AsGeUYjCWlSJEtGc8gqSBnKbbb4UaFoes8hlBKYatqCKvfdtPHoKiWnWCE32h\nkD0LvnnC9c6MxIgUIShppPI6pIJ2ezcw2UkpoAzZSjRtEUbYZxGGX/JxwEM+Vr9sG3pmEZG9JsBU\nA4693IuS5BiqJWvG2NSH5q0IPkvDiJSSPlhDet/J6C0DMboqSKV1d92Eo6VpmJpHwxCsPToE0UBr\nGkSRYtmqi/wFWe3oHRVkhG2VhzAymkNbV5Iqv0tahm3Gc/ercG1AWtMgpGgV+b48u+5K+xpXewCL\nMBz7sikwrGBYFqH0lagyg4BPptJj4y7Jt94Hvd0itUtGfLHP6yRJ+FSE8UmEF1ihtgGlp4bxyYWZ\npBl07l7/8QWPgJdf+hOplKvdavUj0CNFTAudC1imrAwET7boz5IvdP3Yd9b07rkuHmX/9Z5wNYwc\nYRw3uqIpHn51Cz96fNUxlVc1nfv+sMEp/7udf2ZXxx5e3P0auzr20Kl6NqXxxzFNHD+AQxR9PS8x\n21fR1OheJwSijrlletVkTi+4xDouq/g9+xP0K7EEXc9QSO8aiJ6z0YyJQBEUfPb+CacNmMWY0lFO\nkYxQcyN2rBfvi0MtJ51cUYdUbS0AG5jfD0mUmFHVW6XOxPR7yUnvLkatHU1Bno/UzsluWU1mgDrN\nEaz0CAkFCAhhZJto9KYh6K3V2e1lNAwlZa9lsO7Xp9gaBm5Kc729Etk28hb63N3xUtunkd7t9qsq\nz1Lj8+QQX5w5BIAR1YUUhcKOZpcJLvBubas1DsJI5KEeHENB8yn4MpEwqp/k5lmkdkzP6rsiiRT5\n8537KAhlayleQpgwIjtqKuCTGFLlRjBl9pEw4wXcPv7HzB14GtPHWPcxdog7K5dFkZKCQK/6jxVe\nshk+wHoX+9v5lE4aUuwIuglcSGr3RBRJ7mXGOnlYySdu9+knH0GNt1G37H94+GFrkezvfvcc3/rW\n17n22q/y5JOPAZBMJvmXf/k+3/jGV7nmmitZuHAhL730B1pbW7jllhu59VbLpKS39yO9/QsUBQp4\n+unf8q1vXcOBxffTtOnPAEwaWUY61sbzj/yEa6/9KtdddzUNDVagygsvPMM111zJN77xVR599DcA\n3HLLDezcaUWfdXV1csUV8wFYuPAv3HXXndxxx//jtttuIZFIcOut3+G6667mmmu+wtKl1t4oM0+u\npLtuHQcW/4qDSx5g6YLHiMfjXHHFRYSD1viVF4hcccX8oxKP6DFJfdb4p/FhJNVjZ3bg43eX09zw\nN9GfsNIhZGyits9h6thiNtmLuX1qCWmlHUHWkIUApcUSHYaAronki4VowMk1QbbssGZANcUj6U5b\nAk0IRrN24qosCvOf189gb9ce/njQTcAnKCrhoEI0oeLvQRhCOjOzFcj35dOWbGdcmWsHlkQBvbOc\nfMqJ2CGTGfIaXjSEfKmQiN6F6E+iNQxl9BlW7Pe4spOy9m8A+NcvzaIsXMi6Vh8v77MIRj00GjNW\nREF/H60NFWiNg5GrDmJqlvbgmizcmZR6aBR6exWhYWEUj3r+X9+awdqdLbyyxNK4fvSVmfxqxyLL\nxODJwqrIEtjpHjIRY0a8wJnRF/o9C6aS2TP/fnmVfHvcNQzKH0CRv5DZ4/tRWRwimlBRa2vQmgc6\ncfEZk9SVZ47gD+5eRUii4BAdgJnI1mgAZFnk4hEXsHhjA2r9CAKTJX51y2zASo8+sMJN5zBxRBnX\nnl/D03aW4IBP4tLThjFpZBmSKFKc73cWjQZ9fgRB4LovnsQVp49g+ZbDbD1g+UkkSeC8GYMYNbDo\nmKJweiLgk5EH7kAqaSSWH6BqqDWrrTIMXm1bReUpJqYJO9M6yiCVuLiLx3Z/gH+CvTo56OPt6Cre\n7rFf16SKcVw6Yt4R273pplvYu28vj//2BYJ+mTVrVlJXV8vjjz+LaZrcccdtbNy4gc7OdsrKyvnF\nLx6wxiIoMHWqyR//+Ht+/etHKSjIfg6yJHDZZV/m2muvJxJP88tf3M3y5Uu56eJZfLTgl3ztm9cx\ne/YcVFXFMAxWrlzO0qVLePzxZ/H5fEQivZOBWnDf5a1bN/Pss38kHA5jGAY/+9kvCYVCdHV1csMN\n32D27DlMGGDyx9YVjD3rFmKqzIzRhYRCISZPnsKm9av4+Y0z+fD9vzLo9LlI0pG1vIwyciI0jH8a\nwhCOsiFRX/i4zYISzqY1JigppEQxWo/V0uEwELdCJwdXF3KQ1Yj57ZSZZdZCsbi12tcvBtEAjZSb\nUVP2kySIqUvZ2VptDCjLo6L4JA7qk1m2qd6J9Mns+RASLPu8IsrcOP4b/O/WRkBFECxhmNbTjCwe\n7tTnUyQSKZPx+kWYFftY3vZ+VnsFvgIiiS7MtB+haYwznhUhd9Z77uAzKQ4UMaLKmtUWxz0fph1m\nmZlBqw3DEQJx1EOjUEb3bVM3U0HMVAhZErPiyPuV5hFQ3FTNAyvCBHblEVNSrjlQl/ErIrKc0T7S\nVuJDXXFsw7Loef378D9MKHcdwZXFlmDND1p+ogxZgDXmfkVi3PBS/uCJjpMlIYvo+oIii4SUIOoB\nq62AX3JCUHuGogIMKHeJLeCT8CkSowdZ2oNXQGSISpFFSgsDTsglWEQmCAKjBvbhwzkGeLUFr9lE\nsl7OvssAACAASURBVO9VEAQrd1MmlTnZ35/8KUI8JVEg6Lee2+rVq1izZjXf/OZVmKZJIpGkrq6W\n8eMn8pvfPMgjjzzEzJmzOeusU0kkbN9cH2q/KAqsW7ea3/3uOVKpJJFIhFEjRzJx4mS6OtuYPdva\nr0NRrDFcu3Y1X/zihfjspd35+fm96uyJadNmEA5b74xhGDz66ENs2LAeURRobW2ho6OdDRvWcdbc\ns9mWCoGaxh+w3rl58y7id797jtmz5/D22wu4444fH7Utx8T88cP5ifFPQxifFD0JwycqpD07waV0\n2xYqWInNArKPZIYwZBVREBDkTNK6MvoNHsRBczVSSSO0jUJHdXwOQTFIDEgZSQT7W/RLfuv1TuYh\n2NpFdbg/pw1w4+0VSeG2U77FB6/+GZ+S4ltf+CK/XW9F0fQXxnDR5AmUBIooCRQjmNZaDEEQuOak\nK1ENFcUjMH2ySCJl7fw3INhjNTBWfDxYkSU+z4KjYr8rdMaW1jC8aIjzO19xhWrGz+Bsh6r5SO+y\nzFmyLGaFFRrRAsRwt+MjUGSxVyoW78zdp0j4hSBxOe5ESpmqgiKJ+GSP2c7ug1fI5UtFdGtdWX6D\noyEv2HutBVhC0DsuVjtiVj/7Qk9C+Tj/gN/TRk9HcpZQFrPr9fb7k06eevVBkdAO1aAdquE/7zzz\niOUWbajn2Td30q8qn3/56iS+c7+Vb+rWG77gEPCngWmaXH31tcyff0mvc0888TwrVizj0UcfYteu\nzVxxxdVHrEfXNO6//xc8+eTzlJWV8+STj5FOW0EeR2oXeo+hJEnOam/rehfBoGuefeedN+ns7OSp\np15AFC0TUyqVdgg/U3OG/8eNm0Bj48/ZsOEjDMNg6NBhHA0nUsP4p/Fh6H1klz0avPtX17VEScaz\nBUU0kw7DdmIHFZ8zSxWVNGWFAYdUTE2hUClCNoIIgTiabpA0kgiGveLYFmppM+msp/BLfnyymJX+\n4bwhczllwIzenTVk0tu/wJTKCc5MMi/gY0TRUEoC1uwzbM/sg36JkBKksIfDt7TQ9jvIEiXB3qYT\nE3vmbkieaByyVro6foi+ftuJ5zIC0jtTVmSRoEf4pXZOJbVthpOCWpFEZ9V8xmneU9AGpCCC5K6+\nNhNhEARnbMFNV+FdxPSV6utJrj271/0eCflHIAxJElGU7D7JkpBFTn2h5wrcjyvvHfujOa5TPUyw\n4SP0+0Qi0z9ZErOIsCeZHStCoRDxuJuwc8aML7BgweskEta3aM3UO2htbcXv93POOefxla98jW3b\nttnX5xGL9Q56EdEQBCgoKCQej7No0XtO+YqKSj78cBH8/+3deXxU1f038M+9d2Yy2ReyEjBCEAWM\nAsomNMgiQcKSFKIsVm1Q3BGiCNIifUqr/YHlKTwqlmKlVV7Sal36M6htQUULYl0ALaKCYkggC4Ts\nyyz3PH/cmTuZbDMJmSQz+bxfr76aO3MzOXNk7ne+59zzPQCsVisaGxswdqz2d50T6FVV2he6pKRk\nHD+u/a133/1Xi7/jVFNTg+joGMiyjM8++wTFxdpNIddcMxbvvvsv2B03ljTUu9qakTELv/jFz5CZ\nOddjPzm/HIQEdf1/8z6TYXT0rozqJgHjnY8L9LBvLUqFMfkkyqqrAARr+0xD+7Y/69qh2Ft3CCkD\nTZgWNxifWBxbVzrWBBgagmBVamG122GxW2CUItEAINhkhFkxo1GthxxdqxW6C43HgCtM2HcmAWeh\nZQfNL/JOt994hT7xuiLnKrzz8WnccO1At3Puy74Sez76AZlt7A54X3Ya3vjwe/w4fTAMBoGYE8lI\n6+cakokLjcF31d9DrQ9zK2kAuLKv5uWTw4zux4C2O194iBGzr7sUj2w76Og7GUMHRmHssHiYTQom\npfXHZ9+U4e2PtbuvDAYZE0Yk4lRxNW64Vpvwbn6hHRAdjeLSH/RbcZ3lLNwDhpZhNL1gpQ2Kw/Uj\nB8JkkPGP/zQpid6GfpFmTB2djJIL9YgMNeHAl45V9HbV7ds/oF38w4KNuHHcJahtsKG8ugH9Isw4\nXlDh2N/EFfhWLRyJ4wUViPOwwKpp37dW7mHdbdfi0LESDLs0BufPu/YM6cqAMbh/BDLGDsS1l7d/\nf78zAzIoknv208kyFRERkUhLuxq33bYQ48Zdh3vvXY5Tp07h7rt/CkALKOvWbUBh4Wk8/fQWyLIE\ng8GIX/96AwBg7twsPPzwcsTGxmHLlm149JbROPhlMa4dMRBz5mTj1ltvRlJSfwwb5vp3//Of/x9s\n2vQ4duz4PYxGIzZs+A3GjZuAEye+wdKlt8JkMmL8+IlYtuxeLFq0BOvWPYp33nkL11wzps33MWPG\nTKxenYc777wVQ4ZcjpQUrXbZoEGDceutufjt//t/UIUEUTAE98y/1vE7N2LHjmcxffoMj/20cNpl\nMBpkZE1qPxPpjD4TMOwdDBhNh6TCgo1AowrVUacJySfRqDrKYjvmHAySAbPHXIG97wNh4SrGDkvA\nB582aIvpVAVGgwwDTIChEjZo6WqwYkY1AHOQAaHGYJxvOA85SCuJ7RwCmjUyDc99qU1sR5paHytN\nv9p1335yXBhyM1suakrqF4qlmcNbPO4UHR6E22+8Qj/eMOVBt+dzhs7FocNVsBamICjW/QP/iwmr\nUdFYqd+R5KSViwaCJDOcMz4hZgNuv9G9fUaDjBCzAXfPc90jP2RApB4wjI45jFszXGU0mo+qhAe5\nByfn4rembWotw5BlCbdmXI7PvynzKmBIkoRbZriX8zjwZTEaLPYWGYZzTD9nyhC3x785XYHf7NJq\nGzkvnsMujcGwSz3fOdS0nERrQ0uDkiIwKCmixW2YXRkwZEnCzVMv83hecJMMo6mLmcN47LENbsc5\nOQuRk7PQ7bH+/ZMxdux4/TguLhxlZdWYP/9mzJ9/s/74ZQOicNkAbUj1jjvuxh133N3i7w0YMBBb\ntmxr8fiSJbdhyZLb3B675JJL8ac/vaQfO1/vxhtn48YbXZP5kZFRePbZP7b6/mbOzMS/votCeVUj\nJqa51vwcOfI5rr9+GkJDPe9pERUW1O5n/WIwYLShpt41BilLEiRZdZQB1z54+hyGI8MwyAaYFCOC\nDWZUWbS7Jupt9XoZa5NRhkEKgiQJ2GXt22WoY4evIKOCEGOIXozQVurKDgaEuYJBRBsZRncINgQj\nvOpK1Kv1LdYGRAZFtJr9yJKMX123Fke/rcCfoN3RpLRykfNmYri5pkUcASDM6Brisl+Id5UfaSXD\naHUMv5PXMOeF2K4K/XZGp7aGl5rOJ3h67821ty6jPT0xJBXUZsDoMyPhF8dxyfrd7zbho48O4skn\nt/Rse9CXAkaTPbj3HzmDlIRwpCS2/MYuhMAbJ97BB7YPIIeNhloTjeo6K2CyA6pZ2zcZTSe9HUNS\njrUNEaYIVFq0Mc16W4Ne9MxkUGCStAuW3VAHBUBEkHaRM5sUfW2EWh8CUedaHxAb7PrW2XSSuic4\n747pSOXLaHMUgg1NbkFu5SLqaYiitQtM8zH60CYBo+l6ioSQJsX3HHNMVbUtV5o3L/zmrbYmwYG2\nA0bTi3dHq4h6muNoS2cDzcVoOiTV1MVkGH2B89+i84q1YsWqnmtMM30m1NubFK/Z+dZxvPHh962e\n91Hxp/jn6X2QjFZHZVSgqs4CSCqEKusZhr4VqbPOk+NiHmkKR621DjbVhjpbvV6O2qBISIrSAkGQ\no8rpgOh+UGQJ/WND9Q2Y1NpITBntWk0sSzIWXp6Nm4e2vBOkuzk3u4+NbGXvhHa43XrZygWvMxmG\ncyhh2mhtTsOstFzwd/nAKIQYQ/QsbdwQbf5mUFIrmVonr2FtTYK3J7RJIb6Oftt2ZkfhIZ3LGDob\ncDojMtQEk9G1SND576Z5JkbuenP39J0Mo9mQVFVdy2+ZAHDgzMf6zwYDYAdQWdsAqZ+AJBRXcT/F\nimnXDMC732jrAYyKI8NwLAYrrTsHi90CYXUsvpMkJEVG4kgFYIMFI+PSMOuydMxcriA4yIC3/q39\nK8m6ZjRmDnatvgaAHzW5lbYn/XTWMMwcl4Kkfh27JdKt5EJrAcPTraetPJ+SGI7fPTCp1QvnM3np\nsFhVRDjWMeRdcy++vXASI/pdgYUTbK0Oz3T2M9pehtHWIGjTINGZfQqeWpEOo6HjLX4mL73TmVRn\nBAcZsOme6/Ry448vG49Gq/2ib+vtK3qohFW7+mzAaG1hnipUFNWc0Y8jwhSUAaiu14afTIoBjfq2\nmlaMuiwW755wZhiOW9kM2sX0bK1294yzrpIkAcFG1wRs7ojF2i2pjv8Cd111Gw6e+Q+mX3pdr/1A\nGRTZbeWxtzxlGJ6+ZbeVgUQ0Wdg2Kj4NpxtO49p+Wplqc5M1b0GKCVc6VrV39Vh+cHs1mbz4wHcm\nYISYO/ex7apd1zoivEmpE4Mic/7CC66Pf++LGH0nYNjdO7+2WcAQQuBc3Xk02i0wS2FoEDVw7pVS\n3dAIAwBFcmyPaTVBMjZiQHyYPodhUrTnzAYtQJyt1Uo06HWOmq0JaLp+AQAujbgEl0a0Xk7a3zXN\nMFobjvC0uYs3m78YZAPuGrMEZWVtlWloX2djdHsXYW8+7ryAUlt6X7joSwGjWQH+2gYb7KoKxXFP\n/rbXv8TnZV/ANASIkhJQLGrgqAQAq2prGTDMtQgJUiDJzoDhvgjvrVPawh1nwDAbFZQ3qT/VlzS9\nM6q1DKM3jGl3tAKrU2hwOwHDizGFjlQhpb6hX4QZZRUNCA9ufYOonuTzgLF//348/vjjEEJg/vz5\nWLZsmdvzZ8+exerVq1FdXQ1VVZGXl4fJkyd3eTtau622tsGm1zb65OsyKLGOrVCFNrlrNDqW6jdZ\nawEAsJkgSY6tVx0BI8jgWEltcJ8QXnx9Gi6cicDll0QBF5IAoNUKr4HMLcNoLWB4uGh2xxDd0IFR\nmH1dCkYPbbmlaXuS+oUi+0eD9HpOD908Er/9y2GPv7d68Sh8W1jZar0o6tvumD0c//jPacydeGlP\nN6UFnwYMVVWxYcMG7Ny5E/Hx8ViwYAGmTZuG1FRX0btt27Zh1qxZWLhwIU6ePIk777wT+/bta+dV\nO6fVgFFvRUSIybUK3BEYZKF9iA3O4W5HUHBmGMFKCCwAqi01gKT9jp5hKO4BY2hiIpKHaIHiipjL\nsHrMciSHJnXZ+/IHngKGp3jQHd/BJUnCj9NTPZ/YijkTB+k/jxgUg4HxYThdWtPupOXll0TrQYao\nqZgIs77TYG/j0wHUo0ePIiUlBcnJyTAajcjMzMTevXvdzpEkCTU1WgmDqqoqJCQktPZSF635HAbg\nmvgur3Ls1OYIDJLzVliDM5Boj+sbGtm1eYoaa22TDEP7HXOTDCM18lL0D3Wt1gSAS8IHtJi/CHSe\n5jA8DUn1ghGrDvGz5hJ5zacZRklJCZKSXN+mExIS8MUXX7idc//99yM3NxcvvPACGhoa8Pzzz/uk\nLc3nMABXwCi+4Cho5pjA/vpULUypgKw4Aogji5BaCRh6kHEU12saMG4amtVr73jqTp7u/fc8h+Fn\nfehnzSXylk8DhjeTfvn5+Zg/fz5uv/12HD58GKtWrUJ+fr7H34uL81yDvqngkJYLuyRFQVxcOGzf\nOfZWcFz8nWXHTWatrr/z8eEpcTh9BLhueAr2lR2BMFmR0j8UZwAMHhCLuLhw1BtdpcFTkhIQE9yx\ndnZGR/uiu9VYXcG6aVunXjsQ+z45jauuSEBkWMv/PjdOuBRvHTyFMWlJ6BfZflG+1l6/pxgdE+hG\nk9Kj7ekNfdFbsC+6hk8DRmJiIs6cca1rKCkpQXy8e4XLV155Bc899xwAYOTIkWhsbER5eTliYtov\nxNbR2ycrK+tbPHa2rBplZdU478wwZOdeDY7yHxYLQs1G1Dke7xcWjGcfmoTvq7/HvjKguPw8RgyK\nxJkCQG1UUVZWjfoGV8mKxiqBsprO3ebpLWdhtd6sssJVkrppW5dMG4KbJg+Gpd6CsvqWCylzJg/G\nvOtSoFpsXr3H3tIXNpv276Wx0bt2+0Jv6YvegH3hcrGB06dzGGlpaSgoKEBRUREsFgvy8/Mxbdo0\nt3P69++PAwe0vRpPnjwJi8XiMVh0RmuT3s4hKYujLpHkGJISqgFCAHZhQ2iwUb9LyigbYTIqCHPs\n81BtrYXVsamSwbFwz9xk0tvQw7Wfeou27oKSJMljjaOeqIFERK3z6RVNURSsW7cOubm5EEJgwYIF\nSE1NxdatW5GWloYpU6Zg9erV+PnPf46dO3dClmX8z//8T5e2obSiHmFmY6tzGM7FexbHN0LnXVJQ\nZUBVYBM2hAUbcM6qfft1VkR17vNQY6lBiFFb2e2sJeVcuEcufW2tQfMd04gChc+/AqenpyM9Pd3t\nseXLl+s/p6am4qWXXmr+a13is2/K8NSrXyA4yICMMQNbPF9Tr627cGYYzklvqDIgZNiEHZEhJkj1\n2oK7CMd+FGGOIFFjrdVrSDkDhizJkCAhOaxv3TrbntZKmgeyhJgQnCqu7nCRRqLeLqDHTM5VarfL\n1jfaWpTDBoAaRwFC5/af+qS3UABVhk21Ys7ES1H+5ccoBRDpKCyoyApCDSGottSgzlYPo2xw26ti\n65QnfPiu/E9fyzCW3DAUybGhmOqopEsUKAI6YDTdx7uhlYBRGvI5vjwXAoujUrnzFlmoMoQqw6ra\ncGliBGKLJZSWa3tdOIWZQlFcp9WLuixqsNteFbLE+kBNdWdJ7d4gLFjbgpYo0AT0lc3WZKK7+Q5t\nMDagMeobbDv6vGsOo+mQlKrApmqRpNJShSDF5DY/Ue7YHQ8ALo/unasye4u+lmEQBaqADhhNM4wD\nXxa7PRfUpISPPofhzDCENofhvAOqylKtz184OSe+r44dgWmXuM/RkLu+lmEQBarAHpJqZx/vsHAJ\nztUBFpsKJeYslIhyCFUCIOlDUnbVjhpLLeIj3YvS3X3V7ThR+T0mJ/fe/St6C/YPUWAI6IBhs7e8\nldYpOFi4AobVDtOQIwAASXYEGVWGKlRcaKyAgEC0OdLt9weE98eA8P6+aHbAYYZBFBgCfEiq7QzD\nFOya09DnMJpy1IYqrdP22o4OiuraxvUhnMMgCgyBnWG0MSQlRxfjbJhrzwJLK3dQCUd5kMJqrbRJ\njJkBo7MYMIgCQ4BnGK0PSRmTvnc7Pnu+rsU5wqIVuztZqZ0bzYDRab1hRz0iuniBHTBayTDWLBmN\n/v3Cmj3qfp7JKGP2tcMAACcqTgHgkNTFemB+GtbfPqanm0FEFyGwh6RayTAGxIUhvCwIxU2315bd\nh6TSBvVDapwROAM02LXV4lHNJr2pY0Zd1rGtT4mo9wn4DMOQ9B2MqYfhzCIUWYLavCpcs4DRaLMj\nxuzaPlOChBCDd/sxEBEFqoAOGFa7DcaB38DQrxgwaHWjFEVCjbXG7TypWcCwWlW3gBFiCGa5DyLq\n8wL6KlituDZvks11SEkIhyJLqLI020xFsQNCm5iV6qKxaPplWikQx94WIUZmF0REAT2H0ShX6T9L\nQXVY/9MxsNitqLc1wCAZYLUCksGmD0n1D03E6uuX6xsfRZjC0FDfoO95QUTUlwV0hmEXrm0/JbN2\n62xFYyUAYHjUlbCVpGjPKTZAEgg3hbntkucsNhiscF8DIqLADhiSVf9ZCnIGjAoAQIw5ErBrq7kl\ng3ObVfeES5G051Vw6zQiosAOGHAFDNlchxMV3+O7ygIAQGxIDITq2C/aETCMzQOGrD1vd5Q5JyLq\nywJ6DkN1ZBjCrkAOq8T//Wyb/lxCWAx+lDYQh6q/ajPDMEjasV20XcSQiKivCOgMwxkw1NqIFs9F\nm6NwSZxjMZ7SesBw1o+KbLYXBhFRXxTgGYYNQgCiLhyIuOD2XHRQJIIUbRclSR+SMrqdM/+yOTAb\nzMhImdo9DSYi6sUCNmD885PTsAkLJFWB2hCqPy5Bwk1D58FsMMNs1AKEpGhzFAbHnIWT2WDG/Mvm\ndF+jiYh6sYAdknrpX98Cih2SasSg8MH64z8fl4f0AdcBAMwGxz6tbWQYRETkErABA3BkDnYFa3Mm\n64/FmGP0n4ONQY7zHHMYknuGQURELgE7JAUAUGxQHftaPDZ+FSobK2FSXFlEkJ5haENSzDCIiNoW\nwAFDhSSrUG1a1pAQEoeEEPcS284AoWcYSgB3BxHRRQrcISnHRDbUtoOAHjAcGYZz3QUREbUUsAFD\nMtcDAISl7TpQpmYZRfOV3kRE5BKwAUMO1kqYq3VtL7prPmfRfOEeERG5BG7ACNEChqhvvn+3iyIp\naLr5HjMMIqK2BdwV8u8ffo+TZ6ogBWu76rWXYUiSBKiKtoESmGEQEbUn4K6Qr3/4PQAgaLgNQpVw\n95yr2/8FVWbAICLyQsAOSUFSIUPB2GEJ7Z6mlzgH12EQEbUncAOGrEISXqzcVl1d0LyWFBERuQRu\nwJBUSN68PWYYRERe8XnA2L9/P2bOnImMjAxs37691XP27NmDzMxMzJkzBw8//HCX/F3JywxDbrJY\nj3MYRERt8+kVUlVVbNiwATt37kR8fDwWLFiAadOmITU1VT/nhx9+wI4dO/CXv/wFYWFhKC8v75o/\n7mWGkZoYjZNV2l4ZvK2WiKhtPs0wjh49ipSUFCQnJ8NoNCIzMxN79+51O+evf/0rFi9ejLAwbb1E\nTExMay/VcbI26e2Js2ItwAyDiKg9Pg0YJSUlSEpK0o8TEhJQWlrqds6pU6fw/fffY9GiRVi4cCE+\n+OCDrvnjkgrJi4BhNrgCBjMMIqK2ebxClpSUICGh/VtT2yKaLqNug91uR0FBAXbt2oUzZ85gyZIl\nyM/P1zOOzhGQZAFZ9RwwghRmGERE3vB4hZw/fz5GjRqFxYsXY8KECR168cTERJw5c0Y/LikpQXx8\nvNs5CQkJGDVqFGRZxoABAzBo0CCcOnUKV155ZbuvHRfX9gpuSFqgMshK++cBiC5yPZ8YH6Wt/vYz\nnt5jX8K+cGFfuLAvuobHgLFv3z7s2bMHv/vd77BhwwYsWbIE8+bN8yoDSEtLQ0FBAYqKihAXF4f8\n/Hxs3rzZ7Zzp06cjPz8fWVlZKC8vxw8//ICBAwd6fO2ysuq2n5RU7f+F3P55AFSLK0CcO1fj8e/2\nNnFx4R7fY1/BvnBhX7iwL1wuNnB6DBgmkwlZWVnIysrCZ599hry8PPz2t79FdnY27r33XvTr16/N\n31UUBevWrUNubi6EEFiwYAFSU1OxdetWpKWlYcqUKfjRj36Ef//738jMzISiKHjkkUcQGRnZ6Tck\nSYCQtYCheDPp3WRIioiI2ubVoH1RURF2796NN998ExMmTEBOTg4++ugjLF26FK+//nq7v5ueno70\n9HS3x5YvX+52vGbNGqxZs6aDTW9JCKFVn3VkGLIXGyIFGRgwiIi84fGKevfdd+Obb77BwoUL8eqr\nryI6OhoAMHr0aOzZs8fnDewIu6rNXUgdyDDMzDCIiLziMWDMmzcPM2bMgKK0vPi++eabPmlUZ9nt\njruy9AyjY3dJERFR2zyuw4iMjERdXZ1+XFVVhYMHD/q0UZ3lzDDgyDAMXgQMs6HtLVyJiMjFY8DY\nuHGj2x1RYWFh2Lhxo08b1Vl21XF3lCPDULyYw+CQFBGRdzwGDCGE29oEWZZht9t92qjOcs1haO1T\nOCRFRNRlPAaM0NBQHDlyRD8+cuQIQkJCfNqoztLnMJyT3l4EDO6BQUTkHY9jNqtWrcJ9992HIUOG\nAABOnDiBp556yucN6wzXkJQWOIQX5c0jTOEwK0EYkzjal00jIvJ7HgPGqFGjkJ+fj8OHD0MIgVGj\nRl3Uwjpf0ie99ZXenst8KLKCJ9N/6ZclQYiIupNXC/ciIyMxefJkX7flojVfh+HVFq0AgwURkRc8\nBozjx49j/fr1OH78OCwWi/74V1995dOGdYbdLgDFCslUrz2gBu4OtERE3c1jwPjFL36BFStW4Ikn\nnsCOHTuwa9cuhIaGdkfbOsyuCgSNOAjZrK0bEQwYRERdxuMV1WKxYMKECRBCID4+HitXruy6TY66\nmF1V9WABAAaVi/KIiLqKx4Ahy9opkZGROH78OC5cuICioiKfN6wzVNV9wyazLbaHWkJEFHg8Dkll\nZmbiwoULWLZsGRYtWgRVVVtUm+0tbM0CRlRocA+1hIgo8LQbMFRVxYQJExAdHY309HR8/PHHaGxs\nvMjtU33HbhcQqgxJVjFWmY/Z4y/t6SYREQWMdoekZFnGz372M/3YaDT22mABAFa7DZKsop+cjNsm\nj0NwEPfoJiLqKh7nMFJTU1FYWNgdbbloFrt2269RMvVwS4iIAo/Hr+Dl5eWYO3currnmGrcaUlu2\nbPFpwzqjwd4IADDKDBhERF3Nq0nvzMzM7mjLRWu0OTIM2djDLSEiCjweA0Z2dnZ3tKNLNNgbAAAm\nmSXLiYi6mseAsXz58lZrLfXGISmLygyDiMhXPAaMKVOm6D83NjbinXfeQWpqqk8b1VnOgBHEDIOI\nqMt1eEjqxz/+Me655x6fNehiOO+SMimc9CYi6modrs4nSVKvvc3Womp3SXGfbiKirtehOQwhBL7+\n+mtMmDDB5w3rjEbHbbVmA4sOEhF1tQ7NYSiKgtzcXIwcOdKnjeqsRlXbByPM2Dv3HCci8mcBdVtt\no3AGjN5bvoSIyF95nMNYtGgRKisr9eOKigosWbLEp43qrEahrcMIN/XODZ6IiPyZx4BRV1eHyMhI\n/TgqKgo1NTU+bVRnWVEPoUoIMXIOg4ioq3kMGKqqoq7OtYtdbW0t7Ha7TxvVWVbRANhMMBqUnm4K\nEVHA8TiHMXv2bOTm5mLRokUAgJdeeglz5871ecM6wyo1QNiCoCgtV6YTEdHF8Rgw7rrrLsTHx2Pf\nvn0QQmDhwoXIysrqjrZ1iF21Q5WsELZwGJUOLy8hIiIPvNphKDs7u9ffLVVcVwoAEFYTFAYMF7dU\nkAAAFDhJREFUIqIu5/HK+sADD6CiokI/vnDhAh588EGfNqoz3jq1FwBgP5/EDIOIyAc8XllPnz6N\nqKgo/Tg6OhoFBQU+bVRnlNaVQVKNUCviOYdBROQDHgOG3W53uyvKarXCYrH4tFGdUW9rgKwaIUGC\nIjNgEBF1NY8BY9KkSVi5ciU++eQTfPLJJ8jLy0N6errXf2D//v2YOXMmMjIysH379jbPe/vtt3HF\nFVfgv//9r9ev3VS9rQGSaoSiyK3u30FERBfH46R3Xl4efv/73+M3v/kNAK221Lhx47x6cVVVsWHD\nBuzcuRPx8fFYsGABpk2b1mI/jdraWrz44oudrlElhECDrQGKGgIDh6OIiHzCY4ZhNBpx//334+mn\nn8YNN9yAv//971i7dq1XL3706FGkpKQgOTkZRqMRmZmZ2Lt3b4vztmzZgjvvvBNGY+d2ymu0WyAg\nALsBBk54ExH5RLsZhs1mw759+/C3v/0Nhw8fhs1mw3PPPed1JlBSUoKkpCT9OCEhAV988YXbOV99\n9RWKi4sxefJk7NixoxNvwbWXN+xGZhhERD7S5tfxJ554Atdffz12796N2bNn4/3330dkZGSHho2E\nEB6ff/zxx7FmzRqvf6c1DTYtYAhmGEREPtNmhvHSSy9h1KhRWLZsGcaPHw8AHZ5MTkxMxJkzZ/Tj\nkpISxMfH68e1tbU4ceIEfvKTn0AIgXPnzuHee+/Ftm3bMGLEiHZfOy4uXP/5glSm/WA3IMhkcHuu\nL+hr77c97AsX9oUL+6JrtBkwPvzwQ/zv//4vNm7ciMrKSmRlZXW46GBaWhoKCgpQVFSEuLg45Ofn\nY/PmzfrzYWFhOHjwoH78k5/8BI8++iiGDx/u8bXLyqr1n8+eLwcA2K0KpGbPBbq4uPA+9X7bw75w\nYV+4sC9cLjZwtjl+ExERgSVLluDVV1/F008/jcrKSjQ0NGDJkiXYvXu3Vy+uKArWrVuH3NxczJ49\nG5mZmUhNTcXWrVvx7rvvtjhfkqTODUk5tmZVbQrnMIiIfEQSHbhCW61W/POf/8Rrr72GP/zhD75s\nl0dlZdUoqCrEy9/+HcNiLkP+9/+E5WQaBpmHY+1PrunRtnUnfntyYV+4sC9c2BcuF5theFV80Mlo\nNGLWrFmYNWvWRf3RrrL7m9fwQ9VpfFd5CoA26W2x9c69OoiI/J1f31IUFRTp/oDdgPOVDT3TGCKi\nAOfXASM2OMbtWNgNqG2w9VBriIgCm18HjBbsnVspTkREnvl1wLCr7vMVwm5AanJED7WGiCiwdWjS\nu7exqe7DT4umXIGJIwb0UGuIiAKbX2cYNuGeYQzpH4MQs1/HQCKiXsuvA0bzISmzSemhlhARBT6/\nDhjlNfVux2YTswsiIl/x64BRXe++5oIZBhGR7/h1wLA3m8MIMjJgEBH5SkAFDFlm4UEiIl/x64Ch\nwhUwbGXJPdgSIqLA59ezxKpQIVQZDZ9NBwSzCyIiX/LrgGEXNkDI2v+IiMin/PpKq0JlsCAi6iZ+\nfbUVsEOofv0WiIj8hl9fbbUMg3MXRETdwa/nMATsgDBg2ugBGDs8vqebQ0QU0Pw8YKiAKuOmqakw\nGrhoj4jIl/x6SEpI2qS3ovj12yAi8gt+faUVjrukZInzGEREvua3AUMVKiAJSP77FoiI/IrfXm2d\nu+1JXIdBRNQt/PZqa3NsniSBk91ERN3BbwOGs1Kt7L9vgYjIr/jt1VYfkmKGQUTULfw4YDgzDAYM\nIqLu4LcBwy60DINDUkRE3cNvr7ZWZ4YhMcMgIuoOfhswnHMYCgMGEVG38NuAYbE7h6QYMIiIuoP/\nBgybFQCgyAwYRETdwW8DRqMzYDDDICLqFn4ZMKpqLbDYtCEpAzMMIqJu4Zf7YSx57C0kDq4AYjnp\nTUTUXXyeYezfvx8zZ85ERkYGtm/f3uL5nTt3IjMzE/PmzcNPf/pTnD171qvXLausAwAYZL+MeURE\nfsenAUNVVWzYsAHPPfcc3nzzTeTn5+PkyZNu5wwfPhyvvvoq3njjDcyYMQMbN2707sUlFQADBhFR\nd/FpwDh69ChSUlKQnJwMo9GIzMxM7N271+2csWPHIigoCAAwcuRIlJSUePfishYwjAoDBhFRd/Bp\nwCgpKUFSUpJ+nJCQgNLS0jbPf+WVV5Cenu7di+sZBucwiIi6g0+/ngshvD73jTfewH//+1+88MIL\nXp0vSdprh4cEIy4uvFPtCxR9/f03xb5wYV+4sC+6hk8DRmJiIs6cOaMfl5SUID4+vsV5Bw4cwPbt\n2/Hiiy/CaDR69+KOISnVBpSVVXdJe/1RXFx4n37/TbEvXNgXLuwLl4sNnD4dkkpLS0NBQQGKiopg\nsViQn5+PadOmuZ1z7NgxrF+/Htu2bUN0dLT3L+4YkjJxDoOIqFv49GqrKArWrVuH3NxcCCGwYMEC\npKamYuvWrUhLS8OUKVOwadMm1NfX48EHH4QQAv3798czzzzj+cUl56Q35zCIiLqDz7+ep6ent5jI\nXr58uf7z888/36nXlWRnhuHlEBYREV0UvywNAgBwTHqbDAwYRETdwS8DRnR4EOcwiIi6mV8GDINB\n1u+SCvL2rioiIroo/hkwZFnPMIKYYRARdQv/DBgGSV+4Z2KGQUTULfwyYCiya0jKzElvIqJu4ZcB\nw6BI+pCUmRkGEVG38NOA0WQOw8A5DCKi7uCXAUNRZEiyCqHKMBm40puIqDv4ZcAwKjKg2AC7AqPB\nL98CEZHf8curraJIkIyNENYgBgwiom7il1dbWRGQDDYIaxAUxS/fAhGR3/HPq63SCAAQVhNkSerh\nxhAR9Q1+GTBUgzNgBPVwS4iI+g6/DBh2uR4AINkZMIiIuotfBgxVbgAAGEVwD7eEiKjv8MuAcVoc\nAQAEqRE93BIior7DLwNGg1QJ27n+CFHjeropRER9hl8GDABQq2IQEsSyIERE3cV/A0ZdOIIZMIiI\nuo1/BgwBiPowmE2sI0VE1F38MmAEl6cBQkFYMEubExF1F78MGOvmLsaVg2Iwd+Kgnm4KEVGf4ZeT\nAEMGRCHv5pE93Qwioj7FLzMMIiLqfgwYRETkFQYMIiLyCgMGERF5hQGDiIi8woBBREReYcAgIiKv\nMGAQEZFXGDCIiMgrDBhEROQVBgwiIvKKzwPG/v37MXPmTGRkZGD79u0tnrdYLFi5ciVmzJiBm2++\nGWfOnPF1k4iIqBN8GjBUVcWGDRvw3HPP4c0330R+fj5Onjzpds4rr7yCyMhI/OMf/8Btt92GTZs2\n+bJJRETUST4NGEePHkVKSgqSk5NhNBqRmZmJvXv3up2zd+9eZGdnAwAyMjJw8OBBXzaJiIg6yacB\no6SkBElJSfpxQkICSktL3c4pLS1FYmIiAEBRFERERKCiosKXzSIiok7wacAQQnT4HCEEJEnyVZOI\niKiTfLqBUmJiotskdklJCeLj41ucU1xcjISEBNjtdtTU1CAyMtLja8fFhXd5e/0V+8KFfeHCvnBh\nX3QNn2YYaWlpKCgoQFFRESwWC/Lz8zFt2jS3c6ZMmYLXXnsNAPD2229j/PjxvmwSERF1kiS8GTe6\nCPv378evf/1rCCGwYMECLFu2DFu3bkVaWhqmTJkCi8WCVatW4auvvkJUVBQ2b96MAQMG+LJJRETU\nCT4PGEREFBi40puIiLzCgEFERF5hwCAiIq/4XcDwVJsq0KxduxbXXXcd5syZoz9WWVmJ3NxcZGRk\nYOnSpaiurtaf+9WvfoUZM2Zg3rx5+Oqrr3qiyT5RXFyMW2+9FbNmzcKcOXPw5z//GUDf7AuLxYKc\nnBxkZWVhzpw5eOqppwAAhYWFuOmmm5CRkYG8vDzYbDb9/ECv16aqKrKzs3H33XcD6Lt9MXXqVMyd\nOxdZWVlYsGABgC7+jAg/YrfbxfTp00VhYaGwWCxi7ty54sSJEz3dLJ/6z3/+I44dOyZmz56tP7Zx\n40axfft2IYQQv//978WmTZuEEEK899574s477xRCCHH48GGRk5PT/Q32kdLSUnHs2DEhhBA1NTVi\nxowZ4sSJE32yL4QQoq6uTgghhM1mEzk5OeLw4cPiwQcfFHv27BFCCPHYY4+Jl156SQghxK5du8T6\n9euFEELk5+eLFStW9Eibfen5558XDz30kLjrrruEEKLP9sXUqVNFRUWF22Nd+RnxqwzDm9pUgeba\na69FRESE22NN629lZ2frfbB3715kZWUBAK6++mpUV1fj3Llz3dtgH4mLi8OwYcMAAKGhoUhNTUVJ\nSUmf7AsACA4OBqB9Y7bZbJAkCYcOHUJGRgYArS/+9a9/AQj8em3FxcV4//33kZOToz/20Ucf9cm+\nEEJAVVW3x7ryM+JXAcOb2lR9QXl5OWJjYwFoF9Ly8nIA7nW5AK1/SkpKeqSNvlRYWIjjx4/j6quv\nxvnz5/tkX6iqiqysLEycOBETJ07EwIEDERERAVnWPtKJiYn6+w30em2PP/44HnnkEb2k0IULFxAZ\nGdkn+0KSJCxduhTz58/Hyy+/DABd+hnxaWmQria4ZKRdrfVPoNXlqq2txfLly7F27VqEhoa2+f4C\nvS9kWcbrr7+Ompoa3HfffS22DQBc77d5X4gAqtf23nvvITY2FsOGDcOhQ4cAaO+v+XvuC30BALt3\n79aDQm5uLgYNGtSlnxG/Chje1KbqC/r164dz584hNjYWZWVliImJAaB9QyguLtbPKy4uDqj+sdls\nWL58OebNm4fp06cD6Lt94RQWFoYxY8bgyJEjqKqqgqqqkGXZ7f06+6Kj9dr8wWeffYZ9+/bh/fff\nR2NjI2pra/H444+jurq6z/UFoGUQABATE4Pp06fj6NGjXfoZ8ashKW9qUwWi5t8Epk6dildffRUA\n8Nprr+l9MG3aNLz++usAgMOHDyMiIkJPRQPB2rVrMWTIENx22236Y32xL8rLy/U7XRoaGnDw4EEM\nGTIE48aNw9tvvw3AvS+mTp0asPXa8vLy8N5772Hv3r3YvHkzxo0bhyeffLJP9kV9fT1qa2sBAHV1\ndfjwww8xdOjQLv2M+F1pkNZqUwWyhx56CIcOHUJFRQViY2PxwAMPYPr06XjwwQdx9uxZ9O/fH1u2\nbNEnxn/5y1/igw8+QHBwMJ544gmMGDGih99B1/j0009xyy23YOjQoZAkCZIkYeXKlbjqqquwYsWK\nPtUXX3/9NdasWQNVVaGqKmbNmoV77rkHp0+fRl5eHqqqqjBs2DBs2rQJRqOxz9Rr+/jjj/HHP/4R\nzz77bJ/si9OnT+P++++HJEmw2+2YM2cOli1bhoqKii77jPhdwCAiop7hV0NSRETUcxgwiIjIKwwY\nRETkFQYMIiLyCgMGERF5hQGDiIi8woBBfu2mm25CdnY2MjMzMWLECGRnZyM7Oxtr167t8Gvdcccd\nXpW7fvTRR3H48OHONLdDjh07hnfeecfnf4fIW1yHQQGhqKgICxYsaLf6qLNUhL94+eWXcfDgQWze\nvLmnm0IEwM9qSRF1xMGDB7Fp0yaMHDkSx44dw3333Yfy8nLs2rVL31BnzZo1GDt2LABg8uTJ2Llz\nJwYNGoTFixdj1KhR+Pzzz1FaWorZs2djxYoVAIDFixfj3nvvxaRJk7Bq1SqEhYXh5MmTKCkpwejR\no/HEE08A0GrzPPLII7hw4QIGDhwIu92OqVOn4uabb3Zr57lz5/DQQw/hwoULAIBJkybhjjvuwDPP\nPIO6ujpkZ2dj3LhxWLNmDT7//HNs3rwZ9fX1AIDly5cjPT0dBQUFWLx4MWbPno1PP/0UFosF69ev\nx+jRo7ulr6mPuJjNOoh6i8LCQjF+/Hi3xw4cOCCGDx8uvvjiC/2xppvLnDhxQlx//fX6cXp6uvju\nu++EEEIsWrRIPPTQQ0IIIaqqqsTYsWNFYWGh/twHH3wghBDi4YcfFrfccouwWq2isbFRzJw5Uxw6\ndEgIIcQ999wj/vCHPwghhDh9+rQYNWqU2L17d4u279ixQzz22GP6cVVVlRBCiL/+9a8iLy/Pre1Z\nWVni/PnzQgghiouLRXp6uqipqRE//PCDuPzyy0V+fr7+3q+//nphs9m870QiD5hhUEAbPHgwrrzy\nSv341KlT2Lp1K0pLS6EoCkpLS1FRUYGoqKgWv3vjjTcCAMLDwzFo0CAUFBQgOTm5xXk33HADDAbt\nozR8+HAUFBRg7NixOHToEH71q18BAAYMGKBnMs2NHDkSL774Ip588kmMGTMGkyZNavW8Tz/9FIWF\nhVi6dKlekFJRFJw+fRohISEIDg7GrFmzAAATJkyAoig4deoUUlNTve0uonYxYFBACw0NdTteuXIl\n1q9fj8mTJ0NVVVx11VVobGxs9XeDgoL0n2VZht1u79B53u6zcM011+C1117DgQMH8Le//Q07duzA\nCy+80OI8IQRGjBiBnTt3tniuoKCgxWOqqgbUXg/U8/xnBpDIA+HF/Rs1NTV6ddLdu3e3GQS6wtix\nY/Wy0kVFRfj4449bPa+wsBBhYWGYNWsW1qxZgy+//BKAtteFs4w5AIwePRonTpzAJ598oj929OhR\n/ef6+nrs2bMHgLZFKQCkpKR07ZuiPo0ZBgUMb75Nr127FsuWLUNSUhLGjRuH8PDwVn+/+Wu19Vx7\n561btw6rV69Gfn4+Bg8ejNGjR7v9PaeDBw/iz3/+MxRFgRACGzZsAABMnDgRf/rTn5CVlYXx48dj\nzZo1eOaZZ7Bp0yZUV1fDarVi4MCBePbZZwEAsbGx+Pbbb5GTkwOLxYLNmzdDURSPfULkLd5WS+Qj\njY2NMBqNkGUZJSUlyMnJwa5duzBw4MAu/1vOu6Q+/PDDLn9tIidmGEQ+8t133+HRRx+FEAKqqmLl\nypU+CRZE3YUZBhEReYWT3kRE5BUGDCIi8goDBhEReYUBg4iIvMKAQUREXmHAICIir/x/apbYj523\no60AAAAASUVORK5CYII=\n", "text/plain": [ - "" + "\u003cmatplotlib.figure.Figure at 0x7f97f1330850\u003e" ] }, "metadata": { "tags": [] - } + }, + "output_type": "display_data" } + ], + "source": [ + "def plot(train, test, label):\n", + " plt.title('MNIST model %s' % label)\n", + " plt.plot(train, label='train %s' % label)\n", + " plt.plot(test, label='test %s' % label)\n", + " plt.legend()\n", + " plt.xlabel('Training step')\n", + " plt.ylabel(label.capitalize())\n", + " plt.show()\n", + " \n", + "\n", + "with tf.Graph().as_default():\n", + " hp = tf.contrib.training.HParams(\n", + " learning_rate=0.05,\n", + " max_steps=tf.constant(500),\n", + " )\n", + " train_ds = setup_mnist_data(True, hp, 50)\n", + " test_ds = setup_mnist_data(False, hp, 1000)\n", + " tf_train = autograph.to_graph(train)\n", + " all_losses = tf_train(train_ds, test_ds, hp)\n", + "\n", + " with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " (train_losses, test_losses, train_accuracies,\n", + " test_accuracies) = sess.run(all_losses)\n", + " \n", + " plot(train_losses, test_losses, 'loss')\n", + " plot(train_accuracies, test_accuracies, 'accuracy')" ] }, { + "cell_type": "markdown", "metadata": { - "id": "HNqUFL4deCsL", - "colab_type": "text" + "colab_type": "text", + "id": "HNqUFL4deCsL" }, - "cell_type": "markdown", "source": [ "# 4. Case study: building an RNN\n" ] }, { + "cell_type": "markdown", "metadata": { - "id": "YkC1k4HEQ7rw", - "colab_type": "text" + "colab_type": "text", + "id": "YkC1k4HEQ7rw" }, - "cell_type": "markdown", "source": [ "In this exercise we build and train a model similar to the RNNColorbot model that was used in the main Eager notebook. The model is adapted for converting and training in graph mode." ] }, { + "cell_type": "markdown", "metadata": { - "id": "7nkPDl5CTCNb", - "colab_type": "text" + "colab_type": "text", + "id": "7nkPDl5CTCNb" }, - "cell_type": "markdown", "source": [ "To get started, we load the colorbot dataset. The code is identical to that used in the other exercise and its details are unimportant." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "A0uREmVXCQEw", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "A0uREmVXCQEw" }, - "cell_type": "code", + "outputs": [], "source": [ "def parse(line):\n", " \"\"\"Parses a line from the colors dataset.\n", @@ -1136,7 +1034,7 @@ " A tuple of three tensors (rgb, chars, length), of shapes: (batch_size, 3),\n", " (batch_size, max_sequence_length, 256) and respectively (batch_size).\n", " \"\"\"\n", - " items = tf.string_split([line], \",\").values\n", + " items = tf.string_split(tf.expand_dims(line, 0), \",\").values\n", " rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255.0\n", " color_name = items[0]\n", " chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256)\n", @@ -1168,23 +1066,21 @@ " dataset = dataset.repeat()\n", " if training:\n", " dataset = dataset.shuffle(buffer_size=3000)\n", - " dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None, None], []))\n", + " dataset = dataset.padded_batch(batch_size, padded_shapes=((None,), (None, None), ()))\n", " return dataset\n", "\n", "\n", - "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n", - "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n", + "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/train.csv\"\n", + "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/test.csv\"\n", "data_dir = \"tmp/rnn/data\"" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "waZ89t3DTUla", - "colab_type": "text" + "colab_type": "text", + "id": "waZ89t3DTUla" }, - "cell_type": "markdown", "source": [ "Next, we set up the RNNColobot model, which is very similar to the one we used in the main exercise.\n", "\n", @@ -1192,17 +1088,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "9v8AJouiC44V", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "9v8AJouiC44V" }, - "cell_type": "code", + "outputs": [], "source": [ "def model_components():\n", " lower_cell = tf.contrib.rnn.LSTMBlockCell(256)\n", @@ -1226,17 +1124,18 @@ " Returns:\n", " A Tensor of shape (max_sequence_length, batch_size, output_size).\n", " \"\"\"\n", - " hidden_outputs = []\n", - " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n", + " hidden_outputs = tf.TensorArray(tf.float32, size=0, dynamic_size=True)\n", " state, output = cell.zero_state(batch_size, tf.float32)\n", + " initial_state_shape = state.shape\n", + " initial_output_shape = output.shape\n", " n = tf.shape(chars)[0]\n", " i = 0\n", - " while i < n:\n", + " while i \u003c n:\n", " ch = chars[i]\n", " cell_output, (state, output) = cell.call(ch, (state, output))\n", " hidden_outputs.append(cell_output)\n", " i += 1\n", - " hidden_outputs = hidden_outputs.stack()\n", + " hidden_outputs = autograph.stack(hidden_outputs)\n", " if training:\n", " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", " return hidden_outputs\n", @@ -1260,50 +1159,51 @@ " A Tensor of shape (batch_size, 3) - the model predictions.\n", " \"\"\"\n", " (chars, length) = inputs\n", - " chars_time_major = tf.transpose(chars, [1, 0, 2])\n", + " chars_time_major = tf.transpose(chars, (1, 0, 2))\n", " chars_time_major.set_shape((None, batch_size, 256))\n", "\n", " hidden_outputs = rnn_layer(chars_time_major, lower_cell, batch_size, training)\n", " final_outputs = rnn_layer(hidden_outputs, upper_cell, batch_size, training)\n", "\n", " # Grab just the end-of-sequence from each output.\n", - " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n", + " indices = tf.stack((length - 1, range(batch_size)), axis=1)\n", " sequence_ends = tf.gather_nd(final_outputs, indices)\n", + " sequence_ends.set_shape((batch_size, 128))\n", " return relu_layer(sequence_ends)\n", "\n", "def loss_fn(labels, predictions):\n", " return tf.reduce_mean((predictions - labels) ** 2)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "JjK4gXFvFsf4", - "colab_type": "text" + "colab_type": "text", + "id": "JjK4gXFvFsf4" }, - "cell_type": "markdown", "source": [ "The train and test functions are also similar to the ones used in the Eager notebook. Since the network requires a fixed batch size, we'll train in a single shot, rather than by epoch." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "ZWQMExk0S6X6", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "ZWQMExk0S6X6" }, - "cell_type": "code", + "outputs": [], "source": [ "def train(optimizer, train_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps):\n", " iterator = train_data.make_one_shot_iterator()\n", " step = 0\n", - " while step < num_steps:\n", + " while step \u003c num_steps:\n", " labels, chars, sequence_length = iterator.get_next()\n", " predictions = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, batch_size, training=True)\n", " loss = loss_fn(labels, predictions)\n", @@ -1318,7 +1218,7 @@ " total_loss = 0.0\n", " iterator = eval_data.make_one_shot_iterator()\n", " step = 0\n", - " while step < num_steps:\n", + " while step \u003c num_steps:\n", " labels, chars, sequence_length = iterator.get_next()\n", " predictions = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, batch_size, training=False)\n", " total_loss += loss_fn(labels, predictions)\n", @@ -1339,16 +1239,14 @@ " # Here, we create a no_op that will drive the execution of all other code in\n", " # this function. Autograph will add the necessary control dependencies.\n", " return tf.no_op()" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "iopcs5hXG2od", - "colab_type": "text" + "colab_type": "text", + "id": "iopcs5hXG2od" }, - "cell_type": "markdown", "source": [ "Finally, we add code to run inference on a single input, which we'll read from the input.\n", "\n", @@ -1356,17 +1254,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "DyU0wnnAFEYj", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "DyU0wnnAFEYj" }, - "cell_type": "code", + "outputs": [], "source": [ "@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)\n", "def draw_prediction(color_name, pred):\n", @@ -1388,16 +1288,14 @@ " draw_prediction(color_name, pred)\n", " # Create an op that will drive the entire function.\n", " return tf.no_op()" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "Nt0Kv5OCHip0", - "colab_type": "text" + "colab_type": "text", + "id": "Nt0Kv5OCHip0" }, - "cell_type": "markdown", "source": [ "Finally, we put everything together.\n", "\n", @@ -1405,218 +1303,132 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "-GmWa0GtYWdh", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {} - ], - "base_uri": "https://localhost:8080/", - "height": 668 + "height": 415 }, - "outputId": "61f4af1d-c81e-44db-9079-1a7b8ed8ce58", + "colab_type": "code", "executionInfo": { + "elapsed": 15536, "status": "ok", - "timestamp": 1522345877153, - "user_tz": 240, - "elapsed": 75500, + "timestamp": 1531750946373, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "-GmWa0GtYWdh", + "outputId": "2e7a9856-9809-43a3-8b43-3c8514ea43e9" }, - "cell_type": "code", - "source": [ - "def run_input_loop(sess, inference_ops, color_name_placeholder):\n", - " \"\"\"Helper function that reads from input and calls the inference ops in a loop.\"\"\"\n", - "\n", - " tb = widgets.TabBar([\"RNN Colorbot\"])\n", - " while True:\n", - " with tb.output_to(0):\n", - " try:\n", - " color_name = six.moves.input(\"Give me a color name (or press 'enter' to exit): \")\n", - " except (EOFError, KeyboardInterrupt):\n", - " break\n", - " if not color_name:\n", - " break\n", - " with tb.output_to(0):\n", - " tb.clear_tab()\n", - " sess.run(inference_ops, {color_name_placeholder: color_name})\n", - " plt.show()\n", - "\n", - "with tf.Graph().as_default():\n", - " # Read the data.\n", - " batch_size = 64\n", - " train_data = load_dataset(data_dir, train_url, batch_size)\n", - " eval_data = load_dataset(data_dir, test_url, 50, training=False)\n", - " \n", - " # Create the model components.\n", - " lower_cell, upper_cell, relu_layer = model_components()\n", - " # Create the helper placeholder for inference.\n", - " color_name_placeholder = tf.placeholder(tf.string, shape=())\n", - " \n", - " # Compile the train / test code.\n", - " tf_train_model = autograph.to_graph(train_model)\n", - " train_model_ops = tf_train_model(\n", - " train_data, eval_data, batch_size, lower_cell, upper_cell, relu_layer, train_steps=100)\n", - " \n", - " # Compile the inference code.\n", - " tf_inference = autograph.to_graph(inference)\n", - " inference_ops = tf_inference(color_name_placeholder, lower_cell, upper_cell, relu_layer)\n", - " \n", - " with tf.Session() as sess:\n", - " sess.run(tf.global_variables_initializer())\n", - " \n", - " # Run training and testing.\n", - " sess.run(train_model_ops)\n", - " \n", - " # Run the inference loop.\n", - " run_input_loop(sess, inference_ops, color_name_placeholder)" - ], - "execution_count": 0, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ - "('Successfully downloaded', 'train.csv', 28010L, 'bytes.')\n", - "('Successfully downloaded', 'test.csv', 2414L, 'bytes.')\n", - "Step 0 train loss 0.37890616\n", - "Step 10 train loss 0.18515904\n", - "Step 20 train loss 0.0892782\n", - "Step 30 train loss 0.07883155\n", - "Step 40 train loss 0.08585831\n", - "Step 50 train loss 0.09302989\n", - "Step 60 train loss 0.089012615\n", - "Step 70 train loss 0.07275697\n", - "Step 80 train loss 0.06644974\n", - "Step 90 train loss 0.0854013\n", - "Test loss 0.13216865Colorbot is ready to generate colors!\n", - "\n", + "Test loss 0.138294\n", + "Colorbot is ready to generate colors!\n", "\n", "\n" - ], - "name": "stdout" + ] }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ - "" + "\u003clink rel=stylesheet type=text/css href='/nbextensions/google.colab/tabbar.css'\u003e\u003c/link\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f97ee42bb90\u003e" ] }, "metadata": { "tags": [ "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ - "" + "\u003cscript src='/nbextensions/google.colab/tabbar_main.min.js'\u003e\u003c/script\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f97ee42be10\u003e" ] }, "metadata": { "tags": [ "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ - "
" + "\u003cdiv id=\"id1\"\u003e\u003c/div\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f97ee42bd90\u003e" ] }, "metadata": { "tags": [ "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"b102d936-3379-11e8-ac70-0242ac110002\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"borderColor\": [\"#a7a7a7\"], \"tabNames\": [\"RNN Colorbot\"], \"initialSelection\": 0, \"location\": \"top\", \"contentHeight\": [\"initial\"], \"elementId\": \"id1\"});\n", - "//# sourceURL=js_e223a56194" + "window[\"a6045494-8903-11e8-99f9-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"location\": \"top\", \"borderColor\": [\"#a7a7a7\"], \"initialSelection\": 0, \"elementId\": \"id1\", \"contentHeight\": [\"initial\"], \"contentBorder\": [\"0px\"], \"tabNames\": [\"RNN Colorbot\"]});\n", + "//# sourceURL=js_02f896cbda" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2ab810\u003e" ] }, "metadata": { "tags": [ "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"b103532a-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_b8c6a821fb" + "window[\"a6045495-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_7e8f9f77a0" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2ab710\u003e" ] }, "metadata": { "tags": [ "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"b105b28c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_44805e254b" + "window[\"a6045496-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_5531553c2f" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2ab6d0\u003e" ] }, "metadata": { @@ -1624,17 +1436,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"b106197a-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_a63d3c6c47" + "window[\"a6045497-8903-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_d1f809ec17" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2ab990\u003e" ] }, "metadata": { @@ -1642,17 +1454,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"b1069f44-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"b106197a-3379-11e8-ac70-0242ac110002\"]);\n", - "//# sourceURL=js_7e203b8bce" + "window[\"a6045498-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"a6045497-8903-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_3a3123cadb" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2aba50\u003e" ] }, "metadata": { @@ -1660,17 +1472,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"b1070f38-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_d53293d4a7" + "window[\"a6045499-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_1a0e1f7d6f" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2ab890\u003e" ] }, "metadata": { @@ -1678,17 +1490,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c6d90d5c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"b105b28c-3379-11e8-ac70-0242ac110002\"]);\n", - "//# sourceURL=js_3000dc2c05" + "window[\"a8e54762-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"a6045496-8903-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_6213539615" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2abad0\u003e" ] }, "metadata": { @@ -1696,17 +1508,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c6da872c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_4136f669a3" + "window[\"a8e54763-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_0bd7f95c6e" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2ab950\u003e" ] }, "metadata": { @@ -1714,17 +1526,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c6dac868-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_2f70dd9aee" + "window[\"a8e54764-8903-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_215f004f6b" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2abb10\u003e" ] }, "metadata": { @@ -1732,17 +1544,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c6db07d8-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c6dac868-3379-11e8-ac70-0242ac110002\"]);\n", - "//# sourceURL=js_7226726048" + "window[\"a8e54765-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"a8e54764-8903-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_a06186c8ad" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2aba90\u003e" ] }, "metadata": { @@ -1750,17 +1562,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c6dcc6fe-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_72e7709865" + "window[\"a8e54766-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_383fbaae67" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2abc50\u003e" ] }, "metadata": { @@ -1768,14 +1580,14 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAFZCAYAAADHDNdrAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAB9JJREFUeJzt3E1Lle0ax+HTF4jeEAyMBhE0DawI\nwsCH0AIlaGBWNJBo0CDoA0TQhmDXuKAGDioiCA2KlEAlnl05FD9Co8BeaGCQoBDa2jPZsXt4Bvu/\n0+o4Rmvd1zW4rsmP84bFamo0Go0C4H/WvNYHAPhVCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKDy\nUxgeHq5Dhw7V4OBgPXz4sHp7e+vWrVt15cqVOnnyZN2/f78ajUbdvn27+vr6qqenp65du1YrKytV\nVfXhw4e6cOFC9fX1VV9fX01PT1dV1dzcXHV3d9eDBw/q+PHj9ccff9TExMRaXpWfWOtaHwD+zuvX\nr+vOnTs1MTFRbW1tdf78+dW16enpGh8fr/b29hobG6upqal6/Phxbdy4sS5evFgjIyM1NDRUly5d\nqv3799fw8HC9efOmTp8+XVNTU1VV9enTp2pubq5nz57V5ORk3bhxo44dO7ZW1+UnZkJl3Zudna2D\nBw9WR0dHbdiwoQYHB1fX9u7dW+3t7VVV9fLlyxocHKytW7dWa2trnTp1qp4/f16Li4s1MzNT586d\nq6qqXbt21YEDB1an1OXl5Tpx4kRVVe3Zs6fevXv3Yy/IL8OEyrr3+fPnamtrW/2+ffv21c//+Xxh\nYaHu3r1bjx49qqqqlZWVam9vr4WFhWo0GnXmzJnVvYuLi9XV1VVVVS0tLbVp06aqqmpubq6vX7/+\nX+/Dr0tQWfe2bNlSi4uLq98/fvz43X0dHR3V29tbQ0ND3zxfXl6ulpaWevLkSW3evPmbtbm5ufyB\n+W155Wfd6+zsrJmZmZqfn68vX77U2NjYd/cdOXKkxsfHa2lpqaqqRkdH6+nTp9Xa2lqHDx+u0dHR\nqqpaWlqqy5cv1/v373/YHfg9CCrrXmdnZw0MDNTAwECdPXu2enp6vrvv6NGj1dPTUwMDA9Xf318v\nXryo7u7uqqq6evVqzc7OVn9/fw0MDNTOnTtrx44dP/Ia/Aaa/B8qP4NGo1FNTU1VVfXq1au6efPm\nX06qsFZMqKx78/Pz1dXVVW/fvq1Go1GTk5O1b9++tT4W/BcTKj+FkZGRunfvXjU1NdXu3bvr+vXr\ntW3btrU+FnxDUAFCvPIDhAgqQMi6+WH/kX8eXesjAPytf/3jz79cM6EChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCI\noAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIig\nAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAC\nhAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCI\noAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIig\nAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAC\nhAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkBI\nU6PRaKz1IQB+BSZUgBBBBQgRVIAQQQUIEVSAEEEFCBFUgBBBBQgRVIAQQQUIEVSAEEEFCBFUgBBB\nBQgRVIAQQQUIEVSAEEEFCBFUgBBBBQgRVIAQQQUIEVSAkH8D1Aj8lNhhe7QAAAAASUVORK5CYII=\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACL9JREFUeJzt3F+IlXUex/Gv2ziiBRGVOQaFd2JBzOg5aiH+IZGoJgmM\n/uhVGIlgFE0QEYHdFQaRGBJ10VX0D5TAi8jKomGmILsYjEAkmBwbRIxKGDV/e7G7w8ouux9jd911\nX6+rcx6e85zveS7e/J7zb0ZrrRVA4A8XewDgf4dgADHBAGKCAcQEA4gJBhATDC6Kp59+urrdbt13\n3301OjpaK1euvNgjERCMS9yaNWtqeHj4Yo9xnq+++qqGh4frs88+q7fffruqqmbMmHGRpyIhGPxH\n/fbbb/XDDz/U9ddfX7NmzbrY43CBBOMS9tRTT9XExERt2bKlBgYG6vXXX69vvvmm7r///up0OrV+\n/foaHR2d3n/Tpk318ssv1wMPPFADAwP18MMP18mTJ6uq6vTp0zU0NFRLly6tTqdTGzZsqBMnTlRV\n1eTkZG3ZsqWWLl1a69atq3feeWf6mDt37qxt27bV0NBQLVmypN5777169tln6+DBgzUwMFA7d+78\nm7kPHz5cmzZtqk6nU3fffXft37+/qqrGx8er0+lM7/fMM8/UrbfeOn1/aGio3nzzzX/tSeR8jUva\n6tWr2/DwcGuttWPHjrVut9sOHDjQWmvtiy++aN1ut504caK11trGjRvb2rVr2/fff9+mpqbaxo0b\n244dO1prrb311lvt0UcfbVNTU+3cuXNtbGys/fLLL6211h566KG2ffv2dvr06Xbo0KG2bNmy6ed8\n5ZVX2k033dQ++uij1lprU1NT7f33328PPvjg9IwjIyNt5cqVrbXWzpw509auXdt2797dzpw504aH\nh1t/f387cuTI9OsZGxtrrbW2bt26dvvtt7fDhw+31lpbtWpVO3To0L/rVNJas8L4P9D+/HOhvXv3\n1qpVq2rFihVVVbV8+fK6+eab69NPP53e9957760bbrihent764477qhDhw5VVVVPT0+dPHmyjhw5\nUjNmzKhFixbV5ZdfXseOHauvv/66nnzyyZo5c2YtXLiwNmzYUHv27Jk+Zn9/f61Zs6aqqnp7e//h\nrAcPHqxTp07VI488Uj09PbVs2bJavXp1ffDBB1VVtWTJkhodHa3jx49XVdW6devqyy+/rPHx8fr1\n119r4cKF/6Kzxt/Tc7EH4D/n6NGjtW/fvvr444+r6k8hOXv2bC1fvnx6n2uuuWb69uzZs+vUqVNV\nVXXPPffUsWPH6oknnqiff/65BgcH6/HHH6/Jycm68sora/bs2dOPmz9/fo2NjU3fnzdvXjzj5ORk\n9fX1nbdt/vz5NTk5WVVVnU6n9u/fX9ddd111u93qdru1Z8+e6u3trcWLF1/A2eD3EIxL3F9/+tDX\n11fr16+v7du3X/Bxenp6auvWrbV169Y6evRobd68uRYsWFC33XZb/fTTT3Xq1KmaM2dOVVVNTEzU\n3Llz/+4M/8zcuXNrYmLivG1Hjx6tBQsWVFVVt9utF198sfr6+qrT6dTAwEA999xz1dvbW91u94Jf\nFxfGJckl7tprr63x8fGqqhocHKz9+/fX559/XufOnaupqakaHR2tH3/88Z8eZ2RkpL777rs6d+5c\nzZkzp3p6euqyyy6refPmVX9/f7300kt1+vTp+vbbb+vdd9+twcHB3zXvLbfcUnPmzKnXXnutzp49\nWyMjI/XJJ5/UnXfeWVVVN954Y82aNav27t1bnU6nrrjiirr66qvrww8/PO8NUf49BOMSt3nz5tq1\na1d1u93at29f7dq1q3bv3l3Lly+v1atX1xtvvDH9Hsc/WgkcP368tm3bVosXL6677rqrli5dOh2F\nHTt21Pj4eK1YsaK2bdtWjz322HmXORdi5syZ9eqrr9aBAwdq2bJl9fzzz9cLL7wwvcKo+tMq46qr\nrpq+1PlLKBYtWvS7npPcjNb8gQ6QscIAYoIBxAQDiAkGEPuv/R7GxN7+iz0C/F/rG/z6b7ZZYQAx\nwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQE\nA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMM\nICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCA\nmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBi\nggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJ\nBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYY\nQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAA\nMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHE\nBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhAT\nDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEww\ngJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOI\nCQYQEwwgNqO11i72EMD/BisMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBi\nggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiP0RoqNMBlokHDIAAAAASUVORK5CYII=\n", "text/plain": [ - "" + "\u003cmatplotlib.figure.Figure at 0x7f97ee42bb90\u003e" ] }, "metadata": { @@ -1784,17 +1596,17 @@ "outputarea_id1", "user_output" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c70592aa-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c6da872c-3379-11e8-ac70-0242ac110002\"]);\n", - "//# sourceURL=js_25c3aaf79a" + "window[\"a8e54767-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"a8e54763-8903-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_28bd08ac10" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ea9efc10\u003e" ] }, "metadata": { @@ -1802,17 +1614,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c70842c0-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_984c56b816" + "window[\"a8e54768-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_ae2887f57d" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ea9efb50\u003e" ] }, "metadata": { @@ -1820,17 +1632,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c708dec4-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_e0451a1217" + "window[\"a8e54769-8903-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_608805a786" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef710\u003e" ] }, "metadata": { @@ -1838,17 +1650,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c7092726-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c708dec4-3379-11e8-ac70-0242ac110002\"]);\n", - "//# sourceURL=js_7aa23d7385" + "window[\"a8e5476a-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"a8e54769-8903-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_3d87cf7d0f" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ea9efa90\u003e" ] }, "metadata": { @@ -1856,17 +1668,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c7099044-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_5722756ddb" + "window[\"a8e5476b-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_5e91101199" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ea9efa50\u003e" ] }, "metadata": { @@ -1874,24 +1686,149 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "stream", - "text": [ - "Give me a color name (or press 'enter' to exit): \n" - ], - "name": "stdout" + "data": { + "text/html": [ + "\u003cdiv class=id_45185901 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f97ee42bd90\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"a8e5476c-8903-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_45185901 span\");\n", + "//# sourceURL=js_f43052a94e" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef750\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"a8e5476d-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"a8e5476c-8903-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_bfc0fb76ce" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f97ea9efb10\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"a9e9b8b0-8903-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_45185901 input\");\n", + "//# sourceURL=js_7f167283fa" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef610\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"a9e9b8b1-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"a9e9b8b0-8903-11e8-99f9-c8d3ffb5fbe0\"].remove();\n", + "//# sourceURL=js_016ae4bf21" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef250\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"a9e9b8b2-8903-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_45185901 span\");\n", + "//# sourceURL=js_e666f179bc" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef550\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"a9e9b8b3-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"a9e9b8b2-8903-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_cbb9d14aec" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef1d0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c7baac12-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c70842c0-3379-11e8-ac70-0242ac110002\"]);\n", - "//# sourceURL=js_cdd622e58f" + "window[\"a9e9b8b4-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"a8e54768-8903-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_2967a79665" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef1d0\u003e" ] }, "metadata": { @@ -1899,21 +1836,98 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" } + ], + "source": [ + "def run_input_loop(sess, inference_ops, color_name_placeholder):\n", + " \"\"\"Helper function that reads from input and calls the inference ops in a loop.\"\"\"\n", + "\n", + " tb = widgets.TabBar([\"RNN Colorbot\"])\n", + " while True:\n", + " with tb.output_to(0):\n", + " try:\n", + " color_name = six.moves.input(\"Give me a color name (or press 'enter' to exit): \")\n", + " except (EOFError, KeyboardInterrupt):\n", + " break\n", + " if not color_name:\n", + " break\n", + " with tb.output_to(0):\n", + " tb.clear_tab()\n", + " sess.run(inference_ops, {color_name_placeholder: color_name})\n", + " plt.show()\n", + "\n", + "with tf.Graph().as_default():\n", + " # Read the data.\n", + " batch_size = 64\n", + " train_data = load_dataset(data_dir, train_url, batch_size)\n", + " eval_data = load_dataset(data_dir, test_url, 50, training=False)\n", + " \n", + " # Create the model components.\n", + " lower_cell, upper_cell, relu_layer = model_components()\n", + " # Create the helper placeholder for inference.\n", + " color_name_placeholder = tf.placeholder(tf.string, shape=())\n", + " \n", + " # Compile the train / test code.\n", + " tf_train_model = autograph.to_graph(train_model)\n", + " train_model_ops = tf_train_model(\n", + " train_data, eval_data, batch_size, lower_cell, upper_cell, relu_layer, train_steps=100)\n", + " \n", + " # Compile the inference code.\n", + " tf_inference = autograph.to_graph(inference)\n", + " inference_ops = tf_inference(color_name_placeholder, lower_cell, upper_cell, relu_layer)\n", + " \n", + " with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " \n", + " # Run training and testing.\n", + " sess.run(train_model_ops)\n", + " \n", + " # Run the inference loop.\n", + " run_input_loop(sess, inference_ops, color_name_placeholder)" ] }, { + "cell_type": "markdown", "metadata": { - "id": "AHJ2c47U-A5W", - "colab_type": "text" + "colab_type": "text", + "id": "AHJ2c47U-A5W" }, - "cell_type": "markdown", "source": [ "# Where do we go next?\n", "\n", - "Autograph is available in tensorflow.contrib, but it's still in its early stages. We're excited about the possibilities it brings — write your machine learning code in the flexible Eager style, but still enjoy all the benefits that come with running in graph mode. A beta version will be available soon -- stay tuned!" + "AutoGraph is still in its early stages, but is available in [tensorflow.contrib](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/autograph). We're excited about the possibilities it brings. New versions will be available soon — stay tuned!" ] } - ] + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "Dev Summit 2018 - Autograph", + "provenance": [ + { + "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", + "timestamp": 1522238054357 + }, + { + "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", + "timestamp": 1521743157199 + }, + { + "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", + "timestamp": 1520522344607 + } + ], + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 2", + "name": "python2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tensorflow/contrib/autograph/examples/notebooks/graph_vs_ag_vs_eager_sum_speed_test.ipynb b/tensorflow/contrib/autograph/examples/notebooks/graph_vs_ag_vs_eager_sum_speed_test.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..32742bec7ee4a412aabb6640b5a1329353ebfc9d --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/graph_vs_ag_vs_eager_sum_speed_test.ipynb @@ -0,0 +1,519 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "moMkWaT_TTHi" + }, + "source": [ + "This Colab illustrates the differing overhead* between a custom, vectorized graph operation and a loop over a tensor\n", + "that computes the same function. The loop is implemented in TensorFlow Eager mode using Python syntax and control-flow, and using AutoGraph which takes a python function and converts it into graph mode. In AutoGraph the Python loop is converted into a tf.while_loop.\n", + "\n", + "The actual computation, summing a small number of scalar values, takes very little time to compute, so the graphs below are showing the overhead of the differing approaches. As such, this is more of a \"micro-benchmark\" than a representation of real-world performance of the three approaches.\n", + "\n", + "*Note the differing scales of the included plots" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "a0X_rfvuav98" + }, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "EdxWv4Vn0ync" + }, + "outputs": [], + "source": [ + "!pip install -U -q tf-nightly" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "erq3_S7QsjkU" + }, + "outputs": [], + "source": [ + "from __future__ import absolute_import\n", + "from __future__ import division\n", + "from __future__ import print_function\n", + "\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import matplotlib.pyplot as plt\n", + "import math\n", + "import time\n", + "import random\n", + "from colabtools import adhoc_import\n", + "from tensorflow.contrib import autograph as ag\n", + "from tensorflow.python.framework import function" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1JgnsXooa2RP" + }, + "source": [ + "### Testing boilerplate" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "UyD5LLjVZzny" + }, + "outputs": [], + "source": [ + "# Test-only parameters. Test checks successful completion not correctness. \n", + "burn_ins = 1\n", + "trials = 1\n", + "batches = 2\n", + "max_elements = 2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4_NBL0RQa8gY" + }, + "source": [ + "### Speed comparison parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Yq6daecyiJV5" + }, + "outputs": [], + "source": [ + "#@test {\"skip\": true} \n", + "burn_ins = 3 # Batches not counted in the average\n", + "trials = 10 # Batches run per vector-size (and averaged)\n", + "batches = 1000 # Number of random vectors summed over per trial\n", + "max_elements = 100 # Vectors of size 0 to this-1 will be executed and plotted" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "fiR8m13CbKH2" + }, + "source": [ + "### Random input" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "d8vrTlyNXuxc" + }, + "outputs": [], + "source": [ + "# Construct a random num x 1 tensor\n", + "def get_elements(num):\n", + " return tf.random_uniform(shape=(num, 1), maxval=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ILJ6SbF3bXFQ" + }, + "source": [ + "## Graph mode" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "vovRf597X55n" + }, + "outputs": [], + "source": [ + "def tf_sum(elements):\n", + " # Using custom vectorized op\n", + " return tf.reduce_sum(elements)\n", + "\n", + "def run_trial(num):\n", + " elements = get_elements(num)\n", + " return tf_sum(elements)\n", + "\n", + "\n", + "\n", + "graph_means = []\n", + "for num in range(max_elements):\n", + " with tf.Graph().as_default():\n", + " durations = []\n", + " foo = run_trial(num)\n", + " \n", + " with tf.Session() as sess:\n", + " \n", + " for _ in range(burn_ins):\n", + " for _ in range(batches):\n", + " sess.run(foo)\n", + " \n", + " for _ in range(trials):\n", + " \n", + " start = time.time()\n", + " for _ in range(batches):\n", + " sess.run(foo)\n", + " \n", + " duration = time.time() - start\n", + " durations.append(duration) \n", + " \n", + " graph_means.append(np.mean(durations)) " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 301 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 278, + "status": "ok", + "timestamp": 1532447361278, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "Jm9Blkyx90Eq", + "outputId": "d83cd51f-7e56-4d73-f7df-bb157dee46df" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa8AAAEcCAYAAABwNTvaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3WdgVFXegPFnZtI7kJ5QQwlCKIGERGroEoSAYZFVEFAR\ngV1XXHvbFWEtK6xlUVgRXMuLiqBSZQUh9E5ChxRIn/ReJjNz3g8hNxkzCUMJkHh+X2DmtnPPnXv+\n95R7ohJCCCRJkiSpGVHf6QRIkiRJ0vWSwUuSJElqdmTwkiRJkpodGbwkSZKkZkcGL0mSJKnZkcFL\nkiRJanaaJHitWLGCV199tSl23WINHz6cAwcONPlxXnzxRd5///0mP87dIikpiUmTJtGvXz++/PLL\nO52cFi0+Pp4HHnjgTifjmsaPH8+RI0du6T5/b/eVObcqX9966y3Wrl17zfWsbmTnffv2RaVSAVBe\nXo6NjQ1qtRqVSsUbb7zBE088cSO7vW5paWmMGDGCs2fPolY3n0rkiy++iLe3N0899dSdTspdoSmv\n46effsqAAQPYsGGD2eVbt27l888/5/z58/Tq1Yv//ve/JsvPnTvHyy+/TGJiIgEBASxevJjAwEBl\n+bvvvsu6detQqVQ88MADPPvssxZv29SmT5/OxIkTiY6Ovi3H++CDD3jsscea9BjDhw9n8eLFhIeH\n3/A+Nm3adAtT1LgVK1bwySefoFKp0Ov16PV67OzsEELg7+/Pxo0bCQwMxN7eHpVKhRACa2trDh8+\nfNvSeCPMlWG3Kl8fffRRpkyZQnR0NFZWDYeoGyopTpw4wfHjxzl+/Di+vr6sWLFC+W78+PE3nOjr\nJYRQLrjUfDXldUxPT6dz584NLndzc2PmzJnMmTOn3rKqqirmz59PVFQUR44cISoqinnz5qHX6wFY\nu3YtO3fuZOPGjfz000/s2rWLb775xqJtm4PruR7Z2dkcOnSIESNGNGGKbo7BYLjtx3ziiSeUsvHv\nf/87ffv25fjx45w4cYKNGzcCoFKp+Omnn5Tv73TguhP5VJeHhwcBAQHs3Lmz0fVu+jFXCFHvR/7R\nRx8pT6BpaWkEBgayfv16hg0bxoABA1i7di2nTp1iwoQJhIaGsmjRIpPt161bx7hx4xgwYACPPfYY\n6enpZo89ffp0APr3709wcDCxsbEIIVi+fDnDhw9n4MCBvPDCC5SUlJjdPj8/n7lz5xISEsKAAQN4\n+OGHlWWBgYGkpKQon+s2Cxw+fJihQ4fy6aefcu+99zJ48GB++eUXdu/ezZgxYxgwYAArVqwwe8xv\nv/2WjRs38umnnxIcHMyTTz6pLDt37hwTJkwgJCSEhQsXotPplGW//vorUVFRhISEMG3aNC5cuGB2\n/wAJCQnMnj2bAQMGcN9997F169YG121sv8OHD2fVqlVMmDCBvn378sorr5Cbm8vjjz9OcHAws2fP\npri4WFn/5MmTPPjgg4SEhBAVFWVyE06fPp3333+fadOmERwczKOPPkpBQYGyDEyvY3JyMtOnT6d/\n//6Eh4ezcOHCBs9hx44djB8/ntDQUGbMmEFiYiIAjzzyCIcOHeKNN94gODiYK1eu1Ns2PDycsWPH\n4uHhUW/Z4cOHMRgMzJgxA2tra6ZPn44QgoMHDwLwww8/MHv2bDw9PfH09GTWrFlKDe/QoUONblvX\nli1b6jW3rVmzhnnz5gGg0+l4++23iYiIYNCgQfztb38z+W388ssvREVF0a9fP0aPHs3evXtZtmwZ\nx44dY9GiRQQHB/Pmm28CcPz4caKjowkJCWHKlCmcOHHC5BotW7aMadOm0adPH1JTU1m/fj0jR44k\nODiYkSNHNvh0vW/fPnr06IGNjQ0AK1eu5M9//rPJOm+++SaLFy8GoKSkhJdffplBgwYxdOhQ/vWv\nf5mUI99++y3jxo0jODiY8ePHc+7cOZ577jkyMjJ48sknCQ4OZtWqVUD965+QkKDsZ/jw4fznP/9R\nfsMGg8GkiT4kJITg4GCCg4Pp27cvgYGBSnnT2L1x9uxZJk+eTL9+/Xj66aeprKw0my+WsPQhobGy\nraac/fbbbxk8eDCDBw9m9erVJtuuXLmSUaNGERYWxtNPP01RUZHJtuvWrSMiIoKZM2cC8NRTTzFo\n0CBCQkKYPn26kq8NlWF181Wn07F48WIGDx7MkCFDWLJkCVVVVUBt+bl69Wql/Fy/fr3JuYaEhLBr\n165rZshNiYiIEPv37zf57sMPPxTPPvusEEKI1NRU0a1bN/H666+LyspKsW/fPhEUFCTmz58v8vLy\nRGZmpggPDxdHjhwRQgjxv//9T4wePVokJiYKg8EgPv74YzF16lSzx05NTRWBgYHCaDQq33333Xdi\n9OjRIjU1VZSVlYkFCxYoafmt9957T7z++uvCYDAIvV4vjh49qiwLDAwUycnJyucXXnhB/Otf/xJC\nCHHo0CFxzz33iOXLlwu9Xi++/fZbERYWJp555hlRVlYmLl26JIKCgkRKSorZ49bdV918nDJlisjO\nzhaFhYXivvvuE2vXrhVCCHH69GkRHh4u4uLihNFoFBs2bBARERFCp9PV23dZWZkYOnSo2LBhgzAa\njeLs2bNiwIABIj4+vt6xr7XfiIgIMXXqVJGbmyu0Wq0IDw8XkyZNEufOnRM6nU7MmDFDfPTRR0II\nITIzM0VoaKiIiYkRQgixf/9+ERoaKvLy8oQQQjz88MNi1KhR4sqVK6KyslI8/PDD4r333mvwOi5c\nuFB88sknQgghKisrxbFjx8zmZWJioujTp4/Yv3+/0Ov14j//+Y8YNWqUqKqqUo773Xffmd22rm+/\n/VZMnz7d5LvVq1eLxx9/3OS7J554QqxevVoIIUS/fv1EbGyssuzUqVMiODjYom3rKi8vF8HBweLK\nlSvKdw888IDYsmWLEEKIN998Uzz55JOiqKhIlJaWirlz54qlS5cKIYSIjY0V/fr1U+5BrVYrEhMT\nzZ57QUGBCAkJET/99JMwGAxi06ZNIiQkRBQUFCjrR0REiPj4eGEwGERxcbEIDg4Wly9fFkIIkZ2d\nrfyOfuvtt98Wb7zxhvI5LS1N9OnTR5SUlAghhDAYDGLgwIFKfj355JPi9ddfFxUVFSI3N1dMmTJF\nfPPNN0IIIbZs2SKGDBkiTp8+LYQQIjk5WaSnpwshqn+TBw4cUI5zresfEREhoqKiRGZmpqisrFS+\n+22ZJYQQS5cuFQ8//LDQ6/WN3hs6nU5ERESIzz//XOj1erFt2zbRo0ePevf0b61fv1788Y9/rPd9\nt27dTMqahjRWttWUswsXLhQVFRXiwoULIiwsTDnP1atXi6lTpwqtVit0Op147bXXxMKFC022ff75\n50V5ebmST99//70oKysTOp1OLFmyREycOFFJS0NlWM3x/vWvf4mpU6eKvLw8kZeXJ6ZOnSref/99\nIURt+fnhhx8KvV4vdu3aJXr37i2KioqUfW3fvl1MmjSp0fy4LR1FKpWK+fPnY2Njw7333ou9vT2R\nkZG0atUKLy8v+vfvz9mzZwH45ptvmDNnDh07dkStVjNnzhzOnz9PRkZGYwFY+f+mTZuYOXMmfn5+\n2Nvbs3DhQrZs2YLRaKy3nZWVFdnZ2aSmpqLRaOjXr5/ZfZpjbW3N3Llz0Wg0jBs3jvz8fB555BHs\n7e3p3LkznTt3brR2ZM6MGTNwd3fHxcWFiIgIzp07B8B3333Hgw8+SFBQECqViqioKGxsbIiNja23\nj19//RV/f3+ioqJQqVR0796d0aNHs23btnrrWrLfhx9+mNatW+Pp6Un//v3p3bs3gYGBWFtbM2rU\nKCWNP/30E8OGDWPw4MFAdY2mZ8+e7N69W9nX5MmTadeuHTY2Ntx3333KtjXq5rmVlRVpaWlotVps\nbGwIDg42m2dbt25l2LBhhIeHo9FoePTRR6moqDCpUdyosrIynJ2dTb5zcnJSnnZ/u9zZ2ZmysjKL\ntq3Lzs6OESNGKLWay5cvk5SUpDTBrVu3jhdffBFnZ2ccHByYM2eOsu66deuIjo5W+oA8PT3p2LGj\n2fPZtWsXHTp04P7770etVhMZGUmnTp349ddflXUmTZpEQEAAarUajUaDRqPh4sWLVFZW4u7uTkBA\ngNl9FxcX4+joqHz29fXlnnvu4ZdffgHgwIEDODg40KtXL3JyctizZw8vvfQStra2tG7dmkceeYTN\nmzcr5/TYY4/Ro0cPANq2bYuPj4+y77q/E0uu/4wZM/Dy8lJqheZs2bKFTZs28eGHH6LRaBq9N2Jj\nY9Hr9cyYMQONRsOYMWPo2bNng/u2xKRJkwgJCSE0NFSpnf6WJWXbn/70J2xtbenatSuTJ09W8vTb\nb7/lL3/5C56enlhbWzN//nx+/vlnZVuVSsWf/vQn7OzslHyaPHky9vb2yvrnz59vsBXLXFrnz59P\nq1ataNWqFQsWLODHH39UlltbWzNv3jw0Gg1Dhw7FwcGBpKQkZbmjo6NJq445NzRg40a0adNG+b+d\nnR3u7u7KZ1tbW+WmT09PZ/Hixbz99ttAbX+IVqs1+QE3JCsrC19fX+Wzn58fer2enJwcPD09TdZ9\n7LHH+PDDD5k9ezYqlYopU6aY7fswx83NTRm0YmdnZ/Yca87JUnW3t7e3Jzs7G6jOkx9//FEZLSeE\nQK/Xk5WVVW8f6enpnDx5ktDQUGVdg8FAVFSU2XWvtd+6abK1ta33ue5127p1q1IQ1uyrbsd63Wtu\nb2/faP4899xz/Otf/yI6OlrplzI3ku2311ulUuHj44NWq21w35ZycHCod7OWlJTg5ORkdnlJSQkO\nDg4WbftbkZGRvPPOO8ybN49NmzYxcuRIbGxsyMvLo7y83OTcjUajUoBnZmYydOhQi87nt3kF1UGm\nbl55e3sr/7e3t2fZsmWsWrWKl156iX79+vHcc8/RqVOnevt2cXGhtLS03jlt3ryZiRMnsmnTJqU/\nPD09Hb1ez6BBg4Daroea+zszM5N27drd0DmZu/51z8mcs2fPsmjRIlavXo2bm5uSxsbuDS8vL5N9\n+Pn5WZTehmzYsIG2bds2uk5jZRtUn3vdc/X19eXSpUvK+SxYsEAZECWEwMrKStkWTPPJaDSydOlS\nfv75Z/Lz81GpVKhUKvLz8xv8DTeWVl9fX5Nyxc3NzWRwlp2dncnvp7S0tN7D32/dtuBlKW9vb558\n8kmLBn7UBI+6PD09TfrI0tLSsLKyMik4azg4OPD888/z/PPPk5CQwPTp0+nVqxdhYWHY29tTXl6u\nrJudnX3Nm6CpeHt7M3fuXItGcfr4+DBgwAClP+BW7deS40ZFRfHGG29c97bmrmObNm2UvtBjx44x\na9YsQkND693gnp6eyg1aIyMj45Zcqy5durBmzRqT7y5evKj00XXu3Jnz588TFBQEVPdZdunSpdFt\n6/ar1jVo0CBefPFFzp8/z+bNm3nppZcAaNWqFfb29mzatKnewxdUX8O6fbN1/TZfPT092b59u8l3\n6enpDBkypMFtBg4cyMCBA9HpdCxbtoxXX32Vr776qt6xunXrZvJkDTB27FjeeecdtFotv/zyizKY\nxcfHB1tbWw4dOmT22nt7e5OcnGzxOd3M9c/Ly2PBggW8/vrrJiNBG7s3jhw5Uu/hKD093eKAe6Ma\nK9syMjIQQpCRkaHUvDMyMpTfjI+PD0uWLKFv37719puWlgaY5u3GjRv59ddf+fzzz/H19aW4uJiQ\nkJDrSmtaWppSU09PTzf7+21IQkLCNUfm3pZmw2s1wdU1bdo0VqxYQXx8PFDdHGGuyQugdevWqNVq\nkx96ZGQka9asITU1ldLSUpYtW0ZkZKTZIdi7du1StnVwcFCaSaB6wMamTZswGo3ExMTc0vdC3N3d\nGyxwzPnDH/7A2rVriYuLA6qbpHbv3m225jJs2DCSkpL48ccf0ev1VFVVcerUKWUQw43u91omTJjA\nzp072bt3L0ajkcrKSg4fPmxRDcjcddy2bZuyrYuLC2q12uw1vO+++9i1axcHDx5Er9ezatUqbG1t\n6dOnj0XpNhqN6HQ69Hq9yf8BQkNDUavVfPHFF+h0OuUpfMCAAQBERUWxZs0atFotWq2WNWvWMHny\n5Ea3DQsLM5uOmuand955h6KiIgYOHAigtAgsWbKEvLw8ALRaLXv37gUgOjqa9evXc/DgQYQQaLVa\n5Vr/9nc2dOhQrly5wubNmzEYDGzZsoXExEQiIiLMpik3N5edO3dSXl6OlZWVco+YM3DgQM6cOWMy\nkKR169aEhITw4osv0rZtW6XG5uHhwcCBA1myZAklJSUIIUhJSVHusSlTpvDZZ59x5swZAJKTk5Vu\nA3d3d1JTU5Vj3Mz1NxgM/OlPf2LChAmMHTvWZFlj90afPn2wsrLiiy++wGAwsH37dk6dOnXN490s\nS8q25cuXU1FRwaVLl1i/fj2RkZEATJ06laVLlyrBLy8vjx07dijb/baMLi0txcbGBhcXF8rKynjv\nvfdMgtu1yrDIyEg+/vhj8vLyyMvLY/ny5UycONHicz1y5IjJQ5U5Nx28zD05XWudxj6PHDmSxx9/\nnKeffpr+/fszYcIE9uzZY3a/dnZ2zJ07l2nTphEaGkpcXBzR0dFMnDiRhx9+mFGjRmFvb88rr7xi\ndvvLly8zc+ZM+vbty7Rp03jooYeUp4uXX36ZnTt3EhISwubNmxk5cuRNnWNd0dHRxMfHExoayoIF\nC665fs+ePVm0aBFvvPEGoaGhjBkzpsH3lhwdHfnss8/YsmWLMurovffeMylULN3v9ZyTt7c3y5cv\nZ8WKFYSHhxMREcFnn32m3BSNbWvuOp46dYopU6YQHBzM/Pnzefnll802zXTs2JF3332XRYsWER4e\nzq5du/jkk0+U90Ou9fv88ccf6dWrF2+88QbHjh2jd+/eygv21tbWLF++nA0bNhAaGsr69etZvny5\nsu8HH3yQiIgIJkyYwIQJE4iIiOAPf/iDRduaExkZyYEDB7jvvvtMCqS//vWvtG/fnj/84Q/079+f\n2bNnc/nyZQB69erFkiVLWLJkCf369WPGjBlKQT9jxgy2bdvGgAEDWLx4MW5ubnzyySesWrWKsLAw\nVq1axYoVK3B1dTWbV0ajkdWrVzNkyBDCwsI4cuQIr7/+utm0t2nThrCwMKWPq8b48eM5cOAA999/\nv8n3b7/9NlVVVURGRhIaGspTTz2lNJOPHTuWuXPn8swzzyjXv7CwEIA5c+awfPlyQkNDWb169Q1d\n/5rvMjMzOX78OJ9//rky2jA4OJjMzMxG7w1ra2s+/PBD1q9fT2hoKNu2bWP06NENXtdrsaQMBSwq\n20JDQxk1ahSzZs3iscceU5rtH3nkEUaMGMHs2bPp168fDz74oBKYzaUhKioKHx8fhgwZwvjx4+vV\n2K5Vhs2bN4+ePXsyYcIEJk6cSM+ePZk7d65FeZCVlUVCQsK1y1xxPdUiSZKkBiQkJPDCCy/w3Xff\n3emk/O6kpaUxcuRIzpw506wmbDDn7bffpl27dkybNq3R9WTwkiRJauaa62xDN+P3cZaSJEktnKXN\njy2FrHlJkiRJzY6seUmSJEnNzl33ntfN0OsN5Odf/zDvlqhVKweZF1fJvKgl86KWzItaHh6NvxB8\nN2pRNS8rK/PvoPweybyoJfOilsyLWjIvmrcWFbwkSZKk3wcZvCSpBSsq1VFcVv8FdUlq7mTwkqQW\n7J9rT/LBurhrryhJzUyLGrAhSVIto1GQnlOKrY3s25FaHlnzkqQWqri8CqMQlFfq0Rvq/z07SWrO\nZPCSpBaqsKT2T9MXl1XdwZRI0q0ng5cktVAFJbUDNeSgDamlkcFLklqoujWvIhm8pBamyYNXTEwM\nY8eOZcyYMaxcubLe8qNHjzJ58mR69OhR76+8QvWfTh8yZAhvvvlmUydVklqUgtI6Na9S2WwotSxN\nGryMRiOLFi1i1apVbNq0ic2bN5OQkGCyjq+vL2+99Va9P1ZX4/333yc0NLQpkylJLZJpn5eseUkt\nS5MGr7i4ONq3b4+fnx/W1tZERkaa/OlpqA5eXbt2NTud/+nTp8nLy2PQoEFNmUxJapEK6/R5FckB\nG1IL06TBS6vV4uPjo3z28vIiKyvLom2FELz99ts899xzyL/aIknXr6BU1ryklqtJX1K+maDz9ddf\nM2zYMLy8vK5rX81xduSmIvOi1u8xL4rL9djbaiivNFCpF0oe/B7zoiEyL5qvJg1e3t7epKenK5+1\nWi2enp4WbXvixAmOHz/O119/TWlpKXq9HkdHRxYuXNjodtnZxTeV5pbCw8NZ5sVVv8e8EEKQV1iB\nn4cjqVkl5BaUkZ1d/LvMi4bIvKjVHIN4kwavoKAgkpOTSUtLw8PDg82bN7N06dIG169bu/rnP/+p\n/H/Dhg2cOXPmmoFLkqRqZVdn1WjlZEthSaUcKi+1OE3a56XRaHj11VeZPXs248ePJzIykoCAAD74\n4AN+/fVXAE6dOsXQoUPZtm0br7/+eoOjDiVJslxBcXV/l6uTDS4ONnLAhtTiqEQLGw0hmwGqySaR\nWr/HvDhzOY/31p5kwsAOJKQVcuZyPp88MxQ/X7ffXV405Pf4u2hIc2w2lDNsSFILVPOOl5uTLc6O\nNoCc31BqWWTwkqQWqOYdL1cnG5ztrwavctnvJbUcMnjdJpU6A2u2nufoecvec5Okm1EzKa+bky0u\njtYAFMkpoqQWRAav2+RkfA4xseks/+E0KzeeobRCFiRS0ym8+oKyq6MNzg41zYay5iW1HPIvKd8m\nlzOLAGjjYsfBM1ouJBfwxIQedG3rdodTdnuUlFdx8lIO+cUV5JfoqKjUM3FQR7xaO9zppLVIBSU6\nVICLow3ODtU1L9nnJbUkMnjdJpczilEBf58dwi9HU/lp32VWbz3PP+aE3emk3RZfbr/A4XOmTaat\nXGyZMqzzHUpRy1ZYUomzgzVWGjUuV2te8l0vqSWRzYa3gVEILmuL8XF3xMHOmgmDOtLZz4WsvDKq\n9IY7nbwmpzcYiUvIpY2LLU//oTcvPhwMQFp26R1OWctVUKrD1ckWoE7Nq+UEr5LyKjnn6e+cDF63\ngTavjEqdgQ7ete9SeLdxQADa/PI7l7Db5GJKARU6A327eBDUqQ1d/N1wdbIhNbvkTietRarQ6anU\nGXB1qq5x1fZ5tYxmw0upBfz5/T38c+1Jsgta/v0jmSeD121wObP6RUiT4NXaEYDM3LI7kqbbKTY+\nF4Dend2V7/w9nMgrqqRMDly55WqGybs5Vte87Gw0WGnULabmde5KvvLvq6sO8cvRFIyyFva7I4PX\nbXA5oyZ4uSjfeV8dqJCZ1/KDV1xCDrY2GpPBKf4e1cE7VTYd3nIFJbVTQwGoVCpcHK1bzFD5mt/M\nlIgArDVqvv7lEv/dduEOp0q63eSAjdvgcmYRKhW09XJSvvNu8/sIXpl5ZWjzy+nX1QNrq9pnJX+P\n6rxIyy753Yy4vF56g5HSCj2uV2fIsFRhae07XjWc7W3IyL27HxTKKqrYeiiZvKJKist1lJRVMaS3\nL8P6+pmsl5Zdgr2tFWND23FvTx+WfHGUg2cyeWhUF6ytNHco9dLtJmteTcxoFFzRFuPn7oitde2N\n5e5qh0atQtvCg1dsfA4AvTq3Mfm+JnjJmlfDvvrfRZ7/ZD9FpdfX3KdMylsn6Dk7WqPTG6mo1N/S\nNN5K2w4ns/nAFQ6cyeR0Yh6XM4v5+UiKyTpVegPavHL8PBxRqVS4OtrQr6snOr2R88kFdyjl0p0g\ng1cTy8gtRVdlNGkyBLDSqPFwsyczr6xFj5pSgleAu8n3Pm0cUKmQgzYaUFpRxf7TmeiqjCSmF13X\ntgVmal41w+VrmhTvNlV6I7tPpuNoZ8U/ngjj44VD6d6+Fdq8MpMX+tNzyjAKQVuP2laMoIDqB6O4\nq32rUjVtfhknLmbf6WQ0GRm8mpgyWMOn/qzN3q0dKK3QU1x+Y30Rd3vQK6uo4lJqIR19XOo1fdlY\na/Bq5UBqduldfx51ZeSWsnzDKfKLmzYIHDidSZXeCNS+4G6pwt/0eUHtcPnrrcXdLkfPZ1FcVsXg\n3r54tXLA1kZDJ9/qB76kOsG75mGnps8UoIu/K/a2GmITcprVb6mprd5yng/XnyIhvfBOJ6VJyODV\nxMwN1qihDNpoZMThtzvj+b9fLtX7Pq+ogqc+2Mv2w8m3KKW33umkPAxGQe+ANmaX+3s4Ul6pNwkE\nCWmFrN1xCYPReLuSeV02xCRy9EI2/zuaUm9ZUZmO9JybbwYVQhATm45GrQJqH4AsVTuvYW3wuttr\nXjuOp6ICIur0b9UEr8SM+sHLr07Ny0qjpkeH1uQUVpDxOxi9a4n84koupVQ3o/58uP5vtSWQwauJ\nXc4sQqNW0dbTsd6yaw3aiE8tZNvhZP53NKVeoRgTm05JeRUHz2pvfaJvQn5xJTmF5RSWVHL8apNF\n3SHyddX2e9U2Ha7deYntR1K4eBf2X+QUlnPs6jntP52J3lAbYIUQfPh9HH9bfZicm3z3KDG9iNTs\nUvp29aCNiy2XM4quq0ZRWKrD0c7KZPCCU03N6y4MXkkZRSSmF9G7szsebvbK9518XQFMmk1r+kjr\n1rygtlk6LkE2HQIcOZ+FADRqFccuZLXI9+Fk8GpCeoOR5KwS/DwczY6Camy4vBCC73cnKJ93nUxT\n/m8wGtkTlwHAFW3xLZvk1yjETTW7nLyUwzP/3sdzHx/g6Y/2cfhcFm5ONrSrM8qyLj9lxGF1gZSZ\nV0ZCWnVBdSn17mvq2Hk8DSGqB9sUleo4lVhbUF5MKSAhrQi9QbD54JWbOs7u2HQAhvb2pYO3C0Vl\nVdfVTFlYUqnMrlGjtuZ19zUb7jyWCsDwfqajCl0dbWjjYkdiem3wTs0uoY2LLQ521ibrKv1eCTm3\nIcV3vyPntKhUED0sACEw21LQ3Mng1YTSc0qp0tcfrFGjsWbDM5fzuJBSQI+OrXF1tGH/qUwqq6qn\nkjqVmEd+cSU21mqE4JbUUsor9Ty7fD///fnG3pcxCsH6mERUKgjr4UVIoCd9u7jz4IguqFQqs9v4\ne9a861WD0HifAAAgAElEQVRd89p/OlNZdjH17qp5VeoMxJxMx8XBmicm9gBg79UHCIBth6qbb53s\nrdkbl0FuYcUNHae8Us/hc1rcXe3o3qGV0ldqadNhld5gdnh9zSwbhXdZzauoTMehc1l4tXbgng6t\n6y3v5OtCSXkV2YUVFJfpKCzRmTQZ1nB1tKGDtzOXUgspq7gzIypv9kXp3zah36icgnIS0osIbNeK\nEf38aeVsy57YjBb3lyxk8GpCjQ3WgOpOdAdbq3o1r+paVyIAU4YFMLi3L2VXCzWAmJPVT+YPDAkA\n4Fxy/k2n9XRSdUDcfTKds5fzrnv7ExdzSM0uYcA9Xsy5vwdPRvXkTw/0IrS7V4PbeLjZY2OtJjW7\nFKMQHDidgZ2NBk83exLSiu6qfq/9pzMoq9QzrK8fAb6utPNyIjY+l8KSStJySolNyKWznysPjuiM\nwXjjta9DZ7XoqowM7u2LWqVSHnwsHbRRaKa/C8DlarPh3RS8Siuq+GFPEnqDkeHBfqjNPOQo/V7p\nhXWaDM3X5HsFtMFgFDf0+71Ze2LTmbd0NxdTbuyhq7xSz5v/Pcqrnx666Vlnjlz9m4ED7vHCSqNm\nZH9/KqsM7L5abrQUMnjdpLKKKtbuuGS2Tfn81WlsOjZQ81KpVHi3cSC7oNykoD52IZsrmcWEdvek\nnZczQ3v7olLBrhPp5BVVEJuQQwdvZ4b19cPaSq0c52bUHVL7xc8XrmvCYCEEG/cloQLGh3eweDu1\nSoWfuyMZuaWcu5xPblEl/QM96d6hFZVVBpK1t3YYfUZuKa+tOsSWRgJLld5ITGw6b311nP/+fIH8\n4kqMQvC/o6lo1CplQMHgXr7VAfeMlp+v1rrGDmjHgHu88Gxlz57Y6mt1PfQGI7+eSEOtUjEoyAeA\n9lenFKsZ+NMQo1EQl5DD59vOA9RrNqyted1cs2FZhf6mA2BiehErN55h4Uf72HUiDRdHGwb29DG7\nbm3wKjI70rCumr7V2NvcdGgUgi0Hr6CrMvLpprOUX+e7dEIIVm0+R0ZuGWWVemJiM669USMOn8tC\no1YR3NUDgKG9/bCz0fDL0RQycku5kJzP0fNZN9w6cLeQM2zcpB/2JvHL0VS0eWU8NaW38n1uYQVH\nzmfh08bBZGaN3/Ju7UBiehE5BRV4tXbAaBRs2JOIWqUianAnANq42tGrUxtiE3L5+pdLCAFD+/hi\nbaWmi78rZy/nU1Sqw+U6Z2KooTcYiU3IpY2LHX27uvPL0VS2HExm4qCOFm1/Mj6H5KwSQrt74utu\nvmBpiJ+HE0kZxayPqa5pDuzpTV5RdQ3wUkoBHX3MB/7rVVSqY9m3seQUVrBuVwKujjYMDKotMCt0\nenYcS+WXo6nKDBUXUwrYfyqD3p3dycwr496e3kpQGHCPF9/svMTO46nkF1fi1dqBPl3cUatUjA/v\nwGdbzrH54BWmj+5mUfqMQvDZ5nOkZFXXXls5Vx/Hyd4ad1c7LmcWI4Qw2wR75nIen289T87VwqiT\nr4vJuQHY2miwsVYrf6SyMWcu53HyYg4TB3fEyb62b6msooq/rT5CaYWev88Kwb3O4ApLZeWX8Y8v\nj2EwCrxa2TOkty/3BvngYGe+KGrv5YxGrSIpvQjd1Wbzhmpe7b2dcXG0ITY+l437krC11mBtrUGv\nN1JRZUBXZSCwfSt6mGmevBnnLuejzS/H0c6KnMIKvv01nkfGBlq8/ZaDVzh+MZvO/q6kaEv45VgK\no0L80ajN1y1yCsrJK640OzONNq+MK9pigjq1Ua6dg50VQ3r7sv1ICi//55Cybr+uHsyfHHSdZ3v3\naPKaV0xMDGPHjmXMmDGsXLmy3vKjR48yefJkevTowfbt25Xvz58/z4MPPsj999/PxIkT2bJlS1Mn\n9bplFZTz6/HqgRSxCbkmo6K2H0nBYBTcN6C92eaQGjX9XhlXmw73nsogI7eMgUHeyjJAmSLn+MVs\nbG00SnNc9/atADh/E02H55PzKa/U07erO5MGd6KVsy2bD1y2aOoqIQQ/7b2MCrh/oGXBri7/q8Eu\nKaMId1c7urR1o0vb6lFmF2/RoI0KnZ7318WRU1jB0D6+ONhasWbreaWJJz6tkL99doTvdydSWWVg\nbGg73pkbzsz7AnG0t1aaYUb1b6vs08nemr5dPMgprMBgFIwJbatc5/CeXni42bEnNp2kjPrNffFp\nhRw4k6k0Dwkh+HL7RQ6e1RLg58LM3xR8HXyq+33MPSnnFJTz8YbTFJRUMqS3D6/PDOGVGf3xM/MQ\n4eJgQ2EjfSpVeiPf7LzEe2tPsuN4Kh//cFoZUSmE4L8/XyCnsILySj2fbjqL0Xj9fTw198VDo7qy\nZE4Y94W1b3T6KxtrDf4eTlzRlnA5sxiNWqWM0v0ttUpFv64elJRXsWFPEmt3xvPFzxf4vx2X2BCT\nyOYDV1i+4dRN9f2cuZxXr1lv5/HqASd/eqAX/h5O7D6ZbvGoxzNJeayPSaSVsy0LJgUxMKj64e3Y\nBfMvF5+8lMNrnx3mra+Om32lpKZrIbS7p8n348LaE9bDi0G9fIgMb88fR3Zh6vDm/bf0mrTmZTQa\nWbRoEWvWrMHT05Po6GhGjBhBQECAso6vry9vvfUWn332mcm29vb2vPPOO7Rr146srCwmT57MkCFD\ncHJquBZzu22IScRgFAzr68euE2n8uDeJp//Qm5LyKmJi02nlbEtYj4b7fMB00EZlOwM/7EnExkqt\n1LpqBHVqQxsXO3KLKhjQ3Qt72+pLF6gEr4JG+5cac+JidTNLcBcP7G2tmDaiC8t/OM0XP1/gmQf7\nNBp8YxNyuaItJiTQ02yBeS1+nrXX896e3qhVKtxd7WntYsul1AKT2sbxi9moVSr6dDE/9N4cg9HI\ne18dIymjiHt7ejNjTDdCAz1Z+m0sH60/xb09vatHYgkYG9qO8fd2UGoBQ9zsCbvHi10n0hDUNuHV\nGNzLhyPns3BxsGZgT2/le41aTdSgTvxn01kWfX6UPp3dGX9vB/KLK9l2+IoyotLaSk2/bh7YWGmI\niU2nracTT0/pja2N6cjUjt7OHD2fxeXMYpPajt5g5OMfz1BWqWfWuEAG9/JtNC+cHayVl8Lr1uCM\nRkFSRhFf/HyB5KwSvFrZ08bVjrOX8/lmZzwPjerKgTOZHD6XRYCfC26Othy7mM22w8mMC2tv8bUo\nLtOxNy6DNi62DO3j2+BAnt/q5OvCFW0xydoS/D0csdI0/Mz94IguhPf0plJnoPJqbcvaSo2tjYaz\nl/PZdiiZnw+nMHlI7f1VXqlnQ0wirVxs6dGhNf6eTmZ/86cTc1n6bSyd/Vx57o99sdKoyS2s4GR8\nDu29neni78rj99/DG2uOsHrrOSYO6kh8aiHxqYU4OVgzc2wg/nV+76cTc1nx0xk0ahXzJvXExdGG\nUf3b8uvxNLYfSTG5n4UQbD5whQ0xiVhZqfFws2P7kRTSckqZO7EH1ho1567ks/dUBlYaFX27eJik\n3cXRhjn397Aov5uLJg1ecXFxtG/fHj+/6lpDZGQkO3bsqBe8gHo/5Pbta28KT09P2rRpQ15eXqPB\nq6lmD9AbjPxn41kc7ayYEtEZe1srkjKKOHRWS3tvZx4e3ZXM3FJOJeYSn1bI2aQ8KqsMRA3u2OiN\nBqbD5bcfTaGgRMf4e9srzUY11GoVYwe045ud8QwPrh1S3MHbGTsbjfJnIqC6U/7AuSw6eTnh1dr8\nU2oNoxCcuJSNk721UuPp182D3gHVzZQ/H07mvgHmC6hKnYFvdsYDcP/ADo0epyF1m4DC6wSALv5u\nHDqrJTOvDJ82jqTnlLJ8w2mMQjDzvkCG9DZfUFfqDOyJS+dCcgEZeWVk5ZehNwgC27kx875AVCoV\n3Tu0ZvqYbqzZep7tR1Jo42LHY+O7061dq3r7s7HWMDq0ndlj3dOhNSOC/Qls71bvVYjwnt44O1jz\n077LnIzP4WR8bT9Mn87udPB25sCZTA6eqX5S9mrtwDNT+9QbAg61f0onKbOI/oG1T9TrdiWQlFFE\neA8vpY+sMc4ONlTpiymr1KPNK+dMUi4XUwtJSCukQlfdJDe4lw/TRnZBCFjyxTF2HEvF0c6K7UdS\nsLPR8Pj9PXCwtSI+vZANMYn06NC6XlBvyK8n0tDpjYwKaXfN+6KuTr4u/HqiuoWjoSbDGtZWajr7\nuZpd1sXfjf2nM/nlaAqjQ9pSU7x/uf0iB85Uj3T9jgScHawZHdKWyN/032692rcZn1bI97sTmDq8\nC7tjq1+fGN7XD5VKRVtPJ6IGd+T73YnKTPf2thqyCspZ9N+jTBvRhUG9fPhhTxJbDl7BSqNi1n3d\nCbj6TptXawd6d3bnZHwO8WmFdPZzpbBUx5fbL3DsQjatXWz50+ReeLjZs3LjGeIScnn5P4eoqNSj\nuzojy5Devg02w7YkTXqGWq0WH5/am8rLy4tTp05d937i4uLQ6/W0a2e+EKnx0GtbmTaiC6NC2ja6\nXg0hBMnaElq52CrvwZizcd9lpenodFIecyb0YMPVPpo/DAtQ+qfe+uo463YlkJ5TiuPVduZr8Wxl\njwpISC/k8Dktzg7WDQaL4cF+DO7lg02dCX41ajVd27oRl5BLXlEFGo2at78+oTT5dW3rxuBePrT1\ndEKjVqFWq2jtbKc83SdlFFFQomNgkLfSxq5SqZg1rjuvrz7M97sS6eLvZrZA+L8dl9DmlTE6pO01\nC5WGuDra0M7TiVbOtni1qg20Xf1dOXRWy6XUQnzaOPLdr/EYhcDGWs3nW89jY6UmrEdtsCutqFL6\nrEquTrdlb6uhracTgR3bEBna1qTAHNLbl0qdgdyiCiYO6qjUZK+HWq3iodFdG1zes1MbenRszfkr\n+ew4noaTfXWhWNMveP/ADlxKLeR0Uh7D+vg22GdpbtDGiUvZbD+SgndrB6aP6WZRLaZmiqjnPt5P\neWXtgBzv1g508XelXzdPetWZDeVP0b1YtOYIP+27DMBj47vjebXm9+i47iz9NpaVG8/w2swQk0mn\nzdFVGdhxLBV7WysG97p2oK2rZtAGgF8DgzUsYWutYVxYe9buuMTPh5OZ2641h85qOXAmk44+zozs\n35azSXnEJuTy/e5EurVtRWf/6t/9lcxizl3Jp4u/K8VlVfx8OIWOPi7EXJ2PMfSe2lrSfQPao1ar\nsLPW0KWtG77ujsTG5/DZ5nP89+cL/Lg3icJSHZ5u9syN6lHvVZpRIW05GZ/DtkPJdGvnxg97Eimv\nNNC1rRvzonoqv5M/P9CLDXsS2XYouTroBbShd2f3BoN3S9OkwetWzDOWlZXFc889xzvvvHPNdVs5\n27J25yU6+LtxbyNNKAaj4NDpDNbviufClXzcnG15ZVYo3drX78i9mJzP5oNX8Gxlz+A+fqzfFc8/\nvjyGEBAc6MmQkOpA4+HhTJ/DKZy8VN1WPXVkV9r513+SN8ejtYPyou7MyHss3q5GSA9v4hJyOZ9W\nxC+Hk8nMK2NYP3/yCiuIi8+pN3zXyd6a52f0p09XT7ZcnTpmWP92eHjUPkF7eMDz00N45ZN9rNx4\nlg+eGaaMWAM4cCqdmNh0Ovq6MDe69039KYoPnx2uzAZQY0AvP77YfpHk7FLSCyqITcilZ0AbHpvQ\nk5c/3senm89hY2dNhc7A8QtZnE7IRVdlwMnemgdHdWP0gPa4u9k1Wqj/cdw9N5zm6+Hp6aL8Tswt\nGxh87YctH3dHkrNKcHd3YvfxVFZuPIuNlZqXZoXS1teywiqwozv7TmXiaGfN4D7+BHfzpGdAm3oj\nE2t4eDjz4sxQ/v7pQQb38WPCsNp39iI8nLmYVsSmfUms3nqBF2eG1KtNGY0C9dVruu3AZYrLqoge\n3uW6f99t2jjhaGdFaYWeHp09TH6n1yt6VDe2H0lmx7FURoZ14MvtF7Cz0fDCI6H4ejgxYRicTcrl\n+Y/28u2ueN57aihqtYrPf74IwEP3dcfd1Z6F78ew8qczGAVEDQ3A39d08MSM8T1NPo/2dKFvdx/+\n+dVRziblMaSPH/On9DZb03Z3d2Ld7gSOX8zm+MVsHO2tmTu5B2PDO5jcIwBzo/vw+OTe9b7/PWjS\n4OXt7U16eu27BVqtFk9Pz0a2MFVSUsLcuXNZuHAhvXr1uub6rz0Wxgsf7eWfXx3jWaPR7BNIfGoh\nqzafRZtfPbS9W1s3LqYW8MK/9zF7XKDJ07yuysA/vzyK0Sh4ZGwg3du3orOPM//ZdJbCEh0TwtuT\nnV37NDxuQDtOXsrG2kpNeHdPk2WN8XS1IyuvDK9W9gR3bmPxdjXaXu3A/vTH0wCMCPbnqWnB5OSU\nkFVQzqGzWopLdRiMAl2VgUPntLy+8iAPjujMvtg0bKzU+Le2r3dcb1dbJgzsyA97k3jn8yPMmXAP\ndjZW5BdX8v7aE1hbqZk9rjsF+bd+Pjk7DTjaWRF3KZv4q4NRJg/uiIuthqem9Oa9tSf56LtYZX1f\nd0cGBnkzrI9fdS1Krycnp3potYeH83Xn6d2mrYcjh3NK+ft/DnDsQjZ2NhrmTOyBk7Xa4nML7+7B\nkL5j0FfolCCkK9eRXd5wc7uvmx3LFgzC3laj5GeN+8Pbk5RWwOGzmbzz+REeHd8dtUpFUkYRq7ec\nI6ugnF4B7oQGevJ9TCIateq67ou6Ovq4cDopDxdbzU1fy7Gh7fj6l0s8/9FedFUGZt4XiDVC2a+H\nkw1h93hx8KyWDTsvck+HVuw5mYa/hyNtW9ujUqmYProrqzafA2BANw+L07RwSm+0+WXVk3IXV1Ba\nbH64emRYez7+4TThPb2JHhaAi4MNeblN9xcYbuaB4E5p0uAVFBREcnIyaWlpeHh4sHnzZpYuXdrg\n+nVralVVVcyfP5+oqChGjx5t0fE6+7vxZFRPPlgXxwfr4njuj31NmrMOn9Py6aZzGI2Cwb18GDug\nHT5tHDmVmMsnP55m5cazxKcV0qNja/zcHdlxLI2M3DJG9vNXRvV1a9eKxY+FUVymqzdUuLO/K9HD\nAnB1tLmuYettPZ04nZRH9LCA6+oLqOHvWftkOqS3D9NG1T4he7rZc/+9HUzWH9rHj482nOLrqxP+\nBnf1aLDZZ/y9HbiQUsDJ+BzmLY3BzckGlUpFaYWe6aO73tAgDUuoVSo6+7kSe3XUVngPb6V5pbOf\nKwun9mZPXAZd/F3p0aE1rV3smiQdd4sO3i4cPpfFsQvZtPV0Yt6knibNrJZQq1S0drEju/L6Rts1\n1H9ibaVm/uQg3lt7kgNnMnGws8LWWsPWQ1cQAtq42HH0fBZHrza5DwryqdeXa6lpI7uQll16S67z\n0D6+bD2UTH5xJcFdPcw2Y06J6MyJSzl8vzuB+FR3jEIwJrSdcl8NDPKpfgfQKK7Zr1yXWq3Cp821\n75ngrh6s+OswpeYq1acSTfw3BGJiYli8eDFCCKKjo5kzZw4ffPABQUFBREREcOrUKRYsWEBRURG2\ntrZ4eHiwceNGfvrpJ1566SW6dOmijI76xz/+QWBg4+9PZGcXExObzpqt51EBfbq4Mzqk7dVO1kTs\nbDTMi+pJz06mM52n55Tywbo4sn7zsrFXawf+Nuvabfo3o6xCT0pWsdkBA5baE5tOdmEFUYM6olar\nrlnbyCuq4IPv40jWljBnwj2E3ePd4LrFZTq2HLxCanYp2rwycgsrCO7qwbxJPS0eMXYjth68wne7\nErC2UvOPOWE3XHC1hJpXSlYJiz4/SngPLx4a1dWk3/N6NEVelJRX8fZXx0m7Onm0u6sds8Z1J7Cd\nGylZJRw+l8UVbTEzxnQzmXj3Tjp+MZtD57N4eGQXk+bwujbtv6y8f9jK2Za354bf0MNlc9Aca15N\nHrxut5ob89iFLDYfuGIyJ1wrZ1v+MqU3bT3NDy6o0Ok5dyWf9JxS0nJKyS+qZOqIzg3OTXg3s6SQ\nqqwycCWzmC7+rtcVhPQGIxq1qkkDF0BqVgmvrz7MxEEdmXAD75DVaAnBC6rz/WYLz6bKi/ziSlb+\ndIa2Xk5MHtIJO5u7f7TbtfKiSm/glU8PkV1QwZSIgAYHUrUEMnjdBer+GIUQJKQVsf1oCuWVemaP\n637DzRbNTUspsEvKq3C0s7qpQNlS8uJWkHlRy5K8iE8tZPfJNP44qusNjUhtLppj8Gq5V4PqId+d\n/V2V4a5S81N3eiJJut1k+XH3apkNuJIkSVKLJoOXJEmS1OzI4CVJkiQ1OzJ4SZIkSc2ODF6SJElS\nsyODlyRJktTsyOAlSZIkNTsyeEmSJEnNjgxekiRJUrMjg5ckSZLU7MjgJUmSJDU7MnhJkiRJzY4M\nXpIkSVKzI4OXJEmS1OzI4CVJkiQ1OzJ4SZIkSc2ODF6SJElSsyODlyRJktTsyOAlSZIkNTtNHrxi\nYmIYO3YsY8aMYeXKlfWWHz16lMmTJ9OjRw+2b99usmzDhg2MGTOGMWPG8MMPPzR1UiVJkqRmwupa\nK6SkpLBu3ToOHTpEZmYmtra2BAYGMmbMGEaPHo2VVcO7MBqNLFq0iDVr1uDp6Ul0dDQjRowgICBA\nWcfX15e33nqLzz77zGTbwsJC/v3vf7NhwwaEEEyePJkRI0bg7Ox8E6crSZIktQSNBq/XXnuNM2fO\nMHbsWP7617/i7u5OZWUlCQkJ7N27l5UrV/K3v/2NPn36mN0+Li6O9u3b4+fnB0BkZCQ7duyoF7wA\nVCqVybZ79+5l4MCBSrAaOHAge/bsYdy4cTd+tpIkSVKL0GjwGjFiBG+88Ua977t168a4ceMoKCgg\nJSWlwe21Wi0+Pj7KZy8vL06dOmVRwsxtq9VqLdpWkiRJatkaDV5Dhw5tdGM3Nzfc3NwaXC6EuLFU\nNbDtb2tn5nh4yGbFGjIvasm8qCXzopbMi+brmn1eAG+99Rbz58/H3t6eGTNmcPbsWf7+978zceLE\nRrfz9vYmPT1d+azVavH09LQoYd7e3hw6dEj5nJmZSVhY2DW3y84utmj/LZ2Hh7PMi6tkXtSSeVFL\n5kWt5hjELRptuH//fpydndm7dy9eXl78/PPP9QZYmBMUFERycjJpaWnodDo2b97MiBEjGly/bm1r\n0KBB7N+/n+LiYgoLC9m/fz+DBg2yJLmSJElSC2dRzavGkSNHGDVqFF5eXhY14Wk0Gl599VVmz56N\nEILo6GgCAgL44IMPCAoKIiIiglOnTrFgwQKKior49ddf+eijj9i4cSOurq7MmzePBx54AJVKxYIF\nC3BxcbnhE5UkSZJaDpWwoGNq1qxZ+Pn5sW/fPn744QccHR2ZNGkSGzduvB1pvC6yGaCabBKpJfOi\nlsyLWjIvarXYZsP33nuPzp07s2zZMlxdXcnMzGTWrFlNnTZJkiRJMsuiZsPWrVszc+ZM5bO/vz/+\n/v5NlSZJkiRJalSjwSssLKzRvq0DBw7c8gRJkiRJ0rU0Gry+//57ANatW0dBQQFTp05FCMH333+P\nl5fXbUmgJEmSJP1Wo8GrZlqnI0eO8OWXXyrfv/LKKzz88MM8/vjjTZs6SZIkSTLDogEbWVlZ5OXl\nKZ/z8vLIzs5uskRJkiRJUmMsGrDxyCOPEBUVxbBhwwDYvXs3TzzxRFOmS5IkSZIaZFHweuihh+jX\nrx9HjhxBCMFDDz1Et27dmjptkiRJkmSWxTNsBAYGEhgY2JRpkSRJkiSLWBS8jh8/zrvvvktKSgoG\ngwEhBCqVSg6VlyRJku4Ii4LXyy+/zLx58+jTpw9qtUVjPCRJkiSpyVgUvOzs7Lj//vubOi2SJEmS\nZBGLqlFDhgxh9+7dTZ0WSZIkSbKIRTWvb775hhUrVuDo6IiNjY3s85IkSZLuKIuCV800UZIkSZJ0\nN7AoePn5+aHX60lKSkKlUtGhQwesrK7r71hKkiRJ0i1jUQQ6deoUf/7zn5UmQ71ez4cffkiPHj2a\nOn2SJEmSVI9FwWvx4sUsWbKE8PBwAA4ePMiiRYtYu3ZtkyZOkiRJksyxaLRheXm5Erig+u98lZeX\nN1miJEmSJKkxFgUve3t7Dh48qHw+fPgw9vb2TZYoSZIkSWqMRc2GL730Ek899RQ2NjYAVFVV8cEH\nH1h0gJiYGJYsWYIQggceeIA5c+aYLNfpdDz//POcOXOGVq1asWzZMnx9fdHr9bzyyiucOXMGo9HI\nxIkT620rSZIk/T5ZFLx69erF9u3bSUpKQghBp06dsLa2vuZ2RqORRYsWsWbNGjw9PYmOjmbEiBEE\nBAQo66xbtw5XV1e2b9/Oli1bePfdd1m2bBnbtm2jqqqKjRs3UlFRwbhx4xg/fjy+vr43fraSJElS\ni2BRs+H+/fupqKiga9eudOvWjfLycoteUI6Li6N9+/b4+flhbW1NZGQkO3bsMFlnx44dTJo0CYAx\nY8YozZMqlYqysjIMBgPl5eXY2Njg5OR0vecnSZIktUAWBa933nnHJHA4OTnxzjvvXHM7rVaLj4+P\n8tnLy4usrCyTdbKysvD29gZAo9Hg7OxMQUEBY8aMwd7enkGDBjF8+HAeffRRXFxcLDopSZIkqWWz\nqNmwZjqoGmq1GoPBYNF217tOzbHi4uLQaDTs27ePgoIC/vjHPxIeHo6/v78lSZYkSZJaMIuCl6Oj\nI7GxsfTu3RuA2NhYHBwcrrmdt7c36enpymetVounp2e9dTIzM/Hy8sJgMFBSUoKrqyubNm1i8ODB\nqNVqWrduTXBwMKdPn75m8PLwcLbklH4XZF7UknlRS+ZFLZkXzZdFwevZZ59l/vz5dO7cGYD4+Hg+\n+uija24XFBREcnIyaWlpeHh4sHnzZpYuXWqyTkREBBs2bKB3795s27aNsLAwAHx8fDh48CATJkyg\nrKyM2NhYZs6cec1jZmcXW3JKLZ6Hh7PMi6tkXtSSeVFL5kWt5hjEVcKStj2gsLCQkydPIoSgb9++\nuLq6WnSAmJgYFi9ejBCC6Oho5syZwwcffEBQUBARERHodDqeffZZzp07h5ubG0uXLsXf35+ysjJe\nfNEYr50AABg9SURBVPFFEhISAHjggQeYNWvWNY8nf4zV5I1ZS+ZFLZkXtWRe1GrRwSspKYmEhARG\njhxJaWkpVVVVuLm5NXX6rpv8MVaTN2YtmRe1ZF7UknlRqzkGL4tGG27YsIEnn3ySf/zjH0B139Vf\n/vKXJk2YJEmSJDXEouD1+eef8/333+PsXB2dO3XqRE5OTpMmTJIkSZIaYlHwsra2xtHR0eQ7jUbT\nJAmSJEmSpGuxKHi5ubkpf4gS4Mcff1ReLJYkSZKk283iiXmfeeYZkpKSGD58OHZ2dnzyySdNnTZJ\nkiRJMsui4NWxY0e+++47Ll++jBCCjh07ymZDSZIk6Y6xqNkwKSkJvV5PQEAAGRkZrFq1isLCwqZO\nmyRJkiSZZVHw+stf/oJarSYlJYXXX3+dlJQUnn/++aZOmyRJkiSZZVHwUqvVWFtbs3v3bqZNm8ai\nRYvIyMho6rRJkiRJklkWBa/Kykq0Wi07d+5U5h60cGIOSZIkSbrlLApejzzyCJGRkTg6OhIUFERK\nSorywrIkSZIk3W4Wz21Yl8FgwGAwYGNj0xRpuilyrrJqct62WjIvasm8qCXzolaLm9vw9OnTZr/X\naDTY2Nig0+mUWd8lSZIk6XZp9D2vFStWUF5ezvjx4+nduzfu7u5UVlaSlJTEnj172L17Ny+88AIB\nAQG3K72SJEmS1Hjw+vDDD4mLi+Obb77h3//+N5mZmdjb29O1a1dGjhzJV199hZOT0+1KqyRJkiQB\nFsyw0atXL3r16nU70iJJkiRJFrFotKEkSZIk3U1k8JIkSZKaHRm8JEmSpGZHBi9JkiSp2bEoeOXm\n5vLXv/6Vhx56CIDz58/zf//3f02aMEmSJElqiEXB65VXXqFfv34UFRUB0KlTJ77++muLDhATE8PY\nsWMZM2YMK1eurLdcp9Px9NNPM3r0aKZOnUp6erqy7Pz58zz44IOMHz+eCRMmoNPpLDqmJEmS1LJZ\nFLy0Wi3Tpk1T/gCljY0NavW1NzUajSxatIhVq1axadMmNm/eXG9GjnXr1uHq6sr27dt55JFHePfd\nd4HqKaiee+453njjDTZt2sQXX3yBtbX19Z6fJEmS1AJZFLysrExfBysqKrJoVvm4uDjat2+Pn58f\n1tbWREZGsmPHDpN1duzYwaRJkwAYM2YMBw8eBGDv3r0EBgbStWtXAFxdXVGpVJYkV5IkSWrhLApe\no0eP5rXXXqO0tJT169cze/ZsHnjggWtup9Vq8fHxUT57eXmRlZVlsk5WVhbe3t5A9ZyJzs7OFBT8\nf3v3HhxVef9x/L1sAlJMgpiQRaS0JraQGqAzKsERIYBZIITsBiIMUsKlpdoBKqFYwck4crXGyUhk\nOhIBKzRMa4HIJRBSgxI6XGy1hZkCRUEn3JJwS5NgypLN8/sjP3YbgrBWNvEkn9df7Nlnz373yzN8\nOGfPPqeKL774AoAZM2aQlpbG6tWrA/1MIiLSxt12hQ2An/70p2zdupXq6mr27NnDT37yE1JTU2/7\nukCOzm4cY4zBZrPh9Xr55JNP2LRpE506dWLq1Kk89NBDvvuJiYhI+xVQeAGMHTuWsWPHfq2dOxyO\nJhdgVFRU0L1792ZjysvLiY6Oxuv1UltbS0REBA6Hg0ceeYSIiAgAnnjiCY4cOXLb8LLi0v7Bol74\nqRd+6oWfemFdAYXXxYsX+f3vf09ZWRn19fW+7StWrLjl6+Lj4ykrK+PMmTNERUVRWFhITk5OkzGJ\niYkUFBTQv39/ioqKfOH0+OOPs3r1aq5evYrdbuevf/0rU6dOvW2tuj9PI92ryE+98FMv/NQLPyuG\neEDh9Ytf/IK4uDgGDRrku+IwEHa7naysLKZPn44xhvHjxxMTE0Nubi7x8fEkJiaSnp7O/PnzSUpK\nomvXrr5wCw8PZ9q0aYwbNw6bzcbQoUMZMmTI//YpRUSkTQnoTspjx45l69atLVHPN6b/STXS/yr9\n1As/9cJPvfCz4pFXQFcb9u/fn3/961/BrkVERCQgAZ02nDhxIpMnT8bhcNCpUyff9o0bNwatMBER\nka8SUHjNnz+fZ555hri4uK/1nZeIiEgwBBRenTp1YsaMGcGuRUREJCABfec1ePBgSktLg12LiIhI\nQAI68nr33XfJy8ujS5cudOzY0bcKxv79+4Ndn4iISDMBhdemTZuCXYeIiEjAAgqvnj17BrsOERGR\ngN0yvObPn092drZvlYsb6VJ5ERFpDbcMr4yMDAB+/etft0gxIiIigbhleG3YsIFly5bx6KOPtlQ9\nIiIit3XLS+WPHj3aUnWIiIgELKDfeYmIiHyb3PK04fHjxxk0aFCz7fqdl4iItKZbhtf3vvc98vLy\nWqoWERGRgNwyvDp27KjfeImIyLfOLb/zCg0Nbak6REREAnbL8Hr33Xdbqg4REZGA6WpDERGxHIWX\niIhYjsJLREQsJ+jhVVpaysiRI3E6nTe97N7j8TB37lySkpKYMGECZ8+ebfL82bNn+fGPf8zbb78d\n7FJFRMQighpeDQ0NLF68mDVr1rB9+3YKCws5ceJEkzEbN24kIiKC4uJiMjIyyM7ObvL8K6+8wpAh\nQ4JZpoiIWExQw+vw4cP07t2bnj17EhoaSnJyMiUlJU3GlJSU4Ha7AXA6nU1W7Xj//ffp1asXsbGx\nwSxTREQsJqjhVVFRQY8ePXyPo6OjqaysbDKmsrISh8MBgN1uJzw8nKqqKurq6li9ejWzZs0KZoki\nImJBAd1J+X9ljPnaY66vm5ibm8vUqVPp3LlzwPsCiIoK+/qFtlHqhZ964ade+KkX1hXU8HI4HE0u\nwKioqKB79+7NxpSXlxMdHY3X66W2tpaIiAgOHz5McXEx2dnZVFdX06FDBzp16sTTTz99y/c8f74m\nKJ/FaqKiwtSL/6de+KkXfuqFnxVDPKjhFR8fT1lZGWfOnCEqKorCwkJycnKajElMTKSgoID+/ftT\nVFREQkICAPn5+b4xK1eupEuXLrcNLhERaR+CGl52u52srCymT5+OMYbx48cTExNDbm4u8fHxJCYm\nkp6ezvz580lKSqJr167Nwk1ERORGNhPol0kWodMAjXRKxE+98FMv/NQLPyueNtQKGyIiYjkKLxER\nsRyFl4iIWI7CS0RELEfhJSIilqPwEhERy1F4iYiI5Si8RETEchReIiJiOQovERGxHIWXiIhYjsJL\nREQsR+ElIiKWo/ASERHLUXiJiIjlKLxERMRyFF4iImI5Ci8REbEchZeIiFiOwktERCxH4SUiIpYT\n9PAqLS1l5MiROJ1O8vLymj3v8XiYO3cuSUlJTJgwgbNnzwKwb98+0tLSGDt2LOPGjePAgQPBLlVE\nRCwiqOHV0NDA4sWLWbNmDdu3b6ewsJATJ040GbNx40YiIiIoLi4mIyOD7OxsALp168aqVavYunUr\nr7zyCs8//3wwSxUREQsJangdPnyY3r1707NnT0JDQ0lOTqakpKTJmJKSEtxuNwBOp5P9+/cD0KdP\nH6KiogB48MEH8Xg8XLt2LZjlioiIRQQ1vCoqKujRo4fvcXR0NJWVlU3GVFZW4nA4ALDb7YSHh1NV\nVdVkTFFREXFxcYSGhgazXBERsYiQYO7cGPO1xxhjsNlsvseffvopOTk5rF27NqD3jIoK+3pFtmHq\nhZ964ade+KkX1hXU8HI4HL4LMKDxSKx79+7NxpSXlxMdHY3X66W2tpaIiAgAysvLmTVrFq+++ir3\n339/QO95/nzNnfsAFhYVFaZe/D/1wk+98FMv/KwY4kE9bRgfH09ZWRlnzpzB4/FQWFjI8OHDm4xJ\nTEykoKAAaDw9mJCQAEB1dTU///nP+dWvfsWAAQOCWaaIiFhMUMPLbreTlZXF9OnTGTNmDMnJycTE\nxJCbm8sHH3wAQHp6OpcvXyYpKYl33nmHefPmAZCfn09ZWRm//e1vcblcuN1uLl26FMxyRUTEImwm\nkC+mLESnARrplIifeuGnXvipF346bSgiItICFF4iImI5Ci8REbEchZeIiFiOwktERCxH4SUiIpaj\n8BIREctReImIiOUovERExHIUXiIiYjkKLxERsRyFl4iIWI7CS0RELEfhJSIilqPwEhERy1F4iYiI\n5Si8RETEchReIiJiOQovERGxHIWXiIhYTtDDq7S0lJEjR+J0OsnLy2v2vMfjYe7cuSQlJTFhwgTO\nnj3re27VqlUkJSUxatQo/vKXvwS7VBERsYighldDQwOLFy9mzZo1bN++ncLCQk6cONFkzMaNG4mI\niKC4uJiMjAyys7MB+Oyzz9i5cyc7duzgrbfe4uWXX8YYE8xyRUTEIoIaXocPH6Z379707NmT0NBQ\nkpOTKSkpaTKmpKQEt9sNgNPp5MCBAwDs3r2b0aNHExISwv3330/v3r05fPhwMMsVERGLCGp4VVRU\n0KNHD9/j6OhoKisrm4yprKzE4XAAYLfbCQsLo6qq6qavraioCGa5IiJiEUENr0BO891sjM1m+8rt\nIiIiIcHcucPhaHIBRkVFBd27d282pry8nOjoaLxeLzU1NUREROBwODh37pxvXHl5ebPX3kxUVNid\n+wAWp174qRd+6oWfemFdQT3yio+Pp6ysjDNnzuDxeCgsLGT48OFNxiQmJlJQUABAUVERCQkJAAwb\nNowdO3bg8Xg4deoUZWVl9OvXL5jlioiIRQT1yMtut5OVlcX06dMxxjB+/HhiYmLIzc0lPj6exMRE\n0tPTmT9/PklJSXTt2pWcnBwAYmNjGTVqFMnJyYSEhPDSSy/ptKGIiABgM7r+XERELEYrbIiIiOUo\nvERExHIUXiIiYjltJrxut4ZiW1ZeXs6UKVMYPXo0KSkprFu3DoB///vfTJ8+HafTyYwZM6ipqWnl\nSltOQ0MDbrebZ555BoDTp0/z1FNP4XQ6yczMpL6+vpUrbBk1NTXMmTPHd/HToUOH2u28+N3vfseY\nMWNISUlh3rx5eDyedjMvFi5cyGOPPUZKSopv263mwZIlS0hKSiI1NZWjR4+2Rsm31SbCK5A1FNsy\nu93OggUL2LFjB3/4wx/Iz8/nxIkT5OXlMWjQIHbt2sXAgQNZtWpVa5faYtatW0dMTIzv8Wuvvca0\nadPYtWsXYWFhbNy4sRWrazlLly5lyJAh7Ny5ky1btvDAAw+0y3lRUVHB+vXr2bx5M9u2bcPr9VJY\nWNhu5kVaWhpr1qxpsu2r5sGePXsoKyujuLiYRYsW8dJLL7VGybfVJsIrkDUU27KoqCj69u0LQJcu\nXYiJiaGioqLJupFut5v333+/NctsMeXl5ezZs4f09HTftgMHDuB0OoHGXvz5z39urfJaTG1tLX/7\n298YN24cACEhIYSFhbXbedHQ0EBdXR319fX85z//oXv37hw8eLBdzIuHH36Y8PDwJttunAfX/80s\nKSnB5XIB0L9/f2pqarhw4ULLFhyANhFegayh2F6cPn2aY8eO0b9/fy5evEhkZCTQGHCXL19u5epa\nxrJly3j++ed9vwu8fPkyERERdOjQON0dDke7mB+nT5/mnnvuYcGCBbjdbrKysqirq2uX8yI6Oppp\n06YxdOhQnnjiCcLCwoiLiyM8PLzdzYvrLl261GQeXLp0CWi63ix8e9eVbRPhpZ+qNbpy5Qpz5sxh\n4cKFdOnSpV3+qPvDDz8kMjKSvn37+uaFMabZHGkPvamvr+fIkSNMmjSJgoICOnfuTF5eXrv47Deq\nrq6mpKSEDz74gL1791JXV0dpaWmzce2xNzeyyrqyQV1ho6UEsoZiW1dfX8+cOXNITU1lxIgRANx7\n771cuHCByMhIzp8/T7du3Vq5yuD75JNP2L17N3v27OHq1atcuXKFZcuWUVNTQ0NDAx06dAh4nUyr\nczgcOBwO4uPjAUhKSuKtt95ql/Ni37599OrVi65duwIwYsQI/v73v1NdXd3u5sV1XzUPoqOjKS8v\n9437tvalTRx5BbKGYlu3cOFCYmNjycjI8G0bNmwYmzdvBqCgoKBd9CQzM5MPP/yQkpIScnJyGDhw\nIK+99hoDBw6kqKgIaD+9iIyMpEePHnz++edA4/d+sbGx7XJe3HfffRw6dIirV69ijOHAgQM8+OCD\n7Wpe3HhE9VXzYPjw4bz33nsA/OMf/yA8PNx3evHbpM0sD1VaWsrSpUt9ayjOnDmztUtqMR9//DGT\nJ0/mBz/4ATabDZvNxty5c+nXrx/PPfcc586d47777mPFihXNvrRtyz766CPWrl3Lm2++yalTp8jM\nzKS6upq+ffuSnZ1NaGhoa5cYdMeOHePFF1+kvr6eXr16sXz5crxeb7ucFytXrqSwsJCQkBDi4uJY\nsmQJ5eXl7WJezJs3j4MHD1JVVUVkZCSzZ89mxIgR/PKXv7zpPFi0aBF79+6lc+fOLF++nB/96Eet\n/AmaazPhJSIi7UebOG0oIiLti8JLREQsR+ElIiKWo/ASERHLUXiJiIjlKLxERMRyFF5iScOGDeOz\nzz5rkfdauXJlk1tlLFiwgPz8/G+83wULFpCSkkJmZuY33tetHDt2jJ07dwb1PURamsJL5DZWrlzJ\ntWvX7ug+L1y4QHFxMdu2bSMnJ+eO7vtGR44c+Z/Dq6Gh4Q5XI3JnKLykTfn888/52c9+Rnp6Oi6X\ny7f8DUCfPn1YtWoV48eP58knn6S4uNj33K5duxg1ahRpaWmsWrWKPn36UFdXx6JFi7DZbEycOBG3\n201tbS0Ax48fJyMjA6fTyQsvvPCV9bz33nukpKSQmprK7NmzuXTpEleuXCEjI4OrV6/idrt55513\nmrxmy5YtzJo1y/fY6/UyePBg3/qdq1ev5qmnniItLY1nn32WixcvAnDt2jV+85vfkJKSgsvlYvbs\n2VRVVfHGG29w4MAB3G43S5cuBRpXpHG73aSmpjJt2jROnToFNK5K4nK5WLJkCRMnTmTv3r3f5K9D\nJHiMiAUlJiaaTz/9tMm2+vp643a7zcmTJ40xxtTW1hqn0+l7/MMf/tDk5+cbY4z5+OOPzeDBg40x\nxly4cME8+uijpqyszBhjzNtvv2369OljvvzyS9/r6urqfO/zwgsvmEmTJhmPx2M8Ho9JTk42+/bt\na1bj8ePHzeOPP24uXLhgjDHm9ddfN88995wxxpjTp0+bhISEm362uro6k5CQYC5fvmyMMWb37t0m\nIyPDGGPMli1bTFZWlm/shg0bzLx584wxxrzxxhtm9uzZpr6+3hhjfK/fvHmzmTNnju81Fy9eNAkJ\nCebEiRPGGGP+9Kc/mfT0dGOMMQcPHjRxcXHm0KFDN61N5NtCR17SZnzxxRecPHmSzMxMXC4XTz/9\nNNeuXWtyV+3Ro0cDMGDAAM6fP4/H4+HQoUM89NBD9OrVC4Dx48c327e5YRW1ESNGEBoaSmhoKHFx\ncZSVlTV7zcGDBxk6dCj33nsvABMnTmTfvn23/Rx33XUXw4cPZ/v27UDjoqnXbyi5e/du9u/fj8vl\nwuVysWHDBs6dOwc03g5mypQp2O12AN8K6jc6dOgQffv25YEHHgBg3LhxHD16lC+//BKA3r17069f\nv9vWKdKa2sQtUUSgMWC6detGQUHBTZ+32Wx06tQJwHcDQq/X2yyYbnx8Mx07dvT92W63N7mg47/3\nc+N9kK6/7+24XC6WL1/OmDFj+Oijj8jOzvbt89lnnyUtLe2m7xeIm9X134+/853vBLQfkdakIy9p\nM77//e9z1113sWXLFt+2kydPcuXKFaD5P+7XHw8YMIB//vOfvu99/vt7MoC7776bmpqar13PoEGD\n2LNnj+87qT/+8Y889thjzd7/Zh5++GFqa2vJycnhySef9IXusGHD2LBhA9XV1QB4PB6OHTsGQGJi\nIuvWrfNdXHL9Dsl3332377u665/36NGjvlulbN68mbi4OIWWWIqOvMSSbDYbU6dOJSQkxHcksW3b\nNt58802WLl3K2rVr8Xq9REZG8vrrr/tec+M+oPGmfC+//DIzZ87knnvuYejQoYSEhNC5c2cApk2b\nxpQpU+jcuTPr168PuMbY2FgyMzOZOnUqHTp0oFevXixatKjZ+38Vl8tFbm4uGzZs8G1LTU2lqqqK\nyZMnY7PZaGhoYNKkSfTp04eZM2eSk5ODy+WiY8eOfPe732XFihUMGjSINWvW4HK5eOSRR3jxxRd5\n9dVXmTdvHl6vl27duvmO7ESsQrdEEQGuXLlCly5dgMYjkU2bNt2R33KJSHDoyEsEWL9+PUVFRXi9\nXrp27crixYtbuyQRuQUdeYmIiOXogg0REbEchZeIiFiOwktERCxH4SUiIpaj8BIREctReImIiOX8\nH4gzFtcS9o9MAAAAAElFTkSuQmCC\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f47b20dd690\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(graph_means)\n", + "plt.ylabel('Time (seconds)')\n", + "plt.xlabel('Length of vector')\n", + "_ = plt.title('Time to sum the elements of 1000 vectors (vectorized TF operation)')\n", + "_ = plt.ylim(ymin=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4KZg2WXjbhg5" + }, + "source": [ + "## AutoGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "UQJBQWbCbinm" + }, + "outputs": [], + "source": [ + "# Sum written using for loop and converted with AutoGraph\n", + "def sum_all(elements):\n", + " sum_ = 0.0\n", + " length = len(elements)\n", + " for i in tf.range(length): \n", + " sum_ += elements[i][0]\n", + " return sum_\n", + "\n", + "def run_trial(num):\n", + " elements = get_elements(num)\n", + " return sum_all(elements)\n", + " \n", + "ag_means = []\n", + "ag_run_trial = ag.to_graph(run_trial)\n", + "\n", + "for num in range(max_elements):\n", + " with tf.Graph().as_default():\n", + " durations = []\n", + " foo = ag_run_trial(num)\n", + " with tf.Session() as sess:\n", + " for _ in range(burn_ins):\n", + " for _ in range(batches):\n", + " sess.run(foo)\n", + " \n", + " for _ in range(trials):\n", + " start = time.time()\n", + " for _ in range(batches):\n", + " sess.run(foo)\n", + " \n", + " duration = time.time() - start\n", + " durations.append(duration)\n", + " ag_means.append(np.mean(durations))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 301 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 310, + "status": "ok", + "timestamp": 1532448438694, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "DLDOmrRW99v5", + "outputId": "ae0e0573-39db-4004-a064-efc618dbf867" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEcCAYAAADUX4MJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XdYVFf++PH3DE1AinTEjgULioggNiTYjcZCoiZqjEmM\n6RuzcVc32exPE7O72dTNmmhi1u+abLJqLFFsib1E7GLDjkgbQUSKyDAz5/eHySARcYwMQ/m8nsfn\nce4999zPPXOHzz23nKtRSimEEEKIu9DaOgAhhBC1gyQMIYQQFpGEIYQQwiKSMIQQQlhEEoYQQgiL\nSMIQQghhkTqZMObPn88bb7xh6zBqlQceeICffvrJ6uuZOXMmH330kdXXU1NcuHCBUaNG0a1bN776\n6itbh1OnnT17ljFjxtg6jGpV2e9Jr9czZMgQcnNzq2x9tTJhdO3alfDwcMLDw2nfvj1dunQxT1uz\nZg3PPPMMc+bMsXoc6enphISEYDKZrL6uqlTf/mjfjTW/xy+++IKoqCgOHDjAhAkTbpu/bt06xo0b\nR1hYGJMmTbpt/smTJxk9ejRhYWGMGTOG5OTkcvPfffddoqKi6NGjB+++++49LWttEydOZNmyZdW2\nvo8//pinnnqqwjgiIyMpLS29p/pCQkK4dOnSPS3z1VdfMWLECMLCwujduzeTJk1i7dq191RHVXF0\ndCQ+Pp7PP/+8yuqslQnj0KFDHDx4kIMHD9K4cWPmz59vnvbggw9WWxxKKTQaDfLsY+1mze8xIyOD\n1q1b33G+p6cnkydPZurUqbfNKy0t5fnnn2fkyJHs27ePkSNH8txzz2EwGAD49ttv2bx5M6tXr+b7\n779n69at/O9//7No2drgXr6P7OxsEhMTiYuLKzc9PT2dAwcOoNFo2Lx58z2tX6PR3FP5OXPmsHjx\nYmbOnMnevXvZsWMHv/vd79ixY8cdl7H2344HH3yQFStW3HOyvJNamTBupZS6rdE/+eQTXnvtNaDs\n6HH58uX069ePqKgovv32W44ePcqIESOIjIy8rTeybNkyhg4dSlRUFE899RQZGRkVrnvixIkARERE\nEB4ezpEjR1BKMW/ePB544AF69erFH//4RwoLCytc/urVq0ybNo3u3bsTFRVV7gj010c3t/YK9u7d\nS0xMDF988QU9e/akT58+/Pjjj2zbto1BgwYRFRXF/PnzK1znkiVLWL16NV988QXh4eE8++yz5nkn\nT55kxIgRdO/enenTp6PX683ztmzZwsiRI+nevTvjx4/n1KlTFdYPcO7cOaZMmUJUVBRDhgxh3bp1\ndyxbWb0PPPAACxcuZMSIEXTt2pXXX3+dK1eu8PTTTxMeHs6UKVMoKCgwlz98+DDjxo2je/fujBw5\nkr1795rnTZw4kY8++ojx48cTHh7Ok08+SV5ennkelP8eU1NTmThxIhEREURHRzN9+vQ7bsOmTZt4\n8MEHiYyMZNKkSZw/fx6Axx9/nMTERGbPnk14eDgXL168bdno6GgGDx6Mr6/vbfP27t2L0Whk0qRJ\nODg4MHHiRJRS7NmzB4CVK1cyZcoU/Pz88PPz44knnmDFihUAJCYmVrrsrdauXXvbqZxFixbx3HPP\nATdPbfztb38jNjaW3r1785e//KXcvvHjjz8ycuRIunXrxsCBA9m5cycffPABBw4cYM6cOYSHh/PW\nW28BcPDgQeLj4+nevTsPP/wwhw4dKvcdffDBB4wfP56wsDDS0tJYvnw5/fv3Jzw8nP79+7NmzZoK\nv4Ndu3bRsWNHHB0dy01fuXIlYWFhjB492tw2t67v1h7QihUrePTRRwGYMGECSilGjBhBeHi4eR9e\nsmQJAwcOJCoqiueee47Lly8DN089fvPNN3zwwQdER0fj6OiIRqMhPDycd955567bOHToUMLDwxkw\nYIA56f+yD8TExDB//nx69OhBXFwcq1evLrcd165d45lnniE8PJyxY8eW+7vh7++Ph4cHR44cqbDd\n7pmq5WJjY9Xu3bvLTfvnP/+pXnvtNaWUUmlpaapdu3bqzTffVCUlJWrXrl0qNDRUPf/88yo3N1dl\nZWWp6OhotW/fPqWUUj/88IMaOHCgOn/+vDIajerTTz9VY8eOrXDdaWlpKiQkRJlMJvO0pUuXqoED\nB6q0tDR1/fp19cILL5hj+bX33ntPvfnmm8poNCqDwaD2799vnhcSEqJSU1PNn//4xz+qDz/8UCml\nVGJiourQoYOaN2+eMhgMasmSJapHjx7q1VdfVdevX1dnzpxRoaGh6tKlSxWu99a6bm3Hhx9+WGVn\nZ6tr166pIUOGqG+//VYppdSxY8dUdHS0SkpKUiaTSa1YsULFxsYqvV5/W93Xr19XMTExasWKFcpk\nMqkTJ06oqKgodfbs2dvWfbd6Y2Nj1dixY9WVK1eUTqdT0dHRatSoUerkyZNKr9erSZMmqU8++UQp\npVRWVpaKjIxU27dvV0optXv3bhUZGalyc3OVUkpNmDBBDRgwQF28eFGVlJSoCRMmqPfee++O3+P0\n6dPVZ599ppRSqqSkRB04cKDCtjx//rwKCwtTu3fvVgaDQX3++edqwIABqrS01LzepUuXVrjsrZYs\nWaImTpxYbtq///1v9fTTT5eb9swzz6h///vfSimlunXrpo4cOWKed/ToURUeHm7RsrcqLi5W4eHh\n6uLFi+ZpY8aMUWvXrlVKKfXWW2+pZ599VuXn56uioiI1bdo09f777yullDpy5Ijq1q2b+Teo0+nU\n+fPnK9z2vLw81b17d/X9998ro9Go1qxZo7p3767y8vLM5WNjY9XZs2eV0WhUBQUFKjw8XKWkpCil\nlMrOzjbvR7/2t7/9Tc2ePfu26QMGDFDffPONOnbsmOrYsaO6cuWKed6v41u+fLl69NFHzZ/btWtX\n7je4e/duFRUVZd7/5syZox577DGllFLffPONeuCBByqM7Va/3sbS0lK1detW82913759qkuXLurE\niRNKqbLf+l//+lel1+vV3r17VVhYmLpw4YJS6ubvKTIyUh09elQZjUb16quvqunTp5db57Rp09Ti\nxYvvGpslan0PwxIajYbnn38eR0dHevbsibOzM8OGDaNRo0b4+/sTERHBiRMnAPjf//7H1KlTadmy\nJVqtlqlTp5KcnExmZuYd61e39HDWrFnD5MmTCQoKwtnZmenTp7N27doKz4/b29uTnZ1NWloadnZ2\ndOvWrcI6K+Lg4MC0adOws7Nj6NChXL16lccffxxnZ2dat25N69atK+0FVGTSpEn4+Pjg7u5ObGws\nJ0+eBGDp0qWMGzeO0NBQNBoNI0eOxNHRscKjli1bttCkSRNGjhyJRqOhffv2DBw4kPXr199W1pJ6\nJ0yYgJeXF35+fkRERNClSxdCQkJwcHBgwIAB5hi///57+vXrR58+fYCbR+6dOnVi27Zt5rpGjx5N\ns2bNcHR0ZMiQIeZlf3Frm9vb25Oeno5Op8PR0ZHw8PAK22zdunX069eP6Oho7OzsePLJJ7lx40a5\nI+ff6vr167i5uZWb1rBhQ3OP9dfz3dzcuH79ukXL3qpBgwbExcWZj95TUlK4cOGC+fTOsmXLmDlz\nJm5ubri4uDB16lRz2WXLlhEfH090dDQAfn5+tGzZssLt2bp1Ky1atGD48OFotVqGDRtGq1at2LJl\ni7nMqFGjCA4ORqvVYmdnh52dHadPn6akpAQfHx+Cg4MrrLugoABXV9dy0/bv309GRgZDhgyhY8eO\nNGvW7Laj83uxZs0a4uPjzfvf9OnTOXz4MBkZGVy9evW2XmJMTAzdu3enc+fO5f5+3LqN9vb2xMTE\n0KRJE+BmL7dXr17s37/fXF6j0fC73/0OBwcHunfvTkxMTLle+8CBA+nUqRNarZbhw4fftl+7urqS\nn5//m7f7VvZVUkst4O3tbf5/gwYN8PHxMX92cnIy/9AyMjJ4++23+dvf/gaUnd/W6XQEBgbedT2X\nL1+mcePG5s9BQUEYDAZycnLw8/MrV/app57in//8J1OmTEGj0fDwww9XeC67Ip6enuZzrA0aNKhw\nG3/ZJkvduryzszPZ2dnAzTZZtWqV+S4fpRQGg8HcHb9VRkYGhw8fJjIy0lzWaDQycuTICsverd5b\nY3Jycrrt863f27p168x/fH6p65c/ZEC579zZ2bnS9pkxYwYffvgh8fHx5usMFd2B8+vvW6PREBgY\niE6nu2PdlnJxcbntD3xhYSENGzascH5hYSEuLi4WLftrw4YN4+9//zvPPfcca9asoX///jg6OpKb\nm0txcXG5bTeZTObkmpWVRUxMjEXb8+u2AmjcuHG5tgoICDD/39nZmQ8++ICFCxcya9YsunXrxowZ\nM2jVqtVtdbu7u1NUVFRu2qpVq+jduzceHh7mbVy5ciWPP/64RfFWFH/Hjh3Nn11cXPD09ESn0+Hp\n6Xnb72Hbtm0YjUY6depU7mDk1m38pdy8efNISUnBZDJx48YN2rVrV27bnJyczJ8bN25cbl1326+L\niopwd3f/Tdv8a/UmYVgqICCAZ5991qKL5xVdFPPz8yt3zSM9PR17e/tyX+ovXFxc+MMf/sAf/vAH\nzp07x8SJE+ncuTM9evTA2dmZ4uJic9ns7OzbdrTqEhAQwLRp03jmmWfuWjYwMJCoqCgWLlxYpfVa\nst6RI0cye/bse162ou/R29vbfG3rwIEDPPHEE0RGRtK0adNy5fz8/Dhz5ky5aZmZmVXyXbVp04ZF\nixaVm3b69GnzNZfWrVuTnJxMaGgocPMaVJs2bSpdtqI7tQB69+7NzJkzSU5OJiEhgVmzZgHQqFEj\nnJ2dWbNmzW0HPHDzO7zTnUS/blc/Pz82btxYblpGRgZ9+/a94zK9evWiV69e6PV6PvjgA9544w2+\n/vrr29bVrl07Vq1aZf5cUlLCunXrMJlM9O7dG7h5I0B+fj6nTp2iXbt2uLi4cOPGDfMyvxwg3cmv\nf9vXr18nLy8Pf39/PD09eeuttzh+/Hi5pAK3ny24dRv1ej0vv/wy7777LnFxcWi1Wp5//vlyy+Tn\n53Pjxg3zgWFmZiZt27atNNZbnT9/nieffNLi8pWpF6ek7nZ651bjx49n/vz5nD17FrjZ1a3odAqA\nl5cXWq2W1NRU87Rhw4axaNEi0tLSKCoq4oMPPmDYsGFotbc39datW83Luri4mLvgcPOi95o1azCZ\nTGzfvp19+/ZZvA134+Pjc0+3Cz7yyCN8++23JCUlATd/KNu2bavwCL1fv35cuHCBVatWYTAYKC0t\n5ejRo+YLwb+13rsZMWIEmzdvZufOnZhMJkpKSti7d69FR/oVfY/r1683L+vu7o5Wq63wOxwyZAhb\nt25lz549GAwGFi5ciJOTE2FhYRbFbTKZ0Ov1GAyGcv8HiIyMRKvVsnjxYvR6vbknFhUVBcDIkSNZ\ntGgROp0OnU7HokWLGD16dKXL9ujRo8I47OzsGDRoEH//+9/Jz8+nV69eAOae79y5c8338+t0Onbu\n3AlAfHw8y5cvZ8+ePSil0Ol05u/61/tZTEwMFy9eJCEhAaPRyNq1azl//jyxsbEVxnTlyhU2b95M\ncXEx9vb25t9IRXr16sXx48fNF+N/+OEH7OzsWLduHatWrWLVqlWsXbuWbt26sXLlSuDmb2zjxo3c\nuHGDixcv8t1335Wr89fxP/jggyxfvpzk5GT0ej3vv/8+Xbp0oXHjxrRs2ZKxY8cyffp0du/eTUlJ\nCSaTiYMHD1Z6t1VpaSmlpaU0atQIrVbLtm3b2LVrV7kySik+/vhjSktL2b9/P1u3bmXIkCF3rPNW\nOp2Oa9eu0aVLF4vK302tTxiW3Pr26zKVfe7fvz9PP/00r7zyChEREYwYMeKOt8U1aNCAadOmMX78\neCIjI0lKSiI+Pp6HHnqICRMmMGDAAJydnXn99dcrXD4lJYXJkyfTtWtXxo8fz2OPPUb37t0B+NOf\n/sTmzZvp3r07CQkJ9O/f/7628Vbx8fGcPXuWyMhIXnjhhbuW79SpE3PmzGH27NlERkYyaNCg2+44\n+YWrqytffvkla9eupU+fPvTp04f33nuv3F01ltZ7L9sUEBDAvHnzmD9/PtHR0cTGxvLll1+aDxYq\nW7ai7/Ho0aM8/PDDhIeH8/zzz/OnP/2JoKCg25Zt2bIl7777LnPmzCE6OpqtW7fy2WefYW9vf9f1\nws3TJp07d2b27NkcOHCALl26mB86dXBwYN68eaxYsYLIyEiWL1/OvHnzzHWPGzeO2NhYRowYwYgR\nI4iNjeWRRx6xaNmKDBs2jJ9++okhQ4aUS46///3vad68OY888ggRERFMmTKFlJQUADp37szcuXOZ\nO3cu3bp1Y9KkSebz9ZMmTWL9+vVERUXx9ttv4+npyWeffcbChQvp0aMHCxcuZP78+eZTRr9uK5PJ\nxL///W/69u1Ljx492LdvH2+++WaFsXt7e9OjRw82bdoE3Lw7asyYMfj7++Pt7W3+99hjj7F69WpM\nJhOTJ0/GwcGBXr16MXPmTIYPH16uzhdffJEZM2YQGRnJ+vXriY6O5uWXX+bFF1+kT58+pKWl8f77\n75vL//nPf2bixIm88847REVFERMTw8cff8yHH35oPhX36210dXXlT3/6Ey+//DKRkZGsXbv2tluD\nfX198fDwoE+fPsyYMYPZs2fTokWLO36Pt1q9ejWjRo3CwcHBovJ3o1H3cvh9j2bNmsXWrVvx9va+\n48WmxMRE3nnnHQwGA40aNWLx4sXWCkcIUYedO3eOP/7xjyxdutTWoVSZvXv3MmPGDLZu3XrPy+r1\nekaOHMlXX32Fl5dXlcRj1YSxf/9+XF1dmTFjRoUJo6CggHHjxvHll1/i7+9Pbm5ulW2YEELUdveT\nMKzBqqekIiIiKr06v3r1agYOHIi/vz+AJAshhKjBbHoNIyUlhWvXrjFx4kTGjBljvhglhBDi5s0L\nNaV3ATa+rdZoNHLixAn+7//+j+vXrzNu3Di6du1K8+bNbRmWEEKICtg0Yfj7+9OoUSOcnJxwcnIi\nIiKC5OTkuyaMXx6mE0IIUX2snjAqu6YeFxfHW2+9hdFoRK/Xk5SUxBNPPHHXOjUaDdnZBXctVx/4\n+rpJW/xM2qKMtEUZaYsyvr5udy9UCasmjFdffZXExETy8vLo168fL774IqWlpWg0GsaOHUtwcDC9\ne/dmxIgRaLVaHnnkkUqHghZCCGE7Vr2t1prkiOEmOXoqI21RRtqijLRFmfvtYdT6J72FEEJUD0kY\nQgghLCIJQwghhEUkYQghhLCIJAwhhBAWkYQhhBB1VFXfBCsJQwgh6qD07EJe+HAHe05kVVmdkjCE\nEKIOWrnjAsUlBlwbVM3Lk0AShhBC1GqLN5zi201nyp1+StUVcOB0Nq0au9OpZdW9NsKmgw8KIYT4\n7c6mX2PLoXQA/Bs5ExveBIBVOy8AMLJ3yyodqFUShhBC1FIb9qYC4Oig5ZtNZ2gR6I5Wo+HQmRyC\ng9zpWIW9C5BTUkIIUStdvnqdg6eyaR7gxgujQjEaFZ+uPMaSLWcBGNmnVZW/BkJ6GEIIUcMopZi3\n8hhXC0ro0MKLTi29aNXYHXu7smP8jfsuoYDBkc3o1MqbB3u2YPXuFHKu3aBNEw86NG9U5XFJwhBC\niBrmXEY+B05lA3A+I581u1Nwc3HgyWHt6RzsQ2FxKTuTMvF2dyIixBeAh3q35Gz6NU5evGqV3gVI\nwhBCiBpn19FMAJ4b2Qk7Ow3HLuSy40gmHy5NYlh0c+y0GvQGEwO6N8NOe7PXodVqeGlMZzKuFNEy\n0N0qcUnCEEIIG7mYVcD5jGv06xpk7hHoS43sPamjkZsT4W190Wo1dG3jS9/Ojfl05TESfroIgLOT\nPX06B5arz8nRzmrJAuSitxBC2IS+1Mi/Vhxl8cbTJJ7UmacfOpNDcYmRnp0C0GrLTis1D3Djz5O7\nE9725imouG5BODtV7zG/9DCEEMIGNuxNJefaDQCWbD5Ll2AfnJ3szaejenYKuG0Zlwb2PD+qE5cu\nF9LEt2G1xgvSwxBCiGqXm3+DhD0XcXdxYGD3puQV6lm9K4WrBSUcT8klOMidQG/XCpfVaDQ083cr\n1/uoLlZNGLNmzaJnz54MHz680nJJSUl06NCBjRs3WjMcIYSoEZZtO4e+1MSYmGBG922Fj0cDfth/\nieXbzqEU9AoNvHslNmDVhDF69GgWLlxYaRmTycR7771Hnz59rBmKEELUCGfTrrHnuI7mAW706hyI\no4Md4/u3wWhS7DqWhb2dlsgQP1uHWSGrJoyIiAjc3Su/Yr948WIGDRqEl1fVPsIuhBA1jUkpvtl0\nGoBH+7dB+/OdUWGtfegc7A1AeFsfXKpwhNmqZNNrGDqdjh9//JHx48fbMgwhhKgWe45ncSGzgMj2\nfrRp4mmertFomDCgLV2CvRkW3cJ2Ad6FTe+Smjt3Lq+99pr5/uOqfjuUEELUFCWlRr7bdh4Hey0P\n92t923wfT2defriLDSKznE0TxrFjx3jllVdQSnH16lW2b9+Ovb09cXFxd13W19etGiKsHaQtykhb\nlJG2KFMT2uLbH05xtaCEh+PaENLa19bh/CZWTxiV9Ro2bdpk/v/MmTOJjY21KFkAZGcX3HdsdYGv\nr5u0xc+kLcpIW5SpCW1xtaCEpZtO4+7iQL/OgTaL534Tp1UTxquvvkpiYiJ5eXn069ePF198kdLS\nUjQaDWPHjrXmqoUQosZYseM8+lIT4+PaVPvT2VXJqpG/9957Fpd95513rBiJEEJUv5JSI9sOpbMr\nKZMmvq706dzY1iHdl9qb6oQQwoYMRhMnUq4S0swTRwe7cvOKSwxsOZTOhr2pFFwvxdFBy4SB7Wzy\ndHZVkoQhhBC/wfJt51m/N5UALxeeHNae4CAPlFLsP5XNf388zbVCPc5OdjzYszkDIpri5uJo65Dv\nmyQMIYS4R5lXivhh/yWcnezR5V5n7lcH6N+tKZm5RRw7n4u9nZbhPVswKLJpjX0I77eQhCGEEPfo\n201nMZoUU4aG4ObiyJcJJ/lh/yUAOrb0YsLAtvg3crFxlFVPEoYQQtyDI2dzOHr+Cu2bNyK8rS8a\njYb/NyWSH/ZfIsDLhW7tfK3yetSaQBKGEEJYyGA08e2mM2g1Gsb3b2NODE6OdjzYs4Vtg6sG8j4M\nIYSwgMFoYvm28+iuFhPbNcgmLzCyNelhCCFEJZRSHDqTw9Kt59DlXsfD1ZGH+rS0dVg2IQlDCCF+\n5WpBCeczrnE+I58TF69yMasArUZDv65BPNS7JQ2d686dT/dCEoYQQvys1GDi281n2HIw3TxNA3Rt\n40N8v+A7vja1vpCEIYSol86k5eHsaE9jX1e0Gg05ecXMW3mMlKwCGvu4Et3Rn1aB7rQIdK/V4z9V\nJWkFIUS9s+NIBv9elwxAQ2cH2jTx4PSlPIpuGOjVKYAJg9rh9KvhPoQkDCFEPXMq9Sr/2XAK1wb2\ndGntw6nUqxw6k4O9nZbJQ0Lo0zmwzj5Hcb8kYQgh6g3d1et8svwoAM+PCiWkeSOUUuRcu4GTgx3u\nrrV/vCdrkoQhhKgXim6U8tHSJIpuGJg8JISQ5o2Am+/T9vV0tnF0tYM8uCeEqPOMJhOfrTxGVu51\nBkc2o2+X2v1eCluRhCGEqPOWbjnH8ZSrdA72Jr5fsK3DqbUkYQgh6rQf915k475LBHq78MyIjrX+\nJUa2JAlDCFGrXS0oYfPBNIpulN4270xaHv9aloRrA3teiu8sz1PcJ6u23qxZs9i6dSve3t6sXr36\ntvmrV6/m888/R6PR4OLiwl/+8hfatWtnzZCEEHVIid7I+0sOk55dxIrt5xkW3YK4bkHorhazds9F\nEk/o0Gg0TBvZqU6+n6K6aZRSylqV79+/H1dXV2bMmFFhwjh8+DDBwcG4ubmxfft2PvnkE5YsWWJR\n3dnZBVUdbq3k6+smbfEzaYsy9aEtlFIsWH2CxBM62jdvxMWsAq6XGGjo7EBh8c3eRhNfV6aM6EQL\n3/o9pMcvfH3d7mt5q/YwIiIiSE9Pv+P8sLCwcv/X6XTWDEcIUYf8eCCNxBM6goPceeWRLpSUGkn4\n6SJbDqUTHOTOsOgWdAn2xs/Pvc4nz+pSY07oLV26lL59+9o6DCFELXD6Uh5LNp/F3cWB50aGYm+n\nxd5OyyOxrXkktrWtw6uzakTC2LNnD8uXL+e///2vxcvcb9eqLpG2KCNtUaautsW5tDzmrTyGAv44\nOZK2rXzuukxdbYvqZvOEkZyczJ///Ge++OILPDw8LF5Oupg31Ydz1ZaStihTV9vibPo1PlhyhBsl\nBh4fEkKAu9Ndt7OutsVvUaOvYcDNC1N3kpGRwUsvvcTf//53mjVrZu1QhBC1WPLFq3y0LIlSg4mn\nhncgumOArUOqd6yaMF599VUSExPJy8ujX79+vPjii5SWlqLRaBg7dizz5s3j2rVr/L//9/9QSmFv\nb8+yZcusGZIQopZJyy7kx/1p7D6WiVLw7MhOdGvna+uw6iWr3lZrTdLFvEm622WkLcrU9rYwmkwk\nnb3CjwfSOHnxKgA+Hg2YNLgdnVp631Ndtb0tqlKNPyUlhBCWulpQwq6jmWw9nE5ufgkA7Zs3on+3\nJnRp7SPDetiYJAwhhM2YTIp9yZc5kZLL6Ut56K4WA+DkYEds1yBiuwbRxK+hjaMUv5CEIYSwCX2p\nkc9Xn+DA6WwAGjjaEdrKm87B3vTsFCDjPtVA8o0IIapdwXU9//zuKGfTrxHSzJNHHmhNMz83OeVU\nw0nCEEJUq8t5xXyw5Ai63Ov06ODPE0Pb42AvA2fXBpIwhBDVJj2niH98e4hrhXqG9mjO6JhWaDXS\nq6gtJGEIIapFqq6Af3x7mMLiUsY90JqBkfKwbm0jCUMIYTUmk6KguJSLWQUs+P44xSUGHh/cjpiw\nIFuHJn4DSRhCiCp3IiWX/2w4RU7eDUw/Pxus1WhkSI9aThKGEKJKHTiVzfzvjwHQqrE7Hg0d8XR1\nIrydL+2bN7JxdOJ+SMIQQlRqx5EMDp3JoW9YYzoHe1d6kXrHkQwWrU/G0d6OF8eE0qGFVzVGKqxN\nEoYQ4o5+EQoNAAAgAElEQVRKDSaWbj1HYXEph8/mEOjtQv9uTWjo4oi+1Ii+1Ej+9VKuFtzgyrUb\nHE+5imsDe155JIxWjd1tHb6oYpIwhBB3dOhMNoXFpUS298PeTkviCR2LN56+Y3n/Rs68MDqUIF8Z\nzqMukoQhhLijHUmZAIzo1ZLGPq6M7tuKg6ez0Wg0ONprcXDQ4ubiiJebE54NnWQ4jzpOvl0hRIVy\n8oo5cSGX1kEeNPZxBcDLvQH9I5raODJhK/I8vhCiQjuPZqKAPl0CbR2KqCEkYQghbmMyKXYezaSB\nox3dQ/xsHY6oISRhCCFuczwll9z8EiLb+9PAUc5ci5skYQghbrP9SAYAfbs0tnEkoiaxasKYNWsW\nPXv2ZPjw4Xcs89ZbbzFw4EAeeughTp48ac1whBB3kXmliM9WHePgqWyCfF1pGXh/74AWdYtV+5qj\nR49m4sSJzJgxo8L527ZtIzU1lY0bN3LkyBHefPNNlixZYs2QhKi3SvRGPlt1DEcHO8Ja+xAa7I1r\nA3uuXLtBSlYBh87ksOdEFkpBc383Jg8JQSNDj4tbWDVhREREkJ6efsf5mzZtYuTIkQB06dKFgoIC\ncnJy8PHxsWZYQtRLq3encOTcFQD2JV9Gq9Hg7GRH0Q2DuUwTX1ce6t2K8LY+kizEbWx6Nevy5csE\nBJSNXOnv749Op5OEIUQVy7xSxIa9qXi7N+D50Z04fiGXw2dzKLheSocWXrQIcKNloDttm3nKC43E\nHdk0Yaifhz2+laVHNb6+cm71F9IWZaQtyvzSFkopPvouCaNJMW1MZ7p3CqR7aP16H4XsF1XDpgnD\n39+frKws8+esrCz8/Cy75zs7u8BaYdUqvr5u0hY/k7Yoc2tb7D2p48iZHDoHe9PKz7XetZHsF2Xu\nN3Fa/bbainoRv4iLi2PlypUAHD58GHd3dzkdJUQVKiwu5dtNZ7C30/Jo/zZyXULcF6v2MF599VUS\nExPJy8ujX79+vPjii5SWlqLRaBg7diwxMTFs27aNAQMG4OzszDvvvGPNcISos0r0Rq6XGCg1mjAY\nTBy9mMeWfakcu3AFg1ExolcL/Bq52DpMUctpVGVdgBpMupg3SXe7TH1rC6UUpy/lseVQOgdOZWM0\n3f5TbuLrSlQHfwZFNsPern4+p1vf9ovK3O8pKXnmX4ha6PSlPBZvOEV6ThEAgd4uNPN3w95Og4Od\nlqaBHoQ0cSfQ29XGkYq6RBKGELXMoTPZfLryOCaTIrK9H7Fdg2jb1LPc9Qk5qhbWIAlDiFpk19FM\n/r02GXt7DS+N6UynVt62DknUI5IwhKihDEYTmw+kkX3tBiaT4nqJgcQTOlwb2PO7h7sQHORh6xBF\nPSMJQ4ga6rtt59iw91K5aY3cnHjlkS40kXdmCxuQhCFEDXToTDYb9l7C38uFaSM6Ym+vxU6rwdvd\nCQd7O1uHJ+qpuyaMS5cusWzZMhITE8nKysLJyYmQkBAGDRrEwIEDsbeXnCNEVcrJK2bhmpM42Gt5\nbmQnmvpJb0LUDJX+tf/zn//M8ePHGTx4ML///e/x8fGhpKSEc+fOsXPnThYsWMBf/vIXwsLCqite\nIeo0g9HEp6uOc73EwOQhIZIsRI1SacKIi4tj9uzZt01v164dQ4cOJS8vj0uXLlWwpBDiXhmMJhYm\nnORCZj7RHQPo0znQ1iEJUU6lCSMmJqbShT09PfH09KzSgISoj0oNRj5deZzDZ3MIDnJn4qC2Mu6T\nqHEsGivgr3/9KwUFBRgMBh599FHCwsJYtWqVtWMTol64oTfw4dIkDp/NoUOLRvx+bFcaOMq1QVHz\nWJQwdu/ejZubGzt37sTf358NGzbw5ZdfWjs2Ieq8VF0Bf/3qICcvXqVrGx9eju+Mk6PcBSVqpns6\njNm3bx8DBgzA399fustC3IdSg5Hvd6Wwbk8qJqWICWvMhIFtsdPWzwECRe1gUcLw9vbm9ddfZ9eu\nXUydOhWDwYDRaLR2bELUOfpSI3tO6Fi35yK6q8V4uzfg8cHtZIgPUStYlDDee+89vv/+e+Lj4/Hw\n8CAtLY0nnnjC2rEJUWcUFpeycV8qWw9lUFhcilajoX9EE0b3bSXXK0StYdGe6uXlxeTJk82fmzRp\nQpMmTawVkxA12rXCErRaDW4ujhaV1129zgdLjnD5ajGuDewZFt2c2K5BeLk3sHKkQlStShPGc889\nx7Rp0+jcufNt8woLC/nuu+9o0KABY8eOtVqAQtQkJXojbyzcyw29gZ6dAhkU2bTSd06cS7/GR8uS\nKCwuZUiPZozo1RInB7moLWqnShPGSy+9xHvvvUdKSgqdO3fG29ubkpISzp8/T3p6OuPGjWP8+PHV\nFasQNrc3WUdhcSmODlq2H8lgx5EMQoO96djSi3ZNPWni2xC9wYgut5jzGdf43+azlBpNTBrcjn5h\nQbYOX4j7UmnCCAkJ4fPPPyczM5O9e/ei0+lwcnJi8ODBdOvWDUdHy7rkQtQVO45kogFmPxlFalYB\n6xIvknTuCknnrgDgaK9FbzCZyzs6aHlpTGe6tPaxUcRCVB2LrmEEBgby0EMP/aYVbN++nblz56KU\nYsyYMUydOrXc/MzMTP7whz9QUFCAyWRi+vTpd33CXAhbSM8p4mz6NTq29MLP0xk/T2e6tfMl+9oN\nzlzK49SlPFIyC3B3dcDfy4WARi6EBnsT4OVi69CFqBIWJYwrV67wzjvvkJmZyddff01ycjKHDh26\n6+kok8nEnDlzWLRoEX5+fsTHxxMXF0dwcLC5zKeffsrQoUMZN24c586d4+mnn2bz5s33t1VCWMGO\nIxkA9O3S2DxNo9GYk0evUBn7SdRtFj0l9Prrr9OtWzfy8/MBaNWqFf/973/vulxSUhLNmzcnKCgI\nBwcHhg0bxqZNm8qV0Wg0FBYWApCfn4+/v/+9boMQVldqMLH7WBYNnR3o2kZOL4n6yaKEodPpGD9+\nPHZ2N+/ucHR0RGvBE6k6nY7AwLKjLn9/fy5fvlyuzAsvvMCqVauIiYlh2rRpvPHGG/cSvxDV4vDZ\nHAqLS+kVGoC9nTyNLeoni05J/folSfn5+Sil7rqcJWUSEhIYM2YMkydP5vDhw7z22mskJCTcdTlf\nX7e7lqkvpC3KWKst9iw/CsBD/drUmvauLXFWB2mLqmFRwhg4cCB//vOfKSoqYvny5fz3v/9lzJgx\nd10uICCAjIwM82edToefn1+5MsuWLWPhwoUAhIWFUVJSQm5uLl5eXpXWnZ1dYEnodZ6vr5u0xc+s\n1RapugIOn86mdRMPGmhrx74n+0UZaYsy95s4LepbP/XUU0RERNCxY0e2bdvGxIkTefzxx++6XGho\nKKmpqaSnp6PX60lISCAuLq5cmcaNG7N7924Azp07h16vv2uyEKI6KKXYcjCNtxcfQAEDIpraOiQh\nbEqjLDlvdB+2b9/O22+/jVKK+Ph4pk6dyscff0xoaCixsbGcO3eO119/nevXr6PVapkxYwbR0dF3\nrVeOGG6So6cyVdkW+UV6Fq1L5vDZHFwb2DN5SHu6tfOtkrqrg+wXZaQtytxvD8OihHHlyhW++uor\nUlNTMRgM5ukfffTRfa38fsgOcJP8GMpURVvkX9ezYW8qmw+kU1JqJKSZJ08P70gjN6cqirJ6yH5R\nRtqizP0mDIuuYTz33HN06NCB6Oho851SQtQl6dmF7EjKZNvhDEpKjXg0dCS+XzCxXYPQauXdL0KA\nhQmjuLiYN99809qxCFGtSkqNbD+cwe5jWVzU3TwC9WjoyJiYVsSENcbBXg6OhLiVRQmjS5cunDp1\ninbt2lk7HiGqRfLFqyxal8zlvGLstBq6BHvTMzSQsNbekiiEuAOLEsa4ceOYMGECAQEBODmVnctd\ntmyZ1QIToqqUGkwUXNejN5go0RvZejidbYcz0GhgUGRThkQ1x91VBtIU4m4sShivvfYa06ZNo0OH\nDnINQ9Qql/OK+etXB8gr1Jeb3sTXlSeGtqdloLuNIhOi9rEoYTg5OfHkk09aOxYhqpS+1Mi85UfJ\nK9TTtY0Pbi4OONrb4e/lQkxYYxniQ4h7ZFHC6NOnD9u3b6dv377WjkeIKqGUYvHGU6ReLqRvl8ZM\nHhJi65CEqPUsShhLlixhwYIFuLq64ujoiFIKjUbDTz/9ZO34hPhNth/JYNfRLJoHuPHYgDa2DkeI\nOsGihPHdd99ZOw4hqkROXjE/ndCxetcFXBvY8/yoTnLXkxBVxKKEERQk7yIWNVvSuRx+WHKE4+d/\nflWqg5ZnHuqIj4ezjSMTou6oNGG89tprvPvuu4wZMwaN5vanXeW2WmFrJqX4fucFvt+VAkBIM0+i\nOwbQrZ0fLg0sOh4SQlio0l/ULy87+sMf/lAtwQhxL4pLDCxMOMnB09n4eDTgjSd74OYodz4JYS2V\nJoxfXskaGRlZLcEIYYkSvZHEkzrWJ6aSlXudkGaePDuyE62CPGSQOSGsSPrsotbIKywh4aeL7D6W\nSXGJEY0G+ndrwiMPtJZnKoSoBpUmjNOnT1f4bgq5rVZUtwOnsvm/9ckUFpfi2dCRARFN6dulMV7u\nDWwdmhD1RqUJo0WLFixYsKC6YhHiNsUlBr758Qw7j2biYK/lsQFt5SltIWyk0oTh6Ogot9QKmzAY\nTWw7nMGa3SlcK9LT3N+NqSM6EOjtauvQhKi3Kk0YDg4O1RWHEMDN22R/OpbFqp0XyLl2AycHOx7q\n3ZJh0c2lVyGEjVWaMJYsWVJdcQhBenYh/7fhFGfTrmFvp2Vg96YM7SFDjwtRU1j9Lqnt27czd+5c\nlFKMGTOGqVOn3lZm7dq1/Otf/0Kr1dKuXTv+8Y9/WDssUYPkF+n5Yf8l1iemYjQpurXzZdwDbfD2\nkAvaQtQkVk0YJpOJOXPmsGjRIvz8/IiPjycuLo7g4GBzmYsXL/LFF1/wv//9j4YNG5Kbm2vNkISN\nGIwmVu9KQW8w4uxkj4uTPVfyb3Ai5SqXLhcC4O3uxGMD2hHWxsfG0QohKmLVhJGUlETz5s3NF86H\nDRvGpk2byiWMJUuW8Oijj9KwYUMAvLy8rBmSsJENe1NZvTvltun2dlo6tGhEp5be9OvamAaO8miQ\nEDWVVX+dOp2OwMBA82d/f3+OHj1arkxKSgoA48ePRynF888/T58+fawZlqhmV67dYPWuFNxdHHhh\ndGdKDEau3zDg2sCe1kEeODrIaLJC1AZWTRhKqbuWMRqNpKam8vXXX5ORkcFjjz1GQkKCucdxJ76+\nblUVZq1X09vi84ST6A0mnn+4C9Fdm1h1XTW9LaqTtEUZaYuqYdWEERAQQEZGhvmzTqfDz8+vXBl/\nf3+6du2KVqulSZMmtGzZkpSUFDp16lRp3TJm0E2+vm41ui2Onr/CT0czadPEg07NPK0aa01vi+ok\nbVFG2qLM/SZOq97YHhoaSmpqKunp6ej1ehISEoiLiytXpn///uzZsweA3NxcLl68SNOmTa0Zlqgm\npQYTX/9wGq1Gw4SB7SocIl8IUXtYtYdhZ2fHG2+8wZQpU1BKER8fT3BwMB9//DGhoaHExsbSp08f\ndu3axbBhw7Czs2PGjBl4eHhYMyxhRUopUnWFHDydzcHT2Vy+Wkz/iCY09av8FKMQoubTKEsuNNRA\n0sW8qSZ1t3OuFfOvFce4mHUzHns7LWGtvXliaHucnax/91NNagtbk7YoI21R5n5PSck9jKJKnEnL\n41/Lj5J/vZSubXyI7hhAp1ZecpusEHWI/JrFfTEpxc6kTBZvOIVS8NiAtjwQHiTXK4SogyRhiHtm\nMilOpl7l4Kmb1ymuFelxcbLn2VGd6NhCHrwUoq6ShCHuSeaVIr5Yc5ILmTdf39vQ2YHenQMZFt0c\n/0YuNo5OCGFNkjCERUxKselAGsu2nqPUYKJ7iB+xXYNo09QDO60MOy5EfSAJQ9xVzrVivkw4SXJq\nHg2dHXj6wQ5EhPjdfUEhRJ0iCUPckVKKHUmZfLvpDDf0RsJa+/D44HZ4NHSydWhCCBuQhFHPGYwm\ntBoNWm35u5oycopYsuUsSeeu4Oxkx5PD2tOzU4Dc/SREPSYJox7LL9Lz1n/2U1JqJCLEj6j2/rg2\nsGf17hT2nbyMAjq2aMQTQ9vj5S4vMxKivpOEUU+ZlOLzNSduvjfb0Y4tB9PZcjDdPL+ZX0OG92pJ\neFsf6VUIIQBJGPXW2p8ucvxCLqGtvHlxTCinUvNIPKEjr6iE2K5BhLWWRCGEKE8SRj10+lIeK3ac\np5GbE0892B57Oy0dW3rRsaU8dCeEuDNJGPWIUoqz6df4bNUxNGh4ZkRH3FwcbR2WEKKWkIRRD5Qa\njOw/lc0P+y6R8vNIsg/3C6ZtU08bRyaEqE0kYdRRpy/lcfB0NmfTr3ExqwCjSaEBurbxYWD3prRr\n1sjWIQohahlJGHXQ0fNX+HDJERRgp9XQzL8hIc0aEdM1CD9PZ1uHJ4SopSRh1DGX84pZ8P1x7Oy0\nPPtQRzq09MLJwc7WYQkh6gBJGHVIid7IJ98dpeiGgSeGhNC1ra+tQxJC1CEyzGgdoZRi0fpk0rIL\n6dc1iD5dGts6JCFEHWP1hLF9+3YGDx7MoEGDWLBgwR3LrV+/npCQEI4fP27tkOqU4hIDWw+n8+aX\n+0g8oSO4sTuP9m9j67CEEHWQVU9JmUwm5syZw6JFi/Dz8yM+Pp64uDiCg4PLlSsqKuKrr74iLCzM\nmuHUKSaTImHPRTbsTeX6DQN2Wg0RIX481r8N9nbScRRCVD2rJoykpCSaN29OUFAQAMOGDWPTpk23\nJYyPPvqIp59+mi+++MKa4dQZeYUlLPj+OMmpeXi6OdG/WxNiwoJo5CbDjgshrMeqh6I6nY7AwEDz\nZ39/fy5fvlyuzMmTJ8nKyiImJsaaodQZxy5c4S9f7iU5NY+w1j7Mm/EAI/u0kmQhhLA6q/YwlFJ3\nnT937lz+9re/WbzML3x93e4rttomJTOfxWtPsvdEFvZ2Gp5+qBPD+7RCo9HI8B63qG/7RWWkLcpI\nW1QNqyaMgIAAMjIyzJ91Oh1+fmWv9iwqKuLs2bNMnDgRpRQ5OTk899xzfPrpp3Ts2LHSurOzC6wW\nd02Sqitgw95U9hzXoYC2TTwY178NLQLcyckpxNfXrd60xd1IW5SRtigjbVHmfhOnVRNGaGgoqamp\npKen4+vrS0JCAu+//755fsOGDfnpp5/MnydOnMjMmTPp0KGDNcOq8fSlRvaevMzWw+mcz8gHbr6f\nYnRMMKGtvGTYcSGETVg1YdjZ2fHGG28wZcoUlFLEx8cTHBzMxx9/TGhoKLGxseXKazQai09J1UUm\nk2LXsUxWbD9PXqEeDdA52JuYsMZ0ae2DVhKFEMKGNKqW/oWua13Mkym5fLv5LJcuF+JoryUuogmx\nXYPw8ah87CfpbpeRtigjbVFG2qJMjT4lJe5OX2pkyZazbD6Yjgbo1SmAUX1byTu0hRA1jiQMG0q7\nXMj874+TnlNEkI8rTz7YnhYB7rYOSwghKiQJwwZMSrHpQBpLt5zDYDTxQHgQj8S2xlFGlRVC1GCS\nMKpZbv4Nvlx7khMpV2no7MATQzvStY2MKiuEqPkkYVQTfamR3ceyWLb1HNdLDHQO9uaJISF4NJQn\ntIUQtYMkDCvLyStm86F0dhzJoOiGAUcHLZMGtyOmS2N5nkIIUatIwrASk0mxLvEiK3dcwGhSuLk4\n8GDP5vQLC5I7oIQQtZIkDCu4cu0GX6w5walLeXg0dCQ+JpjI9n442MtFbSFE7SUJowoppdhzQsdX\nG09TXGIgvK0vjw9uJ4MDCiHqBEkYVSS/SM9/Npzi4OlsnBzsmDwkhD6dA+U6hRCizpCEcZ+UUuxL\nvsxXG09TWFxK26aeTBnWHj/Pyof0EEKI2kYSxn1Iu1zIf388TXJqHg72WsbFtaF/RBMZJFAIUSdJ\nwvgNCotLWbnjPFsOpaMUdAn2ZlxcG/y9XGwdmhBCWI0kjHtgMim2HU5n+fbzFN0wEODlwvj+bQht\n5W3r0IQQwuokYVjoVOpV/vvjGS5dLqSBox2PxLamf0QT7O2s+lp0IYSoMSRh3EVu/g2WbDnL3pOX\nAegdGsiYmFYypIcQot6RhHEHJXojG/elkrDnIvpSEy0D3Xh0QFuCG3vYOjQhhLAJSRi/YjCa2JmU\nyaqdF7hWpMfNxYHH+relV+dAuftJCFGvWT1hbN++nblz56KUYsyYMUydOrXc/EWLFrF06VLs7e3x\n8vJi7ty5BAYGWjus2xiMJhJP6Fjz00V0uddxdNDyYM8WDI5shksDyatCCGHVv4Qmk4k5c+awaNEi\n/Pz8iI+PJy4ujuDgYHOZDh06sHz5cpycnPjmm2/4+9//zgcffGDNsMopNZjYmZTBusRUcq7dwE6r\noV/XIEb0aoGnXKcQQggzqyaMpKQkmjdvTlBQEADDhg1j06ZN5RJGZGSk+f9hYWGsXr3amiGVk3ml\niM9WHefS5UIc7LXEhTdhcFQzvD1kNFkhhPg1qyYMnU5X7vSSv78/R48evWP5ZcuW0bdvX2uGZLbr\naCZfbTxNSamRPp0DGd1X7nwSQojKWDVhKKUsLrtq1SqOHz/O4sWLqzyOkxev8p/1yZQaTTg72qPR\naEjLLsTZyY5pD3Uksr1/la9TCCHqGqsmjICAADIyMsyfdTodfn5+t5XbvXs3CxYs4KuvvsLBwcGi\nun193Swqd+R0Nh8tS8JkMuHl4Uz+dT3Xbxjo2Mqb343rSoC3q2UbU4NZ2hb1gbRFGWmLMtIWVcOq\nCSM0NJTU1FTS09Px9fUlISGB999/v1yZEydO8Oabb7Jw4UIaNWpkcd3Z2QV3LXPswhX++d1RlFK8\nMDqUzsE+wM2ej0ajAZPJonpqMl9ft1q/DVVF2qKMtEUZaYsy95s4rZow7OzseOONN5gyZQpKKeLj\n4wkODubjjz8mNDSU2NhY3n33XYqLi3n55ZdRStG4cWPmzZt33+s+fCaHeSuPAfDimM7lxnuSd1QI\nIcS906h7udBQg2RnF1BSamRnUiYdWjQi8OdTS0op1iemsmzrOezttbw4JpROLevu4IBy9FRG2qKM\ntEUZaYsyNbqHYW3fbTvHj/vT0AARIX4MimzGjwcusee4jkZuTrwwOpSWge62DlMIIeqEWpswUnUF\nbDqQho9HA1ydHdiXfJl9yTcHCAxu7M7zo0PlwTshhKhCtTJhmEyKxRtPoRRMGtyOji28OHYhlw17\nU/Fv5MK4uDY42Muw40IIUZVqZcL4cV8q59LziQjxM1+fCG3lLS8yEkIIK6qVh+GL1pzAydGO8XFt\nbB2KEELUG7UyYRRc1zOyd0sauck1CiGEqC61MmE82Lslcd2a2DoMIYSoV2plwnhmVGd5l7YQQlQz\n+asrhBDCIpIwhBBCWEQShhBCCItIwhBCCGERSRhCCCEsIglDCCGERSRhCCGEsIgkDCGEEBaRhCGE\nEMIikjCEEEJYRBKGEEIIi1g9YWzfvp3BgwczaNAgFixYcNt8vV7PK6+8wsCBAxk7diwZGRnWDkkI\nIcRvYNWEYTKZmDNnDgsXLmTNmjUkJCRw7ty5cmWWLVuGh4cHGzdu5PHHH+fdd9+1ZkhCCCF+I6sm\njKSkJJo3b05QUBAODg4MGzaMTZs2lSuzadMmRo0aBcCgQYP46aefrBmSEEKI38iqCUOn0xEYGGj+\n7O/vz+XLl8uVuXz5MgEBAQDY2dnh7u5OXl6eNcMSQgjxG1g1YSil7rmMUgqNRmOtkIQQQvxG9tas\nPCAgoNxFbJ1Oh5+f321lsrKy8Pf3x2g0UlhYiIeHx13r9vV1q/J4aytpizLSFmWkLcpIW1QNq/Yw\nQkNDSU1NJT09Hb1eT0JCAnFxceXKxMbGsmLFCgDWr19Pjx49rBmSEEKI30ijLDlvdB+2b9/O22+/\njVKK+Ph4pk6dyscff0xoaCixsbHo9Xpee+01Tp48iaenJ++//z5Nmsj7uoUQoqaxesIQQghRN8iT\n3kIIISwiCUMIIYRFJGEIIYSwSK1LGHcbm6ouy8rKYtKkSQwdOpThw4fzn//8B4Br164xZcoUBg0a\nxJNPPklBQYGNI60eJpOJUaNGMW3aNADS0tJ45JFHGDRoENOnT8dgMNg4wupTUFDASy+9xJAhQxg2\nbBhHjhypl/vFokWLePDBBxk+fDivvvoqer2+Xu0Xs2bNomfPngwfPtw8rbL94K233mLgwIE89NBD\nnDx58q7116qEYcnYVHWZnZ0dM2fOZO3atXz77bd8/fXXnDt3jgULFhAdHc2GDRuIiopi/vz5tg61\nWvznP/8hODjY/Pkf//gHTzzxBBs2bMDNzY1ly5bZMLrq9fbbbxMTE8O6detYtWoVrVq1qnf7hU6n\nY/HixSxfvpzVq1djNBpJSEioV/vF6NGjWbhwYblpd9oPtm3bRmpqKhs3bmT27Nm8+eabd62/ViUM\nS8amqst8fX1p3749AK6urgQHB6PT6cqNxzVq1Ch+/PFHW4ZZLbKysti2bRsPP/ywedqePXsYNGgQ\ncLMdfvjhB1uFV60KCwvZv38/Y8aMAcDe3h43N7d6uV+YTCaKi4sxGAzcuHEDPz8/EhMT681+ERER\ngbu7e7lpv94PfvmbuWnTJkaOHAlAly5dKCgoICcnp9L6a1XCsGRsqvoiLS2N5ORkunTpwpUrV/Dx\n8QFuJpWrV6/aODrrmzt3LjNmzDAPI3P16lU8PDzQam/u0gEBAfVm30hLS6NRo0bMnDmTUaNG8cYb\nb1BcXFzv9gt/f3+eeOIJ+vXrR9++fXFzc6NDhw64u7vXy/3iF7m5ueX2g9zcXKD8OH5ws/10Ol2l\nddWqhCGPjNxUVFTESy+9xKxZs3B1da13Y29t3boVHx8f2rdvb94nlFK37R/1pV0MBgMnTpzg0Ucf\nZcWKFTg7O7NgwYJ6s/2/yM/PZ9OmTWzZsoUdO3ZQXFzM9u3bbytX39rlTir6e3q3trHqWFJVzZKx\nqcpAbb4AAAdwSURBVOo6g8HASy+9xEMPPUT//v0B8Pb2JicnBx8fH7Kzs/Hy8rJxlNZ18OBBNm/e\nzLZt2ygpKaGoqIi5c+dSUFCAyWRCq9WSlZVVb/aNgIAAAgICCA0NBWDgwIF8/vnn9W6/2L17N02b\nNsXT0xOA/v37c+jQIfLz8+vlfvGLO+0H/v7+ZGVlmctZ0ja1qodhydhUdd2sWbNo3bo1jz/+uHna\nAw88wPLlywFYsWJFnW+T6dOns3XrVjZt2sT7779PVFQU//jHP4iKimL9+vVA/WiHX/j4+BAYGMiF\nCxeAm9dyWrduXe/2i8aNG3PkyBFKSkpQSrFnzx7atGlT7/aLX/cc7rQfxMXFsXLlSgAOHz6Mu7u7\n+dTVndS6oUEqGpuqvjhw4AATJkygbdu2aDQaNBoNr7zyCp07d+Z3v/sdmZmZNG7cmI8++ui2C191\n1d69e/nyyy/57LPPuHTpEtOnTyc/P5/27dvz7rvv4uDgYOsQq0VycjJ/+tOfMBgMNG3alHfeeQej\n0Vjv9otPPvmEhIQE7O3t6dChA2+99RZZWVn1Zr949dVXSUxMJC8vDx8fH1588UX69+/Pyy+/XOF+\nMHv2bHbs2IGzszPvvPMOHTt2rLT+WpcwhBBC2EatOiUlhBDCdiRhCCGEsIgkDCGEEBaRhCGEEMIi\nkjCEEEJYRBKGEEIIi0jCEDXaAw88wNmzZ6tlXZ988km5oa9nzpzJ119/fd/1zpw5k+HDhzN9+vT7\nrqsyycnJrFu3zqrrEPWbJAwhfvbJJ59QWlpapXXm5OTw/9u7v5AmuziA49/ln7S8KOvWoghaI8KL\nihkJWon0R/Y8S2NYOL1IEFqE3gjRRZZEBcPyJqE/lDSIyBp2UV4IEVgGXeyiDKMVFnSRltTmaPr4\ney/Eh3KL9vYG7+vb73O182znnN/DYL+dHfY7fX199Pb2EgwGf+vYcz1//vyXE8b09PRvjkb9H2nC\nUPPS69evOXjwIDU1NRiGYZc+AHA6nXR1dVFdXU1FRQV9fX32c/fv32fnzp14vV66urpwOp0kEgna\n2tpwOBz4fD5M0yQWiwEwPDyM3++nsrKS1tbWH8Zz584dqqqq8Hg8BAIBPn78SDwex+/38/XrV0zT\n5OrVq9/1CYfDHDp0yG5blkVpaaldL+3ixYvs27cPr9dLU1MTY2NjAExOTnL69GmqqqowDINAIMD4\n+DidnZ08fvwY0zRpb28HZiojmKaJx+OhoaGBt2/fAjP/kDcMg5MnT+Lz+Xj48OE/eTvUn0KU+g8r\nLy+Xly9ffndtampKTNOUaDQqIiKxWEwqKyvt9tq1a+X69esiIvL06VMpLS0VEZHR0VHZvHmzjIyM\niIjIlStXxOl0ysTEhN0vkUjY87S2tkptba0kk0lJJpOye/duGRgYSIlxeHhYtm7dKqOjoyIi0tHR\nIUeOHBERkXfv3onb7U57b4lEQtxut3z69ElERPr7+8Xv94uISDgclmPHjtmvDYVC0tLSIiIinZ2d\nEggEZGpqSkTE7t/T0yOHDx+2+4yNjYnb7ZZXr16JiMjNmzelpqZGREQGBwfF5XJJJBJJG5tS6egK\nQ807b968IRqN0tzcjGEY7N+/n8nJye9OX9y1axcAxcXFfPjwgWQySSQSYf369RQVFQFQXV2dMrbM\nqZSzY8cOcnJyyMnJweVyMTIyktJncHCQsrIyli1bBoDP52NgYOCn95GXl8f27du5e/cuMFMYbvYQ\npP7+fh49eoRhGBiGQSgU4v3798BMefe6ujqysrIA7Oqsc0UiEdatW8fq1asB2Lt3L0NDQ0xMTACw\ncuVKNmzY8NM4lZo1r8qbKwUzH+qFhYXcvn077fMOh4OFCxcC2AfnWJaVkgzmttPJzc21H2dlZaU9\nD1pEUs4RmJ33ZwzD4NSpU+zZs4cnT55w9uxZe8ympia8Xm/a+TKRLq5v24sWLcpoHKVm6QpDzTur\nVq0iLy+PcDhsX4tGo8TjcSD1A3W2XVxczLNnz+zf8b/d9wAoKCjgy5cvfzuekpISHjx4YO8x3Lhx\ngy1btqTMn87GjRuJxWIEg0EqKirsRLdt2zZCoRCfP38GIJlM8uLFCwDKy8u5du2avUE/e5JeQUGB\nvfcye79DQ0N22fOenh5cLpcmCvXLdIWh/tMcDgf19fVkZ2fb35h7e3u5cOEC7e3tXL58GcuyWL58\nOR0dHXafuWPAzEEyx48fp7GxkaVLl1JWVkZ2djb5+fkANDQ0UFdXR35+Pt3d3RnHuGbNGpqbm6mv\nr2fBggUUFRXR1taWMv+PGIbB+fPnCYVC9jWPx8P4+DgHDhzA4XAwPT1NbW0tTqeTxsZGgsEghmGQ\nm5vLihUrOHfuHCUlJVy6dAnDMNi0aRNHjx7lzJkztLS0YFkWhYWF9gpGqV+h5c3VHyUej7N48WJg\n5hv3rVu3fst/LZT6E+gKQ/1Ruru7uXfvHpZlsWTJEk6cOPFvh6TUvKErDKWUUhnRTW+llFIZ0YSh\nlFIqI5owlFJKZUQThlJKqYxowlBKKZURTRhKKaUy8hf8CwfjbzhfpQAAAABJRU5ErkJggg==\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f47b218dbd0\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(ag_means)\n", + "plt.ylabel('Time(s)')\n", + "plt.xlabel('Length of vector')\n", + "_ = plt.title('Time to sum the elements of 1000 vectors (AutoGraph)')\n", + "_ = plt.ylim(ymin=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "d7IAJ6Bwbk9t" + }, + "source": [ + "## Eager" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "XMu5-12yoOzY" + }, + "outputs": [], + "source": [ + "from tensorflow.python.eager import context" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "_vt9MzpyjQ4T" + }, + "outputs": [], + "source": [ + "# Sum written using for loop and run with tf.eager\n", + "def sum_all(elements):\n", + " sum_ = 0.0\n", + " length = elements.shape[0]\n", + " for i in tf.range(length): \n", + " sum_ += elements[i][0]\n", + " return sum_\n", + "\n", + "eager_means = []\n", + "for num in range(max_elements):\n", + " with context.eager_mode():\n", + " durations = []\n", + " for i in range(trials + burn_ins):\n", + " \n", + " start = time.time()\n", + " for _ in range(batches):\n", + " run_trial(num)\n", + " \n", + " if i \u003c burn_ins:\n", + " continue\n", + " \n", + " duration = time.time() - start\n", + " durations.append(duration)\n", + " eager_means.append(np.mean(durations))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 301 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 422, + "status": "ok", + "timestamp": 1532460024499, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "5gHVdMlD-A-T", + "outputId": "3b581cb7-7ef9-489c-92f1-3e52c0c2dc8a" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAEcCAYAAAAydkhNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XlclNX+wPHPsC+CILviCiruhAiiqaAimqnglpZLWpkt\ntmh50+rXvZZ2q9veq5ulZYtlaZp7mQuaorjklkoqiwgKsojsy8yc3x9cBxFUVGBYvu+/mGfmPM93\nzjzMd85zznOORimlEEIIISphYuwAhBBC1F2SJIQQQtyQJAkhhBA3JElCCCHEDUmSEEIIcUOSJIQQ\nQtyQJAlg8eLFvPrqq8YOo14ZOHAge/furfHjzJs3jw8//LDGj1NXxMfHExERQc+ePfnuu++MHU6D\ndvbsWcaMGWPsMCo1btw4YmNjjR0G0EiSxD333IOfnx9+fn506tSJHj16GLZt2LCBxx9/nNdff73G\n40hOTsbHxwe9Xl/jx6pOje2L+lZq8nNcsmQJgYGBHDp0iEmTJlV4fvPmzUyYMAFfX1+mTJlS4flT\np04xevRofH19GTNmDDExMeWef+eddwgMDKR379688847t1W2pk2ePJlVq1bV2vE++ugjHn30UcPj\ngQMH0qNHD/z8/AzfD2+88UatxXOtRx55pM78zzWKJHH48GH+/PNP/vzzT5o3b87ixYsN2+6///5a\ni0MphUajQe5frN9q8nO8cOEC3t7eN3zewcGBhx9+mBkzZlR4rqSkhKeeeorw8HAOHDhAeHg4Tz75\nJFqtFoAVK1awfft21q9fz7p164iMjOTHH3+sUtn64HY+j7S0NKKjoxk0aFC57YsXL+bPP/80fD+8\n8sor1R3mTel0OqA0YUVHR5Oenl6rx69Mo0gS11JKVTiZPvnkE1588UWg7Ffi6tWrCQ4OJjAwkBUr\nVnD8+HFGjhxJQEBAhVbHqlWruO+++wgMDOTRRx/lwoULlR578uTJAPj7++Pn58fRo0dRSvHpp58y\ncOBA+vbty0svvURubm6l5S9fvszMmTPp1asXgYGB5X5p+vj4cP78ecPja3/979+/nwEDBrBkyRL6\n9OlDv3792Lp1Kzt37iQsLIzAwEAWL15c6TF/+ukn1q9fz5IlS/Dz8+OJJ54wPHfq1ClGjhxJr169\nmD17NsXFxYbnduzYQXh4OL169WLixIn8/fffle4fIDY2lunTpxMYGMiwYcPYvHnzDV97s/0OHDiQ\npUuXMnLkSO655x5eeeUVMjIyeOyxx/Dz82P69Onk5OQYXn/kyBEmTJhAr169CA8PZ//+/YbnJk+e\nzIcffsjEiRPx8/PjkUceISsry/AclP8cExMTmTx5Mv7+/gQFBTF79uwbvodt27Zx//33ExAQwJQp\nU4iLiwNg6tSpREdHs2DBAvz8/Dh37lyFskFBQQwdOhQXF5cKz+3fvx+dTseUKVMwNzdn8uTJKKXY\nt28fAL/88gvTp0/H1dUVV1dXpk2bxpo1awCIjo6+adlrbdq0qcJlmmXLlvHkk08CUFxczFtvvUVI\nSAj33nsv//znP8udG1u3biU8PJyePXsyZMgQdu/ezfvvv8+hQ4d4/fXXy/2C//PPPxk7diy9evVi\n3LhxHD58uNxn9P777zNx4kR8fX1JSkpi9erVDB48GD8/PwYPHsyGDRsq/Qz27NlDly5dsLCwKLf9\nRonm/PnzTJ06lcDAQIKCgnjhhRfK/Z+eOHHCcJnw2Wef5fnnny/XErjVefvFF18Yzlu9Xo+FhQVd\nunRh9+7dlcZTq1QjExISoqKiospt+/jjj9WLL76olFIqKSlJdezYUb322muqqKhI7dmzR3Xr1k09\n9dRTKjMzU6WkpKigoCB14MABpZRSv//+uxoyZIiKi4tTOp1O/fe//1UPPPBApcdOSkpSPj4+Sq/X\nG7atXLlSDRkyRCUlJan8/Hz19NNPG2K53rvvvqtee+01pdPplFarVQcPHjQ85+PjoxITEw2PX3rp\nJfXBBx8opZSKjo5WnTt3Vp9++qnSarXqp59+Ur1791Zz5sxR+fn56syZM6pbt27q/PnzlR732n1d\nW4/jxo1TaWlp6sqVK2rYsGFqxYoVSiml/vrrLxUUFKSOHTum9Hq9WrNmjQoJCVHFxcUV9p2fn68G\nDBig1qxZo/R6vTp58qQKDAxUZ8+erXDsW+03JCREPfDAAyojI0OlpqaqoKAgFRERoU6dOqWKi4vV\nlClT1CeffKKUUiolJUUFBASoXbt2KaWUioqKUgEBASozM1MppdSkSZNUaGioOnfunCoqKlKTJk1S\n77777g0/x9mzZ6vPPvtMKaVUUVGROnToUKV1GRcXp3x9fVVUVJTSarXqiy++UKGhoaqkpMRw3JUr\nV1Za9lo//fSTmjx5crltX331lXrsscfKbXv88cfVV199pZRSqmfPnuro0aOG544fP678/PyqVPZa\nBQUFys/PT507d86wbcyYMWrTpk1KKaXeeOMN9cQTT6js7GyVl5enZs6cqd577z2llFJHjx5VPXv2\nNPwPpqamqri4uErfe1ZWlurVq5dat26d0ul0asOGDapXr14qKyvL8PqQkBB19uxZpdPpVE5OjvLz\n81MJCQlKKaXS0tIM59H13nrrLbVgwYJy2yr7brjq3LlzKioqSpWUlKjMzEw1adIktWjRIqWUUsXF\nxSokJER9++23SqvVqi1btqguXbrc1nkbHh6uUlJSVFFRkeGYr7/+uvr3v/9daTy1qdG1JKpCo9Hw\n1FNPYWFhQZ8+fbC2tmb48OE4Ojri5uaGv78/J0+eBODHH39kxowZtG3bFhMTE2bMmEFMTAwXL168\n4f7VNb9WNmzYwMMPP0yLFi2wtrZm9uzZbNq0qdLr3WZmZqSlpZGUlISpqSk9e/asdJ+VMTc3Z+bM\nmZiamnLfffdx+fJlpk6dirW1Nd7e3nh7e9/0135lpkyZgrOzM/b29oSEhHDq1CkAVq5cyYQJE+jW\nrRsajYbw8HAsLCw4evRohX3s2LEDT09PwsPD0Wg0dOrUiSFDhvDrr79WeG1V9jtp0iSaNWuGq6sr\n/v7+9OjRAx8fH8zNzQkNDTXEuG7dOoKDg+nXrx9Q+gu9a9eu7Ny507Cv0aNH06pVKywsLBg2bJih\n7FXX1rmZmRnJycmkpqZiYWGBn59fpXW2efNmgoODCQoKwtTUlEceeYTCwsJyv5DvVH5+PnZ2duW2\nNWnSxPCL9/rn7ezsyM/Pr1LZa1lZWTFo0CDDr/SEhATi4+MNl25WrVrFvHnzsLOzw8bGhhkzZhhe\nu2rVKsaOHUtQUBAArq6utG3bttL3ExkZSZs2bRgxYgQmJiYMHz6cdu3asWPHDsNrIiIi8PLywsTE\nBFNTU0xNTTl9+jRFRUU4Ozvj5eVV6b5zcnKwtbWtsP2pp54iICCAXr16ERAQwMqVKwFo1aoVQUFB\nmJmZ4ejoyNSpUzlw4ABQ2iLV6XRMmjQJU1NTQkND6d69u2GfVTlvp0yZgpubW7mWja2tLdnZ2ZXG\nX5vMjB1AXeXk5GT428rKCmdnZ8NjS0tLwz/XhQsXWLhwIW+99RZQdr06NTUVDw+PWx7n0qVLNG/e\n3PC4RYsWaLVa0tPTcXV1LffaRx99lI8//pjp06ej0WgYN25cpdemK+Pg4IBGozG8n8re49X3VFXX\nlre2tiYtLQ0orZO1a9caRucopdBqtVy6dKnCPi5cuMCRI0cICAgwvFan0xEeHl7pa2+132tjsrS0\nrPD42s9t8+bNhi+cq/u6+uUFlPvMra2tb1o/c+fO5YMPPmDs2LGGfoPKRs5c/3lrNBo8PDxITU29\n4b6rysbGpsKXem5uLk2aNKn0+dzcXGxsbKpU9nrDhw/n7bff5sknn2TDhg0MHjwYCwsLMjMzKSgo\nKPfe9Xq9IaGmpKQwYMCAKr2f6+sKoHnz5uXqyt3d3fC3tbU177//PkuXLmX+/Pn07NmTuXPn0q5d\nuwr7tre3Jy8vr8L2Tz/9lN69e1fYnpmZyRtvvMHBgwfJz89Hp9Ph4OAAlPZvuLm5lXv9tf/7VTlv\nr30fV+Xl5WFvb19he22TJHGX3N3deeKJJ6rUAX71S/parq6u5fowkpOTMTMzK/cFdZWNjQ3/+Mc/\n+Mc//kFsbCyTJ0+me/fu9O7dG2trawoKCgyvTUtLq/TEqw3u7u7MnDmTxx9//Jav9fDwIDAwkKVL\nl1brfqty3PDwcBYsWHDbZSv7HJ2cnAx9VYcOHWLatGkEBATQsmXLcq9zdXXlzJkz5bZdvHixWj6r\n9u3bs2zZsnLbTp8+behD8fb2JiYmhm7dugGlfUrt27e/adnKRlgB3HvvvcybN4+YmBg2btzI/Pnz\nAXB0dMTa2poNGzZU+JEDpZ/htX1n17q+Xl1dXdmyZUu5bRcuXKB///43LNO3b1/69u1LcXEx77//\nPq+++irLly+vcKyOHTuydu3aCttv1CJ/99130Wg0bNiwAXt7e7Zu3WroN3FxcamQ5C9evEirVq0M\n7/lOztu4uDhGjhx5W2VqglxuqsStLt1ca+LEiSxevJizZ88Cpc3Yyi6VADRr1gwTExMSExMN24YP\nH86yZctISkoiLy+P999/n+HDh2NiUvGjiYyMNJS1sbExNK+htON6w4YN6PV6du3aZWgKVwdnZ+cb\n/mNXZvz48axYsYJjx44BpZcydu7cWekv8eDgYOLj41m7di1arZaSkhKOHz9u6My90/3eysiRI9m+\nfTu7d+9Gr9dTVFTE/v37q/SLvrLP8ddffzWUtbe3x8TEpNLPcNiwYURGRrJv3z60Wi1Lly7F0tIS\nX1/fKsWt1+spLi5Gq9WW+xsgICAAExMTvv32W4qLiw2/XAMDAwEIDw9n2bJlpKamkpqayrJlyxg9\nevRNy1b2qxrA1NSUsLAw3n77bbKzs+nbty+AoYW7aNEiMjMzAUhNTTV0wI4dO5bVq1ezb98+lFKk\npqYaPuvrz7MBAwZw7tw5Nm7ciE6nY9OmTcTFxRESElJpTBkZGWzfvp2CggLMzMwM/yOV6du3LydO\nnCjXoX4zeXl52Nra0qRJE1JTU8v9qPH19cXU1JTly5ej0+nYunWr4RyFOztvi4uLOXHihKFejanR\nJYnKfgXe6jU3ezx48GAee+wxnn/+efz9/Rk5ciR//PFHpfu1srJi5syZTJw4kYCAAI4dO8bYsWMZ\nNWoUkyZNIjQ0FGtr6xsOu0tISODhhx/mnnvuYeLEiTz00EP06tULgJdffpnt27fTq1cvNm7cyODB\ng+/qPV5r7NixnD17loCAAJ5++ulbvr5r1668/vrrLFiwgICAAMLCwgyjaK5na2vLl19+yaZNm+jX\nrx/9+vXj3XffrfSf91b7vZ335O7uzqeffsrixYsJCgoiJCSEL7/80vAD4WZlK/scjx8/zrhx4/Dz\n8+Opp57i5ZdfpkWLFhXKtm3blnfeeYfXX3+doKAgIiMj+eyzzzAzM7vlcQHWrl1L9+7dWbBgAYcO\nHaJHjx6GG0HNzc359NNPWbNmDQEBAaxevZpPP/3UsO8JEyYQEhLCyJEjGTlyJCEhIYwfP75KZSsz\nfPhw9u7dy7Bhw8olxBdeeIHWrVszfvx4/P39mT59OgkJCQB0796dRYsWsWjRInr27MmUKVMM/XdT\npkzh119/JTAwkIULF+Lg4MBnn33G0qVL6d27N0uXLmXx4sU0bdq00rrS6/V89dVX9O/fn969e3Pg\nwAFee+21SmN3cnKid+/ebN26tdz2J554wnBPlZ+fH7NmzQLg6aef5q+//sLf35+ZM2cSFhZmKGNu\nbs7HH3/MypUr6dWrFxs2bGDgwIGG/oXbPW+hdARcYGBgpaPYaptG3c7P5ts0f/58IiMjcXJyYv36\n9QDExMTw2muvUVRUhJmZGa+99pqh+SuEELUlNjaWl156ydA5XZ3Gjx/PxIkTiYiIuKPyDzzwAAsX\nLrzpPTO1pUaTxMGDB7G1tWXu3LmGJPHII48wbdo07r33Xnbu3MmSJUv49ttvayoEIYSocQcOHKBt\n27Y4Ojqybt06/vWvf7F169ZK+xbrmxrtuPb39yc5ObncNo1GY7ihKScnp8KoACGEqG/i4+N57rnn\nyM/Pp1WrVnz00UcNIkFADbckoHS0zsyZMw0tidjYWB599FHDnc8rVqyo0lBRIYQQta/WO65/+OEH\nXn75ZSIjI5k3b55h6JwQQoi6p9aTxC+//GIYeTN06NByQ8VupoYbPEIIUe9l5xVz/Gx6tX5f1vjN\ndNcH6+bmxv79+wkICGDv3r20adOmSvvRaDSkpeXc+oWNgIuLndTF/0hdlJG6KNPY6iL9SgFb9p9n\n17ELFJfo+ee0XrRyK51mxcXF7halb65Gk8ScOXOIjo4mKyuL4OBgZs2axeuvv84bb7yBXq/H0tKy\nVtZxEEKIhijjSiG/7I5j71+p6JWimb0lw4Jb09K18ulU7kSNd1xXp8b0y+BmGtuvpJuRuigjdVGm\noddFXmEJG/eeY+vBJLQ6Pc2dbRkW2IrAzm6YmZbvRajTLQkhhBDV61RCJv9de4LcghKa2VsS0a8d\nQV3cMTG59WwSd0KShBBC1BORR5JZvuU0AGODvRjc0xML88rnp6oukiSEEKKO0+n1/LQ9lt8PnqeJ\ntTlPj+5Gh5YOtXJsSRJCCFFHXczIY/fxi+z9K4Ws3GI8nGx4dlwPXB2say0GSRJCCFHHFBRp+WL9\nSY6cTQfA2tKMgX4tGN3fCxur2v3aliQhhBB1SGZ2IR+sPEZSWi7enk0Z5OfJPe2da7zv4UYkSQgh\nRB1x/lIuH6w8yuWcIkL8WvDQ4A41NmqpqiRJCCGEESmlOJt8hT3HU4g+mUpRiY7xId6EBbSs0iJp\nNU2ShBBCGMmhvy+xMjKWS5dL16d3tLNk+vBO9PKpuD64sUiSEEKIWqbXK9b8EcfGvecwNzMhqIsb\nfbp50KmVo9EvL11PkoQQQtSi/MISFq87yfG4DFwdrZk1uhstXKpvrqXqJklCCCFqScaVQt798Qgp\nmfl0bdeMx0d2wdbK3Nhh3ZQkCSGEqAUX0vN498cjXM4pYkivlowP8a5zl5YqI0lCCCFqWNyFbD5Y\neZTcghLGhXgxLLC1sUOqMkkSQghRA5RSJKTksPvYRfb8dZESrZ5pw3zo16O5sUO7LZIkhBCimh2L\nzWBV5FmS0vIAcGhiweQRHbmng4uRI7t9NZok5s+fT2RkJE5OTqxfv96w/dtvv2X58uWYm5szYMAA\nXnjhhZoMQwghas3Wg+f5YdsZTDQaenZ0oV93D7q0bYapicmtC9dBNZokRo8ezeTJk5k7d65hW3R0\nNDt27GDDhg2YmZmRmZlZkyEIIUSt0OsVK7afYevBJOxtLXh2bHfaetgbO6y7VqNJwt/fn+Tk5HLb\nfvjhBx577DHMzEoP3axZs5oMQQghaoRerzgWm0Fyei5pWYWcS83hXEoOzZ1teW5sd5xrcTrvmlTr\nfRIJCQkcPHiQ999/H0tLS+bOnUu3bt1qOwwhhLhjWp2epRtPEX0ytdz27l5OzBjRGZs6fu/D7aj1\nJKHT6cjOzuann37i2LFjPPfcc2zbtq1KZe92Qe+GROqijNRFGamLMjVVF0UlOv799QEOnkqlU5tm\njB3UHrdmNrg52mBl2fDGAtX6O3J3d2fIkCEAdO/eHRMTEy5fvoyjo+Mty6al5dR0ePWCi4ud1MX/\nSF2UkbooU1N1UVCk5eOfjxGTmEWXts14OqIblhal6zzkZBdQF2v/bpNljXe3K6XKPR48eDB79+4F\nID4+Hq1WW6UEIYQQxhR/MZvXvz5ITGIWPTu68MyY7oYE0ZDVaEtizpw5REdHk5WVRXBwMLNmzWLM\nmDHMmzePESNGYG5uzltvvVWTIQghxF3R6fVs3HuO9XsS0OkVQ3q1ZFyIV70d0nq7NOr6n/p1mDSl\nS8llhTJSF2WkLspUV12kXylg8boTxCZn42hnySPDO9G5Tf0akXm3l5saXi+LEEJUgyNn01m64SR5\nhVoCOrkyOaxjnZ+xtSZIkhBCiGsUlehY+0c8v+5PxMzUhKlDO9K/R/M6sZSoMUiSEEIIILeghO2H\nkth6KIncghLcmtnwxKgutHJr3MOKJUkIIRo1vV6xYW8Cm/ado7hEj62VGff3acOwwFZYN8D7Hm6X\n1IAQotG6klfM5+tOcOrcZRyaWDCmf2v69fDAykK+Gq+SmhBCNEp/J17ms7UnuJJXjK+3M4/c36lR\ndkzfiiQJIUSjUqLV88vuOH6NTkSDhnEhXgwNaNVoO6ZvRZKEEKLRSEzNYcmGkySl5eHc1IrHRnSm\nvaeDscOq0yRJCCEavBKtno17E9i49xw6vWKAb3PGh3hLx3QVSA0JIRq00+ez+PrXGC5m5ONoZ8nU\noT5093Iydlj1hiQJIUSDlH6lgLV/xLPnrxQ0wCA/T0YPaCeth9sktSWEaFCu5Baxenc8m6NKJ+Tz\ndLFlylAfvFs0NXZo9ZIkCSFEvafT6zkRn8nu4ykcOZOGVqdwcbAi/N52BHZ2w8RERi7dKUkSQoh6\nq6hYx/Y/k9hy8DxXcosBaO5sS/gAL3zbNcPMtHFM512TJEkIIeqdohIdO/5MZnP0OXLyS7C2NCPE\nrwX3dvOgjbsdrq72Mm16NZEkIYSoV1Iz8/lw1TFSMvOxtjRlZN82DOnVEhu5W7pG1GhbbP78+fTp\n04cRI0ZUeG7p0qX4+PiQlZVVkyEIIRqQUwmZvPHNQVIy8xnU05O3ZvYhvF87SRA1qEaTxOjRo1m6\ndGmF7SkpKURFRdG8efOaPLwQooEo0erYevA87/10lMJiHdPu8+Gh0A40sZbkUNNq9HKTv78/ycnJ\nFbYvWrSIuXPn8sQTT9Tk4YUQ9ZhSiuNxGUSfvMSRs2kUFOloYm3O06O70aGlTKVRW2q9T2L79u14\neHjQsWPH2j60EKIeWRUZy+boRACc7K0Y4NuCQX6eODW1MnJkjUutJonCwkI+++wzvvzyS8M2pVSV\ny9/tgt4NidRFGamLMg2lLk7EZfDr/kQ8nG2Z/aAfHVs53vYsrQ2lLoytVpNEYmIiycnJjBo1CqUU\nqampjBkzhpUrV+LkdOu5VGRIWykXFzupi/+RuijTUOqiqFjHe98fAgXThvrgZGNOenrube2jodRF\ndbjbZFnjSeLalkKHDh3Ys2eP4fHAgQNZs2YNTZvK7fJCiFKrdsZy6XIBQwNa4e0p3w3GVqOjm+bM\nmcOECROIj48nODiYn3/+udzzGo3mti43CSEatphzl9l2KAkPJxvC+7U1djiCGm5JvPvuuzd9ftu2\nbTV5eCFEPXL4TBpfbjyFRgOPDO+MhbmpsUMSyB3XQggj0+r0rIqMZcuB85ibmTD9vk60a25v7LDE\n/0iSEEIYTWJqDl//+jfxF7Nxb2bDE+FdaenaxNhhiWtIkhBC1LrcghLW7Ioj8kgySkFQFzcmh3XE\nykK+kuoa+USEELXm6uytG/cmkFeoxcPJhomD29O1rSwnWldJkhBC1LgSrY7IIxfYtPccV/KKsbY0\nY8JAbwb29JQ1H+o4SRJCiBqhV4qzSVfYdzKVA6dSySvUYmlhyv19WhMW0Apbmbm1XpAkIYSodhcz\n8vho1TFSLxcA0NTWgvt6tyYsoCV2NhZGjk7cDkkSQohqdTmniPd+PEJGdhG9u7jRt6sHnVo7yjrT\n9ZQkCSFEtckvLOG9n0oTRET/dozo08bYIYm7JD1GQohqUVSi46NVx0hOy2OQnyf3B7U2dkiiGkhL\nQghxx/ILtRw9m87hs+n8FZdBYbGOXj6uTBzc/ran9hZ1kyQJIcQdOZt8hU9+PkZ2fgkALg5WDOrp\nyci+baX/oQGRJCGEuG37TqTw5aYY9HrF/X1aE9jJjebOttJ6aIAkSQghqkwpxdrd8azbk4C1pSlP\nhHeTu6UbOEkSQogqUUqx/PfTbP8zGeemVjw7rgctnG2NHZaoYZIkhBC3pFeK5VtOs+NwMp4uTXhh\ngi/2tnJTXGNQo0li/vz5REZG4uTkxPr16wF4++232bFjBxYWFrRq1Yo333yTJk1kamAh6iq9Uny3\n5TSRh5Np6VqaIOSu6cZDo2pw/dCDBw9ia2vL3LlzDUkiKiqK3r17Y2Jiwn/+8x80Gg1z5syp0v5k\nYfNSssh7GamLMtVVF4XFWtbtTiD2whWy84rJzi+moEhHK9cmvDDxHppY1/05l+S8KOPiYndX5Wu0\nJeHv709ycnK5bX369DH87evry2+//VaTIQghbkNs8hW+WH+SS1kFaDRgZ22Ok70VLVya8FBoh3qR\nIET1MmqfxKpVqxg+fLgxQxBCAHq9Yn1UAuv3JKCUYlhgK8L7tcPcTCZlaOyMliT++9//Ym5uzogR\nI6pc5m6bTQ2J1EUZqYsyd1IX+YUlvPPdIQ6eSsXZwZrZE/3o5u1cA9HVLjkvqodRksSaNWvYuXMn\n33zzzW2Vk2uMpeR6axmpizJ3UheZ2YV8uOoY5y/l0qVtM2aO6oKtlXm9r1M5L8rU6T4JKB1bfa1d\nu3axZMkSvvvuOywsZISEEMZyLiWHD1cdJSu3mGDf5jwY2kFWiRMV1GiSmDNnDtHR0WRlZREcHMys\nWbNYvHgxJSUlTJ8+HYAePXrwz3/+sybDEEJc5+/Ey3y46hhFxTrGh3gTFtBSptQQlarRIbDVTZqP\npaQpXUbqokxV6+LI2XT++8tf6PWKx0Z0JqCTWy1EV7vkvChT5y83CSGMK7eghKycIrLzi0lMzWVV\nZCxmphqeGdudbu1k3iVxc5IkhGjA/jh6gW9++xudvuyCgY2lGc+O6057TwcjRibqC0kSQjRQx2Iz\n+PrXv7GxMqNXJ1fsbSywszGnWzsnXBysjR2eqCckSQjRACWkZPPfX/7C1FTDs2O749WiqbFDEvWU\njHcTooFJSsvlg5XHKC7RMWNEF0kQ4q5IS0KIBuDvxMus3h3PgRMppF4uAOCh0A707Ohi5MhEfSdJ\nQoh6rKhEx4/bzxJ5uHQiTUsLU+5p70xAJzcCOze8oa2i9kmSEKKeSrqUy2frTnAhPQ9PF1seH90d\nN3tLuWtXb7olAAAgAElEQVRaVCtJEkLUMyVaPb/tT2TdngS0Oj2D/DwZF+JFi+YOcgOZqHaSJISo\nR04mZPLdltOkZOZjb2PO1GFduKe99DuImiNJQoh6oLBYy/LfT7PneAoaYKBfC0b3b4eNlSwCJGqW\nJAkh6rjE1Bz++8tfpF4uoLWbHVOHdaSNu72xwxKNhCQJIeoopRTb/0zmx+1n0OoUQwNaMXpAO+mY\nFrVKkoQQdZBSip93xrFp3zmaWJvz6P2d6O5V/1eLE/WPJAkh6hi9UqzYeoath5Jwc7TmhQn34NTU\nythhiUbqlkni/PnzrFq1iujoaFJSUrC0tMTHx4ewsDCGDBmCmdmNdzF//nwiIyNxcnJi/fr1AFy5\ncoXnn3+e5ORkPD09+eCDD7Czk7VoReNUotVzObeIgkItCoVSsPNIMruOXqSFiy0vPOBL0yaWxg5T\nNGI3XXTo//7v/zhx4gRDhw7lnnvuwdnZmaKiImJjY9m9ezcnT57kn//8J76+vpWWP3jwILa2tsyd\nO9eQJN555x0cHBx47LHH+Pzzz8nOzuaFF16oUrAyBryULKhSpj7WRVpWAd/8GsP5S7lk55dU+prW\nbnbMmeBLE+uqj16qj3VRU6QuytTookODBg1iwYIFFbZ37NiR++67j6ysLM6fP3/D8v7+/iQnJ5fb\ntm3bNr777jsAIiIimDx5cpWThBD1XWzyFT76+Rg5+SW4OlrTwqUJjnaW2FiZYaLRoNGAjZU5g/xa\nyPBWUSfcNEkMGDDgpoUdHBxwcLi9hUsyMzNxdi7tgHNxceHy5cu3VV6I+upAzCWWbDiJTqeYPKQD\nIX6exg5JiFuq0li6f//73+Tk5KDVannwwQfx9fVl7dq1NR2bEPVadl4xO48k882vMfxr2YHS9R1M\nNDw7rrskCFFvVGl0U1RUFC+99BKRkZG4ubnx/vvvM2PGDEaNGnXbB3RyciI9PR1nZ2fS0tJo1qxZ\nlcve7bW1hkTqokxdrIviEh3zv9hHSkY+AGamJnRq04wnx/agjUfN3QhXF+vCWKQuqsdtDYE9cOAA\noaGhuLm5odFoqlTm+n7xgQMHsnr1ambMmMGaNWsYNGhQlY8vHVGlpFOuTF2ti1+jE0nJyCeoiztD\nerWkhYut4Sa4moq3rtaFMUhdlLnbZFmly01OTk688sorbNq0ib59+6LVatHpdLcsN2fOHCZMmEB8\nfDzBwcH8/PPPzJgxg6ioKMLCwti7dy8zZsy4qzcgRF2TW1DChqgEbK3MeDC0Pa3d7eQuaVFvVakl\n8e6777Ju3TrGjh1L06ZNSUpKYtq0aVUqV5lly5bdVpBC1CcbohLIL9LywEBvbGWEkqjnqpQkmjVr\nxsMPP2x47OnpiaendLwJcb20rAK2/5mEc1MrBkrntGgAbtoGfvLJJzl27Filz+Xm5vL111/z448/\n1khgQtRHa3bFodUpRvdvh7mZXGIS9d9NWxLPPPMM7777LgkJCXTv3h0nJyeKioqIi4sjOTmZCRMm\nMHHixNqKVYg6Sa9XHI1NZ+vBJE6du0xrdzsCZH1p0UDcNEn4+PjwxRdfcPHiRfbv309qaiqWlpYM\nHTqUnj17YmFhUVtxClEnnU26wpINJ7mUVQBAp9aOPBTaAZMqjv4Toq6rUp+Eh4fHHd0TIURDlltQ\nwqe/HCc7r4T+PTwY3LMlnq5NjB2WENWqShdNMzIyeOGFF3jooYcAiImJ4YcffqjRwISo677b8jdZ\nucWE92vLw8M6SYIQDVKVksQrr7xCz549yc7OBqBdu3Z8//33NRqYEHXZvpMp7D91Ce8WTRnWu5Wx\nwxGixlQpSaSmpjJx4kRMTU0BsLCwwMRERm6Ixikzu5DvfjuNpbkpj97fCVP5XxANWJXO7usXFsrO\nzq4w3YYQjUFqZj6frD5OfpGWCYO8cXW0MXZIQtSoKnVcDxkyhP/7v/8jLy+P1atX8/333zNmzJia\njk2IOkOr0/Pb/kTW7k5Aq9PTp6s7/Xs0N3ZYQtS4KiWJRx99lHXr1pGdnc3OnTuZPHmyjHYSjYJS\nimOxGfy8M46ktFzsbS14KLQD/h1dqjzJpRD1WZVngR05ciQjR46syViEqDZanf6uJtXT6vT8FZfJ\nuj3xJKSUziZ6b3cPmY9JNDpVShIZGRl89913JCYmotVqDds//PDDGgtMiDuVmJrDv5Yd4KmIbvh1\ncKlSGaUUv/wRz7HYDC7nFBrWntYAvXxcGdG3DZ4uMsRVND5VShJPPvkknTt3JigoyDDCSYi66kzS\nFZSCY7EZVU4S6/YksD4qAXMzE5rZWdLc2Ra3ZjYM6ukpyUE0alVKEgUFBbz22ms1HYsQ1eLS5dIp\nMs6lVm3RmX0nU1i7Ox7npla8MsUfe1uZbkaIq6qUJHr06MHff/9Nx44dazoeIe5a2v/mUUpOy63Q\nN/H7gfOsj0qgh7cTfbt6YGZmwpcbY7C2NOXZsd0lQQhxnSoliQkTJjBp0iTc3d2xtLQ0bF+1atUd\nH3jZsmWsWrUKjUZDhw4dePPNN2XCQFEtrk62p9UpLqTn0cqtbPnGvSdSyC0oYc/xFPYcTwFAo4Fn\nRvWghVxWEqKCKiWJF198kZkzZ9K5c+dq6ZNITU3l22+/ZfPmzVhYWPDcc8+xadMmwsPD73rfonHT\nK2VoSUDpJaerSaKwWEtiai5eze0ZG+zF7uMXOR6bQXj/dnRt52SskIWo06qUJCwtLXnkkUeq9cB6\nvZ6CggJMTEwoLCzE1dW1WvcvGqcrucWUaPU42VuRkV1IYmqu4bnYC9nolaJDSwc6tnKkYytHI0Yq\nRP1QpYHk/fr1Y9euXdV2UDc3N6ZNm0ZwcDD9+/fHzs6OPn36VNv+ReN16XI+APd0cMZEoynXeX3m\nfBYA7T0djBKbEPVRlVoSP/30E59//jm2trZYWFiglEKj0bB37947Omh2djbbtm1jx44d2NnZ8cwz\nz7B+/XpGjBhxR/sT4qqr/RGeLk3wcLLhfGoueqUw0Wg4k3QFAG/PpsYMUYh6pUpJ4ueff67Wg0ZF\nRdGyZUscHEp/0YWGhnL48OFbJgkXF7ubPt+YSF2UubYu8or1ALRv04xzl3JJPpRECRrcm9kSdzGb\nVu52tG3VzFih1jg5L8pIXVSPKiWJFi1aVOtBmzdvztGjRykqKsLCwoJ9+/bRrVu3W5ZLS6vauPeG\nzsXFTurif66vi4Tk0ktKFoCbgzUAh0+l4OZoQ1GxjnYe9g227uS8KCN1UeZuk+VNk8SLL77IO++8\nw5gxYyqdzOxOh8B2796dsLAwwsPDMTMzo3PnzowfP/6O9iXEtdKyCjAz1eBoZ0lrt9IhrYmpuVzJ\nLQagvVxqEuK23DRJXLp0CYB//OMf1X7gp59+mqeffrra9ysat7SsQpyaWmNioqGla+kvqHMpOVhb\nlp7qHaTTWojbctMkcXW50oCAgFoJRoi7kV+oJbeghLYe9gDYWJnh6mBNYmoOJiYamtlb4tTUyshR\nClG/yLqLosG4ehOd6//6IgBauduRV6glJ79EWhFC3IGbtiROnz5NUFBQhe13OwRWiJpwdfiri2NZ\nkmjt1oSDMaWXTaU/Qojbd9Mk0aZNGz7//PPaikWI25KTX4xdcdn6JldvpLu2JdH6mnmb2reUloQQ\nt+umScLCwqLah78KUR2KS3S8siQab08HZo0uHT6dVklL4uq8TbZWZjR3tq39QIWo526aJMzNZZlG\nUTfFJGaRk1/C4dNpnEvJobW7nWEdCZdrOqftbS3o29UdFwdrTGRNaiFu2007rn/66afaikOI23Is\nNt3w9+8HzwOlLQlHO0sszMvPVPzI/Z0ZeW/bWo1PiIZCRjeJekcpxbHYDKwtzWjh0oTok6mkXykg\nM7sIl2v6I4QQd0+ShKh3Lmbkk36lkC5tmzGqfzt0esWqyFgU4OIg90EIUZ0kSYh651hsBgA9vJwI\n6dkSWysz9p8qHebqKi0JIaqVJAlR71ztj+jazgkrSzMG+JaNwLt2ZJMQ4u5JkhD1Sn6hljNJV2jr\nYUdT29I10Qf6tcDUpHTkkquDjTHDE6LBkSQh6pWTCZno9Ipu16xJ3czeij5d3bG2NMPDSZKEENWp\nSutJCFFXGPojvJ3LbZ8ytCMPDGxvmO1VCFE95D9K1Bt6pTgel4G9jTmt3csvpGJqYoKNlTSMhahu\n8l8l6o0T8ZlcySumWzsnuXtaiFpitJZETk4OL7/8MmfOnMHExIRFixbRo0cPY4Uj6jCtTs+GqAQ2\nRJ1DAwR1dTd2SEI0GkZLEgsXLmTAgAF89NFHaLVaCgsLjRWKqMMupOexdONJ4i/m4GRvySPDO+PT\n2tHYYQnRaBglSeTm5nLw4EH+/e9/lwZhZkaTJk2MEYqoo/ILtazbE8+2Q0no9IqgLu48FNoBGyvp\nRhOiNhnlPy4pKQlHR0fmzZtHTEwMXbt25eWXX8bKSqZUaOwKirTsO5nK2j/iyM4vwbmpFRMHteee\nDi7GDk2IRkmjlFK1fdC//vqLBx54gBUrVtCtWzcWLlyInZ0dzzzzTG2HIuoApRSH/05j+8Hz7P3r\nIsUlOqwsTBk3qAPhA7wqzOoqhKg9RmlJuLu74+7uTrdupYvFhIWFsWTJkluWS0vLqenQ6gUXF7sG\nUxdZuUV8ufEUf8VnAuDmaE1QV3f6dW+Oo50lV7Lyb1q+IdXF3ZK6KCN1UcbFxe7WL7oJoyQJZ2dn\nPDw8iI+Pp23btuzbtw8vLy9jhCKM6PDpNL7aHENuQQld2zVjVN+2tGtuj0aGtwpRZxitF/CVV17h\nhRdeQKvV0rJlS958801jhSJqWVGxjhXbz7DzyAXMzUx4KLQDA/1aSHIQog4yWpLw8fHh559/Ntbh\nhZEkpGSzeN1JUjPz8XRpwuMjO9PCRUa2CVFXyXhCUSu0Oj2/7U/klz/i0ekVYQEtGd3fC3Mzuelf\niLpMkoSoUVeXGv1x+1lSMvNp2sSCR4d3pkvbZsYOTQhRBZIkRI1JTM1hVWQsf8VnotFAiF8LIvq1\no4m1ubFDE0JUkSQJUe0SU3NYtyeBP0+nAdCljSMPDGqPp/Q9CFHvSJIQ1SY9q4CVkbEciCldb9qr\nuT2j7m1Ll7bNZOSSEPWUJAlx1wqKtGzad47f9p9Hq9PT1sOO8H7t6CrJQYh6T5KEuGNKKfaeSGFl\nZCxXcotxtLNkbLAXgZ3dZL0HIRoISRLijpxLyWH576c5m3wFczMTRvZtw7DA1lhayDxLQjQkkiTE\nbSko0rJmVxzbDiWhgJ4dXHhgoDfODtbGDk0IUQMkSYgqOxabwbe/xZCRXYRbMxsmDelAlzZyv4MQ\nDZkkCXFLl3OK+GnHWaJPpmJqouH+Pm0Y0ac15mZyaUmIhk6ShLghrU7P1oNJrN0TT1GxjrYe9kwb\n5oOnq9zvIERjIUlCVKCU4sjZdH7eGceF9DyaWJszYag3/Xo0l1FLQjQykiREOafOXWb1zlhiL2Sj\n0cAA3+aMGeAlU2kI0UhJkhBA6Qpxy7ec5tD/ptLo2cGF8P7taOFsa+TIhBDGJEmikVNK8cexi/y4\n/SwFRVraezZlwqD2tPWwN3ZoQog6wKhJQq/XM2bMGNzc3Pjss8+MGUqjdC4lhxXbzvD3+SysLEyZ\nHNaRAb7S7yCEKGPUJPHNN9/g5eVFbm6uMcNodDKzC1mzK46ov1JQgK+3M5OGdKCZvZWxQxNC1DFG\nSxIpKSns3LmTmTNn8tVXXxkrjEalqETHb9GJbNp3jmKtHk+XJjwwyFtuiBNC3JDRksSiRYuYO3cu\nOTk5xgqh0VBKcSDmEj/tOEtmdhFNbS14KLQdfbt5YGIil5aEEDdmlCQRGRmJs7MznTp1Ijo6usrl\nXFzsajCq+qUqdaGUYv+JFFZsPc3Z81mYmZowdmB7xg1qj41VwxnSKudFGamLMlIX1UOjlFK1fdD3\n3nuPdevWYWpqSlFREXl5eYSGhvL222/ftFxamrQ6oPTkv1VdHDmTzpo/4jh/KRcN4O/jypgB7XB1\ntKmdIGtJVeqisZC6KCN1UeZuk6VRksS19u/fz5dfflml0U3yoZe62T/A5Zwilv9+mj9Pp6HRQGAn\nN4b3adNg73eQL4MyUhdlpC7K3G2SkPskGgi9Uuw8coFVkWcpKNLRoaUDU8I60ryBJgchRO0wepII\nCAggICDA2GHUa38nXuaHbWdITM3F2tKMqUM7yjxLQohqYfQkIe7cpawCVu44y6G/S6fSCOrixthg\nbxztLI0cmRCioZAkUQ/lF5awKjKWLQcS0eoUXi3smTioA+2ay1QaQojqJUmiHlFKEfVXCqt3xXE5\npwhHO0vGhXgR2MkNjVxaEkLUAEkS9URyeh7f/vY3p89nYWFuyqh72zI0sBWW5rI6nBCi5kiSqOOK\ninWsj0rgt/2J6PSKe9o78/QD96DR6owdmhCiEZAkUUcppYg+mcrKyFgu5xThZG/FQ6Ed8G3vjIuj\njYwBF0LUCkkSddC5lByW/36as8lXMDM14f4+bRjeuzWWFnJpSQhRuyRJ1CG5BSWs3hXHzsPJKKBn\nRxceCPHG2cHa2KEJIRopSRJ1gFanZ/exi6zeFUduQQkeTjY8FNqBzjKFtxDCyCRJGJFWpyfqrxQ2\nRCWQfqUQSwtTxod4M9jfEzNTE2OHJ4QQkiSM5fCZNH7Yeob0K4WYmZowuKcn9wW1xqGJ3C0thKg7\nJEnUsit5xXz/+2kOxFzC1ETDoJ6e3Ne7tUylIYSokyRJ1BK9XvHHsQusiowlr1CLVwt7Hh7WqcFO\n4S2EaBgkSdSC0+ez+H7raRJTc7G0MOWh0A6E+LWQWVqFEHWeJIkalJZVwM87Y9l/6hIAQV3cGRvs\nJZeWhBD1hiSJGpBXWMLGqHNsPXQerU7R1sOeBwe3x6tFU2OHJoQQt8UoSSIlJYW5c+eSnp6Oqakp\n48aNY8qUKcYIpVoppfjj2EVW7jhLXqEWJ3tLxgzwIqCzm1xaEkLUS0ZJEqampsybN49OnTqRl5fH\n6NGj6du3L15eXsYIp1qkXs7n680xxCRmYWVhyrhgLwb7e2JuJlNpCCHqL6MkCRcXF1xcXACwtbXF\ny8uLS5cu1cskUVCk5fcD59m47xwlWj2+3s5MGtKBZvZWxg5NCCHumtH7JJKSkoiJiaF79+7GDuW2\nlGh17Dh8gY17E8jJL8He1oJH7++Af0cXWQBICNFgaJRSylgHz8vLY/LkyTz55JMMHjzYWGHclqIS\nHb9Hn+PnHWdJzyrA2tKMiGBvRvVvh42VubHDE0KIamW0JKHVann88cfp378/U6dOrVIZY66hUFyi\nY/ufyfy6P5HsvGIszEwI8WvBfb1bY2djUauxuLjYyXoS/yN1UUbqoozURRkXF7u7Km+0y03z58/H\n29u7ygnCmOIvZrNkw0kuZuRjZWHK8KDWhPq3xN62dpODEELUNqMkiUOHDrF+/Xo6dOhAeHg4Go2G\n559/nv79+xsjnBvS6vSs25PApr3n0CvF4J6ejOrXFlu5rCSEaCSMkiR69uzJqVOnjHHoKjubdIVv\nfoshKS0PJ3srpg/vRKfWjsYOSwghapXRRzfVNbkFJayKPMuuoxcB6N+jOQ8M9MbaUqpKCNH4yDff\nNf48ncayzTHkFpTg6WLL5LCOtPd0MHZYQghhNJIkKL0h7oetZ9h9/CLmZiaMC/Ei1L+lrA4nhGj0\nGn2SOJmQybLNMaRfKaSVWxMeG9FF1ngQQoj/abRJIuNKIT9uP8PBv9PQAMODWjPq3rbSehBCiGs0\nuiRRotXz6/5ENkYlUKzV49XCnkmhHWntfnc3nAghREPUqJLEsdh0vt96hkuXC7C3tWBymBdBXd1l\nGm8hhLiBRpEk0rIKWLHtDIfPpGOi0RDq35JR97bFxqpRvH0hhLhjDfpbsqhEx+Z959i0LxGtTk+H\nlg5MCu2Ap2sTY4cmhBD1QoNMEkopDv2dxo/bz5KRXYhDEwvGD/QmsJObTOMthBC3ocEliYSUbFZs\nPcPppCuYmmgYFtiK+/u0kTumhRDiDjSYb84recX8HBnLnuMXUcA97Z0ZH+KNWzMbY4cmhBD1Vr1P\nElqdnu2Hkli7J56CIh2eLrZMGNSezm2aGTs0IYSo9+p1kvg78TLfbjnNhfQ8bK3MmDSkAwN8m2Nq\nIjfECSFEdaiXSSK3oISfdpxl97GLaIBg3+ZE9G9X6yvECSFEQ2e0JLFr1y4WLVqEUooxY8YwY8aM\nW5ZRShF9MpXvt54ht6CElq5NmDrUh3bN7WshYiGEaHyMkiT0ej2vv/46y5Ytw9XVlbFjxzJo0CC8\nvLxuWCbjSgEfrTrG0dgMLMxNGB/iTWgvT7m0JIQQNcgoSeLYsWO0bt2aFi1aADB8+HC2bdt20yTx\n1NvbySvU0qm1I1OH+eDqYF1b4QohRKNllCSRmpqKh4eH4bGbmxvHjx+/aRm9gqlDO9K/R3O5IU4I\nIWqJUZKEUuq2yyx9JZTCvKIaiEYIIcSNGCVJuLu7c+HCBcPj1NRUXF1db1rGzsZCRi9dw8VFpja/\nSuqijNRFGamL6mGUXt9u3bqRmJhIcnIyxcXFbNy4kUGDBhkjFCGEEDdhlJaEqakpr776KtOnT0cp\nxdixY2/aaS2EEMI4NOpOOgiEEEI0CnKTgRBCiBuSJCGEEOKGJEkIIYS4oTqfJHbt2sXQoUMJCwvj\n888/N3Y4tSolJYUpU6Zw3333MWLECL755hsArly5wvTp0wkLC+ORRx4hJyfHyJHWHr1eT0REBDNn\nzgQgKSmJ8ePHExYWxuzZs9FqtUaOsHbk5OTwzDPPMGzYMIYPH87Ro0cb7XmxbNky7r//fkaMGMGc\nOXMoLi5uNOfF/Pnz6dOnDyNGjDBsu9l58MYbbzBkyBBGjRrFqVOnqnSMOp0krs7xtHTpUjZs2MDG\njRuJjY01dli1xtTUlHnz5rFp0yZWrFjB8uXLiY2N5fPPPycoKIjffvuNwMBAFi9ebOxQa80333xT\nbiTcf/7zH6ZNm8Zvv/2GnZ0dq1atMmJ0tWfhwoUMGDCAzZs3s3btWtq1a9coz4vU1FS+/fZbVq9e\nzfr169HpdGzcuLHRnBejR49m6dKl5bbd6DzYuXMniYmJbNmyhQULFvDaa69V6Rh1OklcO8eTubm5\nYY6nxsLFxYVOnToBYGtri5eXF6mpqWzbto2IiAgAIiIi2Lp1qzHDrDUpKSns3LmTcePGGbbt27eP\nsLAwoLQufv/9d2OFV2tyc3M5ePAgY8aMAcDMzAw7O7tGe17o9XoKCgrQarUUFhbi6upKdHR0ozgv\n/P39sbcvPwv29efB1e/Mbdu2ER4eDkCPHj3IyckhPT39lseo00misjmeLl26ZMSIjCcpKYmYmBh6\n9OhBRkYGzs7OQGkiuXz5spGjqx2LFi1i7ty5hrm7Ll++TNOmTTH530zA7u7ujeL8SEpKwtHRkXnz\n5hEREcGrr75KQUFBozwv3NzcmDZtGsHBwfTv3x87Ozs6d+6Mvb19ozsvrsrMzCx3HmRmZgJw6dIl\n3N3dDa9zc3MjNTX1lvur00lCbuEolZeXxzPPPMP8+fOxtbVtlBMcRkZG4uzsTKdOnQznhVKqwjnS\nGOpGq9Vy8uRJHnzwQdasWYO1tTWff/55o3jv18vOzmbbtm3s2LGDP/74g4KCAnbt2lXhdY2xbq5X\n2fdpVeqlTq9MdydzPDU0Wq2WZ555hlGjRjF48GAAnJycSE9Px9nZmbS0NJo1a/jref/5559s376d\nnTt3UlRURF5eHosWLSInJwe9Xo+JiQkpKSmN4vxwd3fH3d2dbt26ATBkyBC++OKLRnleREVF0bJl\nSxwcHAAYPHgwhw8fJjs7u9GdF1fd6Dxwc3MjJSXF8Lqq1kudbknIHE+loxe8vb2ZOnWqYdvAgQNZ\nvXo1AGvWrGkUdTJ79mwiIyPZtm0b7733HoGBgfznP/8hMDCQX3/9FWg8deHs7IyHhwfx8fFAab+M\nt7d3ozwvmjdvztGjRykqKkIpxb59+2jfvn2jOi+ubyHc6DwYNGgQv/zyCwBHjhzB3t7ecFnqZur8\ntBy7du1i4cKFhjmeqrLMaUNx6NAhJk2aRIcOHdBoNGg0Gp5//nm6d+/Oc889x8WLF2nevDkffvhh\nhc6rhmz//v18+eWXfPbZZ5w/f57Zs2eTnZ1Np06deOeddzA3Nzd2iDUuJiaGl19+Ga1WS8uWLXnz\nzTfR6XSN8rz45JNP2LhxI2ZmZnTu3Jk33niDlJSURnFezJkzh+joaLKysnB2dmbWrFkMHjyYZ599\nttLzYMGCBfzxxx9YW1vz5ptv0qVLl1seo84nCSGEEMZTpy83CSGEMC5JEkIIIW5IkoQQQogbkiQh\nhBDihiRJCCGEuCFJEkIIIW5IkoSocwYOHMjZs2dr5ViffPJJuWmk582bx/Lly+96v/PmzWPEiBHM\nnj37rvd1MzExMWzevLlGjyEaN0kSolH75JNPKCkpqdZ9pqens2XLFtavX897771Xrfu+3smTJ+84\nSej1+mqORjREkiREvREfH89jjz3GuHHjCA8PN0w9AODj48PixYsZO3YsoaGhbNmyxfDcb7/9xrBh\nwxg9ejSLFy/Gx8eHgoICFixYgEajYcKECURERJCbmwvA6dOnmTp1KmFhYbz00ks3jOeXX35hxIgR\njBo1ilmzZpGZmUleXh5Tp06lqKiIiIgIvv7663Jl1q5dy9NPP214rNPp6Nevn2GOsiVLljB+/HhG\njx7NE088QUZGBgAlJSW89dZbjBgxgvDwcGbNmkVWVhYff/wx+/btIyIigoULFwKlsxREREQwatQo\npk2bxvnz54HSO9XDw8N54403mDBhAn/88cfdfByisVBC1DEhISHqzJkz5bZptVoVERGh4uLilFJK\n5ebmqrCwMMPjjh07quXLlyullDp06JDq16+fUkqp9PR0FRAQoBITE5VSSn311VfKx8dH5efnG8oV\nFGMkQHAAAAOxSURBVBQYjvPSSy+pBx98UBUXF6vi4mI1fPhwFRUVVSHG06dPq3vvvVelp6crpZT6\n4IMP1HPPPaeUUiopKUn17t270vdWUFCgevfurS5fvqyUUmr79u1q6tSpSiml1q5dq1599VXDa7//\n/ns1Z84cpZRSH3/8sZo1a5bSarVKKWUov3r1avXMM88YymRkZKjevXur2NhYpZRSK1euVOPGjVNK\nKRUdHa06d+6sjh49WmlsQlRGWhKiXkhISCAuLo7Zs2cTHh7OQw89RElJSbmVCu+77z4AfH19SUtL\no7i4mKNHj9K1a1datmwJwNixYyvsW103M83gwYMxNzfH3Nyczp07k5iYWKFMdHQ0wcHBODk5ATBh\nwgSioqJu+T6srKwYNGgQGzZsAEonYLu6eND27dvZu3cv4eHhhIeH8/3333Px4kWgdKr0KVOmYGpq\nCmCY9fR6R48epVOnTrRr1w6AMWPGcOrUKfLz8wFo3bo13bt3v2WcQlxVp6cKF+IqpRTNmjVjzZo1\nlT6v0WiwtLQEMCw2o9PpKiSA6x9XxsLCwvC3qalppesjK6UqzMV/9bi3Eh4ezptvvsn999/P/v37\neeeddwz7fOKJJxg9enSlx6uKyuK69rGNjU2V9iPEVdKSEPVC27ZtsbKyYu3atYZtcXFx5OXlARW/\nRK8+9vX15cSJE4br8tf2YwA0adKk3ELxVRUUFMTOnTsNfQY//vgjffr0qXD8yvj7+5Obm8t7771H\naGioIbkNHDiQ77//nuzsbACKi4uJiYkBICQkhG+++cbQyX511bkmTZoY+lKuvt9Tp04ZphFfvXo1\nnTt3luQg7pi0JESdo9FoePjhhzEzMzP8Ml6/fj2fffYZCxcu5Msvv0Sn0+Hs7MwHH3xgKHP9PqB0\nAZZ//etfzJgxA0dHR4KDgzEzM8Pa2hqAadOmMWXKFKytrfn222+rHKO3tzezZ8/m4YcfxsTEhJYt\nW7JgwYIKx7+R8PBwPvroo/9v5w5xGASiIAwPBoMhHADNBRCcgtUEzQWQSByChAOgSHB4joVBLqlo\ngnumadK0/T/51LrZyeat1nW9Z2VZ6jgO1XWtIAh0XZeqqlKWZWqaRuM4yjmnMAyVpqmmaVJRFJrn\nWc455Xmurus0DIPatpX3XkmS3E0FeAVfhePnneepKIokPW/W27a9ZRcC+Ac0Cfy8ZVm077u894rj\nWH3ff/pIwNegSQAATDxcAwBMhAQAwERIAABMhAQAwERIAABMhAQAwPQAVSnSA55bZkwAAAAASUVO\nRK5CYII=\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f47b8e3bd90\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(eager_means)\n", + "plt.ylabel('Time(s)')\n", + "plt.xlabel('Length of vector')\n", + "_ = plt.title('Time to sum the elements of 1000 vectors (Eager)')\n", + "_ = plt.ylim(ymin=0)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "Autograph vs. Eager vs Graph sum", + "provenance": [ + { + "file_id": "1olZkm32B7n7pQwlIAXR0_w8fZhRHCtkX", + "timestamp": 1531755808890 + } + ], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb index 324b23c24b5a7970d7f20ed955839ba1cf1774fc..44532cb078f9bd1578172f8a7d8a4b55cd21a7cb 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb @@ -190,7 +190,6 @@ " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", " self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", "\n", - "\n", " def _rnn_layer(self, chars, cell, batch_size, training):\n", " \"\"\"A single RNN layer.\n", "\n", @@ -203,13 +202,12 @@ " Returns:\n", " A Tensor of shape (max_sequence_length, batch_size, output_size).\n", " \"\"\"\n", - " hidden_outputs = []\n", - " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n", + " hidden_outputs = tf.TensorArray(tf.float32, 0, True)\n", " state, output = cell.zero_state(batch_size, tf.float32)\n", " for ch in chars:\n", " cell_output, (state, output) = cell.call(ch, (state, output))\n", " hidden_outputs.append(cell_output)\n", - " hidden_outputs = hidden_outputs.stack()\n", + " hidden_outputs = autograph.stack(hidden_outputs)\n", " if training:\n", " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", " return hidden_outputs\n", @@ -223,7 +221,7 @@ "\n", "\n", " def call(self, inputs, training=False):\n", - " \"\"\"The RNN model code. Uses Eager and \n", + " \"\"\"The RNN model code. Uses Eager.\n", "\n", " The model consists of two RNN layers (made by lower_cell and upper_cell),\n", " followed by a fully connected layer with ReLU activation.\n", @@ -243,7 +241,8 @@ " seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n", "\n", " # Grab just the end-of-sequence from each output.\n", - " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n", + " indices = (length - 1, range(batch_size))\n", + " indices = tf.stack(indices, 1)\n", " sequence_ends = tf.gather_nd(seq, indices)\n", " return self.relu_layer(sequence_ends)\n", "\n", @@ -381,7 +380,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 107, "metadata": { "colab": { "autoexec": { @@ -392,9 +391,9 @@ }, "colab_type": "code", "executionInfo": { - "elapsed": 10604, + "elapsed": 5454, "status": "ok", - "timestamp": 1524095272039, + "timestamp": 1529952160455, "user": { "displayName": "", "photoUrl": "", @@ -403,7 +402,7 @@ "user_tz": 240 }, "id": "2pg1AfbxBJQq", - "outputId": "9c924b4f-06e1-4538-976c-a3e1ddac5660", + "outputId": "4aef3052-f7c7-4bb1-a0a2-73fef2e96efb", "slideshow": { "slide_type": "-" } @@ -413,7 +412,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Eval loss at step 100: 0.0674834\n" + "Eval loss at step 100: 0.0705221\n" ] } ], @@ -423,8 +422,8 @@ " 'learning_rate': 0.01,\n", "}\n", "\n", - "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n", - "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n", + "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/train.csv\"\n", + "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/test.csv\"\n", "data_dir = \"tmp/rnn/data\"\n", "\n", "regressor = tf.estimator.Estimator(\n", @@ -457,7 +456,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 108, "metadata": { "colab": { "autoexec": { @@ -468,9 +467,9 @@ }, "colab_type": "code", "executionInfo": { - "elapsed": 7990, + "elapsed": 3432, "status": "ok", - "timestamp": 1524095280105, + "timestamp": 1529952163923, "user": { "displayName": "", "photoUrl": "", @@ -479,7 +478,7 @@ "user_tz": 240 }, "id": "dxHex2tUN_10", - "outputId": "2b889e5a-b9ed-4645-bf03-d98f26c72101", + "outputId": "1ff438f2-b045-4f4e-86a0-4dae7503f6b2", "slideshow": { "slide_type": "slide" } @@ -491,12 +490,12 @@ "\u003clink rel=stylesheet type=text/css href='/nbextensions/google.colab/tabbar.css'\u003e\u003c/link\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3f36aa6cd0\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222a110\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -507,12 +506,12 @@ "\u003cscript src='/nbextensions/google.colab/tabbar_main.min.js'\u003e\u003c/script\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3eca67f7d0\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222a8d0\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -520,15 +519,15 @@ { "data": { "text/html": [ - "\u003cdiv id=\"id1\"\u003e\u003c/div\u003e" + "\u003cdiv id=\"id3\"\u003e\u003c/div\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3eca67f8d0\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222a050\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -536,16 +535,16 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa22-4362-11e8-91ec-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id1\", \"borderColor\": [\"#a7a7a7\"], \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0});\n", - "//# sourceURL=js_71b9087b6d" + "window[\"8a03307e-78a7-11e8-99f9-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id3\", \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0, \"borderColor\": [\"#a7a7a7\"]});\n", + "//# sourceURL=js_dc5d7f2784" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f950\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a190\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -553,16 +552,16 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa23-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_e390445f33" + "window[\"8a03307f-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_be7950150b" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222ac90\u003e" ] }, "metadata": { "tags": [ - "outputarea_id1" + "outputarea_id3" ] }, "output_type": "display_data" @@ -570,17 +569,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_241dd76d85" + "window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_d0c3bd4eaa" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222aad0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -588,17 +587,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_60c64e3d50" + "window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n", + "//# sourceURL=js_f10f6eba86" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222aed0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -606,17 +605,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa26-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_14ea437cbd" + "window[\"8a033082-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_ff29697179" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222abd0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -624,17 +623,17 @@ { "data": { "application/javascript": [ - "window[\"e8ddfa27-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_09294c2226" + "window[\"8a033083-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_ff85295dc7" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fcd0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222ab90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -642,17 +641,17 @@ { "data": { "application/javascript": [ - "window[\"ec965514-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_e5e8266997" + "window[\"8b18d8dc-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_ed7aabfedb" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a110\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -660,17 +659,17 @@ { "data": { "application/javascript": [ - "window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_07a097f0ee" + "window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_c86f8feaf4" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fc90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222acd0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -678,17 +677,17 @@ { "data": { "application/javascript": [ - "window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_790d669ca8" + "window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n", + "//# sourceURL=js_4d0fde6662" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f8d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222ae50\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -696,17 +695,17 @@ { "data": { "application/javascript": [ - "window[\"ec965517-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_d30df771f0" + "window[\"8b18d8df-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_3f66d52720" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a210\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -714,32 +713,32 @@ { "data": { "application/javascript": [ - "window[\"ec965518-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_8a43a2da4b" + "window[\"8b18d8e0-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_375f5ae6d7" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd7222a310\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACMBJREFUeJzt3F+I1XX+x/G32zjiFERUpgaFd2JBzOg5joX4h0SiMgmM\n/uhVGIlgFBlERGB3hUEkhkRdtDfRP1ACL6KpLBqcguxCjEAkmGamQcSohFHzsxe7O6zssvsydtff\n+ns8rs758j3f8z7fiyef7/k3o7XWCiDwh4s9APC/QzCAmGAAMcEAYoIBxAQDiAkGF8XTTz9d3W63\n7rvvvhoZGakVK1Zc7JEICMYlbvXq1TU8PHyxxzjPV199VcPDw/XZZ5/V22+/XVVVM2bMuMhTkRAM\n/qt+++23+uGHH+r666+vWbNmXexxuECCcQl76qmnanx8vLZs2VIDAwP1+uuv1zfffFP3339/dTqd\nWr9+fY2MjEzvv2nTpnr55ZfrgQceqIGBgXr44Yfr5MmTVVV1+vTp2r59ey1durQ6nU5t2LChTpw4\nUVVVk5OTtWXLllq6dGmtXbu23nnnnelj7tq1q7Zt21bbt2+vJUuW1HvvvVfPPvtsHTp0qAYGBmrX\nrl1/N/fRo0dr06ZN1el06u67766hoaGqqhodHa1OpzO93zPPPFO33nrr9P3t27fXm2+++e89iZyv\ncUlbtWpVGx4ebq21NjEx0brdbjtw4EBrrbUvvviidbvdduLEidZaaxs3bmxr1qxp33//fZuammob\nN25sO3fubK219tZbb7VHH320TU1NtXPnzrXDhw+3X375pbXW2kMPPdR27NjRTp8+3Y4cOdIGBwen\nn/OVV15pN910U/voo49aa61NTU21999/vz344IPTMx48eLCtWLGitdbamTNn2po1a9qePXvamTNn\n2vDwcOvv72/Hjh2bfj2HDx9urbW2du3advvtt7ejR4+21lpbuXJlO3LkyH/qVNJas8L4f6D95edC\n+/btq5UrV9by5curqmrZsmV1880316effjq977333ls33HBD9fb21h133FFHjhypqqqenp46efJk\nHTt2rGbMmFGLFi2qyy+/vCYmJurrr7+uJ598smbOnFkLFy6sDRs21N69e6eP2d/fX6tXr66qqt7e\n3n8666FDh+rUqVP1yCOPVE9PTw0ODtaqVavqgw8+qKqqJUuW1MjISB0/fryqqtauXVtffvlljY6O\n1q+//loLFy78N501/pGeiz0A/z1jY2O1f//++vjjj6vqzyE5e/ZsLVu2bHqfa665Zvr27Nmz69Sp\nU1VVdc8999TExEQ98cQT9fPPP9e6devq8ccfr8nJybryyitr9uzZ04+bP39+HT58ePr+3Llz4xkn\nJydr3rx5522bP39+TU5OVlVVp9OpoaGhuu6666rb7Va32629e/dWb29vLV68+ALOBr+HYFzi/vbT\nh3nz5tX69etrx44dF3ycnp6e2rp1a23durXGxsZq8+bNtWDBgrrtttvqp59+qlOnTlVfX19VVY2P\nj9ecOXP+4Qz/ypw5c2p8fPy8bWNjY7VgwYKqqup2u/Xiiy/WvHnzqtPp1MDAQD333HPV29tb3W73\ngl8XF8YlySXu2muvrdHR0aqqWrduXQ0NDdXnn39e586dq6mpqRoZGakff/zxXx7n4MGD9d1339W5\nc+eqr6+venp66rLLLqu5c+dWf39/vfTSS3X69On69ttv6913361169b9rnlvueWW6uvrq9dee63O\nnj1bBw8erE8++aTuvPPOqqq68cYba9asWbVv377qdDp1xRVX1NVXX10ffvjheW+I8p8hGJe4zZs3\n1+7du6vb7db+/ftr9+7dtWfPnlq2bFmtWrWq3njjjen3OP7ZSuD48eO1bdu2Wrx4cd111121dOnS\n6Sjs3LmzRkdHa/ny5bVt27Z67LHHzrvMuRAzZ86sV199tQ4cOFCDg4P1/PPP1wsvvDC9wqj68yrj\nqquumr7U+WsoFi1a9Luek9yM1vyBDpCxwgBiggHEBAOICQYQ+z/7PYzjf/QRGVxM12z68u+2WWEA\nMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHE\nBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhAT\nDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEww\ngJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOI\nCQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAm\nGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhg\nADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIB\nxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQ\nEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBM\nMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHB\nAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQD\niAkGEBMMIDajtdYu9hDA/wYrDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4j9CY2LTAbbRbWuAAAAAElFTkSuQmCC\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAABTFJREFUeJzt3C+LV30eh/HP6EZvbP4ZJmkXDA6oQdZRMIhYLIKCMGVA\nyyaLT2ERLMqEDfoUFA2y3WpRrOKoSUSECePcYUEWdsN1OzfOyr5e8ZwT3unie34cfgvb29vbAxDs\n2e0BwK9DMIBMMIBMMIBMMIBMMIBMMPipXrx4MWfOnNntGfwgweCnW1hY2O0J/CDBYEe2trZ2ewI/\nkWDwh509e3bW19fn0qVLc/z48dnY2Jhbt27NyZMn59y5c/Pw4cPvz25ubs7t27dneXl5Ll68OC9f\nvtzF5ezUX3Z7AL+mJ0+ezPr6+uzfv3+uXr0658+fn7t3787GxsbcuHFjjhw5MqdPn5579+7N27dv\n5/nz5/P169dZXV3d7ensgBMGP+T69etz8ODBef369Xz69GnW1tZm7969s7S0NFeuXJnHjx/PzMzT\np09nbW1tfvvttzl48OBcu3Ztl5ezE04Y/JBDhw7NzMy7d+/mw4cPs7y8PDMz29vb8+3btzlx4sTM\nzHz8+PH7szMzi4uLP38sfxrBYEcOHz48S0tL8+zZs/96/8CBA7OxsTFHjx6dmX8Fhl+XVxJ25Nix\nY7Nv375ZX1+fzc3N2dramjdv3nz/cfPChQvz4MGD+fz587x//34ePXq0y4vZCcHgD/v37yj27Nkz\n9+/fn1evXs3KysqcOnVq7ty5M1++fJmZmZs3b87i4uKsrKzM6urqXL58ebdm8ydY8Ac6QOWEAWSC\nAWSCAWSCAWT/s99h/P3GX3d7Avxf+9s//vkf15wwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgGxhe3t7e7dHAL8GJwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg\nEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg+x1QoZHG4XIe4gAAAABJRU5ErkJggg==\n", "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f3ecc00bf10\u003e" + "\u003cmatplotlib.figure.Figure at 0x7fcd0d02dc90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -748,17 +747,17 @@ { "data": { "application/javascript": [ - "window[\"ec965519-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_893ad561f4" + "window[\"8b18d8e1-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_34b0509660" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55c90\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e850\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -766,17 +765,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_2d99e0ac17" + "window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_518a0f26fe" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67fe50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6ec90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -784,17 +783,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_5c19462e32" + "window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n", + "//# sourceURL=js_17eb3ff612" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55dd0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb50\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -802,17 +801,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551c-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_b9c8b7567b" + "window[\"8b18d8e4-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_99da807c8e" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55a50\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -820,17 +819,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551d-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_fd05186348" + "window[\"8b18d8e5-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_dee01cb4b6" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55810\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e610\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -838,16 +837,16 @@ { "data": { "text/html": [ - "\u003cdiv class=id_888646481 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" + "\u003cdiv class=id_853612217 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f3f32414810\u003e" + "\u003cIPython.core.display.HTML at 0x7fcd7222aa10\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -856,17 +855,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n", - "//# sourceURL=js_efef96e882" + "window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n", + "//# sourceURL=js_8c378be329" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e990\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -875,17 +874,17 @@ { "data": { "application/javascript": [ - "window[\"ec96551f-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", - "//# sourceURL=js_6eca889864" + "window[\"8b18d8e7-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_f0b946600c" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e310\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -894,17 +893,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 input\");\n", - "//# sourceURL=js_f02070cc60" + "window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 input\");\n", + "//# sourceURL=js_9e21b1373a" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b553d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6ea90\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -913,17 +912,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea973-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"].remove();\n", - "//# sourceURL=js_ed9faba660" + "window[\"8b18d8ea-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"].remove();\n", + "//# sourceURL=js_a7764968c6" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31a95450\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e5d0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -932,17 +931,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n", - "//# sourceURL=js_f3458d7074" + "window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n", + "//# sourceURL=js_74279d3ff0" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31a95250\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e890\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -951,17 +950,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea975-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", - "//# sourceURL=js_3ffd97bd6f" + "window[\"8b18d8ec-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_82b6c34cdb" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31a953d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1", + "id3_content_0", + "outputarea_id3", "user_output" ] }, @@ -970,17 +969,17 @@ { "data": { "application/javascript": [ - "window[\"ed8ea976-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_7f73e8bcca" + "window[\"8b18d8ed-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_ff6144734a" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e" + "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e" ] }, "metadata": { "tags": [ - "id1_content_0", - "outputarea_id1" + "id3_content_0", + "outputarea_id3" ] }, "output_type": "display_data" @@ -1043,28 +1042,6 @@ "kind": "local" }, "name": "RNN Colorbot using Keras and Estimators", - "provenance": [ - { - "file_id": "1CtzefX39ffFibX_BqE6cRbT0UW_DdVKl", - "timestamp": 1523579810961 - }, - { - "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG", - "timestamp": 1523016192637 - }, - { - "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", - "timestamp": 1522238054357 - }, - { - "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", - "timestamp": 1521743157199 - }, - { - "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", - "timestamp": 1520522344607 - } - ], "version": "0.3.2", "views": {} }, diff --git a/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb b/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..e7dfb13e15a8c30fd905f0ed9db9f0f67d9b6e88 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb @@ -0,0 +1,1129 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "u3B7Uh50lozN" + }, + "outputs": [], + "source": [ + "!pip install -U -q tf-nightly" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "qWUV0FYjDSKj" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph\n", + "\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kGXS3UWBBNoc" + }, + "source": [ + "# 1. AutoGraph writes graph code for you\n", + "\n", + "[AutoGraph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/README.md) helps you write complicated graph code using just plain Python -- behind the scenes, AutoGraph automatically transforms your code into the equivalent TF graph code. We support a large chunk of the Python language, which is growing. [Please see this document for what we currently support, and what we're working on](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/LIMITATIONS.md).\n", + "\n", + "Here's a quick example of how it works:\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "aA3gOodCBkOw" + }, + "outputs": [], + "source": [ + "# Autograph can convert functions like this...\n", + "def g(x):\n", + " if x \u003e 0:\n", + " x = x * x\n", + " else:\n", + " x = 0.0\n", + " return x\n", + "\n", + "# ...into graph-building functions like this:\n", + "def tf_g(x):\n", + " with tf.name_scope('g'):\n", + "\n", + " def if_true():\n", + " with tf.name_scope('if_true'):\n", + " x_1, = x,\n", + " x_1 = x_1 * x_1\n", + " return x_1,\n", + "\n", + " def if_false():\n", + " with tf.name_scope('if_false'):\n", + " x_1, = x,\n", + " x_1 = 0.0\n", + " return x_1,\n", + "\n", + " x = autograph_utils.run_cond(tf.greater(x, 0), if_true, if_false)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "I1RtBvoKBxq5" + }, + "outputs": [], + "source": [ + "# You can run your plain-Python code in graph mode,\n", + "# and get the same results out, but with all the benfits of graphs:\n", + "print('Original value: %2.2f' % g(9.0))\n", + "\n", + "# Generate a graph-version of g and call it:\n", + "tf_g = autograph.to_graph(g)\n", + "\n", + "with tf.Graph().as_default():\n", + " # The result works like a regular op: takes tensors in, returns tensors.\n", + " # You can inspect the graph using tf.get_default_graph().as_graph_def()\n", + " g_ops = tf_g(tf.constant(9.0))\n", + " with tf.Session() as sess:\n", + " print('Autograph value: %2.2f\\n' % sess.run(g_ops))\n", + "\n", + "\n", + "# You can view, debug and tweak the generated code:\n", + "print(autograph.to_code(g))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "m-jWmsCmByyw" + }, + "source": [ + "#### Automatically converting complex control flow\n", + "\n", + "AutoGraph can convert a large chunk of the Python language into equivalent graph-construction code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in AutoGraph.\n", + "AutoGraph will automatically convert most Python control flow statements into their correct graph equivalent. \n", + " \n", + "We support common statements like `while`, `for`, `if`, `break`, `return` and more. You can even nest them as much as you like. Imagine trying to write the graph version of this code by hand:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "toxKBOXbB1ro" + }, + "outputs": [], + "source": [ + "# Continue in a loop\n", + "def f(l):\n", + " s = 0\n", + " for c in l:\n", + " if c % 2 \u003e 0:\n", + " continue\n", + " s += c\n", + " return s\n", + "\n", + "print('Original value: %d' % f([10,12,15,20]))\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default():\n", + " with tf.Session():\n", + " print('Graph value: %d\\n\\n' % tf_f(tf.constant([10,12,15,20])).eval())\n", + "\n", + "print(autograph.to_code(f))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "FUJJ-WTdCGeq" + }, + "source": [ + "Try replacing the `continue` in the above code with `break` -- AutoGraph supports that as well! \n", + " \n", + "Let's try some other useful Python constructs, like `print` and `assert`. We automatically convert Python `assert` statements into the equivalent `tf.Assert` code. " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "IAOgh62zCPZ4" + }, + "outputs": [], + "source": [ + "def f(x):\n", + " assert x != 0, 'Do not pass zero!'\n", + " return x * x\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default():\n", + " with tf.Session():\n", + " try:\n", + " print(tf_f(tf.constant(0)).eval())\n", + " except tf.errors.InvalidArgumentError as e:\n", + " print('Got error message:\\n%s' % e.message)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "KRu8iIPBCQr5" + }, + "source": [ + "You can also use plain Python `print` functions in in-graph" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "ySTsuxnqCTQi" + }, + "outputs": [], + "source": [ + "def f(n):\n", + " if n \u003e= 0:\n", + " while n \u003c 5:\n", + " n += 1\n", + " print(n)\n", + " return n\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default():\n", + " with tf.Session():\n", + " tf_f(tf.constant(0)).eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "NqF0GT-VCVFh" + }, + "source": [ + "Appending to lists in loops also works (we create a tensor list ops behind the scenes)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "ABX070KwCczR" + }, + "outputs": [], + "source": [ + "def f(n):\n", + " z = []\n", + " # We ask you to tell us the element dtype of the list\n", + " autograph.set_element_type(z, tf.int32)\n", + " for i in range(n):\n", + " z.append(i)\n", + " # when you're done with the list, stack it\n", + " # (this is just like np.stack)\n", + " return autograph.stack(z)\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default():\n", + " with tf.Session():\n", + " print(tf_f(tf.constant(3)).eval())\n", + "\n", + "print('\\n\\n'+autograph.to_code(f))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "iu5IF7n2Df7C" + }, + "outputs": [], + "source": [ + "def fizzbuzz(num):\n", + " if num % 3 == 0 and num % 5 == 0:\n", + " print('FizzBuzz')\n", + " elif num % 3 == 0:\n", + " print('Fizz')\n", + " elif num % 5 == 0:\n", + " print('Buzz')\n", + " else:\n", + " print(num)\n", + " return num" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "EExAjWuwDPpR" + }, + "outputs": [], + "source": [ + "tf_g = autograph.to_graph(fizzbuzz)\n", + "\n", + "with tf.Graph().as_default():\n", + " # The result works like a regular op: takes tensors in, returns tensors.\n", + " # You can inspect the graph using tf.get_default_graph().as_graph_def()\n", + " g_ops = tf_g(tf.constant(15))\n", + " with tf.Session() as sess:\n", + " sess.run(g_ops) \n", + " \n", + "# You can view, debug and tweak the generated code:\n", + "print('\\n')\n", + "print(autograph.to_code(fizzbuzz))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "SzpKGzVpBkph" + }, + "source": [ + "# De-graphify Exercises\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8k23dxcSmmXq" + }, + "source": [ + "#### Easy print statements" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "dE1Vsmp-mlpK" + }, + "outputs": [], + "source": [ + "# See what happens when you turn AutoGraph off.\n", + "# Do you see the type or the value of x when you print it?\n", + "\n", + "# @autograph.convert()\n", + "def square_log(x):\n", + " x = x * x\n", + " print('Squared value of x =', x)\n", + " return x\n", + "\n", + "\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(square_log(tf.constant(4))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "_R-Q7BbxmkBF" + }, + "source": [ + "#### Convert the TensorFlow code into Python code for AutoGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "SwA11tO-yCvg" + }, + "outputs": [], + "source": [ + "def square_if_positive(x):\n", + " x = tf.cond(tf.greater(x, 0), lambda: x * x, lambda: x)\n", + " return x\n", + "\n", + "with tf.Session() as sess:\n", + " print(sess.run(square_if_positive(tf.constant(4))))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "GPmx4CNhyPI_" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def square_if_positive(x):\n", + "\n", + " pass # TODO: fill it in!\n", + "\n", + "\n", + "with tf.Session() as sess:\n", + " print(sess.run(square_if_positive(tf.constant(4))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qqsjik-QyA9R" + }, + "source": [ + "#### Uncollapse to see answer" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "DaSmaWUEvMRv" + }, + "outputs": [], + "source": [ + "# Simple cond\n", + "@autograph.convert()\n", + "def square_if_positive(x):\n", + " if x \u003e 0:\n", + " x = x * x\n", + " return x\n", + "\n", + "with tf.Graph().as_default(): \n", + " with tf.Session() as sess:\n", + " print(sess.run(square_if_positive(tf.constant(4))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qj7am2I_xvTJ" + }, + "source": [ + "#### Nested If statement" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "4yyNOf-Twr6s" + }, + "outputs": [], + "source": [ + "def nearest_odd_square(x):\n", + "\n", + " def if_positive():\n", + " x1 = x * x\n", + " x1 = tf.cond(tf.equal(x1 % 2, 0), lambda: x1 + 1, lambda: x1)\n", + " return x1,\n", + "\n", + " x = tf.cond(tf.greater(x, 0), if_positive, lambda: x)\n", + " return x\n", + "\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(nearest_odd_square(tf.constant(4))))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "hqmh5b2VyU9w" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def nearest_odd_square(x):\n", + "\n", + " pass # TODO: fill it in!\n", + "\n", + "\n", + "with tf.Session() as sess:\n", + " print(sess.run(nearest_odd_square(tf.constant(4))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "b9AXIkNLxp6J" + }, + "source": [ + "#### Uncollapse to reveal answer" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "8RlCVEpNxD91" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def nearest_odd_square(x):\n", + " if x \u003e 0:\n", + " x = x * x\n", + " if x % 2 == 0:\n", + " x = x + 1\n", + " return x\n", + "\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(nearest_odd_square(tf.constant(4))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "jXAxjeBr1qWK" + }, + "source": [ + "#### Convert a while loop" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "kWkv7anlxoee" + }, + "outputs": [], + "source": [ + "# Convert a while loop\n", + "def square_until_stop(x, y):\n", + " x = tf.while_loop(lambda x: tf.less(x, y), lambda x: x * x, [x])\n", + " return x\n", + "\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "zVUsc1eA1u2K" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def square_until_stop(x, y):\n", + "\n", + " pass # TODO: fill it in!\n", + "\n", + "\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "L2psuzPI02S9" + }, + "source": [ + "#### Uncollapse for the answer\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "ucmZyQVL03bF" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def square_until_stop(x, y):\n", + " while x \u003c y:\n", + " x = x * x\n", + " return x\n", + "\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "FXB0Zbwl13PY" + }, + "source": [ + "#### Nested loop and conditional" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "clGymxdf15Ig" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def argwhere_cumsum(x, threshold):\n", + " current_sum = 0.0\n", + " idx = 0\n", + "\n", + " for i in range(len(x)):\n", + " idx = i\n", + " if current_sum \u003e= threshold:\n", + " break\n", + " current_sum += x[i]\n", + " return idx\n", + "\n", + "n = 10\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " idx = argwhere_cumsum(tf.ones(n), tf.constant(float(n / 2)))\n", + " print(sess.run(idx))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "i7PF-uId9lp5" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def argwhere_cumsum(x, threshold):\n", + "\n", + " pass # TODO: fill it in!\n", + "\n", + "\n", + "n = 10\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " idx = argwhere_cumsum(tf.ones(n), tf.constant(float(n / 2)))\n", + " print(sess.run(idx))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "weKFXAb615Vp" + }, + "source": [ + "#### Uncollapse to see answer" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "1sjaFcL717Ig" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def argwhere_cumsum(x, threshold):\n", + " current_sum = 0.0\n", + " idx = 0\n", + " for i in range(len(x)):\n", + " idx = i\n", + " if current_sum \u003e= threshold:\n", + " break\n", + " current_sum += x[i]\n", + " return idx\n", + "\n", + "n = 10\n", + "with tf.Graph().as_default(): \n", + " with tf.Session() as sess:\n", + " idx = argwhere_cumsum(tf.ones(n), tf.constant(float(n / 2)))\n", + " print(sess.run(idx))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4LfnJjm0Bm0B" + }, + "source": [ + "# 3. Training MNIST in-graph\n", + "\n", + "Writing control flow in AutoGraph is easy, so running a training loop in a TensorFlow graph should be easy as well! \n", + "\n", + "Here, we show an example of training a simple Keras model on MNIST, where the entire training process -- loading batches, calculating gradients, updating parameters, calculating validation accuracy, and repeating until convergence -- is done in-graph." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Em5dzSUOtLRP" + }, + "source": [ + "#### Download data" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "xqoxumv0ssQW" + }, + "outputs": [], + "source": [ + "import gzip\n", + "import os\n", + "import shutil\n", + "\n", + "from six.moves import urllib\n", + "\n", + "\n", + "def download(directory, filename):\n", + " filepath = os.path.join(directory, filename)\n", + " if tf.gfile.Exists(filepath):\n", + " return filepath\n", + " if not tf.gfile.Exists(directory):\n", + " tf.gfile.MakeDirs(directory)\n", + " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n", + " zipped_filepath = filepath + '.gz'\n", + " print('Downloading %s to %s' % (url, zipped_filepath))\n", + " urllib.request.urlretrieve(url, zipped_filepath)\n", + " with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " os.remove(zipped_filepath)\n", + " return filepath\n", + "\n", + "\n", + "def dataset(directory, images_file, labels_file):\n", + " images_file = download(directory, images_file)\n", + " labels_file = download(directory, labels_file)\n", + "\n", + " def decode_image(image):\n", + " # Normalize from [0, 255] to [0.0, 1.0]\n", + " image = tf.decode_raw(image, tf.uint8)\n", + " image = tf.cast(image, tf.float32)\n", + " image = tf.reshape(image, [784])\n", + " return image / 255.0\n", + "\n", + " def decode_label(label):\n", + " label = tf.decode_raw(label, tf.uint8)\n", + " label = tf.reshape(label, [])\n", + " return tf.to_int32(label)\n", + "\n", + " images = tf.data.FixedLengthRecordDataset(\n", + " images_file, 28 * 28, header_bytes=16).map(decode_image)\n", + " labels = tf.data.FixedLengthRecordDataset(\n", + " labels_file, 1, header_bytes=8).map(decode_label)\n", + " return tf.data.Dataset.zip((images, labels))\n", + "\n", + "\n", + "def mnist_train(directory):\n", + " return dataset(directory, 'train-images-idx3-ubyte',\n", + " 'train-labels-idx1-ubyte')\n", + "\n", + "def mnist_test(directory):\n", + " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "znmy4l8ntMvW" + }, + "source": [ + "#### Define the model" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Pe-erWQdBoC5" + }, + "outputs": [], + "source": [ + "def mlp_model(input_shape):\n", + " model = tf.keras.Sequential((\n", + " tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n", + " tf.keras.layers.Dense(100, activation='relu'),\n", + " tf.keras.layers.Dense(10, activation='softmax')))\n", + " model.build()\n", + " return model\n", + "\n", + "\n", + "def predict(m, x, y):\n", + " y_p = m(x)\n", + " losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n", + " l = tf.reduce_mean(losses)\n", + " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n", + " accuracy = tf.reduce_mean(accuracies)\n", + " return l, accuracy\n", + "\n", + "\n", + "def fit(m, x, y, opt):\n", + " l, accuracy = predict(m, x, y)\n", + " opt.minimize(l)\n", + " return l, accuracy\n", + "\n", + "\n", + "def setup_mnist_data(is_training, hp, batch_size):\n", + " if is_training:\n", + " ds = mnist_train('/tmp/autograph_mnist_data')\n", + " ds = ds.shuffle(batch_size * 10)\n", + " else:\n", + " ds = mnist_test('/tmp/autograph_mnist_data')\n", + " ds = ds.repeat()\n", + " ds = ds.batch(batch_size)\n", + " return ds\n", + "\n", + "\n", + "def get_next_batch(ds):\n", + " itr = ds.make_one_shot_iterator()\n", + " image, label = itr.get_next()\n", + " x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n", + " y = tf.one_hot(tf.squeeze(label), 10)\n", + " return x, y" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "oeYV6mKnJGMr" + }, + "source": [ + "#### Define the training loop" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "3xtg_MMhJETd" + }, + "outputs": [], + "source": [ + "def train(train_ds, test_ds, hp):\n", + " m = mlp_model((28 * 28,))\n", + " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n", + "\n", + " # We'd like to save our losses to a list. In order for AutoGraph\n", + " # to convert these lists into their graph equivalent,\n", + " # we need to specify the element type of the lists.\n", + " train_losses = []\n", + " test_losses = []\n", + " train_accuracies = []\n", + " test_accuracies = []\n", + " autograph.set_element_type(train_losses, tf.float32)\n", + " autograph.set_element_type(test_losses, tf.float32)\n", + " autograph.set_element_type(train_accuracies, tf.float32)\n", + " autograph.set_element_type(test_accuracies, tf.float32)\n", + "\n", + " # This entire training loop will be run in-graph.\n", + " i = tf.constant(0)\n", + " while i \u003c hp.max_steps:\n", + " train_x, train_y = get_next_batch(train_ds)\n", + " test_x, test_y = get_next_batch(test_ds)\n", + "\n", + " step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n", + " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n", + "\n", + " if i % (hp.max_steps // 10) == 0:\n", + " print('Step', i, 'train loss:', step_train_loss, 'test loss:',\n", + " step_test_loss, 'train accuracy:', step_train_accuracy,\n", + " 'test accuracy:', step_test_accuracy)\n", + "\n", + " train_losses.append(step_train_loss)\n", + " test_losses.append(step_test_loss)\n", + " train_accuracies.append(step_train_accuracy)\n", + " test_accuracies.append(step_test_accuracy)\n", + "\n", + " i += 1\n", + "\n", + " # We've recorded our loss values and accuracies\n", + " # to a list in a graph with AutoGraph's help.\n", + " # In order to return the values as a Tensor,\n", + " # we need to stack them before returning them.\n", + " return (\n", + " autograph.stack(train_losses),\n", + " autograph.stack(test_losses),\n", + " autograph.stack(train_accuracies),\n", + " autograph.stack(test_accuracies),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "HYh6MSZyJOag" + }, + "outputs": [], + "source": [ + "with tf.Graph().as_default():\n", + " hp = tf.contrib.training.HParams(\n", + " learning_rate=0.05,\n", + " max_steps=500,\n", + " )\n", + " train_ds = setup_mnist_data(True, hp, 50)\n", + " test_ds = setup_mnist_data(False, hp, 1000)\n", + " tf_train = autograph.to_graph(train)\n", + " loss_tensors = tf_train(train_ds, test_ds, hp)\n", + "\n", + " with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " (\n", + " train_losses,\n", + " test_losses,\n", + " train_accuracies,\n", + " test_accuracies\n", + " ) = sess.run(loss_tensors)\n", + "\n", + " plt.title('MNIST train/test losses')\n", + " plt.plot(train_losses, label='train loss')\n", + " plt.plot(test_losses, label='test loss')\n", + " plt.legend()\n", + " plt.xlabel('Training step')\n", + " plt.ylabel('Loss')\n", + " plt.show()\n", + " plt.title('MNIST train/test accuracies')\n", + " plt.plot(train_accuracies, label='train accuracy')\n", + " plt.plot(test_accuracies, label='test accuracy')\n", + " plt.legend(loc='lower right')\n", + " plt.xlabel('Training step')\n", + " plt.ylabel('Accuracy')\n", + " plt.show()" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "qqsjik-QyA9R", + "b9AXIkNLxp6J", + "L2psuzPI02S9", + "weKFXAb615Vp", + "Em5dzSUOtLRP" + ], + "default_view": {}, + "name": "AutoGraph Workshop.ipynb", + "provenance": [ + { + "file_id": "1kE2gz_zuwdYySL4K2HQSz13uLCYi-fYP", + "timestamp": 1530563781803 + } + ], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD index 02f16ae1875d6bd1fb87d19f8bfc5cae900391dd..a5438592c30021eac7183b65ccc10c36d220bc57 100644 --- a/tensorflow/contrib/autograph/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -18,20 +18,19 @@ py_library( name = "impl", srcs = [ "api.py", - "config.py", "conversion.py", - "directives.py", - "naming.py", - "special_functions.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/contrib/autograph/converters", + "//tensorflow/contrib/autograph/core", "//tensorflow/contrib/autograph/operators", "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/pyct/static_analysis", "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:platform", + "//tensorflow/python:util", "@gast_archive//:gast", "@six_archive//:six", ], @@ -61,23 +60,3 @@ py_test( "@gast_archive//:gast", ], ) - -py_test( - name = "naming_test", - srcs = ["naming_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":impl", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "special_functions_test", - srcs = ["special_functions_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":impl", - "//tensorflow/python:client_testlib", - ], -) diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index 24f87b2c14da4a3523f1e580d4362cbd3679a2cd..276a3871801da2c66fbfffc38ac1ea39704b5de1 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Public API.""" +"""This module contains the user-facing API for AutoGraph.""" from __future__ import absolute_import from __future__ import division @@ -23,18 +23,18 @@ from functools import wraps from enum import Enum # pylint:disable=g-bad-import-order -import gast import six # pylint:enable=g-bad-import-order -from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.impl import conversion from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import inspect_utils -from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.utils import builtins from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect # TODO(mdan): Properly document the type hints. @@ -42,33 +42,32 @@ from tensorflow.python.util import tf_inspect # (currently we require (module + class name, type)) -def convert(recursive=False, verbose=False, arg_types=None): - """Decorator that compiles a function to graph mode. +# TODO(mdan): This should behave like to_graph (e.g. convert statically). +def convert(recursive=False, verbose=False): + """Decorator that compiles a function to use TensorFlow ops. - The decorator is dynamic - invoking compilation whenever the decorated - function is called. This means the parameter values are known at compilation. + The decorator is dynamic - it recompiles the target whenever the decorated + function is called. This means the parameter values are known at conversion. + It also means that repeated calls with different types of parameters will be + correctly processed. Args: - recursive: Whether to recursively convert any functions that the decorator - function may call. - verbose: Whether to output the compiled code in the logs. - arg_types: See to_graph. + recursive: bool, whether to recursively convert any functions or classes + that the converted function may use. + verbose: bool, whether to output the compiled code in the logs. Returns: - A decorator that compiles the given function to graph mode. - - Raises: - ValueError: If any of the arguments are illegal. + Callable, a decorator that converts the given function into an equivalent + function that uses TensorFlow ops. """ - if arg_types is None: - arg_types = {} - def decorator(f): """Decorator implementation.""" @wraps(f) def wrapper(*args, **kwargs): - return converted_call(f, recursive, verbose, arg_types, *args, **kwargs) + return converted_call(f, recursive, verbose, True, {}, *args, **kwargs) + + wrapper = tf_decorator.make_decorator(f, wrapper) # Sometimes the decorator is just desugared, making it impossible to detect. # This attribute makes detection easier. @@ -79,23 +78,36 @@ def convert(recursive=False, verbose=False, arg_types=None): class RunMode(Enum): + """Specifies the way a converted function or method should be executed in TF. + + The enum values have the following semantics: + + * GRAPH: Call this function directly, as-is. This is suitable for functions + that were already designed for TF graphs and contain ops. + * PY_FUNC: Wrap this function into a py_func op. This is suitable for code + that will only run correctly in Python, for example code that renders + to the display, reads keyboard input, etc. + """ GRAPH = 1 PY_FUNC = 2 def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): - """Decorator that suppresses compilation of a function. + """Decorator that suppresses the conversion of a function. + + See also: docs/pyfunc_dtypes.md Args: - run_as: RunMode value. Whether to run the function as-is, or wrap it into - a py_func. - return_dtypes: See autograph.utils.py_func.wrap_py_func. Setting to None or - empty list or tuple will create a dummy return value that can be used - to set control dependencies. + run_as: RunMode, specifies how to use the function in TensorFlow. + return_dtypes: Optional[Iterable[ + Union[tf.DType, utils.py_func.MatchDType]]], the return data types of + the converted function, if run_as is RunMode.PY_FUNC. Ignored otherwise. + May be set to None if the function has no return values. Returns: - A decorator that wraps the original function. + Callable, a decorator that wraps the original function. """ + def decorator(f): """Decorator implementation.""" @@ -106,8 +118,7 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): @wraps(f) def py_func_wrapper(*args, **kwargs): if kwargs: - raise NotImplementedError( - 'RunMode.PY_FUNC does not yet support kwargs') + raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs') # TODO(mdan): Add support for kwargs. return py_func.wrap_py_func( f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes) @@ -127,12 +138,13 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): return decorator -def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): - """Compiles a function call inline.""" +# TODO(mdan): Move to a private, undocumented module. +def converted_call(f, recursive, verbose, force_conversion, arg_types, *args, + **kwargs): + """Compiles a function call inline. For internal use only.""" # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. - - if conversion.is_whitelisted_for_graph(f): + if not force_conversion and conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) unknown_arg_value = object() # Sentinel for arguments of unknown value @@ -199,64 +211,86 @@ def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): return converted_f(*effective_args, **kwargs) +# TODO(mdan): Rename: to_ops? +# TODO(mdan): Looki into overloading as function and decorator, like tfe.defun. +# TODO(mdan): Remove partial_types. def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): - """Compile a Python entity into equivalent TensorFlow code. + """Converts a Python entity into equivalent code that uses TensorFlow ops. - Currently supported entities: + Supported Python entities include: * functions * classes - Classes are handled by converting all their methods into a new class. + Classes are converted by converting all their methods into a new class. Args: - e: A Python entity. - recursive: Whether to recursively convert any functions that the decorator - function may call. - verbose: Whether to output the compiled code in the logs. - arg_values: A dict containing value hints for symbols like function - parameters. - arg_types: A dict containing type hints for symbols like function - parameters. - partial_types: A set of types (e.g. classes) that will not be converted - entirely. Calls to member functions for these types will be renamed - independently. + e: Union[Callable, Type], the Python entity to convert. + recursive: bool, whether to recursively convert any functions that the + converted function may call. + verbose: bool, whether to output the compiled code in the logs. + arg_values: Optional[Dict[Text, Any]], value hints for symbols including + function arguments. + arg_types: Optional[Dict[Text, Type]], type hints for symbols including + function arguments. + partial_types: Set[Type], reserved for internal use. Returns: - A function with a signature identical to `o`, but which when executed it - creates TF a graph that has the same functionality as the original entity. + Union[Callable, Type], the converted entity, which is the same kind as e + (that is, a function is e is a function, a class if e is a class, etc.) but + its code has been converted to use TF ops. + + Raises: + ValueError: If the entity could not be converted. """ - conversion_map = conversion.ConversionMap( + program_ctx = converter.ProgramContext( recursive=recursive, - nocompile_decorators=(convert, do_not_convert, converted_call), + autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, - api_module=tf_inspect.getmodule(to_graph)) - _, name, namespace = conversion.entity_to_graph(e, conversion_map, arg_values, + autograph_module=tf_inspect.getmodule(to_graph), + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) + _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) - module = gast.Module([]) - for import_line in config.COMPILED_IMPORT_STATEMENTS: - module.body.extend(parser.parse_str(import_line).body) - for dep in reversed(conversion_map.dependency_cache.values()): - module.body.append(dep) - compiled_node, compiled_src = compiler.ast_to_object(module) + nodes = [] + for dep in reversed(program_ctx.dependency_cache.values()): + nodes.extend(dep) + compiled_module, compiled_src = compiler.ast_to_object( + nodes, + source_prefix=program_ctx.required_imports, + include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. - if key not in compiled_node.__dict__: - compiled_node.__dict__[key] = val - compiled_fn = getattr(compiled_node, name) + if key not in compiled_module.__dict__: + compiled_module.__dict__[key] = val + compiled = getattr(compiled_module, name) + + # Need this so the source_mapping attribute is available for the context + # manager to access for runtime errors. + # + # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' + # symbol to the compiled module. + # TODO(mdan): Record this statically in the generated code. + # TODO(mdan): Rename this attribute to 'autograph_info__' + source_map_attribute_name = 'ag_source_map' + if getattr(compiled, source_map_attribute_name, None) is not None: + raise ValueError('cannot convert %s because is has an attribute ' + '"%s", which is reserved for AutoGraph.' % + (compiled, source_map_attribute_name)) + setattr(compiled, source_map_attribute_name, + compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) - return compiled_fn + return compiled def to_code(e, @@ -265,32 +299,34 @@ def to_code(e, arg_types=None, partial_types=None, indentation=' '): - """Return the equivalent of an entity in TensorFlow code. + """Returns the equivalent code that uses TensorFlow ops. - See `to_graph` for more details. + Also see: `to_graph`, `convert` Args: - e: A Python entity. - recursive: See to_graph. - arg_values: See to_graph. - arg_types: See to_graph. - partial_types: See to_graph. - indentation: String, when to use for each level of indentation. + e: Union[Callable, Type], the Python entity to convert. + recursive: bool, whether to recursively convert any functions that the + converted function may call. + arg_values: Optional[Dict[Text, Any]], value hints for symbols including + function arguments. + arg_types: Optional[Dict[Text, Type]], type hints for symbols including + function arguments. + partial_types: Set[Type], reserved for internal use. + indentation: Text, when to use for each level of indentation. Returns: - String. + Text, the converted code. """ - conversion_map = conversion.ConversionMap( + program_ctx = converter.ProgramContext( recursive=recursive, - nocompile_decorators=(convert, do_not_convert, converted_call), + autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, - api_module=tf_inspect.getmodule(to_graph)) - conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) + autograph_module=tf_inspect.getmodule(to_graph), + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) + conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) - imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS) code = '\n'.join( compiler.ast_to_source(dep, indentation) - for dep in reversed(tuple( - six.itervalues(conversion_map.dependency_cache)))) + for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache)))) - return imports + '\n\n' + code + return program_ctx.required_imports + '\n\n' + code diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py index a7737b7f448131b1c54951efa719b481e1f4d0c9..803fde9089b1c004d9bfc0dfefd3d6b422752f0a 100644 --- a/tensorflow/contrib/autograph/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -21,12 +21,13 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.impl import api -from tensorflow.contrib.autograph.impl import config from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.framework import constant_op from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect tf = utils.fake_tf() @@ -154,6 +155,22 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) + def test_decorator_preserves_argspec(self): + + class TestClass(object): + + def called_member(self, a): + if a < 0: + a = -a + return a + + called_member_converted = api.convert()(called_member) + + tc = TestClass() + self.assertListEqual( + list(tf_inspect.getfullargspec(tc.called_member)), + list(tf_inspect.getfullargspec(tc.called_member_converted))) + def test_convert_call_site_decorator(self): class TestClass(object): @@ -166,8 +183,8 @@ class ApiTest(test.TestCase): @api.convert(recursive=True) def test_method(self, x, s, a): while tf.reduce_sum(x) > s: - x //= api.converted_call(self.called_member, False, False, {}, self, - a) + x //= api.converted_call(self.called_member, False, False, False, {}, + self, a) return x tc = TestClass() @@ -178,7 +195,7 @@ class ApiTest(test.TestCase): self.assertListEqual([0, 1], sess.run(x).tolist()) def test_converted_call_builtin(self): - x = api.converted_call(range, False, False, {}, 3) + x = api.converted_call(range, False, False, False, {}, 3) self.assertEqual((0, 1, 2), tuple(x)) def test_converted_call_function(self): @@ -189,8 +206,8 @@ class ApiTest(test.TestCase): return x with self.test_session() as sess: - x = api.converted_call( - test_fn, False, False, {}, constant_op.constant(-1)) + x = api.converted_call(test_fn, False, False, False, {}, + constant_op.constant(-1)) self.assertEqual(1, sess.run(x)) def test_converted_call_method(self): @@ -207,7 +224,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc.test_method, False, False, {}, tc) + x = api.converted_call(tc.test_method, False, False, False, {}, tc) self.assertEqual(1, sess.run(x)) def test_converted_call_method_by_class(self): @@ -224,7 +241,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(TestClass.test_method, False, False, {}, tc) + x = api.converted_call(TestClass.test_method, False, False, False, {}, tc) self.assertEqual(1, sess.run(x)) def test_converted_call_callable_object(self): @@ -241,7 +258,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc, False, False, {}) + x = api.converted_call(tc, False, False, False, {}) self.assertEqual(1, sess.run(x)) def test_converted_call_constructor(self): @@ -257,12 +274,27 @@ class ApiTest(test.TestCase): return self.x with self.test_session() as sess: - tc = api.converted_call( - TestClass, False, False, {}, constant_op.constant(-1)) + tc = api.converted_call(TestClass, False, False, False, {}, + constant_op.constant(-1)) # tc is now a converted object. x = tc.test_method() self.assertEqual(1, sess.run(x)) + def test_converted_call_already_converted(self): + + def f(x): + return x == 0 + + with self.test_session() as sess: + x = api.converted_call(f, False, False, False, {}, + constant_op.constant(0)) + self.assertTrue(sess.run(x)) + + converted_f = api.to_graph(f) + x = api.converted_call(converted_f, False, False, False, {}, + constant_op.constant(0)) + self.assertTrue(sess.run(x)) + def test_to_graph_basic(self): def test_fn(x, s): @@ -288,6 +320,13 @@ class ApiTest(test.TestCase): # Just check that it is parseable Python code. self.assertIsNotNone(parser.parse_str(compiled_code)) + def test_source_map_attribute_present(self): + + def test_fn(y): + return y**2 + + self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 7802bbbe27ec5fed891440af2f589801918b3bdd..fc8a976d3f3ecdc9c6339995dd0dfc776824b90d 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""High level conversion support.""" +"""Core conversion logic, serves as main point of access.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import imp import gast @@ -29,87 +28,34 @@ from tensorflow.contrib.autograph.converters import asserts from tensorflow.contrib.autograph.converters import break_statements from tensorflow.contrib.autograph.converters import builtin_functions from tensorflow.contrib.autograph.converters import call_trees +from tensorflow.contrib.autograph.converters import conditional_expressions from tensorflow.contrib.autograph.converters import continue_statements from tensorflow.contrib.autograph.converters import control_flow from tensorflow.contrib.autograph.converters import decorators -from tensorflow.contrib.autograph.converters import ifexp +from tensorflow.contrib.autograph.converters import directives +from tensorflow.contrib.autograph.converters import error_handlers from tensorflow.contrib.autograph.converters import lists from tensorflow.contrib.autograph.converters import logical_expressions from tensorflow.contrib.autograph.converters import name_scopes +from tensorflow.contrib.autograph.converters import return_statements from tensorflow.contrib.autograph.converters import side_effect_guards -from tensorflow.contrib.autograph.converters import single_return from tensorflow.contrib.autograph.converters import slices -from tensorflow.contrib.autograph.impl import config -from tensorflow.contrib.autograph.impl import naming +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.core import errors from tensorflow.contrib.autograph.pyct import ast_util -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import origin_info from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names -from tensorflow.contrib.autograph.pyct.static_analysis import activity -from tensorflow.contrib.autograph.pyct.static_analysis import live_values -from tensorflow.contrib.autograph.pyct.static_analysis import type_info -from tensorflow.contrib.autograph.utils import type_hints +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect # TODO(mdan): Might we not need any renaming at all? -class ConversionMap(object): - """ConversionMap keeps track of converting function hierarchies. - - This object is mutable, and is updated as functions are converted. - - Attributes: - recursive: Whether to recursively convert any functions that the decorator - function may call. - nocompile_decorators: tuple of decorator functions that toggle compilation - off. - dependency_cache: dict[object]: ast; maps original entities to their - converted AST - additional_imports: set(object); additional entities which for any reason - cannot be attached after loading and need to be explicitly imported - in the generated code - name_map: dict[string]: string; maps original entities to the name of - their converted counterparts - api_module: A reference to the api module. The reference needs to be passed - to avoid circular dependencies. - """ - - # TODO(mdan): Rename to ConversionContext, and pull in additional flags. - - def __init__(self, recursive, nocompile_decorators, partial_types, - api_module): - self.recursive = recursive - self.nocompile_decorators = nocompile_decorators - self.partial_types = partial_types if partial_types else () - # Required to output dependencies in discovery order, which should match - # the reverse dependency order. - self.dependency_cache = collections.OrderedDict() - self.additional_imports = set() - self.name_map = {} - self.api_module = api_module - - def new_namer(self, namespace): - return naming.Namer(namespace, self.recursive, self.name_map, - self.partial_types) - - def update_name_map(self, namer): - for o, name in namer.renamed_calls.items(): - if o in self.name_map: - if self.name_map[o] != name: - raise ValueError( - 'Calls to %s were converted using multiple names (%s). This is ' - 'possible when an entity with one of these names already ' - 'existed. To fix, avoid using any of these names.') - else: - self.name_map[o] = name - - def add_to_cache(self, original_entity, converted_ast): - self.dependency_cache[original_entity] = converted_ast - - def is_whitelisted_for_graph(o): """Check whether an entity is whitelisted for use in graph mode. @@ -125,10 +71,12 @@ def is_whitelisted_for_graph(o): for prefix, in config.DEFAULT_UNCOMPILED_MODULES: if m.__name__.startswith(prefix): return True + if hasattr(o, 'autograph_info__'): + return True return False -def entity_to_graph(o, conversion_map, arg_values, arg_types): +def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. The function will also recursively compile all the entities that `o` @@ -139,7 +87,7 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): Args: o: A Python entity. - conversion_map: A ConversionMap object. + program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function @@ -157,7 +105,7 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): ValueError: if the entity type is not supported. """ if tf_inspect.isclass(o): - node, name, ns = class_to_graph(o, conversion_map) + node, name, ns = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): # TODO(mdan): This is not a reliable mechanism. # The most reliable way is to check the source code, the AST will contain @@ -167,36 +115,55 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): 'lambda functions are not yet supported; declare the function' ' using def instead: %s' % o) else: - node, name, ns = function_to_graph(o, conversion_map, arg_values, - arg_types) + node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): - node, name, ns = function_to_graph(o, conversion_map, arg_values, arg_types) + node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) + # TODO(mdan,yashkatariya): Remove when object conversion is implemented. + elif hasattr(o, '__class__'): + raise NotImplementedError( + 'Object conversion is not yet supported. If you are ' + 'trying to convert code that uses an existing object, ' + 'try including the creation of that object in the ' + 'conversion. For example, instead of converting the method ' + 'of a class, try converting the entire class instead. ' + 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' + 'contrib/autograph/README.md#using-the-functional-api ' + 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) - conversion_map.add_to_cache(o, node) - if conversion_map.recursive: + # TODO(mdan): This is temporary. it should be created using a converter. + # TODO(mdan): The attribute should be added with a helper, not directly. + # The helper can ensure there are no collisions. + template = ''' + entity.autograph_info__ = {} + ''' + node.extend(templates.replace(template, entity=name)) + + program_ctx.add_to_cache(o, node) + + if program_ctx.recursive: while True: candidate = None - for obj in conversion_map.name_map.keys(): - if obj not in conversion_map.dependency_cache: + for obj in program_ctx.name_map.keys(): + if obj not in program_ctx.dependency_cache: candidate = obj break if candidate is None: break if (hasattr(candidate, 'im_class') and - getattr(candidate, 'im_class') not in conversion_map.partial_types): + getattr(candidate, 'im_class') not in program_ctx.partial_types): # Class members are converted with their objects, unless they're # only converted partially. continue - entity_to_graph(candidate, conversion_map, {}, {}) + entity_to_graph(candidate, program_ctx, {}, {}) return node, name, ns -def class_to_graph(c, conversion_map): +def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) @@ -211,29 +178,30 @@ def class_to_graph(c, conversion_map): continue node, _, namespace = function_to_graph( m, - conversion_map=conversion_map, + program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, - owner_type=c) + owner_type=c, + rewrite_errors=False) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) - converted_members[m] = node - namer = conversion_map.new_namer(class_namespace) + converted_members[m] = node[0] + namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. - # Process any base classes: if the sueprclass if of a whitelisted type, an + # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by - # conversion_map.update_name_map(namer)). + # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} - bases = [] + base_names = [] for base in c.__bases__: if isinstance(object, base): - bases.append('object') + base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) @@ -245,28 +213,28 @@ def class_to_graph(c, conversion_map): else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) - bases.append(alias) + base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) - conversion_map.update_name_map(namer) + program_ctx.update_name_map(namer) # Generate the definition of the converted class. - output_nodes.append( - gast.ClassDef( - class_name, - bases=bases, - keywords=[], - body=list(converted_members.values()), - decorator_list=[])) - node = gast.Module(output_nodes) - + bases = [gast.Name(n, gast.Load(), None) for n in base_names] + class_def = gast.ClassDef( + class_name, + bases=bases, + keywords=[], + body=list(converted_members.values()), + decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. - node = qual_names.resolve(node) + class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) - node = ast_util.rename_symbols(node, renames) + class_def = ast_util.rename_symbols(class_def, renames) + + output_nodes.append(class_def) - return node, class_name, class_namespace + return output_nodes, class_name, class_namespace def _add_reserved_symbol(namespace, name, entity): @@ -279,15 +247,17 @@ def _add_reserved_symbol(namespace, name, entity): ag_internal = None -def _add_self_references(namespace, api_module): +def _add_self_references(namespace, autograph_module): """Adds namespace references to the module that exposes the api itself.""" global ag_internal if ag_internal is None: # Craft a module that exposes parts of the external API as well as certain # internal modules. ag_internal = imp.new_module('autograph') - ag_internal.converted_call = api_module.converted_call + ag_internal.converted_call = autograph_module.converted_call ag_internal.utils = utils + ag_internal.rewrite_graph_construction_error = ( + errors.rewrite_graph_construction_error) # TODO(mdan): Add safeguards against name clashes. # We don't want to create a submodule because we want the operators to be # accessible as ag__. @@ -296,59 +266,52 @@ def _add_self_references(namespace, api_module): _add_reserved_symbol(namespace, 'ag__', ag_internal) -def function_to_graph(f, conversion_map, arg_values, arg_types, - owner_type=None): +def function_to_graph(f, + program_ctx, + arg_values, + arg_types, + owner_type=None, + rewrite_errors=True): """Specialization of `entity_to_graph` for callable functions.""" + node, source = parser.parse_entity(f) node = node.body[0] - + origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) - _add_self_references(namespace, conversion_map.api_module) - namer = conversion_map.new_namer(namespace) + _add_self_references(namespace, program_ctx.autograph_module) + namer = program_ctx.new_namer(namespace) - ctx = context.EntityContext( - namer=namer, + entity_info = transformer.EntityInfo( source_code=source, source_file='', namespace=namespace, arg_values=arg_values, arg_types=arg_types, - owner_type=owner_type, - recursive=conversion_map.recursive, - type_annotation_func=type_hints.set_element_type) - node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators) + owner_type=owner_type) + context = converter.EntityContext(namer, entity_info, program_ctx) + node = node_to_graph(node, context, rewrite_errors=rewrite_errors) - # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py + # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError('Strange corner case. Send us offending code!') - node.name = new_name - conversion_map.update_name_map(namer) - # TODO(mdan): Use this at compilation. - conversion_map.additional_imports.update(deps) - - return node, new_name, namespace + program_ctx.update_name_map(namer) + # TODO(mdan): Use this at compilation. -def _static_analysis_pass(node, ctx): - node = qual_names.resolve(node) - node = activity.resolve(node, ctx, None) - node = live_values.resolve(node, ctx, config.PYTHON_LITERALS) - node = type_info.resolve(node, ctx) - return node + return [node], new_name, namespace -def node_to_graph(node, ctx, nocompile_decorators): +def node_to_graph(node, context, rewrite_errors=True): """Convert Python code to equivalent TF graph mode code. Args: - node: A Python AST node representing the code to convert. - ctx: An EntityContext object. - nocompile_decorators: A tuple containing decorators to be stripped from - functions during conversion. + node: AST, the code to convert. + context: converter.EntityContext + rewrite_errors: Boolean, whether or not to rewrite the error traceback. Returns: A tuple (node, deps): @@ -356,59 +319,33 @@ def node_to_graph(node, ctx, nocompile_decorators): * deps: A set of strings, the fully qualified names of entity dependencies that this node has. """ - # TODO(mdan): Verify arguments for correctness. - - # TODO(mdan): Factor out common elements. - # These include: - # * code move between blocks - # * visiting blocks in transformers - - # Certain steps, especially canonicalization, insert new symbols into the - # tree, which must be accounted. Although less efficient, it is most robust - # to re-run the analysis. - - node = _static_analysis_pass(node, ctx) - - # TODO(mdan): Clean this up. - # Some intermediate analyses are not required, and some comments got orphaned. - - # TODO(mdan): We may assume all converters require analysis to be re-done. + # TODO(mdan): Insert list_comprehensions somewhere. + node = converter.standard_analysis(node, context, is_initial=True) # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? - ctx.source_code = None - node = ifexp.transform(node, ctx) - node, deps = decorators.transform(node, nocompile_decorators) - node = break_statements.transform(node, ctx) - node = _static_analysis_pass(node, ctx) - - node = asserts.transform(node, ctx) + context.info.source_code = None + node = converter.apply_(node, context, decorators) + node = converter.apply_(node, context, directives) + node = converter.apply_(node, context, break_statements) + node = converter.apply_(node, context, asserts) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. - node = continue_statements.transform(node, ctx) - ctx.namespace['len'] = len - - node = _static_analysis_pass(node, ctx) - node = single_return.transform(node, ctx) - - node = _static_analysis_pass(node, ctx) - node = lists.transform(node, ctx) - node = _static_analysis_pass(node, ctx) - node = slices.transform(node, ctx) - node = builtin_functions.transform(node, ctx) - - node = _static_analysis_pass(node, ctx) - node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES, - nocompile_decorators) - node = control_flow.transform(node, ctx) - - # control_flow may create new symbols and change scopes. - node = _static_analysis_pass(node, ctx) - node = logical_expressions.transform(node, ctx) - node = side_effect_guards.transform(node, ctx) - node = name_scopes.transform(node, ctx) - - return node, deps + node = converter.apply_(node, context, continue_statements) + context.info.namespace['len'] = len + node = converter.apply_(node, context, return_statements) + node = converter.apply_(node, context, lists) + node = converter.apply_(node, context, slices) + node = converter.apply_(node, context, builtin_functions) + node = converter.apply_(node, context, call_trees) + node = converter.apply_(node, context, control_flow) + node = converter.apply_(node, context, conditional_expressions) + node = converter.apply_(node, context, logical_expressions) + node = converter.apply_(node, context, side_effect_guards) + node = converter.apply_(node, context, name_scopes) + if rewrite_errors: + node = converter.apply_(node, context, error_handlers) + return node diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index bc61498b5422f5e130bbfeef935d0a796b4f5922..86432573a719ea3f2b163746996dbf3301785a91 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.impl import api from tensorflow.contrib.autograph.impl import conversion from tensorflow.python.framework import constant_op @@ -30,8 +32,13 @@ from tensorflow.python.platform import test class ConversionTest(test.TestCase): - def _simple_conversion_map(self): - return conversion.ConversionMap(True, (), (), api) + def _simple_program_ctx(self): + return converter.ProgramContext( + recursive=True, + autograph_decorators=(), + partial_types=(), + autograph_module=api, + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) def test_is_whitelisted_for_graph(self): @@ -43,20 +50,21 @@ class ConversionTest(test.TestCase): self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant)) def test_entity_to_graph_unsupported_types(self): - with self.assertRaises(ValueError): - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph('dummy', conversion_map, None, None) + with self.assertRaises(NotImplementedError): + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph('dummy', program_ctx, None, None) def test_entity_to_graph_callable(self): b = 2 def f(a): return a + b - conversion_map = self._simple_conversion_map() - ast, name, ns = conversion.entity_to_graph(f, conversion_map, None, None) - self.assertTrue(isinstance(ast, gast.FunctionDef), ast) + program_ctx = self._simple_program_ctx() + nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) + fn_node, _ = nodes + self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual('tf__f', name) - self.assertTrue(ns['b'] is b) + self.assertIs(ns['b'], b) def test_entity_to_graph_call_tree(self): @@ -66,18 +74,16 @@ class ConversionTest(test.TestCase): def f(a): return g(a) - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(f, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(f, program_ctx, None, None) - self.assertTrue(f in conversion_map.dependency_cache) - self.assertTrue(g in conversion_map.dependency_cache) - self.assertEqual('tf__f', conversion_map.dependency_cache[f].name) - # need the extra .body[0] in order to step past the with tf.name_scope('f') - # that is added automatically - self.assertEqual( - 'tf__g', - conversion_map.dependency_cache[f].body[0].body[0].value.func.id) - self.assertEqual('tf__g', conversion_map.dependency_cache[g].name) + self.assertTrue(f in program_ctx.dependency_cache) + self.assertTrue(g in program_ctx.dependency_cache) + f_node = program_ctx.dependency_cache[f][0] + g_node = program_ctx.dependency_cache[g][0] + self.assertEqual('tf__f', f_node.name) + self.assertEqual('tf__g', f_node.body[0].body[0].body[0].value.func.id) + self.assertEqual('tf__g', g_node.name) def test_entity_to_graph_class_hierarchy(self): @@ -104,16 +110,17 @@ class ConversionTest(test.TestCase): def baz(self): return self.y - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(TestSubclass, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(TestSubclass, program_ctx, None, None) - self.assertTrue(TestBase in conversion_map.dependency_cache) - self.assertTrue(TestSubclass in conversion_map.dependency_cache) + self.assertTrue(TestBase in program_ctx.dependency_cache) + self.assertTrue(TestSubclass in program_ctx.dependency_cache) + # The returned nodes will include: + # , , self.assertEqual('TfTestBase', - conversion_map.dependency_cache[TestBase].body[-1].name) - self.assertEqual( - 'TfTestSubclass', - conversion_map.dependency_cache[TestSubclass].body[-1].name) + program_ctx.dependency_cache[TestBase][-2].name) + self.assertEqual('TfTestSubclass', + program_ctx.dependency_cache[TestSubclass][-2].name) def test_entity_to_graph_class_hierarchy_whitelisted(self): @@ -126,24 +133,24 @@ class ConversionTest(test.TestCase): def call(self, x): return 3 * x - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(TestSubclass, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(TestSubclass, program_ctx, None, None) - self.assertTrue(TestSubclass in conversion_map.dependency_cache) - self.assertFalse(training.Model in conversion_map.dependency_cache) - self.assertEqual( - 'Model', - conversion_map.dependency_cache[TestSubclass].body[0].names[0].name) + self.assertTrue(TestSubclass in program_ctx.dependency_cache) + self.assertFalse(training.Model in program_ctx.dependency_cache) self.assertEqual( - 'TfTestSubclass', - conversion_map.dependency_cache[TestSubclass].body[-1].name) + 'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name) + # The returned nodes will include: + # , , + self.assertEqual('TfTestSubclass', + program_ctx.dependency_cache[TestSubclass][-2].name) def test_entity_to_graph_lambda(self): f = lambda a: a with self.assertRaises(NotImplementedError): - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(f, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(f, program_ctx, None, None) def test_ag_module_cached(self): def callee(): @@ -152,11 +159,11 @@ class ConversionTest(test.TestCase): def caller(a): return a() - conversion_map = self._simple_conversion_map() - _, _, callee_ns = conversion.entity_to_graph( - callee, conversion_map, None, None) - _, _, caller_ns = conversion.entity_to_graph( - caller, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + _, _, callee_ns = conversion.entity_to_graph(callee, program_ctx, None, + None) + _, _, caller_ns = conversion.entity_to_graph(caller, program_ctx, None, + None) self.assertTrue(callee_ns['ag__'] is caller_ns['ag__']) diff --git a/tensorflow/contrib/autograph/impl/special_functions.py b/tensorflow/contrib/autograph/impl/special_functions.py deleted file mode 100644 index b7a8177c44c88217560fb7f72c77d3ac1aa0c9ec..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/impl/special_functions.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Special functions that only make sense for AutoGraph. - -These functions are meant to ensure feature parity between Python and AutoGraph, -so that the exact same code works in both modes. In general, AutoGraph will -replace these calls. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.autograph.operators import data_structures - - -def stack(list_or_tensor, element_dtype=None): - """Stacks the input, if it admits the notion of stacking. No-op otherwise. - - For example, a list of tensors can be stacked into a larger tensor. This - function is similar to tf.stack, but it accepts non-lists and lists of - non-tensors as arguments. In the latter case, the function does nothing. - - Args: - list_or_tensor: Any entity. - element_dtype: Optional dtype for the elements in the list. Required if the - input is stackable, and the list is untyped. - - Returns: - If the input is stackable, a new object representing the stacked inputs. - Otherwise it returns list_or_tensor unchanged. - """ - return data_structures.list_stack( - list_or_tensor, - data_structures.ListStackOpts( - element_dtype=element_dtype, original_call=lambda x: x)) diff --git a/tensorflow/contrib/autograph/lang/BUILD b/tensorflow/contrib/autograph/lang/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..77a2184e229003a3403cbe3bf116ad2570274a1b --- /dev/null +++ b/tensorflow/contrib/autograph/lang/BUILD @@ -0,0 +1,40 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "lang", + srcs = [ + "directives.py", + "special_functions.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/contrib/autograph/operators", + ], +) + +py_test( + name = "special_functions_test", + srcs = ["special_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":lang", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/impl/directives.py b/tensorflow/contrib/autograph/lang/directives.py similarity index 100% rename from tensorflow/contrib/autograph/impl/directives.py rename to tensorflow/contrib/autograph/lang/directives.py diff --git a/tensorflow/contrib/autograph/lang/special_functions.py b/tensorflow/contrib/autograph/lang/special_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..6149cbbd6c9214fb6989bdcae430286445b1db28 --- /dev/null +++ b/tensorflow/contrib/autograph/lang/special_functions.py @@ -0,0 +1,96 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Special functions that only make sense for AutoGraph. + +These functions are meant to ensure feature parity between Python and AutoGraph, +so that the exact same code works in both modes. In general, AutoGraph will +replace these calls. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import data_structures + + +def tensor_list(elements, + element_dtype=None, + element_shape=None, + use_tensor_array=False): + """Creates an tensor list and populates it with the given elements. + + This function provides a more uniform access to tensor lists and tensor + arrays, and allows optional initialization. + + Note: this function is a simplified wrapper. If you need greater control, + it is recommended to use the underlying implementation directly. + + Args: + elements: Iterable[tf.Tensor, ...], the elements to initially fill the list + with + element_dtype: Optional[tf.DType], data type for the elements in the list; + required if the list is empty + element_shape: Optional[tf.TensorShape], shape for the elements in the list; + required if the list is empty + use_tensor_array: bool, whether to use the more compatible but restrictive + tf.TensorArray implementation + Returns: + Union[tf.Tensor, tf.TensorArray], the new list. + Raises: + ValueError: for invalid arguments + """ + if not (elements or (element_dtype and element_shape)): + raise ValueError( + 'element_dtype and element_shape are required for empty lists') + if use_tensor_array: + return data_structures.tf_tensor_array_new(elements, element_dtype, + element_shape) + else: + return data_structures.tf_tensor_list_new(elements, element_dtype, + element_shape) + + +def stack(list_or_tensor, element_dtype=None, strict=True): + """Stacks the input, if it admits the notion of stacking. + + For example, a list of tensors can be stacked into a larger tensor. This + function is similar to tf.stack, but it accepts non-lists and lists of + non-tensors as arguments. In the latter case, the function does nothing. + + Args: + list_or_tensor: Any + element_dtype: tf.DType, optional dtypedtype for the elements in the list. + Required if the input is stackable, and the list is untyped. + strict: bool, if True an error is raised if the input is not stackable. + Otherwise the function is a no-op. + + Returns: + Any, if the input is stackable, the result will be a tf.Tensor. Otherwise, + if strict=False, the result will be list_or_tensor. + + Raises: + ValueError: if strict=True and the input is not stackable. + """ + if strict: + def raise_error(x): + raise ValueError('%s must be stackable when strict=True' % x) + original_call = raise_error + else: + original_call = lambda x: x + return data_structures.list_stack( + list_or_tensor, + data_structures.ListStackOpts( + element_dtype=element_dtype, original_call=original_call)) diff --git a/tensorflow/contrib/autograph/impl/special_functions_test.py b/tensorflow/contrib/autograph/lang/special_functions_test.py similarity index 60% rename from tensorflow/contrib/autograph/impl/special_functions_test.py rename to tensorflow/contrib/autograph/lang/special_functions_test.py index 9b52d2a59b5a3e3c92f11343197379c773ecc828..db492cc5c689155bf7d426cbfee320130f4bda9f 100644 --- a/tensorflow/contrib/autograph/impl/special_functions_test.py +++ b/tensorflow/contrib/autograph/lang/special_functions_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.impl import special_functions +from tensorflow.contrib.autograph.lang import special_functions from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util @@ -28,15 +28,35 @@ from tensorflow.python.platform import test class SpecialFunctionsTest(test.TestCase): - def test_basic(self): - self.assertEqual(special_functions.stack(1), 1) - self.assertListEqual(special_functions.stack([1, 2, 3]), [1, 2, 3]) + def test_tensor_list_from_elements(self): + elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])] + + l = special_functions.tensor_list(elements) + sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]]) + + def test_tensor_list_array_from_elements(self): + elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])] + + l = special_functions.tensor_list(elements, use_tensor_array=True) + sl = l.stack() + with self.test_session() as sess: + self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]]) + + def test_stack(self): + self.assertEqual(special_functions.stack(1, strict=False), 1) + self.assertListEqual( + special_functions.stack([1, 2, 3], strict=False), [1, 2, 3]) # TODO(mdan): This should probably forward to tf.stack. self.assertTrue( isinstance( special_functions.stack( [constant_op.constant(1), - constant_op.constant(2)]), list)) + constant_op.constant(2)], strict=False), list)) + + with self.assertRaises(ValueError): + special_functions.stack([1, 2, 3]) t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor( diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD index 0c6ab65505ee03e19588adae73d3134399a34b65..332d5dab19e7ade1531b564fbdef2fa0dc2d09d5 100644 --- a/tensorflow/contrib/autograph/operators/BUILD +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -28,7 +28,15 @@ py_library( visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:list_ops", "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_util", + "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", ], ) diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py index c900fd6af2ea5dfb419f731ee8d8822d68424b27..392cb60bcc44c0f554defcddc50c4afbdaa25067 100644 --- a/tensorflow/contrib/autograph/operators/__init__.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""This module implements operators that we overload. +"""This module implements operators that AutoGraph overloads. Note that "operator" is used loosely here, and includes control structures like conditionals and loops, implemented in functional form, using for example diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index 671c9ccc13eaa887522cfc248a6d56d7ab9719ca..9909e521644a7a901653dc09853222167828c75c 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -51,7 +51,7 @@ def for_stmt(iter_, extra_test, body, init_state): Args: iter_: The entity being iterated over. extra_test: Callable with the state as arguments, and boolean return type. - An additionnal loop condition. + An additional loop condition. body: Callable with the iterate and the state as arguments, and state as return type. The actual loop body. init_state: Tuple containing the initial state. @@ -141,7 +141,7 @@ def _dataset_for_stmt(ds, extra_test, body, init_state): while_body, init_state=(epoch_number, iterate) + init_state, extra_deps=()) - # Dropping the epoch number and iterate because they are not not syntactically + # Dropping the epoch number and iterate because they are not syntactically # visible. results = results[2:] @@ -212,12 +212,12 @@ def if_stmt(cond, body, orelse): Tuple containing the statement outputs. """ if tensor_util.is_tensor(cond): - return _tf_if_stmt(cond, body, orelse) + return tf_if_stmt(cond, body, orelse) else: return _py_if_stmt(cond, body, orelse) -def _tf_if_stmt(cond, body, orelse): +def tf_if_stmt(cond, body, orelse): """Overload of if_stmt that stages a TF cond.""" return control_flow_ops.cond(cond, body, orelse) diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py index 06d8727b0fcc30b532b3f11281cd1a83c51ac8bc..cc0a3c35448980945f2975f829f9d9421afdb65d 100644 --- a/tensorflow/contrib/autograph/operators/data_structures.py +++ b/tensorflow/contrib/autograph/operators/data_structures.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.ops import variables # TODO(mdan): Once control flow supports objects, repackage as a class. @@ -48,29 +47,101 @@ def new_list(iterable=None): else: elements = () - # TODO(mdan): Extend these criteria. - if any(isinstance(el, variables.Variable) for el in elements): + if elements: + # When the list contains elements, it is assumed to be a "Python" lvalue + # list. return _py_list_new(elements) - return _tf_tensor_list_new(elements) + return tf_tensor_list_new(elements) -def _tf_tensor_list_new(elements): +def tf_tensor_array_new(elements, element_dtype=None, element_shape=None): """Overload of new_list that stages a Tensor list creation.""" elements = tuple(ops.convert_to_tensor(el) for el in elements) + + all_dtypes = set(el.dtype for el in elements) + if len(all_dtypes) == 1: + inferred_dtype, = tuple(all_dtypes) + if element_dtype is not None and element_dtype != inferred_dtype: + raise ValueError( + 'incompatible dtype; specified: {}, inferred from {}: {}'.format( + element_dtype, elements, inferred_dtype)) + elif len(all_dtypes) > 1: + raise ValueError( + 'TensorArray requires all elements to have the same dtype:' + ' {}'.format(elements)) + else: + if element_dtype is None: + raise ValueError('dtype is required to create an empty TensorArray') + + all_shapes = set(tuple(el.shape.as_list()) for el in elements) + if len(all_shapes) == 1: + inferred_shape, = tuple(all_shapes) + if element_shape is not None and element_shape != inferred_shape: + raise ValueError( + 'incompatible shape; specified: {}, inferred from {}: {}'.format( + element_shape, elements, inferred_shape)) + elif len(all_shapes) > 1: + raise ValueError( + 'TensorArray requires all elements to have the same shape:' + ' {}'.format(elements)) + # TODO(mdan): We may want to allow different shapes with infer_shape=False. + else: + inferred_shape = None + + if element_dtype is None: + element_dtype = inferred_dtype + if element_shape is None: + element_shape = inferred_shape + + l = tensor_array_ops.TensorArray( + dtype=element_dtype, + size=len(elements), + dynamic_size=True, + infer_shape=(element_shape is None), + element_shape=element_shape) + for i, el in enumerate(elements): + l = l.write(i, el) + return l + + +def tf_tensor_list_new(elements, element_dtype=None, element_shape=None): + """Overload of new_list that stages a Tensor list creation.""" + elements = tuple(ops.convert_to_tensor(el) for el in elements) + all_dtypes = set(el.dtype for el in elements) if len(all_dtypes) == 1: - element_dtype = tuple(all_dtypes)[0] + inferred_dtype = tuple(all_dtypes)[0] + if element_dtype is not None and element_dtype != inferred_dtype: + raise ValueError( + 'incompatible dtype; specified: {}, inferred from {}: {}'.format( + element_dtype, elements, inferred_dtype)) else: # Heterogeneous lists are ok. - element_dtype = dtypes.variant + if element_dtype is not None: + raise ValueError( + 'specified dtype {} is inconsistent with that of elements {}'.format( + element_dtype, elements)) + inferred_dtype = dtypes.variant - # TODO(mdan): This may fail for elements of variable shapes. all_shapes = set(tuple(el.shape.as_list()) for el in elements) if len(all_shapes) == 1: - element_shape = array_ops.shape(elements[0]) + inferred_shape = array_ops.shape(elements[0]) + if element_shape is not None and element_shape != inferred_shape: + raise ValueError( + 'incompatible shape; specified: {}, inferred from {}: {}'.format( + element_shape, elements, inferred_shape)) else: # Heterogeneous lists are ok. - element_shape = constant_op.constant(-1) # unknown shape, by convention + if element_shape is not None: + raise ValueError( + 'specified shape {} is inconsistent with that of elements {}'.format( + element_shape, elements)) + inferred_shape = constant_op.constant(-1) # unknown shape, by convention + + if element_dtype is None: + element_dtype = inferred_dtype + if element_shape is None: + element_shape = inferred_shape l = list_ops.empty_tensor_list( element_shape=element_shape, element_dtype=element_dtype) diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py index 8bbb52d6c10b241ec754c7dea599fa15a869595f..7ea11a839b6070f6c6dfdd8a8f7939923a7d9eaa 100644 --- a/tensorflow/contrib/autograph/operators/data_structures_test.py +++ b/tensorflow/contrib/autograph/operators/data_structures_test.py @@ -37,10 +37,51 @@ class ListTest(test.TestCase): def test_new_list_tensor(self): l = data_structures.new_list([3, 4, 5]) + self.assertAllEqual(l, [3, 4, 5]) + + def test_tf_tensor_list_new(self): + l = data_structures.tf_tensor_list_new([3, 4, 5]) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.test_session() as sess: self.assertAllEqual(sess.run(t), [3, 4, 5]) + def test_tf_tensor_list_new_illegal_input(self): + with self.assertRaises(ValueError): + data_structures.tf_tensor_list_new([3, 4.0]) + # TODO(mdan): It might make more sense to type cast in this case. + with self.assertRaises(ValueError): + data_structures.tf_tensor_list_new([3, 4], element_dtype=dtypes.float32) + # Tensor lists do support heterogeneous lists. + self.assertIsNot(data_structures.tf_tensor_list_new([3, [4, 5]]), None) + with self.assertRaises(ValueError): + data_structures.tf_tensor_list_new([3, 4], element_shape=(2,)) + with self.assertRaises(ValueError): + data_structures.tf_tensor_list_new([], element_shape=(2,)) + with self.assertRaises(ValueError): + data_structures.tf_tensor_list_new([], element_dtype=dtypes.float32) + + def test_tf_tensor_array_new(self): + l = data_structures.tf_tensor_array_new([3, 4, 5]) + t = l.stack() + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4, 5]) + + def test_tf_tensor_array_new_illegal_input(self): + with self.assertRaises(ValueError): + data_structures.tf_tensor_array_new([3, 4.0]) + with self.assertRaises(ValueError): + data_structures.tf_tensor_array_new([3, 4], element_dtype=dtypes.float32) + with self.assertRaises(ValueError): + data_structures.tf_tensor_array_new([3, [4, 5]]) + with self.assertRaises(ValueError): + data_structures.tf_tensor_array_new([3, 4], element_shape=(2,)) + with self.assertRaises(ValueError): + data_structures.tf_tensor_array_new([], element_shape=(2,)) + # TAs can infer the shape. + self.assertIsNot( + data_structures.tf_tensor_array_new([], element_dtype=dtypes.float32), + None) + def test_append_tensor_list(self): l = data_structures.new_list() x = constant_op.constant([1, 2, 3]) diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD index 989b821e53a5cefbe39095e669f9a9e0bec65b8a..ddadc6b96e8eb5417bfa1676ae304f7cbdedd92b 100644 --- a/tensorflow/contrib/autograph/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -22,9 +22,10 @@ py_library( "__init__.py", "anno.py", "ast_util.py", + "cfg.py", "compiler.py", - "context.py", "inspect_utils.py", + "origin_info.py", "parser.py", "pretty_printer.py", "qual_names.py", @@ -38,6 +39,8 @@ py_library( "@gast_archive//:gast", "@six_archive//:six", "@termcolor_archive//:termcolor", + # TODO(mdan): Remove this dependency. + "//tensorflow/python:util", ], ) @@ -62,6 +65,17 @@ py_test( ], ) +py_test( + name = "cfg_test", + srcs = ["cfg_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + py_test( name = "compiler_test", srcs = ["compiler_test.py"], @@ -85,6 +99,16 @@ py_test( ], ) +py_test( + name = "origin_info_test", + srcs = ["origin_info_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "parser_test", srcs = ["parser_test.py"], diff --git a/tensorflow/contrib/autograph/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py index ae861627fd65cca057e7bf1af41424e605d4b7a1..1a52110ef36bbc0888e03cc25b3717822cb75c16 100644 --- a/tensorflow/contrib/autograph/pyct/anno.py +++ b/tensorflow/contrib/autograph/pyct/anno.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Handling annotations on AST nodes. +"""AST node annotation support. Adapted from Tangent. """ @@ -21,37 +21,90 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from enum import Enum +import enum +# pylint:disable=g-bad-import-order +import gast +# pylint:enable=g-bad-import-order -class NoValue(Enum): + +# TODO(mdan): Shorten the names. +# These names are heavily used, and anno.blaa +# TODO(mdan): Replace the attr-dict mechanism with a more typed solution. + + +class NoValue(enum.Enum): def __repr__(self): return self.name class Basic(NoValue): - """Container for annotation keys. + """Container for basic annotation keys. The enum values are used strictly for documentation purposes. """ - QN = 'Qualified name, as it appeared in the code.' + QN = 'Qualified name, as it appeared in the code. See qual_names.py.' SKIP_PROCESSING = ( 'This node should be preserved as is and not processed any further.') INDENT_BLOCK_REMAINDER = ( - 'When a node is annotated with this, the remainder of the block should ' - 'be indented below it. The annotation contains a tuple ' - '(new_body, name_map), where `new_body` is the new indented block and ' - '`name_map` allows renaming symbols.') + 'When a node is annotated with this, the remainder of the block should' + ' be indented below it. The annotation contains a tuple' + ' (new_body, name_map), where `new_body` is the new indented block and' + ' `name_map` allows renaming symbols.') + ORIGIN = ('Information about the source code that converted code originated' + ' from. See origin_information.py.') + + +class Static(NoValue): + """Container for static analysis annotation keys. + + The enum values are used strictly for documentation purposes. + """ + + # Deprecated - use reaching definitions instead. + # Symbols + # These flags are boolean. + IS_LOCAL = 'Symbol is local to the function scope being analyzed.' + IS_PARAM = 'Symbol is a parameter to the function being analyzed.' + + # Scopes + # Scopes are represented by objects of type activity.Scope. + SCOPE = 'The scope for the annotated node. See activity.py.' + # TODO(mdan): Drop these in favor of accessing the child's SCOPE. + ARGS_SCOPE = 'The scope for the argument list of a function call.' + COND_SCOPE = 'The scope for the test node of a conditional statement.' + BODY_SCOPE = ( + 'The scope for the main body of a statement (True branch for if ' + 'statements, main body for loops).') + ORELSE_SCOPE = ( + 'The scope for the orelse body of a statement (False branch for if ' + 'statements, orelse body for loops).') + + # Static analysis annotations. + DEFINITIONS = ( + 'Reaching definition information. See reaching_definitions.py.') + ORIG_DEFINITIONS = ( + 'The value of DEFINITIONS that applied to the original code before any' + ' conversion.') + DEFINED_VARS_IN = ( + 'Symbols defined when entering the node. See reaching_definitions.py.') + LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.') FAIL = object() +def keys(node, field_name='___pyct_anno'): + if not hasattr(node, field_name): + return frozenset() + return frozenset(getattr(node, field_name).keys()) + + def getanno(node, key, default=FAIL, field_name='___pyct_anno'): - if (default is FAIL or - (hasattr(node, field_name) and (key in getattr(node, field_name)))): + if (default is FAIL or (hasattr(node, field_name) and + (key in getattr(node, field_name)))): return getattr(node, field_name)[key] else: return default @@ -86,3 +139,19 @@ def copyanno(from_node, to_node, key, field_name='___pyct_anno'): key, getanno(from_node, key, field_name=field_name), field_name=field_name) + + +def dup(node, copy_map, field_name='___pyct_anno'): + """Recursively copies annotations in an AST tree. + + Args: + node: ast.AST + copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination + key. All annotations with the source key will be copied to identical + annotations with the destination key. + field_name: str + """ + for n in gast.walk(node): + for k in copy_map: + if hasanno(n, k, field_name): + setanno(n, copy_map[k], getanno(n, k, field_name), field_name) diff --git a/tensorflow/contrib/autograph/pyct/anno_test.py b/tensorflow/contrib/autograph/pyct/anno_test.py index f2c0c8cf05ca4b3671eb653ce56f6da61de54aee..5ef4da61a3627f9c0bc615ce5cb56052a28c64d1 100644 --- a/tensorflow/contrib/autograph/pyct/anno_test.py +++ b/tensorflow/contrib/autograph/pyct/anno_test.py @@ -32,22 +32,27 @@ class AnnoTest(test.TestCase): def test_basic(self): node = ast.Name() + self.assertEqual(anno.keys(node), set()) self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') anno.setanno(node, 'foo', 3) + + self.assertEqual(anno.keys(node), {'foo'}) self.assertTrue(anno.hasanno(node, 'foo')) self.assertEqual(anno.getanno(node, 'foo'), 3) self.assertEqual(anno.getanno(node, 'bar', default=7), 7) anno.delanno(node, 'foo') + + self.assertEqual(anno.keys(node), set()) self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') self.assertIsNone(anno.getanno(node, 'foo', default=None)) - def test_copyanno(self): + def test_copy(self): node_1 = ast.Name() anno.setanno(node_1, 'foo', 3) @@ -58,6 +63,22 @@ class AnnoTest(test.TestCase): self.assertTrue(anno.hasanno(node_2, 'foo')) self.assertFalse(anno.hasanno(node_2, 'bar')) + def test_duplicate(self): + node = ast.If( + test=ast.Num(1), + body=[ast.Expr(ast.Name('bar', ast.Load()))], + orelse=[]) + anno.setanno(node, 'spam', 1) + anno.setanno(node, 'ham', 1) + anno.setanno(node.body[0], 'ham', 1) + + anno.dup(node, {'spam': 'eggs'}) + + self.assertTrue(anno.hasanno(node, 'spam')) + self.assertTrue(anno.hasanno(node, 'ham')) + self.assertTrue(anno.hasanno(node, 'eggs')) + self.assertFalse(anno.hasanno(node.body[0], 'eggs')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py index c4f82d11708393a6029d3f17be428b47eb9342ff..d7453b078197cd6f1c0521b861e96dd28d287cab 100644 --- a/tensorflow/contrib/autograph/pyct/ast_util.py +++ b/tensorflow/contrib/autograph/pyct/ast_util.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Copy an AST tree, discarding annotations.""" +"""AST manipulation utilities.""" from __future__ import absolute_import from __future__ import division @@ -26,47 +26,53 @@ from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import parser -class CleanCopier(gast.NodeVisitor): - """Copies AST nodes. +class CleanCopier(object): + """NodeTransformer-like visitor that copies an AST.""" - The copied nodes will ignore almost all fields that are prefixed by '__'. - Exceptions make some annotations. - """ + def __init__(self, preserve_annos): + super(CleanCopier, self).__init__() + self.preserve_annos = preserve_annos - # TODO(mdan): Parametrize which annotations get carried over. + def copy(self, node): + """Returns a deep copy of node (excluding some fields, see copy_clean).""" + + if isinstance(node, list): + return [self.copy(n) for n in node] + elif isinstance(node, tuple): + return tuple(self.copy(n) for n in node) + elif not isinstance(node, (gast.AST, ast.AST)): + # Assuming everything that's not an AST, list or tuple is a value type + # and may simply be assigned. + return node + + assert isinstance(node, (gast.AST, ast.AST)) - def generic_visit(self, node): new_fields = {} for f in node._fields: - if f.startswith('__'): - continue - if not hasattr(node, f): - continue - v = getattr(node, f) - if isinstance(v, list): - v = [self.generic_visit(n) for n in v] - elif isinstance(v, tuple): - v = tuple(self.generic_visit(n) for n in v) - elif isinstance(v, (gast.AST, ast.AST)): - v = self.generic_visit(v) - else: - # Assume everything else is a value type. - pass - new_fields[f] = v + if not f.startswith('__') and hasattr(node, f): + new_fields[f] = self.copy(getattr(node, f)) new_node = type(node)(**new_fields) - if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): - anno.setanno(new_node, anno.Basic.SKIP_PROCESSING, True) + + if self.preserve_annos: + for k in self.preserve_annos: + anno.copyanno(node, new_node, k) return new_node -def copy_clean(node): - copier = CleanCopier() - if isinstance(node, list): - return [copier.visit(n) for n in node] - elif isinstance(node, tuple): - return tuple(copier.visit(n) for n in node) - else: - return copier.visit(node) +def copy_clean(node, preserve_annos=None): + """Creates a deep copy of an AST. + + The copy will not include fields that are prefixed by '__', with the + exception of user-specified annotations. + + Args: + node: ast.AST + preserve_annos: Optional[Set[Hashable]], annotation keys to include in the + copy + Returns: + ast.AST + """ + return CleanCopier(preserve_annos).copy(node) class SymbolRenamer(gast.NodeTransformer): @@ -78,7 +84,11 @@ class SymbolRenamer(gast.NodeTransformer): def _process(self, node): qn = anno.getanno(node, anno.Basic.QN) if qn in self.name_map: - return gast.Name(str(self.name_map[qn]), node.ctx, None) + new_node = gast.Name(str(self.name_map[qn]), node.ctx, None) + # All annotations get carried over. + for k in anno.keys(node): + anno.copyanno(node, new_node, k) + return new_node return self.generic_visit(node) def visit_Name(self, node): @@ -92,6 +102,7 @@ class SymbolRenamer(gast.NodeTransformer): def rename_symbols(node, name_map): + """Renames symbols in an AST. Requires qual_names annotations.""" renamer = SymbolRenamer(name_map) if isinstance(node, list): return [renamer.visit(n) for n in node] @@ -101,6 +112,7 @@ def rename_symbols(node, name_map): def keywords_to_dict(keywords): + """Converts a list of ast.keyword objects to a dict.""" keys = [] values = [] for kw in keywords: @@ -110,10 +122,7 @@ def keywords_to_dict(keywords): class PatternMatcher(gast.NodeVisitor): - """Matches a node against a pattern represented by a node. - - The pattern may contain wildcards represented by the symbol '_'. - """ + """Matches a node against a pattern represented by a node.""" def __init__(self, pattern): self.pattern = pattern @@ -177,9 +186,128 @@ class PatternMatcher(gast.NodeVisitor): def matches(node, pattern): + """Basic pattern matcher for AST. + + The pattern may contain wildcards represented by the symbol '_'. A node + matches a pattern if for every node in the tree, either there is a node of + the same type in pattern, or a Name node with id='_'. + + Args: + node: ast.AST + pattern: ast.AST + Returns: + bool + """ if isinstance(pattern, str): pattern = parser.parse_expression(pattern) matcher = PatternMatcher(pattern) matcher.visit(node) return matcher.matches + +# TODO(mdan): Once we have error tracing, we may be able to just go to SSA. +def apply_to_single_assignments(targets, values, apply_fn): + """Applies a function to each individual assignment. + + This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. + It tries to break down the unpacking if possible. In effect, it has the same + effect as passing the assigned values in SSA form to apply_fn. + + Examples: + + The following will result in apply_fn(a, c), apply_fn(b, d): + + a, b = c, d + + The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): + + a, b = c + + The following will result in apply_fn(a, (b, c)): + + a = b, c + + It uses the visitor pattern to allow subclasses to process single + assignments individually. + + Args: + targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be + used with the targets field of an ast.Assign node + values: ast.AST + apply_fn: Callable[[ast.AST, ast.AST], None], called with the + respective nodes of each single assignment + """ + if not isinstance(targets, (list, tuple)): + targets = (targets,) + for target in targets: + if isinstance(target, (gast.Tuple, gast.List)): + for i in range(len(target.elts)): + target_el = target.elts[i] + if isinstance(values, (gast.Tuple, gast.List)): + value_el = values.elts[i] + else: + idx = parser.parse_expression(str(i)) + value_el = gast.Subscript(values, gast.Index(idx), ctx=gast.Load()) + apply_to_single_assignments(target_el, value_el, apply_fn) + else: + apply_fn(target, values) + + +def parallel_walk(node, other): + """Walks two ASTs in parallel. + + The two trees must have identical structure. + + Args: + node: Union[ast.AST, Iterable[ast.AST]] + other: Union[ast.AST, Iterable[ast.AST]] + Yields: + Tuple[ast.AST, ast.AST] + Raises: + ValueError: if the two trees don't have identical structure. + """ + if isinstance(node, (list, tuple)): + node_stack = list(node) + else: + node_stack = [node] + + if isinstance(other, (list, tuple)): + other_stack = list(other) + else: + other_stack = [other] + + while node_stack and other_stack: + assert len(node_stack) == len(other_stack) + n = node_stack.pop() + o = other_stack.pop() + + if (not isinstance(n, (ast.AST, gast.AST)) or + not isinstance(o, (ast.AST, gast.AST)) or + n.__class__.__name__ != o.__class__.__name__): + raise ValueError('inconsistent nodes: {} and {}'.format(n, o)) + + yield n, o + + for f in n._fields: + n_child = getattr(n, f, None) + o_child = getattr(o, f, None) + if f.startswith('__') or n_child is None or o_child is None: + continue + + if isinstance(n_child, (list, tuple)): + if (not isinstance(o_child, (list, tuple)) or + len(n_child) != len(o_child)): + raise ValueError( + 'inconsistent values for field {}: {} and {}'.format( + f, n_child, o_child)) + node_stack.extend(n_child) + other_stack.extend(o_child) + + elif isinstance(n_child, (gast.AST, ast.AST)): + node_stack.append(n_child) + other_stack.append(o_child) + + elif n_child != o_child: + raise ValueError( + 'inconsistent values for field {}: {} and {}'.format( + f, n_child, o_child)) diff --git a/tensorflow/contrib/autograph/pyct/ast_util_test.py b/tensorflow/contrib/autograph/pyct/ast_util_test.py index 3afa04a50685d19c90944c14ed39f9d3ad35e486..2293c89720a54f7495670c6f28b00f716cad70db 100644 --- a/tensorflow/contrib/autograph/pyct/ast_util_test.py +++ b/tensorflow/contrib/autograph/pyct/ast_util_test.py @@ -19,7 +19,10 @@ from __future__ import division from __future__ import print_function import ast +import collections +import textwrap +from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import parser @@ -29,62 +32,75 @@ from tensorflow.python.platform import test class AstUtilTest(test.TestCase): - def test_rename_symbols(self): - node = ast.Tuple([ - ast.Name('a', ast.Load()), - ast.Name('b', ast.Load()), - ast.Attribute(ast.Name('b', None), 'c', ast.Store()), - ast.Attribute( - ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', None) - ], None) + def setUp(self): + super(AstUtilTest, self).setUp() + self._invocation_counts = collections.defaultdict(lambda: 0) + + def test_rename_symbols_basic(self): + node = parser.parse_str('a + b') node = qual_names.resolve(node) + node = ast_util.rename_symbols( - node, { - qual_names.QN('a'): - qual_names.QN('renamed_a'), - qual_names.QN(qual_names.QN('b'), attr='c'): - qual_names.QN('renamed_b_c'), - }) - - self.assertEqual(node.elts[0].id, 'renamed_a') - self.assertTrue(isinstance(node.elts[0].ctx, ast.Load)) - self.assertEqual(node.elts[1].id, 'b') - self.assertEqual(node.elts[2].id, 'renamed_b_c') - self.assertTrue(isinstance(node.elts[2].ctx, ast.Store)) - self.assertEqual(node.elts[3].value.id, 'renamed_b_c') - self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load)) + node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) + + self.assertIsInstance(node.body[0].value.left.id, str) + source = compiler.ast_to_source(node) + self.assertEqual(source.strip(), 'renamed_a + b') + + def test_rename_symbols_attributes(self): + node = parser.parse_str('b.c = b.c.d') + node = qual_names.resolve(node) + + node = ast_util.rename_symbols( + node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) + + source = compiler.ast_to_source(node) + self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d') + + def test_rename_symbols_annotations(self): + node = parser.parse_str('a[i]') + node = qual_names.resolve(node) + anno.setanno(node, 'foo', 'bar') + orig_anno = anno.getanno(node, 'foo') + + node = ast_util.rename_symbols(node, + {qual_names.QN('a'): qual_names.QN('b')}) + + self.assertIs(anno.getanno(node, 'foo'), orig_anno) def test_copy_clean(self): - ret = ast.Return( - ast.BinOp( - op=ast.Add(), - left=ast.Name(id='a', ctx=ast.Load()), - right=ast.Num(1))) - setattr(ret, '__foo', 'bar') - node = ast.FunctionDef( - name='f', - args=ast.arguments( - args=[ast.Name(id='a', ctx=ast.Param())], - vararg=None, - kwarg=None, - defaults=[]), - body=[ret], - decorator_list=[], - returns=None) + node = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 1 + """)) + setattr(node.body[0], '__foo', 'bar') new_node = ast_util.copy_clean(node) - self.assertFalse(node is new_node) - self.assertFalse(ret is new_node.body[0]) + self.assertIsNot(new_node, node) + self.assertIsNot(new_node.body[0], node.body[0]) self.assertFalse(hasattr(new_node.body[0], '__foo')) + def test_copy_clean_preserves_annotations(self): + node = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 1 + """)) + anno.setanno(node.body[0], 'foo', 'bar') + anno.setanno(node.body[0], 'baz', 1) + new_node = ast_util.copy_clean(node, preserve_annos={'foo'}) + self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar') + self.assertFalse(anno.hasanno(new_node.body[0], 'baz')) + def test_keywords_to_dict(self): keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords d = ast_util.keywords_to_dict(keywords) # Make sure we generate a usable dict node by attaching it to a variable and # compiling everything. - output = parser.parse_str('b = 3') - output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),) - result, _ = compiler.ast_to_object(output) - self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'}) + node = parser.parse_str('def f(b): pass').body[0] + node.body.append(ast.Return(d)) + result, _ = compiler.ast_to_object(node) + self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'}) def assertMatch(self, target_str, pattern_str): node = parser.parse_expression(target_str) @@ -113,6 +129,68 @@ class AstUtilTest(test.TestCase): self.assertNoMatch('super(Foo, self).__init__()', 'super(Bar, _).__init__(_)') + def _mock_apply_fn(self, target, source): + target = compiler.ast_to_source(target) + source = compiler.ast_to_source(source) + self._invocation_counts[(target.strip(), source.strip())] += 1 + + def test_apply_to_single_assignments_dynamic_unpack(self): + node = parser.parse_str('a, b, c = d') + node = node.body[0] + ast_util.apply_to_single_assignments(node.targets, node.value, + self._mock_apply_fn) + self.assertDictEqual(self._invocation_counts, { + ('a', 'd[0]'): 1, + ('b', 'd[1]'): 1, + ('c', 'd[2]'): 1, + }) + + def test_apply_to_single_assignments_static_unpack(self): + node = parser.parse_str('a, b, c = d, e, f') + node = node.body[0] + ast_util.apply_to_single_assignments(node.targets, node.value, + self._mock_apply_fn) + self.assertDictEqual(self._invocation_counts, { + ('a', 'd'): 1, + ('b', 'e'): 1, + ('c', 'f'): 1, + }) + + def test_parallel_walk(self): + node = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 1 + """)) + for child_a, child_b in ast_util.parallel_walk(node, node): + self.assertEqual(child_a, child_b) + + def test_parallel_walk_inconsistent_trees(self): + node_1 = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 1 + """)) + node_2 = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + (a * 2) + """)) + node_3 = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 2 + """)) + with self.assertRaises(ValueError): + for _ in ast_util.parallel_walk(node_1, node_2): + pass + # There is not particular reason to reject trees that differ only in the + # value of a constant. + # TODO(mdan): This should probably be allowed. + with self.assertRaises(ValueError): + for _ in ast_util.parallel_walk(node_1, node_3): + pass + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/contrib/autograph/pyct/cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..ba51dcf285036220e01b89e8beeb9aec8ffe36be --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/cfg.py @@ -0,0 +1,815 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow graph (CFG) structure for Python AST representation. + +The CFG is a digraph with edges representing valid control flow. Each +node is associated with exactly one AST node, but not all AST nodes may have +a corresponding CFG counterpart. + +Once built, the CFG itself is immutable, but the values it holds need not be; +they are usually annotated with information extracted by walking the graph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from enum import Enum + +# pylint:disable=g-bad-import-order +import gast +# pylint:enable=g-bad-import-order + +from tensorflow.contrib.autograph.pyct import compiler + + +class Node(object): + """A node in the CFG. + + Although new instances of this class are mutable, the objects that a user + finds in the CFG are typically not. + + The nodes represent edges in the CFG graph, and maintain pointers to allow + efficient walking in both forward and reverse order. The following property + holds for all nodes: "child in node.next" iff "node in child.prev". + + Attributes: + next: FrozenSet[Node, ...], the nodes that follow this node, in control + flow order + prev: FrozenSet[Node, ...], the nodes that precede this node, in reverse + control flow order + ast_node: ast.AST, the AST node corresponding to this CFG node + """ + + def __init__(self, next_, prev, ast_node): + self.next = next_ + self.prev = prev + self.ast_node = ast_node + + def freeze(self): + self.next = frozenset(self.next) + self.prev = frozenset(self.prev) + + def __repr__(self): + if isinstance(self.ast_node, gast.FunctionDef): + return 'def %s' % self.ast_node.name + elif isinstance(self.ast_node, gast.withitem): + return compiler.ast_to_source(self.ast_node.context_expr).strip() + return compiler.ast_to_source(self.ast_node).strip() + + +class Graph( + collections.namedtuple( + 'Graph', + ['entry', 'exit', 'error', 'index', 'stmt_prev', 'stmt_next'])): + """A Control Flow Graph. + + The CFG maintains an index to allow looking up a CFG node by the AST node to + which it is associated. The index can also be enumerated in top-down, depth + first order. + + Walking the graph in forward or reverse order is supported by double + parent-child links. + + Note: the error nodes are not wired to their corresponding finally guards, + because these are shared, and wiring them would create a reverse path from + normal control flow into the error nodes, which we want to avoid. + + The graph also maintains edges corresponding to higher level statements + like for-else loops. A node is considered successor of a statement if there + is an edge from a node that is lexically a child of that statement to a node + that is not. Statement predecessors are analogously defined. + + Attributes: + entry: Node, the entry node + exit: FrozenSet[Node, ...], the exit nodes + error: FrozenSet[Node, ...], nodes that exit due to an explicitly raised + error (errors propagated from function calls are not accounted) + index: Dict[ast.Node, Node], mapping AST nodes to the respective CFG + node + stmt_prev: Dict[ast.Node, FrozenSet[Node, ...]], mapping statement AST + nodes to their predecessor CFG nodes + stmt_next: Dict[ast.Node, FrozenSet[Node, ...]], mapping statement AST + nodes to their successor CFG nodes + """ + + def __repr__(self): + result = 'digraph CFG {\n' + for node in self.index.values(): + result += ' %s [label="%s"];\n' % (id(node), node) + for node in self.index.values(): + for next_ in node.next: + result += ' %s -> %s;\n' % (id(node), id(next_)) + result += '}' + return result + + +class _WalkMode(Enum): + FORWARD = 1 + REVERSE = 2 + + +# TODO(mdan): Rename to DataFlowAnalyzer. +# TODO(mdan): Consider specializations that use gen/kill/transfer abstractions. +class GraphVisitor(object): + """Base class for a CFG visitors. + + This implementation is not thread safe. + + The visitor has some facilities to simplify dataflow analyses. In particular, + it allows revisiting the nodes at the decision of the subclass. This can be + used to visit the graph until the state reaches a fixed point. + + For more details on dataflow analysis, see + https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf + + Note: the literature generally suggests visiting successor nodes only when the + state of the current node changed, regardless of whether that successor has + ever been visited. This implementation visits every successor at least once. + + Attributes: + graph: Graph + in_: Dict[Node, Any], stores node-keyed state during a visit + out: Dict[Node, Any], stores node-keyed state during a visit + """ + + def __init__(self, graph): + self.graph = graph + self.reset() + + def init_state(self, node): + """State initialization function. Optional to overload. + + An in/out state slot will be created for each node in the graph. Subclasses + must overload this to control what that is initialized to. + + Args: + node: Node + """ + raise NotImplementedError('Subclasses must implement this.') + + # TODO(mdan): Rename to flow? + def visit_node(self, node): + """Visitor function. + + Args: + node: Node + Returns: + bool, whether the node should be revisited; subclasses can visit every + reachable node exactly once by always returning False + """ + raise NotImplementedError('Subclasses must implement this.') + + def reset(self): + self.in_ = { + node: self.init_state(node) for node in self.graph.index.values() + } + self.out = { + node: self.init_state(node) for node in self.graph.index.values() + } + + def _visit_internal(self, mode): + """Visits the CFG, depth-first.""" + assert mode in (_WalkMode.FORWARD, _WalkMode.REVERSE) + if mode == _WalkMode.FORWARD: + open_ = [self.graph.entry] + elif mode == _WalkMode.REVERSE: + open_ = list(self.graph.exit) + closed = set() + + while open_: + node = open_.pop(0) + closed.add(node) + + should_revisit = self.visit_node(node) + + if mode == _WalkMode.FORWARD: + children = node.next + elif mode == _WalkMode.REVERSE: + children = node.prev + + for next_ in children: + if should_revisit or next_ not in closed: + open_.append(next_) + + def visit_forward(self): + self._visit_internal(_WalkMode.FORWARD) + + def visit_reverse(self): + self._visit_internal(_WalkMode.REVERSE) + + +class GraphBuilder(object): + """Builder that constructs a CFG from a given AST. + + This GraphBuilder facilitates constructing the DAG that forms the CFG when + nodes + are supplied in lexical order (i.e., top-down, depth first). Under these + conditions, it supports building patterns found in typical structured + programs. + + This builder ignores the flow generated by exceptions, which are assumed to + always be catastrophic and present purely for diagnostic purposes (e.g. to + print debug information). Statements like raise and try/catch sections are + allowed and will generate control flow edges, but ordinaty statements are + assumed not to raise exceptions. + + Finally sections are also correctly interleaved between break/continue/return + nodes and their subsequent statements. + + Important concepts: + * nodes - nodes refer refer to CFG nodes; AST nodes are qualified explicitly + * leaf set - since the graph is constructed gradually, a leaf set maintains + the CFG nodes that will precede the node that the builder expects to + receive next; when an ordinary node is added, it is connected to the + existing leaves and it in turn becomes the new leaf + * jump nodes - nodes that should generate edges other than what + ordinary nodes would; these correspond to break, continue and return + statements + * sections - logical delimiters for subgraphs that require special + edges; there are various types of nodes, each admitting various + types of jump nodes; sections are identified by their corresponding AST + node + """ + + # TODO(mdan): Perhaps detail this in a markdown doc. + # TODO(mdan): Add exception support. + + def __init__(self, parent_ast_node): + self.reset() + self.parent = parent_ast_node + + def reset(self): + """Resets the state of this factory.""" + self.head = None + self.errors = set() + self.node_index = collections.OrderedDict() + + # TODO(mdan): Too many primitives. Use classes. + self.leaves = set() + + # Note: This mechanism requires that nodes are added in lexical order (top + # to bottom, depth first). + self.active_stmts = set() + self.owners = {} # type: Set[any] + self.forward_edges = set() # type: Tuple[Node, Node] # (from, to) + + self.finally_sections = {} + # Dict values represent (entry, exits) + self.finally_section_subgraphs = { + } # type: Dict[ast.AST, Tuple[Node, Set[Node]]] + # Whether the guard section can be reached from the statement that precedes + # it. + self.finally_section_has_direct_flow = {} + # Finally sections that await their first node. + self.pending_finally_sections = set() + + # Exit jumps keyed by the section they affect. + self.exits = {} + + # The entry of loop sections, keyed by the section. + self.section_entry = {} + # Continue jumps keyed by the section they affect. + self.continues = {} + + # The entry of conditional sections, keyed by the section. + self.cond_entry = {} + # Lists of leaf nodes corresponding to each branch in the section. + self.cond_leaves = {} + + def _connect_nodes(self, first, second): + """Connects nodes to signify that control flows from first to second. + + Args: + first: Union[Set[Node, ...], Node] + second: Node + """ + if isinstance(first, Node): + first.next.add(second) + second.prev.add(first) + self.forward_edges.add((first, second)) + else: + for node in first: + self._connect_nodes(node, second) + + def _add_new_node(self, ast_node): + """Grows the graph by adding a CFG node following the current leaves.""" + if ast_node is self.node_index: + raise ValueError('%s added twice' % ast_node) + node = Node(next_=set(), prev=set(), ast_node=ast_node) + self.node_index[ast_node] = node + self.owners[node] = frozenset(self.active_stmts) + + if self.head is None: + self.head = node + + for leaf in self.leaves: + self._connect_nodes(leaf, node) + + # If any finally section awaits its first node, populate it. + for section_id in self.pending_finally_sections: + self.finally_section_subgraphs[section_id][0] = node + self.pending_finally_sections = set() + + return node + + def begin_statement(self, stmt): + """Marks the beginning of a statement. + + Args: + stmt: Hashable, a key by which the statement can be identified in + the CFG's stmt_prev and stmt_next attributes + """ + self.active_stmts.add(stmt) + + def end_statement(self, stmt): + """Marks the end of a statement. + + Args: + stmt: Hashable, a key by which the statement can be identified in + the CFG's stmt_prev and stmt_next attributes; must match a key + previously passed to begin_statement. + """ + self.active_stmts.remove(stmt) + + def add_ordinary_node(self, ast_node): + """Grows the graph by adding an ordinary CFG node. + + Ordinary nodes are followed by the next node, in lexical order, that is, + they become the new leaf set. + + Args: + ast_node: ast.AST + Returns: + Node + """ + node = self._add_new_node(ast_node) + self.leaves = set((node,)) + return node + + def _add_jump_node(self, ast_node, guards): + """Grows the graph by adding a jump node. + + Jump nodes are added to the current leaf set, and the leaf set becomes + empty. If the jump node is the last in a cond section, then it may be added + back to the leaf set by a separate mechanism. + + Args: + ast_node: ast.AST + guards: Tuple[ast.AST, ...], the finally sections active for this node + Returns: + Node + """ + node = self._add_new_node(ast_node) + self.leaves = set() + # The guards themselves may not yet be complete, and will be wired later. + self.finally_sections[node] = guards + return node + + def _connect_jump_to_finally_sections(self, node): + """Connects a jump node to the finally sections protecting it.""" + cursor = set((node,)) + for guard_section_id in self.finally_sections[node]: + guard_begin, guard_ends = self.finally_section_subgraphs[guard_section_id] + self._connect_nodes(cursor, guard_begin) + cursor = guard_ends + del self.finally_sections[node] + # TODO(mdan): Should garbage-collect finally_section_subgraphs. + return cursor + + def add_exit_node(self, ast_node, section_id, guards): + """Grows the graph by adding an exit node. + + This node becomes an exit for the current section. + + Args: + ast_node: ast.AST + section_id: Hashable, the node for which ast_node should be considered + to be an exit node + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.exits[section_id].add(node) + + def add_continue_node(self, ast_node, section_id, guards): + """Grows the graph by adding a reentry node. + + This node causes control flow to go back to the loop section's entry. + + Args: + ast_node: ast.AST + section_id: Hashable, the node for which ast_node should be considered + to be an exit node + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.continues[section_id].add(node) + + def add_error_node(self, ast_node, guards): + """Grows the graph by adding an error node. + + This node becomes an exit for the entire graph. + + Args: + ast_node: ast.AST + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.errors.add(node) + self.leaves = set() + + def enter_section(self, section_id): + """Enters a regular section. + + Regular sections admit exit jumps, which end the section. + + Args: + section_id: Hashable, the same node that will be used in calls to the + ast_node arg passed to add_exit_node + """ + assert section_id not in self.exits + self.exits[section_id] = set() + + def exit_section(self, section_id): + """Exits a regular section.""" + + # Exits are jump nodes, which may be protected. + for exit_ in self.exits[section_id]: + self.leaves |= self._connect_jump_to_finally_sections(exit_) + + del self.exits[section_id] + + def enter_loop_section(self, section_id, entry_node): + """Enters a loop section. + + Loop sections define an entry node. The end of the section always flows back + to the entry node. These admit continue jump nodes which also flow to the + entry node. + + Args: + section_id: Hashable, the same node that will be used in calls to the + ast_node arg passed to add_continue_node + entry_node: ast.AST, the entry node into the loop (e.g. the test node + for while loops) + """ + assert section_id not in self.section_entry + assert section_id not in self.continues + self.continues[section_id] = set() + node = self.add_ordinary_node(entry_node) + self.section_entry[section_id] = node + + def exit_loop_section(self, section_id): + """Exits a loop section.""" + self._connect_nodes(self.leaves, self.section_entry[section_id]) + + # continues are jump nodes, which may be protected. + for reentry in self.continues[section_id]: + guard_ends = self._connect_jump_to_finally_sections(reentry) + self._connect_nodes(guard_ends, self.section_entry[section_id]) + + # Loop nodes always loop back. + self.leaves = set((self.section_entry[section_id],)) + + del self.continues[section_id] + del self.section_entry[section_id] + + def enter_cond_section(self, section_id): + """Enters a conditional section. + + Conditional sections define an entry node, and one or more branches. + + Args: + section_id: Hashable, the same node that will be used in calls to the + section_id arg passed to new_cond_branch + """ + + assert section_id not in self.cond_entry + assert section_id not in self.cond_leaves + self.cond_leaves[section_id] = [] + + def new_cond_branch(self, section_id): + """Begins a new branch in a cond section.""" + assert section_id in self.cond_leaves + + if section_id in self.cond_entry: + # Subsequent splits move back to the split point, and memorize the + # current leaves. + self.cond_leaves[section_id].append(self.leaves) + self.leaves = self.cond_entry[section_id] + else: + # If this is the first time we split a section, just remember the split + # point. + self.cond_entry[section_id] = self.leaves + + def exit_cond_section(self, section_id): + """Exits a conditional section.""" + for split in self.cond_leaves[section_id]: + self.leaves |= split + del self.cond_entry[section_id] + del self.cond_leaves[section_id] + + def enter_finally_section(self, section_id): + """Enters a finally section.""" + # TODO(mdan): This, not the caller, should track the active sections. + self.finally_section_subgraphs[section_id] = [None, None] + if self.leaves: + self.finally_section_has_direct_flow[section_id] = True + else: + self.finally_section_has_direct_flow[section_id] = False + self.pending_finally_sections.add(section_id) + + def exit_finally_section(self, section_id): + """Exits a finally section.""" + assert section_id not in self.pending_finally_sections, 'Empty finally?' + self.finally_section_subgraphs[section_id][1] = self.leaves + # If the guard can only be reached by a jump, then it will not flow + # into the statement that follows it. + if not self.finally_section_has_direct_flow[section_id]: + self.leaves = set() + del self.finally_section_has_direct_flow[section_id] + + def build(self): + """Returns the CFG accumulated so far and resets the builder. + + Returns: + Graph + """ + # Freeze the nodes. + for node in self.node_index.values(): + node.freeze() + + # Build the statement edges. + stmt_next = {} + stmt_prev = {} + for node, _ in self.forward_edges: + for stmt in self.owners[node]: + if stmt not in stmt_next: + stmt_next[stmt] = set() + if stmt not in stmt_prev: + stmt_prev[stmt] = set() + for first, second in self.forward_edges: + stmts_exited = self.owners[first] - self.owners[second] + for stmt in stmts_exited: + stmt_next[stmt].add(second) + stmts_entered = self.owners[second] - self.owners[first] + for stmt in stmts_entered: + stmt_prev[stmt].add(first) + for stmt in stmt_next: + stmt_next[stmt] = frozenset(stmt_next[stmt]) + for stmt in stmt_prev: + stmt_prev[stmt] = frozenset(stmt_prev[stmt]) + + # Construct the final graph object. + result = Graph( + entry=self.head, + exit=self.leaves, + error=self.errors, + index=self.node_index, + stmt_prev=stmt_prev, + stmt_next=stmt_next) + + # Reset the state. + self.reset() + + return result + + +class AstToCfg(gast.NodeVisitor): + """Converts an AST to CFGs. + + A separate CFG will be constructed for each function. + """ + + def __init__(self): + super(AstToCfg, self).__init__() + + self.builder_stack = [] + self.builder = None + self.cfgs = {} + + self.lexical_scopes = [] + + def _enter_lexical_scope(self, node): + self.lexical_scopes.append(node) + + def _exit_lexical_scope(self, node): + leaving_node = self.lexical_scopes.pop() + assert node == leaving_node + + def _get_enclosing_scopes(self, include, stop_at): + included = [] + for node in reversed(self.lexical_scopes): + if isinstance(node, include): + included.append(node) + if isinstance(node, stop_at): + return node, included + return None, included + + def _process_basic_statement(self, node): + self.generic_visit(node) + self.builder.add_ordinary_node(node) + + def _process_exit_statement(self, node, *exits_nodes_of_type): + # Note: this is safe because we process functions separately. + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=tuple(exits_nodes_of_type), + ) + if try_node is None: + raise ValueError( + '%s that is not enclosed by any of %s' % (node, exits_nodes_of_type)) + self.builder.add_exit_node(node, try_node, guards) + + def _process_continue_statement(self, node, *loops_to_nodes_of_type): + # Note: this is safe because we process functions separately. + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=tuple(loops_to_nodes_of_type), + ) + if try_node is None: + raise ValueError('%s that is not enclosed by any of %s' % + (node, loops_to_nodes_of_type)) + self.builder.add_continue_node(node, try_node, guards) + + def visit_FunctionDef(self, node): + # We also keep the FunctionDef node in the CFG. This allows us to determine + # things like reaching definitions via closure. Note that the function body + # will be stored in a separate graph, because function definitions are not + # the same as function calls. + if self.builder is not None: + self.builder.add_ordinary_node(node) + + self.builder_stack.append(self.builder) + self.builder = GraphBuilder(node) + + self._enter_lexical_scope(node) + self.builder.enter_section(node) + + self._process_basic_statement(node.args) + for stmt in node.body: + self.visit(stmt) + + self.builder.exit_section(node) + self._exit_lexical_scope(node) + + self.cfgs[node] = self.builder.build() + self.builder = self.builder_stack.pop() + + def visit_Lambda(self, node): + # TODO(mdan): Treat like FunctionDef? That would be a separate CFG. + raise NotImplementedError() + + def visit_Return(self, node): + self._process_exit_statement(node, gast.FunctionDef) + + def visit_Expr(self, node): + self._process_basic_statement(node) + + def visit_Assign(self, node): + self._process_basic_statement(node) + + def visit_AnnAssign(self, node): + self._process_basic_statement(node) + + def visit_AugAssign(self, node): + self._process_basic_statement(node) + + def visit_Print(self, node): + self._process_basic_statement(node) + + def visit_Raise(self, node): + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=(gast.FunctionDef,), + ) + if try_node is None: + raise ValueError('%s that is not enclosed by any FunctionDef' % node) + self.builder.add_error_node(node, guards) + + def visit_Assert(self, node): + # Ignoring the effect of exceptions. + self._process_basic_statement(node) + + def visit_Delete(self, node): + self._process_basic_statement(node) + + def visit_If(self, node): + # No need to track ifs as lexical scopes, for now. + # Lexical scopes are generally tracked in order to be able to resolve the + # targets of jump statements like break/continue/etc. Since there is no + # statement that can interrupt a conditional, we don't need to track their + # lexical scope. That may change in the future. + self.builder.begin_statement(node) + + self.builder.enter_cond_section(node) + self._process_basic_statement(node.test) + + self.builder.new_cond_branch(node) + for stmt in node.body: + self.visit(stmt) + + self.builder.new_cond_branch(node) + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_cond_section(node) + self.builder.end_statement(node) + + def visit_While(self, node): + self.builder.begin_statement(node) + self._enter_lexical_scope(node) + + self.builder.enter_section(node) + + self.builder.enter_loop_section(node, node.test) + for stmt in node.body: + self.visit(stmt) + self.builder.exit_loop_section(node) + + # Note: although the orelse is technically part of the loop node, + # the statements inside it don't affect the loop itself. For example, a + # break in the loop's orelse will not affect the loop itself. + self._exit_lexical_scope(node) + + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_section(node) + self.builder.end_statement(node) + + def visit_For(self, node): + self.builder.begin_statement(node) + self._enter_lexical_scope(node) + + self.builder.enter_section(node) + + # TODO(mdan): Strictly speaking, this should be node.target + node.iter. + # A blind dataflow analysis would have to process both node.target and + # node.iter to properly process read and write access. + self.builder.enter_loop_section(node, node.iter) + for stmt in node.body: + self.visit(stmt) + self.builder.exit_loop_section(node) + + # Note: although the orelse is technically part of the loop node, + # they don't count as loop bodies. For example, a break in the loop's + # orelse will affect the parent loop, not the current one. + self._exit_lexical_scope(node) + + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_section(node) + self.builder.end_statement(node) + + def visit_Break(self, node): + self._process_exit_statement(node, gast.While, gast.For) + + def visit_Continue(self, node): + self._process_continue_statement(node, gast.While, gast.For) + + def visit_Try(self, node): + self._enter_lexical_scope(node) + + for stmt in node.body: + self.visit(stmt) + # Unlike loops, the orelse is a simple continuation of the body. + for stmt in node.orelse: + self.visit(stmt) + + if node.handlers: + # TODO(mdan): Should we still support bare try/except? Might be confusing. + raise NotImplementedError('exceptions are not yet supported') + + self._exit_lexical_scope(node) + + self.builder.enter_finally_section(node) + for stmt in node.finalbody: + self.visit(stmt) + self.builder.exit_finally_section(node) + + def visit_With(self, node): + # TODO(mdan): Mark the context manager's exit call as exit guard. + for item in node.items: + self._process_basic_statement(item) + for stmt in node.body: + self.visit(stmt) + + +def build(node): + visitor = AstToCfg() + visitor.visit(node) + return visitor.cfgs diff --git a/tensorflow/contrib/autograph/pyct/cfg_test.py b/tensorflow/contrib/autograph/pyct/cfg_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0a85d615cc5a7dcebf405aebdbfe409be0b5cf --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/cfg_test.py @@ -0,0 +1,969 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cfg module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.python.platform import test + + +class CountingVisitor(cfg.GraphVisitor): + + def __init__(self, graph): + super(CountingVisitor, self).__init__(graph) + self.counts = {} + + def init_state(self, _): + return None + + def visit_node(self, node): + self.counts[node.ast_node] = self.counts.get(node.ast_node, 0) + 1 + return False # visit only once + + +class GraphVisitorTest(test.TestCase): + + def _build_cfg(self, fn): + node, _ = parser.parse_entity(fn) + cfgs = cfg.build(node) + return cfgs, node + + def test_basic_coverage_forward(self): + + def test_fn(a): + while a > 0: + a = 1 + break + return a # pylint:disable=unreachable + a = 2 + + graphs, node = self._build_cfg(test_fn) + graph, = graphs.values() + visitor = CountingVisitor(graph) + visitor.visit_forward() + fn_node = node.body[0] + + self.assertEqual(visitor.counts[fn_node.args], 1) + self.assertEqual(visitor.counts[fn_node.body[0].test], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1) + # The return node should be unreachable in forward direction. + self.assertTrue(fn_node.body[0].body[2] not in visitor.counts) + self.assertEqual(visitor.counts[fn_node.body[1]], 1) + + def test_basic_coverage_reverse(self): + + def test_fn(a): + while a > 0: + a = 1 + break + return a # pylint:disable=unreachable + a = 2 + + graphs, node = self._build_cfg(test_fn) + graph, = graphs.values() + visitor = CountingVisitor(graph) + visitor.visit_reverse() + fn_node = node.body[0] + + self.assertEqual(visitor.counts[fn_node.args], 1) + self.assertEqual(visitor.counts[fn_node.body[0].test], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1) + self.assertTrue(visitor.counts[fn_node.body[0].body[2]], 1) + self.assertEqual(visitor.counts[fn_node.body[1]], 1) + + +class AstToCfgTest(test.TestCase): + + def _build_cfg(self, fn): + node, _ = parser.parse_entity(fn) + cfgs = cfg.build(node) + return cfgs + + def _repr_set(self, node_set): + return frozenset(repr(n) for n in node_set) + + def _as_set(self, elements): + if elements is None: + return frozenset() + elif isinstance(elements, str): + return frozenset((elements,)) + else: + return frozenset(elements) + + def assertGraphMatches(self, graph, edges): + """Tests whether the CFG contains the specified edges.""" + for prev, node_repr, next_ in edges: + matched = False + for cfg_node in graph.index.values(): + if repr(cfg_node) == node_repr: + if (self._as_set(prev) == frozenset(map(repr, cfg_node.prev)) and + self._as_set(next_) == frozenset(map(repr, cfg_node.next))): + matched = True + break + if not matched: + self.fail( + 'match failed for node "%s" in graph:\n%s' % (node_repr, graph)) + + def assertStatementEdges(self, graph, edges): + """Tests whether the CFG contains the specified statement edges.""" + for prev_node_reprs, node_repr, next_node_reprs in edges: + matched = False + partial_matches = [] + self.assertSetEqual( + frozenset(graph.stmt_next.keys()), frozenset(graph.stmt_prev.keys())) + for stmt_ast_node in graph.stmt_next: + ast_repr = '%s:%s' % (stmt_ast_node.__class__.__name__, + stmt_ast_node.lineno) + if ast_repr == node_repr: + actual_next = frozenset(map(repr, graph.stmt_next[stmt_ast_node])) + actual_prev = frozenset(map(repr, graph.stmt_prev[stmt_ast_node])) + partial_matches.append((actual_prev, node_repr, actual_next)) + if (self._as_set(prev_node_reprs) == actual_prev and + self._as_set(next_node_reprs) == actual_next): + matched = True + break + if not matched: + self.fail('edges mismatch for %s: %s' % (node_repr, partial_matches)) + + def test_straightline(self): + + def test_fn(a): + a += 1 + a = 2 + a = 3 + return + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', 'a += 1'), + ('a += 1', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', 'return'), + ('a = 3', 'return', None), + ), + ) + + def test_straightline_no_return(self): + + def test_fn(a, b): + a = b + 1 + a += max(a) + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a, b', 'a = b + 1'), + ('a = b + 1', 'a += max(a)', None), + ), + ) + + def test_unreachable_code(self): + + def test_fn(a): + return + a += 1 # pylint:disable=unreachable + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', 'return'), + ('a', 'return', None), + (None, 'a += 1', None), + ), + ) + + def test_if_straightline(self): + + def test_fn(a): + if a > 0: + a = 1 + else: + a += -1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('(a > 0)', 'a = 1', None), + ('(a > 0)', 'a += -1', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'If:2', None),), + ) + + def test_branch_nested(self): + + def test_fn(a): + if a > 0: + if a > 1: + a = 1 + else: + a = 2 + else: + if a > 2: + a = 3 + else: + a = 4 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('a', '(a > 0)', ('(a > 1)', '(a > 2)')), + ('(a > 0)', '(a > 1)', ('a = 1', 'a = 2')), + ('(a > 1)', 'a = 1', None), + ('(a > 1)', 'a = 2', None), + ('(a > 0)', '(a > 2)', ('a = 3', 'a = 4')), + ('(a > 2)', 'a = 3', None), + ('(a > 2)', 'a = 4', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'If:2', None), + ('(a > 0)', 'If:3', None), + ('(a > 0)', 'If:8', None), + ), + ) + + def test_branch_straightline_semi(self): + + def test_fn(a): + if a > 0: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('a', '(a > 0)', 'a = 1'), + ('(a > 0)', 'a = 1', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'If:2', None),), + ) + + def test_branch_return(self): + + def test_fn(a): + if a > 0: + return + else: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', ('return', 'a = 1')), + ('(a > 0)', 'a = 1', 'a = 2'), + ('(a > 0)', 'return', None), + ('a = 1', 'a = 2', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'If:2', 'a = 2'),), + ) + + def test_branch_return_minimal(self): + + def test_fn(a): + if a > 0: + return + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', 'return'), + ('(a > 0)', 'return', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'If:2', None),), + ) + + def test_while_straightline(self): + + def test_fn(a): + while a > 0: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')), + ('(a > 0)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'While:2', 'a = 2'),), + ) + + def test_while_else_straightline(self): + + def test_fn(a): + while a > 0: + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')), + ('(a > 0)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'While:2', 'a = 3'),), + ) + + def test_while_else_continue(self): + + def test_fn(a): + while a > 0: + if a > 1: + continue + else: + a = 0 + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'continue', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('continue', 'a = 0')), + ('(a > 1)', 'continue', '(a > 0)'), + ('a = 0', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'While:2', 'a = 3'), + ('(a > 0)', 'If:3', ('a = 1', '(a > 0)')), + ), + ) + + def test_while_else_break(self): + + def test_fn(a): + while a > 0: + if a > 1: + break + a = 1 + else: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('break', 'a = 1')), + ('(a > 1)', 'break', 'a = 3'), + ('(a > 1)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + (('break', 'a = 2'), 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'While:2', 'a = 3'), + ('(a > 0)', 'If:3', ('a = 1', 'a = 3')), + ), + ) + + def test_while_else_return(self): + + def test_fn(a): + while a > 0: + if a > 1: + return + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('return', 'a = 1')), + ('(a > 1)', 'return', None), + ('(a > 1)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'While:2', 'a = 3'), + ('(a > 0)', 'If:3', 'a = 1'), + ), + ) + + def test_while_nested_straightline(self): + + def test_fn(a): + while a > 0: + while a > 1: + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'a = 1'), '(a > 1)', ('a = 1', 'a = 2')), + ('(a > 1)', 'a = 1', '(a > 1)'), + ('(a > 1)', 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'While:2', 'a = 3'), + ('(a > 0)', 'While:3', 'a = 2'), + ), + ) + + def test_while_nested_continue(self): + + def test_fn(a): + while a > 0: + while a > 1: + if a > 3: + continue + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'continue', 'a = 1'), '(a > 1)', ('(a > 3)', 'a = 2')), + ('(a > 1)', '(a > 3)', ('continue', 'a = 1')), + ('(a > 3)', 'continue', '(a > 1)'), + ('(a > 3)', 'a = 1', '(a > 1)'), + ('(a > 1)', 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'While:2', 'a = 3'), + ('(a > 0)', 'While:3', 'a = 2'), + ('(a > 1)', 'If:4', ('a = 1', '(a > 1)')), + ), + ) + + def test_while_nested_break(self): + + def test_fn(a): + while a > 0: + while a > 1: + if a > 2: + break + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches(graph, ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'a = 1'), '(a > 1)', ('(a > 2)', 'a = 2')), + ('(a > 1)', '(a > 2)', ('break', 'a = 1')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', 'a = 1', '(a > 1)'), + (('(a > 1)', 'break'), 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + )) + self.assertStatementEdges( + graph, + ( + ('a', 'While:2', 'a = 3'), + ('(a > 0)', 'While:3', 'a = 2'), + ('(a > 1)', 'If:4', ('a = 1', 'a = 2')), + ), + ) + + def test_for_straightline(self): + + def test_fn(a): + for a in range(0, a): + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')), + ('range(0, a)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'For:2', 'a = 2'),), + ) + + def test_for_else_straightline(self): + + def test_fn(a): + for a in range(0, a): + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')), + ('range(0, a)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'For:2', 'a = 3'),), + ) + + def test_for_else_continue(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + continue + else: + a = 0 + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'continue', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('continue', 'a = 0')), + ('(a > 1)', 'continue', 'range(0, a)'), + ('(a > 1)', 'a = 0', 'a = 1'), + ('a = 0', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'For:2', 'a = 3'), + ('range(0, a)', 'If:3', ('a = 1', 'range(0, a)')), + ), + ) + + def test_for_else_break(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + break + a = 1 + else: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('break', 'a = 1')), + ('(a > 1)', 'break', 'a = 3'), + ('(a > 1)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + (('break', 'a = 2'), 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'For:2', 'a = 3'), + ('range(0, a)', 'If:3', ('a = 1', 'a = 3')), + ), + ) + + def test_for_else_return(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + return + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('return', 'a = 1')), + ('(a > 1)', 'return', None), + ('(a > 1)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'For:2', 'a = 3'), + ('range(0, a)', 'If:3', 'a = 1'), + ), + ) + + def test_for_nested_straightline(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'b += 1'), 'range(1, a)', ('b += 1', 'a = 2')), + ('range(1, a)', 'b += 1', 'range(1, a)'), + ('range(1, a)', 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'For:2', 'a = 3'), + ('range(0, a)', 'For:3', 'a = 2'), + ), + ) + + def test_for_nested_continue(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + if a > 3: + continue + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'continue', 'b += 1'), 'range(1, a)', + ('(a > 3)', 'a = 2')), + ('range(1, a)', '(a > 3)', ('continue', 'b += 1')), + ('(a > 3)', 'continue', 'range(1, a)'), + ('(a > 3)', 'b += 1', 'range(1, a)'), + ('range(1, a)', 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'For:2', 'a = 3'), + ('range(0, a)', 'For:3', 'a = 2'), + ('range(1, a)', 'If:4', ('b += 1', 'range(1, a)')), + ), + ) + + def test_for_nested_break(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + if a > 2: + break + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'b += 1'), 'range(1, a)', ('(a > 2)', 'a = 2')), + ('range(1, a)', '(a > 2)', ('break', 'b += 1')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', 'b += 1', 'range(1, a)'), + (('range(1, a)', 'break'), 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'For:2', 'a = 3'), + ('range(0, a)', 'For:3', 'a = 2'), + ('range(1, a)', 'If:4', ('b += 1', 'a = 2')), + ), + ) + + def test_complex(self): + + def test_fn(a): + b = 0 + while a > 0: + for b in range(0, a): + if a > 2: + break + if a > 3: + if a > 4: + continue + else: + max(a) + break + b += 1 + else: # for b in range(0, a): + return a + a = 2 + for a in range(1, a): + return b + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('b = 0', 'a = 2'), '(a > 0)', ('range(0, a)', 'range(1, a)')), + ( + ('(a > 0)', 'continue', 'b += 1'), + 'range(0, a)', + ('(a > 2)', 'return a'), + ), + ('range(0, a)', '(a > 2)', ('(a > 3)', 'break')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', '(a > 3)', ('(a > 4)', 'b += 1')), + ('(a > 3)', '(a > 4)', ('continue', 'max(a)')), + ('(a > 4)', 'max(a)', 'break'), + ('max(a)', 'break', 'a = 2'), + ('(a > 4)', 'continue', 'range(0, a)'), + ('(a > 3)', 'b += 1', 'range(0, a)'), + ('range(0, a)', 'return a', None), + ('break', 'a = 2', '(a > 0)'), + ('(a > 0)', 'range(1, a)', ('return b', 'a = 3')), + ('range(1, a)', 'return b', None), + ('range(1, a)', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('b = 0', 'While:3', 'range(1, a)'), + ('(a > 0)', 'For:4', 'a = 2'), + ('range(0, a)', 'If:5', ('(a > 3)', 'a = 2')), + ('(a > 2)', 'If:7', ('b += 1', 'a = 2', 'range(0, a)')), + ('(a > 3)', 'If:8', ('a = 2', 'range(0, a)')), + ('(a > 0)', 'For:17', 'a = 3'), + ), + ) + + def test_finally_straightline(self): + + def test_fn(a): + try: + a += 1 + finally: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', 'a += 1', 'a = 2'), + ('a += 1', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_return_finally(self): + + def test_fn(a): + try: + return a + finally: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', 'return a', 'a = 1'), + ('return a', 'a = 1', None), + (None, 'a = 2', None), + ), + ) + + def test_break_finally(self): + + def test_fn(a): + while a > 0: + try: + break + finally: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', 'break'), + ('(a > 0)', 'break', 'a = 1'), + ('break', 'a = 1', None), + ), + ) + + def test_continue_finally(self): + + def test_fn(a): + while a > 0: + try: + continue + finally: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', 'continue'), + ('(a > 0)', 'continue', 'a = 1'), + ('continue', 'a = 1', '(a > 0)'), + ), + ) + + def test_with_straightline(self): + + def test_fn(a): + with max(a) as b: + a = 0 + return b + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', 'max(a)', 'a = 0'), + ('max(a)', 'a = 0', 'return b'), + ('a = 0', 'return b', None), + ), + ) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a0938b3e5f0e52532f63fea6fb4c3e478fc51d93 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD @@ -0,0 +1,39 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "common_transformers", + srcs = [ + "anf.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/autograph/pyct", + "@gast_archive//:gast", + "@six_archive//:six", + ], +) + +py_test( + name = "anf_test", + srcs = ["anf_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":common_transformers", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py new file mode 100644 index 0000000000000000000000000000000000000000..e42f679cfe31f919e10f7baf409247014b3cf386 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py @@ -0,0 +1,418 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Conversion to A-normal form. + +The general idea of A-normal form is that every intermediate value is +explicitly named with a variable. For more, see +https://en.wikipedia.org/wiki/A-normal_form. + +The specific converters used here are based on Python AST semantics as +documented at https://greentreesnakes.readthedocs.io/en/latest/. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast +import six + +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer + + +class DummyGensym(object): + """A dumb gensym that suffixes a stem by sequential numbers from 1000.""" + + def __init__(self, entity_info): + del entity_info + # A proper implementation needs to account for: + # * entity_info.namespace + # * all the symbols defined in the AST + # * the symbols generated so far + self._idx = 0 + + def new_name(self, stem='tmp'): + self._idx += 1 + return stem + '_' + str(1000 + self._idx) + + +class AnfTransformer(transformer.Base): + """Performs the conversion to A-normal form (ANF).""" + + # The algorithm is a postorder recursive tree walk. Any given node A may, in + # general, require creation of a series B of Assign statements, which compute + # and explicitly name the intermediate values needed to compute the value of + # A. If A was already a statement, it can be replaced with the sequence B + + # [A]. If A was an expression, B needs to be propagated up the tree until a + # statement is encountered. Since the `ast.NodeTransformer` framework makes + # no provision for subtraversals returning side information, this class + # accumulates the sequence B in an instance variable. + + # The only other subtlety is that some Python statements (like `if`) have both + # expression fields (`test`) and statement list fields (`body` and `orelse`). + # Any additional assignments needed to name all the intermediate values in the + # `test` can be prepended to the `if` node, but assignments produced by + # processing the `body` and the `orelse` need to be kept together with them, + # and not accidentally lifted out of the `if`. + + def __init__(self, entity_info, gensym_source=None): + """Creates an ANF transformer. + + Args: + entity_info: transformer.EntityInfo + gensym_source: An optional object with the same interface as `DummyGensym` + for generating unique names + """ + super(AnfTransformer, self).__init__(entity_info) + if gensym_source is None: + self._gensym = DummyGensym(entity_info) + else: + self._gensym = gensym_source(entity_info) + self._pending_statements = [] + + def _consume_pending_statements(self): + ans = self._pending_statements + self._pending_statements = [] + return ans + + def _add_pending_statement(self, stmt): + self._pending_statements.append(stmt) + + _trivial_nodes = ( + # Non-nodes that show up as AST fields + bool, six.string_types, + # Leaf nodes that are already in A-normal form + gast.expr_context, gast.Name, gast.Num, gast.Str, gast.Bytes, + gast.NameConstant, gast.Ellipsis, + # Binary operators + gast.Add, gast.Sub, gast.Mult, gast.Div, gast.Mod, gast.Pow, gast.LShift, + gast.RShift, gast.BitOr, gast.BitXor, gast.BitAnd, gast.FloorDiv, + # Unary operators + gast.Invert, gast.Not, gast.UAdd, gast.USub, + # Comparison operators + gast.Eq, gast.NotEq, gast.Lt, gast.LtE, gast.Gt, gast.GtE, + gast.Is, gast.IsNot, gast.In, gast.NotIn, + ) + + def _is_node_trivial(self, node): + if node is None: + return True + elif isinstance(node, self._trivial_nodes): + return True + elif isinstance(node, gast.keyword): + return self._is_node_trivial(node.value) + elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)): + return self._are_children_trivial(node) + return False + + def _are_children_trivial(self, node): + for field in node._fields: + if not field.startswith('__'): + if not self._is_node_trivial(getattr(node, field)): + return False + return True + + def _ensure_node_is_trivial(self, node): + if node is None: + return node + elif isinstance(node, self._trivial_nodes): + return node + elif isinstance(node, list): + # If something's field was actually a list, e.g., variadic arguments. + return [self._ensure_node_is_trivial(n) for n in node] + elif isinstance(node, gast.keyword): + node.value = self._ensure_node_is_trivial(node.value) + return node + elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)): + return self._ensure_fields_trivial(node) + elif isinstance(node, gast.expr): + temp_name = self._gensym.new_name() + temp_assign = templates.replace( + 'temp_name = expr', temp_name=temp_name, expr=node)[0] + self._add_pending_statement(temp_assign) + answer = templates.replace('temp_name', temp_name=temp_name)[0] + return answer + else: + raise ValueError('Do not know how to treat {}'.format(node)) + + def _ensure_fields_trivial(self, node): + for field in node._fields: + if field.startswith('__'): + continue + setattr(node, field, self._ensure_node_is_trivial(getattr(node, field))) + return node + + def _visit_strict_statement(self, node, trivialize_children=True): + assert not self._pending_statements + node = self.generic_visit(node) + if trivialize_children: + self._ensure_fields_trivial(node) + results = self._consume_pending_statements() + results.append(node) + return results + + def _visit_strict_expression(self, node): + node = self.generic_visit(node) + self._ensure_fields_trivial(node) + return node + + # Note on code order: These are listed in the same order as the grammar + # elements on https://github.com/serge-sans-paille/gast + + # FunctionDef, AsyncFunctionDef, and ClassDef should be correct by default. + + def visit_Return(self, node): + return self._visit_strict_statement(node) + + def visit_Delete(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + def visit_Assign(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + def visit_AugAssign(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + def visit_Print(self, node): + return self._visit_strict_statement(node) + + def visit_For(self, node): + assert not self._pending_statements + # It's important to visit node.iter first, because any statements created + # thereby need to live outside the body. + self.visit(node.iter) + node.iter = self._ensure_node_is_trivial(node.iter) + iter_stmts = self._consume_pending_statements() + # This generic_visit will revisit node.iter, but that is both correct and + # cheap because by this point node.iter is trivial. + node = self.generic_visit(node) + assert not self._pending_statements + iter_stmts.append(node) + return iter_stmts + + def visit_AsyncFor(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial AsyncFor nodes not supported yet ' + '(need to think through the semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_While(self, node): + if not self._is_node_trivial(node.test): + msg = ('While with nontrivial test not supported yet ' + '(need to avoid precomputing the test).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_If(self, node): + assert not self._pending_statements + # It's important to visit node.test first, because any statements created + # thereby need to live outside the body. + self.visit(node.test) + node.test = self._ensure_node_is_trivial(node.test) + condition_stmts = self._consume_pending_statements() + # This generic_visit will revisit node.test, but that is both correct and + # cheap because by this point node.test is trivial. + node = self.generic_visit(node) + assert not self._pending_statements + condition_stmts.append(node) + return condition_stmts + + def visit_With(self, node): + assert not self._pending_statements + # It's important to visit node.items first, because any statements created + # thereby need to live outside the body. + for item in node.items: + self.visit(item) + node.items = [self._ensure_node_is_trivial(n) for n in node.items] + contexts_stmts = self._consume_pending_statements() + # This generic_visit will revisit node.items, but that is both correct and + # cheap because by this point node.items is trivial. + node = self.generic_visit(node) + assert not self._pending_statements + contexts_stmts.append(node) + return contexts_stmts + + def visit_AsyncWith(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial AsyncWith nodes not supported yet ' + '(need to think through the semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Raise(self, node): + return self._visit_strict_statement(node) + + # Try should be correct by default. + + def visit_Assert(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Assert nodes not supported yet ' + '(need to avoid computing the test when assertions are off, and ' + 'avoid computing the irritant when the assertion does not fire).') + raise ValueError(msg) + return self.generic_visit(node) + + # Import and ImportFrom should be correct by default. + + def visit_Exec(self, node): + return self._visit_strict_statement(node) + + # Global and Nonlocal should be correct by default. + + def visit_Expr(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + # Pass, Break, and Continue should be correct by default. + + def visit_BoolOp(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial BoolOp nodes not supported yet ' + '(need to preserve short-circuiting semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_BinOp(self, node): + return self._visit_strict_expression(node) + + def visit_UnaryOp(self, node): + return self._visit_strict_expression(node) + + def visit_Lambda(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Lambda nodes not supported ' + '(cannot insert statements into lambda bodies).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_IfExp(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial IfExp nodes not supported yet ' + '(need to convert to If statement, to evaluate branches lazily ' + 'and insert statements into them).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Dict(self, node): + return self._visit_strict_expression(node) + + def visit_Set(self, node): + return self._visit_strict_expression(node) + + def visit_ListComp(self, node): + msg = ('ListComp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_SetComp(self, node): + msg = ('SetComp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_DictComp(self, node): + msg = ('DictComp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_GeneratorExp(self, node): + msg = ('GeneratorExp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_Await(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Await nodes not supported yet ' + '(need to think through the semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Yield(self, node): + return self._visit_strict_expression(node) + + def visit_YieldFrom(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial YieldFrom nodes not supported yet ' + '(need to unit-test them in Python 2).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Compare(self, node): + if len(node.ops) > 1: + msg = ('Multi-ary compare nodes not supported yet ' + '(need to preserve short-circuiting semantics).') + raise ValueError(msg) + return self._visit_strict_expression(node) + + def visit_Call(self, node): + return self._visit_strict_expression(node) + + def visit_Repr(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Repr nodes not supported yet ' + '(need to research their syntax and semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_FormattedValue(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial FormattedValue nodes not supported yet ' + '(need to unit-test them in Python 2).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_JoinedStr(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial JoinedStr nodes not supported yet ' + '(need to unit-test them in Python 2).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Attribute(self, node): + return self._visit_strict_expression(node) + + def visit_Subscript(self, node): + return self._visit_strict_expression(node) + + # Starred and Name are correct by default, because the right thing to do is to + # just recur. + + def visit_List(self, node): + return self._visit_strict_expression(node) + + def visit_Tuple(self, node): + return self._visit_strict_expression(node) + + +def transform(node, entity_info, gensym_source=None): + """Converts the given node to A-normal form (ANF). + + The general idea of A-normal form: https://en.wikipedia.org/wiki/A-normal_form + + The specific converters used here are based on Python AST semantics as + documented at https://greentreesnakes.readthedocs.io/en/latest/. + + Args: + node: The node to transform. + entity_info: transformer.EntityInfo. TODO(mdan): What information does this + argument provide? + gensym_source: An optional object with the same interface as `DummyGensym` + for generating unique names. + """ + return AnfTransformer(entity_info, gensym_source=gensym_source).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py new file mode 100644 index 0000000000000000000000000000000000000000..951974820c784974cb5bb2320adbb2b07f9332df --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py @@ -0,0 +1,403 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for anf module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import textwrap + +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.common_transformers import anf +from tensorflow.python.platform import test + + +class DummyGensym(object): + """A dumb gensym that suffixes a stem by sequential numbers from 1000.""" + + def __init__(self, entity_info): + del entity_info + # A proper implementation needs to account for: + # * entity_info.namespace + # * all the symbols defined in the AST + # * the symbols generated so far + self._idx = 0 + + def new_name(self, stem='tmp'): + self._idx += 1 + return stem + '_' + str(1000 + self._idx) + + +class AnfTransformerTest(test.TestCase): + + def _simple_source_info(self): + return transformer.EntityInfo( + source_code=None, + source_file=None, + namespace=None, + arg_values=None, + arg_types=None, + owner_type=None) + + def test_basic(self): + def test_function(): + a = 0 + return a + node, _ = parser.parse_entity(test_function) + node = anf.transform(node.body[0], self._simple_source_info()) + result, _ = compiler.ast_to_object(node) + self.assertEqual(test_function(), result.test_function()) + + def assert_same_ast(self, expected_node, node, msg=None): + expected_source = compiler.ast_to_source(expected_node, indentation=' ') + expected_str = textwrap.dedent(expected_source).strip() + got_source = compiler.ast_to_source(node, indentation=' ') + got_str = textwrap.dedent(got_source).strip() + self.assertEqual(expected_str, got_str, msg=msg) + + def assert_body_anfs_as_expected(self, expected_fn, test_fn): + # Testing the code bodies only. Wrapping them in functions so the + # syntax highlights nicely, but Python doesn't try to execute the + # statements. + exp_node, _ = parser.parse_entity(expected_fn) + node, _ = parser.parse_entity(test_fn) + node = anf.transform( + node, self._simple_source_info(), gensym_source=DummyGensym) + exp_name = exp_node.body[0].name + # Ignoring the function names in the result because they can't be + # the same (because both functions have to exist in the same scope + # at the same time). + node.body[0].name = exp_name + self.assert_same_ast(exp_node, node) + # Check that ANF is idempotent + node_repeated = anf.transform( + node, self._simple_source_info(), gensym_source=DummyGensym) + self.assert_same_ast(node_repeated, node) + + def test_binop_basic(self): + + def test_function(x, y, z): + a = x + y + z + return a + + def expected_result(x, y, z): + tmp_1001 = x + y + a = tmp_1001 + z + return a + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_if_basic(self): + + def test_function(a, b, c, e, f, g): + if a + b + c: + d = e + f + g + return d + + def expected_result(a, b, c, e, f, g): + tmp_1001 = a + b + tmp_1002 = tmp_1001 + c + if tmp_1002: + tmp_1003 = e + f + d = tmp_1003 + g + return d + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_nested_binop_and_return(self): + + def test_function(b, c, d, e): + return (2 * b + c) + (d + e) + + def expected_result(b, c, d, e): + tmp_1001 = 2 * b + tmp_1002 = tmp_1001 + c + tmp_1003 = d + e + tmp_1004 = tmp_1002 + tmp_1003 + return tmp_1004 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_function_call_and_expr(self): + + def test_function(call_something, a, b, y, z, c, d, e, f, g, h, i): + call_something(a + b, y * z, kwarg=c + d, *(e + f), **(g + h + i)) + + def expected_result(call_something, a, b, y, z, c, d, e, f, g, h, i): + tmp_1001 = g + h + tmp_1002 = a + b + tmp_1003 = y * z + tmp_1004 = e + f + tmp_1005 = c + d + tmp_1006 = tmp_1001 + i + call_something(tmp_1002, tmp_1003, kwarg=tmp_1005, *tmp_1004, **tmp_1006) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_with_and_print(self): + + def test_function(a, b, c): + with a + b + c as d: + print(2 * d + 1) + + def expected_result(a, b, c): + tmp_1001 = a + b + tmp_1002 = tmp_1001 + c + with tmp_1002 as d: + tmp_1003 = 2 * d + tmp_1004 = tmp_1003 + 1 + print(tmp_1004) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_local_definition_and_binary_compare(self): + + def test_function(): + def foo(a, b): + return 2 * a < b + return foo + + def expected_result(): + def foo(a, b): + tmp_1001 = 2 * a + tmp_1002 = tmp_1001 < b + return tmp_1002 + return foo + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_list_literal(self): + + def test_function(a, b, c, d, e, f): + return [a + b, c + d, e + f] + + def expected_result(a, b, c, d, e, f): + tmp_1001 = a + b + tmp_1002 = c + d + tmp_1003 = e + f + tmp_1004 = [tmp_1001, tmp_1002, tmp_1003] + return tmp_1004 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_tuple_literal_and_unary(self): + + def test_function(a, b, c, d, e, f): + return (a + b, -(c + d), e + f) + + def expected_result(a, b, c, d, e, f): + tmp_1001 = c + d + tmp_1002 = a + b + tmp_1003 = -tmp_1001 + tmp_1004 = e + f + tmp_1005 = (tmp_1002, tmp_1003, tmp_1004) + return tmp_1005 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_set_literal(self): + + def test_function(a, b, c, d, e, f): + return set(a + b, c + d, e + f) + + def expected_result(a, b, c, d, e, f): + tmp_1001 = a + b + tmp_1002 = c + d + tmp_1003 = e + f + tmp_1004 = set(tmp_1001, tmp_1002, tmp_1003) + return tmp_1004 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_dict_literal_and_repr(self): + + def test_function(foo, bar, baz): + return repr({foo + bar + baz: 7 | 8}) + + def expected_result(foo, bar, baz): + tmp_1001 = foo + bar + tmp_1002 = tmp_1001 + baz + tmp_1003 = 7 | 8 + tmp_1004 = {tmp_1002: tmp_1003} + tmp_1005 = repr(tmp_1004) + return tmp_1005 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_field_read_and_write(self): + + def test_function(a, d): + a.b.c = d.e.f + 3 + + def expected_result(a, d): + tmp_1001 = a.b + tmp_1002 = d.e + tmp_1003 = tmp_1002.f + tmp_1001.c = tmp_1003 + 3 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_subscript_read_and_write(self): + + def test_function(a, b, c, d, e, f): + a[b][c] = d[e][f] + 3 + + def expected_result(a, b, c, d, e, f): + tmp_1001 = a[b] + tmp_1002 = d[e] + tmp_1003 = tmp_1002[f] + tmp_1001[c] = tmp_1003 + 3 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_augassign_and_delete(self): + + def test_function(a, x, y, z): + a += x + y + z + del a + del z[y][x] + + def expected_result(a, x, y, z): + tmp_1001 = x + y + a += tmp_1001 + z + del a + tmp_1002 = z[y] + del tmp_1002[x] + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_raise_yield_and_raise(self): + + def test_function(a, c, some_computed, exception): + yield a ** c + raise some_computed('complicated' + exception) + + def expected_result(a, c, some_computed, exception): + tmp_1001 = a ** c + yield tmp_1001 + tmp_1002 = 'complicated' + exception + tmp_1003 = some_computed(tmp_1002) + raise tmp_1003 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_with_and_if_with_expressions(self): + + def test_function(foo, bar, function, quux, quozzle, w, x, y, z): + with foo + bar: + function(x + y) + if quux + quozzle: + function(z / w) + + def expected_result(foo, bar, function, quux, quozzle, w, x, y, z): + tmp_1001 = foo + bar + with tmp_1001: + tmp_1002 = x + y + function(tmp_1002) + tmp_1003 = quux + quozzle + if tmp_1003: + tmp_1004 = z / w + function(tmp_1004) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_exec(self): + + def test_function(): + # The point is to test A-normal form conversion of exec + # pylint: disable=exec-used + exec('computed' + 5 + 'stuff', globals(), locals()) + + def expected_result(): + # pylint: disable=exec-used + tmp_1001 = 'computed' + 5 + tmp_1002 = tmp_1001 + 'stuff' + tmp_1003 = globals() + tmp_1004 = locals() + exec(tmp_1002, tmp_1003, tmp_1004) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_simple_while_and_assert(self): + + def test_function(foo, quux): + while foo: + assert quux + foo = foo + 1 * 3 + + def expected_result(foo, quux): + while foo: + assert quux + tmp_1001 = 1 * 3 + foo = foo + tmp_1001 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_for(self): + + def test_function(compute, something, complicated, foo): + for foo in compute(something + complicated): + bar = foo + 1 * 3 + return bar + + def expected_result(compute, something, complicated, foo): + tmp_1001 = something + complicated + tmp_1002 = compute(tmp_1001) + for foo in tmp_1002: + tmp_1003 = 1 * 3 + bar = foo + tmp_1003 + return bar + + self.assert_body_anfs_as_expected(expected_result, test_function) + + # This test collects several examples where the definition of A-normal form + # implemented by this transformer is questionable. Mostly it's here to spell + # out what the definition is in these cases. + def test_controversial(self): + + def test_function(b, c, d, f): + a = c + d + a.b = c + d + a[b] = c + d + a += c + d + a, b = c + a, b = c, d + a = f(c) + a = f(c + d) + a[b + d] = f.e(c + d) + + def expected_result(b, c, d, f): + a = c + d + a.b = c + d # Should be a.b = tmp? (Definitely not tmp = c + d) + a[b] = c + d # Should be a[b] = tmp? (Definitely not tmp = c + d) + a += c + d # Should be a += tmp? (Definitely not tmp = c + d) + a, b = c # Should be a = c[0], b = c[1]? Or not? + a, b = c, d # Should be a = c, b = d? Or not? + a = f(c) + tmp_1001 = c + d + a = f(tmp_1001) + tmp_1002 = b + d + tmp_1003 = f.e + tmp_1004 = c + d + a[tmp_1002] = tmp_1003(tmp_1004) # Or should be a[tmp1] = tmp2? + + self.assert_body_anfs_as_expected(expected_result, test_function) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/compiler.py b/tensorflow/contrib/autograph/pyct/compiler.py index 24c4517afa89147101f80af3ef60237132c1144c..f9cee109624dafd4da4a0981c5f8fda0a5d8a5e7 100644 --- a/tensorflow/contrib/autograph/pyct/compiler.py +++ b/tensorflow/contrib/autograph/pyct/compiler.py @@ -30,46 +30,112 @@ import tempfile import astor import gast +from tensorflow.contrib.autograph.pyct import origin_info + def ast_to_source(node, indentation=' '): - """Return the source code of given AST.""" - if isinstance(node, gast.AST): - node = gast.gast_to_ast(node) + """Return the source code of given AST. + + Args: + node: The code to compile, as an AST object. + indentation: The string to use for indentation. + + Returns: + code: The source code generated from the AST object + source_mapping: A mapping between the user and AutoGraph generated code. + """ + if not isinstance(node, (list, tuple)): + node = (node,) generator = astor.codegen.SourceGenerator(indentation, False, astor.string_repr.pretty_string) - generator.visit(node) - generator.result.append('\n') + + for n in node: + if isinstance(n, gast.AST): + n = gast.gast_to_ast(n) + generator.visit(n) + generator.result.append('\n') + # In some versions of Python, literals may appear as actual values. This # ensures everything is string. code = map(str, generator.result) - return astor.source_repr.pretty_source(code).lstrip() + code = astor.source_repr.pretty_source(code).lstrip() + return code -def ast_to_object( - node, indentation=' ', source_prefix=None, delete_on_exit=True): + +def ast_to_object(nodes, + indentation=' ', + include_source_map=False, + source_prefix=None, + delete_on_exit=True): """Return the Python objects represented by given AST. Compiling the AST code this way ensures that the source code is readable by e.g. `pdb` or `inspect`. Args: - node: The code to compile, as an AST object. - indentation: The string to use for indentation. - source_prefix: Optional string to print as-is into the source file. - delete_on_exit: Whether to delete the temporary file used for compilation - on exit. + nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST + object. + indentation: Text, the string to use for indentation. + include_source_map: bool, whether to attach a source map to the compiled + object. Also see origin_info.py. + source_prefix: Optional[Text], string to print as-is into the source file. + delete_on_exit: bool, whether to delete the temporary file used for + compilation on exit. Returns: - A module object containing the compiled source code. + compiled_nodes: A module object containing the compiled source code. + source: The source code of the compiled object + Raises: + ValueError: If ag_source_map__ is already in the namespace of the compiled + nodes. """ - source = ast_to_source(node, indentation) + if not isinstance(nodes, (list, tuple)): + nodes = (nodes,) + + source = ast_to_source(nodes, indentation=indentation) + + if source_prefix: + source = source_prefix + '\n' + source with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: module_name = os.path.basename(f.name[:-3]) - if source_prefix: - f.write(source_prefix) - f.write('\n') f.write(source) + + if isinstance(nodes, (list, tuple)): + indices = range(-len(nodes), 0) + else: + indices = (-1,) + + if include_source_map: + source_map = origin_info.source_map(nodes, source, f.name, indices) + + # TODO(mdan): Try flush() and delete=False instead. if delete_on_exit: atexit.register(lambda: os.remove(f.name)) - return imp.load_source(module_name, f.name), source + compiled_nodes = imp.load_source(module_name, f.name) + + # TODO(znado): Clean this up so we don't need to attach it to the namespace. + # TODO(znado): This does not work for classes because their methods share a + # namespace. + # This attaches the source map which is needed for error handling. Note that + # api.to_graph copies this source map into an attribute of the function. + # + # We need this so the ag_source_map__ variable is available to the call to + # rewrite_graph_construction_error in the except block inside each function + # that handles graph construction errors. + # + # We cannot get the rewritten function name until it is too late so templating + # is hard, and this cleanly fixes the + # issues encountered with nested functions because this is attached to the + # outermost one. + if include_source_map: + # TODO(mdan): This name should be decided by the caller. + source_map_name = 'ag_source_map__' + if source_map_name in compiled_nodes.__dict__: + raise ValueError('cannot convert %s because is has namespace attribute ' + '"%s", which is reserved for AutoGraph.' % + (compiled_nodes, source_map_name)) + compiled_nodes.__dict__[source_map_name] = source_map + + return compiled_nodes, source diff --git a/tensorflow/contrib/autograph/pyct/compiler_test.py b/tensorflow/contrib/autograph/pyct/compiler_test.py index 98cdc1506b6aced603df99662f1468687a55f92c..cf783da6a3e540c6901a5fe9a5e4afdb6b1cfc03 100644 --- a/tensorflow/contrib/autograph/pyct/compiler_test.py +++ b/tensorflow/contrib/autograph/pyct/compiler_test.py @@ -59,14 +59,14 @@ class CompilerTest(test.TestCase): value=gast.Str('c')) ]) + source = compiler.ast_to_source(node, indentation=' ') self.assertEqual( textwrap.dedent(""" if 1: a = b else: a = 'c' - """).strip(), - compiler.ast_to_source(node, indentation=' ').strip()) + """).strip(), source.strip()) def test_ast_to_object(self): node = gast.FunctionDef( diff --git a/tensorflow/contrib/autograph/pyct/context.py b/tensorflow/contrib/autograph/pyct/context.py deleted file mode 100644 index b34015cfd2888f0dbeb6492b9e7335d561bf4763..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/pyct/context.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Conversion context containers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -class EntityContext(object): - """Contains information about an entity, like source code. - - In general, objects of this class should be considered immutable. - - Attributes: - namer: Namer that matches the contract of all converters. - source_code: The entity's source code. - source_file: The entity's source file. - namespace: Dict[str->*], containing symbols visible to the entity - (excluding parameters). - arg_values: Dict[str->*], containing parameter values, if known. - arg_types: Dict[str->*], containing parameter types, if known. - owner_type: The surrounding class type of the function, if present. - """ - - # TODO(mdan): Remove the default and update tests. - def __init__(self, namer, source_code, source_file, namespace, arg_values, - arg_types, owner_type, recursive, type_annotation_func=None): - self.namer = namer - self.source_code = source_code - self.source_file = source_file - self.namespace = namespace - self.arg_values = {} if arg_values is None else arg_values - self.arg_types = {} if arg_types is None else arg_types - self.owner_type = owner_type - self.recursive = recursive - self.type_annotation_func = type_annotation_func diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py new file mode 100644 index 0000000000000000000000000000000000000000..b60651a30e342dabe40cbcef1486826e16c2e2c7 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/origin_info.py @@ -0,0 +1,186 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Container for origin source code information before AutoGraph compilation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import tokenize + +import gast +import six + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.python.util import tf_inspect + + +class LineLocation( + collections.namedtuple('LineLocation', ('filename', 'lineno'))): + """Similar to Location, but without column information. + + Attributes: + filename: Text + lineno: int, 1-based + """ + pass + + +class Location( + collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))): + """Encodes code location information. + + Attributes: + filename: Text + lineno: int, 1-based + col_offset: int + """ + + @property + def line_loc(self): + return LineLocation(self.filename, self.lineno) + + +class OriginInfo( + collections.namedtuple( + 'OriginInfo', + ('loc', 'function_name', 'source_code_line', 'comment'))): + """Container for information about the source code before conversion. + + Attributes: + loc: Location + function_name: Optional[Text] + source_code_line: Text + comment: Optional[Text] + """ + + def as_frame(self): + """Returns a 4-tuple consistent with the return of traceback.extract_tb.""" + return (self.loc.filename, self.loc.lineno, self.function_name, + self.source_code_line) + + +# TODO(mdan): This source map should be a class - easier to refer to. +def source_map(nodes, code, filename, indices_in_code): + """Creates a source map between an annotated AST and the code it compiles to. + + Args: + nodes: Iterable[ast.AST, ...] + code: Text + filename: Optional[Text] + indices_in_code: Union[int, Iterable[int, ...]], the positions at which + nodes appear in code. The parser always returns a module when parsing + code. This argument indicates the position in that module's body at + which the corresponding of node should appear. + + Returns: + Dict[CodeLocation, OriginInfo], mapping locations in code to locations + indicated by origin annotations in node. + """ + reparsed_nodes = parser.parse_str(code) + reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code] + + resolve(reparsed_nodes, code) + result = {} + + for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): + # Note: generated code might not be mapped back to its origin. + # TODO(mdan): Generated code should always be mapped to something. + origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) + final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) + if origin_info is None or final_info is None: + continue + + line_loc = LineLocation(filename, final_info.loc.lineno) + + existing_origin = result.get(line_loc) + if existing_origin is not None: + # Overlaps may exist because of child nodes, but almost never to + # different line locations. Exception make decorated functions, where + # both lines are mapped to the same line in the AST. + + # Line overlaps: keep bottom node. + if existing_origin.loc.line_loc == origin_info.loc.line_loc: + if existing_origin.loc.lineno >= origin_info.loc.lineno: + continue + + # In case of overlaps, keep the leftmost node. + if existing_origin.loc.col_offset <= origin_info.loc.col_offset: + continue + + result[line_loc] = origin_info + + return result + + +# TODO(znado): Consider refactoring this into a Visitor. +# TODO(mdan): Does this work correctly with inner functions? +def resolve(nodes, source, function=None): + """Adds an origin information to all nodes inside the body of function. + + Args: + nodes: Union[ast.AST, Iterable[ast.AST, ...]] + source: Text, the source code string for the function whose body nodes will + be annotated. + function: Callable, the function that will have all nodes inside of it + annotation with an OriginInfo annotation with key anno.Basic.ORIGIN. If + it is None then only the line numbers and column offset will be set in the + annotation, with the rest of the information being None. + + Returns: + A tuple of the AST node for function and a String containing its source + code. + """ + if not isinstance(nodes, (list, tuple)): + nodes = (nodes,) + + if function: + _, function_lineno = tf_inspect.getsourcelines(function) + function_filepath = tf_inspect.getsourcefile(function) + else: + function_lineno = None + function_filepath = None + + # TODO(mdan): Pull this to a separate utility. + code_reader = six.StringIO(source) + comment_map = {} + for token in tokenize.generate_tokens(code_reader.readline): + tok_type, tok_string, loc, _, _ = token + srow, _ = loc + if tok_type == tokenize.COMMENT: + comment_map[srow] = tok_string.strip()[1:].strip() + + source_lines = source.split('\n') + for node in nodes: + for n in gast.walk(node): + if not hasattr(n, 'lineno'): + continue + + lineno_in_body = n.lineno + + source_code_line = source_lines[lineno_in_body - 1] + if function: + source_lineno = function_lineno + lineno_in_body + function_name = function.__name__ + else: + source_lineno = lineno_in_body + function_name = None + + location = Location(function_filepath, source_lineno, n.col_offset) + origin = OriginInfo(location, function_name, + source_code_line, comment_map.get(source_lineno)) + anno.setanno(n, anno.Basic.ORIGIN, origin) diff --git a/tensorflow/contrib/autograph/pyct/origin_info_test.py b/tensorflow/contrib/autograph/pyct/origin_info_test.py new file mode 100644 index 0000000000000000000000000000000000000000..eeaa13007ea0ae331293c216a76352956c0ee9ec --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/origin_info_test.py @@ -0,0 +1,104 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for origin_info module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import origin_info +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.python.platform import test + + +class OriginInfoTest(test.TestCase): + + def test_source_map(self): + + def test_fn(x): + if x > 0: + x += 1 + return x + + node, source = parser.parse_entity(test_fn) + fn_node = node.body[0] + origin_info.resolve(fn_node, source) + + # Insert a traced line. + new_node = parser.parse_str('x = abs(x)').body[0] + anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN) + fn_node.body.insert(0, new_node) + + # Insert an untraced line. + fn_node.body.insert(0, parser.parse_str('x = 0').body[0]) + + modified_source = compiler.ast_to_source(fn_node) + + source_map = origin_info.source_map(fn_node, modified_source, + 'test_filename', [0]) + + loc = origin_info.LineLocation('test_filename', 1) + origin = source_map[loc] + self.assertEqual(origin.source_code_line, 'def test_fn(x):') + self.assertEqual(origin.loc.lineno, 1) + + # The untraced line, inserted second. + loc = origin_info.LineLocation('test_filename', 2) + self.assertFalse(loc in source_map) + + # The traced line, inserted first. + loc = origin_info.LineLocation('test_filename', 3) + origin = source_map[loc] + self.assertEqual(origin.source_code_line, ' if x > 0:') + self.assertEqual(origin.loc.lineno, 2) + + loc = origin_info.LineLocation('test_filename', 4) + origin = source_map[loc] + self.assertEqual(origin.source_code_line, ' if x > 0:') + self.assertEqual(origin.loc.lineno, 2) + + def test_resolve(self): + + def test_fn(x): + """Docstring.""" + return x # comment + + node, source = parser.parse_entity(test_fn) + fn_node = node.body[0] + origin_info.resolve(fn_node, source) + + origin = anno.getanno(fn_node, anno.Basic.ORIGIN) + self.assertEqual(origin.loc.lineno, 1) + self.assertEqual(origin.loc.col_offset, 0) + self.assertEqual(origin.source_code_line, 'def test_fn(x):') + self.assertIsNone(origin.comment) + + origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) + self.assertEqual(origin.loc.lineno, 2) + self.assertEqual(origin.loc.col_offset, 2) + self.assertEqual(origin.source_code_line, ' """Docstring."""') + self.assertIsNone(origin.comment) + + origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) + self.assertEqual(origin.loc.lineno, 3) + self.assertEqual(origin.loc.col_offset, 2) + self.assertEqual(origin.source_code_line, ' return x # comment') + self.assertEqual(origin.comment, 'comment') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/parser.py b/tensorflow/contrib/autograph/pyct/parser.py index c961efa892df6a21804dae8f52ef64bf99cd409e..112ed46a1e487a7904e79267c1ce7db0ad914552 100644 --- a/tensorflow/contrib/autograph/pyct/parser.py +++ b/tensorflow/contrib/autograph/pyct/parser.py @@ -37,6 +37,7 @@ def parse_entity(entity): def parse_str(src): """Returns the AST of given piece of code.""" + # TODO(mdan): This should exclude the module things are autowrapped in. return gast.parse(src) diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py index da07013cf4f4309c0e24adda3017575d942861b7..fb81404edc1994309f5108fc7e7ba368a1ea3ccb 100644 --- a/tensorflow/contrib/autograph/pyct/qual_names.py +++ b/tensorflow/contrib/autograph/pyct/qual_names.py @@ -30,6 +30,7 @@ import collections import gast from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser class Symbol(collections.namedtuple('Symbol', ['name'])): @@ -89,7 +90,8 @@ class QN(object): if not isinstance(base, (str, StringLiteral, NumberLiteral)): # TODO(mdan): Require Symbol instead of string. raise ValueError( - 'For simple QNs, base must be a string or a Literal object.') + 'for simple QNs, base must be a string or a Literal object;' + ' got instead "%s"' % type(base)) assert '.' not in base and '[' not in base and ']' not in base self._parent = None self.qn = (base,) @@ -112,6 +114,22 @@ class QN(object): raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0]) return self._parent + @property + def owner_set(self): + """Returns all the symbols (simple or composite) that own this QN. + + In other words, if this symbol was modified, the symbols in the owner set + may also be affected. + + Examples: + 'a.b[c.d]' has two owners, 'a' and 'a.b' + """ + owners = set() + if self.has_attr() or self.has_subscript(): + owners.add(self.parent) + owners.update(self.parent.owner_set) + return owners + @property def support_set(self): """Returns the set of simple symbols that this QN relies on. @@ -122,7 +140,7 @@ class QN(object): Examples: 'a.b' has only one support symbol, 'a' - 'a[i]' has two roots, 'a' and 'i' + 'a[i]' has two support symbols, 'a' and 'i' """ # TODO(mdan): This might be the set of Name nodes in the AST. Track those? roots = set() @@ -231,3 +249,9 @@ class QnResolver(gast.NodeTransformer): def resolve(node): return QnResolver().visit(node) + + +def from_str(qn_str): + node = parser.parse_expression(qn_str) + node = resolve(node) + return anno.getanno(node, anno.Basic.QN) diff --git a/tensorflow/contrib/autograph/pyct/qual_names_test.py b/tensorflow/contrib/autograph/pyct/qual_names_test.py index 264afd508cdb847315c486806b531dc1483ef622..c793c2bb39df19f1af9b74f33323dbd4c985ee0d 100644 --- a/tensorflow/contrib/autograph/pyct/qual_names_test.py +++ b/tensorflow/contrib/autograph/pyct/qual_names_test.py @@ -30,6 +30,15 @@ from tensorflow.python.platform import test class QNTest(test.TestCase): + def test_from_str(self): + a = QN('a') + b = QN('b') + a_dot_b = QN(a, attr='b') + a_sub_b = QN(a, subscript=b) + self.assertEqual(qual_names.from_str('a.b'), a_dot_b) + self.assertEqual(qual_names.from_str('a'), a) + self.assertEqual(qual_names.from_str('a[b]'), a_sub_b) + def test_basic(self): a = QN('a') self.assertEqual(a.qn, ('a',)) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD index 8064a967cd389e88d3febbeb21cac87b0fef9e18..92eacba3fd53602ce238dfd7115ff0c3da9b1fc8 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -19,14 +19,17 @@ py_library( srcs = [ "activity.py", "annos.py", - "cfg.py", "live_values.py", + "liveness.py", + "reaching_definitions.py", "type_info.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:util", "@gast_archive//:gast", ], ) @@ -45,23 +48,32 @@ py_test( ) py_test( - name = "cfg_test", - srcs = ["cfg_test.py"], + name = "live_values_test", + srcs = ["live_values_test.py"], srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ ":static_analysis", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", - "@gast_archive//:gast", ], ) py_test( - name = "live_values_test", - srcs = ["live_values_test.py"], + name = "liveness_test", + srcs = ["liveness_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":static_analysis", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "reaching_definitions_test", + srcs = ["reaching_definitions_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], deps = [ ":static_analysis", "//tensorflow/contrib/autograph/pyct", diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py b/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py index c325e19f28376da3be6db4b00b9f664eac047af2..9a82de735dc663f6a824488e4c5864943cecc3d4 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py @@ -18,10 +18,14 @@ This module contains utilities to help annotate AST nodes with as much runtime information as can be possibly extracted without actually executing the code, under that assumption that the context in which the code will run is known. -Note: It's a fair bet that this analysis cannot be reused across contexts -without re-running it. In most cases, the context usually means referenced -modules, which should be static enough to allow reuse, but that is not being -reliably verified. +Overall, the different analyses have the functions listed below: + + * activity: inventories symbols read, written to, params, etc. at different + levels + * liveness, reaching_definitions: dataflow analyses based on the program's CFG + and using the symbol information gathered by activity analysis + * live_values, type_info: type and value inference based on dataflow + analysis and context information """ from __future__ import absolute_import diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py index 4d7b0cbb7b8f6ee5bd64553644dc3ec9b8bca95b..a0182da9d132f50f290f4ba4896484815efb1286 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Activity analysis.""" +"""Activity analysis. + +Requires qualified name annotations (see qual_names.py). +""" from __future__ import absolute_import from __future__ import division @@ -59,9 +62,10 @@ class Scope(object): self.parent = parent self.add_unknown_symbols = add_unknown_symbols self.modified = set() + # TODO(mdan): Completely remove this. self.created = set() self.used = set() - self.params = set() + self.params = {} self.returned = set() # TODO(mdan): Rename to `locals` @@ -106,37 +110,23 @@ class Scope(object): self.modified |= other.modified self.created |= other.created self.used |= other.used - self.params |= other.params + self.params.update(other.params) self.returned |= other.returned def has(self, name): - if name in self.modified or name in self.params: + if name in self.modified: return True elif self.parent is not None: return self.parent.has(name) return False - def is_modified_since_entry(self, name): - if name in self.modified: - return True - elif self.parent is not None and not self.isolated: - return self.parent.is_modified_since_entry(name) - return False - - def is_param(self, name): - if name in self.params: - return True - elif self.parent is not None and not self.isolated: - return self.parent.is_param(name) - return False - def mark_read(self, name): self.used.add(name) if self.parent is not None and name not in self.created: self.parent.mark_read(name) - def mark_param(self, name): - self.params.add(name) + def mark_param(self, name, owner): + self.params[name] = owner def mark_creation(self, name, writes_create_symbol=False): """Mark a qualified name as created.""" @@ -226,37 +216,56 @@ class ActivityAnalyzer(transformer.Base): elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. - # TODO(mdan): This may be incorrect with nested functions. - # For nested functions, we'll have to add the notion of hiding args from - # the parent scope, not writing to them. - self.scope.mark_creation(qn) - self.scope.mark_param(qn) + self.scope.mark_write(qn) + self.scope.mark_param(qn, self.enclosing_entities[-1]) else: raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn)) - anno.setanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY, - self.scope.is_modified_since_entry(qn)) - anno.setanno(node, NodeAnno.IS_PARAM, self.scope.is_param(qn)) if self._in_return_statement: self.scope.mark_returned(qn) + def _enter_scope(self, isolated): + self.scope = Scope(self.scope, isolated=isolated) + + def _exit_scope(self): + self.scope = self.scope.parent + + def _process_statement(self, node): + self._enter_scope(False) + node = self.generic_visit(node) + anno.setanno(node, anno.Static.SCOPE, self.scope) + self._exit_scope() + return node + + def visit_Expr(self, node): + return self._process_statement(node) + + def visit_Return(self, node): + self._in_return_statement = True + node = self._process_statement(node) + self._in_return_statement = False + return node + + def visit_Assign(self, node): + return self._process_statement(node) + def visit_AugAssign(self, node): # Special rules for AugAssign. In Assign, the target is only written, # but in AugAssig (e.g. a += b), the target is both read and written. self._in_aug_assign = True - self.generic_visit(node) + node = self._process_statement(node) self._in_aug_assign = False return node def visit_Name(self, node): - self.generic_visit(node) + node = self.generic_visit(node) self._track_symbol(node) return node def visit_Attribute(self, node): - self.generic_visit(node) + node = self.generic_visit(node) if self._in_constructor and self._node_sets_self_attribute(node): self._track_symbol( node, composite_writes_alter_parent=True, writes_create_symbol=True) @@ -265,44 +274,38 @@ class ActivityAnalyzer(transformer.Base): return node def visit_Subscript(self, node): - self.generic_visit(node) + node = self.generic_visit(node) # Subscript writes (e.g. a[b] = "value") are considered to modify # both the element itself (a[b]) and its parent (a). - self._track_symbol(node, composite_writes_alter_parent=True) + self._track_symbol(node) return node def visit_Print(self, node): - current_scope = self.scope - args_scope = Scope(current_scope) - self.scope = args_scope - for n in node.values: - self.visit(n) - anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope) - self.scope = current_scope + self._enter_scope(False) + node.values = self.visit_block(node.values) + anno.setanno(node, anno.Static.SCOPE, self.scope) + anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope) + self._exit_scope() return node + def visit_Assert(self, node): + return self._process_statement(node) + def visit_Call(self, node): - current_scope = self.scope - args_scope = Scope(current_scope, isolated=False) - self.scope = args_scope - for n in node.args: - self.visit(n) + self._enter_scope(False) + node.args = self.visit_block(node.args) + node.keywords = self.visit_block(node.keywords) # TODO(mdan): Account starargs, kwargs - for n in node.keywords: - self.visit(n) - anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope) - self.scope = current_scope - self.visit(node.func) + anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope) + self._exit_scope() + node.func = self.visit(node.func) return node def _process_block_node(self, node, block, scope_name): - current_scope = self.scope - block_scope = Scope(current_scope, isolated=False) - self.scope = block_scope - for n in block: - self.visit(n) - anno.setanno(node, scope_name, block_scope) - self.scope = current_scope + self._enter_scope(False) + block = self.visit_block(block) + anno.setanno(node, scope_name, self.scope) + self._exit_scope() return node def _process_parallel_blocks(self, parent, children): @@ -321,94 +324,75 @@ class ActivityAnalyzer(transformer.Base): self.scope.merge_from(after_child) return parent + def visit_arguments(self, node): + return self._process_statement(node) + def visit_FunctionDef(self, node): - if self.scope: - qn = qual_names.QN(node.name) - self.scope.mark_write(qn) - current_scope = self.scope - body_scope = Scope(current_scope, isolated=True) - self.scope = body_scope - self.generic_visit(node) - anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope) - self.scope = current_scope + # The FunctionDef node itself has a Scope object that tracks the creation + # of its name, along with the usage of any decorator accompany it. + self._enter_scope(False) + node.decorator_list = self.visit_block(node.decorator_list) + self.scope.mark_write(qual_names.QN(node.name)) + anno.setanno(node, anno.Static.SCOPE, self.scope) + self._exit_scope() + + # A separate Scope tracks the actual function definition. + self._enter_scope(True) + node.args = self.visit(node.args) + + # Track the body separately. This is for compatibility reasons, it may not + # be strictly needed. + self._enter_scope(False) + node.body = self.visit_block(node.body) + anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope) + self._exit_scope() + + self._exit_scope() return node def visit_With(self, node): - current_scope = self.scope - with_scope = Scope(current_scope, isolated=False) - self.scope = with_scope - self.generic_visit(node) - anno.setanno(node, NodeAnno.BODY_SCOPE, with_scope) - self.scope = current_scope + self._enter_scope(False) + node = self.generic_visit(node) + anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope) + self._exit_scope() return node - def visit_If(self, node): - current_scope = self.scope - cond_scope = Scope(current_scope, isolated=False) - self.scope = cond_scope - self.visit(node.test) - anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope) - self.scope = current_scope + def visit_withitem(self, node): + return self._process_statement(node) + def visit_If(self, node): + self._enter_scope(False) + node.test = self.visit(node.test) + anno.setanno(node, NodeAnno.COND_SCOPE, self.scope) + anno.setanno(node.test, anno.Static.SCOPE, self.scope) + self._exit_scope() node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) return node def visit_For(self, node): - self.visit(node.target) - self.visit(node.iter) + self._enter_scope(False) + node.target = self.visit(node.target) + node.iter = self.visit(node.iter) + anno.setanno(node.iter, anno.Static.SCOPE, self.scope) + self._exit_scope() node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) return node def visit_While(self, node): - current_scope = self.scope - cond_scope = Scope(current_scope, isolated=False) - self.scope = cond_scope - self.visit(node.test) - anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope) - self.scope = current_scope - + self._enter_scope(False) + node.test = self.visit(node.test) + anno.setanno(node, NodeAnno.COND_SCOPE, self.scope) + anno.setanno(node.test, anno.Static.SCOPE, self.scope) + self._exit_scope() node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) return node - def visit_Return(self, node): - self._in_return_statement = True - node = self.generic_visit(node) - self._in_return_statement = False - return node - - -def get_read(node, context): - """Return the variable names as QNs (qual_names.py) read by this statement.""" - analyzer = ActivityAnalyzer(context, None, True) - analyzer.visit(node) - return analyzer.scope.used - - -def get_updated(node, context): - """Return the variable names created or mutated by this statement. - - This function considers assign statements, augmented assign statements, and - the targets of for loops, as well as function arguments. - For example, `x[0] = 2` will return `x`, `x, y = 3, 4` will return `x` and - `y`, `for i in range(x)` will return `i`, etc. - Args: - node: An AST node - context: An EntityContext instance - - Returns: - A set of variable names (QNs, see qual_names.py) of all the variables - created or mutated. - """ - analyzer = ActivityAnalyzer(context, None, True) - analyzer.visit(node) - return analyzer.scope.created | analyzer.scope.modified - def resolve(node, context, parent_scope=None): return ActivityAnalyzer(context, parent_scope).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py index fdbd349af9d3325af114a7206d89617134278f14..e940516190182a905f5747ffdd66533567bac76b 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py @@ -21,9 +21,9 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.qual_names import QN from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -52,18 +52,18 @@ class ScopeTest(test.TestCase): other = activity.Scope(None) other.copy_from(scope) - self.assertTrue(QN('foo') in other.created) + self.assertTrue(QN('foo') in other.modified) scope.mark_write(QN('bar')) scope.copy_from(other) - self.assertFalse(QN('bar') in scope.created) + self.assertFalse(QN('bar') in scope.modified) scope.mark_write(QN('bar')) scope.merge_from(other) - self.assertTrue(QN('bar') in scope.created) - self.assertFalse(QN('bar') in other.created) + self.assertTrue(QN('bar') in scope.modified) + self.assertFalse(QN('bar') in other.modified) def test_copy_of(self): scope = activity.Scope(None) @@ -112,18 +112,16 @@ class ActivityAnalyzerTest(test.TestCase): def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, - owner_type=None, - recursive=True) + owner_type=None) node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - return node, ctx + node = activity.resolve(node, entity_info) + return node, entity_info def test_local_markers(self): @@ -159,7 +157,8 @@ class ActivityAnalyzerTest(test.TestCase): """Assert the scope contains specific used, modified & created variables.""" self.assertSymbolSetsAre(used, scope.used, 'read') self.assertSymbolSetsAre(modified, scope.modified, 'modified') - self.assertSymbolSetsAre(created, scope.created, 'created') + # Created is deprecated, we're no longer verifying it. + # self.assertSymbolSetsAre(created, scope.created, 'created') def test_print_statement(self): @@ -217,12 +216,6 @@ class ActivityAnalyzerTest(test.TestCase): (), (), ) - self.assertScopeIsRmc( - anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent, - ('a', 'a.b', 'a.c', 'a.d', 'foo'), - ('a.c',), - ('a',), - ) def test_call_args_subscripts(self): @@ -243,12 +236,6 @@ class ActivityAnalyzerTest(test.TestCase): (), (), ) - self.assertScopeIsRmc( - anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent, - ('a', 'a[0]', 'a[b]', 'a[c]', 'b', 'c', 'foo'), - ('b', 'c'), - ('a', 'b', 'c'), - ) def test_while(self): @@ -364,20 +351,20 @@ class ActivityAnalyzerTest(test.TestCase): self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'b', 'c', 'a[c]'), - ('a', 'a[b]', 'd'), + ('a[b]', 'd'), ('d',), ) # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"? self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'e'), - ('a', 'a[0]', 'd'), + ('a[0]', 'd'), ('d',), ) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('a', 'b', 'c', 'd', 'e', 'a[c]'), - ('a', 'd', 'a[b]', 'a[0]'), + ('d', 'a[b]', 'a[0]'), ('a', 'b', 'c', 'd', 'e'), ) @@ -417,10 +404,6 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) fn_def_node = node.body[0].body[0] - self.assertScopeIsRmc( - anno.getanno(fn_def_node, - NodeAnno.BODY_SCOPE).parent, ('b', 'i', 'f', 'c', 'a'), - ('f', 'b', 'c', 'i'), ('f', 'a', 'b', 'c', 'i')) self.assertScopeIsRmc( anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), ( 'x', @@ -454,7 +437,7 @@ class ActivityAnalyzerTest(test.TestCase): self.assertScopeIsRmc( anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('a', 'a[0]'), - ('a', 'a[0]'), + ('a[0]',), ('a',), ) @@ -520,47 +503,6 @@ class ActivityAnalyzerTest(test.TestCase): anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('b',), (('')), (('a', 'b'))) - def test_get_read(self): - - def test_fn(x, y): - z = test_fn(x, y) - return z - - node, ctx = self._parse_and_analyze(test_fn) - node = node.body[0].body[0] - read_vars = activity.get_read(node, ctx) - self.assertEqual(read_vars, set(map(qual_names.QN, ('test_fn', 'x', 'y')))) - - def test_fn2(x, y, z): - z += test_fn2(x, y, z) - return z - - node, ctx = self._parse_and_analyze(test_fn2) - node = node.body[0].body[0] - read_vars = activity.get_read(node, ctx) - self.assertEqual(read_vars, - set(map(qual_names.QN, ('test_fn2', 'x', 'y', 'z')))) - - def test_get_updated(self): - - def test_fn(x, y): - z = test_fn(x, y) - return z - - node, ctx = self._parse_and_analyze(test_fn) - node = node.body[0].body[0] - updated_vars = activity.get_updated(node, ctx) - self.assertEqual(updated_vars, set(map(qual_names.QN, ('z')))) - - def test_fn2(x, y, z): - z += test_fn2(x, y, z) - return z - - node, ctx = self._parse_and_analyze(test_fn2) - node = node.body[0].body[0] - updated_vars = activity.get_updated(node, ctx) - self.assertEqual(updated_vars, set(map(qual_names.QN, ('z')))) - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py index b929b35b79200b0968c9c4f26b10cda28763773a..5eefecf278992f73464817585a3498de4c031978 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py @@ -21,6 +21,9 @@ from __future__ import print_function from enum import Enum +# TODO(mdan): Remove. + + class NoValue(Enum): def __repr__(self): @@ -50,10 +53,3 @@ class NodeAnno(NoValue): ORELSE_SCOPE = ( 'The scope for the orelse body of a statement (False branch for if ' 'statements, orelse body for loops).') - - # Type and Value annotations - # Type annotations are represented by objects of type type_info.Type. - STATIC_INFO = ( - 'The type or value information that should be asserted about the entity ' - 'referenced by the symbol holding this annotation, irrespective of the ' - 'execution context.') diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py deleted file mode 100644 index ad97fdfa8e78d1fd4c38724612d83519c6609cce..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py +++ /dev/null @@ -1,445 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Control flow graph analysis. - -Given a Python AST we construct a control flow graph, with edges both to the -next and previous statements (so it can easily walk the graph both ways). Its -nodes contain the AST of the statements. It can then perform forward or backward -analysis on this CFG. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import namedtuple -import functools -import operator - -import gast - -from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct.static_analysis import activity - - -class CfgNode(object): - """A node in the CFG.""" - __slots__ = ['next', 'value', 'prev'] - - def __init__(self, value): - self.next = set() - self.prev = set() - self.value = value - - -class Cfg(namedtuple('Cfg', ['entry', 'exit'])): - """A Control Flow Graph. - - Each statement is represented as a node. For control flow statements such - as conditionals and loops the conditional itself is a node which either - branches or cycles, respectively. - Attributes: - entry: The entry node, which contains the `gast.arguments` node of the - function definition. - exit: The exit node. This node is special because it has no value (i.e. no - corresponding AST node). This is because Python functions can have - multiple return statements. - """ - pass - - -class CfgBuilder(gast.NodeVisitor): - """Construct a control flow graph. - - Construct a CFG starting from a FunctionDef node. - Usage: - cfg_obj = CfgBuilder().build_cfg(fndef_node) - """ - - def __init__(self): - # The current leaves of the CFG - self.current_leaves = [] - # TODO(alexbw): generalize to break, return, continue, yield, etc. - # A stack of lists, tracking continue statements - self.continue_ = [] - # A stack of lists tracking break nodes - self.break_ = [] - - def set_current_leaves(self, cfg_node): - """Link this cfg_node to the current leaves. - - This is the central function for building the CFG. It links the current - head cfg_nodes to the passed cfg_node. It then resets the head to the - passed cfg_node. - - Args: - cfg_node: A CfgNode instance. - """ - for head in self.current_leaves: - head.next.add(cfg_node) - # While we're linking the CFG forward, add backlinks - cfg_node.prev.add(head) - self.current_leaves = [cfg_node] - - def build_cfg(self, node): - """Build a CFG for a function. - - Implementation of building a CFG for dataflow analysis. See, e.g.: - https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf - - Args: - node: A function definition the body of which to analyze. - Returns: - A CFG object. - Raises: - TypeError: If the input is not a function definition. - """ - if not isinstance(node, gast.FunctionDef): - raise TypeError('input must be a function definition') - entry_cfg_node = CfgNode(node.args) - self.current_leaves = [entry_cfg_node] - self.visit_statements(node.body) - exit_cfg_node = CfgNode(None) - self.set_current_leaves(exit_cfg_node) - return Cfg(entry_cfg_node, exit_cfg_node) - - def visit_statements(self, nodes): - for node in nodes: - # Check for control flow - if isinstance(node, (gast.For, gast.While, gast.If, gast.Try, gast.Break, - gast.Continue, gast.With)): - self.visit(node) - else: - expr = CfgNode(node) - self.set_current_leaves(expr) - - def generic_visit(self, node): - raise ValueError('unknown control flow') - - def visit_If(self, node): - # TODO(alexbw): change this to use immutable tuples instead of lists - # The current head will hold the conditional - test = CfgNode(node.test) - self.set_current_leaves(test) - # Handle the body - self.visit_statements(node.body) - body_exit = self.current_leaves - self.current_leaves = [test] - # Handle the orelse - self.visit_statements(node.orelse) - self.current_leaves.extend(body_exit) - - def visit_While(self, node): - test = CfgNode(node.test) - self.set_current_leaves(test) - # Start a new level of nesting - self.break_.append([]) - self.continue_.append([]) - # Handle the body - self.visit_statements(node.body) - body_exit = self.current_leaves - self.current_leaves.extend(self.continue_.pop()) - self.set_current_leaves(test) - # Handle the orelse - self.visit_statements(node.orelse) - # The break statements and the test go to the next node - self.current_leaves.extend(self.break_.pop()) - # Body and orelse statements can reach out of the loop - self.current_leaves.extend(body_exit) - - def visit_For(self, node): - iter_ = CfgNode(node.iter) - self.set_current_leaves(iter_) - self.break_.append([]) - self.continue_.append([]) - self.visit_statements(node.body) - body_exit = self.current_leaves - self.current_leaves.extend(self.continue_.pop()) - self.set_current_leaves(iter_) - # Handle the orelse - self.visit_statements(node.orelse) - # The break statements and the test go to the next node - self.current_leaves.extend(self.break_.pop()) - # Body and orelse statements can reach out of the loop - self.current_leaves.extend(body_exit) - - def visit_Break(self, node): - self.break_[-1].extend(self.current_leaves) - self.current_leaves[:] = [] - - def visit_Continue(self, node): - self.continue_[-1].extend(self.current_leaves) - self.current_leaves[:] = [] - - def visit_Try(self, node): - self.visit_statements(node.body) - body = self.current_leaves - handlers = [] - for handler in node.handlers: - self.current_leaves = body[:] - self.visit_statements(handler.body) - handlers.extend(self.current_leaves) - self.current_leaves = body - self.visit_statements(node.orelse) - self.current_leaves = handlers + self.current_leaves - self.visit_statements(node.finalbody) - - def visit_With(self, node): - for item in node.items: - self.set_current_leaves(CfgNode(item)) - self.visit_statements(node.body) - - -# TODO(alexbw): once CFG analysis occurs at a block level, -# this extra class will not be necessary -class PropagateAnalysis(gast.NodeVisitor): - """Port analysis annotations from statements to their enclosing blocks.""" - - def __init__(self, analysis): - self.transfer_fn = analysis.transfer_fn - self.in_label = analysis.in_label - self.out_label = analysis.out_label - super(PropagateAnalysis, self).__init__() - - def visit_If(self, node): - # Depth-first. - self.generic_visit(node) - incoming = anno.getanno(node.body[0], self.in_label) - incoming |= anno.getanno(node.test, self.in_label) - outgoing = anno.getanno(node.body[-1], self.out_label) - outgoing |= anno.getanno(node.test, self.out_label) - if node.orelse: - orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) - outgoing = self.transfer_fn(outgoing, orelse_outgoing) - anno.setanno(node, self.in_label, incoming) - anno.setanno(node, self.out_label, outgoing) - - def visit_For(self, node): - self.generic_visit(node) - incoming = set(anno.getanno(node.body[0], self.in_label)) - incoming -= set((anno.getanno(node.target, anno.Basic.QN),)) - outgoing = anno.getanno(node.body[-1], self.out_label) - if node.orelse: - orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) - outgoing = self.transfer_fn(outgoing, orelse_outgoing) - anno.setanno(node, self.in_label, frozenset(incoming)) - anno.setanno(node, self.out_label, outgoing) - - def visit_While(self, node): - self.generic_visit(node) - incoming = anno.getanno(node.body[0], self.in_label) - incoming |= anno.getanno(node.test, self.in_label) - outgoing = anno.getanno(node.body[-1], self.out_label) - if node.orelse: - orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) - outgoing = self.transfer_fn(outgoing, orelse_outgoing) - anno.setanno(node, self.in_label, incoming) - anno.setanno(node, self.out_label, outgoing) - - def visit_With(self, node): - self.generic_visit(node) - incoming = anno.getanno(node.body[0], self.in_label) - for item in node.items: - incoming |= anno.getanno(item, self.in_label) - outgoing = anno.getanno(node.body[-1], self.out_label) - anno.setanno(node, self.in_label, incoming) - anno.setanno(node, self.out_label, outgoing) - - -# TODO(alexbw): Abstract the CFG walking machinery into a superclass -# which is parameterized on which fields it selects when walking. -# TODO(alexbw): Abstract the application of dataflow analysis -class Forward(object): - """Forward analysis on CFG. - - Args: - label: A name for this analysis e.g. 'active' for activity analysis. The AST - nodes in the CFG will be given annotations 'name_in', 'name_out', - 'name_gen' and 'name_kill' which contain the incoming values, outgoing - values, values generated by the statement, and values deleted by the - statement respectively. - transfer_fn: Either the AND or OR operator. If the AND operator is used it - turns into forward must analysis (i.e. a value will only be carried - forward if it appears on all incoming paths). The OR operator means that - forward may analysis is done (i.e. the union of incoming values will be - taken). - """ - - def __init__(self, label, context, transfer_fn=operator.or_): - self.transfer_fn = transfer_fn - self.context = context - self.out_label = label + '_out' - self.in_label = label + '_in' - self.gen_label = label + '_gen' - self.kill_label = label + '_kill' - - # TODO(alexbw): see if we can simplify by visiting breadth-first - def visit(self, node): - """Depth-first walking the CFG, applying dataflow information propagtion.""" - # node.value is None only for the exit CfgNode. - if not node.value: - return - - if anno.hasanno(node.value, self.out_label): - before = hash(anno.getanno(node.value, self.out_label)) - else: - before = None - preds = [ - anno.getanno(pred.value, self.out_label) - for pred in node.prev - if anno.hasanno(pred.value, self.out_label) - ] - if preds: - incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0]) - else: - incoming = frozenset() - anno.setanno(node.value, self.in_label, incoming) - gen, kill = self.get_gen_kill(node, incoming) - anno.setanno(node.value, self.gen_label, gen) - anno.setanno(node.value, self.kill_label, kill) - anno.setanno(node.value, self.out_label, (incoming - kill) | gen) - - if hash(anno.getanno(node.value, self.out_label)) != before: - for succ in node.next: - self.visit(succ) - - def get_gen_kill(self, cfg_node, incoming): - """Calculate Gen and Kill properties of a CFG node in dataflow analysis. - - A function which takes the CFG node as well as a set of incoming - values. It must return a set of newly generated values by the statement as - well as a set of deleted (killed) values. - - Args: - cfg_node: A CfgNode instance. - incoming: - """ - raise NotImplementedError() - - -class Backward(Forward): - """Backward analysis on CFG.""" - - def visit(self, cfg_node): - # cfg_node.value is None for the exit node, which will be visited only once - if not cfg_node.value: - for pred in cfg_node.prev: - self.visit(pred) - return - - if anno.hasanno(cfg_node.value, self.in_label): - before = hash(anno.getanno(cfg_node.value, self.in_label)) - else: - before = None - succs = [ - anno.getanno(succ.value, self.in_label) - for succ in cfg_node.next - if anno.hasanno(succ.value, self.in_label) - ] - if succs: - incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0]) - else: - incoming = frozenset() - anno.setanno(cfg_node.value, self.out_label, incoming) - gen, kill = self.get_gen_kill(cfg_node, incoming) - anno.setanno(cfg_node.value, self.gen_label, gen) - anno.setanno(cfg_node.value, self.kill_label, kill) - anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen) - if hash(anno.getanno(cfg_node.value, self.in_label)) != before: - for pred in cfg_node.prev: - self.visit(pred) - - -def run_analyses(node, analyses): - """Perform dataflow analysis on all functions within an AST. - - Args: - node: An AST node on which to run dataflow analysis. - analyses: Either an instance of the Forward or Backward dataflow analysis - class, or a list or tuple of them. - - Returns: - node: The node, but now with annotations on the AST nodes containing the - results of the dataflow analyses. - """ - if not isinstance(analyses, (tuple, list)): - analyses = (analyses,) - for analysis in analyses: - if not isinstance(analysis, (Forward, Backward)): - raise TypeError('not a valid forward analysis object') - - for child_node in gast.walk(node): - if isinstance(child_node, gast.FunctionDef): - cfg_obj = CfgBuilder().build_cfg(child_node) - for analysis in analyses: - if isinstance(analysis, Backward): - analysis.visit(cfg_obj.exit) - elif isinstance(analysis, Forward): - analysis.visit(cfg_obj.entry) - for analysis in analyses: - PropagateAnalysis(analysis).visit(node) - return node - - -class Liveness(Backward): - """Perform a liveness analysis. - - Each statement is annotated with a set of variables that may be used - later in the program. - """ - - def __init__(self, context): - super(Liveness, self).__init__('live', context) - - def get_gen_kill(self, node, _): - # A variable's parents are live if it is live - # e.g. x is live if x.y is live. This means gen needs to return - # all parents of a variable (if it's an Attribute or Subscript). - # This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x) - gen = activity.get_read(node.value, self.context) - gen = functools.reduce(lambda left, right: left | right.support_set, gen, - gen) - kill = activity.get_updated(node.value, self.context) - return gen, kill - - -class ReachingDefinitions(Forward): - """Perform reaching definition analysis. - - Each statement is annotated with a set of (variable, definition) pairs. - """ - - def __init__(self, context): - super(ReachingDefinitions, self).__init__('definitions', context) - - def get_gen_kill(self, node, incoming): - definitions = activity.get_updated(node.value, self.context) - gen = frozenset((id_, node.value) for id_ in definitions) - kill = frozenset(def_ for def_ in incoming if def_[0] in definitions) - return gen, kill - - -class Defined(Forward): - """Perform defined variable analysis. - - Each statement is annotated with a set of variables which are guaranteed to - be defined at that point. - """ - - def __init__(self, context): - super(Defined, self).__init__('defined', context, transfer_fn=operator.and_) - - def get_gen_kill(self, node, _): - gen = activity.get_updated(node.value, self.context) - return gen, frozenset() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py deleted file mode 100644 index fc07fa3447b23c0595a5893329de8a2d7055ca15..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py +++ /dev/null @@ -1,306 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for cfg module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools - -import gast - -from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context -from tensorflow.contrib.autograph.pyct import parser -from tensorflow.contrib.autograph.pyct import qual_names -from tensorflow.contrib.autograph.pyct.static_analysis import cfg -from tensorflow.python.platform import test - - -class CFGTest(test.TestCase): - - def _parse_and_analyze(self, test_fn, namespace, arg_types=None): - arg_types = arg_types or {} - node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, - source_code=source, - source_file=None, - namespace=namespace, - arg_values=None, - arg_types=arg_types, - owner_type=None, - recursive=True) - node = qual_names.resolve(node) - return node, ctx - - def _check_anno_matches(self, node, anno_name, var_names): - if isinstance(var_names, str): - var_names = (var_names,) - qual_vars = set() - for var_name in var_names: - if isinstance(var_name, str): - if '[' in var_name or ']' in var_name: - raise ValueError('Annotation matching not supported with subscript.') - if '.' not in var_name: - qual_vars.add(qual_names.QN(var_name)) - else: - attrs = var_name.split('.') - this_qn = functools.reduce(qual_names.QN, attrs[1:], - qual_names.QN(attrs[0])) - qual_vars.add(this_qn) - self.assertEqual(anno.getanno(node, anno_name), qual_vars) - - def test_reaching(self): - - def f(x): - print(x) - while True: - x = x - x = x - return x - - node, ctx = self._parse_and_analyze(f, {}) - cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) - body = node.body[0].body - # Only the argument reaches the expression - def_in = anno.getanno(body[0], 'definitions_in') - # One element, x, from arguments - self.assertEqual(set(type(d[1]) for d in def_in), set((gast.arguments,))) - - while_body = body[1].body - def_in = anno.getanno(while_body[0], 'definitions_in') - # One definition, two possible sources. - # - One from an assignment (if the loop is entered) - # - The other from the arguments (if loop is not entered) - self.assertEqual( - set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) - - def_in = anno.getanno(while_body[1], 'definitions_in') - # If we've reached this line, the only reaching definition of x is the - # Assign node in previous line - self.assertEqual(set(type(d[1]) for d in def_in), set((gast.Assign,))) - - def_in = anno.getanno(body[2], 'definitions_in') - # Same situation as while_body[0] - self.assertEqual( - set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) - - def test_defined(self): - - def f(x): - if x: - y = 2 # pylint: disable=unused-variable - return x - - node, ctx = self._parse_and_analyze(f, {}) - cfg.run_analyses(node, cfg.Defined(ctx)) - body = node.body[0].body - # only x is for sure defined at the end - self._check_anno_matches(body[1], 'defined_in', 'x') - # at the end of the if body both x and y are defined - if_body = body[0].body - self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y')) - - def _get_live_annotated_fnbody(self, f): - node, ctx = self._parse_and_analyze(f, {}) - cfg.run_analyses(node, cfg.Liveness(ctx)) - body = node.body[0].body - return body - - def test_live_straightline(self): - - def f1(x): - a = g(x) # pylint: disable=undefined-variable - b = h(a) # pylint: disable=undefined-variable, unused-variable - return x - - body = self._get_live_annotated_fnbody(f1) - self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x')) - self._check_anno_matches(body[2], 'live_in', ('x')) - self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x')) - self._check_anno_matches(body[2], 'live_out', ()) - - def test_live_stacked_conds_with_else(self): - - def f2(x, a): # pylint: disable=unused-argument - if a > 0: # x should not be live - x = 0 - if a > 1: - x = 1 - else: - x = 2 - - body = self._get_live_annotated_fnbody(f2) - self._check_anno_matches(body[0], 'live_in', ('a')) - self._check_anno_matches(body[1], 'live_in', ('a')) - - def test_live_stacked_conds(self): - - def f3(x, a): - if a > 0: # x and a should be live - x = 0 - if a > 1: # x and a should be live_in - x = 1 - return x # x should be live - - body = self._get_live_annotated_fnbody(f3) - self._check_anno_matches(body[0], 'live_in', ('a', 'x')) - self._check_anno_matches(body[1], 'live_in', ('a', 'x')) - self._check_anno_matches(body[2], 'live_in', ('x')) - - def test_live_possibly_unused_cond(self): - - def f4(x, a): - if a > 0: # x should be live - x = 0 - x += 1 - - body = self._get_live_annotated_fnbody(f4) - self._check_anno_matches(body[0], 'live_in', ('x', 'a')) - self._check_anno_matches(body[1], 'live_in', ('x')) - - def test_live_attribute_in_cond(self): - - def f5(x, a): - if a > 0: # x.y should be live - x.y = 0 - return x.y - - body = self._get_live_annotated_fnbody(f5) - self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a')) - - def test_live_noop(self): - - def f6(x): - return x # should this cause x.* to be live? - - body = self._get_live_annotated_fnbody(f6) - self._check_anno_matches(body[0], 'live_in', ('x')) - - def test_live_loop(self): - - def f7(x, n): - for i in range(n): - x += i - return x - - body = self._get_live_annotated_fnbody(f7) - self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range')) - self._check_anno_matches(body[1], 'live_in', ('x')) - - def test_live_context_manager(self): - - def f8(x, f): - with f: - x += 1 - - body = self._get_live_annotated_fnbody(f8) - self._check_anno_matches(body[0], 'live_in', ('f', 'x')) - - def test_node_equality(self): - node_a = gast.parse('y = x').body[0] - node_b = gast.parse('y = x').body[0] - self.assertNotEqual(node_a, node_b) - - def test_nested_functions_defined(self): - - def f(x): - y = x * 2 - - def g(z): - return z + y - - return g(x) - - node, ctx = self._parse_and_analyze(f, {}) - cfg.run_analyses(node, cfg.Defined(ctx)) - - body = node.body[0].body - self.assertEqual( - anno.getanno(body[2], 'defined_in'), - frozenset(map(qual_names.QN, ('g', 'x', 'y')))) - - # TODO(alexbw): CFG analysis doesn't currently cross FunctionDef boundaries. - # NOTE: 'z' is easy to find, but 'y' is not identified as - # defined, because CFG analysis is applied with each function separately. - # fndef_body = body[1].body - # self.assertEqual( - # anno.getanno(fndef_body[0], 'defined_in'), - # frozenset(map(qual_names.QN, ('z', 'y')))) - - def test_nested_functions_dont_leak_definitions(self): - - def f(x): - print(x) - - def g(): - y = 2 - return y - - return g() # y is not defined here - - node, ctx = self._parse_and_analyze(f, {}) - cfg.run_analyses(node, cfg.Defined(ctx)) - body = node.body[0].body - self.assertEqual( - anno.getanno(body[2], 'defined_in'), - frozenset(map(qual_names.QN, ('x', 'g')))) - - def test_loop_else(self): - - # Disabling useless-else-on-loop error, because 'break' and 'continue' - # canonicalization are a separate analysis pass, and here we test - # the CFG analysis in isolation. - def for_orelse(x): - y = 0 - for i in range(len(x)): - x += i - else: # pylint: disable=useless-else-on-loop - y = 1 - return x, y - - def while_orelse(x, i): - y = 0 - while x < 10: - x += i - else: # pylint: disable=useless-else-on-loop - y = 1 - return x, y - - for f in (for_orelse, while_orelse): - node, ctx = self._parse_and_analyze(f, {}) - cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) - body = node.body[0].body - return_node = body[-1] - reaching_defs = anno.getanno(return_node, 'definitions_in') - - # Y could be defined by Assign(Num(0)) or Assign(Num(1)) - # X could be defined as an argument or an AugAssign. - y_defs = [node for var, node in reaching_defs if str(var) == 'y'] - x_defs = [node for var, node in reaching_defs if str(var) == 'x'] - - self.assertEqual(set((gast.Assign,)), set(type(def_) for def_ in y_defs)) - self.assertEqual(set((0, 1)), set(def_.value.n for def_ in y_defs)) - self.assertEqual(len(y_defs), 2) - self.assertEqual( - set((gast.arguments, gast.AugAssign)), - set(type(def_) for def_ in x_defs)) - self.assertEqual(len(x_defs), 2) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py index 53ae15459097baff918432a493edd7360ebf209d..2d8f922a4589e45ab7e4f20f800e0ffef3d7f0a5 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py @@ -16,7 +16,7 @@ Live values are extracted from the known execution context. -Requires activity analysis annotations. +Requires activity and reaching definitions analyses. """ from __future__ import absolute_import @@ -39,24 +39,22 @@ class LiveValueResolver(transformer.Base): def visit_ClassDef(self, node): self.generic_visit(node) - anno.setanno(node, 'live_val', self.context.namespace[node.name]) + anno.setanno(node, 'live_val', self.entity_info.namespace[node.name]) return node def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Load): - assert anno.hasanno(node, NodeAnno.IS_LOCAL), node - symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL) - assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node - symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY) - assert anno.hasanno(node, NodeAnno.IS_PARAM), node - symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM) - - if not symbol_is_local and not symbol_is_param: + defs = anno.getanno(node, anno.Static.DEFINITIONS, ()) + + is_defined = bool(defs) + has_single_def = len(defs) == 1 + + if not is_defined: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) - elif node.id in self.context.namespace: - obj = self.context.namespace[node.id] + elif node.id in self.entity_info.namespace: + obj = self.entity_info.namespace[node.id] anno.setanno(node, 'live_val', obj) if hasattr(obj, '__name__'): anno.setanno(node, 'fqn', (obj.__name__,)) @@ -79,11 +77,13 @@ class LiveValueResolver(transformer.Base): # TODO(mdan): Attempt to trace its value through the local chain. # TODO(mdan): Use type annotations as fallback. - if not symbol_is_modified: - if node.id in self.context.arg_values: - obj = self.context.arg_values[node.id] - anno.setanno(node, 'live_val', obj) - anno.setanno(node, 'fqn', (obj.__class__.__name__,)) + if has_single_def: + def_, = defs + if def_.param_of is self.enclosing_entities[0]: + if node.id in self.entity_info.arg_values: + obj = self.entity_info.arg_values[node.id] + anno.setanno(node, 'live_val', obj) + anno.setanno(node, 'fqn', (obj.__class__.__name__,)) return node def visit_Attribute(self, node): @@ -91,12 +91,20 @@ class LiveValueResolver(transformer.Base): if anno.hasanno(node.value, 'live_val'): assert anno.hasanno(node.value, 'fqn') parent_object = anno.getanno(node.value, 'live_val') - if not hasattr(parent_object, node.attr): - raise AttributeError('%s has no attribute %s' % (parent_object, - node.attr)) + anno.setanno(node, 'parent_type', type(parent_object)) - anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) + if hasattr(parent_object, node.attr): + # This can happen when the attribute's creation and use depend on the + # same static condition, for example: + # + # if cond: + # foo.bar = baz + # if cond: + # x = foo.bar + # + anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) + # TODO(mdan): Investigate the role built-in annotations can play here. elif anno.hasanno(node.value, 'type'): parent_type = anno.getanno(node.value, 'type') diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py index 69e428bde109ed43c3cdda1a94970a832dc47852..fe3051179cd93ddd2627802dd2536bb50f17fb7f 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py @@ -21,11 +21,13 @@ from __future__ import print_function import six from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import cfg from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.framework import constant_op from tensorflow.python.platform import test @@ -39,22 +41,22 @@ class LiveValuesResolverTest(test.TestCase): literals=None, arg_types=None): literals = literals or {} - arg_types = arg_types or {} node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, - owner_type=None, - recursive=True) + owner_type=None) node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - node = live_values.resolve(node, ctx, literals) - node = type_info.resolve(node, ctx) - node = live_values.resolve(node, ctx, literals) + graphs = cfg.build(node) + node = activity.resolve(node, entity_info) + node = reaching_definitions.resolve(node, entity_info, graphs, + reaching_definitions.Definition) + node = live_values.resolve(node, entity_info, literals) + node = type_info.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, literals) return node def test_literals(self): diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py b/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py new file mode 100644 index 0000000000000000000000000000000000000000..bf29d868a2e4d2a4c7dd1057c0ed93e54d01d750 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py @@ -0,0 +1,200 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Live variable analysis. + +This analysis attaches a set containing the live symbols that are live at the +exit of control flow statements. + +Requires activity analysis. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import annos + + +class Analyzer(cfg.GraphVisitor): + """CFG visitor that performs liveness analysis at statement level.""" + + def __init__(self, graph): + super(Analyzer, self).__init__(graph) + # This allows communicating that nodes generate extra symbols, + # e.g. those that a function definition closes over. + self.extra_gen = {} + + def init_state(self, _): + return set() + + def visit_node(self, node): + prev_live_in = self.in_[node] + + if anno.hasanno(node.ast_node, anno.Static.SCOPE): + node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) + + gen = node_scope.used | self.extra_gen.get(node.ast_node, frozenset()) + # TODO(mdan): verify whether composites' parents need to be added. + # E.g. if x.y is live whether x needs to be added. Theoretically the + # activity analysis should have both so that wouldn't be needed. + kill = node_scope.modified + + live_out = set() + for n in node.next: + live_out |= self.in_[n] + live_in = gen | (live_out - kill) + + else: + # Nodes that don't have a scope annotation are assumed not to touch any + # symbols. + # This Name node below is a literal name, e.g. False + assert isinstance(node.ast_node, + (gast.Name, gast.Continue, gast.Break)), type( + node.ast_node) + live_in = prev_live_in + live_out = live_in + + self.in_[node] = live_in + self.out[node] = live_out + + # TODO(mdan): Move this to the superclass? + return prev_live_in != live_in + + +class WholeTreeAnalyzer(transformer.Base): + """Runs liveness analysis on each of the functions defined in the AST. + + If a function defined other local functions, those will have separate CFGs. + However, dataflow analysis needs to tie up these CFGs to properly emulate the + effect of closures. In the case of liveness, the parent function's live + variables must account for the variables that are live at the entry of each + subfunction. For example: + + def foo(): + # baz is live here + def bar(): + print(baz) + + This analyzer runs liveness analysis on each individual function, accounting + for the effect above. + """ + + def __init__(self, source_info, graphs): + super(WholeTreeAnalyzer, self).__init__(source_info) + self.graphs = graphs + self.current_analyzer = None + self.analyzers = {} + + def visit_FunctionDef(self, node): + parent_analyzer = self.current_analyzer + subgraph = self.graphs[node] + + # Postorder tree processing makes this a bit complicated: + # 1. construct an analyzer object and put it on stack + # 2. recursively walk the subtree; this will initialize the analyzer's + # in_ state properly (done in a block below) + # 3. run the final analysis + analyzer = Analyzer(subgraph) + self.current_analyzer = analyzer + node = self.generic_visit(node) + analyzer.visit_reverse() + + if parent_analyzer is not None: + # Wire the state between the two subgraphs' analyzers. + child_in_state = analyzer.in_[subgraph.entry] + # Exception: symbols modified in the child function are local to it + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + for qn in body_scope.modified: + # Note: a function modifying the symbol doesn't make that symbol + # live at the function's entry. In fact when that happens it is + # probably a case of undefined assignment, like this: + # + # bar = 0 + # def foo(): + # print(bar) # bar is undefined here! + # bar = 1 + # + # Hence we use discard and not remove below. + child_in_state.discard(qn) + parent_analyzer.extra_gen[node] = frozenset(child_in_state,) + + self.analyzers[node] = analyzer + self.current_analyzer = parent_analyzer + return node + + def visit_nonlocal(self, node): + raise NotImplementedError() + + def visit_global(self, node): + raise NotImplementedError() + + +class Annotator(transformer.Base): + """AST visitor that annotates each control flow block with live symbols.""" + + # Note: additional nodes may be added as needed. + + def __init__(self, source_info, cross_function_analyzer): + super(Annotator, self).__init__(source_info) + self.cross_function_analyzer = cross_function_analyzer + self.current_analyzer = None + + def visit_FunctionDef(self, node): + parent_analyzer = self.current_analyzer + self.current_analyzer = self.cross_function_analyzer.analyzers[node] + + node = self.generic_visit(node) + self.current_analyzer = parent_analyzer + return node + + def _aggregate_successors_live_in(self, node): + successors = self.current_analyzer.graph.stmt_next[node] + node_live_out = set() + for s in successors: + node_live_out.update(self.current_analyzer.in_[s]) + anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(node_live_out)) + node = self.generic_visit(node) + return node + + def visit_If(self, node): + return self._aggregate_successors_live_in(node) + + def visit_For(self, node): + return self._aggregate_successors_live_in(node) + + def visit_While(self, node): + return self._aggregate_successors_live_in(node) + + +def resolve(node, source_info, graphs): + """Resolves the live symbols at the exit of control flow statements. + + Args: + node: ast.AST + source_info: transformer.SourceInfo + graphs: Dict[ast.FunctionDef, cfg.Graph] + Returns: + ast.AST + """ + cross_function_analyzer = WholeTreeAnalyzer(source_info, graphs) + node = cross_function_analyzer.visit(node) + visitor = Annotator(source_info, cross_function_analyzer) + node = visitor.visit(node) + return node diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d53adb28af03f0de14f319f642ee82928a480e3a --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py @@ -0,0 +1,149 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for liveness module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import liveness +from tensorflow.python.platform import test + + +class LivenessTest(test.TestCase): + + def _parse_and_analyze(self, test_fn): + node, source = parser.parse_entity(test_fn) + entity_info = transformer.EntityInfo( + source_code=source, + source_file=None, + namespace={}, + arg_values=None, + arg_types=None, + owner_type=None) + node = qual_names.resolve(node) + node = activity.resolve(node, entity_info) + graphs = cfg.build(node) + liveness.resolve(node, entity_info, graphs) + return node + + def assertHasLiveOut(self, node, expected): + live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) + live_out_str = set(str(v) for v in live_out) + if not expected: + expected = () + if not isinstance(expected, tuple): + expected = (expected,) + self.assertSetEqual(live_out_str, set(expected)) + + def test_stacked_if(self): + + def test_fn(x, a): + if a > 0: + x = 0 + if a > 1: + x = 1 + return x + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveOut(fn_body[0], ('a', 'x')) + self.assertHasLiveOut(fn_body[1], 'x') + + def test_stacked_if_else(self): + + def test_fn(x, a): + if a > 0: + x = 0 + if a > 1: + x = 1 + else: + x = 2 + return x + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveOut(fn_body[0], 'a') + self.assertHasLiveOut(fn_body[1], 'x') + + def test_for_basic(self): + + def test_fn(x, a): + for i in range(a): + x += i + return x + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveOut(fn_body[0], 'x') + + def test_attributes(self): + + def test_fn(x, a): + if a > 0: + x.y = 0 + return x.y + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveOut(fn_body[0], ('x.y', 'x')) + + def test_nested_functions(self): + + def test_fn(a, b): + if b: + a = [] + + def foo(): + return a + + foo() + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveOut(fn_body[0], 'a') + + def test_nested_functions_isolation(self): + + def test_fn(b): + if b: + a = 0 # pylint:disable=unused-variable + + def child(): + max(a) # pylint:disable=used-before-assignment + a = 1 + return a + + child() + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveOut(fn_body[0], 'max') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py new file mode 100644 index 0000000000000000000000000000000000000000..7f2b379d3de236020f1ec2b8a4972cc67b10b060 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py @@ -0,0 +1,301 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reaching definition analysis. + +This analysis attaches a set of a Definition objects to each symbol, one +for each distinct definition that may reach it. The Definition objects are +mutable and may be used by subsequent analyses to further annotate data like +static type and value information. +The analysis also attaches the set of the symbols defined at the entry of +control flow statements. + +Requires activity analysis. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import annos + + +class Definition(object): + """Definition objects describe a unique definition of a variable. + + Subclasses of this may be used by passing an appropriate factory function to + resolve. + + Attributes: + param_of: Optional[ast.AST] + """ + + def __init__(self): + self.param_of = None + + def __repr__(self): + return '%s[%d]' % (self.__class__.__name__, id(self)) + + +class _NodeState(object): + """Abstraction for the state of the CFG walk for reaching definition analysis. + + This is a value type. Only implements the strictly necessary operators. + + Attributes: + value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and + their possible definitions + """ + + def __init__(self, init_from=None): + if init_from: + if isinstance(init_from, _NodeState): + self.value = { + s: set(other_infos) for s, other_infos in init_from.value.items() + } + elif isinstance(init_from, dict): + self.value = {s: set((init_from[s],)) for s in init_from} + else: + assert False, init_from + else: + self.value = {} + + def __eq__(self, other): + if frozenset(self.value.keys()) != frozenset(other.value.keys()): + return False + ret = all(self.value[s] == other.value[s] for s in self.value) + return ret + + def __ne__(self, other): + return not self.__eq__(other) + + def __or__(self, other): + assert isinstance(other, _NodeState) + result = _NodeState(self) + for s, other_infos in other.value.items(): + if s in result.value: + result.value[s].update(other_infos) + else: + result.value[s] = set(other_infos) + return result + + def __sub__(self, other): + assert isinstance(other, set) + result = _NodeState(self) + for s in other: + result.value.pop(s, None) + return result + + def __repr__(self): + return 'NodeState[%s]=%s' % (id(self), repr(self.value)) + + +class Analyzer(cfg.GraphVisitor): + """CFG visitor that determines reaching definitions at statement level.""" + + def __init__(self, graph, definition_factory): + self._definition_factory = definition_factory + super(Analyzer, self).__init__(graph) + # This allows communicating that nodes have extra reaching definitions, + # e.g. those that a function closes over. + self.extra_in = {} + + self.gen_map = {} + + def init_state(self, _): + return _NodeState() + + def visit_node(self, node): + prev_defs_out = self.out[node] + + defs_in = _NodeState(self.extra_in.get(node.ast_node, None)) + for n in node.prev: + defs_in |= self.out[n] + + if anno.hasanno(node.ast_node, anno.Static.SCOPE): + node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) + # The definition objects created by each node must be singletons because + # their ids are used in equality checks. + if node not in self.gen_map: + node_symbols = {} + for s in node_scope.modified: + def_ = self._definition_factory() + if s in node_scope.params: + def_.param_of = node_scope.params[s] + node_symbols[s] = def_ + self.gen_map[node] = _NodeState(node_symbols) + + gen = self.gen_map[node] + kill = node_scope.modified + defs_out = gen | (defs_in - kill) + + else: + # Nodes that don't have a scope annotation are assumed not to touch any + # symbols. + # This Name node below is a literal name, e.g. False + # This can also happen if activity.py forgot to annotate the node with a + # scope object. + assert isinstance( + node.ast_node, + (gast.Name, gast.Break, gast.Continue, gast.Raise)), (node.ast_node, + node) + defs_out = defs_in + + self.in_[node] = defs_in + self.out[node] = defs_out + + # TODO(mdan): Move this to the superclass? + return prev_defs_out != defs_out + + +class TreeAnnotator(transformer.Base): + """AST visitor that annotates each symbol name with its reaching definitions. + + Simultaneously, the visitor runs the dataflow analysis on each function node, + accounting for the effect of closures. For example: + + def foo(): + bar = 1 + def baz(): + # bar = 1 reaches here + """ + + def __init__(self, source_info, graphs, definition_factory): + super(TreeAnnotator, self).__init__(source_info) + self.definition_factory = definition_factory + self.graphs = graphs + self.current_analyzer = None + self.current_cfg_node = None + + def visit_FunctionDef(self, node): + parent_analyzer = self.current_analyzer + subgraph = self.graphs[node] + + # Preorder tree processing: + # 1. if this is a child function, the parent was already analyzed and it + # has the proper state value for the subgraph's entry + # 2. analyze the current function body + # 2. recursively walk the subtree; child functions will be processed + analyzer = Analyzer(subgraph, self.definition_factory) + if parent_analyzer is not None: + # Wire the state between the two subgraphs' analyzers. + parent_out_state = parent_analyzer.out[parent_analyzer.graph.index[node]] + # Exception: symbols modified in the child function are local to it + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + parent_out_state -= body_scope.modified + analyzer.extra_in[node.args] = parent_out_state + + # Complete the analysis for the local function and annotate its body. + analyzer.visit_forward() + + # Recursively process any remaining subfunctions. + self.current_analyzer = analyzer + # Note: not visiting name, decorator_list and returns because they don't + # apply to this anlysis. + # TODO(mdan): Should we still process the function name? + node.args = self.visit(node.args) + node.body = self.visit_block(node.body) + self.current_analyzer = parent_analyzer + + return node + + def visit_nonlocal(self, node): + raise NotImplementedError() + + def visit_global(self, node): + raise NotImplementedError() + + def visit_Name(self, node): + if self.current_analyzer is None: + # Names may appear outside function defs - for example in class + # definitions. + return node + + analyzer = self.current_analyzer + cfg_node = self.current_cfg_node + + assert cfg_node is not None, 'name node outside of any statement?' + + qn = anno.getanno(node, anno.Basic.QN) + if isinstance(node.ctx, gast.Load): + anno.setanno(node, anno.Static.DEFINITIONS, + tuple(analyzer.in_[cfg_node].value.get(qn, ()))) + else: + anno.setanno(node, anno.Static.DEFINITIONS, + tuple(analyzer.out[cfg_node].value.get(qn, ()))) + + return node + + def _aggregate_predecessors_defined_in(self, node): + preds = self.current_analyzer.graph.stmt_prev[node] + node_defined_in = set() + for p in preds: + node_defined_in |= set(self.current_analyzer.out[p].value.keys()) + anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in)) + + def visit_If(self, node): + self._aggregate_predecessors_defined_in(node) + return self.generic_visit(node) + + def visit_For(self, node): + self._aggregate_predecessors_defined_in(node) + + # Manually accounting for the shortcoming described in + # cfg.AstToCfg.visit_For. + parent = self.current_cfg_node + self.current_cfg_node = self.current_analyzer.graph.index[node.iter] + node.target = self.visit(node.target) + self.current_cfg_node = parent + + node.iter = self.visit(node.iter) + node.body = self.visit_block(node.body) + node.orelse = self.visit_block(node.orelse) + + return node + + def visit_While(self, node): + self._aggregate_predecessors_defined_in(node) + return self.generic_visit(node) + + def visit(self, node): + parent = self.current_cfg_node + + if (self.current_analyzer is not None and + node in self.current_analyzer.graph.index): + self.current_cfg_node = self.current_analyzer.graph.index[node] + node = super(TreeAnnotator, self).visit(node) + + self.current_cfg_node = parent + return node + + +def resolve(node, source_info, graphs, definition_factory): + """Resolves reaching definitions for each symbol. + + Args: + node: ast.AST + source_info: transformer.SourceInfo + graphs: Dict[ast.FunctionDef, cfg.Graph] + definition_factory: Callable[[], Definition] + Returns: + ast.AST + """ + visitor = TreeAnnotator(source_info, graphs, definition_factory) + node = visitor.visit(node) + return node diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..243fe804b229686f33a4964b16c987c673a97c4b --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py @@ -0,0 +1,263 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for reaching_definitions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions +from tensorflow.python.platform import test + + +class DefinitionInfoTest(test.TestCase): + + def _parse_and_analyze(self, test_fn): + node, source = parser.parse_entity(test_fn) + entity_info = transformer.EntityInfo( + source_code=source, + source_file=None, + namespace={}, + arg_values=None, + arg_types=None, + owner_type=None) + node = qual_names.resolve(node) + node = activity.resolve(node, entity_info) + graphs = cfg.build(node) + node = reaching_definitions.resolve(node, entity_info, graphs, + reaching_definitions.Definition) + return node + + def assertHasDefs(self, node, num): + defs = anno.getanno(node, anno.Static.DEFINITIONS) + self.assertEqual(len(defs), num) + for r in defs: + self.assertIsInstance(r, reaching_definitions.Definition) + + def assertHasDefinedIn(self, node, expected): + defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) + defined_in_str = set(str(v) for v in defined_in) + if not expected: + expected = () + if not isinstance(expected, tuple): + expected = (expected,) + self.assertSetEqual(defined_in_str, set(expected)) + + def assertSameDef(self, first, second): + self.assertHasDefs(first, 1) + self.assertHasDefs(second, 1) + self.assertIs( + anno.getanno(first, anno.Static.DEFINITIONS)[0], + anno.getanno(second, anno.Static.DEFINITIONS)[0]) + + def assertNotSameDef(self, first, second): + self.assertHasDefs(first, 1) + self.assertHasDefs(second, 1) + self.assertIsNot( + anno.getanno(first, anno.Static.DEFINITIONS)[0], + anno.getanno(second, anno.Static.DEFINITIONS)[0]) + + def test_conditional(self): + + def test_fn(a, b): + a = [] + if b: + a = [] + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].test, 1) + self.assertHasDefs(fn_body[1].body[0].targets[0], 1) + self.assertHasDefs(fn_body[2].value, 2) + + self.assertHasDefinedIn(fn_body[1], ('a', 'b')) + + def test_while(self): + + def test_fn(a): + max(a) + while True: + a = a + a = a + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].value.args[0], 1) + self.assertHasDefs(fn_body[1].body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].body[1].targets[0], 1) + self.assertHasDefs(fn_body[1].body[1].value, 1) + # The loop does have an invariant test, but the CFG doesn't know that. + self.assertHasDefs(fn_body[1].body[0].value, 2) + self.assertHasDefs(fn_body[2].value, 2) + + def test_while_else(self): + + def test_fn(x, i): + y = 0 + while x: + x += i + if i: + break + else: + y = 1 + return x, y + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].test, 2) + self.assertHasDefs(fn_body[1].body[0].target, 1) + self.assertHasDefs(fn_body[1].body[1].test, 1) + self.assertHasDefs(fn_body[1].orelse[0].targets[0], 1) + self.assertHasDefs(fn_body[2].value.elts[0], 2) + self.assertHasDefs(fn_body[2].value.elts[1], 2) + + def test_for_else(self): + + def test_fn(x, i): + y = 0 + for i in x: + x += i + if i: + break + else: + continue + else: + y = 1 + return x, y + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].target, 1) + self.assertHasDefs(fn_body[1].body[0].target, 1) + self.assertHasDefs(fn_body[1].body[1].test, 1) + self.assertHasDefs(fn_body[1].orelse[0].targets[0], 1) + self.assertHasDefs(fn_body[2].value.elts[0], 2) + self.assertHasDefs(fn_body[2].value.elts[1], 2) + + def test_nested_functions(self): + + def test_fn(a, b): + a = [] + if b: + a = [] + + def foo(): + return a + + foo() + + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + def_of_a_in_if = fn_body[1].body[0].targets[0] + + self.assertHasDefs(fn_body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].test, 1) + self.assertHasDefs(def_of_a_in_if, 1) + self.assertHasDefs(fn_body[2].value, 2) + + inner_fn_body = fn_body[1].body[1].body + self.assertSameDef(inner_fn_body[0].value, def_of_a_in_if) + + def test_nested_functions_isolation(self): + + def test_fn(a): + a = 0 + + def child(): + a = 1 + return a + + child() + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + parent_return = fn_body[3] + child_return = fn_body[1].body[1] + # The assignment `a = 1` makes `a` local to `child`. + self.assertNotSameDef(parent_return.value, child_return.value) + + def test_function_call_in_with(self): + + def foo(_): + pass + + def test_fn(a): + with foo(a): + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].items[0].context_expr.func, 0) + self.assertHasDefs(fn_body[0].items[0].context_expr.args[0], 1) + + def test_mutation_subscript(self): + + def test_fn(a): + l = [] + l[0] = a + return l + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + creation = fn_body[0].targets[0] + mutation = fn_body[1].targets[0].value + use = fn_body[2].value + self.assertSameDef(creation, mutation) + self.assertSameDef(creation, use) + + def test_replacement(self): + + def foo(a): + return a + + def test_fn(a): + a = foo(a) + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + param = node.body[0].args.args[0] + source = fn_body[0].value.args[0] + target = fn_body[0].targets[0] + retval = fn_body[1].value + self.assertSameDef(param, source) + self.assertNotSameDef(source, target) + self.assertSameDef(target, retval) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index 7d1e65c958d7787ef5ed707d4822d14a83092975..835d5199fa1a5c145e29a413d4d23b4138a3c1cd 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -44,7 +44,7 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -52,6 +52,7 @@ from tensorflow.python.util import tf_inspect # TODO(mdan): Remove the duplication between this and activity.py. # In particular, the symbol definitions we track here could as well be tracked # there because they follow the same rules for visibility. +# TODO(mdan): Use a CFG based Defined analysis instead. class Scope(object): """Tracks symbol value references. @@ -135,35 +136,39 @@ class TypeInfoResolver(transformer.Base): node.orelse = self._visit_block(node.orelse) return node - def _process_function_arg(self, arg_name): - str_name = str(arg_name) - type_holder = arg_name.ast() - self.scope.setval(arg_name, type_holder) - if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types: + def _process_function_arg(self, arg_node): + qn = anno.getanno(arg_node, anno.Basic.QN) + arg_name = str(qn) + self.scope.setval(qn, arg_node) + if (len(self.enclosing_entities) == 1 and + arg_name in self.entity_info.arg_types): # Forge a node to hold the type information, so that method calls on # it can resolve the type. - type_string, type_obj = self.context.arg_types[str_name] - anno.setanno(type_holder, 'type', type_obj) - anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) + type_string, type_obj = self.entity_info.arg_types[arg_name] + anno.setanno(arg_node, 'type', type_obj) + anno.setanno(arg_node, 'type_fqn', tuple(type_string.split('.'))) def visit_arg(self, node): - self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN)) + self._process_function_arg(node.arg) return node def visit_Name(self, node): self.generic_visit(node) - qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Param): - self._process_function_arg(qn) - elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): - # E.g. if we had - # a = b - # then for future references to `a` we should have definition = `b` - definition = self.scope.getval(qn) - anno.copyanno(definition, node, 'type') - anno.copyanno(definition, node, 'type_fqn') - anno.copyanno(definition, node, 'element_type') - anno.copyanno(definition, node, 'element_shape') + self._process_function_arg(node) + elif isinstance(node.ctx, gast.Load): + qn = anno.getanno(node, anno.Basic.QN) + if self.scope.hasval(qn): + # E.g. if we had + # a = b + # then for future references to `a` we should have definition = `b` + definition = self.scope.getval(qn) + anno.copyanno(definition, node, 'type') + anno.copyanno(definition, node, 'type_fqn') + + # TODO(mdan): Remove this when the directives module is in. + anno.copyanno(definition, node, 'element_type') + anno.copyanno(definition, node, 'element_shape') return node def _process_variable_assignment(self, target, value): @@ -191,52 +196,18 @@ class TypeInfoResolver(transformer.Base): def visit_With(self, node): for item in node.items: if item.optional_vars is not None: - self.apply_to_single_assignments((item.optional_vars,), - item.context_expr, - self._process_variable_assignment) + ast_util.apply_to_single_assignments((item.optional_vars,), + item.context_expr, + self._process_variable_assignment) self.generic_visit(node) return node def visit_Assign(self, node): self.generic_visit(node) - self.apply_to_single_assignments( - node.targets, node.value, self._process_variable_assignment) + ast_util.apply_to_single_assignments(node.targets, node.value, + self._process_variable_assignment) return node - def visit_Call(self, node): - if anno.hasanno(node.func, 'live_val'): - # Symbols targeted by the "set_type" marker function are assigned the data - # type that it specified. - if (anno.getanno(node.func, 'live_val') is - self.context.type_annotation_func): - - if len(node.args) < 2 or len(node.args) > 3: - raise ValueError('"%s" must have either two or three parameters' - % self.context.type_annotation_func) - if len(node.args) == 2: - target_arg, type_arg = node.args - shape_arg = parser.parse_expression('None') - else: - target_arg, type_arg, shape_arg = node.args - if not anno.hasanno(target_arg, anno.Basic.QN): - raise ValueError('the first argument of "%s" must by a symbol' - % self.context.type_annotation_func) - # TODO(mdan): This is vulnerable to symbol renaming. - element_type = type_arg - element_shape = shape_arg - - target_symbol = anno.getanno(target_arg, anno.Basic.QN) - # Find the definition of this symbol and annotate it with the given - # data type. That in turn will cause future uses of the symbol - # to receive the same type annotation. - definition = self.scope.getval(target_symbol) - anno.setanno(node, 'element_type', element_type) - anno.setanno(node, 'element_shape', element_shape) - anno.setanno(definition, 'element_type', element_type) - anno.setanno(definition, 'element_shape', element_shape) - # TODO(mdan): Should we update references between definition and here? - return self.generic_visit(node) - def resolve(node, context): return TypeInfoResolver(context).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index 484562f294bb53a63feeca965b8f94c58aa2a685..404311ba242cf0359cf5695dfe3eeaf9cb858eb8 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -18,13 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import cfg from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.client import session from tensorflow.python.platform import test @@ -62,21 +63,21 @@ class TypeInfoResolverTest(test.TestCase): namespace, arg_types=None): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, - owner_type=None, - recursive=True, - type_annotation_func=utils.set_element_type) + owner_type=None) node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) - node = type_info.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) + graphs = cfg.build(node) + node = activity.resolve(node, entity_info) + node = reaching_definitions.resolve(node, entity_info, graphs, + reaching_definitions.Definition) + node = live_values.resolve(node, entity_info, {}) + node = type_info.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) return node def test_constructor_detection(self): @@ -147,7 +148,7 @@ class TypeInfoResolverTest(test.TestCase): opt.minimize(0) node = self._parse_and_analyze( - test_fn, {'training': training}, + test_fn, {}, arg_types={ 'opt': (training.GradientDescentOptimizer.__name__, training.GradientDescentOptimizer) @@ -180,35 +181,6 @@ class TypeInfoResolverTest(test.TestCase): method_call = node.body[0].body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val')) - def test_type_annotation(self): - - class Foo(object): - pass - - def test_fn(): - f = [] - f = utils.set_element_type(f, Foo, (1, 2, 3)) - return f - - node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) - f_def = node.body[0].body[0].value - self.assertEqual(anno.getanno(f_def, 'element_type').id, 'Foo') - f_ref = node.body[0].body[1].value - self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo') - - def test_type_annotation_args(self): - - class Foo(object): - pass - - def test_fn(f): - utils.set_element_type(f, Foo) - return f - - node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) - f_ref = node.body[0].body[1].value - self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo') - def test_nested_unpacking(self): class Foo(object): @@ -230,25 +202,6 @@ class TypeInfoResolverTest(test.TestCase): self.assertFalse(anno.hasanno(b, 'live_val')) self.assertFalse(anno.hasanno(c, 'live_val')) - def test_inner_scope(self): - - def test_fn(): - a = [] - utils.set_element_type(a, 1) - for _ in a: - b = [] - utils.set_element_type(b, 2) - return a, b - - node = self._parse_and_analyze(test_fn, {'utils': utils}) - a, b = node.body[0].body[2].body[2].value.elts - self.assertEquals(anno.getanno(a, 'element_type').n, 1) - self.assertEquals(anno.getanno(b, 'element_type').n, 2) - self.assertFalse(anno.hasanno(a, 'type')) - self.assertFalse(anno.hasanno(b, 'type')) - self.assertFalse(anno.hasanno(a, 'live_val')) - self.assertFalse(anno.hasanno(b, 'live_val')) - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py index 9c479ebc2fa83d27dc363ae306daedb556734a1f..5831d57ceb58d4b291a4f52bbf4282e107104219 100644 --- a/tensorflow/contrib/autograph/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -26,6 +26,7 @@ import textwrap import gast +from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names @@ -43,39 +44,65 @@ class ReplaceTransformer(gast.NodeTransformer): """ self.replacements = replacements self.in_replacements = False + self.preserved_annos = { + anno.Basic.ORIGIN, + anno.Basic.SKIP_PROCESSING, + anno.Static.ORIG_DEFINITIONS, + } + + def _prepare_replacement(self, replaced, key): + """Prepares a replacement AST that's safe to swap in for a node. + + Args: + replaced: ast.AST, the node being replaced + key: Hashable, the key of the replacement AST + Returns: + ast.AST, the replacement AST + """ + repl = self.replacements[key] + + new_nodes = ast_util.copy_clean(repl, preserve_annos=self.preserved_annos) + if isinstance(new_nodes, gast.AST): + new_nodes = [new_nodes] + + return new_nodes def visit_Expr(self, node): - if (isinstance(node.value, gast.Name) and - node.value.id in self.replacements): - return self.visit(node.value) - self.generic_visit(node) - return node + # When replacing a placeholder with an entire statement, the replacement + # must stand on its own and not be wrapped in an Expr. + new_value = self.visit(node.value) + if new_value is node.value: + return node + return new_value def visit_keyword(self, node): - if node.arg in self.replacements: - repl = self.replacements[node.arg] - if isinstance(repl, gast.keyword): - return repl - elif (isinstance(repl, (list, tuple)) and repl and - all(isinstance(r, gast.keyword) for r in repl)): - return repl - # TODO(mdan): We may allow replacing with a string as well. - # For example, if one wanted to replace foo with bar in foo=baz, then - # we could allow changing just node arg, so that we end up with bar=baz. - raise ValueError( - 'a keyword argument may only be replaced by another keyword or a ' - 'non-empty list of keywords. Found: %s' % repl) - return self.generic_visit(node) + if node.arg not in self.replacements: + return self.generic_visit(node) + + repl = self._prepare_replacement(node, node.arg) + if isinstance(repl, gast.keyword): + return repl + elif (repl and isinstance(repl, (list, tuple)) and + all(isinstance(r, gast.keyword) for r in repl)): + return repl + # TODO(mdan): We may allow replacing with a string as well. + # For example, if one wanted to replace foo with bar in foo=baz, then + # we could allow changing just node arg, so that we end up with bar=baz. + raise ValueError( + 'a keyword argument may only be replaced by another keyword or a ' + 'non-empty list of keywords. Found: %s' % repl) def visit_FunctionDef(self, node): node = self.generic_visit(node) - if node.name in self.replacements: - repl = self.replacements[node.name] - if not isinstance(repl, (gast.Name, ast.Name)): - raise ValueError( - 'a function name can only be replaced by a Name node. Found: %s' % - repl) - node.name = repl.id + if node.name not in self.replacements: + return node + + repl = self.replacements[node.name] + if not isinstance(repl, (gast.Name, ast.Name)): + raise ValueError( + 'a function name can only be replaced by a Name node. Found: %s' % + repl) + node.name = repl.id return node def _check_has_context(self, node): @@ -113,8 +140,8 @@ class ReplaceTransformer(gast.NodeTransformer): def _set_inner_child_context(self, node, ctx): if isinstance(node, gast.Attribute): - self._set_inner_child_context(node.value, ctx) - node.ctx = gast.Load() + self._set_inner_child_context(node.value, gast.Load()) + node.ctx = ctx elif isinstance(node, gast.Tuple): for e in node.elts: self._set_inner_child_context(e, ctx) @@ -148,6 +175,7 @@ class ReplaceTransformer(gast.NodeTransformer): node = self.generic_visit(node) if node.attr not in self.replacements: return node + repl = self.replacements[node.attr] if not isinstance(repl, gast.Name): raise ValueError( @@ -159,9 +187,7 @@ class ReplaceTransformer(gast.NodeTransformer): if node.id not in self.replacements: return node - new_nodes = ast_util.copy_clean(self.replacements[node.id]) - if isinstance(new_nodes, gast.AST): - new_nodes = [new_nodes] + new_nodes = self._prepare_replacement(node, node.id) # Preserve the target context. for n in new_nodes: @@ -182,7 +208,7 @@ class ReplaceTransformer(gast.NodeTransformer): def _convert_to_ast(n): - """Convert from a known data type to AST.""" + """Converts from a known data type to AST.""" if isinstance(n, str): # Note: the node will receive the ctx value from the template, see # ReplaceTransformer.visit_Name. @@ -197,7 +223,7 @@ def _convert_to_ast(n): def replace(template, **replacements): - """Replace placeholders in a Python template. + """Replaces placeholders in a Python template. AST Name and Tuple nodes always receive the context that inferred from the template. However, when replacing more complex nodes (that can potentially diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/contrib/autograph/pyct/templates_test.py index a01f8bf04c4faa6ec1779e0fb306155d99f5bd09..77e8ff62fd8665e095cfb410a2aa418e9f9bd52b 100644 --- a/tensorflow/contrib/autograph/pyct/templates_test.py +++ b/tensorflow/contrib/autograph/pyct/templates_test.py @@ -97,6 +97,19 @@ class TemplatesTest(test.TestCase): with self.assertRaises(ValueError): templates.replace(template, foo=1) + def test_replace_attribute_context(self): + template = """ + def test_fn(foo): + foo = 0 + """ + + node = templates.replace( + template, + foo=parser.parse_expression('a.b.c'))[0] + self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) + self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load) + self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load) + def test_replace_call_keyword(self): template = """ def test_fn(): @@ -151,17 +164,13 @@ class TemplatesTest(test.TestCase): self.assertEqual(node.func.id, 'bar') self.assertEqual(node.func.args[0].id, 'baz') - def replace_as_expression_restrictions(self): + def test_replace_as_expression_restrictions(self): template = """ foo(a) bar(b) """ with self.assertRaises(ValueError): templates.replace_as_expression(template) - with self.assertRaises(ValueError): - templates.replace('') - with self.assertRaises(ValueError): - templates.replace('a = b') if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/pyct/testing/BUILD b/tensorflow/contrib/autograph/pyct/testing/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..9ef1ac9663eac8febffd697d7164425716b65d9d --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/testing/BUILD @@ -0,0 +1,46 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "testing", + srcs = [ + "codegen.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", + "@gast_archive//:gast", + ], +) + +py_test( + name = "codegen_test", + size = "large", + srcs = ["codegen_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_windows", + "nomsan", + ], + deps = [ + ":testing", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen.py b/tensorflow/contrib/autograph/pyct/testing/codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..279e7c09dc6449184e2029ad65fc3f71d94db8b4 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/testing/codegen.py @@ -0,0 +1,234 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Random code generation for testing/fuzzing.""" +# pylint: disable=invalid-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import string + +import gast +import numpy as np + +from tensorflow.contrib.autograph.pyct import templates + + +class NodeSampler(object): + sample_map = None + + def sample(self): + nodes, magnitudes = zip(*self.sample_map.items()) + return np.random.choice( + nodes, p=np.array(magnitudes, dtype='float32') / np.sum(magnitudes)) + + +class StatementSampler(NodeSampler): + sample_map = dict(( + (gast.Assign, 10), + (gast.Print, 1), + (gast.If, 2), + (gast.While, 2), + (gast.For, 0), + )) + + +class ExpressionSampler(NodeSampler): + sample_map = dict(( + (gast.UnaryOp, 1), + (gast.BinOp, 8), + (gast.Name, 1), + (gast.Call, 0), + )) + + +class CompareSampler(NodeSampler): + sample_map = dict(( + (gast.Eq, 1), + (gast.NotEq, 1), + (gast.Lt, 1), + (gast.LtE, 1), + (gast.Gt, 1), + (gast.GtE, 1), + (gast.Is, 1), + (gast.IsNot, 1), + )) + + +class BinaryOpSampler(NodeSampler): + sample_map = dict(( + (gast.Add, 1), + (gast.Sub, 1), + (gast.Mult, 1), + (gast.Div, 1), + (gast.FloorDiv, 1), + (gast.Mod, 1), + (gast.Pow, 1), + )) + + +class UnaryOpSampler(NodeSampler): + sample_map = dict(((gast.USub, 1), (gast.UAdd, 0))) + + +class NameSampler(NodeSampler): + sample_map = dict(( + ('new', 1), + ('existing', 1), + )) + + +N_CONTROLFLOW_STATEMENTS = 10 +N_FUNCTIONDEF_STATEMENTS = 10 + + +class CodeGenerator(object): + """Generate random syntactically-valid Python ASTs.""" + + def __init__(self, max_depth=3, depth=0): + self.max_depth = max_depth + self.depth = depth + + def generate_statement(self): + """Generate a statement node, dispatching to the correct class method.""" + desired_node = StatementSampler().sample() + self.depth += 1 + + # Enforce some constraints on generating statements. + # E.g., if statements need at least 3 readable variables. + # If we fail to satisfy our constraints, draw another sample. + if desired_node in (gast.While, gast.For, gast.If): + if self.depth > self.max_depth: + return self.generate_statement() + + # Go get the generator method and run it + method = 'generate_' + desired_node.__name__ + visitor = getattr(self, method) + node = visitor() + self.depth -= 1 + return node + + def sample_node_list(self, low, high, generator): + """Generate a list of statements of random length. + + Args: + low: Fewest number of statements to generate. + high: Highest number of statements to generate. + generator: Function to call to generate nodes. + + Returns: + A list of statements. + """ + statements = [] + for _ in range(np.random.randint(low, high)): + statements.append(generator()) + return statements + + def generate_Name(self, ctx=gast.Load()): + variable_name = '_' + ''.join( + random.choice(string.ascii_lowercase) for _ in range(4)) + return gast.Name(variable_name, ctx=ctx, annotation=None) + + def generate_BinOp(self): + # TODO(alexbw): convert to generate_expression when we get to limit + # expression depth. + op = BinaryOpSampler().sample()() + return gast.BinOp(self.generate_Name(), op, self.generate_Name()) + + def generate_Compare(self): + op = CompareSampler().sample()() + return gast.Compare(self.generate_Name(), [op], [self.generate_Name()]) + + def generate_UnaryOp(self): + operand = self.generate_Name() + op = UnaryOpSampler().sample()() + return gast.UnaryOp(op, operand) + + def generate_expression(self): + desired_node = ExpressionSampler().sample() + # Go get the generator method and run it + method = 'generate_' + desired_node.__name__ + generator = getattr(self, method) + return generator() + + def generate_Assign(self): + """Generate an Assign node.""" + # Generate left-hand side + target_node = self.generate_Name(gast.Store()) + # Generate right-hand side + value_node = self.generate_expression() + # Put it all together + node = gast.Assign(targets=[target_node], value=value_node) + return node + + def generate_If(self): + """Generate an If node.""" + test = self.generate_Compare() + + # Generate true branch statements + body = self.sample_node_list( + low=1, + high=N_CONTROLFLOW_STATEMENTS // 2, + generator=self.generate_statement) + + # Generate false branch statements + orelse = self.sample_node_list( + low=1, + high=N_CONTROLFLOW_STATEMENTS // 2, + generator=self.generate_statement) + + node = gast.If(test, body, orelse) + return node + + def generate_While(self): + """Generate a While node.""" + + test = self.generate_Compare() + body = self.sample_node_list( + low=1, high=N_CONTROLFLOW_STATEMENTS, generator=self.generate_statement) + orelse = [] # not generating else statements + + node = gast.While(test, body, orelse) + return node + + def generate_Call(self): + raise NotImplementedError + + def generate_Return(self): + return gast.Return(self.generate_expression()) + + def generate_Print(self): + return templates.replace('print(x)', x=self.generate_expression())[0] + + def generate_FunctionDef(self): + """Generate a FunctionDef node.""" + + # Generate the arguments, register them as available + arg_vars = self.sample_node_list( + low=2, high=10, generator=lambda: self.generate_Name(gast.Param())) + args = gast.arguments(arg_vars, None, [], [], None, []) + + # Generate the function body + body = self.sample_node_list( + low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement) + body.append(self.generate_Return()) + fn_name = self.generate_Name().id + node = gast.FunctionDef(fn_name, args, body, (), None) + return node + + +def generate_random_functiondef(): + return CodeGenerator().generate_FunctionDef() diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen_test.py b/tensorflow/contrib/autograph/pyct/testing/codegen_test.py new file mode 100644 index 0000000000000000000000000000000000000000..255c3b2a2edc65ab978d8c32682fafd8ce00f5ac --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/testing/codegen_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for type_info module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct.testing import codegen +from tensorflow.python.platform import test + + +class CodeGenTest(test.TestCase): + + def test_codegen_gens(self): + np.random.seed(0) + for _ in range(1000): + node = codegen.generate_random_functiondef() + fn = compiler.ast_to_object(node) + self.assertIsNotNone( + fn, 'Generated invalid AST that could not convert to source.') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index 60bca8b38dcf62b4e997379d075cfc45511a894f..969ca12244148b346ba3160fba124384a9641a05 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -32,31 +32,176 @@ class AutographParseError(SyntaxError): pass -def try_ast_to_source(node): - try: - return compiler.ast_to_source(node) - except AssertionError: - return '' +# TODO(mdan): Use namedtuple. +class EntityInfo(object): + """Contains information about a Python entity. Immutable. + + Examples of entities include functions and classes. + + Attributes: + source_code: The entity's source code. + source_file: The entity's source file. + namespace: Dict[str, ], containing symbols visible to the entity + (excluding parameters). + arg_values: dict[str->*], containing parameter values, if known. + arg_types: dict[str->*], containing parameter types, if known. + owner_type: The surrounding class type of the function, if present. + """ + + # TODO(mdan): Remove the default and update tests. + def __init__(self, source_code, source_file, namespace, arg_values, arg_types, + owner_type): + self.source_code = source_code + self.source_file = source_file + self.namespace = namespace + self.arg_values = {} if arg_values is None else arg_values + self.arg_types = {} if arg_types is None else arg_types + self.owner_type = owner_type + + +class _StateStack(object): + """Typed stack abstraction. + + This class provides syntactic sugar for a stack of objects of known + type. It allows accessing attributes of the object at the top of the stack + directly against this object, which allows for very terse syntax. + + For example, this code: + + stack = _StateStack(Foo) + stack.enter() + stack.bar + + Is equivalent to: + + stack = [] + stack.append(Foo()) + foo = stack[-1] + foo.bar + + See _State for more on how this is used. + + Attributes: + type: Any, the type of objects that this stack holds + level: int, the current stack depth + value: Any, the instance of the object at the top of the stack + """ + + def __init__(self, type_): + # Because we override __setattr__, we need to attach these attributes using + # the superclass' setattr. + object.__setattr__(self, 'type', type_) + object.__setattr__(self, '_stack', []) + self.enter() + + def enter(self): + self._stack.append(self.type()) + + def exit(self): + return self._stack.pop() + + @property + def level(self): + return len(self._stack) + + @property + def value(self): + return self._stack[-1] + + def __getattr__(self, key): + return getattr(self._stack[-1], key) + + def __setattr__(self, key, value): + setattr(self._stack[-1], key, value) + + +class _State(object): + """Supporting class for nested scope variable space for converter.Base. + + This structure offers syntactic sugar over a dict of stacks of objects + of known type. These structures are useful to keep state during AST walks. + Multiple different scopes can be tracked in parallel. For example: + + s = _State() + + s[foo].enter() + s[bar].enter() # this will not affect s[foo] + + Element access has special semantics: + * keys are a data type + * element values are _StateStack(type=key) objects + * missing elements are automatically added, similarly to defaultdict + + For example, the following block : + + _State s + s[Foo] + + Is equivalent to: + + s = {} + if Foo not in s: + s[Foo] = Foo() + s[Foo] + + See Base for how it's used. + """ + + def __init__(self): + self._value = {} + + def __getitem__(self, key): + if key not in self._value: + self._value[key] = _StateStack(key) + return self._value[key] class Base(gast.NodeTransformer): - """Base class for specialized transformers. + """Base class for general-purpose code transformers transformers. + + This is an extension of ast.NodeTransformer that provides a few additional + functions, like state tracking within the scope of arbitrary node, helpers + for processing code blocks, debugging, mapping of transformed code to + original code, and others. Scope-local state tracking: to keep state across nodes, at the level of (possibly nested) scopes, use enter/exit_local_scope and set/get_local. You must call enter/exit_local_scope manually, but the transformer detects when they are not properly paired. + + The transformer allows keeping state across calls to visit_* that is local to + arbitrary nodes and their descendants, using the self.state attribute. + Multiple independent scopes are allowed and automatically constructed. + + For example, to keep track of the If node that encloses any Name node, one can + write: + + class FooType(object): + + def __init__(self): + self.foo_property = None + + class DummyTransformer(Base): + + def visit_If(self, node): + self.state[FooType].enter() + self.state[FooType].foo_property = node + + def visit_Name(self, node): + self.state[FooType].foo_property # will hold the innermost enclosing if """ - def __init__(self, context): + # TODO(mdan): Document all extra features. + + def __init__(self, entity_info): """Initialize the transformer. Subclasses should call this. Args: - context: An EntityContext. + entity_info: An EntityInfo object. """ self._lineno = 0 self._col_offset = 0 - self.context = context + self.entity_info = entity_info self._enclosing_entities = [] # A stack that allows keeping mutable, scope-local state where scopes may be @@ -65,6 +210,12 @@ class Base(gast.NodeTransformer): self._local_scope_state = [] self.enter_local_scope() + # Allows scoping of local variables to keep state across calls to visit_* + # methods. Multiple scope hierchies may exist and are keyed by tag. A scope + # is valid at one or more nodes and all its children. Scopes created in + # child nodes supersede their parent. Scopes are isolated from one another. + self.state = _State() + @property def enclosing_entities(self): return tuple(self._enclosing_entities) @@ -74,7 +225,9 @@ class Base(gast.NodeTransformer): return len(self._local_scope_state) def enter_local_scope(self, inherit=None): - """Marks entry into a new local scope. + """Deprecated. Use self.state instead. + + Marks entry into a new local scope. Args: inherit: Optional enumerable of variable names to copy from the @@ -89,7 +242,9 @@ class Base(gast.NodeTransformer): self._local_scope_state.append(scope_entered) def exit_local_scope(self, keep=None): - """Marks exit from the current local scope. + """Deprecated. Use self.state instead. + + Marks exit from the current local scope. Args: keep: Optional enumerable of variable names to copy into the @@ -106,9 +261,11 @@ class Base(gast.NodeTransformer): return scope_left def set_local(self, name, value): + """Deprecated. Use self.state instead.""" self._local_scope_state[-1][name] = value def get_local(self, name, default=None): + """Deprecated. Use self.state instead.""" return self._local_scope_state[-1].get(name, default) def debug_print(self, node): @@ -189,9 +346,9 @@ class Base(gast.NodeTransformer): node_destination = new_destination return results - # TODO(mdan): Once we have error tracing, we may be able to just go to SSA. + # TODO(mdan): Remove. def apply_to_single_assignments(self, targets, values, apply_fn): - """Applies a fuction to each individual assignment. + """Applies a function to each individual assignment. This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. It tries to break down the unpacking if possible. In effect, it has the same @@ -219,7 +376,7 @@ class Base(gast.NodeTransformer): targets field of an ast.Assign node. values: an AST node. apply_fn: a function of a single argument, which will be called with the - respective nodes of each single assignment. The signaure is + respective nodes of each single assignment. The signature is apply_fn(target, value), no return value. """ if not isinstance(targets, (list, tuple)): @@ -237,15 +394,40 @@ class Base(gast.NodeTransformer): # TODO(mdan): Look into allowing to rewrite the AST here. apply_fn(target, values) + def _get_source(self, node): + try: + source, _ = compiler.ast_to_source(node) + return source + # pylint: disable=broad-except + # This function is used for error reporting. If an exception occurs here, + # it should be suppressed, in favor of emitting as informative a message + # about the original error as possible. + except Exception: + return '' + def visit(self, node): - source_code = self.context.source_code - source_file = self.context.source_file + if not isinstance(node, gast.AST): + # This is not that uncommon a mistake: various node bodies are lists, for + # example, posing a land mine for transformers that need to recursively + # call `visit`. The error needs to be raised before the exception handler + # below is installed, because said handler will mess up if `node` is not, + # in fact, a node. + msg = ( + 'invalid value for "node": expected "ast.AST", got "{}"; to' + ' visit lists of nodes, use "visit_block" instead').format(type(node)) + raise ValueError(msg) + + source_code = self.entity_info.source_code + source_file = self.entity_info.source_file did_enter_function = False local_scope_size_at_entry = len(self._local_scope_state) + processing_expr_node = False try: if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): did_enter_function = True + elif isinstance(node, gast.Expr): + processing_expr_node = True if did_enter_function: self._enclosing_entities.append(node) @@ -254,9 +436,23 @@ class Base(gast.NodeTransformer): self._lineno = node.lineno self._col_offset = node.col_offset + if processing_expr_node: + entry_expr_value = node.value + if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): result = super(Base, self).visit(node) + # Adjust for consistency: replacing the value of an Expr with + # an Assign node removes the need for the Expr node. + if processing_expr_node: + if isinstance(result, gast.Expr) and result.value != entry_expr_value: + # When the replacement is a list, it is assumed that the list came + # from a template that contained a number of statements, which + # themselves are standalone and don't require an enclosing Expr. + if isinstance(result.value, + (list, tuple, gast.Assign, gast.AugAssign)): + result = result.value + # On exception, the local scope integrity is not guaranteed. if did_enter_function: self._enclosing_entities.pop() @@ -275,7 +471,7 @@ class Base(gast.NodeTransformer): except (ValueError, AttributeError, KeyError, NotImplementedError) as e: msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( - e.__class__.__name__, str(e), try_ast_to_source(node), + e.__class__.__name__, str(e), self._get_source(node), pretty_printer.fmt(node, color=False)) if source_code: line = source_code.splitlines()[self._lineno - 1] diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py index f110e79605945e908e8a49112cf758ec29fa1b11..a37e922a1de902106dd3a11f20a14ddde8f6675e 100644 --- a/tensorflow/contrib/autograph/pyct/transformer_test.py +++ b/tensorflow/contrib/autograph/pyct/transformer_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.platform import test @@ -29,16 +28,14 @@ from tensorflow.python.platform import test class TransformerTest(test.TestCase): - def _context_for_testing(self): - return context.EntityContext( - namer=None, + def _simple_source_info(self): + return transformer.EntityInfo( source_code=None, source_file=None, namespace=None, arg_values=None, arg_types=None, - owner_type=None, - recursive=False) + owner_type=None) def test_entity_scope_tracking(self): @@ -55,7 +52,7 @@ class TransformerTest(test.TestCase): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) - tr = TestTransformer(self._context_for_testing()) + tr = TestTransformer(self._simple_source_info()) def test_function(): a = 0 @@ -96,6 +93,83 @@ class TransformerTest(test.TestCase): inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities')) + def assertSameAnno(self, first, second, key): + self.assertIs(anno.getanno(first, key), anno.getanno(second, key)) + + def assertDifferentAnno(self, first, second, key): + self.assertIsNot(anno.getanno(first, key), anno.getanno(second, key)) + + def test_state_tracking(self): + + class LoopState(object): + pass + + class CondState(object): + pass + + class TestTransformer(transformer.Base): + + def visit(self, node): + anno.setanno(node, 'loop_state', self.state[LoopState].value) + anno.setanno(node, 'cond_state', self.state[CondState].value) + return super(TestTransformer, self).visit(node) + + def visit_While(self, node): + self.state[LoopState].enter() + node = self.generic_visit(node) + self.state[LoopState].exit() + return node + + def visit_If(self, node): + self.state[CondState].enter() + node = self.generic_visit(node) + self.state[CondState].exit() + return node + + tr = TestTransformer(self._simple_source_info()) + + def test_function(a): + a = 1 + while a: + _ = 'a' + if a > 2: + _ = 'b' + while True: + raise '1' + if a > 3: + _ = 'c' + while True: + raise '1' + + node, _ = parser.parse_entity(test_function) + node = tr.visit(node) + + fn_body = node.body[0].body + outer_while_body = fn_body[1].body + self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state') + self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state') + + first_if_body = outer_while_body[1].body + self.assertDifferentAnno(outer_while_body[0], first_if_body[0], + 'cond_state') + self.assertSameAnno(outer_while_body[0], first_if_body[0], 'loop_state') + + first_inner_while_body = first_if_body[1].body + self.assertSameAnno(first_if_body[0], first_inner_while_body[0], + 'cond_state') + self.assertDifferentAnno(first_if_body[0], first_inner_while_body[0], + 'loop_state') + + second_if_body = outer_while_body[2].body + self.assertDifferentAnno(first_if_body[0], second_if_body[0], 'cond_state') + self.assertSameAnno(first_if_body[0], second_if_body[0], 'loop_state') + + second_inner_while_body = second_if_body[1].body + self.assertDifferentAnno(first_inner_while_body[0], + second_inner_while_body[0], 'cond_state') + self.assertDifferentAnno(first_inner_while_body[0], + second_inner_while_body[0], 'loop_state') + def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): @@ -118,7 +192,7 @@ class TransformerTest(test.TestCase): def visit_For(self, node): return self._annotate_result(node) - tr = TestTransformer(self._context_for_testing()) + tr = TestTransformer(self._simple_source_info()) def test_function(a): """Docstring.""" @@ -157,7 +231,7 @@ class TransformerTest(test.TestCase): self.exit_local_scope() return node - tr = TestTransformer(self._context_for_testing()) + tr = TestTransformer(self._simple_source_info()) def no_exit(a): if a > 0: @@ -196,7 +270,7 @@ class TransformerTest(test.TestCase): z = y return z - tr = TestTransformer(self._context_for_testing()) + tr = TestTransformer(self._simple_source_info()) node, _ = parser.parse_entity(test_function) node = tr.visit(node) @@ -208,6 +282,88 @@ class TransformerTest(test.TestCase): self.assertTrue(isinstance(node.body[1].body[0], gast.Assign)) self.assertTrue(isinstance(node.body[1].body[1], gast.Return)) + def test_robust_error_on_list_visit(self): + + class BrokenTransformer(transformer.Base): + + def visit_If(self, node): + # This is broken because visit expects a single node, not a list, and + # the body of an if is a list. + # Importantly, the default error handling in visit also expects a single + # node. Therefore, mistakes like this need to trigger a type error + # before the visit called here installs its error handler. + # That type error can then be caught by the enclosing call to visit, + # and correctly blame the If node. + self.visit(node.body) + return node + + def test_function(x): + if x > 0: + return x + + tr = BrokenTransformer(self._simple_source_info()) + + node, _ = parser.parse_entity(test_function) + with self.assertRaises(transformer.AutographParseError) as cm: + node = tr.visit(node) + obtained_message = str(cm.exception) + expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"' + self.assertRegexpMatches(obtained_message, expected_message) + # The exception should point at the if statement, not any place else. Could + # also check the stack trace. + self.assertTrue( + 'Occurred at node:\nIf' in obtained_message, obtained_message) + self.assertTrue( + 'Occurred at node:\nFunctionDef' not in obtained_message, + obtained_message) + self.assertTrue( + 'Occurred at node:\nReturn' not in obtained_message, obtained_message) + + def test_robust_error_on_ast_corruption(self): + # A child class should not be able to be so broken that it causes the error + # handling in `transformer.Base` to raise an exception. Why not? Because + # then the original error location is dropped, and an error handler higher + # up in the call stack gives misleading information. + + # Here we test that the error handling in `visit` completes, and blames the + # correct original exception, even if the AST gets corrupted. + + class NotANode(object): + pass + + class BrokenTransformer(transformer.Base): + + def visit_If(self, node): + node.body = NotANode() + raise ValueError('I blew up') + + def test_function(x): + if x > 0: + return x + + tr = BrokenTransformer(self._simple_source_info()) + + node, _ = parser.parse_entity(test_function) + with self.assertRaises(transformer.AutographParseError) as cm: + node = tr.visit(node) + obtained_message = str(cm.exception) + # The message should reference the exception actually raised, not anything + # from the exception handler. + expected_substring = 'I blew up' + self.assertTrue(expected_substring in obtained_message, obtained_message) + # Expect the exception to have failed to parse the corrupted AST + self.assertTrue( + '' in obtained_message, + obtained_message) + # The exception should point at the if statement, not any place else. Could + # also check the stack trace. + self.assertTrue( + 'Occurred at node:\nIf' in obtained_message, obtained_message) + self.assertTrue( + 'Occurred at node:\nFunctionDef' not in obtained_message, + obtained_message) + self.assertTrue( + 'Occurred at node:\nReturn' not in obtained_message, obtained_message) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD index d82c17bf2afd01aedf4344f983b02c09abcb9bad..d2b399f19b63bfaa20d334df78ae60d50f6ca6e7 100644 --- a/tensorflow/contrib/autograph/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -28,7 +28,6 @@ py_library( "tensor_list.py", "testing.py", "type_check.py", - "type_hints.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py index 817d4126d106487e1fea3e442712a69bbfccd7f3..57b5f747417613a5dd5bce08e4a9e9ef98442cf6 100644 --- a/tensorflow/contrib/autograph/utils/__init__.py +++ b/tensorflow/contrib/autograph/utils/__init__.py @@ -30,4 +30,3 @@ from tensorflow.contrib.autograph.utils.py_func import wrap_py_func from tensorflow.contrib.autograph.utils.tensor_list import dynamic_list_append from tensorflow.contrib.autograph.utils.testing import fake_tf from tensorflow.contrib.autograph.utils.type_check import is_tensor -from tensorflow.contrib.autograph.utils.type_hints import set_element_type diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py index 998087e056c2cd264399982220d6e0528aab9edb..4dd440ef197b7e24b901bc9e30794b0182378a32 100644 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -27,6 +27,7 @@ from tensorflow.contrib.autograph.utils import type_check from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops @@ -43,6 +44,8 @@ def dynamic_builtin(f, *args, **kwargs): return dynamic_int(*args, **kwargs) if f is float: return dynamic_float(*args, **kwargs) + if f is abs: + return dynamic_abs(*args, **kwargs) raise NotImplementedError( 'The "%s" builtin is not yet supported.' % f.__name__) @@ -50,15 +53,22 @@ def dynamic_builtin(f, *args, **kwargs): def dynamic_len(list_or_tensor): """Implementation of len using dynamic dispatch.""" - if tensor_util.is_tensor(list_or_tensor): + if _is_tensor_list(list_or_tensor): + return list_ops.tensor_list_length(list_or_tensor) + elif tensor_util.is_tensor(list_or_tensor): shape = list_or_tensor.shape - if not shape: + if not shape.ndims: raise ValueError( 'len requires non-zero rank for tensor "%s"' % list_or_tensor) return array_ops.shape(list_or_tensor)[0] return len(list_or_tensor) +def _is_tensor_list(list_or_tensor): + return (tensor_util.is_tensor(list_or_tensor) + and list_or_tensor.dtype == dtypes.variant) + + def dynamic_int(num_or_tensor, **kwargs): """Implementation of int() using dynamic dispatch.""" if tensor_util.is_tensor(num_or_tensor): @@ -73,6 +83,13 @@ def dynamic_float(num_or_tensor, **kwargs): return float(num_or_tensor) +def dynamic_abs(num_or_tensor, **kwargs): + if tensor_util.is_tensor(num_or_tensor): + return math_ops.abs(num_or_tensor, **kwargs) + else: + return abs(num_or_tensor, **kwargs) + + def dynamic_range(start_or_stop, stop=None, step=None): """Implementation of range using dynamic dispatch.""" if type_check.is_tensor(start_or_stop, stop, step): diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py index 0c2312178a921037fa419818bf309d671c33914d..b1cd5253bc3ffb1e67d89ef79cf56eaeb65fae07 100644 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -33,7 +33,8 @@ class BuiltinsTest(test.TestCase): def test_dynamic_len_tf_scalar(self): a = constant_op.constant(1) - with self.assertRaises(ValueError): + with self.assertRaisesRegexp(ValueError, + 'len requires non-zero rank for tensor.*'): with self.test_session() as sess: sess.run(builtins.dynamic_builtin(len, a)) @@ -43,6 +44,23 @@ class BuiltinsTest(test.TestCase): with self.test_session() as sess: self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a))) + def test_dynamic_abs_tf_scalar(self): + a = constant_op.constant(-1) + + with self.test_session() as sess: + self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a))) + + def test_dynamic_abs_tf_array(self): + a = constant_op.constant([-1, 2, -3]) + + with self.test_session() as sess: + self.assertListEqual([1, 2, 3], + list(sess.run(builtins.dynamic_builtin(abs, a)))) + + def test_dynamic_abs_py_scalar(self): + a = -1 + self.assertEqual(1, builtins.dynamic_builtin(abs, a)) + def test_dynamic_len_tf_matrix(self): a = constant_op.constant([[1, 2], [3, 4]]) diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 012a51f71101471850d312033c41dcbc4805d44c..55faad983f2bcf2f3fa633669bd371608e2e925b 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -58,8 +58,6 @@ def batch_function(num_batch_threads, max_batch_size, batch_timeout_micros, allowed_batch_sizes=None, - grad_timeout_micros=60 * 1000 * 1000, - unbatch_timeout_micros=60 * 1000 * 1000, max_enqueued_batches=10): """Batches the computation done by the decorated function. @@ -94,10 +92,6 @@ def batch_function(num_batch_threads, does nothing. Otherwise, supplies a list of batch sizes, causing the op to pad batches up to one of those sizes. The entries must increase monotonically, and the final entry must equal max_batch_size. - grad_timeout_micros: The timeout to use for the gradient. See the - documentation of the unbatch op for more details. Defaults to 60s. - unbatch_timeout_micros: The timeout to use for unbatching. See the - documentation of the unbatch op for more details. Defaults to 60s. max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. Returns: @@ -119,10 +113,6 @@ def batch_function(num_batch_threads, raise ValueError("All arguments to functions decorated with " "`batch_function` are supposed to be Tensors; " "found %s" % repr(a)) - for inp in computation.captured_inputs: - print("inp: %s" % inp) - for op in inp.consumers(): - print("op: %s" % op) return gen_batch_ops.batch_function( num_batch_threads=num_batch_threads, max_batch_size=max_batch_size, diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py index d9e23646d8334014f1bef0d0744df9310b59909f..9e6a146f67796466202cc5074ddd25e4c2b083a6 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py @@ -29,7 +29,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib -from tensorflow.python.ops.distributions import gamma as gamma_lib from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test @@ -256,50 +255,6 @@ class ExpectationTest(test.TestCase): gradq_approx_kl_normal_normal_, rtol=0.01, atol=0.) - def test_docstring_example_gamma(self): - with self.test_session() as sess: - num_draws = int(1e5) - concentration_p = constant_op.constant(1.) - concentration_q = constant_op.constant(2.) - p = gamma_lib.Gamma(concentration=concentration_p, rate=1.) - q = gamma_lib.Gamma(concentration=concentration_q, rate=3.) - approx_kl_gamma_gamma = monte_carlo_lib.expectation( - f=lambda x: p.log_prob(x) - q.log_prob(x), - samples=p.sample(num_draws, seed=42), - log_prob=p.log_prob, - use_reparametrization=(p.reparameterization_type - == distribution_lib.FULLY_REPARAMETERIZED)) - exact_kl_gamma_gamma = kullback_leibler.kl_divergence(p, q) - [exact_kl_gamma_gamma_, approx_kl_gamma_gamma_] = sess.run([ - exact_kl_gamma_gamma, approx_kl_gamma_gamma]) - self.assertEqual( - False, - p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED) - self.assertAllClose(exact_kl_gamma_gamma_, approx_kl_gamma_gamma_, - rtol=0.01, atol=0.) - - # Compare gradients. (Not present in `docstring`.) - gradp = lambda fp: gradients_impl.gradients(fp, concentration_p)[0] - gradq = lambda fq: gradients_impl.gradients(fq, concentration_q)[0] - [ - gradp_exact_kl_gamma_gamma_, - gradq_exact_kl_gamma_gamma_, - gradp_approx_kl_gamma_gamma_, - gradq_approx_kl_gamma_gamma_, - ] = sess.run([ - gradp(exact_kl_gamma_gamma), - gradq(exact_kl_gamma_gamma), - gradp(approx_kl_gamma_gamma), - gradq(approx_kl_gamma_gamma), - ]) - # Notice that variance (i.e., `rtol`) is higher when using score-trick. - self.assertAllClose(gradp_exact_kl_gamma_gamma_, - gradp_approx_kl_gamma_gamma_, - rtol=0.05, atol=0.) - self.assertAllClose(gradq_exact_kl_gamma_gamma_, - gradq_approx_kl_gamma_gamma_, - rtol=0.03, atol=0.) - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 032b859d469ee5039e08e4af4c2f4ebf35c2ff19..68ead2f7609ca987180fe8973cf902f1e56b8388 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -192,7 +192,7 @@ def _logspace_mean(log_values): def expectation(f, samples, log_prob=None, use_reparametrization=True, axis=0, keep_dims=False, name=None): - """Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). + r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). This function computes the Monte-Carlo approximation of an expectation, i.e., diff --git a/tensorflow/contrib/bigtable/BUILD b/tensorflow/contrib/bigtable/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..71538e0770dcb436c8ff1571c22e950336328357 --- /dev/null +++ b/tensorflow/contrib/bigtable/BUILD @@ -0,0 +1,213 @@ +# Cloud Bigtable client for TensorFlow + +package( + default_visibility = ["//tensorflow:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load( + "//tensorflow:tensorflow.bzl", + "tf_copts", + "tf_custom_op_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_cc_test", + "tf_py_test", +) + +tf_custom_op_py_library( + name = "bigtable", + srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + dso = [ + ":python/ops/_bigtable.so", + ], + kernels = [ + ":bigtable_kernels", + ":bigtable_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":bigtable_ops", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python/data", + ], +) + +KERNEL_FILES = [ + "kernels/bigtable_kernels.cc", + "kernels/bigtable_lookup_dataset_op.cc", + "kernels/bigtable_prefix_key_dataset_op.cc", + "kernels/bigtable_range_key_dataset_op.cc", + "kernels/bigtable_sample_keys_dataset_op.cc", + "kernels/bigtable_sample_key_pairs_dataset_op.cc", + "kernels/bigtable_scan_dataset_op.cc", +] + +tf_custom_op_library( + name = "python/ops/_bigtable.so", + srcs = KERNEL_FILES + [ + "ops/bigtable_ops.cc", + ], + deps = [ + ":bigtable_lib_cc", + ":bigtable_range_helpers", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +tf_gen_op_wrapper_py( + name = "bigtable_ops", + deps = [":bigtable_ops_op_lib"], +) + +tf_gen_op_libs( + op_lib_names = [ + "bigtable_ops", + "bigtable_test_ops", + ], +) + +tf_kernel_library( + name = "bigtable_kernels", + srcs = KERNEL_FILES, + deps = [ + ":bigtable_lib_cc", + ":bigtable_range_helpers", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +# A library for use in the bigtable kernels. +cc_library( + name = "bigtable_lib_cc", + srcs = ["kernels/bigtable_lib.cc"], + hdrs = ["kernels/bigtable_lib.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +cc_library( + name = "bigtable_range_helpers", + srcs = ["kernels/bigtable_range_helpers.cc"], + hdrs = ["kernels/bigtable_range_helpers.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + ], +) + +cc_library( + name = "bigtable_test_client", + srcs = ["kernels/test_kernels/bigtable_test_client.cc"], + hdrs = ["kernels/test_kernels/bigtable_test_client.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "@com_github_googleapis_googleapis//:bigtable_protos", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + "@com_googlesource_code_re2//:re2", + ], +) + +tf_cc_test( + name = "bigtable_test_client_test", + srcs = ["kernels/test_kernels/bigtable_test_client_test.cc"], + tags = ["manual"], + deps = [ + ":bigtable_test_client", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +tf_cc_test( + name = "bigtable_range_helpers_test", + size = "small", + srcs = ["kernels/bigtable_range_helpers_test.cc"], + deps = [ + ":bigtable_range_helpers", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_gen_op_wrapper_py( + name = "bigtable_test_ops", + deps = [":bigtable_test_ops_op_lib"], +) + +tf_custom_op_library( + name = "python/kernel_tests/_bigtable_test.so", + srcs = [ + "kernels/test_kernels/bigtable_test_client_op.cc", + "ops/bigtable_test_ops.cc", + ], + deps = [ + ":bigtable_lib_cc", + ":bigtable_test_client", + "@com_googlesource_code_re2//:re2", + ], +) + +# Don't use tf_kernel_library because it prevents access to strings/stringprintf.h +cc_library( + name = "bigtable_test_kernels", + srcs = [ + "kernels/test_kernels/bigtable_test_client_op.cc", + ], + copts = tf_copts(), + linkstatic = 1, + deps = [ + ":bigtable_lib_cc", + ":bigtable_test_client", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_googlesource_code_re2//:re2", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "bigtable_test_py", + dso = [ + ":python/kernel_tests/_bigtable_test.so", + ], + kernels = [ + ":bigtable_test_kernels", + ":bigtable_test_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":bigtable_test_ops", + ], +) + +tf_py_test( + name = "bigtable_ops_test", + size = "small", + srcs = ["python/kernel_tests/bigtable_ops_test.py"], + additional_deps = [ + ":bigtable", + ":bigtable_test_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], + tags = ["manual"], +) diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b9abfa8295f9013cd8e92f87466a73952ccceb10 --- /dev/null +++ b/tensorflow/contrib/bigtable/README.md @@ -0,0 +1,347 @@ +# Google Cloud Bigtable + +[Cloud Bigtable](https://cloud.google.com/bigtable/) is a high +performance storage system that can store and serve training data. This contrib +package contains an experimental integration with TensorFlow. + +> **Status: Highly experimental.** The current implementation is very much in +> flux. Please use at your own risk! :-) + +The TensorFlow integration with Cloud Bigtable is optimized for common +TensorFlow usage and workloads. It is currently optimized for reading from Cloud +Bigtable at high speed, in particular to feed modern accelerators. For +general-purpose Cloud Bigtable +APIs, see the [official Cloud Bigtable client library documentation][clientdoc]. + +[clientdoc]: https://cloud.google.com/bigtable/docs/reference/libraries + +## Sample Use + +There are three main reading styles supported by the `BigtableTable` class: + + 1. **Reading keys**: Read only the row keys in a table. Keys are returned in + sorted order from the table. Most key reading operations retrieve all keys + in a contiguous range, however the `sample_keys` operation skips keys, and + operates on the whole table (and not a contiguous subset). + 2. **Retrieving a row's values**: Given a row key, look up the data associated + with a defined set of columns. This operation takes advantage of Cloud + Bigtable's low-latency and excellent support for random access. + 3. **Scanning ranges**: Given a contiguous range of rows retrieve both the row + key and the data associated with a fixed set of columns. This operation + takes advantage of Cloud Bigtable's high throughput scans, and is the most + efficient way to read data. + +When using the Cloud Bigtable API, the workflow is: + + 1. Create a `BigtableClient` object. + 2. Use the `BigtableClient` to create `BigtableTable` objects corresponding to + each table in the Cloud Bigtable instance you would like to access. + 3. Call methods on the `BigtableTable` object to create `tf.data.Dataset`s to + retrieve data. + +The following is an example for how to read all row keys with the prefix +`train-`. + +```python +import tensorflow as tf + +GCP_PROJECT_ID = '' +BIGTABLE_INSTANCE_ID = '' +BIGTABLE_TABLE_NAME = '' +PREFIX = 'train-' + +def main(): + client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID) + table = client.table(BIGTABLE_TABLE_NAME) + dataset = table.keys_by_prefix_dataset(PREFIX) + iterator = dataset.make_initializable_iterator() + get_next_op = iterator.get_next() + + with tf.Session() as sess: + print('Initializing the iterator.') + sess.run(iterator.initializer) + print('Retrieving rows:') + row_index = 0 + while True: + try: + row_key = sess.run(get_next_op) + print('Row key %d: %s' % (row_index, row_key)) + row_index += 1 + except tf.errors.OutOfRangeError: + print('Finished reading data!') + break + +if __name__ == '__main__': + main() + +``` + +### Reading row keys + +Read only the row keys in a table. Keys are returned in sorted order from the +table. Most key reading operations retrieve all keys in a contiguous range, +however the `sample_keys` operation skips keys, and operates on the whole table +(and not a contiguous subset). + +There are 3 methods to retrieve row keys: + + - `table.keys_by_range_dataset(start, end)`: Retrieve row keys starting with + `start`, and ending with `end`. The range is "half-open", and thus it + includes `start` if `start` is present in the table. It does not include + `end`. + - `table.keys_by_prefix_dataset(prefix)`: Retrieves all row keys that start + with `prefix`. It includes the row key `prefix` if present in the table. + - `table.sample_keys()`: Retrieves a sampling of keys from the underlying + table. This is often useful in conjunction with parallel scans. + +### Reading cell values given a row key + +Given a dataset producing row keys, you can use the `table.lookup_columns` +transformation to retrieve values. Example: + +```python +key_dataset = tf.data.Dataset.from_tensor_slices([ + 'row_key_1', + 'other_row_key', + 'final_row_key', +]) +values_dataset = key_dataset.apply( + table.lookup_columns(('my_column_family', 'column_name'), + ('other_cf', 'col'))) +training_data = values_dataset.map(my_parsing_function) # ... +``` + +### Scanning ranges +Given a contiguous range of rows retrieve both the row key and the data +associated with a fixed set of columns. Scanning is the most efficient way to +retrieve data from Cloud Bigtable and is thus a very common API for high +performance data pipelines. To construct a scanning `tf.data.Dataset` from a +`BigtableTable` object, call one of the following methods: + + - `table.scan_prefix(prefix, ...)` + - `table.scan_range(start, end, ...)` + - `table.parallel_scan_prefix(prefix, ...)` + - `table.parallel_scan_range(start, end, ...)` + +Aside from the specification of the contiguous range of rows, they all take the +following arguments: + + - `probability`: (Optional.) A float between 0 (exclusive) and 1 (inclusive). + A non-1 value indicates to probabilistically sample rows with the + provided probability. + - `columns`: The columns to read. (See below.) + - `**kwargs`: The columns to read. (See below.) + +In addition the two parallel operations accept the following optional argument: +`num_parallel_scans` which configures the number of parallel Cloud Bigtable scan +operations to run. A reasonable default is automatically chosen for small +Cloud Bigtable clusters. If you have a large cluster, or an extremely demanding +workload, you can tune this value to optimize performance. + +#### Specifying columns to read when scanning + +All of the scan operations allow you to specify the column family and columns +in the same ways. + +##### Using `columns` + +The first way to specify the data to read is via the `columns` parameter. The +value should be a tuple (or list of tuples) of strings. The first string in the +tuple is the column family, and the second string in the tuple is the column +qualifier. + +##### Using `**kwargs` + +The second way to specify the data to read is via the `**kwargs` parameter, +which you can use to specify keyword arguments corresponding to the columns that +you want to read. The keyword to use is the column family name, and the argument +value should be either a string, or a tuple of strings, specifying the column +qualifiers (column names). + +Although using `**kwargs` has the advantage of requiring less typing, it is not +future-proof in all cases. (If we add a new parameter to the scan functions that +has the same name as your column family, your code will break.) + +##### Examples + +Below are two equivalent snippets for how to specify which columns to read: + +```python +ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) +ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3") +``` + +In this example, we are reading 3 columns from a total of 2 column families. +From the `cfa` column family, we are reading columns `c1`, and `c2`. From the +second column family (`cfb`), we are reading `c3`. Both `ds1` and `ds2` will +output elements of the following types (`tf.string`, `tf.string`, `tf.string`, +`tf.string`). The first `tf.string` is the row key, the second `tf.string` is +the latest data in cell `cfa:c1`, the third corresponds to `cfa:c2`, and the +final one is `cfb:c3`. + +#### Determinism when scanning + +While the non-parallel scan operations are fully deterministic, the parallel +scan operations are not. If you would like to scan in parallel without losing +determinism, you can build up the `parallel_interleave` yourself. As an example, +say we wanted to scan all rows between `training_data_00000`, and +`training_data_90000`, we can use the following code snippet: + +```python +table = # ... +columns = [('cf1', 'col1'), ('cf1', 'col2')] +NUM_PARALLEL_READS = # ... +ds = tf.data.Dataset.range(9).shuffle(10) +def interleave_fn(index): + # Given a starting index, create 2 strings to be the start and end + start_idx = index + end_idx = index + 1 + start_idx_str = tf.as_string(start_idx * 10000, width=5, fill='0') + end_idx_str = tf.as_string(end_idx * 10000, width=5, fill='0') + start = tf.string_join(['training_data_', start_idx_str]) + end = tf.string_join(['training_data_', end_idx_str]) + return table.scan_range(start_idx, end_idx, columns=columns) +ds = ds.apply(tf.contrib.data.parallel_interleave( + interleave_fn, cycle_length=NUM_PARALLEL_READS, prefetch_input_elements=1)) +``` + +> Note: you should divide up the key range into more sub-ranges for increased +> parallelism. + +## Writing to Cloud Bigtable + +In order to simplify getting started, this package provides basic support for +writing data into Cloud Bigtable. + +> Note: The implementation is not optimized for performance! Please consider +> using alternative frameworks such as Apache Beam / Cloud Dataflow for +> production workloads. + +Below is an example for how to write a trivial dataset into Cloud Bigtable. + +```python +import tensorflow as tf + +GCP_PROJECT_ID = '' +BIGTABLE_INSTANCE_ID = '' +BIGTABLE_TABLE_NAME = '' +COLUMN_FAMILY = '' +COLUMN_QUALIFIER = '' + +def make_dataset(): + """Makes a dataset to write to Cloud Bigtable.""" + return tf.data.Dataset.from_tensor_slices([ + 'training_data_1', + 'training_data_2', + 'training_data_3', + ]) + +def make_row_key_dataset(): + """Makes a dataset of strings used for row keys. + + The strings are of the form: `fake-data-` followed by a sequential counter. + For example, this dataset would contain the following elements: + + - fake-data-00000001 + - fake-data-00000002 + - ... + - fake-data-23498103 + """ + counter_dataset = tf.contrib.data.Counter() + width = 8 + row_key_prefix = 'fake-data-' + ds = counter_dataset.map(lambda index: tf.as_string(index, + width=width, + fill='0')) + ds = ds.map(lambda idx_str: tf.string_join([row_key_prefix, idx_str])) + return ds + + +def main(): + client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID) + table = client.table(BIGTABLE_TABLE_NAME) + dataset = make_dataset() + index_dataset = make_row_key_dataset() + aggregate_dataset = tf.data.Dataset.zip((index_dataset, dataset)) + write_op = table.write(aggregate_dataset, column_families=[COLUMN_FAMILY], + columns=[COLUMN_QUALIFIER]) + + with tf.Session() as sess: + print('Starting transfer.') + sess.run(write_op) + print('Transfer complete.') + +if __name__ == '__main__': + main() +``` + +## Sample applications and architectures + +While most machine learning applications are well suited by a high performance +distributed file system, there are certain applications where using Cloud +Bigtable works extremely well. + +### Perfect Shuffling + +Normally, training data is stored in flat files, and a combination of +(1) `tf.data.Dataset.interleave` (or `parallel_interleave`), (2) +`tf.data.Dataset.shuffle`, and (3) writing the data in an unsorted order in the +data files in the first place, provides enough randomization to ensure models +train efficiently. However, if you would like perfect shuffling, you can use +Cloud Bigtable's low-latency random access capabilities. Create a +`tf.data.Dataset` that generates the keys in a perfectly random order (or read +all the keys into memory and use a shuffle buffer sized to fit all of them for a +perfect random shuffle using `tf.data.Dataset.shuffle`), and then use +`lookup_columns` to retrieve the training data. + +### Distributed Reinforcement Learning + +Sophisticated reinforcement learning algorithms are commonly trained across a +distributed cluster. (See [IMPALA by DeepMind][impala].) One part of the cluster +runs self-play, while the other part of the cluster learns a new version of the +model based on the training data generated by self-play. The new model version +is then distributed to the self-play half of the cluster, and new training data +is generated to continue the cycle. + +In such a configuration, because there is value in training on the freshest +examples, a storage service like Cloud Bigtable can be used to store and +serve the generated training data. When using Cloud Bigtable, there is no need +to aggregate the examples into large batch files, but the examples can instead +be written as soon as they are generated, and then retrieved at high speed. + +[impala]: https://arxiv.org/abs/1802.01561 + +## Common Gotchas! + +### gRPC Certificates + +If you encounter a log line that includes the following: + +``` +"description":"Failed to load file", [...], +"filename":"/usr/share/grpc/roots.pem" +``` + +you likely need to copy the [gRPC `roots.pem` file][grpcPem] to +`/usr/share/grpc/roots.pem` on your local machine. + +[grpcPem]: https://github.com/grpc/grpc/blob/master/etc/roots.pem + +### Permission denied errors + +The TensorFlow Cloud Bigtable client will search for credentials to use in the +process's environment. It will use the first credentials it finds if multiple +are available. + + - **Compute Engine**: When running on Compute Engine, the client will often use + the service account from the virtual machine's metadata service. Be sure to + authorize your Compute Engine VM to have access to the Cloud Bigtable service + when creating your VM, or [update the VM's scopes][update-vm-scopes] on a + running VM if you run into this issue. + - **Cloud TPU**: Your Cloud TPUs run with the designated Cloud TPU service + account dedicated to your GCP project. Ensure the service account has been + authorized via the Cloud Console to access your Cloud Bigtable instances. + +[update-vm-scopes]: https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#changeserviceaccountandscopes diff --git a/tensorflow/contrib/bigtable/__init__.py b/tensorflow/contrib/bigtable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7d89c98420ab3ac1465bba718f8257ce2312467 --- /dev/null +++ b/tensorflow/contrib/bigtable/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cloud Bigtable Client for TensorFlow. + +This contrib package allows TensorFlow to interface directly with Cloud Bigtable +for high-speed data loading. + +@@BigtableClient +@@BigtableTable + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient +from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableTable + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'BigtableClient', + 'BigtableTable', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc new file mode 100644 index 0000000000000000000000000000000000000000..a25a641cdb4608dee6d6c1bd18697860cc1f5613 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -0,0 +1,358 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { + +namespace { + +class BigtableClientOp : public OpKernel { + public: + explicit BigtableClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_)); + OP_REQUIRES(ctx, !project_id_.empty(), + errors::InvalidArgument("project_id must be non-empty")); + OP_REQUIRES(ctx, !instance_id_.empty(), + errors::InvalidArgument("instance_id must be non-empty")); + + OP_REQUIRES_OK( + ctx, ctx->GetAttr("connection_pool_size", &connection_pool_size_)); + // If left unset by the client code, set it to a default of 100. Note: the + // cloud-cpp default of 4 concurrent connections is far too low for high + // performance streaming. + if (connection_pool_size_ == -1) { + connection_pool_size_ = 100; + } + + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_receive_message_size", + &max_receive_message_size_)); + // If left unset by the client code, set it to a default of 100. Note: the + // cloud-cpp default of 4 concurrent connections is far too low for high + // performance streaming. + if (max_receive_message_size_ == -1) { + max_receive_message_size_ = 1 << 24; // 16 MBytes + } + OP_REQUIRES(ctx, max_receive_message_size_ > 0, + errors::InvalidArgument("connection_pool_size must be > 0")); + } + + ~BigtableClientOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + BigtableClientResource* resource; + OP_REQUIRES_OK( + ctx, + mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx]( + BigtableClientResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto client_options = + google::cloud::bigtable::ClientOptions() + .set_connection_pool_size(connection_pool_size_) + .set_data_endpoint("batch-bigtable.googleapis.com"); + auto channel_args = client_options.channel_arguments(); + channel_args.SetMaxReceiveMessageSize( + max_receive_message_size_); + channel_args.SetUserAgentPrefix("tensorflow"); + channel_args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 0); + channel_args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 60 * 1000); + client_options.set_channel_arguments(channel_args); + std::shared_ptr client = + google::cloud::bigtable::CreateDefaultDataClient( + project_id_, instance_id_, std::move(client_options)); + *ret = new BigtableClientResource(project_id_, instance_id_, + std::move(client)); + return Status::OK(); + })); + core::ScopedUnref resource_cleanup(resource); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + string project_id_; + string instance_id_; + int64 connection_pool_size_; + int32 max_receive_message_size_; + + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableClient").Device(DEVICE_CPU), + BigtableClientOp); + +class BigtableTableOp : public OpKernel { + public: + explicit BigtableTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_)); + OP_REQUIRES(ctx, !table_.empty(), + errors::InvalidArgument("table_name must be non-empty")); + } + + ~BigtableTableOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + + BigtableClientResource* client_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource)); + core::ScopedUnref unref_client(client_resource); + + BigtableTableResource* resource; + OP_REQUIRES_OK( + ctx, mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, client_resource](BigtableTableResource** ret) { + *ret = new BigtableTableResource(client_resource, table_); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + string table_; // Note: this is const after construction. + + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU), + BigtableTableOp); + +class ToBigtableOp : public AsyncOpKernel { + public: + explicit ToBigtableOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + thread_pool_(new thread::ThreadPool( + ctx->env(), ThreadOptions(), + strings::StrCat("to_bigtable_op_", SanitizeThreadSuffix(name())), + /* num_threads = */ 1, /* low_latency_hint = */ false)) {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + // The call to `iterator->GetNext()` may block and depend on an + // inter-op thread pool thread, so we issue the call from the + // owned thread pool. + thread_pool_->Schedule([this, ctx, done]() { + const Tensor* column_families_tensor; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->input("column_families", &column_families_tensor), done); + OP_REQUIRES_ASYNC( + ctx, column_families_tensor->dims() == 1, + errors::InvalidArgument("`column_families` must be a vector."), done); + + const Tensor* columns_tensor; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input("columns", &columns_tensor), done); + OP_REQUIRES_ASYNC(ctx, columns_tensor->dims() == 1, + errors::InvalidArgument("`columns` must be a vector."), + done); + OP_REQUIRES_ASYNC( + ctx, + columns_tensor->NumElements() == + column_families_tensor->NumElements(), + errors::InvalidArgument("len(column_families) != len(columns)"), + done); + + std::vector column_families; + column_families.reserve(column_families_tensor->NumElements()); + std::vector columns; + columns.reserve(column_families_tensor->NumElements()); + for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) { + column_families.push_back(column_families_tensor->flat()(i)); + columns.push_back(columns_tensor->flat()(i)); + } + + DatasetBase* dataset; + OP_REQUIRES_OK_ASYNC( + ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done); + + std::unique_ptr iterator; + OP_REQUIRES_OK_ASYNC( + ctx, + dataset->MakeIterator(IteratorContext(ctx), "ToBigtableOpIterator", + &iterator), + done); + + int64 timestamp_int; + OP_REQUIRES_OK_ASYNC( + ctx, ParseScalarArgument(ctx, "timestamp", ×tamp_int), + done); + OP_REQUIRES_ASYNC(ctx, timestamp_int >= -1, + errors::InvalidArgument("timestamp must be >= -1"), + done); + + BigtableTableResource* resource; + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done); + core::ScopedUnref resource_cleanup(resource); + + std::vector components; + components.reserve(dataset->output_dtypes().size()); + bool end_of_sequence = false; + do { + ::google::cloud::bigtable::BulkMutation mutation; + // TODO(saeta): Make # of mutations configurable. + for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) { + OP_REQUIRES_OK_ASYNC(ctx, + iterator->GetNext(IteratorContext(ctx), + &components, &end_of_sequence), + done); + if (!end_of_sequence) { + OP_REQUIRES_OK_ASYNC( + ctx, + CreateMutation(std::move(components), column_families, columns, + timestamp_int, &mutation), + done); + } + components.clear(); + } + grpc::Status mutation_status; + std::vector<::google::cloud::bigtable::FailedMutation> failures = + resource->table().BulkApply(std::move(mutation), mutation_status); + if (!mutation_status.ok()) { + LOG(ERROR) << "Failure applying mutation: " + << mutation_status.error_code() << " - " + << mutation_status.error_message() << " (" + << mutation_status.error_details() << ")."; + } + if (!failures.empty()) { + for (const auto& failure : failures) { + LOG(ERROR) << "Failure applying mutation on row (" + << failure.original_index() + << "): " << failure.mutation().row_key() + << " - error: " << failure.status().error_message() + << " (Details: " << failure.status().error_details() + << ")."; + } + } + OP_REQUIRES_ASYNC( + ctx, failures.empty() && mutation_status.ok(), + errors::Unknown("Failure while writing to Cloud Bigtable: ", + mutation_status.error_code(), " - ", + mutation_status.error_message(), " (", + mutation_status.error_details(), + "), # of mutation failures: ", failures.size(), + ". See the log for the specific error details."), + done); + } while (!end_of_sequence); + done(); + }); + } + + private: + static string SanitizeThreadSuffix(string suffix) { + string clean; + for (int i = 0; i < suffix.size(); ++i) { + const char ch = suffix[i]; + if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || + (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') { + clean += ch; + } else { + clean += '_'; + } + } + return clean; + } + + Status CreateMutation( + std::vector tensors, const std::vector& column_families, + const std::vector& columns, int64 timestamp_int, + ::google::cloud::bigtable::BulkMutation* bulk_mutation) { + if (tensors.size() != column_families.size() + 1) { + return errors::InvalidArgument( + "Iterator produced a set of Tensors shorter than expected"); + } + ::google::cloud::bigtable::SingleRowMutation mutation( + std::move(tensors[0].scalar()())); + std::chrono::milliseconds timestamp(timestamp_int); + for (size_t i = 1; i < tensors.size(); ++i) { + if (!TensorShapeUtils::IsScalar(tensors[i].shape())) { + return errors::Internal("Output tensor ", i, " was not a scalar"); + } + if (timestamp_int == -1) { + mutation.emplace_back(::google::cloud::bigtable::SetCell( + column_families[i - 1], columns[i - 1], + std::move(tensors[i].scalar()()))); + } else { + mutation.emplace_back(::google::cloud::bigtable::SetCell( + column_families[i - 1], columns[i - 1], timestamp, + std::move(tensors[i].scalar()()))); + } + } + bulk_mutation->emplace_back(std::move(mutation)); + return Status::OK(); + } + + template + Status ParseScalarArgument(OpKernelContext* ctx, + const StringPiece& argument_name, T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); + } + + std::unique_ptr thread_pool_; +}; + +REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU), + ToBigtableOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc new file mode 100644 index 0000000000000000000000000000000000000000..67bf14c17646cff81af707405b66c9fba2ded0bd --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" + +namespace tensorflow { + +Status GrpcStatusToTfStatus(const ::grpc::Status& status) { + if (status.ok()) { + return Status::OK(); + } + auto grpc_code = status.error_code(); + if (status.error_code() == ::grpc::StatusCode::ABORTED || + status.error_code() == ::grpc::StatusCode::UNAVAILABLE || + status.error_code() == ::grpc::StatusCode::OUT_OF_RANGE) { + grpc_code = ::grpc::StatusCode::INTERNAL; + } + return Status(static_cast<::tensorflow::error::Code>(status.error_code()), + strings::StrCat("Error reading from Cloud Bigtable: ", + status.error_message(), + " (Details: ", status.error_details(), ")")); +} + +string RegexFromStringSet(const std::vector& strs) { + CHECK(!strs.empty()) << "The list of strings to turn into a regex was empty."; + std::unordered_set uniq(strs.begin(), strs.end()); + if (uniq.size() == 1) { + return *uniq.begin(); + } + return str_util::Join(uniq, "|"); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h new file mode 100644 index 0000000000000000000000000000000000000000..a2a5df1037a00ccfdff1910dd950d7b012e684e2 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -0,0 +1,143 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ +#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ + +// Note: we use bigtable/client/internal/table.h as this is the no-exception API + +#include "google/cloud/bigtable/data_client.h" +#include "google/cloud/bigtable/internal/table.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/resource_mgr.h" + +namespace tensorflow { + +Status GrpcStatusToTfStatus(const ::grpc::Status& status); + +string RegexFromStringSet(const std::vector& strs); + +class BigtableClientResource : public ResourceBase { + public: + BigtableClientResource( + string project_id, string instance_id, + std::shared_ptr client) + : project_id_(std::move(project_id)), + instance_id_(std::move(instance_id)), + client_(std::move(client)) {} + + std::shared_ptr get_client() { + return client_; + } + + string DebugString() override { + return strings::StrCat("BigtableClientResource(project_id: ", project_id_, + ", instance_id: ", instance_id_, ")"); + } + + private: + const string project_id_; + const string instance_id_; + std::shared_ptr client_; +}; + +class BigtableTableResource : public ResourceBase { + public: + BigtableTableResource(BigtableClientResource* client, string table_name) + : client_(client), + table_name_(std::move(table_name)), + table_(client->get_client(), table_name_, + google::cloud::bigtable::AlwaysRetryMutationPolicy()) { + client_->Ref(); + } + + ~BigtableTableResource() override { client_->Unref(); } + + ::google::cloud::bigtable::noex::Table& table() { return table_; } + + string DebugString() override { + return strings::StrCat( + "BigtableTableResource(client: ", client_->DebugString(), + ", table: ", table_name_, ")"); + } + + private: + BigtableClientResource* client_; // Ownes one ref. + const string table_name_; + ::google::cloud::bigtable::noex::Table table_; +}; + +// BigtableReaderDatasetIterator is an abstract class for iterators from +// datasets that are "readers" (source datasets, not transformation datasets) +// that read from Bigtable. +template +class BigtableReaderDatasetIterator : public DatasetIterator { + public: + explicit BigtableReaderDatasetIterator( + const typename DatasetIterator::Params& params) + : DatasetIterator(params), iterator_(nullptr, false) {} + + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(EnsureIteratorInitialized()); + if (iterator_ == reader_->end()) { + grpc::Status status = reader_->Finish(); + if (status.ok()) { + *end_of_sequence = true; + return Status::OK(); + } + return GrpcStatusToTfStatus(status); + } + *end_of_sequence = false; + google::cloud::bigtable::Row& row = *iterator_; + Status s = ParseRow(ctx, row, out_tensors); + // Ensure we always advance. + ++iterator_; + return s; + } + + protected: + virtual ::google::cloud::bigtable::RowRange MakeRowRange() = 0; + virtual ::google::cloud::bigtable::Filter MakeFilter() = 0; + virtual Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) = 0; + + private: + Status EnsureIteratorInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (reader_) { + return Status::OK(); + } + + auto rows = MakeRowRange(); + auto filter = MakeFilter(); + + // Note: the this in `this->dataset()` below is necessary due to namespace + // name conflicts. + reader_.reset(new ::google::cloud::bigtable::RowReader( + this->dataset()->table()->table().ReadRows(rows, filter))); + iterator_ = reader_->begin(); + return Status::OK(); + } + + mutex mu_; + std::unique_ptr<::google::cloud::bigtable::RowReader> reader_ GUARDED_BY(mu_); + ::google::cloud::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd32672aa99d7bf70c44a264f488482c4f213a0b --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -0,0 +1,229 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { + public: + using UnaryDatasetOpKernel::UnaryDatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + BigtableTableResource* table; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table)); + + std::vector column_families; + std::vector columns; + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "column_families", + &column_families)); + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "columns", &columns)); + OP_REQUIRES( + ctx, column_families.size() == columns.size(), + errors::InvalidArgument("len(columns) != len(column_families)")); + + const uint64 num_outputs = columns.size() + 1; + std::vector output_shapes; + output_shapes.reserve(num_outputs); + DataTypeVector output_types; + output_types.reserve(num_outputs); + for (uint64 i = 0; i < num_outputs; ++i) { + output_shapes.push_back({}); + output_types.push_back(DT_STRING); + } + + *output = + new Dataset(ctx, input, table, std::move(column_families), + std::move(columns), output_types, std::move(output_shapes)); + } + + private: + class Dataset : public DatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, + BigtableTableResource* table, + std::vector column_families, + std::vector columns, + const DataTypeVector& output_types, + std::vector output_shapes) + : DatasetBase(DatasetContext(ctx)), + input_(input), + table_(table), + column_families_(std::move(column_families)), + columns_(std::move(columns)), + output_types_(output_types), + output_shapes_(std::move(output_shapes)), + filter_(MakeFilter(column_families_, columns_)) { + table_->Ref(); + input_->Ref(); + } + + ~Dataset() override { + table_->Unref(); + input_->Unref(); + } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::BigtableLookup")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "BigtableLookupDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + + private: + static ::google::cloud::bigtable::Filter MakeFilter( + const std::vector& column_families, + const std::vector& columns) { + string column_family_regex = RegexFromStringSet(column_families); + string column_regex = RegexFromStringSet(columns); + + return ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1), + ::google::cloud::bigtable::Filter::FamilyRegex(column_family_regex), + ::google::cloud::bigtable::Filter::ColumnRegex(column_regex)); + } + + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); // Sequence requests. + std::vector input_tensors; + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &input_tensors, end_of_sequence)); + if (*end_of_sequence) { + return Status::OK(); + } + if (input_tensors.size() != 1) { + return errors::InvalidArgument( + "Upstream iterator (", dataset()->input_->DebugString(), + ") did not produce a single `tf.string` `tf.Tensor`. It " + "produced ", + input_tensors.size(), " tensors."); + } + if (input_tensors[0].NumElements() == 0) { + return errors::InvalidArgument("Upstream iterator (", + dataset()->input_->DebugString(), + ") return an empty set of keys."); + } + if (input_tensors[0].NumElements() == 1) { + // Single key lookup. + ::grpc::Status status; + auto pair = dataset()->table_->table().ReadRow( + input_tensors[0].scalar()(), dataset()->filter_, status); + if (!status.ok()) { + return GrpcStatusToTfStatus(status); + } + if (!pair.first) { + return errors::DataLoss("Row key '", + input_tensors[0].scalar()(), + "' not found."); + } + TF_RETURN_IF_ERROR(ParseRow(ctx, pair.second, out_tensors)); + } else { + // Batched get. + return errors::Unimplemented( + "BigtableLookupDataset doesn't yet support batched retrieval."); + } + return Status::OK(); + } + + private: + Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) { + out_tensors->reserve(dataset()->columns_.size() + 1); + Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); + row_key_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(row_key_tensor)); + + if (row.cells().size() > 2 * dataset()->columns_.size()) { + LOG(WARNING) << "An excessive number of columns (" + << row.cells().size() + << ") were retrieved when reading row: " + << row.row_key(); + } + + for (uint64 i = 0; i < dataset()->columns_.size(); ++i) { + Tensor col_tensor(ctx->allocator({}), DT_STRING, {}); + bool found_column = false; + for (auto cell_itr = row.cells().begin(); + !found_column && cell_itr != row.cells().end(); ++cell_itr) { + if (cell_itr->family_name() == dataset()->column_families_[i] && + string(cell_itr->column_qualifier()) == + dataset()->columns_[i]) { + col_tensor.scalar()() = string(cell_itr->value()); + found_column = true; + } + } + if (!found_column) { + return errors::DataLoss("Column ", dataset()->column_families_[i], + ":", dataset()->columns_[i], + " not found in row: ", row.row_key()); + } + out_tensors->emplace_back(std::move(col_tensor)); + } + return Status::OK(); + } + + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + BigtableTableResource* table_; + const std::vector column_families_; + const std::vector columns_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + const ::google::cloud::bigtable::Filter filter_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU), + BigtableLookupDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a803fdcb49604ef4e596b64d62c7278c69764c15 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -0,0 +1,114 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + *output = new Dataset(ctx, resource, std::move(prefix)); + } + + private: + class Dataset : public DatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string prefix) + : DatasetBase(DatasetContext(ctx)), + table_(table), + prefix_(std::move(prefix)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::BigtablePrefixKey")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtablePrefixKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + + private: + class Iterator : public BigtableReaderDatasetIterator { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator(params) {} + + ::google::cloud::bigtable::RowRange MakeRowRange() override { + return ::google::cloud::bigtable::RowRange::Prefix(dataset()->prefix_); + } + ::google::cloud::bigtable::Filter MakeFilter() override { + return ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::CellsRowLimit(1), + ::google::cloud::bigtable::Filter::StripValueTransformer()); + } + Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) override { + Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); + output_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(output_tensor)); + return Status::OK(); + } + }; + + BigtableTableResource* const table_; + const string prefix_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU), + BigtablePrefixKeyDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc new file mode 100644 index 0000000000000000000000000000000000000000..51965f6214413c08453473e71c30eecbd8925a64 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h" + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace { + +string MakePrefixEndKey(const string& prefix) { + string end = prefix; + while (true) { + if (end.empty()) { + return end; + } + ++end[end.size() - 1]; + if (end[end.size() - 1] == 0) { + // Handle wraparound case. + end = end.substr(0, end.size() - 1); + } else { + return end; + } + } +} + +} // namespace + +/* static */ MultiModeKeyRange MultiModeKeyRange::FromPrefix(string prefix) { + string end = MakePrefixEndKey(prefix); + VLOG(1) << "Creating MultiModeKeyRange from Prefix: " << prefix + << ", with end key: " << end; + return MultiModeKeyRange(std::move(prefix), std::move(end)); +} + +/* static */ MultiModeKeyRange MultiModeKeyRange::FromRange(string begin, + string end) { + return MultiModeKeyRange(std::move(begin), std::move(end)); +} + +const string& MultiModeKeyRange::begin_key() const { return begin_; } + +const string& MultiModeKeyRange::end_key() const { return end_; } + +bool MultiModeKeyRange::contains_key(StringPiece key) const { + if (StringPiece(begin_) > key) { + return false; + } + if (StringPiece(end_) <= key && !end_.empty()) { + return false; + } + return true; +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..44c628e366c26b88011642f1e8e8d8e74b4698fd --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_ +#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_ + +#include + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Represents a continuous range of keys defined by either a prefix or a range. +// +// Ranges are represented as "half-open", where the beginning key is included +// in the range, and the end_key is the first excluded key after the range. +// +// The range of keys can be specified either by a key prefix, or by an explicit +// begin key and end key. All methods on this class are valid no matter which +// way the range was specified. +// +// Example: +// MultiModeKeyRange range = MultiModeKeyRange::FromPrefix("myPrefix"); +// if (range.contains_key("myPrefixedKey")) { +// LOG(INFO) << "range from " << range.begin_key() << " to " +// << range.end_key() << "contains \"myPrefixedKey\""; +// } +// if (!range.contains_key("randomKey")) { +// LOG(INFO) << "range does not contain \"randomKey\""; +// } +// range = MultiModeKeyRange::FromRange("a_start_key", "z_end_key"); +class MultiModeKeyRange { + public: + static MultiModeKeyRange FromPrefix(string prefix); + static MultiModeKeyRange FromRange(string begin, string end); + + // The first valid key in the range. + const string& begin_key() const; + // The first invalid key after the valid range. + const string& end_key() const; + // Returns true if the provided key is a part of the range, false otherwise. + bool contains_key(StringPiece key) const; + + private: + MultiModeKeyRange(string begin, string end) + : begin_(std::move(begin)), end_(std::move(end)) {} + + const string begin_; + const string end_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_ diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1bfc547271d5e58a9145b73356b2b558dc1af9f1 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc @@ -0,0 +1,107 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(MultiModeKeyRangeTest, SimplePrefix) { + MultiModeKeyRange r = MultiModeKeyRange::FromPrefix("prefix"); + EXPECT_EQ("prefix", r.begin_key()); + EXPECT_EQ("prefiy", r.end_key()); + EXPECT_TRUE(r.contains_key("prefixed_key")); + EXPECT_FALSE(r.contains_key("not-prefixed-key")); + EXPECT_FALSE(r.contains_key("prefi")); + EXPECT_FALSE(r.contains_key("prefiy")); + EXPECT_FALSE(r.contains_key("early")); + EXPECT_FALSE(r.contains_key("")); +} + +TEST(MultiModeKeyRangeTest, Range) { + MultiModeKeyRange r = MultiModeKeyRange::FromRange("a", "b"); + EXPECT_EQ("a", r.begin_key()); + EXPECT_EQ("b", r.end_key()); + EXPECT_TRUE(r.contains_key("a")); + EXPECT_TRUE(r.contains_key("ab")); + EXPECT_FALSE(r.contains_key("b")); + EXPECT_FALSE(r.contains_key("bc")); + EXPECT_FALSE(r.contains_key("A")); + EXPECT_FALSE(r.contains_key("B")); + EXPECT_FALSE(r.contains_key("")); +} + +TEST(MultiModeKeyRangeTest, InvertedRange) { + MultiModeKeyRange r = MultiModeKeyRange::FromRange("b", "a"); + EXPECT_FALSE(r.contains_key("a")); + EXPECT_FALSE(r.contains_key("b")); + EXPECT_FALSE(r.contains_key("")); +} + +TEST(MultiModeKeyRangeTest, EmptyPrefix) { + MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(""); + EXPECT_EQ("", r.begin_key()); + EXPECT_EQ("", r.end_key()); + EXPECT_TRUE(r.contains_key("")); + EXPECT_TRUE(r.contains_key("a")); + EXPECT_TRUE(r.contains_key("z")); + EXPECT_TRUE(r.contains_key("A")); + EXPECT_TRUE(r.contains_key("ZZZZZZ")); +} + +TEST(MultiModeKeyRangeTest, HalfRange) { + MultiModeKeyRange r = MultiModeKeyRange::FromRange("start", ""); + EXPECT_EQ("start", r.begin_key()); + EXPECT_EQ("", r.end_key()); + EXPECT_TRUE(r.contains_key("start")); + EXPECT_TRUE(r.contains_key("starting")); + EXPECT_TRUE(r.contains_key("z-end")); + EXPECT_FALSE(r.contains_key("")); + EXPECT_FALSE(r.contains_key("early")); +} + +TEST(MultiModeKeyRangeTest, PrefixWrapAround) { + string prefix = "abc\xff"; + MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix); + EXPECT_EQ(prefix, r.begin_key()); + EXPECT_EQ("abd", r.end_key()); + + EXPECT_TRUE(r.contains_key("abc\xff\x07")); + EXPECT_TRUE(r.contains_key("abc\xff\x15")); + EXPECT_TRUE(r.contains_key("abc\xff\x61")); + EXPECT_TRUE(r.contains_key("abc\xff\xff")); + EXPECT_FALSE(r.contains_key("abc\0")); + EXPECT_FALSE(r.contains_key("abd")); +} + +TEST(MultiModeKeyRangeTest, PrefixSignedWrapAround) { + string prefix = "abc\x7f"; + MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix); + EXPECT_EQ(prefix, r.begin_key()); + EXPECT_EQ("abc\x80", r.end_key()); + + EXPECT_TRUE(r.contains_key("abc\x7f\x07")); + EXPECT_TRUE(r.contains_key("abc\x7f\x15")); + EXPECT_TRUE(r.contains_key("abc\x7f\x61")); + EXPECT_TRUE(r.contains_key("abc\x7f\xff")); + EXPECT_FALSE(r.contains_key("abc\0")); + EXPECT_FALSE(r.contains_key("abc\x01")); + EXPECT_FALSE(r.contains_key("abd")); + EXPECT_FALSE(r.contains_key("ab\x80")); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5cd0371c79f7eded9303b81dd388df8d306dff80 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableRangeKeyDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string start_key; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "start_key", &start_key)); + string end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + *output = + new Dataset(ctx, resource, std::move(start_key), std::move(end_key)); + } + + private: + class Dataset : public DatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string start_key, string end_key) + : DatasetBase(DatasetContext(ctx)), + table_(table), + start_key_(std::move(start_key)), + end_key_(std::move(end_key)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::BigtableRangeKey")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtableRangeKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + + private: + class Iterator : public BigtableReaderDatasetIterator { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator(params) {} + + ::google::cloud::bigtable::RowRange MakeRowRange() override { + return ::google::cloud::bigtable::RowRange::Range(dataset()->start_key_, + dataset()->end_key_); + } + ::google::cloud::bigtable::Filter MakeFilter() override { + return ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::CellsRowLimit(1), + ::google::cloud::bigtable::Filter::StripValueTransformer()); + } + Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) override { + Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); + output_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(output_tensor)); + return Status::OK(); + } + }; + + BigtableTableResource* const table_; + const string start_key_; + const string end_key_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU), + BigtableRangeKeyDatasetOp); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6928d9423c84f7504fea3ac1abd929357da034a5 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc @@ -0,0 +1,208 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); + + string start_key; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "start_key", &start_key)); + string end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + OP_REQUIRES(ctx, prefix.empty() || start_key.empty(), + errors::InvalidArgument( + "Only one of prefix and start_key can be provided")); + if (!prefix.empty()) { + OP_REQUIRES(ctx, end_key.empty(), + errors::InvalidArgument( + "If prefix is specified, end_key must be empty.")); + } + + *output = new Dataset(ctx, resource, std::move(prefix), + std::move(start_key), std::move(end_key)); + } + + private: + class Dataset : public DatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string prefix, string start_key, string end_key) + : DatasetBase(DatasetContext(ctx)), + table_(table), + key_range_(MakeMultiModeKeyRange( + std::move(prefix), std::move(start_key), std::move(end_key))) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtableSampleKeyPairs")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = + new DataTypeVector({DT_STRING, DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}, {}}); + return *shapes; + } + + string DebugString() const override { + return "BigtableSampleKeyPairsDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + + private: + static MultiModeKeyRange MakeMultiModeKeyRange(string prefix, + string start_key, + string end_key) { + if (!start_key.empty()) { + return MultiModeKeyRange::FromRange(std::move(start_key), + std::move(end_key)); + } + return MultiModeKeyRange::FromPrefix(std::move(prefix)); + } + + BigtableTableResource& table() const { return *table_; } + + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + // Computes split points (`keys_`) to use when scanning the table. + // + // Initialize first retrieves the sample keys from the table (`row_keys`), + // as these often form good split points within the table. We then iterate + // over them, and copy them to `keys_` if they fall within the requested + // range to scan (`dataset()->key_range_`). Because the requested range + // might start between elements of the sampled keys list, care is taken to + // ensure we don't accidentally miss any subsets of the requested range by + // including `begin_key()` and `end_key()` as appropriate. + Status Initialize(IteratorContext* ctx) override { + grpc::Status status; + std::vector row_keys = + dataset()->table().table().SampleRows(status); + if (!status.ok()) { + return GrpcStatusToTfStatus(status); + } + + for (size_t i = 0; i < row_keys.size(); ++i) { + string row_key(row_keys[i].row_key); + if (dataset()->key_range_.contains_key(row_key)) { + // First key: check to see if we need to add the begin_key. + if (keys_.empty() && dataset()->key_range_.begin_key() != row_key) { + keys_.push_back(dataset()->key_range_.begin_key()); + } + keys_.push_back(std::move(row_key)); + } else if (!keys_.empty()) { + // If !keys_.empty(), then we have found at least one element of + // `row_keys` that is within our requested range + // (`dataset()->key_range_`). Because `row_keys` is sorted, if we + // have found an element that's not within our key range, then we + // are after our requested range (ranges are contiguous) and can end + // iteration early. + break; + } + } + + // Handle the case where we skip over the selected range entirely. + if (keys_.empty()) { + keys_.push_back(dataset()->key_range_.begin_key()); + } + + // Last key: check to see if we need to add the end_key. + if (keys_.back() != dataset()->key_range_.end_key()) { + keys_.push_back(dataset()->key_range_.end_key()); + } + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (index_ > keys_.size() - 2) { + *end_of_sequence = true; + return Status::OK(); + } + + *end_of_sequence = false; + out_tensors->emplace_back(ctx->allocator({}), DT_STRING, + TensorShape({})); + out_tensors->back().scalar()() = keys_[index_]; + + out_tensors->emplace_back(ctx->allocator({}), DT_STRING, + TensorShape({})); + out_tensors->back().scalar()() = keys_[index_ + 1]; + ++index_; + + return Status::OK(); + } + + private: + mutex mu_; + size_t index_ GUARDED_BY(mu_) = 0; + // Note: we store the keys_ on the iterator instead of the dataset + // because we want to re-sample the row keys in case there have been + // tablet rebalancing operations since the dataset was created. + // + // Note: keys_ is readonly after Initialize, and thus does not need a + // guarding lock. + std::vector keys_; + }; + + BigtableTableResource* const table_; + const MultiModeKeyRange key_range_; + }; +}; + +REGISTER_KERNEL_BUILDER( + Name("BigtableSampleKeyPairsDataset").Device(DEVICE_CPU), + BigtableSampleKeyPairsDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a759fb5063900199325304ccf83c52f3bdd7d702 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableSampleKeysDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + *output = new Dataset(ctx, resource); + } + + private: + class Dataset : public DatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table) + : DatasetBase(DatasetContext(ctx)), table_(table) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtableSampleKeys")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtableRangeKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + ::grpc::Status status; + row_keys_ = dataset()->table()->table().SampleRows(status); + if (!status.ok()) { + row_keys_.clear(); + return GrpcStatusToTfStatus(status); + } + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (index_ < row_keys_.size()) { + out_tensors->emplace_back(ctx->allocator({}), DT_STRING, + TensorShape({})); + out_tensors->back().scalar()() = + string(row_keys_[index_].row_key); + *end_of_sequence = false; + index_++; + } else { + *end_of_sequence = true; + } + return Status::OK(); + } + + private: + mutex mu_; + size_t index_ = 0; + std::vector<::google::cloud::bigtable::RowKeySample> row_keys_; + }; + + BigtableTableResource* const table_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableSampleKeysDataset").Device(DEVICE_CPU), + BigtableSampleKeysDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..78a920b077680980a209ad8c30c09409a6f4ebf5 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc @@ -0,0 +1,227 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableScanDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); + string start_key; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "start_key", &start_key)); + string end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); + + OP_REQUIRES(ctx, !(prefix.empty() && start_key.empty()), + errors::InvalidArgument( + "Either prefix or start_key must be specified")); + OP_REQUIRES(ctx, prefix.empty() || start_key.empty(), + errors::InvalidArgument( + "Only one of prefix and start_key can be provided")); + if (!prefix.empty()) { + OP_REQUIRES(ctx, end_key.empty(), + errors::InvalidArgument( + "If prefix is specified, end_key must be empty.")); + } + + std::vector column_families; + std::vector columns; + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "column_families", + &column_families)); + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "columns", &columns)); + OP_REQUIRES( + ctx, column_families.size() == columns.size(), + errors::InvalidArgument("len(columns) != len(column_families)")); + OP_REQUIRES(ctx, !column_families.empty(), + errors::InvalidArgument("`column_families` is empty")); + + float probability = 0; + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "probability", &probability)); + OP_REQUIRES( + ctx, probability > 0 && probability <= 1, + errors::InvalidArgument( + "Probability outside the range of (0, 1]. Got: ", probability)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + const uint64 num_outputs = columns.size() + 1; + std::vector output_shapes; + output_shapes.reserve(num_outputs); + DataTypeVector output_types; + output_types.reserve(num_outputs); + for (uint64 i = 0; i < num_outputs; ++i) { + output_shapes.push_back({}); + output_types.push_back(DT_STRING); + } + + *output = new Dataset(ctx, resource, std::move(prefix), + std::move(start_key), std::move(end_key), + std::move(column_families), std::move(columns), + probability, output_types, std::move(output_shapes)); + } + + private: + class Dataset : public DatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string prefix, string start_key, string end_key, + std::vector column_families, + std::vector columns, float probability, + const DataTypeVector& output_types, + std::vector output_shapes) + : DatasetBase(DatasetContext(ctx)), + table_(table), + prefix_(std::move(prefix)), + start_key_(std::move(start_key)), + end_key_(std::move(end_key)), + column_families_(std::move(column_families)), + columns_(std::move(columns)), + column_family_regex_(RegexFromStringSet(column_families_)), + column_regex_(RegexFromStringSet(columns_)), + probability_(probability), + output_types_(output_types), + output_shapes_(std::move(output_shapes)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::BigtableScan")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "BigtableScanDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + + private: + class Iterator : public BigtableReaderDatasetIterator { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator(params) {} + + ::google::cloud::bigtable::RowRange MakeRowRange() override { + if (!dataset()->prefix_.empty()) { + DCHECK(dataset()->start_key_.empty()); + return ::google::cloud::bigtable::RowRange::Prefix( + dataset()->prefix_); + } else { + DCHECK(!dataset()->start_key_.empty()) + << "Both prefix and start_key were empty!"; + return ::google::cloud::bigtable::RowRange::Range( + dataset()->start_key_, dataset()->end_key_); + } + } + ::google::cloud::bigtable::Filter MakeFilter() override { + // TODO(saeta): Investigate optimal ordering here. + return ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1), + ::google::cloud::bigtable::Filter::FamilyRegex( + dataset()->column_family_regex_), + ::google::cloud::bigtable::Filter::ColumnRegex( + dataset()->column_regex_), + dataset()->probability_ != 1.0 + ? ::google::cloud::bigtable::Filter::RowSample( + dataset()->probability_) + : ::google::cloud::bigtable::Filter::PassAllFilter()); + } + Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) override { + out_tensors->reserve(dataset()->columns_.size() + 1); + Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); + row_key_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(row_key_tensor)); + + if (row.cells().size() > 2 * dataset()->columns_.size()) { + LOG(WARNING) << "An excessive number of columns (" + << row.cells().size() + << ") were retrieved when reading row: " + << row.row_key(); + } + + for (uint64 i = 0; i < dataset()->columns_.size(); ++i) { + Tensor col_tensor(ctx->allocator({}), DT_STRING, {}); + bool found_column = false; + for (auto cell_itr = row.cells().begin(); + !found_column && cell_itr != row.cells().end(); ++cell_itr) { + if (cell_itr->family_name() == dataset()->column_families_[i] && + string(cell_itr->column_qualifier()) == + dataset()->columns_[i]) { + col_tensor.scalar()() = string(cell_itr->value()); + found_column = true; + } + } + if (!found_column) { + return errors::InvalidArgument( + "Column ", dataset()->column_families_[i], ":", + dataset()->columns_[i], " not found in row: ", row.row_key()); + } + out_tensors->emplace_back(std::move(col_tensor)); + } + return Status::OK(); + } + }; + + BigtableTableResource* table_; + const string prefix_; + const string start_key_; + const string end_key_; + const std::vector column_families_; + const std::vector columns_; + const string column_family_regex_; + const string column_regex_; + const float probability_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU), + BigtableScanDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..f083ce6f44b3c2a83d9b5d3235056eb94c4be4a8 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -0,0 +1,374 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" + +#include "google/bigtable/v2/data.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "re2/re2.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/util/ptr_util.h" +// #include "util/task/codes.pb.h" + +namespace tensorflow { +namespace { + +void UpdateRow(const ::google::bigtable::v2::Mutation& mut, + std::map* row) { + if (mut.has_set_cell()) { + CHECK(mut.set_cell().timestamp_micros() >= -1) + << "Timestamp_micros: " << mut.set_cell().timestamp_micros(); + auto col = + strings::Printf("%s:%s", mut.set_cell().family_name().c_str(), + string(mut.set_cell().column_qualifier()).c_str()); + (*row)[col] = string(mut.set_cell().value()); + } else if (mut.has_delete_from_column()) { + auto col = strings::Printf( + "%s:%s", mut.delete_from_column().family_name().c_str(), + string(mut.delete_from_column().column_qualifier()).c_str()); + row->erase(col); + } else if (mut.has_delete_from_family()) { + auto itr = row->lower_bound(mut.delete_from_family().family_name()); + auto prefix = + strings::Printf("%s:", mut.delete_from_family().family_name().c_str()); + while (itr != row->end() && itr->first.substr(0, prefix.size()) == prefix) { + row->erase(itr); + } + } else if (mut.has_delete_from_row()) { + row->clear(); + } else { + LOG(ERROR) << "Unknown mutation: " << mut.ShortDebugString(); + } +} + +} // namespace + +class SampleRowKeysResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::SampleRowKeysResponse> { + public: + explicit SampleRowKeysResponse(BigtableTestClient* client) + : client_(client) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + mutex_lock l2(client_->mu_); + if (num_messages_sent_ * 2 < client_->table_.rows.size()) { + *sz = 10000; // A sufficiently high enough value to not worry about. + return true; + } + return false; + } + + bool Read(google::bigtable::v2::SampleRowKeysResponse* resp) override { + // Send every other key from the table. + mutex_lock l(mu_); + mutex_lock l2(client_->mu_); + *resp = google::bigtable::v2::SampleRowKeysResponse(); + auto itr = client_->table_.rows.begin(); + for (uint64 i = 0; i < 2 * num_messages_sent_; ++i) { + ++itr; + if (itr == client_->table_.rows.end()) { + return false; + } + } + resp->set_row_key(itr->first); + resp->set_offset_bytes(100 * num_messages_sent_); + num_messages_sent_++; + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + mutex mu_; + int64 num_messages_sent_ GUARDED_BY(mu_) = 0; + BigtableTestClient* client_; // Not owned. +}; + +class ReadRowsResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::ReadRowsResponse> { + public: + ReadRowsResponse(BigtableTestClient* client, + google::bigtable::v2::ReadRowsRequest const& request) + : client_(client), request_(request) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + *sz = 10000000; // A sufficiently high enough value to not worry about. + return true; + } + + bool Read(google::bigtable::v2::ReadRowsResponse* resp) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + sent_first_message_ = true; + RowFilter filter = MakeRowFilter(); + + mutex_lock l2(client_->mu_); + *resp = google::bigtable::v2::ReadRowsResponse(); + // Send all contents in first response. + for (auto itr = client_->table_.rows.begin(); + itr != client_->table_.rows.end(); ++itr) { + if (filter.AllowRow(itr->first)) { + ::google::bigtable::v2::ReadRowsResponse_CellChunk* chunk = nullptr; + bool sent_first = false; + for (auto col_itr = itr->second.columns.begin(); + col_itr != itr->second.columns.end(); ++col_itr) { + if (filter.AllowColumn(col_itr->first)) { + chunk = resp->add_chunks(); + if (!sent_first) { + sent_first = true; + chunk->set_row_key(itr->first); + } + auto colon_idx = col_itr->first.find(":"); + CHECK(colon_idx != string::npos) + << "No ':' found in: " << col_itr->first; + chunk->mutable_family_name()->set_value( + string(col_itr->first, 0, colon_idx)); + chunk->mutable_qualifier()->set_value( + string(col_itr->first, ++colon_idx)); + if (!filter.strip_values) { + chunk->set_value(col_itr->second); + } + if (filter.only_one_column) { + break; + } + } + } + if (sent_first) { + // We are sending this row, so set the commit flag on the last chunk. + chunk->set_commit_row(true); + } + } + } + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + struct RowFilter { + std::set row_set; + std::vector> row_ranges; + double row_sample = 0.0; // Note: currently ignored. + std::unique_ptr col_filter; + bool strip_values = false; + bool only_one_column = false; + + bool AllowRow(const string& row) { + if (row_set.find(row) != row_set.end()) { + return true; + } + for (const auto& range : row_ranges) { + if (range.first <= row && range.second > row) { + return true; + } + } + return false; + } + + bool AllowColumn(const string& col) { + if (col_filter) { + return RE2::FullMatch(col, *col_filter); + } else { + return true; + } + } + }; + + RowFilter MakeRowFilter() { + RowFilter filter; + for (auto i = request_.rows().row_keys().begin(); + i != request_.rows().row_keys().end(); ++i) { + filter.row_set.insert(string(*i)); + } + for (auto i = request_.rows().row_ranges().begin(); + i != request_.rows().row_ranges().end(); ++i) { + if (i->start_key_case() != + google::bigtable::v2::RowRange::kStartKeyClosed || + i->end_key_case() != google::bigtable::v2::RowRange::kEndKeyOpen) { + LOG(WARNING) << "Skipping row range that cannot be processed: " + << i->ShortDebugString(); + continue; + } + filter.row_ranges.emplace_back(std::make_pair( + string(i->start_key_closed()), string(i->end_key_open()))); + } + if (request_.filter().has_chain()) { + string family_filter; + string qualifier_filter; + for (auto i = request_.filter().chain().filters().begin(); + i != request_.filter().chain().filters().end(); ++i) { + switch (i->filter_case()) { + case google::bigtable::v2::RowFilter::kFamilyNameRegexFilter: + family_filter = i->family_name_regex_filter(); + break; + case google::bigtable::v2::RowFilter::kColumnQualifierRegexFilter: + qualifier_filter = i->column_qualifier_regex_filter(); + break; + case google::bigtable::v2::RowFilter::kCellsPerColumnLimitFilter: + if (i->cells_per_column_limit_filter() != 1) { + LOG(ERROR) << "Unexpected cells_per_column_limit_filter: " + << i->cells_per_column_limit_filter(); + } + break; + case google::bigtable::v2::RowFilter::kStripValueTransformer: + filter.strip_values = i->strip_value_transformer(); + break; + case google::bigtable::v2::RowFilter::kRowSampleFilter: + LOG(INFO) << "Ignoring row sample directive."; + break; + case google::bigtable::v2::RowFilter::kPassAllFilter: + break; + case google::bigtable::v2::RowFilter::kCellsPerRowLimitFilter: + filter.only_one_column = true; + break; + default: + LOG(WARNING) << "Ignoring unknown filter type: " + << i->ShortDebugString(); + } + } + if (family_filter.empty() || qualifier_filter.empty()) { + LOG(WARNING) << "Missing regex!"; + } else { + string regex = strings::Printf("%s:%s", family_filter.c_str(), + qualifier_filter.c_str()); + filter.col_filter.reset(new RE2(regex)); + } + } else { + LOG(WARNING) << "Read request did not have a filter chain specified: " + << request_.filter().DebugString(); + } + return filter; + } + + mutex mu_; + bool sent_first_message_ GUARDED_BY(mu_) = false; + BigtableTestClient* client_; // Not owned. + const google::bigtable::v2::ReadRowsRequest request_; +}; + +class MutateRowsResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::MutateRowsResponse> { + public: + explicit MutateRowsResponse(size_t num_successes) + : num_successes_(num_successes) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + *sz = 10000000; // A sufficiently high enough value to not worry about. + return true; + } + + bool Read(google::bigtable::v2::MutateRowsResponse* resp) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + sent_first_message_ = true; + *resp = google::bigtable::v2::MutateRowsResponse(); + for (size_t i = 0; i < num_successes_; ++i) { + auto entry = resp->add_entries(); + entry->set_index(i); + } + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + const size_t num_successes_; + + mutex mu_; + bool sent_first_message_ = false; +}; + +grpc::Status BigtableTestClient::MutateRow( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + google::bigtable::v2::MutateRowResponse* response) { + mutex_lock l(mu_); + auto* row = &table_.rows[string(request.row_key())]; + for (int i = 0; i < request.mutations_size(); ++i) { + UpdateRow(request.mutations(i), &row->columns); + } + *response = google::bigtable::v2::MutateRowResponse(); + return grpc::Status::OK; +} +grpc::Status BigtableTestClient::CheckAndMutateRow( + grpc::ClientContext* context, + google::bigtable::v2::CheckAndMutateRowRequest const& request, + google::bigtable::v2::CheckAndMutateRowResponse* response) { + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "CheckAndMutateRow not implemented."); +} +grpc::Status BigtableTestClient::ReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + google::bigtable::v2::ReadModifyWriteRowResponse* response) { + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "ReadModifyWriteRow not implemented."); +} +std::unique_ptr< + grpc::ClientReaderInterface> +BigtableTestClient::ReadRows( + grpc::ClientContext* context, + google::bigtable::v2::ReadRowsRequest const& request) { + return MakeUnique(this, request); +} + +std::unique_ptr< + grpc::ClientReaderInterface> +BigtableTestClient::SampleRowKeys( + grpc::ClientContext* context, + google::bigtable::v2::SampleRowKeysRequest const& request) { + return MakeUnique(this); +} +std::unique_ptr< + grpc::ClientReaderInterface> +BigtableTestClient::MutateRows( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowsRequest const& request) { + mutex_lock l(mu_); + for (auto i = request.entries().begin(); i != request.entries().end(); ++i) { + auto* row = &table_.rows[string(i->row_key())]; + for (auto mut = i->mutations().begin(); mut != i->mutations().end(); + ++mut) { + UpdateRow(*mut, &row->columns); + } + } + return MakeUnique(request.entries_size()); +} + +std::shared_ptr BigtableTestClient::Channel() { + LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " + "cause a crash!"; + return nullptr; +} +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h new file mode 100644 index 0000000000000000000000000000000000000000..dac2b16a216d26f02684c7401ed2ddaa4b7baddb --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -0,0 +1,87 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ +#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ + +#include "google/cloud/bigtable/data_client.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +class BigtableTestClient : public ::google::cloud::bigtable::DataClient { + public: + std::string const& project_id() const override { return project_id_; } + std::string const& instance_id() const override { return instance_id_; } + void reset() override { + mutex_lock l(mu_); + table_ = Table(); + } + + grpc::Status MutateRow( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + google::bigtable::v2::MutateRowResponse* response) override; + + grpc::Status CheckAndMutateRow( + grpc::ClientContext* context, + google::bigtable::v2::CheckAndMutateRowRequest const& request, + google::bigtable::v2::CheckAndMutateRowResponse* response) override; + + grpc::Status ReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + google::bigtable::v2::ReadModifyWriteRowResponse* response) override; + + std::unique_ptr< + grpc::ClientReaderInterface> + ReadRows(grpc::ClientContext* context, + google::bigtable::v2::ReadRowsRequest const& request) override; + std::unique_ptr< + grpc::ClientReaderInterface> + SampleRowKeys( + grpc::ClientContext* context, + google::bigtable::v2::SampleRowKeysRequest const& request) override; + + std::unique_ptr< + grpc::ClientReaderInterface> + MutateRows(grpc::ClientContext* context, + google::bigtable::v2::MutateRowsRequest const& request) override; + + std::shared_ptr Channel() override; + + private: + friend class SampleRowKeysResponse; + friend class ReadRowsResponse; + friend class MutateRowsResponse; + + struct Row { + string row_key; + std::map columns; + }; + struct Table { + std::map rows; + }; + + mutex mu_; + const std::string project_id_ = "testproject"; + const std::string instance_id_ = "testinstance"; + Table table_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fa3e587b90147bd519586eef0cfb5e048b1b75be --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace tensorflow { + +namespace { + +class BigtableTestClientOp : public OpKernel { + public: + explicit BigtableTestClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~BigtableTestClientOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + BigtableClientResource* resource; + OP_REQUIRES_OK( + ctx, + mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx](BigtableClientResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::shared_ptr client( + new BigtableTestClient()); + // Note: must make explicit copies to sequence + // them before the move of client. + string project_id = client->project_id(); + string instance_id = client->instance_id(); + *ret = new BigtableClientResource(std::move(project_id), + std::move(instance_id), + std::move(client)); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableTestClient").Device(DEVICE_CPU), + BigtableTestClientOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..32611e2590d9a81f46d0b9dfc09fe7e0068e9671 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc @@ -0,0 +1,345 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" +#include "google/cloud/bigtable/internal/table.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +void WriteCell(const string& row, const string& family, const string& column, + const string& value, + ::google::cloud::bigtable::noex::Table* table) { + ::google::cloud::bigtable::SingleRowMutation mut(row); + mut.emplace_back(::google::cloud::bigtable::SetCell(family, column, value)); + table->Apply(std::move(mut)); +} + +TEST(BigtableTestClientTest, EmptyRowRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + ::google::cloud::bigtable::RowSet rowset; + rowset.Append("r1"); + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + EXPECT_EQ(rows.begin(), rows.end()) << "Some rows were returned in response!"; + EXPECT_TRUE(rows.Finish().ok()) << "Error reading rows."; +} + +TEST(BigtableTestClientTest, SingleRowWriteAndRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + + ::google::cloud::bigtable::RowSet rowset("r1"); + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + EXPECT_NE(itr, rows.end()) << "No rows were returned in response!"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + EXPECT_EQ(itr, rows.end()); + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndSingleRowRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + ::google::cloud::bigtable::RowSet rowset("r1"); + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + ::google::cloud::bigtable::RowSet rowset("r1", "r2", "r3"); + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndPrefixRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = + table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, ColumnFiltering) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + // Extra cells + WriteCell("r1", "f2", "c1", "v1", &table); + WriteCell("r2", "f2", "c1", "v2", &table); + WriteCell("r3", "f1", "c2", "v3", &table); + + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1), + ::google::cloud::bigtable::Filter::FamilyRegex("f1"), + ::google::cloud::bigtable::Filter::ColumnRegex("c1")); + auto rows = + table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, RowKeys) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + // Extra cells + WriteCell("r1", "f2", "c1", "v1", &table); + WriteCell("r2", "f2", "c1", "v2", &table); + WriteCell("r3", "f1", "c2", "v3", &table); + + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1), + ::google::cloud::bigtable::Filter::CellsRowLimit(1), + ::google::cloud::bigtable::Filter::StripValueTransformer()); + auto rows = + table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, SampleKeys) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + WriteCell("r4", "f1", "c1", "v4", &table); + WriteCell("r5", "f1", "c1", "v5", &table); + + grpc::Status status; + auto resp = table.SampleRows(status); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(3, resp.size()); + EXPECT_EQ("r1", string(resp[0].row_key)); + EXPECT_EQ(0, resp[0].offset_bytes); + EXPECT_EQ("r3", string(resp[1].row_key)); + EXPECT_EQ(100, resp[1].offset_bytes); + EXPECT_EQ("r5", string(resp[2].row_key)); + EXPECT_EQ(200, resp[2].offset_bytes); +} + +TEST(BigtableTestClientTest, SampleKeysShort) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + + grpc::Status status; + auto resp = table.SampleRows(status); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(1, resp.size()); + EXPECT_EQ("r1", string(resp[0].row_key)); +} + +TEST(BigtableTestClientTest, SampleKeysEvenNumber) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + WriteCell("r4", "f1", "c1", "v4", &table); + + grpc::Status status; + auto resp = table.SampleRows(status); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(2, resp.size()); + EXPECT_EQ("r1", string(resp[0].row_key)); + EXPECT_EQ("r3", string(resp[1].row_key)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..416b719e30aa5f2504449d151a48e95c9105c68b --- /dev/null +++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc @@ -0,0 +1,107 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// TODO(saeta): Add support for setting ClientOptions values. +REGISTER_OP("BigtableClient") + .Attr("project_id: string") + .Attr("instance_id: string") + .Attr("connection_pool_size: int") + .Attr("max_receive_message_size: int = -1") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("client: resource") + .SetShapeFn(shape_inference::ScalarShape); + +// TODO(saeta): Add support for Application Profiles. +// See https://cloud.google.com/bigtable/docs/app-profiles for more info. +REGISTER_OP("BigtableTable") + .Input("client: resource") + .Attr("table_name: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("table: resource") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("DatasetToBigtable") + .Input("table: resource") + .Input("input_dataset: variant") + .Input("column_families: string") + .Input("columns: string") + .Input("timestamp: int64") + .SetShapeFn(shape_inference::NoOutputs); + +REGISTER_OP("BigtableLookupDataset") + .Input("keys_dataset: variant") + .Input("table: resource") + .Input("column_families: string") + .Input("columns: string") + .Output("handle: variant") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtablePrefixKeyDataset") + .Input("table: resource") + .Input("prefix: string") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtableRangeKeyDataset") + .Input("table: resource") + .Input("start_key: string") + .Input("end_key: string") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtableSampleKeysDataset") + .Input("table: resource") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtableSampleKeyPairsDataset") + .Input("table: resource") + .Input("prefix: string") + .Input("start_key: string") + .Input("end_key: string") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +// TODO(saeta): Support continuing despite bad data (e.g. empty string, or +// skip incomplete row.) +REGISTER_OP("BigtableScanDataset") + .Input("table: resource") + .Input("prefix: string") + .Input("start_key: string") + .Input("end_key: string") + .Input("column_families: string") + .Input("columns: string") + .Input("probability: float") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..f7d02458f63d547000f00b184b3d5e3c5007fb72 --- /dev/null +++ b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("BigtableTestClient") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("client: resource") + .SetShapeFn(shape_inference::ScalarShape); + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..292d8f4e51abbbd89d68b47febd86b7297bb8ed2 --- /dev/null +++ b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""This module contains tests for the bigtable integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e36f7f32c61b50047c0d9137427f2a24462b1c9a --- /dev/null +++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py @@ -0,0 +1,272 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bigtable Ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import bigtable +from tensorflow.contrib.bigtable.ops import gen_bigtable_ops +from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops +from tensorflow.contrib.bigtable.python.ops import bigtable_api +from tensorflow.contrib.util import loader +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test +from tensorflow.python.util import compat + +_bigtable_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_bigtable_test.so")) + + +def _ListOfTuplesOfStringsToBytes(values): + return [(compat.as_bytes(i[0]), compat.as_bytes(i[1])) for i in values] + + +class BigtableOpsTest(test.TestCase): + COMMON_ROW_KEYS = ["r1", "r2", "r3"] + COMMON_VALUES = ["v1", "v2", "v3"] + + def setUp(self): + self._client = gen_bigtable_test_ops.bigtable_test_client() + table = gen_bigtable_ops.bigtable_table(self._client, "testtable") + self._table = bigtable.BigtableTable("testtable", None, table) + + def _makeSimpleDataset(self): + output_rows = dataset_ops.Dataset.from_tensor_slices(self.COMMON_ROW_KEYS) + output_values = dataset_ops.Dataset.from_tensor_slices(self.COMMON_VALUES) + return dataset_ops.Dataset.zip((output_rows, output_values)) + + def _writeCommonValues(self, sess): + output_ds = self._makeSimpleDataset() + write_op = self._table.write(output_ds, ["cf1"], ["c1"]) + sess.run(write_op) + + def runReadKeyTest(self, read_ds): + itr = read_ds.make_initializable_iterator() + n = itr.get_next() + expected = list(self.COMMON_ROW_KEYS) + expected.reverse() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i in range(3): + output = sess.run(n) + want = expected.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output), + "Unequal at step %d: want: %s, got: %s" % (i, want, output)) + + def testReadPrefixKeys(self): + self.runReadKeyTest(self._table.keys_by_prefix_dataset("r")) + + def testReadRangeKeys(self): + self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4")) + + def runScanTest(self, read_ds): + itr = read_ds.make_initializable_iterator() + n = itr.get_next() + expected_keys = list(self.COMMON_ROW_KEYS) + expected_keys.reverse() + expected_values = list(self.COMMON_VALUES) + expected_values.reverse() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i in range(3): + output = sess.run(n) + want = expected_keys.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output[0]), + "Unequal keys at step %d: want: %s, got: %s" % (i, want, output[0])) + want = expected_values.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output[1]), + "Unequal values at step: %d: want: %s, got: %s" % (i, want, + output[1])) + + def testScanPrefixStringCol(self): + self.runScanTest(self._table.scan_prefix("r", cf1="c1")) + + def testScanPrefixListCol(self): + self.runScanTest(self._table.scan_prefix("r", cf1=["c1"])) + + def testScanPrefixTupleCol(self): + self.runScanTest(self._table.scan_prefix("r", columns=("cf1", "c1"))) + + def testScanRangeStringCol(self): + self.runScanTest(self._table.scan_range("r1", "r4", cf1="c1")) + + def testScanRangeListCol(self): + self.runScanTest(self._table.scan_range("r1", "r4", cf1=["c1"])) + + def testScanRangeTupleCol(self): + self.runScanTest(self._table.scan_range("r1", "r4", columns=("cf1", "c1"))) + + def testLookup(self): + ds = self._table.keys_by_prefix_dataset("r") + ds = ds.apply(self._table.lookup_columns(cf1="c1")) + itr = ds.make_initializable_iterator() + n = itr.get_next() + expected_keys = list(self.COMMON_ROW_KEYS) + expected_values = list(self.COMMON_VALUES) + expected_tuples = zip(expected_keys, expected_values) + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i, elem in enumerate(expected_tuples): + output = sess.run(n) + self.assertEqual( + compat.as_bytes(elem[0]), compat.as_bytes(output[0]), + "Unequal keys at step %d: want: %s, got: %s" % + (i, compat.as_bytes(elem[0]), compat.as_bytes(output[0]))) + self.assertEqual( + compat.as_bytes(elem[1]), compat.as_bytes(output[1]), + "Unequal values at step %d: want: %s, got: %s" % + (i, compat.as_bytes(elem[1]), compat.as_bytes(output[1]))) + + def testSampleKeys(self): + ds = self._table.sample_keys() + itr = ds.make_initializable_iterator() + n = itr.get_next() + expected_key = self.COMMON_ROW_KEYS[0] + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + output = sess.run(n) + self.assertEqual( + compat.as_bytes(self.COMMON_ROW_KEYS[0]), compat.as_bytes(output), + "Unequal keys: want: %s, got: %s" % (compat.as_bytes( + self.COMMON_ROW_KEYS[0]), compat.as_bytes(output))) + output = sess.run(n) + self.assertEqual( + compat.as_bytes(self.COMMON_ROW_KEYS[2]), compat.as_bytes(output), + "Unequal keys: want: %s, got: %s" % (compat.as_bytes( + self.COMMON_ROW_KEYS[2]), compat.as_bytes(output))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(n) + + def runSampleKeyPairsTest(self, ds, expected_key_pairs): + itr = ds.make_initializable_iterator() + n = itr.get_next() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i, elems in enumerate(expected_key_pairs): + output = sess.run(n) + self.assertEqual( + compat.as_bytes(elems[0]), compat.as_bytes(output[0]), + "Unequal key pair (first element) at step %d; want: %s, got %s" % + (i, compat.as_bytes(elems[0]), compat.as_bytes(output[0]))) + self.assertEqual( + compat.as_bytes(elems[1]), compat.as_bytes(output[1]), + "Unequal key pair (second element) at step %d; want: %s, got %s" % + (i, compat.as_bytes(elems[1]), compat.as_bytes(output[1]))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(n) + + def testSampleKeyPairsSimplePrefix(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="r", start="", end="") + expected_key_pairs = [("r", "r1"), ("r1", "r3"), ("r3", "s")] + self.runSampleKeyPairsTest(ds, expected_key_pairs) + + def testSampleKeyPairsSimpleRange(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="", start="r1", end="r3") + expected_key_pairs = [("r1", "r3")] + self.runSampleKeyPairsTest(ds, expected_key_pairs) + + def testSampleKeyPairsSkipRangePrefix(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="r2", start="", end="") + expected_key_pairs = [("r2", "r3")] + self.runSampleKeyPairsTest(ds, expected_key_pairs) + + def testSampleKeyPairsSkipRangeRange(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="", start="r2", end="r3") + expected_key_pairs = [("r2", "r3")] + self.runSampleKeyPairsTest(ds, expected_key_pairs) + + def testSampleKeyPairsOffsetRanges(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="", start="r2", end="r4") + expected_key_pairs = [("r2", "r3"), ("r3", "r4")] + self.runSampleKeyPairsTest(ds, expected_key_pairs) + + def testSampleKeyPairEverything(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="", start="", end="") + expected_key_pairs = [("", "r1"), ("r1", "r3"), ("r3", "")] + self.runSampleKeyPairsTest(ds, expected_key_pairs) + + def testSampleKeyPairsPrefixAndStartKey(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="r", start="r1", end="") + itr = ds.make_initializable_iterator() + with self.test_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(itr.initializer) + + def testSampleKeyPairsPrefixAndEndKey(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="r", start="", end="r3") + itr = ds.make_initializable_iterator() + with self.test_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(itr.initializer) + + def testParallelScanPrefix(self): + ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1") + itr = ds.make_initializable_iterator() + n = itr.get_next() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES)) + actual_values = [] + for _ in range(len(expected_values)): + output = sess.run(n) + actual_values.append(output) + with self.assertRaises(errors.OutOfRangeError): + sess.run(n) + self.assertItemsEqual( + _ListOfTuplesOfStringsToBytes(expected_values), + _ListOfTuplesOfStringsToBytes(actual_values)) + + def testParallelScanRange(self): + ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1") + itr = ds.make_initializable_iterator() + n = itr.get_next() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES)) + actual_values = [] + for _ in range(len(expected_values)): + output = sess.run(n) + actual_values.append(output) + with self.assertRaises(errors.OutOfRangeError): + sess.run(n) + self.assertItemsEqual( + _ListOfTuplesOfStringsToBytes(expected_values), + _ListOfTuplesOfStringsToBytes(actual_values)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bigtable/python/ops/__init__.py b/tensorflow/contrib/bigtable/python/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36d75b0d7068a650347a5e17f4727a5432d8752f --- /dev/null +++ b/tensorflow/contrib/bigtable/python/ops/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""This module contains the Python API for the Cloud Bigtable integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py new file mode 100644 index 0000000000000000000000000000000000000000..3e1b6228673fbdcb5a228a11532d29e6b2c817dc --- /dev/null +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -0,0 +1,746 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Python API for TensorFlow's Cloud Bigtable integration. + +TensorFlow has support for reading from and writing to Cloud Bigtable. To use +TensorFlow + Cloud Bigtable integration, first create a BigtableClient to +configure your connection to Cloud Bigtable, and then create a BigtableTable +object to allow you to create numerous `tf.data.Dataset`s to read data, or +write a `tf.data.Dataset` object to the underlying Cloud Bigtable table. + +For background on Cloud Bigtable, see: https://cloud.google.com/bigtable . +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six import iteritems +from six import string_types + +from tensorflow.contrib.bigtable.ops import gen_bigtable_ops +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.contrib.util import loader +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import resource_loader + +_bigtable_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_bigtable.so")) + + +class BigtableClient(object): + """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF. + + BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the + `table` method to open a Bigtable table. + """ + + def __init__(self, + project_id, + instance_id, + connection_pool_size=None, + max_receive_message_size=None): + """Creates a BigtableClient that can be used to open connections to tables. + + Args: + project_id: A string representing the GCP project id to connect to. + instance_id: A string representing the Bigtable instance to connect to. + connection_pool_size: (Optional.) A number representing the number of + concurrent connections to the Cloud Bigtable service to make. + max_receive_message_size: (Optional.) The maximum bytes received in a + single gRPC response. + + Raises: + ValueError: if the arguments are invalid (e.g. wrong type, or out of + expected ranges (e.g. negative).) + """ + if not isinstance(project_id, str): + raise ValueError("`project_id` must be a string") + self._project_id = project_id + + if not isinstance(instance_id, str): + raise ValueError("`instance_id` must be a string") + self._instance_id = instance_id + + if connection_pool_size is None: + connection_pool_size = -1 + elif connection_pool_size < 1: + raise ValueError("`connection_pool_size` must be positive") + + if max_receive_message_size is None: + max_receive_message_size = -1 + elif max_receive_message_size < 1: + raise ValueError("`max_receive_message_size` must be positive") + + self._connection_pool_size = connection_pool_size + + self._resource = gen_bigtable_ops.bigtable_client( + project_id, instance_id, connection_pool_size, max_receive_message_size) + + def table(self, name, snapshot=None): + """Opens a table and returns a `tf.contrib.bigtable.BigtableTable` object. + + Args: + name: A `tf.string` `tf.Tensor` name of the table to open. + snapshot: Either a `tf.string` `tf.Tensor` snapshot id, or `True` to + request the creation of a snapshot. (Note: currently unimplemented.) + + Returns: + A `tf.contrib.bigtable.BigtableTable` Python object representing the + operations available on the table. + """ + # TODO(saeta): Implement snapshot functionality. + table = gen_bigtable_ops.bigtable_table(self._resource, name) + return BigtableTable(name, snapshot, table) + + +class BigtableTable(object): + """BigtableTable is the entrypoint for reading and writing data in Cloud + Bigtable. + + This BigtableTable class is the Python representation of the Cloud Bigtable + table within TensorFlow. Methods on this class allow data to be read from and + written to the Cloud Bigtable service in flexible and high performance + manners. + """ + + # TODO(saeta): Investigate implementing tf.contrib.lookup.LookupInterface. + # TODO(saeta): Consider variant tensors instead of resources (while supporting + # connection pooling). + + def __init__(self, name, snapshot, resource): + self._name = name + self._snapshot = snapshot + self._resource = resource + + def lookup_columns(self, *args, **kwargs): + """Retrieves the values of columns for a dataset of keys. + + Example usage: + + ```python + table = bigtable_client.table("my_table") + key_dataset = table.get_keys_prefix("imagenet") + images = key_dataset.apply(table.lookup_columns(("cf1", "image"), + ("cf2", "label"), + ("cf2", "boundingbox"))) + training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128) + ``` + + Alternatively, you can use keyword arguments to specify the columns to + capture. Example (same as above, rewritten): + + ```python + table = bigtable_client.table("my_table") + key_dataset = table.get_keys_prefix("imagenet") + images = key_dataset.apply(table.lookup_columns( + cf1="image", cf2=("label", "boundingbox"))) + training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128) + ``` + + Note: certain `kwargs` keys are reserved, and thus, some column families + cannot be identified using the `kwargs` syntax. Instead, please use the + `args` syntax. This list includes: + + - 'name' + + Note: this list can change at any time. + + Args: + *args: A list of tuples containing (column family, column name) pairs. + **kwargs: Column families (keys) and column qualifiers (values). + + Returns: + A function that can be passed to `tf.data.Dataset.apply` to retrieve the + values of columns for the rows. + """ + table = self # Capture self + normalized = args + if normalized is None: + normalized = [] + if isinstance(normalized, tuple): + normalized = list(normalized) + for key, value in iteritems(kwargs): + if key == "name": + continue + if isinstance(value, str): + normalized.append((key, value)) + continue + for col in value: + normalized.append((key, col)) + + def _apply_fn(dataset): + # TODO(saeta): Verify dataset's types are correct! + return _BigtableLookupDataset(dataset, table, normalized) + + return _apply_fn + + def keys_by_range_dataset(self, start, end): + """Retrieves all row keys between start and end. + + Note: it does NOT retrieve the values of columns. + + Args: + start: The start row key. The row keys for rows after start (inclusive) + will be retrieved. + end: (Optional.) The end row key. Rows up to (but not including) end will + be retrieved. If end is None, all subsequent row keys will be retrieved. + + Returns: + A `tf.data.Dataset` containing `tf.string` Tensors corresponding to all + of the row keys between `start` and `end`. + """ + # TODO(saeta): Make inclusive / exclusive configurable? + if end is None: + end = "" + return _BigtableRangeKeyDataset(self, start, end) + + def keys_by_prefix_dataset(self, prefix): + """Retrieves the row keys matching a given prefix. + + Args: + prefix: All row keys that begin with `prefix` in the table will be + retrieved. + + Returns: + A `tf.data.Dataset`. containing `tf.string` Tensors corresponding to all + of the row keys matching that prefix. + """ + return _BigtablePrefixKeyDataset(self, prefix) + + def sample_keys(self): + """Retrieves a sampling of row keys from the Bigtable table. + + This dataset is most often used in conjunction with + `tf.contrib.data.parallel_interleave` to construct a set of ranges for + scanning in parallel. + + Returns: + A `tf.data.Dataset` returning string row keys. + """ + return _BigtableSampleKeysDataset(self) + + def scan_prefix(self, prefix, probability=None, columns=None, **kwargs): + """Retrieves row (including values) from the Bigtable service. + + Rows with row-key prefixed by `prefix` will be retrieved. + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.scan_prefix("row_prefix", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + prefix: The prefix all row keys must match to be retrieved for prefix- + based scans. + probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive). + A non-1 value indicates to probabilistically sample rows with the + provided probability. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A `tf.data.Dataset` returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + probability = _normalize_probability(probability) + normalized = _normalize_columns(columns, kwargs) + return _BigtableScanDataset(self, prefix, "", "", normalized, probability) + + def scan_range(self, start, end, probability=None, columns=None, **kwargs): + """Retrieves rows (including values) from the Bigtable service. + + Rows with row-keys between `start` and `end` will be retrieved. + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + start: The start of the range when scanning by range. + end: (Optional.) The end of the range when scanning by range. + probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive). + A non-1 value indicates to probabilistically sample rows with the + provided probability. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A `tf.data.Dataset` returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + probability = _normalize_probability(probability) + normalized = _normalize_columns(columns, kwargs) + return _BigtableScanDataset(self, "", start, end, normalized, probability) + + def parallel_scan_prefix(self, + prefix, + num_parallel_scans=None, + probability=None, + columns=None, + **kwargs): + """Retrieves row (including values) from the Bigtable service at high speed. + + Rows with row-key prefixed by `prefix` will be retrieved. This method is + similar to `scan_prefix`, but by contrast performs multiple sub-scans in + parallel in order to achieve higher performance. + + Note: The dataset produced by this method is not deterministic! + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.parallel_scan_prefix("row_prefix", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.parallel_scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + prefix: The prefix all row keys must match to be retrieved for prefix- + based scans. + num_parallel_scans: (Optional.) The number of concurrent scans against the + Cloud Bigtable instance. + probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive). + A non-1 value indicates to probabilistically sample rows with the + provided probability. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A `tf.data.Dataset` returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + probability = _normalize_probability(probability) + normalized = _normalize_columns(columns, kwargs) + ds = _BigtableSampleKeyPairsDataset(self, prefix, "", "") + return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, + normalized) + + def parallel_scan_range(self, + start, + end, + num_parallel_scans=None, + probability=None, + columns=None, + **kwargs): + """Retrieves rows (including values) from the Bigtable service. + + Rows with row-keys between `start` and `end` will be retrieved. This method + is similar to `scan_range`, but by contrast performs multiple sub-scans in + parallel in order to achieve higher performance. + + Note: The dataset produced by this method is not deterministic! + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.parallel_scan_range("row_start", + "row_end", + columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.parallel_scan_range("row_start", "row_end", + cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + start: The start of the range when scanning by range. + end: (Optional.) The end of the range when scanning by range. + num_parallel_scans: (Optional.) The number of concurrent scans against the + Cloud Bigtable instance. + probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive). + A non-1 value indicates to probabilistically sample rows with the + provided probability. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A `tf.data.Dataset` returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + probability = _normalize_probability(probability) + normalized = _normalize_columns(columns, kwargs) + ds = _BigtableSampleKeyPairsDataset(self, "", start, end) + return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, + normalized) + + def write(self, dataset, column_families, columns, timestamp=None): + """Writes a dataset to the table. + + Args: + dataset: A `tf.data.Dataset` to be written to this table. It must produce + a list of number-of-columns+1 elements, all of which must be strings. + The first value will be used as the row key, and subsequent values will + be used as cell values for the corresponding columns from the + corresponding column_families and columns entries. + column_families: A `tf.Tensor` of `tf.string`s corresponding to the + column names to store the dataset's elements into. + columns: A `tf.Tensor` of `tf.string`s corresponding to the column names + to store the dataset's elements into. + timestamp: (Optional.) An int64 timestamp to write all the values at. + Leave as None to use server-provided timestamps. + + Returns: + A `tf.Operation` that can be run to perform the write. + + Raises: + ValueError: If there are unexpected or incompatible types, or if the + number of columns and column_families does not match the output of + `dataset`. + """ + if timestamp is None: + timestamp = -1 # Bigtable server provided timestamp. + for tensor_type in nest.flatten(dataset.output_types): + if tensor_type != dtypes.string: + raise ValueError("Not all elements of the dataset were `tf.string`") + for shape in nest.flatten(dataset.output_shapes): + if not shape.is_compatible_with(tensor_shape.scalar()): + raise ValueError("Not all elements of the dataset were scalars") + if len(column_families) != len(columns): + raise ValueError("len(column_families) != len(columns)") + if len(nest.flatten(dataset.output_types)) != len(columns) + 1: + raise ValueError("A column name must be specified for every component of " + "the dataset elements. (e.g.: len(columns) != " + "len(dataset.output_types))") + return gen_bigtable_ops.dataset_to_bigtable( + self._resource, + dataset._as_variant_tensor(), # pylint: disable=protected-access + column_families, + columns, + timestamp) + + def _make_parallel_scan_dataset(self, ds, num_parallel_scans, + normalized_probability, normalized_columns): + """Builds a parallel dataset from a given range. + + Args: + ds: A `_BigtableSampleKeyPairsDataset` returning ranges of keys to use. + num_parallel_scans: The number of concurrent parallel scans to use. + normalized_probability: A number between 0 and 1 for the keep probability. + normalized_columns: The column families and column qualifiers to retrieve. + + Returns: + A `tf.data.Dataset` representing the result of the parallel scan. + """ + if num_parallel_scans is None: + num_parallel_scans = 50 + + ds = ds.shuffle(buffer_size=10000) # TODO(saeta): Make configurable. + + def _interleave_fn(start, end): + return _BigtableScanDataset( + self, + prefix="", + start=start, + end=end, + normalized=normalized_columns, + probability=normalized_probability) + + # Note prefetch_input_elements must be set in order to avoid rpc timeouts. + ds = ds.apply( + interleave_ops.parallel_interleave( + _interleave_fn, + cycle_length=num_parallel_scans, + sloppy=True, + prefetch_input_elements=1)) + return ds + + +def _normalize_probability(probability): + if probability is None: + probability = 1.0 + if isinstance(probability, float) and (probability <= 0.0 or + probability > 1.0): + raise ValueError("probability must be in the range (0, 1].") + return probability + + +def _normalize_columns(columns, provided_kwargs): + """Converts arguments (columns, and kwargs dict) to C++ representation. + + Args: + columns: a datastructure containing the column families and qualifier to + retrieve. Valid types include (1) None, (2) list of tuples, (3) a tuple of + strings. + provided_kwargs: a dictionary containing the column families and qualifiers + to retrieve + + Returns: + A list of pairs of column family+qualifier to retrieve. + + Raises: + ValueError: If there are no cells to retrieve or the columns are in an + incorrect format. + """ + normalized = columns + if normalized is None: + normalized = [] + if isinstance(normalized, tuple): + if len(normalized) == 2: + normalized = [normalized] + else: + raise ValueError("columns was a tuple of inappropriate length") + for key, value in iteritems(provided_kwargs): + if key == "name": + continue + if isinstance(value, string_types): + normalized.append((key, value)) + continue + for col in value: + normalized.append((key, col)) + if not normalized: + raise ValueError("At least one column + column family must be specified.") + return normalized + + +class _BigtableKeyDataset(dataset_ops.Dataset): + """_BigtableKeyDataset is an abstract class representing the keys of a table. + """ + + def __init__(self, table): + """Constructs a _BigtableKeyDataset. + + Args: + table: a Bigtable class. + """ + super(_BigtableKeyDataset, self).__init__() + self._table = table + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.TensorShape([]) + + @property + def output_types(self): + return dtypes.string + + +class _BigtablePrefixKeyDataset(_BigtableKeyDataset): + """_BigtablePrefixKeyDataset represents looking up keys by prefix. + """ + + def __init__(self, table, prefix): + super(_BigtablePrefixKeyDataset, self).__init__(table) + self._prefix = prefix + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_prefix_key_dataset( + table=self._table._resource, # pylint: disable=protected-access + prefix=self._prefix) + + +class _BigtableRangeKeyDataset(_BigtableKeyDataset): + """_BigtableRangeKeyDataset represents looking up keys by range. + """ + + def __init__(self, table, start, end): + super(_BigtableRangeKeyDataset, self).__init__(table) + self._start = start + self._end = end + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_range_key_dataset( + table=self._table._resource, # pylint: disable=protected-access + start_key=self._start, + end_key=self._end) + + +class _BigtableSampleKeysDataset(_BigtableKeyDataset): + """_BigtableSampleKeysDataset represents a sampling of row keys. + """ + + # TODO(saeta): Expose the data size offsets into the keys. + + def __init__(self, table): + super(_BigtableSampleKeysDataset, self).__init__(table) + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_sample_keys_dataset( + table=self._table._resource) # pylint: disable=protected-access + + +class _BigtableLookupDataset(dataset_ops.Dataset): + """_BigtableLookupDataset represents a dataset that retrieves values for keys. + """ + + def __init__(self, dataset, table, normalized): + self._num_outputs = len(normalized) + 1 # 1 for row key + self._dataset = dataset + self._table = table + self._normalized = normalized + self._column_families = [i[0] for i in normalized] + self._columns = [i[1] for i in normalized] + + @property + def output_classes(self): + return tuple([ops.Tensor] * self._num_outputs) + + @property + def output_shapes(self): + return tuple([tensor_shape.TensorShape([])] * self._num_outputs) + + @property + def output_types(self): + return tuple([dtypes.string] * self._num_outputs) + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return gen_bigtable_ops.bigtable_lookup_dataset( + keys_dataset=self._dataset._as_variant_tensor(), + table=self._table._resource, + column_families=self._column_families, + columns=self._columns) + + +class _BigtableScanDataset(dataset_ops.Dataset): + """_BigtableScanDataset represents a dataset that retrieves keys and values. + """ + + def __init__(self, table, prefix, start, end, normalized, probability): + self._table = table + self._prefix = prefix + self._start = start + self._end = end + self._column_families = [i[0] for i in normalized] + self._columns = [i[1] for i in normalized] + self._probability = probability + self._num_outputs = len(normalized) + 1 # 1 for row key + + @property + def output_classes(self): + return tuple([ops.Tensor] * self._num_outputs) + + @property + def output_shapes(self): + return tuple([tensor_shape.TensorShape([])] * self._num_outputs) + + @property + def output_types(self): + return tuple([dtypes.string] * self._num_outputs) + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_scan_dataset( + table=self._table._resource, # pylint: disable=protected-access + prefix=self._prefix, + start_key=self._start, + end_key=self._end, + column_families=self._column_families, + columns=self._columns, + probability=self._probability) + + +class _BigtableSampleKeyPairsDataset(dataset_ops.Dataset): + """_BigtableSampleKeyPairsDataset returns key pairs from a Bigtable table. + """ + + def __init__(self, table, prefix, start, end): + self._table = table + self._prefix = prefix + self._start = start + self._end = end + + @property + def output_classes(self): + return (ops.Tensor, ops.Tensor) + + @property + def output_shapes(self): + return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) + + @property + def output_types(self): + return (dtypes.string, dtypes.string) + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return gen_bigtable_ops.bigtable_sample_key_pairs_dataset( + table=self._table._resource, + prefix=self._prefix, + start_key=self._start, + end_key=self._end) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 8cff1a3bb1d11aff6a264636291a7149b40de516..5fcb19a47aac492d49b0d8e99af5699bae2ad9f0 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -15,8 +15,9 @@ py_library( srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - "custom_export_strategy", + ":custom_export_strategy", ":custom_loss_head", + ":distillation_loss", ":estimator", ":model", ":trainer_hooks", @@ -144,7 +145,9 @@ py_library( srcs = ["dnn_tree_combined_estimator.py"], srcs_version = "PY2AND3", deps = [ + ":distillation_loss", ":estimator_utils", + ":model", ":trainer_hooks", "//tensorflow/contrib/boosted_trees:gbdt_batch", "//tensorflow/contrib/boosted_trees:model_ops_py", @@ -156,6 +159,17 @@ py_library( ], ) +py_library( + name = "distillation_loss", + srcs = ["distillation_loss.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/learn", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + ], +) + py_test( name = "dnn_tree_combined_estimator_test", size = "medium", @@ -177,7 +191,7 @@ py_test( py_test( name = "estimator_test", - size = "medium", + size = "large", srcs = ["estimator_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 62f1f4122b05b56a708823df4246d618bd3fa5d4..78232fa0a6e2311c13d4f35acffc3486a9a28803 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.platform import gfile from tensorflow.python.saved_model import loader as saved_model_loader from tensorflow.python.saved_model import tag_constants +from tensorflow.python.util import compat _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d" @@ -88,10 +89,12 @@ def make_custom_export_strategy(name, len(sparse_float_indices), len(sparse_int_indices)) sorted_by_importance = sorted( feature_importances.items(), key=lambda x: -x[1]) - assets_dir = os.path.join(result_dir, "assets.extra") + assets_dir = os.path.join( + compat.as_bytes(result_dir), compat.as_bytes("assets.extra")) gfile.MakeDirs(assets_dir) - with gfile.GFile(os.path.join(assets_dir, "feature_importances"), - "w") as f: + with gfile.GFile(os.path.join( + compat.as_bytes(assets_dir), + compat.as_bytes("feature_importances")), "w") as f: f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance)) return result_dir diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/distillation_loss.py b/tensorflow/contrib/boosted_trees/estimator_batch/distillation_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9aacc5534329d1302b25dcfab678f9adb8f773f6 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/estimator_batch/distillation_loss.py @@ -0,0 +1,75 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utill functions for distillation loss. + +The distillation loss_fn will be called with the following: + +Args: + dnn_logits: Tensor of logits from the dnn, treated as the "target". This will + be the output of a call to tf.stop_gradient(). + tree_logits: Tensor of logits from the tree, treated as the "predictions". + example_weights: Tensor of example weights, or a single scalar. + +Returns: + A scalar indicating the reduced loss for that batch of examples. + +Note: we calls the loss_fn defined in contrib head, which is computing two +losses, first one for training and second one for reporting. We only take the +first one here. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn + + +def _logits_to_label_for_tree(logits, n_classes): + if n_classes == 2: + return math_ops.sigmoid(logits) + else: + return nn.softmax(logits) + + +def create_dnn_to_tree_squared_loss_fn(n_classes): + """Returns a squared loss function for dnn to tree distillation.""" + + def _dnn_to_tree_squared_loss(dnn_logits, tree_logits, example_weights): + return head_lib._mean_squared_loss( # pylint: disable=protected-access + labels=_logits_to_label_for_tree(dnn_logits, n_classes), + logits=_logits_to_label_for_tree(tree_logits, n_classes), + weights=example_weights)[0] + + return _dnn_to_tree_squared_loss + + +def create_dnn_to_tree_cross_entropy_loss_fn(n_classes): + """Returns a cross entropy loss function for dnn to tree distillation.""" + + def _dnn_to_tree_cross_entropy_loss(dnn_logits, tree_logits, example_weights): + if n_classes == 2: + return head_lib._log_loss_with_two_classes( # pylint: disable=protected-access + labels=_logits_to_label_for_tree(dnn_logits, n_classes), + logits=tree_logits, + weights=example_weights)[0] + else: + return head_lib._softmax_cross_entropy_loss( # pylint: disable=protected-access + labels=_logits_to_label_for_tree(dnn_logits, n_classes), + logits=tree_logits, + weights=example_weights)[0] + + return _dnn_to_tree_cross_entropy_loss diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 911d87fa10570382ee5f03edfc1bfd1d116c8360..194a5c8754cb0ab2db299e3fb5c998c0f27f8435 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -24,7 +24,10 @@ from __future__ import division from __future__ import print_function import six + from tensorflow.contrib import layers +from tensorflow.contrib.boosted_trees.estimator_batch import model +from tensorflow.contrib.boosted_trees.estimator_batch import distillation_loss from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops @@ -32,14 +35,17 @@ from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batc from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.estimator import estimator as core_estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import training_util @@ -58,26 +64,30 @@ def _add_hidden_layer_summary(value, tag): summary.histogram("%s_activation" % tag, value) -def _dnn_tree_combined_model_fn(features, - labels, - mode, - head, - dnn_hidden_units, - dnn_feature_columns, - tree_learner_config, - num_trees, - tree_examples_per_layer, - config=None, - dnn_optimizer="Adagrad", - dnn_activation_fn=nn.relu, - dnn_dropout=None, - dnn_input_layer_partitioner=None, - dnn_input_layer_to_tree=True, - dnn_steps_to_train=10000, - predict_with_tree_only=False, - tree_feature_columns=None, - tree_center_bias=False, - use_core_versions=False): +def _dnn_tree_combined_model_fn( + features, + labels, + mode, + head, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + config=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + predict_with_tree_only=False, + tree_feature_columns=None, + tree_center_bias=False, + dnn_to_tree_distillation_param=None, + use_core_versions=False, + output_type=model.ModelBuilderOutputType.MODEL_FN_OPS, + override_global_step_value=None): """DNN and GBDT combined model_fn. Args: @@ -117,8 +127,21 @@ def _dnn_tree_combined_model_fn(features, set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. Returns: A `ModelFnOps` object. @@ -132,12 +155,22 @@ def _dnn_tree_combined_model_fn(features, if not dnn_feature_columns: raise ValueError("dnn_feature_columns must be specified") + if dnn_to_tree_distillation_param: + if not predict_with_tree_only: + logging.warning("update predict_with_tree_only to True since distillation" + "is specified.") + predict_with_tree_only = True + # Build DNN Logits. dnn_parent_scope = "dnn" dnn_partitioner = dnn_input_layer_partitioner or ( partitioned_variables.min_max_variable_partitioner( max_partitions=config.num_ps_replicas, min_slice_size=64 << 20)) + if (output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC and + not use_core_versions): + raise ValueError("You must use core versions with Estimator Spec") + with variable_scope.variable_scope( dnn_parent_scope, values=tuple(six.itervalues(features)), @@ -217,7 +250,8 @@ def _dnn_tree_combined_model_fn(features, learner_config=tree_learner_config, feature_columns=tree_feature_columns, logits_dimension=head.logits_dimension, - features=tree_features) + features=tree_features, + use_core_columns=use_core_versions) with ops.name_scope("gbdt"): predictions_dict = gbdt_model.predict(mode) @@ -225,6 +259,25 @@ def _dnn_tree_combined_model_fn(features, def _tree_train_op_fn(loss): """Returns the op to optimize the loss.""" + if dnn_to_tree_distillation_param: + loss_weight, loss_fn = dnn_to_tree_distillation_param + weight_tensor = head_lib._weight_tensor( # pylint: disable=protected-access + features, head.weight_column_name) + dnn_logits_fixed = array_ops.stop_gradient(dnn_logits) + + if loss_fn is None: + # we create the loss_fn similar to the head loss_fn for + # multi_class_head used previously as the default one. + n_classes = 2 if head.logits_dimension == 1 else head.logits_dimension + loss_fn = distillation_loss.create_dnn_to_tree_cross_entropy_loss_fn( + n_classes) + + dnn_to_tree_distillation_loss = loss_weight * loss_fn( + dnn_logits_fixed, tree_logits, weight_tensor) + summary.scalar("dnn_to_tree_distillation_loss", + dnn_to_tree_distillation_loss) + loss += dnn_to_tree_distillation_loss + update_op = gbdt_model.train(loss, predictions_dict, labels) with ops.control_dependencies( [update_op]), (ops.colocate_with(global_step)): @@ -232,7 +285,7 @@ def _dnn_tree_combined_model_fn(features, return update_op if predict_with_tree_only: - if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.PREDICT: + if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.INFER: tree_train_logits = tree_logits else: tree_train_logits = control_flow_ops.cond( @@ -247,63 +300,98 @@ def _dnn_tree_combined_model_fn(features, del loss return control_flow_ops.no_op() - if use_core_versions: - model_fn_ops = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_no_train_op_fn, - logits=tree_train_logits) - dnn_train_op = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_dnn_train_op_fn, - logits=dnn_logits) - dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops( - dnn_train_op).train_op + if tree_center_bias: + num_trees += 1 + finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - tree_train_op = head.create_estimator_spec( - features=tree_features, - mode=mode, - labels=labels, - train_op_fn=_tree_train_op_fn, - logits=tree_train_logits) - tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops( - tree_train_op).train_op + if output_type == model.ModelBuilderOutputType.MODEL_FN_OPS: + if use_core_versions: + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits) + dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops( + dnn_train_op).train_op - model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) - else: - model_fn_ops = head.create_model_fn_ops( + tree_train_op = head.create_estimator_spec( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits) + tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops( + tree_train_op).train_op + + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( + model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits).train_op + tree_train_op = head.create_model_fn_ops( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits).train_op + + # Add the hooks + model_fn_ops.training_hooks.extend([ + trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, + tree_train_op), + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, + finalized_trees, + override_global_step_value) + ]) + return model_fn_ops + + elif output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC: + fusion_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_no_train_op_fn, logits=tree_train_logits) - dnn_train_op = head.create_model_fn_ops( + dnn_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_dnn_train_op_fn, - logits=dnn_logits).train_op - tree_train_op = head.create_model_fn_ops( + logits=dnn_logits) + tree_spec = head.create_estimator_spec( features=tree_features, mode=mode, labels=labels, train_op_fn=_tree_train_op_fn, - logits=tree_train_logits).train_op - - if tree_center_bias: - num_trees += 1 - finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - - model_fn_ops.training_hooks.extend([ - trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, - tree_train_op), - trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, finalized_trees) - ]) + logits=tree_train_logits) - return model_fn_ops + training_hooks = [ + trainer_hooks.SwitchTrainOp(dnn_spec.train_op, dnn_steps_to_train, + tree_spec.train_op), + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, + finalized_trees, + override_global_step_value) + ] + fusion_spec = fusion_spec._replace(training_hooks=training_hooks + + list(fusion_spec.training_hooks)) + return fusion_spec class DNNBoostedTreeCombinedClassifier(estimator.Estimator): @@ -331,7 +419,9 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, - use_core_versions=False): + dnn_to_tree_distillation_param=None, + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedClassifier instance. Args: @@ -378,8 +468,19 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ head = head_lib.multi_class_head( n_classes=n_classes, @@ -409,7 +510,9 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): predict_with_tree_only=predict_with_tree_only, tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, - use_core_versions=use_core_versions) + dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedClassifier, self).__init__( model_fn=_model_fn, @@ -442,7 +545,9 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, - use_core_versions=False): + dnn_to_tree_distillation_param=None, + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedRegressor instance. Args: @@ -489,8 +594,19 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ head = head_lib.regression_head( label_name=label_name, @@ -525,7 +641,9 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): predict_with_tree_only=predict_with_tree_only, tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, - use_core_versions=use_core_versions) + dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedRegressor, self).__init__( model_fn=_model_fn, @@ -559,7 +677,9 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, - use_core_versions=False): + dnn_to_tree_distillation_param=None, + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedEstimator instance. Args: @@ -601,8 +721,19 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ def _model_fn(features, labels, mode, config): @@ -626,10 +757,110 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): predict_with_tree_only=predict_with_tree_only, tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, - use_core_versions=use_core_versions) + dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) + + +class CoreDNNBoostedTreeCombinedEstimator(core_estimator.Estimator): + """Initializes a core version of DNNBoostedTreeCombinedEstimator. + + Args: + dnn_hidden_units: List of hidden units per layer for DNN. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + head: `Head` instance. + model_dir: Directory for model exports. + config: `RunConfig` of the estimator. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` + 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. + """ + + def __init__(self, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + head, + model_dir=None, + config=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + predict_with_tree_only=False, + tree_feature_columns=None, + tree_center_bias=False, + dnn_to_tree_distillation_param=None): + + def _model_fn(features, labels, mode, config): + return _dnn_tree_combined_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC, + use_core_versions=True, + override_global_step_value=None) + + super(CoreDNNBoostedTreeCombinedEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py index f495edc62f0909880c170ccb4cf5d11e3f20f55c..839eedd3a87ccaa1faecd1966fe5907d682cac02 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -28,10 +28,11 @@ from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest - +from tensorflow.python.training import checkpoint_utils def _train_input_fn(): features = { @@ -131,6 +132,97 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): classifier.fit(input_fn=_train_input_fn, steps=15) classifier.evaluate(input_fn=_eval_input_fn, steps=1) + def testFitAndEvaluateWithDistillation(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.DNNBoostedTreeCombinedClassifier( + dnn_hidden_units=[1], + dnn_feature_columns=[feature_column.real_valued_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + n_classes=2, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=False, + tree_feature_columns=[feature_column.real_valued_column("x")], + dnn_to_tree_distillation_param=(1, None)) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + + +class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): + + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + + def testTrainEvaluateInferDoesNotThrowErrorWithNoDnnInput(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + est = estimator.CoreDNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=False, + tree_feature_columns=[core_feature_column.numeric_column("x")]) + + # Train for a few steps. + est.train(input_fn=_train_input_fn, steps=1000) + # 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished + self._assert_checkpoint(est.model_dir, global_step=14) + res = est.evaluate(input_fn=_eval_input_fn, steps=1) + self.assertLess(0.5, res["auc"]) + est.predict(input_fn=_eval_input_fn) + + def testTrainEvaluateInferDoesNotThrowErrorWithDnnInput(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + est = estimator.CoreDNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=True, + tree_feature_columns=[]) + + # Train for a few steps. + est.train(input_fn=_train_input_fn, steps=1000) + res = est.evaluate(input_fn=_eval_input_fn, steps=1) + self.assertLess(0.5, res["auc"]) + est.predict(input_fn=_eval_input_fn) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 9c36c302210185bc390751a0229a61f2f8cd91b8..870ce2442bb5e98db7615c43054c9c827b8c88f0 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -22,7 +22,16 @@ from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.estimator.canned import head as core_head_lib +from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.ops import math_ops +from tensorflow.python.ops.losses import losses as core_losses + + +# ================== Old estimator interface=================================== +# The estimators below were designed for old feature columns and old estimator +# interface. They can be used with new feature columns and losses by setting +# use_core_libs = True. class GradientBoostedDecisionTreeClassifier(estimator.Estimator): @@ -42,7 +51,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): logits_modifier_function=None, center_bias=True, use_core_libs=False, - output_leaf_index=False): + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeClassifier estimator instance. Args: @@ -76,6 +86,14 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): for result_dict in result_iter: # access leaf index list by result_dict["leaf_index"] # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. Raises: ValueError: If learner_config is not valid. @@ -116,6 +134,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'use_core_libs': use_core_libs, 'output_leaf_index': output_leaf_index, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, @@ -139,7 +158,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): logits_modifier_function=None, center_bias=True, use_core_libs=False, - output_leaf_index=False): + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeRegressor estimator instance. Args: @@ -173,6 +193,14 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): for example_prediction_result in result_dict: # access leaf index list by example_prediction_result["leaf_index"] # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. """ head = head_lib.regression_head( label_name=label_name, @@ -196,6 +224,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, @@ -221,7 +250,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): logits_modifier_function=None, center_bias=True, use_core_libs=False, - output_leaf_index=False): + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeEstimator estimator instance. Args: @@ -251,6 +281,14 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): for example_prediction_result in result_dict: # access leaf index list by example_prediction_result["leaf_index"] # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. """ super(GradientBoostedDecisionTreeEstimator, self).__init__( model_fn=model.model_builder, @@ -265,7 +303,282 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, + 'override_global_step_value': override_global_step_value + }, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) + + +class GradientBoostedDecisionTreeRanker(estimator.Estimator): + """A ranking estimator using gradient boosted decision trees.""" + + def __init__(self, + learner_config, + examples_per_layer, + head, + ranking_model_pair_keys, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=False, + use_core_libs=False, + output_leaf_index=False, + override_global_step_value=None): + """Initializes a GradientBoostedDecisionTreeRanker instance. + + This is an estimator that can be trained off the pairwise data and can be + used for inference on non-paired data. This is essentially LambdaMart. + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + head: `Head` instance. + ranking_model_pair_keys: Keys to distinguish between features + for left and right part of the training pairs for ranking. For example, + for an Example with features "a.f1" and "b.f1", the keys would be + ("a", "b"). + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + use_core_libs: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is + [batch_size, num_trees]. + For example, + result_iter = classifier.predict(...) + for result_dict in result_iter: + # access leaf index list by result_dict["leaf_index"] + # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. + Raises: + ValueError: If learner_config is not valid. + """ + super(GradientBoostedDecisionTreeRanker, self).__init__( + model_fn=model.ranking_model_builder, + params={ + 'head': head, + 'n_classes': 2, + 'feature_columns': feature_columns, + 'learner_config': learner_config, + 'num_trees': num_trees, + 'weight_column_name': weight_column_name, + 'examples_per_layer': examples_per_layer, + 'center_bias': center_bias, + 'logits_modifier_function': logits_modifier_function, + 'use_core_libs': use_core_libs, + 'output_leaf_index': output_leaf_index, + 'ranking_model_pair_keys': ranking_model_pair_keys, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) + +# ================== New Estimator interface=================================== +# The estimators below use new core Estimator interface and must be used with +# new feature columns and heads. + +# For multiclass classification, use the following head since it uses loss +# that is twice differentiable. +def core_multiclass_head(n_classes): + """Core head for multiclass problems.""" + + def loss_fn(labels, logits): + result = losses.per_example_maxent_loss( + labels=labels, logits=logits, weights=None, num_classes=n_classes) + return result[0] + + # pylint:disable=protected-access + head_fn = core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=n_classes, + loss_fn=loss_fn, + loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + # pylint:enable=protected-access + + return head_fn + + +class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): + """An estimator using gradient boosted decision trees. + + Useful for training with user specified `Head`. + """ + + def __init__(self, + learner_config, + examples_per_layer, + head, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=True, + output_leaf_index=False): + """Initializes a core version of GradientBoostedDecisionTreeEstimator. + + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + head: `Head` instance. + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree + """ + + def _model_fn(features, labels, mode, config): + return model.model_builder( + features=features, + labels=labels, + mode=mode, + config=config, + params={ + 'head': head, + 'feature_columns': feature_columns, + 'learner_config': learner_config, + 'num_trees': num_trees, + 'weight_column_name': weight_column_name, + 'examples_per_layer': examples_per_layer, + 'center_bias': center_bias, + 'logits_modifier_function': logits_modifier_function, + 'use_core_libs': True, + 'output_leaf_index': output_leaf_index, + 'override_global_step_value': None + }, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) + + super(CoreGradientBoostedDecisionTreeEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) + + +class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): + """A ranking estimator using gradient boosted decision trees.""" + + def __init__(self, + learner_config, + examples_per_layer, + head, + ranking_model_pair_keys, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + logits_modifier_function=None, + center_bias=False, + output_leaf_index=False): + """Initializes a GradientBoostedDecisionTreeRanker instance. + + This is an estimator that can be trained off the pairwise data and can be + used for inference on non-paired data. This is essentially LambdaMart. + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + head: `Head` instance. + ranking_model_pair_keys: Keys to distinguish between features + for left and right part of the training pairs for ranking. For example, + for an Example with features "a.f1" and "b.f1", the keys would be + ("a", "b"). + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is + [batch_size, num_trees]. + For example, + result_iter = classifier.predict(...) + for result_dict in result_iter: + # access leaf index list by result_dict["leaf_index"] + # which contains one leaf index per tree + + Raises: + ValueError: If learner_config is not valid. + """ + + def _model_fn(features, labels, mode, config): + return model.ranking_model_builder( + features=features, + labels=labels, + mode=mode, + config=config, + params={ + 'head': head, + 'n_classes': 2, + 'feature_columns': feature_columns, + 'learner_config': learner_config, + 'num_trees': num_trees, + 'weight_column_name': weight_column_name, + 'examples_per_layer': examples_per_layer, + 'center_bias': center_bias, + 'logits_modifier_function': logits_modifier_function, + 'use_core_libs': True, + 'output_leaf_index': output_leaf_index, + 'ranking_model_pair_keys': ranking_model_pair_keys, + 'override_global_step_value': None + }, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) + + super(CoreGradientBoostedDecisionTreeRanker, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index 75ef1b050028b6462b255827c06e836e5c481844..68d710d713770a3a4a623b9447bb6a6b93569cac 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -25,10 +25,12 @@ from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import googletest +from tensorflow.python.training import checkpoint_utils def _train_input_fn(): @@ -37,18 +39,50 @@ def _train_input_fn(): return features, label +def _multiclass_train_input_fn(): + features = { + "x": constant_op.constant([[2.], [1.], [1.], [5.], [3.5], [4.6], [3.5]]) + } + label = constant_op.constant( + [[1], [0], [0], [2], [2], [0], [1]], dtype=dtypes.int32) + return features, label + + +def _ranking_train_input_fn(): + features = { + "a.f1": constant_op.constant([[3.], [0.3], [1.]]), + "a.f2": constant_op.constant([[0.1], [3.], [1.]]), + "b.f1": constant_op.constant([[13.], [0.4], [5.]]), + "b.f2": constant_op.constant([[1.], [3.], [0.01]]), + } + label = constant_op.constant([[0], [0], [1]], dtype=dtypes.int32) + return features, label + + def _eval_input_fn(): features = {"x": constant_op.constant([[1.], [2.], [2.]])} label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32) return features, label +def _infer_ranking_train_input_fn(): + features = { + "f1": constant_op.constant([[3.], [2], [1.]]), + "f2": constant_op.constant([[0.1], [3.], [1.]]) + } + return features, None + + class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): def setUp(self): self._export_dir_base = tempfile.mkdtemp() + "export/" gfile.MkDir(self._export_dir_base) + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + def testFitAndEvaluateDontThrowException(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 @@ -155,6 +189,290 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): regressor.evaluate(input_fn=_eval_input_fn, steps=1) regressor.export(self._export_dir_base) + def testRankingDontThrowExceptionForForEstimator(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + model = estimator.GradientBoostedDecisionTreeRanker( + head=head_fn, + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + use_core_libs=True, + feature_columns=[ + core_feature_column.numeric_column("f1"), + core_feature_column.numeric_column("f2") + ], + ranking_model_pair_keys=("a", "b")) + + model.fit(input_fn=_ranking_train_input_fn, steps=1000) + model.evaluate(input_fn=_ranking_train_input_fn, steps=1) + model.predict(input_fn=_infer_ranking_train_input_fn) + + def testDoesNotOverrideGlobalSteps(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 2 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=False) + + classifier.fit(input_fn=_train_input_fn, steps=15) + # When no override of global steps, 5 steps were used. + self._assert_checkpoint(classifier.model_dir, global_step=5) + + def testOverridesGlobalSteps(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 2 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=False, + override_global_step_value=10000000) + + classifier.fit(input_fn=_train_input_fn, steps=15) + self._assert_checkpoint(classifier.model_dir, global_step=10000000) + + def testFitAndEvaluateMultiClassTreePerClassDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 3 + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.TREE_PER_CLASS) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + n_classes=learner_config.num_classes, + num_trees=1, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("classes" in prediction_dict) + + def testFitAndEvaluateMultiClassDiagonalDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 3 + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + n_classes=learner_config.num_classes, + num_trees=1, + examples_per_layer=7, + model_dir=model_dir, + config=config, + center_bias=False, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("classes" in prediction_dict) + + def testFitAndEvaluateMultiClassFullDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 3 + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.FULL_HESSIAN) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + n_classes=learner_config.num_classes, + num_trees=1, + examples_per_layer=7, + model_dir=model_dir, + config=config, + center_bias=False, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("classes" in prediction_dict) + + +class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): + + def testTrainEvaluateInferDoesNotThrowError(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + est = estimator.CoreGradientBoostedDecisionTreeEstimator( + head=head_fn, + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + # Train for a few steps. + est.train(input_fn=_train_input_fn, steps=1000) + est.evaluate(input_fn=_eval_input_fn, steps=1) + est.predict(input_fn=_eval_input_fn) + + def testRankingDontThrowExceptionForForEstimator(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + est = estimator.CoreGradientBoostedDecisionTreeRanker( + head=head_fn, + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[ + core_feature_column.numeric_column("f1"), + core_feature_column.numeric_column("f2") + ], + ranking_model_pair_keys=("a", "b")) + + # Train for a few steps. + est.train(input_fn=_ranking_train_input_fn, steps=1000) + est.evaluate(input_fn=_ranking_train_input_fn, steps=1) + est.predict(input_fn=_infer_ranking_train_input_fn) + + def testFitAndEvaluateMultiClassTreePerClasssDontThrowException(self): + n_classes = 3 + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = n_classes + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.TREE_PER_CLASS) + + head_fn = estimator.core_multiclass_head(n_classes=n_classes) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.CoreGradientBoostedDecisionTreeEstimator( + learner_config=learner_config, + head=head_fn, + num_trees=1, + center_bias=False, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + classifier.train(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1) + classifier.predict(input_fn=_eval_input_fn) + + def testFitAndEvaluateMultiClassDiagonalDontThrowException(self): + n_classes = 3 + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = n_classes + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + + head_fn = estimator.core_multiclass_head(n_classes=n_classes) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.CoreGradientBoostedDecisionTreeEstimator( + learner_config=learner_config, + head=head_fn, + num_trees=1, + center_bias=False, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + classifier.train(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1) + classifier.predict(input_fn=_eval_input_fn) + + def testFitAndEvaluateMultiClassFullDontThrowException(self): + n_classes = 3 + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = n_classes + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.FULL_HESSIAN) + + head_fn = estimator.core_multiclass_head(n_classes=n_classes) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.CoreGradientBoostedDecisionTreeEstimator( + learner_config=learner_config, + head=head_fn, + num_trees=1, + center_bias=False, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + classifier.train(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1) + classifier.predict(input_fn=_eval_input_fn) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 1ee891198939e53fc5913104b2c2e65dc977823f..04b46c3483fa25286078b88c2776b76e4f3c0bcf 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -20,6 +20,7 @@ from __future__ import print_function import copy +from tensorflow.contrib import learn from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops @@ -28,8 +29,17 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import state_ops from tensorflow.python.training import training_util +class ModelBuilderOutputType(object): + MODEL_FN_OPS = 0 + ESTIMATOR_SPEC = 1 -def model_builder(features, labels, mode, params, config): + +def model_builder(features, + labels, + mode, + params, + config, + output_type=ModelBuilderOutputType.MODEL_FN_OPS): """Multi-machine batch gradient descent tree model. Args: @@ -48,7 +58,13 @@ def model_builder(features, labels, mode, params, config): * weight_column_name: The name of weight column. * center_bias: Whether a separate tree should be created for first fitting the bias. + * override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. config: `RunConfig` of the estimator. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). Returns: A `ModelFnOps` object. @@ -64,6 +80,7 @@ def model_builder(features, labels, mode, params, config): use_core_libs = params["use_core_libs"] logits_modifier_function = params["logits_modifier_function"] output_leaf_index = params["output_leaf_index"] + override_global_step_value = params.get("override_global_step_value", None) if features is None: raise ValueError("At least one feature must be specified.") @@ -115,29 +132,271 @@ def model_builder(features, labels, mode, params, config): return update_op create_estimator_spec_op = getattr(head, "create_estimator_spec", None) - if use_core_libs and callable(create_estimator_spec_op): - model_fn_ops = head.create_estimator_spec( + + training_hooks = [] + if num_trees: + if center_bias: + num_trees += 1 + + finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() + training_hooks.append( + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, + finalized_trees, + override_global_step_value)) + + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + if use_core_libs and callable(create_estimator_spec_op): + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( + model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] + + model_fn_ops.training_hooks.extend(training_hooks) + return model_fn_ops + elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC: + assert callable(create_estimator_spec_op) + estimator_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_train_op_fn, logits=logits) - model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) + + estimator_spec = estimator_spec._replace( + training_hooks=training_hooks + list(estimator_spec.training_hooks)) + return estimator_spec + + return model_fn_ops + + +def ranking_model_builder(features, + labels, + mode, + params, + config, + output_type=ModelBuilderOutputType.MODEL_FN_OPS): + """Multi-machine batch gradient descent tree model for ranking. + + Args: + features: `Tensor` or `dict` of `Tensor` objects. + labels: Labels used to train on. + mode: Mode we are in. (TRAIN/EVAL/INFER) + params: A dict of hyperparameters. + The following hyperparameters are expected: + * head: A `Head` instance. + * learner_config: A config for the learner. + * feature_columns: An iterable containing all the feature columns used by + the model. + * examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + * weight_column_name: The name of weight column. + * center_bias: Whether a separate tree should be created for first fitting + the bias. + * ranking_model_pair_keys (Optional): Keys to distinguish between features + for left and right part of the training pairs for ranking. For example, + for an Example with features "a.f1" and "b.f1", the keys would be + ("a", "b"). + * override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. + config: `RunConfig` of the estimator. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). + + + Returns: + A `ModelFnOps` object. + Raises: + ValueError: if inputs are not valid. + """ + head = params["head"] + learner_config = params["learner_config"] + examples_per_layer = params["examples_per_layer"] + feature_columns = params["feature_columns"] + weight_column_name = params["weight_column_name"] + num_trees = params["num_trees"] + use_core_libs = params["use_core_libs"] + logits_modifier_function = params["logits_modifier_function"] + output_leaf_index = params["output_leaf_index"] + ranking_model_pair_keys = params["ranking_model_pair_keys"] + override_global_step_value = params.get("override_global_step_value", None) + + if features is None: + raise ValueError("At least one feature must be specified.") + + if config is None: + raise ValueError("Missing estimator RunConfig.") + + center_bias = params["center_bias"] + + if isinstance(features, ops.Tensor): + features = {features.name: features} + + # Make a shallow copy of features to ensure downstream usage + # is unaffected by modifications in the model function. + training_features = copy.copy(features) + training_features.pop(weight_column_name, None) + global_step = training_util.get_global_step() + with ops.device(global_step.device): + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config="", # Initialize an empty ensemble. + name="ensemble_model") + + # Extract the features. + if mode == learn.ModeKeys.TRAIN or mode == learn.ModeKeys.EVAL: + # For ranking pairwise training, we extract two sets of features. + if len(ranking_model_pair_keys) != 2: + raise ValueError("You must provide keys for ranking.") + left_pair_key = ranking_model_pair_keys[0] + right_pair_key = ranking_model_pair_keys[1] + if left_pair_key is None or right_pair_key is None: + raise ValueError("Both pair keys should be provided for ranking.") + + features_1 = {} + features_2 = {} + for name in training_features: + feature = training_features[name] + new_name = name[2:] + if name.startswith(left_pair_key + "."): + features_1[new_name] = feature + else: + assert name.startswith(right_pair_key + ".") + features_2[new_name] = feature + + main_features = features_1 + supplementary_features = features_2 else: - model_fn_ops = head.create_model_fn_ops( + # For non-ranking or inference ranking, we have only 1 set of features. + main_features = training_features + + # Create GBDT model. + gbdt_model_main = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=config.is_chief, + num_ps_replicas=config.num_ps_replicas, + ensemble_handle=ensemble_handle, + center_bias=center_bias, + examples_per_layer=examples_per_layer, + learner_config=learner_config, + feature_columns=feature_columns, + logits_dimension=head.logits_dimension, + features=main_features, + use_core_columns=use_core_libs, + output_leaf_index=output_leaf_index) + + with ops.name_scope("gbdt", "gbdt_optimizer"): + # Logits for inference. + if mode == learn.ModeKeys.INFER: + predictions_dict = gbdt_model_main.predict(mode) + logits = predictions_dict[gbdt_batch.PREDICTIONS] + if logits_modifier_function: + logits = logits_modifier_function(logits, features, mode) + else: + gbdt_model_supplementary = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=config.is_chief, + num_ps_replicas=config.num_ps_replicas, + ensemble_handle=ensemble_handle, + center_bias=center_bias, + examples_per_layer=examples_per_layer, + learner_config=learner_config, + feature_columns=feature_columns, + logits_dimension=head.logits_dimension, + features=supplementary_features, + use_core_columns=use_core_libs, + output_leaf_index=output_leaf_index) + + # Logits for train and eval. + if not supplementary_features: + raise ValueError("Features for ranking must be specified.") + + predictions_dict_1 = gbdt_model_main.predict(mode) + predictions_1 = predictions_dict_1[gbdt_batch.PREDICTIONS] + + predictions_dict_2 = gbdt_model_supplementary.predict(mode) + predictions_2 = predictions_dict_2[gbdt_batch.PREDICTIONS] + + logits = predictions_1 - predictions_2 + if logits_modifier_function: + logits = logits_modifier_function(logits, features, mode) + + predictions_dict = predictions_dict_1 + predictions_dict[gbdt_batch.PREDICTIONS] = logits + + def _train_op_fn(loss): + """Returns the op to optimize the loss.""" + update_op = gbdt_model_main.train(loss, predictions_dict, labels) + with ops.control_dependencies( + [update_op]), (ops.colocate_with(global_step)): + update_op = state_ops.assign_add(global_step, 1).op + return update_op + + create_estimator_spec_op = getattr(head, "create_estimator_spec", None) + + training_hooks = [] + if num_trees: + if center_bias: + num_trees += 1 + + finalized_trees, attempted_trees = ( + gbdt_model_main.get_number_of_trees_tensor()) + training_hooks.append( + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, + finalized_trees, + override_global_step_value)) + + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + if use_core_libs and callable(create_estimator_spec_op): + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( + model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] + + model_fn_ops.training_hooks.extend(training_hooks) + return model_fn_ops + + elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC: + assert callable(create_estimator_spec_op) + estimator_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_train_op_fn, logits=logits) - if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: - model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ - gbdt_batch.LEAF_INDEX] - if num_trees: - if center_bias: - num_trees += 1 - finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - model_fn_ops.training_hooks.append( - trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, - finalized_trees)) + + estimator_spec = estimator_spec._replace( + training_hooks=training_hooks + list(estimator_spec.training_hooks)) + return estimator_spec + return model_fn_ops diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py index 2e4151cac40f770e2bece70d752122eb7f34dd40..f137ada35524bf2467314f4a284ea35a82f06825 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py @@ -25,6 +25,7 @@ from tensorflow.contrib.learn.python.learn.session_run_hook import SessionRunArg from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training_util from tensorflow.python.training.summary_io import SummaryWriterCache @@ -150,12 +151,23 @@ class FeedFnHook(session_run_hook.SessionRunHook): class StopAfterNTrees(session_run_hook.SessionRunHook): """Stop training after building N full trees.""" - def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor): + def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor, + override_global_step_value=None): self._num_trees = n # num_attempted_trees_tensor and num_finalized_trees_tensor are both # tensors. self._num_attempted_trees_tensor = num_attempted_trees_tensor self._num_finalized_trees_tensor = num_finalized_trees_tensor + self._override_global_step_value = override_global_step_value + + def begin(self): + self._global_step_tensor = training_util.get_global_step() + if self._global_step_tensor is None: + raise RuntimeError("Global step should be created.") + + if self._override_global_step_value is not None: + self._override_global_step_op = state_ops.assign( + self._global_step_tensor, self._override_global_step_value) def before_run(self, run_context): del run_context # unused by StopTrainingAfterNTrees. @@ -175,6 +187,9 @@ class StopAfterNTrees(session_run_hook.SessionRunHook): num_attempted_trees > 2 * self._num_trees): logging.info("Requesting stop since we have reached %d trees.", num_finalized_trees) + if self._override_global_step_value is not None: + logging.info("Overriding global steps value.") + run_context.session.run(self._override_global_step_op) run_context.request_stop() diff --git a/tensorflow/contrib/boosted_trees/examples/boston.py b/tensorflow/contrib/boosted_trees/examples/boston.py index e9dbdb0fd784052eeb36ac1aa9342165ef2ac0a7..54c4ff059e3408d2cb8fc689a9ae877f57485f58 100644 --- a/tensorflow/contrib/boosted_trees/examples/boston.py +++ b/tensorflow/contrib/boosted_trees/examples/boston.py @@ -45,6 +45,7 @@ from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientB from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn import learn_runner +from tensorflow.python.util import compat _BOSTON_NUM_FEATURES = 13 @@ -79,7 +80,8 @@ def _convert_fn(dtec, sorted_feature_names, num_dense, num_sparse_float, num_sparse_int, export_dir, unused_eval_result): universal_format = custom_export_strategy.convert_to_universal_format( dtec, sorted_feature_names, num_dense, num_sparse_float, num_sparse_int) - with tf.gfile.GFile(os.path.join(export_dir, "tree_proto"), "w") as f: + with tf.gfile.GFile(os.path.join( + compat.as_bytes(export_dir), compat.as_bytes("tree_proto")), "w") as f: f.write(str(universal_format)) diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 0b28f81e7ca9a1228adc5bde19c429265e0aa9b8..1375fddf2bea1a8f856c35d756c38a8beb14a53f 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -125,6 +125,8 @@ void QuantizeFeatures( auto flat_values = values_tensor.flat(); for (int64 instance = 0; instance < num_values; ++instance) { const float value = flat_values(instance); + CHECK(!buckets_vector.empty()) + << "Got empty buckets for feature " << feature_index; auto bucket_iter = std::lower_bound(buckets_vector.begin(), buckets_vector.end(), value); if (bucket_iter == buckets_vector.end()) { @@ -241,6 +243,11 @@ class CreateQuantileAccumulatorOp : public OpKernel { // other exceptions. If one already exists, it unrefs the new one. const Tensor* stamp_token_t; OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); + // An epsilon value of zero could cause perfoamance issues and is therefore, + // disallowed. + OP_REQUIRES( + context, epsilon_ > 0, + errors::InvalidArgument("An epsilon value of zero is not allowed.")); auto result = new QuantileStreamResource(epsilon_, num_quantiles_, max_elements_, generate_quantiles_, stamp_token_t->scalar()()); diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 401bec84a20a0fefcddbfa1039a117e65f853633..d9e7a0f4660470a0c79ad7a832db233481161770 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -34,7 +34,9 @@ namespace tensorflow { +using boosted_trees::learner::LearnerConfig; using boosted_trees::learner::LearnerConfig_MultiClassStrategy; +using boosted_trees::learner::ObliviousSplitInfo; using boosted_trees::learner::SplitInfo; using boosted_trees::learner::stochastic::GradientStats; using boosted_trees::learner::stochastic::NodeStats; @@ -158,6 +160,11 @@ class BuildDenseInequalitySplitsOp : public OpKernel { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); + const Tensor* weak_learner_type_t; + OP_REQUIRES_OK(context, + context->input("weak_learner_type", &weak_learner_type_t)); + const int32 weak_learner_type = weak_learner_type_t->scalar()(); + // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; partition_boundaries.push_back(0); @@ -188,20 +195,59 @@ class BuildDenseInequalitySplitsOp : public OpKernel { tensorflow::TTypes::Vec output_partition_ids = output_partition_ids_t->vec(); - Tensor* gains_t = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output("gains", TensorShape({num_elements}), - &gains_t)); + // For a normal tree, we output a split per partition. For an oblivious + // tree, we output one split for all partitions of the layer + int32 size_output = num_elements; + if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE && + num_elements > 0) { + size_output = 1; + } + Tensor* gains_t = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + "gains", TensorShape({size_output}), &gains_t)); tensorflow::TTypes::Vec gains = gains_t->vec(); Tensor* output_splits_t = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - "split_infos", TensorShape({num_elements}), - &output_splits_t)); + OP_REQUIRES_OK(context, context->allocate_output("split_infos", + TensorShape({size_output}), + &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + + if (num_elements == 0) { + return; + } SplitBuilderState state(context); + switch (weak_learner_type) { + case LearnerConfig::NORMAL_DECISION_TREE: { + ComputeNormalDecisionTree( + &state, normalizer_ratio, num_elements, partition_boundaries, + bucket_boundaries, partition_ids, bucket_ids, gradients_t, + hessians_t, &output_partition_ids, &gains, &output_splits); + break; + } + case LearnerConfig::OBLIVIOUS_DECISION_TREE: { + ComputeObliviousDecisionTree( + &state, normalizer_ratio, num_elements, partition_boundaries, + bucket_boundaries, partition_ids, bucket_ids, gradients_t, + hessians_t, &output_partition_ids, &gains, &output_splits); + break; + } + } + } + + private: + void ComputeNormalDecisionTree( + SplitBuilderState* state, const float normalizer_ratio, + const int num_elements, const std::vector& partition_boundaries, + const tensorflow::TTypes::ConstVec& bucket_boundaries, + const tensorflow::TTypes::ConstVec& partition_ids, + const tensorflow::TTypes::ConstMatrix& bucket_ids, + const Tensor* gradients_t, const Tensor* hessians_t, + tensorflow::TTypes::Vec* output_partition_ids, + tensorflow::TTypes::Vec* gains, + tensorflow::TTypes::Vec* output_splits) { for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[root_idx]; @@ -213,7 +259,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel { GradientStats(*gradients_t, *hessians_t, bucket_idx); } root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats); int32 best_bucket_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -223,10 +269,10 @@ class BuildDenseInequalitySplitsOp : public OpKernel { GradientStats g(*gradients_t, *hessians_t, bucket_idx); g *= normalizer_ratio; left_gradient_stats += g; - NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); + NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats); GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); + NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -237,20 +283,124 @@ class BuildDenseInequalitySplitsOp : public OpKernel { SplitInfo split_info; auto* dense_split = split_info.mutable_split_node()->mutable_dense_float_binary_split(); - dense_split->set_feature_column(state.feature_column_group_id()); + dense_split->set_feature_column(state->feature_column_group_id()); dense_split->set_threshold( bucket_boundaries(bucket_ids(best_bucket_idx, 0))); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - state.FillLeaf(best_left_node_stats, left_child); - state.FillLeaf(best_right_node_stats, right_child); - split_info.SerializeToString(&output_splits(root_idx)); - gains(root_idx) = - best_gain - root_stats.gain - state.tree_complexity_regularization(); - output_partition_ids(root_idx) = partition_ids(start_index); + state->FillLeaf(best_left_node_stats, left_child); + state->FillLeaf(best_right_node_stats, right_child); + split_info.SerializeToString(&(*output_splits)(root_idx)); + (*gains)(root_idx) = + best_gain - root_stats.gain - state->tree_complexity_regularization(); + (*output_partition_ids)(root_idx) = partition_ids(start_index); + } + } + void ComputeObliviousDecisionTree( + SplitBuilderState* state, const float normalizer_ratio, + const int num_elements, const std::vector& partition_boundaries, + const tensorflow::TTypes::ConstVec& bucket_boundaries, + const tensorflow::TTypes::ConstVec& partition_ids, + const tensorflow::TTypes::ConstMatrix& bucket_ids, + const Tensor* gradients_t, const Tensor* hessians_t, + tensorflow::TTypes::Vec* output_partition_ids, + tensorflow::TTypes::Vec* gains, + tensorflow::TTypes::Vec* output_splits) { + // Holds the root stats per each node to be split. + std::vector current_layer_stats; + current_layer_stats.reserve(num_elements); + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + const int start_index = partition_boundaries[root_idx]; + const int end_index = partition_boundaries[root_idx + 1]; + GradientStats root_gradient_stats; + for (int64 bucket_idx = start_index; bucket_idx < end_index; + ++bucket_idx) { + root_gradient_stats += + GradientStats(*gradients_t, *hessians_t, bucket_idx); + } + root_gradient_stats *= normalizer_ratio; + current_layer_stats.push_back(root_gradient_stats); + } + + float best_gain = std::numeric_limits::lowest(); + int64 best_bucket_idx = 0; + std::vector best_right_node_stats(num_elements, NodeStats(0)); + std::vector best_left_node_stats(num_elements, NodeStats(0)); + std::vector current_left_node_stats(num_elements, NodeStats(0)); + std::vector current_right_node_stats(num_elements, NodeStats(0)); + int64 current_bucket_id = 0; + int64 last_bucket_id = -1; + // Indexes offsets for each of the partitions that can be used to access + // gradients of a partition for a current bucket we consider. + std::vector current_layer_offsets(num_elements, 0); + std::vector left_gradient_stats(num_elements); + // The idea is to try every bucket id in increasing order. In each iteration + // we calculate the gain of the layer using the current bucket id as split + // value, and we also obtain the following bucket id to try. + while (current_bucket_id > last_bucket_id) { + last_bucket_id = current_bucket_id; + int64 next_bucket_id = -1; + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + int idx = + current_layer_offsets[root_idx] + partition_boundaries[root_idx]; + const int end_index = partition_boundaries[root_idx + 1]; + if (idx < end_index && bucket_ids(idx, 0) == current_bucket_id) { + GradientStats g(*gradients_t, *hessians_t, idx); + g *= normalizer_ratio; + left_gradient_stats[root_idx] += g; + current_layer_offsets[root_idx]++; + idx++; + } + if (idx < end_index && + (bucket_ids(idx, 0) < next_bucket_id || next_bucket_id == -1)) { + next_bucket_id = bucket_ids(idx, 0); + } + } + float gain_of_split = 0.0; + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + GradientStats right_gradient_stats = + current_layer_stats[root_idx] - left_gradient_stats[root_idx]; + NodeStats left_stat = + state->ComputeNodeStats(left_gradient_stats[root_idx]); + NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats); + gain_of_split += left_stat.gain + right_stat.gain; + current_left_node_stats[root_idx] = left_stat; + current_right_node_stats[root_idx] = right_stat; + } + if (gain_of_split > best_gain) { + best_gain = gain_of_split; + best_left_node_stats = current_left_node_stats; + best_right_node_stats = current_right_node_stats; + } + current_bucket_id = next_bucket_id; + } + + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain; + } + best_gain -= num_elements * state->tree_complexity_regularization(); + + ObliviousSplitInfo oblivious_split_info; + auto* oblivious_dense_split = oblivious_split_info.mutable_split_node() + ->mutable_dense_float_binary_split(); + oblivious_dense_split->set_feature_column(state->feature_column_group_id()); + oblivious_dense_split->set_threshold( + bucket_boundaries(bucket_ids(best_bucket_idx, 0))); + (*gains)(0) = best_gain; + + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + auto* left_children = oblivious_split_info.add_children_leaves(); + auto* right_children = oblivious_split_info.add_children_leaves(); + + state->FillLeaf(best_left_node_stats[root_idx], left_children); + state->FillLeaf(best_right_node_stats[root_idx], right_children); + + const int start_index = partition_boundaries[root_idx]; + (*output_partition_ids)(root_idx) = partition_ids(start_index); } + oblivious_split_info.SerializeToString(&(*output_splits)(0)); } }; REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU), diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 1bfeed306641111718984b2097512e5ec3fa8630..6d9a6ee5a0d05465459393c4339558f1ca38d417 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -372,12 +372,18 @@ class GrowTreeEnsembleOp : public OpKernel { return; } + // Get the max tree depth. + const Tensor* max_tree_depth_t; + OP_REQUIRES_OK(context, + context->input("max_tree_depth", &max_tree_depth_t)); + const int32 max_tree_depth = max_tree_depth_t->scalar()(); + // Update and retrieve the growable tree. // If the tree is fully built and dropout was applied, it also adjusts the // weights of dropped and the last tree. boosted_trees::trees::DecisionTreeConfig* const tree_config = UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate, - dropout_seed); + dropout_seed, max_tree_depth); // Split tree nodes. for (auto& split_entry : best_splits) { @@ -494,7 +500,8 @@ class GrowTreeEnsembleOp : public OpKernel { boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree( boosted_trees::models::DecisionTreeEnsembleResource* const ensemble_resource, - const float learning_rate, const uint64 dropout_seed) { + const float learning_rate, const uint64 dropout_seed, + const int32 max_tree_depth) { const auto num_trees = ensemble_resource->num_trees(); if (num_trees <= 0 || ensemble_resource->LastTreeMetadata()->is_finalized()) { @@ -506,8 +513,7 @@ class GrowTreeEnsembleOp : public OpKernel { tree_config->add_nodes()->mutable_leaf(); boosted_trees::trees::DecisionTreeMetadata* const tree_metadata = ensemble_resource->LastTreeMetadata(); - tree_metadata->set_is_finalized( - learner_config_.constraints().max_tree_depth() <= 1); + tree_metadata->set_is_finalized(max_tree_depth <= 1); tree_metadata->set_num_tree_weight_updates(1); } else { // The growable tree is by definition the last tree in the ensemble. @@ -518,8 +524,7 @@ class GrowTreeEnsembleOp : public OpKernel { << num_trees - 1 << " of ensemble of " << num_trees << " trees."; // Update growable tree metadata. tree_metadata->set_num_layers_grown(new_num_layers); - tree_metadata->set_is_finalized( - new_num_layers >= learner_config_.constraints().max_tree_depth()); + tree_metadata->set_is_finalized(new_num_layers >= max_tree_depth); } UpdateTreeWeightsIfDropout(ensemble_resource, dropout_seed); return ensemble_resource->LastTree(); diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py index 56ff00b39062d57c813633c98c765e077dd4c262..5d4819b0f1cb598cfbe146f569aecd7883186339 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py @@ -37,6 +37,7 @@ class BaseSplitHandler(object): gradient_shape, hessian_shape, multiclass_strategy, + loss_uses_sum_reduction=False, name=None): """Constructor for BaseSplitHandler. @@ -51,6 +52,8 @@ class BaseSplitHandler(object): gradient_shape: A TensorShape, containing shape of gradients. hessian_shape: A TensorShape, containing shape of hessians. multiclass_strategy: Strategy describing how to treat multiclass problems. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ self._l1_regularization = l1_regularization @@ -62,6 +65,7 @@ class BaseSplitHandler(object): self._multiclass_strategy = multiclass_strategy self._hessian_shape = hessian_shape self._gradient_shape = gradient_shape + self._loss_uses_sum_reduction = loss_uses_sum_reduction def scheduled_reads(self): """Returns the list of `ScheduledOp`s required for update_stats.""" @@ -127,6 +131,10 @@ class BaseSplitHandler(object): }, stamp_token, None) return control_flow_ops.group(update_1, *update_2[self]) + @abc.abstractmethod + def reset(self, stamp_token, next_stamp_token): + """Resets the state maintained by the handler.""" + @abc.abstractmethod def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state. diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index 9f78ab20242800fd8af7ad049d5970fbe26ec0ea..efe29216c2a7d8aa985da54cdbb839b9e6f69078 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -23,6 +23,7 @@ from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -44,6 +45,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): hessian_shape, multiclass_strategy, init_stamp_token=0, + loss_uses_sum_reduction=False, name=None): """Initialize the internal state for this split handler. @@ -62,6 +64,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): multiclass_strategy: Strategy describing how to treat multiclass problems. init_stamp_token: A tensor containing an scalar for initial stamp of the stamped objects. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ super(EqualitySplitHandler, self).__init__( @@ -73,6 +77,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): gradient_shape=gradient_shape, hessian_shape=hessian_shape, multiclass_strategy=multiclass_strategy, + loss_uses_sum_reduction=loss_uses_sum_reduction, name=name) self._stats_accumulator = stats_accumulator_ops.StatsAccumulator( init_stamp_token, @@ -173,6 +178,11 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): # pair. num_minibatches, partition_ids, feature_ids, gradients, hessians = ( self._stats_accumulator.flush(stamp_token, next_stamp_token)) + # For sum_reduction, we don't need to divide by number of minibatches. + + num_minibatches = control_flow_ops.cond( + ops.convert_to_tensor(self._loss_uses_sum_reduction), + lambda: math_ops.to_int64(1), lambda: num_minibatches) partition_ids, gains, split_infos = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=num_minibatches, @@ -187,8 +197,12 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): tree_complexity_regularization=self._tree_complexity_regularization, min_node_weight=self._min_node_weight, bias_feature_id=_BIAS_FEATURE_ID, - multiclass_strategy=self._multiclass_strategy,)) + multiclass_strategy=self._multiclass_strategy)) # There are no warm-up rounds needed in the equality column handler. So we # always return ready. are_splits_ready = constant_op.constant(True) return (are_splits_ready, partition_ids, gains, split_infos) + + def reset(self, stamp_token, next_stamp_token): + reset = self._stats_accumulator.flush(stamp_token, next_stamp_token) + return reset diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index 0b65eba2a76273a81f1464ed7639f0c0760e0050..ef253e7cec4e8a96b360ced32b59398c2e2c9680 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -90,7 +90,17 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): empty_hessians, example_weights, is_active=array_ops.constant([True, True])) - with ops.control_dependencies([update_1]): + update_2 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + + with ops.control_dependencies([update_1, update_2]): are_splits_ready, partitions, gains, splits = ( split_handler.make_splits(0, 1, class_id)) are_splits_ready, partitions, gains, splits = (sess.run( @@ -159,6 +169,129 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(1, split_node.feature_id) + def testGenerateFeatureSplitCandidatesSumReduction(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Feature ID | + # i0 | (0.2, 0.12) | 0 | 1,2 | + # i1 | (-0.5, 0.07) | 0 | | + # i2 | (1.2, 0.2) | 0 | 2 | + # i3 | (4.0, 0.13) | 1 | 1 | + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = [0, 0, 0, 1] + indices = [[0, 0], [0, 1], [2, 0], [3, 0]] + values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = categorical_split_handler.EqualitySplitHandler( + l1_regularization=0.1, + l2_regularization=1, + tree_complexity_regularization=0, + min_node_weight=0, + sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]), + feature_column_group_id=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + init_stamp_token=0, + loss_uses_sum_reduction=True) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + update_2 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1, update_2]): + are_splits_ready, partitions, gains, splits = ( + split_handler.make_splits(0, 1, class_id)) + are_splits_ready, partitions, gains, splits = ( + sess.run([are_splits_ready, partitions, gains, splits])) + self.assertTrue(are_splits_ready) + self.assertAllEqual([0, 1], partitions) + + # Check the split on partition 0. + # -(0.4 + 2.4 - 0.1) / (0.24 + 0.4 + 1) + expected_left_weight = -1.6463414634146338 + + # (0.4 + 2.4 - 0.1) ** 2 / (0.24 + 0.4 + 1) + expected_left_gain = 4.445121951219511 + + # -(-1 + 0.1) / (0.14 + 1) + expected_right_weight = 0.789473684211 + + # (-1 + 0.1) ** 2 / (0.14 + 1) + expected_right_gain = 0.710526315789 + + # (0.4 + -1 + 2.4 - 0.1) ** 2 / (0.24 + 0.14 + 0.4 + 1) + expected_bias_gain = 1.6235955056179772 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.categorical_id_binary_split + + self.assertEqual(0, split_node.feature_column) + + self.assertEqual(2, split_node.feature_id) + + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0], + 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + + # Check the split on partition 1. + # (-8 + 0.1) / (0.26 + 1) + expected_left_weight = -6.26984126984 + # (-8 + 0.1) ** 2 / (0.26 + 1) + expected_left_gain = 49.5317460317 + expected_right_weight = 0 + expected_right_gain = 0 + # (-8 + 0.1) ** 2 / (0.26 + 1) + expected_bias_gain = 49.5317460317 + + # Verify candidate for partition 1, there's only one active feature here + # so zero gain is expected. + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[1]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.categorical_id_binary_split + self.assertAllClose(0.0, gains[1], 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + + self.assertEqual(0, split_node.feature_column) + + self.assertEqual(1, split_node.feature_id) + def testGenerateFeatureSplitCandidatesMulticlass(self): with self.test_session() as sess: # Batch size is 4, 2 gradients per each instance. diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 409a2d8f46c331c13aec10542c4967d50575e94a..f45010ec26ed25127ca78b97f4d6fd7ebd6467ae 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -64,6 +64,7 @@ from __future__ import print_function import re from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler +from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops from tensorflow.contrib.boosted_trees.python.ops import quantile_ops @@ -79,6 +80,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops + _BIAS_FEATURE_ID = -1 # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") @@ -99,6 +101,7 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): hessian_shape, multiclass_strategy, init_stamp_token=0, + loss_uses_sum_reduction=False, name=None): """Initialize the internal state for this split handler. @@ -117,6 +120,8 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): multiclass_strategy: Strategy describing how to treat multiclass problems. init_stamp_token: A tensor containing an scalar for initial stamp of the stamped objects. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ super(InequalitySplitHandler, self).__init__( @@ -128,7 +133,8 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): feature_column_group_id=feature_column_group_id, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=multiclass_strategy) + multiclass_strategy=multiclass_strategy, + loss_uses_sum_reduction=loss_uses_sum_reduction) self._stats_accumulator = stats_accumulator_ops.StatsAccumulator( init_stamp_token, gradient_shape, @@ -143,6 +149,11 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): num_quantiles=num_quantiles, name="QuantileAccumulator/{}".format(self._name)) + def reset(self, stamp_token, next_stamp_token): + reset_1 = self._stats_accumulator.flush(stamp_token, next_stamp_token) + reset_2 = self._quantile_accumulator.flush(stamp_token, next_stamp_token) + return control_flow_ops.group([reset_1, reset_2]) + class DenseSplitHandler(InequalitySplitHandler): """Computes stats and finds the best inequality splits on dense columns.""" @@ -160,6 +171,8 @@ class DenseSplitHandler(InequalitySplitHandler): hessian_shape, multiclass_strategy, init_stamp_token=0, + loss_uses_sum_reduction=False, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE, name=None): """Initialize the internal state for this split handler. @@ -179,6 +192,9 @@ class DenseSplitHandler(InequalitySplitHandler): multiclass_strategy: Strategy describing how to treat multiclass problems. init_stamp_token: A tensor containing an scalar for initial stamp of the stamped objects. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. + weak_learner_type: Specifies the type of weak learner to use. name: An optional handler name. """ super(DenseSplitHandler, self).__init__( @@ -193,8 +209,10 @@ class DenseSplitHandler(InequalitySplitHandler): name=name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=multiclass_strategy) + multiclass_strategy=multiclass_strategy, + loss_uses_sum_reduction=loss_uses_sum_reduction) self._dense_float_column = dense_float_column + self._weak_learner_type = weak_learner_type # Register dense_make_stats_update function as an Op to the graph. g = ops.get_default_graph() dense_make_stats_update.add_to_graph(g) @@ -255,7 +273,8 @@ class DenseSplitHandler(InequalitySplitHandler): next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, - self._min_node_weight)) + self._min_node_weight, self._loss_uses_sum_reduction, + self._weak_learner_type)) return are_splits_ready, partition_ids, gains, split_infos @@ -263,7 +282,8 @@ def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, next_stamp_token, multiclass_strategy, class_id, feature_column_id, l1_regularization, l2_regularization, tree_complexity_regularization, - min_node_weight, is_multi_dimentional): + min_node_weight, is_multi_dimentional, + loss_uses_sum_reduction, weak_learner_type): """Function that builds splits for a dense feature column.""" # Get the bucket boundaries are_splits_ready, buckets = ( @@ -291,7 +311,10 @@ def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( gen_stats_accumulator_ops.stats_accumulator_scalar_flush( stats_accumulator_handle, stamp_token, next_stamp_token)) - + # For sum_reduction, we don't need to divide by number of minibatches. + num_minibatches = control_flow_ops.cond(loss_uses_sum_reduction, + lambda: math_ops.to_int64(1), + lambda: num_minibatches) # Put quantile and stats accumulator flushing in the dependency path. with ops.control_dependencies([flush_quantiles, partition_ids]): are_splits_ready = array_ops.identity(are_splits_ready) @@ -309,7 +332,8 @@ def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, l2_regularization=l2_regularization, tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, - multiclass_strategy=multiclass_strategy)) + multiclass_strategy=multiclass_strategy, + weak_learner_type=weak_learner_type)) return are_splits_ready, partition_ids, gains, split_infos @@ -329,6 +353,7 @@ class SparseSplitHandler(InequalitySplitHandler): hessian_shape, multiclass_strategy, init_stamp_token=0, + loss_uses_sum_reduction=False, name=None): """Initialize the internal state for this split handler. @@ -348,6 +373,8 @@ class SparseSplitHandler(InequalitySplitHandler): multiclass_strategy: Strategy describing how to treat multiclass problems. init_stamp_token: A tensor containing an scalar for initial stamp of the stamped objects. + loss_uses_sum_reduction: A scalar boolean tensor that specifies whether + SUM or MEAN reduction was used for the loss. name: An optional handler name. """ super(SparseSplitHandler, self).__init__( @@ -362,6 +389,7 @@ class SparseSplitHandler(InequalitySplitHandler): hessian_shape=hessian_shape, multiclass_strategy=multiclass_strategy, init_stamp_token=init_stamp_token, + loss_uses_sum_reduction=loss_uses_sum_reduction, name=name) self._sparse_float_column = sparse_float_column @@ -424,15 +452,15 @@ class SparseSplitHandler(InequalitySplitHandler): next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, - self._min_node_weight)) + self._min_node_weight, self._loss_uses_sum_reduction)) return are_splits_ready, partition_ids, gains, split_infos -def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle, - stamp_token, next_stamp_token, multiclass_strategy, - class_id, feature_column_id, l1_regularization, - l2_regularization, tree_complexity_regularization, - min_node_weight, is_multi_dimentional): +def _make_sparse_split( + quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional, loss_uses_sum_reduction): """Function that builds splits for a sparse feature column.""" # Get the bucket boundaries are_splits_ready, buckets = ( @@ -460,7 +488,9 @@ def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle, num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( gen_stats_accumulator_ops.stats_accumulator_scalar_flush( stats_accumulator_handle, stamp_token, next_stamp_token)) - + num_minibatches = control_flow_ops.cond(loss_uses_sum_reduction, + lambda: math_ops.to_int64(1), + lambda: num_minibatches) # Put quantile and stats accumulator flushing in the dependency path. with ops.control_dependencies([flush_quantiles, partition_ids]): are_splits_ready = array_ops.identity(are_splits_ready) @@ -483,7 +513,7 @@ def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle, return are_splits_ready, partition_ids, gains, split_infos -def _specialize_make_split(func, is_multi_dimentional): +def _specialize_make_split_dense(func, is_multi_dimentional): """Builds a specialized version of the function.""" @function.Defun( @@ -498,29 +528,65 @@ def _specialize_make_split(func, is_multi_dimentional): dtypes.float32, dtypes.float32, dtypes.float32, + dtypes.bool, + dtypes.int32, noinline=True) def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, next_stamp_token, multiclass_strategy, class_id, feature_column_id, l1_regularization, l2_regularization, tree_complexity_regularization, - min_node_weight): + min_node_weight, loss_uses_sum_reduction, weak_learner_type): """Function that builds splits for a sparse feature column.""" - return func( - quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + return func(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, class_id, + feature_column_id, l1_regularization, l2_regularization, + tree_complexity_regularization, min_node_weight, + is_multi_dimentional, loss_uses_sum_reduction, + weak_learner_type) + + return f + + +def _specialize_make_split_sparse(func, is_multi_dimentional): + """Builds a specialized version of the function.""" + + @function.Defun( + dtypes.resource, + dtypes.resource, + dtypes.int64, + dtypes.int64, + dtypes.int32, + dtypes.int32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.bool, + noinline=True) + def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, next_stamp_token, multiclass_strategy, class_id, feature_column_id, l1_regularization, l2_regularization, tree_complexity_regularization, - min_node_weight, is_multi_dimentional) + min_node_weight, loss_uses_sum_reduction): + """Function that builds splits for a sparse feature column.""" + return func(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, class_id, + feature_column_id, l1_regularization, l2_regularization, + tree_complexity_regularization, min_node_weight, + is_multi_dimentional, loss_uses_sum_reduction) return f -make_dense_split_scalar = _specialize_make_split(_make_dense_split, - is_multi_dimentional=False) -make_dense_split_tensor = _specialize_make_split(_make_dense_split, - is_multi_dimentional=True) -make_sparse_split_scalar = _specialize_make_split(_make_sparse_split, - is_multi_dimentional=False) -make_sparse_split_tensor = _specialize_make_split(_make_sparse_split, - is_multi_dimentional=True) +make_dense_split_scalar = _specialize_make_split_dense( + _make_dense_split, is_multi_dimentional=False) + +make_dense_split_tensor = _specialize_make_split_dense( + _make_dense_split, is_multi_dimentional=True) + +make_sparse_split_scalar = _specialize_make_split_sparse( + _make_sparse_split, is_multi_dimentional=False) +make_sparse_split_tensor = _specialize_make_split_sparse( + _make_sparse_split, is_multi_dimentional=True) @function.Defun( @@ -561,8 +627,10 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, example_partition_ids, feature_ids, gradients, hessians = ( control_flow_ops.cond( - math_ops.logical_and(are_buckets_ready, is_active[0]), - ready_inputs_fn, not_ready_inputs_fn)) + math_ops.logical_and( + math_ops.logical_and(are_buckets_ready, + array_ops.size(quantile_buckets) > 0), + is_active[0]), ready_inputs_fn, not_ready_inputs_fn)) return (quantile_values, quantile_weights, example_partition_ids, feature_ids, gradients, hessians) @@ -656,8 +724,10 @@ def sparse_make_stats_update( lambda: handler_not_active)) example_partition_ids, feature_ids, gradients, hessians = ( - control_flow_ops.cond(are_buckets_ready, quantiles_ready, - quantiles_not_ready)) + control_flow_ops.cond( + math_ops.logical_and(are_buckets_ready, + array_ops.size(quantile_buckets) > 0), + quantiles_ready, quantiles_not_ready)) return (quantile_indices, quantile_values, quantile_shape, quantile_weights, example_partition_ids, feature_ids, gradients, hessians) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 2f2c2302113bf59d6a065d5005c934dc76c2148d..6572f2f414b5d6741f43ec9f79ac7f6ab0f22deb 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -182,6 +182,271 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) + def testObliviousFeatureSplitGeneration(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Dense Quantile | + # i0 | (0.2, 0.12) | 0 | 2 | + # i1 | (-0.5, 0.07) | 0 | 2 | + # i2 | (1.2, 0.2) | 0 | 0 | + # i3 | (4.0, 0.13) | 1 | 1 | + dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52]) + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + class_id = -1 + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + split_handler = ordinal_split_handler.DenseSplitHandler( + l1_regularization=0.1, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., + epsilon=0.001, + num_quantiles=10, + feature_column_group_id=0, + dense_float_column=dense_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([0, 1], partitions) + + oblivious_split_info = split_info_pb2.ObliviousSplitInfo() + oblivious_split_info.ParseFromString(splits[0]) + split_node = oblivious_split_info.split_node.dense_float_binary_split + + self.assertAllClose(0.3, split_node.threshold, 0.00001) + self.assertEqual(0, split_node.feature_column) + + # Check the split on partition 0. + # -(1.2 - 0.1) / (0.2 + 1) + expected_left_weight_0 = -0.9166666666666666 + + # expected_left_weight_0 * -(1.2 - 0.1) + expected_left_gain_0 = 1.008333333333333 + + # (-0.5 + 0.2 + 0.1) / (0.19 + 1) + expected_right_weight_0 = 0.1680672 + + # expected_right_weight_0 * -(-0.5 + 0.2 + 0.1)) + expected_right_gain_0 = 0.033613445378151252 + + # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) + expected_bias_gain_0 = 0.46043165467625896 + + left_child = oblivious_split_info.children_leaves[0].vector + right_child = oblivious_split_info.children_leaves[1].vector + + self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001) + + # Check the split on partition 1. + expected_left_weight_1 = 0 + expected_left_gain_1 = 0 + # -(4 - 0.1) / (0.13 + 1) + expected_right_weight_1 = -3.4513274336283186 + # expected_right_weight_1 * -(4 - 0.1) + expected_right_gain_1 = 13.460176991150442 + # (-4 + 0.1) ** 2 / (0.13 + 1) + expected_bias_gain_1 = 13.460176991150442 + + left_child = oblivious_split_info.children_leaves[2].vector + right_child = oblivious_split_info.children_leaves[3].vector + + self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001) + + # The layer gain is the sum of the gains of each partition + layer_gain = ( + expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + ( + expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1) + self.assertAllClose(layer_gain, gains[0], 0.00001) + + def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Dense Quantile | + # i0 | (0.2, 0.12) | 0 | 1 | + # i1 | (-0.5, 0.07) | 0 | 1 | + # i2 | (1.2, 0.2) | 0 | 0 | + # i3 | (4.0, 0.13) | 1 | 1 | + dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52]) + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + class_id = -1 + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + split_handler = ordinal_split_handler.DenseSplitHandler( + l1_regularization=0.2, + l2_regularization=2., + tree_complexity_regularization=0., + min_node_weight=0., + epsilon=0.001, + num_quantiles=10, + feature_column_group_id=0, + dense_float_column=dense_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + loss_uses_sum_reduction=True) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + update_3 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2, update_3]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([0, 1], partitions) + + # Check the split on partition 0. + # -(2.4 - 0.2) / (0.4 + 2) + expected_left_weight = -0.91666 + + # expected_left_weight * -(2.4 - 0.2) + expected_left_gain = 2.016666666666666 + + # -(-1 + 0.4 + 0.2) / (0.38 + 2) + expected_right_weight = 0.1680672 + + # expected_right_weight * -(-1 + 0.4 + 0.2) + expected_right_gain = 0.0672268907563025 + + # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) + expected_bias_gain = 0.9208633093525178 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.dense_float_binary_split + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0], + 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + + self.assertEqual(0, split_node.feature_column) + + self.assertAllClose(0.3, split_node.threshold, 0.00001) + + # Check the split on partition 1. + # (-8 + 0.2) / (0.26 + 2) + expected_left_weight = -3.4513274336283186 + expected_right_weight = 0 + + # Verify candidate for partition 1, there's only one active bucket here + # so zero gain is expected. + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[1]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.dense_float_binary_split + self.assertAllClose(0.0, gains[1], 0.00001) + + self.assertAllClose([expected_left_weight], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight], right_child.value, 0.00001) + + self.assertEqual(0, split_node.feature_column) + + self.assertAllClose(0.52, split_node.threshold, 0.00001) + def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): with self.test_session() as sess: dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52]) @@ -798,11 +1063,144 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) + def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): + with self.test_session() as sess: + # The data looks like the following: + # Example | Gradients | Partition | Sparse Quantile | + # i0 | (0.2, 0.12) | 0 | 1 | + # i1 | (-0.5, 0.07) | 0 | N/A | + # i2 | (1.2, 0.2) | 0 | 0 | + # i3 | (4.0, 0.13) | 1 | 1 | + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + example_partitions = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + indices = array_ops.constant([[0, 0], [2, 0], [3, 0]], dtype=dtypes.int64) + values = array_ops.constant([0.52, 0.3, 0.52]) + sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = ordinal_split_handler.SparseSplitHandler( + l1_regularization=0.0, + l2_regularization=4.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, + epsilon=0.01, + num_quantiles=2, + feature_column_group_id=0, + sparse_float_column=sparse_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + loss_uses_sum_reduction=True) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + update_3 = split_handler.update_stats_sync( + 1, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2, update_3]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([0, 1], partitions) + # Check the split on partition 0. + # -(0.4 + 2.4) / (0.24 + 0.4 + 4) + expected_left_weight = -0.603448275862069 + # (0.4 + 2.4) ** 2 / (0.24 + 0.4 + 4) + expected_left_gain = 1.689655172413793 + # 1 / (0.14 + 4) + expected_right_weight = 0.24154589371980678 + # 1 ** 2 / (0.14 + 4) + expected_right_gain = 0.24154589371980678 + # (0.4 + 2.4 - 1) ** 2 / (0.24 + 0.4 + 0.14 + 4) + expected_bias_gain = 0.6778242677824265 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.sparse_float_binary_split_default_right + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0]) + + self.assertAllClose([expected_left_weight], left_child.value) + + self.assertAllClose([expected_right_weight], right_child.value) + + self.assertEqual(0, split_node.split.feature_column) + + self.assertAllClose(0.52, split_node.split.threshold) + + # Check the split on partition 1. + expected_left_weight = -1.8779342723004695 + expected_right_weight = 0 + + # Verify candidate for partition 1, there's only one active bucket here + # so zero gain is expected. + split_info.ParseFromString(splits[1]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.sparse_float_binary_split_default_left + + self.assertAllClose(0.0, gains[1]) + + self.assertAllClose([expected_left_weight], left_child.value) + + self.assertAllClose([expected_right_weight], right_child.value) + + self.assertEqual(0, split_node.split.feature_column) + + self.assertAllClose(0.52, split_node.split.threshold) + def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): with self.test_session() as sess: # Batch is 4, 2 classes - gradients = array_ops.constant( - [[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) + gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], + [4.0, -3]]) # 2x2 matrix for each instance hessian_0 = [[0.12, 0.02], [0.3, 0.11]] hessian_1 = [[0.07, -0.2], [-0.5, 0.2]] @@ -896,8 +1294,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): with self.test_session() as sess: # Batch is 4, 2 classes - gradients = array_ops.constant( - [[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) + gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], + [4.0, -3]]) # Each hessian is a diagonal from a full hessian matrix. hessian_0 = [0.12, 0.11] hessian_1 = [0.07, 0.2] @@ -1135,6 +1533,100 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(gains), 0) self.assertEqual(len(splits), 0) + def testEmptyBuckets(self): + """Test that reproduces the case when quantile buckets were empty.""" + with self.test_session() as sess: + sparse_column = array_ops.sparse_placeholder(dtypes.float32) + + # We have two batches - at first, a sparse feature is empty. + empty_indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) + empty_values = array_ops.constant([], dtype=dtypes.float32) + empty_sparse_column = sparse_tensor.SparseTensor(empty_indices, + empty_values, [4, 2]) + empty_sparse_column = empty_sparse_column.eval(session=sess) + + # For the second batch, the sparse feature is not empty. + non_empty_indices = array_ops.constant( + [[0, 0], [2, 1], [3, 2]], dtype=dtypes.int64, shape=[3, 2]) + non_empty_values = array_ops.constant( + [0.52, 0.3, 0.52], dtype=dtypes.float32) + non_empty_sparse_column = sparse_tensor.SparseTensor( + non_empty_indices, non_empty_values, [4, 2]) + non_empty_sparse_column = non_empty_sparse_column.eval(session=sess) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = ordinal_split_handler.SparseSplitHandler( + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, + epsilon=0.01, + num_quantiles=2, + feature_column_group_id=0, + sparse_float_column=sparse_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS) + resources.initialize_resources(resources.shared_resources()).run() + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + + # First, calculate quantiles and try to update on an empty data for a + # feature. + are_splits_ready = ( + sess.run( + are_splits_ready, + feed_dict={sparse_column: empty_sparse_column})) + self.assertFalse(are_splits_ready) + + update_2 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + + # Now the feature in the second batch is not empty, but buckets + # calculated on the first batch are empty. + are_splits_ready2, partitions, gains, splits = ( + sess.run( + [are_splits_ready2, partitions, gains, splits], + feed_dict={sparse_column: non_empty_sparse_column})) + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + # Since the buckets were empty, we can't calculate the splits. + self.assertEqual(len(partitions), 0) + self.assertEqual(len(gains), 0) + self.assertEqual(len(splits), 0) + def testDegenerativeCase(self): with self.test_session() as sess: # One data example only, one leaf and thus one quantile bucket.The same diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h index c120dd8a6c156ec9eb7ba0b6c552f5138bd21a16..f19e5116f5865777ab65e1add2777ac41105acc0 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h @@ -58,6 +58,8 @@ namespace quantiles { // Compute: O(n * log(1/eps * log(eps * n))). // Memory: O(1/eps * log^2(eps * n)) <- for one worker streaming through the // entire dataset. +// An epsilon value of zero would make the algorithm extremely inefficent and +// therefore, is disallowed. template > class WeightedQuantilesStream { @@ -69,6 +71,9 @@ class WeightedQuantilesStream { explicit WeightedQuantilesStream(double eps, int64 max_elements) : eps_(eps), buffer_(1LL, 2LL), finalized_(false) { + // See the class documentation. An epsilon value of zero could cause + // perfoamance issues. + QCHECK(eps > 0) << "An epsilon value of zero is not allowed."; std::tie(max_levels_, block_size_) = GetQuantileSpecs(eps, max_elements); buffer_ = Buffer(block_size_, max_elements); summary_levels_.reserve(max_levels_); diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index a7e7bfc13cadcea4d29d33e0dbd955bdad6ffcb9..69bb8fd4ada861a42a0ccc3f287a47d91be5c879 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -51,7 +51,7 @@ class WeightedQuantilesSummary { SummaryEntry() { memset(this, 0, sizeof(*this)); - value = 0; + value = ValueType(); weight = 0; min_rank = 0; max_rank = 0; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc index 35b059f3496dbc8fb2b3d4fe6ec6b55a9d73dd0c..4fab2b0b7deb6ff2e353d758dc068aa28d44d5ae 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc @@ -16,6 +16,7 @@ #include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" #include "tensorflow/contrib/boosted_trees/lib/utils/macros.h" #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { namespace boosted_trees { @@ -96,9 +97,11 @@ Status BatchFeatures::Initialize( "Sparse float feature shape incompatible with batch size.")); auto tensor_shape = TensorShape({shape_flat(0), shape_flat(1)}); auto order_dims = sparse::SparseTensor::VarDimArray({0, 1}); - sparse_float_feature_columns_.emplace_back(sparse_float_feature_indices, - sparse_float_feature_values, - tensor_shape, order_dims); + sparse::SparseTensor sparse_tensor; + TF_RETURN_IF_ERROR(sparse::SparseTensor::Create( + sparse_float_feature_indices, sparse_float_feature_values, tensor_shape, + order_dims, &sparse_tensor)); + sparse_float_feature_columns_.push_back(std::move(sparse_tensor)); } // Read sparse int features. @@ -136,9 +139,11 @@ Status BatchFeatures::Initialize( "Sparse int feature shape incompatible with batch size.")); auto tensor_shape = TensorShape({shape_flat(0), shape_flat(1)}); auto order_dims = sparse::SparseTensor::VarDimArray({0, 1}); - sparse_int_feature_columns_.emplace_back(sparse_int_feature_indices, - sparse_int_feature_values, - tensor_shape, order_dims); + sparse::SparseTensor sparse_tensor; + TF_RETURN_IF_ERROR(sparse::SparseTensor::Create( + sparse_int_feature_indices, sparse_int_feature_values, tensor_shape, + order_dims, &sparse_tensor)); + sparse_int_feature_columns_.push_back(std::move(sparse_tensor)); } return Status::OK(); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc index d8a608864834b17886313a368221fbf94e31c98e..30c37435fe16ef29a9e29202850501098e9ac7f8 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc @@ -43,27 +43,35 @@ TEST_F(ExamplesIterableTest, Iterate) { test::AsTensor({0, 0, 2, 0, 3, 0, 4, 0}, {4, 2}); auto sparse_float_values1 = test::AsTensor({-3.0f, 0.0f, 5.0f, 0.0f}); auto sparse_float_shape1 = TensorShape({8, 1}); - sparse::SparseTensor sparse_float_tensor1( - sparse_float_indices1, sparse_float_values1, sparse_float_shape1); + sparse::SparseTensor sparse_float_tensor1; + TF_ASSERT_OK( + sparse::SparseTensor::Create(sparse_float_indices1, sparse_float_values1, + sparse_float_shape1, &sparse_float_tensor1)); auto sparse_float_indices2 = test::AsTensor( {0, 1, 1, 0, 2, 1, 3, 0, 4, 1, 5, 0, 5, 1, 7, 0}, {8, 2}); auto sparse_float_values2 = test::AsTensor({1.f, 4.0f, 3.f, 7.0f, 4.3f, 9.0f, 0.8f, -4.0f}); auto sparse_float_shape2 = TensorShape({8, 2}); - sparse::SparseTensor sparse_float_tensor2( - sparse_float_indices2, sparse_float_values2, sparse_float_shape2); + sparse::SparseTensor sparse_float_tensor2; + TF_ASSERT_OK( + sparse::SparseTensor::Create(sparse_float_indices2, sparse_float_values2, + sparse_float_shape2, &sparse_float_tensor2)); auto sparse_int_indices1 = test::AsTensor({0, 0, 0, 1, 1, 0, 3, 0, 3, 1, 7, 0}, {6, 2}); auto sparse_int_values1 = test::AsTensor({1, 8, 0, 2, 0, 5}); auto sparse_int_shape1 = TensorShape({8, 2}); - sparse::SparseTensor sparse_int_tensor1( - sparse_int_indices1, sparse_int_values1, sparse_int_shape1); + sparse::SparseTensor sparse_int_tensor1; + TF_ASSERT_OK( + sparse::SparseTensor::Create(sparse_int_indices1, sparse_int_values1, + sparse_int_shape1, &sparse_int_tensor1)); auto sparse_int_indices2 = test::AsTensor({1, 0, 2, 0, 3, 0, 4, 0}, {4, 2}); auto sparse_int_values2 = test::AsTensor({7, 13, 4, 0}); auto sparse_int_shape2 = TensorShape({8, 1}); - sparse::SparseTensor sparse_int_tensor2( - sparse_int_indices2, sparse_int_values2, sparse_int_shape2); + sparse::SparseTensor sparse_int_tensor2; + TF_ASSERT_OK( + sparse::SparseTensor::Create(sparse_int_indices2, sparse_int_values2, + sparse_int_shape2, &sparse_int_tensor2)); auto validate_example_features = [](int64 example_idx, const Example& example) { diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index ca5c7f3d8c78a543c63fbfa9f7eb7c3d348f11b8..9b68a9de96ec8f6c7679410ca8a468978f2149e6 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -36,6 +36,7 @@ REGISTER_OP("BuildDenseInequalitySplits") .Input("tree_complexity_regularization: float") .Input("min_node_weight: float") .Input("multiclass_strategy: int32") + .Input("weak_learner_type: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -84,6 +85,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child. be considered. multiclass_strategy: A scalar, specifying the multiclass handling strategy. See LearnerConfig.MultiClassStrategy for valid values. +weak_learner_type: A scalar, specifying the weak learner type to use. + See LearnerConfig.WeakLearnerType for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. diff --git a/tensorflow/contrib/boosted_trees/ops/training_ops.cc b/tensorflow/contrib/boosted_trees/ops/training_ops.cc index f63c199ad6146c23c22437ffe2287a77ee91ca44..22ac9edb72ea91ecef6fd1dff9f399b3c9020083 100644 --- a/tensorflow/contrib/boosted_trees/ops/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/training_ops.cc @@ -56,6 +56,7 @@ REGISTER_OP("GrowTreeEnsemble") .Input("next_stamp_token: int64") .Input("learning_rate: float") .Input("dropout_seed: int64") + .Input("max_tree_depth: int32") .Input("partition_ids: num_handlers * int32") .Input("gains: num_handlers * float") .Input("splits: num_handlers * string") @@ -67,6 +68,8 @@ REGISTER_OP("GrowTreeEnsemble") TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_input)); // Dropout seed. TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_input)); + // Maximum tree depth. + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_input)); return Status::OK(); }) .Doc(R"doc( diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto index d84ba7438e7f03685d5bafca52ff8283f0fce898..c49cb48cdea6d8c85588f4c3c2bda6faf7e125db 100644 --- a/tensorflow/contrib/boosted_trees/proto/learner.proto +++ b/tensorflow/contrib/boosted_trees/proto/learner.proto @@ -108,6 +108,11 @@ message LearnerConfig { DIAGONAL_HESSIAN = 3; } + enum WeakLearnerType { + NORMAL_DECISION_TREE = 0; + OBLIVIOUS_DECISION_TREE = 1; + } + // Number of classes. uint32 num_classes = 1; @@ -141,4 +146,7 @@ message LearnerConfig { // If you want to average the ensembles (for regularization), provide the // config below. AveragingConfig averaging_config = 11; + + // By default we use NORMAL_DECISION_TREE as weak learner. + WeakLearnerType weak_learner_type = 12; } diff --git a/tensorflow/contrib/boosted_trees/proto/split_info.proto b/tensorflow/contrib/boosted_trees/proto/split_info.proto index a300c24c8ec507dea0af662b2361d408a2085237..850340f5c2096ca674616254de45d96b84200a64 100644 --- a/tensorflow/contrib/boosted_trees/proto/split_info.proto +++ b/tensorflow/contrib/boosted_trees/proto/split_info.proto @@ -17,3 +17,10 @@ message SplitInfo { // Right Leaf node. tensorflow.boosted_trees.trees.Leaf right_child = 3; } + +message ObliviousSplitInfo { + // The split node with the feature_column and threshold defined. + tensorflow.boosted_trees.trees.TreeNode split_node = 1; + // The new leaves of the tree. + repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2; +} diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index 5cd37ec67ec3bdefb6ea19049a7a12249162d45a..2589504762787deaf598777650b8372320824c22 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -59,7 +59,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): min_node_weight=0, class_id=-1, feature_column_group_id=0, - multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) self.assertAllEqual([0, 1], partitions) @@ -132,7 +133,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): min_node_weight=0, class_id=-1, feature_column_group_id=0, - multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN)) + multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) self.assertAllEqual([0, 1], partitions) @@ -171,7 +173,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): min_node_weight=0, class_id=-1, feature_column_group_id=0, - multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)) partitions, gains, splits = sess.run([partitions, gains, splits]) # .assertEmpty doesn't exist on ubuntu-contrib self.assertEqual(0, len(partitions)) diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index 3e524efbeac74ff754d63cae92b3e194411cb2de..e39e1de8d1954c7f4dcab87d7727a64affa13c8c 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -296,7 +296,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here, tree is not finalized. - dropout_probability=0.5).SerializeToString() + dropout_probability=0.5) # Prepare handler inputs. # Note that handlers 1 & 3 have the same gain but different splits. @@ -321,9 +321,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the simpler split from handler 1 to be chosen. @@ -443,7 +444,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here - tree is not finalized. - dropout_probability=0.5).SerializeToString() + dropout_probability=0.5) # Prepare handler inputs. # Handler 1 only has a candidate for partition 1, handler 2 has candidates @@ -472,9 +473,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split for partition 1 to be chosen from handler 1 and @@ -632,8 +634,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -657,9 +658,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect a new tree to be added with the split from handler 1. @@ -773,8 +775,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # All handlers have negative gain. @@ -794,9 +795,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions, handler2_partitions], gains=[handler1_gains, handler2_gains], splits=[handler1_split, handler2_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the ensemble to be empty. @@ -839,8 +841,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # Note that handlers 1 & 3 have the same gain but different splits. @@ -865,9 +866,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the simpler split from handler 1 to be chosen. @@ -946,8 +948,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=2, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # All handlers have negative gain. @@ -967,9 +968,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions, handler2_partitions], gains=[handler1_gains, handler2_gains], splits=[handler1_split, handler2_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split from handler 2 to be chosen despite the negative gain. @@ -1048,9 +1050,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions], gains=[handler1_gains], splits=[handler1_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the ensemble to be empty as post-pruning will prune @@ -1094,8 +1097,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=2, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # Second handler has positive gain. @@ -1115,9 +1117,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions, handler2_partitions], gains=[handler1_gains, handler2_gains], splits=[handler1_split, handler2_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split from handler 2 to be chosen despite the negative gain. @@ -1194,9 +1197,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions], gains=[handler1_gains], splits=[handler1_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the negative gain split of partition 1 to be pruned and the @@ -1335,7 +1339,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER, # Dropout will have no effect, since the tree will not be fully grown. - dropout_probability=1.0).SerializeToString() + dropout_probability=1.0) # Prepare handler inputs. # Handler 1 only has a candidate for partition 1, handler 2 has candidates @@ -1364,9 +1368,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split for partition 1 to be chosen from handler 1 and @@ -1543,7 +1548,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, - dropout_probability=1.0).SerializeToString() + dropout_probability=1.0) # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -1567,9 +1572,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect a new tree to be added with the split from handler 1. @@ -1669,7 +1675,6 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) learner_config.constraints.max_number_of_unique_feature_columns = 3 - learner_config = learner_config.SerializeToString() # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) handler1_gains = np.array([7.62], dtype=np.float32) @@ -1692,9 +1697,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) _, serialized = session.run( diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 47698d45c81478f2b694aaadc603f742c44d5351..d0d1249bd6afc9cdbf6d88298c5024a4a54a5073 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -46,10 +46,12 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import device_setter + # Key names for prediction dict. ENSEMBLE_STAMP = "ensemble_stamp" PREDICTIONS = "predictions" @@ -61,6 +63,13 @@ USED_HANDLERS_MASK = "used_handlers_mask" LEAF_INDEX = "leaf_index" _FEATURE_NAME_TEMPLATE = "%s_%d" +# Keys in Training state. +GBDTTrainingState = collections.namedtuple("GBDTTrainingState", [ + "num_layer_examples", "num_layer_steps", "num_layers", "active_tree", + "active_layer", "continue_centering", "bias_stats_accumulator", + "steps_accumulator", "handlers" +]) + def _get_column_by_index(tensor, indices): """Returns columns from a 2-D tensor by index.""" @@ -276,9 +285,11 @@ class GradientBoostedDecisionTreeModel(object): learner_config, features, logits_dimension, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS, feature_columns=None, use_core_columns=False, - output_leaf_index=False): + output_leaf_index=False, + output_leaf_index_modes=None): """Construct a new GradientBoostedDecisionTreeModel function. Args: @@ -292,9 +303,15 @@ class GradientBoostedDecisionTreeModel(object): learner_config: A learner config. features: `dict` of `Tensor` objects. logits_dimension: An int, the dimension of logits. + loss_reduction: Either `SUM_OVER_NONZERO_WEIGHTS` (mean) or `SUM`. feature_columns: A list of feature columns. + use_core_columns: A boolean specifying whether core feature columns are + used. output_leaf_index: A boolean variable indicating whether to output leaf index into predictions dictionary. + output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which + dictates when leaf indices will be outputted. By default, leaf indices + are only outputted in INFER mode. Raises: ValueError: if inputs are not valid. @@ -315,6 +332,13 @@ class GradientBoostedDecisionTreeModel(object): self._center_bias = center_bias self._examples_per_layer = examples_per_layer + # Check loss reduction value. + if (loss_reduction != losses.Reduction.SUM and + loss_reduction != losses.Reduction.SUM_OVER_NONZERO_WEIGHTS): + raise ValueError( + "Invalid loss reduction is provided: %s." % loss_reduction) + self._loss_reduction = loss_reduction + # Fill in the defaults. if (learner_config.multi_class_strategy == learner_pb2.LearnerConfig.MULTI_CLASS_STRATEGY_UNSPECIFIED): @@ -325,6 +349,22 @@ class GradientBoostedDecisionTreeModel(object): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + if logits_dimension == 1 or learner_config.multi_class_strategy == ( + learner_pb2.LearnerConfig.TREE_PER_CLASS): + self._gradient_shape = tensor_shape.scalar() + self._hessian_shape = tensor_shape.scalar() + else: + if center_bias: + raise ValueError("Center bias should be False for multiclass.") + + self._gradient_shape = tensor_shape.TensorShape([logits_dimension]) + if (learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.FULL_HESSIAN): + self._hessian_shape = tensor_shape.TensorShape( + ([logits_dimension, logits_dimension])) + else: + # Diagonal hessian strategy. + self._hessian_shape = tensor_shape.TensorShape(([logits_dimension])) if (learner_config.growing_mode == learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED): learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER @@ -344,6 +384,8 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config = learner_config self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() + self._max_tree_depth = variables.Variable( + initial_value=self._learner_config.constraints.max_tree_depth) self._attempted_trees = variables.Variable( initial_value=array_ops.zeros([], dtypes.int64), trainable=False, @@ -359,6 +401,7 @@ class GradientBoostedDecisionTreeModel(object): sparse_int_values, sparse_int_shapes) = extract_features( features, self._feature_columns, use_core_columns) logging.info("Active Feature Columns: " + str(fc_names)) + logging.info("Learner config: " + str(learner_config)) self._fc_names = fc_names self._dense_floats = dense_floats self._sparse_float_indices = sparse_float_indices @@ -371,7 +414,16 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config.multi_class_strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS and learner_config.num_classes == 2) + + if output_leaf_index_modes is None: + output_leaf_index_modes = [learn.ModeKeys.INFER] + elif not all( + mode in (learn.ModeKeys.TRAIN, learn.ModeKeys.EVAL, + learn.ModeKeys.INFER) for mode in output_leaf_index_modes): + raise ValueError("output_leaf_index_modes should only contain ModeKeys.") + self._output_leaf_index = output_leaf_index + self._output_leaf_index_modes = output_leaf_index_modes def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): """Runs prediction and returns a dictionary of the prediction results. @@ -402,8 +454,7 @@ class GradientBoostedDecisionTreeModel(object): # the right stamp. with ops.control_dependencies(ensemble_stats): leaf_index = None - # Only used in infer (predict), not used in train and eval. - if self._output_leaf_index and mode == learn.ModeKeys.INFER: + if self._output_leaf_index and mode in self._output_leaf_index_modes: predictions, _, leaf_index = ( prediction_ops).gradient_trees_prediction_verbose( ensemble_handle, @@ -475,9 +526,6 @@ class GradientBoostedDecisionTreeModel(object): if not input_deps: raise ValueError("No input tensors for prediction.") - if any(i.device != input_deps[0].device for i in input_deps): - raise ValueError("All input tensors should be on the same device.") - # Get most current model stamp. ensemble_stamp = model_ops.tree_ensemble_stamp_token(self._ensemble_handle) @@ -522,17 +570,30 @@ class GradientBoostedDecisionTreeModel(object): return self._predict_and_return_dict(self._ensemble_handle, ensemble_stamp, mode) - def train(self, loss, predictions_dict, labels): - """Grows a new tree and adds it to the ensemble. + def _get_class_id(self, predictions_dict): + # Handle different multiclass strategies. + if (self._learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.TREE_PER_CLASS and + self._logits_dimension != 1): + # Choose the class for which the tree is built (one vs rest). + return math_ops.to_int32( + predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) + return constant_op.constant(-1, dtype=dtypes.int32) + + def update_stats(self, loss, predictions_dict): + """Update the accumulators with stats from this batch. Args: loss: A scalar tensor representing average loss of examples. predictions_dict: Dictionary of Rank 2 `Tensor` representing information about predictions per example. - labels: Rank 2 `Tensor` representing labels per example. Returns: - An op that adds a new tree to the ensemble. + Three values: + - An op that adds a new tree to the ensemble, and + - An op that increments the stamp but removes all the trees and resets + the handlers. This can be used to reset the state of the ensemble. + - A dict containing the training state. Raises: ValueError: if inputs are not valid. @@ -556,13 +617,10 @@ class GradientBoostedDecisionTreeModel(object): aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy - class_id = constant_op.constant(-1, dtype=dtypes.int32) + class_id = self._get_class_id(predictions_dict) # Handle different multiclass strategies. if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS: # We build one vs rest trees. - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() - if self._logits_dimension == 1: # We have only 1 score, gradients is of shape [batch, 1]. hessians = gradients_impl.gradients( @@ -579,11 +637,6 @@ class GradientBoostedDecisionTreeModel(object): hessian_list = self._diagonal_hessian(gradients, predictions) # Assemble hessian list into a tensor. hessians = array_ops.stack(hessian_list, axis=1) - - # Choose the class for which the tree is built (one vs rest). - class_id = math_ops.to_int32( - predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) - # Use class id tensor to get the column with that index from gradients # and hessians. squeezed_gradients = array_ops.squeeze( @@ -592,15 +645,10 @@ class GradientBoostedDecisionTreeModel(object): _get_column_by_index(hessians, class_id)) else: # Other multiclass strategies. - gradient_shape = tensor_shape.TensorShape([self._logits_dimension]) - if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN: - hessian_shape = tensor_shape.TensorShape( - ([self._logits_dimension, self._logits_dimension])) hessian_list = self._full_hessian(gradients, predictions) else: # Diagonal hessian strategy. - hessian_shape = tensor_shape.TensorShape(([self._logits_dimension])) hessian_list = self._diagonal_hessian(gradients, predictions) squeezed_gradients = gradients @@ -608,7 +656,7 @@ class GradientBoostedDecisionTreeModel(object): squeezed_hessians = hessians # Get the weights for each example for quantiles calculation, - weights = self._get_weights(hessian_shape, squeezed_hessians) + weights = self._get_weights(self._hessian_shape, squeezed_hessians) # Create all handlers ensuring resources are evenly allocated across PS. fc_name_idx = 0 @@ -622,6 +670,8 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config.regularization.tree_complexity, dtypes.float32) min_node_weight = constant_op.constant( self._learner_config.constraints.min_node_weight, dtypes.float32) + loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM + loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction) epsilon = 0.01 num_quantiles = 100 strategy_tensor = constant_op.constant(strategy) @@ -635,15 +685,18 @@ class GradientBoostedDecisionTreeModel(object): l2_regularization=l2_regularization, tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, - feature_column_group_id=dense_float_column_idx, + feature_column_group_id=constant_op.constant( + dense_float_column_idx), epsilon=epsilon, num_quantiles=num_quantiles, dense_float_column=self._dense_floats[dense_float_column_idx], name=fc_name, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, multiclass_strategy=strategy_tensor, - init_stamp_token=init_stamp_token)) + init_stamp_token=init_stamp_token, + loss_uses_sum_reduction=loss_uses_sum_reduction, + )) fc_name_idx += 1 # Create handlers for sparse float columns. @@ -655,7 +708,8 @@ class GradientBoostedDecisionTreeModel(object): l2_regularization=l2_regularization, tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, - feature_column_group_id=sparse_float_column_idx, + feature_column_group_id=constant_op.constant( + sparse_float_column_idx), epsilon=epsilon, num_quantiles=num_quantiles, sparse_float_column=sparse_tensor.SparseTensor( @@ -663,10 +717,11 @@ class GradientBoostedDecisionTreeModel(object): self._sparse_float_values[sparse_float_column_idx], self._sparse_float_shapes[sparse_float_column_idx]), name=fc_name, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, multiclass_strategy=strategy_tensor, - init_stamp_token=init_stamp_token)) + init_stamp_token=init_stamp_token, + loss_uses_sum_reduction=loss_uses_sum_reduction)) fc_name_idx += 1 # Create handlers for sparse int columns. @@ -678,32 +733,20 @@ class GradientBoostedDecisionTreeModel(object): l2_regularization=l2_regularization, tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, - feature_column_group_id=sparse_int_column_idx, + feature_column_group_id=constant_op.constant( + sparse_int_column_idx), sparse_int_column=sparse_tensor.SparseTensor( self._sparse_int_indices[sparse_int_column_idx], self._sparse_int_values[sparse_int_column_idx], self._sparse_int_shapes[sparse_int_column_idx]), name=fc_name, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, multiclass_strategy=strategy_tensor, - init_stamp_token=init_stamp_token)) + init_stamp_token=init_stamp_token, + loss_uses_sum_reduction=loss_uses_sum_reduction)) fc_name_idx += 1 - # Create steps accumulator. - steps_accumulator = stats_accumulator_ops.StatsAccumulator( - stamp_token=0, - gradient_shape=tensor_shape.scalar(), - hessian_shape=tensor_shape.scalar(), - name="StepsAccumulator") - - # Create bias stats accumulator. - bias_stats_accumulator = stats_accumulator_ops.StatsAccumulator( - stamp_token=0, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, - name="BiasAccumulator") - # Create ensemble stats variables. num_layer_examples = variables.Variable( initial_value=array_ops.zeros([], dtypes.int64), @@ -725,7 +768,23 @@ class GradientBoostedDecisionTreeModel(object): initial_value=array_ops.zeros([], dtypes.int64), name="active_layer", trainable=False) - + # Variable that becomes false once bias centering is done. + continue_centering = variables.Variable( + initial_value=self._center_bias, + name="continue_centering", + trainable=False) + # Create bias stats accumulator. + bias_stats_accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, + name="BiasAccumulator") + # Create steps accumulator. + steps_accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=tensor_shape.scalar(), + hessian_shape=tensor_shape.scalar(), + name="StepsAccumulator") # Create ensemble stats summaries. summary.scalar("layer_stats/num_examples", num_layer_examples) summary.scalar("layer_stats/num_steps", num_layer_steps) @@ -734,16 +793,13 @@ class GradientBoostedDecisionTreeModel(object): # Update bias stats. stats_update_ops = [] - continue_centering = variables.Variable( - initial_value=self._center_bias, - name="continue_centering", - trainable=False) + stats_update_ops.append( control_flow_ops.cond( continue_centering, - self._make_update_bias_stats_fn(ensemble_stamp, predictions, - gradients, bias_stats_accumulator), - control_flow_ops.no_op)) + self._make_update_bias_stats_fn( + ensemble_stamp, predictions, gradients, + bias_stats_accumulator), control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -800,8 +856,8 @@ class GradientBoostedDecisionTreeModel(object): lambda: active_handlers)) # Prepare empty gradients and hessians when handlers are not ready. - empty_hess_shape = [1] + hessian_shape.as_list() - empty_grad_shape = [1] + gradient_shape.as_list() + empty_hess_shape = [1] + self._hessian_shape.as_list() + empty_grad_shape = [1] + self._gradient_shape.as_list() empty_gradients = constant_op.constant( [], dtype=dtypes.float32, shape=empty_grad_shape) @@ -823,34 +879,86 @@ class GradientBoostedDecisionTreeModel(object): per_handler_updates, ensemble_stamp, worker_device) for update in update_results.values(): stats_update_ops += update + + training_state = GBDTTrainingState( + num_layer_examples=num_layer_examples, + num_layer_steps=num_layer_steps, + num_layers=num_layers, + active_tree=active_tree, + active_layer=active_layer, + continue_centering=continue_centering, + bias_stats_accumulator=bias_stats_accumulator, + steps_accumulator=steps_accumulator, + handlers=handlers) + + reset_op = control_flow_ops.no_op() + if self._is_chief: + # Advance the ensemble stamp to throw away staggered workers. + stamp_token, _ = model_ops.tree_ensemble_serialize(self._ensemble_handle) + next_stamp_token = stamp_token + 1 + + reset_ops = [] + for handler in handlers: + reset_ops.append(handler.reset(stamp_token, next_stamp_token)) + if self._center_bias: + reset_ops.append( + bias_stats_accumulator.flush(stamp_token, next_stamp_token)) + reset_ops.append(steps_accumulator.flush(stamp_token, next_stamp_token)) + reset_ops.append(self._finalized_trees.assign(0).op) + reset_ops.append(self._attempted_trees.assign(0).op) + reset_ops.append( + model_ops.tree_ensemble_deserialize( + self._ensemble_handle, + stamp_token=next_stamp_token, + tree_ensemble_config="", + name="reset_gbdt")) + + reset_op = control_flow_ops.group([reset_ops]) + + return stats_update_ops, reset_op, training_state + + def increment_step_counter_and_maybe_update_ensemble(self, predictions_dict, + training_state): + """Increments number of visited examples and grows the ensemble. + + If the number of visited examples reaches the target examples_per_layer, + ensemble is updated. + + Args: + predictions_dict: Dictionary of Rank 2 `Tensor` representing information + about predictions per example. + training_state: `dict` returned by update_stats. + + Returns: + An op that updates the counters and potientially grows the ensemble. + """ + batch_size = math_ops.cast( + array_ops.shape(predictions_dict[PREDICTIONS])[0], dtypes.float32) + ensemble_stamp = predictions_dict[ENSEMBLE_STAMP] # Accumulate a step after updating stats. - batch_size = math_ops.cast(array_ops.shape(labels)[0], dtypes.float32) - with ops.control_dependencies(stats_update_ops): - add_step_op = steps_accumulator.add(ensemble_stamp, [0], [[0, 0]], - [batch_size], [1.0]) - # Determine learning rate. - learning_rate_tuner = self._learner_config.learning_rate_tuner.WhichOneof( - "tuner") - if learning_rate_tuner == "fixed" or learning_rate_tuner == "dropout": - tuner = getattr(self._learner_config.learning_rate_tuner, - learning_rate_tuner) - learning_rate = tuner.learning_rate - else: - # TODO(nponomareva, soroush) do the line search. - raise ValueError("Line search learning rate is not yet supported.") + steps_accumulator = training_state.steps_accumulator + num_layer_examples = training_state.num_layer_examples + num_layer_steps = training_state.num_layer_steps + active_layer = training_state.active_layer + add_step_op = steps_accumulator.add( + ensemble_stamp, [0], [[0, 0]], [batch_size], [1.0]) # After adding the step, decide if further processing is needed. ensemble_update_ops = [add_step_op] + class_id = self._get_class_id(predictions_dict) + with ops.control_dependencies([add_step_op]): if self._is_chief: dropout_seed = predictions_dict[NUM_TREES_ATTEMPTED] # Get accumulated steps and examples for the current layer. - _, _, _, _, acc_examples, acc_steps = steps_accumulator.serialize() + _, _, _, _, acc_examples, acc_steps = ( + steps_accumulator.serialize()) acc_examples = math_ops.cast(acc_examples[0], dtypes.int64) acc_steps = math_ops.cast(acc_steps[0], dtypes.int64) - ensemble_update_ops.append(num_layer_examples.assign(acc_examples)) + ensemble_update_ops.append( + num_layer_examples.assign(acc_examples)) ensemble_update_ops.append(num_layer_steps.assign(acc_steps)) # Determine whether we need to update tree ensemble. examples_per_layer = self._examples_per_layer @@ -859,18 +967,177 @@ class GradientBoostedDecisionTreeModel(object): ensemble_update_ops.append( control_flow_ops.cond( acc_examples >= examples_per_layer, - self._make_update_ensemble_fn( - ensemble_stamp, steps_accumulator, bias_stats_accumulator, - continue_centering, learning_rate, handlers, num_layers, - active_tree, active_layer, dropout_seed, class_id), + self.make_update_ensemble_fn(ensemble_stamp, training_state, + dropout_seed, class_id), control_flow_ops.no_op)) - # Calculate the loss to be reported. # Note, the loss is calculated from the prediction considering dropouts, so # that the value might look staggering over steps when the dropout ratio is # high. eval_loss might be referred instead in the aspect of convergence. return control_flow_ops.group(*ensemble_update_ops) + def make_update_ensemble_fn(self, ensemble_stamp, training_state, + dropout_seed, class_id): + """A method to create the function which updates the tree ensemble.""" + # Determine learning rate. + learning_rate_tuner = self._learner_config.learning_rate_tuner.WhichOneof( + "tuner") + if learning_rate_tuner == "fixed" or learning_rate_tuner == "dropout": + tuner = getattr(self._learner_config.learning_rate_tuner, + learning_rate_tuner) + learning_rate = tuner.learning_rate + else: + # TODO(nponomareva, soroush) do the line search. + raise ValueError("Line search learning rate is not yet supported.") + + def _update_ensemble(): + """A method to update the tree ensemble.""" + # Get next stamp token. + next_ensemble_stamp = ensemble_stamp + 1 + # Finalize bias stats. + _, _, _, bias_grads, bias_hess = ( + training_state.bias_stats_accumulator.flush(ensemble_stamp, + next_ensemble_stamp)) + + # Finalize handler splits. + are_splits_ready_list = [] + partition_ids_list = [] + gains_list = [] + split_info_list = [] + + for handler in training_state.handlers: + (are_splits_ready, + partition_ids, gains, split_info) = handler.make_splits( + ensemble_stamp, next_ensemble_stamp, class_id) + are_splits_ready_list.append(are_splits_ready) + partition_ids_list.append(partition_ids) + gains_list.append(gains) + split_info_list.append(split_info) + # Stack all the inputs to one tensor per type. + # This is a workaround for the slowness of graph building in tf.cond. + # See (b/36554864). + split_sizes = array_ops.reshape( + array_ops.shape_n(partition_ids_list), [len(partition_ids_list)]) + partition_ids = array_ops.concat(partition_ids_list, axis=0) + gains = array_ops.concat(gains_list, axis=0) + split_infos = array_ops.concat(split_info_list, axis=0) + + # Determine if all splits are ready. + are_all_splits_ready = math_ops.reduce_all( + array_ops.stack( + are_splits_ready_list, axis=0, name="stack_handler_readiness")) + + # Define bias centering update operation. + def _center_bias_fn(): + # Center tree ensemble bias. + delta_updates = array_ops.where(bias_hess > 0, -bias_grads / bias_hess, + array_ops.zeros_like(bias_grads)) + center_bias = training_ops.center_tree_ensemble_bias( + tree_ensemble_handle=self._ensemble_handle, + stamp_token=ensemble_stamp, + next_stamp_token=next_ensemble_stamp, + delta_updates=delta_updates, + learner_config=self._learner_config_serialized) + return training_state.continue_centering.assign(center_bias) + + # Define ensemble growing operations. + def _grow_ensemble_ready_fn(): + # Grow the ensemble given the current candidates. + sizes = array_ops.unstack(split_sizes) + partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0)) + gains_list = list(array_ops.split(gains, sizes, axis=0)) + split_info_list = list(array_ops.split(split_infos, sizes, axis=0)) + return training_ops.grow_tree_ensemble( + tree_ensemble_handle=self._ensemble_handle, + stamp_token=ensemble_stamp, + next_stamp_token=next_ensemble_stamp, + learning_rate=learning_rate, + partition_ids=partition_ids_list, + gains=gains_list, + splits=split_info_list, + learner_config=self._learner_config_serialized, + dropout_seed=dropout_seed, + center_bias=self._center_bias, + max_tree_depth=self._max_tree_depth) + + def _grow_ensemble_not_ready_fn(): + # Don't grow the ensemble, just update the stamp. + return training_ops.grow_tree_ensemble( + tree_ensemble_handle=self._ensemble_handle, + stamp_token=ensemble_stamp, + next_stamp_token=next_ensemble_stamp, + learning_rate=0, + partition_ids=[], + gains=[], + splits=[], + learner_config=self._learner_config_serialized, + dropout_seed=dropout_seed, + center_bias=self._center_bias, + max_tree_depth=self._max_tree_depth) + + def _grow_ensemble_fn(): + # Conditionally grow an ensemble depending on whether the splits + # from all the handlers are ready. + return control_flow_ops.cond(are_all_splits_ready, + _grow_ensemble_ready_fn, + _grow_ensemble_not_ready_fn) + + # Update ensemble. + update_ops = [are_all_splits_ready] + if self._center_bias: + update_model = control_flow_ops.cond(training_state.continue_centering, + _center_bias_fn, _grow_ensemble_fn) + else: + update_model = _grow_ensemble_fn() + update_ops.append(update_model) + + # Update ensemble stats. + with ops.control_dependencies([update_model]): + stats = training_ops.tree_ensemble_stats( + self._ensemble_handle, stamp_token=next_ensemble_stamp) + update_ops.append(self._finalized_trees.assign(stats.num_trees)) + update_ops.append(self._attempted_trees.assign(stats.attempted_trees)) + update_ops.append(training_state.num_layers.assign(stats.num_layers)) + update_ops.append(training_state.active_tree.assign(stats.active_tree)) + update_ops.append( + training_state.active_layer.assign(stats.active_layer)) + + # Flush step stats. + update_ops.extend( + training_state.steps_accumulator.flush(ensemble_stamp, + next_ensemble_stamp)) + return control_flow_ops.group(*update_ops, name="update_ensemble") + + return _update_ensemble + + def get_number_of_trees_tensor(self): + return self._finalized_trees, self._attempted_trees + + def get_max_tree_depth(self): + return self._max_tree_depth + + def train(self, loss, predictions_dict, labels): + """Updates the accumalator stats and grows the ensemble. + + Args: + loss: A scalar tensor representing average loss of examples. + predictions_dict: Dictionary of Rank 2 `Tensor` representing information + about predictions per example. + labels: Rank 2 `Tensor` representing labels per example. Has no effect + on the training and is only kept for backward compatibility. + + Returns: + An op that adds a new tree to the ensemble. + + Raises: + ValueError: if inputs are not valid. + """ + del labels # unused; kept for backward compatibility. + update_op, _, training_state = self.update_stats(loss, predictions_dict) + with ops.control_dependencies(update_op): + return self.increment_step_counter_and_maybe_update_ensemble( + predictions_dict, training_state) + def _get_weights(self, hessian_shape, hessians): """Derives weights to be used based on hessians and multiclass strategy.""" if hessian_shape == tensor_shape.scalar(): @@ -986,127 +1253,3 @@ class GradientBoostedDecisionTreeModel(object): return control_flow_ops.group(*[add_stats_op], name="update_bias_stats") return _update_bias_stats - - def _make_update_ensemble_fn(self, ensemble_stamp, steps_accumulator, - bias_stats_accumulator, continue_centering, - learning_rate, handlers, num_layers, active_tree, - active_layer, dropout_seed, class_id): - """A method to create the function which updates the tree ensemble.""" - - def _update_ensemble(): - """A method to update the tree ensemble.""" - # Get next stamp token. - next_ensemble_stamp = ensemble_stamp + 1 - # Finalize bias stats. - _, _, _, bias_grads, bias_hess = bias_stats_accumulator.flush( - ensemble_stamp, next_ensemble_stamp) - - # Finalize handler splits. - are_splits_ready_list = [] - partition_ids_list = [] - gains_list = [] - split_info_list = [] - - for handler in handlers: - (are_splits_ready, - partition_ids, gains, split_info) = handler.make_splits( - ensemble_stamp, next_ensemble_stamp, class_id) - are_splits_ready_list.append(are_splits_ready) - partition_ids_list.append(partition_ids) - gains_list.append(gains) - split_info_list.append(split_info) - # Stack all the inputs to one tensor per type. - # This is a workaround for the slowness of graph building in tf.cond. - # See (b/36554864). - split_sizes = array_ops.reshape( - array_ops.shape_n(partition_ids_list), [len(partition_ids_list)]) - partition_ids = array_ops.concat(partition_ids_list, axis=0) - gains = array_ops.concat(gains_list, axis=0) - split_infos = array_ops.concat(split_info_list, axis=0) - - # Determine if all splits are ready. - are_all_splits_ready = math_ops.reduce_all( - array_ops.stack( - are_splits_ready_list, axis=0, name="stack_handler_readiness")) - - # Define bias centering update operation. - def _center_bias_fn(): - # Center tree ensemble bias. - delta_updates = array_ops.where(bias_hess > 0, -bias_grads / bias_hess, - array_ops.zeros_like(bias_grads)) - center_bias = training_ops.center_tree_ensemble_bias( - tree_ensemble_handle=self._ensemble_handle, - stamp_token=ensemble_stamp, - next_stamp_token=next_ensemble_stamp, - delta_updates=delta_updates, - learner_config=self._learner_config_serialized) - return continue_centering.assign(center_bias) - - # Define ensemble growing operations. - def _grow_ensemble_ready_fn(): - # Grow the ensemble given the current candidates. - sizes = array_ops.unstack(split_sizes) - partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0)) - gains_list = list(array_ops.split(gains, sizes, axis=0)) - split_info_list = list(array_ops.split(split_infos, sizes, axis=0)) - return training_ops.grow_tree_ensemble( - tree_ensemble_handle=self._ensemble_handle, - stamp_token=ensemble_stamp, - next_stamp_token=next_ensemble_stamp, - learning_rate=learning_rate, - partition_ids=partition_ids_list, - gains=gains_list, - splits=split_info_list, - learner_config=self._learner_config_serialized, - dropout_seed=dropout_seed, - center_bias=self._center_bias) - - def _grow_ensemble_not_ready_fn(): - # Don't grow the ensemble, just update the stamp. - return training_ops.grow_tree_ensemble( - tree_ensemble_handle=self._ensemble_handle, - stamp_token=ensemble_stamp, - next_stamp_token=next_ensemble_stamp, - learning_rate=0, - partition_ids=[], - gains=[], - splits=[], - learner_config=self._learner_config_serialized, - dropout_seed=dropout_seed, - center_bias=self._center_bias) - - def _grow_ensemble_fn(): - # Conditionally grow an ensemble depending on whether the splits - # from all the handlers are ready. - return control_flow_ops.cond(are_all_splits_ready, - _grow_ensemble_ready_fn, - _grow_ensemble_not_ready_fn) - - # Update ensemble. - update_ops = [are_all_splits_ready] - if self._center_bias: - update_model = control_flow_ops.cond(continue_centering, - _center_bias_fn, _grow_ensemble_fn) - else: - update_model = _grow_ensemble_fn() - update_ops.append(update_model) - - # Update ensemble stats. - with ops.control_dependencies([update_model]): - stats = training_ops.tree_ensemble_stats( - self._ensemble_handle, stamp_token=next_ensemble_stamp) - update_ops.append(self._finalized_trees.assign(stats.num_trees)) - update_ops.append(self._attempted_trees.assign(stats.attempted_trees)) - update_ops.append(num_layers.assign(stats.num_layers)) - update_ops.append(active_tree.assign(stats.active_tree)) - update_ops.append(active_layer.assign(stats.active_layer)) - - # Flush step stats. - update_ops.extend( - steps_accumulator.flush(ensemble_stamp, next_ensemble_stamp)) - return control_flow_ops.group(*update_ops, name="update_ensemble") - - return _update_ensemble - - def get_number_of_trees_tensor(self): - return self._finalized_trees, self._attempted_trees diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index e3d4397fadcbaf148f7f6cfaca13e850639786cf..f7867d882d6813a8701065ad0ce8d27f8bb9c301 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -29,6 +29,7 @@ from tensorflow.contrib.layers.python.layers import feature_column as feature_co from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -1560,6 +1561,301 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertEquals(output.growing_metadata.num_layers_attempted, 2) + def testResetModelBeforeAndAfterSplit(self): + """Tests whether resetting works.""" + with self.test_session(): + # First build a small tree and train it to verify training works. + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = model_ops.tree_ensemble_stamp_token(ensemble_handle) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp, + "num_trees": 12, + "max_tree_depth": 4, + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + loss = math_ops.reduce_mean(_squared_loss(labels, weights, predictions)) + + # Create train op. + update_op, reset_op, training_state = gbdt_model.update_stats( + loss, predictions_dict) + with ops.control_dependencies(update_op): + train_op = gbdt_model.increment_step_counter_and_maybe_update_ensemble( + predictions_dict, training_state) + + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + original_stamp = ensemble_stamp.eval() + expected_tree = """ + nodes { + dense_float_binary_split { + threshold: 1.0 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0 + } + } + nodes { + leaf { + vector { + value: 0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + }""" + + def _train_once_and_check(expect_split): + stamp = ensemble_stamp.eval() + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(stamp_token.eval(), stamp + 1) + if expect_split: + # State of the ensemble after a split occurs. + self.assertEquals(len(output.trees), 1) + self.assertProtoEquals(expected_tree, output.trees[0]) + else: + # State of the ensemble after a single accumulation but before any + # splitting occurs + self.assertEquals(len(output.trees), 0) + self.assertProtoEquals(""" + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + }""", output) + + def _run_reset(): + stamp_before_reset = ensemble_stamp.eval() + reset_op.run() + stamp_after_reset = ensemble_stamp.eval() + self.assertNotEquals(stamp_after_reset, stamp_before_reset) + + _, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertProtoEquals("", output) + + return stamp_after_reset + + # Exit after one train_op, so no new layer are created but the handlers + # contain enough information to split on the next call to train. + _train_once_and_check(expect_split=False) + self.assertEquals(ensemble_stamp.eval(), original_stamp + 1) + + # Reset the handlers so it still requires two training calls to split. + stamp_after_reset = _run_reset() + + _train_once_and_check(expect_split=False) + _train_once_and_check(expect_split=True) + self.assertEquals(ensemble_stamp.eval(), stamp_after_reset + 2) + + # This time, test that the reset_op works right after splitting. + stamp_after_reset = _run_reset() + + # Test that after resetting, the tree can be trained as normal. + _train_once_and_check(expect_split=False) + _train_once_and_check(expect_split=True) + self.assertEquals(ensemble_stamp.eval(), stamp_after_reset + 2) + + def testResetModelNonChief(self): + """Tests the reset function on a non-chief worker.""" + with self.test_session(): + # Create ensemble with one bias node. + ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge( + """ + trees { + nodes { + leaf { + vector { + value: 0.25 + } + } + } + } + tree_weights: 1.0 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: false + }""", ensemble_config) + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=ensemble_config.SerializeToString(), + name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=False, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = model_ops.tree_ensemble_stamp_token(ensemble_handle) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + loss = math_ops.reduce_mean(_squared_loss(labels, weights, predictions)) + + # Create reset op. + _, reset_op, _ = gbdt_model.update_stats( + loss, predictions_dict) + + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # Reset op doesn't do anything because this is a non-chief worker. + reset_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertEquals(len(output.tree_weights), 1) + self.assertEquals(stamp_token.eval(), 0) + + def testResetModelWithCenterBias(self): + """Tests the reset function running on chief with bias centering.""" + with self.test_session(): + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=True, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = model_ops.tree_ensemble_stamp_token(ensemble_handle) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp, + "num_trees": 12, + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + loss = math_ops.reduce_mean(_squared_loss(labels, weights, predictions)) + + # Create train op. + update_op, reset_op, training_state = gbdt_model.update_stats( + loss, predictions_dict) + with ops.control_dependencies(update_op): + train_op = gbdt_model.increment_step_counter_and_maybe_update_ensemble( + predictions_dict, training_state) + + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect bias to be centered. + def train_and_check(): + train_op.run() + _, serialized = model_ops.tree_ensemble_serialize(ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + expected_tree = """ + nodes { + leaf { + vector { + value: 0.25 + } + } + }""" + self.assertEquals(len(output.trees), 1) + self.assertAllEqual(output.tree_weights, [1.0]) + self.assertProtoEquals(expected_tree, output.trees[0]) + + train_and_check() + self.assertEquals(ensemble_stamp.eval(), 1) + + reset_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 2) + + train_and_check() + self.assertEquals(ensemble_stamp.eval(), 3) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index ab7ac2aba605db22a8ed370049b27d55cf1d413a..b5ebaf1999519f65110e8164fa20bace5ecc3ef6 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -23,6 +23,12 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.ops.losses import losses + + +def per_example_squared_hinge_loss(labels, weights, predictions): + loss = losses.hinge_loss(labels=labels, logits=predictions, weights=weights) + return math_ops.square(loss), control_flow_ops.no_op() def per_example_logistic_loss(labels, weights, predictions): @@ -126,7 +132,7 @@ def per_example_squared_loss(labels, weights, predictions): def per_example_exp_loss(labels, weights, predictions, name=None, eps=0.1): - """Exponential loss given labels, example weights and predictions. + """Trimmed exponential loss given labels, example weights and predictions. Note that this is only for binary classification. If logistic loss tries to make sure that the classifier is certain of its @@ -211,3 +217,62 @@ def per_example_exp_loss(labels, weights, predictions, name=None, eps=0.1): unweighted_loss = exp_with_logits( name=name, eps=eps, labels=labels, logits=predictions) return unweighted_loss * weights, control_flow_ops.no_op() + + +def per_example_full_exp_loss(labels, weights, predictions, name=None): + """Full exponential loss given labels, example weights and predictions. + + Note that this is only for binary classification. + The loss returns is exp(-targets*logits), where targets are converted to -1 + and 1. + + Args: + labels: Rank 2 (N, D) tensor of per-example labels. + weights: Rank 2 (N, 1) tensor of per-example weights. + predictions: Rank 2 (N, D) tensor of per-example predictions. + name: A name for the operation (optional). + + Returns: + loss: A Rank 2 (N, 1) tensor of per-example exp loss + update_op: An update operation to update the loss's internal state. + """ + + def full_exp_with_logits(name, labels=None, logits=None): + """Computes exponential loss given `logits`. + + Args: + name: A name for the operation (optional). + labels: A `Tensor` of the same type and shape as `logits`. + logits: A `Tensor` of type `float32` or `float64`. + + Returns: + A `Tensor` of the same shape as `logits` with the componentwise + exponential losses. + + Raises: + ValueError: If `logits` and `labels` do not have the same shape. + """ + with ops.name_scope(name, "exp_loss", [logits, labels]) as name: + logits = ops.convert_to_tensor(logits, name="logits") + labels = ops.convert_to_tensor(labels, name="labels") + try: + labels.get_shape().merge_with(logits.get_shape()) + except ValueError: + raise ValueError("logits and labels must have the same shape (%s vs %s)" + % (logits.get_shape(), labels.get_shape())) + + # Default threshold of 0 to switch between classes + zeros = array_ops.zeros_like(logits, dtype=logits.dtype) + ones = array_ops.ones_like(logits, dtype=logits.dtype) + neg_ones = -array_ops.ones_like(logits, dtype=logits.dtype) + + # Convert labels to 1 and -1 + cond_labels = (labels > zeros) + labels_converted = array_ops.where(cond_labels, ones, neg_ones) + + return math_ops.exp(-1.0 * logits * labels_converted) + + labels = math_ops.to_float(labels) + unweighted_loss = full_exp_with_logits( + name=name, labels=labels, logits=predictions) + return unweighted_loss * weights, control_flow_ops.no_op() diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 8ae493ba998bd882b5ef946f927ec1882d91f61d..e92f0bb841ac6dc57547874881af8bd10c47474f 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -16,10 +16,13 @@ Visualization and inspection: @@dot_graph_from_checkpoint +@@list_objects @@object_metadata Managing dependencies: +@@capture_dependencies @@Checkpointable +@@CheckpointableBase @@CheckpointableObjectGraph @@NoDependency @@split_dependency @@ -28,6 +31,9 @@ Checkpointable data structures: @@List @@Mapping @@UniqueNameTracker + +Checkpoint management: +@@CheckpointManager """ from __future__ import absolute_import @@ -38,13 +44,16 @@ from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph -from tensorflow.python.training.checkpointable.base import Checkpointable -from tensorflow.python.training.checkpointable.base import NoDependency +from tensorflow.python.training.checkpoint_management import CheckpointManager +from tensorflow.python.training.checkpointable.base import CheckpointableBase from tensorflow.python.training.checkpointable.data_structures import List from tensorflow.python.training.checkpointable.data_structures import Mapping +from tensorflow.python.training.checkpointable.data_structures import NoDependency +from tensorflow.python.training.checkpointable.tracking import Checkpointable +from tensorflow.python.training.checkpointable.util import capture_dependencies +from tensorflow.python.training.checkpointable.util import list_objects from tensorflow.python.training.checkpointable.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(module_name=__name__) - diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 4d3d5312993740636709cb732c0b8e3e2626262d..242c1e8ba45e0b2f6f9a1a51695b824546382666 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -35,9 +35,9 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker() slotdeps = self.slotdeps slots = [] - slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x" - slots.append(slotdeps.track(tfe.Variable(4.), "y")) - slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1" + slots.append(slotdeps.track(tf.Variable(3.), "x")) # Named "x" + slots.append(slotdeps.track(tf.Variable(4.), "y")) + slots.append(slotdeps.track(tf.Variable(5.), "x")) # Named "x_1" ``` """ diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index 3717d7f583ffdc205a279d45df60cddbc5cbf08e..ac85c7be803cd4c2f8ba19d3ef887a3c65a15933 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -26,13 +26,14 @@ from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util class UniqueNameTrackerTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNames(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") @@ -48,11 +49,11 @@ class UniqueNameTrackerTests(test.TestCase): slots.track(y, "y") self.evaluate((x1.initializer, x2.initializer, x3.initializer, y.initializer)) - save_root = checkpointable_utils.Checkpoint(slots=slots) + save_root = util.Checkpoint(slots=slots) save_path = save_root.save(checkpoint_prefix) - restore_slots = checkpointable.Checkpointable() - restore_root = checkpointable_utils.Checkpoint( + restore_slots = tracking.Checkpointable() + restore_root = util.Checkpoint( slots=restore_slots) status = restore_root.restore(save_path) restore_slots.x = resource_variable_ops.ResourceVariable(0.) @@ -65,9 +66,9 @@ class UniqueNameTrackerTests(test.TestCase): self.assertEqual(4., self.evaluate(restore_slots.x_1_1)) self.assertEqual(5., self.evaluate(restore_slots.y)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testExample(self): - class SlotManager(checkpointable.Checkpointable): + class SlotManager(tracking.Checkpointable): def __init__(self): self.slotdeps = containers.UniqueNameTracker() @@ -79,15 +80,15 @@ class UniqueNameTrackerTests(test.TestCase): resource_variable_ops.ResourceVariable(4.), "y")) slots.append(slotdeps.track( resource_variable_ops.ResourceVariable(5.), "x")) - self.slots = slots + self.slots = data_structures.NoDependency(slots) manager = SlotManager() self.evaluate([v.initializer for v in manager.slots]) - checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager) + checkpoint = util.Checkpoint(slot_manager=manager) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = checkpoint.save(checkpoint_prefix) - metadata = checkpointable_utils.object_metadata(save_path) + metadata = util.object_metadata(save_path) dependency_names = [] for node in metadata.nodes: for child in node.children: @@ -97,7 +98,7 @@ class UniqueNameTrackerTests(test.TestCase): dependency_names, ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayers(self): tracker = containers.UniqueNameTracker() tracker.track(layers.Dense(3), "dense") diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py index 69dc0b9be2d5548852c37552a64a0d31c9557b43..00a805af25d5d0ea723db5d015fb12bf45c53857 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -23,8 +23,9 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util def _split_variable_closure(variable): @@ -43,7 +44,7 @@ def _combine_variable_closure(variable): return _consume_restore_buffer_fn -class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase): +class SaveTensorSlicesAsDeps(base.CheckpointableBase): def __init__(self): self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) @@ -58,14 +59,14 @@ class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase): self._track_checkpointable(dep, name=name) -class HasRegularDeps(checkpointable.Checkpointable): +class HasRegularDeps(tracking.Checkpointable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) -class OnlyOneDep(checkpointable.Checkpointable): +class OnlyOneDep(tracking.Checkpointable): def __init__(self): self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) @@ -73,9 +74,9 @@ class OnlyOneDep(checkpointable.Checkpointable): class SplitTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestoreSplitDep(self): - save_checkpoint = checkpointable_utils.Checkpoint( + save_checkpoint = util.Checkpoint( dep=SaveTensorSlicesAsDeps()) self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.])) checkpoint_directory = self.get_temp_dir() @@ -83,7 +84,7 @@ class SplitTests(test.TestCase): save_path = save_checkpoint.save(checkpoint_prefix) regular_deps = HasRegularDeps() - regular_restore_checkpoint = checkpointable_utils.Checkpoint( + regular_restore_checkpoint = util.Checkpoint( dep=regular_deps) regular_restore_checkpoint.restore( save_path).assert_consumed().run_restore_ops() @@ -91,7 +92,7 @@ class SplitTests(test.TestCase): self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half)) one_dep = OnlyOneDep() - one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep) + one_dep_restore_checkpoint = util.Checkpoint(dep=one_dep) status = one_dep_restore_checkpoint.restore(save_path) with self.assertRaises(AssertionError): # Missing the second dependency. @@ -99,7 +100,7 @@ class SplitTests(test.TestCase): status.run_restore_ops() self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half)) - restore_checkpoint = checkpointable_utils.Checkpoint() + restore_checkpoint = util.Checkpoint() status = restore_checkpoint.restore(save_path) restore_checkpoint.dep = SaveTensorSlicesAsDeps() status.assert_consumed().run_restore_ops() diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index 42ba368531468b789a87429f88ca84937f9b909d..523a9efcf05f5d32589f6e1734f866bf8b4b9cdc 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -50,6 +50,7 @@ py_library( deps = [ ":gen_bigquery_reader_ops", ":gen_gcs_config_ops", + "//tensorflow/contrib/bigtable", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:io_ops", "//tensorflow/python:util", @@ -74,3 +75,14 @@ tf_py_test( ], tags = ["manual"], ) + +tf_py_test( + name = "gcs_config_ops_test", + size = "small", + srcs = ["python/ops/gcs_config_ops_test.py"], + additional_deps = [ + ":cloud_py", + "//tensorflow/python:client_testlib", + ], + tags = ["manual"], +) diff --git a/tensorflow/contrib/cloud/README.md b/tensorflow/contrib/cloud/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a80d8965f3b562cadaff8caad8d40c7b98afa78f --- /dev/null +++ b/tensorflow/contrib/cloud/README.md @@ -0,0 +1,18 @@ +# Cloud # + +## Cloud Bigtable ## + +[Google Cloud Bigtable](https://cloud.google.com/bigtable/) is a high +performance storage system that can store and serve training data. This contrib +package contains an experimental integration with TensorFlow. + +> **Status: Highly experimental.** The current implementation is very much in +> flux. Please use at your own risk! :-) + + + +## Cloud Storage (GCS) ## + +The Google Cloud Storage ops allow the user to configure the GCS File System. + + diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py index a6e13ea3ae938444b9ead0772e52fb8797a847da..8efd259946b7696e66b83a3b0aa451543c107467 100644 --- a/tensorflow/contrib/cloud/__init__.py +++ b/tensorflow/contrib/cloud/__init__.py @@ -18,17 +18,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=line-too-long,wildcard-import +import os + +# pylint: disable=line-too-long,wildcard-import,g-import-not-at-top from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import * from tensorflow.contrib.cloud.python.ops.gcs_config_ops import * -# pylint: enable=line-too-long,wildcard-import + +if os.name != 'nt': + from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient + from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableTable + +del os from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'BigQueryReader', - 'ConfigureColabSession', - 'ConfigureGcs', + 'BigtableClient', + 'BigtableTable', + 'BlockCacheParams', + 'configure_colab_session', + 'configure_gcs', 'ConfigureGcsHook', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 40160706f70e8fa8323005dd183770ed51c8c415..1311063ec023bdaa2588d6f1c826bf900f7dea09 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -79,6 +79,7 @@ tf_kernel_library( srcs = ["gcs_config_ops.cc"], visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/cloud:gcs_config_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/platform/cloud:curl_http_request", diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc index 1bfd27305d569668a0bd67d876e59eec082296b3..58fadffce32f9a8fec047d1e99f9f4eb5a710d91 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc @@ -85,7 +85,7 @@ Status BigQueryTableAccessor::New( int64 timestamp_millis, int64 row_buffer_size, const string& end_point, const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory, + std::shared_ptr http_request_factory, std::unique_ptr* accessor) { if (timestamp_millis <= 0) { return errors::InvalidArgument( @@ -94,29 +94,19 @@ Status BigQueryTableAccessor::New( const string& big_query_end_point = end_point.empty() ? kBigQueryEndPoint : end_point; if (auth_provider == nullptr && http_request_factory == nullptr) { - accessor->reset(new BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - big_query_end_point, columns, partition)); - } else { - accessor->reset(new BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - big_query_end_point, columns, partition, std::move(auth_provider), - std::move(http_request_factory))); + http_request_factory = std::make_shared(); + auto compute_engine_metadata_client = + std::make_shared(http_request_factory); + auth_provider = std::unique_ptr( + new GoogleAuthProvider(compute_engine_metadata_client)); } - return (*accessor)->ReadSchema(); -} -BigQueryTableAccessor::BigQueryTableAccessor( - const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, const string& end_point, - const std::vector& columns, const BigQueryTablePartition& partition) - : BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - end_point, columns, partition, - std::unique_ptr(new GoogleAuthProvider()), - std::unique_ptr( - new CurlHttpRequest::Factory())) { - row_buffer_.resize(row_buffer_size); + accessor->reset(new BigQueryTableAccessor( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + big_query_end_point, columns, partition, std::move(auth_provider), + std::move(http_request_factory))); + + return (*accessor)->ReadSchema(); } BigQueryTableAccessor::BigQueryTableAccessor( @@ -124,7 +114,7 @@ BigQueryTableAccessor::BigQueryTableAccessor( int64 timestamp_millis, int64 row_buffer_size, const string& end_point, const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory) + std::shared_ptr http_request_factory) : project_id_(project_id), dataset_id_(dataset_id), table_id_(table_id), diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h index b349063715c903c982cfe2fb116b6525e35ff63b..1af43a3e1070d466bb50019f12b22a060c1e6ab1 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h @@ -109,24 +109,17 @@ class BigQueryTableAccessor { const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory, + std::shared_ptr http_request_factory, std::unique_ptr* accessor); /// \brief Constructs an object for a given table and partition. - BigQueryTableAccessor(const string& project_id, const string& dataset_id, - const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, const string& end_point, - const std::vector& columns, - const BigQueryTablePartition& partition); - - /// Used for unit testing. BigQueryTableAccessor( const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, int64 row_buffer_size, const string& end_point, const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory); + std::shared_ptr http_request_factory); /// \brief Parses column values for a given row. Status ParseColumnValues(const Json::Value& value, @@ -199,7 +192,7 @@ class BigQueryTableAccessor { SchemaNode schema_root_; std::unique_ptr auth_provider_; - std::unique_ptr http_request_factory_; + std::shared_ptr http_request_factory_; TF_DISALLOW_COPY_AND_ASSIGN(BigQueryTableAccessor); }; diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py index 8c8c5acb31af69b4f738a13c6548cdd31947d71a..95e7e744d34391a511cdba7702aad369b8d9d9c0 100644 --- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -120,13 +120,18 @@ class ConfigureGcsHook(training.SessionRunHook): def begin(self): if self._credentials: self._credentials_placeholder = array_ops.placeholder(dtypes.string) - self._credentials_ops = gen_gcs_config_ops.gcs_configure_credentials( + self._credentials_op = gen_gcs_config_ops.gcs_configure_credentials( self._credentials_placeholder) + else: + self._credentials_op = None + if self._block_cache: self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache( max_cache_size=self._block_cache.max_bytes, block_size=self._block_cache.block_size, max_staleness=self._block_cache.max_staleness) + else: + self._block_cache_op = None def after_create_session(self, session, coord): del coord diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6c056d6c8adfa50b95aefb8e9740631327a572 --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py @@ -0,0 +1,44 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the gcs_config_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cloud.python.ops import gcs_config_ops +from tensorflow.python.platform import test + + +class GcsConfigOpsTest(test.TestCase): + + def testSetBlockCache(self): + cfg = gcs_config_ops.BlockCacheParams(max_bytes=1024*1024*1024) + with self.test_session() as sess: + gcs_config_ops.configure_gcs(sess, block_cache=cfg) + + def testConfigureGcsHook(self): + creds = {'client_id': 'fake_client', + 'refresh_token': 'fake_token', + 'client_secret': 'fake_secret', + 'type': 'authorized_user'} + hook = gcs_config_ops.ConfigureGcsHook(credentials=creds) + hook.begin() + with self.test_session() as sess: + sess.run = lambda _, feed_dict=None, options=None, run_metadata=None: None + hook.after_create_session(sess, None) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index c239e6f8f960910cee14e1df7c4678c643496f54..707f6211846ca0310bde297603928e9ec5bb471c 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -12,6 +12,15 @@ licenses(["notice"]) # Apache 2.0 py_library( name = "cluster_resolver_pip", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":cluster_resolver_py", + ], +) + +py_library( + name = "cluster_resolver_py", srcs = [ "__init__.py", "python/training/__init__.py", @@ -19,7 +28,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - ":cluster_resolver_py", + ":base_cluster_resolver_py", ":gce_cluster_resolver_py", ":tpu_cluster_resolver_py", "//tensorflow/python:util", @@ -27,7 +36,7 @@ py_library( ) py_library( - name = "cluster_resolver_py", + name = "base_cluster_resolver_py", srcs = ["python/training/cluster_resolver.py"], srcs_version = "PY2AND3", deps = [ @@ -40,7 +49,7 @@ py_library( srcs = ["python/training/gce_cluster_resolver.py"], srcs_version = "PY2AND3", deps = [ - ":cluster_resolver_py", + ":base_cluster_resolver_py", "//tensorflow/python:training", ], ) @@ -50,13 +59,13 @@ py_library( srcs = ["python/training/tpu_cluster_resolver.py"], srcs_version = "PY2AND3", deps = [ - ":cluster_resolver_py", + ":base_cluster_resolver_py", "//tensorflow/python:training", ], ) tf_py_test( - name = "cluster_resolver_py_test", + name = "base_cluster_resolver_py_test", srcs = ["python/training/cluster_resolver_test.py"], additional_deps = [ ":cluster_resolver_py", diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 8f521ffee4d31e090c13bac98290656d6e1d330e..1ab150d74ac00c5f9acf3c9399880708b2f62b1e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -148,6 +148,9 @@ class TPUClusterResolver(ClusterResolver): else: tpu = self._envVarFallback() + if tpu is None: + raise ValueError('Please provide a TPU Name to connect to.') + self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name self._credentials = credentials @@ -259,11 +262,11 @@ class TPUClusterResolver(ClusterResolver): if 'state' in response and response['state'] != 'READY': raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % - (self._tpu, response['state'])) + (compat.as_text(self._tpu), response['state'])) if 'health' in response and response['health'] != 'HEALTHY': - raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, - response['health'])) + raise RuntimeError('TPU "%s" is unhealthy: "%s"' % + (compat.as_text(self._tpu), response['health'])) if 'networkEndpoints' in response: worker_list = [ diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index e524e9e7437b19e0d117fe7b85042e8154773a02..f6c928e2be62e7292c6feaa3bb26fd463320158b 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -145,26 +145,41 @@ if(WIN32) # temporary fix for #18241 add_definitions(-DEIGEN_DEFAULT_DENSE_INDEX_TYPE=std::int64_t) endif() - add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11) - add_definitions(-DWIN32 -DOS_WIN -D_MBCS -DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS) + add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00) + add_definitions(-DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS) add_definitions(-DTENSORFLOW_USE_EIGEN_THREADPOOL -DEIGEN_HAS_C99_MATH) add_definitions(-DTF_COMPILE_LIBRARY) - add_definitions(/bigobj /nologo /EHsc /GF /MP /Gm-) + add_compile_options(/bigobj /GF /MP /Gm-) # Suppress warnings to reduce build log size. - add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018) - add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307) - add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334) - add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996) + add_compile_options(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018) + add_compile_options(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307) + add_compile_options(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334) + add_compile_options(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996) # Suppress linker warnings. set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} /ignore:4049 /ignore:4197 /ignore:4217 /ignore:4221") set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} /ignore:4049 /ignore:4197 /ignore:4217 /ignore:4221") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /ignore:4049 /ignore:4197 /ignore:4217 /ignore:4221") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") set(CMAKE_CXX_FLAGS_DEBUG "/D_DEBUG /MDd /Ob2") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /D_ITERATOR_DEBUG_LEVEL=0") set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /D_ITERATOR_DEBUG_LEVEL=0") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /D_ITERATOR_DEBUG_LEVEL=0") + set(compiler_flags + CMAKE_CXX_FLAGS + CMAKE_CXX_FLAGS_DEBUG + CMAKE_CXX_FLAGS_RELEASE + CMAKE_C_FLAGS + CMAKE_C_FLAGS_DEBUG + CMAKE_C_FLAGS_RELEASE + ) + # No exception + foreach(flag ${compiler_flags}) + string(REPLACE "/EHsc" "/EHs-c-" ${flag} "${${flag}}") + endforeach() + add_definitions(/D_HAS_EXCEPTIONS=0) + # Suppress 'noexcept used with no exception handling mode specified' warning + add_compile_options(/wd4577) + # Try to avoid flaky failures due to failed generation of generate.stamp files. set(CMAKE_SUPPRESS_REGENERATION ON) endif() @@ -299,17 +314,20 @@ include_directories( ${double_conversion_INCLUDE_DIR} ) -if(tensorflow_ENABLE_SSL_SUPPORT) - include(boringssl) - list(APPEND tensorflow_EXTERNAL_LIBRARIES ${boringssl_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES boringssl) - include_directories(${boringssl_INCLUDE_DIR}) -endif() if(tensorflow_ENABLE_GRPC_SUPPORT) + if(tensorflow_ENABLE_SSL_SUPPORT) + include(boringssl) + include_directories(${boringssl_INCLUDE_DIR}) + endif() include(grpc) + include_directories(${GRPC_INCLUDE_DIRS}) + # Place boringssl after grpc as grpc depends on boringssl. list(APPEND tensorflow_EXTERNAL_LIBRARIES ${grpc_STATIC_LIBRARIES}) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES grpc) - include_directories(${GRPC_INCLUDE_DIRS}) + if(tensorflow_ENABLE_SSL_SUPPORT) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${boringssl_STATIC_LIBRARIES}) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES boringssl) + endif() endif() if(tensorflow_ENABLE_JEMALLOC_SUPPORT) include(jemalloc) @@ -336,40 +354,14 @@ endif() # MKL Support if (tensorflow_ENABLE_MKL_SUPPORT) add_definitions(-DINTEL_MKL -DEIGEN_USE_VML) - if (WIN32) - find_path(MKL_HOME_PLATFORM mkl - PATHS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ - $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ - PATH_SUFFIXES windows) - set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) - set(MKL_LINK_DIRS - ${MKL_HOME_PLATFORM}/mkl/lib/intel64 - ${MKL_HOME_PLATFORM}/tbb/lib/intel64/vc_mt - ${MKL_HOME_PLATFORM}/compiler/lib/intel64 - ${MKL_HOME_PLATFORM}/mkl/tools/builder/lib) - set(MKL_REDIST_DLL_DIRS - ${MKL_HOME_PLATFORM}/redist/intel64/mkl - ${MKL_HOME_PLATFORM}/redist/intel64/tbb/vc_mt - ${MKL_HOME_PLATFORM}/redist/intel64/compiler) - list(APPEND tensorflow_EXTERNAL_LIBRARIES - mkl_intel_lp64_dll mkl_sequential_dll mkl_core_dll mkl_rt mkl_cdll_intel64) - endif() - if (UNIX) - # Fix me: complete the path on linux - find_path(MKL_HOME_PLATFORM mkl - HINTS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ - $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ - PATH_SUFFIXES linux) - set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) - set(MKL_LINK_DIRS) # incompleted - set(MKL_REDIST_SO_DIRS) # incompleted - endif() - include_directories(${MKL_INCLUDE_DIRS}) - link_directories(${MKL_LINK_DIRS}) + include(mkl) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkl_STATIC_LIBRARIES}) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkl_copy_shared_to_destination) + include_directories(${mkl_INCLUDE_DIRS}) if (tensorflow_ENABLE_MKLDNN_SUPPORT) include(mkldnn) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination) include_directories(${mkldnn_INCLUDE_DIRS}) else (tensorflow_ENABLE_MKLDNN_SUPPORT) add_definitions(-DINTEL_MKL_ML) @@ -402,16 +394,20 @@ if (tensorflow_ENABLE_GPU) # by default we assume compute cabability 3.5 and 5.2. If you change this change it in # CUDA_NVCC_FLAGS and cuda_config.h below - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_30,code=\"sm_30,compute_30\";-gencode arch=compute_35,code=\"sm_35,compute_35\";-gencode arch=compute_52,code=\"sm_52,compute_52\") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_37,code=\"sm_37,compute_37\") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_52,code=\"sm_52,compute_52\") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_60,code=\"sm_60,compute_60\") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_61,code=\"sm_61,compute_61\") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_70,code=\"sm_70,compute_70\") set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--include-path ${PROJECT_BINARY_DIR}/$\{build_configuration\};--expt-relaxed-constexpr) set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-ftz=true) # Flush denormals to zero set(CUDA_INCLUDE ${CUDA_TOOLKIT_TARGET_DIR} ${CUDA_TOOLKIT_TARGET_DIR}/extras/CUPTI/include) include_directories(${CUDA_INCLUDE}) if (WIN32) - add_definitions(-DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=3.0,3.5,5.2) + add_definitions(-DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=3.7,5.2,6.0,6.1,7.0) else (WIN32) - # Without these double quotes, cmake in Linux makes it "-DTF_EXTRA_CUDA_CAPABILITIES=3.0, -D3.5, -D5.2" for cc, which incurs build breaks - add_definitions(-DGOOGLE_CUDA=1 -D"TF_EXTRA_CUDA_CAPABILITIES=3.0,3.5,5.2") + # Without these double quotes, cmake in Linux makes it "-DTF_EXTRA_CUDA_CAPABILITIES=3.7, -D5.2, ..." for cc, which incurs build breaks + add_definitions(-DGOOGLE_CUDA=1 -D"TF_EXTRA_CUDA_CAPABILITIES=3.7,5.2,6.0,6.1,7.0") endif (WIN32) if (WIN32) @@ -460,7 +456,7 @@ if (tensorflow_ENABLE_GPU) FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h "#ifndef CUDA_CUDA_CONFIG_H_\n" "#define CUDA_CUDA_CONFIG_H_\n" - "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n" + "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.7\"),CudaVersion(\"5.2\"),CudaVersion(\"6.0\"),CudaVersion(\"6.1\"),CudaVersion(\"7.0\")\n" "#define TF_CUDA_VERSION \"64_${short_CUDA_VER}\"\n" "#define TF_CUDNN_VERSION \"64_${tensorflow_CUDNN_VERSION}\"\n" "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" @@ -475,7 +471,6 @@ if (tensorflow_ENABLE_GPU) ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_fp16.h ${CUDA_TOOLKIT_TARGET_DIR}/include/device_functions.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h diff --git a/tensorflow/contrib/cmake/external/boringssl.cmake b/tensorflow/contrib/cmake/external/boringssl.cmake index 3c4bb01e24fd121c9d0fc3594cc25de37af0e8a1..fbb14b2515a656f1dfc0e3f63ac367e9b7738a23 100644 --- a/tensorflow/contrib/cmake/external/boringssl.cmake +++ b/tensorflow/contrib/cmake/external/boringssl.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(boringssl_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src/boringssl/include) #set(boringssl_EXTRA_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src) set(boringssl_URL https://boringssl.googlesource.com/boringssl) -set(boringssl_TAG ee7aa02) +set(boringssl_TAG 7f8c553d7f4db0a6ce727f2986d41bf8fe8ec4bf) set(boringssl_BUILD ${CMAKE_BINARY_DIR}/boringssl/src/boringssl-build) #set(boringssl_LIBRARIES ${boringssl_BUILD}/obj/so/libboringssl.so) set(boringssl_STATIC_LIBRARIES diff --git a/tensorflow/contrib/cmake/external/eigen.cmake b/tensorflow/contrib/cmake/external/eigen.cmake index 45a0096085cc2a6332c82e1ea284812acdd45152..33bb31148d2e5b7ca177d7c30b7781e8f620c3cb 100644 --- a/tensorflow/contrib/cmake/external/eigen.cmake +++ b/tensorflow/contrib/cmake/external/eigen.cmake @@ -19,6 +19,12 @@ # build_file = "eigen.BUILD", #) +option(eigen_PATCH_FILE "Patch file to apply to eigen" OFF) +set(eigen_PATCH_COMMAND "") +if(eigen_PATCH_FILE) + set(eigen_PATCH_COMMAND PATCH_COMMAND patch -p0 -i "${eigen_PATCH_FILE}") +endif(eigen_PATCH_FILE) + include (ExternalProject) # We parse the current Eigen version and archive hash from the bazel configuration @@ -45,6 +51,7 @@ ExternalProject_Add(eigen URL ${eigen_URL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" INSTALL_DIR "${eigen_INSTALL}" + ${eigen_PATCH_COMMAND} CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF diff --git a/tensorflow/contrib/cmake/external/highwayhash.cmake b/tensorflow/contrib/cmake/external/highwayhash.cmake index a6e8a38d8c2ee3deb5453c264e0c5eb23248301f..7d260b85f21e7e56e153daf550c81155e4b68777 100644 --- a/tensorflow/contrib/cmake/external/highwayhash.cmake +++ b/tensorflow/contrib/cmake/external/highwayhash.cmake @@ -20,14 +20,6 @@ set(highwayhash_TAG be5edafc2e1a455768e260ccd68ae7317b6690ee) set(highwayhash_BUILD ${CMAKE_CURRENT_BINARY_DIR}/highwayhash/src/highwayhash) set(highwayhash_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/highwayhash/install) -# put highwayhash includes in the directory where they are expected -add_custom_target(highwayhash_create_destination_dir - COMMAND ${CMAKE_COMMAND} -E make_directory ${highwayhash_INCLUDE_DIR}/highwayhash - DEPENDS highwayhash) - -add_custom_target(highwayhash_copy_headers_to_destination - DEPENDS highwayhash_create_destination_dir) - if(WIN32) set(highwayhash_HEADERS "${highwayhash_BUILD}/highwayhash/*.h") set(highwayhash_STATIC_LIBRARIES ${highwayhash_INSTALL}/lib/highwayhash.lib) @@ -36,6 +28,20 @@ else() set(highwayhash_STATIC_LIBRARIES ${highwayhash_INSTALL}/lib/libhighwayhash.a) endif() +set(highwayhash_HEADERS + "${highwayhash_INSTALL}/include/code_annotation.h" + "${highwayhash_INSTALL}/include/highway_tree_hash.h" + "${highwayhash_INSTALL}/include/scalar_highway_tree_hash.h" + "${highwayhash_INSTALL}/include/scalar_sip_tree_hash.h" + "${highwayhash_INSTALL}/include/sip_hash.h" + "${highwayhash_INSTALL}/include/sip_tree_hash.h" + "${highwayhash_INSTALL}/include/sse41_highway_tree_hash.h" + "${highwayhash_INSTALL}/include/state_helpers.h" + "${highwayhash_INSTALL}/include/types.h" + "${highwayhash_INSTALL}/include/vec.h" + "${highwayhash_INSTALL}/include/vec2.h" +) + ExternalProject_Add(highwayhash PREFIX highwayhash GIT_REPOSITORY ${highwayhash_URL} @@ -50,5 +56,15 @@ ExternalProject_Add(highwayhash -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${highwayhash_INSTALL}) -add_custom_command(TARGET highwayhash_copy_headers_to_destination PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_directory ${highwayhash_INSTALL}/include/ ${highwayhash_INCLUDE_DIR}/highwayhash) +# put highwayhash includes in the directory where they are expected +add_custom_target(highwayhash_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${highwayhash_INCLUDE_DIR}/highwayhash + DEPENDS highwayhash) + +add_custom_target(highwayhash_copy_headers_to_destination + DEPENDS highwayhash_create_destination_dir) + +foreach(header_file ${highwayhash_HEADERS}) + add_custom_command(TARGET highwayhash_copy_headers_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${highwayhash_INCLUDE_DIR}/highwayhash/) +endforeach() diff --git a/tensorflow/contrib/cmake/external/mkl.cmake b/tensorflow/contrib/cmake/external/mkl.cmake new file mode 100644 index 0000000000000000000000000000000000000000..a172e3a41a283359b9a8c823ddcb2b1973b5b3cc --- /dev/null +++ b/tensorflow/contrib/cmake/external/mkl.cmake @@ -0,0 +1,68 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +include (ExternalProject) + +# NOTE: Different from mkldnn.cmake, this file is meant to download mkl libraries +set(mkl_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include) +set(mkl_BIN_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/bin) +set(mkl_WIN mklml_win_2018.0.3.20180406.zip) # match for v0.14 +set(mkl_MAC mklml_mac_2018.0.3.20180406.tgz) +set(mkl_LNX mklml_lnx_2018.0.3.20180406.tgz) +set(mkl_TAG v0.14) +set(mkl_URL https://github.com/intel/mkl-dnn/releases) + +if (WIN32) + set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_WIN}) + list(APPEND mkl_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.lib) + list(APPEND mkl_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.lib) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.dll) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.dll) +elseif (UNIX) + set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_LNX}) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5.so) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_gnu.so) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_intel.so) +elseif (APPLE) + set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_MAC}) + #TODO need more information +endif () + +ExternalProject_Add(mkl + PREFIX mkl + URL ${mkl_DOWNLOAD_URL} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "") + +# put mkl dynamic libraries in one bin directory +add_custom_target(mkl_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${mkl_BIN_DIRS} + DEPENDS mkl) + +add_custom_target(mkl_copy_shared_to_destination DEPENDS mkl_create_destination_dir) + +foreach(dll_file ${mkl_SHARED_LIBRARIES}) + add_custom_command(TARGET mkl_copy_shared_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dll_file} ${mkl_BIN_DIRS}) +endforeach() diff --git a/tensorflow/contrib/cmake/external/mkldnn.cmake b/tensorflow/contrib/cmake/external/mkldnn.cmake index a639fdee367f060d4c8a79267803da6ffe3dc503..8123ee1f393ab8e3a52f13915ea2a65decc188d9 100644 --- a/tensorflow/contrib/cmake/external/mkldnn.cmake +++ b/tensorflow/contrib/cmake/external/mkldnn.cmake @@ -22,8 +22,11 @@ set(mkldnn_TAG 3063b2e4c943983f6bf5f2fb9a490d4a998cd291) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.lib) + set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.dll) + set(mkldnn_BUILD ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release) else() set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.lib) + set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.dll) endif() else() set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/libmkldnn.a) @@ -31,6 +34,7 @@ endif() ExternalProject_Add(mkldnn PREFIX mkldnn + DEPENDS mkl GIT_REPOSITORY ${mkldnn_URL} GIT_TAG ${mkldnn_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" @@ -40,5 +44,11 @@ ExternalProject_Add(mkldnn CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DMKLINC:STRING=${MKL_INCLUDE_DIRS} + -DMKLINC:STRING=${mkl_INCLUDE_DIRS} ) + +# since mkldnn depends on mkl, copy the mkldnn.dll together with mklml.dll to mkl_bin_dirs +add_custom_target(mkldnn_copy_shared_to_destination DEPENDS mkldnn) + +add_custom_command(TARGET mkldnn_copy_shared_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${mkldnn_SHARED_LIBRARIES} ${mkl_BIN_DIRS}) diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index b9d1dd88d4c2d3c9141ba56e14911e06b4d33f7c..1d638e64023c7e2706d8d97ff8679677b6cd289d 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,18 +16,10 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 0559ce013feac8db639ee1bf776aca0325d28777) +set(nsync_TAG 1.20.0) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) -# put nsync includes in the directory where they are expected -add_custom_target(nsync_create_destination_dir - COMMAND ${CMAKE_COMMAND} -E make_directory ${nsync_INCLUDE_DIR} - DEPENDS nsync) - -add_custom_target(nsync_copy_headers_to_destination - DEPENDS nsync_create_destination_dir) - if(WIN32) set(nsync_HEADERS "${nsync_BUILD}/public/*.h") set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/nsync.lib) @@ -49,7 +41,35 @@ ExternalProject_Add(nsync -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${nsync_INSTALL} - -DNSYNC_LANGUAGE:STRING=c++11) + -DNSYNC_LANGUAGE:STRING=c++11) + +set(nsync_HEADERS + "${nsync_INSTALL}/include/nsync.h" + "${nsync_INSTALL}/include/nsync_atomic.h" + "${nsync_INSTALL}/include/nsync_counter.h" + "${nsync_INSTALL}/include/nsync_cpp.h" + "${nsync_INSTALL}/include/nsync_cv.h" + "${nsync_INSTALL}/include/nsync_debug.h" + "${nsync_INSTALL}/include/nsync_mu.h" + "${nsync_INSTALL}/include/nsync_mu_wait.h" + "${nsync_INSTALL}/include/nsync_note.h" + "${nsync_INSTALL}/include/nsync_once.h" + "${nsync_INSTALL}/include/nsync_time.h" + "${nsync_INSTALL}/include/nsync_time_internal.h" + "${nsync_INSTALL}/include/nsync_waiter.h" +) + +# put nsync includes in the directory where they are expected +add_custom_target(nsync_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${nsync_INCLUDE_DIR} + DEPENDS nsync) + +add_custom_target(nsync_copy_headers_to_destination + DEPENDS nsync_create_destination_dir) + +foreach(header_file ${nsync_HEADERS}) + add_custom_command(TARGET nsync_copy_headers_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${nsync_INCLUDE_DIR}/) +endforeach() + -add_custom_command(TARGET nsync_copy_headers_to_destination PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_directory ${nsync_INSTALL}/include/ ${nsync_INCLUDE_DIR}/) diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index ab464bc99a43138130bb2758ae28ecef29805c31..f56fb35a0f71250f00b84e5cf94a24682bda6c82 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) +set(PROTOBUF_TAG v3.6.0) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 015cb73bbd93bb77f6748a364b263d99eb305c27..a5a947f7261559b6d25c452efe35097258d5625c 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -14,6 +14,7 @@ tensorflow/examples/tutorials tensorflow/examples/tutorials/mnist tensorflow/python tensorflow/python/client +tensorflow/python/compat tensorflow/python/data tensorflow/python/data/ops tensorflow/python/data/util @@ -35,6 +36,7 @@ tensorflow/python/keras tensorflow/python/keras/applications tensorflow/python/keras/datasets tensorflow/python/keras/engine +tensorflow/python/keras/estimator tensorflow/python/keras/layers tensorflow/python/keras/preprocessing tensorflow/python/keras/utils @@ -60,6 +62,8 @@ tensorflow/python/saved_model tensorflow/python/summary tensorflow/python/summary/writer tensorflow/python/tools +tensorflow/python/tools/api +tensorflow/python/tools/api/generator tensorflow/python/training tensorflow/python/training/checkpointable tensorflow/python/user_ops @@ -67,7 +71,6 @@ tensorflow/python/util tensorflow/python/util/protobuf tensorflow/tools tensorflow/tools/api -tensorflow/tools/api/generator tensorflow/tools/graph_transforms tensorflow/contrib tensorflow/contrib/all_reduce @@ -85,6 +88,8 @@ tensorflow/contrib/batching/python/ops tensorflow/contrib/bayesflow tensorflow/contrib/bayesflow/python tensorflow/contrib/bayesflow/python/ops +# tensorflow/contrib/bigtable/python +# tensorflow/contrib/bigtable/python/ops tensorflow/contrib/boosted_trees tensorflow/contrib/boosted_trees/estimator_batch tensorflow/contrib/boosted_trees/kernels @@ -110,13 +115,10 @@ tensorflow/contrib/coder tensorflow/contrib/coder/kernels tensorflow/contrib/coder/ops tensorflow/contrib/coder/python -tensorflow/contrib/coder/python/layers tensorflow/contrib/coder/python/ops tensorflow/contrib/compiler tensorflow/contrib/constrained_optimization tensorflow/contrib/constrained_optimization/python -tensorflow/contrib/control_flow -tensorflow/contrib/control_flow/python tensorflow/contrib/copy_graph tensorflow/contrib/copy_graph/python tensorflow/contrib/copy_graph/python/util @@ -131,6 +133,7 @@ tensorflow/contrib/data tensorflow/contrib/data/kernels tensorflow/contrib/data/python tensorflow/contrib/data/python/kernel_tests +tensorflow/contrib/data/python/kernel_tests/serialization tensorflow/contrib/data/python/ops tensorflow/contrib/decision_trees tensorflow/contrib/decision_trees/proto @@ -183,6 +186,8 @@ tensorflow/contrib/graph_editor/examples tensorflow/contrib/grid_rnn tensorflow/contrib/grid_rnn/python tensorflow/contrib/grid_rnn/python/ops +tensorflow/contrib/hadoop/python +tensorflow/contrib/hadoop/python/ops tensorflow/contrib/hooks tensorflow/contrib/hooks/python tensorflow/contrib/image @@ -238,6 +243,8 @@ tensorflow/contrib/keras/api/keras/wrappers/scikit_learn tensorflow/contrib/kernel_methods tensorflow/contrib/kernel_methods/python tensorflow/contrib/kernel_methods/python/mappers +tensorflow/contrib/kinesis/python +tensorflow/contrib/kinesis/python/ops tensorflow/contrib/kfac tensorflow/contrib/kfac/examples tensorflow/contrib/kfac/python diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index 2e0a2fcef4cbdc50f0521296c4a25a864dbd8b77..7a30eb94f54b18a2a517615a315e23e09e1170d0 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -36,16 +36,3 @@ add_dependencies( tf_cc_while_loop tf_core_lib tf_protos_cc) - -if(tensorflow_BUILD_PYTHON_BINDINGS) - add_library(tf_c_python_api OBJECT - "${tensorflow_source_dir}/tensorflow/c/python_api.cc" - "${tensorflow_source_dir}/tensorflow/c/python_api.h" - ) - add_dependencies( - tf_c_python_api - tf_c - tf_core_lib - tf_core_framework - tf_protos_cc) -endif() diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index dac84ccb0dbf4848329e35a6e9bcf6213d8c0e55..067c299a71cd4ac96878bcf27b4453466785e4ba 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -125,6 +125,7 @@ endfunction() file(GLOB_RECURSE tf_protos_cc_srcs RELATIVE ${tensorflow_source_dir} "${tensorflow_source_dir}/tensorflow/core/*.proto" + "${tensorflow_source_dir}/tensorflow/compiler/xla/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto" ) @@ -233,15 +234,6 @@ if(WIN32) list(APPEND tf_core_lib_srcs ${tf_core_platform_windows_srcs}) endif(WIN32) -if(tensorflow_ENABLE_SSL_SUPPORT) - # Cloud libraries require boringssl. - file(GLOB tf_core_platform_cloud_srcs - "${tensorflow_source_dir}/tensorflow/core/platform/cloud/*.h" - "${tensorflow_source_dir}/tensorflow/core/platform/cloud/*.cc" - ) - list(APPEND tf_core_lib_srcs ${tf_core_platform_cloud_srcs}) -endif() - if (tensorflow_ENABLE_HDFS_SUPPORT) list(APPEND tf_core_platform_hdfs_srcs "${tensorflow_source_dir}/tensorflow/core/platform/hadoop/hadoop_file_system.cc" diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 2d76bf530a2100b2afa80a16a5d64b6ec51ffc68..7b892ba248bc43cd885f295288c677ac97efaa06 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -68,6 +68,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" @@ -134,14 +135,13 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) list(APPEND tf_core_kernels_srcs ${tf_contrib_kernels_srcs}) endif(tensorflow_BUILD_CONTRIB_KERNELS) -if(NOT tensorflow_ENABLE_SSL_SUPPORT) - # Cloud libraries require boringssl. - file(GLOB tf_core_kernels_cloud_srcs - "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.h" - "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.cc" - ) +# Cloud libraries require curl and boringssl. +# Curl is not supported yet anyway so we remove for now. +file(GLOB tf_core_kernels_cloud_srcs + "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.h" + "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.cc" +) list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_cloud_srcs}) -endif() file(GLOB_RECURSE tf_core_kernels_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/kernels/*test*.h" diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 92446044892127284ecb8753a250b77cb2a5743a..6d86daf5f174a3238ab92e5bba6085c904766766 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -198,7 +198,7 @@ function(add_python_module MODULE_NAME) # so we currently add explicit commands to include those files # later on in this script. if (NOT "${script}" MATCHES "_test\.py$") - add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD + add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/${script} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/${script}) endif() endforeach() @@ -297,7 +297,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) ) target_link_libraries(${tf_python_op_lib_name}_gen_python PRIVATE tf_protos_cc - tf_python_protos_cc + tf_python_protos_cc ${tensorflow_EXTERNAL_LIBRARIES} ) @@ -456,6 +456,18 @@ add_custom_command( COMMENT "Running SWIG to generate Python wrappers" VERBATIM ) +add_library(tf_c_python_api OBJECT + "${tensorflow_source_dir}/tensorflow/c/python_api.cc" + "${tensorflow_source_dir}/tensorflow/c/python_api.h" +) +add_dependencies( + tf_c_python_api + tf_c + tf_core_lib + tf_core_framework + tf_protos_cc + tf_python_protos_cc) + set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h" "${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc" @@ -537,15 +549,15 @@ if(WIN32) ${NUMPY_INCLUDE_DIR} ) #target_link_libraries(pywrap_tensorflow_internal_static - # tf_protos_cc - # tf_python_protos_cc + # tf_protos_cc + # tf_python_protos_cc #) add_dependencies(pywrap_tensorflow_internal_static tf_protos_cc tf_python_protos_cc) set(pywrap_tensorflow_internal_static_dependencies $ $ $ - ${nsync_STATIC_LIBRARIES} + ${nsync_STATIC_LIBRARIES} ) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") @@ -724,8 +736,8 @@ endif() # Generate API __init__.py files. ######################################################## -# Parse tensorflow/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) +# Parse tensorflow/python/tools/api/generator/BUILD to get list of generated files. +FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_init_files.bzl api_generator_BUILD_text) STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) @@ -743,6 +755,27 @@ set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt") file(WRITE "${api_init_list_file}" "${api_init_files}") # Run create_python_api.py to generate __init__.py files. + +### TODO +# In order to download and compile MKL/MKL-DNN automatically in cmake script, mkl-built libraries should be added to system path +# to be loaded by python executor. However `add_custom_command` has an issue with `COMMAND ${CMAKE_COMMAND} -E env PATH=`, where +# arguments of multiple paths (such as D:/;D:/mkl) will be parsed in to seperate string without semicolon and that command fail to +# recongnize paths. As CUDA isn't built with MKL, the MKL built directory is the only path to this command to work around that issue. +# To not override the CUDA and system path in other circumstances, `if-else` branch used here to handle this problem, +# and should be removed if the path issue can be resolved. +# UPDATE: Below block appears to handle multiple items in PATH correctly, but risks command line limits if PATH is large. +# If you have issues, try `set(PY_RUNTIME_ENV "PATH=${mkl_BIN_DIRS}")` instead. +### + +set(PY_RUNTIME_ENV "") +if(tensorflow_ENABLE_MKL_SUPPORT) + # add mkl dist dlls to system path for python + file(TO_CMAKE_PATH "$ENV{PATH}" PY_RUNTIME_ENV) + set(PY_RUNTIME_ENV ${mkl_BIN_DIRS} ${PY_RUNTIME_ENV}) + file(TO_NATIVE_PATH "${PY_RUNTIME_ENV}" PY_RUNTIME_ENV) + set(PY_RUNTIME_ENV "PATH=${PY_RUNTIME_ENV}") +endif(tensorflow_ENABLE_MKL_SUPPORT) + add_custom_command( OUTPUT ${api_init_files} DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops @@ -752,8 +785,8 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py # Run create_python_api.py to generate API init files. - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python "${PY_RUNTIME_ENV}" ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py" "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" "--package=tensorflow.python" @@ -762,6 +795,7 @@ add_custom_command( COMMENT "Generating __init__.py files for Python API." WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" + VERBATIM ) add_custom_target(tf_python_api SOURCES ${api_init_files}) @@ -773,8 +807,8 @@ add_dependencies(tf_python_api tf_python_ops) # Generate API __init__.py files for tf.estimator. ######################################################## -# Parse tensorflow/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) +# Parse tensorflow/python/tools/api/generator/BUILD to get list of generated files. +FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_gen.bzl api_generator_BUILD_text) STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text}) string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) @@ -797,11 +831,12 @@ add_custom_command( DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops # Run create_python_api.py to generate API init files. - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python "${PY_RUNTIME_ENV}" ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py" "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api" "--package=tensorflow.python.estimator" "--apiname=estimator" + "--output_package=tensorflow.python.estimator.api" "${estimator_api_init_list_file}" COMMENT "Generating __init__.py files for Python API." diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 38f40452b533fdc0dba6ac686a0ff43a2ef13cb8..fdf522f1fd90ffc64acbe82381ef57a389645d61 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -145,3 +145,8 @@ install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/ # unsupported Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/ DESTINATION include/unsupported/Eigen) +# mkl +if (tensorflow_ENABLE_MKL_SUPPORT) + install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include/ + DESTINATION include/mkl) +endif (tensorflow_ENABLE_MKL_SUPPORT) diff --git a/tensorflow/contrib/cmake/tf_stream_executor.cmake b/tensorflow/contrib/cmake/tf_stream_executor.cmake index 9a37b681194d4ef82b27a0160dd969f733ecad67..6d634cb1709910f366c7ca538d28bd802b2a7c63 100644 --- a/tensorflow/contrib/cmake/tf_stream_executor.cmake +++ b/tensorflow/contrib/cmake/tf_stream_executor.cmake @@ -64,8 +64,6 @@ file(GLOB tf_stream_executor_srcs if (tensorflow_ENABLE_GPU) file(GLOB tf_stream_executor_gpu_srcs "${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*.cc" - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.h" - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.cc" ) if (NOT tensorflow_BUILD_CC_TESTS) file(GLOB tf_stream_executor_gpu_tests @@ -76,11 +74,11 @@ if (tensorflow_ENABLE_GPU) list(APPEND tf_stream_executor_srcs ${tf_stream_executor_gpu_srcs}) endif() -#file(GLOB_RECURSE tf_stream_executor_test_srcs -# "${tensorflow_source_dir}/tensorflow/stream_executor/*_test.cc" -# "${tensorflow_source_dir}/tensorflow/stream_executor/*_test.h" -#) -#list(REMOVE_ITEM tf_stream_executor_srcs ${tf_stream_executor_test_srcs}) +file(GLOB_RECURSE tf_stream_executor_test_srcs + "${tensorflow_source_dir}/tensorflow/stream_executor/*test.cc" + "${tensorflow_source_dir}/tensorflow/stream_executor/lib/*test.h" +) +list(REMOVE_ITEM tf_stream_executor_srcs ${tf_stream_executor_test_srcs}) if (NOT WIN32) set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lgomp") diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index eb9482dc25f2be8ce46cc38bf3dd28889b09a9d4..2c878c17167c662d10a8c7dabf41687efdbf65d8 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -122,6 +122,17 @@ function(AddPythonTests) endforeach() endfunction(AddPythonTests) +# +# ensure that every element is an existing file +# +function(CheckExists TYPE SOURCES) + foreach(source ${SOURCES}) + if(NOT EXISTS ${source}) + message(SEND_ERROR "${TYPE} not found: ${source}") + endif() + endforeach(source) +endfunction(CheckExists) + if (tensorflow_BUILD_PYTHON_TESTS) # # python tests. This assumes that the tensorflow wheel is @@ -145,7 +156,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/debug/wrappers/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/*_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py" - "${tensorflow_source_dir}/tensorflow/python/meta_graph_transform/*_test.py" "${tensorflow_source_dir}/tensorflow/python/ops/quantized_conv_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/ops/quantized_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/platform/build_info_test.py" @@ -193,11 +203,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) # flaky test "${tensorflow_source_dir}/tensorflow/python/profiler/internal/run_metadata_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/model_analyzer_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/map_dataset_op_test.py" # Fails because uses data dependencies with bazel "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py" # requires scipy - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py" # Takes very long to run without sharding (defined in bazel build file). @@ -216,7 +226,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) ${tensorflow_source_dir}/tensorflow/python/kernel_tests/duplicate_op_test.py ${tensorflow_source_dir}/tensorflow/python/kernel_tests/invalid_op_test.py ${tensorflow_source_dir}/tensorflow/python/kernel_tests/ackermann_test.py - + # Tests too large to run. + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py ) if (WIN32) set(tf_test_src_py_exclude @@ -254,10 +265,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Flaky because of local cluster creation. "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" - "${tensorflow_source_dir}tensorflow/python/training/localhost_cluster_performance_test.py" + "${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Type error in testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU. "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py" @@ -327,6 +337,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py" # b/72894325 ) endif() + CheckExists(${tf_test_src_py_exclude}) list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) AddPythonTests( @@ -478,6 +489,7 @@ if (tensorflow_BUILD_CC_TESTS) "${tensorflow_source_dir}/tensorflow/cc/saved_model/*_test.cc" ) + CheckExists(${tf_test_src_simple_exclude}) list(REMOVE_ITEM tf_test_src_simple ${tf_test_src_simple_exclude} ${tf_cc_saved_model_test_srcs} @@ -492,6 +504,7 @@ if (tensorflow_BUILD_CC_TESTS) ${tf_core_profiler_test_srcs} ) + CheckExists(${tf_src_testlib}) set(tf_test_lib tf_test_lib) add_library(${tf_test_lib} STATIC ${tf_src_testlib}) diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD index a2c6e413039ee3b5af3cb53d1af3325037536d36..855c824ead2f7de4c37db2d2a3648a9ee00fb9e9 100644 --- a/tensorflow/contrib/coder/BUILD +++ b/tensorflow/contrib/coder/BUILD @@ -1,5 +1,5 @@ # Description: -# Contains tools related to data compression. +# Contains ops related to data compression. package(default_visibility = [ "//learning/brain:__subpackages__", @@ -168,7 +168,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":coder_ops_py", - ":entropybottleneck_py", ], ) @@ -205,44 +204,3 @@ tf_py_test( ], main = "python/ops/coder_ops_test.py", ) - -py_library( - name = "entropybottleneck_py", - srcs = [ - "python/layers/entropybottleneck.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":coder_ops_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:functional_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", - "//tensorflow/python:ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - "//tensorflow/python/keras:engine", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "entropybottleneck_py_test", - srcs = [ - "python/layers/entropybottleneck_test.py", - ], - additional_deps = [ - ":entropybottleneck_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:variables", - "//tensorflow/python:training", - ], - main = "python/layers/entropybottleneck_test.py", -) diff --git a/tensorflow/contrib/coder/README.md b/tensorflow/contrib/coder/README.md deleted file mode 100644 index c6c379c458893551b765327c0c1cbfff7f24f9c3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/coder/README.md +++ /dev/null @@ -1,73 +0,0 @@ -# Entropy coder - -This module contains range encoder and range decoder which can encode integer -data into string with cumulative distribution functions (CDF). - -## Data and CDF values - -The data to be encoded should be non-negative integers in half-open interval -`[0, m)`. Then a CDF is represented as an integral vector of length `m + 1` -where `CDF(i) = f(Pr(X < i) * 2^precision)` for i = 0,1,...,m, and `precision` -is an attribute in range `0 < precision <= 16`. The function `f` maps real -values into integers, e.g., round or floor. It is important that to encode a -number `i`, `CDF(i + 1) - CDF(i)` cannot be zero. - -Note that we used `Pr(X < i)` not `Pr(X <= i)`, and therefore CDF(0) = 0 always. - -## RangeEncode: data shapes and CDF shapes - -For each data element, its CDF has to be provided. Therefore if the shape of CDF -should be `data.shape + (m + 1,)` in NumPy-like notation. For example, if `data` -is a 2-D tensor of shape (10, 10) and its elements are in `[0, 64)`, then the -CDF tensor should have shape (10, 10, 65). - -This may make CDF tensor too large, and in many applications all data elements -may have the same probability distribution. To handle this, `RangeEncode` -supports limited broadcasting CDF into data. Broadcasting is limited in the -following sense: - -- All CDF axes but the last one is broadcasted into data but not the other way - around, -- The number of CDF axes does not extend, i.e., `CDF.ndim == data.ndim + 1`. - -In the previous example where data has shape (10, 10), the following are -acceptable CDF shapes: - -- (10, 10, 65) -- (1, 10, 65) -- (10, 1, 65) -- (1, 1, 65) - -## RangeDecode - -`RangeEncode` encodes neither data shape nor termination character. Therefore -the decoder should know how many characters are encoded into the string, and -`RangeDecode` takes the encoded data shape as the second argument. The same -shape restrictions as `RangeEncode` inputs apply here. - -## Example - -```python -data = tf.random_uniform((128, 128), 0, 10, dtype=tf.int32) - -histogram = tf.bincount(data, minlength=10, maxlength=10) -cdf = tf.cumsum(histogram, exclusive=False) -# CDF should have length m + 1. -cdf = tf.pad(cdf, [[1, 0]]) -# CDF axis count must be one more than data. -cdf = tf.reshape(cdf, [1, 1, -1]) - -# Note that data has 2^14 elements, and therefore the sum of CDF is 2^14. -data = tf.cast(data, tf.int16) -encoded = coder.range_encode(data, cdf, precision=14) -decoded = coder.range_decode(encoded, tf.shape(data), cdf, precision=14) - -# data and decoded should be the same. -sess = tf.Session() -x, y = sess.run((data, decoded)) -assert np.all(x == y) -``` - -## Authors -Sung Jin Hwang (github: [ssjhv](https://github.com/ssjhv)) and Nick Johnston -(github: [nmjohn](https://github.com/nmjohn)) diff --git a/tensorflow/contrib/coder/__init__.py b/tensorflow/contrib/coder/__init__.py index 99b8ac7595ec632b2918e6b7ca22c06dd7f0a8b3..8897312046c63c42d85e7fba5b62d2ed908dd6e9 100644 --- a/tensorflow/contrib/coder/__init__.py +++ b/tensorflow/contrib/coder/__init__.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Data compression tools.""" +"""Data compression ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import -from tensorflow.contrib.coder.python.layers.entropybottleneck import * from tensorflow.contrib.coder.python.ops.coder_ops import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py deleted file mode 100644 index 0fbe3081af0b4de7f116918b3f49efe91a2d83bd..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/coder/python/layers/entropybottleneck.py +++ /dev/null @@ -1,697 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Entropy bottleneck layer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.coder.python.ops import coder_ops - -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras import engine -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.summary import summary - - -class EntropyBottleneck(engine.Layer): - """Entropy bottleneck layer. - - This layer can be used to model the entropy (the amount of information - conveyed) of the tensor passing through it. During training, this can be used - to impose a (soft) entropy constraint on its activations, limiting the amount - of information flowing through the layer. Note that this is distinct from - other types of bottlenecks, which reduce the dimensionality of the space, for - example. Dimensionality reduction does not limit the amount of information, - and does not enable efficient data compression per se. - - After training, this layer can be used to compress any input tensor to a - string, which may be written to a file, and to decompress a file which it - previously generated back to a reconstructed tensor (possibly on a different - machine having access to the same model checkpoint). The entropies estimated - during training or evaluation are approximately equal to the average length of - the strings in bits. - - The layer implements a flexible probability density model to estimate entropy, - which is described in the appendix of the paper (please cite the paper if you - use this code for scientific work): - - "Variational image compression with a scale hyperprior" - - Johannes Ballé, David Minnen, Saurabh Singh, Sung Jin Hwang, Nick Johnston - - https://arxiv.org/abs/1802.01436 - - The layer assumes that the input tensor is at least 2D, with a batch dimension - at the beginning and a channel dimension as specified by `data_format`. The - layer trains an independent probability density model for each channel, but - assumes that across all other dimensions, the inputs are i.i.d. (independent - and identically distributed). Because the entropy (and hence, average - codelength) is a function of the densities, this assumption may have a direct - effect on the compression performance. - - Because data compression always involves discretization, the outputs of the - layer are generally only approximations of its inputs. During training, - discretization is modeled using additive uniform noise to ensure - differentiability. The entropies computed during training are differential - entropies. During evaluation, the data is actually quantized, and the - entropies are discrete (Shannon entropies). To make sure the approximated - tensor values are good enough for practical purposes, the training phase must - be used to balance the quality of the approximation with the entropy, by - adding an entropy term to the training loss, as in the following example. - - Here, we use the entropy bottleneck to compress the latent representation of - an autoencoder. The data vectors `x` in this case are 4D tensors in - `'channels_last'` format (for example, 16x16 pixel grayscale images). - - The layer always produces exactly one auxiliary loss and one update op which - are only significant for compression and decompression. To use the compression - feature, the auxiliary loss must be minimized during or after training. After - that, the update op must be executed at least once. Here, we simply attach - them to the main training step. - - Training: - ``` - # Build autoencoder. - x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) - y = forward_transform(x) - entropy_bottleneck = EntropyBottleneck() - y_, likelihoods = entropy_bottleneck(y, training=True) - x_ = backward_transform(y_) - - # Information content (= predicted codelength) in bits of each batch element - # (note that taking the natural logarithm and dividing by `log(2)` is - # equivalent to taking base-2 logarithms): - bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2) - - # Squared difference of each batch element: - squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3)) - - # The loss is a weighted sum of mean squared error and entropy (average - # information content), where the weight controls the trade-off between - # approximation error and entropy. - main_loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits) - - # Minimize loss and auxiliary loss, and execute update op. - main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) - main_step = optimizer.minimize(main_loss) - # 1e-2 is a good starting point for the learning rate of the auxiliary loss, - # assuming Adam is used. - aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-2) - aux_step = optimizer.minimize(entropy_bottleneck.losses[0]) - step = tf.group(main_step, aux_step, entropy_bottleneck.updates[0]) - ``` - - Evaluation: - ``` - # Build autoencoder. - x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) - y = forward_transform(x) - y_, likelihoods = EntropyBottleneck()(y, training=False) - x_ = backward_transform(y_) - - # Information content (= predicted codelength) in bits of each batch element: - bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2) - - # Squared difference of each batch element: - squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3)) - - # The loss is a weighted sum of mean squared error and entropy (average - # information content), where the weight controls the trade-off between - # approximation error and entropy. - loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits) - ``` - - To be able to compress the bottleneck tensor and decompress it in a different - session, or on a different machine, you need three items: - - The compressed representations stored as strings. - - The shape of the bottleneck for these string representations as a `Tensor`, - as well as the number of channels of the bottleneck at graph construction - time. - - The checkpoint of the trained model that was used for compression. Note: - It is crucial that the auxiliary loss produced by this layer is minimized - during or after training, and that the update op is run after training and - minimization of the auxiliary loss, but *before* the checkpoint is saved. - - Compression: - ``` - x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) - y = forward_transform(x) - strings = EntropyBottleneck().compress(y) - shape = tf.shape(y)[1:] - ``` - - Decompression: - ``` - strings = tf.placeholder(tf.string, shape=[None]) - shape = tf.placeholder(tf.int32, shape=[3]) - entropy_bottleneck = EntropyBottleneck(dtype=tf.float32) - y_ = entropy_bottleneck.decompress(strings, shape, channels=5) - x_ = backward_transform(y_) - ``` - Here, we assumed that the tensor produced by the forward transform has 5 - channels. - - The above four use cases can also be implemented within the same session (i.e. - on the same `EntropyBottleneck` instance), for testing purposes, etc., by - calling the object more than once. - - Arguments: - init_scale: Float. A scaling factor determining the initial width of the - probability densities. This should be chosen big enough so that the - range of values of the layer inputs roughly falls within the interval - [`-init_scale`, `init_scale`] at the beginning of training. - filters: An iterable of ints, giving the number of filters at each layer of - the density model. Generally, the more filters and layers, the more - expressive is the density model in terms of modeling more complicated - distributions of the layer inputs. For details, refer to the paper - referenced above. The default is `[3, 3, 3]`, which should be sufficient - for most practical purposes. - tail_mass: Float, between 0 and 1. The bottleneck layer automatically - determines the range of input values that should be represented based on - their frequency of occurrence. Values occurring in the tails of the - distributions will be clipped to that range during compression. - `tail_mass` determines the amount of probability mass in the tails which - is cut off in the worst case. For example, the default value of `1e-9` - means that at most 1 in a billion input samples will be clipped to the - range. - optimize_integer_offset: Boolean. Typically, the input values of this layer - are floats, which means that quantization during evaluation can be - performed with an arbitrary offset. By default, the layer determines that - offset automatically. In special situations, such as when it is known that - the layer will receive only full integer values during evaluation, it can - be desirable to set this argument to `False` instead, in order to always - quantize to full integer values. - likelihood_bound: Float. If positive, the returned likelihood values are - ensured to be greater than or equal to this value. This prevents very - large gradients with a typical entropy loss (defaults to 1e-9). - range_coder_precision: Integer, between 1 and 16. The precision of the range - coder used for compression and decompression. This trades off computation - speed with compression efficiency, where 16 is the slowest but most - efficient setting. Choosing lower values may increase the average - codelength slightly compared to the estimated entropies. - data_format: Either `'channels_first'` or `'channels_last'` (default). - trainable: Boolean. Whether the layer should be trained. - name: String. The name of the layer. - dtype: Default dtype of the layer's parameters (default of `None` means use - the type of the first input). - - Read-only properties: - init_scale: See above. - filters: See above. - tail_mass: See above. - optimize_integer_offset: See above. - likelihood_bound: See above. - range_coder_precision: See above. - data_format: See above. - name: String. See above. - dtype: See above. - trainable_variables: List of trainable variables. - non_trainable_variables: List of non-trainable variables. - variables: List of all variables of this layer, trainable and non-trainable. - updates: List of update ops of this layer. Always contains exactly one - update op, which must be run once after the last training step, before - `compress` or `decompress` is used. - losses: List of losses added by this layer. Always contains exactly one - auxiliary loss, which must be added to the training loss. - - Mutable properties: - trainable: Boolean. Whether the layer should be trained. - input_spec: Optional `InputSpec` object specifying the constraints on inputs - that can be accepted by the layer. - """ - - def __init__(self, init_scale=10, filters=(3, 3, 3), tail_mass=1e-9, - optimize_integer_offset=True, likelihood_bound=1e-9, - range_coder_precision=16, data_format="channels_last", **kwargs): - super(EntropyBottleneck, self).__init__(**kwargs) - self._init_scale = float(init_scale) - self._filters = tuple(int(f) for f in filters) - self._tail_mass = float(tail_mass) - if not 0 < self.tail_mass < 1: - raise ValueError( - "`tail_mass` must be between 0 and 1, got {}.".format(self.tail_mass)) - self._optimize_integer_offset = bool(optimize_integer_offset) - self._likelihood_bound = float(likelihood_bound) - self._range_coder_precision = int(range_coder_precision) - self._data_format = data_format - self._channel_axis(2) # trigger ValueError early - self.input_spec = engine.InputSpec(min_ndim=2) - - @property - def init_scale(self): - return self._init_scale - - @property - def filters(self): - return self._filters - - @property - def tail_mass(self): - return self._tail_mass - - @property - def optimize_integer_offset(self): - return self._optimize_integer_offset - - @property - def likelihood_bound(self): - return self._likelihood_bound - - @property - def range_coder_precision(self): - return self._range_coder_precision - - @property - def data_format(self): - return self._data_format - - def _channel_axis(self, ndim): - try: - return {"channels_first": 1, "channels_last": ndim - 1}[self.data_format] - except KeyError: - raise ValueError("Unsupported `data_format` for {} layer: {}.".format( - self.__class__.__name__, self.data_format)) - - def _logits_cumulative(self, inputs, stop_gradient): - """Evaluate logits of the cumulative densities. - - Args: - inputs: The values at which to evaluate the cumulative densities, expected - to be a `Tensor` of shape `(channels, 1, batch)`. - stop_gradient: Boolean. Whether to add `array_ops.stop_gradient` calls so - that the gradient of the output with respect to the density model - parameters is disconnected (the gradient with respect to `inputs` is - left untouched). - - Returns: - A `Tensor` of the same shape as `inputs`, containing the logits of the - cumulative densities evaluated at the given inputs. - """ - logits = inputs - - for i in range(len(self.filters) + 1): - matrix = self._matrices[i] - if stop_gradient: - matrix = array_ops.stop_gradient(matrix) - logits = math_ops.matmul(matrix, logits) - - bias = self._biases[i] - if stop_gradient: - bias = array_ops.stop_gradient(bias) - logits += bias - - if i < len(self._factors): - factor = self._factors[i] - if stop_gradient: - factor = array_ops.stop_gradient(factor) - logits += factor * math_ops.tanh(logits) - - return logits - - def build(self, input_shape): - """Builds the layer. - - Creates the variables for the network modeling the densities, creates the - auxiliary loss estimating the median and tail quantiles of the densities, - and then uses that to create the probability mass functions and the update - op that produces the discrete cumulative density functions used by the range - coder. - - Args: - input_shape: Shape of the input tensor, used to get the number of - channels. - - Raises: - ValueError: if `input_shape` doesn't specify the length of the channel - dimension. - """ - input_shape = tensor_shape.TensorShape(input_shape) - channel_axis = self._channel_axis(input_shape.ndims) - channels = input_shape[channel_axis].value - if channels is None: - raise ValueError("The channel dimension of the inputs must be defined.") - self.input_spec = engine.InputSpec( - ndim=input_shape.ndims, axes={channel_axis: channels}) - filters = (1,) + self.filters + (1,) - scale = self.init_scale ** (1 / (len(self.filters) + 1)) - - # Create variables. - self._matrices = [] - self._biases = [] - self._factors = [] - for i in range(len(self.filters) + 1): - init = np.log(np.expm1(1 / scale / filters[i + 1])) - matrix = self.add_variable( - "matrix_{}".format(i), dtype=self.dtype, - shape=(channels, filters[i + 1], filters[i]), - initializer=init_ops.Constant(init)) - matrix = nn.softplus(matrix) - self._matrices.append(matrix) - - bias = self.add_variable( - "bias_{}".format(i), dtype=self.dtype, - shape=(channels, filters[i + 1], 1), - initializer=init_ops.RandomUniform(-.5, .5)) - self._biases.append(bias) - - if i < len(self.filters): - factor = self.add_variable( - "factor_{}".format(i), dtype=self.dtype, - shape=(channels, filters[i + 1], 1), - initializer=init_ops.Zeros()) - factor = math_ops.tanh(factor) - self._factors.append(factor) - - # To figure out what range of the densities to sample, we need to compute - # the quantiles given by `tail_mass / 2` and `1 - tail_mass / 2`. Since we - # can't take inverses of the cumulative directly, we make it an optimization - # problem: - # `quantiles = argmin(|logit(cumulative) - target|)` - # where `target` is `logit(tail_mass / 2)` or `logit(1 - tail_mass / 2)`. - # Taking the logit (inverse of sigmoid) of the cumulative makes the - # representation of the right target more numerically stable. - - # Numerically stable way of computing logits of `tail_mass / 2` - # and `1 - tail_mass / 2`. - target = np.log(2 / self.tail_mass - 1) - # Compute lower and upper tail quantile as well as median. - target = constant_op.constant([-target, 0, target], dtype=self.dtype) - - def quantiles_initializer(shape, dtype=None, partition_info=None): - del partition_info # unused - assert tuple(shape[1:]) == (1, 3) - init = constant_op.constant( - [[[-self.init_scale, 0, self.init_scale]]], dtype=dtype) - return array_ops.tile(init, (shape[0], 1, 1)) - - quantiles = self.add_variable( - "quantiles", shape=(channels, 1, 3), dtype=self.dtype, - initializer=quantiles_initializer) - logits = self._logits_cumulative(quantiles, stop_gradient=True) - loss = math_ops.reduce_sum(abs(logits - target)) - self.add_loss(loss, inputs=None) - - # Save medians for `call`, `compress`, and `decompress`. - self._medians = quantiles[:, :, 1:2] - if not self.optimize_integer_offset: - self._medians = math_ops.round(self._medians) - - # Largest distance observed between lower tail quantile and median, - # or between median and upper tail quantile. - minima = math_ops.reduce_max(self._medians - quantiles[:, :, 0:1]) - maxima = math_ops.reduce_max(quantiles[:, :, 2:3] - self._medians) - minmax = math_ops.maximum(minima, maxima) - minmax = math_ops.ceil(minmax) - minmax = math_ops.maximum(minmax, 1) - - # Sample the density up to `minmax` around the median. - samples = math_ops.range(-minmax, minmax + 1, dtype=self.dtype) - samples += self._medians - - half = constant_op.constant(.5, dtype=self.dtype) - # We strip the sigmoid from the end here, so we can use the special rule - # below to only compute differences in the left tail of the sigmoid. - # This increases numerical stability (see explanation in `call`). - lower = self._logits_cumulative(samples - half, stop_gradient=True) - upper = self._logits_cumulative(samples + half, stop_gradient=True) - # Flip signs if we can move more towards the left tail of the sigmoid. - sign = -math_ops.sign(math_ops.add_n([lower, upper])) - pmf = abs(math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) - # Add tail masses to first and last bin of pmf, as we clip values for - # compression, meaning that out-of-range values get mapped to these bins. - pmf = array_ops.concat([ - math_ops.add_n([pmf[:, 0, :1], math_ops.sigmoid(lower[:, 0, :1])]), - pmf[:, 0, 1:-1], - math_ops.add_n([pmf[:, 0, -1:], math_ops.sigmoid(-upper[:, 0, -1:])]), - ], axis=-1) - self._pmf = pmf - - cdf = coder_ops.pmf_to_quantized_cdf( - pmf, precision=self.range_coder_precision) - def cdf_getter(*args, **kwargs): - del args, kwargs # ignored - return variable_scope.get_variable( - "quantized_cdf", dtype=dtypes.int32, initializer=cdf, - trainable=False, validate_shape=False, collections=()) - # Need to provide a fake shape here since add_variable insists on it. - self._quantized_cdf = self.add_variable( - "quantized_cdf", shape=(channels, 1), dtype=dtypes.int32, - getter=cdf_getter, trainable=False) - - update_op = state_ops.assign( - self._quantized_cdf, cdf, validate_shape=False) - self.add_update(update_op, inputs=None) - - super(EntropyBottleneck, self).build(input_shape) - - def call(self, inputs, training): - """Pass a tensor through the bottleneck. - - Args: - inputs: The tensor to be passed through the bottleneck. - training: Boolean. If `True`, returns a differentiable approximation of - the inputs, and their likelihoods under the modeled probability - densities. If `False`, returns the quantized inputs and their - likelihoods under the corresponding probability mass function. These - quantities can't be used for training, as they are not differentiable, - but represent actual compression more closely. - - Returns: - values: `Tensor` with the same shape as `inputs` containing the perturbed - or quantized input values. - likelihood: `Tensor` with the same shape as `inputs` containing the - likelihood of `values` under the modeled probability distributions. - - Raises: - ValueError: if `inputs` has different `dtype` or number of channels than - a previous set of inputs the model was invoked with earlier. - """ - inputs = ops.convert_to_tensor(inputs) - ndim = self.input_spec.ndim - channel_axis = self._channel_axis(ndim) - half = constant_op.constant(.5, dtype=self.dtype) - - # Convert to (channels, 1, batch) format by commuting channels to front - # and then collapsing. - order = list(range(ndim)) - order.pop(channel_axis) - order.insert(0, channel_axis) - values = array_ops.transpose(inputs, order) - shape = array_ops.shape(values) - values = array_ops.reshape(values, (shape[0], 1, -1)) - - # Add noise or quantize. - if training: - noise = random_ops.random_uniform(array_ops.shape(values), -half, half) - values = math_ops.add_n([values, noise]) - elif self.optimize_integer_offset: - values = math_ops.round(values - self._medians) + self._medians - else: - values = math_ops.round(values) - - # Evaluate densities. - # We can use the special rule below to only compute differences in the left - # tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1 - # for large x, 0 for small x. Subtracting two numbers close to 0 can be done - # with much higher precision than subtracting two numbers close to 1. - lower = self._logits_cumulative(values - half, stop_gradient=False) - upper = self._logits_cumulative(values + half, stop_gradient=False) - # Flip signs if we can move more towards the left tail of the sigmoid. - sign = -math_ops.sign(math_ops.add_n([lower, upper])) - sign = array_ops.stop_gradient(sign) - likelihood = abs( - math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) - if self.likelihood_bound > 0: - likelihood_bound = constant_op.constant( - self.likelihood_bound, dtype=self.dtype) - # TODO(jballe): Override gradients. - likelihood = math_ops.maximum(likelihood, likelihood_bound) - - # Convert back to input tensor shape. - order = list(range(1, ndim)) - order.insert(channel_axis, 0) - values = array_ops.reshape(values, shape) - values = array_ops.transpose(values, order) - likelihood = array_ops.reshape(likelihood, shape) - likelihood = array_ops.transpose(likelihood, order) - - if not context.executing_eagerly(): - values_shape, likelihood_shape = self.compute_output_shape(inputs.shape) - values.set_shape(values_shape) - likelihood.set_shape(likelihood_shape) - - return values, likelihood - - def compress(self, inputs): - """Compress inputs and store their binary representations into strings. - - Args: - inputs: `Tensor` with values to be compressed. - - Returns: - String `Tensor` vector containing the compressed representation of each - batch element of `inputs`. - """ - with ops.name_scope(self._name_scope()): - inputs = ops.convert_to_tensor(inputs) - if not self.built: - # Check input assumptions set before layer building, e.g. input rank. - self._assert_input_compatibility(inputs) - if self.dtype is None: - self._dtype = inputs.dtype.base_dtype.name - self.build(inputs.shape) - - # Check input assumptions set after layer building, e.g. input shape. - if not context.executing_eagerly(): - self._assert_input_compatibility(inputs) - - ndim = self.input_spec.ndim - channel_axis = self._channel_axis(ndim) - # Tuple of slices for expanding dimensions of tensors below. - slices = ndim * [None] + [slice(None)] - slices[channel_axis] = slice(None) - slices = tuple(slices) - - # Expand dimensions of CDF to input dimensions, keeping the channels along - # the right dimension. - cdf = self._quantized_cdf[slices[1:]] - num_levels = array_ops.shape(cdf)[-1] - 1 - - # Bring inputs to the right range by centering the range on the medians. - half = constant_op.constant(.5, dtype=self.dtype) - medians = array_ops.squeeze(self._medians, [1, 2]) - offsets = (math_ops.cast(num_levels // 2, self.dtype) + half) - medians - # Expand offsets to input dimensions and add to inputs. - values = inputs + offsets[slices[:-1]] - - # Clip to range and cast to integers. Because we have added .5 above, and - # all values are positive, the cast effectively implements rounding. - values = math_ops.maximum(values, half) - values = math_ops.minimum( - values, math_ops.cast(num_levels, self.dtype) - half) - values = math_ops.cast(values, dtypes.int16) - - def loop_body(tensor): - return coder_ops.range_encode( - tensor, cdf, precision=self.range_coder_precision) - strings = functional_ops.map_fn( - loop_body, values, dtype=dtypes.string, back_prop=False) - - if not context.executing_eagerly(): - strings.set_shape(inputs.shape[:1]) - - return strings - - def decompress(self, strings, shape, channels=None): - """Decompress values from their compressed string representations. - - Args: - strings: A string `Tensor` vector containing the compressed data. - shape: A `Tensor` vector of int32 type. Contains the shape of the tensor - to be decompressed, excluding the batch dimension. - channels: Integer. Specifies the number of channels statically. Needs only - be set if the layer hasn't been built yet (i.e., this is the first input - it receives). - - Returns: - The decompressed `Tensor`. Its shape will be equal to `shape` prepended - with the batch dimension from `strings`. - - Raises: - ValueError: If the length of `shape` isn't available at graph construction - time. - """ - with ops.name_scope(self._name_scope()): - strings = ops.convert_to_tensor(strings) - shape = ops.convert_to_tensor(shape) - if self.built: - ndim = self.input_spec.ndim - channel_axis = self._channel_axis(ndim) - if channels is None: - channels = self.input_spec.axes[channel_axis] - else: - if not (shape.shape.is_fully_defined() and shape.shape.ndims == 1): - raise ValueError("`shape` must be a vector with known length.") - ndim = shape.shape[0].value + 1 - channel_axis = self._channel_axis(ndim) - input_shape = ndim * [None] - input_shape[channel_axis] = channels - self.build(input_shape) - - # Tuple of slices for expanding dimensions of tensors below. - slices = ndim * [None] + [slice(None)] - slices[channel_axis] = slice(None) - slices = tuple(slices) - - # Expand dimensions of CDF to input dimensions, keeping the channels along - # the right dimension. - cdf = self._quantized_cdf[slices[1:]] - num_levels = array_ops.shape(cdf)[-1] - 1 - - def loop_body(string): - return coder_ops.range_decode( - string, shape, cdf, precision=self.range_coder_precision) - outputs = functional_ops.map_fn( - loop_body, strings, dtype=dtypes.int16, back_prop=False) - outputs = math_ops.cast(outputs, self.dtype) - - medians = array_ops.squeeze(self._medians, [1, 2]) - offsets = math_ops.cast(num_levels // 2, self.dtype) - medians - outputs -= offsets[slices[:-1]] - - if not context.executing_eagerly(): - outputs_shape = ndim * [None] - outputs_shape[0] = strings.shape[0] - outputs_shape[channel_axis] = channels - outputs.set_shape(outputs_shape) - - return outputs - - def visualize(self): - """Multi-channel visualization of densities as images. - - Creates and returns an image summary visualizing the current probabilty - density estimates. The image contains one row for each channel. Within each - row, the pixel intensities are proportional to probability values, and each - row is centered on the median of the corresponding distribution. - - Returns: - The created image summary. - """ - with ops.name_scope(self._name_scope()): - image = self._pmf - image *= 255 / math_ops.reduce_max(image, axis=1, keepdims=True) - image = math_ops.cast(image + .5, dtypes.uint8) - image = image[None, :, :, None] - return summary.image("pmf", image, max_outputs=1) - - def compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - return input_shape, input_shape diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py b/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py deleted file mode 100644 index 798b0234ebcce7df108a0da65d1305502ce0253a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py +++ /dev/null @@ -1,315 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests of EntropyBottleneck class.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.coder.python.layers import entropybottleneck - -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test -from tensorflow.python.training import gradient_descent - - -class EntropyBottleneckTest(test.TestCase): - - def test_noise(self): - # Tests that the noise added is uniform noise between -0.5 and 0.5. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck() - noisy, _ = layer(inputs, training=True) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - values = np.linspace(-50, 50, 100)[:, None] - noisy, = sess.run([noisy], {inputs: values}) - self.assertFalse(np.allclose(values, noisy, rtol=0, atol=.49)) - self.assertAllClose(values, noisy, rtol=0, atol=.5) - - def test_quantization(self): - # Tests that inputs are quantized to full integer values, even after - # quantiles have been updated. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=False) - quantized, _ = layer(inputs, training=False) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - values = np.linspace(-50, 50, 100)[:, None] - quantized, = sess.run([quantized], {inputs: values}) - self.assertAllClose(np.around(values), quantized, rtol=0, atol=1e-6) - - def test_quantization_optimized_offset(self): - # Tests that inputs are not quantized to full integer values after quantiles - # have been updated. However, the difference between input and output should - # be between -0.5 and 0.5, and the offset must be consistent. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=True) - quantized, _ = layer(inputs, training=False) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - values = np.linspace(-50, 50, 100)[:, None] - quantized, = sess.run([quantized], {inputs: values}) - self.assertAllClose(values, quantized, rtol=0, atol=.5) - diff = np.ravel(np.around(values) - quantized) % 1 - self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) - self.assertNotEqual(diff[0], 0) - - def test_codec(self): - # Tests that inputs are compressed and decompressed correctly, and quantized - # to full integer values, even after quantiles have been updated. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=60, - optimize_integer_offset=False) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = np.linspace(-50, 50, 100)[None, :, None] - decoded, = sess.run([decoded], {inputs: values}) - self.assertAllClose(np.around(values), decoded, rtol=0, atol=1e-6) - - def test_codec_optimized_offset(self): - # Tests that inputs are compressed and decompressed correctly, and not - # quantized to full integer values after quantiles have been updated. - # However, the difference between input and output should be between -0.5 - # and 0.5, and the offset must be consistent. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=60, - optimize_integer_offset=True) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = np.linspace(-50, 50, 100)[None, :, None] - decoded, = sess.run([decoded], {inputs: values}) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - diff = np.ravel(np.around(values) - decoded) % 1 - self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) - self.assertNotEqual(diff[0], 0) - - def test_codec_clipping(self): - # Tests that inputs are compressed and decompressed correctly, and clipped - # to the expected range. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=40) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = np.linspace(-50, 50, 100)[None, :, None] - decoded, = sess.run([decoded], {inputs: values}) - expected = np.clip(np.around(values), -40, 40) - self.assertAllClose(expected, decoded, rtol=0, atol=1e-6) - - def test_channels_last(self): - # Test the layer with more than one channel and multiple input dimensions, - # with the channels in the last dimension. - inputs = array_ops.placeholder(dtypes.float32, (None, None, None, 2)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=50) - noisy, _ = layer(inputs, training=True) - quantized, _ = layer(inputs, training=False) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = 5 * np.random.normal(size=(7, 5, 3, 2)) - noisy, quantized, decoded = sess.run( - [noisy, quantized, decoded], {inputs: values}) - self.assertAllClose(values, noisy, rtol=0, atol=.5) - self.assertAllClose(values, quantized, rtol=0, atol=.5) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - - def test_channels_first(self): - # Test the layer with more than one channel and multiple input dimensions, - # with the channel dimension right after the batch dimension. - inputs = array_ops.placeholder(dtypes.float32, (None, 3, None, None)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_first", init_scale=50) - noisy, _ = layer(inputs, training=True) - quantized, _ = layer(inputs, training=False) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = 5 * np.random.normal(size=(2, 3, 5, 7)) - noisy, quantized, decoded = sess.run( - [noisy, quantized, decoded], {inputs: values}) - self.assertAllClose(values, noisy, rtol=0, atol=.5) - self.assertAllClose(values, quantized, rtol=0, atol=.5) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - - def test_compress(self): - # Test compression and decompression, and produce test data for - # `test_decompress`. If you set the constant at the end to `True`, this test - # will fail and the log will contain the new test data. - inputs = array_ops.placeholder(dtypes.float32, (2, 3, 10)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_first", filters=(), init_scale=2) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = 5 * np.random.uniform(size=(2, 3, 10)) - 2.5 - bitstrings, quantized_cdf, decoded = sess.run( - [bitstrings, layer._quantized_cdf, decoded], {inputs: values}) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - # Set this constant to `True` to log new test data for `test_decompress`. - if False: # pylint:disable=using-constant-test - assert False, (bitstrings, quantized_cdf, decoded) - - # Data generated by `test_compress`. - # pylint:disable=g-inconsistent-quotes,bad-whitespace - bitstrings = np.array([ - b'\x1e\xbag}\xc2\xdaN\x8b\xbd.', - b'\x8dF\xf0%\x1cv\xccllW' - ], dtype=object) - - quantized_cdf = np.array([ - [ 0, 15636, 22324, 30145, 38278, 65536], - [ 0, 19482, 26927, 35052, 42904, 65535], - [ 0, 21093, 28769, 36919, 44578, 65536] - ], dtype=np.int32) - - expected = np.array([ - [[-2., 1., 0., -2., -1., -2., -2., -2., 2., -1.], - [ 1., 2., 1., 0., -2., -2., 1., 2., 0., 1.], - [ 2., 0., -2., 2., 0., -1., -2., 0., 2., 0.]], - [[ 1., 2., 0., -1., 1., 2., 1., 1., 2., -2.], - [ 2., -1., -1., 0., -1., 2., 0., 2., -2., 2.], - [ 2., -2., -2., -1., -2., 1., -2., 0., 0., 0.]] - ], dtype=np.float32) - # pylint:enable=g-inconsistent-quotes,bad-whitespace - - def test_decompress(self): - # Test that decompression of values compressed with a previous version - # works, i.e. that the file format doesn't change across revisions. - bitstrings = array_ops.placeholder(dtypes.string) - input_shape = array_ops.placeholder(dtypes.int32) - quantized_cdf = array_ops.placeholder(dtypes.int32) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_first", filters=(), dtype=dtypes.float32) - layer.build(self.expected.shape) - layer._quantized_cdf = quantized_cdf - decoded = layer.decompress(bitstrings, input_shape[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - decoded, = sess.run([decoded], { - bitstrings: self.bitstrings, input_shape: self.expected.shape, - quantized_cdf: self.quantized_cdf}) - self.assertAllClose(self.expected, decoded, rtol=0, atol=1e-6) - - def test_build_decompress(self): - # Test that layer can be built when `decompress` is the first call to it. - bitstrings = array_ops.placeholder(dtypes.string) - input_shape = array_ops.placeholder(dtypes.int32, shape=[3]) - layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) - layer.decompress(bitstrings, input_shape[1:], channels=5) - self.assertTrue(layer.built) - - def test_pmf_normalization(self): - # Test that probability mass functions are normalized correctly. - layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) - layer.build((None, 10)) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - pmf, = sess.run([layer._pmf]) - self.assertAllClose(np.ones(10), np.sum(pmf, axis=-1), rtol=0, atol=1e-6) - - def test_visualize(self): - # Test that summary op can be constructed. - layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) - layer.build((None, 10)) - summary = layer.visualize() - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run([summary]) - - def test_normalization(self): - # Test that densities are normalized correctly. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck(filters=(2,)) - _, likelihood = layer(inputs, training=True) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - x = np.repeat(np.arange(-200, 201), 1000)[:, None] - likelihood, = sess.run([likelihood], {inputs: x}) - self.assertEqual(x.shape, likelihood.shape) - integral = np.sum(likelihood) * .001 - self.assertAllClose(1, integral, rtol=0, atol=1e-4) - - def test_entropy_estimates(self): - # Test that entropy estimates match actual range coding. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - filters=(2, 3), data_format="channels_last") - _, likelihood = layer(inputs, training=True) - diff_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2) - _, likelihood = layer(inputs, training=False) - disc_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2) - bitstrings = layer.compress(inputs) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - diff_entropy, disc_entropy, bitstrings = sess.run( - [diff_entropy, disc_entropy, bitstrings], - {inputs: np.random.normal(size=(1, 10000, 1))}) - codelength = 8 * sum(len(bitstring) for bitstring in bitstrings) - self.assertAllClose(diff_entropy, disc_entropy, rtol=5e-3, atol=0) - self.assertAllClose(disc_entropy, codelength, rtol=5e-3, atol=0) - self.assertGreater(codelength, disc_entropy) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/constrained_optimization/README.md b/tensorflow/contrib/constrained_optimization/README.md index c65a150464efc1e77419040f66f36fc6756325aa..cb1dd7d836ae11700b2ffaaff4fda5b7f943f87d 100644 --- a/tensorflow/contrib/constrained_optimization/README.md +++ b/tensorflow/contrib/constrained_optimization/README.md @@ -46,7 +46,7 @@ document. Imagine that we want to constrain the recall of a binary classifier to be at least 90%. Since the recall is proportional to the number of true positive classifications, which itself is a sum of indicator functions, this constraint -is non-differentible, and therefore cannot be used in a problem that will be +is non-differentiable, and therefore cannot be used in a problem that will be optimized using a (stochastic) gradient-based algorithm. For this and similar problems, TFCO supports so-called *proxy constraints*, diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py index 04014ab4aebd6d9cd70653c53f9361320e803329..ff846b191a34e3f3b4aa35671ca22b96b963db80 100644 --- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py @@ -150,7 +150,7 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix): "matrix must be two dimensional (instead is %d-dimensional)" % matrix_shape.ndims) if matrix_shape[0] != matrix_shape[1]: - raise ValueError("matrix must be be square (instead has shape (%d,%d))" % + raise ValueError("matrix must be square (instead has shape (%d,%d))" % (matrix_shape[0], matrix_shape[1])) dimension = matrix_shape[0].value if dimension is None: @@ -169,8 +169,8 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix): del old_inactive # Needed by the condition, but not the body. iteration += 1 scale = (1.0 - standard_ops.reduce_sum( - matrix, axis=0, keep_dims=True)) / standard_ops.maximum( - 1.0, standard_ops.reduce_sum(inactive, axis=0, keep_dims=True)) + matrix, axis=0, keepdims=True)) / standard_ops.maximum( + 1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True)) matrix += scale * inactive new_inactive = standard_ops.to_float(matrix > 0) matrix *= new_inactive @@ -206,10 +206,10 @@ def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix): # For numerical reasons, make sure that the largest matrix element is zero # before exponentiating. - log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keep_dims=True) + log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keepdims=True) log_matrix -= standard_ops.log( standard_ops.reduce_sum( - standard_ops.exp(log_matrix), axis=0, keep_dims=True)) + standard_ops.exp(log_matrix), axis=0, keepdims=True)) return log_matrix diff --git a/tensorflow/contrib/control_flow/BUILD b/tensorflow/contrib/control_flow/BUILD deleted file mode 100644 index e8036d63aeeac224b226899c036416a06b4ffe65..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/control_flow/BUILD +++ /dev/null @@ -1,53 +0,0 @@ -# New implementations of control flow ops - -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//visibility:public"]) - -load("//tensorflow:tensorflow.bzl", "tf_py_test") - -py_library( - name = "control_flow", - srcs = ["__init__.py"], - srcs_version = "PY2AND3", - deps = [ - ":cond_v2", - ], -) - -py_library( - name = "cond_v2", - srcs = ["python/cond_v2.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:c_api_util", - "//tensorflow/python:framework_ops", - "//tensorflow/python:function", - "//tensorflow/python:function_def_to_graph", - "//tensorflow/python:functional_ops_gen", - "//tensorflow/python:gradients", - "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python:util", - ], -) - -tf_py_test( - name = "cond_v2_test", - size = "small", - srcs = ["python/cond_v2_test.py"], - additional_deps = [ - ":cond_v2", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:training", - ], - grpc_enabled = True, -) diff --git a/tensorflow/contrib/control_flow/python/cond_v2_test.py b/tensorflow/contrib/control_flow/python/cond_v2_test.py deleted file mode 100644 index 338601aa2c5ee8ffc97aa2e07ff05a2d17531936..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/control_flow/python/cond_v2_test.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Tests for cond_v2.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.control_flow.python import cond_v2 -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import test -from tensorflow.python.training import saver - - -class NewCondTest(test.TestCase): - - def _testCond(self, true_fn, false_fn, train_vals): - pred = array_ops.placeholder(dtypes.bool, name="pred") - - expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected") - actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") - - expected_grad = gradients_impl.gradients(expected, train_vals) - actual_grad = gradients_impl.gradients(actual, train_vals) - - with self.test_session() as sess: - expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( - (expected, actual, expected_grad, actual_grad), {pred: True}) - self.assertEqual(expected_val, actual_val) - self.assertEqual(expected_grad_val, actual_grad_val) - - expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( - (expected, actual, expected_grad, actual_grad), {pred: False}) - self.assertEqual(expected_val, actual_val) - self.assertEqual(expected_grad_val, actual_grad_val) - - def testBasic(self): - x = constant_op.constant(1.0, name="x") - y = constant_op.constant(2.0, name="y") - - def true_fn(): - return x * 2.0 - - def false_fn(): - return y * 3.0 - - self._testCond(true_fn, false_fn, [x]) - self._testCond(true_fn, false_fn, [x, y]) - self._testCond(true_fn, false_fn, [y]) - - def testBasic2(self): - x = constant_op.constant(1.0, name="x") - y = constant_op.constant(2.0, name="y") - - def true_fn(): - return x * y * 2.0 - - def false_fn(): - return 2.0 - - self._testCond(true_fn, false_fn, [x]) - self._testCond(true_fn, false_fn, [x, y]) - self._testCond(true_fn, false_fn, [y]) - - def testNoInputs(self): - pred = array_ops.placeholder(dtypes.bool, name="pred") - - def true_fn(): - return constant_op.constant(1.0) - - def false_fn(): - return constant_op.constant(2.0) - - out = cond_v2.cond_v2(pred, true_fn, false_fn) - - with self.test_session() as sess: - self.assertEqual(sess.run(out, {pred: True}), [1.0]) - self.assertEqual(sess.run(out, {pred: False}), [2.0]) - - def testSecondDerivative(self): - pred = array_ops.placeholder(dtypes.bool, name="pred") - x = constant_op.constant(3.0, name="x") - - def true_fn(): - return math_ops.pow(x, 3) - - def false_fn(): - return x - - cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") - cond_grad = gradients_impl.gradients(cond, [x]) - cond_grad_grad = gradients_impl.gradients(cond_grad, [x]) - - with self.test_session() as sess: - # d[x^3]/dx = 3x^2 - true_val = sess.run(cond_grad, {pred: True}) - self.assertEqual(true_val, [27.0]) - # d[x]/dx = 1 - false_val = sess.run(cond_grad, {pred: False}) - self.assertEqual(false_val, [1.0]) - - true_val = sess.run(cond_grad_grad, {pred: True}) - # d2[x^3]/dx2 = 6x - self.assertEqual(true_val, [18.0]) - false_val = sess.run(cond_grad_grad, {pred: False}) - # d2[x]/dx2 = 0 - self.assertEqual(false_val, [0.0]) - - def testGradientOfDeserializedCond(self): - with ops.Graph().as_default(): - pred = array_ops.placeholder(dtypes.bool, name="pred") - x = constant_op.constant(3.0, name="x") - ops.add_to_collection("x", x) - - def true_fn(): - return math_ops.pow(x, 3) - - def false_fn(): - return x - - ops.add_to_collection("pred", pred) - cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") - for c in cond: - ops.add_to_collection("cond", c) - meta_graph = saver.export_meta_graph() - - with ops.Graph().as_default() as g: - saver.import_meta_graph(meta_graph) - x = ops.get_collection("x")[0] - pred = ops.get_collection("pred")[0] - cond = ops.get_collection("cond") - cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad") - cond_grad_grad = gradients_impl.gradients( - cond_grad, [x], name="cond_grad_grad") - with self.test_session(graph=g) as sess: - # d[x^3]/dx = 3x^2 - true_val = sess.run(cond_grad, {pred: True}) - self.assertEqual(true_val, [27.0]) - # d[x]/dx = 1 - false_val = sess.run(cond_grad, {pred: False}) - self.assertEqual(false_val, [1.0]) - - true_val = sess.run(cond_grad_grad, {pred: True}) - # d2[x^3]/dx2 = 6x - self.assertEqual(true_val, [18.0]) - false_val = sess.run(cond_grad_grad, {pred: False}) - # d2[x]/dx2 = 0 - self.assertEqual(false_val, [0.0]) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index a0dd3881a86c19e47ccb65f84a2477a55626b81c..6c9ab6aeb87fd39b22ab4f28d69b432b15899a13 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -18,7 +18,7 @@ These functions allow for recursive copying of elements (ops and variables) from one graph to another. The copied elements are initialized inside a user-specified scope in the other graph. There are separate functions to copy ops and variables. -There is also a function to retrive the copied version of an op from the +There is also a function to retrieve the copied version of an op from the first graph inside a scope in the second graph. @@copy_op_to_graph @@ -77,7 +77,7 @@ def copy_variable_to_graph(org_instance, to_graph, scope=''): else: collections.append(scope + '/' + name) - #See if its trainable. + #See if it's trainable. trainable = ( org_instance in org_instance.graph.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES)) @@ -162,7 +162,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): if isinstance(org_instance, ops.Tensor): - #If its a Tensor, it is one of the outputs of the underlying + #If it's a Tensor, it is one of the outputs of the underlying #op. Therefore, copy the op itself and return the appropriate #output. op = org_instance.op @@ -219,8 +219,10 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): op_def) #Use Graph's hidden methods to add the op to_graph._record_op_seen_by_control_dependencies(new_op) - for device_function in reversed(to_graph._device_function_stack): + # pylint: disable=protected-access + for device_function in to_graph._device_functions_outer_to_inner: new_op._set_device(device_function(new_op)) + # pylint: enable=protected-access return new_op diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py index 046c509626bc2eb20a65c0b38495ff37c294e0e1..615e62b16f1906dafa22a12cc7275a2335e8df88 100644 --- a/tensorflow/contrib/crf/__init__.py +++ b/tensorflow/contrib/crf/__init__.py @@ -20,6 +20,7 @@ See the @{$python/contrib.crf} guide. @@crf_decode @@crf_log_likelihood @@crf_log_norm +@@crf_multitag_sequence_score @@crf_sequence_score @@crf_unary_score @@CrfDecodeBackwardRnnCell @@ -36,6 +37,7 @@ from tensorflow.contrib.crf.python.ops.crf import crf_binary_score from tensorflow.contrib.crf.python.ops.crf import crf_decode from tensorflow.contrib.crf.python.ops.crf import crf_log_likelihood from tensorflow.contrib.crf.python.ops.crf import crf_log_norm +from tensorflow.contrib.crf.python.ops.crf import crf_multitag_sequence_score from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score from tensorflow.contrib.crf.python.ops.crf import crf_unary_score from tensorflow.contrib.crf.python.ops.crf import CrfDecodeBackwardRnnCell diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 74f2ec22ffaab1654e5cd38169258fb87d307ad4..8cfe14205927bf7763cf36fa31012ab10fce995c 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -31,6 +31,15 @@ from tensorflow.python.platform import test class CrfTest(test.TestCase): + def calculateSequenceScore(self, inputs, transition_params, tag_indices, + sequence_lengths): + expected_unary_score = sum( + inputs[i][tag_indices[i]] for i in range(sequence_lengths)) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1)) + return expected_unary_score + expected_binary_score + def testCrfSequenceScore(self): transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) @@ -60,14 +69,55 @@ class CrfTest(test.TestCase): transition_params=constant_op.constant(transition_params)) sequence_score = array_ops.squeeze(sequence_score, [0]) tf_sequence_score = sess.run(sequence_score) - expected_unary_score = sum(inputs[i][tag_indices[i]] - for i in range(sequence_lengths)) - expected_binary_score = sum( - transition_params[tag_indices[i], tag_indices[i + 1]] - for i in range(sequence_lengths - 1)) - expected_sequence_score = expected_unary_score + expected_binary_score + expected_sequence_score = self.calculateSequenceScore( + inputs, transition_params, tag_indices, sequence_lengths) self.assertAllClose(tf_sequence_score, expected_sequence_score) + def testCrfMultiTagSequenceScore(self): + transition_params = np.array( + [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[4, 5, -3]], + dtype=np.float32), + ] + tag_bitmap_list = [ + np.array( + [[True, True, False], [True, False, True], [False, True, True], + [True, False, True]], + dtype=np.bool), + np.array([[True, True, False]], dtype=np.bool) + ] + for sequence_lengths, inputs, tag_bitmap in zip( + sequence_lengths_list, inputs_list, tag_bitmap_list): + with self.test_session() as sess: + sequence_score = crf.crf_multitag_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_bitmap=array_ops.expand_dims(tag_bitmap, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + sequence_score = array_ops.squeeze(sequence_score, [0]) + tf_sum_sequence_score = sess.run(sequence_score) + all_indices_list = [ + single_index_bitmap.nonzero()[0] + for single_index_bitmap in tag_bitmap[:sequence_lengths] + ] + expected_sequence_scores = [ + self.calculateSequenceScore(inputs, transition_params, indices, + sequence_lengths) + for indices in itertools.product(*all_indices_list) + ] + expected_log_sum_exp_sequence_scores = np.logaddexp.reduce( + expected_sequence_scores) + self.assertAllClose(tf_sum_sequence_score, + expected_log_sum_exp_sequence_scores) + def testCrfUnaryScore(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) @@ -108,7 +158,7 @@ class CrfTest(test.TestCase): # Test both the length-1 and regular cases. sequence_lengths_list = [ np.array(3, dtype=np.int32), - np.array(1, dtype=np.int32) + np.array(1, dtype=np.int64) ] inputs_list = [ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], @@ -241,7 +291,7 @@ class CrfTest(test.TestCase): # Test both the length-1 and regular cases. sequence_lengths_list = [ np.array(3, dtype=np.int32), - np.array(1, dtype=np.int32) + np.array(1, dtype=np.int64) ] inputs_list = [ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 2d2cbdc1990ed9d8e58c0032cbc141a52271838f..2a91dcb63a80016e62d10d1310ca57e3e54434c5 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -67,7 +67,7 @@ __all__ = [ "crf_sequence_score", "crf_log_norm", "crf_log_likelihood", "crf_unary_score", "crf_binary_score", "CrfForwardRnnCell", "viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell", - "CrfDecodeBackwardRnnCell" + "CrfDecodeBackwardRnnCell", "crf_multitag_sequence_score" ] @@ -114,6 +114,56 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, false_fn=_multi_seq_fn) +def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths, + transition_params): + """Computes the unnormalized score of all tag sequences matching tag_bitmap. + + tag_bitmap enables more than one tag to be considered correct at each time + step. This is useful when an observed output at a given time step is + consistent with more than one tag, and thus the log likelihood of that + observation must take into account all possible consistent tags. + + Using one-hot vectors in tag_bitmap gives results identical to + crf_sequence_score. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor + representing all active tags at each index for which to calculate the + unnormalized score. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + sequence_scores: A [batch_size] vector of unnormalized sequence scores. + """ + + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of all active tags. + def _single_seq_fn(): + filtered_inputs = array_ops.where( + tag_bitmap, inputs, + array_ops.fill(array_ops.shape(inputs), float("-inf"))) + return math_ops.reduce_logsumexp( + filtered_inputs, axis=[1, 2], keepdims=False) + + def _multi_seq_fn(): + # Compute the logsumexp of all scores of sequences matching the given tags. + filtered_inputs = array_ops.where( + tag_bitmap, inputs, + array_ops.fill(array_ops.shape(inputs), float("-inf"))) + return crf_log_norm( + inputs=filtered_inputs, + sequence_lengths=sequence_lengths, + transition_params=transition_params) + + return utils.smart_cond( + pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1], + 1), + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) + + def crf_log_norm(inputs, sequence_lengths, transition_params): """Computes the normalization for a CRF. @@ -498,7 +548,9 @@ def crf_decode(potentials, transition_params, sequence_length): initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] # Sequence length is not allowed to be less than zero. - sequence_length_less_one = math_ops.maximum(0, sequence_length - 1) + sequence_length_less_one = math_ops.maximum( + constant_op.constant(0, dtype=sequence_length.dtype), + sequence_length - 1) backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] crf_fwd_cell, inputs=inputs, diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 8285ea04926d3a24e9c22bd6d69eb7a48f5e3a85..252ea1560d7f5be3799686d6d91ae9a6d262ac0a 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -768,7 +768,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLSTMCheckpointableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION @@ -781,7 +781,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGRUCheckpointableSingleLayer(self): num_units = 2 direction = CUDNN_RNN_UNIDIRECTION @@ -826,7 +826,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCudnnCompatibleLSTMCheckpointablMultiLayer(self): num_units = 2 num_layers = 3 diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index d58198faf353aab68430d2fa153a18de359112de..e26d56c8579e110d61c73c6154b82f47f0093687 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -56,7 +56,7 @@ class _CudnnRNN(base_layer.Layer): Cudnn RNNs have two major differences from other platform-independent RNNs tf provides: * Cudnn LSTM and GRU are mathematically different from their tf counterparts. - (e.g. @{tf.contrib.rnn.LSTMBlockCell} and @{tf.nn.rnn_cell.GRUCell}. + (e.g. `tf.contrib.rnn.LSTMBlockCell` and `tf.nn.rnn_cell.GRUCell`. * Cudnn-trained checkpoints are not directly compatible with tf RNNs: * They use a single opaque parameter buffer for the entire (possibly) multi-layer multi-directional RNN; Whereas tf RNN weights are per-cell and @@ -182,7 +182,7 @@ class _CudnnRNN(base_layer.Layer): dropout: dropout rate, a number between [0, 1]. Dropout is applied between each layer (no dropout is applied for a model with a single layer). When set to 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. dtype: tf.float16, tf.float32 or tf.float64 kernel_initializer: starting value to initialize the weight. diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 8822a7523f6b168f569e29970c9c29f2eb3614fc..2c92f31788378c2a9f01183bc04b035668b59b59 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -33,7 +33,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import tracking as checkpointable_lib CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" @@ -61,8 +61,8 @@ _WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): """Cudnn Compatible LSTMCell. - A simple wrapper around @{tf.contrib.rnn.LSTMBlockCell} to use along with - @{tf.contrib.cudnn_rnn.CudnnLSTM}. The latter's params can be used by + A simple wrapper around `tf.contrib.rnn.LSTMBlockCell` to use along with + `tf.contrib.cudnn_rnn.CudnnLSTM`. The latter's params can be used by this cell seamlessly. """ @@ -76,8 +76,8 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): """Cudnn Compatible GRUCell. - A GRU impl akin to @{tf.nn.rnn_cell.GRUCell} to use along with - @{tf.contrib.cudnn_rnn.CudnnGRU}. The latter's params can be used by + A GRU impl akin to `tf.nn.rnn_cell.GRUCell` to use along with + `tf.contrib.cudnn_rnn.CudnnGRU`. The latter's params can be used by it seamlessly. It differs from platform-independent GRUs in how the new memory gate is @@ -97,7 +97,7 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): $$h_t = (1 - u_t) .* h'_t + u_t .* h_t-1$$ ``` - Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}): + Other GRU (see `tf.nn.rnn_cell.GRUCell` and `tf.contrib.rnn.GRUBlockCell`): ```python # new memory gate \\(h'_t = tanh(x_t * W_h + (r_t .* h_t-1) * R_h + b_{Wh})\\) @@ -891,7 +891,7 @@ def _cudnn_rnn(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -957,7 +957,7 @@ def cudnn_lstm(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -998,7 +998,7 @@ def _cudnn_rnn_no_input_c(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1040,7 +1040,7 @@ def cudnn_gru(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1079,7 +1079,7 @@ def cudnn_rnn_relu(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1119,7 +1119,7 @@ def cudnn_rnn_tanh(inputs, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1161,7 +1161,7 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1224,7 +1224,7 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode, direction: the direction model that the model operates. Could be either 'unidirectional' or 'bidirectional' dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1282,7 +1282,7 @@ def cudnn_rnn_opaque_params_size(rnn_mode, 'unidirectional' or 'bidirectional' dtype: one of tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. name: name of the operation. Returns: @@ -1349,7 +1349,7 @@ class _CudnnRNN(object): 'unidirectional' or 'bidirectional' dtype: dtype of params, tf.float32 or tf.float64. dropout: whether to enable dropout. With it is 0, dropout is disabled. - seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + seed: the op seed used for initializing dropout. See `tf.set_random_seed` for behavior. Raises: ValueError: if direction is invalid. diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 1af1ed08b53ee04367eb316d5c9caa0216f2e88d..dbfff9b4f86065de9736eed72de173bc1bef35d6 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -15,43 +15,51 @@ """Experimental API for building input pipelines. This module contains experimental `Dataset` sources and transformations that can -be used in conjunction with the @{tf.data.Dataset} API. Note that the +be used in conjunction with the `tf.data.Dataset` API. Note that the `tf.contrib.data` API is not subject to the same backwards compatibility guarantees as `tf.data`, but we will provide deprecation advice in advance of removing existing functionality. -See the @{$datasets$Importing Data} Programmer's Guide for an overview. +See @{$guide/datasets$Importing Data} for an overview. @@Counter @@CheckpointInputPipelineHook @@CsvDataset +@@RandomDataset +@@Reducer @@SqlDataset +@@TFRecordWriter @@assert_element_shape @@batch_and_drop_remainder @@bucket_by_sequence_length @@choose_from_datasets +@@copy_to_device @@dense_to_sparse_batch @@enumerate_dataset + +@@get_single_element +@@group_by_reducer @@group_by_window @@ignore_errors @@make_batched_features_dataset @@make_csv_dataset @@make_saveable_from_iterator + @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave @@prefetch_to_device @@read_batch_features @@rejection_resample +@@reduce_dataset @@sample_from_datasets @@scan @@shuffle_and_repeat @@sliding_window_batch @@sloppy_interleave @@unbatch - -@@get_single_element +@@unique """ from __future__ import absolute_import @@ -70,14 +78,20 @@ from tensorflow.contrib.data.python.ops.counter import Counter from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.get_single_element import get_single_element +from tensorflow.contrib.data.python.ops.get_single_element import reduce_dataset from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length +from tensorflow.contrib.data.python.ops.grouping import group_by_reducer from tensorflow.contrib.data.python.ops.grouping import group_by_window +from tensorflow.contrib.data.python.ops.grouping import Reducer +from tensorflow.contrib.data.python.ops.interleave_ops import choose_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator +from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device +from tensorflow.contrib.data.python.ops.random_ops import RandomDataset from tensorflow.contrib.data.python.ops.readers import CsvDataset from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset from tensorflow.contrib.data.python.ops.readers import make_csv_dataset @@ -87,6 +101,8 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch +from tensorflow.contrib.data.python.ops.unique import unique +from tensorflow.contrib.data.python.ops.writers import TFRecordWriter # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 7b69e10441eba3e38c979d5715c16699ac2710ed..2e249f5c14ab111ae412ff3288acc25de8d7aa11 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -37,6 +37,7 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) cc_library( @@ -58,6 +59,7 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) cc_library( @@ -68,11 +70,24 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, +) + +cc_library( + name = "assert_next_dataset_op", + srcs = ["assert_next_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, ) cc_library( name = "dataset_kernels", deps = [ + ":assert_next_dataset_op", ":csv_dataset_op", ":directed_interleave_dataset_op", ":ignore_errors_dataset_op", diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e36c9c0634235022362b59a6699b4d550d6d0eee --- /dev/null +++ b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. +class AssertNextDatasetOp : public UnaryDatasetOpKernel { + public: + explicit AssertNextDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + std::vector transformations; + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "transformations", + &transformations)); + *output = + new Dataset(ctx, input, transformations, output_types_, output_shapes_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const std::vector& transformations, + const DataTypeVector& output_types, + const std::vector& output_shapes) + : DatasetBase(DatasetContext(ctx)), + input_(input), + transformations_(transformations), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::Assert")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "AssertNextDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* transformations_node = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {input_graph_node, transformations_node}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + std::vector tokens = + str_util::Split(prefix(), ':', str_util::SkipEmpty()); + if (dataset()->transformations_.size() > tokens.size() - 2) { + return errors::InvalidArgument( + "Asserted next ", dataset()->transformations_.size(), + " transformations but encountered only ", tokens.size() - 2, "."); + } + int n = tokens.size(); + for (size_t i = 0; i < dataset()->transformations_.size(); ++i) { + if (dataset()->transformations_[i] != tokens[n - 2 - i]) { + return errors::InvalidArgument( + "Asserted ", dataset()->transformations_[i], + " transformation at offset ", i, " but encountered ", + tokens[n - 2 - i], " transformation instead."); + } + } + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + return Status::OK(); + } + + private: + std::unique_ptr input_impl_; + }; + + const DatasetBase* input_; + const std::vector transformations_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + }; + + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU), + AssertNextDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index 4657807785d58727d34f37172bd30c56a5b7cde6..d242cfdf4911ee43051b8aa2f7b960916b40374a 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -18,7 +18,10 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/lib/io/random_inputstream.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/lib/io/zlib_inputstream.h" namespace tensorflow { namespace { @@ -37,6 +40,10 @@ class CSVDatasetOp : public DatasetOpKernel { ctx, filenames_tensor->dims() <= 1, errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + string compression_type; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "compression_type", + &compression_type)); + OpInputList record_defaults_list; OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults_list)); @@ -86,6 +93,19 @@ class CSVDatasetOp : public DatasetOpKernel { filenames.push_back(filenames_tensor->flat()(i)); } + io::ZlibCompressionOptions zlib_compression_options = + io::ZlibCompressionOptions::DEFAULT(); + if (compression_type == "ZLIB") { + zlib_compression_options = io::ZlibCompressionOptions::DEFAULT(); + } else if (compression_type == "GZIP") { + zlib_compression_options = io::ZlibCompressionOptions::GZIP(); + } else { + OP_REQUIRES(ctx, compression_type.empty(), + errors::InvalidArgument( + "Unsupported compression_type: ", compression_type, ".")); + } + zlib_compression_options.input_buffer_size = buffer_size; + std::vector select_cols; select_cols.reserve(select_cols_tensor->NumElements()); for (int i = 0; i < select_cols_tensor->NumElements(); ++i) { @@ -103,31 +123,35 @@ class CSVDatasetOp : public DatasetOpKernel { ctx, select_cols.empty() || select_cols.front() >= 0, errors::InvalidArgument("select_cols should be non-negative indices")); - *output = new Dataset(ctx, std::move(filenames), header, buffer_size, + *output = new Dataset(ctx, std::move(filenames), header, + std::move(compression_type), zlib_compression_options, output_types_, output_shapes_, std::move(record_defaults), std::move(select_cols), use_quote_delim, delim[0], std::move(na_value)); } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, std::vector filenames, bool header, - int64 buffer_size, const DataTypeVector& output_types, + string compression_type, io::ZlibCompressionOptions options, + const DataTypeVector& output_types, const std::vector& output_shapes, std::vector record_defaults, std::vector select_cols, bool use_quote_delim, char delim, string na_value) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), filenames_(std::move(filenames)), header_(header), - buffer_size_(buffer_size), out_type_(output_types), output_shapes_(output_shapes), record_defaults_(std::move(record_defaults)), select_cols_(std::move(select_cols)), use_quote_delim_(use_quote_delim), delim_(delim), - na_value_(std::move(na_value)) {} + na_value_(std::move(na_value)), + use_compression_(!compression_type.empty()), + compression_type_(std::move(compression_type)), + options_(options) {} std::unique_ptr MakeIteratorInternal( const string& prefix) const override { @@ -144,12 +168,48 @@ class CSVDatasetOp : public DatasetOpKernel { string DebugString() const override { return "CSVDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - // TODO(rachelim): Implement this - std::vector input_tensors; - TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output)); - return errors::Unimplemented("CSVDataset: AsGraphDefInternal"); + Node* filenames = nullptr; + Node* compression_type = nullptr; + Node* buffer_size = nullptr; + Node* header = nullptr; + Node* delim = nullptr; + Node* use_quote_delim = nullptr; + Node* na_value = nullptr; + Node* select_cols = nullptr; + + std::vector record_defaults; + record_defaults.reserve(record_defaults_.size()); + for (const Tensor& t : record_defaults_) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + record_defaults.emplace_back(node); + } + + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type)); + TF_RETURN_IF_ERROR( + b->AddScalar(options_.input_buffer_size, &buffer_size)); + TF_RETURN_IF_ERROR(b->AddScalar(header_, &header)); + + string delim_string(1, delim_); + TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim)); + TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim)); + TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value)); + TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols)); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {std::make_pair(0, filenames), std::make_pair(1, compression_type), + std::make_pair(2, buffer_size), std::make_pair(3, header), + std::make_pair(4, delim), std::make_pair(5, use_quote_delim), + std::make_pair(6, na_value), + std::make_pair(7, select_cols)}, // Single tensor inputs + {std::make_pair(8, record_defaults)}, // Tensor list inputs + {}, output)); + return Status::OK(); } private: @@ -201,14 +261,58 @@ class CSVDatasetOp : public DatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - // TODO(rachelim): Implement save - return errors::Unimplemented("CSVDataset: SaveInternal"); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), + current_file_index_)); + // `input_stream_` is empty if + // 1. GetNext has not been called even once. + // 2. All files have been read and the iterator has been exhausted. + if (input_stream_ && num_buffer_reads_ > 0) { + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_)); + // If num_buffer_reads_ == 0, the buffer hasn't been filled even once. + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"), + num_buffer_reads_)); + } + return Status::OK(); } + Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); - // TODO(rachelim): Implement restore - return errors::Unimplemented("CSVDataset: RestoreInternal"); + ResetStreamsLocked(); + int64 current_file_index; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"), + ¤t_file_index)); + current_file_index_ = size_t(current_file_index); + // The keys "pos" and "num_buffer_reads" are written only if + // the iterator was saved with an open, partially read file. + if (reader->Contains(full_name("pos"))) { + int64 pos, num_buffer_reads; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"), + &num_buffer_reads)); + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + + num_buffer_reads_ = size_t(num_buffer_reads - 1); + + // Restores the most recently held buffer + Status s = input_stream_->SkipNBytes( + num_buffer_reads_ * dataset()->options_.input_buffer_size); + if (!s.ok() && !errors::IsOutOfRange(s)) { + // We might get out of range error here if the size of the file + // is not an exact multiple of the buffer size, and the last buffer + // read is < buffer_size. This is valid and we do not surface the + // error. + return s; + } + + Status s2 = FillBuffer(&buffer_); + if (!s2.ok() && !errors::IsOutOfRange(s2)) { + return s2; + } + pos_ = size_t(pos); + } + return Status::OK(); } private: @@ -510,7 +614,9 @@ class CSVDatasetOp : public DatasetOpKernel { Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { result->clear(); - Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); + ++num_buffer_reads_; + Status s = input_stream_->ReadNBytes( + dataset()->options_.input_buffer_size, result); if (errors::IsOutOfRange(s) && !result->empty()) { // Ignore OutOfRange error when ReadNBytes read < N bytes. @@ -675,10 +781,20 @@ class CSVDatasetOp : public DatasetOpKernel { // Actually move on to next file. TF_RETURN_IF_ERROR(env->NewRandomAccessFile( dataset()->filenames_[current_file_index_], &file_)); - input_stream_.reset( - new io::RandomAccessInputStream(file_.get(), false)); + random_access_input_stream_ = + std::make_shared(file_.get(), false); + + if (dataset()->use_compression_) { + input_stream_ = std::make_shared( + random_access_input_stream_.get(), + dataset()->options_.input_buffer_size, + dataset()->options_.input_buffer_size, dataset()->options_); + } else { + input_stream_ = random_access_input_stream_; + } buffer_.clear(); pos_ = 0; + num_buffer_reads_ = 0; if (dataset()->header_) { // Read one line, but don't include it. Pass nullptrs as dummy // pointers to objects that shouldn't be invoked anyway @@ -704,8 +820,10 @@ class CSVDatasetOp : public DatasetOpKernel { string buffer_ GUARDED_BY(mu_); // Maintain our own buffer size_t pos_ GUARDED_BY( mu_); // Index into the buffer must be maintained between iters - std::unique_ptr input_stream_ + size_t num_buffer_reads_ GUARDED_BY(mu_); + std::shared_ptr random_access_input_stream_ GUARDED_BY(mu_); + std::shared_ptr input_stream_ GUARDED_BY(mu_); size_t current_file_index_ GUARDED_BY(mu_) = 0; std::unique_ptr file_ GUARDED_BY(mu_); // must outlive input_stream_ @@ -713,7 +831,6 @@ class CSVDatasetOp : public DatasetOpKernel { const std::vector filenames_; const bool header_; - const int64 buffer_size_; const DataTypeVector out_type_; const std::vector output_shapes_; const std::vector record_defaults_; @@ -721,6 +838,9 @@ class CSVDatasetOp : public DatasetOpKernel { const bool use_quote_delim_; const char delim_; const string na_value_; + const bool use_compression_; + const string compression_type_; + const io::ZlibCompressionOptions options_; }; // class Dataset DataTypeVector output_types_; diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc index 6a12ca06f4d6cc2096aaf8191a01a899881b43db..ccf7ec1f842f5a1ad9b304c904f046ad49ed1757 100644 --- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -63,11 +63,11 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* selector_input, std::vector data_inputs) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), selector_input_(selector_input), data_inputs_(std::move(data_inputs)) { selector_input_->Ref(); @@ -110,15 +110,16 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* selector_input_node; TF_RETURN_IF_ERROR( - b->AddParentDataset(ctx, selector_input_, &selector_input_node)); + b->AddInputDataset(ctx, selector_input_, &selector_input_node)); std::vector data_input_nodes(data_inputs_.size()); for (size_t i = 0; i < data_inputs_.size(); ++i) { TF_RETURN_IF_ERROR( - b->AddParentDataset(ctx, data_inputs_[i], &data_input_nodes[i])); + b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i])); } TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}}, {{1, data_input_nodes}}, {}, output)); @@ -204,7 +205,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (selector_input_impl_) { - TF_RETURN_IF_ERROR(SaveParent(writer, selector_input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(writer, selector_input_impl_)); } else { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("selector_input_impl_empty"), "")); @@ -212,7 +213,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { for (size_t i = 0; i < data_input_impls_.size(); ++i) { const auto& data_input_impl = data_input_impls_[i]; if (data_input_impl) { - TF_RETURN_IF_ERROR(SaveParent(writer, data_input_impl)); + TF_RETURN_IF_ERROR(SaveInput(writer, data_input_impl)); } else { TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(strings::StrCat("data_input_impl_empty[", i, "]")), @@ -226,15 +227,14 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { IteratorStateReader* reader) override { mutex_lock l(mu_); if (!reader->Contains(full_name("selector_input_impl_empty"))) { - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, selector_input_impl_)); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_)); } else { selector_input_impl_.reset(); } for (size_t i = 0; i < data_input_impls_.size(); ++i) { if (!reader->Contains(full_name( strings::StrCat("data_input_impl_empty[", i, "]")))) { - TF_RETURN_IF_ERROR( - RestoreParent(ctx, reader, data_input_impls_[i])); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i])); } else { data_input_impls_[i].reset(); } diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index bbec50681c6f5decec5a3b5fbf09cc3011a21199..db24e608463224f05159b57eb721718afd7cbb20 100644 --- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -35,10 +35,10 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, const DatasetBase* input) - : GraphDatasetBase(ctx), input_(input) { + : DatasetBase(DatasetContext(ctx)), input_(input) { input_->Ref(); } @@ -62,10 +62,11 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); return Status::OK(); } @@ -106,7 +107,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); else TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impls_empty"), "")); @@ -119,7 +120,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { if (reader->Contains(full_name("input_impls_empty"))) input_impl_.reset(); else - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); return Status::OK(); } diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index a2bfce03620a1482f5b21cbf23c66833bc5cd480..74df1e42a8fbca9b6a65aa4800424d27aa90de24 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" @@ -23,6 +24,7 @@ limitations under the License. #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { +namespace { struct BufferElement { // The producer sets `status` if getting the input element fails. @@ -40,7 +42,8 @@ class FunctionBufferingResource : public ResourceBase { const NameAttrList& func, int64 buffer_size, const string& source_device, const string& target_device, - const std::vector& func_args) + const std::vector& func_args, + const DataTypeVector& output_types) : lib_(lib), pflr_(std::move(pflr)), func_(func), @@ -48,6 +51,7 @@ class FunctionBufferingResource : public ResourceBase { source_device_(source_device), target_device_(target_device), func_args_(func_args), + output_types_(output_types), handle_(kInvalidHandle), is_buffering_(false), end_of_sequence_(false), @@ -176,6 +180,13 @@ class FunctionBufferingResource : public ResourceBase { AllocatorAttributes arg_alloc_attr; arg_alloc_attr.set_on_host(true); opts.args_alloc_attrs.push_back(arg_alloc_attr); + for (const auto& dtype : output_types_) { + AllocatorAttributes ret_alloc_attrs; + if (DataTypeAlwaysOnHost(dtype)) { + ret_alloc_attrs.set_on_host(true); + } + opts.rets_alloc_attrs.push_back(ret_alloc_attrs); + } if (opts.source_device != target_device_) { opts.remote_execution = true; } @@ -233,6 +244,7 @@ class FunctionBufferingResource : public ResourceBase { const string source_device_; const string target_device_; const std::vector func_args_; + const DataTypeVector output_types_; FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_); std::deque buffer_ GUARDED_BY(mu_); std::deque requests_ GUARDED_BY(mu_); @@ -250,6 +262,7 @@ class FunctionBufferResourceHandleOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); } ~FunctionBufferResourceHandleOp() override { @@ -269,18 +282,20 @@ class FunctionBufferResourceHandleOp : public OpKernel { std::vector func_args; func_args.push_back(*string_arg); + const string& source_device = ctx->device()->name(); + // Obtain and canonicalize target_device. const Tensor* target_arg; OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg)); - const string& target_device = - DeviceNameUtils::CanonicalizeDeviceName(target_arg->scalar()()); + string target_device; + OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName( + target_arg->scalar()(), source_device, + &target_device)); FunctionLibraryRuntime* lib = ctx->function_library(); OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library is provided.")); - const string& source_device = ctx->device()->name(); - mutex_lock l(mu_); if (!initialized_) { OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def())); @@ -297,7 +312,7 @@ class FunctionBufferResourceHandleOp : public OpKernel { this](FunctionBufferingResource** ptr) { *ptr = new FunctionBufferingResource( clone_lib, std::move(pflr), func_, buffer_size_, - source_device, target_device, func_args); + source_device, target_device, func_args, output_types_); return Status::OK(); })); core::ScopedUnref s(buffer); @@ -319,6 +334,7 @@ class FunctionBufferResourceHandleOp : public OpKernel { int64 buffer_size_; string container_; string name_; + DataTypeVector output_types_; }; REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource") @@ -459,4 +475,635 @@ class IteratorGetDeviceOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU), IteratorGetDeviceOp); +Status VerifyTypesMatch(const DataTypeVector& expected, + const DataTypeVector& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " types but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + if (expected[i] != received[i]) { + return errors::InvalidArgument("Data type mismatch at component ", i, + ": expected ", DataTypeString(expected[i]), + " but got ", DataTypeString(received[i]), + "."); + } + } + return Status::OK(); +} + +Status VerifyShapesCompatible(const std::vector& expected, + const std::vector& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " shapes but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + if (!expected[i].IsCompatibleWith(received[i])) { + return errors::InvalidArgument("Incompatible shapes at component ", i, + ": expected ", expected[i].DebugString(), + " but got ", received[i].DebugString(), + "."); + } + } + + return Status::OK(); +} + +string SanitizeThreadSuffix(string suffix) { + string clean; + for (int i = 0; i < suffix.size(); ++i) { + const char ch = suffix[i]; + if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || + (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') { + clean += ch; + } else { + clean += '_'; + } + } + return clean; +} + +struct HostBufferElement { + Status status; + bool end_of_sequence; + std::vector value; +}; + +using MultiDeviceIteratorCallback = + std::function; + +class MultiDeviceIterator : public ResourceBase { + public: + MultiDeviceIterator(const DataTypeVector& output_types, + const std::vector& output_shapes, + const std::vector& devices, + std::unique_ptr flib_def, + std::unique_ptr pflr, + FunctionLibraryRuntime* lib) + : output_types_(output_types), + output_shapes_(output_shapes), + devices_(devices), + flib_def_(std::move(flib_def)), + pflr_(std::move(pflr)), + lib_(lib) {} + + string DebugString() override { + return strings::StrCat("MultiDeviceIterator for ", devices_.size(), + " devices"); + } + + Status Init(std::unique_ptr iterator, int64 max_buffer_size, + int64* incarnation_id) { + if (iterator) { + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_types_, iterator->output_dtypes())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, iterator->output_shapes())); + } + + mutex_lock l(mu_); + if (multi_device_buffer_) { + multi_device_buffer_->Reset(); + } + + ++incarnation_id_; + *incarnation_id = incarnation_id_; + + multi_device_buffer_.reset( + new MultiDeviceBuffer(devices_.size(), max_buffer_size, incarnation_id_, + std::move(iterator))); + return Status::OK(); + } + + void GetNextFromShard(IteratorContext* ctx, int shard_num, + int64 incarnation_id, + MultiDeviceIteratorCallback callback) { + if (lib_ != nullptr) { + ctx->set_lib(lib_); + } + tf_shared_lock l(mu_); + multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id, + std::move(callback)); + } + + const DataTypeVector& output_types() const { return output_types_; } + + const std::vector& output_shapes() const { + return output_shapes_; + } + + std::shared_ptr function_library() { + tf_shared_lock l(mu_); + return lib_def_; + } + + private: + // A private class that uses a background thread to keep a per device buffer + // full. + class MultiDeviceBuffer { + public: + MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id, + std::unique_ptr host_iterator) + : buffer_(size), + size_(size), + max_buffer_size_(max_buffer_size), + incarnation_id_(incarnation_id), + host_iterator_(std::move(host_iterator)) {} + + ~MultiDeviceBuffer() { Reset(); } + + void Reset() LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + if (background_thread_finished_) { + return; + } + + cancelled_ = true; + // Wake up the background thread. + for (int i = 0; i < size_; ++i) { + buffer_[i].cond_var.notify_all(); + } + + // Make sure background thread has finished first. + while (!background_thread_finished_) { + shutdown_cond_var_.wait(l); + } + } + RunPendingCallbacks(); + } + + void GetNextFromShard(IteratorContext* ctx, int shard_num, + int64 incarnation_id, + MultiDeviceIteratorCallback callback) { + HostBufferElement elem; + if (incarnation_id_ != incarnation_id) { + elem.status = errors::InvalidArgument("Invalid incarnation id"); + callback(elem); + return; + } + + bool produced_output = false; + { + mutex_lock l(mu_); + if (cancelled_) { + elem.status = errors::Cancelled("Cancelled Multidevice iterator"); + callback(elem); + return; + } + + EnsureBackgroundThreadStarted(ctx); + + if (!buffer_[shard_num].data.empty()) { + produced_output = true; + std::swap(elem, buffer_[shard_num].data.front()); + buffer_[shard_num].data.pop_front(); + // Wake up background thread if it is blocked on this element. + if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) { + buffer_[shard_num].cond_var.notify_all(); + } + } else { + if (background_thread_finished_) { + produced_output = true; + elem.end_of_sequence = true; + } else { + buffer_[shard_num].callbacks.push_back(std::move(callback)); + callback = nullptr; + } + } + } + + if (produced_output) { + callback(elem); + } + } + + private: + void EnsureBackgroundThreadStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!background_thread_) { + background_thread_.reset(ctx->env()->StartThread( + {}, "multi_device_iterator_background_thread", + std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread, + this, new IteratorContext(*ctx)))); + } + } + + void RunPendingCallbacks() LOCKS_EXCLUDED(mu_) { + // Run all remaining callbacks. + std::vector cancellation_callbacks; + std::vector cancellation_elements; + { + mutex_lock l(mu_); + + for (int i = 0; i < size_; ++i) { + while (!buffer_[i].callbacks.empty()) { + if (buffer_[i].data.empty()) { + HostBufferElement elem; + elem.status = + errors::Cancelled("Cancelled and buffer not filled."); + cancellation_elements.push_back(std::move(elem)); + } else { + cancellation_elements.push_back( + std::move(buffer_[i].data.front())); + buffer_[i].data.pop_front(); + } + cancellation_callbacks.push_back( + std::move(buffer_[i].callbacks.front())); + buffer_[i].callbacks.pop_front(); + } + } + } + for (int i = 0; i < cancellation_callbacks.size(); ++i) { + cancellation_callbacks[i](cancellation_elements[i]); + } + } + + void BackgroundThread(IteratorContext* ctx) { + std::unique_ptr cleanup(ctx); + int shard_to_fetch = 0; + while (true) { + HostBufferElement elem; + MultiDeviceIteratorCallback callback = nullptr; + bool end_of_iterator = false; + + { + mutex_lock l(mu_); + while (!cancelled_ && + buffer_[shard_to_fetch].data.size() >= max_buffer_size_) { + buffer_[shard_to_fetch].cond_var.wait(l); + } + + if (cancelled_) { + background_thread_finished_ = true; + shutdown_cond_var_.notify_all(); + return; + } + } + + elem.status = + host_iterator_->GetNext(ctx, &elem.value, &elem.end_of_sequence); + + if (elem.status.ok() && elem.end_of_sequence) { + end_of_iterator = true; + } + + { + mutex_lock l(mu_); + // Try to find a callback, else just push stuff into buffer. + if (!buffer_[shard_to_fetch].callbacks.empty()) { + callback = buffer_[shard_to_fetch].callbacks.front(); + buffer_[shard_to_fetch].callbacks.pop_front(); + } else { + buffer_[shard_to_fetch].data.push_back(std::move(elem)); + elem = HostBufferElement(); + } + } + + if (callback) { + (*ctx->runner())(std::bind(std::move(callback), std::move(elem))); + } + + // Finish off the thread if we reach the end of the iterator. Runs + // pending callbacks. + if (end_of_iterator) { + { + mutex_lock l(mu_); + background_thread_finished_ = true; + shutdown_cond_var_.notify_all(); + } + RunPendingCallbacks(); + return; + } + shard_to_fetch = (shard_to_fetch + 1) % size_; + } + } + + struct HostBuffer { + condition_variable cond_var; + std::deque data; + std::deque callbacks; + }; + + mutex mu_; + std::unique_ptr background_thread_ GUARDED_BY(mu_); + bool background_thread_finished_ GUARDED_BY(mu_) = false; + bool cancelled_ GUARDED_BY(mu_) = false; + condition_variable shutdown_cond_var_ GUARDED_BY(mu_); + + std::vector buffer_; + + const size_t size_; + const int64 max_buffer_size_; + const int64 incarnation_id_; + const std::unique_ptr host_iterator_; + }; + + mutex mu_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + const std::vector devices_; + const std::unique_ptr flib_def_; + const std::unique_ptr pflr_; + FunctionLibraryRuntime* const lib_ = nullptr; // not owned. + std::shared_ptr lib_def_ GUARDED_BY(mu_); + + int64 incarnation_id_ GUARDED_BY(mu_) = 0; + std::unique_ptr multi_device_buffer_ GUARDED_BY(mu_); +}; + +// Just creates a MultiDeviceIterator and returns it. +class MultiDeviceIteratorHandleOp : public OpKernel { + public: + explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_)); + } + + // The resource is deleted from the resource manager only when it is private + // to kernel. + ~MultiDeviceIteratorHandleOp() override { + if (resource_ != nullptr) { + resource_->Unref(); + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->template Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + } + + void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + if (resource_ == nullptr) { + FunctionLibraryRuntime* lib; + std::unique_ptr flib_def(nullptr); + std::unique_ptr pflr(nullptr); + OP_REQUIRES_OK(context, context->function_library()->Clone( + &flib_def, &pflr, &lib)); + ResourceMgr* mgr = context->resource_manager(); + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); + + MultiDeviceIterator* resource; + OP_REQUIRES_OK( + context, + mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, lib, &flib_def, &pflr](MultiDeviceIterator** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new MultiDeviceIterator( + output_types_, output_shapes_, devices_, + std::move(flib_def), std::move(pflr), lib); + return Status::OK(); + })); + + Status s = VerifyResource(resource); + if (TF_PREDICT_FALSE(!s.ok())) { + resource->Unref(); + context->SetStatus(s); + return; + } + + resource_ = resource; + } + } + OP_REQUIRES_OK(context, MakeResourceHandleToOutput( + context, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + // During the first Compute(), resource is either created or looked up using + // shared_name. In the latter case, the resource found should be verified if + // it is compatible with this op's configuration. The verification may fail in + // cases such as two graphs asking queues of the same shared name to have + // inconsistent capacities. + Status VerifyResource(MultiDeviceIterator* resource) { + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_types_, resource->output_types())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, resource->output_shapes())); + return Status::OK(); + } + + mutex mu_; + ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. + MultiDeviceIterator* resource_ GUARDED_BY(mu_) = nullptr; + DataTypeVector output_types_; + std::vector output_shapes_; + const int graph_def_version_; + string name_; + string container_; + std::vector devices_; +}; + +REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU), + MultiDeviceIteratorHandleOp); + +// Calls init on the MultiDeviceIterator. +class MultiDeviceIteratorInitOp : public OpKernel { + public: + explicit MultiDeviceIteratorInitOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor* tensor_max_buffer_size; + OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size)); + int64 max_buffer_size = tensor_max_buffer_size->scalar()(); + + DatasetBase* dataset; + OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); + MultiDeviceIterator* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 1), &resource)); + core::ScopedUnref unref(resource); + + std::unique_ptr iterator; + OP_REQUIRES_OK(ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", + &iterator)); + int64 incarnation_id; + OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size, + &incarnation_id)); + Tensor tensor_incarnation_id(DT_INT64, TensorShape({})); + tensor_incarnation_id.scalar()() = incarnation_id; + OP_REQUIRES_OK(ctx, + ctx->set_output("incarnation_id", tensor_incarnation_id)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("MultiDeviceIteratorInit").Device(DEVICE_CPU), + MultiDeviceIteratorInitOp); + +// Calls GetNextFromShard(shard) and returns a vector of Tensors as output. +// TODO(rohanj): Implement using BackgroundWorker that Derek built? +class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { + public: + explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + thread_pool_(new thread::ThreadPool( + ctx->env(), ThreadOptions(), + strings::StrCat("multi_device_iterator_get_next_thread_", + SanitizeThreadSuffix(name())), + 1 /* num_threads */, false /* low_latency_hint */)) {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + const Tensor* tensor_shard_num; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done); + int32 shard_num = tensor_shard_num->scalar()(); + + const Tensor* tensor_incarnation_id; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done); + int64 incarnation_id = tensor_incarnation_id->scalar()(); + + MultiDeviceIterator* iterator; + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); + thread_pool_->Schedule(std::bind( + [ctx, iterator, shard_num, incarnation_id](DoneCallback done) { + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = *(ctx->runner()); + params.function_library = iterator->function_library(); + DeviceBase* device = ctx->function_library()->device(); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; + IteratorContext iter_ctx(std::move(params)); + + MultiDeviceIteratorCallback callback = std::bind( + [ctx](const HostBufferElement& elem, DoneCallback done) { + // iterator->Unref(); + Status s = elem.status; + if (!s.ok()) { + ctx->SetStatus(s); + } else if (elem.end_of_sequence) { + ctx->SetStatus(errors::OutOfRange("End of sequence")); + } else { + for (int i = 0; i < elem.value.size(); ++i) { + ctx->set_output(i, elem.value[i]); + } + } + done(); + }, + std::placeholders::_1, std::move(done)); + + iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, + callback); + iterator->Unref(); + }, + std::move(done))); + } + + private: + std::unique_ptr thread_pool_; +}; + +REGISTER_KERNEL_BUILDER( + Name("MultiDeviceIteratorGetNextFromShard").Device(DEVICE_CPU), + MultiDeviceIteratorGetNextFromShardOp); + +class MultiDeviceIteratorToStringHandleOp : public OpKernel { + public: + explicit MultiDeviceIteratorToStringHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& resource_handle_t = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), + errors::InvalidArgument("resource_handle must be a scalar")); + + // Validate that the handle corresponds to a real resource, and + // that it is an MultiDeviceIterator. + MultiDeviceIterator* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + resource->Unref(); + + Tensor* string_handle_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &string_handle_t)); + string_handle_t->scalar()() = + resource_handle_t.scalar()().SerializeAsString(); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("MultiDeviceIteratorToStringHandle").Device(DEVICE_CPU), + MultiDeviceIteratorToStringHandleOp); + +class MultiDeviceIteratorFromStringHandleOp : public OpKernel { + public: + explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES( + ctx, + output_types_.empty() || output_shapes_.empty() || + output_types_.size() == output_shapes_.size(), + errors::InvalidArgument("If both 'output_types' and 'output_shapes' " + "are set, they must have the same length.")); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& string_handle_t = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()), + errors::InvalidArgument("string_handle must be a scalar")); + + ResourceHandle resource_handle; + OP_REQUIRES( + ctx, + resource_handle.ParseFromString(string_handle_t.scalar()()), + errors::InvalidArgument( + "Could not parse string_handle as a valid ResourceHandle")); + + OP_REQUIRES( + ctx, resource_handle.device() == ctx->device()->attributes().name(), + errors::InvalidArgument("Attempted create an iterator on device \"", + ctx->device()->attributes().name(), + "\" from handle defined on device \"", + resource_handle.device(), "\"")); + + // Validate that the handle corresponds to a real resource, and + // that it is an MultiDeviceIterator. + MultiDeviceIterator* resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource)); + core::ScopedUnref unref_iterator(resource); + if (!output_types_.empty()) { + OP_REQUIRES_OK(ctx, + VerifyTypesMatch(output_types_, resource->output_types())); + } + if (!output_shapes_.empty()) { + OP_REQUIRES_OK(ctx, VerifyShapesCompatible(output_shapes_, + resource->output_shapes())); + } + + Tensor* resource_handle_t; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t)); + resource_handle_t->scalar()() = resource_handle; + } + + private: + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +REGISTER_KERNEL_BUILDER( + Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU), + MultiDeviceIteratorFromStringHandleOp); + +} // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index 3dfc3741c2b040dd5be3223c24f0715ba3be4248..ab584504a05369105d080df73750974af9fc70bb 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace { @@ -24,19 +25,32 @@ namespace { class ThreadPoolResource : public ResourceBase { public: ThreadPoolResource(Env* env, const ThreadOptions& thread_options, - const string& name, int num_threads, bool low_latency_hint) - : thread_pool_(env, thread_options, name, num_threads, low_latency_hint) { - } + const string& name, int num_threads, bool low_latency_hint, + int max_intra_op_parallelism) + : thread_pool_(env, thread_options, name, num_threads, low_latency_hint), + max_intra_op_parallelism_(max_intra_op_parallelism) {} // Schedules fn() for execution in the pool of threads. void Schedule(std::function fn) { - thread_pool_.Schedule(std::move(fn)); + if (max_intra_op_parallelism_ < 0) { + thread_pool_.Schedule(std::move(fn)); + } else { + thread_pool_.Schedule(std::bind( + [this](std::function bound_fn) { + // TODO(mrry): Consider moving this thread-local configuration to + // the threads themselves. + ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_); + bound_fn(); + }, + std::move(fn))); + } } string DebugString() override { return "ThreadPoolResource"; } private: thread::ThreadPool thread_pool_; + const int max_intra_op_parallelism_; }; // Creates a handle to a ThreadPool resource. Note that we don't use @@ -48,6 +62,8 @@ class ThreadPoolHandleOp : public OpKernel { explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism", + &max_intra_op_parallelism_)); OP_REQUIRES( ctx, num_threads_ > 0, errors::InvalidArgument("`num_threads` must be greater than zero.")); @@ -78,7 +94,7 @@ class ThreadPoolHandleOp : public OpKernel { EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new ThreadPoolResource( ctx->env(), {}, display_name_, - num_threads_, + num_threads_, max_intra_op_parallelism_, false /* low_latency_hint */); return Status::OK(); })); @@ -95,6 +111,7 @@ class ThreadPoolHandleOp : public OpKernel { bool initialized_ GUARDED_BY(mu_) = false; string display_name_; int num_threads_; + int max_intra_op_parallelism_; }; class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { @@ -113,11 +130,13 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, ThreadPoolResource* threadpool) - : GraphDatasetBase(ctx), input_(input), threadpool_(threadpool) { + : DatasetBase(DatasetContext(ctx)), + input_(input), + threadpool_(threadpool) { input_->Ref(); threadpool_->Ref(); } @@ -145,11 +164,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { - return errors::Unimplemented( - "Cannot currently serialize the thread pool for a " - "ThreadPoolDataset."); + return errors::Unimplemented("%s does not support serialization", + DebugString()); } private: diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc index 67c237799c10a2724f18bb0df99e4bf8f5cd2b8a..6fbf5d2ebb598132a7e8433608e67436a172b615 100644 --- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -47,10 +47,10 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input) - : GraphDatasetBase(ctx), input_(input) { + : DatasetBase(DatasetContext(ctx)), input_(input) { input_->Ref(); } @@ -75,10 +75,11 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); return Status::OK(); } @@ -116,7 +117,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) { - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); } else { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impl_empty"), "")); @@ -135,7 +136,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { IteratorStateReader* reader) override { mutex_lock l(mu_); if (!reader->Contains(full_name("input_impl_empty"))) { - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); } else { input_impl_.reset(); } diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index f271d269ab1b9339de4657e459dcbbd462890f0a..cc5e250ea15bf89be2db9aba14e3b29b72512a73 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -36,6 +36,7 @@ data_input_datasets: `N` datasets with the same type that will be interleaved REGISTER_OP("CSVDataset") .Input("filenames: string") + .Input("compression_type: string") .Input("buffer_size: int64") .Input("header: bool") .Input("field_delim: string") @@ -52,17 +53,18 @@ REGISTER_OP("CSVDataset") shape_inference::ShapeHandle unused; // `filenames` must be a scalar or a vector. TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); - // `buffer_size`, `header`, `field_delim`, `use_quote_delim`, - // `na_value` must be scalars + // `compression_type`, `buffer_size`, `header`, `field_delim`, + // `use_quote_delim`, `na_value` must be scalars TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // `select_cols` must be a vector - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &unused)); - // `record_defaults` must be a list of scalars...? - for (size_t i = 7; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); + // `record_defaults` must be lists of scalars + for (size_t i = 8; i < c->num_inputs(); ++i) { TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused)); } return shape_inference::ScalarShape(c); @@ -104,6 +106,7 @@ REGISTER_OP("FunctionBufferingResource") .Attr("container: string") .Attr("f: func") .Attr("buffer_size: int") + .Attr("output_types: list(type)") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Creates a resource that fills up a buffer by making function calls. @@ -117,6 +120,7 @@ container: If non-empty, this resource is placed in the given container. Otherwise, a default container is used. shared_name: If non-empty, this resource will be shared under the given name across multiple sessions. +output_types: The type list for the return values. )doc"); REGISTER_OP("FunctionBufferingResourceGetNext") @@ -141,6 +145,82 @@ Resets the FunctionBufferingResource. function_buffer_resource: The FunctionBufferingResource handle. )doc"); +REGISTER_OP("MultiDeviceIterator") + .Output("handle: resource") + .Attr("devices: list(string) >= 1") + .Attr("shared_name: string") + .Attr("container: string") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Doc(R"doc( +Creates a MultiDeviceIterator resource. + +handle: Handle to the resource created. +devices: A list of devices the iterator works across. +shared_name: If non-empty, this resource will be shared under the given name + across multiple sessions. +container: If non-empty, this resource is placed in the given container. + Otherwise, a default container is used. +output_types: The type list for the return values. +output_shapes: The list of shapes being produced. +)doc"); + +REGISTER_OP("MultiDeviceIteratorInit") + .Input("dataset: variant") + .Input("multi_device_iterator: resource") + .Input("max_buffer_size: int64") + .Output("incarnation_id: int64") + .Doc(R"doc( +Initializes the multi device iterator with the given dataset. +max_buffer_size: The maximum size of the host side per device buffer to keep. +incarnation_id: An int64 indicating which incarnation of the MultiDeviceIterator + is running. +dataset: Dataset to be iterated upon. +multi_device_iterator: A MultiDeviceIteratorResource. +)doc"); + +REGISTER_OP("MultiDeviceIteratorGetNextFromShard") + .Input("multi_device_iterator: resource") + .Input("shard_num: int32") + .Input("incarnation_id: int64") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Doc(R"doc( +Gets next element for the provided shard number. + +multi_device_iterator: A MultiDeviceIterator resource. +shard_num: Integer representing which shard to fetch data for. +incarnation_id: Which incarnation of the MultiDeviceIterator is running. +components: Result of the get_next on the dataset. +output_types: The type list for the return values. +output_shapes: The list of shapes being produced. +)doc"); + +REGISTER_OP("MultiDeviceIteratorToStringHandle") + .Input("multi_device_iterator: resource") + .Output("string_handle: string") + .Doc(R"doc( +Produces a string handle for the given MultiDeviceIterator. + +multi_device_iterator: A MultiDeviceIterator resource. +string_handle: A string representing the resource. +)doc"); + +REGISTER_OP("MultiDeviceIteratorFromStringHandle") + .Input("string_handle: string") + .Output("multi_device_iterator: resource") + .Attr("output_types: list(type) >= 0 = []") + .Attr("output_shapes: list(shape) >= 0 = []") + .Doc(R"doc( +Generates a MultiDeviceIterator resource from its provided string handle. + +string_handle: String representing the resource. +multi_device_iterator: A MultiDeviceIterator resource. +output_types: The type list for the return values. +output_shapes: The list of shapes being produced. +)doc"); + REGISTER_OP("ThreadPoolDataset") .Input("input_dataset: variant") .Input("thread_pool: resource") @@ -158,6 +238,7 @@ REGISTER_OP("ThreadPoolHandle") .Output("handle: resource") .SetShapeFn(shape_inference::ScalarShape) .Attr("num_threads: int") + .Attr("max_intra_op_parallelism: int = 1") .Attr("display_name: string") .Attr("container: string = ''") .Attr("shared_name: string = ''") @@ -166,8 +247,23 @@ Creates a custom thread pool with the given number of threads. handle: A resource that can be consumed by one or more ThreadPoolDataset ops. num_threads: The number of threads in the thread pool. +max_intra_op_parallelism: The maximum degree of parallelism to use within + operations that execute on this threadpool. display_name: A human-readable name for the threads that may be visible in some visualizations. )doc"); +REGISTER_OP("AssertNextDataset") + .Input("input_dataset: variant") + .Input("transformations: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // transformations should be a vector. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + return shape_inference::ScalarShape(c); + }); + } // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 0dfd249ec27d96d6f1a4ae65d623df456db9991f..2b75aa2ca54509b42f431db2dd39261cf025588a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,7 +4,7 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test") py_test( name = "batch_dataset_op_test", @@ -16,20 +16,23 @@ py_test( "no_pip", ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -39,7 +42,6 @@ py_test( srcs = ["bucketing_test.py"], srcs_version = "PY2AND3", deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:grouping", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -48,37 +50,33 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", - "//third_party/py/numpy", - ], -) - -py_test( - name = "cache_dataset_op_test", - size = "small", - srcs = ["cache_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":dataset_serialization_test", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) py_test( - name = "concatenate_dataset_op_test", - size = "small", - srcs = ["concatenate_dataset_op_test.py"], + name = "csv_dataset_op_test", + size = "medium", + srcs = ["csv_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//tensorflow/python/data/ops:readers", "//third_party/py/numpy", ], ) @@ -93,104 +91,46 @@ py_test( "nomac", # b/62040583 ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:session", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", - "//third_party/py/numpy", ], ) -py_library( - name = "dataset_serialization_test", - srcs = [ - "dataset_serialization_test_base.py", - ], +py_test( + name = "directed_interleave_dataset_test", + size = "medium", + srcs = ["directed_interleave_dataset_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:interleave_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "csv_dataset_op_test", - size = "small", - srcs = ["csv_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:error_ops", - "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:random_seed", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) py_test( - name = "filter_dataset_op_test", + name = "get_single_element_test", size = "small", - srcs = ["filter_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "optonly", - ], + srcs = ["get_single_element_test.py"], deps = [ - ":dataset_serialization_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:functional_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "flat_map_dataset_op_test", - size = "medium", - srcs = ["flat_map_dataset_op_test.py"], - additional_deps = [ - ":dataset_serialization_test", - "//third_party/py/numpy", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:get_single_element", + "//tensorflow/contrib/data/python/ops:grouping", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:function", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", ], - grpc_enabled = True, - tags = ["no_pip"], ) py_test( @@ -205,10 +145,8 @@ py_test( "notap", ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:interleave_ops", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", @@ -216,43 +154,28 @@ py_test( "//tensorflow/python:script_ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", + "@six_archive//:six", ], ) py_test( - name = "directed_interleave_dataset_test", - size = "medium", - srcs = ["directed_interleave_dataset_test.py"], + name = "iterator_ops_test", + size = "small", + srcs = ["iterator_ops_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "get_single_element_test", - size = "small", - srcs = ["get_single_element_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow/contrib/data/python/ops:get_single_element", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python:array_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -267,55 +190,82 @@ py_test( "optonly", ], deps = [ - ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", "//tensorflow/python:io_ops", - "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", "//tensorflow/python:util", - "//tensorflow/python:variable_scope", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) +py_test( + name = "map_defun_op_test", + size = "small", + srcs = ["map_defun_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/contrib/data/python/ops:map_defun", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + ], +) + py_test( name = "optimize_dataset_op_test", size = "small", srcs = ["optimize_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - ":dataset_serialization_test", + ":stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:platform", + "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", ], ) -py_test( - name = "prefetch_dataset_op_test", +cuda_py_test( + name = "prefetching_ops_test", size = "small", - srcs = ["prefetch_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":dataset_serialization_test", - "//tensorflow/python:platform", + srcs = ["prefetching_ops_test.py"], + additional_deps = [ + "//tensorflow/contrib/data/python/ops:prefetching_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/compat:compat", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + ], + tags = [ + "manual", + "no_oss", + "no_windows_gpu", + "notap", ], ) @@ -325,20 +275,13 @@ py_test( srcs = ["range_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:counter", "//tensorflow/contrib/data/python/ops:enumerate_ops", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", - "//tensorflow/python:parsing_ops", "//tensorflow/python:tensor_shape", - "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -350,15 +293,21 @@ py_library( "reader_dataset_ops_test_base.py", ], srcs_version = "PY2AND3", - visibility = ["//visibility:private"], + visibility = [ + "//tensorflow/contrib/data/python/kernel_tests:__pkg__", + "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__", + ], deps = [ "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", "//tensorflow/python:util", + "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:readers", ], ) @@ -367,24 +316,18 @@ py_test( name = "reader_dataset_ops_test", size = "medium", srcs = ["reader_dataset_ops_test.py"], - shard_count = 4, srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", "//tensorflow/python:string_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:readers", "//third_party/py/numpy", ], @@ -412,6 +355,7 @@ py_test( "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", + "@six_archive//:six", ], ) @@ -422,13 +366,14 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:scan_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:context", "//third_party/py/numpy", @@ -436,57 +381,56 @@ py_test( ) py_test( - name = "sequence_dataset_op_test", + name = "shuffle_dataset_op_test", size = "medium", - srcs = ["sequence_dataset_op_test.py"], + srcs = ["shuffle_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_pip", + "optonly", + ], deps = [ - ":dataset_serialization_test", - "//tensorflow/python:array_ops", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) py_test( - name = "serialization_integration_test", + name = "slide_dataset_op_test", size = "small", - srcs = ["serialization_integration_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], + srcs = ["slide_dataset_op_test.py"], deps = [ - "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:sliding", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) -py_test( - name = "shuffle_dataset_op_test", - size = "medium", - srcs = ["shuffle_dataset_op_test.py"], +py_library( + name = "sql_dataset_op_test_base", + srcs = ["sql_dataset_op_test_base.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + visibility = [ + "//tensorflow/contrib/data/python/kernel_tests:__pkg__", + "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__", + ], deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:iterator_ops", - "//tensorflow/contrib/data/python/ops:shuffle_ops", + "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", + "@org_sqlite//:python", ], ) @@ -495,14 +439,12 @@ py_test( size = "small", srcs = ["sql_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/python:array_ops", + ":sql_dataset_op_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "@org_sqlite//:python", ], ) @@ -513,10 +455,9 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", ":reader_dataset_ops_test_base", + ":stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:stats_ops", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", @@ -526,109 +467,85 @@ py_test( ], ) -py_test( - name = "threadpool_dataset_ops_test", - size = "small", - srcs = ["threadpool_dataset_ops_test.py"], +py_library( + name = "stats_dataset_test_base", + srcs = ["stats_dataset_test_base.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], deps = [ - "//tensorflow/contrib/data/python/ops:threadpool", - "//tensorflow/contrib/data/python/ops:unique", + "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", ], ) py_test( - name = "unique_dataset_op_test", + name = "threadpool_dataset_ops_test", size = "small", - srcs = ["unique_dataset_op_test.py"], + srcs = ["threadpool_dataset_ops_test.py"], srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:threadpool", "//tensorflow/contrib/data/python/ops:unique", - "//tensorflow/contrib/stateless", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:script_ops", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) py_test( - name = "zip_dataset_op_test", + name = "unique_dataset_op_test", size = "small", - srcs = ["zip_dataset_op_test.py"], + srcs = ["unique_dataset_op_test.py"], srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", - "//tensorflow/python:array_ops", + "//tensorflow/contrib/data/python/ops:unique", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", ], ) -cuda_py_test( - name = "prefetching_ops_test", - size = "small", - srcs = ["prefetching_ops_test.py"], - additional_deps = [ - "//tensorflow/contrib/data/python/ops:prefetching_ops", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", +py_test( + name = "window_dataset_op_test", + size = "medium", + srcs = ["window_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", ], -) - -tf_py_test( - name = "slide_dataset_op_test", - size = "small", - srcs = ["slide_dataset_op_test.py"], - additional_deps = [ - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:sliding", + deps = [ + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/contrib/data/python/ops:grouping", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", "//tensorflow/python:math_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) -tf_py_test( +py_test( name = "writer_ops_test", size = "small", srcs = ["writer_ops_test.py"], - additional_deps = [ + deps = [ "//tensorflow/contrib/data/python/ops:writers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", "//tensorflow/python:lib", - "//tensorflow/python:tensor_shape", "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index b5fbc45ad3d8d262c1c79b5723ffeb38ff6a34c2..42adfd17f07e508f25d8b351c791fa519eca8bd9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -20,9 +20,9 @@ from __future__ import print_function import math import time +from absl.testing import parameterized import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops @@ -40,7 +40,7 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class BatchDatasetTest(test.TestCase): +class BatchDatasetTest(test.TestCase, parameterized.TestCase): def assertSparseValuesEqual(self, a, b): self.assertAllEqual(a.indices, b.indices) @@ -293,7 +293,7 @@ class BatchDatasetTest(test.TestCase): ph2: np.arange(8).astype(np.int32) }) with self.assertRaises(errors.InvalidArgumentError): - print(sess.run(next_element)) + sess.run(next_element) # No 0th dimension (i.e. scalar value) for one component. sess.run( @@ -303,7 +303,7 @@ class BatchDatasetTest(test.TestCase): ph2: 7 }) with self.assertRaises(errors.InvalidArgumentError): - print(sess.run(next_element)) + sess.run(next_element) def testBatchAndDropRemainder(self): components = (np.arange(7), @@ -427,9 +427,13 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testMapAndBatchDatasetHelper(self, - num_parallel_calls=None, - num_parallel_batches=None): + @parameterized.named_parameters( + ("default", None, None), + ("sequential_calls", 1, None), + ("parallel_calls", 2, None), + ("parallel_batches", None, 10), + ) + def testMapAndBatch(self, num_parallel_calls, num_parallel_batches): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). @@ -500,19 +504,11 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testMapAndBatch(self): - return self._testMapAndBatchDatasetHelper() - - def testMapAndBatchWithParallelBatches(self): - return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) - - def testMapAndBatchWithSequentialCalls(self): - return self._testMapAndBatchDatasetHelper(num_parallel_calls=1) - - def testMapAndBatchWithParallelCalls(self): - return self._testMapAndBatchDatasetHelper(num_parallel_calls=2) - - def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): + @parameterized.named_parameters( + ("even", False), + ("uneven", True), + ) + def testMapAndBatchPartialBatch(self, drop_remainder): iterator = ( dataset_ops.Dataset.range(10).apply( batching.map_and_batch( @@ -532,12 +528,6 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) - def testMapAndBatchPartialBatch(self): - return self._testMapAndBatchPartialBatchHelper() - - def testMapAndBatchPartialBatchDropRemainder(self): - return self._testMapAndBatchPartialBatchHelper(drop_remainder=True) - def testMapAndBatchYieldsPartialBatch(self): iterator = (dataset_ops.Dataset.range(10) .apply(batching.map_and_batch( @@ -614,7 +604,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testMapAndBatchDatasetFails(self): + def testMapAndBatchFails(self): """Test a dataset that maps a TF function across its input elements.""" dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( @@ -628,7 +618,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) - def testMapAndBatchDatasetShapeMismatch(self): + def testMapAndBatchShapeMismatch(self): """Test a dataset that maps a TF function across its input elements.""" def generator(): @@ -651,173 +641,79 @@ class BatchDatasetTest(test.TestCase): "number of elements does not match"): sess.run(get_next) + def testMapAndBatchImplicitDispose(self): + # Tests whether a map and batch dataset will be cleaned up correctly when + # the pipeline does not run it until exhaustion. + # The pipeline is TensorSliceDataset -> RepeatDataset(1000) -> + # MapAndBatchDataset(f=square_3, batch_size=100). + components = (np.arange(1000), + np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], + np.array(37.0) * np.arange(1000)) -class BatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): - components = ( - np.arange(tensor_slice_len), - np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(tensor_slice_len)) + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) - return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size) + dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat( + 1000).apply(batching.map_and_batch(_map_fn, batch_size=100)) + dataset = dataset.prefetch(5) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() - def testCore(self): - tensor_slice_len = 8 - batch_size = 2 - num_outputs = tensor_slice_len // batch_size - self.run_core_tests( - lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), - lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), - num_outputs) + with self.test_session() as sess: + for _ in range(3): + sess.run(get_next) - def _build_dataset_dense_to_sparse(self, components): - return dataset_ops.Dataset.from_tensor_slices(components).map( - lambda x: array_ops.fill([x], x)).apply( - batching.dense_to_sparse_batch(4, [12])) + @parameterized.parameters(0, 5, 10, 90, 95, 99) + def testMapAndBatchOutOfRangeError(self, threshold): - def testDenseToSparseBatchDatasetCore(self): - components = np.random.randint(5, size=(40,)).astype(np.int32) - diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) - - num_outputs = len(components) // 4 - self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components), - lambda: self._build_dataset_dense_to_sparse(diff_comp), - num_outputs) - - def _sparse(self, i): - return sparse_tensor.SparseTensorValue( - indices=[[0]], values=(i * [1]), dense_shape=[1]) + def raising_py_fn(i): + if i >= threshold: + raise StopIteration() + else: + return i - def _build_dataset_sparse(self, batch_size=5): - return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size) - - def testSparseCore(self): - self.run_core_tests(self._build_dataset_sparse, - lambda: self._build_dataset_sparse(2), 2) - - def _build_dataset_nested_sparse(self): - return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2) - - def testNestedSparseCore(self): - self.run_core_tests(self._build_dataset_nested_sparse, None, 1) + iterator = ( + dataset_ops.Dataset.range(100).apply( + batching.map_and_batch( + lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64), + batch_size=10)).make_one_shot_iterator()) + get_next = iterator.get_next() + with self.test_session() as sess: + for i in range(threshold // 10): + self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next)) + if threshold % 10 != 0: + self.assertAllEqual( + [threshold // 10 * 10 + j for j in range(threshold % 10)], + sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) -class UnbatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + @parameterized.parameters( + (False, dtypes.bool), + (-42, dtypes.int8), + (-42, dtypes.int16), + (-42, dtypes.int32), + (-42, dtypes.int64), + (42, dtypes.uint8), + (42, dtypes.uint16), + (42.0, dtypes.float16), + (42.0, dtypes.float32), + (42.0, dtypes.float64), + (b"hello", dtypes.string), + ) + def testMapAndBatchTypes(self, element, dtype): + def gen(): + yield element + + dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply( + batching.map_and_batch(lambda x: x, batch_size=10)) + + get_next = dataset.make_one_shot_iterator().get_next() - def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): - components = ( - np.arange(tensor_slice_len), - np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(tensor_slice_len)) - - return dataset_ops.Dataset.from_tensor_slices(components).batch( - batch_size).apply(batching.unbatch()) - - def testCore(self): - tensor_slice_len = 8 - batch_size = 2 - num_outputs = tensor_slice_len - self.run_core_tests( - lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), - lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), - num_outputs) - - -class MapAndBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testNumParallelBatches(self): - range_size = 11 - num_repeats = 2 - batch_size = 5 - total_outputs = range_size * num_repeats - num_outputs_drop_remainder = total_outputs // batch_size - num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) - num_parallel_batches = 2 - - def build_ds(range_start, drop_remainder=False): - - def _map_fn(x): - return math_ops.square(x) - - return dataset_ops.Dataset.range( - range_start, range_start + range_size).repeat(num_repeats).apply( - batching.map_and_batch( - map_func=_map_fn, - batch_size=batch_size, - num_parallel_batches=num_parallel_batches, - drop_remainder=drop_remainder)) - - self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), - num_outputs_keep_remainder) - self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), - num_outputs_drop_remainder) - - def testNumParallelCalls(self): - range_size = 11 - num_repeats = 2 - batch_size = 5 - total_outputs = range_size * num_repeats - num_outputs_drop_remainder = total_outputs // batch_size - num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) - num_parallel_calls = 7 - - def build_ds(range_start, drop_remainder=False): - - def _map_fn(x): - return math_ops.square(x) - - return dataset_ops.Dataset.range( - range_start, range_start + range_size).repeat(num_repeats).apply( - batching.map_and_batch( - map_func=_map_fn, - batch_size=batch_size, - num_parallel_calls=num_parallel_calls, - drop_remainder=drop_remainder)) - - self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), - num_outputs_keep_remainder) - self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), - num_outputs_drop_remainder) - - -class PaddedBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testPaddedBatch(self): - - def build_dataset(seq_lens): - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - lambda x: array_ops.fill([x], x)).padded_batch( - 4, padded_shapes=[-1]) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) - - def testPaddedBatchNonDefaultPadding(self): - - def build_dataset(seq_lens): - - def fill_tuple(x): - filled = array_ops.fill([x], x) - return (filled, string_ops.as_string(filled)) - - padded_shape = [-1] - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - fill_tuple).padded_batch( - 4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, "")) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) + with self.test_session() as sess: + for _ in range(10): + self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) class RestructuredDatasetTest(test.TestCase): diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index bd3e034211c4aa454e4f8f6b09f14935d7a3b35c..2022c1f2bdd09cdf43a993b3666335ce468a40ba 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -21,7 +21,6 @@ import random import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -68,7 +67,7 @@ class GroupByReducerTest(test.TestCase): reducer = grouping.Reducer( init_func=lambda _: (0.0, 0.0), reduce_func=reduce_fn, - finalize_func=lambda x: x[0]) + finalize_func=lambda x, _: x) for i in range(1, 11): dataset = dataset_ops.Dataset.range(2 * i).apply( grouping.group_by_reducer( @@ -121,7 +120,7 @@ class GroupByReducerTest(test.TestCase): reducer = grouping.Reducer( init_func=lambda x: ([0], 1), reduce_func=reduce_fn, - finalize_func=lambda x: x) + finalize_func=lambda x, y: (x, y)) for i in range(1, 11): dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply( @@ -176,37 +175,27 @@ class GroupByReducerTest(test.TestCase): dataset.apply( grouping.group_by_reducer(lambda _: "wrong", reducer)) + def testTuple(self): + def init_fn(_): + return np.array([], dtype=np.int64), np.int64(0) -class GroupByReducerSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + def reduce_fn(state, value): + s1, s2 = state + v1, v2 = value + return array_ops.concat([s1, [v1]], 0), s2 + v2 - def _build_dataset(self, components): - reducer = grouping.Reducer( - init_func=lambda _: np.int64(0), - reduce_func=lambda x, y: x + y, - finalize_func=lambda x: x) + def finalize_fn(s1, s2): + return s1, s2 - return dataset_ops.Dataset.from_tensor_slices(components).apply( - grouping.group_by_reducer(lambda x: x % 5, reducer)) - - def testCoreGroupByReducer(self): - components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) - self.verify_unused_iterator( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_init_before_restore( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_multiple_breaks( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_reset_restored_iterator( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_restore_in_empty_graph( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64) - self.verify_restore_in_modified_graph( - lambda: self._build_dataset(components), - lambda: self._build_dataset(diff_components), - 5, - verify_exhausted=True) + reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn) + dataset = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply( + grouping.group_by_reducer(lambda x, y: np.int64(0), reducer)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + x, y = sess.run(get_next) + self.assertAllEqual(x, np.asarray([x for x in range(10)])) + self.assertEqual(y, 45) class GroupByWindowTest(test.TestCase): @@ -353,34 +342,6 @@ class GroupByWindowTest(test.TestCase): self.assertEqual(len(components), sum(counts)) -class GroupByWindowSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, components): - return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( - grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)) - - def testCoreGroupByWindow(self): - components = np.array( - [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) - self.verify_unused_iterator( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_init_before_restore( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_multiple_breaks( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_reset_restored_iterator( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_restore_in_empty_graph( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) - self.verify_restore_in_modified_graph( - lambda: self._build_dataset(components), - lambda: self._build_dataset(diff_components), - 12, - verify_exhausted=False) - - # NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. # Currently, they use a constant batch size, though should be made to use a # different batch size per key. @@ -655,7 +616,44 @@ class BucketBySequenceLength(test.TestCase): batch_sizes = batch_sizes[:-1] self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) - self.assertEqual(sorted(boundaries), sorted(lengths_val)) + self.assertEqual([boundary - 1 for boundary in sorted(boundaries)], + sorted(lengths_val)) + + def testPadToBoundaryNoExtraneousPadding(self): + + boundaries = [3, 7, 11] + batch_sizes = [2, 2, 2, 2] + lengths = range(1, 11) + + def element_gen(): + for length in lengths: + yield ([1] * length,) + + element_len = lambda element: array_ops.shape(element)[0] + dataset = dataset_ops.Dataset.from_generator( + element_gen, (dtypes.int64,), ([None],)).apply( + grouping.bucket_by_sequence_length( + element_len, boundaries, batch_sizes, + pad_to_bucket_boundary=True)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + batches = [] + for _ in range(5): + batches.append(sess.run(batch)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(batch) + + self.assertAllEqual(batches[0], [[1, 0], + [1, 1]]) + self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0]]) + self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1]]) + self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) def testTupleElements(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 97b5e9416521dcad9ee5047a8275f8fd0142e338..2a0e64caeb61c5a7d45669783ace4588746c19e3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gzip import os import string import tempfile import time +import zlib import numpy as np @@ -33,7 +35,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import googletest from tensorflow.python.platform import test @@ -62,21 +64,32 @@ class CsvDatasetOpTest(test.TestCase): op2 = sess.run(next2) self.assertAllEqual(op1, op2) - def setup_files(self, inputs, linebreak='\n'): + def _setup_files(self, inputs, linebreak='\n', compression_type=None): filenames = [] for i, ip in enumerate(inputs): fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) - with open(fn, 'wb') as f: - f.write(linebreak.join(ip).encode('utf-8')) + contents = linebreak.join(ip).encode('utf-8') + if compression_type is None: + with open(fn, 'wb') as f: + f.write(contents) + elif compression_type == 'GZIP': + with gzip.GzipFile(fn, 'wb') as f: + f.write(contents) + elif compression_type == 'ZLIB': + contents = zlib.compress(contents) + with open(fn, 'wb') as f: + f.write(contents) + else: + raise ValueError('Unsupported compression_type', compression_type) filenames.append(fn) return filenames def _make_test_datasets(self, inputs, **kwargs): # Test by comparing its output to what we could get with map->decode_csv - filenames = self.setup_files(inputs) + filenames = self._setup_files(inputs) dataset_expected = core_readers.TextLineDataset(filenames) dataset_expected = dataset_expected.map( - lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) + lambda l: parsing_ops.decode_csv(l, **kwargs)) dataset_actual = readers.CsvDataset(filenames, **kwargs) return (dataset_actual, dataset_expected) @@ -112,15 +125,18 @@ class CsvDatasetOpTest(test.TestCase): except errors.OutOfRangeError: break - def _test_dataset(self, - inputs, - expected_output=None, - expected_err_re=None, - linebreak='\n', - **kwargs): + def _test_dataset( + self, + inputs, + expected_output=None, + expected_err_re=None, + linebreak='\n', + compression_type=None, # Used for both setup and parsing + **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings - filenames = self.setup_files(inputs, linebreak) + filenames = self._setup_files(inputs, linebreak, compression_type) + kwargs['compression_type'] = compression_type with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, **kwargs) @@ -174,7 +190,7 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_ignoreErrWithUnescapedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] - filenames = self.setup_files(inputs) + filenames = self._setup_files(inputs) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) @@ -184,7 +200,7 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_ignoreErrWithUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] - filenames = self.setup_files(inputs) + filenames = self._setup_files(inputs) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) @@ -355,7 +371,7 @@ class CsvDatasetOpTest(test.TestCase): '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19', '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19' ]] - file_path = self.setup_files(data) + file_path = self._setup_files(data) with ops.Graph().as_default() as g: ds = readers.make_csv_dataset( @@ -432,14 +448,29 @@ class CsvDatasetOpTest(test.TestCase): record_defaults=record_defaults, buffer_size=0) - def testCsvDataset_withBufferSize(self): + def _test_dataset_on_buffer_sizes(self, + inputs, + expected, + linebreak, + record_defaults, + compression_type=None, + num_sizes_to_test=20): + # Testing reading with a range of buffer sizes that should all work. + for i in list(range(1, 1 + num_sizes_to_test)) + [None]: + self._test_dataset( + inputs, + expected, + linebreak=linebreak, + compression_type=compression_type, + record_defaults=record_defaults, + buffer_size=i) + + def testCsvDataset_withLF(self): record_defaults = [['NA']] * 3 inputs = [['abc,def,ghi', '0,1,2', ',,']] expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, expected, record_defaults=record_defaults, buffer_size=i + 1) + self._test_dataset_on_buffer_sizes( + inputs, expected, linebreak='\n', record_defaults=record_defaults) def testCsvDataset_withCR(self): # Test that when the line separator is '\r', parsing works with all buffer @@ -447,14 +478,8 @@ class CsvDatasetOpTest(test.TestCase): record_defaults = [['NA']] * 3 inputs = [['abc,def,ghi', '0,1,2', ',,']] expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r', - record_defaults=record_defaults, - buffer_size=i + 1) + self._test_dataset_on_buffer_sizes( + inputs, expected, linebreak='\r', record_defaults=record_defaults) def testCsvDataset_withCRLF(self): # Test that when the line separator is '\r\n', parsing works with all buffer @@ -462,29 +487,15 @@ class CsvDatasetOpTest(test.TestCase): record_defaults = [['NA']] * 3 inputs = [['abc,def,ghi', '0,1,2', ',,']] expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r\n', - record_defaults=record_defaults, - buffer_size=i + 1) + self._test_dataset_on_buffer_sizes( + inputs, expected, linebreak='\r\n', record_defaults=record_defaults) def testCsvDataset_withBufferSizeAndQuoted(self): record_defaults = [['NA']] * 3 inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\n', - record_defaults=record_defaults, - buffer_size=i + 1) - self._test_dataset( + self._test_dataset_on_buffer_sizes( inputs, expected, linebreak='\n', record_defaults=record_defaults) def testCsvDataset_withCRAndQuoted(self): @@ -494,15 +505,7 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r', - record_defaults=record_defaults, - buffer_size=i + 1) - self._test_dataset( + self._test_dataset_on_buffer_sizes( inputs, expected, linebreak='\r', record_defaults=record_defaults) def testCsvDataset_withCRLFAndQuoted(self): @@ -512,17 +515,33 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r\n', - record_defaults=record_defaults, - buffer_size=i + 1) - self._test_dataset( + self._test_dataset_on_buffer_sizes( inputs, expected, linebreak='\r\n', record_defaults=record_defaults) + def testCsvDataset_withGzipCompressionType(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + self._test_dataset_on_buffer_sizes( + inputs, + expected, + linebreak='\r\n', + compression_type='GZIP', + record_defaults=record_defaults) + + def testCsvDataset_withZlibCompressionType(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + self._test_dataset_on_buffer_sizes( + inputs, + expected, + linebreak='\r\n', + compression_type='ZLIB', + record_defaults=record_defaults) + class CsvDatasetBenchmark(test.Benchmark): """Benchmarks for the various ways of creating a dataset from CSV files. @@ -581,7 +600,7 @@ class CsvDatasetBenchmark(test.Benchmark): num_cols = self._num_cols[i] kwargs = {'record_defaults': [[0.0]] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() - dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv') self._tearDown() @@ -591,7 +610,7 @@ class CsvDatasetBenchmark(test.Benchmark): num_cols = self._num_cols[i] kwargs = {'record_defaults': [['']] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() - dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv') self._tearDown() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index a842502cc6fe3605dde0be5f50cf46e3e37d7ed4..a2ab3de52e8e512e3cba399f7a1725e5570cfd01 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -17,14 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes -from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -70,63 +66,5 @@ class DatasetConstructorTest(test.TestCase): # pylint: enable=protected-access -class DatasetConstructorSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_tensor_dataset(self, variable_array): - components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) - - return dataset_ops.Dataset.from_tensors(components) - - def testFromTensorsCore(self): - # Equal length components - arr = np.array(1) - num_outputs = 1 - diff_arr = np.array(2) - self.run_core_tests(lambda: self._build_tensor_dataset(arr), - lambda: self._build_tensor_dataset(diff_arr), - num_outputs) - - def _build_tensor_slices_dataset(self, components): - return dataset_ops.Dataset.from_tensor_slices(components) - - def testFromTensorSlicesCore(self): - # Equal length components - components = (np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[12], [13], [14], [15]]), 22), - np.array([37.0, 38.0, 39.0, 40.0])) - - diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[5], [6], [7], [8]]), 22), - np.array([1.0, 2.0, 3.0, 4.0])) - - dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} - - self.run_core_tests(lambda: self._build_tensor_slices_dataset(components), - lambda: self._build_tensor_slices_dataset(diff_comp), 4) - self.run_core_tests( - lambda: self._build_tensor_slices_dataset(dict_components), None, 3) - - def _build_sparse_tensor_slice_dataset(self, slices): - indices = np.array( - [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], - dtype=np.int64) - values = np.array([val for s in slices for val in s], dtype=np.float64) - dense_shape = np.array( - [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) - sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) - return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) - - def testFromSparseTensorSlicesCore(self): - slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] - diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []] - - self.run_core_tests( - lambda: self._build_sparse_tensor_slice_dataset(slices), - lambda: self._build_sparse_tensor_slice_dataset(diff_slices), - 9, - sparse_tensors=True) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py index 34b6a080c0aae7dfc228746139acc52cea4e6f28..9b1857de1a96c8f71788a1bf5085ef0605417fe7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -19,7 +19,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors @@ -34,8 +33,8 @@ class DirectedInterleaveDatasetTest(test.TestCase): input_datasets = [ dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) ] - dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, - input_datasets) + dataset = interleave_ops._DirectedInterleaveDataset(selector_dataset, + input_datasets) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() @@ -144,24 +143,5 @@ class DirectedInterleaveDatasetTest(test.TestCase): ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0])) -class SampleFromDatasetsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, probs, num_samples): - dataset = interleave_ops.sample_from_datasets( - [ - dataset_ops.Dataset.from_tensors(i).repeat(None) - for i in range(len(probs)) - ], - probs, - seed=1813) - return dataset.take(num_samples) - - def testSerializationCore(self): - self.run_core_tests( - lambda: self._build_dataset([0.5, 0.5], 100), - lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py index 87b7c6ddb7afcbaaf8fe97cd8be87e6f5af8cd4d..e6883d53e02c0f96d966a52abfe2f9b4118f2e12 100644 --- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -17,9 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized +import numpy as np + from tensorflow.contrib.data.python.ops import get_single_element +from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor @@ -27,40 +30,69 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class GetSingleElementTest(test.TestCase): +class GetSingleElementTest(test.TestCase, parameterized.TestCase): - def testGetSingleElement(self): - skip_value = array_ops.placeholder(dtypes.int64, shape=[]) - take_value = array_ops.placeholder_with_default( - constant_op.constant(1, dtype=dtypes.int64), shape=[]) + @parameterized.named_parameters( + ("Zero", 0, 1), + ("Five", 5, 1), + ("Ten", 10, 1), + ("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."), + ("MoreThanOne", 0, 2, errors.InvalidArgumentError, + "Dataset had more than one element."), + ) + def testGetSingleElement(self, skip, take, error=None, error_msg=None): + skip_t = array_ops.placeholder(dtypes.int64, shape=[]) + take_t = array_ops.placeholder(dtypes.int64, shape=[]) def make_sparse(x): x_1d = array_ops.reshape(x, [1]) x_2d = array_ops.reshape(x, [1, 1]) return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d) - dataset = (dataset_ops.Dataset.range(100) - .skip(skip_value) - .map(lambda x: (x * x, make_sparse(x))) - .take(take_value)) - + dataset = dataset_ops.Dataset.range(100).skip(skip_t).map( + lambda x: (x * x, make_sparse(x))).take(take_t) element = get_single_element.get_single_element(dataset) with self.test_session() as sess: - for x in [0, 5, 10]: - dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x}) - self.assertEqual(x * x, dense_val) - self.assertAllEqual([[x]], sparse_val.indices) - self.assertAllEqual([x], sparse_val.values) - self.assertAllEqual([x], sparse_val.dense_shape) - - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dataset was empty."): - sess.run(element, feed_dict={skip_value: 100}) - - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dataset had more than one element."): - sess.run(element, feed_dict={skip_value: 0, take_value: 2}) + if error is None: + dense_val, sparse_val = sess.run( + element, feed_dict={ + skip_t: skip, + take_t: take + }) + self.assertEqual(skip * skip, dense_val) + self.assertAllEqual([[skip]], sparse_val.indices) + self.assertAllEqual([skip], sparse_val.values) + self.assertAllEqual([skip], sparse_val.dense_shape) + else: + with self.assertRaisesRegexp(error, error_msg): + sess.run(element, feed_dict={skip_t: skip, take_t: take}) + + @parameterized.named_parameters( + ("SumZero", 0), + ("SumOne", 1), + ("SumFive", 5), + ("SumTen", 10), + ) + def testReduceDataset(self, stop): + def init_fn(_): + return np.int64(0) + + def reduce_fn(state, value): + return state + value + + def finalize_fn(state): + return state + + sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn) + + stop_t = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = dataset_ops.Dataset.range(stop_t) + element = get_single_element.reduce_dataset(dataset, sum_reducer) + + with self.test_session() as sess: + value = sess.run(element, feed_dict={stop_t: stop}) + self.assertEqual(stop * (stop - 1) / 2, value) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index bee561e3e23a2ab6f314894caa21785347e6ca8b..44c3325a3db84bb844b7f860a7c925982f1e3d6a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -22,10 +22,8 @@ import math import threading import time -import numpy as np from six.moves import zip_longest -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes @@ -38,132 +36,6 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class InterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, input_values, cycle_length, block_length): - repeat_count = 2 - return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( - repeat_count).interleave( - lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length) - - def testSerializationCore(self): - input_values = np.array([4, 5, 6], dtype=np.int64) - num_outputs = np.sum(input_values) * 2 - # cycle_length > 1, block_length > 1 - cycle_length = 2 - block_length = 3 - # pylint: disable=g-long-lambda - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - lambda: self._build_iterator_graph( - input_values, cycle_length * 2, block_length * 1), - num_outputs) - # cycle_length = 1 - cycle_length = 1 - block_length = 3 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) - # block_length = 1 - cycle_length = 2 - block_length = 1 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) - # pylint: enable=g-long-lambda - - def testSparseCore(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) - - def _interleave_fn(x): - return dataset_ops.Dataset.from_tensor_slices( - sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - - def _build_dataset(): - return dataset_ops.Dataset.range(10).map(_map_fn).interleave( - _interleave_fn, cycle_length=1) - - self.run_core_tests(_build_dataset, None, 20) - - -class ParallelInterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def setUp(self): - self.input_values = np.array([4, 5, 6], dtype=np.int64) - self.num_repeats = 2 - self.num_outputs = np.sum(self.input_values) * 2 - - def _build_ds(self, cycle_length, block_length, sloppy=False): - return (dataset_ops.Dataset.from_tensor_slices( - self.input_values).repeat(self.num_repeats).apply( - interleave_ops.parallel_interleave( - lambda x: dataset_ops.Dataset.range(10 * x, 11 * x), - cycle_length, block_length, sloppy))) - - def testSerializationCore(self): - # cycle_length > 1, block_length > 1 - cycle_length = 2 - block_length = 3 - self.run_core_tests( - lambda: self._build_ds(cycle_length, block_length), - lambda: self._build_ds(cycle_length * 2, block_length * 1), - self.num_outputs) - # cycle_length = 1 - cycle_length = 1 - block_length = 3 - self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), - None, self.num_outputs) - # block_length = 1 - cycle_length = 2 - block_length = 1 - self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), - None, self.num_outputs) - - def testSerializationWithSloppy(self): - break_points = self.gen_break_points(self.num_outputs, 10) - expected_outputs = np.repeat( - np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]), - self.num_repeats).tolist() - - def run_test(cycle_length, block_length): - actual = self.gen_outputs( - lambda: self._build_ds(cycle_length, block_length, True), - break_points, self.num_outputs) - self.assertSequenceEqual(sorted(actual), expected_outputs) - - # cycle_length > 1, block_length > 1 - run_test(2, 3) - # cycle_length = 1 - run_test(1, 3) - # block_length = 1 - run_test(2, 1) - - def testSparseCore(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) - - def _interleave_fn(x): - return dataset_ops.Dataset.from_tensor_slices( - sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - - def _build_dataset(): - return dataset_ops.Dataset.range(10).map(_map_fn).apply( - interleave_ops.parallel_interleave(_interleave_fn, 1)) - - self.run_core_tests(_build_dataset, None, 20) - - class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py similarity index 97% rename from tensorflow/contrib/data/python/ops/iterator_ops_test.py rename to tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 30a993b1f7056b9726f524b2279131339c80c5eb..77148aceec7fa90f927a9c009671c2939460877b 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util @@ -55,7 +56,7 @@ class CheckpointInputPipelineHookTest(test.TestCase): def _read_vars(self, model_dir): """Returns (global_step, latest_feature).""" with ops.Graph().as_default() as g: - ckpt_path = saver_lib.latest_checkpoint(model_dir) + ckpt_path = checkpoint_management.latest_checkpoint(model_dir) meta_filename = ckpt_path + '.meta' saver_lib.import_meta_graph(meta_filename) saver = saver_lib.Saver() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 8d4042927970cab2f5a518fc0da49b38444dbcdf..009e21a34c8df86af6abbb7599dbcfa23ddf90a7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -17,27 +17,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import hashlib +import itertools import os +import time import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.util import compat +_NUMPY_RANDOM_SEED = 42 + class MapDatasetTest(test.TestCase): @@ -78,18 +80,21 @@ class MapDatasetTest(test.TestCase): sess.run(get_next) def testReadFileIgnoreError(self): + def write_string_to_file(value, filename): with open(filename, "w") as f: f.write(value) - filenames = [os.path.join(self.get_temp_dir(), "file_%d.txt" % i) - for i in range(5)] + + filenames = [ + os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5) + ] for filename in filenames: write_string_to_file(filename, filename) dataset = ( dataset_ops.Dataset.from_tensor_slices(filenames).map( - io_ops.read_file, num_parallel_calls=2).prefetch(2).apply( - error_ops.ignore_errors())) + io_ops.read_file, + num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -143,228 +148,210 @@ class MapDatasetTest(test.TestCase): sess.run(get_next) -class MapDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def setUp(self): - self._tensor_slice_len = 7 - self._num_epochs = 14 - self._num_outputs = self._tensor_slice_len * self._num_epochs - - def _build_ds(self, multiplier=37.0): - components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * - np.arange(self._tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(self._tensor_slice_len)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - return ( - dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(self._num_epochs)) - - def testSaveRestoreCore(self): - self.run_core_tests( - self._build_ds, - lambda: self._build_ds(multiplier=15.0), - self._num_outputs) - - def testSaveStatefulFunction(self): - - def _build_ds(): - - def _map_fn(x): - return random_ops.random_uniform( - (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - - return dataset_ops.Dataset.range(100).map(_map_fn) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureVariableInMapFn(self): - - def _build_ds(): - counter_var = variable_scope.get_variable( - "counter", (), dtypes.int32, use_resource=True) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda _: counter_var.assign_add(1))) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureConstantInMapFn(self): - - def _build_ds(): - constant_var = constant_op.constant(5) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda x: x + constant_var)) - - self.run_core_tests(_build_ds, None, 10) - - def testCaptureDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return dataset_ops.Dataset.range(num_outputs).map(defun_fn) - - self.run_core_tests(_build_ds, None, num_outputs) - - def testBuildDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - - @function.Defun(dtypes.int32) - def defun_fn_deep(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - - return dataset_ops.Dataset.range(num_outputs).map(defun_fn) - - self.run_core_tests(_build_ds, None, num_outputs) - - def testSparseCore(self): - - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0]]), - values=(i * np.array([1])), - dense_shape=np.array([1, 1])) - - def _build_ds(num_outputs): - return dataset_ops.Dataset.range(num_outputs).map(_sparse) - - num_outputs = 10 - self.run_core_tests(lambda: _build_ds(num_outputs), - lambda: _build_ds(int(num_outputs / 2)), num_outputs) - - -class ParallelMapDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def setUp(self): - self._tensor_slice_len = 7 - self._num_epochs = 1 - self._num_outputs = self._tensor_slice_len * self._num_epochs - - def _build_ds(self, multiplier=37.0): - components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * - np.arange(self._tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(self._tensor_slice_len)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - return (dataset_ops.Dataset.from_tensor_slices(components).map( - _map_fn, num_parallel_calls=3).repeat(self._num_epochs)) - - def _build_ds_with_prefetch(self, multiplier=37.0): - components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * - np.arange(self._tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(self._tensor_slice_len)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - return (dataset_ops.Dataset.from_tensor_slices(components).map( - _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5)) - - def testSaveRestoreCore(self): - for ds_fn in [self._build_ds, self._build_ds_with_prefetch]: - self.run_core_tests( - ds_fn, - lambda: ds_fn(multiplier=15.0), - self._num_outputs) - - def testSaveStatefulFunction(self): - - def _build_ds(): - - def _map_fn(x): - return random_ops.random_uniform( - (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - - return dataset_ops.Dataset.range(100).map( - _map_fn, num_parallel_calls=2).prefetch(2) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureVariableInMapFn(self): - - def _build_ds(): - counter_var = variable_scope.get_variable( - "counter", (), dtypes.int32, use_resource=True) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda _: counter_var.assign_add(1), - num_parallel_calls=2).prefetch(2)) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureConstantInMapFn(self): - - def _build_ds(): - constant_var = constant_op.constant(5) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda x: x + constant_var, num_parallel_calls=2).prefetch(2)) - - self.run_core_tests(_build_ds, None, 10) - - def testCaptureDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return dataset_ops.Dataset.range(num_outputs).map( - defun_fn, num_parallel_calls=2).prefetch(2) - - self.run_core_tests(_build_ds, None, num_outputs) - - def testBuildDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - - @function.Defun(dtypes.int32) - def defun_fn_deep(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - - return dataset_ops.Dataset.range(num_outputs).map( - defun_fn, num_parallel_calls=2).prefetch(2) - - self.run_core_tests(_build_ds, None, num_outputs) - - -class IgnoreErrorsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_ds(self, components): - return dataset_ops.Dataset.from_tensor_slices(components).map( - lambda x: array_ops.check_numerics(x, "message")).apply( - error_ops.ignore_errors()) - - def testIgnoreErrorsCore(self): - components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32) - num_outputs = 4 - self.run_core_tests(lambda: self._build_ds(components), - lambda: self._build_ds(diff_components), num_outputs) +class MapDatasetBenchmark(test.Benchmark): + + # The purpose of this benchmark is to compare the performance of chaining vs + # fusing of the map and batch transformations across various configurations. + # + # NOTE: It is recommended to build the benchmark with + # `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt` + # and execute it on a machine with at least 32 CPU cores. + def benchmarkMapAndBatch(self): + + # Sequential pipeline configurations. + seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16]) + seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64]) + + # Parallel pipeline configuration. + par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256]) + par_batch_size_series = itertools.product([32], [32], [1], + [128, 256, 512, 1024]) + par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512]) + par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512]) + + def name(method, label, num_calls, inter_op, element_size, batch_size): + return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % ( + method, + hashlib.sha1(label).hexdigest(), + num_calls, + inter_op, + element_size, + batch_size, + )) + + def benchmark(label, series): + + print("%s:" % label) + for num_calls, inter_op, element_size, batch_size in series: + + num_iters = 1024 // ( + (element_size * batch_size) // min(num_calls, inter_op)) + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand( + element_size, 4 * k), np.random.rand(4 * k, 1))).repeat() + + chained_dataset = dataset.map( + math_ops.matmul, + num_parallel_calls=num_calls).batch(batch_size=batch_size) + chained_iterator = chained_dataset.make_one_shot_iterator() + chained_get_next = chained_iterator.get_next() + + chained_deltas = [] + with session.Session( + config=config_pb2.ConfigProto( + inter_op_parallelism_threads=inter_op, + use_per_session_threads=True)) as sess: + for _ in range(5): + sess.run(chained_get_next.op) + for _ in range(num_iters): + start = time.time() + sess.run(chained_get_next.op) + end = time.time() + chained_deltas.append(end - start) + + fused_dataset = dataset = dataset.apply( + batching.map_and_batch( + math_ops.matmul, + num_parallel_calls=num_calls, + batch_size=batch_size)) + fused_iterator = fused_dataset.make_one_shot_iterator() + fused_get_next = fused_iterator.get_next() + + fused_deltas = [] + with session.Session( + config=config_pb2.ConfigProto( + inter_op_parallelism_threads=inter_op, + use_per_session_threads=True)) as sess: + + for _ in range(5): + sess.run(fused_get_next.op) + for _ in range(num_iters): + start = time.time() + sess.run(fused_get_next.op) + end = time.time() + fused_deltas.append(end - start) + + print( + "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, " + "element size: %d, num iters: %d\nchained wall time: %f (median), " + "%f (mean), %f (stddev), %f (min), %f (max)\n fused wall time: " + "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n " + "chained/fused: %.2fx (median), %.2fx (mean)" % + (batch_size, num_calls, inter_op, element_size, num_iters, + np.median(chained_deltas), np.mean(chained_deltas), + np.std(chained_deltas), np.min(chained_deltas), + np.max(chained_deltas), np.median(fused_deltas), + np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas), + np.max(fused_deltas), + np.median(chained_deltas) / np.median(fused_deltas), + np.mean(chained_deltas) / np.mean(fused_deltas))) + + self.report_benchmark( + iters=num_iters, + wall_time=np.median(chained_deltas), + name=name("chained", label, num_calls, inter_op, element_size, + batch_size)) + + self.report_benchmark( + iters=num_iters, + wall_time=np.median(fused_deltas), + name=name("fused", label, num_calls, inter_op, element_size, + batch_size)) + + print("") + + np.random.seed(_NUMPY_RANDOM_SEED) + benchmark("Sequential element size evaluation", seq_elem_size_series) + benchmark("Sequential batch size evaluation", seq_batch_size_series) + benchmark("Parallel element size evaluation", par_elem_size_series) + benchmark("Parallel batch size evaluation", par_batch_size_series) + benchmark("Transformation parallelism evaluation", par_num_calls_series) + benchmark("Threadpool size evaluation", par_inter_op_series) + + # This benchmark compares the performance of pipeline with multiple chained + # maps with and without map fusion. + def benchmarkChainOfMaps(self): + chain_lengths = [0, 1, 2, 5, 10, 20, 50] + for chain_length in chain_lengths: + self._benchmarkChainOfMaps(chain_length, False) + self._benchmarkChainOfMaps(chain_length, True) + + def _benchmarkChainOfMaps(self, chain_length, optimize_dataset): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) + for _ in range(chain_length): + dataset = dataset.map(lambda x: x) + if optimize_dataset: + dataset = dataset.apply(optimization.optimize(["map_fusion"])) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(5): + sess.run(next_element.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + opt_mark = "opt" if optimize_dataset else "no-opt" + print("Map dataset {} chain length: {} Median wall time: {}".format( + opt_mark, chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_map_dataset_chain_latency_{}_{}".format( + opt_mark, chain_length)) + + +class MapAndFilterBenchmark(test.Benchmark): + + # This benchmark compares the performance of pipeline with multiple chained + # map + filter with and without map fusion. + def benchmarkMapAndFilter(self): + chain_lengths = [0, 1, 2, 5, 10, 20, 50] + for chain_length in chain_lengths: + self._benchmarkMapAndFilter(chain_length, False) + self._benchmarkMapAndFilter(chain_length, True) + + def _benchmarkMapAndFilter(self, chain_length, optimize_dataset): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) + for _ in range(chain_length): + dataset = dataset.map(lambda x: x + 5).filter( + lambda x: math_ops.greater_equal(x - 5, 0)) + if optimize_dataset: + dataset = dataset.apply( + optimization.optimize(["map_and_filter_fusion"])) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(10): + sess.run(next_element.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + opt_mark = "opt" if optimize_dataset else "no-opt" + print("Map and filter dataset {} chain length: {} Median wall time: {}". + format(opt_mark, chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_map_and_filter_dataset_chain_latency_{}_{}".format( + opt_mark, chain_length)) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a711325daed12f45e4e533f18ee81adc7dec93be --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -0,0 +1,126 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MapDefunOp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import map_defun +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapDefunTest(test.TestCase): + + def testMapDefun_Simple(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + with self.test_session(): + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0] + expected = elems * 2 + 3 + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + + def testMapDefun_MismatchedTypes(self): + + @function.Defun(dtypes.int32) + def fn(x): + return math_ops.cast(x, dtypes.float64) + + with self.test_session(): + nums = [1, 2, 3, 4, 5, 6] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(r) + + def testMapDefun_MultipleOutputs(self): + + @function.Defun(dtypes.int32) + def fn(x): + return (x, math_ops.cast(x * 2 + 3, dtypes.float64)) + + with self.test_session(): + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], + [(2,), (2,)]) + expected = [elems, elems * 2 + 3] + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + + def testMapDefun_ShapeInference(self): + + @function.Defun(dtypes.int32) + def fn(x): + return x + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0] + self.assertEqual(result.get_shape(), (3, 2)) + + def testMapDefun_PartialShapeInference(self): + + @function.Defun(dtypes.int32) + def fn(x): + return x + + elems = array_ops.placeholder(dtypes.int64, (None, 2)) + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)]) + self.assertEqual(result[0].get_shape().as_list(), [None, 2]) + + def testMapDefun_RaisesErrorOnRuntimeShapeMismatch(self): + + @function.Defun(dtypes.int32, dtypes.int32) + def fn(x, y): + return x, y + + elems1 = array_ops.placeholder(dtypes.int32) + elems2 = array_ops.placeholder(dtypes.int32) + result = map_defun.map_defun(fn, [elems1, elems2], + [dtypes.int32, dtypes.int32], [(), ()]) + with self.test_session() as sess: + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + "All inputs must have the same dimension 0."): + sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]}) + + def testMapDefun_RaisesDefunError(self): + + @function.Defun(dtypes.int32) + def fn(x): + with ops.control_dependencies([check_ops.assert_equal(x, 0)]): + return array_ops.identity(x) + + elems = constant_op.constant([0, 0, 0, 37, 0]) + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()]) + with self.test_session(): + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index 30f1847dcddbfaf379ef2b09185f7a8db4aaeae2..ae147b4fa79c5fc8e63e1860f45036709ecc9777 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -17,72 +17,264 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from absl.testing import parameterized + +from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base from tensorflow.contrib.data.python.ops import optimization -from tensorflow.core.framework import graph_pb2 +from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class OptimizeDatasetTest(test.TestCase): +class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): + + def testAssertSuffix(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(0, sess.run(get_next)) + + def testAssertSuffixInvalid(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted Whoops transformation at offset 0 but encountered " + "Map transformation instead."): + sess.run(get_next) + + def testAssertSuffixShort(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map", "Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted next 2 transformations but encountered only 1."): + sess.run(get_next) def testDefaultOptimizations(self): - dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( - 10).apply(optimization.optimize()) + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize()) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() with self.test_session() as sess: - graph = graph_pb2.GraphDef().FromString( - sess.run(dataset._as_serialized_graph())) - self.assertTrue( - all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testEmptyOptimizations(self): - dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( - 10).apply(optimization.optimize([])) + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize([])) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() with self.test_session() as sess: - graph = graph_pb2.GraphDef().FromString( - sess.run(dataset._as_serialized_graph())) - self.assertTrue( - all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testOptimization(self): - dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( - 10).apply(optimization.optimize(["map_and_batch_fusion"])) + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize(["map_and_batch_fusion"])) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() with self.test_session() as sess: - graph = graph_pb2.GraphDef().FromString( - sess.run(dataset._as_serialized_graph())) - self.assertTrue( - any([node.op == "MapAndBatchDatasetV2" for node in graph.node])) self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testFunctionLibraryDefinitionModification(self): + dataset = dataset_ops.Dataset.from_tensors(0).map(lambda x: x).apply( + optimization.optimize(["_test_only_function_rename"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp(errors.NotFoundError, + "Function .* is not defined."): + sess.run(get_next) + + @staticmethod + def map_functions(): + identity = lambda x: x + increment = lambda x: x + 1 + + def increment_and_square(x): + y = x + 1 + return y * y + + functions = [identity, increment, increment_and_square] + tests = [] + for i, fun1 in enumerate(functions): + for j, fun2 in enumerate(functions): + tests.append(( + "test_{}_{}".format(i, j), + [fun1, fun2], + )) + for k, fun3 in enumerate(functions): + tests.append(( + "test_{}_{}_{}".format(i, j, k), + [fun1, fun2, fun3], + )) -class OptimizeDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + swap = lambda x, n: (n, x) + tests.append(( + "swap1", + [lambda x: (x, 42), swap], + )) + tests.append(( + "swap2", + [lambda x: (x, 42), swap, swap], + )) + return tuple(tests) - def testCore(self): + @parameterized.named_parameters(*map_functions.__func__()) + def testMapFusion(self, functions): + dataset = dataset_ops.Dataset.range(5).apply( + optimization.assert_next(["Map", "Prefetch"])) + for function in functions: + dataset = dataset.map(function) - def build_dataset(num_elements, batch_size): - return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( - batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) + dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + for x in range(5): + result = sess.run(get_next) + r = x + for function in functions: + if isinstance(r, tuple): + r = function(*r) # Pass tuple as multiple arguments. + else: + r = function(r) + self.assertAllEqual(r, result) - self.run_core_tests(lambda: build_dataset(200, 10), None, 20) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + @staticmethod + def map_and_filter_functions(): + identity = lambda x: x + increment = lambda x: x + 1 + minus_five = lambda x: x - 5 + + def increment_and_square(x): + y = x + 1 + return y * y + + take_all = lambda x: constant_op.constant(True) + is_zero = lambda x: math_ops.equal(x, 0) + is_odd = lambda x: math_ops.equal(x % 2, 0) + greater = lambda x: math_ops.greater(x + 5, 0) + + functions = [identity, increment, minus_five, increment_and_square] + filters = [take_all, is_zero, is_odd, greater] + tests = [] + + for x, fun in enumerate(functions): + for y, predicate in enumerate(filters): + tests.append(("mixed_{}_{}".format(x, y), fun, predicate)) + + # Multi output + tests.append(("multiOne", lambda x: (x, x), + lambda x, y: constant_op.constant(True))) + tests.append( + ("multiTwo", lambda x: (x, 2), + lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0))) + return tuple(tests) + + @parameterized.named_parameters(*map_and_filter_functions.__func__()) + def testMapFilterFusion(self, function, predicate): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", + "FilterByLastComponent"])).map(function).filter(predicate).apply( + optimization.optimize(["map_and_filter_fusion"])) + self._testMapAndFilter(dataset, function, predicate) + + def _testMapAndFilter(self, dataset, function, predicate): + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + for x in range(10): + r = function(x) + if isinstance(r, tuple): + b = predicate(*r) # Pass tuple as multiple arguments. + else: + b = predicate(r) + if sess.run(b): + result = sess.run(get_next) + self.assertAllEqual(r, result) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testAdditionalInputs(self): + a = constant_op.constant(3, dtype=dtypes.int64) + b = constant_op.constant(4, dtype=dtypes.int64) + some_tensor = math_ops.mul(a, b) + function = lambda x: x * x + + def predicate(y): + return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor) + + # We are currently not supporting functions with additional inputs. + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Filter"])).map(function).filter(predicate).apply( + optimization.optimize(["map_and_filter_fusion"])) + + self._testMapAndFilter(dataset, function, predicate) + + +class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): + + def testLatencyStatsOptimization(self): + + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.from_tensors(1).apply( + optimization.assert_next( + ["LatencyStats", "Map", "LatencyStats", "Prefetch", + "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply( + optimization.optimize(["latency_all_edges"])).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertEqual(1 * 1, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, + "record_latency_TensorDataset/_1", 1) + self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4", + 1) + self._assertSummaryHasCount(summary_str, + "record_latency_PrefetchDataset/_6", 1) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index b08132cd72254326d965907a1fdafb8a820926a1..361fe0dd39bb3f855c3b0b11281a9909fd601232 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -21,6 +21,7 @@ import threading from tensorflow.contrib.data.python.ops import prefetching_ops from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.compat import compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -30,6 +31,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test @@ -68,6 +70,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): with ops.device(device1): buffer_resource_handle = prefetching_ops.function_buffering_resource( f=_remote_fn, + output_types=[dtypes.float32], target_device=target, string_arg=ds_iterator_handle, buffer_size=3, @@ -85,8 +88,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): return (prefetch_op, reset_op, destroy_op) def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False) prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name, @@ -125,8 +127,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): "/job:localhost/replica:0/task:0/gpu:0") def testReinitialization(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) device0 = "/job:localhost/replica:0/task:0/cpu:0" device1 = "/job:localhost/replica:0/task:0/cpu:1" @@ -166,8 +167,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): sess.run(destroy_op) def testReinitializationOutOfRange(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) device0 = "/job:localhost/replica:0/task:0/cpu:0" device1 = "/job:localhost/replica:0/task:0/cpu:1" @@ -201,6 +201,49 @@ class PrefetchingKernelsOpsTest(test.TestCase): sess.run(destroy_op) + def testStringsGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/gpu:0" + + ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"]) + ds_iterator = ds.make_one_shot_iterator() + ds_iterator_handle = ds_iterator.string_handle() + + @function.Defun(dtypes.string) + def _remote_fn(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, ds.output_types, ds.output_shapes) + return remote_iterator.get_next() + + target = constant_op.constant(device0) + with ops.device(device1): + buffer_resource_handle = prefetching_ops.function_buffering_resource( + f=_remote_fn, + output_types=[dtypes.string], + target_device=target, + string_arg=ds_iterator_handle, + buffer_size=3, + shared_name="strings") + + with ops.device(device1): + prefetch_op = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=buffer_resource_handle, + output_types=[dtypes.string]) + destroy_op = resource_variable_ops.destroy_resource_op( + buffer_resource_handle, ignore_lookup_error=True) + + with self.test_session() as sess: + self.assertEqual([b"a"], sess.run(prefetch_op)) + self.assertEqual([b"b"], sess.run(prefetch_op)) + self.assertEqual([b"c"], sess.run(prefetch_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + sess.run(destroy_op) + class PrefetchToDeviceTest(test.TestCase): @@ -227,14 +270,43 @@ class PrefetchToDeviceTest(test.TestCase): self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config) as sess: for i in range(10): self.assertEqual(i, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testPrefetchToSameDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device( + "/job:localhost/replica:0/task:0/device:CPU:0")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + with self.test_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testPrefetchDictToDevice(self): host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) device_dataset = host_dataset.apply( @@ -258,8 +330,7 @@ class PrefetchToDeviceTest(test.TestCase): self.assertEqual(dtypes.int64, next_element["a"].dtype) self.assertEqual([], next_element["a"].shape) - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config) as sess: for i in range(10): self.assertEqual({"a": i}, sess.run(next_element)) @@ -292,8 +363,7 @@ class PrefetchToDeviceTest(test.TestCase): next_element = iterator.get_next() self.assertEqual(dtypes.int64, next_element.dtype) - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config) as sess: for i in range(10): actual = sess.run(next_element) @@ -343,8 +413,7 @@ class PrefetchToDeviceTest(test.TestCase): self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config) as sess: sess.run(iterator.initializer) for i in range(5): @@ -377,5 +446,653 @@ class PrefetchToDeviceTest(test.TestCase): sess.run(next_element) +class CopyToDeviceTest(test.TestCase): + + def testCopyToDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceInt32(self): + host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int32, next_element.dtype) + self.assertEqual((4,), next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToSameDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:0")) + + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceWithPrefetch(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyDictToDevice(self): + host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element["a"].dtype) + self.assertEqual([], next_element["a"].shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual({"a": i}, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyDictToDeviceWithPrefetch(self): + host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element["a"].dtype) + self.assertEqual([], next_element["a"].shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual({"a": i}, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopySparseTensorsToDevice(self): + + def make_tensor(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2]) + + host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) + + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + actual = sess.run(next_element) + self.assertAllEqual([i], actual.values) + self.assertAllEqual([[0, 0]], actual.indices) + self.assertAllEqual([2, 2], actual.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopySparseTensorsToDeviceWithPrefetch(self): + + def make_tensor(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2]) + + host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) + + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + actual = sess.run(next_element) + self.assertAllEqual([i], actual.values) + self.assertAllEqual([[0, 0]], actual.indices) + self.assertAllEqual([2, 2], actual.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuWithPrefetch(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")).prefetch(1) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuInt32(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuInt32AndPrefetch(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")).prefetch(1) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuStrings(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuStringsAndPrefetch(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDevicePingPongCPUGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + with compat.forward_compatibility_horizon(2018, 8, 4): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0", source_device="/cpu:0")) + back_to_cpu_dataset = device_dataset.apply( + prefetching_ops.copy_to_device("/cpu:0", source_device="/gpu:0")) + + with ops.device("/cpu:0"): + iterator = back_to_cpu_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceWithReInit(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceWithReInitAndPrefetch(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuWithReInit(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuWithReInitAndPrefetch(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")).prefetch(1) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testIteratorGetNextAsOptionalOnGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(3) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_elem = iterator_ops.get_next_as_optional(iterator) + elem_has_value_t = next_elem.has_value() + elem_value_t = next_elem.get_value() + + with self.test_session() as sess: + # Before initializing the iterator, evaluating the optional fails with + # a FailedPreconditionError. + with self.assertRaises(errors.FailedPreconditionError): + sess.run(elem_has_value_t) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(elem_value_t) + + # For each element of the dataset, assert that the optional evaluates to + # the expected value. + sess.run(iterator.initializer) + for i in range(3): + elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t]) + self.assertTrue(elem_has_value) + self.assertEqual(i, elem_value) + + # After exhausting the iterator, `next_elem.has_value()` will evaluate to + # false, and attempting to get the value will fail. + for _ in range(2): + self.assertFalse(sess.run(elem_has_value_t)) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(elem_value_t) + + +class MultiDeviceIteratorTest(test.TestCase): + + def testBasic(self): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testOneOnSameDevice(self): + with ops.device("/cpu:0"): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:0", "/cpu:1"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testRepeatDevices(self): + with ops.device("/cpu:0"): + dataset = dataset_ops.Dataset.range(20) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"]) + elements = multi_device_iterator.get_next() + elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 20, 4): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + self.assertEqual(i + 2, sess.run(elem_on_3)) + self.assertEqual(i + 3, sess.run(elem_on_4)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + sess.run(elem_on_3) + sess.run(elem_on_4) + + def testNotFullyDivisible(self): + dataset = dataset_ops.Dataset.range(9) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 8, 2): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + self.assertEqual(8, sess.run(elem_on_1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testUneven(self): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + for i in range(0, 10, 2): + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testMultipleInitializations(self): + with ops.device("/cpu:0"): + epoch = array_ops.placeholder(dtypes.int64, shape=[]) + dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000) + dataset2 = dataset_ops.Dataset.range(1000) + dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + init_op = multi_device_iterator.initializer + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + for i in range(1000): + sess.run(init_op, feed_dict={epoch: i}) + self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2])) + + def testBasicGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + with compat.forward_compatibility_horizon(2018, 8, 4): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/gpu:0"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testUnevenGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + with compat.forward_compatibility_horizon(2018, 8, 4): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + for i in range(0, 10, 2): + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 80e1cb0041024b68bd5268b5de5d69c88c839896..592642da0cfd84e50cb20d9b2e534411faf927e8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -17,21 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import counter from tensorflow.contrib.data.python.ops import enumerate_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import io_ops -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -81,88 +73,5 @@ class RangeDatasetTest(test.TestCase): self.assertEqual(-2, sess.run(negative_get_next)) -class RangeDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _iterator_checkpoint_prefix_local(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _save_op(self, iterator_resource): - iterator_state_variant = gen_dataset_ops.serialize_iterator( - iterator_resource) - save_op = io_ops.write_file( - self._iterator_checkpoint_prefix_local(), - parsing_ops.serialize_tensor(iterator_state_variant)) - return save_op - - def _restore_op(self, iterator_resource): - iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(self._iterator_checkpoint_prefix_local()), - dtypes.variant) - restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, - iterator_state_variant) - return restore_op - - def testSaveRestore(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. - start = 2 - stop = 10 - break_point = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Saving and restoring in same session. - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def _build_range_dataset(self, start, stop): - return dataset_ops.Dataset.range(start, stop) - - def testRangeCore(self): - start = 2 - stop = 10 - stop_1 = 8 - self.run_core_tests(lambda: self._build_range_dataset(start, stop), - lambda: self._build_range_dataset(start, stop_1), - stop - start) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 3b07ef290bc38daa37472ef8919f3350851fe370..15b342d30f85a05b3827998565ba5f84021ac885 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -23,260 +23,17 @@ import zlib import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import readers -from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.data.util import nest from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.lib.io import python_io -from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test -from tensorflow.python.util import compat - - -class TextLineDatasetTestBase(test.TestCase): - - def _lineText(self, f, l): - return compat.as_bytes("%d: %d" % (f, l)) - - def _createFiles(self, - num_files, - num_lines, - crlf=False, - compression_type=None): - filenames = [] - for i in range(num_files): - fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) - filenames.append(fn) - contents = [] - for j in range(num_lines): - contents.append(self._lineText(i, j)) - # Always include a newline after the record unless it is - # at the end of the file, in which case we include it - if j + 1 != num_lines or i == 0: - contents.append(b"\r\n" if crlf else b"\n") - contents = b"".join(contents) - - if not compression_type: - with open(fn, "wb") as f: - f.write(contents) - elif compression_type == "GZIP": - with gzip.GzipFile(fn, "wb") as f: - f.write(contents) - elif compression_type == "ZLIB": - contents = zlib.compress(contents) - with open(fn, "wb") as f: - f.write(contents) - else: - raise ValueError("Unsupported compression_type", compression_type) - - return filenames - - -class TextLineDatasetSerializationTest( - TextLineDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, test_filenames, compression_type=None): - return core_readers.TextLineDataset( - test_filenames, compression_type=compression_type, buffer_size=10) - - def testTextLineCore(self): - compression_types = [None, "GZIP", "ZLIB"] - num_files = 5 - lines_per_file = 5 - num_outputs = num_files * lines_per_file - for compression_type in compression_types: - test_filenames = self._createFiles( - num_files, - lines_per_file, - crlf=True, - compression_type=compression_type) - # pylint: disable=cell-var-from-loop - self.run_core_tests( - lambda: self._build_iterator_graph(test_filenames, compression_type), - lambda: self._build_iterator_graph(test_filenames), num_outputs) - # pylint: enable=cell-var-from-loop - - -class FixedLengthRecordReaderTestBase(test.TestCase): - - def setUp(self): - super(FixedLengthRecordReaderTestBase, self).setUp() - self._num_files = 2 - self._num_records = 7 - self._header_bytes = 5 - self._record_bytes = 3 - self._footer_bytes = 2 - - def _record(self, f, r): - return compat.as_bytes(str(f * 2 + r) * self._record_bytes) - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) - filenames.append(fn) - with open(fn, "wb") as f: - f.write(b"H" * self._header_bytes) - for j in range(self._num_records): - f.write(self._record(i, j)) - f.write(b"F" * self._footer_bytes) - return filenames - - -class FixedLengthRecordDatasetSerializationTest( - FixedLengthRecordReaderTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, num_epochs, compression_type=None): - filenames = self._createFiles() - return core_readers.FixedLengthRecordDataset( - filenames, self._record_bytes, self._header_bytes, - self._footer_bytes).repeat(num_epochs) - - def testFixedLengthRecordCore(self): - num_epochs = 5 - num_outputs = num_epochs * self._num_files * self._num_records - self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), - lambda: self._build_iterator_graph(num_epochs * 2), - num_outputs) - - -class TFRecordDatasetTestBase(test.TestCase): - - def setUp(self): - super(TFRecordDatasetTestBase, self).setUp() - self._num_files = 2 - self._num_records = 7 - - self.test_filenames = self._createFiles() - - self.filenames = array_ops.placeholder(dtypes.string, shape=[None]) - self.num_epochs = array_ops.placeholder_with_default( - constant_op.constant(1, dtypes.int64), shape=[]) - self.compression_type = array_ops.placeholder_with_default("", shape=[]) - self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = core_readers.TFRecordDataset( - self.filenames, self.compression_type).repeat(self.num_epochs) - batch_dataset = repeat_dataset.batch(self.batch_size) - - iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) - self.init_op = iterator.make_initializer(repeat_dataset) - self.init_batch_op = iterator.make_initializer(batch_dataset) - self.get_next = iterator.get_next() - - def _record(self, f, r): - return compat.as_bytes("Record %d of file %d" % (r, f)) - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - writer = python_io.TFRecordWriter(fn) - for j in range(self._num_records): - writer.write(self._record(i, j)) - writer.close() - return filenames - - -class TFRecordDatasetSerializationTest( - TFRecordDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, - num_epochs, - batch_size=1, - compression_type=None, - buffer_size=None): - filenames = self._createFiles() - if compression_type is "ZLIB": - zlib_files = [] - for i, fn in enumerate(filenames): - with open(fn, "rb") as f: - cdata = zlib.compress(f.read()) - zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) - with open(zfn, "wb") as f: - f.write(cdata) - zlib_files.append(zfn) - filenames = zlib_files - - elif compression_type is "GZIP": - gzip_files = [] - for i, fn in enumerate(self.test_filenames): - with open(fn, "rb") as f: - gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) - with gzip.GzipFile(gzfn, "wb") as gzf: - gzf.write(f.read()) - gzip_files.append(gzfn) - filenames = gzip_files - - return core_readers.TFRecordDataset( - filenames, compression_type, - buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) - - def testTFRecordWithoutBufferCore(self): - num_epochs = 5 - batch_size = num_epochs - num_outputs = num_epochs * self._num_files * self._num_records // batch_size - # pylint: disable=g-long-lambda - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, batch_size, - buffer_size=0), - lambda: self._build_iterator_graph(num_epochs * 2, batch_size), - num_outputs) - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None, - num_outputs * batch_size) - # pylint: enable=g-long-lambda - - def testTFRecordWithBufferCore(self): - num_epochs = 5 - num_outputs = num_epochs * self._num_files * self._num_records - self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), - lambda: self._build_iterator_graph(num_epochs * 2), - num_outputs) - - def testTFRecordWithCompressionCore(self): - num_epochs = 5 - num_outputs = num_epochs * self._num_files * self._num_records - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"), - lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"), - lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) - - -def _interleave(iterators, cycle_length): - pending_iterators = iterators - open_iterators = [] - num_open = 0 - for i in range(cycle_length): - if pending_iterators: - open_iterators.append(pending_iterators.pop(0)) - num_open += 1 - - while num_open: - for i in range(min(cycle_length, len(open_iterators))): - if open_iterators[i] is None: - continue - try: - yield next(open_iterators[i]) - except StopIteration: - if pending_iterators: - open_iterators[i] = pending_iterators.pop(0) - else: - open_iterators[i] = None - num_open -= 1 class ReadBatchFeaturesTest( @@ -416,276 +173,383 @@ class ReadBatchFeaturesTest( for num_epochs in [1, 10]: with ops.Graph().as_default(): # Basic test: read from file 0. - self.outputs = self.make_batch_feature( + outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, drop_final_batch=True).make_one_shot_iterator().get_next() - for _, tensor in self.outputs.items(): + for _, tensor in outputs.items(): if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. self.assertEqual(tensor.shape[0], batch_size) + def testIndefiniteRepeatShapeInference(self): + dataset = self.make_batch_feature( + filenames=self.test_filenames[0], num_epochs=None, batch_size=32) + for shape, clazz in zip(nest.flatten(dataset.output_shapes), + nest.flatten(dataset.output_classes)): + if issubclass(clazz, ops.Tensor): + self.assertEqual(32, shape[0]) + class MakeCsvDatasetTest(test.TestCase): - COLUMN_TYPES = [ - dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string - ] - COLUMNS = ["col%d" % i for i in range(len(COLUMN_TYPES))] - DEFAULT_VALS = [[], [], [], [], ["NULL"]] - DEFAULTS = [ - constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([], dtype=dtypes.int64), - constant_op.constant([], dtype=dtypes.float32), - constant_op.constant([], dtype=dtypes.float64), - constant_op.constant(["NULL"], dtype=dtypes.string) - ] - LABEL = COLUMNS[0] - - def setUp(self): - super(MakeCsvDatasetTest, self).setUp() - self._num_files = 2 - self._num_records = 11 - self._test_filenames = self._create_files() - - def _csv_values(self, fileno, recordno): - return [ - fileno, - recordno, - fileno * recordno * 0.5, - fileno * recordno + 0.5, - "record %d" % recordno if recordno % 2 == 1 else "", - ] + def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs): + return readers.make_csv_dataset( + filenames, batch_size=batch_size, num_epochs=num_epochs, **kwargs) - def _write_file(self, filename, rows): - for i in range(len(rows)): - if isinstance(rows[i], list): - rows[i] = ",".join(str(v) if v is not None else "" for v in rows[i]) - fn = os.path.join(self.get_temp_dir(), filename) - f = open(fn, "w") - f.write("\n".join(rows)) - f.close() - return fn - - def _create_file(self, fileno, header=True): - rows = [] - if header: - rows.append(self.COLUMNS) - for recno in range(self._num_records): - rows.append(self._csv_values(fileno, recno)) - return self._write_file("csv_file%d.csv" % fileno, rows) - - def _create_files(self): + def _setup_files(self, inputs, linebreak="\n", compression_type=None): filenames = [] - for i in range(self._num_files): - filenames.append(self._create_file(i)) + for i, ip in enumerate(inputs): + fn = os.path.join(self.get_temp_dir(), "temp_%d.csv" % i) + contents = linebreak.join(ip).encode("utf-8") + if compression_type is None: + with open(fn, "wb") as f: + f.write(contents) + elif compression_type == "GZIP": + with gzip.GzipFile(fn, "wb") as f: + f.write(contents) + elif compression_type == "ZLIB": + contents = zlib.compress(contents) + with open(fn, "wb") as f: + f.write(contents) + else: + raise ValueError("Unsupported compression_type", compression_type) + filenames.append(fn) return filenames - def _make_csv_dataset( - self, - filenames, - defaults, - column_names=COLUMNS, - label_name=LABEL, - select_cols=None, - batch_size=1, - num_epochs=1, - shuffle=False, - shuffle_seed=None, - header=True, - na_value="", - ): - return readers.make_csv_dataset( - filenames, - batch_size=batch_size, - column_names=column_names, - column_defaults=defaults, - label_name=label_name, - num_epochs=num_epochs, - shuffle=shuffle, - shuffle_seed=shuffle_seed, - header=header, - na_value=na_value, - select_columns=select_cols, - ) - - def _next_actual_batch(self, file_indices, batch_size, num_epochs, defaults): - features = {col: list() for col in self.COLUMNS} + def _next_expected_batch(self, expected_output, expected_keys, batch_size, + num_epochs): + features = {k: [] for k in expected_keys} for _ in range(num_epochs): - for i in file_indices: - for j in range(self._num_records): - values = self._csv_values(i, j) - for n, v in enumerate(values): - if v == "": # pylint: disable=g-explicit-bool-comparison - values[n] = defaults[n][0] - values[-1] = values[-1].encode("utf-8") - - # Regroup lists by column instead of row - for n, col in enumerate(self.COLUMNS): - features[col].append(values[n]) - if len(list(features.values())[0]) == batch_size: - yield features - features = {col: list() for col in self.COLUMNS} - - def _run_actual_batch(self, outputs, sess): - features, labels = sess.run(outputs) - batch = [features[k] for k in self.COLUMNS if k != self.LABEL] - batch.append(labels) - return batch - - def _verify_records( + for values in expected_output: + for n, key in enumerate(expected_keys): + features[key].append(values[n]) + if len(features[expected_keys[0]]) == batch_size: + yield features + features = {k: [] for k in expected_keys} + if features[expected_keys[0]]: # Leftover from the last batch + yield features + + def _verify_output( self, sess, dataset, - file_indices, - defaults=tuple(DEFAULT_VALS), - label_name=LABEL, - batch_size=1, - num_epochs=1, + batch_size, + num_epochs, + label_name, + expected_output, + expected_keys, ): - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() + nxt = dataset.make_one_shot_iterator().get_next() - for expected_features in self._next_actual_batch(file_indices, batch_size, - num_epochs, defaults): - actual_features = sess.run(get_next) + for expected_features in self._next_expected_batch( + expected_output, + expected_keys, + batch_size, + num_epochs, + ): + actual_features = sess.run(nxt) if label_name is not None: expected_labels = expected_features.pop(label_name) - # Compare labels self.assertAllEqual(expected_labels, actual_features[1]) - actual_features = actual_features[0] # Extract features dict from tuple + actual_features = actual_features[0] for k in expected_features.keys(): # Compare features self.assertAllEqual(expected_features[k], actual_features[k]) with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testMakeCSVDataset(self): - defaults = self.DEFAULTS - - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - # Basic test: read from file 0. - dataset = self._make_csv_dataset(self._test_filenames[0], defaults) - self._verify_records(sess, dataset, [0]) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - # Basic test: read from file 1. - dataset = self._make_csv_dataset(self._test_filenames[1], defaults) - self._verify_records(sess, dataset, [1]) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - # Read from both files. - dataset = self._make_csv_dataset(self._test_filenames, defaults) - self._verify_records(sess, dataset, range(self._num_files)) + sess.run(nxt) + + def _test_dataset(self, + inputs, + expected_output, + expected_keys, + batch_size=1, + num_epochs=1, + label_name=None, + **kwargs): + """Checks that elements produced by CsvDataset match expected output.""" + # Convert str type because py3 tf strings are bytestrings + filenames = self._setup_files( + inputs, compression_type=kwargs.get("compression_type", None)) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - # Read from both files. Exercise the `batch` and `num_epochs` parameters - # of make_csv_dataset and make sure they work. dataset = self._make_csv_dataset( - self._test_filenames, defaults, batch_size=2, num_epochs=10) - self._verify_records( - sess, dataset, range(self._num_files), batch_size=2, num_epochs=10) + filenames, + batch_size=batch_size, + num_epochs=num_epochs, + label_name=label_name, + **kwargs) + self._verify_output(sess, dataset, batch_size, num_epochs, label_name, + expected_output, expected_keys) + + def testMakeCSVDataset(self): + """Tests making a CSV dataset with keys and defaults provided.""" + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) + + def testMakeCSVDataset_withBatchSizeAndEpochs(self): + """Tests making a CSV dataset with keys and defaults provided.""" + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=3, + num_epochs=10, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) - def testMakeCSVDataset_withBadColumns(self): + def testMakeCSVDataset_withCompressionType(self): + """Tests `compression_type` argument.""" + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + for compression_type in ("GZIP", "ZLIB"): + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + compression_type=compression_type, + ) + + def testMakeCSVDataset_withBadInputs(self): """Tests that exception is raised when input is malformed. """ - dupe_columns = self.COLUMNS[:-1] + self.COLUMNS[:1] - defaults = self.DEFAULTS + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + filenames = self._setup_files(inputs) # Duplicate column names with self.assertRaises(ValueError): self._make_csv_dataset( - self._test_filenames, defaults, column_names=dupe_columns) + filenames, + batch_size=1, + column_defaults=record_defaults, + label_name="col0", + column_names=column_names * 2) # Label key not one of column names with self.assertRaises(ValueError): self._make_csv_dataset( - self._test_filenames, defaults, label_name="not_a_real_label") + filenames, + batch_size=1, + column_defaults=record_defaults, + label_name="not_a_real_label", + column_names=column_names) def testMakeCSVDataset_withNoLabel(self): - """Tests that CSV datasets can be created when no label is specified. - """ - defaults = self.DEFAULTS - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - # Read from both files. Make sure this works with no label key supplied. - dataset = self._make_csv_dataset( - self._test_filenames, - defaults, - batch_size=2, - num_epochs=10, - label_name=None) - self._verify_records( - sess, - dataset, - range(self._num_files), - batch_size=2, - num_epochs=10, - label_name=None) + """Tests making a CSV dataset with no label provided.""" + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) def testMakeCSVDataset_withNoHeader(self): """Tests that datasets can be created from CSV files with no header line. """ - defaults = self.DEFAULTS - file_without_header = self._create_file( - len(self._test_filenames), header=False) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - file_without_header, - defaults, - batch_size=2, - num_epochs=10, - header=False, - ) - self._verify_records( - sess, - dataset, - [len(self._test_filenames)], - batch_size=2, - num_epochs=10, - ) + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [["0,1,2,3,4", "5,6,7,8,9"], ["10,11,12,13,14", "15,16,17,18,19"]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=False, + column_defaults=record_defaults, + ) def testMakeCSVDataset_withTypes(self): """Tests that defaults can be a dtype instead of a Tensor for required vals. """ - defaults = [d for d in self.COLUMN_TYPES[:-1]] - defaults.append(constant_op.constant(["NULL"], dtype=dtypes.string)) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset(self._test_filenames, defaults) - self._verify_records(sess, dataset, range(self._num_files)) + record_defaults = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, + dtypes.string + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x[0] for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], + [ + ",".join(x[0] for x in column_names), "10,11,12,13,14", + "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) def testMakeCSVDataset_withNoColNames(self): """Tests that datasets can be created when column names are not specified. In that case, we should infer the column names from the header lines. """ - defaults = self.DEFAULTS - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - # Read from both files. Exercise the `batch` and `num_epochs` parameters - # of make_csv_dataset and make sure they work. - dataset = self._make_csv_dataset( - self._test_filenames, - defaults, - column_names=None, - batch_size=2, - num_epochs=10) - self._verify_records( - sess, dataset, range(self._num_files), batch_size=2, num_epochs=10) + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) def testMakeCSVDataset_withTypeInferenceMismatch(self): # Test that error is thrown when num fields doesn't match columns + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + filenames = self._setup_files(inputs) with self.assertRaises(ValueError): self._make_csv_dataset( - self._test_filenames, - column_names=self.COLUMNS + ["extra_name"], - defaults=None, + filenames, + column_names=column_names + ["extra_name"], + column_defaults=None, batch_size=2, num_epochs=10) @@ -694,197 +558,215 @@ class MakeCsvDatasetTest(test.TestCase): In that case, we should infer the types from the first N records. """ - # Test that it works with standard test files (with header, etc) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - self._test_filenames, defaults=None, batch_size=2, num_epochs=10) - self._verify_records( - sess, - dataset, - range(self._num_files), - batch_size=2, - num_epochs=10, - defaults=[[], [], [], [], [""]]) - - def testMakeCSVDataset_withTypeInferenceTricky(self): - # Test on a deliberately tricky file (type changes as we read more rows, and - # there are null values) - fn = os.path.join(self.get_temp_dir(), "file.csv") - expected_dtypes = [ - dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float32, - dtypes.string, dtypes.string - ] - col_names = ["col%d" % i for i in range(len(expected_dtypes))] - rows = [[None, None, None, "NAN", "", - "a"], [1, 2**31 + 1, 2**64, 123, "NAN", ""], - ['"123"', 2, 2**64, 123.4, "NAN", '"cd,efg"']] - expected = [[0, 0, 0, 0, "", "a"], [1, 2**31 + 1, 2**64, 123, "", ""], - [123, 2, 2**64, 123.4, "", "cd,efg"]] - for row in expected: - row[-1] = row[-1].encode("utf-8") # py3 expects byte strings - row[-2] = row[-2].encode("utf-8") # py3 expects byte strings - self._write_file("file.csv", [col_names] + rows) + column_names = ["col%d" % i for i in range(5)] + str_int32_max = str(2**33) + inputs = [[ + ",".join(x for x in column_names), + "0,%s,2.0,3e50,rabbit" % str_int32_max + ]] + expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]] + label = "col0" - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - fn, - defaults=None, - column_names=None, - label_name=None, - na_value="NAN", - ) - features = dataset.make_one_shot_iterator().get_next() - # Check that types match - for i in range(len(expected_dtypes)): - print(features["col%d" % i].dtype, expected_dtypes[i]) - assert features["col%d" % i].dtype == expected_dtypes[i] - for i in range(len(rows)): - assert sess.run(features) == dict(zip(col_names, expected[i])) - - def testMakeCSVDataset_withTypeInferenceAllTypes(self): - # Test that we make the correct inference for all types with fallthrough - fn = os.path.join(self.get_temp_dir(), "file.csv") - expected_dtypes = [ - dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, - dtypes.string, dtypes.string + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + ) + + def testMakeCSVDataset_withTypeInferenceFallthrough(self): + """Tests that datasets can be created when no defaults are specified. + + Tests on a deliberately tricky file. + """ + column_names = ["col%d" % i for i in range(5)] + str_int32_max = str(2**33) + inputs = [[ + ",".join(x for x in column_names), + ",,,,", + "0,0,0.0,0.0,0.0", + "0,%s,2.0,3e50,rabbit" % str_int32_max, + ",,,,", + ]] + expected_output = [[0, 0, 0, 0, b""], [0, 0, 0, 0, b"0.0"], + [0, 2**33, 2.0, 3e50, b"rabbit"], [0, 0, 0, 0, b""]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + ) + + def testMakeCSVDataset_withSelectCols(self): + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) ] - col_names = ["col%d" % i for i in range(len(expected_dtypes))] - rows = [[1, 2**31 + 1, 1.0, 4e40, "abc", ""]] - expected = [[ - 1, 2**31 + 1, 1.0, 4e40, "abc".encode("utf-8"), "".encode("utf-8") + column_names = ["col%d" % i for i in range(5)] + str_int32_max = str(2**33) + inputs = [[ + ",".join(x for x in column_names), + "0,%s,2.0,3e50,rabbit" % str_int32_max ]] - self._write_file("file.csv", [col_names] + rows) + expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - fn, - defaults=None, - column_names=None, - label_name=None, - na_value="NAN", - ) - features = dataset.make_one_shot_iterator().get_next() - # Check that types match - for i in range(len(expected_dtypes)): - self.assertAllEqual(features["col%d" % i].dtype, expected_dtypes[i]) - for i in range(len(rows)): - self.assertAllEqual( - sess.run(features), dict(zip(col_names, expected[i]))) + select_cols = [1, 3, 4] + self._test_dataset( + inputs, + expected_output=[[x[i] for i in select_cols] for x in expected_output], + expected_keys=[column_names[i] for i in select_cols], + column_names=column_names, + column_defaults=[record_defaults[i] for i in select_cols], + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + select_columns=select_cols, + ) + + # Can still do inference without provided defaults + self._test_dataset( + inputs, + expected_output=[[x[i] for i in select_cols] for x in expected_output], + expected_keys=[column_names[i] for i in select_cols], + column_names=column_names, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + select_columns=select_cols, + ) + + # Can still do column name inference + self._test_dataset( + inputs, + expected_output=[[x[i] for i in select_cols] for x in expected_output], + expected_keys=[column_names[i] for i in select_cols], + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + select_columns=select_cols, + ) + + # Can specify column names instead of indices + self._test_dataset( + inputs, + expected_output=[[x[i] for i in select_cols] for x in expected_output], + expected_keys=[column_names[i] for i in select_cols], + column_names=column_names, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + select_columns=[column_names[i] for i in select_cols], + ) def testMakeCSVDataset_withSelectColsError(self): - data = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] - col_names = ["col%d" % i for i in range(5)] - fn = self._write_file("file.csv", [col_names] + data) + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + column_names = ["col%d" % i for i in range(5)] + str_int32_max = str(2**33) + inputs = [[ + ",".join(x for x in column_names), + "0,%s,2.0,3e50,rabbit" % str_int32_max + ]] + + select_cols = [1, 3, 4] + filenames = self._setup_files(inputs) + with self.assertRaises(ValueError): # Mismatch in number of defaults and number of columns selected, # should raise an error self._make_csv_dataset( - fn, - defaults=[[0]] * 5, - column_names=col_names, - label_name=None, - select_cols=[1, 3]) + filenames, + batch_size=1, + column_defaults=record_defaults, + column_names=column_names, + select_columns=select_cols) + with self.assertRaises(ValueError): # Invalid column name should raise an error self._make_csv_dataset( - fn, - defaults=[[0]], - column_names=col_names, + filenames, + batch_size=1, + column_defaults=[[0]], + column_names=column_names, label_name=None, - select_cols=["invalid_col_name"]) - - def testMakeCSVDataset_withSelectCols(self): - data = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] - col_names = ["col%d" % i for i in range(5)] - fn = self._write_file("file.csv", [col_names] + data) - # If select_cols is specified, should only yield a subset of columns - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - fn, - defaults=[[0], [0]], - column_names=col_names, - label_name=None, - select_cols=[1, 3]) - expected = [[1, 3], [6, 8]] - features = dataset.make_one_shot_iterator().get_next() - for i in range(len(data)): - self.assertAllEqual( - sess.run(features), - dict(zip([col_names[1], col_names[3]], expected[i]))) - # Can still do default inference with select_cols - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - fn, - defaults=None, - column_names=col_names, - label_name=None, - select_cols=[1, 3]) - expected = [[1, 3], [6, 8]] - features = dataset.make_one_shot_iterator().get_next() - for i in range(len(data)): - self.assertAllEqual( - sess.run(features), - dict(zip([col_names[1], col_names[3]], expected[i]))) - # Can still do column name inference - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - fn, - defaults=None, - column_names=None, - label_name=None, - select_cols=[1, 3]) - expected = [[1, 3], [6, 8]] - features = dataset.make_one_shot_iterator().get_next() - for i in range(len(data)): - self.assertAllEqual( - sess.run(features), - dict(zip([col_names[1], col_names[3]], expected[i]))) - # Can specify column names instead of indices - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - fn, - defaults=None, - column_names=None, - label_name=None, - select_cols=[col_names[1], col_names[3]]) - expected = [[1, 3], [6, 8]] - features = dataset.make_one_shot_iterator().get_next() - for i in range(len(data)): - self.assertAllEqual( - sess.run(features), - dict(zip([col_names[1], col_names[3]], expected[i]))) + select_columns=["invalid_col_name"]) def testMakeCSVDataset_withShuffle(self): - total_records = self._num_files * self._num_records - defaults = self.DEFAULTS + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + def str_series(st): + return ",".join(str(i) for i in range(st, st + 5)) + + column_names = ["col%d" % i for i in range(5)] + inputs = [ + [",".join(x for x in column_names) + ] + [str_series(5 * i) for i in range(15)], + [",".join(x for x in column_names)] + + [str_series(5 * i) for i in range(15, 20)], + ] + + filenames = self._setup_files(inputs) + + total_records = 20 for batch_size in [1, 2]: with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Test that shuffling with the same seed produces the same result dataset1 = self._make_csv_dataset( - self._test_filenames, - defaults, + filenames, + column_defaults=record_defaults, + column_names=column_names, batch_size=batch_size, + header=True, shuffle=True, - shuffle_seed=5) + shuffle_seed=5, + num_epochs=2, + ) dataset2 = self._make_csv_dataset( - self._test_filenames, - defaults, + filenames, + column_defaults=record_defaults, + column_names=column_names, batch_size=batch_size, + header=True, shuffle=True, - shuffle_seed=5) + shuffle_seed=5, + num_epochs=2, + ) outputs1 = dataset1.make_one_shot_iterator().get_next() outputs2 = dataset2.make_one_shot_iterator().get_next() for _ in range(total_records // batch_size): - batch1 = self._run_actual_batch(outputs1, sess) - batch2 = self._run_actual_batch(outputs2, sess) + batch1 = nest.flatten(sess.run(outputs1)) + batch2 = nest.flatten(sess.run(outputs2)) for i in range(len(batch1)): self.assertAllEqual(batch1[i], batch2[i]) @@ -892,29 +774,70 @@ class MakeCsvDatasetTest(test.TestCase): with self.test_session(graph=g) as sess: # Test that shuffling with a different seed produces different results dataset1 = self._make_csv_dataset( - self._test_filenames, - defaults, + filenames, + column_defaults=record_defaults, + column_names=column_names, batch_size=batch_size, + header=True, shuffle=True, - shuffle_seed=5) + shuffle_seed=5, + num_epochs=2, + ) dataset2 = self._make_csv_dataset( - self._test_filenames, - defaults, + filenames, + column_defaults=record_defaults, + column_names=column_names, batch_size=batch_size, + header=True, shuffle=True, - shuffle_seed=6) + shuffle_seed=6, + num_epochs=2, + ) outputs1 = dataset1.make_one_shot_iterator().get_next() outputs2 = dataset2.make_one_shot_iterator().get_next() all_equal = False for _ in range(total_records // batch_size): - batch1 = self._run_actual_batch(outputs1, sess) - batch2 = self._run_actual_batch(outputs2, sess) + batch1 = nest.flatten(sess.run(outputs1)) + batch2 = nest.flatten(sess.run(outputs2)) for i in range(len(batch1)): all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) self.assertFalse(all_equal) - -class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): + def testIndefiniteRepeatShapeInference(self): + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + filenames = self._setup_files(inputs) + dataset = self._make_csv_dataset(filenames, batch_size=32, num_epochs=None) + for shape in nest.flatten(dataset.output_shapes): + self.assertEqual(32, shape[0]) + + +class MakeTFRecordDatasetTest( + reader_dataset_ops_test_base.TFRecordDatasetTestBase): + + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 def _next_expected_batch(self, file_indices, @@ -930,8 +853,8 @@ class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): yield j, i def _next_record_interleaved(file_indices, cycle_length): - return _interleave([_next_record([i]) for i in file_indices], - cycle_length) + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) record_batch = [] batch_index = 0 @@ -1097,6 +1020,12 @@ class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): self._shuffle_test(batch_size, num_epochs, num_parallel_reads, seed=21345) + def testIndefiniteRepeatShapeInference(self): + dataset = readers.make_tf_record_dataset( + file_pattern=self.test_filenames, num_epochs=None, batch_size=32) + for shape in nest.flatten(dataset.output_shapes): + self.assertEqual(32, shape[0]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py index 805a7c7b7384d53cc166a48ba243502ef8643280..e63bc4c72049c61aa40314ffebe5c4366a818d46 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py @@ -12,24 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Base class for testing reader datasets.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gzip import os +import zlib from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.lib.io import python_io +from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.util import compat +class FixedLengthRecordDatasetTestBase(test.TestCase): + """Base class for setting up and testing FixedLengthRecordDataset.""" + + def setUp(self): + super(FixedLengthRecordDatasetTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + self._header_bytes = 5 + self._record_bytes = 3 + self._footer_bytes = 2 + + def _record(self, f, r): + return compat.as_bytes(str(f * 2 + r) * self._record_bytes) + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) + filenames.append(fn) + with open(fn, "wb") as f: + f.write(b"H" * self._header_bytes) + for j in range(self._num_records): + f.write(self._record(i, j)) + f.write(b"F" * self._footer_bytes) + return filenames + + class ReadBatchFeaturesTestBase(test.TestCase): """Base class for setting up and testing `make_batched_feature_dataset`.""" @@ -216,3 +249,83 @@ class ReadBatchFeaturesTestBase(test.TestCase): actual_batch = self._next_actual_batch(sess) for i in range(len(expected_batch)): self.assertAllEqual(expected_batch[i], actual_batch[i]) + + +class TextLineDatasetTestBase(test.TestCase): + """Base class for setting up and testing TextLineDataset.""" + + def _lineText(self, f, l): + return compat.as_bytes("%d: %d" % (f, l)) + + def _createFiles(self, + num_files, + num_lines, + crlf=False, + compression_type=None): + filenames = [] + for i in range(num_files): + fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) + filenames.append(fn) + contents = [] + for j in range(num_lines): + contents.append(self._lineText(i, j)) + # Always include a newline after the record unless it is + # at the end of the file, in which case we include it + if j + 1 != num_lines or i == 0: + contents.append(b"\r\n" if crlf else b"\n") + contents = b"".join(contents) + + if not compression_type: + with open(fn, "wb") as f: + f.write(contents) + elif compression_type == "GZIP": + with gzip.GzipFile(fn, "wb") as f: + f.write(contents) + elif compression_type == "ZLIB": + contents = zlib.compress(contents) + with open(fn, "wb") as f: + f.write(contents) + else: + raise ValueError("Unsupported compression_type", compression_type) + + return filenames + + +class TFRecordDatasetTestBase(test.TestCase): + """Base class for setting up and testing TFRecordDataset.""" + + def setUp(self): + super(TFRecordDatasetTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + + self.test_filenames = self._createFiles() + + self.filenames = array_ops.placeholder(dtypes.string, shape=[None]) + self.num_epochs = array_ops.placeholder_with_default( + constant_op.constant(1, dtypes.int64), shape=[]) + self.compression_type = array_ops.placeholder_with_default("", shape=[]) + self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = core_readers.TFRecordDataset( + self.filenames, self.compression_type).repeat(self.num_epochs) + batch_dataset = repeat_dataset.batch(self.batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + self.init_op = iterator.make_initializer(repeat_dataset) + self.init_batch_op = iterator.make_initializer(batch_dataset) + self.get_next = iterator.get_next() + + def _record(self, f, r): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + writer.write(self._record(i, j)) + writer.close() + return filenames diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index bdc003a8a5bd646e1d5c598befa2694da512d0a9..c5cfddb72b56a1bcffc80c0dd34994def3ee45cd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -17,10 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin import time + from absl.testing import parameterized +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.data.ops import dataset_ops diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index eb2ceff893543f710d4f0246adf4e6367a2deeb0..42cada0b97bcd9ab755297e8b1f0667766f7999e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -21,7 +21,6 @@ import itertools import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context @@ -64,7 +63,7 @@ class ScanDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFibonacci(self): iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])) @@ -168,18 +167,5 @@ class ScanDatasetTest(test.TestCase): scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) -class ScanDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, num_elements): - return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( - scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) - - def testScanCore(self): - num_output = 5 - self.run_core_tests(lambda: self._build_dataset(num_output), - lambda: self._build_dataset(2), num_output) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..7b9ea191a4524891d1b589e1e228e29241fda7f8 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -0,0 +1,541 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "dataset_serialization_test_base", + srcs = [ + "dataset_serialization_test_base.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "batch_dataset_serialization_test", + size = "medium", + srcs = ["batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "cache_dataset_serialization_test", + size = "small", + srcs = ["cache_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "concatenate_dataset_serialization_test", + size = "small", + srcs = ["concatenate_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "csv_dataset_serialization_test", + size = "small", + srcs = ["csv_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + ], +) + +py_test( + name = "dataset_constructor_serialization_test", + size = "medium", + srcs = ["dataset_constructor_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "filter_dataset_serialization_test", + size = "medium", + srcs = ["filter_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "fixed_length_record_dataset_serialization_test", + size = "medium", + srcs = ["fixed_length_record_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "flat_map_dataset_serialization_test", + size = "medium", + srcs = ["flat_map_dataset_serialization_test.py"], + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "group_by_reducer_serialization_test", + size = "medium", + srcs = ["group_by_reducer_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:grouping", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "group_by_window_serialization_test", + size = "medium", + srcs = ["group_by_window_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:grouping", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "ignore_errors_serialization_test", + size = "small", + srcs = ["ignore_errors_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "interleave_dataset_serialization_test", + size = "medium", + srcs = ["interleave_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "map_and_batch_dataset_serialization_test", + size = "medium", + srcs = ["map_and_batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "map_dataset_serialization_test", + size = "medium", + srcs = ["map_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "optimize_dataset_serialization_test", + size = "small", + srcs = ["optimize_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "padded_batch_dataset_serialization_test", + size = "medium", + srcs = ["padded_batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:string_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "parallel_interleave_dataset_serialization_test", + size = "medium", + srcs = ["parallel_interleave_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "parallel_map_dataset_serialization_test", + size = "medium", + srcs = ["parallel_map_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "prefetch_dataset_serialization_test", + size = "small", + srcs = ["prefetch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "range_dataset_serialization_test", + size = "small", + srcs = ["range_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sample_from_datasets_serialization_test", + size = "medium", + srcs = ["sample_from_datasets_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "scan_dataset_serialization_test", + size = "small", + srcs = ["scan_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:scan_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sequence_dataset_serialization_test", + size = "medium", + srcs = ["sequence_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "serialization_integration_test", + size = "small", + srcs = ["serialization_integration_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "shuffle_and_repeat_dataset_serialization_test", + size = "medium", + srcs = ["shuffle_and_repeat_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:shuffle_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "shuffle_dataset_serialization_test", + size = "medium", + srcs = ["shuffle_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sql_dataset_serialization_test", + size = "small", + srcs = ["sql_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:sql_dataset_op_test_base", + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + ], +) + +py_test( + name = "stats_dataset_serialization_test", + size = "medium", + srcs = ["stats_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "textline_dataset_serialization_test", + size = "medium", + srcs = ["textline_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "tf_record_dataset_serialization_test", + size = "medium", + srcs = ["tf_record_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "unbatch_dataset_serialization_test", + size = "medium", + srcs = ["unbatch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "unique_dataset_serialization_test", + size = "small", + srcs = ["unique_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:unique", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "zip_dataset_serialization_test", + size = "small", + srcs = ["zip_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..af87d8b6083de268fafd4346d2871f14e0f4e7c9 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py @@ -0,0 +1,83 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the BatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class BatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len // batch_size + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + + def _build_dataset_dense_to_sparse(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.fill([x], x)).apply( + batching.dense_to_sparse_batch(4, [12])) + + def testDenseToSparseBatchDatasetCore(self): + components = np.random.randint(5, size=(40,)).astype(np.int32) + diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) + + num_outputs = len(components) // 4 + self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components), + lambda: self._build_dataset_dense_to_sparse(diff_comp), + num_outputs) + + def _sparse(self, i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + def _build_dataset_sparse(self, batch_size=5): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size) + + def testSparseCore(self): + self.run_core_tests(self._build_dataset_sparse, + lambda: self._build_dataset_sparse(2), 2) + + def _build_dataset_nested_sparse(self): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2) + + def testNestedSparseCore(self): + self.run_core_tests(self._build_dataset_nested_sparse, None, 1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py similarity index 50% rename from tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py index f08216a303e2d7dee155ccadcdb9f42f1b24ea0f..1b6059ccbcc81937696e1b0ebb269f213adbb976 100644 --- a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py @@ -12,21 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental features of CacheDataset.""" +"""Tests for the CacheDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from absl.testing import parameterized + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.platform import test -class CacheToFileDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): +class CacheDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase, + parameterized.TestCase): def setUp(self): self.range_size = 10 @@ -34,88 +37,123 @@ class CacheToFileDatasetSerializationTest( self.num_outputs = self.range_size * self.num_repeats self.cache_file_prefix = 'test' - def ds_fn(self): - return dataset_ops.Dataset.range(self.range_size).cache( - os.path.join(self.get_temp_dir(), - self.cache_file_prefix)).repeat(self.num_repeats) + def make_dataset_fn(self, is_memory): + if is_memory: + filename = '' + else: + filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix) + + def ds_fn(): + return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat( + self.num_repeats) + + return ds_fn def expected_outputs(self): return list(range(self.range_size)) * self.num_repeats - def testCheckpointBeforeOneEpoch(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Generate 5 entries from iterator and save checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) - def testCheckpointBeforeOneEpochThenRunFewSteps(self): - # Generate 8 entries from iterator but save checkpoint after producing - # 5. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 8 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs( - self.ds_fn, [5], - 8, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, range(8)) - # Restoring from checkpoint and running GetNext should return a - # `AlreadExistsError` now because the lockfile already exists. - with self.assertRaises(errors.AlreadyExistsError): - self.gen_outputs( - self.ds_fn, [], - self.num_outputs - 5, - ckpt_saved=True, - verify_exhausted=False) + if is_memory: + outputs = outputs[:5] + outputs.extend( + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + else: + # Restoring from checkpoint and running GetNext should return + # `AlreadExistsError` now because the lockfile already exists. + with self.assertRaises(errors.AlreadyExistsError): + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointAfterOneEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) - def testCheckpointAfterOneEpoch(self): # Generate 15 entries from iterator and save checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) - def testCheckpointAfterOneEpochThenRunFewSteps(self): - # Generate 18 entries from iterator but save checkpoint after producing - # 15. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 18 entries from iterator but save checkpoint after producing 15. outputs = self.gen_outputs( - self.ds_fn, [15], - 18, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(8))) outputs = list(range(10)) + list(range(5)) + self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointBeforeOneEpochButRunCompleteEpoch(self): - # Generate 13 entries from iterator but save checkpoint after producing - # 5. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 13 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs( - self.ds_fn, [5], - 13, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(3))) # Since we ran for more than one epoch, the cache was completely written. @@ -124,65 +162,90 @@ class CacheToFileDatasetSerializationTest( # been completely written. outputs = list(range(5)) + self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointUnusedWriterIterator(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointUnusedWriterIterator(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Checkpoint before get_next is called even once. - outputs = self.gen_outputs(self.ds_fn, [], 0, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False) self.assertSequenceEqual(outputs, []) outputs = self.gen_outputs( - self.ds_fn, [], - self.num_outputs, - ckpt_saved=True, - verify_exhausted=False) + ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointUnusedMidwayWriterIterator(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointUnusedMidwayWriterIterator(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Produce 5 elements and checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint, then produce no elements and checkpoint. outputs.extend( - self.gen_outputs( - self.ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) + self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce rest of the elements. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testUnusedCheckpointError(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testUnusedCheckpointError(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Produce 5 elements and save ckpt. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) - # Since the complete cache has not been written, a new iterator which does - # not restore the checkpoint will throw an error since there is a partial - # cache shard. - with self.assertRaises(errors.AlreadyExistsError): + if is_memory: outputs = self.gen_outputs( - self.ds_fn, [], self.num_outputs, verify_exhausted=False) + ds_fn, [], self.num_outputs, verify_exhausted=False) + self.assertSequenceEqual(outputs, self.expected_outputs()) + else: + # Since the complete cache has not been written, a new iterator which does + # not restore the checkpoint will throw an error since there is a partial + # cache shard. + with self.assertRaises(errors.AlreadyExistsError): + outputs = self.gen_outputs( + ds_fn, [], self.num_outputs, verify_exhausted=False) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testIgnoreCheckpointIfCacheWritten(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) - def testIgnoreCheckpointIfCacheWritten(self): # Produce 15 elements and save ckpt. This will write the complete cache. - outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Build the iterator again but do not restore from ckpt. Since the cache # has already been written we should be able to use it. outputs = self.gen_outputs( - self.ds_fn, [], self.num_outputs, verify_exhausted=False) + ds_fn, [], self.num_outputs, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) diff --git a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py similarity index 92% rename from tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py index 17f2980157ddd0350dafd1d745cbb9b64e65f7c5..96f13d75a31b6762b35062e6cf8c0cdb4d61d2c5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the ConcatenateDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..247f2046ea313f97bdbda1674765f12406258509 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py @@ -0,0 +1,73 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the CsvDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.platform import test + + +class CsvDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._num_cols = 7 + self._num_rows = 10 + self._num_epochs = 14 + self._num_outputs = self._num_rows * self._num_epochs + + inputs = [ + ",".join(str(self._num_cols * j + i) + for i in range(self._num_cols)) + for j in range(self._num_rows) + ] + contents = "\n".join(inputs).encode("utf-8") + + self._filename = os.path.join(self.get_temp_dir(), "file.csv") + self._compressed = os.path.join(self.get_temp_dir(), + "comp.csv") # GZip compressed + + with open(self._filename, "wb") as f: + f.write(contents) + with gzip.GzipFile(self._compressed, "wb") as f: + f.write(contents) + + def ds_func(self, **kwargs): + compression_type = kwargs.get("compression_type", None) + if compression_type == "GZIP": + filename = self._compressed + elif compression_type is None: + filename = self._filename + else: + raise ValueError("Invalid compression type:", compression_type) + + return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs) + + def testSerializationCore(self): + defs = [[0]] * self._num_cols + self.run_core_tests( + lambda: self.ds_func(record_defaults=defs, buffer_size=2), + lambda: self.ds_func(record_defaults=defs, buffer_size=12), + self._num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2139b5c33db69a7ffbdebee74e5824928004b407 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the dataset constructors serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test + + +class FromTensorsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_tensor_dataset(self, variable_array): + components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) + + return dataset_ops.Dataset.from_tensors(components) + + def testFromTensorsCore(self): + # Equal length components + arr = np.array(1) + num_outputs = 1 + diff_arr = np.array(2) + self.run_core_tests(lambda: self._build_tensor_dataset(arr), + lambda: self._build_tensor_dataset(diff_arr), + num_outputs) + + +class FromTensorSlicesSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_tensor_slices_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components) + + def testFromTensorSlicesCore(self): + # Equal length components + components = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 22), + np.array([37.0, 38.0, 39.0, 40.0])) + + diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[5], [6], [7], [8]]), 22), + np.array([1.0, 2.0, 3.0, 4.0])) + + dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} + + self.run_core_tests(lambda: self._build_tensor_slices_dataset(components), + lambda: self._build_tensor_slices_dataset(diff_comp), 4) + self.run_core_tests( + lambda: self._build_tensor_slices_dataset(dict_components), None, 3) + + +class FromSparseTensorSlicesSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_sparse_tensor_slice_dataset(self, slices): + indices = np.array( + [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], + dtype=np.int64) + values = np.array([val for s in slices for val in s], dtype=np.float64) + dense_shape = np.array( + [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) + sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) + return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) + + def testFromSparseTensorSlicesCore(self): + slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] + diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []] + + self.run_core_tests( + lambda: self._build_sparse_tensor_slice_dataset(slices), + lambda: self._build_sparse_tensor_slice_dataset(diff_slices), + 9, + sparse_tensors=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py similarity index 99% rename from tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py index 393f08850b1865180a8b94e9209b2445b54c8b69..3ed4dfb7295ca77c78ce5318bf31e16a354e16a8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest @@ -655,7 +656,7 @@ class DatasetSerializationTestBase(test.TestCase): return os.path.join(self.get_temp_dir(), "iterator") def _latest_ckpt(self): - return saver_lib.latest_checkpoint(self.get_temp_dir()) + return checkpoint_management.latest_checkpoint(self.get_temp_dir()) def _save(self, sess, saver): saver.save(sess, self._ckpt_path()) diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py similarity index 91% rename from tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py index b572d6ed770fc0fe0f852359baf343c55966eddd..7c170078a11aadce9e5730437e4c25209bd58edb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the FilterDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops @@ -35,7 +35,7 @@ class FilterDatasetSerializationTest( def testFilterCore(self): div = 3 - num_outputs = np.sum([x % 3 is not 2 for x in range(100)]) + num_outputs = np.sum([x % 3 != 2 for x in range(100)]) self.run_core_tests(lambda: self._build_filter_range_graph(div), lambda: self._build_filter_range_graph(div * 2), num_outputs) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..34392d88d4505175c4562e23d5f0c4116e00b022 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py @@ -0,0 +1,45 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the FixedLengthRecordDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class FixedLengthRecordDatasetSerializationTest( + reader_dataset_ops_test_base.FixedLengthRecordDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, num_epochs, compression_type=None): + filenames = self._createFiles() + return core_readers.FixedLengthRecordDataset( + filenames, self._record_bytes, self._header_bytes, + self._footer_bytes).repeat(num_epochs) + + def testFixedLengthRecordCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py similarity index 96% rename from tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py index f3feecef32e587045be25056815315136a883ca7..16051ffd3fd1e1e7ff419f28109df7bc1f165257 100644 --- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the FlatMapDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..571e0899bbc1f856d66f85c4f6f3ac78aa0b1368 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py @@ -0,0 +1,61 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the GroupByReducer serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class GroupByReducerSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + return dataset_ops.Dataset.from_tensor_slices(components).apply( + grouping.group_by_reducer(lambda x: x % 5, reducer)) + + def testCoreGroupByReducer(self): + components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 5, + verify_exhausted=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f86af4084ef61c2f20dbe2fb388a20287676f39d --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the GroupByWindow serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class GroupByWindowSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( + grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)) + + def testCoreGroupByWindow(self): + components = np.array( + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 12, + verify_exhausted=False) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..65ae9923b8f64dddcd54afc53e2fa67bc770fc2a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the IgnoreErrors input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class IgnoreErrorsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.check_numerics(x, "message")).apply( + error_ops.ignore_errors()) + + def testIgnoreErrorsCore(self): + components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) + diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32) + num_outputs = 4 + self.run_core_tests(lambda: self._build_ds(components), + lambda: self._build_ds(diff_components), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ac3892fe81a1c0d325ddc5f501c2caed4b53f5d5 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py @@ -0,0 +1,86 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the InterleaveDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class InterleaveDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, input_values, cycle_length, block_length): + repeat_count = 2 + return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( + repeat_count).interleave( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length) + + def testSerializationCore(self): + input_values = np.array([4, 5, 6], dtype=np.int64) + num_outputs = np.sum(input_values) * 2 + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + lambda: self._build_iterator_graph( + input_values, cycle_length * 2, block_length * 1), + num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # pylint: enable=g-long-lambda + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).interleave( + _interleave_fn, cycle_length=1) + + self.run_core_tests(_build_dataset, None, 20) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c9cd211328fa595c0ce0efe3509e8ba9dc06af80 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py @@ -0,0 +1,88 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapAndBatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapAndBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testNumParallelBatches(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_batches = 2 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_batches=num_parallel_batches, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ab783e5cce95ed63fe64c273abb3846121c7a274 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py @@ -0,0 +1,140 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class MapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 14 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(self._num_epochs)) + + def testSaveRestoreCore(self): + self.run_core_tests( + self._build_ds, + lambda: self._build_ds(multiplier=15.0), + self._num_outputs) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map(_map_fn) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1))) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureConstantInMapFn(self): + + def _build_ds(): + constant_var = constant_op.constant(5) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var)) + + self.run_core_tests(_build_ds, None, 10) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testSparseCore(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def _build_ds(num_outputs): + return dataset_ops.Dataset.range(num_outputs).map(_sparse) + + num_outputs = 10 + self.run_core_tests(lambda: _build_ds(num_outputs), + lambda: _build_ds(int(num_outputs / 2)), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d5c03495e34e73018bf9832bf77cdcf038449488 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the OptimizeDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class OptimizeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + + def build_dataset(num_elements, batch_size): + return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( + batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) + + self.run_core_tests(lambda: build_dataset(200, 10), None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac42a461afcb6803a0e033892e74fb84d1e5e58 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py @@ -0,0 +1,66 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the PaddedBatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test + + +class PaddedBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testPaddedBatch(self): + + def build_dataset(seq_lens): + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + lambda x: array_ops.fill([x], x)).padded_batch( + 4, padded_shapes=[-1]) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + def testPaddedBatchNonDefaultPadding(self): + + def build_dataset(seq_lens): + + def fill_tuple(x): + filled = array_ops.fill([x], x) + return (filled, string_ops.as_string(filled)) + + padded_shape = [-1] + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + fill_tuple).padded_batch( + 4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, "")) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1f8a584df902180aa7ab020b47ecc749912a3a3a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ParallelInterleaveDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class ParallelInterleaveDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self.input_values = np.array([4, 5, 6], dtype=np.int64) + self.num_repeats = 2 + self.num_outputs = np.sum(self.input_values) * 2 + + def _build_ds(self, cycle_length, block_length, sloppy=False): + return (dataset_ops.Dataset.from_tensor_slices( + self.input_values).repeat(self.num_repeats).apply( + interleave_ops.parallel_interleave( + lambda x: dataset_ops.Dataset.range(10 * x, 11 * x), + cycle_length, block_length, sloppy))) + + def testSerializationCore(self): + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + self.run_core_tests( + lambda: self._build_ds(cycle_length, block_length), + lambda: self._build_ds(cycle_length * 2, block_length * 1), + self.num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + + def testSerializationWithSloppy(self): + break_points = self.gen_break_points(self.num_outputs, 10) + expected_outputs = np.repeat( + np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]), + self.num_repeats).tolist() + + def run_test(cycle_length, block_length): + actual = self.gen_outputs( + lambda: self._build_ds(cycle_length, block_length, True), + break_points, self.num_outputs) + self.assertSequenceEqual(sorted(actual), expected_outputs) + + # cycle_length > 1, block_length > 1 + run_test(2, 3) + # cycle_length = 1 + run_test(1, 3) + # block_length = 1 + run_test(2, 1) + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).apply( + interleave_ops.parallel_interleave(_interleave_fn, 1)) + + self.run_core_tests(_build_dataset, None, 20) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb7605be1f230cef4cdae30aa672842a678edf7 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ParallelMapDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class ParallelMapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 1 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs)) + + def _build_ds_with_prefetch(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5)) + + def testSaveRestoreCore(self): + for ds_fn in [self._build_ds, self._build_ds_with_prefetch]: + self.run_core_tests( + ds_fn, + lambda: ds_fn(multiplier=15.0), + self._num_outputs) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map( + _map_fn, num_parallel_calls=2).prefetch(2) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1), + num_parallel_calls=2).prefetch(2)) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureConstantInMapFn(self): + + def _build_ds(): + constant_var = constant_op.constant(5) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var, num_parallel_calls=2).prefetch(2)) + + self.run_core_tests(_build_ds, None, 10) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) + + self.run_core_tests(_build_ds, None, num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py similarity index 90% rename from tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py index 3d120a3071ef730f21221e3291d8c84385b51aa3..c802402461216de33e7d3232ba38063c27f33557 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the PrefetchDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f5b6cf5db788ad2fd09b7e93d0ae5ebb530a11 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py @@ -0,0 +1,118 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the RangeDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class RangeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _iterator_checkpoint_prefix_local(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_prefix_local(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_prefix_local()), + dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + + def testSaveRestore(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Saving and restoring in same session. + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def _build_range_dataset(self, start, stop): + return dataset_ops.Dataset.range(start, stop) + + def testRangeCore(self): + start = 2 + stop = 10 + stop_1 = 8 + self.run_core_tests(lambda: self._build_range_dataset(start, stop), + lambda: self._build_range_dataset(start, stop_1), + stop - start) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb35ea624c22ad0a9561d774c86247119c4c837 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the SampleFromDatasets serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class SampleFromDatasetsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, probs, num_samples): + dataset = interleave_ops.sample_from_datasets( + [ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(len(probs)) + ], + probs, + seed=1813) + return dataset.take(num_samples) + + def testSerializationCore(self): + self.run_core_tests( + lambda: self._build_dataset([0.5, 0.5], 100), + lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..af9ef48c0f3b92f61c097410ef4dfd787292e76a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ScanDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import scan_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class ScanDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_elements): + return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( + scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) + + def testScanCore(self): + num_output = 5 + self.run_core_tests(lambda: self._build_dataset(num_output), + lambda: self._build_dataset(2), num_output) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py similarity index 91% rename from tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py index d0cb203a3afd2775756c8542a1e86faedc5cee53..2afebca0f5849c640044830fff05ebff131e0875 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the sequence datasets serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test -class SequenceDatasetSerializationTest( +class SkipDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def _build_skip_dataset(self, count): @@ -52,6 +52,10 @@ class SequenceDatasetSerializationTest( 'Shape must be rank 0 but is rank 1'): self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0) + +class TakeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + def _build_take_dataset(self, count): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take(count) @@ -79,6 +83,10 @@ class SequenceDatasetSerializationTest( 'Shape must be rank 0 but is rank 1'): self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0) + +class RepeatDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + def _build_repeat_dataset(self, count, take_count=3): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take( @@ -117,5 +125,5 @@ class SequenceDatasetSerializationTest( None, 0) -if __name__ == "__main__": +if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py similarity index 96% rename from tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py index 0a6b74dc3eb80a6168117beed06935737198cecb..992d996a485de94ad55305552e42c7fbc92ec64b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Integration test for input pipeline serialization.""" +"""Integration test for dataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -26,7 +26,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import saver as saver_lib -class MultipleInputPipelinesTest(test.TestCase): +class SerializationIntegrationTest(test.TestCase): def _build_input_pipeline(self, name, num_outputs): with ops.name_scope(name): diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f199ec835ef1c72e2c3f8b3b1cc4f5fe6ea0b6f4 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ShuffleAndRepeatDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class ShuffleAndRepeatSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) + + def testCore(self): + self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), + 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d46c762aaaadc4314a10acc5aeb7ace7df5002a8 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py @@ -0,0 +1,148 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ShuffleDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib + + +class ShuffleDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_shuffle_dataset( + self, + range_limit=10, + num_repeats=5, + buffer_size=5, + seed=None, + reshuffle_each_iteration=None, + ): + return dataset_ops.Dataset.range(range_limit).shuffle( + buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats) + + def testShuffleCore(self): + + seed = 55 + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + # pylint: disable=cell-var-from-loop + # pylint: disable=g-long-lambda + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + self.run_core_tests( + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration), + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=10, + reshuffle_each_iteration=reshuffle_each_iteration), + num_outputs) + # pylint: enable=cell-var-from-loop + # pylint: enable=g-long-lambda + + def testNonDeterministicSeeding(self): + + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + # We checkpoint the initial state of the Dataset so that we can restore + # the seeds in the next run. Since the seeding is non-deterministic + # the dataset gets initialized with different seeds each time. + expected = self.gen_outputs( + ds_fn, + break_points=[0], + num_outputs=num_outputs, + ckpt_saved=False, + verify_exhausted=False, + save_checkpoint_at_end=False) + actual = self.gen_outputs( + ds_fn, + break_points=self.gen_break_points(num_outputs), + num_outputs=num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.match(expected, actual) + + def testMultipleIterators(self): + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + with ops.Graph().as_default() as g: + ds = ds_fn() + iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()] + get_next_ops = [it.get_next() for it in iterators] + saveables = [ + contrib_iterator_ops.make_saveable_from_iterator(it) + for it in iterators + ] + for saveable in saveables: + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saver = saver_lib.Saver(allow_empty=True) + with self.test_session(graph=g) as sess: + self._save(sess, saver) + expected = [sess.run(get_next_ops) for _ in range(num_outputs)] + self._restore(saver, sess) + actual = [sess.run(get_next_ops) for _ in range(num_outputs)] + self.match(expected, actual) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..93b26ed58a065de2074906528a0f49d696a813ff --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the SqlDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SqlDatasetSerializationTest( + sql_dataset_op_test_base.SqlDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_repeats): + data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + query = ("SELECT first_name, last_name, motto FROM students ORDER BY " + "first_name DESC") + output_types = (dtypes.string, dtypes.string, dtypes.string) + return readers.SqlDataset(driver_name, data_source_name, query, + output_types).repeat(num_repeats) + + def testSQLSaveable(self): + num_repeats = 4 + num_outputs = num_repeats * 2 + self.run_core_tests(lambda: self._build_dataset(num_repeats), + lambda: self._build_dataset(num_repeats // 2), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..14cd3e9c4a72cc7832f9bb1cb49c72a8a7cb2dcd --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the StatsDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import stats_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the +# transformation `stats_ops.set_stats_aggregator`, since we don't support +# serializing StatsAggregator yet. +class StatsDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset_bytes_stats(self, num_elements): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( + stats_ops.bytes_produced_stats("bytes_produced")) + + def test_bytes_produced_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.bytes_produced_stats(["bytes_produced"])), + None, 100) + # pylint: enable=g-long-lambda + + def testBytesStatsDatasetSaveableCore(self): + num_outputs = 100 + self.run_core_tests( + lambda: self._build_dataset_bytes_stats(num_outputs), + lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) + + def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag)) + + def _build_dataset_multiple_tags(self, + num_elements, + tag1="record_latency", + tag2="record_latency_2"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) + + def test_latency_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats(["record_latency", "record_latency_2"])), + None, 100) + # pylint: enable=g-long-lambda + + def testLatencyStatsDatasetSaveableCore(self): + num_outputs = 100 + + self.run_core_tests( + lambda: self._build_dataset_latency_stats(num_outputs), + lambda: self._build_dataset_latency_stats(num_outputs // 10), + num_outputs) + + self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), + None, num_outputs) + + tag1 = "record_latency" + tag2 = "record_latency" + self.run_core_tests( + lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), + None, num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2483787f44f913199e3f2aa46d181d609a4a9a8f --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the TextLineDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class TextLineDatasetSerializationTest( + reader_dataset_ops_test_base.TextLineDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, test_filenames, compression_type=None): + return core_readers.TextLineDataset( + test_filenames, compression_type=compression_type, buffer_size=10) + + def testTextLineCore(self): + compression_types = [None, "GZIP", "ZLIB"] + num_files = 5 + lines_per_file = 5 + num_outputs = num_files * lines_per_file + for compression_type in compression_types: + test_filenames = self._createFiles( + num_files, + lines_per_file, + crlf=True, + compression_type=compression_type) + # pylint: disable=cell-var-from-loop + self.run_core_tests( + lambda: self._build_iterator_graph(test_filenames, compression_type), + lambda: self._build_iterator_graph(test_filenames), num_outputs) + # pylint: enable=cell-var-from-loop + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..55a6257a274cd7f78e3818943627cfa09a185fd7 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py @@ -0,0 +1,99 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the TFRecordDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import zlib + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class TFRecordDatasetSerializationTest( + reader_dataset_ops_test_base.TFRecordDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, + num_epochs, + batch_size=1, + compression_type=None, + buffer_size=None): + filenames = self._createFiles() + if compression_type == "ZLIB": + zlib_files = [] + for i, fn in enumerate(filenames): + with open(fn, "rb") as f: + cdata = zlib.compress(f.read()) + zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) + with open(zfn, "wb") as f: + f.write(cdata) + zlib_files.append(zfn) + filenames = zlib_files + + elif compression_type == "GZIP": + gzip_files = [] + for i, fn in enumerate(self.test_filenames): + with open(fn, "rb") as f: + gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) + with gzip.GzipFile(gzfn, "wb") as gzf: + gzf.write(f.read()) + gzip_files.append(gzfn) + filenames = gzip_files + + return core_readers.TFRecordDataset( + filenames, compression_type, + buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) + + def testTFRecordWithoutBufferCore(self): + num_epochs = 5 + batch_size = num_epochs + num_outputs = num_epochs * self._num_files * self._num_records // batch_size + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, batch_size, + buffer_size=0), + lambda: self._build_iterator_graph(num_epochs * 2, batch_size), + num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None, + num_outputs * batch_size) + # pylint: enable=g-long-lambda + + def testTFRecordWithBufferCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + def testTFRecordWithCompressionCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a5a8a20dd7a9f891b07351570006636ca34bd0 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the UnbatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class UnbatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch( + batch_size).apply(batching.unbatch()) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..22f15b88464a770207dc7c6f0387d73ea3d5c2e4 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the UniqueDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class UniqueDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testUnique(self): + + def build_dataset(num_elements, unique_elem_range): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: x % unique_elem_range).apply(unique.unique()) + + self.run_core_tests(lambda: build_dataset(200, 100), + lambda: build_dataset(40, 100), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py similarity index 92% rename from tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py index e39fa957f0bbb9d3671274d5f58b993e8399814b..340a6ff72e6813c3743d3d83a72ac12d4a392b66 100644 --- a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the ZipDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 1b67a33f04b0c2ac80402e163005123a4b3e4400..3c11d7a97fc9a4b2b8b19a8e82ad5e9037d6bbcd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -19,144 +19,32 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib - - -class ShuffleDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_shuffle_dataset( - self, - range_limit=10, - num_repeats=5, - buffer_size=5, - seed=None, - reshuffle_each_iteration=None, - ): - return dataset_ops.Dataset.range(range_limit).shuffle( - buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats) - - def testShuffleCore(self): - - seed = 55 - range_limit = 10 - num_repeats = 5 - num_outputs = range_limit * num_repeats - buffer_sizes = [1, 3, 8, 10, 25, 50] - # pylint: disable=cell-var-from-loop - # pylint: disable=g-long-lambda - for reshuffle_each_iteration in [True, False]: - for buffer_size in buffer_sizes: - self.run_core_tests( - lambda: self._build_shuffle_dataset( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration), - lambda: self._build_shuffle_dataset( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=10, - reshuffle_each_iteration=reshuffle_each_iteration), - num_outputs) - # pylint: enable=cell-var-from-loop - # pylint: enable=g-long-lambda - - def testNonDeterministicSeeding(self): - - range_limit = 10 - num_repeats = 5 - num_outputs = range_limit * num_repeats - buffer_sizes = [1, 3, 8, 10, 25, 50] - for reshuffle_each_iteration in [True, False]: - for buffer_size in buffer_sizes: - - def ds_fn(): - # pylint: disable=cell-var-from-loop - return self._build_shuffle_dataset( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=None, # Iterator seeds are generated non-deterministically. - reshuffle_each_iteration=reshuffle_each_iteration) - # pylint: enable=cell-var-from-loop - - # We checkpoint the initial state of the Dataset so that we can restore - # the seeds in the next run. Since the seeding is non-deterministic - # the dataset gets initialized with different seeds each time. - expected = self.gen_outputs( - ds_fn, - break_points=[0], - num_outputs=num_outputs, - ckpt_saved=False, - verify_exhausted=False, - save_checkpoint_at_end=False) - actual = self.gen_outputs( - ds_fn, - break_points=self.gen_break_points(num_outputs), - num_outputs=num_outputs, - ckpt_saved=True, - verify_exhausted=False) - self.match(expected, actual) - - def testMultipleIterators(self): - range_limit = 10 - num_repeats = 5 - num_outputs = range_limit * num_repeats - buffer_sizes = [1, 3, 8, 10, 25, 50] - - for reshuffle_each_iteration in [True, False]: - for buffer_size in buffer_sizes: - - def ds_fn(): - # pylint: disable=cell-var-from-loop - return self._build_shuffle_dataset( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=None, # Iterator seeds are generated non-deterministically. - reshuffle_each_iteration=reshuffle_each_iteration) - # pylint: enable=cell-var-from-loop - - with ops.Graph().as_default() as g: - ds = ds_fn() - iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()] - get_next_ops = [it.get_next() for it in iterators] - saveables = [ - contrib_iterator_ops.make_saveable_from_iterator(it) - for it in iterators - ] - for saveable in saveables: - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - saver = saver_lib.Saver(allow_empty=True) - with self.test_session(graph=g) as sess: - self._save(sess, saver) - expected = [sess.run(get_next_ops) for _ in range(num_outputs)] - self._restore(saver, sess) - actual = [sess.run(get_next_ops) for _ in range(num_outputs)] - self.match(expected, actual) - - -class ShuffleAndRepeatTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + + +class ShuffleAndRepeatTest(test.TestCase): def _build_ds(self, seed, count=5, num_elements=20): return dataset_ops.Dataset.range(num_elements).apply( shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed)) + def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True): + get_next = ds_fn().make_one_shot_iterator().get_next() + outputs = [] + with self.test_session() as sess: + for _ in range(num_outputs): + outputs.append(sess.run(get_next)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + return outputs + def testCorrectOutput(self): - output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output = self._gen_outputs(lambda: self._build_ds(10), 100) self.assertSequenceEqual( sorted(output), sorted( np.array([range(20) for _ in range(5)]).flatten())) @@ -165,53 +53,53 @@ class ShuffleAndRepeatTest( def testReshuffling(self): # Check that the output orders of different epochs are indeed different. - output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output = self._gen_outputs(lambda: self._build_ds(10), 100) for i in range(4): epoch1 = output[i * 20:(i + 1) * 20] epoch2 = output[(i + 1) * 20:(i + 2) * 20] self.assertNotEqual(epoch1, epoch2) def testSameOrderForSameSeeds(self): - output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) - output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output1 = self._gen_outputs(lambda: self._build_ds(10), 100) + output2 = self._gen_outputs(lambda: self._build_ds(10), 100) self.assertEqual(output1, output2) def testDifferentOrderForDifferentSeeds(self): - output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) - output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100) + output1 = self._gen_outputs(lambda: self._build_ds(10), 100) + output2 = self._gen_outputs(lambda: self._build_ds(20), 100) self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) def testCountNone(self): - output1 = self.gen_outputs( - lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False) - output2 = self.gen_outputs( - lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False) + output1 = self._gen_outputs( + lambda: self._build_ds(10, count=None), 100, verify_exhausted=False) + output2 = self._gen_outputs( + lambda: self._build_ds(20, count=None), 100, verify_exhausted=False) self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) def testCountMinusOne(self): - output1 = self.gen_outputs( - lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False) - output2 = self.gen_outputs( - lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False) + output1 = self._gen_outputs( + lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False) + output2 = self._gen_outputs( + lambda: self._build_ds(20, count=-1), 100, verify_exhausted=False) self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) def testInfiniteOutputs(self): # Asserting the iterator is exhausted after producing 100 items should fail. with self.assertRaises(AssertionError): - self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100) + self._gen_outputs(lambda: self._build_ds(10, count=None), 100) with self.assertRaises(AssertionError): - self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100) + self._gen_outputs(lambda: self._build_ds(10, count=-1), 100) def testInfiniteEmpty(self): with self.assertRaises(errors.OutOfRangeError): - self.gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), - [], 100) + self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), + 100) with self.assertRaises(errors.OutOfRangeError): - self.gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), [], - 100) + self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), + 100) def testLargeBufferSize(self): with ops.Graph().as_default() as g: @@ -222,17 +110,5 @@ class ShuffleAndRepeatTest( sess.run(get_next_op) -class ShuffleAndRepeatSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_ds(self, seed): - return dataset_ops.Dataset.range(20).apply( - shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) - - def testCore(self): - self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), - 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 33c48e20bea53b88d69a59e715af38b22dd2cbd4..8b2f84649486e35e1067f5f9cbe4a7abec71e080 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.ops import sliding @@ -29,28 +30,45 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class SlideDatasetTest(test.TestCase): - - def testSlideDataset(self): - """Test an dataset that maps a TF function across its input elements.""" +class SlideDatasetTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters( + (20, 14, 7, 1), + (20, 17, 9, 1), + (20, 14, 14, 1), + (20, 10, 14, 1), + (20, 14, 19, 1), + (20, 4, 1, 2), + (20, 2, 1, 6), + (20, 4, 7, 2), + (20, 2, 7, 6), + (1, 10, 4, 1), + (0, 10, 4, 1), + ) + def testSlideDataset(self, count, window_size, window_shift, window_stride): + """Tests a dataset that slides a window its input elements.""" components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) - count = array_ops.placeholder(dtypes.int64, shape=[]) - window_size = array_ops.placeholder(dtypes.int64, shape=[]) - stride = array_ops.placeholder(dtypes.int64, shape=[]) + count_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_size_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_shift_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_stride_t = array_ops.placeholder(dtypes.int64, shape=[]) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> - # RepeatDataset(count) -> _SlideDataset(window_size, stride). - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(_map_fn) - .repeat(count) - .apply(sliding.sliding_window_batch(window_size, stride)) - .make_initializable_iterator()) + # RepeatDataset(count) -> + # _SlideDataset(window_size, window_shift, window_stride). + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(count).apply( + sliding.sliding_window_batch( + window_size=window_size_t, + window_shift=window_shift_t, + window_stride=window_stride_t)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -58,60 +76,126 @@ class SlideDatasetTest(test.TestCase): [t.shape.as_list() for t in get_next]) with self.test_session() as sess: - # Slide over a finite input, where the window_size divides the - # total number of elements. - sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7}) - # Same formula with convolution layer. - num_batches = (20 * 7 - 14) // 7 + 1 + sess.run( + init_op, + feed_dict={ + count_t: count, + window_size_t: window_size, + window_shift_t: window_shift, + window_stride_t: window_stride + }) + num_batches = (count * 7 - ( + (window_size - 1) * window_stride + 1)) // window_shift + 1 for i in range(num_batches): result = sess.run(get_next) for component, result_component in zip(components, result): - for j in range(14): - self.assertAllEqual(component[(i*7 + j) % 7]**2, - result_component[j]) + for j in range(window_size): + self.assertAllEqual( + component[(i * window_shift + j * window_stride) % 7]**2, + result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - # Slide over a finite input, where the window_size does not - # divide the total number of elements. - sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9}) + @parameterized.parameters( + (20, 14, 7, 1), + (20, 17, 9, 1), + (20, 14, 14, 1), + (20, 10, 14, 1), + (20, 14, 19, 1), + (20, 4, 1, 2), + (20, 2, 1, 6), + (20, 4, 7, 2), + (20, 2, 7, 6), + (1, 10, 4, 1), + (0, 10, 4, 1), + ) + def testSlideDatasetDeprecated(self, count, window_size, stride, + window_stride): + """Tests a dataset that slides a window its input elements.""" + components = (np.arange(7), + np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], + np.array(37.0) * np.arange(7)) - num_batches = (20 * 7 - 17) // 9 + 1 - for i in range(num_batches): - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range(17): - self.assertAllEqual(component[(i*9 + j) % 7]**2, - result_component[j]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + count_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_size_t = array_ops.placeholder(dtypes.int64, shape=[]) + stride_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_stride_t = array_ops.placeholder(dtypes.int64, shape=[]) - # Slide over a finite input, which is less than window_size, - # should fail straight away. - sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) - sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 8}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> + # RepeatDataset(count) -> _SlideDataset(window_size, stride, window_stride). + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(count).apply( + sliding.sliding_window_batch( + window_size=window_size_t, + stride=stride_t, + window_stride=window_stride_t)).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([[None] + list(c.shape[1:]) for c in components], + [t.shape.as_list() for t in get_next]) - # Slide over an empty input should fail straight away. - sess.run(init_op, feed_dict={count: 0, window_size: 8, stride: 4}) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + count_t: count, + window_size_t: window_size, + stride_t: stride, + window_stride_t: window_stride + }) + num_batches = (count * 7 - ( + (window_size - 1) * window_stride + 1)) // stride + 1 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(window_size): + self.assertAllEqual( + component[(i * stride + j * window_stride) % 7]**2, + result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - # Empty window_size should be an initialization time error. - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, window_size: 0, stride: 0}) + @parameterized.parameters( + (14, 0, 3, 1), + (14, 3, 0, 1), + (14, 3, 3, 0), + ) + def testSlideDatasetInvalid(self, count, window_size, window_shift, + window_stride): + count_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_size_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_shift_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_stride_t = array_ops.placeholder(dtypes.int64, shape=[]) + + iterator = ( + dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count_t).apply( + sliding.sliding_window_batch( + window_size=window_size_t, + window_shift=window_shift_t, + window_stride=window_stride_t)).make_initializable_iterator()) + init_op = iterator.initializer - # Invalid stride should be an initialization time error. - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0}) - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3}) + with self.test_session() as sess: with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5}) + sess.run( + init_op, + feed_dict={ + count_t: count, + window_size_t: window_size, + window_shift_t: window_shift, + window_stride_t: window_stride + }) + + def testSlideDatasetValueError(self): + with self.assertRaises(ValueError): + dataset_ops.Dataset.range(10).map(lambda x: x).apply( + sliding.sliding_window_batch( + window_size=1, stride=1, window_shift=1, window_stride=1)) def assertSparseValuesEqual(self, a, b): self.assertAllEqual(a.indices, b.indices) @@ -125,7 +209,8 @@ class SlideDatasetTest(test.TestCase): indices=[[0]], values=(i * [1]), dense_shape=[1]) iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch(5, 3)).make_initializable_iterator() + sliding.sliding_window_batch( + window_size=5, window_shift=3)).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -153,7 +238,8 @@ class SlideDatasetTest(test.TestCase): dense_shape=[i]) iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch(5, 3)).make_initializable_iterator() + sliding.sliding_window_batch( + window_size=5, window_shift=3)).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -183,11 +269,11 @@ class SlideDatasetTest(test.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = (dataset_ops.Dataset.range(10) - .map(_sparse) - .apply(sliding.sliding_window_batch(4, 2)) - .apply(sliding.sliding_window_batch(3, 1)) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=4, window_shift=2)).apply( + sliding.sliding_window_batch(window_size=3, window_shift=1)) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -196,9 +282,9 @@ class SlideDatasetTest(test.TestCase): # Slide: 1st batch. actual = sess.run(get_next) expected = sparse_tensor.SparseTensorValue( - indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], - [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], - [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0], + [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0], + [2, 2, 0], [2, 3, 0]], values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7], dense_shape=[3, 4, 1]) self.assertTrue(sparse_tensor.is_sparse(actual)) @@ -206,9 +292,9 @@ class SlideDatasetTest(test.TestCase): # Slide: 2nd batch. actual = sess.run(get_next) expected = sparse_tensor.SparseTensorValue( - indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], - [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], - [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0], + [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0], + [2, 2, 0], [2, 3, 0]], values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9], dense_shape=[3, 4, 1]) self.assertTrue(sparse_tensor.is_sparse(actual)) @@ -223,10 +309,11 @@ class SlideDatasetTest(test.TestCase): yield [4.0, 5.0, 6.0] yield [7.0, 8.0, 9.0, 10.0] - iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32, - output_shapes=[None]) - .apply(sliding.sliding_window_batch(3, 1)) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_generator( + generator, dtypes.float32, output_shapes=[None]).apply( + sliding.sliding_window_batch(window_size=3, window_shift=1)) + .make_initializable_iterator()) next_element = iterator.get_next() with self.test_session() as sess: diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index 4148addf2878c99f47ebe1454edf69ad7f38dfbc..2c2cfbebff5d3eba00f120467102b4185d81ab24 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -18,83 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -import sqlite3 - -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import readers +from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SqlDatasetTestBase(test.TestCase): - - def _createSqlDataset(self, output_types, num_repeats=1): - dataset = readers.SqlDataset(self.driver_name, self.data_source_name, - self.query, output_types).repeat(num_repeats) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - return init_op, get_next - - def setUp(self): - self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") - self.driver_name = array_ops.placeholder_with_default( - array_ops.constant("sqlite", dtypes.string), shape=[]) - self.query = array_ops.placeholder(dtypes.string, shape=[]) - - conn = sqlite3.connect(self.data_source_name) - c = conn.cursor() - c.execute("DROP TABLE IF EXISTS students") - c.execute("DROP TABLE IF EXISTS people") - c.execute("DROP TABLE IF EXISTS townspeople") - c.execute( - "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, " - "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), " - "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), " - "desk_number INTEGER, income INTEGER, favorite_number INTEGER, " - "favorite_big_number INTEGER, favorite_negative_number INTEGER, " - "favorite_medium_sized_number INTEGER, brownie_points INTEGER, " - "account_balance INTEGER, registration_complete INTEGER)") - c.executemany( - "INSERT INTO students (first_name, last_name, motto, school_id, " - "favorite_nonsense_word, desk_number, income, favorite_number, " - "favorite_big_number, favorite_negative_number, " - "favorite_medium_sized_number, brownie_points, account_balance, " - "registration_complete) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647, - 9223372036854775807, -2, 32767, 0, 0, 1), - ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000, - -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)]) - c.execute( - "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, " - "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))") - c.executemany( - "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)", - [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe", - "California")]) - c.execute( - "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY " - "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories " - "FLOAT, accolades FLOAT, triumphs FLOAT)") - c.executemany( - "INSERT INTO townspeople (first_name, last_name, victories, " - "accolades, triumphs) VALUES (?, ?, ?, ?, ?)", - [("George", "Washington", 20.00, - 1331241.321342132321324589798264627463827647382647382643874, - 9007199254740991.0), - ("John", "Adams", -19.95, - 1331241321342132321324589798264627463827647382647382643874.0, - 9007199254740992.0)]) - conn.commit() - conn.close() - - -class SqlDatasetTest(SqlDatasetTestBase): +class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # Test that SqlDataset can read from a database table. def testReadResultSet(self): @@ -656,27 +586,5 @@ class SqlDatasetTest(SqlDatasetTestBase): sess.run(get_next) -class SqlDatasetSerializationTest( - SqlDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, num_repeats): - data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") - driver_name = array_ops.placeholder_with_default( - array_ops.constant("sqlite", dtypes.string), shape=[]) - query = ("SELECT first_name, last_name, motto FROM students ORDER BY " - "first_name DESC") - output_types = (dtypes.string, dtypes.string, dtypes.string) - return readers.SqlDataset(driver_name, data_source_name, query, - output_types).repeat(num_repeats) - - def testSQLSaveable(self): - num_repeats = 4 - num_outputs = num_repeats * 2 - self.run_core_tests(lambda: self._build_dataset(num_repeats), - lambda: self._build_dataset(num_repeats // 2), - num_outputs) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..1f5c725a9269e80311f3e73c51c28ab80e7c4815 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py @@ -0,0 +1,96 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing SqlDataset.""" + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import sqlite3 + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SqlDatasetTestBase(test.TestCase): + """Base class for setting up and testing SqlDataset.""" + + def _createSqlDataset(self, output_types, num_repeats=1): + dataset = readers.SqlDataset(self.driver_name, self.data_source_name, + self.query, output_types).repeat(num_repeats) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + return init_op, get_next + + def setUp(self): + self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + self.driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + self.query = array_ops.placeholder(dtypes.string, shape=[]) + + conn = sqlite3.connect(self.data_source_name) + c = conn.cursor() + c.execute("DROP TABLE IF EXISTS students") + c.execute("DROP TABLE IF EXISTS people") + c.execute("DROP TABLE IF EXISTS townspeople") + c.execute( + "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, " + "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), " + "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), " + "desk_number INTEGER, income INTEGER, favorite_number INTEGER, " + "favorite_big_number INTEGER, favorite_negative_number INTEGER, " + "favorite_medium_sized_number INTEGER, brownie_points INTEGER, " + "account_balance INTEGER, registration_complete INTEGER)") + c.executemany( + "INSERT INTO students (first_name, last_name, motto, school_id, " + "favorite_nonsense_word, desk_number, income, favorite_number, " + "favorite_big_number, favorite_negative_number, " + "favorite_medium_sized_number, brownie_points, account_balance, " + "registration_complete) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647, + 9223372036854775807, -2, 32767, 0, 0, 1), + ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000, + -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)]) + c.execute( + "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, " + "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))") + c.executemany( + "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)", + [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe", + "California")]) + c.execute( + "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY " + "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories " + "FLOAT, accolades FLOAT, triumphs FLOAT)") + c.executemany( + "INSERT INTO townspeople (first_name, last_name, victories, " + "accolades, triumphs) VALUES (?, ?, ?, ?, ?)", + [("George", "Washington", 20.00, + 1331241.321342132321324589798264627463827647382647382643874, + 9007199254740991.0), + ("John", "Adams", -19.95, + 1331241321342132321324589798264627463827647382647382643874.0, + 9007199254740992.0)]) + conn.commit() + conn.close() + + diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 17b6644759e53f84b23e070a71267aa15dcffe49..a41d21f8c14ed6bec7626599a5aa7f365765ce8b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -19,10 +19,9 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base from tensorflow.contrib.data.python.ops import stats_ops -from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -30,28 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class StatsDatasetTestBase(test.TestCase): - - def _assertSummaryHasCount(self, summary_str, tag, expected_value): - summary_proto = summary_pb2.Summary() - summary_proto.ParseFromString(summary_str) - for value in summary_proto.value: - if tag == value.tag: - self.assertEqual(expected_value, value.histo.num) - return - self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) - - def _assertSummaryHasSum(self, summary_str, tag, expected_value): - summary_proto = summary_pb2.Summary() - summary_proto.ParseFromString(summary_str) - for value in summary_proto.value: - if tag == value.tag: - self.assertEqual(expected_value, value.histo.sum) - return - self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) - - -class StatsDatasetTest(StatsDatasetTestBase): +class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): def testBytesProduced(self): stats_aggregator = stats_ops.StatsAggregator() @@ -198,7 +176,7 @@ class StatsDatasetTest(StatsDatasetTestBase): class FeatureStatsDatasetTest( - StatsDatasetTestBase, + stats_dataset_test_base.StatsDatasetTestBase, reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): def testFeaturesStats(self): @@ -236,68 +214,5 @@ class FeatureStatsDatasetTest( self._sum_keywords(1) * num_epochs + 2 * total_records) -class StatsDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset_bytes_stats(self, num_elements): - return dataset_ops.Dataset.range(num_elements).map( - lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( - stats_ops.bytes_produced_stats("bytes_produced")) - - def test_bytes_produced_stats_invalid_tag_shape(self): - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 0 but is rank 1'): - self.run_core_tests( - lambda: dataset_ops.Dataset.range(100).apply( - stats_ops.bytes_produced_stats(["bytes_produced"])), - None, 100) - - def testBytesStatsDatasetSaveableCore(self): - num_outputs = 100 - self.run_core_tests( - lambda: self._build_dataset_bytes_stats(num_outputs), - lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) - - def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): - return dataset_ops.Dataset.range(num_elements).apply( - stats_ops.latency_stats(tag)) - - def _build_dataset_multiple_tags(self, - num_elements, - tag1="record_latency", - tag2="record_latency_2"): - return dataset_ops.Dataset.range(num_elements).apply( - stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) - - def test_latency_stats_invalid_tag_shape(self): - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 0 but is rank 1'): - self.run_core_tests( - lambda: dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats(["record_latency", "record_latency_2"])), - None, 100) - - def testLatencyStatsDatasetSaveableCore(self): - num_outputs = 100 - - self.run_core_tests( - lambda: self._build_dataset_latency_stats(num_outputs), - lambda: self._build_dataset_latency_stats(num_outputs // 10), - num_outputs) - - self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), - None, num_outputs) - - tag1 = "record_latency" - tag2 = "record_latency" - self.run_core_tests( - lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), - None, num_outputs) - - -# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the -# transformation `stats_ops.set_stats_aggregator`, since we don't support -# serializing StatsAggregator yet. - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..9a13acf8f0ac6690cad8847873768562da795496 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing the input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.core.framework import summary_pb2 +from tensorflow.python.platform import test + + +class StatsDatasetTestBase(test.TestCase): + """Base class for testing statistics gathered in `StatsAggregator`.""" + + def _assertSummaryHasCount(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.num) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def _assertSummaryHasSum(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.sum) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py index 9167cb3379bba5cb1ba76a96549395c45dca9e35..0486e2bce20e9dcf81dcb5ac49fe5b397e44bf0c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import threading +from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.ops import threadpool @@ -30,9 +31,11 @@ from tensorflow.python.ops import script_ops from tensorflow.python.platform import test -class OverrideThreadpoolDatasetTest(test.TestCase): +class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase): - def testNumThreads(self): + @parameterized.parameters((1, None), (2, None), (4, None), (8, None), + (16, None), (4, -1), (4, 0), (4, 1), (4, 4)) + def testNumThreads(self, num_threads, max_intra_op_parallelism): def get_thread_id(_): # Python creates a dummy thread object to represent the current @@ -42,35 +45,35 @@ class OverrideThreadpoolDatasetTest(test.TestCase): # identifier that maps one-to-one with the underlying OS thread. return np.array(threading.current_thread().ident).astype(np.int64) - for num_threads in [1, 2, 4, 8, 16]: + dataset = ( + dataset_ops.Dataset.range(1000).map( + lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), + num_parallel_calls=32).apply(unique.unique())) - dataset = ( - dataset_ops.Dataset.range(1000).map( - lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), - num_parallel_calls=32).apply(unique.unique())) + dataset = threadpool.override_threadpool( + dataset, + threadpool.PrivateThreadPool( + num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, + display_name="private_thread_pool_%d" % num_threads)) - dataset = threadpool.override_threadpool( - dataset, - threadpool.PrivateThreadPool( - num_threads, display_name="private_thread_pool_%d" % num_threads)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() - iterator = dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - sess.run(iterator.initializer) - thread_ids = [] - try: - while True: - thread_ids.append(sess.run(next_element)) - except errors.OutOfRangeError: - pass - self.assertEqual(len(thread_ids), len(set(thread_ids))) - self.assertGreater(len(thread_ids), 0) - # NOTE(mrry): We don't control the thread pool scheduling, and - # so cannot guarantee that all of the threads in the pool will - # perform work. - self.assertLessEqual(len(thread_ids), num_threads) + with self.test_session() as sess: + sess.run(iterator.initializer) + thread_ids = [] + try: + while True: + thread_ids.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + self.assertEqual(len(thread_ids), len(set(thread_ids))) + self.assertGreater(len(thread_ids), 0) + # NOTE(mrry): We don't control the thread pool scheduling, and + # so cannot guarantee that all of the threads in the pool will + # perform work. + self.assertLessEqual(len(thread_ids), num_threads) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py index 3c436f7a0b45a13109960e87dd97ca56b10bb871..d79a842e7a5d816e2e6a52fc83acbd6b260cf64b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import unique from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes @@ -79,18 +78,5 @@ class UniqueDatasetTest(test.TestCase): ]) -class UniqueSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testUnique(self): - - def build_dataset(num_elements, unique_elem_range): - return dataset_ops.Dataset.range(num_elements).map( - lambda x: x % unique_elem_range).apply(unique.unique()) - - self.run_core_tests(lambda: build_dataset(200, 100), - lambda: build_dataset(40, 100), 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..33d95d67549e1c8d1d9af578fcebbb4f939c418a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py @@ -0,0 +1,523 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class WindowDatasetTest(test.TestCase, parameterized.TestCase): + + def _structuredDataset(self, structure, shape, dtype): + if structure is None: + return dataset_ops.Dataset.from_tensors( + array_ops.zeros(shape, dtype=dtype)) + else: + return dataset_ops.Dataset.zip( + tuple([ + self._structuredDataset(substructure, shape, dtype) + for substructure in structure + ])) + + def _structuredElement(self, structure, shape, dtype): + if structure is None: + return array_ops.zeros(shape, dtype=dtype) + else: + return tuple([ + self._structuredElement(substructure, shape, dtype) + for substructure in structure + ]) + + def _assertEqual(self, xs, ys): + self.assertEqual(type(xs), type(ys)) + if isinstance(xs, tuple) and isinstance(ys, tuple): + self.assertEqual(len(xs), len(ys)) + for x, y in zip(xs, ys): + self._assertEqual(x, y) + elif isinstance(xs, np.ndarray) and isinstance(ys, np.ndarray): + self.assertAllEqual(xs, ys) + else: + self.assertEqual(xs, ys) + + @parameterized.parameters( + (None, np.int32([]), dtypes.bool), + (None, np.int32([]), dtypes.int32), + (None, np.int32([]), dtypes.float32), + (None, np.int32([]), dtypes.string), + (None, np.int32([2]), dtypes.int32), + (None, np.int32([2, 2]), dtypes.int32), + ((None, None, None), np.int32([]), dtypes.int32), + ((None, (None, None)), np.int32([]), dtypes.int32), + ) + def testWindowDatasetFlatMap(self, structure, shape, dtype): + """Tests windowing by chaining it with flat map. + + Args: + structure: the input structure + shape: the input shape + dtype: the input data type + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return args[0] + return dataset_ops.Dataset.zip( + tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args])) + + dataset = self._structuredDataset(structure, shape, dtype).apply( + grouping.window_dataset(5)).flat_map(fn) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected = sess.run(self._structuredElement(structure, shape, dtype)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (None, np.int32([]), dtypes.bool), + (None, np.int32([]), dtypes.int32), + (None, np.int32([]), dtypes.float32), + (None, np.int32([]), dtypes.string), + (None, np.int32([2]), dtypes.int32), + (None, np.int32([2, 2]), dtypes.int32), + ((None, None, None), np.int32([]), dtypes.int32), + ((None, (None, None)), np.int32([]), dtypes.int32), + ) + def testWindowDatasetBatchDense(self, structure, shape, dtype): + """Tests batching of dense tensor windows. + + Args: + structure: the input structure + shape: the input shape + dtype: the input data type + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return batching.batch_window(args[0]) + + return tuple([ + fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg) + for arg in args + ]) + + dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply( + grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected = sess.run( + self._structuredElement(structure, np.concatenate( + ([5], shape), axis=0), dtype)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int32([]),), + (np.int32([1]),), + (np.int32([1, 2, 3]),), + ) + def testWindowDatasetBatchDenseDynamicShape(self, shape): + """Tests batching of dynamically shaped dense tensor windows. + + Args: + shape: the input shape + """ + + shape_t = array_ops.placeholder(dtypes.int32) + dataset = dataset_ops.Dataset.from_tensors( + array_ops.zeros(shape_t)).repeat(5).apply( + grouping.window_dataset(5)).apply( + grouping._map_x_dataset(batching.batch_window)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op, {shape_t: shape}) + expected = sess.run( + self._structuredElement(None, np.concatenate(([5], shape), axis=0), + dtypes.int32)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + def _make_dense_to_sparse_fn(self, is_scalar): + + def dense_to_sparse_scalar(tensor): + indices = [[]] + values = array_ops.expand_dims(tensor, 0) + shape = [] + return sparse_tensor.SparseTensorValue(indices, values, shape) + + def dense_to_sparse_non_scalar(tensor): + indices = array_ops.where(array_ops.ones_like(tensor, dtype=dtypes.bool)) + values = array_ops.gather_nd(tensor, indices) + shape = array_ops.shape(tensor, out_type=dtypes.int64) + return sparse_tensor.SparseTensorValue(indices, values, shape) + + if is_scalar: + return dense_to_sparse_scalar + return dense_to_sparse_non_scalar + + def _structuredSparseDataset(self, structure, shape, dtype): + dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test + if structure is None: + return dataset_ops.Dataset.from_tensors( + dense_to_sparse(array_ops.zeros(shape, dtype=dtype))) + else: + return dataset_ops.Dataset.zip( + tuple([ + self._structuredSparseDataset(substructure, shape, dtype) + for substructure in structure + ])) + + def _structuredSparseElement(self, structure, shape, dtype): + dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test + if structure is None: + return dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) + else: + return tuple([ + self._structuredSparseElement(substructure, shape, dtype) + for substructure in structure + ]) + + @parameterized.parameters( + (None, np.int32([]), dtypes.bool), + (None, np.int32([]), dtypes.int32), + (None, np.int32([]), dtypes.float32), + (None, np.int32([]), dtypes.string), + (None, np.int32([2]), dtypes.int32), + (None, np.int32([2, 2]), dtypes.int32), + ((None, None, None), np.int32([]), dtypes.int32), + ((None, (None, None)), np.int32([]), dtypes.int32), + ) + def testWindowDatasetBatchSparse(self, structure, shape, dtype): + """Tests batching of sparse tensor windows. + + Args: + structure: the input structure + shape: the input shape + dtype: the input data type + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return batching.batch_window(args[0]) + + return tuple([ + fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg) + for arg in args + ]) + + dataset = self._structuredSparseDataset( + structure, shape, dtype).repeat(5).apply( + grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected = sess.run( + self._structuredSparseElement(structure, + np.concatenate(([5], shape), axis=0), + dtype)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int32([]),), + (np.int32([1]),), + (np.int32([1, 2, 3]),), + ) + def testWindowDatasetBatchSparseDynamicShape(self, shape): + """Tests batching of dynamically shaped sparse tensor windows. + + Args: + shape: the input shape + """ + + shape_t = array_ops.placeholder(dtypes.int32) + dataset = dataset_ops.Dataset.from_tensors(array_ops.zeros(shape_t)).map( + self._make_dense_to_sparse_fn(len(shape) == 0)).repeat(5).apply( # pylint: disable=g-explicit-length-test + grouping.window_dataset(5)).apply( + grouping._map_x_dataset(batching.batch_window)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op, {shape_t: shape}) + expected = sess.run( + self._structuredSparseElement(None, + np.concatenate(([5], shape), axis=0), + dtypes.int32)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + def _structuredRaggedDataset(self, structure, shapes, dtype): + + if structure is None: + return dataset_ops.Dataset.from_tensor_slices(shapes).map( + lambda shape: array_ops.zeros(shape, dtype=dtype)) + else: + return dataset_ops.Dataset.zip( + tuple([ + self._structuredRaggedDataset(substructure, shapes, dtype) + for substructure in structure + ])) + + @parameterized.parameters( + (None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.string, [-1]), + (None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), + (None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]), + ((None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ((None, (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])), + ) + def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype, + padded_shape): + """Tests padded batching of dense tensor windows. + + Args: + structure: the input structure + shapes: the input shapes + dtype: the input data type + padded_shape: the shape to pad the output to + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return batching.padded_batch_window(args[0], padded_shape) + + return tuple([ + fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window( + arg, padded_shape) for arg in args + ]) + + dataset = self._structuredRaggedDataset(structure, shapes, dtype).apply( + grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset(fn)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) + expected = sess.run( + self._structuredElement( + structure, + np.concatenate((np.int32([len(shapes)]), expected_shape)), dtype)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int32([[1], [2], [3]]), [-1]), + (np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]), + (np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), + ) + def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape): + """Tests padded batching of dynamically shaped dense tensor windows. + + Args: + shapes: the input shapes + padded_shape: the shape to pad the output to + """ + + shapes_t = array_ops.placeholder(dtypes.int32) + dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map( + lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply( + grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset( + lambda x: batching.padded_batch_window(x, padded_shape))) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op, {shapes_t: shapes}) + expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) + expected = sess.run( + self._structuredElement( + None, np.concatenate((np.int32([len(shapes)]), expected_shape)), + dtypes.int32)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int32([[1]]), np.int32([0])), + (np.int32([[10], [20]]), np.int32([15])), + ) + def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape): + """Tests invalid padded batching of dense tensor windows. + + Args: + shapes: the input shapes + padded_shape: the shape to pad the output to + """ + + dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( + lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply( + grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset( + lambda x: batching.padded_batch_window(x, padded_shape))) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + def _structuredRaggedSparseDataset(self, structure, shapes, dtype): + + def map_fn(shape): + dense_to_sparse = self._make_dense_to_sparse_fn(False) + return dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) + + if structure is None: + return dataset_ops.Dataset.from_tensor_slices(shapes).map(map_fn) + else: + return dataset_ops.Dataset.zip( + tuple([ + self._structuredRaggedSparseDataset(substructure, shapes, dtype) + for substructure in structure + ])) + + def _structuredRaggedSparseElement(self, structure, shapes, dtype, + padded_shape): + if structure is None: + dense_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) + values = [] + for shape in shapes: + dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test + sparse = dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) + padded_sparse = sparse_tensor.SparseTensor(sparse.indices, + sparse.values, dense_shape) + reshaped_sparse = sparse_ops.sparse_reshape( + padded_sparse, + array_ops.concat([np.array([1], dtype=np.int64), dense_shape], 0)) + values.append(reshaped_sparse) + return sparse_ops.sparse_concat(0, values) + else: + return tuple([ + self._structuredRaggedSparseElement(substructure, shapes, dtype, + padded_shape) + for substructure in structure + ]) + + @parameterized.parameters( + (None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.string, [-1]), + (None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), + (None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]), + ((None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ((None, (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])), + ) + def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype, + padded_shape): + """Tests padded batching of sparse tensor windows. + + Args: + structure: the input structure + shapes: the input shapes + dtype: the input data type + padded_shape: the shape to pad the output to + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return batching.padded_batch_window(args[0], padded_shape) + + return tuple([ + fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window( + arg, padded_shape) for arg in args + ]) + + dataset = self._structuredRaggedSparseDataset( + structure, shapes, dtype).apply(grouping.window_dataset( + len(shapes))).apply(grouping._map_x_dataset(fn)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected = sess.run( + self._structuredRaggedSparseElement(structure, shapes, dtype, + padded_shape)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int64([[1], [2], [3]]), [-1]), + (np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]), + (np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), + ) + def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes, + padded_shape): + """Tests padded batching of dynamically shaped sparse tensor windows. + + Args: + shapes: the input shapes + padded_shape: the shape to pad the output to + """ + + shapes_t = array_ops.placeholder(dtypes.int32) + dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map( + lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map( + self._make_dense_to_sparse_fn(False) + ).apply(grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset( + lambda x: batching.padded_batch_window(x, padded_shape))) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op, {shapes_t: shapes}) + expected = sess.run( + self._structuredRaggedSparseElement(None, shapes, dtypes.int32, + padded_shape)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int64([[1]]), [0]), + (np.int64([[10], [20]]), [15]), + ) + def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape): + """Tests invalid padded batching of sparse tensor windows. + + Args: + shapes: the input shapes + padded_shape: the shape to pad the output to + """ + + dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( + lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map( + self._make_dense_to_sparse_fn(False) + ).apply(grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset( + lambda x: batching.padded_batch_window(x, padded_shape))) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 33b7a75046cf2acfa3d787833b907aa2b28dbdca..ad9378dfb9d938c826f994da9bbb89101cfbd872 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -28,10 +28,12 @@ py_library( srcs = ["get_single_element.py"], srcs_version = "PY2AND3", deps = [ + ":grouping", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", + "//third_party/py/numpy", ], ) @@ -49,26 +51,6 @@ py_library( ], ) -py_test( - name = "iterator_ops_test", - size = "small", - srcs = ["iterator_ops_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":iterator_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:model_fn", - ], -) - py_library( name = "random_ops", srcs = [ @@ -135,6 +117,8 @@ py_library( srcs = ["batching.py"], srcs_version = "PY2AND3", deps = [ + ":get_single_element", + ":grouping", "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", @@ -147,6 +131,7 @@ py_library( "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", + "//third_party/py/numpy", ], ) @@ -225,6 +210,17 @@ py_library( ], ) +py_library( + name = "map_defun", + srcs = ["map_defun.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_shape", + ], +) + py_library( name = "resampling", srcs = ["resampling.py"], @@ -385,6 +381,7 @@ py_library( ":get_single_element", ":grouping", ":interleave_ops", + ":map_defun", ":optimization", ":prefetching_ops", ":readers", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 052618e08c8f204613db5a20d42e078f17f12840..9f059942a65177186132164531237f838ecd63a2 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,22 +17,134 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + +from tensorflow.contrib.data.python.ops import get_single_element +from tensorflow.contrib.data.python.ops import grouping from tensorflow.contrib.framework import with_shape from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.util import deprecation +def batch_window(dataset): + """Batches a window of tensors. + + Args: + dataset: the input dataset. + + Returns: + A `Tensor` representing the batch of the entire input dataset. + """ + if isinstance(dataset.output_classes, tuple): + raise TypeError("Input dataset expected to have a single component") + if dataset.output_classes is ops.Tensor: + return _batch_dense_window(dataset) + elif dataset.output_classes is sparse_tensor.SparseTensor: + return _batch_sparse_window(dataset) + else: + raise TypeError("Unsupported dataset type: %s" % dataset.output_classes) + + +def _batch_dense_window(dataset): + """Batches a window of dense tensors.""" + + def key_fn(_): + return np.int64(0) + + def shape_init_fn(_): + return array_ops.shape(first_element) + + def shape_reduce_fn(state, value): + check_ops.assert_equal(state, array_ops.shape(value)) + return state + + def finalize_fn(state): + return state + + if dataset.output_shapes.is_fully_defined(): + shape = dataset.output_shapes + else: + first_element = get_single_element.get_single_element(dataset.take(1)) + shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, + finalize_fn) + shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) + + def batch_init_fn(_): + batch_shape = array_ops.concat([[0], shape], 0) + return gen_array_ops.empty(batch_shape, dtype=dataset.output_types) + + def batch_reduce_fn(state, value): + return array_ops.concat([state, [value]], 0) + + batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer))) + + +def _batch_sparse_window(dataset): + """Batches a window of sparse tensors.""" + + def key_fn(_): + return np.int64(0) + + def shape_init_fn(_): + return first_element.dense_shape + + def shape_reduce_fn(state, value): + check_ops.assert_equal(state, value.dense_shape) + return state + + def finalize_fn(state): + return state + + if dataset.output_shapes.is_fully_defined(): + shape = dataset.output_shapes + else: + first_element = get_single_element.get_single_element(dataset.take(1)) + shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, + finalize_fn) + shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) + + def batch_init_fn(_): + indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0) + return sparse_tensor.SparseTensor( + indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), + values=constant_op.constant([], shape=[0], dtype=dataset.output_types), + dense_shape=array_ops.concat( + [np.array([0], dtype=np.int64), + math_ops.cast(shape, dtypes.int64)], 0)) + + def batch_reduce_fn(state, value): + return sparse_ops.sparse_concat(0, [state, value]) + + def reshape_fn(value): + return sparse_ops.sparse_reshape( + value, + array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0)) + + batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.map(reshape_fn).apply( + grouping.group_by_reducer(key_fn, batch_reducer))) + + def dense_to_sparse_batch(batch_size, row_shape): """A transformation that batches ragged elements into `tf.SparseTensor`s. @@ -73,21 +185,172 @@ def dense_to_sparse_batch(batch_size, row_shape): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): - return DenseToSparseBatchDataset(dataset, batch_size, row_shape) + return _DenseToSparseBatchDataset(dataset, batch_size, row_shape) return _apply_fn -class UnbatchDataset(dataset_ops.Dataset): +def padded_batch_window(dataset, padded_shape, padding_value=None): + """Batches a window of tensors with padding. + + Args: + dataset: the input dataset. + padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like + object representing the shape to which the input elements should be padded + prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a + `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the + maximum size of that dimension in each batch. + padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the + padding value to use. Defaults are `0` for numeric types and the empty + string for string types. If `dataset` contains `tf.SparseTensor`, this + value is ignored. + + Returns: + A `Tensor` representing the batch of the entire input dataset. + + Raises: + ValueError: if invalid arguments are provided. + """ + if not issubclass(dataset.output_classes, + (ops.Tensor, sparse_tensor.SparseTensor)): + raise TypeError("Input dataset expected to have a single tensor component") + if issubclass(dataset.output_classes, (ops.Tensor)): + return _padded_batch_dense_window(dataset, padded_shape, padding_value) + elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)): + if padding_value is not None: + raise ValueError("Padding value not allowed for sparse tensors") + return _padded_batch_sparse_window(dataset, padded_shape) + else: + raise TypeError("Unsupported dataset type: %s" % dataset.output_classes) + + +def _padded_batch_dense_window(dataset, padded_shape, padding_value=None): + """Batches a window of dense tensors with padding.""" + + padded_shape = math_ops.cast( + convert.partial_shape_to_tensor(padded_shape), dtypes.int32) + + def key_fn(_): + return np.int64(0) + + def max_init_fn(_): + return padded_shape + + def max_reduce_fn(state, value): + """Computes the maximum shape to pad to.""" + condition = math_ops.reduce_all( + math_ops.logical_or( + math_ops.less_equal(array_ops.shape(value), padded_shape), + math_ops.equal(padded_shape, -1))) + assert_op = control_flow_ops.Assert(condition, [ + "Actual shape greater than padded shape: ", + array_ops.shape(value), padded_shape + ]) + with ops.control_dependencies([assert_op]): + return math_ops.maximum(state, array_ops.shape(value)) + + def finalize_fn(state): + return state + + # Compute the padded shape. + max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) + padded_shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) + + if padding_value is None: + if dataset.output_types == dtypes.string: + padding_value = "" + elif dataset.output_types == dtypes.bool: + padding_value = False + elif dataset.output_types == dtypes.variant: + raise TypeError("Unable to create padding for field of type 'variant'") + else: + padding_value = 0 + + def batch_init_fn(_): + return array_ops.fill( + array_ops.concat([np.array([0], dtype=np.int32), padded_shape], 0), + constant_op.constant(padding_value, dtype=dataset.output_types)) + + def batch_reduce_fn(state, value): + return array_ops.concat([state, [value]], 0) + + def pad_fn(value): + shape = array_ops.shape(value) + left = array_ops.zeros_like(shape) + right = padded_shape - shape + return array_ops.pad( + value, array_ops.stack([left, right], 1), constant_values=padding_value) + + batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.map(pad_fn).apply( + grouping.group_by_reducer(key_fn, batch_reducer))) + + +def _padded_batch_sparse_window(dataset, padded_shape): + """Batches a window of sparse tensors with padding.""" + + def key_fn(_): + return np.int64(0) + + def max_init_fn(_): + return convert.partial_shape_to_tensor(padded_shape) + + def max_reduce_fn(state, value): + """Computes the maximum shape to pad to.""" + condition = math_ops.reduce_all( + math_ops.logical_or( + math_ops.less_equal(value.dense_shape, padded_shape), + math_ops.equal(padded_shape, -1))) + assert_op = control_flow_ops.Assert(condition, [ + "Actual shape greater than padded shape: ", value.dense_shape, + padded_shape + ]) + with ops.control_dependencies([assert_op]): + return math_ops.maximum(state, value.dense_shape) + + def finalize_fn(state): + return state + + # Compute the padded shape. + max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) + padded_shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) + + def batch_init_fn(_): + indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]], + 0) + return sparse_tensor.SparseTensor( + indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), + values=constant_op.constant([], shape=[0], dtype=dataset.output_types), + dense_shape=array_ops.concat( + [np.array([0], dtype=np.int64), padded_shape], 0)) + + def batch_reduce_fn(state, value): + padded_value = sparse_tensor.SparseTensor( + indices=value.indices, values=value.values, dense_shape=padded_shape) + reshaped_value = sparse_ops.sparse_reshape( + padded_value, + array_ops.concat( + [np.array([1], dtype=np.int64), padded_value.dense_shape], 0)) + return sparse_ops.sparse_concat(0, [state, reshaped_value]) + + reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, reducer))) + + +class _UnbatchDataset(dataset_ops.Dataset): """A dataset that splits the elements of its input into multiple elements.""" def __init__(self, input_dataset): """See `unbatch()` for more details.""" - super(UnbatchDataset, self).__init__() + super(_UnbatchDataset, self).__init__() flat_shapes = nest.flatten(input_dataset.output_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") @@ -138,13 +401,13 @@ def unbatch(): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" if not sparse.any_sparse(dataset.output_classes): - return UnbatchDataset(dataset) + return _UnbatchDataset(dataset) # NOTE(mrry): We must ensure that any SparseTensors in `dataset` # are normalized to the rank-1 dense representation, so that the @@ -170,49 +433,7 @@ def unbatch(): dataset.output_shapes, dataset.output_classes, allow_unsafe_cast=True) - return UnbatchDataset(restructured_dataset) - - return _apply_fn - - -def filter_irregular_batches(batch_size): - """Transformation that filters out batches that are not of size batch_size.""" - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - tensor_batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") - - flattened = _RestructuredDataset( - dataset, - tuple(nest.flatten(dataset.output_types)), - output_classes=tuple(nest.flatten(dataset.output_classes))) - - def _predicate(*xs): - """Return `True` if this element is a full batch.""" - # Extract the dynamic batch size from the first component of the flattened - # batched element. - first_component = xs[0] - first_component_batch_size = array_ops.shape( - first_component, out_type=dtypes.int64)[0] - - return math_ops.equal(first_component_batch_size, tensor_batch_size) - - filtered = flattened.filter(_predicate) - - maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) - - def _set_first_dimension(shape): - return shape.merge_with( - tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) - - known_shapes = nest.map_structure(_set_first_dimension, - dataset.output_shapes) - return _RestructuredDataset( - filtered, - dataset.output_types, - known_shapes, - output_classes=dataset.output_classes) + return _UnbatchDataset(restructured_dataset) return _apply_fn @@ -222,7 +443,7 @@ def filter_irregular_batches(batch_size): def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). - Like @{tf.data.Dataset.batch}, this transformation combines + Like `tf.data.Dataset.batch`, this transformation combines consecutive elements of this dataset into batches. However, if the batch size does not evenly divide the input dataset size, this transformation will drop the final smaller element. @@ -246,15 +467,12 @@ def batch_and_drop_remainder(batch_size): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply} + `tf.data.Dataset.apply` """ def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time - # after 6/30/2018. - batched = dataset.batch(batch_size) - return filter_irregular_batches(batch_size)(batched) + return dataset.batch(batch_size, drop_remainder=True) return _apply_fn @@ -266,44 +484,42 @@ def padded_batch_and_drop_remainder(batch_size, padding_values=None): """A batching and padding transformation that omits the final small batch. - Like @{tf.data.Dataset.padded_batch}, this transformation combines + Like `tf.data.Dataset.padded_batch`, this transformation combines consecutive elements of this dataset into batches. However, if the batch size does not evenly divide the input dataset size, this transformation will drop the final smaller element. - See `@{tf.contrib.data.batch_and_drop_remainder}` for more details. + See `tf.contrib.data.batch_and_drop_remainder` for more details. Args: batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements of this dataset to combine in a single batch. padded_shapes: A nested structure of `tf.TensorShape` or `tf.int64` vector tensor-like objects. See - @{tf.data.Dataset.padded_batch} for details. + `tf.data.Dataset.padded_batch` for details. padding_values: (Optional.) A nested structure of scalar-shaped - `tf.Tensor`. See @{tf.data.Dataset.padded_batch} for details. + `tf.Tensor`. See `tf.data.Dataset.padded_batch` for details. Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply} + `tf.data.Dataset.apply` """ def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - # TODO(jsimsa): Switch to using `padded_batch(..., drop_remainder=True)` - # any time after 6/30/2018. - batched = dataset.padded_batch( - batch_size, padded_shapes=padded_shapes, padding_values=padding_values) - return filter_irregular_batches(batch_size)(batched) + return dataset.padded_batch( + batch_size, padded_shapes=padded_shapes, padding_values=padding_values, + drop_remainder=True) return _apply_fn -class DenseToSparseBatchDataset(dataset_ops.Dataset): +class _DenseToSparseBatchDataset(dataset_ops.Dataset): """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" def __init__(self, input_dataset, batch_size, row_shape): """See `Dataset.dense_to_sparse_batch()` for more details.""" - super(DenseToSparseBatchDataset, self).__init__() + super(_DenseToSparseBatchDataset, self).__init__() if not isinstance(input_dataset.output_types, dtypes.DType): raise TypeError("DenseToSparseDataset requires an input whose elements " "have a single component, whereas the input has %r." % @@ -445,7 +661,7 @@ def assert_element_shape(expected_shapes): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply} + `tf.data.Dataset.apply` """ def _check_shape(*elements): @@ -544,7 +760,7 @@ def map_and_batch(map_func, Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. Raises: ValueError: If both `num_parallel_batches` and `num_parallel_calls` are diff --git a/tensorflow/contrib/data/python/ops/enumerate_ops.py b/tensorflow/contrib/data/python/ops/enumerate_ops.py index ac2b386b81532b801139baa00fd5edd4ecd6ef0a..490281e0d2da7a454a2f63f95753c7c436b87a76 100644 --- a/tensorflow/contrib/data/python/ops/enumerate_ops.py +++ b/tensorflow/contrib/data/python/ops/enumerate_ops.py @@ -47,7 +47,7 @@ def enumerate_dataset(start=0): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 5f5513849cb29a18b86ba8bcee1ab6c9c60674cb..b4a7521e0875089c39ac7aa8b7b49e44feb2b4ad 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -42,21 +42,21 @@ def ignore_errors(): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): - return IgnoreErrorsDataset(dataset) + return _IgnoreErrorsDataset(dataset) return _apply_fn -class IgnoreErrorsDataset(dataset_ops.Dataset): +class _IgnoreErrorsDataset(dataset_ops.Dataset): """A `Dataset` that silently ignores errors when computing its input.""" def __init__(self, input_dataset): """See `Dataset.ignore_errors()` for details.""" - super(IgnoreErrorsDataset, self).__init__() + super(_IgnoreErrorsDataset, self).__init__() self._input_dataset = input_dataset def _as_variant_tensor(self): diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index 0f4cd8e20c5727a5bcfa1dce4dadbfa8f90bd551..a6713b017afa315edec9389d0a6c1c7135e6aeb9 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -17,6 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + +from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse @@ -26,8 +29,8 @@ from tensorflow.python.ops import gen_dataset_ops def get_single_element(dataset): """Returns the single element in `dataset` as a nested structure of tensors. - This function enables you to use a @{tf.data.Dataset} in a stateless - "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. + This function enables you to use a `tf.data.Dataset` in a stateless + "tensor-in tensor-out" expression, without creating a `tf.data.Iterator`. This can be useful when your preprocessing transformations are expressed as a `Dataset`, and you want to use the transformation at serving time. For example: @@ -47,10 +50,10 @@ def get_single_element(dataset): ``` Args: - dataset: A @{tf.data.Dataset} object containing a single element. + dataset: A `tf.data.Dataset` object containing a single element. Returns: - A nested structure of @{tf.Tensor} objects, corresponding to the single + A nested structure of `tf.Tensor` objects, corresponding to the single element of `dataset`. Raises: @@ -68,3 +71,30 @@ def get_single_element(dataset): return sparse.deserialize_sparse_tensors( nested_ret, dataset.output_types, dataset.output_shapes, dataset.output_classes) + + +def reduce_dataset(dataset, reducer): + """Returns the result of reducing the `dataset` using `reducer`. + + Args: + dataset: A `tf.data.Dataset` object. + reducer: A `tf.contrib.data.Reducer` object representing the reduce logic. + + Returns: + A nested structure of `tf.Tensor` objects, corresponding to the result + of reducing `dataset` using `reducer`. + + Raises: + TypeError: if `dataset` is not a `tf.data.Dataset` object. + """ + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`dataset` must be a `tf.data.Dataset` object.") + + # The sentinel dataset is used in case the reduced dataset is empty. + sentinel_dataset = dataset_ops.Dataset.from_tensors( + reducer.finalize_func(reducer.init_func(np.int64(0)))) + reduced_dataset = dataset.apply( + grouping.group_by_reducer(lambda x: np.int64(0), reducer)) + + return get_single_element( + reduced_dataset.concatenate(sentinel_dataset).take(1)) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index f9f25e6a0687fe7167525847c64743f52a551fb0..6edc1d79902c571b34b6a0a108c4d62cb6097ccb 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -21,12 +21,9 @@ import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -53,12 +50,12 @@ def group_by_reducer(key_func, reducer): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return GroupByReducerDataset(dataset, key_func, reducer) + return _GroupByReducerDataset(dataset, key_func, reducer) return _apply_fn @@ -95,7 +92,7 @@ def group_by_window(key_func, Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. Raises: ValueError: if neither or both of {`window_size`, `window_size_func`} are @@ -116,8 +113,8 @@ def group_by_window(key_func, def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return GroupByWindowDataset(dataset, key_func, reduce_func, - window_size_func) + return _GroupByWindowDataset(dataset, key_func, reduce_func, + window_size_func) return _apply_fn @@ -145,20 +142,20 @@ def bucket_by_sequence_length(element_length_func, bucket_batch_sizes: `list`, batch size per bucket. Length should be `len(bucket_boundaries) + 1`. padded_shapes: Nested structure of `tf.TensorShape` to pass to - @{tf.data.Dataset.padded_batch}. If not provided, will use + `tf.data.Dataset.padded_batch`. If not provided, will use `dataset.output_shapes`, which will result in variable length dimensions being padded out to the maximum length in each batch. padding_values: Values to pad with, passed to - @{tf.data.Dataset.padded_batch}. Defaults to padding with 0. + `tf.data.Dataset.padded_batch`. Defaults to padding with 0. pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown size to maximum length in batch. If `True`, will pad dimensions with - unknown size to bucket boundary, and caller must ensure that the source - `Dataset` does not contain any elements with length longer than - `max(bucket_boundaries)`. + unknown size to bucket boundary minus 1 (i.e., the maximum length in each + bucket), and caller must ensure that the source `Dataset` does not contain + any elements with length longer than `max(bucket_boundaries)`. Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. Raises: ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. @@ -206,7 +203,7 @@ def bucket_by_sequence_length(element_length_func, none_filler = None if pad_to_bucket_boundary: err_msg = ("When pad_to_bucket_boundary=True, elements must have " - "length <= max(bucket_boundaries).") + "length < max(bucket_boundaries).") check = check_ops.assert_less( bucket_id, constant_op.constant(len(bucket_batch_sizes) - 1, @@ -216,7 +213,7 @@ def bucket_by_sequence_length(element_length_func, boundaries = constant_op.constant(bucket_boundaries, dtype=dtypes.int64) bucket_boundary = boundaries[bucket_id] - none_filler = bucket_boundary + none_filler = bucket_boundary - 1 shapes = make_padded_shapes( padded_shapes or grouped_dataset.output_shapes, none_filler=none_filler) @@ -230,39 +227,56 @@ def bucket_by_sequence_length(element_length_func, return _apply_fn -class _VariantDataset(dataset_ops.Dataset): - """A Dataset wrapper for a tf.variant-typed function argument.""" +def _map_x_dataset(map_func): + """A transformation that maps `map_func` across its input. - def __init__(self, dataset_variant, output_types, output_shapes, - output_classes): - super(_VariantDataset, self).__init__() - self._dataset_variant = dataset_variant - self._output_types = output_types - self._output_shapes = output_shapes - self._output_classes = output_classes + This transformation is similar to `tf.data.Dataset.map`, but in addition to + supporting dense and sparse tensor inputs, it also supports dataset inputs. - def _as_variant_tensor(self): - return self._dataset_variant + Args: + map_func: A function mapping a nested structure of tensors and/or datasets + (having shapes and types defined by `self.output_shapes` and + `self.output_types`) to another nested structure of tensors and/or + datasets. - @property - def output_classes(self): - return self._output_classes + Returns: + Dataset: A `Dataset`. + """ - @property - def output_shapes(self): - return self._output_shapes + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _MapXDataset(dataset, map_func) + + return _apply_fn - @property - def output_types(self): - return self._output_types +def window_dataset(window_size): + """A transformation that creates window datasets from the input dataset. -class GroupByReducerDataset(dataset_ops.Dataset): + The resulting datasets will contain `window_size` elements (or + `N % window_size` for the last dataset if `window_size` does not divide the + number of input elements `N` evenly). + + Args: + window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of the input dataset to combine into a window. + + Returns: + Dataset: A `Dataset`. + """ + + def _apply_fn(dataset): + return _WindowDataset(dataset, window_size) + + return _apply_fn + + +class _GroupByReducerDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a reduction.""" def __init__(self, input_dataset, key_func, reducer): """See `group_by_reducer()` for details.""" - super(GroupByReducerDataset, self).__init__() + super(_GroupByReducerDataset, self).__init__() self._input_dataset = input_dataset @@ -273,70 +287,27 @@ class GroupByReducerDataset(dataset_ops.Dataset): def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_key_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - # pylint: disable=protected-access - if dataset_ops._should_unpack_args(nested_args): - ret = key_func(*nested_args) - # pylint: enable=protected-access - else: - ret = key_func(nested_args) - ret = ops.convert_to_tensor(ret) - if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar(): - raise ValueError( - "`key_func` must return a single tf.int64 tensor. " - "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) - dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access - return ret - - self._key_func = tf_key_func - self._key_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + key_func, "tf.contrib.data.group_by_reducer()", input_dataset) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`key_func` must return a single tf.int64 tensor. " + "Got type=%s and shape=%s" + % (wrapped_func.output_types, wrapped_func.output_shapes)) + self._key_func = wrapped_func.function def _make_init_func(self, init_func): """Make wrapping Defun for init_func.""" - - @function.Defun(dtypes.int64) - def tf_init_func(key): - """A wrapper for Defun that facilitates shape inference.""" - key.set_shape([]) - ret = init_func(key) - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._state_classes = sparse.get_classes(ret) - self._state_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._state_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._init_func = tf_init_func - self._init_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + init_func, "tf.contrib.data.group_by_reducer()", + input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), + input_types=dtypes.int64) + self._init_func = wrapped_func.function + self._state_classes = wrapped_func.output_classes + self._state_shapes = wrapped_func.output_shapes + self._state_types = wrapped_func.output_types def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping Defun for reduce_func.""" @@ -346,85 +317,47 @@ class GroupByReducerDataset(dataset_ops.Dataset): need_to_rerun = True while need_to_rerun: - # Create a list in which `tf_reduce_func` will store the new shapes. - flat_new_state_shapes = [] - - @function.Defun(*(nest.flatten( - sparse.as_dense_types( - self._state_types, self._state_classes)) + nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes)))) - def tf_reduce_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - for arg, shape in zip( - args, - nest.flatten( - sparse.as_dense_shapes(self._state_shapes, self._state_classes)) - + nest.flatten( - sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes))): - arg.set_shape(shape) - - pivot = len(nest.flatten(self._state_shapes)) - nested_state_args = nest.pack_sequence_as(self._state_types, - args[:pivot]) - nested_state_args = sparse.deserialize_sparse_tensors( - nested_state_args, self._state_types, self._state_shapes, - self._state_classes) - nested_input_args = nest.pack_sequence_as(input_dataset.output_types, - args[pivot:]) - nested_input_args = sparse.deserialize_sparse_tensors( - nested_input_args, input_dataset.output_types, - input_dataset.output_shapes, input_dataset.output_classes) - - ret = reduce_func(nested_state_args, nested_input_args) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - # Extract shape information from the returned values. - flat_new_state = nest.flatten(ret) - flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state]) - - # Extract and validate type information from the returned values. - for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)): - if t.dtype != dtype: - raise TypeError( - "The element types for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_types, - nest.pack_sequence_as(self._state_types, - [t.dtype for t in flat_new_state]))) - - dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, - [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - # Use the private method that will execute `tf_reduce_func` but delay - # adding it to the graph in case we need to rerun the function. - tf_reduce_func._create_definition_if_needed() # pylint: disable=protected-access - + wrapped_func = dataset_ops.StructuredFunctionWrapper( + reduce_func, "tf.contrib.data.group_by_reducer()", + input_classes=(self._state_classes, input_dataset.output_classes), + input_shapes=(self._state_shapes, input_dataset.output_shapes), + input_types=(self._state_types, input_dataset.output_types), + add_to_graph=False) + + # Extract and validate class information from the returned values. + for new_state_class, state_class in zip( + nest.flatten(wrapped_func.output_classes), + nest.flatten(self._state_classes)): + if not issubclass(new_state_class, state_class): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_classes, wrapped_func.output_classes)) + + # Extract and validate type information from the returned values. + for new_state_type, state_type in zip( + nest.flatten(wrapped_func.output_types), + nest.flatten(self._state_types)): + if new_state_type != state_type: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, wrapped_func.output_types)) + + # Extract shape information from the returned values. flat_state_shapes = nest.flatten(self._state_shapes) + flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) weakened_state_shapes = [ - old.most_specific_compatible_shape(new) - for old, new in zip(flat_state_shapes, flat_new_state_shapes) + original.most_specific_compatible_shape(new) + for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False - for old_shape, weakened_shape in zip(flat_state_shapes, - weakened_state_shapes): - if old_shape.ndims is not None and ( + for original_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if original_shape.ndims is not None and ( weakened_shape.ndims is None or - old_shape.as_list() != weakened_shape.as_list()): + original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break @@ -432,52 +365,19 @@ class GroupByReducerDataset(dataset_ops.Dataset): self._state_shapes = nest.pack_sequence_as(self._state_shapes, weakened_state_shapes) - self._reduce_func = tf_reduce_func + self._reduce_func = wrapped_func.function self._reduce_func.add_to_graph(ops.get_default_graph()) def _make_finalize_func(self, finalize_func): """Make wrapping Defun for finalize_func.""" - - @function.Defun(*(nest.flatten( - sparse.as_dense_types(self._state_types, self._state_classes)))) - def tf_finalize_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - for arg, shape in zip( - args, - nest.flatten( - sparse.as_dense_shapes(self._state_shapes, self._state_classes))): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(self._state_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, self._state_types, self._state_shapes, - self._state_classes) - - ret = finalize_func(nested_args) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._output_classes = sparse.get_classes(ret) - self._output_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._output_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._finalize_func = tf_finalize_func - self._finalize_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + finalize_func, "tf.contrib.data.group_by_reducer()", + input_classes=self._state_classes, input_shapes=self._state_shapes, + input_types=self._state_types) + self._finalize_func = wrapped_func.function + self._output_classes = wrapped_func.output_classes + self._output_shapes = wrapped_func.output_shapes + self._output_types = wrapped_func.output_types @property def output_classes(self): @@ -505,12 +405,12 @@ class GroupByReducerDataset(dataset_ops.Dataset): **dataset_ops.flat_structure(self)) -class GroupByWindowDataset(dataset_ops.Dataset): +class _GroupByWindowDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a windowed reduction.""" def __init__(self, input_dataset, key_func, reduce_func, window_size_func): """See `group_by_window()` for details.""" - super(GroupByWindowDataset, self).__init__() + super(_GroupByWindowDataset, self).__init__() self._input_dataset = input_dataset @@ -520,77 +420,48 @@ class GroupByWindowDataset(dataset_ops.Dataset): def _make_window_size_func(self, window_size_func): """Make wrapping Defun for window_size_func.""" - - @function.Defun(dtypes.int64) - def tf_window_size_func(key): - key.set_shape([]) - window_size = ops.convert_to_tensor( - window_size_func(key), dtype=dtypes.int64) - if window_size.dtype != dtypes.int64: - raise ValueError( - "`window_size_func` must return a single tf.int64 tensor.") - dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access - return window_size - - self._window_size_func = tf_window_size_func - self._window_size_func.add_to_graph(ops.get_default_graph()) + def window_size_func_wrapper(key): + return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + window_size_func_wrapper, "tf.contrib.data.group_by_window()", + input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), + input_types=dtypes.int64) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`window_size_func` must return a single tf.int64 scalar tensor.") + self._window_size_func = wrapped_func.function def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_key_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - # pylint: disable=protected-access - if dataset_ops._should_unpack_args(nested_args): - ret = key_func(*nested_args) - # pylint: enable=protected-access - else: - ret = key_func(nested_args) - ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) - if ret.dtype != dtypes.int64: - raise ValueError("`key_func` must return a single tf.int64 tensor.") - dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access - return ret - - self._key_func = tf_key_func - self._key_func.add_to_graph(ops.get_default_graph()) + def key_func_wrapper(*args): + return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`key_func` must return a single tf.int64 scalar tensor.") + self._key_func = wrapped_func.function def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping Defun for reduce_func.""" - - @function.Defun(dtypes.int64, dtypes.variant) - def tf_reduce_func(key, window_dataset_variant): - """A wrapper for Defun that facilitates shape inference.""" - key.set_shape([]) - window_dataset = _VariantDataset( - window_dataset_variant, input_dataset.output_types, - input_dataset.output_shapes, input_dataset.output_classes) - if not isinstance(window_dataset, dataset_ops.Dataset): - raise TypeError("`window_dataset` must return a `Dataset` object.") - output_dataset = reduce_func(key, window_dataset) - if not isinstance(output_dataset, dataset_ops.Dataset): - raise TypeError("`reduce_func` must return a `Dataset` object.") - self._output_classes = output_dataset.output_classes - self._output_types = output_dataset.output_types - self._output_shapes = output_dataset.output_shapes - dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access - return output_dataset._as_variant_tensor() # pylint: disable=protected-access - - self._reduce_func = tf_reduce_func - self._reduce_func.add_to_graph(ops.get_default_graph()) + nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access + wrapped_func = dataset_ops.StructuredFunctionWrapper( + reduce_func, "tf.contrib.data.reduce_by_window()", + input_classes=(ops.Tensor, nested_dataset), + input_shapes=(tensor_shape.scalar(), nested_dataset), + input_types=(dtypes.int64, nested_dataset), + experimental_nested_dataset_support=True) + if not isinstance( + wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access + raise TypeError("`reduce_func` must return a `Dataset` object.") + self._output_classes = wrapped_func.output_classes.output_classes + self._output_types = wrapped_func.output_types.output_types + self._output_shapes = wrapped_func.output_shapes.output_shapes + self._reduce_func = wrapped_func.function @property def output_classes(self): @@ -641,3 +512,85 @@ class Reducer(object): @property def finalize_func(self): return self._finalize_func + + +class _MapXDataset(dataset_ops.Dataset): + """A `Dataset` that maps a function over elements in its input.""" + + def __init__(self, input_dataset, map_func): + """See `map_x_dataset()` for details.""" + super(_MapXDataset, self).__init__() + self._input_dataset = input_dataset + + wrapped_func = dataset_ops.StructuredFunctionWrapper( + map_func, + "tf.contrib.data.map_x_dataset()", + input_dataset, + experimental_nested_dataset_support=True) + self._output_classes = wrapped_func.output_classes + self._output_shapes = wrapped_func.output_shapes + self._output_types = wrapped_func.output_types + self._map_func = wrapped_func.function + + def _as_variant_tensor(self): + input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access + return gen_dataset_ops.map_dataset( + input_t, + self._map_func.captured_inputs, + f=self._map_func, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +class _WindowDataset(dataset_ops.Dataset): + """A dataset that creates window datasets from the input elements.""" + + def __init__(self, input_dataset, window_size): + """See `window_dataset()` for more details.""" + super(_WindowDataset, self).__init__() + self._input_dataset = input_dataset + self._window_size = ops.convert_to_tensor( + window_size, dtype=dtypes.int64, name="window_size") + self._output_classes = nest.pack_sequence_as( + input_dataset.output_classes, + [ + dataset_ops._NestedDatasetComponent( # pylint: disable=protected-access + output_classes=output_class, + output_shapes=output_shape, + output_types=output_type) + for output_class, output_shape, output_type in zip( + nest.flatten(input_dataset.output_classes), + nest.flatten(input_dataset.output_shapes), + nest.flatten(input_dataset.output_types)) + ]) + self._output_shapes = self._output_classes + self._output_types = self._output_classes + + def _as_variant_tensor(self): + return gen_dataset_ops.window_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._window_size, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 70153ac575758f16beff373941dfefb32bd342cf..5a1a35199abecc3890d5733ddf678af8d4098f33 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -42,7 +42,7 @@ def parallel_interleave(map_func, `parallel_interleave()` maps `map_func` across its input to produce nested datasets, and outputs their elements interleaved. Unlike - @{tf.data.Dataset.interleave}, it gets elements from `cycle_length` nested + `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested datasets in parallel, which increases the throughput, especially in the presence of stragglers. Furthermore, the `sloppy` argument can be used to improve performance, by relaxing the requirement that the outputs are produced @@ -79,7 +79,7 @@ def parallel_interleave(map_func, Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): return readers.ParallelInterleaveDataset( @@ -138,7 +138,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): return readers.ParallelInterleaveDataset( @@ -153,7 +153,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): return _apply_fn -class DirectedInterleaveDataset(dataset_ops.Dataset): +class _DirectedInterleaveDataset(dataset_ops.Dataset): """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" def __init__(self, selector_input, data_inputs): @@ -196,15 +196,15 @@ def sample_from_datasets(datasets, weights=None, seed=None): """Samples elements at random from the datasets in `datasets`. Args: - datasets: A list of @{tf.data.Dataset} objects with compatible structure. + datasets: A list of `tf.data.Dataset` objects with compatible structure. weights: (Optional.) A list of `len(datasets)` floating-point values where `weights[i]` represents the probability with which an element should be - sampled from `datasets[i]`, or a @{tf.data.Dataset} object where each + sampled from `datasets[i]`, or a `tf.data.Dataset` object where each element is such a list. Defaults to a uniform distribution across `datasets`. seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random seed that will be used to create the distribution. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. Returns: A dataset that interleaves elements from `datasets` at random, according to @@ -236,7 +236,7 @@ def sample_from_datasets(datasets, weights=None, seed=None): selector_input = dataset_ops.Dataset.zip( (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) - return DirectedInterleaveDataset(selector_input, datasets) + return _DirectedInterleaveDataset(selector_input, datasets) def choose_from_datasets(datasets, choice_dataset): @@ -262,8 +262,8 @@ def choose_from_datasets(datasets, choice_dataset): ``` Args: - datasets: A list of @{tf.data.Dataset} objects with compatible structure. - choice_dataset: A @{tf.data.Dataset} of scalar `tf.int64` tensors between + datasets: A list of `tf.data.Dataset` objects with compatible structure. + choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between `0` and `len(datasets) - 1`. Returns: @@ -280,4 +280,4 @@ def choose_from_datasets(datasets, choice_dataset): and choice_dataset.output_classes == ops.Tensor): raise TypeError("`choice_dataset` must be a dataset of scalar " "`tf.int64` tensors.") - return DirectedInterleaveDataset(choice_dataset, datasets) + return _DirectedInterleaveDataset(choice_dataset, datasets) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index 0d71be66018eeebe60de9deff24ceb6854d209d9..18515e21edfe0449514ab4f21683a600eaf48910 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -20,6 +20,7 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import session_run_hook @@ -117,7 +118,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): pipeline. For saving the input pipeline checkpoint alongside the model weights use - @{tf.contrib.data.make_saveable_from_iterator} directly to create a + `tf.contrib.data.make_saveable_from_iterator` directly to create a `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however, that you will need to be careful not to restore the training iterator during eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS @@ -206,7 +207,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): # Check if there is an existing checkpoint. If so, restore from it. # pylint: disable=protected-access - latest_checkpoint_path = saver_lib.latest_checkpoint( + latest_checkpoint_path = checkpoint_management.latest_checkpoint( self._checkpoint_saver_hook._checkpoint_dir, latest_filename=self._latest_filename) if latest_checkpoint_path: diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py new file mode 100644 index 0000000000000000000000000000000000000000..54d5cd6da068fa5471b7beafcc66d76b5972e7d5 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/map_defun.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for optimizing `tf.data` pipelines.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +def map_defun(fn, elems, output_dtypes, output_shapes): + """Map a function on the list of tensors unpacked from `elems` on dimension 0. + + Args: + fn: A function (`function.Defun`) that takes a list of tensors and returns + another list of tensors. The output list has the same types as + output_dtypes. The elements of the output list have the same dimension 0 + as `elems`, and the remaining dimensions correspond to those of + `fn_output_shapes`. + elems: A list of tensors. + output_dtypes: A list of dtypes corresponding to the output types of the + function. + output_shapes: A list of `TensorShape`s corresponding to the output + shapes from each invocation of the function on slices of inputs. + + Raises: + ValueError: if any of the inputs are malformed. + + Returns: + A list of `Tensor` objects with the same types as `output_dtypes`. + """ + if not isinstance(elems, list): + raise ValueError("`elems` must be a list of tensors.") + if not isinstance(output_dtypes, list): + raise ValueError("`output_dtypes` must be a list of tensors.") + if not isinstance(output_shapes, list): + raise ValueError("`output_shapes` must be a list of tensors.") + + elems = [ops.convert_to_tensor(e) for e in elems] + output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes] + if not all(s.is_fully_defined() for s in output_shapes): + raise ValueError("All fn output shapes must be fully defined.") + return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn) diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index 9612ac5ae910f8ee08d4b3ed9097a5c80266fcfd..fa1b851ad74bcf2cff69d42bce3eaa38822cd663 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -18,12 +18,34 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to +# account for indexing) and transformation sequence. +def assert_next(transformations): + """A transformation that asserts which transformations happen next. + + Args: + transformations: A `tf.string` vector `tf.Tensor` identifying the + transformations that are expected to happen next. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _AssertNextDataset(dataset, transformations) + + return _apply_fn + + def optimize(optimizations=None): """A transformation that applies optimizations. @@ -34,22 +56,53 @@ def optimize(optimizations=None): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return OptimizeDataset(dataset, optimizations) + return _OptimizeDataset(dataset, optimizations) return _apply_fn -class OptimizeDataset(dataset_ops.Dataset): +class _AssertNextDataset(dataset_ops.Dataset): + """A `Dataset` that asserts which transformations happen next.""" + + def __init__(self, input_dataset, transformations): + """See `assert_next()` for details.""" + super(_AssertNextDataset, self).__init__() + self._input_dataset = input_dataset + if transformations is None: + raise ValueError("At least one transformation should be specified") + self._transformations = ops.convert_to_tensor( + transformations, dtype=dtypes.string, name="transformations") + + def _as_variant_tensor(self): + return contrib_gen_dataset_ops.assert_next_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._transformations, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + +class _OptimizeDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and applies optimizations.""" def __init__(self, input_dataset, optimizations): """See `optimize()` for details.""" - super(OptimizeDataset, self).__init__() + super(_OptimizeDataset, self).__init__() self._input_dataset = input_dataset if optimizations is None: optimizations = [] @@ -61,6 +114,7 @@ class OptimizeDataset(dataset_ops.Dataset): self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._optimizations, **dataset_ops.flat_structure(self)) + @property def output_classes(self): return self._input_dataset.output_classes diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index e4c9f8b58a2a4390004b0ad318163526b443d44f..5222011d045efd9a64b4e89b248303cffbcb0b37 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -26,21 +26,43 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.eager import context +from tensorflow.python.framework import device as framework_device from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops +from tensorflow.python.ops import resource_variable_ops -# TODO(rohanj): Add a python class that constructs resource in the __init__ -# method and provides a get_next() that calls the prefetch op. def function_buffering_resource(string_arg, target_device, f, buffer_size, + output_types, container="", shared_name=None, name=None): + """Creates a FunctionBufferingResource. + + A FunctionBufferingResource fills up a buffer by calling a function `f` on + `target_device`. `f` should take in only a single string argument as input. + + Args: + string_arg: The single string argument to the function. + target_device: The device to run `f` on. + f: The function to be executed. + buffer_size: Size of the buffer to be populated. + output_types: The output types generated by the function. + container: (Optional) string. Defaults to "". + shared_name: (Optional) string. + name: (Optional) string to name the op. + + Returns: + Handle to a FunctionBufferingResource. + """ if shared_name is None: shared_name = "" return gen_dataset_ops.function_buffering_resource( @@ -50,7 +72,8 @@ def function_buffering_resource(string_arg, f=f, buffer_size=buffer_size, container=container, - name=name) + name=name, + output_types=output_types) def function_buffering_resource_get_next(function_buffer_resource, @@ -69,7 +92,7 @@ def function_buffering_resource_reset(function_buffer_resource, name=None): # pylint: disable=protected-access class _PrefetchToDeviceIterator(object): - """A replacement for @{tf.data.Iterator} that prefetches to another device. + """A replacement for `tf.data.Iterator` that prefetches to another device. Args: input_dataset: The input dataset @@ -123,7 +146,10 @@ class _PrefetchToDeviceIterator(object): target_device=iterator_device, string_arg=input_iterator_handle, buffer_size=buffer_size, - shared_name=shared_name) + shared_name=shared_name, + output_types=nest.flatten( + sparse.as_dense_types(self._input_dataset.output_types, + self._input_dataset.output_classes))) if not self._one_shot: reset_op = function_buffering_resource_reset(self._buffering_resource) @@ -132,7 +158,7 @@ class _PrefetchToDeviceIterator(object): self._input_dataset) def get_next(self, name=None): - """See @{tf.data.Iterator.get_next}.""" + """See `tf.data.Iterator.get_next`.""" self._get_next_call_count += 1 if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) @@ -173,7 +199,7 @@ class _PrefetchToDeviceIterator(object): class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): - """A replacement for @{tf.data.Iterator} that prefetches to another device. + """A replacement for `tf.data.Iterator` that prefetches to another device. Args: input_dataset: The input dataset @@ -212,6 +238,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): with ops.device(device): self._buffering_resource = function_buffering_resource( f=_prefetch_fn, + output_types=self._flat_output_types, target_device=gen_dataset_ops.iterator_get_device(self._resource), string_arg=input_iterator_handle, buffer_size=buffer_size, @@ -307,7 +334,7 @@ class _PrefetchToDeviceDataset(dataset_ops.Dataset): def prefetch_to_device(device, buffer_size=None): """A transformation that prefetches dataset values to the given `device`. - NOTE: Although the transformation creates a @{tf.data.Dataset}, the + NOTE: Although the transformation creates a `tf.data.Dataset`, the transformation must be the final `Dataset` in the input pipeline. Args: @@ -317,9 +344,367 @@ def prefetch_to_device(device, buffer_size=None): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): return _PrefetchToDeviceDataset(dataset, device, buffer_size) return _apply_fn + + +def copy_to_device(target_device, source_device="/cpu:0"): + """A transformation that copies dataset elements to the given `target_device`. + + Args: + target_device: The name of a device to which elements will be copied. + source_device: The original device on which `input_dataset` will be placed. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + return _CopyToDeviceDataset( + dataset, target_device=target_device, source_device=source_device) + + return _apply_fn + + +# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate +# all inputs to the Op are in host memory, thereby avoiding some unnecessary +# Sends and Recvs. +class _CopyToDeviceDataset(dataset_ops.Dataset): + """A `Dataset` that copies elements to another device.""" + + def __init__(self, input_dataset, target_device, source_device="/cpu:0"): + """Constructs a _CopyToDeviceDataset. + + Args: + input_dataset: `Dataset` to be copied + target_device: The name of the device to which elements would be copied. + source_device: Device where input_dataset would be placed. + """ + self._input_dataset = input_dataset + self._target_device = target_device + spec = framework_device.DeviceSpec().from_string(self._target_device) + self._is_gpu_target = (spec.device_type == "GPU") + self._source_device_string = source_device + self._source_device = ops.convert_to_tensor(source_device) + + self._flat_output_shapes = nest.flatten( + sparse.as_dense_shapes(self._input_dataset.output_shapes, + self._input_dataset.output_classes)) + self._flat_output_types = nest.flatten( + sparse.as_dense_types(self._input_dataset.output_types, + self._input_dataset.output_classes)) + + @function.Defun() + def _init_func(): + """Creates an iterator for the input dataset. + + Returns: + A `string` tensor that encapsulates the iterator created. + """ + # pylint: disable=protected-access + ds_variant = self._input_dataset._as_variant_tensor() + resource = core_gen_dataset_ops.anonymous_iterator( + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + with ops.control_dependencies( + [core_gen_dataset_ops.make_iterator(ds_variant, resource)]): + return core_gen_dataset_ops.iterator_to_string_handle(resource) + + @function.Defun() + def _remote_init_func(): + return functional_ops.remote_call( + target=self._source_device, + args=_init_func.captured_inputs, + Tout=[dtypes.string], + f=_init_func) + + self._init_func = _remote_init_func + self._init_captured_args = _remote_init_func.captured_inputs + + @function.Defun(dtypes.string) + def _next_func(string_handle): + """Calls get_next for created iterator. + + Args: + string_handle: An iterator string handle created by _init_func + Returns: + The elements generated from `input_dataset` + """ + with ops.device(self._source_device_string): + iterator = iterator_ops.Iterator.from_string_handle( + string_handle, self.output_types, self.output_shapes, + self.output_classes) + ret = iterator.get_next() + return nest.flatten(sparse.serialize_sparse_tensors(ret)) + + @function.Defun(dtypes.string) + def _remote_next_func(string_handle): + return functional_ops.remote_call( + target=self._source_device, + args=[string_handle] + _next_func.captured_inputs, + Tout=self._flat_output_types, + f=_next_func) + + self._next_func = _remote_next_func + self._next_captured_args = _remote_next_func.captured_inputs + + @function.Defun(dtypes.string) + def _finalize_func(string_handle): + """Destroys the iterator resource created. + + Args: + string_handle: An iterator string handle created by _init_func + Returns: + Tensor constant 0 + """ + iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2( + string_handle, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + with ops.control_dependencies([ + resource_variable_ops.destroy_resource_op( + iterator_resource, ignore_lookup_error=True)]): + return array_ops.constant(0, dtypes.int64) + + @function.Defun(dtypes.string) + def _remote_finalize_func(string_handle): + return functional_ops.remote_call( + target=self._source_device, + args=[string_handle] + _finalize_func.captured_inputs, + Tout=[dtypes.int64], + f=_finalize_func) + + self._finalize_func = _remote_finalize_func + self._finalize_captured_args = _remote_finalize_func.captured_inputs + + g = ops.get_default_graph() + _remote_init_func.add_to_graph(g) + _remote_next_func.add_to_graph(g) + _remote_finalize_func.add_to_graph(g) + # pylint: enable=protected-scope + + # The one_shot_iterator implementation needs a 0 arg _make_dataset function + # that thereby captures all the inputs required to create the dataset. Since + # there are strings that are inputs to the GeneratorDataset which can't be + # placed on a GPU, this fails for the GPU case. Therefore, disabling it for + # GPU + def make_one_shot_iterator(self): + if self._is_gpu_target: + raise ValueError("Cannot create a one shot iterator when using " + "`tf.contrib.data.copy_to_device()` on GPU. Please use " + "`Dataset.make_initializable_iterator()` instead.") + else: + return super(_CopyToDeviceDataset, self).make_one_shot_iterator() + + def _as_variant_tensor(self): + with ops.device(self._target_device): + return core_gen_dataset_ops.generator_dataset( + self._init_captured_args, + self._next_captured_args, + self._finalize_captured_args, + init_func=self._init_func, + next_func=self._next_func, + finalize_func=self._finalize_func, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +class _PerDeviceGenerator(dataset_ops.Dataset): + """A `dummy` generator dataset.""" + + def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, + source_device, target_device, output_shapes, output_types, + output_classes): + self._target_device = target_device + self._output_types = output_types + self._output_shapes = output_shapes + self._output_classes = output_classes + self._flat_output_shapes = nest.flatten( + sparse.as_dense_shapes(self._output_shapes, self._output_classes)) + self._flat_output_types = nest.flatten( + sparse.as_dense_types(self._output_types, self._output_classes)) + + multi_device_iterator_string_handle = ( + gen_dataset_ops.multi_device_iterator_to_string_handle( + multi_device_iterator_resource)) + + @function.Defun() + def _init_func(): + return multi_device_iterator_string_handle + + @function.Defun() + def _remote_init_func(): + return functional_ops.remote_call( + target=source_device, + args=_init_func.captured_inputs, + Tout=[dtypes.string], + f=_init_func) + + self._init_func = _remote_init_func + self._init_captured_args = _remote_init_func.captured_inputs + + @function.Defun(dtypes.string) + def _next_func(string_handle): + multi_device_iterator = ( + gen_dataset_ops.multi_device_iterator_from_string_handle( + string_handle=string_handle, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes)) + return gen_dataset_ops.multi_device_iterator_get_next_from_shard( + multi_device_iterator=multi_device_iterator, + shard_num=shard_num, + incarnation_id=incarnation_id, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + + @function.Defun(dtypes.string) + def _remote_next_func(string_handle): + return functional_ops.remote_call( + target=source_device, + args=[string_handle] + _next_func.captured_inputs, + Tout=self._flat_output_types, + f=_next_func) + + self._next_func = _remote_next_func + self._next_captured_args = _remote_next_func.captured_inputs + + @function.Defun(dtypes.string) + def _finalize_func(unused_string_handle): + return array_ops.constant(0, dtypes.int64) + + @function.Defun(dtypes.string) + def _remote_finalize_func(string_handle): + return functional_ops.remote_call( + target=source_device, + args=[string_handle] + _finalize_func.captured_inputs, + Tout=[dtypes.int64], + f=_finalize_func) + + self._finalize_func = _remote_finalize_func + self._finalize_captured_args = _remote_finalize_func.captured_inputs + + def _as_variant_tensor(self): + with ops.device(self._target_device): + return core_gen_dataset_ops.generator_dataset( + self._init_captured_args, + self._next_captured_args, + self._finalize_captured_args, + init_func=self._init_func, + next_func=self._next_func, + finalize_func=self._finalize_func, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + + @property + def output_types(self): + return self._output_types + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_classes(self): + return self._output_classes + + +class MultiDeviceIterator(object): + """An iterator over multiple devices.""" + + def __init__(self, + dataset, + devices, + max_buffer_size=1, + prefetch_buffer_size=1, + source_device="/cpu:0"): + """Constructs a MultiDeviceIterator. + + Args: + dataset: The input dataset to be iterated over. + devices: The list of devices to fetch data to. + max_buffer_size: Maximum size of the host side per device buffer to keep. + prefetch_buffer_size: if > 1, then we setup a buffer on each device + to prefetch into. + source_device: The host device to place the `dataset` on. + """ + self._dataset = dataset + self._devices = devices + self._source_device = source_device + self._source_device_tensor = ops.convert_to_tensor(source_device) + + self._flat_output_shapes = nest.flatten( + sparse.as_dense_shapes(self._dataset.output_shapes, + self._dataset.output_classes)) + self._flat_output_types = nest.flatten( + sparse.as_dense_types(self._dataset.output_types, + self._dataset.output_classes)) + + # Create the MultiDeviceIterator. + with ops.device(self._source_device): + self._multi_device_iterator_resource = ( + gen_dataset_ops.multi_device_iterator( + devices=self._devices, + shared_name="", + container="", + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes)) + + # The incarnation ID is used to ensure consistency between the per-device + # iterators and the multi-device iterator. + self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( + self._dataset._as_variant_tensor(), # pylint: disable=protected-access + self._multi_device_iterator_resource, + max_buffer_size=max_buffer_size) + + # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to + # initialize the device side of the pipeline. This would allow the + # MultiDeviceIterator to choose, for example, to move some transformations + # into the device side from its input. It might be useful in rewriting. + # Create the per device iterators. + self._device_iterators = [] + i = 0 + for device in self._devices: + ds = _PerDeviceGenerator( + i, self._multi_device_iterator_resource, self._incarnation_id, + self._source_device_tensor, device, self._dataset.output_shapes, + self._dataset.output_types, self._dataset.output_classes) + if prefetch_buffer_size > 0: + ds = ds.prefetch(prefetch_buffer_size) + with ops.device(device): + self._device_iterators.append(ds.make_initializable_iterator()) + i += 1 + + device_iterator_initializers = [ + iterator.initializer for iterator in self._device_iterators + ] + self._initializer = control_flow_ops.group(*device_iterator_initializers) + + def get_next(self): + result = [] + i = 0 + for device in self._devices: + with ops.device(device): + result.append(self._device_iterators[i].get_next()) + i += 1 + return result + + @property + def initializer(self): + return self._initializer diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 83095c7ba1c6465d18490e5197f71bf7f1fe2497..3882d4bfdbe899c2ce92f829cb331b32d3d50398 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -234,7 +234,7 @@ def make_tf_record_dataset( Args: file_pattern: List of files or patterns of TFRecord file paths. - See @{tf.gfile.Glob} for pattern rules. + See `tf.gfile.Glob` for pattern rules. batch_size: An int representing the number of records to combine in a single batch. parser_fn: (Optional.) A function accepting string input to parse @@ -286,11 +286,14 @@ def make_tf_record_dataset( dataset = _maybe_shuffle_and_repeat( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + drop_final_batch = drop_final_batch or num_epochs is None + if parser_fn is None: - if drop_final_batch: - dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) - else: - dataset = dataset.batch(batch_size) + dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch) else: # TODO(josh11b): if num_parallel_parser_calls is None, use some function # of num cores instead of map_and_batch's default behavior of one batch. @@ -326,6 +329,7 @@ def make_csv_dataset( num_parallel_parser_calls=2, sloppy=False, num_rows_for_inference=100, + compression_type=None, ): """Reads CSV files into a dataset. @@ -336,7 +340,7 @@ def make_csv_dataset( Args: file_pattern: List of files or patterns of file paths containing CSV - records. See @{tf.gfile.Glob} for pattern rules. + records. See `tf.gfile.Glob` for pattern rules. batch_size: An int representing the number of records to combine in a single batch. column_names: An optional list of strings that corresponds to the CSV @@ -399,6 +403,8 @@ def make_csv_dataset( num_rows_for_inference: Number of rows of a file to use for type inference if record_defaults is not provided. If None, reads all the rows of all the files. Defaults to 100. + compression_type: (Optional.) A `tf.string` scalar evaluating to one of + `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression. Returns: A dataset, where each element is a (features, labels) tuple that corresponds @@ -461,7 +467,9 @@ def make_csv_dataset( use_quote_delim=use_quote_delim, na_value=na_value, select_cols=select_columns, - header=header) + header=header, + compression_type=compression_type, + ) def map_fn(*columns): """Organizes columns into a features dictionary. @@ -488,8 +496,13 @@ def make_csv_dataset( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) # Apply batch before map for perf, because map has high overhead relative - # to the size of the computation in each map - dataset = dataset.batch(batch_size=batch_size) + # to the size of the computation in each map. + # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + dataset = dataset.batch(batch_size=batch_size, + drop_remainder=num_epochs is None) dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls) dataset = dataset.prefetch(prefetch_buffer_size) @@ -505,6 +518,7 @@ class CsvDataset(dataset_ops.Dataset): def __init__(self, filenames, record_defaults, + compression_type=None, buffer_size=None, header=False, field_delim=",", @@ -540,11 +554,11 @@ class CsvDataset(dataset_ops.Dataset): The expected output of its iterations is: ```python - next = dataset.make_one_shot_iterator().get_next() + next_element = dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: while True: try: - print(sess.run(nxt)) + print(sess.run(next_element)) except tf.errors.OutOfRangeError: break @@ -562,6 +576,9 @@ class CsvDataset(dataset_ops.Dataset): both this and `select_columns` are specified, these must have the same lengths, and `column_defaults` is assumed to be sorted in order of increasing column index. + compression_type: (Optional.) A `tf.string` scalar evaluating to one of + `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no + compression. buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes to buffer while reading files. Defaults to 4MB. header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) @@ -581,6 +598,11 @@ class CsvDataset(dataset_ops.Dataset): super(CsvDataset, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") + self._compression_type = convert.optional_param_to_tensor( + "compression_type", + compression_type, + argument_default="", + argument_dtype=dtypes.string) record_defaults = [ constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x for x in record_defaults @@ -621,6 +643,7 @@ class CsvDataset(dataset_ops.Dataset): use_quote_delim=self._use_quote_delim, na_value=self._na_value, select_cols=self._select_cols, + compression_type=self._compression_type, ) @property @@ -757,10 +780,12 @@ def make_batched_features_dataset(file_pattern, dataset = dataset.apply(stats_ops.feature_stats("record_stats")) - if drop_final_batch: - dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) - else: - dataset = dataset.batch(batch_size) + # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + dataset = dataset.batch( + batch_size, drop_remainder=drop_final_batch or num_epochs is None) # Parse `Example` tensors to a dictionary of `Feature` tensors. dataset = dataset.map( diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index 182a5c6ff36fcda8c9e2c522cce07bed0c2daec9..75642f143e19c3d77e675384362c4dab94e10932 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -50,7 +50,7 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 67eede981cb8f685ba4840d1f3c12bfea54c7646..6b002b4a533669dd0f5e82a00aa29224a83a7e57 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -22,7 +22,6 @@ import collections from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse -from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import gen_dataset_ops @@ -67,104 +66,45 @@ class _ScanDataset(dataset_ops.Dataset): need_to_rerun = True while need_to_rerun: - # Create a list in which `tf_scan_func` will store the new shapes. - flat_new_state_shapes = [] - - @function.Defun(*(nest.flatten( - sparse.as_dense_types( - self._state_types, self._state_classes)) + nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes)))) - def tf_scan_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the state and input_dataset. - for arg, shape in zip( - args, - nest.flatten( - sparse.as_dense_shapes(self._state_shapes, self._state_classes)) - + nest.flatten( - sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes))): - arg.set_shape(shape) - - pivot = len(nest.flatten(self._state_shapes)) - print(self._state_classes) - nested_state_args = nest.pack_sequence_as(self._state_types, - args[:pivot]) - nested_state_args = sparse.deserialize_sparse_tensors( - nested_state_args, self._state_types, self._state_shapes, - self._state_classes) - print(input_dataset.output_classes) - nested_input_args = nest.pack_sequence_as(input_dataset.output_types, - args[pivot:]) - nested_input_args = sparse.deserialize_sparse_tensors( - nested_input_args, input_dataset.output_types, - input_dataset.output_shapes, input_dataset.output_classes) - - ret = scan_func(nested_state_args, nested_input_args) - if not isinstance(ret, collections.Sequence) or len(ret) != 2: - raise TypeError("The scan function must return a pair comprising the " - "new state and the output value.") - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - new_state, output_value = ret - - # Extract and validate class information from the returned values. - for t, clazz in zip( - nest.flatten(new_state), nest.flatten(self._state_classes)): - if not isinstance(t, clazz): - raise TypeError( - "The element classes for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_classes, - nest.pack_sequence_as( - self._state_types, - [type(t) for t in nest.flatten(new_state)]))) - self._output_classes = sparse.get_classes(output_value) - - # Extract shape information from the returned values. - flat_new_state_shapes.extend( - [t.get_shape() for t in nest.flatten(new_state)]) - self._output_shapes = nest.pack_sequence_as( - output_value, [t.get_shape() for t in nest.flatten(output_value)]) - - # Extract and validate type information from the returned values. - for t, dtype in zip( - nest.flatten(new_state), nest.flatten(self._state_types)): - if t.dtype != dtype: - raise TypeError( - "The element types for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_types, - nest.pack_sequence_as( - self._state_types, - [t.dtype for t in nest.flatten(new_state)]))) - self._output_types = nest.pack_sequence_as( - output_value, [t.dtype for t in nest.flatten(output_value)]) - - dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access - - # Serialize any sparse tensors. - new_state = nest.pack_sequence_as(new_state, [ - t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state)) - ]) - output_value = nest.pack_sequence_as(output_value, [ - t for t in nest.flatten( - sparse.serialize_sparse_tensors(output_value)) - ]) - return nest.flatten(new_state) + nest.flatten(output_value) - - # Use the private method that will execute `tf_scan_func` but delay - # adding it to the graph in case we need to rerun the function. - tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access + wrapped_func = dataset_ops.StructuredFunctionWrapper( + scan_func, "tf.contrib.data.scan()", + input_classes=(self._state_classes, input_dataset.output_classes), + input_shapes=(self._state_shapes, input_dataset.output_shapes), + input_types=(self._state_types, input_dataset.output_types), + add_to_graph=False) + if not ( + isinstance(wrapped_func.output_types, collections.Sequence) and + len(wrapped_func.output_types) == 2): + raise TypeError("The scan function must return a pair comprising the " + "new state and the output value.") + + new_state_classes, self._output_classes = wrapped_func.output_classes + + # Extract and validate class information from the returned values. + for new_state_class, state_class in zip( + nest.flatten(new_state_classes), + nest.flatten(self._state_classes)): + if not issubclass(new_state_class, state_class): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_classes, new_state_classes)) + + # Extract and validate type information from the returned values. + new_state_types, self._output_types = wrapped_func.output_types + for new_state_type, state_type in zip( + nest.flatten(new_state_types), nest.flatten(self._state_types)): + if new_state_type != state_type: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, new_state_types)) + + # Extract shape information from the returned values. + new_state_shapes, self._output_shapes = wrapped_func.output_shapes flat_state_shapes = nest.flatten(self._state_shapes) + flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) @@ -180,12 +120,10 @@ class _ScanDataset(dataset_ops.Dataset): break if need_to_rerun: - # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun - # `tf_scan_func`. self._state_shapes = nest.pack_sequence_as(self._state_shapes, weakened_state_shapes) - self._scan_func = tf_scan_func + self._scan_func = wrapped_func.function self._scan_func.add_to_graph(ops.get_default_graph()) def _as_variant_tensor(self): @@ -213,7 +151,7 @@ class _ScanDataset(dataset_ops.Dataset): def scan(initial_state, scan_func): """A transformation that scans a function across an input dataset. - This transformation is a stateful relative of @{tf.data.Dataset.map}. + This transformation is a stateful relative of `tf.data.Dataset.map`. In addition to mapping `scan_func` across the elements of the input dataset, `scan()` accumulates one or more state tensors, whose initial values are `initial_state`. @@ -228,7 +166,7 @@ def scan(initial_state, scan_func): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): return _ScanDataset(dataset, initial_state, scan_func) diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index d7f8a73fe3d67bb83e44e962832ce34c116aef66..4356721704046199e8ef2938bde6d7d8bce68cc1 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -92,11 +92,11 @@ def shuffle_and_repeat(buffer_size, count=None, seed=None): indefinitely. seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random seed that will be used to create the distribution. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): # pylint: disable=missing-docstring diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index f935beb1a9e85d4901857e7781a5ed8473838fa5..8025dcdd16b0180aeb951a31de21e22b8e8c31c7 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -23,25 +23,29 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util import deprecation class _SlideDataset(dataset_ops.Dataset): """A `Dataset` that passes a sliding window over its input.""" - def __init__(self, input_dataset, window_size, stride=1): + def __init__(self, input_dataset, window_size, window_shift, window_stride): """See `sliding_window_batch` for details.""" super(_SlideDataset, self).__init__() self._input_dataset = input_dataset self._window_size = ops.convert_to_tensor( - window_size, dtype=dtypes.int64, name="window_size") - self._stride = ops.convert_to_tensor( - stride, dtype=dtypes.int64, name="stride") + window_size, dtype=dtypes.int64, name="window_stride") + self._window_stride = ops.convert_to_tensor( + window_stride, dtype=dtypes.int64, name="window_stride") + self._window_shift = ops.convert_to_tensor( + window_shift, dtype=dtypes.int64, name="window_shift") def _as_variant_tensor(self): return gen_dataset_ops.slide_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access window_size=self._window_size, - stride=self._stride, + window_shift=self._window_shift, + window_stride=self._window_stride, **dataset_ops.flat_structure(self)) @property @@ -61,38 +65,63 @@ class _SlideDataset(dataset_ops.Dataset): return self._input_dataset.output_types -def sliding_window_batch(window_size, stride=1): - """A sliding window with size of `window_size` and step of `stride`. +@deprecation.deprecated_args( + None, "stride is deprecated, use window_shift instead", "stride") +def sliding_window_batch(window_size, + stride=None, + window_shift=None, + window_stride=1): + """A sliding window over a dataset. - This transformation passes a sliding window over this dataset. The - window size is `window_size` and step size is `stride`. If the left - elements cannot fill up the sliding window, this transformation will - drop the final smaller element. For example: + This transformation passes a sliding window over this dataset. The window size + is `window_size`, the stride of the input elements is `window_stride`, and the + shift between consecutive windows is `window_shift`. If the remaining elements + cannot fill up the sliding window, this transformation will drop the final + smaller element. For example: ```python # NOTE: The following examples use `{ ... }` to represent the # contents of a dataset. a = { [1], [2], [3], [4], [5], [6] } - a.apply(tf.contrib.data.sliding_window_batch(window_size=3, stride=2)) == - { - [[1], [2], [3]], - [[3], [4], [5]], - } + a.apply(sliding_window_batch(window_size=3)) == + { [[1], [2], [3]], [[2], [3], [4]], [[3], [4], [5]], [[4], [5], [6]] } + + a.apply(sliding_window_batch(window_size=3, window_shift=2)) == + { [[1], [2], [3]], [[3], [4], [5]] } + + a.apply(sliding_window_batch(window_size=3, window_stride=2)) == + { [[1], [3], [5]], [[2], [4], [6]] } ``` Args: window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - elements in the sliding window. + elements in the sliding window. It must be positive. stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the - steps moving the sliding window forward for one iteration. The default - is `1`. It must be in `[1, window_size)`. + forward shift of the sliding window in each iteration. The default is `1`. + It must be positive. Deprecated alias for `window_shift`. + window_shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + forward shift of the sliding window in each iteration. The default is `1`. + It must be positive. + window_stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + stride of the input elements in the sliding window. The default is `1`. + It must be positive. Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. + + Raises: + ValueError: if invalid arguments are provided. """ + if stride is None and window_shift is None: + window_shift = 1 + elif stride is not None and window_shift is None: + window_shift = stride + elif stride is not None and window_shift is not None: + raise ValueError("Cannot specify both `stride` and `window_shift`") + def _apply_fn(dataset): - return _SlideDataset(dataset, window_size, stride) + return _SlideDataset(dataset, window_size, window_shift, window_stride) return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 3c82a03df1745d855b2d3f918f7bbde113600556..3b4e98140234af0bf2128ac32f95dbdbf183cb54 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -23,11 +23,13 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. class StatsAggregator(object): """A stateful resource that aggregates statistics from one or more iterators. To record statistics, use one of the custom transformation functions defined - in this module when defining your @{tf.data.Dataset}. All statistics will be + in this module when defining your `tf.data.Dataset`. All statistics will be aggregated by the `StatsAggregator` that is associated with a particular iterator (see below). For example, to record the total number of bytes produced by iterating over a dataset: @@ -37,7 +39,7 @@ class StatsAggregator(object): dataset = dataset.apply(stats_ops.bytes_produced_stats("total_bytes")) ``` - To associate a `StatsAggregator` with a @{tf.data.Iterator} object, use + To associate a `StatsAggregator` with a `tf.data.Iterator` object, use the following pattern: ```python @@ -53,7 +55,7 @@ class StatsAggregator(object): To get a protocol buffer summary of the currently aggregated statistics, use the `StatsAggregator.get_summary()` tensor. The easiest way to do this - is to add the returned tensor to the @{tf.GraphKeys.SUMMARIES} collection, + is to add the returned tensor to the `tf.GraphKeys.SUMMARIES` collection, so that the summaries will be included with any existing summaries. ```python @@ -72,13 +74,13 @@ class StatsAggregator(object): self._resource = gen_dataset_ops.stats_aggregator_handle() def get_summary(self): - """Returns a string @{tf.Tensor} that summarizes the aggregated statistics. + """Returns a string `tf.Tensor` that summarizes the aggregated statistics. - The returned tensor will contain a serialized @{tf.summary.Summary} protocol + The returned tensor will contain a serialized `tf.summary.Summary` protocol buffer, which can be used with the standard TensorBoard logging facilities. Returns: - A scalar string @{tf.Tensor} that summarizes the aggregated statistics. + A scalar string `tf.Tensor` that summarizes the aggregated statistics. """ return gen_dataset_ops.stats_aggregator_summary(self._resource) @@ -110,7 +112,8 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset): return self._input_dataset.output_classes -# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`. +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def set_stats_aggregator(stats_aggregator): """Set the given stats_aggregator for aggregating the input dataset stats. @@ -119,7 +122,7 @@ def set_stats_aggregator(stats_aggregator): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): @@ -128,6 +131,8 @@ def set_stats_aggregator(stats_aggregator): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def bytes_produced_stats(tag): """Records the number of bytes produced by each element of the input dataset. @@ -140,7 +145,7 @@ def bytes_produced_stats(tag): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): @@ -150,6 +155,8 @@ def bytes_produced_stats(tag): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def latency_stats(tag): """Records the latency of producing each element of the input dataset. @@ -162,7 +169,7 @@ def latency_stats(tag): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): @@ -171,6 +178,8 @@ def latency_stats(tag): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def feature_stats(tag): """Records the features stats from `Example` records of the input dataset. @@ -183,7 +192,7 @@ def feature_stats(tag): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py index bb49604d4de90d726418684124608438aa33e6cf..dc67accdcfbc2692cbe0c961521897a316f40647 100644 --- a/tensorflow/contrib/data/python/ops/threadpool.py +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -37,22 +37,28 @@ def _generate_shared_name(prefix): return "{}{}".format(prefix, uid) +# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. class PrivateThreadPool(object): """A stateful resource that represents a private thread pool.""" - def __init__(self, num_threads, display_name=None): + def __init__(self, num_threads, display_name=None, + max_intra_op_parallelism=1): """Creates a `PrivateThreadPool` with the given number of threads.""" if context.executing_eagerly(): shared_name = _generate_shared_name("privatethreadpool") self._resource = gen_dataset_ops.thread_pool_handle( num_threads=num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, display_name=display_name, shared_name=shared_name) self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device=context.context().device_name) else: self._resource = gen_dataset_ops.thread_pool_handle( - num_threads=num_threads, display_name=display_name) + num_threads=num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, + display_name=display_name) class _ThreadPoolDataset(dataset_ops.Dataset): @@ -82,6 +88,8 @@ class _ThreadPoolDataset(dataset_ops.Dataset): return self._input_dataset.output_classes +# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def override_threadpool(dataset, thread_pool): """Returns a new dataset that uses the given thread pool for its operations. @@ -92,6 +100,6 @@ def override_threadpool(dataset, thread_pool): Returns: A dataset containing the same values as `dataset`, but which uses `thread_pool` to compute any of its parallel operations (such as - @{tf.data.Dataset.map}). + `tf.data.Dataset.map`). """ return _ThreadPoolDataset(dataset, thread_pool) diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index 4ce6ddede8350735636fd152fdc9df0319265990..e0d606311c4f2f678970113c1faa578dbf44b2ba 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -38,21 +38,21 @@ def unique(): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): - return UniqueDataset(dataset) + return _UniqueDataset(dataset) return _apply_fn -class UniqueDataset(dataset_ops.Dataset): +class _UniqueDataset(dataset_ops.Dataset): """A `Dataset` contains the unique elements from its input.""" def __init__(self, input_dataset): """See `unique()` for details.""" - super(UniqueDataset, self).__init__() + super(_UniqueDataset, self).__init__() self._input_dataset = input_dataset if input_dataset.output_types not in (dtypes.int32, dtypes.int64, dtypes.string): diff --git a/tensorflow/contrib/data/python/ops/writers.py b/tensorflow/contrib/data/python/ops/writers.py index f53bd3f7383950d6cfdb35e12811fb1daf24b320..c455fdcba673853079ff0d162c4799e72bc8e627 100644 --- a/tensorflow/contrib/data/python/ops/writers.py +++ b/tensorflow/contrib/data/python/ops/writers.py @@ -38,13 +38,13 @@ class TFRecordWriter(object): argument_dtype=dtypes.string) def write(self, dataset): - """Returns a @{tf.Operation} to write a dataset to a file. + """Returns a `tf.Operation` to write a dataset to a file. Args: - dataset: a @{tf.data.Dataset} whose elements are to be written to a file + dataset: a `tf.data.Dataset` whose elements are to be written to a file Returns: - A @{tf.Operation} that, when run, writes contents of `dataset` to a file. + A `tf.Operation` that, when run, writes contents of `dataset` to a file. """ if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index 74b2cd90a187159fd2da8ce236c14e813cc43c49..d3628d480d31017f835b39f750df40cafa2cc0db 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -25,11 +25,15 @@ py_library( srcs = ["__init__.py"], visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/distribute/python:collective_all_reduce_strategy", "//tensorflow/contrib/distribute/python:cross_tower_ops", "//tensorflow/contrib/distribute/python:mirrored_strategy", "//tensorflow/contrib/distribute/python:monitor", + "//tensorflow/contrib/distribute/python:multi_worker_strategy", "//tensorflow/contrib/distribute/python:one_device_strategy", + "//tensorflow/contrib/distribute/python:parameter_server_strategy", "//tensorflow/contrib/distribute/python:step_fn", + "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:training", "//tensorflow/python:util", ], diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 44a4481021c380e72b535cf0aca39df2bf04d3b7..2f5dd10550d0771d0cd3c2501d0456dc95077386 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -116,8 +116,6 @@ in the input function gives a solid boost in performance. When using ## Caveats This feature is in early stages and there are a lot of improvements forthcoming: -* Metrics are not yet supported during distributed training. They are still -supported during the evaluation. * Summaries are only computed in the first tower in `MirroredStrategy`. * Evaluation is not yet distributed. * Eager support is in the works; performance can be more challenging with eager diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 76711baf3a11c8978fbb5770ec173ff74a153158..2c93ce92cebff0a7c907093eae5f63470d34af1d 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -19,34 +19,44 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy from tensorflow.contrib.distribute.python.cross_tower_ops import * from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy from tensorflow.contrib.distribute.python.monitor import Monitor +from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy +from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * +from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy from tensorflow.python.training.distribute import * +from tensorflow.python.training.distribution_strategy_context import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'AllReduceCrossTowerOps', + 'CollectiveAllReduceStrategy', 'CrossTowerOps', 'DistributionStrategy', 'MirroredStrategy', + 'MultiWorkerMirroredStrategy', 'Monitor', 'OneDeviceStrategy', + 'ParameterServerStrategy', 'ReductionToOneDeviceCrossTowerOps', 'Step', 'StandardInputStep', 'StandardSingleLossStep', 'TowerContext', + 'TPUStrategy', 'get_cross_tower_context', 'get_distribution_strategy', 'get_loss_reduction', 'get_tower_context', 'has_distribution_strategy', 'require_tower_context', + 'UpdateContext', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 9dfb8552f1b0f058b44f8ed09c2ed681367293d5..40a1c1707cfdeaf5f5097ce661fa5f0613f804d0 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -57,7 +57,7 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python:device_util", "//tensorflow/python/eager:test", - "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:estimator_py", ], tags = [ "no_pip", @@ -100,6 +100,23 @@ py_library( ], ) +py_library( + name = "parameter_server_strategy", + srcs = ["parameter_server_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":cross_tower_ops", + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) + py_library( name = "one_device_strategy", srcs = ["one_device_strategy.py"], @@ -116,6 +133,24 @@ py_library( ], ) +py_library( + name = "collective_all_reduce_strategy", + srcs = ["collective_all_reduce_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":cross_tower_ops", + ":cross_tower_utils", + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:collective_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + ], +) + py_library( name = "strategy_test_lib", testonly = 1, @@ -152,6 +187,7 @@ py_library( ":multi_worker_strategy", ":one_device_strategy", ":tpu_strategy", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/optimizer_v2:training", "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", @@ -207,6 +243,35 @@ py_test( ], ) +py_test( + name = "parameter_server_strategy_test", + srcs = ["parameter_server_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":combinations", + ":multi_worker_test_base", + ":parameter_server_strategy", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:layers", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:estimator_py", + "@absl_py//absl/testing:parameterized", + ], +) + cuda_py_test( name = "mirrored_strategy_multigpu_test", srcs = ["mirrored_strategy_multigpu_test.py"], @@ -247,11 +312,11 @@ py_library( ], deps = [ "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", "//tensorflow/python:distributed_framework_test_lib", - "//tensorflow/python:platform", "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:estimator_py", + "//third_party/py/numpy", ], ) @@ -272,8 +337,7 @@ py_library( deps = [ ":one_device_strategy", ":values", - "//tensorflow/contrib/tpu", - "//tensorflow/contrib/tpu:tpu_py", + "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", @@ -281,6 +345,37 @@ py_library( ], ) +py_test( + name = "collective_all_reduce_strategy_test", + srcs = ["collective_all_reduce_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":collective_all_reduce_strategy", + ":combinations", + ":cross_tower_utils", + ":multi_worker_test_base", + ":strategy_test_lib", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:estimator_py", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + py_library( name = "minimize_loss_test_lib", testonly = 1, @@ -345,11 +440,7 @@ cuda_py_test( "//tensorflow/contrib/optimizer_v2:training", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:test", - "//tensorflow/python/estimator:dnn_linear_combined", - "//tensorflow/python/estimator:export_export", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/estimator:run_config", + "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", @@ -375,17 +466,27 @@ py_library( ], ) -cuda_py_test( - name = "step_fn_test", +py_library( + name = "step_fn_test_lib", + testonly = 1, srcs = ["step_fn_test.py"], - additional_deps = [ - ":single_loss_example", + deps = [ ":combinations", - "@absl_py//absl/testing:parameterized", - "//third_party/py/numpy", + ":single_loss_example", + "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:variables", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "step_fn_test", + srcs = ["step_fn_test.py"], + additional_deps = [ + ":step_fn_test_lib", ], tags = [ "multi_and_single_gpu", @@ -451,8 +552,11 @@ py_library( "//tensorflow/contrib/all_reduce:all_reduce_py", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/python:array_ops", + "//tensorflow/python:collective_ops", + "//tensorflow/python:device", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", "//tensorflow/python:math_ops", ], ) @@ -487,7 +591,9 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", + "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "@six_archive//:six", ], @@ -495,6 +601,7 @@ py_library( cuda_py_test( name = "cross_tower_ops_test", + size = "large", srcs = ["cross_tower_ops_test.py"], additional_deps = [ ":combinations", @@ -509,7 +616,6 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], - shard_count = 15, tags = [ "multi_and_single_gpu", "no_pip", @@ -581,13 +687,12 @@ cuda_py_test( "//tensorflow/contrib/distribute/python:mirrored_strategy", "//tensorflow/python:client_testlib", "//tensorflow/python:training", - "//tensorflow/python/estimator:keras", - "//tensorflow/python/estimator:run_config", + "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras", ], tags = [ "multi_and_single_gpu", - "noguitar", + "no_windows_gpu", "notsan", ], ) @@ -610,3 +715,40 @@ cuda_py_test( "no_pip", ], ) + +cuda_py_test( + name = "warm_starting_util_test", + size = "medium", + srcs = ["warm_starting_util_test.py"], + additional_deps = [ + ":combinations", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +cuda_py_test( + name = "checkpoint_utils_test", + size = "medium", + srcs = ["checkpoint_utils_test.py"], + additional_deps = [ + ":combinations", + "//tensorflow/python:client_testlib", + "//tensorflow/python:checkpoint_utils_test", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bcb977f64073b1d15ef5c872eb0d6b09d5307b54 --- /dev/null +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -0,0 +1,78 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for checkpoint_utils.init_from_checkpoint with Distribution Strategy. + +These tests are located here instead of as part of +`python.training.CheckpointsTest` because they need access to distribution +strategies which are only present in contrib right now. +TODO(priyag): Move the tests to core `python.training.CheckpointsTest` when +distribution strategy moves out of contrib. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import checkpoint_utils_test + + +class CheckpointUtilsWithDistributionStrategyTest( + test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + in_tower_mode=[True, False], + mode=["graph"])) + def testInitFromCheckpoint(self, distribution, in_tower_mode): + checkpoint_dir = self.get_temp_dir() + with self.test_session() as session: + v1_value, v2_value, _, _ = checkpoint_utils_test._create_checkpoints( + session, checkpoint_dir) + + def init_and_verify(g): + v1 = variable_scope.get_variable("new_var1", [1, 10]) + v2 = variable_scope.get_variable( + "new_var2", [10, 10], + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.MEAN) + checkpoint_utils.init_from_checkpoint(checkpoint_dir, { + "var1": "new_var1", + "var2": "new_var2" + }) + with self.test_session(graph=g) as session: + session.run(variables.global_variables_initializer()) + self.assertAllEqual(v1_value, self.evaluate(v1)) + self.assertAllEqual(v2_value, self.evaluate(v2)) + + with ops.Graph().as_default() as g, distribution.scope(): + if in_tower_mode: + distribution.call_for_each_tower(init_and_verify, g) + else: + init_and_verify(g) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..9afcaecf78844b011a9dbc30bb95fa3bfeda8470 --- /dev/null +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -0,0 +1,205 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class CollectiveAllReduceStrategy implementing DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os + +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import collective_ops +from tensorflow.python.training import server_lib + + +# TODO(yuefengz): move this function to a common util file. +def _normalize_cluster_spec(cluster_spec): + if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): + return server_lib.ClusterSpec(cluster_spec) + elif not isinstance(cluster_spec, server_lib.ClusterSpec): + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object") + return cluster_spec + + +# TODO(yuefengz): shard the dataset. +# TODO(yuefengz): support in-graph replication. +# TODO(yuefengz): it only works with a cluster without a chief node, maybe +# support chief node? +class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): + """Distribution strategy that uses collective ops for all-reduce. + + It is similar to the MirroredStrategy but it uses collective ops for + reduction. It currently only works for between-graph replication and its + reduction will reduce across all workers. + """ + + def __init__(self, + num_gpus_per_worker=0, + cluster_spec=None, + task_type="worker", + task_id=0): + """Initializes the object. + + Args: + num_gpus_per_worker: number of local GPUs or GPUs per worker. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type, such as "worker". + task_id: the current task id. + + Raises: + ValueError: if `task_type` is not in the `cluster_spec`. + """ + self._num_gpus_per_worker = num_gpus_per_worker + self._initialize(cluster_spec, task_type, task_id) + + def _initialize(self, cluster_spec, task_type, task_id): + if task_type not in ["chief", "worker"]: + raise ValueError( + "Unrecognized task_type: %r, valid task types are: \"chief\", " + "\"worker\"." % task_type) + if cluster_spec: + self._cluster_spec = _normalize_cluster_spec(cluster_spec) + worker_device = "/job:%s/task:%d" % (task_type, task_id) + num_workers = len(self._cluster_spec.as_dict().get(task_type, [])) + if "chief" in self._cluster_spec.as_dict(): + num_workers += 1 + if not num_workers: + raise ValueError("`task_type` shoud be in `cluster_spec`.") + + # TODO(yuefengz): create a utility to infer chief. + if "chief" in self._cluster_spec.as_dict() and task_type == "chief": + assert task_id == 0 + self._is_chief = True + else: + assert task_type == "worker" + self._is_chief = task_id == 0 + else: + self._cluster_spec = None + self._is_chief = True + worker_device = "" + num_workers = 1 + self._num_workers = num_workers + + if self._num_gpus_per_worker: + local_devices = [ + "%s/device:GPU:%d" % (worker_device, i) + for i in range(self._num_gpus_per_worker) + ] + else: + local_devices = [worker_device] + + self._collective_keys = cross_tower_utils.CollectiveKeys() + super(CollectiveAllReduceStrategy, self).__init__( + devices=local_devices, + cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce( + num_workers=num_workers, + num_gpus_per_worker=self._num_gpus_per_worker, + collective_keys=self._collective_keys)) + + # Add a default device so that ops without specified devices will not end up + # on other workers. + if cluster_spec: + self._default_device = "/job:%s/replica:0/task:%d" % (task_type, task_id) + + def _create_variable(self, next_creator, *args, **kwargs): + colocate_with = kwargs.pop("colocate_with", None) + devices = self._get_devices_from(colocate_with) + group_size = len(devices) * self._num_workers + group_key = self._collective_keys.get_group_key(self._devices) + + def _real_mirrored_creator(devices, *args, **kwargs): + """Creates one MirroredVariable on the current worker.""" + index = {} + collective_instance_key = self._collective_keys.get_instance_key( + key_id=kwargs["name"]) + if "initial_value" not in kwargs: + raise ValueError("Initial value must be specified.") + initial_value = kwargs["initial_value"] + if callable(initial_value): + initial_value_fn = initial_value + else: + initial_value_fn = lambda: initial_value + + for i, d in enumerate(devices): + with ops.device(d): + if i > 0: + # Give replicas meaningful distinct names: + var0name = index[devices[0]].name.split(":")[0] + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) + + # The initial value fn makes sure variables all initialized to + # same values. The first device of the chief worker will send their + # variable values to other devices and other workers. + def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring + with ops.device(device): + initial_value = initial_value_fn() + assert not callable(initial_value) + initial_value = ops.convert_to_tensor(initial_value) + + if self._is_chief and index == 0: + bcast_send = collective_ops.broadcast_send( + initial_value, initial_value.shape, initial_value.dtype, + group_size, group_key, collective_instance_key) + with ops.control_dependencies([bcast_send]): + return array_ops.identity(initial_value) + else: + return collective_ops.broadcast_recv( + initial_value.shape, initial_value.dtype, group_size, + group_key, collective_instance_key) + + kwargs["initial_value"] = _overridden_initial_value_fn + + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + v = next_creator(*args, **kwargs) + + assert not isinstance(v, values.DistributedVariable) + index[d] = v + return index + + # pylint: disable=protected-access + return mirrored_strategy._create_mirrored_variable( + devices, _real_mirrored_creator, *args, **kwargs) + + def configure(self, session_config=None): + # Use TF_CONFIG to get the cluster spec and the current job. + if not self._cluster_spec: + tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) + cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {})) + + task_env = tf_config.get("task", {}) + if task_env: + task_type = task_env.get("type", "worker") + task_id = int(task_env.get("index", "0")) + else: + task_type = "worker" + task_id = 0 + + if cluster_spec: + self._initialize(cluster_spec, task_type, task_id) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e54e3b7d7156e87731e6f79aa66262d127232c --- /dev/null +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -0,0 +1,217 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CollectiveAllReduceStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import collective_all_reduce_strategy +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import context +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class DistributedCollectiveAllReduceStrategyTest( + multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): + + collective_key_base = 0 + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' + ] + } + + def setUp(self): + self._run_options = config_pb2.RunOptions() + self._run_options.experimental.collective_graph_key = 6 + + self._sess_config = config_pb2.ConfigProto() + self._sess_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') + + # We use a different key_base for each test so that collective keys won't be + # reused. + # TODO(yuefengz, tucker): enable it to reuse collective keys in different + # tests. + DistributedCollectiveAllReduceStrategyTest.collective_key_base += 100000 + super(DistributedCollectiveAllReduceStrategyTest, self).setUp() + + def _get_test_object(self, task_type, task_id, num_gpus=0): + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus, + cluster_spec=self._cluster_spec, + task_type=task_type, + task_id=task_id) + collective_keys = cross_tower_utils.CollectiveKeys( + group_key_start=10 * num_gpus + + DistributedCollectiveAllReduceStrategyTest.collective_key_base, + instance_key_start=num_gpus * 100 + + DistributedCollectiveAllReduceStrategyTest.collective_key_base, + instance_key_with_id_start=num_gpus * 10000 + + DistributedCollectiveAllReduceStrategyTest.collective_key_base) + distribution._collective_keys = collective_keys + distribution._cross_tower_ops._collective_keys = collective_keys + return distribution, self._workers[task_id].target + + def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_object(task_type, task_id, num_gpus) + with ops.Graph().as_default(), \ + self.test_session(config=self._sess_config, + target=master_target) as sess, \ + d.scope(): + l = core.Dense(1, use_bias=False, name='gpu_%d' % d._num_gpus_per_worker) + + def loss_fn(x): + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + + # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for + # multiple graphs (b/111216820). + def grad_fn(x): + loss = loss_fn(x) + var_list = ( + variables.trainable_variables() + ops.get_collection( + ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + grads = gradients.gradients(loss, var_list) + ret = list(zip(grads, var_list)) + return ret + + def update(v, g): + return v.assign_sub(0.05 * g, use_locking=True) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one) + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.read_var(v) + before_list.append(fetched) + with ops.control_dependencies([fetched]): + # TODO(yuefengz): support non-Mirrored variable as destinations. + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.read_var(v)) + return before_list, after_list + + before_out, after_out = step() + + if context.num_gpus() < d._num_gpus_per_worker: + return True + + sess.run( + variables.global_variables_initializer(), options=self._run_options) + + for i in range(10): + b, a = sess.run((before_out, after_out), options=self._run_options) + if i == 0: + before, = b + after, = a + + error_before = abs(before - 1) + error_after = abs(after - 1) + # Error should go down + self.assertLess(error_after, error_before) + return error_after < error_before + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + def _test_variable_initialization(self, task_type, task_id, num_gpus): + distribution, master_target = self._get_test_object(task_type, task_id, + num_gpus) + with ops.Graph().as_default(), \ + self.test_session(config=self._sess_config, + target=master_target) as sess, \ + distribution.scope(): + + def model_fn(): + x = variable_scope.get_variable( + 'x', + shape=(2, 3), + initializer=init_ops.random_uniform_initializer( + 1.0, 10.0, dtype=dtypes.float32)) + return array_ops.identity(x) + + x = distribution.call_for_each_tower(model_fn) + reduced_x = distribution.unwrap( + distribution.reduce( + variable_scope.VariableAggregation.MEAN, x, + destinations='/cpu:0'))[0] + + sess.run( + variables.global_variables_initializer(), options=self._run_options) + x_value, reduced_x_value = sess.run( + [x, reduced_x], options=self._run_options) + self.assertTrue(np.array_equal(x_value, reduced_x_value)) + return np.array_equal(x_value, reduced_x_value) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testVariableInitialization(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients( + self._test_variable_initialization, + self._cluster_spec, + num_gpus=num_gpus) + + +class LocalCollectiveAllReduceStrategy(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + def testMinimizeLossGraph(self, num_gpus=2): + # Collective ops doesn't support strategy with one device. + if context.num_gpus() < num_gpus: + return + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + self._test_minimize_loss_graph(distribution) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index ba03b14deb9a3897dae29382ce601c0319f84735..aeec9c44d723cb4eedb6e1abc4c6fbcd64f14481 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -46,6 +46,7 @@ import unittest from absl.testing import parameterized import six +from tensorflow.contrib.cluster_resolver import TPUClusterResolver from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib from tensorflow.contrib.distribute.python import multi_worker_strategy from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib @@ -55,7 +56,7 @@ from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.training import adam -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import gradient_descent from tensorflow.python.util import tf_inspect @@ -144,7 +145,7 @@ def _augment_with_special_arguments(test_method): """A wrapped test method that treats some arguments in a special way.""" mode = kwargs.pop("mode", "graph") - distribution = kwargs.pop("distribution", None) + distribution = kwargs.get("distribution", None) required_tpu = kwargs.pop("required_tpu", False) required_gpus = kwargs.pop("required_gpus", None) @@ -153,7 +154,6 @@ def _augment_with_special_arguments(test_method): "Do not use `required_gpus` and `distribution` together.") assert required_tpu is False, ( "Do not use `required_tpu` and `distribution` together.") - kwargs["distribution"] = distribution.strategy required_gpus = distribution.required_gpus required_tpu = distribution.required_tpu @@ -189,9 +189,13 @@ def _augment_with_special_arguments(test_method): if mode == "eager": with ops.Graph().as_default(), context.eager_mode(): + if distribution: + kwargs_to_pass["distribution"] = distribution.strategy test_method(**kwargs_to_pass) elif mode == "graph": with ops.Graph().as_default(), context.graph_mode(): + if distribution: + kwargs_to_pass["distribution"] = distribution.strategy test_method(**kwargs_to_pass) else: raise ValueError( @@ -316,16 +320,15 @@ class NamedDistribution(object): # pylint: disable=g-long-lambda default_strategy = NamedDistribution( "Default", - lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access + distribution_strategy_context._get_default_distribution_strategy, # pylint: disable=protected-access required_gpus=None) one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) -tpu_strategy_single_iteration = NamedDistribution( - "TPUSingleIteration", - lambda: tpu_lib.TPUStrategy(iterations_per_step=1), +tpu_strategy = NamedDistribution( + "TPU", lambda: tpu_lib.TPUStrategy( + TPUClusterResolver(""), steps_per_run=5), required_tpu=True) -tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True) # Note that we disable prefetching for testing since prefetching makes # the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( @@ -371,12 +374,14 @@ adam_optimizer_v1_fn = NamedObject( "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) gradient_descent_optimizer_v1_fn = NamedObject( "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2)) +optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn] adam_optimizer_v2_fn = NamedObject( "AdamV2", lambda: adam_v2.AdamOptimizer(0.2, epsilon=1)) gradient_descent_optimizer_v2_fn = NamedObject( "GradientDescentV2", lambda: gradient_descent_v2.GradientDescentOptimizer(0.2)) +optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn] graph_and_eager_modes = ["graph", "eager"] @@ -388,7 +393,7 @@ def distributions_and_v1_optimizers(): one_device_strategy, mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus ], - optimizer_fn=[adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn]) + optimizer_fn=optimizers_v1) def distributions_and_v2_optimizers(): @@ -398,4 +403,4 @@ def distributions_and_v2_optimizers(): one_device_strategy, mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus ], - optimizer_fn=[adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn]) + optimizer_fn=optimizers_v2) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index f8ae8b9712c392fa948c8598dd123cdea01d9866..3a7addf2215d403cd94601f143d16a18d92b65af 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -28,17 +28,37 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import device_util -def _validate_destinations(destinations): - if not isinstance(destinations, - (value_lib.DistributedValues, six.string_types, list)): +def check_destinations(destinations): + """Checks whether `destinations` is not None and not empty. + + Args: + destinations: a DistributedValues, Variable, string or a list of strings. + + Returns: + Boolean indicating whether `destinations` is not None and not empty. + """ + # Calling bool() on a ResourceVariable is not allowed. + if isinstance(destinations, resource_variable_ops.ResourceVariable): + return bool(destinations.device) + return bool(destinations) + + +def validate_destinations(destinations): + if not isinstance( + destinations, + (value_lib.DistributedValues, resource_variable_ops.ResourceVariable, + six.string_types, list)): raise ValueError("destinations must be one of a `DistributedValues` object," - " a device string, a list of device strings or None") + " a tf.Variable object, a device string, a list of device " + "strings or None") - if not destinations: + if not check_destinations(destinations): raise ValueError("destinations can not be empty") @@ -55,9 +75,11 @@ def _validate_value_destination_pairs(value_destination_pairs): # TODO(yuefengz): consider calling this function in the caller of CrossTowerOps. -def _get_devices_from(destinations): +def get_devices_from(destinations): if isinstance(destinations, value_lib.DistributedValues): return list(destinations.devices) + elif isinstance(destinations, resource_variable_ops.ResourceVariable): + return [destinations.device] elif isinstance(destinations, six.string_types): return [device_util.resolve(destinations)] else: @@ -65,7 +87,7 @@ def _get_devices_from(destinations): def _devices_match(left, right): - return set(_get_devices_from(left)) == set(_get_devices_from(right)) + return set(get_devices_from(left)) == set(get_devices_from(right)) def _all_devices_match(value_destination_pairs): @@ -80,7 +102,7 @@ def _all_devices_match(value_destination_pairs): def _simple_broadcast(value, destinations): index = {} - devices = _get_devices_from(destinations) + devices = get_devices_from(destinations) for d in devices: index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( value, d) @@ -88,7 +110,7 @@ def _simple_broadcast(value, destinations): def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, - method_string): + aggregation): # pylint: disable=g-missing-docstring all_values = [] count = 0 @@ -112,11 +134,12 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( all_values, accumulation_fn) - if method_string == "mean": + if aggregation == vs.VariableAggregation.MEAN: reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( reduced, count) - elif method_string != "sum": - raise ValueError("`method_string` must be 'sum' or 'mean'") + elif aggregation != vs.VariableAggregation.SUM: + raise ValueError("`aggregation` must be VariableAggregation.SUM " + "or VariableAggregation.MEAN.") return reduced @@ -126,14 +149,15 @@ class CrossTowerOps(object): def __init__(self): pass - def reduce(self, method_string, per_device_value, destinations=None): + def reduce(self, aggregation, per_device_value, destinations=None): """Reduce `per_device_value` to `destinations`. - It runs the reduction operation defined by `method_string` and put the + It runs the reduction operation defined by `aggregation` and put the result on `destinations`. Args: - method_string: either 'sum' or 'mean' specifying the reduction method. + aggregation: Indicates how a variable will be aggregated. Accepted values + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. per_device_value: a PerDevice object. destinations: the reduction destinations. @@ -146,17 +170,18 @@ class CrossTowerOps(object): if not isinstance(per_device_value, value_lib.PerDevice): raise ValueError("`per_device_value` must be a `PerDevice` object.") if destinations is not None: - _validate_destinations(destinations) - return self._reduce(method_string, per_device_value, destinations) + validate_destinations(destinations) + return self._reduce(aggregation, per_device_value, destinations) - def batch_reduce(self, method_string, value_destination_pairs): + def batch_reduce(self, aggregation, value_destination_pairs): """Reduce PerDevice objects in a batch. Reduce each first element in `value_destination_pairs` to each second element which indicates the destinations. Args: - method_string: either 'sum' or 'mean' specifying the reduction method. + aggregation: Indicates how a variable will be aggregated. Accepted values + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. value_destination_pairs: a list or a tuple of tuples of PerDevice objects and destinations. If a destination is None, then the destinations are set to match the devices of the input PerDevice object. @@ -173,9 +198,9 @@ class CrossTowerOps(object): "tuples of PerDevice objects and destinations") for _, d in value_destination_pairs: if d is not None: - _validate_destinations(d) + validate_destinations(d) - return self._batch_reduce(method_string, value_destination_pairs) + return self._batch_reduce(aggregation, value_destination_pairs) def broadcast(self, tensor, destinations): """Broadcast the `tensor` to destinations. @@ -187,14 +212,14 @@ class CrossTowerOps(object): Returns: a Mirrored object. """ - _validate_destinations(destinations) + validate_destinations(destinations) return self._broadcast(tensor, destinations) - def _reduce(self, method_string, per_device_value, destinations): + def _reduce(self, aggregation, per_device_value, destinations): raise NotImplementedError( "_reduce method must be implemented in descendants.") - def _batch_reduce(self, method_string, value_destination_pairs): + def _batch_reduce(self, aggregation, value_destination_pairs): raise NotImplementedError( "_batch_reduce method must be implemented in descendants.") @@ -220,16 +245,21 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): self.accumulation_fn = accumulation_fn super(ReductionToOneDeviceCrossTowerOps, self).__init__() - def _reduce(self, method_string, per_device_value, destinations): - devices = _get_devices_from(destinations or per_device_value) + def _reduce(self, aggregation, per_device_value, destinations): + if check_destinations(destinations): + devices = get_devices_from(destinations) + else: + devices = get_devices_from(per_device_value) reduce_to_device = self.reduce_to_device or devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, - self.accumulation_fn, method_string) + self.accumulation_fn, aggregation) return self.broadcast(reduced, devices) - def _batch_reduce(self, method_string, value_destination_pairs): - return [self._reduce(method_string, t, destinations=v) - for t, v in value_destination_pairs] + def _batch_reduce(self, aggregation, value_destination_pairs): + return [ + self._reduce(aggregation, t, destinations=v) + for t, v in value_destination_pairs + ] def _group_value_by_device(per_device_values): @@ -237,9 +267,9 @@ def _group_value_by_device(per_device_values): This grouping is needed to call the all-reduce library because it expects a list of the following form: - [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ... - (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ... - (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ... + [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...], + [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...], + [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...], ... ] @@ -260,18 +290,24 @@ def _group_value_by_device(per_device_values): return grouped -def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string): +def _ungroup_and_make_mirrored(grouped_reduced, + destinations, + aggregation, + num_between_graph_workers=1): """Ungroup results from all-reduce and make Mirrored objects. Each all-reduce result will be divided by the number of destinations before - Mirrored objects are created if method_string is "mean". + Mirrored objects are created if aggregation is "mean". Args: grouped_reduced: a list of lists, each sublist has components for each device, paired with a None. It is the result from cross_tower_utils.aggregate_gradients_using*. destinations: a list of device strings for returned Mirrored objects. - method_string: "mean" or "sum". + aggregation: Indicates how a variable will be aggregated. Accepted values + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. + num_between_graph_workers: number of workers in the between-graph + replication. Returns: a list of Mirrored objects. @@ -279,8 +315,9 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string): index = [{} for _ in range(len(grouped_reduced[0]))] for d, per_device_reduced in enumerate(grouped_reduced): for i, (v, _) in enumerate(per_device_reduced): - if method_string == "mean": - index[i][destinations[d]] = v / len(destinations) + if aggregation == vs.VariableAggregation.MEAN: + index[i][destinations[d]] = v / ( + len(destinations) * num_between_graph_workers) else: index[i][destinations[d]] = v return [value_lib.Mirrored(v) for v in index] @@ -488,32 +525,35 @@ class AllReduceCrossTowerOps(CrossTowerOps): self._agg_small_grads_max_group = agg_small_grads_max_group super(AllReduceCrossTowerOps, self).__init__() - def _reduce(self, method_string, per_device_value, destinations): + def _reduce(self, aggregation, per_device_value, destinations): contains_indexed_slices = cross_tower_utils.contains_indexed_slices( per_device_value) if ((destinations is None or _devices_match(per_device_value, destinations)) and not context.executing_eagerly() and not contains_indexed_slices): - return self._batch_all_reduce(method_string, [per_device_value])[0] + return self._batch_all_reduce(aggregation, [per_device_value])[0] else: if contains_indexed_slices: logging.log_first_n( logging.WARN, "Efficient allreduce is not supported for IndexedSlices.", 10) - devices = _get_devices_from(destinations or per_device_value) + if check_destinations(destinations): + devices = get_devices_from(destinations) + else: + devices = get_devices_from(per_device_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, - math_ops.add_n, method_string) + math_ops.add_n, aggregation) return self.broadcast(reduced, devices) - def _batch_reduce(self, method_string, value_destination_pairs): + def _batch_reduce(self, aggregation, value_destination_pairs): all_devices_match = _all_devices_match(value_destination_pairs) contains_indexed_slices = cross_tower_utils.contains_indexed_slices( value_destination_pairs) if (all_devices_match and not context.executing_eagerly() and not contains_indexed_slices): - return self._batch_all_reduce(method_string, + return self._batch_all_reduce(aggregation, [v[0] for v in value_destination_pairs]) else: if not all_devices_match: @@ -521,22 +561,22 @@ class AllReduceCrossTowerOps(CrossTowerOps): "destinations are different.") return [ - self._reduce(method_string, t, destinations=v) + self._reduce(aggregation, t, destinations=v) for t, v in value_destination_pairs ] - def _batch_all_reduce(self, method_string, per_device_values): + def _batch_all_reduce(self, aggregation, per_device_values): """All reduce algorithm in a batch.""" - logging.info( - "batch_all_reduce invoked for batches size = %d with " + logging.log_first_n( + logging.INFO, "batch_all_reduce invoked for batches size = %d with " "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and " - "agg_small_grads_max_group = %d", len(per_device_values), - self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes, - self._agg_small_grads_max_group) + "agg_small_grads_max_group = %d" % + (len(per_device_values), self._all_reduce_alg, self._num_packs, + self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) destinations = per_device_values[0].devices grouped = _group_value_by_device(per_device_values) - device_grad_packs, self._tensor_packer = _pack_tensors( + device_grad_packs, tensor_packer = _pack_tensors( grouped, self._num_packs, self._agg_small_grads_max_bytes, self._agg_small_grads_max_group) @@ -554,9 +594,9 @@ class AllReduceCrossTowerOps(CrossTowerOps): cross_tower_utils.aggregate_gradients_using_hierarchical_copy( destinations, device_grad_packs)) - reduced = _unpack_tensors(reduced, self._tensor_packer) + reduced = _unpack_tensors(reduced, tensor_packer) return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, - method_string) + aggregation) AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", @@ -635,14 +675,15 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps): validate_and_complete_spec(spec) for spec in all_reduce_spec ] - def _batch_all_reduce(self, method_string, per_device_values): + def _batch_all_reduce(self, aggregation, per_device_values): """All reduce algorithm in a batch.""" - logging.info( + logging.log_first_n( + logging.INFO, "distributed batch_all_reduce invoked for batches size = %d with " "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " - "and agg_small_grads_max_group = %d", len(per_device_values), - self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes, - self._agg_small_grads_max_group) + "and agg_small_grads_max_group = %d" % + (len(per_device_values), self._all_reduce_spec, self._num_packs, + self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) destinations = sorted(per_device_values[0].devices) device_grads = _group_value_by_device(per_device_values) @@ -665,13 +706,13 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps): (this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size( spec_tuple.limit, remaining_grads) if this_grads: - device_grad_packs, self._tensor_packer = _pack_tensors( + device_grad_packs, tensor_packer = _pack_tensors( this_grads, self._num_packs, self._agg_small_grads_max_bytes, self._agg_small_grads_max_group) range_agg_grads = cross_tower_utils.sum_gradients_all_reduce( self._worker_devices, device_grad_packs, len(self._worker_devices), spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker)) - range_agg_grads = _unpack_tensors(range_agg_grads, self._tensor_packer) + range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer) if not aggregated_grads: aggregated_grads = range_agg_grads @@ -682,7 +723,103 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps): assert not remaining_grads return _ungroup_and_make_mirrored(aggregated_grads, destinations, - method_string) + aggregation) + + +# TODO(yuefengz): support in-graph collective all-reduce. +class CollectiveAllReduce(CrossTowerOps): + """All-reduce cross tower ops using collective ops. + + In the between-graph replicated training, it will still do all-reduces across + all workers and then put results on the right destinations. + """ + + def __init__(self, + num_workers=1, + num_gpus_per_worker=0, + all_reduce_merge_scope=1, + collective_keys=None): + """Initializes the object. + + Args: + num_workers: number of workers in the between-graph replicated training. + num_gpus_per_worker: number of GPUs per worker. + all_reduce_merge_scope: size of groups into which to partition consecutive + gradients grouped under a common 'allreduce' name scope. This is useful + for some optimization of collective ops. + collective_keys: an optional CollectiveKey object. + """ + self._num_workers = num_workers + self._num_gpus_per_worker = num_gpus_per_worker + self._all_reduce_merge_scope = all_reduce_merge_scope + self._collective_keys = collective_keys or cross_tower_utils.CollectiveKeys( + ) + super(CollectiveAllReduce, self).__init__() + + # TODO(yuefengz, tucker): is index slices supported by collective ops? + def _reduce(self, aggregation, per_device_value, destinations): + all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0] + if destinations is None or _devices_match(per_device_value, destinations): + return all_reduced + else: + index = {} + for d in get_devices_from(destinations): + # pylint: disable=protected-access + if d in all_reduced._index: + index[d] = all_reduced._index[d] + else: + with ops.device(d): + index[d] = array_ops.identity(list(all_reduced._index.values())[0]) + return value_lib.Mirrored(index) + + def _batch_reduce(self, aggregation, value_destination_pairs): + return [ + self._reduce(aggregation, t, destinations=v) + for t, v in value_destination_pairs + ] + + def _batch_all_reduce(self, aggregation, per_device_values): + """All-reduce across all workers in a batch.""" + if context.executing_eagerly(): + raise ValueError("Eager mode with collective ops is not supported yet.") + + logging.log_first_n( + logging.INFO, "Collective All-reduce invoked with batches size = %d, " + "num_workers = %d" % (len(per_device_values), self._num_workers), 10) + + grouped_by_tower = _group_value_by_device(per_device_values) + + grouped_by_var = list(zip(*grouped_by_tower)) + # grouped_by_var is grouped by variables and takes the following format: + # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..), + # ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..), + # ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..), + # ... + # ] + chunked_gv = [ + grouped_by_var[x:x + self._all_reduce_merge_scope] + for x in range(0, len(grouped_by_var), self._all_reduce_merge_scope) + ] + + reduced_gv_list = [] + for chunk in chunked_gv: + with ops.name_scope("allreduce"): + for grad_and_vars in chunk: + scaled_grads = [g for g, _ in grad_and_vars] + collective_reduced = cross_tower_utils.build_collective_reduce( + scaled_grads, self._num_workers, self._collective_keys, "Add", + "Id") + result = [] + for (_, v), g in zip(grad_and_vars, collective_reduced): + result.append([g, v]) + reduced_gv_list.append(result) + + new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] + return _ungroup_and_make_mirrored( + new_tower_grads, + per_device_values[0].devices, + aggregation, + num_between_graph_workers=self._num_workers) _dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index fed5505d92ef2544215069736c166a67d6141708..aec53b01d7a089fec08eec6ea43373a2cd8267d6 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -21,22 +21,27 @@ from __future__ import print_function import itertools from absl.testing import parameterized +import numpy as np from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import device_util def _make_per_device(values, devices): - devices = cross_tower_ops_lib._get_devices_from(devices) + devices = cross_tower_ops_lib.get_devices_from(devices) assert len(values) == len(devices) index = {} for d, v in zip(devices, values): @@ -53,7 +58,7 @@ def _fake_mirrored(value, devices): All components of the returned Mirrored have the same objects, which is not true in reality. """ - devices = cross_tower_ops_lib._get_devices_from(devices) + devices = cross_tower_ops_lib.get_devices_from(devices) return value_lib.Mirrored( {d: v for d, v in zip(devices, [value] * len(devices))}) @@ -93,7 +98,7 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase): self._assert_values_equal(l, r) else: self.assertEqual(type(left), type(right)) - self.assertEqual(left.devices, right.devices) + self.assertEqual(set(left.devices), set(right.devices)) if isinstance(list(left._index.values())[0], ops.IndexedSlices): for (d, v) in left._index.items(): self._assert_indexed_slices_equal(v, right._index[d]) @@ -129,32 +134,45 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase): # test reduce() for destinations in all_destinations: self._assert_values_equal( - cross_tower_ops.reduce("mean", per_device, destinations=destinations), + cross_tower_ops.reduce( + vs.VariableAggregation.MEAN, + per_device, + destinations=destinations), _fake_mirrored(mean, destinations or per_device)) self._assert_values_equal( cross_tower_ops.reduce( - "mean", per_device_2, destinations=destinations), + vs.VariableAggregation.MEAN, + per_device_2, + destinations=destinations), _fake_mirrored(mean_2, destinations or per_device)) self._assert_values_equal( - cross_tower_ops.reduce("sum", per_device, destinations=destinations), + cross_tower_ops.reduce( + vs.VariableAggregation.SUM, per_device, + destinations=destinations), _fake_mirrored(mean * len(devices), destinations or per_device)) self._assert_values_equal( cross_tower_ops.reduce( - "sum", per_device_2, destinations=destinations), + vs.VariableAggregation.SUM, + per_device_2, + destinations=destinations), _fake_mirrored(mean_2 * len(devices), destinations or per_device)) # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - cross_tower_ops.batch_reduce( - "mean", [(per_device, d1), (per_device_2, d2)]), - [_fake_mirrored(mean, d1 or per_device), - _fake_mirrored(mean_2, d2 or per_device_2)]) + cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN, + [(per_device, d1), (per_device_2, d2)]), + [ + _fake_mirrored(mean, d1 or per_device), + _fake_mirrored(mean_2, d2 or per_device_2) + ]) self._assert_values_equal( - cross_tower_ops.batch_reduce( - "sum", [(per_device, d1), (per_device_2, d2)]), - [_fake_mirrored(mean * len(devices), d1 or per_device), - _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)]) + cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM, + [(per_device, d1), (per_device_2, d2)]), + [ + _fake_mirrored(mean * len(devices), d1 or per_device), + _fake_mirrored(mean_2 * len(devices), d2 or per_device_2) + ]) # test broadcast() for destinations in all_destinations: @@ -255,8 +273,8 @@ class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase): t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) - result = cross_tower_ops_lib._simple_reduce(per_device, devices[0], - math_ops.add_n, "sum") + result = cross_tower_ops_lib._simple_reduce( + per_device, devices[0], math_ops.add_n, vs.VariableAggregation.SUM) # Test that the result is semantically equal to both the concatenated # IndexedSlices with and without duplicate indices. @@ -267,21 +285,22 @@ class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase): self._assert_indexed_slices_equal(total_with_dups, result) self._assert_indexed_slices_equal(total_without_dups, result) - @combinations.generate(combinations.combine( - cross_tower_ops_instance=[ - combinations.NamedObject( - "ReductionToOneDeviceCrossTowerOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), - combinations.NamedObject( - "AllReduceCrossTowerOps", - cross_tower_ops_lib.AllReduceCrossTowerOps()) - ], - method_string=["sum", "mean"], - batch_reduce=[True, False], - mode=["graph", "eager"], - required_gpus=1)) - def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, - method_string, batch_reduce): + @combinations.generate( + combinations.combine( + cross_tower_ops_instance=[ + combinations.NamedObject( + "ReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "AllReduceCrossTowerOps", + cross_tower_ops_lib.AllReduceCrossTowerOps()) + ], + aggregation=[vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN], + batch_reduce=[True, False], + mode=["graph", "eager"], + required_gpus=1)) + def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, aggregation, + batch_reduce): devices = ["/cpu:0", "/gpu:0"] dense_shape = [5, 2] t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) @@ -290,20 +309,19 @@ class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase): per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) if batch_reduce: - result = cross_tower_ops_instance.batch_reduce(method_string, + result = cross_tower_ops_instance.batch_reduce(aggregation, [(per_device, devices)]) else: - result = cross_tower_ops_instance.reduce(method_string, per_device, - devices) + result = cross_tower_ops_instance.reduce(aggregation, per_device, devices) total_indices_with_dups = [1, 1, 3] total_indices_without_dups = [1, 3] - if method_string == "sum": + if aggregation == vs.VariableAggregation.SUM: total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] total_values_without_dups = [[4., 6.], [5., 6.]] else: - assert method_string == "mean" + assert aggregation == vs.VariableAggregation.MEAN total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] total_values_without_dups = [[2., 3.], [2.5, 3.]] @@ -362,5 +380,166 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase, self._testReductionAndBroadcast(cross_tower_ops, distribution) +class MultiWorkerCollectiveAllReduceTest( + multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): + + collective_key_base = 100000 + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + "fake_worker_0", "fake_worker_1", "fake_worker_2" + ] + } + + def setUp(self): + super(MultiWorkerCollectiveAllReduceTest, self).setUp() + # Reusing keys are not supported well. So we have to give a different + # collective key base for different tests. + MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000 + + def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False): + collective_keys = cross_tower_utils.CollectiveKeys( + group_key_start=10 * num_gpus + + MultiWorkerCollectiveAllReduceTest.collective_key_base, + instance_key_start=num_gpus * 100 + + MultiWorkerCollectiveAllReduceTest.collective_key_base, + instance_key_with_id_start=num_gpus * 10000 + + MultiWorkerCollectiveAllReduceTest.collective_key_base) + if local_mode: + collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + 1, num_gpus, collective_keys=collective_keys) + if num_gpus: + devices = ["/device:GPU:%d" % i for i in range(num_gpus)] + else: + devices = ["/device:CPU:0"] + return collective_all_reduce_ops, devices, "local" + else: + collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + 3, num_gpus, collective_keys=collective_keys) + if num_gpus: + devices = [ + "/job:%s/task:%d/device:GPU:%d" % (task_type, task_id, i) + for i in range(num_gpus) + ] + else: + devices = ["/job:%s/task:%d" % (task_type, task_id)] + return collective_all_reduce_ops, devices, self._workers[task_id].target + + def _assert_values_equal(self, left, right, sess): + if isinstance(left, list): + for l, r in zip(left, right): + self._assert_values_equal(l, r, sess) + else: + self.assertEqual(type(left), type(right)) + self.assertEqual(set(left.devices), set(right.devices)) + + run_options = config_pb2.RunOptions() + run_options.experimental.collective_graph_key = 6 + + left_values = np.array( + sess.run(list(left._index.values()), options=run_options)).flatten() + right_values = np.array(list(right._index.values())).flatten() + self.assertEqual(len(left_values), len(right_values)) + for l, r in zip(left_values, right_values): + self.assertEqual(l, r) + + def _test_reduction(self, task_type, task_id, num_gpus, local_mode=False): + collective_all_reduce, devices, master_target = self._get_test_objects( + task_type, task_id, num_gpus, local_mode=local_mode) + if local_mode: + num_workers = 1 + worker_device = None + else: + num_workers = len(self._workers) + worker_device = "/job:%s/task:%d" % (task_type, task_id) + with ops.Graph().as_default(), \ + ops.device(worker_device), \ + self.test_session(target=master_target) as sess: + # Collective ops doesn't support scalar tensors, so we have to construct + # 1-d tensors. + values = [constant_op.constant([float(d)]) for d in range(len(devices))] + per_device = _make_per_device(values, devices) + mean = np.array([(len(devices) - 1.) / 2.]) + + values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] + per_device_2 = _make_per_device(values_2, devices) + mean_2 = np.array([mean[0] + 1.]) + + destination_mirrored = _fake_mirrored(1., devices) + destination_different = _fake_mirrored(1., _cpu_device) + destination_str = _cpu_device + destination_list = devices + + all_destinations = [ + None, destination_mirrored, destination_different, destination_str, + destination_list + ] + + # test reduce() + for destinations in all_destinations: + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.MEAN, + per_device, + destinations=destinations), + _fake_mirrored(mean, destinations or per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.MEAN, + per_device_2, + destinations=destinations), + _fake_mirrored(mean_2, destinations or per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.SUM, + per_device, + destinations=destinations), + _fake_mirrored(mean * len(devices) * num_workers, destinations or + per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.SUM, + per_device_2, + destinations=destinations), + _fake_mirrored(mean_2 * len(devices) * num_workers, destinations or + per_device), sess) + + # test batch_reduce() + for d1, d2 in itertools.product(all_destinations, all_destinations): + self._assert_values_equal( + collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN, + [(per_device, d1), + (per_device_2, d2)]), + [ + _fake_mirrored(mean, d1 or per_device), + _fake_mirrored(mean_2, d2 or per_device_2) + ], sess) + self._assert_values_equal( + collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM, + [(per_device, d1), + (per_device_2, d2)]), + [ + _fake_mirrored(mean * len(devices) * num_workers, d1 or + per_device), + _fake_mirrored(mean_2 * len(devices) * num_workers, d2 or + per_device_2) + ], sess) + + return True + + @combinations.generate( + combinations.combine(mode=["graph"], num_gpus=[0, 1, 2])) + def testReductionDistributed(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients(self._test_reduction, self._cluster_spec, + num_gpus) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index 2bb088e704c584598b863b1b836166af2a5bb12c..24cb08fb48f832572da5ae2113e6c224557c6a81 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -19,13 +19,16 @@ from __future__ import division from __future__ import print_function import collections as pycoll +import threading from tensorflow.contrib import nccl from tensorflow.contrib.all_reduce.python import all_reduce from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import collective_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops @@ -218,6 +221,146 @@ def split_grads_by_size(threshold_size, device_grads): return small_grads, large_grads +# threading.Lock() cannot be pickled and therefore cannot be a field of +# CollectiveKeys. +_lock = threading.Lock() + + +# TODO(yuefengz): use random key starts to avoid reusing keys? +class CollectiveKeys(object): + """Class that manages collective keys. + + We need to manage three different keys for collective: + + *Group key*: an integer key to identify the set of cooperative devices. + Collective ops work under the same set of devices must using the same group + key. + + *Instance key*: an integer key to identify the set of same counterpart of + tensors on different devices in a device group that need to be all-reduced. + + "Graph key": an integer key that is unique key graph. This is used to support + multiple graphs per client session. It must be non-zero and set in the + `config` argument of each call to `session.run`. + """ + + def __init__(self, + group_key_start=1, + instance_key_start=100, + instance_key_with_id_start=10000): + """Initializes the object. + + Args: + group_key_start: the starting integer of group key. + instance_key_start: the starting integer of instance key. + instance_key_with_id_start: the starting integer of instance key that is + recorded with an id. + """ + self._group_key = group_key_start + self._group_key_table = dict() + + # For instance keys with ids + self._instance_key_id_to_key_table = dict() + self._instance_key_with_id_counter = instance_key_with_id_start + + # For instance keys without ids + self._instance_key_start = instance_key_start + + self._thread_local = threading.local() + + def _get_thread_local_object(self): + # We make instance key without key ids thread local so that it will work + # with MirroredStrategy and distribute coordinator. + if not hasattr(self._thread_local, 'instance_key'): + self._thread_local.instance_key = self._instance_key_start + return self._thread_local + + def get_group_key(self, devices): + """Returns a group key for the set of devices. + + Args: + devices: list of strings naming devices in a collective group. + + Returns: + int key uniquely identifying the set of device names. + """ + parsed = [pydev.DeviceSpec.from_string(d) for d in devices] + # In the between-graph replicated training, different workers need to get + # the same device key. So we remove the task_type and task_id from the + # devices. + # TODO(yuefengz): in the in-graph replicated training, we need to include + # task_type and task_id. + names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed]) + key_id = ','.join(names) + with _lock: + if key_id not in self._group_key_table: + new_key = self._group_key + self._group_key += 1 + self._group_key_table[key_id] = new_key + return self._group_key_table[key_id] + + def get_instance_key(self, key_id=None): + """Returns a new instance key for use in defining a collective op. + + Args: + key_id: optional string. If set, key will be recorded and the same key + will be returned when the same key_id is provided. If not, an increasing + instance key will be returned. + """ + if key_id: + with _lock: + if key_id not in self._instance_key_id_to_key_table: + self._instance_key_with_id_counter += 1 + self._instance_key_id_to_key_table[key_id] = ( + self._instance_key_with_id_counter) + return self._instance_key_id_to_key_table[key_id] + else: + v = self._get_thread_local_object().instance_key + self._get_thread_local_object().instance_key += 1 + return v + + +def build_collective_reduce(input_tensors, + num_workers, + collective_keys, + reduction_op='Add', + unary_op='Id'): + """Build a subgraph that does one full all-reduce, using the collective Op. + + Args: + input_tensors: tensors within a single worker graph that are to be reduced + together; must be one per device. + num_workers: total number of workers with identical independent graphs that + will be doing this same reduction. The reduction will actually include + the corresponding tensors at all these workers. + collective_keys: a CollectiveKeys object. + reduction_op: string naming the reduction op. + unary_op: string naming the unary final op. + + Returns: + An array of final tensors, one per device, computed by the full reduction. + + Raises: + ValueError: There must be at least two tensors over all the workers. + """ + group_size = len(input_tensors) * num_workers + if group_size < 2: + raise ValueError('num_workers * len(input_tensors) must be 2 or greater') + devices = [t.device for t in input_tensors] + num_devices = len(devices) + group_key = collective_keys.get_group_key(devices) + instance_key = collective_keys.get_instance_key() + out_tensors = [] + subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec + for d in range(num_devices): + with ops.device(devices[d]): + reduce_op = collective_ops.all_reduce( + input_tensors[d], group_size, group_key, instance_key, reduction_op, + unary_op, subdiv_offsets) + out_tensors.append(reduce_op) + return out_tensors + + def sum_grad_and_var_all_reduce(grad_and_vars, num_workers, alg, @@ -253,10 +396,10 @@ def sum_grad_and_var_all_reduce(grad_and_vars, else: raise ValueError('unsupported all_reduce alg: ', alg) - result = [] - for (_, v), g in zip(grad_and_vars, summed_grads): - result.append([g, v]) - return result + result = [] + for (_, v), g in zip(grad_and_vars, summed_grads): + result.append([g, v]) + return result def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg, diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py index 4ef8db681503dcef8c72f641455dbb999cef05cf..d25964fa41adc7b1c9164a4ffe49c4c5532f76ac 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py @@ -38,7 +38,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): self.evaluate(ops.convert_to_tensor(left)), self.evaluate(ops.convert_to_tensor(right))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAggregateTensors(self): t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) @@ -46,7 +46,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) self._assert_values_equal(total, result) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAggregateIndexedSlices(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) @@ -57,7 +57,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): self.assertIsInstance(result, ops.IndexedSlices) self._assert_values_equal(total, result) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDivideTensor(self): t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) n = 2 @@ -65,7 +65,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) self._assert_values_equal(expected, result) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDivideIndexedSlices(self): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) @@ -75,13 +75,13 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): self.assertIsInstance(result, ops.IndexedSlices) self._assert_values_equal(expected, result) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testIsIndexedSlices(self): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) self.assertTrue(cross_tower_utils.contains_indexed_slices(t)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_List(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) @@ -89,7 +89,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_Tuple(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) @@ -97,7 +97,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_PerDevice(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) @@ -106,7 +106,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1}) self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testContainsIndexedSlices_PerDeviceMapOutput(self): t0 = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index 34410a6470185ac2821bc6a59de9230ff478aeb6..cc626c33bf8e282736f8e6e0c151e5a3d3f3244b 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -29,6 +29,7 @@ from tensorflow.contrib.optimizer_v2 import adagrad from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test from tensorflow.python.estimator import run_config +from tensorflow.python.estimator import training from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export @@ -63,8 +64,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus - ])) - def test_complete_flow_with_mode(self, distribution): + ], + use_train_and_evaluate=[True, False])) + def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): label_dimension = 2 input_dimension = label_dimension batch_size = 10 @@ -75,8 +77,11 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, y=data, batch_size=batch_size // len(distribution.worker_devices), shuffle=True) - eval_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, y=data, batch_size=batch_size, shuffle=False) + eval_input_fn = self.dataset_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size // len(distribution.worker_devices), + shuffle=False) predict_input_fn = numpy_io.numpy_input_fn( x={'x': data}, batch_size=batch_size, shuffle=False) @@ -96,12 +101,19 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, # TODO(isaprykin): Work around the colocate_with error. dnn_optimizer=adagrad.AdagradOptimizer(0.001), linear_optimizer=adagrad.AdagradOptimizer(0.001), - config=run_config.RunConfig(train_distribute=distribution)) + config=run_config.RunConfig( + train_distribute=distribution, eval_distribute=distribution)) num_steps = 10 - estimator.train(train_input_fn, steps=num_steps) + if use_train_and_evaluate: + scores, _ = training.train_and_evaluate( + estimator, + training.TrainSpec(train_input_fn, max_steps=num_steps), + training.EvalSpec(eval_input_fn)) + else: + estimator.train(train_input_fn, steps=num_steps) + scores = estimator.evaluate(eval_input_fn) - scores = estimator.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py index 00c25c7a2482a559c8b94ff3be86c4961dfb439f..44a69ed23a4e00ab81d5b51ae0c14550bd493f14 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py @@ -59,7 +59,8 @@ def build_model_fn_optimizer(): def main(_): distribution = tf.contrib.distribute.MirroredStrategy( ["/device:GPU:0", "/device:GPU:1"]) - config = tf.estimator.RunConfig(train_distribute=distribution) + config = tf.estimator.RunConfig(train_distribute=distribution, + eval_distribute=distribution) def input_fn(): features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) @@ -70,7 +71,7 @@ def main(_): model_fn=build_model_fn_optimizer(), config=config) estimator.train(input_fn=input_fn, steps=10) - eval_result = estimator.evaluate(input_fn=input_fn) + eval_result = estimator.evaluate(input_fn=input_fn, steps=10) print("Eval result: {}".format(eval_result)) def predict_input_fn(): diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py index 2b05884b9b93470ef9a764cbedbc91bd3912c611..518ec9c4232465c3ecd0e4161f707dac499430c7 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py @@ -57,7 +57,8 @@ def main(args): # tf.Estimator that utilizes the DistributionStrategy. strategy = tf.contrib.distribute.MirroredStrategy( ['/device:GPU:0', '/device:GPU:1']) - config = tf.estimator.RunConfig(train_distribute=strategy) + config = tf.estimator.RunConfig( + train_distribute=strategy, eval_distribute=strategy) keras_estimator = tf.keras.estimator.model_to_estimator( keras_model=model, config=config, model_dir=model_dir) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 75ecd90dcffa7a786b78238ef453c4c8e4346afa..4facd72d12680a53cc3f5e2ded2585bc9716ea3c 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -12,33 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Keras Sequential and Functional models.""" +"""Tests for tf.keras models using DistributionStrategy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os - import numpy as np from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop + _RANDOM_SEED = 1337 _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) _NUM_CLASS = 2 +# TODO(anjalisridhar): Add a decorator that will allow us to run these tests as +# part of the tf.keras unit tests suite. def simple_sequential_model(): model = keras.models.Sequential() model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE)) @@ -84,7 +91,7 @@ def get_ds_test_input_fn(): return dataset -class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): +class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): def setUp(self): self._base_dir = os.path.join(self.get_temp_dir(), @@ -107,7 +114,8 @@ class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist) + train_distribute=dist, + eval_distribute=dist) with self.test_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) @@ -144,5 +152,457 @@ class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) + def test_keras_optimizer_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_sequential_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=keras.optimizers.rmsprop(lr=0.01)) + + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator(keras_model=keras_model, + config=config) + with self.assertRaisesRegexp(ValueError, + 'Only TensorFlow native optimizers are ' + 'supported with DistributionStrategy.'): + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + +class TestWithDistributionStrategy(test.TestCase): + + def test_validating_dataset_input_tensors_with_shape_mismatch(self): + with self.test_session(): + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + a = constant_op.constant([1, 2], shape=(1, 2)) + b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) + x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) + y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) + with strategy.scope(): + # Removed device and input tensor shape details from the error message + # since the order of the device and the corresponding input tensor shape + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor shapes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + strategy, x, y) + + def test_validating_dataset_input_tensors_with_dtype_mismatch(self): + with self.test_session(): + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) + b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) + x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) + y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) + with strategy.scope(): + # Removed device and input tensor dtype details from the error message + # since the order of the device and the corresponding input tensor dtype + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor dtypes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + strategy, x, y) + + def test_calling_model_on_same_dataset(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + # Call fit with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.predict(dataset, steps=2) + + def test_fit_with_tuple_and_dict_dataset_inputs(self): + with self.test_session(): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + + # Test with tuples + dataset_tuple = dataset_ops.Dataset.from_tensor_slices(( + (input_a_np, input_b_np), (output_d_np, output_e_np))) + dataset_tuple = dataset_tuple.repeat(100) + dataset_tuple = dataset_tuple.batch(10) + + model.fit(dataset_tuple, epochs=1, steps_per_epoch=2, verbose=1) + + # Test with dict + dataset_dict = dataset_ops.Dataset.from_tensor_slices(( + {'input_a': input_a_np, 'input_b': input_b_np}, + (output_d_np, output_e_np))) + dataset_dict = dataset_dict.repeat(100) + dataset_dict = dataset_dict.batch(10) + + model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) + + def test_fit_eval_and_predict_methods_on_dataset(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) + # Test with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + + def test_raise_error_for_stateful_metrics(self): + + class ExampleStatefulMetric(keras.layers.Layer): + + def __init__(self, name='true_positives', **kwargs): + super(ExampleStatefulMetric, self).__init__(name=name, **kwargs) + self.stateful = True + + def __call__(self, y_true, y_pred): + return y_pred - y_true + + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', ExampleStatefulMetric()] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + with self.assertRaisesRegexp( + NotImplementedError, 'Stateful metrics are not supported with ' + 'DistributionStrategy.'): + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + def test_unsupported_features(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + # Test with validation split + with self.assertRaisesRegexp( + ValueError, '`validation_split` argument is not ' + 'supported when input `x` is a dataset or a ' + 'dataset iterator.+'): + model.fit(dataset, + epochs=1, steps_per_epoch=2, verbose=0, + validation_split=0.5, validation_steps=2) + + # Test with sample weight. + sample_weight = np.random.random((10,)) + with self.assertRaisesRegexp( + NotImplementedError, 'sample_weight is currently not supported when ' + 'using DistributionStrategy.'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + sample_weight=sample_weight) + + # Test with not specifying the `steps` argument. + with self.assertRaisesRegexp( + ValueError, 'you should specify the `steps_per_epoch` argument'): + model.fit(dataset, epochs=1, verbose=0) + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.evaluate(dataset, verbose=0) + + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.predict(dataset, verbose=0) + + def test_calling_with_unsupported_predefined_callbacks(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + def schedule(_): + return 0.001 + with self.assertRaisesRegexp(ValueError, + 'LearningRateScheduler callback is not ' + 'supported with DistributionStrategy.'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + + with self.assertRaisesRegexp(ValueError, + 'ReduceLROnPlateau callback is not ' + 'supported with DistributionStrategy.'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.ReduceLROnPlateau()]) + with self.assertRaisesRegexp(ValueError, + 'histogram_freq in the TensorBoard callback ' + 'is not supported when using ' + 'DistributionStrategy.'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) + + def test_dataset_input_shape_validation(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + + model.compile(optimizer, loss, distribute=strategy) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have 2 dimensions'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + # Wrong input shape + inputs = np.zeros((10, 5), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + def test_learning_phase_value(self): + # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare + # meaningful values. Currently we don't pass the learning phase if the + # Lambda layer uses the learning phase. + with self.test_session(): + x = keras.layers.Input(shape=(16,), name='input') + y = keras.layers.Dense(16)(x) + z = keras.layers.Dropout(0.9999)(y) + model = keras.Model(x, z) + + optimizer = gradient_descent.GradientDescentOptimizer(0.005) + loss = 'mse' + metrics = ['acc'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.random.rand(10, 16) + targets = np.ones((10, 16), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(8) + + hist = model.fit(dataset, epochs=5, steps_per_epoch=20, verbose=1) + self.assertEqual(hist.history['acc'][0], 1) + + evaluate_output = model.evaluate(dataset, steps=20) + self.assertEqual(evaluate_output[1], 0) + + predict_output = model.predict(dataset, steps=1) + self.assertNotEqual(np.mean(predict_output), 0) + + +class LossMaskingWithDistributionStrategyTest(test.TestCase): + + def test_masking(self): + with self.test_session(): + np.random.seed(1337) + x = np.array([[[1], [1]], [[0], [0]]]) + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))) + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=strategy) + y = np.array([[[1], [1]], [[1], [1]]]) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2) + self.assertEqual(hist.history['loss'][0], 0) + + +class NormalizationLayerWithDistributionStrategyTest(test.TestCase): + + def test_batchnorm_correctness(self): + with self.test_session(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) + model.add(norm) + strategy = mirrored_strategy.MirroredStrategy(['/device:CPU:0', + '/device:GPU:0']) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=strategy) + + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) + dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) + dataset = dataset.repeat(100) + dataset = dataset.batch(32) + + model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) + out = model.predict(dataset, steps=2) + out -= keras.backend.eval(norm.beta) + out /= keras.backend.eval(norm.gamma) + np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) + np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + + +class CorrectnessWithDistributionStrategyTest(test.TestCase): + + def test_correctness(self): + with self.test_session(): + keras.backend.set_image_data_format('channels_last') + num_samples = 10000 + x_train = np.random.rand(num_samples, 1) + y_train = 3 * x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + + model = keras.Sequential() + model.add(keras.layers.Dense(1, input_shape=(1,))) + + # With DistributionStrategy + dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + dataset_with = dataset_with.batch(32) + strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0', + '/device:GPU:0'], + prefetch_on_device=False) + + model.compile(loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + distribute=strategy) + model.fit(x=dataset_with, epochs=1, steps_per_epoch=310) + wts_with_ds = model.get_weights() + + x_predict = [[1], [2], [3], [4]] + predict_dataset_with = dataset_ops.Dataset.from_tensor_slices((x_predict, + x_predict)) + predict_dataset_with = predict_dataset_with.batch(2) + predict_with_ds = model.predict(predict_dataset_with, steps=1) + predict_with_ds = np.reshape(predict_with_ds, (4, 1)) + + # Without DistributionStrategy + dataset_without = dataset_ops.Dataset.from_tensor_slices((x_train, + y_train)) + dataset_without = dataset_without.batch(64) + + model.compile(loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5)) + model.fit(x=dataset_without, epochs=1, steps_per_epoch=310) + wts_without_ds = model.get_weights() + + x_predict = [[1], [2], [3], [4]] + predict_dataset_without = dataset_ops.Dataset.from_tensor_slices(( + x_predict, x_predict)) + predict_dataset_without = predict_dataset_without.batch(4) + predict_without_ds = model.predict(predict_dataset_without, steps=1) + + # Verify that the weights are the same within some limits of tolerance. + np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3) + # Verify that the predicted outputs are the same within some limits of + # tolerance. + np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 6c6bf143098c1bba64d47efce1bfface7682683d..2f3d6bdd3f4e4bc7352d7b378ed40b930608ef08 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -19,7 +19,6 @@ from __future__ import print_function from absl.testing import parameterized -from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import combinations from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test @@ -183,7 +182,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): def _dataset_fn(): dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) # Want to produce a fixed, known shape, so drop remainder when batching. - dataset = dataset.apply(batching.batch_and_drop_remainder(4)) + dataset = dataset.batch(4, drop_remainder=True) return dataset def _expected_fn(num_batches): diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index 5c056a7c73def2f1fb4bbe0df4d3f82fdabda3df..aa7a61bb3b24df64dfc2a118611e96242a72b025 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -25,11 +25,13 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example -from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope @@ -43,28 +45,60 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.times( combinations.distributions_and_v1_optimizers(), combinations.combine(mode=["graph"], use_callable_loss=[True, False]) - + combinations.combine(mode=["eager"], use_callable_loss=[True]), - combinations.combine(is_tpu=[False])) + combinations.combine( - distribution=[combinations.tpu_strategy], - optimizer_fn=[ - combinations.adam_optimizer_v1_fn, - # TODO(isaprykin): Make Adam v2 work with while_loops - # and TPUs. - ], - mode=["graph"], - use_callable_loss=[False], - is_tpu=[True])) - def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss, - is_tpu): + + combinations.combine(mode=["eager"], use_callable_loss=[True])) + + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=combinations.optimizers_v1, + mode=["graph"], + use_callable_loss=[True, False])) + def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss): + with distribution.scope(): + model_fn, dataset_fn, layer = minimize_loss_example( + optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) + + def step_fn(ctx, inputs): + del ctx # Unused + return distribution.group( + distribution.call_for_each_tower( + model_fn, inputs, run_concurrently=layer.built)) + + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + + def run_step(): + return distribution.run_steps_on_dataset( + step_fn, iterator, iterations=2).run_op + + self.evaluate(distribution.initialize()) + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(5): + run_step() + + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) + + self.evaluate(distribution.finalize()) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + def testTrainNetworkByCallForEachTower(self, distribution, optimizer_fn, + use_callable_loss): with distribution.scope(): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - # TODO(isaprykin): Eliminate `is_tpu`. Probably add a - # `DistributionStrategy.create_monitor` so that each DistributionStrategy - # could influence its training loop. That method would return an instance - # of Monitor. TPUMonitor would execute tpu.initialize_system() and - # tpu.shutdown_system(). iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() @@ -75,8 +109,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): if not context.executing_eagerly(): with self.test_session() as sess: - if is_tpu: - sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) @@ -84,12 +116,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - weights.append(self.evaluate(distribution.fetch(layer.kernel))) - biases.append(self.evaluate(distribution.fetch(layer.bias))) - - if is_tpu: - with self.test_session() as sess: - sess.run(tpu.shutdown_system()) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) @@ -99,18 +127,12 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.times( combinations.distributions_and_v1_optimizers() + combinations.distributions_and_v2_optimizers(), - combinations.combine(mode=["graph", "eager"], is_tpu=[False])) + + combinations.combine(mode=["graph", "eager"])) + combinations.combine( distribution=[combinations.tpu_strategy], - optimizer_fn=[ - combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v2_fn, - ], - mode=["graph"], - is_tpu=[True])) - - def testOptimizerInsideModelFn(self, distribution, optimizer_fn, is_tpu): + optimizer_fn=combinations.optimizers_v1+combinations.optimizers_v2, + mode=["graph"])) + def testOptimizerInsideModelFn(self, distribution, optimizer_fn): created_variables = [] trainable_variables = [] @@ -131,26 +153,28 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): use_callable_loss=True, create_optimizer_inside_model_fn=True) + def step_fn(ctx, inputs): + del ctx # Unused + return distribution.group( + distribution.call_for_each_tower( + model_fn, inputs, run_concurrently=layer.built)) + iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() def run_step(): - return distribution.group( - distribution.call_for_each_tower( - model_fn, iterator.get_next(), run_concurrently=layer.built)) + return distribution.run_steps_on_dataset( + step_fn, iterator, iterations=1).run_op + self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.test_session() as sess: - if is_tpu: - sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) - self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(variables_lib.global_variables_initializer()) run_step() - if is_tpu: - with self.test_session() as sess: - sess.run(tpu.shutdown_system()) + self.evaluate(distribution.finalize()) def get_expected_variables(optimizer_fn, num_parameter_devices): variables_map = { @@ -181,22 +205,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.distributions_and_v1_optimizers(), combinations.combine( mode=["graph", "eager"], - is_tpu=[False], # TODO(isaprykin): Allow False here. Currently subsequent # towers will re-execute UPDATE_OPS of previous towers. update_ops_in_cross_tower_mode=[True])) + combinations.combine( - distribution=[combinations.tpu_strategy_single_iteration], - optimizer_fn=[ - combinations.gradient_descent_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v2_fn - ], + distribution=[combinations.tpu_strategy], + optimizer_fn=combinations.optimizers_v1, mode=["graph"], - is_tpu=[True], update_ops_in_cross_tower_mode=[False]))) def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, - renorm, is_tpu, - update_ops_in_cross_tower_mode): + renorm, update_ops_in_cross_tower_mode): """Verifies that moving mean updates are reduced across towers.""" with distribution.scope(): num_towers = len(distribution.worker_devices) @@ -212,24 +230,28 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # this test relies on specific input being on each device. if isinstance(distribution, mirrored_strategy.MirroredStrategy): self.assertFalse(distribution._prefetch_on_device) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() - def run_step(): + def step_fn(ctx, inputs): + del ctx # Unused fetches = distribution.unwrap( distribution.call_for_each_tower( - model_fn, iterator.get_next(), - run_concurrently=batchnorm.built)) + model_fn, inputs, run_concurrently=batchnorm.built)) if update_ops_in_cross_tower_mode: fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + + def run_step(): + return distribution.run_steps_on_dataset( + step_fn, iterator, iterations=1).run_op + + self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.test_session() as sess: - if is_tpu: - sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) - self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(variables_lib.global_variables_initializer()) expected_moving_means = [0.] * 8 @@ -242,7 +264,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - moving_means = self.evaluate(distribution.fetch(batchnorm.moving_mean)) + moving_means = self.evaluate(batchnorm.moving_mean) # We make sure that the moving_mean is updated as if the sample mean is # calculated over all towers. @@ -251,9 +273,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) - if is_tpu: - with self.test_session() as sess: - sess.run(tpu.shutdown_system()) + self.evaluate(distribution.finalize()) @combinations.generate( combinations.times( @@ -273,18 +293,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, combinations.mirrored_strategy_with_two_gpus - ], - is_tpu=[False]), + ]), combinations.combine( mode=["graph"], use_callable_loss=[True, False]) + combinations.combine(mode=["eager"], use_callable_loss=[True])) + combinations.combine( - distribution=[combinations.tpu_strategy_single_iteration], - is_tpu=[True], + distribution=[combinations.tpu_strategy], mode=["graph"], use_callable_loss=[True, False]))) def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, - use_callable_loss, is_tpu): + use_callable_loss): with distribution.scope(): all_vars = [] @@ -310,26 +328,31 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) return dataset_ops.Dataset.zip((features, labels)).repeat() + def step_fn(ctx, inputs): + del ctx # Unused + x, y = inputs + return distribution.group( + distribution.call_for_each_tower( + model_fn, x, y, run_concurrently=False)) + iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() def run_step(): - return distribution.group( - distribution.call_for_each_tower( - model_fn, *iterator.get_next(), run_concurrently=False)) + return distribution.run_steps_on_dataset( + step_fn, iterator, iterations=1).run_op + self.evaluate(distribution.initialize()) if not context.executing_eagerly(): with self.test_session() as sess: - if is_tpu: - sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) - self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(variables_lib.global_variables_initializer()) run_step() v = all_vars[0] self.assertTrue(all([v is vi for vi in all_vars[1:]])) - weight = numpy.squeeze(self.evaluate(distribution.fetch(v))) + weight = numpy.squeeze(self.evaluate(v)) # Our model is: # predict = x * w # loss = (predict - y)^2 @@ -353,10 +376,132 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) - if is_tpu: - with self.test_session() as sess: - sess.run(tpu.shutdown_system()) + self.evaluate(distribution.finalize()) + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=["graph", "eager"]), + combinations.combine(is_tpu=[False])) + + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=combinations.optimizers_v1, + mode=["graph"], + is_tpu=[True])) + def testRunStepsWithOutputContext(self, distribution, optimizer_fn, is_tpu): + with distribution.scope(): + def dataset_fn(): + dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + # TODO(priyag): batch with drop_remainder=True causes shapes to be + # fully defined for TPU. Remove this when XLA supports dynamic shapes. + return dataset.batch(batch_size=1, drop_remainder=True) + + optimizer = optimizer_fn() + layer = core.Dense(1, use_bias=True) + key1 = "foo" + value1 = "bar" + + def model_fn(output_context, x): + """A very simple model written by the user.""" + def loss_fn(): + y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) + return y * y + + train_op = optimizer.minimize(loss_fn) + loss = loss_fn() + output_context.set_last_step_output( + name="tower_loss_agg", + output=loss, + aggregation=variables_lib.VariableAggregation.MEAN) + output_context.set_non_tensor_output(key1, value1) + return (train_op, loss) + + def step_fn(output_context, inputs): + (train_op, loss) = distribution.call_for_each_tower( + model_fn, output_context, inputs, run_concurrently=False) + output_context.set_last_step_output( + name="cross_tower_loss_agg", + output=loss, + aggregation=variables_lib.VariableAggregation.MEAN) + output_context.set_last_step_output( + name="cross_tower_loss_noagg", + output=loss) + return distribution.group(train_op) + + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + + def run_step(): + initial_loss = lambda: constant_op.constant(1e7) + # Initial values corresponding to aggregated losses are just single + # tensors. But for non aggregated losses, we need to have initial + # values that are of the same structure as non reduced losses. In + # MirroredStrategy, this will be a list of losses, in TPUStrategy + # it will be single tensor. Using `broadcast` followed by `unwrap` + # gives us the desired initial value structure. + initial_loop_values = { + "tower_loss_agg": initial_loss(), + "cross_tower_loss_agg": initial_loss(), + "cross_tower_loss_noagg": + distribution.unwrap(distribution.broadcast(initial_loss())) + } + ctx = distribution.run_steps_on_dataset( + step_fn, iterator, iterations=2, + initial_loop_values=initial_loop_values) + + self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs) + self._verify_loss_output( + initial_loss(), + loss_output=ctx.last_step_outputs["tower_loss_agg"], + aggregated=True, distribution=distribution) + self._verify_loss_output( + initial_loss(), + loss_output=ctx.last_step_outputs["cross_tower_loss_agg"], + aggregated=True, distribution=distribution) + self._verify_loss_output( + initial_loss(), + loss_output=ctx.last_step_outputs["cross_tower_loss_noagg"], + aggregated=False, distribution=distribution) + return (ctx.run_op, ctx.last_step_outputs["tower_loss_agg"]) + + self.evaluate(distribution.initialize()) + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + weights, biases, losses = [], [], [] + for _ in range(5): + _, loss = run_step() + losses.append(loss) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) + + self.evaluate(distribution.finalize()) + + loss_is_not_increasing = all(y <= x for x, y in zip(losses, losses[1:])) + self.assertTrue(loss_is_not_increasing) + + error = abs( + numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(error_is_not_increasing) + + def _verify_loss_output(self, initial_loss, loss_output, aggregated, + distribution): + if not aggregated: + self.assertEqual(distribution.num_towers, + len(distribution.unwrap(loss_output))) + loss_output = distribution.reduce( + aggregation=variables_lib.VariableAggregation.MEAN, + value=loss_output, destinations="/device:CPU:0") + + unwrapped_output = distribution.unwrap(loss_output) + self.assertEqual(1, len(unwrapped_output)) + loss_tensor = unwrapped_output[0] + self.assertEqual(initial_loss.dtype, loss_tensor.dtype) + self.assertEqual(initial_loss.shape, loss_tensor.shape) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 900aa10e93e8881aa236bac8a2873d5c5531c6f6..e3376a06368e8ef5efcda5bb69de66b7ec3390e1 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -20,7 +20,6 @@ from __future__ import print_function import contextlib import threading -import six from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib from tensorflow.contrib.distribute.python import shared_variable_creator @@ -28,14 +27,17 @@ from tensorflow.contrib.distribute.python import values from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import tape +from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import coordinator from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.util import nest # TODO(josh11b): Replace asserts in this file with if ...: raise ... @@ -61,6 +63,233 @@ class _RequestedStop(Exception): pass +# Make _call_for_each_tower and _reduce_non_distributed_value not members of +# MirroredStrategy so that they are generally not allowed to use anything +# specific to MirroredStrategy and thus can be shared with other distribution +# strategies. + + +# TODO(yuefengz): maybe create a common class for those who need to call this +# _call_for_each_tower. +def _call_for_each_tower(distribution, fn, *args, **kwargs): + """Run `fn` in separate threads, once per tower/worker device. + + Args: + distribution: the DistributionStrategy object. + fn: function to run (will be run once per device, each in its own thread). + *args: positional arguments for `fn` + **kwargs: keyword arguments for `fn`. + `"run_concurrently"`: Boolean indicating whether executions of `fn` + can be run concurrently (under eager execution only), defaults to + `True`. + + Returns: + Merged return value of `fn` across all towers. + + Raises: + RuntimeError: If fn() calls get_tower_context().merge_call() a different + number of times from the available devices. + """ + run_concurrently = kwargs.pop("run_concurrently", True) + if not context.executing_eagerly(): + # Lots of TF library code isn't thread-safe in graph mode, and + # there is little to be gained by turning on multithreading when + # constructing a graph. + run_concurrently = False + # Needed for per-thread device, etc. contexts in graph mode. + ops.get_default_graph().switch_to_thread_local() + elif run_concurrently is None: + run_concurrently = True + + coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) + + shared_variable_store = {} + + # TODO(isaprykin): Create these threads once instead of during every run() + # call. + threads = [] + for index, d in enumerate(distribution.worker_devices): + variable_creator_fn = shared_variable_creator.make_fn( + shared_variable_store, index) + t = MirroredStrategy._MirroredTowerThread( # pylint: disable=protected-access + distribution, coord, d, variable_creator_fn, fn, + *values.select_device(d, args), **values.select_device(d, kwargs)) + threads.append(t) + + for t in threads: + t.start() + + # When `fn` starts `should_run` event is set on _MirroredTowerThread + # (`MTT`) threads. The execution waits until + # `MTT.has_paused` is set, which indicates that either `fn` is + # complete or a `get_tower_context().merge_call()` is called. If `fn` is + # complete, then `MTT.done` is set to True. Otherwise, arguments + # of `get_tower_context().merge_call` from all paused threads are grouped + # and the `merge_fn` is performed. Results of the + # `get_tower_context().merge_call` are then set to `MTT.merge_result`. + # Each such `get_tower_context().merge_call` call returns the + # `MTT.merge_result` for that thread when `MTT.should_run` event + # is reset again. Execution of `fn` resumes. + + try: + with coord.stop_on_exception(): + all_done = False + while not all_done and not coord.should_stop(): + done = [] + if run_concurrently: + for t in threads: + t.should_run.set() + for t in threads: + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + else: + for t in threads: + t.should_run.set() + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + if coord.should_stop(): + return None + all_done = all(done) + if not all_done: + if any(done): + raise RuntimeError("Some towers made a different number of " + "tower_context().merge_call() calls.") + # get_tower_context().merge_call() case + merge_args = values.regroup({t.device: t.merge_args for t in threads}) + merge_kwargs = values.regroup( + {t.device: t.merge_kwargs for t in threads}) + # We capture the name_scope of the MTT when we call merge_fn + # to ensure that if we have opened a name scope in the MTT, + # it will be respected when executing the merge function. We only + # capture the name_scope from the first MTT and assume it is + # the same for all other MTTs. + mtt_captured_name_scope = threads[0].captured_name_scope + with ops.name_scope(mtt_captured_name_scope): + merge_result = threads[0].merge_fn(distribution, *merge_args, + **merge_kwargs) + for t in threads: + t.merge_result = values.select_device(t.device, merge_result) + finally: + for t in threads: + t.should_run.set() + coord.join(threads) + + return values.regroup({t.device: t.main_result for t in threads}) + + +def _reduce_non_distributed_value(distribution, aggregation, value, + destinations): + """Reduce a non-DistributedValue `value` to `destinations`.""" + if isinstance(value, values.DistributedValues): + raise ValueError("You are passing a `DistributedValue` to " + "`_reduce_non_distributed_value`, which is not allowed.") + + # If the same value is present on all towers then the PerDevice value will + # be a single value. We also handle the case when `value` is a single value + # and equal to 0. + if value == 0: + return 0 + # If the aggregation type is MEAN, then this essentially means that the same + # value should be on all destinations. + if aggregation == variable_scope.VariableAggregation.MEAN: + return distribution.broadcast(value, destinations) + + cross_tower_ops_lib.validate_destinations(destinations) + # We do not support an aggregation type of SUM if the value is the same across + # all towers. We call this as part of assign functions for MirroredVariables + # and summing up identical values across towers is not clearly defined. + if (len(distribution.worker_devices) != 1 or + not cross_tower_ops_lib.check_destinations(destinations)): + raise ValueError("A non-DistributedValues value cannot be reduced with the " + "given aggregation.") + # TODO(anjalisridhar): Moves these methods to a device utility file? + devices = cross_tower_ops_lib.get_devices_from(destinations) + if len(devices) == 1: + with ops.device(devices[0]): + return array_ops.identity(value) + else: + value_updates = {} + for d in devices: + with ops.device(d): + value_updates[d] = array_ops.identity(value) + return values.Mirrored(value_updates) + + +def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring + # Figure out what collections this variable should be added to. + # We'll add the MirroredVariable to those collections instead. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + # Get synchronization value + synchronization = kwargs.get("synchronization", + variable_scope.VariableSynchronization.ON_WRITE) + if synchronization == variable_scope.VariableSynchronization.NONE: + raise ValueError("`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please" + " change the `synchronization` for variable: " + + kwargs["name"]) + elif synchronization == variable_scope.VariableSynchronization.ON_READ: + # Variables that are to be synced on read are tower local. + is_tower_local = True + kwargs["trainable"] = False + elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or + synchronization == variable_scope.VariableSynchronization.AUTO): + # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. + is_tower_local = False + else: + raise ValueError("Invalid variable synchronization mode: " + + synchronization + " for variable: " + kwargs["name"]) + + # Get aggregation value + aggregation = kwargs.pop("aggregation", + variable_scope.VariableAggregation.NONE) + if aggregation not in [ + variable_scope.VariableAggregation.NONE, + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN + ]: + raise ValueError("Invalid variable aggregation mode: " + aggregation + + " for variable: " + kwargs["name"]) + + # Ignore user-specified caching device, not needed for mirrored variables. + kwargs.pop("caching_device", None) + + # TODO(josh11b,apassos): It would be better if variable initialization + # was never recorded on the tape instead of having to do this manually + # here. + with tape.stop_recording(): + index = real_mirrored_creator(devices, *args, **kwargs) + + if is_tower_local: + result = values.TowerLocalVariable(index, index[devices[0]], aggregation) + else: + result = values.MirroredVariable(index, index[devices[0]], aggregation) + + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the member variables + # to the TRAINABLE_VARIABLES collection, so we manually remove + # them and replace with the MirroredVariable. We can't set + # "trainable" to False for next_creator() since that causes functions + # like implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + for v in index.values(): + l.remove(v) + g.add_to_collections(collections, result) + return result + + class MirroredStrategy(distribute_lib.DistributionStrategy): """Mirrors vars to distribute across multiple devices on a single machine. @@ -88,31 +317,17 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self._devices = [device_util.resolve(d) for d in devices] self._canonical_device_set = set(self._devices) self._device_index = values.PerDevice( - dict((d, i) for i, d in enumerate(devices))) + {d: i for i, d in enumerate(devices)}) self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device # TODO(yuefengz): consider setting the default device. def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" - # Figure out what collections this variable should be added to. - # We'll add the MirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) - tower_local = kwargs.pop("tower_local_reduce_method", None) - if tower_local is not None: - kwargs["trainable"] = False - - # TODO(josh11b,apassos): It would be better if variable initialization - # was never recorded on the tape instead of having to do this manually - # here. - with tape.stop_recording(): + def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring index = {} for i, d in enumerate(devices): with ops.device(d): @@ -136,155 +351,76 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v + return index - if tower_local is None: - result = values.MirroredVariable(index, index[devices[0]]) - else: - result = values.TowerLocalVariable( - index, index[devices[0]], tower_local) - - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the member variables - # to the TRAINABLE_VARIABLES collection, so we manually remove - # them and replace with the MirroredVariable. We can't set - # "trainable" to False for next_creator() since that causes functions - # like implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for v in index.values(): - l.remove(v) - g.add_to_collections(collections, result) - return result + return _create_mirrored_variable(devices, _real_mirrored_creator, *args, + **kwargs) def distribute_dataset(self, dataset_fn): return values.PerDeviceDataset( self._call_dataset_fn(dataset_fn), self._devices, self._prefetch_on_device) + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. + def _run_steps_on_dataset(self, fn, iterator, iterations, + initial_loop_values=None): + if initial_loop_values is None: + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + + ctx = values.MultiStepContext() + def body(i, *args): + """A wrapper around `fn` to create the while loop body.""" + del args + fn_result = fn(ctx, iterator.get_next()) + for (name, output) in ctx.last_step_outputs.items(): + # Convert all outputs to tensors, potentially from `DistributedValues`. + ctx.last_step_outputs[name] = self.unwrap(output) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + with ops.control_dependencies([fn_result]): + return [i + 1] + flat_last_step_outputs + + cond = lambda i, *args: i < iterations + i = constant_op.constant(0) + loop_result = control_flow_ops.while_loop( + cond, body, [i] + initial_loop_values, name="", + parallel_iterations=1, back_prop=False, swap_memory=False, + return_same_structure=True) + + ctx.run_op = control_flow_ops.group(loop_result) + + # Convert the last_step_outputs from a list to the original dict structure + # of last_step_outputs. + last_step_tensor_outputs = loop_result[1:] + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + output = last_step_tensor_outputs_dict[name] + # For outputs that have already been aggregated, wrap them in a Mirrored + # container, else in a PerDevice container. + if aggregation is variables_lib.VariableAggregation.NONE: + last_step_tensor_outputs_dict[name] = values.regroup( + {d: t for d, t in zip(self._devices, output)}, values.PerDevice) + else: + assert len(output) == 1 + last_step_tensor_outputs_dict[name] = output[0] + + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + return ctx + def _broadcast(self, tensor, destinations): # TODO(josh11b): In eager mode, use one thread per device, or async mode. return self._get_cross_tower_ops().broadcast(tensor, destinations or self._devices) def _call_for_each_tower(self, fn, *args, **kwargs): - """Run `fn` in separate threads, once per tower/worker device. - - Args: - fn: function to run (will be run once per device, each in its own thread). - *args: positional arguments for `fn` - **kwargs: keyword arguments for `fn`. - `"run_concurrently"`: Boolean indicating whether executions of `fn` - can be run concurrently (under eager execution only), defaults to - `True`. - - Returns: - Merged return value of `fn` across all towers. - - Raises: - RuntimeError: If fn() calls get_tower_context().merge_call() a different - number of times for when called for different devices. - """ - run_concurrently = kwargs.pop("run_concurrently", True) - if not context.executing_eagerly(): - # Lots of TF library code isn't thread-safe in graph mode, and - # there is little to be gained by turning on multithreading when - # constructing a graph. - run_concurrently = False - # Needed for per-thread device, etc. contexts in graph mode. - ops.get_default_graph().switch_to_thread_local() - elif run_concurrently is None: - run_concurrently = True - - coord = coordinator.Coordinator( - clean_stop_exception_types=(_RequestedStop,)) - - shared_variable_store = {} - - # TODO(isaprykin): Create these threads once instead of during every run() - # call. - threads = [] - for index, d in enumerate(self._devices): - variable_creator_fn = shared_variable_creator.make_fn( - shared_variable_store, index) - t = MirroredStrategy._MirroredTowerThread( - self, coord, d, variable_creator_fn, fn, - *values.select_device(d, args), **values.select_device(d, kwargs)) - threads.append(t) - - for t in threads: - t.start() - - # When `fn` starts `should_run` event is set on _MirroredTowerThread - # (`MTT`) threads. The execution waits until - # `MTT.has_paused` is set, which indicates that either `fn` is - # complete or a `get_tower_context().merge_call()` is called. If `fn` is - # complete, then `MTT.done` is set to True. Otherwise, arguments - # of `get_tower_context().merge_call` from all paused threads are grouped - # and the `merge_fn` is performed. Results of the - # `get_tower_context().merge_call` are then set to `MTT.merge_result`. - # Each such `get_tower_context().merge_call` call returns the - # `MTT.merge_result` for that thread when `MTT.should_run` event - # is reset again. Execution of `fn` resumes. - - try: - with coord.stop_on_exception(): - all_done = False - while not all_done and not coord.should_stop(): - done = [] - if run_concurrently: - for t in threads: - t.should_run.set() - for t in threads: - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - else: - for t in threads: - t.should_run.set() - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - if coord.should_stop(): - return None - all_done = all(done) - if not all_done: - if any(done): - raise RuntimeError("Some towers made a different number of " - "tower_context().merge_call() calls.") - # get_tower_context().merge_call() case - merge_args = values.regroup( - {t.device: t.merge_args for t in threads}) - merge_kwargs = values.regroup( - {t.device: t.merge_kwargs for t in threads}) - # We capture the name_scope of the MTT when we call merge_fn - # to ensure that if we have opened a name scope in the MTT, - # it will be respected when executing the merge function. We only - # capture the name_scope from the first MTT and assume it is - # the same for all other MTTs. - mtt_captured_name_scope = threads[0].captured_name_scope - with ops.name_scope(mtt_captured_name_scope): - merge_result = threads[0].merge_fn( - self, *merge_args, **merge_kwargs) - for t in threads: - t.merge_result = values.select_device(t.device, merge_result) - finally: - for t in threads: - t.should_run.set() - coord.join(threads) - - return values.regroup({t.device: t.main_result for t in threads}) + return _call_for_each_tower(self, fn, *args, **kwargs) def map(self, map_over, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. index = {} - i = 0 - for m in map_over: + for i, m in enumerate(map_over): d = self._devices[i % len(self._devices)] with ops.device(d): l = index.get(d, []) @@ -307,27 +443,29 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()) return self._cross_tower_ops - def _reduce(self, method_string, value, destinations): - if len(self._devices) == 1 and not isinstance(value, values.PerDevice): - value = values.PerDevice({self._devices[0]: value}) - assert isinstance(value, values.PerDevice) - + def _reduce(self, aggregation, value, destinations): + assert not isinstance(value, values.Mirrored) + if not isinstance(value, values.DistributedValues): + # This function handles reducing values that are not PerDevice or Mirrored + # values. For example, the same value could be present on all towers in + # which case `value` would be a single value or value could be 0. + return _reduce_non_distributed_value(self, aggregation, value, + destinations) return self._get_cross_tower_ops().reduce( - method_string, value, destinations=destinations) + aggregation, value, destinations=destinations) - def _batch_reduce(self, method_string, value_destination_pairs): - return self._get_cross_tower_ops().batch_reduce(method_string, + def _batch_reduce(self, aggregation, value_destination_pairs): + return self._get_cross_tower_ops().batch_reduce(aggregation, value_destination_pairs) def _update(self, var, fn, *args, **kwargs): - # TODO(josh11b): Also support TowerLocalVariables here? If so, args and - # kwargs don't need to be mirrored. - assert isinstance(var, values.MirroredVariable) # TODO(josh11b): In eager mode, use one thread per device. + assert isinstance(var, values.DistributedVariable) updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) @@ -347,37 +485,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def read_var(self, tower_local_var): """Read the aggregate value of a tower-local variable.""" if isinstance(tower_local_var, values.TowerLocalVariable): - return math_ops.add_n(self.unwrap(tower_local_var)) + return tower_local_var._get_cross_tower() # pylint: disable=protected-access assert isinstance(tower_local_var, values.Mirrored) return array_ops.identity(tower_local_var.get()) - def _fetch(self, val, destination, fn): - """Return a copy of `val` or `fn(val)` on `destination`.""" - if isinstance(val, values.TowerLocalVariable): - val = self.reduce(val.reduce_method, val, destinations=destination) - with ops.device(destination): - return fn(self.unwrap(val)[0]) - - assert isinstance(val, values.Mirrored), ( - "val = %s (type %s)" % (val, val.__class__.__name__)) - if val.on_device(destination): - with ops.device(destination): - # Use an identity here to make sure we are returning a tensor - # instead of e.g. a variable object. - return array_ops.identity(fn(val.get(destination))) - device = None - for d in self._devices: - if val.on_device(d): - device = d - break - assert device is not None, ( - "Could not find destination %s in list of devices %s." % - (destination, val.devices)) - with ops.device(device): - v = fn(val.get(device)) - with ops.device(destination): - return array_ops.identity(v) - def _unwrap(self, val): if isinstance(val, values.DistributedValues): # Return in a deterministic order. @@ -386,6 +497,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return [val.get(device=d) for d in sorted(val.devices)] return [val] + def value_container(self, val): + return values.value_container(val) + @property def is_single_tower(self): return len(self._devices) == 1 @@ -413,15 +527,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _get_devices_from(self, colocate_with=None): if colocate_with is None: return self._devices - elif isinstance(colocate_with, values.DistributedValues): - # pylint: disable=protected-access - return list(colocate_with._index.keys()) - elif isinstance(colocate_with, six.string_types): - return [device_util.resolve(colocate_with)] - elif isinstance(colocate_with, list): - return [device_util.resolve(d) for d in colocate_with] else: - return colocate_with + return cross_tower_ops_lib.get_devices_from(colocate_with) class _MirroredTowerThread(threading.Thread): """A thread that runs() a function on a device.""" diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index bccd278847e3c87080af3cb15665e7a0d802d8fb..9a4cc0a8975c39cf82e474d660968afc17991db0 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -25,18 +25,23 @@ from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core +from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import device_util +from tensorflow.python.training import distribution_strategy_context + GPU_TEST = "test_gpu" in sys.argv[0] @@ -83,13 +88,13 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): self.skipTest("Not GPU test") self.assertEqual(2, self._get_distribution_strategy().num_towers) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallAndMergeExceptions(self): if not GPU_TEST: self.skipTest("Not GPU test") self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRunRegroupError(self): def run_fn(device_id): @@ -101,7 +106,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): with dist.scope(), self.assertRaises(AssertionError): dist.call_for_each_tower(run_fn, dist.worker_device_index) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testReduceToCpu(self): if not GPU_TEST: self.skipTest("Not GPU test") @@ -112,12 +117,35 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): dist = self._get_distribution_strategy() with dist.scope(): result = dist.call_for_each_tower(run_fn, dist.worker_device_index) - reduced = dist.reduce("sum", result, destinations="/device:CPU:0") + reduced = dist.reduce( + variable_scope.VariableAggregation.SUM, + result, + destinations="/device:CPU:0") unwrapped = dist.unwrap(reduced) self.assertEqual(1, len(unwrapped)) expected = sum(range(len(dist.worker_devices))) self.assertEqual(expected, self.evaluate(unwrapped[0])) + @test_util.run_in_graph_and_eager_modes() + def testReduceToMultipleDestinations(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + + devices = ["/device:GPU:0"] + if GPU_TEST: + self.assertGreater(context.num_gpus(), 0) + print(self.id().split(".")[-1], "devices:", ", ".join(devices)) + + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + reduced = dist.reduce( + variable_scope.VariableAggregation.SUM, + 1.0, + destinations=["/device:CPU:0", "/device:GPU:0"]) + unwrapped = dist.unwrap(reduced) + self.assertEqual(2, len(unwrapped)) + self.assertEqual(1.0, self.evaluate(unwrapped[0])) + class MirroredStrategyVariableCreationTest(test.TestCase): @@ -136,7 +164,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # This variable should be created only once across the threads because of # special variable_creator functions used by `dist.call_for_each_tower`. v = variable_scope.variable(1.0, name="foo") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -153,7 +181,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): v = variable_scope.variable(1.0) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -173,7 +201,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): vs = [] for i in range(5): vs.append(variable_scope.variable(1.0, name="foo" + str(i))) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return vs dist = mirrored_strategy.MirroredStrategy( @@ -195,7 +223,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): vs.append(variable_scope.variable(1.0, name="foo_1/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) vs.append(variable_scope.variable(1.0, name="foo/bar_1")) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return vs dist = mirrored_strategy.MirroredStrategy( @@ -217,7 +245,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(device_id): v = variable_scope.variable(1.0, name="foo_" + str(device_id)) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -240,7 +268,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): layer2 = core.Dense(1) layer2(features) # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) layer3 = core.Dense(1) layer3(features) return [(layer1.kernel, layer1.bias), @@ -264,18 +293,70 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals("common/dense" + suffix + "/bias:0", bias.name) @test_util.run_in_graph_and_eager_modes(config=config) - def testWithGetVariableAndVariableScope(self): + def testWithVariableAndVariableScope(self): self._skip_eager_if_gpus_less_than(1) def model_fn(): - v0 = variable_scope.get_variable("var-thread0", [1]) + v0 = variable_scope.variable(1.0, name="var0", aggregation=None) with variable_scope.variable_scope("common"): - v1 = variable_scope.get_variable("var-thread1", [1]) + v1 = variable_scope.variable(1.0, name="var1") # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) - v2 = variable_scope.get_variable("var-thread2", [1]) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) + v2 = variable_scope.variable( + 1.0, + name="var2", + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + v3 = variable_scope.variable( + 1.0, + name="var3", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN) + + return v0, v1, v2, v3 - return v0, v1, v2 + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + v = variable_scope.variable(1.0, name="var-main0") + self.assertEquals("var-main0:0", v.name) + + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(4, len(result)) + v0, v1, v2, v3 = result + self.assertIsInstance(v0, values.MirroredVariable) + self.assertEquals("var0:0", v0.name) + self.assertIsInstance(v1, values.MirroredVariable) + self.assertEquals("common/var1:0", v1.name) + self.assertIsInstance(v2, values.TowerLocalVariable) + self.assertEquals("common/var2:0", v2.name) + self.assertEquals(variable_scope.VariableAggregation.SUM, v2.aggregation) + self.assertIsInstance(v3, values.MirroredVariable) + self.assertEquals("common/var3:0", v3.name) + self.assertEquals(variable_scope.VariableAggregation.MEAN, v3.aggregation) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testWithGetVariableAndVariableScope(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + v0 = variable_scope.get_variable("var0", [1]) + with variable_scope.variable_scope("common"): + v1 = variable_scope.get_variable("var1", [1]) + # This will pause the current thread, and execute the other thread. + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) + v2 = variable_scope.get_variable( + "var2", [1], + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + v3 = variable_scope.get_variable( + "var3", [1], + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN) + + return v0, v1, v2, v3 devices = ["/device:CPU:0", "/device:GPU:0"] dist = mirrored_strategy.MirroredStrategy(devices) @@ -285,14 +366,89 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals("main/var-main0:0", v.name) result = dist.call_for_each_tower(model_fn, run_concurrently=False) - self.assertEquals(3, len(result)) - v0, v1, v2 = result + self.assertEquals(4, len(result)) + v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) - self.assertEquals("main/var-thread0:0", v0.name) + self.assertEquals("main/var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) - self.assertEquals("main/common/var-thread1:0", v1.name) - self.assertIsInstance(v2, values.MirroredVariable) - self.assertEquals("main/common/var-thread2:0", v2.name) + self.assertEquals("main/common/var1:0", v1.name) + self.assertIsInstance(v2, values.TowerLocalVariable) + self.assertEquals("main/common/var2:0", v2.name) + self.assertEquals(variable_scope.VariableAggregation.SUM, + v2.aggregation) + self.assertIsInstance(v3, values.MirroredVariable) + self.assertEquals("main/common/var3:0", v3.name) + self.assertEquals(variable_scope.VariableAggregation.MEAN, + v3.aggregation) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testNoneSynchronizationWithGetVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please change " + "the `synchronization` for variable: v"): + variable_scope.get_variable( + "v", [1], + synchronization=variable_scope.VariableSynchronization.NONE) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testNoneSynchronizationWithVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please change " + "the `synchronization` for variable: v"): + variable_scope.variable( + 1.0, + name="v", + synchronization=variable_scope.VariableSynchronization.NONE) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testInvalidSynchronizationWithVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable synchronization mode: Invalid for " + "variable: v"): + variable_scope.variable(1.0, name="v", synchronization="Invalid") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testInvalidAggregationWithGetVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable aggregation mode: invalid for " + "variable: v"): + variable_scope.get_variable( + "v", [1], + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation="invalid") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testInvalidAggregationWithVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable aggregation mode: invalid for " + "variable: v"): + variable_scope.variable( + 1.0, + name="v", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation="invalid") @test_util.run_in_graph_and_eager_modes(config=config) def testThreeDevices(self): @@ -300,7 +456,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): v = variable_scope.variable(1.0, name="foo") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -317,7 +473,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(name): v = variable_scope.variable(1.0, name=name) - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call(lambda _: _) return v dist = mirrored_strategy.MirroredStrategy( @@ -337,34 +493,51 @@ class MirroredStrategyVariableCreationTest(test.TestCase): all_v_sum = {} all_v_mean = {} + components_sum = {} + components_mean = {} def model_fn(device_id): - tower_context = distribute_lib.get_tower_context() - with tower_context.tower_local_var_scope("sum"): - v_sum = variable_scope.variable(1.0) - with tower_context.tower_local_var_scope("mean"): - v_mean = variable_scope.variable(4.0) + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + v_mean = variable_scope.variable( + 4.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.MEAN) self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) self.assertTrue(isinstance(v_mean, values.TowerLocalVariable)) updates = [v_sum.assign_add(2.0 + device_id), v_mean.assign(6.0 * device_id)] all_v_sum[device_id] = v_sum all_v_mean[device_id] = v_mean - return updates, v_sum, v_mean + c_sum = v_sum.get() + c_mean = v_mean.get() + components_sum[device_id] = c_sum + components_mean[device_id] = c_mean + self.assertIsNot(v_sum, c_sum) + self.assertIsNot(v_mean, c_mean) + return updates, v_sum, v_mean, c_sum, c_mean dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): # Create "sum" and "mean" versions of TowerLocalVariables. - ret_ops, ret_v_sum, ret_v_mean = dist.call_for_each_tower( - model_fn, dist.worker_device_index, run_concurrently=False) + ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( + dist.call_for_each_tower( + model_fn, dist.worker_device_index, run_concurrently=False)) # Should see the same wrapping instance in all towers. self.assertIs(all_v_sum[0], ret_v_sum) self.assertIs(all_v_mean[0], ret_v_mean) - for i in range(1, dist.num_towers): - self.assertIs(all_v_sum[0], all_v_sum[1]) - self.assertIs(all_v_mean[0], all_v_mean[1]) + self.assertIs(all_v_sum[0], all_v_sum[1]) + self.assertIs(all_v_mean[0], all_v_mean[1]) + + # Regroup should recover the same wrapper. + self.assertIs(ret_v_sum, regrouped_sum) + self.assertIs(ret_v_mean, regrouped_mean) + self.assertIsNot(components_sum[0], components_sum[1]) + self.assertIsNot(components_mean[0], components_mean[1]) # Apply updates self.evaluate(variables.global_variables_initializer()) @@ -385,14 +558,13 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # Without get(device), should return the value you get by # applying the reduction across all towers (whether you use - # fetch(), get(), or nothing). - self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) - self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean))) + # read_var(), get(), or nothing). + self.assertEqual(expected_sum, self.evaluate(dist.read_var(ret_v_sum))) + self.assertEqual(expected_mean, self.evaluate(dist.read_var(ret_v_mean))) self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) - if not context.executing_eagerly(): - self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) - self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) + self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not # testing this in eager mode. @@ -401,7 +573,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): with ops.name_scope("foo"): a = constant_op.constant(1.0, name="a") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) b = constant_op.constant(1.0, name="b") return a, b @@ -422,7 +595,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): with ops.name_scope(None, "foo"): a = constant_op.constant(1.0, name="a") - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) b = constant_op.constant(2.0, name="b") return a, b @@ -450,7 +624,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): b = variable_scope.variable(1.0, name="b") with ops.name_scope("foo"): - c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + c = distribution_strategy_context.get_tower_context().merge_call( + in_cross_tower) return b, c dist = mirrored_strategy.MirroredStrategy( @@ -482,7 +657,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): def model_fn(): b = variable_scope.get_variable("b", [1]) with ops.name_scope("foo"): - c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + c = distribution_strategy_context.get_tower_context().merge_call( + in_cross_tower) return b, c dist = mirrored_strategy.MirroredStrategy( @@ -530,6 +706,543 @@ class MirroredStrategyVariableCreationTest(test.TestCase): _, v1 = dist.unwrap(v) self.assertStartsWith(v1.name, "tower_1/") + @test_util.run_in_graph_and_eager_modes(config=config) + def testTowerLocalVariableUpdate(self): + with context.graph_mode(): + + def model_fn(): + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) + return v_sum + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]) + + def update(var, value): + return var.assign(value) + + with dist.scope(): + ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False) + update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0)) + + # Initialize variables. + self.evaluate(variables.global_variables_initializer()) + # Assert that the aggregated value of the tower local vars is the sum of + # the individual values before running the update ops. + self.assertEquals(1.0, self.evaluate( + ret_v_sum.get(dist._devices[0]).read_value())) + self.assertEquals(2.0, self.evaluate(ret_v_sum)) + + # Apply updates. + self.evaluate(update_ops) + # Assert that the aggregated value of the tower local vars is the sum of + # the individual values after running the update ops. + self.assertEquals(5.0, self.evaluate( + ret_v_sum.get(dist._devices[0]).read_value())) + self.assertEquals(10.0, self.evaluate(ret_v_sum)) + + +class MirroredVariableUpdateTest(test.TestCase): + # The following tests check assign, assign_add and assign_sub on Mirrored + # variables in tower and cross tower context. + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Enough GPUs not available for this test in eager mode.") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContextWithoutAggregationType(self): + # Test that we always have an aggregation type set on the mirrored variable + # if we assign to it in tower mode. + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + v = variable_scope.variable(1.0, name="foo") + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + + def model_fn(): + return mirrored_var.assign(5.0) + + with self.assertRaisesRegexp( + ValueError, "You must specify an aggregation method to update a " + "MirroredVariable in Tower Context."): + self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContextWithSum(self): + # Test that we don't reduce a non-per-device value with the "sum" + # aggregation type. + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + v = variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + + def model_fn(): + return mirrored_var.assign(5.0) + + with self.assertRaisesRegexp( + ValueError, "A non-DistributedValues value cannot be reduced with " + "the given aggregation."): + self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarCrossTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(1.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + mirrored_var_result = self.evaluate(mirrored_var.assign(6.0)) + self.assertEquals(6.0, mirrored_var_result) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + value = math_ops.cast( + distribution_strategy_context.get_tower_context().tower_id, + mirrored_var.dtype) + return mirrored_var.assign(value) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(0.5, self.evaluate(mirrored_var)) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign(5.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignAddMirroredVarCrossTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(1.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0)) + self.assertEquals(7.0, mirrored_var_result) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignAddMirroredVarTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + value = math_ops.cast( + distribution_strategy_context.get_tower_context().tower_id, + mirrored_var.dtype) + return mirrored_var.assign_add(value) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(1.5, self.evaluate(mirrored_var)) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignAddMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign_add(5.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(6.0, self.evaluate(mirrored_var)) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignSubMirroredVarCrossTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable(5.0, name="foo") + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0)) + self.assertEquals(3.0, mirrored_var_result) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignSubMirroredVarTowerContext(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + + def model_fn(): + value = math_ops.cast( + distribution_strategy_context.get_tower_context().tower_id, + mirrored_var.dtype) + return mirrored_var.assign_sub(value) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(4.5, self.evaluate(mirrored_var)) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignSubMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign_sub(1.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(4.0, self.evaluate(mirrored_var)) + + +class MirroredAndTowerLocalVariableInitializerTest(test.TestCase): + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def testAssignMirroredVarInitializer(self): + # This test is not eager compatible since in eager variables are initialized + # upon construction instead of once the initialization op is run. + with context.graph_mode(): + def var_fn(): + v = variable_scope.variable(1.0, name="foo") + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.assertFalse(self.evaluate(mirrored_var.is_initialized())) + self.evaluate(mirrored_var.initializer) + self.assertTrue(self.evaluate(mirrored_var.is_initialized())) + + def testAssignTowerLocalVarInitializer(self): + # This test is not eager compatible since in eager variables are initialized + # upon construction instead of once the initialization op is run. + with context.graph_mode(): + def model_fn(): + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) + return v_sum + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + tower_local_var = dist.call_for_each_tower(model_fn) + self.assertTrue(isinstance(tower_local_var, values.TowerLocalVariable)) + self.assertFalse(self.evaluate(tower_local_var.is_initialized())) + self.evaluate(tower_local_var.initializer) + self.assertTrue(self.evaluate(tower_local_var.is_initialized())) + + +class TowerLocalVariableAssignTest(test.TestCase): + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Not enough GPUs available for this test in eager mode.") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignTowerLocalVarSumAggregation(self): + self._skip_eager_if_gpus_less_than(1) + def model_fn(): + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + return v_sum + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + tower_local_var = dist.call_for_each_tower(model_fn, + run_concurrently=False) + self.assertTrue(isinstance(tower_local_var, values.TowerLocalVariable)) + self.evaluate(variables.global_variables_initializer()) + # Each tower has a value of 1.0 assigned to it in tower context. + # When we read the value using `read_var` we should see the SUM of each of + # values on each of the towers. + self.assertEqual(2.0, self.evaluate(dist.read_var(tower_local_var))) + # Assigning 6.0 in cross tower context will assign a value of + # 6.0/num_towers to each tower. + tlv_ops = tower_local_var.assign(6.0) + self.evaluate(tlv_ops) + # On reading the tower local var we should get the assigned value back. + # The value on all the towers are added before being returned by + # `read_var`. + self.assertEqual(6.0, self.evaluate(dist.read_var(tower_local_var))) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignTowerLocalVarMeanAggregation(self): + self._skip_eager_if_gpus_less_than(1) + def model_fn(): + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.MEAN) + return v_sum + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + tower_local_var = dist.call_for_each_tower(model_fn, + run_concurrently=False) + self.assertTrue(isinstance(tower_local_var, values.TowerLocalVariable)) + self.evaluate(variables.global_variables_initializer()) + # Each tower has a value of 1.0 assigned to it in tower context. + # When we read the value using `read_var` we should see the MEAN of values + # on all towers which is the value assigned in tower context. + self.assertEqual(1.0, self.evaluate(dist.read_var(tower_local_var))) + tlv_ops = tower_local_var.assign(6.0) + self.evaluate(tlv_ops) + # On reading the tower local var we should get the MEAN of all values + # which is equal to the value assigned. + self.assertEqual(6.0, self.evaluate(dist.read_var(tower_local_var))) + + +class MockModel(object): + + def __init__(self, two_variables=False): + self.variables = [] + self.variables.append(variable_scope.variable(1.25, name="dummy_var1")) + if two_variables: + self.variables.append(variable_scope.variable(2.0, name="dummy_var2")) + + def __call__(self, factor=2): + x = factor * self.variables[0] + if len(self.variables) > 1: + x += self.variables[1] + return x + + +class MirroredStrategyDefunTest(test.TestCase): + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Not enough GPUs available for this test in eager mode.") + + def _call_and_check(self, model_fn, inputs, expected_result, defuns, + two_variables=False): + cpu_dev = device_util.canonicalize("CPU:0") + gpu_dev = device_util.canonicalize("GPU:0") + devices = [cpu_dev, gpu_dev] + dist = mirrored_strategy.MirroredStrategy(devices) + + with dist.scope(): + mock_model = MockModel(two_variables) + self.evaluate(variables.global_variables_initializer()) + + result = dist.call_for_each_tower(model_fn, mock_model, *inputs, + run_concurrently=False) + for device in devices: + device_result = values.select_device(device, result) + device_expected_result = values.select_device(device, expected_result) + self.assertAllClose(device_expected_result, + self.evaluate(device_result)) + + for defun in defuns: + self.assertEqual(set(mock_model.variables), set(defun.variables)) + + @test_util.run_in_graph_and_eager_modes() + def testVariableInDefun(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def times_two(mock_model): + return mock_model() + + def model_fn(mock_model): + return times_two(mock_model) + + self._call_and_check(model_fn, [], 2.5, [times_two]) + + @test_util.run_in_graph_and_eager_modes() + def testVariableInNestedDefun(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def times_two(mock_model): + return mock_model() + + @function.defun + def two_x_plus_one(mock_model): + return times_two(mock_model) + 1 + + def model_fn(mock_model): + return two_x_plus_one(mock_model) + + self._call_and_check(model_fn, [], 3.5, [times_two, two_x_plus_one]) + + @test_util.run_in_graph_and_eager_modes() + def testTwoVariablesInNestedDefun(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def fn1(mock_model): + return mock_model() + + @function.defun + def fn2(mock_model): + return fn1(mock_model) + 1 + + def model_fn(mock_model): + return fn2(mock_model) + + self._call_and_check(model_fn, [], 5.5, [fn1, fn2], two_variables=True) + + @test_util.run_in_graph_and_eager_modes() + def testGradientTapeOverNestedDefuns(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def fn1(mock_model): + return mock_model() + + @function.defun + def fn2(mock_model): + return fn1(mock_model) + 1 + + def model_fn(mock_model): + with backprop.GradientTape(persistent=True) as gtape: + result = fn2(mock_model) + grads = gtape.gradient(result, + [v.get() for v in mock_model.variables]) + return grads + + self._call_and_check(model_fn, [], [2.0, 1.0], [fn1, fn2], + two_variables=True) + + @test_util.run_in_graph_and_eager_modes() + def testPassPerDevice(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def fn1(mock_model, factor): + return mock_model(factor) + + factors = values.PerDevice({"CPU:0": 5.0, "GPU:0": 3.0}) + expected_result = values.PerDevice({"CPU:0": 5.0 * 1.25, + "GPU:0": 3.0 * 1.25}) + self._call_and_check(fn1, [factors], expected_result, [fn1]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index 61cbe6df813bb28bf8baa83d9e28ffafc4f0cbb8..5db2fff2390ea943a73e5cee6fabc4ae92644b42 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -24,7 +24,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): @@ -47,7 +47,7 @@ class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): def testTowerId(self): self._test_tower_id(self._get_distribution_strategy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallAndMergeExceptions(self): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) @@ -68,7 +68,8 @@ class VariableCreatorStackTest(test.TestCase): v = variable_scope.variable(1.0) # This will pause the current thread, and execute the other thread. - distribute_lib.get_tower_context().merge_call(lambda _: _) + distribution_strategy_context.get_tower_context().merge_call( + lambda _: _) return v def main_thread_creator(next_creator, *args, **kwargs): diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py index 4fdb9bf69b4f6ad76b79fd298f5303f24a1bd455..2892ce439494320a115b8eae0025a132841c4a8f 100644 --- a/tensorflow/contrib/distribute/python/monitor_test.py +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -52,11 +52,11 @@ class MonitorTest(test.TestCase, parameterized.TestCase): self.assertEqual(1, len(layer.trainable_variables)) mirrored_weight_variable = layer.trainable_variables[0] - start_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + start_error = self.evaluate(mirrored_weight_variable) start_error = abs(numpy.array(start_error) - 1) monitor.run_steps(9) - end_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + end_error = self.evaluate(mirrored_weight_variable) end_error = abs(numpy.array(end_error) - 1) self.assertGreaterEqual(start_error, end_error) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py index 0f21a427320510635279f80c11711e81715ec37c..cbfe5df61d1ee6fa1eb9275b715b0721d678a46f 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py @@ -46,7 +46,7 @@ class MultiWorkerMirroredStrategy(MirroredStrategy): * **In-graph replication**: the `client` creates a single `tf.Graph` that specifies tasks for devices on all workers. The `client` then creates a client session which will talk to the `master` service of a `worker`. Then - the `master` will parition the graph and distribute the work to all + the `master` will partition the graph and distribute the work to all participating workers. * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one physical machine. We will have multiple `worker`s with different `task` diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index f659be5f42594b275af06435cb0c228e5d594ac9..249de01f0880b02d603687db99692088480f7136 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -20,35 +20,68 @@ from __future__ import print_function import contextlib import copy +import threading +import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session -from tensorflow.python.eager import test +from tensorflow.python.estimator import run_config +from tensorflow.python.platform import test from tensorflow.python.framework import test_util +def create_in_process_cluster(num_workers, num_ps): + """Create an in-process cluster that consists of only standard server.""" + # Leave some memory for cuda runtime. + gpu_mem_frac = 0.7 / num_workers + worker_config = config_pb2.ConfigProto() + worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac + + # Enable collective ops which has no impact on non-collective ops. + # TODO(yuefengz, tucker): removing this after we move the initialization of + # collective mgr to the session level. + worker_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') + + ps_config = config_pb2.ConfigProto() + ps_config.device_count['GPU'] = 0 + + # Create in-process servers. Once an in-process tensorflow server is created, + # there is no way to terminate it. So we create one cluster per test process. + # We could've started the server in another process, we could then kill that + # process to terminate the server. The reasons why we don't want multiple + # processes are + # 1) it is more difficult to manage these processes; + # 2) there is something global in CUDA such that if we initialize CUDA in the + # parent process, the child process cannot initialize it again and thus cannot + # use GPUs (https://stackoverflow.com/questions/22950047). + return test_util.create_local_cluster( + num_workers, + num_ps=num_ps, + worker_config=worker_config, + ps_config=ps_config, + protocol='grpc') + + class MultiWorkerTestBase(test.TestCase): """Base class for testing multi node strategy and dataset.""" @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" - num_workers = 2 - # Leave some memory for cuda runtime. - gpu_mem_frac = 0.7 / num_workers - default_config = config_pb2.ConfigProto() - default_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac - - # The local cluster takes some portion of the local GPUs and there is no way - # for the cluster to terminate unless using multiple processes. Therefore, - # we have to only create only one cluster throughout a test process. - workers, _ = test_util.create_local_cluster( - num_workers, num_ps=0, worker_config=default_config) - cls._master_target = workers[0].target + cls._workers, cls._ps = create_in_process_cluster(num_workers=2, num_ps=0) + + def setUp(self): + # We only cache the session in one test because another test may have a + # different session config or master target. + self._thread_local = threading.local() + self._thread_local.cached_session = None + self._result = 0 + self._lock = threading.Lock() @contextlib.contextmanager - def test_session(self, graph=None, config=None): + def test_session(self, graph=None, config=None, target=None): """Create a test session with master target set to the testing cluster. This overrides the base class' method, removes arguments that are not needed @@ -59,6 +92,7 @@ class MultiWorkerTestBase(test.TestCase): graph: Optional graph to use during the returned session. config: An optional config_pb2.ConfigProto to use to configure the session. + target: the target of session to connect to. Yields: A Session object that should be used as a context manager to surround @@ -78,13 +112,46 @@ class MultiWorkerTestBase(test.TestCase): rewriter_config_pb2.RewriterConfig.OFF) if graph is None: - if self._cached_session is None: # pylint: disable=access-member-before-definition - self._cached_session = session.Session( - graph=None, config=config, target=self._master_target) - sess = self._cached_session + if getattr(self._thread_local, 'cached_session', None) is None: + self._thread_local.cached_session = session.Session( + graph=None, config=config, target=target or self._workers[0].target) + sess = self._thread_local.cached_session with sess.graph.as_default(), sess.as_default(): yield sess else: with session.Session( - graph=graph, config=config, target=self._master_target) as sess: + graph=graph, config=config, target=target or + self._workers[0].target) as sess: yield sess + + def _run_client(self, client_fn, task_type, task_id, num_gpus, *args, + **kwargs): + result = client_fn(task_type, task_id, num_gpus, *args, **kwargs) + if np.all(result): + with self._lock: + self._result += 1 + + def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args, + **kwargs): + """Runs several clients for between-graph replication. + + Args: + client_fn: a function that needs to accept `task_type`, `task_id`, + `num_gpus` and returns True if it succeeds. + cluster_spec: a dict specifying jobs in a cluster. + num_gpus: number of GPUs per worker. + *args: will be passed to `client_fn`. + **kwargs: will be passed to `client_fn`. + """ + threads = [] + for task_type in [run_config.TaskType.CHIEF, run_config.TaskType.WORKER]: + for task_id in range(len(cluster_spec.get(task_type, []))): + t = threading.Thread( + target=self._run_client, + args=(client_fn, task_type, task_id, num_gpus) + args, + kwargs=kwargs) + t.start() + threads.append(t) + for t in threads: + t.join() + self.assertEqual(self._result, len(threads)) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 7f4bab9d93814eb70a2a1586fc291a16b2766b90..016978cdb3a152bbba0a2e63df1dea4035e32789 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -21,10 +21,14 @@ from __future__ import print_function import six from tensorflow.contrib.distribute.python import values +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.util import nest # TODO(josh11b): Replace asserts in this file with if ...: raise ... @@ -43,11 +47,6 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): self._default_device = device def _create_variable(self, next_creator, *args, **kwargs): - # No need to distinguish tower-local variables when not mirroring, - # we just enforce that they are not trainable. - if kwargs.pop("tower_local_reduce_method", None) is not None: - kwargs["trainable"] = False - colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: with ops.device(self._device): @@ -70,6 +69,41 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): def _broadcast(self, tensor, destinations): return tensor + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. + def _run_steps_on_dataset(self, fn, iterator, iterations, + initial_loop_values=None): + if initial_loop_values is None: + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + + ctx = values.MultiStepContext() + def body(i, *args): + """A wrapper around `fn` to create the while loop body.""" + del args + fn_result = fn(ctx, iterator.get_next()) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + with ops.control_dependencies([fn_result]): + return [i + 1] + flat_last_step_outputs + + cond = lambda i, *args: i < iterations + i = constant_op.constant(0) + # TODO(priyag): Use max_iterations instead of an explicit counter. + loop_result = control_flow_ops.while_loop( + cond, body, [i] + initial_loop_values, name="", + parallel_iterations=1, back_prop=False, swap_memory=False, + return_same_structure=True) + + ctx.run_op = control_flow_ops.group(loop_result) + + # Convert the last_step_outputs from a list to the original dict structure + # of last_step_outputs. + last_step_tensor_outputs = loop_result[1:] + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + return ctx + def _call_for_each_tower(self, fn, *args, **kwargs): # We don't run `fn` in multiple threads in OneDeviceStrategy. kwargs.pop("run_concurrently", None) @@ -80,15 +114,15 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.device(self._device): return values.MapOutput([fn(m, *args, **kwargs) for m in map_over]) - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): if not isinstance(value, values.MapOutput): return value l = value.get() assert l with ops.device(self._device): - if method_string == "sum": + if aggregation == vs.VariableAggregation.SUM: return math_ops.add_n(l) - elif method_string == "mean": + elif aggregation == vs.VariableAggregation.MEAN: return math_ops.add_n(l) / len(l) else: assert False @@ -106,16 +140,12 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): """Read the aggregate value of a tower-local variable.""" return array_ops.identity(tower_local_var) - def _fetch(self, val, destination, fn): - """Return a copy of `val` or `fn(val)` on `destination`.""" - with ops.device(self._device): - v = fn(val) - with ops.device(destination): - return array_ops.identity(v) - def _unwrap(self, value): return [value] + def value_container(self, value): + return value + @property def is_single_tower(self): return True diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 7aad8a953cbedd30b48739416e74b3dc164dc4cd..4fdc0f72e6745b7ef25c591157955f214e0b2c79 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -44,7 +44,7 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def testTowerId(self): self._test_tower_id(self._get_distribution_strategy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallAndMergeExceptions(self): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index abd3a65ac4e19ece6b69b9834f4218fde55b60c2..a2d736e42271ab1627240949b99088ed3f0746f6 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -59,8 +59,8 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - weights.append(self.evaluate(distribution.fetch(layer.kernel))) - biases.append(self.evaluate(distribution.fetch(layer.bias))) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..407c78df95ded5ef6f3ad973392a4d4a21d07735 --- /dev/null +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -0,0 +1,358 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes implementing a multi-worker ps DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os + +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import device_setter +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import server_lib +from tensorflow.python.util import nest + +_LOCAL_CPU = "/device:CPU:0" +_LOCAL_GPU_0 = "/device:GPU:0" + + +def _normalize_cluster_spec(cluster_spec): + """Makes `cluster_spec` into a `ClusterSpec` object.""" + if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): + return server_lib.ClusterSpec(cluster_spec) + elif not isinstance(cluster_spec, server_lib.ClusterSpec): + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object") + return cluster_spec + + +# TODO(yuefengz): maybe cache variables on local CPU. +# TODO(yuefengz): we may want to set session options to disallow communication +# between workers. +class ParameterServerStrategy(distribute_lib.DistributionStrategy): + """A parameter server DistributionStrategy. + + This strategy class works for both local training and between-graph replicated + training for multiple workers. If `cluster_spec` is specified, either passed + in to __init__() method or parsed from the + ["TF_CONFIG" environment + variable](https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig), + variables and updates to those variables are assigned to parameter servers and + other operations are assigned to workers. If `cluster_spec` is not set, it + becomes local training where variables are assigned to local CPU or the only + GPU. When each worker has more than one GPU, operations will be replicated on + these GPUs. In both cases, operations are replicated but variables are not and + these workers share a common view for which paramater server a variable is + assigned to. + + This class assumes between-graph replication will be used and works on a graph + for a particular worker. + + It is expected to call `call_for_each_tower(fn, *args, **kwargs)` for any + operations which potentially can be replicated across towers (i.e. multiple + GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra + caution needs to be taken: + + 1) Always use `tf.get_variable` instead of `tf.Variable` which is not able + to refer to the same variable on different towers. + + 2) It is generally not recommended to open a device scope under the strategy's + scope. A device scope (i.e. calling `tf.device`) will be merged with or + override the device for operations but will not change the device for + variables. + + 3) It is also not recommended to open a colocation scope (i.e. calling + `tf.colocate_with`) under the strategy's scope. For colocating variables, + use `distribution.colocate_vars_with` instead. Colocation of ops will possibly + create conflicts of device assignement. + """ + + def __init__(self, + num_gpus_per_worker=0, + cluster_spec=None, + task_type=None, + task_id=None): + """Initiailizes this strategy. + + Args: + num_gpus_per_worker: number of local GPUs or GPUs per worker. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type. + task_id: the current task id. + """ + super(ParameterServerStrategy, self).__init__() + self._num_gpus_per_worker = num_gpus_per_worker + if cluster_spec: + cluster_spec = _normalize_cluster_spec(cluster_spec) + self._cluster_spec = cluster_spec + + # We typically don't need to do all-reduce in this strategy. + self._cross_tower_ops = ( + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + reduce_to_device=_LOCAL_CPU)) + + self._initialize_devices(num_gpus_per_worker, cluster_spec, task_type, + task_id) + + def _initialize_devices(self, num_gpus_per_worker, cluster_spec, task_type, + task_id): + """Initialize internal devices. + + It creates variable devices and compute devices. Variables and operations + will be assigned to them respectively. We have one compute device per tower. + The variable device is a device function or device string. The default + variable device assigns variables to parameter servers in a round-robin + fashion. + + Args: + num_gpus_per_worker: number of local GPUs or GPUs per worker. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type. + task_id: the current task id. + + Raises: + ValueError: if the cluster_spec doesn't have ps jobs. + """ + self._task_type = task_type or "worker" + self._task_id = task_id or 0 + self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id) + + # TODO(yuefengz): maybe clearer to split it into two classes, one for + # the distribuetd case and one for the local case, once we have the factory + # class/method. + + # Define compute devices which is a list of device strings and one for each + # tower. When there are GPUs, replicate operations on these GPUs. Otherwise, + # place operations on CPU. + if cluster_spec is None: + # Local mode. + if num_gpus_per_worker > 0: + self._compute_devices = list( + map("/device:GPU:{}".format, range(num_gpus_per_worker))) + else: + self._compute_devices = [_LOCAL_CPU] + else: + # Distributed mode. + if num_gpus_per_worker > 0: + self._compute_devices = [ + "%s/device:GPU:%d" % (self._worker_device, i) + for i in range(num_gpus_per_worker) + ] + else: + self._compute_devices = [self._worker_device] + + self._compute_devices = list( + map(device_util.resolve, self._compute_devices)) + self._canonical_compute_device_set = set(self._compute_devices) + + # Define variable device which is a device string in the local case and a + # device function in the distributed case. It is used to open a device scope + # where varibles are defined. + # The `_parameter_devices` is needed for the `parameter_devices` property + # and is a list of all variable devices. + if cluster_spec is None: + # Local mode. If there is only one GPU, put everything on that GPU. + # Otherwise, place variables on CPU. + if num_gpus_per_worker == 1: + assert len(list(self._compute_devices)) == 1 + self._variable_device = _LOCAL_GPU_0 + self._parameter_devices = [_LOCAL_GPU_0] + else: + self._variable_device = _LOCAL_CPU + self._parameter_devices = [_LOCAL_CPU] + else: + # Distributed mode. Place variables on ps jobs in a round-robin fashion. + # Note that devices returned from `replica_device_setter` are not + # canonical and therefore we don't canonicalize all variable devices to + # make them consistent. + # TODO(yuefengz): support passing a strategy object to control variable + # assignment. + # TODO(yuefengz): merge the logic of replica_device_setter into this + # class. + num_ps_replicas = len(cluster_spec.as_dict().get("ps", [])) + if num_ps_replicas == 0: + raise ValueError("The cluster spec needs to have `ps` jobs.") + self._variable_device = device_setter.replica_device_setter( + ps_tasks=num_ps_replicas, + worker_device=self._worker_device, + merge_devices=True, + cluster=cluster_spec) + + # Parameter devices are all tasks of the "ps" job. + self._parameter_devices = map("/job:ps/task:{}".format, + range(num_ps_replicas)) + + # Define the default device in cross-tower mode. In the distributed case, we + # set the default device to the corresponding worker to prevent these ops + # from being placed on other workers. + if cluster_spec is None: + self._default_device = None + else: + self._default_device = self._worker_device + + def distribute_dataset(self, dataset_fn): + """Distributes the dataset to each local GPU.""" + return values.PerDeviceDataset( + self._call_dataset_fn(dataset_fn), self._compute_devices, True) + + def _broadcast(self, tensor, destinations): + if not cross_tower_ops_lib.check_destinations(destinations): + destinations = self._compute_devices + return self._cross_tower_ops.broadcast(tensor, destinations) + + # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through + # this creator, such as "MutableHashTable". + def _create_variable(self, next_creator, *args, **kwargs): + if "colocate_with" in kwargs: + with ops.device(None): + with ops.colocate_with(kwargs["colocate_with"]): + return next_creator(*args, **kwargs) + + with ops.colocate_with(None, ignore_existing=True): + with ops.device(self._variable_device): + return next_creator(*args, **kwargs) + + def _call_for_each_tower(self, fn, *args, **kwargs): + # pylint: disable=protected-access + return mirrored_strategy._call_for_each_tower(self, fn, *args, **kwargs) + + def _verify_destinations_not_different_worker(self, destinations): + if destinations is None: + return + for d in cross_tower_ops_lib.get_devices_from(destinations): + d_spec = tf_device.DeviceSpec.from_string(d) + if d_spec.job == self._task_type and d_spec.task != self._task_id: + raise ValueError( + "Cannot reduce to another worker: %r, current worker is %r" % + (d, self._worker_device)) + + def _reduce(self, aggregation, value, destinations): + self._verify_destinations_not_different_worker(destinations) + if not isinstance(value, values.DistributedValues): + # pylint: disable=protected-access + return mirrored_strategy._reduce_non_distributed_value( + self, aggregation, value, destinations) + + return self._cross_tower_ops.reduce( + aggregation, value, destinations=destinations) + + def _batch_reduce(self, aggregation, value_destination_pairs): + for _, destinations in value_destination_pairs: + self._verify_destinations_not_different_worker(destinations) + return self._cross_tower_ops.batch_reduce(aggregation, + value_destination_pairs) + + def _select_single_value(self, structured): + """Select any single values in `structured`.""" + + def _select_fn(x): # pylint: disable=g-missing-docstring + if isinstance(x, values.Mirrored): + if len(x.devices) == 1: + return list(x._index.values())[0] # pylint: disable=protected-access + else: + raise ValueError( + "You cannot update variable with a Mirrored object with multiple " + "components %r when using ParameterServerStrategy. You must " + "specify a single value or a Mirrored with a single value." % x) + elif isinstance(x, values.PerDevice): + raise ValueError( + "You cannot update variable with a PerDevice object %r when using " + "ParameterServerStrategy. You must specify a single value or a " + "Mirrored with a single value" % x) + else: + return x + + return nest.map_structure(_select_fn, structured) + + def _update(self, var, fn, *args, **kwargs): + if not isinstance(var, resource_variable_ops.ResourceVariable): + raise ValueError( + "You can not update `var` %r. It must be a Variable." % var) + with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): + return fn(var, *self._select_single_value(args), + **self._select_single_value(kwargs)) + + # TODO(yuefengz): does it need to call _select_single_value? + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + with ops.device( + colocate_with.device), distribute_lib.UpdateContext(colocate_with): + return fn(*args, **kwargs) + + def _unwrap(self, val): + if isinstance(val, values.DistributedValues): + # Return in a deterministic order. + if set(val.devices) == self._canonical_compute_device_set: + return [val.get(device=d) for d in self._compute_devices] + return [val.get(device=d) for d in sorted(val.devices)] + return [val] + + def value_container(self, val): + return values.value_container(val) + + def read_var(self, var): + # No need to distinguish between normal variables and tower-local variables. + return array_ops.identity(var) + + def configure(self, session_config=None): + del session_config + + # Use TF_CONFIG to get the cluster spec and the current job. + tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) + cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {})) + + task_env = tf_config.get("task", {}) + if task_env: + task_type = task_env.get("type", "worker") + task_id = int(task_env.get("index", "0")) + else: + task_type = "worker" + task_id = None + + # Set the devices if cluster_spec is defined in TF_CONFIG but not passed in + # the constructor. + if not self._cluster_spec and cluster_spec: + self._cluster_spec = cluster_spec + self._initialize_devices(self._num_gpus_per_worker, cluster_spec, + task_type, task_id) + + @property + def num_towers(self): + return len(self._compute_devices) + + @property + def worker_devices(self): + # Make a copy to prevent users from accidentally mutating our copy. + return list(self._compute_devices) + + @property + def parameter_devices(self): + return list(self._parameter_devices) + + def non_slot_devices(self, var_list): + return min(var_list, key=lambda x: x.name) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..02eb68227dfc51b10e7b3c78f54d0bd779f44f03 --- /dev/null +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -0,0 +1,433 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ParameterServerStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import threading +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import parameter_server_strategy +from tensorflow.python.eager import context +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import device_util +from tensorflow.python.training import distribution_strategy_context + + +class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, + parameterized.TestCase): + + @classmethod + def setUpClass(cls): + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=2) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' + ], + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + + def setUp(self): + self._result = 0 + self._lock = threading.Lock() + self._init_condition = threading.Condition() + self._init_reached = 0 + self._finish_condition = threading.Condition() + self._finish_reached = 0 + super(ParameterServerStrategyTest, self).setUp() + + def _get_test_objects(self, task_type, task_id, num_gpus): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=num_gpus) + if not task_type: + return distribution, '' + + tf_config = { + 'cluster': self._cluster_spec, + 'task': { + 'type': task_type, + 'index': task_id + } + } + with self._lock: + # Accessing environment variables should be protected by locks because + # environment variables are shared by all threads. + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + distribution.configure() + return distribution, self._workers[task_id].target + + def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): + worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) + d, _ = self._get_test_objects(task_type, task_id, num_gpus) + with ops.Graph().as_default(), \ + self.test_session(target=self._workers[0].target) as sess, \ + d.scope(): + + # Define a variable outside the call_for_each_tower scope. This is not + # recommended. + n = variable_scope.get_variable('n', initializer=10.0) + self.assertEqual(n.device, '/job:ps/task:0') + + def model_fn(): + if num_gpus == 0: + last_part_device = 'device:CPU:0' + else: + last_part_device = ( + 'device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) + + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + self.assertEqual(a.device, worker_device + '/' + last_part_device) + self.assertEqual(b.device, worker_device + '/' + last_part_device) + self.assertEqual(c.device, worker_device + '/' + last_part_device) + + # The device scope is ignored for variables but not for normal ops. + with ops.device('/job:worker/task:0'): + x = variable_scope.get_variable('x', initializer=10.0) + x_add = x.assign_add(c) + e = a + c + # The variable x is on the task 1 since the device_function has been + # called once before the model_fn. + self.assertEqual(x.device, '/job:ps/task:1') + self.assertEqual(x_add.device, x.device) + self.assertEqual(e.device, + '/job:worker/replica:0/task:0/%s' % last_part_device) + + # The colocate_vars_with can override the distribution's device. + with d.colocate_vars_with(x): + y = variable_scope.get_variable('y', initializer=20.0) + y_add = y.assign_add(x_add) + self.assertEqual(y.device, '/job:ps/task:1') + self.assertEqual(y_add.device, y.device) + self.assertEqual(y.device, x.device) + + z = variable_scope.get_variable('z', initializer=10.0) + self.assertEqual(z.device, '/job:ps/task:0') + self.assertNotEqual(z.device, x.device) + + with ops.control_dependencies([y_add]): + z_add = z.assign_add(y) + with ops.control_dependencies([z_add]): + f = z + c + self.assertEqual(f.device, worker_device + '/' + last_part_device) + + # The device scope would merge with the default worker device. + with ops.device('/CPU:1'): + g = e + 1.0 + self.assertEqual(g.device, worker_device + '/device:CPU:1') + + # Ths ops.colocate_with will be ignored when defining a variale but not + # for a normal tensor. + with ops.colocate_with(x): + u = variable_scope.get_variable('u', initializer=30.0) + v = variable_scope.get_variable('v', initializer=30.0) + h = f + 1.0 + self.assertIn('/job:ps/', u.device) + self.assertIn('/job:ps/', v.device) + # u and v are on different parameter servers. + self.assertTrue(u.device != x.device or v.device != x.device) + self.assertTrue(u.device == x.device or v.device == x.device) + # Here h is not on one worker. Note h.device is canonical while x.device + # is not but. + self.assertIn('/job:ps/', h.device) + return y_add, z_add, f + + y, z, f = d.call_for_each_tower(model_fn) + self.assertNotEqual(y, None) + self.assertNotEqual(z, None) + self.assertNotEqual(f, None) + + if context.num_gpus() >= 1 and num_gpus <= 1: + variables.global_variables_initializer().run() + y_val, z_val, f_val = sess.run([y, z, f]) + self.assertEqual(y_val, 33.0) + self.assertEqual(z_val, 43.0) + self.assertEqual(f_val, 46.0) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testDeviceAssignmentDistributed(self, num_gpus): + self._test_device_assignment_distributed('worker', 1, num_gpus) + + def _test_device_assignment_local(self, + d, + compute_device='CPU', + variable_device='CPU', + num_gpus=0): + with ops.Graph().as_default(), \ + self.test_session(target=self._workers[0].target) as sess, \ + d.scope(): + + def model_fn(): + if 'CPU' in compute_device: + tower_compute_device = '/device:CPU:0' + else: + tower_compute_device = ( + '/device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) + tower_compute_device = device_util.canonicalize(tower_compute_device) + + if 'CPU' in variable_device: + tower_variable_device = '/device:CPU:0' + else: + tower_variable_device = ( + '/device:GPU:%d' % + distribution_strategy_context.get_tower_context().tower_id) + tower_variable_device = device_util.canonicalize(tower_variable_device) + + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + self.assertEqual(a.device, tower_compute_device) + self.assertEqual(b.device, tower_compute_device) + self.assertEqual(c.device, tower_compute_device) + + # The device scope is ignored for variables but not for normal ops. + with ops.device('/device:GPU:2'): + x = variable_scope.get_variable('x', initializer=10.0) + x_add = x.assign_add(c) + e = a + c + self.assertEqual( + device_util.canonicalize(x.device), tower_variable_device) + self.assertEqual(x_add.device, x.device) + self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2')) + + # The colocate_vars_with can override the distribution's device. + with d.colocate_vars_with(x): + y = variable_scope.get_variable('y', initializer=20.0) + y_add = y.assign_add(x_add) + self.assertEqual( + device_util.canonicalize(y.device), tower_variable_device) + self.assertEqual(y_add.device, y.device) + self.assertEqual(y.device, x.device) + + z = variable_scope.get_variable('z', initializer=10.0) + self.assertEqual( + device_util.canonicalize(z.device), tower_variable_device) + + with ops.control_dependencies([y_add]): + z_add = z.assign_add(y) + with ops.control_dependencies([z_add]): + f = z + c + self.assertEqual(f.device, tower_compute_device) + + # The device scope would merge with the default worker device. + with ops.device('/CPU:1'): + g = e + 1.0 + self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1')) + + # Ths ops.colocate_with will be ignored when defining a variale but not + # for a normal tensor. + with ops.colocate_with(x): + u = variable_scope.get_variable('u', initializer=30.0) + h = f + 1.0 + self.assertEqual( + device_util.canonicalize(u.device), tower_variable_device) + self.assertEqual(device_util.canonicalize(x.device), h.device) + return y_add, z_add, f + + y, z, f = d.call_for_each_tower(model_fn) + self.assertNotEqual(y, None) + self.assertNotEqual(z, None) + self.assertNotEqual(f, None) + + if context.num_gpus() >= 1 and num_gpus <= 1: + variables.global_variables_initializer().run() + y_val, z_val, f_val = sess.run([y, z, f]) + self.assertEqual(y_val, 33.0) + self.assertEqual(z_val, 43.0) + self.assertEqual(f_val, 46.0) + + def testDeviceAssignmentLocalCPU(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=0) + self._test_device_assignment_local( + distribution, compute_device='CPU', variable_device='CPU', num_gpus=0) + + def testDeviceAssignmentLocalOneGPU(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=1) + self._test_device_assignment_local( + distribution, compute_device='GPU', variable_device='GPU', num_gpus=1) + + def testDeviceAssignmentLocalTwoGPUs(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_device_assignment_local( + distribution, compute_device='GPU', variable_device='CPU', num_gpus=2) + + def _test_simple_increment(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_objects(task_type, task_id, num_gpus) + if hasattr(d, '_cluster_spec') and d._cluster_spec: + num_workers = len(d._cluster_spec.as_dict().get('worker', + ['dummy_worker'])) + else: + num_workers = 1 + with ops.Graph().as_default(), \ + self.test_session(target=master_target) as sess, \ + d.scope(): + + def model_fn(): + x = variable_scope.get_variable('x', initializer=10.0) + y = variable_scope.get_variable('y', initializer=20.0) + + x_add = x.assign_add(1.0, use_locking=True) + y_add = y.assign_add(1.0, use_locking=True) + + train_op = control_flow_ops.group([x_add, y_add]) + return x, y, train_op + + x, y, train_op = d.call_for_each_tower(model_fn) + train_op = d.group(d.unwrap(train_op)) + + if context.num_gpus() < d._num_gpus_per_worker: + return True + + if task_id == 0: + variables.global_variables_initializer().run() + + # Workers waiting for chief worker's initializing variables. + self._init_condition.acquire() + self._init_reached += 1 + while self._init_reached != num_workers: + self._init_condition.wait() + self._init_condition.notify_all() + self._init_condition.release() + + sess.run(train_op) + + # Wait for other workers to finish training. + self._finish_condition.acquire() + self._finish_reached += 1 + while self._finish_reached != num_workers: + self._finish_condition.wait() + self._finish_condition.notify_all() + self._finish_condition.release() + + x_val, y_val = sess.run([x, y]) + self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_towers) + self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_towers) + return (x_val == 10.0 + 1.0 * num_workers * d.num_towers and + y_val == 20.0 + 1.0 * num_workers * d.num_towers) + + def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_objects(task_type, task_id, num_gpus) + with ops.Graph().as_default(), \ + self.test_session(target=master_target) as sess, \ + d.scope(): + l = core.Dense(1, use_bias=False) + + def loss_fn(x): + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + + # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for + # multiple graphs (b/111216820). + def grad_fn(x): + loss = loss_fn(x) + var_list = ( + variables.trainable_variables() + ops.get_collection( + ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + grads = gradients.gradients(loss, var_list) + ret = list(zip(grads, var_list)) + return ret + + def update(v, g): + return v.assign_sub(0.05 * g, use_locking=True) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one) + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.read_var(v) + before_list.append(fetched) + with ops.control_dependencies([fetched]): + # TODO(yuefengz): support non-Mirrored variable as destinations. + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.read_var(v)) + return before_list, after_list + + before_out, after_out = step() + + if context.num_gpus() < d._num_gpus_per_worker: + return True + + if task_id == 0: + variables.global_variables_initializer().run() + + # Workers waiting for chief worker's initializing variables. + self._init_condition.acquire() + self._init_reached += 1 + while self._init_reached != 3: + self._init_condition.wait() + self._init_condition.notify_all() + self._init_condition.release() + + for i in range(10): + b, a = sess.run((before_out, after_out)) + if i == 0: + before, = b + after, = a + + error_before = abs(before - 1) + error_after = abs(after - 1) + # Error should go down + self.assertLess(error_after, error_before) + return error_after < error_before + + def testSimpleBetweenGraph(self): + self._run_between_graph_clients(self._test_simple_increment, + self._cluster_spec, 0) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testLocalSimpleIncrement(self, num_gpus): + self._test_simple_increment(None, 0, num_gpus) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py index 7b3670b45aba801cf8c18e04bfea03e23eb67184..1ff60c076226299a89060a295c1cc0c50817b861 100644 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py @@ -35,7 +35,7 @@ from tensorflow.python.util import nest # pylint: disable=protected-access class _PrefetchToDeviceIterator(object): - """A replacement for @{tf.data.Iterator} that prefetches to another device. + """A replacement for `tf.data.Iterator` that prefetches to another device. Args: input_dataset: The input dataset. @@ -89,6 +89,9 @@ class _PrefetchToDeviceIterator(object): with ops.device(device): buffer_resource_handle = prefetching_ops.function_buffering_resource( f=_prefetch_fn, + output_types=data_nest.flatten( + sparse.as_dense_types(self._input_dataset.output_types, + self._input_dataset.output_classes)), target_device=target_device, string_arg=input_iterator_handle, buffer_size=buffer_size, @@ -105,7 +108,7 @@ class _PrefetchToDeviceIterator(object): self._input_dataset) def get_next(self, name=None): - """See @{tf.data.Iterator.get_next}.""" + """See `tf.data.Iterator.get_next`.""" self._get_next_call_count += 1 if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) @@ -206,7 +209,7 @@ class _PrefetchToDeviceDataset(dataset_ops.Dataset): def prefetch_to_devices(devices, buffer_size=None): """A transformation that prefetches dataset values to the given `devices`. - NOTE: Although the transformation creates a @{tf.data.Dataset}, the + NOTE: Although the transformation creates a `tf.data.Dataset`, the transformation must be the final `Dataset` in the input pipeline. Args: @@ -217,7 +220,7 @@ def prefetch_to_devices(devices, buffer_size=None): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): return _PrefetchToDeviceDataset(dataset, devices, buffer_size) diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py index a0b452fc2d445d1cf7dbf5e8fe0e29edef516207..2a9ab51fcfd29a8ae5b37b5c513415af29b277dc 100644 --- a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py +++ b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py @@ -46,7 +46,7 @@ class CanonicalizeVariableNameTest(test.TestCase): class SharedVariableCreatorTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSharedVariable(self): shared_variable_store = {} diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py index d1fdb3279cf2a7cba6e2282d58eedccf38bd38a3..5aa19cf6a9f8411120ed929cecaf93dda6c9edf2 100644 --- a/tensorflow/contrib/distribute/python/single_loss_example.py +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -29,7 +29,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -def single_loss_example(optimizer_fn, distribution, use_bias=False): +def single_loss_example(optimizer_fn, distribution, use_bias=False, + iterations_per_step=1): """Build a very simple network to use in tests and examples.""" def dataset_fn(): @@ -38,12 +39,13 @@ def single_loss_example(optimizer_fn, distribution, use_bias=False): optimizer = optimizer_fn() layer = core.Dense(1, use_bias=use_bias) - def loss_fn(x): + def loss_fn(ctx, x): + del ctx y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) return y * y - single_loss_step = step_fn.StandardSingleLossStep(dataset_fn, loss_fn, - optimizer, distribution) + single_loss_step = step_fn.StandardSingleLossStep( + dataset_fn, loss_fn, optimizer, distribution, iterations_per_step) # Layer is returned for inspecting the kernels in tests. return single_loss_step, layer diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index d1910622b38c748fc5a814f9e83c2294850d5d12..d3611570b472078bb5f154e9bcb8823c31d39c24 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -34,15 +34,9 @@ class Step(object): def __call__(self): """Perform one step of this training algorithm.""" - return self.step(self.inputs()) - - def inputs(self): - """For the generating the input to be passed to `step()`.""" raise NotImplementedError("must be implemented in descendants") - def step(self, inputs): - """Perform the main computation of this training algorithm.""" - raise NotImplementedError("must be implemented in descendants") + # TODO(priyag): Add an method to access initialization and finalize ops. class StandardInputStep(Step): @@ -54,12 +48,9 @@ class StandardInputStep(Step): """ def __init__(self, dataset_fn, distribution): - Step.__init__(self, distribution) - self._distributed_input = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() - - def inputs(self): - return self._distributed_input.get_next() + super(StandardInputStep, self).__init__(distribution) + self._distributed_input = distribution.distribute_dataset(dataset_fn) + self._iterator = self._distributed_input.make_one_shot_iterator() class StandardSingleLossStep(StandardInputStep): @@ -69,8 +60,8 @@ class StandardSingleLossStep(StandardInputStep): ```python ... - step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer) - step.initialize(distribution) + step = step_fn.StandardSingleLossStep( + dataset, loss_fn, optimizer, distribution) # Run a single training step on a given DistributionStrategy: step(distribution) @@ -80,27 +71,43 @@ class StandardSingleLossStep(StandardInputStep): Args: dataset_fn: a function that returns a tf.data Dataset that produces the input for the model. - loss_fn: a function that returns loss. + loss_fn: a function that takes a context and inputs as arguments. It returns + the loss for those inputs. `context` is an instance of + `values.MultiStepContext` that will be passed when `loss_fn` is run. + `context` can be used to specify the outputs to be returned from + `loss_fn`, among other things. optimizer: an optimizer that implements an update rule. distribution: a `DistributionStrategy` object. """ - def __init__(self, dataset_fn, loss_fn, optimizer, distribution): - StandardInputStep.__init__(self, dataset_fn, distribution) + def __init__(self, dataset_fn, loss_fn, optimizer, distribution, + iterations_per_step=1): + super(StandardSingleLossStep, self).__init__(dataset_fn, distribution) self._loss_fn = loss_fn self._optimizer = optimizer self._is_run_concurrently = False + self._iterations_per_step = iterations_per_step - def step(self, inputs): + def __call__(self): with self._distribution.scope(): - gradients_fn = backprop.implicit_grad(self._loss_fn) - gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) - - grads_and_vars = self.distribution.call_for_each_tower( - gradients_fn, inputs, run_concurrently=self._is_run_concurrently) - # If threads use layers, then we need to run the first step sequentially, - # so that layers.build() is not executed in parallel. Otherwise, multiple - # sets of mirrored variables are going to be created. - self._is_run_concurrently = True - return self._optimizer._distributed_apply( # pylint: disable=protected-access - self.distribution, grads_and_vars) + def step_fn(ctx, inputs): + """Function to run one iteration with one input.""" + gradients_fn = backprop.implicit_grad(self._loss_fn) + gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) + + grads_and_vars = self.distribution.call_for_each_tower( + gradients_fn, + ctx, inputs, + run_concurrently=self._is_run_concurrently) + # If threads use layers, then we need to run the first step + # sequentially, so that layers.build() is not executed in parallel. + # Otherwise, multiple sets of mirrored variables are going to be + # created. + self._is_run_concurrently = True + return self._optimizer._distributed_apply( # pylint: disable=protected-access + self.distribution, grads_and_vars) + + # TODO(priyag): Return the outputs, context, etc as well. + ctx = self.distribution.run_steps_on_dataset( + step_fn, self._iterator, self._iterations_per_step) + return ctx.run_op diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index 75c5ec9659d193e77d219ba79977615d58841d64..8605ab1f7daeb81e778577ad3c4a18b39c57d743 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -33,12 +33,19 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): @combinations.generate( combinations.times( combinations.distributions_and_v1_optimizers(), - combinations.combine(mode=combinations.graph_and_eager_modes))) - def testTrainNetwork(self, distribution, optimizer_fn): + combinations.combine(mode=combinations.graph_and_eager_modes), + combinations.combine(is_tpu=[False])) + + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=combinations.optimizers_v1, + mode=["graph"], + is_tpu=[True])) + def testTrainNetwork(self, distribution, optimizer_fn, is_tpu): with distribution.scope(): single_loss_step, layer = single_loss_example( - optimizer_fn, distribution, use_bias=True) + optimizer_fn, distribution, use_bias=True, iterations_per_step=2) + self.evaluate(distribution.initialize()) if context.executing_eagerly(): run_step = single_loss_step else: @@ -47,11 +54,13 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): self.evaluate(variables.global_variables_initializer()) weights, biases = [], [] - for _ in range(10): + for _ in range(5): run_step() - weights.append(self.evaluate(distribution.fetch(layer.kernel))) - biases.append(self.evaluate(distribution.fetch(layer.bias))) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) + + self.evaluate(distribution.finalize()) error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 2b4ad9f146bc1d6a987fbeecbb05122946137154..371b97ba96a826194a6469ba63e485fc67639585 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -26,8 +26,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer @@ -44,7 +45,8 @@ def _raise_exception_fn(_=None): # Must be the argument to a distribution.call_for_each_tower() call, calls a # get_tower_context().merge_call() that raises an exception. def _merge_raises_fn(): - distribute_lib.get_tower_context().merge_call(_raise_exception_fn) + distribution_strategy_context.get_tower_context().merge_call( + _raise_exception_fn) # Must be the argument to a get_tower_context().merge_call() call, calls @@ -57,7 +59,7 @@ def _call_raises_fn(dist): # calls a get_tower_context().merge_call() that calls a # call_for_each_tower() that raises an exception. def _merge_call_raises_fn(): - distribute_lib.get_tower_context().merge_call(_call_raises_fn) + distribution_strategy_context.get_tower_context().merge_call(_call_raises_fn) # Must be the argument to a get_tower_context().merge_call() call, calls @@ -71,7 +73,8 @@ def _call_merge_raises_fn(dist): # get_tower_context().merge_call() that calls a call_for_each_tower() that # calls a get_tower_context().merge_call() that raises an exception. def _merge_call_merge_raises_fn(): - distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn) + distribution_strategy_context.get_tower_context().merge_call( + _call_merge_raises_fn) class DistributionTestBase(test.TestCase): @@ -106,13 +109,14 @@ class DistributionTestBase(test.TestCase): before_list = [] after_list = [] for g, v in g_v: - fetched = d.fetch(v) + fetched = d.read_var(v) before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): - g = d.reduce("sum", g, destinations=v) + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) with ops.control_dependencies(d.unwrap(d.update(v, update, g))): - after_list.append(d.fetch(v)) + after_list.append(d.read_var(v)) return before_list, after_list for i in range(10): @@ -159,12 +163,13 @@ class DistributionTestBase(test.TestCase): before_list = [] after_list = [] for g, v in g_v: - fetched = d.fetch(v) + fetched = d.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): - g = d.reduce("sum", g, destinations=v) + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) with ops.control_dependencies(d.unwrap(d.update(v, update, g))): - after_list.append(d.fetch(v)) + after_list.append(d.read_var(v)) return before_list, after_list before_out, after_out = step() @@ -184,7 +189,7 @@ class DistributionTestBase(test.TestCase): with d.scope(): map_in = [constant_op.constant(i) for i in range(10)] map_out = d.map(map_in, lambda x, y: x * y, 2) - observed = d.fetch(d.reduce("sum", map_out)) + observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out) expected = 90 # 2 * (0 + 1 + ... + 9) self.assertEqual(expected, observed.numpy()) @@ -205,7 +210,7 @@ class DistributionTestBase(test.TestCase): expected_devices = [False] * len(d.worker_devices) def mark_devices_fn(): - tower_id = distribute_lib.get_tower_context().tower_id + tower_id = distribution_strategy_context.get_tower_context().tower_id self.assertLess(tower_id, len(d.worker_devices)) self.assertFalse(expected_devices[tower_id]) expected_devices[tower_id] = True diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 75441786a615fc0d87b4c4b0b45b9384d678c1d3..b510fdb888dafe9f18805bc60e9fb670710521ab 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,108 +21,229 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools - -from tensorflow.contrib import tpu +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python import values from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib +from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.training import device_util +from tensorflow.python.training import server_lib from tensorflow.python.util import nest +def get_tpu_system_metadata(tpu_cluster_resolver): + """Retrieves TPU system metadata given a TPUClusterResolver.""" + master = tpu_cluster_resolver.master() + + # pylint: disable=protected-access + cluster_def = (tpu_cluster_resolver.cluster_spec() + or server_lib.ClusterSpec({})).as_cluster_def() + tpu_system_metadata = ( + tpu_system_metadata_lib._query_tpu_system_metadata( + master, + cluster_def=cluster_def, + query_topology=True)) + + return tpu_system_metadata + + class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" - def __init__(self, - num_cores_per_host=2, - iterations_per_step=2): + def __init__(self, tpu_cluster_resolver, steps_per_run): + """Initializes the TPUStrategy object. + + Args: + tpu_cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver, + which provides information about the TPU cluster. + steps_per_run: Number of steps to run on device before returning to the + host. Note that this can have side-effects on performance, hooks, + metrics, summaries etc. + This parameter is only used when Distribution Strategy is used with + estimator or keras. + """ # TODO(isaprykin): Generalize the defaults. They are currently tailored for # the unit test. - super(TPUStrategy, self).__init__('/cpu:0') - # TODO(isaprykin): Auto-detect number of cores and hosts. - self._num_cores_per_host = num_cores_per_host - # TODO(isaprykin): This might have to be per-call. - self._iterations_per_step = iterations_per_step + super(TPUStrategy, self).__init__('/device:CPU:0') - def distribute_dataset(self, dataset_fn): - return values.PerIterationDataset( - self._call_dataset_fn(dataset_fn), self._iterations_per_step, - self._num_cores_per_host) + self._tpu_cluster_resolver = tpu_cluster_resolver + self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) - def _call_for_each_tower(self, fn, *args, **kwargs): - kwargs.pop('run_concurrently', None) + # TODO(priyag): This should not be hardcoded here. + self._host = '/device:CPU:0' + # TODO(sourabhbajaj): Remove this once performance of running one step + # at a time is comparable to multiple steps. + self.steps_per_run = steps_per_run - inputs = {'args': args, 'kwargs': kwargs} - flat_inputs = nest.flatten(inputs) + def distribute_dataset(self, dataset_fn): + # TODO(priyag): Perhaps distribute across cores here. + return self._call_dataset_fn(dataset_fn) - feed_mask = [isinstance(f, values.PerIteration) for f in flat_inputs] + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. + # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have + # a mechanism to infer the outputs of `fn`. Pending b/110550782. + def _run_steps_on_dataset(self, fn, iterator, iterations, + initial_loop_values=None): - feeds = lambda: itertools.compress(flat_inputs, feed_mask) - shapes = [f.get_shape() for f in feeds()] + shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.apply(map_and_batch(..., drop_remainder=True)).') - types = [f.get_dtype() for f in feeds()] - - def infeed_input(i): - """Get input, split it and then enqueue.""" - iteration_inputs = [f.get(i) for f in feeds()] - infeed_inputs = [[inputs_per_core[core_id] - for inputs_per_core in iteration_inputs] - for core_id in range(self._num_cores_per_host)] - - infeed_ops = [] - for core_id, infeed_input in enumerate(infeed_inputs): - infeed_ops.append( + types = nest.flatten(iterator.output_types) + + def enqueue_ops_fn(): + """Enqueue ops for one iteration.""" + control_deps = [] + sharded_inputs = [] + with ops.device(self._host): + for _ in range(self.num_towers): + # Use control dependencies to ensure a deterministic ordering. + with ops.control_dependencies(control_deps): + inputs = nest.flatten(iterator.get_next()) + control_deps.extend(inputs) + sharded_inputs.append(inputs) + + enqueue_ops = [] + for core_id, shard_input in enumerate(sharded_inputs): + enqueue_ops.append( tpu_ops.infeed_enqueue_tuple( - inputs=infeed_input, shapes=shapes, device_ordinal=core_id)) + inputs=shard_input, shapes=shapes, device_ordinal=core_id)) + return enqueue_ops - with ops.control_dependencies(infeed_ops): + def enqueue_ops_loop_body(i): + with ops.control_dependencies(enqueue_ops_fn()): return i + 1 - with ops.device('/task:0/device:CPU:0'): + with ops.device(self._host): enqueue_ops = control_flow_ops.while_loop( - lambda i: i < self._iterations_per_step, - infeed_input, [constant_op.constant(0)], + lambda i: i < iterations, + enqueue_ops_loop_body, + [constant_op.constant(0)], parallel_iterations=1) - def dequeueing_fn(*args, **kwargs): - """Dequeue input arguments and supply them to `fn`.""" - del args, kwargs - dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes) - dequeued = iter(dequeued) - - fn_inputs = [] - for inp, is_feed in zip(flat_inputs, feed_mask): - if is_feed: - fn_inputs.append(next(dequeued)) - else: - fn_inputs.append(inp) - - fn_inputs = nest.pack_sequence_as(inputs, fn_inputs) - return fn(*fn_inputs['args'], **fn_inputs['kwargs']) + def dequeue_fn(): + dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) + return nest.pack_sequence_as(iterator.output_shapes, dequeued) + # Wrap `fn` for repeat. + if initial_loop_values is None: + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + ctx = values.MultiStepContext() + def run_fn(*args, **kwargs): + del args, kwargs + fn_result = fn(ctx, dequeue_fn()) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + if flat_last_step_outputs: + with ops.control_dependencies([fn_result]): + return [array_ops.identity(f) for f in flat_last_step_outputs] + else: + return fn_result + + # TODO(sourabhbajaj): The input to while loop should be based on the output + # type of the step_fn def iterate_on_tpu(): - return tpu.repeat(self._iterations_per_step, dequeueing_fn, []) + return training_loop.repeat(iterations, run_fn, initial_loop_values) + + replicate_inputs = [[]] * self.num_towers + replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) + ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) + + # Filter out any ops from the outputs, typically this would be the case + # when there were no tensor outputs. + last_step_tensor_outputs = [x for x in replicate_outputs + if not isinstance(x, ops.Operation)] + + # Outputs are currently of the structure (grouped by device) + # [[output0_device0, output1_device0, output2_device0], + # [output0_device1, output1_device1, output2_device1]] + # Convert this to the following structure instead: (grouped by output) + # [[output0_device0, output0_device1], + # [output1_device0, output1_device1], + # [output2_device0, output2_device1]] + last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)] + + # Convert replicate_outputs to the original dict structure of + # last_step_outputs. + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + output = last_step_tensor_outputs_dict[name] + # For outputs that have already been aggregated, take the first value + # from the list as each value should be the same. Else return the full + # list of values. + if aggregation is not variables_lib.VariableAggregation.NONE: + # TODO(priyag): Should this return the element or a list with 1 element + last_step_tensor_outputs_dict[name] = output[0] + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + + return ctx + def _call_for_each_tower(self, fn, *args, **kwargs): + kwargs.pop('run_concurrently', None) with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access - tpu_result = tpu.batch_parallel( - iterate_on_tpu, [], num_shards=self._num_cores_per_host) - - return control_flow_ops.group(tpu_result, enqueue_ops) - - def _reduce(self, method_string, value, destinations): - del destinations # TPU is graph mode only. Rely on implicit Send/Recv. - if method_string == 'mean': - # TODO(jhseu): Revisit once we support model-parallelism. - value *= (1. / self._num_cores_per_host) - return tpu_ops.cross_replica_sum(value) + return fn(*args, **kwargs) + + def initialize(self): + if context.executing_eagerly(): + # TODO(priyag): Add appopriate call here when eager is supported for TPUs. + raise NotImplementedError('Eager mode not supported in TPUStrategy.') + else: + return [tpu.initialize_system()] + + def finalize(self): + if context.executing_eagerly(): + # TODO(priyag): Add appopriate call here when eager is supported for TPUs. + raise NotImplementedError('Eager mode not supported in TPUStrategy.') + else: + return [tpu.shutdown_system()] + + def _reduce(self, aggregation, value, destinations): + graph = ops.get_default_graph() + cf_context = graph._get_control_flow_context() # pylint: disable=protected-access + # If we're inside the ReplicateContext, reduction should be done using + # CrossReplicaSum while outside we can directly use an add_n op. + while cf_context: + if isinstance(cf_context, tpu.TPUReplicateContext): + if aggregation == vs.VariableAggregation.MEAN: + # TODO(jhseu): Revisit once we support model-parallelism. + value *= (1. / self.num_towers) + return tpu_ops.cross_replica_sum(value) + cf_context = cf_context.outer_context + + # Validate that the destination is same as the host device + # Note we don't do this when in replicate context as the reduction is + # performed on the TPU device itself. + devices = cross_tower_ops_lib.get_devices_from(destinations) + if len(devices) == 1: + assert device_util.canonicalize(devices[0]) == device_util.canonicalize( + self._host) + else: + raise ValueError('Multiple devices are not supported for TPUStrategy') + + output = math_ops.add_n(value) + if aggregation == vs.VariableAggregation.MEAN: + return output * (1. / len(value)) + return output + + def _unwrap(self, value): + if isinstance(value, list): + return value + return [value] @property def num_towers(self): - return self._num_cores_per_host + return self._tpu_metadata.num_of_cores_per_host diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 9572ade8e497fa13a7ca0746399d3e0237ee79fd..8548a864210a4720e4094873b6470be8d6b26e3c 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -23,27 +23,29 @@ from __future__ import print_function import collections import weakref - import six -from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import input_ops from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import saver from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest # pylint: disable=line-too-long -# TODO(josh11b): Should device values be strings or DeviceSpec objects +# TODO(josh11b): Should device values be strings or DeviceSpec objects? # Not sure DeviceSpec objects are usable as a dict key. class DistributedValues(object): """Holds a map from device to values. Either PerDevice or Mirrored.""" @@ -55,7 +57,7 @@ class DistributedValues(object): def get(self, device=None): """Returns the value for the current device or raises a ValueError.""" if device is None: - tower_context = distribute_lib.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() if tower_context: device = tower_context.device else: @@ -78,6 +80,13 @@ class DistributedValues(object): def devices(self): return list(self._index.keys()) + @property + def is_tensor_like(self): + for v in self._index.values(): + if not tensor_util.is_tensor(v): + return False + return True + def __str__(self): return "%s:%s" % (self.__class__.__name__, self._index) @@ -163,9 +172,16 @@ class PerDevice(DistributedValues): pass -class Mirrored(DistributedValues): +# Note that unlike PerDevice, Mirrored values inherit from +# DistributedDelegate and so can be used directly in cross-tower mode. +class Mirrored(DistributedDelegate): """Holds a map from device to values which are kept in sync.""" - pass + + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return self._index[device] + return list(self._index.values())[0] def _assign_on_device(device, variable, tensor): @@ -186,11 +202,58 @@ class DistributedVariable(DistributedDelegate): # Child class must set self._primary_var before calling # super(...).__init__(index). self._common_name = self._primary_var.name.split(":")[0] + # Use a weakref to make it easy to map from the contained values + # to the container without introducing a reference cycle. + for v in six.itervalues(index): + v._distributed_container = weakref.ref(self) # pylint: disable=protected-access + # tf.keras keeps track of variables initialized using this attribute. When + # tf.keras gets the default session, it initializes all uninitialized vars. + # We need to make _keras_initialized a member of DistributedVariable because + # without this it will use `__getattr__` which will delegate to a component + # variable. + self._keras_initialized = False + # Typically, a `DistributedVariable`'s initializer is composed of the + # initializers of the components variables. However, in some cases, such as + # when restoring from a checkpoint, we may set the _initializer_op + # property on the entire `DistributedVariable`. + self._initializer_op = None super(DistributedVariable, self).__init__(index) + def is_initialized(self, name=None): + """Identifies if all the component variables are initialized. + + Args: + name: Name of the final `logical_and` op. + + Returns: + The op that evaluates to True or False depending on if all the + component variables are initialized. + """ + # We have to cast the self._index.values() to a `list` because when we + # use `model_to_estimator` to run tf.keras models, self._index.values() is + # of type `dict_values` and not `list`. + values_list = list(self._index.values()) + result = values_list[0].is_initialized() + # We iterate through the list of values except the last one to allow us to + # name the final `logical_and` op the same name that is passed by the user + # to the `is_initialized` op. For distributed variables, the + # `is_initialized` op is a `logical_and` op. + for v in values_list[1:-1]: + result = math_ops.logical_and(result, v.is_initialized()) + result = math_ops.logical_and(result, values_list[-1].is_initialized(), + name=name) + return result + @property def initializer(self): - return control_flow_ops.group([v.initializer for v in self._index.values()]) + if self._initializer_op: + init_op = self._initializer_op + else: + # return grouped ops of all the var initializations of component values of + # the mirrored variable + init_op = control_flow_ops.group( + [v.initializer for v in self._index.values()]) + return init_op @property def graph(self): @@ -227,46 +290,24 @@ class DistributedVariable(DistributedDelegate): # We want cross-tower code that does some var.op.X calls # to work (even if the current device isn't in self.devices), but # other uses of var.op in a cross-tower context to fail. - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return DistributedVarOp(self._primary_var.op.name, self._primary_var.op.graph, self._primary_var.op.type) return self.get().op + def read_value(self): + return distribution_strategy_context.get_distribution_strategy().read_var( + self) + def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" pass -# Register a conversion function which reads the value of the variable, -# allowing instances of the class to be used as tensors. -def _tensor_conversion(var, dtype=None, name=None, as_ref=False): - # Try to avoid assignments to and other mutations of MirroredVariable - # state except through a DistributionStrategy.update() call. - assert not as_ref - return ops.internal_convert_to_tensor( - var.get(), dtype=dtype, name=name, as_ref=as_ref) - - -ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion) ops.register_dense_tensor_like_type(DistributedVariable) -class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): - """Class for defining how to restore a MirroredVariable.""" - - def __init__(self, mirrored_variable, primary_variable, name): - self._mirrored_variable = mirrored_variable - super(_MirroredSaveable, self).__init__(primary_variable, "", name) - - def restore(self, restored_tensors, restored_shapes): - """Restore the same value into all variables.""" - tensor, = restored_tensors - return control_flow_ops.group([ - _assign_on_device(d, v, tensor) - for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access - - def _get_update_device(): """Validate we are in update/update_non_slot() and return current device. @@ -287,34 +328,90 @@ def _get_update_device(): return device +class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): + """Class for defining how to restore a MirroredVariable.""" + + def __init__(self, mirrored_variable, primary_variable, name): + self._mirrored_variable = mirrored_variable + super(_MirroredSaveable, self).__init__(primary_variable, "", name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + tensor, = restored_tensors + return control_flow_ops.group([ + _assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access + + class MirroredVariable(DistributedVariable, Mirrored, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are kept in sync.""" - def __init__(self, index, primary_var): + def __init__(self, index, primary_var, aggregation): # Use a weakref to make it easy to map from the contained values # to the container without introducing a reference cycle. for v in six.itervalues(index): v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access self._primary_var = primary_var + self._aggregation = aggregation super(MirroredVariable, self).__init__(index) - # We use _get_update_device() for the assign* methods to enforce - # that we are in an update() function. The arguments to update() are - # automatically unwrapped so the update() function would normally - # see regular variables, not MirroredVariables. However, the update - # function can still operate on wrapped MirroredVariables through - # object members, captured arguments, etc. This is more likely in an + # The arguments to update() are automatically unwrapped so the update() + # function would normally see regular variables, not MirroredVariables. + # However, the update function can still operate on wrapped MirroredVariables + # through object members, captured arguments, etc. This is more likely in an # update_non_slot() function (like OptimizerV2._finish), which can # update several non-slot variables in one call. + def _assign_func(self, *args, **kwargs): + f = kwargs.pop("f") + if distribution_strategy_context.get_cross_tower_context(): + update_device = distribute_lib.get_update_device() + # We are calling update on the mirrored variable in cross tower context. + if update_device is not None: + # We are calling an assign function on the mirrored variable in cross + # tower context. + v = self.get(device=update_device) + return f(v, *args, **kwargs) + + return distribution_strategy_context.get_distribution_strategy().update( + self, f, *args, **kwargs) + else: + _assert_tower_context() + # We are calling an assign function on the mirrored variable in tower + # context. + # We reduce the value we want to assign/add/sub. More details about how we + # handle the different use cases can be found in the _reduce method. + # We call the function on each of the mirrored variables with the reduced + # value. + if self._aggregation == vs.VariableAggregation.NONE: + raise ValueError("You must specify an aggregation method to update a " + "MirroredVariable in Tower Context.") + + def merge_fn(strategy, value, *other_args, **other_kwargs): + return strategy.update( + self, f, + strategy.reduce( + aggregation=self._aggregation, value=value, destinations=self), + *other_args, **other_kwargs) + + return distribution_strategy_context.get_tower_context().merge_call( + merge_fn, *args, **kwargs) + def assign_sub(self, *args, **kwargs): - return self.get(device=_get_update_device()).assign_sub(*args, **kwargs) + assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) + return self._assign_func(f=assign_sub_fn, *args, **kwargs) def assign_add(self, *args, **kwargs): - return self.get(device=_get_update_device()).assign_add(*args, **kwargs) + assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) + return self._assign_func(f=assign_add_fn, *args, **kwargs) def assign(self, *args, **kwargs): - return self.get(device=_get_update_device()).assign(*args, **kwargs) + assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) + return self._assign_func(f=assign_fn, *args, **kwargs) + + @property + def aggregation(self): + return self._aggregation def _get_cross_tower(self): device = device_util.canonicalize(device_util.current()) @@ -324,7 +421,7 @@ class MirroredVariable(DistributedVariable, Mirrored, def _as_graph_element(self): # pylint: disable=protected-access - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return self._primary_var._as_graph_element() return self.get()._as_graph_element() @@ -342,6 +439,20 @@ class MirroredVariable(DistributedVariable, Mirrored, return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): + # Try to avoid assignments to and other mutations of MirroredVariable + # state except through a DistributionStrategy.update() call. + assert not as_ref + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(MirroredVariable, + _tensor_conversion_mirrored) + + class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): """Class for defining how to restore a TowerLocalVariable.""" @@ -350,7 +461,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): # We use a callable so that we don't have to evaluate this expression # in the case where we are trying to restore instead of save. def tensor(): - return distribute_lib.get_distribution_strategy().fetch( + return distribution_strategy_context.get_distribution_strategy().read_var( tower_local_variable) spec = saver.BaseSaverBuilder.SaveSpec( tensor=tensor, @@ -362,18 +473,11 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): def restore(self, restored_tensors, restored_shapes): """Restore the same value into all variables.""" tensor, = restored_tensors - # To preserve the sum across save and restore, we have to divide the - # total across all devices when restoring a variable that was summed - # when saving. - if self._tower_local_variable.reduce_method == "sum": - tensor *= 1. / len(self._tower_local_variable.devices) - return control_flow_ops.group([ - _assign_on_device(d, v, tensor) - for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access + return self._tower_local_variable.assign(tensor) def _assert_tower_context(): - if not distribute_lib.get_tower_context(): + if not distribution_strategy_context.get_tower_context(): raise RuntimeError( "Tower-local variables may only be assigned in a tower context.") @@ -382,9 +486,9 @@ class TowerLocalVariable(DistributedVariable, PerDevice, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are reduced on save.""" - def __init__(self, index, primary_var, reduce_method): + def __init__(self, index, primary_var, aggregation): self._primary_var = primary_var - self._reduce_method = reduce_method + self._aggregation = aggregation super(TowerLocalVariable, self).__init__(index) def assign_sub(self, *args, **kwargs): @@ -396,24 +500,35 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return self.get().assign_add(*args, **kwargs) def assign(self, *args, **kwargs): - _assert_tower_context() - return self.get().assign(*args, **kwargs) + if distribution_strategy_context.get_cross_tower_context(): + # To preserve the sum across save and restore, we have to divide the + # total across all devices when restoring a variable that was summed + # when saving. + tensor = args[0] + if self._aggregation == vs.VariableAggregation.SUM: + tensor *= 1. / len(self.devices) + return control_flow_ops.group( + [_assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._index)]) + else: + _assert_tower_context() + return self.get().assign(*args, **kwargs) @property - def reduce_method(self): - return self._reduce_method + def aggregation(self): + return self._aggregation def _get_cross_tower(self): all_components = tuple(self._index.values()) # TODO(josh11b): Use a strategy-specific method. total = math_ops.add_n(all_components) - if self._reduce_method == "mean": + if self._aggregation == vs.VariableAggregation.MEAN: return total * (1./ len(all_components)) return total def _as_graph_element(self): # pylint: disable=protected-access - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): return self._get_cross_tower() return self.get()._as_graph_element() @@ -431,6 +546,17 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} +# Register a conversion function for TowerLocalVariable which allows as_ref to +# be true. +def _tensor_conversion_tower_local(var, dtype=None, name=None, as_ref=False): + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(TowerLocalVariable, + _tensor_conversion_tower_local) + + def _devices_match(d1, d2): return device_util.canonicalize(d1) == device_util.canonicalize(d2) @@ -478,40 +604,40 @@ def regroup(per_device, wrap_class=PerDevice): same_id = False break # Consider three cases where same_id is true: - # * If v0 is a MirroredVariable (and same_id means it is the same - # across all devices), we want to return it. We check - # MirroredVariable specifically since it can look like it - # has a _mirrored_container member since its members do. - # * If v0 is a member of a mirrored variable, in which case - # hasattr(v0, "_mirrored_container") is true, we want to - # return the MirroredVariable that contains it using the - # _mirrored_container logic below. This case can trigger + # * If v0 is a DistributedVariable (a MirroredVariable or + # TowerLocalVariable, and same_id means it is the same across all + # devices), we want to return it. We check DistributedVariable + # specifically since it can look like it has a + # _distributed_container member since its members do. + # * If v0 is a member of a distributed variable, in which case + # hasattr(v0, "_distributed_container") is true, we want to + # return the DistributedVariable that contains it using the + # _distributed_container logic below. This case can trigger # same_id when there is only one device. # * In any other situation, same_id means we return v0. - if same_id and (isinstance(v0, MirroredVariable) or - not hasattr(v0, "_mirrored_container")): + if same_id and (isinstance(v0, DistributedVariable) or + not hasattr(v0, "_distributed_container")): return v0 # Detect the case where each device has a parallel component of the - # same MirroredVariable. In this case we want to return the - # containing MirroredVariable, after a bunch of sanity checking. - # In particular, each component should have the same container, - # and the devices of the variables should match the keys of the - # per-device dictionary. - # TODO(josh11b): Do we need similar logic for TowerLocalVariables? - if hasattr(v0, "_mirrored_container"): + # same MirroredVariable (or TowerLocalVariable). In this case we + # want to return the containing MirroredVariable, after a bunch of + # sanity checking. In particular, each component should have the + # same container, and the devices of the variables should match the + # keys of the per-device dictionary. + if hasattr(v0, "_distributed_container"): # pylint: disable=protected-access assert not isinstance(v0, MirroredVariable), ( "ids = %s, items = %s" % ([id(v[1]) for v in items], items)) assert _devices_match(v0.device, items[0][0]), ( "v0.device = %s, items = %s" % (v0.device, items)) - mirrored_container = v0._mirrored_container() - assert mirrored_container is not None + distributed_container = v0._distributed_container() + assert distributed_container is not None for d, v in items[1:]: assert _devices_match(v.device, d), ( "v.device = %s, d = %s, items = %s" % (v.device, d, items)) - assert mirrored_container is v._mirrored_container() - return mirrored_container + assert distributed_container is v._distributed_container() + return distributed_container # pylint: enable=protected-access return wrap_class(per_device) @@ -593,8 +719,7 @@ class PerDeviceDataset(object): # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. # Possibly not an issue when we start using PartitionedDataset. - self._dataset = dataset.apply( - batching.batch_and_drop_remainder(len(devices))) + self._dataset = dataset.batch(len(devices), drop_remainder=True) def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" @@ -805,3 +930,130 @@ class MapOutput(object): def get(self): return self._l + + +class MultiStepContext(object): + """A context object that can be used to capture things when running steps. + + This context object is useful when running multiple steps at a time using the + `run_steps_on_dataset` API. For e.g. it allows the user's step function to + specify which outputs to emit at what frequency. Currently it supports + capturing output from the last step, as well as capturing non tensor outputs. + In the future it will be augmented to support other use cases such as output + each N steps. + """ + + def __init__(self): + """Initializes an output context. + + Returns: + A context object. + """ + self._last_step_outputs = {} + self._last_step_outputs_aggregations = {} + self._non_tensor_outputs = {} + + @property + def last_step_outputs(self): + """A dictionary consisting of outputs to be captured on last step. + + Keys in the dictionary are names of tensors to be captured, as specified + when `set_last_step_output` is called. + Values in the dictionary are the tensors themselves. If + `set_last_step_output` was called with an `aggregation` for this output, + then the value is the aggregated value. + + Returns: + A dictionary with last step outputs. + """ + return self._last_step_outputs + + def _set_last_step_outputs(self, outputs): + """Replace the entire dictionary of last step outputs.""" + if not isinstance(outputs, dict): + raise ValueError("Need a dictionary to set last_step_outputs.") + self._last_step_outputs = outputs + + def set_last_step_output(self, name, output, + aggregation=variables_lib.VariableAggregation.NONE): + """Set `output` with `name` to be outputted from the last step. + + Args: + name: String, name to identify the output. Doesn't need to match tensor + name. + output: The tensors that should be outputted with `name`. See below for + actual types supported. + aggregation: Aggregation method to use to aggregate outputs from multiple + towers. Required if `set_last_step_output` is called in a tower context. + Optional in cross_tower_context. + When present, the outputs from all the towers are aggregated using the + current distribution strategy's `reduce` method. Hence, the type of + `output` must be what's supported by the corresponding `reduce` method. + For e.g. if using MirroredStrategy and aggregation is set, output + must be a `PerDevice` value. + The aggregation method is also recorded in a dictionary + `_last_step_outputs_aggregations` for later interpreting of the + outputs as already reduced or not. + + """ + if distribution_strategy_context.get_cross_tower_context(): + self._last_step_outputs_aggregations[name] = aggregation + if aggregation is variables_lib.VariableAggregation.NONE: + self._last_step_outputs[name] = output + else: + distribution = distribution_strategy_context.get_distribution_strategy() + self._last_step_outputs[name] = distribution.reduce( + aggregation, output, destinations="/device:CPU:0") + else: + assert aggregation is not variables_lib.VariableAggregation.NONE + def merge_fn(distribution, value): + self._last_step_outputs[name] = distribution.reduce( + aggregation, value, destinations="/device:CPU:0") + # Setting this inside the `merge_fn` because all towers share the same + # context object, so it's more robust to set it only once (even if all + # the towers are trying to set the same value). + self._last_step_outputs_aggregations[name] = aggregation + + distribution_strategy_context.get_tower_context().merge_call( + merge_fn, output) + + @property + def non_tensor_outputs(self): + """A dictionary consisting of any non tensor outputs to be captured.""" + return self._non_tensor_outputs + + def set_non_tensor_output(self, name, output): + """Set `output` with `name` to be captured as a non tensor output.""" + if distribution_strategy_context.get_cross_tower_context(): + self._non_tensor_outputs[name] = output + else: + def merge_fn(distribution, value): + # NOTE(priyag): For non tensor outputs, we simply return all the values + # in a list as aggregation doesn't make sense on non tensors. + self._non_tensor_outputs[name] = distribution.unwrap(value) + distribution_strategy_context.get_tower_context().merge_call( + merge_fn, output) + + +def value_container(val): + """Returns the container that this per-device `value` belongs to. + + Args: + val: A value returned by `call_for_each_tower()` or a variable + created in `scope()`. + + Returns: + A container that `value` belongs to. + If value does not belong to any container (including the case of + container having been destroyed), returns the value itself. + """ + # pylint: disable=protected-access + if (hasattr(val, "_distributed_container") and + # DistributedVariable has _distributed_container defined + # but we don't want to return it. + not isinstance(val, DistributedVariable)): + container = val._distributed_container() + # pylint: disable=protected-access + if container is not None: + return container + return val diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 1c95758d96aba47e9581dde6411763e98b99a968..91a43d499933c77de846085e0f12abf3064b0499 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -32,6 +32,7 @@ from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops @@ -79,10 +80,34 @@ class DistributedValuesTest(test.TestCase): with self.assertRaises(AssertionError): v = values.DistributedValues({"/device:cpu:0": 42}) + def testIsTensorLike(self): + with context.graph_mode(), \ + ops.Graph().as_default(), \ + ops.device("/device:CPU:0"): + one = constant_op.constant(1) + two = constant_op.constant(2) + v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + self.assertEqual(two, v.get("/device:GPU:0")) + self.assertEqual(one, v.get()) + self.assertTrue(v.is_tensor_like) + self.assertTrue(tensor_util.is_tensor(v)) + + def testIsTensorLikeWithAConstant(self): + with context.graph_mode(), \ + ops.Graph().as_default(), \ + ops.device("/device:CPU:0"): + one = constant_op.constant(1) + two = 2.0 + v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + self.assertEqual(two, v.get("/device:GPU:0")) + self.assertEqual(one, v.get()) + self.assertFalse(v.is_tensor_like) + self.assertFalse(tensor_util.is_tensor(v)) + class DistributedDelegateTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetAttr(self): with ops.device("/device:CPU:0"): @@ -97,7 +122,7 @@ class DistributedDelegateTest(test.TestCase): with self.assertRaises(AttributeError): _ = v.y - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOperatorOverride(self): with ops.device("/device:CPU:0"): v = values.DistributedDelegate({"/device:CPU:0": 7, "/device:GPU:0": 8}) @@ -158,7 +183,8 @@ def _make_mirrored(): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) index[d] = v[-1] - mirrored = values.MirroredVariable(index, v[0]) + mirrored = values.MirroredVariable(index, v[0], + variable_scope.VariableAggregation.SUM) return v, devices, mirrored @@ -277,7 +303,8 @@ class RegroupAndSelectDeviceTest(test.TestCase): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) index = {d: v} - mirrored = values.MirroredVariable(index, v) + mirrored = values.MirroredVariable(index, v, + variable_scope.VariableAggregation.SUM) result = values.regroup(index) self.assertIs(mirrored, result) @@ -363,7 +390,7 @@ class PerDeviceDatasetTest(test.TestCase): self._test_iterator_no_prefetch(devices, dataset, expected_values) self._test_iterator_with_prefetch(devices, dataset, expected_values) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOneDevice(self): devices = ["/device:CPU:0"] dataset = dataset_ops.Dataset.range(10) @@ -581,7 +608,8 @@ class MirroredVariableTest(test.TestCase): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) index = {"/job:foo/device:CPU:0": v} - mirrored = values.MirroredVariable(index, v) + mirrored = values.MirroredVariable(index, v, + variable_scope.VariableAggregation.MEAN) self.assertEquals(v.name, mirrored.name) self.assertEquals(v.dtype, mirrored.dtype) @@ -716,7 +744,9 @@ class MirroredVariableTest(test.TestCase): with ops.device("/device:GPU:0"): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) - mirrored = values.MirroredVariable({"/device:GPU:0": v}, v) + mirrored = values.MirroredVariable({ + "/device:GPU:0": v + }, v, variable_scope.VariableAggregation.MEAN) sess.run(variables_lib.global_variables_initializer()) sess.run({"complicated": mirrored}) @@ -746,24 +776,27 @@ class TowerLocalVariableTest(test.TestCase): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - v, tower_local = _make_tower_local("sum") + v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) self.assertEquals(v[0].name, tower_local.name) self.assertEquals(v[0].dtype, tower_local.dtype) self.assertEquals(v[0].shape, tower_local.shape) - self.assertEquals("sum", tower_local.reduce_method) + self.assertEquals(variable_scope.VariableAggregation.SUM, + tower_local.aggregation) @test_util.run_in_graph_and_eager_modes(config=config) def testVariableOnAnotherDevice(self): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) index = {"/job:foo/device:CPU:0": v} - tower_local = values.TowerLocalVariable(index, v, "mean") + tower_local = values.TowerLocalVariable( + index, v, variable_scope.VariableAggregation.MEAN) self.assertEquals(v.name, tower_local.name) self.assertEquals(v.dtype, tower_local.dtype) self.assertEquals(v.shape, tower_local.shape) - self.assertEquals("mean", tower_local.reduce_method) + self.assertEquals(variable_scope.VariableAggregation.MEAN, + tower_local.aggregation) def _assign_tower_local(self, devices, v, new): for d, var, n in zip(devices, v, new): @@ -789,7 +822,7 @@ class TowerLocalVariableTest(test.TestCase): self.skipTest("A GPU is not available for this test in eager mode.") with self.test_session() as sess: - v, tower_local = _make_tower_local("sum") + v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) # Overwrite the initial values. self._assign_tower_local(_devices, v, [3., 4.]) @@ -812,7 +845,8 @@ class TowerLocalVariableTest(test.TestCase): self.skipTest("A GPU is not available for this test in eager mode.") with self.test_session() as sess: - v, tower_local = _make_tower_local("mean") + v, tower_local = _make_tower_local( + variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. self._assign_tower_local(_devices, v, [3., 4.]) @@ -831,7 +865,8 @@ class TowerLocalVariableTest(test.TestCase): def _save_tower_local_mean(self): """Save variables with mirroring, returns save_path.""" with self.test_session(graph=ops.Graph()) as sess: - v, tower_local = _make_tower_local("mean") + v, tower_local = _make_tower_local( + variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. self._assign_tower_local(_devices, v, [3., 4.]) @@ -893,7 +928,8 @@ class TowerLocalVariableTest(test.TestCase): def _restore_tower_local_mean(self, save_path): """Restore to variables with mirroring in a fresh graph.""" with self.test_session(graph=ops.Graph()) as sess: - v, tower_local = _make_tower_local("mean") + v, tower_local = _make_tower_local( + variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. self._assign_tower_local(_devices, v, [7., 8.]) @@ -907,7 +943,7 @@ class TowerLocalVariableTest(test.TestCase): def _restore_tower_local_sum(self, save_path): """Restore to variables with mirroring in a fresh graph.""" with self.test_session(graph=ops.Graph()) as sess: - v, tower_local = _make_tower_local("sum") + v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) # Overwrite the initial values. self._assign_tower_local(_devices, v, [7., 8.]) @@ -966,6 +1002,18 @@ class TowerLocalVariableTest(test.TestCase): save_path = self._save_normal() self._restore_tower_local_sum(save_path) + def testTensorConversion(self): + with context.graph_mode(): + _, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) + converted = ops.internal_convert_to_tensor(tower_local, as_ref=False) + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, tower_local.dtype) + + converted = ops.internal_convert_to_tensor(tower_local, as_ref=True) + # Resources variable are converted to tensors as well when as_ref is True. + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, tower_local.dtype) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/warm_starting_util_test.py b/tensorflow/contrib/distribute/python/warm_starting_util_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d8bacdb338d93a169a26a55d8ee5f5f9f0d59fce --- /dev/null +++ b/tensorflow/contrib/distribute/python/warm_starting_util_test.py @@ -0,0 +1,97 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for warm_starting_util with Distribution Strategy. + +These tests are located here instead of as part of `WarmStartingUtilTest` +because they need access to distribution strategies which are only present in +contrib right now. +TODO(priyag): Move the tests to core `WarmStartingUtilTest` when distribution +strategy moves out of contrib. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import warm_starting_util as ws_util + + +class WarmStartingUtilWithDistributionStrategyTest( + test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + save_with_distribution=[True, False], + restore_with_distribution=[True, False], + mode=["graph"])) + def testWarmStart(self, distribution, save_with_distribution, + restore_with_distribution): + + var_name = "v" + original_value = [[1., 2.], [3., 4.]] + + # Create variable and save checkpoint from which to warm-start. + def create_var(g): + with self.test_session(graph=g) as sess: + var = variable_scope.get_variable(var_name, initializer=original_value) + sess.run(variables.global_variables_initializer()) + saver = saver_lib.Saver() + ckpt_prefix = os.path.join(self.get_temp_dir(), "model") + saver.save(sess, ckpt_prefix, global_step=0) + return var, sess.run(var) + + if save_with_distribution: + with ops.Graph().as_default() as g, distribution.scope(): + _, prev_init_val = create_var(g) + else: + with ops.Graph().as_default() as g: + _, prev_init_val = create_var(g) + + # Verify we initialized the values correctly. + self.assertAllEqual(original_value, prev_init_val) + + def warm_start(g): + with self.test_session(graph=g) as sess: + # Initialize with zeros. + var = variable_scope.get_variable( + var_name, initializer=[[0., 0.], [0., 0.]]) + ws_util.warm_start(self.get_temp_dir()) + sess.run(variables.global_variables_initializer()) + # Verify weights were correctly warm-started to previous values. + self.assertAllEqual(original_value, self.evaluate(var)) + + # Warm start in a new graph. + if restore_with_distribution: + with ops.Graph().as_default() as g, distribution.scope(): + warm_start(g) + else: + with ops.Graph().as_default() as g: + warm_start(g) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index ad00d1734dd14ed846522a33d888a5387cb25cc6..a8d0d493abcd7de540799f6b94c3cdb9ce9dafae 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -124,7 +124,7 @@ cuda_py_test( cuda_py_test( name = "conditional_distribution_test", - size = "small", + size = "medium", srcs = [ "python/kernel_tests/conditional_distribution_test.py", "python/kernel_tests/distribution_test.py", diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py index e281e81bdf0698c1f7b2f60fb27783dd1351773f..d1ce273499c8a646c0757844c91a785fa8d56ce4 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py @@ -61,6 +61,28 @@ class CholeskyOuterProductBijectorTest(test.TestCase): atol=0., rtol=1e-7) + def testNoBatchStaticJacobian(self): + x = np.eye(2) + bijector = bijectors.CholeskyOuterProduct() + + # The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4. + self.assertAllClose( + np.log(4), + self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=2))) + + def testNoBatchDynamicJacobian(self): + x = np.eye(2) + bijector = bijectors.CholeskyOuterProduct() + x_pl = array_ops.placeholder(dtypes.float32) + + with self.test_session(): + log_det_jacobian = bijector.forward_log_det_jacobian(x_pl, event_ndims=2) + + # The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4. + self.assertAllClose( + np.log(4), + log_det_jacobian.eval({x_pl: x})) + def testNoBatchStatic(self): x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y) y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py index caeaf2a0c6e4fff28c0edd82cb09ca0bcee85fc3..3530e142e4d1545e80a3b1bf1e8ddbf7819ba58a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py @@ -31,7 +31,7 @@ from tensorflow.python.platform import test class FillTriangularBijectorTest(test.TestCase): """Tests the correctness of the FillTriangular bijector.""" - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijector(self): x = np.float32(np.array([1., 2., 3.])) y = np.float32(np.array([[3., 0.], @@ -51,7 +51,7 @@ class FillTriangularBijectorTest(test.TestCase): ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) self.assertAllClose(ildj, 0.) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShape(self): x_shape = tensor_shape.TensorShape([5, 4, 6]) y_shape = tensor_shape.TensorShape([5, 4, 3, 3]) @@ -76,7 +76,7 @@ class FillTriangularBijectorTest(test.TestCase): b.inverse_event_shape_tensor(y_shape.as_list())) self.assertAllEqual(x_shape_tensor, x_shape.as_list()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShapeError(self): b = bijectors.FillTriangular(validate_args=True) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py index 18397035571561731698b06d90e20dc74e3cf83c..49a9afe3f6debe048369c52328fb5534946ab9e5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py @@ -29,7 +29,18 @@ from tensorflow.python.platform import test class MatrixInverseTriLBijectorTest(test.TestCase): """Tests the correctness of the Y = inv(tril) transformation.""" - @test_util.run_in_graph_and_eager_modes() + #The inverse of 0 is undefined, as the numbers above the main + #diagonal must be zero, we zero out these numbers after running inverse. + #See: https://github.com/numpy/numpy/issues/11445 + def _inv(self, x): + y = np.linalg.inv(x) + #triu_indices only works on 2d arrays + #need to iterate over all the 2d arrays in a x-dimensional array. + for idx in np.ndindex(y.shape[0:-2]): + y[idx][np.triu_indices(y[idx].shape[-1], 1)] = 0 + return y + + @test_util.run_in_graph_and_eager_modes def testComputesCorrectValues(self): inv = bijectors.MatrixInverseTriL(validate_args=True) self.assertEqual("matrix_inverse_tril", inv.name) @@ -51,7 +62,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertNear(expected_fldj_, fldj_, err=1e-3) self.assertNear(-expected_fldj_, ildj_, err=1e-3) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOneByOneMatrix(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([[5.]], dtype=np.float32) @@ -70,7 +81,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertNear(expected_fldj_, fldj_, err=1e-3) self.assertNear(-expected_fldj_, ildj_, err=1e-3) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testZeroByZeroMatrix(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.eye(0, dtype=np.float32) @@ -89,7 +100,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertNear(expected_fldj_, fldj_, err=1e-3) self.assertNear(-expected_fldj_, ildj_, err=1e-3) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBatch(self): # Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape # (2, 1). @@ -98,7 +109,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): [2., 3.]]], [[[4., 0.], [5., -6.]]]], dtype=np.float32) - x_inv_ = np.linalg.inv(x_) + x_inv_ = self._inv(x_) expected_fldj_ = -4. * np.sum( np.log(np.abs(np.diagonal(x_, axis1=-2, axis2=-1))), axis=-1) @@ -114,7 +125,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3) self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testErrorOnInputRankTooLow(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([0.1], dtype=np.float32) @@ -149,7 +160,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): ## square_error_msg): ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testErrorOnInputNotLowerTriangular(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([[1., 2.], @@ -169,7 +180,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase): triangular_error_msg): inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testErrorOnInputSingular(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([[1., 0.], diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py index a5f5219588fb3be67beb797ba68ed8148e9e9fd2..cb42331a21a6acdd5244c311a7def5359bb6c574 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py @@ -36,7 +36,7 @@ class OrderedBijectorTest(test.TestCase): def setUp(self): self._rng = np.random.RandomState(42) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorVector(self): with self.test_session(): ordered = Ordered() @@ -82,7 +82,7 @@ class OrderedBijectorTest(test.TestCase): atol=0., rtol=1e-7) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShapeGetters(self): with self.test_session(): x = tensor_shape.TensorShape([4]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py index 566a7b3dff9b5d97a1cb143e0b32fc15984c3a02..d5b3367f9a31a9c602e0b138e617db68834b8229 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py @@ -46,7 +46,7 @@ class ScaleTriLBijectorTest(test.TestCase): x_ = self.evaluate(b.inverse(y)) self.assertAllClose(x, x_) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInvertible(self): # Generate random inputs from an unconstrained space, with diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py index 2ac06fce55b448a5f3da7ccb7f8766b5b1404ad7..d0098c3c105626da1da5855710169069ebeffbd9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py @@ -40,7 +40,7 @@ class SoftsignBijectorTest(test.TestCase): def setUp(self): self._rng = np.random.RandomState(42) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorBounds(self): bijector = Softsign(validate_args=True) with self.test_session(): @@ -54,7 +54,7 @@ class SoftsignBijectorTest(test.TestCase): with self.assertRaisesOpError("less than 1"): bijector.inverse_log_det_jacobian(3., event_ndims=0).eval() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorForwardInverse(self): bijector = Softsign(validate_args=True) self.assertEqual("softsign", bijector.name) @@ -64,7 +64,7 @@ class SoftsignBijectorTest(test.TestCase): self.assertAllClose(y, self.evaluate(bijector.forward(x))) self.assertAllClose(x, self.evaluate(bijector.inverse(y))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorLogDetJacobianEventDimsZero(self): bijector = Softsign(validate_args=True) y = self._rng.rand(2, 10) @@ -74,7 +74,7 @@ class SoftsignBijectorTest(test.TestCase): self.assertAllClose(ildj, self.evaluate( bijector.inverse_log_det_jacobian(y, event_ndims=0))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorForwardInverseEventDimsOne(self): bijector = Softsign(validate_args=True) self.assertEqual("softsign", bijector.name) @@ -83,7 +83,7 @@ class SoftsignBijectorTest(test.TestCase): self.assertAllClose(y, self.evaluate(bijector.forward(x))) self.assertAllClose(x, self.evaluate(bijector.inverse(y))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijectorLogDetJacobianEventDimsOne(self): bijector = Softsign(validate_args=True) y = self._rng.rand(2, 10) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py index 6428a68702274fae384ae3de6d03f7ca126e2346..efc9f266d1fb6bcc53ae318e218b0697825c0155 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py @@ -31,7 +31,7 @@ class TransformDiagonalBijectorTest(test.TestCase): def setUp(self): self._rng = np.random.RandomState(42) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBijector(self): x = np.float32(np.random.randn(3, 4, 4)) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py index 90910f3839b1a4e882debf396b90955a42762794..200310bc414b6703d0683ce9f81b0aa5441f677d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py @@ -173,6 +173,13 @@ class DeterministicTest(test.TestCase): self.assertAllClose( np.zeros(sample_shape_ + (2,)).astype(np.float32), sample_) + def testEntropy(self): + loc = np.array([-0.1, -3.2, 7.]) + deterministic = deterministic_lib.Deterministic(loc=loc) + with self.test_session() as sess: + entropy_ = sess.run(deterministic.entropy()) + self.assertAllEqual(np.zeros(3), entropy_) + class VectorDeterministicTest(test.TestCase): @@ -290,6 +297,13 @@ class VectorDeterministicTest(test.TestCase): self.assertAllClose( np.zeros(sample_shape_ + (2, 1)).astype(np.float32), sample_) + def testEntropy(self): + loc = np.array([[8.3, 1.2, 3.3], [-0.1, -3.2, 7.]]) + deterministic = deterministic_lib.VectorDeterministic(loc=loc) + with self.test_session() as sess: + entropy_ = sess.run(deterministic.entropy()) + self.assertAllEqual(np.zeros(2), entropy_) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index bbbec2103aefd3f38a9b734bcd3f2e15fc8bb683..181c46d2e52552e641bc59c0fe94743f1af42845 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -544,7 +544,7 @@ class PadDynamicTest(_PadTest, test.TestCase): class TestMoveDimension(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_move_dimension_static_shape(self): x = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) @@ -561,7 +561,7 @@ class TestMoveDimension(test.TestCase): x_perm = distribution_util.move_dimension(x, 4, 2) self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_move_dimension_dynamic_shape(self): x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD index 03e26b198ea02ad1bef8bcd2f6076078ecd7df0b..42ecea034d77430924bd6f597bf42ec3f64fec92 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD @@ -34,7 +34,10 @@ py_test( name = "correlation_matrix_volumes_test", size = "medium", srcs = ["correlation_matrix_volumes_test.py"], - tags = ["no_pip"], + tags = [ + "no_pip", + "optonly", + ], deps = [ ":correlation_matrix_volumes_py", # For statistical testing diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 8267ee7df89f69f8d610e9507e0cca9f4a5d4323..3e1e4fc82971b71792d193ea8518dd402e4a4d9d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -182,7 +182,20 @@ class CholeskyOuterProduct(bijector.Bijector): axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag - return fldj + # We finally need to undo adding an extra column in non-scalar cases + # where there is a single matrix as input. + if x.get_shape().ndims is not None: + if x.get_shape().ndims == 2: + fldj = array_ops.squeeze(fldj, axis=-1) + return fldj + + shape = array_ops.shape(fldj) + maybe_squeeze_shape = array_ops.concat([ + shape[:-1], + distribution_util.pick_vector( + math_ops.equal(array_ops.rank(x), 2), + np.array([], dtype=np.int32), shape[-1:])], 0) + return array_ops.reshape(fldj, maybe_squeeze_shape) def _make_columnar(self, x): """Ensures non-scalar input has at least one column. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index b8f2a4b2c731bdaee78692c036fb9f2fba4e3760..296e66f2b24fecf2142066727b5b12ee5cbd0379 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -514,9 +514,8 @@ def masked_autoregressive_default_template( Masked Autoencoder for Distribution Estimation. In _International Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 """ - - with ops.name_scope(name, "masked_autoregressive_default_template", - values=[log_scale_min_clip, log_scale_max_clip]): + name = name or "masked_autoregressive_default_template" + with ops.name_scope(name, values=[log_scale_min_clip, log_scale_max_clip]): def _fn(x): """MADE parameterized via `masked_autoregressive_default_template`.""" # TODO(b/67594795): Better support of dynamic shape. @@ -552,8 +551,7 @@ def masked_autoregressive_default_template( else _clip_by_value_preserve_grad) log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) return shift, log_scale - return template_ops.make_template( - "masked_autoregressive_default_template", _fn) + return template_ops.make_template(name, _fn) @deprecation.deprecated( diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index ad853ee293f86565c1af601214522f53d936b70a..affc64a14f6fe9ae6e08ceff2298bc99ee7caa43 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -152,6 +152,9 @@ class _BaseDeterministic(distribution.Distribution): """Relative tolerance for comparing points to `self.loc`.""" return self._rtol + def _entropy(self): + return array_ops.zeros(self.batch_shape_tensor(), dtype=self.dtype) + def _mean(self): return array_ops.identity(self.loc) diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index 0c762f17c9b770ecada57b6ce60a4825ba374dd9..214c6dca4a7f2b4cd6242e1b7ca78be9eeffb851 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -235,7 +235,7 @@ class OneHotCategorical(distribution.Distribution): return x return control_flow_ops.with_dependencies([ check_ops.assert_non_positive(x), - distribution_util.assert_close( + check_ops.assert_near( array_ops.zeros([], dtype=self.dtype), math_ops.reduce_logsumexp(x, axis=[-1])), ], x) diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index ef3bdfa75fcaa8df17db1238ceadadf788601356..18a0f754e6e618f240db109f593a80dec57e200b 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -326,6 +326,21 @@ class QuantizedDistribution(distributions.Distribution): graph_parents=graph_parents, name=name) + @property + def distribution(self): + """Base distribution, p(x).""" + return self._dist + + @property + def low(self): + """Lowest value that quantization returns.""" + return self._low + + @property + def high(self): + """Highest value that quantization returns.""" + return self._high + def _batch_shape_tensor(self): return self.distribution.batch_shape_tensor() @@ -569,8 +584,3 @@ class QuantizedDistribution(distributions.Distribution): dependencies = [distribution_util.assert_integer_form( value, message="value has non-integer components.")] return control_flow_ops.with_dependencies(dependencies, value) - - @property - def distribution(self): - """Base distribution, p(x).""" - return self._dist diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 9b5bd7576f2a3c364e21da76dd3905a8c6e35829..25aaac379a7c54c832bdcf962e16f339522d61fc 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -299,7 +299,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): return x return control_flow_ops.with_dependencies([ check_ops.assert_non_positive(x), - distribution_util.assert_close( + check_ops.assert_near( array_ops.zeros([], dtype=self.dtype), math_ops.reduce_logsumexp(x, axis=[-1])), ], x) diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index f5aaa5cf34abde3ea4d25de1ecf3adaef3f2a770..aa680a92be64cf0f099acd335369f2a1610c5953 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -134,7 +134,7 @@ def auto_correlation( x_len = util.prefer_static_shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At - # the moment is is necessary so that all FFT implementations work. + # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = math_ops.cast(x_len, np.float64) @@ -198,7 +198,7 @@ def auto_correlation( # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to - # divide by by to make this an unbiased estimate of the expectation + # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = math_ops.cast(x_len, dtype.real_dtype) max_lags = math_ops.cast(max_lags, dtype.real_dtype) diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index 4384431e7b9c3e6ef259391fa9efa5a35d23c86a..86d203452e24d6d73f3ebb17b989867905a61382 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -44,7 +44,7 @@ Installation instructions at https://www.tensorflow.org/install/ For an introduction to eager execution in TensorFlow, see: -- [User Guide](https://www.tensorflow.org/programmers_guide/eager) ([source](../../docs_src/programmers_guide/eager.md)) +- [User Guide](https://www.tensorflow.org/guide/eager) ([source](../../docs_src/guide/eager.md)) - Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb) - Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb) - Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 0cc764d2208c5b061b7b836bdf57a035f52c6fcf..f7933639a086483b8dc044837276ce0e76840319 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -199,7 +199,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", - "//tensorflow/python/estimator:util", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -223,3 +223,17 @@ py_test( "//tensorflow/python/eager:test", ], ) + +py_test( + name = "remote_test", + srcs = ["remote_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:function", + ], +) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index adf92c27ea0a27c5741bcdd175b277462cb28d02..135095a97980da8988b976948fb18492526e390c 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -18,39 +18,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import threading - from tensorflow.contrib.data.python.ops import prefetching_ops from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.saver import BaseSaverBuilder - -_uid_counter = 0 -_uid_lock = threading.Lock() - - -def _generate_shared_name(prefix): - with _uid_lock: - global _uid_counter - uid = _uid_counter - _uid_counter += 1 - return "{}{}".format(prefix, uid) -class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): +class Iterator(iterator_ops.EagerIterator): """An iterator producing tf.Tensor objects from a tf.data.Dataset. NOTE: Unlike the iterator created by the - @{tf.data.Dataset.make_one_shot_iterator} method, this class enables + `tf.data.Dataset.make_one_shot_iterator` method, this class enables additional experimental functionality, such as prefetching to the GPU. """ @@ -80,37 +58,18 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): "`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate " "over the dataset instead.") - super(Iterator, self).__init__(dataset) if not context.context().device_spec.device_type: is_remote_device = False else: is_remote_device = context.context().device_spec.device_type != "CPU" - self._buffer_resource_handle = None if is_remote_device: - with ops.device("/device:CPU:0"): - iter_string_handle = gen_dataset_ops.iterator_to_string_handle( - self._resource) - - @function.Defun(dtypes.string) - def remote_fn(h): - remote_iterator = iterator_ops.Iterator.from_string_handle( - h, self.output_types, self.output_shapes, self.output_classes) - return remote_iterator.get_next() - - remote_fn.add_to_graph(None) - target = constant_op.constant("/device:CPU:0") - with ops.device(self._device): - self._buffer_resource_handle = prefetching_ops.function_buffering_resource( # pylint: disable=line-too-long - string_arg=iter_string_handle, - f=remote_fn, - target_device=target, - buffer_size=10, - container="", - shared_name=_generate_shared_name( - "contrib_eager_iterator_function_buffer_resource")) - self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long - handle=self._buffer_resource_handle, - handle_device=self._device) + with ops.device(None): + # Let the placer figure out where to place the various functions etc. + # created by the CopyToDeviceDataset. + dataset = dataset.apply(prefetching_ops.copy_to_device( + context.context().device_name)) + dataset = dataset.prefetch(1) + super(Iterator, self).__init__(dataset) def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. @@ -119,40 +78,4 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): # that there is no more data to iterate over. # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): - if self._buffer_resource_handle is not None: - with ops.device(self._device): - ret = prefetching_ops.function_buffering_resource_get_next( - function_buffer_resource=self._buffer_resource_handle, - output_types=self._flat_output_types) - return sparse.deserialize_sparse_tensors( - nest.pack_sequence_as(self._output_types, ret), self._output_types, - self._output_shapes, self._output_classes) - else: - return super(Iterator, self)._next_internal() - - # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset - # attributes(potential). - - class _Saveable(BaseSaverBuilder.SaveableObject): - """SaveableObject for saving/restoring iterator state.""" - - def __init__(self, iterator_resource, name): - serialized_iterator = gen_dataset_ops.serialize_iterator( - iterator_resource) - specs = [ - BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") - ] - # pylint: disable=protected-access - super(Iterator._Saveable, self).__init__(iterator_resource, specs, name) - - def restore(self, restored_tensors, restored_shapes): - with ops.colocate_with(self.op): - return gen_dataset_ops.deserialize_iterator(self.op, - restored_tensors[0]) - - def _gather_saveables_for_checkpoint(self): - - def _saveable_factory(name): - return self._Saveable(self._resource, name) - - return {"ITERATOR": _saveable_factory} + return super(Iterator, self)._next_internal() diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 68bec9aee894edd60a025ac1cf87ca3e010db842..a753d77580758af9de8410de4a08f7ea278c4c79 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.training import checkpoint_management from tensorflow.python.training.checkpointable import util as checkpointable_utils @@ -193,6 +194,20 @@ class IteratorTest(test.TestCase): x = math_ops.add(x, x) self.assertAllEqual([0., 2.], x.numpy()) + def testGpuTensor(self): + ds = Dataset.from_tensors([0., 1.]) + with ops.device(test.gpu_device_name()): + for x in ds: + y = math_ops.add(x, x) + self.assertAllEqual([0., 2.], y.numpy()) + + def testGpuDefinedDataset(self): + with ops.device(test.gpu_device_name()): + ds = Dataset.from_tensors([0., 1.]) + for x in ds: + y = math_ops.add(x, x) + self.assertAllEqual([0., 2.], y.numpy()) + def testTensorsExplicitPrefetchToDevice(self): ds = Dataset.from_tensor_slices([0., 1.]) ds = ds.apply(prefetching_ops.prefetch_to_device(test.gpu_device_name())) @@ -292,6 +307,19 @@ class IteratorTest(test.TestCase): checkpoint.restore(save_path) self.assertEqual(2, iterator.get_next().numpy()) + def testRestoreInReconstructedIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.range(10) + for i in range(5): + iterator = datasets.Iterator(dataset) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) + for j in range(2): + self.assertEqual(i * 2 + j, iterator.get_next().numpy()) + checkpoint.save(file_prefix=checkpoint_prefix) + class DatasetConstructorBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 1d9371c7ac405dbf0ec40210270b90f2cf9b9a25..6f02c90368d966b8cf8d0dee09f9d2a5013c90c1 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -11,6 +11,8 @@ py_library( "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets", "//tensorflow/contrib/eager/python/examples/linear_regression", "//tensorflow/contrib/eager/python/examples/resnet50", + "//tensorflow/contrib/eager/python/examples/revnet", + "//tensorflow/contrib/eager/python/examples/revnet:config", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", "//tensorflow/contrib/eager/python/examples/spinn:data", diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..2dc196f550a10367066730f6f042c4ed69533ec3 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD @@ -0,0 +1,48 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_binary( + name = "densenet", + srcs = ["densenet.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + ], +) + +cuda_py_test( + name = "densenet_test", + size = "large", + srcs = ["densenet_test.py"], + additional_deps = [ + ":densenet", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow:tensorflow_py", + ], + tags = [ + "no_pip", + "optonly", + ], +) + +cuda_py_test( + name = "densenet_graph_test", + size = "large", + srcs = ["densenet_graph_test.py"], + additional_deps = [ + ":densenet", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], + tags = [ + "no_pip", + "noasan", + "nomsan", + "notsan", + "optonly", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet.py b/tensorflow/contrib/eager/python/examples/densenet/densenet.py new file mode 100644 index 0000000000000000000000000000000000000000..6de4e6940094849b5cf6f977e351aef525c77cc2 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet.py @@ -0,0 +1,296 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Densely Connected Convolutional Networks. + +Reference [ +Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993) + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +l2 = tf.keras.regularizers.l2 + + +class ConvBlock(tf.keras.Model): + """Convolutional Block consisting of (batchnorm->relu->conv). + + Arguments: + num_filters: number of filters passed to a convolutional layer. + data_format: "channels_first" or "channels_last" + bottleneck: if True, then a 1x1 Conv is performed followed by 3x3 Conv. + weight_decay: weight decay + dropout_rate: dropout rate. + """ + + def __init__(self, num_filters, data_format, bottleneck, weight_decay=1e-4, + dropout_rate=0): + super(ConvBlock, self).__init__() + self.bottleneck = bottleneck + + axis = -1 if data_format == "channels_last" else 1 + inter_filter = num_filters * 4 + # don't forget to set use_bias=False when using batchnorm + self.conv2 = tf.keras.layers.Conv2D(num_filters, + (3, 3), + padding="same", + use_bias=False, + data_format=data_format, + kernel_initializer="he_normal", + kernel_regularizer=l2(weight_decay)) + self.batchnorm1 = tf.keras.layers.BatchNormalization(axis=axis) + self.dropout = tf.keras.layers.Dropout(dropout_rate) + + if self.bottleneck: + self.conv1 = tf.keras.layers.Conv2D(inter_filter, + (1, 1), + padding="same", + use_bias=False, + data_format=data_format, + kernel_initializer="he_normal", + kernel_regularizer=l2(weight_decay)) + self.batchnorm2 = tf.keras.layers.BatchNormalization(axis=axis) + + def call(self, x, training=True): + output = self.batchnorm1(x, training=training) + + if self.bottleneck: + output = self.conv1(tf.nn.relu(output)) + output = self.batchnorm2(output, training=training) + + output = self.conv2(tf.nn.relu(output)) + output = self.dropout(output, training=training) + + return output + + +class TransitionBlock(tf.keras.Model): + """Transition Block to reduce the number of features. + + Arguments: + num_filters: number of filters passed to a convolutional layer. + data_format: "channels_first" or "channels_last" + weight_decay: weight decay + dropout_rate: dropout rate. + """ + + def __init__(self, num_filters, data_format, + weight_decay=1e-4, dropout_rate=0): + super(TransitionBlock, self).__init__() + axis = -1 if data_format == "channels_last" else 1 + + self.batchnorm = tf.keras.layers.BatchNormalization(axis=axis) + self.conv = tf.keras.layers.Conv2D(num_filters, + (1, 1), + padding="same", + use_bias=False, + data_format=data_format, + kernel_initializer="he_normal", + kernel_regularizer=l2(weight_decay)) + self.avg_pool = tf.keras.layers.AveragePooling2D(data_format=data_format) + + def call(self, x, training=True): + output = self.batchnorm(x, training=training) + output = self.conv(tf.nn.relu(output)) + output = self.avg_pool(output) + return output + + +class DenseBlock(tf.keras.Model): + """Dense Block consisting of ConvBlocks where each block's + output is concatenated with its input. + + Arguments: + num_layers: Number of layers in each block. + growth_rate: number of filters to add per conv block. + data_format: "channels_first" or "channels_last" + bottleneck: boolean, that decides which part of ConvBlock to call. + weight_decay: weight decay + dropout_rate: dropout rate. + """ + + def __init__(self, num_layers, growth_rate, data_format, bottleneck, + weight_decay=1e-4, dropout_rate=0): + super(DenseBlock, self).__init__() + self.num_layers = num_layers + self.axis = -1 if data_format == "channels_last" else 1 + + self.blocks = [] + for _ in range(int(self.num_layers)): + self.blocks.append(ConvBlock(growth_rate, + data_format, + bottleneck, + weight_decay, + dropout_rate)) + + def call(self, x, training=True): + for i in range(int(self.num_layers)): + output = self.blocks[i](x, training=training) + x = tf.concat([x, output], axis=self.axis) + + return x + + +class DenseNet(tf.keras.Model): + """Creating the Densenet Architecture. + + Arguments: + depth_of_model: number of layers in the model. + growth_rate: number of filters to add per conv block. + num_of_blocks: number of dense blocks. + output_classes: number of output classes. + num_layers_in_each_block: number of layers in each block. + If -1, then we calculate this by (depth-3)/4. + If positive integer, then the it is used as the + number of layers per block. + If list or tuple, then this list is used directly. + data_format: "channels_first" or "channels_last" + bottleneck: boolean, to decide which part of conv block to call. + compression: reducing the number of inputs(filters) to the transition block. + weight_decay: weight decay + rate: dropout rate. + pool_initial: If True add a 7x7 conv with stride 2 followed by 3x3 maxpool + else, do a 3x3 conv with stride 1. + include_top: If true, GlobalAveragePooling Layer and Dense layer are + included. + """ + + def __init__(self, depth_of_model, growth_rate, num_of_blocks, + output_classes, num_layers_in_each_block, data_format, + bottleneck=True, compression=0.5, weight_decay=1e-4, + dropout_rate=0, pool_initial=False, include_top=True): + super(DenseNet, self).__init__() + self.depth_of_model = depth_of_model + self.growth_rate = growth_rate + self.num_of_blocks = num_of_blocks + self.output_classes = output_classes + self.num_layers_in_each_block = num_layers_in_each_block + self.data_format = data_format + self.bottleneck = bottleneck + self.compression = compression + self.weight_decay = weight_decay + self.dropout_rate = dropout_rate + self.pool_initial = pool_initial + self.include_top = include_top + + # deciding on number of layers in each block + if isinstance(self.num_layers_in_each_block, list) or isinstance( + self.num_layers_in_each_block, tuple): + self.num_layers_in_each_block = list(self.num_layers_in_each_block) + else: + if self.num_layers_in_each_block == -1: + if self.num_of_blocks != 3: + raise ValueError( + "Number of blocks must be 3 if num_layers_in_each_block is -1") + if (self.depth_of_model - 4) % 3 == 0: + num_layers = (self.depth_of_model - 4) / 3 + if self.bottleneck: + num_layers //= 2 + self.num_layers_in_each_block = [num_layers] * self.num_of_blocks + else: + raise ValueError("Depth must be 3N+4 if num_layer_in_each_block=-1") + else: + self.num_layers_in_each_block = [ + self.num_layers_in_each_block] * self.num_of_blocks + + axis = -1 if self.data_format == "channels_last" else 1 + + # setting the filters and stride of the initial covn layer. + if self.pool_initial: + init_filters = (7, 7) + stride = (2, 2) + else: + init_filters = (3, 3) + stride = (1, 1) + + self.num_filters = 2 * self.growth_rate + + # first conv and pool layer + self.conv1 = tf.keras.layers.Conv2D(self.num_filters, + init_filters, + strides=stride, + padding="same", + use_bias=False, + data_format=self.data_format, + kernel_initializer="he_normal", + kernel_regularizer=l2( + self.weight_decay)) + if self.pool_initial: + self.pool1 = tf.keras.layers.MaxPooling2D(pool_size=(3, 3), + strides=(2, 2), + padding="same", + data_format=self.data_format) + self.batchnorm1 = tf.keras.layers.BatchNormalization(axis=axis) + + self.batchnorm2 = tf.keras.layers.BatchNormalization(axis=axis) + + # last pooling and fc layer + if self.include_top: + self.last_pool = tf.keras.layers.GlobalAveragePooling2D( + data_format=self.data_format) + self.classifier = tf.keras.layers.Dense(self.output_classes) + + # calculating the number of filters after each block + num_filters_after_each_block = [self.num_filters] + for i in range(1, self.num_of_blocks): + temp_num_filters = num_filters_after_each_block[i-1] + ( + self.growth_rate * self.num_layers_in_each_block[i-1]) + # using compression to reduce the number of inputs to the + # transition block + temp_num_filters = int(temp_num_filters * compression) + num_filters_after_each_block.append(temp_num_filters) + + # dense block initialization + self.dense_blocks = [] + self.transition_blocks = [] + for i in range(self.num_of_blocks): + self.dense_blocks.append(DenseBlock(self.num_layers_in_each_block[i], + self.growth_rate, + self.data_format, + self.bottleneck, + self.weight_decay, + self.dropout_rate)) + if i+1 < self.num_of_blocks: + self.transition_blocks.append( + TransitionBlock(num_filters_after_each_block[i+1], + self.data_format, + self.weight_decay, + self.dropout_rate)) + + def call(self, x, training=True): + output = self.conv1(x) + + if self.pool_initial: + output = self.batchnorm1(output, training=training) + output = tf.nn.relu(output) + output = self.pool1(output) + + for i in range(self.num_of_blocks - 1): + output = self.dense_blocks[i](output, training=training) + output = self.transition_blocks[i](output, training=training) + + output = self.dense_blocks[ + self.num_of_blocks - 1](output, training=training) + output = self.batchnorm2(output, training=training) + output = tf.nn.relu(output) + + if self.include_top: + output = self.last_pool(output) + output = self.classifier(output) + + return output diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4b3cb624bc947a1d1956eff6accb6d4da3bf3b87 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py @@ -0,0 +1,151 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests and Benchmarks for Densenet model under graph execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.densenet import densenet + + +def data_format(): + return 'channels_first' if tf.test.is_gpu_available() else 'channels_last' + + +def image_shape(batch_size): + if data_format() == 'channels_first': + return [batch_size, 3, 224, 224] + return [batch_size, 224, 224, 3] + + +def random_batch(batch_size): + images = np.random.rand(*image_shape(batch_size)).astype(np.float32) + num_classes = 1000 + labels = np.random.randint( + low=0, high=num_classes, size=[batch_size]).astype(np.int32) + one_hot = np.zeros((batch_size, num_classes)).astype(np.float32) + one_hot[np.arange(batch_size), labels] = 1. + return images, one_hot + + +class DensenetGraphTest(tf.test.TestCase): + + def testApply(self): + depth = 7 + growth_rate = 2 + num_blocks = 3 + output_classes = 10 + num_layers_in_each_block = -1 + batch_size = 1 + with tf.Graph().as_default(): + images = tf.placeholder(tf.float32, image_shape(None)) + model = densenet.DenseNet(depth, growth_rate, num_blocks, + output_classes, num_layers_in_each_block, + data_format(), bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=False, include_top=True) + predictions = model(images, training=False) + + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + np_images, _ = random_batch(batch_size) + out = sess.run(predictions, feed_dict={images: np_images}) + self.assertAllEqual([batch_size, output_classes], out.shape) + + +class DensenetBenchmark(tf.test.Benchmark): + + def __init__(self): + self.depth = 121 + self.growth_rate = 32 + self.num_blocks = 4 + self.output_classes = 1000 + self.num_layers_in_each_block = [6, 12, 24, 16] + + def _report(self, label, start, num_iters, batch_size): + avg_time = (time.time() - start) / num_iters + dev = 'gpu' if tf.test.is_gpu_available() else 'cpu' + name = 'graph_%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format()) + extras = {'examples_per_sec': batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def benchmark_graph_apply(self): + with tf.Graph().as_default(): + images = tf.placeholder(tf.float32, image_shape(None)) + model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, + self.output_classes, + self.num_layers_in_each_block, data_format(), + bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=True, include_top=True) + predictions = model(images, training=False) + + init = tf.global_variables_initializer() + + batch_size = 64 + with tf.Session() as sess: + sess.run(init) + np_images, _ = random_batch(batch_size) + num_burn, num_iters = (3, 30) + for _ in range(num_burn): + sess.run(predictions, feed_dict={images: np_images}) + start = time.time() + for _ in range(num_iters): + sess.run(predictions, feed_dict={images: np_images}) + self._report('apply', start, num_iters, batch_size) + + def benchmark_graph_train(self): + for batch_size in [16, 32, 64]: + with tf.Graph().as_default(): + np_images, np_labels = random_batch(batch_size) + dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat() + (images, labels) = dataset.make_one_shot_iterator().get_next() + + model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, + self.output_classes, + self.num_layers_in_each_block, data_format(), + bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=True, include_top=True) + logits = model(images, training=True) + cross_ent = tf.losses.softmax_cross_entropy( + logits=logits, onehot_labels=labels) + regularization = tf.add_n(model.losses) + loss = cross_ent + regularization + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + train_op = optimizer.minimize(loss) + + init = tf.global_variables_initializer() + with tf.Session() as sess: + sess.run(init) + (num_burn, num_iters) = (5, 10) + for _ in range(num_burn): + sess.run(train_op) + start = time.time() + for _ in range(num_iters): + sess.run(train_op) + self._report('train', start, num_iters, batch_size) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0736ed02b7437240e5da4dd529ad9ba9a5a15042 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py @@ -0,0 +1,350 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests and Benchmarks for Densenet model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gc +import time +import tensorflow as tf +import tensorflow.contrib.eager as tfe + +from tensorflow.contrib.eager.python.examples.densenet import densenet +from tensorflow.python.client import device_lib + + +class DensenetTest(tf.test.TestCase): + + def test_bottleneck_true(self): + depth = 7 + growth_rate = 2 + num_blocks = 3 + output_classes = 10 + num_layers_in_each_block = -1 + batch_size = 1 + data_format = ('channels_first') if tf.test.is_gpu_available() else ( + 'channels_last') + + model = densenet.DenseNet(depth, growth_rate, num_blocks, + output_classes, num_layers_in_each_block, + data_format, bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=False, include_top=True) + + if data_format == 'channels_last': + rand_input = tf.random_uniform((batch_size, 32, 32, 3)) + else: + rand_input = tf.random_uniform((batch_size, 3, 32, 32)) + output_shape = model(rand_input).shape + self.assertEqual(output_shape, (batch_size, output_classes)) + + def test_bottleneck_false(self): + depth = 7 + growth_rate = 2 + num_blocks = 3 + output_classes = 10 + num_layers_in_each_block = -1 + batch_size = 1 + data_format = ('channels_first') if tf.test.is_gpu_available() else ( + 'channels_last') + + model = densenet.DenseNet(depth, growth_rate, num_blocks, + output_classes, num_layers_in_each_block, + data_format, bottleneck=False, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=False, include_top=True) + + if data_format == 'channels_last': + rand_input = tf.random_uniform((batch_size, 32, 32, 3)) + else: + rand_input = tf.random_uniform((batch_size, 3, 32, 32)) + output_shape = model(rand_input).shape + self.assertEqual(output_shape, (batch_size, output_classes)) + + def test_pool_initial_true(self): + depth = 7 + growth_rate = 2 + num_blocks = 4 + output_classes = 10 + num_layers_in_each_block = [1, 2, 2, 1] + batch_size = 1 + data_format = ('channels_first') if tf.test.is_gpu_available() else ( + 'channels_last') + + model = densenet.DenseNet(depth, growth_rate, num_blocks, + output_classes, num_layers_in_each_block, + data_format, bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=True, include_top=True) + + if data_format == 'channels_last': + rand_input = tf.random_uniform((batch_size, 32, 32, 3)) + else: + rand_input = tf.random_uniform((batch_size, 3, 32, 32)) + output_shape = model(rand_input).shape + self.assertEqual(output_shape, (batch_size, output_classes)) + + def test_regularization(self): + if tf.test.is_gpu_available(): + rand_input = tf.random_uniform((10, 3, 32, 32)) + data_format = 'channels_first' + else: + rand_input = tf.random_uniform((10, 32, 32, 3)) + data_format = 'channels_last' + weight_decay = 1e-4 + + conv = tf.keras.layers.Conv2D( + 3, (3, 3), + padding='same', + use_bias=False, + data_format=data_format, + kernel_regularizer=tf.keras.regularizers.l2(weight_decay)) + optimizer = tf.train.GradientDescentOptimizer(0.1) + conv(rand_input) # Initialize the variables in the layer + + def compute_true_l2(vs, wd): + return tf.reduce_sum(tf.square(vs)) * wd + + true_l2 = compute_true_l2(conv.variables, weight_decay) + keras_l2 = tf.add_n(conv.losses) + self.assertAllClose(true_l2, keras_l2) + + with tf.GradientTape() as tape_true, tf.GradientTape() as tape_keras: + loss = tf.reduce_sum(conv(rand_input)) + loss_with_true_l2 = loss + compute_true_l2(conv.variables, weight_decay) + loss_with_keras_l2 = loss + tf.add_n(conv.losses) + + true_grads = tape_true.gradient(loss_with_true_l2, conv.variables) + keras_grads = tape_keras.gradient(loss_with_keras_l2, conv.variables) + self.assertAllClose(true_grads, keras_grads) + + optimizer.apply_gradients(zip(keras_grads, conv.variables)) + keras_l2_after_update = tf.add_n(conv.losses) + self.assertNotAllClose(keras_l2, keras_l2_after_update) + + +def compute_gradients(model, images, labels): + with tf.GradientTape() as tape: + logits = model(images, training=True) + cross_ent = tf.losses.softmax_cross_entropy( + logits=logits, onehot_labels=labels) + regularization = tf.add_n(model.losses) + loss = cross_ent + regularization + tf.contrib.summary.scalar(name='loss', tensor=loss) + return tape.gradient(loss, model.variables) + + +def apply_gradients(model, optimizer, gradients): + optimizer.apply_gradients(zip(gradients, model.variables)) + + +def device_and_data_format(): + return ('/gpu:0', + 'channels_first') if tf.test.is_gpu_available() else ('/cpu:0', + 'channels_last') + + +def random_batch(batch_size, data_format): + shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3) + shape = (batch_size,) + shape + + num_classes = 1000 + images = tf.random_uniform(shape) + labels = tf.random_uniform( + [batch_size], minval=0, maxval=num_classes, dtype=tf.int32) + one_hot = tf.one_hot(labels, num_classes) + + return images, one_hot + + +class MockIterator(object): + + def __init__(self, tensors): + self._tensors = [tf.identity(x) for x in tensors] + + def next(self): + return self._tensors + + +class DensenetBenchmark(tf.test.Benchmark): + + def __init__(self): + self.depth = 121 + self.growth_rate = 32 + self.num_blocks = 4 + self.output_classes = 1000 + self.num_layers_in_each_block = [6, 12, 24, 16] + + def _train_batch_sizes(self): + """Choose batch sizes based on GPU capability.""" + for device in device_lib.list_local_devices(): + if tf.DeviceSpec.from_string(device.name).device_type == 'GPU': + if 'K20' in device.physical_device_desc: + return (16,) + if 'P100' in device.physical_device_desc: + return (16, 32, 64) + + if tf.DeviceSpec.from_string(device.name).device_type == 'TPU': + return (32,) + return (16, 32) + + def _report(self, label, start, num_iters, device, batch_size, data_format): + avg_time = (time.time() - start) / num_iters + dev = tf.DeviceSpec.from_string(device).device_type.lower() + name = '%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format) + extras = {'examples_per_sec': batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def _force_device_sync(self): + # If this function is called in the context of a non-CPU device + # (e.g., inside a 'with tf.device("/gpu:0")' block) + # then this will force a copy from CPU->NON_CPU_DEVICE->CPU, + # which forces a sync. This is a roundabout way, yes. + tf.constant(1.).cpu() + + def _benchmark_eager_apply(self, label, device_and_format, defun=False, + execution_mode=None, compiled=False): + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format + model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, + self.output_classes, + self.num_layers_in_each_block, data_format, + bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=True, include_top=True) + if defun: + model.call = tfe.defun(model.call, compiled=compiled) + batch_size = 64 + num_burn = 5 + num_iters = 30 + with tf.device(device): + images, _ = random_batch(batch_size, data_format) + for _ in xrange(num_burn): + model(images, training=False).cpu() + if execution_mode: + tfe.async_wait() + gc.collect() + start = time.time() + for _ in xrange(num_iters): + model(images, training=False).cpu() + if execution_mode: + tfe.async_wait() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_apply_sync(self): + self._benchmark_eager_apply('eager_apply', device_and_data_format(), + defun=False) + + def benchmark_eager_apply_async(self): + self._benchmark_eager_apply( + 'eager_apply_async', device_and_data_format(), defun=False, + execution_mode=tfe.ASYNC) + + def benchmark_eager_apply_with_defun(self): + self._benchmark_eager_apply('eager_apply_with_defun', + device_and_data_format(), defun=True) + + def _benchmark_eager_train(self, + label, + make_iterator, + device_and_format, + defun=False, + execution_mode=None, + compiled=False): + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format + for batch_size in self._train_batch_sizes(): + (images, labels) = random_batch(batch_size, data_format) + model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, + self.output_classes, + self.num_layers_in_each_block, data_format, + bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=True, include_top=True) + optimizer = tf.train.GradientDescentOptimizer(0.1) + apply_grads = apply_gradients + if defun: + model.call = tfe.defun(model.call, compiled=compiled) + apply_grads = tfe.defun(apply_gradients, compiled=compiled) + + num_burn = 3 + num_iters = 10 + with tf.device(device): + iterator = make_iterator((images, labels)) + for _ in xrange(num_burn): + (images, labels) = iterator.next() + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + gc.collect() + + start = time.time() + for _ in xrange(num_iters): + (images, labels) = iterator.next() + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_train_sync(self): + self._benchmark_eager_train('eager_train', MockIterator, + device_and_data_format(), defun=False) + + def benchmark_eager_train_async(self): + self._benchmark_eager_train( + 'eager_train_async', + MockIterator, + device_and_data_format(), + defun=False, + execution_mode=tfe.ASYNC) + + def benchmark_eager_train_with_defun(self): + self._benchmark_eager_train( + 'eager_train_with_defun', MockIterator, + device_and_data_format(), defun=True) + + def benchmark_eager_train_datasets(self): + + def make_iterator(tensors): + with tf.device('/device:CPU:0'): + ds = tf.data.Dataset.from_tensors(tensors).repeat() + return tfe.Iterator(ds) + + self._benchmark_eager_train( + 'eager_train_dataset', make_iterator, + device_and_data_format(), defun=False) + + def benchmark_eager_train_datasets_with_defun(self): + + def make_iterator(tensors): + with tf.device('/device:CPU:0'): + ds = tf.data.Dataset.from_tensors(tensors).repeat() + return tfe.Iterator(ds) + + self._benchmark_eager_train( + 'eager_train_dataset_with_defun', make_iterator, + device_and_data_format(), defun=True) + + +if __name__ == '__main__': + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index cc9cf53410f641cc3303b4450e9eaa1301904a64..9a4217929916c258b7e8f2e5b3add2905d20d1da 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -29,7 +29,6 @@ import time import tensorflow as tf -import tensorflow.contrib.eager as tfe from tensorflow.examples.tutorials.mnist import input_data layers = tf.keras.layers @@ -214,7 +213,7 @@ def train_one_epoch(generator, discriminator, generator_optimizer, total_generator_loss = 0.0 total_discriminator_loss = 0.0 - for (batch_index, images) in enumerate(tfe.Iterator(dataset)): + for (batch_index, images) in enumerate(dataset): with tf.device('/cpu:0'): tf.assign_add(step_counter, 1) @@ -227,7 +226,10 @@ def train_one_epoch(generator, discriminator, generator_optimizer, maxval=1., seed=batch_index) - with tf.GradientTape(persistent=True) as g: + # we can use 2 tapes or a single persistent tape. + # Using two tapes is memory efficient since intermediate tensors can be + # released between the two .gradient() calls below + with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise) tf.contrib.summary.image( 'generated_images', @@ -243,9 +245,10 @@ def train_one_epoch(generator, discriminator, generator_optimizer, generator_loss_val = generator_loss(discriminator_gen_outputs) total_generator_loss += generator_loss_val - generator_grad = g.gradient(generator_loss_val, generator.variables) - discriminator_grad = g.gradient(discriminator_loss_val, - discriminator.variables) + generator_grad = gen_tape.gradient(generator_loss_val, + generator.variables) + discriminator_grad = disc_tape.gradient(discriminator_loss_val, + discriminator.variables) generator_optimizer.apply_gradients( zip(generator_grad, generator.variables)) @@ -261,7 +264,7 @@ def train_one_epoch(generator, discriminator, generator_optimizer, def main(_): (device, data_format) = ('/gpu:0', 'channels_first') - if FLAGS.no_gpu or tfe.num_gpus() <= 0: + if FLAGS.no_gpu or tf.contrib.eager.num_gpus() <= 0: (device, data_format) = ('/cpu:0', 'channels_last') print('Using device %s, and data format %s.' % (device, data_format)) @@ -287,7 +290,7 @@ def main(_): latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if latest_cpkt: print('Using latest checkpoint at ' + latest_cpkt) - checkpoint = tfe.Checkpoint(**model_objects) + checkpoint = tf.train.Checkpoint(**model_objects) # Restore variables on creation if a checkpoint exists. checkpoint.restore(latest_cpkt) diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ca27a85a229d41a85fa26ecdc982da478fe9e202 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb @@ -0,0 +1,649 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0TD5ZrvEMbhZ" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# Convolutional VAE: An example with tf.keras and eager\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ITZuApL56Mny" + }, + "source": [ + "![evolution of output during training](https://tensorflow.org/images/autoencoders/cvae.gif)\n", + "\n", + "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) by training a Variational Autoencoder. (VAE, [[1]](https://arxiv.org/abs/1312.6114), [[2]](https://arxiv.org/abs/1401.4082)).\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "P-JuIu2N_SQf" + }, + "outputs": [], + "source": [ + "# to generate gifs\n", + "!pip install imageio" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "e1_Y75QXJS6h" + }, + "source": [ + "## Import TensorFlow and enable Eager execution" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "YfIk2es3hJEd" + }, + "outputs": [], + "source": [ + "from __future__ import absolute_import, division, print_function\n", + "\n", + "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", + "import tensorflow as tf\n", + "tfe = tf.contrib.eager\n", + "tf.enable_eager_execution()\n", + "\n", + "import os\n", + "import time\n", + "import numpy as np\n", + "import glob\n", + "import matplotlib.pyplot as plt\n", + "import PIL\n", + "import imageio\n", + "from IPython import display" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "iYn4MdZnKCey" + }, + "source": [ + "## Load the MNIST dataset\n", + "Each MNIST image is originally a vector of 784 integers, each of which is between 0-255 and represents the intensity of a pixel. We model each pixel with a Bernoulli distribution in our model, and we statically binarize the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "a4fYMGxGhrna" + }, + "outputs": [], + "source": [ + "(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "NFC2ghIdiZYE" + }, + "outputs": [], + "source": [ + "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", + "test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')\n", + "\n", + "# Normalizing the images to the range of [0., 1.]\n", + "train_images /= 255.\n", + "test_images /= 255.\n", + "\n", + "# Binarization\n", + "train_images[train_images \u003e= .5] = 1.\n", + "train_images[train_images \u003c .5] = 0.\n", + "test_images[test_images \u003e= .5] = 1.\n", + "test_images[test_images \u003c .5] = 0." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "S4PIDhoDLbsZ" + }, + "outputs": [], + "source": [ + "TRAIN_BUF = 60000\n", + "BATCH_SIZE = 100\n", + "\n", + "TEST_BUF = 10000" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PIGN6ouoQxt3" + }, + "source": [ + "## Use *tf.data* to create batches and shuffle the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "-yKCCQOoJ7cn" + }, + "outputs": [], + "source": [ + "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)\n", + "test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TEST_BUF).batch(BATCH_SIZE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "THY-sZMiQ4UV" + }, + "source": [ + "## Wire up the generative and inference network with *tf.keras.Sequential*\n", + "\n", + "In our VAE example, we use two small ConvNets for the generative and inference network. Since these neural nets are small, we use `tf.keras.Sequential` to simplify our code. Let $x$ and $z$ denote the observation and latent variable respectively in the following descriptions. \n", + "\n", + "### Generative Network\n", + "This defines the generative model which takes a latent encoding as input, and outputs the parameters for a conditional distribution of the observation, i.e. $p(x|z)$. Additionally, we use a unit Gaussian prior $p(z)$ for the latent variable.\n", + "\n", + "### Inference Network\n", + "This defines an approximate posterior distribution $q(z|x)$, which takes as input an observation and outputs a set of parameters for the conditional distribution of the latent representation. In this example, we simply model this distribution as a diagonal Gaussian. In this case, the inference network outputs the mean and log-variance parameters of a factorized Gaussian (log-variance instead of the variance directly is for numerical stability).\n", + "\n", + "### Reparameterization Trick\n", + "During optimization, we can sample from $q(z|x)$ by first sampling from a unit Gaussian, and then multiplying by the standard deviation and adding the mean. This ensures the gradients could pass through the sample to the inference network parameters.\n", + "\n", + "### Network architecture\n", + "For the inference network, we use two convolutional layers followed by a fully-connected layer. In the generative network, we mirror this architecture by using a fully-connected layer followed by three convolution transpose layers (a.k.a. deconvolutional layers in some contexts). Note, it's common practice to avoid using batch normalization when training VAEs, since the additional stochasticity due to using mini-batches may aggravate instability on top of the stochasticity from sampling." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "VGLbvBEmjK0a" + }, + "outputs": [], + "source": [ + "class CVAE(tf.keras.Model):\n", + " def __init__(self, latent_dim):\n", + " super(CVAE, self).__init__()\n", + " self.latent_dim = latent_dim\n", + " self.inference_net = tf.keras.Sequential(\n", + " [\n", + " tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),\n", + " tf.keras.layers.Conv2D(\n", + " filters=32, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", + " tf.keras.layers.Conv2D(\n", + " filters=64, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", + " tf.keras.layers.Flatten(),\n", + " # No activation\n", + " tf.keras.layers.Dense(latent_dim + latent_dim),\n", + " ]\n", + " )\n", + "\n", + " self.generative_net = tf.keras.Sequential(\n", + " [\n", + " tf.keras.layers.InputLayer(input_shape=(latent_dim,)),\n", + " tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),\n", + " tf.keras.layers.Reshape(target_shape=(7, 7, 32)),\n", + " tf.keras.layers.Conv2DTranspose(\n", + " filters=64,\n", + " kernel_size=3,\n", + " strides=(2, 2),\n", + " padding=\"SAME\",\n", + " activation=tf.nn.relu),\n", + " tf.keras.layers.Conv2DTranspose(\n", + " filters=32,\n", + " kernel_size=3,\n", + " strides=(2, 2),\n", + " padding=\"SAME\",\n", + " activation=tf.nn.relu),\n", + " # No activation\n", + " tf.keras.layers.Conv2DTranspose(\n", + " filters=1, kernel_size=3, strides=(1, 1), padding=\"SAME\"),\n", + " ]\n", + " )\n", + "\n", + " def sample(self, eps=None):\n", + " if eps is None:\n", + " eps = tf.random_normal(shape=(100, self.latent_dim))\n", + " return self.decode(eps, apply_sigmoid=True)\n", + "\n", + " def encode(self, x):\n", + " mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)\n", + " return mean, logvar\n", + "\n", + " def reparameterize(self, mean, logvar):\n", + " eps = tf.random_normal(shape=mean.shape)\n", + " return eps * tf.exp(logvar * .5) + mean\n", + "\n", + " def decode(self, z, apply_sigmoid=False):\n", + " logits = self.generative_net(z)\n", + " if apply_sigmoid:\n", + " probs = tf.sigmoid(logits)\n", + " return probs\n", + "\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0FMYgY_mPfTi" + }, + "source": [ + "## Define the loss function and the optimizer\n", + "\n", + "VAEs train by maximizing the evidence lower bound (ELBO) on the marginal log-likelihood:\n", + "\n", + "$$\\log p(x) \\ge \\text{ELBO} = \\mathbb{E}_{q(z|x)}\\left[\\log \\frac{p(x, z)}{q(z|x)}\\right].$$\n", + "\n", + "In practice, we optimize the single sample Monte Carlo estimate of this expectation:\n", + "\n", + "$$\\log p(x| z) + \\log p(z) - \\log q(z|x),$$\n", + "where $z$ is sampled from $q(z|x)$.\n", + "\n", + "**Note**: we could also analytically compute the KL term, but here we incorporate all three terms in the Monte Carlo estimator for simplicity." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "iWCn_PVdEJZ7" + }, + "outputs": [], + "source": [ + "def log_normal_pdf(sample, mean, logvar, raxis=1):\n", + " log2pi = tf.log(2. * np.pi)\n", + " return tf.reduce_sum(\n", + " -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),\n", + " axis=raxis)\n", + "\n", + "def compute_loss(model, x):\n", + " mean, logvar = model.encode(x)\n", + " z = model.reparameterize(mean, logvar)\n", + " x_logit = model.decode(z)\n", + "\n", + " cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)\n", + " logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])\n", + " logpz = log_normal_pdf(z, 0., 0.)\n", + " logqz_x = log_normal_pdf(z, mean, logvar)\n", + " return -tf.reduce_mean(logpx_z + logpz - logqz_x)\n", + "\n", + "def compute_gradients(model, x):\n", + " with tf.GradientTape() as tape:\n", + " loss = compute_loss(model, x)\n", + " return tape.gradient(loss, model.trainable_variables), loss\n", + "\n", + "optimizer = tf.train.AdamOptimizer(1e-4)\n", + "def apply_gradients(optimizer, gradients, variables, global_step=None):\n", + " optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Rw1fkAczTQYh" + }, + "source": [ + "## Training\n", + "\n", + "* We start by iterating over the dataset\n", + "* During each iteration, we pass the image to the encoder to obtain a set of mean and log-variance parameters of the approximate posterior $q(z|x)$\n", + "* We then apply the *reparameterization trick* to sample from $q(z|x)$\n", + "* Finally, we pass the reparameterized samples to the decoder to obtain the logits of the generative distribution $p(x|z)$\n", + "* **Note:** Since we use the dataset loaded by keras with 60k datapoints in the training set and 10k datapoints in the test set, our resulting ELBO on the test set is slightly higher than reported results in the literature which uses dynamic binarization of Larochelle's MNIST.\n", + "\n", + "## Generate Images\n", + "\n", + "* After training, it is time to generate some images\n", + "* We start by sampling a set of latent vectors from the unit Gaussian prior distribution $p(z)$\n", + "* The generator will then convert the latent sample $z$ to logits of the observation, giving a distribution $p(x|z)$\n", + "* Here we plot the probabilities of Bernoulli distributions\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "NS2GWywBbAWo" + }, + "outputs": [], + "source": [ + "epochs = 100\n", + "latent_dim = 50\n", + "num_examples_to_generate = 16\n", + "\n", + "# keeping the random vector constant for generation (prediction) so\n", + "# it will be easier to see the improvement.\n", + "random_vector_for_generation = tf.random_normal(\n", + " shape=[num_examples_to_generate, latent_dim])\n", + "model = CVAE(latent_dim)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "RmdVsmvhPxyy" + }, + "outputs": [], + "source": [ + "def generate_and_save_images(model, epoch, test_input):\n", + " predictions = model.sample(test_input)\n", + " fig = plt.figure(figsize=(4,4))\n", + "\n", + " for i in range(predictions.shape[0]):\n", + " plt.subplot(4, 4, i+1)\n", + " plt.imshow(predictions[i, :, :, 0], cmap='gray')\n", + " plt.axis('off')\n", + "\n", + " # tight_layout minimizes the overlap between 2 sub-plots\n", + " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "2M7LmLtGEMQJ" + }, + "outputs": [], + "source": [ + "generate_and_save_images(model, 0, random_vector_for_generation)\n", + "\n", + "for epoch in range(1, epochs + 1):\n", + " start_time = time.time()\n", + " for train_x in train_dataset:\n", + " gradients, loss = compute_gradients(model, train_x)\n", + " apply_gradients(optimizer, gradients, model.trainable_variables)\n", + " end_time = time.time()\n", + "\n", + " if epoch % 1 == 0:\n", + " loss = tfe.metrics.Mean()\n", + " for test_x in test_dataset.make_one_shot_iterator():\n", + " loss(compute_loss(model, test_x))\n", + " elbo = -loss.result()\n", + " display.clear_output(wait=False)\n", + " print('Epoch: {}, Test set ELBO: {}, '\n", + " 'time elapse for current epoch {}'.format(epoch,\n", + " elbo,\n", + " end_time - start_time))\n", + " generate_and_save_images(\n", + " model, epoch, random_vector_for_generation)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "P4M_vIbUi7c0" + }, + "source": [ + "### Display an image using the epoch number" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "WfO5wCdclHGL" + }, + "outputs": [], + "source": [ + "def display_image(epoch_no):\n", + " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "5x3q9_Oe5q0A" + }, + "outputs": [], + "source": [ + "display_image(epochs) # Display images" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "NywiH3nL8guF" + }, + "source": [ + "### Generate a GIF of all the saved images." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "IGKQgENQ8lEI" + }, + "outputs": [], + "source": [ + "with imageio.get_writer('cvae.gif', mode='I') as writer:\n", + " filenames = glob.glob('image*.png')\n", + " filenames = sorted(filenames)\n", + " last = -1\n", + " for i,filename in enumerate(filenames):\n", + " frame = 2*(i**0.5)\n", + " if round(frame) \u003e round(last):\n", + " last = frame\n", + " else:\n", + " continue\n", + " image = imageio.imread(filename)\n", + " writer.append_data(image)\n", + " image = imageio.imread(filename)\n", + " writer.append_data(image)\n", + " \n", + "# this is a hack to display the gif inside the notebook\n", + "os.system('cp cvae.gif cvae.gif.png')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "uV0yiKpzNP1b" + }, + "outputs": [], + "source": [ + "display.Image(filename=\"cvae.gif.png\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "yQXO_dlXkKsT" + }, + "source": [ + "To downlod the animation from Colab uncomment the code below:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "4fSJS3m5HLFM" + }, + "outputs": [], + "source": [ + "#from google.colab import files\n", + "#files.download('cvae.gif')" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "cvae.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp", + "timestamp": 1527173385672 + } + ], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5621d6a358e8969ea1a6663c1c770987de41ce0c --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb @@ -0,0 +1,694 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0TD5ZrvEMbhZ" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# DCGAN: An example with tf.keras and eager\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ITZuApL56Mny" + }, + "source": [ + "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). To do so, we use Deep Convolutional Generative Adverserial Networks ([DCGAN](https://arxiv.org/pdf/1511.06434.pdf)).\n", + "\n", + "This model takes about ~30 seconds per epoch (using tf.contrib.eager.defun to create graph functions) to train on a single Tesla K80 on Colab, as of July 2018.\n", + "\n", + "Below is the output generated after training the generator and discriminator models for 150 epochs.\n", + "\n", + "![sample output](https://tensorflow.org/images/gan/dcgan.gif)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "u_2z-B3piVsw" + }, + "outputs": [], + "source": [ + "# to generate gifs\n", + "!pip install imageio" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "e1_Y75QXJS6h" + }, + "source": [ + "## Import TensorFlow and enable eager execution" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "YfIk2es3hJEd" + }, + "outputs": [], + "source": [ + "from __future__ import absolute_import, division, print_function\n", + "\n", + "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "import os\n", + "import time\n", + "import numpy as np\n", + "import glob\n", + "import matplotlib.pyplot as plt\n", + "import PIL\n", + "import imageio\n", + "from IPython import display" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "iYn4MdZnKCey" + }, + "source": [ + "## Load the dataset\n", + "\n", + "We are going to use the MNIST dataset to train the generator and the discriminator. The generator will then generate handwritten digits." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "a4fYMGxGhrna" + }, + "outputs": [], + "source": [ + "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "NFC2ghIdiZYE" + }, + "outputs": [], + "source": [ + "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", + "# We are normalizing the images to the range of [-1, 1]\n", + "train_images = (train_images - 127.5) / 127.5" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "S4PIDhoDLbsZ" + }, + "outputs": [], + "source": [ + "BUFFER_SIZE = 60000\n", + "BATCH_SIZE = 256" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PIGN6ouoQxt3" + }, + "source": [ + "## Use tf.data to create batches and shuffle the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "-yKCCQOoJ7cn" + }, + "outputs": [], + "source": [ + "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "THY-sZMiQ4UV" + }, + "source": [ + "## Write the generator and discriminator models\n", + "\n", + "* **Generator** \n", + " * It is responsible for **creating convincing images that are good enough to fool the discriminator**.\n", + " * It consists of Conv2DTranspose (Upsampling) layers. We start with a fully connected layer and upsample the image 2 times so as to reach the desired image size (mnist image size) which is (28, 28, 1). \n", + " * We use **leaky relu** activation except for the **last layer** which uses **tanh** activation.\n", + " \n", + "* **Discriminator**\n", + " * **The discriminator is responsible for classifying the fake images from the real images.**\n", + " * In other words, the discriminator is given generated images (from the generator) and the real MNIST images. The job of the discriminator is to classify these images into fake (generated) and real (MNIST images).\n", + " * **Basically the generator should be good enough to fool the discriminator that the generated images are real**." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "VGLbvBEmjK0a" + }, + "outputs": [], + "source": [ + "class Generator(tf.keras.Model):\n", + " def __init__(self):\n", + " super(Generator, self).__init__()\n", + " self.fc1 = tf.keras.layers.Dense(7*7*64, use_bias=False)\n", + " self.batchnorm1 = tf.keras.layers.BatchNormalization()\n", + " \n", + " self.conv1 = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(1, 1), padding='same', use_bias=False)\n", + " self.batchnorm2 = tf.keras.layers.BatchNormalization()\n", + " \n", + " self.conv2 = tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)\n", + " self.batchnorm3 = tf.keras.layers.BatchNormalization()\n", + " \n", + " self.conv3 = tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False)\n", + "\n", + " def call(self, x, training=True):\n", + " x = self.fc1(x)\n", + " x = self.batchnorm1(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = tf.reshape(x, shape=(-1, 7, 7, 64))\n", + "\n", + " x = self.conv1(x)\n", + " x = self.batchnorm2(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2(x)\n", + " x = self.batchnorm3(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = tf.nn.tanh(self.conv3(x)) \n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "bkOfJxk5j5Hi" + }, + "outputs": [], + "source": [ + "class Discriminator(tf.keras.Model):\n", + " def __init__(self):\n", + " super(Discriminator, self).__init__()\n", + " self.conv1 = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')\n", + " self.conv2 = tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')\n", + " self.dropout = tf.keras.layers.Dropout(0.3)\n", + " self.flatten = tf.keras.layers.Flatten()\n", + " self.fc1 = tf.keras.layers.Dense(1)\n", + "\n", + " def call(self, x, training=True):\n", + " x = tf.nn.leaky_relu(self.conv1(x))\n", + " x = self.dropout(x, training=training)\n", + " x = tf.nn.leaky_relu(self.conv2(x))\n", + " x = self.dropout(x, training=training)\n", + " x = self.flatten(x)\n", + " x = self.fc1(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "gDkA05NE6QMs" + }, + "outputs": [], + "source": [ + "generator = Generator()\n", + "discriminator = Discriminator()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "k1HpMSLImuRi" + }, + "outputs": [], + "source": [ + "# Defun gives 10 secs/epoch performance boost\n", + "generator.call = tf.contrib.eager.defun(generator.call)\n", + "discriminator.call = tf.contrib.eager.defun(discriminator.call)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0FMYgY_mPfTi" + }, + "source": [ + "## Define the loss functions and the optimizer\n", + "\n", + "* **Discriminator loss**\n", + " * The discriminator loss function takes 2 inputs; **real images, generated images**\n", + " * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones (since these are the real images)**\n", + " * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros (since these are the fake images)**\n", + " * Then the total_loss is the sum of real_loss and the generated_loss\n", + " \n", + "* **Generator loss**\n", + " * It is a sigmoid cross entropy loss of the generated images and an **array of ones**\n", + " \n", + "\n", + "* The discriminator and the generator optimizers are different since we will train them separately." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "wkMNfBWlT-PV" + }, + "outputs": [], + "source": [ + "def discriminator_loss(real_output, generated_output):\n", + " # [1,1,...,1] with real output since it is true and we want\n", + " # our generated examples to look like it\n", + " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)\n", + "\n", + " # [0,0,...,0] with generated images since they are fake\n", + " generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)\n", + "\n", + " total_loss = real_loss + generated_loss\n", + "\n", + " return total_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "90BIcCKcDMxz" + }, + "outputs": [], + "source": [ + "def generator_loss(generated_output):\n", + " return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "iWCn_PVdEJZ7" + }, + "outputs": [], + "source": [ + "discriminator_optimizer = tf.train.AdamOptimizer(1e-4)\n", + "generator_optimizer = tf.train.AdamOptimizer(1e-4)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "mWtinsGDPJlV" + }, + "source": [ + "## Checkpoints (Object-based saving)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "CA1w-7s2POEy" + }, + "outputs": [], + "source": [ + "checkpoint_dir = './training_checkpoints'\n", + "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", + "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", + " discriminator_optimizer=discriminator_optimizer,\n", + " generator=generator,\n", + " discriminator=discriminator)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Rw1fkAczTQYh" + }, + "source": [ + "## Training\n", + "\n", + "* We start by iterating over the dataset\n", + "* The generator is given **noise as an input** which when passed through the generator model will output a image looking like a handwritten digit\n", + "* The discriminator is given the **real MNIST images as well as the generated images (from the generator)**.\n", + "* Next, we calculate the generator and the discriminator loss.\n", + "* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables (inputs) and apply those to the optimizer.\n", + "\n", + "## Generate Images\n", + "\n", + "* After training, its time to generate some images!\n", + "* We start by creating noise array as an input to the generator\n", + "* The generator will then convert the noise into handwritten images.\n", + "* Last step is to plot the predictions and **voila!**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "NS2GWywBbAWo" + }, + "outputs": [], + "source": [ + "EPOCHS = 150\n", + "noise_dim = 100\n", + "num_examples_to_generate = 16\n", + "\n", + "# keeping the random vector constant for generation (prediction) so\n", + "# it will be easier to see the improvement of the gan.\n", + "random_vector_for_generation = tf.random_normal([num_examples_to_generate,\n", + " noise_dim])" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "RmdVsmvhPxyy" + }, + "outputs": [], + "source": [ + "def generate_and_save_images(model, epoch, test_input):\n", + " # make sure the training parameter is set to False because we\n", + " # don't want to train the batchnorm layer when doing inference.\n", + " predictions = model(test_input, training=False)\n", + "\n", + " fig = plt.figure(figsize=(4,4))\n", + " \n", + " for i in range(predictions.shape[0]):\n", + " plt.subplot(4, 4, i+1)\n", + " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", + " plt.axis('off')\n", + " \n", + " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "2M7LmLtGEMQJ" + }, + "outputs": [], + "source": [ + "def train(dataset, epochs, noise_dim): \n", + " for epoch in range(epochs):\n", + " start = time.time()\n", + " \n", + " for images in dataset:\n", + " # generating noise from a uniform distribution\n", + " noise = tf.random_normal([BATCH_SIZE, noise_dim])\n", + " \n", + " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", + " generated_images = generator(noise, training=True)\n", + " \n", + " real_output = discriminator(images, training=True)\n", + " generated_output = discriminator(generated_images, training=True)\n", + " \n", + " gen_loss = generator_loss(generated_output)\n", + " disc_loss = discriminator_loss(real_output, generated_output)\n", + " \n", + " gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)\n", + " gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)\n", + " \n", + " generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))\n", + " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))\n", + "\n", + " \n", + " if epoch % 1 == 0:\n", + " display.clear_output(wait=True)\n", + " generate_and_save_images(generator,\n", + " epoch + 1,\n", + " random_vector_for_generation)\n", + " \n", + " # saving (checkpoint) the model every 15 epochs\n", + " if (epoch + 1) % 15 == 0:\n", + " checkpoint.save(file_prefix = checkpoint_prefix)\n", + " \n", + " print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n", + " time.time()-start))\n", + " # generating after the final epoch\n", + " display.clear_output(wait=True)\n", + " generate_and_save_images(generator,\n", + " epochs,\n", + " random_vector_for_generation)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ly3UN0SLLY2l" + }, + "outputs": [], + "source": [ + "train(train_dataset, EPOCHS, noise_dim)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "rfM4YcPVPkNO" + }, + "source": [ + "## Restore the latest checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "XhXsd0srPo8c" + }, + "outputs": [], + "source": [ + "# restoring the latest checkpoint in checkpoint_dir\n", + "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "P4M_vIbUi7c0" + }, + "source": [ + "## Display an image using the epoch number" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WfO5wCdclHGL" + }, + "outputs": [], + "source": [ + "def display_image(epoch_no):\n", + " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "5x3q9_Oe5q0A" + }, + "outputs": [], + "source": [ + "display_image(EPOCHS)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "NywiH3nL8guF" + }, + "source": [ + "## Generate a GIF of all the saved images." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xmO0Dmu2WICn" + }, + "source": [ + "\u003c!-- TODO(markdaoust): Remove the hack when Ipython version is updated --\u003e\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "IGKQgENQ8lEI" + }, + "outputs": [], + "source": [ + "with imageio.get_writer('dcgan.gif', mode='I') as writer:\n", + " filenames = glob.glob('image*.png')\n", + " filenames = sorted(filenames)\n", + " last = -1\n", + " for i,filename in enumerate(filenames):\n", + " frame = 2*(i**0.5)\n", + " if round(frame) \u003e round(last):\n", + " last = frame\n", + " else:\n", + " continue\n", + " image = imageio.imread(filename)\n", + " writer.append_data(image)\n", + " image = imageio.imread(filename)\n", + " writer.append_data(image)\n", + " \n", + "# this is a hack to display the gif inside the notebook\n", + "os.system('cp dcgan.gif dcgan.gif.png')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "uV0yiKpzNP1b" + }, + "outputs": [], + "source": [ + "display.Image(filename=\"dcgan.gif.png\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "6EEG-wePkmJQ" + }, + "source": [ + "To downlod the animation from Colab uncomment the code below:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "4UJjSnIMOzOJ" + }, + "outputs": [], + "source": [ + "#from google.colab import files\n", + "#files.download('dcgan.gif')" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "dcgan.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp", + "timestamp": 1527173385672 + } + ], + "toc_visible": true, + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..1a5a186e7a3e456cc43f8091370d3eeb795d5e0e --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -0,0 +1,1184 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "image_captioning_with_attention.ipynb", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [ + { + "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", + "timestamp": 1530222436922 + } + ], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "metadata": { + "id": "K2s1A9eLRPEj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n" + ] + }, + { + "metadata": { + "id": "Cffg2i257iMS", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Image Captioning with Attention\n", + "\n", + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on GitHub
" + ] + }, + { + "metadata": { + "id": "QASbY_HGo4Lq", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Image captioning is the task of generating a caption for an image. Given an image like this:\n", + "\n", + "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", + "\n", + "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", + "\n", + "Our goal is generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", + "\n", + "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", + "\n", + "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", + "\n", + "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", + "\n", + "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", + "\n", + "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", + "\n", + "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" + ] + }, + { + "metadata": { + "id": "U8l4RJ0XRPEm", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Import TensorFlow and enable eager execution\n", + "# This code requires TensorFlow version >=1.9\n", + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "# We'll generate plots of attention in order to see which parts of an image\n", + "# our model focuses on during captioning\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Scikit-learn includes many helpful utilities\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.utils import shuffle\n", + "\n", + "import re\n", + "import numpy as np\n", + "import os\n", + "import time\n", + "import json\n", + "from glob import glob\n", + "from PIL import Image\n", + "import pickle" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "b6qbGw8MRPE5", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Download and prepare the MS-COCO dataset\n", + "\n", + "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code code below will download and extract the dataset automatically. \n", + "\n", + "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." + ] + }, + { + "metadata": { + "id": "krQuPYTtRPE7", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", + " cache_subdir=os.path.abspath('.'),\n", + " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", + " extract = True)\n", + "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", + "\n", + "name_of_zip = 'train2014.zip'\n", + "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", + " image_zip = tf.keras.utils.get_file(name_of_zip, \n", + " cache_subdir=os.path.abspath('.'),\n", + " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", + " extract = True)\n", + " PATH = os.path.dirname(image_zip)+'/train2014/'\n", + "else:\n", + " PATH = os.path.abspath('.')+'/train2014/'" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "aANEzb5WwSzg", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Optionally, limit the size of the training set for faster training\n", + "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." + ] + }, + { + "metadata": { + "id": "4G3b8x8_RPFD", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# read the json file\n", + "with open(annotation_file, 'r') as f:\n", + " annotations = json.load(f)\n", + "\n", + "# storing the captions and the image name in vectors\n", + "all_captions = []\n", + "all_img_name_vector = []\n", + "\n", + "for annot in annotations['annotations']:\n", + " caption = ' ' + annot['caption'] + ' '\n", + " image_id = annot['image_id']\n", + " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", + " \n", + " all_img_name_vector.append(full_coco_image_path)\n", + " all_captions.append(caption)\n", + "\n", + "# shuffling the captions and image_names together\n", + "# setting a random state\n", + "train_captions, img_name_vector = shuffle(all_captions,\n", + " all_img_name_vector,\n", + " random_state=1)\n", + "\n", + "# selecting the first 30000 captions from the shuffled set\n", + "num_examples = 30000\n", + "train_captions = train_captions[:num_examples]\n", + "img_name_vector = img_name_vector[:num_examples]" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "mPBMgK34RPFL", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "len(train_captions), len(all_captions)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "8cSW4u-ORPFQ", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Preprocess the images using InceptionV3\n", + "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", + "\n", + "First, we will need to convert the images into the format inceptionV3 expects by:\n", + "* Resizing the image to (299, 299)\n", + "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." + ] + }, + { + "metadata": { + "id": "zXR0217aRPFR", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def load_image(image_path):\n", + " img = tf.read_file(image_path)\n", + " img = tf.image.decode_jpeg(img, channels=3)\n", + " img = tf.image.resize_images(img, (299, 299))\n", + " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", + " return img, image_path" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "MDvIu4sXRPFV", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", + "\n", + "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", + "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", + "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", + "* We avoid doing this during training so it does not become a bottleneck. \n", + "* After all the images are passed through the network, we pickle the dictionary and save it to disk." + ] + }, + { + "metadata": { + "id": "RD3vW4SsRPFW", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", + " weights='imagenet')\n", + "new_input = image_model.input\n", + "hidden_layer = image_model.layers[-1].output\n", + "\n", + "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "rERqlR3WRPGO", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Caching the features extracted from InceptionV3\n", + "\n", + "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", + "\n", + "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", + "\n", + "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", + "\n", + "```for img, path in image_dataset:``` \n", + "\n", + "to:\n", + "\n", + "```for img, path in tqdm(image_dataset):```." + ] + }, + { + "metadata": { + "id": "Dx_fvbVgRPGQ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# getting the unique images\n", + "encode_train = sorted(set(img_name_vector))\n", + "\n", + "# feel free to change the batch_size according to your system configuration\n", + "image_dataset = tf.data.Dataset.from_tensor_slices(\n", + " encode_train).map(load_image).batch(16)\n", + "\n", + "for img, path in image_dataset:\n", + " batch_features = image_features_extract_model(img)\n", + " batch_features = tf.reshape(batch_features, \n", + " (batch_features.shape[0], -1, batch_features.shape[3]))\n", + "\n", + " for bf, p in zip(batch_features, path):\n", + " path_of_feature = p.numpy().decode(\"utf-8\")\n", + " np.save(path_of_feature, bf.numpy())" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "nyqH3zFwRPFi", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Preprocess and tokenize the captions\n", + "\n", + "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", + "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", + "* Finally, we create a word --> index mapping and vice-versa.\n", + "* We will then pad all sequences to the be same length as the longest one. " + ] + }, + { + "metadata": { + "id": "HZfK8RhQRPFj", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# This will find the maximum length of any caption in our dataset\n", + "def calc_max_length(tensor):\n", + " return max(len(t) for t in tensor)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "oJGE34aiRPFo", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# The steps above is a general process of dealing with text processing\n", + "\n", + "# choosing the top 5000 words from the vocabulary\n", + "top_k = 5000\n", + "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", + " oov_token=\"\", \n", + " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", + "tokenizer.fit_on_texts(train_captions)\n", + "train_seqs = tokenizer.texts_to_sequences(train_captions)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "8Q44tNQVRPFt", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "tokenizer.word_index = {key:value for key, value in tokenizer.word_index.items() if value <= top_k}\n", + "# putting token in the word2idx dictionary\n", + "tokenizer.word_index[tokenizer.oov_token] = top_k + 1\n", + "tokenizer.word_index[''] = 0" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "0fpJb5ojRPFv", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# creating the tokenized vectors\n", + "train_seqs = tokenizer.texts_to_sequences(train_captions)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "olQArbgbRPF1", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# creating a reverse mapping (index -> word)\n", + "index_word = {value:key for key, value in tokenizer.word_index.items()}" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "AidglIZVRPF4", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# padding each vector to the max_length of the captions\n", + "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", + "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "gL0wkttkRPGA", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# calculating the max_length \n", + "# used to store the attention weights\n", + "max_length = calc_max_length(train_seqs)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "M3CD75nDpvTI", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Split the data into training and testing" + ] + }, + { + "metadata": { + "id": "iS7DDMszRPGF", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Create training and validation sets using 80-20 split\n", + "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", + " cap_vector, \n", + " test_size=0.2, \n", + " random_state=0)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "XmViPkRFRPGH", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "uEWM9xrYcg45", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", + "\n" + ] + }, + { + "metadata": { + "id": "Q3TnZ1ToRPGV", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# feel free to change these parameters according to your system's configuration\n", + "\n", + "BATCH_SIZE = 64\n", + "BUFFER_SIZE = 1000\n", + "embedding_dim = 256\n", + "units = 512\n", + "vocab_size = len(tokenizer.word_index)\n", + "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", + "# these two variables represent that\n", + "features_shape = 2048\n", + "attention_features_shape = 64" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "SmZS2N0bXG3T", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# loading the numpy files \n", + "def map_func(img_name, cap):\n", + " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", + " return img_tensor, cap" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "FDF_Nm3tRPGZ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", + "\n", + "# using map to load the numpy files in parallel\n", + "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", + "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", + "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", + " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", + "\n", + "# shuffling and batching\n", + "dataset = dataset.shuffle(BUFFER_SIZE)\n", + "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", + "dataset = dataset.batch(BATCH_SIZE)\n", + "dataset = dataset.prefetch(1)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "nrvoDphgRPGd", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Model\n", + "\n", + "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", + "\n", + "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", + "\n", + "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", + "* We squash that to a shape of (64, 2048).\n", + "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", + "* The RNN(here GRU) attends over the image to predict the next word." + ] + }, + { + "metadata": { + "id": "AAppCGLKRPGd", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def gru(units):\n", + " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", + " # significant speedup).\n", + " if tf.test.is_gpu_available():\n", + " return tf.keras.layers.CuDNNGRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_initializer='glorot_uniform')\n", + " else:\n", + " return tf.keras.layers.GRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_activation='sigmoid', \n", + " recurrent_initializer='glorot_uniform')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "ja2LFTMSdeV3", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class BahdanauAttention(tf.keras.Model):\n", + " def __init__(self, units):\n", + " super(BahdanauAttention, self).__init__()\n", + " self.W1 = tf.keras.layers.Dense(units)\n", + " self.W2 = tf.keras.layers.Dense(units)\n", + " self.V = tf.keras.layers.Dense(1)\n", + " \n", + " def call(self, features, hidden):\n", + " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", + " \n", + " # hidden shape == (batch_size, hidden_size)\n", + " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", + " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", + " \n", + " # score shape == (batch_size, 64, hidden_size)\n", + " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", + " \n", + " # attention_weights shape == (batch_size, 64, 1)\n", + " # we get 1 at the last axis because we are applying score to self.V\n", + " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", + " \n", + " # context_vector shape after sum == (batch_size, hidden_size)\n", + " context_vector = attention_weights * features\n", + " context_vector = tf.reduce_sum(context_vector, axis=1)\n", + " \n", + " return context_vector, attention_weights" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "AZ7R1RxHRPGf", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class CNN_Encoder(tf.keras.Model):\n", + " # Since we have already extracted the features and dumped it using pickle\n", + " # This encoder passes those features through a Fully connected layer\n", + " def __init__(self, embedding_dim):\n", + " super(CNN_Encoder, self).__init__()\n", + " # shape after fc == (batch_size, 64, embedding_dim)\n", + " self.fc = tf.keras.layers.Dense(embedding_dim)\n", + " \n", + " def call(self, x):\n", + " x = self.fc(x)\n", + " x = tf.nn.relu(x)\n", + " return x" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "V9UbGQmERPGi", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class RNN_Decoder(tf.keras.Model):\n", + " def __init__(self, embedding_dim, units, vocab_size):\n", + " super(RNN_Decoder, self).__init__()\n", + " self.units = units\n", + "\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + " self.gru = gru(self.units)\n", + " self.fc1 = tf.keras.layers.Dense(self.units)\n", + " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", + " \n", + " self.attention = BahdanauAttention(self.units)\n", + " \n", + " def call(self, x, features, hidden):\n", + " # defining attention as a separate model\n", + " context_vector, attention_weights = self.attention(features, hidden)\n", + " \n", + " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", + " x = self.embedding(x)\n", + " \n", + " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", + " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", + " \n", + " # passing the concatenated vector to the GRU\n", + " output, state = self.gru(x)\n", + " \n", + " # shape == (batch_size, max_length, hidden_size)\n", + " x = self.fc1(output)\n", + " \n", + " # x shape == (batch_size * max_length, hidden_size)\n", + " x = tf.reshape(x, (-1, x.shape[2]))\n", + " \n", + " # output shape == (batch_size * max_length, vocab)\n", + " x = self.fc2(x)\n", + "\n", + " return x, state, attention_weights\n", + "\n", + " def reset_state(self, batch_size):\n", + " return tf.zeros((batch_size, self.units))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Qs_Sr03wRPGk", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "encoder = CNN_Encoder(embedding_dim)\n", + "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "-bYN7xA0RPGl", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "optimizer = tf.train.AdamOptimizer()\n", + "\n", + "# We are masking the loss calculated for padding\n", + "def loss_function(real, pred):\n", + " mask = 1 - np.equal(real, 0)\n", + " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", + " return tf.reduce_mean(loss_)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "PHod7t72RPGn", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Training\n", + "\n", + "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", + "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", + "* The decoder returns the predictions and the decoder hidden state.\n", + "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", + "* Use teacher forcing to decide the next input to the decoder.\n", + "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", + "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" + ] + }, + { + "metadata": { + "id": "Vt4WZ5mhJE-E", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# adding this in a separate cell because if you run the training cell \n", + "# many times, the loss_plot array will be reset\n", + "loss_plot = []" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "UlA4VIQpRPGo", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "EPOCHS = 20\n", + "\n", + "for epoch in range(EPOCHS):\n", + " start = time.time()\n", + " total_loss = 0\n", + " \n", + " for (batch, (img_tensor, target)) in enumerate(dataset):\n", + " loss = 0\n", + " \n", + " # initializing the hidden state for each batch\n", + " # because the captions are not related from image to image\n", + " hidden = decoder.reset_state(batch_size=target.shape[0])\n", + "\n", + " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", + " \n", + " with tf.GradientTape() as tape:\n", + " features = encoder(img_tensor)\n", + " \n", + " for i in range(1, target.shape[1]):\n", + " # passing the features through the decoder\n", + " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", + "\n", + " loss += loss_function(target[:, i], predictions)\n", + " \n", + " # using teacher forcing\n", + " dec_input = tf.expand_dims(target[:, i], 1)\n", + " \n", + " total_loss += (loss / int(target.shape[1]))\n", + " \n", + " variables = encoder.variables + decoder.variables\n", + " \n", + " gradients = tape.gradient(loss, variables) \n", + " \n", + " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", + " \n", + " if batch % 100 == 0:\n", + " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", + " batch, \n", + " loss.numpy() / int(target.shape[1])))\n", + " # storing the epoch end loss value to plot later\n", + " loss_plot.append(total_loss / len(cap_vector))\n", + " \n", + " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", + " total_loss/len(cap_vector)))\n", + " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "1Wm83G-ZBPcC", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "plt.plot(loss_plot)\n", + "plt.xlabel('Epochs')\n", + "plt.ylabel('Loss')\n", + "plt.title('Loss Plot')\n", + "plt.show()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "xGvOcLQKghXN", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Caption!\n", + "\n", + "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", + "* Stop predicting when the model predicts the end token.\n", + "* And store the attention weights for every time step." + ] + }, + { + "metadata": { + "id": "RCWpDtyNRPGs", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def evaluate(image):\n", + " attention_plot = np.zeros((max_length, attention_features_shape))\n", + "\n", + " hidden = decoder.reset_state(batch_size=1)\n", + "\n", + " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", + " img_tensor_val = image_features_extract_model(temp_input)\n", + " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", + "\n", + " features = encoder(img_tensor_val)\n", + "\n", + " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", + " result = []\n", + "\n", + " for i in range(max_length):\n", + " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", + "\n", + " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", + "\n", + " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n", + " result.append(index_word[predicted_id])\n", + "\n", + " if index_word[predicted_id] == '':\n", + " return result, attention_plot\n", + "\n", + " dec_input = tf.expand_dims([predicted_id], 0)\n", + "\n", + " attention_plot = attention_plot[:len(result), :]\n", + " return result, attention_plot" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "fD_y7PD6RPGt", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def plot_attention(image, result, attention_plot):\n", + " temp_image = np.array(Image.open(image))\n", + "\n", + " fig = plt.figure(figsize=(10, 10))\n", + " \n", + " len_result = len(result)\n", + " for l in range(len_result):\n", + " temp_att = np.resize(attention_plot[l], (8, 8))\n", + " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", + " ax.set_title(result[l])\n", + " img = ax.imshow(temp_image)\n", + " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", + "\n", + " plt.tight_layout()\n", + " plt.show()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "io7ws3ReRPGv", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# captions on the validation set\n", + "rid = np.random.randint(0, len(img_name_val))\n", + "image = img_name_val[rid]\n", + "real_caption = ' '.join([index_word[i] for i in cap_val[rid] if i not in [0]])\n", + "result, attention_plot = evaluate(image)\n", + "\n", + "print ('Real Caption:', real_caption)\n", + "print ('Prediction Caption:', ' '.join(result))\n", + "plot_attention(image, result, attention_plot)\n", + "# opening the image\n", + "Image.open(img_name_val[rid])" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Rprk3HEvZuxb", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Try it on your own images\n", + "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" + ] + }, + { + "metadata": { + "id": "9Psd1quzaAWg", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "image_url = 'https://tensorflow.org/images/surf.jpg'\n", + "image_extension = image_url[-4:]\n", + "image_path = tf.keras.utils.get_file('image'+image_extension, \n", + " origin=image_url)\n", + "\n", + "result, attention_plot = evaluate(image_path)\n", + "print ('Prediction Caption:', ' '.join(result))\n", + "plot_attention(image_path, result, attention_plot)\n", + "# opening the image\n", + "Image.open(image_path)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "VJZXyJco6uLO", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Next steps\n", + "\n", + "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." + ] + } + ] +} diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..027097908f2c62724830c556d72b6b6bee218eec --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -0,0 +1,671 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "hcD2nPQvPOFM" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# Text Generation using a RNN\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on Github\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BwpJ5IffzRG6" + }, + "source": [ + "This notebook demonstrates how to generate text using an RNN using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). If you like, you can write a similar [model](https://github.com/fchollet/deep-learning-with-python-notebooks/blob/master/8.1-text-generation-with-lstm.ipynb) using less code. Here, we show a lower-level impementation that's useful to understand as prework before diving in to deeper examples in a similar, like [Neural Machine Translation with Attention](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", + "\n", + "This notebook is an end-to-end example. When you run it, it will download a dataset of Shakespeare's writing. We'll use a collection of plays, borrowed from Andrej Karpathy's excellent [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). The notebook will train a model, and use it to generate sample output.\n", + " \n", + "Here is the output(with start string='w') after training a single layer GRU for 30 epochs with the default settings below:\n", + "\n", + "```\n", + "were to the death of him\n", + "And nothing of the field in the view of hell,\n", + "When I said, banish him, I will not burn thee that would live.\n", + "\n", + "HENRY BOLINGBROKE:\n", + "My gracious uncle--\n", + "\n", + "DUKE OF YORK:\n", + "As much disgraced to the court, the gods them speak,\n", + "And now in peace himself excuse thee in the world.\n", + "\n", + "HORTENSIO:\n", + "Madam, 'tis not the cause of the counterfeit of the earth,\n", + "And leave me to the sun that set them on the earth\n", + "And leave the world and are revenged for thee.\n", + "\n", + "GLOUCESTER:\n", + "I would they were talking with the very name of means\n", + "To make a puppet of a guest, and therefore, good Grumio,\n", + "Nor arm'd to prison, o' the clouds, of the whole field,\n", + "With the admire\n", + "With the feeding of thy chair, and we have heard it so,\n", + "I thank you, sir, he is a visor friendship with your silly your bed.\n", + "\n", + "SAMPSON:\n", + "I do desire to live, I pray: some stand of the minds, make thee remedies\n", + "With the enemies of my soul.\n", + "\n", + "MENENIUS:\n", + "I'll keep the cause of my mistress.\n", + "\n", + "POLIXENES:\n", + "My brother Marcius!\n", + "\n", + "Second Servant:\n", + "Will't ple\n", + "```\n", + "\n", + "Of course, while some of the sentences are grammatical, most do not make sense. But, consider:\n", + "\n", + "* Our model is character based (when we began training, it did not yet know how to spell a valid English word, or that words were even a unit of text).\n", + "\n", + "* The structure of the output resembles a play (blocks begin with a speaker name, in all caps similar to the original text). Sentences generally end with a period. If you look at the text from a distance (or don't read the invididual words too closely, it appears as if it's an excerpt from a play).\n", + "\n", + "As a next step, you can experiment training the model on a different dataset - any large text file(ASCII) will do, and you can modify a single line of code below to make that change. Have fun!\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "R3p22DBDsaCA" + }, + "source": [ + "## Install unidecode library\n", + "A helpful library to convert unicode to ASCII." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "wZ6LOM12wKGH" + }, + "outputs": [], + "source": [ + "!pip install unidecode" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "WGyKZj3bzf9p" + }, + "source": [ + "## Import tensorflow and enable eager execution." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "yG_n40gFzf9s" + }, + "outputs": [], + "source": [ + "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", + "import tensorflow as tf\n", + "\n", + "# Note: Once you enable eager execution, it cannot be disabled. \n", + "tf.enable_eager_execution()\n", + "\n", + "import numpy as np\n", + "import os\n", + "import re\n", + "import random\n", + "import unidecode\n", + "import time" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EHDoRoc5PKWz" + }, + "source": [ + "## Download the dataset\n", + "\n", + "In this example, we will use the [shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt). You can use any other dataset that you like.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "pD_55cOxLkAb" + }, + "outputs": [], + "source": [ + "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "UHjdCjDuSvX_" + }, + "source": [ + "## Read the dataset\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "-E5JvY3wzf94" + }, + "outputs": [], + "source": [ + "text = unidecode.unidecode(open(path_to_file).read())\n", + "# length of text is the number of characters in it\n", + "print (len(text))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Il9ww98izf-D" + }, + "source": [ + "Creating dictionaries to map from characters to their indices and vice-versa, which will be used to vectorize the inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "IalZLbvOzf-F" + }, + "outputs": [], + "source": [ + "# unique contains all the unique characters in the file\n", + "unique = sorted(set(text))\n", + "\n", + "# creating a mapping from unique characters to indices\n", + "char2idx = {u:i for i, u in enumerate(unique)}\n", + "idx2char = {i:u for i, u in enumerate(unique)}" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "1v_qUYfAzf-I" + }, + "outputs": [], + "source": [ + "# setting the maximum length sentence we want for a single input in characters\n", + "max_length = 100\n", + "\n", + "# length of the vocabulary in chars\n", + "vocab_size = len(unique)\n", + "\n", + "# the embedding dimension \n", + "embedding_dim = 256\n", + "\n", + "# number of RNN (here GRU) units\n", + "units = 1024\n", + "\n", + "# batch size \n", + "BATCH_SIZE = 64\n", + "\n", + "# buffer size to shuffle our dataset\n", + "BUFFER_SIZE = 10000" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "LFjSVAlWzf-N" + }, + "source": [ + "## Creating the input and output tensors\n", + "\n", + "Vectorizing the input and the target text because our model cannot understand strings only numbers.\n", + "\n", + "But first, we need to create the input and output vectors.\n", + "Remember the max_length we set above, we will use it here. We are creating **max_length** chunks of input, where each input vector is all the characters in that chunk except the last and the target vector is all the characters in that chunk except the first.\n", + "\n", + "For example, consider that the string = 'tensorflow' and the max_length is 9\n", + "\n", + "So, the `input = 'tensorflo'` and `output = 'ensorflow'`\n", + "\n", + "After creating the vectors, we convert each character into numbers using the **char2idx** dictionary we created above." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "0UHJDA39zf-O" + }, + "outputs": [], + "source": [ + "input_text = []\n", + "target_text = []\n", + "\n", + "for f in range(0, len(text)-max_length, max_length):\n", + " inps = text[f:f+max_length]\n", + " targ = text[f+1:f+1+max_length]\n", + "\n", + " input_text.append([char2idx[i] for i in inps])\n", + " target_text.append([char2idx[t] for t in targ])\n", + " \n", + "print (np.array(input_text).shape)\n", + "print (np.array(target_text).shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "MJdfPmdqzf-R" + }, + "source": [ + "## Creating batches and shuffling them using tf.data" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "p2pGotuNzf-S" + }, + "outputs": [], + "source": [ + "dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n", + "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "m8gPwEjRzf-Z" + }, + "source": [ + "## Creating the model\n", + "\n", + "We use the Model Subclassing API which gives us full flexibility to create the model and change it however we like. We use 3 layers to define our model.\n", + "\n", + "* Embedding layer\n", + "* GRU layer (you can use an LSTM layer here)\n", + "* Fully connected layer" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "P3KTiiInzf-a" + }, + "outputs": [], + "source": [ + "class Model(tf.keras.Model):\n", + " def __init__(self, vocab_size, embedding_dim, units, batch_size):\n", + " super(Model, self).__init__()\n", + " self.units = units\n", + " self.batch_sz = batch_size\n", + "\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + "\n", + " if tf.test.is_gpu_available():\n", + " self.gru = tf.keras.layers.CuDNNGRU(self.units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_initializer='glorot_uniform')\n", + " else:\n", + " self.gru = tf.keras.layers.GRU(self.units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_activation='sigmoid', \n", + " recurrent_initializer='glorot_uniform')\n", + "\n", + " self.fc = tf.keras.layers.Dense(vocab_size)\n", + " \n", + " def call(self, x, hidden):\n", + " x = self.embedding(x)\n", + "\n", + " # output shape == (batch_size, max_length, hidden_size) \n", + " # states shape == (batch_size, hidden_size)\n", + "\n", + " # states variable to preserve the state of the model\n", + " # this will be used to pass at every step to the model while training\n", + " output, states = self.gru(x, initial_state=hidden)\n", + "\n", + "\n", + " # reshaping the output so that we can pass it to the Dense layer\n", + " # after reshaping the shape is (batch_size * max_length, hidden_size)\n", + " output = tf.reshape(output, (-1, output.shape[2]))\n", + "\n", + " # The dense layer will output predictions for every time_steps(max_length)\n", + " # output shape after the dense layer == (max_length * batch_size, vocab_size)\n", + " x = self.fc(output)\n", + "\n", + " return x, states" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "trpqTWyvk0nr" + }, + "source": [ + "## Call the model and set the optimizer and the loss function" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "7t2XrzEOzf-e" + }, + "outputs": [], + "source": [ + "model = Model(vocab_size, embedding_dim, units, BATCH_SIZE)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "dkjWIATszf-h" + }, + "outputs": [], + "source": [ + "optimizer = tf.train.AdamOptimizer()\n", + "\n", + "# using sparse_softmax_cross_entropy so that we don't have to create one-hot vectors\n", + "def loss_function(real, preds):\n", + " return tf.losses.sparse_softmax_cross_entropy(labels=real, logits=preds)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "3K6s6F79P7za" + }, + "source": [ + "## Checkpoints (Object-based saving)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "oAGisDdfP9rL" + }, + "outputs": [], + "source": [ + "checkpoint_dir = './training_checkpoints'\n", + "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", + "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", + " model=model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "lPrP0XMUzf-p" + }, + "source": [ + "## Train the model\n", + "\n", + "Here we will use a custom training loop with the help of GradientTape()\n", + "\n", + "* We initialize the hidden state of the model with zeros and shape == (batch_size, number of rnn units). We do this by calling the function defined while creating the model.\n", + "\n", + "* Next, we iterate over the dataset(batch by batch) and calculate the **predictions and the hidden states** associated with that input.\n", + "\n", + "* There are a lot of interesting things happening here.\n", + " * The model gets hidden state(initialized with 0), lets call that **H0** and the first batch of input, lets call that **I0**.\n", + " * The model then returns the predictions **P1** and **H1**.\n", + " * For the next batch of input, the model receives **I1** and **H1**.\n", + " * The interesting thing here is that we pass **H1** to the model with **I1** which is how the model learns. The context learned from batch to batch is contained in the **hidden state**.\n", + " * We continue doing this until the dataset is exhausted and then we start a new epoch and repeat this.\n", + "\n", + "* After calculating the predictions, we calculate the **loss** using the loss function defined above. Then we calculate the gradients of the loss with respect to the model variables(input)\n", + "\n", + "* Finally, we take a step in that direction with the help of the optimizer using the apply_gradients function.\n", + "\n", + "Note:- If you are running this notebook in Colab which has a **Tesla K80 GPU** it takes about 23 seconds per epoch.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "d4tSNwymzf-q" + }, + "outputs": [], + "source": [ + "# Training step\n", + "\n", + "EPOCHS = 20\n", + "\n", + "for epoch in range(EPOCHS):\n", + " start = time.time()\n", + " \n", + " # initializing the hidden state at the start of every epoch\n", + " hidden = model.reset_states()\n", + " \n", + " for (batch, (inp, target)) in enumerate(dataset):\n", + " with tf.GradientTape() as tape:\n", + " # feeding the hidden state back into the model\n", + " # This is the interesting step\n", + " predictions, hidden = model(inp, hidden)\n", + " \n", + " # reshaping the target because that's how the \n", + " # loss function expects it\n", + " target = tf.reshape(target, (-1,))\n", + " loss = loss_function(target, predictions)\n", + " \n", + " grads = tape.gradient(loss, model.variables)\n", + " optimizer.apply_gradients(zip(grads, model.variables))\n", + "\n", + " if batch % 100 == 0:\n", + " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1,\n", + " batch,\n", + " loss))\n", + " # saving (checkpoint) the model every 5 epochs\n", + " if (epoch + 1) % 5 == 0:\n", + " checkpoint.save(file_prefix = checkpoint_prefix)\n", + "\n", + " print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n", + " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "01AR9vpNQMFF" + }, + "source": [ + "## Restore the latest checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "tyvpYomYQQkF" + }, + "outputs": [], + "source": [ + "# restoring the latest checkpoint in checkpoint_dir\n", + "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DjGz1tDkzf-u" + }, + "source": [ + "## Predicting using our trained model\n", + "\n", + "The below code block is used to generated the text\n", + "\n", + "* We start by choosing a start string and initializing the hidden state and setting the number of characters we want to generate.\n", + "\n", + "* We get predictions using the start_string and the hidden state\n", + "\n", + "* Then we use a multinomial distribution to calculate the index of the predicted word. **We use this predicted word as our next input to the model**\n", + "\n", + "* **The hidden state returned by the model is fed back into the model so that it now has more context rather than just one word.** After we predict the next word, the modified hidden states are again fed back into the model, which is how it learns as it gets more context from the previously predicted words.\n", + "\n", + "* If you see the predictions, the model knows when to capitalize, make paragraphs and the text follows a shakespeare style of writing which is pretty awesome!" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WvuwZBX5Ogfd" + }, + "outputs": [], + "source": [ + "# Evaluation step(generating text using the model learned)\n", + "\n", + "# number of characters to generate\n", + "num_generate = 1000\n", + "\n", + "# You can change the start string to experiment\n", + "start_string = 'Q'\n", + "# converting our start string to numbers(vectorizing!) \n", + "input_eval = [char2idx[s] for s in start_string]\n", + "input_eval = tf.expand_dims(input_eval, 0)\n", + "\n", + "# empty string to store our results\n", + "text_generated = ''\n", + "\n", + "# low temperatures results in more predictable text.\n", + "# higher temperatures results in more surprising text\n", + "# experiment to find the best setting\n", + "temperature = 1.0\n", + "\n", + "# hidden state shape == (batch_size, number of rnn units); here batch size == 1\n", + "hidden = [tf.zeros((1, units))]\n", + "for i in range(num_generate):\n", + " predictions, hidden = model(input_eval, hidden)\n", + "\n", + " # using a multinomial distribution to predict the word returned by the model\n", + " predictions = predictions / temperature\n", + " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n", + " \n", + " # We pass the predicted word as the next input to the model\n", + " # along with the previous hidden state\n", + " input_eval = tf.expand_dims([predicted_id], 0)\n", + " \n", + " text_generated += idx2char[predicted_id]\n", + "\n", + "print (start_string + text_generated)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "AM2Uma_-yVIq" + }, + "source": [ + "## Next steps\n", + "\n", + "* Change the start string to a different character, or the start of a sentence.\n", + "* Experiment with training on a different, or with different parameters. [Project Gutenberg](http://www.gutenberg.org/ebooks/100), for example, contains a large collection of books.\n", + "* Experiment with the temperature parameter.\n", + "* Add another RNN layer.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "gtEd86sX5cB2" + }, + "outputs": [], + "source": [ + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "text_generation.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true, + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/README.md b/tensorflow/contrib/eager/python/examples/l2hmc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f171806e379da7213b6ee33e0d454056068fe7a5 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/README.md @@ -0,0 +1,53 @@ +# L2HMC with TensorFlow eager execution + +This folder contains an implementation of [L2HMC](https://arxiv.org/pdf/1711.09268.pdf) adapted from the released implementation by the authors. The presented implementation runs in both eager and graph mode. +With eager execution enabled, longer sample chains can be handled compared to graph mode, since no graph is explicitly stored. Moreover, with eager execution enabled, there is no need to use a `tf.while_loop`. + +## What is L2HMC? +L2HMC is an adaptive Markov Chain Monte Carlo (MCMC) algorithm that learns a non-volume preserving transformation +for a Hamiltonian Monte Carlo (HMC) sampling algorithm. More specifically, the non-volume preserving +transformation is learned with neural nets instantiated within Normalizing Flows +(real-NVPs). + +## Content + +- `l2hmc.py`: Dynamics definitions and example energy functions, +including the 2D strongly correlated Gaussian and the rough well energy function, +- `l2hmc_test.py`: Unit tests and benchmarks for training a sampler on the energy functions in both eager and graph mode. +- `neural_nets.py`: The neural net for learning the kernel on the 2D strongly correlated example. +- `main.py`: Run to train a samplers on 2D energy landscapes. + +## To run +- Make sure you have installed TensorFlow 1.9+ or the latest `tf-nightly` or `tf-nightly-gpu` pip package. +- Execute the command + +```bash +python main.py --train_dir ${PWD}/dump --use_defun +``` + +Specifying the optional argument `train_dir` will store event files for +tensorboard and a plot of sampled chain from the trained sampler. + +Specifying the optional argument `use_defun` will let the program use compiled +graphs when running specific sections and improve the overall speed. + +## Boosting Performance with `tfe.defun` +Currently, some models may experience increased overhead with eager execution enabled. +To improve performance, we could wrap certain functions with the decorator `@tfe.defun`. +For example, we could wrap the function that does the sampling step: + +```python +@tfe.defun +def apply_transition(old_sample): + new_sample = ... + return new_sample +``` + +We could also explicitly wrap the desired function with `tfe.defun`: + +```python +apply_transition = tfe.defun(apply_transition) +``` + +## Reference +Generalizing Hamiltonian Monte Carlo with Neural Networks. Levy, Daniel, Hoffman, Matthew D, and Sohl-Dickstein, Jascha. International Conference on Learning Representations (ICLR), 2018. diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py index 729d8525fab31ee214178ca1bcb18dbd069f767a..14b8324e488a864cb23ff2507fab1c53c0583bc0 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py @@ -32,20 +32,28 @@ from tensorflow.contrib.eager.python.examples.l2hmc import neural_nets class Dynamics(tf.keras.Model): - """Dynamics engine of naive L2HMC sampler. - - Args: - x_dim: dimensionality of observed data - loglikelihood_fn: log-likelihood function of conditional probability - n_steps: number of leapfrog steps within each transition - eps: initial value learnable scale of step size - """ - - def __init__(self, x_dim, loglikelihood_fn, n_steps=25, eps=.1): + """Dynamics engine of naive L2HMC sampler.""" + + def __init__(self, + x_dim, + minus_loglikelihood_fn, + n_steps=25, + eps=.1, + np_seed=1): + """Initialization. + + Args: + x_dim: dimensionality of observed data + minus_loglikelihood_fn: log-likelihood function of conditional probability + n_steps: number of leapfrog steps within each transition + eps: initial value learnable scale of step size + np_seed: Random seed for numpy; used to control sampled masks. + """ super(Dynamics, self).__init__() + npr.seed(np_seed) self.x_dim = x_dim - self.potential = loglikelihood_fn + self.potential = minus_loglikelihood_fn self.n_steps = n_steps self._construct_time() @@ -54,7 +62,7 @@ class Dynamics(tf.keras.Model): self.position_fn = neural_nets.GenericNet(x_dim, factor=2.) self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.) - self.eps = tfe.Variable( + self.eps = tf.Variable( initial_value=eps, name="eps", dtype=tf.float32, trainable=True) def apply_transition(self, position): @@ -68,8 +76,8 @@ class Dynamics(tf.keras.Model): position, forward=False) # Decide direction uniformly - forward_mask = tf.cast( - tf.random_uniform(shape=[tf.shape(position)[0]]) > .5, tf.float32) + batch_size = tf.shape(position)[0] + forward_mask = tf.cast(tf.random_uniform((batch_size,)) > .5, tf.float32) backward_mask = 1. - forward_mask # Obtain proposed states @@ -108,7 +116,6 @@ class Dynamics(tf.keras.Model): position_post, momentum_post, logdet = lf_fn(position_post, momentum_post, i) sumlogdet += logdet - accept_prob = self._compute_accept_prob(position, momentum, position_post, momentum_post, sumlogdet) @@ -125,17 +132,17 @@ class Dynamics(tf.keras.Model): sumlogdet += logdet position, logdet = self._update_position_forward(position, momentum, t, - mask) + mask, mask_inv) sumlogdet += logdet position, logdet = self._update_position_forward(position, momentum, t, - mask_inv) + mask_inv, mask) sumlogdet += logdet momentum, logdet = self._update_momentum_forward(position, momentum, t) sumlogdet += logdet - return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + return position, momentum, sumlogdet def _backward_lf(self, position, momentum, i): """One backward augmented leapfrog step. See Appendix A in paper.""" @@ -149,17 +156,17 @@ class Dynamics(tf.keras.Model): sumlogdet += logdet position, logdet = self._update_position_backward(position, momentum, t, - mask) + mask_inv, mask) sumlogdet += logdet position, logdet = self._update_position_backward(position, momentum, t, - mask_inv) + mask, mask_inv) sumlogdet += logdet momentum, logdet = self._update_momentum_backward(position, momentum, t) sumlogdet += logdet - return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + return position, momentum, sumlogdet def _update_momentum_forward(self, position, momentum, t): """Update v in the forward leapfrog step.""" @@ -172,12 +179,11 @@ class Dynamics(tf.keras.Model): momentum * tf.exp(scale) - .5 * self.eps * (tf.exp(transformed) * grad - translation)) - return momentum, scale + return momentum, tf.reduce_sum(scale, axis=1) - def _update_position_forward(self, position, momentum, t, mask): + def _update_position_forward(self, position, momentum, t, mask, mask_inv): """Update x in the forward leapfrog step.""" - mask_inv = 1. - mask scale, translation, transformed = self.position_fn( [momentum, mask * position, t]) scale *= self.eps @@ -186,8 +192,7 @@ class Dynamics(tf.keras.Model): mask * position + mask_inv * (position * tf.exp(scale) + self.eps * (tf.exp(transformed) * momentum + translation))) - - return position, mask_inv * scale + return position, tf.reduce_sum(mask_inv * scale, axis=1) def _update_momentum_backward(self, position, momentum, t): """Update v in the backward leapfrog step. Inverting the forward update.""" @@ -200,21 +205,20 @@ class Dynamics(tf.keras.Model): tf.exp(scale) * (momentum + .5 * self.eps * (tf.exp(transformed) * grad - translation))) - return momentum, scale + return momentum, tf.reduce_sum(scale, axis=1) - def _update_position_backward(self, position, momentum, t, mask): + def _update_position_backward(self, position, momentum, t, mask, mask_inv): """Update x in the backward leapfrog step. Inverting the forward update.""" - mask_inv = 1. - mask scale, translation, transformed = self.position_fn( - [momentum, mask_inv * position, t]) + [momentum, mask * position, t]) scale *= -self.eps transformed *= self.eps position = ( - mask_inv * position + mask * tf.exp(scale) * - (position - self.eps * tf.exp(transformed) * momentum + translation)) + mask * position + mask_inv * tf.exp(scale) * + (position - self.eps * (tf.exp(transformed) * momentum + translation))) - return position, mask * scale + return position, tf.reduce_sum(mask_inv * scale, axis=1) def _compute_accept_prob(self, position, momentum, position_post, momentum_post, sumlogdet): @@ -222,8 +226,10 @@ class Dynamics(tf.keras.Model): old_hamil = self.hamiltonian(position, momentum) new_hamil = self.hamiltonian(position_post, momentum_post) + prob = tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.)) - return tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.)) + # Ensure numerical stability as well as correct gradients + return tf.where(tf.is_finite(prob), prob, tf.zeros_like(prob)) def _construct_time(self): """Convert leapfrog step index into sinusoidal time.""" @@ -248,6 +254,8 @@ class Dynamics(tf.keras.Model): self.masks = [] for _ in range(self.n_steps): + # Need to use npr here because tf would generated different random + # values across different `sess.run` idx = npr.permutation(np.arange(self.x_dim))[:self.x_dim // 2] mask = np.zeros((self.x_dim,)) mask[idx] = 1. @@ -273,19 +281,15 @@ class Dynamics(tf.keras.Model): def grad_potential(self, position, check_numerics=True): """Get gradient of potential function at current location.""" - if not tf.executing_eagerly(): - # TODO(lxuechen): Change this to tfe.gradients_function when it works - grad = tf.gradients(self.potential(position), position)[0] - else: + if tf.executing_eagerly(): grad = tfe.gradients_function(self.potential)(position)[0] - - if check_numerics: - return tf.check_numerics(grad, message="gradient of potential") + else: + grad = tf.gradients(self.potential(position), position)[0] return grad -# Examples of unnormalized log density/probabilities +# Examples of unnormalized log densities def get_scg_energy_fn(): """Get energy function for 2d strongly correlated Gaussian.""" @@ -295,32 +299,53 @@ def get_scg_energy_fn(): sigma_inv = tf.matrix_inverse(sigma) def energy(x): - """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + """Unnormalized minus log density of 2d strongly correlated Gaussian.""" xmmu = x - mu return .5 * tf.diag_part( tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) - return energy + return energy, mu, sigma -def get_multivariate_gaussian_energy_fn(x_dim=2): - """Get energy function for 2d strongly correlated Gaussian.""" - - mu = tf.random_normal(shape=[x_dim]) - # Lower triangularize and positive diagonal - l = tf.sigmoid( - tf.matrix_band_part(tf.random_normal(shape=[x_dim, x_dim]), -1, 0)) - # Exploit Cholesky decomposition - sigma = tf.matmul(l, tf.transpose(l)) - sigma *= 100. # Small covariance causes extreme numerical instability - sigma_inv = tf.matrix_inverse(sigma) +def get_rw_energy_fn(): + """Get energy function for rough well distribution.""" + # For small eta, the density underlying the rough-well energy is very close to + # a unit Gaussian; however, the gradient is greatly affected by the small + # cosine perturbations + eta = 1e-2 + mu = tf.constant([0., 0.]) + sigma = tf.constant([[1., 0.], [0., 1.]]) def energy(x): - """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + ip = tf.reduce_sum(x**2., axis=1) + return .5 * ip + eta * tf.reduce_sum(tf.cos(x / eta), axis=1) - xmmu = x - mu - return .5 * tf.diag_part( - tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) + return energy, mu, sigma + + +# Loss function +def compute_loss(dynamics, x, scale=.1, eps=1e-4): + """Compute loss defined in equation (8).""" + + z = tf.random_normal(tf.shape(x)) # Auxiliary variable + x_, _, x_accept_prob, x_out = dynamics.apply_transition(x) + z_, _, z_accept_prob, _ = dynamics.apply_transition(z) + + # Add eps for numerical stability; following released impl + x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps + z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps + + loss = tf.reduce_mean( + (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0) + + return loss, x_out, x_accept_prob + + +def loss_and_grads(dynamics, x, loss_fn=compute_loss): + """Obtain loss value and gradients.""" + with tf.GradientTape() as tape: + loss_val, out, accept_prob = loss_fn(dynamics, x) + grads = tape.gradient(loss_val, dynamics.trainable_variables) - return energy + return loss_val, grads, out, accept_prob diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py index e33b4cae4c73388dfd78542c9907953f137ad710..955747988536bd21d52df66a35af4aa31b3f7688 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py @@ -37,63 +37,37 @@ def get_default_hparams(): n_warmup_iters=3) -# Relevant functions for benchmarking -def compute_loss(dynamics, x, scale=.1, eps=1e-4): - """Compute loss defined in equation (8).""" - - z = tf.random_normal(tf.shape(x)) - x_, _, x_accept_prob, x_out = dynamics.apply_transition(x) - z_, _, z_accept_prob, _ = dynamics.apply_transition(z) - - # Add eps for numerical stability; following released impl - x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps - z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps - - loss = tf.reduce_mean( - (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0) - - return loss, x_out - - -def loss_and_grads(dynamics, x, loss_fn=compute_loss): - """Obtain loss value and gradients.""" - - with tf.GradientTape() as tape: - loss_val, x_out = loss_fn(dynamics, x) - grads = tape.gradient(loss_val, dynamics.variables) - - return loss_val, grads, x_out - - -def warmup(dynamics, optimizer, n_iters=1, n_samples=200, loss_fn=compute_loss): +def warmup(dynamics, + optimizer, + n_iters=1, + n_samples=200, + loss_fn=l2hmc.compute_loss): """Warmup optimization to reduce overhead.""" samples = tf.random_normal( shape=[n_samples, dynamics.x_dim], dtype=tf.float32) for _ in range(n_iters): - _, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) + _, grads, samples, _ = l2hmc.loss_and_grads( + dynamics, samples, loss_fn=loss_fn) optimizer.apply_gradients(zip(grads, dynamics.variables)) def fit(dynamics, samples, optimizer, - loss_fn=compute_loss, + loss_fn=l2hmc.compute_loss, n_iters=5000, verbose=True, - logdir=None, - decay_lr=True): + logdir=None): """Fit L2HMC sampler with given log-likelihood function.""" if logdir: summary_writer = tf.contrib.summary.create_file_writer(logdir) for i in range(n_iters): - loss, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) - # TODO(lxuechen): Proper learning rate decay - if decay_lr: - grads = [grad * .96**(i // 1000) for grad in grads] + loss, grads, samples, _ = l2hmc.loss_and_grads( + dynamics, samples, loss_fn=loss_fn) optimizer.apply_gradients(zip(grads, dynamics.variables)) if verbose: print("Iteration %d: loss %.4f" % (i, loss)) @@ -112,9 +86,10 @@ class L2hmcTest(tf.test.TestCase): # Eager mode testing hparams = get_default_hparams() + energy_fn, _, _ = l2hmc.get_scg_energy_fn() dynamics = l2hmc.Dynamics( x_dim=hparams.x_dim, - loglikelihood_fn=l2hmc.get_scg_energy_fn(), + minus_loglikelihood_fn=energy_fn, n_steps=hparams.n_steps, eps=hparams.eps) samples = tf.random_normal(shape=[hparams.n_samples, hparams.x_dim]) @@ -127,9 +102,10 @@ class L2hmcTest(tf.test.TestCase): # Graph mode testing with tf.Graph().as_default(): + energy_fn, _, _ = l2hmc.get_scg_energy_fn() dynamics = l2hmc.Dynamics( x_dim=hparams.x_dim, - loglikelihood_fn=l2hmc.get_scg_energy_fn(), + minus_loglikelihood_fn=energy_fn, n_steps=hparams.n_steps, eps=hparams.eps) x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) @@ -150,32 +126,20 @@ class L2hmcTest(tf.test.TestCase): class L2hmcBenchmark(tf.test.Benchmark): """Eager and graph benchmarks for l2hmc.""" - def _get_energy_fn(self): - """Get specific energy function according to FLAGS.""" - - if FLAGS.energy_fn == "scg": - energy_fn = l2hmc.get_scg_energy_fn() - elif FLAGS.energy_fn == "multivariate_gaussian": - energy_fn = l2hmc.get_multivariate_gaussian_energy_fn(x_dim=FLAGS.x_dim) - else: - raise ValueError("No such energy function %s" % FLAGS.energy_fn) - - return energy_fn - def benchmark_graph(self): """Benchmark Graph performance.""" hparams = get_default_hparams() tf.reset_default_graph() with tf.Graph().as_default(): - energy_fn = self._get_energy_fn() + energy_fn, _, _ = l2hmc.get_scg_energy_fn() dynamics = l2hmc.Dynamics( x_dim=hparams.x_dim, - loglikelihood_fn=energy_fn, + minus_loglikelihood_fn=energy_fn, n_steps=hparams.n_steps, eps=hparams.eps) x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) - loss, x_out = compute_loss(dynamics, x) + loss, x_out, _ = l2hmc.compute_loss(dynamics, x) global_step = tf.Variable(0., name="global_step", trainable=False) learning_rate = tf.train.exponential_decay( @@ -183,7 +147,11 @@ class L2hmcBenchmark(tf.test.Benchmark): optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) train_op = optimizer.minimize(loss, global_step=global_step) - with tf.Session() as sess: + # Single thread; fairer comparison against eager + session_conf = tf.ConfigProto( + intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) + + with tf.Session(config=session_conf) as sess: sess.run(tf.global_variables_initializer()) # Warmup to reduce initialization effect when timing @@ -218,14 +186,14 @@ class L2hmcBenchmark(tf.test.Benchmark): """Benchmark Eager performance.""" hparams = get_default_hparams() - energy_fn = self._get_energy_fn() + energy_fn, _, _ = l2hmc.get_scg_energy_fn() dynamics = l2hmc.Dynamics( x_dim=hparams.x_dim, - loglikelihood_fn=energy_fn, + minus_loglikelihood_fn=energy_fn, n_steps=hparams.n_steps, eps=hparams.eps) optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) - loss_fn = tfe.defun(compute_loss) if defun else compute_loss + loss_fn = tfe.defun(l2hmc.compute_loss) if defun else l2hmc.compute_loss # Warmup to reduce initialization effect when timing warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn) @@ -234,12 +202,7 @@ class L2hmcBenchmark(tf.test.Benchmark): samples = tf.random_normal( shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32) start_time = time.time() - fit(dynamics, - samples, - optimizer, - loss_fn=loss_fn, - n_iters=hparams.n_iters, - decay_lr=True) + fit(dynamics, samples, optimizer, loss_fn=loss_fn, n_iters=hparams.n_iters) wall_time = time.time() - start_time examples_per_sec = hparams.n_samples / wall_time @@ -251,14 +214,8 @@ class L2hmcBenchmark(tf.test.Benchmark): wall_time=wall_time) del dynamics - del loss_fn if __name__ == "__main__": - tf.flags.DEFINE_string("energy_fn", "scg", - ("The energy function/unnormalized log-probability. " - "Either be `scg` or `multivariate_gaussian`")) - tf.flags.DEFINE_integer("x_dim", 2, "Dimensionality of observation space.") - FLAGS = tf.flags.FLAGS tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/main.py b/tensorflow/contrib/eager/python/examples/l2hmc/main.py new file mode 100644 index 0000000000000000000000000000000000000000..45e1f98429f48749d374c2aefd8874690c3830ad --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/main.py @@ -0,0 +1,235 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""L2HMC on simple Gaussian mixture model with TensorFlow eager.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +from absl import flags +import numpy as np +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.l2hmc import l2hmc +try: + import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top + HAS_MATPLOTLIB = True +except ImportError: + HAS_MATPLOTLIB = False +tfe = tf.contrib.eager + + +def main(_): + tf.enable_eager_execution() + global_step = tf.train.get_or_create_global_step() + global_step.assign(1) + + energy_fn, mean, covar = { + "scg": l2hmc.get_scg_energy_fn(), + "rw": l2hmc.get_rw_energy_fn() + }[FLAGS.energy_fn] + + x_dim = 2 + train_iters = 5000 + eval_iters = 2000 + eps = 0.1 + n_steps = 10 # Chain length + n_samples = 200 + record_loss_every = 100 + + dynamics = l2hmc.Dynamics( + x_dim=x_dim, minus_loglikelihood_fn=energy_fn, n_steps=n_steps, eps=eps) + learning_rate = tf.train.exponential_decay( + 1e-3, global_step, 1000, 0.96, staircase=True) + optimizer = tf.train.AdamOptimizer(learning_rate) + checkpointer = tf.train.Checkpoint( + optimizer=optimizer, dynamics=dynamics, global_step=global_step) + + if FLAGS.train_dir: + summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) + if FLAGS.restore: + latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) + checkpointer.restore(latest_path) + print("Restored latest checkpoint at path:\"{}\" ".format(latest_path)) + sys.stdout.flush() + + if not FLAGS.restore: + # Training + if FLAGS.use_defun: + # Use `tfe.deun` to boost performance when there are lots of small ops + loss_fn = tfe.defun(l2hmc.compute_loss) + else: + loss_fn = l2hmc.compute_loss + + samples = tf.random_normal(shape=[n_samples, x_dim]) + for i in range(1, train_iters + 1): + loss, samples, accept_prob = train_one_iter( + dynamics, + samples, + optimizer, + loss_fn=loss_fn, + global_step=global_step) + + if i % record_loss_every == 0: + print("Iteration {}, loss {:.4f}, x_accept_prob {:.4f}".format( + i, loss.numpy(), + accept_prob.numpy().mean())) + if FLAGS.train_dir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("Training loss", loss, step=global_step) + print("Training complete.") + sys.stdout.flush() + + if FLAGS.train_dir: + saved_path = checkpointer.save( + file_prefix=os.path.join(FLAGS.train_dir, "ckpt")) + print("Saved checkpoint at path: \"{}\" ".format(saved_path)) + sys.stdout.flush() + + # Evaluation + if FLAGS.use_defun: + # Use tfe.deun to boost performance when there are lots of small ops + apply_transition = tfe.defun(dynamics.apply_transition) + else: + apply_transition = dynamics.apply_transition + + samples = tf.random_normal(shape=[n_samples, x_dim]) + samples_history = [] + for i in range(eval_iters): + samples_history.append(samples.numpy()) + _, _, _, samples = apply_transition(samples) + samples_history = np.array(samples_history) + print("Sampling complete.") + sys.stdout.flush() + + # Mean and covariance of target distribution + mean = mean.numpy() + covar = covar.numpy() + ac_spectrum = compute_ac_spectrum(samples_history, mean, covar) + print("First 25 entries of the auto-correlation spectrum: {}".format( + ac_spectrum[:25])) + ess = compute_ess(ac_spectrum) + print("Effective sample size per Metropolis-Hastings step: {}".format(ess)) + sys.stdout.flush() + + if FLAGS.train_dir: + # Plot autocorrelation spectrum in tensorboard + plot_step = tfe.Variable(1, trainable=False, dtype=tf.int64) + + for ac in ac_spectrum: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("Autocorrelation", ac, step=plot_step) + plot_step.assign(plot_step + n_steps) + + if HAS_MATPLOTLIB: + # Choose a single chain and plot the trajectory + single_chain = samples_history[:, 0, :] + xs = single_chain[:100, 0] + ys = single_chain[:100, 1] + plt.figure() + plt.plot(xs, ys, color="orange", marker="o", alpha=0.6) # Trained chain + plt.savefig(os.path.join(FLAGS.train_dir, "single_chain.png")) + + +def train_one_iter(dynamics, + x, + optimizer, + loss_fn=l2hmc.compute_loss, + global_step=None): + """Train the sampler for one iteration.""" + loss, grads, out, accept_prob = l2hmc.loss_and_grads( + dynamics, x, loss_fn=loss_fn) + optimizer.apply_gradients( + zip(grads, dynamics.trainable_variables), global_step=global_step) + + return loss, out, accept_prob + + +def compute_ac_spectrum(samples_history, target_mean, target_covar): + """Compute autocorrelation spectrum. + + Follows equation 15 from the L2HMC paper. + + Args: + samples_history: Numpy array of shape [T, B, D], where T is the total + number of time steps, B is the batch size, and D is the dimensionality + of sample space. + target_mean: 1D Numpy array of the mean of target(true) distribution. + target_covar: 2D Numpy array representing a symmetric matrix for variance. + Returns: + Autocorrelation spectrum, Numpy array of shape [T-1]. + """ + + # Using numpy here since eager is a bit slow due to the loop + time_steps = samples_history.shape[0] + trace = np.trace(target_covar) + + rhos = [] + for t in range(time_steps - 1): + rho_t = 0. + for tau in range(time_steps - t): + v_tau = samples_history[tau, :, :] - target_mean + v_tau_plus_t = samples_history[tau + t, :, :] - target_mean + # Take dot product over observation dims and take mean over batch dims + rho_t += np.mean(np.sum(v_tau * v_tau_plus_t, axis=1)) + + rho_t /= trace * (time_steps - t) + rhos.append(rho_t) + + return np.array(rhos) + + +def compute_ess(ac_spectrum): + """Compute the effective sample size based on autocorrelation spectrum. + + This follows equation 16 from the L2HMC paper. + + Args: + ac_spectrum: Autocorrelation spectrum + Returns: + The effective sample size + """ + # Cutoff from the first value less than 0.05 + cutoff = np.argmax(ac_spectrum[1:] < .05) + if cutoff == 0: + cutoff = len(ac_spectrum) + ess = 1. / (1. + 2. * np.sum(ac_spectrum[1:cutoff])) + return ess + + +if __name__ == "__main__": + flags.DEFINE_string( + "train_dir", + default=None, + help="[Optional] Directory to store the training information") + flags.DEFINE_boolean( + "restore", + default=False, + help="[Optional] Restore the latest checkpoint from `train_dir` if True") + flags.DEFINE_boolean( + "use_defun", + default=False, + help="[Optional] Use `tfe.defun` to boost performance") + flags.DEFINE_string( + "energy_fn", + default="scg", + help="[Optional] The energy function used for experimentation" + "Other options include `rw`") + FLAGS = flags.FLAGS + tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py index e230ad5e259df5b450897bd815e901e3934cd293..68e0bc31239007e3b1b8451cf1d6e7592c6ca030 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py @@ -25,7 +25,6 @@ from __future__ import division from __future__ import print_function import tensorflow as tf -import tensorflow.contrib.eager as tfe class GenericNet(tf.keras.Model): @@ -47,13 +46,13 @@ class GenericNet(tf.keras.Model): # Scale self.scale_layer = _custom_dense(x_dim, .001) - self.coeff_scale = tfe.Variable( + self.coeff_scale = tf.Variable( initial_value=tf.zeros([1, x_dim]), name='coeff_scale', trainable=True) # Translation self.translation_layer = _custom_dense(x_dim, factor=.001) # Transformation self.transformation_layer = _custom_dense(x_dim, .001) - self.coeff_transformation = tfe.Variable( + self.coeff_transformation = tf.Variable( initial_value=tf.zeros([1, x_dim]), name='coeff_transformation', trainable=True) diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..08d8364978f6a9b4e8e15b5caac7db14c1d721b4 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -0,0 +1,854 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "AOpGoE2T-YXS" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# Neural Machine Translation with Attention\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "CiwtNgENbx2g" + }, + "source": [ + "This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). This is an advanced example that assumes some knowledge of sequence to sequence models.\n", + "\n", + "After training the model in this notebook, you will be able to input a Spanish sentence, such as *\"¿todavia estan en casa?\"*, and return the English translation: *\"are you still at home?\"*\n", + "\n", + "The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n", + "\n", + "\u003cimg src=\"https://tensorflow.org/images/spanish-english.png\" alt=\"spanish-english attention plot\"\u003e\n", + "\n", + "Note: This example takes approximately 10 mintues to run on a single P100 GPU." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "tnxXKDjq3jEL" + }, + "outputs": [], + "source": [ + "from __future__ import absolute_import, division, print_function\n", + "\n", + "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", + "import tensorflow as tf\n", + "\n", + "tf.enable_eager_execution()\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "import unicodedata\n", + "import re\n", + "import numpy as np\n", + "import os\n", + "import time\n", + "\n", + "print(tf.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "wfodePkj3jEa" + }, + "source": [ + "## Download and prepare the dataset\n", + "\n", + "We'll use a language dataset provided by http://www.manythings.org/anki/. This dataset contains language translation pairs in the format:\n", + "\n", + "```\n", + "May I borrow this book?\t¿Puedo tomar prestado este libro?\n", + "```\n", + "\n", + "There are a variety of languages available, but we'll use the English-Spanish dataset. For convenience, we've hosted a copy of this dataset on Google Cloud, but you can also download your own copy. After downloading the dataset, here are the steps we'll take to prepare the data:\n", + "\n", + "1. Add a *start* and *end* token to each sentence.\n", + "2. Clean the sentences by removing special characters.\n", + "3. Create a word index and reverse word index (dictionaries mapping from word → id and id → word).\n", + "4. Pad each sentence to a maximum length." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "kRVATYOgJs1b" + }, + "outputs": [], + "source": [ + "# Download the file\n", + "path_to_zip = tf.keras.utils.get_file(\n", + " 'spa-eng.zip', origin='http://download.tensorflow.org/data/spa-eng.zip', \n", + " extract=True)\n", + "\n", + "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\"" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "rd0jw-eC3jEh" + }, + "outputs": [], + "source": [ + "# Converts the unicode file to ascii\n", + "def unicode_to_ascii(s):\n", + " return ''.join(c for c in unicodedata.normalize('NFD', s)\n", + " if unicodedata.category(c) != 'Mn')\n", + "\n", + "\n", + "def preprocess_sentence(w):\n", + " w = unicode_to_ascii(w.lower().strip())\n", + " \n", + " # creating a space between a word and the punctuation following it\n", + " # eg: \"he is a boy.\" =\u003e \"he is a boy .\" \n", + " # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n", + " w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n", + " w = re.sub(r'[\" \"]+', \" \", w)\n", + " \n", + " # replacing everything with space except (a-z, A-Z, \".\", \"?\", \"!\", \",\")\n", + " w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n", + " \n", + " w = w.rstrip().strip()\n", + " \n", + " # adding a start and an end token to the sentence\n", + " # so that the model know when to start and stop predicting.\n", + " w = '\u003cstart\u003e ' + w + ' \u003cend\u003e'\n", + " return w" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "OHn4Dct23jEm" + }, + "outputs": [], + "source": [ + "# 1. Remove the accents\n", + "# 2. Clean the sentences\n", + "# 3. Return word pairs in the format: [ENGLISH, SPANISH]\n", + "def create_dataset(path, num_examples):\n", + " lines = open(path, encoding='UTF-8').read().strip().split('\\n')\n", + " \n", + " word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n", + " \n", + " return word_pairs" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "9xbqO7Iie9bb" + }, + "outputs": [], + "source": [ + "# This class creates a word -\u003e index mapping (e.g,. \"dad\" -\u003e 5) and vice-versa \n", + "# (e.g., 5 -\u003e \"dad\") for each language,\n", + "class LanguageIndex():\n", + " def __init__(self, lang):\n", + " self.lang = lang\n", + " self.word2idx = {}\n", + " self.idx2word = {}\n", + " self.vocab = set()\n", + " \n", + " self.create_index()\n", + " \n", + " def create_index(self):\n", + " for phrase in self.lang:\n", + " self.vocab.update(phrase.split(' '))\n", + " \n", + " self.vocab = sorted(self.vocab)\n", + " \n", + " self.word2idx['\u003cpad\u003e'] = 0\n", + " for index, word in enumerate(self.vocab):\n", + " self.word2idx[word] = index + 1\n", + " \n", + " for word, index in self.word2idx.items():\n", + " self.idx2word[index] = word" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "eAY9k49G3jE_" + }, + "outputs": [], + "source": [ + "def max_length(tensor):\n", + " return max(len(t) for t in tensor)\n", + "\n", + "\n", + "def load_dataset(path, num_examples):\n", + " # creating cleaned input, output pairs\n", + " pairs = create_dataset(path, num_examples)\n", + "\n", + " # index language using the class defined above \n", + " inp_lang = LanguageIndex(sp for en, sp in pairs)\n", + " targ_lang = LanguageIndex(en for en, sp in pairs)\n", + " \n", + " # Vectorize the input and target languages\n", + " \n", + " # Spanish sentences\n", + " input_tensor = [[inp_lang.word2idx[s] for s in sp.split(' ')] for en, sp in pairs]\n", + " \n", + " # English sentences\n", + " target_tensor = [[targ_lang.word2idx[s] for s in en.split(' ')] for en, sp in pairs]\n", + " \n", + " # Calculate max_length of input and output tensor\n", + " # Here, we'll set those to the longest sentence in the dataset\n", + " max_length_inp, max_length_tar = max_length(input_tensor), max_length(target_tensor)\n", + " \n", + " # Padding the input and output tensor to the maximum length\n", + " input_tensor = tf.keras.preprocessing.sequence.pad_sequences(input_tensor, \n", + " maxlen=max_length_inp,\n", + " padding='post')\n", + " \n", + " target_tensor = tf.keras.preprocessing.sequence.pad_sequences(target_tensor, \n", + " maxlen=max_length_tar, \n", + " padding='post')\n", + " \n", + " return input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_tar" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GOi42V79Ydlr" + }, + "source": [ + "### Limit the size of the dataset to experiment faster (optional)\n", + "\n", + "Training on the complete dataset of \u003e100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "cnxC7q-j3jFD" + }, + "outputs": [], + "source": [ + "# Try experimenting with the size of that dataset\n", + "num_examples = 30000\n", + "input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_targ = load_dataset(path_to_file, num_examples)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "4QILQkOs3jFG" + }, + "outputs": [], + "source": [ + "# Creating training and validation sets using an 80-20 split\n", + "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n", + "\n", + "# Show length\n", + "len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "rgCLkfv5uO3d" + }, + "source": [ + "### Create a tf.data dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "TqHsArVZ3jFS" + }, + "outputs": [], + "source": [ + "BUFFER_SIZE = len(input_tensor_train)\n", + "BATCH_SIZE = 64\n", + "N_BATCH = BUFFER_SIZE//BATCH_SIZE\n", + "embedding_dim = 256\n", + "units = 1024\n", + "vocab_inp_size = len(inp_lang.word2idx)\n", + "vocab_tar_size = len(targ_lang.word2idx)\n", + "\n", + "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", + "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "TNfHIF71ulLu" + }, + "source": [ + "## Write the encoder and decoder model\n", + "\n", + "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://www.tensorflow.org/tutorials/seq2seq). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://www.tensorflow.org/tutorials/seq2seq#background_on_the_attention_mechanism) from the seq2seq tutorial. The following diagram shows that each input words is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n", + "\n", + "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_mechanism.jpg\" width=\"500\" alt=\"attention mechanism\"\u003e\n", + "\n", + "The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n", + "\n", + "Here are the equations that are implemented:\n", + "\n", + "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_0.jpg\" alt=\"attention equation 0\" width=\"800\"\u003e\n", + "\u003cimg src=\"https://www.tensorflow.org/images/seq2seq/attention_equation_1.jpg\" alt=\"attention equation 1\" width=\"800\"\u003e\n", + "\n", + "We're using *Bahdanau attention*. Lets decide on notation before writing the simplified form:\n", + "\n", + "* FC = Fully connected (dense) layer\n", + "* EO = Encoder output\n", + "* H = hidden state\n", + "* X = input to the decoder\n", + "\n", + "And the pseudo-code:\n", + "\n", + "* `score = FC(tanh(FC(EO) + FC(H)))`\n", + "* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, hidden_size)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.\n", + "* `context vector = sum(attention weights * EO, axis = 1)`. Same reason as above for choosing axis as 1.\n", + "* `embedding output` = The input to the decoder X is passed through an embedding layer.\n", + "* `merged vector = concat(embedding output, context vector)`\n", + "* This merged vector is then given to the GRU\n", + " \n", + "The shapes of all the vectors at each step have been specified in the comments in the code:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "avyJ_4VIUoHb" + }, + "outputs": [], + "source": [ + "def gru(units):\n", + " # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)\n", + " # the code automatically does that.\n", + " if tf.test.is_gpu_available():\n", + " return tf.keras.layers.CuDNNGRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_initializer='glorot_uniform')\n", + " else:\n", + " return tf.keras.layers.GRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_activation='sigmoid', \n", + " recurrent_initializer='glorot_uniform')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "nZ2rI24i3jFg" + }, + "outputs": [], + "source": [ + "class Encoder(tf.keras.Model):\n", + " def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n", + " super(Encoder, self).__init__()\n", + " self.batch_sz = batch_sz\n", + " self.enc_units = enc_units\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + " self.gru = gru(self.enc_units)\n", + " \n", + " def call(self, x, hidden):\n", + " x = self.embedding(x)\n", + " output, state = self.gru(x, initial_state = hidden) \n", + " return output, state\n", + " \n", + " def initialize_hidden_state(self):\n", + " return tf.zeros((self.batch_sz, self.enc_units))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "yJ_B3mhW3jFk" + }, + "outputs": [], + "source": [ + "class Decoder(tf.keras.Model):\n", + " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", + " super(Decoder, self).__init__()\n", + " self.batch_sz = batch_sz\n", + " self.dec_units = dec_units\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + " self.gru = gru(self.dec_units)\n", + " self.fc = tf.keras.layers.Dense(vocab_size)\n", + " \n", + " # used for attention\n", + " self.W1 = tf.keras.layers.Dense(self.dec_units)\n", + " self.W2 = tf.keras.layers.Dense(self.dec_units)\n", + " self.V = tf.keras.layers.Dense(1)\n", + " \n", + " def call(self, x, hidden, enc_output):\n", + " # enc_output shape == (batch_size, max_length, hidden_size)\n", + " \n", + " # hidden shape == (batch_size, hidden size)\n", + " # hidden_with_time_axis shape == (batch_size, 1, hidden size)\n", + " # we are doing this to perform addition to calculate the score\n", + " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", + " \n", + " # score shape == (batch_size, max_length, hidden_size)\n", + " score = tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis))\n", + " \n", + " # attention_weights shape == (batch_size, max_length, 1)\n", + " # we get 1 at the last axis because we are applying score to self.V\n", + " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", + " \n", + " # context_vector shape after sum == (batch_size, hidden_size)\n", + " context_vector = attention_weights * enc_output\n", + " context_vector = tf.reduce_sum(context_vector, axis=1)\n", + " \n", + " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", + " x = self.embedding(x)\n", + " \n", + " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", + " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", + " \n", + " # passing the concatenated vector to the GRU\n", + " output, state = self.gru(x)\n", + " \n", + " # output shape == (batch_size * max_length, hidden_size)\n", + " output = tf.reshape(output, (-1, output.shape[2]))\n", + " \n", + " # output shape == (batch_size * max_length, vocab)\n", + " x = self.fc(output)\n", + " \n", + " return x, state, attention_weights\n", + " \n", + " def initialize_hidden_state(self):\n", + " return tf.zeros((self.batch_sz, self.dec_units))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "P5UY8wko3jFp" + }, + "outputs": [], + "source": [ + "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n", + "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "_ch_71VbIRfK" + }, + "source": [ + "## Define the optimizer and the loss function" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WmTHr5iV3jFr" + }, + "outputs": [], + "source": [ + "optimizer = tf.train.AdamOptimizer()\n", + "\n", + "\n", + "def loss_function(real, pred):\n", + " mask = 1 - np.equal(real, 0)\n", + " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", + " return tf.reduce_mean(loss_)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DMVWzzsfNl4e" + }, + "source": [ + "## Checkpoints (Object-based saving)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Zj8bXQTgNwrF" + }, + "outputs": [], + "source": [ + "checkpoint_dir = './training_checkpoints'\n", + "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", + "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", + " encoder=encoder,\n", + " decoder=decoder)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "hpObfY22IddU" + }, + "source": [ + "## Training\n", + "\n", + "1. Pass the *input* through the *encoder* which return *encoder output* and the *encoder hidden state*.\n", + "2. The encoder output, encoder hidden state and the decoder input (which is the *start token*) is passed to the decoder.\n", + "3. The decoder returns the *predictions* and the *decoder hidden state*.\n", + "4. The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", + "5. Use *teacher forcing* to decide the next input to the decoder.\n", + "6. *Teacher forcing* is the technique where the *target word* is passed as the *next input* to the decoder.\n", + "7. The final step is to calculate the gradients and apply it to the optimizer and backpropagate." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ddefjBMa3jF0" + }, + "outputs": [], + "source": [ + "EPOCHS = 10\n", + "\n", + "for epoch in range(EPOCHS):\n", + " start = time.time()\n", + " \n", + " hidden = encoder.initialize_hidden_state()\n", + " total_loss = 0\n", + " \n", + " for (batch, (inp, targ)) in enumerate(dataset):\n", + " loss = 0\n", + " \n", + " with tf.GradientTape() as tape:\n", + " enc_output, enc_hidden = encoder(inp, hidden)\n", + " \n", + " dec_hidden = enc_hidden\n", + " \n", + " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']] * BATCH_SIZE, 1) \n", + " \n", + " # Teacher forcing - feeding the target as the next input\n", + " for t in range(1, targ.shape[1]):\n", + " # passing enc_output to the decoder\n", + " predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)\n", + " \n", + " loss += loss_function(targ[:, t], predictions)\n", + " \n", + " # using teacher forcing\n", + " dec_input = tf.expand_dims(targ[:, t], 1)\n", + " \n", + " batch_loss = (loss / int(targ.shape[1]))\n", + " \n", + " total_loss += batch_loss\n", + " \n", + " variables = encoder.variables + decoder.variables\n", + " \n", + " gradients = tape.gradient(loss, variables)\n", + " \n", + " optimizer.apply_gradients(zip(gradients, variables))\n", + " \n", + " if batch % 100 == 0:\n", + " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n", + " batch,\n", + " batch_loss.numpy()))\n", + " # saving (checkpoint) the model every 2 epochs\n", + " if (epoch + 1) % 2 == 0:\n", + " checkpoint.save(file_prefix = checkpoint_prefix)\n", + " \n", + " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", + " total_loss / N_BATCH))\n", + " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "mU3Ce8M6I3rz" + }, + "source": [ + "## Translate\n", + "\n", + "* The evaluate function is similar to the training loop, except we don't use *teacher forcing* here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", + "* Stop predicting when the model predicts the *end token*.\n", + "* And store the *attention weights for every time step*.\n", + "\n", + "Note: The encoder output is calculated only once for one input." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "EbQpyYs13jF_" + }, + "outputs": [], + "source": [ + "def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", + " attention_plot = np.zeros((max_length_targ, max_length_inp))\n", + " \n", + " sentence = preprocess_sentence(sentence)\n", + "\n", + " inputs = [inp_lang.word2idx[i] for i in sentence.split(' ')]\n", + " inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs], maxlen=max_length_inp, padding='post')\n", + " inputs = tf.convert_to_tensor(inputs)\n", + " \n", + " result = ''\n", + "\n", + " hidden = [tf.zeros((1, units))]\n", + " enc_out, enc_hidden = encoder(inputs, hidden)\n", + "\n", + " dec_hidden = enc_hidden\n", + " dec_input = tf.expand_dims([targ_lang.word2idx['\u003cstart\u003e']], 0)\n", + "\n", + " for t in range(max_length_targ):\n", + " predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)\n", + " \n", + " # storing the attention weigths to plot later on\n", + " attention_weights = tf.reshape(attention_weights, (-1, ))\n", + " attention_plot[t] = attention_weights.numpy()\n", + "\n", + " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n", + "\n", + " result += targ_lang.idx2word[predicted_id] + ' '\n", + "\n", + " if targ_lang.idx2word[predicted_id] == '\u003cend\u003e':\n", + " return result, sentence, attention_plot\n", + " \n", + " # the predicted ID is fed back into the model\n", + " dec_input = tf.expand_dims([predicted_id], 0)\n", + "\n", + " return result, sentence, attention_plot" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "s5hQWlbN3jGF" + }, + "outputs": [], + "source": [ + "# function for plotting the attention weights\n", + "def plot_attention(attention, sentence, predicted_sentence):\n", + " fig = plt.figure(figsize=(10,10))\n", + " ax = fig.add_subplot(1, 1, 1)\n", + " ax.matshow(attention, cmap='viridis')\n", + " \n", + " fontdict = {'fontsize': 14}\n", + " \n", + " ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n", + " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n", + "\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "sl9zUHzg3jGI" + }, + "outputs": [], + "source": [ + "def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", + " result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)\n", + " \n", + " print('Input: {}'.format(sentence))\n", + " print('Predicted translation: {}'.format(result))\n", + " \n", + " attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n", + " plot_attention(attention_plot, sentence.split(' '), result.split(' '))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "n250XbnjOaqP" + }, + "source": [ + "## Restore the latest checkpoint and test" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "UJpT9D5_OgP6" + }, + "outputs": [], + "source": [ + "# restoring the latest checkpoint in checkpoint_dir\n", + "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WrAM0FDomq3E" + }, + "outputs": [], + "source": [ + "translate('hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "zSx2iM36EZQZ" + }, + "outputs": [], + "source": [ + "translate('esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "A3LLCx3ZE0Ls" + }, + "outputs": [], + "source": [ + "translate('¿todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "DUQVLVqUE1YW" + }, + "outputs": [], + "source": [ + "# wrong translation\n", + "translate('trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "RTe5P5ioMJwN" + }, + "source": [ + "## Next steps\n", + "\n", + "* [Download a different dataset](http://www.manythings.org/anki/) to experiment with translations, for example, English to German, or English to French.\n", + "* Experiment with training on a larger dataset, or using more epochs\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "nmt_with_attention.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U", + "timestamp": 1527858391290 + }, + { + "file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv", + "timestamp": 1527776041613 + } + ], + "toc_visible": true, + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb deleted file mode 100644 index 9c1af9c2084bac7ae6369babeaa13720e6199097..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb +++ /dev/null @@ -1,323 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "vDJ4XzMqodTy" - }, - "source": [ - "# Automatic Differentiation\n", - "\n", - "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "GQJysDM__Qb0" - }, - "source": [ - "## Setup\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "OiMPZStlibBv" - }, - "outputs": [], - "source": [ - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "tfe = tf.contrib.eager # Shorthand for some symbols" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "1CLWJl0QliB0" - }, - "source": [ - "## Derivatives of a function\n", - "\n", - "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: " - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "9FViq92UX7P8" - }, - "outputs": [], - "source": [ - "from math import pi\n", - "\n", - "def f(x):\n", - " return tf.square(tf.sin(x))\n", - "\n", - "assert f(pi/2).numpy() == 1.0\n", - "\n", - "\n", - "# grad_f will return a list of derivatives of f\n", - "# with respect to its arguments. Since f() has a single argument,\n", - "# grad_f will return a list with a single element.\n", - "grad_f = tfe.gradients_function(f)\n", - "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "v9fPs8RyopCf" - }, - "source": [ - "### Higher-order gradients\n", - "\n", - "The same API can be used to differentiate as many times as you like:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 276 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 730, - "status": "ok", - "timestamp": 1527005655565, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "3D0ZvnGYo0rW", - "outputId": "e23f8cc6-6813-4944-f20f-825b8a03c2ff" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAEDCAYAAAAhsS8XAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXd0HNX5sJ/ZXrTq3ZLV3IvcDdgGGwOm2WCbHhJa6C2B\nUBISQioBfoQPkjhACA4QCIQSDITQbGMbsHHvVbZ6s7q0vc18f4xmJVltJa0q+5zDOXhn9s7dqzvv\nfe/briBJkkSYMGHChBkxqAa7A2HChAkTJrSEBXuYMGHCjDDCgj1MmDBhRhhhwR4mTJgwI4ywYA8T\nJkyYEUZYsIcJEybMCCNkgl0URVasWMHtt98eqibDhAkTJkwvCJlgf+2118jJyQlVc2HChAkTppeE\nRLBXVlayceNGrrjiilA0FyZMmDBh+kBIBPvjjz/OQw89hCAIoWguTJgwYcL0gT4L9g0bNhAfH8/E\niRMJVycIEyZMmMFH6GutmGeeeYYPP/wQtVqN2+3Gbrdz3nnn8dRTT3X6HUmSwtp9CKittvH8UxsQ\nxZY/4aXXTGfa7PRB7NXAU1dj5y9PrIfmYUgeFcnya2aQmBI5uB0bYE5WNPHS/9uE6JcHYukVucw8\nPWOQezXw7NhcyCfvH0Bqfi+uumkO4ycnD3KvBpY+C/bWbNu2jdWrV/PCCy90e291tTVUj+03EhIs\nQ7qfWzfls2tzMTNPH01UrJEv/3eU5LRIVnx/5mB3rUP6azw3fnaMQ7vLOX1RNrVVNvIOVZGeFcPS\nq6YNmT6GmlP7KYoi/3ltF9WVNhacO4btXxfi9fi5+Mpc0jJjhkw/+5t9O0r5Zu1xDEYtpy/KZuOn\nR4mOM3HlTbNRqTo3UAynv3swhOPYhymSJJF3sAqtTs35l05mQm4K6VkxVJY2UVdtH+zuDRgOu4ej\n+yqIjDYwbW4a514yiYTkCMqKGnC7vIPdvQFjz9YSqittjJuSxNTZaVywcgoAX3xwCL9PHOTeDRyH\ndpej0ai47PqZTJyWwoTcFOprHBzdf3KwuzaghFSwz507NyhtPUzfOVnehLXRRdbYeLQ6DQATp6UC\ncGhv+WB2bUA5sLMMv19i2pz0gEaWNS4BUZQoOlE3yL0bGDxuHzu+LsRk1jH/nDEApI6OZtL0VFxO\nLyfLmwa5hwNDU4OT+loHozJiiIw2AjB7QSYajYrtXxXg9foHuYcDR1hjH6bkHawCYOzkxMBnmWPj\nMJq1HDtwEt93YBJ7PT4O7CrDYNQwPrfFhpo1Lh6AgmPVg9W1AaWyrBG/X2JCbjIGozbweXqWbIIp\nLawfrK4NKMX58kI+Oic28FmERc/UOWnYbR7yDn53tPawYB+GiKLI8SNVGEzaNvZTtVrFhKkpuF0+\n8o+OfKGWd7gKt8vHlJmj0GrVgc9j4kxExRopzq/7Tixw5cWNAKSkR7f5PHV0NIIApUXfEcHevEMb\nnR3b5vPxU5IAqChpHPA+DRZhwT4MKS2sx+XwMmZCYjuH0MRpKQAc3lsxGF0bUJQXNWdiYpvPBUEg\ne1w8Pq9IyXdAW60oaUAQ5Gig1uj0GhJTIqkqb8Lj9g1S7wYGn89PWXE90XGmgBlGITrWhN6gobIs\nLNjDDGE6MsMoRMUYSUiOoLKsacQ7zaoqrGh1amLiTO2uZY1LAKDgWM1Ad2tA8Xr9VFVYSUi2oNNr\n2l1Py4xBkqC8pGEQejdwVJQ04vOKZJyirYO80CeNiqSpwYXD5h6E3g08YcE+zJAkiZKCOswWHUmp\nHcdpJyRbEEWJupqRGx3jdvloqHWQmGLpMCciMcWCOUJH0fEaRHHkLnBV5U2IokRKelSH10dlyOaZ\nkW5nD5hhcuI6vJ48Sh6fyrLvhiM5LNiHGQ67B6fDS2JyZKdJXgnJcqxrdeXQj8vtLcpv6ywJSRAE\nMsfF43L6RvTLXF4sa+Kn2tcVkkdFodGoKCvqu8b+zjtv8f3vX8Fvf/ton9sKNUX5tWi0KlLSOl7g\nFDPVSJ4LrWm/dwszpKk5aQMgLimi03u+C4JdCeFLTOk8YSMlLYqDu8qpOWkjtRPBN9wpb/YzpHai\nsas1KlLSoygpqMdhc2OK0Pf6WWvWvMsf//hnkpNTet1Gf9BY76Sxzknm2DjUmo51VXlnBye/I3b2\nsMY+zKitkgV7fGLngj02wYxKLVBdaRuobg04VRWyYO/MHAUQlyCPkTJmIw2/X+RkeRNxCWb0Bm2n\n943KaA577IPW/vTTf6C8vIyHH76ft99+s9ft9AeKUzQto/MMW61OQ1xiBFWV1hHve4Kwxj7sUDT2\n+C40drVaRVyCmdpqG36/iFo9stZvSZKoKrditugwWzrXQKNijahUwojLxH17/XF25VXj8fhx+Hzo\nGh1s/+vmTu8X/SJ2RA59egTDxhMd3jNnQiJXLh7TaRsPPPAztm79lj//+UUiI4dWDZ76Zl9SXBfK\nDshmqZqTNqpPWgM295HKyHrjvwPUVNnQGzRERHa9pU5ItiD6pREn1ADsVjcOu6fbIl9qtYqYeBN1\nNfYRWXlU0Ty7W7hVzddFf181VYlApbUhhDLHYxPMXd6XnCbPl5PfATt7WGMfRng9PhrrnM2JJ11X\nx5Tt7BVUn7QGbO4jhZPliuO0+98VlxBBbZWdpgYnUTHtwyKHI1cuHsNdV83g1b9upuhELdf/cG63\ntvN/v7ydpgYnN99xxoirrFpX48Bk1rXJuu2IlsiYRqYxsiughjX2YURts2bSlRlGocWBOvLsy4p9\nPZiyvLGJshZXWzXydi71tXYMJm1QDtHYeBM+r4itaWTFcXs9PqyNLmLiu1+0IyL1mCN0VJY2jcgd\nXGvCgn0YEbCvd2NLBIiNN6NSCdSMwMiYqoqeaeww8hyoPq9fFmixwe1CouPkBa6+ti8L3NDT9Otq\nHED3ZhiQQ2ATUyNx2D3YbZ7+7tqgEhbsw4hgHKcKao2K2AQztVWyA3WkIIoS1ZVWYuJNHWZankpc\n8wtfO8J8DbLfAKI7yLrtiNhmjba+WRD2hnfe+YDIyKHldAzY1+O7F+xAIEu5sa734zAcCAv2YURt\nlQ2VWgj6ZU5ItuD3S4GogZFAU4MTr8dPQlJwfgNThA6DUTPinMi11fIi31E5hY5Q5kx97cgSaMrc\nDkZjB7luDEBDnbPf+jQUCAv2YYIoitRW24mNNwcdvjgS7eyN9fILGR1r7OZOGUEQiE2IoLFeXhBG\nCjXNpqXoYE0xMSYEoa+mmKGHUjYjJi44wR7VPG/CGnuYIUFDnRO/TwzKDKOg3DuS7MuNzZpWVJAC\nDVrMMSOpdk5AsAepsas1KiJjjNTXOEaU47Cuxo7ZokdvCC7AL6yxhxlS9MRxqqBotU0NI2cSN9bL\nmlZUTHAaO7Qkrijmi5GAYpazRBmC/k5snBm3y4fTMTKODHS7vNitnqDNMAAGoxaDUUNDWGMPMxRQ\nbMTdZde1Rm/QojdoaGxw9Ve3BhzFFNMTwa68+HUjJORRkiRqquxExciZtcESHXCgjoxxCETEBBHq\n2JqoWBNNDc4RFVRwKn0W7B6PhyuuuILly5ezbNky/vKXv4SiX2FOQdG6eyLQlPubGpyI4sjYfjfW\nOzGatEFFxCgoERMjJTLGYfPgcfuCdpwqxI4wB2rAcRpkRIxCdIwRSQJr48hReE6lz4Jdp9Px2muv\nsWbNGtasWcOmTZvYt29fKPoWphVNDU7UGhWmCF2PvhcZbUT0S9itwz8xxe8XsTa6Ag6wYNHq1ETF\nGKkbIaYYRTAHa19XiGkWgL0NeWxdtvebb77ijTdeDfq7lZUVfPHFp0Hd+/jjv2bjxvXd3te6lMCa\nNe/x2Wf/C6r9qICdXR6HTz75L9XVLUdJPvnk7ykqKgyqraFKSEoKGI3yi+bxePD5RvYRXINFU4OL\nyChDj9PBFQ2/sd7ZI3vsUMTa6EKS6FVpgMgYIyX5TjxuX4+0/aGIIpCCTU5SUByHvY2MObVs7/z5\nZ7a7x+/3o1ar231eXl7GF198xnnnXdCrZ3eE4gyPjDawfPllQX9PGQfFEf+//33EzJlTSUrKAODh\nh38esj4OFiGZ4aIosnLlSoqLi7n22mvJzc0NRbNhmnG7vLhdvnZnWgZDZLQszGVTTudlTYcDgYiY\nHpqjACKjlHFw9SiyaCjS0EuNXatTY4nU98oU07ps78UXX4LFYuHIkUPcd99DPP74r7FYIsnLO8r4\n8ROZP/9MnnvuaQRBQKvV8OyzL/Dii6soKirkppuu5YILlnLllde0af+ZZ55k9+6dpKSktonaOXr0\nCH/+8zO4XC6ioqL5+c8fIzY2jnvuuQ3RFUt1XSGxH5Zht9sxmUycccYCfve7x3jpJXk3UVlZwcMP\n38+rr77JK6/8nW+++QqHw4lWSmTS9HvYsGEdR44c5sEHH0Sj0fL886t54IF7ufvu+zh8+ADl5eXc\neee9gKzZHz16hB//+AE+//wT3nnnLfx+H5MmTeEnP/npkKrBExLBrlKpWLNmDTabjTvvvJPjx48z\nZkznJUDD9IymZufnqYf0BoMiBEdCZExDLyJiFJQFztroHPaC/Vv315ROK+RP+ZsRCnomTJxjPfh8\nInnfbGgjiGYkTmXlmKWdfu/Usr2ffPLfNt8vLS3mT396AYCHH76Pn/zkp0yZkktEhIamJg+33343\nb731Ok8++f/atb1x45eUlpbwz3++TU1NDd///hUsXXopPp+PZ599iieeeIaoqGjWrfuCF19cxc9+\n9kskScLpsHPdlT9l6VXTWL36bwBkZGTi9/uoqCgnJSWVdes+55xzzgPgssuu4oYbbsbn9XPj9+9k\n566t3P/Idbz33tv88pe/ICGhbWGwRYvO5fbbbwwI9nXrPuf6639IUVEh69Z9zgsvrEatVvPHPz7J\n559/wvnnX9Sjv0V/EtI9aUREBHPnzuWrr77qVrAnJAyPioNDoZ/VzdUMU9OiO+1PZ58b9HLFO5fD\nNyR+S1/64HHKCUaZ2fE9bidttLxb8fukbr87FMapK9xuH6oIAU0v6uyrNSp8PhEkUKtbBLPJqOv2\nd6tUEBdnJjragsViwNj8HYNBy8KFSwPfP/30uTz//HMsW7aMJUuWkJSURHS0CZ1O0+Ezjh07wIoV\nl5KQYCEhwcK8eWcQGWnEZquhoCCfBx+8F0mSEEWRxMREEhIsCAhkpE4nMTmShAQLZrMes9lAQoKF\npUsvZuvWTdxyyy1s2rSeZ599loQEC7t2bebll1/G6XRSVV9FWXkaCQkWtFo1ktQyL7RaNTExJsaO\nTSczM4OKigJGjx5NeXkpixcv4I033uD48WPccceNSJKE2+0mLS15SM2bPgv2uro6tFotFosFl8vF\nli1buPXWW7v9XnX10C9OlZBgGRL9LC2WDyJWaYQO+9NVPyVJQqNVUV1pHfTf0tfxPFkhn5QjIva4\nHalZhlWUNnb53aHyN+8Mr9dPbN5YZo45gwvPn9Lj7x/aU87GT49x9sUTmDA1uc217n63KErU1trw\netVYrS6cTg/V1VZcLi8+X8vcXLHiGqZNm8uWLV9z5ZVX8swzq2hocODx+Dp8htPpwWZzB6653V6a\nmpzU1dnIysrm+edXt+uny+VFY9Gh0amorrZit7uRJDXV1VZOO+0sHn30p8yaNQ+/X8JojKGsrJZf\n/erXrF79OvHxCTx8/2+wNjgpL6vH6/W3+f1er5/6egfV1VYWLDibd99dQ0ZGJvPnL6S62orV6mTJ\nkou47ba7ejR+oSDYxaPPUTHV1dVcd911XHrppVxxxRUsWLCAhQsX9rXZMK1QzCi9McUIgkBktJHG\nBuewzzhsqHNiMut65fxsbYoZziip8PGJPQvxU1Ac6LZ+DPUrKyslOzuHa6+9nilTplBcXIjJZMZu\n79hpO23aTNau/RxRFKmpqWHXrp0AjB6dSX19AwcO7AfA5/NRUJAPtBwy0tE7MWpUGmq1ilde+TuL\nF8tmGI/HgyBAZGQUDoeD44W7geY5ZTJhs3UcMbVw4WK++mpDG5POrFlz2bBhHfX1ssLV1NREZWVl\nr8aqv+izxj5+/Hjef//9UPQlTCcoSTmW6N5FtcihfnacDi8mc8/CJYcKfr+IrcnV6yPN9AY59r1p\nmMcuK6nwSjninqIIdmtTb8YhOHv+O++8ya5dO1Cr1YwfP47TT58PgFqt4cYbv8eFFy5r4zxduPBs\ndu3azvXXX016egYzZswCQKPR8LvfPcmzz/4fNpsNUfRz5ZXXkJWVjd8vtfk9p7J48RKef/5P3HLL\nnYBsJl62bAXXXXcVKSmpZGeNw14vv1sXXbSMxx57DK1Wx/PPr27jO7BYLGRmZlNcXMiECZMAyMzM\n4pZb7uT+++9CFCW0Wi333/8QycnJHfZlMBCkQVLjhvJ2V2GobMtff/5b/H6R6++e1+H17vq5ef0J\n9m4rYcX3Z5CcNnhlV/synvW1dt56aTsTpiZz9sUTetXGO//YQUOtg5t/cmanEQxD5W/eGbu/Lebb\nDflcddOcwCEiPcHn8/PS018xKiOaS66Z3g89bEt/jeen/zlAwbEarr9nXq+UleL8Wj5+ez9zFmQy\ne0HmkP+7KwyYKSZM/6JoqpG91NahVSz7MI6MCZQS6GFyUmssUQZ8PhGnffgesqBkS0b38pg/jUaN\nyawb9lmX1kYXGo0Ko6nr4/A6I1AMrH5kZOGeSliwD3FsTW4kCSKjei/QomLkRUERjsORvsSwKyj2\n2OFsjrE1m1D6Mg4RUfrmeTV8fS7WRheWXiTsKUREGlCpBJrqh+9c6IqwYB/iBBynoRBoI0Fj78OB\n1C3JWsP3ZbY2udHp1d0e3NwVkVEGRFEatsfDuV0+3C5fnzKpVSoBc4QOm3X4zoWuCAv2IU5LclLv\nJ7GinQxrjT0g2Hs/DgHH4TBd4CRJwtroIiKyb6UhlO/3Z2RMf6LsWvpaIiMi0oDd6hmRVR7Dgn2I\n05dQRwWVSsASbRjW205rkwuDSYtW1/tAroDGPkwFmsftw+vxY4nU96kdRSAO13FQ+t3bKDGFiCh5\nHEdCgbxTCQv2IU6LYO/bJI6KNuJyyjVnhhuSJGFvchNhCZFAG6amGGujLIAi+qipBmLZexXyOPhY\nlV1sCDR2kP1YI42wYB/iNDXI3v++xp8PZzu72+XD5xOJ6KOmqtGoMUcM34gQJfbc0kdTjPL94TYO\nu3fv5KGH7gv0uzNTzD333MbRo0e6bU/Z+diaXPzpT39i587tverX22+/idvdsjg89NCPsdsHt0R0\nWLAPYSRJoqnBiSW6995/BWXbaRuG207lRY6w9L3ssCXagK3JNSztqoqG3dcFztI8F4abYAcQBLoV\n7MGiaOyNDU7uvfdeZs2a06t23nnnTdzulrF86qlnMZsHt9Dc8C5MPcJxu3x43H5S0ntvX1dQzBjD\n0Z6oLEZ9FWggh41WljZht7r75LcYDBRTTF8FmlanwWDU9Eiwu1wufvnLn1JdXYUoilx//c0sXnxu\np2V1y8pK+b//exybrQlJEvjtb58gNXUUq1Y9x9atmxEEFddddxPnnHMeu3fvZPXqvxEVFU1BwQkm\nTJjIo4/+FoBvv93Mn//8DNHRMYwdOx6ApkYnGq0qEBnkdrt5/PFfU1RUSEZGBh5PS7TP9u3f8vLL\nf8Pr9TJqVBqPPPIYBoOBK664hLMXXcAXm7/Er1vGV9veZNas09HrDfzvfx/xm9/8AZB3Cf/+9xs8\n8cQzPP30Exw9egi3282iRedw00238u67b1FTU80999xOdHQ0zz33PFdccQkvv/xP3njjNZKTU1ix\n4nIAVq/+G2azmauuupZ//euffPnlF3i9Ps46axE33dR9fa2eEBbsQxjlxeurLRFaBPtwtCfam0In\n2C2tQh6Hm2BXNHb/hv+yY9WePu065tg8iH6J/IffBsAyew4JV1zd6f1bt24mPj6Bp556FgCHw95l\nWd1f//oXXHfdjaxYsZTy8jpEUWTjxvWcOJHHa6/9m/r6Om6++TpmzJgJQF7eMV5//R3i4uK4444f\nsn//XsaPn8hTT/2eP//5RUaNSuOXv/wZ0D6Gfc2adzEajbzyyr84ceI4N910LQCNjQ28+upqnnvu\nr+j1Bt5441Xeeut1brjhZvk3R5pZMu8uRqfHcrSkVB6XOafx9NN/wO12odcbWLfuCxYvXgLAbbfd\nhcViQRRFfvSjO8jPP87ll1/Nv//9ZqCcsYzcr3PPXcJzz/0xINjXr1/LM8/8me3bv6W0tJiXXnoN\nSZJ4+OH72bt3D9OmhS4TOCzYhzC2EAo087DW2BUTRN8XuMCBG43D7+ARa5MLlUpAq1XTVxe4oBKQ\n/CKSJJs3uiM7ewyrVj3HCy/8hTPOWMC0adPJzz9Bfv4J7rvvruayuhLx8Qk4HA5qaqpZsEAuBqjV\nypr1vn17OPfc8wGIiYllxoxZHD58CJPJxKRJk4mPjwdgzJhxVFRUYDAYSU0dxahRaQAsWXIhH3zw\nH3kXm9ayKO/Zs5srmhelnJwxjBkzDoCDBw9QWJjPHXf8EEmS8Pl8TJkyLfC9JUvO57//ymuj7KjV\nak477Qy+/vorFi1azJYtX3PXXT8CYN26z/jwwzX4/X7q6mopKCggO3sMIDX/pyD//9ix42loaKC2\ntob6+noiIyNJTEzinXfeYvv2bdx007VyXXmni9LS4rBg/66gCGFzH6NBWrcxHCMhrMoCF4JxaHEi\nD79xsDW6MVv0JF55NQl33dKn2ibfrDvOvu2lrLxuJkmp3Z/MlZ4+mpdffp0tW77hxRf/wty5p3PW\nWYvIzs5pV1bX4ei4iuOpma6t/60IfwC1WoXf3/HS5fPKu5RTzVGtfVBKu5IkMWfO6Tz22O86bMto\nNBIRaWj3TixefB7/+c/bREZamDhxMkajkYqKct566w1efvmfmM0RPP74r/F4uleSzj77HL78ci21\ntbWcc86SQL9+8IMbuOSSFd1+v7eEnadDmIBtOQQCTa2WD8Iejs5TW5MbQQCzpe+VKZXdj32YmaT8\nPhGH3ROyc2t7GvJYU1ODXq9nyZILuOaa73Ps2NFOy+qaTGYSE5P46qsNAHi9XtxuF9OmzWTdui8Q\nRZH6+nr27dvDpEmTO31mRkYmlZUVlJeXAbB27Wf4mmuntx6H6dNn8PnnnwCQn3+cEyfyAJg8eSr7\n9++lrEw2s7jdLkpKits8IyJSj8ftlw8faWbGjFkcO3aUDz9cEyjVa7fbMRqNmExm6upq+fbbzYH7\nuypJvHjxeaxb9zkbN67n7LPPAeC0007n448/xOl0No9tdaAEcKgIa+xDmFBq7CAvEDVVNiRJGlLn\nM3aHvcmFKUKHStV3PSSwcxlmC5xijuprcpKCEvIYbJJSfv5xVq16DpVKQKPR8sADP+uyrO4vfvFr\n/u//HueVV15CENT89rdPsHDh2Rw8uI8bbrgGQVBx5533EhMTS2FhQZtnKXNTp9Px4IOP8OCDPyI6\nOobc3OmcrKiT+99KsC9ffjmPP/5rbrjhe4wdO45Jk+QDSKKjo3nkkcf41a8ewePxIggCt9xyB+np\no1Hs4Ip5T1kwQD7qc968BXzyycf84he/BmDMmLGMHTueH/zgKlJTR5Gb22LSueSS5TzwwL3Exyfw\n3HPP07q8cVZWNg6Hg4SEJGJj4wCYM+d0iooKuf32GwEwmUw8+uhviYkJnWkwXLa3Cwa7lOcH/9pD\neXEDtz54FuoujkELtp99LXXaV3oznqIo8dLTm0hIsbDyBzND0o9X/vQNOr2G7912Wkj6OBCUFtbz\n0Vt7mTUvg7lnZfW5nzUnrbzzj51MmZnKmUvGhbCnbQn1eH79RR77d5Zx+Q2zSEju+1F0u7YUsXVj\nAVf/cC4xCb2vQzRQhMv2jgDsVjdGs7ZLod4ThmPIo8PuQRSlkJijFMwWPXbr8KpuGKr6KAqBujnD\nLJY9lKGvcjtKlNTwS9zrirBgH6JIkoTN2vc0+tZERA6/kMdQJeW0JsKix+cTh1V5hUCSVojGQT5R\nSh1wTA8X7FY3KrXQp+qWrVHGczifVdARYcE+RHG7fPh9Ysjs69CqNsYwKlVqDziQQ6OpwvAM/VQW\n41Bp7CDb2a2NrmG1c7Fb3Zgj9CHzEQV8DcO48mlH9FmwV1ZWct1113HRRRexbNkyXnvttVD06zuP\nLYQhfgrDWaCFUmMfjg5UpU5MSOdDpB6vx4/X4+/+5iGAKMqRQaGIjlIwRegQhJGnsfc5KkatVvOz\nn/2MiRMnYrfbWblyJfPnzycnJycU/fvOEuqIGBie2afWfjDFBBY42zAah0YXRpMWjVYdsjbNES0L\nvU4/9APkHHYvktTS71AghwHrh/VZBR3RZ409ISGBiRMnAmA2m8nJyaGqqqrPHfuuE8oYdgVThKzp\nDCfB3qKxh84EEXAiD5NxCPhbQjgGMPwWuP5QdkAOIW1qdCGKw8ck1R0htbGXlpZy5MgRcnNzQ9ls\nv2I/sA9nfv5gd6Md/TGJ1WpV83Fg7V9kT2UFjsOHQvasUKE4y3p7aHFHKFv51uMgSRKi14vo9SL5\nhpZT1enwIvqlkO5aAMzNC73d2lI0S3S7se3Zjd/WtuyszWbj/fffDfxbKaHbEU8++XuKigq7fX5X\nbbRGKcMbeCeC0NhffvnFoMvwRkQakEQJR/MC9/bbb+JyOrHu2IansmJIlOHtKSHbf9ntdu69914e\neeQRzGZzt/cHG4/Z3xT+8xU8dfWkLL2IjB9ci1rfdtIMVj+V1OnRmXHExoduPKNiTVSWNRIfFwGS\nSPlHH1P15QYchUUApF9zFaOvvrL3HQ9RPxUcNg9R0UYSE7tPew+WSItcVsDr8ZOQYEH0ejn8uz/Q\nsGcvxwFUKjK+/z3SLuu/lO+eUOlpBCA+IaLN+PV1bqamRcv/I8lt+Z1ODj3zJE2HDiOo1URNyyV1\n6UXEzJqJ293IRx/9h1tvlZNqoqNN6PWaDvvw9NNPtPm3co8oim2SzLpqozVarZqYGBP2etlhmjIq\nqsvviKLIT3/6QPcD0ExisoXjh6vQqNUkJFh49+03mJF/HOn4CZIvWMI//vFy0G0NFUIi2H0+H/fe\ney+XXnoGQVm+AAAgAElEQVQp5557blDfGSpJIEm33U3l6r9R8dHH1Gzbwagf/wRdQiIwuMkqNVXy\nc90eb7d96Ek/DUYNol+iuKgW19frqHn3bVCrMU+bjqesjJI3/43D4SFu2aV9/g196SfIafQ2q5vU\n9KiQ/x10ejX1tQ6qq61UvfUGDXv2ohuVhikhDmt+AUX/fANfXDLmyVNC+tzeUFoip5urNEJgHEIx\nN33N1SGrKps4WVpD2XPP4Dx2FOOEiYgOBw27dtOwZy8Zj/2Gx//2V4qLi1m27BJmzz6NM86YT0ND\nE7fddme7Urv33HMbd999H+PHT2DJkrO46qpr2bbtW+6++8fY7fY2ZXg9Hl+733FqGV673Ul9vYP6\nCh8V1cf42aOrEVRSuzK8F198Cdu3b2XlyivZunUz8+efGVQZ3sYGG3GWCZQUTeTdF/8fVSdP8ugX\nnxEdHcOq85exaNHZg16GVyHYxTwkgv2RRx5hzJgxXH/99aFobkAxZmeT8cvfUPOfd2hY+wU1775N\n6h13D3a3sFvdGIyhdZZBS9hgQ1k19o8+QB1hIePXv0MTFYW3tpbS/3uC2g/eR9DpiD3/wpA+u6co\ntt9Q25ahJUnJumsnDWu/QJeSyuhHHiUpLZ6SbXspfuL3VP79RTIe+y2a6OiQP78nKONgajZBbF5/\ngsK8GsQ+HhaimJSP7KvEsXsHWXlHiZg9h5RbbkdQq7Ht3kX5qj9R9cY/uf32uykszGf16jcAWUB2\nVGp36tRpbZ7hdDrJyRnDD394Gx6Ph6uvXtGuDO+pdFaGt7qqhgN5a3nxpb+RkBTdrgyvTqdn1aqX\nALnMMARXhvfEkZM89PC9HDt8mNOLi3lPq+WPj/6G1IVnN4dVDn4Z3p7SZxv7zp07+eijj/j2229Z\nvnw5K1asYNOmTaHo24Ch0ulIuOp76DOzsO3cgfuUQkEDTX8kJykoNvvyz9Yjud3EX34lmqgoALRx\ncaQ9+DDqqGhqP3i/nZ11oOmPUEeFCIset8tH+auvIOh0pNx+F6pmM5whK5uEK67Cb7VS8dILSOLg\nnrak2MAVm3ioUELBRb+Ir64Oc+40Um6+DUEtKxMRM2ZinjET57Gj2Hbvavd9pdSuIAiBUrunotFo\nWLhwMQBFRYXtyvB2xJ49uwPXWpfhPX7iCI22kzz007u48cbv8emnH3Py5MnA95SCXa1pXYbX7/ez\nZcvXnHmmXE543brPuOmm7/PL395Do/Ukx3ftQLTZEIwmLDNntYqVb1+G9/jxvEAZ3m3btgbK8N50\n07UUFxdRWjq4MqTPGvusWbM4fPhwKPoyqAiCQPzyFZQ9+wy1H35A6l33DFpfPG4fPm9ok5MUApl2\nReUk5owhct78Nte1cfHELDmfmnf+TeNXm4i98KKQ9yFY+iPrVEFxwDk9kPW9a9GPGtXmevQ55+E4\nchj7nt3Y9+4hYkZo6tT0BsWpp8yHeYtzuPSq6SExT73+/Ld4GxsZW7uDhB8/jqBpKxISr7qGwoMH\nqPv4w3YLXDCldnU6Xa+SiToqw+tyekhLnsA//vFih98xGjs+OKW7MrySX8Odt/4Ya1UtKrMZVSft\nwOCV4e0p4czTVpgmT8WQnYNt905cxUWD1g8lWsPcLwJN1vpcmggSr/0BQgcVE6POPAtBr6dh/dpB\njRCx2xRNNfTjYDLJWqkvbhSR889sd10QBOIvXQlA49eDuwPtL40dwKgRcUtajJNz0aWktruujU8g\n9qKlaB0ObLW1PW6/dVZrR2V4O6KzMrwW4yiq6gq6LMPbEd2V4XW6rZRXH8EraIg9/0LM5oghV4a3\np4QFeysEQSDuUnnVrf1wzaD1w94PMewK2qZqAPyJ6RhGZ3R4j9pkJmr+Anz1dR1uwQeK/opbBlDX\nyMJFGJ/b4eIGoE9PR5+ZhX3/PnwNDSHvQ7DYbW40GlW/JBFprDVIggrDWZ0HPcScfwGRlkhydDqu\nu+5q/vrXP7W7p7WG3dn/63Q6Hnro5zz44I+4665bSOlgIQG5DK/D4eCGG77Hm2++zqRJU/B6fKgF\nI8vOv5lf/eoRrr/+Gm677SaKAwpY57sCpQzv1q1bmDdPXsRbl+F96onfkBSdgU+tJ3rxOYEyvD/6\n0R3t2u6sDO95553P7bffyPXXX82jjz6M0+notD8DQbhs7ylIkkTJE7/HdeI4s/72PFbVwJ+LeWhv\nORs/OcbZF09gwtTkbu/vSYTEybfe5D8FSSTGaLns9vaaqoLnZCWFP/8phpwxjP7ZL4Lue6j6CfDZ\n+wfJP1rNdXefEXKtffsfVrFDmMycuUnMXjyx0z42bFhP1euvEb/ycmIvWhrSPgTLK3/+Bp2ubZnh\nkETFNNTz6ZNvURI1kcuun0liSuchpZWvrKbp602kPfAwpgkTO73vVEIVWVZfY+etv29n4rQUFl04\nvs/ttWl7/Vo+3tSI0xTLzQ8uGtJnFYTL9vYSQRCInLcAgLqt2walD/Z+qBMDIIki9p3b0IsunGLX\n2p8uKRlz7jRcJ47jzD8R0n4Ei8Mun5xkNIXWBOEuL0dVIv8mp6/rqCPL3NMRtFoav/lqUIpl+f0i\nTrs3kDUcShq+XI/eKzvIFbNXZ0SedjoA1m1bQ96PYLDb+m/3Ztu1E53fgU8Uhk3dnO4IC/YOiJg+\nAwSB2i3fDsrzbf1kgnDmHcNXX49Rr8Jh93QrqKIXy9tza6tjwAYSu9WDyaxDpQqtBtX09Sb0Pnvg\nGV2hNpmImDUb78mTOPOOhbQfweC094+fQZIkmrZuwaCSfSjdFYYzjp+AOioK687tg+J3USKkQlkA\nDMBvteI8dpSICNkRPFzKK3RHWLB3gCYqCuOYsTQdPoKvqWnAn99iYw/tJLZukxeqiPhI/H4Jj7vr\nF9Q0YSIqgwH7/n0Drq1KkoTD7gm5pir5fDRt+Qa9SYtaLQRV4TFqwVkANH61MaR9CYYWB3Jox8FT\nUY6vpoao0SmAnOHbFYJKhWX2XES7HfuhgyHtSzD0lyPdtncPiCIxaXJSYncL/XAhLNg7IWLGTJAk\n7Ht2D/izbVY3Or0arS50zjLJ58O6cwfqqCgik2SnT3eTWNBoME2egre6Gm9l+xjl/sTjluvRm0L8\nIjvzjuG3Womae7qcpBSEhmYcPwFNfDz23bsGXFvtLweyfd9eAGInjmnznK6wzJVt/IqCMJD0V0CB\nbfdOABInZAItoaXDnbBg74SIGbOAlj/8QOKweUKumdgPHUS02bDMnoup+eVw2LufxObmTEJbsyAY\nKPorxM9+8IDcbm4uZoseh82Dv5sMTkEQME/JRXS5cBUMbME4RZMO9c7Fvm8vCALxM+SSCcEscIbs\nHLTxCdh270Z0D6wA7I8FTnS5cBw8gG5UGjHpSfJzutm5DBfCgr0TtAkJmLMycRw+hN85cLWa/c1H\ntoX6Rbbtkhcoy9zTWqr6BTGJzVOnyvfu3xfS/nSHsuiEWmN3HDyAoNFgHDs+ICQUO3ZXmCdPBhhw\nM0TAaRjCcfA77DiP52HIysIQG41OrwlqLgiCgGXuaUhu14BXArXb3KjVAnpD6Hax9gP7kXw+ImbM\nDJykFLaxfweIPf00JJ8P+/6B01Zb6oKEVrA7jxxGZTJhyMoOtN2dXRVAExWNPjNLNmEM4ALXHxq7\nr6kJd0kxxrHjUOn1LQePBGGGMI6fCCoVjmaNf6DoD03VcfAgiGJgN2a26II+VcvUXBTNcWSABbvV\ng9kSuiPxoGU3HjFzVuDIwWDeieFAWLB3QdzpcwGwD2CSjqNZezSZQ/cie2uq8dZUYxw3HkGlajk5\nJ0jtxDw1F/x+HIcGTqgFxiGEgt1xWNa2TZNk4WQyB7/AqU0mDNk5uAry8XeSldgf2PvBFKPY1825\nzYI9Qq6b4/N2H+pnyM5B0GpxHDkSsv50h9/ffCReCHctkt+Pfd9eNHFx6NNHN5+jGjbFfCcwZWSg\njo7GcfTIgEWFOPohCkJ5CZXEkp4INICIZgFg3zdw5pieHKoQLIq2bWo2q/Rk5wLIJXwlaUC1VbtN\nPrZOG6Iqn5IoYj+wD3VUNPrmzOOWk5S6HweVVotxzFg8pSX4rAMTMRYI+QzhrsVdXITodGKeMhVB\nEFCpBExmXdh5+l1AEARM48bjb2rC26qKXH/SH84yx1G5SJsi2I1mbY+0E31GJmpLJPb9ewes0mGo\nw/wkScJ+8CBqiwV9Wnpz280CLQgnMoBpkrwgOAbQzu6whfbwZldhAX6rFfPU3IBZo+UkpeDGwdg8\nj5xHj4asX13RktcRwnfimNx347gJgc9MEXrstu7zO4YDYcHeDcaxcvqy89jATGJFyChadV+RJAnn\nkcOoLRZ0qXIFQ5VKhdEUvHYiqFSYp0zF39SEp6wsJP3qDiXr1BCirFNPeRn+xgZMkyYHasP0VGM3\nZGahMhqxHzwwIC+/z+vH7fKFdtfSvCgpTnHo+dmnioLgODIwVV0d/RDDrrzPxrHjAp+ZI3T4fWK3\n+R3DgbBg7wbjOFmwO/IGRrAHJnGItp3eqpNytun4CW2KXZkidEFlnyooL4DzeF5I+tUdoc46DZhh\nJrWciNRTk5SgVmOaOAlfTQ3eATiwvT+Sk5S/n6KwyO03C/Ygk3MMGZkIegPOgRLsIfa3SKKIM+8Y\n2oQEtLGxgc+VMOCRkKQUFuzdoEtJQRURMWAae8AUEyKNXdGqTi3cZI7Q4fOKeNzB1cYwjh0LDIxg\n74+sUyVMUQlbBNDpNWi0qh5FQgSiQgbAkRyIkArRIi+JIq4Tx9EmJaGJbCn4pZg4gtXY5XDRcXgq\nK/A19H952lC/E56yMkSHo83iBq1MUiPAzh4W7N0gqFQYx47DV1uLt7am35/nsHnQaENXotXZiWBX\n4sODSVIC0CY3L3An+l+whzrrVBJFXMfz0CYno4mOaXPNZNYFbWMHMI1vti/n9f84hNqR7ikvQ3Q6\nMeaMbfO5orH3xHFomiDbph1H+z865tSjAfuKsvtWduMKph7kdwx1woI9CEwBO3v/F4Gy290hsyVK\nkoTjyBHU0dFok9qW/+2xGUIQMOaMwVdT0+9aWqhj2D3lZYguF8bsMe2umSL0uBxeRDE4k5Q2KUle\n4PKPh6RvXRHqyKCAGWZMW8FuNMsFsHq0c5kwSf7OAJye5rSHVmMP2NfHnaqx93yBG6qEBXsQKBPA\n2c92dlFsLtEaqi1nRTl+axOm8RPbJXa0bL+Df5kVgdDf5phQZ50qZYcNOe0FuzlChySB09GDBS47\nR17gGvv38I1Ql6pV/m6GUwS77EzXYg8iA1dBP3o0KpMJ59H+F+x2m6f5oJG+h3xKkoTz2FFZ2UlI\naHOtJToorLED8MgjjzBv3jyWLVsWiuaGHPr0dFQGQyBEqr9w2r1A6JxErhOyVqnYx1ujJED1RDsZ\nKMEeao3ddUIW7MacnHbXerpzATlJB8DVz3XqQ+08dR0/jspsRpfc/vAWU4QuqNIKCoJKhTFnDN7q\n6n6vgKr4W0KRdeo9eRJ/UxOmcePbtWfqYeLeUCYkgn3lypW8/PLLoWhqSCKo1RjGjMVbWYmvsbHf\nnhNq779SsEoRRK3paagfgD4zE0GjwXm8f80QIR+H/BOoDIZAuGdrejMOxmbN33mifwW70idjCHZw\nvoYGOfs4Z0yHRwGazDo8bj/eILJPFQxZ2QD9WhhNFCWcdk/ozTCnOE4BjCYtKpUwIsoKhESwz549\nm8jIzo/VGgmYAuaY/rOzh7rgkzM/H0GnQz8qrd21nhQCU1BpdegzMuWsvX6s7hdK27LfbsdTUY4h\nK7tjgdbDJCUAQ1YWCEJgR9RfOOweDEYtanXfX9PO7OsKyjj0RGs3ZCuCvf8WOJfTiySFbpFX3l/j\nuHHtrgmCIIcBjwCNPfSn445QAtvvgnwss+cAYPPa2V65G5WgIjMyndSIFLSq4IZUkiRKqmwcKqzH\n6/Oj16pxVcs1SEKhnYhuN56yUoxjxiKo29smjQETRM8msXHMGFwnjuMqyA9E2tS7GjhQewS7186M\nxFySTAndtNKCy+OjqNJKQYUVm9NLXJSB6pPyGZmhMEEoQqejXUvrZ/RES1MZjOhSR+EqKkTy+RA0\nGlw+F2W2SmpddVg9NqbGTySxB+NQ3eCk+KSVmkYXjXYPybEmbFY3lsj+ta8rKHPObvMQGR3cOb+G\nTEWwFwQ+a/JYOVhzBLVKjU6tY5pxLALB/4ayGjvHShpweXx4vCKG5jyLUNVOchacQGU0ouvkIG1T\nhI6aShuSJA3ps0+7Y9AEe7CHsg42Sj995qmUCgL+smI0ESJrDn/G+vxvcPtbBIJereOHs65mUdYZ\nnbbncvt4d30ea7cXU9voanMtFRiFig0HK7GkRzNtbPCC4dTxbDxYDJJEzKTxnY61KUKH2+Xr0d9C\nNWsa9Z99iqqiGPuUFJ7f/k8K6ksC1z/K/4zxcdlcPuVipiVP6rSftY1O3vz8KGu3FeM/JSJlAgIR\nCHyxr4JLzxpDQkzvDxR3VpYCkDRzKrEd/E7RKz9b8kuBvgUzHo1TJnLys1JM9joKLV6e3foyTW5b\n4PoHJ/7HOTkLuHzyxUQbOt7NSpLEoYI63t9wnG2HKmmdKyYAs1FRWu9k8+EqLpqXiVbTdoHuyd+t\nvCgfQaMhbfZU1Pr2QjIxSW5Lq1YF326ChbKUZNyFBcTGGvmycAtv7H0fu7elCqjmoIZrc5dz4biz\nUQkd7zx8fpEvthbxxbZi8kraOqQjgfGo2Ha8moRJSSyYntprgeuz2zlWWUlU7lQSk6La/5wECzGx\nZqrKrUSY9CEvGT2QDJpgD8XJ5f3NqSes65JTsB4/zs8+e4I6dwMx+miWZi3BrDVT2FTCjpO7+eu2\n1yisKueirPPaTEBJktiTV8O/1h6jtsmN2aDh9MlJ5GbHYTHpcHv9HPy2GFu5lX2FdWx9YTNn5qZw\n9TljMXYT097RSfB1u+WEHCk5vdOxNpq0NDW4evS38CXIduqSndt5Ub0Zl8/FxNhxTImbiElrZGvF\nTo7WHucPm1Zx69TrmBrfItwTEixUVDby4TcFfLatBK9PJCnWxPQxcWSlRBJl1lHX5GbfF3l4PH4+\n2JTPf78uYMVZ2Vxw2mhUvXiha/fLBbs8cakd/k63V3ZY19bYqa62djiWHZI6GoAvv3iP12MLUQkq\nFqbNJ9mUiFqlYm3RRj4/vomvCrfx4xm3k2ZpqyHaXV5Wf3yY3XlybkRWSiRzJiQSH2Ug0qyjsKSB\nE5sK8UgSf//gAO9/mcfV54xj1viEwFgG+3cT3W5s+QUYMjKpa/IA7XcnIvKqUlHeSHxK8AuGdnQW\nrq1b+MM7v2efUIlBrWdZ9gVEaE04vE6+LPuKV/e8y9aivVw/+WoidW3bLqux8/f/HqKo0oogQG5O\nHLPGJWAx6dBpVRzeW0HV4WqqrW6een0H//06hmvPG0dKnDnoPiooNeRVqe3fCWU8NVp58SkuriMu\nIaLHz+hvgl10QybYR0LhnO7QjE7HU1GOVFXDBdPO56LMc1GrZC3qtJRZLEybx1/3ruZ/hWupdzdy\n7YTLEQQBUZR4Y+0xvtxVhlolcPEZGSw9IxO9rq0GdnJfJTbgnqum8eaXJ/hqXwWHCuu5fflkclLb\naxhdoURsKHbQjjBF6KmtsuP1+II+hk9jiUSKjcZZkI97ZgLXTbqKuckzA9fnJs8krz6fv+59mb/v\n/ye35t7A5DjZP9Fk9/D/3t7L4aJ6Yix6li/IYt7UZNStbN+SJLH/02MkJ0bww9mjeG/jCd7dcIJj\nJQ3cvHQSEUZt0GMgiSKu/BNok5JQR3T8khqMvXOYGZtNO1VH9hC5KI2bp/6A7KjMwPXTk2ezsWwz\n7+V9xPP7/sFDs+8hSi9r7gUVTTy/5gA1jS7GpUez8qxsxqZFtVEEIlUCJyhk/oxR5Khh3c4yVr2/\nn0vmZ3LJgqwe9dVdUgx+f9dzQTHN9cDGDqDPysK6dQuewgJy58zmqvHLida3zNWLpy7iua//wcHa\nI/xt32vcN/P2wDuzbmcp/15/HJ9fZP6UZFYuzCHmlNBOZ7mVqsPVfO+C8aw7Ws3+/FoeW72dW5dN\nYvaExB711VUom4wMWZ2PnzIOTrsHgt8wDzlC4jz9yU9+wtVXX01BQQGLFi3ivffeC0WzQwqv38s2\nXSUAC8VMlmYtCUxQhWRzIg/MvovRllFsqdjOloodeLx+Vr2/ny93lZGWEMGvb5rLZQtz2gl1kF8q\ntVpgfGYsj14/m6XzMqizunj6zT0cKqzrUX9dBfmoLZFoYuM6vcds7rkDtdZZzwmLG6Nb5Oa0S9oI\ndYWxMdncnnsjgiDw0v5XKWgspqzGzk+e28jhonpmjI3ndzefxpnTUtsIdWjJOjVb9MyfmsKvbpzL\n5KxY9p2o5TevbKemIfjDPjyVFXKmZQeJSQqCIGDsRbnWfK0Nl05gVK3Iw3N+3EaoA6hVahann8ml\n2RfS4G7kxX2v4vF72Hm0isf/uZPaRheXzM/koWtmMC49up15QVlooqMMXLV4LI/dMJuEaAMfflPI\nX98/gKsHhapcRYUAGDK6EGi98DUA7NLLO46JNjO3TP1BG6EOEG2I5I7cG5mdNJ2CpiI+yP8ESZJ4\nb+MJ3vjiGEa9mrtXTuWHSye1E+rQ4sxNSbLw4ytyuXP5FNRqgefXHGDtjpJ293dFQLBndqXs9G4c\nhhohEex//OMf+frrrzlw4AAbNmzgsssuC0WzQ4qPC77ggFGO1811xXRq54vUWbhl6nUY1HrezfuQ\nJ975ht15NUzMiOGn184kNb7zLaTdJod1CYKARq1i5Vk53L1iKn5R5Nl39rEnL7iSBr6GBnx1dRiy\ns7u0R5osPZvEkiTx5tH3qIyRp022tXPn5vjYMdw85Qd4RR+vHnybJ/+1g8paB8vmZXLXyqmdmpdO\nTaOPNOu478ppLJ2XSU2ji6fe3E1NY3DCPbBr6SB+vTXmCB32HhREs3nsvHbk31TGa7FYvZjdnX/v\nvIxFnJ48myJrCX/Z9i9e+OAgGo2K+66axvIzszstcnZqyOeohAgevX4OE0ZHs+tYNb//xza8vuBC\nE92FhYBcfrkzeqOx76s+yIeu3fhVkNOk79SGLggC14xfSaIpnnXFm1i1di0fbykiMcbIo9fPZua4\nzlXj1rH8giAwe0IiP/3eTCLNOv61No/3NgYfkeMqKGhWdmI7vUcJKuhJstZQJJx5GgQn7VWsL/kK\nX1IcqFS4iwq6vD/WEMOKMctw+92UGzczd1Ii9105DVMX5zVKUnO87ikOmxnjEvjRFdNQqWDV+/vZ\nd6K22/4G4tezOtdMAMzmniVkbKvcxeG6YwGNx11U1OX9U+InMit+FtWuKlxRedxxWS4rzsru0lau\nCJbWsdsqQWDlWdmsODNLFu7/Ck64K9Ea3Y2DyaxD9Eu4Xd1rwZIk8caRd2n0WIkeI0cFKZpgRwiC\nwDUTVhKvTeaE8xCaqHruv3IaU7I630lBx4WvIoxa7r9qOtPHxLMnr5oXPjiIr5uDuEHW2AW9ocPE\nJIWeFkRz+py8ceRdVFodmrQ0vKWliN7Ov2vQGPjh5O+jktQckjaQkqziZ9fOJD6qa8d4R+WbM5It\n/PwHs0iKMfLxliI+3VrcbX99TU346moxZGV1qewoCoUzrLGPbCRJ4p28D/FLflZOvBR9Wjru4mIk\nX+dCQJQkDuw04a9PQB1Vx4QZTWi6iUV2OeV6JR3F607OjOX+K6ejUgk8/8EBik927TQLVrD3ZNvZ\n5LHybt6H6NU6zjvjGvk5XQg0gEa7h6Nbk5G8OvTp+czO7d7x4+iiLsiy+Vksbxbu/+/tvTi6EcTu\n4iJQqzuM429NT8Zhb81B9tUcZGx0NhNzF7Y8pwsKym1U7pPNIElTCsgZ1X3OR2dJWhq1ijuWT2ba\n2Hh259Xwj/8d7nKnIbrdchz/6NEdxvG3xmTWBa2xf160AZvXzgWZ5xA1Zjz4/biLuxawhw77cBWN\nQ9B4mTCnhqggok4cNg9GU/vyzfHRRh64egbRETre/vI4Ww5UdtmOMle72rVA730NQ42wYO+GvTUH\nOVx3jImx45iWMAVDZhaSz4e7vPMDJ9798gTbDlUxyn0GBrWeTwq/aBMW2RHdnZw0Lj2aW5ZOwuPx\n8+w7e6lrcnV4H7QW7F072QICLYhJvOb4/3D4nFyacxEJcaloE5PkOO5OhIrXJ7Lq/f1U1/qZrJ+P\niI+Xd/272+d0V6L1kvlZLJmTTkWtg+c/OIC/kxOdJJ8Pd0kx+lFpCJquHcPBVroUJZGP8j9DQDYt\nKDZrxYbdETUNTv7yn/34bdGMi5hMtfsk31bs6PI50PU4aDVqfn7jaeSkRrLl4En+u7nz57uL5bBX\nfWb3DldThB6n3dNtQbQ6Vz1flnxFtD6KxekLMGQpOR6dL/Q7j1bx7/XHMTtyiNPHsa1qBycd1V0+\np7vyzXFRBu6/ajomvYbV/zvMwS78UO4gHKcARlNYsI94vH4v7+V9hFpQc8XYSxAEAUPzC9LZJN5y\nsJJPtxWTEmfivhWncXb6mdi8djaVbu7yWQFbYhfJSbMnJHLl4jE02Dw89+4+3B2kf0uShKuwAG1S\nMmpT1yFhwdZJqXLUsK1yF6nmZM4cdToAhsxMRLsdX017u78kSbzxxVGOlzYyd2Iid5x1PuNixrC7\n4gDHG7rW8oMpJ3Dl2WOYlhPHwYI63lrbcfanp6ICyefDkJnZ5fOgbXJOV+w4uYdK+0lOS5lFkjkR\nTXQ06sjITk1STreP597bh9Xh5drzxnL9tOXoVFo+PPEpTl/nCzO0ONI7K99s1Gu457Jc4iL1vP9V\nAbuPdSwkXUWKwzCzy+eBPA6SJO8eu+Kj/M/wij4uyb4AnVrXUlqgsOPSAkWVVv720SF0WjX3XT6D\n5WMvDCySXeH1+PF5xS7nQlpCBPdenosgwAtrDlDdiXM9GMcpgFqjQm/QhJ2nI5mvirZT56rnrLQz\nSN51YmQAACAASURBVDLLoVXKit/RJC6qtPLqJ0cw6tXcc1kuEUYti9PPxKgxsLZ4Iy5f5xqhI8ia\n00vmpLNoeiolVTZe+7T9IdvemmpEpxNDN1tOCH7b+VnReiQkLsg8J+AgU7a0rg78Det3lbFpbwUZ\nSRZuvGgiKpWKZdnny20Vru/yWcEcqqBSCdx6yWRGJZhZt6uUTXvL292jaNHKgc1dEczOxS/6+Tj/\nc9SCmosyzwVk+7l+dCa+ulr81rbmMUmS+Pt/D1FWbeecWWmcPTONaH0USzIWY/XaWF+8qcs+OZr9\nLV3ZgyPNOu65LBedVsXf/nuI0mpbu3sCAi2I+dCShdv5PC2xlrG9cjdpEanMSZ4BgDYxEUFv6NAU\nY3V4WPX+frw+kdsumUxGsoUZCVPJsKSzu2ofRU2dR7Z0ZZZrzbj0aK49bxx2l49V/9nfTuGRJAlX\nQQGa2Lg2B4x0hnK62HAmLNg7QZREPjwiv8jnjl4Y+FyXOgpBpwts7RRsTi+r3t+Pxydy89JJJMea\nADBpjUFp7cEWvhIEgWvOHUd28zb8y91tTUKK9qgfPbrb36jRqtHp1V1O4hpnHdsqd5FkSmRGYss5\nmYqgcDVHXCjklTbw5to8Ik1a7rlsKnqtHNaZHZXB5MRxHKo7SnFTaafPC/ZlNuo1/OiyXMwGDa9/\nfoyiyraC1V0s90s/OrPLdiC4sgJbKrZT46pjfuppxBlboioMGfLC4TrFzv7p1uJANNTV57SEWy4e\nfSZmjYmNZZvxdGKeCzjSgygtMTrJwg8vnoTb42fVf/bjPCUM0l1UhMpgQJuY1G1bxiAW+k8L5UV+\n+ZiLAou8oFJhGD0aT0V5mxpCoiTx9Bs7qWkO7Zw+Nl6+XxC4NOdCAD488Wmnz+rJwe4Lp49i4fRU\niqtsvHqKwuOrq8NvberWDKNgMssZ2X7fwBzc3h+EBXsn7Ks5RLn1JHOTZ7aJzRXUavTpo3GXlSF6\n5IknShIvfXQoMIFnnFIKYHH6AowaY7PW3vEWvCfHf2k1Ku5cPoUIo5Y31+ZxpJVtUXHkBaOhKc/r\n6kX+vOhLREnkgszFbcLZFE3Y3cq+3OTw8MIHB5GQuGP5FGIjDW3aWjHxAgA+K/qy0+c57B50ejUa\nbfe1t+Ojjdy8dBI+v8hf1+zH4WoxIbiKikClQp/eteMUujdJ+UU/nxauR6vSckHm4jbXlJ1L63E4\nWlzPuxtPEB2h47ZLJreJ1derdZyZdgZ2r6NTW3tXjvSOmDMhkQtPG83Jeif/+KRFqIkuJ57KCvSj\nM7p1nEL3C1yNs5a91QcYbRnFhJi2NWf0ozNAknCXtSzaH35dwK4jVUzJjm2XVDU+dgzjonM4Up9H\nqbX9jgtaFcULsk7M984dR05qJN8ePMmGVgpPixkmSMHeA9/TUCUs2DtAkiQ+L/oSAaGNtq5gyMgA\nUcRdKk/iT74tYn9+LZOz2k9gAKPGyDnpZ2L3Odhcsb3DZ/a0VG1spIHbL52MKEk8+c8d2Jrtoorm\nqE/vXmMHWai5HF78HYTN1bsa+LZiB4nGeGYlTmtzTW0yoU1KDjhQlcWt3upm5VnZjB8d0669qUkT\nyLCks7f6AJX2kx32x9HDEq3TxsSzdF4G1Q0u/v5fOUJEEkXcJcXoUkeh0nbfVncF0fbWHKTe3cAZ\nKbMD2aMKp2rsjTY3z39wEAGB2y+dQmQHv2Vh2jw0Kg3rSr5ClNqPe7C7ltasaM5e3XGkivW7ypr7\nJDtOgxVoxm58DV+WfI2ExOL0s9qZiJQdorJjPFhQx0ffFJIYY+TWZZM7DHFdPPpMud3Srzt8Xkeh\nr12h1ai4Q1F41uUFdnGKshOMWQ5GRmRMWLB3QF7DCYqaSpgzahrJ5vZpywFttaSIYyUNvL+pgBiL\nnluWTeo0RvvMUWegUWnYVLq5y5fZaAo+ZX5SZizLF2RR0+Dk7/89hF8UcRcVoYmL6zSF/lSUhcTl\naO8w21S2Bb/k57yMRe2ybEHeFYgOB97qaj7eXMjBgjpyc+K48PSOXyBBEDg/82wkJD4v2tDuut8v\n4nL0/ASp5QuymZgRw57jNXy2rQRPZQWSxxP0rkWtVmHo4gShDSXfALAwbX67a5rYOFQREbiLChFF\niRc/PEiT3cPli3IYlx7dYXuROgunJc9s1oAPtrvem8ObNWoVt186BYtJy1vr8sgvbwoqMak1gRju\nDsbB4ZWVkmh9FDMTc9tdNzSbvNwlRdRb3fzto4OoVAIPXzen0zIQk+MmkGiMZ0flbqye9v6B3pz5\nGhtpaN7FSc27OF/LLjZowa4cQhMW7COKdc2OrUsnLunwuiLYrScKeOED+bT62y6ZTKSp8wkYoTMz\nO3E61c5aDte1P4HIYfc0F/rv2Z/k4jMymT4ugX0nalm34SB+a1PQmgl0blf1ij42l2/DrDExO2lG\nh99VIi3yt+9nzdcFxEbquXlp54sbwNT4SSQa49lZtReb197mmtOhnCDVs6p6ijM1yqzjvY0nKN4j\nH9emzwh+HMzmjk8QKrGWcaKxgImx4zpc5AVBwDA6A+//Z++9oyS560PfT3WOk3ty3JyjNiqsJAQS\nCiRjHgbDRRhjHDg8Xb/jc1+wr6/TxX6PCxiuMRgso4vBZIQQKGu1knalzTnvTs6xezqHqvdHdfX0\nzHRPV3XXzG6P+nMO54jpqq7f/vpX39/3942jo/zqlYtc7pli++oaHtzdsuDz3tVyDwICL/W8Ns8B\nnm+jkUq3lc8+thFRlPjGL87jV8JeVUTEwMLRQW8OHCWaiHJv850ZN3lLQ4Ncvri7i2/98gLTwRgf\nuX8VazKc3BQMgoF7W+4iLiV4vf/IvM+12NjT2bKymkf2yae4J399iXBvD6bKKoxudQW0SqaYZch4\naIIL41foKGtldXXmI6y1sQmMRgbOX2HKH+VDB1Zk1c7SOdCyH4BDfW/O+yzfLjEGg8Cffmwn5S4L\nxw+eBtRrJpDdvnxq5Cz+WIC9jXdgMWbWuJQN5PQbZzAIAn/4/k05i3QZBAN3Ne0lLsbn2ZgLaVpc\n7rTw2ffJpqmzb+QxD65kB6HobOfjweRvdW8GbV1BmYdTh85QU27j04/M7zE7lzpnLZtq1tPl66Fr\nTmRIPqYYhY0dVTx2ZzvjvjCjF6/JjlOPumJZNocFQZgv0BJigoN9b2I1WrizcU/GewWTCUtzC6He\nPq71TLBzjYcHdub2b+yp34ndZONQ/xFi4uy5L2QePnB3B2tbKrh0sYfE1JSqYAIFRw7TXDFQEuxz\nODxwFAmJu5Lx2pkQTCZC5R5c02NsX1HFQ3vULZpWdzMdZW1cGL/CaHCmNEAsliAaSeTdJabCbeVz\n79tIXVj+zkRt5iYCmZhJzpn9Mh/qO4KAwN2N2WvLm5plrbQiMMZv37eKlU3qKlDubbgDs8HE6/1v\nzTJL5auhKaxvq+T9d3VQ4RtBQsDctLDWnI5ycvGnNTKejvo5PnyaWnsNG6rnt1JTiNfKpYwbouP8\n4Qc24bSpM6cdaJI3+jcH3p7190Ln4X13drCp2YUjMEmgok6V4xRkJcHumF8Q7ezYRaYiXvY27MJh\nzl4CIFBei0FMsMYa5vGH16mqm24zWdnfuJvpqJ+Tw2dmfabFkT4Xo8HAH7x/Ix2CbGcPVOSOClIo\n2diXGQkxwZuDR7Gb7OyY4yxM5/zNca7HnJilBJ/YVampTviB5v1ISBzqnwl9DGl0EmVibWslO8rk\nF/IHF0JZMzLnkmkR90730+nrZn31GjyO7DVNfnFsiCmTi6b4FA/snN9PNBtOs4OdtdsYC41zZWIm\nwUirsywTj+xppSE2yZiljGeOZY62yIQjJdhnopYODxwlLsY50Hxn1gJXsbjIDy7K9+yqiNHRoL5F\n5NqqVVTbqjgxfJpQfCaxphBNFWQB/cntZRiQuBiyc6l7UvW9mWK4lY3nrizaOsDIZJBDI7IA/vB6\nKw6VmxvIG5yAwBsDb836e9BfWK/TCpeVRzrkMT3fk8AXVCeoS6aYZcaZsQtMR/3srd+Z1fwwPBHk\nn5++wIhdFniG4eylBTKxvXYzZRa3XNI3IduU83GWZaJ8eoSIxcGZ4Rg/Oaiu6l0mU8yhPtneqWiU\nmXjr4hDPvd2D112DNRpE1Nip/u5m+USUblstVKABJMZGMSViTLk8PHO4S3VFzJR9OdlvVZREDg8c\nxWIws6dhfmlihe+/dJXzkxA3WamYHtE0VoNg4M7G3UTFGMeGTqX+rkcTa9PYIAAjtiq59rvKcscO\np4V4TCSajIcfD01weeIaK8rbaHRlLiIWiSX4p5+fp9con9hck5kjnrJRba9iXdVqbnq7GUxGSyUS\nIuFQrOAuRmU+OSP3pljGN35+XlXRNKvNJNfoLwn25cGb/UnNpCmzZhIMx/jqT84SjMTZcfc2gJyF\nj+ZiMpjY23AHoXiI06Pn5O/N01mWTsLvJz4+TvmqFdRVO3n+aC+vn82tsc7VTkLxMMeHT1Ftq8xq\nfugemubffn0Zm8XI6js2AvMTdHLR5m6hxd3E2bGLTIbldmip7NsCBFqkV/491u/ZjNlk4F9+dYHB\n8UCOu2bmwZ8U7NcmbzIWnmB77Rbspszmh4On+3nt9ACtdW6cHe3ERoY1N/ne27ALg2DgjYG3U05U\nPZpYK+ty54Ht+EMxvp4hIzMTc9fD4cFjSEhZbeuiJPHtZy7SM+Jn3a4NcvXTXm3vBMD+xt3y8waO\nAoX5W9KJ9HZjcDhZtbGdK71TfO+FqznLM880tS4J9qJnJDjG5clrrKrooN453x4nihL//MsLDE0E\neXB3C3fcK0eKaBVoAPsa5GbYR5LOQz00VeVlcrS384UPyxmZTz13Jecx3GY3z3KYnRw5Q1SMsa9h\nd0bzw4QvzNd/dpZoXOSzj22kZu2qWc9XiyAI3N20FwmJI8nYfj02OGUc9RtW86n3riMUSfA/fniG\nqRyOsJQpxidfd3hQFjCKwJnL2RtjfO/5q7jsZrm+fFurnKDTp635Q7nVzZaajfT7B1NO1KA/e+Er\ntUR65cqW++/blsrI/PavLuYs8JV+gkuICY4MHMNusmUMcQT46cEbnLg6yrrWCn7noY1Y6hsI9/Qg\nqTQFKmyp2YDL7OTtoRPExLgu74QYDhEbGcHa2spnHt1Ia62LQ2cGePlE9sxnBSVxr1g7w5UEexJF\nuNzVON9pKkoS//aby5y/OcHmFdX89r2rMNrtmGvr5BK+Gn/8WkcNqyo6uDp5nbHQuC6mmHBaEkZ9\nlYM/+ZCc/v8/f3aOgbHsGqviMFM0pCMDxxEQ2Nuwc961/lCM//GjM4z7IvzWgRVsW12TSoTKVbo2\nEztrt2IxmHlr8ASiJBIMROXa2xra380lPUFr38Z6Pnh3B+O+MF/58Zl56fbppNvYg7Egp0fPU+fw\nsHJOZySQW9v90y/OYzQKfOHDW/BU2LG2JHMbNJ7gYMZ2/cbAW8RjCaKReEFrQUomz1kbGzGYzXz8\n3WtY21LBiSujPPX8lQXXa7rGfmH8Mt6oj11127EY54/ntdP9/ObtHuqqHPzRBzdjMhqwtrUhRcLE\nRrSZY0wGE3sadhKIBTk7ekGnTb5PTtBqacVqkes3lTkt/ODlaxy9tPD4lBr9UQ2dqm4nSoId2Z76\n9uAJ7CYbWz2bZn0mSRLfe+Eqb5wbpL3ezR+8b2OqNrS1pQUxGCA+kbv5xVz2N8ia4JHB4/os4jkZ\np2tbK/nUe9cRjMT5hx+con8B4a5oJ0OBYTp93ayrWk2lbXb4Zjga58s/OsPAWID37Grh4WQSkqmq\nCoPTSaRXm6YKcvOFHbVbGQ9PcH3qplx72zm/9rYWIr09mKpmErQe3d/OPVsb6Rn28z9/fo5INLM5\nIt0Uc3T4FHExzr6GXfMiO/pH/Xz1x2eIxUU+976NqUggm5J5mYcZQnaiVnJy5CyTPjlRpxDBHh0a\nQopGU2vBZDTw+d/aQmudrLH+7FDmKozpzw0GoryZNItkMsO8fnaAp567gtNm4n//7S2pMFdbARuc\n8k4cHjiqi8Ye7p1dN6m63MYXPrwFq9nIvzxzMWtFTCj+FnklwQ5cmriKN+pjZ922WU5TUZT4wUvX\nOHiqn5Zal1z7Oa0LUioDNQ9tdXvtZmxGK28NHk/VAS/UFCPHLM/UqblzcwMff/cafIEo//D9k/SN\nzM/uA7C7LMSiCd7slU1DiqlIwReM8qUfnqZz0Medm+r5yP2rUgJPEASsLa3ERoZJhNT3I1XY23AH\nMLPBFTIHce8UCa93VsyyIAh84sE1bFtVw8WuSf6/H55KlV9Ix2I1YTAK+H0RDg8cxSAY2DPn1HKj\n38sX//0kvmCM333PWrantXSzNDSC0ZiXYDcIBvbU7ySaiHKm/zKgjzkqPVHNYTPxnz+yLdV16Iev\nXEPMoLkr8z/pnebixBVa3U00u2eHzx483c+Tv76Mw2bi//joduoqHanPlLkP5zEP9c5aVpZ3cHny\nGmNTXnk8eig7afPQ0VDGEx/Zislo4J9+cT6rcz1XeYXbHV0E+6FDh3jooYd48MEH+da3vqXHVy4p\niq17X1LIAATCMf76X9/mpRN9NNY4+dOPbpuXfKMkwITz0E4sRgs767YxFfEy4Z1esPZ2LhKRCNHB\nQawt87vkvGtnM598cC3TwRh///2TnL0xfyErNeBP9VzAaXKwxbMx9dnAWIC/+e5xbvT72Luxjk89\nvG5eeKcyD1GN9mWAVRUdeOzVnB68KNfeLmhzk58/t06O0WDgjz64ib0b67jR7+Pv//0k497ZxdgE\nQcDhtOD1Buj3D7K5ej1llplMxbM3xvh//+MUoUiC33tkPfdtnx3eKZhMWBubiPT1IiXU9SJNZ09y\n7V3ol7OSC5qHLPWCypwW/vSj22iodvD80V65xO2cE4wiSPvGhxElkb1pm3xCFPnF6zd56rkruB1m\n/uxjO2irn53NaU3mNuSzwQHsa5Sf1z0qO/4Lm4ceBLMZS33DrL+vbq7gCx/egtEg8LWfnuVXh7vm\n+R6cRR7yWLBgF0WRv/7rv+Y73/kOv/rVr3j22We5cUN9g9lbTSAW5NzoBeqddbS55UV5Y8DLX/3b\nMY5fGmZjRxX/5eM7MpYLSBU+ykNjB9ifXMTT06FUE+t8CPb0yl1yWjIn5Ny7vYnfe2Q9kViCr/z4\nLN9/8eqsRsj25CKOhBLsqt+O2WAiIYocPNXP3/6vmbKrv//ohlmVChUUAZKPI1kQBFlrj8jfq4um\nmqEAmslo4DOPbuCBO5rpHwvwF//6NgdP9c/SWh1Oi6yhSTMCJhiO893nLvOVH59FkuBPPrSZOzc3\nzPt+5blSLEZ0WJt9GaDGXsWaipWMTckRQos1DzXldv6vT+xkXWsFp66N8bf/6wRXe6dSnyuCdGzK\ni0kwckfdtuT/D/H3/36KX77ZRXWZjT/72A5aaufXIzK6XJiqqvMW7Ns9m7EYLYwm5yHfkE8pHic6\n0I+lqRnBOD/BaV1bJX/2sR1UuK387NBN/vt3j+JNc7Ar85CpzEQxkJ+KmMbZs2dpa2ujqUnWYB55\n5BFefvllVuboDH+7cGz4FHEpwd76ndwY8PGrw12phtEfeWAN79nRlNXmayqvwFhenpd9GeSQv3pH\nHVLEgLUy/58ikOzmtFBFxzs3N9BS6+Kbv7zASyf6OHVtjPt2NHHXlobUIjbFrGyr3s6JKyM8/UYn\nfaMBrBYjv//oBvZtyt4I2VqAfRnktPJXzsjOaz00VVuWeTAIAr/zrtU0e1z88JXrPPX8FY5cGOL+\nHc1sXVWN3WkGUaDcUEGdqY3nj/bw/NEepvxRmj1OHn94/YIJSNbWVjgsR6RYG9Vn/yrsbbiD586f\nBPKfB0mSiPT2YK7xYHQ4Ml7jtMlNsb//4lUOnh7gi/9+kl3rannPrhbaG9wYTQKJMGz2bGRiUuSn\np65w5PwQkViC3etr+eSDaxdMQLK2thI4fYq4dwo86uqzKNhMVnZ4tjB8TijIkR4dHJA7aC1QVmJF\nYxn/9VO7+Oenz/PW+SFOXh7hwLYmHtzdoqo2/e1MwYJ9eHiYhoYZDaauro5z584V+rVLxuHXbuK2\n1fLTX0SIhk4AsKa5nA/cvYK772hldHThxtHWllaC58+R8PtVV1RUEASBXVU76ZQgYgzm/W8I3OyS\nx5KjNkprnZu/+NQufn7oJgdP9/OTgzf4+aGbNDsEagFLqIIvfvs6kgQCcNeWBj50zwoqciSJWOrl\nAlD5OMwAKm0VtFrlscfN+dfnCPf2YLDbMdXUZL1GEATu2drI5hXVfO+FK5y6Nsa1Pi9mk4GV9ghu\nrIhDTfyXf5ZzGkxGgQ/e3cF797blbEg+43PpgT3ZSzFkY1vtZl6NyzZ2m4Yqn+nEp6ZITE9jX71m\nwetMRgOffGgd+zc38IOXrnHs8gjHLo9gtRhZb4hgilk5fzzB4SHZgVpdZuV337OG/Zvqc54srS2y\nYI/09sIq9WUdFPY27OTZ2GWwJPJ2pCvm0VzlqxXz1MkbE/zwxSu8eLyXF4/3Umkzsgq41jfIPopD\nSU2nYMGeb5ynR+NOvliU9zVisVThrK6mbUMZ79ndxuZVM4Ih1zgDa1cRPH8O2/QYFR2Zj+gLsT+4\nk05OMMl43nMyeLMTwWikactaDJbcmt7nP7qDx9+/mVeO9/DayT68kWvgr0OYqmJ9exWbV9Zw59ZG\nOhrV1X4BGGxvI9DVTXWFDYM5u1DK9m9cX76Wq/gYZgCPJ3Ps+EIkwmGuDg9TtnEDtbW50/o9Hjd/\n9bkauod8vHlmgDfODBCMD+CmkcRoLVtX13Dn1ib2b26gXGX2Y9yxnj5AGh7I+7f0mGqJAn7HOOs8\nC6+nTM+Y6L4KQNW61arG4PG42bOliWMXhzhxeYSzN4eJ+idwBMqxhl3csb6Sh/a2cceGeowqhaxh\n01omngHT+FDWcS5Edc0WXox3EbIFcFeYsZltuW+aw/SY/Oy6LesoU/H8h+vKeffuVl4+1suxi8Pc\nnOwkOi4QIXbbyCotFCzY6+vrGRiYyXAcHh6mtjZ3NblcmvBS4XY5KBMcfOKTM45TZWwejzvnOMUa\n+eUbOXeZWEO75uf7RuQ42QlxnDOd17KmbWdDEkUC3d2Y6xsY90YA9RrvvnW17F3r4YsHj8BwHfva\nV/LgYzPt77T8RoaGJqTrNxg4dzWrlrTQfNojZYCPs5PnGRrOXP99IUI3roMkYahv1DRuh1Hg3Tua\n2L3RzZd+fhTGG/n9d+1gzUY5SS0aijIaUn8cN9d4mL5xk5ERX14+E1vcSVgI8XLnm7Q5s5/Ass3l\n+DlZ449X1WmahxV1Lvl/66Z58RdhhEAl//UTu1MmoYnxzBFVmYiVy9FCE5ev0Yz2dz0WjSMkjMRM\nYV64eDjl79DC1JVrIAiEnFVEVDzf43EzNRlk56pqdq6q5t8vXeTwwDH+eNunbxtZBeo3yYKdp5s3\nb6anp4f+/n6i0SjPPvss73rXuwr92iXD6ZKTc/I9eaQch3nalxUbXswS4a2hzK3SFiI2MoIYDmsq\nS5pOr7+fgbhc7yYRzj/LrpAIIYBIQN7gvMIklyauar9/AYehGo4NnyJqliNlCnGYWVtaSUxPk/BO\n5b44A2JIQLLEOTt2nmBMe/io1m5BczkyeCxlDss3httUXYPBbi/4nYibI6mINS2k/Ax1dRhs2rX9\nSCLKyZEzVNrKWVe1OvcNtyEFC3aj0cif//mf8+lPf5pHH32URx55pGgcpyB73RMFZJjJHdqteduX\nlZfHZIWjQydJiNpC5RSBls1hmIu3Bk8gGuIYjIU5itK7SuVD+sv81tAJzfcXItglSeLI4HEkc2zW\nWPIhFcedx3qQJEmO5XdZiIlxToyc1vwdkd4ejC43psrsDS6yMRme4vLENdxuuTZOvvOQym0YHiYR\nztzjdyGUd6LM7eCGt5ORYPZEokzEx8cQQ6G834nTI+cIJyLsadiZtarn7Y4uo77nnnt4/vnneeGF\nF/jsZz+rx1cuGWo61C+EYDBgbW6RO7THtH+H8vKsbmhjOurn4sQVTfcXItBiYpzjQ6dwW1w4XbaC\nsuysTc0gCPlvcIEoJrOB2rIazo1eIBDT5kyO9PSA0Sg3QdFIl6+HocAwq+rlzamgeSigxEIkHEcU\nJWoqKhAQNGuriWSbQmtLa15moLcGTyAhsaJWbpBR8AYnSQS7ta8H5bntHvm31DoPhZ7elAYwe+vv\nyHHl7Utxbkc6okdYk7W1FUSRaL/6+t8KyrH/jhbZtq2kcatFrfc/E+fGLhKIB9lVvx2nq7CiRwab\nDXNdHZFe7bVzYKb29r6GO4hLCY4Pq9dWpUSCSF8v1sYmBJN2t5FSUXBfm1yeVw+NPZ/QT+W55W4H\nG6rX0u3rZcA/pPr+mYxT7WtBlETeGjyGxWBmbf0KoHCTFID/Zqfme5WNdVVdG3aTnbcHj2s6yabe\niTzMUWOhCa5O3ZAT5xboRXC7844X7Hp0S0lpaXmYIZTnrqxrodXdxIXxy0xFvKrvj/T2YPXUaA61\nhJkyxXc27sbutCBJEM6Qbq8WW2sbYihEbEzb0VkUJULBKA6XlV11OzAIBt5MK2Obi+jQEFIslteL\nHI6HOT5yhipbJRtqV2O1mQpaC6bKKowud14ae3oxOKWsw9z2gQtRiGC/PtWZKlNcWe6aNZ58UHwu\ngc4uzfcq819WZmdX3Ta80WlNJ9lCNPa3FW29QbvD9naiJNiz9PzUQiGOQ7n9lwmTycj+xt2pgmRq\niHu9JLxTODsy92ZdiLHQOJcnr7GyvJ16Z50uRY9mzBDa5iEciiFJ8m9RbnWzuWYD/f5BeqZzl1eV\nn9clP19D82qFkyNniSai7G24A4NgwOW2FiTYBUHA2tpKbHSURDB3Hfh00ovBba5Zj9Ps4O2hE6q1\n1ZlSAtrnQaluuq9hly6nWKV2TiAfjT2tAJgSEXNk4Jjq+yM9PRjLyzGVqw/XBbmD2uHBY1iNw/az\nyAAAIABJREFUFrZ7Nue+4TamJNh1qAlhaWzKu8FAeu3tO+q2YTaYOTxwdFYv0GwoJwRnR7vm5x5O\nvihK5T5dTi55OlDnNthQxvRG/1tZ70lH2VBteQi0wwPHEBBSdYJcZTbCwRgJFZ12sjErUUkDMxq7\nFZPBxO76HfhjAc6MXVB1f7inB8FiwVKvLWQ2FA9xauQcHns1qyo6sCeTowra4JK1c4Ld3Zpr56QL\n9hZXE02uBs6NX8IXzR12mPD7iU+M56WtX5y4wlTEy676HdhMhXVuutWUBLsOGrvBYsFS30Ckt1dT\ng4FU+6/kGOReq1sYC09wbTJ7aVWFcHdSsK9coWm8CTHBW4PHsJvsbE82ULiVGvvcssXrq1ZTZavk\n+PBpQvHcURWRnm4QhKy1crIxmFamuMomR5G43PILHQ7mb5KaqSFU2DwovQFe7zuS9R4FMRYjOjiA\ntblZdfNqhbcGTxATY+xv2I0gCBiNBmx287ym1lqxtrUhRqNEhwY13Rf0y450s8WIIAjsb1B/kk1F\nieVhlns9qUjcnaEnQ7FREuw6VXGztrbKDQZG1duXQ0nh4XDOZGoq2qrSwWchlKO3a4U2wX5+/BLe\n6DS767enyhTrobGbysowVlRoPrnMbTSS3gv0+PCphW6diVmu1R6zrPgY0rskKYK9kHlImeY0nlzm\ntoOrd9aypmIlV6duMBRYuLBYdKAfEgnNZhhJkni9/wgmwTgrEShTU2utKPMQ6dY+D+lF8XbXb8ds\nMPN6/5GcJ9l87eujgXEujl+ho6x1XpniYuQdL9hNJiMWa2EOM8jPgTpjgpg59q0ob6PeUcupkXN4\nIwsfPSM93Rhdbiw12rz3mRooOJNp84U2FrC1thGfnCQ+rb65daZGI/uUXqD9CztR42NjiMFgqtGF\nWsLxMEcGj1NucbOlZkPq704dBLu5tg7BastbY7enbfR3N8s1Z17PYZbK13F6ZfI6w8FRttduxW2Z\nccA7nBaikQRxFX1Ss2FtawcgnPSBqCEVy59WBM1hdrC7fjvj4UkujF9e8P5wlpLFuXj55htyb9em\n4tfWoSTYAXRpXGtTFnFXl+p7UppqmkATBIEDzXeSkBK80Z/9CJ4IBOSY5bY2TTHLI8HRlGbS5Jqp\nRTLjMCvw+J2HGSJTa8ByaxmbazbQ5x9I9QLNhCI0tEbEvDV0gnAizN1N+zAZZkIkXW7brDHlg2Aw\nYG1J5jZE1X9PwB9JOdIVttZspMzi5u2hE0QS2b8rX8fpoeQaO9A8u2iZLj6X5hbZ96RBY1cc6XPL\n9R5ovhOAg71vLnh/pLtbDr1VUdZEISEmePnmYewmOzuz9HYtNkqCHXkRh0M6Ocw0LOJsLfH2NOzE\nbrJzqP8IsURmW2+mLjlqeLVX1kzua7l71t9TDrMCN7h8en9mm4d7mmRh82rv61nvjeQRsyxKIq/1\nvYlJMHLXHA3NVVa4xg7JVnnJ3qNqCQWiqYQ5BaNBjpYKxcOcWCC2P9zTI/sZmptVP28yPMXZ0Qu0\nuBppL5ut4ephojRYrdibGjU1t86k7AA0uRpYVSF3VxoKjGS8VwyHiQ4NYm1t0+RnOD16Dm/Yx976\nnRl7uxYjJcGOPkX1jQ4H5to6wt1dquOvszWxthot3NW4B38skDVRJ9zdBYBNQ4ifPxbgyOBxqmyV\nbJvT29VoNGBzmAno4GsArSappAliTqnatZWraHY1cnLkLGOhiYz3ztRGUX/0vjRxjZHgGDvrts0y\nP0CaYC/UcagxQkh2pMczNpa4q3EPAgIH+97MuLYkUSTS24uloUFVdU+FN/rfQkLinub98059ejWa\ncK1ckWxunVkYz2WhXqeK1n4oy0k20tsjN5xJnp7VIEkSL/a8hoDAPc3aSy3frpQEO/o5UG1tbXJz\n67HMfRTnEsiiqQIcaN6PQTDwat8bGV/mGYHWrnp8b/S/TUyMcV/znRmrJzqdloJfZHONB4PDkYrY\nUUMwEMXuNGOYo2UJgsC7Wu9BQuKVLFp7uKcHU2UVJnfuUr0KB/veAODepKBIJ2WKKXiD09YPN7TA\nWqi0VbCzbiv9/kHOj1+a93lsZBgpEtZkhgnHw7ze/xYOkz3VJSkdvRpNOJOOfbV29oUE+9aajZRb\nynh78HjGaKl8lJ0rk9fpne5nT/N2ah2e3DcUCSXBjj72REhzFiUXWC5CSU3VmaHed6Wtgu2ezfT7\nB7k2Nb/VYKS7G4PdPqt59ULExDiv9b2JzWhjX2PmeucOl+wwixXgMBMEAVtbO7HhIRJBdfVeFmpi\nvbN2K5XWCo4MHMUfm53wIzevntKkrQ/4h7g4foUV5e20ls03WzidFgRBB5NUY5Pc3FqlSWohgQbw\nnrb7AHi+65V5G324S04CsmlIVDvUf4RAPMj9LXdnND/oEQYMssYO6k2UC82D0WDkQPN+wolIRlv7\njGBvVz2+F7pfBeD969+j+p5ioCTY0U+w2zQK9kAggsEgYLVlrm9yX8tdAPxmzssshsNEh4dkW6JK\nx+nx4dP4otNy+QBT5rBAvY7fyganRluNRePEogkcWZpZGA1G7mu5i6gY4/W+2ZEh+djXn+18AYD3\ntN2b8XPBIMz0Pi0AwWTC2tSsurl1LsHe5Gpgc80GOn098zZ6xWFva1Mn2COJKC/3HMJmtKXMG3PR\n6xSrJM+p3uCy2NgVDjTvx2ly8HLvoXlljSPd3QhWG+Y6dQla3b5erkxeZ23lKlZW5Vfm+HalJNjR\nJzkH0h2oXaquV7JOswnnjvI2NlSt5erkdS5PXEv9PdIrN69Wm4QRS8T4deeLmAQj97ZkfpFhZh4K\nFWqK5hjuzJ1OHgwosfzZbcPKZnSw7w3CaUfwlIamUmPv8fVxevQ87WWtbKpen/U6R4EF0RSsrW1y\nc+vB3MXhsvlb0nmw7X4Anu96ddbfw12dsuNU5Ty80f8W/liA+1ruxGG2Z7xGL43d5HTKvqcedb6n\nXPNgM9l4oPUAoXiIV5MmNQAxEiE6OICttVW14/TF7oPAzGloOVES7OinsRudTswejyoHaqZ43Uy8\nb+V7AXj6xq9TyRlhjbVRXus/zER4kgPNd6YyLDOhxNMXHPrZnhTs3SoE+5xyAhm/z2Tj/pa78ccC\nPN89I9RmTBDqErSeufk8AI+teHDBk47DaSURF4lG8jdJyePqmDXOhcgWGZROR3kraytXcXnyGlfH\n5MxkKZEg0tONpbEJgzV3Gnw0EePFnoNYjZZ5kVHpWG0mDEZBl2bO1tY2xECA+MR4zmuV9ZDJiaxw\nT/N+XGYnr/a+ntLatTpOu329nB49T6u7ibWVq1TdU0yUBDv6aewgmyHEQID4+MIO1Eg4jpiQcgr2\nFncjd9Rto9c/wMmRs/K93eodp/5YgOe6XsZhsvNQ+/0LXjtz/C4sIsRUVS1XOFQR069GoAE80HqA\nSmsFr/S+zlhoAkmSCHfexFhRgakid1OJ61OdXJy4wpqKlTm74ug1D6kNrjN3eYhcphiFhzveDcC/\nnvwhoiQSHRpEikZTz8rFq72vMx31c6D5TpxmR9brBEE2Sekh2BVnphqHeiAQxeYwY1ygcbjNZE1q\n7WFe6T2U/O6u5LPacz5DlER+dPVpJCQ+uOqRvGrX3+6UBDtgs5sRhMJty6Dezp7LlpjOYysexCgY\neebm88TFOOGebtXFnp7rfJlQPMx729+FY4EXGfQ7uQiCgLW9ndjYKAn/wr0y1ZggACxGCx9Y9TBx\nMc7Prz9LfHKShNerSlsXJZFfXP81AI+tfDDn9XqZIaxNzQgmkzqTlMr1sKqig931O7g52cOhviMz\np5b29pzPGA6O8uuul3CbXTzQeiDn9Ypg18MkBRBRc3LxR3HmWAsga+1ui4sXe15jKDCcMn+q0djf\nGjxOl6+HnbVbWbMMtXUoCXZgRjsp1LYMaY7DHNrJjKaa+/hcY6/m7qa9jIXGefbys0T7+7C1tee0\nJfb7BznUf4QaWxV3N+/P+Rw9Ty6KoMm5wanUVEGOkFlR3s7p0XN0npdjme0qBPvzXa/S6etmR+0W\nVpS357xeL1+DYDJhbWsn0tebMwM1FIgiCLKSkYsPrXoUp8XBMzefw3dD7g9rzeE4FSWRf7/0E+Ji\nnI+s/cCC2rqCw2VBTEhEwvm1jVRIKTs5NrhYNJF0pOdeC1ajhY+u+SBxMc5TF39EuKsLwWrNqewE\nY0GevvEbLEYLH1z1iOp/Q7FREuxJHAU2tVZIFYDKqbHnti2n89iKB6m113DhzKuy4zRH4a9ALMi3\nzn6XhJTgw2veh9mQu7OQXho7gK09Gb+cQ0tTa4oBeQP+8OrHEBA4d+ol+Tk5BHunt5tfd71IhbWc\nj679kJqhF9wuMR1be4ecgZqjMJrib1FjFnBbXHx8ywcJJyIMXz0jtwTMUdnyzYG3ueHtZKtnk+pa\n44rSESgwWcvocmGuqyfcdXPBDFTF9KVG2QHYVruZXXU76J/sITI4ILcEXEDZkSSJH1/7Jf5YgIfb\nH6DSVqHtH1JEFCTYn3vuOR599FHWr1/PhQvqakbfrjicFuJxkVi0MIeZ0eXCVFOT04G6UHJSJmwm\nG7+36XdpnJDHF2/OrpmIksi/XfgBY+EJHmq7n81pRa4WwmwxYjIb9NXYcwl2laYYhbayFt634iEq\nRmQTj9CcvRJfOB7m3y78AEmS+E8bPqpKS4UZwVKojR3SI4Sy29klSSLojy7oMJzL/Sv2s8rVimPE\nR6imbMGWgDemuvj59Wexm2z8b2s+oNqmrOcGZ1+xEjEUWrCEb0CDeVLhI2veR7vfgiBJRBqqFrz2\nmZvPc3ToJK3uplQo8XKlIMG+Zs0avv71r7NrV3G3kQL9Mu1A1lZFv3/BEr4zyUnqF3Gzu5GdIbmS\n43+E3s7YvT0hJvjZtV9xceIKG6rX8sgK9YkXejrMTBWVGMsrcjpQlSbWFqv6XqUPtNxDw6TERJmR\n73U9k7HD0Hhokq+c+iZj4Qne3XYvaypXqv5+XU8uyRPFQoI9GkkQj4ua1oJBMPCJqvswiXDdHeLZ\nzhczXnd54hpfP/0vxMQ4v7vutym3qs/Q1cskBaROmOGb2edB2UDU2NgVHGYHDxnWAfBi4irnxi5m\nvO5g75s83/0KHns1f7T192YVfluOFCTYV6xYQXt7e8Hmi9sBPe3L9lWyQyZ841rWawIabMsKkiTh\nGJwk6rJxnXH++7Gv8ubA28QSMSRJotPbzd8f/0de7XsDj72axzf8DgZB20+smKREsfDf1NbeTnxy\ngrh3Kus1akI+5xIfGcYUjROsr+T06Dn+7uiXOTt6QY6UiYc5N3aRvz/+VXqn+9nXsItHO7RlFerl\nPAW5hK/B4Vjw5JIyy6k0QSiYBuSNPVBXwW+6XuKpiz+k0ys3E58IT/JKzyG+ceZfEZH47OZPsq1W\nW7s3p1OfujkAthXyxhq+OT+LWkFLQEE65UNyiejBWgvfPPtdXuh+lfHQJJIk0Tc9wJMXvs9Prv0S\nt8XFn2z7zLz6QMuR5b1taSC1iHXQ0uwrZcEeun6dsn2ZE4JSha80CLX4xAQJr5eqHTt5fOPd/MeV\nn/H9yz/l+5d/ikEwpOLc72zczftXPpwzCiYTDqc11dRaq8Cdi629g8CZ04S7unBtnV+PRBQlQoEo\ndU3qtUiYMe9s2v4uhhoDHB44xjfPfReb0Uo4IQsho2Dkd9Z+iDsb92gOZzOaDHJTax0EuyAI2No7\nCF68QMLvz9h0PJDH6Q1InYYevPPjdE78hreHTvD20AncZhfTMdlUZTGY+YMtn8oZ4pkJXcOAm5oR\nzGbCndkFeyCPDU6SJELXr2EsL+f37v5j/vncv/H0jd/w9I3f4DQ5CMTlshaNznr+04aPUmPX1rug\nWMkp2B9//HHGMhS1euKJJ7j//oXjohfC43Hnfe9iUN8oCxdBmj22fMYpVmygz2Ih1n0z6/3RcByH\n00J9vfqGu2NXzwFQvXkDWzfdza6Ojfzowq+YCE4RSUSxGM18eOPDrPdof4kVqmuc3LwyitVsKvg3\nMm3byPjTP8cw1IvnATkZJv07/dMRJAkqq5yanjU9JJfCbb1jO19Ys5rf8j3ED889Q59vkFpnNR5H\nNfd27GNVdXte4/Z43JRV2Jn2hnVZp8GN6whevIB1apjKjoZ5nw/2eAGoayjT9LxY900MFgsb9uzm\nHw17OTt8iYOdR7gwcpXtDRvZ0bCZXU1bqXLk5yS0W+UInXhMLGgelHuHV6/Cd/kKVS4TRvv8jFcx\nLp8SW1orqax2qvruyOgoiakpqvbuYf2qjaxs/L95vfsoNya6uTnZTXtVM4+tfTfbGzbm3OBvN5lU\nCDkF+5NPPrkoDx4dzd2YdimJJ731I8PTqbF5PO68x2ltayd4/RpDPSMZF7HPG8JVZtP0/aOnzgOQ\nqGtO3mfmtzs+OG+chcytYJQXf3/fJEZLYUFTiepGEATGz5zH8eD0vHGODcv/bTQZNI158uIVMBoJ\nuqoJj05jxcUn1/zO7IvE/OZBGaPVZmJ0KMbgwBQm8/xKmFoQa5sAGD59gXjzfFv/0IAs2EVJUj3m\nSrtAsKcX+9p1jE/K2ZdNplY+vroV0vb1RABGA/mtB1GUEASYnAjkvabSf3NjcxtcvETf8XM41s0v\n6TAxLhd5C0diqp83ffQMAIaW9uQ9RvbX7GN/zewSvGNjC+dTFPKuLyVqNx/dwh2L3c7u1Cm0S8G2\nchUksyPnEo8liEYSmk0doZs3wGDQVL1OK3ral40OB9bmFsKdNxFj8xuGaAl1VJDicSK9PVhbWjGY\nc8d858tSOlDzsS37Ll8BScK+Kv/TWS4MSkG06cLnANLs7FnmIdVBSsNGGrpxHZgxf5aQKUiwv/TS\nSxw4cIAzZ87wuc99js985jN6jWvJ0VOgAakXLpxceOnkLdB6urE2NauqCZIvelX1U7CvXo0Ui2Us\njKY11BFk+7oUj2PX2MBbK3rOg6miAlNVNaEb1zPGcSvKRKbyzdnwXZCjP+yr1xQ8voXQqyAazETG\nhLI4UIP++R2kchG6cT2ZCLa8qjMWSkHO0wceeIAHHnhAr7HcUowmAza7WZfQLgDbSlk7CV2fHxmT\nj0CL9PUixWIprWex0H2DW72WqVdeJnTtKuzbMeuzlNPQrX4eglfkZsb2Net0GV829HQcAtjXrmX6\nyGGiA/1yL9A0gn456zS9iXUufJcugyBgX7nI68FlZXTITzQSx2or7IRkqqzCWFFB+OYNJEmaZfNO\nxEUi4Tg1deojVsRIhEhPN7aOFRjMy6OlnV6UMk/TcLosuoR2AZjcZZjr6uRFPEdLyycRQ9FycmWc\nFopTx9hlkDV2QBbsc8hHUw1dvSJ/75q1OowuO8qY9BLsjrXyRhRMjj+dgD+C3WGZ10EqG2Isiv/a\ndaytbRhsmcvu6oWe60EQBOwdK0l4vfMqPeZzig13dYIolswwGSgJ9jQcbqvcQShaWG0MBfuKVXK2\n3eDsbDslo1GTQFM01UW0qQLYHMkOQjpkXYKcqGT2eAhdn2+GCE5re5mleJzQtatYGpswlWkLkdSK\ncnIJ6DQP9qRgV35HBSXrVJNA60yao1Yv7lqAxTjBJTf6K7M3uFSoo1P9O6GYOW2LfGopRkqCPQ29\ntVVbMlEpNCdRSUvhK5CbFQcvX8JUVY25tk6XsWXDYBCwOy26vcgA9lVrEIMBgr19s/4e8EcwGAVV\nha9Arr8jRaPY1y6utg76m2LMNR5MlVWErlyZZa/OJ+s0nDTvLbZ9HcDp1i9JCcCxXi5vEbw0O0M0\nmEcsf8lxmp2SYE9D7+O3suDC1+YIdo2mmEhPD2IggGPDhiWpHe10yZUu9Yp0UgSQ7+Lslzngj+J0\nWVX/mxRtVzFrLCZ6a6qCIGBfu5aEf5rowExHpXyyToNXZbOWfdXiC/bUyUWnebA0NWN0uwlcujBr\nfWk1xUiiSOjGdUzV1arq8b/TKAn2NGZqY+ijnVgam+RFfPH87EWs0XmqaDeKtrPYOF3WlDNLD5Tj\nt+/ijBlCFCWC/ogmDW2pHKdAMuzOoFt0EMxsSKErl1J/05p1Koki4RvXsDU2YCpXn9yWL3qfXASD\nAce69SSmpoilFQTT+k5EeroR/f4leyeKjZJgTyNlitEpblcwGHBs3ETC6yXa15v6e9AvF74yW9TF\n6wYvyZUzHeuWSLAnj9+BaX02OHN9A0aXG9/FGYEWDkaRJPWaqhSPE7p+DUtD46Lb1xWcLqu+Jqk1\n8x2oWjX2aH8fYihE2frsPVv1xKljpUsFx/qNAATSzDFaywkEzstZ2M5N2urfvFMoCfY0UuVaddLY\nYWbhKQsRwO+PqDZBiLGoLNCampdEQwP9fQ2CIGBfs4bo2BjRoaFZ36021DHc3YUUiaSckEuBw2kh\nFNSnIBqAubYWU2XlLDu7Vo1dOb2VbVwawa6EYOql7EBmO7tWG3vg/DkQhNQmUWI2JcGeht4CDcCx\ncRMIQkqwJ+Ii4WAspRXnInzjBlI0uqRHTr01dgDnlq0A+M+ckr9bY6ijEua4FPZ1BYfLgiRBKKjn\nBreOxLQvFSml1d/iP30KBIHKnTtyX6wDBoMBu9Osq0nK7PHIkVKXLyEl5JLLWk6xiUCA8I3r2Fas\nxOhUV1PmnUZJsKeRqsmuo8ZucpdhbWsndP0aYjg0I9BUaqpLbV8H/TrnpOPcvFXe4M6clr97WqOm\nelk24yx2/Ho6etuXIS2e/bL8u2rZ4BJ+P6Hr17CtWImlYum6/zhdVgL+iK5lQxzrNyCGQqkG14FA\nRHUHqeClCyBJODdv0W08y42SYE/DaJS1Ez01dkiaYxIJgpcupR291WmqwUsXwGDAsQQhfgrKpqPn\nPJjKy3GvWU3o+jUSfr8mm2oiECB4+RLW1rYlM0dBWv0gHU8ujqRpzn/yhPzdGrJOA+fPgihmLIG8\nmDhcFuKxwruLzfrOpAkleOkCoigSCsRK9nUdKQn2OSyGdpJuZ1eEhBpTTCIYINzZKadML3KGYTqu\nRTDFAFTt3gWiSODc2Rmbqop58J8+CYkE7juWtlNXyiSl48nFXFWFbeUqQlcuE/f5CGrIOvWflk87\nzq3bdRuPGmYK5OnoSF6XPLlcukgoEEs+J/fpTZIkAufPYXS5sbaW6sNkoyTY5+BcBO3E1rECg8NB\n4MI5/Elh6VIj0E6evCVHTovVhNFk0NUkBVC1+w5AtrPPmCByv8z+48cAcO1cWsGu/EZ+nTc4985d\nIElMnzyhOutUiscJnj+L2ePB0pi9z+ti4FgkE6WtYwWhq1fwDcvlBdTMQ7S/j8TUFI6NmxZsXP1O\npzQzc1gM+7JgNOLYsJH42BjTQxOAOk3Vd+RNAMr27Mtxpb4IgiAnKekYCQFgb2nB7PEQPH+OgC+C\n2WLM2es0EQwQuHgBa0srlrrFzbqdy4wTWd95cN0hb3BTx0+qzjoNXrmMGA7j3Lp9SZLU0tGz92k6\n7n37QRQZPymH86oxTwbOlcwwaigJ9jnoHcuu4Eoen729Q7Oek43Y+BihK5exr1mL2ePRdSxqcLqt\nBANREon5ZWbzRRAEnFu3I4bDBLxBVRqa/9QpSCRwLbEZBtJ8DTpr7OaqamwrVjJ1U85tUGNbDiSj\niVzbltYMA/pnZCuU7doDRiMTV+X67K6yhedBkiR8bx0GoxHHpk26jmW5URLsc9C7NoaCa+cdGBxO\npsenEYTcx07fW0cAKNu7X9dxqEV5mUM6hrmBLJhEDISjkioNzX9CNsMstX0dwGQyYrObdDfFgLwe\nIkbZb5Jrk5dEEf/p0xgcjkUvApeJmdr0+s6D0e3GuXkLfp86v1P4+jWi/X24tu/E5F6aJLVipSTY\n57BYx06DxULZnXcRESzYzCzoLJMkCd+RNxFMpluiqcLiRMaAXJ0yXlkLgMO+cMxyIhggcOE81pYW\nLHX1uo5DLU63VXeNHeSNShHsuTT2wNkzxCfGcW3fiWBa+v7zi3WKBSjbt5+ISW66nsvvNHXwFQAq\n7r1P93EsN0qCfQ56t8hLp/yee4kYnViiC/dfjHR1EhsawrV9B0aHQ/dxqGExQv0ABJMJy54DAJgm\nBhe8dvrYMdkMs8RO03ScbiuxaIJoRJ+6OQrm6hrE2mYArGSfY0mSmPj1rwCofM9Duo5BLQ6XFUEA\n/3RY9+92btlGxCr38XQ4sod8xqd9+E8cx9LQuKTZx8VKSbDPYTGSUhSkihpEgxGzf4LIQH/W6xSn\nqXvfrTHDwOKE+ikIa2THl3TzEmI08zyLkQgTv3oawWymbP9duo9BLYsV+gkgJRtbx46+kfWa0LWr\nhG/ewLltO9amJt3HoAaDQcDhshLw6T8HBrOZmKMKczxE5NrlrNf53ngdKR6n/MB9S+48LkZKgn0O\n9mSjCb1NEEDKlmiNB/AefDXjNdGhQbyvH8JYXoFzw61zEC3m8Tuc/EqzfwLf4cxCbfKlF4hPTlL5\n7gcxV1XpPga1KCeXxbCzx9zVAMRPHSHS25PxmolfPwtA1Xsf0f35WnCVWQn49auboyBJEiEs2OIB\nxn72E6T4/JORJIp4XzuIYLFQtv/WKTvFREGC/R/+4R9473vfy/vf/34+//nP4/cvbGIoBpTO7Ho7\nT2FG+7WbJbxvHCLS2zvrc0kUGXryO0ixGLUf+/gtsacqLKbGrmyaNqJMPv+bVL0QhbjPx+RvnsXo\nclP50MO6P18Li1E3R8E/HUEQwBIPMfqTH837PNLbQ/D8Wexr1t7yZhIutxVRlHR3pkfCcRIJCWeZ\njUh3FxPP/XreNVMvv0hsbBT37r0YHaXaMGooSLDfddddPPvsszz99NO0tbXxzW9+U69x3VIcLquu\njSYUFOHg2b0NKRql/2tfJj41lfp88sXnCd+4jnvXbjmJ5RaSciIvgkBTNovqbZuIjY4y+sMfzGqb\nN/7M04jhMFXve/8t8zEoLKZgD/giuMpsODdsIHjh/KwKoHHvFEPffRK49do6LF6yljJpmH0fAAAa\nGklEQVSvVWtXYKyoYPyZp4mklbj2nznN6I/+A2N5OdXve7+uz17OFCTY9+/fn4ru2LZtG0PJkqzF\njtNlIREXCQVjun6vsohrNq+j5kMfJj4xQf/XvkLg/Dkmnv8N47/4GUa3m9qPfULX5+aDEuq3GCYp\nZR4aH3svloZGpl55iYGvfYXg1Sv0f+0reF99GXNdHRX33Kv7s7WSEmg6z0MiIRLwR3G5rdR8+CMg\nCAz809cY/fF/ELx0kZ6//SsiXZ2U7b8zVV/mVuJMxpj7dbazKxuFu8pJ3Sc/BYkEg//yTbyHXmP6\nxDEGv/UNBLOZpj/5Auaqal2fvZzR7az/k5/8hEceufWahR64ymwA+KZCGC36uSHSC4BVvPcRosPD\n+N58nf6vfEm+QBCo/cSnMLrduj2zEBwuK36f/pEQQX8Um92EraaKlv/z/2Hwm/9E4NxZAufOAnIr\nvdqPfeKWmqIUUmGfOgs0xTnvKrNia22j/vd+n7Gf/pjJ559j8vnnAKj+4G9R9fCjt4WzcLGcyOm1\nk1ybtlF+zwG8h15j+KknU9c0/OGfYOtYoetzlzs535zHH3+csbGxeX9/4oknuP/++wH4xje+gdls\n5rHHHlP9YI/n9hBemahvLOP8yX68UyHWbtQvfjoWkW3JbR3VWG1map74Y3qb5DR5Z1srrlUrsdXn\n97zFmM/KagcTowHKy+w5U//V4vG4CQailFfak2N2U/fXf0H3975P4GYnTR98P+Vbt9xSYZY+l5Ik\nYbYYiYRius5xKOmU9tSV4fG48Tz2IB0P3sfwiy8x+tobNH3wfVTv26t6nItNJCg7NRNxUfNzF7pe\nTMjmzqaWSjweNzX/+fP4H32IYE8Pwd4+XCtX4jlwd/4D12mcxUbOt/XJJ59c8POf//znvPbaazz1\n1FOaHjw6Oq3p+qVEMMpCxTcZ0nWck+MBzBYjvukwJGOCHe95FAAJmAam83iex+NelPlUmh50dY5T\nWV24rdvjcTPQP0UkHMdqM80as/PhD+AEYsDY2K1zwmeaS4fLwtSUvmuht2cSAKNZmPW9pt1307D7\nbkQWfkcW6zfPRizp4B4dntb03FzjHB2SP4snEjPXVTVgqGrAtW2PfM0S/DuXej7zRe3mU5Cd4dCh\nQ3z729/mG9/4BhaL+qbEtzvKsdM7FdL1ewMamzffapyL0CpwOmnaUcxdxYDLbSUcjJGI61c3J6Ch\nyuftgMNpwWAQdDfF+DWUsS6hnoLO13/zN39DLBbj05/+NABbt27lL//yL/UY1y1FKUbkndRPsMfj\nCcKhONW1Lt2+c7FZjIgQxWbvLi8ewZ6ejVxWoU9dfH9qgysOgSYnKVkWJSrGZjdhNqtr7F5CHQUJ\n9hdeeEGvcdxWKCnUemrsWhpL3C4ojkM9X+Zpb1JTLRKBBmkRIdN6CnZlHopng3O5rQwP+BBFCYNB\nHx+IPKfFMwfFQinzNAMGg4DLbcWno8ZejEdOd1Lo6Bnipphi3MUk0Bahbo5/OoLJbMBqu/WRP2px\nlVmRJHRrbB2NxIlFE0VjjiomSoI9C64yG9O+sG71yFM2VZV9HW8HFG1y2qtfyGNRmmIWySTlcltv\ni1BGtSjzoFcIbDEqO8VCSbBnwVWe1E50SkyZidctHuep1WbCYjWltGw9mPZGVNWjv53Q2yQVi8n+\nlmIywwC43PJ49drgis2BXEyUBHsWUtqqTkJN0XqLSVMFKCu3Me0N61Zewe8L43RbMRqLZ+nNJOfo\nu8kXm0Bz6Zx9qrbBRgntFM/btcS4dV7ExSrYXeVW4jGRcKjw8gpiQiQwHSk6TdWuc6hfSqAVkQMZ\n0kwxemvsRTYPxUBJsGfBlXIc6qOx+7xhLFYjVlv2ZgK3I8pGpMcG5/OGkaSZTbNYUJp769Vowl+E\nDmSYEcC6bXAlG/uiURLsWdDz2ClJEtPeMGXl+oTKLSXuVN2cwoWaEj7qKrJTC8gbXGA6qkuS0kyo\nY3EJNLtDPrnodYpN+Z2KKKCgWCgJ9iwojiI9NPZwKEY8JhadGQbSNXYdBHsyfLTYNHYgFb+uh8/F\nX6Q2doNB55PLdASL1ahbHaISM5QEexasNhNWm4lpHbSTYrWvw8yY9Qh59E4GgeJKylFwV+h3cim2\nrNN0nGU2gv4ooljYyUWSJDnkswjXQjFQEuwLUF5h10VTTQn2Isyw01ewh2Z9ZzFRVj5TyrlQ/NMR\nrDYTZkvxaaoutz5hwJFwnGgkUco6XSRKgn0ByirtRCMJIuHCOtQrWl4xCjRZABl1MUEUsynGrZhi\nCtzgZE01UnRmGAW9fE/KWtCrREOJ2ZQE+wKUJxddoTZFRRiUFaFgFwQBV5lVN429WDXVMp1MMak0\n+iLc3CDNmV7gelBOPiWNfXEoCfYFKK9MCvYCtZNitrGDvCEVenKRJAnvVKho58DhtGA0GZj2FmaK\nmYlhL855KEu+E4XWUVI2yJLGvjiUBPsCpDT2As0QPm84lZ5fjLh0iIwJh2JFrakKgoC73Fawxp4K\ndSxSU0x5pbwWCq18OqOxlwT7YlAS7AugaCeFRMYoMezFqqmCPsdvRaAVW1JOOmXlNiLheEEnF8W2\nrJwGiw1XmQ1B0FFjL+L34namJNgXQA+NPZTsvFPMtsRULHsBgl0xRxVzeJvyGxZijil2wW40GnCX\n23TR2F1lVoymkghaDEqzugDu8qR2UsDxu9jt66BPyGOqDnt5cZogANzJzOFC1oMSy1+sgh3ksYcC\nMaKR/E4uibiI3xcpaeuLSEmwL4DRaKCswo53In/tRLElLgvBXsDJxZ/snFTM86BHZIx3MoTdaS5a\nfwvM2MXznQdlHZXs64tHSbDnoKLKTjgUy7u64XLQ2O0OczIiJH9fQzE2sZ7LzMklv40+kRCZ9oYp\nr3ToOawlRzlt5NsTuBTquPgUpDZ89atf5eWXX8ZgMFBdXc0Xv/hFPB6PXmO7LSivcsCNCbyTIWx2\n7ZUZZ2LYi1c7EQQBd4Gx7FMTQSxWE3ZHcVW3TCelqeY5D9PJ6pbFbIaBdI09T8E+mXwninwebmcK\n0tg/85nP8Mtf/pJf/OIX3HvvvXz961/Xa1y3DRVV8uKbGg/mdf+Mxl68tmWQtVUlZFEroijhnQxR\nU+sqqlZwc0nVD8rTBKGY9IpesCshjwVr7MU9D7czBQl2p9OZ+u9QKITBsPwsOxVV8rF5ajI/we7z\nhrHZzUWZbZmOklKfz8s87Q0jJiSqa525L77NcZfbknXltXeUUtaQoiwUKwVr7KnkpJIpZrEoWNp8\n+ctf5umnn8btdvPUU0/pMabbivKkYM/HgSpJEn5vmOpal97DWnIqq+V5mBwPUFOn7d+jnHZqlsE8\nlFXYGBv2EwxENdcR9y2T+ihmsxGny5J3LLtvKoTZYszLtFlCHTkF++OPP87Y2Ni8vz/xxBPcf//9\nPPHEEzzxxBN861vf4nvf+x6f//znVT3Y43FrH+0toL2jGrPFiN8X0TxmnzdEIiFRU+ta9H/vYn9/\n+4oa3uQ60VBC87OuXxgBWJJ50IOFxljXUM7NK2MYMWj+twT9sgN+5eparLbCT3C3ci6ra130dE5Q\nWenAZDIueG36OCVJwucNU1XjpLa2bLGHqYliWJtqybm6nnzySVVf9Oijj/IHf/AHqgX76Oi0qutu\nJR6Pm7ExP+UVdsZH/YyM+DTZiHs7JwCwuyyL+u/1eNyLPp8Gs/zv7uuZ1Pysvp5JAKprF3+chZJr\nLk0W2dzY0zWOzaVN4xwdnsbhtOCbDkGB07AUv/lCOJwWkODm9bHUaS4Tc8cZDESJRRM4Fvmd0Mqt\nnk+1qN18CjKKd3d3p/775ZdfZsWKFYV83W1LeZWdeEzU3OtxYjQAQLWn+G3LTpcFs8XI5HhA871T\n40EEAapqijvMD/KPZU8kRPy+cNE7ThXyLQZWcpwuDQWdB7/0pS/R2dmJwWCgsbGR//bf/pte47qt\nSDlQJ0Ka4rAnxmQhWFlT/IJdEAQqaxyMDfkRRVGTo3xyIoi73JbzyF4MKGtB6wbnm1oeoY4KqVh2\njQ7UkuN0aShIsP/jP/6jXuO4rSlXQh4ngjS3V6q+b2IsgMEgLJuXubLaycjANN7J8ILH73TCoRjh\nYIy6huVhv3SX2zBbjIyPaBPsqVICRR4Ro1Be0thva5ZffOIiUJFHZIwkSUyOBamodmA0Lo9pTkXG\njKkXalMTyRA/lRvB7Y4gCFTXOpmaCBKPq4/pXy4x7AqKxq1VY1fWTrGHfN7uLA+Js8ikkpQ0xLL7\nfRFi0QRVy8AMo1BZo5gh1M+DEuqobI7LgSqPC0mCyTH186AIwOUi2K02M1abSXNew9iwH4vVVNQl\nNoqBkmBXgdVmxuYwa9LYFcfpcnAYKlRWy5uUFvuysgksF40dZpzhym+shuWmsYN8gvNNhojH1J1c\nYtEEUxMhauqKOwO5GCgJdpVUVNnxTYVIJERV1yuO06plEBGj4C63YTQZNGmqisau1iZfDCgJZ+Oj\nftX3eCdDOFyWos9ATsdT70aSYFzlBqfM13JIVLvdKQl2lVRUOpAk9WFuyykiRsFgEKiosjM1HlSd\nUj81EcRqMy2rLEPFvKbWgRoJx5n2qnc4Fws19bJDfHRIXfz32LAs2Ks1Zi6X0E5JsKskPTJGDROj\nAYwmw7Lz/lfWOInHRVWVHhMJEd9UmIpqx7I6elttJtxlVtWmGEXw1TbcXpmWheKplwW0WsE+PlLS\n2JeKkmBXiWJSGVOxiEVRYmo8SGW1A4Nh+Qg0SK8Zk3uD802FEEWJymXkOFWoqnURDEQJBaM5rx0Z\n9AFQu0xCPhUqqx2YTAZNGrvBIKSc8CUWj5JgV0ldo6xtDfX7cl477Q0Rj4vLKiJGIeVAVWFnV65Z\nTo5TBcWBqsYcMzwgrxllDS0XDAYD1XUuJsdyh36Kosj4aICqGueyCf+9nSnNsErsDgsVVXaGB3yI\n4sL25YlRWaAtJ8epwkzIo3qBphzZlxNqHaiSJDEyMI3TbcHpLu6a/Jnw1LkRRSnnBuedCJGIiyX7\n+hJREuwaqG8qJxZN5EzQmXGcLj9NtbzSjsEgpOylCzHQO4XBIFDXWL4EI1taUiGPOQRaYDpCMBBd\ndvZ1BbV29jHFvl4S7EtCSbBroK5ZMcd4F7wuFeq4DE0xRqOB2sYyxob9RMLZ+8DGonHGhvx46t2Y\nLcVfI2Yu5VV2jEYhZ6jfyKDiOF1e9nUFj8rIGCUipuQ4XRpKgl0D9U2y5jnUl93OLkkSw33eZZ1d\n19xWgSTBQM9U1muG+mWTVUPL8tPWQbYvV9Y4mRgLLGiaW672dYXKGtmBOja08AkuFepYEuxLQkmw\na6Cy2oHFalpQY58cDzLti9C6onJZhfiloxRC6+uazHrNYK88R40tFUsypltBtcdJIi4u6G9QNHZF\ns11uGAwGqmtdTIwFsjpQJUlibMSPu9ymS4ORErkpCXYNCIJAfVMZvqkwwUDmMLeeG+MAtKyoXsqh\nLSm1jWWYzAb6urNr7AO98mf1zctTYwdobJM3uO7r4xk/F0WJ0aFpKmtkhWC54ql3IYpS1rj+oD9K\nOBgr2deXkJJg10h9k3ykHs4S9thzU+6a1LqiasnGtNQYjQYaWyuYGg/iz9B8JB5PMDLgo6bOtaw1\ntPZV1QgCdF6d3zoS5HIKsWhi2TpOFXLZ2ZV3Yrmao25HSoJdI3WKnT2DOSYaiTPY68VT75Jbhy1j\nmpPaan8Gc8zIwDSJxPK1ryvY7GYaWysYGZzOuMEp9vXl6jhV8CT/fdlMc9cuDgOwan3tko3pnU5J\nsGukrtGNIGROVOrrmkQUJVqXsRlGoSkp2Pu657/Mg0kzzHK2ryt0rKkBoCuD1t6f7PW63DXVqhon\nVR4nXdfG55kofd4Q/d1T1DeXL9tggtuRkmDXiNliorrWxeigj3BodrhfygyzcvmaYRSqa53YHGb6\nuybnFQQbSDpOl7vGDtCxWhbsN6+Ozvq7fzrCjUujVFTZl71tWRAENmxrQBQlrpwbmvXZhdMDAKze\nUNLWlxJdBPt3vvMd1q1bx9RUdmfacmLNxjoSCYmTR2aaeUuSRM/NcWx207K3qYL8Mje3VRDwR2cV\nRgsGogz1e6msdmB3LG9zFICrzEZtg5uBnqlZG/25432IosTWPS3LNjoqnTUb6zCaDFw6Mzhroz9/\nsh+DQWDlOs8tHN07j4IF+9DQEIcPH6axsVGP8RQFm3Y04S6zcu5Ef6rK4cRogMB0lJaOqmVX+Csb\nTcmwx7PH+1N/e/2Fa8RjIhu3v3PWQ8eaGiRpJjomEo5z8fQADqeFNRvrbvHolgarzcyqdR68k7Lp\nBeTQ38E+Ly0dle+ITf52omDB/nd/93f82Z/9mR5jKRqMJgO77ulATEgcfb0T72SIF56+CEB78mj+\nTmD1+lqqPE4unhrg3Ik+blwe4eaVUeqby9m0s+lWD2/J6Fgja6PnT8kb/cUzA0QjCTbf0YTJtPyy\nbrOxYZu8mV86M5A0ywwCsGrDO2Nzu50oKBbtlVdeoaGhgbVr1+o1nqJhzca6/7+9u4tpMkvjAP6v\ntIDDOKaK06DD6CwOG4gFRhPdgURtbeSjVlFRboymDUZvrCB+hKJGA8aAqJekxAjRZDTK2myI0Wym\nWiEIIsYFN6Q6bHAcjAVRMhSj9OvZC9dO2NJqzOgp5fndnSYn+acfT09P3/c56Or4DY/+PYBfe19g\n7I0H6Uu/mVI/OWXRUuQVKPH3c/fQ+nMvZNFSREmnQZX31ymx/fCOfPYXSPxOjt/6hvGT+Q6ipNMg\ni46aUr9aAEAx7yvI47/Af+zP8fiXFng8Psiio/Dd95F/MUG4eW9h1+v1GBoK/Me/uLgYZrMZZ8+e\n9T/2oafqRAKJRIK/rfwLrl56ALfLixU5yf4Vy1QyY2Yscjcq8Y+f/gXXmAc/qpIi6uDqD5W3KQ29\nPQPobP0Vvw+/RvrSRMTERs6pUR9CIpFg8Y/z0fLPX/DVzFjMmhOHH5Z+G1HHAU4WEvrIavzo0SPo\n9XrExsa+7Y8yMACFQoHLly9j9mz+hmaMMVE+urD/P7VaDYvFgpkzI/8SN8YYC2d/2nXsEolkSm3F\nMMZYuPrTVuyMMcbCA995yhhjEYYLO2OMRRgu7IwxFmGEFXa73Y7CwkLk5+ejoKAADx48EBXlvc6f\nP4+cnBzodDrU1NSIjhNUuPfsqa6uRm5uLtatW4ddu3ZhdPT9B2J/Ts3NzcjJyUF2djbq6upEx5mQ\nw+HA1q1bkZeXB51Oh3PnzomOFJTP58P69euxc+dO0VGCcjqdMBqNyM3NhVarRVdXl+hIE2poaMCa\nNWug0+lQWloKl2vig378SBCDwUAtLS1ERGSz2WjLli2iooTU3t5Oer2e3G43ERG9ePFCcKKJPXv2\njAwGA6lUKhoeHhYdZ0Ktra3k9XqJiOjEiRNUU1MjONEfvF4vaTQa6u/vJ5fLRWvXrqXe3l7RsQIM\nDg5ST08PERGNjo7S6tWrwzInEVF9fT2VlpbSjh07REcJ6sCBA9TY2EhERG63m5xOp+BEgRwOB6nV\nahobGyMiot27d5PFYgk5R9iKXSKRwOl8e+KK0+mEQhGe/SQuXLiA7du3Qyp9e/fcrFnh2ZJ3MvTs\nyczMxLRpb99yGRkZcDgc75nx+XR3d2P+/PmYN28eZDIZtFotrFar6FgB5syZg5SUFABAXFwckpKS\nMDg4KDhVIIfDgVu3bmHTpk2iowQ1OjqKzs5ObNy4EQAglUrx5Zfh2WLZ5/Ph9evX8Hg8ePPmDb7+\nOnQbZGH3+paVlaGoqAhVVVUgIly8eFFUlJAeP36Mzs5OnD59GjExMdi/fz+USqXoWONMxp49jY2N\n0Gq1omP4DQwMICEhwT9WKBRhvT0IAP39/bDb7UhLSxMdJcC7hca7xVs46u/vh1wuR1lZGex2OxYt\nWoTy8nLExobXgSAKhQJ6vR4rV67E9OnTkZWVhczMzJBzPmlhD9ZnpqSkBLdv30Z5eTk0Gg2uX78O\nk8mE+vr6TxknqFD9cLxeL0ZGRnDp0iV0d3ejuLhYyEpusvTsCfWaq9VqAEBtbS1kMhl0Ot3njheU\nyOfsY7x69QpGoxEmkwlxcXGi44xjs9kQHx+PlJQU3LlzR3ScoDweD3p6enD48GEolUocO3YMdXV1\nMBqNoqONMzIyAqvVips3b2LGjBkwGo1oamoK/fn55BtEQSxZsmTcePHixYKShFZUVEQdHR3+sUaj\noZcvXwpMNN7Dhw8pMzOT1Go1qVQqSk1NJZVKRUNDQ6KjTejKlStUWFjo3y8MF/fv3yeDweAfm81m\nMpvNAhMF53a7yWAwUENDg+goEzp58iStWLGC1Go1ZWVlUUZGBu3bt090rADPnz8ntVrtH9+9ezcs\n/w+4du0alZeX+8cWi4WOHj0aco6wPXaFQoGOjg4AQFtbGxYsWCAqSkgajQZtbW0AgL6+Png8Hsjl\ncsGp/pCcnIzW1lZYrVbcuHEDCoUCFoslLBuxNTc348yZM6itrUV0dHgdvKBUKvHkyRM8ffoULpcL\nV69exapVq0THmpDJZMLChQuxbds20VEmtGfPHthsNlitVpw6dQrLli1DdXW16FgB4uPjkZCQgL6+\nPgBAe3s7kpKSBKcKNHfuXHR1dWFsbAxE9EE5he2xV1RUoLKyEj6fDzExMaioqBAVJaQNGzbAZDJB\np9NBJpOhqqpKdKSQwrlnT2VlJdxuNwwGAwAgPT0dR44cERvqf6KionDo0CEYDAYQEQoKCsLyQ37v\n3j00NTUhOTkZ+fn5kEgkKCkpwfLly0VHm5QOHjyIvXv3wuPxIDExEcePHxcdKUBaWhqys7ORn58P\nqVSK1NRUbN68OeQc7hXDGGMRhu88ZYyxCMOFnTHGIgwXdsYYizBc2BljLMJwYWeMsQjDhZ0xxiIM\nF3bGGIswXNgZYyzC/Be68EGj7hfMcwAAAABJRU5ErkJggg==\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f385e198650\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "def f(x):\n", - " return tf.square(tf.sin(x))\n", - "\n", - "def grad(f):\n", - " return lambda x: tfe.gradients_function(f)(x)[0]\n", - "\n", - "x = tf.lin_space(-2*pi, 2*pi, 100) # 100 points between -2π and +2π\n", - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "plt.plot(x, f(x), label=\"f\")\n", - "plt.plot(x, grad(f)(x), label=\"first derivative\")\n", - "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n", - "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n", - "plt.legend()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-39gouo7mtgu" - }, - "source": [ - "## Gradient tapes\n", - "\n", - "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`. To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n", - "\n", - "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "MH0UfjympWf7" - }, - "outputs": [], - "source": [ - "def f(x, y):\n", - " output = 1\n", - " for i in range(y):\n", - " output = tf.multiply(output, x)\n", - " return output\n", - "\n", - "def g(x, y):\n", - " # Return the gradient of `f` with respect to it's first parameter\n", - " return tfe.gradients_function(f)(x, y)[0]\n", - "\n", - "assert f(3.0, 2).numpy() == 9.0 # f(x, 2) is essentially x * x\n", - "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n", - "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n", - "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "aNmR5-jhpX2t" - }, - "source": [ - "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n", - "\n", - "For example:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "bAFeIE8EuVIq" - }, - "outputs": [], - "source": [ - "x = tf.ones((2, 2))\n", - " \n", - "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n", - "# a single t.gradient() call when the bug is resolved.\n", - "with tf.GradientTape(persistent=True) as t:\n", - " # TODO(ashankar): Explain with \"watch\" argument better?\n", - " t.watch(x)\n", - " y = tf.reduce_sum(x)\n", - " z = tf.multiply(y, y)\n", - "\n", - "# Use the same tape to compute the derivative of z with respect to the\n", - "# intermediate value y.\n", - "dz_dy = t.gradient(z, y)\n", - "assert dz_dy.numpy() == 8.0\n", - "\n", - "# Derivative of z with respect to the original input tensor x\n", - "dz_dx = t.gradient(z, x)\n", - "for i in [0, 1]:\n", - " for j in [0, 1]:\n", - " assert dz_dx[i][j].numpy() == 8.0" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "DK05KXrAAld3" - }, - "source": [ - "### Higher-order gradients\n", - "\n", - "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "cPQgthZ7ugRJ" - }, - "outputs": [], - "source": [ - "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n", - "\n", - "x = tf.constant(1.0) # Convert the Python 1.0 to a Tensor object\n", - "\n", - "with tf.GradientTape() as t:\n", - " with tf.GradientTape() as t2:\n", - " t2.watch(x)\n", - " y = x * x * x\n", - " # Compute the gradient inside the 't' context manager\n", - " # which means the gradient computation is differentiable as well.\n", - " dy_dx = t2.gradient(y, x)\n", - "d2y_dx2 = t.gradient(dy_dx, x)\n", - "\n", - "assert dy_dx.numpy() == 3.0\n", - "assert d2y_dx2.numpy() == 6.0" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "4U1KKzUpNl58" - }, - "source": [ - "## Next Steps\n", - "\n", - "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)." - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "Automatic Differentiation", - "provenance": [], - "version": "0.3.2", - "views": {} - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb deleted file mode 100644 index bfcc7feb075c403d024772e0d715339d58877a51..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb +++ /dev/null @@ -1,209 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "U9i2Dsh-ziXr" - }, - "source": [ - "# Eager Execution Tutorial: Importing Data\n", - "\n", - "This notebook demonstrates the use of the [`tf.data.Dataset` API](https://www.tensorflow.org/programmers_guide/datasets) to build pipelines to feed data to your program. It covers:\n", - "\n", - "* Creating a `Dataset`.\n", - "* Iteration over a `Dataset` with eager execution enabled.\n", - "\n", - "We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n", - "\n", - "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler.\n", - "You can use Python iteration over the `tf.data.Dataset` object and do not need to explicitly create an `tf.data.Iterator` object.\n", - "As a result, the discussion on iterators in the [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is not relevant when eager execution is enabled." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "z1JcS5iBXMRO" - }, - "source": [ - "# Setup: Enable eager execution\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "RlIWhyeLoYnG" - }, - "outputs": [], - "source": [ - "# Import TensorFlow.\n", - "import tensorflow as tf\n", - "\n", - "# Enable eager execution\n", - "tf.enable_eager_execution()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "H9UySOPLXdaw" - }, - "source": [ - "# Step 1: Create a source `Dataset`\n", - "\n", - "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). See the [Programmer's Guide](https://www.google.com/url?sa=D\u0026q=https%3A%2F%2Fwww.tensorflow.org%2Fprogrammers_guide%2Fdatasets%23reading_input_data) for more information." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "WPTUfGq6kJ5w" - }, - "outputs": [], - "source": [ - "ds_tensors = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n", - "\n", - "# Create a CSV file\n", - "import tempfile\n", - "_, filename = tempfile.mkstemp()\n", - "with open(filename, 'w') as f:\n", - " f.write(\"\"\"Line 1\n", - "Line 2\n", - "Line 3\n", - " \"\"\")\n", - "ds_file = tf.data.TextLineDataset(filename)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "twBfWd5xyu_d" - }, - "source": [ - "# Step 2: Apply transformations\n", - "\n", - "Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for details." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "ngUe237Wt48W" - }, - "outputs": [], - "source": [ - "ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n", - "ds_file = ds_file.batch(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "IDY4WsYRhP81" - }, - "source": [ - "# Step 3: Iterate\n", - "\n", - "When eager execution is enabled `Dataset` objects support iteration.\n", - "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that there is no need for calls to `Dataset.make_one_shot_iterator()` or `get_next()` calls." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 153 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 388, - "status": "ok", - "timestamp": 1525154629129, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "lCUWzso6mbqR", - "outputId": "8e4b0298-d27d-4ac7-e26a-ef94af0594ec" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Elements of ds_tensors:\n", - "tf.Tensor([1 9], shape=(2,), dtype=int32)\n", - "tf.Tensor([16 25], shape=(2,), dtype=int32)\n", - "tf.Tensor([ 4 36], shape=(2,), dtype=int32)\n", - "\n", - "Elements in ds_file:\n", - "tf.Tensor(['Line 1' 'Line 2'], shape=(2,), dtype=string)\n", - "tf.Tensor(['Line 3' ' '], shape=(2,), dtype=string)\n" - ] - } - ], - "source": [ - "print('Elements of ds_tensors:')\n", - "for x in ds_tensors:\n", - " print(x)\n", - "\n", - "print('\\nElements in ds_file:')\n", - "for x in ds_file:\n", - " print(x)" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "Eager Execution Tutorial: Importing Data", - "provenance": [], - "version": "0.3.2", - "views": {} - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb deleted file mode 100644 index 84f1d031d40604ae029e8a8347474950ee01b38a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb +++ /dev/null @@ -1,485 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "k2o3TTG4TFpt" - }, - "source": [ - "# Training Models\n", - "\n", - "In the previous tutorial we covered the TensorFlow APIs for automatic differentiation, a basic building block for machine learning.\n", - "In this tutorial we will use the TensorFlow primitives introduced in the prior tutorials to do some simple machine learning.\n", - "\n", - "TensorFlow also includes a higher-level neural networks API (`tf.keras`) which provides useful abstractions to reduce boilerplate. We strongly recommend those higher level APIs for people working with neural networks. However, in this short tutorial we cover neural network training from first principles to establish a strong foundation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "3LXMVuV0VhDr" - }, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "PJ64L90aVir3" - }, - "outputs": [], - "source": [ - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "tfe = tf.contrib.eager # Shorthand for some symbols" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "eMAWbDJFVmMk" - }, - "source": [ - "## Variables\n", - "\n", - "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "VkJwtLS_Jbn8" - }, - "outputs": [], - "source": [ - "# Using python state\n", - "x = tf.zeros([10, 10])\n", - "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n", - " # value of x\n", - "print(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "wfneTXy7JcUz" - }, - "source": [ - "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n", - "\n", - "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "itxmrMil6DQi" - }, - "outputs": [], - "source": [ - "v = tfe.Variable(1.0)\n", - "assert v.numpy() == 1.0\n", - "\n", - "# Re-assign the value\n", - "v.assign(3.0)\n", - "assert v.numpy() == 3.0\n", - "\n", - "# Use `v` in a TensorFlow operation like tf.square() and reassign\n", - "v.assign(tf.square(v))\n", - "assert v.numpy() == 9.0" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-paSaeq1JzwC" - }, - "source": [ - "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n", - "\n", - "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "BMiFcDzE7Qu3" - }, - "source": [ - "## Example: Fitting a linear model\n", - "\n", - "Let's now put the few concepts we have so far ---`Tensor`, `GradientTape`, `Variable` --- to build and train a simple model. This typically involves a few steps:\n", - "\n", - "1. Define the model.\n", - "2. Define a loss function.\n", - "3. Obtain training data.\n", - "4. Run through the training data and use an \"optimizer\" to adjust the variables to fit the data.\n", - "\n", - "In this tutorial, we'll walk through a trivial example of a simple linear model: `f(x) = x * W + b`, which has two variables - `W` and `b`. Furthermore, we'll synthesize data such that a well trained model would have `W = 3.0` and `b = 2.0`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "gFzH64Jn9PIm" - }, - "source": [ - "### Define the model\n", - "\n", - "Let's define a simple class to encapsulate the variables and the computation." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "_WRu7Pze7wk8" - }, - "outputs": [], - "source": [ - "class Model(object):\n", - " def __init__(self):\n", - " # Initialize variable to (5.0, 0.0)\n", - " # In practice, these should be initialized to random values.\n", - " self.W = tfe.Variable(5.0)\n", - " self.b = tfe.Variable(0.0)\n", - " \n", - " def __call__(self, x):\n", - " return self.W * x + self.b\n", - " \n", - "model = Model()\n", - "\n", - "assert model(3.0).numpy() == 15.0" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "xa6j_yXa-j79" - }, - "source": [ - "### Define a loss function\n", - "\n", - "A loss function measures how well the output of a model for a given input matches the desired output. Let's use the standard L2 loss." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Y0ysUFGY924U" - }, - "outputs": [], - "source": [ - "def loss(predicted_y, desired_y):\n", - " return tf.reduce_mean(tf.square(predicted_y - desired_y))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "qutT_fkl_CBc" - }, - "source": [ - "### Obtain training data\n", - "\n", - "Let's synthesize the training data with some noise." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "gxPTb-kt_N5m" - }, - "outputs": [], - "source": [ - "TRUE_W = 3.0\n", - "TRUE_b = 2.0\n", - "NUM_EXAMPLES = 1000\n", - "\n", - "inputs = tf.random_normal(shape=[NUM_EXAMPLES])\n", - "noise = tf.random_normal(shape=[NUM_EXAMPLES])\n", - "outputs = inputs * TRUE_W + TRUE_b + noise" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-50nq-wPBsAW" - }, - "source": [ - "Before we train the model let's visualize where the model stands right now. We'll plot the model's predictions in red and the training data in blue." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 293 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 1210, - "status": "ok", - "timestamp": 1527005898290, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "_eb83LtrB4nt", - "outputId": "3873f508-72fb-41e7-a7f5-3f513deefe38" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEDCAYAAAA2k7/eAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXlgU1X2xz/pAhRautCWUsCwWVlcUHHGBUFQcSg7uM8P\nFLUICo4VpygObihI3UdmUHBB0IGZQbEgFNGCqKgMolV2pKylCy1pukDp+n5/3LxmaUsDTUjSns8/\nbZKXd09C+b7zvvfccw2apmkIgiAITR4/TwcgCIIgnB9E8AVBEJoJIviCIAjNBBF8QRCEZoIIviAI\nQjNBBF8QBKGZENDYE+Tk5JCUlER+fj7+/v7cdtttTJgwgcLCQhITEzl27BidOnXijTfeICQkxBUx\nC4IgCOeAobF1+Hl5eeTn59OrVy9OnjzJ2LFj+ec//8mnn35KWFgYCQkJLFy4kKKiIh5//HFXxS0I\ngiCcJY22dKKioujVqxcAbdq0oXv37uTm5pKWlsaYMWMAGDNmDF999VVjhxIEQRAagUs9/MzMTPbs\n2cNll13GiRMniIyMBNRFoaCgwJVDCYIgCGeJywT/5MmTPPLII8ycOZM2bdpgMBhcdWpBEATBBbhE\n8CsrK3nkkUcYNWoUN910EwDt2rUjPz8fUD5/REREg+eRtj6CIAjuo9FVOgAzZ86kR48e3HPPPTXP\nDR48mE8//ZRJkyaxcuVKbrzxxgbPYzAYyMsrdkVIbiUqKkTidCESp2vxhTh9IUbwrTidodGCv23b\nNlavXk1cXByjR4/GYDCQmJhIQkICjz76KJ988gmxsbG8+eabjR1KEARBaASNFvwrr7yS3bt31/na\n4sWLG3t6QRAEwUXISltBEIRmggi+IAhCM0EEXxAEoZkggi8IgtBMEMEXBEFoJojgC4IgNBNE8AVB\nEJoJIviCIAjNBBF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCaI4AuCIDQTRPAFQRCa\nCSL4giAIzQQRfEEQhLOk0GTi84R7+XbIDXyecA+FBSZPh+QULtnTVhAEoTnx7YzHuDflUwyAlv4z\nizEwfNFiT4fVIJLhC4IgnCWhhw9hsPxusDz2BVwi+DNnzuTaa69lxIgRNc/Nnz+fAQMGMGbMGMaM\nGcM333zjiqEEQRA8TqHRiGb5XQMKjV08GI3zuMTSGTt2LOPHjycpKcnu+YkTJzJx4kRXDCEIguA1\nXJ/8OosxEHr4EIXGLlyf/JqnQ3IKlwh+v379OHbsWK3nNU2r42hBEATfJjQ8wic8e0fc6uF//PHH\njBo1iqeeeori4mJ3DiUIgiA0gNsE/+677+arr74iJSWFyMhI5s6d666hBEEQXMLRjAwW9u3FWmN7\nFvbtxeGMDE+H5FLcVpYZERFR8/vtt9/O5MmTnXpfVFSIu0JyKRKna5E4XYsvxOmNMb53xQhmZh1T\n5Zalx5h3ww08cfSop8NyGS4TfEe/Pi8vj6ioKAC+/PJL4uLinDpPXp73Wz9RUSESpwuROF2LL8Tp\nTTEWmkx8O+MxQg8fIjory67cMtZk8po4z4SzF0+XCP706dPZsmULZrOZG264gWnTprFlyxZ2796N\nn58fHTt25Pnnn3fFUIIgCC7FdhHVXFSZpcHyM8vGqWgKuETwX3311VrPjRs3zhWnFgRBcCu2i6ju\nBp4JDKR7QACZ4RH839dfezAy1yMrbQVBaNbYLqK6AOgaP4L4w7lMSt+NsXt3T4bmcqSXjiAIzRpf\nXUR1LojgC4LQrPHVRVTnglg6giA0WXy1jbG7kAxfEIQmi6+2MXYXIviCIDQZbGvqC41Ggg9k+GQb\nY3chgi8IQpPBMaOfE9vRrq7eV9oYuwsRfEEQfB49s/dbn2qX0V8QEcHiq/7odAWOyWRmxoyNHD7c\nFqOxkPffHwX4uzv884YIviAIPk2hycS/B1/HJVnH2In9StnK7heelWc/Y8ZGUlLGAwbS0zWmTFnO\n/PnD3RK3JxDBFwTBJzmakUHquOFE5mTTpbqaAcD1wDygQ1AQ1UOGnnVN/eHDbcHmHuHgwWDXBu1h\nRPAFQfBJUscNt3a2BJYDdwG9gRNDhp5TNY7RWEh6uvUeoWvXEhdG7HlE8AVB8BkKTSY2PDqVgB+/\nI8ZstvPrg1HCvz22I3ec42rZ5OTBwFKLh1/EggUjqapyTezegAi+IAhejz4pq23aQBuzmWHAAuz9\n+t8CA8mPH8Edya8RGn5uXS7tu7w3vS1aRfAFQfBq9ElZR/vmbuBZwGgwkN0hlqEr19C5a7dGjdXU\nJ22ltYIgCF7NtzMe4xKL2IPVvrkAuAgwjBzDpPTdjRZ7aPqTtiL4giB4NaGHD1GC1WDR7ZvktqFs\nbH85r2eMJiHhUwoKzI0ey2gstBtJJm0FQRDciF5u2anARGZ4BC169eIBlI3TBsuk7MbNPJ60Sdkv\nuQa279CApSxaNOasx7NdbNWhw0mGDn2P7OxImbQVBEFwF/rEbNba1cysqKjZSPyF6mo+GzWW0MOH\nOGHsUjMp62i/HD7cttZK2eTkwYSHh51xXEffftSopaxffyMAERHes/euKxDBFwTBK9D74HwO9u0R\nCs3E11FT71gzbzQW1RJvZ7L+ui4cTRWXCP7MmTP5+uuvadeuHatXrwagsLCQxMREjh07RqdOnXjj\njTcICXFuZ3VBEJo+u7Zt48vRQ+ladpqDBgNRrVtjAIqxL7fMrKfE0rFmPjl5EHfcsY2zFe+6LhxN\nFZcI/tixYxk/fjxJSUk1zy1cuJBrrrmGhIQEFi5cyDvvvMPjjz/uiuEEQWgCfDkmntllp5XMahpP\nnzyJBsQDy4BC/MhoFcawxcvqfH94eFit7P1cxLuuC0dTxSWC369fP44dO2b3XFpaGh999BEAY8aM\nYfz48SL4giCwa9s20sbG0+10qZ110x14JSwMf8LZZO7HKt6G0+Hs/8dSFi3q69S5z0W867pwNMS5\nzBV4A27z8E0mE5GRkQBERUVRUFDgrqEEQfAQZyN8+qTs6VUruUjTOIS9dZMNxAwczN8Pjyc9fXTN\n+87GUz8X8T4XzmWuwBvwuknbqCjf8PklTtcicbqW8xXn1Kmf2wlfy5bL+fe/76p1nPnECVbc1J/e\nmZmUAEOBpajOltHAAYOBsBtvZMz7i1g3ZZ2dLRMXV4qfXxUPPZTKwYPBdO1azIIF8UREnJ+Muq7v\nMisrHNu5gqyscJ/423Cb4Ldr1478/HwiIyPJy8sjIsK53ha+UAIVFeUbpVoSp2uROGuzb18QtsK3\nb18QeXnFtTL/+PIVzMjMtGuN0BUYDsxqFcRfjuQCUFEFs2dfT1mZ1ZaZPXsQ99+/qubCsnWrRlnZ\nUubNG+R2W6W+7zI21oTt/UlsbIFH/zacvdi4TPA1+65DDB48mE8//ZRJkyaxcuVKbrzxRlcNJQiC\nl1DfJKlueRgwcUH6FPBbb+fXtwF+A7a0CuLmVevszlmXLVNX6aQnbRVfneh1ieBPnz6dLVu2YDab\nueGGG5g2bRqTJk3iL3/5C5988gmxsbG8+eabrhhKEAQvoi7hKzSZCN34D17hUfIoYS4VLKu29+t3\nderEHWnfOd3Vsq4Liyfr58/XXIGrcYngv/rqq3U+v3jxYlecXhAEL8VW+ApNJj57YAKl337NDSgp\n7mD5GY+yccotO1FNfn8RuXknSUhY6ZQlU9eFJSlpQ7Opn3cVXjdpKwiCb/LtjMeI/fZr7sKayb9k\n+RkG3AkstuxEFRYRwr33fe60JaNfWPS5gTvu2Far742v2CqeRARfEASnqasMs9BUwOJxCVyanU4+\nUIgSeAOqffFLQPuwMAwDB3N98muYTGamTv2c9evB1pLJyPBvMOOvr++NyWQmKcn36uLPNyL4giA4\nja3g/pqeT+DqP3BJ9UHmY83ql6E2J9GAn4Hc9pezLCqRblRzHX4251iGrbNvMh1mx44nOVPGX59v\n76t18ecbEXxBEJxGF1wDJ5jM5fyjOrNWs7NC4G3gd/zZenUS3/74ol0LY6toK2c/KKiCIUPgwIE4\nsrLOPAnrOHl7/PguDhzowaZNucDnqE488U26AVpjEMEXBKEW9a2g7djhGMb0eG5mHa3R6mx2to6r\nWcUPwCrC9uRbXjEDqaxfD+HhO4GBNe8oKytl69Z8evUKtjuTPglr36++nPbtnyY39yrgJFlZUxg7\ndgFm85PY3mMYjZVn/BzNFRF8QRBqoSySEcA60tPD2bp1CY/cW0nv1GfpDeQBuWDX7Kwc2A2s4l+W\nV04C+ZbfU4E7KS01UFqqERT0DKWlLYGZVFcbyMrSqK5+gVGjate2O9o1YWGvACNrYi0o6ITtPUZY\n2GmSk2+u873N3eoRwReEZsDZZrrKElkH3Ik/WxmSNYaCOdX0B0qAB4BVwNNAJ9QFIB9Y1vZeKNoO\n/Aj8iWuuWU6LFktZvx5KS62iXFraDyXS1ucKC411irGjbw/tsL0TCA8/Smmp9fHAgQE1n6059bp3\nBhF8QWgG1JXpnqk1gdFYyK/pp4nnEv7ATiKBUGCA5edyIALoDOwliNfIJCzsM7744g/MmfOz5Zyr\nSU4eTnh4GAkJn5KSYmv8nLT8tBXuzDpjd/Ttr7mmmhYtrHcCM2eOYs6cule9Nqde984ggi8ITRDH\njP7AgTY01JrgxImX+emnYsrKuhKifcxf2EAw0BdqGp6lAnehWiMUA7tpxRtsB8Ixm1vx7LPf0aJF\na8s41nYrtgunjh/fRVbWFEs8y/DzKyYm5gQrV1ptGltqL7q6pdbdyaJFRiff27xr9UXwBaEJ4ijm\nsbFzcJwQdbQ7Nm8uAm0ioxjFzeykEHgCa06+HNCnVf8H/MKlrGUssMvyTDybNy+kqOhBHD1z2xW5\nBQVXMmvWOvbtC8JorCQ5Of6M9lJj2hj4agsEdyGCLwhNEEcxj4jowlVXnbk1gb92kkR6MM/yzCrs\nnfM2wK/Atxh4mXuB14A1qJ6X6hynToXSkGceHh7Gv/99l090Hm1qiOALQhNEedcFqInXlvz++06O\nHGmFn18nOnSoAuztjuh2+7k07Q36Y5XrEuzLLb8DXmYDMAhQ1TLV1aUUFS1BOfoltG5toqiocZ65\nlFK6DxF8QfBRMjIOM27cKgoKOhEefpSVK0fRtavyspOTB7N16wKyslR9elnZGMrKlgHxpKauZfPm\nLwgOziW8bTsuP/4SxvTddMZe5IcCs4COwC78mc8e4D+Wo4rp1CmW7t0rSUmZgC7w/fq9xZ49cy0x\nZTJzZm1fXm+toCyd2oIupZTuQwRfEFzI+cxOx41bVSPopaUaI0c+w9VX9yArK5zYWBMREUa7lasQ\nhFoDO4PiIhN9i/7ENVk/0Qb4G/AKcDvKq28DbAPKgHXczKqaupyLgRGAxv79s3jvvTuxnRQtLw+0\ni2nOnKW1JlQbEnQppXQfIviC4EKczU6dvTDUd5zJZCYnR8O2nUBeXsuasdUuTHOxN2X2AH3wYz/j\nuYgYNHoDpZYjwlBVOCGWM5qBv/Moyqu3LacEMHD6tCrBtP18Q4akoZorpALBbNqUQ0GB2e6z2Qt6\nIZs25TJkSFrN55NSSvchgi8ILsTZ7NTZC4PjcWVl7wGQlpZLdfVMbNsJaFqE3dgFBbG0ajWL06fb\no0Q4GH/SmcooWgNXo8wZ/Qy3AWuBTGB/y1DSus7B//dDVFUtAyqBY8Bky/mV+Dt+PiXWa8HSJNls\nHk5S0lK71saHDlUCHwPDgLWYzY+Tnm79HqSU0n2I4AuCC3E2O63vwmCb0cfE5PH99/arUX/80Q+z\neSLUallWTnCwieJi69iqdcFsYmJe4FTO10xhA3HAPuBFrEK/BJiLMmwOGAL5vPsLxPVpz6fJg3n0\n0dWkpgLkoMR+Hcrw2QU8SEzMJ3YtjWfOvJJNm/6H2Xzmjpb6pC+0q3WslFK6DxF8QXAhzmanHTpk\nk57+L5SBUkSHDrZ7waoeNtAeVd9uFfGiohzL744ty/IpKTlOy5azKC/vhqYFoaZdDbSqzOMeNjCX\nusstw1E9cF7yv4viqo9hv4Hd+zVSU5+ksjIEg8FAy5a5lJXNQ9PiwLLQKiZmPgZDJCkp92N7pzJw\noL/dqlr9oud4kevS5UKMxsI6jxXcgwi+ILgQ57PTQLDbG+o9TCazpc1viuX1AcD1wDwgBmhBdXWU\n5fjrUHl5NKqTzd1o2mbKyu5CtTK7k0BWkMjtdM2HFtRfbvk9MI+HoeoKm6MKKS9vA1wClHD69KXA\nBJt3LeP06WNkZ3fA8U7l3/++krouenXd/Yh9c35xu+APHjyY4OBg/Pz8CAgIYMWKFe4eUhC8nuzs\nSGyFMjs7khkzNmI2P255vgBVUdMH8ENJdjzqYvAekAHMwX4dbIjl8XX48xBTeJvLLM9uwb7c8kng\nAiCdIBaRALyB/YYka1G1O/r5P8T+viAEaFeniNd30bMV97i4UmbPHiT2zXnG7YJvMBhYunQpoaGh\n7h5KEHwG+4VRbTh+fCdVVRdhFdV1wAzL4+G0aJFEefkBVG/KAqAH9gIcDBThzxb+j6uJRTnt+j1E\nf+ApoCfKfd9LIPN4CUjEKub6VuOlqEla2/PnYX9fUMw111Rb2hA7l6HbintUVIistPUAbhd8TdOo\nrq529zCC4HHOpgbfcWFUVtYITKZZwDisjQysgtuyZSTl5UlYBVffHlx/vJcW/EoiH2EE2qLuC/Qz\nhANGlNjPYxgwGnXXsAw4hf1W48ss77I9fy7qziIAP78sbrklnDfeGC4Zuo9xXjL8+++/H4PBwB13\n3MHtt9/u7iEFwSMkJq4hNbUt4E96egDl5Z/z4Yf/V+ex4eFhREf3tlsYdfp0F5RfHw38jvLvwwGN\n4uJg7DPuzsALQCcCWMOdfEJHqJmYreuSsA94jf0og8d2/mA2yh5S7REgwTKOyvZbtjxAWdlUoAug\nMWKErHz1Vdwu+MuXLycqKgqTycTEiRPp1q0b/fr1q/f4qKgQd4fkEiRO1+LNcZ44Yeahh1I5eDCY\nrl2LWbAgnoiIsFrHfPVVNqA6RYLGjz++WvO5Tpww88ADKWzapAF5DBgQRocO2PnfaguRGTaP56E8\n/BIgG3v5Pgp0pxWf8hc+IRi4FPtLQk9Urm4GdhDCAn4BuqPyfNsjL0c1QHseqEB1vDcAd9Kp0zx+\n/fVxpkxJZd++X8jP38vhw0amTl1d5/dwNnjzv7ktvhKnM7hd8KOiogCIiIjg5ptvZvv27WcUfF/w\n9XzFf5Q4XUNCwqqaUsmtW4P57rt/sHHjBMLDw2r62eTktKO6ugu2QlpcHMK+fUctG4Cssus5k5Ky\nBPgNeBWIRAl4H+yFuDd6GwNYgLVBcQkBHGIqM7kQ5ei3pnb1zS6gCEjmJ9Qq226Wcxc5HKkvv7rC\n8rt1nLCwzlRV+TN//nASElaSnj6DzEx9Edi5Z/re/m+u40txOoNbBb+0tJTq6mratGnDqVOn+O67\n75g6dao7hxSEs8IZ3z0jwx94B1XXspOsrF4MGrSEjRsn2PWzUatHrUJaWdmKQYOWEh3d27K61FbM\nNVT3Guvyp8DAX6ioGIO9ZBuAdJSFcyd+7Gc0/ehOUU0bYw1l7tyLtQ/O98AxgviI7aisvrvlqF7A\nXtQU7oVAS9RkrS781cDdNWfu3n1pzfcgPW58H7cKfn5+PlOnTsVgMFBVVcWIESPo37+/O4cUhLPC\nmRYHJtNhVCHjcvQtQbKyNJKSllJQYFuHPgyVscehWo/dR1bWf8nK8kdNehage/JwANs+OFBJUFB7\n2rV7kZycSNS0613AZtQdwAECuZ1EVnARqqmZ7eWjI+oeIBz4BQMvk4Ta+1XP6kNRtf0VwHPAEWAp\nEAXMx8+vkN69+9K5cxHwHtnZkbJdYBPErYLfuXNnUlJS3DmEIDQKZ7LWdu3iLJOr9hOnq1ZVomm7\nsWb1eulxgeXnRqADavJ1OJCEqoTRd4PNRU3QLgBKKSp63tJL/hmU4P8XmI4BE4MZSD920hvV0cYP\ne1PGhDJqnuIOVOb+PKp/zjKgHFWRU2zzGb5HZfnqDDExc9mwwdrKWL/zueOObTV3PrJIyveRlbZC\ns8aZrLVbt5Ns365qz21lVrUviEOJahCqQUEgyjL5K9a+MwuAKajFSvYNz2AkyqefZxnNgLJdDgNt\n8eN1pjKDi6gkFHUPEYqa2l2GtbNlJvAmM1C1OXrzhDCUPbPaMsYCwsJ2YzYPx/Hi1a5dnN1nru/O\nR6pzfBsRfKFZo2etGRmtMZn2kZFhJCHhUzsv33qMH/v3P83p00bUQqRhwHpU24NiVLZ+P8qqWYeq\naTegxHYJyj5xXK2q/x6MkvA2qJLMSIJYwiNs4VpqbyLeFdiBWob1Bd1YxR2oOwioPX2rHgcGmtiy\nZQJJSUvZtCnHIvzqmG7dTtl9L+LXN01E8IVmjb5w6J57PmbHji5kZYWwY0cOP/zwHuXlFwD5XHNN\nMG+8MYJZs75jx47nsbY+eB3ohxL7oSj/Xq+jz0ZZKmGWn0eAVjiuVlVoKKPmYcBAC/L4Cw/QAlUh\nb9s8Qd9E/DCQg4G5bEU1MzuFaocQgrJwnkDdfRxH1c8vY8CAkJrPW1BgJimpfntG/PqmiQi+0OQ4\n212nTCYzX32VhbWG/l8cP/4Mutilpi6jRYuNZGWFY93cIwNV6a7L8XxUhYttHf0ylKWi96XRPfVi\nlOtegrJgwlGZfSEB/I9HeYCXUFOqtvcDbVDSvhmYxzxURq8Bn6CqbabYjP0CMBbdVoqN3cE//zm+\n5jM3tEJW/PqmiQi+0OQ42z1RZ8zYSEVFP6zyGoKj9ZKSchz4BiW5T6KyedvVqi+gJktt31cIvI+q\njLH11FehBD8Y/QLhzxbGE04HrF1yjlG7q2UJLfgHPwArLec5iTJ42tuN3bZtB6677hNLtY2Z5OTx\ntS56Z7owSsuEpokIvtDkcPSfMzL87TbpePLJK5k792ebTUayUTX2+i5MjguTTKiKmj4o774QVSrp\n2Opgj8P7TqKmWG1ragpQ9fVRwDEMHORmRtGXHXRB1ebonW3uRuX/eqOFDYRyqtfrxBauo23bzhQV\n7aBduzgOHy6nqGgnaq5AjT1oUIsGBVs2C29+iOALTQ6r/1wIrGXv3kPs2KGqY9LTNdasmUVl5eya\n15V4ZwMXofZvLUJZNsrDV4+fw75LDdiLeybKdHkS1aYsFHjA8vM9VK+aXqhFVOpcLXmTKXSnDfZe\n/RKsPStPAj8Bb/MhsbGZpG+6tdbn7dv3LYqKpqAvuwoK+onk5IRaxzkiE7PNDz9PByAIrsRkMlNe\nXkFY2AcEBLwCDKWiwr7LTGVlZ8vjVNRkaxFqknMsSozboiY6W6BE+yrss/k+KL9cL4FcjppwDUDZ\nQS1R+XmY5fho1H+1/6F8/+W04s88yqNcBfzB4ewRqPqeA0ApLXibn4AJhIZ2r/Mzq5LKcJTFNJKe\nPS8/45yFjtFYiLrEgEzMNg8kwxe8nrOZhH300S9Yt05tuWetbdFw3A5Q/QxGTWr2xl5y+6Hq4/X3\nV1PbqgkDLkbZKDp6q4IdDsf/BDwGrCKA9fyFJfRE4xDKvsHh6L3Ad0AyM7Dtf1lYmFHnZ7auE1DH\nXXjh6TN9nTXIxGzzQwRf8HrOxmv+8UfbLvB6bcsAVHVMIcq6Kbc8PoaycRzr1k/avL8Y5bnvRmX9\nB1CLqqC21/+75Zi7gVmoydTjwP34cZw7uI8LqKpZLfsAyux5DGsPnP8BvxDLWraj7gqWoyZ9Aykp\naUtBgbnWxc5RuBcsGElVVcPfq0zMNj9E8AWv5+y8Zj1710V4p817v0cJcg9U1h2E6lLZBiW90ZZj\nJlse630oo4CHULZJAaqRWm/LuZdg7SMfClwLfInK9A1AJa14iGmspSWq4YFt82Mj8E/LmX/AwFv8\nRHT0ajiul4BqqN2n/CkqaklS0sZaIu0o3BERvtHhUTj/iOALXo+zi4BMJjMtWxZjbTlcSfv2p+jQ\noYqYmFOsWxeNmjgNQVXbPIG1cuYN1H+HauADVNOx6Vjl+VUgFtXorA/KytmL/cbeT6IuAN2Bv+HP\nVhKYRDBVhKHW49ree8SiMv1iYBVhFF74MqN676CkJJy0tGVAlkMMS2RiVWgUIviC1+Ho2c+ceSWO\nXrPtMR06ZAOB/PCDH2ZzT/SOM4GBc7jiigt4440refTRNahM3LZ2XpffdcCzNs/PcXjdgLqABKP6\n4kRZXg/HWk+Ti6rq6QQYaM10pvI6ccAh1KXAcQeqXagcftfVf+XzVbNqPv+QIWmoLQhXO8QQjtFo\nbvwXLDRbRPAFr8Pesy9g69YFREf3tpuwTUhYaXPMv7AX8uXAXVRUXEpqan9++WW+peVwFUpiQdkx\noGrsq7AX1hhqL3tqgbXR2WzUHMCtKBtHb5u8kAC+ZzjzuAhqvPo4y1nuxtp4YR/wHZEcaD+Jrz+c\nbPf5rXc09s3aYmN3kJw8HkE4V0TwBa/D3rNfR1bWk2RlWSds580bxKZNucBnqLr2AOBDy/GjUR0r\n56JWn75ETs5L2Fe5t8Bq59S1+2suKovXNwjMBfoC/0JZOlGoOv3lqHmAUYCBEN5iCjsJwbbxMDxt\n+WlEratNAl7hCoYOfYCvLRuB22Jt1uaPyTSXdu3i6NbtVJ2rZQXhbBDBF7wG3aZRu0Ppq17bYJt9\n792rcdllb1NW9kfURGknVL2LrdeeidqnNQJrWwMsP09TO6PvgrU12V7L43hUnX4R9nbPMlRWfxKV\nvz+PH/uJJ5o+VDAX1SvT9uzdUReAjqgan9dYQlBQFR9+OK7O70GqZwR3IYIveA22Vg5otG37EuXl\nJzl9Wm8ZUMC+fXuprn4Re4G3ldeLUNOhQ1HevGPVTjFqktb2Ob2RgV4FvxtlEd2Ftbe8fv5y1F3E\nt0B/gunNFPbQBWtdTrHD2fcAJ4C5PI1a2KURGvqi6744QXASEXzBa3AsvywpiaK6uhxYCORjMBRS\nXd0fewFuR+3e7yFY+9G/i/1WIUUoF/0pVO69H2XLrEb5+WGoCdqXUP89CrHtUaNkPRQD2VxPF66h\niDhUBX6X0UVuAAAgAElEQVRLyxHxWHtiHgR+IJhv+NYS0xLgd/r0EWtGOP9IawXB45w4YSYhYSUH\nDuzFdql/dfUhVOVLS+AhNK071kVSYF3s9CyqNn4Z1lYJuhV0G9bSS/189wDXAPcB/ijhH2455/2o\n7cCfQOXl01F2zyrURSKHAJ5kEg9yDUX0Rjn8D6LuJf6Gala8C7WI6oer/8qivbsIC/sSVc4ZCEzn\nxIm62yQIgjtxe4b/zTffMGfOHDRNY9y4cUyaNMndQwpegG3ZZExMHgZDJdnZHepsjfDQQ6kWK8ex\nX/x0rJt+L0fVzt+OypL1TUNOowTeH+XNL0CJ/SFUZh6GyvR160evrNmCEnRQUv0qtdsi2/aoAX/2\ncieP0Qkl246LqK6wRP078D3h7G8/hc/evIvw8DAGDowmJcW6w5T0rRE8gVsFv7q6mtmzZ7N48WKi\no6O59dZbufHGG+neXbKbpo6jH6+EfDTp6Rrl5e/QokXrmjr7I0d0K0fvF78ENXG6DjWRql8AilCZ\nfABqQrYUJdIXUbsscwLwIsryse1cqS+gMqIyeX3B1KWoUk1be2h/zWMD+dxOEp1Qa2mzqb2Iapcl\nor/zHPA05GrMmbOURYuM0rdG8ArcKvi//fYbRqORjh07AjBs2DDS0tJE8JsBGRn+WCtfirGVx82b\n8ygq6g74k54eQIcOv6AmQnWhPWY51rZ0ch5KmF8GrkZZO9NRG4E4ZubBKHHvbvndtsGZXrnTEmuZ\nZXeUXD+OtavNVmASBhYxhme4kZyada/hqBoix0VUR4BlrETdbahY9JWxUnkjeANuFfzc3Fw6dOhQ\n87h9+/Zs377dnUMKHka3cvbu3U99PeSLiqqwzchPnnyRsLBXMJs7oCpkOlN7pWs0ag9Z2wqd5aiL\nQ0vs5fc3lGUzHXVn8aHl+TxUM7OHgR9Qwr4AlZf/EVv7BvJoxTKmMZN5DiPehSoYnYOq9P8dSOav\nQDLWuxn1WcW6EbwJtwq+pmkNH+RAVFSIGyJxPRJn3Uyd+rnFyvkMW8E2GNqiaUtQAh2DdW/YYIqK\nqhk6tA2pqX6orQIN1M6hg1Btix07YZajBFvvn3MUlbFnoCyhLOy3F1mG6pXzrOW5EahGadZiSj/2\n8Sce4BKUfeM4Iqj7h2LgZ+C+las5tKyYgwdX07GjCU2rICtrNV27lrBgwUgiIs7/34ov/H36Qozg\nO3E6g1sFPyYmhqysrJrHubm5REdHn/E9vtDlLyrKN7oReiLOffuCUNKod3pUQqtpLVFTnX1Q2fda\nrFn+cNLSnqBt21YUFenyOgwl4hEosR9qeY/tRWAr1lYJlUCZ5fl01DKnO1HzAbaSHYK6IDjePagW\nyi1ZwZ9ZSRRK7B0bJ+9EXbIOA/P4KzCPqsX1t2uuqjr/f9O+8PfpCzGCb8XpDG4V/EsuuYQjR45w\n7NgxoqKiWLNmDa+99po7hxRsUOWOq5zaOMRVqD4wBajVrh+ibJTTqHLIO1HSeT3wH2xFt7z8QgwG\n6ySpyqFjUcuWdGtoKMoaCkNV1kSiBL81+mbg1sVYRcBbqAzfceGVfZ8cP78f8av+HxN5kWiUQXQZ\nSuyHYu/qlwDbCWAZe1AXDqSDpeAzuFXw/f39mTVrFvfddx+apnHrrbfKhO15xFrueP42qU5OHszW\nrQvIyrLtJvMSyh/XbZxAVAXMIpS9UwSUU1b2V1Stew+sE7ftUNXtF6JEvhR1Z6B78Ccs4ziutu1v\nGddoOacR5d/HoCqBVJ+cmBgThTmnmcrrNVO8rbGK/TqsG5OUAIb7JnHqxLWQ0s0ynvj0gu/g9jr8\nAQMGMGDAAHcPI9TBwYPBOL9xyJlxdpvB8PAwoqN7k5VlK8DtgV9RkqnbOONQojsCdVF4EXVR6Im6\nIPwN6wVjBipTvxh157Ac1YuyBEgE3qb2att1KMG3nW69HVXW+RtgwJ9D9Mx5mb5Qk9lXAL9YzqqL\n/fdArrErr//8ExVVgRQUmJESS8EXkdYKTZiuXYvZurXhjUOcwXGbQb2WPiOjNSbTXiIiutC9eyXJ\nyYOJicnDXoBboeri11HbT9d/jwIWo5oRnLa8pxRVNtkVtSh8JKoLpm255nLUBWW25Ryhlvd84zBW\nBaq0cwYQTkve5EFeJghVgW9bxT8Pa9u0X4G+H3zMjcNGEGbZSUpKLAVfRQS/CbNgQTxlZWfORPXM\nvS7hts3gHfvc/PBDMWbzg+gymZX1Pjt2BLFmzVpURfoLqJbCe1GLnlRFTm0/HcvvIVgbmC1Bib6+\n4YgJlYNrqDsAx7qZwyg//3eU7/8ZyhKy9sDx89tD9+4XcOD3KQzj31yE2tMqHzUlbHvGCNQ9QC6w\ns/f/MX2YbR2/IPguIvhNmIiIhjNRxxWxWVnL2bFjZK1NRxy3GSwpsb0AFKJE/lkqKx27wFdYfgaj\nJmuXo7L3LagFSgtQ3vpfLOcyoKpt9K0D9bJJtayp9sYkO1Fi74eaatVQ62BNGAwLMBgKCAgopLz8\nSTJ/f5XH+DctsC/UdOyGvxdY3fmv9L7iYj4Su0ZoQojgN3McM3clzKvsNh3ZsuUFIiO70bLlLMrK\nugIFVFaWohqSrUMJdBeH8/RC9YzvgzJJ/FENyu5CyeoW1MpWfd1qqOW9GsqnL0RV46iWxMHBwbRu\nncHJkyGcPDkLuBJ1FzAFa0Y/E1sZ17TOaFoorQK2MLG8KxEUcrHlXbaRhqIWUYUDu6KiefS7//FE\neIQLvl1B8C5E8Jsp1s1Gcqg94Wm/yjUnpx05OX7AH1AZ9RSsbvdc6l4odQRrqeQIm2MvRpVq9gJS\nULtP9QdmWc5/EjVluhblxa8FWtO2bQGXXRZJaupk9L48+lixsVmUlMTY1PAb0Dcab8F7TDr1Fn1Q\nS7JKLKPbRhqGarV22YrV3DZgoAu+XUHwTkTwmxG2lTbHj+8kK+shlOwtIyTkFOXlBygr80ctWtJ3\nnApFudlTsIq33mCgN9YLg75QKhrlpV+OfR6t7yk7EiXYek2+vvq1o+U1nWLgH+hZe1aWRnb2U8BS\nlGe/gKCgIMLDs4mIMFJdfYCiIpvaevZzEy25iHIuR80QBAKnLL+/iJrizQUyW7Zk8jdb6Ny1G4LQ\nlBHBb0bY+/WjgPctrxRQXNwG5YPbNv3VO0teQG3bR29yZrtQqhglpzEoJ9w2j24DVGP1823PV4i6\nSNger0u09ThNuxp1UVCWTXi4ucZ6ggJiY+cSHd2bnN8/5s8nV9ADlc3bVuC8iqrSH44ylHq/9TZT\n7ri7MV+rIPgMIvhNGMeVthkZAdgLbQFK0O+zPLbfzi8oKJLQ0AxycsB+Zep2AgK+oby8M0pC26EE\nXpU8qvLKu7GuUf0NdRFohVoEFYSSXF2GQ1H5ti7HJSg7Zz72F4GdqBYIYfj5taekRP8cAOFEtA1j\n4P4JtDhZXNPw7ANq32f8AmwA/mgptxSE5oIIfhPmgQdSSElR1S7p6RrR0S9gK6ABAcFUVtq2Frbv\nHBMenkV09CXk5NyAEu8yoAXV1X+mvPxfqIlafU3qfJTYg8r030bZNDstj5+ynONFVEbvKO6foTJ6\n2wtBEQbDU5bM/iQwGVXeeSctWhykqKhnTbxBPMWf9syhB8qmsZ3yrVXT87fnmPlIoku+Y0HwJUTw\nmzCbNtlPvppMkSi/3AAcoqoqFNiBVWSHoiZXewO7aN26DXv2bANyUPZNN5QL/jFK7HeiRPtt7Gvs\ny7BfHDUHqxVkQElxLPbibsDffzdVVXrXSwNt215ARUUwpaW23n4psbFzadu2M3v2DMOfF3iAp4lE\nTfmWoNbTrkXdY4xCFYgagX1Az7feZqRYOEIzRQTfx6ivxUFdzzvWo1RXF6AmX5cBT6BpytYJCHgG\n6EBl5WGUtbINuICMjN/RtBmo0kt9kdW/UBuRLMde1Gdh3SzcfkMSgyHC0iq72CaeocTEvMjp07EY\nDCauvroNYCQ19YGacw4atJStW49SWmr9DLGxOaSnTyMh4VP279nCgzxNR2qvvT2NMpbyLaMa3nqb\nv4rQC80cEXwfo74WB5s2VWI2twRuID09FFjK1Ve3JDX1JZS1coyIiALy8x0nTcMJDu6C2TwRtcI1\nEHgMNUmql17G2hyvi7njxGtfVG/6LNRCKqtI33ijgV275pKV1cVyvjhiY/ewceM9hIeH1bSgLSgw\n06KF/cpgs7mQMWPmUlDQifDwTFauHMnRjAx6bnyEKyiqKcB0XHt7CDUzkN82lAlfbpIKHEFABN/n\naKjFgV4yefhwW7p00YBpNa/17fsOO3a8YKmpt9opRUU5qOy8LepPwlY+9SZluoAXYW2LYOuOH0Jd\nWKpQ3S5fom3bKAYNakFy8jAAkpI2cvhwT4uYj6/VfK2uHjXh4WGkp0+refzh669wdO7zRKNaIDiu\nHNCAHwEzEP3W20yXrF4QahDB9zEcWxyoCpnaJZNGYxHHjkXYvZafH0OfPhXk5LRCTZqGAK2orn4I\nlQ8/i6qksfXWT6ImVZfTqlUZoaEZ5OYuQS2YeslyjmKUp1+NukNQq2mvu+49OwFvTMOxXdu28Wn8\nYDprGiGWT51nGd22Z/1m4GhQax7/+nvJ6gXBARF8H0H36A8caENs7BxLk7MqysurSE21XgDCwvYw\ncGABTz55Bbfeuhpb8TYai9i06TQwFaugv4+qfAFlyVyCtY/8LtS+sGHAnUREzKWg4EJUnxtFQMBz\nVFY+jaqLWYtqoaA2B8/OjnTJZ1/98VL2JD7MGyhhn24T/VxUfVBHVGfLuLfe5nHJ6gWhTkTwfQTH\nJmeXXvoe0IKjR1sTGzuXdu3i6NbtFMnJdxIeHkZCwkoyMyej574xMb9SXh5JUVE7lH0TjxLyAlQd\n/nKs1TS6NdQH+CcBAZFERBwnK2sq6uKgoQt8dXVny/kqsDY8U6tnjcZKwPle+nXxzovPU/LmK/S0\njOLY2bIT6rK0u20od4lXLwhnRATfwzgrho7e/Y8/+mE2Wy8AV11l3c3KZDKzaVMlqi7+LgCOH99h\n6UNjK+h3Uv8kbAVwjKFDI/jww7sZMiSN48f15z9E7Vg1nerqcMv5PrR7f1jYaZKTbwZqTzQ7s/NW\nWspnbEmYQBuse8sOpfZWJzuB61es5o/SA0cQGkQE38M4K4a1vXt9az8A+92sZszYiNlchbJWQoAi\nqqvD7I4PCqrA3382JSX+qBW2O7H37gNRxY7v2Yy/FvssXu+pY8DP7xjV1db4Bg4MqLlwOV6sGtp5\n6z8LF3D0bzO4Cvu2CMtRl6VZqLqhA35+jFi3kd59Lz/j+QRBUIjgexhnxVDV1VtLFsvL29h590Zj\nUc3dwvr1oNab2u4rOwfb3HjIEPjii3KsneGvB55BbczdApVPG2p8+OTkwWza9CVms2MBJIDGLbdE\n1Cqp1HG8WNW389aWDRv45s7RdEXtcWVfza9G247aB6vlW28zQ7x6QTgr3Cb48+fP5z//+Q/t2rUD\nIDExUfa2rQNnxdCxZLGumvWkJFuf374vjiqvXEZY2GkGDgwgOXkQ69dX2RwTDlyFqrixdrLU4wkP\nD2PgQH9SUmwXQe0gOrqamJh8gHptKceLVV07b+kWzlUood9B7f2xvgcKWrfmwY1SgSMI54JbM/yJ\nEycyceJEdw7h8zgjhnWhXwD0rP6OO7ZZetvrXWRKsG5Q0gb4BX//Mq65xkhy8gjCw8MID8+yW8Wq\nO+V610nHeGrHOr5mgjgl5X7qs6XOtAfsrm3bWDlyCK0qKojCauH0x9pBPwzVmu1SaYsgCI3CrYKv\nVmoKZ+JMYujMhO6jj37BunVKbK37wd4DDMVgeBlNe9Hy2giqql4lNXUKLVooQV65cpRlFWsHNO0A\nXbv2IC5udZ2LovRY580bVBNTUtIGkpMHn7VHD8q++e6uMcRpGq1QbdG+xv5+oyewBwiZ+wp/u39S\ng+cUBOHMuFXwP/74Y1JSUrj44ot54oknCAkJcedwTQZd6Otql+B4cfjxRz9sxTYg4DQXX/yZpea+\nu4PnrpqS6YLctavRbhWrM9Q1yWw0ak7ZUjqrP17KvsSHa/bK0uvpY7G3cHYAwdNncKeIvSC4hEYJ\n/sSJE8nPz6/1fGJiInfffTcPP/wwBoOB119/nblz5zJnzpwGzxkV5RsXBXfFeeKEmZtu+pjMTH17\nQGs1TFZWeK1xDYYT2MpkSEgxv/zyIACjRy+289z1n3FxpWcV/4kTZh56KJWDB4PZv78a2wtMVlY4\n69Zdz5Qpyzl4MJiuXUtYsGAkERG1z3/49995+7rr0PLyuBb7GYYY1KaFy1BtEQ4FBjLhhx+49Mor\nnY7zfNDc/z5diS/ECL4TpzM0SvA/+OADp467/fbbmTx5slPH5uUVNyak84Le7MsdJCSsIjPTdutA\na7uE2NiCWuNefXUbUlP1LpXFXHGFH6NHL7HYQBUMHfoemZlhnDixj4gII927L2X27EHk5RU7vQYg\nIWGVzWSw/d61sbEFVFX5M3/+8Jrjq6pq/ztu2bCBNXeOph2qyfJOVF2QXsV/AHVZy42MYsSaL7nN\nMinrTX8P7vx3dyW+EKcvxAi+FaczuM3SycvLIyoqCoAvv/ySuLg4dw3VpFB2i307ML1dgu0Eqi7W\nmZnRxMbutWm10NbOchk1ailpabcAt9Qay9k1APYe/TDCwl6hS5cLnZpkLjSZ+PD2MZz67Rcisd9A\nUe+8vxXVxrjrW2/zkEzKCoLbcJvgv/zyy+zevRs/Pz86duzI888/766hmhQxMXnArVhbIvzGpk33\n1Mq8HVst6CtthwxJw9kJVGcnW+1LR0MZOLA9ixbd2OBnKTSZePe6fgSeyKcDEIf9fUsUkI5aQnaD\nbDcoCG7HbYKfnJzsrlM3aQyGSlS/GmXRXH55O6daLehi7Wxd/9kce7alo4UmEx/eOoLAHdvpgloC\n1prabYx/B4au38TAmwf4xG2zIPg6stLWy8jO7oCavtQff1bncfWJta04x8WVMnt2/eLsrJCfqXTU\nkS0bNrDqztG0R7VAsF3nexfWNsbfA/1XrJa2CIJwHhHB9zIam3XbinNDE05nI+QNUWgy8cn/3U7Z\nT/8jGrVm19a+CQcWoBZRHbj8Sh5Y/gmh4REuGVsQBOcQwT+POFMV8+STV7J1q76l31FmzhxV57lc\nKdaNpdBk4tUr+hB56iRdgQyUjWNr32QBtGzFdau/kKxeEDyECP55xJmqmLlzfyYr60nAQGmpxpw5\nS1m0yOiJcJ1CX0TVDvsKnGewt286zn1FFlAJgodp8oJfV1ataZzzhhyNwZmqmHNpU+AJ0lI+Iz1h\nAj1QmyN2wN7CuQDV1fJXoK9U4AiCV9DkBb+urBo46w05XIEz/vzZVNl4gkKTifWJD3MkdY1da4SZ\n2Fs4+1FCP12EXhC8hiYv+PVnzOc/i3amKuZcu2eeD45mZLDwun6EVVfVqqmPRO2EG4US+/C/PSdZ\nvSB4GU1e8OvOmM+u2ZercGai1ZsmY3UKTSa+eHgSpWnrCUNtObgT1XxZb41QBJQBmZf2JfG/n0kF\njiB4IU1e8OvPmL0zi/Y2jmZk8J8Bf2RuRTnLgenozZZVa4RoVEYfcNUfeOCj/4jQC4IX0+QFv76M\n2duyaG9k17ZtpA4dRFfq3ua8N/CjwY/79hwQoRcEH6DJC75w9hzNyODzUX/C73guc4FXULZNMbW3\nHBz6xUYRe0HwEUTwBTv0rL43qtfNEdTq2GUooX8JaAvkRbfn9tVfyN6yguBDiOALgEXoRw7hgooK\nLgGGoerrXwKmAGtRk7SFLVpy7efrZbWsIPggIvjNHL0C50Taeru6erXHFrRHZfeHUZ0tbxOhFwSf\nRQTfDTi7k5SnSUv5jG8SJhAFGKlrjy3YB1RFRnHXmi/FvhEEH0cEvwHqEu+GthNzdicpT1JoMvFz\nwgQ6A0+gsnjbCdm9wGZUVi/2jSA0DUTwG6Au8f7sswlnfI8398PRWyPkfLWei4BAVKTxKBvnJJCN\n2nLwZulXLwhNCj9PB+DtnIt4G42FqDwZvKkfztGMDN699CJOpK7huYoKgoBjqEjDgDtRi6hCbhzC\ntL2H+OOAgZ4MVxAEFyMZfgOcSzMzb+yHU2gy8emga7m2vAwT1qx+CfA00BXYHxjI7d9tFa9eEJoo\njRL8devWMX/+fDIyMlixYgV9+vSpee2dd97hk08+wd/fn6eeeor+/fs3OlhPcC7i7S39cMwnTvDf\ne+6hcPO3VBYXM1vTMAAfY83qpwFzAgOpvOkW7ntjviyiEoQmTKMEPy4ujvnz5/P000/bPZ+RkUFq\naipr164lJyeHiRMnsn79egwGQz1n8l68RbzPlqMZGSy89goiNI0eqEVU24FLUTX2r6I6XO5v2Yp7\nf9sjQi8IzYBGCX63burWX9M0u+fT0tKIj48nICCATp06YTQa+e2337jssssaM5zgJLp9E6VpdrtQ\nPY0S/FDADJyIiua2z9eL2AtCM8EtHn5ubi59+/atedy+fXtyc3PdMZTgQKHJxL8HX0e306VUYF9b\n3xVYDByL7ci9GzeL0AtCM6NBwZ84cSL5+fm1nk9MTGTw4MF1vscx4wectnMaqnH3FrwtTvOJE6Q8\n8ACZa9Yws6LCzqvXM/x9QI9Ro3j4/fcJi/Ausfe277M+JE7X4Qsxgu/E6QwNCv4HH3xw1ieNiYkh\nOzu75nFOTg7R0dFOvTcvr/isxzvfREWFeE2cu7Zt48sx8XQ9Xcpx7FfMDgNeADqixF7fW7aiyru+\nZ2/6Ps+ExOk6fCFG8K04ncFldfi2Wf3gwYNZu3Yt5eXlHD16lCNHjnDppZe6aijBhi/HxDP7dCn3\no1bM7sW6AiAU8IvtyIC9h5h+vEi2HBSEZk6jPPyvvvqK2bNnU1BQwOTJk+nZsyfvvvsuPXr0YOjQ\noQwbNoyAgACeeeYZn6zQ8WaOZmSQOm443U6X2vn03VFtEsqBnE6duCPtO/HqBUEAwKDVZbh7EF+5\nffJUnIUmE9/OeIyDa1fzXEUFy1BdLXWffhYQGhZG62uu488fLaGiKtAjcZ4NvnTbLHG6Bl+IEXwr\nTmeQlbY+gi702qYNtDSb6YZ9D5xS4ECrIG5eta6m/01YhG/8sQqCcH4QwfcRvp3xGPemfGqXydv2\nwJkT25G/pO/2ZIiCIHg5IvhejJ7Vhx4+hHbogJ1X3xO1G1V7Pz+yYzowdOUazwUqCIJPIILvxdhm\n9Y419dlhYcQMHMz1ya/JpKwgCE4hgu9l7Nq2jS9G/wljWRm5wALgblRN/SthYXTv0o1CYxfGiNAL\ngnCWiOB7GV+OiefFsrKaTH4ZkIry6SMHDub6RYs9GZ4gCD6MCL6X0a3stJ1XHwKYgoJYPGQo1ye/\n5sHIBEHwdUTwPYztxGyh0cjuwBZo5dYMvxioHjKU4ZLZC4LQSETwPYxduWX6z/x94CCe+vF7jGVl\nHDcYaHX9QMZIZi8IggsQwfcwoYcP2Vk4nQsLuftonidDEgShiSKbmJ9HCk0mPk+4l2+H3MDnCfdQ\nWGCi0Gi02e4cCo1dPBihIAhNGcnwzyOO9s1iDFyf/DqLMVg8/C4yMSsIgtsQwT+PONo3oYcPERoe\nIROygiCcF8TSOY+IfSMIgieRDN8NOJZaXp/8OqHhEWLfCILgUUTw3UBdXv3wRYvFvhEEwaOIpeMG\n6vLqBUEQPI0IvhsQr14QBG9ELB03IF69IAjeSKMEf926dcyfP5+MjAxWrFhBnz59ADh27Bjx8fF0\n69YNgMsuu4xnn3220cH6CuLVC4LgjTRK8OPi4pg/fz5PP/10rdcuuOACVq5c2ZjTC4IgCC6kUYKv\nZ/CapjVwpCAIguBp3DZpm5mZydixYxk/fjw//fSTu4YRBEEQnKTBDH/ixInk5+fXej4xMZHBgwfX\n+Z7o6Gi+/vprQkND2blzJw8//DBr1qyhTZs2DQYUFRXiRNjnD/OJE6Q+9BDBBw9S3LUr8QsWAN4X\nZ31InK5F4nQdvhAj+E6cztCg4H/wwQdnfdLAwEBCQ0MB6NOnD507d+bQoUM1k7pnIi+v+KzHcyef\nJ0yyLqLaupXFZZVM/OwTr4uzLqKiQiROFyJxug5fiBF8K05ncJmlY+vjm0wmqqurATh69ChHjhyh\nc+fOrhrqvCKLqARBaCo0atL2q6++Yvbs2RQUFDB58mR69uzJu+++y08//cTf//53AgIC8PPz4/nn\nn6dt27auivm8Umg0oqX/XLPloCyiEgTBV2mU4N90003cdNNNtZ4fMmQIQ4YMacypvQZZRCUIQlNB\nVto2gCyiEgShqSC9dARBEJoJzVLw69pbVhAEoanTLC2d+vrVC4IgNGWaZYYvpZaCIDRHmqXgS796\nQRCaI03e0qlrf1kptRQEoTnS5AW/Pr9ePHtBEJobTd7SEb9eEARB0eQFX/x6QRAERZO3dMSvFwRB\nUDR5wZfWCIIgCIomb+kIgiAIChF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCY0SvCT\nk5MZOnQoo0aNYtq0aZSUlNS89s477zBkyBCGDh3Kd9991+hABUEQhMbRKMHv378/a9asISUlBaPR\nyDvvvAPA/v37SU1NZe3atSxatIjnnnsOTdMaOJsgCILgThol+Ndeey1+fuoUffv2JScnB4ANGzYQ\nHx9PQEAAnTp1wmg08ttvvzU+WkEQBOGccZmHv2LFCgYOHAhAbm4uHTp0qHmtffv25ObmumooQRAE\n4RxosLXCxIkTyc/Pr/V8YmIigwcPBmDBggUEBgYyfPhwgDrtG4PBUOs5QRAE4fzRoOB/8MEHZ3x9\n5cqVbNq0iSVLltQ8FxMTQ3Z2ds3jnJwcoqOjnQooKirEqeM8jcTpWiRO1+ILcfpCjOA7cTpDoyyd\nb775hnfffZcFCxbQokWLmucHDx7M2rVrKS8v5+jRoxw5coRLL7200cEKgiAI545Ba0T5zJAhQ6io\nqIMzjrUAAATvSURBVCAsLAyAyy67jGeffRZQZZkrVqwgICCAp556iv79+7skYEEQBOHcaJTgC4Ig\nCL6DrLQVBEFoJojgC4IgNBNE8AVBEJoJXiv47733Hj179sRsNns6lDp58803GTlyJKNHj+b+++8n\nLy/P0yHVyZn6HXkT69atY/jw4fTq1YudO3d6Ohw7vvnmG/70pz9xyy23sHDhQk+HUy8zZ87k2muv\nZcSIEZ4OpV5ycnKYMGEC8fHxjBgxwq6c25soLy/ntttuY/To0YwYMYL58+d7OqR6qa6uZsyYMUye\nPLnhgzUvJDs7W7vvvvu0QYMGaQUFBZ4Op05KSkpqfl+yZIn29NNPezCa+tm8ebNWVVWlaZqmvfzy\ny9orr7zi4YjqJiMjQzt48KA2fvx4bceOHZ4Op4aqqirtpptu0jIzM7Xy8nJt5MiR2v79+z0dVp1s\n3bpV27VrlzZ8+HBPh1Ivx48f13bt2qVpmvo/NGTIEK/9Pk+dOqVpmqZVVlZqt912m/brr796OKK6\n+eCDD7Tp06drDz74YIPHemWGP2fOHJKSkjwdxhlp06ZNze+lpaU1PYW8jfr6HXkb3bp1o0uXLl7X\nZO+3337DaDTSsWNHAgMDGTZsGGlpaZ4Oq0769etH27ZtPR3GGYmKiqJXr16A+j/UvXt3jh8/7uGo\n6iYoKAhQ2X5lZaWHo6mbnJwcNm3axG233ebU8Q2utD3fbNiwgQ4dOnDRRRd5OpQGef3110lJSSEk\nJMRrb01tWbFiBcOGDfN0GD5FXX2htm/f7sGImg6ZmZns2bPHaxdlVldXM3bsWI4cOcKf//xnr4xT\nT46Li4udOt4jgl9ff55HH32Ud955h/fff7/mOU9mfA31EUpMTCQxMZGFCxfy0UcfMW3aNA9EeXb9\njjzp7zoTp7fhbXccTYWTJ0/yyCOPMHPmTLu7ZW/Cz8+Pzz77jJKSEh566CH2799Pjx49PB1WDV9/\n/TWRkZH06tWLLVu2OPUejwh+ff159u3bx7Fjxxg1ahSappGbm8u4ceP473//S7t27c5zlA33EdIZ\nPnw4Dz74oMcE/1z6HXkCZ79PbyImJoasrKyax7m5uU73hRLqprKykkceeYRRo0Zx0003eTqcBgkO\nDuYPf/gD3377rVcJ/s8//8yGDRvYtGkTZWVlnDx5kqSkJJKTk+t9j1cZz3FxcWzevJm0tDQ2bNhA\n+/btWblypUfEviEOHz5c83taWhrdunXzYDT1U1+/I2/Gm7LqSy65hCNHjnDs2DHKy8tZs2YNN954\no6fDqhdv+u7qY+bMmfTo0YN77rnH06HUi8lkqrFJTp8+zQ8//OB1/8cfe+wxvv76a9LS0njttdf4\n4x//eEaxBy/08G0xGAxe+wf86quvcvDgQfz8/IiNjeW5557zdEh18sILL1BRUcF9990H2Pc78ia+\n+uorZs+eTUFBAZMnT6Znz568++67ng4Lf39/Zs2axX333Yemadx66610797d02HVyfTp09myZQtm\ns5kbbriBadOmMW7cOE+HZce2bdtYvXo1cXFxjB49GoPBQGJiIgMGDPB0aHbk5eXxxBNPUF1dTXV1\nNfHx8TX7ffgy0ktHEAShmeBVlo4gCILgPkTwBUEQmgki+IIgCM0EEXxBEIRmggi+IAhCM0EEXxAE\noZkggi8IgtBMEMEXBEFoJvw//5K32R/vBHAAAAAASUVORK5CYII=\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f5be3c99f50\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Current loss: 9.48636\n" - ] - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "plt.scatter(inputs, outputs, c='b')\n", - "plt.scatter(inputs, model(inputs), c='r')\n", - "plt.show()\n", - "\n", - "print('Current loss: '),\n", - "print(loss(model(inputs), outputs).numpy())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "sSDP-yeq_4jE" - }, - "source": [ - "### Define a training loop\n", - "\n", - "We now have our network and our training data. Let's train it, i.e., use the training data to update the model's variables (`W` and `b`) so that the loss goes down using [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent). There are many variants of the gradient descent scheme that are captured in `tf.train.Optimizer` implementations. We'd highly recommend using those implementations, but in the spirit of building from first principles, in this particular example we will implement the basic math ourselves." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "MBIACgdnA55X" - }, - "outputs": [], - "source": [ - "def train(model, inputs, outputs, learning_rate):\n", - " with tf.GradientTape() as t:\n", - " current_loss = loss(model(inputs), outputs)\n", - " dW, db = t.gradient(current_loss, [model.W, model.b])\n", - " model.W.assign_sub(learning_rate * dW)\n", - " model.b.assign_sub(learning_rate * db)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "RwWPaJryD2aN" - }, - "source": [ - "Finally, let's repeatedly run through the training data and see how `W` and `b` evolve." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 446 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 569, - "status": "ok", - "timestamp": 1527005915434, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "XdfkR223D9dW", - "outputId": "c43591ae-d5ac-4f2b-a8e7-bfce607e0919" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 0: W=5.00 b=0.00, loss=9.48636\n", - "Epoch 1: W=4.58 b=0.42, loss=6.28101\n", - "Epoch 2: W=4.24 b=0.76, loss=4.29357\n", - "Epoch 3: W=3.98 b=1.02, loss=3.06128\n", - "Epoch 4: W=3.78 b=1.23, loss=2.29721\n", - "Epoch 5: W=3.61 b=1.39, loss=1.82345\n", - "Epoch 6: W=3.49 b=1.52, loss=1.52970\n", - "Epoch 7: W=3.38 b=1.62, loss=1.34756\n", - "Epoch 8: W=3.30 b=1.70, loss=1.23463\n", - "Epoch 9: W=3.24 b=1.76, loss=1.16460\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW0AAAEDCAYAAAD+/1UIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4VOXdPvD7zJZ9XwmELQkQIAELsiTsi6xiEBGXAiIW\nbV8WBY2K0tLa4lbsr283qxURtIoioAi8SpFNg6whi0FJKAoJBgLZt5k5c87vj5OZLIRkgEnOGXJ/\nritXJsmZyT0sN1+enPOMIMuyDCIicgs6tQMQEZHzWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu\nxODMQePGjYOvry90Oh0MBgM2b97c1rmIiKgZTpW2IAjYuHEjAgIC2joPERG1wKnlEVmWIUlSW2ch\nIqJWCM5cETl+/HgEBARAEATMmTMH9957b3tkIyKiJpxaHvnggw8QFhaG4uJiLFiwAD179sTgwYPb\nOhsRETXh1PJIWFgYACA4OBgTJ05EVlZWi8fL3t6AIADdugFvvglYrTeflIiIWl8eqampgSRJ8PHx\nQXV1NR5++GEsXrwYI0aMuPadCgtRvfoFeL2zDkJtLWxdu6NqRSrMs+8DDE4N9y4XFuaHoqIKVb73\ntTCTc7SYCdBmLmZyjlYzOaPVSfvy5ct44IEHkJKSgjlz5mDcuHEtFzYAREai6oWXUHwkA9WPPApd\n4QX4L/sVgpMGwWPTvwFRdCocERE15tQPIm9Ew3/FdBcK4P3ntfB89x0IVivEmFhUP/kMzCmzAL2+\nLb79VbT6LysztU6LmQBt5mIm52g1kzPa5YpIKaozKl9+DcWHT6Jm7gLof/wB/r98BEGjh8Fj28cA\nTyckInJKu17GLnWJRuXaP6P40AnUPDgP+jN58F+0AEFjhsO0fRvLm4ioFarsPSJ1647KP/0VxWnH\nUTvnAehPf4+AhfMQNG4ETDs/A/hiOkREzVJ1wyipR09U/OV1lHx9FLX3zIH+uxwEPPQAAieMgunz\nXSxvIqImNLHLny0mDhV/fxMlB4+g9u57YMjORMDcOQicNAamPV+wvImI6miitO1scb1Q8fo6lOz/\nBrUzZsJ4Mh0B99+DwKnjYdy7h+VNRNftL395DR999IHj4+XLl2DVqlWOj//61/+HDz/8txrRboim\nStvO1iceFf96B8V702CeNgPG48cQOGcmAu+cBOOBfSxvInJa//6JyM7OAKBsfldWVorc3FzH17Oz\nM5GQMECteNdNk6VtZ+vXH+Vvv4uSPQdhnjwVxiPfIPCeGQhImQpj2ldqxyMiN5CQMBBZWZkAgLNn\nz6Bnzxj4+PigsrISVqsVP/74A+Liequc0nnqXFN+ncSEASjf8AEMJ0/A+9UX4bH7c5hSpsIycjSq\nnloJcdhwtSMSkRN8Vj8Pj+3bXPqY5jtTULX699f8emhoKPR6Ay5duoisrEz075+I6uoyZGdnwsfH\nBzExsTCotL3GjdD0pN2UOPBnKH/vI5Ts2gPL2PEwHdyPoBmTEDD7LhiOHlY7HhFpVGJiIrKyMpCd\nrZT2gAEDkJWVgaws91oaAdxk0m5KHHQ7yjZtheHIYfi8sgam/Xth2r8X5vETUZ26EuJtg9SOSETN\nqFr9+xan4rbSr18isrIy8d//KssjHh4y/vnPf8HX1wfTpt3V7nluhltN2k2JQ4aibPMnKP1kFyzJ\nI+GxZzeCJo2F/8/vhSHzpNrxiEgjEhIGIC3tIPz9/SEIAgICAlBZWYHs7Cz075+gdrzr4talbWcd\nnoyyrTtQuuUzWIYlweOL/0PQhFHwn3c/9HU/gCCijismJhbl5WXo3z+x0ef8/Pzg7+9er33bLrv8\ntStZhvHAPvi8/AcYjx0BAJin3wWPXz+Hom69lRdn0Ait7jTGTM7RYi5mco5WMznjlpi0GxEEWEeP\nRemO3Sj9YAusPxsEj88+AYYMQeCEUfBc/xaEinK1UxIR3ZBbr7TtBAHWcRNQuutLlG7aCsycCUNO\nNvxSn0BIQm/4Ll8CQ/pxXqhDRG7l1i1tO0GAdex4YMsWFJ88hapnV0EKCYHXu+8gaNJYBI4fCc+3\n/wWhvEztpERErbr1S7sBKSIS1U88heKjmSj9YAvM02bAcOpb+D29HCGJveH7xGIYThzj9E1EmtWh\nSttBp4N13ASUv/2uMn2v/DWk0DB4vbcBQZPHIWjcCHiue5PTNxFpTscs7QakiEhUP/4kio9koHTT\nVpinzYD++1Pwe2aFMn0//j8wHD/K6ZuINKHDl7aDTgfr2PHK9J2eg8rnfgMpNBxe/96IoCnjETQ2\nmdM3kZsqLPwJ8+bNUTuGS7C0myFFRKJm2QoUHzmpTN/T74L+9HfK9J3QC77LfgXDsSOcvonciKCh\nazRuBku7Jfbpe91GXEk/hcrnV0MKj4DX++8iaOoEZfp+6w0IZaVqJyWiVoiiiD/8YTXmz78fy5Yt\ng9lsVjvSDbn1roi8BpddASVJMB7YB6+N62Ha9RkEUYTs5QXzXXejZu5DEAcPcfqqS61elcVMztFi\nLq1nWr3aA9u3u3afujvvFLF6dcsFXFj4E2bPnoF//GMd+vdPwJ/+9CI6dYrGfff93KVZbkbHvSKy\nrel0sI4Zh/K3NuDKye9Q+fxvIYVHwPOD9xA0bSKCxiTB861/cvom0piIiEjH5lAzZsxAZmaGyolu\njFtuzaoVcng4apY+gZrFy2A8uB+eG9fDY+d2+D37FHx/92uYZ8xEzbwF1zV9E93KVq82tzoVt5Wm\na9ru+leSk7Yr6HSwjh6Lin+9o0zfq34HKSISnpv+XTd9D4fnv16HUFqidlKiDquw8Cd8+202AGDH\njh1ITByocqIbw9J2MTk8HDVLHkfxN+ko3fwpau+6G/q8XPitTEVIYm/4LXkMhiOHeeYJUTvr3r0H\ndu36DPPn34+ysjKkpNyjdqQbwh9EtgOhqAiem/4Nz41vw3D2vwAAsU88DAsfxpVREyH16KlKruZo\n/QdZWqLFXMzkHK1mcgYn7XYgh4WhZvEylBw6gdKPt6M25W7oz+QBTz2FkKEDETR6OLxf/gMMWRmc\nwImoRfxBZHvS6WAdORrWkaNReeUKQr/eA/Omj2A6sA8+a1+Gz9qXYYvuCvOUabBMvRPWIcMAN3qV\naCJqe2wElcghIcDChSifcS+EygoYv/wPPHZ+BtPuz+H9xj/g/cY/IAUHwzxpKixTpsMyeizg5aV2\nbCJSGUtbA2RfP1hmzIRlxkzAYoHx64NKgf/fDni9/y683n8Xsrc3LGMnwDx1OiwTJ0EODFI7NhGp\ngKWtNSYTrGPHKy/c8PJaGE4cg8euHTDt3A6PHZ/CY8enkA0GWJNGKgU+ZRqkTlFqpyaidsLS1jKd\nDuLgIRAHD0HV86uhP/09PHZ9BtPO7TAd2AvTgb3AMytg/dkgmKdMh2XqnbDF9VI7NRG1IZ494i4E\nAbbefVD9+JMo/WI/rqTnoOLFV2EZOQaGjJPw/cNvEZw8GEFJg+Dz+9XKHuCSpHZqItVVVlZi69bN\nbfb406dPQGVlJQDgypXLGDnydmRlZTT4+kSUl7vuxcSdLm1JkjBz5kw89thjLvvmdOOkzl1Qu/BR\nlH38Ka7knEH5X/8J89Q7oS/Ih/f/voagKeMRPDAevqlPwLjvS8BiUTsykSoqKsqxdetHzX5NcsFg\n07dvArKzMwEA2dmZ6NWrD7KylI/PnfsRgYFB8Pf3v+nvY+d0aW/YsAExMTEu+8bkOnJQMMz33o/y\n9e/h8qmzKHvnfdTOeQCCuRZe699C4L0pCOkbA79fPgLT9m1A3VRA1BG8/vpfceFCAR5++EH8/e//\ni/T045g3bx5++9vnMX/+fVe9QML777+Lt99+EwBQUJCPFSuW4pFH5mHx4kU4d+7Hqx4/ISHRUdpZ\nWZmYM+dBfPttfYknJCS69Pk4taZdWFiI/fv347HHHsPbb7/t0gDkYt7esEyZBsuUaYAowvhNGky7\nPoPHzs/g+fGH8Pz4Q8geHrCMGQfLlOkw3zEFcmio2qmpAwke1L/Zzxcfz3bJ8U398pdL8MMP/8W6\nde8BANLTjyMrKwsbNnyIyMhIFBb+dM0XSHjllTVITV2Jzp27ICcnG2vXvoQ///kfjY7p3z8R69e/\nBQA4depbPPLIY/joo38DUEo8IWGAUzmd5VRpr1mzBqmpqaio0NZln9QKgwHWEaNgHTEKVb9/GYas\nDOUslF074PH5Lnh8vgu+Oh2sQ4fDMnU6zFOmA2HN/wUhupUkJiYiMjKyxWNqamqQnZ2BVauehn23\nD1EUrzqub99+yM39HrW1tbDZbPD09ERUVGcUFOQjOzsD99/v2j27Wy3tffv2ITQ0FPHx8Th8+LDT\nD+zsdfTtqcNnGj9SeVv7CpCbC3zyCYStW2E6lAbToa/hu+pZICEBYWPGAGPGAKNGARqZwrX4ewdo\nM5fmMzWzxAAAYde68/Ue34TFUg69XufIEBjoDS8vL8fHklQNQajPaDQCOp0JwcHeCAgIwPbtn7by\nHfzQvXs37N//OQYMSEBYmB+GDBmMrKxjKC8vw6Br/E/hRrVa2idOnMCXX36J/fv3w2w2o6qqCqmp\nqXjllVdavJ8WN2NhpgYCI4H5jwLzH4Vw8SI8Pt8Jj53bYUr7CsjKAv7yFwCAGN8X1uHJsCSPhHVY\nMuQwZ/+quI4Wf+8AbeZipqvV1sqoqKh0ZCgtrQZQ31GSZMLly1dw5kwBPD09sXv3HgwbloSaGhkR\nEZ3w4YdbMXbsBABAXl4uYmPjrvoeffr0w7p1b2PhwkdRVFSBbt164YUXViE+vp/Tz93Zf2xbLe3l\ny5dj+fLlAIAjR45g3bp1rRY2uRc5IgK18xagdt4ChPmbUPrFPhi/Pghj2tcwHjsMw6kceK1TfjAj\n9u4D6/BkWJNHwjJ8BOTwcJXTE7XM3z8ACQkDMH/+fRg6NAnDhyc3+rrBYMCCBY9g0aL5iIrqjG7d\nuju+9utfv4A//vElvPPOOthsIsaPv6PZ0k5IGIDNmzehXz/llXF69+6DoqIizJgx0+XP57q2ZrWX\n9uuvv97qsfzXvnVukcligSH9BEyHvlKK/OhhCNXVji+Lcb1gHT4C1uQRsCaNgBTR8jqhSzJphBZz\nMZNztJrJGdxPW0VumclqheHkCRgPfQ3T1wdhOHIYuqr6UwjFmFhYk0Y43lxxib0Wf50AbeZiJudo\nNZMzeBk7XR+jEeLtQyHePhQ1S5crJZ55UllKSTsI4+Fv4LVxPbw2rgcA2Lr3UNbD65ZUpM5d1M1P\n5OZY2nRzjEaIg26HOOh21Cx5HBBFGLIylBI/9BWMh9Lg9d4GeL23AQBg69odluQR9SUe3VXlJ0Dk\nXlja5FoGA8TbBkG8bRBq/mcpYLPB8G0WjF9/VV/iddvNAoAtuiusSSNgsS+ndO3mvi+TTdQOWNrU\ntvR6iIkDISYORM0vFwM2G/Q538KUdtAxjXtu+jc8NylXkNk6d3Gsh1uSRkDq3kPlJ0CkLSxtal96\nPWwJiahJSETNo/8DSBL0p3Ial/hHH8Dzow8AALZOUcCY0fDqkwAxcQDEhETI/gEqPwki9bC0SV06\nHWz9+qOmX3/U/OKXSol//x2MaV/BlKYsqeD99+GL9x13EXv0VKb3hAF1RT5Aefk2og6ApU3aotPB\nFt8Xtvi+qF24CJBlhJUWonx/GgyZGcpb1kl4frIF+GSL4262LtH1JZ44AGLiwDY5Z5zcT2VlJXbv\n/j/MnHlPm32PNWt+i+TkkRg9elybfQ87ljZpmyAAvXrBHNQJ5pRZyudkGbr8844CN2RmwJhxEh67\nPoPHrs8cd7WFR9SXeMJAiIkDIHWJ5g86Oxj7ftpNS1uSJOh07vc6MCxtcj+CACm6KyzRXWGZdqfj\n07qLhTBknmwwkWfA4z9fwOM/XziOkYKCHAVuf7N17wm44V9edzVokE+znz9+vMolxzfVcD9tvV4P\nLy9vREVF4ttvc/Dqq39Gaurj2LBhEwBlL+3a2hosWPALFBTk47XXXkFZWSk8PT2Rmvocunbtds3v\nc/ToYXz44fsoKSnG4sVPIClphFP5rhdLm24ZUkQkLBMnwzJxsuNzwpUrMGTVl7gh82T962va7+fr\nBzEh0bE+LiYOhC02DjDwr8etoOF+2unpx5Ga+gTWrn0VRqPfTe+l3VBh4U/429/eRH7+eSxd+hg2\nbdoGo9Ho8ufDP5V0S5NDQmAdMw7WMfVrjUJ5GQzZWfVTeVYGjIcPwXTo6/r7eXlB7NvfsT4uJg6A\n2DseMJnUeBq3FGcn5Bs9vjV9+/ZDVFRUi5exO7uXdkPjxk0EAHTpEo2oqM748ccfmt1c6maxtKnD\nkf0DHOeCO1RVwZCT3WAiz4AhIx3G40fr72c0QozvpxR4/0Rg2CAIIZ2VnQ65Tu42PD09Hbf1ej1s\ntvrXibRYzAAAWZbg5+fveLUbZzSd2K81wd8sljYRAPj4OPZUcTCbYfgup9FZK4Zvs2HMPOk4JBSA\n5B8AW2wsbDFxsMX1glj33tajJ+Dh0f7PhRrx9vZGdd3OlE33xwsKCkZpaQnKy8vh6emJtLSvMGxY\nEry9fdCpUxT27v1Pq3tp2+3d+x9MnjwNFy4U4MKFghbXv28GS5voWjw8IA64DeKA2+o/Z7VCn3sa\nhqwM+F/4EeaMbOjP5MKQlQnjieON7i7rdJC6doMYG+codFtsHMTYXsqLSXA6bxcN99M2mTwQHBzs\n+Jor9tK2i47uhsWLF6GkpBhPPbWyTdazAW7Nqipmco4WMwFNcokidOd+hOFMLvS5udCfyVXKPS8X\nustFV93XMZ3H1he5LTbupqdzLf5aMZNzuDUrUXsyGCD1jIGlZwzQ4OwVABBKS6DPy4U+LxeGuvf6\nvNOtT+f2Iq9bcuF0TgBLm6jNyYFBEAcPgTh4CMwNvyCK0J/7oa7E86DPO11X7KeVc8sbnF8O1E3n\nccpSixjXS1lyccF0Ts7bsGEd9u79DwRBgCzLEAQBY8dOwNy5C9otA5dHVMRMztFiJqBtc101neee\nVpZczv4XgtXa6FjHdB7XCx59eqEyJBK26GhIXaJh69IVcmioqhO6Fn//tJrJGZy0iTTIqem8bu3c\nUFfoHrs/B3Z/Dt+mj+XlBVvnLkqJR3eFFN0VtrpCl6KjIUV2AvT6dnx2dDNY2kTuxGCArWcsbD1j\ngTumNPqSUFKM0MorKMv8Dvr8c9Dln4f+/Pm69z/CkJfb7EPKBgOkqM6wdbFP59GOYpeio2HrHM3l\nFw1haRPdIuSgYKBXN1iir3FaWmUl9PnnlUI/fx76/PPQ5Z9zFLvx0NcQrrFaaguPUAo8uiukLg0K\nvW5al32d+6893TyWNlFH4esLW5942PrEN/91sxm6CwV1ZX4e+vPn6m+fOwdDxkkYjx9r9q5SYKBS\n4F2i69bT64sdiX0A2YNLMC7C0iYihYcHpB49IfXo2fzXbTboLhbWTen1yy/224b/5kHIzmz2rqF6\nPaTQMEjhEZAiIhq/D49s9DG8vdvwSbo/ljYROUevhxTVGVJUZ4hDh139dVmGUFzcYPlFKXPv4iKI\n5wuUrXPP5ELIymjx20h+/pDCwyFFRCrvHcVu/1wEpIhIyMHBHXJLXZY2EbmGIEAOCYEYEgI0uPTf\nO8wPpQ1OrxMqK6C7dBG6ixfr3hdCd+lS3fv6z+v/e+aaa+xA3Q9Qw8KbTO0RjlJvWPJosEmUu2Np\nE1G7kn39YPP1U86AaYnVCt2Vy1eVeePCvwjD96cgZKS3+FBSQGD91B4RAXTtAm9PX0jBIZCCgyEH\nBUMKCoYcEgIpKFjTJc/SJiJtMhohRXZSziNviSxDqChvMq03md7r3gy5px13a/71cOoe0ttbKfSg\nukIPaVDswcH1X6u7LQcHQ/bxbZeLmFjaROTeBAGyfwBs/gHKKw61xGKB7nIRQsQqlJ45D11JMYSS\nYuiuXGl0Wygpga6kGIYzeRCqnXsRBtlobDSty0H1hS4FBSsTfXDj4pcDAq97XZ6lTUQdh8kEKaoz\nEOYHa9dezt3HbFYKvbgYuuIrSrHbbxcX15e9/eOfLsBwKseph5Z1OsiBgZCCQ4AG/wtoCUubiKgl\nHh7KEk1kJ9icvY8oQigtVQq9boq/ZvHX3XYWS5uIyNUMBsihobCFhgJOvkxkmJMP3fFOciQicmMs\nbSIiN8LSJiJyIyxtIiI30uoPIi0WCx588EFYrVbYbDZMmjQJixcvbo9sRETURKulbTKZsGHDBnh5\necFms+H+++/HqFGjkJiY2B75iIioAaeWR7y8vAAoU7coim0aiIiIrs2p0pYkCSkpKUhOTkZycjKn\nbCIilTh1cY1Op8O2bdtQWVmJX/3qV8jLy0NsbAs7dHXvjmDp6i0Vi49nN3t48KD+zX7epcfrhKsy\nqZoHuCqT6nmaZNJEngaZNJPH7tyPmsrD42+N41tzXVdE+vr6YsiQITh48GDLpQ1Ar7t6t6trvkR8\nM8e2xfFNM6mdp2kmLeRpmEkreeyZtJSnxfuolMd+/FX3UznPVffVQJ5GH2skj7MEWW5hl3EAxcXF\nMBqN8PPzQ21tLRYuXIhFixZh9OjRLT5wUYNNz7UgLMyPmZzATM7TYi5mco5WMzmj1Um7qKgIzzzz\nDCRJgiRJmDp1aquFTUREbaPV0u7duze2bt3aHlmIiKgVvCKSiMiNsLSJiNwIS5uIyI2wtImI3AhL\nm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uI\nyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiN\nsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0\niYjciKG1AwoLC5GamorLly9Dr9dj9uzZmDdvXntkIyKiJlotbb1ej2effRbx8fGoqqrC3XffjeTk\nZMTExLRHPiIiaqDV5ZGwsDDEx8cDAHx8fBATE4NLly61eTAiIrrada1p5+fn47vvvkNiYmJb5SEi\noha0ujxiV1VVhaVLl2LlypXw8fFp8dju3QFJuvqY48ermj1+0KDmH8+Vx+t0V2dSMw+AqzKpnadp\nJi3kaZhJK3nszp1r9tOq5eHxt8bxrXGqtEVRxNKlS3HXXXdhwoQJTj2wTnf1EB8W5neNY5t/DFcf\n3zST2nmaZtJCnoaZtJLHnklLeVq6j1p57Mc3vZ/aeZre1kKehh9rJY+zBFmW5dYOSk1NRVBQEJ59\n9lmnH7ioqOKGArWVsDA/ZnICMzlPi7mYyTlazeSMVte0jx8/ju3bt+Obb75BSkoKZs6ciQMHDtx0\nQCIiun6tLo8MGjQIp06dao8sRETUCl4RSUTkRljaRERuhKVNRORGWNpERG6EpU1E5EacviKSiIiu\nnyQBZWVASYmA4uL6t5ISwfG5khIBn37q3OOxtImInGSxoFHRNizgpkWsfAyUlgqQJMFlGVjaRNTh\nyDJQWYmrCvdaU7D981VVzpWvXi8jKEhGaKiMuDgJQUEygoOVt6Ag1L2XHe+DgmQAvk49NkubiG4Z\nNTVAUZGAixcFXLqkw6VLyu2iIuVj5fMCLl8GLBbnLhv39lZKtUcPqVHR1pdw4/INCZHh5wcIrhuu\nG2FpE5GmSZIyEV+6JDhK2F7IDd8uXtShvLzlpvTwkBERIWPgQMDfX2yxfO23vbza6Yk6iaVNRKqo\nqUGjwm1cwvVTcVGRAFFsuYxDQiR07izhtttkhIfLiIiQEB5uvy3X3Zbg769MwMqGUTXt9Exdi6VN\nRC4likBhoYD8fB3OnxdQUQGcPevRYEpWStn5qVhCeLjUqIAblnJYmAyjsZ2enAawtInoutTUABcu\nCDh/Xof8fB3y8+23laK+cEGAzda0kE2OW9eaiusnYuVzbbku7M5Y2kTUSFkZGpVw49sCLl9u/po8\nQZARGSnjZz+T0KWL/U1GfLwnPD2rEBGhnE3RkabitsDSJupAZFlZR25Ywsq0XH+7oqL58dZolNG5\ns4z4eBFdusjo0kVCdLTkuB0VJcNkuvp+YWGeKCqS2viZdRwsbaJbiNUKnDvXtJDrlzIKCgSYzc2X\nso+P3KiEu3SxfywhOlpZtmjppdeofbC0idyMLAM//SQgL0+H3FwdzpzRIS9PeV9QAEhS8xdphIZK\niI9X1pPrC7m+mAMDuYbsDljaRBpVXQ2cOaOUccNyzsvTobr66naNiJCQlARERFgbTczR0TI6d5bg\n7a3CkyCXY2kTqcg+Nefm1heyfWrOz796LcLTU0bPnhJiYxu/xcQoZ1so5x/XqvBMqL2wtInagX1q\nbljK9um5uak5MlLCyJEiYmIal3OXLlxX7uhY2kQuIsvK+csNJ2b7W0HBtafmuDjJUc72277O7R1E\nHRBLm+g6WSzA99/rcOkScOKEqdH03NzU3KmTMjU3XMqIi5PQuTOnZrp+LG2iFtTUADk5OmRm6pGV\npbw/dUoHq9Vezh4AAC+va681c2omV2JpE9WprASys/XIzKwv6dOndY0uyfbwkJGQIKF/fxsGDzYh\nIqIasbGcmqn9sLSpQyopAbKylIJW3utx5kzj1vX2ljF4sA2JiRISEpT3cXGS4zLssDATiopsKqSn\njoylTbe8S5cEx9KGvaTPnWtc0AEBMkaOFJGQICEx0YbERBt69uT0TNrD0qZbhv3sjYblnJmpQ2Fh\n4+YNDZUwbpyIxESbo6S7dpV5NSC5BZY2uSVZBn74QXAUs30N+sqVxgUdFSVh8mRrgwlaQmQkC5rc\nF0ubNM9mA06f1jUq56ws/VWb6HfrJiEpyepYg05IkBAWJquUmqhtsLRJc8xmID1dj6+/1iMtTY/j\nx4Hqah/H1wVBeYXrCRPqp+f+/W0IDFQxNFE7YWmT6mprgRMnlIJOS9Pj2DE9amvrp+h+/YCEBKtj\nDbpfPxvPfaYOi6VN7a6mBjh+vL6kjx/XO/Z4FgQZfftKSE62YfhwG4YPF9G7NzdBIrJjaVObq64G\njh2rL+lQSYIFAAANpklEQVQTJ/SwWOpLun9/CUlJNiQl2TBsmIigIJUDE2kYS5tcrqrq6pK2X/at\n09WXdHKyiKFDuRZNdD1Y2nTTKiuBo0ftJW1AeroOolhf0omJ9klaKemAAJUDE7kxljZdt8pK4MgR\n+9kdBmRk1Je0Xi9jwAAJw4crk/SQITb4+6scmOgW0mppr1y5Evv27UNISAi2b9/eHplIYyoqgMOH\n6yfpjIz6TZT0ehkDB0pIShKRnGzDkCE8s4OoLbVa2nfffTfmzp2L1NTU9shDGlBeDnzzjVLQaWnK\nFYeSpJS0wSDjttskJCeLGD6cJU3U3lot7cGDB6OgoKA9spBKZBk4eVKHXbsMOHgQSE/3dZS00ajs\ndGc/u+P2223w8WnlAYmozXBNu4OyWoFDh/TYtcuAXbsMuHBB2bPDaARuv92G5GSlpAcPtvFVvIk0\npM1KOyzMr60e+oZ19EzV1cAXXwBbtwLbtyt7SgNAYCAwdy6QkgJMmgT4+BigtX/Ptfh7B2gzFzM5\nR4uZnNFmfzOLiira6qFvSFiYX4fMVFICfPGFATt3GrBvnwE1NcqyR2SkhAULREydKiIpyebY2N/H\np2P+Ot0ILeZiJudoNZMznCptWeZOae7kwgUBu3YpRZ2Wpnec6REba8PUqUpRDxwocYN/IjfUammv\nWLEChw8fRmlpKcaMGYMlS5Zg1qxZ7ZGNrkNurg47dypFnZ6ud3z+ttuUop4yRUSvXpKKCYnIFVot\n7bVr17ZHDrpOkqSc8WEv6rw8paj1euVls+xFHRXF/yUR3Uq09dMmapHVCqSl1Z/x8dNPyvqGl5eM\nKVOsmDpVxB13cMMlolsZS1vjqquBvXuVaXr3bgNKS5X16cBAGffeqxT1mDEiT8sj6iBY2hpUUgJ8\n/rlS1Pv315/xERUlYdYspaiHDas/44OIOg6WtkYUFAiOZY+GZ3z06lX/g8SBAyW+IC1RB8fSVtGp\nU8C775qwc6cBJ0/Wn/Hxs5/ZT82zIjaWP0gkonos7XZWWgp89JER775rxKlTAOABg0HGqFH1Z3x0\n6sSiJqLmsbTbgSwDR4/qsGGDCZ9+akBtrQCjUcbMmcCECTWYOFHkq7cQkVNY2m2orEyZqjduNOLU\nKWX5o0cPCXPnmjFnjoi+fX1RVCSqnJKI3AlL28VkGTh2TIeNG0345BPlzA+jUcZdd1kxd64VI0bY\nePk4Ed0wlraLlJUBmzcbsWFD/VTdrZuEuXMtuP9+K8LCuE5NRDePpX0TZBk4cUJZq962TZmqDQYZ\nM2YoU/XIkZyqici1WNo3oLy8fqrOyWk8Vd93nxXh4ZyqiahtsLSdJMtAeroOGzYYsW2bEdXVylQ9\nfboV8+ZZMWoUp2oianss7VZUVChT9caNRmRnK1N11671U3VEBKdqImo/LO1m2F/oduNGI7ZsUaZq\nvV7GtGnKVD16NKdqIlIHS7uBysr6qTorq36q/vnPlTNAOFUTkdpY2gAyMpS16o8/rp+qp05Vpuox\nYzhVE5F2dNjSrqwEtmxRzgDJzFSm6i5dJCxdasEDD1gRGcmpmoi0p8OVdmamDu+8o6xVV1UpU/Xk\nyVbMn69M1Xp9649BRKSWDlHalZXAtm3A3//u7dgCtXNnCYsXK1M1d9UjIndxS5d2eTnwxhsmvP66\nCeXlgE6nw+TJylr12LGcqonI/dySpV1ZCbz1lgl/+5sJpaUCQkIkrF4tICWliq9OTkRu7ZYq7epq\nYN06I/72NxOuXNEhMFDGc8+ZsXChBT16+KGoiIVNRO7tlijt2lpgwwYj/vxnE4qKdPD3l5Gaasai\nRRb4+6udjojIddy6tM1m4N13lbIuLNTBx0fG8uVmPPaYha8EQ0S3JLcsbasVeP99I/70JxMKCnTw\n9paxZIkZv/qVFSEhXAIholuXW5W2KAIffWTA2rUeOHdOB09PGY89ZsGSJRa+yAARdQhuUdo2G7Bl\niwF//KMHzp7VwWSS8cgjFixbZuF+IETUoWi6tCUJ+PRTA1591YTcXD2MRhkPPWTB449beOoeEXVI\nmixtSQJ27lTK+tQpPfR6GT//uVLWXbuyrImo49JUacsy8MUXerz8sgeys/XQ6WTMmWPF8uVm9OjB\nsiYi0kRpyzKwd69S1unpegiCjLvvtuLJJ82IjWVZExHZqVrasgwcPKiU9dGjykYgM2ZY8eSTFvTp\nI6kZjYhIk1Qr7UOH9HjpJRMOHVIiTJlixVNPWdC/P8uaiOha2r20jx7V4aWXPHDwoPKtJ04UkZpq\nxoABLGsiotY49UJaBw4cwOTJkzFp0iS88cYbN/SNTpzQ4b77vDBtmg8OHjRgzBgRu3ZV4b33aljY\nREROanXSliQJL7zwAtavX4/w8HDcc889GD9+PGJiYpz6BllZOrzyigc+/1z5ViNGiEhNtWDYMNvN\nJSci6oBaLe3MzEx069YNnTt3BgBMmzYNe/bsabW0c3J0ePVVE3bsMAIAhg4V8fTTFowYwbImIrpR\nrZb2xYsX0alTJ8fHERERyMrKavE+990HfPihN2RZwKBBNjz9tBmjR9sgCDcfmIioI2u1tGX5+s+T\n3rQJGDBAwtNPmzF+PMuaiMhVWi3tyMhIXLhwwfHxxYsXER4e3uJ9lJ7XA/C+yXiuFRbmp3aEqzCT\nc7SYCdBmLmZyjhYzOaPVs0cSEhJw7tw5FBQUwGKxYMeOHRg/fnx7ZCMioiZanbT1ej1WrVqFhx9+\nGLIs45577nH6zBEiInItQb6RRWsiIlKFUxfXEBGRNrC0iYjcCEubiMiNuHTDqAMHDmDNmjWQZRmz\nZs3CokWLXPnwN2TlypXYt28fQkJCsH37drXjAAAKCwuRmpqKy5cvQ6/XY/bs2Zg3b56qmSwWCx58\n8EFYrVbYbDZMmjQJixcvVjWTnSRJmDVrFiIiIvD666+rHQfjxo2Dr68vdDodDAYDNm/erHYkVFRU\n4LnnnkNubi50Oh3WrFmDAQMGqJrp7NmzeOKJJyAIAmRZxvnz57Fs2TLV/6yvX78emzdvhiAI6NWr\nF1588UWYTCZVM73zzjuOP0et9oHsIjabTZ4wYYKcn58vWywWecaMGXJeXp6rHv6GHT16VM7JyZGn\nT5+udhSHS5cuyTk5ObIsy3JlZaV8xx13aOLXqrq6WpZlWRZFUZ49e7ackZGhciLF22+/La9YsUJ+\n9NFH1Y4iy7Isjxs3Ti4tLVU7RiNPP/20vHnzZlmWZdlqtcoVFRUqJ2rMZrPJycnJ8oULF1TNUVhY\nKI8bN042m82yLMvysmXL5K1bt6qa6fTp0/L06dNls9ksi6IoP/TQQ/KPP/54zeNdtjzScI8So9Ho\n2KNEbYMHD4a/v7/aMRoJCwtDfHw8AMDHxwcxMTG4dOmSyqkALy8vAMrULYqiymkUhYWF2L9/P2bP\nnq12FAdZliFJ2tmZsrKyEseOHcOsWbMAAAaDAb6+viqnaiwtLQ1du3ZttCWGWiRJQk1NDURRRG1t\nbasXC7a1M2fOYODAgTCZTNDr9bj99tuxe/fuax7vstJubo8SLRSR1uXn5+O7775DYmKi2lEgSRJS\nUlKQnJyM5ORkTWRas2YNUlNTIWhoLwRBELBw4ULMmjULH374odpxkJ+fj6CgIDz77LOYOXMmVq1a\nhdraWrVjNbJz505MmzZN7RiIiIjAggULMGbMGIwaNQp+fn5ISkpSNVNcXByOHj2KsrIy1NTU4MCB\nA/jpp5+uebzLSlvm6d7XraqqCkuXLsXKlSvh4+OjdhzodDps27YNBw4cQEZGBvLy8lTNs2/fPoSG\nhiI+Pl5Tf74++OADbNmyBW+++Sbee+89HDt2TNU8oigiJycHDzzwALZu3QpPT88b3ve+LVitVnz5\n5ZeYMmWK2lFQXl6OPXv2YO/evTh48CCqq6tV/1lXTEwMfvGLX2DBggVYtGgR+vTpA4Ph2j9udFlp\n38geJR2ZKIpYunQp7rrrLkyYMEHtOI34+vpiyJAhOHjwoKo5Tpw4gS+//BLjx4/HihUrcPjwYaSm\npqqaCVCWtwAgODgYEydObHXXy7YWGRmJyMhIJCQkAAAmTZqEnJwcVTM1dODAAfTr1w/BwcFqR0Fa\nWhqio6MRGBgIvV6PiRMnIj09Xe1YmDVrFrZs2YKNGzciICAA3bp1u+axLittLe9RoqUpzW7lypWI\njY3F/Pnz1Y4CACguLkZFRQUAoLa2FocOHULPnj1VzbR8+XLs27cPe/bswWuvvYahQ4filVdeUTVT\nTU0NqqqqAADV1dX46quvEBcXp2qm0NBQdOrUCWfPngUAfPPNN5raamLHjh2YPn262jEAAFFRUcjI\nyIDZbIYsy5r5tSouLgYAXLhwAbt3727x18tlp/xpdY8S+4RWWlqKMWPGYMmSJY4f2Kjl+PHj2L59\nO3r16oWUlBQIgoAnnngCo0aNUi1TUVERnnnmGUiSBEmSMHXqVIwePVq1PFp1+fJlLF68GIIgwGaz\n4c4778SIESPUjoXnn38eTz75JERRRHR0NF588UW1IwFQBoC0tDT87ne/UzsKACAxMRGTJk1CSkoK\nDAYD+vbti3vvvVftWFiyZAnKyspgMBjwm9/8Bn5+196BkHuPEBG5EV4RSUTkRljaRERuhKVNRORG\nWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu5P8D+7Wym3BFpegAAAAASUVORK5CYII=\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f5be4b8ec50\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "model = Model()\n", - "\n", - "# Collect the history of W-values and b-values to plot later\n", - "Ws, bs = [], []\n", - "epochs = range(10)\n", - "for epoch in epochs:\n", - " Ws.append(model.W.numpy())\n", - " bs.append(model.b.numpy())\n", - " current_loss = loss(model(inputs), outputs)\n", - "\n", - " train(model, inputs, outputs, learning_rate=0.1)\n", - " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n", - " (epoch, Ws[-1], bs[-1], current_loss))\n", - "\n", - "# Let's plot it all\n", - "plt.plot(epochs, Ws, 'r',\n", - " epochs, bs, 'b')\n", - "plt.plot([TRUE_W] * len(epochs), 'r--',\n", - " [TRUE_b] * len(epochs), 'b--')\n", - "plt.legend(['W', 'b', 'true W', 'true_b'])\n", - "plt.show()\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "vPnIVuaSJwWz" - }, - "source": [ - "## Next Steps\n", - "\n", - "In this tutorial we covered `Variable`s and built and trained a simple linear model using the TensorFlow primitives discussed so far.\n", - "\n", - "In theory, this is pretty much all you need to use TensorFlow for your machine learning research.\n", - "In practice, particularly for neural networks, the higher level APIs like `tf.keras` will be much more convenient since it provides higher level building blocks (called \"layers\"), utilities to save and restore state, a suite of loss functions, a suite of optimization strategies etc. \n", - "\n", - "The [next tutorial](TODO) will cover these higher level APIs." - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "Training Models", - "provenance": [], - "version": "0.3.2", - "views": {} - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb deleted file mode 100644 index 5749f22ac58e0a012ed7e3fec4dfe2913d3f8273..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb +++ /dev/null @@ -1,551 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "pwX7Fii1rwsJ" - }, - "outputs": [], - "source": [ - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "tfe = tf.contrib.eager\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "UEu3q4jmpKVT" - }, - "source": [ - "# High level API\n", - "\n", - "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "zSFfVVjkrrsI" - }, - "source": [ - "## Layers: common sets of useful operations\n", - "\n", - "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", - "\n", - "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", - "\n", - "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "8PyXlPl-4TzQ" - }, - "outputs": [], - "source": [ - "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", - "# simply construct the object. Most layers take as a first argument the number\n", - "# of output dimensions / channels.\n", - "layer = tf.keras.layers.Dense(100)\n", - "# The number of input dimensions is often unnecessary, as it can be inferred\n", - "# the first time the layer is used, but it can be provided if you want to \n", - "# specify it manually, which is useful in some complex models.\n", - "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Fn69xxPO5Psr" - }, - "source": [ - "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", - "Conv2D, LSTM, BatchNormalization, Dropout, and many others." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 204 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 244, - "status": "ok", - "timestamp": 1527783641557, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "E3XKNknP5Mhb", - "outputId": "c5d52434-d980-4488-efa7-5660819d0207" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\u003ctf.Tensor: id=30, shape=(10, 10), dtype=float32, numpy=\n", - "array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)\u003e" - ] - }, - "execution_count": 3, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# To use a layer, simply call it.\n", - "layer(tf.zeros([10, 5]))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 221 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 320, - "status": "ok", - "timestamp": 1527783642457, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "Wt_Nsv-L5t2s", - "outputId": "f0d96dce-0128-4080-bfe2-0ee6fbc0ad90" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", - " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", - " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", - " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", - " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", - " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", - " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", - " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", - " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", - " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", - " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", - " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e]" - ] - }, - "execution_count": 4, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# Layers have many useful methods. For example, you can inspect all variables\n", - "# in a layer by calling layer.variables. In this case a fully-connected layer\n", - "# will have variables for weights and biases.\n", - "layer.variables" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 221 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 226, - "status": "ok", - "timestamp": 1527783643252, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "6ilvKjz8_4MQ", - "outputId": "f647fced-c2d7-41a3-c237-242036784665" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", - " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", - " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", - " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", - " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", - " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", - " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", - " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", - " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", - " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", - " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", - " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e)" - ] - }, - "execution_count": 5, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# The variables are also accessible through nice accessors\n", - "layer.kernel, layer.bias" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "O0kDbE54-5VS" - }, - "source": [ - "## Implementing custom layers\n", - "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n", - " * `__init__` , where you can do all input-independent initialization\n", - " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n", - " * `call`, where you do the forward computation\n", - "\n", - "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 391 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 251, - "status": "ok", - "timestamp": 1527783661512, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "5Byl3n1k5kIy", - "outputId": "6e7f9285-649a-4132-82ce-73ea92f15862" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tf.Tensor(\n", - "[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(10, 10), dtype=float32)\n", - "[\u003ctf.Variable 'my_dense_layer_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", - "array([[-0.4011991 , 0.22458655, -0.33237562, -0.25117266, 0.33528614,\n", - " -0.01392961, 0.58580834, -0.16346583, 0.28465688, -0.47191954],\n", - " [-0.52922136, 0.22416979, -0.58209574, -0.60914612, 0.05226624,\n", - " -0.18325993, 0.5591442 , -0.24718609, 0.37148207, 0.40475875],\n", - " [ 0.16912812, -0.47618777, -0.38989353, 0.30105609, -0.08085585,\n", - " 0.44758242, 0.545829 , 0.51421839, 0.11063248, 0.20159996],\n", - " [ 0.34073615, -0.59835428, 0.06498981, -0.44489855, -0.34302285,\n", - " 0.20969599, 0.35527444, -0.03173476, -0.22227573, 0.09303057],\n", - " [ 0.41764337, -0.06435019, -0.52509922, -0.39957345, 0.56811184,\n", - " 0.23481232, -0.61666459, 0.31144124, -0.11532354, -0.42421889]], dtype=float32)\u003e]\n" - ] - } - ], - "source": [ - "class MyDenseLayer(tf.keras.layers.Layer):\n", - " def __init__(self, num_outputs):\n", - " super(MyDenseLayer, self).__init__()\n", - " self.num_outputs = num_outputs\n", - " \n", - " def build(self, input_shape):\n", - " self.kernel = self.add_variable(\"kernel\", \n", - " shape=[input_shape[-1].value, \n", - " self.num_outputs])\n", - " \n", - " def call(self, input):\n", - " return tf.matmul(input, self.kernel)\n", - " \n", - "layer = MyDenseLayer(10)\n", - "print(layer(tf.zeros([10, 5])))\n", - "print(layer.variables)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "tk8E2vY0-z4Z" - }, - "source": [ - "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n", - "\n", - "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Qhg4KlbKrs3G" - }, - "source": [ - "## Models: composing layers\n", - "\n", - "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", - "\n", - "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 190 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 420, - "status": "ok", - "timestamp": 1527783698512, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "N30DTXiRASlb", - "outputId": "a8b23a8e-5cf9-4bbf-f93b-6c763d74e2b3" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tf.Tensor(\n", - "[[[[ 0. 0. 0.]\n", - " [ 0. 0. 0.]\n", - " [ 0. 0. 0.]]\n", - "\n", - " [[ 0. 0. 0.]\n", - " [ 0. 0. 0.]\n", - " [ 0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)\n", - "['resnet_identity_block_1/conv2d_3/kernel:0', 'resnet_identity_block_1/conv2d_3/bias:0', 'resnet_identity_block_1/batch_normalization_3/gamma:0', 'resnet_identity_block_1/batch_normalization_3/beta:0', 'resnet_identity_block_1/conv2d_4/kernel:0', 'resnet_identity_block_1/conv2d_4/bias:0', 'resnet_identity_block_1/batch_normalization_4/gamma:0', 'resnet_identity_block_1/batch_normalization_4/beta:0', 'resnet_identity_block_1/conv2d_5/kernel:0', 'resnet_identity_block_1/conv2d_5/bias:0', 'resnet_identity_block_1/batch_normalization_5/gamma:0', 'resnet_identity_block_1/batch_normalization_5/beta:0', 'resnet_identity_block_1/batch_normalization_3/moving_mean:0', 'resnet_identity_block_1/batch_normalization_3/moving_variance:0', 'resnet_identity_block_1/batch_normalization_4/moving_mean:0', 'resnet_identity_block_1/batch_normalization_4/moving_variance:0', 'resnet_identity_block_1/batch_normalization_5/moving_mean:0', 'resnet_identity_block_1/batch_normalization_5/moving_variance:0']\n" - ] - } - ], - "source": [ - "class ResnetIdentityBlock(tf.keras.Model):\n", - " def __init__(self, kernel_size, filters):\n", - " super(ResnetIdentityBlock, self).__init__(name='')\n", - " filters1, filters2, filters3 = filters\n", - "\n", - " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", - " self.bn2a = tf.keras.layers.BatchNormalization()\n", - "\n", - " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", - " self.bn2b = tf.keras.layers.BatchNormalization()\n", - "\n", - " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", - " self.bn2c = tf.keras.layers.BatchNormalization()\n", - "\n", - " def call(self, input_tensor, training=False):\n", - " x = self.conv2a(input_tensor)\n", - " x = self.bn2a(x, training=training)\n", - " x = tf.nn.relu(x)\n", - "\n", - " x = self.conv2b(x)\n", - " x = self.bn2b(x, training=training)\n", - " x = tf.nn.relu(x)\n", - "\n", - " x = self.conv2c(x)\n", - " x = self.bn2c(x, training=training)\n", - "\n", - " x += input_tensor\n", - " return tf.nn.relu(x)\n", - "\n", - " \n", - "block = ResnetIdentityBlock(1, [1, 2, 3])\n", - "print(block(tf.zeros([1, 2, 3, 3])))\n", - "print([x.name for x in block.variables])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "wYfucVw65PMj" - }, - "source": [ - "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 153 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 361, - "status": "ok", - "timestamp": 1526674830777, - "user": { - "displayName": "Alexandre Passos", - "photoUrl": "//lh4.googleusercontent.com/-kmTTWXEgAPw/AAAAAAAAAAI/AAAAAAAAAC0/q_DoOzKGwds/s50-c-k-no/photo.jpg", - "userId": "108023195365833072773" - }, - "user_tz": 420 - }, - "id": "L9frk7Ur4uvJ", - "outputId": "882e9076-b6d9-4380-bb1e-7c6b57d54c39" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\u003ctf.Tensor: id=1423, shape=(1, 2, 3, 3), dtype=float32, numpy=\n", - "array([[[[0., 0., 0.],\n", - " [0., 0., 0.],\n", - " [0., 0., 0.]],\n", - "\n", - " [[0., 0., 0.],\n", - " [0., 0., 0.],\n", - " [0., 0., 0.]]]], dtype=float32)\u003e" - ] - }, - "execution_count": 26, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", - " tf.keras.layers.BatchNormalization(),\n", - " tf.keras.layers.Conv2D(2, 1, \n", - " padding='same'),\n", - " tf.keras.layers.BatchNormalization(),\n", - " tf.keras.layers.Conv2D(3, (1, 1)),\n", - " tf.keras.layers.BatchNormalization()])\n", - "my_seq(tf.zeros([1, 2, 3, 3]))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "c5YwYcnuK-wc" - }, - "source": [ - "# Next steps\n", - "\n", - "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured." - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "4 - High level API - TensorFlow Eager.ipynb", - "provenance": [], - "version": "0.3.2", - "views": {} - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/README.md b/tensorflow/contrib/eager/python/examples/notebooks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0d5ed848946d1eee643a57bf8c341520268c56b1 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/README.md @@ -0,0 +1,11 @@ +## Research and experimentation + +Eager execution provides an imperative, define-by-run interface for advanced +operations. Write custom layers, forward passes, and training loops with auto +differentiation. Start with these notebooks, then read the +[eager execution guide](https://www.tensorflow.org/guide/eager). + +1. [Eager execution basics](./eager_basics.ipynb) +2. [Automatic differentiation and gradient tapes](./automatic_differentiation.ipynb) +3. [Custom training: basics](./custom_training.ipynb) +4. [Custom layers](./custom_layers.ipynb) diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..51b7ffc4de0cee31f7a907ae7bf90f17056f9bcf --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb @@ -0,0 +1,366 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "t09eeeR5prIJ" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "GCCk8_dHpuNf" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xh8WkEwWpnm7" + }, + "source": [ + "# Automatic differentiation and gradient tape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "idv0bPeCp325" + }, + "source": [ + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vDJ4XzMqodTy" + }, + "source": [ + "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GQJysDM__Qb0" + }, + "source": [ + "## Setup\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "OiMPZStlibBv" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "tfe = tf.contrib.eager # Shorthand for some symbols" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1CLWJl0QliB0" + }, + "source": [ + "## Derivatives of a function\n", + "\n", + "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "9FViq92UX7P8" + }, + "outputs": [], + "source": [ + "from math import pi\n", + "\n", + "def f(x):\n", + " return tf.square(tf.sin(x))\n", + "\n", + "assert f(pi/2).numpy() == 1.0\n", + "\n", + "\n", + "# grad_f will return a list of derivatives of f\n", + "# with respect to its arguments. Since f() has a single argument,\n", + "# grad_f will return a list with a single element.\n", + "grad_f = tfe.gradients_function(f)\n", + "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "v9fPs8RyopCf" + }, + "source": [ + "### Higher-order gradients\n", + "\n", + "The same API can be used to differentiate as many times as you like:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "3D0ZvnGYo0rW" + }, + "outputs": [], + "source": [ + "def f(x):\n", + " return tf.square(tf.sin(x))\n", + "\n", + "def grad(f):\n", + " return lambda x: tfe.gradients_function(f)(x)[0]\n", + "\n", + "x = tf.lin_space(-2*pi, 2*pi, 100) # 100 points between -2π and +2π\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(x, f(x), label=\"f\")\n", + "plt.plot(x, grad(f)(x), label=\"first derivative\")\n", + "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n", + "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-39gouo7mtgu" + }, + "source": [ + "## Gradient tapes\n", + "\n", + "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`. To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n", + "\n", + "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "MH0UfjympWf7" + }, + "outputs": [], + "source": [ + "def f(x, y):\n", + " output = 1\n", + " # Must use range(int(y)) instead of range(y) in Python 3 when\n", + " # using TensorFlow 1.10 and earlier. Can use range(y) in 1.11+\n", + " for i in range(int(y)):\n", + " output = tf.multiply(output, x)\n", + " return output\n", + "\n", + "def g(x, y):\n", + " # Return the gradient of `f` with respect to it's first parameter\n", + " return tfe.gradients_function(f)(x, y)[0]\n", + "\n", + "assert f(3.0, 2).numpy() == 9.0 # f(x, 2) is essentially x * x\n", + "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n", + "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n", + "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "aNmR5-jhpX2t" + }, + "source": [ + "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n", + "\n", + "For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "bAFeIE8EuVIq" + }, + "outputs": [], + "source": [ + "x = tf.ones((2, 2))\n", + " \n", + "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n", + "# a single t.gradient() call when the bug is resolved.\n", + "with tf.GradientTape(persistent=True) as t:\n", + " # TODO(ashankar): Explain with \"watch\" argument better?\n", + " t.watch(x)\n", + " y = tf.reduce_sum(x)\n", + " z = tf.multiply(y, y)\n", + "\n", + "# Use the same tape to compute the derivative of z with respect to the\n", + "# intermediate value y.\n", + "dz_dy = t.gradient(z, y)\n", + "assert dz_dy.numpy() == 8.0\n", + "\n", + "# Derivative of z with respect to the original input tensor x\n", + "dz_dx = t.gradient(z, x)\n", + "for i in [0, 1]:\n", + " for j in [0, 1]:\n", + " assert dz_dx[i][j].numpy() == 8.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DK05KXrAAld3" + }, + "source": [ + "### Higher-order gradients\n", + "\n", + "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "cPQgthZ7ugRJ" + }, + "outputs": [], + "source": [ + "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n", + "\n", + "x = tf.constant(1.0) # Convert the Python 1.0 to a Tensor object\n", + "\n", + "with tf.GradientTape() as t:\n", + " with tf.GradientTape() as t2:\n", + " t2.watch(x)\n", + " y = x * x * x\n", + " # Compute the gradient inside the 't' context manager\n", + " # which means the gradient computation is differentiable as well.\n", + " dy_dx = t2.gradient(y, x)\n", + "d2y_dx2 = t.gradient(dy_dx, x)\n", + "\n", + "assert dy_dx.numpy() == 3.0\n", + "assert d2y_dx2.numpy() == 6.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4U1KKzUpNl58" + }, + "source": [ + "## Next Steps\n", + "\n", + "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "automatic_differentiation.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a0bbbb612381c5eb386b04fd7bb9914eb01f4c8e --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb @@ -0,0 +1,399 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "custom_layers.ipynb", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "cells": [ + { + "metadata": { + "id": "tDnwEv8FtJm7", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "metadata": { + "id": "JlknJBWQtKkI", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "cellView": "form" + }, + "cell_type": "code", + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "60RdWsg1tETW", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Custom layers" + ] + }, + { + "metadata": { + "id": "BcJg7Enms86w", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "
\n", + "\n", + " Run in Google Colab\n", + "\n", + "View source on GitHub
" + ] + }, + { + "metadata": { + "id": "UEu3q4jmpKVT", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n" + ] + }, + { + "metadata": { + "id": "pwX7Fii1rwsJ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "tfe = tf.contrib.eager\n", + "\n", + "tf.enable_eager_execution()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "zSFfVVjkrrsI", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Layers: common sets of useful operations\n", + "\n", + "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", + "\n", + "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", + "\n", + "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" + ] + }, + { + "metadata": { + "id": "8PyXlPl-4TzQ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", + "# simply construct the object. Most layers take as a first argument the number\n", + "# of output dimensions / channels.\n", + "layer = tf.keras.layers.Dense(100)\n", + "# The number of input dimensions is often unnecessary, as it can be inferred\n", + "# the first time the layer is used, but it can be provided if you want to \n", + "# specify it manually, which is useful in some complex models.\n", + "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Fn69xxPO5Psr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", + "Conv2D, LSTM, BatchNormalization, Dropout, and many others." + ] + }, + { + "metadata": { + "id": "E3XKNknP5Mhb", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# To use a layer, simply call it.\n", + "layer(tf.zeros([10, 5]))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Wt_Nsv-L5t2s", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Layers have many useful methods. For example, you can inspect all variables\n", + "# in a layer by calling layer.variables. In this case a fully-connected layer\n", + "# will have variables for weights and biases.\n", + "layer.variables" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "6ilvKjz8_4MQ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# The variables are also accessible through nice accessors\n", + "layer.kernel, layer.bias" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "O0kDbE54-5VS", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Implementing custom layers\n", + "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n", + " * `__init__` , where you can do all input-independent initialization\n", + " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n", + " * `call`, where you do the forward computation\n", + "\n", + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified." + ] + }, + { + "metadata": { + "id": "5Byl3n1k5kIy", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class MyDenseLayer(tf.keras.layers.Layer):\n", + " def __init__(self, num_outputs):\n", + " super(MyDenseLayer, self).__init__()\n", + " self.num_outputs = num_outputs\n", + " \n", + " def build(self, input_shape):\n", + " self.kernel = self.add_variable(\"kernel\", \n", + " shape=[input_shape[-1].value, \n", + " self.num_outputs])\n", + " \n", + " def call(self, input):\n", + " return tf.matmul(input, self.kernel)\n", + " \n", + "layer = MyDenseLayer(10)\n", + "print(layer(tf.zeros([10, 5])))\n", + "print(layer.variables)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "tk8E2vY0-z4Z", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n", + "\n", + "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!" + ] + }, + { + "metadata": { + "id": "Qhg4KlbKrs3G", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Models: composing layers\n", + "\n", + "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", + "\n", + "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model." + ] + }, + { + "metadata": { + "id": "N30DTXiRASlb", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class ResnetIdentityBlock(tf.keras.Model):\n", + " def __init__(self, kernel_size, filters):\n", + " super(ResnetIdentityBlock, self).__init__(name='')\n", + " filters1, filters2, filters3 = filters\n", + "\n", + " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", + " self.bn2a = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", + " self.bn2b = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", + " self.bn2c = tf.keras.layers.BatchNormalization()\n", + "\n", + " def call(self, input_tensor, training=False):\n", + " x = self.conv2a(input_tensor)\n", + " x = self.bn2a(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2b(x)\n", + " x = self.bn2b(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2c(x)\n", + " x = self.bn2c(x, training=training)\n", + "\n", + " x += input_tensor\n", + " return tf.nn.relu(x)\n", + "\n", + " \n", + "block = ResnetIdentityBlock(1, [1, 2, 3])\n", + "print(block(tf.zeros([1, 2, 3, 3])))\n", + "print([x.name for x in block.variables])" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "wYfucVw65PMj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" + ] + }, + { + "metadata": { + "id": "L9frk7Ur4uvJ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(2, 1, \n", + " padding='same'),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(3, (1, 1)),\n", + " tf.keras.layers.BatchNormalization()])\n", + "my_seq(tf.zeros([1, 2, 3, 3]))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "c5YwYcnuK-wc", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Next steps\n", + "\n", + "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured." + ] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5f1b48fa0d4aea06adab19a0e561923e1f557e50 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb @@ -0,0 +1,477 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Custom training: basics", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "metadata": { + "id": "5rmpybwysXGV", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "metadata": { + "id": "m8y3rGtQsYP2", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "cellView": "form" + }, + "cell_type": "code", + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "hrXv0rU9sIma", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Custom training: basics" + ] + }, + { + "metadata": { + "id": "7S0BwJ_8sLu7", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "
\n", + "\n", + " Run in Google Colab\n", + "\n", + "View source on GitHub
" + ] + }, + { + "metadata": { + "id": "k2o3TTG4TFpt", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "In the previous tutorial we covered the TensorFlow APIs for automatic differentiation, a basic building block for machine learning.\n", + "In this tutorial we will use the TensorFlow primitives introduced in the prior tutorials to do some simple machine learning.\n", + "\n", + "TensorFlow also includes a higher-level neural networks API (`tf.keras`) which provides useful abstractions to reduce boilerplate. We strongly recommend those higher level APIs for people working with neural networks. However, in this short tutorial we cover neural network training from first principles to establish a strong foundation." + ] + }, + { + "metadata": { + "id": "3LXMVuV0VhDr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Setup" + ] + }, + { + "metadata": { + "id": "PJ64L90aVir3", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "\n", + "tf.enable_eager_execution()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "eMAWbDJFVmMk", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Variables\n", + "\n", + "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n" + ] + }, + { + "metadata": { + "id": "VkJwtLS_Jbn8", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Using python state\n", + "x = tf.zeros([10, 10])\n", + "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n", + " # value of x\n", + "print(x)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "wfneTXy7JcUz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n", + "\n", + "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable." + ] + }, + { + "metadata": { + "id": "itxmrMil6DQi", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "v = tf.Variable(1.0)\n", + "assert v.numpy() == 1.0\n", + "\n", + "# Re-assign the value\n", + "v.assign(3.0)\n", + "assert v.numpy() == 3.0\n", + "\n", + "# Use `v` in a TensorFlow operation like tf.square() and reassign\n", + "v.assign(tf.square(v))\n", + "assert v.numpy() == 9.0" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "-paSaeq1JzwC", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n", + "\n", + "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable." + ] + }, + { + "metadata": { + "id": "BMiFcDzE7Qu3", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Example: Fitting a linear model\n", + "\n", + "Let's now put the few concepts we have so far ---`Tensor`, `GradientTape`, `Variable` --- to build and train a simple model. This typically involves a few steps:\n", + "\n", + "1. Define the model.\n", + "2. Define a loss function.\n", + "3. Obtain training data.\n", + "4. Run through the training data and use an \"optimizer\" to adjust the variables to fit the data.\n", + "\n", + "In this tutorial, we'll walk through a trivial example of a simple linear model: `f(x) = x * W + b`, which has two variables - `W` and `b`. Furthermore, we'll synthesize data such that a well trained model would have `W = 3.0` and `b = 2.0`." + ] + }, + { + "metadata": { + "id": "gFzH64Jn9PIm", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Define the model\n", + "\n", + "Let's define a simple class to encapsulate the variables and the computation." + ] + }, + { + "metadata": { + "id": "_WRu7Pze7wk8", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class Model(object):\n", + " def __init__(self):\n", + " # Initialize variable to (5.0, 0.0)\n", + " # In practice, these should be initialized to random values.\n", + " self.W = tf.Variable(5.0)\n", + " self.b = tf.Variable(0.0)\n", + " \n", + " def __call__(self, x):\n", + " return self.W * x + self.b\n", + " \n", + "model = Model()\n", + "\n", + "assert model(3.0).numpy() == 15.0" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "xa6j_yXa-j79", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Define a loss function\n", + "\n", + "A loss function measures how well the output of a model for a given input matches the desired output. Let's use the standard L2 loss." + ] + }, + { + "metadata": { + "id": "Y0ysUFGY924U", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def loss(predicted_y, desired_y):\n", + " return tf.reduce_mean(tf.square(predicted_y - desired_y))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "qutT_fkl_CBc", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Obtain training data\n", + "\n", + "Let's synthesize the training data with some noise." + ] + }, + { + "metadata": { + "id": "gxPTb-kt_N5m", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "TRUE_W = 3.0\n", + "TRUE_b = 2.0\n", + "NUM_EXAMPLES = 1000\n", + "\n", + "inputs = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "noise = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "outputs = inputs * TRUE_W + TRUE_b + noise" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "-50nq-wPBsAW", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Before we train the model let's visualize where the model stands right now. We'll plot the model's predictions in red and the training data in blue." + ] + }, + { + "metadata": { + "id": "_eb83LtrB4nt", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.scatter(inputs, outputs, c='b')\n", + "plt.scatter(inputs, model(inputs), c='r')\n", + "plt.show()\n", + "\n", + "print('Current loss: '),\n", + "print(loss(model(inputs), outputs).numpy())" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "sSDP-yeq_4jE", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Define a training loop\n", + "\n", + "We now have our network and our training data. Let's train it, i.e., use the training data to update the model's variables (`W` and `b`) so that the loss goes down using [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent). There are many variants of the gradient descent scheme that are captured in `tf.train.Optimizer` implementations. We'd highly recommend using those implementations, but in the spirit of building from first principles, in this particular example we will implement the basic math ourselves." + ] + }, + { + "metadata": { + "id": "MBIACgdnA55X", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def train(model, inputs, outputs, learning_rate):\n", + " with tf.GradientTape() as t:\n", + " current_loss = loss(model(inputs), outputs)\n", + " dW, db = t.gradient(current_loss, [model.W, model.b])\n", + " model.W.assign_sub(learning_rate * dW)\n", + " model.b.assign_sub(learning_rate * db)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "RwWPaJryD2aN", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Finally, let's repeatedly run through the training data and see how `W` and `b` evolve." + ] + }, + { + "metadata": { + "id": "XdfkR223D9dW", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "model = Model()\n", + "\n", + "# Collect the history of W-values and b-values to plot later\n", + "Ws, bs = [], []\n", + "epochs = range(10)\n", + "for epoch in epochs:\n", + " Ws.append(model.W.numpy())\n", + " bs.append(model.b.numpy())\n", + " current_loss = loss(model(inputs), outputs)\n", + "\n", + " train(model, inputs, outputs, learning_rate=0.1)\n", + " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n", + " (epoch, Ws[-1], bs[-1], current_loss))\n", + "\n", + "# Let's plot it all\n", + "plt.plot(epochs, Ws, 'r',\n", + " epochs, bs, 'b')\n", + "plt.plot([TRUE_W] * len(epochs), 'r--',\n", + " [TRUE_b] * len(epochs), 'b--')\n", + "plt.legend(['W', 'b', 'true W', 'true_b'])\n", + "plt.show()\n", + " " + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "vPnIVuaSJwWz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Next Steps\n", + "\n", + "In this tutorial we covered `Variable`s and built and trained a simple linear model using the TensorFlow primitives discussed so far.\n", + "\n", + "In theory, this is pretty much all you need to use TensorFlow for your machine learning research.\n", + "In practice, particularly for neural networks, the higher level APIs like `tf.keras` will be much more convenient since it provides higher level building blocks (called \"layers\"), utilities to save and restore state, a suite of loss functions, a suite of optimization strategies etc. \n", + "\n", + "The [next tutorial](TODO) will cover these higher level APIs." + ] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb similarity index 50% rename from tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb rename to tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb index 51d10a778413cfbb574b4e22e8adcb18bd731dee..f1e13de5dec2fbda126caeb355494875317e3373 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb @@ -1,27 +1,107 @@ { + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "eager_basics.ipynb", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, "cells": [ { + "metadata": { + "id": "iPpI7RaYoZuE", + "colab_type": "text" + }, "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { "metadata": { - "colab_type": "text", - "id": "U9i2Dsh-ziXr" + "id": "hro2InpHobKk", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "cellView": "form" }, + "cell_type": "code", + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "U9i2Dsh-ziXr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Eager execution basics" + ] + }, + { + "metadata": { + "id": "Hndw-YcxoOJK", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "
\n", + "\n", + " Run in Google Colab\n", + "\n", + "View source on GitHub
" + ] + }, + { + "metadata": { + "id": "6sILUVbHoSgH", + "colab_type": "text" + }, + "cell_type": "markdown", "source": [ - "# An introduction to TensorFlow\n", - "\n", "This is an introductory tutorial for using TensorFlow. It will cover:\n", "\n", "* Importing required packages\n", "* Creating and using Tensors\n", - "* Using GPU acceleration\n" + "* Using GPU acceleration\n", + "* Datasets" ] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "z1JcS5iBXMRO" + "id": "z1JcS5iBXMRO", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "## Import TensorFlow\n", "\n", @@ -30,32 +110,32 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "cellView": "code", + "id": "RlIWhyeLoYnG", + "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, - "colab_type": "code", - "id": "RlIWhyeLoYnG" + "cellView": "code" }, - "outputs": [], + "cell_type": "code", "source": [ "import tensorflow as tf\n", "\n", "tf.enable_eager_execution()" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "H9UySOPLXdaw" + "id": "H9UySOPLXdaw", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "## Tensors\n", "\n", @@ -63,46 +143,18 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "cellView": "code", + "id": "ngUe237Wt48W", + "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - }, - "height": 125 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 320, - "status": "ok", - "timestamp": 1526420535530, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 + } }, - "id": "ngUe237Wt48W", - "outputId": "b1a1cd60-4eb3-443d-cd6b-68406390784e" + "cellView": "code" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tf.Tensor(3, shape=(), dtype=int32)\n", - "tf.Tensor([4 6], shape=(2,), dtype=int32)\n", - "tf.Tensor(25, shape=(), dtype=int32)\n", - "tf.Tensor(6, shape=(), dtype=int32)\n", - "tf.Tensor(aGVsbG8gd29ybGQ, shape=(), dtype=string)\n", - "tf.Tensor(13, shape=(), dtype=int32)\n" - ] - } - ], + "cell_type": "code", "source": [ "print(tf.add(1, 2))\n", "print(tf.add([1, 2], [3, 4]))\n", @@ -112,66 +164,46 @@ "\n", "# Operator overloading is also supported\n", "print(tf.square(2) + tf.square(3))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "IDY4WsYRhP81" + "id": "IDY4WsYRhP81", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "Each Tensor has a shape and a datatype" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { + "id": "srYWH1MdJNG7", + "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - }, - "height": 53 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 215, - "status": "ok", - "timestamp": 1526420538162, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "srYWH1MdJNG7", - "outputId": "5e4ac41c-5115-4e50-eba0-42e249c16561" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1, 2)\n", - "\u003cdtype: 'int32'\u003e\n" - ] + } } - ], + }, + "cell_type": "code", "source": [ "x = tf.matmul([[1]], [[2, 3]])\n", "print(x.shape)\n", "print(x.dtype)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "eBPw8e8vrsom" + "id": "eBPw8e8vrsom", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "The most obvious differences between NumPy arrays and TensorFlow Tensors are:\n", "\n", @@ -180,11 +212,11 @@ ] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "Dwi1tdW3JBw6" + "id": "Dwi1tdW3JBw6", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "### NumPy Compatibility\n", "\n", @@ -197,52 +229,17 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { + "id": "lCUWzso6mbqR", + "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - }, - "height": 251 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 238, - "status": "ok", - "timestamp": 1526420540562, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "lCUWzso6mbqR", - "outputId": "fd0a22bc-8249-49dd-fcbd-63161cc47e46" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TensorFlow operations convert numpy arrays to Tensors automatically\n", - "tf.Tensor(\n", - "[[ 42. 42. 42.]\n", - " [ 42. 42. 42.]\n", - " [ 42. 42. 42.]], shape=(3, 3), dtype=float64)\n", - "And NumPy operations convert Tensors to numpy arrays automatically\n", - "[[ 43. 43. 43.]\n", - " [ 43. 43. 43.]\n", - " [ 43. 43. 43.]]\n", - "The .numpy() method explicitly converts a Tensor to a numpy array\n", - "[[ 42. 42. 42.]\n", - " [ 42. 42. 42.]\n", - " [ 42. 42. 42.]]\n" - ] + } } - ], + }, + "cell_type": "code", "source": [ "import numpy as np\n", "\n", @@ -258,14 +255,16 @@ "\n", "print(\"The .numpy() method explicitly converts a Tensor to a numpy array\")\n", "print(tensor.numpy())" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "PBNP8yTRfu_X" + "id": "PBNP8yTRfu_X", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "## GPU acceleration\n", "\n", @@ -273,42 +272,18 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "cellView": "code", + "id": "3Twf_Rw-gQFM", + "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - }, - "height": 53 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 340, - "status": "ok", - "timestamp": 1526420543562, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 + } }, - "id": "3Twf_Rw-gQFM", - "outputId": "2239ae2b-adf3-4895-b1f3-464cf5361d1b" + "cellView": "code" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Is there a GPU available: False\n", - "Is the Tensor on GPU #0: False\n" - ] - } - ], + "cell_type": "code", "source": [ "x = tf.random_uniform([3, 3])\n", "\n", @@ -317,26 +292,28 @@ "\n", "print(\"Is the Tensor on GPU #0: \"),\n", "print(x.device.endswith('GPU:0'))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "vpgYzgVXW2Ud" + "id": "vpgYzgVXW2Ud", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "### Device Names\n", "\n", - "The `Tensor.device` property provides a fully qualified string name of the device hosting the contents of the Tensor. This name encodes a bunch of details, such as an identifier of the network address of the host on which this program is executing and the device within that host. This is required for distributed execution of TensorFlow programs, but we'll skip that for now. The string will end with `GPU:\u003cN\u003e` if the tensor is placed on the `N`-th tensor on the host." + "The `Tensor.device` property provides a fully qualified string name of the device hosting the contents of the Tensor. This name encodes a bunch of details, such as an identifier of the network address of the host on which this program is executing and the device within that host. This is required for distributed execution of TensorFlow programs, but we'll skip that for now. The string will end with `GPU:` if the tensor is placed on the `N`-th tensor on the host." ] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "ZWZQCimzuqyP" + "id": "ZWZQCimzuqyP", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "\n", "\n", @@ -346,41 +323,17 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { + "id": "RjkNZTuauy-Q", + "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - }, - "height": 53 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 1762, - "status": "ok", - "timestamp": 1526420547562, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "RjkNZTuauy-Q", - "outputId": "2e613293-ccac-4db2-b793-8ceb5b5adcfd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "On CPU:\n", - "10 loops, best of 3: 35.8 ms per loop\n" - ] + } } - ], + }, + "cell_type": "code", "source": [ "def time_matmul(x):\n", " %timeit tf.matmul(x, x)\n", @@ -398,32 +351,141 @@ " x = tf.random_uniform([1000, 1000])\n", " assert x.device.endswith(\"GPU:0\")\n", " time_matmul(x)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "o1K4dlhhHtQj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Datasets\n", + "\n", + "This section demonstrates the use of the [`tf.data.Dataset` API](https://www.tensorflow.org/guide/datasets) to build pipelines to feed data to your model. It covers:\n", + "\n", + "* Creating a `Dataset`.\n", + "* Iteration over a `Dataset` with eager execution enabled.\n", + "\n", + "We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n", + "\n", + "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler.\n", + "You can use Python iteration over the `tf.data.Dataset` object and do not need to explicitly create an `tf.data.Iterator` object.\n", + "As a result, the discussion on iterators in the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets) is not relevant when eager execution is enabled." ] }, { + "metadata": { + "id": "zI0fmOynH-Ne", + "colab_type": "text" + }, "cell_type": "markdown", + "source": [ + "### Create a source `Dataset`\n", + "\n", + "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). See the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets#reading_input_data) for more information." + ] + }, + { "metadata": { - "colab_type": "text", - "id": "YEOJTNiOvnpQ" + "id": "F04fVOHQIBiG", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } }, + "cell_type": "code", "source": [ - "## Next Steps\n", + "ds_tensors = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n", "\n", - "In this tutorial we covered the most fundamental concepts in TensorFlow - `Tensor`s, operations, and devices.\n", - "In [the next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/2_gradients.ipynb) we will cover automatic differentiation - a building block required for training many machine learning models like neural networks." + "# Create a CSV file\n", + "import tempfile\n", + "_, filename = tempfile.mkstemp()\n", + "\n", + "with open(filename, 'w') as f:\n", + " f.write(\"\"\"Line 1\n", + "Line 2\n", + "Line 3\n", + " \"\"\")\n", + "\n", + "ds_file = tf.data.TextLineDataset(filename)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "vbxIhC-5IPdf", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Apply transformations\n", + "\n", + "Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for details." ] + }, + { + "metadata": { + "id": "uXSDZWE-ISsd", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n", + "\n", + "ds_file = ds_file.batch(2)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "A8X1GNfoIZKJ", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Iterate\n", + "\n", + "When eager execution is enabled `Dataset` objects support iteration.\n", + "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that there is no need for calls to `Dataset.make_one_shot_iterator()` or `get_next()` calls." + ] + }, + { + "metadata": { + "id": "ws-WKRk5Ic6-", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "print('Elements of ds_tensors:')\n", + "for x in ds_tensors:\n", + " print(x)\n", + "\n", + "print('\\nElements in ds_file:')\n", + "for x in ds_file:\n", + " print(x)" + ], + "execution_count": 0, + "outputs": [] } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "TensorFlow: An introduction", - "provenance": [], - "version": "0.3.2", - "views": {} - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ee25d25b52a2e06d9f99bdbe295afd228a3c6ce1 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb @@ -0,0 +1,810 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0TD5ZrvEMbhZ" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# Pix2Pix: An example with tf.keras and eager\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ITZuApL56Mny" + }, + "source": [ + "This notebook demonstrates image to image translation using conditional GAN's, as described in [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004). Using this technique we can colorize black and white photos, convert google maps to google earth, etc. Here, we convert building facades to real buildings. We use [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) to achieve this.\n", + "\n", + "In example, we will use the [CMP Facade Database](http://cmp.felk.cvut.cz/~tylecr1/facade/), helpfully provided by the [Center for Machine Perception](http://cmp.felk.cvut.cz/) at the [Czech Technical University in Prague](https://www.cvut.cz/). To keep our example short, we will use a preprocessed [copy](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/) of this dataset, created by the authors of the [paper](https://arxiv.org/abs/1611.07004) above.\n", + "\n", + "Each epoch takes around 58 seconds on a single P100 GPU.\n", + "\n", + "Below is the output generated after training the model for 200 epochs.\n", + "\n", + "\n", + "![sample output_1](https://www.tensorflow.org/images/gan/pix2pix_1.png)\n", + "![sample output_2](https://www.tensorflow.org/images/gan/pix2pix_2.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "e1_Y75QXJS6h" + }, + "source": [ + "## Import TensorFlow and enable eager execution" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "YfIk2es3hJEd" + }, + "outputs": [], + "source": [ + "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "import os\n", + "import time\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import PIL\n", + "from IPython.display import clear_output" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "iYn4MdZnKCey" + }, + "source": [ + "## Load the dataset\n", + "\n", + "You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets). As mentioned in the [paper](https://arxiv.org/abs/1611.07004) we apply random jittering and mirroring to the training dataset.\n", + "* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`\n", + "* In random mirroring, the image is randomly flipped horizontally i.e left to right." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Kn-k8kTXuAlv" + }, + "outputs": [], + "source": [ + "path_to_zip = tf.keras.utils.get_file('facades.tar.gz',\n", + " cache_subdir=os.path.abspath('.'),\n", + " origin='https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz', \n", + " extract=True)\n", + "\n", + "PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "2CbTEt448b4R" + }, + "outputs": [], + "source": [ + "BUFFER_SIZE = 400\n", + "BATCH_SIZE = 1\n", + "IMG_WIDTH = 256\n", + "IMG_HEIGHT = 256" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "tyaP4hLJ8b4W" + }, + "outputs": [], + "source": [ + "def load_image(image_file, is_train):\n", + " image = tf.read_file(image_file)\n", + " image = tf.image.decode_jpeg(image)\n", + "\n", + " w = tf.shape(image)[1]\n", + "\n", + " w = w // 2\n", + " real_image = image[:, :w, :]\n", + " input_image = image[:, w:, :]\n", + "\n", + " input_image = tf.cast(input_image, tf.float32)\n", + " real_image = tf.cast(real_image, tf.float32)\n", + "\n", + " if is_train:\n", + " # random jittering\n", + " \n", + " # resizing to 286 x 286 x 3\n", + " # method = 2 indicates using \"ResizeMethod.NEAREST_NEIGHBOR\"\n", + " input_image = tf.image.resize_images(input_image, [286, 286], \n", + " align_corners=True, method=2)\n", + " real_image = tf.image.resize_images(real_image, [286, 286], \n", + " align_corners=True, method=2)\n", + " \n", + " # randomly cropping to 256 x 256 x 3\n", + " stacked_image = tf.stack([input_image, real_image], axis=0)\n", + " cropped_image = tf.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n", + " input_image, real_image = cropped_image[0], cropped_image[1]\n", + "\n", + " if np.random.random() \u003e 0.5:\n", + " # random mirroring\n", + " input_image = tf.image.flip_left_right(input_image)\n", + " real_image = tf.image.flip_left_right(real_image)\n", + " else:\n", + " input_image = tf.image.resize_images(input_image, size=[IMG_HEIGHT, IMG_WIDTH], \n", + " align_corners=True, method=2)\n", + " real_image = tf.image.resize_images(real_image, size=[IMG_HEIGHT, IMG_WIDTH], \n", + " align_corners=True, method=2)\n", + " \n", + " # normalizing the images to [-1, 1]\n", + " input_image = (input_image / 127.5) - 1\n", + " real_image = (real_image / 127.5) - 1\n", + "\n", + " return input_image, real_image" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PIGN6ouoQxt3" + }, + "source": [ + "## Use tf.data to create batches, map(do preprocessing) and shuffle the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "SQHmYSmk8b4b" + }, + "outputs": [], + "source": [ + "train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')\n", + "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n", + "train_dataset = train_dataset.map(lambda x: load_image(x, True))\n", + "train_dataset = train_dataset.batch(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "MS9J0yA58b4g" + }, + "outputs": [], + "source": [ + "test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')\n", + "test_dataset = test_dataset.map(lambda x: load_image(x, False))\n", + "test_dataset = test_dataset.batch(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "THY-sZMiQ4UV" + }, + "source": [ + "## Write the generator and discriminator models\n", + "\n", + "* **Generator** \n", + " * The architecture of generator is a modified U-Net.\n", + " * Each block in the encoder is (Conv -\u003e Batchnorm -\u003e Leaky ReLU)\n", + " * Each block in the decoder is (Transposed Conv -\u003e Batchnorm -\u003e Dropout(applied to the first 3 blocks) -\u003e ReLU)\n", + " * There are skip connections between the encoder and decoder (as in U-Net).\n", + " \n", + "* **Discriminator**\n", + " * The Discriminator is a PatchGAN.\n", + " * Each block in the discriminator is (Conv -\u003e BatchNorm -\u003e Leaky ReLU)\n", + " * The shape of the output after the last layer is (batch_size, 30, 30, 1)\n", + " * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).\n", + " * Discriminator receives 2 inputs.\n", + " * Input image and the target image, which it should classify as real.\n", + " * Input image and the generated image (output of generator), which it should classify as fake. \n", + " * We concatenate these 2 inputs together in the code (`tf.concat([inp, tar], axis=-1)`)\n", + "\n", + "* Shape of the input travelling through the generator and the discriminator is in the comments in the code.\n", + "\n", + "To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "tqqvWxlw8b4l" + }, + "outputs": [], + "source": [ + "OUTPUT_CHANNELS = 3" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "lFPI4Nu-8b4q" + }, + "outputs": [], + "source": [ + "class Downsample(tf.keras.Model):\n", + " \n", + " def __init__(self, filters, size, apply_batchnorm=True):\n", + " super(Downsample, self).__init__()\n", + " self.apply_batchnorm = apply_batchnorm\n", + " initializer = tf.random_normal_initializer(0., 0.02)\n", + "\n", + " self.conv1 = tf.keras.layers.Conv2D(filters, \n", + " (size, size), \n", + " strides=2, \n", + " padding='same',\n", + " kernel_initializer=initializer,\n", + " use_bias=False)\n", + " if self.apply_batchnorm:\n", + " self.batchnorm = tf.keras.layers.BatchNormalization()\n", + " \n", + " def call(self, x, training):\n", + " x = self.conv1(x)\n", + " if self.apply_batchnorm:\n", + " x = self.batchnorm(x, training=training)\n", + " x = tf.nn.leaky_relu(x)\n", + " return x \n", + "\n", + "\n", + "class Upsample(tf.keras.Model):\n", + " \n", + " def __init__(self, filters, size, apply_dropout=False):\n", + " super(Upsample, self).__init__()\n", + " self.apply_dropout = apply_dropout\n", + " initializer = tf.random_normal_initializer(0., 0.02)\n", + "\n", + " self.up_conv = tf.keras.layers.Conv2DTranspose(filters, \n", + " (size, size), \n", + " strides=2, \n", + " padding='same',\n", + " kernel_initializer=initializer,\n", + " use_bias=False)\n", + " self.batchnorm = tf.keras.layers.BatchNormalization()\n", + " if self.apply_dropout:\n", + " self.dropout = tf.keras.layers.Dropout(0.5)\n", + "\n", + " def call(self, x1, x2, training):\n", + " x = self.up_conv(x1)\n", + " x = self.batchnorm(x, training=training)\n", + " if self.apply_dropout:\n", + " x = self.dropout(x, training=training)\n", + " x = tf.nn.relu(x)\n", + " x = tf.concat([x, x2], axis=-1)\n", + " return x\n", + "\n", + "\n", + "class Generator(tf.keras.Model):\n", + " \n", + " def __init__(self):\n", + " super(Generator, self).__init__()\n", + " initializer = tf.random_normal_initializer(0., 0.02)\n", + " \n", + " self.down1 = Downsample(64, 4, apply_batchnorm=False)\n", + " self.down2 = Downsample(128, 4)\n", + " self.down3 = Downsample(256, 4)\n", + " self.down4 = Downsample(512, 4)\n", + " self.down5 = Downsample(512, 4)\n", + " self.down6 = Downsample(512, 4)\n", + " self.down7 = Downsample(512, 4)\n", + " self.down8 = Downsample(512, 4)\n", + "\n", + " self.up1 = Upsample(512, 4, apply_dropout=True)\n", + " self.up2 = Upsample(512, 4, apply_dropout=True)\n", + " self.up3 = Upsample(512, 4, apply_dropout=True)\n", + " self.up4 = Upsample(512, 4)\n", + " self.up5 = Upsample(256, 4)\n", + " self.up6 = Upsample(128, 4)\n", + " self.up7 = Upsample(64, 4)\n", + "\n", + " self.last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, \n", + " (4, 4), \n", + " strides=2, \n", + " padding='same',\n", + " kernel_initializer=initializer)\n", + " \n", + " @tf.contrib.eager.defun\n", + " def call(self, x, training):\n", + " # x shape == (bs, 256, 256, 3) \n", + " x1 = self.down1(x, training=training) # (bs, 128, 128, 64)\n", + " x2 = self.down2(x1, training=training) # (bs, 64, 64, 128)\n", + " x3 = self.down3(x2, training=training) # (bs, 32, 32, 256)\n", + " x4 = self.down4(x3, training=training) # (bs, 16, 16, 512)\n", + " x5 = self.down5(x4, training=training) # (bs, 8, 8, 512)\n", + " x6 = self.down6(x5, training=training) # (bs, 4, 4, 512)\n", + " x7 = self.down7(x6, training=training) # (bs, 2, 2, 512)\n", + " x8 = self.down8(x7, training=training) # (bs, 1, 1, 512)\n", + "\n", + " x9 = self.up1(x8, x7, training=training) # (bs, 2, 2, 1024)\n", + " x10 = self.up2(x9, x6, training=training) # (bs, 4, 4, 1024)\n", + " x11 = self.up3(x10, x5, training=training) # (bs, 8, 8, 1024)\n", + " x12 = self.up4(x11, x4, training=training) # (bs, 16, 16, 1024)\n", + " x13 = self.up5(x12, x3, training=training) # (bs, 32, 32, 512)\n", + " x14 = self.up6(x13, x2, training=training) # (bs, 64, 64, 256)\n", + " x15 = self.up7(x14, x1, training=training) # (bs, 128, 128, 128)\n", + "\n", + " x16 = self.last(x15) # (bs, 256, 256, 3)\n", + " x16 = tf.nn.tanh(x16)\n", + "\n", + " return x16" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "ll6aNeQx8b4v" + }, + "outputs": [], + "source": [ + "class DiscDownsample(tf.keras.Model):\n", + " \n", + " def __init__(self, filters, size, apply_batchnorm=True):\n", + " super(DiscDownsample, self).__init__()\n", + " self.apply_batchnorm = apply_batchnorm\n", + " initializer = tf.random_normal_initializer(0., 0.02)\n", + "\n", + " self.conv1 = tf.keras.layers.Conv2D(filters, \n", + " (size, size), \n", + " strides=2, \n", + " padding='same',\n", + " kernel_initializer=initializer,\n", + " use_bias=False)\n", + " if self.apply_batchnorm:\n", + " self.batchnorm = tf.keras.layers.BatchNormalization()\n", + " \n", + " def call(self, x, training):\n", + " x = self.conv1(x)\n", + " if self.apply_batchnorm:\n", + " x = self.batchnorm(x, training=training)\n", + " x = tf.nn.leaky_relu(x)\n", + " return x \n", + "\n", + "class Discriminator(tf.keras.Model):\n", + " \n", + " def __init__(self):\n", + " super(Discriminator, self).__init__()\n", + " initializer = tf.random_normal_initializer(0., 0.02)\n", + " \n", + " self.down1 = DiscDownsample(64, 4, False)\n", + " self.down2 = DiscDownsample(128, 4)\n", + " self.down3 = DiscDownsample(256, 4)\n", + " \n", + " # we are zero padding here with 1 because we need our shape to \n", + " # go from (batch_size, 32, 32, 256) to (batch_size, 31, 31, 512)\n", + " self.zero_pad1 = tf.keras.layers.ZeroPadding2D()\n", + " self.conv = tf.keras.layers.Conv2D(512, \n", + " (4, 4), \n", + " strides=1, \n", + " kernel_initializer=initializer, \n", + " use_bias=False)\n", + " self.batchnorm1 = tf.keras.layers.BatchNormalization()\n", + " \n", + " # shape change from (batch_size, 31, 31, 512) to (batch_size, 30, 30, 1)\n", + " self.zero_pad2 = tf.keras.layers.ZeroPadding2D()\n", + " self.last = tf.keras.layers.Conv2D(1, \n", + " (4, 4), \n", + " strides=1,\n", + " kernel_initializer=initializer)\n", + " \n", + " @tf.contrib.eager.defun\n", + " def call(self, inp, tar, training):\n", + " # concatenating the input and the target\n", + " x = tf.concat([inp, tar], axis=-1) # (bs, 256, 256, channels*2)\n", + " x = self.down1(x, training=training) # (bs, 128, 128, 64)\n", + " x = self.down2(x, training=training) # (bs, 64, 64, 128)\n", + " x = self.down3(x, training=training) # (bs, 32, 32, 256)\n", + "\n", + " x = self.zero_pad1(x) # (bs, 34, 34, 256)\n", + " x = self.conv(x) # (bs, 31, 31, 512)\n", + " x = self.batchnorm1(x, training=training)\n", + " x = tf.nn.leaky_relu(x)\n", + " \n", + " x = self.zero_pad2(x) # (bs, 33, 33, 512)\n", + " # don't add a sigmoid activation here since\n", + " # the loss function expects raw logits.\n", + " x = self.last(x) # (bs, 30, 30, 1)\n", + "\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "gDkA05NE6QMs" + }, + "outputs": [], + "source": [ + "# The call function of Generator and Discriminator have been decorated\n", + "# with tf.contrib.eager.defun()\n", + "# We get a performance speedup if defun is used (~25 seconds per epoch)\n", + "generator = Generator()\n", + "discriminator = Discriminator()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0FMYgY_mPfTi" + }, + "source": [ + "## Define the loss functions and the optimizer\n", + "\n", + "* **Discriminator loss**\n", + " * The discriminator loss function takes 2 inputs; **real images, generated images**\n", + " * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**\n", + " * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**\n", + " * Then the total_loss is the sum of real_loss and the generated_loss\n", + " \n", + "* **Generator loss**\n", + " * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.\n", + " * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.\n", + " * This allows the generated image to become structurally similar to the target image.\n", + " * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004)." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "cyhxTuvJyIHV" + }, + "outputs": [], + "source": [ + "LAMBDA = 100" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "wkMNfBWlT-PV" + }, + "outputs": [], + "source": [ + "def discriminator_loss(disc_real_output, disc_generated_output):\n", + " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_real_output), \n", + " logits = disc_real_output)\n", + " generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.zeros_like(disc_generated_output), \n", + " logits = disc_generated_output)\n", + "\n", + " total_disc_loss = real_loss + generated_loss\n", + "\n", + " return total_disc_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "90BIcCKcDMxz" + }, + "outputs": [], + "source": [ + "def generator_loss(disc_generated_output, gen_output, target):\n", + " gan_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_generated_output),\n", + " logits = disc_generated_output) \n", + " # mean absolute error\n", + " l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n", + "\n", + " total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n", + "\n", + " return total_gen_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "iWCn_PVdEJZ7" + }, + "outputs": [], + "source": [ + "generator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)\n", + "discriminator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "aKUZnDiqQrAh" + }, + "source": [ + "## Checkpoints (Object-based saving)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "WJnftd5sQsv6" + }, + "outputs": [], + "source": [ + "checkpoint_dir = './training_checkpoints'\n", + "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", + "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", + " discriminator_optimizer=discriminator_optimizer,\n", + " generator=generator,\n", + " discriminator=discriminator)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Rw1fkAczTQYh" + }, + "source": [ + "## Training\n", + "\n", + "* We start by iterating over the dataset\n", + "* The generator gets the input image and we get a generated output.\n", + "* The discriminator receives the input_image and the generated image as the first input. The second input is the input_image and the target_image.\n", + "* Next, we calculate the generator and the discriminator loss.\n", + "* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.\n", + "\n", + "## Generate Images\n", + "\n", + "* After training, its time to generate some images!\n", + "* We pass images from the test dataset to the generator.\n", + "* The generator will then translate the input image into the output we expect.\n", + "* Last step is to plot the predictions and **voila!**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "NS2GWywBbAWo" + }, + "outputs": [], + "source": [ + "EPOCHS = 200" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "RmdVsmvhPxyy" + }, + "outputs": [], + "source": [ + "def generate_images(model, test_input, tar):\n", + " # the training=True is intentional here since\n", + " # we want the batch statistics while running the model\n", + " # on the test dataset. If we use training=False, we will get \n", + " # the accumulated statistics learned from the training dataset\n", + " # (which we don't want)\n", + " prediction = model(test_input, training=True)\n", + " plt.figure(figsize=(15,15))\n", + "\n", + " display_list = [test_input[0], tar[0], prediction[0]]\n", + " title = ['Input Image', 'Ground Truth', 'Predicted Image']\n", + "\n", + " for i in range(3):\n", + " plt.subplot(1, 3, i+1)\n", + " plt.title(title[i])\n", + " # getting the pixel values between [0, 1] to plot it.\n", + " plt.imshow(display_list[i] * 0.5 + 0.5)\n", + " plt.axis('off')\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "2M7LmLtGEMQJ" + }, + "outputs": [], + "source": [ + "def train(dataset, epochs): \n", + " for epoch in range(epochs):\n", + " start = time.time()\n", + "\n", + " for input_image, target in dataset:\n", + "\n", + " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", + " gen_output = generator(input_image, training=True)\n", + "\n", + " disc_real_output = discriminator(input_image, target, training=True)\n", + " disc_generated_output = discriminator(input_image, gen_output, training=True)\n", + "\n", + " gen_loss = generator_loss(disc_generated_output, gen_output, target)\n", + " disc_loss = discriminator_loss(disc_real_output, disc_generated_output)\n", + "\n", + " generator_gradients = gen_tape.gradient(gen_loss, \n", + " generator.variables)\n", + " discriminator_gradients = disc_tape.gradient(disc_loss, \n", + " discriminator.variables)\n", + "\n", + " generator_optimizer.apply_gradients(zip(generator_gradients, \n", + " generator.variables))\n", + " discriminator_optimizer.apply_gradients(zip(discriminator_gradients, \n", + " discriminator.variables))\n", + "\n", + " if epoch % 1 == 0:\n", + " clear_output(wait=True)\n", + " for inp, tar in test_dataset.take(1):\n", + " generate_images(generator, inp, tar)\n", + " \n", + " # saving (checkpoint) the model every 20 epochs\n", + " if (epoch + 1) % 20 == 0:\n", + " checkpoint.save(file_prefix = checkpoint_prefix)\n", + "\n", + " print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n", + " time.time()-start))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "a1zZmKmvOH85" + }, + "outputs": [], + "source": [ + "train(train_dataset, EPOCHS)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kz80bY3aQ1VZ" + }, + "source": [ + "## Restore the latest checkpoint and test" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "4t4x69adQ5xb" + }, + "outputs": [], + "source": [ + "# restoring the latest checkpoint in checkpoint_dir\n", + "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1RGysMU_BZhx" + }, + "source": [ + "## Testing on the entire test dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "KUgSnmy2nqSP" + }, + "outputs": [], + "source": [ + "# Run the trained model on the entire test dataset\n", + "for inp, tar in test_dataset:\n", + " generate_images(generator, inp, tar)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3AJXOByaZVOf" + }, + "outputs": [], + "source": [ + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "pix2pix_eager.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp", + "timestamp": 1527173385672 + } + ], + "toc_visible": true, + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD index 0c0e28dd95c68dc300384a128eb5aa2208f63a0d..68a84d5fbb4f13e4ebe0d71e3f5caebe97e2101c 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -51,5 +51,6 @@ cuda_py_test( "noasan", "nomsan", "notsan", + "optonly", ], ) diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index b14ef1df8ff4c660b9b6f2abfd5df6572d10b1e8..07d8788882c2d831dfb041fe7409af51857190bf 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -29,6 +29,7 @@ import tensorflow.contrib.eager as tfe from tensorflow.contrib.eager.python.examples.resnet50 import resnet50 from tensorflow.contrib.summary import summary_test_util from tensorflow.python.client import device_lib +from tensorflow.python.eager import tape def device_and_data_format(): @@ -49,13 +50,21 @@ def random_batch(batch_size, data_format): return images, one_hot -def compute_gradients(model, images, labels): - with tf.GradientTape() as tape: +def compute_gradients(model, images, labels, num_replicas=1): + with tf.GradientTape() as grad_tape: logits = model(images, training=True) loss = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) tf.contrib.summary.scalar(name='loss', tensor=loss) - return tape.gradient(loss, model.variables) + if num_replicas != 1: + loss /= num_replicas + + # TODO(b/110991947): We can mistakenly trace the gradient call in + # multi-threaded environment. Explicitly disable recording until + # this is fixed. + with tape.stop_recording(): + grads = grad_tape.gradient(loss, model.variables) + return grads def apply_gradients(model, optimizer, gradients): @@ -188,11 +197,14 @@ class ResNet50Benchmarks(tf.test.Benchmark): return (32,) return (16, 32) - def _report(self, label, start, num_iters, device, batch_size, data_format): + def _report(self, label, start, num_iters, device, batch_size, data_format, + num_replicas=1): avg_time = (time.time() - start) / num_iters dev = tf.DeviceSpec.from_string(device).device_type.lower() - name = '%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format) - extras = {'examples_per_sec': batch_size / avg_time} + replica_str = '' if num_replicas == 1 else 'replicas_%d_' % num_replicas + name = '%s_%s_batch_%d_%s%s' % (label, dev, batch_size, + replica_str, data_format) + extras = {'examples_per_sec': (num_replicas * batch_size) / avg_time} self.report_benchmark( iters=num_iters, wall_time=avg_time, name=name, extras=extras) diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..4f0d46b1bae3760a63b2abe871034bdedf258f07 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -0,0 +1,172 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# Model +py_library( + name = "ops", + srcs = ["ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "config", + srcs = ["config.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "blocks", + srcs = ["blocks.py"], + srcs_version = "PY2AND3", + deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "revnet", + srcs = ["revnet.py"], + srcs_version = "PY2AND3", + deps = [ + ":blocks", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "resnet_preprocessing", + srcs = ["resnet_preprocessing.py"], + srcs_version = "PY2AND3", + tags = ["local"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "imagenet_input", + srcs = ["imagenet_input.py"], + srcs_version = "PY2AND3", + tags = ["local"], + deps = [ + ":resnet_preprocessing", + "//tensorflow:tensorflow_py", + ], +) + +# Tests +cuda_py_test( + name = "ops_test", + size = "large", + srcs = ["ops_test.py"], + additional_deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "blocks_test", + size = "large", + srcs = ["blocks_test.py"], + additional_deps = [ + ":blocks", + "//tensorflow:tensorflow_py", + ], + tags = [ + "optonly", + ], +) + +cuda_py_test( + name = "revnet_test", + size = "large", + srcs = ["revnet_test.py"], + additional_deps = [ + ":blocks_test", + ":config", + ":revnet", + "//tensorflow:tensorflow_py", + ], + tags = [ + "no_pip", # depends on blocks_test, which is not available in pip package + "optonly", + ], +) + +# Training +py_library( + name = "cifar_input", + srcs = ["cifar_input.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "cifar_tfrecords", + srcs = ["cifar_tfrecords.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "main", + srcs = ["main.py"], + srcs_version = "PY2AND3", + deps = [ + ":cifar_input", + ":config", + ":revnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "main_estimator", + srcs = ["main_estimator.py"], + srcs_version = "PY2AND3", + deps = [ + ":cifar_input", + ":main", + ":revnet", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "main_estimator_lib", + srcs = ["main_estimator.py"], + srcs_version = "PY2AND3", + deps = [ + ":cifar_input", + ":main", + ":revnet", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "main_estimator_tpu_lib", + srcs = ["main_estimator_tpu.py"], + srcs_version = "PY2AND3", + deps = [ + ":cifar_input", + ":main", + ":revnet", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/revnet/README.md b/tensorflow/contrib/eager/python/examples/revnet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..822d86e9c7a7e620da3b84ded9af98b1c1d4b701 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/README.md @@ -0,0 +1,112 @@ +# RevNet with TensorFlow eager execution + +This folder contains a TensorFlow eager implementation of the [Reversible Residual Network](https://arxiv.org/pdf/1707.04585.pdf) adapted from the released implementation by the authors. The presented implementation can be ran with both eager and graph execution. The code is considerably simplified with `tf.GradientTape`. Moreover, we reduce the a redundant forward pass in the implementation by the authors. This saves us from using `tf.stop_gradient` and makes the model run faster. + +## Content + +- `revnet.py`: The RevNet model. +- `blocks.py`: The relevant reversible blocks. +- `ops.py`: Auxiliary downsampling operation. +- `cifar_tfrecords.py`: Script to generate the TFRecords for both CIFAR-10 and CIFAR-100. +- `cifar_input.py`: Script to read from TFRecords and generate dataset objects with the `tf.data` API. +- `config.py`: Configuration file for network architectures and training hyperparameters. +- `main.py`: Main training and evaluation script. +- `main_estimator.py`: Script to train RevNet models on CIFAR-10 and CIFAR-100 with the `tf.estimator` API. +- `main_estimator_tpu.py`: Script to train RevNet models on ImageNet with TPU estimators on Cloud TPUs. +- `resnet_preprocessing.py`, `imagenet_input.py`: Boilerplate to read ImageNet data from TFRecords. + +## Train on CIFAR-10/CIFAR-100 +- Make sure you have installed TensorFlow 1.10+ or the latest `tf-nightly` +or `tf-nightly-gpu` pip package in order to access the eager execution feature. + +- First run + +```bash +python cifar_tfrecords.py --data_dir ${PWD}/cifar +``` +to download the cifar dataset and convert them +to TFRecords. This produces TFRecord files for both CIFAR-10 and CIFAR-100. + +- To train a model, run + +```bash +python main.py --data_dir ${PWD}/cifar +``` + +- Optional arguments for `main.py` include + - `train_dir`: Directory to store eventfiles and checkpoints. + - `restore`: Restore the latest checkpoint. + - `validate`: Use validation set for training monitoring. + - `dataset`: Use either `cifar-10` or `cifar-100`. + - `config`: RevNet configuration. + - `use_defun`: Use `tfe.defun` to boost performance. + +- To train a model with estimators in graph execution, run + +```bash +python main_estimator.py --data_dir ${PWD}/cifar +``` +To ensure our code works properly when using the Keras model in an estimator, +`tf-nightly` or `tf-nightly-gpu` is highly recommended as of August 2018. + +- Optional arguments for `main.py` include + - `model_dir`: Directory to store eventfiles and checkpoints. + - `dataset`: Use either `cifar-10` or `cifar-100`. + - `config`: RevNet configuration. + - `export`: Export the model for serving if True. + +## Speed up with `tfe.defun` +To ensure that `tf.contrib.eager.defun` in our code works properly with all +part of the model during training, the latest `tf-nightly` or `tf-nightly-gpu` +is highly recommended as of August 2018. + +Even though the speed difference between pure eager execution and graph execution is noticeable, +the difference between fully "defunned" model training and graph +training is negligible. + +## Train on ImageNet with Cloud TPUs +The standard way to train models on Cloud TPUs is via TPU estimators and graph +execution. Models built with the `tf.keras` API are fully compatible with TPU estimators. +To ensure our code works properly in this setting, +`tf-nightly` or `tf-nightly-gpu` is highly recommended as of August 2018. + +### Setup a Google Cloud project + +Follow the instructions at the [Quickstart Guide](https://cloud.google.com/tpu/docs/quickstart) +to get a GCE VM with access to Cloud TPU. + +To run this model, you will need: + +* A GCE VM instance with an associated Cloud TPU resource +* A GCS bucket to store your training checkpoints +* (Optional): The ImageNet training and validation data preprocessed into + TFRecord format, and stored in GCS. + +### Format the data + +The data is expected to be formatted in TFRecord format, as generated by [this +script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py). + +If you do not have ImageNet dataset prepared, you can use a randomly generated +fake dataset to test the model. It is located at +`gs://cloud-tpu-test-datasets/fake_imagenet`. + +### Start training + +Train the model by executing the following command (substituting the appropriate +values): + +```bash +python main_estimator_tpu.py \ + --tpu=$TPU_NAME \ + --data_dir=$DATA_DIR \ + --model_dir=$MODEL_DIR +``` + +## Performance +- RevNet-38 achieves >92% and >71% accuracy on CIFAR-10 and CIFAR-100 respectively. +- RevNet-56 achieves <26% top-1 error rate on ImageNet. + +## Reference +The Reversible Residual Network: Backpropagation Without Storing Activations. +Aidan N. Gomez, Mengye Ren, Raquel Urtasun, Roger B. Grosse. Neural Information Processing Systems (NIPS), 2017. diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..f61354bc38a9fcb941f186cac4eac8097eea742d --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py @@ -0,0 +1,504 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Building blocks with manual backward gradient computation. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import operator + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import ops + + +class RevBlock(tf.keras.Model): + """Single reversible block containing several `_Residual` blocks. + + Each `_Residual` block in turn contains two _ResidualInner blocks, + corresponding to the `F`/`G` functions in the paper. + """ + + def __init__(self, + n_res, + filters, + strides, + input_shape, + batch_norm_first=False, + data_format="channels_first", + bottleneck=False, + fused=True, + dtype=tf.float32): + """Initialization. + + Args: + n_res: number of residual blocks + filters: list/tuple of integers for output filter sizes of each residual + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + bottleneck: use bottleneck residual if True + fused: use fused batch normalization if True + dtype: float16, float32, or float64 + """ + super(RevBlock, self).__init__() + self.blocks = tf.contrib.checkpoint.List() + for i in range(n_res): + curr_batch_norm_first = batch_norm_first and i == 0 + curr_strides = strides if i == 0 else (1, 1) + block = _Residual( + filters, + curr_strides, + input_shape, + batch_norm_first=curr_batch_norm_first, + data_format=data_format, + bottleneck=bottleneck, + fused=fused, + dtype=dtype) + self.blocks.append(block) + + if data_format == "channels_first": + input_shape = (filters, input_shape[1] // curr_strides[0], + input_shape[2] // curr_strides[1]) + else: + input_shape = (input_shape[0] // curr_strides[0], + input_shape[1] // curr_strides[1], filters) + + def call(self, h, training=True): + """Apply reversible block to inputs.""" + + for block in self.blocks: + h = block(h, training=training) + return h + + def backward_grads(self, x, y, dy, training=True): + """Apply reversible block backward to outputs.""" + + grads_all = [] + for i in reversed(range(len(self.blocks))): + block = self.blocks[i] + if i == 0: + # First block usually contains downsampling that can't be reversed + dy, grads = block.backward_grads_with_downsample( + x, y, dy, training=True) + else: + y, dy, grads = block.backward_grads(y, dy, training=training) + grads_all = grads + grads_all + + return dy, grads_all + + +class _Residual(tf.keras.Model): + """Single residual block contained in a _RevBlock. Each `_Residual` object has + two _ResidualInner objects, corresponding to the `F` and `G` functions in the + paper. + """ + + def __init__(self, + filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + bottleneck=False, + fused=True, + dtype=tf.float32): + """Initialization. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC", + bottleneck: use bottleneck residual if True + fused: use fused batch normalization if True + dtype: float16, float32, or float64 + """ + super(_Residual, self).__init__() + + self.filters = filters + self.strides = strides + self.axis = 1 if data_format == "channels_first" else 3 + if data_format == "channels_first": + f_input_shape = (input_shape[0] // 2,) + input_shape[1:] + g_input_shape = (filters // 2, input_shape[1] // strides[0], + input_shape[2] // strides[1]) + else: + f_input_shape = input_shape[:2] + (input_shape[2] // 2,) + g_input_shape = (input_shape[0] // strides[0], + input_shape[1] // strides[1], filters // 2) + + factory = _BottleneckResidualInner if bottleneck else _ResidualInner + self.f = factory( + filters=filters // 2, + strides=strides, + input_shape=f_input_shape, + batch_norm_first=batch_norm_first, + data_format=data_format, + fused=fused, + dtype=dtype) + self.g = factory( + filters=filters // 2, + strides=(1, 1), + input_shape=g_input_shape, + batch_norm_first=batch_norm_first, + data_format=data_format, + fused=fused, + dtype=dtype) + + def call(self, x, training=True): + """Apply residual block to inputs.""" + x1, x2 = x + f_x2 = self.f(x2, training=training) + x1_down = ops.downsample( + x1, self.filters // 2, self.strides, axis=self.axis) + x2_down = ops.downsample( + x2, self.filters // 2, self.strides, axis=self.axis) + y1 = f_x2 + x1_down + g_y1 = self.g(y1, training=training) + y2 = g_y1 + x2_down + + return y1, y2 + + def backward_grads(self, y, dy, training=True): + """Manually compute backward gradients given input and output grads.""" + dy1, dy2 = dy + y1, y2 = y + + with tf.GradientTape() as gtape: + gtape.watch(y1) + gy1 = self.g(y1, training=training) + grads_combined = gtape.gradient( + gy1, [y1] + self.g.trainable_variables, output_gradients=dy2) + dg = grads_combined[1:] + dx1 = dy1 + grads_combined[0] + # This doesn't affect eager execution, but improves memory efficiency with + # graphs + with tf.control_dependencies(dg + [dx1]): + x2 = y2 - gy1 + + with tf.GradientTape() as ftape: + ftape.watch(x2) + fx2 = self.f(x2, training=training) + grads_combined = ftape.gradient( + fx2, [x2] + self.f.trainable_variables, output_gradients=dx1) + df = grads_combined[1:] + dx2 = dy2 + grads_combined[0] + # Same behavior as above + with tf.control_dependencies(df + [dx2]): + x1 = y1 - fx2 + + x = x1, x2 + dx = dx1, dx2 + grads = df + dg + + return x, dx, grads + + def backward_grads_with_downsample(self, x, y, dy, training=True): + """Manually compute backward gradients given input and output grads.""" + # Splitting this from `backward_grads` for better readability + x1, x2 = x + y1, _ = y + dy1, dy2 = dy + + with tf.GradientTape() as gtape: + gtape.watch(y1) + gy1 = self.g(y1, training=training) + grads_combined = gtape.gradient( + gy1, [y1] + self.g.trainable_variables, output_gradients=dy2) + dg = grads_combined[1:] + dz1 = dy1 + grads_combined[0] + + # dx1 need one more step to backprop through downsample + with tf.GradientTape() as x1tape: + x1tape.watch(x1) + z1 = ops.downsample(x1, self.filters // 2, self.strides, axis=self.axis) + dx1 = x1tape.gradient(z1, x1, output_gradients=dz1) + + with tf.GradientTape() as ftape: + ftape.watch(x2) + fx2 = self.f(x2, training=training) + grads_combined = ftape.gradient( + fx2, [x2] + self.f.trainable_variables, output_gradients=dz1) + dx2, df = grads_combined[0], grads_combined[1:] + + # dx2 need one more step to backprop through downsample + with tf.GradientTape() as x2tape: + x2tape.watch(x2) + z2 = ops.downsample(x2, self.filters // 2, self.strides, axis=self.axis) + dx2 += x2tape.gradient(z2, x2, output_gradients=dy2) + + dx = dx1, dx2 + grads = df + dg + + return dx, grads + + +# Ideally, the following should be wrapped in `tf.keras.Sequential`, however +# there are subtle issues with its placeholder insertion policy and batch norm +class _BottleneckResidualInner(tf.keras.Model): + """Single bottleneck residual inner function contained in _Resdual. + + Corresponds to the `F`/`G` functions in the paper. + Suitable for training on ImageNet dataset. + """ + + def __init__(self, + filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + fused=True, + dtype=tf.float32): + """Initialization. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + fused: use fused batch normalization if True + dtype: float16, float32, or float64 + """ + super(_BottleneckResidualInner, self).__init__() + axis = 1 if data_format == "channels_first" else 3 + if batch_norm_first: + self.batch_norm_0 = tf.keras.layers.BatchNormalization( + axis=axis, input_shape=input_shape, fused=fused, dtype=dtype) + self.conv2d_1 = tf.keras.layers.Conv2D( + filters=filters // 4, + kernel_size=1, + strides=strides, + input_shape=input_shape, + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype) + + self.batch_norm_1 = tf.keras.layers.BatchNormalization( + axis=axis, fused=fused, dtype=dtype) + self.conv2d_2 = tf.keras.layers.Conv2D( + filters=filters // 4, + kernel_size=3, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype) + + self.batch_norm_2 = tf.keras.layers.BatchNormalization( + axis=axis, fused=fused, dtype=dtype) + self.conv2d_3 = tf.keras.layers.Conv2D( + filters=filters, + kernel_size=1, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype) + + self.batch_norm_first = batch_norm_first + + def call(self, x, training=True): + net = x + if self.batch_norm_first: + net = self.batch_norm_0(net, training=training) + net = tf.nn.relu(net) + net = self.conv2d_1(net) + + net = self.batch_norm_1(net, training=training) + net = tf.nn.relu(net) + net = self.conv2d_2(net) + + net = self.batch_norm_2(net, training=training) + net = tf.nn.relu(net) + net = self.conv2d_3(net) + + return net + + +class _ResidualInner(tf.keras.Model): + """Single residual inner function contained in _ResdualBlock. + + Corresponds to the `F`/`G` functions in the paper. + """ + + def __init__(self, + filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + fused=True, + dtype=tf.float32): + """Initialization. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + fused: use fused batch normalization if True + dtype: float16, float32, or float64 + """ + super(_ResidualInner, self).__init__() + axis = 1 if data_format == "channels_first" else 3 + if batch_norm_first: + self.batch_norm_0 = tf.keras.layers.BatchNormalization( + axis=axis, input_shape=input_shape, fused=fused, dtype=dtype) + self.conv2d_1 = tf.keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=strides, + input_shape=input_shape, + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype) + + self.batch_norm_1 = tf.keras.layers.BatchNormalization( + axis=axis, fused=fused, dtype=dtype) + self.conv2d_2 = tf.keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype) + + self.batch_norm_first = batch_norm_first + + def call(self, x, training=True): + net = x + if self.batch_norm_first: + net = self.batch_norm_0(net, training=training) + net = tf.nn.relu(net) + net = self.conv2d_1(net) + + net = self.batch_norm_1(net, training=training) + net = tf.nn.relu(net) + net = self.conv2d_2(net) + + return net + + +class InitBlock(tf.keras.Model): + """Initial block of RevNet.""" + + def __init__(self, config): + """Initialization. + + Args: + config: tf.contrib.training.HParams object; specifies hyperparameters + """ + super(InitBlock, self).__init__() + self.config = config + self.axis = 1 if self.config.data_format == "channels_first" else 3 + self.conv2d = tf.keras.layers.Conv2D( + filters=self.config.init_filters, + kernel_size=self.config.init_kernel, + strides=(self.config.init_stride, self.config.init_stride), + data_format=self.config.data_format, + use_bias=False, + padding="SAME", + input_shape=self.config.input_shape, + dtype=self.config.dtype) + self.batch_norm = tf.keras.layers.BatchNormalization( + axis=self.axis, fused=self.config.fused, dtype=self.config.dtype) + self.activation = tf.keras.layers.Activation("relu") + + if self.config.init_max_pool: + self.max_pool = tf.keras.layers.MaxPooling2D( + pool_size=(3, 3), + strides=(2, 2), + padding="SAME", + data_format=self.config.data_format, + dtype=self.config.dtype) + + def call(self, x, training=True): + net = x + net = self.conv2d(net) + net = self.batch_norm(net, training=training) + net = self.activation(net) + + if self.config.init_max_pool: + net = self.max_pool(net) + + return tf.split(net, num_or_size_splits=2, axis=self.axis) + + +class FinalBlock(tf.keras.Model): + """Final block of RevNet.""" + + def __init__(self, config): + """Initialization. + + Args: + config: tf.contrib.training.HParams object; specifies hyperparameters + + Raises: + ValueError: Unsupported data format + """ + super(FinalBlock, self).__init__() + self.config = config + self.axis = 1 if self.config.data_format == "channels_first" else 3 + + f = self.config.filters[-1] # Number of filters + r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio + r *= self.config.init_stride + if self.config.init_max_pool: + r *= 2 + + if self.config.data_format == "channels_first": + w, h = self.config.input_shape[1], self.config.input_shape[2] + input_shape = (f, w // r, h // r) + elif self.config.data_format == "channels_last": + w, h = self.config.input_shape[0], self.config.input_shape[1] + input_shape = (w // r, h // r, f) + else: + raise ValueError("Data format should be either `channels_first`" + " or `channels_last`") + self.batch_norm = tf.keras.layers.BatchNormalization( + axis=self.axis, + input_shape=input_shape, + fused=self.config.fused, + dtype=self.config.dtype) + self.activation = tf.keras.layers.Activation("relu") + self.global_avg_pool = tf.keras.layers.GlobalAveragePooling2D( + data_format=self.config.data_format, dtype=self.config.dtype) + self.dense = tf.keras.layers.Dense( + self.config.n_classes, dtype=self.config.dtype) + + def call(self, x, training=True): + net = tf.concat(x, axis=self.axis) + net = self.batch_norm(net, training=training) + net = self.activation(net) + net = self.global_avg_pool(net) + net = self.dense(net) + + return net diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff6b605b912772a92ab9e07a0ba5b9325030e43 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py @@ -0,0 +1,288 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for basic building blocks used in eager mode RevNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import blocks + + +def compute_degree(g1, g2, eps=1e-7): + """Compute the degree between two vectors using their usual inner product.""" + + def _dot(u, v): + return tf.reduce_sum(u * v) + + g1_norm = tf.sqrt(_dot(g1, g1)) + g2_norm = tf.sqrt(_dot(g2, g2)) + if g1_norm.numpy() == 0 and g2_norm.numpy() == 0: + cosine = 1. - eps + else: + g1_norm = 1. if g1_norm.numpy() == 0 else g1_norm + g2_norm = 1. if g2_norm.numpy() == 0 else g2_norm + cosine = _dot(g1, g2) / g1_norm / g2_norm + # Restrict to arccos range + cosine = tf.minimum(tf.maximum(cosine, eps - 1.), 1. - eps) + degree = tf.acos(cosine) * 180. / 3.141592653589793 + + return degree + + +def _validate_block_call_channels_last(block_factory, test): + """Generic testing function for `channels_last` data format. + + Completes a set of tests varying data format, stride, and batch normalization + configured train vs test time. + Args: + block_factory: constructor of one of blocks.InitBlock, blocks.FinalBlock, + blocks._ResidualInner + test: tf.test.TestCase object + """ + with tf.device("/cpu:0"): # NHWC format + input_shape = (8, 8, 128) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride 1 + block = block_factory( + filters=128, + strides=(1, 1), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 8, 8, 128)) + test.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = block_factory( + filters=128, + strides=(2, 2), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 4, 4, 128)) + test.assertNotAllClose(y_tr, y_ev) + + +def _validate_block_call_channels_first(block_factory, test): + """Generic testing function for `channels_first` data format. + + Completes a set of tests varying data format, stride, and batch normalization + configured train vs test time. + Args: + block_factory: constructor of one of blocks.InitBlock, blocks.FinalBlock, + blocks._ResidualInner + test: tf.test.TestCase object + """ + if not tf.test.is_gpu_available(): + test.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (128, 8, 8) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride of 1 + block = block_factory(filters=128, strides=(1, 1), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 128, 8, 8)) + test.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = block_factory(filters=128, strides=(2, 2), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 128, 4, 4)) + test.assertNotAllClose(y_tr, y_ev) + + +class RevBlockTest(tf.test.TestCase): + + def _check_grad_angle(self, grads, grads_true, atol=1e0): + """Check the angle between two list of vectors are all close.""" + for g1, g2 in zip(grads, grads_true): + degree = compute_degree(g1, g2) + self.assertLessEqual(degree, atol) + + def test_backward_grads_channels_first(self): + """Test `backward` function with `channels_first` data format.""" + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + # Stride 1 + input_shape = (128, 8, 8) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) + block = blocks.RevBlock( + n_res=3, + filters=128, + strides=(1, 1), + input_shape=input_shape, + fused=False, + dtype=tf.float64) + with tf.GradientTape() as tape: + tape.watch(x) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) + y1, y2 = block((x1, x2), training=True) + y = tf.concat((y1, y2), axis=1) + # Compute grads from reconstruction + (dx1, dx2), dw = block.backward_grads( + x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) + dx = tf.concat((dx1, dx2), axis=1) + vars_ = block.trainable_variables + # Compute true grads + grads = tape.gradient(y, [x] + vars_, output_gradients=dy) + dx_true, dw_true = grads[0], grads[1:] + self.assertAllClose(dx_true, dx) + self.assertAllClose(dw_true, dw) + self._check_grad_angle(dx_true, dx) + self._check_grad_angle(dw_true, dw) + + # Stride 2 + x = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=(16, 128, 4, 4), dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) + block = blocks.RevBlock( + n_res=3, + filters=128, + strides=(2, 2), + input_shape=input_shape, + fused=False, + dtype=tf.float64) + with tf.GradientTape() as tape: + tape.watch(x) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) + y1, y2 = block((x1, x2), training=True) + y = tf.concat((y1, y2), axis=1) + # Compute grads from reconstruction + (dx1, dx2), dw = block.backward_grads( + x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) + dx = tf.concat((dx1, dx2), axis=1) + vars_ = block.trainable_variables + # Compute true grads + grads = tape.gradient(y, [x] + vars_, output_gradients=dy) + dx_true, dw_true = grads[0], grads[1:] + self.assertAllClose(dx_true, dx) + self.assertAllClose(dw_true, dw) + self._check_grad_angle(dx_true, dx) + self._check_grad_angle(dw_true, dw) + + def test_backward_grads_with_nativepy(self): + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + input_shape = (128, 8, 8) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) + block = blocks.RevBlock( + n_res=3, + filters=128, + strides=(1, 1), + input_shape=input_shape, + fused=False, + dtype=tf.float64) + with tf.GradientTape() as tape: + tape.watch(x) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) + y1, y2 = block((x1, x2), training=True) + y = tf.concat((y1, y2), axis=1) + + # Compute true grads + dx_true = tape.gradient(y, x, output_gradients=dy) + + # Compute grads from reconstruction + (dx1, dx2), _ = block.backward_grads( + x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) + dx = tf.concat((dx1, dx2), axis=1) + + thres = 1e-5 + diff_abs = tf.reshape(abs(dx - dx_true), [-1]) + assert all(diff_abs < thres) + + +class _ResidualTest(tf.test.TestCase): + + def test_backward_grads_channels_first(self): + """Test `backward_grads` function with `channels_first` data format.""" + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (128, 8, 8) + data_shape = (16,) + input_shape + # Use double precision for testing + x_true = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) + residual = blocks._Residual( + filters=128, + strides=(1, 1), + input_shape=input_shape, + fused=False, + dtype=tf.float64) + + with tf.GradientTape() as tape: + tape.watch(x_true) + x1_true, x2_true = tf.split(x_true, num_or_size_splits=2, axis=1) + y1, y2 = residual((x1_true, x2_true), training=True) + y = tf.concat((y1, y2), axis=1) + + # Gradients computed due to reversibility + (x1, x2), (dx1, dx2), dw = residual.backward_grads( + y=(y1, y2), dy=(dy1, dy2), training=True) + x = tf.concat((x1, x2), axis=1) + dx = tf.concat((dx1, dx2), axis=1) + # True gradients computed by the tape + grads = tape.gradient( + y, [x_true] + residual.trainable_variables, output_gradients=dy) + dx_true, dw_true = grads[0], grads[1:] + + self.assertAllClose(x_true, x) + self.assertAllClose(dx_true, dx) + self.assertAllClose(dw_true, dw) + + +class _ResidualInnerTest(tf.test.TestCase): + + def test_call(self): + """Test `call` function.""" + + _validate_block_call_channels_first(blocks._ResidualInner, self) + _validate_block_call_channels_last(blocks._ResidualInner, self) + + +class _BottleneckResidualInner(tf.test.TestCase): + + def test_call(self): + """Test `call` function.""" + + _validate_block_call_channels_first(blocks._BottleneckResidualInner, self) + _validate_block_call_channels_last(blocks._BottleneckResidualInner, self) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py new file mode 100644 index 0000000000000000000000000000000000000000..e9672f13e1587c96cea0fc7dd58b66ef256296cd --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py @@ -0,0 +1,116 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script for reading and loading CIFAR-10.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf + +# Global constants describing the CIFAR data set. +IMAGE_HEIGHT = 32 +IMAGE_WIDTH = 32 +NUM_CHANNEL = 3 + + +def get_ds_from_tfrecords(data_dir, + split, + data_aug=True, + batch_size=100, + epochs=None, + shuffle=True, + data_format="channels_first", + num_parallel_calls=12, + prefetch=0, + div255=True, + dtype=tf.float32): + """Returns a tf.train.Dataset object from reading tfrecords. + + Args: + data_dir: Directory of tfrecords + split: "train", "validation", or "test" + data_aug: Apply data augmentation if True + batch_size: Batch size of dataset object + epochs: Number of epochs to repeat the dataset; default `None` means + repeating indefinitely + shuffle: Shuffle the dataset if True + data_format: `channels_first` or `channels_last` + num_parallel_calls: Number of threads for dataset preprocess + prefetch: Buffer size for prefetch + div255: Divide the images by 255 if True + dtype: Data type of images + Returns: + A tf.train.Dataset object + + Raises: + ValueError: Unknown split + """ + + if split not in ["train", "validation", "test", "train_all"]: + raise ValueError("Unknown split {}".format(split)) + + def _parser(serialized_example): + """Parses a single tf.Example into image and label tensors.""" + features = tf.parse_single_example( + serialized_example, + features={ + "image": tf.FixedLenFeature([], tf.string), + "label": tf.FixedLenFeature([], tf.int64), + }) + image = tf.decode_raw(features["image"], tf.uint8) + # Initially reshaping to [H, W, C] does not work + image = tf.reshape(image, [NUM_CHANNEL, IMAGE_HEIGHT, IMAGE_WIDTH]) + # This is needed for `tf.image.resize_image_with_crop_or_pad` + image = tf.transpose(image, [1, 2, 0]) + + image = tf.cast(image, dtype) + label = tf.cast(features["label"], tf.int32) + + if data_aug: + image = tf.image.resize_image_with_crop_or_pad(image, IMAGE_HEIGHT + 4, + IMAGE_WIDTH + 4) + image = tf.random_crop(image, [IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNEL]) + image = tf.image.random_flip_left_right(image) + + if data_format == "channels_first": + image = tf.transpose(image, [2, 0, 1]) + + if div255: + image /= 255. + + return image, label + + filename = os.path.join(data_dir, split + ".tfrecords") + dataset = tf.data.TFRecordDataset(filename) + dataset = dataset.repeat(epochs) + dataset = dataset.map(_parser, num_parallel_calls=num_parallel_calls) + dataset = dataset.prefetch(prefetch) + + if shuffle: + # Find the right size according to the split + size = { + "train": 40000, + "validation": 10000, + "test": 10000, + "train_all": 50000 + }[split] + dataset = dataset.shuffle(size) + + dataset = dataset.batch(batch_size, drop_remainder=True) + + return dataset diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py new file mode 100644 index 0000000000000000000000000000000000000000..377844ad8fbca92629a4d71f5df2aab67b570c3c --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py @@ -0,0 +1,154 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Read CIFAR data from pickled numpy arrays and writes TFRecords. + +Generates tf.train.Example protos and writes them to TFRecord files from the +python version of the CIFAR dataset downloaded from +https://www.cs.toronto.edu/~kriz/cifar.html. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import tarfile + +from absl import flags +from six.moves import cPickle as pickle +from six.moves import urllib +import tensorflow as tf + +BASE_URL = 'https://www.cs.toronto.edu/~kriz/' +CIFAR_FILE_NAMES = ['cifar-10-python.tar.gz', 'cifar-100-python.tar.gz'] +CIFAR_DOWNLOAD_URLS = [BASE_URL + name for name in CIFAR_FILE_NAMES] +CIFAR_LOCAL_FOLDERS = ['cifar-10', 'cifar-100'] +EXTRACT_FOLDERS = ['cifar-10-batches-py', 'cifar-100-python'] + + +def download_and_extract(data_dir, file_name, url): + """Download CIFAR if not already downloaded.""" + filepath = os.path.join(data_dir, file_name) + if tf.gfile.Exists(filepath): + return filepath + if not tf.gfile.Exists(data_dir): + tf.gfile.MakeDirs(data_dir) + + urllib.request.urlretrieve(url, filepath) + tarfile.open(os.path.join(filepath), 'r:gz').extractall(data_dir) + return filepath + + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _get_file_names(folder): + """Returns the file names expected to exist in the input_dir.""" + assert folder in ['cifar-10', 'cifar-100'] + + file_names = {} + if folder == 'cifar-10': + file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)] + file_names['validation'] = ['data_batch_5'] + file_names['train_all'] = ['data_batch_%d' % i for i in range(1, 6)] + file_names['test'] = ['test_batch'] + else: + file_names['train_all'] = ['train'] + file_names['test'] = ['test'] + # Split in `convert_to_tfrecord` function + file_names['train'] = ['train'] + file_names['validation'] = ['train'] + return file_names + + +def read_pickle_from_file(filename): + with tf.gfile.Open(filename, 'rb') as f: + if sys.version_info >= (3, 0): + data_dict = pickle.load(f, encoding='bytes') + else: + data_dict = pickle.load(f) + return data_dict + + +def convert_to_tfrecord(input_files, output_file, folder): + """Converts files with pickled data to TFRecords.""" + assert folder in ['cifar-10', 'cifar-100'] + + print('Generating %s' % output_file) + with tf.python_io.TFRecordWriter(output_file) as record_writer: + for input_file in input_files: + data_dict = read_pickle_from_file(input_file) + data = data_dict[b'data'] + try: + labels = data_dict[b'labels'] + except KeyError: + labels = data_dict[b'fine_labels'] + + if folder == 'cifar-100' and input_file.endswith('train.tfrecords'): + data = data[:40000] + labels = labels[:40000] + elif folder == 'cifar-100' and input_file.endswith( + 'validation.tfrecords'): + data = data[40000:] + labels = labels[40000:] + + num_entries_in_batch = len(labels) + + for i in range(num_entries_in_batch): + example = tf.train.Example( + features=tf.train.Features( + feature={ + 'image': _bytes_feature(data[i].tobytes()), + 'label': _int64_feature(labels[i]) + })) + record_writer.write(example.SerializeToString()) + + +def main(_): + for file_name, url, folder, extract_folder in zip( + CIFAR_FILE_NAMES, CIFAR_DOWNLOAD_URLS, CIFAR_LOCAL_FOLDERS, + EXTRACT_FOLDERS): + print('Download from {} and extract.'.format(url)) + data_dir = os.path.join(FLAGS.data_dir, folder) + download_and_extract(data_dir, file_name, url) + file_names = _get_file_names(folder) + input_dir = os.path.join(data_dir, extract_folder) + + for mode, files in file_names.items(): + input_files = [os.path.join(input_dir, f) for f in files] + output_file = os.path.join(data_dir, mode + '.tfrecords') + try: + os.remove(output_file) + except OSError: + pass + convert_to_tfrecord(input_files, output_file, folder) + + print('Done!') + + +if __name__ == '__main__': + FLAGS = flags.FLAGS + flags.DEFINE_string( + 'data_dir', + default=None, + help='Directory to download, extract and store TFRecords.') + + tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py new file mode 100644 index 0000000000000000000000000000000000000000..29f1db0e0367515757413c8e47f7b7280fc4cfbb --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/config.py @@ -0,0 +1,175 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Configuration in format of tf.contrib.training.HParams. +Supports CIFAR-10, CIFAR-100, and ImageNet datasets. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def get_hparams_cifar_38(): + """RevNet-38 configurations for CIFAR-10/CIFAR-100.""" + + config = tf.contrib.training.HParams() + config.add_hparam("num_train_images", 50000) + config.add_hparam("num_eval_images", 10000) + config.add_hparam("init_filters", 32) + config.add_hparam("init_kernel", 3) + config.add_hparam("init_stride", 1) + config.add_hparam("n_rev_blocks", 3) + config.add_hparam("n_res", [3, 3, 3]) + config.add_hparam("filters", [32, 64, 112]) + config.add_hparam("strides", [1, 2, 2]) + config.add_hparam("batch_size", 100) + config.add_hparam("bottleneck", False) + config.add_hparam("fused", True) + config.add_hparam("init_max_pool", False) + if tf.test.is_gpu_available(): + config.add_hparam("input_shape", (3, 32, 32)) + config.add_hparam("data_format", "channels_first") + else: + config.add_hparam("input_shape", (32, 32, 3)) + config.add_hparam("data_format", "channels_last") + + # Training details + config.add_hparam("weight_decay", 2e-4) + config.add_hparam("momentum", .9) + config.add_hparam("lr_decay_steps", [40000, 60000]) + config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3]) + config.add_hparam("max_train_iter", 80000) + config.add_hparam("seed", 1234) + config.add_hparam("shuffle", True) + config.add_hparam("log_every", 500) + config.add_hparam("save_every", 500) + config.add_hparam("dtype", tf.float32) + config.add_hparam("eval_batch_size", 1000) + config.add_hparam("div255", True) + # This is imprecise, when training with validation set, + # we only have 40k images in training data + config.add_hparam("iters_per_epoch", + config.num_train_images // config.batch_size) + config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch) + + # Customized TPU hyperparameters due to differing batch size caused by + # TPU architecture specifics + # Suggested batch sizes to reduce overhead from excessive tensor padding + # https://cloud.google.com/tpu/docs/troubleshooting + config.add_hparam("tpu_batch_size", 1024) + config.add_hparam("tpu_eval_batch_size", 1024) + config.add_hparam("tpu_iters_per_epoch", + config.num_train_images // config.tpu_batch_size) + config.add_hparam("tpu_epochs", + config.max_train_iter // config.tpu_iters_per_epoch) + config.add_hparam("tpu_eval_steps", + config.num_eval_images // config.tpu_eval_batch_size) + return config + + +def get_hparams_cifar_110(): + config = get_hparams_cifar_38() + config.filters = [32, 64, 128] + config.n_res = [9, 9, 9] + + return config + + +def get_hparams_cifar_164(): + config = get_hparams_cifar_38() + config.filters = [32, 64, 128] + config.n_res = [9, 9, 9] + config.use_bottleneck = True + # Due to bottleneck residual blocks + filters = [f * 4 for f in config.filters] + config.filters = filters + + return config + + +def get_hparams_imagenet_56(): + """RevNet-56 configurations for ImageNet.""" + + config = tf.contrib.training.HParams() + config.add_hparam("n_classes", 1000) + config.add_hparam("dataset", "ImageNet") + config.add_hparam("num_train_images", 1281167) + config.add_hparam("num_eval_images", 50000) + config.add_hparam("init_filters", 128) + config.add_hparam("init_kernel", 7) + config.add_hparam("init_stride", 2) + config.add_hparam("n_rev_blocks", 4) + config.add_hparam("n_res", [2, 2, 2, 2]) + config.add_hparam("filters", [128, 256, 512, 832]) + config.add_hparam("strides", [1, 2, 2, 2]) + config.add_hparam("batch_size", 256) + config.add_hparam("bottleneck", True) + config.add_hparam("fused", True) + config.add_hparam("init_max_pool", True) + if tf.test.is_gpu_available(): + config.add_hparam("input_shape", (3, 224, 224)) + config.add_hparam("data_format", "channels_first") + else: + config.add_hparam("input_shape", (224, 224, 3)) + config.add_hparam("data_format", "channels_last") + # Due to bottleneck residual blocks + filters = [f * 4 for f in config.filters] + config.filters = filters + + # Training details + config.add_hparam("weight_decay", 1e-4) + config.add_hparam("momentum", .9) + config.add_hparam("lr_decay_steps", [160000, 320000, 480000]) + config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3, 1e-4]) + config.add_hparam("max_train_iter", 600000) + config.add_hparam("seed", 1234) + config.add_hparam("shuffle", True) + config.add_hparam("log_every", 500) + config.add_hparam("save_every", 500) + config.add_hparam("dtype", tf.float32) + config.add_hparam("eval_batch_size", 256) + config.add_hparam("div255", True) + config.add_hparam("iters_per_epoch", + config.num_train_images // config.batch_size) + config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch) + + # Customized TPU hyperparameters due to differing batch size caused by + # TPU architecture specifics + # Suggested batch sizes to reduce overhead from excessive tensor padding + # https://cloud.google.com/tpu/docs/troubleshooting + config.add_hparam("tpu_batch_size", 1024) + config.add_hparam("tpu_eval_batch_size", 1024) + config.add_hparam("tpu_iters_per_epoch", + config.num_train_images // config.tpu_batch_size) + config.add_hparam("tpu_epochs", + config.max_train_iter // config.tpu_iters_per_epoch) + config.add_hparam("tpu_eval_steps", + config.num_eval_images // config.tpu_eval_batch_size) + return config + + +def get_hparams_imagenet_104(): + config = get_hparams_imagenet_56() + config.n_res = [2, 2, 11, 2] + + return config diff --git a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py new file mode 100644 index 0000000000000000000000000000000000000000..34a9984b0ecc527ad1991c28146246b716e96c98 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py @@ -0,0 +1,229 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Efficient ImageNet input pipeline using tf.data.Dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.revnet import resnet_preprocessing + + +def image_serving_input_fn(): + """Serving input fn for raw images.""" + + def _preprocess_image(image_bytes): + """Preprocess a single raw image.""" + image = resnet_preprocessing.preprocess_image( + image_bytes=image_bytes, is_training=False) + return image + + image_bytes_list = tf.placeholder( + shape=[None], + dtype=tf.string, + ) + images = tf.map_fn( + _preprocess_image, image_bytes_list, back_prop=False, dtype=tf.float32) + return tf.estimator.export.ServingInputReceiver( + images, {'image_bytes': image_bytes_list}) + + +class ImageNetInput(object): + """Generates ImageNet input_fn for training or evaluation. + + The training data is assumed to be in TFRecord format with keys as specified + in the dataset_parser below, sharded across 1024 files, named sequentially: + train-00000-of-01024 + train-00001-of-01024 + ... + train-01023-of-01024 + + The validation data is in the same format but sharded in 128 files. + + The format of the data required is created by the script at: + https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py + + Args: + is_training: `bool` for whether the input is for training + data_dir: `str` for the directory of the training and validation data; + if 'null' (the literal string 'null', not None), then construct a null + pipeline, consisting of empty images. + use_bfloat16: If True, use bfloat16 precision; else use float32. + transpose_input: 'bool' for whether to use the double transpose trick + num_cores: `int` for the number of TPU cores + """ + + def __init__(self, is_training, + use_bfloat16, + data_dir, + num_cores=8, + num_parallel_calls=64, + image_size=224, + transpose_input=False, + cache=False): + self.image_preprocessing_fn = resnet_preprocessing.preprocess_image + self.is_training = is_training + self.use_bfloat16 = use_bfloat16 + self.data_dir = data_dir + self.num_cores = num_cores + self.num_parallel_calls = num_parallel_calls + if self.data_dir == 'null' or self.data_dir == '': + self.data_dir = None + self.transpose_input = transpose_input + self.image_size = image_size + self.cache = cache + + def set_shapes(self, batch_size, images, labels): + """Statically set the batch_size dimension.""" + if self.transpose_input: + images.set_shape(images.get_shape().merge_with( + tf.TensorShape([None, None, None, batch_size]))) + labels.set_shape(labels.get_shape().merge_with( + tf.TensorShape([batch_size]))) + else: + images.set_shape(images.get_shape().merge_with( + tf.TensorShape([batch_size, None, None, None]))) + labels.set_shape(labels.get_shape().merge_with( + tf.TensorShape([batch_size]))) + + return images, labels + + def dataset_parser(self, value): + """Parse an ImageNet record from a serialized string Tensor.""" + keys_to_features = { + 'image/encoded': tf.FixedLenFeature((), tf.string, ''), + 'image/format': tf.FixedLenFeature((), tf.string, 'jpeg'), + 'image/class/label': tf.FixedLenFeature([], tf.int64, -1), + 'image/class/text': tf.FixedLenFeature([], tf.string, ''), + 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), + 'image/object/class/label': tf.VarLenFeature(dtype=tf.int64), + } + + parsed = tf.parse_single_example(value, keys_to_features) + image_bytes = tf.reshape(parsed['image/encoded'], shape=[]) + + image = self.image_preprocessing_fn( + image_bytes=image_bytes, + is_training=self.is_training, + image_size=self.image_size, + use_bfloat16=self.use_bfloat16) + + # Subtract one so that labels are in [0, 1000). + label = tf.cast( + tf.reshape(parsed['image/class/label'], shape=[]), dtype=tf.int32) - 1 + + return image, label + + def input_fn(self, params): + """Input function which provides a single batch for train or eval. + + Args: + params: `dict` of parameters passed from the `TPUEstimator`. + `params['batch_size']` is always provided and should be used as the + effective batch size. + + Returns: + A `tf.data.Dataset` object. + """ + if self.data_dir is None: + tf.logging.info('Using fake input.') + return self.input_fn_null(params) + + # Retrieves the batch size for the current shard. The # of shards is + # computed according to the input pipeline deployment. See + # tf.contrib.tpu.RunConfig for details. + batch_size = params['batch_size'] + + # Shuffle the filenames to ensure better randomization. + file_pattern = os.path.join( + self.data_dir, 'train-*' if self.is_training else 'validation-*') + dataset = tf.data.Dataset.list_files(file_pattern, shuffle=self.is_training) + + if self.is_training and not self.cache: + dataset = dataset.repeat() + + def fetch_dataset(filename): + buffer_size = 8 * 1024 * 1024 # 8 MiB per file + dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size) + return dataset + + # Read the data from disk in parallel + dataset = dataset.apply( + tf.contrib.data.parallel_interleave( + fetch_dataset, cycle_length=self.num_parallel_calls, sloppy=True)) + if self.cache: + dataset = dataset.cache().apply( + tf.contrib.data.shuffle_and_repeat(1024 * 16)) + else: + dataset = dataset.shuffle(1024) + + # Use the fused map-and-batch operation. + # + # For XLA, we must used fixed shapes. Because we repeat the source training + # dataset indefinitely, we can use `drop_remainder=True` to get fixed-size + # batches without dropping any training examples. + # + # When evaluating, `drop_remainder=True` prevents accidentally evaluating + # the same image twice by dropping the final batch if it is less than a full + # batch size. As long as this validation is done with consistent batch size, + # exactly the same images will be used. + dataset = dataset.apply( + tf.contrib.data.map_and_batch( + self.dataset_parser, batch_size=batch_size, + num_parallel_batches=self.num_cores, drop_remainder=True)) + + # Transpose for performance on TPU + if self.transpose_input: + dataset = dataset.map( + lambda images, labels: (tf.transpose(images, [1, 2, 3, 0]), labels), + num_parallel_calls=self.num_cores) + + # Assign static batch size dimension + dataset = dataset.map(functools.partial(self.set_shapes, batch_size)) + + # Prefetch overlaps in-feed with training + dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) + return dataset + + def input_fn_null(self, params): + """Input function which provides null (black) images.""" + batch_size = params['batch_size'] + dataset = tf.data.Dataset.range(1).repeat().map(self._get_null_input) + dataset = dataset.prefetch(batch_size) + + dataset = dataset.batch(batch_size, drop_remainder=True) + if self.transpose_input: + dataset = dataset.map( + lambda images, labels: (tf.transpose(images, [1, 2, 3, 0]), labels), + num_parallel_calls=8) + + dataset = dataset.map(functools.partial(self.set_shapes, batch_size)) + + dataset = dataset.prefetch(32) # Prefetch overlaps in-feed with training + tf.logging.info('Input dataset: %s', str(dataset)) + return dataset + + def _get_null_input(self, _): + null_image = tf.zeros([224, 224, 3], tf.bfloat16 + if self.use_bfloat16 else tf.float32) + return (null_image, tf.constant(0, tf.int32)) diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py new file mode 100644 index 0000000000000000000000000000000000000000..b702e91f92220c2a9003a1b82411131332012a9e --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -0,0 +1,265 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Eager execution workflow with RevNet train on CIFAR-10.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +from absl import flags +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import cifar_input +from tensorflow.contrib.eager.python.examples.revnet import config as config_ +from tensorflow.contrib.eager.python.examples.revnet import revnet +tfe = tf.contrib.eager + + +def apply_gradients(optimizer, grads, vars_, global_step=None): + """Functional style apply_grads for `tfe.defun`.""" + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + + +def main(_): + """Eager execution workflow with RevNet trained on CIFAR-10.""" + tf.enable_eager_execution() + + config = get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) + ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets( + data_dir=FLAGS.data_dir, config=config) + model = revnet.RevNet(config=config) + global_step = tf.train.get_or_create_global_step() # Ensure correct summary + global_step.assign(1) + learning_rate = tf.train.piecewise_constant( + global_step, config.lr_decay_steps, config.lr_list) + optimizer = tf.train.MomentumOptimizer( + learning_rate, momentum=config.momentum) + checkpointer = tf.train.Checkpoint( + optimizer=optimizer, model=model, optimizer_step=global_step) + + if FLAGS.use_defun: + model.call = tfe.defun(model.call) + model.compute_gradients = tfe.defun(model.compute_gradients) + model.get_moving_stats = tfe.defun(model.get_moving_stats) + model.restore_moving_stats = tfe.defun(model.restore_moving_stats) + global apply_gradients # pylint:disable=global-variable-undefined + apply_gradients = tfe.defun(apply_gradients) + + if FLAGS.train_dir: + summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) + if FLAGS.restore: + latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) + checkpointer.restore(latest_path) + print("Restored latest checkpoint at path:\"{}\" " + "with global_step: {}".format(latest_path, global_step.numpy())) + sys.stdout.flush() + + for x, y in ds_train: + train_one_iter(model, x, y, optimizer, global_step=global_step) + + if global_step.numpy() % config.log_every == 0: + it_test = ds_test.make_one_shot_iterator() + acc_test, loss_test = evaluate(model, it_test) + + if FLAGS.validate: + it_train = ds_train_one_shot.make_one_shot_iterator() + it_validation = ds_validation.make_one_shot_iterator() + acc_train, loss_train = evaluate(model, it_train) + acc_validation, loss_validation = evaluate(model, it_validation) + print("Iter {}, " + "training set accuracy {:.4f}, loss {:.4f}; " + "validation set accuracy {:.4f}, loss {:.4f}; " + "test accuracy {:.4f}, loss {:.4f}".format( + global_step.numpy(), acc_train, loss_train, acc_validation, + loss_validation, acc_test, loss_test)) + else: + print("Iter {}, test accuracy {:.4f}, loss {:.4f}".format( + global_step.numpy(), acc_test, loss_test)) + sys.stdout.flush() + + if FLAGS.train_dir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("Test accuracy", acc_test) + tf.contrib.summary.scalar("Test loss", loss_test) + if FLAGS.validate: + tf.contrib.summary.scalar("Training accuracy", acc_train) + tf.contrib.summary.scalar("Training loss", loss_train) + tf.contrib.summary.scalar("Validation accuracy", acc_validation) + tf.contrib.summary.scalar("Validation loss", loss_validation) + + if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: + saved_path = checkpointer.save( + file_prefix=os.path.join(FLAGS.train_dir, "ckpt")) + print("Saved checkpoint at path: \"{}\" " + "with global_step: {}".format(saved_path, global_step.numpy())) + sys.stdout.flush() + + +def get_config(config_name="revnet-38", dataset="cifar-10"): + """Return configuration.""" + print("Config: {}".format(config_name)) + sys.stdout.flush() + config = { + "revnet-38": config_.get_hparams_cifar_38(), + "revnet-110": config_.get_hparams_cifar_110(), + "revnet-164": config_.get_hparams_cifar_164(), + }[config_name] + + if dataset == "cifar-10": + config.add_hparam("n_classes", 10) + config.add_hparam("dataset", "cifar-10") + else: + config.add_hparam("n_classes", 100) + config.add_hparam("dataset", "cifar-100") + + return config + + +def get_datasets(data_dir, config): + """Return dataset.""" + if data_dir is None: + raise ValueError("No supplied data directory") + if not os.path.exists(data_dir): + raise ValueError("Data directory {} does not exist".format(data_dir)) + if config.dataset not in ["cifar-10", "cifar-100"]: + raise ValueError("Unknown dataset {}".format(config.dataset)) + + print("Training on {} dataset.".format(config.dataset)) + sys.stdout.flush() + data_dir = os.path.join(data_dir, config.dataset) + if FLAGS.validate: + # 40k Training set + ds_train = cifar_input.get_ds_from_tfrecords( + data_dir=data_dir, + split="train", + data_aug=True, + batch_size=config.batch_size, + epochs=config.epochs, + shuffle=config.shuffle, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.batch_size) + # 10k Training set + ds_validation = cifar_input.get_ds_from_tfrecords( + data_dir=data_dir, + split="validation", + data_aug=False, + batch_size=config.eval_batch_size, + epochs=1, + shuffle=False, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.eval_batch_size) + else: + # 50k Training set + ds_train = cifar_input.get_ds_from_tfrecords( + data_dir=data_dir, + split="train_all", + data_aug=True, + batch_size=config.batch_size, + epochs=config.epochs, + shuffle=config.shuffle, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.batch_size) + ds_validation = None + + # Always compute loss and accuracy on whole test set + ds_train_one_shot = cifar_input.get_ds_from_tfrecords( + data_dir=data_dir, + split="train_all", + data_aug=False, + batch_size=config.eval_batch_size, + epochs=1, + shuffle=False, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.eval_batch_size) + + ds_test = cifar_input.get_ds_from_tfrecords( + data_dir=data_dir, + split="test", + data_aug=False, + batch_size=config.eval_batch_size, + epochs=1, + shuffle=False, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.eval_batch_size) + + return ds_train, ds_train_one_shot, ds_validation, ds_test + + +def train_one_iter(model, inputs, labels, optimizer, global_step=None): + """Train for one iteration.""" + logits, saved_hiddens = model(inputs, training=True) + values = model.get_moving_stats() + grads, loss = model.compute_gradients(saved_hiddens, labels) + # Restore moving averages when executing eagerly to avoid updating twice + model.restore_moving_stats(values) + apply_gradients( + optimizer, grads, model.trainable_variables, global_step=global_step) + + return logits, loss + + +def evaluate(model, iterator): + """Compute accuracy with the given dataset iterator.""" + mean_loss = tfe.metrics.Mean() + accuracy = tfe.metrics.Accuracy() + for x, y in iterator: + logits, _ = model(x, training=False) + loss = model.compute_loss(logits=logits, labels=y) + accuracy( + labels=tf.cast(y, tf.int64), + predictions=tf.argmax(logits, axis=1, output_type=tf.int64)) + mean_loss(loss) + + return accuracy.result().numpy(), mean_loss.result().numpy() + + +if __name__ == "__main__": + flags.DEFINE_string( + "data_dir", default=None, help="Directory to load tfrecords") + flags.DEFINE_string( + "train_dir", + default=None, + help="[Optional] Directory to store the training information") + flags.DEFINE_boolean( + "restore", + default=False, + help="[Optional] Restore the latest checkpoint from `train_dir` if True") + flags.DEFINE_boolean( + "validate", + default=False, + help="[Optional] Use the validation set or not for hyperparameter search") + flags.DEFINE_string( + "dataset", + default="cifar-10", + help="[Optional] The dataset used; either `cifar-10` or `cifar-100`") + flags.DEFINE_string( + "config", + default="revnet-38", + help="[Optional] Architecture of network. " + "Other options include `revnet-110` and `revnet-164`") + flags.DEFINE_boolean( + "use_defun", + default=False, + help="[Optional] Use `tfe.defun` to boost performance.") + FLAGS = flags.FLAGS + tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..3a17eb30da3b989acb0b33f2fcb730da76546c18 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py @@ -0,0 +1,200 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Estimator workflow with RevNet train on CIFAR-10.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl import flags +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import cifar_input +from tensorflow.contrib.eager.python.examples.revnet import main as main_ +from tensorflow.contrib.eager.python.examples.revnet import revnet + + +def model_fn(features, labels, mode, params): + """Function specifying the model that is required by the `tf.estimator` API. + + Args: + features: Input images + labels: Labels of images + mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT' + params: A dictionary of extra parameter that might be passed + + Returns: + An instance of `tf.estimator.EstimatorSpec` + """ + + inputs = features + if isinstance(inputs, dict): + inputs = features["image"] + + config = params["config"] + model = revnet.RevNet(config=config) + + if mode == tf.estimator.ModeKeys.TRAIN: + global_step = tf.train.get_or_create_global_step() + learning_rate = tf.train.piecewise_constant( + global_step, config.lr_decay_steps, config.lr_list) + optimizer = tf.train.MomentumOptimizer( + learning_rate, momentum=config.momentum) + logits, saved_hidden = model(inputs, training=True) + grads, loss = model.compute_gradients(saved_hidden, labels, training=True) + with tf.control_dependencies(model.get_updates_for(inputs)): + train_op = optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) + + return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) + else: + logits, _ = model(inputs, training=False) + predictions = tf.argmax(logits, axis=1) + probabilities = tf.nn.softmax(logits) + + if mode == tf.estimator.ModeKeys.EVAL: + loss = model.compute_loss(labels=labels, logits=logits) + return tf.estimator.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops={ + "accuracy": + tf.metrics.accuracy(labels=labels, predictions=predictions) + }) + + else: # mode == tf.estimator.ModeKeys.PREDICT + result = { + "classes": predictions, + "probabilities": probabilities, + } + + return tf.estimator.EstimatorSpec( + mode=mode, + predictions=predictions, + export_outputs={ + "classify": tf.estimator.export.PredictOutput(result) + }) + + +def get_input_fn(config, data_dir, split): + """Get the input function that is required by the `tf.estimator` API. + + Args: + config: Customized hyperparameters + data_dir: Directory where the data is stored + split: One of `train`, `validation`, `train_all`, and `test` + + Returns: + Input function required by the `tf.estimator` API + """ + + data_dir = os.path.join(data_dir, config.dataset) + # Fix split-dependent hyperparameters + if split == "train_all" or split == "train": + data_aug = True + batch_size = config.batch_size + epochs = config.epochs + shuffle = True + prefetch = config.batch_size + else: + data_aug = False + batch_size = config.eval_batch_size + epochs = 1 + shuffle = False + prefetch = config.eval_batch_size + + def input_fn(): + """Input function required by the `tf.estimator.Estimator` API.""" + return cifar_input.get_ds_from_tfrecords( + data_dir=data_dir, + split=split, + data_aug=data_aug, + batch_size=batch_size, + epochs=epochs, + shuffle=shuffle, + prefetch=prefetch, + data_format=config.data_format) + + return input_fn + + +def main(_): + tf.logging.set_verbosity(tf.logging.INFO) + + # RevNet specific configuration + config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) + + # Estimator specific configuration + run_config = tf.estimator.RunConfig( + model_dir=FLAGS.model_dir, # Directory for storing checkpoints + tf_random_seed=config.seed, + save_summary_steps=config.log_every, + save_checkpoints_steps=config.log_every, + session_config=None, # Using default + keep_checkpoint_max=100, + keep_checkpoint_every_n_hours=10000, # Using default + log_step_count_steps=config.log_every, + train_distribute=None # Default not use distribution strategy + ) + + # Construct estimator + revnet_estimator = tf.estimator.Estimator( + model_fn=model_fn, + model_dir=FLAGS.model_dir, + config=run_config, + params={"config": config}) + + # Construct input functions + train_input_fn = get_input_fn( + config=config, data_dir=FLAGS.data_dir, split="train_all") + eval_input_fn = get_input_fn( + config=config, data_dir=FLAGS.data_dir, split="test") + + # Train and evaluate estimator + revnet_estimator.train(input_fn=train_input_fn) + revnet_estimator.evaluate(input_fn=eval_input_fn) + + if FLAGS.export: + input_shape = (None,) + config.input_shape + inputs = tf.placeholder(tf.float32, shape=input_shape) + input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ + "image": inputs + }) + revnet_estimator.export_savedmodel(FLAGS.model_dir, input_fn) + + +if __name__ == "__main__": + flags.DEFINE_string( + "data_dir", default=None, help="Directory to load tfrecords") + flags.DEFINE_string( + "model_dir", + default=None, + help="[Optional] Directory to store the training information") + flags.DEFINE_string( + "dataset", + default="cifar-10", + help="[Optional] The dataset used; either `cifar-10` or `cifar-100`") + flags.DEFINE_boolean( + "export", + default=False, + help="[Optional] Export the model for serving if True") + flags.DEFINE_string( + "config", + default="revnet-38", + help="[Optional] Architecture of network. " + "Other options include `revnet-110` and `revnet-164`") + FLAGS = flags.FLAGS + tf.app.run() diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py new file mode 100644 index 0000000000000000000000000000000000000000..8520cf5b71af503be35d5415707a283fb363a476 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py @@ -0,0 +1,394 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cloud TPU Estimator workflow with RevNet train on ImageNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from absl import flags +import tensorflow as tf +from tensorflow.contrib import summary +from tensorflow.contrib.eager.python.examples.revnet import config as config_ +from tensorflow.contrib.eager.python.examples.revnet import imagenet_input +from tensorflow.contrib.eager.python.examples.revnet import revnet +from tensorflow.contrib.training.python.training import evaluation +from tensorflow.python.estimator import estimator + +MEAN_RGB = [0.485, 0.456, 0.406] +STDDEV_RGB = [0.229, 0.224, 0.225] + + +def _host_call_fn(gs, loss, lr): + """Training host call. + + Creates scalar summaries for training metrics. + + This function is executed on the CPU and should not directly reference + any Tensors in the rest of the `model_fn`. To pass Tensors from the + model to the `metric_fn`, provide as part of the `host_call`. See + https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec + for more information. + + Arguments should match the list of `Tensor` objects passed as the second + element in the tuple passed to `host_call`. + + Args: + gs: `Tensor with shape `[batch]` for the global_step + loss: `Tensor` with shape `[batch]` for the training loss. + lr: `Tensor` with shape `[batch]` for the learning_rate. + + Returns: + List of summary ops to run on the CPU host. + """ + # Host call fns are executed FLAGS.iterations_per_loop times after one + # TPU loop is finished, setting max_queue value to the same as number of + # iterations will make the summary writer only flush the data to storage + # once per loop. + gs = gs[0] + with summary.create_file_writer( + FLAGS.model_dir, max_queue=FLAGS.iterations_per_loop).as_default(): + with summary.always_record_summaries(): + summary.scalar("loss", loss[0], step=gs) + summary.scalar("learning_rate", lr[0], step=gs) + return summary.all_summary_ops() + + +def _metric_fn(labels, logits): + """Evaluation metric function. Evaluates accuracy. + + This function is executed on the CPU and should not directly reference + any Tensors in the rest of the `model_fn`. To pass Tensors from the model + to the `metric_fn`, provide as part of the `eval_metrics`. See + https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec + for more information. + + Arguments should match the list of `Tensor` objects passed as the second + element in the tuple passed to `eval_metrics`. + + Args: + labels: `Tensor` with shape `[batch]`. + logits: `Tensor` with shape `[batch, num_classes]`. + + Returns: + A dict of the metrics to return from evaluation. + """ + predictions = tf.argmax(logits, axis=1) + top_1_accuracy = tf.metrics.accuracy(labels, predictions) + in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) + top_5_accuracy = tf.metrics.mean(in_top_5) + + return { + "top_1_accuracy": top_1_accuracy, + "top_5_accuracy": top_5_accuracy, + } + + +def model_fn(features, labels, mode, params): + """Model function required by the `tf.contrib.tpu.TPUEstimator` API. + + Args: + features: Input images + labels: Labels of images + mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT' + params: A dictionary of extra parameter that might be passed + + Returns: + An instance of `tf.contrib.tpu.TPUEstimatorSpec` + """ + revnet_config = params["revnet_config"] + model = revnet.RevNet(config=revnet_config) + + inputs = features + if isinstance(inputs, dict): + inputs = features["image"] + + if revnet_config.data_format == "channels_first": + assert not FLAGS.transpose_input # channels_first only for GPU + inputs = tf.transpose(inputs, [0, 3, 1, 2]) + + if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT: + inputs = tf.transpose(inputs, [3, 0, 1, 2]) # HWCN to NHWC + + # Normalize the image to zero mean and unit variance. + inputs -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=inputs.dtype) + inputs /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=inputs.dtype) + + if mode == tf.estimator.ModeKeys.TRAIN: + global_step = tf.train.get_or_create_global_step() + learning_rate = tf.train.piecewise_constant( + global_step, revnet_config.lr_decay_steps, revnet_config.lr_list) + optimizer = tf.train.MomentumOptimizer(learning_rate, + revnet_config.momentum) + if FLAGS.use_tpu: + optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) + + logits, saved_hidden = model(inputs, training=True) + grads, loss = model.compute_gradients(saved_hidden, labels, training=True) + with tf.control_dependencies(model.get_updates_for(inputs)): + train_op = optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) + if not FLAGS.skip_host_call: + # To log the loss, current learning rate, and epoch for Tensorboard, the + # summary op needs to be run on the host CPU via host_call. host_call + # expects [batch_size, ...] Tensors, thus reshape to introduce a batch + # dimension. These Tensors are implicitly concatenated to + # [params['batch_size']]. + gs_t = tf.reshape(global_step, [1]) + loss_t = tf.reshape(loss, [1]) + lr_t = tf.reshape(learning_rate, [1]) + host_call = (_host_call_fn, [gs_t, loss_t, lr_t]) + + return tf.contrib.tpu.TPUEstimatorSpec( + mode=mode, loss=loss, train_op=train_op, host_call=host_call) + + elif mode == tf.estimator.ModeKeys.EVAL: + logits, _ = model(inputs, training=False) + loss = model.compute_loss(labels=labels, logits=logits) + + return tf.contrib.tpu.TPUEstimatorSpec( + mode=mode, loss=loss, eval_metrics=(_metric_fn, [labels, logits])) + + else: # Predict or export + logits, _ = model(inputs, training=False) + predictions = { + "classes": tf.argmax(logits, axis=1), + "probabilities": tf.nn.softmax(logits), + } + + return tf.contrib.tpu.TPUEstimatorSpec( + mode=mode, + predictions=predictions, + export_outputs={ + "classify": tf.estimator.export.PredictOutput(predictions) + }) + + +def main(_): + tf.logging.set_verbosity(tf.logging.INFO) + + # RevNet specific configuration + revnet_config = { + "revnet-56": config_.get_hparams_imagenet_56(), + "revnet-104": config_.get_hparams_imagenet_104() + }[FLAGS.revnet_config] + + if FLAGS.use_tpu: + revnet_config.data_format = "channels_last" + + tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( + FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) + + # Estimator specific configuration + config = tf.contrib.tpu.RunConfig( + cluster=tpu_cluster_resolver, + model_dir=FLAGS.model_dir, + session_config=tf.ConfigProto( + allow_soft_placement=True, log_device_placement=True), + tpu_config=tf.contrib.tpu.TPUConfig( + iterations_per_loop=FLAGS.iterations_per_loop, + num_shards=FLAGS.num_shards, + per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig. + PER_HOST_V2), + ) + + # Input pipelines are slightly different (with regards to shuffling and + # preprocessing) between training and evaluation. + imagenet_train, imagenet_eval = [ + imagenet_input.ImageNetInput( + is_training=is_training, + data_dir=FLAGS.data_dir, + transpose_input=FLAGS.transpose_input, + use_bfloat16=False) for is_training in [True, False] + ] + + revnet_classifier = tf.contrib.tpu.TPUEstimator( + model_fn=model_fn, + use_tpu=FLAGS.use_tpu, + train_batch_size=revnet_config.tpu_batch_size, + eval_batch_size=revnet_config.tpu_eval_batch_size, + config=config, + export_to_tpu=False, + params={"revnet_config": revnet_config}) + + steps_per_epoch = revnet_config.tpu_iters_per_epoch + eval_steps = revnet_config.tpu_eval_steps + + # pylint: disable=protected-access + if FLAGS.mode == "eval": + # Run evaluation when there's a new checkpoint + for ckpt in evaluation.checkpoints_iterator( + FLAGS.model_dir, timeout=FLAGS.eval_timeout): + tf.logging.info("Starting to evaluate.") + try: + start_timestamp = time.time() # This time will include compilation time + eval_results = revnet_classifier.evaluate( + input_fn=imagenet_eval.input_fn, + steps=eval_steps, + checkpoint_path=ckpt) + elapsed_time = int(time.time() - start_timestamp) + tf.logging.info("Eval results: %s. Elapsed seconds: %d" % + (eval_results, elapsed_time)) + + # Terminate eval job when final checkpoint is reached + current_step = int(os.path.basename(ckpt).split("-")[1]) + if current_step >= revnet_config.max_train_iter: + tf.logging.info( + "Evaluation finished after training step %d" % current_step) + break + + except tf.errors.NotFoundError: + # Since the coordinator is on a different job than the TPU worker, + # sometimes the TPU worker does not finish initializing until long after + # the CPU job tells it to start evaluating. In this case, the checkpoint + # file could have been deleted already. + tf.logging.info( + "Checkpoint %s no longer exists, skipping checkpoint" % ckpt) + + else: # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval' + current_step = estimator._load_global_step_from_checkpoint_dir( + FLAGS.model_dir) + + tf.logging.info( + "Training for %d steps (%.2f epochs in total). Current" + " step %d." % (revnet_config.max_train_iter, + revnet_config.max_train_iter / steps_per_epoch, + current_step)) + + start_timestamp = time.time() # This time will include compilation time + + if FLAGS.mode == "train": + revnet_classifier.train( + input_fn=imagenet_train.input_fn, + max_steps=revnet_config.max_train_iter) + + else: + assert FLAGS.mode == "train_and_eval" + while current_step < revnet_config.max_train_iter: + # Train for up to steps_per_eval number of steps. + # At the end of training, a checkpoint will be written to --model_dir. + next_checkpoint = min(current_step + FLAGS.steps_per_eval, + revnet_config.max_train_iter) + revnet_classifier.train( + input_fn=imagenet_train.input_fn, max_steps=next_checkpoint) + current_step = next_checkpoint + + tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % + (next_checkpoint, int(time.time() - start_timestamp))) + + # Evaluate the model on the most recent model in --model_dir. + # Since evaluation happens in batches of --eval_batch_size, some images + # may be excluded modulo the batch size. As long as the batch size is + # consistent, the evaluated images are also consistent. + tf.logging.info("Starting to evaluate.") + eval_results = revnet_classifier.evaluate( + input_fn=imagenet_eval.input_fn, steps=eval_steps) + tf.logging.info("Eval results: %s" % eval_results) + + elapsed_time = int(time.time() - start_timestamp) + tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % + (revnet_config.max_train_iter, elapsed_time)) + + if FLAGS.export_dir is not None: + # The guide to serve an exported TensorFlow model is at: + # https://www.tensorflow.org/serving/serving_basic + tf.logging.info("Starting to export model.") + revnet_classifier.export_savedmodel( + export_dir_base=FLAGS.export_dir, + serving_input_receiver_fn=imagenet_input.image_serving_input_fn) + + +if __name__ == "__main__": + # Cloud TPU Cluster Resolver flags + flags.DEFINE_string( + "tpu", + default=None, + help="The Cloud TPU to use for training. This should be either the name " + "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " + "url.") + flags.DEFINE_string( + "tpu_zone", + default=None, + help="[Optional] GCE zone where the Cloud TPU is located in. If not " + "specified, we will attempt to automatically detect the GCE project from " + "metadata.") + flags.DEFINE_string( + "gcp_project", + default=None, + help="[Optional] Project name for the Cloud TPU-enabled project. If not " + "specified, we will attempt to automatically detect the GCE project from " + "metadata.") + + # Model specific parameters + flags.DEFINE_string( + "data_dir", default=None, help="Directory to load tfrecords") + flags.DEFINE_string( + "model_dir", + default=None, + help="[Optional] Directory to store the model information") + flags.DEFINE_string( + "revnet_config", + default="revnet-56", + help="[Optional] Architecture of network. " + "Other options include `revnet-104`") + flags.DEFINE_boolean( + "use_tpu", default=True, help="[Optional] Whether to use TPU") + flags.DEFINE_integer( + "num_shards", default=8, help="Number of shards (TPU chips).") + flags.DEFINE_integer( + "iterations_per_loop", + default=100, + help=( + "Number of steps to run on TPU before feeding metrics to the CPU." + " If the number of iterations in the loop would exceed the number of" + " train steps, the loop will exit before reaching" + " --iterations_per_loop. The larger this value is, the higher the" + " utilization on the TPU.")) + flags.DEFINE_integer( + "eval_timeout", + default=None, + help="Maximum seconds between checkpoints before evaluation terminates.") + flags.DEFINE_integer( + "steps_per_eval", + default=5000, + help=( + "Controls how often evaluation is performed. Since evaluation is" + " fairly expensive, it is advised to evaluate as infrequently as" + " possible (i.e. up to --train_steps, which evaluates the model only" + " after finishing the entire training regime).")) + flags.DEFINE_bool( + "transpose_input", + default=True, + help="Use TPU double transpose optimization") + flags.DEFINE_string( + "export_dir", + default=None, + help=("The directory where the exported SavedModel will be stored.")) + flags.DEFINE_bool( + "skip_host_call", + default=False, + help=("Skip the host_call which is executed every training step. This is" + " generally used for generating training summaries (train loss," + " learning rate, etc...). When --skip_host_call=false, there could" + " be a performance drop if host_call function is slow and cannot" + " keep up with the TPU-side computation.")) + flags.DEFINE_string( + "mode", + default="train_and_eval", + help='One of {"train_and_eval", "train", "eval"}.') + FLAGS = flags.FLAGS + tf.app.run() diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops.py b/tensorflow/contrib/eager/python/examples/revnet/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed5d363e6c8bffd817357c006abee7ac0d1dbba --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/ops.py @@ -0,0 +1,70 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Customized basic operations. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def downsample(x, filters, strides, axis=1): + """Downsample feature map with avg pooling, if filter size doesn't match.""" + + def pad_strides(strides, axis=1): + """Convert length 2 to length 4 strides. + + Needed since `tf.layers.Conv2D` uses length 2 strides, whereas operations + such as `tf.nn.avg_pool` use length 4 strides. + + Args: + strides: length 2 list/tuple strides for height and width + axis: integer specifying feature dimension according to data format + Returns: + length 4 strides padded with 1 on batch and channel dimension + """ + + assert len(strides) == 2 + + if axis == 1: + return [1, 1, strides[0], strides[1]] + return [1, strides[0], strides[1], 1] + + assert len(x.shape) == 4 and (axis == 1 or axis == 3) + + data_format = "NCHW" if axis == 1 else "NHWC" + strides_ = pad_strides(strides, axis=axis) + + if strides[0] > 1: + x = tf.nn.avg_pool( + x, strides_, strides_, padding="VALID", data_format=data_format) + + in_filter = x.shape[axis] + out_filter = filters + + if in_filter < out_filter: + pad_size = [(out_filter - in_filter) // 2, (out_filter - in_filter) // 2] + if axis == 1: + x = tf.pad(x, [[0, 0], pad_size, [0, 0], [0, 0]]) + else: + x = tf.pad(x, [[0, 0], [0, 0], [0, 0], pad_size]) + # In case `tape.gradient(x, [x])` produces a list of `None` + return x + 0. diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops_test.py b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5bc2641faf5a5d26262de683e52e36b1f42b3a7b --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for basic ops used in eager mode RevNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import ops +tfe = tf.contrib.eager + + +class OpsTest(tf.test.TestCase): + + def test_downsample(self): + """Test `possible_down_sample` function with mock object.""" + + batch_size = 100 + # NHWC format + x = tf.random_normal(shape=[batch_size, 32, 32, 3]) + # HW doesn't change but number of features increased + y = ops.downsample(x, filters=5, strides=(1, 1), axis=3) + self.assertEqual(y.shape, [batch_size, 32, 32, 5]) + # Feature map doesn't change but HW reduced + y = ops.downsample(x, filters=3, strides=(2, 2), axis=3) + self.assertEqual(y.shape, [batch_size, 16, 16, 3]) + # Number of feature increased and HW reduced + y = ops.downsample(x, filters=5, strides=(2, 2), axis=3) + self.assertEqual(y.shape, [batch_size, 16, 16, 5]) + + # Test gradient flow + x = tf.random_normal(shape=[batch_size, 32, 32, 3]) + with tfe.GradientTape() as tape: + tape.watch(x) + y = ops.downsample(x, filters=3, strides=(1, 1)) + self.assertEqual(y.shape, x.shape) + dy = tf.random_normal(shape=[batch_size, 3, 32, 32]) + grad, = tape.gradient(y, [x], output_gradients=[dy]) + self.assertEqual(grad.shape, x.shape) + + # Default NCHW format + if tf.test.is_gpu_available(): + x = tf.random_normal(shape=[batch_size, 3, 32, 32]) + # HW doesn't change but feature map reduced + y = ops.downsample(x, filters=5, strides=(1, 1)) + self.assertEqual(y.shape, [batch_size, 5, 32, 32]) + # Feature map doesn't change but HW reduced + y = ops.downsample(x, filters=3, strides=(2, 2)) + self.assertEqual(y.shape, [batch_size, 3, 16, 16]) + # Both feature map and HW reduced + y = ops.downsample(x, filters=5, strides=(2, 2)) + self.assertEqual(y.shape, [batch_size, 5, 16, 16]) + + # Test gradient flow + x = tf.random_normal(shape=[batch_size, 3, 32, 32]) + with tfe.GradientTape() as tape: + tape.watch(x) + y = ops.downsample(x, filters=3, strides=(1, 1)) + self.assertEqual(y.shape, x.shape) + dy = tf.random_normal(shape=[batch_size, 3, 32, 32]) + grad, = tape.gradient(y, [x], output_gradients=[dy]) + self.assertEqual(grad.shape, x.shape) + + +if __name__ == '__main__': + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/revnet/resnet_preprocessing.py b/tensorflow/contrib/eager/python/examples/revnet/resnet_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..21a1ab85d46cde11453e1f693cc4aabbbf3c90ed --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/resnet_preprocessing.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ImageNet preprocessing for ResNet.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +IMAGE_SIZE = 224 +CROP_PADDING = 32 + + +def distorted_bounding_box_crop(image_bytes, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(0.75, 1.33), + area_range=(0.05, 1.0), + max_attempts=100, + scope=None): + """Generates cropped_image using one of the bboxes randomly distorted. + + See `tf.image.sample_distorted_bounding_box` for more documentation. + + Args: + image_bytes: `Tensor` of binary image data. + bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` + where each coordinate is [0, 1) and the coordinates are arranged + as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole + image. + min_object_covered: An optional `float`. Defaults to `0.1`. The cropped + area of the image must contain at least this fraction of any bounding + box supplied. + aspect_ratio_range: An optional list of `float`s. The cropped area of the + image must have an aspect ratio = width / height within this range. + area_range: An optional list of `float`s. The cropped area of the image + must contain a fraction of the supplied image within in this range. + max_attempts: An optional `int`. Number of attempts at generating a cropped + region of the image of the specified constraints. After `max_attempts` + failures, return the entire image. + scope: Optional `str` for name scope. + Returns: + cropped image `Tensor` + """ + with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]): + shape = tf.image.extract_jpeg_shape(image_bytes) + sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( + shape, + bounding_boxes=bbox, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=True) + bbox_begin, bbox_size, _ = sample_distorted_bounding_box + + # Crop the image to the specified bounding box. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + + return image + + +def _at_least_x_are_equal(a, b, x): + """At least `x` of `a` and `b` `Tensors` are equal.""" + match = tf.equal(a, b) + match = tf.cast(match, tf.int32) + return tf.greater_equal(tf.reduce_sum(match), x) + + +def _decode_and_random_crop(image_bytes, image_size): + """Make a random crop of image_size.""" + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) + image = distorted_bounding_box_crop( + image_bytes, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(3. / 4, 4. / 3.), + area_range=(0.08, 1.0), + max_attempts=10, + scope=None) + original_shape = tf.image.extract_jpeg_shape(image_bytes) + bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) + + image = tf.cond( + bad, + lambda: _decode_and_center_crop(image_bytes, image_size), + lambda: tf.image.resize_bicubic([image], # pylint: disable=g-long-lambda + [image_size, image_size])[0]) + + return image + + +def _decode_and_center_crop(image_bytes, image_size): + """Crops to center of image with padding then scales image_size.""" + shape = tf.image.extract_jpeg_shape(image_bytes) + image_height = shape[0] + image_width = shape[1] + + padded_center_crop_size = tf.cast( + ((image_size / (image_size + CROP_PADDING)) * + tf.cast(tf.minimum(image_height, image_width), tf.float32)), + tf.int32) + + offset_height = ((image_height - padded_center_crop_size) + 1) // 2 + offset_width = ((image_width - padded_center_crop_size) + 1) // 2 + crop_window = tf.stack([offset_height, offset_width, + padded_center_crop_size, padded_center_crop_size]) + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + image = tf.image.resize_bicubic([image], [image_size, image_size])[0] + + return image + + +def _flip(image): + """Random horizontal image flip.""" + image = tf.image.random_flip_left_right(image) + return image + + +def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE): + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + + Returns: + A preprocessed image `Tensor`. + """ + image = _decode_and_random_crop(image_bytes, image_size) + image = _flip(image) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.convert_image_dtype( + image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) + return image + + +def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE): + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + + Returns: + A preprocessed image `Tensor`. + """ + image = _decode_and_center_crop(image_bytes, image_size) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.convert_image_dtype( + image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) + return image + + +def preprocess_image(image_bytes, + is_training=False, + use_bfloat16=False, + image_size=IMAGE_SIZE): + """Preprocesses the given image. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + is_training: `bool` for whether the preprocessing is for training. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + + Returns: + A preprocessed image `Tensor`. + """ + if is_training: + return preprocess_for_train(image_bytes, use_bfloat16, image_size) + else: + return preprocess_for_eval(image_bytes, use_bfloat16, image_size) diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1f2cb14972f0b92d29489adff8f94e790e1ec4ed --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -0,0 +1,217 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Code for main model. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import blocks + + +class RevNet(tf.keras.Model): + """RevNet that depends on all the blocks.""" + + def __init__(self, config): + """Initialize RevNet with building blocks. + + Args: + config: tf.contrib.training.HParams object; specifies hyperparameters + """ + super(RevNet, self).__init__() + self.axis = 1 if config.data_format == "channels_first" else 3 + self.config = config + + self._init_block = blocks.InitBlock(config=self.config) + self._final_block = blocks.FinalBlock(config=self.config) + self._block_list = self._construct_intermediate_blocks() + self._moving_average_variables = [] + + def _construct_intermediate_blocks(self): + # Precompute input shape after initial block + stride = self.config.init_stride + if self.config.init_max_pool: + stride *= 2 + if self.config.data_format == "channels_first": + w, h = self.config.input_shape[1], self.config.input_shape[2] + input_shape = (self.config.init_filters, w // stride, h // stride) + else: + w, h = self.config.input_shape[0], self.config.input_shape[1] + input_shape = (w // stride, h // stride, self.config.init_filters) + + # Aggregate intermediate blocks + block_list = tf.contrib.checkpoint.List() + for i in range(self.config.n_rev_blocks): + # RevBlock configurations + n_res = self.config.n_res[i] + filters = self.config.filters[i] + if filters % 2 != 0: + raise ValueError("Number of output filters must be even to ensure" + "correct partitioning of channels") + stride = self.config.strides[i] + strides = (self.config.strides[i], self.config.strides[i]) + + # Add block + rev_block = blocks.RevBlock( + n_res, + filters, + strides, + input_shape, + batch_norm_first=(i != 0), # Only skip on first block + data_format=self.config.data_format, + bottleneck=self.config.bottleneck, + fused=self.config.fused, + dtype=self.config.dtype) + block_list.append(rev_block) + + # Precompute input shape for the next block + if self.config.data_format == "channels_first": + w, h = input_shape[1], input_shape[2] + input_shape = (filters, w // stride, h // stride) + else: + w, h = input_shape[0], input_shape[1] + input_shape = (w // stride, h // stride, filters) + + return block_list + + def call(self, inputs, training=True): + """Forward pass.""" + + if training: + saved_hidden = [inputs] + + h = self._init_block(inputs, training=training) + if training: + saved_hidden.append(h) + + for block in self._block_list: + h = block(h, training=training) + if training: + saved_hidden.append(h) + + logits = self._final_block(h, training=training) + + return (logits, saved_hidden) if training else (logits, None) + + def compute_loss(self, logits, labels): + """Compute cross entropy loss.""" + + if self.config.dtype == tf.float32 or self.config.dtype == tf.float16: + cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + else: + # `sparse_softmax_cross_entropy_with_logits` does not have a GPU kernel + # for float64, int32 pairs + labels = tf.one_hot( + labels, depth=self.config.n_classes, axis=1, dtype=self.config.dtype) + cross_ent = tf.nn.softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + + return tf.reduce_mean(cross_ent) + + def compute_gradients(self, saved_hidden, labels, training=True, l2_reg=True): + """Manually computes gradients. + + This method silently updates the running averages of batch normalization. + + Args: + saved_hidden: List of hidden states Tensors + labels: One-hot labels for classification + training: Use the mini-batch stats in batch norm if set to True + l2_reg: Apply l2 regularization + + Returns: + A tuple with the first entry being a list of all gradients and the second + being the loss + """ + + def _defunable_pop(l): + """Functional style list pop that works with `tfe.defun`.""" + t, l = l[-1], l[:-1] + return t, l + + # Backprop through last block + x = saved_hidden[-1] + with tf.GradientTape() as tape: + tape.watch(x) + logits = self._final_block(x, training=training) + loss = self.compute_loss(logits, labels) + grads_combined = tape.gradient(loss, + [x] + self._final_block.trainable_variables) + dy, final_grads = grads_combined[0], grads_combined[1:] + + # Backprop through intermediate blocks + intermediate_grads = [] + for block in reversed(self._block_list): + y, saved_hidden = _defunable_pop(saved_hidden) + x = saved_hidden[-1] + dy, grads = block.backward_grads(x, y, dy, training=training) + intermediate_grads = grads + intermediate_grads + + # Backprop through first block + _, saved_hidden = _defunable_pop(saved_hidden) + x, saved_hidden = _defunable_pop(saved_hidden) + assert not saved_hidden + with tf.GradientTape() as tape: + y = self._init_block(x, training=training) + init_grads = tape.gradient( + y, self._init_block.trainable_variables, output_gradients=dy) + + # Ordering match up with `model.trainable_variables` + grads_all = init_grads + final_grads + intermediate_grads + if l2_reg: + grads_all = self._apply_weight_decay(grads_all) + + return grads_all, loss + + def _apply_weight_decay(self, grads): + """Update gradients to reflect weight decay.""" + return [ + g + self.config.weight_decay * v if v.name.endswith("kernel:0") else g + for g, v in zip(grads, self.trainable_variables) + ] + + def get_moving_stats(self): + """Get moving averages of batch normalization.""" + device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" + with tf.device(device): + return [v.read_value() for v in self.moving_average_variables] + + def restore_moving_stats(self, values): + """Restore moving averages of batch normalization.""" + device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" + with tf.device(device): + for var_, val in zip(self.moving_average_variables, values): + var_.assign(val) + + @property + def moving_average_variables(self): + """Get all variables that are batch norm moving averages.""" + + def _is_moving_avg(v): + n = v.name + return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") + + if not self._moving_average_variables: + self._moving_average_variables = filter(_is_moving_avg, self.variables) + + return self._moving_average_variables diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py new file mode 100644 index 0000000000000000000000000000000000000000..84b2ddf0de0739936d458ae1bce832cfbb167d64 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -0,0 +1,339 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for basic building blocks used in eager mode RevNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gc +import time + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import blocks_test +from tensorflow.contrib.eager.python.examples.revnet import config as config_ +from tensorflow.contrib.eager.python.examples.revnet import revnet +from tensorflow.python.client import device_lib +tfe = tf.contrib.eager + + +def train_one_iter(model, inputs, labels, optimizer, global_step=None): + """Train for one iteration.""" + logits, saved_hidden = model(inputs) + grads, loss = model.compute_gradients( + saved_hidden=saved_hidden, labels=labels) + optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) + + return logits, loss + + +class RevNetTest(tf.test.TestCase): + + def setUp(self): + super(RevNetTest, self).setUp() + config = config_.get_hparams_cifar_38() + config.add_hparam("n_classes", 10) + config.add_hparam("dataset", "cifar-10") + # Reconstruction could cause numerical error, use double precision for tests + config.dtype = tf.float64 + config.fused = False # Fused batch norm does not support tf.float64 + shape = (config.batch_size,) + config.input_shape + self.model = revnet.RevNet(config=config) + self.x = tf.random_normal(shape=shape, dtype=tf.float64) + self.t = tf.random_uniform( + shape=[config.batch_size], + minval=0, + maxval=config.n_classes, + dtype=tf.int64) + self.config = config + + def tearDown(self): + del self.model + del self.x + del self.t + del self.config + super(RevNetTest, self).tearDown() + + def test_call(self): + """Test `call` function.""" + + y, _ = self.model(self.x, training=False) + self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes]) + + def _check_grad_angle_combined(self, grads, grads_true): + """Verify that the reconstructed gradients has correct direction. + + Due to numerical imprecision, the magnitude may be slightly different. + Yet according to the paper, the angle should be roughly the same. + + Args: + grads: list of gradients from reconstruction + grads_true: list of true gradients + """ + + def _combine(gs): + return [tf.reshape(g, [-1]) for g in gs] + + g1_all = tf.concat(_combine(grads), axis=0) + g2_all = tf.concat(_combine(grads_true), axis=0) + + self.assertEqual(len(g1_all.shape), 1) + self.assertEqual(len(g2_all.shape), 1) + + degree = blocks_test.compute_degree(g1_all, g2_all) + self.assertLessEqual(degree, 1e0) + + def test_compute_gradients(self): + """Test `compute_gradients` function.""" + _, saved_hidden = self.model(self.x) # Initialize model + grads, loss = self.model.compute_gradients( + saved_hidden=saved_hidden, labels=self.t) + vars_ = self.model.trainable_variables + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + self.assertEqual(len(grads), len(vars_)) + for grad, var in zip(grads, vars_): + self.assertEqual(grad.shape, var.shape) + + # Compare against the true gradient computed by the tape + with tf.GradientTape() as tape: + logits, _ = self.model(self.x) + loss_true = self.model.compute_loss(logits=logits, labels=self.t) + grads_true = tape.gradient(loss_true, vars_) + self.assertAllClose(loss, loss_true) + self.assertAllClose(grads, grads_true, rtol=1e-4, atol=1e-4) + self._check_grad_angle_combined(grads, grads_true) + + def test_call_defun(self): + """Test `call` function with defun.""" + y, _ = tfe.defun(self.model.call)(self.x, training=False) + self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes]) + + def test_compute_gradients_defun(self): + """Test `compute_gradients` function with defun.""" + compute_gradients = tfe.defun(self.model.compute_gradients) + _, saved_hidden = self.model(self.x) + grads, _ = compute_gradients(saved_hidden=saved_hidden, labels=self.t) + vars_ = self.model.trainable_variables + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + self.assertEqual(len(grads), len(vars_)) + for grad, var in zip(grads, vars_): + if grad is not None: + self.assertEqual(grad.shape, var.shape) + + def test_training_graph(self): + """Test model training in graph mode.""" + with tf.Graph().as_default(): + config = config_.get_hparams_cifar_38() + config.add_hparam("n_classes", 10) + config.add_hparam("dataset", "cifar-10") + + x = tf.random_normal( + shape=(self.config.batch_size,) + self.config.input_shape) + t = tf.random_uniform( + shape=(self.config.batch_size,), + minval=0, + maxval=self.config.n_classes, + dtype=tf.int32) + global_step = tf.Variable(0., trainable=False) + model = revnet.RevNet(config=config) + _, saved_hidden = model(x) + grads, _ = model.compute_gradients(saved_hidden=saved_hidden, labels=t) + optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) + train_op = optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + for _ in range(1): + sess.run(train_op) + + +# Benchmark related +def device_and_data_format(): + return ("/gpu:0", + "channels_first") if tf.test.is_gpu_available() else ("/cpu:0", + "channels_last") + + +def random_batch(batch_size, config): + shape = (batch_size,) + config.input_shape + images = tf.random_uniform(shape) + labels = tf.random_uniform( + [batch_size], minval=0, maxval=config.n_classes, dtype=tf.int32) + + return images, labels + + +class MockIterator(object): + + def __init__(self, tensors): + self._tensors = [tf.identity(x) for x in tensors] + + def next(self): + return self._tensors + + +class RevNetBenchmark(tf.test.Benchmark): + """Eager and graph benchmarks for RevNet.""" + + def _train_batch_sizes(self): + """Shamelessly copied from `resnet50_test.py`. + + Note: This is targeted towards ImageNet. CIFAR-10 should allow more + aggressive batch sizes. + + Returns: + A tuple of possible batch sizes + """ + for device in device_lib.list_local_devices(): + if tf.DeviceSpec.from_string(device.name).device_type == "GPU": + if "K20" in device.physical_device_desc: + return (16,) + if "P100" in device.physical_device_desc: + return (16, 32, 64) + if tf.DeviceSpec.from_string(device.name).device_type == "TPU": + return (32,) + return (16, 32) + + def _force_device_sync(self): + """Shamelessly copied from `resnet50_test.py`.""" + tf.constant(1.).cpu() + + def _report(self, label, start, num_iters, device, batch_size, data_format): + avg_time = (time.time() - start) / num_iters + dev = tf.DeviceSpec.from_string(device).device_type.lower() + name = "%s_%s_batch_%d_%s" % (label, dev, batch_size, data_format) + extras = {"examples_per_sec": batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def _benchmark_eager_apply(self, + label, + device_and_format, + defun=False, + execution_mode=None, + compiled=False): + config = config_.get_hparams_imagenet_56() + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format + model = revnet.RevNet(config=config) + if defun: + model.call = tfe.defun(model.call, compiled=compiled) + batch_size = 64 + num_burn = 5 + num_iters = 10 + with tf.device(device): + images, _ = random_batch(batch_size, config) + for _ in range(num_burn): + model(images, training=False) + if execution_mode: + tfe.async_wait() + gc.collect() + start = time.time() + for _ in range(num_iters): + model(images, training=False) + if execution_mode: + tfe.async_wait() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_apply_sync(self): + self._benchmark_eager_apply( + "eager_apply_sync", device_and_data_format(), defun=False) + + def benchmark_eager_apply_async(self): + self._benchmark_eager_apply( + "eager_apply_async", + device_and_data_format(), + defun=False, + execution_mode=tfe.ASYNC) + + def benchmark_eager_call_defun(self): + self._benchmark_eager_apply( + "eager_apply_with_defun", device_and_data_format(), defun=True) + + def _benchmark_eager_train(self, + label, + make_iterator, + device_and_format, + defun=False, + execution_mode=None, + compiled=False): + config = config_.get_hparams_imagenet_56() + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format + for batch_size in self._train_batch_sizes(): + (images, labels) = random_batch(batch_size, config) + model = revnet.RevNet(config=config) + optimizer = tf.train.GradientDescentOptimizer(0.1) + if defun: + model.call = tfe.defun(model.call) + + num_burn = 3 + num_iters = 10 + with tf.device(device): + iterator = make_iterator((images, labels)) + for _ in range(num_burn): + (images, labels) = iterator.next() + train_one_iter(model, images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + gc.collect() + + start = time.time() + for _ in range(num_iters): + (images, labels) = iterator.next() + train_one_iter(model, images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_train_sync(self): + self._benchmark_eager_train( + "eager_train_sync", MockIterator, device_and_data_format(), defun=False) + + def benchmark_eager_train_async(self): + self._benchmark_eager_train( + "eager_train_async", + MockIterator, + device_and_data_format(), + defun=False, + execution_mode=tfe.ASYNC) + + def benchmark_eager_train_defun(self): + self._benchmark_eager_train( + "eager_train", MockIterator, device_and_data_format(), defun=False) + + def benchmark_eager_train_datasets_with_defun(self): + + def make_iterator(tensors): + with tf.device("/device:CPU:0"): + ds = tf.data.Dataset.from_tensors(tensors).repeat() + return tfe.Iterator(ds) + + self._benchmark_eager_train( + "eager_train_dataset_with_defun", + make_iterator, + device_and_data_format(), + defun=True) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index c2340a293a80924f2dfa90e2fb23134b0f1feb6b..15776c694e92825895437a4c1547699f6d9269fb 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -310,12 +310,12 @@ def main(_): with tf.device("/device:GPU:0" if have_gpu else None): # Make learning_rate a Variable so it can be included in the checkpoint # and we can resume training with the last saved learning_rate. - learning_rate = tfe.Variable(20.0, name="learning_rate") + learning_rate = tf.Variable(20.0, name="learning_rate") model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim, FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout, use_cudnn_rnn) optimizer = tf.train.GradientDescentOptimizer(learning_rate) - checkpoint = tfe.Checkpoint( + checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, model=model, # GradientDescentOptimizer has no state to checkpoint, but noting it # here lets us swap in an optimizer that does. diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 8ac553e0ae71382966d03d9ef4429adf5137b369..d18a097063c7d25947af3e2e2959ce574edd553f 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -36,7 +36,7 @@ from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=g-bad-import-order @@ -422,7 +422,7 @@ class SpinnTest(test_util.TensorFlowTestCase): # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) object_graph = checkpointable_utils.object_metadata( - saver.latest_checkpoint(config.logdir)) + checkpoint_management.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: for attribute in node.attributes: diff --git a/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb b/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..75cb3f8227fe90223734f422e458f15810b8089a --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb @@ -0,0 +1,282 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "TFE Workshop: control flow", + "version": "0.3.2", + "provenance": [], + "include_colab_link": true + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "[View in Colaboratory](https://colab.research.google.com/gist/alextp/664b2f8700485ff6801f4d26293bd567/tfe-workshop-control-flow.ipynb)" + ] + }, + { + "metadata": { + "id": "9BpQzh9BvJlj", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 37 + }, + "outputId": "0b336886-8204-4815-89fa-5291a49d5784" + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "import numpy as np\n", + "tf.enable_eager_execution()" + ], + "execution_count": 1, + "outputs": [] + }, + { + "metadata": { + "id": "0roIB19GvOjI", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Eager execution basics\n", + "\n", + "When eager execution is enabled TensorFlow immediately executes operations, and Tensors are always available. " + ] + }, + { + "metadata": { + "id": "jeO8F-V-vN24", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "outputId": "aeb3bdec-50b7-440d-93d8-5a171f091081" + }, + "cell_type": "code", + "source": [ + "t = tf.constant([[1, 2], [3, 4]])\n", + "t" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 2 + } + ] + }, + { + "metadata": { + "id": "Y17RwSFxvlDL", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "outputId": "cfcc10c7-707b-4997-99b3-a5f382c5166b" + }, + "cell_type": "code", + "source": [ + "tf.matmul(t, t)" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 3 + } + ] + }, + { + "metadata": { + "id": "Dab1bS3TvmRE", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "8a624f3d-a658-4359-c586-1c5f6bf4c8b7" + }, + "cell_type": "code", + "source": [ + "# It's also possible to have Python control flow which depends on the value of tensors.\n", + "if t[0, 0] > 0.5:\n", + " print(\"T is bigger\")\n", + "else:\n", + " print(\"T is smaller\")" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "T is bigger\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "dPgptJcGwIon", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "c4f27f2b-0848-4475-dde5-2534dac65a5c" + }, + "cell_type": "code", + "source": [ + "# Tensors are also usable as numpy arrays\n", + "np.prod(t)" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "24" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 6 + } + ] + }, + { + "metadata": { + "id": "p3DTfQXnwXzj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Exercise\n", + "\n", + "The algorithm for bisecting line search is a pretty simple way to find a zero of a continuous scalar function in an interval [a,b] where f(a) and f(b) have different signs. Simply evaluate f((a+b)/2), and narrow the interval by replacing either a or b with (a+b)/2 such that the function when applied on the boundary of the interval still has different signs.\n", + "\n", + "Implement a python function `bisecting_line_search(f, a, b, epsilon)` which returns a value such that `tf.abs(f(value)) < epsilon`.\n", + "\n", + "One thing to keep in mind: python's `==` opertor is not overloaded on Tensors, so you need to use `tf.equal` to compare for equality." + ] + }, + { + "metadata": { + "id": "6eq0YuI6ykm5", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "# Example test harness to get you going\n", + "\n", + "def test_f(x):\n", + " return x - 0.1234\n", + "def bisecting_line_search(f, a, b, epsilon):\n", + " # Return x such that f(x) <= epsilon.\n", + " pass\n", + "a = tf.constant(0.0)\n", + "b = tf.constant(1.0)\n", + "epsilon = tf.constant(0.001)\n", + "x = bisecting_line_search(test_f, a, b, epsilon)\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "LcMmEfd_xvej", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 170 + }, + "outputId": "f402aa50-8ce3-4416-f755-8bbcd1af7809" + }, + "cell_type": "code", + "source": [ + "#@title Double-click to see the solution\n", + "\n", + "def bisecting_line_search(f, a, b, epsilon):\n", + " f_a = f(a)\n", + " f_b = f(b)\n", + " probe = (a + b) / 2\n", + " f_probe = f(probe)\n", + " while tf.abs(f_probe) > epsilon:\n", + " if tf.equal(tf.sign(f_probe), tf.sign(f_a)):\n", + " a = probe\n", + " f_a = f_probe\n", + " else:\n", + " b = probe\n", + " f_b = f_probe\n", + " probe = (a + b) / 2\n", + " f_probe = f(probe)\n", + " print(\"new probe\", probe)\n", + " return probe\n", + "\n", + "bisecting_line_search(test_f, 0., 1., 0.001)" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "text": [ + "('new probe', 0.25)\n", + "('new probe', 0.125)\n", + "('new probe', 0.0625)\n", + "('new probe', 0.09375)\n", + "('new probe', 0.109375)\n", + "('new probe', 0.1171875)\n", + "('new probe', 0.12109375)\n", + "('new probe', 0.123046875)\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.123046875" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 8 + } + ] + } + ] +} diff --git a/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f3a65f5aab1fe683565caf21dcfa8054045fd759 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb @@ -0,0 +1,1018 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "TFE Workshop: Models.ipynb", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "[View in Colaboratory](https://colab.research.google.com/gist/alextp/5cfcffd408bd5103f5ae747bc97ab0b5/tfe-workshop-models.ipynb)" + ] + }, + { + "metadata": { + "id": "BMxv1O6Q0SJL", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "8be9c556-ac7f-4142-e35e-19dc2b097121" + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "tfe = tf.contrib.eager" + ], + "execution_count": 1, + "outputs": [] + }, + { + "metadata": { + "id": "lE1vJhxp0WR9", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Variables\n", + "\n", + "TensorFlow variables are useful to store the state in your program. They are integrated with other parts of the API (taking gradients, checkpointing, graph functions)." + ] + }, + { + "metadata": { + "id": "C4ztQNgc0VpW", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "8b63ae1f-2670-49c0-a31b-8cf7fc4194a1" + }, + "cell_type": "code", + "source": [ + "# Creating variables\n", + "v = tf.Variable(1.0)\n", + "v" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 2 + } + ] + }, + { + "metadata": { + "id": "H0daItGg1IAp", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "e47d5aab-16a1-4e29-c27d-7fbc0b94b5d3" + }, + "cell_type": "code", + "source": [ + "v.assign_add(1.0)\n", + "v" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 3 + } + ] + }, + { + "metadata": { + "id": "BJvBzcIG1hyK", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Layers: common sets of useful operations\n", + "\n", + "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", + "\n", + "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", + "\n", + "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" + ] + }, + { + "metadata": { + "id": "iSQTS3QW1YQQ", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "c5d8aa10-dcad-44f7-f0eb-0faf5249fd7e" + }, + "cell_type": "code", + "source": [ + "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", + "# simply construct the object. Most layers take as a first argument the number\n", + "# of output dimensions / channels.\n", + "layer = tf.keras.layers.Dense(100)\n", + "\n", + "# The number of input dimensions is often unnecessary, as it can be inferred\n", + "# the first time the layer is used, but it can be provided if you want to \n", + "# specify it manually, which is useful in some complex models.\n", + "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))\n" + ], + "execution_count": 4, + "outputs": [] + }, + { + "metadata": { + "id": "nRuUogoS1liV", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "outputId": "c352ce79-d519-45e4-a12e-1eaba76871a2" + }, + "cell_type": "code", + "source": [ + "layer(tf.zeros([2, 2]))" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 5 + } + ] + }, + { + "metadata": { + "id": "JH4Kf4ka1mht", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 136 + }, + "outputId": "c34e2378-f83d-42c5-d30a-ebe55620368a" + }, + "cell_type": "code", + "source": [ + "layer.variables" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[,\n", + " ]" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 6 + } + ] + }, + { + "metadata": { + "id": "DSI4NF0_1vn-", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", + "Conv2D, LSTM, BatchNormalization, Dropout, and many others." + ] + }, + { + "metadata": { + "id": "hMgDBftJ12Bp", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Models: composing layers\n", + "\n", + "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", + "\n", + "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model.\n" + ] + }, + { + "metadata": { + "id": "K3gVY6gj1nbe", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 190 + }, + "outputId": "6e9be0c4-960e-46c2-cdd9-7e94ad09d46b" + }, + "cell_type": "code", + "source": [ + "class ResnetIdentityBlock(tf.keras.Model):\n", + " def __init__(self, kernel_size, filters):\n", + " super(ResnetIdentityBlock, self).__init__(name='')\n", + " filters1, filters2, filters3 = filters\n", + "\n", + " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", + " self.bn2a = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", + " self.bn2b = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", + " self.bn2c = tf.keras.layers.BatchNormalization()\n", + "\n", + " def call(self, input_tensor, training=False):\n", + " x = self.conv2a(input_tensor)\n", + " x = self.bn2a(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2b(x)\n", + " x = self.bn2b(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2c(x)\n", + " x = self.bn2c(x, training=training)\n", + "\n", + " x += input_tensor\n", + " return tf.nn.relu(x)\n", + " \n", + "block = ResnetIdentityBlock(1, [1, 2, 3])\n", + "print(block(tf.zeros([1, 2, 3, 3])))\n", + "print([x.name for x in block.variables])" + ], + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[[[0. 0. 0.]\n", + " [0. 0. 0.]\n", + " [0. 0. 0.]]\n", + "\n", + " [[0. 0. 0.]\n", + " [0. 0. 0.]\n", + " [0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)\n", + "['resnet_identity_block/conv2d/kernel:0', 'resnet_identity_block/conv2d/bias:0', 'resnet_identity_block/batch_normalization/gamma:0', 'resnet_identity_block/batch_normalization/beta:0', 'resnet_identity_block/conv2d_1/kernel:0', 'resnet_identity_block/conv2d_1/bias:0', 'resnet_identity_block/batch_normalization_1/gamma:0', 'resnet_identity_block/batch_normalization_1/beta:0', 'resnet_identity_block/conv2d_2/kernel:0', 'resnet_identity_block/conv2d_2/bias:0', 'resnet_identity_block/batch_normalization_2/gamma:0', 'resnet_identity_block/batch_normalization_2/beta:0', 'resnet_identity_block/batch_normalization/moving_mean:0', 'resnet_identity_block/batch_normalization/moving_variance:0', 'resnet_identity_block/batch_normalization_1/moving_mean:0', 'resnet_identity_block/batch_normalization_1/moving_variance:0', 'resnet_identity_block/batch_normalization_2/moving_mean:0', 'resnet_identity_block/batch_normalization_2/moving_variance:0']\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "LPXhHUIc1-sO", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" + ] + }, + { + "metadata": { + "id": "5pXgzNAU17xk", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 173 + }, + "outputId": "03b7eaf8-9b35-482b-bcf0-a99af6c2c6a4" + }, + "cell_type": "code", + "source": [ + " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(2, 1, \n", + " padding='same'),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(3, (1, 1)),\n", + " tf.keras.layers.BatchNormalization()])\n", + "my_seq(tf.zeros([1, 2, 3, 3]))\n" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 8 + } + ] + }, + { + "metadata": { + "id": "MZrns6p22GEQ", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Exercise!\n", + "\n", + "Make a simple convolutional neural network model, useful for things such as MNIST which don't need too many parameters. A sequence of two or three convolutions with small output channels (say, 32 and 64) plus one or two fully connected layers is probably enough.\n", + "\n", + "The input shape should be [batch_size, 28, 28, 1]." + ] + }, + { + "metadata": { + "id": "8CAUa3KNN916", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "97c0ff3c-c962-4c13-eee8-406101465761" + }, + "cell_type": "code", + "source": [ + "# TODO: Implement a convolutional model as described above, and assign it to\n", + "# model.\n", + "model = tf.keras.Sequential([\n", + " \n", + "])" + ], + "execution_count": 9, + "outputs": [] + }, + { + "metadata": { + "id": "vLDDduR32E82", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "09bb1d43-b4c6-44b5-916e-0d2903d10cf4" + }, + "cell_type": "code", + "source": [ + "#@title Click to see the answer\n", + "\n", + "max_pool = tf.keras.layers.MaxPooling2D(\n", + " (2, 2), (2, 2), padding='same')\n", + " # The model consists of a sequential chain of layers, so tf.keras.Sequential\n", + " # (a subclass of tf.keras.Model) makes for a compact description.\n", + "model = tf.keras.Sequential(\n", + " [\n", + " tf.keras.layers.Conv2D(\n", + " 32,\n", + " 5,\n", + " padding='same',\n", + " activation=tf.nn.relu),\n", + " max_pool,\n", + " tf.keras.layers.Conv2D(\n", + " 64,\n", + " 5,\n", + " padding='same',\n", + " activation=tf.nn.relu),\n", + " max_pool,\n", + " tf.keras.layers.Flatten(),\n", + " tf.keras.layers.Dense(1024, activation=tf.nn.relu),\n", + " tf.keras.layers.Dropout(0.4),\n", + " tf.keras.layers.Dense(10)\n", + " ])\n", + "\n", + "model(tf.zeros([1, 28, 28, 1]))" + ], + "execution_count": 10, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 10 + } + ] + }, + { + "metadata": { + "id": "H_CKVBroik4M", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Stop here for now" + ] + }, + { + "metadata": { + "id": "_yRwuE6MMmzC", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Training\n", + "\n", + "When eager execution is enabled, you can write Pythonic training loops. Simply\n", + "\n", + "1. load your data into a `tf.data.Dataset`, which lets you construct functional pipelines for processing, shuffling, and batching your data,\n", + "2. iterate over the dataset using a Python `for` loop, and\n", + "3. perform an optimization step in the body of your `for` loop.\n", + "\n", + "This workflow is exemplified in the following exercise." + ] + }, + { + "metadata": { + "id": "gj0-EkTc_Xt1", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "\n", + "\n", + "## Exercise!\n", + "\n", + "In this exercise, you'll train the convolutional model you implemented for the previous exericse on the MNIST dataset. " + ] + }, + { + "metadata": { + "id": "WOGm9HHn_byR", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "bbccc7ad-33cd-446e-bcda-f358c7547e1b" + }, + "cell_type": "code", + "source": [ + "#@title Utilities for downloading MNIST data (double-click to show code)\n", + "import gzip\n", + "import os\n", + "import tempfile\n", + "from six.moves import urllib\n", + "import shutil\n", + "\n", + "import numpy as np\n", + "\n", + "def read32(bytestream):\n", + " \"\"\"Read 4 bytes from bytestream as an unsigned 32-bit integer.\"\"\"\n", + " dt = np.dtype(np.uint32).newbyteorder('>')\n", + " return np.frombuffer(bytestream.read(4), dtype=dt)[0]\n", + "\n", + "\n", + "def check_image_file_header(filename):\n", + " \"\"\"Validate that filename corresponds to images for the MNIST dataset.\"\"\"\n", + " with tf.gfile.Open(filename, 'rb') as f:\n", + " magic = read32(f)\n", + " read32(f) # num_images, unused\n", + " rows = read32(f)\n", + " cols = read32(f)\n", + " if magic != 2051:\n", + " raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,\n", + " f.name))\n", + " if rows != 28 or cols != 28:\n", + " raise ValueError(\n", + " 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %\n", + " (f.name, rows, cols))\n", + "\n", + "\n", + "def check_labels_file_header(filename):\n", + " \"\"\"Validate that filename corresponds to labels for the MNIST dataset.\"\"\"\n", + " with tf.gfile.Open(filename, 'rb') as f:\n", + " magic = read32(f)\n", + " read32(f) # num_items, unused\n", + " if magic != 2049:\n", + " raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,\n", + " f.name))\n", + " \n", + "def download(directory, filename):\n", + " \"\"\"Download (and unzip) a file from the MNIST dataset if not already done.\"\"\"\n", + " filepath = os.path.join(directory, filename)\n", + " if tf.gfile.Exists(filepath):\n", + " return filepath\n", + " if not tf.gfile.Exists(directory):\n", + " tf.gfile.MakeDirs(directory)\n", + " # CVDF mirror of http://yann.lecun.com/exdb/mnist/\n", + " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n", + " _, zipped_filepath = tempfile.mkstemp(suffix='.gz')\n", + " print('Downloading %s to %s' % (url, zipped_filepath))\n", + " urllib.request.urlretrieve(url, zipped_filepath)\n", + " with gzip.open(zipped_filepath, 'rb') as f_in, \\\n", + " tf.gfile.Open(filepath, 'wb') as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " os.remove(zipped_filepath)\n", + " return filepath\n", + "\n", + "\n", + "def dataset(directory, images_file, labels_file):\n", + " \"\"\"Download and parse MNIST dataset.\"\"\"\n", + "\n", + " images_file = download(directory, images_file)\n", + " labels_file = download(directory, labels_file)\n", + "\n", + " check_image_file_header(images_file)\n", + " check_labels_file_header(labels_file)\n", + "\n", + " def decode_image(image):\n", + " # Normalize from [0, 255] to [0.0, 1.0]\n", + " image = tf.decode_raw(image, tf.uint8)\n", + " image = tf.cast(image, tf.float32)\n", + " image = tf.reshape(image, [28, 28, 1])\n", + " return image / 255.0\n", + "\n", + " def decode_label(label):\n", + " label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]\n", + " label = tf.reshape(label, []) # label is a scalar\n", + " return tf.to_int32(label)\n", + "\n", + " images = tf.data.FixedLengthRecordDataset(\n", + " images_file, 28 * 28, header_bytes=16).map(decode_image)\n", + " labels = tf.data.FixedLengthRecordDataset(\n", + " labels_file, 1, header_bytes=8).map(decode_label)\n", + " return tf.data.Dataset.zip((images, labels))\n", + "\n", + "\n", + "def get_training_data(directory):\n", + " \"\"\"tf.data.Dataset object for MNIST training data.\"\"\"\n", + " return dataset(directory, 'train-images-idx3-ubyte',\n", + " 'train-labels-idx1-ubyte').take(1024)\n", + "\n", + "def get_test_data(directory):\n", + " \"\"\"tf.data.Dataset object for MNIST test data.\"\"\"\n", + " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')" + ], + "execution_count": 11, + "outputs": [] + }, + { + "metadata": { + "id": "4ejmJ2dv_f0R", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 85 + }, + "outputId": "274c0381-e505-4e69-f910-3def6f8572a7" + }, + "cell_type": "code", + "source": [ + "# Don't forget to run the cell above!\n", + "training_data = get_training_data(\"/tmp/mnist/train\")\n", + "test_data = get_test_data(\"/tmp/mnist/test\")" + ], + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/tmp4ull1xwa.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/tmp1eikhj1v.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/tmpcp8xah9c.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/tmpqww_1e74.gz\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "TANpFS6GKLMC", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Fill in the implementation of `train_one_epoch` below and run the cell to train your model. " + ] + }, + { + "metadata": { + "id": "btKL0Ss9_rmC", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 + }, + "outputId": "56858516-86fc-424a-f00d-6f088f98bf9b" + }, + "cell_type": "code", + "source": [ + "EPOCHS = 5\n", + "optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.5)\n", + "\n", + "def loss_fn(logits, labels):\n", + " return tf.reduce_mean(\n", + " tf.nn.sparse_softmax_cross_entropy_with_logits(\n", + " logits=tf.squeeze(logits), labels=labels))\n", + "\n", + "def train_one_epoch(model, training_data, optimizer):\n", + " # TODO: Implement an optimization step and return the average loss.\n", + " #\n", + " # Hint: Use `tf.GradientTape` to compute the gradient of the loss, and use\n", + " # `optimizer.apply_gradients` to update the model's variables, which are\n", + " # accessible as `model.variables`\n", + " average_loss = tfe.metrics.Mean('loss')\n", + " for images, labels in training_data.shuffle(buffer_size=10000).batch(64):\n", + " pass\n", + " return average_loss.result()\n", + "\n", + "for epoch in range(EPOCHS):\n", + " loss = train_one_epoch(model, training_data, optimizer)\n", + " print(\"Average loss after epoch %d: %.4f\" % (epoch, loss))" + ], + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Average loss after epoch 0: 2.2847\n", + "Average loss after epoch 1: 2.2305\n", + "Average loss after epoch 2: 2.1334\n", + "Average loss after epoch 3: 1.9115\n", + "Average loss after epoch 4: 1.4285\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "yAOFupJN_htg", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 + }, + "outputId": "67e711e4-76c9-4e3f-bb49-a14955dba03a" + }, + "cell_type": "code", + "source": [ + "#@title Double-click to see a solution.\n", + "EPOCHS = 5\n", + "optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.5)\n", + "\n", + "def _loss_fn(logits, labels):\n", + " return tf.reduce_mean(\n", + " tf.nn.sparse_softmax_cross_entropy_with_logits(\n", + " logits=tf.squeeze(logits), labels=labels))\n", + "\n", + "def _train_one_epoch(model, training_data):\n", + " average_loss = tfe.metrics.Mean(\"loss\")\n", + " for images, labels in training_data.shuffle(buffer_size=10000).batch(64):\n", + " with tf.GradientTape() as tape:\n", + " logits = model(images, training=True)\n", + " loss = _loss_fn(logits, labels)\n", + " average_loss(loss)\n", + " gradients = tape.gradient(loss, model.variables)\n", + " optimizer.apply_gradients(zip(gradients, model.variables))\n", + " return average_loss.result()\n", + " \n", + "for epoch in range(EPOCHS):\n", + " loss = _train_one_epoch(model, training_data)\n", + " print(\"Average loss after epoch %d: %.4f\" % (epoch, loss))" + ], + "execution_count": 15, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Average loss after epoch 0: 1.0563\n", + "Average loss after epoch 1: 0.8013\n", + "Average loss after epoch 2: 0.6306\n", + "Average loss after epoch 3: 0.5543\n", + "Average loss after epoch 4: 0.5037\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "uDy1DrYA_2Jz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Run the below cell to qualitatively evaluate your model. Note how eager execution interoperates seamlessly with `matplotlib`." + ] + }, + { + "metadata": { + "id": "vR7rMtpu_3nB", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1752 + }, + "outputId": "b212aefa-f4b3-425c-f34d-2491429fa521" + }, + "cell_type": "code", + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "sampled_data = test_data.batch(1).shuffle(buffer_size=10000).take(5)\n", + "for image, label in sampled_data:\n", + " plt.figure()\n", + " plt.imshow(tf.reshape(image, (28, 28)))\n", + " plt.show()\n", + " logits = model(image, training=False)\n", + " prediction = tf.argmax(logits, axis=1, output_type=tf.int64)\n", + " print(\"Prediction: %d\" % prediction)" + ], + "execution_count": 16, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEwpJREFUeJzt3X1Ilff/x/HXmScxV2GZOmLVohXK\nKmLQjbUsy+pbI7rbaEm1IFhRSU1aE+kO3LqxCGrBMlsNkq0zZIM2Cu1mUTg1itXQbVnBQqKZNtcN\n2d3J3x9ffpLrNN/ndM65jn6fj7/m5cfrvI9XPHedc7zOcTU3NzcLAPCvXnJ6AABoD4glABgQSwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAG7kB/cOPGjbpw4YJcLpdyc3M1ZMiQYM4FABEloFieOXNGV69e\nlcfj0ZUrV5SbmyuPxxPs2QAgYgT0MLy8vFwZGRmSpP79++vWrVu6e/duUAcDgEgSUCwbGhrUvXv3\nlq979Oih+vr6oA0FAJEmKC/w8F4cADq6gGKZmJiohoaGlq9v3LihhISEoA0FAJEmoFiOHj1aJSUl\nkqTq6molJiaqS5cuQR0MACJJQK+Gv/nmm3rjjTf03nvvyeVyaf369cGeCwAiios3/wWAtnEFDwAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMHA7\nPQAQiAcPHpjX3rlzx+f2nj17qqGhodW2kydPmvb566+/mm//xx9/NK+13r4kjRgx4pltFRUVGjly\nZKttP/30k3mfL73E+dPz8JsBAIOAziwrKyu1YsUKDRgwQJI0cOBArV27NqiDAUAkCfhh+PDhw7Vz\n585gzgIAEYuH4QBgEHAsL1++rCVLlmju3LkqKysL5kwAEHFczc3Nzf7+UF1dnc6dO6cpU6aotrZW\nCxYsUGlpqaKjo0MxIwA4LqDnLJOSkjR16lRJUp8+fdSzZ0/V1dWpd+/eQR0OeB7+dIg/HQq3gH4z\nhw4d0hdffCFJqq+v182bN5WUlBTUwQAgkgR0Zjl+/HitWrVKx48f16NHj7RhwwYeggPo0AKKZZcu\nXbR79+5gzwIAESugF3gAf1RVVZnXfvfdd6Z1hw8fNu/zzJkzPrd7vV5FRUWZ99Me+LpPDx8+NP98\nR/t9BBPP5gKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAAM+3RGtPO/qV5fL\n1ep7BQUF5n1mZWWZ1z558sS8NhRcLpdpnT9vZebPJYT9+vUzry0pKfG5/Y8//mj1NW+7Fhz8FgHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgCt40MrBgwd9bp87d26r7y1btsy8z1de\necW89q233jKte//99837/Dfff/99q68TExNNP/fqq6+ab8Of+x8MvXv3Duvt/a/gzBIADIglABgQ\nSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4mp/3CVXoMB49emRe+/rrr/vcfvXqVfXt\n27fl68zMTPM+P/74Y/PauLg481ognDizBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGx\nBAADYgkABny6YztVX19vXjthwgTz2oEDB5q+l5eXZ96n223/Z/b48WPTuuvXr5v3efz4cZ/bFy5c\nqC+//NK8n0CNHTvWvLZfv34hnAQvwnRmWVNTo4yMDBUVFUn67z/U+fPnKzMzUytWrNDDhw9DOiQA\nOK3NWN67d095eXlKTU1t2bZz505lZmbqq6++Ut++fVVcXBzSIQHAaW3GMjo6WoWFha0+fL6ysrLl\noV16errKy8tDNyEARIA2n0xyu93PPOfU1NSk6OhoSVJ8fLxfz58BQHv0wi/w8HaYzkhISDCv/eWX\nX4Jym0ePHg3Kfv6N9cWg3r17m/e5cOHCgL4HPC2gWMbGxur+/fuKiYlRXV1dq4foCI9QvRqelJTk\nc/vRo0c1ceLElq+PHDli3ievhvNqeEcQ0N9Zjho1SiUlJZKk0tJSjRkzJqhDAUCkafN/+VVVVdqy\nZYuuXbsmt9utkpISbdu2TTk5OfJ4POrVq5dmzJgRjlkBwDFtxnLQoEE6cODAM9v3798fkoEAIBLx\ngWXt1A8//GBeO3v2bPPa572Ik5aWplOnTrV8ff78efM+J02aZF5rnfX333837/N5vF6voqKiAvrZ\nd99917x20KBB5rWrVq0yr42JiTGvxYvj2nAAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAY\nEEsAMCCWAGDA5Y7tlD+X23377bcvfHv/vDTQn7cS8+ft1NLS0kzr/Ln/o0aN8rk9OTn5mcsmO3Xq\nZNrn7du3zbc/YsQI89q9e/ea1y5YsMC8Fi+OM0sAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyI\nJQAYEEsAMCCWAGDQ5kfhIjItXrzYvHb06NHmtRcvXnzu9z744IOW//bnUruhQ4ea11ovN3S7g/NP\nNzk5OaCfe/qTLtvi9XrNa/351E4udwwvziwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIAreNqpjIyMkKz9N59//nlQ9tMRPHjwwOkREGacWQKAAbEEAANiCQAGxBIADIglABgQSwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADEyxrKmpUUZGhoqKiiRJOTk5mjZtmubPn6/58+fr5MmT\noZwRABzX5rsO3bt3T3l5eUpNTW21PTs7W+np6SEbDAAiSZtnltHR0SosLFRiYmI45gGAiNTmmaXb\n7Zbb/eyyoqIi7d+/X/Hx8Vq7dq169OgRkgGBSDRx4kTzWq/XG8JJEC4Bvfnv9OnTFRcXp5SUFO3Z\ns0e7du3SunXrgj0bELGOHj1qXvuf//zHvHb27Nnmtd988415LV5cQK+Gp6amKiUlRZI0fvx41dTU\nBHUoAIg0AcUyKytLtbW1kqTKykoNGDAgqEMBQKRp82F4VVWVtmzZomvXrsntdqukpETz5s3TypUr\n1blzZ8XGxmrTpk3hmBUAHNNmLAcNGqQDBw48s33y5MkhGQgAIhGf7ggEgAsx/vdwuSMAGBBLADAg\nlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDgckcgAKdPnw7JfqdNmxaS/eLFcWYJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAZcwQM85dSpU6Z1P//8s3mfL7/8snntuHHj\nzGsRXpxZAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAAy53RIf3999/+9we\nFxf3zPcyMjJM+/R6vebbP3jwoHlt7969zWsRXpxZAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAAy53DIMnT56Y1+bm5prWbdiwwbzPmJgY89r24u7du+a1b7/9ts/tZWVlz3zP\nehnjO++8Y7792bNnm9cicplimZ+fr3Pnzunx48davHixBg8erNWrV8vr9SohIUFbt25VdHR0qGcF\nAMe0GcuKigpdunRJHo9HjY2NmjlzplJTU5WZmakpU6Zo+/btKi4uVmZmZjjmBQBHtPmc5bBhw7Rj\nxw5JUrdu3dTU1KTKykpNmDBBkpSenq7y8vLQTgkADmszllFRUYqNjZUkFRcXKy0tTU1NTS0Pu+Pj\n41VfXx/aKQHAYeYXeI4dO6bi4mLt27dPkyZNatne3NwcksE6kpdesv/RwebNm0M4ScfRpUsX89qy\nsrKAvgc8zRTL06dPa/fu3dq7d6+6du2q2NhY3b9/XzExMaqrq1NiYmKo52zXeDU8+Px5NXzy5Mk+\nt5eVlWn06NGttlVUVJj26c+r4V9//bV5rT//Y0V4tXlk7ty5o/z8fBUUFCguLk6SNGrUKJWUlEiS\nSktLNWbMmNBOCQAOa/PM8vDhw2psbNTKlStbtm3evFlr1qyRx+NRr169NGPGjJAOCQBOazOWc+bM\n0Zw5c57Zvn///pAMBACRyNXMKzQh58+HW1n/uP/TTz817zM7Ozvotx8qv/32m2nd0qVLzfs8deqU\nz+1er1dRUVHm/TyturravDY5OTmg20Bk4dlkADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUA\nGBBLADAglgBgwOWOYeDP5Y4JCQmmdbdu3TLvc+LEiea148aN87k9Jycn4PfavH//vnntJ598Ylrn\nzz/bbt26+dze2Nio7t27t9p28eJF0z6tx0mSXC6XeS0iF2eWAGBALAHAgFgCgAGxBAADYgkABsQS\nAAyIJQAYEEsAMCCWAGBALAHAgMsdI0xxcbFp3bJly8z7bGhoCHScFi/ySYj++Oflh88zefJk8z4/\n+ugjn9uHDh2q8+fPP7MN8IUzSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAw4Aqe\ndqqmpsa8Njs727z2yJEjPre/yBU8q1evNq8dPHiwaV1mZmZAswCB4swSAAyIJQAYEEsAMCCWAGBA\nLAHAgFgCgAGxBAADYgkABsQSAAyIJQAYcLkjABi4LYvy8/N17tw5PX78WIsXL9aJEydUXV2tuLg4\nSdKiRYs0bty4UM4JAI5qM5YVFRW6dOmSPB6PGhsbNXPmTI0cOVLZ2dlKT08Px4wA4Lg2Yzls2DAN\nGTJEktStWzc1NTXJ6/WGfDAAiCR+PWfp8Xh09uxZRUVFqb6+Xo8ePVJ8fLzWrl2rHj16hHJOAHCU\nOZbHjh1TQUGB9u3bp6qqKsXFxSklJUV79uzRn3/+qXXr1oV6VgBwjOlPh06fPq3du3ersLBQXbt2\nVWpqqlJSUiRJ48eP9+uNaAGgPWozlnfu3FF+fr4KCgpaXv3OyspSbW2tJKmyslIDBgwI7ZQA4LA2\nX+A5fPiwGhsbtXLlypZts2bN0sqVK9W5c2fFxsZq06ZNIR0SAJzGH6UDgAGXOwKAAbEEAANiCQAG\nxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKA\nAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4nbjRjRs3\n6sKFC3K5XMrNzdWQIUOcGCOoKisrtWLFCg0YMECSNHDgQK1du9bhqQJXU1OjpUuXauHChZo3b56u\nX7+u1atXy+v1KiEhQVu3blV0dLTTY/rln/cpJydH1dXViouLkyQtWrRI48aNc3ZIP+Xn5+vcuXN6\n/PixFi9erMGDB7f74yQ9e79OnDjh+LEKeyzPnDmjq1evyuPx6MqVK8rNzZXH4wn3GCExfPhw7dy5\n0+kxXti9e/eUl5en1NTUlm07d+5UZmampkyZou3bt6u4uFiZmZkOTukfX/dJkrKzs5Wenu7QVC+m\noqJCly5dksfjUWNjo2bOnKnU1NR2fZwk3/dr5MiRjh+rsD8MLy8vV0ZGhiSpf//+unXrlu7evRvu\nMfAvoqOjVVhYqMTExJZtlZWVmjBhgiQpPT1d5eXlTo0XEF/3qb0bNmyYduzYIUnq1q2bmpqa2v1x\nknzfL6/X6/BUDsSyoaFB3bt3b/m6R48eqq+vD/cYIXH58mUtWbJEc+fOVVlZmdPjBMztdismJqbV\ntqamppaHc/Hx8e3umPm6T5JUVFSkBQsW6MMPP9Rff/3lwGSBi4qKUmxsrCSpuLhYaWlp7f44Sb7v\nV1RUlOPHypHnLJ/W3Nzs9AhB8dprr2n58uWaMmWKamtrtWDBApWWlrbL54va0lGO2fTp0xUXF6eU\nlBTt2bNHu3bt0rp165wey2/Hjh1TcXGx9u3bp0mTJrVsb+/H6en7VVVV5fixCvuZZWJiohoaGlq+\nvnHjhhISEsI9RtAlJSVp6tSpcrlc6tOnj3r27Km6ujqnxwqa2NhY3b9/X5JUV1fXIR7OpqamKiUl\nRZI0fvx41dTUODyR/06fPq3du3ersLBQXbt27TDH6Z/3KxKOVdhjOXr0aJWUlEiSqqurlZiYqC5d\nuoR7jKA7dOiQvvjiC0lSfX29bt68qaSkJIenCp5Ro0a1HLfS0lKNGTPG4YleXFZWlmprayX99znZ\n//9Lhvbizp07ys/PV0FBQcurxB3hOPm6X5FwrFzNDpyrb9u2TWfPnpXL5dL69euVnJwc7hGC7u7d\nu1q1apVu376tR48eafny5Ro7dqzTYwWkqqpKW7Zs0bVr1+R2u5WUlKRt27YpJydHDx48UK9evbRp\n0yZ16tTJ6VHNfN2nefPmac+ePercubNiY2O1adMmxcfHOz2qmcfj0WeffaZ+/fq1bNu8ebPWrFnT\nbo+T5Pt+zZo1S0VFRY4eK0diCQDtDVfwAIABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwOD/\nAKCzFeFbFn4BAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 5\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEQ1JREFUeJzt3W9Ilff/x/HXSSd2VmKaRwiqjTBy\nq9gfap2iliaFQfRvsCXW1rpRRJGTCJG0MSHLIpbF8M9qN3L7cjZvNQiOVAQt7LQcBLqB1Y0QaXYs\naUa2mZ3fjS9ff7Vcvj2ec65jez7ueZ1P57wPlzy7Li8vjysUCoUEAHihcU4PAABjAbEEAANiCQAG\nxBIADIglABgQSwAwIJYAYEAsAcAgMdx/uH//fl27dk0ul0ulpaWaO3duJOcCgLgSViyvXLmiW7du\nyefz6ebNmyotLZXP54v0bAAQN8I6DW9ublZeXp4kacaMGbp//74ePHgQ0cEAIJ6EFcvu7m5NmjRp\n8Ou0tDQFg8GIDQUA8SYiF3j4WxwAXnZhxdLj8ai7u3vw6zt37igjIyNiQwFAvAkrlosWLZLf75ck\ntbW1yePxaMKECREdDADiSVhXw9955x29+eab+uijj+RyubRv375IzwUAccXFH/8FgOFxBw8AGBBL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMEp0eAIgnP/30k2nd+vXrzc+Zl5dnXvvtt9+a1yK2OLIEAANiCQAGxBIADIgl\nABgQSwAwIJYAYEAsAcCAWAKAAbEEAAPu4AGecuzYMdO6YDBofk6XyxXuOIgjHFkCgEFYR5aBQEC7\ndu1SVlaWJGnmzJkqKyuL6GAAEE/CPg2fP3++qqurIzkLAMQtTsMBwCDsWN64cUPbtm3Thg0bdOnS\npUjOBABxxxUKhUIj/UddXV1qaWlRfn6+Ojo6tGnTJjU1NSkpKSkaMwKA48L6mWVmZqZWrlwpSZo2\nbZomT56srq4uTZ06NaLDAbH24Ycfmtb98MMP5ucsKCgwr21oaDCvRWyFdRp++vRpnThxQtJ/f9/s\n7t27yszMjOhgABBPwjqyzM3N1e7du3Xu3Dn19/fr888/5xQcwEstrFhOmDBBNTU1kZ4FAOIWtzsC\nT7lw4ULEn3PVqlURf07EHr9nCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIA\nDLjdES89v98/5PYVK1Y899hIPrXRqre3N+LPidjjyBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEE\nAANiCQAGxBIADLiDB2NSKBQyr21oaBhy+4oVK/7xsUh6++23o/4aiD6OLAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgIErNJL7xoA40dnZaV47derUIbc/efJE48aFd7zw7rvv\nmtf+/PPPYb0G4gtHlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBPd8SY\nVFlZ6ejrb9682dHXR+yZjizb29uVl5c3+LGht2/f1saNG1VQUKBdu3bpr7/+iuqQAOC0YWP58OFD\nVVRUyOv1Dm6rrq5WQUGBvvvuO02fPl2NjY1RHRIAnDZsLJOSklRfXy+PxzO4LRAIaNmyZZKknJwc\nNTc3R29CAIgDw/7MMjExUYmJzy7r6+tTUlKSJCk9PV3BYDA60wFAnBj1BR7+HCaccPz48YisffLk\nSSTGwb9AWLF0u9169OiRkpOT1dXV9cwpOhALO3bsMK/96quvhtw+mj/+O5JYb9++PazXQHwJ6ztl\n4cKF8vv9kqSmpiYtXrw4okMBQLwZ9siytbVVBw8eVGdnpxITE+X3+3X48GGVlJTI5/NpypQpWrNm\nTSxmBQDHDBvL2bNn69SpU89t/+abb6IyEADEI+7gQVyxXnCJ1oeAWX/+XlhYGJXXR/zi3nAAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDA7Y6IKxUVFaZ10brd8dVXXzWt6+3t\nNT9nSkpKuOMgjnBkCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADLjdEXHl\nyy+/dPT1BwYGTOv8fr/5OT/99NNwx0Ec4cgSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAy4gwdR99tvv5nXjuSDwKzcbrf5sV9++cX0nGlpaaOaCWMPR5YAYEAsAcCAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAwIJYAYEAsAcCA2x0RFusHe0kj+xCyJ0+ehDPOC507d878GLcx4p9wZAkA\nBqZYtre3Ky8vTw0NDZKkkpISrVq1Shs3btTGjRt14cKFaM4IAI4b9jT84cOHqqiokNfrfWZ7cXGx\ncnJyojYYAMSTYY8sk5KSVF9fL4/HE4t5ACAuuUKhUMiy8NixY5o0aZIKCwtVUlKiYDCo/v5+paen\nq6ysjB+MA3iphXU1fPXq1UpNTVV2drbq6up0/PhxlZeXR3o2xLGRXA3fvn27eW19fX0447xQc3Pz\nkNvfe+89BQKB57YBQwnrarjX61V2drYkKTc3V+3t7REdCgDiTVix3Llzpzo6OiRJgUBAWVlZER0K\nAOLNsKfhra2tOnjwoDo7O5WYmCi/36/CwkIVFRVp/PjxcrvdqqysjMWsAOCYYWM5e/ZsnTp16rnt\nK1asiMpAABCPzFfDgafdu3fPvHby5MkRf/0PPvjAvPY///nPkNsTEhKeu1CVkJAwqrnw8uJ2RwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYMCnO+IZ//TpiuPGjXvmsc2bN0fl\n9V0ul2ndF198YX7OF93CyO2NsOLIEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\nuIMHz/jf58H/3fTp05957Mcff4zK6xcWFprWzZo1KyqvD/wTjiwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBgQCwBwIBYAoABtzviGRcuXBhy+8cff/zMY6FQKCqvX15eHpXnBUaLI0sAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDgCkXrvjXEjV9//dW8ds6cOUNuHxgY\nUEJCwuDXI/m2Wb9+vXmtz+czrRs3jv/nEVume8OrqqrU0tKix48fa+vWrZozZ4727NmjgYEBZWRk\n6NChQ0pKSor2rADgmGFjefnyZV2/fl0+n089PT1au3atvF6vCgoKlJ+fryNHjqixsVEFBQWxmBcA\nHDHsucy8efN09OhRSVJKSor6+voUCAS0bNkySVJOTo6am5ujOyUAOGzYWCYkJMjtdkuSGhsbtWTJ\nEvX19Q2edqenpysYDEZ3SgBwmPnvWZ49e1aNjY06efKkli9fPrid60Px74033jCvHRgYCOsx4GVn\niuXFixdVU1Ojr7/+WhMnTpTb7dajR4+UnJysrq4ueTyeaM+JUeBqODB6w37H9fb2qqqqSrW1tUpN\nTZUkLVy4UH6/X5LU1NSkxYsXR3dKAHDYsEeWZ86cUU9Pj4qKiga3HThwQHv37pXP59OUKVO0Zs2a\nqA4JAE7jl9L/BTgNB0aPDyz7F7AGSHpxBJ9+LCUlxfycJ06cMK8lgohXfGcCgAGxBAADYgkABsQS\nAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADbnf8F7hx44Z5rfV2x+TkZPNzjuTWSCBecWQJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMuN3xX6C4uNi89vvvv//HxxIT///b\n5a233hrVTMBYw5ElABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4Qi/6hCoAgCSO\nLAHAhFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAAD06c7VlVVqaWlRY8f\nP9bWrVt1/vx5tbW1KTU1VZK0ZcsWLV26NJpzAoCjho3l5cuXdf36dfl8PvX09Gjt2rVasGCBiouL\nlZOTE4sZAcBxw8Zy3rx5mjt3riQpJSVFfX19GhgYiPpgABBPRvQn2nw+n65evaqEhAQFg0H19/cr\nPT1dZWVlSktLi+acAOAocyzPnj2r2tpanTx5Uq2trUpNTVV2drbq6ur0+++/q7y8PNqzAoBjTFfD\nL168qJqaGtXX12vixInyer3Kzs6WJOXm5qq9vT2qQwKA04aNZW9vr6qqqlRbWzt49Xvnzp3q6OiQ\nJAUCAWVlZUV3SgBw2LAXeM6cOaOenh4VFRUNblu3bp2Kioo0fvx4ud1uVVZWRnVIAHAan8EDAAbc\nwQMABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGCQ6MSL7t+/X9euXZPL5VJpaanmzp3rxBgRFQgEtGvXLmVlZUmSZs6cqbKyMoenCl97e7u2\nb9+uTz75RIWFhbp9+7b27NmjgYEBZWRk6NChQ0pKSnJ6zBH5+3sqKSlRW1ubUlNTJUlbtmzR0qVL\nnR1yhKqqqtTS0qLHjx9r69atmjNnzpjfT9Lz7+v8+fOO76uYx/LKlSu6deuWfD6fbt68qdLSUvl8\nvliPERXz589XdXW102OM2sOHD1VRUSGv1zu4rbq6WgUFBcrPz9eRI0fU2NiogoICB6ccmaHekyQV\nFxcrJyfHoalG5/Lly7p+/bp8Pp96enq0du1aeb3eMb2fpKHf14IFCxzfVzE/DW9ublZeXp4kacaM\nGbp//74ePHgQ6zHwAklJSaqvr5fH4xncFggEtGzZMklSTk6OmpubnRovLEO9p7Fu3rx5Onr0qCQp\nJSVFfX19Y34/SUO/r4GBAYenciCW3d3dmjRp0uDXaWlpCgaDsR4jKm7cuKFt27Zpw4YNunTpktPj\nhC0xMVHJycnPbOvr6xs8nUtPTx9z+2yo9yRJDQ0N2rRpkz777DPdu3fPgcnCl5CQILfbLUlqbGzU\nkiVLxvx+koZ+XwkJCY7vK0d+Zvm0UCjk9AgR8dprr2nHjh3Kz89XR0eHNm3apKampjH586LhvCz7\nbPXq1UpNTVV2drbq6up0/PhxlZeXOz3WiJ09e1aNjY06efKkli9fPrh9rO+np99Xa2ur4/sq5keW\nHo9H3d3dg1/fuXNHGRkZsR4j4jIzM7Vy5Uq5XC5NmzZNkydPVldXl9NjRYzb7dajR48kSV1dXS/F\n6azX61V2drYkKTc3V+3t7Q5PNHIXL15UTU2N6uvrNXHixJdmP/39fcXDvop5LBctWiS/3y9Jamtr\nk8fj0YQJE2I9RsSdPn1aJ06ckCQFg0HdvXtXmZmZDk8VOQsXLhzcb01NTVq8eLHDE43ezp071dHR\nIem/P5P9328yjBW9vb2qqqpSbW3t4FXil2E/DfW+4mFfuUIOHKsfPnxYV69elcvl0r59+zRr1qxY\njxBxDx480O7du/XHH3+ov79fO3bs0Pvvv+/0WGFpbW3VwYMH1dnZqcTERGVmZurw4cMqKSnRn3/+\nqSlTpqiyslKvvPKK06OaDfWeCgsLVVdXp/Hjx8vtdquyslLp6elOj2rm8/l07Ngxvf7664PbDhw4\noL17947Z/SQN/b7WrVunhoYGR/eVI7EEgLGGO3gAwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAg\nlgBg8H/nb4OLnfGqVAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 1\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAE1ZJREFUeJzt3X1olfX/x/HXccc1DyrLuY1GaRGL\nRqZSaE7zZmqKgnhDsVwqkYGRE29QW8tp4M102solNJ03fzSqgyPoBmFDIlg1Jw0xNsrZDbKGranD\nG5x3x33/+NF+rp153js751znrOfjv13n43Xex4NPrrPL61yujo6ODgEA7muA0wMAQCwglgBgQCwB\nwIBYAoABsQQAA2IJAAbEEgAMiCUAGLiD/YM7duzQ6dOn5XK5lJ+fr9GjR4dyLgCIKkHF8uTJkzp3\n7py8Xq9+++035efny+v1hno2AIgaQX0Mr6mp0cyZMyVJjz/+uC5fvqxr166FdDAAiCZBxfLChQt6\n8MEHO38eNmyYWltbQzYUAESbkJzg4bs4APR3QcUyJSVFFy5c6Pz577//VnJycsiGAoBoE1QsJ02a\npMrKSklSQ0ODUlJSNHjw4JAOBgDRJKiz4c8884yeeuopvfzyy3K5XNqyZUuo5wKAqOLiy38BIDCu\n4AEAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUA\nGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJ\nAAbEEgAMiCUAGLiD+UO1tbVavXq10tPTJUlPPPGECgoKQjoYAESToGIpSePHj1dJSUkoZwGAqMXH\ncAAwCDqWv/76q9544w0tXrxY33//fShnAoCo4+ro6Ojo7R9qaWlRXV2d5syZo6amJi1btkxVVVWK\nj48Px4wA4LigjixTU1M1d+5cuVwujRgxQsOHD1dLS0uoZwOAqBFULL/88ksdOnRIktTa2qqLFy8q\nNTU1pIMBQDQJ6mP4tWvXtH79el25ckW3b99Wbm6upk6dGo75ACAqBBVLAPivCfr/WQL90alTp0zr\nSktLzfssKysLdpz78nec09HRIZfL1WVbbm6ueZ+9+b/T/36e/o7/ZwkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAy4Nhz93tmzZ/1uT09P7/bY4sWLTfu0XhYZaT6fT3FxcUH/\n+Vu3bpnX9uV5YhFHlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgwA3LEHa9uUjs\nzJkzpnXz588377Opqcnv9uvXr2vMmDFdtt28edO8Xyu32/7PrKCgwLw2Pj7e7/bCwsIuPz/77LPm\nfQ4YwPFTT/ibAQADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtywDEG5ffu2\nee1bb71lXrt3795gxgmKv5t7PfTQQ6Y/u3r1avPzLF++3Lz2yJEj5rW5ubndtj3wwAPdLtl84IEH\nzPtEzziyBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtzdEV3cvXvX7/YB\nAwZ0eSwvL8+8z0hewujPokWLzI999NFHpn16PB7z8y9evNi89uuvvzavbW5u7ratuLhYb7/9drdt\n6DvTkWVjY6Nmzpyp8vJySdL58+e1dOlS5eTkaPXq1bp161ZYhwQApwWM5fXr17V161ZlZmZ2bisp\nKVFOTo4++eQTjRw5UhUVFWEdEgCcFjCW8fHxKisrU0pKSue22tpazZgxQ5KUlZWlmpqa8E0IAFEg\n4O8s3W633O6uy9rb2xUfHy9JSkpKUmtra3imA4Ao0ecTPHwdZv8yYEDPHzbufey9994z77M3ayPt\n6NGjYX+OL774IuzPcS9O6IRHULH0eDy6ceOGEhIS1NLS0uUjOmKb9Wz4hg0bzPv84IMP+jxXX/R0\nNvzo0aN66aWXumyLpbPh/r6AuLi4WOvWreu2DX0X1P+znDhxoiorKyVJVVVVmjx5ckiHAoBoE/DI\nsr6+Xrt27VJzc7PcbrcqKyu1Z88e5eXlyev1Ki0tTQsWLIjErADgmICxHDVqlD7++ONu23tzrxAA\niHVcwfMf8Ndff5nXzpo1y+/2n376SWPHju38uaGhoc9z+TN06FDTutLSUvM+X3zxxR4f++yzz7r8\nfL8TXPf69NNPzc/fm99D9kZaWlqvtqNvuDYcAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYuDr4QsqYdPXqVfPaUaNGmdf++eeffrf7fD7FxcWZ93Ovf75V3+LQoUOmdY888khQ\nswRivUVKdnZ2WJ7/ny/Vtjh16lS3bU8++aR++eWXbtvQdxxZAoABsQQAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA+7uGKPKy8vNa3u6hLEvlixZYl67Z88e89rk5GTTupaWFvM+X3/9\ndb/bv/rqK82bN6/LtsrKSvN+w6E3d43s6TJGLm8MD44sAcCAWAKAAbEEAANiCQAGxBIADIglABgQ\nSwAwIJYAYEAsAcCAG5ZFmbt375rWvfDCC+Z9fvvtt+a1Pd0wq729XYMGDer8ubGx0bzPtLQ089qf\nf/7ZtG7Dhg3mfVZVVfnd3pebsIXLjRs3zGsHDhwYxknwbxxZAoABsQQAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA25YFmWsV5/25hLG3vD5fKbHiouLzfv8448/zGu/+uor89pYsWDB\nAvPaaLv8Ev+PI0sAMDDFsrGxUTNnzuy8/WpeXp7mzZunpUuXaunSpWE7ygGAaBHwY/j169e1detW\nZWZmdtm+bt06ZWVlhW0wAIgmAY8s4+PjVVZWppSUlEjMAwBRKeCRpdvtltvdfVl5ebmOHDmipKQk\nFRQUaNiwYWEZ8L/G+gv++52ICZdbt25F/DnDzYm/R8SmoM6Gz58/X4mJicrIyNCBAwe0b98+bd68\nOdSz/SdZ//H29CW9fdVTrG/dutXlOVeuXGneZ7SeDY/Ul//25mz40aNHzWsHDOD8bCQF9bedmZmp\njIwMSdL06dN79a3ZABCLgorlqlWr1NTUJEmqra1Venp6SIcCgGgT8GN4fX29du3apebmZrndblVW\nVmrJkiVas2aNBg0aJI/Ho8LCwkjMCgCOCRjLUaNG6eOPP+62ffbs2WEZCACiEZc7ogvr5Y4lJSWR\nGKdf6M0JHk7aRC/eGQAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYMDljlHG\nernbsWPHzPvszeV24fiC3958MfT69etN6/Lz84MdJyS2bdtmXvvKK6+EcRJECkeWAGBALAHAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDAFTxRxuVymdb15u6ap06dMq+9dOlSj49VV1eb93Ov\nsWPHmtfW1dUF9RyhMmbMGNO6lStXmvfJTcj6B95FADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBg4Oro6Ohwegj0b21tbea1kyZNMq07c+ZMsON08vl8iouL67Lthx9+MP3Z5557\nrs/Pj9jCkSUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDg7o4Iu5MnT5rX\nhuIyxn/Ly8szPzZ+/PiQPz/6B1Msi4qKVFdXpzt37mjFihV6+umntXHjRvl8PiUnJ2v37t2Kj48P\n96wA4JiAsTxx4oTOnj0rr9ertrY2LVy4UJmZmcrJydGcOXNUXFysiooK5eTkRGJeAHBEwN9Zjhs3\nTnv37pUkDR06VO3t7aqtrdWMGTMkSVlZWaqpqQnvlADgsICxjIuLk8fjkSRVVFRoypQpam9v7/zY\nnZSUpNbW1vBOCQAOM5/gOX78uCoqKnT48GHNmjWrcztfh4lAZs+ebV7r8/nCOEl327dvj+jzIXaZ\nYlldXa3S0lIdPHhQQ4YMkcfj0Y0bN5SQkKCWlhalpKSEe07EsMrKSvPauXPnhvz5ezobvn37dr3z\nzjtdtm3bts20T5fL1ee5EFsCfgy/evWqioqKtH//fiUmJkqSJk6c2PkPoKqqSpMnTw7vlADgsIBH\nlseOHVNbW5vWrFnTuW3nzp3atGmTvF6v0tLStGDBgrAOCQBOCxjL7OxsZWdnd9t+5MiRsAwEANGI\nG5YhKL25CVlGRoZ5bTj+Z8Xvv//ud/vIkSN17ty5btsAf7g2HAAMiCUAGBBLADAglgBgQCwBwIBY\nAoABsQQAA2IJAAbEEgAMiCUAGHDDMgSlrKzMvDYclzDm5uaa16alpQX1GHAvjiwBwIBYAoABsQQA\nA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABlzuiizt37vjd7na7uzz2+eefh+X5V61aZVr3\n/vvvm/fpcrl6fGzgwIHm/eC/jSNLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADBw\ndXR0dDg9BKLHd99953f7888/3+WxqVOnmvf58MMPm9eeOXPGtC4hIcG8TyAUOLIEAANiCQAGxBIA\nDIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAG3LAMXQwZMiSox+5ny5Yt5rVcxohoZYplUVGR\n6urqdOfOHa1YsULffPONGhoalJiYKElavny5pk2bFs45AcBRAWN54sQJnT17Vl6vV21tbVq4cKEm\nTJigdevWKSsrKxIzAoDjAsZy3LhxGj16tCRp6NCham9vl8/nC/tgABBNAp7giYuLk8fjkSRVVFRo\nypQpiouLU3l5uZYtW6a1a9fq0qVLYR8UAJxk/j7L48ePa//+/Tp8+LDq6+uVmJiojIwMHThwQH/9\n9Zc2b94c7lkBwDGmEzzV1dUqLS3VwYMHNWTIEGVmZnY+Nn36dL377rvhmg8Rdvr0ab/bx4wZ0+Wx\nZ555xrzPsrIy89rXXnvNvBaIpIAfw69evaqioiLt37+/8+z3qlWr1NTUJEmqra1Venp6eKcEAIcF\nPLI8duyY2tratGbNms5tixYt0po1azRo0CB5PB4VFhaGdUgAcFrAWGZnZys7O7vb9oULF4ZlIACI\nRlzuCAAG3N0RAAw4sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQA\nA2IJAAbEEgAM3E486Y4dO3T69Gm5XC7l5+dr9OjRTowRUrW1tVq9erXS09MlSU888YQKCgocnip4\njY2NevPNN/Xqq69qyZIlOn/+vDZu3Cifz6fk5GTt3r1b8fHxTo/ZK/9+TXl5eWpoaFBiYqIkafny\n5Zo2bZqzQ/ZSUVGR6urqdOfOHa1YsUJPP/10zL9PUvfX9c033zj+XkU8lidPntS5c+fk9Xr122+/\nKT8/X16vN9JjhMX48eNVUlLi9Bh9dv36dW3dulWZmZmd20pKSpSTk6M5c+aouLhYFRUVysnJcXDK\n3vH3miRp3bp1ysrKcmiqvjlx4oTOnj0rr9ertrY2LVy4UJmZmTH9Pkn+X9eECRMcf68i/jG8pqZG\nM2fOlCQ9/vjjunz5sq5duxbpMXAf8fHxKisrU0pKSue22tpazZgxQ5KUlZWlmpoap8YLir/XFOvG\njRunvXv3SpKGDh2q9vb2mH+fJP+vy+fzOTyVA7G8cOGCHnzwwc6fhw0bptbW1kiPERa//vqr3njj\nDS1evFjff/+90+MEze12KyEhocu29vb2zo9zSUlJMfee+XtNklReXq5ly5Zp7dq1unTpkgOTBS8u\nLk4ej0eSVFFRoSlTpsT8+yT5f11xcXGOv1eO/M7yXh0dHU6PEBKPPvqocnNzNWfOHDU1NWnZsmWq\nqqqKyd8XBdJf3rP58+crMTFRGRkZOnDggPbt26fNmzc7PVavHT9+XBUVFTp8+LBmzZrVuT3W36d7\nX1d9fb3j71XEjyxTUlJ04cKFzp///vtvJScnR3qMkEtNTdXcuXPlcrk0YsQIDR8+XC0tLU6PFTIe\nj0c3btyQJLW0tPSLj7OZmZnKyMiQJE2fPl2NjY0OT9R71dXVKi0tVVlZmYYMGdJv3qd/v65oeK8i\nHstJkyapsrJSktTQ0KCUlBQNHjw40mOE3JdffqlDhw5JklpbW3Xx4kWlpqY6PFXoTJw4sfN9q6qq\n0uTJkx2eqO9WrVqlpqYmSf/3O9l//idDrLh69aqKioq0f//+zrPE/eF98ve6ouG9cnU4cKy+Z88e\n/fjjj3K5XNqyZYuefPLJSI8QcteuXdP69et15coV3b59W7m5uZo6darTYwWlvr5eu3btUnNzs9xu\nt1JTU7Vnzx7l5eXp5s2bSktLU2FhoQYOHOj0qGb+XtOSJUt04MABDRo0SB6PR4WFhUpKSnJ6VDOv\n16sPP/xQjz32WOe2nTt3atOmTTH7Pkn+X9eiRYtUXl7u6HvlSCwBINZwBQ8AGBBLADAglgBgQCwB\nwIBYAoABsQQAA2IJAAbEEgAM/gepgR0uaefKmwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 4\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEelJREFUeJzt3W9MlfX/x/HXEWJyhg5BIG1ZfR0u\nKr3hhopOE2Q23FxiN0xCdNmGa5pG6hhTtNn8g85NtI0/aS1Z29moG96wILM2dYDKDRu0hrpyzCkC\nkUocDeH8brQfk8R4czyH64DPx624+Hid99nFnl2H61wHl8/n8wkA8J/GOD0AAIwExBIADIglABgQ\nSwAwIJYAYEAsAcCAWAKAAbEEAINwf//h7t27denSJblcLhUUFGjGjBmBnAsAQopfsTx//ryuXbsm\nj8ejq1evqqCgQB6PJ9CzAUDI8OtleE1NjdLT0yVJU6dO1e3bt9XZ2RnQwQAglPgVy7a2Nk2YMKHv\n65iYGLW2tgZsKAAINQG5wMNncQAY7fyKZXx8vNra2vq+vnXrluLi4gI2FACEGr9iOW/ePFVVVUmS\nGhsbFR8fr6ioqIAOBgChxK+r4TNnztSrr76qt99+Wy6XSzt27Aj0XAAQUlx8+C8ADI47eADAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAz8+lO4AJz3yy+/PLLtlVdeeWT777//bt7ne++9Z147f/58\n0zqPx2PeZyjjzBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4fD6fz+kh\ngNHsr7/+Mq+tr683r33rrbce2dba2qq4uLh+29rb2837XL16tXntp59+alrndrvN+wxlnFkCgAGx\nBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAF/sAzww/37981rMzMzzWtPnTplXvu4O2O8\nXm+/rysrK837XLJkiXnt2LFjzWtHA84sAcDArzPLuro6bdy4UYmJiZKkadOmafv27QEdDABCid8v\nw2fNmqXi4uJAzgIAIYuX4QBg4Hcsr1y5onXr1mnlypU6d+5cIGcCgJDj1+dZtrS0qL6+XhkZGWpu\nblZOTo6qq6sVERERjBkBwHF+/c4yISGh7y0GU6ZM0cSJE9XS0qLnn38+oMMBoWoobx1aunSpee2T\nvnWos7NTUVFR/bZ9+eWX5n3y1qHH8+tl+IkTJ3T06FFJ/3wyc3t7uxISEgI6GACEEr/OLNPS0rR5\n82b98MMP6u7u1s6dO3kJDmBU8yuWUVFRKikpCfQsABCyuN0ReIj1vcNbtmwx77O7u9u8dii/9//x\nxx8H3P7zzz/3+/p///ufeZ94PN5nCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAG\nxBIADPz6PEvAaT09Pea1x48fH3D7mjVr9MUXX/Tblpuba9pnb2+v+fE/+eQT89qcnBzz2kmTJpnX\n4slxZgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtzBgxHpcXflDGT16tUDbu/t\n7dWYMf6dL+zcudO8trCw0K/HQGjhzBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIA\nDIglABhwuyNCSnFxsWndRx99ZN7n4/642UC3O77zzjumff77D539l7CwMPNahC7OLAHAgFgCgAGx\nBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAG3OyLovF6vee2kSZNM6+7cuePvOH0Gut2x\npqbG9G9nz579xI+PkcV0ZtnU1KT09HRVVFRIkm7cuKFVq1YpKytLGzdu1N9//x3UIQHAaYPGsqur\nS7t27VJKSkrftuLiYmVlZemrr77SCy+8oMrKyqAOCQBOGzSWERERKi8vV3x8fN+2uro6LVq0SJKU\nmppqfukCACNV+KALwsMVHt5/mdfrVUREhCQpNjZWra2twZkOAELEoLEcDNeHMJjIyEjz2j///DOI\nkzyqt7d3WB8PI5dfsXS73bp3757Gjh2rlpaWfi/RgX/jajhGA7/eZzl37lxVVVVJkqqrqzV//vyA\nDgUAoWbQM8uGhgbt27dP169fV3h4uKqqqnTgwAHl5+fL4/Fo8uTJWrZs2XDMCgCO4U3pCDpehmM0\neOILPHg6ffvtt+a1hw4dMq8NRASfRElJiWkdsXz6cG84ABgQSwAwIJYAYEAsAcCAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAw4HZH+MV6W6D0zydTWU2ZMsW07v79++Z9trS0mNcCj8OZJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMOB2R/Rz4cKFAbcnJyf3+15tbW1QHv/77783\nrRvKX4FMTk72dxygD2eWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDAHTzoZ8GC\nBQNu93q9/b43lD8YNhTWP1jm9XqD8vjA43BmCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEE\nAANiCQAGxBIADLjd8Slw5MgR89r/uo3R31scZ8yYYV7rcrn8eoxAuXnzpmldV1eXeZ9ut9vfcRBC\nOLMEAANTLJuampSenq6KigpJUn5+vpYuXapVq1Zp1apV+umnn4I5IwA4btCX4V1dXdq1a5dSUlL6\nbc/Ly1NqamrQBgOAUDLomWVERITKy8sVHx8/HPMAQEhy+Xw+n2Xh4cOHNWHCBGVnZys/P1+tra3q\n7u5WbGystm/frpiYmGDPCgCO8etq+Jtvvqno6GglJSWprKxMR44cUWFhYaBnQ4AM5Wr4Bx98MOD2\n3t5ejRnj3/XAoVwNP3/+vGndUK5GP+5/5AM9pzfeeMO0z6+//tr8+FwNHx38+ulPSUlRUlKSJCkt\nLU1NTU0BHQoAQo1fsdywYYOam5slSXV1dUpMTAzoUAAQagZ9Gd7Q0KB9+/bp+vXrCg8PV1VVlbKz\ns7Vp0yZFRkbK7XZrz549wzErADhm0Fi+9tprOn78+CPbrb/bAYDRgNsdnwLt7e2OPv6WLVvMayMi\nIkzrhnKBZyiqqqpM63799VfzPmfOnOnvOAgh3O4IAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMuN0RfomNjTWvTU5ODvjjnz17NuD7lNT30YODee6554Ly+AhdnFkCgAGxBAAD\nYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAF38MAv48ePN6999tlnA/74FRUVAd+nJM2aNcu0\nLiEhISiPj9DFmSUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDgdkf45bff\nfjOv/eabb8xrs7OzTet6e3vN+/T5fH59D3gYZ5YAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQ\nSwAwIJYAYEAsAcCA2x0RdO+++25Q1lq5XC6/vgc8zBTLoqIi1dfX68GDB8rNzdX06dO1detW9fT0\nKC4uTvv371dERESwZwUAxwway9raWl2+fFkej0cdHR3KzMxUSkqKsrKylJGRoYMHD6qyslJZWVnD\nMS8AOGLQ31kmJyfr0KFDkqTx48fL6/Wqrq5OixYtkiSlpqaqpqYmuFMCgMMGjWVYWJjcbrckqbKy\nUgsWLJDX6+172R0bG6vW1tbgTgkADjNf4Dl16pQqKyt17NgxLV68uG87nwcY+nbs2BGQtUP5DMmR\nYjQ+JwSHKZZnzpxRSUmJPvvsM40bN05ut1v37t3T2LFj1dLSovj4+GDPiSfw8ccfP/Ha3t5ejRkz\nut5pNtBzWr16tenffv7558EYCSFs0J/+u3fvqqioSKWlpYqOjpYkzZ07V1VVVZKk6upqzZ8/P7hT\nAoDDBj2zPHnypDo6OrRp06a+bXv37tW2bdvk8Xg0efJkLVu2LKhDAoDTBo3lihUrtGLFike28zIE\nwNOEO3ieAnl5eea1Fy5ceOz3lixZ0vffZ8+eNe/zzp075rVAqBpdv7EHgCAhlgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg4PLxgZTww3fffWde+/Btkk543I+4z+d75A+W1dbWmvY5\ne/bsJ54LIwtnlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIDbHQHAgDNL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAg3DLoqKiItXX1+vBgwfKzc3V6dOn1djYqOjoaEnS\n2rVrtXDhwmDOCQCOGjSWtbW1unz5sjwejzo6OpSZmak5c+YoLy9PqampwzEjADhu0FgmJydrxowZ\nkqTx48fL6/Wqp6cn6IMBQChx+Xw+n3Wxx+PRxYsXFRYWptbWVnV3dys2Nlbbt29XTExMMOcEAEeZ\nY3nq1CmVlpbq2LFjamhoUHR0tJKSklRWVqabN2+qsLAw2LMCgGNMV8PPnDmjkpISlZeXa9y4cUpJ\nSVFSUpIkKS0tTU1NTUEdEgCcNmgs7969q6KiIpWWlvZd/d6wYYOam5slSXV1dUpMTAzulADgsEEv\n8Jw8eVIdHR3atGlT37bly5dr06ZNioyMlNvt1p49e4I6JAA4bUgXeADgacUdPABgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbhTjzo7t27\ndenSJblcLhUUFGjGjBlOjBFQdXV12rhxoxITEyVJ06ZN0/bt2x2eyn9NTU16//33tWbNGmVnZ+vG\njRvaunWrenp6FBcXp/379ysiIsLpMYfk388pPz9fjY2Nio6OliStXbtWCxcudHbIISoqKlJ9fb0e\nPHig3NxcTZ8+fcQfJ+nR53X69GnHj9Wwx/L8+fO6du2aPB6Prl69qoKCAnk8nuEeIyhmzZql4uJi\np8d4Yl1dXdq1a5dSUlL6thUXFysrK0sZGRk6ePCgKisrlZWV5eCUQzPQc5KkvLw8paamOjTVk6mt\nrdXly5fl8XjU0dGhzMxMpaSkjOjjJA38vObMmeP4sRr2l+E1NTVKT0+XJE2dOlW3b99WZ2fncI+B\n/xAREaHy8nLFx8f3baurq9OiRYskSampqaqpqXFqPL8M9JxGuuTkZB06dEiSNH78eHm93hF/nKSB\nn1dPT4/DUzkQy7a2Nk2YMKHv65iYGLW2tg73GEFx5coVrVu3TitXrtS5c+ecHsdv4eHhGjt2bL9t\nXq+37+VcbGzsiDtmAz0nSaqoqFBOTo4+/PBD/fHHHw5M5r+wsDC53W5JUmVlpRYsWDDij5M08PMK\nCwtz/Fg58jvLh/l8PqdHCIgXX3xR69evV0ZGhpqbm5WTk6Pq6uoR+fuiwYyWY/bmm28qOjpaSUlJ\nKisr05EjR1RYWOj0WEN26tQpVVZW6tixY1q8eHHf9pF+nB5+Xg0NDY4fq2E/s4yPj1dbW1vf17du\n3VJcXNxwjxFwCQkJWrJkiVwul6ZMmaKJEyeqpaXF6bECxu126969e5KklpaWUfFyNiUlRUlJSZKk\ntLQ0NTU1OTzR0J05c0YlJSUqLy/XuHHjRs1x+vfzCoVjNeyxnDdvnqqqqiRJjY2Nio+PV1RU1HCP\nEXAnTpzQ0aNHJUmtra1qb29XQkKCw1MFzty5c/uOW3V1tebPn+/wRE9uw4YNam5ulvTP72T//50M\nI8Xdu3dVVFSk0tLSvqvEo+E4DfS8QuFYuXwOnKsfOHBAFy9elMvl0o4dO/Tyyy8P9wgB19nZqc2b\nN+vOnTvq7u7W+vXr9frrrzs9ll8aGhq0b98+Xb9+XeHh4UpISNCBAweUn5+v+/fva/LkydqzZ4+e\neeYZp0c1G+g5ZWdnq6ysTJGRkXK73dqzZ49iY2OdHtXM4/Ho8OHDeumll/q27d27V9u2bRuxx0ka\n+HktX75cFRUVjh4rR2IJACMNd/AAgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA4P8ALqDX\nN3rmU3AAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 1\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEqVJREFUeJzt3W9Ilff/x/HX+eWkpMQ0dQRrZdgm\nq24Miiz6Y0nrFKPVjZqiMgiW/SMX0ZxlDYJMiyALZrnqRlKc4a1u5B9cjIWZUbDA7ljWQqJMm1iR\nbSbne2P8/H7NY77P8Ryvoz0f97y8us777BpPrnMuP+e4vF6vVwCAd/o/pwcAgNGAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAwIJYAYBAR6D88dOiQbt++LZfLpYKCAs2dOzeYcwFAWAkoljdu3NDDhw/l\n8XjU0tKigoICeTyeYM8GAGEjoJfhDQ0NSk9PlyTNnDlTXV1devnyZVAHA4BwElAsOzo6NHny5L6f\nY2Nj1d7eHrShACDcBOUGD5/FAWCsCyiWCQkJ6ujo6Pv56dOnio+PD9pQABBuAorlokWLVFNTI0m6\nc+eOEhISNHHixKAOBgDhJKC74Z9//rk+++wzff3113K5XDpw4ECw5wKAsOLiw38BYGis4AEAA2IJ\nAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBY\nAoABsQQAA2IJAAbEEgAMiCUAGBBLADAI6KtwgVC5ePGiab+9e/eaj/ngwQOf271er1wul/k4gWpp\naTHvm5SUFMJJMBxcWQKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAAOWOyIg\n9+/fD8lxMzMzTfutWrXKfMzBljv6MmPGjKAf88mTJ+Z9We4YvriyBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGBALAHAgFgCgAGxBAADVvAgIOnp6eZ9/VntYrV06VLzvh6PZ9DfdXV19fs5OjradMwtW7aY\nH3/27NnmfRG+uLIEAIOAriwbGxu1c+dOJScnS5JmzZqlwsLCoA4GAOEk4Jfh8+fPV2lpaTBnAYCw\nxctwADAIOJb37t1Tbm6uMjIyVF9fH8yZACDsuLxer9fff9TW1qZbt27J7XartbVVOTk5qq2tVWRk\nZChmBADHBfSeZWJiolavXi1JmjZtmqZMmaK2tjZ99NFHQR0O4cufD6kNxZ8OFRUVmffdunWrz+3R\n0dF6/vz5gG0W/vzpUHFxsXlf6+Nj5AX0MvzSpUs6c+aMJKm9vV3Pnj1TYmJiUAcDgHAS0JXl8uXL\ntXv3bv3666/q6enRjz/+yEtwAGNaQLGcOHGiysrKgj0LAIStgG7wYHR5+325d9m4caPP7VVVVXK7\n3X0/V1dXD3suX6zvRebn54fk8YHB8HeWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgOWO7wF/Pk5ssDX/Xq9XLpcroMf35+PUWMaIcMWVJQAYEEsAMCCWAGBALAHAgFgCgAGx\nBAADYgkABsQSAAyIJQAYsIJnlLp27Zp530WLFg378d5ewXPhwgXzv83IyBj24wNO48oSAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYRDg9APp7/vy5ab9gLGH0JTc31/Q7ljDi\nfcOVJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMODbHcOM2+027VddXW0+\n5qpVq8z7ejwen9ujo6P7LcWMjo42HxMYC0xXls3NzUpPT1dFRYUk6fHjx8rOzlZmZqZ27typf/75\nJ6RDAoDThozlq1evdPDgQaWmpvZtKy0tVWZmpi5cuKCPP/5YlZWVIR0SAJw2ZCwjIyNVXl6uhISE\nvm2NjY1asWKFJCktLU0NDQ2hmxAAwsCQH9EWERGhiIj+u3V3dysyMlKSFBcXp/b29tBMBwBhYtif\nZ8n9oeCqqqpyeoRBcVMH77OAYhkVFaXXr19r/Pjxamtr6/cSHcPD3XAgPAX0d5YLFy5UTU2NJKm2\ntlaLFy8O6lAAEG6GvLJsampScXGxHj16pIiICNXU1Ojo0aPKz8+Xx+PR1KlT9dVXX43ErADgmCFj\nOXv2bJ0/f37A9nPnzoVkIAAIR6zgGQH379837ztz5sygP35LS4t536SkpKA/PjAWsDYcAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYDPvzLDG0I0eOBP2Yubm55n1ZwggMH1eW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgOWOI6Cmpibox8zOzg76Mceq\nwb5dMykpacDvrEtT//zzT/PjT58+3byvP/+vfPLJJwO2VVVVye1299uWk5NjPuaaNWvM+0ZHR5v3\nHQu4sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA5fX6/U6PcRY588Xhj148MC0\nX0tLS0ge32kXL1407bd3717zMQf7b+r1euVyuczHGQ2G+5xWrVpl3tfj8Zj2GysrfbiyBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABix3HAFbtmwx71tWVmbabzSdtlAs9wyG\n4SwN9GdZYHV1dUCPEYiRXMJpXXI7mpbbvgtXlgBgYIplc3Oz0tPTVVFRIUnKz8/Xl19+qezsbGVn\nZ+u3334L5YwA4Lghvzf81atXOnjwoFJTU/tt37Vrl9LS0kI2GACEkyGvLCMjI1VeXq6EhISRmAcA\nwpL5Bs+JEyc0efJkZWVlKT8/X+3t7erp6VFcXJwKCwsVGxsb6lkBwDFDvgz3Ze3atYqJiVFKSopO\nnz6tkydPav/+/cGebczgbjh3w0cKd8NDJ6C74ampqUpJSZEkLV++XM3NzUEdCgDCTUCx3LFjh1pb\nWyVJjY2NSk5ODupQABBuhnwZ3tTUpOLiYj169EgRERGqqalRVlaW8vLyNGHCBEVFRamoqGgkZgUA\nxwwZy9mzZ+v8+fMDtn/xxRchGQgAwlFAN3gAt9tt3jcUN238eTWzYcOGQX/39k2KKVOmBDzTYEL1\n7YbPnz/3ub2rq6vfz99//735mNYbjJK0bds2035VVVXmY4YzljsCgAGxBAADYgkABsQSAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADljuOUteuXTPvu3DhwmEfd+HChf1+F6rPaKyvrzft589zepfR\n/FmLgy2jfHv7Tz/9ZD6mP8sd3zdcWQKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKA\nASt4RkBxcbF535qaGtN+WVlZ5mP+8ccf5n19fZOn9O+KmcF+NxR/vlwsWCtz8F/+rPbyR2FhYUiO\nG664sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAYur9frdXoI/Jd1adqi\nRYtCPEl/Xq9XLpcroH/b1dVl3newL+HCQBcvXhywLSMjY8D2zMxM8zEvXLhg3nfNmjWm/cbKOeXK\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGLDccZTy5xv7grE0cjjLHf35\ndseHDx+a9svOzjYf88MPP/S5PSkpSffv3++37ZdffjEdc8mSJebH98fBgwfN+1ZXVw/YNpzzJEn1\n9fXmfd+3b+I0fRVuSUmJbt26pTdv3mjz5s2aM2eO9uzZo97eXsXHx+vIkSOKjIwM9awA4JghY3n9\n+nXdvXtXHo9HnZ2dWrdunVJTU5WZmSm3261jx46psrLSr8X6ADDaDPme5bx583T8+HFJ/356SHd3\ntxobG7VixQpJUlpamhoaGkI7JQA4bMhYjhs3TlFRUZKkyspKLVmyRN3d3X0vu+Pi4tTe3h7aKQHA\nYab3LCWprq5OlZWVOnv2rFauXNm3nftDzvDnzfVgnaOxeK6TkpL6/Zyfn+/QJP+qqqoa9jHG4nkK\nB6ZYXr16VWVlZfr55581adIkRUVF6fXr1xo/frza2tqUkJAQ6jnxFu6Gczecu+Eja8iX4S9evFBJ\nSYlOnTqlmJgYSf/+R6qpqZEk1dbWavHixaGdEgAcNuSV5eXLl9XZ2am8vLy+bYcPH9a+ffvk8Xg0\ndepUffXVVyEdEgCcNmQsN27cqI0bNw7Yfu7cuZAMBADhiBU874G335d7l23btvncXlVVJbfb3fez\nr/fLRpvhvr/ntBkzZgzYdv/+/QE3rerq6szHnDJlinnfsfJFZFasDQcAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAYsd0RA/PmIuPPnz5v3tX702u+//24+5g8//OBzu6/ljr6W\nEPry7bffmh9/w4YN5n398fayRoQWV5YAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYA\nYEAsAcCA5Y4AYMCVJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBA\nLAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgEGEZaeSkhLdunVLb9680ebN\nm3XlyhXduXNHMTExkqRNmzZp2bJloZwTABw1ZCyvX7+uu3fvyuPxqLOzU+vWrdOCBQu0a9cupaWl\njcSMAOC4IWM5b948zZ07V5IUHR2t7u5u9fb2hnwwAAgnLq/X67Xu7PF4dPPmTY0bN07t7e3q6elR\nXFycCgsLFRsbG8o5AcBR5ljW1dXp1KlTOnv2rJqamhQTE6OUlBSdPn1aT5480f79+0M9KwA4xnQ3\n/OrVqyorK1N5ebkmTZqk1NRUpaSkSJKWL1+u5ubmkA4JAE4bMpYvXrxQSUmJTp061Xf3e8eOHWpt\nbZUkNTY2Kjk5ObRTAoDDhrzBc/nyZXV2diovL69v2/r165WXl6cJEyYoKipKRUVFIR0SAJzm1w0e\nAHhfsYIHAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGBALAHAIMKJBz106JBu374tl8ulgoICzZ0714kxgqqxsVE7d+5UcnKyJGnWrFkqLCx0eKrA\nNTc3a+vWrfrmm2+UlZWlx48fa8+ePert7VV8fLyOHDmiyMhIp8f0y9vPKT8/X3fu3FFMTIwkadOm\nTVq2bJmzQ/qppKREt27d0ps3b7R582bNmTNn1J8naeDzunLliuPnasRjeePGDT18+FAej0ctLS0q\nKCiQx+MZ6TFCYv78+SotLXV6jGF79eqVDh48qNTU1L5tpaWlyszMlNvt1rFjx1RZWanMzEwHp/SP\nr+ckSbt27VJaWppDUw3P9evXdffuXXk8HnV2dmrdunVKTU0d1edJ8v28FixY4Pi5GvGX4Q0NDUpP\nT5ckzZw5U11dXXr58uVIj4F3iIyMVHl5uRISEvq2NTY2asWKFZKktLQ0NTQ0ODVeQHw9p9Fu3rx5\nOn78uCQpOjpa3d3do/48Sb6fV29vr8NTORDLjo4OTZ48ue/n2NhYtbe3j/QYIXHv3j3l5uYqIyND\n9fX1To8TsIiICI0fP77ftu7u7r6Xc3FxcaPunPl6TpJUUVGhnJwcfffdd/rrr78cmCxw48aNU1RU\nlCSpsrJSS5YsGfXnSfL9vMaNG+f4uXLkPcv/5fV6nR4hKKZPn67t27fL7XartbVVOTk5qq2tHZXv\nFw1lrJyztWvXKiYmRikpKTp9+rROnjyp/fv3Oz2W3+rq6lRZWamzZ89q5cqVfdtH+3n63+fV1NTk\n+Lka8SvLhIQEdXR09P389OlTxcfHj/QYQZeYmKjVq1fL5XJp2rRpmjJlitra2pweK2iioqL0+vVr\nSVJbW9uYeDmbmpqqlJQUSdLy5cvV3Nzs8ET+u3r1qsrKylReXq5JkyaNmfP09vMKh3M14rFctGiR\nampqJEl37txRQkKCJk6cONJjBN2lS5d05swZSVJ7e7uePXumxMREh6cKnoULF/adt9raWi1evNjh\niYZvx44dam1tlfTve7L//5cMo8WLFy9UUlKiU6dO9d0lHgvnydfzCodz5fI6cK1+9OhR3bx5Uy6X\nSwcOHNCnn3460iME3cuXL7V79249f/5cPT092r59u5YuXer0WAFpampScXGxHj16pIiICCUmJuro\n0aPKz8/X33//ralTp6qoqEgffPCB06Oa+XpOWVlZOn36tCZMmKCoqCgVFRUpLi7O6VHNPB6PTpw4\noRkzZvRtO3z4sPbt2zdqz5Pk+3mtX79eFRUVjp4rR2IJAKMNK3gAwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBg8B9OkjtgR8VvdgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 6\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "4SJizeJtNaAs", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Profiling\n", + "\n", + "If you want to drill down into the performance characteristics of your code, you can use native Python profilers like [`cProfile`](https://docs.python.org/3/library/profile.html). In the next exercise, you'll do just that." + ] + }, + { + "metadata": { + "id": "_2v0QnG8__PJ", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Exercise!\n", + "\n", + "This exercise does not require coding. If you have not completed the training exercise, replace `train_one_epoch` below with `_train_one_epoch`.\n", + "\n", + "Run the below cell and inspect the printed profiles. What parts of the code appear to be hotspots or\n", + "bottlenecks? How does sorting the profile by total time compare to sorting it\n", + "by cumulative time?\n", + "\n" + ] + }, + { + "metadata": { + "id": "IFypaYbG_9fB", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 714 + }, + "outputId": "d9c3596b-a165-4edd-fc6b-53ccd0d01d19" + }, + "cell_type": "code", + "source": [ + "import cProfile\n", + "import pstats\n", + "\n", + "cProfile.run(\"train_one_epoch(model, training_data, optimizer)\", \"training_profile\")\n", + "\n", + "stats = pstats.Stats(\"training_profile\").strip_dirs().sort_stats(\"tottime\")\n", + "stats.print_stats(10)\n", + "\n", + "stats.sort_stats(\"cumtime\").print_stats(10)" + ], + "execution_count": 17, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Thu Jun 7 12:25:04 2018 training_profile\n", + "\n", + " 92209 function calls (91817 primitive calls) in 3.446 seconds\n", + "\n", + " Ordered by: internal time\n", + " List reduced from 672 to 10 due to restriction <10>\n", + "\n", + " ncalls tottime percall cumtime percall filename:lineno(function)\n", + " 1080 2.552 0.002 2.552 0.002 {built-in method _pywrap_tensorflow_internal.TFE_Py_FastPathExecute}\n", + " 83 0.753 0.009 0.753 0.009 {built-in method _pywrap_tensorflow_internal.TFE_Py_Execute}\n", + " 16 0.006 0.000 1.019 0.064 network.py:736(_run_internal_graph)\n", + " 16 0.005 0.000 2.253 0.141 {built-in method _pywrap_tensorflow_internal.TFE_Py_TapeGradient}\n", + " 2321 0.004 0.000 0.007 0.000 abc.py:178(__instancecheck__)\n", + " 288 0.004 0.000 0.009 0.000 inspect.py:2092(_signature_from_function)\n", + " 878 0.004 0.000 0.005 0.000 ops.py:5936(__enter__)\n", + " 288 0.004 0.000 0.016 0.000 inspect.py:1079(getfullargspec)\n", + " 11006 0.003 0.000 0.005 0.000 {built-in method builtins.isinstance}\n", + " 768 0.003 0.000 0.008 0.000 {built-in method _pywrap_tensorflow_internal.Flatten}\n", + "\n", + "\n", + "Thu Jun 7 12:25:04 2018 training_profile\n", + "\n", + " 92209 function calls (91817 primitive calls) in 3.446 seconds\n", + "\n", + " Ordered by: cumulative time\n", + " List reduced from 672 to 10 due to restriction <10>\n", + "\n", + " ncalls tottime percall cumtime percall filename:lineno(function)\n", + " 1 0.000 0.000 3.446 3.446 {built-in method builtins.exec}\n", + " 1 0.000 0.000 3.446 3.446 :1()\n", + " 1 0.001 0.001 3.446 3.446 :9(train_one_epoch)\n", + " 1080 2.552 0.002 2.552 0.002 {built-in method _pywrap_tensorflow_internal.TFE_Py_FastPathExecute}\n", + " 16 0.000 0.000 2.255 0.141 backprop.py:739(gradient)\n", + " 16 0.000 0.000 2.253 0.141 imperative_grad.py:31(imperative_grad)\n", + " 16 0.005 0.000 2.253 0.141 {built-in method _pywrap_tensorflow_internal.TFE_Py_TapeGradient}\n", + " 400 0.002 0.000 2.246 0.006 backprop.py:145(grad_fn)\n", + " 400 0.002 0.000 2.239 0.006 backprop.py:95(_magic_gradient_function)\n", + " 32 0.001 0.000 1.601 0.050 nn_grad.py:497(_Conv2DGrad)\n", + "\n", + "\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 17 + } + ] + }, + { + "metadata": { + "id": "8ixpnyCNNTI4", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb b/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..64d19ec5c9bfccd07eabb21ce8fbb62b21f23efa --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb @@ -0,0 +1,443 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Debugging \"graph-first\" models with eager execution", + "version": "0.3.2", + "provenance": [], + "include_colab_link": true + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "[View in Colaboratory](https://colab.research.google.com/gist/alextp/9568ab40f6ed6f9a3ba4736f6aef6127/debugging-graph-first-models-with-eager-execution.ipynb)" + ] + }, + { + "metadata": { + "id": "mm-t0GuIu1Dt", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "This colab uses eager execution and the Python debugger to modify the execution of a translation model. This combination lets you quickly explore counterfactuals when researching and designing modifications to a model.\n", + "\n", + "The model, Transformer from [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), was originally written with graph building in mind. Executing it eagerly can still be helpful!" + ] + }, + { + "metadata": { + "id": "gxb1DvIDg4sv", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "#@title License (double click to show)\n", + "# Copyright 2018 The TensorFlow Authors.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Gx3HA9N1ui64", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 37 + }, + "outputId": "f6986f34-f3e1-44e1-c902-2eb33081acad" + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "import pdb\n", + "tfe = tf.contrib.eager\n", + "\n", + "tf.enable_eager_execution()" + ], + "execution_count": 1, + "outputs": [] + }, + { + "metadata": { + "id": "3LkOm2ct-Lmc", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 37 + }, + "outputId": "2edc74d9-6bc0-4e78-ab4e-83bf96099ef4" + }, + "cell_type": "code", + "source": [ + "!pip install -q -U tensor2tensor\n", + "from tensor2tensor.models import transformer" + ], + "execution_count": 2, + "outputs": [] + }, + { + "metadata": { + "id": "1Z3oMsqV0zB6", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 170 + }, + "outputId": "0a8186ee-c688-457f-c9f6-9a6c1477a93b" + }, + "cell_type": "code", + "source": [ + "#@title Create a tensor2tensor translation model, fetch a checkpoint (double click to show)\n", + "from tensor2tensor import problems\n", + "from tensor2tensor.utils import trainer_lib\n", + "from tensor2tensor.utils import registry\n", + "\n", + "import numpy as np\n", + "import os\n", + "\n", + "# Setup some directories\n", + "data_dir = os.path.expanduser(\"~/t2t/data\")\n", + "tmp_dir = os.path.expanduser(\"~/t2t/tmp\")\n", + "train_dir = os.path.expanduser(\"~/t2t/train\")\n", + "checkpoint_dir = os.path.expanduser(\"~/t2t/checkpoints\")\n", + "tf.gfile.MakeDirs(data_dir)\n", + "tf.gfile.MakeDirs(tmp_dir)\n", + "tf.gfile.MakeDirs(train_dir)\n", + "tf.gfile.MakeDirs(checkpoint_dir)\n", + "gs_data_dir = \"gs://tensor2tensor-data\"\n", + "gs_ckpt_dir = \"gs://tensor2tensor-checkpoints/\"\n", + "\n", + "# Fetch the problem\n", + "ende_problem = problems.problem(\"translate_ende_wmt32k\")\n", + "\n", + "# Copy the vocab file locally so we can encode inputs and decode model outputs\n", + "# All vocabs are stored on GCS\n", + "vocab_name = \"vocab.ende.32768\"\n", + "vocab_file = os.path.join(gs_data_dir, vocab_name)\n", + "!gsutil cp {vocab_file} {data_dir}\n", + "\n", + "# Get the encoders from the problem\n", + "encoders = ende_problem.feature_encoders(data_dir)\n", + "\n", + "# Setup helper functions for encoding and decoding\n", + "def encode(input_str, output_str=None):\n", + " \"\"\"Input str to features dict, ready for inference\"\"\"\n", + " inputs = encoders[\"inputs\"].encode(input_str) + [1] # add EOS id\n", + " batch_inputs = tf.reshape(inputs, [1, -1, 1]) # Make it 3D.\n", + " return {\"inputs\": batch_inputs}\n", + "\n", + "def decode(integers):\n", + " \"\"\"List of ints to str\"\"\"\n", + " integers = list(np.squeeze(integers))\n", + " if 1 in integers:\n", + " integers = integers[:integers.index(1)]\n", + " return encoders[\"inputs\"].decode(np.squeeze(integers))\n", + "\n", + "# Copy the pretrained checkpoint locally\n", + "ckpt_name = \"transformer_ende_test\"\n", + "gs_ckpt = os.path.join(gs_ckpt_dir, ckpt_name)\n", + "!gsutil -q cp -R {gs_ckpt} {checkpoint_dir}\n", + "checkpoint_path = tf.train.latest_checkpoint(\n", + " os.path.join(checkpoint_dir, ckpt_name))\n", + "\n", + "# Create hparams and the model\n", + "model_name = \"transformer\"\n", + "hparams_set = \"transformer_base\"\n", + "\n", + "hparams = trainer_lib.create_hparams(hparams_set, data_dir=data_dir, problem_name=\"translate_ende_wmt32k\")\n", + "\n", + "# NOTE: Only create the model once when restoring from a checkpoint; it's a\n", + "# Layer and so subsequent instantiations will have different variable scopes\n", + "# that will not match the checkpoint.\n", + "translate_model = registry.model(model_name)(hparams, tf.estimator.ModeKeys.EVAL)" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Copying gs://tensor2tensor-data/vocab.ende.32768...\n", + "/ [1 files][316.4 KiB/316.4 KiB] \n", + "Operation completed over 1 objects/316.4 KiB. \n", + "INFO:tensorflow:Setting T2TModel mode to 'eval'\n", + "INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0\n", + "INFO:tensorflow:Setting hparams.symbol_dropout to 0.0\n", + "INFO:tensorflow:Setting hparams.attention_dropout to 0.0\n", + "INFO:tensorflow:Setting hparams.dropout to 0.0\n", + "INFO:tensorflow:Setting hparams.relu_dropout to 0.0\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "4IblPXLGjuCl", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We've created a Transformer model and fetched an existing training checkpoint. It hasn't created variables yet, and we want to load them from the checkpoint before they're used (restore-on-create) so the first run of the model outputs the correct value. The `tfe.restore_variables_on_create` API looks up variables by name on creation and restores their values." + ] + }, + { + "metadata": { + "id": "o3MWxcAqJoqG", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "outputId": "fbc1b1bf-ffbe-4621-b3cb-5eb855fec3a8" + }, + "cell_type": "code", + "source": [ + "with tfe.restore_variables_on_create(checkpoint_path):\n", + " model_output = translate_model.infer(encode(\"Eager execution\"))\n", + "print(decode(model_output[\"outputs\"]))" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "INFO:tensorflow:Greedy Decoding\n", + "Hinrichtung\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "xk5HV9Hhu9zO", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Using global variable names can get somewhat fragile, so for new code we recommend the object-based `tf.keras.Model.save_weights` or `tf.train.Checkpoint`. However, these require some small code changes to work with existing graph building code.\n", + "\n", + "The Transformer model translates \"Eager execution\" in English to \"Hinrichtung\" in German, which refers to capital punishment rather than getting things done. Transformer first encodes the English, then decodes to German. We'll add a debugging hook at the start of the decode phase (once the encodings have been finalized) and see if we can correct the translation." + ] + }, + { + "metadata": { + "id": "GUGwbYvXZ9-7", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "previous_fast_decode = transformer.fast_decode\n", + "def debug_fn(*args, **kwargs):\n", + " pdb.set_trace()\n", + " return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n", + "transformer.fast_decode = debug_fn # Add our debugging hook to Transformer" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "f61HlvECxJn0", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Now that we've \"monkey patched\" the model, we'll drop into a debugger just before decoding starts. In most cases it'd be simpler to add the `pdb.set_trace()` call to the code directly, but in this case we're working with prepackaged library code.\n", + "\n", + "First, let's find an encoding which represents the correct sense of \"execution\". Then we'll patch part of that encoding into the encoding of \"Eager execution\" to fix the translation. Feel free to poke around with the debugger (e.g. print a Tensor's value), but your main task is to save the encodings by assigning them to an attribute of the function:\n", + "\n", + "```\n", + "(running the next cell drops you into a pdb shell)\n", + "step\n", + "fast_decode.previous_encoding = encoder_output\n", + "continue\n", + "\n", + "```\n", + "\n", + "You can type `next` (or `n`) a few times before `continue` to watch the decoding ops run." + ] + }, + { + "metadata": { + "id": "dX4CPOGSpZrb", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 179 + }, + "outputId": "6de38c31-836f-40ef-b701-e42908172619" + }, + "cell_type": "code", + "source": [ + "model_output = translate_model.infer(encode(\"Immediate running\"))\n", + "print(decode(model_output[\"outputs\"]))" + ], + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "text": [ + "> (4)debug_fn()\n", + "-> return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n", + "(Pdb) step\n", + "--Call--\n", + "> /usr/local/lib/python2.7/dist-packages/tensor2tensor/models/transformer.py(427)fast_decode()\n", + "-> def fast_decode(encoder_output,\n", + "(Pdb) fast_decode.previous_encoding = encoder_output\n", + "(Pdb) continue\n", + "Sofortige Durchführung\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "-ZEZciV4FpLo", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Now we have an encoding saved which gets the correct sense for \"execution\"." + ] + }, + { + "metadata": { + "id": "QeC_oDVqHD_v", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 179 + }, + "outputId": "253c9af1-003e-46bd-8bf5-db968cf6a8cf" + }, + "cell_type": "code", + "source": [ + "# Assumes you followed the pdb instructions above!\n", + "transformer.fast_decode.previous_encoding" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 8 + } + ] + }, + { + "metadata": { + "id": "bC9JjeDcHEav", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Let's replace part of the encoding for \"Eager execution\" with the encoding of \"Immediate running\".\n", + "\n", + "Again we'll drop into a pdb shell. This time we'll run some TensorFlow operations to patch the encodings while the model is running.\n", + "\n", + "```\n", + "(running the next cell again drops you into a pdb shell)\n", + "step\n", + "encoder_output = tf.concat([fast_decode.previous_encoding[:, :3], encoder_output[:, 3:]], axis=1)\n", + "continue\n", + "```" + ] + }, + { + "metadata": { + "id": "t2as_Kn1h65G", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 179 + }, + "outputId": "5b4e546e-3bb4-4761-c545-467b631e3ffe" + }, + "cell_type": "code", + "source": [ + "model_output = translate_model.infer(encode(\"Eager execution\"))\n", + "print(decode(model_output[\"outputs\"]))" + ], + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "text": [ + "> (4)debug_fn()\n", + "-> return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n", + "(Pdb) step\n", + "--Call--\n", + "> /usr/local/lib/python2.7/dist-packages/tensor2tensor/models/transformer.py(427)fast_decode()\n", + "-> def fast_decode(encoder_output,\n", + "(Pdb) encoder_output = tf.concat([fast_decode.previous_encoding[:, :3], encoder_output[:, 3:]], axis=1)\n", + "(Pdb) continue\n", + "sofortige Ausführung\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "rK6tYZ23I2cm", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We get a different decoding, with the correct sense of \"execution\". Likely we're keeping just the encoding of \"tion\" from \"Eager execution\", so no great breakthrough in translation modeling.\n", + "\n", + "Similarly it's possible to modify attention vectors, or change words during decoding to help debug a beam search." + ] + }, + { + "metadata": { + "id": "Nb-4ipYNRWxA", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "This colab was adapted from the [Tensor2Tensor colab](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb). Credit to Ankur Taly for its concept." + ] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index 2d2aba6908b168e0bf63f4706b6344cbb4ca82bd..23f33d0230b0b9fa906636a9df4e046c6873d90b 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -4,8 +4,8 @@ Eager execution is a feature that makes TensorFlow execute operations immediately: concrete values are returned, instead of creating a computational graph that is executed later. -A user guide is available: https://www.tensorflow.org/programmers_guide/eager -([source file](../../../../docs_src/programmers_guide/eager.md)) +A user guide is available: https://www.tensorflow.org/guide/eager +([source file](../../../../docs_src/guide/eager.md)) We welcome feedback through [GitHub issues](https://github.com/tensorflow/tensorflow/labels/comp:eager). diff --git a/tensorflow/contrib/eager/python/metrics.py b/tensorflow/contrib/eager/python/metrics.py index 3e3100427376ddd480b50d967cf53e7831aaefb2..04b7b1165e19612be2fa878f83effbe814fc5c46 100644 --- a/tensorflow/contrib/eager/python/metrics.py +++ b/tensorflow/contrib/eager/python/metrics.py @@ -22,5 +22,6 @@ from __future__ import print_function from tensorflow.contrib.eager.python.metrics_impl import * from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['Accuracy', 'Mean', 'Metric'] +_allowed_symbols = ['Accuracy', 'Mean', 'Metric', 'CategoricalAccuracy', + 'BinaryAccuracy', 'SparseAccuracy'] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index c947ed9dcc415670a820f8a5cd9eaaf07334cfc3..6efafccd6b93ad58da395e0b2e1e647809af62ad 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -291,8 +291,6 @@ class Metric(checkpointable.CheckpointableBase): class Mean(Metric): """Computes the (weighted) mean of the given values.""" - # TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64? - # Or defaults to type of the input if it is tf.float32, else tf.float64? def __init__(self, name=None, dtype=dtypes.float64, use_global_variables=False): @@ -345,9 +343,14 @@ class Mean(Metric): class Accuracy(Mean): - """Calculates how often `predictions` matches `labels`.""" + """Calculates how often `predictions` matches `labels`. + Attributes: + name: name of the accuracy object + dtype: data type of the tensor + """ def __init__(self, name=None, dtype=dtypes.float64): + """Inits Accuracy class with name and dtype.""" super(Accuracy, self).__init__(name=name, dtype=dtype) def call(self, labels, predictions, weights=None): @@ -372,8 +375,151 @@ class Accuracy(Mean): array_ops.shape(labels), array_ops.shape(predictions), message="Shapes of labels and predictions are unequal") matches = math_ops.equal(labels, predictions) - matches = math_ops.cast(matches, dtypes.float64) + matches = math_ops.cast(matches, self.dtype) super(Accuracy, self).call(matches, weights=weights) if weights is None: return labels, predictions return labels, predictions, weights + + +class CategoricalAccuracy(Mean): + """Calculates how often `predictions` matches `labels`. + + This class is compatible with `tf.keras.losses.categorical_crossentropy`, + `tf.nn.softmax_cross_entropy_with_logits_v2`, + `tf.losses.softmax_cross_entropy`. + + Attributes: + name: name of the accuracy object. + dtype: data type of tensor. + """ + + def __init__(self, name=None, dtype=dtypes.float64): + """Inits CategoricalAccuracy with name and dtype.""" + super(CategoricalAccuracy, self).__init__(name=name, dtype=dtype) + + def call(self, labels, predictions, weights=None): + """Accumulate accuracy statistics. + + `labels` and `predictions` should have the same shape. + As argmax is being done here, labels and predictions type + can be different. + + Args: + labels: One-hot Tensor. + predictions: Tensor with the logits or probabilities for each example. + weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. + """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions), + message="Shapes of labels and predictions are unequal") + labels = math_ops.argmax(labels, axis=-1) + predictions = math_ops.argmax(predictions, axis=-1) + matches = math_ops.equal(labels, predictions) + matches = math_ops.cast(matches, self.dtype) + super(CategoricalAccuracy, self).call(matches, weights=weights) + if weights is None: + return labels, predictions + return labels, predictions, weights + + +class BinaryAccuracy(Mean): + """Calculates how often `predictions` matches `labels`. + + This class is compatible with `tf.keras.losses.binary_crossentropy`, + `tf.losses.sigmoid_cross_entropy`, + `tf.nn.sigmoid_cross_entropy_with_logits`. + If there is more than one label, this will become multi-label classification. + + Attributes: + name: name of the accuracy object. + threshold: Used for rounding off the predictions. + If the predictions are, + 1. probabilities then set the threshold to 0.5. + 2. logits then set the threshold to 0. + You can set the threshold appropriately, + to trade off with precision and recall. + dtype: data type of tensor. + """ + + def __init__(self, threshold, name=None, dtype=dtypes.float64): + """Inits BinaryAccuracy with name, threshold and dtype.""" + + super(BinaryAccuracy, self).__init__(name=name, dtype=dtype) + self.threshold = threshold + + def call(self, labels, predictions, weights=None): + """Accumulate accuracy statistics. + + `labels` and `predictions` should have the same shape and type. + + Args: + labels: Binary Tensor(containing 0 or 1). + predictions: Tensor with probabilities or logits. + weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. + """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions), + message="Shapes of labels and predictions are unequal") + predictions = ops.convert_to_tensor(predictions) + predictions = predictions > self.threshold + matches = math_ops.equal(labels, predictions) + matches = math_ops.cast(matches, self.dtype) + super(BinaryAccuracy, self).call(matches, weights=weights) + if weights is None: + return labels, predictions + return labels, predictions, weights + + +class SparseAccuracy(Mean): + """Calculates how often `predictions` matches `labels`. + + This class is compatible with + `tf.keras.losses.sparse_categorical_crossentropy`, + `tf.nn.sparse_softmax_cross_entropy_with_logits`, + `tf.losses.sparse_softmax_cross_entropy`. + + Attributes: + name: name of the accuracy object + dtype: data type of tensor. + """ + + def __init__(self, name=None, dtype=dtypes.float64): + """Inits SparseAccuracy with name and dtype.""" + + super(SparseAccuracy, self).__init__(name=name, dtype=dtype) + + def call(self, labels, predictions, weights=None): + """Accumulate accuracy statistics. + + `labels` and `predictions` should have the same shape except the + predictions must have one additional trailing dimension equal to the + number of classes(you want to predict). + + Type of labels and predictions can be different. + + Args: + labels: Tensor of shape (batch_size, ) containing integers + predictions: Tensor with the logits or probabilities for each example. + weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. + """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions)[0], + message="First axis of labels and predictions is unequal") + predictions = math_ops.argmax(predictions, axis=-1) + labels = math_ops.cast(labels, dtypes.int64) + matches = math_ops.equal(labels, predictions) + matches = math_ops.cast(matches, self.dtype) + super(SparseAccuracy, self).call(matches, weights=weights) + if weights is None: + return labels, predictions + return labels, predictions, weights diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 02ee05487515b81bfae70d02c1dfdb6d816b77c7..20d938d492bf78fab852c638ba675d7ee6ed9073 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -118,6 +118,39 @@ class MetricsTest(test.TestCase): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) + def testCategoricalAccuracy(self): + m = metrics.CategoricalAccuracy() + m([[1, 0, 0, 0], [0, 1, 0, 0]], + [[0.6, 0.1, 0.25, 0.05], [0.4, 0.05, 0.45, 0.0]]) # 1/2 correct + m([[0, 0, 0, 1]], [[0.25, 0.95, 0.25, 0.0]]) # 0/1 correct + m([[1, 0, 0, 0], [0, 1, 0, 0]], + [[0.99, 0.01, 0.0, 0.0], [0.35, 0.35, 0.3, 0.0]]) # 1/2 correct + self.assertEqual(2.0/5, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) + + def testBinaryAccuracy(self): + m = metrics.BinaryAccuracy(threshold=0) + # as threshold is 0 hence the predictions are logits + m([[0, 0, 0, 0]], + [[-4.2, 4.5, 1.2, -1.1]]) # 2/4 correct + m([[0, 1]], [[-5.3, 11.65]]) # 2/2 correct + m([[0, 1], [1, 1]], + [[-5.3, 11.65], [-10.32, 56.38]]) # 3/4 correct + self.assertEqual(7.0/10, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) + + def testSparseAccuracy(self): + m = metrics.SparseAccuracy() + m([0, 2], + [[0.6, 0.1, 0.25, 0.05], [0.4, 0.05, 0.45, 0.0]]) # 2/2 correct + m([1], [[0.25, 0.95, 0.25, 0.0]]) # 1/1 correct + m([0, 3], [[0.99, 0.01, 0.0, 0.0], [0.35, 0.35, 0.3, 0.0]]) # 1/2 correct + self.assertEqual(4.0/5, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) + def testAccuracyDifferentShapes(self): m = metrics.Accuracy() with self.assertRaises(errors.InvalidArgumentError): @@ -173,7 +206,7 @@ class MetricsTest(test.TestCase): sess.run(accumulate, feed_dict={p: 7}) self.assertAllEqual(m.result().eval(), 7) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphAndEagerTensor(self): m = metrics.Mean() inputs = ops.convert_to_tensor([1.0, 2.0]) @@ -221,7 +254,7 @@ class MetricsTest(test.TestCase): self.assertAllEqual(m2.result().eval(), 2.0) self.assertAllEqual(m1.result().eval(), 1.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestore(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index c92bd15b253b67a3301cd562046a4467e1bf877d..240f213c602395b8589d39c3ecd90b602ffa9848 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -126,7 +126,7 @@ class NetworkTest(test.TestCase): self.assertAllEqual([[17.0], [34.0]], self.evaluate(result)) # TODO(allenl): This test creates garbage in some Python versions - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNetworkSaveRestoreAlreadyBuilt(self): net = MyNetwork(name="abcd") with self.assertRaisesRegexp( @@ -138,7 +138,7 @@ class NetworkTest(test.TestCase): self._save_modify_load_network_built(net, global_step=10) # TODO(allenl): This test creates garbage in some Python versions - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestoreDefaultGlobalStep(self): net = MyNetwork(name="abcd") net(constant_op.constant([[2.0]])) @@ -149,7 +149,7 @@ class NetworkTest(test.TestCase): self.assertIn("abcd-4242", save_path) # TODO(allenl): This test creates garbage in some Python versions - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNetworkSaveAndRestoreIntoUnbuilt(self): save_dir = self.get_temp_dir() net1 = MyNetwork() @@ -166,7 +166,7 @@ class NetworkTest(test.TestCase): self.assertAllEqual(self.evaluate(net1.variables[0]), self.evaluate(net2.variables[0])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNetworkMatchesLayerVariableNames(self): zero = constant_op.constant([[0.]]) layer_one = core.Dense(1, use_bias=False) @@ -193,7 +193,7 @@ class NetworkTest(test.TestCase): self.assertEqual("two_layer_net/" + layer_two.variables[0].name, net.second.variables[0].name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLoadIntoUnbuiltSharedLayer(self): class Owner(network.Network): @@ -272,7 +272,7 @@ class NetworkTest(test.TestCase): network.restore_network_checkpoint( load_into, save_path, map_func=_restore_map_func) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRestoreIntoSubNetwork(self): class Parent(network.Network): @@ -327,7 +327,7 @@ class NetworkTest(test.TestCase): # The checkpoint is incompatible. network.restore_network_checkpoint(save_into_parent, checkpoint) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCustomMapCollisionErrors(self): class Parent(network.Network): @@ -372,7 +372,7 @@ class NetworkTest(test.TestCase): network.restore_network_checkpoint( loader, checkpoint, map_func=lambda n: "foo") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDefaultMapCollisionErrors(self): one = constant_op.constant([[1.]]) @@ -571,7 +571,7 @@ class NetworkTest(test.TestCase): expected_start="my_network_1/dense/", actual=outside_net_after.trainable_weights[0].name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVariableScopeStripping(self): with variable_scope.variable_scope("scope1"): with variable_scope.variable_scope("scope2"): @@ -596,7 +596,7 @@ class NetworkTest(test.TestCase): self.assertAllEqual([[42.]], self.evaluate(restore_net.variables[0])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerNamesRespected(self): class ParentNetwork(network.Network): @@ -677,7 +677,7 @@ class NetworkTest(test.TestCase): self.assertStartsWith(expected_start="my_network_1/dense/", actual=net2.trainable_weights[0].name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableAnonymous(self): # The case where no explicit names are specified. We make up unique names, @@ -721,7 +721,7 @@ class NetworkTest(test.TestCase): self.assertEqual("my_network", net2.first.name) self.assertEqual("my_network_1", net2.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableExplicit(self): # We have explicit network names and everything is globally unique. @@ -750,7 +750,7 @@ class NetworkTest(test.TestCase): self.assertEqual("first_unique_child_name", net.first.name) self.assertEqual("second_unique_child_name", net.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerNetworkNameInteractions(self): # Same base name as core.Dense; Networks and non-Network Layers with the @@ -801,7 +801,7 @@ class NetworkTest(test.TestCase): actual=net.trainable_weights[4].name) self.assertEqual("mixed_layer_network", net.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableExplicitCollisions(self): # We have explicit network names and they are unique within the layer @@ -831,7 +831,7 @@ class NetworkTest(test.TestCase): self.assertEqual("nonunique_name", net.first.name) self.assertEqual("second_unique_child_name", net.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableExplicitWithAnonymousParent(self): # A parent network is instantiated multiple times with explicitly named @@ -873,7 +873,7 @@ class NetworkTest(test.TestCase): self.assertEqual("first_unique_child_name", net2.first.name) self.assertEqual("second_unique_child_name", net2.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestableExplicitSameLayerCollisions(self): # We have explicit network names and they are _not_ unique within the layer @@ -891,7 +891,7 @@ class NetworkTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "nonunique_name"): ParentNetwork() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAnonymousVariableSharing(self): # Two "owned" Networks @@ -989,7 +989,7 @@ class NetworkTest(test.TestCase): self.assertEqual("my_network", net4.first.name) self.assertEqual("my_network", net4.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRecursiveLayerRenaming(self): core.Dense(1) # Under default Layer naming, would change subsequent names. @@ -1041,7 +1041,7 @@ class NetworkTest(test.TestCase): self.assertEqual("dense", net.second.first.name) self.assertEqual("dense_1", net.second.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallInDifferentOrderThanConstruct(self): shared_network = MyNetwork() @@ -1091,7 +1091,7 @@ class NetworkTest(test.TestCase): self.assertTrue(net2.first is net1.first) self.assertEqual("my_network", net2.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerCallInDifferentOrderThanConstruct(self): # Same idea as testCallInDifferentOrderThanConstruct, but this time with a # non-Network Layer shared between two Networks rather than a @@ -1144,7 +1144,7 @@ class NetworkTest(test.TestCase): self.assertTrue(net2.first is net1.first) self.assertEqual("dense", net2.second.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerAlreadyBuilt(self): one = constant_op.constant([[1.]]) core.Dense(1, use_bias=False) # pre-built layers use global naming diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py new file mode 100644 index 0000000000000000000000000000000000000000..76f48eeb1cab9d1f014adeafe4827cb5d3a8c77d --- /dev/null +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -0,0 +1,178 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for remote eager execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +import numpy as np + +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.core.protobuf import tensorflow_server_pb2 +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib + +JOB_NAME = "remote_device" +ALT_JOB_NAME = "alt_remote_device" + + +def run_sync_and_async(f): + """Execute all test methods in the given class in sync and async modes.""" + + @functools.wraps(f) + def decorator(self, *args, **kwargs): + with context.execution_mode(context.ASYNC): + f(self, *args, **kwargs) + + with context.execution_mode(context.SYNC): + f(self, *args, **kwargs) + + return decorator + + +def get_server_def(job_name, local_server_port, remote_server_addresses, + task_index): + """Returns a server def with a single job + multiple tasks.""" + cluster_def = cluster_pb2.ClusterDef() + job_def = cluster_def.job.add() + job_def.name = job_name + job_def.tasks[0] = "localhost:%d" % local_server_port + + for i, remote_server_address in enumerate(remote_server_addresses, start=1): + job_def.tasks[i] = remote_server_address + + server_def = tensorflow_server_pb2.ServerDef( + cluster=cluster_def, + job_name=job_name, + task_index=task_index, + protocol="grpc") + + return server_def + + +class RemoteExecutionTest(test.TestCase): + + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super(RemoteExecutionTest, self).__init__(methodName) + self._cached_server1 = server_lib.Server.create_local_server() + self._cached_server2 = server_lib.Server.create_local_server() + + os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1" + + self._cached_server1_target = self._cached_server1.target[len("grpc://"):] + self._cached_server2_target = self._cached_server2.target[len("grpc://"):] + + # Start the local server. + context.set_server_def( + server_def=get_server_def( + JOB_NAME, + local_server_port=0, + remote_server_addresses=[ + self._cached_server1_target, self._cached_server2_target + ], + task_index=0)) + + @run_sync_and_async + def testDefunMatmul(self): + """Basic remote eager execution with defun.""" + + mm_defun = function.defun(math_ops.matmul) + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + x1 = array_ops.ones([2, 2]) + with ops.device("job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME): + x2 = array_ops.ones([2, 2]) + y = mm_defun(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @run_sync_and_async + def testSimpleMatmul(self): + """Basic remote eager execution.""" + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + x1 = array_ops.ones([2, 2]) + with ops.device("job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME): + x2 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + @run_sync_and_async + def testSimpleWeightRead(self): + """Basic remote eager weight read.""" + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + w = resource_variable_ops.ResourceVariable([[2.0]]) + loss = w * w + np.testing.assert_array_equal([[4.0]], loss.numpy()) + + @run_sync_and_async + def testTapeWeightRead(self): + """Remote eager weight read in a tape.""" + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + w = resource_variable_ops.ResourceVariable([[3.0]]) + with backprop.GradientTape() as tape: + loss = w * w + + grad = tape.gradient(loss, w) + np.testing.assert_array_equal([[9.0]], loss.numpy()) + np.testing.assert_array_equal([[6.0]], grad.numpy()) + + @run_sync_and_async + def testServerDefChanged(self): + """Update server def, and run ops on new cluster.""" + context.set_server_def( + server_def=get_server_def( + ALT_JOB_NAME, + local_server_port=0, + remote_server_addresses=[ + self._cached_server1_target, self._cached_server2_target + ], + task_index=0)) + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % ALT_JOB_NAME): + x1 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + # Set the server def back to JOB_NAME + context.set_server_def( + server_def=get_server_def( + JOB_NAME, + local_server_port=0, + remote_server_addresses=[ + self._cached_server1_target, self._cached_server2_target + ], + task_index=0)) + + with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME): + x1 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x1) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + +if __name__ == "__main__": + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index fdaca90fd13576e6ca8a3408aaf528dbc2384b0c..f9c716360c5755ee1902b576545d776725f9966f 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -125,8 +125,8 @@ class Saver(object): Args: var_list: The list of variables that will be saved and restored. Either a - list of `tfe.Variable` objects, or a dictionary mapping names to - `tfe.Variable` objects. + list of `tf.Variable` objects, or a dictionary mapping names to + `tf.Variable` objects. Raises: RuntimeError: if invoked when eager execution has not been enabled. @@ -161,7 +161,7 @@ class Saver(object): Args: file_prefix: Path prefix where parameters were previously saved. Typically obtained from a previous `save()` call, or from - @{tf.train.latest_checkpoint}. + `tf.train.latest_checkpoint`. """ with ops.device("/device:CPU:0"): self._saver.restore(None, file_prefix) diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index fee9db46fa4f79d7dd613436726e8ddad51faf1c..de11d00a1a0a34372467eedb02d790c920e7f449 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -34,6 +34,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@run @@enable_eager_execution +@@enable_remote_eager_execution @@custom_gradient @@ -68,6 +69,9 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@async_clear_error @@run_test_in_graph_and_eager_modes +@@run_all_tests_in_graph_and_eager_modes + +@@TensorSpec @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@ -112,7 +116,9 @@ from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr +from tensorflow.python.framework.tensor_spec import TensorSpec from tensorflow.python.framework.ops import enable_eager_execution +from tensorflow.python.framework.ops import enable_eager_execution_internal as enable_remote_eager_execution from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes as run_all_tests_in_graph_and_eager_modes @@ -121,7 +127,7 @@ from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Vari from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template -from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.tracking import Checkpointable from tensorflow.python.training.checkpointable.util import CheckpointableSaver from tensorflow.python.training.checkpointable.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index db50b33af2e4f1cc6575d4b0d416d6d2669b5c35..4454abfb9667f824b9de0100bb81bae24ad5f7a6 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -27,7 +27,6 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import numerics -from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer @@ -45,12 +44,6 @@ class TFETest(test_util.TensorFlowTestCase): r'indices = 7 is not in \[0, 3\)'): array_ops.gather([0, 1, 2], 7) - def testVariableError(self): - with self.assertRaisesRegexp( - RuntimeError, - r'Variable not supported when eager execution is enabled'): - variables.Variable(initial_value=1.0) - def testGradients(self): def square(x): diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 1937ffb583bc727df76470d072b35fb3c9acaa88..82272bf1207c9b85243bb1c2d92a2c6704a2761e 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -18,6 +18,7 @@ py_library( ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":early_stopping", ":export", ":extenders", ":head", @@ -27,7 +28,8 @@ py_library( ":multi_head", ":replicate_model_fn", ":rnn", - "//tensorflow/python:util", + ":saved_model_estimator", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -53,22 +55,10 @@ py_test( deps = [ ":baseline", ":head", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:session", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variables", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -95,11 +85,8 @@ py_test( ], deps = [ ":boosted_trees", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/feature_column", "//third_party/py/numpy", ], ) @@ -109,7 +96,7 @@ py_library( srcs = ["python/estimator/dnn.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:nn", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:dnn", ], @@ -117,7 +104,7 @@ py_library( py_test( name = "dnn_test", - size = "small", + size = "medium", srcs = ["python/estimator/dnn_test.py"], srcs_version = "PY2AND3", tags = [ @@ -128,16 +115,11 @@ py_test( deps = [ ":dnn", ":head", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:dnn_testing_utils", "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -148,7 +130,7 @@ py_library( srcs = ["python/estimator/dnn_linear_combined.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:nn", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:dnn_linear_combined", ], @@ -167,18 +149,12 @@ py_test( deps = [ ":dnn_linear_combined", ":head", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:nn", - "//tensorflow/python:platform", - "//tensorflow/python:summary", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:dnn_testing_utils", "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:linear_testing_utils", "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -191,10 +167,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:clip_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:util", @@ -210,18 +183,11 @@ py_test( tags = ["notsan"], # b/62863147 deps = [ ":extenders", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/predictor", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python:variables", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/estimator:linear", - "//tensorflow/python/feature_column", "//third_party/py/numpy", ], ) @@ -245,21 +211,11 @@ py_test( tags = ["notsan"], # b/62863147 deps = [ ":export", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:metrics", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:session", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variables", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/saved_model:loader", - "//tensorflow/python/saved_model:tag_constants", ], ) @@ -270,25 +226,12 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:nn", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:summary", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:signature_constants", ], ) @@ -299,25 +242,10 @@ py_test( srcs_version = "PY2AND3", deps = [ ":head", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:signature_constants", "//third_party/py/numpy", "@six_archive//:six", ], @@ -330,8 +258,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:estimator_py", ], ) @@ -344,10 +271,7 @@ py_test( tags = ["notsan"], deps = [ ":hooks", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:estimator_py", "//third_party/py/numpy", "@six_archive//:six", @@ -376,16 +300,11 @@ py_test( deps = [ ":head", ":linear", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:linear_testing_utils", "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -398,8 +317,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:framework_ops", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:linear", ], @@ -412,9 +330,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":logit_fns", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:session", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:model_fn", ], ) @@ -426,18 +342,11 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:summary", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/saved_model:signature_constants", "@six_archive//:six", ], ) @@ -450,15 +359,10 @@ py_test( deps = [ ":head", ":multi_head", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:string_ops", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/saved_model:signature_constants", "//third_party/py/numpy", "@six_archive//:six", ], @@ -471,24 +375,10 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:device", - "//tensorflow/python:device_lib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:util", - "//tensorflow/python/ops/losses", "@six_archive//:six", ], ) @@ -499,6 +389,7 @@ cuda_py_test( srcs = ["python/estimator/replicate_model_fn_test.py"], additional_deps = [ "@absl_py//absl/testing:parameterized", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:export_export", @@ -507,21 +398,6 @@ cuda_py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:optimizers", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", ":replicate_model_fn", ], tags = [ @@ -537,22 +413,11 @@ py_library( srcs_version = "PY2AND3", deps = [ ":extenders", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/contrib/feature_column:feature_column_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:rnn", - "//tensorflow/python:rnn_cell", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", "//tensorflow/python/estimator", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:optimizers", - "//tensorflow/python/feature_column", "@six_archive//:six", ], ) @@ -571,22 +436,76 @@ py_test( deps = [ ":head", ":rnn", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/contrib/data", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:check_ops", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:parsing_utils", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "early_stopping", + srcs = ["python/estimator/early_stopping.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py_no_contrib", + "//tensorflow/python/estimator", + ], +) + +py_test( + name = "early_stopping_test", + srcs = ["python/estimator/early_stopping_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":early_stopping", + "//tensorflow:tensorflow_py_no_contrib", + "//tensorflow/python/estimator", + "@absl_py//absl/testing:parameterized", + ], +) + +py_library( + name = "saved_model_estimator", + srcs = ["python/estimator/saved_model_estimator.py"], + deps = [ + ":export", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/saved_model", + ], +) + +py_test( + name = "saved_model_estimator_test", + size = "medium", + srcs = ["python/estimator/saved_model_estimator_test.py"], + srcs_version = "PY2AND3", + tags = [ + "notsan", + ], + deps = [ + ":export", + ":saved_model_estimator", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", - "//tensorflow/python:lib", - "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:platform", "//tensorflow/python:state_ops", - "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variables", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/estimator:parsing_utils", - "//tensorflow/python/feature_column", - "//third_party/py/numpy", - "@six_archive//:six", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", ], ) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 788ac5ca7046d6dd30a3d5520b243944532622fa..6ad3a4a604901aec8390774f12a039a220e072a2 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -23,6 +23,7 @@ from tensorflow.contrib.estimator.python.estimator.baseline import * from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * +from tensorflow.contrib.estimator.python.estimator.early_stopping import * from tensorflow.contrib.estimator.python.estimator.export import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * @@ -32,6 +33,8 @@ from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import * from tensorflow.contrib.estimator.python.estimator.rnn import * +from tensorflow.contrib.estimator.python.estimator.saved_model_estimator import * +from tensorflow.python.estimator.export.export import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import @@ -42,6 +45,7 @@ _allowed_symbols = [ 'clip_gradients_by_norm', 'forward_features', 'InMemoryEvaluatorHook', + 'StopAtCheckpointStepHook', 'logistic_regression_head', 'multi_class_head', 'multi_head', @@ -63,6 +67,15 @@ _allowed_symbols = [ 'RNNEstimator', 'export_saved_model_for_mode', 'export_all_saved_models', + 'make_early_stopping_hook', + 'read_eval_metrics', + 'stop_if_lower_hook', + 'stop_if_higher_hook', + 'stop_if_no_increase_hook', + 'stop_if_no_decrease_hook', + 'build_raw_supervised_input_receiver_fn', + 'build_supervised_input_receiver_fn_from_input_fn', + 'SavedModelEstimator' ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py index d0e3e670f7332811c1bfdaea65b0308ce59ade59..505c94e97192afdd4e2ce9af2abb9825320751f2 100644 --- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py +++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py @@ -113,6 +113,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 18., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -141,6 +143,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 27., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -166,7 +170,9 @@ class BaselineEstimatorEvaluationTest(test.TestCase): self.assertItemsEqual( (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, - ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys()) + metric_keys.MetricKeys.PREDICTION_MEAN, + metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP), + eval_metrics.keys()) # Logit is bias which is [46, 58] self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS]) diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py index bd641014e9eec6623d66574bccd08ff03ebc28ac..7ed77bcce6f00ed13e9952951800f1017d582f19 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -49,7 +49,9 @@ class _BoostedTreesEstimator(estimator.Estimator): l2_regularization=0., tree_complexity=0., min_node_weight=0., - config=None): + config=None, + center_bias=False, + pruning_mode='none'): """Initializes a `BoostedTreesEstimator` instance. Args: @@ -82,17 +84,35 @@ class _BoostedTreesEstimator(estimator.Estimator): considered. The value will be compared with sum(leaf_hessian)/ (batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. + """ # pylint:disable=protected-access # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( - features, labels, mode, head, feature_columns, tree_hparams, - n_batches_per_layer, config) + features, + labels, + mode, + head, + feature_columns, + tree_hparams, + n_batches_per_layer, + config=config) super(_BoostedTreesEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) @@ -114,7 +134,9 @@ def boosted_trees_classifier_train_in_memory( tree_complexity=0., min_node_weight=0., config=None, - train_hooks=None): + train_hooks=None, + center_bias=False, + pruning_mode='none'): """Trains a boosted tree classifier with in memory dataset. Example: @@ -186,7 +208,18 @@ def boosted_trees_classifier_train_in_memory( considered. The value will be compared with sum(leaf_hessian)/ (batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. - train_hooks: a list of Hook instances to be passed to estimator.train(). + train_hooks: a list of Hook instances to be passed to estimator.train() + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. Returns: a `BoostedTreesClassifier` instance created with the given arguments and @@ -207,7 +240,7 @@ def boosted_trees_classifier_train_in_memory( # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( @@ -247,7 +280,9 @@ def boosted_trees_regressor_train_in_memory( tree_complexity=0., min_node_weight=0., config=None, - train_hooks=None): + train_hooks=None, + center_bias=False, + pruning_mode='none'): """Trains a boosted tree regressor with in memory dataset. Example: @@ -313,6 +348,17 @@ def boosted_trees_regressor_train_in_memory( (batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. train_hooks: a list of Hook instances to be passed to estimator.train(). + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. Returns: a `BoostedTreesClassifier` instance created with the given arguments and @@ -332,7 +378,7 @@ def boosted_trees_regressor_train_in_memory( # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py index 76cbefe5e94502188388df6fc2816d130ac896d5..b1581f37509b5dc2bec98942e88c024905f25d93 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py @@ -115,6 +115,70 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): eval_res = est.evaluate(input_fn=input_fn, steps=1) self.assertAllClose(eval_res['average_loss'], 1.008551) + def testTrainAndEvaluateEstimatorWithCenterBias(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5, + center_bias=True) + + # It will stop after 11 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + # 10 steps for training and 2 step for bias centering. + self._assert_checkpoint( + est.model_dir, global_step=12, finalized_trees=2, attempted_layers=10) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 0.614642) + + def testTrainAndEvaluateEstimatorWithPrePruning(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5, + tree_complexity=0.001, + pruning_mode='pre') + + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + # We stop actually after 2*depth*n_trees steps (via a hook) because we still + # could not grow 2 trees of depth 5 (due to pre-pruning). + self._assert_checkpoint( + est.model_dir, global_step=21, finalized_trees=0, attempted_layers=21) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 3.83943) + + def testTrainAndEvaluateEstimatorWithPostPruning(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5, + tree_complexity=0.001, + pruning_mode='post') + + # It will stop after 10 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.37652) + def testInferEstimator(self): train_input_fn = _make_train_input_fn(is_classification=False) predict_input_fn = numpy_io.numpy_input_fn( @@ -139,6 +203,33 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], [pred['predictions'] for pred in predictions]) + def testInferEstimatorWithCenterBias(self): + train_input_fn = _make_train_input_fn(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5, + center_bias=True, + head=self._head) + + # It will stop after 6 steps because of the max depth and num trees (5 for + # training and 2 for bias centering). + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(train_input_fn, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=7, finalized_trees=1, attempted_layers=5) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + + self.assertAllClose( + [[1.634501], [1.325703], [1.187431], [2.019683], [2.832683]], + [pred['predictions'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryAndEvalAndInfer(self): train_input_fn = _make_train_input_fn(is_classification=True) predict_input_fn = numpy_io.numpy_input_fn( @@ -159,14 +250,65 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertAllClose([[0], [1], [1], [0], [0]], [pred['class_ids'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryAndEvalAndInferWithCenterBias(self): + train_input_fn = _make_train_input_fn(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, + feature_columns=self._feature_columns, + n_trees=1, + max_depth=5, + center_bias=True) + # It will stop after 5 steps + 3 for bias, because of the max depth and num + # trees. + self._assert_checkpoint( + est.model_dir, global_step=8, finalized_trees=1, attempted_layers=5) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + + def testBinaryClassifierTrainInMemoryAndEvalAndInferWithPrePruning(self): + train_input_fn = _make_train_input_fn(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, + feature_columns=self._feature_columns, + n_trees=1, + max_depth=5, + pruning_mode='pre', + tree_complexity=0.01) + # We stop actually after 2*depth*n_trees steps (via a hook) because we still + # could not grow 1 trees of depth 5 (due to pre-pruning). + self._assert_checkpoint( + est.model_dir, global_step=11, finalized_trees=0, attempted_layers=11) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryWithDataset(self): train_input_fn = _make_train_input_fn_dataset(is_classification=True) predict_input_fn = numpy_io.numpy_input_fn( x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) est = boosted_trees.boosted_trees_classifier_train_in_memory( - train_input_fn=train_input_fn, feature_columns=self._feature_columns, - n_trees=1, max_depth=5) + train_input_fn=train_input_fn, + feature_columns=self._feature_columns, + n_trees=1, + max_depth=5) # It will stop after 5 steps because of the max depth and num trees. self._assert_checkpoint( est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py index 7ff25b95c079c7e06d29e874bcaa0d2c13e7167e..9efa8f474d865a36788cba40a15404bf0b30a17e 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn.py @@ -53,6 +53,25 @@ class DNNEstimator(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = DNNEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], + hidden_units=[1024, 512, 256], + optimizer=lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + + # Or estimator with warm-starting from a previous checkpoint. + estimator = DNNEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], + hidden_units=[1024, 512, 256], + warm_start_from="/path/to/checkpoint/dir") + # Input builders def input_fn_train: # returns x, y pass @@ -92,7 +111,9 @@ class DNNEstimator(estimator.Estimator): activation_fn=nn.relu, dropout=None, input_layer_partitioner=None, - config=None): + config=None, + warm_start_from=None, + batch_norm=False): """Initializes a `DNNEstimator` instance. Args: @@ -107,8 +128,9 @@ class DNNEstimator(estimator.Estimator): model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to Adagrad optimizer. activation_fn: Activation function applied to each layer. If `None`, will use `tf.nn.relu`. dropout: When not `None`, the probability we will drop out a given @@ -116,6 +138,12 @@ class DNNEstimator(estimator.Estimator): input_layer_partitioner: Optional. Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: `RunConfig` object to configure the runtime settings. + warm_start_from: A string filepath to a checkpoint to warm-start from, or + a `WarmStartSettings` object to fully configure warm-starting. If the + string filepath is provided instead of a `WarmStartSettings`, then all + weights are warm-started, and it is assumed that vocabularies and Tensor + names are unchanged. + batch_norm: Whether to use batch normalization after each hidden layer. """ def _model_fn(features, labels, mode, config): return dnn_lib._dnn_model_fn( # pylint: disable=protected-access @@ -129,6 +157,8 @@ class DNNEstimator(estimator.Estimator): activation_fn=activation_fn, dropout=dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + batch_norm=batch_norm) super(DNNEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) + model_fn=_model_fn, model_dir=model_dir, config=config, + warm_start_from=warm_start_from) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py index ccaf1128bf23af734f7a5722a4dd8c1f0304fab7..724bc2c82f8289bbaa19a1dbbc1dc81b6e158e02 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py @@ -53,12 +53,19 @@ class DNNLinearCombinedEstimator(estimator.Estimator): dnn_hidden_units=[1000, 500, 100], dnn_optimizer=tf.train.ProximalAdagradOptimizer(...)) - # To apply L1 and L2 regularization, you can set optimizers as follows: + # To apply L1 and L2 regularization, you can set dnn_optimizer to: tf.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001, l2_regularization_strength=0.001) - # It is same for FtrlOptimizer. + # To apply learning rate decay, you can set dnn_optimizer to a callable: + lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96) + # It is the same for linear_optimizer. # Input builders def input_fn_train: # returns x, y @@ -103,7 +110,8 @@ class DNNLinearCombinedEstimator(estimator.Estimator): dnn_activation_fn=nn.relu, dnn_dropout=None, input_layer_partitioner=None, - config=None): + config=None, + linear_sparse_combiner='sum'): """Initializes a DNNLinearCombinedEstimator instance. Args: @@ -116,12 +124,16 @@ class DNNLinearCombinedEstimator(estimator.Estimator): used by linear part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the linear part of the model. Defaults to FTRL optimizer. + the linear part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL + optimizer. dnn_feature_columns: An iterable containing all the feature columns used by deep part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the deep part of the model. Defaults to Adagrad optimizer. + the deep part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad + optimizer. dnn_hidden_units: List of hidden units per layer. All layers are fully connected. dnn_activation_fn: Activation function applied to each layer. If None, @@ -131,6 +143,11 @@ class DNNLinearCombinedEstimator(estimator.Estimator): input_layer_partitioner: Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: RunConfig object to configure the runtime settings. + linear_sparse_combiner: A string specifying how to reduce the linear model + if a categorical column is multivalent. One of "mean", "sqrtn", and + "sum" -- these are effectively different ways to do example-level + normalization, which can be useful for bag-of-words features. For more + details, see `tf.feature_column.linear_model`. Raises: ValueError: If both linear_feature_columns and dnn_features_columns are @@ -158,7 +175,8 @@ class DNNLinearCombinedEstimator(estimator.Estimator): dnn_activation_fn=dnn_activation_fn, dnn_dropout=dnn_dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + linear_sparse_combiner=linear_sparse_combiner) super(DNNLinearCombinedEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py index dd009a6753f3231638f93e50fc8f19eae8820139..51b9ce7005cec3910ba73db62a674e4628ca30a2 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py @@ -100,7 +100,8 @@ def _linear_only_estimator_fn( weight_column=None, optimizer='Ftrl', config=None, - partitioner=None): + partitioner=None, + sparse_combiner='sum'): return dnn_linear_combined.DNNLinearCombinedEstimator( head=head_lib.regression_head( weight_column=weight_column, label_dimension=label_dimension, @@ -110,7 +111,8 @@ def _linear_only_estimator_fn( linear_feature_columns=feature_columns, linear_optimizer=optimizer, input_layer_partitioner=partitioner, - config=config) + config=config, + linear_sparse_combiner=sparse_combiner) class LinearOnlyEstimatorEvaluateTest( diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_test.py index 75e3107670d658e55ce23d983e47311f1c180104..050b0428bf7b685229e12561cfb0682d931299d2 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_test.py @@ -38,7 +38,7 @@ from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache -def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): +def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg """Returns a DNNEstimator that uses regression_head.""" return dnn.DNNEstimator( head=head_lib.regression_head( @@ -48,6 +48,12 @@ def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): *args, **kwargs) +def _dnn_estimator_classifier_fn(n_classes=3, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg + """Returns a DNNEstimator that uses multi_class_head.""" + return dnn.DNNEstimator(head=head_lib.multi_class_head(n_classes=n_classes), + *args, **kwargs) + + class DNNEstimatorEvaluateTest( dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase): @@ -75,6 +81,15 @@ class DNNEstimatorTrainTest( self, _dnn_estimator_fn) +class DNNEstimatorWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest, + test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + dnn_testing_utils.BaseDNNWarmStartingTest.__init__( + self, _dnn_estimator_classifier_fn, _dnn_estimator_fn) + + class DNNEstimatorIntegrationTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping.py b/tensorflow/contrib/estimator/python/estimator/early_stopping.py new file mode 100644 index 0000000000000000000000000000000000000000..3eab21d5acaf26f14a73e7fa8e9c50fffc22fe9c --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/early_stopping.py @@ -0,0 +1,469 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for early stopping.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import operator +import os + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging +from tensorflow.python.summary import summary_iterator +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util + +_EVENT_FILE_GLOB_PATTERN = 'events.out.tfevents.*' + + +def make_early_stopping_hook(estimator, + should_stop_fn, + run_every_secs=60, + run_every_steps=None): + """Creates early-stopping hook. + + Returns a `SessionRunHook` that stops training when `should_stop_fn` returns + `True`. + + Usage example: + + ```python + estimator = ... + hook = early_stopping.make_early_stopping_hook( + estimator, should_stop_fn=make_stop_fn(...)) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + should_stop_fn: `callable`, function that takes no arguments and returns a + `bool`. If the function returns `True`, stopping will be initiated by the + chief. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + A `SessionRunHook` that periodically executes `should_stop_fn` and initiates + early stopping if the function returns `True`. + + Raises: + TypeError: If `estimator` is not of type `tf.estimator.Estimator`. + ValueError: If both `run_every_secs` and `run_every_steps` are set. + """ + if not isinstance(estimator, estimator_lib.Estimator): + raise TypeError('`estimator` must have type `tf.estimator.Estimator`. ' + 'Got: {}'.format(type(estimator))) + + if run_every_secs is not None and run_every_steps is not None: + raise ValueError('Only one of `run_every_secs` and `run_every_steps` must ' + 'be set.') + + if estimator.config.is_chief: + return _StopOnPredicateHook(should_stop_fn, run_every_secs, run_every_steps) + else: + return _CheckForStoppingHook() + + +def stop_if_higher_hook(estimator, + metric_name, + threshold, + eval_dir=None, + min_steps=0, + run_every_secs=60, + run_every_steps=None): + """Creates hook to stop if the given metric is higher than the threshold. + + Usage example: + + ```python + estimator = ... + # Hook to stop training if accuracy becomes higher than 0.9. + hook = early_stopping.stop_if_higher_hook(estimator, "accuracy", 0.9) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + metric_name: `str`, metric to track. "loss", "accuracy", etc. + threshold: Numeric threshold for the given metric. + eval_dir: If set, directory containing summary files with eval metrics. By + default, `estimator.eval_dir()` will be used. + min_steps: `int`, stop is never requested if global step is less than this + value. Defaults to 0. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + An early-stopping hook of type `SessionRunHook` that periodically checks + if the given metric is higher than specified threshold and initiates + early stopping if true. + """ + return _stop_if_threshold_crossed_hook( + estimator=estimator, + metric_name=metric_name, + threshold=threshold, + higher_is_better=True, + eval_dir=eval_dir, + min_steps=min_steps, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def stop_if_lower_hook(estimator, + metric_name, + threshold, + eval_dir=None, + min_steps=0, + run_every_secs=60, + run_every_steps=None): + """Creates hook to stop if the given metric is lower than the threshold. + + Usage example: + + ```python + estimator = ... + # Hook to stop training if loss becomes lower than 100. + hook = early_stopping.stop_if_lower_hook(estimator, "loss", 100) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + metric_name: `str`, metric to track. "loss", "accuracy", etc. + threshold: Numeric threshold for the given metric. + eval_dir: If set, directory containing summary files with eval metrics. By + default, `estimator.eval_dir()` will be used. + min_steps: `int`, stop is never requested if global step is less than this + value. Defaults to 0. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + An early-stopping hook of type `SessionRunHook` that periodically checks + if the given metric is lower than specified threshold and initiates + early stopping if true. + """ + return _stop_if_threshold_crossed_hook( + estimator=estimator, + metric_name=metric_name, + threshold=threshold, + higher_is_better=False, + eval_dir=eval_dir, + min_steps=min_steps, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def stop_if_no_increase_hook(estimator, + metric_name, + max_steps_without_increase, + eval_dir=None, + min_steps=0, + run_every_secs=60, + run_every_steps=None): + """Creates hook to stop if metric does not increase within given max steps. + + Usage example: + + ```python + estimator = ... + # Hook to stop training if accuracy does not increase in over 100000 steps. + hook = early_stopping.stop_if_no_increase_hook(estimator, "accuracy", 100000) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + metric_name: `str`, metric to track. "loss", "accuracy", etc. + max_steps_without_increase: `int`, maximum number of training steps with no + increase in the given metric. + eval_dir: If set, directory containing summary files with eval metrics. By + default, `estimator.eval_dir()` will be used. + min_steps: `int`, stop is never requested if global step is less than this + value. Defaults to 0. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + An early-stopping hook of type `SessionRunHook` that periodically checks + if the given metric shows no increase over given maximum number of + training steps, and initiates early stopping if true. + """ + return _stop_if_no_metric_improvement_hook( + estimator=estimator, + metric_name=metric_name, + max_steps_without_improvement=max_steps_without_increase, + higher_is_better=True, + eval_dir=eval_dir, + min_steps=min_steps, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def stop_if_no_decrease_hook(estimator, + metric_name, + max_steps_without_decrease, + eval_dir=None, + min_steps=0, + run_every_secs=60, + run_every_steps=None): + """Creates hook to stop if metric does not decrease within given max steps. + + Usage example: + + ```python + estimator = ... + # Hook to stop training if loss does not decrease in over 100000 steps. + hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + metric_name: `str`, metric to track. "loss", "accuracy", etc. + max_steps_without_decrease: `int`, maximum number of training steps with no + decrease in the given metric. + eval_dir: If set, directory containing summary files with eval metrics. By + default, `estimator.eval_dir()` will be used. + min_steps: `int`, stop is never requested if global step is less than this + value. Defaults to 0. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + An early-stopping hook of type `SessionRunHook` that periodically checks + if the given metric shows no decrease over given maximum number of + training steps, and initiates early stopping if true. + """ + return _stop_if_no_metric_improvement_hook( + estimator=estimator, + metric_name=metric_name, + max_steps_without_improvement=max_steps_without_decrease, + higher_is_better=False, + eval_dir=eval_dir, + min_steps=min_steps, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def read_eval_metrics(eval_dir): + """Helper to read eval metrics from eval summary files. + + Args: + eval_dir: Directory containing summary files with eval metrics. + + Returns: + A `dict` with global steps mapping to `dict` of metric names and values. + """ + eval_metrics_dict = {} + for event in _summaries(eval_dir): + if not event.HasField('summary'): + continue + metrics = {} + for value in event.summary.value: + if value.HasField('simple_value'): + metrics[value.tag] = value.simple_value + if metrics: + eval_metrics_dict[event.step] = metrics + return eval_metrics_dict + + +def _stop_if_threshold_crossed_hook(estimator, metric_name, threshold, + higher_is_better, eval_dir, min_steps, + run_every_secs, run_every_steps): + """Creates early-stopping hook to stop training if threshold is crossed.""" + + if eval_dir is None: + eval_dir = estimator.eval_dir() + + is_lhs_better = operator.gt if higher_is_better else operator.lt + greater_or_lesser = 'greater than' if higher_is_better else 'less than' + + def stop_if_threshold_crossed_fn(): + """Returns `True` if the given metric crosses specified threshold.""" + + eval_results = read_eval_metrics(eval_dir) + + for step, metrics in eval_results.items(): + if step < min_steps: + continue + val = metrics[metric_name] + if is_lhs_better(val, threshold): + tf_logging.info( + 'At step %s, metric "%s" has value %s which is %s the configured ' + 'threshold (%s) for early stopping.', step, metric_name, val, + greater_or_lesser, threshold) + return True + return False + + return make_early_stopping_hook( + estimator=estimator, + should_stop_fn=stop_if_threshold_crossed_fn, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def _stop_if_no_metric_improvement_hook( + estimator, metric_name, max_steps_without_improvement, higher_is_better, + eval_dir, min_steps, run_every_secs, run_every_steps): + """Returns hook to stop training if given metric shows no improvement.""" + + if eval_dir is None: + eval_dir = estimator.eval_dir() + + is_lhs_better = operator.gt if higher_is_better else operator.lt + increase_or_decrease = 'increase' if higher_is_better else 'decrease' + + def stop_if_no_metric_improvement_fn(): + """Returns `True` if metric does not improve within max steps.""" + + eval_results = read_eval_metrics(eval_dir) + + best_val = None + best_val_step = None + for step, metrics in eval_results.items(): + if step < min_steps: + continue + val = metrics[metric_name] + if best_val is None or is_lhs_better(val, best_val): + best_val = val + best_val_step = step + if step - best_val_step >= max_steps_without_improvement: + tf_logging.info( + 'No %s in metric "%s" for %s steps, which is greater than or equal ' + 'to max steps (%s) configured for early stopping.', + increase_or_decrease, metric_name, step - best_val_step, + max_steps_without_improvement) + return True + return False + + return make_early_stopping_hook( + estimator=estimator, + should_stop_fn=stop_if_no_metric_improvement_fn, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def _summaries(eval_dir): + """Yields `tensorflow.Event` protos from event files in the eval dir. + + Args: + eval_dir: Directory containing summary files with eval metrics. + + Yields: + `tensorflow.Event` object read from the event files. + """ + if gfile.Exists(eval_dir): + for event_file in gfile.Glob( + os.path.join(eval_dir, _EVENT_FILE_GLOB_PATTERN)): + for event in summary_iterator.summary_iterator(event_file): + yield event + + +def _get_or_create_stop_var(): + with variable_scope.variable_scope( + name_or_scope='signal_early_stopping', + values=[], + reuse=variable_scope.AUTO_REUSE): + return variable_scope.get_variable( + name='STOP', + shape=[], + dtype=dtypes.bool, + initializer=init_ops.constant_initializer(False), + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False) + + +class _StopOnPredicateHook(session_run_hook.SessionRunHook): + """Hook that requests stop when `should_stop_fn` returns `True`.""" + + def __init__(self, should_stop_fn, run_every_secs=60, run_every_steps=None): + if not callable(should_stop_fn): + raise TypeError('`should_stop_fn` must be callable.') + + self._should_stop_fn = should_stop_fn + self._timer = basic_session_run_hooks.SecondOrStepTimer( + every_secs=run_every_secs, every_steps=run_every_steps) + self._global_step_tensor = None + self._stop_var = None + self._stop_op = None + + def begin(self): + self._global_step_tensor = training_util.get_global_step() + self._stop_var = _get_or_create_stop_var() + self._stop_op = state_ops.assign(self._stop_var, True) + + def before_run(self, run_context): + del run_context + return session_run_hook.SessionRunArgs(self._global_step_tensor) + + def after_run(self, run_context, run_values): + global_step = run_values.results + if self._timer.should_trigger_for_step(global_step): + self._timer.update_last_triggered_step(global_step) + if self._should_stop_fn(): + tf_logging.info('Requesting early stopping at global step %d', + global_step) + run_context.session.run(self._stop_op) + run_context.request_stop() + + +class _CheckForStoppingHook(session_run_hook.SessionRunHook): + """Hook that requests stop if stop is requested by `_StopOnPredicateHook`.""" + + def __init__(self): + self._stop_var = None + + def begin(self): + self._stop_var = _get_or_create_stop_var() + + def before_run(self, run_context): + del run_context + return session_run_hook.SessionRunArgs(self._stop_var) + + def after_run(self, run_context, run_values): + should_early_stop = run_values.results + if should_early_stop: + tf_logging.info('Early stopping requested, suspending run.') + run_context.request_stop() diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py b/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e4bfd4b446b9413bd1627ef6904ff2dc9f1a9120 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py @@ -0,0 +1,246 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for early_stopping.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from absl.testing import parameterized +from tensorflow.contrib.estimator.python.estimator import early_stopping +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session +from tensorflow.python.training import training_util + + +class _FakeRunConfig(run_config.RunConfig): + + def __init__(self, is_chief): + super(_FakeRunConfig, self).__init__() + self._is_chief = is_chief + + @property + def is_chief(self): + return self._is_chief + + +def _dummy_model_fn(features, labels, params): + _, _, _ = features, labels, params + + +class _FakeEstimator(estimator.Estimator): + """Fake estimator for testing.""" + + def __init__(self, config): + super(_FakeEstimator, self).__init__( + model_fn=_dummy_model_fn, config=config) + + +def _write_events(eval_dir, params): + """Test helper to write events to summary files.""" + for steps, loss, accuracy in params: + estimator._write_dict_to_summary(eval_dir, { + 'loss': loss, + 'accuracy': accuracy, + }, steps) + + +class ReadEvalMetricsTest(test.TestCase): + + def test_read_eval_metrics(self): + eval_dir = tempfile.mkdtemp() + _write_events( + eval_dir, + [ + # steps, loss, accuracy + (1000, 1, 2), + (2000, 3, 4), + (3000, 5, 6), + ]) + self.assertEqual({ + 1000: { + 'loss': 1, + 'accuracy': 2 + }, + 2000: { + 'loss': 3, + 'accuracy': 4 + }, + 3000: { + 'loss': 5, + 'accuracy': 6 + }, + }, early_stopping.read_eval_metrics(eval_dir)) + + def test_read_eval_metrics_when_no_events(self): + eval_dir = tempfile.mkdtemp() + self.assertTrue(os.path.exists(eval_dir)) + + # No error should be raised when eval directory exists with no event files. + self.assertEqual({}, early_stopping.read_eval_metrics(eval_dir)) + + os.rmdir(eval_dir) + self.assertFalse(os.path.exists(eval_dir)) + + # No error should be raised when eval directory does not exist. + self.assertEqual({}, early_stopping.read_eval_metrics(eval_dir)) + + +class EarlyStoppingHooksTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + config = _FakeRunConfig(is_chief=True) + self._estimator = _FakeEstimator(config=config) + eval_dir = self._estimator.eval_dir() + os.makedirs(eval_dir) + _write_events( + eval_dir, + [ + # steps, loss, accuracy + (1000, 0.8, 0.5), + (2000, 0.7, 0.6), + (3000, 0.4, 0.7), + (3500, 0.41, 0.68), + ]) + + def run_session(self, hooks, should_stop): + hooks = hooks if isinstance(hooks, list) else [hooks] + with ops.Graph().as_default(): + training_util.create_global_step() + no_op = control_flow_ops.no_op() + with monitored_session.SingularMonitoredSession(hooks=hooks) as mon_sess: + mon_sess.run(no_op) + self.assertEqual(mon_sess.should_stop(), should_stop) + + @parameterized.parameters((0.8, 0, False), (0.6, 4000, False), (0.6, 0, True)) + def test_stop_if_higher_hook(self, threshold, min_steps, should_stop): + self.run_session( + early_stopping.stop_if_higher_hook( + self._estimator, + metric_name='accuracy', + threshold=threshold, + min_steps=min_steps), should_stop) + + @parameterized.parameters((0.3, 0, False), (0.5, 4000, False), (0.5, 0, True)) + def test_stop_if_lower_hook(self, threshold, min_steps, should_stop): + self.run_session( + early_stopping.stop_if_lower_hook( + self._estimator, + metric_name='loss', + threshold=threshold, + min_steps=min_steps), should_stop) + + @parameterized.parameters((1500, 0, False), (500, 4000, False), + (500, 0, True)) + def test_stop_if_no_increase_hook(self, max_steps, min_steps, should_stop): + self.run_session( + early_stopping.stop_if_no_increase_hook( + self._estimator, + metric_name='accuracy', + max_steps_without_increase=max_steps, + min_steps=min_steps), should_stop) + + @parameterized.parameters((1500, 0, False), (500, 4000, False), + (500, 0, True)) + def test_stop_if_no_decrease_hook(self, max_steps, min_steps, should_stop): + self.run_session( + early_stopping.stop_if_no_decrease_hook( + self._estimator, + metric_name='loss', + max_steps_without_decrease=max_steps, + min_steps=min_steps), should_stop) + + @parameterized.parameters((1500, 0.3, False), (1500, 0.5, True), + (500, 0.3, True)) + def test_multiple_hooks(self, max_steps, loss_threshold, should_stop): + self.run_session([ + early_stopping.stop_if_no_decrease_hook( + self._estimator, + metric_name='loss', + max_steps_without_decrease=max_steps), + early_stopping.stop_if_lower_hook( + self._estimator, metric_name='loss', threshold=loss_threshold) + ], should_stop) + + @parameterized.parameters(False, True) + def test_make_early_stopping_hook(self, should_stop): + self.run_session([ + early_stopping.make_early_stopping_hook( + self._estimator, should_stop_fn=lambda: should_stop) + ], should_stop) + + def test_make_early_stopping_hook_typeerror(self): + with self.assertRaises(TypeError): + early_stopping.make_early_stopping_hook( + estimator=object(), should_stop_fn=lambda: True) + + def test_make_early_stopping_hook_valueerror(self): + with self.assertRaises(ValueError): + early_stopping.make_early_stopping_hook( + self._estimator, + should_stop_fn=lambda: True, + run_every_secs=60, + run_every_steps=100) + + +class StopOnPredicateHookTest(test.TestCase): + + def test_stop(self): + hook = early_stopping._StopOnPredicateHook( + should_stop_fn=lambda: False, run_every_secs=0) + with ops.Graph().as_default(): + training_util.create_global_step() + no_op = control_flow_ops.no_op() + with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.run(no_op) + self.assertFalse(mon_sess.should_stop()) + self.assertFalse(mon_sess.raw_session().run(hook._stop_var)) + + hook = early_stopping._StopOnPredicateHook( + should_stop_fn=lambda: True, run_every_secs=0) + with ops.Graph().as_default(): + training_util.create_global_step() + no_op = control_flow_ops.no_op() + with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.run(no_op) + self.assertTrue(mon_sess.should_stop()) + self.assertTrue(mon_sess.raw_session().run(hook._stop_var)) + + +class CheckForStoppingHookTest(test.TestCase): + + def test_stop(self): + hook = early_stopping._CheckForStoppingHook() + with ops.Graph().as_default(): + no_op = control_flow_ops.no_op() + assign_op = state_ops.assign(early_stopping._get_or_create_stop_var(), + True) + with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.run(no_op) + self.assertFalse(mon_sess.should_stop()) + mon_sess.run(assign_op) + self.assertTrue(mon_sess.should_stop()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index bf08be09e7baf63e507a6a4db6a91e7b6bb20b74..26449b46516fe1d8c93a8e3567f93801c689a65a 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -34,7 +34,7 @@ _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) def add_metrics(estimator, metric_fn): - """Creates a new @{tf.estimator.Estimator} which has given metrics. + """Creates a new `tf.estimator.Estimator` which has given metrics. Example: @@ -61,7 +61,7 @@ def add_metrics(estimator, metric_fn): ``` Args: - estimator: A @{tf.estimator.Estimator} object. + estimator: A `tf.estimator.Estimator` object. metric_fn: A function which should obey the following signature: - Args: can only have following four arguments in any order: * predictions: Predictions `Tensor` or dict of `Tensor` created by given @@ -79,7 +79,7 @@ def add_metrics(estimator, metric_fn): function, namely a `(metric_tensor, update_op)` tuple. Returns: - A new @{tf.estimator.Estimator} which has a union of original metrics with + A new `tf.estimator.Estimator` which has a union of original metrics with given ones. """ _verify_metric_fn_args(metric_fn) @@ -165,14 +165,14 @@ def forward_features(estimator, keys=None): ``` Args: - estimator: A @{tf.estimator.Estimator} object. + estimator: A `tf.estimator.Estimator` object. keys: a `string` or a `list` of `string`. If it is `None`, all of the `features` in `dict` is forwarded to the `predictions`. If it is a `string`, only given key is forwarded. If it is a `list` of strings, all the given `keys` are forwarded. Returns: - A new @{tf.estimator.Estimator} which forwards features to predictions. + A new `tf.estimator.Estimator` which forwards features to predictions. Raises: ValueError: diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 9594e5132fd20dadea118fd1dd6768feb7fd7fff..34f765d56546d3cd10fcde5ac444a221c73602cd 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -534,7 +534,8 @@ def multi_label_head(n_classes, * An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` - must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`. + must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary` or a + multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`. If `weight_column` is specified, weights must be of shape `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. @@ -942,20 +943,30 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access class_probabilities = array_ops.slice( probabilities, begin=begin, size=size) class_labels = array_ops.slice(labels, begin=begin, size=size) - prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id + if self._label_vocabulary is None: + prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id + else: + prob_key = ( + keys.PROBABILITY_MEAN_AT_NAME % self._label_vocabulary[class_id]) metric_ops[head_lib._summary_key(self._name, prob_key)] = ( # pylint:disable=protected-access head_lib._predictions_mean( # pylint:disable=protected-access predictions=class_probabilities, weights=weights, name=prob_key)) - auc_key = keys.AUC_AT_CLASS % class_id + if self._label_vocabulary is None: + auc_key = keys.AUC_AT_CLASS % class_id + else: + auc_key = keys.AUC_AT_NAME % self._label_vocabulary[class_id] metric_ops[head_lib._summary_key(self._name, auc_key)] = ( # pylint:disable=protected-access head_lib._auc( # pylint:disable=protected-access labels=class_labels, predictions=class_probabilities, weights=weights, name=auc_key)) - auc_pr_key = keys.AUC_PR_AT_CLASS % class_id + if self._label_vocabulary is None: + auc_pr_key = keys.AUC_PR_AT_CLASS % class_id + else: + auc_pr_key = keys.AUC_PR_AT_NAME % self._label_vocabulary[class_id] metric_ops[head_lib._summary_key(self._name, auc_pr_key)] = ( # pylint:disable=protected-access head_lib._auc( # pylint:disable=protected-access labels=class_labels, diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index b2b57fa06ba818d4455871fe57dde5ce287b39a2..2d367adb47080a630d1d2ef5ecfd4e8d5d0377d9 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -568,6 +568,33 @@ class MultiLabelHead(test.TestCase): expected_loss=expected_loss, expected_metrics=expected_metrics) + def test_eval_with_label_vocabulary_with_multi_hot_input(self): + n_classes = 2 + head = head_lib.multi_label_head( + n_classes, label_vocabulary=['class0', 'class1']) + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + } + self._test_eval( + head=head, + logits=logits, + labels=labels_multi_hot, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + def test_eval_with_thresholds(self): n_classes = 2 thresholds = [0.25, 0.5, 0.75] @@ -667,12 +694,14 @@ class MultiLabelHead(test.TestCase): # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, keys.AUC_PR: 0.7639, - keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2., - keys.AUC_AT_CLASS % 0: 0., - keys.AUC_PR_AT_CLASS % 0: 1., - keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2., - keys.AUC_AT_CLASS % 1: 1., - keys.AUC_PR_AT_CLASS % 1: 1., + keys.PROBABILITY_MEAN_AT_NAME % 'a': + np.sum(_sigmoid(logits[:, 0])) / 2., + keys.AUC_AT_NAME % 'a': 0., + keys.AUC_PR_AT_NAME % 'a': 1., + keys.PROBABILITY_MEAN_AT_NAME % 'b': + np.sum(_sigmoid(logits[:, 1])) / 2., + keys.AUC_AT_NAME % 'b': 1., + keys.AUC_PR_AT_NAME % 'b': 1., } self._test_eval( diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py index ddd6aa442f82bad2d4714dbcdc85b20b34773068..faefda7c4898f140864908da33b1f19f93b2dd47 100644 --- a/tensorflow/contrib/estimator/python/estimator/hooks.py +++ b/tensorflow/contrib/estimator/python/estimator/hooks.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +import time from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.framework import ops @@ -26,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import training +from tensorflow.python.training import training_util # pylint: disable=protected-access @@ -189,7 +191,7 @@ class InMemoryEvaluatorHook(training.SessionRunHook): init_fn=feed_variables, copy_from_scaffold=self._scaffold) with self._graph.as_default(): - return self._estimator._evaluate_run( + self._estimator._evaluate_run( checkpoint_path=None, scaffold=scaffold, update_op=self._update_op, @@ -210,4 +212,55 @@ class InMemoryEvaluatorHook(training.SessionRunHook): self._evaluate(session) +class StopAtCheckpointStepHook(training.SessionRunHook): + """Hook that requests stop at a specified step based on checkpoint.""" + + def __init__(self, model_dir, last_step, + wait_after_file_check_secs=30): + """Initializes a `StopAtCheckpointStepHook`. + + This hook requests stop after a last step has been reached. It checks latest + checkpoint to verify last step is written on disk or not. + + Args: + model_dir: Directory to read global step from latest checkpoint. + last_step: Step after which to stop. + wait_after_file_check_secs: Reading same file by many workers may create + I/O issues. To throttle that we will wait given secs after each read of + the file. + + Raises: + ValueError: If one of the arguments is invalid. + """ + if last_step is None: + raise ValueError('last_step must be specified.') + if model_dir is None: + raise ValueError('model_dir must be specified.') + + self._model_dir = model_dir + self._last_step = last_step + self._wait_after_file_check_secs = wait_after_file_check_secs + + def begin(self): + self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access + if self._global_step_tensor is None: + raise RuntimeError( + 'Global step should be created to use StopAtCheckpointStepHook.') + + def before_run(self, run_context): # pylint: disable=unused-argument + return training.SessionRunArgs(self._global_step_tensor) + + def after_run(self, run_context, run_values): + global_step = run_values.results + 1 + if global_step >= self._last_step: + # Check latest global step in the checkpoint to ensure that the targeted + # last step is written on disk. + + step = estimator_lib._load_global_step_from_checkpoint_dir( + self._model_dir) + if step >= self._last_step: + run_context.request_stop() + else: + time.sleep(self._wait_after_file_check_secs) + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py index 95ae971852ee6dffb6174fc243686721c30ef685..42352aa3ffbc087d148d16d69dfef5e5057de2e6 100644 --- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py +++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py @@ -21,8 +21,11 @@ from __future__ import print_function import glob import json import os +import tempfile +import time from tensorflow.contrib.estimator.python.estimator import hooks as hooks_lib +from tensorflow.python.client import session as tf_session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator import run_config as run_config_lib @@ -102,6 +105,7 @@ class InMemoryEvaluatorHookTest(test.TestCase): self.assertTrue(os.path.isdir(estimator.eval_dir())) step_keyword_to_value = summary_step_keyword_to_value_mapping( estimator.eval_dir()) + # 4.5 = sum(range(10))/10 # before training self.assertEqual(4.5, step_keyword_to_value[0]['mean_of_features']) @@ -110,6 +114,7 @@ class InMemoryEvaluatorHookTest(test.TestCase): self.assertEqual(4.5, step_keyword_to_value[8]['mean_of_features']) # end self.assertEqual(4.5, step_keyword_to_value[10]['mean_of_features']) + self.assertEqual(set([0, 4, 8, 10]), set(step_keyword_to_value.keys())) def test_uses_latest_variable_value(self): @@ -314,5 +319,59 @@ class InMemoryEvaluatorHookTest(test.TestCase): estimator.train(input_fn, hooks=[evaluator]) +class StopAtCheckpointStepHookTest(test.TestCase): + + def test_do_not_stop_if_checkpoint_is_not_there(self): + with ops.Graph().as_default(): + step = training.create_global_step() + assign_ten = step.assign(10) + no_op = control_flow_ops.no_op() + hook = hooks_lib.StopAtCheckpointStepHook( + model_dir=tempfile.mkdtemp(), last_step=10) + with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.raw_session().run(assign_ten) + with test.mock.patch.object(time, 'sleep') as mock_sleep: + mon_sess.run(no_op) + self.assertTrue(mock_sleep.called) + self.assertFalse(mon_sess.should_stop()) + + def test_do_not_stop_if_checkpoint_step_is_smaller(self): + model_dir = tempfile.mkdtemp() + with ops.Graph().as_default(): + step = training.create_global_step() + assign_nine = step.assign(9) + assign_ten = step.assign(10) + no_op = control_flow_ops.no_op() + hook = hooks_lib.StopAtCheckpointStepHook( + model_dir=model_dir, last_step=10) + with tf_session.Session() as sess: + sess.run(assign_nine) + training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) + with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.raw_session().run(assign_ten) + with test.mock.patch.object(time, 'sleep') as mock_sleep: + mon_sess.run(no_op) + self.assertTrue(mock_sleep.called) + self.assertFalse(mon_sess.should_stop()) + + def test_stop_if_checkpoint_step_is_laststep(self): + model_dir = tempfile.mkdtemp() + with ops.Graph().as_default(): + step = training.create_global_step() + assign_ten = step.assign(10) + no_op = control_flow_ops.no_op() + hook = hooks_lib.StopAtCheckpointStepHook( + model_dir=model_dir, last_step=10) + with tf_session.Session() as sess: + sess.run(assign_ten) + training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) + with training.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.raw_session().run(assign_ten) + with test.mock.patch.object(time, 'sleep') as mock_sleep: + mon_sess.run(no_op) + self.assertFalse(mock_sleep.called) + self.assertTrue(mon_sess.should_stop()) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/linear.py b/tensorflow/contrib/estimator/python/estimator/linear.py index 3bf4abe83d54504d55de73b63f369cceaf149dd2..2b68f24eb2d4c528bc1cb87e7d858014f66c0433 100644 --- a/tensorflow/contrib/estimator/python/estimator/linear.py +++ b/tensorflow/contrib/estimator/python/estimator/linear.py @@ -39,6 +39,18 @@ class LinearEstimator(estimator.Estimator): feature_columns=[categorical_column_a, categorical_feature_a_x_categorical_feature_b]) + # Or estimator using an optimizer with a learning rate decay. + estimator = LinearEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=lambda: tf.train.FtrlOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator using the FTRL optimizer with regularization. estimator = LinearEstimator( head=tf.contrib.estimator.multi_label_head(n_classes=3), @@ -87,7 +99,8 @@ class LinearEstimator(estimator.Estimator): model_dir=None, optimizer='Ftrl', config=None, - partitioner=None): + partitioner=None, + sparse_combiner='sum'): """Initializes a `LinearEstimator` instance. Args: @@ -99,10 +112,16 @@ class LinearEstimator(estimator.Estimator): model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to FTRL optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to FTRL optimizer. config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. + sparse_combiner: A string specifying how to reduce if a categorical column + is multivalent. One of "mean", "sqrtn", and "sum" -- these are + effectively different ways to do example-level normalization, which can + be useful for bag-of-words features. for more details, see + `tf.feature_column.linear_model`. """ def _model_fn(features, labels, mode, config): return linear_lib._linear_model_fn( # pylint: disable=protected-access @@ -113,6 +132,7 @@ class LinearEstimator(estimator.Estimator): feature_columns=tuple(feature_columns or []), optimizer=optimizer, partitioner=partitioner, - config=config) + config=config, + sparse_combiner=sparse_combiner) super(LinearEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..ce98e9987ec728fadf170e56fe4bfe24fc9a0105 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py @@ -0,0 +1,449 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class that creates an Estimator from a SavedModel.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export as export_lib +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import loader_impl +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import monitored_session +from tensorflow.python.training import training_util + + +class SavedModelEstimator(estimator_lib.Estimator): + """Create an Estimator from a SavedModel. + + Only SavedModels exported with + `tf.contrib.estimator.export_all_saved_models()` or + `tf.estimator.Estimator.export_savedmodel()` are supported for this class. + + Example with `tf.estimator.DNNClassifier`: + + **Step 1: Create and train DNNClassifier.** + + ```python + feature1 = tf.feature_column.embedding_column( + tf.feature_column.categorical_column_with_vocabulary_list( + key='feature1', vocabulary_list=('green', 'yellow')), dimension=1) + feature2 = tf.feature_column.numeric_column(key='feature2', default_value=0.0) + + classifier = tf.estimator.DNNClassifier( + hidden_units=[4,2], feature_columns=[feature1, feature2]) + + def input_fn(): + features = {'feature1': tf.constant(['green', 'green', 'yellow']), + 'feature2': tf.constant([3.5, 4.2, 6.1])} + label = tf.constant([1., 0., 0.]) + return tf.data.Dataset.from_tensors((features, label)).repeat() + + classifier.train(input_fn=input_fn, steps=10) + ``` + + **Step 2: Export classifier.** + First, build functions that specify the expected inputs. + + ```python + # During train and evaluation, both the features and labels should be defined. + supervised_input_receiver_fn = ( + tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + {'feature1': tf.placeholder(dtype=tf.string, shape=[None]), + 'feature2': tf.placeholder(dtype=tf.float32, shape=[None])}, + tf.placeholder(dtype=tf.float32, shape=[None]))) + + # During predict mode, expect to receive a `tf.Example` proto, so a parsing + # function is used. + serving_input_receiver_fn = ( + tf.estimator.export.build_parsing_serving_input_receiver_fn( + tf.feature_column.make_parse_example_spec([feature1, feature2]))) + ``` + + Next, export the model as a SavedModel. A timestamped directory will be + created (for example `/tmp/export_all/1234567890`). + + ```python + # Option 1: Save all modes (train, eval, predict) + export_dir = tf.contrib.estimator.export_all_saved_models( + classifier, '/tmp/export_all', + {tf.estimator.ModeKeys.TRAIN: supervised_input_receiver_fn, + tf.estimator.ModeKeys.EVAL: supervised_input_receiver_fn, + tf.estimator.ModeKeys.PREDICT: serving_input_receiver_fn}) + + # Option 2: Only export predict mode + export_dir = classifier.export_savedmodel( + '/tmp/export_predict', serving_input_receiver_fn) + ``` + + **Step 3: Create a SavedModelEstimator from the exported SavedModel.** + + ```python + est = tf.contrib.estimator.SavedModelEstimator(export_dir) + + # If all modes were exported, you can immediately evaluate and predict, or + # continue training. Otherwise only predict is available. + eval_results = est.evaluate(input_fn=input_fn, steps=1) + print(eval_results) + + est.train(input_fn=input_fn, steps=20) + + def predict_input_fn(): + example = tf.train.Example() + example.features.feature['feature1'].bytes_list.value.extend(['yellow']) + example.features.feature['feature2'].float_list.value.extend([1.]) + return {'inputs':tf.constant([example.SerializeToString()])} + + predictions = est.predict(predict_input_fn) + print(next(predictions)) + ``` + """ + + def __init__(self, saved_model_dir, model_dir=None): + """Initialize a SavedModelEstimator. + + The SavedModelEstimator loads its model function and variable values from + the graphs defined in the SavedModel. There is no option to pass in + `RunConfig` or `params` arguments, because the model function graph is + defined statically in the SavedModel. + + Args: + saved_model_dir: Directory containing SavedModel protobuf and subfolders. + model_dir: Directory to save new checkpoints during training. + + Raises: + NotImplementedError: If a DistributionStrategy is defined in the config. + Unless the SavedModelEstimator is subclassed, this shouldn't happen. + """ + checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir) # pylint: disable=protected-access + vars_to_warm_start = [name for name, _ in + checkpoint_utils.list_variables(checkpoint)] + warm_start_settings = estimator_lib.WarmStartSettings( + ckpt_to_initialize_from=checkpoint, + vars_to_warm_start=vars_to_warm_start) + + super(SavedModelEstimator, self).__init__( + model_fn=self._model_fn_from_saved_model, model_dir=model_dir, + warm_start_from=warm_start_settings) + if self._train_distribution or self._eval_distribution: + raise NotImplementedError( + 'SavedModelEstimator currently does not support ' + 'DistributionStrategy.') + self.saved_model_dir = saved_model_dir + self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir) + self._available_modes = self._extract_available_modes() + + def _extract_available_modes(self): + """Return list of modes found in SavedModel.""" + available_modes = [] + logging.info('Checking available modes for SavedModelEstimator.') + for mode in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT]: + try: + self._get_meta_graph_def_for_mode(mode) + except RuntimeError: + logging.warning('%s mode not found in SavedModel.' % mode) + continue + + if self._get_signature_def_for_mode(mode) is not None: + available_modes.append(mode) + + logging.info('Available modes for Estimator: %s' % available_modes) + return available_modes + + def _validate_mode(self, mode): + """Make sure that mode can be run using the SavedModel.""" + if mode not in self._available_modes: + raise RuntimeError('%s mode is not available in the SavedModel. Use ' + 'saved_model_cli to check that the Metagraph for this ' + 'mode has been exported.' % mode) + + def _get_meta_graph_def_for_mode(self, mode): + tags = model_fn_lib.EXPORT_TAG_MAP[mode] + return self.saved_model_loader.get_meta_graph_def_from_tags(tags) + + def _get_signature_def_for_mode(self, mode): + meta_graph_def = self._get_meta_graph_def_for_mode(mode) + sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + if mode == model_fn_lib.ModeKeys.PREDICT else mode) + if sig_def_key not in meta_graph_def.signature_def: + logging.warning('Metagraph for mode %s was found, but SignatureDef with' + ' key \"%s\" is missing.' % (mode, sig_def_key)) + return None + return meta_graph_def.signature_def[sig_def_key] + + def _create_and_assert_global_step(self, graph): + # Do nothing here. The global step variable will be created/loaded from the + # SavedModel. If a global step variable were created here, the result + # will be two duplicate global step variables, causing issues during + # the warm-start phase. + # Due to the global variable being created in the model function, this may + # cause issues when running DistributionStrategy. Thus, DistributionStrategy + # is not yet supported with SavedModelEstimator. + return None + + def _model_fn_from_saved_model(self, features, labels, mode): + """Load a SavedModel graph and return an EstimatorSpec.""" + # TODO(kathywu): Model function loads placeholders from the graph. Calling + # export_all_saved_models creates another placeholder for the inputs, on top + # of the original placeholders. There should be a way to avoid this. + self._validate_mode(mode) + + g = ops.get_default_graph() + if training_util.get_global_step(g) is not None: + raise RuntimeError( + 'Graph must not contain a global step tensor before the SavedModel is' + ' loaded. Please make sure that the input function does not create a ' + 'global step.') + + # Extract SignatureDef for information about the input and output tensors. + signature_def = self._get_signature_def_for_mode(mode) + + # Generate input map for replacing the inputs in the SavedModel graph with + # the provided features and labels. + input_map = _generate_input_map(signature_def, features, labels) + + # Create a list of the names of output tensors. When the graph is loaded, + # names of the output tensors may be remapped. This ensures that the correct + # tensors are returned in the EstimatorSpec. + output_tensor_names = [ + value.name for value in six.itervalues(signature_def.outputs)] + + # Load the graph. `output_tensors` contains output `Tensors` in the same + # same order as the `output_tensor_names` list. + tags = model_fn_lib.EXPORT_TAG_MAP[mode] + _, output_tensors = self.saved_model_loader.load_graph( + g, tags, input_map=input_map, return_elements=output_tensor_names) + + # Create a scaffold from the MetaGraphDef that contains ops to initialize + # the graph. This should mirror the steps from _add_meta_graph_for_mode(), + # which creates a MetaGraphDef from the EstimatorSpec's scaffold. + scaffold = monitored_session.Scaffold( + local_init_op=loader_impl._get_main_op_tensor( # pylint: disable=protected-access + self._get_meta_graph_def_for_mode(mode))) + + # Ensure that a global step tensor has been created. + global_step_tensor = training_util.get_global_step(g) + training_util.assert_global_step(global_step_tensor) + + # Extract values to return in the EstimatorSpec. + output_map = dict(zip(output_tensor_names, output_tensors)) + outputs = {key: output_map[value.name] + for key, value in six.iteritems(signature_def.outputs)} + + loss, predictions, metrics = _validate_and_extract_outputs( + mode, outputs, signature_def.method_name) + + train_op = ops.get_collection(constants.TRAIN_OP_KEY) + if len(train_op) > 1: + raise RuntimeError('Multiple ops found in the train_op collection.') + train_op = None if not train_op else train_op[0] + + _clear_saved_model_collections() + return model_fn_lib.EstimatorSpec( + scaffold=scaffold, + mode=mode, + loss=loss, + train_op=train_op, + predictions=predictions, + eval_metric_ops=metrics) + + +def _clear_saved_model_collections(): + """Clear collections that are expected empty when exporting a SavedModel. + + The SavedModel builder uses these collections to track ops necessary to + restore the graph state. These collections are expected to be empty before + MetaGraphs are added to the builder. + """ + del ops.get_collection_ref(constants.ASSETS_KEY)[:] + del ops.get_collection_ref(constants.LEGACY_INIT_OP_KEY)[:] + del ops.get_collection_ref(constants.MAIN_OP_KEY)[:] + del ops.get_collection_ref(constants.TRAIN_OP_KEY)[:] + + +def _generate_input_map(signature_def, features, labels): + """Return dict mapping an input tensor name to a feature or label tensor. + + Args: + signature_def: SignatureDef loaded from SavedModel + features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the features to be passed to the model. + labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the labels to be passed to the model. May be + `None`. + + Returns: + dict mapping string names of inputs to features or labels tensors + + Raises: + ValueError: if SignatureDef inputs are not completely mapped by the input + features and labels. + """ + # pylint: disable=protected-access + if not isinstance(features, dict): + features = {export_lib._SINGLE_FEATURE_DEFAULT_NAME: features} + if labels is not None and not isinstance(labels, dict): + labels = {export_lib._SINGLE_LABEL_DEFAULT_NAME: labels} + # pylint: enable=protected-access + + inputs = signature_def.inputs + input_map = {} + for key, tensor_info in six.iteritems(inputs): + input_name = tensor_info.name + if ':' in input_name: + input_name = input_name[:input_name.find(':')] + + # When tensors are used as control inputs for operations, their names are + # prepended with a '^' character in the GraphDef. To handle possible control + # flow edge cases, control input names must be included in the input map. + control_dependency_name = '^' + input_name + + if key in features: + _check_same_dtype_and_shape(features[key], tensor_info, key) + input_map[input_name] = input_map[control_dependency_name] = features[key] + elif labels is not None and key in labels: + _check_same_dtype_and_shape(labels[key], tensor_info, key) + input_map[input_name] = input_map[control_dependency_name] = labels[key] + else: + raise ValueError( + 'Key \"%s\" not found in features or labels passed in to the model ' + 'function. All required keys: %s' % (key, inputs.keys())) + + return input_map + + +def _check_same_dtype_and_shape(tensor, tensor_info, name): + """Validate that tensor has the same properties as the TensorInfo proto. + + Args: + tensor: a `Tensor` object. + tensor_info: a `TensorInfo` proto. + name: Name of the input (to identify Tensor if an error is raised). + + Raises: + ValueError: If the tensor shape or dtype don't match the TensorInfo + """ + dtype_error = (tensor.dtype != dtypes.DType(tensor_info.dtype)) + shape_error = not tensor.shape.is_compatible_with(tensor_info.tensor_shape) + + if dtype_error or shape_error: + msg = 'Tensor shape and/or dtype validation failed for input %s:' % name + if dtype_error: + msg += ('\n\tExpected dtype: %s, Got: %s' + % (dtypes.DType(tensor_info.dtype), tensor.dtype)) + if shape_error: + msg += ('\n\tExpected shape: %s, Got: %s' + % (tensor_shape.TensorShape(tensor_info.tensor_shape), + tensor.shape)) + + raise ValueError(msg) + + +def _extract_eval_metrics(output_dict): + """Return a eval metric dict extracted from the output_dict. + + Eval metrics consist of a value tensor and an update op. Both must be in the + passed-in tensor dictionary for an eval metric to be added to the returned + dictionary. + + Args: + output_dict: a dict that maps strings to tensors. + + Returns: + dict mapping strings to (value, update_op) tuples. + """ + # pylint: disable=protected-access + metric_ops = {} + separator_char = export_output._SupervisedOutput._SEPARATOR_CHAR + + for key, tensor in six.iteritems(output_dict): + split_key = key.split(separator_char) + + # The metric name may contain the separator character, so recreate its name. + metric_name = separator_char.join(split_key[:-1]) + + if split_key[0] == export_output._SupervisedOutput.METRICS_NAME: + # If the key ends with the value suffix, and there is a corresponding + # key ending with the update_op suffix, then add tensors to metrics dict. + if split_key[-1] == export_output._SupervisedOutput.METRIC_VALUE_SUFFIX: + update_op = ''.join( + [metric_name, separator_char, + export_output._SupervisedOutput.METRIC_UPDATE_SUFFIX]) + if update_op in output_dict: + update_op_tensor = output_dict[update_op] + metric_ops[metric_name] = (tensor, update_op_tensor) + + # pylint: enable=protected-access + return metric_ops + + +def _validate_and_extract_outputs(mode, output_dict, method_name): + """Extract values from SignatureDef output dictionary. + + Args: + mode: One of the modes enumerated in `tf.estimator.ModeKeys`. + output_dict: dict of string SignatureDef keys to `Tensor`. + method_name: Method name of the SignatureDef as a string. + + Returns: + Tuple of ( + loss: `Tensor` object, + predictions: dictionary mapping string keys to `Tensor` objects, + metrics: dictionary mapping string keys to a tuple of two `Tensor` objects + ) + + Raises: + RuntimeError: raised if SignatureDef has an invalid method name for the mode + """ + # pylint: disable=protected-access + loss, predictions, metrics = None, None, None + + if mode == model_fn_lib.ModeKeys.PREDICT: + predictions = output_dict + else: + # Validate that the SignatureDef's method name matches the expected name for + # the given mode. + expected_method_name = signature_constants.SUPERVISED_TRAIN_METHOD_NAME + if mode == model_fn_lib.ModeKeys.EVAL: + expected_method_name = signature_constants.SUPERVISED_EVAL_METHOD_NAME + if method_name != expected_method_name: + raise RuntimeError( + 'Invalid SignatureDef method name for mode %s.\n\tExpected: %s\n\t' + 'Got: %s\nPlease ensure that the SavedModel was exported with ' + '`tf.contrib.estimator.export_all_saved_models()`.' % + (mode, expected_method_name, method_name)) + + # Extract loss, metrics and predictions from the output dict. + loss = output_dict[export_output._SupervisedOutput.LOSS_NAME] + metrics = _extract_eval_metrics(output_dict) + predictions = { + key: value for key, value in six.iteritems(output_dict) + if key.split(export_output._SupervisedOutput._SEPARATOR_CHAR)[0] == ( + export_output._SupervisedOutput.PREDICTIONS_NAME)} + + # pylint: enable=protected-access + return loss, predictions, metrics diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..718da1367ce69285f37269c5631fa0be2b050c97 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py @@ -0,0 +1,369 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SavedModelEstimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +from tensorflow.contrib.estimator.python.estimator import export as contrib_export +from tensorflow.contrib.estimator.python.estimator import saved_model_estimator +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session +from tensorflow.python.training import training + + +def dummy_input_fn(): + return dataset_ops.Dataset.from_tensors(( + {'x': constant_op.constant([[1], [-2]], dtype=dtypes.int64)}, + constant_op.constant([[4], [-3]], dtype=dtypes.float32))).repeat() + + +def dummy_input_fn_features_only(): + return dataset_ops.Dataset.from_tensors( + {'x': constant_op.constant([[5], [6]], dtype=dtypes.int64)}).repeat() + + +def dummy_supervised_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[2, 1], name='truth') + return export.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + +def dummy_serving_receiver_fn(): + feature_spec = {'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'),} + return export.build_raw_serving_input_receiver_fn(feature_spec) + + +def model_fn_diff_modes(features, labels, mode): + _, _ = features, labels + v = variables.Variable(21, name='some_var') + train_op = None + loss = constant_op.constant(104) + if mode == model_fn_lib.ModeKeys.TRAIN: + loss = constant_op.constant(105) + predictions = constant_op.constant([501]) + train_op = control_flow_ops.group( + state_ops.assign_add(training.get_global_step(), 1), + state_ops.assign_add(v, 3)) + elif mode == model_fn_lib.ModeKeys.EVAL: + loss = constant_op.constant(106) + predictions = constant_op.constant([502]) + else: + loss = constant_op.constant(107) + predictions = constant_op.constant([503]) + return model_fn_lib.EstimatorSpec( + mode, + loss=loss, + train_op=train_op, + eval_metric_ops={ + 'abs_err': metrics_lib.mean_absolute_error( + constant_op.constant(0), predictions)}, + predictions=predictions) + + +class SavedModelEstimatorTest(test.TestCase): + + def setUp(self): + self.tmpdirs = [] + + def tearDown(self): + for tmpdir in self.tmpdirs: + # gfile.DeleteRecursively fails in the windows cmake test, so use shutil. + shutil.rmtree(tmpdir, ignore_errors=True) + self.tmpdirs = [] + + def _get_tmp_dir(self): + tmpdir = tempfile.mkdtemp() + self.tmpdirs.append(tmpdir) + return tmpdir + + def _export_estimator(self, train=True, evaluate=True, predict=True, + model_fn=model_fn_diff_modes): + est = estimator.Estimator(model_fn, self._get_tmp_dir()) + est.train(input_fn=dummy_input_fn, steps=10) + + input_receiver_fn_map = {} + if train: + input_receiver_fn_map[model_fn_lib.ModeKeys.TRAIN] = ( + dummy_supervised_receiver_fn()) + if evaluate: + input_receiver_fn_map[model_fn_lib.ModeKeys.EVAL] = ( + dummy_supervised_receiver_fn()) + if predict: + input_receiver_fn_map[model_fn_lib.ModeKeys.PREDICT] = ( + dummy_serving_receiver_fn()) + + export_base_path = self._get_tmp_dir() + export_dir = contrib_export.export_all_saved_models( + est, export_base_path, input_receiver_fn_map) + return export_dir + + def test_load_all_modes(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + sme.train(input_fn=dummy_input_fn, steps=1) + sme.train(input_fn=dummy_input_fn, steps=2) + self.assertEqual(13, sme.get_variable_value('global_step')) + self.assertEqual(60, sme.get_variable_value('some_var')) + + eval_results = sme.evaluate(dummy_input_fn, steps=5) + + self.assertEqual(13, eval_results['global_step']) + self.assertEqual(106, eval_results['loss']) + self.assertEqual(502, eval_results['metrics/abs_err']) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + def test_load_all_modes_no_train(self): + """Ensure that all functions can be used without requiring a ckpt.""" + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + eval_results = sme.evaluate(dummy_input_fn, steps=5) + self.assertEqual(10, eval_results['global_step']) + self.assertEqual(106, eval_results['loss']) + self.assertEqual(502, eval_results['metrics/abs_err']) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + def test_partial_exported_estimator(self): + sme1 = saved_model_estimator.SavedModelEstimator( + self._export_estimator(train=False, predict=False), self._get_tmp_dir()) + sme1.evaluate(dummy_input_fn, steps=5) + with self.assertRaisesRegexp(RuntimeError, 'train mode is not available'): + sme1.train(input_fn=dummy_input_fn, steps=1) + with self.assertRaisesRegexp(RuntimeError, 'infer mode is not available'): + next(sme1.predict(dummy_input_fn_features_only)) + + sme2 = saved_model_estimator.SavedModelEstimator( + self._export_estimator(evaluate=False), self._get_tmp_dir()) + sme2.train(input_fn=dummy_input_fn, steps=1) + next(sme2.predict(dummy_input_fn_features_only)) + with self.assertRaisesRegexp(RuntimeError, 'eval mode is not available'): + sme2.evaluate(dummy_input_fn, steps=5) + + def test_with_incorrect_input(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + + def bad_shape_input_fn(): + return dataset_ops.Dataset.from_tensors(( + {'x': constant_op.constant([1, 2], dtype=dtypes.int64)}, + constant_op.constant([1, 2], dtype=dtypes.float32))) + + with self.assertRaisesRegexp(ValueError, 'Expected shape'): + sme.train(bad_shape_input_fn, steps=1) + + def bad_dtype_input_fn(): + return dataset_ops.Dataset.from_tensors(( + {'x': constant_op.constant([[1], [1]], dtype=dtypes.int32)}, + constant_op.constant([[1], [1]], dtype=dtypes.int64))) + + with self.assertRaisesRegexp(ValueError, 'Expected dtype'): + sme.train(bad_dtype_input_fn, steps=1) + + def test_input_fn_with_global_step(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + + def bad_input_fn(): + training.get_or_create_global_step() + return dataset_ops.Dataset.from_tensors(( + {'x': constant_op.constant([[1], [1]], dtype=dtypes.int64)}, + constant_op.constant([[1], [1]], dtype=dtypes.float32))) + + with self.assertRaisesRegexp(RuntimeError, + 'Graph must not contain a global step tensor'): + sme.train(bad_input_fn, steps=1) + + def test_re_export_saved_model_serving_only(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + sme.train(dummy_input_fn, steps=3) + self.assertEqual(13, sme.get_variable_value('global_step')) + self.assertEqual(60, sme.get_variable_value('some_var')) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + # Export SavedModel, and test that the variable and prediction values are + # the same. + sme_export_dir = sme.export_savedmodel( + self._get_tmp_dir(), dummy_serving_receiver_fn()) + + sme2 = saved_model_estimator.SavedModelEstimator( + sme_export_dir, self._get_tmp_dir()) + self.assertEqual(60, sme.get_variable_value('some_var')) + self.assertEqual(13, sme.get_variable_value('global_step')) + + predictions = next(sme2.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + def test_re_export_saved_model(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + self.assertDictEqual( + {'loss': 106, 'metrics/abs_err': 502, 'global_step': 10}, + sme.evaluate(dummy_input_fn, steps=1)) + + sme.train(dummy_input_fn, steps=3) + self.assertDictEqual( + {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13}, + sme.evaluate(dummy_input_fn, steps=1)) + self.assertEqual(60, sme.get_variable_value('some_var')) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + # Export SavedModel for all modes + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: dummy_supervised_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: dummy_supervised_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: dummy_serving_receiver_fn()} + sme_export_dir = contrib_export.export_all_saved_models( + sme, self._get_tmp_dir(), input_receiver_fn_map) + + sme2 = saved_model_estimator.SavedModelEstimator( + sme_export_dir, self._get_tmp_dir()) + self.assertDictEqual( + {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13}, + sme.evaluate(dummy_input_fn, steps=1)) + self.assertEqual(60, sme.get_variable_value('some_var')) + + sme.train(dummy_input_fn, steps=7) + self.assertEqual(20, sme.get_variable_value('global_step')) + + predictions = next(sme2.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + def test_load_saved_model_from_serving_only(self): + def model_fn(features, labels, mode): + _, _ = features, labels + return model_fn_lib.EstimatorSpec( + mode, + loss=constant_op.constant([103]), + train_op=state_ops.assign_add(training.get_global_step(), 1), + predictions=constant_op.constant([502]), + export_outputs={'test': export_output.ClassificationOutput( + constant_op.constant([[32.]]))}) + + est = estimator.Estimator(model_fn, self._get_tmp_dir()) + est.train(input_fn=dummy_input_fn, steps=10) + + def serving_input_receiver_fn(): + return export.ServingInputReceiver( + {'test-features': constant_op.constant([[1], [1]])}, + array_ops.placeholder(dtype=dtypes.string)) + + export_dir = est.export_savedmodel( + self._get_tmp_dir(), serving_input_receiver_fn) + + sme = saved_model_estimator.SavedModelEstimator( + export_dir, self._get_tmp_dir()) + + def input_fn(): + return {'inputs': constant_op.constant('someinputstr')} + + prediction = next(sme.predict(input_fn)) + self.assertDictEqual({'scores': 32}, prediction) + + def test_with_local_init_op(self): + def model_fn(features, labels, mode): + _, _ = features, labels + v = variables.Variable(21, name='some_var') + scaffold = monitored_session.Scaffold( + local_init_op=state_ops.assign_add(v, -3).op + ) + return model_fn_lib.EstimatorSpec( + mode, + scaffold=scaffold, + train_op=state_ops.assign_add(training.get_global_step(), 1), + loss=array_ops.identity(v)) + export_dir = self._export_estimator(predict=False, model_fn=model_fn) + sme = saved_model_estimator.SavedModelEstimator( + export_dir, self._get_tmp_dir()) + + eval_results1 = sme.evaluate(dummy_input_fn, steps=2) + self.assertEqual(15, eval_results1['loss']) + + sme.train(dummy_input_fn, steps=1) + self.assertEqual(15, sme.get_variable_value('some_var')) + + eval_results2 = sme.evaluate(dummy_input_fn, steps=5) + self.assertEqual(12, eval_results2['loss']) + + def test_with_working_input_fn(self): + def model_fn(features, labels, mode): + loss = None + if labels is not None: + loss = labels[0][0] + labels[1][0] + return model_fn_lib.EstimatorSpec( + mode, + loss=loss, + train_op=state_ops.assign_add(training.get_global_step(), 1), + predictions={'features_0': array_ops.identity([features['x'][0][0]]), + 'features_1': array_ops.identity([features['x'][1][0]])}) + + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(model_fn=model_fn), self._get_tmp_dir()) + eval_results = sme.evaluate(dummy_input_fn, steps=1) + self.assertEqual(1, eval_results['loss']) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'features_0': 5, 'features_1': 6}, predictions) + + def test_control_dependency(self): + # Control dependencies are saved with "^" appended to the start of the input + # name. The input map must include control dependencies as well. + def model_fn(features, labels, mode): + _ = labels + with ops.control_dependencies([features['x']]): + loss = features['x'][1][0] + return model_fn_lib.EstimatorSpec( + mode, + loss=loss, + train_op=state_ops.assign_add(training.get_global_step(), 1)) + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(train=False, predict=False, model_fn=model_fn), + self._get_tmp_dir()) + sme.evaluate(dummy_input_fn, steps=1) # Should run without error + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index effec42f028fe472593a8d06e15a0831346d6f50..9e1f14f9905d584287864c15d9b6f9c152d17787 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -65,7 +65,7 @@ tf_custom_op_py_library( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/estimator", - "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], @@ -242,7 +242,7 @@ py_test( "//tensorflow/python:platform_benchmark", "//tensorflow/python:random_ops", "//tensorflow/python:training", - "//tensorflow/python/estimator:run_config", + "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc index bb9b835889b1b5e36d6f470b51834d4c6bb3d493..7fcae5ad8e1536530e2d039e1d14df4e192c4fa3 100644 --- a/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc +++ b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc @@ -62,10 +62,11 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { public: explicit WALSComputePartialLhsAndRhsOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->MatchSignature( - {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, - DT_INT64, DT_FLOAT, DT_INT64, DT_BOOL}, - {DT_FLOAT, DT_FLOAT})); + OP_REQUIRES_OK(context, + context->MatchSignature( + {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, + DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL}, + {DT_FLOAT, DT_FLOAT})); } void Compute(OpKernelContext* context) override { @@ -75,8 +76,9 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { const Tensor& input_weights = context->input(3); const Tensor& input_indices = context->input(4); const Tensor& input_values = context->input(5); - const Tensor& input_block_size = context->input(6); - const Tensor& input_is_transpose = context->input(7); + const Tensor& entry_weights = context->input(6); + const Tensor& input_block_size = context->input(7); + const Tensor& input_is_transpose = context->input(8); OP_REQUIRES(context, TensorShapeUtils::IsMatrix(factors.shape()), InvalidArgument("Input factors should be a matrix.")); @@ -89,13 +91,33 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { InvalidArgument("Input input_weights should be a vector.")); OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()), InvalidArgument("Input input_indices should be a matrix.")); + OP_REQUIRES( + context, input_indices.dim_size(1) == 2, + InvalidArgument("Input input_indices should have shape (?, 2).")); OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()), InvalidArgument("Input input_values should be a vector")); + OP_REQUIRES(context, TensorShapeUtils::IsVector(entry_weights.shape()), + InvalidArgument("Input entry_weights should be a vector")); + OP_REQUIRES(context, input_indices.dim_size(0) == input_values.dim_size(0), + InvalidArgument("Input input_values' length should match the " + "first dimension of Input input_indices ")); OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_block_size.shape()), InvalidArgument("Input input_block_size should be a scalar.")); OP_REQUIRES( context, TensorShapeUtils::IsScalar(input_is_transpose.shape()), InvalidArgument("Input input_is_transpose should be a scalar.")); + OP_REQUIRES( + context, + ((input_weights.dim_size(0) > 0 && + factor_weights.dim_size(0) == factors.dim_size(0) && + entry_weights.dim_size(0) == 0) || + (input_weights.dim_size(0) == 0 && factor_weights.dim_size(0) == 0 && + entry_weights.dim_size(0) == input_indices.dim_size(0))), + InvalidArgument("To specify the weights for observed entries, either " + "(1) entry_weights must be set or (2) input_weights " + "and factor_weights must be set, but not both.")); + // TODO(yifanchen): Deprecate the support of input_weights and + // factor_weights. const int64 factor_dim = factors.dim_size(1); const int64 factors_size = factors.dim_size(0); @@ -105,6 +127,7 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { const auto& input_weights_vec = input_weights.vec(); const float w_0 = unobserved_weights.scalar()(); const auto& input_values_vec = input_values.vec(); + const auto& entry_weights_vec = entry_weights.vec(); ConstEigenMatrixFloatMap factors_mat(factors.matrix().data(), factor_dim, factors_size); @@ -134,6 +157,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { return is_transpose ? indices_mat(0, i) : indices_mat(1, i); }; + const bool use_entry_weights = entry_weights_vec.size() > 0; + // TODO(rmlarsen): In principle, we should be using the SparseTensor class // and machinery for iterating over groups, but the fact that class // SparseTensor makes a complete copy of the matrix makes me reluctant to @@ -195,6 +220,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { // map using the hash of the thread id as the key. // // TODO(jpoulson): Switch to try_emplace once C++17 is supported + // TODO(b/72952120): Check whether the 3 lock-unlock pairs can be + // consolidated into just one. map_mutex.lock(); const auto key_count = factor_batch_map.count(id_hash); map_mutex.unlock(); @@ -213,6 +240,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { CHECK_LE(shard.second, perm.size()); CHECK_LE(shard.first, shard.second); const int64 input_index = get_input_index(perm[shard.first]); + const float input_weight = + use_entry_weights ? 1.0 : input_weights_vec(input_index); // Accumulate the rhs and lhs terms in the normal equations // for the non-zero elements in the row or column of the sparse matrix // corresponding to input_index. @@ -228,7 +257,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel { const int64 factor_index = get_factor_index(i); const float input_value = input_values_vec(i); const float weight = - input_weights_vec(input_index) * factor_weights_vec(factor_index); + use_entry_weights ? entry_weights_vec(i) + : input_weight * factor_weights_vec(factor_index); CHECK_GE(weight, 0); factor_batch.col(num_batched) = factors_mat.col(factor_index) * std::sqrt(weight); diff --git a/tensorflow/contrib/factorization/ops/factorization_ops.cc b/tensorflow/contrib/factorization/ops/factorization_ops.cc index 11ea36946e92769cd6901eb998a20148250ef7ce..1d31bd38c824f24e9a70c0f69da129f5ddc18985 100644 --- a/tensorflow/contrib/factorization/ops/factorization_ops.cc +++ b/tensorflow/contrib/factorization/ops/factorization_ops.cc @@ -25,20 +25,33 @@ REGISTER_OP("WALSComputePartialLhsAndRhs") .Input("input_weights: float32") .Input("input_indices: int64") .Input("input_values: float32") + .Input("entry_weights: float32") .Input("input_block_size: int64") .Input("input_is_transpose: bool") .Output("partial_lhs: float32") .Output("partial_rhs: float32") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"( -Computes the partial left-hand side and right-hand side of WALS update. +Computes the partial left-hand side and right-hand side of WALS update. For +observed entry input_indices[i]=[m, n] with value input_values[i]=v, the weight +should be specified either through (1) entry_weights[i] or (2) through +input_weights[m] * factor_weights[n] (if input_is_transpose is false) or +input_weights[n] * factor_weights[m] (if input_is_transpose is true). Note it is +not allowed to have both (1) and (2) specified at the same time: when one +approach is used, the input tensors related to the other approach must be kept +completely empty. factors: Matrix of size m * k. -factor_weights: Vector of size m. Corresponds to column weights +factor_weights: Vector of size m. Corresponds to column weights. Should be empty + if entry_weights is used. unobserved_weights: Scalar. Weight for unobserved input entries. -input_weights: Vector of size n. Corresponds to row weights. +input_weights: Vector of size n. Corresponds to row weights. Should be empty if + entry_weights is used. input_indices: Indices for the input SparseTensor. input_values: Values for the input SparseTensor. +entry_weights: If not empty, this must be same length as input_vaues and is used + as the per-entry non-zero weight. If this is used, input_weights and + factor_weights must be empty. input_block_size: Scalar. Number of rows spanned by input. input_is_transpose: If true, logically transposes the input for processing. partial_lhs: 3-D tensor with size input_block_size x k x k. diff --git a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py index ba30fd997700f461b6afffa13cf371c598d3332e..6c2f1d46084d701beac1e3a99e3ad66bae57eda5 100644 --- a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py +++ b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py @@ -55,7 +55,41 @@ class WalsSolverOpsTest(test.TestCase): rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs( self._column_factors, self._column_weights, self._unobserved_weights, self._row_weights, sparse_block.indices, sparse_block.values, - sparse_block.dense_shape[0], False) + [], + input_block_size=sparse_block.dense_shape[0], + input_is_transpose=False) + self.assertAllClose(lhs_tensor.eval(), [[ + [0.014800, 0.017000, 0.019200], + [0.017000, 0.019600, 0.022200], + [0.019200, 0.022200, 0.025200], + ], [ + [0.0064000, 0.0080000, 0.0096000], + [0.0080000, 0.0100000, 0.0120000], + [0.0096000, 0.0120000, 0.0144000], + ], [ + [0.0099000, 0.0126000, 0.0153000], + [0.0126000, 0.0162000, 0.0198000], + [0.0153000, 0.0198000, 0.0243000], + ], [ + [0.058800, 0.067200, 0.075600], + [0.067200, 0.076800, 0.086400], + [0.075600, 0.086400, 0.097200], + ]]) + self.assertAllClose(rhs_matrix.eval(), [[0.019300, 0.023000, 0.026700], + [0.061600, 0.077000, 0.092400], + [0.160400, 0.220000, 0.279600], + [0.492800, 0.563200, 0.633600]]) + + def testWalsSolverLhsEntryWeights(self): + sparse_block = SparseBlock3x3() + with self.test_session(): + [lhs_tensor, + rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs( + self._column_factors, [], self._unobserved_weights, + [], sparse_block.indices, sparse_block.values, + [0.01, 0.03, 0.04, 0.03, 0.06, 0.12], + input_block_size=sparse_block.dense_shape[0], + input_is_transpose=False) self.assertAllClose(lhs_tensor.eval(), [[ [0.014800, 0.017000, 0.019200], [0.017000, 0.019600, 0.022200], diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 8f73274c2a0ebbdc41ce6a647a8a5650694c9a23..7ab70fbcfd7324961b61526a08daab7e393630e9 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -943,6 +943,7 @@ class WALSModel(object): row_weights_slice, new_sp_input.indices, new_sp_input.values, + [], num_rows, transpose_input, name="wals_compute_partial_lhs_rhs")) diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 9ffdd3ba5e8ac496533d0207f2b6846dbc92bc89..f384d761a8430074f022c973d7ec3d46cd90f70b 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -158,12 +158,12 @@ class _ModelFn(object): return either `features` or, equivalently, `(features, None)`. Args: - features: The input points. See @{tf.estimator.Estimator}. - mode: See @{tf.estimator.Estimator}. - config: See @{tf.estimator.Estimator}. + features: The input points. See `tf.estimator.Estimator`. + mode: See `tf.estimator.Estimator`. + config: See `tf.estimator.Estimator`. Returns: - A @{tf.estimator.EstimatorSpec} (see @{tf.estimator.Estimator}) specifying + A `tf.estimator.EstimatorSpec` (see `tf.estimator.Estimator`) specifying this behavior: * `train_op`: Execute one mini-batch or full-batch run of Lloyd's algorithm. @@ -188,7 +188,6 @@ class _ModelFn(object): # center. # is_initialized: scalar indicating whether the initial cluster centers # have been chosen; see init_op. - # cluster_centers_var: a Variable containing the cluster centers. # init_op: an op to choose the initial cluster centers. A single worker # repeatedly executes init_op until is_initialized becomes True. # training_op: an op that runs an iteration of training, either an entire @@ -394,7 +393,7 @@ class KMeansClustering(estimator.Estimator): relative_tolerance: A relative tolerance of change in the loss between iterations. Stops learning if the loss changes less than this amount. This may not work correctly if `use_mini_batch=True`. - config: See @{tf.estimator.Estimator}. + config: See `tf.estimator.Estimator`. feature_columns: An optionable iterable containing all the feature columns used by the model. All items in the set should be feature column instances that can be passed to `tf.feature_column.input_layer`. If this @@ -431,7 +430,7 @@ class KMeansClustering(estimator.Estimator): """Finds the index of the closest cluster center to each input point. Args: - input_fn: Input points. See @{tf.estimator.Estimator.predict}. + input_fn: Input points. See `tf.estimator.Estimator.predict`. Yields: The index of the closest cluster center for each input point. @@ -447,7 +446,7 @@ class KMeansClustering(estimator.Estimator): which returns the negative sum. Args: - input_fn: Input points. See @{tf.estimator.Estimator.evaluate}. Only one + input_fn: Input points. See `tf.estimator.Estimator.evaluate`. Only one batch is retrieved. Returns: @@ -465,7 +464,7 @@ class KMeansClustering(estimator.Estimator): sklearn function returns the Euclidean distance. Args: - input_fn: Input points. See @{tf.estimator.Estimator.predict}. + input_fn: Input points. See `tf.estimator.Estimator.predict`. Yields: The distances from each input point to each cluster center. diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index b588f75efe9d0bbf8213a89978a627c0a0ccf554..05bcdac2caa77062f9a8a44a948d2897b439ea1f 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -95,7 +95,7 @@ def sequence_input_layer( Raises: ValueError: If any of the `feature_columns` is the wrong type. """ - feature_columns = fc._clean_feature_columns(feature_columns) + feature_columns = fc._normalize_feature_columns(feature_columns) for c in feature_columns: if not isinstance(c, fc._SequenceDenseColumn): raise ValueError( diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 89b5f4c4137f6c42417f539a578fd8b11f8b235d..45d7b740462ca21139e2e93e34b43668f1e08a94 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -110,7 +110,7 @@ class SequenceInputLayerTest(test.TestCase): expected_sequence_length, sequence_length.eval(session=sess)) def test_embedding_column_with_non_sequence_categorical(self): - """Tests that error is raised for non-sequence categorical column.""" + """Tests that error is raised for non-sequence embedding column.""" vocabulary_size = 3 sparse_input = sparse_tensor.SparseTensorValue( # example 0, ids [2] @@ -132,6 +132,107 @@ class SequenceInputLayerTest(test.TestCase): features={'aaa': sparse_input}, feature_columns=[embedding_column_a]) + def test_shared_embedding_column(self): + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + + def _get_initializer(embedding_dimension, embedding_values): + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + return _initializer + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 3., 4.], [0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 5., 6.], [3., 4., 1., 2.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + # Test that columns are reordered alphabetically. + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension, + initializer=_get_initializer(embedding_dimension, embedding_values)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + feature_columns=shared_embedding_columns) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_shared_embedding_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence shared embedding column.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_shared_embedding\. categorical_column must ' + r'be of type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b + }, + feature_columns=shared_embedding_columns) + def test_indicator_column(self): vocabulary_size_a = 3 sparse_input_a = sparse_tensor.SparseTensorValue( @@ -578,6 +679,182 @@ class SequenceEmbeddingColumnTest(test.TestCase): expected_sequence_length, sequence_length.eval(session=sess)) +class SequenceSharedEmbeddingColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [0, 2] + # example 2, ids [0] + # example 3, ids [] + indices=((0, 0), (1, 0), (1, 1), (2, 0)), + values=(1, 0, 2, 0), + dense_shape=(4, 2)) + + expected_lookups_a = [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]], + ] + + expected_lookups_b = [ + # example 0, ids [1] + [[3., 5.], [0., 0.]], + # example 1, ids [0, 2] + [[1., 2.], [7., 11.]], + # example 2, ids [0] + [[1., 2.], [0., 0.]], + # example 3, ids [] + [[0., 0.], [0., 0.]], + ] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[0] + embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[0] + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual( + expected_lookups_a, embedding_lookup_a.eval(session=sess)) + self.assertAllEqual( + expected_lookups_b, embedding_lookup_b.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length_a = [1, 2] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [0, 2] + # example 1, ids [1] + indices=((0, 0), (0, 1), (1, 0)), + values=(0, 2, 1), + dense_shape=(2, 2)) + expected_sequence_length_b = [2, 1] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + sequence_length_a = sess.run(sequence_length_a) + self.assertAllEqual(expected_sequence_length_a, sequence_length_a) + self.assertEqual(np.int64, sequence_length_a.dtype) + sequence_length_b = sess.run(sequence_length_b) + self.assertAllEqual(expected_sequence_length_b, sequence_length_b) + self.assertEqual(np.int64, sequence_length_b.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length_a = [0, 1, 2, 0, 1, 0] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [] + # example 2, ids [] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [0, 1] + indices=((0, 0), (4, 0), (5, 0), (5, 1)), + values=(2, 1, 0, 1), + dense_shape=(6, 2)) + expected_sequence_length_b = [1, 0, 0, 0, 1, 2] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length_a, sequence_length_a.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length_b, sequence_length_b.eval(session=sess)) + + class SequenceIndicatorColumnTest(test.TestCase): def test_get_sequence_dense_tensor(self): diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index dc49383c5c300e82839c478e097074b3e8776b3b..20d099fe5d49dac0caec4a28801f09e7bee4f2e2 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -100,6 +100,8 @@ See the @{$python/contrib.framework} guide. @@BoundedTensorSpec @@TensorSpec + +@@RecordInput """ from __future__ import absolute_import @@ -119,6 +121,7 @@ from tensorflow.python.framework.smart_cond import smart_cond from tensorflow.python.framework.smart_cond import smart_constant_value from tensorflow.python.framework.tensor_spec import BoundedTensorSpec from tensorflow.python.framework.tensor_spec import TensorSpec +from tensorflow.python.ops.data_flow_ops import RecordInput from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal from tensorflow.python.ops.init_ops import convolutional_orthogonal_1d from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d @@ -133,6 +136,7 @@ _nest_allowed_symbols = [ 'flatten_dict_items', 'pack_sequence_as', 'map_structure', + 'map_structure_with_paths', 'assert_shallow_structure', 'flatten_up_to', 'map_structure_up_to', diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py index 9e356dd96562c28adec7fc28fe144394e1c2ed38..e7184a01fbf57319399fc6dd287b7387138b4058 100644 --- a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py +++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py @@ -27,7 +27,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import training as train __all__ = [ @@ -40,7 +40,7 @@ __all__ = [ def _get_checkpoint_filename(filepattern): """Returns checkpoint filename given directory or specific filepattern.""" if gfile.IsDirectory(filepattern): - return saver.latest_checkpoint(filepattern) + return checkpoint_management.latest_checkpoint(filepattern) return filepattern diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index 5b150339953f961c756c0909dd1795341159b9cd..0a02e76a265c8ad25d978e7d610fb50fc0fdfdb1 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -103,9 +103,8 @@ def _kwarg_names(func): def _add_op(op): - key = arg_scope_func_key(op) - if key not in _DECORATED_OPS: - _DECORATED_OPS[key] = _kwarg_names(op) + key_op = arg_scope_func_key(op) + _DECORATED_OPS[key_op] = _kwarg_names(op) @tf_contextlib.contextmanager diff --git a/tensorflow/contrib/framework/python/ops/arg_scope_test.py b/tensorflow/contrib/framework/python/ops/arg_scope_test.py index 4c3879d4fc08b53ea8be5f1256a830a64fb39af6..bcafc1a3280ba0435f655eacb8173e4e97051154 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope_test.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope_test.py @@ -38,6 +38,12 @@ def func3(args, a=None, b=1, c=2): """Some cool doc string.""" return (args, a, b, c) +@add_arg_scope +def func4(x='x', y='y'): + if x: + pass + if y: + pass def _key_op(op): return getattr(op, '_key_op', str(op)) @@ -231,6 +237,15 @@ class ArgScopeTest(test.TestCase): self.assertTupleEqual(args, func2_args) self.assertDictEqual(kwargs, func2_kwargs) + def testAddArgScopeRaceCondition(self): + func4_kwargs = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h') + for i in range(4): + # redefine the function with different args + @add_arg_scope + def func4(a=1, b=2, c=3, d=4, e=5, f=6, g=7, h=8): + pass + self.assertTupleEqual(arg_scoped_arguments(func4), func4_kwargs) + def testDocString(self): self.assertEqual(func3.__doc__, 'Some cool doc string.') diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py index df7d7e9dae80722569efccbc9cc0d1b75e90cf03..34fd5018af125335845540dedfdffc984ba02313 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_test.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py @@ -34,7 +34,7 @@ from tensorflow.python.platform import tf_logging as logging class CriticalSectionTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCreateCriticalSection(self): cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") @@ -53,7 +53,7 @@ class CriticalSectionTest(test.TestCase): self.assertAllClose([2.0 * i for i in range(num_concurrent)], sorted(r_value)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCriticalSectionWithControlFlow(self): for outer_cond in [False, True]: for inner_cond in [False, True]: @@ -109,7 +109,7 @@ class CriticalSectionTest(test.TestCase): with self.assertRaisesOpError("Error"): self.evaluate(r) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCreateCriticalSectionFnReturnsOp(self): cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") @@ -332,7 +332,7 @@ class CriticalSectionTest(test.TestCase): self.evaluate(v.initializer) self.assertEqual(10, self.evaluate(out)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInsideFunction(self): cs = critical_section_ops.CriticalSection() v = resource_variable_ops.ResourceVariable(1) diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 40ae01bfcce1dde580e6a5f6d9c8ec1aa1abb83f..a7acae804a0c71cc19757a48d47fd9cf9022b0e2 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import resource_loader from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as tf_saver @@ -199,10 +200,20 @@ def global_variable(initial_value, @contrib_add_arg_scope -def variable(name, shape=None, dtype=None, initializer=None, - regularizer=None, trainable=True, collections=None, - caching_device=None, device=None, - partitioner=None, custom_getter=None, use_resource=None): +def variable(name, + shape=None, + dtype=None, + initializer=None, + regularizer=None, + trainable=True, + collections=None, + caching_device=None, + device=None, + partitioner=None, + custom_getter=None, + use_resource=None, + synchronization=variables.VariableSynchronization.AUTO, + aggregation=variables.VariableAggregation.NONE): """Gets an existing variable with these parameters or creates a new one. Args: @@ -228,6 +239,15 @@ def variable(name, shape=None, dtype=None, initializer=None, custom_getter: Callable that allows overwriting the internal get_variable method and has to have the same signature. use_resource: If `True` use a ResourceVariable instead of a Variable. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. Returns: The created or existing variable. @@ -242,21 +262,36 @@ def variable(name, shape=None, dtype=None, initializer=None, getter = functools.partial(custom_getter, reuse=variable_scope.get_variable_scope().reuse) with ops.device(device or ''): - return getter(name, shape=shape, dtype=dtype, - initializer=initializer, - regularizer=regularizer, - trainable=trainable, - collections=collections, - caching_device=caching_device, - partitioner=partitioner, - use_resource=use_resource) + return getter( + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + trainable=trainable, + collections=collections, + caching_device=caching_device, + partitioner=partitioner, + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) @contrib_add_arg_scope -def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None, - regularizer=None, trainable=True, collections=None, - caching_device=None, device=None, partitioner=None, - custom_getter=None, use_resource=None): +def model_variable(name, + shape=None, + dtype=dtypes.float32, + initializer=None, + regularizer=None, + trainable=True, + collections=None, + caching_device=None, + device=None, + partitioner=None, + custom_getter=None, + use_resource=None, + synchronization=variables.VariableSynchronization.AUTO, + aggregation=variables.VariableAggregation.NONE): """Gets an existing model variable with these parameters or creates a new one. Args: @@ -283,18 +318,36 @@ def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None, custom_getter: Callable that allows overwriting the internal get_variable method and has to have the same signature. use_resource: If `True` use a ResourceVariable instead of a Variable. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. Returns: The created or existing variable. """ collections = list(collections or []) collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES] - var = variable(name, shape=shape, dtype=dtype, - initializer=initializer, regularizer=regularizer, - trainable=trainable, collections=collections, - caching_device=caching_device, device=device, - partitioner=partitioner, custom_getter=custom_getter, - use_resource=use_resource) + var = variable( + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + trainable=trainable, + collections=collections, + caching_device=caching_device, + device=device, + partitioner=partitioner, + custom_getter=custom_getter, + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) return var @@ -712,7 +765,8 @@ class VariableDeviceChooser(object): num_tasks=0, job_name='ps', device_type='CPU', - device_index=0): + device_index=0, + replica=None): """Initialize VariableDeviceChooser. Usage: @@ -733,12 +787,15 @@ class VariableDeviceChooser(object): self._job_name = job_name self._device_type = device_type self._device_index = device_index + self._replica = replica self._num_tasks = num_tasks self._next_task_id = 0 def __call__(self, op): - device_spec = tf_device.DeviceSpec(device_type=self._device_type, - device_index=self._device_index) + device_spec = tf_device.DeviceSpec( + replica=self._replica, + device_type=self._device_type, + device_index=self._device_index) if self._num_tasks > 0: task_id = self._next_task_id self._next_task_id = (self._next_task_id + 1) % self._num_tasks diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 37ea6eb12aba7d25656f19cbbc86475c1228d916..3c44630a51deb8a468165e8da458600665d0ada1 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -106,8 +106,9 @@ class LocalVariableTest(test.TestCase): def testResourceVariable(self): a = variables_lib2.local_variable(0) b = variables_lib2.local_variable(0, use_resource=True) - self.assertEqual(type(a), variables_lib.Variable) - self.assertEqual(type(b), resource_variable_ops.ResourceVariable) + self.assertTrue(isinstance(a, variables_lib.Variable)) + self.assertFalse(isinstance(a, resource_variable_ops.ResourceVariable)) + self.assertTrue(isinstance(b, resource_variable_ops.ResourceVariable)) class GlobalVariableTest(test.TestCase): @@ -176,8 +177,9 @@ class GlobalVariableTest(test.TestCase): def testResourceVariable(self): a = variables_lib2.global_variable(0) b = variables_lib2.global_variable(0, use_resource=True) - self.assertEqual(type(a), variables_lib.Variable) - self.assertEqual(type(b), resource_variable_ops.ResourceVariable) + self.assertTrue(isinstance(a, variables_lib.Variable)) + self.assertFalse(isinstance(a, resource_variable_ops.ResourceVariable)) + self.assertTrue(isinstance(b, resource_variable_ops.ResourceVariable)) class GlobalStepTest(test.TestCase): @@ -506,6 +508,35 @@ class VariablesTest(test.TestCase): self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0') self.assertDeviceEqual(e.initial_value.device, '/cpu:99') + def testVariableWithVariableDeviceChooserWithReplica(self): + + with ops.Graph().as_default(): + device_fn = variables_lib2.VariableDeviceChooser(replica=3, num_tasks=2) + with arg_scope([variables_lib2.variable], device=device_fn): + a = variables_lib2.variable('a', []) + b = variables_lib2.variable('b', []) + c = variables_lib2.variable('c', [], device='cpu:12') + d = variables_lib2.variable('d', []) + with ops.device('cpu:99'): + e_init = constant_op.constant(12) + e = variables_lib2.variable('e', initializer=e_init) + # The values below highlight how the VariableDeviceChooser puts initial + # values on the same device as the variable job. + self.assertDeviceEqual(a.device, '/job:ps/replica:3/task:0/cpu:0') + self.assertEqual(a.initial_value.op.colocation_groups(), + a.op.colocation_groups()) + self.assertDeviceEqual(b.device, '/job:ps/replica:3/task:1/cpu:0') + self.assertEqual(b.initial_value.op.colocation_groups(), + b.op.colocation_groups()) + self.assertDeviceEqual(c.device, '/cpu:12') + self.assertEqual(c.initial_value.op.colocation_groups(), + c.op.colocation_groups()) + self.assertDeviceEqual(d.device, '/job:ps/replica:3/task:0/cpu:0') + self.assertEqual(d.initial_value.op.colocation_groups(), + d.op.colocation_groups()) + self.assertDeviceEqual(e.device, '/job:ps/replica:3/task:1/cpu:0') + self.assertDeviceEqual(e.initial_value.device, '/cpu:99') + def testVariableGPUPlacement(self): with ops.Graph().as_default(): @@ -930,8 +961,8 @@ class AssignFromCheckpointTest(test.TestCase): return saver.save(sess, checkpoint_dir, global_step=global_step) def testLoadExistingVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'load_existing_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables')) init_value0 = 10.0 init_value1 = 20.0 @@ -944,8 +975,8 @@ class AssignFromCheckpointTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -960,8 +991,8 @@ class AssignFromCheckpointTest(test.TestCase): # Tests restoring PartitionedVariables and tests using a dictionary # of lists as the assign_from_checkpoint() var_list param. def testLoadPartitionedVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'load_partitioned_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_partitioned_variables')) init_value0 = np.array([[10.0, 11.0], [12.0, 13.0]]) init_value1 = np.array([20.0]) # Partitioned into 1 part, edge case. @@ -974,15 +1005,14 @@ class AssignFromCheckpointTest(test.TestCase): partitioner = partitioned_variables.variable_axis_size_partitioner(2) var0 = variables_lib2.variable( 'var0', shape=init_value0.shape, partitioner=partitioner) - var0full = variables_lib2.variable( - 'var0full', shape=init_value0.shape) + var0full = variables_lib2.variable('var0full', shape=init_value0.shape) var1 = variables_lib2.variable( 'var1', shape=init_value1.shape, partitioner=partitioner) # Convert var0 and var1 into a list of underlying variables. vars_to_restore = {'var0': list(var0) + [var0full], 'var1': list(var1)} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -992,16 +1022,18 @@ class AssignFromCheckpointTest(test.TestCase): # Request and test the variable values. PartitionedVariables can't # be evaled so we wrap them in an identity. - self.assertTrue(np.array_equal( - init_value0, array_ops.identity(var0).eval())) - self.assertTrue(np.array_equal( - init_value0, var0full.eval())) - self.assertTrue(np.array_equal( - init_value1, array_ops.identity(var1).eval())) + self.assertTrue( + np.array_equal(init_value0, + array_ops.identity(var0).eval())) + self.assertTrue(np.array_equal(init_value0, var0full.eval())) + self.assertTrue( + np.array_equal(init_value1, + array_ops.identity(var1).eval())) def testRaisesValueErrorIfAVariableIsntFound(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'raises_value_error_if_var_isnt_found')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'raises_value_error_if_var_isnt_found')) init_value0 = 10.0 init_value1 = 20.0 @@ -1019,8 +1051,9 @@ class AssignFromCheckpointTest(test.TestCase): variables_lib2.assign_from_checkpoint(model_path, vars_to_restore) def testInitFromCheckpointWithScopes(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'init_from_checkpoint_with_scopes')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'init_from_checkpoint_with_scopes')) init_value0 = np.asarray( [1.0, 3.0, 9.0], dtype=np.float32).reshape((1, 3, 1)) @@ -1038,8 +1071,8 @@ class AssignFromCheckpointTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=init_value1.shape) vars_to_restore = {'layer0/v0': var0, 'layer1/v1': var1} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1081,8 +1114,8 @@ class AssignFromCheckpointFnTest(test.TestCase): return saver.save(sess, checkpoint_dir, global_step=global_step) def testLoadExistingVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'load_existing_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1097,8 +1130,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1111,8 +1144,9 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testLoadExistingVariablesDifferentShapeDefaultDoesNotAllowReshape(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'load_existing_vars_no_reshape')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'load_existing_vars_no_reshape')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1127,8 +1161,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1138,9 +1172,10 @@ class AssignFromCheckpointFnTest(test.TestCase): init_fn(sess) def testLoadExistingVariablesDifferentShapeAllowReshape(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), - 'load_existing_variables_different_shape_allow_reshape')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join( + self.get_temp_dir(), + 'load_existing_variables_different_shape_allow_reshape')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1169,8 +1204,8 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testNotFoundError(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'not_found_error')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'not_found_error')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1186,8 +1221,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var2 = variables_lib2.variable('my_var2', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1197,8 +1232,8 @@ class AssignFromCheckpointFnTest(test.TestCase): init_fn(sess) def testMissingVariablesList(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'missing_variables_list')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'missing_variables_list')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1228,8 +1263,8 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testMissingVariablesDict(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'missing_variables_dict')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'missing_variables_dict')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1279,9 +1314,8 @@ class ZeroInitializerOpTest(test.TestCase): def testZeroInitializer(self): for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64): for use_init in (False, True): - self._testZeroInitializer( - [10, 20], array_ops.ones( - [10, 20], dtype=dtype), use_init) + self._testZeroInitializer([10, 20], array_ops.ones( + [10, 20], dtype=dtype), use_init) class ZeroVarInitializerOpTest(test.TestCase): diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 2458f7554afdc12709571c551a8323cda7fa5c17..0ccb4583ab653bc2ef6c5c810c902a9332e82df9 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -135,9 +135,12 @@ class FusedConv2DBiasActivationOp : public OpKernel { context->GetAttr("activation_mode", &activation_mode_str)); OP_REQUIRES_OK(context, GetActivationModeFromString(activation_mode_str, &activation_mode_)); - OP_REQUIRES(context, activation_mode_ == ActivationMode::RELU, - errors::InvalidArgument("Current implementation only supports " - "RELU as the activation function.")); + OP_REQUIRES(context, + activation_mode_ == ActivationMode::RELU || + activation_mode_ == ActivationMode::NONE, + errors::InvalidArgument( + "Current implementation only supports RELU or NONE " + "as the activation function.")); cudnn_use_autotune_ = CudnnUseAutotune(); } @@ -440,6 +443,8 @@ void LaunchFusedConv2DBiasActivationOp:: : dnn::DataLayout::kBatchDepthYX; constexpr auto filter_layout = is_int8x4 ? dnn::FilterLayout::kOutputInputYX4 : dnn::FilterLayout::kOutputInputYX; + constexpr auto compute_data_format = + is_int8x4 ? FORMAT_NCHW_VECT_C : FORMAT_NCHW; dnn::BatchDescriptor conv_input_desc; conv_input_desc.set_count(batch_size) @@ -526,6 +531,7 @@ void LaunchFusedConv2DBiasActivationOp:: batch_size, conv_input_depth, {{conv_input_rows, conv_input_cols}}, + compute_data_format, output_depth, {{filter_rows, filter_cols}}, // TODO(yangzihao): Add support for arbitrary dilations for fused conv. @@ -538,6 +544,18 @@ void LaunchFusedConv2DBiasActivationOp:: activation_mode, }; + dnn::ActivationMode dnn_activation_mode; + switch (activation_mode) { + case ActivationMode::NONE: + dnn_activation_mode = dnn::ActivationMode::kNone; + break; + case ActivationMode::RELU: + dnn_activation_mode = dnn::ActivationMode::kRelu; + break; + default: + LOG(FATAL) << "Activation mode " << activation_mode << " not supported"; + } + dnn::AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find( fused_conv_parameters, &algorithm_config)) { @@ -558,10 +576,9 @@ void LaunchFusedConv2DBiasActivationOp:: ->ThenFusedConvolveWithAlgorithm( conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc, filter_ptr, conv_desc, side_input_ptr, - side_input_scale, bias_desc, bias_ptr, - dnn::ActivationMode::kRelu, output_desc, &output_ptr, - &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm), - &profile_result) + side_input_scale, bias_desc, bias_ptr, dnn_activation_mode, + output_desc, &output_ptr, &scratch_allocator, + dnn::AlgorithmConfig(profile_algorithm), &profile_result) .ok(); if (cudnn_launch_status) { if (profile_result.is_valid()) { @@ -597,7 +614,7 @@ void LaunchFusedConv2DBiasActivationOp:: ->ThenFusedConvolveWithAlgorithm( conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc, filter_ptr, conv_desc, side_input_ptr, side_input_scale, - bias_desc, bias_ptr, dnn::ActivationMode::kRelu, output_desc, + bias_desc, bias_ptr, dnn_activation_mode, output_desc, &output_ptr, &scratch_allocator, algorithm_config, /*output_profile_result=*/nullptr) .ok(); diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h index ba52697679dafc239b1dac5562573b3589877a8c..b9c131a2e91469c52931080d8a5af90247bd16f0 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h @@ -29,13 +29,13 @@ namespace tensorflow { class FusedConvParameters : public ConvParameters { public: FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, - int64 out_depths, const SpatialArray& filter, - const SpatialArray& dilation, const SpatialArray& stride, - const SpatialArray& padding, DataType dtype, - int device_id, bool has_side_input, + TensorFormat data_format, int64 out_depths, + const SpatialArray& filter, const SpatialArray& dilation, + const SpatialArray& stride, const SpatialArray& padding, + DataType dtype, int device_id, bool has_side_input, ActivationMode activation_mode) - : ConvParameters(batch, in_depths, in, out_depths, filter, dilation, - stride, padding, dtype, device_id), + : ConvParameters(batch, in_depths, in, data_format, out_depths, filter, + dilation, stride, padding, dtype, device_id), activation_mode_(activation_mode), has_side_input_(has_side_input) { hash_code_ = Hash64Combine(hash_code_, has_side_input); diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc index bafd1d59418f0ba47ebbdaabbf06f8e5471fc1a1..410571f3783263152fda93980580182eb666886d 100644 --- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc @@ -44,7 +44,7 @@ REGISTER_OP("FusedConv2DBiasActivation") .Attr(GetPaddingAttrString()) .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'") - .Attr("activation_mode: {'Relu'} = 'Relu'") + .Attr("activation_mode: {'Relu', 'None'} = 'Relu'") .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn([](shape_inference::InferenceContext* c) { using shape_inference::ShapeHandle; @@ -144,7 +144,7 @@ REGISTER_OP("FusedConv2DBiasActivation") `qint8 [ output_channels, input_channels / 4, kernel_height, kernel_width, input_channels % 4 ]` activation_mode: The activation applied to the output. - Currently must be "Relu". + Must be "Relu" or "None". dilations: 1-D tensor of length 4. The dilation factor for each dimension of `input`. If set to k > 1, there will be k-1 skipped cells between each filter element on that dimension. The dimension order is determined diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py index 983b6dc8e5a1512ba81ecbc8d5ca5adaea09afe4..cdc07b935dcc42ce3c0cef6bb8f4a126fe82c883 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py @@ -66,8 +66,10 @@ def fused_conv2d_bias_activation(conv_input, This is optional and defaults to 0. side_input: A `Tensor` of the format specified by `data_format`. This is useful for implementing ResNet blocks. - activation_mode: (optional) currently must be the default "Relu". - Note that in qint8 mode, it also clips to 127, so acts like ReluX. + activation_mode: (optional) currently supports the default "Relu", or + "None" activation function. + Note: in qint8 mode, "None" actually clips to the range [-128, 127], + while "Relu" clips to the range [0, 127]. data_format: Specifies the data format. Possible values are: "NHWC" float [batch, height, width, channels] diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index a955e21b72e765f751318c7927f9644481fe7933..0185ef662c2ed05b1ceaf0e3e8071bad4c0d1a0a 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -21,8 +21,6 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -35,13 +33,6 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging -def NoMemoryOptimizationConfig(): - config = config_pb2.ConfigProto() - config.graph_options.rewrite_options.memory_optimization = ( - rewriter_config_pb2.RewriterConfig.OFF) - return config - - def GetShrunkInceptionShapes(shrink=10): """Iterator for smaller versions of convolution shapes in 2015 Inception. @@ -202,8 +193,7 @@ class FusedConv2DBiasActivationTest(test.TestCase): # This is to guarantee that there is always negative values after # bias add so that we can test whether relu works correctly. x3 = bias - # TODO(b/79323979): re-enable memory optimization after this bug is fixed. - with self.test_session(use_gpu=True, config=NoMemoryOptimizationConfig()): + with self.test_session(use_gpu=True): t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) fused_t2 = t2 @@ -251,9 +241,7 @@ class FusedConv2DBiasActivationTest(test.TestCase): x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32) def _SetupVal(data_format, use_gpu): - # TODO(b/79323979): re-enable memory optimization after this bug is fixed. - with self.test_session( - use_gpu=use_gpu, config=NoMemoryOptimizationConfig()): + with self.test_session(use_gpu=use_gpu): t1 = constant_op.constant(x1, shape=tensor_in_sizes) t2 = constant_op.constant(x2, shape=filter_in_sizes) t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]]) @@ -634,7 +622,7 @@ def HwioToOihw(in_tensor): def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, padding, strides, side_input_scale, - side_input, biases): + side_input, biases, apply_relu): """Simulates the int8 fused 2-D convolution op using separate float ops. The arguments and return values have the same format, meanings and @@ -648,6 +636,9 @@ def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, side_input_scale: A scalar 'float'. side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout. biases: A `Tensor` of type `float32` in NCHW layout. + apply_relu: A boolean to specify whether to apply "Relu" activation function + that clips outputs to the range [0, 127], or "None" activation that clips + to the range [-128, 127]. Returns: A `Tensor` of type `qint8` in NCHW_VECT_C layout. """ @@ -661,10 +652,12 @@ def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, conv_and_side_inputs = conv_result + side_input_scale * NchwVectCToNchw( gen_array_ops.dequantize(side_input, -128, 127)) - logit = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW") + output = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW") + if apply_relu: + output = nn_ops.relu(output) result, _, _ = gen_array_ops.quantize_v2( - NchwToNchwVectC(nn_ops.relu(logit)), -128, 127, dtypes.qint8) + NchwToNchwVectC(output), -128, 127, dtypes.qint8) return result @@ -807,7 +800,7 @@ class FusedConvInt8Tests(test.TestCase): }, ] - def runTest(self, test_param): + def runTest(self, test_param, apply_relu): batch_size = test_param["batch_size"] input_channels = test_param["input_channels"] output_channels = test_param["output_channels"] @@ -843,8 +836,8 @@ class FusedConvInt8Tests(test.TestCase): vertical_stride, padding_type) output_width = CalculateConvolvedOutputDim(input_width, filter_width, horizontal_stride, padding_type) - tf_logging.info("output_height=", output_height, ", output_width=", - output_width) + tf_logging.info("output_height=", output_height, ", output_width=", + output_width) side_input, _, _ = gen_array_ops.quantize_v2( random_ops.random_uniform( @@ -870,16 +863,15 @@ class FusedConvInt8Tests(test.TestCase): conv_input_scale=conv_input_scale, side_input_scale=side_input_scale, side_input=side_input, + activation_mode="Relu" if apply_relu else "None", data_format="NCHW_VECT_C", filter_format="OIHW_VECT_I") expected = SimulateFusedConv2dBiasActivationInt8( conv_input_scale, conv_input, kernel, padding_type, strides, - side_input_scale, side_input, biases) + side_input_scale, side_input, biases, apply_relu) - # TODO(b/79323979): re-enable memory optimization after this bug is fixed. - with self.test_session( - use_gpu=True, config=NoMemoryOptimizationConfig()) as sess: + with self.test_session(use_gpu=True) as sess: actual_y, expected_y = sess.run([actual, expected]) tf_logging.info("actual_y = ", actual_y) tf_logging.info("expected_y = ", expected_y) @@ -891,8 +883,9 @@ class FusedConvInt8Tests(test.TestCase): tf_logging.info("int8 test skipped because not run with --config=cuda or " "no GPUs with compute capability >= 6.1 are available.") return - for test_param in self._test_params: - self.runTest(test_param) + for apply_relu in [True, False]: + for test_param in self._test_params: + self.runTest(test_param, apply_relu) if __name__ == "__main__": diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index b305f37791d71f5a6edeada2bb710a2e5f23087d..9866fccfba3562221ea7fe845e860ab470e238a0 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -42,9 +42,12 @@ py_library( "//tensorflow/contrib/training:training_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", + "//tensorflow/python:random_ops", "//tensorflow/python:training", + "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/losses", @@ -54,26 +57,31 @@ py_library( py_test( name = "train_test", srcs = ["python/train_test.py"], + shard_count = 50, srcs_version = "PY2AND3", tags = ["notsan"], deps = [ - ":features", ":namedtuples", + ":random_tensor_pool", ":train", "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/slim:learning", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:random_seed", "//tensorflow/python:training", + "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/ops/distributions", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -188,10 +196,16 @@ py_test( srcs = ["python/losses/python/tuple_losses_test.py"], srcs_version = "PY2AND3", deps = [ + ":losses_impl", + ":namedtuples", ":tuple_losses", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//third_party/py/numpy", ], @@ -248,12 +262,15 @@ py_library( py_test( name = "random_tensor_pool_test", srcs = ["python/features/python/random_tensor_pool_test.py"], + shard_count = 6, srcs_version = "PY2AND3", deps = [ ":random_tensor_pool", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//third_party/py/numpy", ], ) @@ -344,9 +361,11 @@ py_library( "//tensorflow/python:image_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:platform", "//tensorflow/python:util", + "@six_archive//:six", ], ) @@ -405,9 +424,11 @@ py_library( ":namedtuples", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:functional_ops", "//tensorflow/python:math_ops", "//tensorflow/python:summary", "//tensorflow/python:util", + "//tensorflow/python:variable_scope", "//tensorflow/python/ops/losses", ], ) @@ -440,8 +461,7 @@ py_library( ":train", "//tensorflow/python:framework_ops", "//tensorflow/python:util", - "//tensorflow/python/estimator:head", - "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -458,7 +478,7 @@ py_test( "//tensorflow/python:math_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -470,16 +490,15 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":head", ":namedtuples", ":summaries", ":train", "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:framework_ops", + "//tensorflow/python:metrics", "//tensorflow/python:util", "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -498,16 +517,18 @@ py_test( "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", "//tensorflow/python:parsing_ops", "//tensorflow/python:summary", "//tensorflow/python:training", - "//tensorflow/python/estimator:head", - "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python:training_util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:estimator_py", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 4092b320042162e4eb4c5f4879c2c3ea5dc14fc9..ab9886580d1648852e08f64cb3e9b51f679c25de 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -24,11 +24,11 @@ import enum from tensorflow.contrib.framework.python.ops import variables as variable_lib from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples from tensorflow.contrib.gan.python import train as tfgan_train -from tensorflow.contrib.gan.python.estimator.python import head as head_lib from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import variable_scope from tensorflow.python.util import tf_inspect as inspect @@ -53,9 +53,6 @@ _summary_type_map = { } -# TODO(joelshor): For now, this only supports 1:1 generator:discriminator -# training sequentially. Find a nice way to expose options to the user without -# exposing internals. class GANEstimator(estimator.Estimator): """An estimator for Generative Adversarial Networks (GANs). @@ -154,94 +151,93 @@ class GANEstimator(estimator.Estimator): use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. config: `RunConfig` object to configure the runtime settings. + + Raises: + ValueError: If loss functions aren't callable. + ValueError: If `use_loss_summaries` isn't boolean or `None`. + ValueError: If `get_hooks_fn` isn't callable or `None`. """ - # TODO(joelshor): Explicitly validate inputs. + if not callable(generator_loss_fn): + raise ValueError('generator_loss_fn must be callable.') + if not callable(discriminator_loss_fn): + raise ValueError('discriminator_loss_fn must be callable.') + if use_loss_summaries not in [True, False, None]: + raise ValueError('use_loss_summaries must be True, False or None.') + if get_hooks_fn is not None and not callable(get_hooks_fn): + raise TypeError('get_hooks_fn must be callable.') def _model_fn(features, labels, mode): - gopt = (generator_optimizer() if callable(generator_optimizer) else - generator_optimizer) - dopt = (discriminator_optimizer() if callable(discriminator_optimizer) - else discriminator_optimizer) - gan_head = head_lib.gan_head( - generator_loss_fn, discriminator_loss_fn, gopt, dopt, - use_loss_summaries, get_hooks_fn=get_hooks_fn, - get_eval_metric_ops_fn=get_eval_metric_ops_fn) - return _gan_model_fn( - features, labels, mode, generator_fn, discriminator_fn, gan_head, + """GANEstimator model function.""" + if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT]: + raise ValueError('Mode not recognized: %s' % mode) + real_data = labels # rename inputs for clarity + generator_inputs = features # rename inputs for clarity + + # Make GANModel, which encapsulates the GAN model architectures. + gan_model = _get_gan_model( + mode, generator_fn, discriminator_fn, real_data, generator_inputs, add_summaries) + # Make the EstimatorSpec, which incorporates the GANModel, losses, eval + # metrics, and optimizers (if required). + return _get_estimator_spec( + mode, gan_model, generator_loss_fn, discriminator_loss_fn, + get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, + get_hooks_fn) + super(GANEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) -def _gan_model_fn( - features, - labels, - mode, - generator_fn, - discriminator_fn, - head, - add_summaries=None, - generator_scope_name='Generator'): - """The `model_fn` for the GAN estimator. - - We make the following convention: - features -> TFGAN's `generator_inputs` - labels -> TFGAN's `real_data` - - Args: - features: A dictionary to feed to generator. In the unconditional case, - this might be just `noise`. In the conditional GAN case, this - might be the generator's conditioning. The `generator_fn` determines - what the required keys are. - labels: Real data. Can be any structure, as long as `discriminator_fn` - can accept it for the first argument. - mode: Defines whether this is training, evaluation or prediction. - See `ModeKeys`. - generator_fn: A python lambda that takes `generator_inputs` as inputs and - returns the outputs of the GAN generator. - discriminator_fn: A python lambda that takes `real_data`/`generated data` - and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. - head: A `Head` instance suitable for GANs. - add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. - generator_scope_name: The name of the generator scope. We need this to be - the same for GANModels produced by TFGAN's `train.gan_model` and the - manually constructed ones for predictions. - - Returns: - `ModelFnOps` - - Raises: - ValueError: If `labels` isn't `None` during prediction. - """ - real_data = labels - generator_inputs = features - - if mode == model_fn_lib.ModeKeys.TRAIN: - gan_model = _make_train_gan_model( - generator_fn, discriminator_fn, real_data, generator_inputs, - generator_scope_name, add_summaries) - elif mode == model_fn_lib.ModeKeys.EVAL: - gan_model = _make_eval_gan_model( - generator_fn, discriminator_fn, real_data, generator_inputs, - generator_scope_name, add_summaries) - else: +def _get_gan_model( + mode, generator_fn, discriminator_fn, real_data, generator_inputs, + add_summaries, generator_scope='Generator'): + """Makes the GANModel tuple, which encapsulates the GAN model architecture.""" + if mode == model_fn_lib.ModeKeys.PREDICT: if real_data is not None: raise ValueError('`labels` must be `None` when mode is `predict`. ' 'Instead, found %s' % real_data) gan_model = _make_prediction_gan_model( - generator_inputs, generator_fn, generator_scope_name) + generator_inputs, generator_fn, generator_scope) + else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL + gan_model = _make_gan_model( + generator_fn, discriminator_fn, real_data, generator_inputs, + generator_scope, add_summaries, mode) - return head.create_estimator_spec( - features=None, - mode=mode, - logits=gan_model, - labels=None) + return gan_model + + +def _get_estimator_spec( + mode, gan_model, generator_loss_fn, discriminator_loss_fn, + get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, + get_hooks_fn=None): + """Get the EstimatorSpec for the current mode.""" + if mode == model_fn_lib.ModeKeys.PREDICT: + estimator_spec = model_fn_lib.EstimatorSpec( + mode=mode, predictions=gan_model.generated_data) + else: + gan_loss = tfgan_tuples.GANLoss( + generator_loss=generator_loss_fn(gan_model), + discriminator_loss=discriminator_loss_fn(gan_model)) + if mode == model_fn_lib.ModeKeys.EVAL: + estimator_spec = _get_eval_estimator_spec( + gan_model, gan_loss, get_eval_metric_ops_fn) + else: # model_fn_lib.ModeKeys.TRAIN: + gopt = (generator_optimizer() if callable(generator_optimizer) else + generator_optimizer) + dopt = (discriminator_optimizer() if callable(discriminator_optimizer) + else discriminator_optimizer) + get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() + estimator_spec = _get_train_estimator_spec( + gan_model, gan_loss, gopt, dopt, get_hooks_fn) + + return estimator_spec def _make_gan_model(generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope, add_summaries, mode): - """Make a `GANModel`, and optionally pass in `mode`.""" + """Construct a `GANModel`, and optionally pass in `mode`.""" # If network functions have an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial(generator_fn, mode=mode) @@ -264,22 +260,6 @@ def _make_gan_model(generator_fn, discriminator_fn, real_data, return gan_model -def _make_train_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries): - """Make a `GANModel` for training.""" - return _make_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries, - model_fn_lib.ModeKeys.TRAIN) - - -def _make_eval_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries): - """Make a `GANModel` for evaluation.""" - return _make_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries, - model_fn_lib.ModeKeys.EVAL) - - def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): """Make a `GANModel` from just the generator.""" # If `generator_fn` has an argument `mode`, pass mode to it. @@ -303,3 +283,46 @@ def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): discriminator_variables=None, discriminator_scope=None, discriminator_fn=None) + + +def _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn=None, + name=None): + """Return an EstimatorSpec for the eval case.""" + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + with ops.name_scope(None, 'metrics', + [gan_loss.generator_loss, + gan_loss.discriminator_loss]): + def _summary_key(head_name, val): + return '%s/%s' % (val, head_name) if head_name else val + eval_metric_ops = { + _summary_key(name, 'generator_loss'): + metrics_lib.mean(gan_loss.generator_loss), + _summary_key(name, 'discriminator_loss'): + metrics_lib.mean(gan_loss.discriminator_loss) + } + if get_eval_metric_ops_fn is not None: + custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model) + if not isinstance(custom_eval_metric_ops, dict): + raise TypeError('get_eval_metric_ops_fn must return a dict, ' + 'received: {}'.format(custom_eval_metric_ops)) + eval_metric_ops.update(custom_eval_metric_ops) + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, + predictions=gan_model.generated_data, + loss=scalar_loss, + eval_metric_ops=eval_metric_ops) + + +def _get_train_estimator_spec( + gan_model, gan_loss, generator_optimizer, discriminator_optimizer, + get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops): + """Return an EstimatorSpec for the train case.""" + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer, + discriminator_optimizer) + training_hooks = get_hooks_fn(train_ops) + return model_fn_lib.EstimatorSpec( + loss=scalar_loss, + mode=model_fn_lib.ModeKeys.TRAIN, + train_op=train_ops.global_step_inc_op, + training_hooks=training_hooks) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 955482599b372be3f0d0cbc81451c514958d0eb1..9ac9c6ca9ca86a8a9abe9c0f6ebc4cdf5dd2cfb1 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -21,30 +21,30 @@ from __future__ import print_function import shutil import tempfile +from absl.testing import parameterized import numpy as np import six from tensorflow.contrib import layers -from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as estimator from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses from tensorflow.contrib.learn.python.learn.learn_io import graph_io from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import input as input_lib from tensorflow.python.training import learning_rate_decay -from tensorflow.python.training import monitored_session from tensorflow.python.training import training from tensorflow.python.training import training_util @@ -60,120 +60,109 @@ def discriminator_fn(data, unused_conditioning, mode): return layers.fully_connected(data, 1) -def mock_head(testcase, expected_generator_inputs, expected_real_data, - generator_scope_name): - """Returns a mock head that validates logits values and variable names.""" - discriminator_scope_name = 'Discriminator' # comes from TFGAN defaults - generator_var_names = set([ - '%s/fully_connected/weights:0' % generator_scope_name, - '%s/fully_connected/biases:0' % generator_scope_name]) - discriminator_var_names = set([ - '%s/fully_connected/weights:0' % discriminator_scope_name, - '%s/fully_connected/biases:0' % discriminator_scope_name]) - - def _create_estimator_spec(features, mode, logits, labels): - gan_model = logits # renaming for clarity - is_predict = mode == model_fn_lib.ModeKeys.PREDICT - testcase.assertIsNone(features) - testcase.assertIsNone(labels) - testcase.assertIsInstance(gan_model, namedtuples.GANModel) - - trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - expected_var_names = (generator_var_names if is_predict else - generator_var_names | discriminator_var_names) - testcase.assertItemsEqual(expected_var_names, - [var.name for var in trainable_vars]) - - assertions = [] - def _or_none(x): - return None if is_predict else x - testcase.assertEqual(expected_generator_inputs, gan_model.generator_inputs) - # TODO(joelshor): Add check on `generated_data`. - testcase.assertItemsEqual( - generator_var_names, - set([x.name for x in gan_model.generator_variables])) - testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name) - testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data) - # TODO(joelshor): Add check on `discriminator_real_outputs`. - # TODO(joelshor): Add check on `discriminator_gen_outputs`. - if is_predict: - testcase.assertIsNone(gan_model.discriminator_scope) - else: - testcase.assertEqual(discriminator_scope_name, - gan_model.discriminator_scope.name) - - with ops.control_dependencies(assertions): - if mode == model_fn_lib.ModeKeys.TRAIN: - return model_fn_lib.EstimatorSpec( - mode=mode, loss=array_ops.zeros([]), - train_op=control_flow_ops.no_op(), training_hooks=[]) - elif mode == model_fn_lib.ModeKeys.EVAL: - return model_fn_lib.EstimatorSpec( - mode=mode, predictions=gan_model.generated_data, - loss=array_ops.zeros([])) - elif mode == model_fn_lib.ModeKeys.PREDICT: - return model_fn_lib.EstimatorSpec( - mode=mode, predictions=gan_model.generated_data) - else: - testcase.fail('Invalid mode: {}'.format(mode)) - - head = test.mock.NonCallableMagicMock(spec=head_lib._Head) - head.create_estimator_spec = test.mock.MagicMock( - wraps=_create_estimator_spec) - - return head - - -class GANModelFnTest(test.TestCase): - """Tests that _gan_model_fn passes expected logits to mock head.""" - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) +class GetGANModelTest(test.TestCase, parameterized.TestCase): + """Tests that `GetGANModel` produces the correct model.""" - def _test_logits_helper(self, mode): - """Tests that the expected logits are passed to mock head.""" + @parameterized.named_parameters( + ('train', model_fn_lib.ModeKeys.TRAIN), + ('eval', model_fn_lib.ModeKeys.EVAL), + ('predict', model_fn_lib.ModeKeys.PREDICT)) + def test_get_gan_model(self, mode): with ops.Graph().as_default(): - training_util.get_or_create_global_step() - generator_inputs = {'x': array_ops.zeros([5, 4])} - real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else - array_ops.zeros([5, 4])) - generator_scope_name = 'generator' - head = mock_head(self, - expected_generator_inputs=generator_inputs, - expected_real_data=real_data, - generator_scope_name=generator_scope_name) - estimator_spec = estimator._gan_model_fn( - features=generator_inputs, - labels=real_data, - mode=mode, - generator_fn=generator_fn, - discriminator_fn=discriminator_fn, - generator_scope_name=generator_scope_name, - head=head) - with monitored_session.MonitoredTrainingSession( - checkpoint_dir=self._model_dir) as sess: - if mode == model_fn_lib.ModeKeys.TRAIN: - sess.run(estimator_spec.train_op) - elif mode == model_fn_lib.ModeKeys.EVAL: - sess.run(estimator_spec.loss) - elif mode == model_fn_lib.ModeKeys.PREDICT: - sess.run(estimator_spec.predictions) - else: - self.fail('Invalid mode: {}'.format(mode)) - - def test_logits_predict(self): - self._test_logits_helper(model_fn_lib.ModeKeys.PREDICT) - - def test_logits_eval(self): - self._test_logits_helper(model_fn_lib.ModeKeys.EVAL) - - def test_logits_train(self): - self._test_logits_helper(model_fn_lib.ModeKeys.TRAIN) + generator_inputs = {'x': array_ops.ones([3, 4])} + real_data = (array_ops.zeros([3, 4]) if + mode != model_fn_lib.ModeKeys.PREDICT else None) + gan_model = estimator._get_gan_model( + mode, generator_fn, discriminator_fn, real_data, generator_inputs, + add_summaries=False) + + self.assertEqual(generator_inputs, gan_model.generator_inputs) + self.assertIsNotNone(gan_model.generated_data) + self.assertEqual(2, len(gan_model.generator_variables)) # 1 FC layer + self.assertIsNotNone(gan_model.generator_fn) + if mode == model_fn_lib.ModeKeys.PREDICT: + self.assertIsNone(gan_model.real_data) + self.assertIsNone(gan_model.discriminator_real_outputs) + self.assertIsNone(gan_model.discriminator_gen_outputs) + self.assertIsNone(gan_model.discriminator_variables) + self.assertIsNone(gan_model.discriminator_scope) + self.assertIsNone(gan_model.discriminator_fn) + else: + self.assertIsNotNone(gan_model.real_data) + self.assertIsNotNone(gan_model.discriminator_real_outputs) + self.assertIsNotNone(gan_model.discriminator_gen_outputs) + self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer + self.assertIsNotNone(gan_model.discriminator_scope) + self.assertIsNotNone(gan_model.discriminator_fn) + + +def get_dummy_gan_model(): + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) + with variable_scope.variable_scope('discriminator') as dis_scope: + dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) + return tfgan_tuples.GANModel( + generator_inputs=None, + generated_data=array_ops.ones([3, 4]), + generator_variables=[gen_var], + generator_scope=gen_scope, + generator_fn=None, + real_data=array_ops.zeros([3, 4]), + discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, + discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, + discriminator_variables=[dis_var], + discriminator_scope=dis_scope, + discriminator_fn=None) + + +def dummy_loss_fn(gan_model): + return math_ops.reduce_sum(gan_model.discriminator_real_outputs - + gan_model.discriminator_gen_outputs) + + +def get_metrics(gan_model): + return { + 'mse_custom_metric': metrics_lib.mean_squared_error( + gan_model.real_data, gan_model.generated_data) + } + + +class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): + """Tests that the EstimatorSpec is constructed appropriately.""" + + @classmethod + def setUpClass(cls): + cls._generator_optimizer = training.GradientDescentOptimizer(1.0) + cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) + + @parameterized.named_parameters( + ('train', model_fn_lib.ModeKeys.TRAIN), + ('eval', model_fn_lib.ModeKeys.EVAL), + ('predict', model_fn_lib.ModeKeys.PREDICT)) + def test_get_estimator_spec(self, mode): + with ops.Graph().as_default(): + self._gan_model = get_dummy_gan_model() + spec = estimator._get_estimator_spec( + mode, + self._gan_model, + generator_loss_fn=dummy_loss_fn, + discriminator_loss_fn=dummy_loss_fn, + get_eval_metric_ops_fn=get_metrics, + generator_optimizer=self._generator_optimizer, + discriminator_optimizer=self._discriminator_optimizer) + + self.assertEqual(mode, spec.mode) + if mode == model_fn_lib.ModeKeys.PREDICT: + self.assertEqual(self._gan_model.generated_data, spec.predictions) + elif mode == model_fn_lib.ModeKeys.TRAIN: + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.train_op) + self.assertIsNotNone(spec.training_hooks) + elif mode == model_fn_lib.ModeKeys.EVAL: + self.assertEqual(self._gan_model.generated_data, spec.predictions) + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.eval_metric_ops) # TODO(joelshor): Add pandas test. @@ -195,12 +184,6 @@ class GANEstimatorIntegrationTest(test.TestCase): lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) return training.GradientDescentOptimizer(lr) - def get_metrics(gan_model): - return { - 'mse_custom_metric': metrics_lib.mean_squared_error( - gan_model.real_data, gan_model.generated_data) - } - gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) est = estimator.GANEstimator( diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index ff903a78cc36c1965b7655aa902501b1943637a8..1a0ee6dfc498eb6dc8c97411589d9e35bc352062 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -24,18 +24,24 @@ from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples from tensorflow.contrib.gan.python import train as tfgan_train from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.canned import head +from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import ops from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.util import deprecation __all__ = [ 'GANHead', 'gan_head', ] + def _summary_key(head_name, val): return '%s/%s' % (val, head_name) if head_name else val +@deprecation.deprecated( + None, 'Please use tf.contrib.gan.GANEstimator without explicitly making a ' + 'GANHead.') def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=tfgan_train.get_sequential_train_hooks(), @@ -76,6 +82,9 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, class GANHead(head._Head): # pylint: disable=protected-access """`Head` for a GAN.""" + @deprecation.deprecated( + None, 'Please use tf.contrib.gan.GANEstimator without explicitly making ' + 'a GANHead.') def __init__(self, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, @@ -102,9 +111,20 @@ class GANHead(head._Head): # pylint: disable=protected-access name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ + + if not callable(generator_loss_fn): + raise TypeError('generator_loss_fn must be callable.') + if not callable(discriminator_loss_fn): + raise TypeError('discriminator_loss_fn must be callable.') + if use_loss_summaries not in [True, False, None]: + raise ValueError('use_loss_summaries must be True, False or None.') + if get_hooks_fn is not None and not callable(get_hooks_fn): + raise TypeError('get_hooks_fn must be callable.') + if name is not None and not isinstance(name, str): + raise TypeError('name must be string.') + if get_hooks_fn is None: get_hooks_fn = tfgan_train.get_sequential_train_hooks() - # TODO(joelshor): Validate inputs. if use_loss_summaries in [True, False]: generator_loss_fn = functools.partial( @@ -182,7 +202,10 @@ class GANHead(head._Head): # pylint: disable=protected-access if mode == model_fn_lib.ModeKeys.PREDICT: return model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.PREDICT, - predictions=gan_model.generated_data) + predictions=gan_model.generated_data, + export_outputs={ + 'predict': export_output.PredictOutput(gan_model.generated_data) + }) elif mode == model_fn_lib.ModeKeys.EVAL: gan_loss = self.create_loss( features=None, mode=mode, logits=gan_model, labels=None) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 6587f1fc600b94d27f7c12b44ca2136d0be5a8c5..8205bc889dc01c8680e2139393d65723280cfbd0 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -26,8 +26,11 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import training +_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument return math_ops.reduce_sum(gan_model.discriminator_real_outputs - @@ -64,20 +67,22 @@ class GANHeadTest(test.TestCase): generator_optimizer=training.GradientDescentOptimizer(1.0), discriminator_optimizer=training.GradientDescentOptimizer(1.0), get_eval_metric_ops_fn=self.get_metrics) - self.assertTrue(isinstance(self.gan_head, head.GANHead)) + self.assertIsInstance(self.gan_head, head.GANHead) def get_metrics(self, gan_model): self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel)) return {} def _test_modes_helper(self, mode): - self.gan_head.create_estimator_spec( + return self.gan_head.create_estimator_spec( features=None, mode=mode, logits=get_gan_model()) def test_modes_predict(self): - self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) + spec = self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) + self.assertItemsEqual((_DEFAULT_SERVING_KEY, 'predict'), + spec.export_outputs.keys()) def test_modes_eval(self): self._test_modes_helper(model_fn_lib.ModeKeys.EVAL) diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index 508f487722fba89cc8391a340f73673a526e86c4..f9995bb19d0d09eaf6fd96d039b0bba1d3a7055c 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -22,7 +22,9 @@ from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.eval.python import eval_utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import util as loss_util from tensorflow.python.summary import summary @@ -32,6 +34,7 @@ __all__ = [ 'add_gan_model_summaries', 'add_regularization_loss_summaries', 'add_cyclegan_image_summaries', + 'add_stargan_image_summaries' ] @@ -179,6 +182,94 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2, max_outputs=1) +def add_stargan_image_summaries(stargan_model, + num_images=2, + display_diffs=False): + """Adds image summaries to see StarGAN image results. + + If display_diffs is True, each image result has `2` rows and `num_domains + 1` + columns. + The first row looks like: + [original_image, transformed_to_domain_0, transformed_to_domain_1, ...] + The second row looks like: + [no_modification_baseline, transformed_to_domain_0-original_image, ...] + If display_diffs is False, only the first row is shown. + + IMPORTANT: + Since the model originally does not transformed the image to every domains, + we will transform them on-the-fly within this function in parallel. + + Args: + stargan_model: A StarGANModel tuple. + num_images: The number of examples/images to be transformed and shown. + display_diffs: Also display the difference between generated and target. + + Raises: + ValueError: If input_data is not images. + ValueError: If input_data_domain_label is not rank 2. + ValueError: If dimension 2 of input_data_domain_label is not fully defined. + """ + + _assert_is_image(stargan_model.input_data) + stargan_model.input_data_domain_label.shape.assert_has_rank(2) + stargan_model.input_data_domain_label.shape[1:].assert_is_fully_defined() + + num_domains = stargan_model.input_data_domain_label.get_shape().as_list()[-1] + + def _build_image(image): + """Helper function to create a result for each image on the fly.""" + + # Expand the first dimension as batch_size = 1. + images = array_ops.expand_dims(image, axis=0) + + # Tile the image num_domains times, so we can get all transformed together. + images = array_ops.tile(images, [num_domains, 1, 1, 1]) + + # Create the targets to 0, 1, 2, ..., num_domains-1. + targets = array_ops.one_hot(list(range(num_domains)), num_domains) + + with variable_scope.variable_scope( + stargan_model.generator_scope, reuse=True): + + # Add the original image. + output_images_list = [image] + + # Generate the image and add to the list. + gen_images = stargan_model.generator_fn(images, targets) + gen_images_list = array_ops.split(gen_images, num_domains) + gen_images_list = [ + array_ops.squeeze(img, axis=0) for img in gen_images_list + ] + output_images_list.extend(gen_images_list) + + # Display diffs. + if display_diffs: + diff_images = gen_images - images + diff_images_list = array_ops.split(diff_images, num_domains) + diff_images_list = [ + array_ops.squeeze(img, axis=0) for img in diff_images_list + ] + output_images_list.append(array_ops.zeros_like(image)) + output_images_list.extend(diff_images_list) + + # Create the final image. + final_image = eval_utils.image_reshaper( + output_images_list, num_cols=num_domains + 1) + + # Reduce the first rank. + return array_ops.squeeze(final_image, axis=0) + + summary.image( + 'stargan_image_generation', + functional_ops.map_fn( + _build_image, + stargan_model.input_data[:num_images], + parallel_iterations=num_images, + back_prop=False, + swap_memory=True), + max_outputs=num_images) + + def add_gan_model_summaries(gan_model): """Adds typical GANModel summaries. diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index 33d51bfc218ab93fb52439b1eefed98a4568c4a1..54a6f8d4d9086ad7fc8db31032677628561e48e8 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.eval.python import summaries_impl as summaries from tensorflow.python.framework import ops @@ -37,6 +36,10 @@ def discriminator_model(inputs, _): return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs +def stargan_generator_model(inputs, _): + return generator_model(inputs) + + def get_gan_model(): # TODO(joelshor): Find a better way of creating a variable scope. with variable_scope.variable_scope('generator') as gen_scope: @@ -57,6 +60,31 @@ def get_gan_model(): discriminator_fn=discriminator_model) +def get_stargan_model(): + """Similar to get_gan_model().""" + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('discriminator') as dis_scope: + pass + with variable_scope.variable_scope('generator') as gen_scope: + return namedtuples.StarGANModel( + input_data=array_ops.ones([1, 2, 2, 3]), + input_data_domain_label=array_ops.ones([1, 2]), + generated_data=stargan_generator_model( + array_ops.ones([1, 2, 2, 3]), None), + generated_data_domain_target=array_ops.ones([1, 2]), + reconstructed_data=array_ops.ones([1, 2, 2, 3]), + discriminator_input_data_source_predication=array_ops.ones([1]), + discriminator_generated_data_source_predication=array_ops.ones([1]), + discriminator_input_data_domain_predication=array_ops.ones([1, 2]), + discriminator_generated_data_domain_predication=array_ops.ones([1, 2]), + generator_variables=None, + generator_scope=gen_scope, + generator_fn=stargan_generator_model, + discriminator_variables=None, + discriminator_scope=dis_scope, + discriminator_fn=discriminator_model) + + def get_cyclegan_model(): with variable_scope.variable_scope('x2y'): model_x2y = get_gan_model() @@ -143,6 +171,16 @@ class SummariesTest(test.TestCase): with self.test_session(use_gpu=True): summary.merge_all().eval() + def test_add_image_comparison_summaries_for_stargan(self): + + summaries.add_stargan_image_summaries(get_stargan_model()) + + self.assertEquals(1, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + summary.merge_all().eval() + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py index 9e4ec59e7098443efc53506a4ba159e84b5c1618..ca2d724b49db25191b5744e10b48c66b6bdeb120 100644 --- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py @@ -36,16 +36,15 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import random_ops +from tensorflow.python.util import nest __all__ = [ 'tensor_pool', ] -def _to_tuple(x): - if isinstance(x, (list, tuple)): - return tuple(x) - return (x,) +def _to_list(x): + return [x] if isinstance(x, ops.Tensor) else list(x) def tensor_pool(input_values, @@ -63,8 +62,8 @@ def tensor_pool(input_values, `pool_size` = 0 or `pooling_probability` = 0. Args: - input_values: A `Tensor`, or a list or tuple of `Tensor`s from which to read - values to be pooled. + input_values: An arbitrarily nested structure of `tf.Tensors`, from which to + read values to be pooled. pool_size: An integer specifying the maximum size of the pool. Defaults to 50. pooling_probability: A float `Tensor` specifying the probability of getting @@ -72,9 +71,10 @@ def tensor_pool(input_values, name: A string prefix for the name scope for all tensorflow ops. Returns: - A `Tensor`, or a list or tuple of `Tensor`s (according to the type ofx - `input_values`) which is with given probability either the `input_values` or - a randomly chosen sample that was previously inserted in the pool. + A nested structure of `Tensor` objects with the same structure as + `input_values`. With the given probability, the Tensor values are either the + same as in `input_values` or a randomly chosen sample that was previously + inserted in the pool. Raises: ValueError: If `pool_size` is negative. @@ -86,11 +86,10 @@ def tensor_pool(input_values, return input_values original_input_values = input_values - input_values = _to_tuple(input_values) + input_values = nest.flatten(input_values) - with ops.name_scope( - '{}_pool_queue'.format(name), - values=input_values + (pooling_probability,)): + with ops.name_scope('{}_pool_queue'.format(name), + values=input_values + [pooling_probability]): pool_queue = data_flow_ops.RandomShuffleQueue( capacity=pool_size, min_after_dequeue=0, @@ -112,10 +111,10 @@ def tensor_pool(input_values, def _get_input_value_pooled(): enqueue_op = pool_queue.enqueue(input_values) with ops.control_dependencies([enqueue_op]): - return tuple(array_ops.identity(v) for v in input_values) + return [array_ops.identity(v) for v in input_values] def _get_random_pool_value_and_enqueue_input(): - dequeue_values = _to_tuple(pool_queue.dequeue()) + dequeue_values = _to_list(pool_queue.dequeue()) with ops.control_dependencies(dequeue_values): enqueue_op = pool_queue.enqueue(input_values) with ops.control_dependencies([enqueue_op]): @@ -124,7 +123,7 @@ def tensor_pool(input_values, return control_flow_ops.cond(prob, lambda: dequeue_values, lambda: input_values) - output_values = _to_tuple(control_flow_ops.cond( + output_values = _to_list(control_flow_ops.cond( pool_queue.size() < pool_size, _get_input_value_pooled, _get_random_pool_value_and_enqueue_input)) @@ -132,8 +131,4 @@ def tensor_pool(input_values, for input_value, output_value in zip(input_values, output_values): output_value.set_shape(input_value.shape) - if isinstance(original_input_values, list): - return list(output_values) - elif isinstance(original_input_values, tuple): - return output_values - return output_values[0] + return nest.pack_sequence_as(original_input_values, output_values) diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py index d8cf549cf71838178c9da01df462d41d81595fe5..08584dcd656e3e7a079a3fa36f44742b5eac1178 100644 --- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py @@ -21,7 +21,9 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import tensor_pool +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -111,6 +113,23 @@ class TensorPoolTest(test.TestCase): self.assertEqual(len(outs), len(input_values)) self.assertEqual(outs[1] - outs[0], 1) + def test_pool_preserves_shape(self): + t = constant_op.constant(1) + input_values = [[t, t, t], (t, t), t] + output_values = tensor_pool(input_values, pool_size=5) + print('stuff: ', output_values) + # Overall shape. + self.assertIsInstance(output_values, list) + self.assertEqual(3, len(output_values)) + # Shape of first element. + self.assertIsInstance(output_values[0], list) + self.assertEqual(3, len(output_values[0])) + # Shape of second element. + self.assertIsInstance(output_values[1], tuple) + self.assertEqual(2, len(output_values[1])) + # Shape of third element. + self.assertIsInstance(output_values[2], ops.Tensor) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 1ba3a641671c7f2a411a0c5f99228ca16eee1080..d3897483740faafa62befbaf873886139f1482d2 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -949,6 +949,11 @@ def cycle_consistency_loss(data_x, * loss = (loss_x2x + loss_y2y) / 2 where `loss` is the final result. + For the L1-norm, we follow the original implementation: + https://github.com/junyanz/CycleGAN/blob/master/models/cycle_gan_model.lua + we use L1-norm of pixel-wise error normalized by data size such that + `cycle_loss_weight` can be specified independent of image size. + See https://arxiv.org/abs/1703.10593 for more details. Args: @@ -965,19 +970,12 @@ def cycle_consistency_loss(data_x, A scalar `Tensor` of cycle consistency loss. """ - def _partial_cycle_consistency_loss(data, reconstructed_data): - # Following the original implementation - # https://github.com/junyanz/CycleGAN/blob/master/models/cycle_gan_model.lua - # use L1-norm of pixel-wise error normalized by data size so that - # `cycle_loss_weight` can be specified independent of image size. - return math_ops.reduce_mean(math_ops.abs(data - reconstructed_data)) - with ops.name_scope( scope, 'cycle_consistency_loss', values=[data_x, reconstructed_data_x, data_y, reconstructed_data_y]): - loss_x2x = _partial_cycle_consistency_loss(data_x, reconstructed_data_x) - loss_y2y = _partial_cycle_consistency_loss(data_y, reconstructed_data_y) + loss_x2x = losses.absolute_difference(data_x, reconstructed_data_x) + loss_y2y = losses.absolute_difference(data_y, reconstructed_data_y) loss = (loss_x2x + loss_y2y) / 2.0 if add_summaries: summary.scalar('cycle_consistency_loss_x2x', loss_x2x) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index dcc3f94c2d6b9e5e44036e7cc1a9d1bb39104fb5..221c70c38bd432a6be7f6cda9c6700aa2255821f 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -80,6 +80,9 @@ __all__ = [ 'mutual_information_penalty', 'combine_adversarial_loss', 'cycle_consistency_loss', + 'stargan_generator_loss_wrapper', + 'stargan_discriminator_loss_wrapper', + 'stargan_gradient_penalty_wrapper' ] @@ -277,3 +280,86 @@ def cycle_consistency_loss(cyclegan_model, scope=None, add_summaries=False): cyclegan_model.model_x2y.generator_inputs, cyclegan_model.reconstructed_x, cyclegan_model.model_y2x.generator_inputs, cyclegan_model.reconstructed_y, scope, add_summaries) + + +def stargan_generator_loss_wrapper(loss_fn): + """Convert a generator loss function to take a StarGANModel. + + The new function has the same name as the original one. + + Args: + loss_fn: A python function taking Discriminator's real/fake prediction for + generated data. + + Returns: + A new function that takes a StarGANModel namedtuple and returns the same + loss. + """ + + def new_loss_fn(stargan_model, **kwargs): + return loss_fn( + stargan_model.discriminator_generated_data_source_predication, **kwargs) + + new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__ + new_loss_fn.__docstring__ = new_docstring + new_loss_fn.__name__ = loss_fn.__name__ + new_loss_fn.__module__ = loss_fn.__module__ + return new_loss_fn + + +def stargan_discriminator_loss_wrapper(loss_fn): + """Convert a discriminator loss function to take a StarGANModel. + + The new function has the same name as the original one. + + Args: + loss_fn: A python function taking Discriminator's real/fake prediction for + real data and generated data. + + Returns: + A new function that takes a StarGANModel namedtuple and returns the same + loss. + """ + + def new_loss_fn(stargan_model, **kwargs): + return loss_fn( + stargan_model.discriminator_input_data_source_predication, + stargan_model.discriminator_generated_data_source_predication, **kwargs) + + new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__ + new_loss_fn.__docstring__ = new_docstring + new_loss_fn.__name__ = loss_fn.__name__ + new_loss_fn.__module__ = loss_fn.__module__ + return new_loss_fn + + +def stargan_gradient_penalty_wrapper(loss_fn): + """Convert a gradient penalty function to take a StarGANModel. + + The new function has the same name as the original one. + + Args: + loss_fn: A python function taking real_data, generated_data, + generator_inputs for Discriminator's condition (i.e. number of domains), + discriminator_fn, and discriminator_scope. + + Returns: + A new function that takes a StarGANModel namedtuple and returns the same + loss. + """ + + def new_loss_fn(stargan_model, **kwargs): + num_domains = stargan_model.input_data_domain_label.shape.as_list()[-1] + return loss_fn( + real_data=stargan_model.input_data, + generated_data=stargan_model.generated_data, + generator_inputs=num_domains, + discriminator_fn=stargan_model.discriminator_fn, + discriminator_scope=stargan_model.discriminator_scope, + **kwargs) + + new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__ + new_loss_fn.__docstring__ = new_docstring + new_loss_fn.__name__ = loss_fn.__name__ + new_loss_fn.__module__ = loss_fn.__module__ + return new_loss_fn diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py index aa1ef11172dee6799994b87f70a3883cd67fd15b..a559bbfa11367afd7dfe6a72d2ce2cc9d7ba1f16 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py @@ -22,10 +22,15 @@ import collections import numpy as np +from tensorflow.contrib import layers from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl as tfgan_losses from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -129,6 +134,9 @@ manual_tests = [ 'mutual_information_penalty', 'wasserstein_gradient_penalty', 'cycle_consistency_loss', + 'stargan_generator_loss_wrapper', + 'stargan_discriminator_loss_wrapper', + 'stargan_gradient_penalty_wrapper' ] discriminator_keyword_args = { @@ -175,6 +183,112 @@ class CycleConsistencyLossTest(test.TestCase): self.assertNear(5.0, loss.eval(), 1e-5) +class StarGANLossWrapperTest(test.TestCase): + + def setUp(self): + + super(StarGANLossWrapperTest, self).setUp() + + self.input_data = array_ops.ones([1, 2, 2, 3]) + self.input_data_domain_label = constant_op.constant([[0, 1]]) + self.generated_data = array_ops.ones([1, 2, 2, 3]) + self.discriminator_input_data_source_predication = array_ops.ones([1]) + self.discriminator_generated_data_source_predication = array_ops.ones([1]) + + def _discriminator_fn(inputs, num_domains): + """Differentiable dummy discriminator for StarGAN.""" + hidden = layers.flatten(inputs) + output_src = math_ops.reduce_mean(hidden, axis=1) + output_cls = layers.fully_connected( + inputs=hidden, + num_outputs=num_domains, + activation_fn=None, + normalizer_fn=None, + biases_initializer=None) + return output_src, output_cls + + with variable_scope.variable_scope('discriminator') as dis_scope: + pass + + self.model = namedtuples.StarGANModel( + input_data=self.input_data, + input_data_domain_label=self.input_data_domain_label, + generated_data=self.generated_data, + generated_data_domain_target=None, + reconstructed_data=None, + discriminator_input_data_source_predication=self. + discriminator_input_data_source_predication, + discriminator_generated_data_source_predication=self. + discriminator_generated_data_source_predication, + discriminator_input_data_domain_predication=None, + discriminator_generated_data_domain_predication=None, + generator_variables=None, + generator_scope=None, + generator_fn=None, + discriminator_variables=None, + discriminator_scope=dis_scope, + discriminator_fn=_discriminator_fn) + + self.discriminator_fn = _discriminator_fn + self.discriminator_scope = dis_scope + + def test_stargan_generator_loss_wrapper(self): + """Test StarGAN generator loss wrapper.""" + loss_fn = tfgan_losses_impl.wasserstein_generator_loss + wrapped_loss_fn = tfgan_losses.stargan_generator_loss_wrapper(loss_fn) + + loss_result_tensor = loss_fn( + self.discriminator_generated_data_source_predication) + wrapped_loss_result_tensor = wrapped_loss_fn(self.model) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + loss_result, wrapped_loss_result = sess.run( + [loss_result_tensor, wrapped_loss_result_tensor]) + self.assertAlmostEqual(loss_result, wrapped_loss_result) + + def test_stargan_discriminator_loss_wrapper(self): + """Test StarGAN discriminator loss wrapper.""" + loss_fn = tfgan_losses_impl.wasserstein_discriminator_loss + wrapped_loss_fn = tfgan_losses.stargan_discriminator_loss_wrapper(loss_fn) + + loss_result_tensor = loss_fn( + self.discriminator_generated_data_source_predication, + self.discriminator_generated_data_source_predication) + wrapped_loss_result_tensor = wrapped_loss_fn(self.model) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + loss_result, wrapped_loss_result = sess.run( + [loss_result_tensor, wrapped_loss_result_tensor]) + self.assertAlmostEqual(loss_result, wrapped_loss_result) + + def test_stargan_gradient_penalty_wrapper(self): + """Test StaGAN gradient penalty wrapper. + + Notes: + The random interpolates are handled by given setting the reconstruction to + be the same as the input. + + """ + loss_fn = tfgan_losses_impl.wasserstein_gradient_penalty + wrapped_loss_fn = tfgan_losses.stargan_gradient_penalty_wrapper(loss_fn) + + loss_result_tensor = loss_fn( + real_data=self.input_data, + generated_data=self.generated_data, + generator_inputs=self.input_data_domain_label.shape.as_list()[-1], + discriminator_fn=self.discriminator_fn, + discriminator_scope=self.discriminator_scope) + wrapped_loss_result_tensor = wrapped_loss_fn(self.model) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + loss_result, wrapped_loss_result = sess.run( + [loss_result_tensor, wrapped_loss_result_tensor]) + self.assertAlmostEqual(loss_result, wrapped_loss_result) + + if __name__ == '__main__': for loss_name in tfgan_losses.__all__: if loss_name in manual_tests: continue diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 25cfeafeec9000b0dc3849ebe646e59c1b4d1cc3..a462b68e28be989eee04fe4ec5ee902d75e5d909 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -25,12 +25,12 @@ from __future__ import print_function import collections - __all__ = [ 'GANModel', 'InfoGANModel', 'ACGANModel', 'CycleGANModel', + 'StarGANModel', 'GANLoss', 'CycleGANLoss', 'GANTrainOps', @@ -136,6 +136,54 @@ class CycleGANModel( """ +class StarGANModel( + collections.namedtuple('StarGANModel', ( + 'input_data', + 'input_data_domain_label', + 'generated_data', + 'generated_data_domain_target', + 'reconstructed_data', + 'discriminator_input_data_source_predication', + 'discriminator_generated_data_source_predication', + 'discriminator_input_data_domain_predication', + 'discriminator_generated_data_domain_predication', + 'generator_variables', + 'generator_scope', + 'generator_fn', + 'discriminator_variables', + 'discriminator_scope', + 'discriminator_fn', + ))): + """A StarGANModel contains all the pieces needed for StarGAN training. + + Args: + input_data: The real images that need to be transferred by the generator. + input_data_domain_label: The real domain labels associated with the real + images. + generated_data: The generated images produced by the generator. It has the + same shape as the input_data. + generated_data_domain_target: The target domain that the generated images + belong to. It has the same shape as the input_data_domain_label. + reconstructed_data: The reconstructed images produced by the G(enerator). + reconstructed_data = G(G(input_data, generated_data_domain_target), + input_data_domain_label). + discriminator_input_data_source: The discriminator's output for predicting + the source (real/generated) of input_data. + discriminator_generated_data_source: The discriminator's output for + predicting the source (real/generated) of generated_data. + discriminator_input_data_domain_predication: The discriminator's output for + predicting the domain_label for the input_data. + discriminator_generated_data_domain_predication: The discriminatorr's output + for predicting the domain_target for the generated_data. + generator_variables: A list of all generator variables. + generator_scope: Variable scope all generator variables live in. + generator_fn: The generator function. + discriminator_variables: A list of all discriminator variables. + discriminator_scope: Variable scope all discriminator variables live in. + discriminator_fn: The discriminator function. + """ + + class GANLoss( collections.namedtuple('GANLoss', ( 'generator_loss', diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 6fa43059f3125daea080f780210223363d0a89f9..9e5aea1498a7e9d47480af18cad9f80ede84c0f9 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -34,27 +34,33 @@ from __future__ import print_function from tensorflow.contrib.framework.python.ops import variables as variables_lib from tensorflow.contrib.gan.python import losses as tfgan_losses from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl from tensorflow.contrib.slim.python.slim import learning as slim_learning from tensorflow.contrib.training.python.training import training +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.distributions import distribution as ds from tensorflow.python.ops.losses import losses +from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook from tensorflow.python.training import sync_replicas_optimizer from tensorflow.python.training import training_util - __all__ = [ 'gan_model', 'infogan_model', 'acgan_model', 'cyclegan_model', + 'stargan_model', 'gan_loss', 'cyclegan_loss', + 'stargan_loss', 'gan_train_ops', 'gan_train', 'get_sequential_train_hooks', @@ -123,16 +129,9 @@ def gan_model( discriminator_variables = variables_lib.get_trainable_variables(dis_scope) return namedtuples.GANModel( - generator_inputs, - generated_data, - generator_variables, - gen_scope, - generator_fn, - real_data, - discriminator_real_outputs, - discriminator_gen_outputs, - discriminator_variables, - dis_scope, + generator_inputs, generated_data, generator_variables, gen_scope, + generator_fn, real_data, discriminator_real_outputs, + discriminator_gen_outputs, discriminator_variables, dis_scope, discriminator_fn) @@ -201,8 +200,7 @@ def infogan_model( # Get model-specific variables. generator_variables = variables_lib.get_trainable_variables(gen_scope) - discriminator_variables = variables_lib.get_trainable_variables( - disc_scope) + discriminator_variables = variables_lib.get_trainable_variables(disc_scope) return namedtuples.InfoGANModel( generator_inputs, @@ -279,12 +277,12 @@ def acgan_model( generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as dis_scope: - with ops.name_scope(dis_scope.name+'/generated/'): + with ops.name_scope(dis_scope.name + '/generated/'): (discriminator_gen_outputs, discriminator_gen_classification_logits ) = _validate_acgan_discriminator_outputs( discriminator_fn(generated_data, generator_inputs)) with variable_scope.variable_scope(dis_scope, reuse=True): - with ops.name_scope(dis_scope.name+'/real/'): + with ops.name_scope(dis_scope.name + '/real/'): real_data = ops.convert_to_tensor(real_data) (discriminator_real_outputs, discriminator_real_classification_logits ) = _validate_acgan_discriminator_outputs( @@ -297,8 +295,7 @@ def acgan_model( # Get model-specific variables. generator_variables = variables_lib.get_trainable_variables(gen_scope) - discriminator_variables = variables_lib.get_trainable_variables( - dis_scope) + discriminator_variables = variables_lib.get_trainable_variables(dis_scope) return namedtuples.ACGANModel( generator_inputs, generated_data, generator_variables, gen_scope, @@ -379,6 +376,108 @@ def cyclegan_model( reconstructed_y) +def stargan_model(generator_fn, + discriminator_fn, + input_data, + input_data_domain_label, + generator_scope='Generator', + discriminator_scope='Discriminator'): + """Returns a StarGAN model outputs and variables. + + See https://arxiv.org/abs/1711.09020 for more details. + + Args: + generator_fn: A python lambda that takes `inputs` and `targets` as inputs + and returns 'generated_data' as the transformed version of `input` based + on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n, + num_domains), and `generated_data` has the same shape as `input`. + discriminator_fn: A python lambda that takes `inputs` and `num_domains` as + inputs and returns a tuple (`source_prediction`, `domain_prediction`). + `source_prediction` represents the source(real/generated) prediction by + the discriminator, and `domain_prediction` represents the domain + prediction/classification by the discriminator. `source_prediction` has + shape (n) and `domain_prediction` has shape (n, num_domains). + input_data: Tensor or a list of tensor of shape (n, h, w, c) representing + the real input images. + input_data_domain_label: Tensor or a list of tensor of shape (batch_size, + num_domains) representing the domain label associated with the real + images. + generator_scope: Optional generator variable scope. Useful if you want to + reuse a subgraph that has already been created. + discriminator_scope: Optional discriminator variable scope. Useful if you + want to reuse a subgraph that has already been created. + + Returns: + StarGANModel nametuple return the tensor that are needed to compute the + loss. + + Raises: + ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully + defined in every dimensions. + """ + + # Convert to tensor. + input_data = _convert_tensor_or_l_or_d(input_data) + input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label) + + # Convert list of tensor to a single tensor if applicable. + if isinstance(input_data, (list, tuple)): + input_data = array_ops.concat( + [ops.convert_to_tensor(x) for x in input_data], 0) + if isinstance(input_data_domain_label, (list, tuple)): + input_data_domain_label = array_ops.concat( + [ops.convert_to_tensor(x) for x in input_data_domain_label], 0) + + # Get batch_size, num_domains from the labels. + input_data_domain_label.shape.assert_has_rank(2) + input_data_domain_label.shape.assert_is_fully_defined() + batch_size, num_domains = input_data_domain_label.shape.as_list() + + # Transform input_data to random target domains. + with variable_scope.variable_scope(generator_scope) as generator_scope: + generated_data_domain_target = _generate_stargan_random_domain_target( + batch_size, num_domains) + generated_data = generator_fn(input_data, generated_data_domain_target) + + # Transform generated_data back to the original input_data domain. + with variable_scope.variable_scope(generator_scope, reuse=True): + reconstructed_data = generator_fn(generated_data, input_data_domain_label) + + # Predict source and domain for the generated_data using the discriminator. + with variable_scope.variable_scope( + discriminator_scope) as discriminator_scope: + disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn( + generated_data, num_domains) + + # Predict source and domain for the input_data using the discriminator. + with variable_scope.variable_scope(discriminator_scope, reuse=True): + disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn( + input_data, num_domains) + + # Collect trainable variables from the neural networks. + generator_variables = variables_lib.get_trainable_variables(generator_scope) + discriminator_variables = variables_lib.get_trainable_variables( + discriminator_scope) + + # Create the StarGANModel namedtuple. + return namedtuples.StarGANModel( + input_data=input_data, + input_data_domain_label=input_data_domain_label, + generated_data=generated_data, + generated_data_domain_target=generated_data_domain_target, + reconstructed_data=reconstructed_data, + discriminator_input_data_source_predication=disc_input_data_source_pred, + discriminator_generated_data_source_predication=disc_gen_data_source_pred, + discriminator_input_data_domain_predication=disc_input_data_domain_pred, + discriminator_generated_data_domain_predication=disc_gen_data_domain_pred, + generator_variables=generator_variables, + generator_scope=generator_scope, + generator_fn=generator_fn, + discriminator_variables=discriminator_variables, + discriminator_scope=discriminator_scope, + discriminator_fn=discriminator_fn) + + def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'): if isinstance(aux_loss_weight, ops.Tensor): aux_loss_weight.shape.assert_is_compatible_with([]) @@ -419,33 +518,42 @@ def _tensor_pool_adjusted_model(model, tensor_pool_fn): Raises: ValueError: If tensor pool does not support the `model`. """ - if tensor_pool_fn is None: - return model - - pooled_generated_data, pooled_generator_inputs = tensor_pool_fn( - (model.generated_data, model.generator_inputs)) - if isinstance(model, namedtuples.GANModel): + pooled_generator_inputs, pooled_generated_data = tensor_pool_fn( + (model.generator_inputs, model.generated_data)) with variable_scope.variable_scope(model.discriminator_scope, reuse=True): dis_gen_outputs = model.discriminator_fn(pooled_generated_data, pooled_generator_inputs) - return model._replace(discriminator_gen_outputs=dis_gen_outputs) + return model._replace( + generator_inputs=pooled_generator_inputs, + generated_data=pooled_generated_data, + discriminator_gen_outputs=dis_gen_outputs) elif isinstance(model, namedtuples.ACGANModel): + pooled_generator_inputs, pooled_generated_data = tensor_pool_fn( + (model.generator_inputs, model.generated_data)) with variable_scope.variable_scope(model.discriminator_scope, reuse=True): - (dis_pooled_gen_outputs, - dis_pooled_gen_classification_logits) = model.discriminator_fn( + (pooled_discriminator_gen_outputs, + pooled_discriminator_gen_classification_logits) = model.discriminator_fn( pooled_generated_data, pooled_generator_inputs) return model._replace( - discriminator_gen_outputs=dis_pooled_gen_outputs, + generator_inputs=pooled_generator_inputs, + generated_data=pooled_generated_data, + discriminator_gen_outputs=pooled_discriminator_gen_outputs, discriminator_gen_classification_logits= - dis_pooled_gen_classification_logits) + pooled_discriminator_gen_classification_logits) elif isinstance(model, namedtuples.InfoGANModel): + pooled_generator_inputs, pooled_generated_data, pooled_structured_input = ( + tensor_pool_fn((model.generator_inputs, model.generated_data, + model.structured_generator_inputs))) with variable_scope.variable_scope(model.discriminator_scope, reuse=True): - (dis_pooled_gen_outputs, + (pooled_discriminator_gen_outputs, pooled_predicted_distributions) = model.discriminator_and_aux_fn( pooled_generated_data, pooled_generator_inputs) return model._replace( - discriminator_gen_outputs=dis_pooled_gen_outputs, + generator_inputs=pooled_generator_inputs, + generated_data=pooled_generated_data, + structured_generator_inputs=pooled_structured_input, + discriminator_gen_outputs=pooled_discriminator_gen_outputs, predicted_distributions=pooled_predicted_distributions) else: raise ValueError('Tensor pool does not support `model`: %s.' % type(model)) @@ -512,8 +620,8 @@ def gan_loss( `model` isn't an `InfoGANModel`. """ # Validate arguments. - gradient_penalty_weight = _validate_aux_loss_weight(gradient_penalty_weight, - 'gradient_penalty_weight') + gradient_penalty_weight = _validate_aux_loss_weight( + gradient_penalty_weight, 'gradient_penalty_weight') mutual_information_penalty_weight = _validate_aux_loss_weight( mutual_information_penalty_weight, 'infogan_weight') aux_cond_generator_weight = _validate_aux_loss_weight( @@ -537,33 +645,40 @@ def gan_loss( 'is provided, `model` must be an `ACGANModel`. Instead, was %s.' % type(model)) + # Optionally create pooled model. + pooled_model = ( + _tensor_pool_adjusted_model(model, tensor_pool_fn) + if tensor_pool_fn else model) + # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) - dis_loss = discriminator_loss_fn( - _tensor_pool_adjusted_model(model, tensor_pool_fn), - add_summaries=add_summaries) + dis_loss = discriminator_loss_fn(pooled_model, add_summaries=add_summaries) # Add optional extra losses. if _use_aux_loss(gradient_penalty_weight): gp_loss = tfgan_losses.wasserstein_gradient_penalty( - model, + pooled_model, epsilon=gradient_penalty_epsilon, target=gradient_penalty_target, one_sided=gradient_penalty_one_sided, add_summaries=add_summaries) dis_loss += gradient_penalty_weight * gp_loss if _use_aux_loss(mutual_information_penalty_weight): - info_loss = tfgan_losses.mutual_information_penalty( + gen_info_loss = tfgan_losses.mutual_information_penalty( model, add_summaries=add_summaries) - dis_loss += mutual_information_penalty_weight * info_loss - gen_loss += mutual_information_penalty_weight * info_loss + dis_info_loss = ( + gen_info_loss + if tensor_pool_fn is None else tfgan_losses.mutual_information_penalty( + pooled_model, add_summaries=add_summaries)) + gen_loss += mutual_information_penalty_weight * gen_info_loss + dis_loss += mutual_information_penalty_weight * dis_info_loss if _use_aux_loss(aux_cond_generator_weight): ac_gen_loss = tfgan_losses.acgan_generator_loss( model, add_summaries=add_summaries) gen_loss += aux_cond_generator_weight * ac_gen_loss if _use_aux_loss(aux_cond_discriminator_weight): ac_disc_loss = tfgan_losses.acgan_discriminator_loss( - model, add_summaries=add_summaries) + pooled_model, add_summaries=add_summaries) dis_loss += aux_cond_discriminator_weight * ac_disc_loss # Gathers auxiliary losses. if model.generator_scope: @@ -631,8 +746,8 @@ def cyclegan_loss( generator_loss_fn=generator_loss_fn, discriminator_loss_fn=discriminator_loss_fn, **kwargs) - return partial_loss._replace( - generator_loss=partial_loss.generator_loss + aux_loss) + return partial_loss._replace(generator_loss=partial_loss.generator_loss + + aux_loss) with ops.name_scope('cyclegan_loss_x2y'): loss_x2y = _partial_loss(model.model_x2y) @@ -642,6 +757,130 @@ def cyclegan_loss( return namedtuples.CycleGANLoss(loss_x2y, loss_y2x) +def stargan_loss( + model, + generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper( + tfgan_losses_impl.wasserstein_generator_loss), + discriminator_loss_fn=tfgan_losses.stargan_discriminator_loss_wrapper( + tfgan_losses_impl.wasserstein_discriminator_loss), + gradient_penalty_weight=10.0, + gradient_penalty_epsilon=1e-10, + gradient_penalty_target=1.0, + gradient_penalty_one_sided=False, + reconstruction_loss_fn=losses.absolute_difference, + reconstruction_loss_weight=10.0, + classification_loss_fn=losses.softmax_cross_entropy, + classification_loss_weight=1.0, + classification_one_hot=True, + add_summaries=True): + """StarGAN Loss. + + The four major part can be found here: http://screen/tMRMBAohDYG. + + Args: + model: (StarGAN) Model output of the stargan_model() function call. + generator_loss_fn: The loss function on the generator. Takes a + `StarGANModel` named tuple. + discriminator_loss_fn: The loss function on the discriminator. Takes a + `StarGANModel` namedtuple. + gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per + the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to + turn off gradient penalty. + gradient_penalty_epsilon: (float) A small positive number added for + numerical stability when computing the gradient norm. + gradient_penalty_target: (float, or tf.float `Tensor`) The target value of + gradient norm. Defaults to 1.0. + gradient_penalty_one_sided: (bool) If `True`, penalty proposed in + https://arxiv.org/abs/1709.08894 is used. Defaults to `False`. + reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm + and the function must conform to the `tf.losses` API. + reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0. + classification_loss_fn: The loss function on the discriminator's ability to + classify domain of the input. Default to one-hot softmax cross entropy + loss, and the function must conform to the `tf.losses` API. + classification_loss_weight: (float) Classification loss weight. Default to + 1.0. + classification_one_hot: (bool) If the label is one hot representation. + Default to True. If False, classification classification_loss_fn need to + be sigmoid cross entropy loss instead. + add_summaries: (bool) Add the loss to the summary + + Returns: + GANLoss namedtuple where we have generator loss and discriminator loss. + + Raises: + ValueError: If input StarGANModel.input_data_domain_label does not have rank + 2, or dimension 2 is not defined. + """ + + def _classification_loss_helper(true_labels, predict_logits, scope_name): + """Classification Loss Function Helper. + + Args: + true_labels: Tensor of shape [batch_size, num_domains] representing the + label where each row is an one-hot vector. + predict_logits: Tensor of shape [batch_size, num_domains] representing the + predicted label logit, which is UNSCALED output from the NN. + scope_name: (string) Name scope of the loss component. + + Returns: + Single scalar tensor representing the classification loss. + """ + + with ops.name_scope(scope_name, values=(true_labels, predict_logits)): + + loss = classification_loss_fn( + onehot_labels=true_labels, logits=predict_logits) + + if not classification_one_hot: + loss = math_ops.reduce_sum(loss, axis=1) + loss = math_ops.reduce_mean(loss) + + if add_summaries: + summary.scalar(scope_name, loss) + + return loss + + # Check input shape. + model.input_data_domain_label.shape.assert_has_rank(2) + model.input_data_domain_label.shape[1:].assert_is_fully_defined() + + # Adversarial Loss. + generator_loss = generator_loss_fn(model, add_summaries=add_summaries) + discriminator_loss = discriminator_loss_fn(model, add_summaries=add_summaries) + + # Gradient Penalty. + if _use_aux_loss(gradient_penalty_weight): + gradient_penalty_fn = tfgan_losses.stargan_gradient_penalty_wrapper( + tfgan_losses_impl.wasserstein_gradient_penalty) + discriminator_loss += gradient_penalty_fn( + model, + epsilon=gradient_penalty_epsilon, + target=gradient_penalty_target, + one_sided=gradient_penalty_one_sided, + add_summaries=add_summaries) * gradient_penalty_weight + + # Reconstruction Loss. + reconstruction_loss = reconstruction_loss_fn(model.input_data, + model.reconstructed_data) + generator_loss += reconstruction_loss * reconstruction_loss_weight + if add_summaries: + summary.scalar('reconstruction_loss', reconstruction_loss) + + # Classification Loss. + generator_loss += _classification_loss_helper( + true_labels=model.generated_data_domain_target, + predict_logits=model.discriminator_generated_data_domain_predication, + scope_name='generator_classification_loss') * classification_loss_weight + discriminator_loss += _classification_loss_helper( + true_labels=model.input_data_domain_label, + predict_logits=model.discriminator_input_data_domain_predication, + scope_name='discriminator_classification_loss' + ) * classification_loss_weight + + return namedtuples.GANLoss(generator_loss, discriminator_loss) + + def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True): """Gets generator and discriminator update ops. @@ -822,12 +1061,14 @@ def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): Returns: A function that takes a GANTrainOps tuple and returns a list of hooks. """ + def get_hooks(train_ops): generator_hook = RunTrainOpsHook(train_ops.generator_train_op, train_steps.generator_train_steps) discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op, train_steps.discriminator_train_steps) return [generator_hook, discriminator_hook] + return get_hooks @@ -881,23 +1122,23 @@ def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): d_hook = RunTrainOpsHook(d_op, num_d_steps) return [joint_hook, g_hook, d_hook] + return get_hooks # TODO(joelshor): This function currently returns the global step. Find a # good way for it to return the generator, discriminator, and final losses. -def gan_train( - train_ops, - logdir, - get_hooks_fn=get_sequential_train_hooks(), - master='', - is_chief=True, - scaffold=None, - hooks=None, - chief_only_hooks=None, - save_checkpoint_secs=600, - save_summaries_steps=100, - config=None): +def gan_train(train_ops, + logdir, + get_hooks_fn=get_sequential_train_hooks(), + master='', + is_chief=True, + scaffold=None, + hooks=None, + chief_only_hooks=None, + save_checkpoint_secs=600, + save_summaries_steps=100, + config=None): """A wrapper around `contrib.training.train` that uses GAN hooks. Args: @@ -943,8 +1184,7 @@ def gan_train( config=config) -def get_sequential_train_steps( - train_steps=namedtuples.GANTrainSteps(1, 1)): +def get_sequential_train_steps(train_steps=namedtuples.GANTrainSteps(1, 1)): """Returns a thin wrapper around slim.learning.train_step, for GANs. This function is to provide support for the Supervisor. For new code, please @@ -1042,3 +1282,19 @@ def _validate_acgan_discriminator_outputs(discriminator_output): 'A discriminator function for ACGAN must output a tuple ' 'consisting of (discrimination logits, classification logits).') return a, b + + +def _generate_stargan_random_domain_target(batch_size, num_domains): + """Generate random domain label. + + Args: + batch_size: (int) Number of random domain label. + num_domains: (int) Number of domains representing with the label. + + Returns: + Tensor of shape (batch_size, num_domains) representing random label. + """ + domain_idx = random_ops.random_uniform( + [batch_size], minval=0, maxval=num_domains, dtype=dtypes.int32) + + return array_ops.one_hot(domain_idx, num_domains) diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 3ebbe55d059e5e72607bc4efdbf95a6c96d99f11..58f348034fdcaadd8d738517aef2a7e2f0172c13 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -18,8 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np +from tensorflow.contrib import layers from tensorflow.contrib.framework.python.ops import variables as variables_lib from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python import train @@ -30,6 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -84,19 +87,59 @@ class InfoGANDiscriminator(object): def acgan_discriminator_model(inputs, _, num_classes=10): - return (discriminator_model(inputs, _), array_ops.one_hot( - # TODO(haeusser): infer batch size from input - random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32), - num_classes)) + return ( + discriminator_model(inputs, _), + array_ops.one_hot( + # TODO(haeusser): infer batch size from input + random_ops.random_uniform( + [3], maxval=num_classes, dtype=dtypes.int32), + num_classes)) class ACGANDiscriminator(object): def __call__(self, inputs, _, num_classes=10): - return (discriminator_model(inputs, _), array_ops.one_hot( - # TODO(haeusser): infer batch size from input - random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32), - num_classes)) + return ( + discriminator_model(inputs, _), + array_ops.one_hot( + # TODO(haeusser): infer batch size from input + random_ops.random_uniform( + [3], maxval=num_classes, dtype=dtypes.int32), + num_classes)) + + +def stargan_generator_model(inputs, _): + """Dummy generator for StarGAN.""" + + return variable_scope.get_variable('dummy_g', initializer=0.5) * inputs + + +class StarGANGenerator(object): + + def __call__(self, inputs, _): + return stargan_generator_model(inputs, _) + + +def stargan_discriminator_model(inputs, num_domains): + """Differentiable dummy discriminator for StarGAN.""" + + hidden = layers.flatten(inputs) + + output_src = math_ops.reduce_mean(hidden, axis=1) + + output_cls = layers.fully_connected( + inputs=hidden, + num_outputs=num_domains, + activation_fn=None, + normalizer_fn=None, + biases_initializer=None) + return output_src, output_cls + + +class StarGANDiscriminator(object): + + def __call__(self, inputs, num_domains): + return stargan_discriminator_model(inputs, num_domains) def get_gan_model(): @@ -122,8 +165,7 @@ def get_gan_model(): def get_callable_gan_model(): ganmodel = get_gan_model() return ganmodel._replace( - generator_fn=Generator(), - discriminator_fn=Discriminator()) + generator_fn=Generator(), discriminator_fn=Discriminator()) def create_gan_model(): @@ -242,69 +284,84 @@ def create_callable_cyclegan_model(): data_y=array_ops.ones([1, 2])) -def get_sync_optimizer(): - return sync_replicas_optimizer.SyncReplicasOptimizer( - gradient_descent.GradientDescentOptimizer(learning_rate=1.0), - replicas_to_aggregate=1) +def get_stargan_model(): + """Similar to get_gan_model().""" + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + pass + with variable_scope.variable_scope('discriminator') as dis_scope: + pass + return namedtuples.StarGANModel( + input_data=array_ops.ones([1, 2, 2, 3]), + input_data_domain_label=array_ops.ones([1, 2]), + generated_data=array_ops.ones([1, 2, 2, 3]), + generated_data_domain_target=array_ops.ones([1, 2]), + reconstructed_data=array_ops.ones([1, 2, 2, 3]), + discriminator_input_data_source_predication=array_ops.ones([1]), + discriminator_generated_data_source_predication=array_ops.ones([1]), + discriminator_input_data_domain_predication=array_ops.ones([1, 2]), + discriminator_generated_data_domain_predication=array_ops.ones([1, 2]), + generator_variables=None, + generator_scope=gen_scope, + generator_fn=stargan_generator_model, + discriminator_variables=None, + discriminator_scope=dis_scope, + discriminator_fn=stargan_discriminator_model) -def get_tensor_pool_fn(pool_size): +def get_callable_stargan_model(): + model = get_stargan_model() + return model._replace( + generator_fn=StarGANGenerator(), discriminator_fn=StarGANDiscriminator()) - def tensor_pool_fn_impl(input_values): - return random_tensor_pool.tensor_pool(input_values, pool_size=pool_size) - return tensor_pool_fn_impl +def create_stargan_model(): + return train.stargan_model( + stargan_generator_model, stargan_discriminator_model, + array_ops.ones([1, 2, 2, 3]), array_ops.ones([1, 2])) -def get_tensor_pool_fn_for_infogan(pool_size): +def create_callable_stargan_model(): + return train.stargan_model(StarGANGenerator(), StarGANDiscriminator(), + array_ops.ones([1, 2, 2, 3]), + array_ops.ones([1, 2])) - def tensor_pool_fn_impl(input_values): - generated_data, generator_inputs = input_values - output_values = random_tensor_pool.tensor_pool( - [generated_data] + generator_inputs, pool_size=pool_size) - return output_values[0], output_values[1:] - return tensor_pool_fn_impl +def get_sync_optimizer(): + return sync_replicas_optimizer.SyncReplicasOptimizer( + gradient_descent.GradientDescentOptimizer(learning_rate=1.0), + replicas_to_aggregate=1) -class GANModelTest(test.TestCase): +class GANModelTest(test.TestCase, parameterized.TestCase): """Tests for `gan_model`.""" - def _test_output_type_helper(self, create_fn, tuple_type): - self.assertTrue(isinstance(create_fn(), tuple_type)) - - def test_output_type_gan(self): - self._test_output_type_helper(get_gan_model, namedtuples.GANModel) - - def test_output_type_callable_gan(self): - self._test_output_type_helper(get_callable_gan_model, namedtuples.GANModel) - - def test_output_type_infogan(self): - self._test_output_type_helper(get_infogan_model, namedtuples.InfoGANModel) - - def test_output_type_callable_infogan(self): - self._test_output_type_helper( - get_callable_infogan_model, namedtuples.InfoGANModel) - - def test_output_type_acgan(self): - self._test_output_type_helper(get_acgan_model, namedtuples.ACGANModel) - - def test_output_type_callable_acgan(self): - self._test_output_type_helper( - get_callable_acgan_model, namedtuples.ACGANModel) - - def test_output_type_cyclegan(self): - self._test_output_type_helper(get_cyclegan_model, namedtuples.CycleGANModel) - - def test_output_type_callable_cyclegan(self): - self._test_output_type_helper(get_callable_cyclegan_model, - namedtuples.CycleGANModel) + @parameterized.named_parameters( + ('gan', get_gan_model, namedtuples.GANModel), + ('callable_gan', get_callable_gan_model, namedtuples.GANModel), + ('infogan', get_infogan_model, namedtuples.InfoGANModel), + ('callable_infogan', get_callable_infogan_model, + namedtuples.InfoGANModel), + ('acgan', get_acgan_model, namedtuples.ACGANModel), + ('callable_acgan', get_callable_acgan_model, namedtuples.ACGANModel), + ('cyclegan', get_cyclegan_model, namedtuples.CycleGANModel), + ('callable_cyclegan', get_callable_cyclegan_model, + namedtuples.CycleGANModel), + ('stargan', get_stargan_model, namedtuples.StarGANModel), + ('callabel_stargan', get_callable_stargan_model, namedtuples.StarGANModel) + ) + def test_output_type(self, create_fn, expected_tuple_type): + """Test that output type is as expected.""" + self.assertIsInstance(create_fn(), expected_tuple_type) def test_no_shape_check(self): + def dummy_generator_model(_): return (None, None) + def dummy_discriminator_model(data, conditioning): # pylint: disable=unused-argument return 1 + with self.assertRaisesRegexp(AttributeError, 'object has no attribute'): train.gan_model( dummy_generator_model, @@ -320,52 +377,182 @@ class GANModelTest(test.TestCase): check_shapes=False) -class GANLossTest(test.TestCase): - """Tests for `gan_loss`.""" +class StarGANModelTest(test.TestCase): + """Tests for `stargan_model`.""" + + @staticmethod + def create_input_and_label_tensor(batch_size, img_size, c_size, num_domains): + input_tensor_list = [] + label_tensor_list = [] + for _ in range(num_domains): + input_tensor_list.append( + random_ops.random_uniform((batch_size, img_size, img_size, c_size))) + domain_idx = random_ops.random_uniform( + [batch_size], minval=0, maxval=num_domains, dtype=dtypes.int32) + label_tensor_list.append(array_ops.one_hot(domain_idx, num_domains)) + return input_tensor_list, label_tensor_list + + def test_generate_stargan_random_domain_target(self): + batch_size = 8 + domain_numbers = 3 + + target_tensor = train._generate_stargan_random_domain_target( + batch_size, domain_numbers) + + with self.test_session() as sess: + targets = sess.run(target_tensor) + self.assertTupleEqual((batch_size, domain_numbers), targets.shape) + for target in targets: + self.assertEqual(1, np.sum(target)) + self.assertEqual(1, np.max(target)) + + def test_stargan_model_output_type(self): + batch_size = 2 + img_size = 16 + c_size = 3 + num_domains = 5 + + input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( + batch_size, img_size, c_size, num_domains) + model = train.stargan_model( + generator_fn=stargan_generator_model, + discriminator_fn=stargan_discriminator_model, + input_data=input_tensor, + input_data_domain_label=label_tensor) + + self.assertIsInstance(model, namedtuples.StarGANModel) + self.assertTrue(isinstance(model.discriminator_variables, list)) + self.assertTrue(isinstance(model.generator_variables, list)) + self.assertIsInstance(model.discriminator_scope, + variable_scope.VariableScope) + self.assertTrue(model.generator_scope, variable_scope.VariableScope) + self.assertTrue(callable(model.discriminator_fn)) + self.assertTrue(callable(model.generator_fn)) + + def test_stargan_model_generator_output(self): + batch_size = 2 + img_size = 16 + c_size = 3 + num_domains = 5 + + input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( + batch_size, img_size, c_size, num_domains) + model = train.stargan_model( + generator_fn=stargan_generator_model, + discriminator_fn=stargan_discriminator_model, + input_data=input_tensor, + input_data_domain_label=label_tensor) - # Test output type. - def _test_output_type_helper(self, get_gan_model_fn): - loss = train.gan_loss(get_gan_model_fn(), add_summaries=True) - self.assertTrue(isinstance(loss, namedtuples.GANLoss)) - self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) - - def test_output_type_gan(self): - self._test_output_type_helper(get_gan_model) + with self.test_session(use_gpu=True) as sess: - def test_output_type_callable_gan(self): - self._test_output_type_helper(get_callable_gan_model) + sess.run(variables.global_variables_initializer()) - def test_output_type_infogan(self): - self._test_output_type_helper(get_infogan_model) + input_data, generated_data, reconstructed_data = sess.run( + [model.input_data, model.generated_data, model.reconstructed_data]) + self.assertTupleEqual( + (batch_size * num_domains, img_size, img_size, c_size), + input_data.shape) + self.assertTupleEqual( + (batch_size * num_domains, img_size, img_size, c_size), + generated_data.shape) + self.assertTupleEqual( + (batch_size * num_domains, img_size, img_size, c_size), + reconstructed_data.shape) + + def test_stargan_model_discriminator_output(self): + batch_size = 2 + img_size = 16 + c_size = 3 + num_domains = 5 + + input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( + batch_size, img_size, c_size, num_domains) + model = train.stargan_model( + generator_fn=stargan_generator_model, + discriminator_fn=stargan_discriminator_model, + input_data=input_tensor, + input_data_domain_label=label_tensor) - def test_output_type_callable_infogan(self): - self._test_output_type_helper(get_callable_infogan_model) + with self.test_session(use_gpu=True) as sess: - def test_output_type_acgan(self): - self._test_output_type_helper(get_acgan_model) + sess.run(variables.global_variables_initializer()) - def test_output_type_callable_acgan(self): - self._test_output_type_helper(get_callable_acgan_model) + disc_input_data_source_pred, disc_gen_data_source_pred = sess.run([ + model.discriminator_input_data_source_predication, + model.discriminator_generated_data_source_predication + ]) + self.assertEqual(1, len(disc_input_data_source_pred.shape)) + self.assertEqual(batch_size * num_domains, + disc_input_data_source_pred.shape[0]) + self.assertEqual(1, len(disc_gen_data_source_pred.shape)) + self.assertEqual(batch_size * num_domains, + disc_gen_data_source_pred.shape[0]) + + input_label, disc_input_label, gen_label, disc_gen_label = sess.run([ + model.input_data_domain_label, + model.discriminator_input_data_domain_predication, + model.generated_data_domain_target, + model.discriminator_generated_data_domain_predication + ]) + self.assertTupleEqual((batch_size * num_domains, num_domains), + input_label.shape) + self.assertTupleEqual((batch_size * num_domains, num_domains), + disc_input_label.shape) + self.assertTupleEqual((batch_size * num_domains, num_domains), + gen_label.shape) + self.assertTupleEqual((batch_size * num_domains, num_domains), + disc_gen_label.shape) + + +class GANLossTest(test.TestCase, parameterized.TestCase): + """Tests for `gan_loss`.""" - def test_output_type_cyclegan(self): - loss = train.cyclegan_loss(create_cyclegan_model(), add_summaries=True) - self.assertIsInstance(loss, namedtuples.CycleGANLoss) + @parameterized.named_parameters( + ('gan', get_gan_model), + ('callable_gan', get_callable_gan_model), + ('infogan', get_infogan_model), + ('callable_infogan', get_callable_infogan_model), + ('acgan', get_acgan_model), + ('callable_acgan', get_callable_acgan_model), + ) + def test_output_type(self, get_gan_model_fn): + """Test output type.""" + loss = train.gan_loss(get_gan_model_fn(), add_summaries=True) + self.assertIsInstance(loss, namedtuples.GANLoss) self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) - def test_output_type_callable_cyclegan(self): - loss = train.cyclegan_loss( - create_callable_cyclegan_model(), add_summaries=True) + @parameterized.named_parameters( + ('cyclegan', create_cyclegan_model), + ('callable_cyclegan', create_callable_cyclegan_model), + ) + def test_cyclegan_output_type(self, get_gan_model_fn): + loss = train.cyclegan_loss(get_gan_model_fn(), add_summaries=True) self.assertIsInstance(loss, namedtuples.CycleGANLoss) self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) - # Test gradient penalty option. - def _test_grad_penalty_helper(self, create_gan_model_fn, one_sided=False): + @parameterized.named_parameters( + ('gan', create_gan_model, False), + ('gan_one_sided', create_gan_model, True), + ('callable_gan', create_callable_gan_model, False), + ('callable_gan_one_sided', create_callable_gan_model, True), + ('infogan', create_infogan_model, False), + ('infogan_one_sided', create_infogan_model, True), + ('callable_infogan', create_callable_infogan_model, False), + ('callable_infogan_one_sided', create_callable_infogan_model, True), + ('acgan', create_acgan_model, False), + ('acgan_one_sided', create_acgan_model, True), + ('callable_acgan', create_callable_acgan_model, False), + ('callable_acgan_one_sided', create_callable_acgan_model, True), + ) + def test_grad_penalty(self, create_gan_model_fn, one_sided): + """Test gradient penalty option.""" model = create_gan_model_fn() loss = train.gan_loss(model) - loss_gp = train.gan_loss(model, - gradient_penalty_weight=1.0, - gradient_penalty_one_sided=one_sided) - self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss)) + loss_gp = train.gan_loss( + model, + gradient_penalty_weight=1.0, + gradient_penalty_one_sided=one_sided) + self.assertIsInstance(loss_gp, namedtuples.GANLoss) # Check values. with self.test_session(use_gpu=True) as sess: @@ -376,58 +563,28 @@ class GANLossTest(test.TestCase): [loss.discriminator_loss, loss_gp.discriminator_loss]) self.assertEqual(loss_gen_np, loss_gen_gp_np) - self.assertTrue(loss_dis_np < loss_dis_gp_np) - - def test_grad_penalty_gan(self): - self._test_grad_penalty_helper(create_gan_model) - - def test_grad_penalty_callable_gan(self): - self._test_grad_penalty_helper(create_callable_gan_model) - - def test_grad_penalty_infogan(self): - self._test_grad_penalty_helper(create_infogan_model) - - def test_grad_penalty_callable_infogan(self): - self._test_grad_penalty_helper(create_callable_infogan_model) - - def test_grad_penalty_acgan(self): - self._test_grad_penalty_helper(create_acgan_model) - - def test_grad_penalty_callable_acgan(self): - self._test_grad_penalty_helper(create_callable_acgan_model) - - def test_grad_penalty_one_sided_gan(self): - self._test_grad_penalty_helper(create_gan_model, one_sided=True) - - def test_grad_penalty_one_sided_callable_gan(self): - self._test_grad_penalty_helper(create_callable_gan_model, one_sided=True) - - def test_grad_penalty_one_sided_infogan(self): - self._test_grad_penalty_helper(create_infogan_model, one_sided=True) - - def test_grad_penalty_one_sided_callable_infogan(self): - self._test_grad_penalty_helper( - create_callable_infogan_model, one_sided=True) - - def test_grad_penalty_one_sided_acgan(self): - self._test_grad_penalty_helper(create_acgan_model, one_sided=True) - - def test_grad_penalty_one_sided_callable_acgan(self): - self._test_grad_penalty_helper(create_callable_acgan_model, one_sided=True) - - # Test mutual information penalty option. - def _test_mutual_info_penalty_helper(self, create_gan_model_fn): - train.gan_loss(create_gan_model_fn(), - mutual_information_penalty_weight=constant_op.constant(1.0)) - - def test_mutual_info_penalty_infogan(self): - self._test_mutual_info_penalty_helper(get_infogan_model) - - def test_mutual_info_penalty_callable_infogan(self): - self._test_mutual_info_penalty_helper(get_callable_infogan_model) - - # Test regularization loss. - def _test_regularization_helper(self, get_gan_model_fn): + self.assertLess(loss_dis_np, loss_dis_gp_np) + + @parameterized.named_parameters( + ('infogan', get_infogan_model), + ('callable_infogan', get_callable_infogan_model), + ) + def test_mutual_info_penalty(self, create_gan_model_fn): + """Test mutual information penalty option.""" + train.gan_loss( + create_gan_model_fn(), + mutual_information_penalty_weight=constant_op.constant(1.0)) + + @parameterized.named_parameters( + ('gan', get_gan_model), + ('callable_gan', get_callable_gan_model), + ('infogan', get_infogan_model), + ('callable_infogan', get_callable_infogan_model), + ('acgan', get_acgan_model), + ('callable_acgan', get_callable_acgan_model), + ) + def test_regularization_helper(self, get_gan_model_fn): + """Test regularization loss.""" # Evaluate losses without regularization. no_reg_loss = train.gan_loss(get_gan_model_fn()) with self.test_session(use_gpu=True): @@ -435,11 +592,11 @@ class GANLossTest(test.TestCase): no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval() with ops.name_scope(get_gan_model_fn().generator_scope.name): - ops.add_to_collection( - ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0)) + ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, + constant_op.constant(3.0)) with ops.name_scope(get_gan_model_fn().discriminator_scope.name): - ops.add_to_collection( - ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0)) + ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, + constant_op.constant(2.0)) # Check that losses now include the correct regularization values. reg_loss = train.gan_loss(get_gan_model_fn()) @@ -447,63 +604,47 @@ class GANLossTest(test.TestCase): reg_loss_gen_np = reg_loss.generator_loss.eval() reg_loss_dis_np = reg_loss.discriminator_loss.eval() - self.assertTrue(3.0, reg_loss_gen_np - no_reg_loss_gen_np) - self.assertTrue(3.0, reg_loss_dis_np - no_reg_loss_dis_np) - - def test_regularization_gan(self): - self._test_regularization_helper(get_gan_model) + self.assertEqual(3.0, reg_loss_gen_np - no_reg_loss_gen_np) + self.assertEqual(2.0, reg_loss_dis_np - no_reg_loss_dis_np) - def test_regularization_callable_gan(self): - self._test_regularization_helper(get_callable_gan_model) - - def test_regularization_infogan(self): - self._test_regularization_helper(get_infogan_model) - - def test_regularization_callable_infogan(self): - self._test_regularization_helper(get_callable_infogan_model) - - def test_regularization_acgan(self): - self._test_regularization_helper(get_acgan_model) - - def test_regularization_callable_acgan(self): - self._test_regularization_helper(get_callable_acgan_model) - - # Test that ACGan models work. - def _test_acgan_helper(self, create_gan_model_fn): + @parameterized.named_parameters( + ('notcallable', create_acgan_model), + ('callable', create_callable_acgan_model), + ) + def test_acgan(self, create_gan_model_fn): + """Test that ACGAN models work.""" model = create_gan_model_fn() loss = train.gan_loss(model) loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0) loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0) - self.assertTrue(isinstance(loss, namedtuples.GANLoss)) - self.assertTrue(isinstance(loss_ac_gen, namedtuples.GANLoss)) - self.assertTrue(isinstance(loss_ac_dis, namedtuples.GANLoss)) + self.assertIsInstance(loss, namedtuples.GANLoss) + self.assertIsInstance(loss_ac_gen, namedtuples.GANLoss) + self.assertIsInstance(loss_ac_dis, namedtuples.GANLoss) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() - loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run( - [loss.generator_loss, - loss_ac_gen.generator_loss, - loss_ac_dis.generator_loss]) - loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run( - [loss.discriminator_loss, - loss_ac_gen.discriminator_loss, - loss_ac_dis.discriminator_loss]) - - self.assertTrue(loss_gen_np < loss_dis_np) + loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run([ + loss.generator_loss, loss_ac_gen.generator_loss, + loss_ac_dis.generator_loss + ]) + loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run([ + loss.discriminator_loss, loss_ac_gen.discriminator_loss, + loss_ac_dis.discriminator_loss + ]) + + self.assertLess(loss_gen_np, loss_dis_np) self.assertTrue(np.isscalar(loss_ac_gen_gen_np)) self.assertTrue(np.isscalar(loss_ac_dis_gen_np)) self.assertTrue(np.isscalar(loss_ac_gen_dis_np)) self.assertTrue(np.isscalar(loss_ac_dis_dis_np)) - def test_acgan(self): - self._test_acgan_helper(create_acgan_model) - - def test_callable_acgan(self): - self._test_acgan_helper(create_callable_acgan_model) - - # Test that CycleGan models work. - def _test_cyclegan_helper(self, create_gan_model_fn): + @parameterized.named_parameters( + ('notcallable', create_cyclegan_model), + ('callable', create_callable_cyclegan_model), + ) + def test_cyclegan(self, create_gan_model_fn): + """Test that CycleGan models work.""" model = create_gan_model_fn() loss = train.cyclegan_loss(model) self.assertIsInstance(loss, namedtuples.CycleGANLoss) @@ -524,14 +665,86 @@ class GANLossTest(test.TestCase): self.assertTrue(np.isscalar(loss_y2x_gen_np)) self.assertTrue(np.isscalar(loss_y2x_dis_np)) - def test_cyclegan(self): - self._test_cyclegan_helper(create_cyclegan_model) + @parameterized.named_parameters( + ('notcallable', create_stargan_model), + ('callable', create_callable_stargan_model), + ) + def test_stargan(self, create_gan_model_fn): + + model = create_gan_model_fn() + model_loss = train.stargan_loss(model) + + self.assertIsInstance(model_loss, namedtuples.GANLoss) + + with self.test_session() as sess: + + sess.run(variables.global_variables_initializer()) + + gen_loss, disc_loss = sess.run( + [model_loss.generator_loss, model_loss.discriminator_loss]) + + self.assertTrue(np.isscalar(gen_loss)) + self.assertTrue(np.isscalar(disc_loss)) + + @parameterized.named_parameters( + ('gan', create_gan_model), + ('callable_gan', create_callable_gan_model), + ('infogan', create_infogan_model), + ('callable_infogan', create_callable_infogan_model), + ('acgan', create_acgan_model), + ('callable_acgan', create_callable_acgan_model), + ) + def test_tensor_pool(self, create_gan_model_fn): + """Test tensor pool option.""" + model = create_gan_model_fn() + tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=5) + loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) + self.assertIsInstance(loss, namedtuples.GANLoss) + + # Check values. + with self.test_session(use_gpu=True) as sess: + variables.global_variables_initializer().run() + for _ in range(10): + sess.run([loss.generator_loss, loss.discriminator_loss]) + + def test_discriminator_only_sees_pool(self): + """Checks that discriminator only sees pooled values.""" + def checker_gen_fn(_): + return constant_op.constant(0.0) + model = train.gan_model( + checker_gen_fn, + discriminator_model, + real_data=array_ops.zeros([]), + generator_inputs=random_ops.random_normal([])) + def tensor_pool_fn(_): + return (random_ops.random_uniform([]), random_ops.random_uniform([])) + def checker_dis_fn(inputs, _): + """Discriminator that checks that it only sees pooled Tensors.""" + self.assertFalse(constant_op.is_constant(inputs)) + return inputs + model = model._replace( + discriminator_fn=checker_dis_fn) + train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) + + def test_doesnt_crash_when_in_nested_scope(self): + with variable_scope.variable_scope('outer_scope'): + gan_model = train.gan_model( + generator_model, + discriminator_model, + real_data=array_ops.zeros([1, 2]), + generator_inputs=random_ops.random_normal([1, 2])) + + # This should work inside a scope. + train.gan_loss(gan_model, gradient_penalty_weight=1.0) - def test_callable_cyclegan(self): - self._test_cyclegan_helper(create_callable_cyclegan_model) + # This should also work outside a scope. + train.gan_loss(gan_model, gradient_penalty_weight=1.0) - def _check_tensor_pool_adjusted_model_outputs(self, tensor1, tensor2, - pool_size): + +class TensorPoolAdjusteModelTest(test.TestCase): + + def _check_tensor_pool_adjusted_model_outputs( + self, tensor1, tensor2, pool_size): history_values = [] with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() @@ -548,115 +761,66 @@ class GANLossTest(test.TestCase): # pool). self.assertTrue(any([(v == t2).all() for v in history_values])) - # Test `_tensor_pool_adjusted_model` for gan model. - def test_tensor_pool_adjusted_model_gan(self): - model = create_gan_model() - - new_model = train._tensor_pool_adjusted_model(model, None) + def _make_new_model_and_check(self, model, pool_size): + pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size) + new_model = train._tensor_pool_adjusted_model(model, pool_fn) # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) - self.assertIs(new_model.discriminator_gen_outputs, - model.discriminator_gen_outputs) - - pool_size = 5 - new_model = train._tensor_pool_adjusted_model( - model, get_tensor_pool_fn(pool_size=pool_size)) self.assertIsNot(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) + + return new_model + + def test_tensor_pool_adjusted_model_gan(self): + """Test `_tensor_pool_adjusted_model` for gan model.""" + pool_size = 5 + model = create_gan_model() + new_model = self._make_new_model_and_check(model, pool_size) + # Check values. self._check_tensor_pool_adjusted_model_outputs( model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, pool_size) - # Test _tensor_pool_adjusted_model for infogan model. def test_tensor_pool_adjusted_model_infogan(self): + """Test _tensor_pool_adjusted_model for infogan model.""" + pool_size = 5 model = create_infogan_model() + new_model = self._make_new_model_and_check(model, pool_size) - pool_size = 5 - new_model = train._tensor_pool_adjusted_model( - model, get_tensor_pool_fn_for_infogan(pool_size=pool_size)) - # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' - self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) - self.assertIsNot(new_model.discriminator_gen_outputs, - model.discriminator_gen_outputs) + # Check values. self.assertIsNot(new_model.predicted_distributions, model.predicted_distributions) - # Check values. self._check_tensor_pool_adjusted_model_outputs( model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, pool_size) - # Test _tensor_pool_adjusted_model for acgan model. def test_tensor_pool_adjusted_model_acgan(self): + """Test _tensor_pool_adjusted_model for acgan model.""" + pool_size = 5 model = create_acgan_model() + new_model = self._make_new_model_and_check(model, pool_size) - pool_size = 5 - new_model = train._tensor_pool_adjusted_model( - model, get_tensor_pool_fn(pool_size=pool_size)) - # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' - self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) - self.assertIsNot(new_model.discriminator_gen_outputs, - model.discriminator_gen_outputs) + # Check values. self.assertIsNot(new_model.discriminator_gen_classification_logits, model.discriminator_gen_classification_logits) - # Check values. self._check_tensor_pool_adjusted_model_outputs( model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, pool_size) - # Test tensor pool. - def _test_tensor_pool_helper(self, create_gan_model_fn): - model = create_gan_model_fn() - if isinstance(model, namedtuples.InfoGANModel): - tensor_pool_fn = get_tensor_pool_fn_for_infogan(pool_size=5) - else: - tensor_pool_fn = get_tensor_pool_fn(pool_size=5) - loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) - self.assertTrue(isinstance(loss, namedtuples.GANLoss)) - - # Check values. - with self.test_session(use_gpu=True) as sess: - variables.global_variables_initializer().run() - for _ in range(10): - sess.run([loss.generator_loss, loss.discriminator_loss]) - - def test_tensor_pool_gan(self): - self._test_tensor_pool_helper(create_gan_model) - - def test_tensor_pool_callable_gan(self): - self._test_tensor_pool_helper(create_callable_gan_model) - - def test_tensor_pool_infogan(self): - self._test_tensor_pool_helper(create_infogan_model) - - def test_tensor_pool_callable_infogan(self): - self._test_tensor_pool_helper(create_callable_infogan_model) - - def test_tensor_pool_acgan(self): - self._test_tensor_pool_helper(create_acgan_model) - - def test_tensor_pool_callable_acgan(self): - self._test_tensor_pool_helper(create_callable_acgan_model) - - def test_doesnt_crash_when_in_nested_scope(self): - with variable_scope.variable_scope('outer_scope'): - gan_model = train.gan_model( - generator_model, - discriminator_model, - real_data=array_ops.zeros([1, 2]), - generator_inputs=random_ops.random_normal([1, 2])) - - # This should work inside a scope. - train.gan_loss(gan_model, gradient_penalty_weight=1.0) - # This should also work outside a scope. - train.gan_loss(gan_model, gradient_penalty_weight=1.0) - - -class GANTrainOpsTest(test.TestCase): +class GANTrainOpsTest(test.TestCase, parameterized.TestCase): """Tests for `gan_train_ops`.""" - def _test_output_type_helper(self, create_gan_model_fn): + @parameterized.named_parameters( + ('gan', create_gan_model), + ('callable_gan', create_callable_gan_model), + ('infogan', create_infogan_model), + ('callable_infogan', create_callable_infogan_model), + ('acgan', create_acgan_model), + ('callable_acgan', create_callable_acgan_model), + ) + def test_output_type(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.gan_loss(model) @@ -670,28 +834,24 @@ class GANTrainOpsTest(test.TestCase): summarize_gradients=True, colocate_gradients_with_ops=True) - self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps)) - - def test_output_type_gan(self): - self._test_output_type_helper(create_gan_model) - - def test_output_type_callable_gan(self): - self._test_output_type_helper(create_callable_gan_model) - - def test_output_type_infogan(self): - self._test_output_type_helper(create_infogan_model) - - def test_output_type_callable_infogan(self): - self._test_output_type_helper(create_callable_infogan_model) - - def test_output_type_acgan(self): - self._test_output_type_helper(create_acgan_model) - - def test_output_type_callable_acgan(self): - self._test_output_type_helper(create_callable_acgan_model) + self.assertIsInstance(train_ops, namedtuples.GANTrainOps) # TODO(joelshor): Add a test to check that custom update op is run. - def _test_unused_update_ops(self, create_gan_model_fn, provide_update_ops): + @parameterized.named_parameters( + ('gan', create_gan_model, False), + ('gan_provideupdates', create_gan_model, True), + ('callable_gan', create_callable_gan_model, False), + ('callable_gan_provideupdates', create_callable_gan_model, True), + ('infogan', create_infogan_model, False), + ('infogan_provideupdates', create_infogan_model, True), + ('callable_infogan', create_callable_infogan_model, False), + ('callable_infogan_provideupdates', create_callable_infogan_model, True), + ('acgan', create_acgan_model, False), + ('acgan_provideupdates', create_acgan_model, True), + ('callable_acgan', create_callable_acgan_model, False), + ('callable_acgan_provideupdates', create_callable_acgan_model, True), + ) + def test_unused_update_ops(self, create_gan_model_fn, provide_update_ops): model = create_gan_model_fn() loss = train.gan_loss(model) @@ -707,8 +867,11 @@ class GANTrainOpsTest(test.TestCase): # Add an update op outside the generator and discriminator scopes. if provide_update_ops: - kwargs = {'update_ops': - [constant_op.constant(1.0), gen_update_op, dis_update_op]} + kwargs = { + 'update_ops': [ + constant_op.constant(1.0), gen_update_op, dis_update_op + ] + } else: ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0)) kwargs = {} @@ -717,8 +880,8 @@ class GANTrainOpsTest(test.TestCase): d_opt = gradient_descent.GradientDescentOptimizer(1.0) with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'): - train.gan_train_ops(model, loss, g_opt, d_opt, - check_for_unused_update_ops=True, **kwargs) + train.gan_train_ops( + model, loss, g_opt, d_opt, check_for_unused_update_ops=True, **kwargs) train_ops = train.gan_train_ops( model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs) @@ -735,44 +898,16 @@ class GANTrainOpsTest(test.TestCase): self.assertEqual(1, gen_update_count.eval()) self.assertEqual(1, dis_update_count.eval()) - def test_unused_update_ops_gan(self): - self._test_unused_update_ops(create_gan_model, False) - - def test_unused_update_ops_gan_provideupdates(self): - self._test_unused_update_ops(create_gan_model, True) - - def test_unused_update_ops_callable_gan(self): - self._test_unused_update_ops(create_callable_gan_model, False) - - def test_unused_update_ops_callable_gan_provideupdates(self): - self._test_unused_update_ops(create_callable_gan_model, True) - - def test_unused_update_ops_infogan(self): - self._test_unused_update_ops(create_infogan_model, False) - - def test_unused_update_ops_infogan_provideupdates(self): - self._test_unused_update_ops(create_infogan_model, True) - - def test_unused_update_ops_callable_infogan(self): - self._test_unused_update_ops(create_callable_infogan_model, False) - - def test_unused_update_ops_callable_infogan_provideupdates(self): - self._test_unused_update_ops(create_callable_infogan_model, True) - - def test_unused_update_ops_acgan(self): - self._test_unused_update_ops(create_acgan_model, False) - - def test_unused_update_ops_acgan_provideupdates(self): - self._test_unused_update_ops(create_acgan_model, True) - - def test_unused_update_ops_callable_acgan(self): - self._test_unused_update_ops(create_callable_acgan_model, False) - - def test_unused_update_ops_callable_acgan_provideupdates(self): - self._test_unused_update_ops(create_callable_acgan_model, True) - - def _test_sync_replicas_helper( - self, create_gan_model_fn, create_global_step=False): + @parameterized.named_parameters( + ('gan', create_gan_model, False), + ('callable_gan', create_callable_gan_model, False), + ('infogan', create_infogan_model, False), + ('callable_infogan', create_callable_infogan_model, False), + ('acgan', create_acgan_model, False), + ('callable_acgan', create_callable_acgan_model, False), + ('gan_canbeint32', create_gan_model, True), + ) + def test_sync_replicas(self, create_gan_model_fn, create_global_step): model = create_gan_model_fn() loss = train.gan_loss(model) num_trainable_vars = len(variables_lib.get_trainable_variables()) @@ -785,11 +920,8 @@ class GANTrainOpsTest(test.TestCase): g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = train.gan_train_ops( - model, - loss, - generator_optimizer=g_opt, - discriminator_optimizer=d_opt) - self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps)) + model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt) + self.assertIsInstance(train_ops, namedtuples.GANTrainOps) # No new trainable variables should have been added. self.assertEqual(num_trainable_vars, len(variables_lib.get_trainable_variables())) @@ -827,29 +959,8 @@ class GANTrainOpsTest(test.TestCase): coord.request_stop() coord.join(g_threads + d_threads) - def test_sync_replicas_gan(self): - self._test_sync_replicas_helper(create_gan_model) - - def test_sync_replicas_callable_gan(self): - self._test_sync_replicas_helper(create_callable_gan_model) - - def test_sync_replicas_infogan(self): - self._test_sync_replicas_helper(create_infogan_model) - def test_sync_replicas_callable_infogan(self): - self._test_sync_replicas_helper(create_callable_infogan_model) - - def test_sync_replicas_acgan(self): - self._test_sync_replicas_helper(create_acgan_model) - - def test_sync_replicas_callable_acgan(self): - self._test_sync_replicas_helper(create_callable_acgan_model) - - def test_global_step_can_be_int32(self): - self._test_sync_replicas_helper(create_gan_model, create_global_step=True) - - -class GANTrainTest(test.TestCase): +class GANTrainTest(test.TestCase, parameterized.TestCase): """Tests for `gan_train`.""" def _gan_train_ops(self, generator_add, discriminator_add): @@ -860,12 +971,20 @@ class GANTrainTest(test.TestCase): # joint training. train_ops = namedtuples.GANTrainOps( generator_train_op=step.assign_add(generator_add, use_locking=True), - discriminator_train_op=step.assign_add(discriminator_add, - use_locking=True), + discriminator_train_op=step.assign_add( + discriminator_add, use_locking=True), global_step_inc_op=step.assign_add(1)) return train_ops - def _test_run_helper(self, create_gan_model_fn): + @parameterized.named_parameters( + ('gan', create_gan_model), + ('callable_gan', create_callable_gan_model), + ('infogan', create_infogan_model), + ('callable_infogan', create_callable_infogan_model), + ('acgan', create_acgan_model), + ('callable_acgan', create_callable_acgan_model), + ) + def test_run_helper(self, create_gan_model_fn): random_seed.set_random_seed(1234) model = create_gan_model_fn() loss = train.gan_loss(model) @@ -881,30 +1000,15 @@ class GANTrainTest(test.TestCase): self.assertTrue(np.isscalar(final_step)) self.assertEqual(2, final_step) - def test_run_gan(self): - self._test_run_helper(create_gan_model) - - def test_run_callable_gan(self): - self._test_run_helper(create_callable_gan_model) - - def test_run_infogan(self): - self._test_run_helper(create_infogan_model) - - def test_run_callable_infogan(self): - self._test_run_helper(create_callable_infogan_model) - - def test_run_acgan(self): - self._test_run_helper(create_acgan_model) - - def test_run_callable_acgan(self): - self._test_run_helper(create_callable_acgan_model) - - # Test multiple train steps. - def _test_multiple_steps_helper(self, get_hooks_fn_fn): + @parameterized.named_parameters( + ('seq_train_steps', train.get_sequential_train_hooks), + ('efficient_seq_train_steps', train.get_joint_train_hooks), + ) + def test_multiple_steps(self, get_hooks_fn_fn): + """Test multiple train steps.""" train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100) train_steps = namedtuples.GANTrainSteps( - generator_train_steps=3, - discriminator_train_steps=4) + generator_train_steps=3, discriminator_train_steps=4) final_step = train.gan_train( train_ops, get_hooks_fn=get_hooks_fn_fn(train_steps), @@ -914,12 +1018,6 @@ class GANTrainTest(test.TestCase): self.assertTrue(np.isscalar(final_step)) self.assertEqual(1 + 3 * 10 + 4 * 100, final_step) - def test_multiple_steps_seq_train_steps(self): - self._test_multiple_steps_helper(train.get_sequential_train_hooks) - - def test_multiple_steps_efficient_seq_train_steps(self): - self._test_multiple_steps_helper(train.get_joint_train_hooks) - def test_supervisor_run_gan_model_train_ops_multiple_steps(self): step = training_util.create_global_step() train_ops = namedtuples.GANTrainOps( @@ -927,8 +1025,7 @@ class GANTrainTest(test.TestCase): discriminator_train_op=constant_op.constant(2.0), global_step_inc_op=step.assign_add(1)) train_steps = namedtuples.GANTrainSteps( - generator_train_steps=3, - discriminator_train_steps=4) + generator_train_steps=3, discriminator_train_steps=4) final_loss = slim_learning.train( train_op=train_ops, @@ -940,10 +1037,18 @@ class GANTrainTest(test.TestCase): self.assertEqual(17.0, final_loss) -class PatchGANTest(test.TestCase): +class PatchGANTest(test.TestCase, parameterized.TestCase): """Tests that functions work on PatchGAN style output.""" - def _test_patchgan_helper(self, create_gan_model_fn): + @parameterized.named_parameters( + ('gan', create_gan_model), + ('callable_gan', create_callable_gan_model), + ('infogan', create_infogan_model), + ('callable_infogan', create_callable_infogan_model), + ('acgan', create_acgan_model), + ('callable_acgan', create_callable_acgan_model), + ) + def test_patchgan(self, create_gan_model_fn): """Ensure that patch-based discriminators work end-to-end.""" random_seed.set_random_seed(1234) model = create_gan_model_fn() @@ -960,24 +1065,6 @@ class PatchGANTest(test.TestCase): self.assertTrue(np.isscalar(final_step)) self.assertEqual(2, final_step) - def test_patchgan_gan(self): - self._test_patchgan_helper(create_gan_model) - - def test_patchgan_callable_gan(self): - self._test_patchgan_helper(create_callable_gan_model) - - def test_patchgan_infogan(self): - self._test_patchgan_helper(create_infogan_model) - - def test_patchgan_callable_infogan(self): - self._test_patchgan_helper(create_callable_infogan_model) - - def test_patchgan_acgan(self): - self._test_patchgan_helper(create_acgan_model) - - def test_patchgan_callable_acgan(self): - self._test_patchgan_helper(create_callable_acgan_model) - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 81e70ae30a4c72dbcedd1aabfe758ecca4c8b366..7e6a0f14f6f5e467801fef39ebb597565b3d7e98 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -33,9 +33,11 @@ limitations under the License. #include "tensorflow/core/common_runtime/bfc_allocator.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/pool_allocator.h" +#include "tensorflow/core/common_runtime/process_state.h" #if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" #endif // GOOGLE_CUDA #include "tensorflow/core/framework/allocator_registry.h" #include "tensorflow/core/lib/core/status.h" @@ -172,7 +174,7 @@ class GdrMemoryManager : public RemoteMemoryManager { // Client side endpoints mutex client_mu_; std::map, RdmaEndpointPtr> clients_ - GUARDED_BY(cient_mu_); + GUARDED_BY(client_mu_); // Managed memory regions mutex alloc_mu_; @@ -181,28 +183,25 @@ class GdrMemoryManager : public RemoteMemoryManager { TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager); }; -// TODO(byronyi): remove this class duplicated from the one in -// common/runtime/gpu/pool_allocator.h when it is available in common_runtime -class BasicCPUAllocator : public SubAllocator { - public: - ~BasicCPUAllocator() override {} - - void* Alloc(size_t alignment, size_t num_bytes) override { - return port::AlignedMalloc(num_bytes, alignment); - } - void Free(void* ptr, size_t) override { port::AlignedFree(ptr); } -}; - // TODO(byronyi): remove this class and its registration when the default -// cpu_allocator() returns visitable allocator +// cpu_allocator() returns visitable allocator, or cpu_allocator() is no +// longer in use. class BFCRdmaAllocator : public BFCAllocator { public: BFCRdmaAllocator() - : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") { + : BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36, + true, "cpu_rdma_bfc") {} +}; +class BFCRdmaAllocatorFactory : public AllocatorFactory { + public: + Allocator* CreateAllocator() override { return new BFCRdmaAllocator; } + + virtual SubAllocator* CreateSubAllocator(int numa_node) { + return new BasicCPUAllocator(numa_node); } }; -REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator); +REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocatorFactory); GdrMemoryManager::GdrMemoryManager(const string& host, const string& port) : host_(host), @@ -274,9 +273,9 @@ Status GdrMemoryManager::Init() { Allocator* allocators[] = { #if GOOGLE_CUDA - ProcessState::singleton()->GetCUDAHostAllocator(0), - ProcessState::singleton()->GetCPUAllocator(0), + GPUProcessState::singleton()->GetCUDAHostAllocator(0), #endif // GOOGLE_CUDA + ProcessState::singleton()->GetCPUAllocator(0), cpu_allocator(), }; @@ -308,7 +307,8 @@ Status GdrMemoryManager::Init() { if (IsGDRAvailable()) { // Note we don't free allocated GPU memory so there is no free visitor int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1; - ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor); + GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id, + cuda_alloc_visitor); LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; } #endif // GOOGLE_CUDA @@ -430,7 +430,7 @@ void GdrMemoryManager::TransportOptionsFromTensor( #if GOOGLE_CUDA if (!on_host) { - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); GPUUtil::CopyGPUTensorToCPU( device, device_context, &tensor, host_copy, @@ -532,7 +532,7 @@ void GdrMemoryManager::TensorFromTransportOptions( Tensor host_copy; #if GOOGLE_CUDA if (mr == nullptr && !on_host) { - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); buffer = DMAHelper::buffer(&host_copy); addr = buffer->data(); diff --git a/tensorflow/contrib/graph_editor/reroute.py b/tensorflow/contrib/graph_editor/reroute.py index 95c02a64d47c26e731ef2628fb551529e9bc3f4d..d42e0c01f455f861e9ccdbfb79aefab762e61abe 100644 --- a/tensorflow/contrib/graph_editor/reroute.py +++ b/tensorflow/contrib/graph_editor/reroute.py @@ -208,9 +208,9 @@ def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None): def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None): """For each tensor's pair, swap the end of (t0,t1). - B0 B1 B0 B1 - | | => X - A0 A1 A0 A1 + B0 B1 B0 B1 + | | => X + A0 A1 A0 A1 Args: ts0: an object convertible to a list of `tf.Tensor`. @@ -233,9 +233,9 @@ def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None): def reroute_ts(ts0, ts1, can_modify=None, cannot_modify=None): """For each tensor's pair, replace the end of t1 by the end of t0. - B0 B1 B0 B1 - | | => |/ - A0 A1 A0 A1 + B0 B1 B0 B1 + | | => |/ + A0 A1 A0 A1 The end of the tensors in ts1 are left dangling. diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 026a3d1200033400472c4fd763a244c04b284a9b..e79ccd8da1f8952758ae322d3a92dec34910a9db 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -129,7 +129,7 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True): return None -def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None): +def copy_op_handler(info, op, new_inputs, copy_shape=False, nodedef_fn=None): """Copy a `tf.Operation`. Args: diff --git a/tensorflow/contrib/hadoop/BUILD b/tensorflow/contrib/hadoop/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..ccad31efa1dba92d954ff1cb455b6c9c784b29bc --- /dev/null +++ b/tensorflow/contrib/hadoop/BUILD @@ -0,0 +1,117 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_py_test", +) + +filegroup( + name = "test_data", + srcs = glob(["python/kernel_tests/testdata/*"]), +) + +py_library( + name = "hadoop", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", + ], +) + +tf_custom_op_library( + name = "_dataset_ops.so", + srcs = ["ops/dataset_ops.cc"], + deps = [ + ":dataset_kernels", + ], +) + +tf_gen_op_libs( + op_lib_names = ["dataset_ops"], +) + +cc_library( + name = "dataset_kernels", + srcs = ["kernels/hadoop_dataset_ops.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +py_library( + name = "dataset_ops", + srcs = [ + "python/ops/hadoop_dataset_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":hadoop_op_loader", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_dataset_ops", + out = "python/ops/gen_dataset_ops.py", + deps = ["//tensorflow/contrib/hadoop:dataset_ops_op_lib"], +) + +tf_kernel_library( + name = "dataset_ops_kernels", + deps = [ + ":dataset_kernels", + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "hadoop_op_loader", + srcs = ["python/ops/hadoop_op_loader.py"], + dso = ["//tensorflow/contrib/hadoop:_dataset_ops.so"], + kernels = [ + ":dataset_ops_kernels", + "//tensorflow/contrib/hadoop:dataset_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_dataset_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + ], +) + +tf_py_test( + name = "hadoop_test", + srcs = ["python/kernel_tests/hadoop_test.py"], + additional_deps = [ + ":hadoop", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + data = [ + ":test_data", + ], + tags = [ + "notap", + ], +) diff --git a/tensorflow/contrib/control_flow/__init__.py b/tensorflow/contrib/hadoop/__init__.py similarity index 82% rename from tensorflow/contrib/control_flow/__init__.py rename to tensorflow/contrib/hadoop/__init__.py index 582af2cf10a3d92dd8611b0f2826625e3acfb099..abf8cd4845f9713ebd8a647af191000061e01ad1 100644 --- a/tensorflow/contrib/control_flow/__init__.py +++ b/tensorflow/contrib/hadoop/__init__.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Sequence File Dataset. -"""New implementations of TF control flow ops. - -@@cond_v2 +@@SequenceFileDataset """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import -from tensorflow.contrib.control_flow.python.cond_v2 import cond_v2 -# pylint: enable=unused-import +from tensorflow.contrib.hadoop.python.ops.hadoop_dataset_ops import SequenceFileDataset from tensorflow.python.util.all_util import remove_undocumented +_allowed_symbols = [ + "SequenceFileDataset", +] + remove_undocumented(__name__) diff --git a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..80b2d3e08b6745b776aa7b4073e841145defd3c4 --- /dev/null +++ b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc @@ -0,0 +1,340 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { +namespace { + +static const size_t kSyncMarkerSize = 16; +static const size_t kSequenceFileBufferSize = 1024 * 1024; + +class SequenceFileReader { + public: + explicit SequenceFileReader(RandomAccessFile* file) + : input_stream_( + new io::BufferedInputStream(file, kSequenceFileBufferSize)) {} + + Status ReadHeader() { + string version; + TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &version)); + if (version.substr(0, 3) != "SEQ" || version[3] != 6) { + return errors::InvalidArgument( + "sequence file header must starts with `SEQ6`, received \"", + version.substr(0, 3), static_cast(version[3]), "\""); + } + TF_RETURN_IF_ERROR(ReadString(&key_class_name_)); + TF_RETURN_IF_ERROR(ReadString(&value_class_name_)); + + // At the moment we only support `org.apache.hadoop.io.Text` for key/value. + // TODO (yongtang): Add more class name support. + if (key_class_name_ != "org.apache.hadoop.io.Text" || + value_class_name_ != "org.apache.hadoop.io.Text") { + return errors::Unimplemented("key/value of '", key_class_name_, "/", + value_class_name_, + "' is currently not supported"); + } + + string buffer; + TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(2, &buffer)); + compression_ = buffer[0]; + block_compression_ = buffer[1]; + if (compression_ || block_compression_) { + TF_RETURN_IF_ERROR(ReadString(&compression_codec_class_name_)); + } + + // At the moment no compression is supported. + // TODO (yongtang): Add compression support. + if (compression_ || block_compression_) { + return errors::Unimplemented("compression is currently not supported"); + } + + // Not interested in metadata for now. + uint32 num_metadata_pairs = 0; + TF_RETURN_IF_ERROR(ReadUInt32(&num_metadata_pairs)); + if (num_metadata_pairs > 1024) { + return errors::InvalidArgument( + "sequence file metadata should have key value pairs < 1024, " + "received ", + num_metadata_pairs); + } + for (int i = 0; i < num_metadata_pairs; i++) { + TF_RETURN_IF_ERROR(ReadString(nullptr)); + TF_RETURN_IF_ERROR(ReadString(nullptr)); + } + + TF_RETURN_IF_ERROR( + input_stream_->ReadNBytes(kSyncMarkerSize, &sync_marker_)); + + return Status::OK(); + } + + Status ReadRecord(string* key, string* value) { + uint32 length = 0; + TF_RETURN_IF_ERROR(ReadUInt32(&length)); + if (length == static_cast(-1)) { + // Sync marker. + string sync_marker; + TF_RETURN_IF_ERROR( + input_stream_->ReadNBytes(kSyncMarkerSize, &sync_marker)); + if (sync_marker != sync_marker_) { + return errors::InvalidArgument( + "sequence file should have sync marker \"", sync_marker_, + "\" at pos ", input_stream_->Tell() - kSyncMarkerSize, + ", received \"", sync_marker, "\""); + } + return ReadRecord(key, value); + } + uint32 key_length = 0; + TF_RETURN_IF_ERROR(ReadUInt32(&key_length)); + if (key_length > length) { + return errors::InvalidArgument("key length (", key_length, + ") should be < record length (", length, + ")"); + } + // At the moment we only support `org.apache.hadoop.io.Text` for key/value. + // TODO (yongtang): Expand supported format. + TF_RETURN_IF_ERROR(ReadString(key)); + TF_RETURN_IF_ERROR(ReadString(value)); + return Status::OK(); + } + + Status ReadString(string* value) { + int64 length = 0; + TF_RETURN_IF_ERROR(ReadVInt(&length)); + if (value == nullptr) { + return input_stream_->SkipNBytes(length); + } + return input_stream_->ReadNBytes(length, value); + } + + Status ReadUInt32(uint32* value) { + string buffer; + TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &buffer)); + *value = ((static_cast(buffer[0]) << 24) | + static_cast(buffer[1]) << 16) | + (static_cast(buffer[2]) << 8) | + static_cast(buffer[3]); + return Status::OK(); + } + + Status ReadVInt(int64* value) { + string buffer; + TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(1, &buffer)); + if (buffer[0] >= -112) { + *value = static_cast(buffer[0]); + return Status::OK(); + } + + int64 remaining = 0; + bool negative = false; + if (buffer[0] >= -120) { + remaining = static_cast(-112) - static_cast(buffer[0]); + } else { + remaining = static_cast(-120) - static_cast(buffer[0]); + negative = true; + } + buffer.clear(); + TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(remaining, &buffer)); + + uint64 v = 0; + for (int i = 0; i < buffer.size(); i++) { + v = (v << 8) | static_cast(buffer[i]); + } + if (negative) { + v = ~v; + } + *value = static_cast(v); + return Status::OK(); + } + + virtual ~SequenceFileReader() = default; + + private: + std::unique_ptr input_stream_; + string key_class_name_; + string value_class_name_; + string sync_marker_; + bool compression_; + bool block_compression_; + string compression_codec_class_name_; + TF_DISALLOW_COPY_AND_ASSIGN(SequenceFileReader); +}; +class SequenceFileDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + explicit SequenceFileDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + for (const DataType& dt : output_types_) { + OP_REQUIRES(ctx, dt == DT_STRING, + errors::InvalidArgument( + "Each element of `output_types_` must be one of: " + "DT_STRING")); + } + } + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + std::vector filenames; + filenames.reserve(filenames_tensor->NumElements()); + for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + filenames.push_back(filenames_tensor->flat()(i)); + } + + *output = new Dataset(ctx, filenames, output_types_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const std::vector& filenames, + const DataTypeVector& output_types) + : DatasetBase(DatasetContext(ctx)), + filenames_(filenames), + output_types_(output_types) {} + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::SequenceFile")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}, {}}); + return *shapes; + } + + string DebugString() const override { + return "SequenceFileDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* filenames = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + do { + // We are currently processing a file, so try to read the next record. + if (reader_) { + string key, value; + Status status = reader_->ReadRecord(&key, &value); + if (!errors::IsOutOfRange(status)) { + TF_RETURN_IF_ERROR(status); + + Tensor key_tensor(ctx->allocator({}), DT_STRING, {}); + key_tensor.scalar()() = key; + out_tensors->emplace_back(std::move(key_tensor)); + + Tensor value_tensor(ctx->allocator({}), DT_STRING, {}); + value_tensor.scalar()() = value; + out_tensors->emplace_back(std::move(value_tensor)); + + *end_of_sequence = false; + return Status::OK(); + } + // We have reached the end of the current file, so maybe + // move on to next file. + ResetStreamsLocked(); + ++current_file_index_; + } + + // Iteration ends when there are no more files to process. + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented("SaveInternal is currently not supported"); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented( + "RestoreInternal is currently not supported"); + } + + private: + // Sets up SequenceFile streams to read from the topic at + // `current_file_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); + } + + // Actually move on to next file. + const string& filename = dataset()->filenames_[current_file_index_]; + TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file_)); + reader_.reset(new SequenceFileReader(file_.get())); + return reader_->ReadHeader(); + } + + // Resets all Hadoop SequenceFile streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + reader_.reset(); + file_.reset(); + } + + mutex mu_; + size_t current_file_index_ GUARDED_BY(mu_) = 0; + std::unique_ptr file_ GUARDED_BY(mu_); + std::unique_ptr reader_ GUARDED_BY(mu_); + }; + + const std::vector filenames_; + const DataTypeVector output_types_; + }; + DataTypeVector output_types_; +}; +} // namespace + +REGISTER_KERNEL_BUILDER(Name("SequenceFileDataset").Device(DEVICE_CPU), + SequenceFileDatasetOp); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc b/tensorflow/contrib/hadoop/ops/dataset_ops.cc similarity index 52% rename from tensorflow/contrib/tensorrt/ops/trt_calib_op.cc rename to tensorflow/contrib/hadoop/ops/dataset_ops.cc index 4835e5065068ec7a59995eb7f6126b31aecf6704..66ad549b4756028a45c1ce76db4a2367517f81a5 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc +++ b/tensorflow/contrib/hadoop/ops/dataset_ops.cc @@ -13,25 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" + namespace tensorflow { -REGISTER_OP("TRTCalibOp") - .Attr("segment_nodes: list(string)") // names of the ops in segment - .Attr("segment_output_names: list(string)") // names of the output ops in - // segment - .Attr("input_names: list(string)") // names of the inputs for - // passing into tensorrt - .Attr("resource_name: string") - .Attr("InT: list({int8, float16, float32})") - .Input("in_tensor: InT") - .Output("out_tensor: InT") - .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { - for (int i = 0; i < c->num_inputs(); i++) { - c->set_output(i, c->input(i)); - } - return Status::OK(); - }); +REGISTER_OP("SequenceFileDataset") + .Input("filenames: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape); } // namespace tensorflow diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d796e43d877e463fa4398741748013b2eb661155 --- /dev/null +++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for SequenceFileDataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.hadoop.python.ops import hadoop_dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test + + +class SequenceFileDatasetTest(test.TestCase): + + def test_sequence_file_dataset(self): + """Test case for SequenceFileDataset. + + The file is generated with `org.apache.hadoop.io.Text` for key/value. + There are 25 records in the file with the format of: + key = XXX + value = VALUEXXX + where XXX is replaced as the line number (starts with 001). + """ + filename = os.path.join(resource_loader.get_data_files_path(), + "testdata", "string.seq") + + filenames = constant_op.constant([filename], dtypes.string) + num_repeats = 2 + + dataset = hadoop_dataset_ops.SequenceFileDataset(filenames).repeat( + num_repeats) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(num_repeats): # Dataset is repeated. + for i in range(25): # 25 records. + v0 = b"%03d" % (i + 1) + v1 = b"VALUE%03d" % (i + 1) + self.assertEqual((v0, v1), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/testdata/string.seq b/tensorflow/contrib/hadoop/python/kernel_tests/testdata/string.seq new file mode 100755 index 0000000000000000000000000000000000000000..b7175338af3417a8858d66082ab5a616f87cb234 Binary files /dev/null and b/tensorflow/contrib/hadoop/python/kernel_tests/testdata/string.seq differ diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0e628655fbc32a43fad2dc4883b26c6ad57c48 --- /dev/null +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py @@ -0,0 +1,75 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SequenceFile Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.hadoop.python.ops import gen_dataset_ops +from tensorflow.contrib.hadoop.python.ops import hadoop_op_loader # pylint: disable=unused-import +from tensorflow.python.data.ops.dataset_ops import Dataset +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class SequenceFileDataset(Dataset): + """A Sequence File Dataset that reads the sequence file.""" + + def __init__(self, filenames): + """Create a `SequenceFileDataset`. + + `SequenceFileDataset` allows a user to read data from a hadoop sequence + file. A sequence file consists of (key value) pairs sequentially. At + the moment, `org.apache.hadoop.io.Text` is the only serialization type + being supported, and there is no compression support. + + For example: + + ```python + dataset = tf.contrib.hadoop.SequenceFileDataset("/foo/bar.seq") + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + # Prints the (key, value) pairs inside a hadoop sequence file. + while True: + try: + print(sess.run(next_element)) + except tf.errors.OutOfRangeError: + break + ``` + + Args: + filenames: A `tf.string` tensor containing one or more filenames. + """ + super(SequenceFileDataset, self).__init__() + self._filenames = ops.convert_to_tensor( + filenames, dtype=dtypes.string, name="filenames") + + def _as_variant_tensor(self): + return gen_dataset_ops.sequence_file_dataset( + self._filenames, nest.flatten(self.output_types)) + + @property + def output_classes(self): + return ops.Tensor, ops.Tensor + + @property + def output_shapes(self): + return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) + + @property + def output_types(self): + return dtypes.string, dtypes.string diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_op_loader.py b/tensorflow/contrib/hadoop/python/ops/hadoop_op_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..6dbf1253f3f746de0da9664b4262cb208bee9c98 --- /dev/null +++ b/tensorflow/contrib/hadoop/python/ops/hadoop_op_loader.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python helper for loading hadoop ops and kernels.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_dataset_ops.so")) diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index c2e32da133b32c8fe169302668031af8bace2c22..693724b45751b82f942bc9416e6fb2ba33b30e22 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -35,6 +35,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; template struct FillProjectiveTransform; template struct FillProjectiveTransform; template struct FillProjectiveTransform; +template struct FillProjectiveTransform; template struct FillProjectiveTransform; template struct FillProjectiveTransform; @@ -70,6 +71,7 @@ class ImageProjectiveTransform : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& images_t = ctx->input(0); const Tensor& transform_t = ctx->input(1); + const Tensor& shape_t = ctx->input(2); OP_REQUIRES(ctx, images_t.shape().dims() == 4, errors::InvalidArgument("Input images must have rank 4")); OP_REQUIRES(ctx, @@ -80,11 +82,28 @@ class ImageProjectiveTransform : public OpKernel { ProjectiveGenerator::kNumParameters), errors::InvalidArgument( "Input transform should be num_images x 8 or 1 x 8")); - auto images = images_t.tensor(); - auto transform = transform_t.matrix(); + OP_REQUIRES(ctx, shape_t.dims() == 1, + errors::InvalidArgument("output shape must be 1-dimensional", + shape_t.shape().DebugString())); + OP_REQUIRES(ctx, shape_t.NumElements() == 2, + errors::InvalidArgument("output shape must have two elements", + shape_t.shape().DebugString())); + auto shape_vec = shape_t.vec(); + int32 out_height = shape_vec(0); + int32 out_width = shape_vec(1); + OP_REQUIRES(ctx, out_height > 0 && out_width > 0, + errors::InvalidArgument("output dimensions must be positive")); + Tensor* output_t; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t)); + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 0, + TensorShape({images_t.dim_size(0), out_height, + out_width, images_t.dim_size(3)}), + &output_t)); auto output = output_t->tensor(); + auto images = images_t.tensor(); + auto transform = transform_t.matrix(); + (FillProjectiveTransform(interpolation_))( ctx->eigen_device(), &output, images, transform); } @@ -99,6 +118,7 @@ class ImageProjectiveTransform : public OpKernel { TF_CALL_uint8(REGISTER); TF_CALL_int32(REGISTER); TF_CALL_int64(REGISTER); +TF_CALL_half(REGISTER); TF_CALL_float(REGISTER); TF_CALL_double(REGISTER); @@ -127,10 +147,11 @@ TF_CALL_double(DECLARE_FUNCTOR); } // end namespace functor -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("dtype"), \ +#define REGISTER(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("dtype") \ + .HostMemory("output_shape"), \ ImageProjectiveTransform) TF_CALL_uint8(REGISTER); diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h index ad501330617be89c87a0e94ab6e8773a6e1eecf6..6b63eed1303accc330293b3a44cdb9def7881666 100644 --- a/tensorflow/contrib/image/kernels/image_ops.h +++ b/tensorflow/contrib/image/kernels/image_ops.h @@ -21,6 +21,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/types.h" @@ -58,6 +59,11 @@ class ProjectiveGenerator { ? transforms_.data() : &transforms_.data()[transforms_.dimension(1) * coords[0]]; float projection = transform[6] * output_x + transform[7] * output_y + 1.f; + if (projection == 0) { + // Return the fill value (0) for infinite coordinates, + // which are outside the input image + return T(0); + } const float input_x = (transform[0] * output_x + transform[1] * output_y + transform[2]) / projection; @@ -105,21 +111,21 @@ class ProjectiveGenerator { // f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor) // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor) const float value_yfloor = - (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor), - DenseIndex(x_floor), channel, - fill_value) + - (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor), - DenseIndex(x_ceil), channel, - fill_value); + (x_ceil - x) * static_cast(read_with_fill_value( + batch, DenseIndex(y_floor), DenseIndex(x_floor), + channel, fill_value)) + + (x - x_floor) * static_cast(read_with_fill_value( + batch, DenseIndex(y_floor), DenseIndex(x_ceil), + channel, fill_value)); // f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil) // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil) const float value_yceil = - (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil), - DenseIndex(x_floor), channel, - fill_value) + - (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil), - DenseIndex(x_ceil), channel, - fill_value); + (x_ceil - x) * static_cast(read_with_fill_value( + batch, DenseIndex(y_ceil), DenseIndex(x_floor), + channel, fill_value)) + + (x - x_floor) * static_cast(read_with_fill_value( + batch, DenseIndex(y_ceil), DenseIndex(x_ceil), + channel, fill_value)); // f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor) // + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil) return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil); @@ -161,7 +167,7 @@ struct FillProjectiveTransform { void operator()(const Device& device, OutputType* output, const InputType& images, const TransformsType& transform) const { - output->device(device) = images.generate( + output->device(device) = output->generate( ProjectiveGenerator(images, transform, interpolation_)); } }; diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc index ebdcaea7abae2a967786831b62b331897aa3f6a3..4969ac58f96c8c0b829828ad7617a0bb5520cd6a 100644 --- a/tensorflow/contrib/image/ops/image_ops.cc +++ b/tensorflow/contrib/image/ops/image_ops.cc @@ -19,23 +19,66 @@ limitations under the License. namespace tensorflow { +using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; +namespace { + +// Sets output[0] to shape [batch_dim,height,width,channel_dim], where +// height and width come from the size_tensor. +Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim, + int size_input_idx, DimensionHandle channel_dim) { + // Verify shape of size input. + ShapeHandle size; + TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size)); + DimensionHandle unused; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused)); + + // Get size values from the size tensor. + const Tensor* size_tensor = c->input_tensor(size_input_idx); + DimensionHandle width; + DimensionHandle height; + if (size_tensor == nullptr) { + width = c->UnknownDim(); + height = c->UnknownDim(); + } else { + // TODO(petewarden) - Remove once we have constant evaluation in C++ only. + if (size_tensor->dtype() != DT_INT32) { + return errors::InvalidArgument( + "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 " + "but got ", + DataTypeString(size_tensor->dtype()), " for input #", size_input_idx, + " in ", c->DebugString()); + } + auto vec = size_tensor->vec(); + height = c->MakeDim(vec(0)); + width = c->MakeDim(vec(1)); + } + c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim})); + return Status::OK(); +} + +// TODO(qyu): Move this to core/framework/common_shape_fns.h +Status ResizeShapeFn(InferenceContext* c) { + ShapeHandle input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); + return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */, + c->Dim(input, 3)); +} + +} // namespace + // TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc. // TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0). -// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to -// implement "same" and "valid" modes in the Python function. REGISTER_OP("ImageProjectiveTransform") .Input("images: dtype") .Input("transforms: float32") - .Attr("dtype: {uint8, int32, int64, float32, float64}") + .Input("output_shape: int32") + .Attr("dtype: {uint8, int32, int64, float16, float32, float64}") .Attr("interpolation: string") .Output("transformed_images: dtype") - .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); - return Status::OK(); - }) + .SetShapeFn(ResizeShapeFn) .Doc(R"doc( Applies the given transform to each of the images. @@ -49,7 +92,7 @@ If one row of `transforms` is `[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps the *output* point `(x, y)` to a transformed *input* point `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where `k = c0 x + c1 y + 1`. If the transformed point lays outside of the input -image, the output pixel is set to 0. The output is the same size as the input, +image, the output pixel is set to 0. images: 4D `Tensor`, input image(s) in NHWC format. transforms: 2D `Tensor`, projective transform(s) to apply to the image(s). diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index b50177ae5651fbc15f292e11031411c2074357ec..f588eae923f403f07c7f502821db4ef6acad71d5 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -27,10 +27,12 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.platform import googletest _DTYPES = set( - [dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) + [dtypes.uint8, dtypes.int32, dtypes.int64, + dtypes.float16, dtypes.float32, dtypes.float64]) class ImageOpsTest(test_util.TensorFlowTestCase): @@ -127,6 +129,23 @@ class ImageOpsTest(test_util.TensorFlowTestCase): [0, 1, 0, 1], [0, 1, 1, 1]]) + def test_extreme_projective_transform(self): + for dtype in _DTYPES: + with self.test_session(): + image = constant_op.constant( + [[1, 0, 1, 0], + [0, 1, 0, 1], + [1, 0, 1, 0], + [0, 1, 0, 1]], dtype=dtype) + transformation = constant_op.constant([1, 0, 0, 0, 1, 0, -1, 0], + dtypes.float32) + image_transformed = image_ops.transform(image, transformation) + self.assertAllEqual(image_transformed.eval(), + [[1, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 0]]) + def test_bilinear(self): with self.test_session(): image = constant_op.constant( @@ -176,6 +195,19 @@ class ImageOpsTest(test_util.TensorFlowTestCase): [0.0, 149, 233, 149, 0.0], [0.0, 0.0, 87., 0.0, 0.0]]) + def test_rotate_static_shape(self): + image = array_ops.diag([1., 2., 3.]) + result = image_ops.rotate( + image, random_ops.random_uniform((), -1, 1), interpolation="BILINEAR") + self.assertEqual(image.get_shape(), result.get_shape()) + + def test_transform_static_output_shape(self): + image = constant_op.constant([[1., 2.], [3., 4.]]) + result = image_ops.transform( + image, random_ops.random_uniform([8], -1, 1), + output_shape=constant_op.constant([3, 5])) + self.assertAllEqual([3, 5], result.get_shape()) + def _test_grad(self, shape_to_test): with self.test_session(): test_image_shape = shape_to_test @@ -195,10 +227,40 @@ class ImageOpsTest(test_util.TensorFlowTestCase): x_init_value=test_image) self.assertLess(left_err, 1e-10) + def _test_grad_different_shape(self, input_shape, output_shape): + with self.test_session(): + test_image_shape = input_shape + test_image = np.random.randn(*test_image_shape) + test_image_tensor = constant_op.constant( + test_image, shape=test_image_shape) + test_transform = image_ops.angles_to_projective_transforms( + np.pi / 2, 4, 4) + + if len(output_shape) == 2: + resize_shape = output_shape + elif len(output_shape) == 3: + resize_shape = output_shape[0:2] + elif len(output_shape) == 4: + resize_shape = output_shape[1:3] + output = image_ops.transform( + images=test_image_tensor, + transforms=test_transform, + output_shape=resize_shape) + left_err = gradient_checker.compute_gradient_error( + test_image_tensor, + test_image_shape, + output, + output_shape, + x_init_value=test_image) + self.assertLess(left_err, 1e-10) + def test_grad(self): self._test_grad([16, 16]) self._test_grad([4, 12, 12]) self._test_grad([3, 4, 12, 12]) + self._test_grad_different_shape([16, 16], [8, 8]) + self._test_grad_different_shape([4, 12, 3], [8, 24, 3]) + self._test_grad_different_shape([3, 4, 12, 3], [3, 8, 24, 3]) class BipartiteMatchTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index cd984c80543886be1f682933e2e003bd3374e425..e7a09041adb33981df0a8c8238bc5b9358f14180 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops @@ -33,12 +34,16 @@ _image_ops_so = loader.load_op_library( resource_loader.get_path_to_datafile("_image_ops.so")) _IMAGE_DTYPES = set( - [dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) + [dtypes.uint8, dtypes.int32, dtypes.int64, + dtypes.float16, dtypes.float32, dtypes.float64]) ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) +# TODO(ringwalt): Support a "reshape" (name used by SciPy) or "expand" (name +# used by PIL, maybe more readable) mode, which determines the correct +# output_shape and translation for the transform. def rotate(images, angles, interpolation="NEAREST", name=None): """Rotate image(s) counterclockwise by the passed angle(s) in radians. @@ -212,7 +217,11 @@ def translations_to_projective_transforms(translations, name=None): axis=1) -def transform(images, transforms, interpolation="NEAREST", name=None): +def transform(images, + transforms, + interpolation="NEAREST", + output_shape=None, + name=None): """Applies the given transform(s) to the image(s). Args: @@ -229,6 +238,10 @@ def transform(images, transforms, interpolation="NEAREST", name=None): the transform mapping input points to output points. Note that gradients are not backpropagated into transformation parameters. interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". + output_shape: Output dimesion after the transform, [height, width]. + If None, output is the same size as input image. + + name: The name of the op. Returns: Image(s) with the same type and shape as `images`, with the given @@ -237,6 +250,7 @@ def transform(images, transforms, interpolation="NEAREST", name=None): Raises: TypeError: If `image` is an invalid type. + ValueError: If output shape is not 1-D int32 Tensor. """ with ops.name_scope(name, "transform"): image_or_images = ops.convert_to_tensor(images, name="images") @@ -255,6 +269,17 @@ def transform(images, transforms, interpolation="NEAREST", name=None): else: raise TypeError("Images should have rank between 2 and 4.") + if output_shape is None: + output_shape = tensor_util.constant_value( + array_ops.shape(images)[1:3]) or array_ops.shape(images)[1:3] + + output_shape = ops.convert_to_tensor( + output_shape, dtypes.int32, name="output_shape") + + if not output_shape.get_shape().is_compatible_with([2]): + raise ValueError("output_shape must be a 1-D Tensor of 2 elements: " + "new_height, new_width") + if len(transform_or_transforms.get_shape()) == 1: transforms = transform_or_transforms[None] elif transform_or_transforms.get_shape().ndims is None: @@ -264,8 +289,12 @@ def transform(images, transforms, interpolation="NEAREST", name=None): transforms = transform_or_transforms else: raise TypeError("Transforms should have rank 1 or 2.") + output = gen_image_ops.image_projective_transform( - images, transforms, interpolation=interpolation.upper()) + images, + output_shape=output_shape, + transforms=transforms, + interpolation=interpolation.upper()) if len(image_or_images.get_shape()) == 2: return output[0, :, :, 0] elif len(image_or_images.get_shape()) == 3: @@ -375,14 +404,6 @@ def _image_projective_transform_grad(op, grad): if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - if len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4") if len(transform_or_transforms.get_shape()) == 1: transforms = transform_or_transforms[None] elif len(transform_or_transforms.get_shape()) == 2: @@ -395,13 +416,11 @@ def _image_projective_transform_grad(op, grad): inverse = linalg_ops.matrix_inverse(transforms) transforms = matrices_to_flat_transforms(inverse) output = gen_image_ops.image_projective_transform( - grad, transforms, interpolation=interpolation) - if len(image_or_images.get_shape()) == 2: - return [output[0, :, :, 0], None] - elif len(image_or_images.get_shape()) == 3: - return [output[0, :, :, :], None] - else: - return [output, None] + images=grad, + transforms=transforms, + output_shape=array_ops.shape(image_or_images)[1:3], + interpolation=interpolation) + return [output, None, None] def bipartite_match(distance_mat, diff --git a/tensorflow/contrib/image/python/ops/sparse_image_warp.py b/tensorflow/contrib/image/python/ops/sparse_image_warp.py index 54a215d6db6ded56a1a4a018a7e176f35fe6397e..1ea8f705b7e6f522281de6384de0d42efab6a406 100644 --- a/tensorflow/contrib/image/python/ops/sparse_image_warp.py +++ b/tensorflow/contrib/image/python/ops/sparse_image_warp.py @@ -112,10 +112,10 @@ def sparse_image_warp(image, Apply a non-linear warp to the image, where the warp is specified by the source and destination locations of a (potentially small) number of control points. First, we use a polyharmonic spline - (@{tf.contrib.image.interpolate_spline}) to interpolate the displacements + (`tf.contrib.image.interpolate_spline`) to interpolate the displacements between the corresponding control points to a dense flow field. Then, we warp the image using this dense flow field - (@{tf.contrib.image.dense_image_warp}). + (`tf.contrib.image.dense_image_warp`). Let t index our control points. For regularization_weight=0, we have: warped_image[b, dest_control_point_locations[b, t, 0], @@ -126,7 +126,7 @@ def sparse_image_warp(image, For regularization_weight > 0, this condition is met approximately, since regularized interpolation trades off smoothness of the interpolant vs. reconstruction of the interpolant at the control points. - See @{tf.contrib.image.interpolate_spline} for further documentation of the + See `tf.contrib.image.interpolate_spline` for further documentation of the interpolation_order and regularization_weight arguments. diff --git a/tensorflow/contrib/integrate/python/ops/odes.py b/tensorflow/contrib/integrate/python/ops/odes.py index 61f78febfc07bb4e677259366a81c16b2b585244..7b7ac4f347e30d20eb2f4889e0cae5669c975e4f 100644 --- a/tensorflow/contrib/integrate/python/ops/odes.py +++ b/tensorflow/contrib/integrate/python/ops/odes.py @@ -73,7 +73,7 @@ def _scaled_dot_product(scale, xs, ys, name=None): # _possibly_nonzero lets us avoid wasted computation. return math_ops.add_n( [(scale * x) * y for x, y in zip(xs, ys) - if _possibly_nonzero(x) or _possibly_nonzero(y)], + if _possibly_nonzero(x) and _possibly_nonzero(y)], name=scope) @@ -122,7 +122,7 @@ def _runge_kutta_step(func, yi = y0 + _scaled_dot_product(dt_cast, beta_i, k) k.append(func(yi, ti)) - if not (tableau.c_sol[-1] == 0 and tableau.c_sol == tableau.beta[-1]): + if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]): # This property (true for Dormand-Prince) lets us save a few FLOPs. yi = y0 + _scaled_dot_product(dt_cast, tableau.c_sol, k) diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc index 2638b25ec424b5b4ef556ff769e94e64da32fec2..d0ea961473c7d6a07b152d1450b0ca2fdf1dc11f 100644 --- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" -#include "src-cpp/rdkafkacpp.h" +#include "rdkafkacpp.h" namespace tensorflow { @@ -52,12 +52,12 @@ class KafkaDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, std::vector topics, const string& servers, const string& group, const bool eof, const int64 timeout) - : GraphDatasetBase(ctx), + : DatasetBase(DatasetContext(ctx)), topics_(std::move(topics)), servers_(servers), group_(group), @@ -84,7 +84,8 @@ class KafkaDatasetOp : public DatasetOpKernel { string DebugString() const override { return "KafkaDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, Node** output) const override { Node* topics = nullptr; TF_RETURN_IF_ERROR(b->AddVector(topics_, &topics)); diff --git a/tensorflow/contrib/kafka/ops/kafka_ops.cc b/tensorflow/contrib/kafka/ops/kafka_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cdf16103bab2b22d51c144d21a589e1e39f2f0b --- /dev/null +++ b/tensorflow/contrib/kafka/ops/kafka_ops.cc @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("KafkaDataset") + .Input("topics: string") + .Input("servers: string") + .Input("group: string") + .Input("eof: bool") + .Input("timeout: int64") + .Output("handle: variant") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that emits the messages of one or more Kafka topics. + +topics: A `tf.string` tensor containing one or more subscriptions, + in the format of [topic:partition:offset:length], + by default length is -1 for unlimited. +servers: A list of bootstrap servers. +group: The consumer group id. +eof: If True, the kafka reader will stop on EOF. +timeout: The timeout value for the Kafka Consumer to wait + (in millisecond). +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/keras/__init__.py b/tensorflow/contrib/keras/__init__.py index a162f0cb584038b8df7d1ee6fe8237160ad8f695..cecf1ddcdb1c6e1b6a6f895b83a6c4f2a2aae1f7 100644 --- a/tensorflow/contrib/keras/__init__.py +++ b/tensorflow/contrib/keras/__init__.py @@ -15,7 +15,7 @@ # ============================================================================== """Implementation of the Keras API meant to be a high-level API for TensorFlow. -This module an alias for @{tf.keras}, for backwards compatibility. +This module an alias for `tf.keras`, for backwards compatibility. Detailed documentation and user guides are also available at [keras.io](https://keras.io). diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py index 938c881fcbe18623fa18c21c112375f9914f887b..3327a9f9a613bfb56e6a25af0fe1c0ca18609035 100644 --- a/tensorflow/contrib/keras/api/keras/layers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py @@ -20,10 +20,10 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras.engine import Input -from tensorflow.python.keras.engine import InputLayer -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.engine.input_layer import Input +from tensorflow.python.keras.engine.input_layer import InputLayer # Advanced activations. from tensorflow.python.keras.layers.advanced_activations import LeakyReLU diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py index 1f9e82b41bf09b235e93fa512a50ea4c3047c01b..cb649a37510c301cb3df997f844617e9a4e6c7be 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py @@ -18,10 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras.preprocessing.image import apply_transform from tensorflow.python.keras.preprocessing.image import array_to_img from tensorflow.python.keras.preprocessing.image import DirectoryIterator -from tensorflow.python.keras.preprocessing.image import flip_axis from tensorflow.python.keras.preprocessing.image import ImageDataGenerator from tensorflow.python.keras.preprocessing.image import img_to_array from tensorflow.python.keras.preprocessing.image import Iterator diff --git a/tensorflow/contrib/kernel_methods/README.md b/tensorflow/contrib/kernel_methods/README.md index 44ed9670a09ece8fb11e79a3e58725e2a54e513b..1bce3277ff46ac91a8de118db17041a0e424ebc0 100644 --- a/tensorflow/contrib/kernel_methods/README.md +++ b/tensorflow/contrib/kernel_methods/README.md @@ -21,13 +21,15 @@ Currently, there is a [RandomFourierFeatureMapper](https://www.tensorflow.org/co output. More mappers are on the way. ## Kernel-based Estimators -These are estimators inheriting from the @{tf.contrib.learn.Estimator} class and -use kernel mappers internally to discover non-linearities in the data. These -canned estimators map their input features using kernel mapper Ops and then -apply linear models to the mapped features. Combining kernel mappers with linear -models and different loss functions leads to a variety of models: linear and -non-linear SVMs, linear regression (with and without kernels) and (multinomial) -logistic regression (with and without kernels). + +These estimators inherit from the +[`tf.contrib.learn.Estimator`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/estimator.py) +class and use kernel mappers internally to discover non-linearities in the +data. These canned estimators map their input features using kernel mapper +Ops and then apply linear models to the mapped features. Combining kernel +mappers with linear models and different loss functions leads to a variety of +models: linear and non-linear SVMs, linear regression (with and without +kernels) and (multinomial) logistic regression (with and without kernels). Currently there is a [KernelLinearClassifier](https://www.tensorflow.org/code/tensorflow/contrib/kernel_methods/python/kernel_estimators.py) implemented but more pre-packaged estimators are on the way. diff --git a/tensorflow/contrib/kinesis/BUILD b/tensorflow/contrib/kinesis/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..25443d0ad47aa7d503f905eb34000488b62f22c6 --- /dev/null +++ b/tensorflow/contrib/kinesis/BUILD @@ -0,0 +1,113 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_py_test", +) + +py_library( + name = "kinesis", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", + ], +) + +tf_custom_op_library( + name = "_dataset_ops.so", + srcs = ["ops/dataset_ops.cc"], + deps = [":dataset_kernels"], +) + +tf_gen_op_libs( + op_lib_names = ["dataset_ops"], +) + +cc_library( + name = "dataset_kernels", + srcs = [ + "kernels/kinesis_dataset_ops.cc", + ], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core/platform/s3:aws_crypto", + "//third_party/eigen3", + "@aws", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +py_library( + name = "dataset_ops", + srcs = [ + "python/ops/kinesis_dataset_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":kinesis_op_loader", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_dataset_ops", + out = "python/ops/gen_dataset_ops.py", + deps = ["//tensorflow/contrib/kinesis:dataset_ops_op_lib"], +) + +tf_kernel_library( + name = "dataset_ops_kernels", + deps = [ + ":dataset_kernels", + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "kinesis_op_loader", + srcs = ["python/ops/kinesis_op_loader.py"], + dso = ["//tensorflow/contrib/kinesis:_dataset_ops.so"], + kernels = [ + ":dataset_ops_kernels", + "//tensorflow/contrib/kinesis:dataset_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_dataset_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + ], +) + +tf_py_test( + name = "kinesis_test", + srcs = ["python/kernel_tests/kinesis_test.py"], + additional_deps = [ + ":kinesis", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + tags = [ + "manual", + "no_windows", + "notap", + ], +) diff --git a/tensorflow/python/training/checkpointable/data_structures_base.py b/tensorflow/contrib/kinesis/__init__.py similarity index 71% rename from tensorflow/python/training/checkpointable/data_structures_base.py rename to tensorflow/contrib/kinesis/__init__.py index f1b2cf105b81490ea12e0a667f53fb02d45135c9..3824b8ae7532ab97a5ebf01ab66ece6476c87d42 100644 --- a/tensorflow/python/training/checkpointable/data_structures_base.py +++ b/tensorflow/contrib/kinesis/__init__.py @@ -1,4 +1,3 @@ -"""A trivial base class to avoid circular imports for isinstance checks.""" # Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,15 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Kinesis Dataset. + +@@KinesisDataset +""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.kinesis.python.ops.kinesis_dataset_ops import KinesisDataset -from tensorflow.python.training.checkpointable import base as checkpointable_lib - +from tensorflow.python.util.all_util import remove_undocumented -class CheckpointableDataStructureBase(checkpointable_lib.CheckpointableBase): - """Base class for data structures which contain checkpointable objects.""" +_allowed_symbols = [ + "KinesisDataset", +] - pass +remove_undocumented(__name__) diff --git a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..95c7001371a9b43f2e6c0c66245cc4f1fafc486d --- /dev/null +++ b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc @@ -0,0 +1,360 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/platform/s3/aws_crypto.h" + +namespace tensorflow { +namespace { + +Aws::Client::ClientConfiguration* InitializeDefaultClientConfig() { + static Aws::Client::ClientConfiguration config; + const char* endpoint = getenv("KINESIS_ENDPOINT"); + if (endpoint) { + config.endpointOverride = Aws::String(endpoint); + } + const char* region = getenv("AWS_REGION"); + if (region) { + config.region = Aws::String(region); + } else { + // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG + // is set with a truthy value. + const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG"); + string load_config = + load_config_env ? str_util::Lowercase(load_config_env) : ""; + if (load_config == "true" || load_config == "1") { + Aws::String config_file; + // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config. + const char* config_file_env = getenv("AWS_CONFIG_FILE"); + if (config_file_env) { + config_file = config_file_env; + } else { + const char* home_env = getenv("HOME"); + if (home_env) { + config_file = home_env; + config_file += "/.aws/config"; + } + } + Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file); + // Load the configuration. If successful, get the region. + // If the load is not successful, then generate a warning. + if (loader.Load()) { + auto profiles = loader.GetProfiles(); + if (!profiles["default"].GetRegion().empty()) { + config.region = profiles["default"].GetRegion(); + } + } else { + LOG(WARNING) << "Failed to load the profile in " << config_file << "."; + } + } + } + const char* use_https = getenv("KINESIS_USE_HTTPS"); + if (use_https) { + if (use_https[0] == '0') { + config.scheme = Aws::Http::Scheme::HTTP; + } else { + config.scheme = Aws::Http::Scheme::HTTPS; + } + } + const char* verify_ssl = getenv("KINESIS_VERIFY_SSL"); + if (verify_ssl) { + if (verify_ssl[0] == '0') { + config.verifySSL = false; + } else { + config.verifySSL = true; + } + } + const char* connect_timeout = getenv("KINESIS_CONNECT_TIMEOUT_MSEC"); + if (connect_timeout) { + int64 timeout; + + if (strings::safe_strto64(connect_timeout, &timeout)) { + config.connectTimeoutMs = timeout; + } + } + const char* request_timeout = getenv("KINESIS_REQUEST_TIMEOUT_MSEC"); + if (request_timeout) { + int64 timeout; + + if (strings::safe_strto64(request_timeout, &timeout)) { + config.requestTimeoutMs = timeout; + } + } + + return &config; +} + +Aws::Client::ClientConfiguration& GetDefaultClientConfig() { + static Aws::Client::ClientConfiguration* config = + InitializeDefaultClientConfig(); + return *config; +} + +static mutex mu(LINKER_INITIALIZED); +static unsigned count(0); +void AwsInitAPI() { + mutex_lock lock(mu); + count++; + if (count == 1) { + Aws::SDKOptions options; + options.cryptoOptions.sha256Factory_create_fn = []() { + return Aws::MakeShared(AWSCryptoAllocationTag); + }; + options.cryptoOptions.sha256HMACFactory_create_fn = []() { + return Aws::MakeShared(AWSCryptoAllocationTag); + }; + Aws::InitAPI(options); + } +} +void AwsShutdownAPI() { + mutex_lock lock(mu); + count--; + if (count == 0) { + Aws::SDKOptions options; + Aws::ShutdownAPI(options); + } +} +void ShutdownClient(Aws::Kinesis::KinesisClient* client) { + if (client != nullptr) { + delete client; + AwsShutdownAPI(); + } +} +} +class KinesisDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + std::string stream = ""; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "stream", &stream)); + std::string shard = ""; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "shard", &shard)); + bool read_indefinitely = true; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "read_indefinitely", + &read_indefinitely)); + int64 interval = -1; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "interval", &interval)); + OP_REQUIRES(ctx, (interval > 0), + errors::InvalidArgument( + "Interval value should be large than 0, got ", interval)); + *output = new Dataset(ctx, stream, shard, read_indefinitely, interval); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const string& stream, const string& shard, + const bool read_indefinitely, const int64 interval) + : DatasetBase(DatasetContext(ctx)), + stream_(stream), + shard_(shard), + read_indefinitely_(read_indefinitely), + interval_(interval) {} + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::Kinesis")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { return "KinesisDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* stream = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(stream_, &stream)); + Node* shard = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(shard_, &shard)); + Node* read_indefinitely = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(read_indefinitely_, &read_indefinitely)); + Node* interval = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(interval_, &interval)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {stream, shard, read_indefinitely, interval}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + client_(nullptr, ShutdownClient) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (iterator_ == "") { + TF_RETURN_IF_ERROR(SetupStreamsLocked()); + } + do { + Aws::Kinesis::Model::GetRecordsRequest request; + auto outcome = client_->GetRecords( + request.WithShardIterator(iterator_).WithLimit(1)); + if (!outcome.IsSuccess()) { + return errors::Unknown(outcome.GetError().GetExceptionName(), ": ", + outcome.GetError().GetMessage()); + } + if (outcome.GetResult().GetRecords().size() == 0) { + // If no records were returned then nothing is available at the + // moment. + if (!dataset()->read_indefinitely_) { + *end_of_sequence = true; + return Status::OK(); + } + // Continue the loop after a period of time. + ctx->env()->SleepForMicroseconds(dataset()->interval_); + continue; + } + if (outcome.GetResult().GetRecords().size() != 1) { + return errors::Unknown("invalid number of records ", + outcome.GetResult().GetRecords().size(), + " returned"); + } + + iterator_ = outcome.GetResult().GetNextShardIterator(); + + const auto& data = outcome.GetResult().GetRecords()[0].GetData(); + StringPiece value( + reinterpret_cast(data.GetUnderlyingData()), + data.GetLength()); + Tensor value_tensor(ctx->allocator({}), DT_STRING, {}); + value_tensor.scalar()() = std::string(value); + out_tensors->emplace_back(std::move(value_tensor)); + + *end_of_sequence = false; + return Status::OK(); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented("SaveInternal is currently not supported"); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented( + "RestoreInternal is currently not supported"); + } + + private: + // Sets up Kinesis streams to read from. + Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + AwsInitAPI(); + client_.reset( + new Aws::Kinesis::KinesisClient(GetDefaultClientConfig())); + + Aws::Kinesis::Model::DescribeStreamRequest request; + auto outcome = client_->DescribeStream( + request.WithStreamName(dataset()->stream_.c_str())); + if (!outcome.IsSuccess()) { + return errors::Unknown(outcome.GetError().GetExceptionName(), ": ", + outcome.GetError().GetMessage()); + } + Aws::String shard; + Aws::String sequence; + if (dataset()->shard_ == "") { + if (outcome.GetResult().GetStreamDescription().GetShards().size() != + 1) { + return errors::InvalidArgument( + "shard has to be provided unless the stream only have one " + "shard, there are ", + outcome.GetResult().GetStreamDescription().GetShards().size(), + " shards in stream ", dataset()->stream_); + } + shard = outcome.GetResult() + .GetStreamDescription() + .GetShards()[0] + .GetShardId(); + sequence = outcome.GetResult() + .GetStreamDescription() + .GetShards()[0] + .GetSequenceNumberRange() + .GetStartingSequenceNumber(); + } else { + for (const auto& entry : + outcome.GetResult().GetStreamDescription().GetShards()) { + if (entry.GetShardId() == dataset()->shard_.c_str()) { + shard = entry.GetShardId(); + sequence = + entry.GetSequenceNumberRange().GetStartingSequenceNumber(); + break; + } + } + if (shard == "") { + return errors::InvalidArgument("no shard ", dataset()->shard_, + " in stream ", dataset()->stream_); + } + } + + Aws::Kinesis::Model::GetShardIteratorRequest iterator_request; + auto iterator_outcome = client_->GetShardIterator( + iterator_request.WithStreamName(dataset()->stream_.c_str()) + .WithShardId(shard) + .WithShardIteratorType( + Aws::Kinesis::Model::ShardIteratorType::AT_SEQUENCE_NUMBER) + .WithStartingSequenceNumber(sequence)); + if (!iterator_outcome.IsSuccess()) { + return errors::Unknown(iterator_outcome.GetError().GetExceptionName(), + ": ", + iterator_outcome.GetError().GetMessage()); + } + iterator_ = iterator_outcome.GetResult().GetShardIterator(); + return Status::OK(); + } + + mutex mu_; + Aws::String iterator_ GUARDED_BY(mu_); + std::unique_ptr + client_ GUARDED_BY(mu_); + }; + + const std::string stream_; + const std::string shard_; + const bool read_indefinitely_; + const int64 interval_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("KinesisDataset").Device(DEVICE_CPU), + KinesisDatasetOp); + +} // namespace tensorflow diff --git a/tensorflow/contrib/kinesis/ops/dataset_ops.cc b/tensorflow/contrib/kinesis/ops/dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..54204513cf22519ecfb5fa45748250ee0f4aac7a --- /dev/null +++ b/tensorflow/contrib/kinesis/ops/dataset_ops.cc @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("KinesisDataset") + .Input("stream: string") + .Input("shard: string") + .Input("read_indefinitely: bool") + .Input("interval: int64") + .Output("handle: variant") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that emits the messages of one or more Kinesis topics. + +stream: A `tf.string` tensor containing the name of the stream. +shard: A `tf.string` tensor containing the id of the shard. +read_indefinitely: If `True`, the Kinesis dataset will keep retry + again on `EOF` after the `interval` period. If `False`, then + the dataset will stop on `EOF`. The default value is `True`. +interval: The interval for the Kinesis Client to wait before + it tries to get records again (in millisecond). +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7289b45c50fa92455b4c317b8a039ca414fa585e --- /dev/null +++ b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py @@ -0,0 +1,139 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for KinesisDataset. +NOTE: boto3 is needed and the test has to be invoked manually: +``` +$ bazel test -s --verbose_failures --config=opt \ + --action_env=AWS_ACCESS_KEY_ID=XXXXXX \ + --action_env=AWS_SECRET_ACCESS_KEY=XXXXXX \ + //tensorflow/contrib/kinesis:kinesis_test +``` +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import boto3 + +from tensorflow.contrib.kinesis.python.ops import kinesis_dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class KinesisDatasetTest(test.TestCase): + + def testKinesisDatasetOneShard(self): + client = boto3.client('kinesis', region_name='us-east-1') + + # Setup the Kinesis with 1 shard. + stream_name = "tf_kinesis_test_1" + client.create_stream(StreamName=stream_name, ShardCount=1) + # Wait until stream exists, default is 10 * 18 seconds. + client.get_waiter('stream_exists').wait(StreamName=stream_name) + for i in range(10): + data = "D" + str(i) + client.put_record( + StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i)) + + stream = array_ops.placeholder(dtypes.string, shape=[]) + num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kinesis_dataset_ops.KinesisDataset( + stream, read_indefinitely=False).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + with self.test_session() as sess: + # Basic test: read from shard 0 of stream 1. + sess.run(init_op, feed_dict={stream: stream_name, num_epochs: 1}) + for i in range(10): + self.assertEqual("D" + str(i), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + client.delete_stream(StreamName=stream_name) + # Wait until stream deleted, default is 10 * 18 seconds. + client.get_waiter('stream_not_exists').wait(StreamName=stream_name) + + def testKinesisDatasetTwoShards(self): + client = boto3.client('kinesis', region_name='us-east-1') + + # Setup the Kinesis with 2 shards. + stream_name = "tf_kinesis_test_2" + client.create_stream(StreamName=stream_name, ShardCount=2) + # Wait until stream exists, default is 10 * 18 seconds. + client.get_waiter('stream_exists').wait(StreamName=stream_name) + + for i in range(10): + data = "D" + str(i) + client.put_record( + StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i)) + response = client.describe_stream(StreamName=stream_name) + shard_id_0 = response["StreamDescription"]["Shards"][0]["ShardId"] + shard_id_1 = response["StreamDescription"]["Shards"][1]["ShardId"] + + stream = array_ops.placeholder(dtypes.string, shape=[]) + shard = array_ops.placeholder(dtypes.string, shape=[]) + num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kinesis_dataset_ops.KinesisDataset( + stream, shard, read_indefinitely=False).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + data = list() + with self.test_session() as sess: + # Basic test: read from shard 0 of stream 2. + sess.run( + init_op, feed_dict={ + stream: stream_name, shard: shard_id_0, num_epochs: 1}) + with self.assertRaises(errors.OutOfRangeError): + # Use range(11) to guarantee the OutOfRangeError. + for i in range(11): + data.append(sess.run(get_next)) + + # Basic test: read from shard 1 of stream 2. + sess.run( + init_op, feed_dict={ + stream: stream_name, shard: shard_id_1, num_epochs: 1}) + with self.assertRaises(errors.OutOfRangeError): + # Use range(11) to guarantee the OutOfRangeError. + for i in range(11): + data.append(sess.run(get_next)) + + data.sort() + self.assertEqual(data, ["D" + str(i) for i in range(10)]) + + client.delete_stream(StreamName=stream_name) + # Wait until stream deleted, default is 10 * 18 seconds. + client.get_waiter('stream_not_exists').wait(StreamName=stream_name) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ca2df95ba4f20ec5fa58ff13530096e6e065f4fe --- /dev/null +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -0,0 +1,96 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kinesis Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import +from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops +from tensorflow.python.data.ops.dataset_ops import Dataset +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class KinesisDataset(Dataset): + """A Kinesis Dataset that consumes the message. + + Kinesis is a managed service provided by AWS for data streaming. + This dataset reads messages from Kinesis with each message presented + as a `tf.string`. + + For example, we can construct and use the KinesisDataset as follows: + ```python + dataset = tf.contrib.kinesis.KinesisDataset( + "kinesis_stream_name", read_indefinitely=False) + next = dataset.make_one_shot_iterator().get_next() + with tf.Session() as sess: + while True: + try: + print(sess.run(nxt)) + except tf.errors.OutOfRangeError: + break + ``` + + Since Kinesis is a data streaming service, data may not be available + at the time it is being read. The argument `read_indefinitely` is + used to control the behavior in this situation. If `read_indefinitely` + is `True`, then `KinesisDataset` will keep retrying to retrieve data + from the stream. If `read_indefinitely` is `False`, an `OutOfRangeError` + is returned immediately instead. + """ + + def __init__(self, + stream, + shard="", + read_indefinitely=True, + interval=100000): + """Create a KinesisDataset. + + Args: + stream: A `tf.string` tensor containing the name of the stream. + shard: A `tf.string` tensor containing the id of the shard. + read_indefinitely: If `True`, the Kinesis dataset will keep retry + again on `EOF` after the `interval` period. If `False`, then + the dataset will stop on `EOF`. The default value is `True`. + interval: The interval for the Kinesis Client to wait before + it tries to get records again (in millisecond). + """ + super(KinesisDataset, self).__init__() + self._stream = ops.convert_to_tensor( + stream, dtype=dtypes.string, name="stream") + self._shard = ops.convert_to_tensor( + shard, dtype=dtypes.string, name="shard") + self._read_indefinitely = ops.convert_to_tensor( + read_indefinitely, dtype=dtypes.bool, name="read_indefinitely") + self._interval = ops.convert_to_tensor( + interval, dtype=dtypes.int64, name="interval") + + def _as_variant_tensor(self): + return gen_dataset_ops.kinesis_dataset( + self._stream, self._shard, self._read_indefinitely, self._interval) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.string diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py b/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..c9ce9f3646200a777cdbdf34b37626154ca730bb --- /dev/null +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python helper for loading kinesis ops and kernels.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_dataset_ops.so")) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index 3ba1026383ef146adb32197ae41b5c251155bf46..2ede5daee74223e812cc29e9708b1989b698fb4e 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -652,7 +652,8 @@ def map_fn(fn, labeled_tensor, name=None): tensor_lt = core.LabeledTensor(tensor, original_axes) return fn(tensor_lt).tensor - map_op = functional_ops.map_fn(tf_fn, labeled_tensor.tensor) + map_op = functional_ops.map_fn( + tf_fn, labeled_tensor.tensor, dtype=first_map_lt.dtype) map_lt = core.LabeledTensor(map_op, final_axes) return core.identity(map_lt, name=scope) diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index bc3359693562deb1229a78a2db5c256c76f7fd8d..a7b41b714ffaa062e2eba8caf9b4fa033c7633cd 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -121,6 +121,7 @@ from tensorflow.contrib.layers.python.layers import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['bias_add', + 'conv1d', 'conv2d', 'conv3d', 'elu', diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index dd2395f8c9748dadbecfe47df5511874d5f848ea..7ede193029d2d95fa4953b4c417a1e86ebb4a42e 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import itertools import math -import sys import numpy as np diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index 06060b99e7e58787994f20f037ffa451abbc7459..a85cff4f7098e9a5eedca1b0c8c0cb42e172d90a 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -683,11 +683,12 @@ def parse_feature_columns_from_sequence_examples( the serialized proto. Returns: - A tuple consisting of: - context_features: a dict mapping `FeatureColumns` from - `context_feature_columns` to their parsed `Tensors`/`SparseTensor`s. - sequence_features: a dict mapping `FeatureColumns` from - `sequence_feature_columns` to their parsed `Tensors`/`SparseTensor`s. + A tuple consisting of (context_features, sequence_features) + + * context_features: a dict mapping `FeatureColumns` from + `context_feature_columns` to their parsed `Tensors`/`SparseTensor`s. + * sequence_features: a dict mapping `FeatureColumns` from + `sequence_feature_columns` to their parsed `Tensors`/`SparseTensor`s. """ # Sequence example parsing requires a single (scalar) example. try: diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py index 51610f21b24f1d40f26630cc1e69ca723d130639..1192198ec26c9db749a9bd1ee07f52395fd16a0f 100644 --- a/tensorflow/contrib/layers/python/layers/initializers.py +++ b/tensorflow/contrib/layers/python/layers/initializers.py @@ -47,7 +47,7 @@ def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32): Args: uniform: Whether to use uniform or normal distributed random initialization. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. dtype: The data type. Only floating point types are supported. Returns: @@ -98,7 +98,7 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False, mode: String. 'FAN_IN', 'FAN_OUT', 'FAN_AVG'. uniform: Whether to use uniform or normal distributed random initialization. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. dtype: The data type. Only floating point types are supported. Returns: diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index b6d63c9640611abdda65f1205f544ee505dae1f0..04668f112d85b946f313f85e60ee607fe761f63c 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -55,9 +55,9 @@ from tensorflow.python.training import moving_averages # TODO(b/28426988): Replace legacy_* fns migrated from slim. # TODO(b/28426988): Remove legacy_* when all uses have migrated to new API. __all__ = [ - 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d', - 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', - 'convolution1d', 'convolution2d', 'convolution2d_in_plane', + 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv1d', 'conv2d', + 'conv3d', 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', + 'convolution', 'convolution1d', 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', @@ -1584,7 +1584,7 @@ def dropout(inputs, outputs_collections: Collection to add the outputs. scope: Optional scope for name_scope. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. Returns: A tensor representing the output of the operation. @@ -1702,19 +1702,22 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None): return utils.collect_named_outputs(output_collections, sc, flattened) -def _model_variable_getter(getter, - name, - shape=None, - dtype=None, - initializer=None, - regularizer=None, - trainable=True, - collections=None, - caching_device=None, - partitioner=None, - rename=None, - use_resource=None, - **_): +def _model_variable_getter( + getter, + name, + shape=None, + dtype=None, + initializer=None, + regularizer=None, + trainable=True, + collections=None, + caching_device=None, + partitioner=None, + rename=None, + use_resource=None, + synchronization=tf_variables.VariableSynchronization.AUTO, + aggregation=tf_variables.VariableAggregation.NONE, + **_): """Getter that uses model_variable for compatibility with core layers.""" short_name = name.split('/')[-1] if rename and short_name in rename: @@ -1732,7 +1735,9 @@ def _model_variable_getter(getter, caching_device=caching_device, partitioner=partitioner, custom_getter=getter, - use_resource=use_resource) + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) def _build_variable_getter(rename=None): @@ -2655,7 +2660,7 @@ def separable_convolution2d( inputs, num_outputs, kernel_size, - depth_multiplier, + depth_multiplier=1, stride=1, padding='SAME', data_format=DATA_FORMAT_NHWC, @@ -2664,6 +2669,7 @@ def separable_convolution2d( normalizer_fn=None, normalizer_params=None, weights_initializer=initializers.xavier_initializer(), + pointwise_initializer=None, weights_regularizer=None, biases_initializer=init_ops.zeros_initializer(), biases_regularizer=None, @@ -2705,7 +2711,9 @@ def separable_convolution2d( `biases_regularizer` are ignored and `biases` are not created nor added. default set to None for no normalizer function normalizer_params: Normalization function parameters. - weights_initializer: An initializer for the weights. + weights_initializer: An initializer for the depthwise weights. + pointwise_initializer: An initializer for the pointwise weights. + default set to None, means use weights_initializer. weights_regularizer: Optional regularizer for the weights. biases_initializer: An initializer for the biases. If None skip biases. biases_regularizer: Optional regularizer for the biases. @@ -2737,6 +2745,9 @@ def separable_convolution2d( custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) + if pointwise_initializer is None: + pointwise_initializer = weights_initializer + df = ('channels_first' if data_format and data_format.startswith('NC') else 'channels_last') if num_outputs is not None: @@ -2752,7 +2763,7 @@ def separable_convolution2d( depth_multiplier=depth_multiplier, use_bias=not normalizer_fn and biases_initializer, depthwise_initializer=weights_initializer, - pointwise_initializer=weights_initializer, + pointwise_initializer=pointwise_initializer, bias_initializer=biases_initializer, depthwise_regularizer=weights_regularizer, pointwise_regularizer=weights_regularizer, @@ -3309,6 +3320,7 @@ relu6 = functools.partial(fully_connected, activation_fn=nn.relu6) linear = functools.partial(fully_connected, activation_fn=None) # Simple alias. +conv1d = convolution1d conv2d = convolution2d conv3d = convolution3d conv2d_transpose = convolution2d_transpose diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index c5c7269b1f15849956e90654e3bcf8ab0eebc393..51c7abb105a29ff0dfab49d77bc62d5b51517179 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1189,7 +1189,7 @@ class ConvolutionInPlaneTest(test.TestCase): result = sess.run(horz_gradients) expected = np.zeros((1, 10, 9, 1)) - self.assertAllEqual(result, expected) + self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5) def testHorzConvWithBlankImageAndPlaceholder(self): image = array_ops.placeholder(dtypes.float32, shape=(None, None, None, 1)) @@ -1209,7 +1209,7 @@ class ConvolutionInPlaneTest(test.TestCase): }) expected = np.zeros((1, 10, 9, 1)) - self.assertAllEqual(result, expected) + self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5) def testHorzConvWithRandomImageMultiBatch(self): np.random.seed(1) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 0e35b1aa8bf682c1b4f7e8d974d3e8fad69e33cb..dad3da3748097c26e07b4abe0495f62a18aad369 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -514,15 +514,15 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): original_vars = set(tape.watched_variables()) # Backward pass - def grad_fn(*output_grads, **kwargs): + def _grad_fn(output_grads, variables=None): """Recompute outputs for gradient computation.""" - variables = [] + variables = variables or [] if original_vars: - variables = kwargs["variables"] - if set(variables) != original_vars: - raise ValueError(_WRONG_VARS_ERR) - del kwargs - inputs = list(args) + assert variables, ("Fn created variables but the variables were not " + "passed to the gradient fn.") + if set(variables) != original_vars: + raise ValueError(_WRONG_VARS_ERR) + inputs = [array_ops.identity(x) for x in list(args)] # Recompute outputs with framework_ops.control_dependencies(output_grads): if use_data_dep_: @@ -538,7 +538,7 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): if original_vars != recompute_vars: raise ValueError(_WRONG_VARS_ERR) - if not (isinstance(outputs, list) or isinstance(outputs, tuple)): + if not isinstance(outputs, (list, tuple)): outputs = [outputs] outputs = list(outputs) grads = gradients_impl.gradients(outputs, inputs + variables, @@ -554,6 +554,16 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): grad_vars = grads[len(inputs):] return grad_inputs, grad_vars + # custom_gradient inspects the signature of the function to determine + # whether the user expects variables passed in the grad_fn. If the function + # created variables, the grad_fn should accept the "variables" kwarg. + if original_vars: + def grad_fn(*output_grads, **kwargs): + return _grad_fn(output_grads, kwargs["variables"]) + else: + def grad_fn(*output_grads): + return _grad_fn(output_grads) + return outputs, grad_fn return fn_with_recompute(*args) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index bc09ba8d439808c1582f207a99504012afcf33a6..d5971fb9d8e2fbc1e14fd24fc79e7981a284a418 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -372,6 +372,26 @@ class RecomputeTest(test.TestCase): self.assertEqual(2, len(update_ops)) self.assertEqual([False, True], kwarg_values) + def testWithoutVariables(self): + + def concat_n(layer_list, num_inputs): + return math_ops.reduce_sum( + array_ops.concat([x for x in layer_list[-num_inputs:]], axis=-1), + axis=1, keepdims=True) + + @rev_block_lib.recompute_grad + def concat_n_wrap(*args): + return concat_n(args, 3) + + # DenseNet-style layers + layer_list = [random_ops.random_uniform((4, 8))] + for _ in range(5): + layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list))) + + grads = gradients_impl.gradients(layer_list[-1], layer_list[0]) + with self.test_session() as sess: + sess.run(grads) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index b56a88659bbd4467600788fc8e3e9dbf38ce8244..418b0cf39205391cd67bbdc5c6483f5dc0cfc381 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -79,16 +79,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python:weights_broadcast_ops", - "//tensorflow/python/estimator", "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/estimator:export_export", - "//tensorflow/python/estimator:export_output", - "//tensorflow/python/estimator:inputs", - "//tensorflow/python/estimator:inputs_queues", - "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/estimator:pandas_io", - "//tensorflow/python/estimator:run_config", "//tensorflow/python/feature_column", "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/ops/losses", @@ -117,7 +108,6 @@ py_test( size = "small", srcs = ["python/learn/learn_io/data_feeder_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/python:client_testlib", @@ -171,9 +161,8 @@ tf_py_test( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variables", - "//tensorflow/python/estimator", + "//tensorflow/python/estimator:estimator_py", ], - tags = ["no_windows"], # TODO: needs investigation on Windows ) py_test( @@ -220,7 +209,7 @@ py_test( "//tensorflow/contrib/training:training_py", "//tensorflow/python:client_testlib", "//tensorflow/python:platform", - "//tensorflow/python/estimator:run_config", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -245,7 +234,7 @@ py_test( "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variables", - "//tensorflow/python/estimator", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -259,7 +248,7 @@ py_test( "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:training", - "//tensorflow/python/estimator:run_config", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -600,7 +589,6 @@ py_test( size = "small", srcs = ["python/learn/learn_io/io_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/contrib/learn/python/learn/datasets", @@ -621,7 +609,7 @@ py_test( "//tensorflow/python:control_flow_ops", "//tensorflow/python:session", "//tensorflow/python:training", - "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/saved_model:signature_constants", "@six_archive//:six", ], diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 7a026a15e4aeea0dde4ed9f7de053a757a0abb58..c1de42782efb3497660affb3ef7162457977c150 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -72,6 +72,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary as core_summary from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import device_setter from tensorflow.python.training import monitored_session from tensorflow.python.training import saver @@ -891,7 +892,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, # Check that model has been trained (if nothing has been set explicitly). if not checkpoint_path: - latest_path = saver.latest_checkpoint(self._model_dir) + latest_path = checkpoint_management.latest_checkpoint(self._model_dir) if not latest_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) @@ -956,7 +957,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, as_iterable=True, iterate_batches=False): # Check that model has been trained. - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir) if not checkpoint_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) @@ -1364,7 +1365,7 @@ class Estimator(BaseEstimator): if not checkpoint_path: # Locate the latest checkpoint - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir) if not checkpoint_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 339c4e0e360ed9ef9906f0e51b64a0dc13826259..ded93d4a7fb473c0c5df446ea89c5ab7784e9f3c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -563,10 +563,10 @@ def _mean_squared_loss(labels, logits, weights=None): labels = ops.convert_to_tensor(labels) # To prevent broadcasting inside "-". if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, dim=(1,)) + labels = array_ops.expand_dims(labels, axis=(1,)) # TODO(zakaria): make sure it does not recreate the broadcast bug. if len(logits.get_shape()) == 1: - logits = array_ops.expand_dims(logits, dim=(1,)) + logits = array_ops.expand_dims(logits, axis=(1,)) logits.get_shape().assert_is_compatible_with(labels.get_shape()) loss = math_ops.square(logits - math_ops.to_float(labels), name=name) return _compute_weighted_loss(loss, weights) @@ -579,10 +579,10 @@ def _poisson_loss(labels, logits, weights=None): labels = ops.convert_to_tensor(labels) # To prevent broadcasting inside "-". if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, dim=(1,)) + labels = array_ops.expand_dims(labels, axis=(1,)) # TODO(zakaria): make sure it does not recreate the broadcast bug. if len(logits.get_shape()) == 1: - logits = array_ops.expand_dims(logits, dim=(1,)) + logits = array_ops.expand_dims(logits, axis=(1,)) logits.get_shape().assert_is_compatible_with(labels.get_shape()) loss = nn.log_poisson_loss(labels, logits, compute_full_loss=True, name=name) @@ -797,7 +797,7 @@ def _log_loss_with_two_classes(labels, logits, weights=None): # TODO(ptucker): This will break for dynamic shapes. # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels. if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, dim=(1,)) + labels = array_ops.expand_dims(labels, axis=(1,)) loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits, name=name) return _compute_weighted_loss(loss, weights) diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index 66ebcfd1d81904b9afe5be6bd1a648fe325e1e0b..21f7dcc5e427bf00ffbc71150475d94f5336f8aa 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -15,9 +15,9 @@ """Implementation of k-means clustering on top of `Estimator` API (deprecated). This module is deprecated. Please use -@{tf.contrib.factorization.KMeansClustering} instead of -@{tf.contrib.learn.KMeansClustering}. It has a similar interface, but uses the -@{tf.estimator.Estimator} API instead of @{tf.contrib.learn.Estimator}. +`tf.contrib.factorization.KMeansClustering` instead of +`tf.contrib.learn.KMeansClustering`. It has a similar interface, but uses the +`tf.estimator.Estimator` API instead of `tf.contrib.learn.Estimator`. """ from __future__ import absolute_import diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index 14ee2ba6094760d52180d6de7763ea88b8ee98c8..08f23aa2231424887f3c935dbb8368a2aa46cc63 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -221,7 +221,7 @@ class ClusterConfig(object): class RunConfig(ClusterConfig, core_run_config.RunConfig): """This class specifies the configurations for an `Estimator` run. - This class is a deprecated implementation of @{tf.estimator.RunConfig} + This class is a deprecated implementation of `tf.estimator.RunConfig` interface. """ _USE_DEFAULT = 0 @@ -240,6 +240,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): keep_checkpoint_max=5, keep_checkpoint_every_n_hours=10000, log_step_count_steps=100, + protocol=None, evaluation_master='', model_dir=None, session_config=None): @@ -289,6 +290,8 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): session_config: a ConfigProto used to set session parameters, or None. Note - using this argument, it is easy to provide settings which break otherwise perfectly good models. Use with care. + protocol: An optional argument which specifies the protocol used when + starting server. None means default to grpc. """ # Neither parent class calls super().__init__(), so here we have to # manually call their __init__() methods. @@ -299,6 +302,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): # so instead of breaking compatibility with that assumption, we # just manually initialize this field: self._train_distribute = None + self._eval_distribute = None self._device_fn = None gpu_options = config_pb2.GPUOptions( @@ -313,6 +317,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): self._save_summary_steps = save_summary_steps self._save_checkpoints_secs = save_checkpoints_secs self._log_step_count_steps = log_step_count_steps + self._protocol = protocol self._session_config = session_config if save_checkpoints_secs == RunConfig._USE_DEFAULT: if save_checkpoints_steps is None: diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index f8a3709ee57a32734afa7ac8133271c75d152b2c..4e64efdd959eef0951c9ab782996fc2bd5919cc5 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -41,7 +41,7 @@ from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import server_lib from tensorflow.python.util import compat from tensorflow.python.util import function_utils @@ -95,7 +95,7 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): # Load and cache the path of the most recent checkpoint to avoid duplicate # searches on GCS. logging.info("Checking for checkpoint in %s", self._model_dir) - latest_path = saver.latest_checkpoint(self._model_dir) + latest_path = checkpoint_management.latest_checkpoint(self._model_dir) if not latest_path: logging.warning("Skipping evaluation and export since model has not been " @@ -162,16 +162,16 @@ class Experiment(object): Args: estimator: Object implementing Estimator interface, which could be a - combination of @{tf.contrib.learn.Trainable} and - @{tf.contrib.learn.Evaluable} (deprecated), or - @{tf.estimator.Estimator}. + combination of `tf.contrib.learn.Trainable` and + `tf.contrib.learn.Evaluable` (deprecated), or + `tf.estimator.Estimator`. train_input_fn: function, returns features and labels for training. eval_input_fn: function, returns features and labels for evaluation. If `eval_steps` is `None`, this should be configured only to produce for a finite number of batches (generally, 1 epoch over the evaluation data). eval_metrics: `dict` of string, metric function. If `None`, default set is used. This should be `None` if the `estimator` is - @{tf.estimator.Estimator}. If metrics are provided they will be + `tf.estimator.Estimator`. If metrics are provided they will be *appended* to the default set. train_steps: Perform this many steps of training. `None`, the default, means train forever. @@ -516,7 +516,8 @@ class Experiment(object): start = time.time() error_msg = None - latest_path = saver.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if not latest_path: error_msg = ("Estimator is not fitted yet. " "Will start an evaluation when a checkpoint is ready.") @@ -778,7 +779,8 @@ class Experiment(object): saving_listeners=self._saving_listeners) logging.info("Evaluating model now.") - latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir) + latest_checkpoint = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) eval_result = self._call_evaluate( input_fn=self._eval_input_fn, steps=self._eval_steps, diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py index 0d039d593b7850ead34484f88426255dc659b7fc..df156da3f467538ed1c6b640d651fdfd33ce243d 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.summary import summary +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib @@ -124,7 +125,7 @@ class GraphActionsTest(test.TestCase): # TODO(ptucker): Test number and contents of checkpoint files. def _assert_ckpt(self, output_dir, expected=True): - ckpt_state = saver_lib.get_checkpoint_state(output_dir) + ckpt_state = checkpoint_management.get_checkpoint_state(output_dir) if expected: pattern = '%s/model.ckpt-.*' % output_dir primary_ckpt_path = ckpt_state.model_checkpoint_path @@ -434,7 +435,7 @@ class GraphActionsTrainTest(test.TestCase): # TODO(ptucker): Test number and contents of checkpoint files. def _assert_ckpt(self, output_dir, expected=True): - ckpt_state = saver_lib.get_checkpoint_state(output_dir) + ckpt_state = checkpoint_management.get_checkpoint_state(output_dir) if expected: pattern = '%s/model.ckpt-.*' % output_dir primary_ckpt_path = ckpt_state.model_checkpoint_path diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 77f7c73d5412d40b338eaff4cf04d99fd0892723..3d691d434044aab1e3e86457cee6aadb5bf798c7 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -51,7 +51,7 @@ from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as core_summary -from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.util import deprecation @@ -735,7 +735,8 @@ class ValidationMonitor(EveryN): return False self._last_checkpoint_check_time = current_time # Check that we are not running evaluation on the same checkpoint. - latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if latest_path is None: logging.debug("Skipping evaluation since model has not been saved yet " "at step %d.", step) @@ -1059,7 +1060,8 @@ class ExportMonitor(EveryN): def end(self, session=None): super(ExportMonitor, self).end(session=session) - latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if latest_path is None: logging.info("Skipping export at the end since model has not been saved " "yet.") diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py index 5c34d0ddb01f3bcdc407e6926e7c5b73be1863b4..ff1da32c218b4e105b5503426ac01410665f9c7e 100644 --- a/tensorflow/contrib/learn/python/learn/monitors_test.py +++ b/tensorflow/contrib/learn/python/learn/monitors_test.py @@ -39,9 +39,9 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import gradient_descent from tensorflow.python.training import monitored_session -from tensorflow.python.training import saver from tensorflow.python.training import training_util @@ -317,7 +317,7 @@ class MonitorsTest(test.TestCase): self._run_monitor(monitor) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_no_ckpt(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -336,7 +336,7 @@ class MonitorsTest(test.TestCase): mock_latest_checkpoint.assert_called_with(model_dir) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_no_early_stopping_rounds(self, mock_latest_checkpoint, mock_estimator_class): @@ -356,7 +356,7 @@ class MonitorsTest(test.TestCase): self._assert_validation_monitor(monitor) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_invalid_metric(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -375,7 +375,7 @@ class MonitorsTest(test.TestCase): self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -464,7 +464,7 @@ class MonitorsTest(test.TestCase): monitor.epoch_end(epoch=0) monitor.end() - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_with_core_estimator(self, mock_latest_checkpoint): estimator = test.mock.Mock(spec=core_estimator.Estimator) model_dir = 'model/dir' @@ -495,7 +495,7 @@ class MonitorsTest(test.TestCase): expected_best_metrics={'loss': 42.0, 'auc': 0.5}) monitor.post_step(step=step, session=None) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_fail_with_core_estimator_and_metrics( self, mock_latest_checkpoint): estimator = test.mock.Mock(spec=core_estimator.Estimator) diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index 3eacac7a3d3dcff4d39025fdee88e16e385b1b84..0144b93814a174cfb8c3162f407a595ac637f4f5 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import training_util @@ -298,7 +299,8 @@ def _export_estimator(estimator, # If checkpoint_path is specified, use the specified checkpoint path. checkpoint_path = (checkpoint_path or - tf_saver.latest_checkpoint(estimator._model_dir)) + checkpoint_management.latest_checkpoint( + estimator._model_dir)) with ops.Graph().as_default() as g: training_util.create_global_step(g) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index f8106d1e4a7e79f1cd651c40995be480721a8129..4f22054af3077fa5322b52f56e815fe76104f602 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -55,7 +55,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.summary import summary_iterator -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated @@ -415,7 +415,7 @@ def make_export_strategy(serving_input_fn, `InputFnOps`. default_output_alternative_key: the name of the head to serve when an incoming serving request does not explicitly request a specific head. - Must be `None` if the estimator inherits from @{tf.estimator.Estimator} + Must be `None` if the estimator inherits from `tf.estimator.Estimator` or for single-headed models. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel. Each key should give the destination @@ -453,7 +453,7 @@ def make_export_strategy(serving_input_fn, The string path to the exported directory. Raises: - ValueError: If `estimator` is a @{tf.estimator.Estimator} instance + ValueError: If `estimator` is a `tf.estimator.Estimator` instance and `default_output_alternative_key` was specified. """ if isinstance(estimator, core_estimator.Estimator): @@ -504,7 +504,7 @@ def make_parsing_export_strategy(feature_columns, that must be provided at serving time (excluding labels!). default_output_alternative_key: the name of the head to serve when an incoming serving request does not explicitly request a specific head. - Must be `None` if the estimator inherits from @{tf.estimator.Estimator} + Must be `None` if the estimator inherits from `tf.estimator.Estimator` or for single-headed models. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel. Each key should give the destination @@ -714,7 +714,8 @@ def make_best_model_export_strategy( # as soon as contrib is cleaned up and we can thus be sure that # estimator is a tf.estimator.Estimator and not a # tf.contrib.learn.Estimator - checkpoint_path = saver.latest_checkpoint(estimator.model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint( + estimator.model_dir) export_checkpoint_path, export_eval_result = best_model_selector.update( checkpoint_path, eval_result) @@ -766,7 +767,7 @@ def extend_export_strategy(base_export_strategy, The string path to the SavedModel indicated by post_export_fn. Raises: - ValueError: If `estimator` is a @{tf.estimator.Estimator} instance + ValueError: If `estimator` is a `tf.estimator.Estimator` instance and `default_output_alternative_key` was specified or if post_export_fn does not return a valid directory. RuntimeError: If unable to create temporary or final export directory. diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index 5e7b422e3cc368a22eb94ed470297ae78293c4eb..e74244720896a835174f54bb97049c1d9b1c92f8 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -625,11 +625,13 @@ def attention_decoder(decoder_inputs, v = [] attention_vec_size = attn_size # Size of query vectors for attention. for a in xrange(num_heads): - k = variable_scope.get_variable("AttnW_%d" % a, - [1, 1, attn_size, attention_vec_size]) + k = variable_scope.get_variable( + "AttnW_%d" % a, [1, 1, attn_size, attention_vec_size], + dtype=dtype) hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) v.append( - variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size])) + variable_scope.get_variable( + "AttnV_%d" % a, [attention_vec_size], dtype=dtype)) state = initial_state @@ -647,11 +649,13 @@ def attention_decoder(decoder_inputs, with variable_scope.variable_scope("Attention_%d" % a): y = Linear(query, attention_vec_size, True)(query) y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) + y = math_ops.cast(y, dtype) # Attention mask is a softmax of v^T * tanh(...). s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3]) - a = nn_ops.softmax(s) + a = nn_ops.softmax(math_ops.cast(s, dtype=dtypes.float32)) # Now calculate the attention-weighted vector d. + a = math_ops.cast(a, dtype) d = math_ops.reduce_sum( array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2]) ds.append(array_ops.reshape(d, [-1, attn_size])) @@ -681,6 +685,7 @@ def attention_decoder(decoder_inputs, raise ValueError("Could not infer input size from input: %s" % inp.name) inputs = [inp] + attns + inputs = [math_ops.cast(e, dtype) for e in inputs] x = Linear(inputs, input_size, True)(inputs) # Run the RNN. cell_output, state = cell(x, state) @@ -693,6 +698,7 @@ def attention_decoder(decoder_inputs, attns = attention(state) with variable_scope.variable_scope("AttnOutputProjection"): + cell_output = math_ops.cast(cell_output, dtype) inputs = [cell_output] + attns output = Linear(inputs, output_size, True)(inputs) if loop_function is not None: diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD index 5b89c6cef9fa9fdef7c26ddee1efa03f3056d881..7534b50a4ae0076fb27fb9cd0d1dd58b29192876 100644 --- a/tensorflow/contrib/linear_optimizer/BUILD +++ b/tensorflow/contrib/linear_optimizer/BUILD @@ -41,6 +41,10 @@ py_test( size = "medium", srcs = ["python/kernel_tests/sdca_ops_test.py"], srcs_version = "PY2AND3", + tags = [ + "no_gpu", + "no_pip_gpu", + ], deps = [ ":sdca_ops_py", ":sparse_feature_column_py", diff --git a/tensorflow/contrib/lite/Android.bp b/tensorflow/contrib/lite/Android.bp index ff3d18b4b117863ab483b4ff9c5dac71fb5379ee..a960d853723fde8f1696dfc85dccd894e4f00b69 100644 --- a/tensorflow/contrib/lite/Android.bp +++ b/tensorflow/contrib/lite/Android.bp @@ -23,6 +23,9 @@ cc_defaults { "-Wextra", "-Wno-unused-parameter", ], + shared_libs: [ + "liblog", + ], } cc_library_static { @@ -31,6 +34,7 @@ cc_library_static { srcs: ["context.c"], cflags: [ "-Wno-typedef-redefinition", + "-Wno-visibility", ], } @@ -44,6 +48,7 @@ cc_library_static { "error_reporter.cc", "graph_info.cc", "interpreter.cc", + "mmap_allocation.cc", "model.cc", "op_resolver.cc", "nnapi_delegate.cc", @@ -64,6 +69,7 @@ cc_library_static { "-Wno-mismatched-tags", "-Wno-sign-compare", "-Wno-unused-lambda-capture", + "-Wno-invalid-partial-specialization", ], } diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 9c804d27854b8004d34c65691b48ca2b0d3bbf7c..1e6f1e7da212c3aeb1563dc2f4b6dff2cb550736 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -47,6 +47,10 @@ cc_test( name = "arena_planner_test", size = "small", srcs = ["arena_planner_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable", + ], deps = [ ":arena_planner", "//tensorflow/contrib/lite/testing:util", @@ -121,13 +125,26 @@ cc_library( "graph_info.cc", "interpreter.cc", "model.cc", - "nnapi_delegate.cc", "op_resolver.cc", "optional_debug_tools.cc", - ], + ] + select({ + "//tensorflow:android": [ + "nnapi_delegate.cc", + "mmap_allocation.cc", + ], + "//tensorflow:windows": [ + "nnapi_delegate_disabled.cc", + "mmap_allocation_disabled.cc", + ], + "//conditions:default": [ + "nnapi_delegate_disabled.cc", + "mmap_allocation.cc", + ], + }), hdrs = [ "allocation.h", "context.h", + "context_util.h", "error_reporter.h", "graph_info.h", "interpreter.h", @@ -145,6 +162,7 @@ cc_library( ":memory_planner", ":schema_fbs_version", ":simple_memory_arena", + ":string", ":util", "//tensorflow/contrib/lite/kernels:eigen_support", "//tensorflow/contrib/lite/kernels:gemm_support", @@ -184,6 +202,7 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/contrib/lite/kernels:kernel_util", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "//tensorflow/contrib/lite/schema:schema_fbs", @@ -197,6 +216,7 @@ cc_test( name = "graph_info_test", size = "small", srcs = ["graph_info_test.cc"], + tags = ["no_oss"], deps = [ ":framework", ":string_util", @@ -241,6 +261,7 @@ cc_test( name = "op_resolver_test", size = "small", srcs = ["op_resolver_test.cc"], + tags = ["no_oss"], deps = [ ":framework", "//tensorflow/contrib/lite/testing:util", @@ -273,6 +294,7 @@ cc_test( name = "util_test", size = "small", srcs = ["util_test.cc"], + tags = ["no_oss"], deps = [ ":context", ":util", diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc index a4772731ecda92431c412672610a39c188dabf27..89462618148a2afbcf2ef6b1dd2985bcd0178734 100644 --- a/tensorflow/contrib/lite/allocation.cc +++ b/tensorflow/contrib/lite/allocation.cc @@ -13,56 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include "tensorflow/contrib/lite/allocation.h" + #include #include -#include #include #include #include #include #include -#include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" -#include "tensorflow/contrib/lite/nnapi_delegate.h" namespace tflite { -MMAPAllocation::MMAPAllocation(const char* filename, - ErrorReporter* error_reporter) - : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) { - mmap_fd_ = open(filename, O_RDONLY); - if (mmap_fd_ == -1) { - error_reporter_->Report("Could not open '%s'.", filename); - return; - } - struct stat sb; - fstat(mmap_fd_, &sb); - buffer_size_bytes_ = sb.st_size; - mmapped_buffer_ = - mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0); - if (mmapped_buffer_ == MAP_FAILED) { - error_reporter_->Report("Mmap of '%s' failed.", filename); - return; - } -} - -MMAPAllocation::~MMAPAllocation() { - if (valid()) { - munmap(const_cast(mmapped_buffer_), buffer_size_bytes_); - } - if (mmap_fd_ != -1) close(mmap_fd_); -} - -const void* MMAPAllocation::base() const { return mmapped_buffer_; } - -size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; } - -bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; } - +#ifndef TFLITE_MCU FileCopyAllocation::FileCopyAllocation(const char* filename, ErrorReporter* error_reporter) : Allocation(error_reporter) { @@ -94,7 +60,9 @@ FileCopyAllocation::FileCopyAllocation(const char* filename, filename); return; } - copied_buffer_ = std::move(buffer); + // Versions of GCC before 6.2.0 don't support std::move from non-const + // char[] to const char[] unique_ptrs. + copied_buffer_.reset(const_cast(buffer.release())); } FileCopyAllocation::~FileCopyAllocation() {} @@ -104,6 +72,7 @@ const void* FileCopyAllocation::base() const { return copied_buffer_.get(); } size_t FileCopyAllocation::bytes() const { return buffer_size_bytes_; } bool FileCopyAllocation::valid() const { return copied_buffer_ != nullptr; } +#endif MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes, ErrorReporter* error_reporter) diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h index 68aee2e64473320c461ec8b3f194904e7b8da43c..6ffdb4145ab36bbf88973ea6d49ec1f767409a3b 100644 --- a/tensorflow/contrib/lite/allocation.h +++ b/tensorflow/contrib/lite/allocation.h @@ -26,6 +26,8 @@ limitations under the License. namespace tflite { +using std::string; + // A memory allocation handle. This could be a mmap or shared memory. class Allocation { public: @@ -51,6 +53,8 @@ class MMAPAllocation : public Allocation { size_t bytes() const override; bool valid() const override; + static bool IsSupported(); + protected: // Data required for mmap. int mmap_fd_ = -1; // mmap file descriptor diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index 4f836d367747e06de682b5764206d33f6e2fb983..02442575b3aeed04ac6569440dd52a4d5ddd4d98 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -17,29 +17,26 @@ limitations under the License. namespace tflite { -namespace { - -// Memory allocation tuning -constexpr const int kDefaultArenaAlignment = 64; -constexpr const int kDefaultTensorAlignment = 4; - -} // namespace - struct AllocationInfo { // The node index requesting this allocation. int node; // The tensor index to be allocated or deallocated. int tensor; // Whether to allocate or deallocate - enum { ALLOC, DEALLOC } type; + enum Type { ALLOC, DEALLOC } type; }; ArenaPlanner::ArenaPlanner(TfLiteContext* context, - std::unique_ptr graph_info) + std::unique_ptr graph_info, + bool preserve_inputs, bool preserve_intermediates, + int tensor_alignment) : context_(context), graph_info_(std::move(graph_info)), arena_(kDefaultArenaAlignment), - persistent_arena_(kDefaultArenaAlignment) {} + persistent_arena_(kDefaultArenaAlignment), + preserve_inputs_(preserve_inputs), + preserve_intermediates_(preserve_intermediates), + tensor_alignment_(tensor_alignment) {} ArenaPlanner::~ArenaPlanner() {} @@ -67,6 +64,33 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { // Keeps track of references to each tensor. std::vector refcounts(graph_info_->num_tensors(), 0); + // `allocated` and `deallocated` are technically list of boolean values. + // We're saving the compiled binary size by using `vector`. + std::vector allocated(graph_info_->num_tensors(), false); + std::vector deallocated(graph_info_->num_tensors(), false); + + auto allocate = [this, &allocated, &deallocated](int node, + int tensor) -> TfLiteStatus { + if (allocated[tensor]) { + return kTfLiteOk; + } + TF_LITE_ENSURE(context_, !deallocated[tensor]); + alloc_queue_.push_back({node, tensor, AllocationInfo::ALLOC}); + allocated[tensor] = true; + return kTfLiteOk; + }; + + auto deallocate = [this, &allocated, &deallocated]( + int node, int tensor) -> TfLiteStatus { + if (!allocated[tensor]) { + // Do not enqueue a DEALLOC if the tensor is never allocated. + // This happened with the constant tensors. + return kTfLiteOk; + } + TF_LITE_ENSURE(context_, !deallocated[tensor]); + alloc_queue_.push_back({node, tensor, AllocationInfo::DEALLOC}); + return kTfLiteOk; + }; // There will be an entry in alloc_queue_ for the allocation of each tensor // and another for their deallocation. @@ -79,6 +103,32 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { refcounts[tensor_index]++; } + // Variable tensors should are also never overwritten and need to be alive all + // the time. + for (int tensor_index : graph_info_->variables()) { + refcounts[tensor_index]++; + } + + // Queue all graph inputs for allocation. If preserve_inputs_ is true, make + // sure they never be overwritten. + for (int tensor_index : graph_info_->inputs()) { + if (tensor_index != kOptionalTensor) { + if (preserve_inputs_) { + refcounts[tensor_index]++; + } + TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); + } + } + + // Queue all graph variable tensors for allocation. + for (int tensor_index : graph_info_->variables()) { + if (tensor_index != kOptionalTensor) { + // Increase the reference count for input tensors by one, so it will + // never be deallocated. + TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); + } + } + // Count references to node input tensors. for (int i = 0; i < graph_info_->num_nodes(); ++i) { const TfLiteNode& node = graph_info_->node(i); @@ -94,10 +144,9 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { // Queue all graph inputs for allocation. for (int tensor_index : graph_info_->inputs()) { if (tensor_index != kOptionalTensor) { - alloc_queue_.push_back({0, tensor_index, AllocationInfo::ALLOC}); + TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); } } - // Go through the graph in execution order. for (int i = 0; i < graph_info_->num_nodes(); ++i) { const TfLiteNode& node = graph_info_->node(i); @@ -106,18 +155,20 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { TfLiteIntArray* node_outputs = node.outputs; for (int j = 0; j < node_outputs->size; ++j) { int tensor_index = node_outputs->data[j]; - alloc_queue_.push_back({i, tensor_index, AllocationInfo::ALLOC}); + TF_LITE_ENSURE_STATUS(allocate(i, tensor_index)); } // Then update the ref-counts of the node's inputs, and if necessary queue // them for deallocation. - TfLiteIntArray* node_inputs = node.inputs; - for (int j = 0; j < node_inputs->size; ++j) { - int tensor_index = node_inputs->data[j]; - if (tensor_index != kOptionalTensor) { - refcounts[tensor_index]--; - if (refcounts[tensor_index] == 0) { - alloc_queue_.push_back({i, tensor_index, AllocationInfo::DEALLOC}); + if (!preserve_intermediates_) { + TfLiteIntArray* node_inputs = node.inputs; + for (int j = 0; j < node_inputs->size; ++j) { + int tensor_index = node_inputs->data[j]; + if (tensor_index != kOptionalTensor) { + refcounts[tensor_index]--; + if (refcounts[tensor_index] == 0) { + TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index)); + } } } } @@ -208,14 +259,12 @@ TfLiteStatus ArenaPlanner::ResolveTensorAllocation(int tensor_index) { TfLiteStatus ArenaPlanner::CalculateTensorAllocation(int tensor_index) { TfLiteTensor& tensor = *graph_info_->tensor(tensor_index); if (tensor.allocation_type == kTfLiteArenaRw) { - TF_LITE_ENSURE_STATUS(arena_.Allocate(context_, kDefaultTensorAlignment, - tensor.bytes, - &allocs_[tensor_index])); + TF_LITE_ENSURE_STATUS(arena_.Allocate( + context_, tensor_alignment_, tensor.bytes, &allocs_[tensor_index])); } if (tensor.allocation_type == kTfLiteArenaRwPersistent) { - TF_LITE_ENSURE_STATUS( - persistent_arena_.Allocate(context_, kDefaultTensorAlignment, - tensor.bytes, &allocs_[tensor_index])); + TF_LITE_ENSURE_STATUS(persistent_arena_.Allocate( + context_, tensor_alignment_, tensor.bytes, &allocs_[tensor_index])); } return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h index e9d0fbc5a9b5aec06e28da8757466b25f40da2f5..55003cf4e92d9ca79416c0f9f7a0c57e828af4ee 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/contrib/lite/arena_planner.h @@ -25,6 +25,10 @@ limitations under the License. namespace tflite { +// Memory allocation tuning +constexpr const int kDefaultArenaAlignment = 64; +constexpr const int kDefaultTensorAlignment = 64; + struct AllocationInfo; // A memory planner that makes all the allocations using arenas. @@ -43,8 +47,12 @@ struct AllocationInfo; class ArenaPlanner : public MemoryPlanner { public: // Ownership of 'context' is not taken and it must remain util the - // ArenaPlanner is destroyed. - ArenaPlanner(TfLiteContext* context, std::unique_ptr graph_info); + // ArenaPlanner is destroyed. If 'preserve_inputs' is true the inputs to the + // graph will not share memory with any other tensor, effectively preserving + // them until the end of inference. + ArenaPlanner(TfLiteContext* context, std::unique_ptr graph_info, + bool preserve_inputs, bool preserve_intermediates, + int tensor_alignment = kDefaultTensorAlignment); ~ArenaPlanner() override; ArenaPlanner(const ArenaPlanner&) = delete; ArenaPlanner& operator=(const ArenaPlanner&) = delete; @@ -100,6 +108,18 @@ class ArenaPlanner : public MemoryPlanner { // Raw memory buffer that is allocated for persistent tensors that are // declared as kTfLiteArenaRwPersistent. SimpleMemoryArena persistent_arena_; + + // Ensure that the memory self-allocated for inputs is never reused by the + // allocator. This allows for example, multiple runs without getting + // unpredictable results. + bool preserve_inputs_; + + // If true, then no overlapping of memory areas is done, meaning intermediates + // results can be queried after running (modulo running delegates). + bool preserve_intermediates_; + + // Number of bytes that tensor buffers should be aligned to. + int tensor_alignment_; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc index 16171df10a7b18b22919c6e54fe3d1e8e0120f69..7d7c41289cad95b73423a7218bf1e0516b2e87a2 100644 --- a/tensorflow/contrib/lite/arena_planner_test.cc +++ b/tensorflow/contrib/lite/arena_planner_test.cc @@ -24,6 +24,8 @@ limitations under the License. namespace tflite { namespace { +constexpr const int kTensorAlignment = 4; + // A simple op to be used in tests, as syntactic sugar. class TestOp { public: @@ -100,12 +102,18 @@ class TestGraph { std::vector* tensors() { return &tensors_; } const std::vector& inputs() { return inputs_; } const std::vector& outputs() { return outputs_; } + const std::vector& variables() { return variables_; } + + void SetVariables(const std::vector& variables) { + variables_ = variables; + } private: std::vector nodes_; std::vector tensors_; std::vector inputs_; std::vector outputs_; + std::vector variables_; }; // The GraphInfo for a TestGraph. @@ -123,6 +131,9 @@ class TestGraphInfo : public GraphInfo { } const std::vector& inputs() const override { return graph_->inputs(); } const std::vector& outputs() const override { return graph_->outputs(); } + const std::vector& variables() const override { + return graph_->variables(); + } private: TestGraph* graph_; @@ -142,11 +153,12 @@ void ReportError(TfLiteContext* context, const char* format, ...) { class ArenaPlannerTest : public ::testing::Test { protected: - void SetGraph(TestGraph* graph) { + void SetGraph(TestGraph* graph, bool preserve_inputs = false) { graph_ = graph; context_.ReportError = ReportError; planner_.reset(new ArenaPlanner( - &context_, std::unique_ptr(new TestGraphInfo(graph)))); + &context_, std::unique_ptr(new TestGraphInfo(graph)), + preserve_inputs, /*preserve intermediates*/ false, kTensorAlignment)); CHECK(planner_->ResetAllocations() == kTfLiteOk); CHECK(planner_->PlanAllocations() == kTfLiteOk); } @@ -168,8 +180,8 @@ class ArenaPlannerTest : public ::testing::Test { const TfLiteTensor& tensor = (*graph_->tensors())[tensor_index]; int64_t offset = GetOffset(tensor_index) + tensor.bytes; // We must make sure the offset is aligned to kDefaultArenaAlignment. - if (offset % 4 != 0) { - offset += 4 - offset % 4; + if (offset % kTensorAlignment != 0) { + offset += kTensorAlignment - offset % kTensorAlignment; } return offset; }; @@ -234,6 +246,30 @@ TEST_F(ArenaPlannerTest, SimpleGraph) { EXPECT_EQ(GetOffset(3), 0); } +TEST_F(ArenaPlannerTest, SimpleGraphInputsPreserved) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4, 5}, {}}, // Second op + {{4, 5}, {3}, {}} // Third op + }, + {3}); + SetGraph(&graph, /*preserve_inputs=*/true); + Execute(0, 10); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 +4 +5 -2 +3 -4 -5 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + // Because we are keeping the inputs alive until the end (due to + // preserve_inputs=true), the output tensor will not be able to use that + // space. It will end up using the same are as tensor #2. + EXPECT_EQ(GetOffset(3), GetOffsetAfter(1)); +} + TEST_F(ArenaPlannerTest, SimpleGraphWithTemporary) { TestGraph graph({0, 1}, { @@ -306,13 +342,15 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithPersistentTensor) { { /* in, out, tmp */ {{0, 1}, {2}, {}}, // First op - {{2, 0}, {4}, {5}}, // Second op, with temporary + {{2, 0}, {4}, {5}}, // Second op, with persistent {{4, -1}, {3}, {}} // Third op, with optional }, {3}); // Make #1 persistent so it goes into its own arena. (*graph.tensors())[1].allocation_type = kTfLiteArenaRwPersistent; + // The only use case for kTfLiteArenaRwPersistent is variable tensor now. + graph.SetVariables({1}); SetGraph(&graph); Execute(0, 10); diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 974e6c5d98e5691a3733495e915d919c4bf57d3a..ab694d768f9dc7affea72074d4605321296d758d 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -1,193 +1,218 @@ """Generate Flatbuffer binary from json.""" + load( "//tensorflow:tensorflow.bzl", + "tf_cc_shared_object", "tf_cc_test", ) def tflite_copts(): - """Defines compile time flags.""" - copts = [ - "-DFARMHASH_NO_CXX_STRING", - ] + select({ - str(Label("//tensorflow:android_arm64")): [ - "-std=c++11", - "-O3", - ], - str(Label("//tensorflow:android_arm")): [ - "-mfpu=neon", - "-mfloat-abi=softfp", - "-std=c++11", - "-O3", - ], - str(Label("//tensorflow:android_x86")): [ - "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK", - ], - str(Label("//tensorflow:ios_x86_64")): [ - "-msse4.1", - ], - "//conditions:default": [], - }) + select({ - str(Label("//tensorflow:with_default_optimizations")): [], - "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"], - }) + """Defines compile time flags.""" + copts = [ + "-DFARMHASH_NO_CXX_STRING", + ] + select({ + str(Label("//tensorflow:android_arm64")): [ + "-std=c++11", + "-O3", + ], + str(Label("//tensorflow:android_arm")): [ + "-mfpu=neon", + "-mfloat-abi=softfp", + "-std=c++11", + "-O3", + ], + str(Label("//tensorflow:android_x86")): [ + "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK", + ], + str(Label("//tensorflow:ios_x86_64")): [ + "-msse4.1", + ], + str(Label("//tensorflow:windows")): [ + "/DTF_COMPILE_LIBRARY", + ], + "//conditions:default": [], + }) + select({ + str(Label("//tensorflow:with_default_optimizations")): [], + "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"], + }) - return copts + return copts LINKER_SCRIPT = "//tensorflow/contrib/lite/java/src/main/native:version_script.lds" def tflite_linkopts_unstripped(): - """Defines linker flags to reduce size of TFLite binary. + """Defines linker flags to reduce size of TFLite binary. - These are useful when trying to investigate the relative size of the - symbols in TFLite. + These are useful when trying to investigate the relative size of the + symbols in TFLite. - Returns: - a select object with proper linkopts - """ - return select({ - "//tensorflow:android": [ - "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj. - "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export. - "-Wl,--gc-sections", # Eliminate unused code and data. - "-Wl,--as-needed", # Don't link unused libs. - ], - "//tensorflow/contrib/lite:mips": [], - "//tensorflow/contrib/lite:mips64": [], - "//conditions:default": [ - "-Wl,--icf=all", # Identical code folding. - ], - }) + Returns: + a select object with proper linkopts + """ + return select({ + "//tensorflow:android": [ + "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj. + "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export. + "-Wl,--gc-sections", # Eliminate unused code and data. + "-Wl,--as-needed", # Don't link unused libs. + ], + "//tensorflow:darwin": [], + "//tensorflow/contrib/lite:mips": [], + "//tensorflow/contrib/lite:mips64": [], + "//conditions:default": [ + "-Wl,--icf=all", # Identical code folding. + ], + }) def tflite_jni_linkopts_unstripped(): - """Defines linker flags to reduce size of TFLite binary with JNI. + """Defines linker flags to reduce size of TFLite binary with JNI. - These are useful when trying to investigate the relative size of the - symbols in TFLite. + These are useful when trying to investigate the relative size of the + symbols in TFLite. - Returns: - a select object with proper linkopts - """ - return select({ - "//tensorflow:android": [ - "-Wl,--gc-sections", # Eliminate unused code and data. - "-Wl,--as-needed", # Don't link unused libs. - ], - "//tensorflow/contrib/lite:mips": [], - "//tensorflow/contrib/lite:mips64": [], - "//conditions:default": [ - "-Wl,--icf=all", # Identical code folding. - ], - }) + Returns: + a select object with proper linkopts + """ + return select({ + "//tensorflow:android": [ + "-Wl,--gc-sections", # Eliminate unused code and data. + "-Wl,--as-needed", # Don't link unused libs. + ], + "//tensorflow:darwin": [], + "//tensorflow/contrib/lite:mips": [], + "//tensorflow/contrib/lite:mips64": [], + "//conditions:default": [ + "-Wl,--icf=all", # Identical code folding. + ], + }) def tflite_linkopts(): - """Defines linker flags to reduce size of TFLite binary.""" - return tflite_linkopts_unstripped() + select({ - "//tensorflow:android": [ - "-s", # Omit symbol table. - ], - "//conditions:default": [], - }) + """Defines linker flags to reduce size of TFLite binary.""" + return tflite_linkopts_unstripped() + select({ + "//tensorflow:android": [ + "-s", # Omit symbol table. + ], + "//conditions:default": [], + }) def tflite_jni_linkopts(): - """Defines linker flags to reduce size of TFLite binary with JNI.""" - return tflite_jni_linkopts_unstripped() + select({ - "//tensorflow:android": [ - "-s", # Omit symbol table. - "-latomic", # Required for some uses of ISO C++11 in x86. - ], - "//conditions:default": [], - }) + """Defines linker flags to reduce size of TFLite binary with JNI.""" + return tflite_jni_linkopts_unstripped() + select({ + "//tensorflow:android": [ + "-s", # Omit symbol table. + "-latomic", # Required for some uses of ISO C++11 in x86. + ], + "//conditions:default": [], + }) + +def tflite_jni_binary( + name, + copts = tflite_copts(), + linkopts = tflite_jni_linkopts(), + linkscript = LINKER_SCRIPT, + linkshared = 1, + linkstatic = 1, + deps = []): + """Builds a jni binary for TFLite.""" + linkopts = linkopts + [ + "-Wl,--version-script", # Export only jni functions & classes. + "$(location {})".format(linkscript), + ] + native.cc_binary( + name = name, + copts = copts, + linkshared = linkshared, + linkstatic = linkstatic, + deps = deps + [linkscript], + linkopts = linkopts, + ) -def tflite_jni_binary(name, - copts=tflite_copts(), - linkopts=tflite_jni_linkopts(), - linkscript=LINKER_SCRIPT, - linkshared=1, - linkstatic=1, - deps=[]): - """Builds a jni binary for TFLite.""" - linkopts = linkopts + [ - "-Wl,--version-script", # Export only jni functions & classes. - "$(location {})".format(linkscript), - ] - native.cc_binary( - name=name, - copts=copts, - linkshared=linkshared, - linkstatic=linkstatic, - deps= deps + [linkscript], - linkopts=linkopts) +def tflite_cc_shared_object( + name, + copts = tflite_copts(), + linkopts = [], + linkstatic = 1, + deps = []): + """Builds a shared object for TFLite.""" + tf_cc_shared_object( + name = name, + copts = copts, + linkstatic = linkstatic, + linkopts = linkopts + tflite_jni_linkopts(), + framework_so = [], + deps = deps, + ) def tf_to_tflite(name, src, options, out): - """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer. + """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer. - Args: - name: Name of rule. - src: name of the input graphdef file. - options: options passed to TOCO. - out: name of the output flatbuffer file. - """ + Args: + name: Name of rule. + src: name of the input graphdef file. + options: options passed to TOCO. + out: name of the output flatbuffer file. + """ - toco_cmdline = " ".join([ - "//tensorflow/contrib/lite/toco:toco", - "--input_format=TENSORFLOW_GRAPHDEF", - "--output_format=TFLITE", - ("--input_file=$(location %s)" % src), - ("--output_file=$(location %s)" % out), - ] + options ) - native.genrule( - name = name, - srcs=[src], - outs=[out], - cmd = toco_cmdline, - tools= ["//tensorflow/contrib/lite/toco:toco"], - ) + toco_cmdline = " ".join([ + "//tensorflow/contrib/lite/toco:toco", + "--input_format=TENSORFLOW_GRAPHDEF", + "--output_format=TFLITE", + ("--input_file=$(location %s)" % src), + ("--output_file=$(location %s)" % out), + ] + options) + native.genrule( + name = name, + srcs = [src], + outs = [out], + cmd = toco_cmdline, + tools = ["//tensorflow/contrib/lite/toco:toco"], + ) def tflite_to_json(name, src, out): - """Convert a TF Lite flatbuffer to JSON. + """Convert a TF Lite flatbuffer to JSON. - Args: - name: Name of rule. - src: name of the input flatbuffer file. - out: name of the output JSON file. - """ + Args: + name: Name of rule. + src: name of the input flatbuffer file. + out: name of the output JSON file. + """ - flatc = "@flatbuffers//:flatc" - schema = "//tensorflow/contrib/lite/schema:schema.fbs" - native.genrule( - name = name, - srcs = [schema, src], - outs = [out], - cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" + - "$(location %s) --raw-binary --strict-json -t" + - " -o /tmp $(location %s) -- $${TMP}.bin &&" + - "cp $${TMP}.json $(location %s)") - % (src, flatc, schema, out), - tools = [flatc], - ) + flatc = "@flatbuffers//:flatc" + schema = "//tensorflow/contrib/lite/schema:schema.fbs" + native.genrule( + name = name, + srcs = [schema, src], + outs = [out], + cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" + + "$(location %s) --raw-binary --strict-json -t" + + " -o /tmp $(location %s) -- $${TMP}.bin &&" + + "cp $${TMP}.json $(location %s)") % + (src, flatc, schema, out), + tools = [flatc], + ) def json_to_tflite(name, src, out): - """Convert a JSON file to TF Lite's flatbuffer. + """Convert a JSON file to TF Lite's flatbuffer. - Args: - name: Name of rule. - src: name of the input JSON file. - out: name of the output flatbuffer file. - """ + Args: + name: Name of rule. + src: name of the input JSON file. + out: name of the output flatbuffer file. + """ - flatc = "@flatbuffers//:flatc" - schema = "//tensorflow/contrib/lite/schema:schema_fbs" - native.genrule( - name = name, - srcs = [schema, src], - outs = [out], - cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" + - "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" + - " -o /tmp $(location %s) $${TMP}.json &&" + - "cp $${TMP}.bin $(location %s)") - % (src, flatc, schema, out), - tools = [flatc], - ) + flatc = "@flatbuffers//:flatc" + schema = "//tensorflow/contrib/lite/schema:schema_fbs" + native.genrule( + name = name, + srcs = [schema, src], + outs = [out], + cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" + + "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" + + " -o /tmp $(location %s) $${TMP}.json &&" + + "cp $${TMP}.bin $(location %s)") % + (src, flatc, schema, out), + tools = [flatc], + ) # This is the master list of generated examples that will be made into tests. A # function called make_XXX_tests() must also appear in generate_examples.py. @@ -195,13 +220,14 @@ def json_to_tflite(name, src, out): def generated_test_models(): return [ "add", - "arg_max", + "arg_min_max", "avg_pool", "batch_to_space_nd", "concat", "constant", "control_dep", "conv", + "conv_with_shared_weights", "depthwiseconv", "div", "equal", @@ -214,6 +240,7 @@ def generated_test_models(): "global_batch_norm", "greater", "greater_equal", + "sum", "l2norm", "l2_pool", "less", @@ -221,8 +248,10 @@ def generated_test_models(): "local_response_norm", "log_softmax", "log", - # TODO(b/110143200): Enable after resolving issues with LSTM conversion. - # "lstm", + "logical_and", + "logical_or", + "logical_xor", + "lstm", "max_pool", "maximum", "mean", @@ -230,14 +259,21 @@ def generated_test_models(): "mul", "neg", "not_equal", + "one_hot", + "pack", "pad", "padv2", - # "prelu", + "prelu", + "pow", + "reduce_max", + "reduce_prod", "relu", "relu1", "relu6", "reshape", "resize_bilinear", + "rsqrt", + "shape", "sigmoid", "sin", "slice", @@ -246,6 +282,7 @@ def generated_test_models(): "space_to_depth", "sparse_to_dense", "split", + "sqrt", "squeeze", "strided_slice", "strided_slice_1d_exhaustive", @@ -253,63 +290,63 @@ def generated_test_models(): "tile", "topk", "transpose", - "transpose_conv", + #"transpose_conv", # disabled due to b/111213074 "where", ] def gen_zip_test(name, test_name, **kwargs): - """Generate a zipped-example test and its dependent zip files. + """Generate a zipped-example test and its dependent zip files. - Args: - name: Resulting cc_test target name - test_name: Test targets this model. Comes from the list above. - **kwargs: tf_cc_test kwargs. - """ - gen_zipped_test_file( - name = "zip_%s" % test_name, - file = "%s.zip" % test_name, - ) - tf_cc_test(name, **kwargs) + Args: + name: Resulting cc_test target name + test_name: Test targets this model. Comes from the list above. + **kwargs: tf_cc_test kwargs. + """ + gen_zipped_test_file( + name = "zip_%s" % test_name, + file = "%s.zip" % test_name, + ) + tf_cc_test(name, **kwargs) def gen_zipped_test_file(name, file): - """Generate a zip file of tests by using :generate_examples. + """Generate a zip file of tests by using :generate_examples. - Args: - name: Name of output. We will produce "`file`.files" as a target. - file: The name of one of the generated_examples targets, e.g. "transpose" - """ - toco = "//tensorflow/contrib/lite/toco:toco" - native.genrule( - name = file + ".files", - cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco - + " --zip_to_output " + file + " $(@D)"), - outs = [file], - tools = [ - ":generate_examples", - toco, - ], - ) + Args: + name: Name of output. We will produce "`file`.files" as a target. + file: The name of one of the generated_examples targets, e.g. "transpose" + """ + toco = "//tensorflow/contrib/lite/toco:toco" + native.genrule( + name = file + ".files", + cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco + + " --zip_to_output " + file + " $(@D)"), + outs = [file], + tools = [ + ":generate_examples", + toco, + ], + ) - native.filegroup( - name = name, - srcs = [file], - ) + native.filegroup( + name = name, + srcs = [file], + ) def gen_selected_ops(name, model): - """Generate the library that includes only used ops. + """Generate the library that includes only used ops. - Args: - name: Name of the generated library. - model: TFLite model to interpret. - """ - out = name + "_registration.cc" - tool = "//tensorflow/contrib/lite/tools:generate_op_registrations" - tflite_path = "//tensorflow/contrib/lite" - native.genrule( - name = name, - srcs = [model], - outs = [out], - cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") - % (tool, model, out, tflite_path[2:]), - tools = [tool], - ) + Args: + name: Name of the generated library. + model: TFLite model to interpret. + """ + out = name + "_registration.cc" + tool = "//tensorflow/contrib/lite/tools:generate_op_registrations" + tflite_path = "//tensorflow/contrib/lite" + native.genrule( + name = name, + srcs = [model], + outs = [out], + cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") % + (tool, model, out, tflite_path[2:]), + tools = [tool], + ) diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh deleted file mode 100755 index 9f398f4a9f3dcafd7bd49fd5d95e9991b8b36b75..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash -x -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -set -e - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -cd "$SCRIPT_DIR/../../.." - -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_x86_64/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_i386/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_armv7/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 \ -$SCRIPT_DIR/gen/lib/ios_armv7s/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_arm64/libtensorflow-lite.a - -lipo \ -tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_i386/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_armv7/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_armv7s/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_arm64/libtensorflow-lite.a \ --create \ --output tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index c1cc4476fbd45fa6b3f5b3a1ed2cba39cc2ad54b..70178b2faabe85f8a53a94c2b5d2e3ea40c8ba05 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -92,8 +92,17 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteSequenceRNNParams; +typedef enum { + kTfLiteFullyConnectedWeightsFormatDefault = 0, + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, +} TfLiteFullyConnectedWeightsFormat; + typedef struct { + // Parameters for FullyConnected version 1 or above. TfLiteFusedActivation activation; + + // Parameters for FullyConnected version 2 or above. + TfLiteFullyConnectedWeightsFormat weights_format; } TfLiteFullyConnectedParams; typedef enum { @@ -215,7 +224,7 @@ typedef struct { typedef struct { bool keep_dims; -} TfLiteMeanParams; +} TfLiteReducerParams; typedef struct { int num_splits; @@ -240,6 +249,10 @@ typedef struct { TfLiteType output_type; } TfLiteArgMaxParams; +typedef struct { + TfLiteType output_type; +} TfLiteArgMinParams; + typedef struct { TfLitePadding padding; int stride_width; @@ -250,6 +263,29 @@ typedef struct { bool validate_indices; } TfLiteSparseToDenseParams; +typedef struct { + TfLiteType out_type; +} TfLiteShapeParams; + +typedef struct { + // Parameters supported by version 1: + float min; + float max; + int num_bits; + + // Parameters supported by version 2: + bool narrow_range; +} TfLiteFakeQuantParams; + +typedef struct { + int values_count; + int axis; +} TfLitePackParams; + +typedef struct { + int axis; +} TfLiteOneHotParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index aef9a92883f18dabfc36058507d739856c3c2af7..8a8eb9856886538a1483141ab5f67f54613ea2a1 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -99,6 +99,20 @@ typedef enum { kTfLiteBuiltinEqual = 71, kTfLiteBuiltinNotEqual = 72, kTfLiteBuiltinLog = 73, + kTfLiteBuiltinSum = 74, + kTfLiteBuiltinSqrt = 75, + kTfLiteBuiltinRsqrt = 76, + kTfLiteBuiltinShape = 77, + kTfLiteBuiltinPow = 78, + kTfLiteBuiltinArgMin = 79, + kTfLiteBuiltinFakeQuant = 80, + kTfLiteBuiltinReduceProd = 81, + kTfLiteBuiltinReduceMax = 82, + kTfLiteBuiltinPack = 83, + kTfLiteBuiltinLogicalOr = 84, + kTfLiteBuiltinOneHot = 85, + kTfLiteBuiltinLogicalAnd = 86, + kTfLiteBuiltinLogicalNot = 87, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c index 5c6f5e72a47180cd98be46f60cfa8eaf28197806..7f2aa316f4a9a265b14a216a6ffa53c7f0757426 100644 --- a/tensorflow/contrib/lite/context.c +++ b/tensorflow/contrib/lite/context.c @@ -76,7 +76,7 @@ void TfLiteTensorFree(TfLiteTensor* t) { void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, TfLiteQuantizationParams quantization, char* buffer, size_t size, TfLiteAllocationType allocation_type, - const void* allocation, TfLiteTensor* tensor) { + const void* allocation, bool is_variable, TfLiteTensor* tensor) { TfLiteTensorFree(tensor); tensor->type = type; tensor->name = name; @@ -86,6 +86,7 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, tensor->bytes = size; tensor->allocation_type = allocation_type; tensor->allocation = allocation; + tensor->is_variable = is_variable; } void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) { diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 4eb66cc225eb04923be9aaa445a335ad822c8a6f..c920f6a508bdf6719112af5ec9c675460f723d21 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -29,6 +29,9 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ #define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#if defined(_MSC_VER) +#include +#endif #include #include #include @@ -39,6 +42,26 @@ extern "C" { typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; +// The list of external context types known to TF Lite. This list exists solely +// to avoid conflicts and to ensure ops can share the external contexts they +// need. Access to the external contexts is controled by one of the +// corresponding support files. +typedef enum { + kTfLiteEigenContext = 0, // include eigen_support.h to use. + kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. + kTfLiteMaxExternalContexts = 2 +} TfLiteExternalContextType; + +// An external context is a collection of information unrelated to the TF Lite +// framework, but useful to a subset of the ops. TF Lite knows very little +// about about the actual contexts, but it keeps a list of them, and is able to +// refresh them if configurations like the number of recommended threads +// change. +typedef struct { + TfLiteExternalContextType type; + TfLiteStatus (*Refresh)(struct TfLiteContext* context); +} TfLiteExternalContext; + // Forward declare so GetNode can use this is in Context. typedef struct _TfLiteRegistration TfLiteRegistration; typedef struct _TfLiteDelegate TfLiteDelegate; @@ -138,6 +161,8 @@ typedef enum { kTfLiteInt64 = 4, kTfLiteString = 5, kTfLiteBool = 6, + kTfLiteInt16 = 7, + kTfLiteComplex64 = 8, } TfLiteType; // Parameters for asymmetric quantization. Quantized values can be converted @@ -148,7 +173,7 @@ typedef struct { int32_t zero_point; } TfLiteQuantizationParams; -// A union of points that points to memory for a given tensor. +// A union of pointers that points to memory for a given tensor. typedef union { int* i32; int64_t* i64; @@ -157,6 +182,12 @@ typedef union { const char* raw_const; uint8_t* uint8; bool* b; + int16_t* i16; +#if defined(_MSC_VER) + _Fcomplex* c64; +#else + _Complex float* c64; +#endif } TfLitePtrUnion; // Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped @@ -223,6 +254,9 @@ typedef struct { // delegate buffer. // WARNING: This is an // experimental interface that is subject to change. bool data_is_stale; + + // True if the tensor is a variable. + bool is_variable; } TfLiteTensor; // Free data memory of tensor `t`; @@ -235,9 +269,11 @@ void TfLiteTensorFree(TfLiteTensor* t); void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, TfLiteQuantizationParams quantization, char* buffer, size_t size, TfLiteAllocationType allocation_type, - const void* allocation, TfLiteTensor* tensor); + const void* allocation, bool is_variable, + TfLiteTensor* tensor); -// Resize the allocated data of a (dynamic) tensor. +// Resize the allocated data of a (dynamic) tensor. Tensors with allocation +// types other than kTfLiteDynamic will be ignored. void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); // A structure representing an instance of a node. @@ -330,10 +366,15 @@ typedef struct TfLiteContext { // eigen. int recommended_num_threads; - // TODO(ahentz): we should create a more general mechanism for this sort of - // library-global objects. - void* gemm_context; - void* eigen_context; + // Access external contexts by type. + // WARNING: This is an experimental interface that is subject to change. + TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*, + TfLiteExternalContextType); + // Set the value of a external context. Does not take ownership of the + // pointer. + // WARNING: This is an experimental interface that is subject to change. + void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType, + TfLiteExternalContext*); } TfLiteContext; typedef struct _TfLiteRegistration { @@ -368,6 +409,14 @@ typedef struct _TfLiteRegistration { // Returns kTfLiteOk on success. TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); + // profiling_string is called during summarization of profiling information + // in order to group executions together. Providing a value here will cause a + // given op to appear multiple times is the profiling report. This is + // particularly useful for custom ops that can perform significantly + // different calculations depending on their `user-data`. + const char* (*profiling_string)(const TfLiteContext* context, + const TfLiteNode* node); + // Builtin codes. If this kernel refers to a builtin this is the code // of the builtin. This is so we can do marshaling to other frameworks like // NN API. @@ -403,13 +452,15 @@ typedef struct _TfLiteDelegate { // Copy the data from delegate buffer handle to raw memory. // This can be null if the delegate doesn't use its own buffer. - TfLiteStatus (*CopyFromBufferHandle)(TfLiteDelegate* delegate, + TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context, + TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, void* data, size_t size); // Copy the data from raw memory to delegate buffer handle. // This can be null if the delegate doesn't use its own buffer. - TfLiteStatus (*CopyToBufferHandle)(TfLiteDelegate* delegate, + TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context, + TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, void* data, size_t size); @@ -417,11 +468,17 @@ typedef struct _TfLiteDelegate { // this doesn't release the underlying resource (e.g. textures). The // resources are either owned by application layer or the delegate. // This can be null if the delegate doesn't use its own buffer. - void (*FreeBufferHandle)(TfLiteDelegate* delegate, + void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate, TfLiteBufferHandle* handle); } TfLiteDelegate; // WARNING: This is an experimental interface that is subject to change. +// +// Currently, TfLiteDelegateParams has to be allocated in a way that it's +// trivially destructable. It will be stored as `builtin_data` field in +// `TfLiteNode` of the delegate node. +// +// See also the `CreateDelegateParams` function in `interpreter.cc` details. typedef struct { TfLiteDelegate* delegate; TfLiteIntArray* nodes_to_replace; diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..bb518becc582b776096fc0d2720042286b0b871e --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/BUILD @@ -0,0 +1,185 @@ +# +# This is a TF Lite delegate that is powered by TensorFlow's Eager. +# +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "buffer_map", + srcs = ["buffer_map.cc"], + hdrs = ["buffer_map.h"], + deps = [ + ":util", + "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "buffer_map_test", + size = "small", + srcs = ["buffer_map_test.cc"], + tags = [ + "tflite_not_portable", + ], + deps = [ + ":buffer_map", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "delegate", + srcs = [ + "delegate.cc", + ], + hdrs = [ + "delegate.h", + ], + deps = [ + ":buffer_map", + ":delegate_data", + ":kernel", + ":util", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "delegate_test", + size = "small", + srcs = ["delegate_test.cc"], + tags = [ + "tflite_not_portable", + ], + deps = [ + ":delegate", + ":test_util", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "delegate_data", + srcs = ["delegate_data.cc"], + hdrs = ["delegate_data.h"], + deps = [ + ":buffer_map", + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core/common_runtime/eager:context", + ], +) + +tf_cc_test( + name = "delegate_data_test", + size = "small", + srcs = ["delegate_data_test.cc"], + tags = [ + "tflite_not_portable", + ], + deps = [ + ":delegate_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "kernel", + srcs = ["kernel.cc"], + hdrs = ["kernel.h"], + deps = [ + ":delegate_data", + ":util", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite:string", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:execute", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "@flatbuffers", + ], +) + +tf_cc_test( + name = "kernel_test", + size = "small", + srcs = ["kernel_test.cc"], + tags = [ + "tflite_not_portable", + ], + deps = [ + ":delegate_data", + ":kernel", + ":test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "test_util", + testonly = True, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite:string", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_absl//absl/memory", + "@flatbuffers", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + ":constants", + "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "util_test", + size = "small", + srcs = ["util_test.cc"], + tags = [ + "tflite_not_portable", + ], + deps = [ + ":util", + "//tensorflow/contrib/lite:string", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "constants", + hdrs = ["constants.h"], +) diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc new file mode 100644 index 0000000000000000000000000000000000000000..e5a19c39976969a0b05b28596c6d7d5ebe7c7782 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc @@ -0,0 +1,111 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/log_memory.h" + +namespace tflite { +namespace eager { +namespace { +// A tensor buffer that is allocated, deallocated and populated by TF Lite. +class TfLiteTensorBuffer : public tensorflow::TensorBuffer { + public: + explicit TfLiteTensorBuffer(const TfLiteTensor* tensor) { + len_ = tensor->bytes; + // TODO(ahentz): if we can guarantee that TF Lite allocated tensors with + // the same alignment as TensorFlow (EIGEN_MAX_ALIGN_BYTES), then we can + // potentially eliminate the copy below. + data_ = + tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len_); + if (data_ != nullptr) { + if (tensorflow::LogMemory::IsEnabled()) { + tensorflow::LogMemory::RecordRawAllocation( + "TfLiteTensorBuffer_New", + tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, len_, + data_, tensorflow::cpu_allocator()); + } + std::memcpy(data_, tensor->data.raw, tensor->bytes); + } + } + + ~TfLiteTensorBuffer() override { + if (tensorflow::LogMemory::IsEnabled() && data_ != nullptr) { + tensorflow::LogMemory::RecordRawDeallocation( + "TfLiteTensorBuffer_Delete", + tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data_, + tensorflow::cpu_allocator(), false); + } + tensorflow::cpu_allocator()->DeallocateRaw(data_); + } + + void* data() const override { return data_; } + size_t size() const override { return len_; } + + TensorBuffer* root_buffer() override { return this; } + void FillAllocationDescription( + tensorflow::AllocationDescription* proto) const override { + tensorflow::int64 rb = size(); + proto->set_requested_bytes(rb); + proto->set_allocator_name(tensorflow::cpu_allocator()->Name()); + } + + // Prevents input forwarding from mutating this buffer. + bool OwnsMemory() const override { return false; } + + private: + void* data_; + size_t len_; +}; +} // namespace + +BufferMap::BufferMap() {} + +BufferMap::~BufferMap() {} + +bool BufferMap::HasTensor(int tensor_index) const { + return id_to_tensor_.count(tensor_index) != 0; +} + +tensorflow::Tensor BufferMap::GetTensor(int tensor_index) const { + return id_to_tensor_.at(tensor_index); +} + +void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) { + tensorflow::TensorShape shape; + int num_dims = tensor->dims->size; + for (int i = 0; i < num_dims; ++i) { + shape.AddDim(tensor->dims->data[i]); + } + // TODO(ahentz): we assume this is a new tensor and allocate a new buffer + // for it. This is not always the best approach. For example, this might + // be a reallocation after resizing tensors. In that case we would be + // preferable to somehow reuse the buffer. + auto* buf = new TfLiteTensorBuffer(tensor); + tensorflow::Tensor t = tensorflow::TensorCApi::MakeTensor( + GetTensorFlowDataType(tensor->type), shape, buf); + buf->Unref(); + + SetFromTensorFlow(tensor_index, std::move(t)); +} + +void BufferMap::SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor) { + id_to_tensor_[tensor_index] = std::move(tensor); +} + +} // namespace eager +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/eager/buffer_map.h new file mode 100644 index 0000000000000000000000000000000000000000..a28329ae7d14e3e0214c6602b28b09c43876bbf0 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.h @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_ + +#include + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tflite { +namespace eager { + +// Maps a TF Lite tensor index into a TensorFlow tensor. +// +// The TF Lite interpreter assigns integer indices to each of its tensors, but +// the Eager delegate deals in terms of TensorFlow tensors. This class maps +// from indices to tensors and allows the creation of new tensors to be +// associated with a given index. +class BufferMap { + public: + BufferMap(); + ~BufferMap(); + + // Returns true if the given 'tensor_index' has a corresponding + // tensorflow::Tensor. + bool HasTensor(int tensor_index) const; + + // Returns the tensorflow::Tensor associated with the given 'tensor_index'. + // Precondition: HasTensor() is true. + tensorflow::Tensor GetTensor(int tensor_index) const; + + // Associates the given tensorflow::Tensor with the given 'tensor_index'. + // Note that tensorflow Tensors share data buffers, so this method is only a + // shallow copy. + void SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor); + + // Same as above but creates a new tensorflow::Tensor with a copy of the + // given TfLiteTensor's data. + void SetFromTfLite(int tensor_index, const TfLiteTensor* tensor); + + private: + std::map id_to_tensor_; +}; + +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc b/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a046943e56d2b80f2670b7fc3dd57b36dc4d2425 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc @@ -0,0 +1,174 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/contrib/lite/util.h" + +namespace tflite { +namespace eager { +namespace { + +using ::testing::ElementsAre; + +// A bit of RAII to simplify handling of TfLiteTensors in the tests. +using UniqueTfLiteTensor = + std::unique_ptr>; + +template +UniqueTfLiteTensor MakeLiteTensor(const std::vector& shape, + const std::vector& data) { + auto tensor = UniqueTfLiteTensor(new TfLiteTensor, [](TfLiteTensor* t) { + TfLiteTensorDataFree(t); + TfLiteIntArrayFree(t->dims); + delete t; + }); + tensor->allocation_type = kTfLiteDynamic; + tensor->type = typeToTfLiteType(); + tensor->dims = ConvertVectorToTfLiteIntArray(shape); + tensor->data.raw = nullptr; + TfLiteTensorRealloc(data.size() * sizeof(T), tensor.get()); + memcpy(tensor->data.raw, data.data(), data.size() * sizeof(T)); + return tensor; +} + +template +tensorflow::Tensor MakeTensor(const std::vector& shape, + const std::vector& data) { + BufferMap buffer_map; // BufferMap is the easiest way to build the tensor. + UniqueTfLiteTensor t1 = MakeLiteTensor(shape, data); + buffer_map.SetFromTfLite(0, t1.get()); + return buffer_map.GetTensor(0); +} + +std::vector GetTensorShape(const tensorflow::Tensor& t) { + std::vector shape(t.dims()); + for (int i = 0; i < t.dims(); ++i) { + shape[i] = t.dim_size(i); + } + return shape; +} + +template +std::vector GetTensorData(const tensorflow::Tensor& t) { + const T* data = t.flat().data(); + return std::vector(data, data + t.NumElements()); +} + +TEST(BufferMapTest, EmptyBuffer) { + BufferMap buffer_map; + EXPECT_FALSE(buffer_map.HasTensor(0)); +} + +TEST(BufferMapTest, SetFromTfLite) { + BufferMap buffer_map; + + UniqueTfLiteTensor t = + MakeLiteTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + buffer_map.SetFromTfLite(0, t.get()); + ASSERT_TRUE(buffer_map.HasTensor(0)); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 0.123f, 0, 0)); + + // Also check details of the tensor. + tensorflow::Tensor out_tensor = buffer_map.GetTensor(0); + ASSERT_EQ(out_tensor.dtype(), tensorflow::DT_FLOAT); + ASSERT_EQ(out_tensor.NumElements(), 6); + ASSERT_THAT(GetTensorShape(out_tensor), ElementsAre(1, 2, 1, 3)); +} + +TEST(BufferMapTest, SetFromTfLiteTwice) { + UniqueTfLiteTensor t1 = + MakeLiteTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + UniqueTfLiteTensor t2 = + MakeLiteTensor({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2}); + + BufferMap buffer_map; + buffer_map.SetFromTfLite(0, t1.get()); + buffer_map.SetFromTfLite(0, t2.get()); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 3, 0, 0, 1, 2)); +} + +TEST(BufferMapTest, SetFromTensorFlow) { + tensorflow::Tensor t1 = + MakeTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + + BufferMap buffer_map; + buffer_map.SetFromTensorFlow(0, t1); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 0.123f, 0, 0)); + + // Also check details of the tensor. + tensorflow::Tensor out_tensor = buffer_map.GetTensor(0); + ASSERT_EQ(out_tensor.dtype(), tensorflow::DT_FLOAT); + ASSERT_EQ(out_tensor.NumElements(), 6); + ASSERT_THAT(GetTensorShape(out_tensor), ElementsAre(1, 2, 1, 3)); +} + +TEST(BufferMapTest, SetFromTensorFlowTwice) { + tensorflow::Tensor t1 = + MakeTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + tensorflow::Tensor t2 = MakeTensor({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2}); + BufferMap buffer_map; + buffer_map.SetFromTensorFlow(0, t1); + buffer_map.SetFromTensorFlow(0, t2); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 3, 0, 0, 1, 2)); +} + +TEST(BufferMapTest, TfLiteOverwritesTensorFlow) { + tensorflow::Tensor t1 = + MakeTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + UniqueTfLiteTensor t2 = + MakeLiteTensor({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2}); + + BufferMap buffer_map; + buffer_map.SetFromTensorFlow(0, t1); + buffer_map.SetFromTfLite(0, t2.get()); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 3, 0, 0, 1, 2)); +} + +TEST(BufferMapTest, TensorFlowOverwritesTfLite) { + tensorflow::Tensor t1 = + MakeTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + UniqueTfLiteTensor t2 = + MakeLiteTensor({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2}); + BufferMap buffer_map; + buffer_map.SetFromTfLite(0, t2.get()); + buffer_map.SetFromTensorFlow(0, t1); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 0.123f, 0, 0)); +} + +} // namespace +} // namespace eager +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc b/tensorflow/contrib/lite/delegates/eager/constants.h similarity index 51% rename from tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc rename to tensorflow/contrib/lite/delegates/eager/constants.h index 0e08a04370592f6e3c92b5811fa7e163f808e03c..7ed6ab7552792c68e6d90056c83c3c574c3f69f7 100644 --- a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc +++ b/tensorflow/contrib/lite/delegates/eager/constants.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,27 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -#include -#include +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_ namespace tflite { +namespace eager { -// Gets the elapsed wall-clock timespec. -timespec getCurrentTime() { - timespec time; - clock_gettime(CLOCK_MONOTONIC, &time); - return time; -} - -// Computes the time diff from two timespecs. Returns '-1' if 'stop' is earlier -// than 'start'. -jlong timespec_diff_nanoseconds(struct timespec* start, struct timespec* stop) { - jlong result = stop->tv_sec - start->tv_sec; - if (result < 0) return -1; - result = 1000000000 * result + (stop->tv_nsec - start->tv_nsec); - if (result < 0) return -1; - return result; -} +// The prefix of Eager op custom code. +// This will be matched agains the `custom_code` field in `OperatorCode` +// Flatbuffer Table. +constexpr char kCustomCodePrefix[] = "Eager"; +} // namespace eager } // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc new file mode 100644 index 0000000000000000000000000000000000000000..8ab768575e8cc4421ae0e0daddd9c25da79f6e24 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc @@ -0,0 +1,109 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/delegate.h" + +#include + +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" +#include "tensorflow/contrib/lite/delegates/eager/kernel.h" +#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tflite { +namespace eager { +namespace delegate { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { + // Get the nodes in the current execution plan. Interpreter owns this array. + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + + // Add all custom ops starting with "Eager" to list of supported nodes. + std::vector supported_nodes; + for (int node_index : TfLiteIntArrayView(plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + + if (IsEagerOp(registration->custom_name)) { + supported_nodes.push_back(node_index); + } + } + + // Request TFLite to partition the graph and make kernels for each independent + // subgraph. + TfLiteIntArray* size_and_nodes = + ConvertVectorToTfLiteIntArray(supported_nodes); + context->ReplaceSubgraphsWithDelegateKernels(context, GetKernel(), + size_and_nodes, delegate); + TfLiteIntArrayFree(size_and_nodes); + return kTfLiteOk; +} + +TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, void* data, + size_t size) { + BufferMap* buffer_map = + reinterpret_cast(delegate->data_)->GetBufferMap(context); + + if (!buffer_map->HasTensor(buffer_handle)) { + context->ReportError(context, "Invalid tensor index %d.", buffer_handle); + return kTfLiteError; + } + + tensorflow::Tensor t = buffer_map->GetTensor(buffer_handle); + tensorflow::StringPiece t_data = t.tensor_data(); + + if (size != t_data.size()) { + context->ReportError( + context, "Not enough space to store TensorFlow's aligned buffer."); + return kTfLiteError; + } + + memcpy(data, t_data.data(), t_data.size()); + return kTfLiteOk; +} + +} // namespace delegate +} // namespace eager + +EagerDelegate::EagerDelegate() {} + +EagerDelegate::~EagerDelegate() {} + +TfLiteStatus EagerDelegate::Apply(Interpreter* interpreter) { + if (!delegate_) { + if (!eager::DelegateData::Create(&delegate_data_).ok()) { + fprintf(stderr, "Unable to initialize TensorFlow context.\n"); + return kTfLiteError; + } + + delegate_.reset(new TfLiteDelegate{ + /*data_=*/delegate_data_.get(), + /*nullptr,*/ &eager::delegate::Prepare, + /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle, + /*CopyToBufferHandle=*/nullptr, + /*FreeBufferHandle=*/nullptr}); + } + + return interpreter->ModifyGraphWithDelegate(delegate_.get(), + /*allow_dynamic_tensors=*/true); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h new file mode 100644 index 0000000000000000000000000000000000000000..a07002f4870c018f5fd7cf5fe6ef14aa0320e65b --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" +#include "tensorflow/contrib/lite/interpreter.h" + +namespace tflite { + +// WARNING: This is an experimental interface that is subject to change. +// Delegate that can be used to extract parts of a graph that are designed to be +// executed by TensorFlow's runtime via Eager. +// +// The interpreter must be constructed after the EagerDelegate and destructed +// before the EagerDelegate. This delegate may be used with multiple +// interpreters, but it is *not* thread-safe. +// +// Usage: +// EagerDelegate delegate; +// ... build interpreter ... +// +// delegate.Apply(interpreter); +// ... run inference ... +// ... destroy interpreter ... +// ... destroy delegate ... +class EagerDelegate { + public: + EagerDelegate(); + ~EagerDelegate(); + + // Modifies the graph loaded in the interpreter. + TfLiteStatus Apply(Interpreter* interpreter); + + private: + std::unique_ptr delegate_data_; + std::unique_ptr delegate_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc new file mode 100644 index 0000000000000000000000000000000000000000..0fd5c976f8ca9be16f7e3c5e610573755b40c506 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tflite { +namespace eager { +tensorflow::Status DelegateData::Create(std::unique_ptr* data) { + std::vector devices; + + TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( + tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0", + &devices)); + + std::unique_ptr device_mgr( + new tensorflow::DeviceMgr(devices)); + // Note that Rendezvous is ref-counted so it will be automatically deleted. + tensorflow::Rendezvous* rendezvous = + new tensorflow::IntraProcessRendezvous(device_mgr.get()); + data->reset(new DelegateData(new tensorflow::EagerContext( + tensorflow::SessionOptions(), + tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, + /*async=*/false, std::move(device_mgr), rendezvous))); + return tensorflow::Status(); +} + +DelegateData::DelegateData(tensorflow::EagerContext* eager_context) + : eager_context_(eager_context) {} + +DelegateData::~DelegateData() {} + +} // namespace eager +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.h b/tensorflow/contrib/lite/delegates/eager/delegate_data.h new file mode 100644 index 0000000000000000000000000000000000000000..772d26f44e8b5b2b962c06f42b86df29ee1c1f8d --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_ + +#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" +#include "tensorflow/core/common_runtime/eager/context.h" + +namespace tflite { +namespace eager { + +// Data kept by the Eager delegate for the lifetime of an Interpreter. +class DelegateData { + public: + // Create a new DelegateData, initialized with a newly-created EagerContext. + static tensorflow::Status Create(std::unique_ptr* data); + + ~DelegateData(); + + // The EagerContext that is required for execution of Eager Ops. + tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); } + + // Map from TF Lite tensor index to TensorFlow tensor for a given context. + BufferMap* GetBufferMap(const TfLiteContext* context) { + return &buffer_map_[context]; + } + + private: + explicit DelegateData(tensorflow::EagerContext* eager_context); + + std::unique_ptr eager_context_; + // TODO(b/112439500): Clean up stale BufferMap instances after adding the + // necessary cleanup hook from a TfLiteContext to a TfLiteDelegate. + std::unordered_map buffer_map_; +}; + +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b3a0ffcec1d450ed4edcf10b9048e08d82b9eeca --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" + +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace eager { +namespace { + +TEST(DelegateDataTest, Basic) { + std::unique_ptr data; + // We only check for success because it is hard to make initialization fail. + // It only happens if we manage to not link the CPU device factory into the + // binary. + EXPECT_TRUE(DelegateData::Create(&data).ok()); + + TfLiteContext dummy_context1 = {}; + TfLiteContext dummy_context2 = {}; + EXPECT_NE(data->GetEagerContext(), nullptr); + EXPECT_NE(data->GetBufferMap(&dummy_context1), nullptr); + EXPECT_NE(data->GetBufferMap(&dummy_context1), + data->GetBufferMap(&dummy_context2)); +} + +} // namespace +} // namespace eager +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..511a239363ecaf9d64dc25f6a448435e3364df60 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc @@ -0,0 +1,198 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/delegate.h" + +#include +#include +#include "tensorflow/contrib/lite/delegates/eager/test_util.h" + +namespace tflite { +namespace eager { +namespace { + +using ::testing::ContainsRegex; +using ::testing::ElementsAre; + +class DelegateTest : public testing::EagerModelTest { + public: + DelegateTest() { + // The delegate needs to be constructed before the interpreter because the + // interpreter references data contained in the delegate. + delegate_.reset(new EagerDelegate()); + interpreter_.reset(new Interpreter(&error_reporter_)); + } + + ~DelegateTest() override { + // The delegate needs to be destructed after the interpreter because the + // interpreter references data contained in the delegate. + delete interpreter_.release(); + delete delegate_.release(); + } + + void ConfigureDelegate() { + CHECK(delegate_->Apply(interpreter_.get()) == kTfLiteOk); + } + + private: + std::unique_ptr delegate_; +}; + +TEST_F(DelegateTest, FullGraph) { + // Define the graph. + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + + // Apply the delegate. + ConfigureDelegate(); + + // Define inputs. + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(DelegateTest, MixedGraph) { + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfLiteMulOp({6, 7}, {8}); + + ConfigureDelegate(); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(DelegateTest, SplitGraph) { + AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kAdd, {1, 2}, {3}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + + AddTfLiteMulOp({4, 5}, {6}); + + AddTfOp(testing::kUnpack, {6}, {7, 8}); + AddTfOp(testing::kAdd, {7, 8}, {9}); + + ConfigureDelegate(); + + SetShape(0, {2, 2, 2, 1}); + SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(9), ElementsAre(1)); + ASSERT_THAT(GetValues(9), ElementsAre(10.0f)); +} + +TEST_F(DelegateTest, OnlyTFLite) { + // Only TFLite single op model. + AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3}); + AddTfLiteMulOp({0, 1}, {2}); + + ConfigureDelegate(); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(1, {2, 2, 1}); + SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1)); + ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f)); +} + +TEST_F(DelegateTest, MultipleInterpretersSameDelegate) { + // Build a graph, configure the delegate and set inputs. + { + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + ConfigureDelegate(); + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + } + + // Create a new interpreter, inject into the test framework and build + // a different graph using the *same* delegate. + std::unique_ptr interpreter(new Interpreter(&error_reporter_)); + interpreter_.swap(interpreter); + { + AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kAdd, {1, 2}, {3}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfLiteMulOp({4, 5}, {6}); + AddTfOp(testing::kUnpack, {6}, {7, 8}); + AddTfOp(testing::kAdd, {7, 8}, {9}); + ConfigureDelegate(); + SetShape(0, {2, 2, 2, 1}); + SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); + } + + // Swap back in the first interpreter and validate inference. + interpreter_.swap(interpreter); + { + ASSERT_TRUE(Invoke()); + EXPECT_THAT(GetShape(8), ElementsAre(2, 1)); + EXPECT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); + } + + // Swap in the second interpreter and validate inference. + interpreter_.swap(interpreter); + { + ASSERT_TRUE(Invoke()); + EXPECT_THAT(GetShape(9), ElementsAre(1)); + EXPECT_THAT(GetValues(9), ElementsAre(10.0f)); + } +} + +} // namespace +} // namespace eager +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..1082b78725986ba3e6f31607f526ea2df2f1fdfb --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc @@ -0,0 +1,290 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/kernel.h" + +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/builtin_ops.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" +#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/execute.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/node_def.pb.h" + +// Note: this is part of TF Lite's Eager delegation code which is to be +// completed soon. + +// This is the TF Lite op that is created by the eager delegate to handle +// execution of a supported subgraph. The usual flow is that the delegate +// informs the interpreter of supported nodes in a graph, and each supported +// subgraph is replaced with one instance of this kernel. +// +// The kernel is initialized with TfLiteDelegateParams from which we retrieve +// the global EagerContext and BufferMap, as well as a list of inputs and +// outputs to the subgraph. Those are used to build the OpData, with a list of +// TensorFlow Ops that should be executed in order (which we call an OpNode). +// +// For each node included in the subgraph, we query the interpreter and +// retrieve the associated NodeDef, which is then used to configure the +// corresponding TensorFlow/Eager Op. + +namespace tflite { +namespace eager { +namespace kernel { + +// Controls the lifetime of tensor handles in a vector. +class VectorOfHandles { + public: + explicit VectorOfHandles(int num_elements) : vector_(num_elements, nullptr) {} + + ~VectorOfHandles() { + for (auto* handle : vector_) { + if (handle) handle->Unref(); + } + } + + tensorflow::gtl::InlinedVector* GetVector() { + return &vector_; + } + + tensorflow::TensorHandle* GetHandle(int index) { return vector_[index]; } + + private: + tensorflow::gtl::InlinedVector vector_; +}; + +// Executes the TensorFlow op given by 'op_name', with the attributes specified +// in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'. +tensorflow::Status ExecuteEagerOp(tensorflow::EagerContext* eager_context, + BufferMap* buffer_map, const string& op_name, + const tensorflow::NodeDef& nodedef, + const std::vector& inputs, + const std::vector& outputs) { + const tensorflow::AttrTypeMap* attr_types; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + tensorflow::AttrTypeMapForOp(op_name.c_str(), &attr_types), + " (while processing attributes of '", op_name, "')"); + + tensorflow::EagerOperation op(eager_context, op_name.c_str(), attr_types); + for (const auto& attr : nodedef.attr()) { + op.MutableAttrs()->Set(attr.first, attr.second); + } + + for (int input_index : inputs) { + if (!buffer_map->HasTensor(input_index)) { + return tensorflow::errors::Internal( + "Cannot read from invalid tensor index ", input_index); + } + auto* handle = new tensorflow::TensorHandle( + buffer_map->GetTensor(input_index), nullptr, nullptr, nullptr); + op.AddInput(handle); + handle->Unref(); + } + + int num_retvals = outputs.size(); + VectorOfHandles retvals(num_retvals); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EagerExecute(&op, retvals.GetVector(), &num_retvals), + " (while executing '", op_name, "' via Eager)"); + + if (num_retvals != outputs.size()) { + return tensorflow::errors::Internal( + "Unexpected number of outputs from EagerExecute"); + } + + for (int i = 0; i < num_retvals; ++i) { + const tensorflow::Tensor* tensor = nullptr; + TF_RETURN_IF_ERROR(retvals.GetHandle(i)->Tensor(&tensor)); + buffer_map->SetFromTensorFlow(outputs[i], *tensor); + } + + return tensorflow::Status::OK(); +} + +// A single node within the larger 'op'. Note that this kernel executes many +// TensorFlow ops within a single TF Lite op. +struct OpNode { + // The name of the TensorFlow op to execute. + string name; + // The corresponding NodeDef, containing the attributes for the op. + tensorflow::NodeDef nodedef; + // List of inputs, as TF Lite tensor indices. + std::vector inputs; + // List of outputs, as TF Lite tensor indices. + std::vector outputs; +}; + +// The Larger 'op', which contains all the nodes in a supported subgraph. +struct OpData { + tensorflow::EagerContext* eager_context; + BufferMap* buffer_map; + std::vector nodes; + std::vector subgraph_inputs; + std::vector subgraph_outputs; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + + const TfLiteDelegateParams* params = + reinterpret_cast(buffer); + CHECK(params); + CHECK(params->delegate); + CHECK(params->delegate->data_); + op_data->eager_context = + reinterpret_cast(params->delegate->data_) + ->GetEagerContext(); + op_data->buffer_map = reinterpret_cast(params->delegate->data_) + ->GetBufferMap(context); + + CHECK(params->output_tensors); + for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) { + op_data->subgraph_outputs.push_back(tensor_index); + } + + CHECK(params->input_tensors); + for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) { + op_data->subgraph_inputs.push_back(tensor_index); + } + + CHECK(params->nodes_to_replace); + for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) { + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + + op_data->nodes.push_back(OpNode()); + OpNode& node_data = op_data->nodes.back(); + + node_data.name = ""; + if (node->custom_initial_data) { + // The flexbuffer contains a vector where the first elements is the + // op name and the second is a serialized NodeDef. + const flexbuffers::Vector& v = + flexbuffers::GetRoot( + reinterpret_cast(node->custom_initial_data), + node->custom_initial_data_size) + .AsVector(); + + node_data.name = v[0].AsString().str(); + if (!node_data.nodedef.ParseFromString(v[1].AsString().str())) { + // We will just leave the nodedef empty and error out in Eval(). + node_data.nodedef.Clear(); + } + } + + for (auto input_index : TfLiteIntArrayView(node->inputs)) { + node_data.inputs.push_back(input_index); + } + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + node_data.outputs.push_back(output_index); + } + } + + return op_data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_MSG( + context, op_data->eager_context != nullptr, + "Failed to initialize eager context. This often happens when a CPU " + "device has not been registered, presumably because some symbols from " + "tensorflow/core:core_cpu_impl were not linked into the binary."); + + // Whenever we find a constant tensor, insert it in the buffer map. + BufferMap* buffer_map = op_data->buffer_map; + for (auto tensor_index : op_data->subgraph_inputs) { + TfLiteTensor* tensor = &context->tensors[tensor_index]; + if (IsConstantTensor(tensor)) { + if (!buffer_map->HasTensor(tensor_index)) { + buffer_map->SetFromTfLite(tensor_index, tensor); + } + } + } + + // All output tensors are allocated by TensorFlow/Eager, so we + // mark them as kTfLiteDynamic. + for (auto tensor_index : op_data->subgraph_outputs) { + SetTensorToDynamic(&context->tensors[tensor_index]); + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + BufferMap* buffer_map = op_data->buffer_map; + tensorflow::EagerContext* eager_context = op_data->eager_context; + + // Insert a tensor in the buffer map for all inputs that are not constant. + // Constants were handled in Prepare() already. + for (auto tensor_index : op_data->subgraph_inputs) { + TfLiteTensor* tensor = &context->tensors[tensor_index]; + if (!IsConstantTensor(tensor)) { + buffer_map->SetFromTfLite(tensor_index, tensor); + } + } + + // Execute the TensorFlow Ops sequentially. + for (const auto& node_data : op_data->nodes) { + if (node_data.nodedef.op().empty()) { + context->ReportError(context, "Invalid NodeDef in Eager op '%s'", + node_data.name.c_str()); + return kTfLiteError; + } + auto status = + ExecuteEagerOp(eager_context, buffer_map, node_data.name, + node_data.nodedef, node_data.inputs, node_data.outputs); + TF_LITE_ENSURE_OK(context, ConvertStatus(context, status)); + } + + for (auto tensor_index : op_data->subgraph_outputs) { + if (!buffer_map->HasTensor(tensor_index)) { + context->ReportError(context, "Cannot write to invalid tensor index %d", + tensor_index); + return kTfLiteError; + } + + TfLiteTensor* tensor = &context->tensors[tensor_index]; + TF_LITE_ENSURE_OK( + context, + CopyShape(context, buffer_map->GetTensor(tensor_index), tensor)); + tensor->buffer_handle = tensor_index; + tensor->data_is_stale = true; + } + + return kTfLiteOk; +} + +} // namespace kernel + +TfLiteRegistration GetKernel() { + TfLiteRegistration registration{&kernel::Init, &kernel::Free, + &kernel::Prepare, &kernel::Eval, + nullptr, kTfLiteBuiltinDelegate}; + return registration; +} + +} // namespace eager +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/eager/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..100672c82dcd3eaee17325f3b712140b081e8efe --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/kernel.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { +namespace eager { + +// Return the registration object used to initialize and execute ops that will +// be delegated to TensorFlow's Eager runtime. This TF Lite op is created by +// the eager delegate to handle execution of a supported subgraph. The usual +// flow is that the delegate informs the interpreter of supported nodes in a +// graph, and each supported subgraph is replaced with one instance of this +// kernel. +TfLiteRegistration GetKernel(); + +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..66f2226626677fa26a8c0eb2ae8ef448ed35c141 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc @@ -0,0 +1,230 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/kernel.h" + +#include +#include +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" +#include "tensorflow/contrib/lite/delegates/eager/test_util.h" + +namespace tflite { +namespace eager { +namespace { + +using ::testing::ContainsRegex; +using ::testing::ElementsAre; + +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate, + const std::vector& supported_nodes) { + TfLiteIntArray* size_and_nodes = + ConvertVectorToTfLiteIntArray(supported_nodes); + TF_LITE_ENSURE_STATUS(context->ReplaceSubgraphsWithDelegateKernels( + context, eager::GetKernel(), size_and_nodes, delegate)); + TfLiteIntArrayFree(size_and_nodes); + return kTfLiteOk; +} + +class KernelTest : public testing::EagerModelTest { + public: + KernelTest() { + CHECK(DelegateData::Create(&delegate_data_).ok()); + interpreter_.reset(new Interpreter(&error_reporter_)); + } + + ~KernelTest() override { + // The data needs to be released before the interpreter because the + // interpreter references the data. + delegate_data_.reset(); + interpreter_.reset(); + } + + template + void ConfigureDelegate(T prepare_function) { + delegate_.data_ = delegate_data_.get(); + delegate_.FreeBufferHandle = nullptr; + delegate_.Prepare = prepare_function; + delegate_.CopyFromBufferHandle = [](TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, size_t size) { + auto* delegate_data = reinterpret_cast(delegate->data_); + tensorflow::StringPiece values = delegate_data->GetBufferMap(context) + ->GetTensor(buffer_handle) + .tensor_data(); + memcpy(data, values.data(), values.size()); + return kTfLiteOk; + }; + CHECK(interpreter_->ModifyGraphWithDelegate( + &delegate_, /*allow_dynamic_tensors=*/true) == kTfLiteOk); + } + + private: + std::unique_ptr delegate_data_; + TfLiteDelegate delegate_; +}; + +TEST_F(KernelTest, FullGraph) { + // Define the graph. + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + + // Apply Delegate. + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1, 2, 3, 4}); + }); + + // Define inputs. + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(KernelTest, BadTensorFlowOp) { + AddTensors(2, {0}, {1}, kTfLiteFloat32, {3}); + AddTfOp(testing::kNonExistent, {0}, {1}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("while processing attributes of 'NonExistentOp'")); +} + +TEST_F(KernelTest, BadNumberOfOutputs) { + AddTensors(3, {0}, {1, 2}, kTfLiteFloat32, {3}); + AddTfOp(testing::kIdentity, {0}, {1, 2}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("Unexpected number of outputs")); +} + +TEST_F(KernelTest, IncompatibleNodeDef) { + AddTensors(2, {0}, {1}, kTfLiteFloat32, {3}); + + // Cast is a TF op, but we don't add the proper nodedef to it in AddTfOp. + AddTfOp(testing::kIncompatibleNodeDef, {0}, {1}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("while executing 'Cast' via Eager")); +} + +TEST_F(KernelTest, WrongSetOfNodes) { + AddTensors(4, {0}, {3}, kTfLiteFloat32, {3}); + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfLiteMulOp({1, 2}, {3}); + + // Specify that testing::kMul (#1) is supported when it actually isn't. + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("Invalid NodeDef in Eager op")); +} + +TEST_F(KernelTest, MixedGraph) { + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfLiteMulOp({6, 7}, {8}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1, 2, 3}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(KernelTest, SplitGraph) { + AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kAdd, {1, 2}, {3}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + + AddTfLiteMulOp({4, 5}, {6}); + + AddTfOp(testing::kUnpack, {6}, {7, 8}); + AddTfOp(testing::kAdd, {7, 8}, {9}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1, 2, 4, 5}); + }); + + SetShape(0, {2, 2, 2, 1}); + SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(9), ElementsAre(1)); + ASSERT_THAT(GetValues(9), ElementsAre(10.0f)); +} + +} // namespace +} // namespace eager +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..26d96acc82064ba1046555940e1b1132874ef23e --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc @@ -0,0 +1,155 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/delegates/eager/test_util.h" + +#include "absl/memory/memory.h" +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace eager { +namespace testing { + +bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; } + +void EagerModelTest::SetValues(int tensor_index, + const std::vector& values) { + float* v = interpreter_->typed_tensor(tensor_index); + for (float f : values) { + *v++ = f; + } +} + +std::vector EagerModelTest::GetValues(int tensor_index) { + TfLiteTensor* o = interpreter_->tensor(tensor_index); + return std::vector(o->data.f, o->data.f + o->bytes / sizeof(float)); +} + +void EagerModelTest::SetShape(int tensor_index, + const std::vector& values) { + ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); +} + +std::vector EagerModelTest::GetShape(int tensor_index) { + std::vector result; + auto* dims = interpreter_->tensor(tensor_index)->dims; + result.reserve(dims->size); + for (int i = 0; i < dims->size; ++i) { + result.push_back(dims->data[i]); + } + return result; +} + +void EagerModelTest::AddTensors(int num_tensors, const std::vector& inputs, + const std::vector& outputs, + const TfLiteType& type, + const std::vector& dims) { + interpreter_->AddTensors(num_tensors); + for (int i = 0; i < num_tensors; ++i) { + TfLiteQuantizationParams quant; + CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type, + /*name=*/"", + /*dims=*/dims, quant), + kTfLiteOk); + } + + CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk); + CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk); +} + +void EagerModelTest::AddTfLiteMulOp(const std::vector& inputs, + const std::vector& outputs) { + static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.builtin_code = BuiltinOperator_MUL; + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + auto* i0 = &context->tensors[node->inputs->data[0]]; + auto* o = &context->tensors[node->outputs->data[0]]; + return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims)); + }; + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + auto* i0 = &context->tensors[node->inputs->data[0]]; + auto* i1 = &context->tensors[node->inputs->data[1]]; + auto* o = &context->tensors[node->outputs->data[0]]; + for (int i = 0; i < o->bytes / sizeof(float); ++i) { + o->data.f[i] = i0->data.f[i] * i1->data.f[i]; + } + return kTfLiteOk; + }; + + CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0, + nullptr, ®), + kTfLiteOk); +} + +void EagerModelTest::AddTfOp(TfOpType op, const std::vector& inputs, + const std::vector& outputs) { + auto attr = [](const string& key, const string& value) { + return " attr{ key: '" + key + "' value {" + value + "}}"; + }; + + if (op == kUnpack) { + string attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") + + attr("axis", "i: 0"); + AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs); + } else if (op == kIdentity) { + string attributes = attr("T", "type: DT_FLOAT"); + AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs); + } else if (op == kAdd) { + string attributes = attr("T", "type: DT_FLOAT"); + AddTfOp("EagerAdd", "Add", attributes, inputs, outputs); + } else if (op == kMul) { + string attributes = attr("T", "type: DT_FLOAT"); + AddTfOp("EagerMul", "Mul", attributes, inputs, outputs); + } else if (op == kNonExistent) { + AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs); + } else if (op == kIncompatibleNodeDef) { + // "Cast" op is created without attributes - making it incompatible. + AddTfOp("EagerCast", "Cast", "", inputs, outputs); + } +} + +void EagerModelTest::AddTfOp(const char* tflite_name, const string& tf_name, + const string& nodedef_str, + const std::vector& inputs, + const std::vector& outputs) { + static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.builtin_code = BuiltinOperator_CUSTOM; + reg.custom_name = tflite_name; + + tensorflow::NodeDef nodedef; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + nodedef_str + " op: '" + tf_name + "'", &nodedef)); + string serialized_nodedef; + CHECK(nodedef.SerializeToString(&serialized_nodedef)); + flexbuffers::Builder fbb; + fbb.Vector([&]() { + fbb.String(nodedef.op()); + fbb.String(serialized_nodedef); + }); + fbb.Finish(); + + flexbuffers_.push_back(fbb.GetBuffer()); + auto& buffer = flexbuffers_.back(); + CHECK_EQ(interpreter_->AddNodeWithParameters( + inputs, outputs, reinterpret_cast(buffer.data()), + buffer.size(), nullptr, ®), + kTfLiteOk); +} + +} // namespace testing +} // namespace eager +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/eager/test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0eab9e1135f02b4f22a4b36a85cf6771fbbb81d5 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/test_util.h @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_ + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" + +namespace tflite { +namespace eager { +namespace testing { + +enum TfOpType { + kUnpack, + kIdentity, + kAdd, + kMul, + // Represents an op that does not exist in TensorFlow. + kNonExistent, + // Represents an valid TensorFlow op where the NodeDef is incompatible. + kIncompatibleNodeDef, +}; + +// This class creates models with TF and TFLite ops. In order to use this class +// to test the Eager delegate, implement a function that calls +// interpreter->ModifyGraphWithDelegate. +class EagerModelTest : public ::testing::Test { + public: + EagerModelTest() {} + ~EagerModelTest() {} + + bool Invoke(); + + // Sets the tensor's values at the given index. + void SetValues(int tensor_index, const std::vector& values); + + // Returns the tensor's values at the given index. + std::vector GetValues(int tensor_index); + + // Sets the tensor's shape at the given index. + void SetShape(int tensor_index, const std::vector& values); + + // Returns the tensor's shape at the given index. + std::vector GetShape(int tensor_index); + + const TestErrorReporter& error_reporter() const { return error_reporter_; } + + // Adds `num_tensor` tensors to the model. `inputs` contains the indices of + // the input tensors and `outputs` contains the indices of the output + // tensors. All tensors are set to have `type` and `dims`. + void AddTensors(int num_tensors, const std::vector& inputs, + const std::vector& outputs, const TfLiteType& type, + const std::vector& dims); + + // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors + // and `outputs` contains the indices of the output tensors. + void AddTfLiteMulOp(const std::vector& inputs, + const std::vector& outputs); + + // Adds a TensorFlow op. `inputs` contains the indices of the + // input tensors and `outputs` contains the indices of the output tensors. + // This function is limited to the set of ops defined in TfOpType. + void AddTfOp(TfOpType op, const std::vector& inputs, + const std::vector& outputs); + + protected: + std::unique_ptr interpreter_; + TestErrorReporter error_reporter_; + + private: + // Helper method to add a TensorFlow op. tflite_names needs to start with + // "Eager" in order to work with the Eager delegate. + void AddTfOp(const char* tflite_name, const string& tf_name, + const string& nodedef_str, const std::vector& inputs, + const std::vector& outputs); + + std::vector> flexbuffers_; +}; + +} // namespace testing +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc new file mode 100644 index 0000000000000000000000000000000000000000..c8aa0b7f69f8f6bd3bff52b13f3cc7d689a514da --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/util.cc @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/delegates/eager/constants.h" + +namespace tflite { +namespace eager { + +bool IsEagerOp(const char* custom_name) { + return custom_name && strncmp(custom_name, kCustomCodePrefix, + strlen(kCustomCodePrefix)) == 0; +} + +TfLiteStatus ConvertStatus(TfLiteContext* context, + const tensorflow::Status& status) { + if (!status.ok()) { + context->ReportError(context, "%s", status.error_message().c_str()); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src, + TfLiteTensor* tensor) { + int num_dims = src.dims(); + TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims); + for (int j = 0; j < num_dims; ++j) { + // We need to cast from TensorFlow's int64 to TF Lite's int32. Let's + // make sure there's no overflow. + if (src.dim_size(j) >= std::numeric_limits::max()) { + context->ReportError(context, + "Dimension value in TensorFlow shape is larger than " + "supported by TF Lite"); + TfLiteIntArrayFree(shape); + return kTfLiteError; + } + shape->data[j] = static_cast(src.dim_size(j)); + } + return context->ResizeTensor(context, tensor, shape); +} + +TF_DataType GetTensorFlowDataType(TfLiteType type) { + switch (type) { + case kTfLiteNoType: + return TF_FLOAT; + case kTfLiteFloat32: + return TF_FLOAT; + case kTfLiteInt16: + return TF_INT16; + case kTfLiteInt32: + return TF_INT32; + case kTfLiteUInt8: + return TF_UINT8; + case kTfLiteInt64: + return TF_INT64; + case kTfLiteComplex64: + return TF_COMPLEX64; + case kTfLiteString: + return TF_STRING; + case kTfLiteBool: + return TF_BOOL; + } +} + +} // namespace eager +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h new file mode 100644 index 0000000000000000000000000000000000000000..b7363361bec47f30e0741e3a76a5a375d7d9aeb1 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/util.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_ + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tflite { +namespace eager { + +// Checks whether the prefix of the custom name indicates the operation is an +// Eager operation. +bool IsEagerOp(const char* custom_name); + +// Converts a tensorflow:Status into a TfLiteStatus. If the original status +// represented an error, reports it using the given 'context'. +TfLiteStatus ConvertStatus(TfLiteContext* context, + const tensorflow::Status& status); + +// Copies the given shape of the given 'src' into a TF Lite 'tensor'. Logs an +// error and returns kTfLiteError if the shape can't be converted. +TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src, + TfLiteTensor* tensor); + +// Returns the TF C API Data type that corresponds to the given TfLiteType. +TF_DataType GetTensorFlowDataType(TfLiteType type); + +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..541d0b170197f7ac657cccfb79769522887e87e5 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/util.h" + +#include + +#include +#include +#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace eager { +namespace { + +using tensorflow::DT_FLOAT; +using tensorflow::Tensor; +using ::testing::ElementsAre; + +struct TestContext : public TfLiteContext { + string error; + std::vector new_size; +}; + +void ReportError(TfLiteContext* context, const char* format, ...) { + TestContext* c = static_cast(context); + const size_t kBufferSize = 1024; + char temp_buffer[kBufferSize]; + + va_list args; + va_start(args, format); + vsnprintf(temp_buffer, kBufferSize, format, args); + va_end(args); + + c->error = temp_buffer; +} + +TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor, + TfLiteIntArray* new_size) { + TestContext* c = static_cast(context); + c->new_size.clear(); + for (int i = 0; i < new_size->size; ++i) { + c->new_size.push_back(new_size->data[i]); + } + TfLiteIntArrayFree(new_size); + return kTfLiteOk; +} + +TEST(UtilTest, ConvertStatus) { + TestContext context; + context.ReportError = ReportError; + + EXPECT_EQ(ConvertStatus(&context, tensorflow::errors::Internal("Some Error")), + kTfLiteError); + EXPECT_EQ(context.error, "Some Error"); + + context.error.clear(); + EXPECT_EQ(ConvertStatus(&context, tensorflow::Status()), kTfLiteOk); + EXPECT_TRUE(context.error.empty()); +} + +TEST(UtilTest, CopyShape) { + TestContext context; + context.ReportError = ReportError; + context.ResizeTensor = ResizeTensor; + + TfLiteTensor dst; + + EXPECT_EQ(CopyShape(&context, Tensor(), &dst), kTfLiteOk); + EXPECT_THAT(context.new_size, ElementsAre(0)); + + EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1, 2}), &dst), kTfLiteOk); + EXPECT_THAT(context.new_size, ElementsAre(1, 2)); + + EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst), + kTfLiteError); + EXPECT_EQ(context.error, + "Dimension value in TensorFlow shape is larger than supported by " + "TF Lite"); +} + +TEST(UtilTest, TypeConversions) { + EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType)); + EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32)); + EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16)); + EXPECT_EQ(TF_INT32, GetTensorFlowDataType(kTfLiteInt32)); + EXPECT_EQ(TF_UINT8, GetTensorFlowDataType(kTfLiteUInt8)); + EXPECT_EQ(TF_INT64, GetTensorFlowDataType(kTfLiteInt64)); + EXPECT_EQ(TF_COMPLEX64, GetTensorFlowDataType(kTfLiteComplex64)); + EXPECT_EQ(TF_STRING, GetTensorFlowDataType(kTfLiteString)); + EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool)); +} + +TEST(UtilTest, IsEagerOp) { + EXPECT_TRUE(IsEagerOp("Eager")); + EXPECT_TRUE(IsEagerOp("EagerOp")); + EXPECT_FALSE(IsEagerOp("eager")); + EXPECT_FALSE(IsEagerOp("Eage")); + EXPECT_FALSE(IsEagerOp("OpEager")); + EXPECT_FALSE(IsEagerOp(nullptr)); + EXPECT_FALSE(IsEagerOp("")); +} + +} // namespace +} // namespace eager +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD index 35a8f6ca4166e373ea1a0af5d4a013327b30d2b6..954955f24b87f79a8dbe2863f608d532e25902c6 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/BUILD +++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD @@ -22,6 +22,10 @@ tf_cc_test( name = "nnapi_delegate_test", size = "small", srcs = ["nnapi_delegate_test.cc"], + tags = [ + "no_oss", + "noasan", # TODO(b/112326936): re-enable for asan once fixed. + ], deps = [ ":nnapi_delegate", "//tensorflow/contrib/lite:framework", diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc index 0731d14419d2dec2ea5efa48ef5d4b7728af635f..e6cc3dd99c2e18bf297f8fac244e5d809954a01a 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -26,6 +26,12 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" +#ifdef __ANDROID__ +#include +#include +#include +#endif + namespace tflite { namespace { @@ -37,6 +43,32 @@ namespace { return kTfLiteError; \ } +namespace { +int32_t GetAndroidSdkVersion() { +#ifdef __ANDROID__ + const char* sdkProp = "ro.build.version.sdk"; + char sdkVersion[PROP_VALUE_MAX]; + int length = __system_property_get(sdkProp, sdkVersion); + if (length != 0) { + for (int i = 0; i < length; ++i) { + int digit = sdkVersion[i] - '0'; + if (digit < 0 || digit > 9) { + // Non-numeric SDK version, assume it's higher then expected; + return std::numeric_limits::max(); + } + } + return atoi(sdkVersion); + } +#endif // __ANDROID__ + return 0; +} + +constexpr int32_t kMinSdkVersionForNNAPI = 27; +constexpr int32_t kMinSdkVersionForNNAPI11 = 28; +static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion(); + +} // namespace + // RAII NN API Model Destructor for use with std::unique_ptr struct NNFreeModel { void operator()(ANeuralNetworksModel* model) { @@ -50,6 +82,44 @@ struct NNFreeCompilation { } }; +// Manage NNAPI shared memory handle +class NNMemory { + public: + NNMemory(const char* name, size_t size) { +#ifdef __ANDROID__ + byte_size_ = size; + fd_ = ASharedMemory_create(name, size); + data_ptr_ = reinterpret_cast( + mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd_, 0)); + ANeuralNetworksMemory_createFromFd(size, PROT_READ | PROT_WRITE, fd_, 0, + &nn_memory_handle_); +#endif + } + + ~NNMemory() { +#ifdef __ANDROID__ + if (data_ptr_) { + munmap(data_ptr_, byte_size_); + } + if (nn_memory_handle_) { + ANeuralNetworksMemory_free(nn_memory_handle_); + } + if (fd_ > 0) close(fd_); +#endif + } + + ANeuralNetworksMemory* get_handle() { return nn_memory_handle_; } + uint8_t* get_data_ptr() { return data_ptr_; } + + private: +#ifdef __ANDROID__ + int fd_ = 0; + size_t byte_size_ = 0; +#endif + uint8_t* data_ptr_ = nullptr; + ANeuralNetworksMemory* nn_memory_handle_ = nullptr; +}; // namespace + // Track tensor indices to NN API tensor indices mapping. class OperandMapping { public: @@ -71,7 +141,7 @@ class OperandMapping { // Add a new mapping from `tflite_index` and return the NN API tensor index. int add_new_ann_tensor_index(int tflite_index) { if (tflite_index >= lite_tensor_to_ann_tensor_.size()) { - lite_tensor_to_ann_tensor_.resize(tflite_index + 1); + lite_tensor_to_ann_tensor_.resize(tflite_index + 1, -1); } int new_tensor_index = next_ann_tensor_index_++; lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index; @@ -98,14 +168,34 @@ class NNAPIOpBuilder { operand_mapping_(tensor_mapping), nn_model_(nn_model) {} - TfLiteStatus AddScalarInt32Operand(int value) { - ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32}; - CHECK_NN(context_, - ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); - int ann_operand = operand_mapping_->add_new_non_tensor_operand(); - CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( - nn_model_, ann_operand, &value, sizeof(int32_t))); - augmented_inputs_.push_back(ann_operand); + TfLiteStatus AddScalarInt32Operand(int32_t value) { + return AddScalarOperand(value, ANEURALNETWORKS_INT32); + } + + TfLiteStatus AddScalarFloat32Operand(float value) { + return AddScalarOperand(value, ANEURALNETWORKS_FLOAT32); + } + + TfLiteStatus AddVectorInt32Operand(const int32_t* values, + uint32_t num_values) { + return AddVectorOperand(values, num_values, + ANEURALNETWORKS_TENSOR_INT32); + } + + TfLiteStatus AddVectorFloat32Operand(const float* values, + uint32_t num_values) { + return AddVectorOperand(values, num_values, + ANEURALNETWORKS_TENSOR_FLOAT32); + } + + TfLiteStatus AddPoolingParams(void* data) { + auto builtin = reinterpret_cast(data); + AddScalarInt32Operand(builtin->padding); + AddScalarInt32Operand(builtin->stride_width); + AddScalarInt32Operand(builtin->stride_height); + AddScalarInt32Operand(builtin->filter_width); + AddScalarInt32Operand(builtin->filter_height); + AddScalarInt32Operand(builtin->activation); return kTfLiteOk; } @@ -123,6 +213,37 @@ class NNAPIOpBuilder { return kTfLiteOk; } + TfLiteStatus AddAdditionalFloat32OutputTensor(uint32_t dimension_count) { + std::vector dims(dimension_count, 0); + ANeuralNetworksOperandType operand_type{ + .type = ANEURALNETWORKS_TENSOR_FLOAT32, + .dimensionCount = dimension_count, + .dimensions = dims.data()}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + augmented_outputs_.push_back(ann_operand); + return kTfLiteOk; + } + + TfLiteStatus AddStateFloat32Tensor(int tensor_index, + int* ann_tensor_index_out) { + TfLiteTensor* tensor = &context_->tensors[tensor_index]; + int ann_index = operand_mapping_->add_new_non_tensor_operand(); + + ANeuralNetworksOperandType operand_type{ + ANEURALNETWORKS_TENSOR_FLOAT32, + static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), tensor->params.scale, + tensor->params.zero_point}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + augmented_inputs_.push_back(ann_index); + + *ann_tensor_index_out = ann_index; + return kTfLiteOk; + } + // Adds a new NN API tensor that shadows the TF Lite tensor `tensor_index`. // This returns the NN API tensor index corresponding to the created tensor. // If another caller previously created a NN API tensor for `tensor_index` @@ -149,17 +270,20 @@ class NNAPIOpBuilder { return kTfLiteOk; case kTfLiteFloat32: nn_type = ANEURALNETWORKS_TENSOR_FLOAT32; - scale = 0.f; break; case kTfLiteUInt8: nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; scale = tensor->params.scale; zeroPoint = tensor->params.zero_point; + if (scale == 0) { + // TENSOR_QUANT8_ASYMM with zero scale is not valid in NNAPI. + scale = 1; + } break; case kTfLiteInt32: nn_type = ANEURALNETWORKS_TENSOR_INT32; - scale = 0.f; - zeroPoint = 0; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; break; default: context_->ReportError(context_, "Logic error in NN API Delegate.\n"); @@ -192,12 +316,39 @@ class NNAPIOpBuilder { augmented_inputs_.data(), static_cast(augmented_outputs_.size()), augmented_outputs_.data())); - augmented_outputs_.clear(); + augmented_inputs_.clear(); augmented_outputs_.clear(); return kTfLiteOk; } private: + template + TfLiteStatus AddScalarOperand(T value, int32_t nn_type) { + ANeuralNetworksOperandType operand_type{.type = nn_type}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_operand, &value, sizeof(T))); + augmented_inputs_.push_back(ann_operand); + return kTfLiteOk; + } + + template + TfLiteStatus AddVectorOperand(const T* values, uint32_t num_values, + int32_t nn_type) { + ANeuralNetworksOperandType operand_type{ + .type = nn_type, .dimensionCount = 1, .dimensions = &num_values}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + CHECK_NN(context_, + ANeuralNetworksModel_setOperandValue( + nn_model_, ann_operand, values, sizeof(T) * num_values)); + augmented_inputs_.push_back(ann_operand); + return kTfLiteOk; + } + // TfLiteContext for error handling. Must be named context for macros to // work. TfLiteContext* context_; @@ -215,41 +366,550 @@ class NNAPIOpBuilder { std::vector augmented_outputs_; }; +struct NNAPIOpMappingArgs { + TfLiteContext* context; + NNAPIOpBuilder* builder; + TfLiteNode* node; + std::vector* model_state_inputs; + std::vector* model_state_tfl_outputs; +}; + // The kernel that represents the subgraph of TF Lite being run on NN API. class NNAPIDelegateKernel { public: NNAPIDelegateKernel() = default; - typedef ANeuralNetworksOperationType (*MappingFn)(TfLiteContext*, - NNAPIOpBuilder* builder, - TfLiteNode* node); + typedef ANeuralNetworksOperationType (*MappingFn)( + const NNAPIOpMappingArgs& mapping_args); // Return a function that knows how to translate a node into its operands // when called. You can use this function to see if a node is supported // (i.e. that MappingFn is not nullptr). - MappingFn Map(TfLiteContext* context, int builtin_code, TfLiteNode* node) { + MappingFn Map(TfLiteContext* context, int builtin_code, int version, + TfLiteNode* node) { switch (builtin_code) { case kTfLiteBuiltinAdd: - return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { - auto builtin = reinterpret_cast(node->builtin_data); - builder->AddScalarInt32Operand(builtin->activation); - return ANEURALNETWORKS_ADD; - }; + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_ADD; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinMul: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_MUL; + }; + } else { + return nullptr; + } break; case kTfLiteBuiltinAveragePool2d: - return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + mapping_args.builder->AddPoolingParams( + mapping_args.node->builtin_data); + return ANEURALNETWORKS_AVERAGE_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinMaxPool2d: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + mapping_args.builder->AddPoolingParams( + mapping_args.node->builtin_data); + return ANEURALNETWORKS_MAX_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinL2Pool2d: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + mapping_args.builder->AddPoolingParams( + mapping_args.node->builtin_data); + return ANEURALNETWORKS_L2_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinConv2d: + if (version == 1) { auto builtin = - reinterpret_cast(node->builtin_data); - builder->AddScalarInt32Operand(builtin->padding); - builder->AddScalarInt32Operand(builtin->stride_width); - builder->AddScalarInt32Operand(builtin->stride_height); - builder->AddScalarInt32Operand(builtin->filter_width); - builder->AddScalarInt32Operand(builtin->filter_height); - builder->AddScalarInt32Operand(builtin->activation); - return ANEURALNETWORKS_AVERAGE_POOL_2D; + reinterpret_cast(node->builtin_data); + if (builtin->dilation_width_factor != 1 || + builtin->dilation_height_factor != 1 || node->inputs->size != 3) { + // NNAPI does not support dilated Conv2D. + return nullptr; + } + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->padding); + mapping_args.builder->AddScalarInt32Operand(builtin->stride_width); + mapping_args.builder->AddScalarInt32Operand(builtin->stride_height); + mapping_args.builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_CONV_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinDepthwiseConv2d: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->padding); + mapping_args.builder->AddScalarInt32Operand(builtin->stride_width); + mapping_args.builder->AddScalarInt32Operand(builtin->stride_height); + mapping_args.builder->AddScalarInt32Operand( + builtin->depth_multiplier); + mapping_args.builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_DEPTHWISE_CONV_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinFullyConnected: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_FULLY_CONNECTED; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSoftmax: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarFloat32Operand(builtin->beta); + return ANEURALNETWORKS_SOFTMAX; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinReshape: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RESHAPE; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSqueeze: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + // Note that we add the squeeze dimensions even if the dimensions + // were unspecified (empty), as NNAPI requires the operand. + mapping_args.builder->AddVectorInt32Operand( + builtin->squeeze_dims, + static_cast(builtin->num_squeeze_dims)); + return ANEURALNETWORKS_SQUEEZE; + }; + } else { + return nullptr; + } + case kTfLiteBuiltinL2Normalization: { + auto builtin = + reinterpret_cast(node->builtin_data); + if (builtin->activation != kTfLiteActNone) { + // NNAPI does not support activations + return nullptr; + } + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_L2_NORMALIZATION; }; + } + case kTfLiteBuiltinLocalResponseNormalization: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->radius); + mapping_args.builder->AddScalarFloat32Operand(builtin->bias); + mapping_args.builder->AddScalarFloat32Operand(builtin->alpha); + mapping_args.builder->AddScalarFloat32Operand(builtin->beta); + return ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION; + }; + } else { + // TODO(miaowang): clean-up code and return early in the unsupported + // case. + return nullptr; + } + break; + case kTfLiteBuiltinLshProjection: + if (version == 1) { + // NNAPI does not support sparse projection correctly (b/111751836). + if (reinterpret_cast(node->builtin_data) + ->type == kTfLiteLshProjectionSparse) { + return nullptr; + } + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->type); + return ANEURALNETWORKS_LSH_PROJECTION; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinConcatenation: + if (version == 1 && + reinterpret_cast(node->builtin_data) + ->activation == kTfLiteActNone) { + if (context->tensors[node->inputs->data[0]].type == kTfLiteUInt8) { + // NNAPI only support concatenating quantized tensor of the same + // scale and offset. + auto first_param = context->tensors[node->inputs->data[0]].params; + for (int i = 0; i < node->inputs->size; i++) { + auto curr_param = context->tensors[node->inputs->data[i]].params; + if (curr_param.scale != first_param.scale || + curr_param.zero_point != first_param.zero_point) { + return nullptr; + } + } + } + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->axis); + return ANEURALNETWORKS_CONCATENATION; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinDequantize: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_DEQUANTIZE; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinFloor: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_FLOOR; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinRelu: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RELU; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinReluN1To1: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RELU1; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinRelu6: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RELU6; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinLogistic: + if (version == 1) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_LOGISTIC; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinTanh: + // TODO(miaowang): add additional checks for the parameters. + if (version == 1 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + // NNAPI only support float tanh. + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_TANH; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSub: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + // NNAPI only support float sub. + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_SUB; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinDiv: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + // NNAPI only support float div. + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_DIV; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinPad: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + node->inputs->size == 2 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + // NNAPI does not support specifying the padding value. + // NNAPI pads physical zero for quantized tensors, so only delegate + // float pad to NNAPI. + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_PAD; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSpaceToBatchNd: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_SPACE_TO_BATCH_ND; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinStridedSlice: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->begin_mask); + mapping_args.builder->AddScalarInt32Operand(builtin->end_mask); + mapping_args.builder->AddScalarInt32Operand( + builtin->shrink_axis_mask); + return ANEURALNETWORKS_STRIDED_SLICE; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinTranspose: + // Note that the permutation input tensor value dictates the output + // dimensions. + // TODO(b/110888333): Support dynamically-sized tensors in delegates. + if ((version == 1) && + (kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) && + (node->inputs->size > 1) && + (context->tensors[node->inputs->data[1]].allocation_type == + kTfLiteMmapRo)) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_TRANSPOSE; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinRnn: + // NNAPI only support float32 weights. + // TODO(miaowang): check the number of inputs before accessing it. + if (version == 1 && + context->tensors[node->inputs->data[/*kWeightsTensor*/ 1]].type == + kTfLiteFloat32) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + // NNAPI need both state_in and state_out. + int ann_index; + mapping_args.builder->AddStateFloat32Tensor( + mapping_args.node->outputs->data[/*kHiddenStateTensor*/ 0], + &ann_index); + mapping_args.model_state_inputs->push_back(ann_index); + mapping_args.model_state_tfl_outputs->push_back( + mapping_args.node->outputs->data[/*kHiddenStateTensor*/ 0]); + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_RNN; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSvdf: + // NNAPI only support float32 weights. + if (version == 1 && + context->tensors[node->inputs->data[/*kWeightsFeatureTensor*/ 1]] + .type == kTfLiteFloat32) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + // NNAPI need both state_in and state_out. + int ann_index; + mapping_args.builder->AddStateFloat32Tensor( + mapping_args.node->outputs->data[/*kStateTensor*/ 0], + &ann_index); + mapping_args.model_state_inputs->push_back(ann_index); + mapping_args.model_state_tfl_outputs->push_back( + mapping_args.node->outputs->data[/*kStateTensor*/ 0]); + + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->rank); + mapping_args.builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_SVDF; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinLstm: + // NNAPI only support float32 weights. + // TODO(miaowang): add loggings to indicate why the op is rejected. + if (version == 1 && node->inputs->size == 18 && + context->tensors[node->inputs + ->data[/*kInputToOutputWeightsTensor*/ 4]] + .type == kTfLiteFloat32) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + // NNAPI need both state_in and state_out for cell_state and + // output_state. + int ann_index; + mapping_args.builder->AddStateFloat32Tensor( + mapping_args.node->outputs->data[/*kOutputStateTensor*/ 0], + &ann_index); + mapping_args.model_state_inputs->push_back(ann_index); + mapping_args.model_state_tfl_outputs->push_back( + mapping_args.node->outputs->data[/*kOutputStateTensor*/ 0]); + mapping_args.builder->AddStateFloat32Tensor( + mapping_args.node->outputs->data[/*kCellStateTensor*/ 1], + &ann_index); + mapping_args.model_state_inputs->push_back(ann_index); + mapping_args.model_state_tfl_outputs->push_back( + mapping_args.node->outputs->data[/*kCellStateTensor*/ 1]); + + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + mapping_args.builder->AddScalarInt32Operand(builtin->activation); + mapping_args.builder->AddScalarFloat32Operand(builtin->cell_clip); + mapping_args.builder->AddScalarFloat32Operand(builtin->proj_clip); + + // Current NNAPI implementation requires the sratch_buffer as + // output. + mapping_args.builder->AddAdditionalFloat32OutputTensor(2); + return ANEURALNETWORKS_LSTM; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinMean: + // NNAPI does not support generating a scalar as output for MEAN. + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32 && + context->tensors[node->outputs->data[0]].dims->size > 0) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + mapping_args.node->builtin_data); + int32_t keep_dims = 0; + if (builtin->keep_dims) keep_dims = 1; + mapping_args.builder->AddScalarInt32Operand(keep_dims); + return ANEURALNETWORKS_MEAN; + }; + } else { + return nullptr; + } + case kTfLiteBuiltinEmbeddingLookup: + // NNAPI only support float32 values. + if (version == 1 && + context->tensors[node->inputs->data[1]].type == kTfLiteFloat32) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_EMBEDDING_LOOKUP; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinHashtableLookup: + // NNAPI only support float32 output. + if (version == 1 && + context->tensors[node->outputs->data[0]].type == kTfLiteFloat32) { + return [](const NNAPIOpMappingArgs& mapping_args) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_HASHTABLE_LOOKUP; + }; + } else { + return nullptr; + } break; default: return nullptr; @@ -290,23 +950,56 @@ class NNAPIDelegateKernel { // Set the input tensor buffers. Note: we access tflite tensors using // absolute indices but NN api indices inputs by relative indices. int relative_input_index = 0; + int num_optional_tensors = 0; + + size_t input_offset = 0; for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) { + if (absolute_input_index == kOptionalTensor) { + num_optional_tensors++; + continue; + } TfLiteTensor* tensor = &context->tensors[absolute_input_index]; - CHECK_NN(context, ANeuralNetworksExecution_setInput( - execution, relative_input_index, nullptr, - tensor->data.raw, tensor->bytes)); - relative_input_index++; + // TODO(miaowang): make sure the delegation works with dequantized weights + // as intermediate tensors. + if (tensor->allocation_type != kTfLiteMmapRo) { + // copy data to pre-allocated shared memory. + memcpy(nn_input_memory_->get_data_ptr() + input_offset, + tensor->data.raw, tensor->bytes); + CHECK_NN(context, ANeuralNetworksExecution_setInputFromMemory( + execution, relative_input_index, nullptr, + nn_input_memory_->get_handle(), input_offset, + tensor->bytes)); + input_offset += tensor->bytes; + relative_input_index++; + } } // Set the output tensor buffers. int relative_output_index = 0; + size_t output_offset = 0; for (auto output_index : TfLiteIntArrayView(node->outputs)) { TfLiteTensor* tensor = &context->tensors[output_index]; - CHECK_NN(context, ANeuralNetworksExecution_setOutput( + CHECK_NN(context, ANeuralNetworksExecution_setOutputFromMemory( execution, relative_output_index, nullptr, - tensor->data.raw, tensor->bytes)); + nn_output_memory_->get_handle(), output_offset, + tensor->bytes)); + output_offset += tensor->bytes; relative_output_index++; } + + // The state_out of previous invocation need to be mapped to state_in of + // current invocation. + for (size_t i = 0; i < model_state_tfl_outputs_.size(); i++) { + int state_tensor_idx = model_state_tfl_outputs_[i]; + TfLiteTensor* tensor = &context->tensors[state_tensor_idx]; + // Here we are using a deep copy for state_in tensors so that we are not + // reading and writing into the same buffer during a invocation. + // TODO(110369471): using double shared buffer to minimize the copies. + CHECK_NN(context, + ANeuralNetworksExecution_setInput( + execution, i + node->inputs->size - num_optional_tensors, + nullptr, tensor->data.raw, tensor->bytes)); + } // Invoke ANN in blocking fashion. ANeuralNetworksEvent* event = nullptr; CHECK_NN(context, ANeuralNetworksExecution_startCompute(execution, &event)); @@ -314,6 +1007,15 @@ class NNAPIDelegateKernel { ANeuralNetworksEvent_free(event); ANeuralNetworksExecution_free(execution); + // copy results from shared memory to the destination. + output_offset = 0; + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + TfLiteTensor* tensor = &context->tensors[output_index]; + memcpy(tensor->data.raw, + nn_output_memory_->get_data_ptr() + output_offset, tensor->bytes); + output_offset += tensor->bytes; + } + return kTfLiteOk; } @@ -328,6 +1030,12 @@ class NNAPIDelegateKernel { // Track indices we use OperandMapping operand_mapping_; + std::vector model_state_inputs_; + std::vector model_state_tfl_outputs_; + + std::unique_ptr nn_input_memory_; + std::unique_ptr nn_output_memory_; + TfLiteStatus AddOpsAndTensors(TfLiteContext* context) { // The operand builder allows creating a single op. We create it at this // reduced power position rather than in the for loop to avoid reallocating @@ -342,11 +1050,22 @@ class NNAPIDelegateKernel { context->GetNodeAndRegistration(context, node_index, &node, ®); // Map inputs to NN API tensor indices. for (auto input_index : TfLiteIntArrayView(node->inputs)) { - TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); + if (input_index == kOptionalTensor && + (reg->builtin_code == kTfLiteBuiltinLstm || + reg->builtin_code == kTfLiteBuiltinSvdf)) { + // properly handle the optional tensor for LSTM and SVDF. + // currently only support float32. + // TODO(miaowang): make sure this is also able to handle quantized + // tensor when supported by NNAPI. + TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0)); + } else { + TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); + } } // Get op type and operands - int nn_op_type = - Map(context, reg->builtin_code, node)(context, &builder, node); + int nn_op_type = Map(context, reg->builtin_code, reg->version, + node)({context, &builder, node, &model_state_inputs_, + &model_state_tfl_outputs_}); // Map outputs to NN API tensor indices. for (auto output_index : TfLiteIntArrayView(node->outputs)) { TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); @@ -367,11 +1086,29 @@ class NNAPIDelegateKernel { inputs.reserve(input_tensors->size); std::vector outputs; outputs.reserve(output_tensors->size); + + size_t total_input_byte_size = 0; // Make the TensorFlow lite inputs and outputs to ann_indices. - for (int i : TfLiteIntArrayView(input_tensors)) - inputs.push_back(operand_mapping_.lite_index_to_ann(i)); - for (int i : TfLiteIntArrayView(output_tensors)) + for (int i : TfLiteIntArrayView(input_tensors)) { + // Constant tensors are not NNAPI inputs. + if (i != kOptionalTensor && + context->tensors[i].allocation_type != kTfLiteMmapRo) { + inputs.push_back(operand_mapping_.lite_index_to_ann(i)); + total_input_byte_size += context->tensors[i].bytes; + } + } + + // Add state input tensors as model inputs + for (int i : model_state_inputs_) { + inputs.push_back(i); + } + + size_t total_output_byte_size = 0; + for (int i : TfLiteIntArrayView(output_tensors)) { outputs.push_back(operand_mapping_.lite_index_to_ann(i)); + total_output_byte_size += context->tensors[i].bytes; + } + // Tell ANN to declare inputs/outputs CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs( nn_model_.get(), inputs.size(), inputs.data(), @@ -379,6 +1116,11 @@ class NNAPIDelegateKernel { // Finalize the model CHECK_NN(context, ANeuralNetworksModel_finish(nn_model_.get())); + // Create shared memory pool for inputs and outputs. + nn_input_memory_.reset(new NNMemory("input_pool", total_input_byte_size)); + nn_output_memory_.reset( + new NNMemory("output_pool", total_output_byte_size)); + return kTfLiteOk; } }; @@ -392,7 +1134,9 @@ TfLiteDelegate* NnApiDelegate() { .Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) -> TfLiteStatus { // Do not check nodes_ if NN API is unavailable. - if (!NNAPIExists()) return kTfLiteOk; + if (kAndroidSdkVersion < kMinSdkVersionForNNAPI || !NNAPIExists()) { + return kTfLiteOk; + } std::vector supported_nodes(1); // We don't care about all nodes_, we only care about ones in the @@ -400,6 +1144,7 @@ TfLiteDelegate* NnApiDelegate() { TfLiteIntArray* plan; TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); int total_supported_nodes = 0; + // Check for every node if it is supported // TODO(b/80625235): Fix this to do more careful checking of versioning. for (int node_index : TfLiteIntArrayView(plan)) { @@ -408,7 +1153,8 @@ TfLiteDelegate* NnApiDelegate() { TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( context, node_index, &node, ®istration)); NNAPIDelegateKernel dummy_kernel; - if (dummy_kernel.Map(context, registration->builtin_code, node)) { + if (dummy_kernel.Map(context, registration->builtin_code, + registration->version, node)) { supported_nodes.push_back(node_index); } total_supported_nodes += 1; diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc index ff2e721423f07889f36746a2889afcc3369f28fc..3224b23a0c3bc8456bd75f2923d16f0eed7d53ff 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -21,16 +21,26 @@ limitations under the License. namespace tflite { namespace { +using ::testing::ElementsAre; using ::testing::ElementsAreArray; -class FloatAddOpModel : public SingleOpModel { +// TODO(b/110368244): figure out how to share the existing tests in kernels/ but +// with the delegation on. Also, add more unit tests to improve code coverage. + +class SingleOpModelWithNNAPI : public SingleOpModel { + public: + SingleOpModelWithNNAPI() { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate(), false); + }); + } +}; + +class FloatAddOpModel : public SingleOpModelWithNNAPI { public: FloatAddOpModel(const TensorData& input1, const TensorData& input2, const TensorData& output, ActivationFunctionType activation_type) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); input1_ = AddInput(input1); input2_ = AddInput(input2); output_ = AddOutput(output); @@ -72,6 +82,3439 @@ TEST(NNAPIDelegate, AddWithRelu) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3})); } +class FloatMulOpModel : public SingleOpModelWithNNAPI { + public: + FloatMulOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(NNAPIDelegate, MulWithNoActivation) { + FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); +} + +class FloatPoolingOpModel : public SingleOpModelWithNNAPI { + public: + FloatPoolingOpModel(BuiltinOperator type, const TensorData& input, + int filter_width, int filter_height, + const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + + SetBuiltinOp( + type, BuiltinOptions_Pool2DOptions, + CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width, + filter_height, ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, AveragePoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75})); +} + +TEST(NNAPIDelegate, MaxPoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10})); +} + +TEST(NNAPIDelegate, L2PoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5})); +} + +class BaseConvolutionOpModel : public SingleOpModelWithNNAPI { + public: + BaseConvolutionOpModel( + const TensorData& input, const TensorData& filter, + const TensorData& output, int stride_width = 2, int stride_height = 2, + enum Padding padding = Padding_VALID, + enum ActivationFunctionType activation = ActivationFunctionType_NONE, + int dilation_width_factor = 1, int dilation_height_factor = 1) { + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[0]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + if (input.type != TensorType_FLOAT32) { + // The following is required by quantized inference. It is the unittest's + // responsibility to make sure the output scale falls into the correct + // range. + CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); + } + + SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, + CreateConv2DOptions( + builder_, padding, stride_width, stride_height, activation, + dilation_width_factor, dilation_height_factor) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +class ConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + QuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// In this tests we set the input and output scales so that the results +// match exactly the 'non-quantized' version. +TEST(NNAPIDelegate, SimpleTestQuantized) { + QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + }, + 1e-5))); + // For good measure, let's also verify the quantized values: + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 145, 129, 132, // + 145, 129, 132, // + 144, 131, 130, // + 164, 131, 130, // + })); +} + +TEST(NNAPIDelegate, Conv2DWithNoActivation) { + ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + })); +} + +class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI { + public: + DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter, + const TensorData& output) { + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[3]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + int input_depth = GetShape(input_)[3]; + int output_depth = GetShape(filter_)[3]; + int depth_mul = output_depth / input_depth; + + SetBuiltinOp( + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOptions_DepthwiseConv2DOptions, + CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul, + ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) { + DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}}, + {TensorType_FLOAT32, {1, 2, 2, 4}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }); + m.SetFilter({ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }); + m.SetBias({1, 2, 3, 4}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 71, -34, 99, -20, // + 91, -26, 127, -4, // + })); +} + +class FloatFullyConnectedOpModel : public SingleOpModelWithNNAPI { + public: + FloatFullyConnectedOpModel(int units, int batches, const TensorData& input, + const TensorData& output = {TensorType_FLOAT32}) + : batches_(batches), units_(units) { + int total_input_size = 1; + for (int i = 0; i < input.shape.size(); ++i) { + total_input_size *= input.shape[i]; + } + input_size_ = total_input_size / batches_; + + input_ = AddInput(input); + weights_ = + AddInput({input.type, {units_, input_size_}, input.min, input.max}); + + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {units_}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(weights_); + TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + SetBuiltinOp( + BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + .Union()); + BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int weights_; + int bias_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +TEST(NNAPIDelegate, FullyConnectedSimpleTest) { + FloatFullyConnectedOpModel m(/*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 10}}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); +} + +class SoftmaxOpModel : public SingleOpModelWithNNAPI { + public: + SoftmaxOpModel(int batches, int size, float beta) + : batches_(batches), input_size_(size), beta_(beta) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, beta_).Union()); + BuildInterpreter({{batches_, input_size_}}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; + + int batches_; + int input_size_; + float beta_; +}; + +TEST(NNAPIDelegate, SoftmaxSimpleTest) { + SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0); + m.SetInput({ + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + }); + + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647, + 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231}, + 1e-6))); +} + +class ReshapeOpModel : public SingleOpModelWithNNAPI { + public: + ReshapeOpModel(std::initializer_list input_shape, + std::initializer_list new_shape) { + input_ = AddInput(TensorType_FLOAT32); + new_shape_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, + CreateReshapeOptions(builder_, builder_.CreateVector(new_shape)) + .Union()); + BuildInterpreter({input_shape, {static_cast(new_shape.size())}}); + PopulateTensor(new_shape_, new_shape); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int new_shape_; + int output_; +}; + +TEST(NNAPIDelegate, ReshapeSimpleTest) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); +} + +class SqueezeOpModel : public SingleOpModelWithNNAPI { + public: + SqueezeOpModel(const TensorData& input, const TensorData& output, + std::initializer_list axis) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp( + BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions, + CreateSqueezeOptions(builder_, builder_.CreateVector(axis)) + .Union()); + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int new_shape_; + int output_; +}; + +TEST(NNAPIDelegate, SqueezeSimpleTest) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}}, + {}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +TEST(NNAPIDelegate, SqueezeWithAxisTest) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}}, + {2}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +class L2NormOpModel : public SingleOpModelWithNNAPI { + public: + L2NormOpModel(const TensorData& input, const TensorData& output, + ActivationFunctionType activation_type) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions, + CreateL2NormOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int new_shape_; + int output_; +}; + +TEST(NNAPIDelegate, L2NormSimpleTest) { + std::initializer_list data = {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}; + L2NormOpModel m({TensorType_FLOAT32, {1, 1, 1, 6}}, + {TensorType_FLOAT32, {1, 1, 1, 6}}, + ActivationFunctionType_NONE); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 6})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})); +} + +class TransposeSimpleModel : public SingleOpModelWithNNAPI { + public: + TransposeSimpleModel(std::initializer_list input_shape, + std::initializer_list perm_shape, + std::initializer_list perm) { + input_ = AddInput(TensorType_FLOAT32); + perm_ = AddConstInput(TensorType_INT32, perm, perm_shape); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions, + CreateTransposeOptions(builder_).Union()); + BuildInterpreter({input_shape, perm_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int perm_; + int output_; +}; + +TEST(NNAPIDelegate, TransposeSimpleTest) { + TransposeSimpleModel m({2, 3, 4}, {3}, {2, 0, 1}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, + 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23})); +} + +class FloatSubOpModel : public SingleOpModelWithNNAPI { + public: + FloatSubOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_SUB, BuiltinOptions_SubOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(NNAPIDelegate, SubWithNoActivation) { + FloatSubOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-2.1, 0.0, 0.4, 0.3}))); +} + +class FloatDivOpModel : public SingleOpModelWithNNAPI { + public: + FloatDivOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_DIV, BuiltinOptions_DivOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(NNAPIDelegate, DivWithNoActivation) { + FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.8, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.4, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-20, 1, 2, 4}))); +} + +class BaseConcatenationOpModel : public SingleOpModelWithNNAPI { + public: + BaseConcatenationOpModel() {} + BaseConcatenationOpModel(const TensorData& input_template, int axis, + int num_inputs) { + std::vector> all_input_shapes; + for (int i = 0; i < num_inputs; ++i) { + all_input_shapes.push_back(input_template.shape); + AddInput(input_template); + } + output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min, + input_template.max}); + SetBuiltinOp( + BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions, + CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE) + .Union()); + BuildInterpreter(all_input_shapes); + } + + protected: + int output_; +}; + +class ConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + void SetInput(int index, std::initializer_list data) { + PopulateTensor(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(NNAPIDelegate, ConcatenationThreeDimensionalOneInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7})); +} + +TEST(NNAPIDelegate, ConcatenationFourInputs) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2, + /*num_inputs=*/4); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + })); +} + +class QuantizedConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + QuantizedConcatenationOpModel(const std::vector& input_template, + int axis, int num_inputs, + const TensorData& output_template) { + std::vector> all_input_shapes; + CHECK_EQ(input_template.size(), num_inputs); + for (int i = 0; i < num_inputs; ++i) { + all_input_shapes.push_back(input_template[i].shape); + AddInput(input_template[i]); + } + output_ = AddOutput({output_template.type, /*shape=*/{}, + output_template.min, output_template.max}); + SetBuiltinOp( + BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions, + CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE) + .Union()); + BuildInterpreter(all_input_shapes); + } + void SetInput(int index, std::initializer_list data) { + QuantizeAndPopulate(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(NNAPIDelegate, ConcatenationFourInputsQuantized) { + QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8}, + /*axis=*/2, + /*num_inputs=*/4); + + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + }))); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ + 137, 157, 138, 158, 139, 159, 140, 160, // + 167, 197, 168, 198, 169, 199, 170, 200, // + })); +} + +TEST(NNAPIDelegate, ConcatenationFourInputsQuantizedMixedRange) { + QuantizedConcatenationOpModel m0({{TensorType_UINT8, {2, 1, 2}, -10.7, 10.8}, + {TensorType_UINT8, {2, 1, 2}, 0, 12.8}, + {TensorType_UINT8, {2, 1, 2}, -11, 11.8}, + {TensorType_UINT8, {2, 1, 2}, 0, 7.4}}, + /*axis=*/2, /*num_inputs=*/4, + {TensorType_UINT8, {2, 1, 2}, -12.7, 12.8}); + + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + }))); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ + 137, 157, 138, 158, 139, 159, 140, 160, // + 167, 197, 168, 198, 169, 199, 170, 200, // + })); +} + +class DequantizeOpModel : public SingleOpModelWithNNAPI { + public: + DequantizeOpModel(std::initializer_list shape, float min, float max) { + input_ = AddInput({TensorType_UINT8, shape, min, max}); + output_ = AddOutput({TensorType_FLOAT32, shape}); + SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions, + CreateDequantizeOptions(builder_).Union()); + + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, DequantizeFourDimensional) { + DequantizeOpModel m({2, 5}, -63.5, 64); + + m.SetInput({0, 1, 2, 3, 4, 251, 252, 253, 254, 255}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64}))); +} + +class FloorOpModel : public SingleOpModelWithNNAPI { + public: + FloorOpModel(std::initializer_list input_shape, TensorType input_type) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_FLOOR, BuiltinOptions_NONE, 0); + BuildInterpreter({ + input_shape, + }); + } + + int input() { return input_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, FloorSingleDim) { + FloorOpModel model({2}, TensorType_FLOAT32); + model.PopulateTensor(model.input(), {8.5, 0.0}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({8, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +TEST(NNAPIDelegate, FloorMultiDims) { + FloorOpModel model({2, 1, 1, 5}, TensorType_FLOAT32); + model.PopulateTensor(model.input(), { + 0.0001, + 8.0001, + 0.9999, + 9.9999, + 0.5, + -0.0001, + -8.0001, + -0.9999, + -9.9999, + -0.5, + }); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({0, 8, 0, 9, 0, -1, -9, -1, -10, -1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5})); +} + +class LocalResponseNormOpModel : public SingleOpModelWithNNAPI { + public: + LocalResponseNormOpModel(std::initializer_list input_shape, int radius, + float bias, float alpha, float beta) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + BuiltinOptions_LocalResponseNormalizationOptions, + CreateLocalResponseNormalizationOptions(builder_, radius, bias, + alpha, beta) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, LocalResponseNormSameAsL2Norm) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/1.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 2. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}))); +} + +TEST(NNAPIDelegate, LocalResponseNormWithAlpha) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 3. + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + {-0.275, 0.15, 0.175, 0.3, -0.175, 0.025}))); +} + +TEST(NNAPIDelegate, LocalResponseNormWithBias) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 5. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02}))); +} + +TEST(NNAPIDelegate, LocalResponseNormSmallRadius) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-0.264926, 0.125109, 0.140112, 0.267261, -0.161788, 0.0244266}))); +} + +class LSHProjectionOpModel : public SingleOpModelWithNNAPI { + public: + LSHProjectionOpModel(LSHProjectionType type, + std::initializer_list hash_shape, + std::initializer_list input_shape, + std::initializer_list weight_shape) { + hash_ = AddInput(TensorType_FLOAT32); + input_ = AddInput(TensorType_INT32); + if (weight_shape.size() > 0) { + weight_ = AddInput(TensorType_FLOAT32); + } + output_ = AddOutput(TensorType_INT32); + + SetBuiltinOp(BuiltinOperator_LSH_PROJECTION, + BuiltinOptions_LSHProjectionOptions, + CreateLSHProjectionOptions(builder_, type).Union()); + if (weight_shape.size() > 0) { + BuildInterpreter({hash_shape, input_shape, weight_shape}); + } else { + BuildInterpreter({hash_shape, input_shape}); + } + + output_size_ = 1; + for (int i : hash_shape) { + output_size_ *= i; + if (type == LSHProjectionType_SPARSE) { + break; + } + } + } + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetHash(std::initializer_list data) { + PopulateTensor(hash_, data); + } + + void SetWeight(std::initializer_list f) { PopulateTensor(weight_, f); } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int hash_; + int weight_; + int output_; + + int output_size_; +}; + +TEST(NNAPIDelegate, LSHProjectionDense1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0)); +} + +TEST(NNAPIDelegate, LSHProjectionSparse1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0)); +} + +TEST(NNAPIDelegate, LSHProjectionSparse3DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5}); + + m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912, + 9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1)); +} + +class BaseActivationsOpModel : public SingleOpModelWithNNAPI { + public: + // Most activations don't take any options, so this constructor works for + // them. + BaseActivationsOpModel(BuiltinOperator type, TensorData input) { + input_ = AddInput(input); + if (input.type == TensorType_UINT8) { + output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); + } else { + output_ = AddOutput({input.type, {}}); + } + SetBuiltinOp(type, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_)}); + } + + BaseActivationsOpModel(BuiltinOperator type, const TensorData& input, + const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(type, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + int output_; +}; + +class FloatActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +const float kQuantizedTolerance = 2 * (1. / 256); + +class QuantizedActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + template + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + template + + std::vector GetOutput() { + return ExtractVector(output_); + } + template + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); + } +}; + +TEST(NNAPIDelegate, Relu) { + FloatActivationsOpModel m(BuiltinOperator_RELU, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 10, 1, // + })); +} + +TEST(NNAPIDelegate, Relu1) { + FloatActivationsOpModel m(BuiltinOperator_RELU_N1_TO_1, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -2.0, 1.1, -0.1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -1.0, 1.0, -0.1, // + })); +} + +TEST(NNAPIDelegate, Relu6) { + FloatActivationsOpModel m(BuiltinOperator_RELU6, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 6, 1, // + })); +} + +TEST(NNAPIDelegate, Tanh) { + FloatActivationsOpModel m(BuiltinOperator_TANH, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0, -0.9999877, 0.9640275, 0.999329, // + 0.99505475, -0.9640275, 1, 0.7615941, // + }))); +} + +TEST(NNAPIDelegate, LogisticFloat) { + FloatActivationsOpModel m(BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }))); +} + +TEST(NNAPIDelegate, LogisticQuantized) { + QuantizedActivationsOpModel m( + BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }, + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188})); +} + +#if 0 +class ResizeBilinearOpModel : public SingleOpModelWithNNAPI { + public: + ResizeBilinearOpModel(const TensorData& input, + std::initializer_list size_data = {}) { + bool const_size = size_data.size() != 0; + input_ = AddInput(input); + if (const_size) { + size_ = AddConstInput(TensorType_INT32, size_data, {2}); + } else { + size_ = AddInput({TensorType_INT32, {2}}); + } + output_ = AddOutput(input.type); + SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR, + BuiltinOptions_ResizeBilinearOptions, + CreateResizeBilinearOptions(builder_).Union()); + if (const_size) { + BuildInterpreter({GetShape(input_)}); + } else { + BuildInterpreter({GetShape(input_), GetShape(size_)}); + } + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetSize(std::initializer_list data) { PopulateTensor(size_, data); } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + private: + int input_; + int size_; + int output_; +}; + +TEST(NNAPIDelegate, ResizeBilinearHorizontal) { + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); + m.SetInput({3, 6}); + m.SetSize({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} + +TEST(NNAPIDelegate, ResizeBilinearVertical) { + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); + m.SetInput({3, 9}); + m.SetSize({3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, 9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} + +TEST(NNAPIDelegate, ResizeBilinearTwoDimensional) { + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); + m.SetInput({ + 3, 6, // + 9, 12 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); +} +#endif + +template +class PadOpModel : public SingleOpModelWithNNAPI { + public: + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetQuantizedInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetQuantizedPadValue(float data) { + QuantizeAndPopulate(constant_values_, {data}); + } + + void SetPaddings(std::initializer_list paddings) { + PopulateTensor(paddings_, paddings); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + + protected: + int input_; + int output_; + int paddings_; + int constant_values_; +}; + +class PadOpConstModel : public PadOpModel { + public: + PadOpConstModel(const TensorData& input, + std::initializer_list paddings_shape, + std::initializer_list paddings, + const TensorData& output) { + input_ = AddInput(input); + paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape); + output_ = AddOutput(output); + + SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions, + CreatePadOptions(builder_).Union()); + BuildInterpreter({input.shape}); + } +}; + +TEST(NNAPIDelegate, PadAdvancedConstTest) { + PadOpConstModel m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2}, + {0, 0, 0, 2, 1, 3, 0, 0}, {TensorType_FLOAT32}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +class SpaceToBatchNDOpModel : public SingleOpModelWithNNAPI { + public: + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetBlockShape(std::initializer_list data) { + PopulateTensor(block_shape_, data); + } + + void SetPaddings(std::initializer_list data) { + PopulateTensor(paddings_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int block_shape_; + int paddings_; + int output_; +}; + +class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel { + public: + SpaceToBatchNDOpConstModel(std::initializer_list input_shape, + std::initializer_list block_shape, + std::initializer_list paddings) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); + paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2}); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOptions_SpaceToBatchNDOptions, + CreateSpaceToBatchNDOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +TEST(NNAPIDelegate, SpaceToBatchNDSimpleConstTest) { + SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(NNAPIDelegate, SpaceToBatchNDMultipleInputBatchesConstTest) { + SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(NNAPIDelegate, SpaceToBatchNDSimplePaddingConstTest) { + SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7, + 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10, + })); +} + +TEST(NNAPIDelegate, SpaceToBatchNDComplexPaddingConstTest) { + SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, + 0, 1, 0, 0, 0, 7, 0, 0, 0, 2, 0, 0, 0, 8, 0, 0, + 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, + })); +} + +template +class StridedSliceOpModel : public SingleOpModelWithNNAPI { + public: + StridedSliceOpModel(std::initializer_list input_shape, + std::initializer_list begin_shape, + std::initializer_list end_shape, + std::initializer_list strides_shape, int begin_mask, + int end_mask, int ellipsis_mask, int new_axis_mask, + int shrink_axis_mask) { + input_ = AddInput(tensor_input_type); + begin_ = AddInput(TensorType_INT32); + end_ = AddInput(TensorType_INT32); + strides_ = AddInput(TensorType_INT32); + output_ = AddOutput(tensor_input_type); + SetBuiltinOp( + BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions, + CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask, + new_axis_mask, shrink_axis_mask) + .Union()); + BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetBegin(std::initializer_list data) { + PopulateTensor(begin_, data); + } + void SetEnd(std::initializer_list data) { + PopulateTensor(end_, data); + } + void SetStrides(std::initializer_list data) { + PopulateTensor(strides_, data); + } + + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int begin_; + int end_; + int strides_; + int output_; +}; + +TEST(NNAPIDelegate, StridedSliceIn2D) { + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetEnd({2, 2}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5})); +} + +TEST(NNAPIDelegate, StridedSliceIn2D_ShrinkAxis_NegativeSlice) { + // This is equivalent to tf.range(4)[:, tf.newaxis][-2, -1]. + StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + m.SetInput({0, 1, 2, 3}); + m.SetBegin({-2, -1}); + m.SetEnd({-1, 0}); + m.SetStrides({1, 1}); + + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TEST(NNAPIDelegate, StridedSliceIn2D_ShrinkAxisMask) { + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({1, 1}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} + +static float rnn_input[] = { + 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, + 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, + -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, + 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, + 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, + 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, + -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, + -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, + 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, + 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, + 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, + -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, + 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, + -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, + -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, + -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, + 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, + -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, + -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, + 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, + -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, + 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, + 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, + 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, + -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, + 0.93455386, -0.6324693, -0.083922029}; + +static float rnn_golden_output[] = { + 0.496726, 0, 0.965996, 0, 0.0584254, 0, + 0, 0.12315, 0, 0, 0.612266, 0.456601, + 0, 0.52286, 1.16099, 0.0291232, + + 0, 0, 0.524901, 0, 0, 0, + 0, 1.02116, 0, 1.35762, 0, 0.356909, + 0.436415, 0.0355727, 0, 0, + + 0, 0, 0, 0.262335, 0, 0, + 0, 1.33992, 0, 2.9739, 0, 0, + 1.31914, 2.66147, 0, 0, + + 0.942568, 0, 0, 0, 0.025507, 0, + 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, + 0.8158, 1.21805, 0.586239, 0.25427, + + 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, + 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, + 0, 1.22031, 1.30117, 0.495867, + + 0.222187, 0, 0.72725, 0, 0.767003, 0, + 0, 0.147835, 0, 0, 0, 0.608758, + 0.469394, 0.00720298, 0.927537, 0, + + 0.856974, 0.424257, 0, 0, 0.937329, 0, + 0, 0, 0.476425, 0, 0.566017, 0.418462, + 0.141911, 0.996214, 1.13063, 0, + + 0.967899, 0, 0, 0, 0.0831304, 0, + 0, 1.00378, 0, 0, 0, 1.44818, + 1.01768, 0.943891, 0.502745, 0, + + 0.940135, 0, 0, 0, 0, 0, + 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, + 1.30225, 1.59644, 0.70222, 0, + + 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, + 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, + 0.0454298, 0.300267, 0.562784, 0.395095, + + 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, + 0, 0, 0, 0.735363, 0.0759267, 1.91017, + 0.941888, 0, 0, 0, + + 0, 0, 1.5909, 0, 0, 0, + 0, 0.5755, 0, 0.184687, 0, 1.56296, + 0.625285, 0, 0, 0, + + 0, 0, 0.0857888, 0, 0, 0, + 0, 0.488383, 0.252786, 0, 0, 0, + 1.02817, 1.85665, 0, 0, + + 0.00981836, 0, 1.06371, 0, 0, 0, + 0, 0, 0, 0.290445, 0.316406, 0, + 0.304161, 1.25079, 0.0707152, 0, + + 0.986264, 0.309201, 0, 0, 0, 0, + 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, + 0.524981, 1.92076, 2.07013, 0.333244, + + 0.415153, 0.210318, 0, 0, 0, 0, + 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, + 0.628881, 3.58099, 1.49974, 0}; + +static std::initializer_list rnn_weights = { + 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}; + +static std::initializer_list rnn_recurrent_weights = { + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}; + +static std::initializer_list rnn_bias = { + 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568, + -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, + 0.37197268, 0.61957061, 0.3956964, -0.37609905}; + +class RNNOpModel : public SingleOpModelWithNNAPI { + public: + RNNOpModel(int batches, int units, int size, + const TensorType& weights = TensorType_FLOAT32, + const TensorType& recurrent_weights = TensorType_FLOAT32) + : batches_(batches), units_(units), input_size_(size) { + input_ = AddInput(TensorType_FLOAT32); + weights_ = AddInput(weights); + recurrent_weights_ = AddInput(recurrent_weights); + bias_ = AddInput(TensorType_FLOAT32); + hidden_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RNN, BuiltinOptions_RNNOptions, + CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); + BuildInterpreter({{batches_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetRecurrentWeights(std::initializer_list f) { + PopulateTensor(recurrent_weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + void ResetHiddenState() { + const int zero_buffer_size = units_ * batches_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(hidden_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + protected: + int input_; + int weights_; + int recurrent_weights_; + int bias_; + int hidden_state_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +TEST(NNAPIDelegate, RnnBlackBoxTest) { + RNNOpModel rnn(2, 16, 8); + rnn.SetWeights(rnn_weights); + rnn.SetBias(rnn_bias); + rnn.SetRecurrentWeights(rnn_recurrent_weights); + + rnn.ResetHiddenState(); + const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / + (rnn.input_size() * rnn.num_batches()); + + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(rnn.input_size(), batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output + i * rnn.num_units(); + float* golden_end = golden_start + rnn.num_units(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +static float svdf_input[] = { + 0.12609188, -0.46347019, -0.89598465, + 0.35867718, 0.36897406, 0.73463392, + + 0.14278367, -1.64410412, -0.75222826, + -0.57290924, 0.12729003, 0.7567004, + + 0.49837467, 0.19278903, 0.26584083, + 0.17660543, 0.52949083, -0.77931279, + + -0.11186574, 0.13164264, -0.05349274, + -0.72674477, -0.5683046, 0.55900657, + + -0.68892461, 0.37783599, 0.18263303, + -0.63690937, 0.44483393, -0.71817774, + + -0.81299269, -0.86831826, 1.43940818, + -0.95760226, 1.82078898, 0.71135032, + + -1.45006323, -0.82251364, -1.69082689, + -1.65087092, -1.89238167, 1.54172635, + + 0.03966608, -0.24936394, -0.77526885, + 2.06740379, -1.51439476, 1.43768692, + + 0.11771342, -0.23761693, -0.65898693, + 0.31088525, -1.55601168, -0.87661445, + + -0.89477462, 1.67204106, -0.53235275, + -0.6230064, 0.29819036, 1.06939757, +}; + +static float svdf_golden_output_rank_1[] = { + 0.014899, -0.0517661, -0.143725, -0.00271883, + -0.03004015, 0.09565311, 0.1587342, 0.00784263, + + 0.068281, -0.162217, -0.152268, 0.00323521, + 0.01582633, 0.03858774, -0.03001583, -0.02671271, + + -0.0317821, -0.0333089, 0.0609602, 0.0333759, + -0.01432795, 0.05524484, 0.1101355, -0.02382665, + + -0.00623099, -0.077701, -0.391193, -0.0136691, + -0.02333033, 0.02293761, 0.12338032, 0.04326871, + + 0.201551, -0.164607, -0.179462, -0.0592739, + 0.01064911, -0.17503069, 0.07821996, -0.00224009, + + 0.0886511, -0.0875401, -0.269283, 0.0281379, + -0.02282338, 0.09741908, 0.32973239, 0.12281385, + + -0.201174, -0.586145, -0.628624, -0.0330412, + 0.24780814, -0.39304617, -0.22473189, 0.02589256, + + -0.0839096, -0.299329, 0.108746, 0.109808, + 0.10084175, -0.06416984, 0.28936723, 0.0026358, + + 0.419114, -0.237824, -0.422627, 0.175115, + -0.2314795, -0.18584411, -0.4228974, -0.12928449, + + 0.36726, -0.522303, -0.456502, -0.175475, + 0.17012937, -0.34447709, 0.38505614, -0.28158101, +}; + +static float svdf_golden_output_rank_2[] = { + -0.09623547, -0.10193135, 0.11083051, -0.0347917, + 0.1141196, 0.12965347, -0.12652366, 0.01007236, + + -0.16396809, -0.21247184, 0.11259045, -0.04156673, + 0.10132131, -0.06143532, -0.00924693, 0.10084561, + + 0.01257364, 0.0506071, -0.19287863, -0.07162561, + -0.02033747, 0.22673416, 0.15487903, 0.02525555, + + -0.1411963, -0.37054959, 0.01774767, 0.05867489, + 0.09607603, -0.0141301, -0.08995658, 0.12867066, + + -0.27142537, -0.16955489, 0.18521598, -0.12528358, + 0.00331409, 0.11167502, 0.02218599, -0.07309391, + + 0.09593632, -0.28361851, -0.0773851, 0.17199151, + -0.00075242, 0.33691186, -0.1536046, 0.16572715, + + -0.27916506, -0.27626723, 0.42615682, 0.3225764, + -0.37472126, -0.55655634, -0.05013514, 0.289112, + + -0.24418658, 0.07540751, -0.1940318, -0.08911639, + 0.00732617, 0.46737891, 0.26449674, 0.24888524, + + -0.17225097, -0.54660404, -0.38795233, 0.08389944, + 0.07736043, -0.28260678, 0.15666828, 1.14949894, + + -0.57454878, -0.64704704, 0.73235172, -0.34616736, + 0.21120001, -0.22927976, 0.02455296, -0.35906726, +}; + +class BaseSVDFOpModel : public SingleOpModelWithNNAPI { + public: + BaseSVDFOpModel(int batches, int units, int input_size, int memory_size, + int rank, + TensorType weights_feature_type = TensorType_FLOAT32, + TensorType weights_time_type = TensorType_FLOAT32) + : batches_(batches), + units_(units), + input_size_(input_size), + memory_size_(memory_size), + rank_(rank) { + input_ = AddInput(TensorType_FLOAT32); + weights_feature_ = AddInput(weights_feature_type); + weights_time_ = AddInput(weights_time_type); + bias_ = AddNullInput(); + state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, + CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union()); + BuildInterpreter({ + {batches_, input_size_}, // Input tensor + {units_ * rank, input_size_}, // weights_feature tensor + {units_ * rank, memory_size_}, // weights_time tensor + {units_} // bias tensor + }); + } + + // Populates the weights_feature tensor. + void SetWeightsFeature(std::initializer_list f) { + PopulateTensor(weights_feature_, f); + } + + // Populates the weights_time tensor. + void SetWeightsTime(std::initializer_list f) { + PopulateTensor(weights_time_, f); + } + + // Populates the input tensor. + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + // Resets the state of SVDF op by filling it with 0's. + void ResetState() { + const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + // Extracts the output tensor from the SVDF op. + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + protected: + int input_; + int weights_feature_; + int weights_time_; + int bias_; + int state_; + int output_; + + int batches_; + int units_; + int input_size_; + int memory_size_; + int rank_; +}; + +class SVDFOpModel : public BaseSVDFOpModel { + public: + using BaseSVDFOpModel::BaseSVDFOpModel; + + void VerifyGoldens(float golden_input[], float golden_output[], + int golden_size, float tolerance = 1e-5) { + const int svdf_num_batches = num_batches(); + const int svdf_input_size = input_size(); + const int svdf_num_units = num_units(); + const int input_sequence_size = + golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF + // op and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = + golden_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + SetInput(0, batch_start, batch_end); + + Invoke(); + + const float* golden_start = + golden_output + i * svdf_num_units * svdf_num_batches; + const float* golden_end = + golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +TEST(NNAPIDelegate, SVDFBlackBoxTestRank1) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/1); + svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, + 0.22197971, 0.12416199, 0.27901134, 0.27557442, + 0.3905206, -0.36137494, -0.06634006, -0.10640851}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); + + svdf.ResetState(); + svdf.VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input)); +} + +TEST(NNAPIDelegate, SVDFBlackBoxTestRank2) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/2); + svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, + 0.12416199, 0.15785322, 0.27901134, 0.3905206, + 0.21931258, -0.36137494, -0.10640851, 0.31053296, + -0.36118156, -0.0976817, -0.36916667, 0.22197971, + 0.15294972, 0.38031587, 0.27557442, 0.39635518, + -0.21580373, -0.06634006, -0.02702999, 0.27072677}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, + + -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, + 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, + + -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, + 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, + + -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, + -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, + + 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, + 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); + + svdf.ResetState(); + svdf.VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input)); +} + +class LSTMOpModel : public SingleOpModelWithNNAPI { + public: + LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes, + const TensorType& weight_type = TensorType_FLOAT32) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(weight_type); + } + + input_to_forget_weights_ = AddInput(weight_type); + input_to_cell_weights_ = AddInput(weight_type); + input_to_output_weights_ = AddInput(weight_type); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(weight_type); + } + + recurrent_to_forget_weights_ = AddInput(weight_type); + recurrent_to_cell_weights_ = AddInput(weight_type); + recurrent_to_output_weights_ = AddInput(weight_type); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(weight_type); + } + cell_to_forget_weights_ = AddInput(weight_type); + cell_to_output_weights_ = AddInput(weight_type); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(weight_type); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + output_state_ = AddOutput(TensorType_FLOAT32); + cell_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void ResetOutputState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetCellState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + protected: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + int input_activation_state_; + int input_cell_state_; + + int output_; + int output_state_; + int cell_state_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + +class BaseLstmTest : public ::testing::Test { + protected: + // Weights of the LSTM model. Some are optional. + std::initializer_list input_to_input_weights_; + std::initializer_list input_to_cell_weights_; + std::initializer_list input_to_forget_weights_; + std::initializer_list input_to_output_weights_; + std::initializer_list input_gate_bias_; + std::initializer_list cell_gate_bias_; + std::initializer_list forget_gate_bias_; + std::initializer_list output_gate_bias_; + std::initializer_list recurrent_to_input_weights_; + std::initializer_list recurrent_to_cell_weights_; + std::initializer_list recurrent_to_forget_weights_; + std::initializer_list recurrent_to_output_weights_; + std::initializer_list cell_to_input_weights_; + std::initializer_list cell_to_forget_weights_; + std::initializer_list cell_to_output_weights_; + std::initializer_list projection_weights_; + + // LSTM input is stored as num_batch x num_inputs vector. + std::vector> lstm_input_; + // LSTM output is stored as num_batch x num_outputs vector. + std::vector> lstm_golden_output_; + + // Compares output up to tolerance to the result of the lstm given the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + LSTMOpModel* lstm, float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end); + } + + lstm->Invoke(); + + const int num_outputs = lstm->num_outputs(); + std::vector expected; + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + EXPECT_THAT(lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}; + input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, -0.29909778}; + input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, + -0.1556896, 0.19487578}; + input_gate_bias_ = {0., 0., 0., 0.}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_input_weights_ = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + + recurrent_to_cell_weights_ = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + + recurrent_to_forget_weights_ = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + + recurrent_to_output_weights_ = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}}; + } +}; + +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight_tensor + {n_cell, n_output}, // recurrent_to_forget_weight_tensor + {n_cell, n_output}, // recurrent_to_cell_weight_tensor + {n_cell, n_output}, // recurrent_to_output_weight_tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, + 0.05100781, 0.04717243, 0.48944736, + -0.38535351, -0.17212132}; + + input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, + 0.24407166, 0.33826375}; + + input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_cell_weights_ = { + 0.54066205, -0.32668582, -0.43562764, -0.56094903, + 0.42957711, 0.01841056, -0.32764608, -0.33027974, + -0.10826075, 0.20675004, 0.19069612, -0.03026325, + -0.54532051, 0.33003211, 0.44901288, 0.21193194}; + + recurrent_to_forget_weights_ = { + -0.13832897, -0.0515101, -0.2359007, -0.16661474, + -0.14340827, 0.36986142, 0.23414481, 0.55899, + 0.10798943, -0.41174671, 0.17751795, -0.34484994, + -0.35874045, -0.11352962, 0.27268326, 0.54058349}; + + recurrent_to_output_weights_ = { + 0.41613156, 0.42610586, -0.16495961, -0.5663873, + 0.30579174, -0.05115908, -0.33941799, 0.23364776, + 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}; + + cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, + 0.31544167}; + cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, + -0.77109635}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}}; + } +}; + +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = { + 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}; + + input_to_forget_weights_ = { + -0.0018401089, -0.004852237, 0.03698424, 0.014181704, + 0.028273236, -0.016726194, -0.05249759, -0.10204261, + 0.00861066, -0.040979505, -0.009899187, 0.01923892, + -0.028177269, -0.08535103, -0.14585495, 0.10662567, + -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, + 0.0814421, -0.12257899, -0.033945758, -0.031303465, + 0.045630626, 0.06843887, -0.13492945, -0.012480007, + -0.0811829, -0.07224499, -0.09628791, 0.045100946, + 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, + 0.052625068, 0.12784666, 0.07077897, 0.025725935, + 0.04165009, 0.07241905, 0.018668644, -0.037377294, + -0.06277783, -0.08833636, -0.040120605, -0.011405586, + -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, + 0.13396506, -0.08402166, -0.01901462, -0.044678304, + -0.07720565, 0.014350063, -0.11757958, -0.0652038, + -0.08185733, -0.076754324, -0.092614375, 0.10405491, + 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, + -0.054523353, 0.02582715, 0.02327355, -0.011857179, + -0.0011980024, -0.034641717, -0.026125094, -0.17582615, + -0.15923657, -0.27486774, -0.0006143371, 0.0001771948, + -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}; + + input_to_cell_weights_ = { + -0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}; + + input_to_output_weights_ = { + -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}; + + input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666, + 0.053110216, -0.06928846, -0.13942584, -0.11816189, + 0.19483899, 0.03652339, -0.10250295, 0.036714908, + -0.18426876, 0.036065217, 0.21810818, 0.02383196, + -0.043370757, 0.08690144, -0.04444982, 0.00030581196}; + + forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}; + + cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}; + + output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113, + 0.027195795, 0.35373217, -0.018957434, 0.008907322, + -0.0762701, 0.12018895, 0.04216877, 0.0022856654, + 0.040952638, 0.3147856, 0.08225149, -0.057416286, + -0.14995944, -0.008040261, 0.13208859, 0.029760877}; + + recurrent_to_input_weights_ = { + -0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}; + + recurrent_to_cell_weights_ = { + -0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}; + + recurrent_to_forget_weights_ = { + -0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}; + + recurrent_to_output_weights_ = { + 0.025825322, -0.05813119, 0.09495884, -0.045984812, + -0.01255415, -0.0026479573, -0.08196161, -0.054914974, + -0.0046604523, -0.029587349, -0.044576716, -0.07480124, + -0.082868785, 0.023254942, 0.027502948, -0.0039728214, + -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, + -0.08829125, -0.005139627, -0.08989442, -0.0555066, + 0.13596267, -0.025062224, -0.048351806, -0.03850004, + 0.07266485, -0.022414139, 0.05940088, 0.075114764, + 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, + 0.014416728, 0.043229222, 0.034178585, -0.07530371, + 0.035837382, -0.085607, -0.007721233, -0.03287832, + -0.043848954, -0.06404588, -0.06632928, -0.073643476, + 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, + -0.030063879, 0.008801774, -0.023021035, -0.019558564, + 0.05158114, -0.010947698, -0.011825728, 0.0075720972, + 0.0699727, -0.0039981045, 0.069350146, 0.08799282, + 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, + -0.00924166, 0.0046702605, -0.036598757, -0.08811812, + 0.10522024, -0.032441203, 0.008176899, -0.04454919, + 0.07058152, 0.0067963637, 0.039206743, 0.03259838, + 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, + 0.036879618, 0.043357447, 0.028362012, -0.05908629, + 0.0059240665, -0.04995891, -0.019187413, 0.0276265, + -0.01628143, 0.0025863599, 0.08800015, 0.035250366, + -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, + -0.009660886, 0.019076364, 0.018299393, -0.046004917, + 0.08891175, 0.0431396, -0.026327137, -0.051502608, + 0.08979574, -0.051670972, 0.04940282, -0.07491107, + -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, + -0.035936575, -0.011681591, 0.064818054, 0.0073146066, + -0.021745546, -0.043124277, -0.06471268, -0.07053354, + -0.029321948, -0.05330136, 0.016933719, -0.053782392, + 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, + -0.07924483, 0.06936997, 0.0034815092, -0.007305279, + -0.037325785, -0.07251102, -0.033633437, -0.08677009, + 0.091591336, -0.14165086, 0.021752775, 0.019683983, + 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, + 0.1183656, -0.0010731248, -0.023590032, -0.072285876, + -0.0724771, -0.026382286, -0.0014920527, 0.042667855, + 0.0018776858, 0.02986552, 0.009814309, 0.0733756, + 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, + -0.010036754, 0.02576849, -0.08307328, 0.010112348, + 0.042521734, -0.05869831, -0.071689695, 0.03876447, + -0.13275425, -0.0352966, -0.023077697, 0.10285965, + 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, + -0.08271222, -0.0030240538, -0.016368777, 0.1070414, + 0.042672627, 0.013456989, -0.0437609, -0.022309763, + 0.11576483, 0.04108048, 0.061026827, -0.0190714, + -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, + -0.023771819, -0.01965048, 0.007955471, -0.043740474, + 0.03346837, -0.10549954, 0.090567775, 0.042013682, + -0.03176985, 0.12569028, -0.02421228, -0.029526481, + 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, + -0.06861939, -0.021256343, -0.041093912, -0.06669611, + 0.035498552, 0.021757556, -0.09302526, -0.015403468, + -0.06614931, -0.051798206, -0.013874718, 0.03630673, + 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, + -0.020674974, -0.03944324, -0.008110165, -0.11113267, + 0.08484226, 0.043586485, 0.040582247, 0.0968012, + -0.065249965, -0.028036479, 0.0050708856, 0.0017462453, + 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, + -0.11768019, 0.085926116, -0.08251791, -0.045081906, + 0.0948852, 0.068401024, 0.024856757, 0.06978981, + -0.057309967, -0.012775832, -0.0032452994, 0.01977615, + -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }; + + cell_to_input_weights_ = { + 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}; + + cell_to_forget_weights_ = { + -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}; + + cell_to_output_weights_ = { + 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}; + + projection_weights_ = { + -0.009802181, 0.09401916, 0.0717386, -0.13895074, + 0.09641832, 0.060420845, 0.08539281, 0.054285463, + 0.061395317, 0.034448683, -0.042991187, 0.019801661, + -0.16840284, -0.015726732, -0.23041931, -0.024478018, + -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, + 0.16169067, 0.22465782, -0.03993472, -0.004017731, + 0.08633481, -0.28869787, 0.08682067, 0.17240396, + 0.014975425, 0.056431185, 0.031037588, 0.16702051, + 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, + 0.014777949, -0.20203483, 0.094781205, 0.19100232, + 0.13987629, -0.036132768, -0.06426278, -0.05108664, + 0.13221376, 0.009441198, -0.16715929, 0.15859416, + -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, + 0.0046453946, 0.050794356, 0.10770313, -0.20790008, + -0.07149004, -0.11425117, 0.008225835, -0.035802525, + 0.14374903, 0.15262283, 0.048710253, 0.1847461, + -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, + 0.016261552, 0.022461696, 0.12689082, -0.043589946, + -0.12035478, -0.08361797, -0.050666027, -0.1248618, + -0.1275799, -0.071875185, 0.07377272, 0.09944291, + -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, + 0.041546922, -0.20424393, 0.06907816, 0.050412357, + 0.00724631, 0.039827548, 0.12449835, 0.10747581, + 0.13708383, 0.09134148, -0.12617786, -0.06428341, + 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, + 0.022913318, -0.042050496, 0.16842307, -0.060597885, + 0.10531834, -0.06411776, -0.07451711, -0.03410368, + -0.13393489, 0.06534304, 0.003620307, 0.04490757, + 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, + -0.024575593, -0.036445823, 0.07155557, 0.009672501, + -0.02328883, 0.009533515, -0.03606021, -0.07421458, + -0.028082801, -0.2678904, -0.13221288, 0.18419984, + -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, + 0.14494097, -0.12522776, -0.098633975, -0.10766018, + -0.08317623, 0.08594209, 0.07749552, 0.039474737, + 0.1776665, -0.07409566, -0.0477268, 0.29323658, + 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, + 0.10034707, 0.045594677, 0.0635285, -0.0715442, + -0.089667566, -0.10811871, 0.00026344223, 0.08298446, + -0.009525053, 0.006585689, -0.24567553, -0.09450807, + 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, + 0.067035615, 0.19271925, -0.0032889997, -0.043264326, + 0.09663576, -0.057112187, -0.10100678, 0.0628376, + 0.04447668, 0.017961001, -0.10094388, -0.10190601, + 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, + -0.1947263, 0.02251204, 0.11216432, -0.10307853, + 0.17351969, -0.039091777, 0.08066188, -0.00561982, + 0.12633002, 0.11335965, -0.0088127935, -0.019777594, + 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, + -0.07468996, -0.0855457, 0.099339016, -0.07580735, + -0.13775392, 0.08434318, 0.08330512, -0.12131499, + 0.031935584, 0.09180414, -0.08876437, -0.08049874, + 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, + 0.04331237, 0.04299654, -0.036394123, -0.12915532, + 0.09793732, 0.07512415, -0.11319543, -0.032502122, + 0.15661901, 0.07671967, -0.005491124, -0.19379048, + -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, + -0.09334311, 0.15026465, -0.15493552, -0.057762887, + -0.11604192, -0.262013, -0.01391798, 0.012185008, + 0.11156489, -0.07483202, 0.06693364, -0.26151478, + 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, + 0.030635102, 0.010969227, 0.11109743, 0.010919218, + 0.027526086, 0.13519906, 0.01891392, -0.046839405, + -0.040167913, 0.017953383, -0.09700955, 0.0061885654, + -0.07000971, 0.026893595, -0.038844477, 0.14543656}; + + lstm_input_ = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0 + 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1 + 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2 + 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3 + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0 + 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1 + 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2 + 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3 + }; + + lstm_golden_output_ = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + } +}; + +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +class BaseReduceOpModel : public SingleOpModelWithNNAPI { + public: + void SetAxis(const std::vector& data) { PopulateTensor(axis_, data); } + + template + void SetInput(std::vector data) { + PopulateTensor(input_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + int Input() { return input_; } + + protected: + int input_; + int axis_; + int output_; +}; + +// Model for the tests case where axis is a const tensor. +class MeanOpConstModel : public BaseReduceOpModel { + public: + MeanOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Tests for reduce_mean +TEST(NNAPIDelegate, MeanFloatNotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); +} + +TEST(NNAPIDelegate, MeanFloatKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); +} + +class BaseEmbeddingLookupOpModel : public SingleOpModelWithNNAPI { + public: + BaseEmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape, + TensorType weight_type = TensorType_FLOAT32) { + input_ = AddInput(TensorType_INT32); + weight_ = AddInput(weight_type); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({index_shape, weight_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int weight_; + int output_; +}; + +class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { + public: + using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel; + + void Set3DWeightMatrix(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(weight_); + int rows = tensor->dims->data[0]; + int columns = tensor->dims->data[1]; + int features = tensor->dims->data[2]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + for (int k = 0; k < features; k++) { + tensor->data.f[(i * columns + j) * features + k] = function(i, j, k); + } + } + } + } +}; + +TEST(NNAPIDelegate, EmbeddingLookupSimpleTest) { + EmbeddingLookupOpModel m({3}, {3, 2, 4}); + m.SetInput({1, 0, 2}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }))); +} + +class HashtableLookupOpModel : public SingleOpModelWithNNAPI { + public: + HashtableLookupOpModel(std::initializer_list lookup_shape, + std::initializer_list key_shape, + std::initializer_list value_shape, + TensorType type) { + lookup_ = AddInput(TensorType_INT32); + key_ = AddInput(TensorType_INT32); + value_ = AddInput(type); + output_ = AddOutput(type); + hit_ = AddOutput(TensorType_UINT8); + SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({lookup_shape, key_shape, value_shape}); + } + + void SetLookup(std::initializer_list data) { + PopulateTensor(lookup_, data); + } + + void SetHashtableKey(std::initializer_list data) { + PopulateTensor(key_, data); + } + + void SetHashtableValue(const std::vector& content) { + PopulateStringTensor(value_, content); + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + for (int i = 0; i < rows; i++) { + tensor->data.f[i] = function(i); + } + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + int features = tensor->dims->data[1]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < features; j++) { + tensor->data.f[i * features + j] = function(i, j); + } + } + } + + std::vector GetStringOutput() { + TfLiteTensor* output = interpreter_->tensor(output_); + int num = GetStringCount(output); + std::vector result(num); + for (int i = 0; i < num; i++) { + auto ref = GetString(output, i); + result[i] = string(ref.str, ref.len); + } + return result; + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetHit() { return ExtractVector(hit_); } + + private: + int lookup_; + int key_; + int value_; + int output_; + int hit_; +}; + +TEST(NNAPIDelegate, HashtableLookupTest2DInput) { + HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 2.0, 2.1, // 2-nd item + 0, 0, // Not found + 0.0, 0.1, // 0-th item + 1.0, 1.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} + +TEST(NNAPIDelegate, HashtableLookupTest1DInput) { + HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i) { return i * i / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.4, // 2-nd item + 0, // Not found + 0.0, // 0-th item + 0.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/error_reporter.cc index 03fcd5409ceab1895cea3b9e0e4fcb5a127e6a45..646913c0262c3483e999208651b5f0f872006cf6 100644 --- a/tensorflow/contrib/lite/error_reporter.cc +++ b/tensorflow/contrib/lite/error_reporter.cc @@ -16,6 +16,10 @@ limitations under the License. #include #include +#ifdef __ANDROID__ +#include +#endif + namespace tflite { ErrorReporter::~ErrorReporter() {} @@ -39,6 +43,15 @@ int ErrorReporter::ReportError(void*, const char* format, ...) { } int StderrReporter::Report(const char* format, va_list args) { +#ifdef __ANDROID__ + // On Android stderr is not captured for applications, only for code run from + // the shell. Rather than assume all users will set up a custom error + // reporter, let's output to logcat here + va_list args_for_log; + va_copy(args_for_log, args); + __android_log_vprint(ANDROID_LOG_ERROR, "tflite", format, args_for_log); + va_end(args_for_log); +#endif const int result = vfprintf(stderr, format, args); fputc('\n', stderr); return result; diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD index 57000072561303e8457f61b1ebe95d382fc01f10..4d2437e7d3714e1b8b427b0c6197b295c0355b07 100644 --- a/tensorflow/contrib/lite/examples/android/BUILD +++ b/tensorflow/contrib/lite/examples/android/BUILD @@ -1,6 +1,8 @@ # Description: # TensorFlow camera demo app for Android. +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 @@ -24,28 +26,29 @@ cc_library( android_binary( name = "tflite_demo", srcs = glob([ - "src/**/*.java", + "app/src/main/java/**/*.java", ]), # Package assets from assets dir as well as all model targets. # Remove undesired models (and corresponding Activities in source) # to reduce APK size. assets = [ - "//tensorflow/contrib/lite/examples/android/assets:labels_mobilenet_quant_v1_224.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", "@tflite_conv_actions_frozen//:conv_actions_frozen.tflite", - "//tensorflow/contrib/lite/examples/android/assets:conv_actions_labels.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:conv_actions_labels.txt", "@tflite_mobilenet_ssd//:mobilenet_ssd.tflite", - "//tensorflow/contrib/lite/examples/android/assets:box_priors.txt", - "//tensorflow/contrib/lite/examples/android/assets:coco_labels_list.txt", + "@tflite_mobilenet_ssd_quant//:detect.tflite", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:box_priors.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:coco_labels_list.txt", ], assets_dir = "", custom_package = "org.tensorflow.lite.demo", inline_constants = 1, - manifest = "AndroidManifest.xml", + manifest = "app/src/main/AndroidManifest.xml", nocompress_extensions = [ ".tflite", ], - resource_files = glob(["res/**"]), + resource_files = glob(["app/src/main/res/**"]), tags = [ "manual", "notap", @@ -55,31 +58,3 @@ android_binary( "//tensorflow/contrib/lite/java:tensorflowlite", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "bin/**", - "gen/**", - "gradleBuild/**", - "libs/**", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - -filegroup( - name = "java_files", - srcs = glob(["src/**/*.java"]), -) - -filegroup( - name = "resource_files", - srcs = glob(["res/**"]), -) - -exports_files(["AndroidManifest.xml"]) diff --git a/tensorflow/contrib/lite/examples/android/android.iml b/tensorflow/contrib/lite/examples/android/android.iml new file mode 100644 index 0000000000000000000000000000000000000000..f0a5ac2bf4cdfb7c98f5704310fbf2f16e9065a2 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/android.iml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tensorflow/contrib/lite/examples/android/app/README.md b/tensorflow/contrib/lite/examples/android/app/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cbdeeac8790d93210a6c637953605b4ca270d3f6 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/app/README.md @@ -0,0 +1,19 @@ +# TF Lite Android App Example + +## Building from Source with Bazel + +1. Install [Bazel](https://docs.bazel.build/versions/master/install.html), the Android NDK and SDK. The recommended versions are specified on this [webpage](https://www.tensorflow.org/mobile/tflite/demo_android#build_tensorflow_lite_and_the_demo_app_from_source). + +2. Build this demo app with Bazel. The demo needs C++11. We configure the fat_apk_cpu flag to package support for 4 hardware variants. You may replace it with --config=android_arm64 on a 64-bit device and --config=android_arm for 32-bit device: + + ```shell + bazel build -c opt --cxxopt='--std=c++11' --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ + //tensorflow/contrib/lite/examples/android:tflite_demo + ``` + +3. Install the demo on a + [debug-enabled device](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install): + + ```shell + adb install bazel-bin/tensorflow/contrib/lite/examples/android/tflite_demo.apk + ``` diff --git a/tensorflow/contrib/lite/examples/android/app/build.gradle b/tensorflow/contrib/lite/examples/android/app/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..eb7fd705e18f53eb026600207faefa3d2bb072af --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/app/build.gradle @@ -0,0 +1,60 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion '26.0.2' + defaultConfig { + applicationId "org.tensorflow.lite.demo" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +// import DownloadModels task +project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' +project.ext.TMP_DIR = project.buildDir.toString() + '/downloads' + +// Download default models; if you wish to use your own models then +// place them in the "assets" directory and comment out this line. +apply from: "download-models.gradle" + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/examples/android/app/download-models.gradle b/tensorflow/contrib/lite/examples/android/app/download-models.gradle new file mode 100644 index 0000000000000000000000000000000000000000..c100e37c16f38a65f7b1f64a3f6e3eaa1477e8eb --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/app/download-models.gradle @@ -0,0 +1,74 @@ +/* + * download-models.gradle + * Downloads model files from ${MODEL_URL} into application's asset folder + * Input: + * project.ext.TMP_DIR: absolute path to hold downloaded zip files + * project.ext.ASSET_DIR: absolute path to save unzipped model files + * Output: + * 3 model files will be downloaded into given folder of ext.ASSET_DIR + */ +// hard coded model files +// LINT.IfChange + +def models = ['conv_actions_tflite.zip', + 'mobilenet_ssd_tflite_v1.zip', + 'mobilenet_v1_224_android_quant_2017_11_08.zip', + 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip'] +// LINT.ThenChange(//tensorflow/contrib/lite/examples/android/BUILD) + +// Root URL for model archives +def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite' + +buildscript { + repositories { + jcenter() + } + dependencies { + classpath 'de.undercouch:gradle-download-task:3.2.0' + } +} + +import de.undercouch.gradle.tasks.download.Download +task downloadFile(type: Download){ + for (f in models) { + def modelUrl = MODEL_URL + "/" + f + println "Downloading ${f} from ${modelUrl}" + src modelUrl + } + + dest new File(project.ext.TMP_DIR) + overwrite true +} + +task extractModels(type: Copy) { + for (f in models) { + def localFile = f.split("/")[-1] + from zipTree(project.ext.TMP_DIR + '/' + localFile) + } + + into file(project.ext.ASSET_DIR) + fileMode 0644 + exclude '**/LICENSE' + + def needDownload = false + for (f in models) { + def localFile = f.split("/")[-1] + if (!(new File(project.ext.TMP_DIR + '/' + localFile)).exists()) { + needDownload = true + } + } + + if (needDownload) { + dependsOn downloadFile + } +} + +tasks.whenTaskAdded { task -> + if (task.name == 'assembleDebug') { + task.dependsOn 'extractModels' + } + if (task.name == 'assembleRelease') { + task.dependsOn 'extractModels' + } +} + diff --git a/tensorflow/contrib/lite/examples/android/AndroidManifest.xml b/tensorflow/contrib/lite/examples/android/app/src/main/AndroidManifest.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/AndroidManifest.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/AndroidManifest.xml diff --git a/tensorflow/contrib/lite/examples/android/assets/BUILD b/tensorflow/contrib/lite/examples/android/app/src/main/assets/BUILD similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/BUILD rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/BUILD diff --git a/tensorflow/contrib/lite/examples/android/assets/box_priors.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/box_priors.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/box_priors.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/box_priors.txt diff --git a/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/coco_labels_list.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/coco_labels_list.txt diff --git a/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/conv_actions_labels.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/conv_actions_labels.txt diff --git a/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..d581f733e48ff8c2ba88162ee56b5e9d12aec7de --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt @@ -0,0 +1,38 @@ +??? +Abyssinian +american_bulldog +american_pit_bull_terrier +basset_hound +beagle +Bengal +Birman +Bombay +boxer +British_Shorthair +chihuahua +Egyptian_Mau +english_cocker_spaniel +english_setter +german_shorthaired +great_pyrenees +havanese +japanese_chin +keeshond +leonberger +Maine_Coon +miniature_pinscher +newfoundland +Persian +pomeranian +pug +Ragdoll +Russian_Blue +saint_bernard +samoyed +scottish_terrier +shiba_inu +Siamese +Sphynx +staffordshire_bull_terrier +wheaten_terrier +yorkshire_terrier diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/Classifier.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/Classifier.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ClassifierActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ClassifierActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java similarity index 96% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/DetectorActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java index de997e454a1e33254cb7c2c932ca79d0072539fa..87160f6b3fb8c0d24e5df131d9becbb3eb6e2980 100644 --- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/DetectorActivity.java +++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java @@ -1,5 +1,5 @@ /* - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * Copyright 2018 The TensorFlow Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,9 +50,10 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable // Configuration values for the prepackaged SSD model. private static final int TF_OD_API_INPUT_SIZE = 300; - private static final String TF_OD_API_MODEL_FILE = "mobilenet_ssd.tflite"; + private static final boolean TF_OD_API_IS_QUANTIZED = true; + private static final String TF_OD_API_MODEL_FILE = "detect.tflite"; private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt"; - + // Which detection model to use: by default uses Tensorflow Object Detection API frozen // checkpoints. private enum DetectorMode { @@ -107,7 +108,11 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable try { detector = TFLiteObjectDetectionAPIModel.create( - getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE); + getAssets(), + TF_OD_API_MODEL_FILE, + TF_OD_API_LABELS_FILE, + TF_OD_API_INPUT_SIZE, + TF_OD_API_IS_QUANTIZED); cropSize = TF_OD_API_INPUT_SIZE; } catch (final IOException e) { LOGGER.e("Exception initializing classifier!", e); diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/OverlayView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/OverlayView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognizeCommands.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognizeCommands.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ResultsView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ResultsView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/SpeechActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/SpeechActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteImageClassifier.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteImageClassifier.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java similarity index 50% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java index bfb4a0a04bc90566736864bf62340d1032961858..9eb21de9d03e387d3c25b38171e154a358dc81ce 100644 --- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java +++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java @@ -25,15 +25,14 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.util.ArrayList; -import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.PriorityQueue; -import java.util.StringTokenizer; import java.util.Vector; import org.tensorflow.demo.env.Logger; import org.tensorflow.lite.Interpreter; @@ -46,32 +45,35 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { private static final Logger LOGGER = new Logger(); // Only return this many results. - private static final int NUM_RESULTS = 1917; - private static final int NUM_CLASSES = 91; - - private static final float Y_SCALE = 10.0f; - private static final float X_SCALE = 10.0f; - private static final float H_SCALE = 5.0f; - private static final float W_SCALE = 5.0f; - + private static final int NUM_DETECTIONS = 10; + private boolean isModelQuantized; + // Float model + private static final float IMAGE_MEAN = 128.0f; + private static final float IMAGE_STD = 128.0f; + // Number of threads in the java app + private static final int NUM_THREADS = 4; // Config values. private int inputSize; - - private final float[][] boxPriors = new float[4][NUM_RESULTS]; - // Pre-allocated buffers. private Vector labels = new Vector(); private int[] intValues; + // outputLocations: array of shape [Batchsize, NUM_DETECTIONS,4] + // contains the location of detected boxes private float[][][] outputLocations; - private float[][][] outputClasses; - - float[][][][] img; + // outputClasses: array of shape [Batchsize, NUM_DETECTIONS] + // contains the classes of detected boxes + private float[][] outputClasses; + // outputScores: array of shape [Batchsize, NUM_DETECTIONS] + // contains the scores of detected boxes + private float[][] outputScores; + // numDetections: array of shape [Batchsize] + // contains the number of detected boxes + private float[] numDetections; + + private ByteBuffer imgData; private Interpreter tfLite; - private float expit(final float x) { - return (float) (1. / (1. + Math.exp(-x))); - } /** Memory-map the model file in Assets. */ private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename) @@ -84,77 +86,24 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } - private void loadCoderOptions( - final AssetManager assetManager, final String locationFilename, final float[][] boxPriors) - throws IOException { - // Try to be intelligent about opening from assets or sdcard depending on prefix. - final String assetPrefix = "file:///android_asset/"; - InputStream is; - if (locationFilename.startsWith(assetPrefix)) { - is = assetManager.open(locationFilename.split(assetPrefix, -1)[1]); - } else { - is = new FileInputStream(locationFilename); - } - - final BufferedReader reader = new BufferedReader(new InputStreamReader(is)); - - for (int lineNum = 0; lineNum < 4; ++lineNum) { - String line = reader.readLine(); - final StringTokenizer st = new StringTokenizer(line, ", "); - int priorIndex = 0; - while (st.hasMoreTokens()) { - final String token = st.nextToken(); - try { - final float number = Float.parseFloat(token); - boxPriors[lineNum][priorIndex++] = number; - } catch (final NumberFormatException e) { - // Silently ignore. - } - } - if (priorIndex != NUM_RESULTS) { - throw new RuntimeException( - "BoxPrior length mismatch: " + priorIndex + " vs " + NUM_RESULTS); - } - } - - LOGGER.i("Loaded box priors!"); - } - - void decodeCenterSizeBoxes(float[][][] predictions) { - for (int i = 0; i < NUM_RESULTS; ++i) { - float ycenter = predictions[0][i][0] / Y_SCALE * boxPriors[2][i] + boxPriors[0][i]; - float xcenter = predictions[0][i][1] / X_SCALE * boxPriors[3][i] + boxPriors[1][i]; - float h = (float) Math.exp(predictions[0][i][2] / H_SCALE) * boxPriors[2][i]; - float w = (float) Math.exp(predictions[0][i][3] / W_SCALE) * boxPriors[3][i]; - - float ymin = ycenter - h / 2.f; - float xmin = xcenter - w / 2.f; - float ymax = ycenter + h / 2.f; - float xmax = xcenter + w / 2.f; - - predictions[0][i][0] = ymin; - predictions[0][i][1] = xmin; - predictions[0][i][2] = ymax; - predictions[0][i][3] = xmax; - } - } - /** * Initializes a native TensorFlow session for classifying images. * * @param assetManager The asset manager to be used to load assets. * @param modelFilename The filepath of the model GraphDef protocol buffer. * @param labelFilename The filepath of label file for classes. + * @param inputSize The size of image input + * @param isQuantized Boolean representing model is quantized or not */ public static Classifier create( final AssetManager assetManager, final String modelFilename, final String labelFilename, - final int inputSize) throws IOException { + final int inputSize, + final boolean isQuantized) + throws IOException { final TFLiteObjectDetectionAPIModel d = new TFLiteObjectDetectionAPIModel(); - d.loadCoderOptions(assetManager, "file:///android_asset/box_priors.txt", d.boxPriors); - InputStream labelsInput = null; String actualFilename = labelFilename.split("file:///android_asset/")[1]; labelsInput = assetManager.open(actualFilename); @@ -175,12 +124,23 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { throw new RuntimeException(e); } + d.isModelQuantized = isQuantized; // Pre-allocate buffers. - d.img = new float[1][inputSize][inputSize][3]; - + int numBytesPerChannel; + if (isQuantized) { + numBytesPerChannel = 1; // Quantized + } else { + numBytesPerChannel = 4; // Floating point + } + d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel); + d.imgData.order(ByteOrder.nativeOrder()); d.intValues = new int[d.inputSize * d.inputSize]; - d.outputLocations = new float[1][NUM_RESULTS][4]; - d.outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES]; + + d.tfLite.setNumThreads(NUM_THREADS); + d.outputLocations = new float[1][NUM_DETECTIONS][4]; + d.outputClasses = new float[1][NUM_DETECTIONS]; + d.outputScores = new float[1][NUM_DETECTIONS]; + d.numDetections = new float[1]; return d; } @@ -196,25 +156,37 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { // on the provided parameters. bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); + imgData.rewind(); for (int i = 0; i < inputSize; ++i) { for (int j = 0; j < inputSize; ++j) { - int pixel = intValues[j * inputSize + i]; - img[0][j][i][2] = (float) (pixel & 0xFF) / 128.0f - 1.0f; - img[0][j][i][1] = (float) ((pixel >> 8) & 0xFF) / 128.0f - 1.0f; - img[0][j][i][0] = (float) ((pixel >> 16) & 0xFF) / 128.0f - 1.0f; + int pixelValue = intValues[i * inputSize + j]; + if (isModelQuantized) { + // Quantized model + imgData.put((byte) ((pixelValue >> 16) & 0xFF)); + imgData.put((byte) ((pixelValue >> 8) & 0xFF)); + imgData.put((byte) (pixelValue & 0xFF)); + } else { // Float model + imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + } } } Trace.endSection(); // preprocessBitmap // Copy the input data into TensorFlow. Trace.beginSection("feed"); - outputLocations = new float[1][NUM_RESULTS][4]; - outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES]; + outputLocations = new float[1][NUM_DETECTIONS][4]; + outputClasses = new float[1][NUM_DETECTIONS]; + outputScores = new float[1][NUM_DETECTIONS]; + numDetections = new float[1]; - Object[] inputArray = {img}; + Object[] inputArray = {imgData}; Map outputMap = new HashMap<>(); outputMap.put(0, outputLocations); outputMap.put(1, outputClasses); + outputMap.put(2, outputScores); + outputMap.put(3, numDetections); Trace.endSection(); // Run the inference call. @@ -222,56 +194,26 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { tfLite.runForMultipleInputsOutputs(inputArray, outputMap); Trace.endSection(); - decodeCenterSizeBoxes(outputLocations); - - // Find the best detections. - final PriorityQueue pq = - new PriorityQueue( - 1, - new Comparator() { - @Override - public int compare(final Recognition lhs, final Recognition rhs) { - // Intentionally reversed to put high confidence at the head of the queue. - return Float.compare(rhs.getConfidence(), lhs.getConfidence()); - } - }); - - // Scale them back to the input size. - for (int i = 0; i < NUM_RESULTS; ++i) { - float topClassScore = -1000f; - int topClassScoreIndex = -1; - - // Skip the first catch-all class. - for (int j = 1; j < NUM_CLASSES; ++j) { - float score = expit(outputClasses[0][i][j]); - - if (score > topClassScore) { - topClassScoreIndex = j; - topClassScore = score; - } - } - - if (topClassScore > 0.001f) { - final RectF detection = - new RectF( - outputLocations[0][i][1] * inputSize, - outputLocations[0][i][0] * inputSize, - outputLocations[0][i][3] * inputSize, - outputLocations[0][i][2] * inputSize); - - pq.add( - new Recognition( - "" + i, - labels.get(topClassScoreIndex), - outputClasses[0][i][topClassScoreIndex], - detection)); - } - } - - final ArrayList recognitions = new ArrayList(); - for (int i = 0; i < Math.min(pq.size(), 10); ++i) { - Recognition recog = pq.poll(); - recognitions.add(recog); + // Show the best detections. + // after scaling them back to the input size. + final ArrayList recognitions = new ArrayList<>(NUM_DETECTIONS); + for (int i = 0; i < NUM_DETECTIONS; ++i) { + final RectF detection = + new RectF( + outputLocations[0][i][1] * inputSize, + outputLocations[0][i][0] * inputSize, + outputLocations[0][i][3] * inputSize, + outputLocations[0][i][2] * inputSize); + // SSD Mobilenet V1 Model assumes class 0 is background class + // in label file and class labels start from 1 to number_of_classes+1, + // while outputClasses correspond to class index from 0 to number_of_classes + int labelOffset = 1; + recognitions.add( + new Recognition( + "" + i, + labels.get((int) outputClasses[0][i] + labelOffset), + outputScores[0][i], + detection)); } Trace.endSection(); // "recognizeImage" return recognitions; diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/AssetUtils.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/AssetUtils.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/BorderedText.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/BorderedText.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/ImageUtils.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Logger.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Logger.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Size.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Size.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/SplitTimer.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/SplitTimer.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java diff --git a/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/animator/color_animation.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/animator/color_animation.xml diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable/border.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable/border.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable/border.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable/border.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_camera.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_camera.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_speech.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_speech.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_stylize.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_stylize.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_tracking.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_tracking.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/list_text_item.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/list_text_item.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/list_text_item.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/list_text_item.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-dimens.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-dimens.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v11/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v11/styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v11/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v11/template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v14/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v14/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v14/styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v14/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v21/base-colors.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-colors.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v21/base-colors.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-colors.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v21/base-template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v21/base-template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/attrs.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/attrs.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/attrs.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/attrs.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/base-strings.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/base-strings.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/base-strings.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/base-strings.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/colors.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/colors.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/colors.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/colors.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/strings.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/strings.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/strings.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/strings.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/template-dimens.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-dimens.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/template-dimens.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-dimens.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/build.gradle b/tensorflow/contrib/lite/examples/android/build.gradle index 0d4de358156a5d139e35cc542b8d36ab24e763b9..a47fa4bbf6730c7d1269737564381c8464224713 100644 --- a/tensorflow/contrib/lite/examples/android/build.gradle +++ b/tensorflow/contrib/lite/examples/android/build.gradle @@ -1,52 +1,23 @@ -apply plugin: 'com.android.application' +// Top-level build file where you can add configuration options common to all sub-projects/modules. -android { - compileSdkVersion 26 - buildToolsVersion "26.0.1" - defaultConfig { - applicationId "org.tensorflow.lite.demo" - minSdkVersion 15 - targetSdkVersion 26 - versionCode 1 - versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" - - // Remove this block. - jackOptions { - enabled true - } - } - lintOptions { - abortOnError false - } - buildTypes { - release { - minifyEnabled false - proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' - } - } - aaptOptions { - noCompress "tflite" +buildscript { + repositories { + jcenter() } + dependencies { + classpath 'com.android.tools.build:gradle:3.0.1' - compileOptions { - sourceCompatibility JavaVersion.VERSION_1_8 - targetCompatibility JavaVersion.VERSION_1_8 + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files } } -repositories { - maven { - url 'https://google.bintray.com/tensorflow' +allprojects { + repositories { + jcenter() } } -dependencies { - compile fileTree(dir: 'libs', include: ['*.jar']) - androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { - exclude group: 'com.android.support', module: 'support-annotations' - }) - compile 'org.tensorflow:tensorflow-lite:+' - - testCompile 'junit:junit:4.12' +task clean(type: Delete) { + delete rootProject.buildDir } diff --git a/tensorflow/contrib/lite/examples/android/settings.gradle b/tensorflow/contrib/lite/examples/android/settings.gradle new file mode 100644 index 0000000000000000000000000000000000000000..e7b4def49cb53d9aa04228dd3edb14c9e635e003 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/settings.gradle @@ -0,0 +1 @@ +include ':app' diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm index d74e275f0439b1ce56b29e0eadff5f211f6a4faa..734b15e0a10bfbd485b0a0a89296b27546ea5f40 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -26,7 +26,7 @@ #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" +#include "tensorflow/contrib/lite/op_resolver.h" #define LOG(x) std::cerr @@ -315,7 +315,7 @@ static void GetTopN(const uint8_t* prediction, const int prediction_size, const labelLayers = [[NSMutableArray alloc] init]; oldPredictionValues = [[NSMutableDictionary alloc] init]; - NSString* graph_path = FilePathForResourceName(model_file_name, @"tflite"); + NSString* graph_path = FilePathForResourceName(model_file_name, model_file_type); model = tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]); if (!model) { LOG(FATAL) << "Failed to mmap model " << graph_path; diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile index c7d3b1c966eaa0de71f5c37a6a77b3881e30ddd7..8084307ac794c3cb114270c3c3a08a73db0ef359 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/Podfile +++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile @@ -2,4 +2,4 @@ platform :ios, '8.0' inhibit_all_warnings! target 'tflite_camera_example' - pod 'TensorFlowLite' + pod 'TensorFlowLite', '1.10.0' diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/contrib/lite/examples/ios/simple/Podfile index e4aca2be82d437a0225d2c15d3e486b0344aa978..eea7ecb759688a9a919dade58f97d1a141a3ddeb 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/Podfile +++ b/tensorflow/contrib/lite/examples/ios/simple/Podfile @@ -2,4 +2,4 @@ platform :ios, '8.0' inhibit_all_warnings! target 'tflite_simple_example' - pod 'TensorFlowLite' + pod 'TensorFlowLite', '1.10.0' diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm index 0ab7aa25d0b4e6d2c02e61ec1d82b85258b3dfbc..650c73f7322c3169e60231ce52e86d2cdc86d0a4 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm @@ -25,7 +25,7 @@ #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" +#include "tensorflow/contrib/lite/op_resolver.h" #include "ios_image_load.h" diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD index c61445114ecc6dfbe4f2b6ab666b28a8aa746be3..fc55a78019b4a12b24231034a7e4b912869389f2 100644 --- a/tensorflow/contrib/lite/examples/label_image/BUILD +++ b/tensorflow/contrib/lite/examples/label_image/BUILD @@ -63,6 +63,7 @@ cc_test( data = [ "testdata/grace_hopper.bmp", ], + tags = ["no_oss"], deps = [ ":bitmap_helpers", "@com_google_googletest//:gtest", diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h index e36218e4f12057a362af47c48454f7930fc495f2..6fdcf78b69c6799fc2e666af1150efb88b55ff5c 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -16,11 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ #define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/contrib/lite/examples/label_image/label_image.h" #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/interpreter.h" @@ -28,8 +24,6 @@ limitations under the License. #include "tensorflow/contrib/lite/string_util.h" #include "tensorflow/contrib/lite/version.h" -#include "tensorflow/contrib/lite/examples/label_image/label_image.h" - namespace tflite { namespace label_image { diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index 86d7d1cc4a625243791d5e7d5b746526a58efb6d..7c6f523041ad5a516f348c1b4f66683128838228 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -213,22 +213,23 @@ void RunInference(Settings* s) { } } - const int output_size = 1000; - const size_t num_results = 5; const float threshold = 0.001f; std::vector> top_results; int output = interpreter->outputs()[0]; + TfLiteIntArray* output_dims = interpreter->tensor(output)->dims; + // assume output dims to be something like (1, 1, ... ,size) + auto output_size = output_dims->data[output_dims->size - 1]; switch (interpreter->tensor(output)->type) { case kTfLiteFloat32: get_top_n(interpreter->typed_output_tensor(0), output_size, - num_results, threshold, &top_results, true); + s->number_of_results, threshold, &top_results, true); break; case kTfLiteUInt8: get_top_n(interpreter->typed_output_tensor(0), - output_size, num_results, threshold, &top_results, - false); + output_size, s->number_of_results, threshold, + &top_results, false); break; default: LOG(FATAL) << "cannot handle output type " @@ -259,6 +260,7 @@ void display_usage() { << "--labels, -l: labels for the model\n" << "--tflite_model, -m: model_name.tflite\n" << "--profiling, -p: [0|1], profiling or not\n" + << "--num_results, -r: number of results to show\n" << "--threads, -t: number of threads\n" << "--verbose, -v: [0|1] print more information\n" << "\n"; @@ -280,12 +282,13 @@ int Main(int argc, char** argv) { {"threads", required_argument, nullptr, 't'}, {"input_mean", required_argument, nullptr, 'b'}, {"input_std", required_argument, nullptr, 's'}, + {"num_results", required_argument, nullptr, 'r'}, {nullptr, 0, nullptr, 0}}; /* getopt_long stores the option index here. */ int option_index = 0; - c = getopt_long(argc, argv, "a:b:c:f:i:l:m:p:s:t:v:", long_options, + c = getopt_long(argc, argv, "a:b:c:f:i:l:m:p:r:s:t:v:", long_options, &option_index); /* Detect the end of the options. */ @@ -315,6 +318,10 @@ int Main(int argc, char** argv) { s.profiling = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; + case 'r': + s.number_of_results = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) + break; case 's': s.input_std = strtod(optarg, nullptr); break; diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h index 4b48014e1c77eca1eca081f0fe906441a5dcce22..34c223f713b9fe7692440a6b7538f00be995ad11 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.h +++ b/tensorflow/contrib/lite/examples/label_image/label_image.h @@ -34,6 +34,7 @@ struct Settings { string labels_file_name = "./labels.txt"; string input_layer_type = "uint8_t"; int number_of_threads = 4; + int number_of_results = 5; }; } // namespace label_image diff --git a/tensorflow/contrib/lite/examples/minimal/BUILD b/tensorflow/contrib/lite/examples/minimal/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..b403628d6c457ce3fb67eac3675fd7bb9187deab --- /dev/null +++ b/tensorflow/contrib/lite/examples/minimal/BUILD @@ -0,0 +1,27 @@ +# Description: +# TensorFlow Lite minimal example. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") + +tf_cc_binary( + name = "minimal", + srcs = [ + "minimal.cc", + ], + linkopts = tflite_linkopts() + select({ + "//tensorflow:android": [ + "-pie", # Android 5.0 and later supports only PIE + "-lm", # some builtin ops, e.g., tanh, need -lm + ], + "//conditions:default": [], + }), + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) diff --git a/tensorflow/contrib/lite/examples/minimal/minimal.cc b/tensorflow/contrib/lite/examples/minimal/minimal.cc index 8b0ace96ccaf06ac1cbdc2ea95ac6e92ef886993..8b65cde7b79fde19280ad778ea874c64b01d169a 100644 --- a/tensorflow/contrib/lite/examples/minimal/minimal.cc +++ b/tensorflow/contrib/lite/examples/minimal/minimal.cc @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/model.h" +#include #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" -#include +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/optional_debug_tools.h" // This is an example that is minimal to read a model // from disk and perform inference. There is no data being loaded @@ -29,14 +30,13 @@ limitations under the License. using namespace tflite; -#define TFLITE_MINIMAL_CHECK(x) \ - if(!(x)) { \ - fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ - exit(1); \ +#define TFLITE_MINIMAL_CHECK(x) \ + if (!(x)) { \ + fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ + exit(1); \ } - -int main(int argc, char *argv[]) { +int main(int argc, char* argv[]) { if(argc != 2) { fprintf(stderr, "minimal \n"); return 1; @@ -44,8 +44,8 @@ int main(int argc, char *argv[]) { const char* filename = argv[1]; // Load model - std::unique_ptr model - = tflite::FlatBufferModel::BuildFromFile(filename); + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromFile(filename); TFLITE_MINIMAL_CHECK(model != nullptr); // Build the interpreter @@ -57,12 +57,16 @@ int main(int argc, char *argv[]) { // Allocate tensor buffers. TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk); + printf("=== Pre-invoke Interpreter State ===\n"); + tflite::PrintInterpreterState(interpreter.get()); // Fill input buffers // TODO(user): Insert code to fill input tensors // Run inference TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk); + printf("\n\n=== Post-invoke Interpreter State ===\n"); + tflite::PrintInterpreterState(interpreter.get()); // Read output buffers // TODO(user): Insert getting data out code. diff --git a/tensorflow/contrib/lite/examples/python/BUILD b/tensorflow/contrib/lite/examples/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..d337c3ddc43a23e50a5afdab93b16c0f61ccd538 --- /dev/null +++ b/tensorflow/contrib/lite/examples/python/BUILD @@ -0,0 +1,13 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +py_binary( + name = "label_image", + srcs = ["label_image.py"], + main = "label_image.py", + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/lite/python:lite", + ], +) diff --git a/tensorflow/contrib/lite/examples/python/label_image.md b/tensorflow/contrib/lite/examples/python/label_image.md new file mode 100644 index 0000000000000000000000000000000000000000..e81192a96c142f2b3e7e85d160166fdd37ccdc53 --- /dev/null +++ b/tensorflow/contrib/lite/examples/python/label_image.md @@ -0,0 +1,50 @@ + +With model, input image (grace_hopper.bmp), and labels file (labels.txt) +in /tmp. + +The example input image and labels file are from TensorFlow repo and +MobileNet V1 model files. + +``` +curl https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp > /tmp/grace_hopper.bmp + +curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz | tar xzv -C /tmp mobilenet_v1_1.0_224/labels.txt +mv /tmp/mobilenet_v1_1.0_224/labels.txt /tmp/ + +``` + +Run + +``` +curl http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz | tar xzv -C /tmp +bazel run --config opt //tensorflow/contrib/lite/examples/python:label_image +``` + +We can get results like + +``` +0.470588: military uniform +0.337255: Windsor tie +0.047059: bow tie +0.031373: mortarboard +0.019608: suit +``` + +Run + +``` +curl http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz | tar xzv -C /tmp +bazel run --config opt //tensorflow/contrib/lite/examples/python:label_image \ +-- --model_file /tmp/mobilenet_v1_1.0_224.tflite +``` + +We can get results like +``` +0.728693: military uniform +0.116163: Windsor tie +0.035517: bow tie +0.014874: mortarboard +0.011758: bolo tie +``` + +Check [models](../../g3doc/models.md) for models hosted by Google. diff --git a/tensorflow/contrib/lite/examples/python/label_image.py b/tensorflow/contrib/lite/examples/python/label_image.py new file mode 100644 index 0000000000000000000000000000000000000000..282118a1d2b43a08930b24366110a021fc634b5e --- /dev/null +++ b/tensorflow/contrib/lite/examples/python/label_image.py @@ -0,0 +1,86 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""label_image for tflite""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import numpy as np + +from PIL import Image + +from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper + +def load_labels(filename): + my_labels = [] + input_file = open(filename, 'r') + for l in input_file: + my_labels.append(l.strip()) + return my_labels + +if __name__ == "__main__": + floating_model = False + + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--image", default="/tmp/grace_hopper.bmp", \ + help="image to be classified") + parser.add_argument("-m", "--model_file", \ + default="/tmp/mobilenet_v1_1.0_224_quant.tflite", \ + help=".tflite model to be executed") + parser.add_argument("-l", "--label_file", default="/tmp/labels.txt", \ + help="name of file containing labels") + parser.add_argument("--input_mean", default=127.5, help="input_mean") + parser.add_argument("--input_std", default=127.5, \ + help="input standard deviation") + args = parser.parse_args() + + interpreter = interpreter_wrapper.Interpreter(model_path=args.model_file) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # check the type of the input tensor + if input_details[0]['dtype'] == np.float32: + floating_model = True + + # NxHxWxC, H:1, W:2 + height = input_details[0]['shape'][1] + width = input_details[0]['shape'][2] + img = Image.open(args.image) + img = img.resize((width, height)) + + # add N dim + input_data = np.expand_dims(img, axis=0) + + if floating_model: + input_data = (np.float32(input_data) - args.input_mean) / args.input_std + + interpreter.set_tensor(input_details[0]['index'], input_data) + + interpreter.invoke() + + output_data = interpreter.get_tensor(output_details[0]['index']) + results = np.squeeze(output_data) + + top_k = results.argsort()[-5:][::-1] + labels = load_labels(args.label_file) + for i in top_k: + if floating_model: + print('{0:08.6f}'.format(float(results[i]))+":", labels[i]) + else: + print('{0:08.6f}'.format(float(results[i]/255.0))+":", labels[i]) diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..8fc07e8eb7eb1b53cc94eed75093c49c29679d77 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/BUILD @@ -0,0 +1,98 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite:build_def.bzl", + "tflite_cc_shared_object", + "tflite_copts", + "tflite_jni_binary", +) + +tflite_cc_shared_object( + name = "libtensorflowlite_c.so", + linkopts = select({ + "//tensorflow:darwin": [ + "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file + "$(location //tensorflow/contrib/lite/experimental/c:exported_symbols.lds)", + "-Wl,-install_name,@rpath/libtensorflowlite_c.so", + ], + "//tensorflow:windows": [], + "//conditions:default": [ + "-z defs", + "-Wl,--version-script", # This line must be directly followed by the version_script.lds file + "$(location //tensorflow/contrib/lite/experimental/c:version_script.lds)", + ], + }), + deps = [ + ":c_api", + ":c_api_experimental", + ":exported_symbols.lds", + ":version_script.lds", + ], +) + +cc_library( + name = "c_api_internal", + srcs = ["c_api.h"], + hdrs = ["c_api_internal.h"], + copts = tflite_copts(), + visibility = [ + "//tensorflow/contrib/lite/experimental/c:__subpackages__", + ], + deps = [ + "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite:framework", + ], +) + +cc_library( + name = "c_api", + srcs = ["c_api.cc"], + hdrs = ["c_api.h"], + copts = tflite_copts(), + deps = [ + ":c_api_internal", + "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) + +cc_library( + name = "c_api_experimental", + srcs = ["c_api_experimental.cc"], + hdrs = ["c_api_experimental.h"], + copts = tflite_copts(), + deps = [ + ":c_api", + ":c_api_internal", + ], +) + +cc_test( + name = "c_api_test", + size = "small", + srcs = ["c_api_test.cc"], + data = ["//tensorflow/contrib/lite:testdata/add.bin"], + deps = [ + ":c_api", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "c_api_experimental_test", + size = "small", + srcs = ["c_api_experimental_test.cc"], + data = ["//tensorflow/contrib/lite:testdata/add.bin"], + deps = [ + ":c_api", + ":c_api_experimental", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4ab0e8c306b5b1e514e1ddf0c166ba0b43d75d1 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api.cc @@ -0,0 +1,154 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/c/c_api.h" + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// LINT.IfChange + +TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) { + auto model = tflite::FlatBufferModel::BuildFromBuffer( + static_cast(model_data), model_size); + return model ? new TFL_Model{std::move(model)} : nullptr; +} + +TFL_Model* TFL_NewModelFromFile(const char* model_path) { + auto model = tflite::FlatBufferModel::BuildFromFile(model_path); + return model ? new TFL_Model{std::move(model)} : nullptr; +} + +void TFL_DeleteModel(TFL_Model* model) { delete model; } + +TFL_InterpreterOptions* TFL_NewInterpreterOptions() { + return new TFL_InterpreterOptions{}; +} + +void TFL_DeleteInterpreterOptions(TFL_InterpreterOptions* options) { + delete options; +} + +void TFL_InterpreterOptionsSetNumThreads(TFL_InterpreterOptions* options, + int32_t num_threads) { + options->num_threads = num_threads; +} + +TFL_Interpreter* TFL_NewInterpreter( + const TFL_Model* model, const TFL_InterpreterOptions* optional_options) { + if (!model || !model->impl) { + return nullptr; + } + + tflite::ops::builtin::BuiltinOpResolver resolver; + tflite::InterpreterBuilder builder(*model->impl, resolver); + std::unique_ptr interpreter; + if (builder(&interpreter) != kTfLiteOk) { + return nullptr; + } + + if (optional_options) { + if (optional_options->num_threads != + TFL_InterpreterOptions::kDefaultNumThreads) { + interpreter->SetNumThreads(optional_options->num_threads); + } + } + + return new TFL_Interpreter{std::move(interpreter)}; +} + +void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; } + +int32_t TFL_InterpreterGetInputTensorCount(const TFL_Interpreter* interpreter) { + return static_cast(interpreter->impl->inputs().size()); +} + +TFL_Tensor* TFL_InterpreterGetInputTensor(const TFL_Interpreter* interpreter, + int32_t input_index) { + return interpreter->impl->tensor(interpreter->impl->inputs()[input_index]); +} + +TFL_Status TFL_InterpreterResizeInputTensor(TFL_Interpreter* interpreter, + int32_t input_index, + const int* input_dims, + int32_t input_dims_size) { + std::vector dims{input_dims, input_dims + input_dims_size}; + return interpreter->impl->ResizeInputTensor( + interpreter->impl->inputs()[input_index], dims); +} + +TFL_Status TFL_InterpreterAllocateTensors(TFL_Interpreter* interpreter) { + return interpreter->impl->AllocateTensors(); +} + +TFL_Status TFL_InterpreterInvoke(TFL_Interpreter* interpreter) { + return interpreter->impl->Invoke(); +} + +int32_t TFL_InterpreterGetOutputTensorCount( + const TFL_Interpreter* interpreter) { + return static_cast(interpreter->impl->outputs().size()); +} + +const TFL_Tensor* TFL_InterpreterGetOutputTensor( + const TFL_Interpreter* interpreter, int32_t output_index) { + return interpreter->impl->tensor(interpreter->impl->outputs()[output_index]); +} + +TFL_Type TFL_TensorType(const TFL_Tensor* tensor) { return tensor->type; } + +int32_t TFL_TensorNumDims(const TFL_Tensor* tensor) { + return tensor->dims->size; +} + +int32_t TFL_TensorDim(const TFL_Tensor* tensor, int32_t dim_index) { + return tensor->dims->data[dim_index]; +} + +size_t TFL_TensorByteSize(const TFL_Tensor* tensor) { return tensor->bytes; } + +void* TFL_TensorData(const TFL_Tensor* tensor) { + return static_cast(tensor->data.raw); +} + +TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data, + size_t input_data_size) { + if (tensor->bytes != input_data_size) { + return kTfLiteError; + } + memcpy(tensor->data.raw, input_data, input_data_size); + return kTfLiteOk; +} + +TFL_Status TFL_TensorCopyToBuffer(const TFL_Tensor* tensor, void* output_data, + size_t output_data_size) { + if (tensor->bytes != output_data_size) { + return kTfLiteError; + } + memcpy(output_data, tensor->data.raw, output_data_size); + return kTfLiteOk; +} + +// LINT.ThenChange(//tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs) + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h new file mode 100644 index 0000000000000000000000000000000000000000..3757349b5510ea3c3ac876b50b5c8c7db14688c9 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api.h @@ -0,0 +1,196 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_ + +#include + +// Eventually the various C APIs defined in context.h will be migrated into +// the appropriate /c/c_api*.h header. For now, we pull in existing definitions +// for convenience. +#include "tensorflow/contrib/lite/context.h" + +// -------------------------------------------------------------------------- +// Experimental C API for TensorFlowLite. +// +// The API leans towards simplicity and uniformity instead of convenience, as +// most usage will be by language-specific wrappers. +// +// Conventions: +// * We use the prefix TFL_ for everything in the API. +// * size_t is used to represent byte sizes of objects that are +// materialized in the address space of the calling process. +// * int is used as an index into arrays. + +#ifdef SWIG +#define TFL_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TFL_CAPI_EXPORT __declspec(dllexport) +#else +#define TFL_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TFL_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef TfLiteTensor TFL_Tensor; +typedef TfLiteStatus TFL_Status; +typedef TfLiteType TFL_Type; + +// -------------------------------------------------------------------------- +// TFL_Model wraps a loaded TensorFlow Lite model. +typedef struct TFL_Model TFL_Model; + +// Returns a model from the provided buffer, or null on failure. +TFL_CAPI_EXPORT extern TFL_Model* TFL_NewModel(const void* model_data, + size_t model_size); + +// Returns a model from the provided file, or null on failure. +TFL_CAPI_EXPORT extern TFL_Model* TFL_NewModelFromFile(const char* model_path); + +// Destroys the model instance. +TFL_CAPI_EXPORT extern void TFL_DeleteModel(TFL_Model* model); + +// -------------------------------------------------------------------------- +// TFL_InterpreterOptions allows customized interpreter configuration. +typedef struct TFL_InterpreterOptions TFL_InterpreterOptions; + +// Returns a new interpreter options instances. +TFL_CAPI_EXPORT extern TFL_InterpreterOptions* TFL_NewInterpreterOptions(); + +// Destroys the interpreter options instance. +TFL_CAPI_EXPORT extern void TFL_DeleteInterpreterOptions( + TFL_InterpreterOptions* options); + +// Sets the number of CPU threads to use for the interpreter. +TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetNumThreads( + TFL_InterpreterOptions* options, int32_t num_threads); + +// -------------------------------------------------------------------------- +// TFL_Interpreter provides inference from a provided model. +typedef struct TFL_Interpreter TFL_Interpreter; + +// Returns a new interpreter using the provided model and options, or null on +// failure. +// +// * `model` must be a valid model instance. The caller retains ownership of the +// object, and can destroy it immediately after creating the interpreter. +// * `optional_options` may be null. The caller retains ownership of the object, +// and can safely destroy it immediately after creating the interpreter. +// +// NOTE: The client *must* explicitly allocate tensors before attempting to +// access input tensor data or invoke the interpreter. +TFL_CAPI_EXPORT extern TFL_Interpreter* TFL_NewInterpreter( + const TFL_Model* model, const TFL_InterpreterOptions* optional_options); + +// Destroys the interpreter. +TFL_CAPI_EXPORT extern void TFL_DeleteInterpreter(TFL_Interpreter* interpreter); + +// Returns the number of input tensors associated with the model. +TFL_CAPI_EXPORT extern int TFL_InterpreterGetInputTensorCount( + const TFL_Interpreter* interpreter); + +// Returns the tensor associated with the input index. +// REQUIRES: 0 <= input_index < TFL_InterpreterGetInputTensorCount(tensor) +TFL_CAPI_EXPORT extern TFL_Tensor* TFL_InterpreterGetInputTensor( + const TFL_Interpreter* interpreter, int32_t input_index); + +// Resizes the specified input tensor. +// +// NOTE: After a resize, the client *must* explicitly allocate tensors before +// attempting to access the resized tensor data or invoke the interpreter. +// REQUIRES: 0 <= input_index < TFL_InterpreterGetInputTensorCount(tensor) +TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResizeInputTensor( + TFL_Interpreter* interpreter, int32_t input_index, const int* input_dims, + int32_t input_dims_size); + +// Updates allocations for all tensors, resizing dependent tensors using the +// specified input tensor dimensionality. +// +// This is a relatively expensive operation, and need only be called after +// creating the graph and/or resizing any inputs. +TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterAllocateTensors( + TFL_Interpreter* interpreter); + +// Runs inference for the loaded graph. +// +// NOTE: It is possible that the interpreter is not in a ready state to +// evaluate (e.g., if a ResizeInputTensor() has been performed without a call to +// AllocateTensors()). +TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterInvoke( + TFL_Interpreter* interpreter); + +// Returns the number of output tensors associated with the model. +TFL_CAPI_EXPORT extern int32_t TFL_InterpreterGetOutputTensorCount( + const TFL_Interpreter* interpreter); + +// Returns the tensor associated with the output index. +// REQUIRES: 0 <= input_index < TFL_InterpreterGetOutputTensorCount(tensor) +TFL_CAPI_EXPORT extern const TFL_Tensor* TFL_InterpreterGetOutputTensor( + const TFL_Interpreter* interpreter, int32_t output_index); + +// -------------------------------------------------------------------------- +// TFL_Tensor wraps data associated with a graph tensor. +// +// Note that, while the TFL_Tensor struct is not currently opaque, and its +// fields can be accessed directly, these methods are still convenient for +// language bindings. In the future the tensor struct will likely be made opaque +// in the public API. + +// Returns the type of a tensor element. +TFL_CAPI_EXPORT extern TFL_Type TFL_TensorType(const TFL_Tensor* tensor); + +// Returns the number of dimensions that the tensor has. +TFL_CAPI_EXPORT extern int32_t TFL_TensorNumDims(const TFL_Tensor* tensor); + +// Returns the length of the tensor in the "dim_index" dimension. +// REQUIRES: 0 <= dim_index < TFLiteTensorNumDims(tensor) +TFL_CAPI_EXPORT extern int32_t TFL_TensorDim(const TFL_Tensor* tensor, + int32_t dim_index); + +// Returns the size of the underlying data in bytes. +TFL_CAPI_EXPORT extern size_t TFL_TensorByteSize(const TFL_Tensor* tensor); + +// Returns a pointer to the underlying data buffer. +// +// Note: The result may be null if tensors have not yet been allocated, e.g., +// if the Tensor has just been created or resized and `TFL_AllocateTensors()` +// has yet to be called, or if the output tensor is dynamically sized and the +// interpreter hasn't been invoked. +TFL_CAPI_EXPORT extern void* TFL_TensorData(const TFL_Tensor* tensor); + +// Copies from the provided input buffer into the tensor's buffer. +// REQUIRES: input_data_size == TFL_TensorByteSize(tensor) +TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyFromBuffer( + TFL_Tensor* tensor, const void* input_data, size_t input_data_size); + +// Copies to the provided output buffer from the tensor's buffer. +// REQUIRES: output_data_size == TFL_TensorByteSize(tensor) +TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyToBuffer( + const TFL_Tensor* output_tensor, void* output_data, + size_t output_data_size); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_ diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc new file mode 100644 index 0000000000000000000000000000000000000000..c4dbc55cbf6b116df46553411be5337f83ceb4e7 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/c/c_api_experimental.h" + +#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +TFL_Status TFL_InterpreterResetVariableTensorsToZero( + TFL_Interpreter* interpreter) { + return interpreter->impl->ResetVariableTensorsToZero(); +} + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h new file mode 100644 index 0000000000000000000000000000000000000000..b0ac258dcf9bf4ab603ba847f1b111a89cf2f29b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ + +#include "tensorflow/contrib/lite/experimental/c/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Resets all variable tensors to zero. +TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensorsToZero( + TFL_Interpreter* interpreter); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_ diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..db6e5251de518d2e754f853edbfb1c1edc425a83 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/c/c_api_experimental.h" + +#include +#include "tensorflow/contrib/lite/experimental/c/c_api.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace { + +TEST(CApiExperimentalSimple, Smoke) { + TFL_Model* model = TFL_NewModelFromFile( + "tensorflow/contrib/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); + + TFL_Interpreter* interpreter = + TFL_NewInterpreter(model, /*optional_options=*/nullptr); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk); + + EXPECT_EQ(TFL_InterpreterResetVariableTensorsToZero(interpreter), kTfLiteOk); + + TFL_DeleteModel(model); + TFL_DeleteInterpreter(interpreter); +} + +} // namespace + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..c5c612a4c6d3f8ccc49697961fd87b81bc00b6a8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_ + +#include "tensorflow/contrib/lite/experimental/c/c_api.h" + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/model.h" + +// Internal structures used by the C API. These are likely to change and should +// not be depended on. + +struct TFL_Model { + std::unique_ptr impl; +}; + +struct TFL_InterpreterOptions { + enum { + kDefaultNumThreads = -1, + }; + int num_threads = kDefaultNumThreads; +}; + +struct TFL_Interpreter { + std::unique_ptr impl; +}; + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_INTERNAL_H_ diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a631dae8906a2f5ab10b4125454f2eafb937823f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/contrib/lite/experimental/c/c_api.h" + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace { + +TEST(CApiSimple, Smoke) { + TFL_Model* model = TFL_NewModelFromFile( + "tensorflow/contrib/lite/testdata/add.bin"); + ASSERT_NE(model, nullptr); + + TFL_InterpreterOptions* options = TFL_NewInterpreterOptions(); + ASSERT_NE(options, nullptr); + TFL_InterpreterOptionsSetNumThreads(options, 2); + + TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options); + ASSERT_NE(interpreter, nullptr); + + // The options/model can be deleted immediately after interpreter creation. + TFL_DeleteInterpreterOptions(options); + TFL_DeleteModel(model); + + ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk); + ASSERT_EQ(TFL_InterpreterGetInputTensorCount(interpreter), 1); + ASSERT_EQ(TFL_InterpreterGetOutputTensorCount(interpreter), 1); + + std::array input_dims = {2}; + ASSERT_EQ(TFL_InterpreterResizeInputTensor(interpreter, 0, input_dims.data(), + input_dims.size()), + kTfLiteOk); + ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk); + + TFL_Tensor* input_tensor = TFL_InterpreterGetInputTensor(interpreter, 0); + ASSERT_NE(input_tensor, nullptr); + EXPECT_EQ(TFL_TensorType(input_tensor), kTfLiteFloat32); + EXPECT_EQ(TFL_TensorNumDims(input_tensor), 1); + EXPECT_EQ(TFL_TensorDim(input_tensor, 0), 2); + EXPECT_EQ(TFL_TensorByteSize(input_tensor), sizeof(float) * 2); + + std::array input = {1.f, 3.f}; + ASSERT_EQ(TFL_TensorCopyFromBuffer(input_tensor, input.data(), + input.size() * sizeof(float)), + kTfLiteOk); + + ASSERT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteOk); + + const TFL_Tensor* output_tensor = + TFL_InterpreterGetOutputTensor(interpreter, 0); + ASSERT_NE(output_tensor, nullptr); + EXPECT_EQ(TFL_TensorType(output_tensor), kTfLiteFloat32); + EXPECT_EQ(TFL_TensorNumDims(output_tensor), 1); + EXPECT_EQ(TFL_TensorDim(output_tensor, 0), 2); + EXPECT_EQ(TFL_TensorByteSize(output_tensor), sizeof(float) * 2); + + std::array output; + ASSERT_EQ(TFL_TensorCopyToBuffer(output_tensor, output.data(), + output.size() * sizeof(float)), + kTfLiteOk); + EXPECT_EQ(output[0], 3.f); + EXPECT_EQ(output[1], 9.f); + + TFL_DeleteInterpreter(interpreter); +} + +} // namespace + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/experimental/c/exported_symbols.lds b/tensorflow/contrib/lite/experimental/c/exported_symbols.lds new file mode 100644 index 0000000000000000000000000000000000000000..a3ddc6bc8d370b1715fb1ebf2a66122296330249 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/exported_symbols.lds @@ -0,0 +1 @@ +_TFL_* diff --git a/tensorflow/contrib/lite/experimental/c/version_script.lds b/tensorflow/contrib/lite/experimental/c/version_script.lds new file mode 100644 index 0000000000000000000000000000000000000000..c0c8a2bca19afed186e6f8c72a58989a79c7b251 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/version_script.lds @@ -0,0 +1,9 @@ +VERS_1.0 { + # Export symbols in c_api.h. + global: + *TFL_*; + + # Hide everything else. + local: + *; +}; diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c72a5cae9ebfb15f60961fe25e622663cad89a41 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore @@ -0,0 +1,13 @@ +# Unity generated +Builds/ +Temp/ +Library/ +obj/ +# Visual Studio / MonoDevelop generated +*.csproj +*.unityproj +*.sln +*.suo +*.userprefs +# OS generated +.DS_Store diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta new file mode 100644 index 0000000000000000000000000000000000000000..ed9337b53e880b62f70953f197613dcb1409d208 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 71d1b4219b1da4aeaa1cebbec324fc81 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta new file mode 100644 index 0000000000000000000000000000000000000000..edcce00939a298683b15ea45a5ec92709c6abc4f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: d948aead14abd4c88947c9886d16f774 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta new file mode 100644 index 0000000000000000000000000000000000000000..36b35516f0cee064c8d8e4814a2ae515e28590ce --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: b810b85b794fa48fd93100acf5525e1f +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta new file mode 100644 index 0000000000000000000000000000000000000000..d4133da49a88d38a57d074d28b903f9f18102413 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 154f4201e2e454d4696fa5834eaa3ad3 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity new file mode 100644 index 0000000000000000000000000000000000000000..bcf24b89e335781877a7046001ac4deb6fc55041 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity @@ -0,0 +1,477 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!29 &1 +OcclusionCullingSettings: + m_ObjectHideFlags: 0 + serializedVersion: 2 + m_OcclusionBakeSettings: + smallestOccluder: 5 + smallestHole: 0.25 + backfaceThreshold: 100 + m_SceneGUID: 00000000000000000000000000000000 + m_OcclusionCullingData: {fileID: 0} +--- !u!104 &2 +RenderSettings: + m_ObjectHideFlags: 0 + serializedVersion: 8 + m_Fog: 0 + m_FogColor: {r: 0.5, g: 0.5, b: 0.5, a: 1} + m_FogMode: 3 + m_FogDensity: 0.01 + m_LinearFogStart: 0 + m_LinearFogEnd: 300 + m_AmbientSkyColor: {r: 0.212, g: 0.227, b: 0.259, a: 1} + m_AmbientEquatorColor: {r: 0.114, g: 0.125, b: 0.133, a: 1} + m_AmbientGroundColor: {r: 0.047, g: 0.043, b: 0.035, a: 1} + m_AmbientIntensity: 1 + m_AmbientMode: 3 + m_SubtractiveShadowColor: {r: 0.42, g: 0.478, b: 0.627, a: 1} + m_SkyboxMaterial: {fileID: 0} + m_HaloStrength: 0.5 + m_FlareStrength: 1 + m_FlareFadeSpeed: 3 + m_HaloTexture: {fileID: 0} + m_SpotCookie: {fileID: 10001, guid: 0000000000000000e000000000000000, type: 0} + m_DefaultReflectionMode: 0 + m_DefaultReflectionResolution: 128 + m_ReflectionBounces: 1 + m_ReflectionIntensity: 1 + m_CustomReflection: {fileID: 0} + m_Sun: {fileID: 0} + m_IndirectSpecularColor: {r: 0, g: 0, b: 0, a: 1} +--- !u!157 &3 +LightmapSettings: + m_ObjectHideFlags: 0 + serializedVersion: 11 + m_GIWorkflowMode: 1 + m_GISettings: + serializedVersion: 2 + m_BounceScale: 1 + m_IndirectOutputScale: 1 + m_AlbedoBoost: 1 + m_TemporalCoherenceThreshold: 1 + m_EnvironmentLightingMode: 0 + m_EnableBakedLightmaps: 0 + m_EnableRealtimeLightmaps: 0 + m_LightmapEditorSettings: + serializedVersion: 9 + m_Resolution: 2 + m_BakeResolution: 40 + m_TextureWidth: 1024 + m_TextureHeight: 1024 + m_AO: 0 + m_AOMaxDistance: 1 + m_CompAOExponent: 1 + m_CompAOExponentDirect: 0 + m_Padding: 2 + m_LightmapParameters: {fileID: 0} + m_LightmapsBakeMode: 1 + m_TextureCompression: 1 + m_FinalGather: 0 + m_FinalGatherFiltering: 1 + m_FinalGatherRayCount: 256 + m_ReflectionCompression: 2 + m_MixedBakeMode: 2 + m_BakeBackend: 0 + m_PVRSampling: 1 + m_PVRDirectSampleCount: 32 + m_PVRSampleCount: 500 + m_PVRBounces: 2 + m_PVRFilterTypeDirect: 0 + m_PVRFilterTypeIndirect: 0 + m_PVRFilterTypeAO: 0 + m_PVRFilteringMode: 1 + m_PVRCulling: 1 + m_PVRFilteringGaussRadiusDirect: 1 + m_PVRFilteringGaussRadiusIndirect: 5 + m_PVRFilteringGaussRadiusAO: 2 + m_PVRFilteringAtrousPositionSigmaDirect: 0.5 + m_PVRFilteringAtrousPositionSigmaIndirect: 2 + m_PVRFilteringAtrousPositionSigmaAO: 1 + m_ShowResolutionOverlay: 1 + m_LightingDataAsset: {fileID: 0} + m_UseShadowmask: 1 +--- !u!196 &4 +NavMeshSettings: + serializedVersion: 2 + m_ObjectHideFlags: 0 + m_BuildSettings: + serializedVersion: 2 + agentTypeID: 0 + agentRadius: 0.5 + agentHeight: 2 + agentSlope: 45 + agentClimb: 0.4 + ledgeDropHeight: 0 + maxJumpAcrossDistance: 0 + minRegionArea: 2 + manualCellSize: 0 + cellSize: 0.16666667 + manualTileSize: 0 + tileSize: 256 + accuratePlacement: 0 + debug: + m_Flags: 0 + m_NavMeshData: {fileID: 0} +--- !u!1 &492081941 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 492081945} + - component: {fileID: 492081944} + - component: {fileID: 492081943} + - component: {fileID: 492081942} + m_Layer: 0 + m_Name: Main Camera + m_TagString: MainCamera + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!81 &492081942 +AudioListener: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 492081941} + m_Enabled: 1 +--- !u!124 &492081943 +Behaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 492081941} + m_Enabled: 1 +--- !u!20 &492081944 +Camera: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 492081941} + m_Enabled: 1 + serializedVersion: 2 + m_ClearFlags: 1 + m_BackGroundColor: {r: 0.21933319, g: 0.21933319, b: 0.21933319, a: 0} + m_NormalizedViewPortRect: + serializedVersion: 2 + x: 0 + y: 0 + width: 1 + height: 1 + near clip plane: 0.3 + far clip plane: 1000 + field of view: 60 + orthographic: 1 + orthographic size: 5 + m_Depth: -1 + m_CullingMask: + serializedVersion: 2 + m_Bits: 4294967295 + m_RenderingPath: -1 + m_TargetTexture: {fileID: 0} + m_TargetDisplay: 0 + m_TargetEye: 3 + m_HDR: 1 + m_AllowMSAA: 1 + m_AllowDynamicResolution: 0 + m_ForceIntoRT: 0 + m_OcclusionCulling: 1 + m_StereoConvergence: 10 + m_StereoSeparation: 0.022 +--- !u!4 &492081945 +Transform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 492081941} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: -10} + m_LocalScale: {x: 1, y: 1, z: 1} + m_Children: + - {fileID: 904015944} + m_Father: {fileID: 0} + m_RootOrder: 0 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} +--- !u!1 &871349752 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 871349756} + - component: {fileID: 871349755} + - component: {fileID: 871349754} + - component: {fileID: 871349753} + m_Layer: 5 + m_Name: Canvas + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!114 &871349753 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 871349752} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 1301386320, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_IgnoreReversedGraphics: 1 + m_BlockingObjects: 0 + m_BlockingMask: + serializedVersion: 2 + m_Bits: 4294967295 +--- !u!114 &871349754 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 871349752} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 1980459831, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_UiScaleMode: 0 + m_ReferencePixelsPerUnit: 100 + m_ScaleFactor: 1 + m_ReferenceResolution: {x: 800, y: 600} + m_ScreenMatchMode: 0 + m_MatchWidthOrHeight: 0 + m_PhysicalUnit: 3 + m_FallbackScreenDPI: 96 + m_DefaultSpriteDPI: 96 + m_DynamicPixelsPerUnit: 1 +--- !u!223 &871349755 +Canvas: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 871349752} + m_Enabled: 1 + serializedVersion: 3 + m_RenderMode: 0 + m_Camera: {fileID: 0} + m_PlaneDistance: 100 + m_PixelPerfect: 0 + m_ReceivesEvents: 1 + m_OverrideSorting: 0 + m_OverridePixelPerfect: 0 + m_SortingBucketNormalizedSize: 0 + m_AdditionalShaderChannelsFlag: 0 + m_SortingLayerID: 0 + m_SortingOrder: 0 + m_TargetDisplay: 0 +--- !u!224 &871349756 +RectTransform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 871349752} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 0, y: 0, z: 0} + m_Children: + - {fileID: 1726294324} + m_Father: {fileID: 0} + m_RootOrder: 1 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} + m_AnchorMin: {x: 0, y: 0} + m_AnchorMax: {x: 0, y: 0} + m_AnchoredPosition: {x: 0, y: 0} + m_SizeDelta: {x: 0, y: 0} + m_Pivot: {x: 0, y: 0} +--- !u!1 &904015943 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 904015944} + - component: {fileID: 904015945} + m_Layer: 0 + m_Name: HelloTFLite + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!4 &904015944 +Transform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 904015943} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 1, y: 1, z: 1} + m_Children: [] + m_Father: {fileID: 492081945} + m_RootOrder: 0 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} +--- !u!114 &904015945 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 904015943} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: 899510441e0ca4be0879d3055e467878, type: 3} + m_Name: + m_EditorClassIdentifier: + model: {fileID: 4900000, guid: adff4e1dbdba344c199ee4fe7e84457e, type: 3} + inputs: + - 1 + - 3 + - 7 + inferenceText: {fileID: 1726294325} +--- !u!1 &1726294323 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 1726294324} + - component: {fileID: 1726294326} + - component: {fileID: 1726294325} + m_Layer: 5 + m_Name: InferenceText + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!224 &1726294324 +RectTransform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 1726294323} + m_LocalRotation: {x: -0, y: -0, z: -0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 1, y: 1, z: 1} + m_Children: [] + m_Father: {fileID: 871349756} + m_RootOrder: 0 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} + m_AnchorMin: {x: 0.5, y: 0.5} + m_AnchorMax: {x: 0.5, y: 0.5} + m_AnchoredPosition: {x: 0, y: 25} + m_SizeDelta: {x: 450, y: 250} + m_Pivot: {x: 0.5, y: 0.5} +--- !u!114 &1726294325 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 1726294323} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 708705254, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_Material: {fileID: 0} + m_Color: {r: 0.9338235, g: 0.9338235, b: 0.9338235, a: 1} + m_RaycastTarget: 1 + m_OnCullStateChanged: + m_PersistentCalls: + m_Calls: [] + m_TypeName: UnityEngine.UI.MaskableGraphic+CullStateChangedEvent, UnityEngine.UI, + Version=1.0.0.0, Culture=neutral, PublicKeyToken=null + m_FontData: + m_Font: {fileID: 10102, guid: 0000000000000000e000000000000000, type: 0} + m_FontSize: 35 + m_FontStyle: 0 + m_BestFit: 0 + m_MinSize: 2 + m_MaxSize: 40 + m_Alignment: 4 + m_AlignByGeometry: 0 + m_RichText: 1 + m_HorizontalOverflow: 0 + m_VerticalOverflow: 0 + m_LineSpacing: 1 + m_Text: 'Inference took 0.0153 ms + + Input: 1,3,7 + + Output: 3,9,21' +--- !u!222 &1726294326 +CanvasRenderer: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 1726294323} +--- !u!1 &2026426602 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 2026426605} + - component: {fileID: 2026426604} + - component: {fileID: 2026426603} + m_Layer: 0 + m_Name: EventSystem + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!114 &2026426603 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 2026426602} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 1077351063, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_HorizontalAxis: Horizontal + m_VerticalAxis: Vertical + m_SubmitButton: Submit + m_CancelButton: Cancel + m_InputActionsPerSecond: 10 + m_RepeatDelay: 0.5 + m_ForceModuleActive: 0 +--- !u!114 &2026426604 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 2026426602} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: -619905303, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_FirstSelected: {fileID: 0} + m_sendNavigationEvents: 1 + m_DragThreshold: 5 +--- !u!4 &2026426605 +Transform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 2026426602} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 1, y: 1, z: 1} + m_Children: [] + m_Father: {fileID: 0} + m_RootOrder: 2 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta new file mode 100644 index 0000000000000000000000000000000000000000..e1e13efb66027b555f1d45c76fe58fe2103774a2 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: f8a8c37a396584bb7b21687f33d6d3f8 +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes new file mode 100644 index 0000000000000000000000000000000000000000..aef0fe3d82c9d92dc444076d3b46e05af1923f46 Binary files /dev/null and b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes differ diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta new file mode 100644 index 0000000000000000000000000000000000000000..ba24871413e06154afd0c0d5e2db83b7619d34a9 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: adff4e1dbdba344c199ee4fe7e84457e +TextScriptImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta new file mode 100644 index 0000000000000000000000000000000000000000..28fde68b8b1346e88375dc7a8613270f0e2f2762 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: f7d1e2dec09b64acdb7b8f5aef9fcb44 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs new file mode 100644 index 0000000000000000000000000000000000000000..83291e61794819e7c57f69ed2be6ea40294e01da --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs @@ -0,0 +1,85 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using TensorFlowLite; +using UnityEngine; +using UnityEngine.UI; + +/// +/// Simple example demonstrating use of the experimental C# bindings for TensorFlowLite. +/// +public class HelloTFLite : MonoBehaviour { + + [Tooltip("Configurable TFLite model.")] + public TextAsset model; + + [Tooltip("Configurable TFLite input tensor data.")] + public float[] inputs; + + [Tooltip("Target Text widget for display of inference execution.")] + public Text inferenceText; + + private Interpreter interpreter; + private float[] outputs; + + void Awake() { + // As the demo is extremely simple, there's no need to run at full frame-rate. + QualitySettings.vSyncCount = 0; + Application.targetFrameRate = 5; + } + + void Start () { + interpreter = new Interpreter(model.bytes); + Debug.LogFormat( + "InputCount: {0}, OutputCount: {1}", + interpreter.GetInputTensorCount(), + interpreter.GetOutputTensorCount()); + } + + void Update () { + if (inputs == null) { + return; + } + + if (outputs == null || outputs.Length != inputs.Length) { + interpreter.ResizeInputTensor(0, new int[]{inputs.Length}); + interpreter.AllocateTensors(); + outputs = new float[inputs.Length]; + } + + float startTimeSeconds = Time.realtimeSinceStartup; + interpreter.SetInputTensorData(0, inputs); + interpreter.Invoke(); + interpreter.GetOutputTensorData(0, outputs); + float inferenceTimeSeconds = Time.realtimeSinceStartup - startTimeSeconds; + + inferenceText.text = string.Format( + "Inference took {0:0.0000} ms\nInput(s): {1}\nOutput(s): {2}", + inferenceTimeSeconds * 1000.0, + ArrayToString(inputs), + ArrayToString(outputs)); + } + + void OnDestroy() { + interpreter.Dispose(); + } + + private static string ArrayToString(float[] values) { + return string.Join(",", values.Select(x => x.ToString()).ToArray()); + } +} diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta new file mode 100644 index 0000000000000000000000000000000000000000..ba83f45084bb624e5e7777684b0fda98b4d46688 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 899510441e0ca4be0879d3055e467878 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta new file mode 100644 index 0000000000000000000000000000000000000000..bf5ce15c6a6932398d798d193b54f4ecfd8ba2d8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 16dad1655bcdc48f7b325a2a634b9c69 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta new file mode 100644 index 0000000000000000000000000000000000000000..22ed2c466bde1668595967f7a07f34a9193aaec8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: d70863368f8904d509a9b73d3a555914 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs new file mode 100644 index 0000000000000000000000000000000000000000..b6905b5fbfe5b49e30d79b372b3be35d90fe252a --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs @@ -0,0 +1,155 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +using System; +using System.Runtime.InteropServices; + +using TFL_Interpreter = System.IntPtr; +using TFL_InterpreterOptions = System.IntPtr; +using TFL_Model = System.IntPtr; +using TFL_Tensor = System.IntPtr; + +namespace TensorFlowLite +{ + /// + /// Simple C# bindings for the experimental TensorFlowLite C API. + /// + public class Interpreter : IDisposable + { + private const string TensorFlowLibrary = "tensorflowlite_c"; + + private TFL_Interpreter handle; + + public Interpreter(byte[] modelData) { + GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned); + IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject(); + TFL_Model model = TFL_NewModel(modelDataPtr, modelData.Length); + handle = TFL_NewInterpreter(model, /*options=*/IntPtr.Zero); + TFL_DeleteModel(model); + if (handle == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter"); + } + + ~Interpreter() { + Dispose(); + } + + public void Dispose() { + if (handle != IntPtr.Zero) TFL_DeleteInterpreter(handle); + handle = IntPtr.Zero; + } + + public void Invoke() { + ThrowIfError(TFL_InterpreterInvoke(handle)); + } + + public int GetInputTensorCount() { + return TFL_InterpreterGetInputTensorCount(handle); + } + + public void SetInputTensorData(int inputTensorIndex, Array inputTensorData) { + GCHandle tensorDataHandle = GCHandle.Alloc(inputTensorData, GCHandleType.Pinned); + IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject(); + TFL_Tensor tensor = TFL_InterpreterGetInputTensor(handle, inputTensorIndex); + ThrowIfError(TFL_TensorCopyFromBuffer( + tensor, tensorDataPtr, Buffer.ByteLength(inputTensorData))); + } + + public void ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape) { + ThrowIfError(TFL_InterpreterResizeInputTensor( + handle, inputTensorIndex, inputTensorShape, inputTensorShape.Length)); + } + + public void AllocateTensors() { + ThrowIfError(TFL_InterpreterAllocateTensors(handle)); + } + + public int GetOutputTensorCount() { + return TFL_InterpreterGetOutputTensorCount(handle); + } + + public void GetOutputTensorData(int outputTensorIndex, Array outputTensorData) { + GCHandle tensorDataHandle = GCHandle.Alloc(outputTensorData, GCHandleType.Pinned); + IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject(); + TFL_Tensor tensor = TFL_InterpreterGetOutputTensor(handle, outputTensorIndex); + ThrowIfError(TFL_TensorCopyToBuffer( + tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData))); + } + + private static void ThrowIfError(int resultCode) { + if (resultCode != 0) throw new Exception("TensorFlowLite operation failed."); + } + + #region Externs + + [DllImport (TensorFlowLibrary)] + private static extern unsafe TFL_Interpreter TFL_NewModel(IntPtr model_data, int model_size); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe TFL_Interpreter TFL_DeleteModel(TFL_Model model); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe TFL_Interpreter TFL_NewInterpreter( + TFL_Model model, + TFL_InterpreterOptions optional_options); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe void TFL_DeleteInterpreter(TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterGetInputTensorCount( + TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe TFL_Tensor TFL_InterpreterGetInputTensor( + TFL_Interpreter interpreter, + int input_index); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterResizeInputTensor( + TFL_Interpreter interpreter, + int input_index, + int[] input_dims, + int input_dims_size); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterAllocateTensors( + TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterInvoke(TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterGetOutputTensorCount( + TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe TFL_Tensor TFL_InterpreterGetOutputTensor( + TFL_Interpreter interpreter, + int output_index); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_TensorCopyFromBuffer( + TFL_Tensor tensor, + IntPtr input_data, + int input_data_size); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_TensorCopyToBuffer( + TFL_Tensor tensor, + IntPtr output_data, + int output_data_size); + + #endregion + } +} diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta new file mode 100644 index 0000000000000000000000000000000000000000..5ec84ef7f70e9be45ff6292ed7a412fac35010de --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 0bbaf59e6ac914ed1b28174fb9008a09 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..da6112576a5ca4290108f6d4c731bd4c391e91d4 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset @@ -0,0 +1,17 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!11 &1 +AudioManager: + m_ObjectHideFlags: 0 + m_Volume: 1 + Rolloff Scale: 1 + Doppler Factor: 1 + Default Speaker Mode: 2 + m_SampleRate: 0 + m_DSPBufferSize: 0 + m_VirtualVoiceCount: 512 + m_RealVoiceCount: 32 + m_SpatializerPlugin: + m_AmbisonicDecoderPlugin: + m_DisableAudio: 0 + m_VirtualizeEffects: 1 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..e7886b266a005f4d9d80f2fef8d1649dcfd3ed2b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset @@ -0,0 +1,6 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!236 &1 +ClusterInputManager: + m_ObjectHideFlags: 0 + m_Inputs: [] diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..78992f08c7ab7a4353c8a7d07cf1548174aaacbf --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset @@ -0,0 +1,29 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!55 &1 +PhysicsManager: + m_ObjectHideFlags: 0 + serializedVersion: 7 + m_Gravity: {x: 0, y: -9.81, z: 0} + m_DefaultMaterial: {fileID: 0} + m_BounceThreshold: 2 + m_SleepThreshold: 0.005 + m_DefaultContactOffset: 0.01 + m_DefaultSolverIterations: 6 + m_DefaultSolverVelocityIterations: 1 + m_QueriesHitBackfaces: 0 + m_QueriesHitTriggers: 1 + m_EnableAdaptiveForce: 0 + m_ClothInterCollisionDistance: 0 + m_ClothInterCollisionStiffness: 0 + m_ContactsGeneration: 1 + m_LayerCollisionMatrix: ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff + m_AutoSimulation: 1 + m_AutoSyncTransforms: 1 + m_ClothInterCollisionSettingsToggle: 0 + m_ContactPairsMode: 0 + m_BroadphaseType: 0 + m_WorldBounds: + m_Center: {x: 0, y: 0, z: 0} + m_Extent: {x: 250, y: 250, z: 250} + m_WorldSubdivisions: 8 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..6dc24f7dfdb697ad6f5d0a4ec5599bcd3cbd2f43 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset @@ -0,0 +1,7 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!1045 &1 +EditorBuildSettings: + m_ObjectHideFlags: 0 + serializedVersion: 2 + m_Scenes: [] diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..fcd016402f97e4c009a16640517a6930ed615ef9 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset @@ -0,0 +1,21 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!159 &1 +EditorSettings: + m_ObjectHideFlags: 0 + serializedVersion: 7 + m_ExternalVersionControlSupport: Visible Meta Files + m_SerializationMode: 2 + m_LineEndingsForNewScripts: 1 + m_DefaultBehaviorMode: 1 + m_SpritePackerMode: 4 + m_SpritePackerPaddingPower: 1 + m_EtcTextureCompressorBehavior: 1 + m_EtcTextureFastCompressor: 1 + m_EtcTextureNormalCompressor: 2 + m_EtcTextureBestCompressor: 4 + m_ProjectGenerationIncludedExtensions: txt;xml;fnt;cd;asmdef;rsp + m_ProjectGenerationRootNamespace: + m_UserGeneratedProjectSuffix: + m_CollabEditorSettings: + inProgressEnabled: 1 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..a9bbfb02d1e7065b7d0e90609a3928d667933477 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset @@ -0,0 +1,64 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!30 &1 +GraphicsSettings: + m_ObjectHideFlags: 0 + serializedVersion: 12 + m_Deferred: + m_Mode: 1 + m_Shader: {fileID: 69, guid: 0000000000000000f000000000000000, type: 0} + m_DeferredReflections: + m_Mode: 1 + m_Shader: {fileID: 74, guid: 0000000000000000f000000000000000, type: 0} + m_ScreenSpaceShadows: + m_Mode: 1 + m_Shader: {fileID: 64, guid: 0000000000000000f000000000000000, type: 0} + m_LegacyDeferred: + m_Mode: 1 + m_Shader: {fileID: 63, guid: 0000000000000000f000000000000000, type: 0} + m_DepthNormals: + m_Mode: 1 + m_Shader: {fileID: 62, guid: 0000000000000000f000000000000000, type: 0} + m_MotionVectors: + m_Mode: 1 + m_Shader: {fileID: 75, guid: 0000000000000000f000000000000000, type: 0} + m_LightHalo: + m_Mode: 1 + m_Shader: {fileID: 105, guid: 0000000000000000f000000000000000, type: 0} + m_LensFlare: + m_Mode: 1 + m_Shader: {fileID: 102, guid: 0000000000000000f000000000000000, type: 0} + m_AlwaysIncludedShaders: + - {fileID: 7, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 15104, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 15105, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 15106, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 10753, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 10770, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 17000, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 16000, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 16002, guid: 0000000000000000f000000000000000, type: 0} + m_PreloadedShaders: [] + m_SpritesDefaultMaterial: {fileID: 10754, guid: 0000000000000000f000000000000000, + type: 0} + m_CustomRenderPipeline: {fileID: 0} + m_TransparencySortMode: 0 + m_TransparencySortAxis: {x: 0, y: 0, z: 1} + m_DefaultRenderingPath: 1 + m_DefaultMobileRenderingPath: 1 + m_TierSettings: [] + m_LightmapStripping: 0 + m_FogStripping: 0 + m_InstancingStripping: 0 + m_LightmapKeepPlain: 1 + m_LightmapKeepDirCombined: 1 + m_LightmapKeepDynamicPlain: 1 + m_LightmapKeepDynamicDirCombined: 1 + m_LightmapKeepShadowMask: 1 + m_LightmapKeepSubtractive: 1 + m_FogKeepLinear: 1 + m_FogKeepExp: 1 + m_FogKeepExp2: 1 + m_AlbedoSwatchInfos: [] + m_LightsUseLinearIntensity: 0 + m_LightsUseColorTemperature: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..17c8f538e2152c0a0310b4870979eeecece2153c --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset @@ -0,0 +1,295 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!13 &1 +InputManager: + m_ObjectHideFlags: 0 + serializedVersion: 2 + m_Axes: + - serializedVersion: 3 + m_Name: Horizontal + descriptiveName: + descriptiveNegativeName: + negativeButton: left + positiveButton: right + altNegativeButton: a + altPositiveButton: d + gravity: 3 + dead: 0.001 + sensitivity: 3 + snap: 1 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Vertical + descriptiveName: + descriptiveNegativeName: + negativeButton: down + positiveButton: up + altNegativeButton: s + altPositiveButton: w + gravity: 3 + dead: 0.001 + sensitivity: 3 + snap: 1 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire1 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: left ctrl + altNegativeButton: + altPositiveButton: mouse 0 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire2 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: left alt + altNegativeButton: + altPositiveButton: mouse 1 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire3 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: left shift + altNegativeButton: + altPositiveButton: mouse 2 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Jump + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: space + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Mouse X + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0 + sensitivity: 0.1 + snap: 0 + invert: 0 + type: 1 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Mouse Y + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0 + sensitivity: 0.1 + snap: 0 + invert: 0 + type: 1 + axis: 1 + joyNum: 0 + - serializedVersion: 3 + m_Name: Mouse ScrollWheel + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0 + sensitivity: 0.1 + snap: 0 + invert: 0 + type: 1 + axis: 2 + joyNum: 0 + - serializedVersion: 3 + m_Name: Horizontal + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0.19 + sensitivity: 1 + snap: 0 + invert: 0 + type: 2 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Vertical + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0.19 + sensitivity: 1 + snap: 0 + invert: 1 + type: 2 + axis: 1 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire1 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: joystick button 0 + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire2 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: joystick button 1 + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire3 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: joystick button 2 + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Jump + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: joystick button 3 + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Submit + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: return + altNegativeButton: + altPositiveButton: joystick button 0 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Submit + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: enter + altNegativeButton: + altPositiveButton: space + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Cancel + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: escape + altNegativeButton: + altPositiveButton: joystick button 1 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset new file mode 100644 index 0000000000000000000000000000000000000000..3b0b7c3d183abdd300112f56965916ef11667f54 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset @@ -0,0 +1,91 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!126 &1 +NavMeshProjectSettings: + m_ObjectHideFlags: 0 + serializedVersion: 2 + areas: + - name: Walkable + cost: 1 + - name: Not Walkable + cost: 1 + - name: Jump + cost: 2 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + m_LastAgentTypeID: -887442657 + m_Settings: + - serializedVersion: 2 + agentTypeID: 0 + agentRadius: 0.5 + agentHeight: 2 + agentSlope: 45 + agentClimb: 0.75 + ledgeDropHeight: 0 + maxJumpAcrossDistance: 0 + minRegionArea: 2 + manualCellSize: 0 + cellSize: 0.16666667 + manualTileSize: 0 + tileSize: 256 + accuratePlacement: 0 + debug: + m_Flags: 0 + m_SettingNames: + - Humanoid diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..5dc6a831d9f2a11f08ed96571e0f602e3c3908b5 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset @@ -0,0 +1,8 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!149 &1 +NetworkManager: + m_ObjectHideFlags: 0 + m_DebugLevel: 0 + m_Sendrate: 15 + m_AssetToPrefab: {} diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..132ee6bc868f1aae138555dc139e054b0d1d8620 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset @@ -0,0 +1,37 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!19 &1 +Physics2DSettings: + m_ObjectHideFlags: 0 + serializedVersion: 3 + m_Gravity: {x: 0, y: -9.81} + m_DefaultMaterial: {fileID: 0} + m_VelocityIterations: 8 + m_PositionIterations: 3 + m_VelocityThreshold: 1 + m_MaxLinearCorrection: 0.2 + m_MaxAngularCorrection: 8 + m_MaxTranslationSpeed: 100 + m_MaxRotationSpeed: 360 + m_BaumgarteScale: 0.2 + m_BaumgarteTimeOfImpactScale: 0.75 + m_TimeToSleep: 0.5 + m_LinearSleepTolerance: 0.01 + m_AngularSleepTolerance: 2 + m_DefaultContactOffset: 0.01 + m_AutoSimulation: 1 + m_QueriesHitTriggers: 1 + m_QueriesStartInColliders: 1 + m_ChangeStopsCallbacks: 0 + m_CallbacksOnDisable: 1 + m_AutoSyncTransforms: 1 + m_AlwaysShowColliders: 0 + m_ShowColliderSleep: 1 + m_ShowColliderContacts: 0 + m_ShowColliderAABB: 0 + m_ContactArrowScale: 0.2 + m_ColliderAwakeColor: {r: 0.5686275, g: 0.95686275, b: 0.54509807, a: 0.7529412} + m_ColliderAsleepColor: {r: 0.5686275, g: 0.95686275, b: 0.54509807, a: 0.36078432} + m_ColliderContactColor: {r: 1, g: 0, b: 1, a: 0.6862745} + m_ColliderAABBColor: {r: 1, g: 1, b: 0, a: 0.2509804} + m_LayerCollisionMatrix: ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..3fbfab76c13c84f66a166c5dfe1d4552503350ff --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset @@ -0,0 +1,641 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!129 &1 +PlayerSettings: + m_ObjectHideFlags: 0 + serializedVersion: 14 + productGUID: a084943b991dd4597b140f4ce2b41c65 + AndroidProfiler: 0 + AndroidFilterTouchesWhenObscured: 0 + defaultScreenOrientation: 4 + targetDevice: 2 + useOnDemandResources: 0 + accelerometerFrequency: 60 + companyName: DefaultCompany + productName: TensorFlowLitePlugin + defaultCursor: {fileID: 0} + cursorHotspot: {x: 0, y: 0} + m_SplashScreenBackgroundColor: {r: 0.13725491, g: 0.12156863, b: 0.1254902, a: 1} + m_ShowUnitySplashScreen: 1 + m_ShowUnitySplashLogo: 1 + m_SplashScreenOverlayOpacity: 1 + m_SplashScreenAnimation: 1 + m_SplashScreenLogoStyle: 1 + m_SplashScreenDrawMode: 0 + m_SplashScreenBackgroundAnimationZoom: 1 + m_SplashScreenLogoAnimationZoom: 1 + m_SplashScreenBackgroundLandscapeAspect: 1 + m_SplashScreenBackgroundPortraitAspect: 1 + m_SplashScreenBackgroundLandscapeUvs: + serializedVersion: 2 + x: 0 + y: 0 + width: 1 + height: 1 + m_SplashScreenBackgroundPortraitUvs: + serializedVersion: 2 + x: 0 + y: 0 + width: 1 + height: 1 + m_SplashScreenLogos: [] + m_VirtualRealitySplashScreen: {fileID: 0} + m_HolographicTrackingLossScreen: {fileID: 0} + defaultScreenWidth: 1024 + defaultScreenHeight: 768 + defaultScreenWidthWeb: 960 + defaultScreenHeightWeb: 600 + m_StereoRenderingPath: 0 + m_ActiveColorSpace: 0 + m_MTRendering: 1 + m_StackTraceTypes: 010000000100000001000000010000000100000001000000 + iosShowActivityIndicatorOnLoading: -1 + androidShowActivityIndicatorOnLoading: -1 + tizenShowActivityIndicatorOnLoading: -1 + iosAppInBackgroundBehavior: 0 + displayResolutionDialog: 1 + iosAllowHTTPDownload: 1 + allowedAutorotateToPortrait: 1 + allowedAutorotateToPortraitUpsideDown: 1 + allowedAutorotateToLandscapeRight: 1 + allowedAutorotateToLandscapeLeft: 1 + useOSAutorotation: 1 + use32BitDisplayBuffer: 1 + preserveFramebufferAlpha: 0 + disableDepthAndStencilBuffers: 0 + androidBlitType: 0 + defaultIsFullScreen: 1 + defaultIsNativeResolution: 1 + macRetinaSupport: 1 + runInBackground: 0 + captureSingleScreen: 0 + muteOtherAudioSources: 0 + Prepare IOS For Recording: 0 + Force IOS Speakers When Recording: 0 + deferSystemGesturesMode: 0 + hideHomeButton: 0 + submitAnalytics: 1 + usePlayerLog: 1 + bakeCollisionMeshes: 0 + forceSingleInstance: 0 + resizableWindow: 0 + useMacAppStoreValidation: 0 + macAppStoreCategory: public.app-category.games + gpuSkinning: 0 + graphicsJobs: 0 + xboxPIXTextureCapture: 0 + xboxEnableAvatar: 0 + xboxEnableKinect: 0 + xboxEnableKinectAutoTracking: 0 + xboxEnableFitness: 0 + visibleInBackground: 1 + allowFullscreenSwitch: 1 + graphicsJobMode: 0 + macFullscreenMode: 2 + d3d11FullscreenMode: 1 + xboxSpeechDB: 0 + xboxEnableHeadOrientation: 0 + xboxEnableGuest: 0 + xboxEnablePIXSampling: 0 + metalFramebufferOnly: 0 + n3dsDisableStereoscopicView: 0 + n3dsEnableSharedListOpt: 1 + n3dsEnableVSync: 0 + xboxOneResolution: 0 + xboxOneSResolution: 0 + xboxOneXResolution: 3 + xboxOneMonoLoggingLevel: 0 + xboxOneLoggingLevel: 1 + xboxOneDisableEsram: 0 + xboxOnePresentImmediateThreshold: 0 + videoMemoryForVertexBuffers: 0 + psp2PowerMode: 0 + psp2AcquireBGM: 1 + wiiUTVResolution: 0 + wiiUGamePadMSAA: 1 + wiiUSupportsNunchuk: 0 + wiiUSupportsClassicController: 0 + wiiUSupportsBalanceBoard: 0 + wiiUSupportsMotionPlus: 0 + wiiUSupportsProController: 0 + wiiUAllowScreenCapture: 1 + wiiUControllerCount: 0 + m_SupportedAspectRatios: + 4:3: 1 + 5:4: 1 + 16:10: 1 + 16:9: 1 + Others: 1 + bundleVersion: 1.0 + preloadedAssets: [] + metroInputSource: 0 + wsaTransparentSwapchain: 0 + m_HolographicPauseOnTrackingLoss: 1 + xboxOneDisableKinectGpuReservation: 0 + xboxOneEnable7thCore: 0 + vrSettings: + cardboard: + depthFormat: 0 + enableTransitionView: 0 + daydream: + depthFormat: 0 + useSustainedPerformanceMode: 0 + enableVideoLayer: 0 + useProtectedVideoMemory: 0 + minimumSupportedHeadTracking: 0 + maximumSupportedHeadTracking: 1 + hololens: + depthFormat: 1 + depthBufferSharingEnabled: 0 + oculus: + sharedDepthBuffer: 0 + dashSupport: 0 + protectGraphicsMemory: 0 + useHDRDisplay: 0 + m_ColorGamuts: 00000000 + targetPixelDensity: 30 + resolutionScalingMode: 0 + androidSupportedAspectRatio: 1 + androidMaxAspectRatio: 2.1 + applicationIdentifier: {} + buildNumber: {} + AndroidBundleVersionCode: 1 + AndroidMinSdkVersion: 16 + AndroidTargetSdkVersion: 0 + AndroidPreferredInstallLocation: 1 + aotOptions: + stripEngineCode: 1 + iPhoneStrippingLevel: 0 + iPhoneScriptCallOptimization: 0 + ForceInternetPermission: 0 + ForceSDCardPermission: 0 + CreateWallpaper: 0 + APKExpansionFiles: 0 + keepLoadedShadersAlive: 0 + StripUnusedMeshComponents: 0 + VertexChannelCompressionMask: + serializedVersion: 2 + m_Bits: 238 + iPhoneSdkVersion: 988 + iOSTargetOSVersionString: 7.0 + tvOSSdkVersion: 0 + tvOSRequireExtendedGameController: 0 + tvOSTargetOSVersionString: 9.0 + uIPrerenderedIcon: 0 + uIRequiresPersistentWiFi: 0 + uIRequiresFullScreen: 1 + uIStatusBarHidden: 1 + uIExitOnSuspend: 0 + uIStatusBarStyle: 0 + iPhoneSplashScreen: {fileID: 0} + iPhoneHighResSplashScreen: {fileID: 0} + iPhoneTallHighResSplashScreen: {fileID: 0} + iPhone47inSplashScreen: {fileID: 0} + iPhone55inPortraitSplashScreen: {fileID: 0} + iPhone55inLandscapeSplashScreen: {fileID: 0} + iPhone58inPortraitSplashScreen: {fileID: 0} + iPhone58inLandscapeSplashScreen: {fileID: 0} + iPadPortraitSplashScreen: {fileID: 0} + iPadHighResPortraitSplashScreen: {fileID: 0} + iPadLandscapeSplashScreen: {fileID: 0} + iPadHighResLandscapeSplashScreen: {fileID: 0} + appleTVSplashScreen: {fileID: 0} + appleTVSplashScreen2x: {fileID: 0} + tvOSSmallIconLayers: [] + tvOSSmallIconLayers2x: [] + tvOSLargeIconLayers: [] + tvOSTopShelfImageLayers: [] + tvOSTopShelfImageLayers2x: [] + tvOSTopShelfImageWideLayers: [] + tvOSTopShelfImageWideLayers2x: [] + iOSLaunchScreenType: 0 + iOSLaunchScreenPortrait: {fileID: 0} + iOSLaunchScreenLandscape: {fileID: 0} + iOSLaunchScreenBackgroundColor: + serializedVersion: 2 + rgba: 0 + iOSLaunchScreenFillPct: 100 + iOSLaunchScreenSize: 100 + iOSLaunchScreenCustomXibPath: + iOSLaunchScreeniPadType: 0 + iOSLaunchScreeniPadImage: {fileID: 0} + iOSLaunchScreeniPadBackgroundColor: + serializedVersion: 2 + rgba: 0 + iOSLaunchScreeniPadFillPct: 100 + iOSLaunchScreeniPadSize: 100 + iOSLaunchScreeniPadCustomXibPath: + iOSUseLaunchScreenStoryboard: 0 + iOSLaunchScreenCustomStoryboardPath: + iOSDeviceRequirements: [] + iOSURLSchemes: [] + iOSBackgroundModes: 0 + iOSMetalForceHardShadows: 0 + metalEditorSupport: 1 + metalAPIValidation: 1 + iOSRenderExtraFrameOnPause: 0 + appleDeveloperTeamID: + iOSManualSigningProvisioningProfileID: + tvOSManualSigningProvisioningProfileID: + appleEnableAutomaticSigning: 0 + clonedFromGUID: 00000000000000000000000000000000 + AndroidTargetDevice: 0 + AndroidSplashScreenScale: 0 + androidSplashScreen: {fileID: 0} + AndroidKeystoreName: + AndroidKeyaliasName: + AndroidTVCompatibility: 1 + AndroidIsGame: 1 + AndroidEnableTango: 0 + androidEnableBanner: 1 + androidUseLowAccuracyLocation: 0 + m_AndroidBanners: + - width: 320 + height: 180 + banner: {fileID: 0} + androidGamepadSupportLevel: 0 + resolutionDialogBanner: {fileID: 0} + m_BuildTargetIcons: [] + m_BuildTargetBatching: [] + m_BuildTargetGraphicsAPIs: [] + m_BuildTargetVRSettings: [] + m_BuildTargetEnableVuforiaSettings: [] + openGLRequireES31: 0 + openGLRequireES31AEP: 0 + m_TemplateCustomTags: {} + mobileMTRendering: + Android: 1 + iPhone: 1 + tvOS: 1 + m_BuildTargetGroupLightmapEncodingQuality: [] + wiiUTitleID: 0005000011000000 + wiiUGroupID: 00010000 + wiiUCommonSaveSize: 4096 + wiiUAccountSaveSize: 2048 + wiiUOlvAccessKey: 0 + wiiUTinCode: 0 + wiiUJoinGameId: 0 + wiiUJoinGameModeMask: 0000000000000000 + wiiUCommonBossSize: 0 + wiiUAccountBossSize: 0 + wiiUAddOnUniqueIDs: [] + wiiUMainThreadStackSize: 3072 + wiiULoaderThreadStackSize: 1024 + wiiUSystemHeapSize: 128 + wiiUTVStartupScreen: {fileID: 0} + wiiUGamePadStartupScreen: {fileID: 0} + wiiUDrcBufferDisabled: 0 + wiiUProfilerLibPath: + playModeTestRunnerEnabled: 0 + actionOnDotNetUnhandledException: 1 + enableInternalProfiler: 0 + logObjCUncaughtExceptions: 1 + enableCrashReportAPI: 0 + cameraUsageDescription: + locationUsageDescription: + microphoneUsageDescription: + switchNetLibKey: + switchSocketMemoryPoolSize: 6144 + switchSocketAllocatorPoolSize: 128 + switchSocketConcurrencyLimit: 14 + switchScreenResolutionBehavior: 2 + switchUseCPUProfiler: 0 + switchApplicationID: 0x01004b9000490000 + switchNSODependencies: + switchTitleNames_0: + switchTitleNames_1: + switchTitleNames_2: + switchTitleNames_3: + switchTitleNames_4: + switchTitleNames_5: + switchTitleNames_6: + switchTitleNames_7: + switchTitleNames_8: + switchTitleNames_9: + switchTitleNames_10: + switchTitleNames_11: + switchTitleNames_12: + switchTitleNames_13: + switchTitleNames_14: + switchPublisherNames_0: + switchPublisherNames_1: + switchPublisherNames_2: + switchPublisherNames_3: + switchPublisherNames_4: + switchPublisherNames_5: + switchPublisherNames_6: + switchPublisherNames_7: + switchPublisherNames_8: + switchPublisherNames_9: + switchPublisherNames_10: + switchPublisherNames_11: + switchPublisherNames_12: + switchPublisherNames_13: + switchPublisherNames_14: + switchIcons_0: {fileID: 0} + switchIcons_1: {fileID: 0} + switchIcons_2: {fileID: 0} + switchIcons_3: {fileID: 0} + switchIcons_4: {fileID: 0} + switchIcons_5: {fileID: 0} + switchIcons_6: {fileID: 0} + switchIcons_7: {fileID: 0} + switchIcons_8: {fileID: 0} + switchIcons_9: {fileID: 0} + switchIcons_10: {fileID: 0} + switchIcons_11: {fileID: 0} + switchIcons_12: {fileID: 0} + switchIcons_13: {fileID: 0} + switchIcons_14: {fileID: 0} + switchSmallIcons_0: {fileID: 0} + switchSmallIcons_1: {fileID: 0} + switchSmallIcons_2: {fileID: 0} + switchSmallIcons_3: {fileID: 0} + switchSmallIcons_4: {fileID: 0} + switchSmallIcons_5: {fileID: 0} + switchSmallIcons_6: {fileID: 0} + switchSmallIcons_7: {fileID: 0} + switchSmallIcons_8: {fileID: 0} + switchSmallIcons_9: {fileID: 0} + switchSmallIcons_10: {fileID: 0} + switchSmallIcons_11: {fileID: 0} + switchSmallIcons_12: {fileID: 0} + switchSmallIcons_13: {fileID: 0} + switchSmallIcons_14: {fileID: 0} + switchManualHTML: + switchAccessibleURLs: + switchLegalInformation: + switchMainThreadStackSize: 1048576 + switchPresenceGroupId: + switchLogoHandling: 0 + switchReleaseVersion: 0 + switchDisplayVersion: 1.0.0 + switchStartupUserAccount: 0 + switchTouchScreenUsage: 0 + switchSupportedLanguagesMask: 0 + switchLogoType: 0 + switchApplicationErrorCodeCategory: + switchUserAccountSaveDataSize: 0 + switchUserAccountSaveDataJournalSize: 0 + switchApplicationAttribute: 0 + switchCardSpecSize: -1 + switchCardSpecClock: -1 + switchRatingsMask: 0 + switchRatingsInt_0: 0 + switchRatingsInt_1: 0 + switchRatingsInt_2: 0 + switchRatingsInt_3: 0 + switchRatingsInt_4: 0 + switchRatingsInt_5: 0 + switchRatingsInt_6: 0 + switchRatingsInt_7: 0 + switchRatingsInt_8: 0 + switchRatingsInt_9: 0 + switchRatingsInt_10: 0 + switchRatingsInt_11: 0 + switchLocalCommunicationIds_0: + switchLocalCommunicationIds_1: + switchLocalCommunicationIds_2: + switchLocalCommunicationIds_3: + switchLocalCommunicationIds_4: + switchLocalCommunicationIds_5: + switchLocalCommunicationIds_6: + switchLocalCommunicationIds_7: + switchParentalControl: 0 + switchAllowsScreenshot: 1 + switchAllowsVideoCapturing: 1 + switchAllowsRuntimeAddOnContentInstall: 0 + switchDataLossConfirmation: 0 + switchSupportedNpadStyles: 3 + switchSocketConfigEnabled: 0 + switchTcpInitialSendBufferSize: 32 + switchTcpInitialReceiveBufferSize: 64 + switchTcpAutoSendBufferSizeMax: 256 + switchTcpAutoReceiveBufferSizeMax: 256 + switchUdpSendBufferSize: 9 + switchUdpReceiveBufferSize: 42 + switchSocketBufferEfficiency: 4 + switchSocketInitializeEnabled: 1 + switchNetworkInterfaceManagerInitializeEnabled: 1 + switchPlayerConnectionEnabled: 1 + ps4NPAgeRating: 12 + ps4NPTitleSecret: + ps4NPTrophyPackPath: + ps4ParentalLevel: 11 + ps4ContentID: ED1633-NPXX51362_00-0000000000000000 + ps4Category: 0 + ps4MasterVersion: 01.00 + ps4AppVersion: 01.00 + ps4AppType: 0 + ps4ParamSfxPath: + ps4VideoOutPixelFormat: 0 + ps4VideoOutInitialWidth: 1920 + ps4VideoOutBaseModeInitialWidth: 1920 + ps4VideoOutReprojectionRate: 60 + ps4PronunciationXMLPath: + ps4PronunciationSIGPath: + ps4BackgroundImagePath: + ps4StartupImagePath: + ps4StartupImagesFolder: + ps4IconImagesFolder: + ps4SaveDataImagePath: + ps4SdkOverride: + ps4BGMPath: + ps4ShareFilePath: + ps4ShareOverlayImagePath: + ps4PrivacyGuardImagePath: + ps4NPtitleDatPath: + ps4RemotePlayKeyAssignment: -1 + ps4RemotePlayKeyMappingDir: + ps4PlayTogetherPlayerCount: 0 + ps4EnterButtonAssignment: 1 + ps4ApplicationParam1: 0 + ps4ApplicationParam2: 0 + ps4ApplicationParam3: 0 + ps4ApplicationParam4: 0 + ps4DownloadDataSize: 0 + ps4GarlicHeapSize: 2048 + ps4ProGarlicHeapSize: 2560 + ps4Passcode: d3hjjul8UhK6ZnQCEBYYQPozR9sQV066 + ps4pnSessions: 1 + ps4pnPresence: 1 + ps4pnFriends: 1 + ps4pnGameCustomData: 1 + playerPrefsSupport: 0 + restrictedAudioUsageRights: 0 + ps4UseResolutionFallback: 0 + ps4ReprojectionSupport: 0 + ps4UseAudio3dBackend: 0 + ps4SocialScreenEnabled: 0 + ps4ScriptOptimizationLevel: 0 + ps4Audio3dVirtualSpeakerCount: 14 + ps4attribCpuUsage: 0 + ps4PatchPkgPath: + ps4PatchLatestPkgPath: + ps4PatchChangeinfoPath: + ps4PatchDayOne: 0 + ps4attribUserManagement: 0 + ps4attribMoveSupport: 0 + ps4attrib3DSupport: 0 + ps4attribShareSupport: 0 + ps4attribExclusiveVR: 0 + ps4disableAutoHideSplash: 0 + ps4videoRecordingFeaturesUsed: 0 + ps4contentSearchFeaturesUsed: 0 + ps4attribEyeToEyeDistanceSettingVR: 0 + ps4IncludedModules: [] + monoEnv: + psp2Splashimage: {fileID: 0} + psp2NPTrophyPackPath: + psp2NPSupportGBMorGJP: 0 + psp2NPAgeRating: 12 + psp2NPTitleDatPath: + psp2NPCommsID: + psp2NPCommunicationsID: + psp2NPCommsPassphrase: + psp2NPCommsSig: + psp2ParamSfxPath: + psp2ManualPath: + psp2LiveAreaGatePath: + psp2LiveAreaBackroundPath: + psp2LiveAreaPath: + psp2LiveAreaTrialPath: + psp2PatchChangeInfoPath: + psp2PatchOriginalPackage: + psp2PackagePassword: 3onkgZsAECEn0fzCoWiCtWCKe4l74pE5 + psp2KeystoneFile: + psp2MemoryExpansionMode: 0 + psp2DRMType: 0 + psp2StorageType: 0 + psp2MediaCapacity: 0 + psp2DLCConfigPath: + psp2ThumbnailPath: + psp2BackgroundPath: + psp2SoundPath: + psp2TrophyCommId: + psp2TrophyPackagePath: + psp2PackagedResourcesPath: + psp2SaveDataQuota: 10240 + psp2ParentalLevel: 1 + psp2ShortTitle: Not Set + psp2ContentID: IV0000-ABCD12345_00-0123456789ABCDEF + psp2Category: 0 + psp2MasterVersion: 01.00 + psp2AppVersion: 01.00 + psp2TVBootMode: 0 + psp2EnterButtonAssignment: 2 + psp2TVDisableEmu: 0 + psp2AllowTwitterDialog: 1 + psp2Upgradable: 0 + psp2HealthWarning: 0 + psp2UseLibLocation: 0 + psp2InfoBarOnStartup: 0 + psp2InfoBarColor: 0 + psp2ScriptOptimizationLevel: 0 + psmSplashimage: {fileID: 0} + splashScreenBackgroundSourceLandscape: {fileID: 0} + splashScreenBackgroundSourcePortrait: {fileID: 0} + spritePackerPolicy: + webGLMemorySize: 256 + webGLExceptionSupport: 1 + webGLNameFilesAsHashes: 0 + webGLDataCaching: 0 + webGLDebugSymbols: 0 + webGLEmscriptenArgs: + webGLModulesDirectory: + webGLTemplate: APPLICATION:Default + webGLAnalyzeBuildSize: 0 + webGLUseEmbeddedResources: 0 + webGLUseWasm: 0 + webGLCompressionFormat: 1 + scriptingDefineSymbols: {} + platformArchitecture: {} + scriptingBackend: {} + incrementalIl2cppBuild: {} + additionalIl2CppArgs: + scriptingRuntimeVersion: 0 + apiCompatibilityLevelPerPlatform: {} + m_RenderingPath: 1 + m_MobileRenderingPath: 1 + metroPackageName: TensorFlowLitePlugin + metroPackageVersion: + metroCertificatePath: + metroCertificatePassword: + metroCertificateSubject: + metroCertificateIssuer: + metroCertificateNotAfter: 0000000000000000 + metroApplicationDescription: TensorFlowLitePlugin + wsaImages: {} + metroTileShortName: + metroCommandLineArgsFile: + metroTileShowName: 0 + metroMediumTileShowName: 0 + metroLargeTileShowName: 0 + metroWideTileShowName: 0 + metroDefaultTileSize: 1 + metroTileForegroundText: 2 + metroTileBackgroundColor: {r: 0.13333334, g: 0.17254902, b: 0.21568628, a: 0} + metroSplashScreenBackgroundColor: {r: 0.12941177, g: 0.17254902, b: 0.21568628, + a: 1} + metroSplashScreenUseBackgroundColor: 0 + platformCapabilities: {} + metroFTAName: + metroFTAFileTypes: [] + metroProtocolName: + metroCompilationOverrides: 1 + tizenProductDescription: + tizenProductURL: + tizenSigningProfileName: + tizenGPSPermissions: 0 + tizenMicrophonePermissions: 0 + tizenDeploymentTarget: + tizenDeploymentTargetType: -1 + tizenMinOSVersion: 1 + n3dsUseExtSaveData: 0 + n3dsCompressStaticMem: 1 + n3dsExtSaveDataNumber: 0x12345 + n3dsStackSize: 131072 + n3dsTargetPlatform: 2 + n3dsRegion: 7 + n3dsMediaSize: 0 + n3dsLogoStyle: 3 + n3dsTitle: GameName + n3dsProductCode: + n3dsApplicationId: 0xFF3FF + XboxOneProductId: + XboxOneUpdateKey: + XboxOneSandboxId: + XboxOneContentId: + XboxOneTitleId: + XboxOneSCId: + XboxOneGameOsOverridePath: + XboxOnePackagingOverridePath: + XboxOneAppManifestOverridePath: + XboxOnePackageEncryption: 0 + XboxOnePackageUpdateGranularity: 2 + XboxOneDescription: + XboxOneLanguage: + - enus + XboxOneCapability: [] + XboxOneGameRating: {} + XboxOneIsContentPackage: 0 + XboxOneEnableGPUVariability: 0 + XboxOneSockets: {} + XboxOneSplashScreen: {fileID: 0} + XboxOneAllowedProductIds: [] + XboxOnePersistentLocalStorageSize: 0 + XboxOneXTitleMemory: 8 + xboxOneScriptCompiler: 0 + vrEditorSettings: + daydream: + daydreamIconForeground: {fileID: 0} + daydreamIconBackground: {fileID: 0} + cloudServicesEnabled: {} + facebookSdkVersion: 7.9.4 + apiCompatibilityLevel: 2 + cloudProjectId: + projectName: + organizationId: + cloudEnabled: 0 + enableNativePlatformBackendsForNewInputSystem: 0 + disableOldInputManagerSupport: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt new file mode 100644 index 0000000000000000000000000000000000000000..4a9cfb61ab55abc2f0d09b0225a802ef8122eaaf --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt @@ -0,0 +1 @@ +m_EditorVersion: 2017.4.6f1 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..05daac3c4922feef068af19efa921fcbb476afde --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset @@ -0,0 +1,191 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!47 &1 +QualitySettings: + m_ObjectHideFlags: 0 + serializedVersion: 5 + m_CurrentQuality: 5 + m_QualitySettings: + - serializedVersion: 2 + name: Very Low + pixelLightCount: 0 + shadows: 0 + shadowResolution: 0 + shadowProjection: 1 + shadowCascades: 1 + shadowDistance: 15 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 0 + blendWeights: 1 + textureQuality: 1 + anisotropicTextures: 0 + antiAliasing: 0 + softParticles: 0 + softVegetation: 0 + realtimeReflectionProbes: 0 + billboardsFaceCameraPosition: 0 + vSyncCount: 0 + lodBias: 0.3 + maximumLODLevel: 0 + particleRaycastBudget: 4 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: Low + pixelLightCount: 0 + shadows: 0 + shadowResolution: 0 + shadowProjection: 1 + shadowCascades: 1 + shadowDistance: 20 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 0 + blendWeights: 2 + textureQuality: 0 + anisotropicTextures: 0 + antiAliasing: 0 + softParticles: 0 + softVegetation: 0 + realtimeReflectionProbes: 0 + billboardsFaceCameraPosition: 0 + vSyncCount: 0 + lodBias: 0.4 + maximumLODLevel: 0 + particleRaycastBudget: 16 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: Medium + pixelLightCount: 1 + shadows: 1 + shadowResolution: 0 + shadowProjection: 1 + shadowCascades: 1 + shadowDistance: 20 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 0 + blendWeights: 2 + textureQuality: 0 + anisotropicTextures: 1 + antiAliasing: 0 + softParticles: 0 + softVegetation: 0 + realtimeReflectionProbes: 0 + billboardsFaceCameraPosition: 0 + vSyncCount: 1 + lodBias: 0.7 + maximumLODLevel: 0 + particleRaycastBudget: 64 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: High + pixelLightCount: 2 + shadows: 2 + shadowResolution: 1 + shadowProjection: 1 + shadowCascades: 2 + shadowDistance: 40 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 1 + blendWeights: 2 + textureQuality: 0 + anisotropicTextures: 1 + antiAliasing: 0 + softParticles: 0 + softVegetation: 1 + realtimeReflectionProbes: 1 + billboardsFaceCameraPosition: 1 + vSyncCount: 1 + lodBias: 1 + maximumLODLevel: 0 + particleRaycastBudget: 256 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: Very High + pixelLightCount: 3 + shadows: 2 + shadowResolution: 2 + shadowProjection: 1 + shadowCascades: 2 + shadowDistance: 70 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 1 + blendWeights: 4 + textureQuality: 0 + anisotropicTextures: 2 + antiAliasing: 2 + softParticles: 1 + softVegetation: 1 + realtimeReflectionProbes: 1 + billboardsFaceCameraPosition: 1 + vSyncCount: 1 + lodBias: 1.5 + maximumLODLevel: 0 + particleRaycastBudget: 1024 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: Ultra + pixelLightCount: 4 + shadows: 2 + shadowResolution: 2 + shadowProjection: 1 + shadowCascades: 4 + shadowDistance: 150 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 1 + blendWeights: 4 + textureQuality: 0 + anisotropicTextures: 2 + antiAliasing: 2 + softParticles: 1 + softVegetation: 1 + realtimeReflectionProbes: 1 + billboardsFaceCameraPosition: 1 + vSyncCount: 1 + lodBias: 2 + maximumLODLevel: 0 + particleRaycastBudget: 4096 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + m_PerPlatformDefaultQuality: + Android: 2 + Nintendo 3DS: 5 + Nintendo Switch: 5 + PS4: 5 + PSM: 5 + PSP2: 2 + Standalone: 5 + Tizen: 2 + WebGL: 3 + WiiU: 5 + Windows Store Apps: 5 + XboxOne: 5 + iPhone: 2 + tvOS: 2 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..1c92a7840ec11895c76785f65d949a3d20d53355 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset @@ -0,0 +1,43 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!78 &1 +TagManager: + serializedVersion: 2 + tags: [] + layers: + - Default + - TransparentFX + - Ignore Raycast + - + - Water + - UI + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + m_SortingLayers: + - name: Default + uniqueID: 0 + locked: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..558a017e1f50b2db73414a1abad3c033922774f8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset @@ -0,0 +1,9 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!5 &1 +TimeManager: + m_ObjectHideFlags: 0 + Fixed Timestep: 0.02 + Maximum Allowed Timestep: 0.33333334 + m_TimeScale: 1 + Maximum Particle Timestep: 0.03 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..3da14d5baf1fa24df1746c3ce9d969eda3a9c59d --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset @@ -0,0 +1,34 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!310 &1 +UnityConnectSettings: + m_ObjectHideFlags: 0 + m_Enabled: 0 + m_TestMode: 0 + m_TestEventUrl: + m_TestConfigUrl: + m_TestInitMode: 0 + CrashReportingSettings: + m_EventUrl: https://perf-events.cloud.unity3d.com/api/events/crashes + m_NativeEventUrl: https://perf-events.cloud.unity3d.com/symbolicate + m_Enabled: 0 + m_CaptureEditorExceptions: 1 + UnityPurchasingSettings: + m_Enabled: 0 + m_TestMode: 0 + UnityAnalyticsSettings: + m_Enabled: 0 + m_InitializeOnStartup: 1 + m_TestMode: 0 + m_TestEventUrl: + m_TestConfigUrl: + UnityAdsSettings: + m_Enabled: 0 + m_InitializeOnStartup: 1 + m_TestMode: 0 + m_IosGameId: + m_AndroidGameId: + m_GameIds: {} + m_GameId: + PerformanceReportingSettings: + m_Enabled: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f480c49cd050de2192e9673f72c9e4d5c3c6ceff --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md @@ -0,0 +1,29 @@ +# TF Lite Experimental Unity Plugin + +This directory contains an experimental sample Unity (2017) Plugin, based on +the experimental TF Lite C API. The sample demonstrates running inference within +Unity by way of a C# `Interpreter` wrapper. + +Note that the native TF Lite plugin(s) *must* be built before using the Unity +Plugin, and placed in Assets/TensorFlowLite/SDK/Plugins/. For the editor (note +that this has only been tested on Linux; the syntax may differ on Mac/Windows): + +```sh +bazel build -c opt --cxxopt=--std=c++11 \ + //tensorflow/contrib/lite/experimental/c:libtensorflowlite_c.so +``` + +and for Android: + +```sh +bazel build -c opt --cxxopt=--std=c++11 \ + --crosstool_top=//external:android/crosstool \ + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --cpu=armeabi-v7a \ + //tensorflow/contrib/lite/experimental/c:libtensorflowlite_c.so +``` + +If you encounter issues with native plugin discovery on Mac ("Darwin") +platforms, try renaming `libtensorflowlite_c.so` to `tensorflowlite_c.bundle`. +Similarly, on Windows you'll likely need to rename `libtensorflowlite_c.so` to +`tensorflowlite_c.dll`. diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json new file mode 100644 index 0000000000000000000000000000000000000000..526aca60573f334a6b6bd536fa5be9c26d678e0f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json @@ -0,0 +1,4 @@ +{ + "dependencies": { + } +} diff --git a/tensorflow/contrib/lite/experimental/kernels/BUILD b/tensorflow/contrib/lite/experimental/kernels/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..9c06c4ebd958294586dbb1fde5040a0d328954ac --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/BUILD @@ -0,0 +1,84 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +# ctc support classes imported directly from TensorFlow. +cc_library( + name = "ctc_utils", + hdrs = [ + "ctc_beam_entry.h", + "ctc_beam_scorer.h", + "ctc_beam_search.h", + "ctc_decoder.h", + "ctc_loss_util.h", + ], + deps = [ + ":top_n", + "//tensorflow/contrib/lite/kernels/internal:types", + "//third_party/eigen3", + ], +) + +# top_n support classes imported directly from TensorFlow. +cc_library( + name = "top_n", + hdrs = [ + "top_n.h", + ], + deps = [ + "//tensorflow/contrib/lite/kernels/internal:types", + ], +) + +cc_library( + name = "experimental_ops", + srcs = [ + "ctc_beam_search_decoder.cc", + ], + # Suppress warnings that are introduced by Eigen Tensor. + copts = tflite_copts() + [ + "-Wno-error=reorder", + ] + select({ + "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"], + "//conditions:default": [ + ], + }), + deps = [ + ":ctc_utils", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/contrib/lite/kernels/internal:kernel_utils", + "//tensorflow/contrib/lite/kernels/internal:optimized", + "//tensorflow/contrib/lite/kernels/internal:optimized_base", + "//tensorflow/contrib/lite/kernels/internal:quantization_util", + "//tensorflow/contrib/lite/kernels/internal:reference", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "@flatbuffers", + ], +) + +tf_cc_test( + name = "ctc_beam_search_decoder_test", + size = "small", + srcs = ["ctc_beam_search_decoder_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":experimental_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h new file mode 100644 index 0000000000000000000000000000000000000000..a60ff2a1c53f1b3f9f490ab5cf2bc429ba09dff0 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h @@ -0,0 +1,150 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_beam_entry.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ + +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h" + +namespace tflite { +namespace experimental { +namespace ctc { + +// The ctc_beam_search namespace holds several classes meant to be accessed only +// in case of extending the CTCBeamSearch decoder to allow custom scoring +// functions. +// +// BeamEntry is exposed through template arguments BeamScorer and BeamComparer +// of CTCBeamSearch (ctc_beam_search.h). +namespace ctc_beam_search { + +struct EmptyBeamState {}; + +struct BeamProbability { + BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {} + void Reset() { + total = kLogZero; + blank = kLogZero; + label = kLogZero; + } + float total; + float blank; + float label; +}; + +template +class BeamRoot; + +template +struct BeamEntry { + // BeamRoot::AddEntry() serves as the factory method. + friend BeamEntry* BeamRoot::AddEntry( + BeamEntry* p, int l); + inline bool Active() const { return newp.total != kLogZero; } + // Return the child at the given index, or construct a new one in-place if + // none was found. + BeamEntry& GetChild(int ind) { + auto entry = children.emplace(ind, nullptr); + auto& child_entry = entry.first->second; + // If this is a new child, populate the BeamEntry*. + if (entry.second) { + child_entry = beam_root->AddEntry(this, ind); + } + return *child_entry; + } + std::vector LabelSeq(bool merge_repeated) const { + std::vector labels; + int prev_label = -1; + const BeamEntry* c = this; + while (c->parent != nullptr) { // Checking c->parent to skip root leaf. + if (!merge_repeated || c->label != prev_label) { + labels.push_back(c->label); + } + prev_label = c->label; + c = c->parent; + } + std::reverse(labels.begin(), labels.end()); + return labels; + } + + BeamEntry* parent; + int label; + // All instances of child BeamEntry are owned by *beam_root. + std::unordered_map*> children; + BeamProbability oldp; + BeamProbability newp; + CTCBeamState state; + + private: + // Constructor giving parent, label, and the beam_root. + // The object pointed to by p cannot be copied and should not be moved, + // otherwise parent will become invalid. + // This private constructor is only called through the factory method + // BeamRoot::AddEntry(). + BeamEntry(BeamEntry* p, int l, BeamRoot* beam_root) + : parent(p), label(l), beam_root(beam_root) {} + BeamRoot* beam_root; + + BeamEntry(const BeamEntry&) = delete; + void operator=(const BeamEntry&) = delete; +}; + +// This class owns all instances of BeamEntry. This is used to avoid recursive +// destructor call during destruction. +template +class BeamRoot { + public: + BeamRoot(BeamEntry* p, int l) { root_entry_ = AddEntry(p, l); } + BeamRoot(const BeamRoot&) = delete; + BeamRoot& operator=(const BeamRoot&) = delete; + + BeamEntry* AddEntry(BeamEntry* p, int l) { + auto* new_entry = new BeamEntry(p, l, this); + beam_entries_.emplace_back(new_entry); + return new_entry; + } + BeamEntry* RootEntry() const { return root_entry_; } + + private: + BeamEntry* root_entry_ = nullptr; + std::vector>> beam_entries_; +}; + +// BeamComparer is the default beam comparer provided in CTCBeamSearch. +template +class BeamComparer { + public: + virtual ~BeamComparer() {} + virtual bool inline operator()(const BeamEntry* a, + const BeamEntry* b) const { + return a->newp.total > b->newp.total; + } +}; + +} // namespace ctc_beam_search + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h new file mode 100644 index 0000000000000000000000000000000000000000..ec60e26257b0f4126e7a7abed6a663abe277ef12 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Collection of scoring classes that can be extended and provided to the +// CTCBeamSearchDecoder to incorporate additional scoring logic (such as a +// language model). +// +// To build a custom scorer extend and implement the pure virtual methods from +// BeamScorerInterface. The default CTC decoding behavior is implemented +// through BaseBeamScorer. + +// Copied from tensorflow/core/util/ctc/ctc_beam_scorer.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ + +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h" + +namespace tflite { +namespace experimental { +namespace ctc { + +// Base implementation of a beam scorer used by default by the decoder that can +// be subclassed and provided as an argument to CTCBeamSearchDecoder, if complex +// scoring is required. Its main purpose is to provide a thin layer for +// integrating language model scoring easily. +template +class BaseBeamScorer { + public: + virtual ~BaseBeamScorer() {} + // State initialization. + virtual void InitializeState(CTCBeamState* root) const {} + // ExpandState is called when expanding a beam to one of its children. + // Called at most once per child beam. In the simplest case, no state + // expansion is done. + virtual void ExpandState(const CTCBeamState& from_state, int from_label, + CTCBeamState* to_state, int to_label) const {} + // ExpandStateEnd is called after decoding has finished. Its purpose is to + // allow a final scoring of the beam in its current state, before resorting + // and retrieving the TopN requested candidates. Called at most once per beam. + virtual void ExpandStateEnd(CTCBeamState* state) const {} + // GetStateExpansionScore should be an inexpensive method to retrieve the + // (cached) expansion score computed within ExpandState. The score is + // multiplied (log-addition) with the input score at the current step from + // the network. + // + // The score returned should be a log-probability. In the simplest case, as + // there's no state expansion logic, the expansion score is zero. + virtual float GetStateExpansionScore(const CTCBeamState& state, + float previous_score) const { + return previous_score; + } + // GetStateEndExpansionScore should be an inexpensive method to retrieve the + // (cached) expansion score computed within ExpandStateEnd. The score is + // multiplied (log-addition) with the final probability of the beam. + // + // The score returned should be a log-probability. + virtual float GetStateEndExpansionScore(const CTCBeamState& state) const { + return 0; + } +}; + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h new file mode 100644 index 0000000000000000000000000000000000000000..c658e43092519ba29d880a670a890af148230091 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h @@ -0,0 +1,420 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_beam_search.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ + +#include +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h" +#include "tensorflow/contrib/lite/experimental/kernels/top_n.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +namespace tflite { +namespace experimental { +namespace ctc { + +template > +class CTCBeamSearchDecoder : public CTCDecoder { + // Beam Search + // + // Example (GravesTh Fig. 7.5): + // a - + // P = [ 0.3 0.7 ] t = 0 + // [ 0.4 0.6 ] t = 1 + // + // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42 + // P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58 + // + // In this case, Best Path decoding is suboptimal. + // + // For Beam Search, we use the following main recurrence relations: + // + // Relation 1: + // ---------------------------------------------------------- Eq. 1 + // P(l=abcd @ t=7) = P(l=abc @ t=6) * P(d @ 7) + // + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7)) + // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and + // updated recursively in the beam entry. + // + // Relation 2: + // ---------------------------------------------------------- Eq. 2 + // P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3) + // for ? in a, b, d, ..., (not including c or the blank index), + // and the recurrence starts from the beam entry for P(l=abc @ t=2). + // + // For this case, the length of the new sequence equals t+1 (t + // starts at 0). This special case can be calculated as: + // P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3) + // but we calculate it recursively for speed purposes. + typedef ctc_beam_search::BeamEntry BeamEntry; + typedef ctc_beam_search::BeamRoot BeamRoot; + typedef ctc_beam_search::BeamProbability BeamProbability; + + public: + typedef BaseBeamScorer DefaultBeamScorer; + + // The beam search decoder is constructed specifying the beam_width (number of + // candidates to keep at each decoding timestep) and a beam scorer (used for + // custom scoring, for example enabling the use of a language model). + // The ownership of the scorer remains with the caller. The default + // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the + // standard beam search. + CTCBeamSearchDecoder(int num_classes, int beam_width, + BaseBeamScorer* scorer, int batch_size = 1, + bool merge_repeated = false) + : CTCDecoder(num_classes, batch_size, merge_repeated), + beam_width_(beam_width), + leaves_(beam_width), + beam_scorer_(scorer) { + Reset(); + } + + ~CTCBeamSearchDecoder() override {} + + // Run the hibernating beam search algorithm on the given input. + bool Decode(const CTCDecoder::SequenceLength& seq_len, + const std::vector& input, + std::vector* output, + CTCDecoder::ScoreOutput* scores) override; + + // Calculate the next step of the beam search and update the internal state. + template + void Step(const Vector& log_input_t); + + template + float GetTopK(const int K, const Vector& input, + std::vector* top_k_logits, + std::vector* top_k_indices); + + // Retrieve the beam scorer instance used during decoding. + BaseBeamScorer* GetBeamScorer() const { return beam_scorer_; } + + // Set label selection parameters for faster decoding. + // See comments for label_selection_size_ and label_selection_margin_. + void SetLabelSelectionParameters(int label_selection_size, + float label_selection_margin) { + label_selection_size_ = label_selection_size; + label_selection_margin_ = label_selection_margin; + } + + // Reset the beam search + void Reset(); + + // Extract the top n paths at current time step + bool TopPaths(int n, std::vector>* paths, + std::vector* log_probs, bool merge_repeated) const; + + private: + int beam_width_; + + // Label selection is designed to avoid possibly very expensive scorer calls, + // by pruning the hypotheses based on the input alone. + // Label selection size controls how many items in each beam are passed + // through to the beam scorer. Only items with top N input scores are + // considered. + // Label selection margin controls the difference between minimal input score + // (versus the best scoring label) for an item to be passed to the beam + // scorer. This margin is expressed in terms of log-probability. + // Default is to do no label selection. + // For more detail: https://research.google.com/pubs/pub44823.html + int label_selection_size_ = 0; // zero means unlimited + float label_selection_margin_ = -1; // -1 means unlimited. + + gtl::TopN leaves_; + std::unique_ptr beam_root_; + BaseBeamScorer* beam_scorer_; + + CTCBeamSearchDecoder(const CTCBeamSearchDecoder&) = delete; + void operator=(const CTCBeamSearchDecoder&) = delete; +}; + +template +bool CTCBeamSearchDecoder::Decode( + const CTCDecoder::SequenceLength& seq_len, + const std::vector& input, + std::vector* output, ScoreOutput* scores) { + // Storage for top paths. + std::vector> beams; + std::vector beam_log_probabilities; + int top_n = output->size(); + if (std::any_of(output->begin(), output->end(), + [this](const CTCDecoder::Output& output) -> bool { + return output.size() < this->batch_size_; + })) { + return false; + } + if (scores->rows() < batch_size_ || scores->cols() < top_n) { + return false; + } + + for (int b = 0; b < batch_size_; ++b) { + int seq_len_b = seq_len[b]; + Reset(); + + for (int t = 0; t < seq_len_b; ++t) { + // Pass log-probabilities for this example + time. + Step(input[t].row(b)); + } // for (int t... + + // O(n * log(n)) + std::unique_ptr> branches(leaves_.Extract()); + leaves_.Reset(); + for (int i = 0; i < branches->size(); ++i) { + BeamEntry* entry = (*branches)[i]; + beam_scorer_->ExpandStateEnd(&entry->state); + entry->newp.total += + beam_scorer_->GetStateEndExpansionScore(entry->state); + leaves_.push(entry); + } + + bool status = + TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_); + if (!status) { + return status; + } + + TFLITE_DCHECK_EQ(top_n, beam_log_probabilities.size()); + TFLITE_DCHECK_EQ(beams.size(), beam_log_probabilities.size()); + + for (int i = 0; i < top_n; ++i) { + // Copy output to the correct beam + batch + (*output)[i][b].swap(beams[i]); + (*scores)(b, i) = -beam_log_probabilities[i]; + } + } // for (int b... + return true; +} + +template +template +float CTCBeamSearchDecoder::GetTopK( + const int K, const Vector& input, std::vector* top_k_logits, + std::vector* top_k_indices) { + // Find Top K choices, complexity nk in worst case. The array input is read + // just once. + TFLITE_DCHECK_EQ(num_classes_, input.size()); + top_k_logits->clear(); + top_k_indices->clear(); + top_k_logits->resize(K, -INFINITY); + top_k_indices->resize(K, -1); + for (int j = 0; j < num_classes_ - 1; ++j) { + const float logit = input(j); + if (logit > (*top_k_logits)[K - 1]) { + int k = K - 1; + while (k > 0 && logit > (*top_k_logits)[k - 1]) { + (*top_k_logits)[k] = (*top_k_logits)[k - 1]; + (*top_k_indices)[k] = (*top_k_indices)[k - 1]; + k--; + } + (*top_k_logits)[k] = logit; + (*top_k_indices)[k] = j; + } + } + // Return max value which is in 0th index or blank character logit + return std::max((*top_k_logits)[0], input(num_classes_ - 1)); +} + +template +template +void CTCBeamSearchDecoder::Step( + const Vector& raw_input) { + std::vector top_k_logits; + std::vector top_k_indices; + const bool top_k = + (label_selection_size_ > 0 && label_selection_size_ < raw_input.size()); + // Number of character classes to consider in each step. + const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1); + // Get max coefficient and remove it from raw_input later. + float max_coeff; + if (top_k) { + max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits, + &top_k_indices); + } else { + max_coeff = raw_input.maxCoeff(); + } + const float label_selection_input_min = + (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_) + : -std::numeric_limits::infinity(); + + // Extract the beams sorted in decreasing new probability + TFLITE_DCHECK_EQ(num_classes_, raw_input.size()); + + std::unique_ptr> branches(leaves_.Extract()); + leaves_.Reset(); + + for (BeamEntry* b : *branches) { + // P(.. @ t) becomes the new P(.. @ t-1) + b->oldp = b->newp; + } + + for (BeamEntry* b : *branches) { + if (b->parent != nullptr) { // if not the root + if (b->parent->Active()) { + // If last two sequence characters are identical: + // Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5) + // + Pblank(l=ac @ t=5)) + // else: + // Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5) + // + P(l=ab @ t=5)) + float previous = (b->label == b->parent->label) ? b->parent->oldp.blank + : b->parent->oldp.total; + b->newp.label = + LogSumExp(b->newp.label, + beam_scorer_->GetStateExpansionScore(b->state, previous)); + } + // Plabel(l=abc @ t=6) *= P(c @ 6) + b->newp.label += raw_input(b->label) - max_coeff; + } + // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6) + b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff; + // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6) + b->newp.total = LogSumExp(b->newp.blank, b->newp.label); + + // Push the entry back to the top paths list. + // Note, this will always fill leaves back up in sorted order. + leaves_.push(b); + } + + // we need to resort branches in descending oldp order. + + // branches is in descending oldp order because it was + // originally in descending newp order and we copied newp to oldp. + + // Grow new leaves + for (BeamEntry* b : *branches) { + // A new leaf (represented by its BeamProbability) is a candidate + // iff its total probability is nonzero and either the beam list + // isn't full, or the lowest probability entry in the beam has a + // lower probability than the leaf. + auto is_candidate = [this](const BeamProbability& prob) { + return (prob.total > kLogZero && + (leaves_.size() < beam_width_ || + prob.total > leaves_.peek_bottom()->newp.total)); + }; + + if (!is_candidate(b->oldp)) { + continue; + } + + for (int ind = 0; ind < max_classes; ind++) { + const int label = top_k ? top_k_indices[ind] : ind; + const float logit = top_k ? top_k_logits[ind] : raw_input(ind); + // Perform label selection: if input for this label looks very + // unpromising, never evaluate it with a scorer. + if (logit < label_selection_input_min) { + continue; + } + BeamEntry& c = b->GetChild(label); + if (!c.Active()) { + // Pblank(l=abcd @ t=6) = 0 + c.newp.blank = kLogZero; + // If new child label is identical to beam label: + // Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6) + // Otherwise: + // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6) + beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label); + float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total; + c.newp.label = logit - max_coeff + + beam_scorer_->GetStateExpansionScore(c.state, previous); + // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6) + c.newp.total = c.newp.label; + + if (is_candidate(c.newp)) { + // Before adding the new node to the beam, check if the beam + // is already at maximum width. + if (leaves_.size() == beam_width_) { + // Bottom is no longer in the beam search. Reset + // its probability; signal it's no longer in the beam search. + BeamEntry* bottom = leaves_.peek_bottom(); + bottom->newp.Reset(); + } + leaves_.push(&c); + } else { + // Deactivate child. + c.oldp.Reset(); + c.newp.Reset(); + } + } + } + } // for (BeamEntry* b... +} + +template +void CTCBeamSearchDecoder::Reset() { + leaves_.Reset(); + + // This beam root, and all of its children, will be in memory until + // the next reset. + beam_root_.reset(new BeamRoot(nullptr, -1)); + beam_root_->RootEntry()->newp.total = 0.0; // ln(1) + beam_root_->RootEntry()->newp.blank = 0.0; // ln(1) + + // Add the root as the initial leaf. + leaves_.push(beam_root_->RootEntry()); + + // Call initialize state on the root object. + beam_scorer_->InitializeState(&beam_root_->RootEntry()->state); +} + +template +bool CTCBeamSearchDecoder::TopPaths( + int n, std::vector>* paths, std::vector* log_probs, + bool merge_repeated) const { + TFLITE_DCHECK(paths); + TFLITE_DCHECK(log_probs); + paths->clear(); + log_probs->clear(); + if (n > beam_width_) { + return false; + } + if (n > leaves_.size()) { + return false; + } + + gtl::TopN top_branches(n); + + // O(beam_width_ * log(n)), space complexity is O(n) + for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) { + top_branches.push(*it); + } + // O(n * log(n)) + std::unique_ptr> branches(top_branches.Extract()); + + for (int i = 0; i < n; ++i) { + BeamEntry* e((*branches)[i]); + paths->push_back(e->LabelSeq(merge_repeated)); + log_probs->push_back(e->newp.total); + } + return true; +} + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc new file mode 100644 index 0000000000000000000000000000000000000000..834d1ebd666db2be46394166edadf2a166d958aa --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc @@ -0,0 +1,247 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace experimental { +namespace ctc_beam_search_decoder { + +constexpr int kInputsTensor = 0; +constexpr int kSequenceLengthTensor = 1; + +typedef struct { + int beam_width; + int top_paths; + bool merge_repeated; +} CTCBeamSearchDecoderParams; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_CHECK(buffer != nullptr); + const uint8_t* buffer_t = reinterpret_cast(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + + CTCBeamSearchDecoderParams* option = new CTCBeamSearchDecoderParams; + option->beam_width = m["beam_width"].AsInt32(); + option->top_paths = m["top_paths"].AsInt32(); + option->merge_repeated = m["merge_repeated"].AsBool(); + + return option; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const CTCBeamSearchDecoderParams* option = + reinterpret_cast(node->user_data); + const int top_paths = option->top_paths; + TF_LITE_ENSURE(context, option->beam_width >= top_paths); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + // The outputs should be top_paths * 3 + 1. + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1); + + const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3); + // TensorFlow only supports float. + TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32); + const int batch_size = SizeOfDimension(inputs, 1); + + const TfLiteTensor* sequence_length = + GetInput(context, node, kSequenceLengthTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1); + TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size); + // TensorFlow only supports int32. + TF_LITE_ENSURE_EQ(context, sequence_length->type, kTfLiteInt32); + + // Resize decoded outputs. + // Do not resize indices & values cause we don't know the values yet. + for (int i = 0; i < top_paths; ++i) { + TfLiteTensor* indices = GetOutput(context, node, i); + SetTensorToDynamic(indices); + TfLiteTensor* values = GetOutput(context, node, i + top_paths); + SetTensorToDynamic(values); + TfLiteTensor* output_shape = GetOutput(context, node, i + 2 * top_paths); + SetTensorToDynamic(output_shape); + } + + // Resize log probability outputs. + TfLiteTensor* log_probability_output = + GetOutput(context, node, top_paths * 3); + TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2); + log_probability_output_shape_array->data[0] = batch_size; + log_probability_output_shape_array->data[1] = top_paths; + return context->ResizeTensor(context, log_probability_output, + log_probability_output_shape_array); +} + +TfLiteStatus Resize(TfLiteContext* context, + std::initializer_list output_shape, + TfLiteTensor* output) { + const int dimensions = output_shape.size(); + TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(dimensions); + int i = 0; + for (const int v : output_shape) { + output_shape_array->data[i++] = v; + } + return context->ResizeTensor(context, output, output_shape_array); +} + +TfLiteStatus StoreAllDecodedSequences( + TfLiteContext* context, + const std::vector>>& sequences, + TfLiteNode* node, int top_paths) { + const int32_t batch_size = sequences.size(); + std::vector num_entries(top_paths, 0); + + // Calculate num_entries per path + for (const auto& batch_s : sequences) { + TF_LITE_ENSURE_EQ(context, batch_s.size(), top_paths); + for (int p = 0; p < top_paths; ++p) { + num_entries[p] += batch_s[p].size(); + } + } + + for (int p = 0; p < top_paths; ++p) { + const int32_t p_num = num_entries[p]; + + // Resize the decoded outputs. + TfLiteTensor* indices = GetOutput(context, node, p); + TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices)); + + TfLiteTensor* values = GetOutput(context, node, p + top_paths); + TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values)); + + TfLiteTensor* decoded_shape = GetOutput(context, node, p + 2 * top_paths); + TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape)); + + int32_t max_decoded = 0; + int32_t offset = 0; + + int32_t* indices_data = GetTensorData(indices); + int32_t* values_data = GetTensorData(values); + int32_t* decoded_shape_data = GetTensorData(decoded_shape); + for (int b = 0; b < batch_size; ++b) { + auto& p_batch = sequences[b][p]; + int32_t num_decoded = p_batch.size(); + max_decoded = std::max(max_decoded, num_decoded); + + std::copy_n(p_batch.begin(), num_decoded, values_data + offset); + for (int32_t t = 0; t < num_decoded; ++t, ++offset) { + indices_data[offset * 2] = b; + indices_data[offset * 2 + 1] = t; + } + } + + decoded_shape_data[0] = batch_size; + decoded_shape_data[1] = max_decoded; + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor); + const TfLiteTensor* sequence_length = + GetInput(context, node, kSequenceLengthTensor); + const CTCBeamSearchDecoderParams* option = + reinterpret_cast(node->user_data); + + const int max_time = SizeOfDimension(inputs, 0); + const int batch_size = SizeOfDimension(inputs, 1); + const int num_classes = SizeOfDimension(inputs, 2); + + const int beam_width = option->beam_width; + const int top_paths = option->top_paths; + const bool merge_repeated = option->merge_repeated; + + // Validate sequence length is less or equal than max time. + for (int i = 0; i < batch_size; ++i) { + TF_LITE_ENSURE(context, + max_time >= GetTensorData(sequence_length)[i]); + } + + // The following logic is implemented like + // tensorflow/core/kernels/ctc_decoder_ops.cc + std::vector::UnalignedConstMatrix> input_list_t; + + for (std::size_t t = 0; t < max_time; ++t) { + input_list_t.emplace_back( + GetTensorData(inputs) + t * batch_size * num_classes, batch_size, + num_classes); + } + + ::tflite::experimental::ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer + beam_scorer; + ::tflite::experimental::ctc::CTCBeamSearchDecoder<> beam_search( + num_classes, beam_width, &beam_scorer, 1 /* batch_size */, + merge_repeated); + + // Allocate temporary memory for holding chip operation data. + float* input_chip_t_data = + static_cast(malloc(num_classes * sizeof(float))); + Eigen::array dims; + dims[0] = num_classes; + optimized_ops::TTypes::Flat input_chip_t(input_chip_t_data, dims); + + std::vector>> best_paths(batch_size); + std::vector log_probs; + + TfLiteTensor* log_probabilities = GetOutput(context, node, 3 * top_paths); + float* log_probabilities_output = GetTensorData(log_probabilities); + + // Assumption: the blank index is num_classes - 1 + for (int b = 0; b < batch_size; ++b) { + auto& best_paths_b = best_paths[b]; + best_paths_b.resize(top_paths); + for (int t = 0; t < GetTensorData(sequence_length)[b]; ++t) { + input_chip_t = input_list_t[t].chip(b, 0); + auto input_bi = + Eigen::Map(input_chip_t.data(), num_classes); + beam_search.Step(input_bi); + } + TF_LITE_ENSURE(context, beam_search.TopPaths(top_paths, &best_paths_b, + &log_probs, merge_repeated)); + beam_search.Reset(); + + // Fill in log_probabilities output. + for (int bp = 0; bp < top_paths; ++bp) { + log_probabilities_output[b * top_paths + bp] = log_probs[bp]; + } + } + + free(input_chip_t_data); + return StoreAllDecodedSequences(context, best_paths, node, top_paths); +} + +} // namespace ctc_beam_search_decoder + +TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER() { + static TfLiteRegistration r = { + ctc_beam_search_decoder::Init, ctc_beam_search_decoder::Free, + ctc_beam_search_decoder::Prepare, ctc_beam_search_decoder::Eval}; + return &r; +} + +} // namespace experimental +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9d1e6a562f00905d1db7f7e055ac1c6b1cc34f9e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc @@ -0,0 +1,238 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace experimental { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER(); + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class CTCBeamSearchDecoderOpModel : public SingleOpModel { + public: + CTCBeamSearchDecoderOpModel(std::initializer_list input_shape, + std::initializer_list sequence_length_shape, + int beam_width, int top_paths, + bool merge_repeated) { + inputs_ = AddInput(TensorType_FLOAT32); + sequence_length_ = AddInput(TensorType_INT32); + + for (int i = 0; i < top_paths * 3; ++i) { + outputs_.push_back(AddOutput(TensorType_INT32)); + } + outputs_.push_back(AddOutput(TensorType_FLOAT32)); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("beam_width", beam_width); + fbb.Int("top_paths", top_paths); + fbb.Bool("merge_repeated", merge_repeated); + }); + fbb.Finish(); + SetCustomOp("CTCBeamSearchDecoder", fbb.GetBuffer(), + Register_CTC_BEAM_SEARCH_DECODER); + BuildInterpreter({input_shape, sequence_length_shape}); + } + + int inputs() { return inputs_; } + + int sequence_length() { return sequence_length_; } + + std::vector> GetDecodedOutpus() { + std::vector> outputs; + for (int i = 0; i < outputs_.size() - 1; ++i) { + outputs.push_back(ExtractVector(outputs_[i])); + } + return outputs; + } + + std::vector GetLogProbabilitiesOutput() { + return ExtractVector(outputs_[outputs_.size() - 1]); + } + + std::vector> GetOutputShapes() { + std::vector> output_shapes; + for (const int output : outputs_) { + output_shapes.push_back(GetTensorShape(output)); + } + return output_shapes; + } + + private: + int inputs_; + int sequence_length_; + std::vector outputs_; +}; + +TEST(CTCBeamSearchTest, SimpleTest) { + CTCBeamSearchDecoderOpModel m({2, 1, 2}, {1}, 1, 1, true); + m.PopulateTensor(m.inputs(), + {-0.50922557, -1.35512652, -2.55445064, -1.58419356}); + m.PopulateTensor(m.sequence_length(), {2}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 4); + EXPECT_THAT(output_shapes[0], ElementsAre(1, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(1)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + EXPECT_THAT(output_shapes[3], ElementsAre(1, 1)); + + // Check decoded outputs. + const std::vector>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 3); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(0)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1)); + // Check log probabilities output. + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({0.32134813}))); +} + +TEST(CTCBeamSearchTest, MultiBatchTest) { + CTCBeamSearchDecoderOpModel m({3, 3, 3}, {3}, 1, 1, true); + m.PopulateTensor( + m.inputs(), + {-0.63649208, -0.00487571, -0.04249819, -0.67754697, -1.0341399, + -2.14717721, -0.77686821, -3.41973774, -0.05151402, -0.21482619, + -0.57411168, -1.45039917, -0.73769373, -2.10941739, -0.44818325, + -0.25287673, -2.80057302, -0.54748312, -0.73334867, -0.86537719, + -0.2065197, -0.18725838, -1.42770405, -0.86051965, -1.61642301, + -2.07275114, -0.9201845}); + m.PopulateTensor(m.sequence_length(), {3, 3, 3}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 4); + EXPECT_THAT(output_shapes[0], ElementsAre(4, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(4)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + EXPECT_THAT(output_shapes[3], ElementsAre(3, 1)); + + // Check decoded outputs. + const std::vector>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 3); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 2, 0)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2)); + // Check log probabilities output. + EXPECT_THAT( + m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({0.46403232, 0.49500442, 0.40443572}))); +} + +TEST(CTCBeamSearchTest, MultiPathsTest) { + CTCBeamSearchDecoderOpModel m({3, 2, 5}, {2}, 3, 2, true); + m.PopulateTensor( + m.inputs(), + {-2.206851, -0.09542714, -0.2393415, -3.81866197, -0.27241158, + -0.20371124, -0.68236623, -1.1397166, -0.17422639, -1.85224048, + -0.9406037, -0.32544678, -0.21846784, -0.38377237, -0.33498676, + -0.10139782, -0.51886883, -0.21678554, -0.15267063, -1.91164412, + -0.31328673, -0.27462716, -0.65975336, -1.53671973, -2.76554225, + -0.23920634, -1.2370502, -4.98751576, -3.12995717, -0.43129368}); + m.PopulateTensor(m.sequence_length(), {3, 3}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 7); + EXPECT_THAT(output_shapes[0], ElementsAre(4, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(3, 2)); + EXPECT_THAT(output_shapes[2], ElementsAre(4)); + EXPECT_THAT(output_shapes[3], ElementsAre(3)); + EXPECT_THAT(output_shapes[4], ElementsAre(2)); + EXPECT_THAT(output_shapes[5], ElementsAre(2)); + EXPECT_THAT(output_shapes[6], ElementsAre(2, 2)); + + // Check decoded outputs. + const std::vector>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 6); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 1, 1)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(0, 0, 0, 1, 1, 0)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 2, 3, 0)); + EXPECT_THAT(decoded_outputs[3], ElementsAre(2, 1, 0)); + EXPECT_THAT(decoded_outputs[4], ElementsAre(2, 2)); + EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2)); + // Check log probabilities output. + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear( + {0.91318405, 0.9060272, 1.0780245, 0.64358956}))); +} + +TEST(CTCBeamSearchTest, NonEqualSequencesTest) { + CTCBeamSearchDecoderOpModel m({3, 3, 4}, {3}, 3, 1, true); + m.PopulateTensor( + m.inputs(), + {-1.26658163, -0.25760023, -0.03917975, -0.63772235, -0.03794756, + -0.45063099, -0.27706473, -0.01569179, -0.59940385, -0.35700127, + -0.48920721, -1.42635476, -1.3462478, -0.02565498, -0.30179568, + -0.6491698, -0.55017719, -2.92291466, -0.92522973, -0.47592022, + -0.07099135, -0.31575624, -0.86345281, -0.36017021, -0.79208612, + -1.75306124, -0.65089224, -0.00912786, -0.42915003, -1.72606203, + -1.66337589, -0.70800793, -2.52272352, -0.67329562, -2.49145522, + -0.49786342}); + m.PopulateTensor(m.sequence_length(), {1, 2, 3}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 4); + EXPECT_THAT(output_shapes[0], ElementsAre(3, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(3)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + EXPECT_THAT(output_shapes[3], ElementsAre(3, 1)); + + // Check decoded outputs. + const std::vector>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 3); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 1, 0, 2, 0)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(2, 0, 1)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1)); + // Check log probabilities output. + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({0., 1.0347567, 0.7833005}))); +} + +} // namespace +} // namespace experimental +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h b/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h new file mode 100644 index 0000000000000000000000000000000000000000..596ad4a5f7264ae24caa5592d10c09c256629b06 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h @@ -0,0 +1,114 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_decoder.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ + +#include +#include + +#include "third_party/eigen3/Eigen/Core" + +namespace tflite { +namespace experimental { +namespace ctc { + +// The CTCDecoder is an abstract interface to be implemented when providing a +// decoding method on the timestep output of a RNN trained with CTC loss. +// +// The two types of decoding available are: +// - greedy path, through the CTCGreedyDecoder +// - beam search, through the CTCBeamSearchDecoder +class CTCDecoder { + public: + typedef Eigen::Map SequenceLength; + typedef Eigen::Map Input; + typedef std::vector> Output; + typedef Eigen::Map ScoreOutput; + + CTCDecoder(int num_classes, int batch_size, bool merge_repeated) + : num_classes_(num_classes), + blank_index_(num_classes - 1), + batch_size_(batch_size), + merge_repeated_(merge_repeated) {} + + virtual ~CTCDecoder() {} + + // Dimensionality of the input/output is expected to be: + // - seq_len[b] - b = 0 to batch_size_ + // - input[t].rows(b) - t = 0 to timesteps; b = 0 t batch_size_ + // - output.size() specifies the number of beams to be returned. + // - scores(b, i) - b = 0 to batch_size; i = 0 to output.size() + virtual bool Decode(const SequenceLength& seq_len, + const std::vector& input, + std::vector* output, ScoreOutput* scores) = 0; + + int batch_size() { return batch_size_; } + int num_classes() { return num_classes_; } + + protected: + int num_classes_; + int blank_index_; + int batch_size_; + bool merge_repeated_; +}; + +// CTCGreedyDecoder is an implementation of the simple best path decoding +// algorithm, selecting at each timestep the most likely class at each timestep. +class CTCGreedyDecoder : public CTCDecoder { + public: + CTCGreedyDecoder(int num_classes, int batch_size, bool merge_repeated) + : CTCDecoder(num_classes, batch_size, merge_repeated) {} + + bool Decode(const CTCDecoder::SequenceLength& seq_len, + const std::vector& input, + std::vector* output, + CTCDecoder::ScoreOutput* scores) override { + if (output->empty() || (*output)[0].size() < batch_size_) { + return false; + } + if (scores->rows() < batch_size_ || scores->cols() == 0) { + return false; + } + // For each batch entry, identify the transitions + for (int b = 0; b < batch_size_; ++b) { + int seq_len_b = seq_len[b]; + // Only writing to beam 0 + std::vector& output_b = (*output)[0][b]; + + int prev_class_ix = -1; + (*scores)(b, 0) = 0; + for (int t = 0; t < seq_len_b; ++t) { + auto row = input[t].row(b); + int max_class_ix; + (*scores)(b, 0) += -row.maxCoeff(&max_class_ix); + if (max_class_ix != blank_index_ && + !(merge_repeated_ && max_class_ix == prev_class_ix)) { + output_b.push_back(max_class_ix); + } + prev_class_ix = max_class_ix; + } + } + return true; + } +}; + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h b/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0bae732533716ac047a55ea31633c8ed51253fe0 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_loss_util.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ + +#include +#include + +namespace tflite { +namespace experimental { +namespace ctc { + +const float kLogZero = -std::numeric_limits::infinity(); + +// Add logarithmic probabilities using: +// ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a))) +// The two inputs are assumed to be log probabilities. +// (GravesTh) Eq. 7.18 +inline float LogSumExp(float log_prob_1, float log_prob_2) { + // Always have 'b' be the smaller number to avoid the exponential from + // blowing up. + if (log_prob_1 == kLogZero && log_prob_2 == kLogZero) { + return kLogZero; + } else { + return (log_prob_1 > log_prob_2) + ? log_prob_1 + log1pf(expf(log_prob_2 - log_prob_1)) + : log_prob_2 + log1pf(expf(log_prob_1 - log_prob_2)); + } +} + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/top_n.h b/tensorflow/contrib/lite/experimental/kernels/top_n.h new file mode 100644 index 0000000000000000000000000000000000000000..cd2a2f1c80276d4659ccd2f8f05af3af030acb90 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/top_n.h @@ -0,0 +1,341 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This simple class finds the top n elements of an incrementally provided set +// of elements which you push one at a time. If the number of elements exceeds +// n, the lowest elements are incrementally dropped. At the end you get +// a vector of the top elements sorted in descending order (through Extract() or +// ExtractNondestructive()), or a vector of the top elements but not sorted +// (through ExtractUnsorted() or ExtractUnsortedNondestructive()). +// +// The value n is specified in the constructor. If there are p elements pushed +// altogether: +// The total storage requirements are O(min(n, p)) elements +// The running time is O(p * log(min(n, p))) comparisons +// If n is a constant, the total storage required is a constant and the running +// time is linear in p. +// +// NOTE(zhifengc): There is a way to do this in O(min(n, p)) storage and O(p) +// runtime. The basic idea is to repeatedly fill up a buffer of 2 * n elements, +// discarding the lowest n elements whenever the buffer is full using a linear- +// time median algorithm. This may have better performance when the input +// sequence is partially sorted. +// +// NOTE(zhifengc): This class should be redesigned to avoid reallocating a +// vector for each Extract. + +// Copied from tensorflow/core/lib/gtl/top_n.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +namespace tflite { +namespace gtl { + +// Cmp is an stl binary predicate. Note that Cmp is the "greater" predicate, +// not the more commonly used "less" predicate. +// +// If you use a "less" predicate here, the TopN will pick out the bottom N +// elements out of the ones passed to it, and it will return them sorted in +// ascending order. +// +// TopN is rule-of-zero copyable and movable if its members are. +template > +class TopN { + public: + // The TopN is in one of the three states: + // + // o UNORDERED: this is the state an instance is originally in, + // where the elements are completely orderless. + // + // o BOTTOM_KNOWN: in this state, we keep the invariant that there + // is at least one element in it, and the lowest element is at + // position 0. The elements in other positions remain + // unsorted. This state is reached if the state was originally + // UNORDERED and a peek_bottom() function call is invoked. + // + // o HEAP_SORTED: in this state, the array is kept as a heap and + // there are exactly (limit_+1) elements in the array. This + // state is reached when at least (limit_+1) elements are + // pushed in. + // + // The state transition graph is at follows: + // + // peek_bottom() (limit_+1) elements + // UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED + // | ^ + // | (limit_+1) elements | + // +-----------------------------------------------------------+ + + enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED }; + using UnsortedIterator = typename std::vector::const_iterator; + + // 'limit' is the maximum number of top results to return. + explicit TopN(size_t limit) : TopN(limit, Cmp()) {} + TopN(size_t limit, const Cmp &cmp) : limit_(limit), cmp_(cmp) {} + + size_t limit() const { return limit_; } + + // Number of elements currently held by this TopN object. This + // will be no greater than 'limit' passed to the constructor. + size_t size() const { return std::min(elements_.size(), limit_); } + + bool empty() const { return size() == 0; } + + // If you know how many elements you will push at the time you create the + // TopN object, you can call reserve to preallocate the memory that TopN + // will need to process all 'n' pushes. Calling this method is optional. + void reserve(size_t n) { elements_.reserve(std::min(n, limit_ + 1)); } + + // Push 'v'. If the maximum number of elements was exceeded, drop the + // lowest element and return it in 'dropped' (if given). If the maximum is not + // exceeded, 'dropped' will remain unchanged. 'dropped' may be omitted or + // nullptr, in which case it is not filled in. + // Requires: T is CopyAssignable, Swappable + void push(const T &v) { push(v, nullptr); } + void push(const T &v, T *dropped) { PushInternal(v, dropped); } + + // Move overloads of push. + // Requires: T is MoveAssignable, Swappable + void push(T &&v) { // NOLINT(build/c++11) + push(std::move(v), nullptr); + } + void push(T &&v, T *dropped) { // NOLINT(build/c++11) + PushInternal(std::move(v), dropped); + } + + // Peeks the bottom result without calling Extract() + const T &peek_bottom(); + + // Extract the elements as a vector sorted in descending order. The caller + // assumes ownership of the vector and must delete it when done. This is a + // destructive operation. The only method that can be called immediately + // after Extract() is Reset(). + std::vector *Extract(); + + // Similar to Extract(), but makes no guarantees the elements are in sorted + // order. As with Extract(), the caller assumes ownership of the vector and + // must delete it when done. This is a destructive operation. The only + // method that can be called immediately after ExtractUnsorted() is Reset(). + std::vector *ExtractUnsorted(); + + // A non-destructive version of Extract(). Copy the elements in a new vector + // sorted in descending order and return it. The caller assumes ownership of + // the new vector and must delete it when done. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + std::vector *ExtractNondestructive() const; + + // A non-destructive version of Extract(). Copy the elements to a given + // vector sorted in descending order. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractNondestructive(std::vector *output) const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements in a new + // vector and return it, with no guarantees the elements are in sorted order. + // The caller assumes ownership of the new vector and must delete it when + // done. After calling ExtractUnsortedNondestructive(), the caller can + // continue to push() new elements. + std::vector *ExtractUnsortedNondestructive() const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements into + // a given vector, with no guarantees the elements are in sorted order. + // After calling ExtractUnsortedNondestructive(), the caller can continue + // to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractUnsortedNondestructive(std::vector *output) const; + + // Return an iterator to the beginning (end) of the container, + // with no guarantees about the order of iteration. These iterators are + // invalidated by mutation of the data structure. + UnsortedIterator unsorted_begin() const { return elements_.begin(); } + UnsortedIterator unsorted_end() const { return elements_.begin() + size(); } + + // Accessor for comparator template argument. + Cmp *comparator() { return &cmp_; } + + // This removes all elements. If Extract() or ExtractUnsorted() have been + // called, this will put it back in an empty but useable state. + void Reset(); + + private: + template + void PushInternal(U &&v, T *dropped); // NOLINT(build/c++11) + + // elements_ can be in one of two states: + // elements_.size() <= limit_: elements_ is an unsorted vector of elements + // pushed so far. + // elements_.size() > limit_: The last element of elements_ is unused; + // the other elements of elements_ are an stl heap whose size is exactly + // limit_. In this case elements_.size() is exactly one greater than + // limit_, but don't use "elements_.size() == limit_ + 1" to check for + // that because you'll get a false positive if limit_ == size_t(-1). + std::vector elements_; + size_t limit_; // Maximum number of elements to find + Cmp cmp_; // Greater-than comparison function + State state_ = UNORDERED; +}; + +// ---------------------------------------------------------------------- +// Implementations of non-inline functions + +template +template +void TopN::PushInternal(U &&v, T *dropped) { // NOLINT(build/c++11) + if (limit_ == 0) { + if (dropped) *dropped = std::forward(v); // NOLINT(build/c++11) + return; + } + if (state_ != HEAP_SORTED) { + elements_.push_back(std::forward(v)); // NOLINT(build/c++11) + if (state_ == UNORDERED || cmp_(elements_.back(), elements_.front())) { + // Easy case: we just pushed the new element back + } else { + // To maintain the BOTTOM_KNOWN state, we need to make sure that + // the element at position 0 is always the smallest. So we put + // the new element at position 0 and push the original bottom + // element in the back. + // Warning: this code is subtle. + using std::swap; + swap(elements_.front(), elements_.back()); + } + if (elements_.size() == limit_ + 1) { + // Transition from unsorted vector to a heap. + std::make_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.front()); + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + state_ = HEAP_SORTED; + } + } else { + // Only insert the new element if it is greater than the least element. + if (cmp_(v, elements_.front())) { + elements_.back() = std::forward(v); // NOLINT(build/c++11) + std::push_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.front()); + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + } else { + if (dropped) *dropped = std::forward(v); // NOLINT(build/c++11) + } + } +} + +template +const T &TopN::peek_bottom() { + TFLITE_DCHECK(!empty()); + if (state_ == UNORDERED) { + // We need to do a linear scan to find out the bottom element + int min_candidate = 0; + for (size_t i = 1; i < elements_.size(); ++i) { + if (cmp_(elements_[min_candidate], elements_[i])) { + min_candidate = i; + } + } + // By swapping the element at position 0 and the minimal + // element, we transition to the BOTTOM_KNOWN state + if (min_candidate != 0) { + using std::swap; + swap(elements_[0], elements_[min_candidate]); + } + state_ = BOTTOM_KNOWN; + } + return elements_.front(); +} + +template +std::vector *TopN::Extract() { + auto out = new std::vector; + out->swap(elements_); + if (state_ != HEAP_SORTED) { + std::sort(out->begin(), out->end(), cmp_); + } else { + out->pop_back(); + std::sort_heap(out->begin(), out->end(), cmp_); + } + return out; +} + +template +std::vector *TopN::ExtractUnsorted() { + auto out = new std::vector; + out->swap(elements_); + if (state_ == HEAP_SORTED) { + // Remove the limit_+1'th element. + out->pop_back(); + } + return out; +} + +template +std::vector *TopN::ExtractNondestructive() const { + auto out = new std::vector; + ExtractNondestructive(out); + return out; +} + +template +void TopN::ExtractNondestructive(std::vector *output) const { + TFLITE_DCHECK(output); + *output = elements_; + if (state_ != HEAP_SORTED) { + std::sort(output->begin(), output->end(), cmp_); + } else { + output->pop_back(); + std::sort_heap(output->begin(), output->end(), cmp_); + } +} + +template +std::vector *TopN::ExtractUnsortedNondestructive() const { + auto elements = new std::vector; + ExtractUnsortedNondestructive(elements); + return elements; +} + +template +void TopN::ExtractUnsortedNondestructive(std::vector *output) const { + TFLITE_DCHECK(output); + *output = elements_; + if (state_ == HEAP_SORTED) { + // Remove the limit_+1'th element. + output->pop_back(); + } +} + +template +void TopN::Reset() { + elements_.clear(); + state_ = UNORDERED; +} + +} // namespace gtl +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ diff --git a/tensorflow/contrib/lite/g3doc/README.md b/tensorflow/contrib/lite/g3doc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e3db4784815b7562588d3afbd34f837b101f0977 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/README.md @@ -0,0 +1,4 @@ +This is a *work-in-progress* TF Lite subsite for: +https://www.tensorflow.org/mobile + +DO NOT PUBLISH diff --git a/tensorflow/contrib/lite/g3doc/_book.yaml b/tensorflow/contrib/lite/g3doc/_book.yaml new file mode 100644 index 0000000000000000000000000000000000000000..98abd5743b2412399496f2fb3a70cd25d8597bca --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/_book.yaml @@ -0,0 +1,58 @@ +upper_tabs: +# Tabs left of dropdown menu +- include: /_upper_tabs_left.yaml +# Dropdown menu +- name: Ecosystem + path: /ecosystem + is_default: True + menu: + - include: /ecosystem/_menu_toc.yaml + lower_tabs: + # Subsite tabs + other: + - name: Guide + contents: + - title: Overview + path: /mobile/overview + - title: Developer Guide + path: /mobile/devguide + - title: Android Demo App + path: /mobile/demo_android + - title: iOS Demo App + path: /mobile/demo_ios + - title: Performance + path: /mobile/performance + - break: True + - title: TensorFlow Lite APIs + path: /mobile/apis + - title: Custom operators + path: /mobile/custom_operators + - title: TensorFlow Lite Ops Versioning + path: /mobile/ops_versioning + - title: TensorFlow Lite Compatibility Guide + path: /mobile/tf_ops_compatibility + - title: List of Hosted Models + path: /mobile/models + - title: TensorFlow Lite for iOS + path: /mobile/ios + - title: TensorFlow Lite for Raspberry Pi + path: /mobile/rpi + + - heading: TF Mobile + status: deprecated + - title: Overview + path: /mobile/tfmobile/ + - title: Building TensorFlow on Android + path: /mobile/tfmobile/android_build + - title: Building TensorFlow on IOS + path: /mobile/tfmobile/ios_build + - title: Integrating TensorFlow libraries + path: /mobile/tfmobile/linking_libs + - title: Preparing models for mobile deployment + path: /mobile/tfmobile/prepare_models + - title: Optimizing for mobile + path: /mobile/tfmobile/optimizing + + - name: API + contents: + - include: /mobile/api_docs/python/_toc.yaml diff --git a/tensorflow/contrib/lite/g3doc/_index.yaml b/tensorflow/contrib/lite/g3doc/_index.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9119e49117ffbda268f36324072d30ffd83c9e6c --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/_index.yaml @@ -0,0 +1,67 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml +description: +landing_page: + rows: + - heading: TensorFlow Lite is a lightweight solution for mobile and embedded devices. + items: + - description: > + TensorFlow Lite is TensorFlow’s lightweight solution for mobile and + embedded devices. It enables on-device machine learning inference with + low latency and a small binary size. TensorFlow Lite also supports + hardware acceleration with the + Android Neural Networks API. + list: + - heading: Key point 1 + description: > + [high-level overview] + icon: + icon_name: chevron_right + foreground: theme + background: grey + - heading: Key point 2 + description: > + [high-level overview] + icon: + icon_name: chevron_right + foreground: theme + background: grey + - heading: Key point 3 + description: > + [high-level overview] + icon: + icon_name: chevron_right + foreground: theme + background: grey + - code_block: | +
+        $ toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
+               --input_format=TENSORFLOW_GRAPHDEF \
+               --output_format=TFLITE \
+               --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
+               --inference_type=FLOAT \
+               --input_type=FLOAT \
+               --input_arrays=input \
+               --output_arrays=MobilenetV1/Predictions/Reshape_1 \
+               --input_shapes=1,224,224,3
+        
+ + - classname: devsite-landing-row-cards + items: + - heading: Using TensorFlow Lite on Android + image_path: /ecosystem/images/tf-logo-card-16x9.png + path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d + buttons: + - label: Read on TensorFlow blog + path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d + - heading: TensorFlow Lite at the Dev Summit + youtube_id: FAMfy7izB6A + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=FAMfy7izB6A + - heading: TensorFlow Lite on GitHub + image_path: /ecosystem/images/github-card-16x9.png + path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite + buttons: + - label: View on GitHub + path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite diff --git a/tensorflow/contrib/lite/g3doc/_project.yaml b/tensorflow/contrib/lite/g3doc/_project.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b39666516baab42d289e4d40077c2877ed65d396 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/_project.yaml @@ -0,0 +1,10 @@ +name: TensorFlow Lite +breadcrumb_name: Mobile +home_url: /mobile/ +parent_project_metadata_path: /_project.yaml +description: > + TensorFlow Lite is a lightweight solution for mobile and embedded devices. +use_site_branding: True +hide_from_products_list: True +content_license: cc3-apache2 +buganizer_id: 316308 diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml b/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e1c44c6929571144d8cf0b54463c48e37466022 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml @@ -0,0 +1,6 @@ +# Automatically generated file; please do not edit +toc: + - title: TensorFlow Lite + section: + - title: Overview + path: /mobile/api_docs/python/ diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md b/tensorflow/contrib/lite/g3doc/api_docs/python/index.md new file mode 100644 index 0000000000000000000000000000000000000000..70031a3c3d26eb6557014879cc92288cd22331eb --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/api_docs/python/index.md @@ -0,0 +1,10 @@ +Project: /mobile/_project.yaml +Book: /mobile/_book.yaml +page_type: reference + + + + +# All symbols in TensorFlow Lite + +TEMP PAGE diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md index 50cc146a87ee9ab94aea6a92fb2fb5c531f83369..776803da8c7126c6198e3740448888119df030b9 100644 --- a/tensorflow/contrib/lite/g3doc/apis.md +++ b/tensorflow/contrib/lite/g3doc/apis.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # TensorFlow Lite APIs TensorFlow Lite provides programming APIs in C++ and Java, and in both cases @@ -7,6 +10,9 @@ no surprise that the APIs try to avoid unnecessary copies at the expense of convenience. Similarly, consistency with TensorFlow APIs was not an explicit goal and some variance is to be expected. +There is also a Python API for TensorFlow Lite described +[here](../toco/g3doc/python_api.md#interpreter). + ## C++ In order to run the inference model in TensorFlow Lite, one has to load the @@ -50,6 +56,7 @@ typedef enum { ``` Failures can be easily verified with: + ```c++ if (status != kTfLiteOk) { // ... error handling here ... diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md index 972e57f73e82961ebc5e341dd7a41bc00acc5d21..d979353bb3550fe53d86b2e6c76702a3970b01fe 100644 --- a/tensorflow/contrib/lite/g3doc/custom_operators.md +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # How to use custom operators TensorFlow Lite currently supports a subset of TensorFlow operators. However, it @@ -89,3 +92,83 @@ builtins.AddCustom("Sin", Register_SIN()); Note that a similar process as above can be followed for supporting for a set of operations instead of a single operator. + +## Best Practices for writing custom operators + +1. Optimize memory allocations and de-allocations cautiously. It is more + efficient to allocate memory in Prepare() instead of Invoke(), and allocate + memory before a loop instead of in every iteration. Use temporary tensors + data rather than mallocing yourself (see item 2). Use pointers/references + instead of copying as much as possible. + +2. If a data structure will persist during the entire operation, we advise + pre-allocating the memory using temporary tensors. You may need to use + OpData struct to reference the tensor indices in other functions. See + example in the + [kernel for convolution](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/kernels/conv.cc). + A sample code snippet is below + + ``` + auto* op_data = reinterpret_cast(node->user_data); + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(1); + node->temporaries->data[0] = op_data->temp_tensor_index; + TfLiteTensor* temp_tensor = &context->tensors[op_data->temp_tensor_index]; + temp_tensor->type = kTfLiteFloat32; + temp_tensor->allocation_type = kTfLiteArenaRw; + ``` + +3. If it doesn't cost too much wasted memory, prefer using a static fixed size + array (or in Resize() pre-allocated std::vector) rather than using a + dynamically allocating std::vector every iteration of execution. + +4. Avoid instantiating standard library container templates that don't already + exist, because they affect binary size. For example, if you need a std::map + in your operation that doesn't exist in other kernels, using a std::vector + with direct indexing mapping could work while keeping the binary size small. + See what other kernels use to gain insight (or ask). + +5. Check the pointer to the memory returned by malloc. If this pointer is + nullptr, no operations should be performed using that pointer. If you + malloc() in a function and have an error exit, deallocate memory before you + exit. + +6. Use TF_LITE_ENSURE(context, condition) to check for a specific condition. + Your code must not leave memory hanging when TF_LITE_ENSURE is done, i.e., + these should be done before any resources are allocated that will leak. + +## Special TF Graph Attributes + +When Toco convertes a TF graph into TFLite format, it makes some assumption +about custom operations that might be not correct. In this case, the generated +graph can be not executable. + +It is possible to add aditional information about your custom op output to TF +graph before it is converted. The following attributes are supported: + +- **_output_quantized** a boolean attribute, true if the operation outputs are + quantized +- **_output_types** a list of types for output tensors +- **_output_shapes** a list of shapes for output tensors + +### Setting the Attributes + +This is an example how the attributes can be set: + +```python +frozen_graph_def = tf.graph_util.convert_variables_to_constants(...) +for node in frozen_graph_def.node: + if node.op == 'sin': + node.attr['_output_types'].list.type.extend([ + types_pb2.DT_FLOAT, + ]) + node.attr['_output_shapes'].list.shape.extend([ + tf.TensorShape([10]), + ]) + node.attr['_output_quantized'].b = False +tflite_model = tf.contrib.lite.toco_convert( + frozen_graph_def,...) +``` + +**Note:** After the attributes are set, the graph can not be executed by +Tensorflow, therefore it should be done just before the conversion. diff --git a/tensorflow/docs_src/mobile/tflite/demo_android.md b/tensorflow/contrib/lite/g3doc/demo_android.md similarity index 82% rename from tensorflow/docs_src/mobile/tflite/demo_android.md rename to tensorflow/contrib/lite/g3doc/demo_android.md index 7f2f8882a24702d167599452e66afbe720026808..d79a2696b4e9cc10480aa67c7eaec5a356eff596 100644 --- a/tensorflow/docs_src/mobile/tflite/demo_android.md +++ b/tensorflow/contrib/lite/g3doc/demo_android.md @@ -1,7 +1,10 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Android Demo App An example Android application using TensorFLow Lite is available -[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). +[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo). The demo is a sample camera app that classifies images continuously using either a quantized Mobilenet model or a floating point Inception-v3 model. To run the demo, a device running Android 5.0 ( API 21) or higher is required. @@ -44,20 +47,22 @@ app: Android Studio project. * Install all the Gradle extensions it requests. -To get a model, either: +Now you can build and run the demo app. + +The build process downloads the quantized [Mobilenet TensorFlow Lite model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip), and unzips it into the assets directory: `tensorflow/contrib/lite/java/demo/app/src/main/assets/`. -* Download the quantized [Mobilenet TensorFlow Lite model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) - and unzip and copy `mobilenet_quant_v1_224.tflite` to the assets directory: - `tensorflow/contrib/lite/java/demo/app/src/main/assets/`. -* Or, download the floating point [Inception-v3 model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip) - and unzip and copy `inceptionv3_non_slim_2015.tflite` to the assets - directory. Change the chosen classifier in - [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java)
+Some additional details are available on the +[TF Lite Android App page](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). + +### Using other models + +To use a different model: +* Download the floating point [Inception-v3 model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip). +* Unzip and copy `inceptionv3_non_slim_2015.tflite` to the assets directory. +* Change the chosen classifier in [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java)
from: `classifier = new ImageClassifierQuantizedMobileNet(getActivity());`
to: `classifier = new ImageClassifierFloatInception(getActivity());`. -Now you can build and run the demo app. - ## Build TensorFlow Lite and the demo app from source diff --git a/tensorflow/docs_src/mobile/tflite/demo_ios.md b/tensorflow/contrib/lite/g3doc/demo_ios.md similarity index 97% rename from tensorflow/docs_src/mobile/tflite/demo_ios.md rename to tensorflow/contrib/lite/g3doc/demo_ios.md index 3be21da89f9e53d324c2ade0cb937f4b5b30fad4..a554898899e67a6bc2bc52733f5301767bc1c06a 100644 --- a/tensorflow/docs_src/mobile/tflite/demo_ios.md +++ b/tensorflow/contrib/lite/g3doc/demo_ios.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # iOS Demo App The TensorFlow Lite demo is a camera app that continuously classifies whatever diff --git a/tensorflow/docs_src/mobile/tflite/devguide.md b/tensorflow/contrib/lite/g3doc/devguide.md similarity index 89% rename from tensorflow/docs_src/mobile/tflite/devguide.md rename to tensorflow/contrib/lite/g3doc/devguide.md index 4133bc172a1924f0ce8bb515d66fc03d716923c8..dc9cc98c0821edff57cb9428a50637a15211cfda 100644 --- a/tensorflow/docs_src/mobile/tflite/devguide.md +++ b/tensorflow/contrib/lite/g3doc/devguide.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Developer Guide Using a TensorFlow Lite model in your mobile app requires multiple @@ -54,10 +57,11 @@ both floating point and quantized inference. ### Train a custom model A developer may choose to train a custom model using Tensorflow (see the -@{$tutorials} for examples of building and training models). If you have already -written a model, the first step is to export this to a @{tf.GraphDef} file. This -is required because some formats do not store the model structure outside the -code, and we must communicate with other parts of the framework. See +[TensorFlow tutorials](../../tutorials/) for examples of building and training +models). If you have already written a model, the first step is to export this +to a `tf.GraphDef` file. This is required because some formats do not store the +model structure outside the code, and we must communicate with other parts of the +framework. See [Exporting the Inference Graph](https://github.com/tensorflow/models/blob/master/research/slim/README.md) to create .pb file for the custom model. @@ -70,12 +74,12 @@ grow in future Tensorflow Lite releases. ## 2. Convert the model format The model generated (or downloaded) in the previous step is a *standard* -Tensorflow model and you should now have a .pb or .pbtxt @{tf.GraphDef} file. +Tensorflow model and you should now have a .pb or .pbtxt `tf.GraphDef` file. Models generated with transfer learning (re-training) or custom models must be converted—but, we must first freeze the graph to convert the model to the Tensorflow Lite format. This process uses several model formats: -* @{tf.GraphDef} (.pb) —A protobuf that represents the TensorFlow training or +* `tf.GraphDef` (.pb) —A protobuf that represents the TensorFlow training or computation graph. It contains operators, tensors, and variables definitions. * *CheckPoint* (.ckpt) —Serialized variables from a TensorFlow graph. Since this does not contain a graph structure, it cannot be interpreted by itself. @@ -142,11 +146,11 @@ containing the model architecture. The [frozen_graph.pb](https://storage.googlea file used here is available for download. `output_file` is where the TensorFlow Lite model will get generated. The `input_type` and `inference_type` arguments should be set to `FLOAT`, unless converting a -@{$performance/quantization$quantized model}. Setting the `input_array`, -`output_array`, and `input_shape` arguments are not as straightforward. The -easiest way to find these values is to explore the graph using Tensorboard. Reuse -the arguments for specifying the output nodes for inference in the -`freeze_graph` step. +quantized model. +Setting the `input_array`, `output_array`, and `input_shape` arguments are not as +straightforward. The easiest way to find these values is to explore the graph +using Tensorboard. Reuse the arguments for specifying the output nodes for +inference in the `freeze_graph` step. It is also possible to use the Tensorflow Optimizing Converter with protobufs from either Python or from the command line (see the @@ -203,16 +207,16 @@ The open source Android demo app uses the JNI interface and is available [on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). You can also download a [prebuilt APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). -See the @{$tflite/demo_android} guide for details. +See the Android demo guide for details. -The @{$mobile/android_build} guide has instructions for installing TensorFlow on -Android and setting up `bazel` and Android Studio. +The Android mobile guide has instructions for +installing TensorFlow on Android and setting up `bazel` and Android Studio. ### iOS To integrate a TensorFlow model in an iOS app, see the [TensorFlow Lite for iOS](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) -guide and @{$tflite/demo_ios} guide. +guide and iOS demo guide. #### Core ML support diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md index e0358a444d6dffc377bf13ee72ba5477359d6e07..d78d373ccfea074872773693c562253b202a646b 100644 --- a/tensorflow/contrib/lite/g3doc/ios.md +++ b/tensorflow/contrib/lite/g3doc/ios.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # TensorFlow Lite for iOS ## Building diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md index c1c8ef049f693dae038e5e0ca242b9219329cc50..4ceb9a53dc0967ab6320a1bfdb1ddb859482c5dd 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/contrib/lite/g3doc/models.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # List of Hosted Models ## Image classification (Float Models) @@ -39,22 +42,22 @@ single thread large core. Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance ------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: -Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.9% | 65.8% | 3.7 ms -Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.5% | 69.1% | 5.5 ms -Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.8% | 71.9% | 7.9 ms -Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.2% | 73.8% | 10.4 ms -Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.9% | 78.9% | 8.8 ms -Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.7% | 81.3% | 13.0 ms -Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.4% | 83.2% | 18.3 ms -Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 62.2% | 84.5% | 24.7 ms -Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 59.8% | 82.8% | 16.2 ms -Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 63.9% | 85.5% | 24.3 ms -Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.2% | 87.1% | 33.8 ms -Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 67.9% | 88.1% | 45.4 ms -Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 64.0% | 85.5% | 24.9 ms -Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.3% | 87.7% | 37.4 ms -Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.0% | 88.9% | 51.9 ms -Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 69.7% | 89.5% | 70.2 ms +Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.5% | 64.4% | 3.7 ms +Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.4% | 68.5% | 5.5 ms +Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 46.0% | 71.2% | 7.9 ms +Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.0% | 72.8% | 10.4 ms +Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.5% | 77.7% | 8.8 ms +Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.7% | 80.4% | 13.0 ms +Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.0% | 82.2% | 18.3 ms +Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 60.7% | 83.2% | 24.7 ms +Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 55.8% | 78.8% | 16.2 ms +Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 62.3% | 83.8% | 24.3 ms +Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.1% | 86.4% | 33.8 ms +Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 66.8% | 87.0% | 45.4 ms +Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 63.4% | 84.2% | 24.9 ms +Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.2% | 86.7% | 37.4 ms +Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.2% | 88.3% | 51.9 ms +Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.1% | 88.9% | 70.2 ms ## Other models diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md index bd2f797e6c5b05f52bec9fc34f1b8011aca70330..b06f4fd3b893e5e5977f92de26109a6dd264531f 100644 --- a/tensorflow/contrib/lite/g3doc/ops_versioning.md +++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # TensorFlow Lite Ops Versioning This document describes TensorFlow Lite's op versioning schema. Op diff --git a/tensorflow/docs_src/mobile/tflite/index.md b/tensorflow/contrib/lite/g3doc/overview.md similarity index 93% rename from tensorflow/docs_src/mobile/tflite/index.md rename to tensorflow/contrib/lite/g3doc/overview.md index 3d1733024e493042a2cc85aa9f2fec4b75eefa94..be60d7941ade824ee201bfd05400fb3e4e9fae7e 100644 --- a/tensorflow/docs_src/mobile/tflite/index.md +++ b/tensorflow/contrib/lite/g3doc/overview.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Introduction to TensorFlow Lite TensorFlow Lite is TensorFlow’s lightweight solution for mobile and embedded @@ -70,10 +73,9 @@ There are several factors which are fueling interest in this domain: We believe the next wave of machine learning applications will have significant processing on mobile and embedded devices. -## TensorFlow Lite developer preview highlights +## TensorFlow Lite highlights -TensorFlow Lite is available as a developer preview and includes the -following: +TensorFlow Lite provides: - A set of core operators, both quantized and float, many of which have been tuned for mobile platforms. These can be used to create and run custom @@ -129,9 +131,6 @@ following: - Java and C++ API support -Note: This is a developer release, and it’s likely that there will be changes in -the API in upcoming versions. We do not guarantee backward or forward -compatibility with this release. ## Getting Started @@ -201,9 +200,5 @@ possible performance for a particular model on a particular device. ## Next Steps -For the developer preview, most of our documentation is on GitHub. Please take a -look at the [TensorFlow Lite -repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite) -on GitHub for more information and for code samples, demo applications, and -more. - +The TensorFlow Lite [GitHub repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite). +contains additional docs, code samples, and demo applications. diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md new file mode 100644 index 0000000000000000000000000000000000000000..5cd0aab44f10de1b76e1acb302fc1ee2711c8d74 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/performance.md @@ -0,0 +1,176 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + +# Performance + +This document lists TensorFlow Lite performance benchmarks when running well +known models on some Android and iOS devices. + +These performance benchmark numbers were generated with the +[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) +and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios). + +# Android performance benchmarks + +For Android benchmarks, the CPU affinity is set to use big cores on the device to +reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)). + +It assumes that models were download and unzipped to the +`/data/local/tmp/tflite_models` directory. The benchmark binary is built +using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android) +and assumed in the `/data/local/tmp` directory. + +To run the benchmark: + +``` +adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \ + --num_threads=1 \ + --graph=/data/local/tmp/tflite_models/${GRAPH} \ + --warmup_runs=1 \ + --num_runs=50 \ + --use_nnapi=false +``` + +Here, `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity +chosen according to the following table: + +Device | CPU_MASK | +-------| ---------- +Pixel 2 | f0 | +Pixel xl | 0c | + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model NameDevice Mean inference time (std dev)
+ Mobilenet_1.0_224(float) + Pixel 2 166.5 ms (2.6 ms)
Pixel xl 122.9 ms (1.8 ms)
+ Mobilenet_1.0_224 (quant) + Pixel 2 69.5 ms (0.9 ms)
Pixel xl 78.9 ms (2.2 ms)
+ NASNet mobile + Pixel 2 273.8 ms (3.5 ms)
Pixel xl 210.8 ms (4.2 ms)
+ SqueezeNet + Pixel 2 234.0 ms (2.1 ms)
Pixel xl 158.0 ms (2.1 ms)
+ Inception_ResNet_V2 + Pixel 2 2846.0 ms (15.0 ms)
Pixel xl 1973.0 ms (15.0 ms)
+ Inception_V4 + Pixel 2 3180.0 ms (11.7 ms)
Pixel xl 2262.0 ms (21.0 ms)
+ +# iOS benchmarks + +To run iOS benchmarks, the [benchmark +app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios) +was modified to include the appropriate model and `benchmark_params.json` was +modified to set `num_threads` to 1. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model NameDevice Mean inference time (std dev)
+ Mobilenet_1.0_224(float) + iPhone 8 32.2 ms (0.8 ms)
+ Mobilenet_1.0_224 (quant) + iPhone 8 24.4 ms (0.8 ms)
+ NASNet mobile + iPhone 8 60.3 ms (0.6 ms)
+ SqueezeNet + iPhone 8 44.3 (0.7 ms)
+ Inception_ResNet_V2 + iPhone 8562.4 ms (18.2 ms)
+ Inception_V4 + iPhone 8 661.0 ms (29.2 ms)
diff --git a/tensorflow/contrib/lite/g3doc/rpi.md b/tensorflow/contrib/lite/g3doc/rpi.md index ab50789307414255bccd84d4cfcb6ddecc25ba08..9fcf79ba004d85566b64ce35b3693e01c4b0e2cf 100644 --- a/tensorflow/contrib/lite/g3doc/rpi.md +++ b/tensorflow/contrib/lite/g3doc/rpi.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # TensorFlow Lite for Raspberry Pi ## Cross compiling @@ -17,7 +20,7 @@ Clone this Tensorflow repository, Run this script at the root of the repository ```bash ./tensorflow/contrib/lite/download_dependencies.sh ``` -Note than you only need to to this once. +Note that you only need to do this once. You should then be able to compile: ```bash @@ -39,7 +42,7 @@ First, clone this TensorFlow repository. Run this at the root of the repository: ```bash ./tensorflow/contrib/lite/download_dependencies.sh ``` -Note than you only need to to this once. +Note that you only need to do this once. You should then be able to compile: ```bash diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 965273f0f04d33b52903c0551fff3533c31d3bd8..aa65ec99887a61df658dd7add7b5cc3b91d81846 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # TensorFlow Lite & TensorFlow Compatibility Guide TensorFlow Lite supports a number of TensorFlow operations used in common @@ -42,6 +45,7 @@ counterparts: *as long as the input tensor is 4D (1 batch + 2 spatial + 1 other) and the crops attribute is not used* * [tf.exp](https://www.tensorflow.org/api_docs/python/tf/exp) +* [tf.fake_quant*](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args) * [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul) - *as long as the second argument is constant and transposition is not used* * [tf.nn.avg_pool](https://www.tensorflow.org/api_docs/python/tf/nn/avg_pool) @@ -58,6 +62,7 @@ counterparts: * [tf.nn.softmax](https://www.tensorflow.org/api_docs/python/tf/nn/softmax) - *as long as tensors are 2D and axis is the last dimension* * [tf.nn.top_k](https://www.tensorflow.org/api_docs/python/tf/nn/top_k) +* [tf.one_hot](https://www.tensorflow.org/api_docs/python/tf/one_hot) * [tf.pad](https://www.tensorflow.org/api_docs/python/tf/pad) - *as long as mode and constant_values are not used* * [tf.reduce_mean](https://www.tensorflow.org/api_docs/python/tf/reduce_mean) - @@ -584,6 +589,31 @@ Options { } ``` +**RSQRT** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: result of computing element-wise reciprocal square root of the input tensor +} +``` + +**SHAPE** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: a 1D tensor representing the shape of the input tensor +} +Options { + out_type: the output type of the op (int32 or int64). Defaults to int32. +} +``` + **SLICE** ``` @@ -670,6 +700,17 @@ Options { } ``` +**SQRT** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: result of computing element-wise square root of the input tensor +} +``` + **SQUEEZE** ``` @@ -742,6 +783,66 @@ Outputs { } ``` +**POW** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: elementwise pow of the input tensors +} +``` + +**ARG_MAX** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: A tensor of indices of maximum values. +} +``` + +**ARG_MIN** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: A tensor of indices of minium values. +} +``` + +**PACK** + +``` +Inputs { + 0: a list of tensors. + 1: an integer. +} +Outputs { + 0: A tensor of stacked tensors. +} +``` + +**LOGICAL_OR** + +``` +Inputs { + 0: a list of tensors. + 1: a list of tensors. +} +Outputs { + 0: A tensor of logical_or output tensors. +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/docs_src/mobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md similarity index 97% rename from tensorflow/docs_src/mobile/android_build.md rename to tensorflow/contrib/lite/g3doc/tfmobile/android_build.md index f4b07db4591dddcfbf3633f471072f4a0eea9843..76e16fc9db27782fe0f9454ba463722f4bf6eb4b 100644 --- a/tensorflow/docs_src/mobile/android_build.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Building TensorFlow on Android To get you started working with TensorFlow on Android, we'll walk through two @@ -91,7 +94,8 @@ using [ADB](https://developer.android.com/studio/command-line/adb.html). This requires some knowledge of build systems and Android developer tools, but we'll guide you through the basics here. -- First, follow our instructions for @{$install/install_sources$installing from sources}. +- First, follow our instructions for + installing from sources. This will also guide you through installing Bazel and cloning the TensorFlow code. diff --git a/tensorflow/docs_src/mobile/mobile_intro.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md similarity index 85% rename from tensorflow/docs_src/mobile/mobile_intro.md rename to tensorflow/contrib/lite/g3doc/tfmobile/index.md index 241f01d460ae35e818a61be4c4914b3bd8dae00a..bd047bfceceddfd0b5a9fd0c83cb47a339299abf 100644 --- a/tensorflow/docs_src/mobile/mobile_intro.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md @@ -1,4 +1,45 @@ -# Introduction to TensorFlow Mobile +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + +# Overview + +TensorFlow was designed to be a good deep learning solution for mobile +platforms. Currently we have two solutions for deploying machine learning +applications on mobile and embedded devices: TensorFlow for Mobile and +TensorFlow Lite. + +## TensorFlow Lite versus TensorFlow Mobile + +Here are a few of the differences between the two: + +- TensorFlow Lite is an evolution of TensorFlow Mobile. In most cases, apps + developed with TensorFlow Lite will have a smaller binary size, fewer + dependencies, and better performance. + +- TensorFlow Lite is in developer preview, so not all use cases are covered yet. + We expect you to use TensorFlow Mobile to cover production cases. + +- TensorFlow Lite supports only a limited set of operators, so not all models + will work on it by default. TensorFlow for Mobile has a fuller set of + supported functionality. + +TensorFlow Lite provides better performance and a small binary size on mobile +platforms as well as the ability to leverage hardware acceleration if available +on their platforms. In addition, it has many fewer dependencies so it can be +built and hosted on simpler, more constrained device scenarios. TensorFlow Lite +also allows targeting accelerators through the [Neural Networks +API](https://developer.android.com/ndk/guides/neuralnetworks/index.html). + +TensorFlow Lite currently has coverage for a limited set of operators. While +TensorFlow for Mobile supports only a constrained set of ops by default, in +principle if you use an arbitrary operator in TensorFlow, it can be customized +to build that kernel. Thus use cases which are not currently supported by +TensorFlow Lite should continue to use TensorFlow for Mobile. As TensorFlow Lite +evolves, it will gain additional operators, and the decision will be easier to +make. + + +## Introduction to TensorFlow Mobile TensorFlow was designed from the ground up to be a good deep learning solution for mobile platforms like Android and iOS. This mobile guide should help you @@ -38,7 +79,8 @@ speech-driven interface, and many of these require on-device processing. Most of the time a user isn’t giving commands, and so streaming audio continuously to a remote server would be a waste of bandwidth, since it would mostly be silence or background noises. To solve this problem it’s common to have a small neural -network running on-device @{$tutorials/audio_recognition$listening out for a particular keyword}. +network running on-device +[listening out for a particular keyword](../tutorials/sequences/audio_recognition). Once that keyword has been spotted, the rest of the conversation can be transmitted over to the server for further processing if more computing power is needed. @@ -166,7 +208,7 @@ interesting products possible. TensorFlow runs on Ubuntu Linux, Windows 10, and OS X. For a list of all supported operating systems and instructions to install TensorFlow, see -@{$install$Installing Tensorflow}. +Installing Tensorflow. Note that some of the sample code we provide for mobile TensorFlow requires you to compile TensorFlow from source, so you’ll need more than just `pip install` @@ -240,8 +282,3 @@ results you’ll see. It’s common for an algorithm to get great training accur numbers but then fail to be useful within a real application because there’s a mismatch between the dataset and real usage. Prototype end-to-end usage as soon as possible to create a consistent user experience. - -## Next Steps - -We suggest you get started by building one of our demos for -@{$mobile/android_build$Android} or @{$mobile/ios_build$iOS}. diff --git a/tensorflow/docs_src/mobile/ios_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md similarity index 98% rename from tensorflow/docs_src/mobile/ios_build.md rename to tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md index 4c84a1214a26eeb90c1b6a186a369212377b06cd..6223707892ce7b288ecabf932b33cd39860446a6 100644 --- a/tensorflow/docs_src/mobile/ios_build.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Building TensorFlow on iOS ## Using CocoaPods diff --git a/tensorflow/docs_src/mobile/linking_libs.md b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md similarity index 83% rename from tensorflow/docs_src/mobile/linking_libs.md rename to tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md index efef5dd0daa0b267d8384d32d62d9ce0226dc102..4c2071ed053125cfa643ed785fe302198f734ead 100644 --- a/tensorflow/docs_src/mobile/linking_libs.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Integrating TensorFlow libraries Once you have made some progress on a model that addresses the problem you’re @@ -14,11 +17,11 @@ TensorFlow mobile demo apps. After you've managed to build the examples, you'll probably want to call TensorFlow from one of your existing applications. The very easiest way to do -this is to use the Pod installation steps described -@{$mobile/ios_build#using_cocoapods$here}, but if you want to build TensorFlow -from source (for example to customize which operators are included) you'll need -to break out TensorFlow as a framework, include the right header files, and link -against the built libraries and dependencies. +this is to use the Pod installation steps described in +Building TensorFlow on iOS, but if you want to build +TensorFlow from source (for example to customize which operators are included) +you'll need to break out TensorFlow as a framework, include the right header +files, and link against the built libraries and dependencies. ### Android @@ -82,10 +85,12 @@ recompile of the core. To achieve this capability, TensorFlow uses a registration pattern in a lot of places. In the code, it looks like this: - class MulKernel : OpKernel { - Status Compute(OpKernelContext* context) { … } - }; - REGISTER_KERNEL(MulKernel, “Mul”); +``` +class MulKernel : OpKernel { + Status Compute(OpKernelContext* context) { … } +}; +REGISTER_KERNEL(MulKernel, “Mul”); +``` This would be in a standalone `.cc` file linked into your application, either as part of the main set of kernels or as a separate custom library. The magic @@ -101,15 +106,17 @@ doesn’t offer a good mechanism for doing this sort of registration, so we have to resort to some tricky code. Under the hood, the macro is implemented so that it produces something like this: - class RegisterMul { - public: - RegisterMul() { - global_kernel_registry()->Register(“Mul”, [](){ - return new MulKernel() - }); - } - }; - RegisterMul g_register_mul; +``` +class RegisterMul { + public: + RegisterMul() { + global_kernel_registry()->Register(“Mul”, [](){ + return new MulKernel() + }); + } +}; +RegisterMul g_register_mul; +``` This sets up a class `RegisterMul` with a constructor that tells the global kernel registry what function to call when somebody asks it how to create a @@ -176,8 +183,10 @@ have an experimental script at [rename_protobuf.sh](https://github.com/tensorflo You need to run this as part of the makefile build, after you’ve downloaded all the dependencies: - tensorflow/contrib/makefile/download_dependencies.sh - tensorflow/contrib/makefile/rename_protobuf.sh +``` +tensorflow/contrib/makefile/download_dependencies.sh +tensorflow/contrib/makefile/rename_protobuf.sh +``` ## Calling the TensorFlow API @@ -193,18 +202,20 @@ use case, while on iOS and Raspberry Pi you call directly into the C++ API. Here’s what a typical Inference Library sequence looks like on Android: - // Load the model from disk. - TensorFlowInferenceInterface inferenceInterface = - new TensorFlowInferenceInterface(assetManager, modelFilename); +``` +// Load the model from disk. +TensorFlowInferenceInterface inferenceInterface = +new TensorFlowInferenceInterface(assetManager, modelFilename); - // Copy the input data into TensorFlow. - inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); +// Copy the input data into TensorFlow. +inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); - // Run the inference call. - inferenceInterface.run(outputNames, logStats); +// Run the inference call. +inferenceInterface.run(outputNames, logStats); - // Copy the output Tensor back into the output array. - inferenceInterface.fetch(outputName, outputs); +// Copy the output Tensor back into the output array. +inferenceInterface.fetch(outputName, outputs); +``` You can find the source of this code in the [Android examples](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java#L107). @@ -212,27 +223,29 @@ You can find the source of this code in the [Android examples](https://github.co Here’s the equivalent code for iOS and Raspberry Pi: - // Load the model. - PortableReadFileToProto(file_path, &tensorflow_graph); - - // Create a session from the model. - tensorflow::Status s = session->Create(tensorflow_graph); - if (!s.ok()) { - LOG(FATAL) << "Could not create TensorFlow Graph: " << s; - } - - // Run the model. - std::string input_layer = "input"; - std::string output_layer = "output"; - std::vector outputs; - tensorflow::Status run_status = session->Run({{input_layer, image_tensor}}, +``` +// Load the model. +PortableReadFileToProto(file_path, &tensorflow_graph); + +// Create a session from the model. +tensorflow::Status s = session->Create(tensorflow_graph); +if (!s.ok()) { + LOG(FATAL) << "Could not create TensorFlow Graph: " << s; +} + +// Run the model. +std::string input_layer = "input"; +std::string output_layer = "output"; +std::vector outputs; +tensorflow::Status run_status = session->Run({\{input_layer, image_tensor}}, {output_layer}, {}, &outputs); - if (!run_status.ok()) { - LOG(FATAL) << "Running model failed: " << run_status; - } +if (!run_status.ok()) { + LOG(FATAL) << "Running model failed: " << run_status; +} - // Access the output data. - tensorflow::Tensor* output = &outputs[0]; +// Access the output data. +tensorflow::Tensor* output = &outputs[0]; +``` This is all based on the [iOS sample code](https://www.tensorflow.org/code/tensorflow/examples/ios/simple/RunModelViewController.mm), diff --git a/tensorflow/docs_src/mobile/optimizing.md b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md similarity index 98% rename from tensorflow/docs_src/mobile/optimizing.md rename to tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md index 778e4d3a6233c3bec70b830bc998013745a1f0ba..a0192c3541483437b817e22eb92193bd7bcb4c28 100644 --- a/tensorflow/docs_src/mobile/optimizing.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Optimizing for mobile There are some special issues that you have to deal with when you’re trying to @@ -77,7 +80,7 @@ out of a mobile device's memory faster. To understand how large your network will be on disk, start by looking at the size on disk of your `GraphDef` file after you’ve run `freeze_graph` and -`strip_unused_nodes` on it (see @{$mobile/prepare_models$Preparing models} for +`strip_unused_nodes` on it (see Preparing models for more details on these tools), since then it should only contain inference-related nodes. To double-check that your results are as expected, run the `summarize_graph` tool to see how many parameters are in constants: @@ -103,7 +106,8 @@ you multiply the number of const parameters by four, you should get something that’s close to the size of the file on disk. You can often get away with only eight-bits per parameter with very little loss of accuracy in the final result, so if your file size is too large you can try using -@{$performance/quantization$quantize_weights} to transform the parameters down. +quantize_weights +to transform the parameters down. bazel build tensorflow/tools/graph_transforms:transform_graph && \ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ @@ -292,7 +296,8 @@ run it on a 64-bit ARM device: You can interpret the results in exactly the same way as the desktop version above. If you have any trouble figuring out what the right input and output -names and types are, take a look at the @{$mobile/prepare_models$Preparing models} +names and types are, take a look at the +Preparing models page for details about detecting these for your model, and look at the `summarize_graph` tool which may give you helpful information. diff --git a/tensorflow/docs_src/mobile/prepare_models.md b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md similarity index 98% rename from tensorflow/docs_src/mobile/prepare_models.md rename to tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md index 2b84dbb97388b16c6a4ae1d3472e0b1a993285f0..6b4e4a92bd9262139be3cf650b7d16714ee3a277 100644 --- a/tensorflow/docs_src/mobile/prepare_models.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Preparing models for mobile deployment The requirements for storing model information during training are very @@ -255,8 +258,8 @@ The criteria for including ops and types fall into several categories: These ops are trimmed by default to optimize for inference on mobile, but it is possible to alter some build files to change the default. After alternating the build files, you will need to recompile TensorFlow. See below for more details -on how to do this, and also see @{$mobile/optimizing#binary_size$Optimizing} for -more on reducing your binary size. +on how to do this, and also see optimizing binary size +for more on reducing your binary size. ### Locate the implementation diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h index 313af5fb7574b42bcdd53b4baad06e4ccfb34053..77268d7aebe9ebfb33b9f35b319d34e6de8324ee 100644 --- a/tensorflow/contrib/lite/graph_info.h +++ b/tensorflow/contrib/lite/graph_info.h @@ -46,6 +46,9 @@ class GraphInfo { // Returns the indices of the output tensors. virtual const std::vector& outputs() const = 0; + + // Returns the indices of the variable tensors. + virtual const std::vector& variables() const = 0; }; // Represents a subgraph of a TensorFlow Lite graph. diff --git a/tensorflow/contrib/lite/graph_info_test.cc b/tensorflow/contrib/lite/graph_info_test.cc index ea38b43993fef71c6820c7a978351d92d5420287..89a8f36b416b5dec54c1e374cdcdae3ab9ab0cde 100644 --- a/tensorflow/contrib/lite/graph_info_test.cc +++ b/tensorflow/contrib/lite/graph_info_test.cc @@ -45,6 +45,7 @@ class SimpleTestGraph : public GraphInfo { TfLiteTensor* tensor(size_t index) override { return &tensors_[index]; } const std::vector& inputs() const override { return inputs_; } const std::vector& outputs() const override { return outputs_; } + const std::vector& variables() const override { return variables_; } void AddNode(const std::vector& inputs, const std::vector& outputs) { @@ -67,6 +68,7 @@ class SimpleTestGraph : public GraphInfo { std::vector tensors_; std::vector inputs_; std::vector outputs_; + std::vector variables_; }; // Partition a graph to generate a list of subgraphs. This wraps the API call diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index ebb0aedc2001a86b7fcff67ef8703b5e4a845818..362e5887257f1a06263aadbdaef011b3893a577f 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -22,10 +22,9 @@ limitations under the License. #include "tensorflow/contrib/lite/arena_planner.h" #include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/context_util.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/graph_info.h" -#include "tensorflow/contrib/lite/kernels/eigen_support.h" -#include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/nnapi_delegate.h" #include "tensorflow/contrib/lite/profiling/profiler.h" @@ -33,9 +32,21 @@ limitations under the License. #include "tensorflow/contrib/lite/util.h" namespace tflite { - namespace { +TfLiteStatus ReportOpError(TfLiteContext* context, const TfLiteNode& node, + const TfLiteRegistration& registration, + int node_index, const char* message) { + context->ReportError( + context, "Node number %d (%s) %s.\n", node_index, + registration.custom_name + ? registration.custom_name + : EnumNameBuiltinOperator( + static_cast(registration.builtin_code)), + message); + return kTfLiteError; +} + // Stub method which returns kTfLiteError when the function is forbidden. // We're registrating this function to several different function to save // compiled binary size. Please note the restrictions: @@ -53,6 +64,19 @@ void SetForbiddenContextFunction(FunctionType* func) { *func = reinterpret_cast(ForbiddenContextFunction); } +// Returns true if at least one tensor in the given list is kTfLiteDynamic. +template +bool HasDynamicTensorImpl(const TfLiteContext& context, + const TensorIntArray& int_array) { + for (int i : int_array) { + const TfLiteTensor& tensor = context.tensors[i]; + if (tensor.allocation_type == kTfLiteDynamic) { + return true; + } + } + return false; +} + } // namespace // A trivial implementation of GraphInfo around the Interpreter. @@ -82,6 +106,9 @@ class InterpreterInfo : public GraphInfo { const std::vector& outputs() const override { return interpreter_->outputs(); } + const std::vector& variables() const override { + return interpreter_->variables(); + } public: Interpreter* interpreter_; @@ -96,19 +123,22 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) context_.AddTensors = AddTensors; context_.tensors = nullptr; context_.tensors_size = 0; - context_.eigen_context = nullptr; - context_.gemm_context = nullptr; context_.recommended_num_threads = -1; + context_.GetExternalContext = GetExternalContext; + context_.SetExternalContext = SetExternalContext; // Invalid to call these these except from TfLiteDelegate - SetForbiddenContextFunction(&context_.GetNodeAndRegistration); - SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); - SetForbiddenContextFunction(&context_.GetExecutionPlan); + SwitchToKernelContext(); // Reserve some space for the tensors to avoid excessive resizing. tensors_.reserve(kTensorsReservedCapacity); nodes_and_registration_.reserve(kTensorsReservedCapacity); next_execution_plan_index_to_prepare_ = 0; + + for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) { + external_contexts_[i] = nullptr; + } + UseNNAPI(false); } @@ -127,7 +157,7 @@ Interpreter::~Interpreter() { TfLiteTensor* tensor = &context_.tensors[i]; if (tensor->buffer_handle != kTfLiteNullBufferHandle && tensor->delegate->FreeBufferHandle != nullptr) { - tensor->delegate->FreeBufferHandle(tensor->delegate, + tensor->delegate->FreeBufferHandle(&context_, tensor->delegate, &tensor->buffer_handle); } TfLiteTensorFree(tensor); @@ -243,8 +273,9 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( int node_index; TfLiteDelegateParams* params = CreateDelegateParams(delegate, subgraph); - AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors, - nullptr, 0, params, ®istration, &node_index); + TF_LITE_ENSURE_STATUS(AddNodeWithParameters( + subgraph.input_tensors, subgraph.output_tensors, nullptr, 0, params, + ®istration, &node_index)); // Initialize the output tensors's delegate-related fields. for (int tensor_index : subgraph.output_tensors) { @@ -266,6 +297,33 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( return kTfLiteOk; } +TfLiteExternalContext* Interpreter::GetExternalContext( + TfLiteExternalContextType type) { + if (type >= 0 && type < kTfLiteMaxExternalContexts) { + return external_contexts_[type]; + } + return nullptr; +} + +TfLiteExternalContext* Interpreter::GetExternalContext( + struct TfLiteContext* context, TfLiteExternalContextType type) { + return static_cast(context->impl_)->GetExternalContext(type); +} + +void Interpreter::SetExternalContext(TfLiteExternalContextType type, + TfLiteExternalContext* ctx) { + if (type >= 0 && type < kTfLiteMaxExternalContexts) { + external_contexts_[type] = ctx; + } +} + +void Interpreter::SetExternalContext(struct TfLiteContext* context, + TfLiteExternalContextType type, + TfLiteExternalContext* ctx) { + return static_cast(context->impl_) + ->SetExternalContext(type, ctx); +} + // Gets an TfLiteIntArray* representing the execution plan. The interpreter owns // this memory and it is only guaranteed to exist during the invocation of the // delegate prepare. @@ -302,6 +360,13 @@ TfLiteStatus Interpreter::SetOutputs(std::vector outputs) { return kTfLiteOk; } +TfLiteStatus Interpreter::SetVariables(std::vector variables) { + TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("variables", variables.data(), + variables.size())); + variables_ = std::move(variables); + return kTfLiteOk; +} + TfLiteStatus Interpreter::CheckTensorIndices(const char* label, const int* indices, int length) { // Making sure kOptionalTensor is not re-defined to something other than -1. @@ -334,6 +399,9 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, case kTfLiteFloat32: *bytes = sizeof(float) * count; break; + case kTfLiteInt16: + *bytes = sizeof(int16_t) * count; + break; case kTfLiteInt32: *bytes = sizeof(int32_t) * count; break; @@ -346,32 +414,65 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, case kTfLiteBool: *bytes = sizeof(bool) * count; break; + case kTfLiteComplex64: + *bytes = sizeof(std::complex) * count; + break; default: - ReportError( - &context_, - "Only float32, int32, int64, uint8, bool supported currently."); + ReportError(&context_, + "Only float32, int16, int32, int64, uint8, bool, complex64 " + "supported currently."); return kTfLiteError; } return kTfLiteOk; } TfLiteStatus Interpreter::AllocateTensors() { - next_execution_plan_index_to_prepare_ = 0; - if (memory_planner_) { - TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations()); - } - if (!consistent_) { ReportError(&context_, "AllocateTensors() called on inconsistent model."); return kTfLiteError; } + // Explicit (re)allocation is necessary if nodes have been changed or tensors + // have been resized. For inputs marked as dynamic, we can't short-circuit the + // allocation as the client may have done the resize manually. + if (state_ != kStateUninvokable && !HasDynamicTensorImpl(context_, inputs_)) { + return kTfLiteOk; + } + + next_execution_plan_index_to_prepare_ = 0; + if (memory_planner_) { + TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations()); + } + TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); - if (state_ == kStateUninvokable) { - state_ = kStateInvokable; + + state_ = kStateInvokable; + + // Reset the variable tensors to zero after (re)allocating the tensors. + // Developers shouldn't rely on the side effect of this function to reset + // variable tesnsors. They should call `ResetVariableTensorsToZero` directly + // instead. + ResetVariableTensorsToZero(); + + return kTfLiteOk; +} + +// TODO(ycling): Consider to provide other functions to initialize variable +// tensors to non-zero values. +TfLiteStatus Interpreter::ResetVariableTensorsToZero() { + for (auto& tensor : tensors_) { + if (!tensor.is_variable) { + continue; + } + + // Variable tensors have to be `kTfLiteArenaRwPersistent`, and must be + // allocated after the initial `PrepareOpsAndTensors()` is called. + TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type, + kTfLiteArenaRwPersistent); + TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr); + + memset(tensor.data.raw, 0, tensor.bytes); } - TF_LITE_ENSURE(&context_, state_ == kStateInvokable || - state_ == kStateInvokableAndImmutable); return kTfLiteOk; } @@ -445,26 +546,26 @@ TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, "ResizeInputTensor is disallowed when graph is immutable."); return kTfLiteError; } - state_ = kStateUninvokable; // TODO(aselle): All bounds checks can be implemented as one-sided bounds // checks by casting to unsigned for efficiency. Profile before doing this. TF_LITE_ENSURE(&context_, tensor_index < context_.tensors_size && tensor_index >= 0); - TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims); - return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite); + TfLiteTensor* tensor = &context_.tensors[tensor_index]; + + // Short-circuit the state change if the dimensions don't change, avoiding + // unnecessary (re)allocations. + if (EqualArrayAndTfLiteIntArray(tensor->dims, dims.size(), dims.data())) { + return kTfLiteOk; + } + + state_ = kStateUninvokable; + return ResizeTensorImpl(tensor, ConvertVectorToTfLiteIntArray(dims)); } -// Returns true if at least one tensor in the given list is kTfLiteDynamic. bool HasDynamicTensor(const TfLiteContext& context, - const TfLiteIntArray* tensors) { - for (int i = 0; i < tensors->size; ++i) { - const TfLiteTensor& tensor = context.tensors[tensors->data[i]]; - if (tensor.allocation_type == kTfLiteDynamic) { - return true; - } - } - return false; + const TfLiteIntArray* int_array) { + return HasDynamicTensorImpl(context, TfLiteIntArrayView{int_array}); } TfLiteStatus Interpreter::PrepareOpsStartingAt( @@ -477,7 +578,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt( nodes_and_registration_[node_index].second; EnsureTensorsVectorCapacity(); if (OpPrepare(registration, &node) == kTfLiteError) { - return kTfLiteError; + return ReportOpError(&context_, node, registration, node_index, + "failed to prepare"); } *last_execution_plan_index_prepared = execution_plan_index; @@ -495,7 +597,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt( TfLiteStatus Interpreter::PrepareOpsAndTensors() { if (!memory_planner_) { memory_planner_.reset(new ArenaPlanner( - &context_, std::unique_ptr(new InterpreterInfo(this)))); + &context_, std::unique_ptr(new InterpreterInfo(this)), + /*preserve_inputs=*/true, /*preserve_intermediates*/ false)); memory_planner_->PlanAllocations(); } @@ -572,8 +675,17 @@ TfLiteStatus Interpreter::Invoke() { } EnsureTensorsVectorCapacity(); + tensor_resized_since_op_invoke_ = false; if (OpInvoke(registration, &node) == kTfLiteError) { - status = kTfLiteError; + status = ReportOpError(&context_, node, registration, node_index, + "failed to invoke"); + } + + // Force execution prep for downstream ops if the latest op triggered the + // resize of a dynamic tensor. + if (tensor_resized_since_op_invoke_ && + HasDynamicTensor(context_, node.outputs)) { + next_execution_plan_index_to_prepare_ = execution_plan_index + 1; } } @@ -687,7 +799,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( state_ = kStateUninvokable; TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), quantization, const_cast(buffer), bytes, - kTfLiteMmapRo, allocation, &tensor); + kTfLiteMmapRo, allocation, false, &tensor); } return kTfLiteOk; } @@ -698,7 +810,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( // to Interpreter. TfLiteStatus Interpreter::SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization) { + const int* dims, TfLiteQuantizationParams quantization, bool is_variable) { if (state_ == kStateInvokableAndImmutable) { ReportError( &context_, @@ -716,11 +828,23 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite( TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims, rank, &required_bytes)); } + + TfLiteAllocationType allocation_type = kTfLiteArenaRw; + if (type == kTfLiteString) { + if (is_variable) { + // We don't have a real use case for string variable tensor. + ReportError(&context_, "String variable tensor isn't supported."); + return kTfLiteError; + } + allocation_type = kTfLiteDynamic; + } else if (is_variable) { + allocation_type = kTfLiteArenaRwPersistent; + } + TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), quantization, - /*buffer=*/nullptr, required_bytes, - type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw, - nullptr, &context_.tensors[tensor_index]); + /*buffer=*/nullptr, required_bytes, allocation_type, + nullptr, is_variable, &context_.tensors[tensor_index]); return kTfLiteOk; } @@ -736,7 +860,10 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size) { // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. if (tensor->allocation_type == kTfLiteArenaRw || - tensor->allocation_type == kTfLiteDynamic) { + tensor->allocation_type == kTfLiteDynamic || + tensor->allocation_type == kTfLiteArenaRwPersistent) { + tensor_resized_since_op_invoke_ |= + TfLiteIntArrayEqual(tensor->dims, new_size) == 0; if (tensor->type != kTfLiteString) { size_t bytesRequired; TfLiteStatus status = BytesRequired(tensor->type, new_size->data, @@ -770,7 +897,7 @@ void Interpreter::UseNNAPI(bool enable) { // TODO(aselle): This is a workaround for finding if NNAPI exists. // We also need to make sure getLibraryHandle() is renamed to be NNAPI // prefixed. - if (!NNAPIExists()) enable = false; + if (!NNAPIDelegate::IsSupported()) enable = false; if (!enable) { nnapi_delegate_.reset(); } else if (!nnapi_delegate_) { @@ -781,10 +908,25 @@ void Interpreter::UseNNAPI(bool enable) { void Interpreter::SetNumThreads(int num_threads) { context_.recommended_num_threads = num_threads; - // TODO(ahentz): find a way to avoid this. It causes gemmlowp and eigen to - // be required in order to compile the framework. - gemm_support::SetNumThreads(&context_, num_threads); - eigen_support::SetNumThreads(&context_, num_threads); + for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) { + auto* c = external_contexts_[i]; + if (c && c->Refresh) { + c->Refresh(&context_); + } + } +} + +void Interpreter::SwitchToDelegateContext() { + context_.GetNodeAndRegistration = GetNodeAndRegistration; + context_.ReplaceSubgraphsWithDelegateKernels = + ReplaceSubgraphsWithDelegateKernels; + context_.GetExecutionPlan = GetExecutionPlan; +} + +void Interpreter::SwitchToKernelContext() { + SetForbiddenContextFunction(&context_.GetNodeAndRegistration); + SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); + SetForbiddenContextFunction(&context_.GetExecutionPlan); } TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, @@ -813,24 +955,20 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, // TODO(aselle): Consider if it is worth storing pointers to delegates. // Setup additional context interface. - context_.GetNodeAndRegistration = GetNodeAndRegistration; - context_.ReplaceSubgraphsWithDelegateKernels = - ReplaceSubgraphsWithDelegateKernels; - context_.GetExecutionPlan = GetExecutionPlan; + SwitchToDelegateContext(); TfLiteStatus status = delegate->Prepare(&context_, delegate); // Remove additional context info. - SetForbiddenContextFunction(&context_.GetNodeAndRegistration); - SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); - SetForbiddenContextFunction(&context_.GetExecutionPlan); + SwitchToKernelContext(); TF_LITE_ENSURE_OK(&context_, status); if (!allow_dynamic_tensors) { + // Reset the state to force tensor/op reallocation. + state_ = kStateUninvokable; TF_LITE_ENSURE_OK(&context_, AllocateTensors()); - TF_LITE_ENSURE(&context_, state_ == kStateInvokable || - state_ == kStateInvokableAndImmutable); + TF_LITE_ENSURE_EQ(&context_, state_, kStateInvokable); // After using a delegate which doesn't support dynamic tensors, make the // entire graph immutable. state_ = kStateInvokableAndImmutable; @@ -850,7 +988,7 @@ TfLiteStatus Interpreter::SetBufferHandle(int tensor_index, tensor->delegate = delegate; if (tensor->buffer_handle != kTfLiteNullBufferHandle) { TF_LITE_ENSURE(&context_, tensor->delegate->FreeBufferHandle != nullptr); - tensor->delegate->FreeBufferHandle(tensor->delegate, + tensor->delegate->FreeBufferHandle(&context_, tensor->delegate, &tensor->buffer_handle); } tensor->buffer_handle = buffer_handle; diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 7315d8360680ca0d3c405dc80b593762275815ee..a27df4b964c1d2dea0274e51c3a86435172172c5 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -17,6 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ #define TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#include #include #include #include @@ -39,6 +40,10 @@ constexpr TfLiteType typeToTfLiteType() { return kTfLiteInt32; } template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteInt16; +} +template <> constexpr TfLiteType typeToTfLiteType() { return kTfLiteInt64; } @@ -54,6 +59,14 @@ template <> constexpr TfLiteType typeToTfLiteType() { return kTfLiteBool; } +template <> +constexpr TfLiteType typeToTfLiteType>() { + return kTfLiteComplex64; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteString; +} // Forward declare since NNAPIDelegate uses Interpreter. class NNAPIDelegate; @@ -98,7 +111,7 @@ class Interpreter { // processing this model will be forwarded to the error_reporter object. // // Note, if error_reporter is nullptr, then a default StderrReporter is - // used. + // used. Ownership of 'error_reporter' remains with the caller. explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); ~Interpreter(); @@ -118,6 +131,11 @@ class Interpreter { // interpreter. TfLiteStatus SetOutputs(std::vector outputs); + // Provide a list of tensor indexes that are variable tensors. + // Each index is bound check and this modifies the consistent_ flag of the + // interpreter. + TfLiteStatus SetVariables(std::vector variables); + // Adds a node with the given parameters and returns the index of the new // node in `node_index` (optionally). Interpreter will take ownership of // `builtin_data` and destroy it with `free`. Ownership of 'init_data' @@ -147,7 +165,7 @@ class Interpreter { return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(), dims.data(), quantization, buffer, bytes, allocation); - }; + } TfLiteStatus SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const size_t rank, @@ -160,13 +178,15 @@ class Interpreter { // to Interpreter. inline TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, - const std::vector& dims, TfLiteQuantizationParams quantization) { + const std::vector& dims, TfLiteQuantizationParams quantization, + bool is_variable = false) { return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(), - dims.data(), quantization); + dims.data(), quantization, is_variable); } TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization); + const int* dims, TfLiteQuantizationParams quantization, + bool is_variable = false); // Functions to access tensor data @@ -182,6 +202,9 @@ class Interpreter { // Read only access to list of outputs. const std::vector& outputs() const { return outputs_; } + // Read only access to list of variable tensors. + const std::vector& variables() const { return variables_; } + // Return the name of a given output. The given index must be between 0 and // outputs().size(). const char* GetOutputName(int index) const { @@ -327,7 +350,7 @@ class Interpreter { // This can be null if the delegate doesn't use its own buffer. TF_LITE_ENSURE(&context_, tensor->delegate->CopyFromBufferHandle != nullptr); - tensor->delegate->CopyFromBufferHandle(tensor->delegate, + tensor->delegate->CopyFromBufferHandle(&context_, tensor->delegate, tensor->buffer_handle, tensor->data.raw, tensor->bytes); tensor->data_is_stale = false; @@ -379,7 +402,27 @@ class Interpreter { allow_buffer_handle_output_ = allow_buffer_handle_output; } + // Reset all variable tensors to zero. + // WARNING: This is an experimental API and subject to change. + TfLiteStatus ResetVariableTensorsToZero(); + + // Retrieve an operator's description of its work, for profiling purposes. + const char* OpProfilingString(const TfLiteRegistration& op_reg, + const TfLiteNode* node) const { + if (op_reg.profiling_string == nullptr) return nullptr; + return op_reg.profiling_string(&context_, node); + } + private: + friend class InterpreterTest; + + // Prevent 'context_' from accessing functions that are only available to + // delegated kernels. + void SwitchToKernelContext(); + + // Add delegate-only functions to 'context_'. + void SwitchToDelegateContext(); + // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. void* OpInit(const TfLiteRegistration& op_reg, const char* buffer, @@ -466,6 +509,7 @@ class Interpreter { // Update the execution graph to replace some of the nodes with stub // nodes. Specifically any node index that has `nodes[index]==1` will be // slated for replacement with a delegate kernel specified by registration. + // Ownership of 'nodes_to_replace' and 'delegate' remains with the caller. // WARNING: This is an experimental interface that is subject to change. TfLiteStatus ReplaceSubgraphsWithDelegateKernels( TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace, @@ -483,15 +527,28 @@ class Interpreter { TfLiteRegistration** registration); // WARNING: This is an experimental interface that is subject to change. - // Gets an TfLiteIntArray* representing the execution plan. The caller owns - // this memory and must free it with TfLiteIntArrayFree(). + // Gets an TfLiteIntArray* representing the execution plan. The interpreter + // owns this memory and it is only guaranteed to exist during the invocation + // of the delegate prepare. TfLiteStatus GetExecutionPlan(TfLiteIntArray** execution_plan); // WARNING: This is an experimental interface that is subject to change. - // Entry point for C node plugin API to get the execution plan + // Entry point for C node plugin API to get the execution plan. static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context, TfLiteIntArray** execution_plan); + // Retrieve an existing external context by type. + TfLiteExternalContext* GetExternalContext(TfLiteExternalContextType type); + static TfLiteExternalContext* GetExternalContext( + struct TfLiteContext* context, TfLiteExternalContextType type); + + // Set the value of an external context. + void SetExternalContext(TfLiteExternalContextType type, + TfLiteExternalContext* ctx); + static void SetExternalContext(struct TfLiteContext* context, + TfLiteExternalContextType type, + TfLiteExternalContext* ctx); + // Ensures that `tensors_` has at least `kTensorsCapacityHeadroom` extra // capacity. Calling this function may invalidate existing pointers to // tensors. After calling this function, adding `kTensorsCapacityHeadroom` @@ -541,6 +598,9 @@ class Interpreter { // interpreter. std::vector outputs_; + // Array of indices representing the tensors that are variable tensors. + std::vector variables_; + // The error reporter delegate that tflite will forward queries errors to. ErrorReporter* error_reporter_; @@ -572,8 +632,16 @@ class Interpreter { bool allow_buffer_handle_output_ = false; + // Tracking bit for whether a tensor was resized in the course of an op + // invocation. This is a useful hint to ensure that dynamic tensor outputs + // trigger downstream reallocation after op invocation. + bool tensor_resized_since_op_invoke_ = false; + // Profiler for this interpreter instance. - profiling::Profiler* profiler_; + profiling::Profiler* profiler_ = nullptr; + + // List of active external contexts. + TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts]; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 4c78466480bdcfb42b1f582ecc7c185201a81b05..f00697826c0113df4db687549232123d16a60686 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -23,6 +23,21 @@ limitations under the License. #include "tensorflow/contrib/lite/testing/util.h" namespace tflite { + +// InterpreterTest is a friend of Interpreter, so it can access context_. +class InterpreterTest : public ::testing::Test { + protected: + TfLiteContext* GetInterpreterContext() { return &interpreter_.context_; } + + Interpreter interpreter_; +}; + +namespace ops { +namespace builtin { +TfLiteRegistration* Register_PADV2(); +TfLiteRegistration* Register_NEG(); +} // namespace builtin +} // namespace ops namespace { // Make an interpreter that has no tensors and no nodes @@ -42,6 +57,22 @@ TEST(BasicInterpreter, InvokeInvalidModel) { ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); } +TEST(BasicInterpreter, TestAllocateTensorsResetVariableTensors) { + Interpreter interpreter; + int tensor_index; + ASSERT_EQ(interpreter.AddTensors(1, &tensor_index), kTfLiteOk); + constexpr int kTensorSize = 16; + interpreter.SetTensorParametersReadWrite(tensor_index, kTfLiteFloat32, "", + {kTensorSize}, {}, true); + interpreter.SetVariables({tensor_index}); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + TfLiteTensor* tensor = interpreter.tensor(tensor_index); + // Ensure that variable tensors are reset to zero. + for (int i = 0; i < kTensorSize; ++i) { + ASSERT_EQ(tensor->data.f[i], 0.0f); + } +} + // Test size accessor functions. TEST(BasicInterpreter, TestSizeFunctions) { Interpreter interpreter; @@ -106,10 +137,9 @@ TEST(BasicInterpreter, CheckAllocate) { TfLiteType type; size_t size; } cases[] = { - {kTfLiteFloat32, sizeof(float)}, - {kTfLiteInt32, sizeof(int32_t)}, - {kTfLiteUInt8, sizeof(uint8_t)}, - {kTfLiteInt64, sizeof(int64_t)}, + {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)}, + {kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)}, + {kTfLiteInt16, sizeof(int16_t)}, }; for (auto test : cases) { @@ -134,6 +164,7 @@ TEST(BasicInterpreter, CheckResize) { const int32_t int32s[] = {-3, -4}; const uint8_t uint8s[] = {3, 4}; const int64_t int64s[] = {6, -7}; + const int16_t int16s[] = {8, -9}; struct { TfLiteType type; @@ -144,6 +175,7 @@ TEST(BasicInterpreter, CheckResize) { {kTfLiteInt32, sizeof(int32_t), reinterpret_cast(int32s)}, {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast(uint8s)}, {kTfLiteInt64, sizeof(int64_t), reinterpret_cast(int64s)}, + {kTfLiteInt16, sizeof(int16_t), reinterpret_cast(int16s)}, }; for (auto test : cases) { @@ -179,10 +211,8 @@ TEST(BasicInterpreter, CheckAlignment) { struct { TfLiteType type; } cases[] = { - {kTfLiteFloat32}, - {kTfLiteInt32}, - {kTfLiteUInt8}, - {kTfLiteInt64}, + {kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8}, + {kTfLiteInt64}, {kTfLiteInt16}, }; for (auto test : cases) { @@ -226,32 +256,16 @@ TEST(BasicInterpreter, CheckArenaAllocation) { ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); - ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw); - ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw); - ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr); - - ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw); - ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw); ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw); - - ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(2)->data.raw); ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(3)->data.raw); - - ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(4)->data.raw); ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(5)->data.raw, interpreter.tensor(7)->data.raw); + ASSERT_EQ(interpreter.tensor(6)->data.raw, interpreter.tensor(2)->data.raw); + // #7 is the one with the largest pointer. + ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw); } TEST(BasicInterpreter, BufferAccess) { @@ -287,6 +301,57 @@ TEST(BasicInterpreter, NoOpInterpreter) { ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); } +TEST(BasicInterpreter, RedundantAllocateTensors) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + const auto data_raw = interpreter.tensor(0)->data.raw; + ASSERT_NE(data_raw, nullptr); + + // A redundant allocation request should have no impact. + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.tensor(0)->data.raw, data_raw); +} + +TEST(BasicInterpreter, RedundantAllocateTensorsWithDynamicInputs) { + Interpreter interpreter; + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + interpreter.SetInputs({0}); + interpreter.SetOutputs({1}); + interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 1, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + // Configure the input tensor as dynamic. + interpreter.tensor(0)->data.raw = nullptr; + interpreter.tensor(0)->allocation_type = kTfLiteDynamic; + + ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); + + // Reset the output tensor's buffer. + interpreter.tensor(1)->data.raw = nullptr; + + // A redundant allocation request should be honored, as the input tensor + // was marked dynamic. + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); +} + TEST(BasicInterpreter, ResizingTensors) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); @@ -344,6 +409,37 @@ TEST(BasicInterpreter, ResizingTensors) { tensor->data.f[15] = 0.123f; } +TEST(BasicInterpreter, NoopResizingTensors) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + int t = interpreter.inputs()[0]; + TfLiteTensor* tensor = interpreter.tensor(t); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + tensor->data.f[5] = 0.123f; + + // Resizing to the same size should not trigger re-allocation. + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_NE(tensor->data.raw, nullptr); + ASSERT_EQ(tensor->data.f[5], 0.123f); + + // Explicitly allocating should be a no-op, as no resize was performed. + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_NE(tensor->data.raw, nullptr); + ASSERT_EQ(tensor->data.f[5], 0.123f); +} + TEST(BasicInterpreter, OneOpInterpreter) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); @@ -551,18 +647,6 @@ TEST(BasicInterpreter, AllocateTwice) { ASSERT_EQ(old_tensor1_ptr, interpreter.tensor(1)->data.raw); } -struct TestErrorReporter : public ErrorReporter { - int Report(const char* format, va_list args) override { - char buffer[1024]; - int size = vsnprintf(buffer, sizeof(buffer), format, args); - all_reports += buffer; - calls++; - return size; - } - int calls = 0; - std::string all_reports; -}; - TEST(BasicInterpreter, TestNullErrorReporter) { TestErrorReporter reporter; Interpreter interpreter; @@ -572,8 +656,9 @@ TEST(BasicInterpreter, TestCustomErrorReporter) { TestErrorReporter reporter; Interpreter interpreter(&reporter); ASSERT_NE(interpreter.Invoke(), kTfLiteOk); - ASSERT_EQ(reporter.all_reports, "Invoke called on model that is not ready."); - ASSERT_EQ(reporter.calls, 1); + ASSERT_EQ(reporter.error_messages(), + "Invoke called on model that is not ready."); + ASSERT_EQ(reporter.num_calls(), 1); } TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) { @@ -616,6 +701,59 @@ TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) { EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteError); } +TEST(BasicInterpreter, DynamicTensorsResizeDescendants) { + // Assemble a graph with a node that has dynamically sized output (via the + // pad op), followed by a node with a standard element-wise op (negate). + Interpreter interpreter; + interpreter.AddTensors(4); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({3}); + TfLiteQuantizationParams quant; + interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {2, 2, 1, 1}, + quant); + interpreter.SetTensorParametersReadWrite(1, kTfLiteInt32, "", {4, 2}, quant); + interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {}, quant); + interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {}, quant); + + TfLiteRegistration* pad_op = tflite::ops::builtin::Register_PADV2(); + TfLiteRegistration* neg_op = tflite::ops::builtin::Register_NEG(); + interpreter.AddNodeWithParameters({0, 1}, {2}, nullptr, 0, nullptr, pad_op); + interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, neg_op); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // Configure [[2,2],[4,4]] padding and execute the graph. + interpreter.typed_tensor(1)[0] = 2; + interpreter.typed_tensor(1)[1] = 2; + interpreter.typed_tensor(1)[2] = 2; + interpreter.typed_tensor(1)[3] = 2; + interpreter.typed_tensor(1)[4] = 0; + interpreter.typed_tensor(1)[5] = 0; + interpreter.typed_tensor(1)[6] = 0; + interpreter.typed_tensor(1)[7] = 0; + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + // Both the output and intermediate tensor sizes should reflect the output + // from the dynamic pad operation. + ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 6 * 6); + ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 6 * 6); + + // Now configure [[4,4],[6,6]] padding and execute the graph. + interpreter.typed_tensor(1)[0] = 4; + interpreter.typed_tensor(1)[1] = 4; + interpreter.typed_tensor(1)[2] = 6; + interpreter.typed_tensor(1)[3] = 6; + interpreter.typed_tensor(1)[4] = 0; + interpreter.typed_tensor(1)[5] = 0; + interpreter.typed_tensor(1)[6] = 0; + interpreter.typed_tensor(1)[7] = 0; + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + // Again, the output and intermediate tensor sizes should reflect the *new* + // resize from the latest pad operation. + ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 10 * 14); + ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 10 * 14); +} + TEST(InterpreterTensorsCapacityTest, TestWithinHeadroom) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity), @@ -656,6 +794,47 @@ TEST(InterpreterTensorsCapacityTest, TestExceedHeadroom) { ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); } +struct TestExternalContext : public TfLiteExternalContext { + static const TfLiteExternalContextType kType = kTfLiteGemmLowpContext; + + static TestExternalContext* Get(TfLiteContext* context) { + return reinterpret_cast( + context->GetExternalContext(context, kType)); + } + + static void Set(TfLiteContext* context, TestExternalContext* value) { + context->SetExternalContext(context, kType, value); + } + + int num_refreshes = 0; +}; + +TEST_F(InterpreterTest, GetSetResetExternalContexts) { + auto* context = GetInterpreterContext(); + + TestExternalContext external_context; + external_context.Refresh = [](TfLiteContext* context) { + auto* ptr = TestExternalContext::Get(context); + if (ptr != nullptr) { + ++ptr->num_refreshes; + } + return kTfLiteOk; + }; + + EXPECT_EQ(TestExternalContext::Get(context), nullptr); + interpreter_.SetNumThreads(4); + + TestExternalContext::Set(context, &external_context); + EXPECT_EQ(TestExternalContext::Get(context), &external_context); + interpreter_.SetNumThreads(4); + interpreter_.SetNumThreads(5); + EXPECT_EQ(external_context.num_refreshes, 2); + + TestExternalContext::Set(context, nullptr); + EXPECT_EQ(TestExternalContext::Get(context), nullptr); + interpreter_.SetNumThreads(4); +} + // Test fixture that allows playing with execution plans. It creates a two // node graph that can be executed in either [0,1] order or [1,0] order. // The CopyOp records when it is invoked in the class member run_order_ @@ -901,21 +1080,22 @@ class TestDelegate : public ::testing::Test { return kTfLiteOk; }; delegate_.CopyToBufferHandle = - [](TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, - void* data, size_t size) -> TfLiteStatus { + [](TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, void* data, + size_t size) -> TfLiteStatus { // TODO(ycling): Implement tests to test buffer copying logic. return kTfLiteOk; }; delegate_.CopyFromBufferHandle = - [](TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, - void* data, size_t size) -> TfLiteStatus { + [](TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, void* data, + size_t size) -> TfLiteStatus { // TODO(ycling): Implement tests to test buffer copying logic. return kTfLiteOk; }; - delegate_.FreeBufferHandle = [](TfLiteDelegate* delegate, - TfLiteBufferHandle* handle) { - *handle = kTfLiteNullBufferHandle; - }; + delegate_.FreeBufferHandle = + [](TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle* handle) { *handle = kTfLiteNullBufferHandle; }; // Store type-punned data SimpleDelegate structure. delegate_.data_ = reinterpret_cast(this); } diff --git a/tensorflow/contrib/lite/java/AndroidManifest.xml b/tensorflow/contrib/lite/java/AndroidManifest.xml index f705feacbec38ab5152ce52b701320d8f1cd8d3d..b91c6d149a213926be90b9b131bd632d4f79a0fc 100644 --- a/tensorflow/contrib/lite/java/AndroidManifest.xml +++ b/tensorflow/contrib/lite/java/AndroidManifest.xml @@ -1,7 +1,12 @@ - - + package="org.tensorflow.lite"> + + + + + diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 593af81a18a1e20a41dcc8d9bb3a1d815876e294..098ba7e7731d833678fbd5eab9cce3f022570f23 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -69,6 +69,7 @@ java_test( size = "small", srcs = ["src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java"], javacopts = JAVACOPTS, + tags = ["no_oss"], test_class = "org.tensorflow.lite.TensorFlowLiteTest", deps = [ ":libtensorflowlite_jni.so", @@ -83,6 +84,7 @@ java_test( size = "small", srcs = ["src/test/java/org/tensorflow/lite/DataTypeTest.java"], javacopts = JAVACOPTS, + tags = ["no_oss"], test_class = "org.tensorflow.lite.DataTypeTest", deps = [ ":libtensorflowlite_jni.so", @@ -105,6 +107,7 @@ java_test( "src/testdata/with_custom_op.lite", ], javacopts = JAVACOPTS, + tags = ["no_oss"], test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest", deps = [ ":libtensorflowlite_jni.so", @@ -124,6 +127,7 @@ java_test( "src/testdata/mobilenet.tflite.bin", ], javacopts = JAVACOPTS, + tags = ["no_oss"], test_class = "org.tensorflow.lite.InterpreterTest", visibility = ["//visibility:private"], deps = [ @@ -142,6 +146,7 @@ java_test( "src/testdata/add.bin", ], javacopts = JAVACOPTS, + tags = ["no_oss"], test_class = "org.tensorflow.lite.TensorTest", deps = [ ":tensorflowlitelib", diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl index 4450bc9085555b3416f51bac07ea94a1240e919c..db837cf29edfc0ffe9950ffedc02cca1389b0fdf 100644 --- a/tensorflow/contrib/lite/java/aar_with_jni.bzl +++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl @@ -1,5 +1,7 @@ """Generate zipped aar file including different variants of .so in jni folder.""" +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + def aar_with_jni(name, android_library): # Generate dummy AndroidManifest.xml for dummy apk usage # (dummy apk is generated by _dummy_app_for_so target below) @@ -19,7 +21,7 @@ EOF # Generate dummy apk including .so files and later we extract out # .so files and throw away the apk. - native.android_binary( + android_binary( name = name + "_dummy_app_for_so", manifest = name + "_generated_AndroidManifest.xml", custom_package = "dummy.package.for.so", diff --git a/tensorflow/contrib/lite/java/demo/.gitignore b/tensorflow/contrib/lite/java/demo/.gitignore index 39fb081a42a86ccf8f9cf99dbccc8bdf7c828bce..d245ab61095a6f9b6d2077aac934f9b13e66d85e 100644 --- a/tensorflow/contrib/lite/java/demo/.gitignore +++ b/tensorflow/contrib/lite/java/demo/.gitignore @@ -1,9 +1,29 @@ +# This file is based on https://github.com/github/gitignore/blob/master/Android.gitignore *.iml +.idea/compiler.xml +.idea/copyright +.idea/dictionaries +.idea/gradle.xml +.idea/libraries +.idea/inspectionProfiles +.idea/misc.xml +.idea/modules.xml +.idea/runConfigurations.xml +.idea/tasks.xml +.idea/workspace.xml .gradle -/local.properties -/.idea/workspace.xml -/.idea/libraries +local.properties .DS_Store -/build +build/ +gradleBuild/ +*.apk +*.ap_ +*.dex +*.class +bin/ +gen/ +out/ +*.log +.navigation/ /captures .externalNativeBuild diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md index 2e818f728ef208d30b0eeb27ffd7e3fa0c7c1a2d..e3cea19e1683ac2680521bce66d1328e4b2caf1c 100644 --- a/tensorflow/contrib/lite/java/demo/README.md +++ b/tensorflow/contrib/lite/java/demo/README.md @@ -1,5 +1,14 @@ # TF Lite Android App +## Building in Android Studio with TensorFlow Lite AAR from JCenter. +The build.gradle is configured to use TensorFlow Lite's nightly build. + +If you see a build error related to compatibility with Tensorflow Lite's Java API (example: method X is +undefined for type Interpreter), there has likely been a backwards compatible +change to the API. You will need to pull new app code that's compatible with the +nightly build and may need to first wait a few days for our external and internal +code to merge. + ## Building from Source with Bazel 1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel): diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle index b76eaad8bb91224805d16b3d6f7c3274c9feb90c..92f04c651c0488a5202def593774890630c8631f 100644 --- a/tensorflow/contrib/lite/java/demo/app/build.gradle +++ b/tensorflow/contrib/lite/java/demo/app/build.gradle @@ -5,11 +5,12 @@ android { buildToolsVersion "26.0.1" defaultConfig { applicationId "android.example.com.tflitecamerademo" - minSdkVersion 15 + // Required by Camera2 API. + minSdkVersion 21 targetSdkVersion 26 versionCode 1 versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" // Remove this block. jackOptions { @@ -43,7 +44,7 @@ repositories { dependencies { compile fileTree(dir: 'libs', include: ['*.jar']) - androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', { exclude group: 'com.android.support', module: 'support-annotations' }) compile 'com.android.support:appcompat-v7:25.2.0' @@ -52,7 +53,43 @@ dependencies { compile 'com.android.support:support-annotations:25.3.1' compile 'com.android.support:support-v13:25.2.0' - compile 'org.tensorflow:tensorflow-lite:+' + compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly' testCompile 'junit:junit:4.12' } + +def modelDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip" +def localCache = "build/intermediates/mobilenet_v1_224_android_quant_2017_11_08.zip" +def targetFolder = "src/main/assets" + +task downloadModel(type: DownloadUrlTask) { + doFirst { + println "Downloading ${modelDownloadUrl}" + } + sourceUrl = "${modelDownloadUrl}" + target = file("${localCache}") +} + +task unzipModel(type: Copy, dependsOn: 'downloadModel') { + doFirst { + println "Unzipping ${localCache}" + } + from zipTree("${localCache}") + into "${targetFolder}" +} + +// Ensure the model file is downloaded and extracted before every build +preBuild.dependsOn unzipModel + +class DownloadUrlTask extends DefaultTask { + @Input + String sourceUrl + + @OutputFile + File target + + @TaskAction + void download() { + ant.get(src: sourceUrl, dest: target) + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD index d6fbef9cc938993b283103984307ab51e609dd6e..220d6c2159b56f6349e93132418fa0f6c69d1ab3 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD +++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD @@ -1,3 +1,5 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:private"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD index 362d93636f72205ddcda6d97fa9fae376ff211f1..06f46fb92394b19415ddb95dcf8c798753b630e3 100644 --- a/tensorflow/contrib/lite/java/ovic/BUILD +++ b/tensorflow/contrib/lite/java/ovic/BUILD @@ -1,6 +1,8 @@ # Description: # OVIC Benchmarker Java API. +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 @@ -16,6 +18,7 @@ java_test( "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", ], javacopts = JAVACOPTS, + tags = ["no_oss"], test_class = "org.tensorflow.ovic.OvicClassifierTest", visibility = ["//visibility:public"], deps = [ diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD index 83974f4b337baedebaf9c9ffc0a03501418a3e36..a8d751ade26adc358e130138381eab9956f2d848 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD +++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD @@ -1,3 +1,5 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + # Sample app for OVIC benchmarking. licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle index c5d19bad89a93988a6830a17fe2fb4a60e2fb00f..2a08608bbb121a2e279bbd16de6a014e68848796 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle +++ b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle @@ -9,7 +9,7 @@ android { targetSdkVersion 26 versionCode 1 versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" // Remove this block. jackOptions { @@ -43,7 +43,7 @@ repositories { dependencies { compile fileTree(dir: 'libs', include: ['*.jar']) - androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', { exclude group: 'com.android.support', module: 'support-annotations' }) compile 'com.android.support:appcompat-v7:25.2.0' diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java index 56f3e7604a5b172e907edbe862b017957594397f..1587c3c56f45c0baddfa75286c979fe0c0edffcc 100644 --- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java +++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java @@ -127,12 +127,8 @@ public final class OvicClassifierTest { try { testResult = classifier.classifyByteBuffer(testImage); fail(); - } catch (RuntimeException e) { - assertThat(e) - .hasMessageThat() - .contains( - "Failed to get input dimensions. 0-th input should have 49152 bytes, " - + "but found 150528 bytes."); + } catch (IllegalArgumentException e) { + // Success. } } diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java index 75334cd96e8daadc356dadea063eee30ef6d5245..94a1ec65d64b6493cdb309fc0c19155eb9cb26cb 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java @@ -27,10 +27,7 @@ enum DataType { UINT8(3), /** 64-bit signed integer. */ - INT64(4), - - /** A {@link ByteBuffer}. */ - BYTEBUFFER(999); + INT64(4); private final int value; @@ -69,8 +66,6 @@ enum DataType { return 1; case INT64: return 8; - case BYTEBUFFER: - return 1; } throw new IllegalArgumentException( "DataType error: DataType " + this + " is not supported yet"); @@ -87,8 +82,6 @@ enum DataType { return "byte"; case INT64: return "long"; - case BYTEBUFFER: - return "ByteBuffer"; } throw new IllegalArgumentException( "DataType error: DataType " + this + " is not supported yet"); diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index fd1f0ffa68eeca7b5866b146ecaa1f9216ef377d..7002f826775b216e0a27ebe00f30680c9ce362bb 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -135,7 +135,8 @@ public final class Interpreter implements AutoCloseable { * including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large * input data. When {@link ByteBuffer} is used, its content should remain unchanged until * model inference is done. - * @param output a multidimensional array of output data. + * @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive + * types including int, float, long, and byte. */ public void run(@NonNull Object input, @NonNull Object output) { Object[] inputs = {input}; @@ -155,28 +156,16 @@ public final class Interpreter implements AutoCloseable { * primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred * way to pass large input data. When {@link ByteBuffer} is used, its content should remain * unchanged until model inference is done. - * @param outputs a map mapping output indices to multidimensional arrays of output data. It only - * needs to keep entries for the outputs to be used. + * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link + * ByteBuffer}s of primitive types including int, float, long, and byte. It only needs to keep + * entries for the outputs to be used. */ public void runForMultipleInputsOutputs( @NonNull Object[] inputs, @NonNull Map outputs) { if (wrapper == null) { throw new IllegalStateException("Internal error: The Interpreter has already been closed."); } - Tensor[] tensors = wrapper.run(inputs); - if (outputs == null || tensors == null || outputs.size() > tensors.length) { - throw new IllegalArgumentException("Output error: Outputs do not match with model outputs."); - } - final int size = tensors.length; - for (Integer idx : outputs.keySet()) { - if (idx == null || idx < 0 || idx >= size) { - throw new IllegalArgumentException( - String.format( - "Output error: Invalid index of output %d (should be in range [0, %d))", - idx, size)); - } - tensors[idx].copyTo(outputs.get(idx)); - } + wrapper.run(inputs, outputs); } /** @@ -249,8 +238,10 @@ public final class Interpreter implements AutoCloseable { /** Release resources associated with the {@code Interpreter}. */ @Override public void close() { - wrapper.close(); - wrapper = null; + if (wrapper != null) { + wrapper.close(); + wrapper = null; + } } @Override diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 2ae6c516b03ef4292667bbd944c73d2eeaf82db3..767a220f8cd5381ce10e044553317b1cb05ba17b 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -15,10 +15,10 @@ limitations under the License. package org.tensorflow.lite; -import java.lang.reflect.Array; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -40,6 +40,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { modelHandle = createModel(modelPath, errorHandle); interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); isMemoryAllocated = true; + inputTensors = new Tensor[getInputCount(interpreterHandle)]; + outputTensors = new Tensor[getOutputCount(interpreterHandle)]; } /** @@ -72,6 +74,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); isMemoryAllocated = true; + inputTensors = new Tensor[getInputCount(interpreterHandle)]; + outputTensors = new Tensor[getOutputCount(interpreterHandle)]; } /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ @@ -85,75 +89,63 @@ final class NativeInterpreterWrapper implements AutoCloseable { inputsIndexes = null; outputsIndexes = null; isMemoryAllocated = false; + Arrays.fill(inputTensors, null); + Arrays.fill(outputTensors, null); } /** Sets inputs, runs model inference and returns outputs. */ - Tensor[] run(Object[] inputs) { + void run(Object[] inputs, Map outputs) { + inferenceDurationNanoseconds = -1; if (inputs == null || inputs.length == 0) { throw new IllegalArgumentException("Input error: Inputs should not be null or empty."); } - int[] dataTypes = new int[inputs.length]; - Object[] sizes = new Object[inputs.length]; - int[] numsOfBytes = new int[inputs.length]; + if (outputs == null || outputs.isEmpty()) { + throw new IllegalArgumentException("Input error: Outputs should not be null or empty."); + } + + // TODO(b/80431971): Remove implicit resize after deprecating multi-dimensional array inputs. + // Rather than forcing an immediate resize + allocation if an input's shape differs, we first + // flush all resizes, avoiding redundant allocations. for (int i = 0; i < inputs.length; ++i) { - DataType dataType = dataTypeOf(inputs[i]); - dataTypes[i] = dataType.getNumber(); - if (dataType == DataType.BYTEBUFFER) { - ByteBuffer buffer = (ByteBuffer) inputs[i]; - if (buffer == null || !buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()) { - throw new IllegalArgumentException( - "Input error: ByteBuffer should be a direct ByteBuffer that uses " - + "ByteOrder.nativeOrder()."); - } - numsOfBytes[i] = buffer.limit(); - sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]); - } else if (isNonEmptyArray(inputs[i])) { - int[] dims = shapeOf(inputs[i]); - sizes[i] = dims; - numsOfBytes[i] = dataType.elemByteSize() * numElements(dims); - } else { - throw new IllegalArgumentException( - String.format( - "Input error: %d-th element of the %d inputs is not an array or a ByteBuffer.", - i, inputs.length)); + Tensor tensor = getInputTensor(i); + int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]); + if (newShape != null) { + resizeInput(i, newShape); } } - inferenceDurationNanoseconds = -1; - long[] outputsHandles = - run( - interpreterHandle, - errorHandle, - sizes, - dataTypes, - numsOfBytes, - inputs, - this, - isMemoryAllocated); - if (outputsHandles == null || outputsHandles.length == 0) { - throw new IllegalStateException("Internal error: Interpreter has no outputs."); + + if (!isMemoryAllocated) { + allocateTensors(interpreterHandle, errorHandle); + isMemoryAllocated = true; + // Allocation can trigger dynamic resizing of output tensors, so clear the + // output tensor cache. + Arrays.fill(outputTensors, null); } - isMemoryAllocated = true; - Tensor[] outputs = new Tensor[outputsHandles.length]; - for (int i = 0; i < outputsHandles.length; ++i) { - outputs[i] = Tensor.fromHandle(outputsHandles[i]); + + for (int i = 0; i < inputs.length; ++i) { + getInputTensor(i).setTo(inputs[i]); + } + + long inferenceStartNanos = System.nanoTime(); + run(interpreterHandle, errorHandle); + long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos; + + for (Map.Entry output : outputs.entrySet()) { + getOutputTensor(output.getKey()).copyTo(output.getValue()); } - return outputs; + + // Only set if the entire operation succeeds. + this.inferenceDurationNanoseconds = inferenceDurationNanoseconds; } - private static native long[] run( - long interpreterHandle, - long errorHandle, - Object[] sizes, - int[] dtypes, - int[] numsOfBytes, - Object[] values, - NativeInterpreterWrapper wrapper, - boolean memoryAllocated); + private static native boolean run(long interpreterHandle, long errorHandle); /** Resizes dimensions of a specific input. */ void resizeInput(int idx, int[] dims) { if (resizeInput(interpreterHandle, errorHandle, idx, dims)) { isMemoryAllocated = false; + // Resizing will invalidate the Tensor's shape, so invalidate the Tensor handle. + inputTensors[idx] = null; } } @@ -212,78 +204,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { } } - static int numElements(int[] shape) { - if (shape == null) { - return 0; - } - int n = 1; - for (int i = 0; i < shape.length; i++) { - n *= shape[i]; - } - return n; - } - - static boolean isNonEmptyArray(Object o) { - return (o != null && o.getClass().isArray() && Array.getLength(o) != 0); - } - - /** Returns the type of the data. */ - static DataType dataTypeOf(Object o) { - if (o != null) { - Class c = o.getClass(); - while (c.isArray()) { - c = c.getComponentType(); - } - if (float.class.equals(c)) { - return DataType.FLOAT32; - } else if (int.class.equals(c)) { - return DataType.INT32; - } else if (byte.class.equals(c)) { - return DataType.UINT8; - } else if (long.class.equals(c)) { - return DataType.INT64; - } else if (ByteBuffer.class.isInstance(o)) { - return DataType.BYTEBUFFER; - } - } - throw new IllegalArgumentException( - "DataType error: cannot resolve DataType of " + o.getClass().getName()); - } - - /** Returns the shape of an object as an int array. */ - static int[] shapeOf(Object o) { - int size = numDimensions(o); - int[] dimensions = new int[size]; - fillShape(o, 0, dimensions); - return dimensions; - } - - static int numDimensions(Object o) { - if (o == null || !o.getClass().isArray()) { - return 0; - } - if (Array.getLength(o) == 0) { - throw new IllegalArgumentException("Array lengths cannot be 0."); - } - return 1 + numDimensions(Array.get(o, 0)); - } - - static void fillShape(Object o, int dim, int[] shape) { - if (shape == null || dim == shape.length) { - return; - } - final int len = Array.getLength(o); - if (shape[dim] == 0) { - shape[dim] = len; - } else if (shape[dim] != len) { - throw new IllegalArgumentException( - String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); - } - for (int i = 0; i < len; ++i) { - fillShape(Array.get(o, i), dim + 1, shape); - } - } - /** * Gets the last inference duration in nanoseconds. It returns null if there is no previous * inference run or the last inference run failed. @@ -293,26 +213,63 @@ final class NativeInterpreterWrapper implements AutoCloseable { } /** - * Gets the dimensions of an input. It throws IllegalArgumentException if input index is invalid. + * Gets the quantization zero point of an output. + * + * @throws IllegalArgumentException if the output index is invalid. */ - int[] getInputDims(int index) { - return getInputDims(interpreterHandle, index, -1); + int getOutputQuantizationZeroPoint(int index) { + return getOutputQuantizationZeroPoint(interpreterHandle, index); } /** - * Gets the dimensions of an input. If numBytes >= 0, it will check whether num of bytes match the - * input. + * Gets the quantization scale of an output. + * + * @throws IllegalArgumentException if the output index is invalid. */ - private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes); + float getOutputQuantizationScale(int index) { + return getOutputQuantizationScale(interpreterHandle, index); + } + + /** + * Gets the input {@link Tensor} for the provided input index. + * + * @throws IllegalArgumentException if the input index is invalid. + */ + Tensor getInputTensor(int index) { + if (index < 0 || index >= inputTensors.length) { + throw new IllegalArgumentException("Invalid input Tensor index: " + index); + } + Tensor inputTensor = inputTensors[index]; + if (inputTensor == null) { + inputTensor = + inputTensors[index] = Tensor.fromHandle(getInputTensor(interpreterHandle, index)); + } + return inputTensor; + } - /** Gets the type of an output. It throws IllegalArgumentException if output index is invalid. */ - String getOutputDataType(int index) { - int type = getOutputDataType(interpreterHandle, index); - return DataType.fromNumber(type).toStringName(); + /** + * Gets the output {@link Tensor} for the provided output index. + * + * @throws IllegalArgumentException if the output index is invalid. + */ + Tensor getOutputTensor(int index) { + if (index < 0 || index >= outputTensors.length) { + throw new IllegalArgumentException("Invalid output Tensor index: " + index); + } + Tensor outputTensor = outputTensors[index]; + if (outputTensor == null) { + outputTensor = + outputTensors[index] = Tensor.fromHandle(getOutputTensor(interpreterHandle, index)); + } + return outputTensor; } private static native int getOutputDataType(long interpreterHandle, int outputIdx); + private static native int getOutputQuantizationZeroPoint(long interpreterHandle, int outputIdx); + + private static native float getOutputQuantizationScale(long interpreterHandle, int outputIdx); + private static final int ERROR_BUFFER_SIZE = 512; private long errorHandle; @@ -321,18 +278,30 @@ final class NativeInterpreterWrapper implements AutoCloseable { private long modelHandle; - private int inputSize; - private long inferenceDurationNanoseconds = -1; private ByteBuffer modelByteBuffer; + // Lazily constructed maps of input and output names to input and output Tensor indexes. private Map inputsIndexes; - private Map outputsIndexes; + // Lazily constructed and populated arrays of input and output Tensor wrappers. + private final Tensor[] inputTensors; + private final Tensor[] outputTensors; + private boolean isMemoryAllocated = false; + private static native long allocateTensors(long interpreterHandle, long errorHandle); + + private static native long getInputTensor(long interpreterHandle, int inputIdx); + + private static native long getOutputTensor(long interpreterHandle, int outputIdx); + + private static native int getInputCount(long interpreterHandle); + + private static native int getOutputCount(long interpreterHandle); + private static native String[] getInputNames(long interpreterHandle); private static native String[] getOutputNames(long interpreterHandle); diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index 09e887aae3339e9f114c07d689c0d7b5e2fc384b..2403570c527e762f6782e313731e383feeeef46d 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -15,6 +15,9 @@ limitations under the License. package org.tensorflow.lite; +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.Arrays; /** @@ -29,30 +32,179 @@ final class Tensor { return new Tensor(nativeHandle); } - /** Reads Tensor content into an array. */ - T copyTo(T dst) { - if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) { + /** Returns the {@link DataType} of elements stored in the Tensor. */ + public DataType dataType() { + return dtype; + } + + /** Returns the size, in bytes, of the tensor data. */ + public int numBytes() { + return numBytes(nativeHandle); + } + + /** + * Returns the shape of + * the Tensor, i.e., the sizes of each dimension. + * + * @return an array where the i-th element is the size of the i-th dimension of the tensor. + */ + public int[] shape() { + return shapeCopy; + } + + /** + * Copies the contents of the provided {@code src} object to the Tensor. + * + *

The {@code src} should either be a (multi-dimensional) array with a shape matching that of + * this tensor, or a {@link ByteByffer} of compatible primitive type with a matching flat size. + * + * @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible + * with the tensor (for example, mismatched data types or shapes). + */ + void setTo(Object src) { + throwExceptionIfTypeIsIncompatible(src); + if (isByteBuffer(src)) { + ByteBuffer srcBuffer = (ByteBuffer) src; + // For direct ByteBuffer instances we support zero-copy. Note that this assumes the caller + // retains ownership of the source buffer until inference has completed. + if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) { + writeDirectBuffer(nativeHandle, srcBuffer); + } else { + buffer().put(srcBuffer); + } + return; + } + writeMultiDimensionalArray(nativeHandle, src); + } + + /** + * Copies the contents of the tensor to {@code dst} and returns {@code dst}. + * + * @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}. + * @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example, + * mismatched data types or shapes). + */ + Object copyTo(Object dst) { + throwExceptionIfTypeIsIncompatible(dst); + if (dst instanceof ByteBuffer) { + ByteBuffer dstByteBuffer = (ByteBuffer) dst; + dstByteBuffer.put(buffer()); + return dst; + } + readMultiDimensionalArray(nativeHandle, dst); + return dst; + } + + /** Returns the provided buffer's shape if specified and different from this Tensor's shape. */ + // TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs. + int[] getInputShapeIfDifferent(Object input) { + // Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path. + // The ByteBuffer's size will be validated against this Tensor's size in {@link #setTo(Object)}. + if (isByteBuffer(input)) { + return null; + } + int[] inputShape = shapeOf(input); + if (Arrays.equals(shapeCopy, inputShape)) { + return null; + } + return inputShape; + } + + /** Returns the type of the data. */ + static DataType dataTypeOf(Object o) { + if (o != null) { + Class c = o.getClass(); + while (c.isArray()) { + c = c.getComponentType(); + } + if (float.class.equals(c)) { + return DataType.FLOAT32; + } else if (int.class.equals(c)) { + return DataType.INT32; + } else if (byte.class.equals(c)) { + return DataType.UINT8; + } else if (long.class.equals(c)) { + return DataType.INT64; + } + } + throw new IllegalArgumentException( + "DataType error: cannot resolve DataType of " + o.getClass().getName()); + } + + /** Returns the shape of an object as an int array. */ + static int[] shapeOf(Object o) { + int size = numDimensions(o); + int[] dimensions = new int[size]; + fillShape(o, 0, dimensions); + return dimensions; + } + + /** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */ + static int numDimensions(Object o) { + if (o == null || !o.getClass().isArray()) { + return 0; + } + if (Array.getLength(o) == 0) { + throw new IllegalArgumentException("Array lengths cannot be 0."); + } + return 1 + numDimensions(Array.get(o, 0)); + } + + /** Recursively populates the shape dimensions for a given (multi-dimensional) array. */ + static void fillShape(Object o, int dim, int[] shape) { + if (shape == null || dim == shape.length) { + return; + } + final int len = Array.getLength(o); + if (shape[dim] == 0) { + shape[dim] = len; + } else if (shape[dim] != len) { + throw new IllegalArgumentException( + String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); + } + for (int i = 0; i < len; ++i) { + fillShape(Array.get(o, i), dim + 1, shape); + } + } + + private void throwExceptionIfTypeIsIncompatible(Object o) { + if (isByteBuffer(o)) { + ByteBuffer oBuffer = (ByteBuffer) o; + if (oBuffer.capacity() != numBytes()) { + throw new IllegalArgumentException( + String.format( + "Cannot convert between a TensorFlowLite buffer with %d bytes and a " + + "ByteBuffer with %d bytes.", + numBytes(), oBuffer.capacity())); + } + return; + } + DataType oType = dataTypeOf(o); + if (oType != dtype) { throw new IllegalArgumentException( String.format( - "Output error: Cannot convert an TensorFlowLite tensor with type %s to a Java " - + "object of type %s (which is compatible with the TensorFlowLite type %s)", - dtype, dst.getClass().getName(), NativeInterpreterWrapper.dataTypeOf(dst))); + "Cannot convert between a TensorFlowLite tensor with type %s and a Java " + + "object of type %s (which is compatible with the TensorFlowLite type %s).", + dtype, o.getClass().getName(), oType)); } - int[] dstShape = NativeInterpreterWrapper.shapeOf(dst); - if (!Arrays.equals(dstShape, shapeCopy)) { + + int[] oShape = shapeOf(o); + if (!Arrays.equals(oShape, shapeCopy)) { throw new IllegalArgumentException( String.format( - "Output error: Shape of output target %s does not match with the shape of the " - + "Tensor %s.", - Arrays.toString(dstShape), Arrays.toString(shapeCopy))); + "Cannot copy between a TensorFlowLite tensor with shape %s and a Java object " + + "with shape %s.", + Arrays.toString(shapeCopy), Arrays.toString(oShape))); } - readMultiDimensionalArray(nativeHandle, dst); - return dst; } - final long nativeHandle; - final DataType dtype; - final int[] shapeCopy; + private static boolean isByteBuffer(Object o) { + return o instanceof ByteBuffer; + } + + private final long nativeHandle; + private final DataType dtype; + private final int[] shapeCopy; private Tensor(long nativeHandle) { this.nativeHandle = nativeHandle; @@ -60,11 +212,23 @@ final class Tensor { this.shapeCopy = shape(nativeHandle); } + private ByteBuffer buffer() { + return buffer(nativeHandle).order(ByteOrder.nativeOrder()); + } + + private static native ByteBuffer buffer(long handle); + + private static native void writeDirectBuffer(long handle, ByteBuffer src); + private static native int dtype(long handle); private static native int[] shape(long handle); - private static native void readMultiDimensionalArray(long handle, Object value); + private static native int numBytes(long handle); + + private static native void readMultiDimensionalArray(long handle, Object dst); + + private static native void writeMultiDimensionalArray(long handle, Object src); static { TensorFlowLite.init(); diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD index 4399ed202597082fba36c04a744bf6378e4539a2..4b4e1c21d818dc56803ff31d83d19dea2ac08707 100644 --- a/tensorflow/contrib/lite/java/src/main/native/BUILD +++ b/tensorflow/contrib/lite/java/src/main/native/BUILD @@ -11,7 +11,6 @@ licenses(["notice"]) # Apache 2.0 cc_library( name = "native_framework_only", srcs = [ - "duration_utils_jni.cc", "exception_jni.cc", "nativeinterpreterwrapper_jni.cc", "tensor_jni.cc", diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 1fb6997fb9ba180e9a3f3a89a6d177086440c0d7..fdcf00a0a08459d8d669f1def3ae2eb21dbd31c3 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -16,9 +16,6 @@ limitations under the License. #include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h" namespace { -const int kByteBufferValue = 999; -const int kBufferSize = 256; - tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) { if (handle == 0) { throwException(env, kIllegalArgumentException, @@ -62,22 +59,6 @@ std::vector convertJIntArrayToVector(JNIEnv* env, jintArray inputs) { return outputs; } -bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; } - -TfLiteType resolveDataType(jint data_type) { - switch (data_type) { - case 1: - return kTfLiteFloat32; - case 2: - return kTfLiteInt32; - case 3: - return kTfLiteUInt8; - case 4: - return kTfLiteInt64; - default: - return kTfLiteNoType; - } -} int getDataType(TfLiteType data_type) { switch (data_type) { @@ -108,64 +89,6 @@ void printDims(char* buffer, int max_size, int* dims, int num_dims) { } } -TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter, - const int input_size, jintArray data_types, - jintArray nums_of_bytes, jobjectArray values, - jobjectArray sizes) { - if (input_size != interpreter->inputs().size()) { - throwException(env, kIllegalArgumentException, - "Input error: Expected num of inputs is %d but got %d", - interpreter->inputs().size(), input_size); - return kTfLiteError; - } - if (input_size != env->GetArrayLength(data_types) || - input_size != env->GetArrayLength(nums_of_bytes) || - input_size != env->GetArrayLength(values)) { - throwException(env, kIllegalArgumentException, - "Internal error: Arrays in arguments should be of the same " - "length, but got %d sizes, %d data_types, %d nums_of_bytes, " - "and %d values", - input_size, env->GetArrayLength(data_types), - env->GetArrayLength(nums_of_bytes), - env->GetArrayLength(values)); - return kTfLiteError; - } - for (int i = 0; i < input_size; ++i) { - int input_idx = interpreter->inputs()[i]; - TfLiteTensor* target = interpreter->tensor(input_idx); - jintArray dims = - static_cast(env->GetObjectArrayElement(sizes, i)); - int num_dims = static_cast(env->GetArrayLength(dims)); - if (target->dims->size != num_dims) { - throwException(env, kIllegalArgumentException, - "Input error: %d-th input should have %d dimensions, but " - "found %d dimensions", - i, target->dims->size, num_dims); - return kTfLiteError; - } - jint* ptr = env->GetIntArrayElements(dims, nullptr); - for (int j = 1; j < num_dims; ++j) { - if (target->dims->data[j] != ptr[j]) { - std::unique_ptr expected_dims(new char[kBufferSize]); - std::unique_ptr obtained_dims(new char[kBufferSize]); - printDims(expected_dims.get(), kBufferSize, target->dims->data, - num_dims); - printDims(obtained_dims.get(), kBufferSize, ptr, num_dims); - throwException(env, kIllegalArgumentException, - "Input error: %d-th input dimension should be [%s], but " - "found [%s]", - i, expected_dims.get(), obtained_dims.get()); - env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); - return kTfLiteError; - } - } - env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); - env->DeleteLocalRef(dims); - if (env->ExceptionCheck()) return kTfLiteError; - } - return kTfLiteOk; -} - // Checks whether there is any difference between dimensions of a tensor and a // given dimensions. Returns true if there is difference, else false. bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) { @@ -188,74 +111,6 @@ bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) { return false; } -bool areInputDimensionsTheSame(JNIEnv* env, tflite::Interpreter* interpreter, - int input_size, jobjectArray sizes) { - if (interpreter->inputs().size() != input_size) { - return false; - } - for (int i = 0; i < input_size; ++i) { - int input_idx = interpreter->inputs()[i]; - jintArray dims = - static_cast(env->GetObjectArrayElement(sizes, i)); - TfLiteTensor* target = interpreter->tensor(input_idx); - if (areDimsDifferent(env, target, dims)) return false; - env->DeleteLocalRef(dims); - if (env->ExceptionCheck()) return false; - } - return true; -} - -TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter, - int input_size, jobjectArray sizes) { - for (int i = 0; i < input_size; ++i) { - int input_idx = interpreter->inputs()[i]; - jintArray dims = - static_cast(env->GetObjectArrayElement(sizes, i)); - TfLiteStatus status = interpreter->ResizeInputTensor( - input_idx, convertJIntArrayToVector(env, dims)); - if (status != kTfLiteOk) { - return status; - } - env->DeleteLocalRef(dims); - if (env->ExceptionCheck()) return kTfLiteError; - } - return kTfLiteOk; -} - -TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter, - int input_size, jintArray data_types, - jintArray nums_of_bytes, jobjectArray values) { - jint* data_type = env->GetIntArrayElements(data_types, nullptr); - jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr); - for (int i = 0; i < input_size; ++i) { - int input_idx = interpreter->inputs()[i]; - TfLiteTensor* target = interpreter->tensor(input_idx); - jobject value = env->GetObjectArrayElement(values, i); - bool is_byte_buffer = isByteBuffer(data_type[i]); - if (is_byte_buffer) { - writeByteBuffer(env, value, &(target->data.raw), - static_cast(num_bytes[i])); - } else { - TfLiteType type = resolveDataType(data_type[i]); - if (type != target->type) { - throwException(env, kIllegalArgumentException, - "Input error: DataType (%d) of input data does not " - "match with the DataType (%d) of model inputs.", - type, target->type); - return kTfLiteError; - } - writeMultiDimensionalArray(env, value, target->type, target->dims->size, - &(target->data.raw), - static_cast(num_bytes[i])); - } - env->DeleteLocalRef(value); - if (env->ExceptionCheck()) return kTfLiteError; - } - env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT); - env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT); - return kTfLiteOk; -} - // TODO(yichengfan): evaluate the benefit to use tflite verifier. bool VerifyModel(const void* buf, size_t len) { flatbuffers::Verifier verifier(static_cast(buf), len); @@ -287,6 +142,64 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, return names; } +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( + JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return; + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return; + + if (interpreter->AllocateTensors() != kTfLiteOk) { + throwException( + env, kIllegalStateException, + "Internal error: Unexpected failure when preparing tensor allocations:" + " %s", + error_reporter->CachedErrorMessage()); + } +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env, + jclass clazz, + jlong handle, + jint index) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + return reinterpret_cast( + interpreter->tensor(interpreter->inputs()[index])); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env, + jclass clazz, + jlong handle, + jint index) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + return reinterpret_cast( + interpreter->tensor(interpreter->outputs()[index])); +} + +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + return static_cast(interpreter->inputs().size()); +} + +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + return static_cast(interpreter->outputs().size()); +} + JNIEXPORT jobjectArray JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, jclass clazz, @@ -424,124 +337,32 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( // allocates memory status = interpreter->AllocateTensors(); if (status != kTfLiteOk) { - throwException(env, kNullPointerException, - "Internal error: Cannot allocate memory for the interpreter:" - " %s", - error_reporter->CachedErrorMessage()); + throwException( + env, kIllegalStateException, + "Internal error: Unexpected failure when preparing tensor allocations:" + " %s", + error_reporter->CachedErrorMessage()); return 0; } return reinterpret_cast(interpreter.release()); } // Sets inputs, runs inference, and returns outputs as long handles. -JNIEXPORT jlongArray JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_run( - JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, - jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, - jobjectArray values, jobject wrapper, jboolean memory_allocated) { +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) { tflite::Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); - if (interpreter == nullptr) return nullptr; + if (interpreter == nullptr) return; BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle); - if (error_reporter == nullptr) return nullptr; - const int input_size = env->GetArrayLength(sizes); - // validates inputs - TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types, - nums_of_bytes, values, sizes); - if (status != kTfLiteOk) return nullptr; - if (!memory_allocated || - !areInputDimensionsTheSame(env, interpreter, input_size, sizes)) { - // resizes inputs - status = resizeInputs(env, interpreter, input_size, sizes); - if (status != kTfLiteOk) { - throwException(env, kNullPointerException, - "Internal error: Can not resize the input: %s", - error_reporter->CachedErrorMessage()); - return nullptr; - } - // allocates memory - status = interpreter->AllocateTensors(); - if (status != kTfLiteOk) { - throwException(env, kNullPointerException, - "Internal error: Can not allocate memory for the given " - "inputs: %s", - error_reporter->CachedErrorMessage()); - return nullptr; - } - } - // sets inputs - status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes, - values); - if (status != kTfLiteOk) return nullptr; - timespec beforeInference = ::tflite::getCurrentTime(); - // runs inference + if (error_reporter == nullptr) return; + if (interpreter->Invoke() != kTfLiteOk) { throwException(env, kIllegalArgumentException, "Internal error: Failed to run on the given Interpreter: %s", error_reporter->CachedErrorMessage()); - return nullptr; + return; } - timespec afterInference = ::tflite::getCurrentTime(); - jclass wrapper_clazz = env->GetObjectClass(wrapper); - jfieldID fid = - env->GetFieldID(wrapper_clazz, "inferenceDurationNanoseconds", "J"); - if (env->ExceptionCheck()) { - env->ExceptionClear(); - } else if (fid != nullptr) { - env->SetLongField( - wrapper, fid, - ::tflite::timespec_diff_nanoseconds(&beforeInference, &afterInference)); - } - // returns outputs - const std::vector& results = interpreter->outputs(); - if (results.empty()) { - throwException( - env, kIllegalArgumentException, - "Internal error: The Interpreter does not have any outputs."); - return nullptr; - } - jlongArray outputs = env->NewLongArray(results.size()); - size_t size = results.size(); - for (int i = 0; i < size; ++i) { - TfLiteTensor* source = interpreter->tensor(results[i]); - jlong output = reinterpret_cast(source); - env->SetLongArrayRegion(outputs, i, 1, &output); - } - return outputs; -} - -JNIEXPORT jintArray JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( - JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); - if (interpreter == nullptr) return nullptr; - const int idx = static_cast(input_idx); - if (input_idx < 0 || input_idx >= interpreter->inputs().size()) { - throwException(env, kIllegalArgumentException, - "Input error: Out of range: Failed to get %d-th input out of" - " %d inputs", - input_idx, interpreter->inputs().size()); - return nullptr; - } - TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]); - int size = target->dims->size; - if (num_bytes >= 0) { // verifies num of bytes matches if num_bytes if valid. - int expected_num_bytes = elementByteSize(target->type); - for (int i = 0; i < size; ++i) { - expected_num_bytes *= target->dims->data[i]; - } - if (num_bytes != expected_num_bytes) { - throwException(env, kIllegalArgumentException, - "Input error: Failed to get input dimensions. %d-th input " - "should have %d bytes, but found %d bytes.", - idx, expected_num_bytes, num_bytes); - return nullptr; - } - } - jintArray outputs = env->NewIntArray(size); - env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0])); - return outputs; } JNIEXPORT jint JNICALL @@ -561,6 +382,38 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( return static_cast(type); } +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationZeroPoint( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + const int idx = static_cast(output_idx); + if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { + throwException(env, kIllegalArgumentException, + "Failed to get %d-th output out of %d outputs", output_idx, + interpreter->outputs().size()); + return 0; + } + TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]); + return static_cast(target->params.zero_point); +} + +JNIEXPORT jfloat JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationScale( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 1.0f; + const int idx = static_cast(output_idx); + if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { + throwException(env, kIllegalArgumentException, + "Failed to get %d-th output out of %d outputs", output_idx, + interpreter->outputs().size()); + return 1.0f; + } + TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]); + return static_cast(target->params.scale); +} + JNIEXPORT jboolean JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index eaa765cb343e9764bd0ef018d636a76f4b8a13e4..618fba480e4a1c4a1ff8531cb3fbc29fcb8191d8 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -29,15 +29,63 @@ limitations under the License. namespace tflite { // This is to be provided at link-time by a library. extern std::unique_ptr CreateOpResolver(); -extern timespec getCurrentTime(); -extern jlong timespec_diff_nanoseconds(struct timespec* start, - struct timespec* stop); } // namespace tflite #ifdef __cplusplus extern "C" { #endif // __cplusplus +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: allocateTensors + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( + JNIEnv* env, jclass clazz, jlong handle, jlong error_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: getInputTensor + * Signature: (JI)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env, + jclass clazz, + jlong handle, + jint index); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: getOutputTensor + * Signature: (JI)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env, + jclass clazz, + jlong handle, + jint index); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: getInputCount + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: getOutputCount + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env, + jclass clazz, + jlong handle); + /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: @@ -118,38 +166,43 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( /* * Class: org_tensorflow_lite_NativeInterpreterWrapper - * Method: - * Signature: - * (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;Ljava/lang/Object;Z)[J + * Method: run + * Signature: (JJ)V */ -JNIEXPORT jlongArray JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_run( - JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, - jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, - jobjectArray values, jobject wrapper, jboolean memory_allocated); +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: - * Signature: (JII)[I + * Signature: (JI)I * - * Gets input dimensions. If num_bytes is non-negative, it will check whether - * num_bytes matches num of bytes required by the input, and return null and - * throw IllegalArgumentException if not. + * Gets output dimensions. */ -JNIEXPORT jintArray JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( - JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes); +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: * Signature: (JI)I * - * Gets output dimensions. + * Gets output quantization zero point. */ JNIEXPORT jint JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationZeroPoint( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JI)F + * + * Gets output quantization scale. + */ +JNIEXPORT jfloat JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationScale( JNIEnv* env, jclass clazz, jlong handle, jint output_idx); /* diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc index 9e9387da86ebde7d711a7ce967461e370c95bc3e..7ff96a3172dcf020b34fcbe7491c9022fc7f51de 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc @@ -29,6 +29,35 @@ TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) { return reinterpret_cast(handle); } +size_t elementByteSize(TfLiteType data_type) { + // The code in this file makes the assumption that the + // TensorFlow TF_DataTypes and the Java primitive types + // have the same byte sizes. Validate that: + switch (data_type) { + case kTfLiteFloat32: + static_assert(sizeof(jfloat) == 4, + "Interal error: Java float not compatible with " + "kTfLiteFloat"); + return 4; + case kTfLiteInt32: + static_assert(sizeof(jint) == 4, + "Interal error: Java int not compatible with kTfLiteInt"); + return 4; + case kTfLiteUInt8: + static_assert(sizeof(jbyte) == 1, + "Interal error: Java byte not compatible with " + "kTfLiteUInt8"); + return 1; + case kTfLiteInt64: + static_assert(sizeof(jlong) == 8, + "Interal error: Java long not compatible with " + "kTfLiteInt64"); + return 8; + default: + return 0; + } +} + size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, void* dst, size_t dst_size) { jarray array = static_cast(object); @@ -141,48 +170,6 @@ size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src, } } -} // namespace - -size_t elementByteSize(TfLiteType data_type) { - // The code in this file makes the assumption that the - // TensorFlow TF_DataTypes and the Java primitive types - // have the same byte sizes. Validate that: - switch (data_type) { - case kTfLiteFloat32: - static_assert(sizeof(jfloat) == 4, - "Interal error: Java float not compatible with " - "kTfLiteFloat"); - return 4; - case kTfLiteInt32: - static_assert(sizeof(jint) == 4, - "Interal error: Java int not compatible with kTfLiteInt"); - return 4; - case kTfLiteUInt8: - static_assert(sizeof(jbyte) == 1, - "Interal error: Java byte not compatible with " - "kTfLiteUInt8"); - return 1; - case kTfLiteInt64: - static_assert(sizeof(jlong) == 8, - "Interal error: Java long not compatible with " - "kTfLiteInt64"); - return 8; - default: - return 0; - } -} - -size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) { - char* buf = static_cast(env->GetDirectBufferAddress(object)); - if (!buf) { - throwException(env, kIllegalArgumentException, - "Input ByteBuffer is not a direct buffer"); - return 0; - } - *dst = buf; - return dst_size; -} - size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, int dims_left, char** dst, int dst_size) { if (dims_left <= 1) { @@ -203,6 +190,37 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, } } +} // namespace + +JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env, + jclass clazz, + jlong handle) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return nullptr; + if (tensor->data.raw == nullptr) { + throwException(env, kIllegalArgumentException, + "Internal error: Tensor hasn't been allocated."); + return nullptr; + } + return env->NewDirectByteBuffer(static_cast(tensor->data.raw), + static_cast(tensor->bytes)); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer( + JNIEnv* env, jclass clazz, jlong handle, jobject src) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return; + + char* src_data_raw = static_cast(env->GetDirectBufferAddress(src)); + if (!src_data_raw) { + throwException(env, kIllegalArgumentException, + "Input ByteBuffer is not a direct buffer"); + return; + } + + tensor->data.raw = src_data_raw; +} + JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, jclass clazz, @@ -220,6 +238,27 @@ Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, num_dims, static_cast(value)); } +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject src) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return; + if (tensor->data.raw == nullptr) { + throwException(env, kIllegalArgumentException, + "Internal error: Target Tensor hasn't been allocated."); + return; + } + if (tensor->dims->size == 0) { + throwException(env, kIllegalArgumentException, + "Internal error: Cannot copy empty/scalar Tensors."); + return; + } + writeMultiDimensionalArray(env, src, tensor->type, tensor->dims->size, + &tensor->data.raw, tensor->bytes); +} + JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, jclass clazz, jlong handle) { @@ -237,3 +276,11 @@ Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) { env->SetIntArrayRegion(result, 0, num_dims, tensor->dims->data); return result; } + +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env, + jclass clazz, + jlong handle) { + const TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return 0; + return static_cast(tensor->bytes); +} diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h index 3a4910dcc3a719fbb9f365dae693423de768349c..06e2546af8400de117ed6923a1d1bd67bcb998e2 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h @@ -24,8 +24,25 @@ extern "C" { #endif // __cplusplus /* - * Class: org_tensorflow_lite_TfLiteTensor - * Method: + * Class: org_tensorflow_lite_Tensor + * Method: buffer + * Signature: (J)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_Tensor + * Method: writeDirectBuffer + * Signature: (JLjava/nio/ByteBuffer;) + */ +JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer( + JNIEnv* env, jclass clazz, jlong handle, jobject src); + +/* + * Class: org_tensorflow_lite_Tensor + * Method: dtype * Signature: (J)I */ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, @@ -33,8 +50,8 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, jlong handle); /* - * Class: org_tensorflow_lite_TfLiteTensor - * Method: + * Class: org_tensorflow_lite_Tensor + * Method: shape * Signature: (J)[I */ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, @@ -42,31 +59,35 @@ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jlong handle); /* - * Class: org_tensorflow_lite_TfLiteTensor - * Method: + * Class: org_tensorflow_lite_Tensor + * Method: numBytes + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_Tensor + * Method: readMultiDimensionalArray * Signature: (JLjava/lang/Object;) */ JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, jclass clazz, jlong handle, - jobject value); + jobject dst); /* - * Finds the size of each data type. - */ -size_t elementByteSize(TfLiteType data_type); - -/* - * Writes data of a ByteBuffer into dest. - */ -size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size); - -/* - * Writes a multi-dimensional array into dest. + * Class: org_tensorflow_lite_Tensor + * Method: writeMultidimensionalArray + * Signature: (JLjava/lang/Object;) */ -size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, - int dims_left, char** dst, int dst_size); +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject src); #ifdef __cplusplus } // extern "C" diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 82007a6ab5be3492495125b1c20ed155907ae5a0..d66a73db94f06776fe2a7310ed0837941aba87c4 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -164,6 +164,24 @@ public final class InterpreterTest { interpreter.close(); } + @Test + public void testRunWithByteBufferOutput() { + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + ByteBuffer parsedOutput = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + try (Interpreter interpreter = new Interpreter(MODEL_FILE)) { + interpreter.run(fourD, parsedOutput); + } + float[] outputOneD = { + parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8) + }; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + @Test public void testMobilenetRun() { // Create a gray image. @@ -203,7 +221,9 @@ public final class InterpreterTest { assertThat(e) .hasMessageThat() .contains( - "DataType (2) of input data does not match with the DataType (1) of model inputs."); + "Cannot convert between a TensorFlowLite tensor with type " + + "FLOAT32 and a Java object of type [[[[I (which is compatible with the" + + " TensorFlowLite type INT32)"); } interpreter.close(); } @@ -223,8 +243,8 @@ public final class InterpreterTest { assertThat(e) .hasMessageThat() .contains( - "Cannot convert an TensorFlowLite tensor with type " - + "FLOAT32 to a Java object of type [[[[I (which is compatible with the" + "Cannot convert between a TensorFlowLite tensor with type " + + "FLOAT32 and a Java object of type [[[[I (which is compatible with the" + " TensorFlowLite type INT32)"); } interpreter.close(); @@ -311,4 +331,11 @@ public final class InterpreterTest { interpreter.close(); fileChannel.close(); } + + @Test + public void testRedundantClose() throws Exception { + Interpreter interpreter = new Interpreter(MODEL_FILE); + interpreter.close(); + interpreter.close(); + } } diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 7c00d3196fd001a288d77d4e01f0b30978d72afe..9c4a5acd797ec3476f44fb203901c9ba0429ab26 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -20,6 +20,8 @@ import static org.junit.Assert.fail; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.HashMap; +import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -41,6 +43,9 @@ public final class NativeInterpreterWrapperTest { private static final String BYTE_MODEL_PATH = "tensorflow/contrib/lite/java/src/testdata/uint8.bin"; + private static final String QUANTIZED_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/quantized.bin"; + private static final String INVALID_MODEL_PATH = "tensorflow/contrib/lite/java/src/testdata/invalid_model.bin"; @@ -98,16 +103,37 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); float[][][][] parsedOutputs = new float[2][8][8][3]; - outputs[0].copyTo(parsedOutputs); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); float[] outputOneD = parsedOutputs[0][0][0]; float[] expected = {3.69f, -19.62f, 23.43f}; assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); wrapper.close(); } + @Test + public void testRunWithBufferOutput() { + try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) { + float[] oneD = {1.23f, -6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + ByteBuffer parsedOutput = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutput); + wrapper.run(inputs, outputs); + float[] outputOneD = { + parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8) + }; + float[] expected = {3.69f, -19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + } + @Test public void testRunWithInputsOfSameDims() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); @@ -116,17 +142,16 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); float[][][][] parsedOutputs = new float[2][8][8][3]; - outputs[0].copyTo(parsedOutputs); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); float[] outputOneD = parsedOutputs[0][0][0]; float[] expected = {3.69f, -19.62f, 23.43f}; assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); - outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); parsedOutputs = new float[2][8][8][3]; - outputs[0].copyTo(parsedOutputs); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); outputOneD = parsedOutputs[0][0][0]; assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); wrapper.close(); @@ -140,10 +165,10 @@ public final class NativeInterpreterWrapperTest { int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; int[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); int[][][][] parsedOutputs = new int[2][4][4][12]; - outputs[0].copyTo(parsedOutputs); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); int[] outputOneD = parsedOutputs[0][0][0]; int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4}; assertThat(outputOneD).isEqualTo(expected); @@ -158,10 +183,10 @@ public final class NativeInterpreterWrapperTest { long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; long[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); long[][][][] parsedOutputs = new long[2][4][4][12]; - outputs[0].copyTo(parsedOutputs); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); long[] outputOneD = parsedOutputs[0][0][0]; long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L}; @@ -179,10 +204,10 @@ public final class NativeInterpreterWrapperTest { Object[] inputs = {fourD}; int[] inputDims = {2, 8, 8, 3}; wrapper.resizeInput(0, inputDims); - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); byte[][][][] parsedOutputs = new byte[2][4][4][12]; - outputs[0].copyTo(parsedOutputs); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); byte[] outputOneD = parsedOutputs[0][0][0]; byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0}; @@ -205,13 +230,14 @@ public final class NativeInterpreterWrapperTest { } } } + bbuf.rewind(); Object[] inputs = {bbuf}; int[] inputDims = {2, 8, 8, 3}; wrapper.resizeInput(0, inputDims); - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); byte[][][][] parsedOutputs = new byte[2][4][4][12]; - outputs[0].copyTo(parsedOutputs); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); byte[] outputOneD = parsedOutputs[0][0][0]; byte[] expected = { (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, @@ -237,21 +263,22 @@ public final class NativeInterpreterWrapperTest { } } Object[] inputs = {bbuf}; + float[][][][] parsedOutputs = new float[4][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() .contains( - "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes"); + "Cannot convert between a TensorFlowLite buffer with 768 bytes and a " + + "ByteBuffer with 3072 bytes."); } int[] inputDims = {4, 8, 8, 3}; wrapper.resizeInput(0, inputDims); - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); - float[][][][] parsedOutputs = new float[4][8][8][3]; - outputs[0].copyTo(parsedOutputs); + wrapper.run(inputs, outputs); float[] outputOneD = parsedOutputs[0][0][0]; float[] expected = {3.69f, -19.62f, 23.43f}; assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); @@ -264,14 +291,18 @@ public final class NativeInterpreterWrapperTest { ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3); bbuf.order(ByteOrder.nativeOrder()); Object[] inputs = {bbuf}; + Map outputs = new HashMap<>(); + ByteBuffer parsedOutput = ByteBuffer.allocateDirect(2 * 7 * 8 * 3); + outputs.put(0, parsedOutput); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() .contains( - "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes."); + "Cannot convert between a TensorFlowLite buffer with 192 bytes and a " + + "ByteBuffer with 336 bytes."); } wrapper.close(); } @@ -284,14 +315,18 @@ public final class NativeInterpreterWrapperTest { int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; int[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; + int[][][][] parsedOutputs = new int[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() .contains( - "DataType (2) of input data does not match with the DataType (1) of model inputs."); + "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object " + + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)"); } wrapper.close(); } @@ -305,8 +340,11 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter."); @@ -318,7 +356,7 @@ public final class NativeInterpreterWrapperTest { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); try { Object[] inputs = {}; - wrapper.run(inputs); + wrapper.run(inputs, null); fail(); } catch (IllegalArgumentException e) { assertThat(e).hasMessageThat().contains("Inputs should not be null or empty."); @@ -334,11 +372,14 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD, fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2"); + assertThat(e).hasMessageThat().contains("Invalid input Tensor index: 1"); } wrapper.close(); } @@ -350,13 +391,18 @@ public final class NativeInterpreterWrapperTest { float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD}; float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; Object[] inputs = {threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() - .contains("0-th input should have 4 dimensions, but found 3 dimensions"); + .contains( + "Cannot copy between a TensorFlowLite tensor with shape [8, 7, 3] and a " + + "Java object with shape [2, 8, 8, 3]."); } wrapper.close(); } @@ -369,91 +415,22 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() - .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]"); + .contains( + "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 7, 3] and a " + + "Java object with shape [2, 8, 8, 3]."); } wrapper.close(); } - @Test - public void testNumElements() { - int[] shape = {2, 3, 4}; - int num = NativeInterpreterWrapper.numElements(shape); - assertThat(num).isEqualTo(24); - shape = null; - num = NativeInterpreterWrapper.numElements(shape); - assertThat(num).isEqualTo(0); - } - - @Test - public void testIsNonEmtpyArray() { - assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse(); - assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse(); - int[] emptyArray = {}; - assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse(); - int[] validArray = {9, 5, 2, 1}; - assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue(); - } - - @Test - public void testDataTypeOf() { - float[] testEmtpyArray = {}; - DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray); - assertThat(dataType).isEqualTo(DataType.FLOAT32); - float[] testFloatArray = {0.783f, 0.251f}; - dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); - assertThat(dataType).isEqualTo(DataType.FLOAT32); - float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; - dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); - assertThat(dataType).isEqualTo(DataType.FLOAT32); - try { - double[] testDoubleArray = {0.783, 0.251}; - NativeInterpreterWrapper.dataTypeOf(testDoubleArray); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); - } - try { - Float[] testBoxedArray = {0.783f, 0.251f}; - NativeInterpreterWrapper.dataTypeOf(testBoxedArray); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); - } - } - - @Test - public void testNumDimensions() { - int scalar = 1; - assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0); - int[][] array = {{2, 4}, {1, 9}}; - assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2); - try { - int[] emptyArray = {}; - NativeInterpreterWrapper.numDimensions(emptyArray); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("Array lengths cannot be 0."); - } - } - - @Test - public void testFillShape() { - int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}}; - int num = NativeInterpreterWrapper.numDimensions(array); - int[] shape = new int[num]; - NativeInterpreterWrapper.fillShape(array, 0, shape); - assertThat(num).isEqualTo(3); - assertThat(shape[0]).isEqualTo(2); - assertThat(shape[1]).isEqualTo(3); - assertThat(shape[2]).isEqualTo(1); - } - @Test public void testGetInferenceLatency() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); @@ -462,8 +439,10 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isGreaterThan(0L); wrapper.close(); } @@ -483,13 +462,14 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { - assertThat(e) - .hasMessageThat() - .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]"); + // Expected. } assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull(); wrapper.close(); @@ -499,41 +479,19 @@ public final class NativeInterpreterWrapperTest { public void testGetInputDims() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); int[] expectedDims = {1, 8, 8, 3}; - assertThat(wrapper.getInputDims(0)).isEqualTo(expectedDims); + assertThat(wrapper.getInputTensor(0).shape()).isEqualTo(expectedDims); wrapper.close(); } @Test - public void testGetInputDimsOutOfRange() { - NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); - try { - wrapper.getInputDims(-1); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("Out of range"); + public void testGetOutputQuantizationParams() { + try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) { + assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(0); + assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.0f); } - try { - wrapper.getInputDims(1); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("Out of range"); + try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(QUANTIZED_MODEL_PATH)) { + assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(127); + assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.25f); } - wrapper.close(); - } - - @Test - public void testGetOutputDataType() { - NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); - assertThat(wrapper.getOutputDataType(0)).contains("float"); - wrapper.close(); - wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH); - assertThat(wrapper.getOutputDataType(0)).contains("long"); - wrapper.close(); - wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH); - assertThat(wrapper.getOutputDataType(0)).contains("int"); - wrapper.close(); - wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); - assertThat(wrapper.getOutputDataType(0)).contains("byte"); - wrapper.close(); } } diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java index 94b6632bb8dd7117bf4074da1939bd23ce732efd..71ef04494357e8b951cbbbd2c68385b17c472736 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -18,6 +18,10 @@ package org.tensorflow.lite; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.HashMap; +import java.util.Map; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -32,7 +36,7 @@ public final class TensorTest { "tensorflow/contrib/lite/java/src/testdata/add.bin"; private NativeInterpreterWrapper wrapper; - private long nativeHandle; + private Tensor tensor; @Before public void setUp() { @@ -42,8 +46,10 @@ public final class TensorTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - nativeHandle = outputs[0].nativeHandle; + Map outputs = new HashMap<>(); + outputs.put(0, new float[2][8][8][3]); + wrapper.run(inputs, outputs); + tensor = wrapper.getOutputTensor(0); } @After @@ -52,17 +58,16 @@ public final class TensorTest { } @Test - public void testFromHandle() throws Exception { - Tensor tensor = Tensor.fromHandle(nativeHandle); + public void testBasic() throws Exception { assertThat(tensor).isNotNull(); int[] expectedShape = {2, 8, 8, 3}; - assertThat(tensor.shapeCopy).isEqualTo(expectedShape); - assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32); + assertThat(tensor.shape()).isEqualTo(expectedShape); + assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32); + assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4); } @Test public void testCopyTo() { - Tensor tensor = Tensor.fromHandle(nativeHandle); float[][][][] parsedOutputs = new float[2][8][8][3]; tensor.copyTo(parsedOutputs); float[] outputOneD = parsedOutputs[0][0][0]; @@ -70,9 +75,32 @@ public final class TensorTest { assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); } + @Test + public void testCopyToByteBuffer() { + ByteBuffer parsedOutput = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + tensor.copyTo(parsedOutput); + assertThat(parsedOutput.position()).isEqualTo(2 * 8 * 8 * 3 * 4); + float[] outputOneD = { + parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8) + }; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + + @Test + public void testCopyToInvalidByteBuffer() { + ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); + try { + tensor.copyTo(parsedOutput); + fail(); + } catch (IllegalArgumentException e) { + // Expected. + } + } + @Test public void testCopyToWrongType() { - Tensor tensor = Tensor.fromHandle(nativeHandle); int[][][][] parsedOutputs = new int[2][8][8][3]; try { tensor.copyTo(parsedOutputs); @@ -81,15 +109,13 @@ public final class TensorTest { assertThat(e) .hasMessageThat() .contains( - "Cannot convert an TensorFlowLite tensor with type " - + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite " - + "type INT32)"); + "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object " + + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)"); } } @Test public void testCopyToWrongShape() { - Tensor tensor = Tensor.fromHandle(nativeHandle); float[][][][] parsedOutputs = new float[1][8][8][3]; try { tensor.copyTo(parsedOutputs); @@ -98,8 +124,104 @@ public final class TensorTest { assertThat(e) .hasMessageThat() .contains( - "Shape of output target [1, 8, 8, 3] does not match " - + "with the shape of the Tensor [2, 8, 8, 3]."); + "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 8, 3] " + + "and a Java object with shape [1, 8, 8, 3]."); + } + } + + @Test + public void testSetTo() { + float[][][][] input = new float[2][8][8][3]; + float[][][][] output = new float[2][8][8][3]; + ByteBuffer inputByteBuffer = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + + input[0][0][0][0] = 2.0f; + tensor.setTo(input); + tensor.copyTo(output); + assertThat(output[0][0][0][0]).isEqualTo(2.0f); + + inputByteBuffer.putFloat(0, 3.0f); + tensor.setTo(inputByteBuffer); + tensor.copyTo(output); + assertThat(output[0][0][0][0]).isEqualTo(3.0f); + } + + @Test + public void testSetToInvalidByteBuffer() { + ByteBuffer input = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); + try { + tensor.setTo(input); + fail(); + } catch (IllegalArgumentException e) { + // Success. + } + } + + @Test + public void testGetInputShapeIfDifferent() { + ByteBuffer bytBufferInput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); + assertThat(tensor.getInputShapeIfDifferent(bytBufferInput)).isNull(); + + float[][][][] sameShapeInput = new float[2][8][8][3]; + assertThat(tensor.getInputShapeIfDifferent(sameShapeInput)).isNull(); + + float[][][][] differentShapeInput = new float[1][8][8][3]; + assertThat(tensor.getInputShapeIfDifferent(differentShapeInput)) + .isEqualTo(new int[] {1, 8, 8, 3}); + } + + @Test + public void testDataTypeOf() { + float[] testEmptyArray = {}; + DataType dataType = Tensor.dataTypeOf(testEmptyArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[] testFloatArray = {0.783f, 0.251f}; + dataType = Tensor.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; + dataType = Tensor.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + try { + double[] testDoubleArray = {0.783, 0.251}; + Tensor.dataTypeOf(testDoubleArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); } + try { + Float[] testBoxedArray = {0.783f, 0.251f}; + Tensor.dataTypeOf(testBoxedArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); + } + } + + @Test + public void testNumDimensions() { + int scalar = 1; + assertThat(Tensor.numDimensions(scalar)).isEqualTo(0); + int[][] array = {{2, 4}, {1, 9}}; + assertThat(Tensor.numDimensions(array)).isEqualTo(2); + try { + int[] emptyArray = {}; + Tensor.numDimensions(emptyArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Array lengths cannot be 0."); + } + } + + @Test + public void testFillShape() { + int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}}; + int num = Tensor.numDimensions(array); + int[] shape = new int[num]; + Tensor.fillShape(array, 0, shape); + assertThat(num).isEqualTo(3); + assertThat(shape[0]).isEqualTo(2); + assertThat(shape[1]).isEqualTo(3); + assertThat(shape[2]).isEqualTo(1); } } diff --git a/tensorflow/contrib/lite/java/src/testdata/quantized.bin b/tensorflow/contrib/lite/java/src/testdata/quantized.bin new file mode 100644 index 0000000000000000000000000000000000000000..4062088cdf717e8752490de5c9acff35fd6af54f Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/quantized.bin differ diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD index b524246d436858bbf506809a38cead2897f78d93..af1d99ef41e6413d8ef2c6f478aaa8f9e3931ff8 100644 --- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD @@ -1,6 +1,8 @@ # Description: # Internal helper function to test TF Lite API. +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java index 3aef0c3bb6cc4748de0e55d31f0215a77320ae69..38b740021bb5037fc8980c75ca6aac2a9cc20c4e 100644 --- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java @@ -58,13 +58,32 @@ public class TestHelper { */ public static int[] getInputDims(Interpreter interpreter, int index) { if (interpreter != null && interpreter.wrapper != null) { - return interpreter.wrapper.getInputDims(index); + return interpreter.wrapper.getInputTensor(index).shape(); } else { throw new IllegalArgumentException( "Interpreter has not initialized;" + " Failed to get input dimensions."); } } + /** + * Gets the string name of the data type of an input. + * + * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code + * IllegalArgumentException} will be thrown. + * @param index an integer index of the input. If it is invalid, an {@code + * IllegalArgumentException} will be thrown. + * @return string name of the data type. Possible values include "float", "int", "byte", and + * "long". + */ + public static String getInputDataType(Interpreter interpreter, int index) { + if (interpreter != null && interpreter.wrapper != null) { + return interpreter.wrapper.getInputTensor(index).dataType().toStringName(); + } else { + throw new IllegalArgumentException( + "Interpreter has not initialized;" + " Failed to get input data type."); + } + } + /** * Gets the string name of the data type of an output. * @@ -77,7 +96,7 @@ public class TestHelper { */ public static String getOutputDataType(Interpreter interpreter, int index) { if (interpreter != null && interpreter.wrapper != null) { - return interpreter.wrapper.getOutputDataType(index); + return interpreter.wrapper.getOutputTensor(index).dataType().toStringName(); } else { throw new IllegalArgumentException( "Interpreter has not initialized;" + " Failed to get output data type."); diff --git a/tensorflow/contrib/lite/kernels/Android.bp b/tensorflow/contrib/lite/kernels/Android.bp index 0e89cc33f6e7438f38d121cac2639965b35f9e51..5963d7555dd3bbe037b8950f6ec549361f2d8c87 100644 --- a/tensorflow/contrib/lite/kernels/Android.bp +++ b/tensorflow/contrib/lite/kernels/Android.bp @@ -37,7 +37,7 @@ cc_library_static { srcs: [ "activations.cc", "add.cc", - "arg_max.cc", + "arg_min_max.cc", "basic_rnn.cc", "batch_to_space_nd.cc", "bidirectional_sequence_lstm.cc", @@ -54,6 +54,7 @@ cc_library_static { "embedding_lookup_sparse.cc", "exp.cc", "expand_dims.cc", + "fake_quant.cc", "floor.cc", "fully_connected.cc", "gather.cc", @@ -61,18 +62,23 @@ cc_library_static { "kernel_util.cc", "l2norm.cc", "local_response_norm.cc", + "logical.cc", "lsh_projection.cc", "lstm.cc", "maximum_minimum.cc", - "mean.cc", "mul.cc", "neg.cc", + "one_hot.cc", "pad.cc", + "pack.cc", "pooling.cc", + "pow.cc", + "reduce.cc", "register.cc", "reshape.cc", "resize_bilinear.cc", "select.cc", + "shape.cc", "skip_gram.cc", "slice.cc", "space_to_batch_nd.cc", @@ -110,5 +116,6 @@ cc_library_static { "-Wno-sign-compare", "-Wno-unused-local-typedef", "-Wno-unused-variable", + "-Wno-mismatched-tags", ], } diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index cf5d0b4ce9cb3c516c185f31fea12db70a2c3bdb..c5586475ec258849948ff6b960abc846e2ea1b3c 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -8,11 +8,27 @@ load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +# Suppress warnings that are introduced by Eigen Tensor. +EXTRA_EIGEN_COPTS = select({ + "//tensorflow:ios": [ + "-Wno-error=invalid-partial-specialization", + "-Wno-error=reorder", + ], + "//tensorflow:windows": [ + "/DEIGEN_HAS_C99_MATH", + "/DEIGEN_AVOID_STL_ARRAY", + ], + "//conditions:default": ["-Wno-error=reorder"], +}) + tf_cc_test( name = "optional_tensor_test", size = "small", srcs = ["optional_tensor_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -46,11 +62,12 @@ cc_library( hdrs = [ "eigen_support.h", ], - copts = tflite_copts(), + copts = tflite_copts() + EXTRA_EIGEN_COPTS, deps = [ ":op_macros", + "//tensorflow/contrib/lite:arena_planner", "//tensorflow/contrib/lite:context", - "//third_party/eigen3", + "//tensorflow/contrib/lite/kernels/internal:optimized", ], ) @@ -106,7 +123,10 @@ tf_cc_test( name = "kernel_util_test", size = "small", srcs = ["kernel_util_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":kernel_util", "//tensorflow/contrib/lite/testing:util", @@ -118,6 +138,7 @@ tf_cc_test( name = "test_util_test", size = "small", srcs = ["test_util_test.cc"], + tags = ["no_oss"], deps = [ ":test_util", "//tensorflow/contrib/lite/testing:util", @@ -130,7 +151,7 @@ cc_library( srcs = [ "activations.cc", "add.cc", - "arg_max.cc", + "arg_min_max.cc", "audio_spectrogram.cc", "basic_rnn.cc", "batch_to_space_nd.cc", @@ -142,31 +163,38 @@ cc_library( "conv.cc", "depthwise_conv.cc", "dequantize.cc", + "detection_postprocess.cc", "div.cc", "elementwise.cc", "embedding_lookup.cc", "embedding_lookup_sparse.cc", "exp.cc", "expand_dims.cc", + "fake_quant.cc", "floor.cc", "fully_connected.cc", "gather.cc", "hashtable_lookup.cc", "l2norm.cc", "local_response_norm.cc", + "logical.cc", "lsh_projection.cc", "lstm.cc", "maximum_minimum.cc", - "mean.cc", "mfcc.cc", "mul.cc", "neg.cc", + "one_hot.cc", + "pack.cc", "pad.cc", "pooling.cc", + "pow.cc", + "reduce.cc", "register.cc", "reshape.cc", "resize_bilinear.cc", "select.cc", + "shape.cc", "skip_gram.cc", "slice.cc", "space_to_batch_nd.cc", @@ -188,14 +216,7 @@ cc_library( "padding.h", "register.h", ], - # Suppress warnings that are introduced by Eigen Tensor. - copts = tflite_copts() + [ - "-Wno-error=reorder", - ] + select({ - "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"], - "//conditions:default": [ - ], - }), + copts = tflite_copts() + EXTRA_EIGEN_COPTS, deps = [ ":activation_functor", ":eigen_support", @@ -222,7 +243,10 @@ tf_cc_test( name = "audio_spectrogram_test", size = "small", srcs = ["audio_spectrogram_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -236,7 +260,27 @@ tf_cc_test( name = "mfcc_test", size = "small", srcs = ["mfcc_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + +tf_cc_test( + name = "detection_postprocess_test", + size = "small", + srcs = ["detection_postprocess_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -273,10 +317,11 @@ tf_cc_test( ) tf_cc_test( - name = "arg_max_test", + name = "arg_min_max_test", size = "small", - srcs = ["arg_max_test.cc"], + srcs = ["arg_min_max_test.cc"], tags = [ + "no_oss", "tflite_not_portable_ios", ], deps = [ @@ -291,7 +336,10 @@ tf_cc_test( name = "div_test", size = "small", srcs = ["div_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -304,7 +352,10 @@ tf_cc_test( name = "sub_test", size = "small", srcs = ["sub_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -317,7 +368,10 @@ tf_cc_test( name = "transpose_test", size = "small", srcs = ["transpose_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -332,7 +386,10 @@ tf_cc_test( name = "space_to_batch_nd_test", size = "small", srcs = ["space_to_batch_nd_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -345,7 +402,10 @@ tf_cc_test( name = "batch_to_space_nd_test", size = "small", srcs = ["batch_to_space_nd_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -358,7 +418,10 @@ tf_cc_test( name = "cast_test", size = "small", srcs = ["cast_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -411,7 +474,10 @@ tf_cc_test( name = "dequantize_test", size = "small", srcs = ["dequantize_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -438,7 +504,10 @@ tf_cc_test( name = "bidirectional_sequence_lstm_test", size = "small", srcs = ["bidirectional_sequence_lstm_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -451,7 +520,10 @@ tf_cc_test( name = "floor_test", size = "small", srcs = ["floor_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -464,7 +536,10 @@ tf_cc_test( name = "elementwise_test", size = "small", srcs = ["elementwise_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -477,7 +552,10 @@ tf_cc_test( name = "unidirectional_sequence_lstm_test", size = "small", srcs = ["unidirectional_sequence_lstm_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -491,6 +569,7 @@ tf_cc_test( size = "small", srcs = ["bidirectional_sequence_rnn_test.cc"], tags = [ + "no_oss", "tflite_not_portable", ], deps = [ @@ -505,7 +584,10 @@ tf_cc_test( name = "unidirectional_sequence_rnn_test", size = "small", srcs = ["unidirectional_sequence_rnn_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -531,7 +613,26 @@ tf_cc_test( name = "exp_test", size = "small", srcs = ["exp_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "fake_quant_test", + size = "small", + srcs = ["fake_quant_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -544,7 +645,10 @@ tf_cc_test( name = "maximum_minimum_test", size = "small", srcs = ["maximum_minimum_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -554,10 +658,13 @@ tf_cc_test( ) tf_cc_test( - name = "mean_test", + name = "reduce_test", size = "small", - srcs = ["mean_test.cc"], - tags = ["tflite_not_portable_ios"], + srcs = ["reduce_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -583,7 +690,10 @@ tf_cc_test( name = "pad_test", size = "small", srcs = ["pad_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -609,7 +719,10 @@ tf_cc_test( name = "gather_test", size = "small", srcs = ["gather_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:builtin_op_data", @@ -623,7 +736,10 @@ tf_cc_test( name = "topk_v2_test", size = "small", srcs = ["topk_v2_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:builtin_op_data", @@ -744,7 +860,10 @@ tf_cc_test( name = "log_softmax_test", size = "small", srcs = ["log_softmax_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -825,7 +944,10 @@ tf_cc_test( name = "split_test", size = "small", srcs = ["split_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -838,7 +960,10 @@ tf_cc_test( name = "squeeze_test", size = "small", srcs = ["squeeze_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -851,7 +976,10 @@ tf_cc_test( name = "strided_slice_test", size = "small", srcs = ["strided_slice_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -864,7 +992,10 @@ tf_cc_test( name = "tile_test", size = "small", srcs = ["tile_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:builtin_op_data", @@ -881,6 +1012,7 @@ tf_cc_test( "comparisons_test.cc", ], tags = [ + "no_oss", "tflite_not_portable_ios", ], deps = [ @@ -895,7 +1027,10 @@ tf_cc_test( name = "neg_test", size = "small", srcs = ["neg_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -911,6 +1046,7 @@ tf_cc_test( "select_test.cc", ], tags = [ + "no_oss", "tflite_not_portable_ios", ], deps = [ @@ -928,6 +1064,7 @@ tf_cc_test( "slice_test.cc", ], tags = [ + "no_oss", "tflite_not_portable_ios", ], deps = [ @@ -942,7 +1079,10 @@ tf_cc_test( name = "transpose_conv_test", size = "small", srcs = ["transpose_conv_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -955,7 +1095,10 @@ tf_cc_test( name = "expand_dims_test", size = "small", srcs = ["expand_dims_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:builtin_op_data", @@ -969,6 +1112,84 @@ tf_cc_test( name = "sparse_to_dense_test", size = "small", srcs = ["sparse_to_dense_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "shape_test", + size = "small", + srcs = ["shape_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "pow_test", + size = "small", + srcs = ["pow_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "pack_test", + size = "small", + srcs = ["pack_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "one_hot_test", + size = "small", + srcs = ["one_hot_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "logical_test", + size = "small", + srcs = ["logical_test.cc"], tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index add36b46c0b8a4deab1e842d50194c8b99a3a20c..d6d62580e2d26328a9ac4c91dabe581fa3f47ad1 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -41,6 +40,11 @@ struct OpData { int diff_min = 0; }; +struct LogSoftmaxOpData : public OpData { + int32_t reverse_scaling_divisor = 0; + int32_t reverse_scaling_right_shift = 0; +}; + void* Init(TfLiteContext* context, const char* buffer, size_t length) { // This is a builtin op, so we don't use the contents in 'buffer', if any. // Instead, we allocate a new object to carry information from Prepare() to @@ -48,10 +52,19 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { return new OpData; } +void* LogSoftmaxInit(TfLiteContext* context, const char* buffer, + size_t length) { + return new LogSoftmaxOpData; +} + void Free(TfLiteContext* context, void* buffer) { delete reinterpret_cast(buffer); } +void LogSoftmaxFree(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -84,6 +97,38 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) { &data->input_left_shift); data->input_range_radius = CalculateInputRadius(kInputIntegerBits, data->input_left_shift); + } else if (input->type == kTfLiteInt16) { + static constexpr int kInputIntegerBits = 3; + static constexpr int kOutputFractionalBits = 15; + + // These operators are implemented in fixed-point arithmetic, + // which intrinsically wants symmetric ranges (zero_point==0) + // and power-of-two scales (power-of-two is abbreviated below as POT). + // While more general support would be possible by means of rescaling, + // that would add some overhead and some loss of accuracy and wouldn't + // be used at the moment as current quantized LSTM applications are + // happy with symmetric, power-of-two-scales quantization. So we just + // implement that narrow case only for now. + + TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + int input_scale_log2_rounded; + TF_LITE_ENSURE(context, + CheckedLog2(input->params.scale, &input_scale_log2_rounded)); + + int output_scale_log2_rounded; + TF_LITE_ENSURE( + context, CheckedLog2(output->params.scale, &output_scale_log2_rounded)); + TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded, + -kOutputFractionalBits); + + data->input_left_shift = + (15 - kInputIntegerBits) + input_scale_log2_rounded; + // Support for shifts is limited until we have a parameterized version of + // SaturatingRoundingMultiplyByPOT(). + TF_LITE_ENSURE(context, data->input_left_shift >= 0); + TF_LITE_ENSURE(context, data->input_left_shift <= 1); } return context->ResizeTensor(context, output, @@ -114,6 +159,30 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) { &data->input_left_shift); data->input_range_radius = CalculateInputRadius(kInputIntegerBits, data->input_left_shift); + } else if (input->type == kTfLiteInt16) { + static constexpr int kInputIntegerBits = 3; + static constexpr int kOutputFractionalBits = 15; + + // See comments in TanhPrepare about requiring zero_point==0 + // and a power-of-two ("POT") scale. + + TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + int input_scale_log2_rounded; + TF_LITE_ENSURE(context, + CheckedLog2(input->params.scale, &input_scale_log2_rounded)); + + int output_scale_log2_rounded; + TF_LITE_ENSURE( + context, CheckedLog2(output->params.scale, &output_scale_log2_rounded)); + TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded, + -kOutputFractionalBits); + + data->input_left_shift = + (15 - kInputIntegerBits) + input_scale_log2_rounded; + // The int16 logistic implementation does not support shifting of the input. + TF_LITE_ENSURE_EQ(context, data->input_left_shift, 0); } return context->ResizeTensor(context, output, @@ -130,8 +199,8 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); - TF_LITE_ENSURE(context, - NumDimensions(input) == 2 || NumDimensions(input) == 4); + const int num_dims = NumDimensions(input); + TF_LITE_ENSURE(context, num_dims == 1 || num_dims == 2 || num_dims == 4); if (input->type == kTfLiteUInt8) { TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); @@ -150,6 +219,34 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayCopy(input->dims)); } +TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { + LogSoftmaxOpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255); + TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256); + + static const double kBeta = 1.0; + static const int kScaledDiffIntegerBits = 5; + tflite::PreprocessLogSoftmaxScalingExp( + kBeta, input->params.scale, kScaledDiffIntegerBits, + &data->input_multiplier, &data->input_left_shift, + &data->reverse_scaling_divisor, &data->reverse_scaling_right_shift); + data->reverse_scaling_right_shift *= -1; + data->diff_min = -1.0 * tflite::CalculateInputRadius( + kScaledDiffIntegerBits, data->input_left_shift); + } + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -157,25 +254,25 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); const TfLiteTensor* alpha = GetInput(context, node, 1); - output->type = input->type; - // Currently only Float32 is supported // TODO(ycling): Support other data types. TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, alpha->type, kTfLiteFloat32); + output->type = input->type; - // Currently, only support 4D `input` and 3D `alpha` with shape - // (1, 1, channels). - // TODO(impjdi): Support other cases where `alpha` is broadcastable - // to `input`. - TF_LITE_ENSURE_EQ(context, input->dims->size, 4); - TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], input->dims->data[3]); + // PRelu (parameteric Relu) shares the same alpha value on "shared axis". + // This means it's always required to "broadcast" alpha values in PRelu. + TfLiteIntArray* output_size = nullptr; + TF_LITE_ENSURE_OK( + context, CalculateShapeForBroadcast(context, input, alpha, &output_size)); - return context->ResizeTensor(context, output, - TfLiteIntArrayCopy(input->dims)); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + // After broadcasting, the output shape should always be the same as the + // input shape. + TF_LITE_ENSURE(context, HaveSameShapes(input, output)); + + return kTfLiteOk; } TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { @@ -250,12 +347,19 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { for (; in < in_end; in++, out++) *out = std::tanh(*in); return kTfLiteOk; } break; + case kTfLiteInt16: { + optimized_ops::Tanh(GetTensorData(input), GetTensorShape(input), + data->input_left_shift, + GetTensorData(output), + GetTensorShape(output)); + return kTfLiteOk; + } break; case kTfLiteUInt8: { - optimized_ops::Tanh(GetTensorData(input), GetTensorDims(input), + optimized_ops::Tanh(GetTensorData(input), GetTensorShape(input), input->params.zero_point, data->input_range_radius, data->input_multiplier, data->input_left_shift, GetTensorData(output), - GetTensorDims(output)); + GetTensorShape(output)); return kTfLiteOk; } break; default: @@ -280,12 +384,18 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { for (; in < in_end; in++, out++) *out = 1.f / (1.f + std::exp(-*in)); break; } + case kTfLiteInt16: { + optimized_ops::Logistic( + GetTensorData(input), GetTensorShape(input), + GetTensorData(output), GetTensorShape(output)); + break; + } case kTfLiteUInt8: { optimized_ops::Logistic( - GetTensorData(input), GetTensorDims(input), + GetTensorData(input), GetTensorShape(input), input->params.zero_point, data->input_range_radius, data->input_multiplier, data->input_left_shift, - GetTensorData(output), GetTensorDims(output)); + GetTensorData(output), GetTensorShape(output)); break; } default: @@ -296,13 +406,9 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -// Takes a 2D tensor and perform softmax along the second dimension. -void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params) { - const int batch_size = input->dims->data[0]; - const int input_size = input->dims->data[1]; - float* in = input->data.f; - float* out = output->data.f; +// Performs softmax along the input of size (input_size * batch_size). +void Softmax(const float* in, const int input_size, const int batch_size, + const float beta, float* out) { TF_LITE_ASSERT(input_size > 0); // For each batch @@ -316,7 +422,7 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, // Compute the normalized sum of exps. float exp_sum = 0.0; for (int i = 0; i < input_size; i++) { - out[i] = std::exp((in[i] - max_coeff) * params->beta); + out[i] = std::exp((in[i] - max_coeff) * beta); exp_sum += out[i]; } @@ -332,6 +438,33 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, } } +// Takes a 1D tensor and performs softmax along it. +void Softmax1DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int input_size = input->dims->data[0]; + Softmax(input->data.f, input_size, 1, params->beta, output->data.f); +} + +// Takes a 2D tensor and perform softmax along the last dimension. +void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + Softmax(input->data.f, input_size, batch_size, params->beta, output->data.f); +} + +void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + // TODO(ahentz): this is arguably a dirty trick. Since the implementation + // always traverses the last dimension of a 4D tensor, we will pretend our 1D + // tensor is 4D in a special way. We will convert a (Y) shape into a (1, + // 1, 1, Y) shape. + const int input_size = input->dims->data[0]; + optimized_ops::Softmax( + GetTensorData(input), GetTensorShape({1, 1, 1, input_size}), + data->input_multiplier, data->input_left_shift, data->diff_min, + GetTensorData(output), GetTensorShape({1, 1, 1, input_size})); +} void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params, OpData* data) { // TODO(ahentz): this is arguably a dirty trick. Since the implementation @@ -341,26 +474,26 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, const int batch_size = input->dims->data[0]; const int input_size = input->dims->data[1]; optimized_ops::Softmax(GetTensorData(input), - GetTensorDims({batch_size, 1, 1, input_size}), + GetTensorShape({batch_size, 1, 1, input_size}), data->input_multiplier, data->input_left_shift, data->diff_min, GetTensorData(output), - GetTensorDims({batch_size, 1, 1, input_size})); + GetTensorShape({batch_size, 1, 1, input_size})); } // Takes a 4D tensor and perform softmax along the forth dimension. void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params) { - optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + optimized_ops::Softmax(GetTensorData(input), GetTensorShape(input), params->beta, GetTensorData(output), - GetTensorDims(output)); + GetTensorShape(output)); } void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params, OpData* data) { - optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + optimized_ops::Softmax(GetTensorData(input), GetTensorShape(input), data->input_multiplier, data->input_left_shift, data->diff_min, GetTensorData(output), - GetTensorDims(output)); + GetTensorShape(output)); } TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { @@ -374,6 +507,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { // dimensions. switch (input->type) { case kTfLiteFloat32: { + if (NumDimensions(input) == 1) { + Softmax1DFloat(input, output, params); + return kTfLiteOk; + } if (NumDimensions(input) == 2) { Softmax2DFloat(input, output, params); return kTfLiteOk; @@ -383,11 +520,15 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } context->ReportError( - context, "Only 2D and 4D tensors supported currently, got %dD.", + context, "Only 1D, 2D and 4D tensors supported currently, got %dD.", NumDimensions(input)); return kTfLiteError; } case kTfLiteUInt8: { + if (NumDimensions(input) == 1) { + Softmax1DQuantized(input, output, params, data); + return kTfLiteOk; + } if (NumDimensions(input) == 2) { Softmax2DQuantized(input, output, params, data); return kTfLiteOk; @@ -410,13 +551,23 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + const LogSoftmaxOpData* data = + reinterpret_cast(node->user_data); const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { case kTfLiteFloat32: optimized_ops::LogSoftmax( - GetTensorData(input), GetTensorDims(input), - GetTensorData(output), GetTensorDims(output)); + GetTensorData(input), GetTensorShape(input), + GetTensorData(output), GetTensorShape(output)); + return kTfLiteOk; + case kTfLiteUInt8: + optimized_ops::LogSoftmax( + GetTensorData(input), GetTensorShape(input), + data->input_multiplier, data->input_left_shift, + data->reverse_scaling_divisor, data->reverse_scaling_right_shift, + data->diff_min, GetTensorData(output), + GetTensorShape(output)); return kTfLiteOk; default: context->ReportError(context, "Only float32 supported currently., got %d", @@ -425,33 +576,24 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } } +template +T ApplyPrelu(T input, T alpha) { + return input >= 0.0 ? input : input * alpha; +} + TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* alpha = GetInput(context, node, 1); - const TfLiteTensor* output = GetOutput(context, node, 0); - + TfLiteTensor* output = GetOutput(context, node, 0); if (input->type != kTfLiteFloat32) { context->ReportError(context, "Only float32 supported currently, got %d.", input->type); return kTfLiteError; } - TF_LITE_ENSURE_EQ(context, input->dims->size, 4); - const int batches = input->dims->data[0]; - const int height = input->dims->data[1]; - const int width = input->dims->data[2]; - const int channels = input->dims->data[3]; - - TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], channels); - - const int n = batches * height * width * channels; - for (int i = 0; i < n; ++i) { - const float x = input->data.f[i]; - output->data.f[i] = x >= 0.0f ? x : alpha->data.f[i % channels] * x; - } - + reference_ops::BroadcastBinaryFunction( + GetTensorData(input), GetTensorDims(input), + GetTensorData(alpha), GetTensorDims(alpha), + GetTensorData(output), GetTensorDims(output), ApplyPrelu); return kTfLiteOk; } @@ -500,9 +642,9 @@ TfLiteRegistration* Register_SOFTMAX() { } TfLiteRegistration* Register_LOG_SOFTMAX() { - static TfLiteRegistration r = {activations::Init, activations::Free, - activations::GenericPrepare, - activations::LogSoftmaxEval}; + static TfLiteRegistration r = { + activations::LogSoftmaxInit, activations::LogSoftmaxFree, + activations::LogSoftmaxPrepare, activations::LogSoftmaxEval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index 50a84edd475c8051a563cf8ed9fc03099829b786..e577e3a762b9db62a8b84f159b8502bc991f97e2 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -75,23 +75,42 @@ class FloatActivationsOpModel : public BaseActivationsOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; -// TODO(ahentz): I don't quite understand the tradeoffs in the quantized -// implementation of sigmoid and software, but a tolerance of twice the output -// scale seems reasonable. We might want to change this if we have a better -// theoretical bound. +// Our fixed-point math function implementations have roughly 12 bits of +// accuracy, when specialized to 16-bit fixed-point arithmetic. +// That is purely an implementation compromise, it would have been possible +// to get closer to 16 bits of accuracy but that would be more expensive, +// and not needed for our purposes as ultimately the output is either +// immediately down-quantized to 8 bits, or will typically be at the output +// of the surrounding LSTM cell. +// So we can require roughly 2^-12 accuracy when the output is 16-bit, and +// we can more or less expect the full 2^-8 accuracy when the output is 8-bit. +// +// However, the representable output interval is often [-1, 1] (it has to be +// for tanh, and even for logistic, when we implement it in fixed-point, we +// typically have to do so on such a symmetric interval, e.g. ARM NEON only +// has signed fixed-point arithmetic (SQRDMULH)). As the width of [-1, 1] +// is 2, our representable values are often diluted by a factor of 2, whence +// the factor of 2 below. const float kQuantizedTolerance = 2 * (1. / 256); +const float kQuantizedToleranceInt16 = 2 * (1. / 4096); class QuantizedActivationsOpModel : public BaseActivationsOpModel { public: using BaseActivationsOpModel::BaseActivationsOpModel; + template void SetInput(std::initializer_list data) { - QuantizeAndPopulate(input_, data); + QuantizeAndPopulate(input_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + + std::vector GetOutput() { + return ExtractVector(output_); + } + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); } }; @@ -152,24 +171,47 @@ TEST(FloatActivationsOpTest, Tanh) { } TEST(QuantizedActivationsOpTest, Tanh) { + const float kMin = -1; + const float kMax = 127.f / 128.f; QuantizedActivationsOpModel m( BuiltinOperator_TANH, - /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -8, 8}, - /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, -1, 1}); - m.SetInput({ + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, kMin, kMax}); + m.SetInput({ + 0, -6, 2, 4, // + -4, -2, 8, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.0, -0.999987, 0.964027, 0.999329, // + -0.999329, -0.96402, 0.99999, 0.76159, // + }, + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 225})); +} + +TEST(QuantizedActivationsOpTest, TanhInt16) { + const float kMin = -1; + const float kMax = 32767.f / 32768.f; + QuantizedActivationsOpModel m( + BuiltinOperator_TANH, + /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax}); + m.SetInput({ 0, -6, 2, 4, // -4, -2, 8, 1, // }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { 0.0, -0.999987, 0.964027, 0.999329, // - -0.996078, -0.96402, 0.99999, 0.76159, // + -0.999329, -0.96402, 0.99999, 0.76159, // }, - 4 * (1. / 256)))); - EXPECT_THAT(m.GetOutput(), - ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 226})); + kQuantizedToleranceInt16))); } TEST(FloatActivationsOpTest, Sigmoid) { @@ -190,22 +232,43 @@ TEST(QuantizedActivationsOpTest, Sigmoid) { QuantizedActivationsOpModel m( BuiltinOperator_LOGISTIC, /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10}); - m.SetInput({ + m.SetInput({ 0, -6, 2, 4, // 3, -2, 10, 1, // }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { 0.5, 0.002473, 0.880797, 0.982014, // 0.952574, 0.119203, 0.999955, 0.731059, // }, kQuantizedTolerance))); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188})); } +TEST(QuantizedActivationsOpTest, SigmoidInt16) { + const float kMin = -1; + const float kMax = 32767.f / 32768.f; + QuantizedActivationsOpModel m( + BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }, + kQuantizedToleranceInt16))); +} + TEST(FloatActivationsOpTest, Softmax4D) { FloatActivationsOpModel m(0.1, /*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}}); @@ -241,12 +304,12 @@ TEST(QuantizedActivationsOpTest, Softmax4D) { QuantizedActivationsOpModel m( 0.1, /*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10}); - m.SetInput({ + m.SetInput({ 0, -6, 2, 4, // depth = 0 3, -2, 10, 1, // depth = 1 }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { .23463, .12877, .28658, .35003, // @@ -258,21 +321,45 @@ TEST(QuantizedActivationsOpTest, Softmax4D) { QuantizedActivationsOpModel m2( 0.1, /*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10}); - m2.SetInput({ + m2.SetInput({ 0, -6, // 2, 4, // 3, -2, // 10, 1, // }); m2.Invoke(); - EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - { - 0.645656, 0.354344, // - 0.450166, 0.549834, // - 0.622459, 0.377541, // - 0.710949, 0.28905, // - }, - kQuantizedTolerance))); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + +TEST(FloatActivationsOpTest, Softmax1D) { + FloatActivationsOpModel m(0.1, + /*input=*/{TensorType_FLOAT32, {8}}); + m.SetInput({0, -6, 2, 4, 3, -2, 10, 1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {.09752, .05352, .11911, .14548, .13164, .07984, .26509, .10778}))); +} + +TEST(QuantizedActivationsOpTest, Softmax1D) { + QuantizedActivationsOpModel m(0.1, + /*input=*/{TensorType_UINT8, {8}, -10, 10}); + m.SetInput({0, -6, 2, 4, 3, -2, 10, 1}); + m.Invoke(); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({0.09766, 0.05469, 0.12109, 0.14453, + 0.13281, 0.07813, 0.26563, 0.10938}, + kQuantizedTolerance))); } TEST(FloatActivationsOpTest, Softmax2D) { @@ -309,12 +396,12 @@ TEST(FloatActivationsOpTest, Softmax2D) { TEST(QuantizedActivationsOpTest, Softmax2D) { QuantizedActivationsOpModel m(0.1, /*input=*/{TensorType_UINT8, {2, 4}, -10, 10}); - m.SetInput({ + m.SetInput({ 0, -6, 2, 4, // 3, -2, 10, 1, // }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { .23463, .12877, .28658, .35003, // @@ -325,21 +412,22 @@ TEST(QuantizedActivationsOpTest, Softmax2D) { // Same input, but a different shape. QuantizedActivationsOpModel m2(0.1, /*input=*/{TensorType_UINT8, {4, 2}, -10, 10}); - m2.SetInput({ + m2.SetInput({ 0, -6, // 2, 4, // 3, -2, // 10, 1, // }); m2.Invoke(); - EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - { - 0.645656, 0.354344, // - 0.450166, 0.549834, // - 0.622459, 0.377541, // - 0.710949, 0.28905, // - }, - kQuantizedTolerance))); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); } // This contains the same test values as the Softmax test, but reference answer @@ -383,6 +471,28 @@ TEST(FloatActivationsOpTest, LogSoftmax) { }))); } +TEST(QuantizedActivationsOpTest, LogSoftmax) { + const float kLogSoftmaxQuantizedTolerance = 16 / 256.0; + QuantizedActivationsOpModel m( + BuiltinOperator_LOG_SOFTMAX, + /*input=*/{TensorType_UINT8, {2, 4}, -10, 10}, + /*output=*/{TensorType_UINT8, {}, 0, 0, 16. / 256, 255}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + -4.14297, -10.14297, -2.14297, -.142971, // + -7.00104, -12.00104, -.00104087, -9.00104, // + }, + kLogSoftmaxQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({189, 93, 221, 253, 142, 63, 255, 111})); +} + class PReluOpModel : public SingleOpModel { public: PReluOpModel(const TensorData& input, const TensorData& alpha) { diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index 443ce8924a43669fb264e19561c733d7e3436cb0..af9b5c7013afc5d32d01cba07492a282727b3e12 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -39,6 +39,23 @@ constexpr int kOutputTensor = 0; struct OpData { bool requires_broadcast; + + // These fields are used in both the general 8-bit -> 8bit quantized path, + // and the special 16-bit -> 16bit quantized path + int input1_shift; + int input2_shift; + int32 output_activation_min; + int32 output_activation_max; + + // These fields are used only in the general 8-bit -> 8bit quantized path + int32 input1_multiplier; + int32 input2_multiplier; + int32 output_multiplier; + int output_shift; + int left_shift; + int32 input1_offset; + int32 input2_offset; + int32 output_offset; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -52,6 +69,7 @@ void Free(TfLiteContext* context, void* buffer) { } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); @@ -74,92 +92,182 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size = TfLiteIntArrayCopy(input1->dims); } + if (output->type == kTfLiteUInt8) { + // 8bit -> 8bit general quantized path, with general rescalings + data->input1_offset = -input1->params.zero_point; + data->input2_offset = -input2->params.zero_point; + data->output_offset = output->params.zero_point; + data->left_shift = 20; + const double twice_max_input_scale = + 2 * std::max(input1->params.scale, input2->params.scale); + const double real_input1_multiplier = + input1->params.scale / twice_max_input_scale; + const double real_input2_multiplier = + input2->params.scale / twice_max_input_scale; + const double real_output_multiplier = + twice_max_input_scale / + ((1 << data->left_shift) * output->params.scale); + + QuantizeMultiplierSmallerThanOneExp( + real_input1_multiplier, &data->input1_multiplier, &data->input1_shift); + + QuantizeMultiplierSmallerThanOneExp( + real_input2_multiplier, &data->input2_multiplier, &data->input2_shift); + + QuantizeMultiplierSmallerThanOneExp( + real_output_multiplier, &data->output_multiplier, &data->output_shift); + + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + + } else if (output->type == kTfLiteInt16) { + // 16bit -> 16bit special quantized path, supporting only a rather + // narrow case of quantization parameters: zero_points must all be 0 + // ("symmetric quantization") and scales must be power-of-two (which + // we abbreviate as "POT" below). The intended use case for this path + // is in LSTM cells, where, due to the constraints of implementing + // some of the math in these LSTM cells in fixed-point arithmetic, + // we need to have such symmetric, power-of-two quantization + // (Fixed-point formats are inherently symmetric, power-of-two). + TF_LITE_ENSURE_EQ(context, input1->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, input2->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + int input1_scale_log2_rounded; + bool input1_scale_is_pot = + CheckedLog2(input1->params.scale, &input1_scale_log2_rounded); + TF_LITE_ENSURE(context, input1_scale_is_pot); + + int input2_scale_log2_rounded; + bool input2_scale_is_pot = + CheckedLog2(input2->params.scale, &input2_scale_log2_rounded); + TF_LITE_ENSURE(context, input2_scale_is_pot); + + int output_scale_log2_rounded; + bool output_scale_is_pot = + CheckedLog2(output->params.scale, &output_scale_log2_rounded); + TF_LITE_ENSURE(context, output_scale_is_pot); + + data->input1_shift = input1_scale_log2_rounded - output_scale_log2_rounded; + data->input2_shift = input2_scale_log2_rounded - output_scale_log2_rounded; + + // Shifting of one input is supported. The graph quantization should ensure + // that the other input matches the output. + TF_LITE_ENSURE(context, data->input1_shift == 0 || data->input2_shift == 0); + TF_LITE_ENSURE(context, data->input1_shift <= 0); + TF_LITE_ENSURE(context, data->input2_shift <= 0); + + CalculateActivationRangeQuantized(context, params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + return context->ResizeTensor(context, output, output_size); } template -void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); -#define TF_LITE_ADD(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) - if (kernel_type == kReference) { - if (data->requires_broadcast) { - TF_LITE_ADD(reference_ops, BroadcastAdd); +void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, + const OpData* data, const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { +#define TF_LITE_ADD(type, opname, data_type) \ + data_type output_activation_min, output_activation_max; \ + CalculateActivationRange(params->activation, &output_activation_min, \ + &output_activation_max); \ + tflite::ArithmeticParams op_params; \ + SetActivationParams(output_activation_min, output_activation_max, \ + &op_params); \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)) + if (output->type == kTfLiteInt32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, int32_t); + } else { + TF_LITE_ADD(reference_ops, Add, int32_t); + } } else { - TF_LITE_ADD(reference_ops, Add); + if (data->requires_broadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, int32_t); + } else { + TF_LITE_ADD(optimized_ops, Add, int32_t); + } } - } else { - if (data->requires_broadcast) { - TF_LITE_ADD(optimized_ops, BroadcastAdd); + } else if (output->type == kTfLiteFloat32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, float); + } else { + TF_LITE_ADD(reference_ops, Add, float); + } } else { - TF_LITE_ADD(optimized_ops, Add); + if (data->requires_broadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, float); + } else { + TF_LITE_ADD(optimized_ops, Add, float); + } } } #undef TF_LITE_ADD } template -void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - auto input1_offset = -input1->params.zero_point; - auto input2_offset = -input2->params.zero_point; - auto output_offset = output->params.zero_point; - const int left_shift = 20; - const double twice_max_input_scale = - 2 * std::max(input1->params.scale, input2->params.scale); - const double real_input1_multiplier = - input1->params.scale / twice_max_input_scale; - const double real_input2_multiplier = - input2->params.scale / twice_max_input_scale; - const double real_output_multiplier = - twice_max_input_scale / ((1 << left_shift) * output->params.scale); - - int32 input1_multiplier; - int input1_shift; - QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, - &input1_multiplier, &input1_shift); - input1_shift *= -1; - int32 input2_multiplier; - int input2_shift; - QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, - &input2_multiplier, &input2_shift); - input2_shift *= -1; - int32 output_multiplier; - int output_shift; - QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, - &output_multiplier, &output_shift); - output_shift *= -1; - - int32 output_activation_min, output_activation_max; - CalculateActivationRangeUint8(params->activation, output, - &output_activation_min, &output_activation_max); - -#define TF_LITE_ADD(type, opname) \ - type::opname(left_shift, GetTensorData(input1), \ - GetTensorDims(input1), input1_offset, input1_multiplier, \ - input1_shift, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, input2_multiplier, \ - input2_shift, output_offset, output_multiplier, output_shift, \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)); - // The quantized version of Add doesn't support activations, so we - // always use BroadcastAdd. - if (kernel_type == kReference) { - TF_LITE_ADD(reference_ops, BroadcastAdd); - } else { - TF_LITE_ADD(optimized_ops, BroadcastAdd); - } +TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, const OpData* data, + const TfLiteTensor* input1, + const TfLiteTensor* input2, + TfLiteTensor* output) { + if (output->type == kTfLiteUInt8) { +#define TF_LITE_ADD(type, opname) \ + tflite::ArithmeticParams op_params; \ + op_params.left_shift = data->left_shift; \ + op_params.input1_offset = data->input1_offset; \ + op_params.input1_multiplier = data->input1_multiplier; \ + op_params.input1_shift = data->input1_shift; \ + op_params.input2_offset = data->input2_offset; \ + op_params.input2_multiplier = data->input2_multiplier; \ + op_params.input2_shift = data->input2_shift; \ + op_params.output_offset = data->output_offset; \ + op_params.output_multiplier = data->output_multiplier; \ + op_params.output_shift = data->output_shift; \ + SetActivationParams(data->output_activation_min, \ + data->output_activation_max, &op_params); \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)) + // The quantized version of Add doesn't support activations, so we + // always use BroadcastAdd. + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow); + } else { + TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow); + } #undef TF_LITE_ADD + } else if (output->type == kTfLiteInt16) { +#define TF_LITE_ADD(type, opname) \ + tflite::ArithmeticParams op_params; \ + op_params.input1_shift = data->input1_shift; \ + op_params.input2_shift = data->input2_shift; \ + SetActivationParams(data->output_activation_min, \ + data->output_activation_max, &op_params); \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)) + // The quantized version of Add doesn't support activations, so we + // always use BroadcastAdd. + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops, Add); + } else { + TF_LITE_ADD(optimized_ops, Add); + } +#undef TF_LITE_ADD + } + + return kTfLiteOk; } template @@ -171,15 +279,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (output->type == kTfLiteFloat32) { - EvalAddFloat(context, node, params, data, input1, input2, - output); - } else if (output->type == kTfLiteUInt8) { - EvalAddQuantized(context, node, params, data, input1, input2, - output); + if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { + EvalAdd(context, node, params, data, input1, input2, output); + } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + TF_LITE_ENSURE_OK(context, + EvalAddQuantized(context, node, params, data, + input1, input2, output)); } else { context->ReportError(context, - "Inputs and outputs not all float|uint8 types."); + "Inputs and outputs not all float|uint8|int16 types."); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc index 956d05bed5162f6ce59705d59aad77ff056dda77..0b5844321133de103919de76d367574f018a6698 100644 --- a/tensorflow/contrib/lite/kernels/add_test.cc +++ b/tensorflow/contrib/lite/kernels/add_test.cc @@ -52,6 +52,13 @@ class FloatAddOpModel : public BaseAddOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; +class IntegerAddOpModel : public BaseAddOpModel { + public: + using BaseAddOpModel::BaseAddOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + class QuantizedAddOpModel : public BaseAddOpModel { public: using BaseAddOpModel::BaseAddOpModel; @@ -60,15 +67,26 @@ class QuantizedAddOpModel : public BaseAddOpModel { return Dequantize(ExtractVector(output_), GetScale(output_), GetZeroPoint(output_)); } + + std::vector GetDequantizedOutputInt16() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } }; // for quantized Add, the error shouldn't exceed 2*step -float GetTolerance(int min, int max) { +float GetTolerance(float min, float max) { float kQuantizedStep = (max - min) / 255.0; float kQuantizedTolerance = 2.0 * kQuantizedStep; return kQuantizedTolerance; } +float GetToleranceInt16(float min, float max) { + float kQuantizedStep = (max - min) / 32767.f; + float kQuantizedTolerance = 2.0 * kQuantizedStep; + return kQuantizedTolerance; +} + TEST(FloatAddOpModel, NoActivation) { FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, @@ -122,6 +140,57 @@ TEST(FloatAddOpModel, WithBroadcast) { } } +TEST(IntegerAddOpModel, NoActivation) { + IntegerAddOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-19, 4, 10, 13})); +} + +TEST(IntegerAddOpModel, ActivationRELU_N1_TO_1) { + IntegerAddOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_RELU_N1_TO_1); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 1, 1})); +} + +TEST(IntegerAddOpModel, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerAddOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5, 11, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-19, 04, 10, 13, 22, 21})) + << "With shape number " << i; + } +} + +TEST(IntegerAddOpModel, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerAddOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, // always a scalar + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-19, 3, 8, 9, 12, 21}))) + << "With shape number " << i; + } +} + TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = { @@ -144,6 +213,31 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { } } +TEST(QuantizedAddOpModel, QuantizedTestsNoActivationInt16) { + const float kMin = -1.f; + const float kMax = 32767.f / 32768.f; + float kQuantizedTolerance = GetToleranceInt16(kMin, kMax); + std::vector> inputs1 = { + {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}}; + std::vector> inputs2 = { + {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}}; + std::vector> results = { + {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedAddOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT( + m.GetDequantizedOutputInt16(), + ElementsAreArray(ArrayFloatNear(results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, diff --git a/tensorflow/contrib/lite/kernels/arg_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc similarity index 70% rename from tensorflow/contrib/lite/kernels/arg_max.cc rename to tensorflow/contrib/lite/kernels/arg_min_max.cc index 26f57e88962116f446e72fbc164d2747e8b633b4..4f30d09030fb8d26c08090b180fdd352a967807f 100644 --- a/tensorflow/contrib/lite/kernels/arg_max.cc +++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc @@ -23,7 +23,7 @@ limitations under the License. namespace tflite { namespace ops { namespace builtin { -namespace arg_max { +namespace arg_min_max { constexpr int kInputTensor = 0; constexpr int kAxis = 1; @@ -80,30 +80,39 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } +template +std::function GetComparefunction(bool is_arg_max) { + if (is_arg_max) { + return std::greater(); + } else { + return std::less(); + } +} + // The current impl actually ignores the axis argument. // Only determine the index of the maximum value in the last dimension. -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* axis = GetInput(context, node, kAxis); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); -#define TF_LITE_ARG_MAX(data_type, axis_type, output_type) \ - optimized_ops::ArgMax(GetTensorData(axis), \ - GetTensorData(input), GetTensorDims(input), \ - GetTensorData(output), \ - GetTensorDims(output)) +#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \ + optimized_ops::ArgMinMax( \ + GetTensorData(axis), GetTensorData(input), \ + GetTensorDims(input), GetTensorData(output), \ + GetTensorDims(output), GetComparefunction(is_arg_max)) if (axis->type == kTfLiteInt32) { switch (output->type) { case kTfLiteInt32: { switch (input->type) { case kTfLiteFloat32: - TF_LITE_ARG_MAX(float, int32_t, int32_t); + TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t); break; case kTfLiteUInt8: - TF_LITE_ARG_MAX(uint8_t, int32_t, int32_t); + TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t); break; case kTfLiteInt32: - TF_LITE_ARG_MAX(int32_t, int32_t, int32_t); + TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t); break; default: return kTfLiteError; @@ -112,13 +121,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: { switch (input->type) { case kTfLiteFloat32: - TF_LITE_ARG_MAX(float, int32_t, int64_t); + TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t); break; case kTfLiteUInt8: - TF_LITE_ARG_MAX(uint8_t, int32_t, int64_t); + TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t); break; case kTfLiteInt32: - TF_LITE_ARG_MAX(int32_t, int32_t, int64_t); + TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t); break; default: return kTfLiteError; @@ -132,13 +141,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt32: { switch (input->type) { case kTfLiteFloat32: - TF_LITE_ARG_MAX(float, int64_t, int32_t); + TF_LITE_ARG_MIN_MAX(float, int64_t, int32_t); break; case kTfLiteUInt8: - TF_LITE_ARG_MAX(uint8_t, int64_t, int32_t); + TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t); break; case kTfLiteInt32: - TF_LITE_ARG_MAX(int32_t, int64_t, int32_t); + TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t); break; default: return kTfLiteError; @@ -147,13 +156,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: { switch (input->type) { case kTfLiteFloat32: - TF_LITE_ARG_MAX(float, int64_t, int64_t); + TF_LITE_ARG_MIN_MAX(float, int64_t, int64_t); break; case kTfLiteUInt8: - TF_LITE_ARG_MAX(uint8_t, int64_t, int64_t); + TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t); break; case kTfLiteInt32: - TF_LITE_ARG_MAX(int32_t, int64_t, int64_t); + TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t); break; default: return kTfLiteError; @@ -163,16 +172,30 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } } -#undef TF_LITE_ARG_MAX +#undef TF_LITE_ARG_MIN_MAX return kTfLiteOk; } -} // namespace arg_max +TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, false); +} + +TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, true); +} + +} // namespace arg_min_max TfLiteRegistration* Register_ARG_MAX() { - static TfLiteRegistration r = {nullptr, nullptr, arg_max::Prepare, - arg_max::Eval}; + static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare, + arg_min_max::ArgMaxEval}; + return &r; +} + +TfLiteRegistration* Register_ARG_MIN() { + static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare, + arg_min_max::ArgMinEval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/arg_max_test.cc b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc similarity index 52% rename from tensorflow/contrib/lite/kernels/arg_max_test.cc rename to tensorflow/contrib/lite/kernels/arg_min_max_test.cc index 31b15fe19ab87027c28bde9eaff7d88d03b2c213..90e5fdc532c821691aaeca6e6faa4c24919ca2c8 100644 --- a/tensorflow/contrib/lite/kernels/arg_max_test.cc +++ b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc @@ -24,16 +24,13 @@ namespace { using ::testing::ElementsAreArray; template -class ArgMaxOpModel : public SingleOpModel { +class ArgBaseOpModel : public SingleOpModel { public: - ArgMaxOpModel(std::initializer_list input_shape, TensorType input_type, - TensorType output_type, TensorType index_output_type) { + ArgBaseOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType output_type, TensorType index_output_type) { input_ = AddInput(input_type); axis_ = AddInput(TensorType_INT32); output_ = AddOutput(output_type); - SetBuiltinOp(BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions, - CreateArgMaxOptions(builder_, index_output_type).Union()); - BuildInterpreter({input_shape, {1, 1, 1, 1}}); } int input() { return input_; } @@ -42,12 +39,42 @@ class ArgMaxOpModel : public SingleOpModel { std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } - private: + protected: int input_; int axis_; int output_; }; +template +class ArgMaxOpModel : public ArgBaseOpModel { + public: + ArgMaxOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType output_type, TensorType index_output_type) + : ArgBaseOpModel(input_shape, input_type, output_type, + index_output_type) { + ArgBaseOpModel::SetBuiltinOp( + BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions, + CreateArgMaxOptions(ArgBaseOpModel::builder_, index_output_type) + .Union()); + ArgBaseOpModel::BuildInterpreter({input_shape, {1, 1, 1, 1}}); + } +}; + +template +class ArgMinOpModel : public ArgBaseOpModel { + public: + ArgMinOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType output_type, TensorType index_output_type) + : ArgBaseOpModel(input_shape, input_type, output_type, + index_output_type) { + ArgBaseOpModel::SetBuiltinOp( + BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions, + CreateArgMinOptions(ArgBaseOpModel::builder_, index_output_type) + .Union()); + ArgBaseOpModel::BuildInterpreter({input_shape, {1, 1, 1, 1}}); + } +}; + TEST(ArgMaxOpTest, GetMaxArgFloat) { ArgMaxOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, TensorType_INT32, TensorType_INT32); @@ -96,6 +123,54 @@ TEST(ArgMaxOpTest, GetMaxArgOutput64) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1})); } +TEST(ArgMinOpTest, GetMinArgFloat) { + ArgMinOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, + TensorType_INT32, TensorType_INT32); + model.PopulateTensor(model.input(), {0.1, 0.9, 0.7, 0.3}); + // Currently only support the last dimension. + model.PopulateTensor(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1})); +} + +TEST(ArgMinOpTest, GetMinArgInt) { + ArgMinOpModel model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32, + TensorType_INT32); + model.PopulateTensor(model.input(), {1, 9, 7, 3}); + // Currently only support the last dimension. + model.PopulateTensor(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1})); +} + +TEST(ArgMinOpTest, GetMinArgMulDimensions) { + ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32, + TensorType_INT32); + model.PopulateTensor(model.input(), {1, 2, 7, 8, 1, 9, 7, 3}); + // Currently only support the last dimension. + model.PopulateTensor(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1})); +} + +TEST(ArgMinOpTest, GetMinArgOutput64) { + ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64, + TensorType_INT64); + model.PopulateTensor(model.input(), {10, 2, 7, 8, 1, 9, 7, 3}); + // Currently only support the last dimension. + model.PopulateTensor(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index 3425288f027a6fd9eb65f730bc7d039c832ace1c..a11a59aa050675314ac8b1316cdd0f15c81b8b15 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -276,27 +275,33 @@ TfLiteStatus CheckLstmTensorDimensions( TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell) { - CheckLstmTensorDimensions( - context, node, n_input, n_output, n_cell, kFwInputToInputWeightsTensor, - kFwInputToForgetWeightsTensor, kFwInputToCellWeightsTensor, - kFwInputToOutputWeightsTensor, kFwRecurrentToInputWeightsTensor, - kFwRecurrentToForgetWeightsTensor, kFwRecurrentToCellWeightsTensor, - kFwRecurrentToOutputWeightsTensor, kFwCellToInputWeightsTensor, - kFwCellToForgetWeightsTensor, kFwCellToOutputWeightsTensor, - kFwInputGateBiasTensor, kFwForgetGateBiasTensor, kFwCellGateBiasTensor, - kFwOutputGateBiasTensor, kFwProjectionWeightsTensor, - kFwProjectionBiasTensor); - - CheckLstmTensorDimensions( - context, node, n_input, n_output, n_cell, kBwInputToInputWeightsTensor, - kBwInputToForgetWeightsTensor, kBwInputToCellWeightsTensor, - kBwInputToOutputWeightsTensor, kBwRecurrentToInputWeightsTensor, - kBwRecurrentToForgetWeightsTensor, kBwRecurrentToCellWeightsTensor, - kBwRecurrentToOutputWeightsTensor, kBwCellToInputWeightsTensor, - kBwCellToForgetWeightsTensor, kBwCellToOutputWeightsTensor, - kBwInputGateBiasTensor, kBwForgetGateBiasTensor, kBwCellGateBiasTensor, - kBwOutputGateBiasTensor, kBwProjectionWeightsTensor, - kBwProjectionBiasTensor); + TF_LITE_ENSURE_OK( + context, + CheckLstmTensorDimensions( + context, node, n_input, n_output, n_cell, + kFwInputToInputWeightsTensor, kFwInputToForgetWeightsTensor, + kFwInputToCellWeightsTensor, kFwInputToOutputWeightsTensor, + kFwRecurrentToInputWeightsTensor, kFwRecurrentToForgetWeightsTensor, + kFwRecurrentToCellWeightsTensor, kFwRecurrentToOutputWeightsTensor, + kFwCellToInputWeightsTensor, kFwCellToForgetWeightsTensor, + kFwCellToOutputWeightsTensor, kFwInputGateBiasTensor, + kFwForgetGateBiasTensor, kFwCellGateBiasTensor, + kFwOutputGateBiasTensor, kFwProjectionWeightsTensor, + kFwProjectionBiasTensor)); + + TF_LITE_ENSURE_OK( + context, + CheckLstmTensorDimensions( + context, node, n_input, n_output, n_cell, + kBwInputToInputWeightsTensor, kBwInputToForgetWeightsTensor, + kBwInputToCellWeightsTensor, kBwInputToOutputWeightsTensor, + kBwRecurrentToInputWeightsTensor, kBwRecurrentToForgetWeightsTensor, + kBwRecurrentToCellWeightsTensor, kBwRecurrentToOutputWeightsTensor, + kBwCellToInputWeightsTensor, kBwCellToForgetWeightsTensor, + kBwCellToOutputWeightsTensor, kBwInputGateBiasTensor, + kBwForgetGateBiasTensor, kBwCellGateBiasTensor, + kBwOutputGateBiasTensor, kBwProjectionWeightsTensor, + kBwProjectionBiasTensor)); // Check if Forward and Backward tensors match along required dimensions. return kTfLiteOk; @@ -334,7 +339,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - CheckInputTensorDimensions(context, node, n_input, n_fw_output, n_fw_cell); + TF_LITE_ENSURE_OK( + context, CheckInputTensorDimensions(context, node, n_input, n_fw_output, + n_fw_cell)); // Get the pointer to output, state and scratch buffer tensors. TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); @@ -404,7 +411,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - CheckInputTensorDimensions(context, node, n_input, n_bw_output, n_bw_cell); + TF_LITE_ENSURE_OK( + context, CheckInputTensorDimensions(context, node, n_input, n_bw_output, + n_bw_cell)); // Get the pointer to output, output_state and cell_state buffer tensors. TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc index aa24c1f34cd1e8c02a6a75b62fbe5f3c629498ca..517309a226bcfb717186be8c1d02d68e3b337f8e 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc index 60770ca0aa8b85d9710d26beca3d4d603da5db2f..8dd48af57fd1bd9ef21256410d6bede6b7baa566 100644 --- a/tensorflow/contrib/lite/kernels/cast.cc +++ b/tensorflow/contrib/lite/kernels/cast.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" @@ -53,6 +54,20 @@ void copyCast(const FromT* in, ToT* out, int num_elements) { [](FromT a) { return static_cast(a); }); } +template +void copyCast(const std::complex* in, ToT* out, int num_elements) { + std::transform(in, in + num_elements, out, [](std::complex a) { + return static_cast(std::real(a)); + }); +} + +template <> +void copyCast(const std::complex* in, std::complex* out, + int num_elements) { + std::transform(in, in + num_elements, out, + [](std::complex a) { return a; }); +} + template TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out, int num_elements) { @@ -72,6 +87,10 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out, case kTfLiteBool: copyCast(in, out->data.b, num_elements); break; + case kTfLiteComplex64: + copyCast(in, reinterpret_cast*>(out->data.c64), + num_elements); + break; default: // Unsupported type. return kTfLiteError; @@ -95,6 +114,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return copyToTensor(input->data.f, output, num_elements); case kTfLiteBool: return copyToTensor(input->data.b, output, num_elements); + case kTfLiteComplex64: + return copyToTensor( + reinterpret_cast*>(input->data.c64), output, + num_elements); default: // Unsupported type. return kTfLiteError; diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/contrib/lite/kernels/cast_test.cc index 53e20007378392467356ab29ecb8b217bb7a9e89..954f998206563a38c74a1382092851cfbee1013b 100644 --- a/tensorflow/contrib/lite/kernels/cast_test.cc +++ b/tensorflow/contrib/lite/kernels/cast_test.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" @@ -73,6 +75,71 @@ TEST(CastOpModel, CastBoolToFloat) { ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f})); } +TEST(CastOpModel, CastComplex64ToFloat) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); + m.PopulateTensor>( + m.input(), + {std::complex(1.0f, 11.0f), std::complex(2.0f, 12.0f), + std::complex(3.0f, 13.0f), std::complex(4.0f, 14.0f), + std::complex(5.0f, 15.0f), std::complex(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); +} + +TEST(CastOpModel, CastFloatToComplex64) { + CastOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor(m.input(), {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector>(m.output()), + ElementsAreArray( + {std::complex(1.0f, 0.0f), std::complex(2.0f, 0.0f), + std::complex(3.0f, 0.0f), std::complex(4.0f, 0.0f), + std::complex(5.0f, 0.0f), std::complex(6.0f, 0.0f)})); +} + +TEST(CastOpModel, CastComplex64ToInt) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_INT32, {2, 3}}); + m.PopulateTensor>( + m.input(), + {std::complex(1.0f, 11.0f), std::complex(2.0f, 12.0f), + std::complex(3.0f, 13.0f), std::complex(4.0f, 14.0f), + std::complex(5.0f, 15.0f), std::complex(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(CastOpModel, CastIntToComplex64) { + CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor(m.input(), {1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector>(m.output()), + ElementsAreArray( + {std::complex(1.0f, 0.0f), std::complex(2.0f, 0.0f), + std::complex(3.0f, 0.0f), std::complex(4.0f, 0.0f), + std::complex(5.0f, 0.0f), std::complex(6.0f, 0.0f)})); +} + +TEST(CastOpModel, CastComplex64ToComplex64) { + CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_COMPLEX64, {2, 3}}); + m.PopulateTensor>( + m.input(), + {std::complex(1.0f, 11.0f), std::complex(2.0f, 12.0f), + std::complex(3.0f, 13.0f), std::complex(4.0f, 14.0f), + std::complex(5.0f, 15.0f), std::complex(6.0f, 16.0f)}); + m.Invoke(); + EXPECT_THAT( + m.ExtractVector>(m.output()), + ElementsAreArray( + {std::complex(1.0f, 11.0f), std::complex(2.0f, 12.0f), + std::complex(3.0f, 13.0f), std::complex(4.0f, 14.0f), + std::complex(5.0f, 15.0f), + std::complex(6.0f, 16.0f)})); +} + } // namespace } // namespace tflite int main(int argc, char** argv) { diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index f678f48fa5bbbcece6c5b87030d951783378d78f..8b4d778332afd5f4b53509bd669a674c63d9f6f9 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -57,6 +57,57 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } +// TODO(ruic): optimize macros below to using template functions. +#define TF_LITE_QUANTIZE_COMPARISON(opname) \ + void EvalQuantized##opname(TfLiteContext* context, TfLiteNode* node, \ + const TfLiteTensor* input1, \ + const TfLiteTensor* input2, TfLiteTensor* output, \ + bool requires_broadcast) { \ + if (input1->type == kTfLiteUInt8) { \ + auto input1_offset = -input1->params.zero_point; \ + auto input2_offset = -input2->params.zero_point; \ + const int left_shift = 20; \ + const double twice_max_input_scale = \ + 2 * std::max(input1->params.scale, input2->params.scale); \ + const double real_input1_multiplier = \ + input1->params.scale / twice_max_input_scale; \ + const double real_input2_multiplier = \ + input2->params.scale / twice_max_input_scale; \ + \ + int32 input1_multiplier; \ + int input1_shift; \ + QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, \ + &input1_multiplier, &input1_shift); \ + int32 input2_multiplier; \ + int input2_shift; \ + QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \ + &input2_multiplier, &input2_shift); \ + \ + if (requires_broadcast) { \ + reference_ops::Broadcast##opname( \ + left_shift, GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, input1_multiplier, input1_shift, \ + GetTensorData(input2), GetTensorDims(input2), \ + input2_offset, input2_multiplier, input2_shift, \ + GetTensorData(output), GetTensorDims(output)); \ + } else { \ + reference_ops::opname( \ + left_shift, GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, input1_multiplier, input1_shift, \ + GetTensorData(input2), GetTensorDims(input2), \ + input2_offset, input2_multiplier, input2_shift, \ + GetTensorData(output), GetTensorDims(output)); \ + } \ + } \ + } +TF_LITE_QUANTIZE_COMPARISON(Equal); +TF_LITE_QUANTIZE_COMPARISON(NotEqual); +TF_LITE_QUANTIZE_COMPARISON(Greater); +TF_LITE_QUANTIZE_COMPARISON(GreaterEqual); +TF_LITE_QUANTIZE_COMPARISON(Less); +TF_LITE_QUANTIZE_COMPARISON(LessEqual); +#undef TF_LITE_QUANTIZE_COMPARISON + #define TF_LITE_COMPARISON(type, opname, requires_broadcast) \ requires_broadcast \ ? reference_ops::Broadcast##opname( \ @@ -73,7 +124,6 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, Equal, requires_broadcast); @@ -84,9 +134,13 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedEqual(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -99,7 +153,6 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, NotEqual, requires_broadcast); @@ -110,9 +163,13 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedNotEqual(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -124,7 +181,6 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, Greater, requires_broadcast); @@ -135,9 +191,13 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedGreater(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -149,7 +209,6 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast); @@ -160,9 +219,13 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedGreaterEqual(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -174,7 +237,6 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, Less, requires_broadcast); @@ -185,9 +247,13 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, Less, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedLess(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -199,7 +265,6 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, LessEqual, requires_broadcast); @@ -210,9 +275,13 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedLessEqual(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc index bb02e1c812fdc40bf515f1f978e9e39b5a16a4ea..67a91c17fd4f25e4a9ea22de5e2a10dc1c17656d 100644 --- a/tensorflow/contrib/lite/kernels/comparisons_test.cc +++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc @@ -35,6 +35,15 @@ class ComparisonOpModel : public SingleOpModel { BuildInterpreter({input1_shape, input2_shape}); } + ComparisonOpModel(const TensorData& input1, const TensorData& input2, + TensorType input_type, BuiltinOperator op) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(TensorType_BOOL); + ConfigureBuiltinOp(op); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + int input1() { return input1_; } int input2() { return input2_; } @@ -354,6 +363,192 @@ TEST(ComparisonsTest, LessEqualBroadcastTwoD) { EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } +TEST(QuantizedComparisonsTest, EqualQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, false)); +} + +TEST(QuantizedComparisonsTest, NotEqualQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_NOT_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 7, 0}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, true)); +} + +TEST(ComparisonsTest, GreaterQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_GREATER); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 6, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); +} + +TEST(ComparisonsTest, GreaterEqualQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_GREATER_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 6, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false)); +} + +TEST(ComparisonsTest, LessQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_LESS); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 6, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, true)); +} + +TEST(ComparisonsTest, LessEqualQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_LESS_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 6, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); +} + +TEST(ComparisonsTest, QuantizedEqualWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {2}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, false, false, false, false)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedNotEqualWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_NOT_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {2}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, true, true, true, true)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedGreaterWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_GREATER); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, false, true, true)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedGreaterEqualWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_GREATER_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, true)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedLessWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_LESS); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, false)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedLessEqualWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_LESS_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, true, false, false)) + << "With shape number " << i; + } +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc index 45ea8d00498455be98467f2f1addc8ad7dcf35fa..605a20ac3e7c8346db2bcf64e9422132b433b3da 100644 --- a/tensorflow/contrib/lite/kernels/concatenation.cc +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -58,7 +57,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, t0->dims->size <= 4); TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); TF_LITE_ENSURE(context, - input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || + input_type == kTfLiteInt16 || input_type == kTfLiteInt32 || + input_type == kTfLiteInt64); // Output dimensions will match input dimensions, except 'axis', which // will be the sum of inputs @@ -122,6 +123,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_CONCATENATION(optimized_ops, float); } break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, int32); + } else { + TF_LITE_CONCATENATION(optimized_ops, int32); + } + break; case kTfLiteUInt8: if (kernel_type == kReference) { TF_LITE_CONCATENATION_QUANTIZED(reference_ops); @@ -129,6 +137,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_CONCATENATION_QUANTIZED(optimized_ops); } break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, int64_t); + } else { + TF_LITE_CONCATENATION(optimized_ops, int64_t); + } + break; + default: context->ReportError(context, "Only float32 and uint8 are currently supported."); diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 14b399ef96eab1d5066a22a7eb95ab061e8ba2bc..04c0263b789e75727ed3bd4d6b3292063a4530e0 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -179,9 +178,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node)); - bool hasBias = node->inputs->size == 3; + bool has_bias = node->inputs->size == 3; // Check number of inputs/outputs - TF_LITE_ENSURE(context, hasBias || node->inputs->size == 2); + TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; @@ -204,9 +203,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // TODO(ahentz): At this point the optimized versions require 'bias'. We can // either change that or document that convolution requires it. - TF_LITE_ENSURE(context, hasBias); + TF_LITE_ENSURE(context, has_bias); - if (hasBias) { + if (has_bias) { bias = &context->tensors[node->inputs->data[2]]; if (data_type == kTfLiteUInt8) { TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); @@ -226,29 +225,30 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Matching GetWindowedOutputSize in TensorFlow. auto padding = params->padding; - auto computeOutSize = [padding](int imageSize, int filterSize, int stride, - int dilationRate) -> int { - int effectiveFilterSize = (filterSize - 1) * dilationRate + 1; + auto compute_out_size = [padding](int image_size, int filter_size, int stride, + int dilation_rate) -> int { + int effective_filter_size = (filter_size - 1) * dilation_rate + 1; return padding == kTfLitePaddingSame - ? (imageSize + stride - 1) / stride + ? (image_size + stride - 1) / stride : padding == kTfLitePaddingValid - ? (imageSize - effectiveFilterSize + stride) / stride + ? (image_size - effective_filter_size + stride) / stride : 0; }; - int outWidth = computeOutSize(width, filter_width, params->stride_width, - params->dilation_width_factor); - int outHeight = computeOutSize(height, filter_height, params->stride_height, - params->dilation_height_factor); + int out_width = compute_out_size(width, filter_width, params->stride_width, + params->dilation_width_factor); + int out_height = + compute_out_size(height, filter_height, params->stride_height, + params->dilation_height_factor); data->padding.height = ComputePadding(params->stride_height, params->dilation_height_factor, - height, filter_height, outHeight); + height, filter_height, out_height); data->padding.width = ComputePadding(params->stride_width, params->dilation_width_factor, width, - filter_width, outWidth); + filter_width, out_width); - TF_LITE_ENSURE(context, hasBias); + TF_LITE_ENSURE(context, has_bias); // Note that quantized inference requires that all tensors have their // parameters set. This is usually done during quantized training. @@ -256,10 +256,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); - TF_LITE_ENSURE(context, real_multiplier < 1.0); - QuantizeMultiplierSmallerThanOneExp( - real_multiplier, &data->output_multiplier, &data->output_shift); - data->output_shift *= -1; + + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, &data->output_activation_max); @@ -267,8 +267,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); output_size->data[0] = batches; - output_size->data[1] = outHeight; - output_size->data[2] = outWidth; + output_size->data[1] = out_height; + output_size->data[2] = out_width; output_size->data[3] = channels_out; auto output_status = context->ResizeTensor(context, output, output_size); @@ -308,18 +308,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* hwcn_weights = &context->tensors[node->temporaries->data[data->hwcn_weights_index]]; hwcn_weights->type = data_type; - hwcn_weights->allocation_type = kTfLiteDynamic; - // Make sure we release any previous allocations before we reallocate. - // TODO(petewarden): Persistent arenas would be a better fit for this, but - // they aren't fully implemented yet. - if (hwcn_weights->data.raw) { - free(hwcn_weights->data.raw); - hwcn_weights->data.raw = nullptr; - } + hwcn_weights->allocation_type = kTfLiteArenaRwPersistent; - // Note that hwcn_weights_status is a kTfLiteDynamic tensor, and - // ResizeTensor will actually allocate space for it. The would be more - // efficient if we placed hwcn_weights_status in the persistent arena. auto hwcn_weights_status = context->ResizeTensor(context, hwcn_weights, hwcn_weights_size); if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status; @@ -381,8 +371,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col, TfLiteTensor* hwcn_weights, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); KernelType effective_kernel_type; if (((kernel_type == kMultithreadOptimized) || (kernel_type == kCblasOptimized)) && @@ -427,6 +417,7 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, filter_data = GetTensorData(filter); } multithreaded_ops::Conv( + *eigen_support::GetThreadPoolDevice(context), GetTensorData(input), GetTensorDims(input), filter_data, GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), params->stride_width, params->stride_height, @@ -458,9 +449,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; - bool hasBias = node->inputs->size == 3; + bool has_bias = node->inputs->size == 3; TfLiteTensor* bias = - hasBias ? &context->tensors[node->inputs->data[2]] : nullptr; + has_bias ? &context->tensors[node->inputs->data[2]] : nullptr; TfLiteTensor* im2col = data->need_im2col ? &context->tensors[node->temporaries->data[data->im2col_index]] diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc index 0dcfc826fd218d2d2dfbf89201d2c13fbfe6f0e1..24633c2fd7cb3725977ae6c6459daa829165ccfd 100644 --- a/tensorflow/contrib/lite/kernels/conv_test.cc +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -64,12 +64,6 @@ class BaseConvolutionOpModel : public SingleOpModel { } output_ = AddOutput(output); - if (input.type != TensorType_FLOAT32) { - // The following is required by quantized inference. It is the unittest's - // responsibility to make sure the output scale falls into the correct - // range. - CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); - } SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, CreateConv2DOptions( @@ -441,6 +435,44 @@ TEST_P(ConvolutionOpTest, SimpleTestQuantized) { })); } +TEST_P(ConvolutionOpTest, SimpleTestQuantizedOutputMultiplierGreaterThan1) { + // output_multiplier = 1.0118 + QuantizedConvolutionOpModel quant_op( + GetRegistration(), {TensorType_UINT8, {2, 2, 4, 1}, -128.5, 128}, + {TensorType_UINT8, {3, 2, 2, 1}, -128.5, 128}, + {TensorType_UINT8, {}, -127, 128}); + ConvolutionOpModel float_op( + GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}}); + std::initializer_list input = { + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }; + std::initializer_list filter = { + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }; + std::initializer_list bias = {1, 2, 3}; + + quant_op.SetInput(input); + quant_op.SetFilter(filter); + quant_op.SetBias(bias); + quant_op.Invoke(); + + float_op.SetInput(input); + float_op.SetFilter(filter); + float_op.SetBias(bias); + float_op.Invoke(); + + EXPECT_THAT(quant_op.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1))); +} + TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { QuantizedConvolutionOpModel m(GetRegistration(), {TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc index a308de055f49eddba99d02e264fad11409a799f4..21518156b851892f50c62df7901d71c41fd733f7 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -173,8 +172,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); void (*depthwise_conv)(const float*, const Dims<4>&, const float*, const Dims<4>&, const float*, const Dims<4>&, int, int, diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc index 672b2170e4990f0a7ca9755071d9d086f5ae5c2b..2b0f04489a48cd4402e7574ecc5eeecfd8c6234f 100644 --- a/tensorflow/contrib/lite/kernels/dequantize.cc +++ b/tensorflow/contrib/lite/kernels/dequantize.cc @@ -36,6 +36,21 @@ struct OpContext { TfLiteTensor* output; }; +struct OpData { + // This boolean value is only used when the input tensor is constant. + bool float_dequantized_weights_initialized; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData(); + op_data->float_dequantized_weights_initialized = false; + return op_data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -45,12 +60,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8); op_context.output->type = kTfLiteFloat32; + // If the input tensor is constant, we can persist the dequantized value in + // the output tensor. Otherwise we run dequantize upon each eval. + if (IsConstantTensor(op_context.input)) { + op_context.output->allocation_type = kTfLiteArenaRwPersistent; + } return context->ResizeTensor(context, op_context.output, TfLiteIntArrayCopy(op_context.input->dims)); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpData* op_data = reinterpret_cast(node->user_data); OpContext op_context(context, node); + if (IsConstantTensor(op_context.input) && + op_data->float_dequantized_weights_initialized) { + return kTfLiteOk; + } auto zero_point = op_context.input->params.zero_point; auto scale = op_context.input->params.scale; @@ -59,14 +84,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorDims(op_context.input), zero_point, scale, GetTensorData(op_context.output), GetTensorDims(op_context.output)); + + if (IsConstantTensor(op_context.input)) { + op_data->float_dequantized_weights_initialized = true; + } + return kTfLiteOk; } } // namespace dequantize TfLiteRegistration* Register_DEQUANTIZE_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, dequantize::Prepare, - dequantize::Eval}; + static TfLiteRegistration r = {dequantize::Init, dequantize::Free, + dequantize::Prepare, dequantize::Eval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc new file mode 100644 index 0000000000000000000000000000000000000000..d7bde0ff79bd23fa4c277dd04ec4343663e0ad00 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc @@ -0,0 +1,591 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace detection_postprocess { + +// Input tensors +constexpr int kInputTensorBoxEncodings = 0; +constexpr int kInputTensorClassPredictions = 1; +constexpr int kInputTensorAnchors = 2; + +// Output tensors +constexpr int kOutputTensorDetectionBoxes = 0; +constexpr int kOutputTensorDetectionClasses = 1; +constexpr int kOutputTensorDetectionScores = 2; +constexpr int kOutputTensorNumDetections = 3; + +constexpr int kNumCoordBox = 4; +constexpr int kBatchSize = 1; + +// Object Detection model produces axis-aligned boxes in two formats: +// BoxCorner represents the upper right (xmin, ymin) and +// lower left corner (xmax, ymax). +// CenterSize represents the center (xcenter, ycenter), height and width. +// BoxCornerEncoding and CenterSizeEncoding are related as follows: +// ycenter = y / y_scale * anchor.h + anchor.y; +// xcenter = x / x_scale * anchor.w + anchor.x; +// half_h = 0.5*exp(h/ h_scale)) * anchor.h; +// half_w = 0.5*exp(w / w_scale)) * anchor.w; +// ymin = ycenter - half_h +// ymax = ycenter + half_h +// xmin = xcenter - half_w +// xmax = xcenter + half_w +struct BoxCornerEncoding { + float ymin; + float xmin; + float ymax; + float xmax; +}; + +struct CenterSizeEncoding { + float y; + float x; + float h; + float w; +}; +// We make sure that the memory allocations are contiguous with static assert. +static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox, + "Size of BoxCornerEncoding is 4 float values"); +static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox, + "Size of CenterSizeEncoding is 4 float values"); + +struct OpData { + int max_detections; + int max_classes_per_detection; + float non_max_suppression_score_threshold; + float intersection_over_union_threshold; + int num_classes; + CenterSizeEncoding scale_values; + // Indices of Temporary tensors + int decoded_boxes_index; + int scores_index; + int active_candidate_index; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + const uint8_t* buffer_t = reinterpret_cast(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + op_data->max_detections = m["max_detections"].AsInt32(); + op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32(); + op_data->non_max_suppression_score_threshold = + m["nms_score_threshold"].AsFloat(); + op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat(); + op_data->num_classes = m["num_classes"].AsInt32(); + op_data->scale_values.y = m["y_scale"].AsFloat(); + op_data->scale_values.x = m["x_scale"].AsFloat(); + op_data->scale_values.h = m["h_scale"].AsFloat(); + op_data->scale_values.w = m["w_scale"].AsFloat(); + context->AddTensors(context, 1, &op_data->decoded_boxes_index); + context->AddTensors(context, 1, &op_data->scores_index); + context->AddTensors(context, 1, &op_data->active_candidate_index); + return op_data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +// TODO(chowdhery): Add to kernel_util.h +TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor, + std::initializer_list values) { + TfLiteIntArray* size = TfLiteIntArrayCreate(values.size()); + int index = 0; + for (int v : values) { + size->data[index] = v; + ++index; + } + return context->ResizeTensor(context, tensor, size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* op_data = reinterpret_cast(node->user_data); + // Inputs: box_encodings, scores, anchors + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* input_class_predictions = + GetInput(context, node, kInputTensorClassPredictions); + const TfLiteTensor* input_anchors = + GetInput(context, node, kInputTensorAnchors); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2); + // number of detected boxes + const int num_detected_boxes = + op_data->max_detections * op_data->max_classes_per_detection; + + // Outputs: detection_boxes, detection_scores, detection_classes, + // num_detections + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4); + // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4) + TfLiteTensor* detection_boxes = + GetOutput(context, node, kOutputTensorDetectionBoxes); + detection_boxes->type = kTfLiteFloat32; + SetTensorSizes(context, detection_boxes, + {kBatchSize, num_detected_boxes, kNumCoordBox}); + + // Output Tensor detection_classes: size is set to (1, num_detected_boxes) + TfLiteTensor* detection_classes = + GetOutput(context, node, kOutputTensorDetectionClasses); + detection_classes->type = kTfLiteFloat32; + SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes}); + + // Output Tensor detection_scores: size is set to (1, num_detected_boxes) + TfLiteTensor* detection_scores = + GetOutput(context, node, kOutputTensorDetectionScores); + detection_scores->type = kTfLiteFloat32; + SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes}); + + // Output Tensor num_detections: size is set to 1 + TfLiteTensor* num_detections = + GetOutput(context, node, kOutputTensorNumDetections); + num_detections->type = kTfLiteFloat32; + // TODO (chowdhery): Make it a scalar when available + SetTensorSizes(context, num_detections, {1}); + + // Temporary tensors + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(3); + node->temporaries->data[0] = op_data->decoded_boxes_index; + node->temporaries->data[1] = op_data->scores_index; + node->temporaries->data[2] = op_data->active_candidate_index; + + // decoded_boxes + TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index]; + decoded_boxes->type = kTfLiteFloat32; + decoded_boxes->allocation_type = kTfLiteArenaRw; + SetTensorSizes(context, decoded_boxes, + {input_box_encodings->dims->data[1], kNumCoordBox}); + + // scores + TfLiteTensor* scores = &context->tensors[op_data->scores_index]; + scores->type = kTfLiteFloat32; + scores->allocation_type = kTfLiteArenaRw; + SetTensorSizes(context, scores, + {input_class_predictions->dims->data[1], + input_class_predictions->dims->data[2]}); + + // active_candidate + TfLiteTensor* active_candidate = + &context->tensors[op_data->active_candidate_index]; + active_candidate->type = kTfLiteUInt8; + active_candidate->allocation_type = kTfLiteArenaRw; + SetTensorSizes(context, active_candidate, + {input_box_encodings->dims->data[1]}); + + return kTfLiteOk; +} + +class Dequantizer { + public: + Dequantizer(int zero_point, float scale) + : zero_point_(zero_point), scale_(scale) {} + float operator()(uint8 x) { + return (static_cast(x) - zero_point_) * scale_; + } + + private: + int zero_point_; + float scale_; +}; + +void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx, + float quant_zero_point, float quant_scale, + CenterSizeEncoding* box_centersize) { + const uint8* boxes = + GetTensorData(input_box_encodings) + kNumCoordBox * idx; + Dequantizer dequantize(quant_zero_point, quant_scale); + box_centersize->y = dequantize(boxes[0]); + box_centersize->x = dequantize(boxes[1]); + box_centersize->h = dequantize(boxes[2]); + box_centersize->w = dequantize(boxes[3]); +} + +template +T ReInterpretTensor(const TfLiteTensor* tensor) { + // TODO (chowdhery): check float + const float* tensor_base = tensor->data.f; + return reinterpret_cast(tensor_base); +} + +template +T ReInterpretTensor(TfLiteTensor* tensor) { + // TODO (chowdhery): check float + float* tensor_base = tensor->data.f; + return reinterpret_cast(tensor_base); +} + +TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, + OpData* op_data) { + // Parse input tensor boxencodings + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize); + const int num_boxes = input_box_encodings->dims->data[1]; + TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[2], kNumCoordBox); + const TfLiteTensor* input_anchors = + GetInput(context, node, kInputTensorAnchors); + + // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors + CenterSizeEncoding box_centersize; + CenterSizeEncoding scale_values = op_data->scale_values; + CenterSizeEncoding anchor; + for (int idx = 0; idx < num_boxes; ++idx) { + switch (input_box_encodings->type) { + // Quantized + case kTfLiteUInt8: + DequantizeBoxEncodings( + input_box_encodings, idx, + static_cast(input_box_encodings->params.zero_point), + static_cast(input_box_encodings->params.scale), + &box_centersize); + DequantizeBoxEncodings( + input_anchors, idx, + static_cast(input_anchors->params.zero_point), + static_cast(input_anchors->params.scale), &anchor); + break; + // Float + case kTfLiteFloat32: + box_centersize = ReInterpretTensor( + input_box_encodings)[idx]; + anchor = + ReInterpretTensor(input_anchors)[idx]; + break; + default: + // Unsupported type. + return kTfLiteError; + } + + float ycenter = box_centersize.y / scale_values.y * anchor.h + anchor.y; + float xcenter = box_centersize.x / scale_values.x * anchor.w + anchor.x; + float half_h = + 0.5f * static_cast(std::exp(box_centersize.h / scale_values.h)) * + anchor.h; + float half_w = + 0.5f * static_cast(std::exp(box_centersize.w / scale_values.w)) * + anchor.w; + TfLiteTensor* decoded_boxes = + &context->tensors[op_data->decoded_boxes_index]; + auto& box = ReInterpretTensor(decoded_boxes)[idx]; + box.ymin = ycenter - half_h; + box.xmin = xcenter - half_w; + box.ymax = ycenter + half_h; + box.xmax = xcenter + half_w; + } + return kTfLiteOk; +} + +void DecreasingPartialArgSort(const float* values, int num_values, + int num_to_sort, int* indices) { + std::iota(indices, indices + num_values, 0); + std::partial_sort( + indices, indices + num_to_sort, indices + num_values, + [&values](const int i, const int j) { return values[i] > values[j]; }); +} + +void SelectDetectionsAboveScoreThreshold(const std::vector& values, + const float threshold, + std::vector* keep_values, + std::vector* keep_indices) { + for (int i = 0; i < values.size(); i++) { + if (values[i] >= threshold) { + keep_values->emplace_back(values[i]); + keep_indices->emplace_back(i); + } + } +} + +bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) { + for (int i = 0; i < num_boxes; ++i) { + // ymax>=ymin, xmax>=xmin + auto& box = ReInterpretTensor(decoded_boxes)[i]; + if (box.ymin >= box.ymax || box.xmin >= box.xmax) { + return false; + } + } + return true; +} + +float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes, + const int i, const int j) { + auto& box_i = ReInterpretTensor(decoded_boxes)[i]; + auto& box_j = ReInterpretTensor(decoded_boxes)[j]; + const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin); + const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin); + if (area_i <= 0 || area_j <= 0) return 0.0; + const float intersection_ymin = std::max(box_i.ymin, box_j.ymin); + const float intersection_xmin = std::max(box_i.xmin, box_j.xmin); + const float intersection_ymax = std::min(box_i.ymax, box_j.ymax); + const float intersection_xmax = std::min(box_i.xmax, box_j.xmax); + const float intersection_area = + std::max(intersection_ymax - intersection_ymin, 0.0) * + std::max(intersection_xmax - intersection_xmin, 0.0); + return intersection_area / (area_i + area_j - intersection_area); +} + +// NonMaxSuppressionSingleClass() is O(n^2) pairwise comparison between boxes +// It assumes all boxes are good in beginning and sorts based on the scores. +// If lower-scoring box has too much overlap with a higher-scoring box, +// we get rid of the lower-scoring box. +TfLiteStatus NonMaxSuppressionSingleClassHelper( + TfLiteContext* context, TfLiteNode* node, OpData* op_data, + const std::vector& scores, std::vector* selected) { + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* decoded_boxes = + &context->tensors[op_data->decoded_boxes_index]; + const int num_boxes = input_box_encodings->dims->data[1]; + const int max_detections = op_data->max_detections; + const float non_max_suppression_score_threshold = + op_data->non_max_suppression_score_threshold; + const float intersection_over_union_threshold = + op_data->intersection_over_union_threshold; + // Maximum detections should be positive. + TF_LITE_ENSURE(context, (max_detections >= 0)); + // intersection_over_union_threshold should be positive + // and should be less than 1. + TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) && + (intersection_over_union_threshold <= 1.0f)); + // Validate boxes + TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes)); + + // threshold scores + std::vector keep_indices; + // TODO (chowdhery): Remove the dynamic allocation and replace it + // with temporaries, esp for std::vector + std::vector keep_scores; + SelectDetectionsAboveScoreThreshold( + scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices); + + int num_scores_kept = keep_scores.size(); + std::vector sorted_indices; + sorted_indices.resize(num_scores_kept); + DecreasingPartialArgSort(keep_scores.data(), num_scores_kept, num_scores_kept, + sorted_indices.data()); + + const int num_boxes_kept = num_scores_kept; + const int output_size = std::min(num_boxes_kept, max_detections); + selected->clear(); + TfLiteTensor* active_candidate = + &context->tensors[op_data->active_candidate_index]; + TF_LITE_ENSURE(context, (active_candidate->dims->data[0]) == num_boxes); + int num_active_candidate = num_boxes_kept; + uint8_t* active_box_candidate = (active_candidate->data.uint8); + for (int row = 0; row < num_boxes_kept; row++) { + active_box_candidate[row] = 1; + } + + for (int i = 0; i < num_boxes_kept; ++i) { + if (num_active_candidate == 0 || selected->size() >= output_size) break; + if (active_box_candidate[i] == 1) { + selected->push_back(keep_indices[sorted_indices[i]]); + active_box_candidate[i] = 0; + num_active_candidate--; + } else { + continue; + } + for (int j = i + 1; j < num_boxes_kept; ++j) { + if (active_box_candidate[j] == 1) { + float intersection_over_union = ComputeIntersectionOverUnion( + decoded_boxes, keep_indices[sorted_indices[i]], + keep_indices[sorted_indices[j]]); + + if (intersection_over_union > intersection_over_union_threshold) { + active_box_candidate[j] = 0; + num_active_candidate--; + } + } + } + } + return kTfLiteOk; +} + +// This function implements a fast version of Non Maximal Suppression for +// multiple classes where +// 1) we keep the top-k scores for each anchor and +// 2) during NMS, each anchor only uses the highest class score for sorting. +// 3) Compared to standard NMS, the worst runtime of this version is O(N^2) +// instead of O(KN^2) where N is the number of anchors and K the number of +// classes. +TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, + TfLiteNode* node, + OpData* op_data, + const float* scores) { + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* decoded_boxes = + &context->tensors[op_data->decoded_boxes_index]; + + TfLiteTensor* detection_boxes = + GetOutput(context, node, kOutputTensorDetectionBoxes); + TfLiteTensor* detection_classes = + GetOutput(context, node, kOutputTensorDetectionClasses); + TfLiteTensor* detection_scores = + GetOutput(context, node, kOutputTensorDetectionScores); + TfLiteTensor* num_detections = + GetOutput(context, node, kOutputTensorNumDetections); + + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + const int max_categories_per_anchor = op_data->max_classes_per_detection; + // The row index offset is 1 if background class is included and 0 otherwise. + const int label_offset = 1; + TF_LITE_ENSURE(context, (label_offset != -1)); + TF_LITE_ENSURE(context, (max_categories_per_anchor > 0)); + const int num_classes_with_background = num_classes + label_offset; + const int num_categories_per_anchor = + std::min(max_categories_per_anchor, num_classes); + std::vector max_scores; + max_scores.resize(num_boxes); + std::vector sorted_class_indices; + sorted_class_indices.resize(num_boxes * num_classes); + for (int row = 0; row < num_boxes; row++) { + const float* box_scores = + scores + row * num_classes_with_background + label_offset; + int* class_indices = sorted_class_indices.data() + row * num_classes; + DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor, + class_indices); + max_scores[row] = box_scores[class_indices[0]]; + } + // Perform non-maximal suppression on max scores + std::vector selected; + NonMaxSuppressionSingleClassHelper(context, node, op_data, max_scores, + &selected); + // Allocate output tensors + int output_box_index = 0; + for (const auto& selected_index : selected) { + const float* box_scores = + scores + selected_index * num_classes_with_background + label_offset; + const int* class_indices = + sorted_class_indices.data() + selected_index * num_classes; + + for (int col = 0; col < num_categories_per_anchor; ++col) { + int box_offset = num_categories_per_anchor * output_box_index + col; + // detection_boxes + ReInterpretTensor(detection_boxes)[box_offset] = + ReInterpretTensor( + decoded_boxes)[selected_index]; + // detection_classes + detection_classes->data.f[box_offset] = class_indices[col]; + // detection_scores + detection_scores->data.f[box_offset] = box_scores[class_indices[col]]; + output_box_index++; + } + } + num_detections->data.f[0] = output_box_index; + return kTfLiteOk; +} + +void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions, + const int num_boxes, + const int num_classes_with_background, + const TfLiteTensor* scores) { + float quant_zero_point = + static_cast(input_class_predictions->params.zero_point); + float quant_scale = static_cast(input_class_predictions->params.scale); + Dequantizer dequantize(quant_zero_point, quant_scale); + const uint8* scores_quant = GetTensorData(input_class_predictions); + for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) { + scores->data.f[idx] = dequantize(scores_quant[idx]); + } +} + +TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, + TfLiteNode* node, OpData* op_data) { + // Get the input tensors + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* input_class_predictions = + GetInput(context, node, kInputTensorClassPredictions); + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0], + kBatchSize); + TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes); + const int num_classes_with_background = + input_class_predictions->dims->data[2]; + + TF_LITE_ENSURE(context, (num_classes_with_background == num_classes + 1)); + + const TfLiteTensor* scores; + switch (input_class_predictions->type) { + case kTfLiteUInt8: { + TfLiteTensor* temporary_scores = &context->tensors[op_data->scores_index]; + DequantizeClassPredictions(input_class_predictions, num_boxes, + num_classes_with_background, temporary_scores); + scores = temporary_scores; + } break; + case kTfLiteFloat32: + scores = input_class_predictions; + break; + default: + // Unsupported type. + return kTfLiteError; + } + NonMaxSuppressionMultiClassFastHelper(context, node, op_data, + GetTensorData(scores)); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + // TODO(chowdhery): Generalize for any batch size + TF_LITE_ENSURE(context, (kBatchSize == 1)); + auto* op_data = reinterpret_cast(node->user_data); + // These two functions correspond to two blocks in the Object Detection model. + // In future, we would like to break the custom op in two blocks, which is + // currently not feasible because we would like to input quantized inputs + // and do all calculations in float. Mixed quantized/float calculations are + // currently not supported in TFLite. + + // This fills in temporary decoded_boxes + // by transforming input_box_encodings and input_anchors from + // CenterSizeEncodings to BoxCornerEncoding + DecodeCenterSizeBoxes(context, node, op_data); + // This fills in the output tensors + // by choosing effective set of decoded boxes + // based on Non Maximal Suppression, i.e. selecting + // highest scoring non-overlapping boxes. + NonMaxSuppressionMultiClass(context, node, op_data); + + return kTfLiteOk; +} +} // namespace detection_postprocess + +TfLiteRegistration* Register_DETECTION_POSTPROCESS() { + static TfLiteRegistration r = {detection_postprocess::Init, + detection_postprocess::Free, + detection_postprocess::Prepare, + detection_postprocess::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e0f8484a328d7d1668afd096ad3d08204fbb4a1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc @@ -0,0 +1,235 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_DETECTION_POSTPROCESS(); + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class BaseDetectionPostprocessOpModel : public SingleOpModel { + public: + BaseDetectionPostprocessOpModel(const TensorData& input1, + const TensorData& input2, + const TensorData& input3, + const TensorData& output1, + const TensorData& output2, + const TensorData& output3, + const TensorData& output4) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + input3_ = AddInput(input3); + output1_ = AddOutput(output1); + output2_ = AddOutput(output2); + output3_ = AddOutput(output3); + output4_ = AddOutput(output4); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("max_detections", 3); + fbb.Int("max_classes_per_detection", 1); + fbb.Float("nms_score_threshold", 0.0); + fbb.Float("nms_iou_threshold", 0.5); + fbb.Int("num_classes", 2); + fbb.Float("y_scale", 10.0); + fbb.Float("x_scale", 10.0); + fbb.Float("h_scale", 5.0); + fbb.Float("w_scale", 5.0); + }); + fbb.Finish(); + SetCustomOp("TFLite_Detection_PostProcess", fbb.GetBuffer(), + Register_DETECTION_POSTPROCESS); + BuildInterpreter({GetShape(input1_), GetShape(input2_), GetShape(input3_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + int input3() { return input3_; } + + template + void SetInput1(std::initializer_list data) { + PopulateTensor(input1_, data); + } + + template + void SetInput2(std::initializer_list data) { + PopulateTensor(input2_, data); + } + + template + void SetInput3(std::initializer_list data) { + PopulateTensor(input3_, data); + } + + template + std::vector GetOutput1() { + return ExtractVector(output1_); + } + + template + std::vector GetOutput2() { + return ExtractVector(output2_); + } + + template + std::vector GetOutput3() { + return ExtractVector(output3_); + } + + template + std::vector GetOutput4() { + return ExtractVector(output4_); + } + + std::vector GetOutputShape1() { return GetTensorShape(output1_); } + std::vector GetOutputShape2() { return GetTensorShape(output2_); } + std::vector GetOutputShape3() { return GetTensorShape(output3_); } + std::vector GetOutputShape4() { return GetTensorShape(output4_); } + + protected: + int input1_; + int input2_; + int input3_; + int output1_; + int output2_; + int output3_; + int output4_; +}; + +TEST(DetectionPostprocessOpTest, FloatTest) { + BaseDetectionPostprocessOpModel m( + {TensorType_FLOAT32, {1, 6, 4}}, {TensorType_FLOAT32, {1, 6, 3}}, + {TensorType_FLOAT32, {6, 4}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}); + + // six boxes in center-size encoding + m.SetInput1({0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}); + // class scores - two classes with background + m.SetInput2({0., .9, .8, 0., .75, .72, 0., .6, .5, 0., .93, .95, 0., + .5, .4, 0., .3, .2}); + // six anchors in center-size encoding + m.SetInput3({0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, + 0.5, 0.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0, + 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}); + // Same boxes in box-corner encoding: + // { 0.0, 0.0, 1.0, 1.0, + // 0.0, 0.1, 1.0, 1.1, + // 0.0, -0.1, 1.0, 0.9, + // 0.0, 10.0, 1.0, 11.0, + // 0.0, 10.1, 1.0, 11.1, + // 0.0, 100.0, 1.0, 101.0} + m.Invoke(); + // detection_boxes + // in center-size + std::vector output_shape1 = m.GetOutputShape1(); + EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4)); + EXPECT_THAT( + m.GetOutput1(), + ElementsAreArray(ArrayFloatNear( + {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0}, + 1e-1))); + // detection_classes + std::vector output_shape2 = m.GetOutputShape2(); + EXPECT_THAT(output_shape2, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput2(), + ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1))); + // detection_scores + std::vector output_shape3 = m.GetOutputShape3(); + EXPECT_THAT(output_shape3, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput3(), + ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1))); + // num_detections + std::vector output_shape4 = m.GetOutputShape4(); + EXPECT_THAT(output_shape4, ElementsAre(1)); + EXPECT_THAT(m.GetOutput4(), + ElementsAreArray(ArrayFloatNear({3.0}, 1e-1))); +} + +TEST(DetectionPostprocessOpTest, QuantizedTest) { + BaseDetectionPostprocessOpModel m( + {TensorType_UINT8, {1, 6, 4}, -1.0, 1.0}, + {TensorType_UINT8, {1, 6, 3}, 0.0, 1.0}, + {TensorType_UINT8, {6, 4}, 0.0, 100.5}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}); + // six boxes in center-size encoding + std::vector> inputs1 = { + {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}; + m.QuantizeAndPopulate(m.input1(), inputs1[0]); + // class scores - two classes with background + std::vector> inputs2 = { + {0., .9, .8, 0., .75, .72, 0., .6, .5, 0., .93, .95, 0., .5, .4, 0., .3, + .2}}; + m.QuantizeAndPopulate(m.input2(), inputs2[0]); + // six anchors in center-size encoding + std::vector> inputs3 = { + {0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, + 0.5, 10.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}}; + m.QuantizeAndPopulate(m.input3(), inputs3[0]); + m.Invoke(); + // detection_boxes + // in center-size + std::vector output_shape1 = m.GetOutputShape1(); + EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4)); + EXPECT_THAT( + m.GetOutput1(), + ElementsAreArray(ArrayFloatNear( + {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0}, + 3e-1))); + // detection_classes + std::vector output_shape2 = m.GetOutputShape2(); + EXPECT_THAT(output_shape2, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput2(), + ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1))); + // detection_scores + std::vector output_shape3 = m.GetOutputShape3(); + EXPECT_THAT(output_shape3, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput3(), + ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1))); + // num_detections + std::vector output_shape4 = m.GetOutputShape4(); + EXPECT_THAT(output_shape4, ElementsAre(1)); + EXPECT_THAT(m.GetOutput4(), + ElementsAreArray(ArrayFloatNear({3.0}, 1e-1))); +} +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc index d264821e30cf622ff5d3d8ad513add46caa9e7ae..d7420ddd8e41a57c901527884e942d444e543aa6 100644 --- a/tensorflow/contrib/lite/kernels/div.cc +++ b/tensorflow/contrib/lite/kernels/div.cc @@ -78,29 +78,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } template -void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteDivParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); -#define TF_LITE_DIV(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) - if (kernel_type == kReference) { - if (data->requires_broadcast) { - TF_LITE_DIV(reference_ops, BroadcastDiv); +void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params, + const OpData* data, const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { +#define TF_LITE_DIV(type, opname, data_type) \ + data_type output_activation_min, output_activation_max; \ + CalculateActivationRange(params->activation, &output_activation_min, \ + &output_activation_max); \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (output->type == kTfLiteInt32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_DIV(reference_ops, BroadcastDiv, int32_t); + } else { + TF_LITE_DIV(reference_ops, Div, int32_t); + } } else { - TF_LITE_DIV(reference_ops, Div); + if (data->requires_broadcast) { + TF_LITE_DIV(optimized_ops, BroadcastDiv, int32_t); + } else { + TF_LITE_DIV(optimized_ops, Div, int32_t); + } } - } else { - if (data->requires_broadcast) { - TF_LITE_DIV(optimized_ops, BroadcastDiv); + } else if (output->type == kTfLiteFloat32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_DIV(reference_ops, BroadcastDiv, float); + } else { + TF_LITE_DIV(reference_ops, Div, float); + } } else { - TF_LITE_DIV(optimized_ops, Div); + if (data->requires_broadcast) { + TF_LITE_DIV(optimized_ops, BroadcastDiv, float); + } else { + TF_LITE_DIV(optimized_ops, Div, float); + } } } #undef TF_LITE_DIV @@ -115,11 +130,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (output->type == kTfLiteFloat32) { - EvalFloat(context, node, params, data, input1, input2, output); + if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { + EvalDiv(context, node, params, data, input1, input2, output); } else { context->ReportError( - context, "Div only supports FLOAT32 and quantized UINT8 now, got %d.", + context, + "Div only supports FLOAT32, INT32 and quantized UINT8 now, got %d.", output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/div_test.cc b/tensorflow/contrib/lite/kernels/div_test.cc index 276b8289fbc1b4dcbf4624b76b854300d0fd4912..97aa2fe04e27416b99f48ab61ece54b745597ae3 100644 --- a/tensorflow/contrib/lite/kernels/div_test.cc +++ b/tensorflow/contrib/lite/kernels/div_test.cc @@ -52,6 +52,13 @@ class FloatDivOpModel : public BaseDivOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; +class IntegerDivOpModel : public BaseDivOpModel { + public: + using BaseDivOpModel::BaseDivOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + TEST(FloatDivOpTest, NoActivation) { FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, @@ -75,7 +82,7 @@ TEST(FloatDivOpTest, ActivationRELU_N1_TO_1) { } TEST(FloatDivOpTest, VariousInputShapes) { - std::vector> test_shapes = { + std::vector> test_shapes = { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { FloatDivOpModel m({TensorType_FLOAT32, test_shapes[i]}, @@ -92,7 +99,7 @@ TEST(FloatDivOpTest, VariousInputShapes) { } TEST(FloatDivOpTest, WithBroadcast) { - std::vector> test_shapes = { + std::vector> test_shapes = { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { FloatDivOpModel m({TensorType_FLOAT32, test_shapes[i]}, @@ -108,6 +115,56 @@ TEST(FloatDivOpTest, WithBroadcast) { } } +TEST(IntegerDivOpTest, NoActivation) { + IntegerDivOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2, 2, -15, 8}); + m.PopulateTensor(m.input2(), {5, -2, -3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, -1, 5, 1})); +} + +TEST(IntegerDivOpTest, ActivationRELU_N1_TO_1) { + IntegerDivOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_RELU_N1_TO_1); + m.PopulateTensor(m.input1(), {-2, 2, -12, 8}); + m.PopulateTensor(m.input2(), {1, 2, -15, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 0, 1})); +} + +TEST(IntegerDivOpTest, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerDivOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 3, 8, 11, -20}); + m.PopulateTensor(m.input2(), {1, 2, 6, 5, -11, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 1, 0, 1, -1, 20})) + << "With shape number " << i; + } +} + +TEST(IntegerDivOpTest, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerDivOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, // always a scalar + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 21, 7, 8, 11, -123}); + m.PopulateTensor(m.input2(), {3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-6, 7, 2, 2, 3, -41})) + << "With shape number " << i; + } +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc index 97f5e908240d76a88b3963c4d14aa3a6f1ea313c..e542ad076528fa30152abba074a5c7dcd6ca1f48 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.cc +++ b/tensorflow/contrib/lite/kernels/eigen_support.cc @@ -14,31 +14,100 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/eigen_support.h" -#include "Eigen/Core" +#include + +#include "tensorflow/contrib/lite/arena_planner.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { namespace eigen_support { +namespace { + +#ifndef EIGEN_DONT_ALIGN +// Eigen may require buffers to be algiend to 16, 32 or 64 bytes depending on +// hardware architecture and build configurations. +// If the static assertion fails, try to increase `kDefaultTensorAlignment` to +// in `arena_planner.h` to 32 or 64. +static_assert( + kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0, + "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement."); +#endif // EIGEN_DONT_ALIGN + +// We have a single global threadpool for all convolution operations. This means +// that inferences started from different threads may block each other, but +// since the underlying resource of CPU cores should be consumed by the +// operations anyway, it shouldn't affect overall performance. +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + // Takes ownership of 'pool' + explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} -struct RefCountedEigenContext { + void Schedule(std::function fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + std::unique_ptr pool_; +}; + +struct RefCountedEigenContext : public TfLiteExternalContext { + std::unique_ptr thread_pool_wrapper; + std::unique_ptr device; int num_references = 0; }; +RefCountedEigenContext* GetEigenContext(TfLiteContext* context) { + return reinterpret_cast( + context->GetExternalContext(context, kTfLiteEigenContext)); +} + +void InitDevice(TfLiteContext* context, RefCountedEigenContext* ptr) { + int num_threads = 4; + if (context->recommended_num_threads != -1) { + num_threads = context->recommended_num_threads; + } + ptr->device.reset(); // destroy before we invalidate the thread pool + ptr->thread_pool_wrapper.reset( + new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads))); + ptr->device.reset( + new Eigen::ThreadPoolDevice(ptr->thread_pool_wrapper.get(), num_threads)); +} + +TfLiteStatus Refresh(TfLiteContext* context) { + Eigen::setNbThreads(context->recommended_num_threads); + + auto* ptr = GetEigenContext(context); + if (ptr != nullptr) { + InitDevice(context, ptr); + } + + return kTfLiteOk; +} + +} // namespace + void IncrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->eigen_context); + auto* ptr = GetEigenContext(context); if (ptr == nullptr) { if (context->recommended_num_threads != -1) { Eigen::setNbThreads(context->recommended_num_threads); } ptr = new RefCountedEigenContext; + ptr->type = kTfLiteEigenContext; + ptr->Refresh = Refresh; ptr->num_references = 0; - context->eigen_context = ptr; + InitDevice(context, ptr); + context->SetExternalContext(context, kTfLiteEigenContext, ptr); } ptr->num_references++; } void DecrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->eigen_context); + auto* ptr = GetEigenContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to DecrementUsageCounter() not preceded by " @@ -46,14 +115,17 @@ void DecrementUsageCounter(TfLiteContext* context) { } if (--ptr->num_references == 0) { delete ptr; - context->eigen_context = nullptr; + context->SetExternalContext(context, kTfLiteEigenContext, nullptr); } } -void SetNumThreads(TfLiteContext* context, int num_threads) { - IncrementUsageCounter(context); - Eigen::setNbThreads(num_threads); - DecrementUsageCounter(context); +const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) { + auto* ptr = GetEigenContext(context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to GetFromContext() not preceded by IncrementUsageCounter()"); + } + return ptr->device.get(); } } // namespace eigen_support diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h index aa8c351fd8e8dae45f7d4807ce24d80bb393c41c..ec77856b1054e85c405193c6f44dc6e74b58a645 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.h +++ b/tensorflow/contrib/lite/kernels/eigen_support.h @@ -17,6 +17,10 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" +namespace EigenForTFLite { +class ThreadPoolDevice; +} + namespace tflite { namespace eigen_support { @@ -28,8 +32,8 @@ void IncrementUsageCounter(TfLiteContext* context); // usages all temporary Eigen objects will be deleted. void DecrementUsageCounter(TfLiteContext* context); -// Set the number of threads that can be used by Eigen. -void SetNumThreads(TfLiteContext* context, int num_threads); +const EigenForTFLite::ThreadPoolDevice* GetThreadPoolDevice( + TfLiteContext* context); } // namespace eigen_support } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index 98c21ce9d390aaa1f3cb5fdb8f31cbffb1b81d6a..e19779ea59d441984d3562508e4237e10ce17515 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -22,59 +22,118 @@ namespace tflite { namespace ops { namespace builtin { namespace elementwise { +namespace { +bool IsNumericSupportedType(const TfLiteType type) { + return type == kTfLiteFloat32; +} + +bool IsLogicalSupportedType(const TfLiteType type) { + return type == kTfLiteBool; +} + +typedef bool (*IsSupportedType)(TfLiteType); +template TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); - // Quantized float is not supported yet. - TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + if (!IsSupportedType(input->type)) { + context->ReportError(context, "Current data type %d is not supported.", + input->type); + return kTfLiteError; + } return context->ResizeTensor(context, output, TfLiteIntArrayCopy(input->dims)); } -inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, - float float_func(float)) { +template +inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, + T func(T), TfLiteType expected_type) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); - switch (input->type) { - case kTfLiteFloat32: { - size_t elements = NumElements(input); - const float* in = GetTensorData(input); - const float* in_end = in + elements; - float* out = output->data.f; - for (; in < in_end; in++, out++) *out = float_func(*in); - return kTfLiteOk; - } - default: { - context->ReportError(context, "Input type is %d, requires float32", - input->type); - return kTfLiteError; - } + TF_LITE_ENSURE_EQ(context, input->type, expected_type); + const int64_t num_elements = NumElements(input); + const T* in_data = GetTensorData(input); + T* out_data = GetTensorData(output); + for (int64_t i = 0; i < num_elements; ++i) { + out_data[i] = func(in_data[i]); } + return kTfLiteOk; +} + +inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node, + float float_func(float)) { + return EvalImpl(context, node, float_func, kTfLiteFloat32); +} + +inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node, + bool bool_func(bool)) { + return EvalImpl(context, node, bool_func, kTfLiteBool); } TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, std::sin); + return EvalNumeric(context, node, std::sin); } TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, std::log); + return EvalNumeric(context, node, std::log); +} + +TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) { + return EvalNumeric(context, node, std::sqrt); +} + +TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) { + return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); }); +} + +TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) { + return EvalLogical(context, node, [](bool v) { return !v; }); } +} // namespace } // namespace elementwise TfLiteRegistration* Register_SIN() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::SinEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::SinEval}; return &r; } TfLiteRegistration* Register_LOG() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::LogEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::LogEval}; + return &r; +} + +TfLiteRegistration* Register_SQRT() { + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::SqrtEval}; + return &r; +} + +TfLiteRegistration* Register_RSQRT() { + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::RsqrtEval}; + return &r; +} + +TfLiteRegistration* Register_LOGICAL_NOT() { + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::LogicalNotEval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc index 10e88d5a31868eeb5f65c7ade1f1c73827dea24a..b9d7d73c52862da9166f6881b1e27a6ff6b76bbc 100644 --- a/tensorflow/contrib/lite/kernels/elementwise_test.cc +++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc @@ -24,26 +24,40 @@ namespace { using ::testing::ElementsAreArray; -class ElementWiseOpModel : public SingleOpModel { +class ElementWiseOpBaseModel : public SingleOpModel { public: - ElementWiseOpModel(BuiltinOperator op, - std::initializer_list input_shape) { + int input() const { return input_; } + int output() const { return output_; } + + protected: + int input_; + int output_; +}; + +class ElementWiseOpFloatModel : public ElementWiseOpBaseModel { + public: + ElementWiseOpFloatModel(BuiltinOperator op, + std::initializer_list input_shape) { input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(op, BuiltinOptions_NONE, 0); BuildInterpreter({input_shape}); } +}; - int input() const { return input_; } - int output() const { return output_; } - - private: - int input_; - int output_; +class ElementWiseOpBoolModel : public ElementWiseOpBaseModel { + public: + ElementWiseOpBoolModel(BuiltinOperator op, + std::initializer_list input_shape) { + input_ = AddInput(TensorType_BOOL); + output_ = AddOutput(TensorType_BOOL); + SetBuiltinOp(op, BuiltinOptions_NONE, 0); + BuildInterpreter({input_shape}); + } }; TEST(ElementWise, Sin) { - ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {0, 3.1415926, -3.1415926, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -52,7 +66,7 @@ TEST(ElementWise, Sin) { } TEST(ElementWise, Log) { - ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {1, 3.1415926, 1, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -60,6 +74,33 @@ TEST(ElementWise, Log) { EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); } +TEST(ElementWise, Sqrt) { + ElementWiseOpFloatModel m(BuiltinOperator_SQRT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {0, 1, 2, 4}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({0, 1, 1.41421, 2}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + +TEST(ElementWise, Rsqrt) { + ElementWiseOpFloatModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {1, 2, 4, 9}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({1, 0.7071, 0.5, 0.33333}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + +TEST(ElementWise, LogicalNot) { + ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {true, false, true, false}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({false, true, false, true})); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc index 9410bead5e7a68363d034c22fb2c0eff9f060ef1..b2dff87e6296c6038241c704d9158e174501f026 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -29,7 +29,6 @@ limitations under the License. // When indices are out of bound, the ops will not succeed. // -#include #include #include #include @@ -94,7 +93,7 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* lookup, const TfLiteTensor* value, TfLiteTensor* output) { const int row_size = SizeOfDimension(value, 0); - const double scaling_factor = 1.0 / value->params.scale; + const double scaling_factor = value->params.scale; // col_size after we flatten tensor into 2D. int col_size = 1; @@ -112,8 +111,9 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, // TODO(alanchiao): refactor scalar multiply into separate function // for ease of adding a neon equivalent if ever necessary. for (int j = 0; j < col_size; j++) { + const int8_t* value_ptr = reinterpret_cast(value->data.uint8); output->data.f[j + i * col_size] = - value->data.uint8[j + idx * col_size] * scaling_factor; + value_ptr[j + idx * col_size] * scaling_factor; } } } diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc index 04657fd86323ef1c58d069c06097c7665f55cc87..4a88d168c60203f10802e634def9b1d1316c9c6d 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -107,9 +107,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) { HybridEmbeddingLookupOpModel m({3}, {3, 8}); m.SetInput({1, 0, 2}); m.SetWeight({ - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }); m.Invoke(); @@ -117,9 +117,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }, 7.41e-03))); } @@ -128,9 +128,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) { HybridEmbeddingLookupOpModel m({3}, {3, 2, 4}); m.SetInput({1, 0, 2}); m.SetWeight({ - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }); m.Invoke(); @@ -138,9 +138,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }, 7.41e-03))); } @@ -149,9 +149,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) { HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2}); m.SetInput({1, 0, 2}); m.SetWeight({ - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }); m.Invoke(); @@ -159,9 +159,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }, 7.41e-03))); } diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc index b755e8ce293442813b26ec3177162a3c95af2f89..50dc860e5a83f185abc70a844abdbc974f7bc4e7 100644 --- a/tensorflow/contrib/lite/kernels/expand_dims_test.cc +++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc @@ -39,7 +39,7 @@ class ExpandDimsOpModel : public SingleOpModel { void SetInputFloat(std::initializer_list data) { PopulateTensor(input_, data); } - void SetAxis(int axis) { PopulateTensor(axis_, {axis}); } + void SetAxis(int axis) { PopulateTensor(axis_, {axis}); } std::vector GetValuesFloat() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } @@ -51,7 +51,7 @@ class ExpandDimsOpModel : public SingleOpModel { TEST(ExpandDimsOpTest, DifferentAxis) { ExpandDimsOpModel m({2, 2}, TensorType_FLOAT32); - const auto values = {-1.f, 1.f, -2.f, 2.f}; + std::initializer_list values = {-1.f, 1.f, -2.f, 2.f}; m.SetInputFloat(values); m.SetAxis(0); m.Invoke(); diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ef1a50b308b2e8a781bc9ed7195c22e627ea2de --- /dev/null +++ b/tensorflow/contrib/lite/kernels/fake_quant.cc @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace fake_quant { + +// This file has reference implementation of FakeQuant. +enum KernelType { + kReference, +}; + +struct OpContext { + OpContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + const TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const auto* params = + reinterpret_cast(node->builtin_data); + + if (params->narrow_range) { + context->ReportError( + context, + "narrow_range FakeQuant is not currently supported at runtime. " + "narrow_range is only meant to be applied to weights, not activations"); + return kTfLiteError; + } + + OpContext op_context(context, node); + TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input->dims); + op_context.output->type = op_context.input->type; + return context->ResizeTensor(context, op_context.output, output_dims); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + + const auto* params = + reinterpret_cast(node->builtin_data); + + reference_ops::FakeQuant(GetTensorData(op_context.input), + GetTensorDims(op_context.input), params->min, + params->max, params->num_bits, + GetTensorData(op_context.output), + GetTensorDims(op_context.output)); + + return kTfLiteOk; +} + +} // namespace fake_quant + +TfLiteRegistration* Register_FAKE_QUANT_REF() { + static TfLiteRegistration r = {nullptr, nullptr, fake_quant::Prepare, + fake_quant::Eval}; + return &r; +} + +TfLiteRegistration* Register_FAKE_QUANT() { return Register_FAKE_QUANT_REF(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/fake_quant_test.cc b/tensorflow/contrib/lite/kernels/fake_quant_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..11a02f7ed7474e05b887955c111179d2d403f0e6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/fake_quant_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class FakeQuantOpModel : public SingleOpModel { + public: + FakeQuantOpModel(const TensorData& input, const TensorType& output, float min, + float max, int num_bits) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_FAKE_QUANT, BuiltinOptions_FakeQuantOptions, + CreateFakeQuantOptions(builder_, min, max, num_bits).Union()); + BuildInterpreter({GetShape(input_)}); + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int output_; +}; + +TEST(FakeQuantOpTest, FloatPositiveRange8Test) { + std::initializer_list data = {0.0, 1.0, 0.25, + 0.50, 0.4444444, 0.00001}; + FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, 0.0f, + 1.0f, 8); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({0, 1, 0.25098, 0.498039, 0.443137, 0}))); +} + +TEST(FakeQuantOpTest, FloatNegativeRange8Test) { + std::initializer_list data = {0.0, -0.9, 0.25, + 0.50, 0.4444444, -0.00001}; + FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, -0.9f, + 0.9f, 8); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0, -0.896471, 0.247059, 0.501176, 0.444706, 0}))); +} + +TEST(FakeQuantOpTest, FloatPositiveRange16Test) { + std::initializer_list data = {0.0, 1.0, 0.25, + 0.50, 0.4444444, 0.00001}; + FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, 0.0f, + 1.0f, 16); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0, 1, 0.250004, 0.500008, 0.44445, 1.5259e-05}))); +} + +TEST(FakeQuantOpTest, FloatNegativeRange16Test) { + std::initializer_list data = {0.0, -0.9, 0.25, + 0.50, 0.4444444, -0.00001}; + FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, -0.9f, + 0.9f, 16); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0, -0.900014, 0.249998, 0.499995, 0.444431, 0}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index f6fc0f5b6ad12d58c541efc6eae566ab4b8327f4..eaf5a67d6787b9113bd0835d436b459e00ed7fff 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -63,6 +62,7 @@ constexpr int kInputTensor = 0; constexpr int kWeightsTensor = 1; constexpr int kBiasTensor = 2; constexpr int kOutputTensor = 0; +constexpr int kShuffledInputWorkspaceTensor = 1; constexpr int kScratchBufferTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -70,7 +70,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // Instead, we allocate a new object to carry information from Prepare() to // Eval(). gemm_support::IncrementUsageCounter(context); - auto* op_data = new OpData; + auto* op_data = new OpData(); context->AddTensors(context, 1, &op_data->input_quantized_index); return op_data; } @@ -87,7 +87,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 3); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + // Shuffled formats need a workspace to store the shuffled input activations. + const int expected_outputs_count = + params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1 + : 2; + TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); @@ -117,13 +121,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); - TF_LITE_ENSURE(context, real_multiplier < 1.0); - QuantizeMultiplierSmallerThanOneExp( - real_multiplier, &data->output_multiplier, &data->output_shift); - data->output_shift *= -1; - CalculateActivationRangeUint8(params->activation, output, - &data->output_activation_min, - &data->output_activation_max); + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; + TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( + context, params->activation, output, &data->output_activation_min, + &data->output_activation_max)); } // If we have to perform on-the-fly quantization (with quantized weights and @@ -278,44 +281,101 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, int32_t input_offset = -input->params.zero_point; int32_t filter_offset = -filter->params.zero_point; int32_t output_offset = output->params.zero_point; -#define TF_LITE_FULLY_CONNECTED(type) \ +#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \ type::FullyConnected( \ GetTensorData(input), GetTensorDims(input), input_offset, \ GetTensorData(filter), GetTensorDims(filter), filter_offset, \ GetTensorData(bias), GetTensorDims(bias), output_offset, \ data->output_multiplier, data->output_shift, \ data->output_activation_min, data->output_activation_max, \ - GetTensorData(output), GetTensorDims(output), gemm_context) + GetTensorData(output), GetTensorDims(output), \ + gemm_context) if (kernel_type == kReference) { - TF_LITE_FULLY_CONNECTED(reference_ops); - } else if (kernel_type == kPie) { - if (input->type == kTfLiteFloat32) { - // Pie currently only supports quantized models and float inputs/outputs. - TfLiteTensor* input_quantized = - &context->tensors[node->temporaries->data[0]]; - return EvalPieQuantized(context, node, params, data, input, filter, bias, - input_quantized, output); - } else { - // TODO(ahentz): we don't have a quantized version of the PIE kernels, so - // we just defer to the MINI ones. - TF_LITE_FULLY_CONNECTED(optimized_ops); + switch (output->type) { + case kTfLiteUInt8: + TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t); + break; + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(reference_ops, int16_t); + break; + default: + context->ReportError( + context, + "Quantized FullyConnected expects output data type uint8 or int16"); + return kTfLiteError; } + } else if (kernel_type == kPie && input->type == kTfLiteFloat32) { + // Pie currently only supports quantized models and float inputs/outputs. + TfLiteTensor* input_quantized = + &context->tensors[node->temporaries->data[0]]; + return EvalPieQuantized(context, node, params, data, input, filter, bias, + input_quantized, output); } else { - TF_LITE_FULLY_CONNECTED(optimized_ops); + switch (output->type) { + case kTfLiteUInt8: + TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t); + break; + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t); + break; + default: + context->ReportError( + context, + "Quantized FullyConnected expects output data type uint8 or int16"); + return kTfLiteError; + } } #undef TF_LITE_FULLY_CONNECTED return kTfLiteOk; } +template +TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, + OpData* data, const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, + TfLiteTensor* output, + TfLiteTensor* shuffled_input_workspace) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + + // TODO(b/110697972) decide more consistently if / how / where we want + // to perform this kind of runtime data type checks. + if (input->type != kTfLiteUInt8 || filter->type != kTfLiteUInt8 || + bias->type != kTfLiteInt32 || output->type != kTfLiteInt16 || + shuffled_input_workspace->type != kTfLiteUInt8) { + context->ReportError(context, "Unexpected data type"); + return kTfLiteError; + } + +#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \ + type::ShuffledFullyConnected( \ + GetTensorData(input), GetTensorDims(input), \ + GetTensorData(filter), GetTensorDims(filter), \ + GetTensorData(bias), GetTensorDims(bias), \ + data->output_multiplier, data->output_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData(output), GetTensorDims(output), \ + GetTensorData(shuffled_input_workspace), gemm_context) + if (kernel_type == kReference) { + TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops); + } else { + TF_LITE_SHUFFLED_FULLY_CONNECTED(optimized_ops); + } +#undef TF_LITE_SHUFFLED_FULLY_CONNECTED + + return kTfLiteOk; +} + template TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output) { float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); #define TF_LITE_FULLY_CONNECTED(type) \ type::FullyConnected(GetTensorData(input), GetTensorDims(input), \ GetTensorData(filter), GetTensorDims(filter), \ @@ -352,8 +412,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return EvalFloat(context, node, params, data, input, filter, bias, output); case kTfLiteUInt8: - return EvalQuantized(context, node, params, data, input, - filter, bias, output); + if (params->weights_format == + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) { + TfLiteTensor* shuffled_input_workspace = + GetOutput(context, node, kShuffledInputWorkspaceTensor); + return EvalShuffledQuantized(context, node, params, data, + input, filter, bias, output, + shuffled_input_workspace); + } else if (params->weights_format == + kTfLiteFullyConnectedWeightsFormatDefault) { + return EvalQuantized(context, node, params, data, input, + filter, bias, output); + } else { + context->ReportError(context, + "Unhandled fully-connected weights format"); + return kTfLiteError; + } default: context->ReportError(context, "Type %d not currently supported.", filter->type); diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc index 05dd028b484c09bdf90a09fab1238f48e8a9ddab..08b43209466a1b85613ae41d5aa776194f992c60 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc @@ -15,6 +15,7 @@ limitations under the License. // Unit test for TFLite FULLY_CONNECTED op. #include +#include #include #include @@ -133,9 +134,12 @@ static float fully_connected_golden_output[] = { class BaseFullyConnectedOpModel : public SingleOpModel { public: // TODO(ahentz): test different activation types too. - BaseFullyConnectedOpModel(TfLiteRegistration* registration, int units, - int batches, const TensorData& input, - const TensorData& output = {TensorType_FLOAT32}) + BaseFullyConnectedOpModel( + TfLiteRegistration* registration, int units, int batches, + const TensorData& input, const TensorData& output = {TensorType_FLOAT32}, + ActivationFunctionType activation_func = ActivationFunctionType_RELU, + FullyConnectedOptionsWeightsFormat weights_format = + FullyConnectedOptionsWeightsFormat_DEFAULT) : batches_(batches), units_(units) { int total_input_size = 1; for (int i = 0; i < input.shape.size(); ++i) { @@ -159,10 +163,13 @@ class BaseFullyConnectedOpModel : public SingleOpModel { } output_ = AddOutput(output); + if (weights_format != FullyConnectedOptionsWeightsFormat_DEFAULT) { + AddOutput({TensorType_UINT8, input.shape}); + } SetBuiltinOp( BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, - CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + CreateFullyConnectedOptions(builder_, activation_func, weights_format) .Union()); resolver_ = absl::make_unique( BuiltinOperator_FULLY_CONNECTED, registration); @@ -188,13 +195,11 @@ class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel { public: using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; - void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + void SetBias(const std::vector& f) { PopulateTensor(bias_, f); } - void SetWeights(std::initializer_list f) { - PopulateTensor(weights_, f); - } + void SetWeights(const std::vector& f) { PopulateTensor(weights_, f); } - void SetInput(std::initializer_list data) { + void SetInput(const std::vector& data) { PopulateTensor(input_, data); } void SetInput(int offset, float* begin, float* end) { @@ -208,20 +213,50 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel { public: using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; - void SetBias(std::initializer_list data) { + void SetBias(const std::vector& data) { QuantizeAndPopulate(bias_, data); } - void SetWeights(std::initializer_list data) { + void SetWeights(const std::vector& data) { QuantizeAndPopulate(weights_, data); } - void SetInput(std::initializer_list data) { + void ShuffleAndSetWeights(const std::vector& data, int input_depth, + int output_depth) { + std::vector shuffled_data(data.size()); + CHECK_EQ(input_depth % 16, 0); + CHECK_EQ(output_depth % 4, 0); + float* shuffled_data_ptr = shuffled_data.data(); + for (int block_o = 0; block_o < output_depth; block_o += 4) { + for (int block_i = 0; block_i < input_depth; block_i += 16) { + for (int o = 0; o < 4; o++) { + for (int i = 0; i < 16; i++) { + *shuffled_data_ptr++ = + data[(block_o + o) * input_depth + block_i + i]; + } + } + } + } + TfLiteTensor* t = interpreter_->tensor(weights_); + auto quantized_data = + Quantize(shuffled_data, t->params.scale, t->params.zero_point); + for (uint8_t& q : quantized_data) { + q ^= 0x80; + } + PopulateTensor(weights_, 0, quantized_data.data(), + quantized_data.data() + quantized_data.size()); + } + void SetInput(const std::vector& data) { QuantizeAndPopulate(input_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); } }; @@ -256,12 +291,12 @@ class HybridFullyConnectedOpModel : public SingleOpModel { ops::builtin::Register_FULLY_CONNECTED_PIE()); BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); } - void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } - void SetWeights(std::initializer_list data) { + void SetBias(const std::vector& f) { PopulateTensor(bias_, f); } + void SetWeights(const std::vector& data) { SymmetricQuantizeAndPopulate(weights_, data); } - void SetInput(std::initializer_list f) { PopulateTensor(input_, f); } + void SetInput(const std::vector& f) { PopulateTensor(input_, f); } std::vector GetOutput() { return ExtractVector(output_); } int input_size() { return input_size_; } @@ -340,6 +375,24 @@ TEST_P(FloatFullyConnectedOpTest, SimpleTest) { EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); } +TEST_P(FloatFullyConnectedOpTest, SimpleTest2) { + FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/1, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 2}}); + m.SetWeights({ + 2, 4, // u = 0 + }); + m.SetBias({1}); + + m.SetInput({ + 1, 2, // b = 0 + 2, 1, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9)); +} + TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) { QuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches*/ 2, @@ -350,7 +403,38 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) { m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), + ElementsAre(151, 152, 153, 185, 186, 187)); +} + +TEST_P(QuantizedFullyConnectedOpTest, + SimpleTestQuantizedOutputMultiplierGreaterThan1) { + // real_multiplier = 2. + QuantizedFullyConnectedOpModel m( + GetRegistration(), /*units=*/3, /*batches*/ 2, + /*input=*/{TensorType_UINT8, {2, 10}, -127, 128}, + /*output=*/{TensorType_UINT8, {}, -63.5, 64}); + + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 }); m.SetBias({1, 2, 3}); @@ -361,11 +445,136 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) { m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ - 24, 25, 26, // - 58, 59, 60, // - }))); - EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // first batch + 58, 59, 60, // second batch + }))); + EXPECT_THAT(m.GetOutput(), + ElementsAre(175, 177, 179, 243, 245, 247)); +} + +void SimpleTestQuantizedInt16OutputCase( + TfLiteRegistration* registration, int input_depth, int output_depth, + int batches, FullyConnectedOptionsWeightsFormat weights_format) { + const uint8_t kWeightsZeroPoint = 128; + const float kWeightsScale = 1.f / 128.f; + const uint8_t kInputZeroPoint = 128; + const float kInputScale = 1.f / 128.f; + const float kInputMin = (0 - kInputZeroPoint) * kInputScale; + const float kInputMax = (255 - kInputZeroPoint) * kInputScale; + // Output ranges in [-8..8] encoded as int16 + const float kOutputScale = 8.f / 32768.f; + const float kOutputMin = -32768 * kOutputScale; + const float kOutputMax = 32767 * kOutputScale; + + QuantizedFullyConnectedOpModel m( + registration, output_depth, batches, + /*input=*/ + {TensorType_UINT8, {batches, input_depth}, kInputMin, kInputMax}, + /*output=*/{TensorType_INT16, {}, kOutputMin, kOutputMax}, + /*activation_func=*/ActivationFunctionType_NONE, weights_format); + + std::mt19937 random_engine; + std::uniform_int_distribution weights_dist; + + std::vector weights_data(input_depth * output_depth); + for (auto& w : weights_data) { + uint8_t q = weights_dist(random_engine); + w = (q - kWeightsZeroPoint) * kWeightsScale; + } + + // Based on weights_format, enforce any shape requirement for that format/path + // and set the (possibly shuffled) weights. + switch (weights_format) { + case FullyConnectedOptionsWeightsFormat_DEFAULT: + m.SetWeights(weights_data); + break; + case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + // The shuffled path currently supports only a restrictive subset of + // shapes, described by the following assertions: + CHECK_EQ(input_depth % 16, 0); + CHECK_EQ(output_depth % 4, 0); + CHECK(batches == 1 || batches == 4); + m.ShuffleAndSetWeights(weights_data, input_depth, output_depth); + break; + default: + LOG(FATAL) << "Unhandled weights format"; + } + + std::uniform_int_distribution input_dist; + std::vector input_data(input_depth * batches); + for (auto& i : input_data) { + uint8_t q = input_dist(random_engine); + i = (q - kInputZeroPoint) * kInputScale; + } + + std::vector bias_data(output_depth); + // As the output ranges in [-8, 8], it's reasonable to have bias values + // in [-1, 1], this won't result in too much saturation. + std::uniform_real_distribution bias_dist(-1.f, 1.f); + for (auto& b : bias_data) { + b = bias_dist(random_engine); + } + + m.SetBias(bias_data); + m.SetInput(input_data); + + m.Invoke(); + + std::vector expected_output_data(output_depth * batches); + for (int b = 0; b < batches; b++) { + for (int o = 0; o < output_depth; o++) { + float accum = bias_data[o]; + for (int i = 0; i < input_depth; i++) { + accum += + input_data[b * input_depth + i] * weights_data[o * input_depth + i]; + } + accum = std::min(accum, kOutputMax); + accum = std::max(accum, kOutputMin); + expected_output_data[b * output_depth + o] = accum; + } + } + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(expected_output_data, 3e-4f))); +} + +TEST_P(QuantizedFullyConnectedOpTest, + SimpleTestQuantizedInt16OutputDefaultWeights) { + for (int input_depth : {1, 3, 10, 100}) { + for (int output_depth : {1, 3, 10, 100}) { + for (int batch : {1, 3, 10, 100}) { + SimpleTestQuantizedInt16OutputCase( + GetRegistration(), input_depth, output_depth, batch, + FullyConnectedOptionsWeightsFormat_DEFAULT); + } + } + } +} + +TEST_P(QuantizedFullyConnectedOpTest, + SimpleTestQuantizedInt16OutputShuffled4x16Int8Weights) { + // The shuffled weights block shape is 4x16. The shape of the weights matrix + // is: rows = output_depth, cols = input_depth. It must be a multiple of 4x16. + // This means that output_depth must be a multiple of 4, and input_deth must + // be a multiple of 16. + for (int input_depth_numblocks : {1, 3}) { + for (int output_depth_numblocks : {1, 3}) { + int input_depth = 16 * input_depth_numblocks; + int output_depth = 4 * output_depth_numblocks; + // The fast shuffled path is currently supporting only batch sizes of 1 + // and 4. The idea is that the whole point of that path is to go as fast + // as possible for small batch size, which requires fully specializing + // it for each batch size, and for larger batch sizes the generic + // gemmlowp-based implementation is fast enough. + for (int batch : {1, 4}) { + SimpleTestQuantizedInt16OutputCase( + GetRegistration(), input_depth, output_depth, batch, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8); + } + } + } } TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) { @@ -396,11 +605,11 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) { /*max_abs_error=*/1.3f))); } -TEST(FloatFullyConnectedOpTest, SimpleTest4DInput) { +TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) { // Note that it is not required that the first dimension be the number of // batches. All we care is that the input can be evenly distributed in // batches. In this case, we need the input to have multiples of '2'. - FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(), + FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/3, /*batches=*/2, /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}}); m.SetWeights({ @@ -444,11 +653,44 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantized) { m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ - 24, 25, 26, // - 58, 59, 60, // - }))); - EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), + ElementsAre(151, 152, 153, 185, 186, 187)); +} + +TEST_P(QuantizedFullyConnectedOpTest, + SimpleTest4dInputQuantizedOutputMultiplierGreaterThan1) { + // real_multiplier = 2. + QuantizedFullyConnectedOpModel m( + GetRegistration(), /*units=*/3, /*batches=*/2, + /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -127, 128}, + /*output=*/{TensorType_UINT8, {}, -63.5, 64}); + + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // first batch + 58, 59, 60, // second batch + }))); + EXPECT_THAT(m.GetOutput(), + ElementsAre(175, 177, 179, 243, 245, 247)); } INSTANTIATE_TEST_CASE_P( diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc index 6a2341461f2c627c78bd4783ee27579b59b5fde3..2b2a9e662051287fd1e3dbe8978f4689dc731064 100644 --- a/tensorflow/contrib/lite/kernels/gather.cc +++ b/tensorflow/contrib/lite/kernels/gather.cc @@ -40,10 +40,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Only INT32 positions are supported. TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32); - // Check that input and output types match. - TF_LITE_ENSURE_EQ(context, input->type, output->type); - // TODO(mgubin): only 0D or 1D positions are currently supported. - TF_LITE_ENSURE(context, NumDimensions(positions) <= 1); + // Assign to output the input type. + output->type = input->type; // TODO(mgubin): Only default axis == 0 is supported. TF_LITE_ENSURE_EQ(context, params->axis, 0); // Check conditions for different types. @@ -102,6 +100,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_GATHER(int32_t, int32_t); break; case kTfLiteString: { + // TODO(mgubin): Currently support only for 1D output tensors. DynamicBuffer buffer; const int32* indexes = positions->data.i32; const int num_strings = GetStringCount(input); diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc index cdadbeda1884ba0186846826dd16be6ff69878d9..1d4292955cced59a47e0500833a86113cb9d3eb8 100644 --- a/tensorflow/contrib/lite/kernels/gather_test.cc +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -96,6 +96,15 @@ TEST(GatherOpTest, Test0DIndexWith0DResult) { EXPECT_TRUE(m.GetOutputShape().empty()); } +TEST(GatherOpTest, Test2DIndexWith2DResult) { + GatherOpModel m({3}, TensorType_FLOAT32, {1, 2}); + m.SetInputFloat({1.0, 2.0, 3.0}); + m.SetPositions({1, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({2.0, 1.0}))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); +} + TEST(FloatGatherOpTest, Duplicate) { GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2}); m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc index 95f45ea768be7f9bae9570563f161792afbff436..ed334af2da877edf9f591612478e22f04cf15931 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.cc +++ b/tensorflow/contrib/lite/kernels/gemm_support.cc @@ -14,57 +14,70 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include + #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { namespace gemm_support { +namespace { -struct RefCountedGemmContext { - gemmlowp::GemmContext* gemm_context_ = nullptr; - int num_references_ = 0; +struct RefCountedGemmContext : public TfLiteExternalContext { + std::unique_ptr gemm_context; + int num_references = 0; }; +RefCountedGemmContext* GetGemmLowpContext(TfLiteContext* context) { + return reinterpret_cast( + context->GetExternalContext(context, kTfLiteGemmLowpContext)); +} + +TfLiteStatus Refresh(TfLiteContext* context) { + auto* ptr = GetGemmLowpContext(context); + if (ptr != nullptr) { + ptr->gemm_context->set_max_num_threads(context->recommended_num_threads); + } + return kTfLiteOk; +} + +} // namespace + void IncrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->gemm_context); + auto* ptr = GetGemmLowpContext(context); if (ptr == nullptr) { ptr = new RefCountedGemmContext; - ptr->gemm_context_ = new gemmlowp::GemmContext(); + ptr->type = kTfLiteGemmLowpContext; + ptr->Refresh = Refresh; + ptr->gemm_context.reset(new gemmlowp::GemmContext()); if (context->recommended_num_threads != -1) { - ptr->gemm_context_->set_max_num_threads(context->recommended_num_threads); + ptr->gemm_context->set_max_num_threads(context->recommended_num_threads); } - ptr->num_references_ = 0; - context->gemm_context = ptr; + ptr->num_references = 0; + context->SetExternalContext(context, kTfLiteGemmLowpContext, ptr); } - ptr->num_references_++; + ptr->num_references++; } void DecrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->gemm_context); + auto* ptr = GetGemmLowpContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to DecrementUsageCounter() not preceded by " "IncrementUsageCounter()"); } - if (--ptr->num_references_ == 0) { - delete ptr->gemm_context_; + if (--ptr->num_references == 0) { delete ptr; - context->gemm_context = nullptr; + context->SetExternalContext(context, kTfLiteGemmLowpContext, nullptr); } } gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->gemm_context); + auto* ptr = GetGemmLowpContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to GetFromContext() not preceded by IncrementUsageCounter()"); } - return ptr->gemm_context_; -} - -void SetNumThreads(TfLiteContext* context, int num_threads) { - IncrementUsageCounter(context); - GetFromContext(context)->set_max_num_threads(num_threads); - DecrementUsageCounter(context); + return ptr->gemm_context.get(); } } // namespace gemm_support diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h index f033501cb6e341aa014fa4d956b531bd79aa555b..37af772c6846f2f8124faabf1a0f0987e2e9393d 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.h +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -45,9 +45,6 @@ void IncrementUsageCounter(TfLiteContext* context); // 'context'. If there are no more usages the GemmContext will be deleted. void DecrementUsageCounter(TfLiteContext* context); -// Set the number of threads that can be used by gemmlowp. -void SetNumThreads(TfLiteContext* context, int num_threads); - } // namespace gemm_support } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc index 41211d41aa85a5a2da6ae96dc6f0337c54fb1a45..f37c66acb33eb9995772e595b84df6616e8d9e6a 100644 --- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc @@ -31,7 +31,6 @@ limitations under the License. // Each item indicates whether the corresponding lookup has a returned value. // 0 for missing key, 1 for found key. -#include #include #include #include diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 75298b995d6184985efc76c60c2f5541e9cbea40..a97db6c6b2523e09705c22ab0463c362ad3d2ff1 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -176,6 +176,40 @@ cc_library( }), ) +cc_library( + name = "legacy_optimized_base", + srcs = [], + hdrs = [ + "common.h", + "optimized/depthwiseconv_float.h", + "optimized/depthwiseconv_uint8.h", + "optimized/depthwiseconv_uint8_3x3_filter.h", + "optimized/legacy_optimized_ops.h", + "optimized/optimized_ops.h", + ], + copts = tflite_copts(), + deps = [ + ":quantization_util", + ":strided_slice_logic", + ":types", + ":legacy_reference_base", + ":round", + "//third_party/eigen3", + "@gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, + ":freebsd": tflite_deps_intel, + "//conditions:default": [], + }), +) + cc_library( name = "optimized", hdrs = [ @@ -198,6 +232,7 @@ cc_library( cc_test( name = "tensor_test", srcs = ["tensor_test.cc"], + tags = ["no_oss"], deps = [ ":reference", "@com_google_googletest//:gtest", @@ -226,6 +261,7 @@ cc_library( cc_test( name = "quantization_util_test", srcs = ["quantization_util_test.cc"], + tags = ["no_oss"], deps = [ ":quantization_util", "@com_google_googletest//:gtest", @@ -273,6 +309,37 @@ cc_library( }), ) +cc_library( + name = "legacy_reference_base", + srcs = [], + hdrs = [ + "common.h", + "reference/depthwiseconv_float.h", + "reference/depthwiseconv_uint8.h", + "reference/legacy_reference_ops.h", + "reference/reference_ops.h", + ], + deps = [ + ":quantization_util", + ":round", + ":strided_slice_logic", + ":types", + "//third_party/eigen3", + "@gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, + ":freebsd": tflite_deps_intel, + "//conditions:default": [], + }), +) + cc_library( name = "reference", hdrs = ["tensor.h"], @@ -414,6 +481,9 @@ cc_library( ":darwin": [ ":neon_tensor_utils", ], + ":darwin_x86_64": [ + ":neon_tensor_utils", + ], "//conditions:default": [ ":portable_tensor_utils", ], @@ -426,6 +496,7 @@ cc_library( hdrs = ["test_util.h"], deps = [ ":types", + "//tensorflow/contrib/lite:string", ], ) @@ -440,7 +511,10 @@ cc_test( "//conditions:default": [], }), linkstatic = 1, - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":tensor_utils", "//tensorflow/contrib/lite:builtin_op_data", @@ -452,6 +526,7 @@ cc_test( cc_test( name = "depthwiseconv_float_test", srcs = ["depthwiseconv_float_test.cc"], + tags = ["no_oss"], deps = [ ":optimized_base", ":reference_base", @@ -464,6 +539,10 @@ cc_test( cc_test( name = "depthwiseconv_quantized_test", srcs = ["depthwiseconv_quantized_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":optimized_base", ":reference_base", @@ -476,7 +555,10 @@ cc_test( cc_test( name = "resize_bilinear_test", srcs = ["resize_bilinear_test.cc"], - tags = ["tflite_not_portable"], + tags = [ + "no_oss", + "tflite_not_portable", + ], deps = [ ":optimized_base", ":reference_base", @@ -492,11 +574,13 @@ cc_test( srcs = [ "softmax_quantized_test.cc", ], + tags = ["no_oss"], deps = [ ":optimized_base", ":quantization_util", ":reference_base", ":test_util", + "//tensorflow/contrib/lite:string", "@com_google_googletest//:gtest_main", ], ) @@ -507,12 +591,16 @@ cc_test( srcs = [ "logsoftmax_quantized_test.cc", ], - tags = ["tflite_not_portable"], + tags = [ + "no_oss", + "tflite_not_portable", + ], deps = [ ":optimized_base", ":quantization_util", ":reference_base", ":test_util", + "//tensorflow/contrib/lite:string", "@com_google_googletest//:gtest_main", ], ) @@ -520,9 +608,11 @@ cc_test( cc_test( name = "log_quantized_test", srcs = ["log_quantized_test.cc"], + tags = ["no_oss"], deps = [ ":optimized_base", ":reference_base", + "//tensorflow/contrib/lite:string", "@com_google_googletest//:gtest_main", ], ) @@ -546,6 +636,7 @@ cc_library( cc_test( name = "batch_to_space_nd_test", srcs = ["batch_to_space_nd_test.cc"], + tags = ["no_oss"], deps = [ ":optimized_base", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index 3543b6f79e1633b4ffa0872b2999161a519c8cad..4a29aecce5bd52b6a7fe1a142c270ef3430ff437 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -119,6 +119,9 @@ template int CountLeadingZeros(T integer_input) { static_assert(std::is_unsigned::value, "Only unsigned integer types handled."); +#if defined(__GNUC__) + return integer_input ? __builtin_clz(integer_input) : 0; +#else const T one_in_leading_positive = static_cast(1) << (std::numeric_limits::digits - 1); int leading_zeros = 0; @@ -127,6 +130,140 @@ int CountLeadingZeros(T integer_input) { ++leading_zeros; } return leading_zeros; +#endif +} + +// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING +// BROADCASTING. +// +// NdArrayDesc describes the shape and memory layout of an N-dimensional +// rectangular array of numbers. +// +// NdArrayDesc is basically identical to Dims defined in types.h. +// However, as Dims is to be deprecated, this class exists as an adaptor +// to enable simple unoptimized implementations of element-wise broadcasting +// operations. +template +struct NdArrayDesc { + // The "extent" of each dimension. Indices along dimension d must be in the + // half-open interval [0, extents[d]). + int extents[N]; + + // The number of *elements* (not bytes) between consecutive indices of each + // dimension. + int strides[N]; +}; + +// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING +// BROADCASTING. +// +// Same as Offset(), except takes as NdArrayDesc instead of Dims. +inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, + int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); + return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + + i3 * desc.strides[3]; +} + +// Given the dimensions of the operands for an element-wise binary broadcast, +// adjusts them so that they can be directly iterated over with simple loops. +// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and +// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. +// +// This function assumes that the two input shapes are compatible up to +// broadcasting and the shorter one has already been prepended with 1s to be the +// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), +// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that +// Dims refer to shapes in reverse order. In this case, input0_dims will be +// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). +// +// When two shapes are compatible up to broadcasting, for each dimension d, +// the input extents are either equal, or one of them is 1. +// +// This function performs the following for each dimension d: +// - If the extents are equal, then do nothing since the loop that walks over +// both of the input arrays is correct. +// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 +// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows +// array0 to be referenced *at any index* in dimension d and still access the +// same slice. +template +inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, + const Dims& input1_dims, + NdArrayDesc* desc0_out, + NdArrayDesc* desc1_out) { + TFLITE_DCHECK(desc0_out != nullptr); + TFLITE_DCHECK(desc1_out != nullptr); + + // Copy dims to desc. + for (int i = 0; i < N; ++i) { + desc0_out->extents[i] = input0_dims.sizes[i]; + desc0_out->strides[i] = input0_dims.strides[i]; + desc1_out->extents[i] = input1_dims.sizes[i]; + desc1_out->strides[i] = input1_dims.strides[i]; + } + + // Walk over each dimension. If the extents are equal do nothing. + // Otherwise, set the desc with extent 1 to have extent equal to the other and + // stride 0. + for (int i = 0; i < N; ++i) { + const int extent0 = ArraySize(input0_dims, i); + const int extent1 = ArraySize(input1_dims, i); + if (extent0 != extent1) { + if (extent0 == 1) { + desc0_out->strides[i] = 0; + desc0_out->extents[i] = extent1; + } else { + TFLITE_DCHECK_EQ(extent1, 1); + desc1_out->strides[i] = 0; + desc1_out->extents[i] = extent0; + } + } + } +} + +template +inline void NdArrayDescsForElementwiseBroadcast( + const RuntimeShape& input0_shape, const RuntimeShape& input1_shape, + NdArrayDesc* desc0_out, NdArrayDesc* desc1_out) { + TFLITE_DCHECK(desc0_out != nullptr); + TFLITE_DCHECK(desc1_out != nullptr); + + auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape); + auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape); + + // Copy dims to desc, calculating strides. + int desc0_stride = 1; + int desc1_stride = 1; + for (int i = N - 1; i >= 0; --i) { + desc0_out->extents[i] = extended_input0_shape.Dims(i); + desc0_out->strides[i] = desc0_stride; + desc0_stride *= extended_input0_shape.Dims(i); + desc1_out->extents[i] = extended_input1_shape.Dims(i); + desc1_out->strides[i] = desc1_stride; + desc1_stride *= extended_input1_shape.Dims(i); + } + + // Walk over each dimension. If the extents are equal do nothing. + // Otherwise, set the desc with extent 1 to have extent equal to the other and + // stride 0. + for (int i = 0; i < N; ++i) { + const int extent0 = extended_input0_shape.Dims(i); + const int extent1 = extended_input1_shape.Dims(i); + if (extent0 != extent1) { + if (extent0 == 1) { + desc0_out->strides[i] = 0; + desc0_out->extents[i] = extent1; + } else { + TFLITE_DCHECK_EQ(extent1, 1); + desc1_out->strides[i] = 0; + desc1_out->extents[i] = extent0; + } + } + } } } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 36c25388e8bde721d7644dc83d5b7c490d37b4d3..200f2f151582c2361dd2403164d0bbe119cbed72 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -255,14 +255,6 @@ void LstmStep( output_state_ptr); } -// TODO(alanchiao): move this to tensor_utils. -void VectorMultiply(const int8_t* vector, const int v_size, const float scale, - float* result) { - for (int i = 0; i < v_size; ++i) { - *result++ = scale * *vector++; - } -} - void LstmStep( const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, float input_to_input_weights_scale, @@ -415,8 +407,9 @@ void LstmStep( // For each batch and cell: update input gate. if (!use_cifg) { if (use_peephole && !is_cell_state_all_zeros) { - VectorMultiply(cell_to_input_weights_ptr, n_cell, - 1. / cell_to_input_weights_scale, recovered_cell_weights); + tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell, + cell_to_input_weights_scale, + recovered_cell_weights); tensor_utils::VectorBatchVectorCwiseProductAccumulate( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, input_gate_scratch); @@ -427,8 +420,9 @@ void LstmStep( // For each batch and cell: update forget gate. if (use_peephole && !is_cell_state_all_zeros) { - VectorMultiply(cell_to_forget_weights_ptr, n_cell, - 1. / cell_to_forget_weights_scale, recovered_cell_weights); + tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell, + cell_to_forget_weights_scale, + recovered_cell_weights); tensor_utils::VectorBatchVectorCwiseProductAccumulate( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, forget_gate_scratch); @@ -459,8 +453,9 @@ void LstmStep( tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); // For each batch and cell: update the output gate. if (use_peephole && !is_cell_state_all_zeros) { - VectorMultiply(cell_to_output_weights_ptr, n_cell, - 1. / cell_to_output_weights_scale, recovered_cell_weights); + tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell, + cell_to_output_weights_scale, + recovered_cell_weights); tensor_utils::VectorBatchVectorCwiseProductAccumulate( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, output_gate_scratch); diff --git a/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc index 7e9ff5242a43a8b54e0e6ae167cdcf7a341c918e..8963abb9afd9d51473fe5a22d8e88d314b385ad9 100644 --- a/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc @@ -29,8 +29,9 @@ limitations under the License. #include #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/string.h" -namespace { +namespace tflite { class NumberGenerator { public: @@ -330,4 +331,4 @@ TEST_F(LogQuantizedTest, SelectedIntegerBits) { &generator_); } -} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc index e786f785abe3aa66a9fb243dd4f332ca91676863..3624c20ae3bbf5f8eb5cb5fb51aadcde7327fd55 100644 --- a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc @@ -27,24 +27,27 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/test_util.h" +#include "tensorflow/contrib/lite/string.h" namespace tflite { namespace { void RunLogSoftmaxFloatReference(const uint8* input_data, - const Dims<4>& dims_common, int32 input_offset, - const double input_scale, int stride, - float beta, uint8* reference_output_data) { - const int ref_buffer_size = RequiredBufferSizeForDims(dims_common); + const RuntimeShape& shape_common, + int32 input_offset, const double input_scale, + int stride, float beta, + uint8* reference_output_data) { + const int ref_buffer_size = shape_common.FlatSize(); std::vector reference_dequant_data(ref_buffer_size); std::vector reference_output_float_data(ref_buffer_size); // Reference data generated via Dequant of input into float, and then applying // float LogSoftmax. - reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale, - reference_dequant_data.data(), dims_common); - optimized_ops::LogSoftmax(reference_dequant_data.data(), dims_common, - reference_output_float_data.data(), dims_common); + reference_ops::Dequantize( + input_data, ToRuntimeDims(shape_common), input_offset, input_scale, + reference_dequant_data.data(), ToRuntimeDims(shape_common)); + optimized_ops::LogSoftmax(reference_dequant_data.data(), shape_common, + reference_output_float_data.data(), shape_common); // Work with quantized scaling for LogSoftmax, under which 255 represents 0, // and -16 gets nudged up to 0. for (int i = 0; i < ref_buffer_size; i++) { @@ -55,9 +58,9 @@ void RunLogSoftmaxFloatReference(const uint8* input_data, } void CheckOutputData(const uint8* test_output, const uint8* reference_output, - const Dims<4>& dims_common, const string& check_label, - bool be_exacting) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); + const RuntimeShape& shape_common, + const string& check_label, bool be_exacting) { + const int buffer_size = shape_common.FlatSize(); // While calculating some metrics in floating point, we work with quantized // scaling. std::vector diff(buffer_size); @@ -99,15 +102,15 @@ void CheckOutputData(const uint8* test_output, const uint8* reference_output, // Runs the LogSoftmax and compares against the float reference implementation // and the quantized reference implementation. -void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, - int32 input_offset, const double input_scale, - int stride, float beta) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); +void RunOneLogSoftmaxTest(const uint8* input_data, + const RuntimeShape& shape_common, int32 input_offset, + const double input_scale, int stride, float beta) { + const int buffer_size = shape_common.FlatSize(); std::vector optimized_logsoftmax_output(buffer_size); std::vector reference_float_logsoftmax_output(buffer_size); std::vector reference_quant_logsoftmax_output(buffer_size); - RunLogSoftmaxFloatReference(input_data, dims_common, input_offset, + RunLogSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, stride, beta, reference_float_logsoftmax_output.data()); @@ -126,23 +129,23 @@ void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, input_beta_left_shift); - optimized_ops::LogSoftmax(input_data, dims_common, input_beta_multiplier, + optimized_ops::LogSoftmax(input_data, shape_common, input_beta_multiplier, input_beta_left_shift, reverse_scaling_divisor, reverse_scaling_right_shift, diff_min, - optimized_logsoftmax_output.data(), dims_common); + optimized_logsoftmax_output.data(), shape_common); reference_ops::LogSoftmax( - input_data, dims_common, input_beta_multiplier, input_beta_left_shift, + input_data, shape_common, input_beta_multiplier, input_beta_left_shift, reverse_scaling_divisor, reverse_scaling_right_shift, diff_min, - reference_quant_logsoftmax_output.data(), dims_common); + reference_quant_logsoftmax_output.data(), shape_common); CheckOutputData(optimized_logsoftmax_output.data(), - reference_float_logsoftmax_output.data(), dims_common, + reference_float_logsoftmax_output.data(), shape_common, "Optimized vs float reference", false); CheckOutputData(optimized_logsoftmax_output.data(), - reference_quant_logsoftmax_output.data(), dims_common, + reference_quant_logsoftmax_output.data(), shape_common, "Optimized vs quant reference", true); CheckOutputData(reference_quant_logsoftmax_output.data(), - reference_float_logsoftmax_output.data(), dims_common, + reference_float_logsoftmax_output.data(), shape_common, "Quant reference vs float reference", false); } @@ -165,13 +168,13 @@ bool TryOneUniformLogSoftmax() { const int32 input_offset = UniformRandomInt(-256, 0); static constexpr float beta = 1.0f; - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector input_data(buffer_size); FillRandom(&input_data); - RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset, + RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } @@ -203,14 +206,14 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) { const int middle_min = UniformRandomInt(0, 255); const int sides_max = UniformRandomInt(0, middle_min); - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector input_data(buffer_size); FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, sides_max); - RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset, + RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index a7b0d805a3acd35b592a35ba4266dfff4eb992cd..0ce64f8c70d76f970df610f47947580a1efde720 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -26,7 +26,7 @@ namespace optimized_ops { // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on // Jetson TX-2. This compiler does not support the offsetof() macro. #if defined(__aarch64__) && !defined(GOOGLE_L4T) - +#include // clang-format gets confused with this file and ends up formatting lines to // be larger than 80 characters. Turn off here and back on at the end of the // file. @@ -3242,6 +3242,7 @@ inline void DepthwiseConv3x3Filter( int32 output_shift, int32 output_activation_min, int32 output_activation_max, uint8* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); DepthwiseConvParams params; params.input_depth = ArraySize(input_dims, 0); params.input_width = ArraySize(input_dims, 1); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h index 655c8a3f022fe83c946d8028491ad750bb83b73a..2d96f547abb7e69b51bd9784f7e3e3abf4ec636b 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h @@ -33,7 +33,7 @@ limitations under the License. #include #ifdef _WIN32 -#include +#include #elif defined(__APPLE__) #include #else diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..03e2af22541f32eda1da10fe5ee5eddf1d9cda69 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -0,0 +1,631 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ + +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Unoptimized reference ops: +using reference_ops::Relu1; +using reference_ops::Relu6; + +inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { + return RuntimeShape( + {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + L2Normalization(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, + output_data, DimsToShape(output_dims)); +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + tflite::ArithmeticParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +inline void Add(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, int input2_shift, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + + tflite::ArithmeticParams op_params; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = kReverseShift * output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +void Add(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add/int32"); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + + tflite::ArithmeticParams op_params; + op_params.quantized_activation_min = std::numeric_limits::min(); + op_params.quantized_activation_max = std::numeric_limits::max(); + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + +template +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + + tflite::ArithmeticParams op_params; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = kReverseShift * output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + +template +inline void BroadcastAddFivefold( + int y0, int y1, int y2, int y3, int y4, int left_shift, + const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + tflite::ArithmeticParams op_params; + op_params.broadcast_category = + tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = kReverseShift * output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + op_params.broadcast_shape[4] = y0; + op_params.broadcast_shape[3] = y1; + op_params.broadcast_shape[2] = y2; + op_params.broadcast_shape[1] = y3; + op_params.broadcast_shape[0] = y4; + BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + +// legacy, for compatibility with old checked-in code +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +template +inline void Add(const int16* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16* input2_data, + const Dims<4>& input2_dims, int input2_shift, + int16 output_activation_min, int16 output_activation_max, + int16* output_data, const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, -32768); + TFLITE_DCHECK_EQ(output_activation_max, 32767); + } + + tflite::ArithmeticParams op_params; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +inline void Sub(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(FusedActivationFunctionType::kNone, + &output_activation_min, &output_activation_max); + tflite::ArithmeticParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + Sub(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, + const Dims<4>& input2_dims, T* output_data, + const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(FusedActivationFunctionType::kNone, + &output_activation_min, &output_activation_max); + tflite::ArithmeticParams op_params; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + Sub(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul4DSlow( + input1_data, input1_dims, input1_offset, input2_data, input2_dims, + input2_offset, output_offset, output_multiplier, + // This legacy version switches the sign of the output shift. + kReverseShift * output_shift, + // (Break to highlight preceding line.) + output_activation_min, output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, + input2_dims, input2_offset, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_data, output_dims); +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + tflite::PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = kheight; + params.filter_width = kwidth; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.float_activation_min = output_activation_min; + params.float_activation_max = output_activation_max; + AveragePool(params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + tflite::PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = filter_height; + params.filter_width = filter_width; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.quantized_activation_min = output_activation_min; + params.quantized_activation_max = output_activation_max; + AveragePool(params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + tflite::PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = kheight; + params.filter_width = kwidth; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.float_activation_min = output_activation_min; + params.float_activation_max = output_activation_max; + MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = filter_height; + params.filter_width = filter_width; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.quantized_activation_min = output_activation_min; + params.quantized_activation_max = output_activation_max; + MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = filter_height; + params.filter_width = filter_width; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.float_activation_min = output_activation_min; + params.float_activation_max = output_activation_max; + L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), beta, output_data, + DimsToShape(output_dims)); +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier, + input_beta_left_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_multiplier, int32 input_left_shift, + int32 reverse_scaling_divisor, + int32 reverse_scaling_right_shift, int diff_min, + uint8* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier, + input_left_shift, reverse_scaling_divisor, + reverse_scaling_right_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const int16* input_data, const Dims<4>& input_dims, + int16* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const int16* input_data, const Dims<4>& input_dims, + int input_left_shift, int16* output_data, + const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data, + DimsToShape(output_dims)); +} + +} // namespace optimized_ops +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h index 27d9224512a835ea58911031f1b4d6dcf5482ba9..4a3545d47aca7d649061d39cbc23fa7ddf208156 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -35,35 +35,6 @@ limitations under the License. namespace tflite { namespace multithreaded_ops { -class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { - public: - explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {} - ~EigenThreadPoolWrapper() override {} - - void Schedule(std::function fn) override { - pool_->Schedule(std::move(fn)); - } - int NumThreads() const override { return pool_->NumThreads(); } - int CurrentThreadId() const override { return pool_->CurrentThreadId(); } - - private: - Eigen::ThreadPool* pool_ = nullptr; -}; - -// We have a single global threadpool for all convolution operations. This means -// that inferences started from different threads may block each other, but -// since the underlying resource of CPU cores should be consumed by the -// operations anyway, it shouldn't affect overall performance. -const Eigen::ThreadPoolDevice& GetThreadPoolDevice() { - const int thread_count = 4; - static Eigen::ThreadPool* tp = new Eigen::ThreadPool(thread_count); - static EigenThreadPoolWrapper* thread_pool_wrapper = - new EigenThreadPoolWrapper(tp); - static Eigen::ThreadPoolDevice* device = - new Eigen::ThreadPoolDevice(thread_pool_wrapper, thread_count); - return *device; -} - // Shorthands for the types we need when interfacing with the EigenTensor // library. typedef Eigen::TensorMap< @@ -113,14 +84,13 @@ class EigenTensorConvFunctor { } public: - void operator()(const T* input_data, T* im2col_buffer, int input_batches, - int input_height, int input_width, int input_depth, - const T* filter_data, int filter_height, int filter_width, - int filter_count, int stride_rows, int stride_cols, - int pad_width, int pad_height, TfLitePadding padding, - T* output_data, int output_height, int output_width) { - const Eigen::ThreadPoolDevice& device = GetThreadPoolDevice(); - + void operator()(const Eigen::ThreadPoolDevice& device, const T* input_data, + T* im2col_buffer, int input_batches, int input_height, + int input_width, int input_depth, const T* filter_data, + int filter_height, int filter_width, int filter_count, + int stride_rows, int stride_cols, int pad_width, + int pad_height, TfLitePadding padding, T* output_data, + int output_height, int output_width) { const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 && stride_rows == 1 && stride_cols == 1); if (is_1x1_kernel) { @@ -162,11 +132,11 @@ class EigenTensorConvFunctor { } }; -inline void Conv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - const float* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, TfLitePadding padding, +inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data, + const Dims<4>& input_dims, const float* filter_data, + const Dims<4>& filter_dims, const float* bias_data, + const Dims<4>& bias_dims, int stride_width, int stride_height, + int pad_width, int pad_height, TfLitePadding padding, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims, float* im2col_data, const Dims<4>& im2col_dims) { @@ -180,10 +150,11 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, const int output_height = ArraySize(output_dims, 2); const int output_width = ArraySize(output_dims, 1); EigenTensorConvFunctor conv_functor; - conv_functor(input_data, im2col_data, batches, input_height, input_width, - input_depth, filter_data, filter_height, filter_width, - output_depth, stride_height, stride_width, pad_height, pad_width, - padding, output_data, output_height, output_width); + conv_functor(device, input_data, im2col_data, batches, input_height, + input_width, input_depth, filter_data, filter_height, + filter_width, output_depth, stride_height, stride_width, + pad_height, pad_width, padding, output_data, output_height, + output_width); optimized_ops::AddBiasAndEvalActivationFunction( bias_data, bias_dims, output_data, output_dims, output_activation_min, diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc index 38ad32c734a2286c7d23162810625169a4d8df43..420bc68b43dd8c135e95badcc7d18935449cfc73 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -55,83 +55,33 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, const int postamble_start = m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1)); - // The arrays used to cache the vector. - void* aligned_vector_cache_free = nullptr; - float32x4_t* vector_cache_float32x4 = - reinterpret_cast(aligned_alloc( - sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t), - &aligned_vector_cache_free)); - - const int kUnrollSize = 2; for (int b = 0; b < n_batch; b++) { float* result_in_batch = result + b * m_rows * result_stride; const float* vector_in_batch = vector + b * m_cols; + const float* matrix_row = matrix; - const float* matrix_ptr0 = matrix; - // If there is only 1 row, we don't want to assign an illegal pointer. - const float* matrix_ptr1 = nullptr; - if (m_rows > 1) { - matrix_ptr1 = matrix + m_cols; - } - - // Cache the vector. - for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { - vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c); - } - - // Main matrix by vector multiplication loop, which handles two rows of - // matrix by vector multiplication. - for (int r = 0; r < (m_rows & ~(kUnrollSize - 1)); r += kUnrollSize) { - float32x4_t acc0_32x4 = vmovq_n_f32(0.0); - float32x4_t acc1_32x4 = vmovq_n_f32(0.0); + // Main matrix by vector multiplication loop + for (int r = 0; r < m_rows; r++) { + float32x4_t acc_32x4 = vmovq_n_f32(0.0); for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { - float32x4_t temp = vector_cache_float32x4[c >> 2]; - // Load 4 float values from vector1 and vector2 and accumulator. - float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c); - float32x4_t v1_f32x4 = vld1q_f32(matrix_ptr1 + c); - // Vector multiply-accumulate 4 float - acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp); - acc1_32x4 = vmlaq_f32(acc1_32x4, v1_f32x4, temp); + // Load 4 float values from vector and matrix row. + float32x4_t vector_f32x4 = vld1q_f32(vector_in_batch + c); + float32x4_t matrix_f32x4 = vld1q_f32(matrix_row + c); + // Multiply the vector and matrix row and add to accumulator. + acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4); } // Add the 4 intermediate sum values to get the final dot-prod value for // this column. *result_in_batch += - (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) + - vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3)); - *(result_in_batch + result_stride) += - (vgetq_lane_f32(acc1_32x4, 0) + vgetq_lane_f32(acc1_32x4, 1) + - vgetq_lane_f32(acc1_32x4, 2) + vgetq_lane_f32(acc1_32x4, 3)); + (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) + + vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3)); for (int c = postamble_start; c < m_cols; c++) { - *result_in_batch += matrix_ptr0[c] * vector_in_batch[c]; - *(result_in_batch + result_stride) += - matrix_ptr1[c] * vector_in_batch[c]; + *result_in_batch += matrix_row[c] * vector_in_batch[c]; } - matrix_ptr0 += kUnrollSize * m_cols; - matrix_ptr1 += kUnrollSize * m_cols; - result_in_batch += kUnrollSize * result_stride; - } - for (int r = (m_rows & ~(kUnrollSize - 1)); r < m_rows; r++) { - float32x4_t acc0_32x4 = vmovq_n_f32(0.0); - for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { - float32x4_t temp = vector_cache_float32x4[c >> 2]; - // Load 4 float values from vector1 and vector2 and accumulator. - float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c); - // Vector multiply-accumulate 4 float - acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp); - } - // Add the 4 intermediate sum values to get the final dot-prod value for - // this column. - *result_in_batch += - (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) + - vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3)); - for (int c = postamble_start; c < m_cols; c++) { - *result_in_batch += matrix_ptr0[c] * vector_in_batch[c]; - } - matrix_ptr0 += m_cols; + matrix_row += m_cols; result_in_batch += result_stride; } } - free(aligned_vector_cache_free); } void NeonMatrixBatchVectorMultiplyAccumulate( @@ -162,7 +112,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate( int batch, row, col; for (batch = 0; batch < n_batch; ++batch) { - const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch]; + const float batch_scaling_factor = scaling_factors[batch]; // Copy the vector data to an aligned vector. memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8) * m_cols); // Compute dot-product for every column. @@ -232,7 +182,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate( int32 neon_sum = vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1); - *result += ((neon_sum + postable_sum) * batch_scaling_factor_inv); + *result += ((neon_sum + postable_sum) * batch_scaling_factor); } // for row } // for batch @@ -296,17 +246,6 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, const int postamble_start = v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); - // The arrays used to cache the vector. - void* aligned_vector_cache_free = nullptr; - float32x4_t* vector_cache_float32x4 = - reinterpret_cast(aligned_alloc( - sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t), - &aligned_vector_cache_free)); - - for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { - vector_cache_float32x4[v >> 2] = vld1q_f32(vector + v); - } - float* result_ptr = result; const float* batch_vector_ptr = batch_vector; for (int b = 0; b < n_batch; b++) { @@ -314,9 +253,9 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, // Load from memory to vectors. float32x4_t result_f32x4 = vld1q_f32(result_ptr + v); float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector_ptr + v); + float32x4_t vector_f32x4 = vld1q_f32(vector + v); // Multiply-accumulate. - result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4, - vector_cache_float32x4[v >> 2]); + result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4, vector_f32x4); // Store. vst1q_f32(result_ptr + v, result_f32x4); } @@ -328,7 +267,6 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, result_ptr += v_size; batch_vector_ptr += v_size; } - free(aligned_vector_cache_free); } void NeonSub1Vector(const float* vector, int v_size, float* result) { @@ -404,6 +342,77 @@ void NeonClipVector(const float* vector, int v_size, float abs_limit, } } +void NeonVectorScalarMultiply(const int8_t* vector, const int v_size, + const float scale, float* result) { + // Here the assumption is that each buffer is 4-byte aligned. + const int kWeightsPerUint32 = 4; + TFLITE_CHECK_EQ((intptr_t)(&vector[0]) & (kWeightsPerUint32 - 1), 0); + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int kWeightsPerNeonLane = 16; + const int postamble_start = v_size - (v_size & (kWeightsPerNeonLane - 1)); + + // Create a vector of 4 floats with the scale value. + const float32x4_t scale_f32x4 = vdupq_n_f32(scale); + int v = 0; + for (; v < postamble_start; v += kWeightsPerNeonLane) { + // Load int8 values, sixteen at a time. + const int8x16_t v_i8x16 = vld1q_s8(vector + v); + // Split it into two components of size eight. + const int8x8_t v0_i8x8 = vget_low_s8(v_i8x16); + const int8x8_t v1_i8x8 = vget_high_s8(v_i8x16); + // Convert both components to int16 first. + const int16x8_t v0_i16x8 = vmovl_s8(v0_i8x8); + const int16x8_t v1_i16x8 = vmovl_s8(v1_i8x8); + // Split each of them into two components each. + const int16x4_t v0_i16x4 = vget_low_s16(v0_i16x8); + const int16x4_t v1_i16x4 = vget_high_s16(v0_i16x8); + const int16x4_t v2_i16x4 = vget_low_s16(v1_i16x8); + const int16x4_t v3_i16x4 = vget_high_s16(v1_i16x8); + // Convert these to int32 and then to float. + float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4)); + float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4)); + float32x4_t v2_f32x4 = vcvtq_f32_s32(vmovl_s16(v2_i16x4)); + float32x4_t v3_f32x4 = vcvtq_f32_s32(vmovl_s16(v3_i16x4)); + // Vector multiply four floats at a time. + v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4); + v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4); + v2_f32x4 = vmulq_f32(v2_f32x4, scale_f32x4); + v3_f32x4 = vmulq_f32(v3_f32x4, scale_f32x4); + // Store the results. + vst1q_f32(result + v, v0_f32x4); + vst1q_f32(result + v + 4, v1_f32x4); + vst1q_f32(result + v + 8, v2_f32x4); + vst1q_f32(result + v + 12, v3_f32x4); + } + + if (v_size - postamble_start >= (kWeightsPerNeonLane >> 1)) { + // Load eight int8 values, if there is at least eight remaining. + const int8x8_t v_i8x8 = vld1_s8(vector + v); + // Convert them to int16 first. + const int16x8_t v_i16x8 = vmovl_s8(v_i8x8); + // Split it into two components. + const int16x4_t v0_i16x4 = vget_low_s16(v_i16x8); + const int16x4_t v1_i16x4 = vget_high_s16(v_i16x8); + // Convert the components two floats. + float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4)); + float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4)); + // Vector multiply four floats at a time. + v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4); + v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4); + // Store the results. + vst1q_f32(result + v, v0_f32x4); + vst1q_f32(result + v + 4, v1_f32x4); + v += (kWeightsPerNeonLane >> 1); + } + + // Postamble loop. + for (; v < v_size; v++) { + result[v] = scale * vector[v]; + } +} + void NeonSymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, float* min, float* max, float* scaling_factor) { @@ -418,13 +427,14 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size, *scaling_factor = 1; return; } - *scaling_factor = kScale / range; + *scaling_factor = range / kScale; + const float scaling_factor_inv = 1.0f / *scaling_factor; const int postamble_start = size - (size & (2 * kFloatWeightsPerNeonLane - 1)); // Vectorized constants. - const float32x4_t q_factor_f32x4 = vmovq_n_f32(*scaling_factor); + const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv); const float32x4_t point5_f32x4 = vmovq_n_f32(0.5); const float32x4_t zero_f32x4 = vmovq_n_f32(0.0); const int32x4_t scale_i32x4 = vmovq_n_s32(kScale); @@ -476,7 +486,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size, for (int i = postamble_start; i < size; ++i) { const int32 quantized_value = - static_cast(TfLiteRound(*scaling_factor * values[i])); + static_cast(TfLiteRound(scaling_factor_inv * values[i])); quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value)); } } diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h index 7a5a8fc54123946229963abd1720030d0bb358bf..63c89d1eeef47b206fc871929f1fb1295b2f70ff 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -105,16 +105,20 @@ bool IsZeroVector(const float* vector, int v_size) { return NEON_OR_PORTABLE(IsZeroVector, vector, v_size); } +void VectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result) { + NEON_OR_PORTABLE(VectorScalarMultiply, vector, v_size, scale, result); +} void ClipVector(const float* vector, int v_size, float abs_limit, float* result) { NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result); } void SymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float* min, float* max, - float* scaling_factor) { - NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values, min, - max, scaling_factor); + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor) { + NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values, + min_value, max_value, scaling_factor); } void VectorShiftLeft(float* vector, int v_size, float shift_value) { diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index be4825f2507ccb583298af6cd824dd6da1fc47c7..7eb1b5a81d3c7060d70f97e952cc69eaa24f310b 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,16 +40,34 @@ namespace tflite { namespace optimized_ops { // Unoptimized reference ops: +using reference_ops::ArgMax; +using reference_ops::ArgMinMax; +using reference_ops::BroadcastAdd4DSlow; using reference_ops::BroadcastGreater; using reference_ops::BroadcastGreaterEqual; using reference_ops::BroadcastLess; using reference_ops::BroadcastLessEqual; +using reference_ops::BroadcastMul4DSlow; +using reference_ops::BroadcastSub4DSlow; +using reference_ops::Concatenation; +using reference_ops::DepthConcatenation; +using reference_ops::Dequantize; +using reference_ops::Div; +using reference_ops::FakeQuant; +using reference_ops::Gather; using reference_ops::Greater; using reference_ops::GreaterEqual; using reference_ops::Less; using reference_ops::LessEqual; +using reference_ops::Mean; using reference_ops::RankOneSelect; +using reference_ops::Relu1; +using reference_ops::Relu6; +using reference_ops::ReluX; using reference_ops::Select; +using reference_ops::SpaceToBatchND; +using reference_ops::StridedSlice; +using reference_ops::Transpose; // TODO(b/80247582) Remove this constant. // This will be phased out as the shifts are revised with more thought. Use of a @@ -72,6 +90,12 @@ using VectorMap = typename std::conditional< Eigen::Dynamic, 1>>, Eigen::Map>>::type; +template +VectorMap MapAsVector(Scalar* data, const RuntimeShape& shape) { + const int size = shape.FlatSize(); + return VectorMap(data, size, 1); +} + template VectorMap MapAsVector(Scalar* data, const Dims& dims) { const int size = FlatSize(dims); @@ -88,6 +112,23 @@ using MatrixMap = typename std::conditional< Eigen::Dynamic, Eigen::Dynamic>>, Eigen::Map>>::type; +template +MatrixMap MapAsMatrixWithLastDimAsRows(Scalar* data, + const RuntimeShape& shape) { + const int dims_count = shape.DimensionsCount(); + const int rows = shape.Dims(dims_count - 1); + const int cols = FlatSizeSkipDim(shape, dims_count - 1); + return MatrixMap(data, rows, cols); +} + +template +MatrixMap MapAsMatrixWithFirstDimAsCols(Scalar* data, + const RuntimeShape& shape) { + const int cols = shape.Dims(0); + const int rows = FlatSizeSkipDim(shape, 0); + return MatrixMap(data, rows, cols); +} + template MatrixMap MapAsMatrixWithFirstDimAsRows(Scalar* data, const Dims& dims) { @@ -128,22 +169,27 @@ ArrayMap MapAsArrayWithFirstDimAsRows(Scalar* data, return ArrayMap(data, rows, cols); } +// Copied from tensorflow/core/framework/tensor_types.h +template +struct TTypes { + // Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Flat; + typedef Eigen::TensorMap< + Eigen::Tensor> + UnalignedConstMatrix; +}; + // TODO(b/62193649): this function is only needed as long // as we have the --variable_batch hack. template MatrixMap MapAsMatrixWithGivenNumberOfRows(Scalar* data, const Dims& dims, int rows) { - int cols = 1; - bool matched_rows = false; - for (int d = 0; d < N; d++) { - cols *= dims.sizes[d]; - if (cols == rows) { - matched_rows = true; - cols = 1; - } - } - TFLITE_DCHECK(matched_rows); + const int flatsize = FlatSize(dims); + TFLITE_DCHECK((flatsize % rows) == 0); + const int cols = flatsize / rows; return MatrixMap(data, rows, cols); } @@ -186,98 +232,6 @@ SaturatingRoundingMultiplyByPOTParam( SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); } -// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE -// BROADCASTING. -// -// NdArrayDesc describes the shape and memory layout of an N-dimensional -// rectangular array of numbers. -// -// NdArrayDesc is basically identical to Dims defined in types.h. -// However, as Dims is to be deprecated, this class exists as an adaptor -// to enable simple unoptimized implementations of element-wise broadcasting -// operations. -template -struct NdArrayDesc { - // The "extent" of each dimension. Indices along dimension d must be in the - // half-open interval [0, extents[d]). - int extents[N]; - - // The number of *elements* (not bytes) between consecutive indices of each - // dimension. - int strides[N]; -}; - -// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING -// ELEMENT-WISE BROADCASTING. -// -// Same as Offset(), except takes as NdArrayDesc instead of Dims. -inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, - int i3) { - TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); - TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); - TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); - TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); - return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + - i3 * desc.strides[3]; -} - -// Given the dimensions of the operands for an element-wise binary broadcast, -// adjusts them so that they can be directly iterated over with simple loops. -// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and -// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. -// -// This function assumes that the two input shapes are compatible up to -// broadcasting and the shorter one has already been prepended with 1s to be the -// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), -// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that -// Dims refer to shapes in reverse order. In this case, input0_dims will be -// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). -// -// When two shapes are compatible up to broadcasting, for each dimension d, -// the input extents are either equal, or one of them is 1. -// -// This function performs the following for each dimension d: -// - If the extents are equal, then do nothing since the loop that walks over -// both of the input arrays is correct. -// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 -// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows -// array0 to be referenced *at any index* in dimension d and still access the -// same slice. -template -inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, - const Dims& input1_dims, - NdArrayDesc* desc0_out, - NdArrayDesc* desc1_out) { - TFLITE_DCHECK(desc0_out != nullptr); - TFLITE_DCHECK(desc1_out != nullptr); - - // Copy dims to desc. - for (int i = 0; i < N; ++i) { - desc0_out->extents[i] = input0_dims.sizes[i]; - desc0_out->strides[i] = input0_dims.strides[i]; - desc1_out->extents[i] = input1_dims.sizes[i]; - desc1_out->strides[i] = input1_dims.strides[i]; - } - - // Walk over each dimension. If the extents are equal do nothing. - // Otherwise, set the desc with extent 1 to have extent equal to the other and - // stride 0. - for (int i = 0; i < N; ++i) { - const int extent0 = ArraySize(input0_dims, i); - const int extent1 = ArraySize(input1_dims, i); - if (extent0 != extent1) { - if (extent0 == 1) { - desc0_out->strides[i] = 0; - desc0_out->extents[i] = extent1; - } else { - TFLITE_DCHECK_EQ(extent1, 1); - desc1_out->strides[i] = 0; - desc1_out->extents[i] = extent0; - } - } - } -} - inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) { for (int i = 0; i < 4; i++) { if (dims1.sizes[i] != dims2.sizes[i]) { @@ -940,6 +894,7 @@ inline void FullyConnectedAsGEMV( const int input_size = FlatSizeSkipDim(input_dims, 3); const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0); static constexpr int kPeel = 4; + const bool shift_left = (output_shift <= 0); for (int k = 0; k < input_size; k += 64) { optimized_ops_preload_l1_stream(input_data + k); } @@ -1051,11 +1006,17 @@ inline void FullyConnectedAsGEMV( int32x4_t bias_vec = vld1q_s32(bias_ptr); bias_ptr += 4; reduced = vaddq_s32(reduced, bias_vec); - // Multiply by the fixed-point multiplier. - reduced = vqrdmulhq_n_s32(reduced, output_multiplier); - // Rounding-shift-right. - using gemmlowp::RoundingDivideByPOT; - reduced = RoundingDivideByPOT(reduced, output_shift); + if (shift_left) { + const int32 multiplier_power_of_two = 1 << -output_shift; + reduced = vmulq_n_s32(reduced, multiplier_power_of_two); + reduced = vqrdmulhq_n_s32(reduced, output_multiplier); + } else { + // Multiply by the fixed-point multiplier. + reduced = vqrdmulhq_n_s32(reduced, output_multiplier); + // Rounding-shift-right. + using gemmlowp::RoundingDivideByPOT; + reduced = RoundingDivideByPOT(reduced, output_shift); + } // Add the output offset. const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); reduced = vaddq_s32(reduced, output_offset_vec); @@ -1077,10 +1038,10 @@ inline void FullyConnectedAsGEMV( struct GemmlowpOutputPipeline { typedef gemmlowp::VectorMap ColVectorMap; - typedef std::tuple< - gemmlowp::OutputStageBiasAddition, - gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, - gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8> + typedef std::tuple, + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent, + gemmlowp::OutputStageClamp, + gemmlowp::OutputStageSaturatingCastToUint8> Pipeline; static Pipeline MakeExp(const int32* bias_data, int output_rows, int32 output_offset, int32 output_multiplier, @@ -1089,11 +1050,10 @@ struct GemmlowpOutputPipeline { ColVectorMap bias_vector(bias_data, output_rows); gemmlowp::OutputStageBiasAddition bias_addition_stage; bias_addition_stage.bias_vector = bias_vector; - gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint - quantize_down_stage; + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage; quantize_down_stage.result_offset_after_shift = output_offset; quantize_down_stage.result_fixedpoint_multiplier = output_multiplier; - quantize_down_stage.result_shift = -output_left_shift; + quantize_down_stage.result_exponent = output_left_shift; gemmlowp::OutputStageClamp clamp_stage; clamp_stage.min = output_activation_min; clamp_stage.max = output_activation_max; @@ -1256,11 +1216,11 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, } // Internal function doing the actual arithmetic work for -// ExperimentalShuffledFullyConnected. +// ShuffledFullyConnected. // May be called either directly by it (single-threaded case) or may be used // as the 'task' for worker threads to run (multi-threaded case, see -// ExperimentalShuffledFullyConnectedWorkerTask below). -inline void ExperimentalShuffledFullyConnectedWorkerImpl( +// ShuffledFullyConnectedWorkerTask below). +inline void ShuffledFullyConnectedWorkerImpl( const uint8* shuffled_input_workspace_data, const int8* shuffled_weights_data, int batches, int output_depth, int output_stride, int accum_depth, const int32* bias_data, @@ -1534,14 +1494,16 @@ inline void ExperimentalShuffledFullyConnectedWorkerImpl( #endif } -// Wraps ExperimentalShuffledFullyConnectedWorkerImpl into a Task class +// Wraps ShuffledFullyConnectedWorkerImpl into a Task class // to allow using gemmlowp's threadpool. -struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task { - ExperimentalShuffledFullyConnectedWorkerTask( - const uint8* input_data, const int8* shuffled_weights_data, int batches, - int output_depth, int output_stride, int accum_depth, - const int32* bias_data, int32 output_multiplier, int output_shift, - int16* output_data) +struct ShuffledFullyConnectedWorkerTask : gemmlowp::Task { + ShuffledFullyConnectedWorkerTask(const uint8* input_data, + const int8* shuffled_weights_data, + int batches, int output_depth, + int output_stride, int accum_depth, + const int32* bias_data, + int32 output_multiplier, int output_shift, + int16* output_data) : input_data_(input_data), shuffled_weights_data_(shuffled_weights_data), batches_(batches), @@ -1554,7 +1516,7 @@ struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task { output_data_(output_data) {} void Run() override { - ExperimentalShuffledFullyConnectedWorkerImpl( + ShuffledFullyConnectedWorkerImpl( input_data_, shuffled_weights_data_, batches_, output_depth_, output_stride_, accum_depth_, bias_data_, output_multiplier_, output_shift_, output_data_); @@ -1572,15 +1534,14 @@ struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task { int16* output_data_; }; -inline void ExperimentalShuffledFullyConnected( +inline void ShuffledFullyConnected( const uint8* input_data, const Dims<4>& input_dims, const uint8* shuffled_weights_data, const Dims<4>& weights_dims, const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier, int output_shift, int32 output_activation_min, int32 output_activation_max, int16* output_data, const Dims<4>& output_dims, uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) { - gemmlowp::ScopedProfilingLabel label( - "ExperimentalShuffledFullyConnected/8bit"); + gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit"); (void)gemm_context; // only used in optimized code. TFLITE_DCHECK_EQ(output_activation_min, -32768); TFLITE_DCHECK_EQ(output_activation_max, 32767); @@ -1664,7 +1625,7 @@ inline void ExperimentalShuffledFullyConnected( if (thread_count == 1) { // Single-thread case: do the computation on the current thread, don't // use a threadpool - ExperimentalShuffledFullyConnectedWorkerImpl( + ShuffledFullyConnectedWorkerImpl( shuffled_input_workspace_data, int8_shuffled_weights_data, batches, output_depth, output_depth, accum_depth, bias_data, output_multiplier, output_shift, output_data); @@ -1679,7 +1640,7 @@ inline void ExperimentalShuffledFullyConnected( int row_start = 0; for (int i = 0; i < thread_count; i++) { int row_end = std::min(output_depth, row_start + kRowsPerWorker); - tasks[i] = new ExperimentalShuffledFullyConnectedWorkerTask( + tasks[i] = new ShuffledFullyConnectedWorkerTask( shuffled_input_workspace_data, int8_shuffled_weights_data + row_start * accum_depth, batches, row_end - row_start, output_depth, accum_depth, bias_data + row_start, @@ -1821,8 +1782,8 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, // Use dimensions M and N to construct dims for indexing directly into im2col Dims<4> im2col_dims; - im2col_dims.sizes[0] = col_dims.strides[3]; - im2col_dims.sizes[1] = row_dims.strides[3]; + im2col_dims.sizes[0] = FlatSize(col_dims); + im2col_dims.sizes[1] = FlatSize(row_dims); im2col_dims.sizes[2] = 1; im2col_dims.sizes[3] = 1; ComputeStrides(&im2col_dims); @@ -1831,8 +1792,8 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { - // Each row is an output pixel. Arrange the input data into this row in - // an order we can conveniently multiply with the filter data. + // Each im2col row is an output pixel. Arrange the input data in this + // row in an order we can conveniently multiply with the filter data. int row_offset = Offset(row_dims, out_x, out_y, batch, 0); const int in_x_origin = (out_x * stride_width) - pad_width; const int in_y_origin = (out_y * stride_height) - pad_height; @@ -1848,7 +1809,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, T* dst = im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0); if ((in_x >= 0) && (in_x < input_width)) { - // Filter pixel is within the input, copy the data. + // Filter pixel is within the input, copy the input data. T const* src = input_data + Offset(input_dims, 0, in_x, in_y, batch); memcpy(dst, src, input_depth * sizeof(T)); @@ -1858,7 +1819,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, } } } else { - // Filter row is outside the input, zero out the entire im2col row. + // Filter row is outside the input, zero out the entire filter row. int col_offset = Offset(col_dims, 0, 0, filter_y, 0); T* dst = im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0); @@ -1922,7 +1883,7 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, (void)im2col_dims; gemmlowp::ScopedProfilingLabel label("Conv"); - // A float set to 0x00000000h == 0.0f + // NB: static_cast(0x00000000h) == 0.0f const uint8 float_zero_byte = 0x00; const float* gemm_input_data = nullptr; const Dims<4>* gemm_input_dims = nullptr; @@ -2330,48 +2291,25 @@ void GlobalBatchNormalization(const float* input_data, } } -inline void Relu(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void Relu(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Relu (not fused)"); - const auto input = MapAsVector(input_data, input_dims); - auto output = MapAsVector(output_data, output_dims); + const auto input = MapAsVector(input_data, input_shape); + auto output = MapAsVector(output_data, output_shape); output = input.cwiseMax(0.0f); } -inline void Relu1(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); - const int flat_size = MatchingFlatSize(input_dims, output_dims); - for (int i = 0; i < flat_size; ++i) { - const float val = input_data[i]; - const float upper = 1; - const float lower = -1; - const float clamped = val > upper ? upper : val < lower ? lower : val; - output_data[i] = clamped; - } -} - -inline void Relu6(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); - const int flat_size = MatchingFlatSize(input_dims, output_dims); - for (int i = 0; i < flat_size; ++i) { - const float val = input_data[i]; - const float upper = 6; - const float lower = 0; - const float clamped = val > upper ? upper : val < lower ? lower : val; - output_data[i] = clamped; - } -} - template -void L2Normalization(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +void L2Normalization(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("L2Normalization"); static_assert(Ac == FusedActivationFunctionType::kNone, ""); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { float squared_l2_norm = 0; for (int c = 0; c < depth; ++c) { @@ -2396,7 +2334,8 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input, ++*output_shift; } TFLITE_DCHECK_GT(input, 0); - const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bits = + CountLeadingZeros(static_cast(input)) - 1; const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; *output_shift -= left_shift_bit_pairs; @@ -2434,17 +2373,20 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input, *output_shift *= kReverseShift; } -inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, +inline void L2Normalization(const uint8* input_data, + const RuntimeShape& input_shape, int32 input_zero_point, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit"); - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); for (int i = 0; i < outer_size; ++i) { int32 square_l2_norm = 0; for (int c = 0; c < depth; c++) { + // Note that input_data advances by depth in the second pass below. int32 diff = input_data[c] - input_zero_point; square_l2_norm += diff * diff; } @@ -2466,20 +2408,17 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Add(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const float* input1_data, + const RuntimeShape& input2_shape, const float* input2_data, + const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("Add"); - TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); int i = 0; - const int size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape); #ifdef USE_NEON - const auto activation_min = vdupq_n_f32(output_activation_min); - const auto activation_max = vdupq_n_f32(output_activation_max); + const auto activation_min = vdupq_n_f32(params.float_activation_min); + const auto activation_max = vdupq_n_f32(params.float_activation_max); for (; i <= size - 16; i += 16) { auto a10 = vld1q_f32(input1_data + i); auto a11 = vld1q_f32(input1_data + i + 4); @@ -2518,29 +2457,26 @@ inline void Add(const float* input1_data, const Dims<4>& input1_dims, for (; i < size; i++) { auto x = input1_data[i] + input2_data[i]; - output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min, - output_activation_max); + output_data[i] = ActivationFunctionWithMinMax( + x, params.float_activation_min, params.float_activation_max); } } // Element-wise add that can often be used for inner loop of broadcast add as // well as the non-broadcast add. -inline void AddElementwise(int size, int left_shift, const uint8* input1_data, - int32 input1_offset, int32 input1_multiplier, - int input1_shift, const uint8* input2_data, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data) { +inline void AddElementwise(int size, const ArithmeticParams& params, + const uint8* input1_data, const uint8* input2_data, + uint8* output_data) { int i = 0; - TFLITE_DCHECK_GT(input1_offset, -256); - TFLITE_DCHECK_GT(input2_offset, -256); - TFLITE_DCHECK_LT(input1_offset, 256); - TFLITE_DCHECK_LT(input2_offset, 256); + TFLITE_DCHECK_GT(params.input1_offset, -256); + TFLITE_DCHECK_GT(params.input2_offset, -256); + TFLITE_DCHECK_LT(params.input1_offset, 256); + TFLITE_DCHECK_LT(params.input2_offset, 256); #ifdef USE_NEON - const auto output_activation_min_vector = vdup_n_u8(output_activation_min); - const auto output_activation_max_vector = vdup_n_u8(output_activation_max); + const auto output_activation_min_vector = + vdup_n_u8(params.quantized_activation_min); + const auto output_activation_max_vector = + vdup_n_u8(params.quantized_activation_max); for (; i <= size - 8; i += 8) { const auto input1_val_original = vld1_u8(input1_data + i); const auto input2_val_original = vld1_u8(input2_data + i); @@ -2549,9 +2485,9 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data, const auto input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original)); const auto input1_val = - vaddq_s16(input1_val_s16, vdupq_n_s16(input1_offset)); + vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset)); const auto input2_val = - vaddq_s16(input2_val_s16, vdupq_n_s16(input2_offset)); + vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset)); const auto input1_val_high = vget_high_s16(input1_val); const auto input1_val_low = vget_low_s16(input1_val); const auto input2_val_high = vget_high_s16(input2_val); @@ -2560,32 +2496,32 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data, auto x12 = vmovl_s16(input1_val_high); auto x21 = vmovl_s16(input2_val_low); auto x22 = vmovl_s16(input2_val_high); - const auto left_shift_dup = vdupq_n_s32(left_shift); + const auto left_shift_dup = vdupq_n_s32(params.left_shift); x11 = vshlq_s32(x11, left_shift_dup); x12 = vshlq_s32(x12, left_shift_dup); x21 = vshlq_s32(x21, left_shift_dup); x22 = vshlq_s32(x22, left_shift_dup); - x11 = vqrdmulhq_n_s32(x11, input1_multiplier); - x12 = vqrdmulhq_n_s32(x12, input1_multiplier); - x21 = vqrdmulhq_n_s32(x21, input2_multiplier); - x22 = vqrdmulhq_n_s32(x22, input2_multiplier); - const auto input1_shift_dup = vdupq_n_s32(-input1_shift); - const auto input2_shift_dup = vdupq_n_s32(-input2_shift); + x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier); + x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier); + x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier); + x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier); + const auto input1_shift_dup = vdupq_n_s32(params.input1_shift); + const auto input2_shift_dup = vdupq_n_s32(params.input2_shift); x11 = vshlq_s32(x11, input1_shift_dup); x12 = vshlq_s32(x12, input1_shift_dup); x21 = vshlq_s32(x21, input2_shift_dup); x22 = vshlq_s32(x22, input2_shift_dup); auto s1 = vaddq_s32(x11, x21); auto s2 = vaddq_s32(x12, x22); - s1 = vqrdmulhq_n_s32(s1, output_multiplier); - s2 = vqrdmulhq_n_s32(s2, output_multiplier); + s1 = vqrdmulhq_n_s32(s1, params.output_multiplier); + s2 = vqrdmulhq_n_s32(s2, params.output_multiplier); using gemmlowp::RoundingDivideByPOT; - s1 = RoundingDivideByPOT(s1, output_shift); - s2 = RoundingDivideByPOT(s2, output_shift); + s1 = RoundingDivideByPOT(s1, -params.output_shift); + s2 = RoundingDivideByPOT(s2, -params.output_shift); const auto s1_narrowed = vmovn_s32(s1); const auto s2_narrowed = vmovn_s32(s2); const auto s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), - vdupq_n_s16(output_offset)); + vdupq_n_s16(params.output_offset)); const auto clamped = vmax_u8(output_activation_min_vector, vmin_u8(output_activation_max_vector, vqmovun_s16(s))); @@ -2594,113 +2530,74 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data, #endif // NEON for (; i < size; ++i) { - const int32 input1_val = input1_offset + input1_data[i]; - const int32 input2_val = input2_offset + input2_data[i]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 input1_val = params.input1_offset + input1_data[i]; + const int32 input2_val = params.input2_offset + input2_data[i]; + const int32 shifted_input1_val = input1_val * (1 << params.left_shift); + const int32 shifted_input2_val = input2_val * (1 << params.left_shift); const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); + shifted_input1_val, params.input1_multiplier, params.input1_shift); const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); + shifted_input2_val, params.input2_multiplier, params.input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sum, output_multiplier, kReverseShift * output_shift) + - output_offset; - const int32 clamped_output = std::min( - output_activation_max, std::max(output_activation_min, raw_output)); + raw_sum, params.output_multiplier, params.output_shift) + + params.output_offset; + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, raw_output)); output_data[i] = static_cast(clamped_output); } } -// legacy, for compatibility with old checked-in code -template -void Add(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float* output_data, const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, - output_activation_max, output_data, output_dims); -} - -template -inline void Add(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, int input2_shift, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const uint8* input1_data, + const RuntimeShape& input2_shape, const uint8* input2_data, + const RuntimeShape& output_shape, uint8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); gemmlowp::ScopedProfilingLabel label("Add/8bit"); - const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); - TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); - TFLITE_DCHECK_GT(input1_offset, -256); - TFLITE_DCHECK_GT(input2_offset, -256); - TFLITE_DCHECK_LT(input1_offset, 256); - TFLITE_DCHECK_LT(input2_offset, 256); - AddElementwise(flat_size, left_shift, input1_data, input1_offset, - input1_multiplier, input1_shift, input2_data, input2_offset, - input2_multiplier, input2_shift, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data); + TFLITE_DCHECK_GT(params.input1_offset, -256); + TFLITE_DCHECK_GT(params.input2_offset, -256); + TFLITE_DCHECK_LT(params.input1_offset, 256); + TFLITE_DCHECK_LT(params.input2_offset, 256); + AddElementwise(flat_size, params, input1_data, input2_data, output_data); } -template -inline void Add(const int16* input1_data, const Dims<4>& input1_dims, - int input1_shift, const int16* input2_data, - const Dims<4>& input2_dims, int input2_shift, - int16 output_activation_min, int16 output_activation_max, - int16* output_data, const Dims<4>& output_dims) { +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int16* input1_data, + const RuntimeShape& input2_shape, const int16* input2_data, + const RuntimeShape& output_shape, int16* output_data) { gemmlowp::ScopedProfilingLabel label("Add/Int16"); - // This is a copy of the reference implementation. We do not currently have a - // properly optimized version. - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, -32768); - TFLITE_DCHECK_EQ(output_activation_max, 32767); - } - - const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); - - TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0); - TFLITE_DCHECK_GE(input1_shift, 0); - TFLITE_DCHECK_GE(input2_shift, 0); + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + + const int input1_shift = params.input1_shift; + const int flat_size = + MatchingFlatSize(output_shape, input1_shape, input2_shape); + const int16 output_activation_min = params.quantized_activation_min; + const int16 output_activation_max = params.quantized_activation_max; + + TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0); + TFLITE_DCHECK_LE(input1_shift, 0); + TFLITE_DCHECK_LE(params.input2_shift, 0); const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data; const int16* shift_input = input1_shift == 0 ? input2_data : input1_data; - const int input_shift = input1_shift == 0 ? input2_shift : input1_shift; + const int input_right_shift = + input1_shift == 0 ? -params.input2_shift : -input1_shift; for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. using F0 = gemmlowp::FixedPoint; F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]); - F0 scaled_input = - F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_shift)); + F0 scaled_input = F0::FromRaw( + gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift)); F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled); const int16 raw_output = result.raw(); const int16 clamped_output = std::min( @@ -2709,159 +2606,59 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims, } } -template -void Add(const int32* input1_data, const Dims<4>& input1_dims, - const int32* input2_data, const Dims<4>& input2_dims, - int32* output_data, const Dims<4>& output_dims) { +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int32* input1_data, + const RuntimeShape& input2_shape, const int32* input2_data, + const RuntimeShape& output_shape, int32* output_data) { gemmlowp::ScopedProfilingLabel label("Add/int32"); - TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); - auto input1_map = MapAsVector(input1_data, input1_dims); - auto input2_map = MapAsVector(input2_data, input2_dims); - auto output_map = MapAsVector(output_data, output_dims); - if (AreSameDims(input1_dims, input2_dims)) { + auto input1_map = MapAsVector(input1_data, input1_shape); + auto input2_map = MapAsVector(input2_data, input2_shape); + auto output_map = MapAsVector(output_data, output_shape); + if (input1_shape == input2_shape) { output_map.array() = input1_map.array() + input2_map.array(); - } else if (FlatSize(input2_dims) == 1) { + } else if (input2_shape.FlatSize() == 1) { auto scalar = input2_data[0]; output_map.array() = input1_map.array() + scalar; - } else if (FlatSize(input1_dims) == 1) { + } else if (input1_shape.FlatSize() == 1) { auto scalar = input1_data[0]; output_map.array() = scalar + input2_map.array(); } else { // Should not come here. TFLITE_DCHECK(false); } + output_map = output_map.cwiseMax(params.quantized_activation_min); + output_map = output_map.cwiseMin(params.quantized_activation_max); } -// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary -// dimensionality if the runtime code does a single loop over one dimension -// that handles broadcasting as the base case. The code generator would then -// generate max(D1, D2) nested for loops. -// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from -// reference_ops.h. Once an optimized version is implemented and NdArrayDesc -// is no longer referenced in this file, move NdArrayDesc from types.h to -// reference_ops.h. -template -void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T output_activation_min, T output_activation_max, - T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] + - input2_data[SubscriptToIndex(desc2, c, x, y, b)], - output_activation_min, output_activation_max); - } - } - } - } -} - -// legacy, for compatibility with old checked-in code -template -void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T* output_data, const Dims<4>& output_dims) { - T output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims, - output_activation_min, output_activation_max, output_data, - output_dims); -} - -inline void BroadcastAdd(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastAddGeneric/8bit"); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; - const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); - const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); - const int32 raw_sum = scaled_input1_val + scaled_input2_val; - const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sum, output_multiplier, kReverseShift * output_shift) + - output_offset; - const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, raw_output)); - output_data[Offset(output_dims, c, x, y, b)] = - static_cast(clamped_output); - } - } - } - } -} - -inline void BroadcastAddFivefold( - int y0, int y1, int y2, int y3, int y4, int left_shift, - const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { +inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params, + const RuntimeShape& unswitched_input1_shape, + const uint8* unswitched_input1_data, + const RuntimeShape& unswitched_input2_shape, + const uint8* unswitched_input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit"); + ArithmeticParams switched_params = unswitched_params; + switched_params.input1_offset = unswitched_params.input2_offset; + switched_params.input1_multiplier = unswitched_params.input2_multiplier; + switched_params.input1_shift = unswitched_params.input2_shift; + switched_params.input2_offset = unswitched_params.input1_offset; + switched_params.input2_multiplier = unswitched_params.input1_multiplier; + switched_params.input2_shift = unswitched_params.input1_shift; + + const bool use_unswitched = + unswitched_params.broadcast_category == + tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast; + + const ArithmeticParams& params = + use_unswitched ? unswitched_params : switched_params; + const uint8* input1_data = + use_unswitched ? unswitched_input1_data : unswitched_input2_data; + const uint8* input2_data = + use_unswitched ? unswitched_input2_data : unswitched_input1_data; + // Fivefold nested loops. The second input resets its position for each // iteration of the second loop. The first input resets its position at the // beginning of the fourth loop. The innermost loop is an elementwise add of @@ -2869,82 +2666,29 @@ inline void BroadcastAddFivefold( uint8* output_data_ptr = output_data; const uint8* input1_data_ptr = input1_data; const uint8* input2_data_reset = input2_data; - for (int i4 = 0; i4 < y4; ++i4) { + int y0 = params.broadcast_shape[0]; + int y1 = params.broadcast_shape[1]; + int y2 = params.broadcast_shape[2]; + int y3 = params.broadcast_shape[3]; + int y4 = params.broadcast_shape[4]; + for (int i0 = 0; i0 < y0; ++i0) { const uint8* input2_data_ptr; - for (int i3 = 0; i3 < y3; ++i3) { + for (int i1 = 0; i1 < y1; ++i1) { input2_data_ptr = input2_data_reset; for (int i2 = 0; i2 < y2; ++i2) { - for (int i1 = 0; i1 < y1; ++i1) { - AddElementwise( - y0, left_shift, input1_data_ptr, input1_offset, input1_multiplier, - input1_shift, input2_data_ptr, input2_offset, input2_multiplier, - input2_shift, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_data_ptr); - input2_data_ptr += y0; - output_data_ptr += y0; + for (int i3 = 0; i3 < y3; ++i3) { + AddElementwise(y4, params, input1_data_ptr, input2_data_ptr, + output_data_ptr); + input2_data_ptr += y4; + output_data_ptr += y4; } - input1_data_ptr += y0; + input1_data_ptr += y4; } } input2_data_reset = input2_data_ptr; } } -template -inline void BroadcastAdd(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset, - input1_multiplier, input1_shift, input2_data, input2_dims, - input2_offset, input2_multiplier, input2_shift, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_dims); -} - -template -inline void BroadcastAddFivefold( - int y0, int y1, int y2, int y3, int y4, int left_shift, - const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - BroadcastAddFivefold(y0, y1, y2, y3, y4, left_shift, input1_data, input1_dims, - input1_offset, input1_multiplier, input1_shift, - input2_data, input2_dims, input2_offset, - input2_multiplier, input2_shift, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_dims); -} - inline void Mul(const float* input1_data, const Dims<4>& input1_dims, const float* input2_data, const Dims<4>& input2_dims, float output_activation_min, float output_activation_max, @@ -3018,6 +2762,20 @@ void Mul(const float* input1_data, const Dims<4>& input1_dims, output_activation_max, output_data, output_dims); } +inline void Mul(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32 output_activation_min, int32 output_activation_max, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul/int32"); + + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] * input2_data[i], output_activation_min, + output_activation_max); + } +} + template void Mul(const int32* input1_data, const Dims<4>& input1_dims, const int32* input2_data, const Dims<4>& input2_dims, @@ -3147,78 +2905,127 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, output_dims); } -inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, - int32 input1_offset, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit"); +// Element-wise mul that can often be used for inner loop of broadcast Mul as +// well as the non-broadcast Mul. +inline void MulElementwise(int size, const ArithmeticParams& params, + const uint8* input1_data, const uint8* input2_data, + uint8* output_data) { + for (int i = 0; i < size; ++i) { + const int32 input1_val = params.input1_offset + input1_data[i]; + const int32 input2_val = params.input2_offset + input2_data[i]; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val, + params.output_multiplier, + params.output_shift); + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[i] = static_cast(clamped_output); + } +} - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); +// Broadcast mul that can often be used for inner loop of broadcast Mul. +inline void MulSimpleBroadcast(int size, const ArithmeticParams& params, + const uint8 broadcast_value, + const uint8* input2_data, uint8* output_data) { + const int32 input1_val = params.input1_offset + broadcast_value; - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; - const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - const int32 unclamped_result = - output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp( - input1_val * input2_val, output_multiplier, - kReverseShift * output_shift); - const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, unclamped_result)); - output_data[Offset(output_dims, c, x, y, b)] = - static_cast(clamped_output); - } - } - } + for (int i = 0; i < size; ++i) { + const int32 input2_val = params.input2_offset + input2_data[i]; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val, + params.output_multiplier, + params.output_shift); + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[i] = static_cast(clamped_output); } } -// legacy, for compatibility with old checked-in code -template -inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, - int32 input1_offset, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, - input2_dims, input2_offset, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_data, output_dims); +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const uint8* input1_data, + const RuntimeShape& input2_shape, const uint8* input2_data, + const RuntimeShape& output_shape, uint8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + gemmlowp::ScopedProfilingLabel label("Mul/8bit"); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + + MulElementwise(flat_size, params, input1_data, input2_data, output_data); } -// TODO(aselle): This is not actually optimized yet. -inline void Div(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); - for (int i = 0; i < flat_size; i++) { - output_data[i] = ActivationFunctionWithMinMax( - input1_data[i] / input2_data[i], output_activation_min, - output_activation_max); +inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, + const RuntimeShape& unswitched_input1_shape, + const uint8* unswitched_input1_data, + const RuntimeShape& unswitched_input2_shape, + const uint8* unswitched_input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastMulFivefold/8bit"); + + ArithmeticParams switched_params = unswitched_params; + switched_params.input1_offset = unswitched_params.input2_offset; + switched_params.input2_offset = unswitched_params.input1_offset; + + const bool use_unswitched = + unswitched_params.broadcast_category == + tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast; + + const ArithmeticParams& params = + use_unswitched ? unswitched_params : switched_params; + const uint8* input1_data = + use_unswitched ? unswitched_input1_data : unswitched_input2_data; + const uint8* input2_data = + use_unswitched ? unswitched_input2_data : unswitched_input1_data; + + // Fivefold nested loops. The second input resets its position for each + // iteration of the second loop. The first input resets its position at the + // beginning of the fourth loop. The innermost loop is an elementwise Mul of + // sections of the arrays. + uint8* output_data_ptr = output_data; + const uint8* input1_data_ptr = input1_data; + const uint8* input2_data_reset = input2_data; + int y0 = params.broadcast_shape[0]; + int y1 = params.broadcast_shape[1]; + int y2 = params.broadcast_shape[2]; + int y3 = params.broadcast_shape[3]; + int y4 = params.broadcast_shape[4]; + if (y4 > 1) { + for (int i0 = 0; i0 < y0; ++i0) { + const uint8* input2_data_ptr; + for (int i1 = 0; i1 < y1; ++i1) { + input2_data_ptr = input2_data_reset; + for (int i2 = 0; i2 < y2; ++i2) { + for (int i3 = 0; i3 < y3; ++i3) { + MulElementwise(y4, params, input1_data_ptr, input2_data_ptr, + output_data_ptr); + input2_data_ptr += y4; + output_data_ptr += y4; + } + input1_data_ptr += y4; + } + } + input2_data_reset = input2_data_ptr; + } + } else { + for (int i0 = 0; i0 < y0; ++i0) { + const uint8* input2_data_ptr; + for (int i1 = 0; i1 < y1; ++i1) { + input2_data_ptr = input2_data_reset; + for (int i2 = 0; i2 < y2; ++i2) { + MulSimpleBroadcast(y3, params, *input1_data_ptr, input2_data_ptr, + output_data_ptr); + input2_data_ptr += y3; + output_data_ptr += y3; + ++input1_data_ptr; + } + } + input2_data_reset = input2_data_ptr; + } } } @@ -3268,224 +3075,81 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims, } // TODO(aselle): This is not actually optimized yet. -inline void Sub(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Sub"); - const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); +inline void SubNonBroadcast(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + float* output_data) { + gemmlowp::ScopedProfilingLabel label("SubNonBroadcast"); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( - input1_data[i] - input2_data[i], output_activation_min, - output_activation_max); + input1_data[i] - input2_data[i], params.float_activation_min, + params.float_activation_max); } } -// TODO(jiawen): We can implement BroadcastSub on buffers of arbitrary -// dimensionality if the runtime code does a single loop over one dimension -// that handles broadcasting as the base case. The code generator would then -// generate max(D1, D2) nested for loops. -// TODO(benoitjacob): BroadcastSub is intentionally duplicated from -// reference_ops.h. Once an optimized version is implemented and NdArrayDesc -// is no longer referenced in this file, move NdArrayDesc from types.h to -// reference_ops.h. -template -void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T output_activation_min, T output_activation_max, - T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub"); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] - - input2_data[SubscriptToIndex(desc2, c, x, y, b)], - output_activation_min, output_activation_max); - } - } - } +inline void SubWithActivation(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int32* input1_data, + const RuntimeShape& input2_shape, + const int32* input2_data, + const RuntimeShape& output_shape, + int32* output_data) { + gemmlowp::ScopedProfilingLabel label("SubWithActivation/int32"); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, input2_shape); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] - input2_data[i], params.quantized_activation_min, + params.quantized_activation_max); } } -inline void BroadcastSub(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub/8bit"); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; - const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); - const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); - const int32 raw_sub = scaled_input1_val - scaled_input2_val; - const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sub, output_multiplier, kReverseShift * output_shift) + - output_offset; - const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, raw_output)); - output_data[Offset(output_dims, c, x, y, b)] = - static_cast(clamped_output); - } - } - } +inline void SubWithActivation(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + float* output_data) { + gemmlowp::ScopedProfilingLabel label("SubWithActivation/float"); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, input2_shape); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] - input2_data[i], params.float_activation_min, + params.float_activation_max); } } -template -void Concatenation(int concat_dim, const Scalar* const* input_data, - const Dims<4>* const* input_dims, int inputs_count, - Scalar* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Concatenation"); - int concat_size = 0; - for (int i = 0; i < inputs_count; i++) { - for (int j = 0; j < 4; j++) { - if (j != concat_dim) { - MatchingArraySize(*input_dims[i], j, output_dims, j); - } - } - concat_size += ArraySize(*input_dims[i], concat_dim); - } - TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - // for now we dont have a model with a Concatenation - // with fused activation function. - TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); - int outer_size = 1; - for (int i = concat_dim + 1; i < 4; i++) { - outer_size *= output_dims.sizes[i]; - } - Scalar* output_ptr = output_data; - for (int k = 0; k < outer_size; k++) { - for (int i = 0; i < inputs_count; ++i) { - const int copy_size = - input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; - memcpy(output_ptr, input_data[i] + k * copy_size, - copy_size * sizeof(Scalar)); - output_ptr += copy_size; - } - } -} +template +void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape, + const T* input1_data, const RuntimeShape& input2_shape, + const T* input2_data, const RuntimeShape& output_shape, + T* output_data) { + gemmlowp::ScopedProfilingLabel label("Sub"); -// TODO(prabhumk): This is the same as the reference implementation. -// TODO(prabhumk): The quantized implementation of concatentation isn't fully -// quantized as it takes scale as a floating point value. This should be fixed -// when optimizng this routine further. -inline void Concatenation(int concat_dim, const uint8* const* input_data, - const Dims<4>* const* input_dims, - const int32* input_zeropoint, - const float* input_scale, int inputs_count, - uint8* output_data, const Dims<4>& output_dims, - const int32 output_zeropoint, - const float output_scale) { - // The arguments input_zeropoint and input_scale are expected to be an array - // that have the quantization parameters for all the inputs to the concat - // operator. - gemmlowp::ScopedProfilingLabel label("Concatenation"); - TFLITE_DCHECK_GT(inputs_count, 1); - int concat_size = 0; - for (int i = 0; i < inputs_count; i++) { - for (int j = 0; j < 4; j++) { - if (j != concat_dim) { - MatchingArraySize(*input_dims[i], j, output_dims, j); - } - } - concat_size += ArraySize(*input_dims[i], concat_dim); - } - TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); - int outer_size = 1; - for (int i = concat_dim + 1; i < 4; i++) { - outer_size *= output_dims.sizes[i]; - } - const float inverse_output_scale = 1.f / output_scale; - uint8* output_ptr = output_data; - for (int k = 0; k < outer_size; k++) { - for (int i = 0; i < inputs_count; ++i) { - const int copy_size = - input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; - const uint8* input_ptr = input_data[i] + k * copy_size; - if (input_zeropoint[i] == output_zeropoint && - input_scale[i] == output_scale) { - memcpy(output_ptr, input_ptr, copy_size); - } else { - const float scale = input_scale[i] * inverse_output_scale; - const float bias = -input_zeropoint[i] * scale; - for (int j = 0; j < copy_size; ++j) { - const int32_t value = - static_cast(round(input_ptr[j] * scale + bias)) + - output_zeropoint; - output_ptr[j] = - static_cast(std::max(std::min(255, value), 0)); - } - } - output_ptr += copy_size; - } + auto input1_map = MapAsVector(input1_data, input1_shape); + auto input2_map = MapAsVector(input2_data, input2_shape); + auto output_map = MapAsVector(output_data, output_shape); + if (input1_shape == input2_shape) { + output_map.array() = input1_map.array() - input2_map.array(); + } else if (input1_shape.FlatSize() == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar - input2_map.array(); + } else if (input2_shape.FlatSize() == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() - scalar; + } else { + BroadcastSub4DSlow(params, input1_shape, input1_data, input2_shape, + input2_data, output_shape, output_data); } } -template -void DepthConcatenation(const Scalar* const* input_data, - const Dims<4>* const* input_dims, int inputs_count, - Scalar* output_data, const Dims<4>& output_dims) { - Concatenation(0, input_data, input_dims, inputs_count, - output_data, output_dims); -} - inline void LstmCell(const float* input_data, const Dims<4>& input_dims, const float* prev_activ_data, const Dims<4>& prev_activ_dims, const float* weights_data, @@ -3848,23 +3512,24 @@ inline int NodeOffset(int b, int h, int w, int height, int width) { return (b * height + h) * width + w; } -inline void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int kwidth, int kheight, - float output_activation_min, - float output_activation_max, float* output_data, - const Dims<4>& output_dims) { +inline void AveragePool(const PoolParams& params, + const RuntimeShape& input_shape, + const float* input_data, + const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("AveragePool"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; // TODO(benoitjacob) make this a proper reference impl without Eigen! - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // TODO(benoitjacob) get rid of the dynamic memory allocation here! Eigen::VectorXf out_count(out_mat.cols()); out_count.setZero(); @@ -3875,12 +3540,15 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, for (int w = 0; w < input_width; ++w) { // (h_start, h_end) * (w_start, w_end) is the range that the input // vector projects to. - int hpad = h + pad_height; - int wpad = w + pad_width; - int h_start = - (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1; + int hpad = h + params.padding_values.height; + int wpad = w + params.padding_values.width; + int h_start = (hpad < params.filter_height) + ? 0 + : (hpad - params.filter_height) / stride_height + 1; int h_end = std::min(hpad / stride_height + 1, output_height); - int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1; + int w_start = (wpad < params.filter_width) + ? 0 + : (wpad - params.filter_width) / stride_width + 1; int w_end = std::min(wpad / stride_width + 1, output_width); // compute elementwise sum for (int ph = h_start; ph < h_end; ++ph) { @@ -3898,69 +3566,44 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, TFLITE_DCHECK_GT(out_count.minCoeff(), 0); out_mat.array().rowwise() /= out_count.transpose().array(); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < output_height; ++y) { - for (int x = 0; x < output_width; ++x) { - for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - ActivationFunctionWithMinMax( - output_data[Offset(output_dims, c, x, y, b)], - output_activation_min, output_activation_max); - } - } - } + const int flat_size = output_shape.FlatSize(); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax(output_data[i], + params.float_activation_min, + params.float_activation_max); } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int kwidth, int kheight, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, kwidth, kheight, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, float* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { +inline void AveragePool(const PoolParams& params, + const RuntimeShape& input_shape, + const uint8* input_data, + const RuntimeShape& output_shape, uint8* output_data) { gemmlowp::ScopedProfilingLabel label("AveragePool/8bit"); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); const int filter_count = (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start); // 1280 required by Inception v3 @@ -3969,11 +3612,12 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, uint16 acc[kAccBufferMaxSize]; memset(acc, 0, depth * sizeof(acc[0])); const uint8* input_ptr = - input_data + input_dims.strides[1] * in_x_origin + - input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + input_data + + depth * (in_x_origin + + input_width * (in_y_origin + input_height * batch)); for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + - filter_x_start * input_dims.strides[1]; + const uint8* input_row_ptr = + input_ptr + depth * (fy * input_width + filter_x_start); for (int fx = filter_x_start; fx < filter_x_end; fx++) { int channel = 0; #ifdef USE_NEON @@ -4004,21 +3648,21 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, } } uint8* output_ptr = - output_data + Offset(output_dims, 0, out_x, out_y, batch); + output_data + Offset(output_shape, batch, out_y, out_x, 0); int channel = 0; #ifdef USE_NEON -#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \ - if (filter_count == FILTER_COUNT) { \ - for (; channel <= depth - 8; channel += 8) { \ - uint16 buf[8]; \ - for (int i = 0; i < 8; i++) { \ - buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \ - } \ - uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \ - buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); \ - buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); \ - vst1_u8(output_ptr + channel, buf8); \ - } \ +#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \ + if (filter_count == FILTER_COUNT) { \ + for (; channel <= depth - 8; channel += 8) { \ + uint16 buf[8]; \ + for (int i = 0; i < 8; i++) { \ + buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \ + } \ + uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \ + buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); \ + buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); \ + vst1_u8(output_ptr + channel, buf8); \ + } \ } AVGPOOL_DIVIDING_BY(9) AVGPOOL_DIVIDING_BY(15) @@ -4029,15 +3673,15 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, buf[i] = (acc[channel + i] + filter_count / 2) / filter_count; } uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); - buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); - buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); + buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); + buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); vst1_u8(output_ptr + channel, buf8); } #endif for (; channel < depth; ++channel) { uint16 a = (acc[channel] + filter_count / 2) / filter_count; - a = std::max(a, output_activation_min); - a = std::min(a, output_activation_max); + a = std::max(a, params.quantized_activation_min); + a = std::min(a, params.quantized_activation_max); output_ptr[channel] = static_cast(a); } } @@ -4045,54 +3689,22 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -inline void MaxPool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int kwidth, int kheight, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { +inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { gemmlowp::ScopedProfilingLabel label("MaxPool"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; + + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // Prefill the output to minimum representable float value out_mat.setConstant(std::numeric_limits::lowest()); for (int b = 0; b < batches; ++b) { @@ -4100,12 +3712,15 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, for (int w = 0; w < input_width; ++w) { // (h_start, h_end) * (w_start, w_end) is the range that the input // vector projects to. - int hpad = h + pad_height; - int wpad = w + pad_width; - int h_start = - (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1; + int hpad = h + params.padding_values.height; + int wpad = w + params.padding_values.width; + int h_start = (hpad < params.filter_height) + ? 0 + : (hpad - params.filter_height) / stride_height + 1; int h_end = std::min(hpad / stride_height + 1, output_height); - int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1; + int w_start = (wpad < params.filter_width) + ? 0 + : (wpad - params.filter_width) / stride_width + 1; int w_end = std::min(wpad / stride_width + 1, output_width); // compute elementwise sum for (int ph = h_start; ph < h_end; ++ph) { @@ -4120,78 +3735,55 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, } } } - - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < output_height; ++y) { - for (int x = 0; x < output_width; ++x) { - for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - ActivationFunctionWithMinMax( - output_data[Offset(output_dims, c, x, y, b)], - output_activation_min, output_activation_max); - } - } - } + const int flat_size = output_shape.FlatSize(); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax(output_data[i], + params.float_activation_min, + params.float_activation_max); } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int kwidth, int kheight, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, kwidth, kheight, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { +inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& output_shape, + uint8* output_data) { gemmlowp::ScopedProfilingLabel label("MaxPool/8bit"); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); // 2048 required by Inception v3 static constexpr int kAccBufferMaxSize = 2048; TFLITE_DCHECK_LE(depth, kAccBufferMaxSize); uint8 acc[kAccBufferMaxSize]; memset(acc, 0, depth * sizeof(acc[0])); const uint8* input_ptr = - input_data + input_dims.strides[1] * in_x_origin + - input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + input_data + + depth * (in_x_origin + + input_width * (in_y_origin + input_height * batch)); for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + - filter_x_start * input_dims.strides[1]; + const uint8* input_row_ptr = + input_ptr + depth * (fy * input_width + filter_x_start); for (int fx = filter_x_start; fx < filter_x_end; fx++) { int channel = 0; #ifdef USE_NEON @@ -4217,26 +3809,26 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, } } uint8* output_ptr = - output_data + Offset(output_dims, 0, out_x, out_y, batch); + output_data + Offset(output_shape, batch, out_y, out_x, 0); int channel = 0; #ifdef USE_NEON for (; channel <= depth - 16; channel += 16) { uint8x16_t a = vld1q_u8(acc + channel); - a = vminq_u8(a, vdupq_n_u8(output_activation_max)); - a = vmaxq_u8(a, vdupq_n_u8(output_activation_min)); + a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max)); + a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min)); vst1q_u8(output_ptr + channel, a); } for (; channel <= depth - 8; channel += 8) { uint8x8_t a = vld1_u8(acc + channel); - a = vmin_u8(a, vdup_n_u8(output_activation_max)); - a = vmax_u8(a, vdup_n_u8(output_activation_min)); + a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max)); + a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min)); vst1_u8(output_ptr + channel, a); } #endif for (; channel < depth; ++channel) { uint8 a = acc[channel]; - a = std::max(a, output_activation_min); - a = std::min(a, output_activation_max); + a = std::max(a, params.quantized_activation_min); + a = std::min(a, params.quantized_activation_max); output_ptr[channel] = static_cast(a); } } @@ -4244,53 +3836,23 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -inline void L2Pool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { +inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { gemmlowp::ScopedProfilingLabel label("L2Pool"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; // Actually carry out L2 Pool. Code is written in forward mode: we go through // the input values once, and write to all the pooled regions that it maps to. - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); Eigen::VectorXf in_square(in_mat.rows()); Eigen::VectorXf out_count(out_mat.cols()); out_count.setZero(); @@ -4301,15 +3863,17 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, for (int w = 0; w < input_width; ++w) { // (h_start, h_end) * (w_start, w_end) is the range that the input // vector projects to. - const int hpad = h + pad_height; - const int wpad = w + pad_width; - const int h_start = (hpad < filter_height) - ? 0 - : (hpad - filter_height) / stride_height + 1; + const int hpad = h + params.padding_values.height; + const int wpad = w + params.padding_values.width; + const int h_start = + (hpad < params.filter_height) + ? 0 + : (hpad - params.filter_height) / stride_height + 1; const int h_end = std::min(hpad / stride_height + 1, output_height); - const int w_start = (wpad < filter_width) - ? 0 - : (wpad - filter_width) / stride_width + 1; + const int w_start = + (wpad < params.filter_width) + ? 0 + : (wpad - params.filter_width) / stride_width + 1; const int w_end = std::min(wpad / stride_width + 1, output_width); // pre-compute square const int in_offset = w + input_width * (h + input_height * b); @@ -4330,28 +3894,13 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, out_count = out_count.array().inverse(); out_mat = (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt(); -} -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); + const int flat_size = output_shape.FlatSize(); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax(output_data[i], + params.float_activation_min, + params.float_activation_max); + } } inline void LocalResponseNormalization(const float* input_data, @@ -4399,14 +3948,14 @@ inline void LocalResponseNormalization(const float* input_data, } } -inline void Softmax(const float* input_data, const Dims<4>& input_dims, +inline void Softmax(const float* input_data, const RuntimeShape& input_shape, float beta, float* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Softmax"); - MatchingFlatSize(input_dims, output_dims); + MatchingFlatSize(input_shape, output_shape); - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // Compute the exponential first, removing the max coefficient for numerical // stability. out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta; @@ -4418,10 +3967,10 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims, out_mat.array().rowwise() *= scale; } -inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, +inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_beta_multiplier, int32 input_beta_left_shift, int diff_min, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -4435,8 +3984,11 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPoint0 = gemmlowp::FixedPoint; gemmlowp::ScopedProfilingLabel label("Softmax/8bit"); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int b = 0; b < outer_size; ++b) { const uint8* input_data_ptr = input_data + b * depth; @@ -4553,7 +4105,7 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, // perform a division by the above-computed sum-of-exponentials. int32 fixed_sum_of_exps = sum_of_exps.raw(); int headroom_plus_one = - __builtin_clz(static_cast(fixed_sum_of_exps)); + CountLeadingZeros(static_cast(fixed_sum_of_exps)); // This is the number of bits to the left of the binary point above 1.0. // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and // no later adjustment will be needed. @@ -4626,11 +4178,14 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, // TODO(myenik): This is the same as the reference implementation, not actually // optimized yet. -inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("LogSoftmax"); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { const float* block_input_data = input_data + i * depth; @@ -4696,7 +4251,7 @@ log_x_for_x_greater_than_or_equal_to_1_impl( // required shift "ourselves" instead of using, say, Rescale. FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); // z_a_pow_2 = input_integer_bits - z_a_headroom; - int z_a_headroom_plus_1 = __builtin_clz(static_cast(z_a.raw())); + int z_a_headroom_plus_1 = CountLeadingZeros(static_cast(z_a.raw())); FixedPoint0 r_a_tmp = SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); const int32 r_a_raw = @@ -4711,7 +4266,7 @@ log_x_for_x_greater_than_or_equal_to_1_impl( // z_b is treated like z_a, but premultiplying by sqrt(0.5). FixedPoint0 z_b = z_a * sqrt_half; - int z_b_headroom = __builtin_clz(static_cast(z_b.raw())) - 1; + int z_b_headroom = CountLeadingZeros(static_cast(z_b.raw())) - 1; const int32 r_b_raw = SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); const FixedPointAccum z_b_pow_2_adj = SaturatingSub( @@ -4771,11 +4326,11 @@ log_x_for_x_greater_than_or_equal_to_1( } // Currently just a copy of the reference code. -inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, +inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_multiplier, int32 input_left_shift, int32 reverse_scaling_divisor, int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8"); // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as @@ -4790,8 +4345,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { const uint8* block_input_data = input_data + i * depth; @@ -4855,21 +4413,21 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void Logistic(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic"); - auto input_map = MapAsVector(input_data, input_dims); - auto output_map = MapAsVector(output_data, output_dims); + auto input_map = MapAsVector(input_data, input_shape); + auto output_map = MapAsVector(output_data, output_shape); output_map.array() = input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op()); } -inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, +inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic/Uint8"); - const int size = MatchingFlatSize(input_dims, output_dims); + const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; #ifdef USE_NEON @@ -5001,10 +4559,10 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const int16* input_data, const Dims<4>& input_dims, - int16* output_data, const Dims<4>& output_dims) { +inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, + int16* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic/Int16"); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { } @@ -5061,21 +4619,21 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void Tanh(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Tanh"); - auto input_map = MapAsVector(input_data, input_dims); - auto output_map = MapAsVector(output_data, output_dims); + auto input_map = MapAsVector(input_data, input_shape); + auto output_map = MapAsVector(output_data, output_shape); output_map.array() = input_map.array().tanh(); } -inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, +inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { // Note that this is almost the exact same code as in Logistic(). gemmlowp::ScopedProfilingLabel label("Tanh"); - const int size = MatchingFlatSize(input_dims, output_dims); + const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; int32_t output_zero_point = 128; @@ -5216,16 +4774,16 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const int16* input_data, const Dims<4>& input_dims, +inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, int input_left_shift, int16* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Tanh/Int16"); // Support for shifts is limited until we have a parameterized version of // SaturatingRoundingMultiplyByPOT(). TFLITE_DCHECK_GE(input_left_shift, 0); TFLITE_DCHECK_LE(input_left_shift, 1); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); int c = 0; const int16* input_data_ptr = input_data; @@ -5316,49 +4874,6 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims, } } -inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, - int32 zero_point, double scale, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Dequantize"); - const int flat_size = MatchingFlatSize(output_dims, input_dims); - for (int i = 0; i < flat_size; ++i) { - int32 val = input_data[i]; - float result = static_cast(scale * (val - zero_point)); - output_data[i] = result; - } -} - -inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, - float rmin, float rmax, int num_bits, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("FakeQuant"); - - // 0 should always be a representable value. Let's assume that the initial - // min,max range contains 0. - TFLITE_DCHECK_LE(rmin, 0.0f); - TFLITE_DCHECK_GE(rmax, 0.0f); - TFLITE_DCHECK_LT(rmin, rmax); - - // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor. - int quant_min = 0; - int quant_max = (1 << num_bits) - 1; - float nudged_min, nudged_max, nudged_scale; - NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min, - &nudged_max, &nudged_scale); - const float inv_nudged_scale = 1.0f / nudged_scale; - - const int flat_size = MatchingFlatSize(output_dims, input_dims); - for (int i = 0; i < flat_size; ++i) { - const float src_val = input_data[i]; - const float clamped = std::min(nudged_max, std::max(nudged_min, src_val)); - const float clamped_shifted = clamped - nudged_min; - const float dst_val = - TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale + - nudged_min; - output_data[i] = dst_val; - } -} - template inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data, const Dims<4>& output_dims) { @@ -5376,26 +4891,6 @@ inline void Floor(const float* input_data, const Dims<4>& input_dims, output_map.array() = Eigen::floor(input_map.array()); } -template -inline void Gather(const T* input_data, const Dims<4>& input_dims, - int input_rank, const int32* coords_data, - const Dims<4>& coords_dims, T* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Gather"); - - TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); - int stride = input_dims.strides[input_rank - 1]; - T* out = output_data; - - for (int i = 0; i < coords_dims.sizes[0]; i++) { - TFLITE_DCHECK_GE(coords_data[i], 0); - TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); - const T* in = input_data + coords_data[i] * stride; - memcpy(out, in, sizeof(T) * stride); - out += stride; - } -} - #ifdef USE_NEON inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, float scale, float* output_ptr) { @@ -5857,55 +5352,6 @@ inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, output_data, output_dims, /*align_corners=*/false); } -template -inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, - const Dims<4>& block_shape_dims, - const int32* paddings_data, - const Dims<4>& paddings_dims, T* output_data, - const Dims<4>& output_dims) { - // Unoptimized - Straight copy from reference ops. - gemmlowp::ScopedProfilingLabel label("SpaceToBatchND"); - - const int output_batch_size = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int input_batch_size = ArraySize(input_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int depth = ArraySize(input_dims, 0); - const int block_shape_height = block_shape_data[0]; - const int block_shape_width = block_shape_data[1]; - const int padding_top = paddings_data[0]; - const int padding_left = paddings_data[2]; - - for (int out_b = 0; out_b < output_batch_size; ++out_b) { - int input_batch = out_b % input_batch_size; - int shift_w = (out_b / input_batch_size) % block_shape_width; - int shift_h = (out_b / input_batch_size) / block_shape_width; - for (int out_h = 0; out_h < output_height; ++out_h) { - for (int out_w = 0; out_w < output_width; ++out_w) { - T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); - if (out_h * block_shape_height + shift_h < padding_top || - out_h * block_shape_height + shift_h >= - padding_top + input_height || - out_w * block_shape_width + shift_w < padding_left || - out_w * block_shape_width + shift_w >= padding_left + input_width) { - memset(out, 0, depth * sizeof(T)); - } else { - const T* in = - input_data + - Offset(input_dims, 0, - (out_w * block_shape_width + shift_w) - padding_left, - (out_h * block_shape_height + shift_h) - padding_top, - input_batch); - memcpy(out, in, depth * sizeof(T)); - } - } - } - } -} - // Helper methods for BatchToSpaceND. // `spatial_index_dim` specifies post-crop offset index in this spatial // dimension, i.e. spatial offset introduced by flattening batch to spatial @@ -6108,54 +5554,6 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, output_dims, 0); } -// UNOPTIMIZED COPY of StridedSlice from reference_ops.h. -template -inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, - int begin_mask, int end_mask, - const std::vector& start_indices, - const std::vector& stop_indices, - const std::vector& strides, T* output_data, - const Dims<4>& output_dims) { - TFLITE_DCHECK_EQ(start_indices.size(), 4); - TFLITE_DCHECK_EQ(stop_indices.size(), 4); - TFLITE_DCHECK_EQ(strides.size(), 4); - const int start_b = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 3); - const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 3); - const int start_h = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 2); - const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 2); - const int start_w = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 1); - const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 1); - const int start_d = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 0); - const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 0); - - T* out_ptr = output_data; - for (int in_b = start_b; - !strided_slice::LoopCondition(in_b, stop_b, strides[3]); - in_b += strides[3]) { - for (int in_h = start_h; - !strided_slice::LoopCondition(in_h, stop_h, strides[2]); - in_h += strides[2]) { - for (int in_w = start_w; - !strided_slice::LoopCondition(in_w, stop_w, strides[1]); - in_w += strides[1]) { - for (int in_d = start_d; - !strided_slice::LoopCondition(in_d, stop_d, strides[0]); - in_d += strides[0]) { - *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; - } - } - } - } -} - template inline void Slice(const T* input_data, const Dims<4>& input_dims, const std::vector& begin, const std::vector& size, @@ -6190,98 +5588,6 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, } } -template -inline void Mean(const T* input_data, const Dims<4>& input_dims, - const std::vector& reduction_indices, T* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Mean"); - const int output_batch = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int output_depth = ArraySize(output_dims, 0); - - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - - // The current implementation only supports simultaneous reduction over - // width and height. - TFLITE_DCHECK_EQ(reduction_indices.size(), 2); - TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || - (reduction_indices[0] == 2 && reduction_indices[1] == 1)); - TFLITE_DCHECK_EQ(output_height, 1); - TFLITE_DCHECK_EQ(output_width, 1); - - for (int out_b = 0; out_b < output_batch; ++out_b) { - for (int out_d = 0; out_d < output_depth; ++out_d) { - float value = 0; - for (int in_h = 0; in_h < input_height; ++in_h) { - for (int in_w = 0; in_w < input_width; ++in_w) { - value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; - } - } - output_data[Offset(output_dims, out_d, 0, 0, out_b)] = - value / (input_width * input_height); - } - } -} - -template -void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("GenericBroadcastSub"); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - input1_data[SubscriptToIndex(desc1, c, x, y, b)] - - input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - } - } - } - } -} - -template -void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, - const Dims<4>& input2_dims, T* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Sub"); - - auto input1_map = MapAsVector(input1_data, input1_dims); - auto input2_map = MapAsVector(input2_data, input2_dims); - auto output_map = MapAsVector(output_data, output_dims); - if (AreSameDims(input1_dims, input2_dims)) { - output_map.array() = input1_map.array() - input2_map.array(); - } else if (FlatSize(input1_dims) == 1) { - auto scalar = input1_data[0]; - output_map.array() = scalar - input2_map.array(); - } else if (FlatSize(input2_dims) == 1) { - auto scalar = input2_data[0]; - output_map.array() = input1_map.array() - scalar; - } else { - GenericBroadcastSub(input1_data, input1_dims, input2_data, input2_dims, - output_data, output_dims); - } -} - template void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, T* output_data, @@ -6304,130 +5610,84 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, output_map.array() = input1_map.array().max(max_value); } -template -void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, - T2* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("ArgMax"); - - // The current ArgMax implemention can only determine the index of the maximum - // value in the last dimension. So the axis argument is ignored. - - // For ArgMax, the number of output dimensions = (number of input dimensions - - // 1). For the sake of simplicity, the output dimensions are equal to the - // input dimensions here. We enforce the constraint that the last dimension - // must always be 1. - TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = ArraySize(input_dims, 0); - for (int i = 0; i < outer_size; ++i) { - auto max_value = *input_data; - ++input_data; - int max_index = 0; - for (int d = 1; d < depth; ++d) { - const auto& curr_value = *input_data; - if (curr_value > max_value) { - max_value = curr_value; - max_index = d; - } - ++input_data; - } - *output_data = max_index; - ++output_data; - } -} - template -void Transpose(const T* input, const Dims<4>& input_dims, T* output, - const Dims<4>& output_dims, const int* permuted_axes) { - int out_sizes[4]; - // Compute the inverse permutation array so we can do an output centered - // transpose. Also, check to make sure output_dims is matching input_dims. - for (int k = 0; k < 4; k++) { - out_sizes[k] = - MatchingArraySize(input_dims, permuted_axes[k], output_dims, k); - } - - // Naive transpose loop (iterate on output index and compute input index). - int o[4]; // loop index (on output). - int i[4]; - for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) { - i[permuted_axes[3]] = o[3]; - for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) { - i[permuted_axes[2]] = o[2]; - for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) { - i[permuted_axes[1]] = o[1]; - for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) { - i[permuted_axes[0]] = o[0]; - output[Offset(output_dims, o)] = input[Offset(input_dims, i)]; - } - } - } - } -} +void TransposeIm2col(const T* input_data, const Dims<4>& input_dims, + const Dims<4>& filter_dims, int stride_width, + int stride_height, int pad_width, int pad_height, + const Dims<4>& output_dims, uint8 zero_byte, + T* im2col_data) { + gemmlowp::ScopedProfilingLabel label("TransposeIm2col"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK(im2col_data); -inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("TransposeConv"); - // THIS FUNCTION IS A COPY FROM reference_ops.h. - // To optimize, start by using the conv code with transposed weights for the - // case of stride_height = stride_width = 1. const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); - const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); const int input_height = ArraySize(input_dims, 2); const int input_width = ArraySize(input_dims, 1); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); const int filter_height = ArraySize(filter_dims, 2); const int filter_width = ArraySize(filter_dims, 1); const int output_height = ArraySize(output_dims, 2); const int output_width = ArraySize(output_dims, 1); + MatchingArraySize(output_dims, 0, filter_dims, 0); // output_depth - // Although transpose convolution simplifies to convolution with transposed - // weights for strides of 1, non-unitary striding complicates matters. To - // keep this reference implementation as clear as possible, we use a "scatter" - // access pattern, where we loop through all the input elements, computing - // their influence on the output, rather than looping through the output - // elements in the typical "gather" access pattern of a conv. We therefore - // must initialize the output array to zero. - for (int batch = 0; batch < batches; ++batch) { - for (int out_y = 0; out_y < output_height; ++out_y) { - for (int out_x = 0; out_x < output_width; ++out_x) { - for (int out_channel = 0; out_channel < output_depth; ++out_channel) { - output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = - 0.0f; - } - } - } - } + // Construct the MxN sized im2col matrix. + // The rows M, are sub-ordered B x H x W + Dims<4> row_dims; + row_dims.sizes[0] = output_width; + row_dims.sizes[1] = output_height; + row_dims.sizes[2] = batches; + row_dims.sizes[3] = 1; + ComputeStrides(&row_dims); + + // The columns, N, are sub-ordered Kh x Kw x Din + Dims<4> col_dims; + col_dims.sizes[0] = input_depth; + col_dims.sizes[1] = filter_width; + col_dims.sizes[2] = filter_height; + col_dims.sizes[3] = 1; + ComputeStrides(&col_dims); + + // Use dimensions M and N to construct dims for indexing directly into im2col + Dims<4> im2col_dims; + im2col_dims.sizes[0] = FlatSize(col_dims); + im2col_dims.sizes[1] = FlatSize(row_dims); + im2col_dims.sizes[2] = 1; + im2col_dims.sizes[3] = 1; + ComputeStrides(&im2col_dims); + + // Build the im2col matrix by looping through all the input pixels, + // computing their influence on the output, rather than looping through all + // the output pixels. We therefore must initialize the im2col array to zero. + // This is potentially inefficient because we subsequently overwrite bytes + // set here. However, in practice memset is very fast and costs negligible. + memset(im2col_data, zero_byte, FlatSize(im2col_dims) * sizeof(T)); - // Loop through input elements one at a time. + // Loop through the output batches for (int batch = 0; batch < batches; ++batch) { + // Loop through input pixels one at a time. for (int in_y = 0; in_y < input_height; ++in_y) { for (int in_x = 0; in_x < input_width; ++in_x) { - for (int in_channel = 0; in_channel < input_depth; ++in_channel) { - // Loop through the output elements it will influence - const int out_x_origin = (in_x * stride_width) - pad_width; - const int out_y_origin = (in_y * stride_height) - pad_height; - for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + // Loop through the output pixels it will influence + const int out_x_origin = (in_x * stride_width) - pad_width; + const int out_y_origin = (in_y * stride_height) - pad_height; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int out_y = out_y_origin + filter_y; + // Is output pixel within height bounds? + if ((out_y >= 0) && (out_y < output_height)) { for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - for (int out_channel = 0; out_channel < output_depth; - ++out_channel) { - // Compute output element location - const int out_x = out_x_origin + filter_x; - const int out_y = out_y_origin + filter_y; - // We cannot accumulate out of bounds - if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) && - (out_y < output_height)) { - float input_value = input_data[Offset(input_dims, in_channel, - in_x, in_y, batch)]; - float filter_value = - filter_data[Offset(filter_dims, in_channel, filter_x, - filter_y, out_channel)]; - output_data[Offset(output_dims, out_channel, out_x, out_y, - batch)] += input_value * filter_value; - } + const int out_x = out_x_origin + filter_x; + // Is output pixel within width bounds? + if ((out_x >= 0) && (out_x < output_width)) { + // Copy the input elements of this pixel + T const* src = + input_data + Offset(input_dims, 0, in_x, in_y, batch); + T* dst = im2col_data + + Offset(im2col_dims, + Offset(col_dims, 0, filter_x, filter_y, 0), + Offset(row_dims, out_x, out_y, batch, 0), 0, 0); + memcpy(dst, src, input_depth * sizeof(T)); } } } @@ -6437,6 +5697,31 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, } } +inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + gemmlowp::ScopedProfilingLabel label("TransposeConv"); + + // Note we could use transposed weights with forward conv for unstrided + // cases. But we are already getting good performance with this code as-is. + TFLITE_DCHECK(im2col_data); + TransposeIm2col(input_data, input_dims, filter_dims, stride_width, + stride_height, pad_width, pad_height, output_dims, 0, + im2col_data); + + const auto im2col_matrix_map = + MapAsMatrixWithFirstDimAsRows(im2col_data, im2col_dims); + const auto filter_matrix_map = + MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map); +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index f14667090f5c3867c7992211272063239f3b92aa..010b40b901e2821c36367da7e2c987fac113de11 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -19,6 +19,10 @@ limitations under the License. // structure. #include "tensorflow/contrib/lite/builtin_op_data.h" +#if defined(_MSC_VER) +#define __restrict__ __restrict +#endif + #ifndef USE_NEON #if defined(__ARM_NEON__) || defined(__ARM_NEON) #define USE_NEON @@ -124,6 +128,12 @@ void PortableCopyVector(const float* vector, int v_size, float* result); // Fill vector with 0.f. void PortableZeroVector(float* vector, int v_size); +// Multiply all elements of vector with a scalar. +void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result); +void NeonVectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result); + // Limit a float input f between +abs_limit and -abs_limit. float PortableClip(float f, float abs_limit); diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc index 57ee859115cddbcbccae24ff639e848340d8e2ee..f882f9910e0c65d69eb5a86886bae4d3c881e6ab 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include #include #include @@ -108,12 +109,12 @@ int CalculateInputRadius(int input_integer_bits, int input_left_shift) { void NudgeQuantizationRange(const float min, const float max, const int quant_min, const int quant_max, float* nudged_min, float* nudged_max, - float* scale) { + float* nudged_scale) { // This code originates from tensorflow/core/kernels/fake_quant_ops_functor.h. const float quant_min_float = static_cast(quant_min); const float quant_max_float = static_cast(quant_max); - *scale = (max - min) / (quant_max_float - quant_min_float); - const float zero_point_from_min = quant_min_float - min / *scale; + *nudged_scale = (max - min) / (quant_max_float - quant_min_float); + const float zero_point_from_min = quant_min_float - min / *nudged_scale; uint16 nudged_zero_point; if (zero_point_from_min < quant_min_float) { nudged_zero_point = static_cast(quant_min); @@ -122,8 +123,37 @@ void NudgeQuantizationRange(const float min, const float max, } else { nudged_zero_point = static_cast(TfLiteRound(zero_point_from_min)); } - *nudged_min = (quant_min_float - nudged_zero_point) * (*scale); - *nudged_max = (quant_max_float - nudged_zero_point) * (*scale); + *nudged_min = (quant_min_float - nudged_zero_point) * (*nudged_scale); + *nudged_max = (quant_max_float - nudged_zero_point) * (*nudged_scale); +} + +void FakeQuantizeArray(const float nudged_scale, const float nudged_min, + const float nudged_max, const float* input_data, + float* output_data, const float size) { + // This code originates from tensorflow/core/kernels/fake_quant_ops_functor.h. + const float inv_nudged_scale = 1.0f / nudged_scale; + + for (int i = 0; i < size; i++) { + const float src_val = input_data[i]; + const float clamped = std::min(nudged_max, std::max(nudged_min, src_val)); + const float clamped_shifted = clamped - nudged_min; + const float dst_val = + TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale + + nudged_min; + output_data[i] = dst_val; + } +} + +bool CheckedLog2(const float x, int* log2_result) { + // Using TfLiteRound instead of std::round and std::log instead of + // std::log2 to work around these fuctions being missing in a toolchain + // used in some TensorFlow tests as of May 2018. + const float x_log2 = std::log(x) * (1.0f / std::log(2.0f)); + const float x_log2_rounded = TfLiteRound(x_log2); + const float x_log2_fracpart = x_log2 - x_log2_rounded; + + *log2_result = static_cast(x_log2_rounded); + return std::abs(x_log2_fracpart) < 1e-3; } } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index 182ee782c76fcccedc99327d47805b49bfb8580d..9ee4a47fbb5bba1a409830f99c7b9ba967325a0a 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -28,8 +28,9 @@ namespace tflite { // Given the min and max values of a float array, return // reasonable quantization parameters to use for this array. template -QuantizationParams ChooseQuantizationParams(double rmin, double rmax) { - const T qmin = std::numeric_limits::min(); +QuantizationParams ChooseQuantizationParams(double rmin, double rmax, + bool narrow_range) { + const T qmin = std::numeric_limits::min() + (narrow_range ? 1 : 0); const T qmax = std::numeric_limits::max(); const double qmin_double = qmin; const double qmax_double = qmax; @@ -97,6 +98,11 @@ QuantizationParams ChooseQuantizationParams(double rmin, double rmax) { return quantization_params; } +template +QuantizationParams ChooseQuantizationParams(double rmin, double rmax) { + return ChooseQuantizationParams(rmin, rmax, false); +} + // Converts a floating-point number to an integer. For all inputs x where // static_cast(x) is legal according to the C++ standard, the result // is identical to that cast (i.e. the result is x with its fractional part @@ -216,7 +222,20 @@ int CalculateInputRadius(int input_integer_bits, int input_left_shift); // Outputs nudged_min, nudged_max, nudged_scale. void NudgeQuantizationRange(const float min, const float max, const int quant_min, const int quant_max, - float* nudged_min, float* nudged_max, float* scale); + float* nudged_min, float* nudged_max, + float* nudged_scale); + +// Fake quantizes (quantizes and dequantizes) input_data using the scale, +// nudged_min, and nudged_max from NudgeQuantizationRange. This matches the code +// in TensorFlow's FakeQuantizeWithMinMaxVarsFunctor. +void FakeQuantizeArray(const float nudged_scale, const float nudged_min, + const float nudged_max, const float* input_data, + float* output_data, const float size); + +// If x is approximately a power of two (with any positive or negative +// exponent), stores that exponent (i.e. log2(x)) in *log2_result, otherwise +// returns false. +bool CheckedLog2(const float x, int* log2_result); } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc index 94773b47d3817d7ed7240f74545ad04e7fa4bd52..00fc3e91dc90254ca68d637941e5a2482e2832a8 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -130,22 +130,22 @@ void RunSafeCastTests() { } TEST(QuantizationUtilTest, SafeCast) { - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); - RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); + RunSafeCastTests(); } // Example taken from http://www.tensorflow.org/performance/quantization diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..57afeeeea0a809c246db5b49ebfbe925452f6896 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h @@ -0,0 +1,634 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ + +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +namespace reference_ops { + +inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { + return RuntimeShape( + {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + L2Normalization(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, + output_data, DimsToShape(output_dims)); +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Relu1(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu1(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Relu6(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu6(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +template +inline void Add(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, int input2_shift, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + + tflite::ArithmeticParams op_params; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = kReverseShift * output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +void Add(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add/int32"); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + + tflite::ArithmeticParams op_params; + op_params.quantized_activation_min = std::numeric_limits::min(); + op_params.quantized_activation_max = std::numeric_limits::max(); + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + + tflite::ArithmeticParams op_params; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = kReverseShift * output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + +template +void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + tflite::ArithmeticParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + +template +inline void BroadcastAddFivefold( + int y0, int y1, int y2, int y3, int y4, int left_shift, + const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + tflite::ArithmeticParams op_params; + op_params.broadcast_category = + tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = kReverseShift * output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + op_params.broadcast_shape[4] = y0; + op_params.broadcast_shape[3] = y1; + op_params.broadcast_shape[2] = y2; + op_params.broadcast_shape[1] = y3; + op_params.broadcast_shape[0] = y4; + BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + +// legacy, for compatibility with old checked-in code +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +template +inline void Add(const int16* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16* input2_data, + const Dims<4>& input2_dims, int input2_shift, + int16 output_activation_min, int16 output_activation_max, + int16* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, -32768); + TFLITE_DCHECK_EQ(output_activation_max, 32767); + } + + tflite::ArithmeticParams op_params; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +inline void Sub(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(FusedActivationFunctionType::kNone, + &output_activation_min, &output_activation_max); + tflite::ArithmeticParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + Sub(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, + const Dims<4>& input2_dims, T* output_data, + const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + op_params.quantized_activation_min = std::numeric_limits::min(); + op_params.quantized_activation_max = std::numeric_limits::max(); + Sub(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + tflite::PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = kheight; + params.filter_width = kwidth; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.float_activation_min = output_activation_min; + params.float_activation_max = output_activation_max; + AveragePool(params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul4DSlow( + input1_data, input1_dims, input1_offset, input2_data, input2_dims, + input2_offset, output_offset, output_multiplier, + // + kReverseShift * output_shift, + // + output_activation_min, output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, + input2_dims, input2_offset, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + tflite::PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = filter_height; + params.filter_width = filter_width; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.quantized_activation_min = output_activation_min; + params.quantized_activation_max = output_activation_max; + AveragePool(params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + tflite::PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = kheight; + params.filter_width = kwidth; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.float_activation_min = output_activation_min; + params.float_activation_max = output_activation_max; + MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = filter_height; + params.filter_width = filter_width; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.quantized_activation_min = output_activation_min; + params.quantized_activation_max = output_activation_max; + MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = filter_height; + params.filter_width = filter_width; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.float_activation_min = output_activation_min; + params.float_activation_max = output_activation_max; + L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), beta, output_data, + DimsToShape(output_dims)); +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier, + input_beta_left_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_multiplier, int32 input_left_shift, + int32 reverse_scaling_divisor, + int32 reverse_scaling_right_shift, int diff_min, + uint8* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier, + input_left_shift, reverse_scaling_divisor, + reverse_scaling_right_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const int16* input_data, const Dims<4>& input_dims, + int16* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const int16* input_data, const Dims<4>& input_dims, + int input_left_shift, int16* output_data, + const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data, + DimsToShape(output_dims)); +} + +} // namespace reference_ops +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc index f8c6f341f7e61529bbbac592f9caf115f6121e0c..aa93e857d7a9f98aa06e91ff3d6c743b00b17137 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -14,12 +14,17 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/round.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" +#if defined(_MSC_VER) +#define __restrict__ __restrict +#endif + namespace tflite { namespace tensor_utils { @@ -37,24 +42,23 @@ bool PortableIsZeroVector(const float* vector, int v_size) { } void PortableSymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, - float* __restrict__ min, - float* __restrict__ max, - float* __restrict__ scaling_factor) { + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor) { auto minmax = std::minmax_element(values, values + size); - *min = *minmax.first; - *max = *minmax.second; + *min_value = *minmax.first; + *max_value = *minmax.second; const int kScale = 127; - const float range = std::max(std::abs(*min), std::abs(*max)); + const float range = std::max(std::abs(*min_value), std::abs(*max_value)); if (range == 0) { memset(quantized_values, 0, size * sizeof(int8_t)); *scaling_factor = 1; return; } - *scaling_factor = kScale / range; + *scaling_factor = range / kScale; + const float scaling_factor_inv = 1.0f / *scaling_factor; for (int i = 0; i < size; ++i) { const int32_t quantized_value = - static_cast(TfLiteRound(*scaling_factor * values[i])); + static_cast(TfLiteRound(values[i] * scaling_factor_inv)); // Clamp: just in case some odd numeric offset. quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value)); } @@ -69,10 +73,12 @@ void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, for (int b = 0; b < n_batch; b++) { const float* matrix_ptr = matrix; for (int r = 0; r < m_rows; r++) { + float dot_prod = 0.0f; const float* vector_in_batch = vector + b * m_cols; for (int c = 0; c < m_cols; c++) { - *result_in_batch += *matrix_ptr++ * *vector_in_batch++; + dot_prod += *matrix_ptr++ * *vector_in_batch++; } + *result_in_batch += dot_prod; result_in_batch += result_stride; } } @@ -80,25 +86,26 @@ void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, void PortableMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, - const float* __restrict__ scaling_factors, int n_batch, - float* __restrict__ result, int result_stride) { + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, float* __restrict__ result, int result_stride) { int batch, row, col; for (batch = 0; batch < n_batch; ++batch, vectors += m_cols) { - const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch]; + const float batch_scaling_factor = scaling_factors[batch]; // Get the address of the first row. const int8_t* row_ptr = matrix; for (row = 0; row < m_rows; ++row, result += result_stride) { // Initialize the dot product sum for the row to 0. int32_t dotprod = 0; +#if defined(__GNUC__) // Prefetch the row to cache. __builtin_prefetch(row_ptr, 0 /* prefetch for read */, 3 /* temporal locality */); +#endif // For every block of 16 8-bit elements (128-bit register) from each row. for (col = 0; col < m_cols; ++col, ++row_ptr) { dotprod += (*row_ptr) * (vectors[col]); } // for col - *result += (dotprod * batch_scaling_factor_inv); + *result += (dotprod * batch_scaling_factor); } // for row } // for batch } @@ -194,6 +201,13 @@ void PortableZeroVector(float* vector, int v_size) { memset(vector, 0, v_size * sizeof(float)); } +void PortableVectorScalarMultiply(const int8_t* vector, const int v_size, + const float scale, float* result) { + for (int v = 0; v < v_size; ++v) { + *result++ = scale * *vector++; + } +} + void PortableClipVector(const float* vector, int v_size, float abs_limit, float* result) { for (int v = 0; v < v_size; v++) { diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index d2e1fecd25cf3d11d3daffcc566dc1d5df97128c..a375aaffa67ac19975cc8e371be11547d689dc72 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -19,6 +19,10 @@ limitations under the License. // structure. #include "tensorflow/contrib/lite/builtin_op_data.h" +#if defined(_MSC_VER) +#define __restrict__ __restrict +#endif + namespace tflite { namespace tensor_utils { @@ -28,8 +32,8 @@ float PortableClip(float f, float abs_limit); bool PortableIsZeroVector(const float* vector, int v_size); void PortableSymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float* min, - float* max, float* scaling_factor); + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor); // Multiply a matrix by a batch vector, and store results in a batch-size // vector. @@ -96,6 +100,10 @@ void PortableSub1Vector(const float* vector, int v_size, float* result); // Fill vector with 0.f. void PortableZeroVector(float* vector, int v_size); +// Multiply all elements of vector with a scalar. +void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result); + // Clip elements of a vector using a abs_limit value. void PortableClipVector(const float* vector, int v_size, float abs_limit, float* result); @@ -199,6 +207,12 @@ void ZeroVector(float* vector, int v_size) { PortableZeroVector(vector, v_size); } +// Multiply all elements of vector with a scalar. +void VectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result) { + PortableVectorScalarMultiply(vector, v_size, scale, result); +} + void ClipVector(const float* vector, int v_size, float abs_limit, float* result) { PortableClipVector(vector, v_size, abs_limit, result); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 07c215ebcb6f65b33df7086a2ffc1c3d1f04c605..2816fb5cb2fd5f46c35665923507b17470c50772 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -158,98 +158,6 @@ SaturatingRoundingMultiplyByPOTParam( SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); } -// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE -// BROADCASTING. -// -// NdArrayDesc describes the shape and memory layout of an N-dimensional -// rectangular array of numbers. -// -// NdArrayDesc is basically identical to Dims defined in types.h. -// However, as Dims is to be deprecated, this class exists as an adaptor -// to enable simple unoptimized implementations of element-wise broadcasting -// operations. -template -struct NdArrayDesc { - // The "extent" of each dimension. Indices along dimension d must be in the - // half-open interval [0, extents[d]). - int extents[N]; - - // The number of *elements* (not bytes) between consecutive indices of each - // dimension. - int strides[N]; -}; - -// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING -// ELEMENT-WISE BROADCASTING. -// -// Same as Offset(), except takes as NdArrayDesc instead of Dims. -inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, - int i3) { - TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); - TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); - TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); - TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); - return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + - i3 * desc.strides[3]; -} - -// Given the dimensions of the operands for an element-wise binary broadcast, -// adjusts them so that they can be directly iterated over with simple loops. -// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and -// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. -// -// This function assumes that the two input shapes are compatible up to -// broadcasting and the shorter one has already been prepended with 1s to be the -// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), -// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that -// Dims refer to shapes in reverse order. In this case, input0_dims will be -// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). -// -// When two shapes are compatible up to broadcasting, for each dimension d, -// the input extents are either equal, or one of them is 1. -// -// This function performs the following for each dimension d: -// - If the extents are equal, then do nothing since the loop that walks over -// both of the input arrays is correct. -// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 -// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows -// array0 to be referenced *at any index* in dimension d and still access the -// same slice. -template -inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, - const Dims& input1_dims, - NdArrayDesc* desc0_out, - NdArrayDesc* desc1_out) { - TFLITE_DCHECK(desc0_out != nullptr); - TFLITE_DCHECK(desc1_out != nullptr); - - // Copy dims to desc. - for (int i = 0; i < N; ++i) { - desc0_out->extents[i] = input0_dims.sizes[i]; - desc0_out->strides[i] = input0_dims.strides[i]; - desc1_out->extents[i] = input1_dims.sizes[i]; - desc1_out->strides[i] = input1_dims.strides[i]; - } - - // Walk over each dimension. If the extents are equal do nothing. - // Otherwise, set the desc with extent 1 to have extent equal to the other and - // stride 0. - for (int i = 0; i < N; ++i) { - const int extent0 = ArraySize(input0_dims, i); - const int extent1 = ArraySize(input1_dims, i); - if (extent0 != extent1) { - if (extent0 == 1) { - desc0_out->strides[i] = 0; - desc0_out->extents[i] = extent1; - } else { - TFLITE_DCHECK_EQ(extent1, 1); - desc1_out->strides[i] = 0; - desc1_out->extents[i] = extent0; - } - } - } -} - inline void Conv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, const float* bias_data, const Dims<4>& bias_dims, @@ -414,8 +322,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOneExp( - acc, output_multiplier, kReverseShift * output_shift); + acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, + kReverseShift * output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); @@ -638,8 +546,8 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOneExp( - acc, output_multiplier, kReverseShift * output_shift); + acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, + kReverseShift * output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); @@ -697,7 +605,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, } } -inline void ExperimentalShuffledFullyConnected( +inline void ShuffledFullyConnected( const uint8* input_data, const Dims<4>& input_dims, const uint8* shuffled_weights_data, const Dims<4>& weights_dims, const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier, @@ -914,9 +822,9 @@ void GlobalBatchNormalization(const float* input_data, } } -inline void Relu(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input_dims, output_dims); +inline void Relu(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; const float lower = 0; @@ -925,9 +833,10 @@ inline void Relu(const float* input_data, const Dims<4>& input_dims, } } -inline void Relu1(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input_dims, output_dims); +inline void Relu1(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; const float upper = 1; @@ -937,9 +846,10 @@ inline void Relu1(const float* input_data, const Dims<4>& input_dims, } } -inline void Relu6(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input_dims, output_dims); +inline void Relu6(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; const float upper = 6; @@ -949,12 +859,28 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims, } } +inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data, + const RuntimeShape& input_shape, uint8* output_data, + const RuntimeShape& output_shape) { + gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)"); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const uint8 val = input_data[i]; + const uint8 clamped = + val > max_value ? max_value : val < min_value ? min_value : val; + output_data[i] = clamped; + } +} + template -void L2Normalization(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +void L2Normalization(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { static_assert(Ac == FusedActivationFunctionType::kNone, ""); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { float squared_l2_norm = 0; for (int c = 0; c < depth; ++c) { @@ -977,7 +903,8 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input, ++*output_shift; } TFLITE_DCHECK_GT(input, 0); - const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bits = + CountLeadingZeros(static_cast(input)) - 1; const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; *output_shift -= left_shift_bit_pairs; @@ -1015,16 +942,19 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input, *output_shift *= kReverseShift; } -inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, +inline void L2Normalization(const uint8* input_data, + const RuntimeShape& input_shape, int32 input_zero_point, uint8* output_data, - const Dims<4>& output_dims) { - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); + const RuntimeShape& output_shape) { + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); for (int i = 0; i < outer_size; ++i) { int32 square_l2_norm = 0; for (int c = 0; c < depth; c++) { - int32 diff = - input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point; + int32 diff = input_data[depth * i + c] - input_zero_point; square_l2_norm += diff * diff; } int32 inv_l2norm_multiplier; @@ -1033,136 +963,119 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, &inv_l2norm_shift); for (int c = 0; c < depth; c++) { - int32 diff = - input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point; + int32 diff = input_data[depth * i + c] - input_zero_point; int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); int32 unclamped_output_val = 128 + rescaled_diff; int32 output_val = std::min(255, std::max(0, unclamped_output_val)); - output_data[Offset(output_dims, c, i, 0, 0)] = - static_cast(output_val); + output_data[depth * i + c] = static_cast(output_val); } } } -inline void Add(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); +template +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const T* input1_data, + const RuntimeShape& input2_shape, const T* input2_data, + const RuntimeShape& output_shape, T* output_data) { + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( - input1_data[i] + input2_data[i], output_activation_min, - output_activation_max); + input1_data[i] + input2_data[i], params.quantized_activation_min, + params.quantized_activation_max); } } -// legacy, for compatibility with old checked-in code -template -void Add(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float* output_data, const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, - output_activation_max, output_data, output_dims); -} - -template -inline void Add(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, int input2_shift, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - const int batches = - MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); - const int height = - MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); - const int width = - MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); - const int depth = - MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - const int32 input1_val = - input1_offset + input1_data[Offset(input1_dims, c, x, y, b)]; - const int32 input2_val = - input2_offset + input2_data[Offset(input2_dims, c, x, y, b)]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); - const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); - const int32 raw_sum = scaled_input1_val + scaled_input2_val; - const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sum, output_multiplier, kReverseShift * output_shift) + - output_offset; - const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, raw_output)); - output_data[Offset(output_dims, c, x, y, b)] = - static_cast(clamped_output); - } - } - } +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const float* input1_data, + const RuntimeShape& input2_shape, const float* input2_data, + const RuntimeShape& output_shape, float* output_data) { + const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape); + for (int i = 0; i < size; i++) { + auto x = input1_data[i] + input2_data[i]; + output_data[i] = ActivationFunctionWithMinMax( + x, params.float_activation_min, params.float_activation_max); } } -template -inline void Add(const int16* input1_data, const Dims<4>& input1_dims, - int input1_shift, const int16* input2_data, - const Dims<4>& input2_dims, int input2_shift, - int16 output_activation_min, int16 output_activation_max, - int16* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, -32768); - TFLITE_DCHECK_EQ(output_activation_max, 32767); +// Element-wise add that can often be used for inner loop of broadcast add as +// well as the non-broadcast add. +inline void AddElementwise(int size, const ArithmeticParams& params, + const uint8* input1_data, const uint8* input2_data, + uint8* output_data) { + TFLITE_DCHECK_GT(params.input1_offset, -256); + TFLITE_DCHECK_GT(params.input2_offset, -256); + TFLITE_DCHECK_LT(params.input1_offset, 256); + TFLITE_DCHECK_LT(params.input2_offset, 256); + + for (int i = 0; i < size; ++i) { + const int32 input1_val = params.input1_offset + input1_data[i]; + const int32 input2_val = params.input2_offset + input2_data[i]; + const int32 shifted_input1_val = input1_val * (1 << params.left_shift); + const int32 shifted_input2_val = input2_val * (1 << params.left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, params.input1_multiplier, params.input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, params.input2_multiplier, params.input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, params.output_multiplier, params.output_shift) + + params.output_offset; + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, raw_output)); + output_data[i] = static_cast(clamped_output); } +} - const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); - - TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0); - TFLITE_DCHECK_GE(input1_shift, 0); - TFLITE_DCHECK_GE(input2_shift, 0); +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const uint8* input1_data, + const RuntimeShape& input2_shape, const uint8* input2_data, + const RuntimeShape& output_shape, uint8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + + TFLITE_DCHECK_GT(params.input1_offset, -256); + TFLITE_DCHECK_GT(params.input2_offset, -256); + TFLITE_DCHECK_LT(params.input1_offset, 256); + TFLITE_DCHECK_LT(params.input2_offset, 256); + AddElementwise(flat_size, params, input1_data, input2_data, output_data); +} + +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int16* input1_data, + const RuntimeShape& input2_shape, const int16* input2_data, + const RuntimeShape& output_shape, int16* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + + const int input1_shift = params.input1_shift; + const int flat_size = + MatchingFlatSize(output_shape, input1_shape, input2_shape); + const int16 output_activation_min = params.quantized_activation_min; + const int16 output_activation_max = params.quantized_activation_max; + + TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0); + TFLITE_DCHECK_LE(input1_shift, 0); + TFLITE_DCHECK_LE(params.input2_shift, 0); const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data; const int16* shift_input = input1_shift == 0 ? input2_data : input1_data; - const int input_shift = input1_shift == 0 ? input2_shift : input1_shift; + const int input_right_shift = + input1_shift == 0 ? -params.input2_shift : -input1_shift; for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. using F0 = gemmlowp::FixedPoint; F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]); - F0 scaled_input = - F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_shift)); + F0 scaled_input = F0::FromRaw( + gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift)); F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled); const int16 raw_output = result.raw(); const int16 clamped_output = std::min( @@ -1175,16 +1088,24 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims, // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then // generate max(D1, D2) nested for loops. -template -void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T output_activation_min, T output_activation_max, - T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); - +// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from +// reference_ops.h. Once an optimized version is implemented and NdArrayDesc +// is no longer referenced in this file, move NdArrayDesc from types.h to +// reference_ops.h. +inline void BroadcastAdd4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + float* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -1197,49 +1118,77 @@ void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + output_data[Offset(extended_output_shape, b, y, x, c)] = ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] + - input2_data[SubscriptToIndex(desc2, c, x, y, b)], - output_activation_min, output_activation_max); + input1_data[SubscriptToIndex(desc1, b, y, x, c)] + + input2_data[SubscriptToIndex(desc2, b, y, x, c)], + params.float_activation_min, params.float_activation_max); } } } } } -// legacy, for compatibility with old checked-in code -template -void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T* output_data, const Dims<4>& output_dims) { - T output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); +inline void BroadcastAdd4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int32* input1_data, + const RuntimeShape& input2_shape, + const int32* input2_data, + const RuntimeShape& output_shape, + int32* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32"); + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); - BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims, - output_activation_min, output_activation_max, output_data, - output_dims); + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + output_data[Offset(extended_output_shape, b, y, x, c)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, b, y, x, c)] + + input2_data[SubscriptToIndex(desc2, b, y, x, c)], + params.quantized_activation_min, + params.quantized_activation_max); + } + } + } + } } -inline void BroadcastAdd(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit"); - +inline void BroadcastAdd4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const uint8* input1_data, + const RuntimeShape& input2_shape, + const uint8* input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -1252,33 +1201,37 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + params.input1_offset + + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); + params.input2_offset + + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; + const int32 shifted_input1_val = + input1_val * (1 << params.left_shift); + const int32 shifted_input2_val = + input2_val * (1 << params.left_shift); const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); + shifted_input1_val, params.input1_multiplier, + params.input1_shift); const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); + shifted_input2_val, params.input2_multiplier, + params.input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sum, output_multiplier, kReverseShift * output_shift) + - output_offset; + raw_sum, params.output_multiplier, params.output_shift) + + params.output_offset; const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, raw_output)); - output_data[Offset(output_dims, c, x, y, b)] = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, raw_output)); + output_data[Offset(extended_output_shape, b, y, x, c)] = static_cast(clamped_output); } } @@ -1286,121 +1239,67 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, } } -inline void BroadcastAddFivefold( - int y0, int y1, int y2, int y3, int y4, int left_shift, - const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit"); - - int sb1 = y0; - int sa2 = y0; - int sb2 = y0 * y1; - int sa3 = y0 * y2; - int sa4 = y0 * y2 * y3; - int sb4 = y0 * y1 * y2; - +inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params, + const RuntimeShape& unswitched_input1_shape, + const uint8* unswitched_input1_data, + const RuntimeShape& unswitched_input2_shape, + const uint8* unswitched_input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + ArithmeticParams switched_params = unswitched_params; + switched_params.input1_offset = unswitched_params.input2_offset; + switched_params.input1_multiplier = unswitched_params.input2_multiplier; + switched_params.input1_shift = unswitched_params.input2_shift; + switched_params.input2_offset = unswitched_params.input1_offset; + switched_params.input2_multiplier = unswitched_params.input1_multiplier; + switched_params.input2_shift = unswitched_params.input1_shift; + + const bool use_unswitched = + unswitched_params.broadcast_category == + tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast; + + const ArithmeticParams& params = + use_unswitched ? unswitched_params : switched_params; + const uint8* input1_data = + use_unswitched ? unswitched_input1_data : unswitched_input2_data; + const uint8* input2_data = + use_unswitched ? unswitched_input2_data : unswitched_input1_data; + + // Fivefold nested loops. The second input resets its position for each + // iteration of the second loop. The first input resets its position at the + // beginning of the fourth loop. The innermost loop is an elementwise add of + // sections of the arrays. uint8* output_data_ptr = output_data; - for (int i4 = 0; i4 < y4; ++i4) { - for (int i3 = 0; i3 < y3; ++i3) { + const uint8* input1_data_ptr = input1_data; + const uint8* input2_data_reset = input2_data; + int y0 = params.broadcast_shape[0]; + int y1 = params.broadcast_shape[1]; + int y2 = params.broadcast_shape[2]; + int y3 = params.broadcast_shape[3]; + int y4 = params.broadcast_shape[4]; + for (int i0 = 0; i0 < y0; ++i0) { + const uint8* input2_data_ptr; + for (int i1 = 0; i1 < y1; ++i1) { + input2_data_ptr = input2_data_reset; for (int i2 = 0; i2 < y2; ++i2) { - for (int i1 = 0; i1 < y1; ++i1) { - for (int i0 = 0; i0 < y0; ++i0) { - const int32 input1_val = - input1_offset + - input1_data[i4 * sa4 + i3 * sa3 + i2 * sa2 + i0]; - const int32 input2_val = - input2_offset + - input2_data[i4 * sb4 + i2 * sb2 + i1 * sb1 + i0]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); - const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); - const int32 raw_sum = scaled_input1_val + scaled_input2_val; - const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sum, output_multiplier, kReverseShift * output_shift) + - output_offset; - const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, raw_output)); - *output_data_ptr = static_cast(clamped_output); - ++output_data_ptr; - } + for (int i3 = 0; i3 < y3; ++i3) { + AddElementwise(y4, params, input1_data_ptr, input2_data_ptr, + output_data_ptr); + input2_data_ptr += y4; + output_data_ptr += y4; } + input1_data_ptr += y4; } } + input2_data_reset = input2_data_ptr; } } -template -inline void BroadcastAdd(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset, - input1_multiplier, input1_shift, input2_data, input2_dims, - input2_offset, input2_multiplier, input2_shift, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_dims); -} - -template -inline void BroadcastAddFivefold( - int y0, int y1, int y2, int y3, int y4, int left_shift, - const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - BroadcastAddFivefold(y0, y1, y2, y3, y4, left_shift, input1_data, input1_dims, - input1_offset, input1_multiplier, input1_shift, - input2_data, input2_dims, input2_offset, - input2_multiplier, input2_shift, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_dims); -} - -inline void Mul(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { +template +inline void Mul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( @@ -1475,13 +1374,143 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, output_dims); } -inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, - int32 input1_offset, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { +// Element-wise mul that can often be used for inner loop of broadcast Mul as +// well as the non-broadcast Mul. +inline void MulElementwise(int size, const ArithmeticParams& params, + const uint8* input1_data, const uint8* input2_data, + uint8* output_data) { + for (int i = 0; i < size; ++i) { + const int32 input1_val = params.input1_offset + input1_data[i]; + const int32 input2_val = params.input2_offset + input2_data[i]; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp(input1_val * input2_val, + params.output_multiplier, + params.output_shift); + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[i] = static_cast(clamped_output); + } +} + +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const uint8* input1_data, + const RuntimeShape& input2_shape, const uint8* input2_data, + const RuntimeShape& output_shape, uint8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + gemmlowp::ScopedProfilingLabel label("Mul/8bit"); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + + MulElementwise(flat_size, params, input1_data, input2_data, output_data); +} + +inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params, + const RuntimeShape& unswitched_input1_shape, + const uint8* unswitched_input1_data, + const RuntimeShape& unswitched_input2_shape, + const uint8* unswitched_input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + ArithmeticParams switched_params = unswitched_params; + switched_params.input1_offset = unswitched_params.input2_offset; + switched_params.input2_offset = unswitched_params.input1_offset; + + const bool use_unswitched = + unswitched_params.broadcast_category == + tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast; + + const ArithmeticParams& params = + use_unswitched ? unswitched_params : switched_params; + const uint8* input1_data = + use_unswitched ? unswitched_input1_data : unswitched_input2_data; + const uint8* input2_data = + use_unswitched ? unswitched_input2_data : unswitched_input1_data; + + // Fivefold nested loops. The second input resets its position for each + // iteration of the second loop. The first input resets its position at the + // beginning of the fourth loop. The innermost loop is an elementwise Mul of + // sections of the arrays. + uint8* output_data_ptr = output_data; + const uint8* input1_data_ptr = input1_data; + const uint8* input2_data_reset = input2_data; + int y0 = params.broadcast_shape[0]; + int y1 = params.broadcast_shape[1]; + int y2 = params.broadcast_shape[2]; + int y3 = params.broadcast_shape[3]; + int y4 = params.broadcast_shape[4]; + for (int i0 = 0; i0 < y0; ++i0) { + const uint8* input2_data_ptr; + for (int i1 = 0; i1 < y1; ++i1) { + input2_data_ptr = input2_data_reset; + for (int i2 = 0; i2 < y2; ++i2) { + for (int i3 = 0; i3 < y3; ++i3) { + MulElementwise(y4, params, input1_data_ptr, input2_data_ptr, + output_data_ptr); + input2_data_ptr += y4; + output_data_ptr += y4; + } + input1_data_ptr += y4; + } + } + input2_data_reset = input2_data_ptr; + } +} + +inline void BroadcastMul4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const uint8* input1_data, + const RuntimeShape& input2_shape, + const uint8* input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); + + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + const int32 input1_val = + params.input1_offset + + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; + const int32 input2_val = + params.input2_offset + + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; + const int32 unclamped_result = + params.output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input1_val * input2_val, params.output_multiplier, + params.output_shift); + const int32 clamped_output = std::min( + params.quantized_activation_max, + std::max(params.quantized_activation_min, unclamped_result)); + output_data[Offset(extended_output_shape, b, y, x, c)] = + static_cast(clamped_output); + } + } + } + } +} + +// Transitional version that will be moved shortly to legacy_reference_ops, as +// part of RuntimeShape revisions. +inline void BroadcastMul4DSlow(const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit"); NdArrayDesc<4> desc1; @@ -1508,9 +1537,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, const int32 input2_val = input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; const int32 unclamped_result = - output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp( - input1_val * input2_val, output_multiplier, - kReverseShift * output_shift); + output_offset + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input1_val * input2_val, output_multiplier, output_shift); const int32 clamped_output = std::min(output_activation_max, std::max(output_activation_min, unclamped_result)); @@ -1565,21 +1594,6 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, } } -// legacy, for compatibility with old checked-in code -template -inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, - int32 input1_offset, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, - input2_dims, input2_offset, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_data, output_dims); -} - // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then @@ -1621,10 +1635,11 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims, } } -inline void Div(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { +template +inline void Div(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( @@ -1633,15 +1648,35 @@ inline void Div(const float* input1_data, const Dims<4>& input1_dims, } } -inline void Sub(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); +inline void SubNonBroadcast(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + float* output_data) { + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( - input1_data[i] - input2_data[i], output_activation_min, - output_activation_max); + input1_data[i] - input2_data[i], params.float_activation_min, + params.float_activation_max); + } +} + +inline void SubNonBroadcast(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int32* input1_data, + const RuntimeShape& input2_shape, + const int32* input2_data, + const RuntimeShape& output_shape, + int32* output_data) { + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] - input2_data[i], params.quantized_activation_min, + params.quantized_activation_max); } } @@ -1649,16 +1684,24 @@ inline void Sub(const float* input1_data, const Dims<4>& input1_dims, // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then // generate max(D1, D2) nested for loops. -template -void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T output_activation_min, T output_activation_max, - T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub"); - +// TODO(benoitjacob): BroadcastSub is intentionally duplicated from +// reference_ops.h. Once an optimized version is implemented and NdArrayDesc +// is no longer referenced in this file, move NdArrayDesc from types.h to +// reference_ops.h. +inline void BroadcastSub4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + float* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -1671,36 +1714,35 @@ void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + output_data[Offset(extended_output_shape, b, y, x, c)] = ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] - - input2_data[SubscriptToIndex(desc2, c, x, y, b)], - output_activation_min, output_activation_max); + input1_data[SubscriptToIndex(desc1, b, y, x, c)] - + input2_data[SubscriptToIndex(desc2, b, y, x, c)], + params.float_activation_min, params.float_activation_max); } } } } } -inline void BroadcastSub(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub/8bit"); - +inline void BroadcastSub4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const uint8* input1_data, + const RuntimeShape& input2_shape, + const uint8* input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -1713,33 +1755,37 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + params.input1_offset + + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); + params.input2_offset + + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; + const int32 shifted_input1_val = + input1_val * (1 << params.left_shift); + const int32 shifted_input2_val = + input2_val * (1 << params.left_shift); const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); + shifted_input1_val, params.input1_multiplier, + params.input1_shift); const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); + shifted_input2_val, params.input2_multiplier, + params.input2_shift); const int32 raw_sub = scaled_input1_val - scaled_input2_val; const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sub, output_multiplier, kReverseShift * output_shift) + - output_offset; + raw_sub, params.output_multiplier, params.output_shift) + + params.output_offset; const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, raw_output)); - output_data[Offset(output_dims, c, x, y, b)] = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, raw_output)); + output_data[Offset(extended_output_shape, b, y, x, c)] = static_cast(clamped_output); } } @@ -1747,11 +1793,160 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, } } +inline void BroadcastSub4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int32* input1_data, + const RuntimeShape& input2_shape, + const int32* input2_data, + const RuntimeShape& output_shape, + int32* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32"); + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + output_data[Offset(extended_output_shape, b, y, x, c)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, b, y, x, c)] - + input2_data[SubscriptToIndex(desc2, b, y, x, c)], + params.quantized_activation_min, + params.quantized_activation_max); + } + } + } + } +} + +template +void BroadcastSub4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const T* input1_data, + const RuntimeShape& input2_shape, const T* input2_data, + const RuntimeShape& output_shape, T* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/templated"); + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + output_data[Offset(extended_output_shape, b, y, x, c)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, b, y, x, c)] - + input2_data[SubscriptToIndex(desc2, b, y, x, c)], + params.quantized_activation_min, + params.quantized_activation_max); + } + } + } + } +} + +template +void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape, + const T* input1_data, const RuntimeShape& input2_shape, + const T* input2_data, const RuntimeShape& output_shape, + T* output_data) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + output_data[Offset(extended_output_shape, b, y, x, c)] = + input1_data[SubscriptToIndex(desc1, b, y, x, c)] - + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; + } + } + } + } +} + +inline void SubWithActivation(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int32* input1_data, + const RuntimeShape& input2_shape, + const int32* input2_data, + const RuntimeShape& output_shape, + int32* output_data) { + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, input2_shape); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] - input2_data[i], params.quantized_activation_min, + params.quantized_activation_max); + } +} + +inline void SubWithActivation(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + float* output_data) { + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, input2_shape); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] - input2_data[i], params.float_activation_min, + params.float_activation_max); + } +} + template void Concatenation(int concat_dim, const Scalar* const* input_data, const Dims<4>* const* input_dims, int inputs_count, Scalar* output_data, const Dims<4>& output_dims) { - TFLITE_DCHECK_GT(inputs_count, 1); int concat_size = 0; for (int i = 0; i < inputs_count; i++) { for (int j = 0; j < 4; j++) { @@ -1762,7 +1957,9 @@ void Concatenation(int concat_dim, const Scalar* const* input_data, concat_size += ArraySize(*input_dims[i], concat_dim); } TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); - TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + // For now we don't have a model with a Concatenation with fused activation. + TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone); int outer_size = 1; for (int i = concat_dim + 1; i < 4; i++) { outer_size *= output_dims.sizes[i]; @@ -1779,6 +1976,26 @@ void Concatenation(int concat_dim, const Scalar* const* input_data, } } +template +void Pack(int dim, const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + int outer_size = 1; + for (int i = dim + 1; i < 4; i++) { + outer_size *= output_dims.sizes[i]; + } + Scalar* output_ptr = output_data; + const int copy_size = FlatSize(**input_dims) / outer_size; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < inputs_count; ++i) { + memcpy(output_ptr, input_data[i] + k * copy_size, + copy_size * sizeof(Scalar)); + output_ptr += copy_size; + } + } +} + // TODO(prabhumk): This is the same as the optimized implementation. // TODO(prabhumk): The quantized implementation of concatentation isn't fully // quantized as it takes scale as a floating point value. This should be fixed @@ -2240,32 +2457,36 @@ inline int NodeOffset(int b, int h, int w, int height, int width) { return (b * height + h) * width + w; } -inline void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - float output_activation_min, - float output_activation_max, float* output_data, - const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); +inline void AveragePool(const PoolParams& params, + const RuntimeShape& input_shape, + const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { for (int channel = 0; channel < depth; ++channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; // Compute the boundaries of the filter region clamped so as to // ensure that the filter window fits in the input array. const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); float total = 0.f; float filter_count = 0; for (int filter_y = filter_y_start; filter_y < filter_y_end; @@ -2275,70 +2496,52 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, const int in_x = in_x_origin + filter_x; const int in_y = in_y_origin + filter_y; total += - input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + input_data[Offset(input_shape, batch, in_y, in_x, channel)]; filter_count++; } } const float average = total / filter_count; - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = - ActivationFunctionWithMinMax(average, output_activation_min, - output_activation_max); + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = + ActivationFunctionWithMinMax(average, params.float_activation_min, + params.float_activation_max); } } } } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, float* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); +inline void AveragePool(const PoolParams& params, + const RuntimeShape& input_shape, + const uint8* input_data, + const RuntimeShape& output_shape, uint8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { for (int channel = 0; channel < depth; ++channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; // Compute the boundaries of the filter region clamped so as to // ensure that the filter window fits in the input array. const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); int32 acc = 0; int filter_count = 0; for (int filter_y = filter_y_start; filter_y < filter_y_end; @@ -2347,79 +2550,51 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, ++filter_x) { const int in_x = in_x_origin + filter_x; const int in_y = in_y_origin + filter_y; - acc += input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + acc += + input_data[Offset(input_shape, batch, in_y, in_x, channel)]; filter_count++; } } acc = (acc + filter_count / 2) / filter_count; - acc = std::max(acc, output_activation_min); - acc = std::min(acc, output_activation_max); - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = - static_cast(acc); - } - } - } - } -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); + acc = std::max(acc, params.quantized_activation_min); + acc = std::min(acc, params.quantized_activation_max); + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = + static_cast(acc); + } + } + } } - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); } -inline void L2Pool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); +inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { for (int channel = 0; channel < depth; ++channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; // Compute the boundaries of the filter region clamped so as to // ensure that the filter window fits in the input array. const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); float sum_squares = 0.f; int filter_count = 0; for (int filter_y = filter_y_start; filter_y < filter_y_end; @@ -2429,69 +2604,51 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, const int in_x = in_x_origin + filter_x; const int in_y = in_y_origin + filter_y; const float val = - input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + input_data[Offset(input_shape, batch, in_y, in_x, channel)]; sum_squares += val * val; filter_count++; } } const float l2pool_result = std::sqrt(sum_squares / filter_count); - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = - ActivationFunctionWithMinMax(l2pool_result, output_activation_min, - output_activation_max); + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = + ActivationFunctionWithMinMax(l2pool_result, + params.float_activation_min, + params.float_activation_max); } } } } } -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void MaxPool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); +inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { for (int channel = 0; channel < depth; ++channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; // Compute the boundaries of the filter region clamped so as to // ensure that the filter window fits in the input array. const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); float max = std::numeric_limits::lowest(); for (int filter_y = filter_y_start; filter_y < filter_y_end; ++filter_y) { @@ -2501,68 +2658,51 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, const int in_y = in_y_origin + filter_y; max = std::max( max, - input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + input_data[Offset(input_shape, batch, in_y, in_x, channel)]); } } - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = - ActivationFunctionWithMinMax(max, output_activation_min, - output_activation_max); + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = + ActivationFunctionWithMinMax(max, params.float_activation_min, + params.float_activation_max); } } } } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - TFLITE_DCHECK_GE(output_activation_min, 0); - TFLITE_DCHECK_LE(output_activation_max, 255); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); +inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& output_shape, + uint8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + TFLITE_DCHECK_GE(params.quantized_activation_min, 0); + TFLITE_DCHECK_LE(params.quantized_activation_max, 255); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { for (int channel = 0; channel < depth; ++channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; // Compute the boundaries of the filter region clamped so as to // ensure that the filter window fits in the input array. const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); uint8 max = 0; for (int filter_y = filter_y_start; filter_y < filter_y_end; ++filter_y) { @@ -2572,12 +2712,12 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, const int in_y = in_y_origin + filter_y; max = std::max( max, - input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + input_data[Offset(input_shape, batch, in_y, in_x, channel)]); } } - max = std::max(max, output_activation_min); - max = std::min(max, output_activation_max); - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + max = std::max(max, params.quantized_activation_min); + max = std::min(max, params.quantized_activation_max); + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = static_cast(max); } } @@ -2585,38 +2725,6 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - inline void LocalResponseNormalization(const float* input_data, const Dims<4>& input_dims, int range, float bias, float alpha, float beta, @@ -2640,11 +2748,14 @@ inline void LocalResponseNormalization(const float* input_data, } } -inline void Softmax(const float* input_data, const Dims<4>& input_dims, +inline void Softmax(const float* input_data, const RuntimeShape& input_shape, float beta, float* output_data, - const Dims<4>& output_dims) { - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const RuntimeShape& output_shape) { + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { // Find max element value which we'll use to ensure numerical stability @@ -2669,10 +2780,10 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims, } } -inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, +inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_beta_multiplier, int32 input_beta_left_shift, int diff_min, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -2685,8 +2796,11 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { uint8 max_in_row = 0; @@ -2747,10 +2861,13 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, } } -inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); +inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { // Find max element value which we'll use to ensure numerical stability @@ -2890,11 +3007,11 @@ log_x_for_x_greater_than_or_equal_to_1( input_val); } -inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, +inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_multiplier, int32 input_left_shift, int32 reverse_scaling_divisor, int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -2908,8 +3025,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { uint8 max_in_row = 0; @@ -2973,9 +3093,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Logistic(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { float val = input_data[i]; @@ -2984,11 +3104,11 @@ inline void Logistic(const float* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, +inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); + uint8* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { const uint8 input_val_u8 = input_data[i]; @@ -3022,9 +3142,9 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const int16* input_data, const Dims<4>& input_dims, - int16* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, + int16* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -3040,9 +3160,9 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Tanh(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { float val = input_data[i]; @@ -3051,12 +3171,12 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, +inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { const int32 output_zero_point = 128; - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { const uint8 input_val_u8 = input_data[i]; @@ -3091,15 +3211,15 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const int16* input_data, const Dims<4>& input_dims, +inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, int input_left_shift, int16* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { // Support for shifts is limited until we have a parameterized version of // SaturatingRoundingMultiplyByPOT(). TFLITE_DCHECK_GE(input_left_shift, 0); TFLITE_DCHECK_LE(input_left_shift, 1); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); // F0 uses 0 integer bits, range [-1, 1]. // This is the return type of math functions such as tanh, logistic, @@ -3151,18 +3271,9 @@ inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, float nudged_min, nudged_max, nudged_scale; NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min, &nudged_max, &nudged_scale); - const float inv_nudged_scale = 1.0f / nudged_scale; - const int flat_size = MatchingFlatSize(output_dims, input_dims); - for (int i = 0; i < flat_size; i++) { - const float src_val = input_data[i]; - const float clamped = std::min(nudged_max, std::max(nudged_min, src_val)); - const float clamped_shifted = clamped - nudged_min; - const float dst_val = - TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale + - nudged_min; - output_data[i] = dst_val; - } + FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data, + output_data, flat_size); } template @@ -3280,7 +3391,8 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const Dims<4>& block_shape_dims, const int32* paddings_data, const Dims<4>& paddings_dims, T* output_data, - const Dims<4>& output_dims) { + const Dims<4>& output_dims, + const int32_t pad_value) { const int output_batch_size = ArraySize(output_dims, 3); const int output_height = ArraySize(output_dims, 2); const int output_width = ArraySize(output_dims, 1); @@ -3305,7 +3417,7 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, padding_top + input_height || out_w * block_shape_width + shift_w < padding_left || out_w * block_shape_width + shift_w >= padding_left + input_width) { - memset(out, 0, depth * sizeof(T)); + memset(out, pad_value, depth * sizeof(T)); } else { const T* in = input_data + @@ -3320,6 +3432,17 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, } } +template +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + SpaceToBatchND(input_data, input_dims, block_shape_data, block_shape_dims, + paddings_data, paddings_dims, output_data, output_dims, 0); +} + template inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, @@ -3431,7 +3554,7 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, template inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, - int begin_mask, int end_mask, + int begin_mask, int end_mask, int shrink_axis_mask, const std::vector& start_indices, const std::vector& stop_indices, const std::vector& strides, T* output_data, @@ -3443,20 +3566,24 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, TFLITE_DCHECK_EQ(strides.size(), 4); const int start_b = strided_slice::StartForAxis(begin_mask, start_indices, strides, input_dims.sizes, 3); - const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 3); + const int stop_b = + strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, + strides, input_dims.sizes, 3, start_b); const int start_h = strided_slice::StartForAxis(begin_mask, start_indices, strides, input_dims.sizes, 2); - const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 2); + const int stop_h = + strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, + strides, input_dims.sizes, 2, start_h); const int start_w = strided_slice::StartForAxis(begin_mask, start_indices, strides, input_dims.sizes, 1); - const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 1); + const int stop_w = + strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, + strides, input_dims.sizes, 1, start_w); const int start_d = strided_slice::StartForAxis(begin_mask, start_indices, strides, input_dims.sizes, 0); - const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 0); + const int stop_d = + strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, + strides, input_dims.sizes, 0, start_d); T* out_ptr = output_data; for (int in_b = start_b; @@ -3519,8 +3646,6 @@ inline void Exp(const T* input_data, const size_t num_elements, } // A generic reduce method that can be used for reduce_sum, reduce_mean, etc. -// It takes a reducer function as input and returns false when numeric overflow -// is detected. // This method iterates through input data and reduce elements along the // dimensions given in axis. template @@ -3528,10 +3653,9 @@ inline bool Reduce(const In* input_data, const int* input_dims, const int* output_dims, const int input_num_dims, const int output_num_dims, const int* axis, const int num_axis, int* input_iter, - Out reducer(Out current, const In in, bool* overflow), + Out reducer(const Out current, const In in), Out* output_data) { // Reset input iterator. - TFLITE_DCHECK(input_num_dims > 0); for (int idx = 0; idx < input_num_dims; ++idx) { input_iter[idx] = 0; } @@ -3541,19 +3665,22 @@ inline bool Reduce(const In* input_data, const int* input_dims, ReducedOutputOffset(input_num_dims, input_dims, input_iter, 0, nullptr); size_t output_offset = ReducedOutputOffset(input_num_dims, input_dims, input_iter, num_axis, axis); - bool overflow = false; - output_data[output_offset] = reducer(output_data[output_offset], - input_data[input_offset], &overflow); - if (overflow) return false; + output_data[output_offset] = + reducer(output_data[output_offset], input_data[input_offset]); } while (NextIndex(input_num_dims, input_dims, input_iter)); return true; } -inline bool ResolveAxis(const int num_dims, const int* axis, const int num_axis, - int* out_axis, int* out_num_axis) { +inline bool ResolveAxis(const int num_dims, const int* axis, + const int64_t num_axis, int* out_axis, + int* out_num_axis) { *out_num_axis = 0; // Just in case. + // Short-circuit axis resolution for scalars; the axis will go unused. + if (num_dims == 0) { + return true; + } // o(n^2) is fine since out_num_axis should be really small, mostly <= 4 - for (int idx = 0; idx < num_axis; ++idx) { + for (int64_t idx = 0; idx < num_axis; ++idx) { // Handle negative index. int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx]; TFLITE_DCHECK(current >= 0 && current < num_dims); @@ -3579,7 +3706,7 @@ inline bool ReduceSumImpl(const In* input_data, const int* input_dims, const int output_num_dims, const int* axis, const int num_axis, int* input_iter, Out* output_data) { - auto reducer = [](Out current, const In in, bool* overflow) -> Out { + auto reducer = [](const Out current, const In in) -> Out { const Out actual_in = static_cast(in); return current + actual_in; }; @@ -3588,6 +3715,104 @@ inline bool ReduceSumImpl(const In* input_data, const int* input_dims, output_data); } +template +inline bool InitTensorDataForReduce(const int* dims, const int num_dims, + const T init_value, T* data) { + size_t num_elements = 1; + for (int idx = 0; idx < num_dims; ++idx) { + size_t current = static_cast(dims[idx]); + // Overflow prevention. + if (num_elements > std::numeric_limits::max() / current) { + return false; + } + num_elements *= current; + } + for (size_t idx = 0; idx < num_elements; ++idx) { + data[idx] = init_value; + } + return true; +} + +// Computes the sum of elements across dimensions given in axis. +template +inline bool Sum(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int num_axis_dimensions, bool keep_dims, + int* temp_index, int* resolved_axis) { + // Reset output data. + if (!InitTensorDataForReduce(output_dims, output_num_dims, static_cast(0), + output_data)) { + return false; + } + + // Resolve axis. + int num_resolved_axis = 0; + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; + } + + return ReduceSumImpl(input_data, input_dims, output_dims, + input_num_dims, output_num_dims, resolved_axis, + num_resolved_axis, temp_index, output_data); +} + +// Computes the max of elements across dimensions given in axis. +template +inline bool ReduceMax(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int64_t num_axis_dimensions, + bool keep_dims, int* temp_index, int* resolved_axis) { + T init_value = std::numeric_limits::lowest(); + // Reset output data. + if (!InitTensorDataForReduce(output_dims, output_num_dims, init_value, + output_data)) { + return false; + } + + // Resolve axis. + int num_resolved_axis = 0; + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; + } + + auto reducer = [](const T current, const T in) -> T { + return (in > current) ? in : current; + }; + return Reduce(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, resolved_axis, num_resolved_axis, + temp_index, reducer, output_data); +} + +// Computes the prod of elements across dimensions given in axis. +template +inline bool ReduceProd(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int64_t num_axis_dimensions, + bool keep_dims, int* temp_index, int* resolved_axis) { + // Reset output data. + if (!InitTensorDataForReduce(output_dims, output_num_dims, static_cast(1), + output_data)) { + return false; + } + + // Resolve axis. + int num_resolved_axis = 0; + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; + } + + auto reducer = [](const T current, const T in) -> T { return in * current; }; + return Reduce(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, resolved_axis, num_resolved_axis, + temp_index, reducer, output_data); +} + // Computes the mean of elements across dimensions given in axis. // It does so in two stages, first calculates the sum of elements along the axis // then divides it by the number of element in axis. @@ -3679,38 +3904,6 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims, } } -template -void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, - const Dims<4>& input2_dims, T* output_data, - const Dims<4>& output_dims) { - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - input1_data[SubscriptToIndex(desc1, c, x, y, b)] - - input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - } - } - } - } -} - template void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, T* output_data, @@ -3760,9 +3953,9 @@ void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims, } } -template -void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, - T2* output_data, const Dims<4>& output_dims) { +template +void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, + T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) { // The current ArgMax implemention can only determine the index of the maximum // value in the last dimension. So the axis argument is ignored. @@ -3775,22 +3968,31 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, const int depth = ArraySize(input_dims, 0); for (int i = 0; i < outer_size; ++i) { - auto max_value = input_data[i * depth]; - int max_index = 0; + auto min_max_value = input_data[i * depth]; + int min_max_index = 0; for (int d = 1; d < depth; ++d) { const auto& curr_value = input_data[i * depth + d]; - if (curr_value > max_value) { - max_value = curr_value; - max_index = d; + if (cmp(curr_value, min_max_value)) { + min_max_value = curr_value; + min_max_index = d; } } - output_data[i] = max_index; + output_data[i] = min_max_index; } } +// TODO(renjieliu): Remove this one. +template +void ArgMax(const T3* axis, const T1* input_data, + const tflite::Dims<4>& input_dims, T2* output_data, + const tflite::Dims<4>& output_dims) { + ArgMinMax(axis, input_data, input_dims, output_data, output_dims, + std::greater()); +} + template void Transpose(const T* input, const Dims<4>& input_dims, T* output, - const Dims<4>& output_dims, int* permuted_axes) { + const Dims<4>& output_dims, const int* permuted_axes) { int out_sizes[4]; // Compute the inverse permutation array so we can do an output centered // transpose. Also, check to make sure output_dims is matching input_dims. @@ -3821,7 +4023,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, int stride_width, int stride_height, int pad_width, int pad_height, float* output_data, - const Dims<4>& output_dims) { + const Dims<4>& output_dims, float* /*im2col_data*/, + const Dims<4>& /*im2col_dims*/) { const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); @@ -3839,7 +4042,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, // computing their influence on the output, rather than looping through the // output elements in the typical "gather" access pattern of a conv. We // therefore must initialize the output array to zero. - for (int i = 0; i < FlatSize(output_dims); i++) { + const int num_elements = FlatSize(output_dims); + for (int i = 0; i < num_elements; i++) { output_data[i] = 0.0f; } @@ -4093,8 +4297,8 @@ inline void RankOneSelect(const D* input_condition_data, } // For easy implementation, the indices is always a vector of size-4 vectors. -template -inline void SparseToDense(const std::vector>& indices, +template +inline void SparseToDense(const std::vector>& indices, const T* values, T default_value, T* output_data, const Dims<4>& output_dims, bool value_is_scalar) { const int value_count = indices.size(); @@ -4109,7 +4313,7 @@ inline void SparseToDense(const std::vector>& indices, // condition within the loop every time. if (value_is_scalar) { for (int i = 0; i < value_count; ++i) { - const std::vector& index = indices[i]; + const std::vector& index = indices[i]; TFLITE_DCHECK_EQ(index.size(), 4); const T value = *values; // just use the first value. output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = @@ -4120,7 +4324,7 @@ inline void SparseToDense(const std::vector>& indices, // Go through the values and indices to fill the sparse values. for (int i = 0; i < value_count; ++i) { - const std::vector& index = indices[i]; + const std::vector& index = indices[i]; TFLITE_DCHECK_EQ(index.size(), 4); const T value = values[i]; output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = @@ -4128,6 +4332,95 @@ inline void SparseToDense(const std::vector>& indices, } } +template +inline void Pow(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = std::pow(input1_data[i], input2_data[i]); + } +} + +template +inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + std::pow(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void Logical(const bool* input1_data, const Dims<4>& input1_dims, + const bool* input2_data, const Dims<4>& input2_dims, + bool* output_data, const Dims<4>& output_dims, + const std::function& func) { + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = func(input1_data[i], input2_data[i]); + } +} + +inline void BroadcastLogical(const bool* input1_data, + const Dims<4>& input1_dims, + const bool* input2_data, + const Dims<4>& input2_dims, bool* output_data, + const Dims<4>& output_dims, + const std::function& func) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + func(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more +// generalized and efficient BroadcastBinaryFunction. +// +// R: Result type. T1: Input 1 type. T2: Input 2 type. +template +inline void BroadcastBinaryFunction(const T1* input1_data, + const Dims<4>& input1_dims, + const T2* input2_data, + const Dims<4>& input2_dims, R* output_data, + const Dims<4>& output_dims, + R (*func)(T1, T2)) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + func(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc index d781a7b642036f3c5ddaa366f257fe26511c83c3..ca94e7740eb18e9d2d36c676e1db2766d7050852 100644 --- a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc @@ -27,24 +27,27 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/test_util.h" +#include "tensorflow/contrib/lite/string.h" namespace tflite { namespace { void RunSoftmaxFloatReference(const uint8* input_data, - const Dims<4>& dims_common, int32 input_offset, - const double input_scale, int stride, float beta, + const RuntimeShape& shape_common, + int32 input_offset, const double input_scale, + int stride, float beta, uint8* reference_output_data) { - const int ref_buffer_size = RequiredBufferSizeForDims(dims_common); + const int ref_buffer_size = shape_common.FlatSize(); std::vector reference_dequant_data(ref_buffer_size); std::vector reference_output_float_data(ref_buffer_size); // Reference data generated via Dequant of input into float, and then applying // float Softmax. - reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale, - reference_dequant_data.data(), dims_common); - optimized_ops::Softmax(reference_dequant_data.data(), dims_common, beta, - reference_output_float_data.data(), dims_common); + reference_ops::Dequantize( + input_data, ToRuntimeDims(shape_common), input_offset, input_scale, + reference_dequant_data.data(), ToRuntimeDims(shape_common)); + optimized_ops::Softmax(reference_dequant_data.data(), shape_common, beta, + reference_output_float_data.data(), shape_common); // Work with quantized scaling for Softmax, under which 256 represents 1, but // we limit this to 255. for (int i = 0; i < ref_buffer_size; i++) { @@ -55,9 +58,9 @@ void RunSoftmaxFloatReference(const uint8* input_data, } void CheckOutputData(const uint8* test_output, const uint8* reference_output, - const Dims<4>& dims_common, const string& check_label, - bool be_exacting) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); + const RuntimeShape& shape_common, + const string& check_label, bool be_exacting) { + const int buffer_size = shape_common.FlatSize(); // While calculating some metrics in floating point, we work with quantized // scaling. std::vector diff(buffer_size); @@ -91,15 +94,15 @@ void CheckOutputData(const uint8* test_output, const uint8* reference_output, // Runs the Softmax and compares against the float reference implementation and // the quantized reference implementation. -void RunOneSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, - int32 input_offset, const double input_scale, int stride, - float beta) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); +void RunOneSoftmaxTest(const uint8* input_data, + const RuntimeShape& shape_common, int32 input_offset, + const double input_scale, int stride, float beta) { + const int buffer_size = shape_common.FlatSize(); std::vector optimized_softmax_output(buffer_size); std::vector reference_float_softmax_output(buffer_size); std::vector reference_quant_softmax_output(buffer_size); - RunSoftmaxFloatReference(input_data, dims_common, input_offset, input_scale, + RunSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, stride, beta, reference_float_softmax_output.data()); int32 input_beta_multiplier; @@ -113,21 +116,21 @@ void RunOneSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, input_beta_left_shift); - optimized_ops::Softmax(input_data, dims_common, input_beta_multiplier, + optimized_ops::Softmax(input_data, shape_common, input_beta_multiplier, input_beta_left_shift, diff_min, - optimized_softmax_output.data(), dims_common); - reference_ops::Softmax(input_data, dims_common, input_beta_multiplier, + optimized_softmax_output.data(), shape_common); + reference_ops::Softmax(input_data, shape_common, input_beta_multiplier, input_beta_left_shift, diff_min, - reference_quant_softmax_output.data(), dims_common); + reference_quant_softmax_output.data(), shape_common); CheckOutputData(optimized_softmax_output.data(), - reference_float_softmax_output.data(), dims_common, + reference_float_softmax_output.data(), shape_common, "Optimized vs float reference", false); CheckOutputData(optimized_softmax_output.data(), - reference_quant_softmax_output.data(), dims_common, + reference_quant_softmax_output.data(), shape_common, "Optimized vs quant reference", true); CheckOutputData(reference_quant_softmax_output.data(), - reference_float_softmax_output.data(), dims_common, + reference_float_softmax_output.data(), shape_common, "Quant reference vs float reference", false); } @@ -150,13 +153,13 @@ bool TryOneUniformSoftmax() { const int32 input_offset = UniformRandomInt(-256, 0); const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10); - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector input_data(buffer_size); FillRandom(&input_data); - RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale, + RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } @@ -188,14 +191,14 @@ bool TryOneSkyscraperSoftmax(bool small_depth) { const int middle_min = UniformRandomInt(0, 255); const int sides_max = UniformRandomInt(0, middle_min); - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector input_data(buffer_size); FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, sides_max); - RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale, + RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } diff --git a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc b/tensorflow/contrib/lite/kernels/internal/spectrogram.cc index 4eddf7bf0a2cbca695dae20ba8ba56a9cd72e4ba..20abcb725859d03f83c969369bddf1429895e0ba 100644 --- a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc +++ b/tensorflow/contrib/lite/kernels/internal/spectrogram.cc @@ -43,13 +43,13 @@ bool Spectrogram::Initialize(int window_length, int step_length) { return Initialize(window, step_length); } -inline int Log2Floor(uint n) { +inline int Log2Floor(uint32_t n) { if (n == 0) return -1; int log = 0; - uint value = n; + uint32_t value = n; for (int i = 4; i >= 0; --i) { int shift = (1 << i); - uint x = value >> shift; + uint32_t x = value >> shift; if (x != 0) { value = x; log += shift; @@ -58,7 +58,7 @@ inline int Log2Floor(uint n) { return log; } -inline int Log2Ceiling(uint n) { +inline int Log2Ceiling(uint32_t n) { int floor = Log2Floor(n); if (n == (n & ~(n - 1))) // zero or a power of two return floor; @@ -66,7 +66,7 @@ inline int Log2Ceiling(uint n) { return floor + 1; } -inline uint NextPowerOfTwo(uint value) { +inline uint32_t NextPowerOfTwo(uint32_t value) { int exponent = Log2Ceiling(value); // DCHECK_LT(exponent, std::numeric_limits::digits); return 1 << exponent; diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h index ef77371bf65cc975dfa35275c8daa32de112a249..5994fad5c73df1dde6e33ba46dbd6e0802ea61be 100644 --- a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h +++ b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h @@ -74,12 +74,22 @@ inline int StartForAxis(int begin_mask, // size 4, this function would return 4 as the stop, because it is one past the // "real" indices of 0, 1, 2 & 3. template -inline int StopForAxis(int end_mask, std::vector const& stop_indices, +inline int StopForAxis(int end_mask, int shrink_axis_mask, + std::vector const& stop_indices, std::vector const& strides, - int const* input_shape, int axis) { + int const* input_shape, int axis, int start_for_axis) { // Begin with the specified index + const bool shrink_axis = shrink_axis_mask & (1 << axis); int stop = stop_indices[axis]; + // When shrinking an axis, the end position does not matter (and can be + // incorrect when negative indexing is used, see Issue #19260). Always use + // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has + // already been adjusted for negative indices. + if (shrink_axis) { + stop = start_for_axis + 1; + } + // end_mask override if (end_mask & (1 << axis)) { if (strides[axis] > 0) { @@ -93,7 +103,7 @@ inline int StopForAxis(int end_mask, std::vector const& stop_indices, } // Handle negative indices - int axis_size = input_shape[axis]; + const int axis_size = input_shape[axis]; if (stop < 0) { stop += axis_size; } diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index ce887cea8b794b4b0cfd31722581cf9327be625e..ee2af5b46046c9e8bdc5816d5b6e9e9100cdc240 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#include #include #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/internal/types.h" @@ -34,6 +35,11 @@ inline uint8_t* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.uint8 : nullptr; } +template <> +inline int16_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i16 : nullptr; +} + template <> inline int32_t* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.i32 : nullptr; @@ -49,6 +55,13 @@ inline bool* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.b : nullptr; } +template <> +inline std::complex* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr + ? reinterpret_cast*>(tensor->data.c64) + : nullptr; +} + template inline const T* GetTensorData(const TfLiteTensor* tensor); @@ -62,6 +75,11 @@ inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.uint8 : nullptr; } +template <> +inline const int16_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i16 : nullptr; +} + template <> inline const int32_t* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.i32 : nullptr; @@ -77,6 +95,13 @@ inline const bool* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.b : nullptr; } +template <> +inline const std::complex* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr + ? reinterpret_cast*>(tensor->data.c64) + : nullptr; +} + inline int RemapDim(int max_dimensions, int d) { return max_dimensions - d - 1; } @@ -114,6 +139,19 @@ inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { return GetTensorDims(dims->data, dims->size); } +inline RuntimeShape GetTensorShape(std::vector data) { + return RuntimeShape(data.size(), data.data()); +} + +inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) { + if (tensor == nullptr) { + return RuntimeShape(); + } + + auto* dims = tensor->dims; + return RuntimeShape(dims->size, dims->data); +} + // A list of tensors in a format that can be used by kernels like split and // concatenation. template diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index 5160e22307ae0894fabd0e9c4f7b9cd38b00840e..1ff8cfe39c9aed7e9241815dc8eff7ab4d9fd585 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -17,6 +17,10 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" +#if defined(_MSC_VER) +#define __restrict__ __restrict +#endif + namespace tflite { namespace tensor_utils { @@ -31,8 +35,8 @@ bool IsZeroVector(const float* vector, int v_size); // It also outputs the range (min, max) of the floating point buffer, and the // scaling factor used to quantize the values. void SymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float* min, float* max, - float* scaling_factor); + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor); // Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch // dimension composed by input vectors independent from each other). The result @@ -124,6 +128,10 @@ void Sub1Vector(const float* vector, int v_size, float* result); // Fill vector with 0.f. void ZeroVector(float* vector, int v_size); +// Multiply all elements of vector with a scalar. +void VectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result); + // Clip elements of a vector using a abs_limit value. void ClipVector(const float* vector, int v_size, float abs_limit, float* result); diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc index 14ee528394b6872d9e79969db0e431658277f56b..e8343f1223b2137a7df9cc264c56100bc66f9fc1 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -32,6 +32,22 @@ TEST(uKernels, ClipTest) { {0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0}))); } +TEST(uKernels, VectorScalarMultiply) { + constexpr int kVectorSize = 29; + static int8_t input[kVectorSize]; + for (int i = 0; i < 29; ++i) { + input[i] = static_cast(i - 14); + } + const float scale = 0.1f; + std::vector output(kVectorSize, 0.0f); + VectorScalarMultiply(input, kVectorSize, scale, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {-1.4, -1.3, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6, -0.5, + -0.4, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5, + 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4}))); +} + TEST(uKernels, IsZeroTest) { constexpr int kVectorSize = 21; static float zeros[kVectorSize] = {0.0}; @@ -56,14 +72,15 @@ TEST(uKernels, SymmetricQuantizeFloatsTest) { static float input[kVectorSize] = {-640, -635.0, -630, 10.0, 2.0, -5.0, -10.0, 0.0, 1000.0}; - int8 output[kVectorSize]; + int8_t output[kVectorSize]; float min, max, scaling_factor; SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max, &scaling_factor); EXPECT_EQ(min, -640); EXPECT_EQ(max, 1000); - EXPECT_NEAR(scaling_factor, 0.127, 1e-6); // EQ won't work due to fpoint. + // EQ won't work due to fpoint. + EXPECT_NEAR(scaling_factor, 1000 / 127.0, 1e-6); EXPECT_THAT(output, testing::ElementsAreArray({-81, -81, -80, 1, 0, -1, -1, 0, 127})); } @@ -72,7 +89,7 @@ TEST(uKernels, SymmetricQuantizeFloatsAllZerosTest) { constexpr int kVectorSize = 9; static float input[kVectorSize] = {0, 0, 0, 0, 0, 0, 0, 0, 0}; - int8 output[kVectorSize]; + int8_t output[kVectorSize]; float min, max, scaling_factor; SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max, &scaling_factor); @@ -88,14 +105,14 @@ TEST(uKernels, SymmetricQuantizeFloatsAllAlmostZeroTest) { static float input[kVectorSize] = {-1e-5, 3e-5, -7e-6, -9e-5, 1e-6, 4e-5, 9e-6, 2e-4, 0}; - int8 output[kVectorSize]; + int8_t output[kVectorSize]; float min, max, scaling_factor; SymmetricQuantizeFloats(input, kVectorSize, output, &min, &max, &scaling_factor); EXPECT_NEAR(min, -9e-05, 1e-6); EXPECT_NEAR(max, 0.0002, 1e-6); - EXPECT_EQ(scaling_factor, 635000); + EXPECT_NEAR(scaling_factor, 1.57e-6, 1e-6); EXPECT_THAT(output, testing::ElementsAreArray({-6, 19, -4, -57, 1, 25, 6, 127, 0})); } @@ -126,6 +143,7 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateTest) { -1., 3., 7., 3., 23., 3.}))); } +#ifdef __ANDROID__ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) { // Note we use 29 columns as this exercises all the neon kernel: the // 16-block SIMD code, the 8-block postamble, and the leftover postamble. @@ -149,13 +167,13 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) { -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22, -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0}; - int8* a_int8_data = reinterpret_cast( + int8_t* a_int8_data = reinterpret_cast( aligned_malloc(a_rows * a_cols, kWeightsPerUint32)); float a_min, a_max; float scaling_factor_a; SymmetricQuantizeFloats(a_float_data, a_rows * a_cols, a_int8_data, &a_min, &a_max, &scaling_factor_a); - const int8 expected_a_int8_data[] = { + const int8_t expected_a_int8_data[] = { /* 1st row */ 5, 10, @@ -346,7 +364,7 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) { }; // Quantized values of B: - int8 b_int8_data[b_rows * b_cols * batches]; + int8_t b_int8_data[b_rows * b_cols * batches]; float b_min, b_max; float scaling_factor_b[batches]; SymmetricQuantizeFloats(b_float_data, b_rows * b_cols, b_int8_data, &b_min, @@ -355,7 +373,7 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) { &b_int8_data[b_rows * b_cols], &b_min, &b_max, &scaling_factor_b[1]); - const int8 expected_b_int8_data[] = { + const int8_t expected_b_int8_data[] = { /* batch 1 */ 127, -127, @@ -448,6 +466,7 @@ TEST(uKernels, MatrixBatchVectorMultiplyAccumulateSymmetricQuantizedTest) { aligned_free(a_int8_data); } +#endif // __ANDROID__ TEST(uKernels, VectorVectorCwiseProductTest) { constexpr int kVectorSize = 10; diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 3ecef15271c6eb9dd9d6dd370377fddda2723fcf..81a479017e7d65a29af5e35139b054df4abb7116 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -23,7 +23,73 @@ limitations under the License. namespace tflite { enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; -enum class PaddingType { kNone, kSame, kValid }; +enum class PaddingType : uint8 { kNone, kSame, kValid }; + +struct PaddingValues { + int8 width; + int8 height; +}; + +// This enumeration allows for non-default formats for the weights array +// of a fully-connected operator, allowing the use of special optimized +// runtime paths. +enum class FullyConnectedWeightsFormat : uint8 { + // Default format (flat 2D layout, the inner contiguous dimension + // is input_depth, the outer non-contiguous dimension is output_depth) + kDefault, + // Summary: optimized layout for fast CPU runtime implementation, + // aimed specifically at ARM CPUs at the moment, and specialized for + // 8-bit quantized layers. + // + // The use case we're concerned with here is: 8-bit quantization, + // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in + // a key application that drove this), very small batch size (e.g. 1 -- 4). + // + // Even with 8-bit quantization of weights, the performance of memory + // accesses to the weights can become the dominant issue when + // the batch size is small, so each weight value is used in only a few + // arithmetic ops, i.e. the fully-connected node has a low arithmetic + // intensity. The specific issues that arise are of three kinds: + // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory + // bound. That's the "good" issue to run into. + // (2) One may run into sub-optimal pre-fetching: the data hasn't been + // prefetched into the cache by the time we need it. + // (3) One may run into cache aliasing: multiple values that are + // pre-fetched, alias each other in the L1 cache (which typically + // has only 4-way set associativity in ARM CPUs) and thus evict + // each other before we get to using them. + // + // The point of this shuffling is to avoid issues (2) and (3) so that + // we get as fast as possible given only the hard constraint (1). + // This is achieved by turning the difficulty into a solution: the + // difficulty, that each value loaded from memory is used only in + // one kernel iteration, making this operation memory-intensive, hints at + // the solution, of shuffling the weights so that they are stored in the + // exact order as the kernel needs to load them, so that the memory + // accesses made by the kernel are trivial. This solves (2) because the + // trivial memory access pattern allows the CPU's automatic prefetching + // to perform very well (no need even for preload instructions), and this + // solves (3) because the values being loaded concurrently are now + // contiguous in the address space, thus don't alias each other in the cache. + // + // On ARM, we typically want our kernel to process a 4x16 block of weights + // at a time, because: + // - 16 is the number of bytes in a NEON register. + // - 4 is how many rows we need to handle concurrently in the kernel in + // order to have sufficient mutual independence of instructions to + // maximize arithmetic throughput. + // + // Finally, the 'Int8' part in the name refers to the fact that this + // weights format has each weights value encoded as a signed int8 value, + // even if the data type of the weights buffer is uint8. This is intended + // to save runtime kernels the effort to have to XOR the top bit of these + // bytes before using them in signed arithmetic, see this file for more + // explanations on the 'signed int8 trick' in matrix multiplication kernels: + // + // tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc + // + kShuffled4x16Int8, +}; // Quantization parameters, determining the mapping of quantized values // to real values (i.e. determining how quantized values are mathematically @@ -53,6 +119,8 @@ class RuntimeShape { // larger shapes are separately allocated. static constexpr int kMaxSmallSize = 4; + RuntimeShape& operator=(RuntimeShape const&) = delete; + RuntimeShape() : size_(0) {} explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) { @@ -61,10 +129,35 @@ class RuntimeShape { } } + RuntimeShape(int shape_size, int32 value) : size_(0) { + Resize(shape_size); + for (int i = 0; i < shape_size; ++i) { + SetDim(i, value); + } + } + RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) { ReplaceWith(dimensions_count, dims_data); } + RuntimeShape(const std::initializer_list init_list) : size_(0) { + BuildFrom(init_list); + } + + // Avoid using this constructor. We should be able to delete it when C++17 + // rolls out. + RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) { + if (size_ > kMaxSmallSize) { + dims_pointer_ = new int32[size_]; + } + std::memcpy(DimsData(), other.DimsData(), sizeof(int32) * size_); + } + + bool operator==(const RuntimeShape& comp) const { + return this->size_ == comp.size_ && + std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32)) == 0; + } + ~RuntimeShape() { if (size_ > kMaxSmallSize) { delete[] dims_pointer_; @@ -121,6 +214,16 @@ class RuntimeShape { } } + // This will probably be factored out. Old code made substantial use of 4-D + // shapes, and so this function is used to extend smaller shapes. Note that + // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be + // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their + // inputs should already be 4-D, so this function should not be needed. + inline static RuntimeShape ExtendedShape(int new_shape_size, + const RuntimeShape& shape) { + return RuntimeShape(new_shape_size, shape, 1); + } + inline void BuildFrom(const std::initializer_list init_list) { BuildFrom>(init_list); } @@ -138,7 +241,25 @@ class RuntimeShape { return buffer_size; } + bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); } + private: + // For use only by ExtendFrom(), written to guarantee (return-value) copy + // elision in C++17. + // This creates a shape padded to the desired size with the specified value. + RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value) + : size_(0) { + TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount()); + TFLITE_CHECK_LE(new_shape_size, kMaxSmallSize); + Resize(new_shape_size); + const int size_increase = new_shape_size - shape.DimensionsCount(); + for (int i = 0; i < size_increase; ++i) { + SetDim(i, pad_value); + } + std::memcpy(DimsData() + size_increase, shape.DimsData(), + sizeof(int32) * shape.DimensionsCount()); + } + int32 size_; union { int32 dims_[kMaxSmallSize]; @@ -164,7 +285,9 @@ inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) { // Gets next index to iterate through a multidimensional array. inline bool NextIndex(const int num_dims, const int* dims, int* current) { - TFLITE_DCHECK_GT(num_dims, 0); + if (num_dims == 0) { + return false; + } TFLITE_DCHECK(dims != nullptr); TFLITE_DCHECK(current != nullptr); int carry = 1; @@ -191,7 +314,9 @@ inline bool NextIndex(const int num_dims, const int* dims, int* current) { inline size_t ReducedOutputOffset(const int num_dims, const int* dims, const int* index, const int num_axis, const int* axis) { - TFLITE_DCHECK_GT(num_dims, 0); + if (num_dims == 0) { + return 0; + } TFLITE_DCHECK(dims != nullptr); TFLITE_DCHECK(index != nullptr); size_t offset = 0; @@ -214,6 +339,15 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims, return offset; } +inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < shape.Dims(0)); + TFLITE_DCHECK(i1 >= 0 && i1 < shape.Dims(1)); + TFLITE_DCHECK(i2 >= 0 && i2 < shape.Dims(2)); + TFLITE_DCHECK(i3 >= 0 && i3 < shape.Dims(3)); + const int* dims_data = shape.DimsData(); + return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3; +} + inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]); TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]); @@ -228,6 +362,9 @@ inline int Offset(const Dims<4>& dims, int* index) { } // Get array size, DCHECKing that the dim index is in range. +// +// Note that this will be phased out with Dims<4>, since RuntimeShape::Dims() +// already performs this check. template int ArraySize(const Dims& array, int index) { TFLITE_DCHECK(index >= 0 && index < N); @@ -249,6 +386,21 @@ int MatchingArraySize(const ArrayType1& array1, int index1, return MatchingArraySize(array1, index1, args...); } +// Get common shape dim, DCHECKing that they all agree. +inline int MatchingDim(const RuntimeShape& shape1, int index1, + const RuntimeShape& shape2, int index2) { + TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2)); + return shape1.Dims(index1); +} + +template +int MatchingDim(const RuntimeShape& shape1, int index1, + const RuntimeShape& shape2, int index2, Args... args) { + TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2)); + return MatchingDim(shape1, index1, args...); +} + +// Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize(). template inline int FlatSize(const Dims& dims) { int flat_size = 1; @@ -263,6 +415,54 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) { return FlatSize(dims); } +// Flat size calculation, checking that dimensions match with one or more other +// arrays. +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0) { + TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount()); + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return shape.FlatSize(); +} + +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1) { + TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount()); + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return MatchingFlatSize(shape, check_shape_1); +} + +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2) { + TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount()); + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return MatchingFlatSize(shape, check_shape_1, check_shape_2); +} + +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2, + const RuntimeShape& check_shape_3) { + TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount()); + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3); +} + // Flat size calculation, checking that dimensions match with one or more other // arrays. template @@ -289,7 +489,7 @@ inline int MatchingFlatSize(const Dims& dims, const Dims& check_dims_0, for (int i = 0; i < N; ++i) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } - return FlatSize(dims, check_dims_1, check_dims_2); + return MatchingFlatSize(dims, check_dims_1, check_dims_2); } template @@ -300,7 +500,7 @@ inline int MatchingFlatSize(const Dims& dims, const Dims& check_dims_0, for (int i = 0; i < N; ++i) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } - return FlatSize(dims, check_dims_1, check_dims_2, check_dims_3); + return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3); } // Data is required to be contiguous, and so many operators can use either the @@ -368,6 +568,72 @@ inline int MatchingFlatSizeSkipDim(const Dims& dims, int skip_dim, check_dims_3); } +// Data is required to be contiguous, and so many operators can use either the +// full array flat size or the flat size with one dimension skipped (commonly +// the depth). +inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) { + const int dims_count = shape.DimensionsCount(); + TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count); + const auto* dims_data = shape.DimsData(); + int flat_size = 1; + for (int i = 0; i < dims_count; ++i) { + flat_size *= (i == skip_dim) ? 1 : dims_data[i]; + } + return flat_size; +} + +// A combination of MatchingFlatSize() and FlatSizeSkipDim(). +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return FlatSizeSkipDim(shape, skip_dim); +} + +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1); +} + +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2); +} + +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2, + const RuntimeShape& check_shape_3) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2, + check_shape_3); +} + template bool IsPackedWithoutStrides(const Dims& dims) { int expected_stride = 1; @@ -386,6 +652,82 @@ void ComputeStrides(Dims* dims) { } } +struct PoolParams { + FusedActivationFunctionType activation; + PaddingType padding_type; + PaddingValues padding_values; + int stride_height; + int stride_width; + int filter_height; + int filter_width; + // uint8, etc, activation params. + int32 quantized_activation_min; + int32 quantized_activation_max; + // float activation params. + float float_activation_min; + float float_activation_max; +}; + +enum class BroadcastableOpCategory : uint8 { + kNone, + kNonBroadcast, // Matching input shapes. + kFirstInputBroadcastsFast, // Fivefold nested loops. + kSecondInputBroadcastsFast, // Fivefold nested loops. + kGenericBroadcast, // Fall-back. +}; + +// For Add, Sub, Mul ops. +struct ArithmeticParams { + // Shape dependent / common to data / op types. + BroadcastableOpCategory broadcast_category; + // uint8 inference params. + int32 input1_offset; + int32 input2_offset; + int32 output_offset; + int32 output_multiplier; + int output_shift; + // Add / Sub, not Mul, uint8 inference params. + int left_shift; + int32 input1_multiplier; + int input1_shift; + int32 input2_multiplier; + int input2_shift; + // uint8, etc, activation params. + int32 quantized_activation_min; + int32 quantized_activation_max; + // float activation params. + float float_activation_min; + float float_activation_max; + + // Processed output dimensions. + // Let input "a" be the one that broadcasts in the faster-changing dimension. + // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and + // {b0, b1, b2, b3, b4}, + // broadcast_shape[4] = b0 = a0. + // broadcast_shape[3] = b1; a1 = 1. + // broadcast_shape[2] = b2 = a2. + // broadcast_shape[1] = a3; b3 = 1. + // broadcast_shape[0] = b4 = a4. + int broadcast_shape[5]; +}; + +template +inline void SetActivationParams(T min, T max, ArithmeticParams* params); + +template <> +inline void SetActivationParams(float min, float max, + ArithmeticParams* params) { + params->float_activation_min = min; + params->float_activation_max = max; +} + +template <> +inline void SetActivationParams(int32 min, int32 max, + ArithmeticParams* params) { + params->quantized_activation_min = min; + params->quantized_activation_max = max; +} + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc index 184028427fb193aa99cf155961c16eda1298e326..08f942c933552aa6ca7369550c928efba9e2e93e 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.cc +++ b/tensorflow/contrib/lite/kernels/kernel_util.cc @@ -43,12 +43,11 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, return kTfLiteOk; } -void CalculateActivationRangeUint8(TfLiteFusedActivation activation, - TfLiteTensor* output, int32_t* act_min, - int32_t* act_max) { - const int32_t qmin = std::numeric_limits::min(); - const int32_t qmax = std::numeric_limits::max(); - +namespace { +void CalculateActivationRangeQuantizedImpl(TfLiteFusedActivation activation, + int32_t qmin, int32_t qmax, + TfLiteTensor* output, + int32_t* act_min, int32_t* act_max) { const auto scale = output->params.scale; const auto zero_point = output->params.zero_point; @@ -70,23 +69,38 @@ void CalculateActivationRangeUint8(TfLiteFusedActivation activation, *act_max = qmax; } } - -void CalculateActivationRangeFloat(TfLiteFusedActivation activation, - float* activation_min, - float* activation_max) { - if (activation == kTfLiteActRelu) { - *activation_min = 0.f; - *activation_max = std::numeric_limits::max(); - } else if (activation == kTfLiteActRelu6) { - *activation_min = 0.f; - *activation_max = 6.f; - } else if (activation == kTfLiteActRelu1) { - *activation_min = -1.f; - *activation_max = 1.f; +} // namespace + +TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context, + TfLiteFusedActivation activation, + TfLiteTensor* output, + int32_t* act_min, + int32_t* act_max) { + int32_t qmin = 0; + int32_t qmax = 0; + if (output->type == kTfLiteUInt8) { + qmin = std::numeric_limits::min(); + qmax = std::numeric_limits::max(); + } else if (output->type == kTfLiteInt16) { + qmin = std::numeric_limits::min(); + qmax = std::numeric_limits::max(); } else { - *activation_min = std::numeric_limits::lowest(); - *activation_max = std::numeric_limits::max(); + TF_LITE_ENSURE(context, false); } + + CalculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, act_min, + act_max); + return kTfLiteOk; +} + +void CalculateActivationRangeUint8(TfLiteFusedActivation activation, + TfLiteTensor* output, int32_t* act_min, + int32_t* act_max) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + + CalculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, act_min, + act_max); } bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) { diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h index 82cded36f2ed2777daccafee5890f47c0d7254e8..c8ce3c917d5bf66e01fbae95c18dfe97b3c84bae 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#include + #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" @@ -86,14 +88,35 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, TfLiteTensor* output, double* multiplier); -// Calculates the useful range of an activation layer given its activation -// tensor. +// Calculates the useful quantized range of an activation layer given its +// activation tensor. +TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context, + TfLiteFusedActivation activation, + TfLiteTensor* output, + int32_t* act_min, + int32_t* act_max); void CalculateActivationRangeUint8(TfLiteFusedActivation activation, TfLiteTensor* output, int32_t* act_min, int32_t* act_max); -void CalculateActivationRangeFloat(TfLiteFusedActivation activation, - float* activation_min, - float* activation_max); +// Calculates the useful range of an activation layer given its activation +// tensor.a +template +void CalculateActivationRange(TfLiteFusedActivation activation, + T* activation_min, T* activation_max) { + if (activation == kTfLiteActRelu) { + *activation_min = 0; + *activation_max = std::numeric_limits::max(); + } else if (activation == kTfLiteActRelu6) { + *activation_min = 0; + *activation_max = 6; + } else if (activation == kTfLiteActRelu1) { + *activation_min = -1; + *activation_max = 1; + } else { + *activation_min = std::numeric_limits::lowest(); + *activation_max = std::numeric_limits::max(); + } +} // Return true if the given tensors have the same shape. bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2); diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc index 3205c1cc52724207904621a5870636841ef379fe..a7b54c6b842332feb2d9e7179e79ae054bd23bb9 100644 --- a/tensorflow/contrib/lite/kernels/l2norm.cc +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -70,8 +70,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32) { #define TF_LITE_L2NORM(type) \ type::L2Normalization( \ - GetTensorData(input), GetTensorDims(input), \ - GetTensorData(output), GetTensorDims(output)) + GetTensorData(input), GetTensorShape(input), \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_L2NORM(reference_ops); @@ -81,10 +81,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } #undef TF_LITE_L2NORM } else if (output->type == kTfLiteUInt8) { -#define TF_LITE_L2NORM(type) \ - type::L2Normalization(GetTensorData(input), GetTensorDims(input), \ - input->params.zero_point, \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_L2NORM(type) \ + type::L2Normalization(GetTensorData(input), GetTensorShape(input), \ + input->params.zero_point, \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_L2NORM(reference_ops); diff --git a/tensorflow/contrib/lite/kernels/log_softmax_test.cc b/tensorflow/contrib/lite/kernels/log_softmax_test.cc index 62820a2f5113cb6ae252386aaf3842135383b79f..9a8d35e82cbc3a7e55246e6c06599b2838d1ee67 100644 --- a/tensorflow/contrib/lite/kernels/log_softmax_test.cc +++ b/tensorflow/contrib/lite/kernels/log_softmax_test.cc @@ -90,10 +90,9 @@ TEST(LogSoftmaxOpTest, CompareWithTFmini) { m.Invoke(); std::unique_ptr output_buffer(new float[input_size * batch_size]); - static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, - {1, 0, 0, input_size}}; - tflite::reference_ops::LogSoftmax(input_buffer, input_dims, - output_buffer.get(), input_dims); + auto input_shape = RuntimeShape({batch_size, 1, 1, input_size}); + tflite::reference_ops::LogSoftmax(input_buffer, input_shape, + output_buffer.get(), input_shape); std::vector expected; expected.insert(expected.end(), output_buffer.get(), diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc new file mode 100644 index 0000000000000000000000000000000000000000..87c2fee667ccaf7bfdc4e2316309d2abc35b5324 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/logical.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace logical { +namespace { + +// Input/output tensor index. +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +// Op data for logical op. +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Reinterprete the opaque data provided by user. + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + + const TfLiteType type = input1->type; + if (type != kTfLiteBool) { + context->ReportError(context, "Logical ops only support bool type."); + return kTfLiteError; + } + output->type = type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } + + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node, + const std::function& func) { + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (data->requires_broadcast) { + reference_ops::BroadcastLogical( + GetTensorData(input1), GetTensorDims(input1), + GetTensorData(input2), GetTensorDims(input2), + GetTensorData(output), GetTensorDims(output), func); + } else { + reference_ops::Logical(GetTensorData(input1), GetTensorDims(input1), + GetTensorData(input2), GetTensorDims(input2), + GetTensorData(output), GetTensorDims(output), + func); + } + + return kTfLiteOk; +} + +TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) { + const auto logical_or_func = std::logical_or(); + return LogicalImpl(context, node, logical_or_func); +} + +TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) { + const auto logical_and_func = std::logical_and(); + return LogicalImpl(context, node, logical_and_func); +} + +} // namespace +} // namespace logical + +TfLiteRegistration* Register_LOGICAL_OR() { + // Init, Free, Prepare, Eval are satisfying the Interface required by + // TfLiteRegistration. + static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare, + logical::LogicalOrEval}; + return &r; +} + +TfLiteRegistration* Register_LOGICAL_AND() { + // Init, Free, Prepare, Eval are satisfying the Interface required by + // TfLiteRegistration. + static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare, + logical::LogicalAndEval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/logical_test.cc b/tensorflow/contrib/lite/kernels/logical_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..206cbde98fa48ec5f7c863bbced9dccc9cab5207 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/logical_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +class LogicalOpModel : public SingleOpModel { + public: + LogicalOpModel(std::initializer_list input1_shape, + std::initializer_list input2_shape, BuiltinOperator op) { + input1_ = AddInput(TensorType_BOOL); + input2_ = AddInput(TensorType_BOOL); + output_ = AddOutput(TensorType_BOOL); + ConfigureBuiltinOp(op); + BuildInterpreter({input1_shape, input2_shape}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; + + void ConfigureBuiltinOp(BuiltinOperator op) { + switch (op) { + case BuiltinOperator_LOGICAL_OR: { + SetBuiltinOp(op, BuiltinOptions_LogicalOrOptions, + CreateLogicalOrOptions(builder_).Union()); + break; + } + case BuiltinOperator_LOGICAL_AND: { + SetBuiltinOp(op, BuiltinOptions_LogicalAndOptions, + CreateLogicalAndOptions(builder_).Union()); + break; + } + default: { FAIL() << "We shouldn't get here."; } + } + } +}; + +TEST(LogicalTest, LogicalOr) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, BuiltinOperator_LOGICAL_OR); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {true, false, true, false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(LogicalTest, BroadcastLogicalOr) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, BuiltinOperator_LOGICAL_OR); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(LogicalTest, LogicalAnd) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, BuiltinOperator_LOGICAL_AND); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {true, false, true, false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(LogicalTest, BroadcastLogicalAnd) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, BuiltinOperator_LOGICAL_AND); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {true}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc index 863e08ed680588a8cfb73792b8262d390ea6bf36..6ed9385dc35ecb2b4db1c895a663f962ea689ccf 100644 --- a/tensorflow/contrib/lite/kernels/lsh_projection.cc +++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc @@ -50,7 +50,6 @@ limitations under the License. // Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] } // A flattened tensor represents projected bit vectors. -#include #include #include #include diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index eb26a02455ce2afccaa081a72d93a9ceeca746cc..ba251c451e549a09d265fc43fed7dc7eb6896d61 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -24,6 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" @@ -37,14 +37,17 @@ namespace builtin { namespace lstm { struct OpData { - // Which kernel type to use. Full kernel (18-inputs) or basic kernel - // (5-inputs). + // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel + // (5 inputs). TfLiteLSTMKernelType kernel_type; - // Only used by full kernel. + + // These fields are only used by full kernel. + int activation_state_tensor_index; + int cell_state_tensor_index; int scratch_tensor_index; }; -// For full inputs kernel (18-inputs). +// For full inputs kernel (18 or 20 inputs). namespace full { // Input Tensors of size {n_batch, n_input} @@ -78,13 +81,22 @@ constexpr int kProjectionWeightsTensor = 16; // Optional // Projection bias tensor of size {n_output} constexpr int kProjectionBiasTensor = 17; // Optional +// If the node has 20 inputs, the following 2 tensors are used as state tensors. +// These are defined as variable tensors, and will be modified by this op. +constexpr int kInputActivationStateTensor = 18; +constexpr int kInputCellStateTensor = 19; + // Output tensors. +// * If the node has 18 inputs, these 2 tensors are used as state tensors. +// * If the node has 20 inputs, these 2 tensors are ignored. +// TODO(ycling): Make the 2 output state tensors optional, and propagate the +// state to output tensors when the 2 tensors present. constexpr int kOutputStateTensor = 0; constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* op_data = new OpData; + auto* op_data = new OpData(); op_data->kernel_type = kTfLiteLSTMFullKernel; context->AddTensors(context, /*tensors_to_add=*/7, &op_data->scratch_tensor_index); @@ -246,10 +258,31 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpData* op_data = reinterpret_cast(node->user_data); - // Check we have all the inputs and outputs we need. - TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); TF_LITE_ENSURE_EQ(context, node->outputs->size, 3); + // True if the node is using input variable state tensors. It means: + // * The state tensors are defined as inputs. In this case it would be the + // 19th and 20th input tensors. + // * Otherwise, the output tensors are used to store states. + bool use_input_variable_states; + if (node->inputs->size == 20) { + use_input_variable_states = true; + op_data->activation_state_tensor_index = + node->inputs->data[kInputActivationStateTensor]; + op_data->cell_state_tensor_index = + node->inputs->data[kInputCellStateTensor]; + } else if (node->inputs->size == 18) { + use_input_variable_states = false; + op_data->activation_state_tensor_index = + node->outputs->data[kOutputStateTensor]; + op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor]; + } else { + context->ReportError( + context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs", + node->inputs->size); + return kTfLiteError; + } + // Inferring batch size, number of outputs and number of cells from the // input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); @@ -272,36 +305,50 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_output = recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); + TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input, + n_output, n_cell)); - // Get the pointer to output, output_state and cell_state tensors. + // Get the pointer to output, activation_state and cell_state tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - // Resize the output, output_state and cell_state tensors. + TfLiteTensor* activation_state = + &context->tensors[op_data->activation_state_tensor_index]; + TfLiteTensor* cell_state = + &context->tensors[op_data->cell_state_tensor_index]; + + if (use_input_variable_states) { + // Check the shape of input state tensors. + // These tensor may be 1D or 2D. It's fine as long as the total size is + // correct. + TF_LITE_ENSURE_EQ(context, NumElements(activation_state), + n_batch * n_output); + TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); + } else { + // If the state tensors are outputs, this function takes the + // responsibility to resize the state tensors. + TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2); + activation_state_size->data[0] = n_batch; + activation_state_size->data[1] = n_output; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state, + activation_state_size)); + + TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); + cell_size->data[0] = n_batch; + cell_size->data[1] = n_cell; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state, cell_size)); + // Mark state tensors as persistent tensors. + activation_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + } + + // Resize the output tensors. TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); output_size->data[0] = n_batch; output_size->data[1] = n_output; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size)); - TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2); - output_state_size->data[0] = n_batch; - output_state_size->data[1] = n_output; - TF_LITE_ENSURE_OK( - context, context->ResizeTensor(context, output_state, output_state_size)); - - TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); - cell_size->data[0] = n_batch; - cell_size->data[1] = n_cell; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, cell_state, cell_size)); - - // Mark state tensors as persistent tensors. - output_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - // The weights are of consistent type, so it suffices to check one. // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && @@ -337,7 +384,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (is_hybrid_op) { // Allocate temporary tensors to store quantized values of input, - // output_state and cell_state tensors. + // activation_state and cell_state tensors. node->temporaries->data[1] = op_data->scratch_tensor_index + 1; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); input_quantized->type = kTfLiteUInt8; @@ -348,17 +395,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { input_quantized_size)); } node->temporaries->data[2] = op_data->scratch_tensor_index + 2; - TfLiteTensor* output_state_quantized = + TfLiteTensor* activation_state_quantized = GetTemporary(context, node, /*index=*/2); - output_state_quantized->type = kTfLiteUInt8; - output_state_quantized->allocation_type = kTfLiteArenaRw; - if (!TfLiteIntArrayEqual(output_state_quantized->dims, - output_state->dims)) { - TfLiteIntArray* output_state_quantized_size = - TfLiteIntArrayCopy(output_state->dims); - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, output_state_quantized, - output_state_quantized_size)); + activation_state_quantized->type = kTfLiteUInt8; + activation_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(activation_state_quantized->dims, + activation_state->dims)) { + TfLiteIntArray* activation_state_quantized_size = + TfLiteIntArrayCopy(activation_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, activation_state_quantized, + activation_state_quantized_size)); } node->temporaries->data[3] = op_data->scratch_tensor_index + 3; TfLiteTensor* cell_state_quantized = @@ -438,7 +485,7 @@ TfLiteStatus EvalFloat( const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, - TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* activation_state, TfLiteTensor* cell_state, TfLiteTensor* output) { const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; @@ -499,7 +546,7 @@ TfLiteStatus EvalFloat( const float* cell_bias_ptr = cell_bias->data.f; const float* output_gate_bias_ptr = output_gate_bias->data.f; - float* output_state_ptr = output_state->data.f; + float* activation_state_ptr = activation_state->data.f; float* cell_state_ptr = cell_state->data.f; float* output_ptr_batch = output->data.f; @@ -512,8 +559,8 @@ TfLiteStatus EvalFloat( cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, - output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, output_ptr_batch); + activation_state_ptr, cell_state_ptr, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch); return kTfLiteOk; } @@ -536,9 +583,9 @@ TfLiteStatus EvalHybrid( const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, - TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, - TfLiteTensor* output_state, TfLiteTensor* cell_state, - TfLiteTensor* output) { + TfLiteTensor* activation_state_quantized, + TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; // n_cell and n_output will be the same size when there is no projection. @@ -639,15 +686,15 @@ TfLiteStatus EvalHybrid( const float* cell_bias_ptr = cell_bias->data.f; const float* output_gate_bias_ptr = output_gate_bias->data.f; - float* output_state_ptr = output_state->data.f; + float* activation_state_ptr = activation_state->data.f; float* cell_state_ptr = cell_state->data.f; float* output_ptr_batch = output->data.f; // Temporary storage for quantized values and scaling factors. int8_t* quantized_input_ptr = reinterpret_cast(input_quantized->data.uint8); - int8_t* quantized_output_state_ptr = - reinterpret_cast(output_state_quantized->data.uint8); + int8_t* quantized_activation_state_ptr = + reinterpret_cast(activation_state_quantized->data.uint8); int8_t* quantized_cell_state_ptr = reinterpret_cast(cell_state_quantized->data.uint8); float* scaling_factors_ptr = scaling_factors->data.f; @@ -672,14 +719,16 @@ TfLiteStatus EvalHybrid( input_gate_scratch, forget_gate_scratch, cell_scratch, output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, recovered_cell_weights_ptr, quantized_input_ptr, - quantized_output_state_ptr, quantized_cell_state_ptr, output_state_ptr, - cell_state_ptr, output_ptr_batch); + quantized_activation_state_ptr, quantized_cell_state_ptr, + activation_state_ptr, cell_state_ptr, output_ptr_batch); return kTfLiteOk; } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast(node->builtin_data); + OpData* op_data = reinterpret_cast(node->user_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input_to_input_weights = @@ -723,8 +772,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Index the scratch buffers pointers to the global scratch buffer. TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* activation_state = + &context->tensors[op_data->activation_state_tensor_index]; + TfLiteTensor* cell_state = + &context->tensors[op_data->cell_state_tensor_index]; + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // TODO(mirkov): add a check that weights are all uint8s or all floats. @@ -738,11 +790,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { cell_to_output_weights, input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, projection_bias, params, - scratch_buffer, output_state, cell_state, output); + scratch_buffer, activation_state, cell_state, output); } case kTfLiteUInt8: { TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); - TfLiteTensor* output_state_quantized = + TfLiteTensor* activation_state_quantized = GetTemporary(context, node, /*index=*/2); TfLiteTensor* cell_state_quantized = GetTemporary(context, node, /*index=*/3); @@ -760,8 +812,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, projection_weights, projection_bias, params, scratch_buffer, scaling_factors, prod_scaling_factors, recovered_cell_weights, - input_quantized, output_state_quantized, cell_state_quantized, - output_state, cell_state, output); + input_quantized, activation_state_quantized, cell_state_quantized, + activation_state, cell_state, output); } default: context->ReportError(context, "Type %d is not currently supported.", @@ -794,7 +846,7 @@ enum OutputTensor { }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* op_data = new OpData; + auto* op_data = new OpData(); op_data->kernel_type = kTfLiteLSTMBasicKernel; // `scratch_tensor_index` is unused in this kernel. op_data->scratch_tensor_index = -1; @@ -805,13 +857,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, node->inputs->size == kInputNum); TF_LITE_ENSURE(context, node->outputs->size == kOutputNum); - // Only Float32 is supported currently. - // TODO(ycling): Implement quantize uint8 support. - for (int index = 0; index < node->inputs->size; ++index) { - TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; - TF_LITE_ENSURE_EQ(context, tensor->type, kTfLiteFloat32); - } - const TfLiteTensor* input = GetInput(context, node, kInputData); const TfLiteTensor* prev_activation = GetInput(context, node, kInputPrevActivation); @@ -821,15 +866,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, input->dims->size, 2); const int num_batches = input->dims->data[0]; + const int input_depth = input->dims->data[1]; TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2); TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches); + const int activation_depth = prev_activation->dims->data[1]; + const int total_depth = input_depth + activation_depth; TF_LITE_ENSURE_EQ(context, weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, weights->dims->data[0], 4 * activation_depth); + TF_LITE_ENSURE_EQ(context, weights->dims->data[1], total_depth); + TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, bias->dims->data[0], 4 * activation_depth); TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2); TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches); + TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth); TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); TfLiteTensor* state_out = GetOutput(context, node, kOutputState); @@ -843,14 +896,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, state_out, TfLiteIntArrayCopy(prev_state->dims))); + TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2); concat_temp_size->data[0] = num_batches; - concat_temp_size->data[1] = weights->dims->data[1]; + concat_temp_size->data[1] = total_depth; TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, concat_temp, concat_temp_size)); TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2); activation_temp_size->data[0] = num_batches; - activation_temp_size->data[1] = weights->dims->data[0]; + activation_temp_size->data[1] = 4 * activation_depth; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp, activation_temp_size)); @@ -876,18 +930,73 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* activation_temp = GetOutput(context, node, kOutputActivationTemp); - optimized_ops::LstmCell( - // Inputs. - GetTensorData(input), GetTensorDims(input), - GetTensorData(prev_activation), GetTensorDims(prev_activation), - GetTensorData(weights), GetTensorDims(weights), - GetTensorData(bias), GetTensorDims(bias), - GetTensorData(prev_state), GetTensorDims(prev_state), - // Outputs. - GetTensorData(state_out), GetTensorDims(state_out), - GetTensorData(activation_out), GetTensorDims(activation_out), - GetTensorData(concat_temp), GetTensorDims(concat_temp), - GetTensorData(activation_temp), GetTensorDims(activation_temp)); + if (input->type == kTfLiteFloat32 && + prev_activation->type == kTfLiteFloat32 && + weights->type == kTfLiteFloat32 && bias->type == kTfLiteFloat32 && + prev_state->type == kTfLiteFloat32 && state_out->type == kTfLiteFloat32 && + activation_out->type == kTfLiteFloat32 && + concat_temp->type == kTfLiteFloat32 && + activation_temp->type == kTfLiteFloat32) { + optimized_ops::LstmCell( + // Inputs. + GetTensorData(input), GetTensorDims(input), + GetTensorData(prev_activation), GetTensorDims(prev_activation), + GetTensorData(weights), GetTensorDims(weights), + GetTensorData(bias), GetTensorDims(bias), + GetTensorData(prev_state), GetTensorDims(prev_state), + // Outputs. + GetTensorData(state_out), GetTensorDims(state_out), + GetTensorData(activation_out), GetTensorDims(activation_out), + GetTensorData(concat_temp), GetTensorDims(concat_temp), + GetTensorData(activation_temp), GetTensorDims(activation_temp)); + } else if (input->type == kTfLiteUInt8 && + prev_activation->type == kTfLiteUInt8 && + weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 && + prev_state->type == kTfLiteInt16 && + state_out->type == kTfLiteInt16 && + activation_out->type == kTfLiteUInt8 && + concat_temp->type == kTfLiteUInt8 && + activation_temp->type == kTfLiteInt16) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + int state_scale_log2_rounded; + if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) { + context->ReportError( + context, + "The internal state of a LSTM cell must have a power-of-two scale."); + return kTfLiteError; + } + const int state_integer_bits = 15 + state_scale_log2_rounded; + if (state_integer_bits != 4) { + context->ReportError(context, + "The only case of quantized LstmCell currently " + "supported is with StateIntegerBits==4"); + return kTfLiteError; + } + + double real_accum_multiplier = 4096 * bias->params.scale; + int32 accum_multiplier; + int accum_shift; + tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier, + &accum_shift); + optimized_ops::LstmCell<4>( + // Inputs. + GetTensorData(input), GetTensorDims(input), + GetTensorData(prev_activation), GetTensorDims(prev_activation), + GetTensorData(weights), GetTensorDims(weights), + GetTensorData(bias), GetTensorDims(bias), + GetTensorData(prev_state), GetTensorDims(prev_state), + // Outputs. + GetTensorData(state_out), GetTensorDims(state_out), + GetTensorData(activation_out), GetTensorDims(activation_out), + GetTensorData(concat_temp), GetTensorDims(concat_temp), + GetTensorData(activation_temp), GetTensorDims(activation_temp), + weights->params.zero_point, accum_multiplier, accum_shift, + gemm_context); + } else { + context->ReportError(context, + "Unsupported combination of data types for LstmCell"); + return kTfLiteError; + } // TODO(ycling): Investigate if this copy can be avoided with the 5-inputs // LSTM kernel. @@ -901,6 +1010,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace basic void* Init(TfLiteContext* context, const char* buffer, size_t length) { + gemm_support::IncrementUsageCounter(context); + const auto* params = reinterpret_cast(buffer); switch (params->kernel_type) { case kTfLiteLSTMFullKernel: @@ -910,6 +1021,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { } } void Free(TfLiteContext* context, void* buffer) { + gemm_support::DecrementUsageCounter(context); + delete reinterpret_cast(buffer); } diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index 6da29a4a923f16f7b5ad382f51cfd820783504cd..0266f5fe57e6c60ea19ad5f8de05e879e7da9304 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Unit test for TFLite LSTM op. +// +// TODO(alanchiao): add unit test with invalid input dimensions for this and its +// variants. #include #include @@ -97,6 +100,12 @@ class LSTMOpModel : public SingleOpModel { projection_bias_ = AddNullInput(); } + // Adding the 2 input state tensors. + input_activation_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); + input_cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); + output_state_ = AddOutput(TensorType_FLOAT32); cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -227,6 +236,8 @@ class LSTMOpModel : public SingleOpModel { int projection_weights_; int projection_bias_; + int input_activation_state_; + int input_cell_state_; int output_; int output_state_; @@ -352,14 +363,6 @@ class BaseLstmTest : public ::testing::Test { } EXPECT_THAT(lstm->GetOutput(), ElementsAreArray(ArrayFloatNear(expected, tolerance))); - for (int i = 0; i < num_outputs; ++i) { - std::cout << lstm->GetOutput()[i] << ", "; - } - std::cout << std::endl; - for (int i = 0; i < num_outputs; ++i) { - std::cout << expected[i] << ", "; - } - std::cout << std::endl; } } }; diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc b/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc index 0752aa1804722accb1f88910fe013ffd632a4503..fd4d5367c5a6369b5ffeeea30a910262bc0796a9 100644 --- a/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc +++ b/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc @@ -126,10 +126,10 @@ TEST(MaximumOpTest, FloatWithBroadcastTest) { TEST(MaximumOpTest, Int32WithBroadcastTest) { std::initializer_list data1 = {1, 0, -1, -2, 3, 11}; std::initializer_list data2 = {2}; - TestModel(BuiltinOperator_MAXIMUM, {TensorType_INT32, {3, 1, 2}}, + TestModel(BuiltinOperator_MAXIMUM, {TensorType_INT32, {3, 1, 2}}, {TensorType_INT32, {1}}, {TensorType_INT32, {3, 1, 2}}, data1, data2, {2, 2, 2, 2, 3, 11}); - TestModel(BuiltinOperator_MINIMUM, {TensorType_INT32, {3, 1, 2}}, + TestModel(BuiltinOperator_MINIMUM, {TensorType_INT32, {3, 1, 2}}, {TensorType_INT32, {1}}, {TensorType_INT32, {3, 1, 2}}, data1, data2, {1, 0, -1, -2, 2, 2}); } diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/mean_test.cc deleted file mode 100644 index 79c9957f76fdb994be0a71f2e90b883435de4815..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/kernels/mean_test.cc +++ /dev/null @@ -1,219 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/kernels/test_util.h" -#include "tensorflow/contrib/lite/model.h" - -namespace tflite { -namespace { - -using ::testing::ElementsAreArray; - -class BaseMeanOpModel : public SingleOpModel { - public: - void SetAxis(std::initializer_list data) { PopulateTensor(axis_, data); } - - template - void SetInput(std::initializer_list data) { - PopulateTensor(input_, data); - } - - template - std::vector GetOutput() { - return ExtractVector(output_); - } - - std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); - } - - std::vector GetOutputShape() { return GetTensorShape(output_); } - - int Input() { return input_; } - - protected: - int input_; - int axis_; - int output_; -}; - -// Model for the tests case where axis is a const tensor. -class MeanOpConstModel : public BaseMeanOpModel { - public: - MeanOpConstModel(const TensorData& input, const TensorData& output, - std::initializer_list axis_shape, - std::initializer_list axis, bool keep_dims) { - input_ = AddInput(input); - axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); - output_ = AddOutput(output); - SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, - CreateMeanOptions(builder_, keep_dims).Union()); - BuildInterpreter({GetShape(input_)}); - } -}; - -// Model for the tests case where axis is a dynamic tensor. -class MeanOpDynamicModel : public BaseMeanOpModel { - public: - MeanOpDynamicModel(const TensorData& input, const TensorData& output, - const TensorData& axis, bool keep_dims) { - input_ = AddInput(input); - axis_ = AddInput(axis); - output_ = AddOutput(output); - SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, - CreateMeanOptions(builder_, keep_dims).Union()); - BuildInterpreter({GetShape(input_)}); - } -}; - -TEST(ConstFloatMeanOpTest, NotKeepDims) { - std::initializer_list data = { - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; - MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, - {4}, {1, 0, -3, -3}, false); - m.SetInput(data); - m.Invoke(); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); -} - -TEST(ConstFloatMeanOpTest, KeepDims) { - std::initializer_list data = { - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; - MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, - {2}, {0, 2}, true); - m.SetInput(data); - m.Invoke(); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); - EXPECT_THAT(m.GetOutput(), - ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); -} - -TEST(DynamicFloatMeanOpTest, NotKeepDims) { - std::initializer_list data = { - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; - MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, - {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, - false); - std::initializer_list axis = {1, 0, -3, -3}; - m.SetAxis(axis); - m.SetInput(data); - m.Invoke(); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); -} - -TEST(DynamicFloatMeanOpTest, KeepDims) { - std::initializer_list data = { - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; - MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, - {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, - true); - std::initializer_list axis = {0, 2}; - m.SetAxis(axis); - m.SetInput(data); - m.Invoke(); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); - EXPECT_THAT(m.GetOutput(), - ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); -} - -TEST(DynamicFloatMeanOpTest, Scale) { - std::initializer_list data = {9.527}; - MeanOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, - {TensorType_INT32, {1}}, true); - std::initializer_list axis = {0}; - m.SetAxis(axis); - m.SetInput(data); - m.Invoke(); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); -} - -// for quantized Add, the error shouldn't exceed step -float GetTolerance(int min, int max) { return (max - min) / 255.0; } - -TEST(ConstUint8MeanOpTest, NotKeepDims) { - float kQuantizedTolerance = GetTolerance(-1.0, 1.0); - std::initializer_list data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; - MeanOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0}, - {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false); - m.QuantizeAndPopulate(m.Input(), data); - m.Invoke(); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); - EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - {0.4, 0.4}, kQuantizedTolerance))); -} - -TEST(ConstUint8MeanOpTest, KeepDims) { - float kQuantizedTolerance = GetTolerance(-1.0, 1.0); - std::initializer_list data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; - MeanOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0}, - {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true); - m.QuantizeAndPopulate(m.Input(), data); - m.Invoke(); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); - EXPECT_THAT( - m.GetDequantizedOutput(), - ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance))); -} - -TEST(DynamicUint8MeanOpTest, NotKeepDims) { - float kQuantizedTolerance = GetTolerance(-5.0, 2.0); - std::initializer_list data = {1.3, -4.8, -3.6, 0.24}; - MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0}, - {TensorType_UINT8, {2}, -5.0, 2.0}, - {TensorType_INT32, {1}}, false); - std::initializer_list axis = {1}; - m.SetAxis(axis); - m.QuantizeAndPopulate(m.Input(), data); - m.Invoke(); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT( - m.GetDequantizedOutput(), - ElementsAreArray(ArrayFloatNear({-1.75, -1.68}, kQuantizedTolerance))); -} - -TEST(DynamicUint8MeanOpTest, KeepDims) { - float kQuantizedTolerance = GetTolerance(-10.0, 12.0); - std::initializer_list data = {11.14, -0.14, 7.423, 0.879}; - MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0}, - {TensorType_UINT8, {2}, -10.0, 12.0}, - {TensorType_INT32, {1}}, true); - std::initializer_list axis = {0}; - m.SetAxis(axis); - m.QuantizeAndPopulate(m.Input(), data); - m.Invoke(); - EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); - EXPECT_THAT( - m.GetDequantizedOutput(), - ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance))); -} - -} // namespace -} // namespace tflite - -int main(int argc, char** argv) { - ::tflite::LogToStderr(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index b69a221447db963bcd3a7e6a69f132fe3767bfd1..561e39cfc694349b79a5aa991a489fe07d7f922f 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -39,6 +39,14 @@ constexpr int kOutputTensor = 0; struct OpData { bool requires_broadcast; + + // Parameters used in the quantized paths where the output is 8bit + int32 output_activation_min; + int32 output_activation_max; + + // Parameters used in all quantized paths + int32_t output_multiplier; + int output_shift; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -52,6 +60,7 @@ void Free(TfLiteContext* context, void* buffer) { } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); @@ -62,7 +71,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, input1->type, input2->type); - output->type = input2->type; data->requires_broadcast = !HaveSameShapes(input1, input2); @@ -74,75 +82,121 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size = TfLiteIntArrayCopy(input1->dims); } + if (output->type == kTfLiteUInt8) { + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + double real_multiplier = + input1->params.scale * input2->params.scale / output->params.scale; + QuantizeMultiplierSmallerThanOneExp( + real_multiplier, &data->output_multiplier, &data->output_shift); + } + return context->ResizeTensor(context, output, output_size); } template -void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteMulParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); -#define TF_LITE_MUL(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) - if (kernel_type == kReference) { - if (data->requires_broadcast) { - TF_LITE_MUL(reference_ops, BroadcastMul); +void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params, + const OpData* data, const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { +#define TF_LITE_MUL(type, opname, data_type) \ + data_type output_activation_min, output_activation_max; \ + CalculateActivationRange(params->activation, &output_activation_min, \ + &output_activation_max); \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (output->type == kTfLiteInt32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_MUL(reference_ops, BroadcastMul, int32_t); + } else { + TF_LITE_MUL(reference_ops, Mul, int32_t); + } } else { - TF_LITE_MUL(reference_ops, Mul); + if (data->requires_broadcast) { + TF_LITE_MUL(optimized_ops, BroadcastMul, int32_t); + } else { + TF_LITE_MUL(optimized_ops, Mul, int32_t); + } } - } else { - if (data->requires_broadcast) { - TF_LITE_MUL(optimized_ops, BroadcastMul); + } else if (output->type == kTfLiteFloat32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_MUL(reference_ops, BroadcastMul, float); + } else { + TF_LITE_MUL(reference_ops, Mul, float); + } } else { - TF_LITE_MUL(optimized_ops, Mul); + if (data->requires_broadcast) { + TF_LITE_MUL(optimized_ops, BroadcastMul, float); + } else { + TF_LITE_MUL(optimized_ops, Mul, float); + } } } #undef TF_LITE_MUL } template -void EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteMulParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - auto input1_offset = -input1->params.zero_point; - auto input2_offset = -input2->params.zero_point; - auto output_offset = output->params.zero_point; - - int32_t output_multiplier; - int output_shift; - - double real_multiplier = - input1->params.scale * input2->params.scale / output->params.scale; - QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier, - &output_shift); - output_shift *= -1; - - int32 output_activation_min, output_activation_max; - CalculateActivationRangeUint8(params->activation, output, - &output_activation_min, &output_activation_max); - -#define TF_LITE_MUL(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - input1_offset, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, output_offset, \ - output_multiplier, output_shift, output_activation_min, \ - output_activation_max, GetTensorData(output), \ +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, const OpData* data, + const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { + if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 && + output->type == kTfLiteUInt8) { +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + -input1->params.zero_point, GetTensorData(input2), \ + GetTensorDims(input2), -input2->params.zero_point, \ + output->params.zero_point, data->output_multiplier, \ + data->output_shift, data->output_activation_min, \ + data->output_activation_max, GetTensorData(output), \ GetTensorDims(output)); - // The quantized version of Mul doesn't support activations, so we - // always use BroadcastMul. - if (kernel_type == kReference) { - TF_LITE_MUL(reference_ops, BroadcastMul); + // The quantized version of Mul doesn't support activations, so we + // always use BroadcastMul. + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops, BroadcastMul4DSlow); + } else { + TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow); + } +#undef TF_LITE_MUL + } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 && + output->type == kTfLiteInt16) { +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + GetTensorData(output), GetTensorDims(output)); + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops, Mul); + } else { + TF_LITE_MUL(optimized_ops, Mul); + } +#undef TF_LITE_MUL + } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 && + output->type == kTfLiteUInt8) { +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output->params.zero_point, data->output_activation_min, \ + data->output_activation_max, GetTensorData(output), \ + GetTensorDims(output)); + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops, Mul); + } else { + TF_LITE_MUL(optimized_ops, Mul); + } +#undef TF_LITE_MUL } else { - TF_LITE_MUL(optimized_ops, BroadcastMul); + context->ReportError( + context, "Unsupported combination of input and output types in Mul."); + return kTfLiteError; } -#undef TF_LITE_MUL + return kTfLiteOk; } template @@ -154,15 +208,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (output->type == kTfLiteFloat32) { - EvalFloat(context, node, params, data, input1, input2, output); - } else if (output->type == kTfLiteUInt8) { - EvalQuantized(context, node, params, data, input1, input2, - output); + if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { + EvalMul(context, node, params, data, input1, input2, output); + } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + TF_LITE_ENSURE_OK( + context, EvalQuantized(context, node, params, data, input1, + input2, output)); } else { - context->ReportError( - context, "Mul only supports FLOAT32 and quantized UINT8 now, got %d.", - output->type); + context->ReportError(context, + "Mul only supports FLOAT32, INT32 and quantized UINT8 " + "and INT16 now, got %d.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc index f1a30f82634631ba8320421d5b36ffe446f443fa..2807550a6b07f3f9f1f1e3f72acc9882c76d166a 100644 --- a/tensorflow/contrib/lite/kernels/mul_test.cc +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -52,12 +52,22 @@ class FloatMulOpModel : public BaseMulOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; +class IntegerMulOpModel : public BaseMulOpModel { + public: + using BaseMulOpModel::BaseMulOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + // For quantized Mul, the error shouldn't exceed (2*step + step^2). // The param min=-1.0 & max=1.0 is used in the following tests. // The tolerance value is ~0.0157. const float kQuantizedStep = 2.0 / 255.0; const float kQuantizedTolerance = 2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep; +const float kQuantizedStepInt16 = 2.0 / 32767.0; +const float kQuantizedToleranceInt16 = + 2.0 * kQuantizedStepInt16 + kQuantizedStepInt16 * kQuantizedStepInt16; class QuantizedMulOpModel : public BaseMulOpModel { public: @@ -67,6 +77,11 @@ class QuantizedMulOpModel : public BaseMulOpModel { return Dequantize(ExtractVector(output_), GetScale(output_), GetZeroPoint(output_)); } + + std::vector GetDequantizedOutputInt16() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } }; TEST(FloatMulOpTest, NoActivation) { @@ -125,6 +140,57 @@ TEST(FloatMulOpTest, WithBroadcast) { } } +TEST(IntegerMulOpTest, NoActivation) { + IntegerMulOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 4, 21, 40})); +} + +TEST(IntegerMulOpTest, ActivationRELU_N1_TO_1) { + IntegerMulOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_RELU_N1_TO_1); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 1, 1})); +} + +TEST(IntegerMulOpTest, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerMulOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5, 11, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 4, 21, 40, 121, 20})) + << "With shape number " << i; + } +} + +TEST(IntegerMulOpTest, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerMulOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, // always a scalar + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-20, 2, 7, 8, 11, 20}))) + << "With shape number " << i; + } +} + TEST(QuantizedMulOpTest, NoActivation) { QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, @@ -138,6 +204,38 @@ TEST(QuantizedMulOpTest, NoActivation) { kQuantizedTolerance))); } +TEST(QuantizedMulOpTest, NoActivationInt16) { + const float kMin = -1.f; + const float kMax = 32767.f / 32768.f; + QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); + m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutputInt16(), + ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, + kQuantizedToleranceInt16))); +} + +TEST(QuantizedMulOpTest, NoActivationInt16WithUint8Output) { + const float kMinInt16 = -1.f; + const float kMaxInt16 = 32767.f / 32768.f; + const float kMinUint8 = -1.f; + const float kMaxUint8 = 127.f / 128.f; + QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16}, + {TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16}, + {TensorType_UINT8, {}, kMinUint8, kMaxUint8}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); + m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, + kQuantizedTolerance))); +} + // for quantized Mul, the error shouldn't exceed 2*step float GetTolerance(int min, int max) { float kQuantizedStep = (max - min) / 255.0; diff --git a/tensorflow/contrib/lite/kernels/neg_test.cc b/tensorflow/contrib/lite/kernels/neg_test.cc index 3c95ac8cc2727fdeff5f39aa2fe30eb6129a6022..3d3594c60bbe1684dff7b1816f5f8a715b1abc60 100644 --- a/tensorflow/contrib/lite/kernels/neg_test.cc +++ b/tensorflow/contrib/lite/kernels/neg_test.cc @@ -58,9 +58,9 @@ TEST(NegOpModel, NegFloat) { TEST(NegOpModel, NegInt32) { NegOpModel m({TensorType_INT32, {2, 3}}, {TensorType_INT32, {2, 3}}); - m.SetInput({-2, -1, 0, 1, 2, 3}); + m.SetInput({-2, -1, 0, 1, 2, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1, 0, -1, -2, -3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1, 0, -1, -2, -3})); } TEST(NegOpModel, NegInt64) { diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/contrib/lite/kernels/one_hot.cc new file mode 100644 index 0000000000000000000000000000000000000000..9ff3dca932d4284321b299cfc79571c43fce7155 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/one_hot.cc @@ -0,0 +1,199 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace one_hot { + +constexpr int kIndicesTensor = 0; +constexpr int kDepthTensor = 1; +constexpr int kOnValueTensor = 2; +constexpr int kOffValueTensor = 3; +constexpr int kOutputTensor = 0; + +// Convenience utility for destructuring a node into the appropriate tensors and +// data for the op. Note that this destructuring is quite cheap, so we can avoid +// allocating op-specific, persistent data on the heap. +struct OneHotContext { + OneHotContext(TfLiteContext* context, TfLiteNode* node) { + indices = GetInput(context, node, kIndicesTensor); + depth = GetInput(context, node, kDepthTensor); + on_value = GetInput(context, node, kOnValueTensor); + off_value = GetInput(context, node, kOffValueTensor); + output = GetOutput(context, node, kOutputTensor); + + const auto* params = + reinterpret_cast(node->builtin_data); + const int indices_dims = indices->dims->size; + axis = (params->axis == -1) ? indices_dims : params->axis; + output_dims = indices_dims + 1; + dtype = on_value->type; + } + + const TfLiteTensor* indices; + const TfLiteTensor* depth; + const TfLiteTensor* on_value; + const TfLiteTensor* off_value; + TfLiteTensor* output; + int axis; + int output_dims; + TfLiteType dtype; +}; + +template +void OneHotComputeImpl(const OneHotContext& op_context) { + // prefix_dim_size == # of elements before the axis + // depth == # of elements per axis + // suffix_dim_size == # of elements after the axis + int prefix_dim_size = 1; + for (int i = 0; i < op_context.axis; ++i) { + prefix_dim_size *= op_context.indices->dims->data[i]; + } + const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size; + const int depth = *op_context.depth->data.i32; + + const T on_value = *GetTensorData(op_context.on_value); + const T off_value = *GetTensorData(op_context.off_value); + + // View the indices as a matrix of size: + // prefix_dim_size x suffix_dim_size + // View the output as a matrix of size: + // prefix_dim_size x depth x suffix_dim_size + // Then the output is: + // output(i, j, k) == (indices(i, k) == j) ? on : off + T* output = GetTensorData(op_context.output); + const TI* indices = GetTensorData(op_context.indices); + for (int i = 0; i < prefix_dim_size; ++i) { + for (int j = 0; j < depth; ++j) { + for (int k = 0; k < suffix_dim_size; ++k, ++output) { + *output = static_cast(indices[i * suffix_dim_size + k]) == j + ? on_value + : off_value; + } + } + } +} + +template +void OneHotCompute(const OneHotContext& op_context) { + if (op_context.indices->type == kTfLiteInt64) { + OneHotComputeImpl(op_context); + } else { + OneHotComputeImpl(op_context); + } +} + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + const OneHotContext& op_context) { + TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0); + TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context.output_dims); + for (int i = 0; i < op_context.output_dims; ++i) { + if (i < op_context.axis) { + output_size->data[i] = op_context.indices->dims->data[i]; + } else if (i == op_context.axis) { + output_size->data[i] = *op_context.depth->data.i32; + } else { + output_size->data[i] = op_context.indices->dims->data[i - 1]; + } + } + return context->ResizeTensor(context, op_context.output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + OneHotContext op_context{context, node}; + switch (op_context.dtype) { + // TODO(b/111744875): Support uint8 and quantization. + case kTfLiteFloat32: + case kTfLiteInt16: + case kTfLiteInt32: + case kTfLiteInt64: + case kTfLiteBool: + op_context.output->type = op_context.dtype; + break; + default: + context->ReportError(context, "Unknown output data type: %d", + op_context.dtype); + return kTfLiteError; + } + + TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 || + op_context.indices->type == kTfLiteInt64); + TF_LITE_ENSURE(context, op_context.axis >= 0 && + op_context.axis < op_context.output_dims); + TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1); + TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1); + TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1); + TF_LITE_ENSURE_EQ(context, op_context.on_value->type, op_context.dtype); + TF_LITE_ENSURE_EQ(context, op_context.off_value->type, op_context.dtype); + + if (!IsConstantTensor(op_context.depth)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + + return ResizeOutputTensor(context, op_context); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OneHotContext op_context{context, node}; + + if (IsDynamicTensor(op_context.output)) { + ResizeOutputTensor(context, op_context); + } + + switch (op_context.output->type) { + case kTfLiteFloat32: + OneHotCompute(op_context); + break; + case kTfLiteInt32: + OneHotCompute(op_context); + break; + case kTfLiteInt64: + OneHotCompute(op_context); + break; + case kTfLiteBool: + OneHotCompute(op_context); + break; + default: + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace one_hot + +TfLiteRegistration* Register_ONE_HOT() { + static TfLiteRegistration r = { + nullptr, + nullptr, + one_hot::Prepare, + one_hot::Eval, + }; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/one_hot_test.cc b/tensorflow/contrib/lite/kernels/one_hot_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b604ec7a7f86b333805d91a95cb5054f0257ae4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/one_hot_test.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class OneHotOpModel : public SingleOpModel { + public: + OneHotOpModel(std::initializer_list input_shape, int depth_value, + TensorType dtype, int axis = -1, T on_value = 1, + T off_value = 0, TensorType indices_type = TensorType_INT32) { + indices_ = AddInput(indices_type); + int depth = AddInput(TensorType_INT32); + int on = AddInput(dtype); + int off = AddInput(dtype); + output_ = AddOutput(dtype); + SetBuiltinOp(BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions, + CreateOneHotOptions(builder_, axis).Union()); + BuildInterpreter({input_shape}); + + PopulateTensor(depth, {depth_value}); + PopulateTensor(on, {on_value}); + PopulateTensor(off, {off_value}); + } + + template + void SetIndices(std::initializer_list data) { + PopulateTensor(indices_, data); + } + + TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); } + + int32_t GetOutputSize() { return GetTensorSize(output_); } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int indices_; + int output_; +}; + +TEST(OneHotOpTest, BasicFloat) { + const int depth = 3; + OneHotOpModel model({3}, depth, TensorType_FLOAT32); + model.SetIndices({0, 1, 2}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f})); +} + +TEST(OneHotOpTest, BasicInt) { + const int depth = 3; + OneHotOpModel model({3}, depth, TensorType_INT32); + model.SetIndices({0, 1, 2}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1})); +} + +TEST(OneHotOpTest, BasicBool) { + const int depth = 3; + OneHotOpModel model({3}, depth, TensorType_BOOL); + model.SetIndices({0, 1, 2}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({true, false, false, false, true, false, false, + false, true})); +} + +TEST(OneHotOpTest, SmallDepth) { + const int depth = 1; + OneHotOpModel model({3}, depth, TensorType_INT32); + model.SetIndices({0, 1, 2}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0})); +} + +TEST(OneHotOpTest, BigDepth) { + const int depth = 4; + OneHotOpModel model({2}, depth, TensorType_INT32); + model.SetIndices({0, 1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 0, 1, 0, 0})); +} + +TEST(OneHotOpTest, OnOffValues) { + const int depth = 3; + const int axis = -1; + const int on = 5; + const int off = 0; + OneHotOpModel model({4}, depth, TensorType_INT32, axis, on, off); + model.SetIndices({0, 2, -1, 1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({4, 3})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0})); +} + +TEST(OneHotOpTest, ZeroAxis) { + const int depth = 3; + const int axis = 0; + const int on = 5; + const int off = 0; + OneHotOpModel model({4}, depth, TensorType_INT32, axis, on, off); + model.SetIndices({0, 2, -1, 1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0})); +} + +TEST(OneHotOpTest, MultiDimensionalIndices) { + const int depth = 3; + const int axis = -1; + const float on = 2; + const float off = 0; + OneHotOpModel model({2, 2}, depth, TensorType_FLOAT32, axis, on, off); + model.SetIndices({0, 2, 1, -1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 3})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({2, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0})); +} + +TEST(OneHotOpTest, Int64Indices) { + const int depth = 3; + const int axis = -1; + const int on = 1; + const int off = 0; + OneHotOpModel model({3}, depth, TensorType_INT32, axis, on, off, + TensorType_INT64); + std::initializer_list indices = {0, 1, 2}; + model.SetIndices(indices); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc index bcad58406af1cdd466e410a06011641692194be4..1c728a473326564a85a5e7d3d72718265979e29a 100644 --- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -95,6 +95,12 @@ class LSTMOpModel : public SingleOpModel { projection_bias_ = AddNullInput(); } + // Adding the 2 input state tensors. + input_activation_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); + input_cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); + output_state_ = AddOutput(TensorType_FLOAT32); cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -228,6 +234,8 @@ class LSTMOpModel : public SingleOpModel { int projection_weights_; int projection_bias_; + int input_activation_state_; + int input_cell_state_; int output_; int output_state_; diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc new file mode 100644 index 0000000000000000000000000000000000000000..bb3416f6a6ca60250f137986e479e8f1085e2558 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pack.cc @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pack { +namespace { + +constexpr int kOutputTensor = 0; + +// Op data for pack op. +struct OpData { + int values_count; + int axis; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->axis = 0; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const OpData* data = reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input0 = GetInput(context, node, 0); + TF_LITE_ENSURE(context, NumDimensions(input0) < 4); + TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis); + // TODO(renjieliu): Support negative axis. + TF_LITE_ENSURE(context, data->axis >= 0); + if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32) { + context->ReportError(context, + "Currently pack only supports int32 and float32."); + return kTfLiteError; + } + // Make sure all inputs have the same shape and type. + for (int i = 1; i < data->values_count; ++i) { + const TfLiteTensor* input = GetInput(context, node, i); + TF_LITE_ENSURE(context, HaveSameShapes(input0, input)); + TF_LITE_ENSURE_EQ(context, input0->type, input->type); + } + + // Resize output. rank R will become rank R + 1 + const int dimension_size = NumDimensions(input0) + 1; + const TfLiteIntArray* input_shape = input0->dims; + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size); + int i = 0; + for (int index = 0; index < dimension_size; ++index) { + if (index == data->axis) { + output_shape->data[index] = data->values_count; + } else { + output_shape->data[index] = input_shape->data[i++]; + } + } + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, output->type, input0->type); + + return context->ResizeTensor(context, output, output_shape); +} + +template +void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output, + int values_count, int axis) { + VectorOfTensors all_inputs(*context, *node->inputs); + reference_ops::Pack(RemapDim(NumDimensions(output), axis), + all_inputs.data(), all_inputs.dims(), values_count, + GetTensorData(output), GetTensorDims(output)); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const OpData* data = reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + switch (output->type) { + case kTfLiteFloat32: { + PackImpl(context, node, output, data->values_count, data->axis); + break; + } + case kTfLiteInt32: { + PackImpl(context, node, output, data->values_count, data->axis); + break; + } + default: { + context->ReportError(context, + "Currently pack only supports int32 and float32."); + return kTfLiteError; + } + } + + return kTfLiteOk; +} + +} // namespace +} // namespace pack + +TfLiteRegistration* Register_PACK() { + static TfLiteRegistration r = {pack::Init, pack::Free, pack::Prepare, + pack::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pack_test.cc b/tensorflow/contrib/lite/kernels/pack_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..485a50ad3ac493fd02f619f7d7cbaf10d3a6aff0 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pack_test.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +template +class PackOpModel : public SingleOpModel { + public: + PackOpModel(const TensorData& input_template, int axis, int values_count) { + std::vector> all_input_shapes; + for (int i = 0; i < values_count; ++i) { + all_input_shapes.push_back(input_template.shape); + AddInput(input_template); + } + output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min, + input_template.max}); + SetBuiltinOp(BuiltinOperator_PACK, BuiltinOptions_PackOptions, + CreatePackOptions(builder_, values_count, axis).Union()); + BuildInterpreter(all_input_shapes); + } + + void SetInput(int index, std::initializer_list data) { + PopulateTensor(index, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int output_; +}; + +TEST(PackOpTest, FloatThreeInputs) { + PackOpModel model({TensorType_FLOAT32, {2}}, 0, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6})); +} + +TEST(PackOpTest, FloatThreeInputsDifferentAxis) { + PackOpModel model({TensorType_FLOAT32, {2}}, 1, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(PackOpTest, FloatMultilDimensions) { + PackOpModel model({TensorType_FLOAT32, {2, 3}}, 1, 2); + model.SetInput(0, {1, 2, 3, 4, 5, 6}); + model.SetInput(1, {7, 8, 9, 10, 11, 12}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); +} + +TEST(PackOpTest, IntThreeInputs) { + PackOpModel model({TensorType_INT32, {2}}, 0, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6})); +} + +TEST(PackOpTest, IntThreeInputsDifferentAxis) { + PackOpModel model({TensorType_INT32, {2}}, 1, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(PackOpTest, IntMultilDimensions) { + PackOpModel model({TensorType_INT32, {2, 3}}, 1, 2); + model.SetInput(0, {1, 2, 3, 4, 5, 6}); + model.SetInput(1, {7, 8, 9, 10, 11, 12}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index 83668cb4ca87e9eb53ab4ba9e88f91e3315594de..4be8c243c17c533e8c7d5aa7bb50c9d790b06995 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -128,7 +128,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(nupurgarg): Change kernel implementation to use padding arrays in // forward order (depth, width, height, batch). // Build paddings in order of int[] = {batch, height, width, depth} to match - // kernel implementation of Pad in referenced_ops.h and optimized_ops.h. + // kernel implementation of Pad in reference_ops.h and optimized_ops.h. for (int idx = op_context.dims - 1; idx >= 0; --idx) { before_padding.push_back(paddings_data[idx * 2]); after_padding.push_back(paddings_data[idx * 2 + 1]); diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc index 311e9b8399726d758182e1f084a890d6f10e57ce..29a5be068368365e67ad0653b775afe1e976f23a 100644 --- a/tensorflow/contrib/lite/kernels/pooling.cc +++ b/tensorflow/contrib/lite/kernels/pooling.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -80,24 +79,24 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { // Matching GetWindowedOutputSize in TensorFlow. auto padding = params->padding; - auto computeOutSize = [padding](int imageSize, int filterSize, - int stride) -> int { + auto compute_out_size = [padding](int image_size, int filter_size, + int stride) -> int { return padding == kTfLitePaddingSame - ? (imageSize + stride - 1) / stride + ? (image_size + stride - 1) / stride : padding == kTfLitePaddingValid - ? (imageSize - filterSize + stride) / stride + ? (image_size - filter_size + stride) / stride : 0; }; - int outWidth = - computeOutSize(width, params->filter_width, params->stride_width); - int outHeight = - computeOutSize(height, params->filter_height, params->stride_height); + int out_width = + compute_out_size(width, params->filter_width, params->stride_width); + int out_height = + compute_out_size(height, params->filter_height, params->stride_height); data->padding.height = ComputePadding(params->stride_height, 1, height, - params->filter_height, outHeight); + params->filter_height, out_height); data->padding.width = ComputePadding(params->stride_width, 1, width, - params->filter_width, outWidth); + params->filter_width, out_width); if (input->type == kTfLiteUInt8) { if (pool_type == kAverage || pool_type == kMax) { @@ -111,12 +110,12 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { } } - TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4); - outputSize->data[0] = batches; - outputSize->data[1] = outHeight; - outputSize->data[2] = outWidth; - outputSize->data[3] = channels_out; - return context->ResizeTensor(context, output, outputSize); + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = batches; + output_size->data[1] = out_height; + output_size->data[2] = out_width; + output_size->data[3] = channels_out; + return context->ResizeTensor(context, output, output_size); } template @@ -124,14 +123,21 @@ void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, const TfLiteTensor* input, TfLiteTensor* output) { float activation_min, activation_max; - CalculateActivationRangeFloat(params->activation, &activation_min, - &activation_max); -#define TF_LITE_AVERAGE_POOL(type) \ - type::AveragePool( \ - GetTensorData(input), GetTensorDims(input), params->stride_width, \ - params->stride_height, data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), GetTensorDims(output)) + CalculateActivationRange(params->activation, &activation_min, + &activation_max); +#define TF_LITE_AVERAGE_POOL(type) \ + tflite::PoolParams op_params; \ + op_params.stride_height = params->stride_height; \ + op_params.stride_width = params->stride_width; \ + op_params.filter_height = params->filter_height; \ + op_params.filter_width = params->filter_width; \ + op_params.padding_values.height = data->padding.height; \ + op_params.padding_values.width = data->padding.width; \ + op_params.float_activation_min = activation_min; \ + op_params.float_activation_max = activation_max; \ + type::AveragePool(op_params, GetTensorShape(input), \ + GetTensorData(input), GetTensorShape(output), \ + GetTensorData(output)) if (kernel_type == kReference) { TF_LITE_AVERAGE_POOL(reference_ops); } else { @@ -148,13 +154,19 @@ void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node, int32_t activation_max; CalculateActivationRangeUint8(params->activation, output, &activation_min, &activation_max); -#define TF_LITE_AVERAGE_POOL(type) \ - type::AveragePool(GetTensorData(input), GetTensorDims(input), \ - params->stride_width, params->stride_height, \ - data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, \ - activation_min, activation_max, \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_AVERAGE_POOL(type) \ + tflite::PoolParams op_params; \ + op_params.stride_height = params->stride_height; \ + op_params.stride_width = params->stride_width; \ + op_params.filter_height = params->filter_height; \ + op_params.filter_width = params->filter_width; \ + op_params.padding_values.height = data->padding.height; \ + op_params.padding_values.width = data->padding.width; \ + op_params.quantized_activation_min = activation_min; \ + op_params.quantized_activation_max = activation_max; \ + type::AveragePool(op_params, GetTensorShape(input), \ + GetTensorData(input), GetTensorShape(output), \ + GetTensorData(output)) if (kernel_type == kReference) { TF_LITE_AVERAGE_POOL(reference_ops); } else { @@ -168,14 +180,20 @@ void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, const TfLiteTensor* input, TfLiteTensor* output) { float activation_min, activation_max; - CalculateActivationRangeFloat(params->activation, &activation_min, - &activation_max); + CalculateActivationRange(params->activation, &activation_min, + &activation_max); #define TF_LITE_MAX_POOL(type) \ - type::MaxPool( \ - GetTensorData(input), GetTensorDims(input), params->stride_width, \ - params->stride_height, data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), GetTensorDims(output)) + tflite::PoolParams op_params; \ + op_params.stride_height = params->stride_height; \ + op_params.stride_width = params->stride_width; \ + op_params.filter_height = params->filter_height; \ + op_params.filter_width = params->filter_width; \ + op_params.padding_values.height = data->padding.height; \ + op_params.padding_values.width = data->padding.width; \ + op_params.float_activation_min = activation_min; \ + op_params.float_activation_max = activation_max; \ + type::MaxPool(op_params, GetTensorShape(input), GetTensorData(input), \ + GetTensorShape(output), GetTensorData(output)) if (kernel_type == kReference) { TF_LITE_MAX_POOL(reference_ops); } else { @@ -192,13 +210,19 @@ void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, int32_t activation_max; CalculateActivationRangeUint8(params->activation, output, &activation_min, &activation_max); -#define TF_LITE_MAX_POOL(type) \ - type::MaxPool(GetTensorData(input), GetTensorDims(input), \ - params->stride_width, params->stride_height, \ - data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), \ - GetTensorDims(output)) +#define TF_LITE_MAX_POOL(type) \ + tflite::PoolParams op_params; \ + op_params.stride_height = params->stride_height; \ + op_params.stride_width = params->stride_width; \ + op_params.filter_height = params->filter_height; \ + op_params.filter_width = params->filter_width; \ + op_params.padding_values.height = data->padding.height; \ + op_params.padding_values.width = data->padding.width; \ + op_params.quantized_activation_min = activation_min; \ + op_params.quantized_activation_max = activation_max; \ + type::MaxPool(op_params, GetTensorShape(input), \ + GetTensorData(input), GetTensorShape(output), \ + GetTensorData(output)) if (kernel_type == kReference) { TF_LITE_MAX_POOL(reference_ops); } else { @@ -212,14 +236,20 @@ void L2EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLitePoolParams* params, OpData* data, const TfLiteTensor* input, TfLiteTensor* output) { float activation_min, activation_max; - CalculateActivationRangeFloat(params->activation, &activation_min, - &activation_max); -#define TF_LITE_L2_POOL(type) \ - type::L2Pool( \ - GetTensorData(input), GetTensorDims(input), params->stride_width, \ - params->stride_height, data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), GetTensorDims(output)) + CalculateActivationRange(params->activation, &activation_min, + &activation_max); +#define TF_LITE_L2_POOL(type) \ + tflite::PoolParams op_params; \ + op_params.stride_height = params->stride_height; \ + op_params.stride_width = params->stride_width; \ + op_params.filter_height = params->filter_height; \ + op_params.filter_width = params->filter_width; \ + op_params.padding_values.height = data->padding.height; \ + op_params.padding_values.width = data->padding.width; \ + op_params.float_activation_min = activation_min; \ + op_params.float_activation_max = activation_max; \ + type::L2Pool(op_params, GetTensorShape(input), GetTensorData(input), \ + GetTensorShape(output), GetTensorData(output)) if (kernel_type == kReference) { TF_LITE_L2_POOL(reference_ops); } else { diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a539c47a8fbe392e0e6542ab8ffb9065b550485 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pow.cc @@ -0,0 +1,143 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pow { +namespace { + +// Input/output tensor index. +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +// Op data for pow op. +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + + const TfLiteType type = input1->type; + if (type != kTfLiteInt32 && type != kTfLiteFloat32) { + context->ReportError(context, "Unsupported data type %d.", type); + return kTfLiteError; + } + output->type = type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } + + return context->ResizeTensor(context, output, output_size); +} + +template +void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2, + TfLiteTensor* output, bool requires_broadcast) { + if (requires_broadcast) { + reference_ops::BroadcastPow(GetTensorData(input1), GetTensorDims(input1), + GetTensorData(input2), GetTensorDims(input2), + GetTensorData(output), + GetTensorDims(output)); + } else { + reference_ops::Pow(GetTensorData(input1), GetTensorDims(input1), + GetTensorData(input2), GetTensorDims(input2), + GetTensorData(output), GetTensorDims(output)); + } +} + +TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) { + const int64_t num_elements = NumElements(input); + const int32_t* data = GetTensorData(input); + for (int i = 0; i < num_elements; ++i) { + if (data[i] < 0) { + context->ReportError(context, + "POW does not support negative value for int32."); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (output->type) { + case kTfLiteInt32: { + // TensorFlow does not support negative for int32. + TF_LITE_ENSURE_OK(context, CheckValue(context, input2)); + PowImpl(input1, input2, output, data->requires_broadcast); + break; + } + case kTfLiteFloat32: { + PowImpl(input1, input2, output, data->requires_broadcast); + break; + } + default: { + context->ReportError(context, "Unsupported data type: %d", output->type); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +} // namespace +} // namespace pow + +TfLiteRegistration* Register_POW() { + static TfLiteRegistration r = {pow::Init, pow::Free, pow::Prepare, pow::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pow_test.cc b/tensorflow/contrib/lite/kernels/pow_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..74b3aef5bd39d8bdb6375f24bd00d793889deef8 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pow_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +template +class PowOpModel : public SingleOpModel { + public: + PowOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_POW, BuiltinOptions_PowOptions, + CreatePowOptions(builder_).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; +}; + +TEST(PowOpModel, Simple) { + PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {12, 2, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 1}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(12, 4, 343, 8)); +} + +TEST(PowOpModel, NegativeAndZeroValue) { + PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {0, 2, -7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 0}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(0, 4, -343, 1)); +} + +TEST(PowOpModel, Float) { + PowOpModel model({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + model.PopulateTensor(model.input1(), {0.3, 0.4, 0.7, 5.8}); + model.PopulateTensor(model.input2(), {0.5, 2.7, 3.1, 3.2}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.5477226, 0.08424846, 0.33098164, 277.313}, 1e-3))); +} + +TEST(PowOpModel, NegativeFloatTest) { + PowOpModel model({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + model.PopulateTensor(model.input1(), {0.3, 0.4, 0.7, 5.8}); + model.PopulateTensor(model.input2(), {0.5, -2.7, 3.1, -3.2}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.5477226, 11.869653, 0.33098164, 0.003606}, 1e-3))); +} + +TEST(PowOpModel, BroadcastTest) { + PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1}}, {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {12, 2, 7, 8}); + model.PopulateTensor(model.input2(), {4}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/reduce.cc similarity index 53% rename from tensorflow/contrib/lite/kernels/mean.cc rename to tensorflow/contrib/lite/kernels/reduce.cc index 03e5db24de3f3c2d4e17df21bc0b592a02078d6b..e99f67c7258c555903069dff67a86a3703249c7c 100644 --- a/tensorflow/contrib/lite/kernels/mean.cc +++ b/tensorflow/contrib/lite/kernels/reduce.cc @@ -25,21 +25,21 @@ limitations under the License. namespace tflite { namespace ops { namespace builtin { -namespace mean { +namespace reduce { -// This file has reference implementation of Mean. +// This file has reference implementation of reduce_* operators. enum KernelType { kReference, }; -struct MeanContext { - MeanContext(TfLiteContext* context, TfLiteNode* node) { - params = reinterpret_cast(node->builtin_data); +struct OpContext { + OpContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); input = GetInput(context, node, 0); axis = GetInput(context, node, 1); output = GetOutput(context, node, 0); } - TfLiteMeanParams* params; + TfLiteReducerParams* params; const TfLiteTensor* input; const TfLiteTensor* axis; TfLiteTensor* output; @@ -58,7 +58,7 @@ void Free(TfLiteContext* context, void* buffer) { } // Resizes the temp tensor that stores resolved axis. -TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context, +TfLiteStatus ResizeTempAxis(TfLiteContext* context, OpContext* op_context, TfLiteTensor* resolved_axis) { TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1); axis_size->data[0] = static_cast(NumElements(op_context->axis)); @@ -66,7 +66,7 @@ TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context, } // Resizes the temp tensor that stores temp sum of reduced elements. -TfLiteStatus ResizeTempSum(TfLiteContext* context, MeanContext* op_context, +TfLiteStatus ResizeTempSum(TfLiteContext* context, OpContext* op_context, TfLiteTensor* temp_sum) { TfLiteIntArray* size = TfLiteIntArrayCreate(1); size->data[0] = static_cast(NumElements(op_context->output)); @@ -74,11 +74,14 @@ TfLiteStatus ResizeTempSum(TfLiteContext* context, MeanContext* op_context, } // Resizes output array based on the input size and resolved axis. -TfLiteStatus ResizeOutputTensor(TfLiteContext* context, - MeanContext* op_context) { +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) { size_t num_axis = NumElements(op_context->axis); const TfLiteIntArray* input_dims = op_context->input->dims; int input_num_dims = NumDimensions(op_context->input); + if (input_num_dims == 0) { + return context->ResizeTensor(context, op_context->output, + TfLiteIntArrayCreate(0)); + } const int* axis = GetTensorData(op_context->axis); if (op_context->params->keep_dims) { TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims); @@ -140,7 +143,7 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, // Initializes temp tensors to store index and resolved axis. TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, - MeanContext* op_context) { + OpContext* op_context) { // Creates a temp index to iterate through input data. int* scratch_tensor_index = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); @@ -180,33 +183,44 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - MeanContext op_context(context, node); + OpContext op_context(context, node); TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context)); TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); - TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); // Leaves work to Eval if axis is not constant; else resizes output. if (!IsConstantTensor(op_context.axis)) { SetTensorToDynamic(op_context.output); SetTensorToDynamic(resolved_axis); - SetTensorToDynamic(temp_sum); return kTfLiteOk; } resolved_axis->allocation_type = kTfLiteArenaRw; TF_LITE_ENSURE_OK(context, ResizeTempAxis(context, &op_context, resolved_axis)); TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + return kTfLiteOk; +} + +TfLiteStatus PrepareMean(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_OK(context, PrepareSimple(context, node)); + + // reduce_mean requires a buffer to store intermediate sum result. + OpContext op_context(context, node); + TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); + if (!IsConstantTensor(op_context.axis)) { + SetTensorToDynamic(temp_sum); + return kTfLiteOk; + } temp_sum->allocation_type = kTfLiteArenaRw; return ResizeTempSum(context, &op_context, temp_sum); } template -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - MeanContext op_context(context, node); +TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); int num_axis = static_cast(NumElements(op_context.axis)); TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); @@ -255,16 +269,186 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { #undef TF_LITE_MEAN return kTfLiteOk; } -} // namespace mean + +template +TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + int num_axis = static_cast(NumElements(op_context.axis)); + TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); + TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + +#define TF_LITE_SUM(kernel_type, data_type) \ + kernel_type::Sum<>( \ + GetTensorData(op_context.input), \ + op_context.input->dims->data, op_context.input->dims->size, \ + GetTensorData(op_context.output), \ + op_context.output->dims->data, op_context.output->dims->size, \ + GetTensorData(op_context.axis), num_axis, \ + op_context.params->keep_dims, GetTensorData(temp_index), \ + GetTensorData(resolved_axis)) + + if (kernel_type == kReference) { + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, float)); + break; + case kTfLiteInt32: + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int)); + break; + case kTfLiteInt64: + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int64_t)); + break; + case kTfLiteUInt8: + TF_LITE_ENSURE_EQ(context, op_context.input->params.scale, + op_context.output->params.scale); + TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, + op_context.output->params.zero_point); + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, uint8_t)); + break; + default: + return kTfLiteError; + } + } +#undef TF_LITE_SUM + return kTfLiteOk; +} + +template +TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + int64_t num_axis = NumElements(op_context.axis); + TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); + TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + +#define TF_LITE_PROD(kernel_type, data_type) \ + kernel_type::ReduceProd<>( \ + GetTensorData(op_context.input), \ + op_context.input->dims->data, op_context.input->dims->size, \ + GetTensorData(op_context.output), \ + op_context.output->dims->data, op_context.output->dims->size, \ + GetTensorData(op_context.axis), num_axis, \ + op_context.params->keep_dims, GetTensorData(temp_index), \ + GetTensorData(resolved_axis)) + + if (kernel_type == kReference) { + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, float)); + break; + case kTfLiteInt32: + TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int)); + break; + case kTfLiteInt64: + TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int64_t)); + break; + case kTfLiteUInt8: + // TODO(wangtz): uint8 reduce_prod is not yet supported. + default: + return kTfLiteError; + } + } +#undef TF_LITE_PROD + return kTfLiteOk; +} + +template +TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + int64_t num_axis = NumElements(op_context.axis); + TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); + TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + +#define TF_LITE_MAX(kernel_type, data_type) \ + kernel_type::ReduceMax<>( \ + GetTensorData(op_context.input), \ + op_context.input->dims->data, op_context.input->dims->size, \ + GetTensorData(op_context.output), \ + op_context.output->dims->data, op_context.output->dims->size, \ + GetTensorData(op_context.axis), num_axis, \ + op_context.params->keep_dims, GetTensorData(temp_index), \ + GetTensorData(resolved_axis)) + + if (kernel_type == kReference) { + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, float)); + break; + case kTfLiteInt32: + TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int)); + break; + case kTfLiteInt64: + TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int64_t)); + break; + case kTfLiteUInt8: + TF_LITE_ENSURE_EQ(context, op_context.input->params.scale, + op_context.output->params.scale); + TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, + op_context.output->params.zero_point); + TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, uint8_t)); + break; + default: + return kTfLiteError; + } + } +#undef TF_LITE_MAX + return kTfLiteOk; +} + +} // namespace reduce TfLiteRegistration* Register_MEAN_REF() { - static TfLiteRegistration r = {mean::Init, mean::Free, mean::Prepare, - mean::Eval}; + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareMean, + reduce::EvalMean}; + return &r; +} + +TfLiteRegistration* Register_SUM_REF() { + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareSimple, + reduce::EvalSum}; + return &r; +} + +TfLiteRegistration* Register_REDUCE_PROD_REF() { + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareSimple, + reduce::EvalProd}; + return &r; +} + +TfLiteRegistration* Register_REDUCE_MAX_REF() { + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareSimple, + reduce::EvalMax}; return &r; } // TODO(kanlig): add optimized implementation of Mean. TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); } +TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); } +TfLiteRegistration* Register_REDUCE_PROD() { + return Register_REDUCE_PROD_REF(); +} +TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_REF(); } } // namespace builtin } // namespace ops diff --git a/tensorflow/contrib/lite/kernels/reduce_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5d432d34ef5118e7164d7f767dad6017aa640e51 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/reduce_test.cc @@ -0,0 +1,675 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; +using ::testing::IsEmpty; + +class BaseOpModel : public SingleOpModel { + public: + void SetAxis(const std::vector& data) { PopulateTensor(axis_, data); } + + template + void SetInput(std::vector data) { + PopulateTensor(input_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + int Input() { return input_; } + + protected: + int input_; + int axis_; + int output_; +}; + +// Model for the tests case where axis is a const tensor. +class MeanOpConstModel : public BaseOpModel { + public: + MeanOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a dynamic tensor. +class MeanOpDynamicModel : public BaseOpModel { + public: + MeanOpDynamicModel(const TensorData& input, const TensorData& output, + const TensorData& axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddInput(axis); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a const tensor. +class SumOpConstModel : public BaseOpModel { + public: + SumOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_SUM, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a dynamic tensor. +class SumOpDynamicModel : public BaseOpModel { + public: + SumOpDynamicModel(const TensorData& input, const TensorData& output, + const TensorData& axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddInput(axis); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_SUM, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a const tensor. +class ProdOpConstModel : public BaseOpModel { + public: + ProdOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_REDUCE_PROD, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a dynamic tensor. +class ProdOpDynamicModel : public BaseOpModel { + public: + ProdOpDynamicModel(const TensorData& input, const TensorData& output, + const TensorData& axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddInput(axis); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_REDUCE_PROD, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a const tensor. +class MaxOpConstModel : public BaseOpModel { + public: + MaxOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_REDUCE_MAX, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a dynamic tensor. +class MaxOpDynamicModel : public BaseOpModel { + public: + MaxOpDynamicModel(const TensorData& input, const TensorData& output, + const TensorData& axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddInput(axis); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_REDUCE_MAX, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// for quantized Add, the error shouldn't exceed step +float GetTolerance(int min, int max) { return (max - min) / 255.0; } + +// Tests for reduce_mean +TEST(ConstFloatMeanOpTest, NotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); +} + +TEST(ConstFloatMeanOpTest, KeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); +} + +TEST(ConstFloatMeanOpTest, Scalar) { + std::vector data = {3.27}; + MeanOpConstModel m({TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, {}, + {0}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3.27}))); +} + +TEST(DynamicFloatMeanOpTest, NotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, + false); + std::vector axis = {1, 0, -3, -3}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); +} + +TEST(DynamicFloatMeanOpTest, KeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, + true); + std::vector axis = {0, 2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); +} + +TEST(DynamicFloatMeanOpTest, Scale) { + std::vector data = {9.527}; + MeanOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); +} + + +TEST(ConstUint8MeanOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + MeanOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + {0.4, 0.4}, kQuantizedTolerance))); +} + +TEST(ConstUint8MeanOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + MeanOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance))); +} + +TEST(DynamicUint8MeanOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-5.0, 2.0); + std::vector data = {1.3, -4.8, -3.6, 0.24}; + MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0}, + {TensorType_UINT8, {2}, -5.0, 2.0}, + {TensorType_INT32, {1}}, false); + std::vector axis = {1}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-1.75, -1.68}, kQuantizedTolerance))); +} + +TEST(DynamicUint8MeanOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::vector data = {11.14, -0.14, 7.423, 0.879}; + MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0}, + {TensorType_UINT8, {2}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance))); +} + +// Tests for reduce_sum + +TEST(ConstFloatSumOpTest, NotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({144, 156}))); +} + +TEST(ConstFloatSumOpTest, KeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({84, 100, 116}))); +} + +TEST(DynamicFloatSumOpTest, NotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, + false); + std::vector axis = {1, 0, -3, -3}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({144, 156}))); +} + +TEST(ConstFloatSumOpTest, Scalar) { + std::vector data = {17.}; + SumOpConstModel m({TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, {}, {0}, + false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({17.}))); +} + +TEST(DynamicFloatSumOpTest, KeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true); + std::vector axis = {0, 2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({84, 100, 116}))); +} + +TEST(DynamicFloatSumOpTest, Scale) { + std::vector data = {9.527}; + SumOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); +} + +TEST(ConstUint8SumOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({-0.823529, -0.815686}, kQuantizedTolerance))); +} + +TEST(ConstUint8SumOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + SumOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-0.407843, -0.313726, 0.0941177}, + kQuantizedTolerance))); +} + +TEST(DynamicUint8SumOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-5.0, 2.0); + std::vector data = {1.3, -4.8, -3.6, 0.24}; + SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0}, + {TensorType_UINT8, {2}, -5.0, 2.0}, + {TensorType_INT32, {1}}, false); + std::vector axis = {1}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({1.48235, 1.64706}, kQuantizedTolerance))); +} + +TEST(DynamicUint8SumOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::vector data = {11.14, -0.14, 7.423, 0.879}; + SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0}, + {TensorType_UINT8, {2}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({6.47059, 10.698}, kQuantizedTolerance))); +} + +// Tests for reduce_prod + +TEST(ConstFloatProdOpTest, NotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + ProdOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3.162341376e+11, 1.9619905536e+12}))); +} + +TEST(ConstFloatProdOpTest, KeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + ProdOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray( + ArrayFloatNear({7.74592e+06, 1.197504e+08, 6.6889152e+08}))); +} + +TEST(DynamicFloatProdOpTest, NotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + ProdOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, + false); + std::vector axis = {1, 0, -3, -3}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3.16234143225e+11, 1.9619905536e+12}))); +} + +TEST(DynamicFloatProdOpTest, KeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + ProdOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, + true); + std::vector axis = {0, 2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray( + ArrayFloatNear({7.74592e+06, 1.197504e+08, 6.6889152e+08}))); +} + +TEST(DynamicFloatProdOpTest, Scale) { + std::vector data = {9.527}; + ProdOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); +} + +// Tests for reduce_max + +TEST(ConstFloatMaxOpTest, NotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MaxOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({23, 24}))); +} + +TEST(ConstFloatMaxOpTest, KeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MaxOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({20, 22, 24}))); +} + +TEST(DynamicFloatMaxOpTest, NotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MaxOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, + false); + std::vector axis = {1, 0, -3, -3}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({23, 24}))); +} + +TEST(DynamicFloatMaxOpTest, KeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MaxOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true); + std::vector axis = {0, 2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({20, 22, 24}))); +} + +TEST(DynamicFloatMaxOpTest, Scale) { + std::vector data = {9.527}; + MaxOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); +} + +TEST(ConstUint8MaxOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + MaxOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({0.501961, 0.603922}, kQuantizedTolerance))); +} + +TEST(ConstUint8MaxOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + MaxOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({0.4, 0.4, 0.603922}, kQuantizedTolerance))); +} + +TEST(DynamicUint8MaxOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-5.0, 2.0); + std::vector data = {1.3, -4.8, -3.6, 0.24}; + MaxOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0}, + {TensorType_UINT8, {2}, -5.0, 2.0}, + {TensorType_INT32, {1}}, false); + std::vector axis = {1}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({1.2902, 0.247059}, kQuantizedTolerance))); +} + +TEST(DynamicUint8MaxOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::vector data = {11.14, -0.14, 7.423, 0.879}; + MaxOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0}, + {TensorType_UINT8, {2}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({11.1294, 0.862745}, kQuantizedTolerance))); +} + +TEST(DynamicUint8MaxOpTest, Scalar) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::vector data = {11.14}; + MaxOpDynamicModel m({TensorType_UINT8, {}, -10.0, 12.0}, + {TensorType_UINT8, {}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({11.1294}, kQuantizedTolerance))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index fca040649d8ff76b42db1962135cd4315e961ec0..c5e77fbfed10881d060de7e9292fd5cea5451e60 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -22,6 +22,7 @@ namespace custom { TfLiteRegistration* Register_AUDIO_SPECTROGRAM(); TfLiteRegistration* Register_MFCC(); +TfLiteRegistration* Register_DETECTION_POSTPROCESS(); } // namespace custom @@ -81,6 +82,7 @@ TfLiteRegistration* Register_PRELU(); TfLiteRegistration* Register_MAXIMUM(); TfLiteRegistration* Register_MINIMUM(); TfLiteRegistration* Register_ARG_MAX(); +TfLiteRegistration* Register_ARG_MIN(); TfLiteRegistration* Register_GREATER(); TfLiteRegistration* Register_GREATER_EQUAL(); TfLiteRegistration* Register_LESS(); @@ -88,6 +90,9 @@ TfLiteRegistration* Register_LESS_EQUAL(); TfLiteRegistration* Register_FLOOR(); TfLiteRegistration* Register_TILE(); TfLiteRegistration* Register_NEG(); +TfLiteRegistration* Register_SUM(); +TfLiteRegistration* Register_REDUCE_PROD(); +TfLiteRegistration* Register_REDUCE_MAX(); TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SLICE(); TfLiteRegistration* Register_SIN(); @@ -96,6 +101,43 @@ TfLiteRegistration* Register_EXPAND_DIMS(); TfLiteRegistration* Register_SPARSE_TO_DENSE(); TfLiteRegistration* Register_EQUAL(); TfLiteRegistration* Register_NOT_EQUAL(); +TfLiteRegistration* Register_SQRT(); +TfLiteRegistration* Register_RSQRT(); +TfLiteRegistration* Register_SHAPE(); +TfLiteRegistration* Register_POW(); +TfLiteRegistration* Register_FAKE_QUANT(); +TfLiteRegistration* Register_PACK(); +TfLiteRegistration* Register_ONE_HOT(); +TfLiteRegistration* Register_LOGICAL_OR(); +TfLiteRegistration* Register_LOGICAL_AND(); +TfLiteRegistration* Register_LOGICAL_NOT(); + +TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { + context->ReportError( + context, + "Regular TensorFlow ops are not supported by this interpreter. Make sure " + "you invoke the Eager delegate before inference."); + return kTfLiteError; +} + +const TfLiteRegistration* BuiltinOpResolver::FindOp(tflite::BuiltinOperator op, + int version) const { + return MutableOpResolver::FindOp(op, version); +} + +const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op, + int version) const { + // Return the NULL Op for all ops whose name start with "Eager:", allowing + // the interpreter to delegate their execution. + if (string(op).find("Eager:") == 0) { + static TfLiteRegistration null_op{ + nullptr, nullptr, &UnsupportedTensorFlowOp, + nullptr, nullptr, BuiltinOperator_CUSTOM, + "Eager", 1}; + return &null_op; + } + return MutableOpResolver::FindOp(op, version); +} BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -117,7 +159,9 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP()); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, Register_EMBEDDING_LOOKUP_SPARSE()); - AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED()); + AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION()); AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP()); AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX()); @@ -159,6 +203,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM()); AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM()); AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX()); + AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN()); AddBuiltin(BuiltinOperator_GREATER, Register_GREATER()); AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL()); AddBuiltin(BuiltinOperator_LESS, Register_LESS()); @@ -170,10 +215,23 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); AddBuiltin(BuiltinOperator_TILE, Register_TILE()); + AddBuiltin(BuiltinOperator_SUM, Register_SUM()); + AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD()); + AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX()); AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS()); AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL()); AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL()); + AddBuiltin(BuiltinOperator_SQRT, Register_SQRT()); + AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT()); + AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE()); + AddBuiltin(BuiltinOperator_POW, Register_POW()); + AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2); + AddBuiltin(BuiltinOperator_PACK, Register_PACK()); + AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT()); + AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR()); + AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND()); + AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT()); #if 0 // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that @@ -181,6 +239,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddCustom("Mfcc", tflite::ops::custom::Register_MFCC()); AddCustom("AudioSpectrogram", tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); + AddCustom("TFLite_Detection_PostProcess", + tflite::ops::custom::Register_DETECTION_POSTPROCESS()); #endif } diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h index b928f1b302580d52f708bbf85dfcfc0f79ff1e69..0296152d68d6836fd592a65eeea69a7d4ebbb6ef 100644 --- a/tensorflow/contrib/lite/kernels/register.h +++ b/tensorflow/contrib/lite/kernels/register.h @@ -26,10 +26,14 @@ namespace builtin { class BuiltinOpResolver : public MutableOpResolver { public: BuiltinOpResolver(); + + const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const override; + const TfLiteRegistration* FindOp(const char* op, int version) const override; }; } // namespace builtin } // namespace ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc index 3287040695140e3e7921c9f517450b9416b050b6..49ba0571e2f214c0b2407240753fcec0661c71bf 100644 --- a/tensorflow/contrib/lite/kernels/reshape.cc +++ b/tensorflow/contrib/lite/kernels/reshape.cc @@ -25,16 +25,11 @@ namespace builtin { namespace reshape { constexpr int kInputTensor = 0; +constexpr int kShapeTensor = 1; constexpr int kOutputTensor = 0; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - - // TODO(ahentz): we are often given a tensor with the shape but we only pay - // attention to what the shape specified in 'params'. - TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - +TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node, + TfLiteIntArray* output_shape) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -42,37 +37,84 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // special -1 value, meaning it will be calculated automatically based on the // input. Here we calculate what that dimension should be so that the number // of output elements in the same as the number of input elements. - int num_input_elements = 1; - for (int i = 0; i < NumDimensions(input); ++i) { - num_input_elements *= SizeOfDimension(input, i); - } + int num_input_elements = NumElements(input); - TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions); int num_output_elements = 1; int stretch_dim = -1; - for (int i = 0; i < params->num_dimensions; ++i) { - int value = params->shape[i]; + for (int i = 0; i < output_shape->size; ++i) { + int value = output_shape->data[i]; if (value == -1) { TF_LITE_ENSURE_EQ(context, stretch_dim, -1); stretch_dim = i; } else { num_output_elements *= value; - output_size->data[i] = value; } } if (stretch_dim != -1) { - output_size->data[stretch_dim] = num_input_elements / num_output_elements; - num_output_elements *= output_size->data[stretch_dim]; + output_shape->data[stretch_dim] = num_input_elements / num_output_elements; + num_output_elements *= output_shape->data[stretch_dim]; } TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements); - return context->ResizeTensor(context, output, output_size); + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus ResizeOutputWithShapeTensor(TfLiteContext* context, + TfLiteNode* node) { + const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); + + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]); + for (int i = 0; i < output_shape->size; ++i) { + output_shape->data[i] = shape->data.i32[i]; + } + return ResizeOutput(context, node, output_shape); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Attempt to use shape tensor if it exists. + if (NumInputs(node) == 2) { + const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); + // Check if the shape tensor is valid. + if (shape->dims->size == 1 && shape->type == kTfLiteInt32) { + // Set the output tensor as dynamic if the shape isn't constnat. + if (!IsConstantTensor(shape)) { + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + SetTensorToDynamic(output); + return kTfLiteOk; + } + // Shape is constant. Resize now. + return ResizeOutputWithShapeTensor(context, node); + } + } + // The function is returned above this line if the shape tensor is usable. + // Now fallback to the shape parameter in `TfLiteReshapeParams`. + int num_dimensions = params->num_dimensions; + if (num_dimensions == 1 && params->shape[0] == 0) { + // Legacy tflite models use a shape parameter of [0] to indicate scalars, + // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during + // toco conversion. + num_dimensions = 0; + } + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + for (int i = 0; i < num_dimensions; ++i) { + output_shape->data[i] = params->shape[i]; + } + return ResizeOutput(context, node, output_shape); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputWithShapeTensor(context, node)); + } + memcpy(output->data.raw, input->data.raw, input->bytes); return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc index aecbd0399f7454045e8189072f45b695b0525204..52d71350d3ba9a27bf9a8df7a194161c4fb7f87c 100644 --- a/tensorflow/contrib/lite/kernels/reshape_test.cc +++ b/tensorflow/contrib/lite/kernels/reshape_test.cc @@ -22,18 +22,27 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using ::testing::IsEmpty; class ReshapeOpModel : public SingleOpModel { public: ReshapeOpModel(std::initializer_list input_shape, - std::initializer_list new_shape) { + std::initializer_list new_shape, + bool use_shape_input_tensor = false) { input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); + int shape_input_tensor = + use_shape_input_tensor ? AddInput(TensorType_INT32) : -1; SetBuiltinOp( BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, CreateReshapeOptions(builder_, builder_.CreateVector(new_shape)) .Union()); - BuildInterpreter({input_shape}); + if (use_shape_input_tensor) { + BuildInterpreter({input_shape, GetShape(shape_input_tensor)}); + PopulateTensor(shape_input_tensor, new_shape); + } else { + BuildInterpreter({input_shape}); + } } void SetInput(std::initializer_list data) { @@ -71,6 +80,14 @@ TEST(ReshapeOpTest, SimpleTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); } +TEST(ReshapeOpTest, ShapeTensorInput) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}, /*use_shape_input_tensor=*/true); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); +} + TEST(ReshapeOpTest, WithStretchDimension) { ReshapeOpModel m({1, 2, 4, 1}, {2, 1, -1}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); @@ -79,6 +96,22 @@ TEST(ReshapeOpTest, WithStretchDimension) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 4})); } +TEST(ReshapeOpTest, ScalarOutput) { + ReshapeOpModel m({1}, {}); + m.SetInput({3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); +} + +TEST(ReshapeOpTest, LegacyScalarOutput) { + ReshapeOpModel m({1}, {0}); + m.SetInput({3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 10caffea03ebcec7862df1627541ac3d076b04e4..f4289105f7931ae572f219a61b5479287aff926a 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -247,7 +247,7 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { 3, 6, // 9, 12, // 4, 10, // - 10, 16 // + 12, 16 // }); m.SetSize({3, 3}); m.Invoke(); @@ -256,8 +256,8 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { 7, 9, 10, // 9, 11, 12, // 4, 8, 10, // - 8, 12, 14, // - 10, 13, 16, // + 9, 12, 14, // + 12, 14, 16, // }))); ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3}); @@ -265,7 +265,7 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { 3, 6, // 9, 12, // 4, 10, // - 10, 16 // + 12, 16 // }); const_m.Invoke(); EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ @@ -273,35 +273,35 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { 7, 9, 10, // 9, 11, 12, // 4, 8, 10, // - 8, 12, 14, // - 10, 13, 16, // + 9, 12, 14, // + 12, 14, 16, // }))); } TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) { ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}); m.SetInput({ - 3, 4, 6, 10, // - 9, 10, 12, 16, // + 3, 4, 6, 10, // + 10, 12, 14, 16, // }); m.SetSize({3, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 13, 12, 16, // + 3, 4, 5, 8, 6, 10, // + 7, 9, 10, 12, 11, 14, // + 10, 12, 12, 14, 14, 16, // }))); ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3}); const_m.SetInput({ - 3, 4, 6, 10, // - 9, 10, 12, 16, // + 3, 4, 6, 10, // + 10, 12, 14, 16, // }); const_m.Invoke(); EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 13, 12, 16, // + 3, 4, 5, 8, 6, 10, // + 7, 9, 10, 12, 11, 14, // + 10, 12, 12, 14, 14, 16, // }))); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc index 9b6cee3cb55bf93b987fa8e59bdf9c591f5c0372..3cdb5db2090a3cb3eeb43c6e20a4fec09fe8a069 100644 --- a/tensorflow/contrib/lite/kernels/select.cc +++ b/tensorflow/contrib/lite/kernels/select.cc @@ -89,6 +89,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteUInt8: \ TF_LITE_SELECT(uint8_t, op); \ break; \ + case kTfLiteInt16: \ + TF_LITE_SELECT(int16_t, op); \ + break; \ case kTfLiteInt32: \ TF_LITE_SELECT(int32_t, op); \ break; \ diff --git a/tensorflow/contrib/lite/kernels/select_test.cc b/tensorflow/contrib/lite/kernels/select_test.cc index cfe24a5fc92765747d1c75bc3e6964b959e2205d..5b2e61cd29a7fd7c699fd81cb81e5f9a12c4b18f 100644 --- a/tensorflow/contrib/lite/kernels/select_test.cc +++ b/tensorflow/contrib/lite/kernels/select_test.cc @@ -88,11 +88,24 @@ TEST(SelectOpTest, SelectUInt8) { TensorType_UINT8); model.PopulateTensor(model.input1(), {false, true, false, false}); - model.PopulateTensor(model.input2(), {1, 2, 3, 4}); - model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + +TEST(SelectOpTest, SelectInt16) { + SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, + TensorType_INT16); + + model.PopulateTensor(model.input1(), {false, true, false, false}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } @@ -101,11 +114,11 @@ TEST(SelectOpTest, SelectInt32) { TensorType_INT32); model.PopulateTensor(model.input1(), {false, true, false, false}); - model.PopulateTensor(model.input2(), {1, 2, 3, 4}); - model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } @@ -113,11 +126,11 @@ TEST(SelectOpTest, RankOneSelectInt32) { SelectOpModel model({2}, {2, 1, 2, 1}, {2, 1, 2, 1}, TensorType_INT32); model.PopulateTensor(model.input1(), {false, true}); - model.PopulateTensor(model.input2(), {1, 2, 3, 4}); - model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 3, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 3, 4})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1})); } @@ -125,11 +138,11 @@ TEST(SelectOpTest, RankZeroSelectInt32) { SelectOpModel model({1}, {1, 2, 2, 1}, {1, 2, 2, 1}, TensorType_INT32); model.PopulateTensor(model.input1(), {false}); - model.PopulateTensor(model.input2(), {1, 2, 3, 4}); - model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 7, 8})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 6, 7, 8})); EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1})); } diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc new file mode 100644 index 0000000000000000000000000000000000000000..dbcd2ef004f490f00193153be7a2cfda83e73c24 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/shape.cc @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace shape { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +template +void ExtractShape(const TfLiteTensor* input, OutType* output_data) { + for (int i = 0; i < NumDimensions(input); ++i) { + output_data[i] = SizeOfDimension(input, i); + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + auto* params = reinterpret_cast(node->builtin_data); + switch (params->out_type) { + case kTfLiteInt32: + output->type = kTfLiteInt32; + break; + case kTfLiteInt64: + output->type = kTfLiteInt64; + break; + default: + context->ReportError(context, "Unknown shape output data type: %d", + params->out_type); + return kTfLiteError; + } + + // Shape always produces a 1-dimensional output tensor, where each output + // element is the length of the corresponding input tensor's dimension. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(1); + output_size->data[0] = NumDimensions(input); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TFLITE_DCHECK_EQ(NumDimensions(output), 1); + TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input)); + + switch (output->type) { + case kTfLiteInt32: + ExtractShape(input, GetTensorData(output)); + break; + case kTfLiteInt64: + ExtractShape(input, GetTensorData(output)); + break; + default: + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace shape + +TfLiteRegistration* Register_SHAPE() { + static TfLiteRegistration r = {nullptr, nullptr, shape::Prepare, shape::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/shape_test.cc b/tensorflow/contrib/lite/kernels/shape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..27b48f4e992a8f02d56815bd1bd9074f5b41f400 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/shape_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class ShapeOpModel : public SingleOpModel { + public: + ShapeOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType output_type) { + input_ = AddInput(input_type); + output_ = AddOutput(output_type); + SetBuiltinOp(BuiltinOperator_SHAPE, BuiltinOptions_ShapeOptions, + CreateShapeOptions(builder_, output_type).Union()); + BuildInterpreter({input_shape}); + } + + TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); } + + int input() { return input_; } + + int32_t GetOutputSize() { return GetTensorSize(output_); } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(ShapeOpTest, OutTypeInt) { + ShapeOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32, + TensorType_INT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(ShapeOpTest, OutTypeInt64) { + ShapeOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32, + TensorType_INT64); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(ShapeOpTest, ScalarTensor) { + ShapeOpModel model({}, TensorType_FLOAT32, TensorType_INT32); + model.Invoke(); + + EXPECT_EQ(model.GetOutputSize(), 0); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({0})); +} + +TEST(ShapeOpTest, EmptyTensor) { + ShapeOpModel model({1, 0}, TensorType_FLOAT32, TensorType_INT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc index 6c5338ff0fd26337c9adc8e0b94a0a88edfde37f..727822f6beaa8a63ca8f1b57ba4993d2e59f7e0b 100644 --- a/tensorflow/contrib/lite/kernels/softmax_test.cc +++ b/tensorflow/contrib/lite/kernels/softmax_test.cc @@ -92,10 +92,9 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) { m.Invoke(); std::unique_ptr output_buffer(new float[input_size * batch_size]); - static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, - {1, 0, 0, input_size}}; - tflite::reference_ops::Softmax(input_buffer, input_dims, beta, - output_buffer.get(), input_dims); + auto input_shape = RuntimeShape({batch_size, 1, 1, input_size}); + tflite::reference_ops::Softmax(input_buffer, input_shape, beta, + output_buffer.get(), input_shape); std::vector expected; expected.insert(expected.end(), output_buffer.get(), @@ -120,10 +119,9 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) { m.Invoke(); std::unique_ptr output_buffer(new float[input_size * batch_size]); - static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, - {1, 0, 0, input_size}}; - tflite::reference_ops::Softmax(input_buffer, input_dims, beta, - output_buffer.get(), input_dims); + auto input_shape = RuntimeShape({batch_size, 1, 1, input_size}); + tflite::reference_ops::Softmax(input_buffer, input_shape, beta, + output_buffer.get(), input_shape); std::vector expected; expected.insert(expected.end(), output_buffer.get(), diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc index c9269599e58f095ded4788e2ab064583ae0a708c..03079f1c3b4110da9193f91ed22940594152b10f 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc @@ -113,7 +113,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); } -#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \ +#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \ type::SpaceToBatchND(GetTensorData(op_context.input), \ GetTensorDims(op_context.input), \ GetTensorData(op_context.block_shape), \ @@ -121,34 +121,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorData(op_context.paddings), \ GetTensorDims(op_context.paddings), \ GetTensorData(op_context.output), \ - GetTensorDims(op_context.output)) + GetTensorDims(op_context.output), pad_value) switch (op_context.input->type) { // Already know in/out types are same. case kTfLiteFloat32: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0); } break; case kTfLiteUInt8: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t, + op_context.output->params.zero_point); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t, + op_context.output->params.zero_point); } break; case kTfLiteInt32: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0); } break; case kTfLiteInt64: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0); } break; default: diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc index 92a4a037d5873e608ee7bdbdfc5eaa5e9b62bc8c..5756573629a51917e39a312117a1fcd29c150dc0 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc @@ -23,6 +23,7 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using ::testing::Matcher; class SpaceToBatchNDOpModel : public SingleOpModel { public: @@ -30,6 +31,10 @@ class SpaceToBatchNDOpModel : public SingleOpModel { PopulateTensor(input_, data); } + void SetQuantizedInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + void SetBlockShape(std::initializer_list data) { PopulateTensor(block_shape_, data); } @@ -41,6 +46,11 @@ class SpaceToBatchNDOpModel : public SingleOpModel { std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + protected: int input_; int block_shape_; @@ -56,18 +66,19 @@ class SpaceToBatchNDOpModel : public SingleOpModel { // m.Invoke(); class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel { public: - SpaceToBatchNDOpConstModel(std::initializer_list input_shape, + SpaceToBatchNDOpConstModel(const TensorData& input, std::initializer_list block_shape, - std::initializer_list paddings) { - input_ = AddInput(TensorType_FLOAT32); + std::initializer_list paddings, + const TensorData& output) { + input_ = AddInput(input); block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2}); - output_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, BuiltinOptions_SpaceToBatchNDOptions, CreateSpaceToBatchNDOptions(builder_).Union()); - BuildInterpreter({input_shape}); + BuildInterpreter({input.shape}); } }; @@ -81,26 +92,30 @@ class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel { // m.Invoke(); class SpaceToBatchNDOpDynamicModel : public SpaceToBatchNDOpModel { public: - SpaceToBatchNDOpDynamicModel(std::initializer_list input_shape) { - input_ = AddInput(TensorType_FLOAT32); + SpaceToBatchNDOpDynamicModel(const TensorData& input, + const TensorData& output) { + input_ = AddInput(input); block_shape_ = AddInput(TensorType_INT32); paddings_ = AddInput(TensorType_INT32); - output_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, BuiltinOptions_SpaceToBatchNDOptions, CreateSpaceToBatchNDOptions(builder_).Union()); - BuildInterpreter({input_shape, {2}, {2, 2}}); + BuildInterpreter({input.shape, {2}, {2, 2}}); } }; TEST(SpaceToBatchNDOpTest, InvalidShapeTest) { - EXPECT_DEATH(SpaceToBatchNDOpConstModel({1, 3, 3, 1}, {2, 2}, {0, 0, 0, 0}), - "Cannot allocate tensors"); + EXPECT_DEATH( + SpaceToBatchNDOpConstModel({TensorType_FLOAT32, {1, 3, 3, 1}}, {2, 2}, + {0, 0, 0, 0}, {TensorType_FLOAT32}), + "Cannot allocate tensors"); } TEST(SpaceToBatchNDOpTest, SimpleConstTest) { - SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0}); + SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 4, 4, 1}}, {2, 2}, + {0, 0, 0, 0}, {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1})); @@ -109,7 +124,8 @@ TEST(SpaceToBatchNDOpTest, SimpleConstTest) { } TEST(SpaceToBatchNDOpTest, SimpleDynamicTest) { - SpaceToBatchNDOpDynamicModel m({1, 4, 4, 1}); + SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 4, 4, 1}}, + {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.SetBlockShape({2, 2}); m.SetPaddings({0, 0, 0, 0}); @@ -120,7 +136,8 @@ TEST(SpaceToBatchNDOpTest, SimpleDynamicTest) { } TEST(SpaceToBatchNDOpTest, MultipleInputBatchesConstTest) { - SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0}); + SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, {2, 2}, + {0, 0, 0, 0}, {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1})); @@ -129,7 +146,8 @@ TEST(SpaceToBatchNDOpTest, MultipleInputBatchesConstTest) { } TEST(SpaceToBatchNDOpTest, MultipleInputBatchesDynamicTest) { - SpaceToBatchNDOpDynamicModel m({2, 2, 4, 1}); + SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.SetBlockShape({2, 2}); m.SetPaddings({0, 0, 0, 0}); @@ -140,7 +158,8 @@ TEST(SpaceToBatchNDOpTest, MultipleInputBatchesDynamicTest) { } TEST(SpaceToBatchNDOpTest, SimplePaddingConstTest) { - SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0}); + SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 5, 2, 1}}, {3, 2}, + {1, 0, 2, 0}, {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); @@ -151,7 +170,8 @@ TEST(SpaceToBatchNDOpTest, SimplePaddingConstTest) { } TEST(SpaceToBatchNDOpTest, SimplePaddingDynamicTest) { - SpaceToBatchNDOpDynamicModel m({1, 5, 2, 1}); + SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 5, 2, 1}}, + {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); m.SetBlockShape({3, 2}); m.SetPaddings({1, 0, 2, 0}); @@ -164,7 +184,8 @@ TEST(SpaceToBatchNDOpTest, SimplePaddingDynamicTest) { } TEST(SpaceToBatchNDOpTest, ComplexPaddingConstTest) { - SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4}); + SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 4, 2, 1}}, {3, 2}, + {1, 1, 2, 4}, {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); @@ -176,7 +197,8 @@ TEST(SpaceToBatchNDOpTest, ComplexPaddingConstTest) { } TEST(SpaceToBatchNDOpTest, ComplexPaddingDynamicTest) { - SpaceToBatchNDOpDynamicModel m({1, 4, 2, 1}); + SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 4, 2, 1}}, + {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); m.SetBlockShape({3, 2}); m.SetPaddings({1, 1, 2, 4}); @@ -189,6 +211,88 @@ TEST(SpaceToBatchNDOpTest, ComplexPaddingDynamicTest) { })); } +class QuantizedSpaceToBatchNDOpTest : public ::testing::Test { + protected: + std::vector> DequantizedArrayNear( + const std::vector& values, const float min, const float max) { + const float quantization_tolerance = (max - min) / 255.0; + return ArrayFloatNear(values, quantization_tolerance); + } +}; + +TEST_F(QuantizedSpaceToBatchNDOpTest, ZeroNotInQuantizationRange) { + // The test_util and actual quantization code currently ensure that the range + // must include zero, but if that ever changes, this test will catch it. + EXPECT_DEATH(SpaceToBatchNDOpConstModel m( + {TensorType_UINT8, {1, 2, 2, 1}, 1.0, 2.0}, {4, 2}, + {0, 0, 1, 1, 1, 1, 0, 0}, {TensorType_UINT8, {}, 1.0, 2.0}), + ".*Check failed: f_min <= 0.*"); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingConstTest) { + SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0}, + {3, 2}, {1, 0, 2, 0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7, + 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1}, + -1.0, 1.0))); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingDynamicTest) { + SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1}); + m.SetBlockShape({3, 2}); + m.SetPaddings({1, 0, 2, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7, + 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1}, + -1.0, 1.0))); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingConstTest) { + SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0}, + {3, 2}, {1, 1, 2, 4}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + { + 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0, + 0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, + 0, -0.3, 0, 0, 0, 0, 0, 0, 0, 0.4, 0, 0, 0, 0, 0, 0, + }, + -1.0, 1.0))); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingDynamicTest) { + SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8}); + m.SetBlockShape({3, 2}); + m.SetPaddings({1, 1, 2, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + { + 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0, + 0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, + 0, -0.3, 0, 0, 0, 0, 0, 0, 0, 0.4, 0, 0, 0, 0, 0, 0, + }, + -1.0, 1.0))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc index 404c32ad9ca8b9f1e467b747708ccb451f2a5118..fec2a6f0d97ae48e0c49d82c726278a46d96a7fc 100644 --- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -188,7 +187,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return ResizeOutputShape(context, output_shape, output); } -template +template TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); const TfLiteTensor* output_shape = @@ -205,10 +204,10 @@ TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) { const int num_indices = SizeOfDimension(indices, 0); const bool value_is_scalar = NumDimensions(values) == 0; - std::vector> indices_vector; + std::vector> indices_vector; indices_vector.reserve(num_indices); - TF_LITE_ENSURE_OK(context, GetIndicesVector(context, indices, num_indices, - &indices_vector)); + TF_LITE_ENSURE_OK(context, GetIndicesVector(context, indices, num_indices, + &indices_vector)); reference_ops::SparseToDense(indices_vector, GetTensorData(values), *GetTensorData(default_value), GetTensorData(output), GetTensorDims(output), diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc index 43387df9ceb4d54a2784c3fa4718a95262948729..b14448604123253bac9c50c21f047891721ab122 100644 --- a/tensorflow/contrib/lite/kernels/split.cc +++ b/tensorflow/contrib/lite/kernels/split.cc @@ -76,8 +76,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits); auto input_type = op_context.input->type; - TF_LITE_ENSURE(context, - input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 || + input_type == kTfLiteUInt8 || + input_type == kTfLiteInt16); for (int i = 0; i < NumOutputs(node); ++i) { GetOutput(context, node, i)->type = input_type; } @@ -137,9 +138,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_SPLIT(uint8_t); break; } + case kTfLiteInt16: { + TF_LITE_SPLIT(int16_t); + break; + } default: context->ReportError( - context, "Only float32 and uint8 are currently supported, got %d.", + context, + "Only float32, uint8 and int16 are currently supported, got %d.", op_context.input->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index 725dd8105ab9506d5203ed38a11f8e06abdab603..bed2117f9ae3a64e963478eb03b46f0547f4c05f 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -121,10 +121,19 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, int32_t begin = GetBeginValueAtIndex(op_context, idx); int32_t end = GetEndValueAtIndex(op_context, idx); + // When shrinking an axis, the end position does not matter (and can be + // incorrect when negative indexing is used, see Issue #19260). Always use + // begin + 1 to generate a length 1 slice, since begin has + // already been adjusted for negative indices by GetBeginValueAtIndex. + const bool shrink_axis = op_context->params->shrink_axis_mask & (1 << idx); + if (shrink_axis) { + end = begin + 1; + } + // This is valid for both positive and negative strides int32_t dim_shape = ceil((end - begin) / static_cast(stride)); dim_shape = dim_shape < 0 ? 0 : dim_shape; - if (!(op_context->params->shrink_axis_mask & (1 << idx))) { + if (!shrink_axis) { output_shape_vector.push_back(dim_shape); } } @@ -204,13 +213,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { int begin_mask = ReverseMaskBits(op_context.params->begin_mask, op_context.dims); int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims); - -#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ - kernel_type::StridedSlice(GetTensorData(op_context.input), \ - GetTensorDims(op_context.input), begin_mask, \ - end_mask, starts, stops, strides, \ - GetTensorData(op_context.output), \ - GetTensorDims(op_context.output)) + int shrink_axis_mask = + ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims); + +#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ + kernel_type::StridedSlice( \ + GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), begin_mask, end_mask, shrink_axis_mask, \ + starts, stops, strides, GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) switch (op_context.input->type) { case kTfLiteFloat32: diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc index cc39179bc705aa1083e74b06f8f7f3fb45e9f616..c5d4f9affb46c82b4dec15bc0653d7315d132335 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice_test.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc @@ -21,7 +21,6 @@ limitations under the License. namespace tflite { namespace { -using ::int32; using ::testing::ElementsAreArray; template data) { PopulateTensor(input_, data); } - void SetBegin(std::initializer_list data) { - PopulateTensor(begin_, data); + void SetBegin(std::initializer_list data) { + PopulateTensor(begin_, data); } - void SetEnd(std::initializer_list data) { - PopulateTensor(end_, data); + void SetEnd(std::initializer_list data) { + PopulateTensor(end_, data); } - void SetStrides(std::initializer_list data) { - PopulateTensor(strides_, data); + void SetStrides(std::initializer_list data) { + PopulateTensor(strides_, data); } std::vector GetOutput() { @@ -384,6 +383,45 @@ TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); } +TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1_NegativeSlice) { + // This is equivalent to tf.range(4)[-1]. + StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({0, 1, 2, 3}); + m.SetBegin({-1}); + m.SetEnd({0}); + m.SetStrides({1}); + + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxis3_NegativeSlice) { + // This is equivalent to tf.range(4)[:, tf.newaxis][-2, -1]. + StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + m.SetInput({0, 1, 2, 3}); + m.SetBegin({-2, -1}); + m.SetEnd({-1, 0}); + m.SetStrides({1, 1}); + + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxis2_BeginEndAxis1_NegativeSlice) { + // This is equivalent to tf.range(4)[:, tf.newaxis][:, -1]. + StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 1, 1, 0, 0, 2); + m.SetInput({0, 1, 2, 3}); + m.SetBegin({0, -1}); + m.SetEnd({0, 0}); + m.SetStrides({1, 1}); + + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3})); +} + TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4}); @@ -395,17 +433,6 @@ TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); } -TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) { - StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); - m.SetInput({1, 2, 3, 4}); - m.SetBegin({-2}); - m.SetEnd({-3}); - m.SetStrides({-1}); - m.Invoke(); - EXPECT_TRUE(m.GetOutputShape().empty()); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); -} - TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6}); @@ -538,7 +565,7 @@ TEST(StridedSliceOpTest, RunTwice) { } TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) { - StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); m.SetBegin({0, 0, 0}); diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index a8b803589962032db3ed579d31e8b736c3afada0..77a1f596898bb7fa99a7509a25229c627d762bdd 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -78,29 +78,47 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } template -void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteSubParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); -#define TF_LITE_SUB(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) - if (kernel_type == kReference) { - if (data->requires_broadcast) { - TF_LITE_SUB(reference_ops, BroadcastSub); +void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, + const OpData* data, const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { +#define TF_LITE_SUB(type, opname, data_type) \ + data_type output_activation_min, output_activation_max; \ + CalculateActivationRange(params->activation, &output_activation_min, \ + &output_activation_max); \ + tflite::ArithmeticParams op_params; \ + SetActivationParams(output_activation_min, output_activation_max, \ + &op_params); \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)) + if (output->type == kTfLiteInt32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, int32_t); + } else { + TF_LITE_SUB(reference_ops, SubWithActivation, int32_t); + } } else { - TF_LITE_SUB(reference_ops, Sub); + if (data->requires_broadcast) { + TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, int32_t); + } else { + TF_LITE_SUB(optimized_ops, SubWithActivation, int32_t); + } } - } else { - if (data->requires_broadcast) { - TF_LITE_SUB(optimized_ops, BroadcastSub); + } else if (output->type == kTfLiteFloat32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, float); + } else { + TF_LITE_SUB(reference_ops, SubWithActivation, float); + } } else { - TF_LITE_SUB(optimized_ops, Sub); + if (data->requires_broadcast) { + TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, float); + } else { + TF_LITE_SUB(optimized_ops, SubWithActivation, float); + } } } #undef TF_LITE_SUB @@ -128,36 +146,43 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, int input1_shift; QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier, &input1_shift); - input1_shift *= -1; int32 input2_multiplier; int input2_shift; QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier, &input2_shift); - input2_shift *= -1; int32 output_multiplier; int output_shift; QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier, &output_shift); - output_shift *= -1; int32 output_activation_min, output_activation_max; CalculateActivationRangeUint8(params->activation, output, &output_activation_min, &output_activation_max); -#define TF_LITE_SUB(type, opname) \ - type::opname(left_shift, GetTensorData(input1), \ - GetTensorDims(input1), input1_offset, input1_multiplier, \ - input1_shift, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, input2_multiplier, \ - input2_shift, output_offset, output_multiplier, output_shift, \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)); +#define TF_LITE_SUB(type, opname) \ + tflite::ArithmeticParams op_params; \ + op_params.left_shift = left_shift; \ + op_params.input1_offset = input1_offset; \ + op_params.input1_multiplier = input1_multiplier; \ + op_params.input1_shift = input1_shift; \ + op_params.input2_offset = input2_offset; \ + op_params.input2_multiplier = input2_multiplier; \ + op_params.input2_shift = input2_shift; \ + op_params.output_offset = output_offset; \ + op_params.output_multiplier = output_multiplier; \ + op_params.output_shift = output_shift; \ + SetActivationParams(output_activation_min, output_activation_max, \ + &op_params); \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)) // The quantized version of Sub doesn't support activations, so we // always use BroadcastSub. if (kernel_type == kReference) { - TF_LITE_SUB(reference_ops, BroadcastSub); + TF_LITE_SUB(reference_ops, BroadcastSub4DSlow); } else { - TF_LITE_SUB(optimized_ops, BroadcastSub); + TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow); } #undef TF_LITE_SUB } @@ -171,14 +196,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (output->type == kTfLiteFloat32) { - EvalFloat(context, node, params, data, input1, input2, output); + if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { + EvalSub(context, node, params, data, input1, input2, output); } else if (output->type == kTfLiteUInt8) { EvalQuantized(context, node, params, data, input1, input2, output); } else { context->ReportError( - context, "output type %d is not supported, requires float|uint8 types.", + context, + "output type %d is not supported, requires float|uint8|int32 types.", output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/sub_test.cc b/tensorflow/contrib/lite/kernels/sub_test.cc index ff07aeec49dbfcc0e1f65df3d674d5ec30f1b54c..5978c574d35492eda6b903fd83d95ecbd6b62148 100644 --- a/tensorflow/contrib/lite/kernels/sub_test.cc +++ b/tensorflow/contrib/lite/kernels/sub_test.cc @@ -52,6 +52,13 @@ class FloatSubOpModel : public BaseSubOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; +class IntegerSubOpModel : public BaseSubOpModel { + public: + using BaseSubOpModel::BaseSubOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + class QuantizedSubOpModel : public BaseSubOpModel { public: using BaseSubOpModel::BaseSubOpModel; @@ -125,6 +132,57 @@ TEST(FloatSubOpModel, WithBroadcast) { } } +TEST(IntegerSubOpModel, NoActivation) { + IntegerSubOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3})); +} + +TEST(IntegerSubOpModel, ActivationRELU_N1_TO_1) { + IntegerSubOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_RELU_N1_TO_1); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 0, 1, 1})); +} + +TEST(IntegerSubOpModel, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerSubOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5, 11, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3, 0, 19})) + << "With shape number " << i; + } +} + +TEST(IntegerSubOpModel, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerSubOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, // always a scalar + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-21, 1, 6, 7, 10, 19}))) + << "With shape number " << i; + } +} + TEST(QuantizedSubOpModel, QuantizedTestsNoActivation) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = { diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index 308860c299e9d74729d35b760e0f605437872c92..6d4912ce3aa40bf95dc1e26572b8a07fb6362744 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include + +// SVDF op that compresses a fully connected op via low-rank matrix +// factorization. See https://research.google.com/pubs/archive/43813.pdf for +// details. #include #include #include @@ -32,6 +35,67 @@ namespace ops { namespace builtin { namespace svdf { +namespace { + +struct OpData { + int scratch_tensor_index; + bool float_weights_time_initialized; +}; + +static inline void ApplyTimeWeightsBiasAndActivation( + int batch_size, int memory_size, int num_filters, int num_units, int rank, + const TfLiteTensor* weights_time, const TfLiteTensor* bias, + TfLiteFusedActivation activation, TfLiteTensor* state, + TfLiteTensor* scratch, TfLiteTensor* output) { + // Compute matmul(state, weights_time). + // The right most column is used to save temporary output (with the size of + // num_filters). This is achieved by starting at state->data.f and having the + // stride equal to memory_size. + for (int b = 0; b < batch_size; ++b) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::BatchVectorBatchVectorDotProduct( + weights_time->data.f, state_ptr_batch, memory_size, num_filters, + scratch_ptr_batch, /*result_stride=*/1); + } + + // Initialize output with bias if provided. + if (bias) { + tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, + output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, batch_size * num_units); + } + + // Reduction sum. + for (int b = 0; b < batch_size; ++b) { + float* output_ptr_batch = output->data.f + b * num_units; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch, + num_units, rank); + } + + // Apply activation. + for (int b = 0; b < batch_size; ++b) { + float* output_ptr_batch = output->data.f + b * num_units; + tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units, + activation, output_ptr_batch); + } + + // Left shift the state to make room for next cycle's activation. + // TODO(alanchiao): explore collapsing this into a single loop. + for (int b = 0; b < batch_size; ++b) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + for (int f = 0; f < num_filters; ++f) { + tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size, + /*shift_value=*/0.0); + state_ptr_batch += memory_size; + } + } +} + +} // namespace + constexpr int kInputTensor = 0; constexpr int kWeightsFeatureTensor = 1; constexpr int kWeightsTimeTensor = 2; @@ -40,29 +104,34 @@ constexpr int kStateTensor = 0; constexpr int kOutputTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, 1, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData(); + op_data->float_weights_time_initialized = false; + context->AddTensors(context, /*tensors_to_add=*/4, + &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + delete reinterpret_cast(buffer); } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - int* scratch_tensor_index = reinterpret_cast(node->user_data); + const auto* params = reinterpret_cast(node->builtin_data); + OpData* op_data = reinterpret_cast(node->user_data); + int scratch_tensor_index = op_data->scratch_tensor_index; // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); - TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* weights_feature = GetInput(context, node, kWeightsFeatureTensor); const TfLiteTensor* weights_time = GetInput(context, node, kWeightsTimeTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + // Check all the parameters of tensor match within themselves and match the // input configuration. const int rank = params->rank; @@ -103,10 +172,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size_array)); + // The weights are of consistent type, so it suffices to check one. + const bool is_hybrid_op = + (input->type == kTfLiteFloat32 && weights_feature->type == kTfLiteUInt8); + // Resize scratch. TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(1); - node->temporaries->data[0] = *scratch_tensor_index; + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(4); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } + node->temporaries->data[0] = scratch_tensor_index; TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2); scratch_size_array->data[0] = batch_size; @@ -118,24 +195,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor, scratch_size_array)); - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* weights_feature = - GetInput(context, node, kWeightsFeatureTensor); - const TfLiteTensor* weights_time = - GetInput(context, node, kWeightsTimeTensor); + if (is_hybrid_op) { + // Tell interpreter to allocate temporary tensors to store quantized values + // of input tensors. + node->temporaries->data[1] = scratch_tensor_index + 1; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } - TfLiteTensor* state = GetOutput(context, node, kStateTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0); + // Tell interpreter to allocate temporary tensors to store scaling factors. + node->temporaries->data[2] = scratch_tensor_index + 2; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = batch_size; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + // Used to store dequantized weights_time matrix for hybrid computation + // of matmul(state, weights_time), which occurs in floating point. + node->temporaries->data[3] = scratch_tensor_index + 3; + TfLiteTensor* float_weights_time = GetTemporary(context, node, /*index=*/3); + float_weights_time->type = kTfLiteFloat32; + // Persistent so that we can compute the dequantized weights only once. + float_weights_time->allocation_type = kTfLiteArenaRwPersistent; + if (!TfLiteIntArrayEqual(float_weights_time->dims, weights_time->dims)) { + TfLiteIntArray* float_weights_time_size = + TfLiteIntArrayCopy(weights_time->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, float_weights_time, + float_weights_time_size)); + } + } + return kTfLiteOk; +} +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* input, + const TfLiteTensor* weights_feature, + const TfLiteTensor* weights_time, + const TfLiteTensor* bias, const TfLiteSVDFParams* params, + TfLiteTensor* scratch, TfLiteTensor* state, + TfLiteTensor* output) { const int rank = params->rank; const int batch_size = input->dims->data[0]; const int input_size = input->dims->data[1]; @@ -146,67 +255,151 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Clear the activation (state left most column). // TODO(ghodrat): Add a test which initialize state with invalid values in // left most column and make sure it passes. - for (int b = 0; b < batch_size; b++) { + for (int b = 0; b < batch_size; ++b) { float* state_ptr_batch = state->data.f + b * memory_size * num_filters; - for (int c = 0; c < num_filters; c++) { + for (int c = 0; c < num_filters; ++c) { float* state_ptr = state_ptr_batch + c * memory_size; state_ptr[memory_size - 1] = 0.0; } } // Compute conv1d(inputs, weights_feature). - // The state left most column is used to save current cycle activation. This + // The state right most column is used to save current cycle activation. This // is achieved by starting at state->data.f[memory_size - 1] and having the // stride equal to memory_size. tensor_utils::MatrixBatchVectorMultiplyAccumulate( weights_feature->data.f, num_filters, input_size, input->data.f, batch_size, &state->data.f[memory_size - 1], memory_size); - // Compute matmul(state, weights_time). - // The right most column is used to save temporary output (with the size of - // num_filters). This is achieved by starting at state->data.f and having the - // stride equal to memory_size. - for (int b = 0; b < batch_size; b++) { + ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters, + num_units, rank, weights_time, bias, + params->activation, state, scratch, output); + return kTfLiteOk; +} + +TfLiteStatus EvalHybrid( + TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input, + const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time, + const TfLiteTensor* bias, const TfLiteSVDFParams* params, + TfLiteTensor* scratch, TfLiteTensor* scaling_factors, + TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output) { + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + const int num_filters = weights_feature->dims->data[0]; + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + // Initialize the pointer to input. + const float* input_ptr_batch = input->data.f; + + // Initialize the pointer to storage for quantized values and + // scaling factors. + int8_t* quantized_input_ptr_batch = + reinterpret_cast(input_quantized->data.uint8); + + float* scaling_factors_ptr = scaling_factors->data.f; + + // Other initializations. + const int8_t* weights_feature_ptr = + reinterpret_cast(weights_feature->data.uint8); + const float weights_feature_scale = weights_feature->params.scale; + + // Clear the activation (state left most column). + // TODO(ghodrat): Add a test which initialize state with invalid values in + // left most column and make sure it passes. + for (int b = 0; b < batch_size; ++b) { float* state_ptr_batch = state->data.f + b * memory_size * num_filters; - float* scratch_ptr_batch = scratch->data.f + b * num_filters; - tensor_utils::BatchVectorBatchVectorDotProduct( - weights_time->data.f, state_ptr_batch, memory_size, num_filters, - scratch_ptr_batch, /*result_stride=*/1); + for (int c = 0; c < num_filters; ++c) { + float* state_ptr = state_ptr_batch + c * memory_size; + state_ptr[memory_size - 1] = 0.0; + } } - // Initialize output with bias if provided. - if (bias) { - tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, - output->data.f); - } else { - tensor_utils::ZeroVector(output->data.f, batch_size * num_units); - } + if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) { + // Quantize input from float to int8. + float unused_min, unused_max; + for (int b = 0; b < batch_size; ++b) { + const int offset = b * input_size; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, input_size, + quantized_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors_ptr[b]); + scaling_factors_ptr[b] *= weights_feature_scale; + } - // Reduction sum - for (int b = 0; b < batch_size; b++) { - float* output_ptr_batch = output->data.f + b * num_units; - float* scratch_ptr_batch = scratch->data.f + b * num_filters; - tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch, - num_units, rank); + // Compute conv1d(inputs, weights_feature). + // The state right most column is used to save current cycle activation. + // This is achieved by starting at state->data.f[memory_size - 1] and having + // the stride equal to memory_size. + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + weights_feature_ptr, num_filters, input_size, quantized_input_ptr_batch, + scaling_factors_ptr, batch_size, &state->data.f[memory_size - 1], + memory_size); } - // Apply activation. - for (int b = 0; b < batch_size; b++) { - float* output_ptr_batch = output->data.f + b * num_units; - tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units, - params->activation, output_ptr_batch); - } + // TODO(alanchiao): can optimize hybrid case ~5% by unrolling loop in applying + // time weights so that the inner loop multiplies eight elements at a time. + ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters, + num_units, rank, weights_time, bias, + params->activation, state, scratch, output); + return kTfLiteOk; +} - // Right shift the state. - for (int b = 0; b < batch_size; b++) { - float* state_ptr_batch = state->data.f + b * memory_size * num_filters; - for (int f = 0; f < num_filters; f++) { - tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size, - /*shift_value=*/0.0); - state_ptr_batch += memory_size; +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* op_data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* weights_feature = + GetInput(context, node, kWeightsFeatureTensor); + const TfLiteTensor* weights_time = + GetInput(context, node, kWeightsTimeTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + + TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* state = GetOutput(context, node, kStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (weights_feature->type) { + case kTfLiteFloat32: { + return EvalFloat(context, node, input, weights_feature, weights_time, + bias, params, scratch, state, output); + break; } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + TfLiteTensor* float_weights_time = + GetTemporary(context, node, /*index=*/3); + + // Dequantize weights time. + // TODO(alanchiao): this dequantization initialization only needs to + // happen once per model and should theoretically be placed in either Init + // or Prepare. However, TFLite doesn't allocate float_weights_time until + // the Eval function. + // TODO(alanchiao): refactor logic out into dequantize function. + if (!op_data->float_weights_time_initialized) { + const float dequantization_scale = weights_time->params.scale; + const int8_t* weights_time_ptr = + reinterpret_cast(weights_time->data.uint8); + for (int i = 0; i < NumElements(float_weights_time); ++i) { + float_weights_time->data.f[i] = + weights_time_ptr[i] * dequantization_scale; + } + op_data->float_weights_time_initialized = true; + } + return EvalHybrid(context, node, input, weights_feature, + float_weights_time, bias, params, scratch, + scaling_factors, input_quantized, state, output); + break; + } + default: + context->ReportError(context, "Type %d not currently supported.", + weights_feature->type); + return kTfLiteError; } - return kTfLiteOk; } } // namespace svdf diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc index 0f166dc69b95f3459388135b3a6c4d9b73a31cb4..5af3ff85004ce43c5b75c6f12761f121c0d8deca 100644 --- a/tensorflow/contrib/lite/kernels/svdf_test.cc +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -126,17 +126,20 @@ static float svdf_golden_output_rank_2[] = { }; // Derived class of SingleOpModel, which is used to test SVDF TFLite op. -class SVDFOpModel : public SingleOpModel { +class BaseSVDFOpModel : public SingleOpModel { public: - SVDFOpModel(int batches, int units, int input_size, int memory_size, int rank) + BaseSVDFOpModel(int batches, int units, int input_size, int memory_size, + int rank, + TensorType weights_feature_type = TensorType_FLOAT32, + TensorType weights_time_type = TensorType_FLOAT32) : batches_(batches), units_(units), input_size_(input_size), memory_size_(memory_size), rank_(rank) { input_ = AddInput(TensorType_FLOAT32); - weights_feature_ = AddInput(TensorType_FLOAT32); - weights_time_ = AddInput(TensorType_FLOAT32); + weights_feature_ = AddInput(weights_feature_type); + weights_time_ = AddInput(weights_time_type); bias_ = AddNullInput(); state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -182,7 +185,7 @@ class SVDFOpModel : public SingleOpModel { int num_units() { return units_; } int num_batches() { return batches_; } - private: + protected: int input_; int weights_feature_; int weights_time_; @@ -197,7 +200,61 @@ class SVDFOpModel : public SingleOpModel { int rank_; }; -TEST(SVDFOpTest, BlackBoxTestRank1) { +class SVDFOpModel : public BaseSVDFOpModel { + public: + using BaseSVDFOpModel::BaseSVDFOpModel; +}; + +class HybridSVDFOpModel : public BaseSVDFOpModel { + public: + HybridSVDFOpModel(int batches, int units, int input_size, int memory_size, + int rank) + : BaseSVDFOpModel(batches, units, input_size, memory_size, rank, + TensorType_UINT8, TensorType_UINT8) {} + + void SetWeightsFeature(std::initializer_list f) { + SymmetricQuantizeAndPopulate(weights_feature_, f); + } + + void SetWeightsTime(std::initializer_list f) { + SymmetricQuantizeAndPopulate(weights_time_, f); + } +}; + +class SVDFOpTest : public ::testing::Test { + protected: + void VerifyGoldens(float golden_input[], float golden_output[], + int golden_size, BaseSVDFOpModel* svdf, + float tolerance = 1e-5) { + const int svdf_num_batches = svdf->num_batches(); + const int svdf_input_size = svdf->input_size(); + const int svdf_num_units = svdf->num_units(); + const int input_sequence_size = + golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF + // op and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = + golden_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf->SetInput(0, batch_start, batch_end); + + svdf->Invoke(); + + const float* golden_start = + golden_output + i * svdf_num_units * svdf_num_batches; + const float* golden_end = + golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +TEST_F(SVDFOpTest, BlackBoxTestRank1) { SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, /*memory_size=*/10, /*rank=*/1); svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, @@ -218,31 +275,11 @@ TEST(SVDFOpTest, BlackBoxTestRank1) { -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); svdf.ResetState(); - const int svdf_num_batches = svdf.num_batches(); - const int svdf_input_size = svdf.input_size(); - const int svdf_num_units = svdf.num_units(); - const int input_sequence_size = - sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); - // Going over each input batch, setting the input tensor, invoking the SVDF op - // and checking the output with the expected golden values. - for (int i = 0; i < input_sequence_size; i++) { - float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; - float* batch_end = batch_start + svdf_input_size * svdf_num_batches; - svdf.SetInput(0, batch_start, batch_end); - - svdf.Invoke(); - - float* golden_start = - svdf_golden_output_rank_1 + i * svdf_num_units * svdf_num_batches; - float* golden_end = golden_start + svdf_num_units * svdf_num_batches; - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - - EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), + &svdf); } -TEST(SVDFOpTest, BlackBoxTestRank2) { +TEST_F(SVDFOpTest, BlackBoxTestRank2) { SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, /*memory_size=*/10, /*rank=*/2); svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, @@ -278,28 +315,75 @@ TEST(SVDFOpTest, BlackBoxTestRank2) { 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); svdf.ResetState(); - const int svdf_num_batches = svdf.num_batches(); - const int svdf_input_size = svdf.input_size(); - const int svdf_num_units = svdf.num_units(); - const int input_sequence_size = - sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); - // Going over each input batch, setting the input tensor, invoking the SVDF op - // and checking the output with the expected golden values. - for (int i = 0; i < input_sequence_size; i++) { - float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; - float* batch_end = batch_start + svdf_input_size * svdf_num_batches; - svdf.SetInput(0, batch_start, batch_end); - - svdf.Invoke(); - - float* golden_start = - svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches; - float* golden_end = golden_start + svdf_num_units * svdf_num_batches; - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - - EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), + &svdf); +} + +TEST_F(SVDFOpTest, BlackBoxTestHybridRank1) { + HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/1); + svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, + 0.22197971, 0.12416199, 0.27901134, 0.27557442, + 0.3905206, -0.36137494, -0.06634006, -0.10640851}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); + + svdf.ResetState(); + VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), + &svdf, + /*tolerance=*/0.002945); +} + +TEST_F(SVDFOpTest, BlackBoxTestHybridRank2) { + HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/2); + svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, + 0.12416199, 0.15785322, 0.27901134, 0.3905206, + 0.21931258, -0.36137494, -0.10640851, 0.31053296, + -0.36118156, -0.0976817, -0.36916667, 0.22197971, + 0.15294972, 0.38031587, 0.27557442, 0.39635518, + -0.21580373, -0.06634006, -0.02702999, 0.27072677}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, + + -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, + 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, + + -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, + 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, + + -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, + -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, + + 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, + 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); + + svdf.ResetState(); + VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), + &svdf, + /*tolerance=*/0.00625109); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index d23ec201b41887b0682242687fc938d76d058c44..9156917140b5af6c0f38c878ab77fef7f93b049a 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -32,8 +32,8 @@ std::vector> ArrayFloatNear(const std::vector& values, return matchers; } -int SingleOpModel::AddInput(const TensorData& t) { - int id = AddTensor(t, {}); +int SingleOpModel::AddInput(const TensorData& t, bool is_variable) { + int id = AddTensor(t, {}, is_variable); inputs_.push_back(id); return id; } @@ -120,6 +120,7 @@ void SingleOpModel::BuildInterpreter( CHECK(interpreter_->AllocateTensors() == kTfLiteOk) << "Cannot allocate tensors"; + interpreter_->ResetVariableTensorsToZero(); } void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); } diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index db80c0082c394a2cb2f9388d3db5bd1a7cbe6266..bedbe93ae65662647f6a0fb0c9c6a6a921e148bb 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -126,8 +126,10 @@ class SingleOpModel { SingleOpModel& operator=(const SingleOpModel&) = delete; // Add a TensorType input tensor and return its index. - int AddInput(TensorType type) { return AddInput(TensorData{type}); } - int AddInput(const TensorData& t); + int AddInput(TensorType type, bool is_variable = false) { + return AddInput(TensorData{type}, is_variable); + } + int AddInput(const TensorData& t, bool is_variable = false); // Templated version of AddConstInput(). template @@ -146,20 +148,18 @@ class SingleOpModel { int AddOutput(const TensorData& t); template - void QuantizeAndPopulate(int index, std::initializer_list data) { + void QuantizeAndPopulate(int index, const std::vector& data) { TfLiteTensor* t = interpreter_->tensor(index); auto q = Quantize(data, t->params.scale, t->params.zero_point); PopulateTensor(index, 0, q.data(), q.data() + q.size()); } - void SymmetricQuantizeAndPopulate(int index, - std::initializer_list data) { + void SymmetricQuantizeAndPopulate(int index, const std::vector& data) { TfLiteTensor* t = interpreter_->tensor(index); - std::vector values(data); - const int length = values.size(); + const int length = data.size(); std::vector q(length); float min, max, scaling_factor; - tensor_utils::SymmetricQuantizeFloats(values.data(), length, q.data(), &min, + tensor_utils::SymmetricQuantizeFloats(data.data(), length, q.data(), &min, &max, &scaling_factor); // Update quantization params. t->params.scale = scaling_factor; @@ -196,8 +196,22 @@ class SingleOpModel { } // Populate the tensor given its index. + // TODO(b/110696148) clean up and merge with vector-taking variant below. + template + void PopulateTensor(int index, const std::initializer_list& data) { + T* v = interpreter_->typed_tensor(index); + CHECK(v) << "No tensor with index '" << index << "'."; + for (T f : data) { + *v = f; + ++v; + } + } + + // Populate the tensor given its index. + // TODO(b/110696148) clean up and merge with initializer_list-taking variant + // above. template - void PopulateTensor(int index, std::initializer_list data) { + void PopulateTensor(int index, const std::vector& data) { T* v = interpreter_->typed_tensor(index); CHECK(v) << "No tensor with index '" << index << "'."; for (T f : data) { @@ -260,7 +274,8 @@ class SingleOpModel { } template - int AddTensor(TensorData t, std::initializer_list data) { + int AddTensor(TensorData t, std::initializer_list data, + bool is_variable = false) { int id = tensors_.size(); // This is slightly different depending on whether we are adding a @@ -277,6 +292,9 @@ class SingleOpModel { } else if (t.type == TensorType_INT32) { std::tie(t.scale, t.zero_point) = QuantizationParams(t.min, t.max); + } else if (t.type == TensorType_INT16) { + std::tie(t.scale, t.zero_point) = + QuantizationParams(t.min, t.max); } else { LOG(FATAL) << "No support for the requested quantized type"; } @@ -309,7 +327,7 @@ class SingleOpModel { tensors_.push_back(CreateTensor(builder_, builder_.CreateVector(t.shape), t.type, /*buffer=*/buffer_id, - /*name=*/0, q_params)); + /*name=*/0, q_params, is_variable)); tensor_data_[id] = t; diff --git a/tensorflow/contrib/lite/kernels/test_util_test.cc b/tensorflow/contrib/lite/kernels/test_util_test.cc index 1e10e89061213b6fcabd404310893dd97a51d83f..236580347254d336609a3081736f54e069b5cb5a 100644 --- a/tensorflow/contrib/lite/kernels/test_util_test.cc +++ b/tensorflow/contrib/lite/kernels/test_util_test.cc @@ -22,22 +22,22 @@ using ::testing::ElementsAreArray; TEST(TestUtilTest, QuantizeVector) { std::vector data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0}; - auto q_data = Quantize(data, /*scale=*/1.0, /*zero_point=*/0); - std::vector expected = {0, 0, 0, 1, 1, 255}; + auto q_data = Quantize(data, /*scale=*/1.0, /*zero_point=*/0); + std::vector expected = {0, 0, 0, 1, 1, 255}; EXPECT_THAT(q_data, ElementsAreArray(expected)); } TEST(TestUtilTest, QuantizeVectorScalingDown) { std::vector data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0}; - auto q_data = Quantize(data, /*scale=*/10.0, /*zero_point=*/0); - std::vector expected = {0, 0, 0, 0, 0, 100}; + auto q_data = Quantize(data, /*scale=*/10.0, /*zero_point=*/0); + std::vector expected = {0, 0, 0, 0, 0, 100}; EXPECT_THAT(q_data, ElementsAreArray(expected)); } TEST(TestUtilTest, QuantizeVectorScalingUp) { std::vector data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0}; - auto q_data = Quantize(data, /*scale=*/0.1, /*zero_point=*/0); - std::vector expected = {0, 0, 0, 5, 10, 255}; + auto q_data = Quantize(data, /*scale=*/0.1, /*zero_point=*/0); + std::vector expected = {0, 0, 0, 5, 10, 255}; EXPECT_THAT(q_data, ElementsAreArray(expected)); } diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc index af77f074742eb3fef10a74616ff679255911fbb2..5181a8f89a376302bad02913e3c7c1d094821da8 100644 --- a/tensorflow/contrib/lite/kernels/tile.cc +++ b/tensorflow/contrib/lite/kernels/tile.cc @@ -87,8 +87,9 @@ std::pair TileOneDimension(const TfLiteIntArray& in_dimensions, if (dimension == in_dimensions.size - 1) { CopyMultipleTimes(in_data, dimension_size, multipliers[dimension], out_data); - return std::make_pair(dimension_size, - dimension_size * multipliers[dimension]); + return std::make_pair( + dimension_size, + dimension_size * static_cast(multipliers[dimension])); } int total_stride_size = 0, total_tiled_stride_size = 0; const T* copy_from_data = in_data; diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc index a134a75d56ae03a5d03a3cdf632146474b863971..4f78c224e54f0c71bc6622134a1c8e4142c22daa 100644 --- a/tensorflow/contrib/lite/kernels/tile_test.cc +++ b/tensorflow/contrib/lite/kernels/tile_test.cc @@ -38,27 +38,27 @@ class TileOpModel : public SingleOpModel { PopulateTensor(input_, data); } - void SetInputUInt8(std::initializer_list data) { - PopulateTensor(input_, data); + void SetInputUInt8(std::initializer_list data) { + PopulateTensor(input_, data); } - void SetInputInt32(std::initializer_list data) { - PopulateTensor(input_, data); + void SetInputInt32(std::initializer_list data) { + PopulateTensor(input_, data); } void SetInputInt64(std::initializer_list data) { PopulateTensor(input_, data); } - void SetMultipliers(std::initializer_list data) { - PopulateTensor(multipliers_, data); + void SetMultipliers(std::initializer_list data) { + PopulateTensor(multipliers_, data); } std::vector GetOutputFloat() { return ExtractVector(output_); } - std::vector GetOutputUInt8() { return ExtractVector(output_); } + std::vector GetOutputUInt8() { return ExtractVector(output_); } - std::vector GetOutputInt32() { return ExtractVector(output_); } + std::vector GetOutputInt32() { return ExtractVector(output_); } std::vector GetOutputInt64() { return ExtractVector(output_); diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc index fb0e49c90c41747f9b7e53570276c8b8045030fd..2dd760bbfebd1faa8b7ff9158bc1a1b1d4647525 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2.cc @@ -56,11 +56,13 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { output_values_shape->data[num_dimensions - 1] = k; TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); + // Force output types. + output_indexes->type = kTfLiteInt32; + output_values->type = input->type; auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size, TfLiteIntArray* delete_on_error) { TfLiteStatus status = context->ResizeTensor(context, tensor, new_size); if (status != kTfLiteOk) { - TfLiteIntArrayFree(new_size); if (delete_on_error != nullptr) { TfLiteIntArrayFree(delete_on_error); } diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc index 212f8acc76d4afba56933029175f69b34ea87a3e..2abb89b617742b33b9280b15ad379422c5c9b207 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc @@ -42,32 +42,32 @@ class TopKV2OpModel : public SingleOpModel { PopulateTensor(input_, data); } - void SetInputUInt8(std::initializer_list data) { - PopulateTensor(input_, data); + void SetInputUInt8(std::initializer_list data) { + PopulateTensor(input_, data); } - void SetInputInt32(std::initializer_list data) { - PopulateTensor(input_, data); + void SetInputInt32(std::initializer_list data) { + PopulateTensor(input_, data); } void SetInputInt64(std::initializer_list data) { PopulateTensor(input_, data); } - std::vector GetIndexes() { - return ExtractVector(output_indexes_); + std::vector GetIndexes() { + return ExtractVector(output_indexes_); } std::vector GetValuesFloat() { return ExtractVector(output_values_); } - std::vector GetValuesUInt8() { - return ExtractVector(output_values_); + std::vector GetValuesUInt8() { + return ExtractVector(output_values_); } - std::vector GetValuesInt32() { - return ExtractVector(output_values_); + std::vector GetValuesInt32() { + return ExtractVector(output_values_); } std::vector GetValuesInt64() { @@ -119,7 +119,7 @@ TEST(TopKV2OpTest, VectorFloat) { EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(ArrayFloatNear({0.8, 0.2}))); } -// Check that uint8 works. +// Check that uint8_t works. TEST(TopKV2OpTest, TypeUint8) { TopKV2OpModel m({2, 3}, TensorType_UINT8, 2); m.SetInputUInt8({1, 2, 3, 251, 250, 249}); @@ -128,7 +128,7 @@ TEST(TopKV2OpTest, TypeUint8) { EXPECT_THAT(m.GetValuesUInt8(), ElementsAreArray({3, 2, 251, 250})); } -// Check that int32 works. +// Check that int32_t works. TEST(TopKV2OpTest, TypeInt32) { TopKV2OpModel m({2, 3}, TensorType_INT32, 2); m.SetInputInt32({1, 2, 3, 10251, 10250, 10249}); diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc index e83b1ec9879d3c360203a52835d8486d0a9b81bb..a9baa5c6988877ccc2e007e5fefdc980d7a3a679 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -119,10 +118,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Currently only support float32. switch (input->type) { case kTfLiteFloat32: - optimized_ops::TransposeConv( + reference_ops::TransposeConv( GetTensorData(input), GetTensorDims(input), GetTensorData(weights), GetTensorDims(weights), stride_width, stride_height, padding_size.width, padding_size.height, + GetTensorData(output), GetTensorDims(output), + // Last two args specify im2col which reference_ops ignores. + // (Note this does not lead to a performance regression, as the + // previous optimized version was just a copy of the reference code.) + // TODO(b/110208176): Allocate im2col tensors and switch to + // optimized_ops. GetTensorData(output), GetTensorDims(output)); break; default: diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index 1c28123a24edd9886476bf8e9ea3ba4c692baa2b..0acd705950cb262bbb2625aa6143f88b429a6562 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -70,9 +69,21 @@ constexpr int kOutputStateTensor = 0; constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; +// Temporary tensors +enum TemporaryTensor { + kScratchBuffer = 0, + kInputQuantized = 1, + kOutputStateQuantized = 2, + kCellStateQuantized = 3, + kScalingFactors = 4, + kProductScalingFactors = 5, + kRecoveredCellWeights = 6, + kNumTemporaryTensors = 7 +}; + void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, 1, scratch_tensor_index); + context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); return scratch_tensor_index; } @@ -84,7 +95,7 @@ void Free(TfLiteContext* context, void* buffer) { TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell) { - auto* params = reinterpret_cast(node->builtin_data); + const auto* params = reinterpret_cast(node->builtin_data); // Making sure clipping parameters have valid values. // == 0 means no clipping @@ -242,6 +253,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE(context, input->dims->size > 1); const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; @@ -261,7 +273,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_output = recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); + TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input, + n_output, n_cell)); // Get the pointer to output, output_state and cell_state buffer tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -288,86 +301,156 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, cell_state, cell_size)); - // Create a scratch buffer tensor. + // Mark state tensors as persistent tensors. + output_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + + // The weights are of consistent type, so it suffices to check one. + // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. + const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && + input->type == kTfLiteFloat32); + TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(1); + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } node->temporaries->data[0] = *scratch_tensor_index; - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + // Create a scratch buffer tensor. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer); scratch_buffer->type = input->type; scratch_buffer->allocation_type = kTfLiteArenaRw; - // Mark state tensors as persistent tensors. - output_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); const bool use_cifg = (input_to_input_weights == nullptr); + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; if (use_cifg) { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 3; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); } else { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Input, Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 4; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); + } + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // output_state and cell_state tensors. + node->temporaries->data[kInputQuantized] = + *scratch_tensor_index + kInputQuantized; + TfLiteTensor* input_quantized = + GetTemporary(context, node, kInputQuantized); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[kOutputStateQuantized] = + *scratch_tensor_index + kOutputStateQuantized; + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, kOutputStateQuantized); + output_state_quantized->type = kTfLiteUInt8; + output_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(output_state_quantized->dims, + output_state->dims)) { + TfLiteIntArray* output_state_quantized_size = + TfLiteIntArrayCopy(output_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output_state_quantized, + output_state_quantized_size)); + } + node->temporaries->data[kCellStateQuantized] = + *scratch_tensor_index + kCellStateQuantized; + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, kCellStateQuantized); + cell_state_quantized->type = kTfLiteUInt8; + cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { + TfLiteIntArray* cell_state_quantized_size = + TfLiteIntArrayCopy(cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state_quantized, + cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[kScalingFactors] = + *scratch_tensor_index + kScalingFactors; + TfLiteTensor* scaling_factors = + GetTemporary(context, node, kScalingFactors); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[kProductScalingFactors] = + *scratch_tensor_index + kProductScalingFactors; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, kProductScalingFactors); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered cell weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[kRecoveredCellWeights] = + *scratch_tensor_index + kRecoveredCellWeights; + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, kRecoveredCellWeights); + recovered_cell_weights->type = kTfLiteFloat32; + recovered_cell_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); + recovered_cell_weights_size->data[0] = n_cell; + if (!TfLiteIntArrayEqual(recovered_cell_weights->dims, + recovered_cell_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_cell_weights, + recovered_cell_weights_size)); + } } return kTfLiteOk; } // The LSTM Op engine. -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - - const TfLiteTensor* input_to_input_weights = - GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - const TfLiteTensor* input_to_forget_weights = - GetInput(context, node, kInputToForgetWeightsTensor); - const TfLiteTensor* input_to_cell_weights = - GetInput(context, node, kInputToCellWeightsTensor); - const TfLiteTensor* input_to_output_weights = - GetInput(context, node, kInputToOutputWeightsTensor); - - const TfLiteTensor* recurrent_to_input_weights = - GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - const TfLiteTensor* recurrent_to_forget_weights = - GetInput(context, node, kRecurrentToForgetWeightsTensor); - const TfLiteTensor* recurrent_to_cell_weights = - GetInput(context, node, kRecurrentToCellWeightsTensor); - const TfLiteTensor* recurrent_to_output_weights = - GetInput(context, node, kRecurrentToOutputWeightsTensor); - - const TfLiteTensor* cell_to_input_weights = - GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); - const TfLiteTensor* cell_to_forget_weights = - GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); - const TfLiteTensor* cell_to_output_weights = - GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); - - const TfLiteTensor* input_gate_bias = - GetOptionalInputTensor(context, node, kInputGateBiasTensor); - const TfLiteTensor* forget_gate_bias = - GetInput(context, node, kForgetGateBiasTensor); - const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); - const TfLiteTensor* output_gate_bias = - GetInput(context, node, kOutputGateBiasTensor); - - const TfLiteTensor* projection_weights = - GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - const TfLiteTensor* projection_bias = - GetOptionalInputTensor(context, node, kProjectionBiasTensor); - - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; const int n_input = input->dims->data[2]; @@ -380,8 +463,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); - // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); float* input_gate_scratch = nullptr; float* cell_scratch = nullptr; float* forget_gate_scratch = nullptr; @@ -432,6 +513,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { float* output_state_ptr = output_state->data.f; float* cell_state_ptr = cell_state->data.f; + // Feed the sequence into the LSTM step-by-step. for (int t = 0; t < max_time; t++) { const float* input_ptr_batch = input->data.f + t * n_batch * n_input; float* output_ptr_batch = output->data.f + t * n_batch * n_output; @@ -452,6 +534,262 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, + TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast(projection_weights->data.uint8); + float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast(input_quantized->data.uint8); + int8_t* quantized_output_state_ptr = + reinterpret_cast(output_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + + // Feed the sequence into the LSTM step-by-step. + for (int t = 0; t < max_time; t++) { + const float* input_ptr_batch = input->data.f + t * n_batch * n_input; + float* output_ptr_batch = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr_batch, input_to_input_weights_ptr, + input_to_input_weights_scale, input_to_forget_weights_ptr, + input_to_forget_weights_scale, input_to_cell_weights_ptr, + input_to_cell_weights_scale, input_to_output_weights_ptr, + input_to_output_weights_scale, recurrent_to_input_weights_ptr, + recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, + recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, + recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, + recurrent_to_output_weights_scale, cell_to_input_weights_ptr, + cell_to_input_weights_scale, cell_to_forget_weights_ptr, + cell_to_forget_weights_scale, cell_to_output_weights_ptr, + cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr, + cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, + projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell, + n_input, n_output, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, scaling_factors_ptr, + prod_scaling_factors_ptr, recovered_cell_weights_ptr, + quantized_input_ptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, + output_ptr_batch); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + const TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + const TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + const TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + const TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + const TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + const TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input_to_output_weights->type) { + case kTfLiteFloat32: { + return EvalFloat(input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, + scratch_buffer, output_state, cell_state, output); + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, /*index=*/2); + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + return EvalHybrid( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, scratch_buffer, + scaling_factors, prod_scaling_factors, recovered_cell_weights, + input_quantized, output_state_quantized, cell_state_quantized, + output_state, cell_state, output); + } + default: + context->ReportError(context, "Type %d is not currently supported.", + input_to_output_weights->type); + return kTfLiteError; + } + return kTfLiteOk; +} } // namespace unidirectional_sequence_lstm TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() { diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc index 5881ced7c7a616ef2c24db60892cbbf9eec7c42e..de38bdef6fd1b019c7790a664b29cd45d29e5dcc 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite Sequential LSTM op. -#include #include #include @@ -37,7 +36,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, - const std::vector>& input_shapes) + const std::vector>& input_shapes, + const TensorType& weights_type = TensorType_FLOAT32) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -48,31 +48,31 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { if (use_cifg) { input_to_input_weights_ = AddNullInput(); } else { - input_to_input_weights_ = AddInput(TensorType_FLOAT32); + input_to_input_weights_ = AddInput(weights_type); } - input_to_forget_weights_ = AddInput(TensorType_FLOAT32); - input_to_cell_weights_ = AddInput(TensorType_FLOAT32); - input_to_output_weights_ = AddInput(TensorType_FLOAT32); + input_to_forget_weights_ = AddInput(weights_type); + input_to_cell_weights_ = AddInput(weights_type); + input_to_output_weights_ = AddInput(weights_type); if (use_cifg) { recurrent_to_input_weights_ = AddNullInput(); } else { - recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_input_weights_ = AddInput(weights_type); } - recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_forget_weights_ = AddInput(weights_type); + recurrent_to_cell_weights_ = AddInput(weights_type); + recurrent_to_output_weights_ = AddInput(weights_type); if (use_peephole) { if (use_cifg) { cell_to_input_weights_ = AddNullInput(); } else { - cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + cell_to_input_weights_ = AddInput(weights_type); } - cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); - cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + cell_to_forget_weights_ = AddInput(weights_type); + cell_to_output_weights_ = AddInput(weights_type); } else { cell_to_input_weights_ = AddNullInput(); cell_to_forget_weights_ = AddNullInput(); @@ -89,7 +89,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { output_gate_bias_ = AddInput(TensorType_FLOAT32); if (use_projection_weights) { - projection_weights_ = AddInput(TensorType_FLOAT32); + projection_weights_ = AddInput(weights_type); if (use_projection_bias) { projection_bias_ = AddInput(TensorType_FLOAT32); } else { @@ -196,8 +196,9 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { zero_buffer.get() + zero_buffer_size); } - void SetInput(int offset, float* begin, float* end) { - PopulateTensor(input_, offset, begin, end); + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); } std::vector GetOutput() { return ExtractVector(output_); } @@ -208,7 +209,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { int num_batches() { return n_batch_; } int sequence_length() { return sequence_length_; } - private: + protected: int input_; int input_to_input_weights_; int input_to_forget_weights_; @@ -243,7 +244,183 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { int sequence_length_; }; -TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { +// The hybrid model has quantized weights. +class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel { + public: + HybridUnidirectionalLSTMOpModel( + int n_batch, int n_input, int n_cell, int n_output, int sequence_length, + bool use_cifg, bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : UnidirectionalLSTMOpModel( + n_batch, n_input, n_cell, n_output, sequence_length, use_cifg, + use_peephole, use_projection_weights, use_projection_bias, + cell_clip, proj_clip, input_shapes, TensorType_UINT8) {} + + void SetInputToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(projection_weights_, f); + } +}; + +class BaseLstmTest : public ::testing::Test { + protected: + // Weights of the LSTM model. Some are optional. + std::initializer_list input_to_input_weights_; + std::initializer_list input_to_cell_weights_; + std::initializer_list input_to_forget_weights_; + std::initializer_list input_to_output_weights_; + std::initializer_list input_gate_bias_; + std::initializer_list cell_gate_bias_; + std::initializer_list forget_gate_bias_; + std::initializer_list output_gate_bias_; + std::initializer_list recurrent_to_input_weights_; + std::initializer_list recurrent_to_cell_weights_; + std::initializer_list recurrent_to_forget_weights_; + std::initializer_list recurrent_to_output_weights_; + std::initializer_list cell_to_input_weights_; + std::initializer_list cell_to_forget_weights_; + std::initializer_list cell_to_output_weights_; + std::initializer_list projection_weights_; + + // LSTM input is stored as num_batch x num_inputs vector. + std::vector> lstm_input_; + // LSTM output is stored as num_batch x num_outputs vector. + std::vector> lstm_golden_output_; + + // Compares output up to tolerance to the result of the lstm given the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + // Feed the whole sequence as input. + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + lstm->SetInput(((i * num_batches) + b) * lstm->num_inputs(), + batch_start, batch_end); + } + } + + lstm->Invoke(); + + const int num_outputs = lstm->num_outputs(); + EXPECT_GT(num_outputs, 0); + std::vector expected; + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + } + + EXPECT_THAT(lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } +}; + +class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}; + input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, -0.29909778}; + input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, + -0.1556896, 0.19487578}; + input_gate_bias_ = {0., 0., 0., 0.}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_input_weights_ = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + + recurrent_to_cell_weights_ = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + + recurrent_to_forget_weights_ = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + + recurrent_to_output_weights_ = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}}; + } +}; + +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -252,9 +429,11 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { const int sequence_length = 3; UnidirectionalLSTMOpModel lstm( - n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, - /*use_peephole=*/false, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, { {sequence_length, n_batch, n_input}, // input tensor @@ -281,77 +460,138 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, - -0.34550029, 0.04266912, -0.15680569, - -0.34856534, 0.43890524}); + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, - -0.20583314, 0.44344562, 0.22077113, - -0.29909778}); + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, - -0.31343272, -0.40032279, 0.44781327, - 0.01387155, -0.35593212}); + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, - 0.40525138, 0.44272184, 0.03897077, -0.1556896, - 0.19487578}); +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; - lstm.SetInputGateBias({0., 0., 0., 0.}); + HybridUnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor - lstm.SetCellBias({0., 0., 0., 0.}); + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor - lstm.SetForgetGateBias({1., 1., 1., 1.}); + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor - lstm.SetOutputGateBias({0., 0., 0., 0.}); + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor - lstm.SetRecurrentToInputWeights( - {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, - -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, - -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor - lstm.SetRecurrentToCellWeights( - {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, - -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, - -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); - lstm.SetRecurrentToForgetWeights( - {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, - -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, - 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetRecurrentToOutputWeights( - {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, - 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, - -0.51818722, -0.15390486, 0.0468148, 0.39922136}); + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - // Input should have n_input * sequence_length many values. - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, - -0.15358765, -0.03716109, 0.12507336, - 0.41193449, -0.20860538, -0.15053082, - 0.09120187, 0.24278517, -0.12222792}; + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - float* batch0_start = lstm_input; - float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, + /*tolerance=*/0.0157651); +} - lstm.SetInput(0, batch0_start, batch0_end); +class CifgPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, + 0.05100781, 0.04717243, 0.48944736, + -0.38535351, -0.17212132}; - lstm.Invoke(); + input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, + 0.24407166, 0.33826375}; - float* golden_start = lstm_golden_output; - float* golden_end = - golden_start + lstm.num_outputs() * lstm.sequence_length(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); -} + input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_cell_weights_ = { + 0.54066205, -0.32668582, -0.43562764, -0.56094903, + 0.42957711, 0.01841056, -0.32764608, -0.33027974, + -0.10826075, 0.20675004, 0.19069612, -0.03026325, + -0.54532051, 0.33003211, 0.44901288, 0.21193194}; + + recurrent_to_forget_weights_ = { + -0.13832897, -0.0515101, -0.2359007, -0.16661474, + -0.14340827, 0.36986142, 0.23414481, 0.55899, + 0.10798943, -0.41174671, 0.17751795, -0.34484994, + -0.35874045, -0.11352962, 0.27268326, 0.54058349}; + + recurrent_to_output_weights_ = { + 0.41613156, 0.42610586, -0.16495961, -0.5663873, + 0.30579174, -0.05115908, -0.33941799, 0.23364776, + 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}; + + cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, + 0.31544167}; + cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, + -0.77109635}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}}; + } +}; -TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { +TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -360,9 +600,11 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { const int sequence_length = 3; UnidirectionalLSTMOpModel lstm( - n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true, - /*use_peephole=*/true, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, { {sequence_length, n_batch, n_input}, // input tensor @@ -389,71 +631,690 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, - 0.04717243, 0.48944736, -0.38535351, - -0.17212132}); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, - -0.3633365, -0.22755712, 0.28253698, 0.24407166, - 0.33826375}); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, - -0.09426838, -0.44257352, 0.54939759, - 0.01533556, 0.42751634}); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + HybridUnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor - lstm.SetCellBias({0., 0., 0., 0.}); + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor - lstm.SetForgetGateBias({1., 1., 1., 1.}); + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor - lstm.SetOutputGateBias({0., 0., 0., 0.}); + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor - lstm.SetRecurrentToCellWeights( - {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, - 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, - 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, - 0.21193194}); + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); - lstm.SetRecurrentToForgetWeights( - {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, - 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, - -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetRecurrentToOutputWeights( - {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, - -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, - 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetCellToForgetWeights( - {0.47485286, -0.51955009, -0.24458408, 0.31544167}); - lstm.SetCellToOutputWeights( - {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, - -0.05163646, -0.42312205, -0.01218222, - 0.24201041, -0.08124574, -0.358325, - -0.04621704, 0.21641694, -0.06471302}; + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - float* batch0_start = lstm_input; - float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); - - lstm.SetInput(0, batch0_start, batch0_end); - - lstm.Invoke(); - - float* golden_start = lstm_golden_output; - float* golden_end = - golden_start + lstm.num_outputs() * lstm.sequence_length(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); } -TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { +class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = { + 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}; + + input_to_forget_weights_ = { + -0.0018401089, -0.004852237, 0.03698424, 0.014181704, + 0.028273236, -0.016726194, -0.05249759, -0.10204261, + 0.00861066, -0.040979505, -0.009899187, 0.01923892, + -0.028177269, -0.08535103, -0.14585495, 0.10662567, + -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, + 0.0814421, -0.12257899, -0.033945758, -0.031303465, + 0.045630626, 0.06843887, -0.13492945, -0.012480007, + -0.0811829, -0.07224499, -0.09628791, 0.045100946, + 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, + 0.052625068, 0.12784666, 0.07077897, 0.025725935, + 0.04165009, 0.07241905, 0.018668644, -0.037377294, + -0.06277783, -0.08833636, -0.040120605, -0.011405586, + -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, + 0.13396506, -0.08402166, -0.01901462, -0.044678304, + -0.07720565, 0.014350063, -0.11757958, -0.0652038, + -0.08185733, -0.076754324, -0.092614375, 0.10405491, + 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, + -0.054523353, 0.02582715, 0.02327355, -0.011857179, + -0.0011980024, -0.034641717, -0.026125094, -0.17582615, + -0.15923657, -0.27486774, -0.0006143371, 0.0001771948, + -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}; + + input_to_cell_weights_ = { + -0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}; + + input_to_output_weights_ = { + -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}; + + input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666, + 0.053110216, -0.06928846, -0.13942584, -0.11816189, + 0.19483899, 0.03652339, -0.10250295, 0.036714908, + -0.18426876, 0.036065217, 0.21810818, 0.02383196, + -0.043370757, 0.08690144, -0.04444982, 0.00030581196}; + + forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}; + + cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}; + + output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113, + 0.027195795, 0.35373217, -0.018957434, 0.008907322, + -0.0762701, 0.12018895, 0.04216877, 0.0022856654, + 0.040952638, 0.3147856, 0.08225149, -0.057416286, + -0.14995944, -0.008040261, 0.13208859, 0.029760877}; + + recurrent_to_input_weights_ = { + -0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}; + + recurrent_to_cell_weights_ = { + -0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}; + + recurrent_to_forget_weights_ = { + -0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}; + + recurrent_to_output_weights_ = { + 0.025825322, -0.05813119, 0.09495884, -0.045984812, + -0.01255415, -0.0026479573, -0.08196161, -0.054914974, + -0.0046604523, -0.029587349, -0.044576716, -0.07480124, + -0.082868785, 0.023254942, 0.027502948, -0.0039728214, + -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, + -0.08829125, -0.005139627, -0.08989442, -0.0555066, + 0.13596267, -0.025062224, -0.048351806, -0.03850004, + 0.07266485, -0.022414139, 0.05940088, 0.075114764, + 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, + 0.014416728, 0.043229222, 0.034178585, -0.07530371, + 0.035837382, -0.085607, -0.007721233, -0.03287832, + -0.043848954, -0.06404588, -0.06632928, -0.073643476, + 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, + -0.030063879, 0.008801774, -0.023021035, -0.019558564, + 0.05158114, -0.010947698, -0.011825728, 0.0075720972, + 0.0699727, -0.0039981045, 0.069350146, 0.08799282, + 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, + -0.00924166, 0.0046702605, -0.036598757, -0.08811812, + 0.10522024, -0.032441203, 0.008176899, -0.04454919, + 0.07058152, 0.0067963637, 0.039206743, 0.03259838, + 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, + 0.036879618, 0.043357447, 0.028362012, -0.05908629, + 0.0059240665, -0.04995891, -0.019187413, 0.0276265, + -0.01628143, 0.0025863599, 0.08800015, 0.035250366, + -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, + -0.009660886, 0.019076364, 0.018299393, -0.046004917, + 0.08891175, 0.0431396, -0.026327137, -0.051502608, + 0.08979574, -0.051670972, 0.04940282, -0.07491107, + -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, + -0.035936575, -0.011681591, 0.064818054, 0.0073146066, + -0.021745546, -0.043124277, -0.06471268, -0.07053354, + -0.029321948, -0.05330136, 0.016933719, -0.053782392, + 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, + -0.07924483, 0.06936997, 0.0034815092, -0.007305279, + -0.037325785, -0.07251102, -0.033633437, -0.08677009, + 0.091591336, -0.14165086, 0.021752775, 0.019683983, + 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, + 0.1183656, -0.0010731248, -0.023590032, -0.072285876, + -0.0724771, -0.026382286, -0.0014920527, 0.042667855, + 0.0018776858, 0.02986552, 0.009814309, 0.0733756, + 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, + -0.010036754, 0.02576849, -0.08307328, 0.010112348, + 0.042521734, -0.05869831, -0.071689695, 0.03876447, + -0.13275425, -0.0352966, -0.023077697, 0.10285965, + 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, + -0.08271222, -0.0030240538, -0.016368777, 0.1070414, + 0.042672627, 0.013456989, -0.0437609, -0.022309763, + 0.11576483, 0.04108048, 0.061026827, -0.0190714, + -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, + -0.023771819, -0.01965048, 0.007955471, -0.043740474, + 0.03346837, -0.10549954, 0.090567775, 0.042013682, + -0.03176985, 0.12569028, -0.02421228, -0.029526481, + 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, + -0.06861939, -0.021256343, -0.041093912, -0.06669611, + 0.035498552, 0.021757556, -0.09302526, -0.015403468, + -0.06614931, -0.051798206, -0.013874718, 0.03630673, + 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, + -0.020674974, -0.03944324, -0.008110165, -0.11113267, + 0.08484226, 0.043586485, 0.040582247, 0.0968012, + -0.065249965, -0.028036479, 0.0050708856, 0.0017462453, + 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, + -0.11768019, 0.085926116, -0.08251791, -0.045081906, + 0.0948852, 0.068401024, 0.024856757, 0.06978981, + -0.057309967, -0.012775832, -0.0032452994, 0.01977615, + -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }; + + cell_to_input_weights_ = { + 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}; + + cell_to_forget_weights_ = { + -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}; + + cell_to_output_weights_ = { + 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}; + + projection_weights_ = { + -0.009802181, 0.09401916, 0.0717386, -0.13895074, + 0.09641832, 0.060420845, 0.08539281, 0.054285463, + 0.061395317, 0.034448683, -0.042991187, 0.019801661, + -0.16840284, -0.015726732, -0.23041931, -0.024478018, + -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, + 0.16169067, 0.22465782, -0.03993472, -0.004017731, + 0.08633481, -0.28869787, 0.08682067, 0.17240396, + 0.014975425, 0.056431185, 0.031037588, 0.16702051, + 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, + 0.014777949, -0.20203483, 0.094781205, 0.19100232, + 0.13987629, -0.036132768, -0.06426278, -0.05108664, + 0.13221376, 0.009441198, -0.16715929, 0.15859416, + -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, + 0.0046453946, 0.050794356, 0.10770313, -0.20790008, + -0.07149004, -0.11425117, 0.008225835, -0.035802525, + 0.14374903, 0.15262283, 0.048710253, 0.1847461, + -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, + 0.016261552, 0.022461696, 0.12689082, -0.043589946, + -0.12035478, -0.08361797, -0.050666027, -0.1248618, + -0.1275799, -0.071875185, 0.07377272, 0.09944291, + -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, + 0.041546922, -0.20424393, 0.06907816, 0.050412357, + 0.00724631, 0.039827548, 0.12449835, 0.10747581, + 0.13708383, 0.09134148, -0.12617786, -0.06428341, + 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, + 0.022913318, -0.042050496, 0.16842307, -0.060597885, + 0.10531834, -0.06411776, -0.07451711, -0.03410368, + -0.13393489, 0.06534304, 0.003620307, 0.04490757, + 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, + -0.024575593, -0.036445823, 0.07155557, 0.009672501, + -0.02328883, 0.009533515, -0.03606021, -0.07421458, + -0.028082801, -0.2678904, -0.13221288, 0.18419984, + -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, + 0.14494097, -0.12522776, -0.098633975, -0.10766018, + -0.08317623, 0.08594209, 0.07749552, 0.039474737, + 0.1776665, -0.07409566, -0.0477268, 0.29323658, + 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, + 0.10034707, 0.045594677, 0.0635285, -0.0715442, + -0.089667566, -0.10811871, 0.00026344223, 0.08298446, + -0.009525053, 0.006585689, -0.24567553, -0.09450807, + 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, + 0.067035615, 0.19271925, -0.0032889997, -0.043264326, + 0.09663576, -0.057112187, -0.10100678, 0.0628376, + 0.04447668, 0.017961001, -0.10094388, -0.10190601, + 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, + -0.1947263, 0.02251204, 0.11216432, -0.10307853, + 0.17351969, -0.039091777, 0.08066188, -0.00561982, + 0.12633002, 0.11335965, -0.0088127935, -0.019777594, + 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, + -0.07468996, -0.0855457, 0.099339016, -0.07580735, + -0.13775392, 0.08434318, 0.08330512, -0.12131499, + 0.031935584, 0.09180414, -0.08876437, -0.08049874, + 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, + 0.04331237, 0.04299654, -0.036394123, -0.12915532, + 0.09793732, 0.07512415, -0.11319543, -0.032502122, + 0.15661901, 0.07671967, -0.005491124, -0.19379048, + -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, + -0.09334311, 0.15026465, -0.15493552, -0.057762887, + -0.11604192, -0.262013, -0.01391798, 0.012185008, + 0.11156489, -0.07483202, 0.06693364, -0.26151478, + 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, + 0.030635102, 0.010969227, 0.11109743, 0.010919218, + 0.027526086, 0.13519906, 0.01891392, -0.046839405, + -0.040167913, 0.017953383, -0.09700955, 0.0061885654, + -0.07000971, 0.026893595, -0.038844477, 0.14543656}; + + lstm_input_ = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0 + 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1 + 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2 + 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3 + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0 + 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1 + 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2 + 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3 + }; + + lstm_golden_output_ = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + } +}; + +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -461,8 +1322,9 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { const int sequence_length = 4; UnidirectionalLSTMOpModel lstm( - n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, - /*use_peephole=*/true, /*use_projection_weights=*/true, + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, { @@ -491,588 +1353,99 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights( - {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, - 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, - -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, - -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, - -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, - -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, - -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, - 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, - 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, - 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, - -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, - 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, - -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, - -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, - -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, - 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, - -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, - -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, - -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, - -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); - - lstm.SetInputToForgetWeights( - {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, - -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, - -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, - 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, - 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, - -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, - -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, - 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, - 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, - 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, - 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, - -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, - 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, - -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, - -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, - 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, - 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, - 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, - -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, - 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); - - lstm.SetInputToCellWeights( - {-0.04580283, -0.09549462, -0.032418985, -0.06454633, - -0.043528453, 0.043018587, -0.049152344, -0.12418144, - -0.078985475, -0.07596889, 0.019484362, -0.11434962, - -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, - -0.025034338, -0.0028890965, 0.048929527, 0.06235075, - 0.10665918, -0.032036792, -0.08505916, -0.10843358, - -0.13002433, -0.036816437, -0.02130134, -0.016518239, - 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, - -0.10652836, -0.1037554, -0.13056071, -0.03266643, - -0.033702414, -0.006473424, -0.04611692, 0.014419339, - -0.025174323, 0.0396852, 0.081777506, 0.06157468, - 0.10210095, -0.009658194, 0.046511717, 0.03603906, - 0.0069369148, 0.015960095, -0.06507666, 0.09551598, - 0.053568836, 0.06408714, 0.12835667, -0.008714329, - -0.20211966, -0.12093674, 0.029450472, 0.2849013, - -0.029227901, 0.1164364, -0.08560263, 0.09941786, - -0.036999565, -0.028842626, -0.0033637602, -0.017012902, - -0.09720865, -0.11193351, -0.029155117, -0.017936034, - -0.009768936, -0.04223324, -0.036159635, 0.06505112, - -0.021742892, -0.023377212, -0.07221364, -0.06430552, - 0.05453865, 0.091149814, 0.06387331, 0.007518393, - 0.055960953, 0.069779344, 0.046411168, 0.10509911, - 0.07463894, 0.0075130584, 0.012850982, 0.04555431, - 0.056955688, 0.06555285, 0.050801456, -0.009862683, - 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); - - lstm.SetInputToOutputWeights( - {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, - -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, - 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, - -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, - -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, - 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, - -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, - -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, - -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, - -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, - 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, - 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, - 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, - -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, - 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, - 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, - -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, - 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, - -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, - -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); - - lstm.SetInputGateBias( - {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, - -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, - -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, - 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); - - lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, - 0.11098921, 0.15378423, 0.09263801, 0.09790885, - 0.09508917, 0.061199076, 0.07665568, -0.015443159, - -0.03499149, 0.046190713, 0.08895977, 0.10899629, - 0.40694186, 0.06030037, 0.012413437, -0.06108739}); - - lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, - -0.1483596, -0.10639995, -0.091433935, 0.058573797, - -0.06809782, -0.07889636, -0.043246906, -0.09829136, - -0.4279842, 0.034901652, 0.18797937, 0.0075234566, - 0.016178843, 0.1749513, 0.13975595, 0.92058027}); - - lstm.SetOutputGateBias( - {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, - 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, - 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, - -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); - - lstm.SetRecurrentToInputWeights( - {-0.001374326, -0.078856036, 0.10672688, 0.029162422, - -0.11585556, 0.02557986, -0.13446963, -0.035785314, - -0.01244275, 0.025961924, -0.02337298, -0.044228926, - -0.055839065, -0.046598054, -0.010546039, -0.06900766, - 0.027239809, 0.022582639, -0.013296484, -0.05459212, - 0.08981, -0.045407712, 0.08682226, -0.06867011, - -0.14390695, -0.02916037, 0.000996957, 0.091420636, - 0.14283475, -0.07390571, -0.06402044, 0.062524505, - -0.093129106, 0.04860203, -0.08364217, -0.08119002, - 0.009352075, 0.22920375, 0.0016303885, 0.11583097, - -0.13732095, 0.012405723, -0.07551853, 0.06343048, - 0.12162708, -0.031923793, -0.014335606, 0.01790974, - -0.10650317, -0.0724401, 0.08554849, -0.05727212, - 0.06556731, -0.042729504, -0.043227166, 0.011683251, - -0.013082158, -0.029302018, -0.010899579, -0.062036745, - -0.022509435, -0.00964907, -0.01567329, 0.04260106, - -0.07787477, -0.11576462, 0.017356863, 0.048673786, - -0.017577527, -0.05527947, -0.082487635, -0.040137455, - -0.10820036, -0.04666372, 0.022746278, -0.07851417, - 0.01068115, 0.032956902, 0.022433773, 0.0026891115, - 0.08944216, -0.0685835, 0.010513544, 0.07228705, - 0.02032331, -0.059686817, -0.0005566496, -0.086984694, - 0.040414046, -0.1380399, 0.094208956, -0.05722982, - 0.012092817, -0.04989123, -0.086576, -0.003399834, - -0.04696032, -0.045747425, 0.10091314, 0.048676282, - -0.029037097, 0.031399418, -0.0040285117, 0.047237843, - 0.09504992, 0.041799378, -0.049185462, -0.031518843, - -0.10516937, 0.026374253, 0.10058866, -0.0033195973, - -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, - -0.10167381, 0.042500053, -0.01447153, 0.06464186, - -0.017142897, 0.03312627, 0.009205989, 0.024138335, - -0.011337001, 0.035530265, -0.010912711, 0.0706555, - -0.005894094, 0.051841937, -0.1401738, -0.02351249, - 0.0365468, 0.07590991, 0.08838724, 0.021681072, - -0.10086113, 0.019608743, -0.06195883, 0.077335775, - 0.023646897, -0.095322326, 0.02233014, 0.09756986, - -0.048691444, -0.009579111, 0.07595467, 0.11480546, - -0.09801813, 0.019894179, 0.08502348, 0.004032281, - 0.037211012, 0.068537936, -0.048005626, -0.091520436, - -0.028379958, -0.01556313, 0.06554592, -0.045599163, - -0.01672207, -0.020169014, -0.011877351, -0.20212261, - 0.010889619, 0.0047078193, 0.038385306, 0.08540671, - -0.017140968, -0.0035865551, 0.016678626, 0.005633034, - 0.015963363, 0.00871737, 0.060130805, 0.028611384, - 0.10109069, -0.015060172, -0.07894427, 0.06401885, - 0.011584063, -0.024466386, 0.0047652307, -0.09041358, - 0.030737216, -0.0046374933, 0.14215417, -0.11823516, - 0.019899689, 0.006106124, -0.027092824, 0.0786356, - 0.05052217, -0.058925, -0.011402121, -0.024987547, - -0.0013661642, -0.06832946, -0.015667673, -0.1083353, - -0.00096863037, -0.06988685, -0.053350925, -0.027275559, - -0.033664223, -0.07978348, -0.025200296, -0.017207067, - -0.058403496, -0.055697463, 0.005798788, 0.12965427, - -0.062582195, 0.0013350133, -0.10482091, 0.0379771, - 0.072521195, -0.0029455067, -0.13797039, -0.03628521, - 0.013806405, -0.017858358, -0.01008298, -0.07700066, - -0.017081132, 0.019358726, 0.0027079724, 0.004635139, - 0.062634714, -0.02338735, -0.039547626, -0.02050681, - 0.03385117, -0.083611414, 0.002862572, -0.09421313, - 0.058618143, -0.08598433, 0.00972939, 0.023867095, - -0.053934585, -0.023203006, 0.07452513, -0.048767887, - -0.07314807, -0.056307215, -0.10433547, -0.06440842, - 0.04328182, 0.04389765, -0.020006588, -0.09076438, - -0.11652589, -0.021705797, 0.03345259, -0.010329105, - -0.025767034, 0.013057034, -0.07316461, -0.10145612, - 0.06358255, 0.18531723, 0.07759293, 0.12006465, - 0.1305557, 0.058638252, -0.03393652, 0.09622831, - -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, - -0.005644518, 0.06857898, -0.12598175, -0.035084512, - 0.03156317, -0.12794146, -0.031963028, 0.04692781, - 0.030070418, 0.0071660685, -0.095516115, -0.004643372, - 0.040170413, -0.062104587, -0.0037324072, 0.0554317, - 0.08184801, -0.019164372, 0.06791302, 0.034257166, - -0.10307039, 0.021943003, 0.046745934, 0.0790918, - -0.0265588, -0.007824208, 0.042546265, -0.00977924, - -0.0002440307, -0.017384544, -0.017990116, 0.12252321, - -0.014512694, -0.08251313, 0.08861942, 0.13589665, - 0.026351685, 0.012641483, 0.07466548, 0.044301085, - -0.045414884, -0.051112458, 0.03444247, -0.08502782, - -0.04106223, -0.028126027, 0.028473156, 0.10467447}); - - lstm.SetRecurrentToForgetWeights( - {-0.057784554, -0.026057621, -0.068447545, -0.022581743, - 0.14811787, 0.10826372, 0.09471067, 0.03987225, - -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, - 0.08414449, -0.022036452, -0.00066928595, -0.09203576, - 0.032950465, -0.10985798, -0.023809856, 0.0021431844, - -0.02196096, -0.00326074, 0.00058621005, -0.074678116, - -0.06193199, 0.055729095, 0.03736828, 0.020123724, - 0.061878487, -0.04729229, 0.034919553, -0.07585433, - -0.04421272, -0.044019096, 0.085488975, 0.04058006, - -0.06890133, -0.030951202, -0.024628663, -0.07672815, - 0.034293607, 0.08556707, -0.05293577, -0.033561368, - -0.04899627, 0.0241671, 0.015736353, -0.095442444, - -0.029564252, 0.016493602, -0.035026584, 0.022337519, - -0.026871363, 0.004780428, 0.0077918363, -0.03601621, - 0.016435321, -0.03263031, -0.09543275, -0.047392778, - 0.013454138, 0.028934088, 0.01685226, -0.086110644, - -0.046250615, -0.01847454, 0.047608484, 0.07339695, - 0.034546845, -0.04881143, 0.009128804, -0.08802852, - 0.03761666, 0.008096139, -0.014454086, 0.014361001, - -0.023502491, -0.0011840804, -0.07607001, 0.001856849, - -0.06509276, -0.006021153, -0.08570962, -0.1451793, - 0.060212336, 0.055259194, 0.06974018, 0.049454916, - -0.027794661, -0.08077226, -0.016179763, 0.1169753, - 0.17213494, -0.0056326236, -0.053934924, -0.0124349, - -0.11520337, 0.05409887, 0.088759385, 0.0019655675, - 0.0042065294, 0.03881498, 0.019844765, 0.041858196, - -0.05695512, 0.047233116, 0.038937137, -0.06542224, - 0.014429736, -0.09719407, 0.13908425, -0.05379757, - 0.012321099, 0.082840554, -0.029899208, 0.044217527, - 0.059855383, 0.07711018, -0.045319796, 0.0948846, - -0.011724666, -0.0033288454, -0.033542685, -0.04764985, - -0.13873616, 0.040668588, 0.034832682, -0.015319203, - -0.018715994, 0.046002675, 0.0599172, -0.043107376, - 0.0294216, -0.002314414, -0.022424703, 0.0030315618, - 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, - 0.12375372, -0.0006038222, 0.029104086, 0.087442465, - 0.052958444, 0.07558703, 0.04817258, 0.044462286, - -0.015213451, -0.08783778, -0.0561384, -0.003008196, - 0.047060397, -0.002058388, 0.03429439, -0.018839769, - 0.024734668, 0.024614193, -0.042046934, 0.09597743, - -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, - -0.02558259, -0.022822596, -0.023273505, -0.02464396, - -0.10991725, -0.006240552, 0.0074488563, 0.024044557, - 0.04383914, -0.046476185, 0.028658995, 0.060410924, - 0.050786525, 0.009452605, -0.0073054377, -0.024810238, - 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, - 0.015898481, 0.021362653, -0.030262267, 0.016587038, - -0.011442813, 0.041154444, -0.007631438, -0.03423484, - -0.010977775, 0.036152758, 0.0066366293, 0.11915515, - 0.02318443, -0.041350313, 0.021485701, -0.10906167, - -0.028218046, -0.00954771, 0.020531068, -0.11995105, - -0.03672871, 0.024019798, 0.014255957, -0.05221243, - -0.00661567, -0.04630967, 0.033188973, 0.10107534, - -0.014027541, 0.030796422, -0.10270911, -0.035999842, - 0.15443139, 0.07684145, 0.036571592, -0.035900835, - -0.0034699554, 0.06209149, 0.015920248, -0.031122351, - -0.03858649, 0.01849943, 0.13872518, 0.01503974, - 0.069941424, -0.06948533, -0.0088794185, 0.061282158, - -0.047401894, 0.03100163, -0.041533746, -0.10430945, - 0.044574402, -0.01425562, -0.024290353, 0.034563623, - 0.05866852, 0.023947537, -0.09445152, 0.035450947, - 0.02247216, -0.0042998926, 0.061146557, -0.10250651, - 0.020881841, -0.06747029, 0.10062043, -0.0023941975, - 0.03532124, -0.016341697, 0.09685456, -0.016764693, - 0.051808182, 0.05875331, -0.04536488, 0.001626336, - -0.028892258, -0.01048663, -0.009793449, -0.017093895, - 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, - -0.001845119, -0.03551521, 0.0018358806, 0.05763657, - -0.01769146, 0.040995963, 0.02235177, -0.060430344, - 0.11475477, -0.023854522, 0.10071741, 0.0686208, - -0.014250481, 0.034261297, 0.047418304, 0.08562733, - -0.030519066, 0.0060542435, 0.014653856, -0.038836084, - 0.04096551, 0.032249358, -0.08355519, -0.026823482, - 0.056386515, -0.010401743, -0.028396193, 0.08507674, - 0.014410365, 0.020995233, 0.17040324, 0.11511526, - 0.02459721, 0.0066619175, 0.025853224, -0.023133837, - -0.081302024, 0.017264642, -0.009585969, 0.09491168, - -0.051313367, 0.054532815, -0.014298593, 0.10657464, - 0.007076659, 0.10964551, 0.0409152, 0.008275321, - -0.07283536, 0.07937492, 0.04192024, -0.1075027}); - - lstm.SetRecurrentToCellWeights( - {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, - 0.055647098, -0.05713207, -0.05626563, 0.005559383, - 0.03375411, -0.025757805, -0.088049285, 0.06017052, - -0.06570978, 0.007384076, 0.035123326, -0.07920549, - 0.053676967, 0.044480428, -0.07663568, 0.0071805613, - 0.08089997, 0.05143358, 0.038261272, 0.03339287, - -0.027673481, 0.044746667, 0.028349208, 0.020090483, - -0.019443132, -0.030755889, -0.0040000007, 0.04465846, - -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, - -0.10893326, 0.076739706, -0.08509834, -0.027997585, - 0.037871376, 0.01449768, -0.09002357, -0.06111149, - -0.046195522, 0.0422062, -0.005683705, -0.1253618, - -0.012925729, -0.04890792, 0.06985068, 0.037654128, - 0.03398274, -0.004781977, 0.007032333, -0.031787455, - 0.010868644, -0.031489216, 0.09525667, 0.013939797, - 0.0058680447, 0.0167067, 0.02668468, -0.04797466, - -0.048885044, -0.12722108, 0.035304096, 0.06554885, - 0.00972396, -0.039238118, -0.05159735, -0.11329045, - 0.1613692, -0.03750952, 0.06529313, -0.071974665, - -0.11769596, 0.015524369, -0.0013754242, -0.12446318, - 0.02786344, -0.014179351, 0.005264273, 0.14376344, - 0.015983658, 0.03406988, -0.06939408, 0.040699873, - 0.02111075, 0.09669095, 0.041345075, -0.08316494, - -0.07684199, -0.045768797, 0.032298047, -0.041805092, - 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, - -0.024950314, 0.11574242, 0.04508852, -0.04335324, - 0.06760663, -0.027437469, 0.07216407, 0.06977076, - -0.05438599, 0.034033038, -0.028602652, 0.05346137, - 0.043184172, -0.037189785, 0.10420091, 0.00882477, - -0.054019816, -0.074273005, -0.030617684, -0.0028467078, - 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, - 0.04361412, -0.007001822, 0.09631092, -0.06702025, - -0.042049985, -0.035070654, -0.04103342, -0.10273396, - 0.0544271, 0.037184782, -0.13150354, -0.0058036847, - -0.008264958, 0.042035464, 0.05891794, 0.029673764, - 0.0063542654, 0.044788733, 0.054816857, 0.062257513, - -0.00093483756, 0.048938446, -0.004952862, -0.007730018, - -0.04043371, -0.017094059, 0.07229206, -0.023670016, - -0.052195564, -0.025616996, -0.01520939, 0.045104615, - -0.007376126, 0.003533447, 0.006570588, 0.056037236, - 0.12436656, 0.051817212, 0.028532185, -0.08686856, - 0.11868599, 0.07663395, -0.07323171, 0.03463402, - -0.050708205, -0.04458982, -0.11590894, 0.021273347, - 0.1251325, -0.15313013, -0.12224372, 0.17228661, - 0.023029093, 0.086124025, 0.006445803, -0.03496501, - 0.028332196, 0.04449512, -0.042436164, -0.026587414, - -0.006041347, -0.09292539, -0.05678812, 0.03897832, - 0.09465633, 0.008115513, -0.02171956, 0.08304309, - 0.071401566, 0.019622514, 0.032163795, -0.004167056, - 0.02295182, 0.030739572, 0.056506045, 0.004612461, - 0.06524936, 0.059999723, 0.046395954, -0.0045512207, - -0.1335546, -0.030136576, 0.11584653, -0.014678886, - 0.0020118146, -0.09688814, -0.0790206, 0.039770417, - -0.0329582, 0.07922767, 0.029322514, 0.026405897, - 0.04207835, -0.07073373, 0.063781224, 0.0859677, - -0.10925287, -0.07011058, 0.048005477, 0.03438226, - -0.09606514, -0.006669445, -0.043381985, 0.04240257, - -0.06955775, -0.06769346, 0.043903265, -0.026784198, - -0.017840602, 0.024307009, -0.040079936, -0.019946516, - 0.045318738, -0.12233574, 0.026170589, 0.0074471775, - 0.15978073, 0.10185836, 0.10298046, -0.015476589, - -0.039390966, -0.072174534, 0.0739445, -0.1211869, - -0.0347889, -0.07943156, 0.014809798, -0.12412325, - -0.0030663363, 0.039695457, 0.0647603, -0.08291318, - -0.018529687, -0.004423833, 0.0037507233, 0.084633216, - -0.01514876, -0.056505352, -0.012800942, -0.06994386, - 0.012962922, -0.031234352, 0.07029052, 0.016418684, - 0.03618972, 0.055686004, -0.08663945, -0.017404709, - -0.054761406, 0.029065743, 0.052404847, 0.020238016, - 0.0048197987, -0.0214882, 0.07078733, 0.013016777, - 0.06262858, 0.009184685, 0.020785125, -0.043904778, - -0.0270329, -0.03299152, -0.060088247, -0.015162964, - -0.001828936, 0.12642565, -0.056757294, 0.013586685, - 0.09232601, -0.035886683, 0.06000002, 0.05229691, - -0.052580316, -0.082029596, -0.010794592, 0.012947712, - -0.036429964, -0.085508935, -0.13127148, -0.017744139, - 0.031502828, 0.036232427, -0.031581745, 0.023051167, - -0.05325106, -0.03421577, 0.028793324, -0.034633752, - -0.009881397, -0.043551125, -0.018609839, 0.0019097115, - -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); - - lstm.SetRecurrentToOutputWeights({ - 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, - -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, - -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, - -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, - -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, - -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, - -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, - 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, - -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, - 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, - -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, - -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, - 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, - 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, - -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, - 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, - 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, - 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, - 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, - 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, - -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, - 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, - -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, - 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, - 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, - 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, - -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, - -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, - -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, - -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, - -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, - -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, - 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, - 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, - -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, - 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, - -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, - -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, - -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, - 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, - 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, - 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, - -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, - 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, - -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, - -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, - -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, - -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, - 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, - -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, - 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, - -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, - -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, - -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, - -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, - 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, - 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, - -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, - 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, - 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, - -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, - 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, - 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, - 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, - }); - - lstm.SetCellToInputWeights( - {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, - -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, - -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, - 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); - - lstm.SetCellToForgetWeights( - {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, - -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, - -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, - 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); - - lstm.SetCellToOutputWeights( - {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, - -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, - -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, - 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); - - lstm.SetProjectionWeights( - {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, - 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, - -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, - -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, - 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, - 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, - 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, - 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, - -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, - -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, - -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, - 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, - 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, - 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, - 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, - 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, - -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, - 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, - -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, - 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, - -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, - -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, - 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, - -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, - 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, - -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, - -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, - 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, - -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, - -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, - -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, - 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, - 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, - -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, - 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, - 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, - 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, - 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, - 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, - -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, - -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, - 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, - -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, - -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, - 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, - 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, - 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, - -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, - -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, - -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, - 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, - -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, - 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, - 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, - -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, - -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, - -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, - 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, - -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, - -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, - -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, - 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, - 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, - 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); - - static float lstm_input[][20] = { - {// Batch0: 4 (input_sequence_size) * 5 (n_input) - 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, - 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, - 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, - - {// Batch1: 4 (input_sequence_size) * 5 (n_input) - 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, - 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, - 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; - - static float lstm_golden_output[][64] = { - {// Batch0: 4 (input_sequence_size) * 16 (n_output) - -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, - -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, - -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, - 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, - -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, - -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, - 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, - 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, - 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, - 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, - -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, - -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, - 0.0286833, 0.00824207, 0.0264887, 0.0305169}, - {// Batch1: 4 (input_sequence_size) * 16 (n_output) - -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, - -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, - 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, - 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, - -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, - -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, - 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, - 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, - 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, - 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, - -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, - -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, - 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - for (int i = 0; i < lstm.sequence_length(); i++) { - float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end); +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + const int sequence_length = 4; - float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); - float* batch1_end = batch1_start + lstm.num_inputs(); - lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end); - } + HybridUnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor - lstm.Invoke(); + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor - std::vector expected; - for (int i = 0; i < lstm.sequence_length(); i++) { - float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); - float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); - float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); - float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); - expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); - expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); - } - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index 164a0cbd08d6ce82a413f12ba6b1703087a30aba..0d6d29a171735a00a8dcc6cd0213a859b9f8094a 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include diff --git a/tensorflow/contrib/lite/mmap_allocation.cc b/tensorflow/contrib/lite/mmap_allocation.cc new file mode 100644 index 0000000000000000000000000000000000000000..fa9a3cd1d839b07149bb80c3b7714b32b5eda235 --- /dev/null +++ b/tensorflow/contrib/lite/mmap_allocation.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/error_reporter.h" + +namespace tflite { + +MMAPAllocation::MMAPAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) { + mmap_fd_ = open(filename, O_RDONLY); + if (mmap_fd_ == -1) { + error_reporter_->Report("Could not open '%s'.", filename); + return; + } + struct stat sb; + fstat(mmap_fd_, &sb); + buffer_size_bytes_ = sb.st_size; + mmapped_buffer_ = + mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0); + if (mmapped_buffer_ == MAP_FAILED) { + error_reporter_->Report("Mmap of '%s' failed.", filename); + return; + } +} + +MMAPAllocation::~MMAPAllocation() { + if (valid()) { + munmap(const_cast(mmapped_buffer_), buffer_size_bytes_); + } + if (mmap_fd_ != -1) close(mmap_fd_); +} + +const void* MMAPAllocation::base() const { return mmapped_buffer_; } + +size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; } + +bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; } + +bool MMAPAllocation::IsSupported() { return true; } + +} // namespace tflite diff --git a/tensorflow/contrib/lite/mmap_allocation_disabled.cc b/tensorflow/contrib/lite/mmap_allocation_disabled.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3d4cf1a257d43ebd56cc9b8831de0bb1994d40c --- /dev/null +++ b/tensorflow/contrib/lite/mmap_allocation_disabled.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/allocation.h" + +#include + +namespace tflite { + +MMAPAllocation::MMAPAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter), mmapped_buffer_(nullptr) { + // The disabled variant should never be created. + assert(false); +} + +MMAPAllocation::~MMAPAllocation() {} + +const void* MMAPAllocation::base() const { return nullptr; } + +size_t MMAPAllocation::bytes() const { return 0; } + +bool MMAPAllocation::valid() const { return false; } + +bool MMAPAllocation::IsSupported() { return false; } + +} // namespace tflite diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 039f32b38eb29068b223dd63355c66295301beba..9edf5ba38f4c6506524074bc0a3ebe7e068c7ee3 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -16,16 +16,16 @@ limitations under the License. #include #include #include -#include #include #include -#include #include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/model.h" +#ifndef TFLITE_MCU #include "tensorflow/contrib/lite/nnapi_delegate.h" +#endif #include "tensorflow/contrib/lite/version.h" namespace tflite { @@ -45,6 +45,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, case TensorType_FLOAT32: *type = kTfLiteFloat32; break; + case TensorType_INT16: + *type = kTfLiteInt16; + break; case TensorType_INT32: *type = kTfLiteInt32; break; @@ -60,6 +63,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, case TensorType_BOOL: *type = kTfLiteBool; break; + case TensorType_COMPLEX64: + *type = kTfLiteComplex64; + break; default: error_reporter->Report("Unimplemented data type %s (%d) in tensor\n", EnumNameTensorType(tensor_type), tensor_type); @@ -68,6 +74,7 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, return kTfLiteOk; } +#ifndef TFLITE_MCU // Loads a model from `filename`. If `mmap_file` is true then use mmap, // otherwise make a copy of the model in a buffer. std::unique_ptr GetAllocationFromFile(const char* filename, @@ -75,8 +82,8 @@ std::unique_ptr GetAllocationFromFile(const char* filename, ErrorReporter* error_reporter, bool use_nnapi) { std::unique_ptr allocation; - if (mmap_file) { - if (use_nnapi && NNAPIExists()) + if (mmap_file && MMAPAllocation::IsSupported()) { + if (use_nnapi && NNAPIDelegate::IsSupported()) allocation.reset(new NNAPIAllocation(filename, error_reporter)); else allocation.reset(new MMAPAllocation(filename, error_reporter)); @@ -115,6 +122,7 @@ std::unique_ptr FlatBufferModel::VerifyAndBuildFromFile( if (!model->initialized()) model.reset(); return model; } +#endif std::unique_ptr FlatBufferModel::BuildFromBuffer( const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) { @@ -180,6 +188,8 @@ InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model, op_resolver_(op_resolver), error_reporter_(ValidateErrorReporter(error_reporter)) {} +InterpreterBuilder::~InterpreterBuilder() {} + TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { TfLiteStatus status = kTfLiteOk; auto opcodes = model_->operator_codes(); @@ -198,8 +208,9 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { } else if (builtin_code != BuiltinOperator_CUSTOM) { registration = op_resolver_.FindOp(builtin_code, version); if (registration == nullptr) { - error_reporter_->Report("Didn't find op for builtin opcode '%s'\n", - EnumNameBuiltinOperator(builtin_code)); + error_reporter_->Report( + "Didn't find op for builtin opcode '%s' version '%d'\n", + EnumNameBuiltinOperator(builtin_code), version); status = kTfLiteError; } } else if (!opcode->custom_code()) { @@ -441,6 +452,18 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, op->builtin_options_as_FullyConnectedOptions()) { params->activation = parse_activation( fully_connected_params->fused_activation_function()); + switch (fully_connected_params->weights_format()) { + case FullyConnectedOptionsWeightsFormat_DEFAULT: + params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; + break; + case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + params->weights_format = + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; + break; + default: + error_reporter->Report("Unhandled fully-connected weights format."); + return kTfLiteError; + } } *builtin_data = reinterpret_cast(params); break; @@ -594,9 +617,12 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_MEAN: { - auto* params = MallocPOD(); - if (auto* schema_params = op->builtin_options_as_MeanOptions()) { + case BuiltinOperator_MEAN: + case BuiltinOperator_REDUCE_MAX: + case BuiltinOperator_REDUCE_PROD: + case BuiltinOperator_SUM: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_ReducerOptions()) { params->keep_dims = schema_params->keep_dims(); } *builtin_data = reinterpret_cast(params); @@ -642,6 +668,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_ARG_MIN: { + auto* params = MallocPOD(); + if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) { + ConvertTensorType(schema_params->output_type(), ¶ms->output_type, + error_reporter); + } + *builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_TRANSPOSE_CONV: { TfLiteTransposeConvParams* params = MallocPOD(); @@ -664,11 +699,48 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_SHAPE: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_ShapeOptions()) { + ConvertTensorType(schema_params->out_type(), ¶ms->out_type, + error_reporter); + } + *builtin_data = static_cast(params); + break; + } + case BuiltinOperator_PACK: { + TfLitePackParams* params = MallocPOD(); + if (auto* pack_params = op->builtin_options_as_PackOptions()) { + params->values_count = pack_params->values_count(); + params->axis = pack_params->axis(); + } + *builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_DELEGATE: { // TODO(ycling): Revisit when supporting saving delegated models. error_reporter->Report("DELEGATE op shouldn't exist in model."); return kTfLiteError; } + case BuiltinOperator_FAKE_QUANT: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) { + params->min = schema_params->min(); + params->max = schema_params->max(); + params->num_bits = schema_params->num_bits(); + params->narrow_range = schema_params->narrow_range(); + } + *builtin_data = static_cast(params); + break; + } + case BuiltinOperator_ONE_HOT: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_OneHotOptions()) { + params->axis = schema_params->axis(); + } + *builtin_data = static_cast(params); + break; + } // Below are the ops with no builtin_data strcture. case BuiltinOperator_BATCH_TO_SPACE_ND: @@ -700,14 +772,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_RELU: case BuiltinOperator_RELU6: case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_RSQRT: case BuiltinOperator_SELECT: case BuiltinOperator_SIN: case BuiltinOperator_SLICE: case BuiltinOperator_SPACE_TO_BATCH_ND: + case BuiltinOperator_SQRT: case BuiltinOperator_TANH: case BuiltinOperator_TILE: case BuiltinOperator_TOPK_V2: case BuiltinOperator_TRANSPOSE: + case BuiltinOperator_POW: + case BuiltinOperator_LOGICAL_OR: + case BuiltinOperator_LOGICAL_AND: + case BuiltinOperator_LOGICAL_NOT: break; } return kTfLiteOk; @@ -730,7 +808,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes( } const TfLiteRegistration* registration = - flatbuffer_op_index_to_registration_[op->opcode_index()]; + flatbuffer_op_index_to_registration_[index]; if (registration == nullptr) { error_reporter_->Report("Skipping op for opcode_index %d\n", index); status = kTfLiteError; @@ -849,7 +927,16 @@ TfLiteStatus InterpreterBuilder::ParseTensors( const char* buffer_ptr; TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size)); + bool is_variable = tensor->is_variable(); if (buffer_ptr) { + if (is_variable) { + error_reporter_->Report( + "Tensor %d is a variable tensor with buffer. " + "It's not supported now.\n", + i); + status = kTfLiteError; + } + if (interpreter->SetTensorParametersReadOnly( i, type, get_name(tensor), dims, quantization, buffer_ptr, buffer_size, allocation_) != kTfLiteOk) { @@ -858,8 +945,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors( status = kTfLiteError; } } else { - if (interpreter->SetTensorParametersReadWrite( - i, type, get_name(tensor), dims, quantization) != kTfLiteOk) { + if (interpreter->SetTensorParametersReadWrite(i, type, get_name(tensor), + dims, quantization, + is_variable) != kTfLiteOk) { error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", i); status = kTfLiteError; @@ -943,6 +1031,15 @@ TfLiteStatus InterpreterBuilder::operator()( if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk) return cleanup_and_error(); + std::vector variables; + for (int i = 0; i < (*interpreter)->tensors_size(); ++i) { + auto* tensor = (*interpreter)->tensor(i); + if (tensor->is_variable) { + variables.push_back(i); + } + } + (**interpreter).SetVariables(std::move(variables)); + return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index 3946b490417104f620ecb55bb22d4ef99fd33bb7..8bc9ecd7ce9725c3d88985ccd92d48aed169fe31 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -156,6 +156,7 @@ class InterpreterBuilder { InterpreterBuilder(const ::tflite::Model* model, const OpResolver& op_resolver, ErrorReporter* error_reporter = DefaultErrorReporter()); + ~InterpreterBuilder(); InterpreterBuilder(const InterpreterBuilder&) = delete; InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; TfLiteStatus operator()(std::unique_ptr* interpreter); diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index 15bae21a411c1241cf71ab4d3f0e0289eaac8ef3..df4f60d4ad4eb71f48eb3ad364f95f93b84f3d75 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include "tensorflow/contrib/lite/model.h" @@ -242,14 +241,6 @@ TEST(BasicFlatBufferModel, TestWithNullVerifier) { "tensorflow/contrib/lite/testdata/test_model.bin", nullptr)); } -struct TestErrorReporter : public ErrorReporter { - int Report(const char* format, va_list args) override { - calls++; - return 0; - } - int calls = 0; -}; - // This makes sure the ErrorReporter is marshalled from FlatBufferModel to // the Interpreter. TEST(BasicFlatBufferModel, TestCustomErrorReporter) { @@ -263,7 +254,7 @@ TEST(BasicFlatBufferModel, TestCustomErrorReporter) { TrivialResolver resolver; InterpreterBuilder(*model, resolver)(&interpreter); ASSERT_NE(interpreter->Invoke(), kTfLiteOk); - ASSERT_EQ(reporter.calls, 1); + ASSERT_EQ(reporter.num_calls(), 1); } // This makes sure the ErrorReporter is marshalled from FlatBufferModel to diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD index 8b5fa240ac31d9ee61879c42aee3c5d449ae60db..9d88c396ba69948e3ae285c913a4499a1409b93a 100644 --- a/tensorflow/contrib/lite/models/smartreply/BUILD +++ b/tensorflow/contrib/lite/models/smartreply/BUILD @@ -47,6 +47,7 @@ cc_test( name = "extract_feature_op_test", size = "small", srcs = ["ops/extract_feature_test.cc"], + tags = ["no_oss"], deps = [ ":custom_ops", "//tensorflow/contrib/lite:framework", @@ -61,6 +62,7 @@ cc_test( name = "normalize_op_test", size = "small", srcs = ["ops/normalize_test.cc"], + tags = ["no_oss"], deps = [ ":custom_ops", "//tensorflow/contrib/lite:framework", @@ -75,6 +77,7 @@ cc_test( name = "predict_op_test", size = "small", srcs = ["ops/predict_test.cc"], + tags = ["no_oss"], deps = [ ":custom_ops", "//tensorflow/contrib/lite:framework", diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD index f8767b443a2aa64b666c3b6bfb7db30cc0be62ea..f18a2ca07a5f66b760e96a6d9a57db8d6c26b7b9 100644 --- a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD @@ -1,3 +1,5 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h index 90260c8d620b0e756f72089d3f4d8d9f92d44fbe..3151192d9277b6df513a76afb08af30d0379b7b1 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor.h +++ b/tensorflow/contrib/lite/models/smartreply/predictor.h @@ -65,9 +65,9 @@ struct SmartReplyConfig { float backoff_confidence; // Backoff responses are used when predicted responses cannot fulfill the // list. - const std::vector& backoff_responses; + std::vector backoff_responses; - SmartReplyConfig(std::vector backoff_responses) + SmartReplyConfig(const std::vector& backoff_responses) : num_response(kDefaultNumResponse), backoff_confidence(kDefaultBackoffConfidence), backoff_responses(backoff_responses) {} diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h index becd1f615f04a806cba9c494323285c004ec41df..42b8163445d252c766491e7bcd2fd7eea0dd7571 100644 --- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h +++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h @@ -44,6 +44,19 @@ inline void* loadLibrary(const char* name) { return handle; } +typedef int (*ASharedMemory_create_fn)(const char* name, size_t size); + +// ASharedMemory_create was added in Android 8.0, so safe to use with NNAPI +// which was added in 8.1. +inline int ASharedMemory_create(const char* name, size_t size) { + static void* handle = loadLibrary("libandroid.so"); + static ASharedMemory_create_fn fn = + handle != nullptr ? reinterpret_cast( + dlsym(handle, "ASharedMemory_create")) + : nullptr; + return fn(name, size); +} + inline void* getLibraryHandle() { static void* handle = loadLibrary("libneuralnetworks.so"); return handle; diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index c71ad1d37937228662798a7370d66147191032ec..7cecebf425273e1ac0cf8a17bd9bc4312db7de62 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -25,32 +25,58 @@ limitations under the License. #include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" #ifdef __ANDROID__ +#include #include #endif namespace tflite { -// TODO(aselle): FATAL leaves resources hanging. -void FATAL(const char* format, ...) { - va_list args; - va_start(args, format); - vfprintf(stderr, format, args); - va_end(args); +void logError(const char* format, ...) { + // stderr is convenient for native tests, but is not captured for apps + va_list args_for_stderr; + va_start(args_for_stderr, format); + vfprintf(stderr, format, args_for_stderr); + va_end(args_for_stderr); + fprintf(stderr, "\n"); fflush(stderr); - exit(1); +#ifdef __ANDROID__ + // produce logcat output for general consumption + va_list args_for_log; + va_start(args_for_log, format); + __android_log_vprint(ANDROID_LOG_ERROR, "tflite", format, args_for_log); + va_end(args_for_log); +#endif } +#define FATAL(...) \ + logError(__VA_ARGS__); \ + exit(1); + // TODO(aselle): Change the error model to use status codes. -#define CHECK_TFLITE_SUCCESS(x) \ - if (x != kTfLiteOk) { \ - FATAL("Aborting since tflite returned failure."); \ +#define CHECK_TFLITE_SUCCESS(x) \ + if (x != kTfLiteOk) { \ + FATAL("Aborting since tflite returned failure nnapi_delegate.cc:%d.", \ + __LINE__); \ } -#define CHECK_NN(x) \ - if (x != ANEURALNETWORKS_NO_ERROR) { \ - FATAL("Aborting since tflite returned failure."); \ +#define CHECK_NN(x) \ + if (x != ANEURALNETWORKS_NO_ERROR) { \ + FATAL("Aborting since NNAPI returned failure nnapi_delegate.cc:%d", \ + __LINE__); \ } +#define RETURN_ERROR_IF_NN_FAILED(x) \ + if (x != ANEURALNETWORKS_NO_ERROR) { \ + logError( \ + "Returning error since NNAPI returned failure nnapi_delegate.cc:%d.", \ + __LINE__); \ + return kTfLiteError; \ + } + +// Tracking of NNAPI operand ids +static const int64_t kOperandIdNotSet = -1; +static const int64_t kOperandNotNeeded = -2; + namespace { int32_t GetAndroidSdkVersion() { @@ -105,21 +131,16 @@ NNAPIDelegate::~NNAPIDelegate() { } // Adds the tensors of the interpreter to the NN API model. -// Returns the number of operands added. -uint32_t addTensorOperands(tflite::Interpreter* interpreter, - ANeuralNetworksModel* nn_model, - const std::vector& skip_list) { +TfLiteStatus addTensorOperands(tflite::Interpreter* interpreter, + ANeuralNetworksModel* nn_model, + uint32_t* no_of_operands_added, + std::vector* nnapi_ids) { uint32_t next_id = 0; for (size_t i = 0; i < interpreter->tensors_size(); i++) { - // skip temporaries tensors. - bool shouldSkip = false; - for (auto skip_idx : skip_list) { - if (i == skip_idx) { - shouldSkip = true; - break; - } - } - if (shouldSkip) continue; + // Skip temporaries and RNN back-edges. + if ((*nnapi_ids)[i] == kOperandNotNeeded) continue; + + (*nnapi_ids)[i] = int64_t(next_id); int32_t nn_type = 0; // NNAPI requires 32-bit float scale to be zero, tflite doesn't care @@ -149,7 +170,18 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter, zeroPoint = tensor->params.zero_point; break; default: - FATAL("Unsupported type."); + logError("Unsupported tensor type %d", tensor->type); + return kTfLiteError; + } + if (tensor->dims->size == 0) { + logError("NNAPI doesn't support tensors with rank 0 (index %d name %s)", + i, tensor->name); + return kTfLiteError; + } + if (tensor->dims->size > 4) { + logError("NNAPI doesn't support tensors with rank > 4 (index %d name %s)", + i, tensor->name); + return kTfLiteError; } // We set size of all intermediate operands here. NNAPI is able to execute // a graph with unknown intermediate operand sizes but known sizes allow @@ -158,36 +190,53 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter, ANeuralNetworksOperandType operand_type{ nn_type, static_cast(tensor->dims->size), reinterpret_cast(tensor->dims->data), scale, zeroPoint}; - CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); + RETURN_ERROR_IF_NN_FAILED( + ANeuralNetworksModel_addOperand(nn_model, &operand_type)); // TODO(aselle): Based on Michael's suggestion, limiting this to read // only memory if (tensor->allocation_type == kTfLiteMmapRo) { if (const NNAPIAllocation* alloc = dynamic_cast( static_cast(tensor->allocation))) { - CHECK_NN(ANeuralNetworksModel_setOperandValueFromMemory( - nn_model, next_id, alloc->memory(), alloc->offset(tensor->data.raw), - tensor->bytes)); + RETURN_ERROR_IF_NN_FAILED( + ANeuralNetworksModel_setOperandValueFromMemory( + nn_model, next_id, alloc->memory(), + alloc->offset(tensor->data.raw), tensor->bytes)); } else { - CHECK_NN(ANeuralNetworksModel_setOperandValue( + RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_setOperandValue( nn_model, next_id, tensor->data.raw, tensor->bytes)); } } else if (tensor->bytes == 0) { // These size 0 tensors are optional tensors reserved. - CHECK_NN( + RETURN_ERROR_IF_NN_FAILED( ANeuralNetworksModel_setOperandValue(nn_model, next_id, nullptr, 0)); } ++next_id; } - return next_id; + *no_of_operands_added = next_id; + return kTfLiteOk; +} + +void MapAndAddTensorIds(const int* from_ids_buf, size_t from_ids_count, + std::vector* into, + const std::vector& map) { + for (size_t i = 0; i < from_ids_count; i++) { + int from_id = from_ids_buf[i]; + if (from_id == kOptionalTensor) { + into->push_back(from_id); + } else { + into->push_back(map[from_id]); + } + } } // Adds the operations and their parameters to the NN API model. // 'next-id' is the operand ID of the next operand of the model. -void AddOpsAndParams(tflite::Interpreter* interpreter, - ANeuralNetworksModel* nn_model, uint32_t next_id, - std::vector* model_state_inputs, - std::vector* model_state_outputs) { +TfLiteStatus AddOpsAndParams( + tflite::Interpreter* interpreter, ANeuralNetworksModel* nn_model, + uint32_t next_id, std::vector* model_state_inputs, + std::vector* model_state_outputs, + const std::vector& tensor_id_to_nnapi_id) { for (size_t i = 0; i < interpreter->nodes_size(); i++) { const auto* node_and_registration = interpreter->node_and_registration(i); const TfLiteNode& node = node_and_registration->first; @@ -196,10 +245,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, static_cast(registration.builtin_code); // Add the parameters. - std::vector augmented_inputs( - node.inputs->data, node.inputs->data + node.inputs->size); - std::vector augmented_outputs( - node.outputs->data, node.outputs->data + node.outputs->size); + std::vector augmented_inputs, augmented_outputs; + MapAndAddTensorIds(node.inputs->data, node.inputs->size, &augmented_inputs, + tensor_id_to_nnapi_id); + MapAndAddTensorIds(node.outputs->data, node.outputs->size, + &augmented_outputs, tensor_id_to_nnapi_id); auto add_scalar_int32 = [&nn_model, &augmented_inputs, &next_id](int value) { @@ -219,6 +269,17 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, augmented_inputs.push_back(next_id++); }; + auto add_vector_int32 = [&](const int* values, uint32_t num_values) { + ANeuralNetworksOperandType operand_type{ + .type = ANEURALNETWORKS_TENSOR_INT32, + .dimensionCount = 1, + .dimensions = &num_values}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(ANeuralNetworksModel_setOperandValue( + nn_model, next_id, values, sizeof(int32_t) * num_values)); + augmented_inputs.push_back(next_id++); + }; + // Handle state tensors of RNN, LSTM, SVDF. // For each state_out tensor, a corresponding state_in operand needs to be // created for NNAPI. @@ -237,42 +298,54 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, model_state_outputs->push_back(tensor_id); next_id++; }; + auto check_and_add_activation = [&add_scalar_int32](int activation) { + if (activation > kTfLiteActRelu6) { + FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations"); + } + add_scalar_int32(activation); + }; auto add_add_params = [&add_scalar_int32](void* data) { auto* builtin = reinterpret_cast(data); + if (builtin->activation > kTfLiteActRelu6) { + FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations"); + } add_scalar_int32(builtin->activation); }; - auto add_pooling_params = [&add_scalar_int32](void* data) { + auto add_pooling_params = [&add_scalar_int32, + &check_and_add_activation](void* data) { auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->padding); add_scalar_int32(builtin->stride_width); add_scalar_int32(builtin->stride_height); add_scalar_int32(builtin->filter_width); add_scalar_int32(builtin->filter_height); - add_scalar_int32(builtin->activation); + check_and_add_activation(builtin->activation); }; - auto add_convolution_params = [&add_scalar_int32](void* data) { + auto add_convolution_params = [&add_scalar_int32, + &check_and_add_activation](void* data) { auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->padding); add_scalar_int32(builtin->stride_width); add_scalar_int32(builtin->stride_height); - add_scalar_int32(builtin->activation); + check_and_add_activation(builtin->activation); }; - auto add_depthwise_conv_params = [&add_scalar_int32](void* data) { + auto add_depthwise_conv_params = [&add_scalar_int32, + &check_and_add_activation](void* data) { auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->padding); add_scalar_int32(builtin->stride_width); add_scalar_int32(builtin->stride_height); add_scalar_int32(builtin->depth_multiplier); - add_scalar_int32(builtin->activation); + check_and_add_activation(builtin->activation); }; - auto add_fully_connected_params = [&add_scalar_int32](void* data) { + auto add_fully_connected_params = [&check_and_add_activation](void* data) { auto builtin = reinterpret_cast(data); - add_scalar_int32(builtin->activation); + check_and_add_activation(builtin->activation); }; auto add_concatenation_params = [&add_scalar_int32](void* data) { @@ -304,6 +377,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, // LSTM in NNAPI requires scratch tensor as an output operand. auto add_lstm_scratch_tensor_float32 = [interpreter, &node, &nn_model, &next_id, &augmented_outputs]() { + if (node.temporaries->size == 0) return; int scratch_buffer_index = node.temporaries->data[0]; const TfLiteTensor* tensor = interpreter->tensor(scratch_buffer_index); ANeuralNetworksOperandType operand_type{ @@ -316,7 +390,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, }; auto add_mean_params = [&add_scalar_int32](void* data) { - auto builtin = reinterpret_cast(data); + auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->keep_dims); }; @@ -331,6 +405,14 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, add_scalar_int32(builtin->activation); }; + auto add_squeeze_params = [&](void* data) { + const auto* builtin = reinterpret_cast(data); + // Note that we add the squeeze dimensions even if the dimensions were + // unspecified (empty), as NNAPI requires the operand. + add_vector_int32(builtin->squeeze_dims, + static_cast(builtin->num_squeeze_dims)); + }; + // Handle optional input tensors. auto add_optional_tensors = [&nn_model, &augmented_inputs, &next_id](int nn_type) { @@ -370,7 +452,14 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, add_pooling_params(node.builtin_data); nn_op_type = ANEURALNETWORKS_L2_POOL_2D; break; - case tflite::BuiltinOperator_CONV_2D: + case tflite::BuiltinOperator_CONV_2D: { + auto builtin = reinterpret_cast(node.builtin_data); + if (builtin->dilation_width_factor != 1 || + builtin->dilation_height_factor != 1 || node.inputs->size != 3) { + logError("NNAPI does not support dilated Conv2D."); + return kTfLiteError; + } + } add_convolution_params(node.builtin_data); nn_op_type = ANEURALNETWORKS_CONV_2D; break; @@ -414,6 +503,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, nn_op_type = ANEURALNETWORKS_SPACE_TO_DEPTH; break; case tflite::BuiltinOperator_LSTM: { + if (node.inputs->size + /* no of params */ 3 != 21) { + logError("NNAPI only supports 21-input LSTMs"); + return kTfLiteError; + } duplicate_state_tensor_float32( node.outputs->data[/*kOutputStateTensor*/ 0]); duplicate_state_tensor_float32( @@ -452,20 +545,62 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_DIV: nnapi_version = 11; // require NNAPI 1.1 nn_op_type = ANEURALNETWORKS_DIV; + check_and_add_activation( + reinterpret_cast(node.builtin_data)->activation); break; case tflite::BuiltinOperator_SUB: nnapi_version = 11; // require NNAPI 1.1 nn_op_type = ANEURALNETWORKS_SUB; + check_and_add_activation( + reinterpret_cast(node.builtin_data)->activation); + break; + case tflite::BuiltinOperator_SQUEEZE: + nnapi_version = 11; // requires NNAPI 1.1 + add_squeeze_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_SQUEEZE; + break; + case tflite::BuiltinOperator_TRANSPOSE: + // The permutation input tensor value dictates the output dimensions. + // TODO(b/110888333): Support dynamically-sized tensors in delegates. + if ((node.inputs->size > 1) && + (interpreter->tensor(node.inputs->data[1])->allocation_type != + kTfLiteMmapRo)) { + logError("NNAPI does not yet support dynamic tensors."); + return kTfLiteError; + } + nnapi_version = 11; // require NNAPI 1.1 + nn_op_type = ANEURALNETWORKS_TRANSPOSE; + break; + case tflite::BuiltinOperator_L2_NORMALIZATION: + nn_op_type = ANEURALNETWORKS_L2_NORMALIZATION; + if (reinterpret_cast(node.builtin_data) + ->activation != kTfLiteActNone) { + logError( + "NNAPI does not support L2Normalization with fused activations"); + return kTfLiteError; + } + if ((node.inputs->size > 0) && + (interpreter->tensor(node.inputs->data[0])->dims->size != 4)) { + logError("NNAPI only supports input rank 4 for L2Normalization"); + return kTfLiteError; + } + break; + case tflite::BuiltinOperator_HASHTABLE_LOOKUP: + if (interpreter->tensor(node.outputs->data[0])->type != + kTfLiteFloat32) { + logError("NNAPI only support HASHTABLE_LOOKUP with float32 output", + builtin); + return kTfLiteError; + } + nn_op_type = ANEURALNETWORKS_HASHTABLE_LOOKUP; break; case tflite::BuiltinOperator_CONCAT_EMBEDDINGS: case tflite::BuiltinOperator_LSH_PROJECTION: - case tflite::BuiltinOperator_HASHTABLE_LOOKUP: case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: - case tflite::BuiltinOperator_L2_NORMALIZATION: case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: case tflite::BuiltinOperator_PADV2: case tflite::BuiltinOperator_RESIZE_BILINEAR: @@ -476,9 +611,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SPACE_TO_BATCH_ND: case tflite::BuiltinOperator_BATCH_TO_SPACE_ND: case tflite::BuiltinOperator_TOPK_V2: - case tflite::BuiltinOperator_TRANSPOSE: case tflite::BuiltinOperator_SPLIT: - case tflite::BuiltinOperator_SQUEEZE: case tflite::BuiltinOperator_STRIDED_SLICE: case tflite::BuiltinOperator_EXP: case tflite::BuiltinOperator_LOG_SOFTMAX: @@ -489,6 +622,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_MAXIMUM: case tflite::BuiltinOperator_MINIMUM: case tflite::BuiltinOperator_ARG_MAX: + case tflite::BuiltinOperator_ARG_MIN: case tflite::BuiltinOperator_GREATER: case tflite::BuiltinOperator_GREATER_EQUAL: case tflite::BuiltinOperator_LESS: @@ -504,12 +638,25 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SPARSE_TO_DENSE: case tflite::BuiltinOperator_EQUAL: case tflite::BuiltinOperator_NOT_EQUAL: - FATAL("Op code %d is currently not delegated to NNAPI", builtin); - nn_op_type = -1; // set to invalid + case tflite::BuiltinOperator_SUM: + case tflite::BuiltinOperator_REDUCE_MAX: + case tflite::BuiltinOperator_REDUCE_PROD: + case tflite::BuiltinOperator_SQRT: + case tflite::BuiltinOperator_RSQRT: + case tflite::BuiltinOperator_SHAPE: + case tflite::BuiltinOperator_POW: + case tflite::BuiltinOperator_FAKE_QUANT: + case tflite::BuiltinOperator_PACK: + case tflite::BuiltinOperator_LOGICAL_OR: + case tflite::BuiltinOperator_ONE_HOT: + case tflite::BuiltinOperator_LOGICAL_AND: + case tflite::BuiltinOperator_LOGICAL_NOT: + logError("Op code %d is currently not delegated to NNAPI", builtin); + return kTfLiteError; break; case tflite::BuiltinOperator_CUSTOM: - FATAL("Custom operations are not supported when using NNAPI."); - nn_op_type = -1; // set to invalid + logError("Custom operations are not supported when using NNAPI."); + return kTfLiteError; break; } @@ -518,47 +665,70 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, } // Add the operation. - CHECK_NN(ANeuralNetworksModel_addOperation( + RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_addOperation( nn_model, nn_op_type, static_cast(augmented_inputs.size()), augmented_inputs.data(), static_cast(augmented_outputs.size()), reinterpret_cast(augmented_outputs.data()))); } + return kTfLiteOk; } TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { - // TODO(aselle): This is not correct. need to handle resize invalidation. - if (nn_model_ && nn_compiled_model_) return kTfLiteOk; + if (nn_model_ && nn_compiled_model_) return model_status_; + // TODO(aselle): This is not correct. need to handle resize invalidation. if (!nn_model_) { CHECK_NN(ANeuralNetworksModel_create(&nn_model_)); - // Find all the temporary tensors and put them in a skip_list. - std::vector skip_list; + // Find which tensors should be added to NNAPI. TFLite has temporaries + // and RNN back-edges which are are not valid for NNAPI. We look through all + // inputs and outputs and mark the mapping in tensor_id_to_nnapi_id with + // kOperandIdNotSet. addTensorOperands will replace those with the + // corresponding NNAPI operand ids and skip kOperandNotNeeded entries. + std::vector tensor_id_to_nnapi_id(interpreter->tensors_size(), + kOperandNotNeeded); + auto set_ids_to_not_set = [&tensor_id_to_nnapi_id](const int* buf, + size_t count) { + for (int j = 0; j < count; j++) { + auto tensor_id = buf[j]; + if (tensor_id != kOptionalTensor) { + tensor_id_to_nnapi_id[tensor_id] = kOperandIdNotSet; + } + } + }; for (size_t i = 0; i < interpreter->nodes_size(); i++) { const auto* node_and_registration = interpreter->node_and_registration(i); const TfLiteNode& node = node_and_registration->first; - if (node.temporaries != nullptr) { - for (int j = 0; j < node.temporaries->size; j++) { - skip_list.push_back(static_cast(node.temporaries->data[j])); - } - } + set_ids_to_not_set(node.inputs->data, node.inputs->size); + set_ids_to_not_set(node.outputs->data, node.outputs->size); } - - uint32_t next_id = addTensorOperands(interpreter, nn_model_, skip_list); - AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_, - &model_states_outputs_); - - std::vector augmented_inputs = interpreter->inputs(); - std::vector augmented_outputs = interpreter->outputs(); - - // All state tensors input/output need to be treated as model input/output. + set_ids_to_not_set(interpreter->inputs().data(), + interpreter->inputs().size()); + set_ids_to_not_set(interpreter->outputs().data(), + interpreter->outputs().size()); + + uint32_t next_id = 0; + RETURN_ERROR_IF_NN_FAILED(addTensorOperands( + interpreter, nn_model_, &next_id, &tensor_id_to_nnapi_id)); + RETURN_ERROR_IF_NN_FAILED( + AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_, + &model_states_outputs_, tensor_id_to_nnapi_id)); + + std::vector augmented_inputs; + MapAndAddTensorIds(interpreter->inputs().data(), + interpreter->inputs().size(), &augmented_inputs, + tensor_id_to_nnapi_id); augmented_inputs.insert(augmented_inputs.end(), model_states_inputs_.begin(), model_states_inputs_.end()); - augmented_outputs.insert(augmented_outputs.end(), - model_states_outputs_.begin(), - model_states_outputs_.end()); + std::vector augmented_outputs; + MapAndAddTensorIds(interpreter->outputs().data(), + interpreter->outputs().size(), &augmented_outputs, + tensor_id_to_nnapi_id); + MapAndAddTensorIds(model_states_outputs_.data(), + model_states_outputs_.size(), &augmented_outputs, + tensor_id_to_nnapi_id); CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs( nn_model_, static_cast(augmented_inputs.size()), @@ -576,7 +746,13 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { if (!nn_model_) { - TF_LITE_ENSURE_STATUS(BuildGraph(interpreter)); + model_status_ = BuildGraph(interpreter); + if (model_status_ != kTfLiteOk) { + logError("Failed to build graph for NNAPI"); + } + } + if (model_status_ != kTfLiteOk) { + return model_status_; } ANeuralNetworksExecution* execution = nullptr; @@ -640,4 +816,6 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { return kTfLiteOk; } +bool NNAPIDelegate::IsSupported() { return NNAPIExists(); } + } // namespace tflite diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h index 94dea4f9b23f208fddbacd3c77d889ea753a8a1d..2bdb2cc5c8211a48ea07e7ec45f9eebc0a3f7c10 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.h +++ b/tensorflow/contrib/lite/nnapi_delegate.h @@ -19,9 +19,10 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" -class ANeuralNetworsModel; +class ANeuralNetworksModel; +class ANeuralNetworksMemory; +class ANeuralNetworksCompilation; namespace tflite { @@ -54,19 +55,24 @@ class NNAPIDelegate { // Run TfLiteStatus Invoke(Interpreter* interpreter); + // Whether the current platform supports NNAPI delegation. + static bool IsSupported(); + private: // The NN API model handle ANeuralNetworksModel* nn_model_ = nullptr; // The NN API compilation handle ANeuralNetworksCompilation* nn_compiled_model_ = nullptr; + // Model status + TfLiteStatus model_status_ = kTfLiteOk; // List of state tensors for LSTM, RNN, SVDF. // NN API does not allow ops to maintain states across multiple // invocations. We need to manually create state input tensors from // corresponding state output tensors of TFLite operations, and map them // correctly. - std::vector model_states_inputs_; - std::vector model_states_outputs_; + std::vector model_states_inputs_; // holds NNAPI operand ids + std::vector model_states_outputs_; // holds TFLite tensor ids }; } // namespace tflite diff --git a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc new file mode 100644 index 0000000000000000000000000000000000000000..efde72b1a76a86728f4cccd8782ca0e993dd0338 --- /dev/null +++ b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/nnapi_delegate.h" + +#include + +namespace tflite { + +NNAPIAllocation::NNAPIAllocation(const char* filename, + ErrorReporter* error_reporter) + : MMAPAllocation(filename, error_reporter) { + // The disabled variant should never be created. + assert(false); +} + +NNAPIAllocation::~NNAPIAllocation() {} + +NNAPIDelegate::~NNAPIDelegate() {} + +TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { + return kTfLiteError; +} + +TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { + return kTfLiteError; +} + +bool NNAPIDelegate::IsSupported() { return false; } + +} // namespace tflite diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc index dfdd80ea8a42af683632be1d7e8ab0062847077d..f1f025f777c987c5ee47bdea457a973896b9bb82 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.cc +++ b/tensorflow/contrib/lite/optional_debug_tools.cc @@ -50,6 +50,10 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteString"; case kTfLiteBool: return "kTfLiteBool"; + case kTfLiteInt16: + return "kTfLiteInt16"; + case kTfLiteComplex64: + return "kTfLiteComplex64"; } return "(invalid)"; } @@ -82,13 +86,13 @@ void PrintInterpreterState(Interpreter* interpreter) { for (int tensor_index = 0; tensor_index < interpreter->tensors_size(); tensor_index++) { TfLiteTensor* tensor = interpreter->tensor(tensor_index); - printf("Tensor %3d %10s %15s %10zu bytes (%4.1f MB) ", tensor_index, - TensorTypeName(tensor->type), AllocTypeName(tensor->allocation_type), - tensor->bytes, float(tensor->bytes) / float(1 << 20)); + printf("Tensor %3d %-20s %10s %15s %10zu bytes (%4.1f MB) ", tensor_index, + tensor->name, TensorTypeName(tensor->type), + AllocTypeName(tensor->allocation_type), tensor->bytes, + (static_cast(tensor->bytes) / (1 << 20))); PrintTfLiteIntVector(tensor->dims); - printf("\n"); } - + printf("\n"); for (int node_index = 0; node_index < interpreter->nodes_size(); node_index++) { const std::pair* node_and_reg = @@ -104,7 +108,4 @@ void PrintInterpreterState(Interpreter* interpreter) { } } -// Prints a dump of what tensors and what nodes are in the interpreter. -TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter); - } // namespace tflite diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/contrib/lite/optional_debug_tools.h index 1b6998cda382782b974bea3d18ffb6217e8f780c..7fb4b8d8b7ae87cc6e8dd8503c8a4ce0cef2ce8d 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.h +++ b/tensorflow/contrib/lite/optional_debug_tools.h @@ -24,9 +24,6 @@ namespace tflite { // Prints a dump of what tensors and what nodes are in the interpreter. void PrintInterpreterState(Interpreter* interpreter); -// Prints a dump of what tensors and what nodes are in the interpreter. -TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter); - } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/contrib/lite/profiling/BUILD index a162b87b8f98576ec7c3b2623d1d34f2baef6cce..1172722f7a70771af73eb07571349e431755471c 100644 --- a/tensorflow/contrib/lite/profiling/BUILD +++ b/tensorflow/contrib/lite/profiling/BUILD @@ -58,6 +58,7 @@ cc_test( name = "profile_summarizer_test", srcs = ["profile_summarizer_test.cc"], copts = common_copts, + tags = ["no_oss"], deps = [ ":profile_summarizer", "//tensorflow/contrib/lite:framework", diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc index 45388b500c7897c8b33b49eb6ab4e9f8c4fdb37c..720bd717b9e3b0c45cbdbaaad2b6900edacc3051 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer.cc +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc @@ -23,8 +23,6 @@ namespace tflite { namespace profiling { namespace { -using Detail = tensorflow::StatsCalculator::Detail; - struct OperatorDetails { std::string name; std::vector inputs; @@ -78,8 +76,13 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter, } else { op_name = tflite::EnumNamesBuiltinOperator()[code]; } + const char* profiling_string = + interpreter.OpProfilingString(node_reg->second, &node_reg->first); OperatorDetails details; details.name = op_name; + if (profiling_string) { + details.name += ":" + std::string(profiling_string); + } details.inputs = GetTensorNames(interpreter, inputs); details.outputs = GetTensorNames(interpreter, outputs); return details; @@ -120,28 +123,17 @@ void ProfileSummarizer::ProcessProfiles( int64_t base_start_us = events[0]->begin_timestamp_us; int node_num = 0; int64_t curr_total_us = 0; - std::map details; for (auto event : events) { auto op_details = GetOperatorDetails(interpreter, event->event_metadata); auto node_name = ToString(op_details.outputs); - auto result = details.emplace(node_name, Detail()); - Detail* detail = &(result.first->second); - detail->start_us.UpdateStat(event->begin_timestamp_us - base_start_us); + int64_t start_us = event->begin_timestamp_us - base_start_us; int64_t node_exec_time = event->end_timestamp_us - event->begin_timestamp_us; - detail->rel_end_us.UpdateStat(node_exec_time); + stats_calculator_->AddNodeStats(node_name, op_details.name, node_num, + start_us, node_exec_time, 0 /*memory */); curr_total_us += node_exec_time; ++node_num; - - if (result.second) { - detail->name = node_name; - detail->type = op_details.name; - detail->run_order = node_num; - detail->times_called = 0; - } - ++detail->times_called; } - stats_calculator_->UpdateDetails(details); stats_calculator_->UpdateRunTotalUs(curr_total_us); } } // namespace profiling diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc index 35cf780713b93db559f86dcaf62e1ac004b5049a..67a5eecfa05379c7a721e7d669fcd02602e5e369 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc +++ b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc @@ -31,6 +31,7 @@ namespace profiling { namespace { +#ifdef TFLITE_PROFILING_ENABLED TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0); const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1); @@ -42,20 +43,35 @@ TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +const char* SimpleOpProfilingString(const TfLiteContext* context, + const TfLiteNode* node) { + return "Profile"; +} + TfLiteRegistration* RegisterSimpleOp() { + static TfLiteRegistration registration = { + nullptr, nullptr, nullptr, + SimpleOpEval, nullptr, tflite::BuiltinOperator_CUSTOM, + "SimpleOpEval", 1}; + return ®istration; +} + +TfLiteRegistration* RegisterSimpleOpWithProfilingDetails() { static TfLiteRegistration registration = {nullptr, nullptr, nullptr, SimpleOpEval, + SimpleOpProfilingString, tflite::BuiltinOperator_CUSTOM, "SimpleOpEval", 1}; return ®istration; } +#endif class SimpleOpModel : public SingleOpModel { public: - void Init(); + void Init(const std::function& registration); tflite::Interpreter* GetInterpreter() { return interpreter_.get(); } void SetInputs(int32_t x, int32_t y) { PopulateTensor(inputs_[0], {x}); @@ -68,11 +84,12 @@ class SimpleOpModel : public SingleOpModel { int output_; }; -void SimpleOpModel::Init() { +void SimpleOpModel::Init( + const std::function& registration) { inputs_[0] = AddInput({TensorType_INT32, {1}}); inputs_[1] = AddInput({TensorType_INT32, {1}}); output_ = AddOutput({TensorType_INT32, {}}); - SetCustomOp("SimpleAdd", {}, RegisterSimpleOp); + SetCustomOp("SimpleAdd", {}, registration); BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])}); } @@ -86,7 +103,28 @@ TEST(ProfileSummarizerTest, Empty) { TEST(ProfileSummarizerTest, Interpreter) { Profiler profiler; SimpleOpModel m; - m.Init(); + m.Init(RegisterSimpleOp); + auto interpreter = m.GetInterpreter(); + interpreter->SetProfiler(&profiler); + profiler.StartProfiling(); + m.SetInputs(1, 2); + m.Invoke(); + // 3 = 1 + 2 + EXPECT_EQ(m.GetOutput(), 3); + profiler.StopProfiling(); + ProfileSummarizer summarizer; + auto events = profiler.GetProfileEvents(); + EXPECT_EQ(1, events.size()); + summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter); + auto output = summarizer.GetOutputString(); + // TODO(shashishekhar): Add a better test here. + ASSERT_TRUE(output.find("SimpleOpEval") != std::string::npos) << output; +} + +TEST(ProfileSummarizerTest, InterpreterPlusProfilingDetails) { + Profiler profiler; + SimpleOpModel m; + m.Init(RegisterSimpleOpWithProfilingDetails); auto interpreter = m.GetInterpreter(); interpreter->SetProfiler(&profiler); profiler.StartProfiling(); @@ -101,8 +139,10 @@ TEST(ProfileSummarizerTest, Interpreter) { summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter); auto output = summarizer.GetOutputString(); // TODO(shashishekhar): Add a better test here. - ASSERT_TRUE(output.find("SimpleOp") != std::string::npos) << output; + ASSERT_TRUE(output.find("SimpleOpEval:Profile") != std::string::npos) + << output; } + #endif } // namespace diff --git a/tensorflow/contrib/lite/profiling/time.cc b/tensorflow/contrib/lite/profiling/time.cc index 446660bb747cd6e3b694669b64ac1d95cf415fbe..875ddb02bcfc30f4c2ef543fe1c15bec467e5410 100644 --- a/tensorflow/contrib/lite/profiling/time.cc +++ b/tensorflow/contrib/lite/profiling/time.cc @@ -14,16 +14,34 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/profiling/time.h" +#if defined(_MSC_VER) +#include // NOLINT(build/c++11) +#else #include +#endif namespace tflite { namespace profiling { namespace time { + +#if defined(_MSC_VER) + +uint64_t NowMicros() { + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} + +#else + uint64_t NowMicros() { struct timeval tv; gettimeofday(&tv, nullptr); return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; } + +#endif // defined(_MSC_VER) + } // namespace time } // namespace profiling } // namespace tflite diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 27909a9458f6b09f96cb556a5254f01e54f46e05..860aff9e7e2de9616dea40f42a33bc1e4ee9f400 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -19,6 +19,8 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper", + "//tensorflow/python:util", + "//third_party/py/numpy", ], ) @@ -30,9 +32,10 @@ py_test( tags = ["no_oss"], deps = [ ":interpreter", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:platform_test", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//third_party/py/numpy", ], ) @@ -69,7 +72,10 @@ py_test( srcs = ["lite_test.py"], data = [":interpreter_test_data"], srcs_version = "PY2AND3", - tags = ["no_windows"], + tags = [ + "no_oss", + "no_windows", + ], deps = [ ":lite", ], @@ -161,7 +167,10 @@ py_test( name = "convert_saved_model_test", srcs = ["convert_saved_model_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], + tags = [ + "no_oss", + "no_windows", + ], visibility = ["//visibility:public"], deps = [ ":convert_saved_model", diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index c038c88945b71f30bf091a1098dcf853f5415b1b..11d4bdbe82295bff9a7a457e2fd5ca1f8fe04036 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -25,7 +25,6 @@ import tempfile as _tempfile from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 -from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.platform import resource_loader as _resource_loader from tensorflow.python.util.lazy_loader import LazyLoader @@ -55,7 +54,7 @@ def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): """Convert `input_data_str` according to model and toco parameters. Unless you know what you are doing consider using - the more friendly @{tf.contrib.lite.toco_convert}}. + the more friendly `tf.contrib.lite.toco_convert`. Args: model_flags_str: Serialized proto describing model properties, see @@ -116,6 +115,7 @@ def build_toco_convert_protos(input_tensors, inference_type=lite_constants.FLOAT, inference_input_type=None, input_format=lite_constants.TENSORFLOW_GRAPHDEF, + input_shapes=None, output_format=lite_constants.TFLITE, quantized_input_stats=None, default_ranges_stats=None, @@ -135,13 +135,15 @@ def build_toco_convert_protos(input_tensors, input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). - inference_type: Target data type of arrays in the output file. Currently - must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) - inference_input_type: Target data type of input arrays. Allows for a - different type for input arrays in the case of quantization. Currently - must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + inference_type: Target data type of real-number arrays in the output file. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of real-number input arrays. Allows + for a different type for input arrays in the case of quantization. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) input_format: Type of data to read Currently must be `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) + input_shapes: Input array shape. It needs to be a list of the same length + as `input_tensors`, or None. (default None) output_format: Output file format. Currently must be `{TFLITE, GRAPHVIZ_DOT}`. (default TFLITE) quantized_input_stats: List of tuples of integers representing the mean and @@ -202,31 +204,19 @@ def build_toco_convert_protos(input_tensors, if dump_graphviz_dir: toco.dump_graphviz_dir = dump_graphviz_dir toco.dump_graphviz_include_video = dump_graphviz_video + model = _model_flags_pb2.ModelFlags() model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): - if input_tensor.dtype == _dtypes.float32: - tflite_input_type = lite_constants.FLOAT - elif input_tensor.dtype == _dtypes.int32: - tflite_input_type = lite_constants.INT32 - elif input_tensor.dtype == _dtypes.int64: - tflite_input_type = lite_constants.INT64 - elif input_tensor.dtype == _dtypes.uint8: - tflite_input_type = lite_constants.QUANTIZED_UINT8 - # TODO(aselle): Insert strings when they are available - else: - raise ValueError("Tensors %s not known type %r" % (input_tensor.name, - input_tensor.dtype)) - input_array = model.input_arrays.add() - if inference_type == lite_constants.QUANTIZED_UINT8: - if tflite_input_type == lite_constants.FLOAT: - tflite_input_type = lite_constants.QUANTIZED_UINT8 input_array.mean_value, input_array.std_value = quantized_input_stats[idx] - input_array.name = tensor_name(input_tensor) - input_array.shape.dims.extend(map(int, input_tensor.get_shape())) + if input_shapes is None: + shape = input_tensor.get_shape() + else: + shape = input_shapes[idx] + input_array.shape.dims.extend(map(int, shape)) for output_tensor in output_tensors: model.output_arrays.append(tensor_name(output_tensor)) diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py index 779bda4c9d05fd056d6a262412fdcf0d47e7c57c..1be61fe05343a0e7d39f2808c78672698e0d767f 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys +import numpy as np from tensorflow.python.util.lazy_loader import LazyLoader # Lazy load since some of the performance benchmark skylark rules @@ -52,20 +54,49 @@ class Interpreter(object): if not self._interpreter: raise ValueError('Failed to open {}'.format(model_path)) elif model_content and not model_path: + # Take a reference, so the pointer remains valid. + # Since python strings are immutable then PyString_XX functions + # will always return the same pointer. + self._model_content = model_content self._interpreter = ( _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer( model_content)) - if not self._interpreter: - raise ValueError( - 'Failed to create model from {} bytes'.format(len(model_content))) elif not model_path and not model_path: raise ValueError('`model_path` or `model_content` must be specified.') else: raise ValueError('Can\'t both provide `model_path` and `model_content`') def allocate_tensors(self): - if not self._interpreter.AllocateTensors(): - raise ValueError('Failed to allocate tensors') + self._ensure_safe() + return self._interpreter.AllocateTensors() + + def _safe_to_run(self): + """Returns true if there exist no numpy array buffers. + + This means it is safe to run tflite calls that may destroy internally + allocated memory. This works, because in the wrapper.cc we have made + the numpy base be the self._interpreter. + """ + # NOTE, our tensor() call in cpp will use _interpreter as a base pointer. + # If this environment is the only _interpreter, then the ref count should be + # 2 (1 in self and 1 in temporary of sys.getrefcount). + return sys.getrefcount(self._interpreter) == 2 + + def _ensure_safe(self): + """Makes sure no numpy arrays pointing to internal buffers are active. + + This should be called from any function that will call a function on + _interpreter that may reallocate memory e.g. invoke(), ... + + Raises: + RuntimeError: If there exist numpy objects pointing to internal memory + then we throw. + """ + if not self._safe_to_run(): + raise RuntimeError("""There is at least 1 reference to internal data + in the interpreter in the form of a numpy array or slice. Be sure to + only hold the function returned from tensor() if you are using raw + data access.""") def _get_tensor_details(self, tensor_index): """Gets tensor details. @@ -109,7 +140,10 @@ class Interpreter(object): ] def set_tensor(self, tensor_index, value): - """Sets the value of the input. + """Sets the value of the input tensor. Note this copies data in `value`. + + If you want to avoid copying, you can use the `tensor()` function to get a + numpy buffer pointing to the input buffer in the tflite interpreter. Args: tensor_index: Tensor index of tensor to set. This value can be gotten from @@ -119,8 +153,7 @@ class Interpreter(object): Raises: ValueError: If the interpreter could not set the tensor. """ - if not self._interpreter.SetTensor(tensor_index, value): - raise ValueError('Failed to set tensor') + self._interpreter.SetTensor(tensor_index, value) def resize_tensor_input(self, input_index, tensor_size): """Resizes an input tensor. @@ -133,8 +166,11 @@ class Interpreter(object): Raises: ValueError: If the interpreter could not resize the input tensor. """ - if not self._interpreter.ResizeInputTensor(input_index, tensor_size): - raise ValueError('Failed to resize input') + self._ensure_safe() + # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size + # parameter. + tensor_size = np.array(tensor_size, dtype=np.int32) + self._interpreter.ResizeInputTensor(input_index, tensor_size) def get_output_details(self): """Gets model output details. @@ -147,7 +183,9 @@ class Interpreter(object): ] def get_tensor(self, tensor_index): - """Sets the value of the input. + """Gets the value of the input tensor (get a copy). + + If you wish to avoid the copy, use `tensor()`. Args: tensor_index: Tensor index of tensor to get. This value can be gotten from @@ -158,6 +196,62 @@ class Interpreter(object): """ return self._interpreter.GetTensor(tensor_index) + def tensor(self, tensor_index): + """Returns function that gives a numpy view of the current tensor buffer. + + This allows reading and writing to this tensors w/o copies. This more + closely mirrors the C++ Interpreter class interface's tensor() member, hence + the name. Be careful to not hold these output references through calls + to `allocate_tensors()` and `invoke()`. + + Usage: + + interpreter.allocate_tensors() + input = interpreter.tensor(interpreter.get_input_details()[0]["index"]) + output = interpreter.tensor(interpreter.get_output_details()[0]["index"]) + for i in range(10): + input().fill(3.) + interpreter.invoke() + print("inference %s" % output()) + + Notice how this function avoids making a numpy array directly. This is + because it is important to not hold actual numpy views to the data longer + than necessary. If you do, then the interpreter can no longer be invoked, + because it is possible the interpreter would resize and invalidate the + referenced tensors. The NumPy API doesn't allow any mutability of the + the underlying buffers. + + WRONG: + + input = interpreter.tensor(interpreter.get_input_details()[0]["index"])() + output = interpreter.tensor(interpreter.get_output_details()[0]["index"])() + interpreter.allocate_tensors() # This will throw RuntimeError + for i in range(10): + input.fill(3.) + interpreter.invoke() # this will throw RuntimeError since input,output + + Args: + tensor_index: Tensor index of tensor to get. This value can be gotten from + the 'index' field in get_output_details. + + Returns: + A function that can return a new numpy array pointing to the internal + TFLite tensor state at any point. It is safe to hold the function forever, + but it is not safe to hold the numpy array forever. + """ + return lambda: self._interpreter.tensor(self._interpreter, tensor_index) + def invoke(self): - if not self._interpreter.Invoke(): - raise ValueError('Failed to invoke TFLite model') + """Invoke the interpreter. + + Be sure to set the input sizes, allocate tensors and fill values before + calling this. + + Raises: + ValueError: When the underlying interpreter fails raise ValueError. + """ + self._ensure_safe() + self._interpreter.Invoke() + + def reset_all_variables_to_zero(self): + return self._interpreter.ResetVariableTensorsToZero() diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py index f802edf020db8a9d4e7bb890aadaae7e34e983a8..e77d52ca9950ec42300264bb56ebce253d4982b1 100644 --- a/tensorflow/contrib/lite/python/interpreter_test.py +++ b/tensorflow/contrib/lite/python/interpreter_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import io import numpy as np +import six from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper from tensorflow.python.framework import test_util @@ -82,7 +83,7 @@ class InterpreterTest(test_util.TensorFlowTestCase): test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8) expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8) interpreter.resize_tensor_input(input_details[0]['index'], - np.array(test_input.shape, dtype=np.int32)) + test_input.shape) interpreter.allocate_tensors() interpreter.set_tensor(input_details[0]['index'], test_input) interpreter.invoke() @@ -91,5 +92,83 @@ class InterpreterTest(test_util.TensorFlowTestCase): self.assertTrue((expected_output == output_data).all()) +class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase): + + def testInvalidModelContent(self): + with self.assertRaisesRegexp(ValueError, + 'Model provided has model identifier \''): + interpreter_wrapper.Interpreter(model_content=six.b('garbage')) + + def testInvalidModelFile(self): + with self.assertRaisesRegexp( + ValueError, 'Could not open \'totally_invalid_file_name\''): + interpreter_wrapper.Interpreter( + model_path='totally_invalid_file_name') + + def testInvokeBeforeReady(self): + interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite')) + with self.assertRaisesRegexp(RuntimeError, + 'Invoke called on model that is not ready'): + interpreter.invoke() + + +class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase): + + def setUp(self): + self.interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite')) + self.interpreter.allocate_tensors() + self.input0 = self.interpreter.get_input_details()[0]['index'] + self.initial_data = np.array([[-1., -2., -3., -4.]], np.float32) + + def testTensorAccessor(self): + """Check that tensor returns a reference.""" + array_ref = self.interpreter.tensor(self.input0) + np.copyto(array_ref(), self.initial_data) + self.assertAllEqual(array_ref(), self.initial_data) + self.assertAllEqual( + self.interpreter.get_tensor(self.input0), self.initial_data) + + def testGetTensorAccessor(self): + """Check that get_tensor returns a copy.""" + self.interpreter.set_tensor(self.input0, self.initial_data) + array_initial_copy = self.interpreter.get_tensor(self.input0) + new_value = np.add(1., array_initial_copy) + self.interpreter.set_tensor(self.input0, new_value) + self.assertAllEqual(array_initial_copy, self.initial_data) + self.assertAllEqual(self.interpreter.get_tensor(self.input0), new_value) + + def testBase(self): + self.assertTrue(self.interpreter._safe_to_run()) + _ = self.interpreter.tensor(self.input0) + self.assertTrue(self.interpreter._safe_to_run()) + in0 = self.interpreter.tensor(self.input0)() + self.assertFalse(self.interpreter._safe_to_run()) + in0b = self.interpreter.tensor(self.input0)() + self.assertFalse(self.interpreter._safe_to_run()) + # Now get rid of the buffers so that we can evaluate. + del in0 + del in0b + self.assertTrue(self.interpreter._safe_to_run()) + + def testBaseProtectsFunctions(self): + in0 = self.interpreter.tensor(self.input0)() + # Make sure we get an exception if we try to run an unsafe operation + with self.assertRaisesRegexp( + RuntimeError, 'There is at least 1 reference'): + _ = self.interpreter.allocate_tensors() + # Make sure we get an exception if we try to run an unsafe operation + with self.assertRaisesRegexp( + RuntimeError, 'There is at least 1 reference'): + _ = self.interpreter.invoke() + # Now test that we can run + del in0 # this is our only buffer reference, so now it is safe to change + in0safe = self.interpreter.tensor(self.input0) + _ = self.interpreter.allocate_tensors() + del in0safe # make sure in0Safe is held but lint doesn't complain + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD index 12ab38847dc0f838ae2c6bf80ed80805285e4b8b..69ee95c320b72b68052c6f76f32c1493707f34b1 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD @@ -13,8 +13,7 @@ cc_library( deps = [ "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/core:lib", - "//tensorflow/python:numpy_lib", + "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 5f304ad45d400b13e20bda8184b5b40cfe13f6c2..9ab05f3068494a573ffa5b46f84be66a12d54e46 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -14,14 +14,21 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" +#include #include #include "absl/memory/memory.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/python/lib/core/numpy.h" + +// Disallow Numpy 1.7 deprecated symbols. +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include + +#include "numpy/arrayobject.h" +#include "numpy/ufuncobject.h" #if PY_MAJOR_VERSION >= 3 #define PY_TO_CPPSTRING PyBytes_AsStringAndSize @@ -31,10 +38,66 @@ limitations under the License. #define CPP_TO_PYSTRING PyString_FromStringAndSize #endif +#define TFLITE_PY_CHECK(x) \ + if ((x) != kTfLiteOk) { \ + return error_reporter_->exception(); \ + } + +#define TFLITE_PY_TENSOR_BOUNDS_CHECK(i) \ + if (i >= interpreter_->tensors_size() || i < 0) { \ + PyErr_Format(PyExc_ValueError, \ + "Invalid tensor index %d exceeds max tensor index %lu", i, \ + interpreter_->tensors_size()); \ + return nullptr; \ + } + +#define TFLITE_PY_ENSURE_VALID_INTERPRETER() \ + if (!interpreter_) { \ + PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \ + return nullptr; \ + } + namespace tflite { namespace interpreter_wrapper { +class PythonErrorReporter : public tflite::ErrorReporter { + public: + PythonErrorReporter() {} + + // Report an error message + int Report(const char* format, va_list args) override { + char buf[1024]; + int formatted = vsnprintf(buf, sizeof(buf), format, args); + buffer_ << buf; + return formatted; + } + + // Set's a Python runtime exception with the last error. + PyObject* exception() { + std::string last_message = message(); + PyErr_SetString(PyExc_RuntimeError, last_message.c_str()); + return nullptr; + } + + // Gets the last error message and clears the buffer. + std::string message() { + std::string value = buffer_.str(); + buffer_.clear(); + return value; + } + + private: + std::stringstream buffer_; +}; + namespace { + +// Calls PyArray's initialization to initialize all the API pointers. Note that +// this usage implies only this translation unit can use the pointers. See +// tensorflow/python/core/numpy.cc for a strategy if we ever need to extend +// this further. +void ImportNumpy() { import_array1(); } + std::unique_ptr CreateInterpreter( const tflite::FlatBufferModel* model, const tflite::ops::builtin::BuiltinOpResolver& resolver) { @@ -42,22 +105,11 @@ std::unique_ptr CreateInterpreter( return nullptr; } - tensorflow::ImportNumpy(); + ImportNumpy(); std::unique_ptr interpreter; - tflite::InterpreterBuilder(*model, resolver)(&interpreter); - if (interpreter) { - for (const int input_index : interpreter->inputs()) { - const TfLiteTensor* tensor = interpreter->tensor(input_index); - CHECK(tensor); - const TfLiteIntArray* dims = tensor->dims; - if (!dims) { - continue; - } - - std::vector input_dims(dims->data, dims->data + dims->size); - interpreter->ResizeInputTensor(input_index, input_dims); - } + if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { + return nullptr; } return interpreter; } @@ -68,6 +120,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { return NPY_FLOAT32; case kTfLiteInt32: return NPY_INT32; + case kTfLiteInt16: + return NPY_INT16; case kTfLiteUInt8: return NPY_UINT8; case kTfLiteInt64: @@ -76,11 +130,13 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { return NPY_OBJECT; case kTfLiteBool: return NPY_BOOL; + case kTfLiteComplex64: + return NPY_COMPLEX64; case kTfLiteNoType: - return -1; + return NPY_NOTYPE; + // Avoid default so compiler errors created when new types are made. } - LOG(ERROR) << "Unknown TfLiteType " << tf_lite_type; - return -1; + return NPY_NOTYPE; } TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { @@ -90,6 +146,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { return kTfLiteFloat32; case NPY_INT32: return kTfLiteInt32; + case NPY_INT16: + return kTfLiteInt16; case NPY_UINT8: return kTfLiteUInt8; case NPY_INT64: @@ -100,8 +158,10 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { case NPY_STRING: case NPY_UNICODE: return kTfLiteString; + case NPY_COMPLEX64: + return kTfLiteComplex64; + // Avoid default so compiler errors created when new types are made. } - LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type; return kTfLiteNoType; } @@ -124,33 +184,54 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) { } // namespace +InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper( + std::unique_ptr model, + std::unique_ptr error_reporter, + std::string* error_msg) { + if (!model) { + *error_msg = error_reporter->message(); + return nullptr; + } + + auto resolver = absl::make_unique(); + auto interpreter = CreateInterpreter(model.get(), *resolver); + if (!interpreter) { + *error_msg = error_reporter->message(); + return nullptr; + } + + InterpreterWrapper* wrapper = + new InterpreterWrapper(std::move(model), std::move(error_reporter), + std::move(resolver), std::move(interpreter)); + return wrapper; +} + InterpreterWrapper::InterpreterWrapper( - std::unique_ptr model) + std::unique_ptr model, + std::unique_ptr error_reporter, + std::unique_ptr resolver, + std::unique_ptr interpreter) : model_(std::move(model)), - resolver_(absl::make_unique()), - interpreter_(CreateInterpreter(model_.get(), *resolver_)) {} + error_reporter_(std::move(error_reporter)), + resolver_(std::move(resolver)), + interpreter_(std::move(interpreter)) {} InterpreterWrapper::~InterpreterWrapper() {} -bool InterpreterWrapper::AllocateTensors() { - if (!interpreter_) { - LOG(ERROR) << "Cannot allocate tensors: invalid interpreter."; - return false; - } - - if (interpreter_->AllocateTensors() != kTfLiteOk) { - LOG(ERROR) << "Unable to allocate tensors."; - return false; - } - - return true; +PyObject* InterpreterWrapper::AllocateTensors() { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_CHECK(interpreter_->AllocateTensors()); + Py_RETURN_NONE; } -bool InterpreterWrapper::Invoke() { - return interpreter_ ? (interpreter_->Invoke() == kTfLiteOk) : false; +PyObject* InterpreterWrapper::Invoke() { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_CHECK(interpreter_->Invoke()); + Py_RETURN_NONE; } PyObject* InterpreterWrapper::InputIndices() const { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(), interpreter_->inputs().size()); @@ -164,35 +245,36 @@ PyObject* InterpreterWrapper::OutputIndices() const { return PyArray_Return(reinterpret_cast(np_array)); } -bool InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) { - if (!interpreter_) { - LOG(ERROR) << "Invalid interpreter."; - return false; - } +PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); std::unique_ptr array_safe( PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); if (!array_safe) { - LOG(ERROR) << "Failed to convert value into readable tensor."; - return false; + PyErr_SetString(PyExc_ValueError, + "Failed to convert numpy value into readable tensor."); + return nullptr; } PyArrayObject* array = reinterpret_cast(array_safe.get()); if (PyArray_NDIM(array) != 1) { - LOG(ERROR) << "Expected 1-D defining input shape."; - return false; + PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.", + PyArray_NDIM(array)); + return nullptr; } if (PyArray_TYPE(array) != NPY_INT32) { - LOG(ERROR) << "Shape must be an int32 array"; - return false; + PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).", + PyArray_TYPE(array)); + return nullptr; } std::vector dims(PyArray_SHAPE(array)[0]); memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int)); - return (interpreter_->ResizeInputTensor(i, dims) == kTfLiteOk); + TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims)); + Py_RETURN_NONE; } std::string InterpreterWrapper::TensorName(int i) const { @@ -205,21 +287,21 @@ std::string InterpreterWrapper::TensorName(int i) const { } PyObject* InterpreterWrapper::TensorType(int i) const { - if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { - return nullptr; - } + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); const TfLiteTensor* tensor = interpreter_->tensor(i); - int typenum = TfLiteTypeToPyArrayType(tensor->type); - return PyArray_TypeObjectFromType(typenum); + int code = TfLiteTypeToPyArrayType(tensor->type); + if (code == -1) { + PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code); + return nullptr; + } + return PyArray_TypeObjectFromType(code); } PyObject* InterpreterWrapper::TensorSize(int i) const { - if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { - Py_INCREF(Py_None); - return Py_None; - } - + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); const TfLiteTensor* tensor = interpreter_->tensor(i); PyObject* np_array = PyArrayFromIntVector(tensor->dims->data, tensor->dims->size); @@ -228,120 +310,166 @@ PyObject* InterpreterWrapper::TensorSize(int i) const { } PyObject* InterpreterWrapper::TensorQuantization(int i) const { - if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { - Py_INCREF(Py_None); - return Py_None; - } - + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); const TfLiteTensor* tensor = interpreter_->tensor(i); return PyTupleFromQuantizationParam(tensor->params); } -bool InterpreterWrapper::SetTensor(int i, PyObject* value) { - if (!interpreter_) { - LOG(ERROR) << "Invalid interpreter."; - return false; - } - - if (i >= interpreter_->tensors_size()) { - LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index " - << interpreter_->tensors_size(); - return false; - } +PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); std::unique_ptr array_safe( PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); if (!array_safe) { - LOG(ERROR) << "Failed to convert value into readable tensor."; - return false; + PyErr_SetString(PyExc_ValueError, + "Failed to convert value into readable tensor."); + return nullptr; } PyArrayObject* array = reinterpret_cast(array_safe.get()); const TfLiteTensor* tensor = interpreter_->tensor(i); if (TfLiteTypeFromPyArray(array) != tensor->type) { - LOG(ERROR) << "Cannot set tensor:" - << " Got tensor of type " << TfLiteTypeFromPyArray(array) - << " but expected type " << tensor->type << " for input " << i; - return false; + PyErr_Format(PyExc_ValueError, + "Cannot set tensor:" + " Got tensor of type %d" + " but expected type %d for input %d ", + TfLiteTypeFromPyArray(array), tensor->type, i); + return nullptr; } if (PyArray_NDIM(array) != tensor->dims->size) { - LOG(ERROR) << "Cannot set tensor: Dimension mismatch"; - return false; + PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch"); + return nullptr; } for (int j = 0; j < PyArray_NDIM(array); j++) { if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) { - LOG(ERROR) << "Cannot set tensor: Dimension mismatch"; - return false; + PyErr_SetString(PyExc_ValueError, + "Cannot set tensor: Dimension mismatch"); + return nullptr; } } size_t size = PyArray_NBYTES(array); - DCHECK_EQ(size, tensor->bytes); + if (size != tensor->bytes) { + PyErr_Format(PyExc_ValueError, + "numpy array had %zu bytes but expected %zu bytes.", size, + tensor->bytes); + return nullptr; + } memcpy(tensor->data.raw, PyArray_DATA(array), size); - return true; + Py_RETURN_NONE; } -PyObject* InterpreterWrapper::GetTensor(int i) const { - if (!interpreter_) { - LOG(ERROR) << "Invalid interpreter."; - Py_INCREF(Py_None); - return Py_None; - } +namespace { + +// Checks to see if a tensor access can succeed (returns nullptr on error). +// Otherwise returns Py_None. +PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, + TfLiteTensor** tensor, int* type_num) { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index); - if (i >= interpreter_->tensors_size()) { - LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index " - << interpreter_->inputs().size(); - Py_INCREF(Py_None); - return Py_None; + *tensor = interpreter_->tensor(tensor_index); + if ((*tensor)->bytes == 0) { + PyErr_SetString(PyExc_ValueError, "Invalid tensor size."); + return nullptr; } - const TfLiteTensor* output_tensor = interpreter_->tensor(i); - const int tensor_size = output_tensor->bytes; - if (tensor_size <= 0) { - LOG(ERROR) << "Invalid tensor size"; - Py_INCREF(Py_None); - return Py_None; + *type_num = TfLiteTypeToPyArrayType((*tensor)->type); + if (*type_num == -1) { + PyErr_SetString(PyExc_ValueError, "Unknown tensor type."); + return nullptr; } - int type_num = TfLiteTypeToPyArrayType(output_tensor->type); - if (type_num == -1) { - LOG(ERROR) << "Unknown tensor type " << output_tensor->type; - Py_INCREF(Py_None); - return Py_None; + if (!(*tensor)->data.raw) { + PyErr_SetString(PyExc_ValueError, "Tensor data is null."); + return nullptr; } - void* data = malloc(tensor_size); - memcpy(data, output_tensor->data.raw, tensor_size); + Py_RETURN_NONE; +} - const TfLiteIntArray* output_dims = output_tensor->dims; - std::vector dims(output_dims->data, - output_dims->data + output_dims->size); +} // namespace + +PyObject* InterpreterWrapper::GetTensor(int i) const { + // Sanity check accessor + TfLiteTensor* tensor = nullptr; + int type_num = 0; + + PyObject* check_result = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num); + if (check_result == nullptr) return check_result; + Py_XDECREF(check_result); + + std::vector dims(tensor->dims->data, + tensor->dims->data + tensor->dims->size); + // Make a buffer copy but we must tell Numpy It owns that data or else + // it will leak. + void* data = malloc(tensor->bytes); + if (!data) { + PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed."); + return nullptr; + } + memcpy(data, tensor->data.raw, tensor->bytes); PyObject* np_array = PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data); - + PyArray_ENABLEFLAGS(reinterpret_cast(np_array), + NPY_ARRAY_OWNDATA); return PyArray_Return(reinterpret_cast(np_array)); } +PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { + // Sanity check accessor + TfLiteTensor* tensor = nullptr; + int type_num = 0; + + PyObject* check_result = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num); + if (check_result == nullptr) return check_result; + Py_XDECREF(check_result); + + std::vector dims(tensor->dims->data, + tensor->dims->data + tensor->dims->size); + PyArrayObject* np_array = + reinterpret_cast(PyArray_SimpleNewFromData( + dims.size(), dims.data(), type_num, tensor->data.raw)); + Py_INCREF(base_object); // SetBaseObject steals, so we need to add. + PyArray_SetBaseObject(np_array, base_object); + return PyArray_Return(np_array); +} + InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( - const char* model_path) { + const char* model_path, std::string* error_msg) { + std::unique_ptr error_reporter(new PythonErrorReporter); std::unique_ptr model = - tflite::FlatBufferModel::BuildFromFile(model_path); - return model ? new InterpreterWrapper(std::move(model)) : nullptr; + tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get()); + return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), + error_msg); } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( - PyObject* data) { + PyObject* data, std::string* error_msg) { char * buf = nullptr; Py_ssize_t length; + std::unique_ptr error_reporter(new PythonErrorReporter); if (PY_TO_CPPSTRING(data, &buf, &length) == -1) { return nullptr; } std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(buf, length); - return model ? new InterpreterWrapper(std::move(model)) : nullptr; + tflite::FlatBufferModel::BuildFromBuffer(buf, length, + error_reporter.get()); + return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), + error_msg); +} + +PyObject* InterpreterWrapper::ResetVariableTensorsToZero() { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_CHECK(interpreter_->ResetVariableTensorsToZero()); + Py_RETURN_NONE; } } // namespace interpreter_wrapper diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index c02aa3804367f787016ef78fc8557005507f051b..641dd93db5b9df292e03e9704a218299f48b14fb 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -21,6 +21,9 @@ limitations under the License. // Place `` before to avoid build failures in macOS. #include + +// The empty line above is on purpose as otherwise clang-format will +// automatically move before . #include // We forward declare TFLite classes here to avoid exposing them to SWIG. @@ -36,38 +39,63 @@ class Interpreter; namespace interpreter_wrapper { +class PythonErrorReporter; + class InterpreterWrapper { public: // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path); + static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path, + std::string* error_msg); // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data); + static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data, + std::string* error_msg); ~InterpreterWrapper(); - bool AllocateTensors(); - bool Invoke(); + PyObject* AllocateTensors(); + PyObject* Invoke(); PyObject* InputIndices() const; PyObject* OutputIndices() const; - bool ResizeInputTensor(int i, PyObject* value); + PyObject* ResizeInputTensor(int i, PyObject* value); std::string TensorName(int i) const; PyObject* TensorType(int i) const; PyObject* TensorSize(int i) const; PyObject* TensorQuantization(int i) const; - bool SetTensor(int i, PyObject* value); + PyObject* SetTensor(int i, PyObject* value); PyObject* GetTensor(int i) const; + PyObject* ResetVariableTensorsToZero(); + + // Returns a reference to tensor index i as a numpy array. The base_object + // should be the interpreter object providing the memory. + PyObject* tensor(PyObject* base_object, int i); private: - InterpreterWrapper(std::unique_ptr model); + // Helper function to construct an `InterpreterWrapper` object. + // It only returns InterpreterWrapper if it can construct an `Interpreter`. + // Otherwise it returns `nullptr`. + static InterpreterWrapper* CreateInterpreterWrapper( + std::unique_ptr model, + std::unique_ptr error_reporter, + std::string* error_msg); + + InterpreterWrapper( + std::unique_ptr model, + std::unique_ptr error_reporter, + std::unique_ptr resolver, + std::unique_ptr interpreter); // InterpreterWrapper is not copyable or assignable. We avoid the use of // InterpreterWrapper() = delete here for SWIG compatibility. InterpreterWrapper(); InterpreterWrapper(const InterpreterWrapper& rhs); + // The public functions which creates `InterpreterWrapper` should ensure all + // these member variables are initialized successfully. Otherwise it should + // report the error and return `nullptr`. const std::unique_ptr model_; + const std::unique_ptr error_reporter_; const std::unique_ptr resolver_; const std::unique_ptr interpreter_; }; diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i index 7f51f9f00d1b2fe057052f7b7bd52bcb65231164..afb2092eacab1d8dcccf8c75cee1d8d5c34d7e75 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i @@ -18,8 +18,51 @@ limitations under the License. %{ #define SWIG_FILE_WITH_INIT +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" %} %include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" + +namespace tflite { +namespace interpreter_wrapper { +%extend InterpreterWrapper { + + // Version of the constructor that handles producing Python exceptions + // that propagate strings. + static PyObject* CreateWrapperCPPFromFile(const char* model_path) { + std::string error; + if(tflite::interpreter_wrapper::InterpreterWrapper* ptr = + tflite::interpreter_wrapper::InterpreterWrapper + ::CreateWrapperCPPFromFile( + model_path, &error)) { + return SWIG_NewPointerObj( + ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1); + } else { + PyErr_SetString(PyExc_ValueError, error.c_str()); + return nullptr; + } + } + + // Version of the constructor that handles producing Python exceptions + // that propagate strings. + static PyObject* CreateWrapperCPPFromBuffer( + PyObject* data) { + std::string error; + if(tflite::interpreter_wrapper::InterpreterWrapper* ptr = + tflite::interpreter_wrapper::InterpreterWrapper + ::CreateWrapperCPPFromBuffer( + data, &error)) { + return SWIG_NewPointerObj( + ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1); + } else { + PyErr_SetString(PyExc_ValueError, error.c_str()); + return nullptr; + } + } +} + +} // namespace interpreter_wrapper +} // namespace tflite diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 876ffbbffa5a47b91d1318baf431050fe364aac0..5ec52035add63ffe5a47fffae258ce4a2efd1bcc 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -22,6 +22,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. @@Interpreter @@OpHint @@convert_op_hints_to_stubs +@@build_toco_convert_protos @@FLOAT @@QUANTIZED_UINT8 @@ -39,22 +40,23 @@ from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError from tensorflow.contrib.lite.python import lite_constants as constants from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import -from tensorflow.contrib.lite.python.convert import tensor_name +from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name from tensorflow.contrib.lite.python.convert import toco_convert from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import -from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model -from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names -from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes +from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model +from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names as _get_tensors_from_tensor_names +from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes as _set_tensor_shapes from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: disable=unused-import from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import from tensorflow.core.framework import graph_pb2 as _graph_pb2 +from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session -from tensorflow.python.framework import graph_util as tf_graph_util -from tensorflow.python.framework.importer import import_graph_def -from tensorflow.python.ops.variables import global_variables_initializer -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model import tag_constants +from tensorflow.python.framework import graph_util as _tf_graph_util +from tensorflow.python.framework import ops as _ops +from tensorflow.python.framework.importer import import_graph_def as _import_graph_def +from tensorflow.python.saved_model import signature_constants as _signature_constants +from tensorflow.python.saved_model import tag_constants as _tag_constants class TocoConverter(object): @@ -65,11 +67,11 @@ class TocoConverter(object): Attributes: - inference_type: Target data type of arrays in the output file. Currently - must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) - inference_input_type: Target data type of input arrays. Allows for a - different type for input arrays in the case of quantization. Currently - must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + inference_type: Target data type of real-number arrays in the output file. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of real-number input arrays. Allows + for a different type for input arrays in the case of quantization. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) output_format: Output file format. Currently must be `{TFLITE, GRAPHVIZ_DOT}`. (default TFLITE) quantized_input_stats: Dict of strings representing input tensor names @@ -129,7 +131,7 @@ class TocoConverter(object): Args: - graph_def: TensorFlow GraphDef. + graph_def: Frozen TensorFlow GraphDef. input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). @@ -175,7 +177,7 @@ class TocoConverter(object): """Creates a TocoConverter class from a file containing a frozen GraphDef. Args: - graph_def_file: Full filepath of file containing TensorFlow GraphDef. + graph_def_file: Full filepath of file containing frozen GraphDef. input_arrays: List of input tensors to freeze graph with. output_arrays: List of output tensors to freeze graph with. input_shapes: Dict of strings representing input tensor names to list of @@ -192,42 +194,41 @@ class TocoConverter(object): The graph is not frozen. input_arrays or output_arrays contains an invalid tensor name. """ - with _session.Session() as sess: - sess.run(global_variables_initializer()) - - # Read GraphDef from file. - graph_def = _graph_pb2.GraphDef() - with open(graph_def_file, "rb") as f: - file_content = f.read() - try: - graph_def.ParseFromString(file_content) - except (_text_format.ParseError, DecodeError): + with _ops.Graph().as_default(): + with _session.Session() as sess: + # Read GraphDef from file. + graph_def = _graph_pb2.GraphDef() + with open(graph_def_file, "rb") as f: + file_content = f.read() try: - print("Ignore 'tcmalloc: large alloc' warnings.") - - if not isinstance(file_content, str): - if PY3: - file_content = file_content.decode('utf-8') - else: - file_content = file_content.encode('utf-8') - _text_format.Merge(file_content, graph_def) + graph_def.ParseFromString(file_content) except (_text_format.ParseError, DecodeError): - raise ValueError( - "Unable to parse input file '{}'.".format(graph_def_file)) - sess.graph.as_default() - import_graph_def(graph_def, name="") - - # Get input and output tensors. - input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays) - output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays) - set_tensor_shapes(input_tensors, input_shapes) + try: + print("Ignore 'tcmalloc: large alloc' warnings.") + + if not isinstance(file_content, str): + if PY3: + file_content = file_content.decode("utf-8") + else: + file_content = file_content.encode("utf-8") + _text_format.Merge(file_content, graph_def) + except (_text_format.ParseError, DecodeError): + raise ValueError( + "Unable to parse input file '{}'.".format(graph_def_file)) + _import_graph_def(graph_def, name="") + + # Get input and output tensors. + input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays) + output_tensors = _get_tensors_from_tensor_names(sess.graph, + output_arrays) + _set_tensor_shapes(input_tensors, input_shapes) - # Check if graph is frozen. - if not _is_frozen_graph(sess): - raise ValueError("Please freeze the graph using freeze_graph.py.") + # Check if graph is frozen. + if not _is_frozen_graph(sess): + raise ValueError("Please freeze the graph using freeze_graph.py.") - # Create TocoConverter class. - return cls(sess.graph_def, input_tensors, output_tensors) + # Create TocoConverter class. + return cls(sess.graph_def, input_tensors, output_tensors) @classmethod def from_saved_model(cls, @@ -258,15 +259,57 @@ class TocoConverter(object): TocoConverter class. """ if tag_set is None: - tag_set = set([tag_constants.SERVING]) + tag_set = set([_tag_constants.SERVING]) if signature_key is None: - signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - result = freeze_saved_model(saved_model_dir, input_arrays, input_shapes, - output_arrays, tag_set, signature_key) + result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, + output_arrays, tag_set, signature_key) return cls( graph_def=result[0], input_tensors=result[1], output_tensors=result[2]) + @classmethod + def from_keras_model_file(cls, + model_file, + input_arrays=None, + input_shapes=None, + output_arrays=None): + """Creates a TocoConverter class from a tf.keras model file. + + Args: + model_file: Full filepath of HDF5 file containing the tf.keras model. + input_arrays: List of input tensors to freeze graph with. Uses input + arrays from SignatureDef when none are provided. (default None) + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : + None}). (default None) + output_arrays: List of output tensors to freeze graph with. Uses output + arrays from SignatureDef when none are provided. (default None) + + Returns: + TocoConverter class. + """ + _keras.backend.clear_session() + _keras.backend.set_learning_phase(False) + keras_model = _keras.models.load_model(model_file) + sess = _keras.backend.get_session() + + # Get input and output tensors. + if input_arrays: + input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays) + else: + input_tensors = keras_model.inputs + + if output_arrays: + output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays) + else: + output_tensors = keras_model.outputs + _set_tensor_shapes(input_tensors, input_shapes) + + graph_def = _freeze_graph(sess, output_tensors) + return cls(graph_def, input_tensors, output_tensors) + def convert(self): """Converts a TensorFlow GraphDef based on instance variables. @@ -283,12 +326,12 @@ class TocoConverter(object): for tensor in self._input_tensors: if not tensor.get_shape(): raise ValueError("Provide an input shape for input array '{0}'.".format( - tensor_name(tensor))) + _tensor_name(tensor))) shape = tensor.get_shape().as_list() if None in shape[1:]: raise ValueError( "None is only supported in the 1st dimension. Tensor '{0}' has " - "invalid shape '{1}'.".format(tensor_name(tensor), shape)) + "invalid shape '{1}'.".format(_tensor_name(tensor), shape)) elif shape[0] is None: self._set_batch_size(batch_size=1) @@ -298,7 +341,7 @@ class TocoConverter(object): quantized_stats = [] invalid_stats = [] for tensor in self._input_tensors: - name = tensor_name(tensor) + name = _tensor_name(tensor) if name in self.quantized_input_stats: quantized_stats.append(self.quantized_input_stats[name]) else: @@ -336,7 +379,7 @@ class TocoConverter(object): Returns: List of strings. """ - return [tensor_name(tensor) for tensor in self._input_tensors] + return [_tensor_name(tensor) for tensor in self._input_tensors] def _set_batch_size(self, batch_size): """Sets the first dimension of the input tensor to `batch_size`. @@ -364,7 +407,7 @@ def _is_frozen_graph(sess): Bool. """ for op in sess.graph.get_operations(): - if op.type.startswith("Variable"): + if op.type.startswith("Variable") or op.type.endswith("VariableOp"): return False return True @@ -383,9 +426,8 @@ def _freeze_graph(sess, output_tensors): Frozen GraphDef. """ if not _is_frozen_graph(sess): - sess.run(global_variables_initializer()) - output_arrays = [tensor_name(tensor) for tensor in output_tensors] - return tf_graph_util.convert_variables_to_constants(sess, sess.graph_def, - output_arrays) + output_arrays = [_tensor_name(tensor) for tensor in output_tensors] + return _tf_graph_util.convert_variables_to_constants( + sess, sess.graph_def, output_arrays) else: return sess.graph_def diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 8c9d2c1651dd2d0b3cd27cf638c04429e3131efb..2f1368422842846aa616eaa7bc1e60ee6b0deaaf 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -19,11 +19,13 @@ from __future__ import division from __future__ import print_function import os +import tempfile import numpy as np from tensorflow.contrib.lite.python import lite from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.python.interpreter import Interpreter +from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -31,6 +33,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import saved_model @@ -196,6 +199,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = in_tensor + var sess = session.Session() + sess.run(_global_variables_initializer()) # Convert model and ensure model is not None. converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) @@ -267,7 +271,8 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(num_items_graphviz_video > num_items_graphviz) def testInferenceInputType(self): - in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8) + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = in_tensor + in_tensor sess = session.Session() @@ -286,14 +291,13 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertEqual('Placeholder', input_details[0]['name']) self.assertEqual(np.uint8, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) - self.assertEqual((0., 0.), input_details[0]['quantization']) + self.assertEqual((1., 0.), input_details[0]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) self.assertEqual('add', output_details[0]['name']) - self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) - self.assertEqual((0., 0.), input_details[0]['quantization']) def testDefaultRangesStats(self): in_tensor = array_ops.placeholder( @@ -618,5 +622,313 @@ class FromSavedModelTest(test_util.TensorFlowTestCase): self.assertTrue(tflite_model) +class FromKerasFile(test_util.TensorFlowTestCase): + + def setUp(self): + keras.backend.clear_session() + + def _getSequentialModel(self): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + model.predict(x) + + try: + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + finally: + os.close(fd) + return keras_file + + def testSequentialModel(self): + """Test a Sequential tf.keras model with default inputs.""" + keras_file = self._getSequentialModel() + + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check tensor details of converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('time_distributed/Reshape_1', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + # Check inference of converted model. + input_data = np.array([[1, 2, 3]], dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], input_data) + interpreter.invoke() + tflite_result = interpreter.get_tensor(output_details[0]['index']) + + keras_model = keras.models.load_model(keras_file) + keras_result = keras_model.predict(input_data) + + np.testing.assert_almost_equal(tflite_result, keras_result, 5) + os.remove(keras_file) + + def testSequentialModelInputArray(self): + """Test a Sequential tf.keras model testing input arrays argument.""" + keras_file = self._getSequentialModel() + + # Invalid input array raises error. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_keras_model_file( + keras_file, input_arrays=['invalid-input']) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) + + # Valid input array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_arrays=['dense_input']) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + def testSequentialModelInputShape(self): + """Test a Sequential tf.keras model testing input shapes argument.""" + keras_file = self._getSequentialModel() + + # Passing in shape of invalid input array has no impact as long as all input + # arrays have a shape. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_shapes={'invalid-input': [2, 3]}) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Passing in shape of valid input array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_shapes={'dense_input': [2, 3]}) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + # Check input shape from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertTrue(([2, 3] == input_details[0]['shape']).all()) + + def testSequentialModelOutputArray(self): + """Test a Sequential tf.keras model testing output arrays argument.""" + keras_file = self._getSequentialModel() + + # Invalid output array raises error. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_keras_model_file( + keras_file, output_arrays=['invalid-output']) + self.assertEqual("Invalid tensors 'invalid-output' were found.", + str(error.exception)) + + # Valid output array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, output_arrays=['time_distributed/Reshape_1']) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + def testFunctionalModel(self): + """Test a Functional tf.keras model with default inputs.""" + inputs = keras.layers.Input(shape=(3,), name='input') + x = keras.layers.Dense(2)(inputs) + output = keras.layers.Dense(3)(x) + + model = keras.models.Model(inputs, output) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy]) + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + model.predict(x) + fd, keras_file = tempfile.mkstemp('.h5') + try: + keras.models.save_model(model, keras_file) + finally: + os.close(fd) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check tensor details of converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('dense_1/BiasAdd', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + # Check inference of converted model. + input_data = np.array([[1, 2, 3]], dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], input_data) + interpreter.invoke() + tflite_result = interpreter.get_tensor(output_details[0]['index']) + + keras_model = keras.models.load_model(keras_file) + keras_result = keras_model.predict(input_data) + + np.testing.assert_almost_equal(tflite_result, keras_result, 5) + os.remove(keras_file) + + def testFunctionalModelMultipleInputs(self): + """Test a Functional tf.keras model with multiple inputs and outputs.""" + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.mae], + loss_weights=[1., 0.5]) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) + + model.predict([input_a_np, input_b_np], batch_size=5) + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.close(fd) + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('input_a', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('input_b', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(2, len(output_details)) + self.assertEqual('dense_1/BiasAdd', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 4] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + self.assertEqual('dropout/Identity', output_details[1]['name']) + self.assertEqual(np.float32, output_details[1]['dtype']) + self.assertTrue(([1, 4] == output_details[1]['shape']).all()) + self.assertEqual((0., 0.), output_details[1]['quantization']) + + def testFunctionalSequentialModel(self): + """Test a Functional tf.keras model containing a Sequential model.""" + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model = keras.models.Model(model.input, model.output) + + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + model.predict(x) + + model.predict(x) + fd, keras_file = tempfile.mkstemp('.h5') + try: + keras.models.save_model(model, keras_file) + finally: + os.close(fd) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check tensor details of converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('time_distributed/Reshape_1', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + # Check inference of converted model. + input_data = np.array([[1, 2, 3]], dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], input_data) + interpreter.invoke() + tflite_result = interpreter.get_tensor(output_details[0]['index']) + + keras_model = keras.models.load_model(keras_file) + keras_result = keras_model.predict(input_data) + + np.testing.assert_almost_equal(tflite_result, keras_result, 5) + os.remove(keras_file) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index f497533bed054d260aefc7b3fe67ae655c7cbcda..a76cc3963580767ab8bd745a9bcd7c9c780ec2b5 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -23,19 +23,15 @@ import os import sys from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 from tensorflow.python.platform import app -def _parse_array(values): +def _parse_array(values, type_fn=str): if values: - return values.split(",") - - -def _parse_int_array(values): - if values: - return [int(val) for val in values.split(",")] + return [type_fn(val) for val in values.split(",") if val] def _parse_set(values): @@ -57,7 +53,8 @@ def _get_toco_converter(flags): input_shapes = None if flags.input_shapes: input_shapes_list = [ - _parse_int_array(shape) for shape in flags.input_shapes.split(":") + _parse_array(shape, type_fn=int) + for shape in flags.input_shapes.split(":") ] input_shapes = dict(zip(input_arrays, input_shapes_list)) output_arrays = _parse_array(flags.output_arrays) @@ -77,6 +74,9 @@ def _get_toco_converter(flags): converter_kwargs["saved_model_dir"] = flags.saved_model_dir converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set) converter_kwargs["signature_key"] = flags.saved_model_signature_key + elif flags.keras_model_file: + converter_fn = lite.TocoConverter.from_keras_model_file + converter_kwargs["model_file"] = flags.keras_model_file return converter_fn(**converter_kwargs) @@ -103,9 +103,9 @@ def _convert_model(flags): if flags.mean_values and flags.std_dev_values: input_arrays = converter.get_input_arrays() - std_dev_values = _parse_int_array(flags.std_dev_values) - mean_values = _parse_int_array(flags.mean_values) - quant_stats = zip(mean_values, std_dev_values) + std_dev_values = _parse_array(flags.std_dev_values, type_fn=int) + mean_values = _parse_array(flags.mean_values, type_fn=int) + quant_stats = list(zip(mean_values, std_dev_values)) if ((not flags.input_arrays and len(input_arrays) > 1) or (len(input_arrays) != len(quant_stats))): raise ValueError("Mismatching --input_arrays, --std_dev_values, and " @@ -130,6 +130,9 @@ def _convert_model(flags): if flags.allow_custom_ops: converter.allow_custom_ops = flags.allow_custom_ops if flags.quantize_weights: + if flags.inference_type == lite_constants.QUANTIZED_UINT8: + raise ValueError("--quantized_weights is not supported with " + "--inference_type=QUANTIZED_UINT8") converter.quantize_weights = flags.quantize_weights if flags.dump_graphviz_dir: converter.dump_graphviz_dir = flags.dump_graphviz_dir @@ -200,6 +203,10 @@ def _check_flags(flags, unparsed): raise ValueError("--default_ranges_min and --default_ranges_max must be " "used together") + if flags.dump_graphviz_video and not flags.dump_graphviz_dir: + raise ValueError("--dump_graphviz_video must be used with " + "--dump_graphviz_dir") + def run_main(_): """Main in toco_convert.py.""" @@ -219,11 +226,15 @@ def run_main(_): input_file_group.add_argument( "--graph_def_file", type=str, - help="Full filepath of file containing TensorFlow GraphDef.") + help="Full filepath of file containing frozen TensorFlow GraphDef.") input_file_group.add_argument( "--saved_model_dir", type=str, help="Full filepath of directory containing the SavedModel.") + input_file_group.add_argument( + "--keras_model_file", + type=str, + help="Full filepath of HDF5 file containing tf.Keras model.") # Model format flags. parser.add_argument( @@ -235,19 +246,19 @@ def run_main(_): "--inference_type", type=str.upper, choices=["FLOAT", "QUANTIZED_UINT8"], - help="Target data type of arrays in the output file.") + help="Target data type of real-number arrays in the output file.") parser.add_argument( "--inference_input_type", type=str.upper, choices=["FLOAT", "QUANTIZED_UINT8"], - help=("Target data type of input arrays. Allows for a different type for " - "input arrays in the case of quantization.")) + help=("Target data type of real-number input arrays. Allows for a " + "different type for input arrays in the case of quantization.")) # Input and output arrays flags. parser.add_argument( "--input_arrays", type=str, - help="Names of the output arrays, comma-separated.") + help="Names of the input arrays, comma-separated.") parser.add_argument( "--input_shapes", type=str, @@ -275,12 +286,12 @@ def run_main(_): "--std_dev_values", type=str, help=("Standard deviation of training data for each input tensor, " - "comma-separated. Used for quantization. (default None)")) + "comma-separated integers. Used for quantization. (default None)")) parser.add_argument( "--mean_values", type=str, - help=("Mean of training data for each input tensor, comma-separated. " - "Used for quantization. (default None)")) + help=("Mean of training data for each input tensor, comma-separated " + "integers. Used for quantization. (default None)")) parser.add_argument( "--default_ranges_min", type=int, diff --git a/tensorflow/contrib/lite/rpi_makefile.inc b/tensorflow/contrib/lite/rpi_makefile.inc deleted file mode 100644 index 832ef5824bea86a368184bd7e3d17915739e9d46..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/rpi_makefile.inc +++ /dev/null @@ -1,33 +0,0 @@ -# Settings for Raspberry Pi. -ifeq ($(TARGET), RPI) - ifeq ($(TARGET_ARCH), armv7) - CXXFLAGS += \ - -march=armv7-a \ - -mfpu=neon-vfpv4 \ - -funsafe-math-optimizations \ - -ftree-vectorize - - CCFLAGS += \ - -march=armv7-a \ - -mfpu=neon-vfpv4 \ - -funsafe-math-optimizations \ - -ftree-vectorize - - LDFLAGS := \ - -Wl,--no-export-dynamic \ - -Wl,--exclude-libs,ALL \ - -Wl,--gc-sections \ - -Wl,--as-needed - endif - - LIBS := \ - -lstdc++ \ - -lpthread \ - -lm \ - -ldl - - OBJDIR := $(OBJDIR)rpi_$(TARGET_ARCH)/ - LIBDIR := $(LIBDIR)rpi_$(TARGET_ARCH)/ - BINDIR := $(BINDIR)rpi_$(TARGET_ARCH)/ - DEPDIR := $(DEPDIR)rpi_$(TARGET_ARCH)/ -endif diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD index 9717a4a1a496b888348514584888e62c4e3703b4..b616e449e6ddae6467a6b86269cd108c7eec0c26 100644 --- a/tensorflow/contrib/lite/schema/BUILD +++ b/tensorflow/contrib/lite/schema/BUILD @@ -30,7 +30,10 @@ py_test( size = "small", srcs = ["upgrade_schema_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_oss", + "no_pip", + ], deps = [ ":upgrade_schema", "//tensorflow/python:client_testlib", @@ -64,7 +67,9 @@ cc_test( "schema_v3.fbs", ], tags = [ + "no_oss", "tflite_not_portable_android", + "tflite_not_portable_ios", ], deps = [ "//tensorflow/core:lib_platform", diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD b/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD index 0148149a6adc141d67e82808f7e8c72ddb7e309a..4a627761daf45b0fddd7b99e8a9c3d0d0ed2ee5e 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD @@ -24,6 +24,7 @@ cc_binary( cc_test( name = "generator_test", srcs = ["generator_test.cc"], + tags = ["no_oss"], deps = [ ":generator", "@com_google_googletest//:gtest", @@ -36,6 +37,7 @@ cc_test( data = [ "//tensorflow/contrib/lite:builtin_ops.h", ], + tags = ["no_oss"], deps = [ ":generator", "@com_google_googletest//:gtest", diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index ee5208df1456d01f1a99ecc69722f5fb4ab0763a..14f88b4c009e4f7cd913c2a27799ab418562fb1f 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -34,6 +34,8 @@ enum TensorType : byte { INT64 = 4, STRING = 5, BOOL = 6, + INT16 = 7, + COMPLEX64 = 8, } // Parameters for converting a quantized tensor back to float. Given a @@ -42,7 +44,7 @@ enum TensorType : byte { table QuantizationParameters { min:[float]; // For importing back into tensorflow. max:[float]; // For importing back into tensorflow. - scale:[float]; + scale:[float]; // For dequantizing the tensor's values. zero_point:[long]; } @@ -63,6 +65,8 @@ table Tensor { buffer:uint; name:string; // For debugging and importing back into tensorflow. quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; } // A list of builtin operators. Builtin operators are slightly faster than custom @@ -151,6 +155,20 @@ enum BuiltinOperator : byte { EQUAL = 71, NOT_EQUAL = 72, LOG = 73, + SUM = 74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, + POW = 78, + ARG_MIN = 79, + FAKE_QUANT = 80, + REDUCE_PROD = 81, + REDUCE_MAX = 82, + PACK = 83, + LOGICAL_OR = 84, + ONE_HOT = 85, + LOGICAL_AND = 86, + LOGICAL_NOT = 87, } // Options for the builtin operators. @@ -181,7 +199,7 @@ union BuiltinOptions { BatchToSpaceNDOptions, SpaceToBatchNDOptions, TransposeOptions, - MeanOptions, + ReducerOptions, SubOptions, DivOptions, SqueezeOptions, @@ -209,6 +227,15 @@ union BuiltinOptions { ExpandDimsOptions, EqualOptions, NotEqualOptions, + ShapeOptions, + PowOptions, + ArgMinOptions, + FakeQuantOptions, + PackOptions, + LogicalOrOptions, + OneHotOptions, + LogicalAndOptions, + LogicalNotOptions, } enum Padding : byte { SAME, VALID } @@ -286,9 +313,18 @@ table BidirectionalSequenceRNNOptions { fused_activation_function:ActivationFunctionType; } +enum FullyConnectedOptionsWeightsFormat: byte { + DEFAULT = 0, + SHUFFLED4x16INT8 = 1, +} + // An implementation of TensorFlow fully_connected (a.k.a Dense) layer. table FullyConnectedOptions { + // Parameters for FullyConnected version 1 or above. fused_activation_function:ActivationFunctionType; + + // Parameters for FullyConnected version 2 or above. + weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; } table SoftmaxOptions { @@ -408,7 +444,7 @@ table TransposeOptions { table ExpOptions { } -table MeanOptions { +table ReducerOptions { keep_dims: bool; } @@ -449,6 +485,10 @@ table ArgMaxOptions { output_type : TensorType; } +table ArgMinOptions { + output_type : TensorType; +} + table GreaterOptions { } @@ -489,6 +529,42 @@ table EqualOptions { table NotEqualOptions { } +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + +table PowOptions { +} + +table FakeQuantOptions { + // Parameters supported by version 1: + min:float; + max:float; + num_bits:int; + + // Parameters supported by version 2: + narrow_range:bool; +} + +table PackOptions { + values_count:int; + axis:int; +} + +table LogicalOrOptions { +} + +table OneHotOptions { + axis:int; +} + +table LogicalAndOptions { +} + +table LogicalNotOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { @@ -520,6 +596,16 @@ table Operator { builtin_options:BuiltinOptions; custom_options:[ubyte]; custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; } // The root type, defining a subgraph, which typically represents an entire diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 887e47ed1ea309d025d4be8745ffb8da06e8ee6b..3efa153e2cfd98dcac9352ff0ef4d8eb9bb6b66a 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -127,8 +127,8 @@ struct TransposeOptionsT; struct ExpOptions; struct ExpOptionsT; -struct MeanOptions; -struct MeanOptionsT; +struct ReducerOptions; +struct ReducerOptionsT; struct SqueezeOptions; struct SqueezeOptionsT; @@ -157,6 +157,9 @@ struct TileOptionsT; struct ArgMaxOptions; struct ArgMaxOptionsT; +struct ArgMinOptions; +struct ArgMinOptionsT; + struct GreaterOptions; struct GreaterOptionsT; @@ -193,6 +196,30 @@ struct EqualOptionsT; struct NotEqualOptions; struct NotEqualOptionsT; +struct ShapeOptions; +struct ShapeOptionsT; + +struct PowOptions; +struct PowOptionsT; + +struct FakeQuantOptions; +struct FakeQuantOptionsT; + +struct PackOptions; +struct PackOptionsT; + +struct LogicalOrOptions; +struct LogicalOrOptionsT; + +struct OneHotOptions; +struct OneHotOptionsT; + +struct LogicalAndOptions; +struct LogicalAndOptionsT; + +struct LogicalNotOptions; +struct LogicalNotOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -216,11 +243,13 @@ enum TensorType { TensorType_INT64 = 4, TensorType_STRING = 5, TensorType_BOOL = 6, + TensorType_INT16 = 7, + TensorType_COMPLEX64 = 8, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_BOOL + TensorType_MAX = TensorType_COMPLEX64 }; -inline TensorType (&EnumValuesTensorType())[7] { +inline TensorType (&EnumValuesTensorType())[9] { static TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, @@ -228,7 +257,9 @@ inline TensorType (&EnumValuesTensorType())[7] { TensorType_UINT8, TensorType_INT64, TensorType_STRING, - TensorType_BOOL + TensorType_BOOL, + TensorType_INT16, + TensorType_COMPLEX64 }; return values; } @@ -242,6 +273,8 @@ inline const char **EnumNamesTensorType() { "INT64", "STRING", "BOOL", + "INT16", + "COMPLEX64", nullptr }; return names; @@ -326,11 +359,25 @@ enum BuiltinOperator { BuiltinOperator_EQUAL = 71, BuiltinOperator_NOT_EQUAL = 72, BuiltinOperator_LOG = 73, + BuiltinOperator_SUM = 74, + BuiltinOperator_SQRT = 75, + BuiltinOperator_RSQRT = 76, + BuiltinOperator_SHAPE = 77, + BuiltinOperator_POW = 78, + BuiltinOperator_ARG_MIN = 79, + BuiltinOperator_FAKE_QUANT = 80, + BuiltinOperator_REDUCE_PROD = 81, + BuiltinOperator_REDUCE_MAX = 82, + BuiltinOperator_PACK = 83, + BuiltinOperator_LOGICAL_OR = 84, + BuiltinOperator_ONE_HOT = 85, + BuiltinOperator_LOGICAL_AND = 86, + BuiltinOperator_LOGICAL_NOT = 87, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_LOG + BuiltinOperator_MAX = BuiltinOperator_LOGICAL_NOT }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[73] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[87] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -404,7 +451,21 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[73] { BuiltinOperator_EXPAND_DIMS, BuiltinOperator_EQUAL, BuiltinOperator_NOT_EQUAL, - BuiltinOperator_LOG + BuiltinOperator_LOG, + BuiltinOperator_SUM, + BuiltinOperator_SQRT, + BuiltinOperator_RSQRT, + BuiltinOperator_SHAPE, + BuiltinOperator_POW, + BuiltinOperator_ARG_MIN, + BuiltinOperator_FAKE_QUANT, + BuiltinOperator_REDUCE_PROD, + BuiltinOperator_REDUCE_MAX, + BuiltinOperator_PACK, + BuiltinOperator_LOGICAL_OR, + BuiltinOperator_ONE_HOT, + BuiltinOperator_LOGICAL_AND, + BuiltinOperator_LOGICAL_NOT }; return values; } @@ -485,6 +546,20 @@ inline const char **EnumNamesBuiltinOperator() { "EQUAL", "NOT_EQUAL", "LOG", + "SUM", + "SQRT", + "RSQRT", + "SHAPE", + "POW", + "ARG_MIN", + "FAKE_QUANT", + "REDUCE_PROD", + "REDUCE_MAX", + "PACK", + "LOGICAL_OR", + "ONE_HOT", + "LOGICAL_AND", + "LOGICAL_NOT", nullptr }; return names; @@ -523,7 +598,7 @@ enum BuiltinOptions { BuiltinOptions_BatchToSpaceNDOptions = 24, BuiltinOptions_SpaceToBatchNDOptions = 25, BuiltinOptions_TransposeOptions = 26, - BuiltinOptions_MeanOptions = 27, + BuiltinOptions_ReducerOptions = 27, BuiltinOptions_SubOptions = 28, BuiltinOptions_DivOptions = 29, BuiltinOptions_SqueezeOptions = 30, @@ -551,11 +626,20 @@ enum BuiltinOptions { BuiltinOptions_ExpandDimsOptions = 52, BuiltinOptions_EqualOptions = 53, BuiltinOptions_NotEqualOptions = 54, + BuiltinOptions_ShapeOptions = 55, + BuiltinOptions_PowOptions = 56, + BuiltinOptions_ArgMinOptions = 57, + BuiltinOptions_FakeQuantOptions = 58, + BuiltinOptions_PackOptions = 59, + BuiltinOptions_LogicalOrOptions = 60, + BuiltinOptions_OneHotOptions = 61, + BuiltinOptions_LogicalAndOptions = 62, + BuiltinOptions_LogicalNotOptions = 63, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_NotEqualOptions + BuiltinOptions_MAX = BuiltinOptions_LogicalNotOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[55] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[64] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -584,7 +668,7 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[55] { BuiltinOptions_BatchToSpaceNDOptions, BuiltinOptions_SpaceToBatchNDOptions, BuiltinOptions_TransposeOptions, - BuiltinOptions_MeanOptions, + BuiltinOptions_ReducerOptions, BuiltinOptions_SubOptions, BuiltinOptions_DivOptions, BuiltinOptions_SqueezeOptions, @@ -611,7 +695,16 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[55] { BuiltinOptions_TileOptions, BuiltinOptions_ExpandDimsOptions, BuiltinOptions_EqualOptions, - BuiltinOptions_NotEqualOptions + BuiltinOptions_NotEqualOptions, + BuiltinOptions_ShapeOptions, + BuiltinOptions_PowOptions, + BuiltinOptions_ArgMinOptions, + BuiltinOptions_FakeQuantOptions, + BuiltinOptions_PackOptions, + BuiltinOptions_LogicalOrOptions, + BuiltinOptions_OneHotOptions, + BuiltinOptions_LogicalAndOptions, + BuiltinOptions_LogicalNotOptions }; return values; } @@ -645,7 +738,7 @@ inline const char **EnumNamesBuiltinOptions() { "BatchToSpaceNDOptions", "SpaceToBatchNDOptions", "TransposeOptions", - "MeanOptions", + "ReducerOptions", "SubOptions", "DivOptions", "SqueezeOptions", @@ -673,6 +766,15 @@ inline const char **EnumNamesBuiltinOptions() { "ExpandDimsOptions", "EqualOptions", "NotEqualOptions", + "ShapeOptions", + "PowOptions", + "ArgMinOptions", + "FakeQuantOptions", + "PackOptions", + "LogicalOrOptions", + "OneHotOptions", + "LogicalAndOptions", + "LogicalNotOptions", nullptr }; return names; @@ -791,8 +893,8 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions; }; -template<> struct BuiltinOptionsTraits { - static const BuiltinOptions enum_value = BuiltinOptions_MeanOptions; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReducerOptions; }; template<> struct BuiltinOptionsTraits { @@ -903,6 +1005,42 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ShapeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_PowOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ArgMinOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FakeQuantOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_PackOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalAndOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalNotOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1142,13 +1280,13 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_TransposeOptions ? reinterpret_cast(value) : nullptr; } - MeanOptionsT *AsMeanOptions() { - return type == BuiltinOptions_MeanOptions ? - reinterpret_cast(value) : nullptr; + ReducerOptionsT *AsReducerOptions() { + return type == BuiltinOptions_ReducerOptions ? + reinterpret_cast(value) : nullptr; } - const MeanOptionsT *AsMeanOptions() const { - return type == BuiltinOptions_MeanOptions ? - reinterpret_cast(value) : nullptr; + const ReducerOptionsT *AsReducerOptions() const { + return type == BuiltinOptions_ReducerOptions ? + reinterpret_cast(value) : nullptr; } SubOptionsT *AsSubOptions() { return type == BuiltinOptions_SubOptions ? @@ -1366,6 +1504,78 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_NotEqualOptions ? reinterpret_cast(value) : nullptr; } + ShapeOptionsT *AsShapeOptions() { + return type == BuiltinOptions_ShapeOptions ? + reinterpret_cast(value) : nullptr; + } + const ShapeOptionsT *AsShapeOptions() const { + return type == BuiltinOptions_ShapeOptions ? + reinterpret_cast(value) : nullptr; + } + PowOptionsT *AsPowOptions() { + return type == BuiltinOptions_PowOptions ? + reinterpret_cast(value) : nullptr; + } + const PowOptionsT *AsPowOptions() const { + return type == BuiltinOptions_PowOptions ? + reinterpret_cast(value) : nullptr; + } + ArgMinOptionsT *AsArgMinOptions() { + return type == BuiltinOptions_ArgMinOptions ? + reinterpret_cast(value) : nullptr; + } + const ArgMinOptionsT *AsArgMinOptions() const { + return type == BuiltinOptions_ArgMinOptions ? + reinterpret_cast(value) : nullptr; + } + FakeQuantOptionsT *AsFakeQuantOptions() { + return type == BuiltinOptions_FakeQuantOptions ? + reinterpret_cast(value) : nullptr; + } + const FakeQuantOptionsT *AsFakeQuantOptions() const { + return type == BuiltinOptions_FakeQuantOptions ? + reinterpret_cast(value) : nullptr; + } + PackOptionsT *AsPackOptions() { + return type == BuiltinOptions_PackOptions ? + reinterpret_cast(value) : nullptr; + } + const PackOptionsT *AsPackOptions() const { + return type == BuiltinOptions_PackOptions ? + reinterpret_cast(value) : nullptr; + } + LogicalOrOptionsT *AsLogicalOrOptions() { + return type == BuiltinOptions_LogicalOrOptions ? + reinterpret_cast(value) : nullptr; + } + const LogicalOrOptionsT *AsLogicalOrOptions() const { + return type == BuiltinOptions_LogicalOrOptions ? + reinterpret_cast(value) : nullptr; + } + OneHotOptionsT *AsOneHotOptions() { + return type == BuiltinOptions_OneHotOptions ? + reinterpret_cast(value) : nullptr; + } + const OneHotOptionsT *AsOneHotOptions() const { + return type == BuiltinOptions_OneHotOptions ? + reinterpret_cast(value) : nullptr; + } + LogicalAndOptionsT *AsLogicalAndOptions() { + return type == BuiltinOptions_LogicalAndOptions ? + reinterpret_cast(value) : nullptr; + } + const LogicalAndOptionsT *AsLogicalAndOptions() const { + return type == BuiltinOptions_LogicalAndOptions ? + reinterpret_cast(value) : nullptr; + } + LogicalNotOptionsT *AsLogicalNotOptions() { + return type == BuiltinOptions_LogicalNotOptions ? + reinterpret_cast(value) : nullptr; + } + const LogicalNotOptionsT *AsLogicalNotOptions() const { + return type == BuiltinOptions_LogicalNotOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -1473,6 +1683,35 @@ inline const char *EnumNameLSHProjectionType(LSHProjectionType e) { return EnumNamesLSHProjectionType()[index]; } +enum FullyConnectedOptionsWeightsFormat { + FullyConnectedOptionsWeightsFormat_DEFAULT = 0, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 = 1, + FullyConnectedOptionsWeightsFormat_MIN = FullyConnectedOptionsWeightsFormat_DEFAULT, + FullyConnectedOptionsWeightsFormat_MAX = FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 +}; + +inline FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOptionsWeightsFormat())[2] { + static FullyConnectedOptionsWeightsFormat values[] = { + FullyConnectedOptionsWeightsFormat_DEFAULT, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 + }; + return values; +} + +inline const char **EnumNamesFullyConnectedOptionsWeightsFormat() { + static const char *names[] = { + "DEFAULT", + "SHUFFLED4x16INT8", + nullptr + }; + return names; +} + +inline const char *EnumNameFullyConnectedOptionsWeightsFormat(FullyConnectedOptionsWeightsFormat e) { + const size_t index = static_cast(e); + return EnumNamesFullyConnectedOptionsWeightsFormat()[index]; +} + enum LSTMKernelType { LSTMKernelType_FULL = 0, LSTMKernelType_BASIC = 1, @@ -1671,9 +1910,11 @@ struct TensorT : public flatbuffers::NativeTable { uint32_t buffer; std::string name; std::unique_ptr quantization; + bool is_variable; TensorT() : type(TensorType_FLOAT32), - buffer(0) { + buffer(0), + is_variable(false) { } }; @@ -1684,7 +1925,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_TYPE = 6, VT_BUFFER = 8, VT_NAME = 10, - VT_QUANTIZATION = 12 + VT_QUANTIZATION = 12, + VT_IS_VARIABLE = 14 }; const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); @@ -1701,6 +1943,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const QuantizationParameters *quantization() const { return GetPointer(VT_QUANTIZATION); } + bool is_variable() const { + return GetField(VT_IS_VARIABLE, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && @@ -1711,6 +1956,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.Verify(name()) && VerifyOffset(verifier, VT_QUANTIZATION) && verifier.VerifyTable(quantization()) && + VerifyField(verifier, VT_IS_VARIABLE) && verifier.EndTable(); } TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -1736,6 +1982,9 @@ struct TensorBuilder { void add_quantization(flatbuffers::Offset quantization) { fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization); } + void add_is_variable(bool is_variable) { + fbb_.AddElement(Tensor::VT_IS_VARIABLE, static_cast(is_variable), 0); + } explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -1754,12 +2003,14 @@ inline flatbuffers::Offset CreateTensor( TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, flatbuffers::Offset name = 0, - flatbuffers::Offset quantization = 0) { + flatbuffers::Offset quantization = 0, + bool is_variable = false) { TensorBuilder builder_(_fbb); builder_.add_quantization(quantization); builder_.add_name(name); builder_.add_buffer(buffer); builder_.add_shape(shape); + builder_.add_is_variable(is_variable); builder_.add_type(type); return builder_.Finish(); } @@ -1770,14 +2021,16 @@ inline flatbuffers::Offset CreateTensorDirect( TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, const char *name = nullptr, - flatbuffers::Offset quantization = 0) { + flatbuffers::Offset quantization = 0, + bool is_variable = false) { return tflite::CreateTensor( _fbb, shape ? _fbb.CreateVector(*shape) : 0, type, buffer, name ? _fbb.CreateString(name) : 0, - quantization); + quantization, + is_variable); } flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -2511,22 +2764,29 @@ flatbuffers::Offset CreateBidirectionalSequence struct FullyConnectedOptionsT : public flatbuffers::NativeTable { typedef FullyConnectedOptions TableType; ActivationFunctionType fused_activation_function; + FullyConnectedOptionsWeightsFormat weights_format; FullyConnectedOptionsT() - : fused_activation_function(ActivationFunctionType_NONE) { + : fused_activation_function(ActivationFunctionType_NONE), + weights_format(FullyConnectedOptionsWeightsFormat_DEFAULT) { } }; struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef FullyConnectedOptionsT NativeTableType; enum { - VT_FUSED_ACTIVATION_FUNCTION = 4 + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_WEIGHTS_FORMAT = 6 }; ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } + FullyConnectedOptionsWeightsFormat weights_format() const { + return static_cast(GetField(VT_WEIGHTS_FORMAT, 0)); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField(verifier, VT_WEIGHTS_FORMAT) && verifier.EndTable(); } FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2540,6 +2800,9 @@ struct FullyConnectedOptionsBuilder { void add_fused_activation_function(ActivationFunctionType fused_activation_function) { fbb_.AddElement(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } + void add_weights_format(FullyConnectedOptionsWeightsFormat weights_format) { + fbb_.AddElement(FullyConnectedOptions::VT_WEIGHTS_FORMAT, static_cast(weights_format), 0); + } explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2554,8 +2817,10 @@ struct FullyConnectedOptionsBuilder { inline flatbuffers::Offset CreateFullyConnectedOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + FullyConnectedOptionsWeightsFormat weights_format = FullyConnectedOptionsWeightsFormat_DEFAULT) { FullyConnectedOptionsBuilder builder_(_fbb); + builder_.add_weights_format(weights_format); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } @@ -3822,16 +4087,16 @@ inline flatbuffers::Offset CreateExpOptions( flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -struct MeanOptionsT : public flatbuffers::NativeTable { - typedef MeanOptions TableType; +struct ReducerOptionsT : public flatbuffers::NativeTable { + typedef ReducerOptions TableType; bool keep_dims; - MeanOptionsT() + ReducerOptionsT() : keep_dims(false) { } }; -struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef MeanOptionsT NativeTableType; +struct ReducerOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ReducerOptionsT NativeTableType; enum { VT_KEEP_DIMS = 4 }; @@ -3843,38 +4108,38 @@ struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_KEEP_DIMS) && verifier.EndTable(); } - MeanOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ReducerOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReducerOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; -struct MeanOptionsBuilder { +struct ReducerOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_keep_dims(bool keep_dims) { - fbb_.AddElement(MeanOptions::VT_KEEP_DIMS, static_cast(keep_dims), 0); + fbb_.AddElement(ReducerOptions::VT_KEEP_DIMS, static_cast(keep_dims), 0); } - explicit MeanOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit ReducerOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - MeanOptionsBuilder &operator=(const MeanOptionsBuilder &); - flatbuffers::Offset Finish() { + ReducerOptionsBuilder &operator=(const ReducerOptionsBuilder &); + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreateMeanOptions( +inline flatbuffers::Offset CreateReducerOptions( flatbuffers::FlatBufferBuilder &_fbb, bool keep_dims = false) { - MeanOptionsBuilder builder_(_fbb); + ReducerOptionsBuilder builder_(_fbb); builder_.add_keep_dims(keep_dims); return builder_.Finish(); } -flatbuffers::Offset CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateReducerOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SqueezeOptionsT : public flatbuffers::NativeTable { typedef SqueezeOptions TableType; @@ -4374,6 +4639,60 @@ inline flatbuffers::Offset CreateArgMaxOptions( flatbuffers::Offset CreateArgMaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct ArgMinOptionsT : public flatbuffers::NativeTable { + typedef ArgMinOptions TableType; + TensorType output_type; + ArgMinOptionsT() + : output_type(TensorType_FLOAT32) { + } +}; + +struct ArgMinOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ArgMinOptionsT NativeTableType; + enum { + VT_OUTPUT_TYPE = 4 + }; + TensorType output_type() const { + return static_cast(GetField(VT_OUTPUT_TYPE, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_OUTPUT_TYPE) && + verifier.EndTable(); + } + ArgMinOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ArgMinOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_output_type(TensorType output_type) { + fbb_.AddElement(ArgMinOptions::VT_OUTPUT_TYPE, static_cast(output_type), 0); + } + explicit ArgMinOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ArgMinOptionsBuilder &operator=(const ArgMinOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateArgMinOptions( + flatbuffers::FlatBufferBuilder &_fbb, + TensorType output_type = TensorType_FLOAT32) { + ArgMinOptionsBuilder builder_(_fbb); + builder_.add_output_type(output_type); + return builder_.Finish(); +} + +flatbuffers::Offset CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct GreaterOptionsT : public flatbuffers::NativeTable { typedef GreaterOptions TableType; GreaterOptionsT() { @@ -4906,166 +5225,592 @@ inline flatbuffers::Offset CreateNotEqualOptions( flatbuffers::Offset CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -struct OperatorCodeT : public flatbuffers::NativeTable { - typedef OperatorCode TableType; - BuiltinOperator builtin_code; - std::string custom_code; - int32_t version; - OperatorCodeT() - : builtin_code(BuiltinOperator_ADD), - version(1) { +struct ShapeOptionsT : public flatbuffers::NativeTable { + typedef ShapeOptions TableType; + TensorType out_type; + ShapeOptionsT() + : out_type(TensorType_FLOAT32) { } }; -struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef OperatorCodeT NativeTableType; +struct ShapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ShapeOptionsT NativeTableType; enum { - VT_BUILTIN_CODE = 4, - VT_CUSTOM_CODE = 6, - VT_VERSION = 8 + VT_OUT_TYPE = 4 }; - BuiltinOperator builtin_code() const { - return static_cast(GetField(VT_BUILTIN_CODE, 0)); - } - const flatbuffers::String *custom_code() const { - return GetPointer(VT_CUSTOM_CODE); - } - int32_t version() const { - return GetField(VT_VERSION, 1); + TensorType out_type() const { + return static_cast(GetField(VT_OUT_TYPE, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_BUILTIN_CODE) && - VerifyOffset(verifier, VT_CUSTOM_CODE) && - verifier.Verify(custom_code()) && - VerifyField(verifier, VT_VERSION) && + VerifyField(verifier, VT_OUT_TYPE) && verifier.EndTable(); } - OperatorCodeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(OperatorCodeT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ShapeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ShapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; -struct OperatorCodeBuilder { +struct ShapeOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_builtin_code(BuiltinOperator builtin_code) { - fbb_.AddElement(OperatorCode::VT_BUILTIN_CODE, static_cast(builtin_code), 0); - } - void add_custom_code(flatbuffers::Offset custom_code) { - fbb_.AddOffset(OperatorCode::VT_CUSTOM_CODE, custom_code); - } - void add_version(int32_t version) { - fbb_.AddElement(OperatorCode::VT_VERSION, version, 1); + void add_out_type(TensorType out_type) { + fbb_.AddElement(ShapeOptions::VT_OUT_TYPE, static_cast(out_type), 0); } - explicit OperatorCodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit ShapeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - OperatorCodeBuilder &operator=(const OperatorCodeBuilder &); - flatbuffers::Offset Finish() { + ShapeOptionsBuilder &operator=(const ShapeOptionsBuilder &); + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreateOperatorCode( +inline flatbuffers::Offset CreateShapeOptions( flatbuffers::FlatBufferBuilder &_fbb, - BuiltinOperator builtin_code = BuiltinOperator_ADD, - flatbuffers::Offset custom_code = 0, - int32_t version = 1) { - OperatorCodeBuilder builder_(_fbb); - builder_.add_version(version); - builder_.add_custom_code(custom_code); - builder_.add_builtin_code(builtin_code); + TensorType out_type = TensorType_FLOAT32) { + ShapeOptionsBuilder builder_(_fbb); + builder_.add_out_type(out_type); return builder_.Finish(); } -inline flatbuffers::Offset CreateOperatorCodeDirect( - flatbuffers::FlatBufferBuilder &_fbb, - BuiltinOperator builtin_code = BuiltinOperator_ADD, - const char *custom_code = nullptr, - int32_t version = 1) { - return tflite::CreateOperatorCode( - _fbb, - builtin_code, - custom_code ? _fbb.CreateString(custom_code) : 0, - version); -} +flatbuffers::Offset CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -flatbuffers::Offset CreateOperatorCode(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct PowOptionsT : public flatbuffers::NativeTable { + typedef PowOptions TableType; + PowOptionsT() { + } +}; -struct OperatorT : public flatbuffers::NativeTable { - typedef Operator TableType; - uint32_t opcode_index; - std::vector inputs; - std::vector outputs; - BuiltinOptionsUnion builtin_options; - std::vector custom_options; - CustomOptionsFormat custom_options_format; - OperatorT() - : opcode_index(0), - custom_options_format(CustomOptionsFormat_FLEXBUFFERS) { +struct PowOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef PowOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); } + PowOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PowOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; -struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef OperatorT NativeTableType; - enum { - VT_OPCODE_INDEX = 4, - VT_INPUTS = 6, - VT_OUTPUTS = 8, - VT_BUILTIN_OPTIONS_TYPE = 10, - VT_BUILTIN_OPTIONS = 12, - VT_CUSTOM_OPTIONS = 14, - VT_CUSTOM_OPTIONS_FORMAT = 16 - }; - uint32_t opcode_index() const { - return GetField(VT_OPCODE_INDEX, 0); +struct PowOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit PowOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); } - const flatbuffers::Vector *inputs() const { - return GetPointer *>(VT_INPUTS); + PowOptionsBuilder &operator=(const PowOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; } - const flatbuffers::Vector *outputs() const { - return GetPointer *>(VT_OUTPUTS); +}; + +inline flatbuffers::Offset CreatePowOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + PowOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct FakeQuantOptionsT : public flatbuffers::NativeTable { + typedef FakeQuantOptions TableType; + float min; + float max; + int32_t num_bits; + bool narrow_range; + FakeQuantOptionsT() + : min(0.0f), + max(0.0f), + num_bits(0), + narrow_range(false) { } - BuiltinOptions builtin_options_type() const { - return static_cast(GetField(VT_BUILTIN_OPTIONS_TYPE, 0)); +}; + +struct FakeQuantOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef FakeQuantOptionsT NativeTableType; + enum { + VT_MIN = 4, + VT_MAX = 6, + VT_NUM_BITS = 8, + VT_NARROW_RANGE = 10 + }; + float min() const { + return GetField(VT_MIN, 0.0f); } - const void *builtin_options() const { - return GetPointer(VT_BUILTIN_OPTIONS); + float max() const { + return GetField(VT_MAX, 0.0f); } - template const T *builtin_options_as() const; - const Conv2DOptions *builtin_options_as_Conv2DOptions() const { - return builtin_options_type() == BuiltinOptions_Conv2DOptions ? static_cast(builtin_options()) : nullptr; + int32_t num_bits() const { + return GetField(VT_NUM_BITS, 0); } - const DepthwiseConv2DOptions *builtin_options_as_DepthwiseConv2DOptions() const { - return builtin_options_type() == BuiltinOptions_DepthwiseConv2DOptions ? static_cast(builtin_options()) : nullptr; + bool narrow_range() const { + return GetField(VT_NARROW_RANGE, 0) != 0; } - const ConcatEmbeddingsOptions *builtin_options_as_ConcatEmbeddingsOptions() const { - return builtin_options_type() == BuiltinOptions_ConcatEmbeddingsOptions ? static_cast(builtin_options()) : nullptr; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_MIN) && + VerifyField(verifier, VT_MAX) && + VerifyField(verifier, VT_NUM_BITS) && + VerifyField(verifier, VT_NARROW_RANGE) && + verifier.EndTable(); } - const LSHProjectionOptions *builtin_options_as_LSHProjectionOptions() const { - return builtin_options_type() == BuiltinOptions_LSHProjectionOptions ? static_cast(builtin_options()) : nullptr; + FakeQuantOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(FakeQuantOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct FakeQuantOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_min(float min) { + fbb_.AddElement(FakeQuantOptions::VT_MIN, min, 0.0f); } - const Pool2DOptions *builtin_options_as_Pool2DOptions() const { - return builtin_options_type() == BuiltinOptions_Pool2DOptions ? static_cast(builtin_options()) : nullptr; + void add_max(float max) { + fbb_.AddElement(FakeQuantOptions::VT_MAX, max, 0.0f); } - const SVDFOptions *builtin_options_as_SVDFOptions() const { - return builtin_options_type() == BuiltinOptions_SVDFOptions ? static_cast(builtin_options()) : nullptr; + void add_num_bits(int32_t num_bits) { + fbb_.AddElement(FakeQuantOptions::VT_NUM_BITS, num_bits, 0); } - const RNNOptions *builtin_options_as_RNNOptions() const { - return builtin_options_type() == BuiltinOptions_RNNOptions ? static_cast(builtin_options()) : nullptr; + void add_narrow_range(bool narrow_range) { + fbb_.AddElement(FakeQuantOptions::VT_NARROW_RANGE, static_cast(narrow_range), 0); } - const FullyConnectedOptions *builtin_options_as_FullyConnectedOptions() const { - return builtin_options_type() == BuiltinOptions_FullyConnectedOptions ? static_cast(builtin_options()) : nullptr; + explicit FakeQuantOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); } - const SoftmaxOptions *builtin_options_as_SoftmaxOptions() const { - return builtin_options_type() == BuiltinOptions_SoftmaxOptions ? static_cast(builtin_options()) : nullptr; + FakeQuantOptionsBuilder &operator=(const FakeQuantOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; } - const ConcatenationOptions *builtin_options_as_ConcatenationOptions() const { - return builtin_options_type() == BuiltinOptions_ConcatenationOptions ? static_cast(builtin_options()) : nullptr; +}; + +inline flatbuffers::Offset CreateFakeQuantOptions( + flatbuffers::FlatBufferBuilder &_fbb, + float min = 0.0f, + float max = 0.0f, + int32_t num_bits = 0, + bool narrow_range = false) { + FakeQuantOptionsBuilder builder_(_fbb); + builder_.add_num_bits(num_bits); + builder_.add_max(max); + builder_.add_min(min); + builder_.add_narrow_range(narrow_range); + return builder_.Finish(); +} + +flatbuffers::Offset CreateFakeQuantOptions(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct PackOptionsT : public flatbuffers::NativeTable { + typedef PackOptions TableType; + int32_t values_count; + int32_t axis; + PackOptionsT() + : values_count(0), + axis(0) { + } +}; + +struct PackOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef PackOptionsT NativeTableType; + enum { + VT_VALUES_COUNT = 4, + VT_AXIS = 6 + }; + int32_t values_count() const { + return GetField(VT_VALUES_COUNT, 0); + } + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VALUES_COUNT) && + VerifyField(verifier, VT_AXIS) && + verifier.EndTable(); + } + PackOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PackOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct PackOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_values_count(int32_t values_count) { + fbb_.AddElement(PackOptions::VT_VALUES_COUNT, values_count, 0); + } + void add_axis(int32_t axis) { + fbb_.AddElement(PackOptions::VT_AXIS, axis, 0); + } + explicit PackOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + PackOptionsBuilder &operator=(const PackOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreatePackOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t values_count = 0, + int32_t axis = 0) { + PackOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + builder_.add_values_count(values_count); + return builder_.Finish(); +} + +flatbuffers::Offset CreatePackOptions(flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LogicalOrOptionsT : public flatbuffers::NativeTable { + typedef LogicalOrOptions TableType; + LogicalOrOptionsT() { + } +}; + +struct LogicalOrOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LogicalOrOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LogicalOrOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LogicalOrOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LogicalOrOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit LogicalOrOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LogicalOrOptionsBuilder &operator=(const LogicalOrOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLogicalOrOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + LogicalOrOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLogicalOrOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct OneHotOptionsT : public flatbuffers::NativeTable { + typedef OneHotOptions TableType; + int32_t axis; + OneHotOptionsT() + : axis(0) { + } +}; + +struct OneHotOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OneHotOptionsT NativeTableType; + enum { + VT_AXIS = 4 + }; + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_AXIS) && + verifier.EndTable(); + } + OneHotOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct OneHotOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { + fbb_.AddElement(OneHotOptions::VT_AXIS, axis, 0); + } + explicit OneHotOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OneHotOptionsBuilder &operator=(const OneHotOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateOneHotOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0) { + OneHotOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + return builder_.Finish(); +} + +flatbuffers::Offset CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LogicalAndOptionsT : public flatbuffers::NativeTable { + typedef LogicalAndOptions TableType; + LogicalAndOptionsT() { + } +}; + +struct LogicalAndOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LogicalAndOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LogicalAndOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LogicalAndOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LogicalAndOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit LogicalAndOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LogicalAndOptionsBuilder &operator=(const LogicalAndOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLogicalAndOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + LogicalAndOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLogicalAndOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LogicalNotOptionsT : public flatbuffers::NativeTable { + typedef LogicalNotOptions TableType; + LogicalNotOptionsT() { + } +}; + +struct LogicalNotOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LogicalNotOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LogicalNotOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LogicalNotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LogicalNotOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit LogicalNotOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LogicalNotOptionsBuilder &operator=(const LogicalNotOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLogicalNotOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + LogicalNotOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLogicalNotOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct OperatorCodeT : public flatbuffers::NativeTable { + typedef OperatorCode TableType; + BuiltinOperator builtin_code; + std::string custom_code; + int32_t version; + OperatorCodeT() + : builtin_code(BuiltinOperator_ADD), + version(1) { + } +}; + +struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OperatorCodeT NativeTableType; + enum { + VT_BUILTIN_CODE = 4, + VT_CUSTOM_CODE = 6, + VT_VERSION = 8 + }; + BuiltinOperator builtin_code() const { + return static_cast(GetField(VT_BUILTIN_CODE, 0)); + } + const flatbuffers::String *custom_code() const { + return GetPointer(VT_CUSTOM_CODE); + } + int32_t version() const { + return GetField(VT_VERSION, 1); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_BUILTIN_CODE) && + VerifyOffset(verifier, VT_CUSTOM_CODE) && + verifier.Verify(custom_code()) && + VerifyField(verifier, VT_VERSION) && + verifier.EndTable(); + } + OperatorCodeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OperatorCodeT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct OperatorCodeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_builtin_code(BuiltinOperator builtin_code) { + fbb_.AddElement(OperatorCode::VT_BUILTIN_CODE, static_cast(builtin_code), 0); + } + void add_custom_code(flatbuffers::Offset custom_code) { + fbb_.AddOffset(OperatorCode::VT_CUSTOM_CODE, custom_code); + } + void add_version(int32_t version) { + fbb_.AddElement(OperatorCode::VT_VERSION, version, 1); + } + explicit OperatorCodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OperatorCodeBuilder &operator=(const OperatorCodeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateOperatorCode( + flatbuffers::FlatBufferBuilder &_fbb, + BuiltinOperator builtin_code = BuiltinOperator_ADD, + flatbuffers::Offset custom_code = 0, + int32_t version = 1) { + OperatorCodeBuilder builder_(_fbb); + builder_.add_version(version); + builder_.add_custom_code(custom_code); + builder_.add_builtin_code(builtin_code); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateOperatorCodeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + BuiltinOperator builtin_code = BuiltinOperator_ADD, + const char *custom_code = nullptr, + int32_t version = 1) { + return tflite::CreateOperatorCode( + _fbb, + builtin_code, + custom_code ? _fbb.CreateString(custom_code) : 0, + version); +} + +flatbuffers::Offset CreateOperatorCode(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct OperatorT : public flatbuffers::NativeTable { + typedef Operator TableType; + uint32_t opcode_index; + std::vector inputs; + std::vector outputs; + BuiltinOptionsUnion builtin_options; + std::vector custom_options; + CustomOptionsFormat custom_options_format; + std::vector mutating_variable_inputs; + OperatorT() + : opcode_index(0), + custom_options_format(CustomOptionsFormat_FLEXBUFFERS) { + } +}; + +struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OperatorT NativeTableType; + enum { + VT_OPCODE_INDEX = 4, + VT_INPUTS = 6, + VT_OUTPUTS = 8, + VT_BUILTIN_OPTIONS_TYPE = 10, + VT_BUILTIN_OPTIONS = 12, + VT_CUSTOM_OPTIONS = 14, + VT_CUSTOM_OPTIONS_FORMAT = 16, + VT_MUTATING_VARIABLE_INPUTS = 18 + }; + uint32_t opcode_index() const { + return GetField(VT_OPCODE_INDEX, 0); + } + const flatbuffers::Vector *inputs() const { + return GetPointer *>(VT_INPUTS); + } + const flatbuffers::Vector *outputs() const { + return GetPointer *>(VT_OUTPUTS); + } + BuiltinOptions builtin_options_type() const { + return static_cast(GetField(VT_BUILTIN_OPTIONS_TYPE, 0)); + } + const void *builtin_options() const { + return GetPointer(VT_BUILTIN_OPTIONS); + } + template const T *builtin_options_as() const; + const Conv2DOptions *builtin_options_as_Conv2DOptions() const { + return builtin_options_type() == BuiltinOptions_Conv2DOptions ? static_cast(builtin_options()) : nullptr; + } + const DepthwiseConv2DOptions *builtin_options_as_DepthwiseConv2DOptions() const { + return builtin_options_type() == BuiltinOptions_DepthwiseConv2DOptions ? static_cast(builtin_options()) : nullptr; + } + const ConcatEmbeddingsOptions *builtin_options_as_ConcatEmbeddingsOptions() const { + return builtin_options_type() == BuiltinOptions_ConcatEmbeddingsOptions ? static_cast(builtin_options()) : nullptr; + } + const LSHProjectionOptions *builtin_options_as_LSHProjectionOptions() const { + return builtin_options_type() == BuiltinOptions_LSHProjectionOptions ? static_cast(builtin_options()) : nullptr; + } + const Pool2DOptions *builtin_options_as_Pool2DOptions() const { + return builtin_options_type() == BuiltinOptions_Pool2DOptions ? static_cast(builtin_options()) : nullptr; + } + const SVDFOptions *builtin_options_as_SVDFOptions() const { + return builtin_options_type() == BuiltinOptions_SVDFOptions ? static_cast(builtin_options()) : nullptr; + } + const RNNOptions *builtin_options_as_RNNOptions() const { + return builtin_options_type() == BuiltinOptions_RNNOptions ? static_cast(builtin_options()) : nullptr; + } + const FullyConnectedOptions *builtin_options_as_FullyConnectedOptions() const { + return builtin_options_type() == BuiltinOptions_FullyConnectedOptions ? static_cast(builtin_options()) : nullptr; + } + const SoftmaxOptions *builtin_options_as_SoftmaxOptions() const { + return builtin_options_type() == BuiltinOptions_SoftmaxOptions ? static_cast(builtin_options()) : nullptr; + } + const ConcatenationOptions *builtin_options_as_ConcatenationOptions() const { + return builtin_options_type() == BuiltinOptions_ConcatenationOptions ? static_cast(builtin_options()) : nullptr; } const AddOptions *builtin_options_as_AddOptions() const { return builtin_options_type() == BuiltinOptions_AddOptions ? static_cast(builtin_options()) : nullptr; @@ -5115,8 +5860,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const TransposeOptions *builtin_options_as_TransposeOptions() const { return builtin_options_type() == BuiltinOptions_TransposeOptions ? static_cast(builtin_options()) : nullptr; } - const MeanOptions *builtin_options_as_MeanOptions() const { - return builtin_options_type() == BuiltinOptions_MeanOptions ? static_cast(builtin_options()) : nullptr; + const ReducerOptions *builtin_options_as_ReducerOptions() const { + return builtin_options_type() == BuiltinOptions_ReducerOptions ? static_cast(builtin_options()) : nullptr; } const SubOptions *builtin_options_as_SubOptions() const { return builtin_options_type() == BuiltinOptions_SubOptions ? static_cast(builtin_options()) : nullptr; @@ -5199,12 +5944,42 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const NotEqualOptions *builtin_options_as_NotEqualOptions() const { return builtin_options_type() == BuiltinOptions_NotEqualOptions ? static_cast(builtin_options()) : nullptr; } + const ShapeOptions *builtin_options_as_ShapeOptions() const { + return builtin_options_type() == BuiltinOptions_ShapeOptions ? static_cast(builtin_options()) : nullptr; + } + const PowOptions *builtin_options_as_PowOptions() const { + return builtin_options_type() == BuiltinOptions_PowOptions ? static_cast(builtin_options()) : nullptr; + } + const ArgMinOptions *builtin_options_as_ArgMinOptions() const { + return builtin_options_type() == BuiltinOptions_ArgMinOptions ? static_cast(builtin_options()) : nullptr; + } + const FakeQuantOptions *builtin_options_as_FakeQuantOptions() const { + return builtin_options_type() == BuiltinOptions_FakeQuantOptions ? static_cast(builtin_options()) : nullptr; + } + const PackOptions *builtin_options_as_PackOptions() const { + return builtin_options_type() == BuiltinOptions_PackOptions ? static_cast(builtin_options()) : nullptr; + } + const LogicalOrOptions *builtin_options_as_LogicalOrOptions() const { + return builtin_options_type() == BuiltinOptions_LogicalOrOptions ? static_cast(builtin_options()) : nullptr; + } + const OneHotOptions *builtin_options_as_OneHotOptions() const { + return builtin_options_type() == BuiltinOptions_OneHotOptions ? static_cast(builtin_options()) : nullptr; + } + const LogicalAndOptions *builtin_options_as_LogicalAndOptions() const { + return builtin_options_type() == BuiltinOptions_LogicalAndOptions ? static_cast(builtin_options()) : nullptr; + } + const LogicalNotOptions *builtin_options_as_LogicalNotOptions() const { + return builtin_options_type() == BuiltinOptions_LogicalNotOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } CustomOptionsFormat custom_options_format() const { return static_cast(GetField(VT_CUSTOM_OPTIONS_FORMAT, 0)); } + const flatbuffers::Vector *mutating_variable_inputs() const { + return GetPointer *>(VT_MUTATING_VARIABLE_INPUTS); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_OPCODE_INDEX) && @@ -5218,6 +5993,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_CUSTOM_OPTIONS) && verifier.Verify(custom_options()) && VerifyField(verifier, VT_CUSTOM_OPTIONS_FORMAT) && + VerifyOffset(verifier, VT_MUTATING_VARIABLE_INPUTS) && + verifier.Verify(mutating_variable_inputs()) && verifier.EndTable(); } OperatorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -5329,8 +6106,8 @@ template<> inline const TransposeOptions *Operator::builtin_options_as inline const MeanOptions *Operator::builtin_options_as() const { - return builtin_options_as_MeanOptions(); +template<> inline const ReducerOptions *Operator::builtin_options_as() const { + return builtin_options_as_ReducerOptions(); } template<> inline const SubOptions *Operator::builtin_options_as() const { @@ -5409,36 +6186,72 @@ template<> inline const LessEqualOptions *Operator::builtin_options_as inline const SelectOptions *Operator::builtin_options_as() const { - return builtin_options_as_SelectOptions(); +template<> inline const SelectOptions *Operator::builtin_options_as() const { + return builtin_options_as_SelectOptions(); +} + +template<> inline const SliceOptions *Operator::builtin_options_as() const { + return builtin_options_as_SliceOptions(); +} + +template<> inline const TransposeConvOptions *Operator::builtin_options_as() const { + return builtin_options_as_TransposeConvOptions(); +} + +template<> inline const SparseToDenseOptions *Operator::builtin_options_as() const { + return builtin_options_as_SparseToDenseOptions(); +} + +template<> inline const TileOptions *Operator::builtin_options_as() const { + return builtin_options_as_TileOptions(); +} + +template<> inline const ExpandDimsOptions *Operator::builtin_options_as() const { + return builtin_options_as_ExpandDimsOptions(); +} + +template<> inline const EqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_EqualOptions(); +} + +template<> inline const NotEqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_NotEqualOptions(); +} + +template<> inline const ShapeOptions *Operator::builtin_options_as() const { + return builtin_options_as_ShapeOptions(); +} + +template<> inline const PowOptions *Operator::builtin_options_as() const { + return builtin_options_as_PowOptions(); } -template<> inline const SliceOptions *Operator::builtin_options_as() const { - return builtin_options_as_SliceOptions(); +template<> inline const ArgMinOptions *Operator::builtin_options_as() const { + return builtin_options_as_ArgMinOptions(); } -template<> inline const TransposeConvOptions *Operator::builtin_options_as() const { - return builtin_options_as_TransposeConvOptions(); +template<> inline const FakeQuantOptions *Operator::builtin_options_as() const { + return builtin_options_as_FakeQuantOptions(); } -template<> inline const SparseToDenseOptions *Operator::builtin_options_as() const { - return builtin_options_as_SparseToDenseOptions(); +template<> inline const PackOptions *Operator::builtin_options_as() const { + return builtin_options_as_PackOptions(); } -template<> inline const TileOptions *Operator::builtin_options_as() const { - return builtin_options_as_TileOptions(); +template<> inline const LogicalOrOptions *Operator::builtin_options_as() const { + return builtin_options_as_LogicalOrOptions(); } -template<> inline const ExpandDimsOptions *Operator::builtin_options_as() const { - return builtin_options_as_ExpandDimsOptions(); +template<> inline const OneHotOptions *Operator::builtin_options_as() const { + return builtin_options_as_OneHotOptions(); } -template<> inline const EqualOptions *Operator::builtin_options_as() const { - return builtin_options_as_EqualOptions(); +template<> inline const LogicalAndOptions *Operator::builtin_options_as() const { + return builtin_options_as_LogicalAndOptions(); } -template<> inline const NotEqualOptions *Operator::builtin_options_as() const { - return builtin_options_as_NotEqualOptions(); +template<> inline const LogicalNotOptions *Operator::builtin_options_as() const { + return builtin_options_as_LogicalNotOptions(); } struct OperatorBuilder { @@ -5465,6 +6278,9 @@ struct OperatorBuilder { void add_custom_options_format(CustomOptionsFormat custom_options_format) { fbb_.AddElement(Operator::VT_CUSTOM_OPTIONS_FORMAT, static_cast(custom_options_format), 0); } + void add_mutating_variable_inputs(flatbuffers::Offset> mutating_variable_inputs) { + fbb_.AddOffset(Operator::VT_MUTATING_VARIABLE_INPUTS, mutating_variable_inputs); + } explicit OperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -5485,8 +6301,10 @@ inline flatbuffers::Offset CreateOperator( BuiltinOptions builtin_options_type = BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, flatbuffers::Offset> custom_options = 0, - CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS, + flatbuffers::Offset> mutating_variable_inputs = 0) { OperatorBuilder builder_(_fbb); + builder_.add_mutating_variable_inputs(mutating_variable_inputs); builder_.add_custom_options(custom_options); builder_.add_builtin_options(builtin_options); builder_.add_outputs(outputs); @@ -5505,7 +6323,8 @@ inline flatbuffers::Offset CreateOperatorDirect( BuiltinOptions builtin_options_type = BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, const std::vector *custom_options = nullptr, - CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS, + const std::vector *mutating_variable_inputs = nullptr) { return tflite::CreateOperator( _fbb, opcode_index, @@ -5514,7 +6333,8 @@ inline flatbuffers::Offset CreateOperatorDirect( builtin_options_type, builtin_options, custom_options ? _fbb.CreateVector(*custom_options) : 0, - custom_options_format); + custom_options_format, + mutating_variable_inputs ? _fbb.CreateVector(*mutating_variable_inputs) : 0); } flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -5885,6 +6705,7 @@ inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t { auto _e = buffer(); _o->buffer = _e; }; { auto _e = name(); if (_e) _o->name = _e->str(); }; { auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr(_e->UnPack(_resolver)); }; + { auto _e = is_variable(); _o->is_variable = _e; }; } inline flatbuffers::Offset Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -5900,13 +6721,15 @@ inline flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder & auto _buffer = _o->buffer; auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0; + auto _is_variable = _o->is_variable; return tflite::CreateTensor( _fbb, _shape, _type, _buffer, _name, - _quantization); + _quantization, + _is_variable); } inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -6210,6 +7033,7 @@ inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const fl (void)_o; (void)_resolver; { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = weights_format(); _o->weights_format = _e; }; } inline flatbuffers::Offset FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -6221,9 +7045,11 @@ inline flatbuffers::Offset CreateFullyConnectedOptions(fl (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FullyConnectedOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; + auto _weights_format = _o->weights_format; return tflite::CreateFullyConnectedOptions( _fbb, - _fused_activation_function); + _fused_activation_function, + _weights_format); } inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -6830,28 +7656,28 @@ inline flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferB _fbb); } -inline MeanOptionsT *MeanOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new MeanOptionsT(); +inline ReducerOptionsT *ReducerOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ReducerOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void MeanOptions::UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void ReducerOptions::UnPackTo(ReducerOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; { auto _e = keep_dims(); _o->keep_dims = _e; }; } -inline flatbuffers::Offset MeanOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateMeanOptions(_fbb, _o, _rehasher); +inline flatbuffers::Offset ReducerOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateReducerOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateReducerOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MeanOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ReducerOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _keep_dims = _o->keep_dims; - return tflite::CreateMeanOptions( + return tflite::CreateReducerOptions( _fbb, _keep_dims); } @@ -7093,6 +7919,32 @@ inline flatbuffers::Offset CreateArgMaxOptions(flatbuffers::FlatB _output_type); } +inline ArgMinOptionsT *ArgMinOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ArgMinOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ArgMinOptions::UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = output_type(); _o->output_type = _e; }; +} + +inline flatbuffers::Offset ArgMinOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateArgMinOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ArgMinOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _output_type = _o->output_type; + return tflite::CreateArgMinOptions( + _fbb, + _output_type); +} + inline GreaterOptionsT *GreaterOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new GreaterOptionsT(); UnPackTo(_o, _resolver); @@ -7381,6 +8233,214 @@ inline flatbuffers::Offset CreateNotEqualOptions(flatbuffers::F _fbb); } +inline ShapeOptionsT *ShapeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ShapeOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ShapeOptions::UnPackTo(ShapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = out_type(); _o->out_type = _e; }; +} + +inline flatbuffers::Offset ShapeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateShapeOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ShapeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _out_type = _o->out_type; + return tflite::CreateShapeOptions( + _fbb, + _out_type); +} + +inline PowOptionsT *PowOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new PowOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void PowOptions::UnPackTo(PowOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset PowOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreatePowOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PowOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreatePowOptions( + _fbb); +} + +inline FakeQuantOptionsT *FakeQuantOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new FakeQuantOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void FakeQuantOptions::UnPackTo(FakeQuantOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = min(); _o->min = _e; }; + { auto _e = max(); _o->max = _e; }; + { auto _e = num_bits(); _o->num_bits = _e; }; + { auto _e = narrow_range(); _o->narrow_range = _e; }; +} + +inline flatbuffers::Offset FakeQuantOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateFakeQuantOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateFakeQuantOptions(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FakeQuantOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _min = _o->min; + auto _max = _o->max; + auto _num_bits = _o->num_bits; + auto _narrow_range = _o->narrow_range; + return tflite::CreateFakeQuantOptions( + _fbb, + _min, + _max, + _num_bits, + _narrow_range); +} + +inline PackOptionsT *PackOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new PackOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void PackOptions::UnPackTo(PackOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = values_count(); _o->values_count = _e; }; + { auto _e = axis(); _o->axis = _e; }; +} + +inline flatbuffers::Offset PackOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreatePackOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreatePackOptions(flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PackOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _values_count = _o->values_count; + auto _axis = _o->axis; + return tflite::CreatePackOptions( + _fbb, + _values_count, + _axis); +} + +inline LogicalOrOptionsT *LogicalOrOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LogicalOrOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LogicalOrOptions::UnPackTo(LogicalOrOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset LogicalOrOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLogicalOrOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLogicalOrOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LogicalOrOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLogicalOrOptions( + _fbb); +} + +inline OneHotOptionsT *OneHotOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new OneHotOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void OneHotOptions::UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = axis(); _o->axis = _e; }; +} + +inline flatbuffers::Offset OneHotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateOneHotOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OneHotOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _axis = _o->axis; + return tflite::CreateOneHotOptions( + _fbb, + _axis); +} + +inline LogicalAndOptionsT *LogicalAndOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LogicalAndOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LogicalAndOptions::UnPackTo(LogicalAndOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset LogicalAndOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLogicalAndOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLogicalAndOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LogicalAndOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLogicalAndOptions( + _fbb); +} + +inline LogicalNotOptionsT *LogicalNotOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LogicalNotOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LogicalNotOptions::UnPackTo(LogicalNotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset LogicalNotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLogicalNotOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLogicalNotOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LogicalNotOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLogicalNotOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -7429,6 +8489,7 @@ inline void Operator::UnPackTo(OperatorT *_o, const flatbuffers::resolver_functi { auto _e = builtin_options(); if (_e) _o->builtin_options.value = BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); }; { auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom_options[_i] = _e->Get(_i); } } }; { auto _e = custom_options_format(); _o->custom_options_format = _e; }; + { auto _e = mutating_variable_inputs(); if (_e) { _o->mutating_variable_inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->mutating_variable_inputs[_i] = _e->Get(_i) != 0; } } }; } inline flatbuffers::Offset Operator::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -7446,6 +8507,7 @@ inline flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuild auto _builtin_options = _o->builtin_options.Pack(_fbb); auto _custom_options = _o->custom_options.size() ? _fbb.CreateVector(_o->custom_options) : 0; auto _custom_options_format = _o->custom_options_format; + auto _mutating_variable_inputs = _o->mutating_variable_inputs.size() ? _fbb.CreateVector(_o->mutating_variable_inputs) : 0; return tflite::CreateOperator( _fbb, _opcode_index, @@ -7454,7 +8516,8 @@ inline flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuild _builtin_options_type, _builtin_options, _custom_options, - _custom_options_format); + _custom_options_format, + _mutating_variable_inputs); } inline SubGraphT *SubGraph::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -7671,8 +8734,8 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(obj); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SubOptions: { @@ -7783,6 +8846,42 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FakeQuantOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_PackOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalOrOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -7905,8 +9004,8 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(obj); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SubOptions: { @@ -8017,6 +9116,42 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_FakeQuantOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_PackOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LogicalOrOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -8127,9 +9262,9 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateTransposeOptions(_fbb, ptr, _rehasher).Union(); } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(value); - return CreateMeanOptions(_fbb, ptr, _rehasher).Union(); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(value); + return CreateReducerOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SubOptions: { auto ptr = reinterpret_cast(value); @@ -8239,6 +9374,42 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateNotEqualOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(value); + return CreateShapeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(value); + return CreatePowOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast(value); + return CreateArgMinOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_FakeQuantOptions: { + auto ptr = reinterpret_cast(value); + return CreateFakeQuantOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_PackOptions: { + auto ptr = reinterpret_cast(value); + return CreatePackOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LogicalOrOptions: { + auto ptr = reinterpret_cast(value); + return CreateLogicalOrOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(value); + return CreateOneHotOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(value); + return CreateLogicalAndOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(value); + return CreateLogicalNotOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -8349,8 +9520,8 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new TransposeOptionsT(*reinterpret_cast(u.value)); break; } - case BuiltinOptions_MeanOptions: { - value = new MeanOptionsT(*reinterpret_cast(u.value)); + case BuiltinOptions_ReducerOptions: { + value = new ReducerOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SubOptions: { @@ -8461,6 +9632,42 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new NotEqualOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_ShapeOptions: { + value = new ShapeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_PowOptions: { + value = new PowOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ArgMinOptions: { + value = new ArgMinOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_FakeQuantOptions: { + value = new FakeQuantOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_PackOptions: { + value = new PackOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LogicalOrOptions: { + value = new LogicalOrOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_OneHotOptions: { + value = new OneHotOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LogicalAndOptions: { + value = new LogicalAndOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LogicalNotOptions: { + value = new LogicalNotOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -8598,8 +9805,8 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(value); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(value); delete ptr; break; } @@ -8738,6 +9945,51 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_FakeQuantOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_PackOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LogicalOrOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/schema/upgrade_schema.py b/tensorflow/contrib/lite/schema/upgrade_schema.py index e0b36d3d3ee94b00cccd3968d14c63fe19c3c27c..a2ddf6295014f3b29fa584f2bb367a7e0a4399e7 100644 --- a/tensorflow/contrib/lite/schema/upgrade_schema.py +++ b/tensorflow/contrib/lite/schema/upgrade_schema.py @@ -99,9 +99,9 @@ class Converter(object): # dispatch function table. self._schemas.sort() self._new_version, self._new_schema = self._schemas[-1][:2] - self._upgrade_dispatch = dict( - (version, dispatch) - for version, unused1, unused2, dispatch in self._schemas) + self._upgrade_dispatch = { + version: dispatch + for version, unused1, unused2, dispatch in self._schemas} def _Read(self, input_file, schema, raw_binary=False): """Read a tflite model assuming the given flatbuffer schema. diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc index 4eaf6f1bfe76efc1e6737d03d58be9bc87bb849d..cd0f1f7c17a50f6ce61fa2033e5d13580399f5cf 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.cc +++ b/tensorflow/contrib/lite/simple_memory_arena.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/lite/simple_memory_arena.h" +#include #include #include #include @@ -34,7 +35,7 @@ namespace tflite { TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context, size_t alignment, size_t size, ArenaAlloc* new_alloc) { - TF_LITE_ENSURE(context, alignment < arena_alignment_); + TF_LITE_ENSURE(context, alignment <= arena_alignment_); if (size == 0) { new_alloc->offset = 0; diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc index a89776b29f895fe82ee71efe00c0949c58c109df..a316a40b62d89189da43768d448acdf5bbeca129 100644 --- a/tensorflow/contrib/lite/string_util.cc +++ b/tensorflow/contrib/lite/string_util.cc @@ -105,7 +105,7 @@ void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) { dims->data[0] = offset_.size() - 1; // Store number of strings. TfLiteTensorReset(tensor->type, tensor->name, dims, tensor->params, tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation, - tensor); + tensor->is_variable, tensor); } int GetStringCount(const char* raw_buffer) { diff --git a/tensorflow/contrib/lite/testdata/add.bin b/tensorflow/contrib/lite/testdata/add.bin new file mode 100644 index 0000000000000000000000000000000000000000..aef0fe3d82c9d92dc444076d3b46e05af1923f46 Binary files /dev/null and b/tensorflow/contrib/lite/testdata/add.bin differ diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index b823c97f38e7660652aa0ce3538b11de59dc9aea..a788d41ba7b370cd0e84c343202f1dca090180f3 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -140,6 +140,7 @@ cc_test( cc_library( name = "join", hdrs = ["join.h"], + deps = ["//tensorflow/contrib/lite:string"], ) cc_test( @@ -172,6 +173,7 @@ cc_test( data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], tags = [ "tflite_not_portable_android", + "tflite_not_portable_ios", ], deps = [ ":tflite_driver", @@ -208,6 +210,10 @@ cc_library( cc_library( name = "util", hdrs = ["util.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string", + ], ) cc_test( @@ -251,6 +257,7 @@ cc_test( srcs = ["tf_driver_test.cc"], data = ["//tensorflow/contrib/lite:testdata/multi_add.pb"], tags = [ + "no_oss", "tflite_not_portable", ], deps = [ @@ -267,6 +274,7 @@ cc_library( ":join", ":split", ":tf_driver", + "//tensorflow/contrib/lite:string", "//tensorflow/core:framework", ], ) @@ -276,6 +284,7 @@ cc_test( size = "small", srcs = ["generate_testspec_test.cc"], tags = [ + "no_oss", "tflite_not_portable", ], deps = [ @@ -332,7 +341,7 @@ tf_cc_test( ], tags = [ "no_cuda_on_cpu_tap", - "no_oss", + "no_oss", # needs test data "tflite_not_portable", ], deps = [ diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index f5e25784fa17209af7cfb06d32aeea2b9b947196..9dd5c8ae449013986594131dc100f3411892729e 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -90,12 +90,10 @@ TEST_INPUT_DEPTH = 3 # matching the expression will be considered due to the corresponding bug. KNOWN_BUGS = { # TOCO doesn't support scalars as input. - r"relu.*input_shape=\[\]": "67587484", - r"sigmoid.*input_shape=\[\]": "67645668", # Concat doesn't work with a single input tensor r"concat.*num_tensors=1": "67378344", - # Transposition in MatMul is not supported. - r"fully_connected.*transpose_.=True": "67586970", + # Transposition in MatMul is not fully supported. + "fully_connected.*transpose_a=True": "67586970", # Softmax graphs are too complex. r"softmax.*dim=0": "67749831", # BatchToSpaceND only supports 4D tensors. @@ -137,7 +135,7 @@ def toco_options(data_types, Returns: the options in a string. """ - shape_str = ":".join([",".join(str(y) for y in x) for x in shapes]) + shape_str = ":".join([",".join(str(y) for y in x) for x in shapes if x]) inference_type = "FLOAT" # TODO(ahentz): if we get multi-input quantization to work we need this # to change @@ -228,7 +226,9 @@ _TF_TYPE_INFO = { tf.float16: (np.float16, "FLOAT"), tf.int32: (np.int32, "INT32"), tf.uint8: (np.uint8, "QUANTIZED_UINT8"), + tf.int16: (np.int16, "QUANTIZED_INT16"), tf.int64: (np.int64, "INT64"), + tf.bool: (np.bool, "BOOL"), } @@ -240,9 +240,12 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): if dtype in (tf.float32, tf.float16): value = (max_value-min_value)*np.random.random_sample(shape)+min_value - elif dtype in (tf.int32, tf.uint8, tf.int64): + elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): value = np.random.randint(min_value, max_value+1, shape) - return value.astype(dtype) + elif dtype == tf.bool: + value = np.random.choice([True, False], size=shape) + return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype( + dtype) def create_scalar_data(dtype, min_value=-100, max_value=100): @@ -253,7 +256,7 @@ def create_scalar_data(dtype, min_value=-100, max_value=100): if dtype in (tf.float32, tf.float16): value = (max_value - min_value) * np.random.random() + min_value - elif dtype in (tf.int32, tf.uint8, tf.int64): + elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): value = np.random.randint(min_value, max_value + 1) return np.array(value, dtype=dtype) @@ -479,7 +482,7 @@ def make_zip_of_tests(zip_path, else report_lib.FAILED) report["toco_log"] = toco_log - if FLAGS.save_graphdefs: + if True or FLAGS.save_graphdefs: archive.writestr(label + ".pbtxt", text_format.MessageToString(graph_def), zipfile.ZIP_DEFLATED) @@ -678,6 +681,63 @@ def make_relu6_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_prelu_tests(zip_path): + """Make a set of tests to do PReLU.""" + + test_parameters = [ + { + # The canonical case for image processing is having a 4D `input` + # (NHWC)and `shared_axes`=[1, 2], so the alpha parameter is per + # channel. + "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]], + "shared_axes": [[1, 2], [1]], + }, + { + # 2D-3D example. Share the 2nd axis. + "input_shape": [[20, 20], [20, 20, 20]], + "shared_axes": [[1]], + } + ] + + def build_graph(parameters): + """Build the graph for the test case.""" + + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + prelu = tf.keras.layers.PReLU(shared_axes=parameters["shared_axes"]) + out = prelu(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Build the inputs for the test case.""" + + input_shape = parameters["input_shape"] + input_values = create_tensor_data( + np.float32, input_shape, min_value=-10, max_value=10) + shared_axes = parameters["shared_axes"] + + alpha_shape = [] + for dim in range(1, len(input_shape)): + alpha_shape.append(1 if dim in shared_axes else input_shape[dim]) + + alpha_values = create_tensor_data(np.float32, alpha_shape) + + # There should be only 1 trainable variable tensor. + variables = tf.all_variables() + assert len(variables) == 1 + sess.run(variables[0].assign(alpha_values)) + + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + use_frozen_graph=True) + + # This function tests various TensorFLow functions that generates Const op, # including `tf.ones`, `tf.zeros` and random functions. def make_constant_tests(zip_path): @@ -685,27 +745,28 @@ def make_constant_tests(zip_path): test_parameters = [{ "dtype": [tf.float32, tf.int32], - "input_shape": [[1], [2], [1, 1, 1, 1], [2, 2, 2, 2]], + "input_shape": [[], [1], [2], [1, 1, 1, 1], [2, 2, 2, 2]], }] def build_graph(parameters): - # Since Toco & Tflite can't have a single constant op in the entire graph, - # this test adds a zero tensor with a constant op tensor. - input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", - shape=parameters["input_shape"]) - out = tf.ones(parameters["input_shape"], dtype=parameters["dtype"]) + input1 - return [input1], [out] + dummy_input = tf.placeholder( + dtype=parameters["dtype"], + name="input1", + shape=parameters["input_shape"]) + out = tf.constant( + create_tensor_data(parameters["dtype"], parameters["input_shape"])) + return [dummy_input], [out] def build_inputs(parameters, sess, inputs, outputs): - input1 = np.zeros(parameters["input_shape"], - dtype=_TF_TYPE_INFO[parameters["dtype"]][0]) - return [input1], sess.run(outputs, feed_dict={inputs[0]: input1}) + dummy_input = np.zeros( + parameters["input_shape"], dtype=_TF_TYPE_INFO[parameters["dtype"]][0]) + return [dummy_input], sess.run(outputs, feed_dict={inputs[0]: dummy_input}) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) def make_binary_op_tests(zip_path, binary_operator): - """Make a set of tests to do add with and without broadcast.""" + """Make a set of tests to do binary ops with and without broadcast.""" # These parameters are split because we don't support broadcasting. test_parameters = [{ @@ -723,6 +784,11 @@ def make_binary_op_tests(zip_path, binary_operator): "input_shape_1": [[1, 3, 4, 3]], "input_shape_2": [[3]], "activation": [True] + }, { + "dtype": [tf.float32], + "input_shape_1": [[]], + "input_shape_2": [[]], + "activation": [False] }] def build_graph(parameters): @@ -755,11 +821,13 @@ def make_binary_op_tests(zip_path, binary_operator): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_reduce_tests(reduce_op): +def make_reduce_tests(reduce_op, min_value=-10, max_value=10): """Make a set of tests to do reduce operation. Args: reduce_op: TensorFlow reduce operation to test, i.e. `tf.reduce_mean`. + min_value: min value for created tensor data. + max_value: max value for created tensor data. Returns: a function representing the true generator with `reduce_op_in` curried. @@ -772,7 +840,7 @@ def make_reduce_tests(reduce_op): "input_dtype": [tf.float32, tf.int32, tf.int64], "input_shape": [[3, 2, 4]], "axis": [ - None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], + 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0], [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] ], @@ -782,13 +850,19 @@ def make_reduce_tests(reduce_op): "input_dtype": [tf.float32], "input_shape": [[1, 8, 8, 3]], "axis": [ - None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], + 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2, -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2], [2, 2, 3], [-3, -3, -4], [-3, 2, 1] ], "const_axis": [True, False], "keepdims": [True, False], + }, { + "input_dtype": [tf.float32], + "input_shape": [[], [1, 8, 8, 3], [3, 2, 4]], + "axis": [None], + "const_axis": [True], + "keepdims": [True, False], }] def build_graph(parameters): @@ -806,7 +880,7 @@ def make_reduce_tests(reduce_op): if isinstance(parameters["axis"], list): shape = [len(parameters["axis"])] else: - shape = [0] # shape for None or integers. + shape = [] # shape for None or integers. axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape) input_tensors = [input_tensor, axis] @@ -816,11 +890,14 @@ def make_reduce_tests(reduce_op): def build_inputs(parameters, sess, inputs, outputs): values = [ - create_tensor_data(parameters["input_dtype"], - parameters["input_shape"])] + create_tensor_data( + parameters["input_dtype"], + parameters["input_shape"], + min_value=min_value, + max_value=max_value) + ] if not parameters["const_axis"]: - if parameters["axis"]: - values.append(np.array(parameters["axis"])) + values.append(np.array(parameters["axis"])) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -830,16 +907,31 @@ def make_reduce_tests(reduce_op): def make_mean_tests(zip_path): """Make a set of tests to do mean.""" - return make_reduce_tests(tf.reduce_mean)(zip_path) +def make_sum_tests(zip_path): + """Make a set of tests to do sum.""" + return make_reduce_tests(tf.reduce_sum)(zip_path) + + +def make_reduce_prod_tests(zip_path): + """Make a set of tests to do prod.""" + # set min max value to be -2, 2 to avoid overflow. + return make_reduce_tests(tf.reduce_prod, -2, 2)(zip_path) + + +def make_reduce_max_tests(zip_path): + """Make a set of tests to do max.""" + return make_reduce_tests(tf.reduce_max)(zip_path) + + def make_exp_tests(zip_path): """Make a set of tests to do exp.""" test_parameters = [{ "input_dtype": [tf.float32], - "input_shape": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], }] def build_graph(parameters): @@ -898,8 +990,8 @@ def make_maximum_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32], - "input_shape_1": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], - "input_shape_2": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_1": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_2": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], }] def build_graph(parameters): @@ -933,8 +1025,8 @@ def make_minimum_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32], - "input_shape_1": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], - "input_shape_2": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_1": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_2": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], }] def build_graph(parameters): @@ -984,6 +1076,10 @@ def make_mul_tests(zip_path): make_binary_op_tests(zip_path, tf.multiply) +def make_pow_tests(zip_path): + make_binary_op_tests(zip_path, tf.pow) + + def make_gather_tests(zip_path): """Make a set of tests to do gather.""" @@ -1159,6 +1255,75 @@ def make_conv_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +# Note: This is a regression test for a bug (b/112436267) that Toco incorrectly +# fuses weights when multiple Conv2D/FULLY_CONNECTED ops share the same constant +# weight tensor. +def make_conv_with_shared_weights_tests(zip_path): + """Make a test where 2 Conv ops shared the same constant weight tensor.""" + + test_parameters = [{ + "input_shape": [[1, 10, 10, 3]], + "filter_shape": [[3, 3]], + "strides": [[1, 1, 1, 1]], + "dilations": [[1, 1, 1, 1]], + "padding": ["SAME"], + "data_format": ["NHWC"], + "channel_multiplier": [1], + }] + + def get_tensor_shapes(parameters): + input_shape = parameters["input_shape"] + filter_size = parameters["filter_shape"] + filter_shape = filter_size + [ + input_shape[3], parameters["channel_multiplier"] + ] + return [input_shape, filter_shape] + + def build_graph(parameters): + """Build a conv graph given `parameters`.""" + input_shape, filter_shape = get_tensor_shapes(parameters) + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=input_shape) + + # Construct a constant weights tensor which will be used by both Conv2D. + filter_tensor = tf.constant( + create_tensor_data(np.float32, filter_shape), dtype=tf.float32) + input_tensors = [input_tensor] + + # Construct 2 Conv2D operations which use exactly the same input and + # weights. + result1 = tf.nn.conv2d( + input_tensor, + filter_tensor, + strides=parameters["strides"], + dilations=parameters["dilations"], + padding=parameters["padding"], + data_format=parameters["data_format"]) + result2 = tf.nn.conv2d( + input_tensor, + filter_tensor, + strides=parameters["strides"], + dilations=parameters["dilations"], + padding=parameters["padding"], + data_format=parameters["data_format"]) + # Add MUL ops after Conv2D ops. These MUL ops should be fused into the + # weights of Conv2D. + result1 = result1 * 2 + result2 = result2 * 3 + # Add the 2 results up. + out = result1 + result2 + return input_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + # Build list of input values either containing 1 tensor (input) or 2 tensors + # (input, filter) based on whether filter is constant or variable input. + input_shape, unused_filter_shape = get_tensor_shapes(parameters) + values = [create_tensor_data(np.float32, input_shape)] + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_depthwiseconv_tests(zip_path): """Make a set of tests to do convolution.""" @@ -1261,6 +1426,7 @@ def make_concat_tests(zip_path): "base_shape": [[1, 3, 4, 3], [3, 4]], "num_tensors": [1, 2, 3, 4, 5, 6], "axis": [0, 1, 2, 3, -3, -2, -1], + "type": [tf.float32, tf.uint8, tf.int32, tf.int64], }] def get_shape(parameters, delta): @@ -1276,7 +1442,8 @@ def make_concat_tests(zip_path): def build_graph(parameters): all_tensors = [] for n in range(0, parameters["num_tensors"]): - input_tensor = tf.placeholder(dtype=tf.float32, name=("input%d" % n), + input_tensor = tf.placeholder(dtype=parameters["type"], + name=("input%d" % n), shape=get_shape(parameters, n)) all_tensors.append(input_tensor) out = tf.concat(all_tensors, parameters["axis"]) @@ -1285,8 +1452,8 @@ def make_concat_tests(zip_path): def build_inputs(parameters, sess, inputs, outputs): all_values = [] for n in range(0, parameters["num_tensors"]): - input_values = create_tensor_data(np.float32, - get_shape(parameters, n)) + input_values = create_tensor_data( + parameters["type"], get_shape(parameters, n)) all_values.append(input_values) return all_values, sess.run( outputs, feed_dict=dict(zip(inputs, all_values))) @@ -1315,6 +1482,12 @@ def make_fully_connected_tests(zip_path): "transpose_a": [False], "transpose_b": [False], "constant_filter": [True, False], + }, { + "shape1": [[40, 37]], + "shape2": [[40, 37]], + "transpose_a": [False], + "transpose_b": [True], + "constant_filter": [True, False], }] def build_graph(parameters): @@ -1522,19 +1695,124 @@ def make_reshape_tests(zip_path): "dtype": [tf.float32, tf.int32], "input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]], "output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]], + "constant_shape": [True, False], + }, { + "dtype": [tf.float32], + "input_shape": [[1]], + "output_shape": [[]], + "constant_shape": [True, False], }] def build_graph(parameters): input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", shape=parameters["input_shape"]) - out = tf.reshape(input_tensor, shape=parameters["output_shape"]) - return [input_tensor], [out] + + # Get shape as either a placeholder or constants. + if parameters["constant_shape"]: + output_shape = parameters["output_shape"] + input_tensors = [input_tensor] + else: + # The shape of the shape tensor. + shape_tensor_shape = [len(parameters["output_shape"])] + output_shape = tf.placeholder( + dtype=tf.int32, name="output_shape", shape=shape_tensor_shape) + input_tensors = [input_tensor, output_shape] + out = tf.reshape(input_tensor, shape=output_shape) + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + values = [ + create_tensor_data(parameters["dtype"], parameters["input_shape"]) + ] + if not parameters["constant_shape"]: + values.append(np.array(parameters["output_shape"])) + + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_shape_tests(zip_path): + """Make a set of tests to do shape.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape": [[], [0], [1, 1, 1, 3], [2, 3, 4, 5], [5, 5], [10]], + "out_type": [tf.int32, tf.int64], + }] + + def build_graph(parameters): + """Build the shape op testing graph.""" + # Note that we intentionally leave out the shape from the input placeholder + # to prevent the Shape operation from being optimized out during conversion. + input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input") + out = tf.shape(input_value, out_type=parameters["out_type"]) + return [input_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_one_hot_tests(zip_path): + """Make a set of tests to do one_hot.""" + + test_parameters = [{ + "indices_type": [tf.int32, tf.int64], + "indices_shape": [[3], [4, 4], [1, 5], [5, 1]], + "axis": [0, 1], + "dtype": [tf.int32, tf.int64, tf.float32], + "provide_optional_inputs": [True, False], + }] + + def build_graph(parameters): + indices = tf.placeholder( + dtype=parameters["indices_type"], + name="indices", + shape=parameters["indices_shape"]) + depth = tf.placeholder(dtype=tf.int32, name="depth", shape=()) + + if not parameters["provide_optional_inputs"]: + out = tf.one_hot(indices=indices, depth=depth) + return [indices, depth], [out] + + on_value = tf.placeholder( + dtype=parameters["dtype"], name="on_value", shape=()) + off_value = tf.placeholder( + dtype=parameters["dtype"], name="off_value", shape=()) + out = tf.one_hot( + indices=indices, + depth=depth, + on_value=on_value, + off_value=off_value, + axis=parameters["axis"], + dtype=parameters["dtype"]) + return [indices, depth, on_value, off_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = [ + create_tensor_data( + parameters["indices_type"], + shape=parameters["indices_shape"], + min_value=-1, + max_value=10), + create_tensor_data(tf.int32, shape=None, min_value=1, max_value=10), + ] + + if parameters["provide_optional_inputs"]: + input_values.append( + create_tensor_data( + parameters["dtype"], shape=None, min_value=1, max_value=10)) + input_values.append( + create_tensor_data( + parameters["dtype"], shape=None, min_value=-1, max_value=0)) + + return input_values, sess.run( + outputs, feed_dict=dict(zip(inputs, input_values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -2112,6 +2390,7 @@ def make_topk_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32, tf.int32], "input_shape": [[10], [5, 20]], + "input_k": [None, 1, 3], }] def build_graph(parameters): @@ -2120,27 +2399,36 @@ def make_topk_tests(zip_path): dtype=parameters["input_dtype"], name="input", shape=parameters["input_shape"]) - k = tf.constant(3, name="k") + if parameters["input_k"] is not None: + k = tf.placeholder(dtype=tf.int32, name="input_k", shape=[]) + else: + k = tf.constant(3, name="k") out = tf.nn.top_k(input_value, k) - return [input_value], [out[1]] + return [input_value, k], [out[1]] def build_inputs(parameters, sess, inputs, outputs): input_value = create_tensor_data(parameters["input_dtype"], parameters["input_shape"]) - return [input_value], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_value]))) + if parameters["input_k"] is not None: + k = np.array(parameters["input_k"], dtype=np.int32) + return [input_value, k], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value, k]))) + else: + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_arg_max_tests(zip_path): +def make_arg_min_max_tests(zip_path): """Make a set of tests to do arg_max.""" test_parameters = [{ "input_dtype": [tf.float32, tf.int32], - "input_shape": [[1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]], + "input_shape": [[], [1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]], "output_type": [tf.int32, tf.int64], "axis_is_last_dim": [True, False], + "is_arg_max": [True], }] def build_graph(parameters): @@ -2153,7 +2441,10 @@ def make_arg_max_tests(zip_path): axis = len(parameters["input_shape"]) - 1 else: axis = random.randint(0, max(len(parameters["input_shape"]) - 2, 0)) - out = tf.arg_max(input_value, axis, output_type=parameters["output_type"]) + if parameters["is_arg_max"]: + out = tf.arg_max(input_value, axis, output_type=parameters["output_type"]) + else: + out = tf.arg_min(input_value, axis, output_type=parameters["output_type"]) return [input_value], [out] def build_inputs(parameters, sess, inputs, outputs): @@ -2170,7 +2461,8 @@ def make_equal_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32, tf.int32, tf.int64], - "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + "input_shape_pair": [([], []), + ([1, 1, 1, 3], [1, 1, 1, 3]), ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), ([5, 5], [1]), ([10], [2, 4, 10])], }] @@ -2427,11 +2719,11 @@ def _make_elementwise_tests(op): """Actual function that generates examples.""" test_parameters = [{ "input_dtype": [tf.float32], - "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], + "input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], }] def build_graph(parameters): - """Build the sin op testing graph.""" + """Build the unary op testing graph.""" input_value = tf.placeholder( dtype=parameters["input_dtype"], name="input1", @@ -2460,6 +2752,16 @@ def make_log_tests(zip_path): return _make_elementwise_tests(tf.log)(zip_path) +def make_sqrt_tests(zip_path): + """Make a set of tests to do sqrt.""" + return _make_elementwise_tests(tf.sqrt)(zip_path) + + +def make_rsqrt_tests(zip_path): + """Make a set of tests to do 1/sqrt.""" + return _make_elementwise_tests(tf.rsqrt)(zip_path) + + def make_where_tests(zip_path): """Make a set of tests to do where.""" @@ -2739,6 +3041,95 @@ def make_sparse_to_dense_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_pack_tests(zip_path): + """Make a set of tests to do stack.""" + + test_parameters = [{ + "base_shape": [[3, 4, 3], [3, 4], [5]], + "num_tensors": [1, 2, 3, 4, 5, 6], + "axis": [0, 1, 2, 3], + "additional_shape": [1, 2, 3], + }] + + def get_shape(parameters): + """Return a tweaked version of 'base_shape'.""" + axis = parameters["axis"] + shape = parameters["base_shape"][:] + if axis < len(shape): + shape[axis] += parameters["additional_shape"] + return shape + + def build_graph(parameters): + all_tensors = [] + for n in range(0, parameters["num_tensors"]): + input_tensor = tf.placeholder( + dtype=tf.float32, name=("input%d" % n), shape=get_shape(parameters)) + all_tensors.append(input_tensor) + out = tf.stack(all_tensors, parameters["axis"]) + return all_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + all_values = [] + for _ in range(0, parameters["num_tensors"]): + input_values = create_tensor_data(np.float32, get_shape(parameters)) + all_values.append(input_values) + return all_values, sess.run( + outputs, feed_dict=dict(zip(inputs, all_values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def _make_logical_tests(op): + """Make a set of tests to do logical operations.""" + + def logical(zip_path): + """Generate examples.""" + test_parameters = [{ + "input_shape_pair": [([], []), ([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the logical testing graph.""" + input_value1 = tf.placeholder( + dtype=tf.bool, name="input1", shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=tf.bool, name="input2", shape=parameters["input_shape_pair"][1]) + out = op(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(tf.bool, + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(tf.bool, + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + return logical + + +def make_logical_or_tests(zip_path): + """Make a set of tests to do logical_or.""" + return _make_logical_tests(tf.logical_or)(zip_path) + + +def make_logical_and_tests(zip_path): + """Make a set of tests to do logical_and.""" + return _make_logical_tests(tf.logical_and)(zip_path) + + +def make_logical_xor_tests(zip_path): + """Make a set of tests to do logical_xor. + + Test logical_not as well. + """ + return _make_logical_tests(tf.logical_xor)(zip_path) + + # Toco binary path provided by the generate rule. bin_path = None diff --git a/tensorflow/contrib/lite/testing/generate_testspec.cc b/tensorflow/contrib/lite/testing/generate_testspec.cc index c0c861ff6da2fc144b9303dfdd48f19794cebeca..f29c188e6c2c55bdb13d257c70e23c2943abfa4a 100644 --- a/tensorflow/contrib/lite/testing/generate_testspec.cc +++ b/tensorflow/contrib/lite/testing/generate_testspec.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/contrib/lite/testing/generate_testspec.h" #include "tensorflow/contrib/lite/testing/join.h" #include "tensorflow/contrib/lite/testing/split.h" @@ -25,7 +27,7 @@ namespace testing { template void GenerateCsv(const std::vector& shape, float min, float max, string* out) { - auto random_float = [](int min, int max) { + auto random_float = [](float min, float max) { static unsigned int seed; return min + (max - min) * static_cast(rand_r(&seed)) / RAND_MAX; }; @@ -37,16 +39,10 @@ void GenerateCsv(const std::vector& shape, float min, float max, *out = Join(data.data(), data.size(), ","); } -bool GenerateTestSpecFromTensorflowModel( - std::iostream& stream, const string& tensorflow_model_path, - const string& tflite_model_path, const std::vector& input_layer, +std::vector GenerateInputValues( + const std::vector& input_layer, const std::vector& input_layer_type, - const std::vector& input_layer_shape, - const std::vector& output_layer) { - CHECK_EQ(input_layer.size(), input_layer_type.size()); - CHECK_EQ(input_layer.size(), input_layer_shape.size()); - - // Generate inputs. + const std::vector& input_layer_shape) { std::vector input_values; input_values.resize(input_layer.size()); for (int i = 0; i < input_layer.size(); i++) { @@ -73,57 +69,82 @@ bool GenerateTestSpecFromTensorflowModel( default: fprintf(stderr, "Unsupported type %d (%s) when generating testspec.\n", type, input_layer_type[i].c_str()); - return false; + input_values.clear(); + return input_values; } } + return input_values; +} + +bool GenerateTestSpecFromTensorflowModel( + std::iostream& stream, const string& tensorflow_model_path, + const string& tflite_model_path, int num_invocations, + const std::vector& input_layer, + const std::vector& input_layer_type, + const std::vector& input_layer_shape, + const std::vector& output_layer) { + CHECK_EQ(input_layer.size(), input_layer_type.size()); + CHECK_EQ(input_layer.size(), input_layer_shape.size()); // Invoke tensorflow model. TfDriver runner(input_layer, input_layer_type, input_layer_shape, output_layer); if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; + std::cerr << runner.GetErrorMessage() << std::endl; return false; } runner.LoadModel(tensorflow_model_path); if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; + std::cerr << runner.GetErrorMessage() << std::endl; return false; } - for (int i = 0; i < input_values.size(); i++) { - runner.SetInput(i, input_values[i]); - if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; - return false; - } - } - - runner.Invoke(); - if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; - return false; - } - - // Write test spec. + // Write first part of test spec, defining model and input shapes. stream << "load_model: " << tflite_model_path << "\n"; stream << "reshape {\n"; for (const auto& shape : input_layer_shape) { stream << " input: \"" << shape << "\"\n"; } stream << "}\n"; - stream << "invoke {\n"; - for (const auto& value : input_values) { - stream << " input: \"" << value << "\"\n"; - } - for (int i = 0; i < output_layer.size(); i++) { - stream << " output: \"" << runner.ReadOutput(i) << "\"\n"; + + // Generate inputs. + for (int i = 0; i < num_invocations; ++i) { + // Note that the input values are random, so each invocation will have a + // different set. + std::vector input_values = + GenerateInputValues(input_layer, input_layer_type, input_layer_shape); + if (input_values.empty()) return false; + + // Run TensorFlow. + for (int j = 0; j < input_values.size(); j++) { + runner.SetInput(j, input_values[j]); + if (!runner.IsValid()) { + std::cerr << runner.GetErrorMessage() << std::endl; + return false; + } + } + + runner.Invoke(); if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; + std::cerr << runner.GetErrorMessage() << std::endl; return false; } + + // Write second part of test spec, with inputs and outputs. + stream << "invoke {\n"; + for (const auto& value : input_values) { + stream << " input: \"" << value << "\"\n"; + } + for (int j = 0; j < output_layer.size(); j++) { + stream << " output: \"" << runner.ReadOutput(j) << "\"\n"; + if (!runner.IsValid()) { + std::cerr << runner.GetErrorMessage() << std::endl; + return false; + } + } + stream << "}\n"; } - stream << "}\n"; return true; } diff --git a/tensorflow/contrib/lite/testing/generate_testspec.h b/tensorflow/contrib/lite/testing/generate_testspec.h index 6e31a853c3f7f82a89126ff83af784ffd418741a..b3d0db31c01a8cb1b8f34ff6dbb00c77de29b131 100644 --- a/tensorflow/contrib/lite/testing/generate_testspec.h +++ b/tensorflow/contrib/lite/testing/generate_testspec.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "tensorflow/contrib/lite/string.h" + namespace tflite { namespace testing { @@ -30,13 +32,15 @@ namespace testing { // stream: mutable iostream that contains the contents of test spec. // tensorflow_model_path: path to TensorFlow model. // tflite_model_path: path to tflite_model_path that the test spec runs +// num_invocations: how many pairs of inputs and outputs will be generated. // against. input_layer: names of input tensors. Example: input1 // input_layer_type: datatypes of input tensors. Example: float // input_layer_shape: shapes of input tensors, separated by comma. example: // 1,3,4 output_layer: names of output tensors. Example: output bool GenerateTestSpecFromTensorflowModel( std::iostream& stream, const string& tensorflow_model_path, - const string& tflite_model_path, const std::vector& input_layer, + const string& tflite_model_path, int num_invocations, + const std::vector& input_layer, const std::vector& input_layer_type, const std::vector& input_layer_shape, const std::vector& output_layer); diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 8a59d756f8dbbcefc930b5285c1ced8ce6b08845..e67fee2a1ca40790a171dc236dd2d85203690a62 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -33,15 +33,21 @@ namespace testing { namespace { bool FLAGS_ignore_known_bugs = true; -// TODO(b/71769302) zip_files_dir should have a more accurate default, if -// possible -string* FLAGS_zip_file_path = new string("./"); +// As archive file names are test-specific, no default is possible. +// +// This test supports input as both zip and tar, as a stock android image does +// not have unzip but does have tar. +string* FLAGS_zip_file_path = new string; +string* FLAGS_tar_file_path = new string; #ifndef __ANDROID__ string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip"); +string* FLAGS_tar_binary_path = new string("/bin/tar"); #else string* FLAGS_unzip_binary_path = new string("/system/bin/unzip"); +string* FLAGS_tar_binary_path = new string("/system/bin/tar"); #endif bool FLAGS_use_nnapi = false; +bool FLAGS_ignore_unsupported_nnapi = false; } // namespace // TensorFlow system environment for file system called. @@ -52,13 +58,6 @@ tensorflow::Env* env = tensorflow::Env::Default(); // Key is a substring of the test name and value is a bug number. // TODO(ahentz): make sure we clean this list up frequently. std::map kBrokenTests = { - // Add only supports float32. (and "constant" tests use Add) - {R"(^\/add_a.*int32)", "68808744"}, - {R"(^\/constant.*int32)", "68808744"}, - {R"(^\/mul.*int32)", "68808744"}, - {R"(^\/div.*int32)", "68808744"}, - {R"(^\/sub.*int32)", "68808744"}, - // Pad and PadV2 only supports 4D tensors. {R"(^\/pad.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])", "70527055"}, @@ -92,25 +91,23 @@ std::map kBrokenTests = { // Transpose only supports 1D-4D input tensors. {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"}, - // PRelu only supports 4D input with (1, 1, channels) 3D alpha now. - {R"(^\/prelu.*shared_axes=\[1\])", "75975192"}, - // No support for axis!=0 in GatherV2. {R"(^\/gather.*axis=1)", "76910444"}, // No support for arbitrary dimensions in ArgMax. - {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])", + {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])", "77546240"}, - {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])", + {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])", + "77546240"}, + {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.\])", "77546240"}, - {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.\])", "77546240"}, }; -// Allows test data to be unzipped into a temporary directory and makes +// Allows test data to be unarchived into a temporary directory and makes // sure those temporary directories are removed later. -class ZipEnvironment : public ::testing::Environment { +class ArchiveEnvironment : public ::testing::Environment { public: - ~ZipEnvironment() override {} + ~ArchiveEnvironment() override {} // Delete all temporary directories on teardown. void TearDown() override { @@ -122,15 +119,26 @@ class ZipEnvironment : public ::testing::Environment { temporary_directories_.clear(); } - // Unzip `zip` file into a new temporary directory `out_dir`. - tensorflow::Status UnZip(const string& zip, string* out_dir) { + // Unarchive `archive` file into a new temporary directory `out_dir`. + tensorflow::Status UnArchive(const string& zip, const string& tar, + string* out_dir) { string dir; TF_CHECK_OK(MakeTemporaryDirectory(&dir)); tensorflow::SubProcess proc; - string unzip_binary = *FLAGS_unzip_binary_path; - TF_CHECK_OK(env->FileExists(unzip_binary)); - TF_CHECK_OK(env->FileExists(zip)); - proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip}); + if (!zip.empty()) { + string unzip_binary = *FLAGS_unzip_binary_path; + TF_CHECK_OK(env->FileExists(unzip_binary)); + TF_CHECK_OK(env->FileExists(zip)); + proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip}); + } else { + string tar_binary = *FLAGS_tar_binary_path; + TF_CHECK_OK(env->FileExists(tar_binary)); + TF_CHECK_OK(env->FileExists(tar)); + // 'o' needs to be explicitly set on Android so that + // untarring works as non-root (otherwise tries to chown + // files, which fails) + proc.SetProgram(tar_binary, {"tar", "xfo", tar, "-C", dir}); + } proc.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); proc.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); if (!proc.Start()) @@ -164,15 +172,15 @@ class ZipEnvironment : public ::testing::Environment { std::vector temporary_directories_; }; -// Return the singleton zip_environment. -ZipEnvironment* zip_environment() { - static ZipEnvironment* env = new ZipEnvironment; +// Return the singleton archive_environment. +ArchiveEnvironment* archive_environment() { + static ArchiveEnvironment* env = new ArchiveEnvironment; return env; } -// Read the manifest.txt out of the unarchived zip file. Specifically +// Read the manifest.txt out of the unarchived archive file. Specifically // `original_file` is the original zip file for error messages. `dir` is -// the temporary directory where the zip file has been unarchived and +// the temporary directory where the archive file has been unarchived and // `test_paths` is the list of test prefixes that were in the manifest. // Note, it is an error for a manifest to contain no tests. tensorflow::Status ReadManifest(const string& original_file, const string& dir, @@ -198,12 +206,22 @@ tensorflow::Status ReadManifest(const string& original_file, const string& dir, return tensorflow::Status::OK(); } -// Get a list of tests from a zip file `zip_file_name`. -std::vector UnarchiveZipAndFindTestNames(const string& zip_file) { +// Get a list of tests from either zip or tar file +std::vector UnarchiveAndFindTestNames(const string& zip_file, + const string& tar_file) { + if (zip_file.empty() && tar_file.empty()) { + TF_CHECK_OK(tensorflow::Status(tensorflow::error::UNKNOWN, + "Neither zip_file nor tar_file was given")); + } string decompress_tmp_dir; - TF_CHECK_OK(zip_environment()->UnZip(zip_file, &decompress_tmp_dir)); + TF_CHECK_OK(archive_environment()->UnArchive(zip_file, tar_file, + &decompress_tmp_dir)); std::vector stuff; - TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff)); + if (!zip_file.empty()) { + TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff)); + } else { + TF_CHECK_OK(ReadManifest(tar_file, decompress_tmp_dir, &stuff)); + } return stuff; } @@ -228,16 +246,21 @@ TEST_P(OpsTest, RunZipTests) { } bool result = tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver); + string message = test_driver.GetErrorMessage(); if (bug_number.empty()) { - EXPECT_TRUE(result) << test_driver.GetErrorMessage(); + if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) { + EXPECT_EQ(message, string("Failed to invoke interpreter")) << message; + } else { + EXPECT_TRUE(result) << message; + } } else { if (FLAGS_ignore_known_bugs) { EXPECT_FALSE(result) << "Test was expected to fail but is now passing; " "you can mark http://b/" << bug_number << " as fixed! Yay!"; } else { - EXPECT_TRUE(result) << test_driver.GetErrorMessage() - << ": Possibly due to http://b/" << bug_number; + EXPECT_TRUE(result) << message << ": Possibly due to http://b/" + << bug_number; } } } @@ -258,30 +281,40 @@ struct ZipPathParamName { } }; -INSTANTIATE_TEST_CASE_P( - tests, OpsTest, - ::testing::ValuesIn(UnarchiveZipAndFindTestNames(*FLAGS_zip_file_path)), - ZipPathParamName()); +INSTANTIATE_TEST_CASE_P(tests, OpsTest, + ::testing::ValuesIn(UnarchiveAndFindTestNames( + *FLAGS_zip_file_path, *FLAGS_tar_file_path)), + ZipPathParamName()); } // namespace testing } // namespace tflite int main(int argc, char** argv) { - ::testing::AddGlobalTestEnvironment(tflite::testing::zip_environment()); + ::testing::AddGlobalTestEnvironment(tflite::testing::archive_environment()); std::vector flags = { tensorflow::Flag( "ignore_known_bugs", &tflite::testing::FLAGS_ignore_known_bugs, "If a particular model is affected by a known bug, the " "corresponding test should expect the outputs to not match."), - tensorflow::Flag("zip_file_path", tflite::testing::FLAGS_zip_file_path, - "Required: Location of the test zip file."), + tensorflow::Flag( + "tar_file_path", tflite::testing::FLAGS_tar_file_path, + "Required (or zip_file_path): Location of the test tar file."), + tensorflow::Flag( + "zip_file_path", tflite::testing::FLAGS_zip_file_path, + "Required (or tar_file_path): Location of the test zip file."), tensorflow::Flag("unzip_binary_path", tflite::testing::FLAGS_unzip_binary_path, - "Required: Location of a suitable unzip binary."), + "Location of a suitable unzip binary."), + tensorflow::Flag("tar_binary_path", + tflite::testing::FLAGS_tar_binary_path, + "Location of a suitable tar binary."), tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi, - "Whether to enable the NNAPI delegate")}; - + "Whether to enable the NNAPI delegate"), + tensorflow::Flag("ignore_unsupported_nnapi", + &tflite::testing::FLAGS_ignore_unsupported_nnapi, + "Don't fail tests just because delegation to NNAPI " + "is not possible")}; bool success = tensorflow::Flags::Parse(&argc, argv, flags); if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) { fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); diff --git a/tensorflow/contrib/lite/testing/join.h b/tensorflow/contrib/lite/testing/join.h index 1edee01cf97da3c53be1895e667b005551ac2991..4be19ad7569c3333b6647b91adbc6e77ff088f10 100644 --- a/tensorflow/contrib/lite/testing/join.h +++ b/tensorflow/contrib/lite/testing/join.h @@ -17,7 +17,8 @@ limitations under the License. #include #include -#include + +#include "tensorflow/contrib/lite/string.h" namespace tflite { namespace testing { diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h index 96ab6be54e528334f9e4a8cc259e44f99878fefb..fac7d01aab4b1e4c251213041eb4b823cd7d66aa 100644 --- a/tensorflow/contrib/lite/testing/test_runner.h +++ b/tensorflow/contrib/lite/testing/test_runner.h @@ -90,7 +90,7 @@ class TestRunner { // Invalidate the test runner, preventing it from executing any further. void Invalidate(const string& error_message) { - cerr << error_message << std::endl; + std::cerr << error_message << std::endl; error_message_ = error_message; } bool IsValid() const { return error_message_.empty(); } diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc index 3b27f6f3da92ce80c3830feb7c6af095e7c48e9c..ec435ca60d959a11a9392b6fbab99b0561f50942 100644 --- a/tensorflow/contrib/lite/testing/tf_driver.cc +++ b/tensorflow/contrib/lite/testing/tf_driver.cc @@ -28,8 +28,8 @@ namespace { tensorflow::Tensor CreateTensor(const tensorflow::DataType type, const std::vector& dim) { - tensorflow::TensorShape shape{gtl::ArraySlice{ - reinterpret_cast(dim.data()), dim.size()}}; + tensorflow::TensorShape shape{tensorflow::gtl::ArraySlice{ + reinterpret_cast(dim.data()), dim.size()}}; return {type, shape}; } @@ -179,7 +179,7 @@ void TfDriver::Invoke() { auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()}, output_names_, {}, &output_tensors_); if (!status.ok()) { - Invalidate("Failed to invoke interpreter"); + Invalidate("Failed to run input data on graph"); } } diff --git a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc index 5afa0f800cdaa8bf70a11cb6e2ac64ace8138e79..f2c49fe389763110279b3dd1e4f13b1522de0460 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc +++ b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc @@ -20,12 +20,29 @@ int main(int argc, char** argv) { ::tflite::testing::DiffOptions options = ::tflite::testing::ParseTfliteDiffFlags(&argc, argv); if (options.tensorflow_model.empty()) return 1; + int failure_count = 0; - for (int i = 0; i < 100; i++) { - if (!tflite::testing::RunDiffTest(options)) { + for (int i = 0; i < options.num_runs_per_pass; i++) { + if (!tflite::testing::RunDiffTest(options, /*num_invocations=*/1)) { ++failure_count; } } - fprintf(stderr, "Num errors: %d\n", failure_count); + int failures_in_first_pass = failure_count; + + if (failure_count == 0) { + // Let's try again with num_invocations > 1 to make sure we can do multiple + // invocations without resetting the interpreter. + for (int i = 0; i < options.num_runs_per_pass; i++) { + if (!tflite::testing::RunDiffTest(options, /*num_invocations=*/2)) { + ++failure_count; + } + } + } + + fprintf(stderr, "Num errors in single-inference pass: %d\n", + failures_in_first_pass); + fprintf(stderr, "Num errors in multi-inference pass : %d\n", + failure_count - failures_in_first_pass); + return failure_count != 0 ? 1 : 0; } diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h index 706108ed73bb3fd9bd784cffffe322d6981433e6..695c2a3de6c5d7c74a943134f0c97390710ef1e7 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h +++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ #define TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ +#include + #include "tensorflow/contrib/lite/testing/split.h" #include "tensorflow/contrib/lite/testing/tflite_diff_util.h" #include "tensorflow/core/util/command_line_flags.h" @@ -30,6 +32,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { string input_layer_type; string input_layer_shape; string output_layer; + int32_t num_runs_per_pass = 100; } values; std::vector flags = { @@ -49,6 +52,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { tensorflow::Flag("output_layer", &values.output_layer, "Names of output tensors, separated by comma. Example " "output_1,output_2"), + tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass, + "Number of full runs in each pass."), }; bool no_inputs = *argc == 1; @@ -63,7 +68,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { Split(values.input_layer, ","), Split(values.input_layer_type, ","), Split(values.input_layer_shape, ":"), - Split(values.output_layer, ",")}; + Split(values.output_layer, ","), + values.num_runs_per_pass}; } } // namespace testing diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.cc b/tensorflow/contrib/lite/testing/tflite_diff_util.cc index f601d3752ddb5df9f2b5ac73d9bc303efaade4a5..19f34c0a51e442804bf2824adc3a1d8bde1eb4b0 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_util.cc +++ b/tensorflow/contrib/lite/testing/tflite_diff_util.cc @@ -25,13 +25,14 @@ limitations under the License. namespace tflite { namespace testing { -bool RunDiffTest(const DiffOptions& options) { +bool RunDiffTest(const DiffOptions& options, int num_invocations) { std::stringstream tflite_stream; if (!GenerateTestSpecFromTensorflowModel( tflite_stream, options.tensorflow_model, options.tflite_model, - options.input_layer, options.input_layer_type, - options.input_layer_shape, options.output_layer)) + num_invocations, options.input_layer, options.input_layer_type, + options.input_layer_shape, options.output_layer)) { return false; + } TfLiteDriver tflite_driver(/*use_nnapi=*/true); tflite_driver.LoadModel(options.tflite_model); return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver); diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.h b/tensorflow/contrib/lite/testing/tflite_diff_util.h index 326fa6c3e28000dee9b6eb9cc5b3a6c5c87e28d0..4ab2f230fdcdfe4616ab1706aa41f0e806665f66 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_util.h +++ b/tensorflow/contrib/lite/testing/tflite_diff_util.h @@ -40,10 +40,14 @@ struct DiffOptions { // Names of output tensors. // Example output_1,output_2 std::vector output_layer; + // Number of full runs (from building interpreter to checking outputs) in + // each of the passes. The first pass has a single inference, while the + // second pass does multiple inferences back to back. + int num_runs_per_pass; }; // Run a single TensorFLow Lite diff test with a given options. -bool RunDiffTest(const DiffOptions& options); +bool RunDiffTest(const DiffOptions& options, int num_invocations); } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index f518bf864c6a71679400f0013bbcd40142bb8ca1..4d08fb545801521213890a4f5a9b010de57b27cd 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -285,9 +285,11 @@ bool TfLiteDriver::CheckResults() { } void TfLiteDriver::ResetLSTMStateTensors() { - // This is a workaround for initializing state tensors for LSTM. - // TODO(ycling): Refactoring and find a better way to initialize state - // tensors. Maybe write the reset instructions into the test data. + interpreter_->ResetVariableTensorsToZero(); + + // Below is a workaround for initializing state tensors for LSTM. + // TODO(ycling): Remove the code below after nobody is using the 18-inputs + // definition. for (auto node_index : interpreter_->execution_plan()) { const auto& node_and_reg = interpreter_->node_and_registration(node_index); const auto& node = node_and_reg->first; @@ -297,19 +299,12 @@ void TfLiteDriver::ResetLSTMStateTensors() { const auto* params = reinterpret_cast(node.builtin_data); if (params->kernel_type == kTfLiteLSTMFullKernel && - node.outputs->size >= 2) { + node.inputs->size == 18 && node.outputs->size >= 2) { // The first 2 outputs of LSTM are state tensors. for (int i = 0; i < 2; ++i) { int node_index = node.outputs->data[i]; ResetTensor(node_index); } - } else if (params->kernel_type == kTfLiteLSTMBasicKernel && - node.inputs->size == 5) { - // The 2th and 5th inputs are state tensors. - for (int i : {1, 4}) { - int node_index = node.inputs->data[i]; - ResetTensor(node_index); - } } } } diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h index 6d20aec141c7c3a3e48af290edb169c6fd7254cf..8aa639157b8b68061f9ee8c3483959a79cb5794e 100644 --- a/tensorflow/contrib/lite/testing/util.h +++ b/tensorflow/contrib/lite/testing/util.h @@ -15,8 +15,39 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ #define TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ +#include + +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/string.h" + namespace tflite { +// An ErrorReporter that collects error message in a string, in addition +// to printing to stderr. +class TestErrorReporter : public ErrorReporter { + public: + int Report(const char* format, va_list args) override { + char buffer[1024]; + int size = vsnprintf(buffer, sizeof(buffer), format, args); + fprintf(stderr, "%s", buffer); + error_messages_ += buffer; + num_calls_++; + return size; + } + + void Reset() { + num_calls_ = 0; + error_messages_.clear(); + } + + int num_calls() const { return num_calls_; } + const string& error_messages() const { return error_messages_; } + + private: + int num_calls_ = 0; + string error_messages_; +}; + inline void LogToStderr() { #ifdef PLATFORM_GOOGLE FLAGS_logtostderr = true; diff --git a/tensorflow/contrib/lite/tflite_static.bp b/tensorflow/contrib/lite/tflite_static.bp index 5a78ace359bb99b25f374262b6f689a4ad88cbd1..a0181778ffb78a8bad117ec610a388243aaaab0e 100644 --- a/tensorflow/contrib/lite/tflite_static.bp +++ b/tensorflow/contrib/lite/tflite_static.bp @@ -24,6 +24,7 @@ cc_library_static { "error_reporter.cc", "graph_info.cc", "interpreter.cc", + "mmap_allocation.cc", "model.cc", "nnapi_delegate.cc", "optional_debug_tools.cc", @@ -36,7 +37,7 @@ cc_library_static { "kernels/topk_v2.cc", "kernels/activations.cc", "kernels/add.cc", - "kernels/arg_max.cc", + "kernels/arg_min_max.cc", "kernels/basic_rnn.cc", "kernels/batch_to_space_nd.cc", "kernels/bidirectional_sequence_lstm.cc", @@ -53,6 +54,7 @@ cc_library_static { "kernels/embedding_lookup_sparse.cc", "kernels/exp.cc", "kernels/expand_dims.cc", + "kernels/fake_quant.cc", "kernels/floor.cc", "kernels/fully_connected.cc", "kernels/gather.cc", @@ -61,18 +63,23 @@ cc_library_static { "kernels/kernel_util.cc", "kernels/l2norm.cc", "kernels/local_response_norm.cc", + "kernels/logical.cc", "kernels/lsh_projection.cc", "kernels/lstm.cc", "kernels/maximum_minimum.cc", - "kernels/mean.cc", "kernels/mul.cc", "kernels/neg.cc", + "kernels/one_hot.cc", "kernels/pad.cc", + "kernels/pack.cc", "kernels/pooling.cc", + "kernels/pow.cc", + "kernels/reduce.cc", "kernels/register.cc", "kernels/reshape.cc", "kernels/resize_bilinear.cc", "kernels/select.cc", + "kernels/shape.cc", "kernels/skip_gram.cc", "kernels/slice.cc", "kernels/space_to_batch_nd.cc", @@ -119,6 +126,9 @@ cc_library_static { "-Wno-unused-local-typedef", "-Wno-unused-parameter", "-Wno-unused-variable", + "-Wno-invalid-partial-specialization", + "-Wno-mismatched-tags", + "-Wno-visibility", ], stl: "libc++_static", } diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 0789dc99286361183aea4c95db98c11ff700ea79..02d0890a7af606627c237314fe5ee108924d761d 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -11,6 +11,7 @@ load( "//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", + "tf_copts", ) tf_proto_library_cc( @@ -93,6 +94,7 @@ cc_library( ":runtime", ":toco_port", "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", ], ) @@ -143,7 +145,6 @@ cc_library( ":toco_graphviz_dump_options", ":toco_port", ":types_proto_cc", - "//tensorflow/cc/saved_model:tag_constants", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "@com_google_absl//absl/strings", @@ -169,41 +170,6 @@ cc_library( ], ) -cc_library( - name = "toco_saved_model", - srcs = [ - "toco_saved_model.cc", - ], - hdrs = [ - "toco_saved_model.h", - ], - visibility = ["//visibility:public"], - deps = [ - ":model_cmdline_flags", - ":model_flags_proto_cc", - ":toco_flags_proto_cc", - ":types_proto_cc", - "//tensorflow/cc/tools:freeze_saved_model", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "toco_saved_model_test", - srcs = ["toco_saved_model_test.cc"], - deps = [ - ":model_cmdline_flags", - ":toco_cmdline_flags", - ":toco_saved_model", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:scope", - "//tensorflow/core:test", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - ], -) - cc_library( name = "graph_transformations", srcs = [ @@ -212,7 +178,8 @@ cc_library( "graph_transformations/convert_reorder_axes.cc", "graph_transformations/convert_squeeze_to_reshape.cc", "graph_transformations/convert_trivial_addn_to_add.cc", - "graph_transformations/convert_trivial_stack_to_reshape.cc", + "graph_transformations/convert_trivial_pack_to_reshape.cc", + "graph_transformations/convert_trivial_tile_to_concat.cc", "graph_transformations/convert_trivial_transpose_to_reshape.cc", "graph_transformations/create_im2col_arrays.cc", "graph_transformations/dequantize.cc", @@ -220,10 +187,10 @@ cc_library( "graph_transformations/drop_im2col_arrays.cc", "graph_transformations/ensure_bias_vectors.cc", "graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc", - "graph_transformations/experimental_shuffle_fc_weights.cc", "graph_transformations/fuse_activation_functions.cc", "graph_transformations/fuse_binary_into_following_affine.cc", "graph_transformations/fuse_binary_into_preceding_affine.cc", + "graph_transformations/fuse_broadcast_into_following_binary.cc", "graph_transformations/graph_transformations.cc", "graph_transformations/hardcode_min_max.cc", "graph_transformations/identify_dilated_conv.cc", @@ -237,6 +204,7 @@ cc_library( "graph_transformations/lstm_utils.cc", "graph_transformations/make_initial_dequantize_operator.cc", "graph_transformations/merge_reshape_into_preceding_transpose.cc", + "graph_transformations/move_binary_operator_before_reshape.cc", "graph_transformations/propagate_activation_function_into_constants.cc", "graph_transformations/propagate_array_data_types.cc", "graph_transformations/propagate_default_min_max.cc", @@ -246,7 +214,7 @@ cc_library( "graph_transformations/quantization_util.h", "graph_transformations/quantize.cc", "graph_transformations/quantize_weights.cc", - "graph_transformations/read_fake_quant_min_max.cc", + "graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc", "graph_transformations/remove_final_dequantize_op.cc", "graph_transformations/remove_tensorflow_assert.cc", "graph_transformations/remove_tensorflow_identity.cc", @@ -270,19 +238,23 @@ cc_library( "graph_transformations/resolve_constant_fake_quant.cc", "graph_transformations/resolve_constant_fill.cc", "graph_transformations/resolve_constant_gather.cc", + "graph_transformations/resolve_constant_pack.cc", "graph_transformations/resolve_constant_random_uniform.cc", "graph_transformations/resolve_constant_range.cc", "graph_transformations/resolve_constant_reshape.cc", + "graph_transformations/resolve_constant_select.cc", "graph_transformations/resolve_constant_shape_or_rank.cc", "graph_transformations/resolve_constant_slice.cc", - "graph_transformations/resolve_constant_stack.cc", "graph_transformations/resolve_constant_strided_slice.cc", + "graph_transformations/resolve_constant_tile.cc", "graph_transformations/resolve_constant_transpose.cc", "graph_transformations/resolve_constant_unary.cc", - "graph_transformations/resolve_mean_attributes.cc", + "graph_transformations/resolve_fake_quant_args_from_vars.cc", + "graph_transformations/resolve_gather_attributes.cc", "graph_transformations/resolve_multiply_by_zero.cc", "graph_transformations/resolve_pad_attributes.cc", "graph_transformations/resolve_padv2_attributes.cc", + "graph_transformations/resolve_reduce_attributes.cc", "graph_transformations/resolve_reorder_axes.cc", "graph_transformations/resolve_reshape_attributes.cc", "graph_transformations/resolve_slice_attributes.cc", @@ -293,8 +265,8 @@ cc_library( "graph_transformations/resolve_tensorflow_matmul.cc", "graph_transformations/resolve_tensorflow_merge.cc", "graph_transformations/resolve_tensorflow_switch.cc", - "graph_transformations/resolve_tensorflow_tile.cc", "graph_transformations/resolve_transpose_attributes.cc", + "graph_transformations/shuffle_fc_weights.cc", "graph_transformations/unfuse_activation_functions.cc", "graph_transformations/unpartition_embedding_lookup.cc", "graph_transformations/unroll_batch_matmul.cc", @@ -336,7 +308,7 @@ cc_library( "tensorflow_util.h", "toco_tooling.h", ], - copts = select({ + copts = tf_copts() + select({ "//tensorflow:darwin": ["-DTOCO_SUPPORT_PORTABLE_PROTOS=0"], "//conditions:default": [], }), @@ -370,6 +342,7 @@ cc_library( tf_cc_test( name = "import_tensorflow_test", srcs = ["import_tensorflow_test.cc"], + tags = ["no_oss"], deps = [ ":toco_tooling", "//tensorflow/core:framework", @@ -390,6 +363,7 @@ cc_library( "dump_graphviz.h", "tooling_util.h", ], + copts = tf_copts(), visibility = ["//visibility:public"], deps = [ ":model", @@ -409,6 +383,7 @@ cc_library( tf_cc_test( name = "tooling_util_test", srcs = ["tooling_util_test.cc"], + tags = ["no_oss"], deps = [ ":model", ":tooling_util", @@ -430,7 +405,6 @@ tf_cc_binary( ":toco_cmdline_flags", ":toco_flags_proto_cc", ":toco_port", - ":toco_saved_model", ":toco_tooling", ":types_proto_cc", "//tensorflow/core:lib", @@ -444,6 +418,7 @@ tf_cc_test( data = [ "toco_port_test.cc", ], + tags = ["no_oss"], deps = [ ":toco_port", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/contrib/lite/toco/README.md b/tensorflow/contrib/lite/toco/README.md index 522e260ad2a14c5f8e080c0a0f538f4192b7ed2d..2db6a627ab59604a99cafe3b38df08b70092d989 100644 --- a/tensorflow/contrib/lite/toco/README.md +++ b/tensorflow/contrib/lite/toco/README.md @@ -17,11 +17,12 @@ Usage information is given in these documents: Once an application developer has a trained TensorFlow model, TOCO will accept that model and generate a TensorFlow Lite [FlatBuffer](https://google.github.io/flatbuffers/) file. TOCO currently supports -[SavedModels](https://www.tensorflow.org/programmers_guide/saved_model#using_savedmodel_with_estimators) -and frozen graphs (models generated via -[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)). -The TensorFlow Lite FlatBuffer file can be shipped to client devices, generally -mobile devices, where the TensorFlow Lite interpreter handles them on-device. -This flow is represented in the diagram below. +[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators), +frozen graphs (models generated via +[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)), +and `tf.Keras` model files. The TensorFlow Lite FlatBuffer file can be shipped +to client devices, generally mobile devices, where the TensorFlow Lite +interpreter handles them on-device. This flow is represented in the diagram +below. ![drawing](g3doc/toco_landscape.svg) diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 9f5ca66d050f0ead9b8856c77dba8d9bbd182d10..aef35ad490656c09a7d7314aa033bc985b3af661 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -21,13 +21,13 @@ limitations under the License. #include #include #include +#include "tensorflow/contrib/lite/toco/toco_port.h" #if defined(PLATFORM_GOOGLE) #include "strings/split.h" +#include "strings/strip.h" #endif #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" -#include "tensorflow/cc/saved_model/tag_constants.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/toco_types.h" namespace toco { @@ -145,8 +145,10 @@ class Arg final { } string outer_member_copy = outer_member; absl::StripAsciiWhitespace(&outer_member); - if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false; - if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false; + if (!strings::TryStripPrefixString(outer_member, "{", &outer_member)) + return false; + if (!strings::TryStripSuffixString(outer_member, "}", &outer_member)) + return false; const std::vector inner_fields_vector = absl::StrSplit(outer_member, ','); @@ -223,7 +225,7 @@ struct ParsedTocoFlags { Arg output_file; Arg input_format = Arg("TENSORFLOW_GRAPHDEF"); Arg output_format = Arg("TFLITE"); - Arg savedmodel_tagset = Arg(tensorflow::kSavedModelTagServe); + Arg savedmodel_tagset; // TODO(aselle): command_line_flags doesn't support doubles Arg default_ranges_min = Arg(0.); Arg default_ranges_max = Arg(0.); diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index 878bda36ef3900d6d8c509aca40cee834cefe514..30525efd2391bb63afd7035b8134e5858add45f2 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -167,7 +167,7 @@ NodeProperties GetPropertiesForArray(const Model& model, node_properties.label += "]"; int buffer_size = 0; - if (IsValid(array.shape())) { + if (IsNonEmpty(array.shape())) { buffer_size = RequiredBufferSizeForShape(array.shape()); node_properties.log2_buffer_size = std::log2(static_cast(buffer_size)); @@ -227,7 +227,7 @@ NodeProperties GetPropertiesForArray(const Model& model, NodeProperties GetPropertiesForOperator(const Operator& op) { NodeProperties node_properties; - if (op.type == OperatorType::kTensorFlowUnsupported) { + if (op.type == OperatorType::kUnsupported) { node_properties.label = static_cast(op).tensorflow_op; } else { diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index c7c80ab21cc88ecb94e5234750431130dbd8400c..02671f0408f55726df730dbe0fe9a4f936d22632 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -145,7 +145,7 @@ void ConvertFloatTensorConst(const string& name, const Shape& input_shape, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -162,7 +162,7 @@ void ConvertFloatTensorConst(const string& name, const Shape& input_shape, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -178,7 +178,7 @@ void ConvertFloatTensorConst(const Model& model, const string& name, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -199,7 +199,7 @@ void ConvertFloatTensorConst(const Model& model, const string& name, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -215,6 +215,30 @@ void ConvertFloatTensorConst(const Model& model, const string& name, LegacyScalarPolicy::kAvoidLegacyScalars); } +void ConvertBoolTensorConst(const Model& model, const string& name, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + CHECK(model.HasArray(name)); + const auto& array = model.GetArray(name); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_BOOL); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_BOOL); + const auto& data = array.GetBuffer().data; + for (auto index : data) { + tensor->add_bool_val(index); + } + const auto& array_shape = array.shape(); + auto* shape = tensor->mutable_tensor_shape(); + for (int i = 0; i < array_shape.dimensions_count(); i++) { + shape->add_dim()->set_size(array_shape.dims(i)); + } +} + void ConvertIntTensorConst(const Model& model, const string& name, GraphDef* tensorflow_graph) { if (HasAlreadyExportedConst(name, *tensorflow_graph)) { @@ -222,7 +246,7 @@ void ConvertIntTensorConst(const Model& model, const string& name, } CHECK(model.HasArray(name)); const auto& array = model.GetArray(name); - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -245,7 +269,7 @@ void CreateIntTensorConst(const string& name, const std::vector& data, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -268,7 +292,7 @@ void CreateMatrixShapeTensorConst(const string& name, int rows, int cols, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -286,7 +310,7 @@ void CreateDummyConcatDimTensorConst(const string& name, int dim, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -301,7 +325,7 @@ void CreateReshapeShapeTensorConst(const string& name, if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; } - auto* const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); const_op->set_op("Const"); const_op->set_name(name); (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -341,7 +365,7 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op, conv_output += "/conv"; } - auto* conv2d_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node(); conv2d_op->set_op("Conv2D"); conv2d_op->set_name(conv_output); *conv2d_op->add_input() = src_op.inputs[0]; @@ -377,7 +401,7 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op, (*conv2d_op->mutable_attr())["padding"].set_s(padding); if (has_bias) { - auto* biasadd_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(src_op.outputs[0]); biasadd_op->add_input(conv_output); @@ -409,7 +433,7 @@ void ConvertDepthwiseConvOperator(const Model& model, conv_output += "/conv"; } - auto* dc2d_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* dc2d_op = tensorflow_graph->add_node(); dc2d_op->set_op("DepthwiseConv2dNative"); dc2d_op->set_name(conv_output); *dc2d_op->add_input() = src_op.inputs[0]; @@ -457,7 +481,7 @@ void ConvertDepthwiseConvOperator(const Model& model, (*dc2d_op->mutable_attr())["padding"].set_s(padding); if (has_bias) { - auto* biasadd_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(src_op.outputs[0]); biasadd_op->add_input(conv_output); @@ -482,7 +506,7 @@ void ConvertDepthwiseConvOperator(const Model& model, void ConvertTransposeConvOperator(const Model& model, const TransposeConvOperator& src_op, GraphDef* tensorflow_graph) { - auto* conv2d_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node(); conv2d_op->set_op("Conv2DBackpropInput"); conv2d_op->set_name(src_op.outputs[0]); *conv2d_op->add_input() = src_op.inputs[0]; @@ -514,7 +538,7 @@ void ConvertTransposeConvOperator(const Model& model, void ConvertDepthToSpaceOperator(const Model& model, const DepthToSpaceOperator& src_op, GraphDef* tensorflow_graph) { - auto* op = tensorflow_graph->add_node(); + tensorflow::NodeDef* op = tensorflow_graph->add_node(); op->set_op("DepthToSpace"); op->set_name(src_op.outputs[0]); *op->add_input() = src_op.inputs[0]; @@ -525,7 +549,7 @@ void ConvertDepthToSpaceOperator(const Model& model, void ConvertSpaceToDepthOperator(const Model& model, const SpaceToDepthOperator& src_op, GraphDef* tensorflow_graph) { - auto* op = tensorflow_graph->add_node(); + tensorflow::NodeDef* op = tensorflow_graph->add_node(); op->set_op("SpaceToDepth"); op->set_name(src_op.outputs[0]); *op->add_input() = src_op.inputs[0]; @@ -546,7 +570,7 @@ void ConvertFullyConnectedOperator(const Model& model, CHECK_EQ(fc_weights_shape.dimensions_count(), 2); CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1, tensorflow_graph); - auto* reshape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(reshape_output); reshape_op->add_input(src_op.inputs[0]); @@ -568,7 +592,7 @@ void ConvertFullyConnectedOperator(const Model& model, const string transpose_perm = AvailableArrayName(model, transpose_output + "/perm"); CreateIntTensorConst(transpose_perm, {1, 0}, {2}, tensorflow_graph); - auto transpose_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node(); transpose_op->set_op("Transpose"); transpose_op->set_name(transpose_output); *transpose_op->add_input() = src_op.inputs[1]; @@ -577,7 +601,7 @@ void ConvertFullyConnectedOperator(const Model& model, GetTensorFlowDataType(model, src_op.inputs[1])); (*transpose_op->mutable_attr())["Tperm"].set_type(DT_INT32); - auto* matmul_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node(); matmul_op->set_op("MatMul"); matmul_op->set_name(matmul_output); *matmul_op->add_input() = reshape_output; @@ -590,7 +614,7 @@ void ConvertFullyConnectedOperator(const Model& model, // Add the bias, if it exists. if (has_bias) { - auto* biasadd_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(src_op.outputs[0]); biasadd_op->add_input(matmul_output); @@ -615,45 +639,61 @@ void ConvertFullyConnectedOperator(const Model& model, void ConvertAddOperator(const Model& model, const AddOperator& src_op, GraphDef* tensorflow_graph) { - auto* add_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* add_op = tensorflow_graph->add_node(); add_op->set_op("Add"); add_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *add_op->add_input() = src_op.inputs[0]; *add_op->add_input() = src_op.inputs[1]; - (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*add_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, GraphDef* tensorflow_graph) { - auto* add_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* add_op = tensorflow_graph->add_node(); add_op->set_op("AddN"); add_op->set_name(src_op.outputs[0]); for (const auto& input : src_op.inputs) { *add_op->add_input() = input; } (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size()); - (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*add_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } void ConvertMulOperator(const Model& model, const MulOperator& src_op, GraphDef* tensorflow_graph) { - auto* add_op = tensorflow_graph->add_node(); - add_op->set_op("Mul"); - add_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* mul_op = tensorflow_graph->add_node(); + mul_op->set_op("Mul"); + mul_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); - *add_op->add_input() = src_op.inputs[0]; - *add_op->add_input() = src_op.inputs[1]; - (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); + *mul_op->add_input() = src_op.inputs[0]; + *mul_op->add_input() = src_op.inputs[1]; + (*mul_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); +} + +void ConvertDivOperator(const Model& model, const DivOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* div_op = tensorflow_graph->add_node(); + div_op->set_op("Div"); + div_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *div_op->add_input() = src_op.inputs[0]; + *div_op->add_input() = src_op.inputs[1]; + (*div_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } -void ConvertReluOperator(const ReluOperator& src_op, +void ConvertReluOperator(const Model& model, const ReluOperator& src_op, GraphDef* tensorflow_graph) { - auto* relu_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); relu_op->set_op("Relu"); relu_op->set_name(src_op.outputs[0]); *relu_op->add_input() = src_op.inputs[0]; - (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*relu_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } void ConvertRelu1Operator(const Relu1Operator& src_op, @@ -662,7 +702,7 @@ void ConvertRelu1Operator(const Relu1Operator& src_op, const string min_bounds = src_op.outputs[0] + "/min_bounds"; const string max_output = src_op.outputs[0] + "/max_output"; - auto* max_bounds_const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* max_bounds_const_op = tensorflow_graph->add_node(); max_bounds_const_op->set_op("Const"); max_bounds_const_op->set_name(max_bounds); (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -671,7 +711,7 @@ void ConvertRelu1Operator(const Relu1Operator& src_op, max_bounds_const_op_tensor->set_dtype(DT_FLOAT); max_bounds_const_op_tensor->add_float_val(-1.0f); - auto* min_bounds_const_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* min_bounds_const_op = tensorflow_graph->add_node(); min_bounds_const_op->set_op("Const"); min_bounds_const_op->set_name(min_bounds); (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -680,14 +720,14 @@ void ConvertRelu1Operator(const Relu1Operator& src_op, min_bounds_const_op_tensor->set_dtype(DT_FLOAT); min_bounds_const_op_tensor->add_float_val(1.0f); - auto* max_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* max_op = tensorflow_graph->add_node(); max_op->set_op("Maximum"); max_op->set_name(max_output); *max_op->add_input() = src_op.inputs[0]; *max_op->add_input() = max_bounds; (*max_op->mutable_attr())["T"].set_type(DT_FLOAT); - auto* min_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* min_op = tensorflow_graph->add_node(); min_op->set_op("Minimum"); min_op->set_name(src_op.outputs[0]); *min_op->add_input() = max_output; @@ -697,7 +737,7 @@ void ConvertRelu1Operator(const Relu1Operator& src_op, void ConvertRelu6Operator(const Relu6Operator& src_op, GraphDef* tensorflow_graph) { - auto* relu_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); relu_op->set_op("Relu6"); relu_op->set_name(src_op.outputs[0]); *relu_op->add_input() = src_op.inputs[0]; @@ -705,7 +745,7 @@ void ConvertRelu6Operator(const Relu6Operator& src_op, } void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) { - auto* op = tensorflow_graph->add_node(); + tensorflow::NodeDef* op = tensorflow_graph->add_node(); op->set_op("Log"); op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -715,7 +755,7 @@ void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) { void ConvertLogisticOperator(const LogisticOperator& src_op, GraphDef* tensorflow_graph) { - auto* relu_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); relu_op->set_op("Sigmoid"); relu_op->set_name(src_op.outputs[0]); *relu_op->add_input() = src_op.inputs[0]; @@ -724,7 +764,7 @@ void ConvertLogisticOperator(const LogisticOperator& src_op, void ConvertTanhOperator(const TanhOperator& src_op, GraphDef* tensorflow_graph) { - auto* tanh_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* tanh_op = tensorflow_graph->add_node(); tanh_op->set_op("Tanh"); tanh_op->set_name(src_op.outputs[0]); *tanh_op->add_input() = src_op.inputs[0]; @@ -735,8 +775,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, GraphDef* tensorflow_graph) { string softmax_input; Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); - if (providing_op != nullptr && - providing_op->type == OperatorType::kTensorFlowReshape) { + if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) { softmax_input = src_op.inputs[0]; } else { // Insert a reshape operator that reduces the dimensions down to the 2 that @@ -745,7 +784,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, const string softmax_size = src_op.outputs[0] + "/softmax_insert_size"; softmax_input = reshape_output; - auto* reshape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(reshape_output); *reshape_op->add_input() = src_op.inputs[0]; @@ -762,7 +801,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph); } - auto* softmax_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* softmax_op = tensorflow_graph->add_node(); softmax_op->set_op("Softmax"); softmax_op->set_name(src_op.outputs[0]); *softmax_op->add_input() = softmax_input; @@ -776,8 +815,7 @@ void ConvertLogSoftmaxOperator(const Model& model, GraphDef* tensorflow_graph) { string softmax_input; Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); - if (providing_op != nullptr && - providing_op->type == OperatorType::kTensorFlowReshape) { + if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) { softmax_input = src_op.inputs[0]; } else { // Insert a reshape operator that reduces the dimensions down to the 2 that @@ -787,7 +825,7 @@ void ConvertLogSoftmaxOperator(const Model& model, const string softmax_size = src_op.outputs[0] + "/log_softmax_insert_size"; softmax_input = reshape_output; - auto* reshape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(reshape_output); *reshape_op->add_input() = src_op.inputs[0]; @@ -804,7 +842,7 @@ void ConvertLogSoftmaxOperator(const Model& model, CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph); } - auto* log_softmax_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* log_softmax_op = tensorflow_graph->add_node(); log_softmax_op->set_op("LogSoftmax"); log_softmax_op->set_name(src_op.outputs[0]); *log_softmax_op->add_input() = softmax_input; @@ -819,7 +857,7 @@ void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, const string rsqrt_output = src_op.outputs[0] + "/rsqrt"; const string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled"; - auto* sum_reduction_indices_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sum_reduction_indices_op = tensorflow_graph->add_node(); sum_reduction_indices_op->set_op("Const"); sum_reduction_indices_op->set_name(sum_reduction_indices); (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -833,26 +871,26 @@ void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, sum_reduction_indices_tensor->add_int_val(0); sum_reduction_indices_tensor->add_int_val(1); - auto* square_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* square_op = tensorflow_graph->add_node(); square_op->set_op("Square"); square_op->set_name(square_output); *square_op->add_input() = src_op.inputs[0]; (*square_op->mutable_attr())["T"].set_type(DT_FLOAT); - auto* sum_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sum_op = tensorflow_graph->add_node(); sum_op->set_op("Sum"); sum_op->set_name(sum_output); *sum_op->add_input() = square_output; *sum_op->add_input() = sum_reduction_indices; (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT); - auto* rsqrt_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node(); rsqrt_op->set_op("Rsqrt"); rsqrt_op->set_name(rsqrt_output); *rsqrt_op->add_input() = sum_output; (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT); - auto* mul_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* mul_op = tensorflow_graph->add_node(); mul_op->set_op("Mul"); mul_op->set_name(src_op.outputs[0]); *mul_op->add_input() = src_op.inputs[0]; @@ -863,7 +901,7 @@ void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, void ConvertLocalResponseNormalizationOperator( const LocalResponseNormalizationOperator& src_op, GraphDef* tensorflow_graph) { - auto* lrn_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* lrn_op = tensorflow_graph->add_node(); lrn_op->set_op("LRN"); lrn_op->set_name(src_op.outputs[0]); *lrn_op->add_input() = src_op.inputs[0]; @@ -875,7 +913,7 @@ void ConvertLocalResponseNormalizationOperator( void ConvertFakeQuantOperator(const FakeQuantOperator& src_op, GraphDef* tensorflow_graph) { - auto* fakequant_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* fakequant_op = tensorflow_graph->add_node(); fakequant_op->set_op("FakeQuantWithMinMaxArgs"); fakequant_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -886,11 +924,14 @@ void ConvertFakeQuantOperator(const FakeQuantOperator& src_op, if (src_op.num_bits) { (*fakequant_op->mutable_attr())["num_bits"].set_i(src_op.num_bits); } + if (src_op.narrow_range) { + (*fakequant_op->mutable_attr())["narrow_range"].set_b(src_op.narrow_range); + } } void ConvertMaxPoolOperator(const MaxPoolOperator& src_op, GraphDef* tensorflow_graph) { - auto* maxpool_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* maxpool_op = tensorflow_graph->add_node(); maxpool_op->set_op("MaxPool"); maxpool_op->set_name(src_op.outputs[0]); *maxpool_op->add_input() = src_op.inputs[0]; @@ -918,7 +959,7 @@ void ConvertMaxPoolOperator(const MaxPoolOperator& src_op, void ConvertAveragePoolOperator(const AveragePoolOperator& src_op, GraphDef* tensorflow_graph) { - auto* avgpool_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node(); avgpool_op->set_op("AvgPool"); avgpool_op->set_name(src_op.outputs[0]); *avgpool_op->add_input() = src_op.inputs[0]; @@ -947,7 +988,7 @@ void ConvertAveragePoolOperator(const AveragePoolOperator& src_op, void ConvertConcatenationOperator(const Model& model, const ConcatenationOperator& src_op, GraphDef* tensorflow_graph) { - auto* dc_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* dc_op = tensorflow_graph->add_node(); dc_op->set_op("ConcatV2"); dc_op->set_name(src_op.outputs[0]); const string dummy_axis = src_op.outputs[0] + "/axis"; @@ -965,7 +1006,7 @@ void ConvertConcatenationOperator(const Model& model, void ConvertTensorFlowReshapeOperator(const Model& model, const TensorFlowReshapeOperator& src_op, GraphDef* tensorflow_graph) { - auto* reshape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); reshape_op->set_op("Reshape"); reshape_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -987,7 +1028,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op, const string square_output = src_op.outputs[0] + "/square"; const string avgpool_output = src_op.outputs[0] + "/avgpool"; - auto* square_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* square_op = tensorflow_graph->add_node(); square_op->set_op("Square"); square_op->set_name(square_output); *square_op->add_input() = src_op.inputs[0]; @@ -1002,7 +1043,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } - auto* avgpool_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node(); avgpool_op->set_op("AvgPool"); avgpool_op->set_name(avgpool_output); *avgpool_op->add_input() = square_output; @@ -1020,7 +1061,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op, ksize.mutable_list()->add_i(src_op.kwidth); ksize.mutable_list()->add_i(1); - auto* sqrt_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node(); sqrt_op->set_op("Sqrt"); sqrt_op->set_name(src_op.outputs[0]); *sqrt_op->add_input() = avgpool_output; @@ -1029,7 +1070,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op, void ConvertSquareOperator(const TensorFlowSquareOperator& src_op, GraphDef* tensorflow_graph) { - auto* square_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* square_op = tensorflow_graph->add_node(); square_op->set_op("Square"); square_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1039,7 +1080,7 @@ void ConvertSquareOperator(const TensorFlowSquareOperator& src_op, void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op, GraphDef* tensorflow_graph) { - auto* sqrt_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node(); sqrt_op->set_op("Sqrt"); sqrt_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1047,10 +1088,23 @@ void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op, (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT); } +void ConvertRsqrtOperator(const Model& model, + const TensorFlowRsqrtOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node(); + rsqrt_op->set_op("Rsqrt"); + rsqrt_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *rsqrt_op->add_input() = src_op.inputs[0]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*rsqrt_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertSplitOperator(const Model& model, const TensorFlowSplitOperator& src_op, GraphDef* tensorflow_graph) { - auto* split_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* split_op = tensorflow_graph->add_node(); split_op->set_op("Split"); split_op->set_name(src_op.outputs[0]); for (const auto& input : src_op.inputs) { @@ -1071,7 +1125,7 @@ void ConvertSplitOperator(const Model& model, void ConvertCastOperator(const Model& model, const CastOperator& src_op, GraphDef* tensorflow_graph) { - auto* cast_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* cast_op = tensorflow_graph->add_node(); cast_op->set_op("Cast"); cast_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1085,7 +1139,7 @@ void ConvertCastOperator(const Model& model, const CastOperator& src_op, void ConvertFloorOperator(const Model& model, const FloorOperator& src_op, GraphDef* tensorflow_graph) { - auto* floor_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* floor_op = tensorflow_graph->add_node(); floor_op->set_op("Floor"); floor_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1095,21 +1149,36 @@ void ConvertFloorOperator(const Model& model, const FloorOperator& src_op, void ConvertGatherOperator(const Model& model, const GatherOperator& src_op, GraphDef* tensorflow_graph) { - auto* gather_op = tensorflow_graph->add_node(); - gather_op->set_op("Gather"); + tensorflow::NodeDef* gather_op = tensorflow_graph->add_node(); + gather_op->set_op("GatherV2"); gather_op->set_name(src_op.outputs[0]); - CHECK_EQ(src_op.inputs.size(), 2); *gather_op->add_input() = src_op.inputs[0]; *gather_op->add_input() = src_op.inputs[1]; + if (!src_op.axis) { + // Dynamic axis. + CHECK_EQ(src_op.inputs.size(), 3); + *gather_op->add_input() = src_op.inputs[2]; + } else { + // Constant axis. + CHECK_EQ(src_op.inputs.size(), 2); + const string gather_axis = + AvailableArrayName(model, gather_op->name() + "/axis"); + CreateIntTensorConst(gather_axis, {src_op.axis.value()}, {}, + tensorflow_graph); + *gather_op->add_input() = gather_axis; + } + (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32); - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*gather_op->mutable_attr())["Taxis"].set_type(DT_INT32); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*gather_op->mutable_attr())["Tparams"].set_type(params_type); } void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op, GraphDef* tensorflow_graph) { - auto* argmax_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* argmax_op = tensorflow_graph->add_node(); argmax_op->set_op("ArgMax"); argmax_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1123,10 +1192,26 @@ void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op, GetTensorFlowDataType(model, src_op.outputs[0])); } +void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node(); + argmin_op->set_op("ArgMin"); + argmin_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *argmin_op->add_input() = src_op.inputs[0]; + *argmin_op->add_input() = src_op.inputs[1]; + (*argmin_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); + (*argmin_op->mutable_attr())["Tidx"].set_type( + GetTensorFlowDataType(model, src_op.inputs[1])); + (*argmin_op->mutable_attr())["output_type"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); +} + void ConvertTransposeOperator(const Model& model, const TransposeOperator& src_op, GraphDef* tensorflow_graph) { - auto* transpose_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node(); transpose_op->set_op("Transpose"); transpose_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1141,7 +1226,7 @@ void ConvertTransposeOperator(const Model& model, void ConvertTensorFlowShapeOperator(const Model& model, const TensorFlowShapeOperator& src_op, GraphDef* tensorflow_graph) { - auto* shape_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* shape_op = tensorflow_graph->add_node(); shape_op->set_op("Shape"); shape_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1154,7 +1239,7 @@ void ConvertTensorFlowShapeOperator(const Model& model, void ConvertRankOperator(const Model& model, const RankOperator& src_op, GraphDef* tensorflow_graph) { - auto* rank_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* rank_op = tensorflow_graph->add_node(); rank_op->set_op("Rank"); rank_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); @@ -1165,7 +1250,7 @@ void ConvertRankOperator(const Model& model, const RankOperator& src_op, void ConvertRangeOperator(const Model& model, const RangeOperator& src_op, GraphDef* tensorflow_graph) { - auto* range_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* range_op = tensorflow_graph->add_node(); range_op->set_op("Range"); range_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); @@ -1176,22 +1261,22 @@ void ConvertRangeOperator(const Model& model, const RangeOperator& src_op, GetTensorFlowDataType(src_op.dtype)); } -void ConvertStackOperator(const Model& model, const StackOperator& src_op, - GraphDef* tensorflow_graph) { - auto* stack_op = tensorflow_graph->add_node(); - stack_op->set_op("Stack"); - stack_op->set_name(src_op.outputs[0]); +void ConvertPackOperator(const Model& model, const PackOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* pack_op = tensorflow_graph->add_node(); + pack_op->set_op("Pack"); + pack_op->set_name(src_op.outputs[0]); for (const auto& input : src_op.inputs) { - *stack_op->add_input() = input; + *pack_op->add_input() = input; } - (*stack_op->mutable_attr())["elem_type"].set_type( - GetTensorFlowDataType(model, src_op.outputs[0])); - (*stack_op->mutable_attr())["axis"].set_i(src_op.axis); + (*pack_op->mutable_attr())["axis"].set_i(src_op.axis); + (*pack_op->mutable_attr())["N"].set_i(src_op.inputs.size()); + (*pack_op->mutable_attr())["T"].set_type(GetTensorFlowDataType(src_op.dtype)); } void ConvertFillOperator(const Model& model, const FillOperator& src_op, GraphDef* tensorflow_graph) { - auto* fill_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* fill_op = tensorflow_graph->add_node(); fill_op->set_op("Fill"); fill_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1205,7 +1290,7 @@ void ConvertFillOperator(const Model& model, const FillOperator& src_op, void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op, GraphDef* tensorflow_graph) { - auto* floor_div_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* floor_div_op = tensorflow_graph->add_node(); floor_div_op->set_op("FloorDiv"); floor_div_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1218,7 +1303,7 @@ void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op, void ConvertExpandDimsOperator(const Model& model, const ExpandDimsOperator& src_op, GraphDef* tensorflow_graph) { - auto* expand_dims_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* expand_dims_op = tensorflow_graph->add_node(); expand_dims_op->set_op("ExpandDims"); expand_dims_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1233,7 +1318,7 @@ void ConvertExpandDimsOperator(const Model& model, void ConvertResizeBilinearOperator(const Model& model, const ResizeBilinearOperator& src_op, GraphDef* tensorflow_graph) { - auto* resize_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* resize_op = tensorflow_graph->add_node(); resize_op->set_op("ResizeBilinear"); resize_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1243,6 +1328,20 @@ void ConvertResizeBilinearOperator(const Model& model, (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners); } +void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node(); + onehot_op->set_op("OneHot"); + onehot_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 4); + for (const auto& input : src_op.inputs) { + *onehot_op->add_input() = input; + } + (*onehot_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); + (*onehot_op->mutable_attr())["axis"].set_i(src_op.axis); +} + namespace { // TODO(aselle): Remove when available in absl absl::string_view FindLongestCommonPrefix(absl::string_view a, @@ -1283,7 +1382,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // works the same since the tensor has the same underlying data layout. const string axis_output = concat_output + "/axis"; CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph); - auto* concat_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* concat_op = tensorflow_graph->add_node(); concat_op->set_op("ConcatV2"); concat_op->set_name(concat_output); *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT]; @@ -1311,7 +1410,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Fully connected matrix multiply const string matmul_output = base + "MatMul"; - auto* matmul_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node(); matmul_op->set_op("MatMul"); matmul_op->set_name(matmul_output); *matmul_op->add_input() = concat_output; @@ -1340,7 +1439,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Add biases string biasadd_output = base + "BiasAdd"; - auto* biasadd_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); biasadd_op->set_op("BiasAdd"); biasadd_op->set_name(biasadd_output); biasadd_op->add_input(matmul_output); @@ -1353,7 +1452,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // The dimension is the same as the concatenation dimension CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph); string split_output = base + "split"; - auto* split_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* split_op = tensorflow_graph->add_node(); split_op->set_op("Split"); split_op->set_name(split_output); *split_op->add_input() = split_dim_output; @@ -1363,21 +1462,21 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Activation functions and memory computations const string tanh_0_output = base + "Tanh"; - auto* tanh_0_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* tanh_0_op = tensorflow_graph->add_node(); tanh_0_op->set_op("Tanh"); tanh_0_op->set_name(tanh_0_output); *tanh_0_op->add_input() = split_output + ":1"; (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT); const string sigmoid_1_output = base + "Sigmoid_1"; - auto* logistic_1_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* logistic_1_op = tensorflow_graph->add_node(); logistic_1_op->set_op("Sigmoid"); logistic_1_op->set_name(sigmoid_1_output); *logistic_1_op->add_input() = split_output; (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT); const string mul_1_output = base + "mul_1"; - auto* mul_1_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* mul_1_op = tensorflow_graph->add_node(); mul_1_op->set_op("Mul"); mul_1_op->set_name(mul_1_output); *mul_1_op->add_input() = sigmoid_1_output; @@ -1385,21 +1484,21 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT); const string sigmoid_0_output = base + "Sigmoid"; - auto* logistic_2_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* logistic_2_op = tensorflow_graph->add_node(); logistic_2_op->set_op("Sigmoid"); logistic_2_op->set_name(sigmoid_0_output); *logistic_2_op->add_input() = split_output + ":2"; (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT); const string sigmoid_2_output = base + "Sigmoid_2"; - auto* logistic_3_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* logistic_3_op = tensorflow_graph->add_node(); logistic_3_op->set_op("Sigmoid"); logistic_3_op->set_name(sigmoid_2_output); *logistic_3_op->add_input() = split_output + ":3"; (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT); const string mul_0_output = base + "mul"; - auto* mul_0_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* mul_0_op = tensorflow_graph->add_node(); mul_0_op->set_op("Mul"); mul_0_op->set_name(mul_0_output); *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT]; @@ -1407,7 +1506,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT); const string add_1_output = src_op.outputs[LstmCellOperator::STATE_OUTPUT]; - auto* add_1_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* add_1_op = tensorflow_graph->add_node(); add_1_op->set_op("Add"); add_1_op->set_name(add_1_output); *add_1_op->add_input() = mul_0_output; @@ -1415,14 +1514,14 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT); const string tanh_1_output = base + "Tanh_1"; - auto* tanh_1_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* tanh_1_op = tensorflow_graph->add_node(); tanh_1_op->set_op("Tanh"); tanh_1_op->set_name(tanh_1_output); *tanh_1_op->add_input() = add_1_output; (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT); const string mul_2_output = src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]; - auto* mul_2_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* mul_2_op = tensorflow_graph->add_node(); mul_2_op->set_op("Mul"); mul_2_op->set_name(mul_2_output); *mul_2_op->add_input() = tanh_1_output; @@ -1433,14 +1532,15 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, void ConvertSpaceToBatchNDOperator(const Model& model, const SpaceToBatchNDOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("SpaceToBatchND"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); *new_op->add_input() = src_op.inputs[0]; *new_op->add_input() = src_op.inputs[1]; *new_op->add_input() = src_op.inputs[2]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32); @@ -1449,14 +1549,15 @@ void ConvertSpaceToBatchNDOperator(const Model& model, void ConvertBatchToSpaceNDOperator(const Model& model, const BatchToSpaceNDOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("BatchToSpaceND"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); *new_op->add_input() = src_op.inputs[0]; *new_op->add_input() = src_op.inputs[1]; *new_op->add_input() = src_op.inputs[2]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32); @@ -1464,18 +1565,19 @@ void ConvertBatchToSpaceNDOperator(const Model& model, void ConvertPadOperator(const Model& model, const PadOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("Pad"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *new_op->add_input() = src_op.inputs[0]; *new_op->add_input() = src_op.inputs[1]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); // Create the params tensor. - auto* params_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); params_op->set_op("Const"); params_op->set_name(src_op.inputs[1]); (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -1494,7 +1596,7 @@ void ConvertPadOperator(const Model& model, const PadOperator& src_op, void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("PadV2"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); @@ -1502,11 +1604,12 @@ void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op, *new_op->add_input() = src_op.inputs[1]; *new_op->add_input() = src_op.inputs[2]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); // Create the params tensor. - auto* params_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); params_op->set_op("Const"); params_op->set_name(src_op.inputs[1]); (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -1525,7 +1628,7 @@ void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op, void CreateSliceInput(const string& input_name, const std::vector& values, GraphDef* tensorflow_graph) { - auto* params_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); params_op->set_op("Const"); params_op->set_name(input_name); (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -1542,7 +1645,7 @@ void CreateSliceInput(const string& input_name, const std::vector& values, void ConvertStridedSliceOperator(const Model& model, const StridedSliceOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("StridedSlice"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 4); @@ -1551,7 +1654,8 @@ void ConvertStridedSliceOperator(const Model& model, *new_op->add_input() = src_op.inputs[2]; *new_op->add_input() = src_op.inputs[3]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); (*new_op->mutable_attr())["Index"].set_type(DT_INT32); @@ -1569,7 +1673,7 @@ void ConvertStridedSliceOperator(const Model& model, void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("Slice"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); @@ -1577,7 +1681,8 @@ void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, *new_op->add_input() = src_op.inputs[1]; *new_op->add_input() = src_op.inputs[2]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); (*new_op->mutable_attr())["Index"].set_type(DT_INT32); @@ -1586,24 +1691,29 @@ void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph); } -void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, - GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); - new_op->set_op("Mean"); +template +void ConvertReduceOperator(const Model& model, const T& src_op, + GraphDef* tensorflow_graph, const string& op_name) { + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); + new_op->set_op(op_name); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *new_op->add_input() = src_op.inputs[0]; *new_op->add_input() = src_op.inputs[1]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); + const tensorflow::DataType indices_type = + GetTensorFlowDataType(model, src_op.inputs[1]); + (*new_op->mutable_attr())["Tidx"].set_type(indices_type); if (src_op.keep_dims) { (*new_op->mutable_attr())["keep_dims"].set_b(true); } // Create the params tensor. - auto* params_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); params_op->set_op("Const"); params_op->set_name(src_op.inputs[1]); (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); @@ -1619,13 +1729,14 @@ void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op, GraphDef* tensorflow_graph) { - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("Squeeze"); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 1); *new_op->add_input() = src_op.inputs[0]; - const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType params_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); if (!src_op.squeeze_dims.empty()) { @@ -1638,63 +1749,87 @@ void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op, void ConvertSubOperator(const Model& model, const SubOperator& src_op, GraphDef* tensorflow_graph) { - auto* sub_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); sub_op->set_op("Sub"); sub_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *sub_op->add_input() = src_op.inputs[0]; *sub_op->add_input() = src_op.inputs[1]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*sub_op->mutable_attr())["T"].set_type(data_type); } void ConvertTensorFlowMinimumOperator(const Model& model, const TensorFlowMinimumOperator& src_op, GraphDef* tensorflow_graph) { - auto* sub_op = tensorflow_graph->add_node(); - sub_op->set_op("Minimum"); - sub_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* min_op = tensorflow_graph->add_node(); + min_op->set_op("Minimum"); + min_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); - *sub_op->add_input() = src_op.inputs[0]; - *sub_op->add_input() = src_op.inputs[1]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); - (*sub_op->mutable_attr())["T"].set_type(data_type); + *min_op->add_input() = src_op.inputs[0]; + *min_op->add_input() = src_op.inputs[1]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*min_op->mutable_attr())["T"].set_type(data_type); } void ConvertTensorFlowMaximumOperator(const Model& model, const TensorFlowMaximumOperator& src_op, GraphDef* tensorflow_graph) { - auto* sub_op = tensorflow_graph->add_node(); - sub_op->set_op("Maximum"); - sub_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* max_op = tensorflow_graph->add_node(); + max_op->set_op("Maximum"); + max_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); - *sub_op->add_input() = src_op.inputs[0]; - *sub_op->add_input() = src_op.inputs[1]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); - (*sub_op->mutable_attr())["T"].set_type(data_type); + *max_op->add_input() = src_op.inputs[0]; + *max_op->add_input() = src_op.inputs[1]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*max_op->mutable_attr())["T"].set_type(data_type); } void ConvertSelectOperator(const Model& model, const SelectOperator& src_op, GraphDef* tensorflow_graph) { - auto* sub_op = tensorflow_graph->add_node(); - sub_op->set_op("Select"); - sub_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* select_op = tensorflow_graph->add_node(); + select_op->set_op("Select"); + select_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); - *sub_op->add_input() = src_op.inputs[0]; - *sub_op->add_input() = src_op.inputs[1]; - *sub_op->add_input() = src_op.inputs[2]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[1]); - (*sub_op->mutable_attr())["T"].set_type(data_type); + *select_op->add_input() = src_op.inputs[0]; + *select_op->add_input() = src_op.inputs[1]; + *select_op->add_input() = src_op.inputs[2]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[1]); + (*select_op->mutable_attr())["T"].set_type(data_type); +} + +void ConvertTileOperator(const Model& model, + const TensorFlowTileOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* tile_op = tensorflow_graph->add_node(); + tile_op->set_op("Tile"); + tile_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *tile_op->add_input() = src_op.inputs[0]; + *tile_op->add_input() = src_op.inputs[1]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*tile_op->mutable_attr())["T"].set_type(data_type); + const tensorflow::DataType multiples_data_type = + GetTensorFlowDataType(model, src_op.inputs[1]); + (*tile_op->mutable_attr())["Tmultiples"].set_type(multiples_data_type); } void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, GraphDef* tensorflow_graph) { - auto* topk_op = tensorflow_graph->add_node(); - topk_op->set_op("TOPKV2"); + tensorflow::NodeDef* topk_op = tensorflow_graph->add_node(); + topk_op->set_op("TopKV2"); topk_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *topk_op->add_input() = src_op.inputs[0]; *topk_op->add_input() = src_op.inputs[1]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*topk_op->mutable_attr())["T"].set_type(data_type); (*topk_op->mutable_attr())["sorted"].set_b(true); } @@ -1702,12 +1837,13 @@ void ConvertRandomUniformOperator(const Model& model, const RandomUniformOperator& src_op, GraphDef* tensorflow_graph) { CHECK(tensorflow_graph != nullptr); - auto* new_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); new_op->set_op("RandomUniform"); CHECK_EQ(src_op.inputs.size(), 1); new_op->set_name(src_op.outputs[0]); *new_op->add_input() = src_op.inputs[0]; - const auto shape_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType shape_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(shape_type); (*new_op->mutable_attr())["dtype"].set_type( GetTensorFlowDataType(src_op.dtype)); @@ -1718,13 +1854,14 @@ void ConvertRandomUniformOperator(const Model& model, void ConvertComparisonOperator(const Model& model, const Operator& src_op, const char* op_name, GraphDef* tensorflow_graph) { - auto* comparison_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* comparison_op = tensorflow_graph->add_node(); comparison_op->set_op(op_name); comparison_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *comparison_op->add_input() = src_op.inputs[0]; *comparison_op->add_input() = src_op.inputs[1]; - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*comparison_op->mutable_attr())["T"].set_type(data_type); } @@ -1732,21 +1869,104 @@ void ConvertSparseToDenseOperator(const Model& model, const SparseToDenseOperator& src_op, const char* op_name, GraphDef* tensorflow_graph) { - auto* sparse_to_dense_op = tensorflow_graph->add_node(); + tensorflow::NodeDef* sparse_to_dense_op = tensorflow_graph->add_node(); sparse_to_dense_op->set_op(op_name); sparse_to_dense_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 4); for (int i = 0; i < 4; ++i) { *sparse_to_dense_op->add_input() = src_op.inputs[i]; } - const auto data_type = GetTensorFlowDataType(model, src_op.inputs[3]); + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[3]); (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type); - const auto index_type = GetTensorFlowDataType(model, src_op.inputs[0]); + const tensorflow::DataType index_type = + GetTensorFlowDataType(model, src_op.inputs[0]); (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type); (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b( src_op.validate_indices); } +void ConvertPowOperator(const Model& model, const PowOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + tensorflow::NodeDef* pow_op = tensorflow_graph->add_node(); + pow_op->set_op(op_name); + pow_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *pow_op->add_input() = src_op.inputs[i]; + } + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*pow_op->mutable_attr())["T"].set_type(data_type); +} + +void ConvertAnyOperator(const Model& model, const AnyOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* any_op = tensorflow_graph->add_node(); + any_op->set_op("Any"); + any_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *any_op->add_input() = src_op.inputs[i]; + } + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[1]); + (*any_op->mutable_attr())["Tidx"].set_type(data_type); + (*any_op->mutable_attr())["keep_dims"].set_b(src_op.keep_dims); +} + +void ConvertLogicalAndOperator(const Model& model, + const LogicalAndOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* logical_op = tensorflow_graph->add_node(); + logical_op->set_op("LogicalAnd"); + logical_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *logical_op->add_input() = src_op.inputs[i]; + } +} + +void ConvertLogicalNotOperator(const Model& model, + const LogicalNotOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* logical_op = tensorflow_graph->add_node(); + logical_op->set_op("LogicalNot"); + logical_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *logical_op->add_input() = src_op.inputs[0]; +} + +void ConvertLogicalOrOperator(const Model& model, + const LogicalOrOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + tensorflow::NodeDef* logical_or_op = tensorflow_graph->add_node(); + logical_or_op->set_op(op_name); + logical_or_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *logical_or_op->add_input() = src_op.inputs[i]; + } + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*logical_or_op->mutable_attr())["T"].set_type(data_type); +} + +void ConvertCTCBeamSearchDecoderOperator( + const Model& model, const CTCBeamSearchDecoderOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + auto* op = tensorflow_graph->add_node(); + op->set_op(op_name); + op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *op->add_input() = src_op.inputs[i]; + } + (*op->mutable_attr())["beam_width"].set_i(src_op.beam_width); + (*op->mutable_attr())["top_paths"].set_i(src_op.top_paths); + (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1782,8 +2002,11 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kMul) { ConvertMulOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kDiv) { + ConvertDivOperator(model, static_cast(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kRelu) { - ConvertReluOperator(static_cast(src_op), + ConvertReluOperator(model, static_cast(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kRelu1) { ConvertRelu1Operator(static_cast(src_op), @@ -1827,20 +2050,24 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertConcatenationOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowReshape) { + } else if (src_op.type == OperatorType::kReshape) { ConvertTensorFlowReshapeOperator( model, static_cast(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kL2Pool) { ConvertL2PoolOperator(static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowSquare) { + } else if (src_op.type == OperatorType::kSquare) { ConvertSquareOperator(static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowSqrt) { + } else if (src_op.type == OperatorType::kSqrt) { ConvertSqrtOperator(static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowSplit) { + } else if (src_op.type == OperatorType::kRsqrt) { + ConvertRsqrtOperator(model, + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSplit) { ConvertSplitOperator(model, static_cast(src_op), tensorflow_graph); @@ -1879,16 +2106,32 @@ void ConvertOperator(const Model& model, const Operator& src_op, model, static_cast(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kMean) { - ConvertMeanOperator(model, static_cast(src_op), - tensorflow_graph); + ConvertReduceOperator(model, static_cast(src_op), + tensorflow_graph, "Mean"); + } else if (src_op.type == OperatorType::kSum) { + ConvertReduceOperator(model, + static_cast(src_op), + tensorflow_graph, "Sum"); + } else if (src_op.type == OperatorType::kReduceProd) { + ConvertReduceOperator(model, + static_cast(src_op), + tensorflow_graph, "Prod"); + } else if (src_op.type == OperatorType::kReduceMin) { + ConvertReduceOperator(model, + static_cast(src_op), + tensorflow_graph, "Min"); + } else if (src_op.type == OperatorType::kReduceMax) { + ConvertReduceOperator(model, + static_cast(src_op), + tensorflow_graph, "Max"); } else if (src_op.type == OperatorType::kSub) { ConvertSubOperator(model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowMinimum) { + } else if (src_op.type == OperatorType::kMinimum) { ConvertTensorFlowMinimumOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowMaximum) { + } else if (src_op.type == OperatorType::kMaximum) { ConvertTensorFlowMaximumOperator( model, static_cast(src_op), tensorflow_graph); @@ -1901,13 +2144,16 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kArgMax) { ConvertArgMaxOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kArgMin) { + ConvertArgMinOperator(model, static_cast(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kTopK_V2) { ConvertTopKV2Operator(model, static_cast(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kTranspose) { ConvertTransposeOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowShape) { + } else if (src_op.type == OperatorType::kShape) { ConvertTensorFlowShapeOperator( model, static_cast(src_op), tensorflow_graph); @@ -1917,9 +2163,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kRange) { ConvertRangeOperator(model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kStack) { - ConvertStackOperator(model, static_cast(src_op), - tensorflow_graph); + } else if (src_op.type == OperatorType::kPack) { + ConvertPackOperator(model, static_cast(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kFill) { ConvertFillOperator(model, static_cast(src_op), tensorflow_graph); @@ -1938,21 +2184,50 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertRandomUniformOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowEqual) { + } else if (src_op.type == OperatorType::kEqual) { ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowNotEqual) { + } else if (src_op.type == OperatorType::kNotEqual) { ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowGreater) { + } else if (src_op.type == OperatorType::kGreater) { ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) { + } else if (src_op.type == OperatorType::kGreaterEqual) { ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowLess) { + } else if (src_op.type == OperatorType::kLess) { ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowLessEqual) { + } else if (src_op.type == OperatorType::kLessEqual) { ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph); } else if (src_op.type == OperatorType::kSelect) { ConvertSelectOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTile) { + ConvertTileOperator(model, + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kPow) { + ConvertPowOperator(model, static_cast(src_op), "Pow", + tensorflow_graph); + } else if (src_op.type == OperatorType::kAny) { + ConvertAnyOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLogicalAnd) { + ConvertLogicalAndOperator(model, + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLogicalNot) { + ConvertLogicalNotOperator(model, + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kOneHot) { + ConvertOneHotOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLogicalOr) { + ConvertLogicalOrOperator(model, + static_cast(src_op), + "LogicalOr", tensorflow_graph); + } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) { + ConvertCTCBeamSearchDecoderOperator( + model, static_cast(src_op), + "CTCBeamSearchDecoder", tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } @@ -1960,7 +2235,7 @@ void ConvertOperator(const Model& model, const Operator& src_op, void AddPlaceholder(const string& name, ArrayDataType type, GraphDef* tensorflow_graph) { - auto* placeholder = tensorflow_graph->add_node(); + tensorflow::NodeDef* placeholder = tensorflow_graph->add_node(); placeholder->set_op("Placeholder"); switch (type) { case ArrayDataType::kBool: @@ -1989,7 +2264,7 @@ void AddPlaceholder(const string& name, ArrayDataType type, void AddPlaceholderForRNNState(const Model& model, const string& name, int size, GraphDef* tensorflow_graph) { - auto* placeholder = tensorflow_graph->add_node(); + tensorflow::NodeDef* placeholder = tensorflow_graph->add_node(); placeholder->set_op("Placeholder"); placeholder->set_name(name); (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT); @@ -2031,6 +2306,9 @@ void ExportTensorFlowGraphDefImplementation(const Model& model, const auto& array = *array_pair.second; if (array.buffer) { switch (array.data_type) { + case ArrayDataType::kBool: + ConvertBoolTensorConst(model, array_name, tensorflow_graph); + break; case ArrayDataType::kFloat: ConvertFloatTensorConst(model, array_name, tensorflow_graph); break; diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md index 8e93f02ef109f7bccd07ce54baff3d0bb4ae50c7..4bf47aa3c4d1b682808ab8175c4d07d8a347067a 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md @@ -9,57 +9,56 @@ complemented by the following documents: Table of contents: -* [Convert a TensorFlow SavedModel to TensorFlow Lite](#savedmodel) -* [Convert a TensorFlow GraphDef to TensorFlow Lite for float - inference](#graphdef-float) +* [Command-line tools](#tools) + * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9) +* [Basic examples](#basic) + * [Convert a TensorFlow GraphDef](#graphdef) + * [Convert a TensorFlow SavedModel](#savedmodel) + * [Convert a tf.keras model](#keras) * [Quantization](#quantization) - * [Convert a TensorFlow GraphDef to TensorFlow Lite for quantized - inference](#graphdef-quant) + * [Convert a TensorFlow GraphDef for quantized inference](#graphdef-quant) * [Use "dummy-quantization" to try out quantized inference on a float graph](#dummy-quant) * [Specifying input and output arrays](#specifying-input-and-output-arrays) - * [Multiple output arrays](#multiple-output-arrays) * [Multiple input arrays](#multiple-input-arrays) + * [Multiple output arrays](#multiple-output-arrays) * [Specifying subgraphs](#specifying-subgraphs) -* [Other conversions supported by TOCO](#other-conversions) - * [Optimize a TensorFlow GraphDef](#optimize-graphdef) - * [Convert a TensorFlow Lite FlatBuffer back into TensorFlow GraphDef - format](#to-graphdef) -* [Logging](#logging) - * [Graph "video" logging](#graph-video-logging) * [Graph visualizations](#graph-visualizations) * [Using --output_format=GRAPHVIZ_DOT](#using-output-formatgraphviz-dot) * [Using --dump_graphviz](#using-dump-graphviz) + * [Graph "video" logging](#graph-video-logging) * [Legend for the graph visualizations](#graphviz-legend) -## Convert a TensorFlow SavedModel to TensorFlow Lite +## Command-line tools -The follow example converts a basic TensorFlow SavedModel into a Tensorflow Lite -FlatBuffer to perform floating-point inference. +There are two approaches to running TOCO via command line. -``` -bazel run --config=opt \ - third_party/tensorflow/contrib/lite/toco:toco -- \ - --savedmodel_directory=/tmp/saved_model \ - --output_file=/tmp/foo.tflite -``` +* `tflite_convert`: Starting from TensorFlow 1.9, the command-line tool + `tflite_convert` will be installed as part of the Python package. All of the + examples below use `tflite_convert` for simplicity. + * Example: `tflite_convert --output_file=...` +* `bazel`: In order to run the latest version of TOCO, [clone the TensorFlow + repository](https://www.tensorflow.org/install/install_sources#clone_the_tensorflow_repository) + and use `bazel`. This is the recommended approach for converting models that + utilize new features that were not supported by TOCO in TensorFlow 1.9. + * Example: `bazel run + //tensorflow/contrib/lite/python:tflite_convert -- + --output_file=...` -[SavedModel](https://www.tensorflow.org/programmers_guide/saved_model#using_savedmodel_with_estimators) -has fewer required flags than frozen graphs (described [below](#graphdef-float)) -due to access to additional data contained within the SavedModel. The values for -`--input_arrays` and `--output_arrays` are an aggregated, alphabetized list of -the inputs and outputs in the -[SignatureDefs](https://www.tensorflow.org/serving/signature_defs) within the -[MetaGraphDef](https://www.tensorflow.org/programmers_guide/saved_model#apis_to_build_and_load_a_savedmodel) -specified by `--savedmodel_tagset`. The value for `input_shapes` is -automatically determined from the MetaGraphDef whenever possible. The default -value for `--inference_type` for SavedModels is `FLOAT`. +### Converting models prior to TensorFlow 1.9. -There is currently no support for MetaGraphDefs without a SignatureDef or for -MetaGraphDefs that use the [`assets/` -directory](https://www.tensorflow.org/programmers_guide/saved_model#structure_of_a_savedmodel_directory). +The recommended approach for using TOCO prior to TensorFlow 1.9 is the [Python +API](python_api.md#pre-tensorflow-1.9). If a command line tool is desired, the +`toco` command line tool was available in TensorFlow 1.7. Enter `toco --help` in +Terminal for additional details on the command-line flags available. There were +no command line tools in TensorFlow 1.8. + +## Basic examples + +The following section shows examples of how to convert a basic float-point model +from each of the supported data formats into a TensorFlow Lite FlatBuffers. -## Convert a TensorFlow GraphDef to TensorFlow Lite for float inference +### Convert a TensorFlow GraphDef The follow example converts a basic TensorFlow GraphDef (frozen by [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)) @@ -69,19 +68,54 @@ graphs contain the variables stored in Checkpoint files as Const ops. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ +tflite_convert \ --output_file=/tmp/foo.tflite \ - --inference_type=FLOAT \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 + --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 +``` + +The value for `input_shapes` is automatically determined whenever possible. + +### Convert a TensorFlow SavedModel + +The follow example converts a basic TensorFlow SavedModel into a Tensorflow Lite +FlatBuffer to perform floating-point inference. + +``` +tflite_convert \ + --output_file=/tmp/foo.tflite \ + --saved_model_dir=/tmp/saved_model +``` + +[SavedModel](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators) +has fewer required flags than frozen graphs due to access to additional data +contained within the SavedModel. The values for `--input_arrays` and +`--output_arrays` are an aggregated, alphabetized list of the inputs and outputs +in the [SignatureDefs](https://www.tensorflow.org/serving/signature_defs) within +the +[MetaGraphDef](https://www.tensorflow.org/guide/saved_model#apis_to_build_and_load_a_savedmodel) +specified by `--saved_model_tag_set`. As with the GraphDef, the value for +`input_shapes` is automatically determined whenever possible. + +There is currently no support for MetaGraphDefs without a SignatureDef or for +MetaGraphDefs that use the [`assets/` +directory](https://www.tensorflow.org/guide/saved_model#structure_of_a_savedmodel_directory). + +### Convert a tf.Keras model + +The following example converts a `tf.keras` model into a TensorFlow Lite +Flatbuffer. The `tf.keras` file must contain both the model and the weights. + +``` +tflite_convert \ + --output_file=/tmp/foo.tflite \ + --keras_model_file=/tmp/keras_model.h5 ``` ## Quantization -### Convert a TensorFlow GraphDef to TensorFlow Lite for quantized inference +### Convert a TensorFlow GraphDef for quantized inference TOCO is compatible with fixed point quantization models described [here](https://www.tensorflow.org/performance/quantization). These are float @@ -95,18 +129,14 @@ The following command generates a quantized TensorFlow Lite FlatBuffer from a "quantized" TensorFlow GraphDef. ``` -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/some_quantized_graph.pb \ +tflite_convert \ --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ + --graph_def_file=/tmp/some_quantized_graph.pb \ --inference_type=QUANTIZED_UINT8 \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 \ - --mean_value=128 \ - --std_value=127 + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 \ + --mean_values=128 \ + --std_dev_values=127 ``` ### Use \"dummy-quantization\" to try out quantized inference on a float graph @@ -124,45 +154,20 @@ a reasonable guess is that most activation ranges should be contained in [0, 6]. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ +tflite_convert \ --output_file=/tmp/foo.cc \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ + --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ --inference_type=QUANTIZED_UINT8 \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 \ + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 \ --default_ranges_min=0 \ --default_ranges_max=6 \ - --mean_value=127.5 \ - --std_value=127.5 + --mean_values=128 \ + --std_dev_values=127 ``` ## Specifying input and output arrays -### Multiple output arrays - -The flag `output_arrays` takes in a comma-separated list of output arrays as -seen in the example below. This is useful for models or subgraphs with multiple -outputs. - -``` -curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ - | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ - --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ - --inference_type=FLOAT \ - --input_shape=1,224,224,3 \ - --input_array=input \ - --output_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu -``` - ### Multiple input arrays The flag `input_arrays` takes in a comma-separated list of input arrays as seen @@ -172,21 +177,33 @@ inputs. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ +tflite_convert \ + --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \ --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ - --inference_type=FLOAT \ --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \ --input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \ - --output_array=InceptionV1/Logits/Predictions/Reshape_1 + --output_arrays=InceptionV1/Logits/Predictions/Reshape_1 ``` Note that `input_shapes` is provided as a colon-separated list. Each input shape corresponds to the input array at the same position in the respective list. +### Multiple output arrays + +The flag `output_arrays` takes in a comma-separated list of output arrays as +seen in the example below. This is useful for models or subgraphs with multiple +outputs. + +``` +curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ + | tar xzv -C /tmp +tflite_convert \ + --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \ + --output_file=/tmp/foo.tflite \ + --input_arrays=input \ + --output_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu +``` + ### Specifying subgraphs Any array in the input file can be specified as an input or output array in @@ -201,115 +218,57 @@ GraphDef. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ +tflite_convert \ + --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \ --output_file=/tmp/foo.pb \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TENSORFLOW_GRAPHDEF \ --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \ --input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \ - --output_array=InceptionV1/InceptionV1/Mixed_3b/concat_v2 + --output_arrays=InceptionV1/InceptionV1/Mixed_3b/concat_v2 ``` -Note that the final representation of an on-device inference workload (say, in -TensorFlow Lite FlatBuffers format) tends to have coarser granularity than the -very fine granularity of the TensorFlow GraphDef representation. For example, -while a fully-connected layer is typically represented as at least four separate -ops in TensorFlow GraphDef (Reshape, MatMul, BiasAdd, Relu...), it is typically -represented as a single "fused" op (FullyConnected) in the converter's optimized -representation and in the final on-device representation (e.g. in TensorFlow -Lite FlatBuffer format). As the level of granularity gets coarser, some +Note that the final representation in TensorFlow Lite FlatBuffers tends to have +coarser granularity than the very fine granularity of the TensorFlow GraphDef +representation. For example, while a fully-connected layer is typically +represented as at least four separate ops in TensorFlow GraphDef (Reshape, +MatMul, BiasAdd, Relu...), it is typically represented as a single "fused" op +(FullyConnected) in the converter's optimized representation and in the final +on-device representation. As the level of granularity gets coarser, some intermediate arrays (say, the array between the MatMul and the BiasAdd in the -TensorFlow GraphDef) are dropped. When specifying intermediate arrays as -`--input_arrays` / `--output_arrays`, it is desirable (and often required) to -specify arrays that are meant to survive in the final form of the graph, after -fusing. These are typically the outputs of activation functions (since -everything in each layer until the activation function tends to get fused). - -## Other conversions supported by TOCO - -The converter accepts both TENSORFLOW_GRAPHDEF and TFLITE file formats as both -`--input_format` and `--output_format`. This means that conversion to and from -any supported format is possible. - -### Optimize a TensorFlow GraphDef - -Same-format "conversions" can be used to optimize and simplify a graph or be -used to [get a subgraph](#specifying-subgraphs) of a graph. The flag -`--inference_type` is not required because TensorFlow graphs, including those -containing the -[`FakeQuant*`](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization) -ops are always float graphs. - -``` -curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ - | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ - --output_file=/tmp/foo.pb \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TENSORFLOW_GRAPHDEF \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 -``` +TensorFlow GraphDef) are dropped. -### Convert a TensorFlow Lite FlatBuffer back into TensorFlow GraphDef format - -The converter supports file format conversions from TensorFlow Lite, back into -TensorFlow GraphDef format. - -``` -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/foo.tflite \ - --output_file=/tmp/foo.pb \ - --input_format=TFLITE \ - --output_format=TENSORFLOW_GRAPHDEF \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 -``` +When specifying intermediate arrays as `--input_arrays` and `--output_arrays`, +it is desirable (and often required) to specify arrays that are meant to survive +in the final form of the graph, after fusing. These are typically the outputs of +activation functions (since everything in each layer until the activation +function tends to get fused). ## Logging -### Graph "video" logging - -When `--dump_graphviz=` is used (see the section on [graph -visualizations](#graph-visualizations)), one may additionally pass -`--dump_graphviz_video`, which causes a graph visualization to be dumped after -each individual graph transformation. This results in thousands of files. -Typically, one would then bisect into these files to understand when a given -change was introduced in the graph. ## Graph visualizations TOCO can export a graph to the GraphViz Dot format for easy visualization via -either the `--output_format` flag or the `--dump_graphviz` flag. The subsections -below outline the use cases for each. +either the `--output_format` flag or the `--dump_graphviz_dir` flag. The +subsections below outline the use cases for each. ### Using `--output_format=GRAPHVIZ_DOT` The first way to get a graphviz rendering is to pass `GRAPHVIZ_DOT` into `--output_format`. This results in a plausible visualization of the graph. This -reduces the requirements that normally exist during conversion between other -input and output formats. For example, this may be useful if conversion from -TENSORFLOW_GRAPHDEF to TFLITE is failing. +reduces the requirements that exist during conversion between other input and +output formats. This may be useful if conversion from TENSORFLOW_GRAPHDEF to +TFLITE is failing. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ +tflite_convert \ + --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ --output_file=/tmp/foo.dot \ - --input_format=TENSORFLOW_GRAPHDEF \ --output_format=GRAPHVIZ_DOT \ --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 ``` The resulting `.dot` file can be rendered into a PDF as follows: @@ -330,49 +289,35 @@ Example PDF files are viewable online in the next section. ### Using `--dump_graphviz` -The second way to get a graphviz rendering is to pass the `--dump_graphviz=` +The second way to get a graphviz rendering is to pass the `--dump_graphviz_dir` flag, specifying a destination directory to dump GraphViz rendering to. Unlike -the previous approach, this one allows you to keep your real command-line (with -your real `--output_format` and other flags) unchanged, just appending a -`--dump_graphviz=` flag to it. This provides a visualization of the actual graph -during a specific conversion process. +the previous approach, this one retains the original output format. This +provides a visualization of the actual graph resulting from a specific +conversion process. ``` curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ +tflite_convert \ + --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ - --inference_type=FLOAT \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 \ - --dump_graphviz=/tmp -``` - -This generates a few files in the destination directory, here `/tmp`. The two -most important files are: - -``` -/tmp/toco_AT_IMPORT.dot -/tmp/toco_AFTER_TRANSFORMATIONS.dot + --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 \ + --dump_graphviz_dir=/tmp ``` -`toco_AT_IMPORT.dot` represents the graph as it was imported from -`--input_file`, before any transformation was applied to it (besides some -transformations that are applied immediately while importing). This tends to be -a complex visualization with limited information, but is useful especially in -situations where a conversion command fails (this file is generated even if the -conversion subsequently fails). +This generates a few files in the destination directory. The two most important +files are `toco_AT_IMPORT.dot` and `/tmp/toco_AFTER_TRANSFORMATIONS.dot`. +`toco_AT_IMPORT.dot` represents the original graph containing only the +transformations done at import time. This tends to be a complex visualization +with limited information about each node. It is useful in situations where a +conversion command fails. `toco_AFTER_TRANSFORMATIONS.dot` represents the graph after all transformations -were applied to it, just before it was exported to the `--output_file`. -Typically, this is a much smaller graph with more information about each node. +were applied to it, just before it is exported. Typically, this is a much +smaller graph with more information about each node. -Again, these can be rendered to PDFs: +As before, these can be rendered to PDFs: ``` dot -Tpdf -O /tmp/toco_*.dot @@ -383,6 +328,14 @@ Sample output files can be seen here: * [toco_AT_IMPORT.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AT_IMPORT.dot.pdf) * [toco_AFTER_TRANSFORMATIONS.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AFTER_TRANSFORMATIONS.dot.pdf). +### Graph "video" logging + +When `--dump_graphviz_dir` is used, one may additionally pass +`--dump_graphviz_video`. This causes a graph visualization to be dumped after +each individual graph transformation, resulting in thousands of files. +Typically, one would then bisect into these files to understand when a given +change was introduced in the graph. + ### Legend for the graph visualizations * Operators are red square boxes with the following hues of red: diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md index 8085ae07489816c38677ff792e7ac71f1a75fa71..decc8a45a40ffba2a27320ce8391b1916391d744 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md @@ -1,7 +1,8 @@ # TensorFlow Lite Optimizing Converter command-line glossary -This page is complete reference of command-line flags. It is complemented by the -following other documents: +This page is complete reference of command-line flags used by TOCO's command +line starting from TensorFlow 1.9 up until the most recent build of TensorFlow. +It is complemented by the following other documents: * [README](../README.md) * [Command-line examples](cmdline_examples.md) @@ -16,116 +17,81 @@ Table of contents: ## High-level flags -The following high level flags specify the location of the input and output +The following high level flags specify the details of the input and output files. The flag `--output_file` is always required. Additionally, either -`--input_file` or `--savedmodel_directory` is required. - -* `--savedmodel_directory`. Type: string. Specifies the full path to the - directory containing the SavedModel. -* `--savedmodel_tagset`. Type: string. Default: +`--graph_def_file`, `--saved_model_dir` or `--keras_model_file` is required. + +* `--output_file`. Type: string. Specifies the full path of the output file. +* `--graph_def_file`. Type: string. Specifies the full path of the input + GraphDef file frozen using + [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). +* `--saved_model_dir`. Type: string. Specifies the full path to the directory + containing the SavedModel. +* `--keras_model_file`. Type: string. Specifies the full path of the HDF5 file + containing the tf.keras model. +* `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of + the output file. Allowed values: + * `TFLITE`: TensorFlow Lite FlatBuffer format. + * `GRAPHVIZ_DOT`: GraphViz `.dot` format containg a visualization of the + graph after graph transformations. + * Note that passing `GRAPHVIZ_DOT` to `--output_format` leads to loss + of TFLite specific transformations. Therefore, the resulting + visualization may not reflect the final set of graph + transformations. To get a final visualization with all graph + transformations use `--dump_graphviz` instead. + +The following flags specify optional parameters when using SavedModels. + +* `--saved_model_tag_set`. Type: string. Default: [kSavedModelTagServe](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/tag_constants.h). Specifies a comma-separated set of tags identifying the MetaGraphDef within the SavedModel to analyze. All tags in the tag set must be specified. -* `--input_file`. Type: string. Specifies the path of the input file. This may - be either an absolute or a relative path. -* `--output_file`. Type: string. Specifies the path of the output file. - -The following high level flags specify the types of the input and output files: - -* `--input_format`. Type: string. Default: `TENSORFLOW_GRAPHDEF`. Specifies - the format of the input file. Allowed values: - * `TENSORFLOW_GRAPHDEF` — The TensorFlow GraphDef format. Both - binary and text proto formats are allowed. - * `TFLITE` — The TensorFlow Lite FlatBuffers format. -* `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of - the output file. Allowed values: - * `TENSORFLOW_GRAPHDEF` — The TensorFlow GraphDef format. Always - produces a file in binary (not text) proto format. - * `TFLITE` — The TensorFlow Lite FlatBuffers format. - * Whether a float or quantized TensorFlow Lite file will be produced - depends on the `--inference_type` flag. - * `GRAPHVIZ_DOT` — The GraphViz `.dot` format. This asks the - converter to generate a reasonable graphical representation of the graph - after simplification by a generic set of transformation. - * A typical `dot` command line to view the resulting graph might look - like: `dot -Tpdf -O file.dot`. - * Note that since passing this `--output_format` means losing the - information of which output format you actually care about, and - since the converter's transformations depend on the specific output - format, the resulting visualization may not fully reflect what you - would get on the actual output format that you are using. To avoid - that concern, and generally to get a visualization of exactly what - you get in your actual output format as opposed to just a merely - plausible visualization of a model, consider using `--dump_graphviz` - instead and keeping your true `--output_format`. +* `--saved_model_signature_key`. Type: string. Default: + [DEFAULT_SERVING_SIGNATURE_DEF_KEY](https://www.tensorflow.org/api_docs/python/tf/saved_model/signature_constants). + Specifies the key identifying the SignatureDef containing inputs and + outputs. ## Model flags *Model flags* provide additional information about the model stored in the input file. -* `--output_array`. Type: string. Specifies a single array as the output - activations. Incompatible with `--output_arrays`. -* `--output_arrays`. Type: comma-separated list of strings. Specifies a list - of arrays as the output activations, for models with multiple outputs. - Incompatible with `--output_array`. -* `--input_array`. Type: string. Specifies a single array as the input - activations. Incompatible with `--input_arrays`. -* `--input_arrays`. Type: comma-separated list of strings. Specifies a list of - arrays as the input activations, for models with multiple inputs. - Incompatible with `--input_array`. -* `--batch_size`. Type: integer. Default: 1. Specifies the batch size for the - model. Replaces the first dimension of an input size array if undefined. Use - only with SavedModels when neither `--input_shape` nor `input_shapes` flags - are specified. Incompatible with GraphDefs. - -When `--input_array` is used, the following flags are available to provide -additional information about the single input array: - -* `--input_shape`. Type: comma-separated list of integers. Specifies the shape - of the input array, in TensorFlow convention: starting with the outer-most - dimension (the dimension corresponding to the largest offset stride in the - array layout), ending with the inner-most dimension (the dimension along - which array entries are typically laid out contiguously in memory). - * For example, a typical vision model might pass - `--input_shape=1,60,80,3`, meaning a batch size of 1 (no batching), an - input image height of 60, an input image width of 80, and an input image - depth of 3, for the typical case where the input image is a RGB bitmap - (3 channels, depth=3) stored by horizontal scanlines (so 'width' is the - next innermost dimension after 'depth'). -* `--mean_value` and `--std_value`. Type: floating-point. The decimal point - character is always the dot (`.`) regardless of the locale. These specify - the (de-)quantization parameters of the input array, when it is quantized. - * The meaning of mean_value and std_value is as follows: each quantized - value in the quantized input array will be interpreted as a mathematical - real number (i.e. as an input activation value) according to the - following formula: +* `--input_arrays`. Type: comma-separated list of strings. Specifies the list + of names of input activation tensors. +* `--output_arrays`. Type: comma-separated list of strings. Specifies the list + of names of output activation tensors. + +The following flags define properties of the input tensors. Each item in the +`--input_arrays` flag should correspond to each item in the following flags +based on index. + +* `--input_shapes`. Type: colon-separated list of comma-separated lists of + integers. Each comma-separated list of integers gives the shape of one of + the input arrays specified in [TensorFlow + convention](https://www.tensorflow.org/versions/r1.2/programmers_guide/dims_types#shape). + * Example: `--input_shapes=1,60,80,3` for a typical vision model means a + batch size of 1, an input image height of 60, an input image width of + 80, and an input image depth of 3 (representing RGB channels). + * Example: `--input_arrays=foo,bar --input_shapes=2,3:4,5,6` means "foo" + has a shape of [2, 3] and "bar" has a shape of [4, 5, 6]. +* `--std_dev_values`, `--mean_values`. Type: comma-separated list of integers. + These specify the (de-)quantization parameters of the input array, when it + is quantized. + * The meaning of `mean_values` and `std_dev_values` is as follows: each + quantized value in the quantized input array will be interpreted as a + mathematical real number (i.e. as an input activation value) according + to the following formula: * `real_value = (quantized_input_value - mean_value) / std_value`. * When performing float inference (`--inference_type=FLOAT`) on a quantized input, the quantized input would be immediately dequantized by the inference code according to the above formula, before proceeding with float inference. * When performing quantized inference - (`--inference_type=QUANTIZED_UINT8`), no dequantization is ever to be - performed by the inference code; however, the quantization parameters of - all arrays, including those of the input arrays as specified by - mean_value and std_value, all participate in the determination of the - fixed-point multipliers used in the quantized inference code. - -When `--input_arrays` is used, the following flags are available to provide -additional information about the multiple input arrays: - -* `--input_shapes`. Type: colon-separated list of comma-separated lists of - integers. Each comma-separated list of integer gives the shape of one of the - input arrays specified in `--input_arrays`, in the same order. See - `--input_shape` for details. - * Example: `--input_arrays=foo,bar --input_shapes=2,3:4,5,6` means that - there are two input arrays. The first one, "foo", has shape [2,3]. The - second one, "bar", has shape [4,5,6]. -* `--mean_values`, `--std_values`. Type: comma-separated lists of - floating-point numbers. Each number gives the corresponding value for one of - the input arrays specified in `--input_arrays`, in the same order. See - `--mean_value`, `--std_value` for details. + (`--inference_type=QUANTIZED_UINT8`), no dequantization is performed by + the inference code. However, the quantization parameters of all arrays, + including those of the input arrays as specified by `mean_value` and + `std_dev_value`, determine the fixed-point multipliers used in the + quantized inference code. ## Transformation flags @@ -133,21 +99,13 @@ additional information about the multiple input arrays: the graph, i.e. they specify requested properties that the output file should have. -* `--inference_type`. Type: string. Sets the type of real-number arrays in the - output file, that is, controls the representation (quantization) of real - numbers in the output file, except for input arrays, which are controlled by - `--inference_input_type`. - - This flag only impacts real-number arrays. By "real-number" we mean float - arrays, and quantized arrays. This excludes plain integer arrays, strings - arrays, and every other data type. +* `--inference_type`. Type: string. Default: `FLOAT`. Data type of all + real-number arrays in the output file except for input arrays (defined by + `--inference_input_type`). Must be `{FLOAT, QUANTIZED_UINT8}`. - For real-number arrays, the impact of this flag is to allow the output file - to choose a different real-numbers representation (quantization) from what - the input file used. For any other types of arrays, changing the data type - would not make sense. - - Specifically: + This flag only impacts real-number arrays including float and quantized + arrays. This excludes all other data types including plain integer arrays + and string arrays. Specifically: * If `FLOAT`, then real-numbers arrays will be of type float in the output file. If they were quantized in the input file, then they get @@ -155,66 +113,54 @@ have. * If `QUANTIZED_UINT8`, then real-numbers arrays will be quantized as uint8 in the output file. If they were float in the input file, then they get quantized. - * If not set, then all real-numbers arrays retain the same type in the - output file as they have in the input file. - -* `--inference_input_type`. Type: string. Similar to inference_type, but - allows to control specifically the quantization of input arrays, separately - from other arrays. - - If not set, then the value of `--inference_type` is implicitly used, i.e. by - default input arrays are quantized like other arrays. - - Like `--inference_type`, this only affects real-number arrays. By - "real-number" we mean float arrays, and quantized arrays. This excludes - plain integer arrays, strings arrays, and every other data type. - - The typical use for this flag is for vision models taking a bitmap as input, - typically with uint8 channels, yet still requiring floating-point inference. - For such image models, the uint8 input is quantized, i.e. the uint8 values - are interpreted as real numbers, and the quantization parameters used for - such input arrays are their `mean_value`, `std_value` parameters. - -* `--default_ranges_min`, `--default_ranges_max`. Type: floating-point. The - decimal point character is always the dot (`.`) regardless of the locale. - These flags enable what is called "dummy quantization". If defined, their - effect is to define fallback (min, max) range values for all arrays that do - not have a properly specified (min, max) range in the input file, thus - allowing to proceed with quantization of non-quantized or - incorrectly-quantized input files. This enables easy performance prototyping - ("how fast would my model run if I quantized it?") but should never be used - in production as the resulting quantized arithmetic is inaccurate. - -* `--drop_fake_quant`. Type: boolean. Default: false. Causes fake-quantization - nodes to be dropped from the graph. This may be used to recover a plain - float graph from a fake-quantized graph. - -* `--reorder_across_fake_quant`. Type: boolean. Default: false. Normally, - fake-quantization nodes must be strict boundaries for graph transformations, - in order to ensure that quantized inference has the exact same arithmetic - behavior as quantized training --- which is the whole point of quantized - training and of FakeQuant nodes in the first place. However, that entails - subtle requirements on where exactly FakeQuant nodes must be placed in the - graph. Some quantized graphs have FakeQuant nodes at unexpected locations, - that prevent graph transformations that are necessary in order to generate a - well-formed quantized representation of these graphs. Such graphs should be - fixed, but as a temporary work-around, setting this - reorder_across_fake_quant flag allows the converter to perform necessary - graph transformations on them, at the cost of no longer faithfully matching - inference and training arithmetic. - -* `--quantize_weights`. Type: boolean. Default: false. Store weights as - quantized weights followed by dequantize operations. Computation is still - done in float, but reduces model size (at the cost of accuracy and latency). + +* `--inference_input_type`. Type: string. Data type of a real-number input + array in the output file. By default the `--inference_type` is used as type + of all of the input arrays. Flag is primarily intended for generating a + float-point graph with a quantized input array. A Dequantized operator is + added immediately after the input array. Must be `{FLOAT, QUANTIZED_UINT8}`. + + The flag is typically used for vision models taking a bitmap as input but + requiring floating-point inference. For such image models, the uint8 input + is quantized and the quantization parameters used for such input arrays are + their `mean_value` and `std_dev_value` parameters. + +* `--default_ranges_min`, `--default_ranges_max`. Type: floating-point. + Default value for the (min, max) range values used for all arrays without a + specified range. Allows user to proceed with quantization of non-quantized + or incorrectly-quantized input files. These flags produce models with low + accuracy. They are intended for easy experimentation with quantization via + "dummy quantization". + +* `--drop_control_dependency`. Type: boolean. Default: True. Indicates whether + to drop control dependencies silently. This is due to TensorFlow Lite not + supporting control dependencies. + +* `--reorder_across_fake_quant`. Type: boolean. Default: False. Indicates + whether to reorder FakeQuant nodes in unexpected locations. Used when the + location of the FakeQuant nodes is preventing graph transformations + necessary to convert the graph. Results in a graph that differs from the + quantized training graph, potentially causing differing arithmetic behavior. + +* `--allow_custom_ops`. Type: string. Default: False. Indicates whether to + allow custom operations. When false, any unknown operation is an error. When + true, custom ops are created for any op that is unknown. The developer will + need to provide these to the TensorFlow Lite runtime with a custom resolver. + +* `--quantize_weights`. Type: boolean. Default: False. Indicates whether to + store weights as quantized weights followed by dequantize operations. + Computation is still done in float, but reduces model size (at the cost of + accuracy and latency). ## Logging flags -The following flags allow to generate graph visualizations of the actual graph -at various points during transformations: +The following flags generate graph visualizations of the graph as +[GraphViz](https://www.graphviz.org/) `.dot` files at various points during +graph transformations: -* `--dump_graphviz=/path` enables dumping of the graphs at various stages of - processing as GraphViz `.dot` files. Generally preferred over - `--output_format=GRAPHVIZ_DOT` as this allows you to keep your actually - relevant `--output_format`. -* `--dump_graphviz_video` enables dumping of the graph after every single - graph transformation (for debugging purposes). +* `--dump_graphviz_dir`. Type: string. Specifies the full path of the + directory to output GraphViz `.dot` files. Outputs the graph immediately + after reading in the graph and after all of the transformations have been + completed. +* `--dump_graphviz_video`. Type: boolean. Outputs GraphViz after every graph + transformation. Requires `--dump_graphviz_dir` to be specified. diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md index a7841a685528fb18bb08f1943278339a2daec16a..3799eac0a1181afe3b63d2f8651745c2ec61f5e0 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -15,11 +15,15 @@ Table of contents: * [Exporting a GraphDef from tf.Session](#basic-graphdef-sess) * [Exporting a GraphDef from file](#basic-graphdef-file) * [Exporting a SavedModel](#basic-savedmodel) + * [Exporting a tf.keras File](#basic-keras-file) * [Complex examples](#complex) * [Exporting a quantized GraphDef](#complex-quant) * [TensorFlow Lite Python interpreter](#interpreter) * [Using the interpreter from a model file](#interpreter-file) * [Using the interpreter from model data](#interpreter-data) +* [Additional instructions](#additional-instructions) + * [Build from source code](#latest-package) + * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9) ## High-level overview @@ -31,15 +35,17 @@ designing a model that can be targeted to devices with mobile. ## API -The API for converting TensorFlow models to TensorFlow Lite is -`tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is +The API for converting TensorFlow models to TensorFlow Lite as of TensorFlow 1.9 +is `tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is `tf.contrib.lite.Interpreter`. `TocoConverter` provides class methods based on the original format of the model. `TocoConverter.from_session()` is available for GraphDefs. -`TocoConverter.from_saved_model()` is available for SavedModels. Example usages -for simple float-point models are shown in [Basic Examples](#basic). Examples -usages for more complex models is shown in [Complex Examples](#complex). +`TocoConverter.from_saved_model()` is available for SavedModels. +`TocoConverter.from_keras_model_file()` is available for `tf.Keras` files. +Example usages for simple float-point models are shown in [Basic +Examples](#basic). Examples usages for more complex models is shown in [Complex +Examples](#complex). **NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python interpreter when the conversion fails. This will be remedied as soon as @@ -111,6 +117,51 @@ For more complex SavedModels, the optional parameters that can be passed into `output_arrays`, `tag_set` and `signature_key`. Details of each parameter are available by running `help(tf.contrib.lite.TocoConverter)`. +### Exporting a tf.keras File + +The following example shows how to convert a `tf.keras` model into a TensorFlow +Lite FlatBuffer. + +```python +import tensorflow as tf + +converter = tf.contrib.lite.TocoConverter.from_keras_model_file("keras_model.h5") +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + +The `tf.keras` file must contain both the model and the weights. A comprehensive +example including model construction can be seen below. + +```python +import numpy as np +import tensorflow as tf + +# Generate tf.keras model. +model = tf.keras.models.Sequential() +model.add(tf.keras.layers.Dense(2, input_shape=(3,))) +model.add(tf.keras.layers.RepeatVector(3)) +model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3))) +model.compile(loss=tf.keras.losses.MSE, + optimizer=tf.keras.optimizers.RMSprop(lr=0.0001), + metrics=[tf.keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + +x = np.random.random((1, 3)) +y = np.random.random((1, 3, 3)) +model.train_on_batch(x, y) +model.predict(x) + +# Save tf.keras model in HDF5 format. +keras_file = "keras_model.h5" +tf.keras.models.save_model(model, keras_file) + +# Convert to TensorFlow Lite model. +converter = tf.contrib.lite.TocoConverter.from_keras_model_file(keras_file) +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + ## Complex examples For models where the default value of the attributes is not sufficient, the @@ -200,3 +251,18 @@ with tf.Session() as sess: interpreter = tf.contrib.lite.Interpreter(model_content=tflite_model) interpreter.allocate_tensors() ``` + +## Additional instructions + +### Build from source code + +In order to run the latest version of the TOCO Python API, clone the TensorFlow +repository, configure the installation, and build and install the pip package. +Detailed instructions are available +[here](https://www.tensorflow.org/install/install_sources). + +### Converting models prior to TensorFlow 1.9. + +To use TOCO in TensorFlow 1.7 and TensorFlow 1.8, use the `toco_convert` +function. Run `help(tf.contrib.lite.toco_convert)` to get details about accepted +parameters. diff --git a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg index a47c088991299159be39bc490149720dae43eb53..262e13a591b998c4f38f0a9f44a5b385f612df90 100644 --- a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg +++ b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc index 56f48d47de4e86ece76ceef1d09a25f50957a8dc..310a88484c246b8035aa73b5e04ad677d575e4c4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc @@ -40,11 +40,6 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { // Yield until input dims have been resolved. return false; } - if (input_array.shape().dimensions_count() == 0) { - // Input array cannot be 0-D. - // (Unsure if this is TF behavior, but was required to get a test to pass.) - return false; - } const auto& axis_array = model->GetArray(expand_op->inputs[1]); if (!axis_array.has_shape()) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc index 0fffab574ddd8ad75ec07ae4442f363a36ed289e..1ea83abf8eb1b49f649e81def29857094cd0c2d7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -38,6 +38,16 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { // Depthwise conv does not support dilation return false; } + auto& input_array = model->GetArray(conv_op->inputs[0]); + if (!input_array.has_shape()) { + // Shapes not propagated yet + return false; + } + if (input_array.shape().dims(3) != 1) { + // Not a pure convolution: Conv does accumulation across the depth + // dimension. + return false; + } auto& weights_array = model->GetArray(conv_op->inputs[1]); if (!weights_array.buffer) { // Yield until the weights are resolved as a constant array. @@ -46,11 +56,6 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { if (weights_array.data_type != ArrayDataType::kFloat) { return false; } - if (weights_array.shape().dims(3) != 1) { - // Not a pure convolution: Conv does accumulation across the depth - // dimension. - return false; - } // At this point we know we have a pure conv. Rewrite it as DepthwiseConv. AddMessageF( "%s is purely convolutional (input/weights depth is 1), replacing it by " diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_stack_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc similarity index 72% rename from tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_stack_to_reshape.cc rename to tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc index 0615b5e6c6db910ee847188427b416fd812aa141..75113a2a8c7c446bd13de8b5c1a8d8ef3cf7fdd6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_stack_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc @@ -25,19 +25,19 @@ limitations under the License. namespace toco { -bool ConvertTrivialStackToReshape::Run(Model* model, std::size_t op_index) { - auto stack_it = model->operators.begin() + op_index; - if (stack_it->get()->type != OperatorType::kStack) { +bool ConvertTrivialPackToReshape::Run(Model* model, std::size_t op_index) { + auto pack_it = model->operators.begin() + op_index; + if (pack_it->get()->type != OperatorType::kPack) { return false; } - auto* stack_op = static_cast(stack_it->get()); - if (stack_op->inputs.size() > 1) { + auto* pack_op = static_cast(pack_it->get()); + if (pack_op->inputs.size() > 1) { // Not trivial. return false; } - CHECK_EQ(stack_op->outputs.size(), 1); + CHECK_EQ(pack_op->outputs.size(), 1); - const auto& input_array = model->GetArray(stack_op->inputs[0]); + const auto& input_array = model->GetArray(pack_op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. return false; @@ -48,16 +48,16 @@ bool ConvertTrivialStackToReshape::Run(Model* model, std::size_t op_index) { return false; } - AddMessageF("Converting trivial %s to a reshape", LogName(*stack_op)); + AddMessageF("Converting trivial %s to a reshape", LogName(*pack_op)); // Note that we could convert to ExpandDims but toco prefers reshapes. auto* reshape_op = new TensorFlowReshapeOperator; - reshape_op->inputs = {stack_op->inputs[0]}; - reshape_op->outputs = stack_op->outputs; + reshape_op->inputs = {pack_op->inputs[0]}; + reshape_op->outputs = pack_op->outputs; // Create shape param. string shape_array_name = - AvailableArrayName(*model, stack_op->outputs[0] + "_shape"); + AvailableArrayName(*model, pack_op->outputs[0] + "_shape"); Array& shape_array = model->GetOrCreateArray(shape_array_name); *(shape_array.mutable_shape()->mutable_dims()) = { 1 + input_array.shape().dimensions_count()}; @@ -70,10 +70,10 @@ bool ConvertTrivialStackToReshape::Run(Model* model, std::size_t op_index) { } // Replace the operator in the graph. - const auto reshape_it = model->operators.emplace(stack_it, reshape_op); - stack_it = reshape_it + 1; - CHECK_EQ(stack_it->get(), stack_op); - model->operators.erase(stack_it); + const auto reshape_it = model->operators.emplace(pack_it, reshape_op); + pack_it = reshape_it + 1; + CHECK_EQ(pack_it->get(), pack_op); + model->operators.erase(pack_it); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc new file mode 100644 index 0000000000000000000000000000000000000000..b689be07926ecd9be4cc317735dc88eb90950e13 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) { + auto tile_it = model->operators.begin() + op_index; + if (tile_it->get()->type != OperatorType::kTile) { + return false; + } + auto* tile_op = static_cast(tile_it->get()); + + const auto& input_array = model->GetArray(tile_op->inputs[0]); + const auto& multiples_array = model->GetArray(tile_op->inputs[1]); + const auto& output_array = model->GetArray(tile_op->outputs[0]); + if (!input_array.has_shape() || !multiples_array.has_shape() || + !output_array.has_shape()) { + // Yield until PropagateFixedSizes has been run on this op. + return false; + } + // Note: We can assume we have error checked inputs in PropagateFixedSizes. + + if (!multiples_array.buffer) { + // Yield until the multiples is constant. + return false; + } + std::vector const& multiples = + multiples_array.GetBuffer().data; + + // We can simplify the tile if only a single dimension is being multiplied. + // It then just becomes a concat along that dimension. + int non_one_dims = 0; + int concat_axis = 0; + for (int i = 0; i < multiples.size(); ++i) { + if (multiples[i] != 1) { + ++non_one_dims; + concat_axis = i; + } + } + if (non_one_dims != 1) { + // The tile is non-trivial. Good luck. + AddMessageF("Tile %s is non-trivial (has more than one multiply dimension)", + LogName(*tile_op)); + return false; + } + + // The tile is like a concat. + AddMessageF("Simplifying %s to a Concat along a single axis %d", + LogName(*tile_op), concat_axis); + + auto* concat_op = new ConcatenationOperator; + + // Copy input and output. + // Note that we multiply out the input by the number of times requested. + for (int i = 0; i < multiples[concat_axis]; ++i) { + concat_op->inputs.push_back(tile_op->inputs[0]); + } + concat_op->axis = concat_axis; + concat_op->outputs = tile_op->outputs; + + // Delete multiples array if unused. + if (IsDiscardableArray(*model, tile_op->inputs[1]) && + CountOpsWithInput(*model, tile_op->inputs[1]) == 1) { + model->EraseArray(tile_op->inputs[1]); + } + + // Replace the operator in the graph. + const auto concat_it = model->operators.emplace(tile_it, concat_op); + tile_it = concat_it + 1; + CHECK_EQ(tile_it->get(), tile_op); + model->operators.erase(tile_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc index 8ca2cd66ac6377a70a4c504ac006dc0388b88bf7..1e68cd678bce6c27f1852a5ae0c13362d8938cdd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -25,17 +25,12 @@ limitations under the License. namespace toco { -bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { - auto conv_it = model->operators.begin() + op_index; - if (conv_it->get()->type != OperatorType::kConv) { - return false; - } - auto* conv_op = static_cast(conv_it->get()); - if (conv_op->outputs.size() == 2) { +bool ProcessConvOperator(Model* model, ConvOperator* op) { + if (op->outputs.size() == 2) { // We already have an im2col array return false; } - const auto& weights_array = model->GetArray(conv_op->inputs[1]); + const auto& weights_array = model->GetArray(op->inputs[1]); if (!weights_array.has_shape()) { // We need to yield until weights dims have been resolved, because // from the weights dims we determine whether an im2col array is @@ -45,26 +40,52 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { const auto& weights_shape = weights_array.shape(); const int kheight = weights_shape.dims(1); const int kwidth = weights_shape.dims(2); - if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 && - conv_op->stride_height == 1 && conv_op->dilation_width_factor == 1 && - conv_op->dilation_height_factor == 1) { + if (kwidth == 1 && kheight == 1 && op->stride_width == 1 && + op->stride_height == 1 && op->dilation_width_factor == 1 && + op->dilation_height_factor == 1) { // 1x1 unstrided undilated conv does not need an im2col array. return false; } // Create the im2col array. - CHECK_EQ(conv_op->outputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); const string& im2col_array_name = - AvailableArrayName(*model, conv_op->inputs[0] + "_im2col"); + AvailableArrayName(*model, op->inputs[0] + "_im2col"); model->GetOrCreateArray(im2col_array_name); - conv_op->outputs.push_back(im2col_array_name); - AddMessageF( - "Created an im2col array for %s, with %dx%d kernel and stride_width=%d, " - "stride_height=%d", - LogName(*conv_op), kwidth, kheight, conv_op->stride_width, - conv_op->stride_height); + op->outputs.push_back(im2col_array_name); return true; } +bool ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { + if (op->outputs.size() == 2) { + // We already have an im2col array + return false; + } + + // Always create an im2col array for transpose_conv. + CHECK_EQ(op->outputs.size(), 1); + const string& im2col_array_name = AvailableArrayName( + *model, op->inputs[TransposeConvOperator::DATA_INPUT] + "_im2col"); + model->GetOrCreateArray(im2col_array_name); + op->outputs.push_back(im2col_array_name); + + return true; +} + +bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + + switch (op->type) { + case OperatorType::kConv: + return ProcessConvOperator(model, static_cast(op)); + case OperatorType::kTransposeConv: + return ProcessTransposeConvOperator( + model, static_cast(op)); + default: + return false; + } +} + } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc index 498c864bde6d656c8318e981204cb42cb3a4d03f..1688586733b0434c7fc98686a19f0ceb8092f33b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc @@ -111,7 +111,7 @@ bool DequantizeArray(const string& array_name, auto* op_outputting_array = GetOpWithOutput(*model, array_name); if (op_outputting_array) { - if (op_outputting_array->type == OperatorType::kTensorFlowReshape) { + if (op_outputting_array->type == OperatorType::kReshape) { return true; } } @@ -159,6 +159,7 @@ bool DequantizeArray(const string& array_name, new_array.GetOrCreateMinMax() = array->GetMinMax(); fakequant_op->minmax.reset(new MinMax); *fakequant_op->minmax = array->GetMinMax(); + fakequant_op->narrow_range = array->narrow_range; if (must_insert_fakequant_before) { for (const auto& op : model->operators) { for (string& output : op->outputs) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc index 708ecf6e0a96811ab274fbb25f748f562cd3afad..e80ed036b311cfc586c40ece410ef6a6432a0cd9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc @@ -26,17 +26,38 @@ namespace toco { namespace { +int GetOutputDepthFromWeights(const Model& model, const Operator& op) { + const string& weights_name = op.inputs[1]; + const auto& weights_shape = model.GetArray(weights_name).shape(); + if (op.type == OperatorType::kConv || + op.type == OperatorType::kFullyConnected) { + return weights_shape.dims(0); + } + if (op.type == OperatorType::kDepthwiseConv) { + return weights_shape.dims(3); + } + LOG(FATAL) << "Unhandled operator type"; + return 0; +} + bool ProcessLinearOperator(Model* model, Operator* op) { if (op->inputs.size() >= 3) { return false; } const string& output_name = op->outputs[0]; + const string& weights_name = op->inputs[1]; + if (!model->GetArray(weights_name).has_shape()) { + return false; + } + const int depth = GetOutputDepthFromWeights(*model, *op); const string& bias_name = AvailableArrayName(*model, output_name + "_bias"); op->inputs.push_back(bias_name); DCHECK_EQ(op->inputs.size(), 3); auto& bias_array = model->GetOrCreateArray(bias_name); bias_array.data_type = ArrayDataType::kFloat; - + bias_array.mutable_shape()->mutable_dims()->push_back(depth); + auto& bias_buffer = bias_array.GetMutableBuffer(); + bias_buffer.data.resize(depth, 0.f); return true; } } // namespace diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc index 394fa349e2663e2806344f27a96a5132a2d4a810..c13fc0de7502a9edc80dc399354708a5b1b96b02 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc @@ -122,7 +122,7 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model, case OperatorType::kFullyConnected: { weights_index = 1; const auto& fc_op = static_cast(op); - CHECK(!fc_op.experimental_shuffled_weights) + CHECK(fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) << "This graph transformation expects to run before FC weights get " "shuffled."; break; @@ -181,7 +181,7 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model, // future without worrying. static constexpr int kMinDistanceBetweenBadValues = 16; if (distance < kMinDistanceBetweenBadValues) { - if (allow_nudging_weights()) { + if (allow_nudging_weights() || has_default_ranges_flag()) { buffer_data[i] = 1; changed = true; continue; @@ -200,6 +200,15 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model, } if (changed) { + if (has_default_ranges_flag()) { + std::cerr + << "Since the specified values of --default_ranges_min and " + "--default_ranges_max result in values incompatible with TFLite's " + "fast int8 kernels, " + "--allow_nudging_weights_to_use_fast_gemm_kernel " + "has been enabled. This may affect the accuracy of the model." + << std::endl; + } AddMessageF("Tweaked weights values for %s", LogName(op)); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc index 76c6be00d407ca30b898d088c9fa34cd7f76f656..19f166d35de6c304909c5bc76bdc8d3449aebcea 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -274,8 +274,10 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { return false; } - const auto& weights = model->GetArray(preceding_op->inputs[1]); - const auto& bias = model->GetArray(preceding_op->inputs[2]); + const auto& weights_name = preceding_op->inputs[1]; + const auto& bias_name = preceding_op->inputs[2]; + const auto& weights = model->GetArray(weights_name); + const auto& bias = model->GetArray(bias_name); if (binary_op->type == OperatorType::kAdd || binary_op->type == OperatorType::kSub) { if (!bias.buffer) { @@ -285,6 +287,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { LogName(*binary_op), LogName(*preceding_op)); return false; } + int count_ops_consuming_bias = CountOpsWithInput(*model, bias_name); + if (count_ops_consuming_bias > 1) { + AddMessageF( + "Not fusing %s because the bias of the preceding %s is consumed by " + "another op", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } } else { if (!weights.buffer || !bias.buffer) { AddMessageF( @@ -293,6 +303,14 @@ bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { LogName(*binary_op), LogName(*preceding_op)); return false; } + int count_ops_consuming_weights = CountOpsWithInput(*model, weights_name); + if (count_ops_consuming_weights > 1) { + AddMessageF( + "Not fusing %s because the weights of the preceding %s is consumed " + "by another op", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } } int count_ops_consuming_output = diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc new file mode 100644 index 0000000000000000000000000000000000000000..874d8def571fbce4219de15285c8df6fd2487a9a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +// Returns true if the given op is strictly a broadcasting operation. +// This is commonly seen as a Concat of the same input multiple times, and is +// often generated from Tile ops that were converted via the +// convert_trivial_tile_to_concat transformation. +bool IsBroadcastingOp(const Model& model, Operator* op) { + // Concatenation of identical inputs is usually a broadcast. + if (op->type == OperatorType::kConcatenation) { + // Verify that all inputs are the same. + for (int i = 1; i < op->inputs.size(); ++i) { + if (op->inputs[i] != op->inputs[0]) { + return false; + } + } + return true; + } + + // There are other things we could look for (Stack/etc) when needed. + return false; +} + +} // namespace + +// Finds an operation that looks like a broadcast (concat of the same sources +// along the last dimension) and drops it by relying on the ability of certain +// binary ops to perform an implicit broadcast. +bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + auto* binary_op = binary_it->get(); + + // Test for binary ops of types that we know how to resolve + if (binary_op->inputs.size() != 2) { + return false; + } + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + // NOTE: either of these ops may be nullptr if the input array is constant. + Operator* const op[2] = { + GetOpWithOutput(*model, binary_op->inputs[0]), + GetOpWithOutput(*model, binary_op->inputs[1]), + }; + + // Check whether either input is a broadcast-like concat. + bool is_op_0_broadcast = op[0] && IsBroadcastingOp(*model, op[0]); + bool is_op_1_broadcast = op[1] && IsBroadcastingOp(*model, op[1]); + if (!is_op_0_broadcast && !is_op_1_broadcast) { + // Neither input is a broadcast-looking thing. + AddMessageF("Neither input looks broadcasty"); + return false; + } else if (is_op_0_broadcast && is_op_1_broadcast) { + AddMessageF( + "Unable to fuse broadcast into %s as both inputs (%s, %s) are " + "broadcasts", + LogName(*binary_op), op[0] ? LogName(*op[0]) : "(?)", + op[1] ? LogName(*op[1]) : "(?)"); + return false; + } + int broadcast_index = is_op_0_broadcast ? 0 : 1; + + // Just pull out the input of the broadcast op and pass it directly to the + // binary op. + AddMessageF("Fusing broadcast op %s into the following binary %s", + LogName(*op[broadcast_index]), LogName(*binary_op)); + binary_op->inputs[broadcast_index] = op[broadcast_index]->inputs[0]; + + // We leave the broadcast op in; it'll get cleaned up if it's not used later. + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 1bc7557d46cfa5e1b27468d2da271e75fd491d36..99f4a7d8f61eb3e75b643673d42d4b2103309f2e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -116,13 +116,15 @@ DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) DECLARE_GRAPH_TRANSFORMATION(ConvertSqueezeToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd) -DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialStackToReshape) +DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialPackToReshape) +DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTileToConcat) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes) DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors) DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions) DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine) DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine) +DECLARE_GRAPH_TRANSFORMATION(FuseBroadcastIntoFollowingBinary) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool) DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) @@ -133,6 +135,7 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu) DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv) DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) +DECLARE_GRAPH_TRANSFORMATION(MoveBinaryOperatorBeforeReshape) DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants) DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes) DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits); @@ -156,7 +159,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantBinaryOperator) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator) DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays) DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays) -DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax) +DECLARE_GRAPH_TRANSFORMATION(ReadArrayMinmaxAndNarrowRangeFromFakeQuant) DECLARE_GRAPH_TRANSFORMATION(ReorderElementwiseUnary) DECLARE_GRAPH_TRANSFORMATION(ReorderReshapeTranspose) DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes) @@ -165,7 +168,6 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge) DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch) -DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantReshape) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose) @@ -178,20 +180,24 @@ DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolvePadV2Attributes) DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes) -DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveReduceAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantPack) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice) -DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantGather) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSelect) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTile) DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero) DECLARE_GRAPH_TRANSFORMATION(Dequantize) DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup) -DECLARE_GRAPH_TRANSFORMATION(ExperimentalShuffleFCWeights) +DECLARE_GRAPH_TRANSFORMATION(ShuffleFCWeights) +DECLARE_GRAPH_TRANSFORMATION(ResolveFakeQuantArgsFromVars) +DECLARE_GRAPH_TRANSFORMATION(ResolveGatherAttributes) class PropagateDefaultMinMax : public GraphTransformation { public: @@ -258,8 +264,12 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation { bool allow_nudging_weights() const { return allow_nudging_weights_; } void set_allow_nudging_weights(bool val) { allow_nudging_weights_ = val; } + bool has_default_ranges_flag() const { return has_default_ranges_flag_; } + void set_has_default_ranges_flag(bool val) { has_default_ranges_flag_ = val; } + private: bool allow_nudging_weights_ = false; + bool has_default_ranges_flag_ = false; }; #undef DECLARE_GRAPH_TRANSFORMATION diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index bda6dce22be0f0ca83eb8339ad17573b0267c18c..d26c3b2878b8499fcbabc5448de9ec045eb07879 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -133,24 +133,20 @@ bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) { } bool HardcodeMinMaxForSplit(Model* model, Operator* op) { - for (const auto& output : op->outputs) { - if (model->GetArray(output).minmax) { - LOG(WARNING) << "Skipping min-max setting for " << LogName(*op) - << " because output " << output << " already has min-max."; - return false; - } - } // Data is in second input. auto& input_array = model->GetArray(op->inputs[1]); if (!input_array.minmax) { return false; - } else { - for (const auto& output : op->outputs) { - auto& array = model->GetArray(output); + } + bool changed = false; + for (const auto& output : op->outputs) { + auto& array = model->GetArray(output); + if (!array.minmax || !(array.GetMinMax() == input_array.GetMinMax())) { + changed = true; array.GetOrCreateMinMax() = *input_array.minmax; } - return true; } + return changed; } // The output of average or max pooling is within the same range as its input. @@ -232,6 +228,14 @@ bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min, return true; } +bool MinMaxApproximatelyEqual(const MinMax& minmax1, const MinMax& minmax2) { + const double magnitude = + std::min(minmax1.max - minmax1.min, minmax2.max - minmax2.min); + const double tolerated = 1e-6 * magnitude; + return std::abs(minmax1.min - minmax2.min) < tolerated && + std::abs(minmax1.max - minmax2.max) < tolerated; +} + // Propagates MinMax from any of the listed arrays, to all others. // If multiple of these arrays have MinMax, then these are required // to agree with each other. @@ -254,7 +258,7 @@ bool PropagateMinMaxAmongArrays(Model* model, for (const string& array_name : array_names) { auto& array = model->GetArray(array_name); if (array.minmax) { - CHECK(*array.minmax == *reference_minmax) + CHECK(MinMaxApproximatelyEqual(*array.minmax, *reference_minmax)) << "Both the following arrays have minmax, and they disagree: " << reference_array_name << " (" << reference_minmax->min << "," << reference_minmax->max << ") and " << array_name << " (" @@ -353,7 +357,7 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForConcatenation(model, op); break; - case OperatorType::kTensorFlowSplit: + case OperatorType::kSplit: changed = HardcodeMinMaxForSplit(model, op); break; @@ -366,13 +370,27 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { case OperatorType::kSlice: case OperatorType::kStridedSlice: case OperatorType::kSqueeze: - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: + case OperatorType::kExpandDims: case OperatorType::kPad: case OperatorType::kGather: case OperatorType::kTranspose: case OperatorType::kMean: changed = HardcodeMinMaxFromFirstInput(model, op); break; + case OperatorType::kSum: + // reduce_sum is expected to change the output range. Hence + // a fake_quant op is necessary in the output to minimize error. However + // in special circumstances like when computing expected value using + // reduce_sum the input range and the output range matches. Hence the + // below code would act as a fallback. If a fake_quant node is observed in + // the output that takes precendence over the hard coding logic below. + changed = HardcodeMinMaxFromFirstInput(model, op); + if (changed) { + LOG(WARNING) << "Using the input range for output in reduce_sum op." + << "This could have an impact on your model accuracy."; + } + break; case OperatorType::kSelect: changed = HardcodeMinMaxForSelect(model, op); break; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc index 419a0776a6b987a18df059d3c1d4bf4370cd24d8..b78efd7fc3602dc2d6e03fd28d694c344b61c17c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -44,10 +44,9 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { const auto* div_or_mul_op = div_it->get(); OperatorType expected_op_type_producing_div_or_mul_input; if (div_or_mul_op->type == OperatorType::kDiv) { - expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt; + expected_op_type_producing_div_or_mul_input = OperatorType::kSqrt; } else if (div_or_mul_op->type == OperatorType::kMul) { - expected_op_type_producing_div_or_mul_input = - OperatorType::kTensorFlowRsqrt; + expected_op_type_producing_div_or_mul_input = OperatorType::kRsqrt; } else { return false; } @@ -75,8 +74,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { Operator* add_op = nullptr; Operator* op_producing_add_input = nullptr; if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd || - op_producing_sqrt_or_rsqrt_input->type == - OperatorType::kTensorFlowMaximum) { + op_producing_sqrt_or_rsqrt_input->type == OperatorType::kMaximum) { add_op = op_producing_sqrt_or_rsqrt_input; bool add_can_be_removed = false; CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2); @@ -113,7 +111,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { Operator* sum_op = add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input; - if (sum_op->type != OperatorType::kTensorFlowSum) { + if (sum_op->type != OperatorType::kSum) { AddMessageF( "Giving up trying to identify L2Normalization subgraph: " "expected Sum op, got %s", @@ -122,7 +120,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { } Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]); - if (square_op->type != OperatorType::kTensorFlowSquare) { + if (square_op->type != OperatorType::kSquare) { AddMessageF( "Giving up trying to identify L2Normalization subgraph: " "expected Square op, got %s", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc index e4d52476c649de53b3ab663f53ce7a5538dbb5ab..705e73779b7f74698149d5e9e56f69a371326ceb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc @@ -41,7 +41,7 @@ std::vector>::iterator FindOperator( bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { const auto sqrt_it = model->operators.begin() + op_index; const auto* sqrt_op = sqrt_it->get(); - if (sqrt_op->type != OperatorType::kTensorFlowSqrt) { + if (sqrt_op->type != OperatorType::kSqrt) { return false; } @@ -52,6 +52,13 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { const Operator* square_op; Operator* prev_to_sqrt_op = GetOpWithOutput(*model, sqrt_op->inputs[0]); + if (prev_to_sqrt_op == nullptr) { + AddMessageF( + "Giving up trying to identify L2Pool subgraph: " + "expected AveragePool op, but Sqrt op has no preceding op"); + return false; + } + if (prev_to_sqrt_op->type != OperatorType::kAveragePool) { AddMessageF( "Giving up trying to identify L2Pool subgraph: " @@ -65,7 +72,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { square_op = GetOpWithOutput(*model, avpool_op->inputs[0]); CHECK_EQ(square_op->inputs.size(), 1); - if (square_op->type != OperatorType::kTensorFlowSquare) { + if (square_op->type != OperatorType::kSquare) { AddMessageF( "Giving up trying to identify L2Pool subgraph: " "expected Square op, got %s", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc index e9842524c829b839b97b3453a36c41efe186efbb..c0b014b45eb1df25173ce3ca3fa488b0655c3c76 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc @@ -35,19 +35,24 @@ std::vector>::iterator FindOperator( return it; } -bool GetStateArrayForBackEdge(const Model& model, - const string& back_edge_source_array, - string* state_array = nullptr) { - for (const auto& rnn_state : model.flags.rnn_states()) { - if (back_edge_source_array == rnn_state.back_edge_source_array()) { - // Found LSTM cell output - if (state_array) { - *state_array = rnn_state.state_array(); - } - return true; +bool ValidateSourceOp(const Model& model, const string& array_name, + OperatorType op_type, Operator** source_op) { + if (op_type == OperatorType::kNone) { + CHECK(!source_op); + } else { + CHECK(source_op); + *source_op = GetOpWithOutput(model, array_name); + if (*source_op == nullptr) { + return false; + } + + // Check that first operator, if connected, is of correct type + if ((*source_op)->type != op_type) { + return false; } } - return false; + + return true; } // Returns true if the given operator has exactly 1 input, and is connected to @@ -62,24 +67,10 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // Check if first input is disconnected/connected to an operator - Operator* x = GetOpWithOutput(model, op.inputs[0]); - if ((op_type == OperatorType::kNone) && (x != nullptr)) { - return false; - } - if ((op_type != OperatorType::kNone) && (x == nullptr)) { + if (!ValidateSourceOp(model, op.inputs[0], op_type, connected_op)) { return false; } - // Check that first operator, if connected, is of correct type - if ((x != nullptr) && (x->type != op_type)) { - return false; - } - - // Successfully matched. Optionally return matching input operators. - if (connected_op) { - *connected_op = x; - } - return true; } @@ -96,40 +87,15 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // Check if first input is disconnected/connected to an operator - Operator* x = GetOpWithOutput(model, op.inputs[0]); - if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { - return false; - } - if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { - return false; - } - - // Check that first operator, if connected, is of correct type - if ((x != nullptr) && (x->type != a_op_type)) { + if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) { return false; } // Check if second input is disconnected/connected to an operator - Operator* y = GetOpWithOutput(model, op.inputs[1]); - if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { - return false; - } - if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { + if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) { return false; } - // Check that second operator, if connected, is of correct type - if ((y != nullptr) && (y->type != b_op_type)) { - return false; - } - - // Successfully matched. Optionally return matching input operators. - if (a_op != nullptr) { - *a_op = x; - } - if (b_op != nullptr) { - *b_op = y; - } return true; } @@ -147,57 +113,20 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // Check if first input is disconnected/connected to an operator - Operator* x = GetOpWithOutput(model, op.inputs[0]); - if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { - return false; - } - if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { - return false; - } - - // Check that first operator, if connected, is of correct type - if ((x != nullptr) && (x->type != a_op_type)) { + if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) { return false; } // Check if second input is disconnected/connected to an operator - Operator* y = GetOpWithOutput(model, op.inputs[1]); - if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { - return false; - } - if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { - return false; - } - - // Check that second operator, if connected, is of correct type - if ((y != nullptr) && (y->type != b_op_type)) { + if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) { return false; } // Check if third input is disconnected/connected to an operator - Operator* z = GetOpWithOutput(model, op.inputs[2]); - if ((c_op_type == OperatorType::kNone) && (z != nullptr)) { - return false; - } - if ((c_op_type != OperatorType::kNone) && (z == nullptr)) { - return false; - } - - // Check that third operator, if connected, is of correct type - if ((z != nullptr) && (z->type != c_op_type)) { + if (!ValidateSourceOp(model, op.inputs[2], c_op_type, c_op)) { return false; } - // Successfully matched. Optionally return matching input operators. - if (a_op != nullptr) { - *a_op = x; - } - if (b_op != nullptr) { - *b_op = y; - } - if (c_op != nullptr) { - *c_op = z; - } return true; } @@ -231,11 +160,6 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { &state_combine_add)) { return false; } - string prev_state; - if (!GetStateArrayForBackEdge(*model, state_output_tanh->inputs[0], - &prev_state)) { - return false; - } // State forget & remember addition Operator *state_forget_mul, *state_remember_mul; @@ -244,9 +168,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { &state_remember_mul)) { return false; } - if (state_forget_mul->inputs[0] != prev_state) { - return false; - } + const string prev_state = state_forget_mul->inputs[0]; // State forget gate Operator* state_forget_sig; @@ -266,26 +188,26 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { // State remember "information" activation function Operator* fc_output_split; - if (!MatchOperatorInputs(*state_info_tanh, *model, - OperatorType::kTensorFlowSplit, &fc_output_split)) { + if (!MatchOperatorInputs(*state_info_tanh, *model, OperatorType::kSplit, + &fc_output_split)) { return false; } // State remember gate activation function Operator* tmp; - if (!MatchOperatorInputs(*state_remember_sig, *model, - OperatorType::kTensorFlowSplit, &tmp) || + if (!MatchOperatorInputs(*state_remember_sig, *model, OperatorType::kSplit, + &tmp) || (tmp != fc_output_split)) { return false; } // State forget gate activation function - if (!MatchOperatorInputs(*state_forget_sig, *model, - OperatorType::kTensorFlowSplit, &tmp) || + if (!MatchOperatorInputs(*state_forget_sig, *model, OperatorType::kSplit, + &tmp) || (tmp != fc_output_split)) { return false; } // Fully connected output activation function - if (!MatchOperatorInputs(*fc_output_sig, *model, - OperatorType::kTensorFlowSplit, &tmp) || + if (!MatchOperatorInputs(*fc_output_sig, *model, OperatorType::kSplit, + &tmp) || (tmp != fc_output_split)) { return false; } @@ -306,8 +228,8 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { return false; } - if (static_cast(fully_connected) - ->experimental_shuffled_weights) { + if (static_cast(fully_connected)->weights_format != + FullyConnectedWeightsFormat::kDefault) { // Not yet implemented: experimental shuffled weights in fused LSTM cell. return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc index e6e3dfa1de9c9fdd5e759fd547d11a7b8c95d837..46d1fce50e5d6e2a74cf5461d731e46469dde5bf 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -74,6 +74,12 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { lstm_cell_op->inputs[kInputTensor] = curr_op->inputs[LstmCellOperator::ACTIV_OUTPUT]; + // Previous states. + lstm_cell_op->inputs[kInputActivationStateTensor] = + curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]; + lstm_cell_op->inputs[kInputCellStateTensor] = + curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT]; + // Get original weight tensor and decompose 1 tensor to 8 sub tensors. Array& kernel = model->GetArray(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]); @@ -160,10 +166,6 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { // Erase curr lstm op being replaced. DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT], model); DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model); - DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT], - model); - DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT], - model); model->operators.erase(FindOp(*model, curr_op)); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc index 30be4ac0aa5e9f639bbf0630e142c2806faa3260..b90a156a0dcfcd77c3e2b47bb0d77e246f2fc625 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc @@ -74,14 +74,30 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { const auto* relu_neg_input_op = GetOpWithOutput(*model, mul_op->inputs[1]); if (relu_neg_input_op == nullptr || - relu_neg_input_op->type != OperatorType::kNeg || - relu_neg_input_op->fused_activation_function != - FusedActivationFunctionType::kRelu || relu_neg_input_op->inputs.size() != 1) { return false; } - if (relu_input_op->inputs[0] != relu_neg_input_op->inputs[0]) { + const Operator* final_input_op; + if (relu_neg_input_op->type == OperatorType::kNeg && + relu_neg_input_op->fused_activation_function == + FusedActivationFunctionType::kRelu) { + // This detects a Neg op with fused Relu activation function. + final_input_op = relu_neg_input_op; + } else { + // This detects a Neg op followed by a separated Relu op. + const auto* neg_input_op = + GetOpWithOutput(*model, relu_neg_input_op->inputs[0]); + if (neg_input_op == nullptr || neg_input_op->inputs.size() != 1 || + relu_neg_input_op->type != OperatorType::kRelu || + relu_neg_input_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + return false; + } + final_input_op = neg_input_op; + } + + if (relu_input_op->inputs[0] != final_input_op->inputs[0]) { return false; } @@ -112,7 +128,6 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { // intermediate tensors aren't used by other ops, those will be removed by // other graph transformation rules. model->operators.erase(FindOp(*model, add_op)); - return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc index bddb563206f763a756685d196836fa41825cf045..94820a016622a12654e91967737e05fc91ed404c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc @@ -60,24 +60,22 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { // Follow sequences of min+max and max+min. First get the leading op. const auto op_it = model->operators.begin() + op_index; const auto* op_0 = op_it->get(); - if (op_0->type != OperatorType::kTensorFlowMinimum && - op_0->type != OperatorType::kTensorFlowMaximum) { + if (op_0->type != OperatorType::kMinimum && + op_0->type != OperatorType::kMaximum) { return false; } // Get the paired op and ensure it's the counter to the first. const auto* op_1 = GetOpWithInput(*model, op_0->outputs[0]); if (!op_1 || - (op_1->type != OperatorType::kTensorFlowMinimum && - op_1->type != OperatorType::kTensorFlowMaximum) || + (op_1->type != OperatorType::kMinimum && + op_1->type != OperatorType::kMaximum) || op_0->type == op_1->type) { return false; } - const auto* min_op = - op_0->type == OperatorType::kTensorFlowMinimum ? op_0 : op_1; - const auto* max_op = - op_0->type == OperatorType::kTensorFlowMaximum ? op_0 : op_1; + const auto* min_op = op_0->type == OperatorType::kMinimum ? op_0 : op_1; + const auto* max_op = op_0->type == OperatorType::kMaximum ? op_0 : op_1; if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) { return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h index 1c32a781698ec78003ebbf9caff28557924323e5..6d8603a1133a7478647b8bcc49ea1eceba28df31 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h @@ -47,10 +47,14 @@ enum ExtendedLstmCellInputs { kOutputGateBiasTensor = 15, kProjectionWeightsTensor = 16, // Optional kProjectionBiasTensor = 17, // Optional - kExtendedLstmInputCount = 18 + kInputActivationStateTensor = 18, + // The op can handle 18 inputs or 20 inputs. + kInputCellStateTensor = 19, + kExtendedLstmInputCount = 20, }; enum ExtendedLstmCellOutputs { + // TODO(ycling): Make the 2 output state tensors optional. kOutputStateTensor = 0, kCellStateTensor = 1, kOutputTensor = 2, diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc index 45d9f73a1e6416b8f3fe3936c740da637961b7fc..f684de08abf72d05d4408bf6341fa5a3c2ed11cd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc @@ -85,15 +85,8 @@ bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op, dequantized_input_minmax = input_minmax; auto& input_qparams = input_array.GetOrCreateQuantizationParams(); input_array.data_type = input_array.final_data_type; - if (input_array.data_type == ArrayDataType::kUint8) { - GetQuantizationParamsFromMinMax(input_minmax, - &input_qparams); - } else if (input_array.data_type == ArrayDataType::kInt16) { - GetQuantizationParamsFromMinMax(input_minmax, - &input_qparams); - } else { - LOG(FATAL) << "unhandled data type"; - } + ChooseQuantizationParamsForArrayAndQuantizedDataType( + input_array, input_array.data_type, &input_qparams); transformation->AddMessageF( "Created %s" diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc index 5065004093434475172a39efdcfd26c10c49148b..95bc7f7d4b8b517c1cc5a73b3e85bbd985ce460f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc @@ -106,7 +106,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, std::size_t op_index) { auto it = model->operators.begin() + op_index; auto* reshape_op = ConvertOperator( - it->get(), OperatorType::kTensorFlowReshape); + it->get(), OperatorType::kReshape); if (reshape_op == nullptr) { return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f44c65285bdef6ba314b16122fdd550bfa47e6a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc @@ -0,0 +1,178 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +bool IsTailOfShape(const Shape& tail, const Shape& shape) { + // Return true if 'tail' dimensions are the same as the ending dimensions of + // 'shape'. + + int shape_end = shape.dimensions_count() - 1; + int tail_end = tail.dimensions_count() - 1; + + if (tail_end > shape_end) { + // tail cannot be longer than shape. + return false; + } + + // Walk dimensions back to front and compare + for (int i = 0; i <= tail_end; i++) { + if (shape.dims(shape_end - i) != tail.dims(tail_end - i)) { + return false; + } + } + return true; +} + +} // namespace + +// If a binary operator is doing a broadcast operation from a constant array, +// and the constant array shape is the tail of both the other input shape, and a +// subsequent reshape op's output shape, we can swap their order. Since we +// prefer to have reshape ops after mathematic ops, this can allow for the +// collapsing of some reshapes. The WaveNet model in particular benefits from +// this transformation. +// +// Note we are testing for one particular case of a broader set of possible +// binary-reshape op transformations. This transformation could be generalized. +bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + Operator* binary_op = binary_it->get(); + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv && + binary_op->type != OperatorType::kFloorDiv && + binary_op->type != OperatorType::kFloorMod && + binary_op->type != OperatorType::kMinimum && + binary_op->type != OperatorType::kMaximum && + binary_op->type != OperatorType::kLess && + binary_op->type != OperatorType::kLessEqual && + binary_op->type != OperatorType::kGreater && + binary_op->type != OperatorType::kGreaterEqual) { + return false; + } + + // BINARY OP INPUT CHECKS + CHECK_EQ(binary_op->inputs.size(), 2); + const bool input_is_const[2] = { + IsConstantParameterArray(*model, binary_op->inputs[0]), + IsConstantParameterArray(*model, binary_op->inputs[1]), + }; + if (!input_is_const[0] && !input_is_const[1]) { + // To limit our scope, we require one constant input. Though there's no + // reason this transformation wouldn't work with all variable inputs. + return false; + } + if (input_is_const[0] && input_is_const[1]) { + // Both inputs are constants. Leave this for constants propagation. + return false; + } + const int constant_input_idx = input_is_const[0] ? 0 : 1; + const int variable_input_idx = input_is_const[0] ? 1 : 0; + CHECK(input_is_const[constant_input_idx]); + CHECK(!input_is_const[variable_input_idx]); + + const auto& variable_input_array = + model->GetArray(binary_op->inputs[variable_input_idx]); + if (!variable_input_array.has_shape()) { + AddMessageF( + "Not moving %s because it's non-constant input shape is not resolved.", + LogName(*binary_op)); + return false; + } + if (!IsTailOfShape( + model->GetArray(binary_op->inputs[constant_input_idx]).shape(), + model->GetArray(binary_op->inputs[variable_input_idx]).shape())) { + // Constant array shape must be the latter part of the variable shape. + return false; + } + + // RESHAPE OP CHECKS + auto reshape_it = + FindOpWithOutput(*model, binary_op->inputs[variable_input_idx]); + if (reshape_it == model->operators.end()) { + AddMessageF("Not moving %s because it's variable input is not connected.", + LogName(*binary_op)); + return false; + } + Operator* reshape_op = reshape_it->get(); + if (reshape_op->type != OperatorType::kReshape) { + AddMessageF("Not moving %s because the preceding %s is not a reshape op", + LogName(*binary_op), LogName(*reshape_op)); + return false; + } + const auto& reshape_input_array = model->GetArray(reshape_op->inputs[0]); + if (!reshape_input_array.has_shape()) { + AddMessageF( + "Not moving %s because it's non-constant input shape is not resolved " + "yet", + LogName(*binary_op)); + return false; + } + if (!IsTailOfShape( + model->GetArray(binary_op->inputs[constant_input_idx]).shape(), + model->GetArray(reshape_op->outputs[0]).shape())) { + // Constant array shape must be the latter part of the binary op output + // shape. + return false; + } + + // EXTRA CHECKS ON CONNECTING ARRAY + for (const string& output_array : model->flags.output_arrays()) { + if (binary_op->inputs[variable_input_idx] == output_array) { + AddMessageF( + "Not moving %s because the output of reshape op %s is an output op.", + LogName(*binary_op), LogName(*reshape_op)); + return false; + } + } + int count_ops_consuming_output = + CountOpsWithInput(*model, binary_op->inputs[variable_input_idx]); + DCHECK_GE(count_ops_consuming_output, 1); + if (count_ops_consuming_output > 1) { + AddMessageF( + "Not moving %s because the output of reshape op %s is consumed by " + "another op", + LogName(*binary_op), LogName(*reshape_op)); + return false; + } + + // SWAP ORDER OF BINARY AND RESHAPE OPS + AddMessageF("Moving op %s before reshape op %s", LogName(*binary_op), + LogName(*reshape_op)); + + // Swap op input and outputs + std::iter_swap(reshape_op->inputs.begin(), + binary_op->inputs.begin() + variable_input_idx); + std::iter_swap(reshape_op->outputs.begin(), binary_op->outputs.begin()); + + // Swap operator ordering + std::iter_swap(binary_it, reshape_it); + + // Clear binary output shape so it will be re-propagated + model->GetArray(binary_op->outputs[0]).clear_shape(); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 92d283ca2cc7069f4b80c95ffdadcad81a884cbf..c8310161cb33bcc7137e8b163ea6469698ed2fd7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -56,22 +56,26 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // These operators unconditionally produce float outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat); break; - case OperatorType::kTensorFlowLess: - case OperatorType::kTensorFlowLessEqual: - case OperatorType::kTensorFlowGreater: - case OperatorType::kTensorFlowGreaterEqual: - case OperatorType::kTensorFlowEqual: - case OperatorType::kTensorFlowNotEqual: + case OperatorType::kLess: + case OperatorType::kLessEqual: + case OperatorType::kGreater: + case OperatorType::kGreaterEqual: + case OperatorType::kEqual: + case OperatorType::kNotEqual: + case OperatorType::kAny: + case OperatorType::kLogicalAnd: + case OperatorType::kLogicalNot: + case OperatorType::kLogicalOr: // These operators unconditionally produce bool outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); break; case OperatorType::kRank: - case OperatorType::kTensorFlowShape: + case OperatorType::kShape: // These operators only produce int32 outputs. SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); break; - case OperatorType::kTensorFlowSplit: - case OperatorType::kTensorFlowConcat: + case OperatorType::kSplit: + case OperatorType::kConcat: case OperatorType::kFill: { // These operators produce an output with the same type as their 2nd input CHECK_GE(op->inputs.size(), 2); @@ -100,6 +104,13 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type; break; } + case OperatorType::kArgMin: { + // Data type of the ArgMin op is specified. + CHECK_EQ(op->outputs.size(), 1); + auto* argmin_op = static_cast(op); + model->GetArray(op->outputs[0]).data_type = argmin_op->output_data_type; + break; + } case OperatorType::kRange: { auto* range_op = static_cast(op); // Output type of the Range op can be set via an attribute @@ -131,11 +142,12 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 2); CHECK(model->GetArray(op->inputs[1]).data_type == ArrayDataType::kInt32); - model->GetArray(op->outputs[0]).data_type = model->GetArray(op->inputs[0]).data_type; + model->GetArray(op->outputs[0]).data_type = + model->GetArray(op->inputs[0]).data_type; model->GetArray(op->outputs[1]).data_type = ArrayDataType ::kInt32; break; } - case OperatorType::kTensorFlowUnsupported: { + case OperatorType::kUnsupported: { auto* unsupported_op = static_cast(op); // Some output tensors from the op could be eliminated by optimization. // This can make unsupported_op->output_data_types have more elements than @@ -144,8 +156,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { return false; } for (int i = 0; i < op->outputs.size(); ++i) { - auto output = op->outputs[i]; - auto data_type = unsupported_op->output_data_types[i]; + const string& output = op->outputs[i]; + const ArrayDataType data_type = unsupported_op->output_data_types[i]; model->GetArray(output).data_type = data_type; } break; @@ -175,6 +187,46 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, data_type); break; } + case OperatorType::kPow: { + CHECK_EQ(op->inputs.size(), 2); + CHECK(model->GetArray(op->inputs[0]).data_type == + model->GetArray(op->inputs[1]).data_type); + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + SetDataTypeForAllOutputs(model, op, data_type); + break; + } + case OperatorType::kPack: { + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + for (const auto& input : op->inputs) { + CHECK(data_type == model->GetArray(input).data_type); + } + SetDataTypeForAllOutputs(model, op, data_type); + break; + } + case OperatorType::kOneHot: { + CHECK_EQ(op->inputs.size(), 4); + CHECK_EQ(op->outputs.size(), 1); + const ArrayDataType on_value_type = + model->GetArray(op->inputs[OneHotOperator::ON_VALUE_INPUT]).data_type; + const ArrayDataType off_value_type = + model->GetArray(op->inputs[OneHotOperator::OFF_VALUE_INPUT]) + .data_type; + CHECK(on_value_type == off_value_type); + model->GetArray(op->outputs[0]).data_type = on_value_type; + break; + } + case OperatorType::kCTCBeamSearchDecoder: { + CHECK_EQ(op->inputs.size(), 2); + // All outputs (sparse tensors) are int32s (although tf uses int64s) + // except the last one (log probabilities) is float. + const int output_size = op->outputs.size(); + for (int i = 0; i < output_size - 1; ++i) { + model->GetArray(op->outputs[i]).data_type = ArrayDataType::kInt32; + } + model->GetArray(op->outputs[output_size - 1]).data_type = + ArrayDataType::kFloat; + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc index 50b90e7c2bfddb0382a4d44ad6c90fc7f7701273..cd078ef189e922682098a0ec8dc4743060181aac 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc @@ -25,6 +25,14 @@ limitations under the License. namespace toco { +namespace { + +bool SupportsMinMax(const Array& array) { + return array.data_type == ArrayDataType::kFloat; +} + +} // namespace + // Propagates default min/max values to any operator input/output array that // is missing them. // @@ -39,14 +47,16 @@ bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) { for (const auto& input : op->inputs) { auto& input_array = model->GetArray(input); - if (!input_array.minmax && !input_array.buffer) { + if (!input_array.minmax && !input_array.buffer && + SupportsMinMax(input_array)) { did_change |= SetArrayMinMax(input, &input_array); } } for (const auto& output : op->outputs) { auto& output_array = model->GetArray(output); - if (!output_array.minmax && !output_array.buffer) { + if (!output_array.minmax && !output_array.buffer && + SupportsMinMax(output_array)) { did_change |= SetArrayMinMax(output, &output_array); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index 6d51fc8c31e6c86701c3dc1fd07a9a5479114738..3ad6b0ec6f7a3c4a9a0ab3964c1198ee757ea4b5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -27,11 +27,15 @@ namespace toco { namespace { -void ChangeArrayDataType(GraphTransformation* transformation, Array* array, +bool ChangeArrayDataType(GraphTransformation* transformation, Array* array, ArrayDataType new_data_type, const MinMax* new_minmax) { // Ensure the array ends up in the new type (if it hasn't yet been quantized). - array->final_data_type = new_data_type; + bool changed = false; + if (array->final_data_type != new_data_type) { + array->final_data_type = new_data_type; + changed = true; + } if (array->minmax && array->quantization_params) { // The array is already quantized and has min/max info. @@ -62,18 +66,16 @@ void ChangeArrayDataType(GraphTransformation* transformation, Array* array, "Rescaling min/max from %g,%g (%s) to %g,%g (%s)", array_minmax.min, array_minmax.max, ArrayDataTypeName(array->data_type), min, max, ArrayDataTypeName(new_data_type)); - array_minmax.min = min; array_minmax.max = max; - GetQuantizationParamsFromMinMax( - array_minmax, array->quantization_params.get()); - + ChooseQuantizationParamsForArrayAndQuantizedDataType( + *array, new_data_type, array->quantization_params.get()); // Directly change the type as the array was already quantized. array->data_type = new_data_type; - } else { + changed = true; + } else if (!array->quantization_params) { // Array has not yet been quantized so we can just set the final data type // and assign the new min/max value (if provided). - CHECK(!array->quantization_params); if (!array->minmax && new_minmax) { transformation->AddMessageF("Forcing new minmax to %g,%g (%s)", @@ -82,16 +84,19 @@ void ChangeArrayDataType(GraphTransformation* transformation, Array* array, auto& array_minmax = array->GetOrCreateMinMax(); array_minmax.min = new_minmax->min; array_minmax.max = new_minmax->max; + changed = true; } } + + return changed; } // Returns true if the op blocks our backward recursive data type propagation. bool DoesOpBlockBackwardPropagation(const Operator& op) { switch (op.type) { case OperatorType::kConcatenation: - case OperatorType::kTensorFlowConcat: - case OperatorType::kTensorFlowConcatV2: + case OperatorType::kConcat: + case OperatorType::kConcatV2: // Concat shouldn't block propagation, but we do expect that all inputs // have the same range. return false; @@ -100,9 +105,10 @@ bool DoesOpBlockBackwardPropagation(const Operator& op) { // FakeQuant so make sure we move across them. case OperatorType::kGather: // Gathers need their parameters changed to the appropriate data type. - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kTranspose: case OperatorType::kSelect: + case OperatorType::kTile: // Reshapes and transposes don't change values. return false; default: @@ -120,10 +126,13 @@ bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) { // Ignore gather indices. return input_index != 0; break; - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kTranspose: // Ignore reshape/transpose shapes/dimensions. return input_index != 0; + case OperatorType::kTile: + // Ignore tile multiples. + return input_index != 0; default: return false; } @@ -155,9 +164,8 @@ bool RecursivelyBackwardPropagateDataType(GraphTransformation* transformation, "Adjusting input final data type of array %s from %s to %s", input, ArrayDataTypeName(input_array.final_data_type), ArrayDataTypeName(new_data_type)); - did_change = true; - ChangeArrayDataType(transformation, &input_array, new_data_type, - &new_minmax); + did_change |= ChangeArrayDataType(transformation, &input_array, + new_data_type, &new_minmax); // Walk up into all ops producing the inputs to this op. for (auto& producing_op : model->operators) { @@ -208,9 +216,8 @@ bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation, "Adjusting output final data type of array %s from %s to %s", output, ArrayDataTypeName(output_array.final_data_type), ArrayDataTypeName(new_data_type)); - did_change = true; - ChangeArrayDataType(transformation, &output_array, new_data_type, - nullptr); + did_change |= ChangeArrayDataType(transformation, &output_array, + new_data_type, nullptr); // Walk down into all ops consuming the output of this op. for (auto& consuming_op : model->operators) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 170a499d4eeea6f38704dadd5274e52da7ae2817..91e290439ae4bfd491c8201b02b161fe2caf2f8d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -120,49 +120,7 @@ void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x, CHECK(output_array->has_shape()); } -int GetOutputDepthFromWeights(const Model& model, const Operator& op) { - const string& weights_name = op.inputs[1]; - const auto& weights_shape = model.GetArray(weights_name).shape(); - if (op.type == OperatorType::kConv || - op.type == OperatorType::kFullyConnected) { - return weights_shape.dims(0); - } else if (op.type == OperatorType::kDepthwiseConv) { - return weights_shape.dims(3); - } else { - LOG(FATAL) << "Unhandled operator type"; - } -} - -bool EnsureBiasVectorShape(Model* model, Operator* op) { - const string& weights_name = op->inputs[1]; - const auto& weights_array = model->GetArray(weights_name); - // Yield until weights shape has been resolved. - if (!weights_array.has_shape()) { - return false; - } - - if (op->inputs.size() < 3) { - return false; - } - auto& bias_array = model->GetArray(op->inputs[2]); - if (bias_array.has_shape()) { - return true; - } - - const int output_depth = GetOutputDepthFromWeights(*model, *op); - bias_array.copy_shape(Shape({output_depth})); - - auto& float_buffer = bias_array.GetMutableBuffer(); - float_buffer.data.resize(output_depth, 0); - - return true; -} - void ProcessConvOperator(Model* model, ConvOperator* op) { - if (!EnsureBiasVectorShape(model, op)) { - return; - } - const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { @@ -211,12 +169,6 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { // might as well calculate the output shape and ensure it matches the // specified one - // Check if we have already run. - auto& output_array = model->GetArray(op->outputs[0]); - if (output_array.has_shape()) { - return; - } - // SPECIFIED OUTPUT SHAPE // The below is the specified, or prescribed output shape, _given_ to the // operator as an input. @@ -284,14 +236,20 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { // Set the output shape according to the specified output shape. std::vector const& specified_output_shape = specified_output_shape_array.GetBuffer().data; + auto& output_array = model->GetArray(op->outputs[0]); *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape; -} -void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { - if (!EnsureBiasVectorShape(model, op)) { - return; + // Set im2col array dimensions if there is one. + if (op->outputs.size() == 2) { + const int input_depth = weights_shape.dims(3); + auto& im2col_array = model->GetArray(op->outputs[1]); + im2col_array.copy_shape( + Shape{specified_output_shape[0], specified_output_shape[1], + specified_output_shape[2], input_depth * kheight * kwidth}); } +} +void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { @@ -321,7 +279,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { if (!op->depth_multiplier) { op->depth_multiplier = output_depth / input_depth; } - QCHECK_EQ(output_depth, input_depth * op->depth_multiplier) + CHECK_EQ(output_depth, input_depth * op->depth_multiplier) << "input/output depths and depth_multiplier don't match"; const int kheight = weights_shape.dims(1); @@ -406,10 +364,6 @@ void ProcessOpWithShapeInput(Model* model, Operator* op) { } void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { - if (!EnsureBiasVectorShape(model, op)) { - return; - } - const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { @@ -483,6 +437,7 @@ void ProcessTensorFlowReshapeOperator(Model* model, product_non_wildcard_dims *= shape_data[i]; } } + const int input_flat_size = RequiredBufferSizeForShape(input_shape); if (has_wildcard) { CHECK_GE(input_flat_size, product_non_wildcard_dims) @@ -491,6 +446,12 @@ void ProcessTensorFlowReshapeOperator(Model* model, << op->outputs[0] << "\". Are your input shapes correct?"; shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims; } + + if (shape_data.size() == 1 && shape_data[0] == 0) { + // We have reshaped a scalar, so preserve as a scalar. + shape_data.clear(); + } + auto& output_shape = *output_array.mutable_shape(); *output_shape.mutable_dims() = shape_data; CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape)) @@ -568,12 +529,14 @@ void ProcessAddNOperator(Model* model, Operator* op) { bool KeepDims(const Operator& op) { switch (op.type) { - case OperatorType::kTensorFlowMin: + case OperatorType::kReduceMin: // Reduction Min return static_cast(op).keep_dims; - case OperatorType::kTensorFlowMax: + case OperatorType::kReduceMax: // Reduction Max return static_cast(op).keep_dims; - case OperatorType::kTensorFlowSum: + case OperatorType::kSum: return static_cast(op).keep_dims; + case OperatorType::kReduceProd: + return static_cast(op).keep_dims; case OperatorType::kMean: return static_cast(op).keep_dims; default: @@ -1080,20 +1043,28 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) { return; } + // Yield until the axis has been resolved. + if (!op->axis) { + return; + } + int axis = op->axis.value(); + const auto& input_shape = input_array.shape(); const auto& indices_shape = indices_array.shape(); QCHECK_GE(input_shape.dimensions_count(), 1); op->input_rank = input_shape.dimensions_count(); + QCHECK_LT(axis, op->input_rank); - // We only support 1-D indices. - QCHECK_EQ(indices_shape.dimensions_count(), 1); - - // Copy the input dimensions to the output except for dimension 0, + // Copy the input dimensions to the output except for the axis dimensions // where the dimension of indices_shape is used. - // TODO(mgubin): if axis != 0 this is not true, change when it's supported. auto output_dims = output_array.mutable_shape()->mutable_dims(); - output_dims->push_back(indices_shape.dims(0)); - for (int dim = 1; dim < input_shape.dimensions_count(); dim++) { + for (int dim = 0; dim < axis; ++dim) { + output_dims->push_back(input_shape.dims(dim)); + } + for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) { + output_dims->push_back(indices_shape.dims(dim)); + } + for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) { output_dims->push_back(input_shape.dims(dim)); } } @@ -1111,27 +1082,23 @@ void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) { } // Yield until input dims have been resolved. - if (!input_values.has_shape()) { + if (!input_values.has_shape() || !input_k.has_shape()) { return; } - const auto& input_values_shape = input_values.shape(); - auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims(); - auto output_values_dims = output_values.mutable_shape()->mutable_dims(); - for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) { - output_indexes_dims->push_back(input_values_shape.dims(dim)); - output_values_dims->push_back(input_values_shape.dims(dim)); - } // If the value is initialized, we can specify the last dimension, otherwise // unknown. if (input_k.buffer) { + const auto& input_values_shape = input_values.shape(); + auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims(); + auto output_values_dims = output_values.mutable_shape()->mutable_dims(); + for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) { + output_indexes_dims->push_back(input_values_shape.dims(dim)); + output_values_dims->push_back(input_values_shape.dims(dim)); + } const int32_t k_value = input_k.GetBuffer().data[0]; output_indexes_dims->push_back(k_value); output_values_dims->push_back(k_value); - - } else { - output_indexes_dims->push_back(0); - output_values_dims->push_back(0); } } @@ -1239,7 +1206,7 @@ void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) { output_shape->ReplaceDims({input_array.shape().dimensions_count()}); } -void ProcessStackOperator(Model* model, StackOperator* op) { +void ProcessPackOperator(Model* model, PackOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); auto& output_array = model->GetArray(op->outputs[0]); @@ -1248,7 +1215,7 @@ void ProcessStackOperator(Model* model, StackOperator* op) { return; } - std::unique_ptr stacked_shape; + std::unique_ptr packed_shape; for (const auto& input : op->inputs) { const auto& input_array = model->GetArray(input); if (!input_array.has_shape()) { @@ -1257,23 +1224,23 @@ void ProcessStackOperator(Model* model, StackOperator* op) { } Shape shape = input_array.shape(); - if (!stacked_shape) { - stacked_shape.reset(new Shape(shape)); + if (!packed_shape) { + packed_shape.reset(new Shape(shape)); } else { - CHECK(*stacked_shape == shape) << "All input arrays to Stack operators " - "must have the same shape. Input \"" - << input << "\" is different."; + CHECK(*packed_shape == shape) << "All input arrays to Pack operators " + "must have the same shape. Input \"" + << input << "\" is different."; } } int axis = op->axis; if (axis < 0) { // Handle negative axis - axis += stacked_shape->dims().size() + 1; + axis += packed_shape->dims().size() + 1; } - stacked_shape->mutable_dims()->insert( - stacked_shape->mutable_dims()->begin() + axis, op->inputs.size()); - output_array.copy_shape(*stacked_shape); + packed_shape->mutable_dims()->insert( + packed_shape->mutable_dims()->begin() + axis, op->inputs.size()); + output_array.copy_shape(*packed_shape); } void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { @@ -1337,8 +1304,8 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { op->begin_mask, op->start_indices, op->strides, input_array.shape().dims().data(), axis); int stop_index = tflite::strided_slice::StopForAxis( - op->end_mask, op->stop_indices, op->strides, - input_array.shape().dims().data(), axis); + op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides, + input_array.shape().dims().data(), axis, start_index); int dim_size = ceil(static_cast(stop_index - start_index) / op->strides[axis]); @@ -1453,7 +1420,8 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) { } } -void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { +template +void ProcessArgMinMaxOperator(Model* model, Op* op) { CHECK_EQ(op->inputs.size(), 2); const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. @@ -1505,6 +1473,162 @@ void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) { } } +void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // We have already run. + return; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + const auto& input_shape = input_array.shape(); + + auto& multiples_array = model->GetArray(op->inputs[1]); + if (!multiples_array.has_shape()) { + // Yield until multiples shape been resolved. + return; + } + if (!multiples_array.buffer) { + // Yield until the multiples is constant. + return; + } + CHECK(multiples_array.data_type == ArrayDataType::kInt32) + << "Tile multiples input must be int32"; + + std::vector const& multiples = + multiples_array.GetBuffer().data; + CHECK_EQ(multiples.size(), input_shape.dimensions_count()) + << "Tile multiples input " << op->inputs[1] + << " must be same length as input dimensions"; + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->resize(multiples.size()); + for (int i = 0; i < mutable_dims->size(); ++i) { + (*mutable_dims)[i] = input_shape.dims(i) * multiples[i]; + } +} + +void ProcessAnyOperator(Model* model, AnyOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // We have already run. + return; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + const auto& input_shape = input_array.shape(); + + auto& reduction_indices_array = model->GetArray(op->inputs[1]); + if (!reduction_indices_array.has_shape()) { + // Yield until reduction indices shape been resolved. + return; + } + if (!reduction_indices_array.buffer) { + // Yield until the reduction indices are constant. + return; + } + CHECK(reduction_indices_array.data_type == ArrayDataType::kInt32) + << "Any reduction input must be int32"; + + int input_rank = input_shape.dimensions_count(); + std::set true_indices; + const auto& reduction_indices = + reduction_indices_array.GetBuffer().data; + for (int i = 0; i < reduction_indices.size(); ++i) { + const int32 reduction_index = reduction_indices[i]; + if (reduction_index < -input_rank || reduction_index >= input_rank) { + CHECK(false) << "Invalid reduction dimension " << reduction_index + << " for input with " << input_rank << " dimensions"; + } + int32 wrapped_index = reduction_index; + if (wrapped_index < 0) { + wrapped_index += input_rank; + } + true_indices.insert(wrapped_index); + } + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->clear(); + for (int i = 0; i < input_rank; ++i) { + if (true_indices.count(i) > 0) { + if (op->keep_dims) { + mutable_dims->emplace_back(1); + } + } else { + mutable_dims->emplace_back(input_shape.dims(i)); + } + } +} + +void ProcessOneHotOperator(Model* model, OneHotOperator* op) { + CHECK_EQ(op->inputs.size(), 4); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // Shape already propagated + return; + } + + // Yield until indices dims have been resolved. + const auto& indices_array = + model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]); + if (!indices_array.has_shape()) { + return; + } + + // Yield until depth is constant and dims have been resolved. + if (!IsConstantParameterArray(*model, + op->inputs[OneHotOperator::DEPTH_INPUT])) { + return; + } + const auto& depth_array = + model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]); + if (!depth_array.has_shape()) { + return; + } + + CHECK(depth_array.data_type == ArrayDataType::kInt32) + << "Depth array must be int32."; + CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1) + << "Depth array must be scalar."; + + const int depth = depth_array.GetBuffer().data[0]; + CHECK_GE(depth, 0) << "Depth must be non-negative."; + + const int indices_dims = indices_array.shape().dimensions_count(); + const int output_dims = indices_dims + 1; + const int axis = op->axis == -1 ? indices_dims : op->axis; + CHECK_GE(axis, 0) << "Resolved axis must be non-negative."; + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->resize(output_dims); + for (int i = 0; i < output_dims; ++i) { + int dim = 0; + if (i < axis) { + dim = indices_array.shape().dims(i); + } else if (i == axis) { + dim = depth; + } else { + dim = indices_array.shape().dims(i - 1); + } + (*mutable_dims)[i] = dim; + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1531,18 +1655,21 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kLogistic: case OperatorType::kTanh: case OperatorType::kLocalResponseNormalization: - case OperatorType::kTensorFlowIdentity: + case OperatorType::kIdentity: case OperatorType::kFakeQuant: case OperatorType::kNeg: - case OperatorType::kTensorFlowRsqrt: - case OperatorType::kTensorFlowSqrt: - case OperatorType::kTensorFlowSquare: - case OperatorType::kTensorFlowAll: - case OperatorType::kTensorFlowAssert: + case OperatorType::kRsqrt: + case OperatorType::kSqrt: + case OperatorType::kSquare: + case OperatorType::kAll: + case OperatorType::kAssert: case OperatorType::kCast: case OperatorType::kFloor: case OperatorType::kExp: case OperatorType::kSin: + case OperatorType::kLogicalAnd: + case OperatorType::kLogicalNot: + case OperatorType::kLogicalOr: ProcessSimpleOperator(model, op, 0); break; case OperatorType::kGather: @@ -1557,14 +1684,15 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kDiv: case OperatorType::kFloorDiv: case OperatorType::kFloorMod: - case OperatorType::kTensorFlowLess: - case OperatorType::kTensorFlowLessEqual: - case OperatorType::kTensorFlowGreater: - case OperatorType::kTensorFlowMaximum: - case OperatorType::kTensorFlowMinimum: - case OperatorType::kTensorFlowGreaterEqual: - case OperatorType::kTensorFlowEqual: - case OperatorType::kTensorFlowNotEqual: + case OperatorType::kLess: + case OperatorType::kLessEqual: + case OperatorType::kGreater: + case OperatorType::kMaximum: // Element-wise Maximum + case OperatorType::kMinimum: // Element-wise Minimum + case OperatorType::kGreaterEqual: + case OperatorType::kEqual: + case OperatorType::kNotEqual: + case OperatorType::kPow: ProcessSimpleBinaryOperator(model, op); break; case OperatorType::kAddN: @@ -1597,7 +1725,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessFullyConnectedOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: ProcessTensorFlowReshapeOperator( model, static_cast(op)); break; @@ -1610,9 +1738,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kL2Pool: ProcessL2PoolOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowMin: - case OperatorType::kTensorFlowMax: - case OperatorType::kTensorFlowSum: + case OperatorType::kReduceMin: // Reduction Min + case OperatorType::kReduceMax: // Reduction Max + case OperatorType::kSum: + case OperatorType::kReduceProd: case OperatorType::kMean: ProcessTensorFlowReductionOperator(model, op); break; @@ -1623,34 +1752,26 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessSliceOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowTile: - // We don't currently implement the propagation of fixed sizes through - // a TensorFlow Tile. - // - // Fortunately, we don't need to: so far, we have only dealt with Tile - // or Slice ops in subgraphs that are identified as L2Normalization. - // See IdentifyL2Normalization. - break; - case OperatorType::kTensorFlowSwitch: + case OperatorType::kSwitch: // We can't know the sizes of the outputs until we have resolved the // predicate, and once we have resolved the predicate, the whole // Switch node will get resolved away. // See ResolveTensorFlowSwitch. break; - case OperatorType::kTensorFlowMerge: + case OperatorType::kMerge: // No need to bother resolving TensorFlow Merge ops: other graph // transformations will remove them anyway. // See ResolveTensorFlowMerge. break; - case OperatorType::kTensorFlowSplit: + case OperatorType::kSplit: ProcessTensorFlowSplitOperator(model, static_cast(op)); break; case OperatorType::kSqueeze: ProcessSqueezeOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowConcat: - case OperatorType::kTensorFlowConcatV2: + case OperatorType::kConcat: + case OperatorType::kConcatV2: // Unimplemented, hopefully another graph transformation will // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat // will resolve this node to a DepthConcatenation, or else we have @@ -1666,11 +1787,11 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kRank: ProcessRankOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowShape: + case OperatorType::kShape: ProcessShapeOperator(model, static_cast(op)); break; - case OperatorType::kStack: - ProcessStackOperator(model, static_cast(op)); + case OperatorType::kPack: + ProcessPackOperator(model, static_cast(op)); break; case OperatorType::kReorderAxes: ProcessReorderAxesOperator(model, static_cast(op)); @@ -1687,7 +1808,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessLstmCellOperator(model, static_cast(op)); break; case OperatorType::kBatchMatMul: - case OperatorType::kTensorFlowMatMul: + case OperatorType::kMatMul: // MatMul operators are converted to FullyConnected, after which their // shapes are propagated. break; @@ -1710,10 +1831,26 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { static_cast(op)); break; case OperatorType::kArgMax: - ProcessArgMaxOperator(model, static_cast(op)); + ProcessArgMinMaxOperator( + model, static_cast(op)); + break; + case OperatorType::kArgMin: + ProcessArgMinMaxOperator( + model, static_cast(op)); break; - case OperatorType::kTensorFlowUnsupported: + case OperatorType::kUnsupported: { + const auto* unsupported_op = + static_cast(op); + // Attribute can be not specified, ignore it. + if (unsupported_op->output_shapes.size() < op->outputs.size()) { + return false; + } + for (int i = 0; i < op->outputs.size(); ++i) { + const string& output = op->outputs[i]; + model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i)); + } break; + } case OperatorType::kSvdf: ProcessSvdfOperator(model, static_cast(op)); break; @@ -1734,6 +1871,15 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessSparseToDenseOperator(model, static_cast(op)); break; + case OperatorType::kTile: + ProcessTileOperator(model, static_cast(op)); + break; + case OperatorType::kAny: + ProcessAnyOperator(model, static_cast(op)); + break; + case OperatorType::kOneHot: + ProcessOneHotOperator(model, static_cast(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc index d74cad9a626b3a472e2740d6bdaaaf7aab5bd484..44733391f5a1d9ebf9a24f4f31b425a35354e1fc 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc @@ -74,46 +74,54 @@ ArrayDataType GetQuantizedDataType(const Array& array, } } -void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax, - QuantizationParams* quantization_params) { - switch (data_type) { +template +void ChooseQuantizationParamsForArrayAndQuantizedDataType( + const Array& array, QuantizationParams* quantization_params) { + *quantization_params = ::tflite::ChooseQuantizationParams>( + array.minmax->min, array.minmax->max, array.narrow_range); +} + +void ChooseQuantizationParamsForArrayAndQuantizedDataType( + const Array& array, ArrayDataType quantized_data_type, + QuantizationParams* quantization_params) { + switch (quantized_data_type) { case ArrayDataType::kInt8: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kInt8>(array, quantization_params); break; case ArrayDataType::kUint8: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kUint8>(array, quantization_params); break; case ArrayDataType::kInt16: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kInt16>(array, quantization_params); break; case ArrayDataType::kUint16: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kUint16>(array, quantization_params); break; case ArrayDataType::kInt32: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kInt32>(array, quantization_params); break; case ArrayDataType::kUint32: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kUint32>(array, quantization_params); break; case ArrayDataType::kInt64: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kInt64>(array, quantization_params); break; case ArrayDataType::kUint64: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kUint64>(array, quantization_params); break; case ArrayDataType::kFloat: case ArrayDataType::kNone: default: LOG(FATAL) << "Unhandled final quantization type " - << static_cast(data_type); + << static_cast(quantized_data_type); } } @@ -121,8 +129,8 @@ namespace { template std::unique_ptr QuantizeBuffer( - const GenericBuffer& buffer, - const QuantizationParams& quantization_params) { + const Array& array, const QuantizationParams& quantization_params) { + const GenericBuffer& buffer = *array.buffer; const auto inverse_scale = 1. / quantization_params.scale; CHECK(buffer.type == ArrayDataType::kFloat); const auto& float_buffer = @@ -140,8 +148,15 @@ std::unique_ptr QuantizeBuffer( } else { scaled_val = quantization_params.zero_point + inverse_scale * src_val; } - quantized_buffer->data[i] = - tflite::SafeCast>(std::round(scaled_val)); + auto integer_val = tflite::SafeCast>(std::round(scaled_val)); + // In addition to its effect on the choice of quantization params upstream + // of here, narrow_range also means nudge the min quantized value by +1, + // so e.g. uint8 values get constrained to [1, 255]. + if (integer_val == std::numeric_limits>::min() && + array.narrow_range) { + integer_val++; + } + quantized_buffer->data[i] = integer_val; } return std::unique_ptr(quantized_buffer); } @@ -155,7 +170,7 @@ void QuantizeArray(GraphTransformation* transformation, Model* model, CHECK(!array.quantization_params); array.GetOrCreateQuantizationParams() = quantization_params; if (array.buffer) { - array.buffer = QuantizeBuffer(*array.buffer, quantization_params); + array.buffer = QuantizeBuffer(array, quantization_params); } array.data_type = A; array.final_data_type = A; @@ -210,8 +225,8 @@ bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation, } else { // Work around cases where we are asking for this prior to the Quantize // transformation having added the quantization_params. - GetQuantizationParams(quantized_data_type, *array.minmax, - &quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType( + array, quantized_data_type, &quantization_params); transformation->AddMessageF( "No quantization params - infering from data type %s with minmax " "%g,%g as zero_point=%g, scale=%g", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h index 79a2ce7e50887b4608b278471da0e5e63b5673e3..cf093c6f17b45839156dae0d06ca2fc7e5e2f3c6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h @@ -38,21 +38,11 @@ bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type, ArrayDataType GetQuantizedDataType(const Array& array, ArrayDataType default_type); -// Returns the quantization params for the array with the given data type and -// minmax. -void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax, - QuantizationParams* quantization_params); - -// Returns the quantization params for the data type and minmax values. -template -void GetQuantizationParamsFromMinMax(const MinMax& minmax, - QuantizationParams* quantization_params) { - using Integer = DataType; - const double rmin = minmax.min; - const double rmax = minmax.max; - *quantization_params = - ::tflite::ChooseQuantizationParams(rmin, rmax); -} +// Chooses the quantization params for a given array and a given target +// quantized data type (which may not be the array's current data type). +void ChooseQuantizationParamsForArrayAndQuantizedDataType( + const Array& array, ArrayDataType quantized_data_type, + QuantizationParams* quantization_params); // Quantizes an array by setting its data type and (if constant) quantizing // all values in the array. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index eca2c701f8bbf889088794c939af7082db0734dd..8d22ae2eb1356b8c9c9430c517acddfc971b9f57 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -33,7 +33,7 @@ namespace { bool SupportsQuantization(const Operator& op) { auto type = op.type; - if (type == OperatorType::kTensorFlowUnsupported) { + if (type == OperatorType::kUnsupported) { auto* unsupported = static_cast(&op); return unsupported->quantized; } @@ -42,29 +42,39 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kConcatenation || type == OperatorType::kL2Normalization || type == OperatorType::kAdd || type == OperatorType::kAveragePool || type == OperatorType::kMaxPool || - type == OperatorType::kTensorFlowMinimum || - type == OperatorType::kTensorFlowMaximum || + type == OperatorType::kMinimum || type == OperatorType::kMaximum || type == OperatorType::kLogistic || type == OperatorType::kSoftmax || type == OperatorType::kLogSoftmax || type == OperatorType::kSlice || type == OperatorType::kResizeBilinear || - type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub || + type == OperatorType::kSplit || type == OperatorType::kSub || type == OperatorType::kSqueeze || type == OperatorType::kPad || - type == OperatorType::kPadV2 || - type == OperatorType::kTensorFlowReshape || + type == OperatorType::kPadV2 || type == OperatorType::kReshape || type == OperatorType::kTanh || type == OperatorType::kMul || + type == OperatorType::kBatchToSpaceND || type == OperatorType::kSum || type == OperatorType::kSpaceToBatchND || type == OperatorType::kSpaceToDepth || type == OperatorType::kStridedSlice || type == OperatorType::kDepthToSpace || type == OperatorType::kLstmCell || type == OperatorType::kGather || type == OperatorType::kTranspose || type == OperatorType::kMean || - type == OperatorType::kTensorFlowGreater || - type == OperatorType::kTensorFlowGreaterEqual || - type == OperatorType::kTensorFlowLess || - type == OperatorType::kTensorFlowLessEqual || - type == OperatorType::kSelect || type == OperatorType::kArgMax; + type == OperatorType::kGreater || + type == OperatorType::kGreaterEqual || type == OperatorType::kLess || + type == OperatorType::kLessEqual || type == OperatorType::kSelect || + type == OperatorType::kArgMax || type == OperatorType::kRelu || + type == OperatorType::kRelu1 || type == OperatorType::kRelu6 || + type == OperatorType::kShape || type == OperatorType::kExpandDims; } +// The quantized op allows output arrays of type float using +// the attribute support_output_type_float_in_quantized_op +bool SupportOutputTypeFloatInQuantizedOp(const Operator& op) { + auto type = op.type; + if (type == OperatorType::kUnsupported) { + auto* unsupported = static_cast(&op); + return unsupported->support_output_type_float_in_quantized_op; + } + return false; +} const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { auto& array = model->GetArray(array_name); // Normally we should have a MinMax recorded on this Array, @@ -214,13 +224,15 @@ bool ChooseQuantizationForOperatorInput( if (op.type == OperatorType::kLstmCell) { if (input_index == LstmCellOperator::PREV_STATE_INPUT) { *quantized_data_type = ArrayDataType::kInt16; - GetQuantizationParams(*quantized_data_type, minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType( + array, *quantized_data_type, quantization_params); return true; } } *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8); - GetQuantizationParams(*quantized_data_type, minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType( + array, *quantized_data_type, quantization_params); transformation->AddMessageF( "For input array %s with min=%g, max=%g, chose to quantize as %s (f=%s) " "with zero_point=%d, scale=%g", @@ -328,14 +340,15 @@ bool ChooseQuantizationForOperatorOutput( output, OperatorTypeName(op.type)); return true; } - if ((op.type == OperatorType::kDepthToSpace) || - (op.type == OperatorType::kSpaceToDepth) || - (op.type == OperatorType::kTensorFlowReshape) || - (op.type == OperatorType::kTensorFlowSplit) || - (op.type == OperatorType::kConcatenation && - model->flags.change_concat_input_ranges())) { + if ((op.type == OperatorType::kConcatenation && + model->flags.change_concat_input_ranges()) || + op.type == OperatorType::kDepthToSpace || + op.type == OperatorType::kSpaceToDepth || + op.type == OperatorType::kReshape || op.type == OperatorType::kSplit || + op.type == OperatorType::kRelu || op.type == OperatorType::kRelu1 || + op.type == OperatorType::kRelu6) { int data_input_index = 0; - if (op.type == OperatorType::kTensorFlowSplit) { + if (op.type == OperatorType::kSplit) { data_input_index = 1; } // Copying and rearrangement ops should preserve the quantization parameters @@ -359,12 +372,14 @@ bool ChooseQuantizationForOperatorOutput( if (output_index == LstmCellOperator::STATE_OUTPUT || output_index == LstmCellOperator::ACTIV_TEMP) { *quantized_data_type = ArrayDataType::kInt16; - GetQuantizationParams(*quantized_data_type, minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType( + array, *quantized_data_type, quantization_params); return true; } } *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8); - GetQuantizationParams(*quantized_data_type, minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType( + array, *quantized_data_type, quantization_params); transformation->AddMessageF( "For output array %s with min=%g, max=%g" ", chose to quantize as %s with zero_point=%d" @@ -508,36 +523,47 @@ bool Quantize::Run(Model* model, std::size_t op_index) { // Check if the output of that Dequantize op was not used by any // other operator. We will then erase that Dequantize op. if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) { - // If any of the model's output_arrays was pointing to the - // Dequantize op's output, let it point to the Dequantize op's - // input instead. - for (int i = 0; i < model->flags.output_arrays_size(); i++) { - if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) { - // TODO(b/78013785): never rename output arrays. - if (IsInputArray(*model, dequantize_op->inputs[0])) { - // The op input is an input array and the output is an output - // array and we can't have an array be both. Insert a copy - // op to ensure the two arrays stay separate. - AddMessageF( - "Tried to rename output array %d while removing dequant " - "op %s but array is also an input; inserting copy %s " - "-> %s", - i, LogName(*dequantize_op), model->flags.output_arrays(i), - dequantize_op->inputs[0]); - InsertCopyOperator(model, dequantize_op->inputs[0], - dequantize_op->outputs[0]); - } else { - // Op output is strictly used as an output array, so we can - // just rename the array and directly bypass the op. - AddMessageF( - "Renaming output array %d after removing dequant op %s: " - "%s -> %s", - i, LogName(*dequantize_op), model->flags.output_arrays(i), - dequantize_op->inputs[0]); - model->flags.set_output_arrays(i, dequantize_op->inputs[0]); - model->EraseArray(dequantize_op->outputs[0]); + if (IsDiscardableArray(*model, dequantize_op->outputs[0])) { + // Usual case: we can just discard the dequantize output. + model->EraseArray(dequantize_op->outputs[0]); + } else { + // The dequantize output is not discardable. Special care needed. + // If any of the model's output_arrays was pointing to the + // Dequantize op's output, let it point to the Dequantize op's + // input instead. + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == + dequantize_op->outputs[0]) { + // TODO(b/78013785): never rename output arrays. + if (IsInputArray(*model, dequantize_op->inputs[0])) { + // The op input is an input array and the output is an + // output array and we can't have an array be both. Insert a + // copy op to ensure the two arrays stay separate. + AddMessageF( + "Tried to rename output array %d while removing " + "dequant " + "op %s but array is also an input; inserting copy %s " + "-> %s", + i, LogName(*dequantize_op), + model->flags.output_arrays(i), + dequantize_op->inputs[0]); + InsertCopyOperator(model, dequantize_op->inputs[0], + dequantize_op->outputs[0]); + } else { + // Op output is strictly used as an output array, so we can + // just rename the array and directly bypass the op. + AddMessageF( + "Renaming output array %d after removing dequant op " + "%s: " + "%s -> %s", + i, LogName(*dequantize_op), + model->flags.output_arrays(i), + dequantize_op->inputs[0]); + model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + model->EraseArray(dequantize_op->outputs[0]); + } + break; } - break; } } model->operators.erase(dequantize_it); @@ -569,61 +595,67 @@ bool Quantize::Run(Model* model, std::size_t op_index) { } // Quantize outputs, add Dequantize ops as needed on the outputs side - for (std::size_t output_index = 0; output_index < op.outputs.size(); - output_index++) { - ArrayDataType quantized_data_type; - QuantizationParams quantization_params; - if (ChooseQuantizationForOperatorOutput(this, model, op, output_index, - &quantized_data_type, - &quantization_params)) { - changed = true; - const auto& output = op.outputs[output_index]; - auto& output_array = model->GetArray(output); - - // Fix up the min/max information on the output array to match the chosen - // quantization parameters. - CHECK(output_array.minmax) - << "Output array named " << output << " lacks minmax"; - auto& output_minmax = output_array.GetMinMax(); - FixMinMaxPostQuantization(this, quantized_data_type, quantization_params, - &output_minmax); - - QuantizeArray(this, model, output, quantized_data_type, - quantization_params); - - const auto& dequantized_output = - AvailableArrayName(*model, output + "_dequantized"); - auto& dequantized_output_array = - model->GetOrCreateArray(dequantized_output); - dequantized_output_array.data_type = ArrayDataType::kFloat; - dequantized_output_array.final_data_type = output_array.data_type; - auto& dequantized_output_minmax = - dequantized_output_array.GetOrCreateMinMax(); - dequantized_output_minmax.min = output_minmax.min; - dequantized_output_minmax.max = output_minmax.max; - for (const auto& other_op : model->operators) { - for (auto& other_op_input : other_op->inputs) { - if (other_op_input == output) { - other_op_input = dequantized_output; + if (SupportOutputTypeFloatInQuantizedOp(op)) { + LOG(WARNING) + << HelpfulOperatorTypeName(op) << " is a quantized op" + << "but it has a model flag that sets the output arrays to float."; + } else { + for (std::size_t output_index = 0; output_index < op.outputs.size(); + output_index++) { + QuantizationParams quantization_params; + ArrayDataType quantized_data_type; + if (ChooseQuantizationForOperatorOutput(this, model, op, output_index, + &quantized_data_type, + &quantization_params)) { + changed = true; + const auto& output = op.outputs[output_index]; + auto& output_array = model->GetArray(output); + + // Fix up the min/max information on the output array to match the + // chosen quantization parameters. + CHECK(output_array.minmax) + << "Output array named " << output << " lacks minmax"; + auto& output_minmax = output_array.GetMinMax(); + FixMinMaxPostQuantization(this, quantized_data_type, + quantization_params, &output_minmax); + + QuantizeArray(this, model, output, quantized_data_type, + quantization_params); + + const auto& dequantized_output = + AvailableArrayName(*model, output + "_dequantized"); + auto& dequantized_output_array = + model->GetOrCreateArray(dequantized_output); + dequantized_output_array.data_type = ArrayDataType::kFloat; + dequantized_output_array.final_data_type = output_array.data_type; + auto& dequantized_output_minmax = + dequantized_output_array.GetOrCreateMinMax(); + dequantized_output_minmax.min = output_minmax.min; + dequantized_output_minmax.max = output_minmax.max; + for (const auto& other_op : model->operators) { + for (auto& other_op_input : other_op->inputs) { + if (other_op_input == output) { + other_op_input = dequantized_output; + } } } - } - auto* dequantize_op = new DequantizeOperator; - dequantize_op->inputs = {output}; - dequantize_op->outputs = {dequantized_output}; - for (int i = 0; i < model->flags.output_arrays_size(); i++) { - if (model->flags.output_arrays(i) == output) { - // TODO(b/78013785): never rename output arrays. - AddMessageF( - "Renaming output array %d after inserting dequant op %s: %s -> " - "%s", - i, LogName(*dequantize_op), model->flags.output_arrays(i), - dequantized_output); - model->flags.set_output_arrays(i, dequantized_output); + auto* dequantize_op = new DequantizeOperator; + dequantize_op->inputs = {output}; + dequantize_op->outputs = {dequantized_output}; + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == output) { + // TODO(b/78013785): never rename output arrays. + AddMessageF( + "Renaming output array %d after inserting dequant op %s: %s -> " + "%s", + i, LogName(*dequantize_op), model->flags.output_arrays(i), + dequantized_output); + model->flags.set_output_arrays(i, dequantized_output); + } } + const auto op_it = FindOp(*model, &op); + model->operators.emplace(op_it + 1, dequantize_op); } - const auto op_it = FindOp(*model, &op); - model->operators.emplace(op_it + 1, dequantize_op); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc index 88ea0945e7dd15ba325d34ea3fdbf34ff7d91381..7a8515f6d12f96d464ea0764907f9cc2a487d3e7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc @@ -36,10 +36,8 @@ void GetQuantizationParamsFromArray(const Array& array, const std::vector& float_vals = array.GetBuffer().data; auto minmax = std::minmax_element(float_vals.begin(), float_vals.end()); - MinMax toco_minmax; - toco_minmax.min = *minmax.first; - toco_minmax.max = *minmax.second; - GetQuantizationParams(ArrayDataType::kUint8, toco_minmax, params); + *params = tflite::ChooseQuantizationParams( + *minmax.first, *minmax.second, array.narrow_range); } } // namespace diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc new file mode 100644 index 0000000000000000000000000000000000000000..5b41c49bfaff245d599d26989e4ed3f9b0d582cf --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool ApplyAttrsToArray(GraphTransformation* transformation, Model* model, + const FakeQuantOperator& fq_op, + const string& array_name) { + bool changed = false; + auto& annotated_array = model->GetArray(array_name); + if (!annotated_array.minmax) { + const MinMax& minmax = *fq_op.minmax; + annotated_array.GetOrCreateMinMax() = minmax; + transformation->AddMessageF( + "Read min/max annotation for array %s: min=%g, max=%g", array_name, + minmax.min, minmax.max); + changed = true; + } + if (fq_op.narrow_range && !annotated_array.narrow_range) { + annotated_array.narrow_range = true; + transformation->AddMessageF("Read narrow_range annotation for array %s", + array_name); + changed = true; + } + return changed; +} + +} // end namespace + +bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model, + std::size_t op_index) { + const auto fakequant_it = model->operators.begin() + op_index; + auto* fakequant_base_op = fakequant_it->get(); + if (fakequant_base_op->type != OperatorType::kFakeQuant) { + return false; + } + auto* fq_op = static_cast(fakequant_base_op); + + if (!fq_op->minmax) { + // Need to be resolved first by ResolveFakeQuantArgsFromVars. + return false; + } + + // At this point, this FakeQuantOperator should have a MinMax + // attached to it, and should only have 1 input (it should not have + // 2nd and 3rd input arrays giving min and max anymore). + CHECK(fq_op->minmax); + CHECK_EQ(1, fq_op->inputs.size()); + + return ApplyAttrsToArray(this, model, *fq_op, fq_op->inputs[0]) || + ApplyAttrsToArray(this, model, *fq_op, fq_op->outputs[0]); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc deleted file mode 100644 index bdcca5b7caf61a62203debaa32c4d7a9b2eb43fa..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc +++ /dev/null @@ -1,112 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include -#include - -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" - -namespace toco { - -namespace { - -bool ApplyMinMaxToArray(GraphTransformation* transformation, Model* model, - const MinMax& minmax, const string& array_name) { - auto& annotated_array = model->GetArray(array_name); - if (annotated_array.minmax) { - return false; - } - annotated_array.GetOrCreateMinMax() = minmax; - transformation->AddMessageF( - "Read min/max annotation for array %s: min=%g, max=%g", array_name, - minmax.min, minmax.max); - return true; -} - -} // end namespace - -bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) { - const auto fakequant_it = model->operators.begin() + op_index; - auto* fakequant_base_op = fakequant_it->get(); - if (fakequant_base_op->type != OperatorType::kFakeQuant) { - return false; - } - auto* fakequant_op = static_cast(fakequant_base_op); - - bool changed = false; - - if (!fakequant_op->minmax) { - CHECK_EQ(fakequant_op->inputs.size(), 3); - // We need to yield until the min and max parameters have been - // resolved to constant arrays. - for (int i = 1; i <= 2; i++) { - if (!IsConstantParameterArray(*model, fakequant_op->inputs[1])) { - return false; - } - } - - // Obtain the final min/max values - const auto& min_array = model->GetArray(fakequant_op->inputs[1]); - const auto& max_array = model->GetArray(fakequant_op->inputs[2]); - CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1); - CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1); - fakequant_op->minmax.reset(new MinMax); - MinMax& minmax = *fakequant_op->minmax; - minmax.min = min_array.GetBuffer().data[0]; - minmax.max = max_array.GetBuffer().data[0]; - // We always want [min, max] to contain 0. - if (minmax.min > 0 || minmax.max < 0) { - LOG(ERROR) << "For " << LogName(*fakequant_op) << " the MinMax range " - << "[" << minmax.min << ", " << minmax.max - << "] does not contain 0. " - << "Proceeding by tweaking it to contain 0, which will result " - "in poor accuracy."; - } - minmax.min = std::min(minmax.min, 0.); - minmax.max = std::max(minmax.max, 0.); - - // We won't use the input arrays that provided these min and max - // values, anymore. Delete them unless they are used by something - // else. - for (int i = 1; i <= 2; i++) { - if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) { - model->EraseArray(fakequant_op->inputs[i]); - } - } - fakequant_op->inputs.resize(1); - changed = true; - } - - // At this point, this FakeQuantOperator should have a MinMax - // attached to it, and should only have 1 input (it should not have - // 2nd and 3rd input arrays giving min and max anymore). - CHECK(fakequant_op->minmax); - CHECK_EQ(1, fakequant_op->inputs.size()); - - const MinMax& minmax = *fakequant_op->minmax; - - // Record the MinMax info on the input and output arrays - changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->inputs[0]); - changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->outputs[0]); - - return changed; -} - -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc index 35a0c465327f352863350e7a8af714d16b7be393..73ad326299bbd929afbb8dda2c41b97a126afbe1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc @@ -26,7 +26,7 @@ namespace toco { bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) { const auto assert_it = model->operators.begin() + op_index; const auto* assert_op = assert_it->get(); - if (assert_op->type != OperatorType::kTensorFlowAssert) { + if (assert_op->type != OperatorType::kAssert) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc index 404269bbfd9312bbbab32489783d9e4217ecbd89..7ec7752f25dad1c24b821733c0e6dafbd1cd8bf2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc @@ -28,7 +28,7 @@ namespace toco { bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) { const auto passthru_it = model->operators.begin() + op_index; const auto* passthru_op = passthru_it->get(); - if (passthru_op->type != OperatorType::kTensorFlowIdentity) { + if (passthru_op->type != OperatorType::kIdentity) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc index a950fe6442bc656b725a1f0687f4c024f4fb0f84..fc49fbda59c78f056a7e194367618b43c0a4a7db 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc @@ -48,20 +48,26 @@ void RerouteEdges(const string& from_array, const string& to_array, } // namespace bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, - Model* model, std::size_t op_index) { + Model* model, std::size_t op_index, + int input_index) { const auto passthru_it = model->operators.begin() + op_index; auto* passthru_op = passthru_it->get(); CHECK_EQ(passthru_op->outputs.size(), 1); CHECK_GE(passthru_op->inputs.size(), 1); - int count_nonconstant_input_arrays = 0; - // We call 'main input' the unique nonconstant input array if there is one, - // or else the 0-th input. + int main_input_array_index = 0; - for (int i = 0; i < passthru_op->inputs.size(); i++) { - if (!model->GetArray(passthru_op->inputs[i]).buffer) { - count_nonconstant_input_arrays++; - if (count_nonconstant_input_arrays == 1) { - main_input_array_index = i; + if (input_index != -1) { + main_input_array_index = input_index; + } else { + // We call 'main input' the unique nonconstant input array if there is one, + // or else the 0-th input. + int count_nonconstant_input_arrays = 0; + for (int i = 0; i < passthru_op->inputs.size(); i++) { + if (!model->GetArray(passthru_op->inputs[i]).buffer) { + count_nonconstant_input_arrays++; + if (count_nonconstant_input_arrays == 1) { + main_input_array_index = i; + } } } } @@ -97,7 +103,7 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, "Cannot remove %s, neither its main input nor its output may be " "discarded", LogName(*passthru_op)); - if (passthru_op->type != OperatorType::kTensorFlowReshape && + if (passthru_op->type != OperatorType::kReshape && model->GetArray(main_input_name).has_shape()) { // We can't remove either array but we can remove the op. Converting it to // a reshape gives us some hope of later on fixing that (either in the diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h index 9d448c3ee9088c16b96aa7ddc84457d2cab3231a..663704e5acf745d3768ad682e0a7888f0a690e6c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h @@ -50,7 +50,8 @@ namespace toco { // and then discards it and returns true, or, if it's not trivial (if neither // the input nor the output may be discarded), returns false. bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, - Model* model, std::size_t op_index); + Model* model, std::size_t op_index, + int input_index = -1); } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc index eaee1c662b7cedb2baec7be47e12e348c3e7b25c..142c876b154755ac9c6b93e560f22ec8d6ec6563 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc @@ -47,11 +47,11 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model, double clamp_min; double clamp_max; switch (op_type) { - case OperatorType::kTensorFlowMinimum: + case OperatorType::kMinimum: // Element-wise Minimum clamp_min = -std::numeric_limits::infinity(); clamp_max = clamp_value; break; - case OperatorType::kTensorFlowMaximum: + case OperatorType::kMaximum: // Element-wise Maximum clamp_min = clamp_value; clamp_max = std::numeric_limits::infinity(); break; @@ -72,8 +72,8 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model, bool RemoveTrivialQuantizedMinMax::Run(Model* model, std::size_t op_index) { const auto it = model->operators.begin() + op_index; auto* op = it->get(); - if ((op->type != OperatorType::kTensorFlowMinimum && - op->type != OperatorType::kTensorFlowMaximum) || + if ((op->type != OperatorType::kMinimum && + op->type != OperatorType::kMaximum) || op->inputs.size() != 2) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc index e28d8cf01eafee64e08ac2cc4b43ea7c227456c2..5295eeccecb05b05232922f4b5e4ef75a2b04672 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc @@ -30,7 +30,7 @@ namespace { bool IsReshapeTrivial(const Model& model, const Operator& op, RemoveTrivialReshape* transformation) { - CHECK(op.type == OperatorType::kTensorFlowReshape); + CHECK(op.type == OperatorType::kReshape); // One way in which a reshape can be trivial is if its // output shape is == its input shape @@ -58,7 +58,16 @@ bool IsReshapeTrivial(const Model& model, const Operator& op, // is only consumed by another reshape. if (CountOpsWithInput(model, op.outputs[0]) == 1) { const auto* next_op = GetOpWithInput(model, op.outputs[0]); - if (next_op->type == OperatorType::kTensorFlowReshape) { + if (next_op->type == OperatorType::kReshape) { + if (!IsDiscardableArray(model, next_op->outputs[0])) { + // If the |next_op| output is used as a model output we need to preserve + // its shape. + transformation->AddMessageF( + "%s cannot be merged into following reshape %s as it is " + "non-discardable and must keep the specified shape", + LogName(op), LogName(*next_op)); + return false; + } transformation->AddMessageF( "%s is trivial because its output is only consumed by another " "Reshape op %s", @@ -75,7 +84,7 @@ bool IsReshapeTrivial(const Model& model, const Operator& op, bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) { const auto reshape_it = model->operators.begin() + op_index; auto* reshape_op = reshape_it->get(); - if (reshape_op->type != OperatorType::kTensorFlowReshape) { + if (reshape_op->type != OperatorType::kReshape) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc index 1956ab2d2021cda84a0d715534923d6174c30dd1..dde91234a8240f4518cd105c2cc4e79102735980 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc @@ -48,7 +48,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { for (const auto& rnn_state : model->flags.rnn_states()) { if (output == rnn_state.state_array()) { CHECK(op->type == OperatorType::kFill || - op->type == OperatorType::kTensorFlowIdentity); + op->type == OperatorType::kIdentity); found_output_as_rnn_state_array = true; break; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc index 9f5b7920cb937b021eb23fc1d5fdc3c1ff18a72d..550de83018f25a7aa4da82707fedb86434615fb0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc @@ -37,8 +37,8 @@ bool IsElementwiseOperator(OperatorType optype) { case OperatorType::kRelu1: case OperatorType::kRelu6: case OperatorType::kTanh: - case OperatorType::kTensorFlowSqrt: - case OperatorType::kTensorFlowSquare: + case OperatorType::kSqrt: + case OperatorType::kSquare: return true; default: return false; @@ -51,7 +51,7 @@ bool IsMoveOperator(OperatorType optype) { case OperatorType::kExpandDims: case OperatorType::kSpaceToDepth: case OperatorType::kSqueeze: - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kTranspose: return true; default: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc index 9e7fe1b1ccd851dd998e59e75ff798f52f7c6e5a..c907a597cb719b68dbf36868a75e49a7c5181423 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc @@ -123,8 +123,8 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) { } TensorFlowReshapeOperator* reshape_op = - ConvertOperator( - reshape_it->get(), OperatorType::kTensorFlowReshape); + ConvertOperator(reshape_it->get(), + OperatorType::kReshape); if (reshape_op == nullptr) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc index a06919e228dc2084f8943a714a0ca111d013c159..b8b35161d77e5b6dd8c30e03959dba3c60d1d56c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc @@ -50,7 +50,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { // will delete this op. return false; } - std::vector crops_buffer = + const std::vector& crops_buffer = crops_array.GetBuffer().data; for (int i = 0; i < crops_dims[0]; ++i) { op->before_crops.push_back(crops_buffer[i * 2]); @@ -62,7 +62,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) { if (!block_shape_array.has_shape()) return false; const std::vector& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); - std::vector block_shape_buffer = + const std::vector& block_shape_buffer = block_shape_array.GetBuffer().data; for (int i = 0; i < block_shape_dims[0]; ++i) { op->block_shape.push_back(block_shape_buffer[i]); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc index 6e78653fad238085da5ba66166884093ea9b0214..f7e5aa6609bd4f7eb2a95750125e30a7803b36e1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -145,17 +145,17 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model, outval = floor(val0 / val1); } else if (binary_op->type == OperatorType::kFloorMod) { outval = val0 - (floor(val0 / val1) * val1); - } else if (binary_op->type == OperatorType::kTensorFlowMinimum) { + } else if (binary_op->type == OperatorType::kMinimum) { outval = std::min(val0, val1); - } else if (binary_op->type == OperatorType::kTensorFlowMaximum) { + } else if (binary_op->type == OperatorType::kMaximum) { outval = std::max(val0, val1); - } else if (binary_op->type == OperatorType::kTensorFlowLess) { + } else if (binary_op->type == OperatorType::kLess) { outval = val0 < val1; - } else if (binary_op->type == OperatorType::kTensorFlowLessEqual) { + } else if (binary_op->type == OperatorType::kLessEqual) { outval = val0 <= val1; - } else if (binary_op->type == OperatorType::kTensorFlowGreater) { + } else if (binary_op->type == OperatorType::kGreater) { outval = val0 > val1; - } else if (binary_op->type == OperatorType::kTensorFlowGreaterEqual) { + } else if (binary_op->type == OperatorType::kGreaterEqual) { outval = val0 >= val1; } else { LOG(FATAL) << "should not get here"; @@ -198,12 +198,12 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { binary_op->type != OperatorType::kDiv && binary_op->type != OperatorType::kFloorDiv && binary_op->type != OperatorType::kFloorMod && - binary_op->type != OperatorType::kTensorFlowMinimum && - binary_op->type != OperatorType::kTensorFlowMaximum && - binary_op->type != OperatorType::kTensorFlowLess && - binary_op->type != OperatorType::kTensorFlowLessEqual && - binary_op->type != OperatorType::kTensorFlowGreater && - binary_op->type != OperatorType::kTensorFlowGreaterEqual) { + binary_op->type != OperatorType::kMinimum && + binary_op->type != OperatorType::kMaximum && + binary_op->type != OperatorType::kLess && + binary_op->type != OperatorType::kLessEqual && + binary_op->type != OperatorType::kGreater && + binary_op->type != OperatorType::kGreaterEqual) { return false; } CHECK_EQ(binary_op->inputs.size(), 2); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index efb7bb218421dd045e3e8e2a38b9c70989f222e1..d395d7a6a0862d93fd4f52bb8b8d8d3ea7f8dc1e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -25,6 +25,40 @@ limitations under the License. namespace toco { +template +void GetBoundsForQuantizedDataType(float* min, float* max) { + using limits = std::numeric_limits>; + *min = limits::min(); + *max = limits::max(); +} + +void GetBoundsForQuantizedDataType(ArrayDataType quantized_data_type, + float* min, float* max) { + // It is important for matching accuracy between TF training and TFLite + // inference, that the min and max values are float to match TF's + // FakeQuantWithMinMaxVarsFunctor. + switch (quantized_data_type) { + case ArrayDataType::kUint8: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kInt8: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kUint16: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kInt16: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kUint32: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kInt32: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kUint64: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kInt64: + return GetBoundsForQuantizedDataType(min, max); + default: + LOG(FATAL) << "unhandled quantized data type"; + } +} + bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { const auto fakequant_it = model->operators.begin() + op_index; const auto* fakequant_base_op = fakequant_it->get(); @@ -76,18 +110,25 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { const int size = input_buffer.data.size(); output_buffer.data.resize(size); QuantizationParams qparams; - GetQuantizationParamsFromMinMax(*fakequant_op->minmax, - &qparams); - for (int i = 0; i < size; i++) { - const double src_val = input_buffer.data[i]; - const double unclamped_quantized_val = - std::round(qparams.zero_point + src_val / qparams.scale); - const double quantized_val = - std::min(255., std::max(0., unclamped_quantized_val)); - const double dst_val = qparams.scale * (quantized_val - qparams.zero_point); - output_buffer.data[i] = dst_val; + ChooseQuantizationParamsForArrayAndQuantizedDataType( + output_array, quantized_data_type, &qparams); + float quantized_min, quantized_max; + GetBoundsForQuantizedDataType(quantized_data_type, &quantized_min, + &quantized_max); + if (fakequant_op->narrow_range) { + quantized_min++; } + // It is important for matching accuracy between TF training and TFLite + // inference, that the following variables are float to match TF's + // FakeQuantWithMinMaxVarsFunctor. + const float scale = qparams.scale; + const float nudged_min = (quantized_min - qparams.zero_point) * scale; + const float nudged_max = (quantized_max - qparams.zero_point) * scale; + tflite::FakeQuantizeArray(scale, nudged_min, nudged_max, + input_buffer.data.data(), output_buffer.data.data(), + size); + if (IsDiscardableArray(*model, fakequant_op->inputs[0]) && CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) { model->EraseArray(fakequant_op->inputs[0]); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc index debe298a5a93034bcb928d7384b5ec1fc7439e47..36d7dad0ce9de81ec132ef992538b6022916bfbd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc @@ -69,7 +69,7 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) { } const auto* op = static_cast(base_op); - CHECK_EQ(op->inputs.size(), 2); + CHECK_GE(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 1); auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { @@ -81,10 +81,14 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) { return false; } - // Only handling axis=0 for now. - if (op->axis != 0) { + if (!op->axis) { + // Yield until axis has been set by ResolveGatherAttributes. + return false; + } + if (op->axis.value() != 0) { + // Only handling axis=0 for now. AddMessageF("%s has axis %d; only axis=0 is supported", LogName(*op), - op->axis); + op->axis.value()); return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc similarity index 81% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc rename to tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc index 69db1942cd52af810acf38a818997c71122d8500..e86616574d5a0f1345cde167d4ce0d41665d5a02 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc @@ -24,7 +24,7 @@ namespace toco { namespace { template -void Stack(Model* model, StackOperator const& op) { +void Pack(Model* model, PackOperator const& op) { auto& output_array = model->GetArray(op.outputs[0]); CHECK(output_array.data_type == Type); @@ -33,15 +33,15 @@ void Stack(Model* model, StackOperator const& op) { output_array.GetMutableBuffer().data; output_data.resize(RequiredBufferSizeForShape(output_array.shape())); - // Stack inputs into buffer - CHECK_EQ(op.axis, 0) << "Stacking only supported along first axis"; + // Pack inputs into buffer + CHECK_EQ(op.axis, 0) << "Packing only supported along first axis"; int dst_offset = 0; for (int i = 0; i < op.inputs.size(); i++) { // Append array data to output for each input array const auto& input_array = model->GetArray(op.inputs[i]); int input_size = RequiredBufferSizeForShape(input_array.shape()); memcpy(&output_data[dst_offset], &input_array.GetBuffer().data[0], - input_size * sizeof(Type)); + input_size * ElementSize(Type)); dst_offset += input_size; } CHECK_EQ(dst_offset, output_data.size()); @@ -49,13 +49,13 @@ void Stack(Model* model, StackOperator const& op) { } // namespace -bool ResolveConstantStack::Run(Model* model, std::size_t op_index) { +bool ResolveConstantPack::Run(Model* model, std::size_t op_index) { auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); - if (base_op->type != OperatorType::kStack) { + if (base_op->type != OperatorType::kPack) { return false; } - const auto* op = static_cast(base_op); + const auto* op = static_cast(base_op); CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); @@ -82,24 +82,24 @@ bool ResolveConstantStack::Run(Model* model, std::size_t op_index) { // Handle negative axis axis += model->GetArray(op->inputs[0]).shape().dims().size(); } - CHECK_EQ(axis, 0) << "Stacking only supported along 0th axis"; + CHECK_EQ(axis, 0) << "Packing only supported along 0th axis"; CHECK(!output_array.buffer); switch (output_array.data_type) { case ArrayDataType::kFloat: - Stack(model, *op); + Pack(model, *op); break; case ArrayDataType::kUint8: - Stack(model, *op); + Pack(model, *op); break; case ArrayDataType::kInt32: - Stack(model, *op); + Pack(model, *op); break; case ArrayDataType::kInt64: - Stack(model, *op); + Pack(model, *op); break; default: - LOG(FATAL) << "Unsupported data type given to Stack op with output \"" + LOG(FATAL) << "Unsupported data type given to Pack op with output \"" << op->outputs[0] << "\""; break; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc index 7e7ad383e7789891f5396845241e70143dc8b76f..41562ab393694d76c5cb6c5df5f7df2a71f893f5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc @@ -25,7 +25,7 @@ namespace toco { bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); - if (base_op->type != OperatorType::kTensorFlowReshape) { + if (base_op->type != OperatorType::kReshape) { return false; } const auto* op = static_cast(base_op); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc new file mode 100644 index 0000000000000000000000000000000000000000..e880a3f44dab376e5e441e3d6c0f747ee8490489 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +// Resolves a constant Select operation. +// +// This implementation is looking strictly for all-or-nothing on the select +// condition. It's possible to enhance this by looking per-element and possibly +// producing a Mul op. +bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + const auto* base_op = it->get(); + if (base_op->type != OperatorType::kSelect) { + return false; + } + const auto* op = static_cast(base_op); + + CHECK_GE(op->inputs.size(), 3); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes. + return false; + } + if (!output_array.has_shape()) { + // Yield until the output shape has been set by PropagateFixedShapes. + return false; + } + + // We require the cond input to be constant. + if (!IsConstantParameterArray(*model, op->inputs[0])) { + return false; + } + const Array& cond_array = model->GetArray(op->inputs[0]); + CHECK(cond_array.data_type == ArrayDataType::kBool) + << "Only bool conditions are supported"; + const auto& cond_data = cond_array.GetBuffer().data; + if (cond_data.empty()) { + return false; + } + + // Check if the condition is the same for all elements. + bool cond_value = cond_data[0]; + for (size_t i = 1; i < cond_data.size(); ++i) { + if (cond_data[i] != cond_value) { + AddMessageF( + "Cannot resolve %s as constant; cond_array has differing " + "per-element values", + LogName(*op)); + return false; + } + } + + // Pass-through the selected input. + return RemoveTrivialPassthroughOp(this, model, op_index, cond_value ? 1 : 2); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc index 9ea01acd05364224ce219bed533c999793a2a2f1..8a0e3e8995839a737b5671701a97b514b0fc7bf1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc @@ -22,8 +22,7 @@ namespace toco { bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { const auto it = model->operators.begin() + op_index; const auto* op = it->get(); - if (!(op->type == OperatorType::kTensorFlowShape || - op->type == OperatorType::kRank)) { + if (!(op->type == OperatorType::kShape || op->type == OperatorType::kRank)) { return false; } @@ -48,7 +47,7 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { // Compute the output CHECK(!output_array.buffer); auto& output_buffer = output_array.GetMutableBuffer(); - if (op->type == OperatorType::kTensorFlowShape) { + if (op->type == OperatorType::kShape) { // Copy the input shape into the output buffer. output_buffer.data = input_array.shape().dims(); } else if (op->type == OperatorType::kRank) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 1dd52e906900e997f282740404a81b9fcd21e867..9d8bd4fc39344a4ea1fa4942a2a99ec535b5bee8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -38,6 +38,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, CHECK_EQ(op.new_axis_mask, 0); int num_input_axes = op.start_indices.size(); + CHECK_EQ(num_input_axes, op.start_indices.size()); CHECK_EQ(num_input_axes, op.stop_indices.size()); CHECK_EQ(num_input_axes, op.strides.size()); @@ -49,11 +50,16 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, // Initialize source coordinate Shape const& input_shape = input_array.shape(); Buffer const& input_buffer = input_array.GetBuffer(); - std::vector src_coord(op.start_indices.size()); + std::vector src_coord(num_input_axes); + std::vector stop_for_axis(num_input_axes); for (int axis = 0; axis < num_input_axes; axis++) { - src_coord[axis] = tflite::strided_slice::StartForAxis( + int start = tflite::strided_slice::StartForAxis( op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(), axis); + src_coord[axis] = start; + stop_for_axis[axis] = tflite::strided_slice::StopForAxis( + op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides, + input_shape.dims().data(), axis, start); } // In order to handle any number (N) of dimensions, we copy elements one by @@ -76,9 +82,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, } // Check if we've overflowed. - int stop = tflite::strided_slice::StopForAxis( - op.end_mask, op.stop_indices, op.strides, input_shape.dims().data(), - axis); + int stop = stop_for_axis[axis]; if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) { // Reset axis and set carry src_coord[axis] = tflite::strided_slice::StartForAxis( @@ -155,14 +159,7 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) { break; } - // Erase input array if no longer used - if (IsDiscardableArray(*model, op->inputs[0]) && - CountOpsWithInput(*model, op->inputs[0]) == 1) { - model->EraseArray(op->inputs[0]); - } - - // Erase the operator - model->operators.erase(it); + DeleteOpAndArraysIfUnused(model, it->get()); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b0d0707146255562c093dd27b91ccb2b603a587 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc @@ -0,0 +1,173 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +// NOTE: the Tile implementation here is taken from tflite's Tile kernel. + +template +void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier, + T* out_data) { + for (int i = 0; i < multiplier; ++i) { + const T* in_end = in_data + in_size; + T* new_out_data = std::copy(in_data, in_end, out_data); + in_data = out_data; + out_data = new_out_data; + } +} + +template +std::pair TileOneDimension(const Shape& in_dimensions, + const T* in_data, const M* multipliers, + T* out_data, int dimension) { + const int dimension_size = in_dimensions.dims(dimension); + if (dimension == in_dimensions.dimensions_count() - 1) { + CopyMultipleTimes(in_data, dimension_size, multipliers[dimension], + out_data); + return std::make_pair( + dimension_size, + dimension_size * static_cast(multipliers[dimension])); + } + int total_stride_size = 0, total_tiled_stride_size = 0; + const T* copy_from_data = in_data; + T* copy_to_data = out_data; + for (int i = 0; i < dimension_size; ++i) { + int stride_size = 0, tiled_stride_size = 0; + std::tie(stride_size, tiled_stride_size) = + TileOneDimension(in_dimensions, copy_from_data, multipliers, + copy_to_data, dimension + 1); + copy_from_data += stride_size; + copy_to_data += tiled_stride_size; + total_stride_size += stride_size; + total_tiled_stride_size += tiled_stride_size; + } + CopyMultipleTimes(out_data, total_tiled_stride_size, + multipliers[dimension] - 1, + out_data + total_tiled_stride_size); + return std::make_pair(total_stride_size, + total_tiled_stride_size * multipliers[dimension]); +} + +template +inline void Tile(const Array& input_array, const Array& multiples_array, + Array* output_array) { + // Allocate output storage. + auto& output_data = output_array->GetMutableBuffer().data; + output_data.resize(RequiredBufferSizeForShape(output_array->shape())); + + switch (multiples_array.data_type) { + case ArrayDataType::kInt32: + TileOneDimension( + input_array.shape(), input_array.GetBuffer().data.data(), + multiples_array.GetBuffer().data.data(), + output_array->GetMutableBuffer().data.data(), 0); + break; + case ArrayDataType::kInt64: + TileOneDimension( + input_array.shape(), input_array.GetBuffer().data.data(), + multiples_array.GetBuffer().data.data(), + output_array->GetMutableBuffer().data.data(), 0); + break; + default: + CHECK(false); + break; + } +} + +} // namespace + +// Resolves a constant Tile operation. +bool ResolveConstantTile::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + const auto* base_op = it->get(); + if (base_op->type != OperatorType::kTile) { + return false; + } + const auto* op = static_cast(base_op); + + CHECK_GE(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes. + return false; + } + if (!output_array.has_shape()) { + // Yield until the output shape has been set by PropagateFixedShapes. + return false; + } + + // We require constant inputs. + if (!IsConstantParameterArray(*model, op->inputs[0]) || + !IsConstantParameterArray(*model, op->inputs[1])) { + return false; + } + const Array& input_array = model->GetArray(op->inputs[0]); + const Array& multiples_array = model->GetArray(op->inputs[1]); + CHECK(multiples_array.data_type == ArrayDataType::kInt32 || + multiples_array.data_type == ArrayDataType::kInt64) + << "Only int32/int64 indices are supported"; + + // Copy min/max info if present. The ranges of the selected values may be + // a subset of the original range but we want to ensure the quantization + // params stay the same. + if (input_array.minmax) { + const auto& input_minmax = input_array.GetMinMax(); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = input_minmax.min; + output_minmax.max = input_minmax.max; + } + + CHECK(!output_array.buffer); + switch (output_array.data_type) { + case ArrayDataType::kFloat: + Tile(input_array, multiples_array, &output_array); + break; + case ArrayDataType::kUint8: + Tile(input_array, multiples_array, &output_array); + break; + case ArrayDataType::kInt16: + Tile(input_array, multiples_array, &output_array); + break; + case ArrayDataType::kInt32: + Tile(input_array, multiples_array, &output_array); + break; + case ArrayDataType::kInt64: + Tile(input_array, multiples_array, &output_array); + break; + default: + LOG(FATAL) << "Unsupported data type given to Tile op with output \"" + << op->outputs[0] << "\""; + break; + } + + // Erase input arrays if no longer used after we remove the op. + DeleteArrayIfUsedOnce(op->inputs[0], model); + DeleteArrayIfUsedOnce(op->inputs[1], model); + + // Erase the operator. + model->operators.erase(it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index f6c8f79d8d3311dc2294e3ec406a184b2a16a6b5..475415e4814387fe10cb630a84b5d0304352b1e8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -53,13 +53,13 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { case OperatorType::kCast: case OperatorType::kLog: case OperatorType::kNeg: - case OperatorType::kTensorFlowRsqrt: - case OperatorType::kTensorFlowSqrt: - case OperatorType::kTensorFlowSquare: - case OperatorType::kTensorFlowSum: - case OperatorType::kTensorFlowMin: - case OperatorType::kTensorFlowMax: - case OperatorType::kTensorFlowReshape: + case OperatorType::kRsqrt: + case OperatorType::kSqrt: + case OperatorType::kSquare: + case OperatorType::kSum: + case OperatorType::kReduceMin: // Reduction Min + case OperatorType::kReduceMax: // Reduction Max + case OperatorType::kReshape: case OperatorType::kRelu6: case OperatorType::kRelu1: case OperatorType::kRelu: @@ -103,7 +103,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { // The min-max is only copied for ops that copy data without arithmetic. // In future trivial transpose, etc, can be handled here. - if (unary_op->type == OperatorType::kTensorFlowReshape) { + if (unary_op->type == OperatorType::kReshape) { CopyMinMaxFromFirstInput(*unary_op, model); } @@ -164,10 +164,10 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } output_float_data[i] = outval; } - } else if (unary_op->type == OperatorType::kTensorFlowReshape) { + } else if (unary_op->type == OperatorType::kReshape) { CHECK(input_buffer_size == output_buffer_size); output_float_data = *input_float_data; - } else if (unary_op->type == OperatorType::kTensorFlowSum) { + } else if (unary_op->type == OperatorType::kSum) { CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs"; if (!IsConstantParameterArray(*model, unary_op->inputs[1])) { AddMessageF("Axis input is non-constant"); @@ -196,7 +196,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } output_float_data[i] = sum; } - } else if (unary_op->type == OperatorType::kTensorFlowMin) { + } else if (unary_op->type == OperatorType::kReduceMin) { // At the moment only full reduction across all dimensions is supported. // TODO(starka): Output should not be padded. for (int i = 0; i < output_dims_count; i++) { @@ -207,7 +207,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { min = std::min(min, (*input_float_data)[i]); } output_float_data[0] = min; - } else if (unary_op->type == OperatorType::kTensorFlowMax) { + } else if (unary_op->type == OperatorType::kReduceMax) { // At the moment only full reduction across all dimensions is supported. // TODO(starka): Output should not be padded. for (int i = 0; i < output_dims_count; i++) { @@ -220,9 +220,9 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { output_float_data[0] = max; } else if (unary_op->type == OperatorType::kNeg || unary_op->type == OperatorType::kLog || - unary_op->type == OperatorType::kTensorFlowRsqrt || - unary_op->type == OperatorType::kTensorFlowSqrt || - unary_op->type == OperatorType::kTensorFlowSquare) { + unary_op->type == OperatorType::kRsqrt || + unary_op->type == OperatorType::kSqrt || + unary_op->type == OperatorType::kSquare) { // Element-wise ops. Should have perfectly matching sizes here. for (int i = 0; i < output_dims_count; i++) { CHECK_EQ(output_shape.dims(i), input_shape.dims(i)); @@ -235,19 +235,19 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { outval = -val; } else if (unary_op->type == OperatorType::kLog) { outval = std::log(val); - } else if (unary_op->type == OperatorType::kTensorFlowRsqrt) { + } else if (unary_op->type == OperatorType::kRsqrt) { outval = 1.0f / std::sqrt(val); - } else if (unary_op->type == OperatorType::kTensorFlowSqrt) { + } else if (unary_op->type == OperatorType::kSqrt) { outval = std::sqrt(val); - } else if (unary_op->type == OperatorType::kTensorFlowSquare) { + } else if (unary_op->type == OperatorType::kSquare) { outval = val * val; } else { LOG(FATAL) << "should not get here."; } output_float_data[i] = outval; } - } else if (unary_op->type == OperatorType::kRelu6 && - unary_op->type == OperatorType::kRelu1 && + } else if (unary_op->type == OperatorType::kRelu6 || + unary_op->type == OperatorType::kRelu1 || unary_op->type == OperatorType::kRelu) { for (size_t i = 0; i < output_buffer_size; ++i) { const float value = (*input_float_data)[i]; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc new file mode 100644 index 0000000000000000000000000000000000000000..0dda1fd0b35fb0cdc3c605360df5126c52c05403 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) { + const auto fakequant_it = model->operators.begin() + op_index; + auto* fakequant_base_op = fakequant_it->get(); + if (fakequant_base_op->type != OperatorType::kFakeQuant) { + return false; + } + auto* fakequant_op = static_cast(fakequant_base_op); + + if (fakequant_op->minmax) { + // Already resolved. + return false; + } + + CHECK_EQ(fakequant_op->inputs.size(), 3); + // We need to yield until the min and max parameters have been + // resolved to constant arrays. + for (int i = 1; i <= 2; i++) { + if (!IsConstantParameterArray(*model, fakequant_op->inputs[i])) { + return false; + } + } + + // Obtain the final min/max values + const auto& min_array = model->GetArray(fakequant_op->inputs[1]); + const auto& max_array = model->GetArray(fakequant_op->inputs[2]); + CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1); + CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1); + fakequant_op->minmax.reset(new MinMax); + MinMax& minmax = *fakequant_op->minmax; + minmax.min = min_array.GetBuffer().data[0]; + minmax.max = max_array.GetBuffer().data[0]; + // We always want [min, max] to contain 0. + if (minmax.min > 0 || minmax.max < 0) { + LOG(ERROR) << "For " << LogName(*fakequant_op) << " the MinMax range " + << "[" << minmax.min << ", " << minmax.max + << "] does not contain 0. " + << "Proceeding by tweaking it to contain 0, which will result " + "in poor accuracy."; + } + minmax.min = std::min(minmax.min, 0.); + minmax.max = std::max(minmax.max, 0.); + + // We won't use the input arrays that provided these min and max + // values, anymore. Delete them unless they are used by something + // else. + for (int i = 1; i <= 2; i++) { + DeleteArrayIfUsedOnce(fakequant_op->inputs[i], model); + } + fakequant_op->inputs.resize(1); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc new file mode 100644 index 0000000000000000000000000000000000000000..ce825c91af428c866ca9f83b765399f209606af9 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) { + auto* gather_op = model->operators[op_index].get(); + if (gather_op->type != OperatorType::kGather) return false; + auto* op = static_cast(gather_op); + + if (op->axis) { + // Attributes already resolved + return false; + } + if (op->inputs.size() != 3) return false; + if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + + const auto& indices_array = model->GetArray(op->inputs[2]); + if (!indices_array.has_shape()) return false; + const auto& axis_data = indices_array.GetBuffer().data; + CHECK_EQ(axis_data.size(), 1) + << "Multidimensional gather not supported on " << LogName(*op); + op->axis = {axis_data[0]}; + + // Drop the axis array as we no longer need it. + DeleteArrayIfUsedOnce(op->inputs[2], model); + op->inputs.resize(2); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc similarity index 60% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc rename to tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc index 013b50ac9ba8a51c23b19953d987b2fbf63fcea1..7d456af2fbc69352662b798cf1314f1653ef9f98 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc @@ -24,11 +24,8 @@ limitations under the License. namespace toco { -bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) { - auto* mean_op = model->operators[op_index].get(); - if (mean_op->type != OperatorType::kMean) return false; - auto* op = static_cast(mean_op); - +template +bool ResolveAttributes(Model* model, T* op) { if (!op->axis.empty()) { // Attributes already resolved return false; @@ -36,10 +33,28 @@ bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) { if (op->inputs.size() != 2) return false; if (!IsConstantParameterArray(*model, op->inputs[1])) return false; - const auto& indices_array = model->GetArray(op->inputs[1]); + const Array& indices_array = model->GetArray(op->inputs[1]); if (!indices_array.has_shape()) return false; op->axis = indices_array.GetBuffer().data; return true; } +bool ResolveReduceAttributes::Run(Model* model, std::size_t op_index) { + Operator* op = model->operators[op_index].get(); + switch (op->type) { + case OperatorType::kMean: + return ResolveAttributes(model, static_cast(op)); + case OperatorType::kSum: + return ResolveAttributes(model, static_cast(op)); + case OperatorType::kReduceProd: + return ResolveAttributes(model, static_cast(op)); + case OperatorType::kReduceMin: + return ResolveAttributes(model, static_cast(op)); + case OperatorType::kReduceMax: + return ResolveAttributes(model, static_cast(op)); + default: + return false; + } +} + } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc index bc70db0bd8c26319fa140616de96452260a01058..8266e2c205b65e9d8a969643f102bb852be9125b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc @@ -51,11 +51,12 @@ void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order, } bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { - auto reorder_it = model->operators.begin() + op_index; - auto* reorder_op = static_cast(reorder_it->get()); - if (reorder_op->type != OperatorType::kReorderAxes) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + if (op->type != OperatorType::kReorderAxes) { return false; } + auto* reorder_op = static_cast(op); const auto& input_array_name = reorder_op->inputs[0]; const auto& output_array_name = reorder_op->outputs[0]; auto& input_array = model->GetArray(input_array_name); @@ -95,7 +96,7 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { // Remove the op and output array. model->EraseArray(output_array_name); - model->operators.erase(reorder_it); + model->operators.erase(it); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc index 2e063e35548aa5e51c3bcc94a2dfc7992180d014..b615c9a545695e5d14fa5809e0c38a770f23ea24 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc @@ -28,7 +28,7 @@ namespace toco { bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) { const auto reshape_it = model->operators.begin() + op_index; auto* reshape_op = reshape_it->get(); - if (reshape_op->type != OperatorType::kTensorFlowReshape) { + if (reshape_op->type != OperatorType::kReshape) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc index dad6aceccfd201b3db07c29c99a8c6ef75bb89a1..fab50bec1fc5ec50cecba53845457931ed59c0b8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc @@ -53,7 +53,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { // will delete this op. return false; } - std::vector paddings_buffer = + const std::vector& paddings_buffer = paddings_array.GetBuffer().data; for (int i = 0; i < paddings_dims[0]; ++i) { op->before_paddings.push_back(paddings_buffer[i * 2]); @@ -66,7 +66,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) { if (!block_shape_array.has_shape()) return false; const std::vector& block_shape_dims = block_shape_array.shape().dims(); CHECK_EQ(block_shape_dims.size(), 1); - std::vector block_shape_buffer = + const std::vector& block_shape_buffer = block_shape_array.GetBuffer().data; for (int i = 0; i < block_shape_dims[0]; ++i) { op->block_shape.push_back(block_shape_buffer[i]); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc index dd3e73635ae0215510f0a8d1aee487da5af35700..e8bb85704e1c750300079681b5a12f6a488b6b48 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc @@ -36,7 +36,7 @@ bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) { // If the output is consumed by a reshape op, it's a trivial squeeze. if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) { const auto* next_op = GetOpWithInput(*model, squeeze_op->outputs[0]); - if (next_op->type == OperatorType::kTensorFlowReshape) { + if (next_op->type == OperatorType::kReshape) { AddMessageF( "%s is trivial because its output is only consumed by a " "Reshape op", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc index 5c0c1e3478fa0d94104d1b76bab176b98b314c50..fa5ee899334bdf2d39a6861b0e0c4548142e9d2a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc @@ -28,8 +28,8 @@ namespace toco { bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { auto concat_it = model->operators.begin() + op_index; const auto* tf_concat_op = concat_it->get(); - if (tf_concat_op->type != OperatorType::kTensorFlowConcat && - tf_concat_op->type != OperatorType::kTensorFlowConcatV2) { + if (tf_concat_op->type != OperatorType::kConcat && + tf_concat_op->type != OperatorType::kConcatV2) { return false; } @@ -38,7 +38,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { // of inputs: in Concat,the axis is the first input, while in // ConcatV2, it is the last input. std::size_t axis_pos = 0; - if (tf_concat_op->type == OperatorType::kTensorFlowConcatV2) { + if (tf_concat_op->type == OperatorType::kConcatV2) { axis_pos = tf_concat_op->inputs.size() - 1; } const string axis_name = tf_concat_op->inputs[axis_pos]; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index 2a236d3f98784e8244942f94d5a250b5bc00a8ad..fcf30bd34725fc59bb819e75deda0dadf330f372 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -26,27 +26,40 @@ namespace toco { bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { auto matmul_it = model->operators.begin() + op_index; - if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) { + if (matmul_it->get()->type != OperatorType::kMatMul) { return false; } const auto* matmul_op = static_cast(matmul_it->get()); + // Handling transposition of the first input here isn't very simple because + // we need to know the actual shape in order to produce a proper + // TransposeOperator. However, the second input is supposed to be 2D, so we + // can actually handle transposition of that matrix, which happens to be more + // common anyway. + CHECK(!matmul_op->transpose_a); + // Reorder the axes on the second input. TensorFlow uses row-major ordering // on both inputs, however this is inefficient for the FullyConnected // operator. We'll transpose the second input to be in column-major order now // and let constant propagation optimize things (if possible). - auto* transpose_op = new TransposeOperator; - transpose_op->inputs = { - matmul_op->inputs[1], - CreateInt32Array( - model, - AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose/perm"), - {1, 0})}; - transpose_op->outputs = { - AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")}; - model->GetOrCreateArray(transpose_op->outputs[0]); - model->operators.emplace(matmul_it, transpose_op); + string input_lhs = matmul_op->inputs[0]; + string input_rhs = matmul_op->inputs[1]; + if (!matmul_op->transpose_b) { + auto* transpose_op = new TransposeOperator; + transpose_op->inputs = { + matmul_op->inputs[1], + CreateInt32Array(model, + AvailableArrayName( + *model, matmul_op->inputs[1] + "/transpose/perm"), + {1, 0})}; + transpose_op->outputs = { + AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")}; + model->GetOrCreateArray(transpose_op->outputs[0]); + model->operators.emplace(matmul_it, transpose_op); + + input_rhs = transpose_op->outputs[0]; + } // Refresh iterator. matmul_it = model->operators.begin(); @@ -57,9 +70,6 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { } DCHECK_EQ(matmul_it->get(), matmul_op); - string input_lhs = matmul_op->inputs[0]; - string input_rhs = transpose_op->outputs[0]; - // Construct the new FullyConnectedOperator. auto* fc_op = new FullyConnectedOperator; fc_op->outputs = matmul_op->outputs; @@ -97,7 +107,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { // MatMul op as a FullyConnected. However, TensorFlow skips the Reshape ops if // the input doesn't need reshaping, so we can't just match (Reshape, MatMul) // pairs. - if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) { + if (previous_op && previous_op->type == OperatorType::kReshape) { AddMessageF("Combining %s and %s into %s", LogName(*previous_op), LogName(*matmul_op), LogName(*fc_op)); const auto& previous_op_output = previous_op->outputs[0]; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc index 38e0005890ac10410df4ddb5290be8fcc948c349..4edffe3d48fd880c0261b34fc407b8e2ac66ccb9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc @@ -27,7 +27,7 @@ namespace toco { bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) { const auto merge_it = model->operators.begin() + op_index; const auto* merge_op = merge_it->get(); - if (merge_op->type != OperatorType::kTensorFlowMerge) { + if (merge_op->type != OperatorType::kMerge) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc index a418073441f1241a5acb1164b36f332828ea2e99..8bef440afd21572d7014e4f376be3aba2d80127d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc @@ -27,7 +27,7 @@ namespace toco { bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { const auto switch_it = model->operators.begin() + op_index; const auto* switch_op = switch_it->get(); - if (switch_op->type != OperatorType::kTensorFlowSwitch) { + if (switch_op->type != OperatorType::kSwitch) { return false; } @@ -92,7 +92,9 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { if (*input_it == switch_op->outputs[nonselected_output_index]) { // Let us guard our assumption that only Merge nodes consume the outputs // of Switch nodes: - CHECK(other_op->type == OperatorType::kTensorFlowMerge); + CHECK(other_op->type == OperatorType::kMerge) + << "Found " << HelpfulOperatorTypeName(*other_op) + << " as non-selected output from Switch, but only Merge supported."; input_it = other_op->inputs.erase(input_it); } else { ++input_it; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc deleted file mode 100644 index 1ddf54c778cd1fae7a8fce0ecb97209274e71ac0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include - -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" - -namespace toco { - -namespace { - -void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op, - int operand_index) { - CHECK(tile_op->type == OperatorType::kTensorFlowTile); - CHECK_EQ(binary_op->inputs.size(), 2); - CHECK_EQ(tile_op->inputs.size(), 2); - const string tile_multiplier_array = tile_op->inputs[1]; - const string tile_output_array = tile_op->outputs[0]; - binary_op->inputs[operand_index] = tile_op->inputs[0]; - auto tile_it = model->operators.begin(); - for (; tile_it != model->operators.end(); ++tile_it) { - if (tile_it->get() == tile_op) { - break; - } - } - CHECK(tile_it != model->operators.end()); - CHECK(tile_it->get() == tile_op); - model->operators.erase(tile_it); - if (!CountOpsWithInput(*model, tile_multiplier_array) && - !GetOpWithOutput(*model, tile_multiplier_array)) { - model->EraseArray(tile_multiplier_array); - } - if (!CountOpsWithInput(*model, tile_output_array)) { - model->EraseArray(tile_output_array); - } -} -} // namespace - -bool ResolveTensorFlowTile::Run(Model* model, std::size_t op_index) { - const auto binary_it = model->operators.begin() + op_index; - auto* binary_op = binary_it->get(); - // Test for binary ops of types that we know how to resolve - if (binary_op->inputs.size() != 2) { - return false; - } - if (binary_op->type != OperatorType::kAdd && - binary_op->type != OperatorType::kMul && - binary_op->type != OperatorType::kSub && - binary_op->type != OperatorType::kDiv) { - return false; - } - - Operator* const op[2] = { - GetOpWithOutput(*model, binary_op->inputs[0]), - GetOpWithOutput(*model, binary_op->inputs[1]), - }; - - // In the unlikely case where both operands are Tile, we can't infer the - // output - // size without the Tile nodes, so we have to bail out. - if (op[0] && op[0]->type == OperatorType::kTensorFlowTile && op[1] && - op[1]->type == OperatorType::kTensorFlowTile) { - return false; - } - - for (int i = 0; i < 2; i++) { - if (op[i] && op[i]->type == OperatorType::kTensorFlowTile) { - // We can only remove a Tile operator is no other op than the present - // binary op was consuming its tiled output. - if (CountOpsWithInput(*model, binary_op->inputs[i]) == 1) { - AddMessageF("Removing %s", LogName(*op[i])); - RemoveTileOperator(model, op[i], binary_op, i); - return true; - } - } - } - return false; -} - -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc similarity index 96% rename from tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc rename to tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc index c00cdcb944b085dda41033b95c96537cc2e047c3..22c258cec5fde4144c4b048d5ec60a8604362cbb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc @@ -24,14 +24,14 @@ limitations under the License. namespace toco { -bool ExperimentalShuffleFCWeights::Run(Model* model, std::size_t op_index) { +bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { Operator* op = model->operators[op_index].get(); if (op->type != OperatorType::kFullyConnected) { return false; } FullyConnectedOperator* fc_op = static_cast(op); // Exit if this FC op already has shuffled weights - if (fc_op->experimental_shuffled_weights) { + if (fc_op->weights_format != FullyConnectedWeightsFormat::kDefault) { return false; } const Array& input_array = model->GetArray(fc_op->inputs[0]); @@ -135,7 +135,7 @@ bool ExperimentalShuffleFCWeights::Run(Model* model, std::size_t op_index) { CHECK_EQ(shuffled_data_ptr, shuffled_data.data() + rows * cols); // Switch this FC op to using the shuffled weights. weights_data = std::move(shuffled_data); - fc_op->experimental_shuffled_weights = true; + fc_op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8; AddMessageF("Applied experimental shuffling to the weights of %s", LogName(*op)); // Add a second output array to this FC op, serving as a workspace to perform diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD index 95e8433be2a332cfce5175f4f65ea0b83d5638c5..e163fc9ae1422504ef1b0a3c567c420f649f0827 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD @@ -10,6 +10,7 @@ load( tf_cc_test( name = "lstm_utils_test", srcs = ["lstm_utils_test.cc"], + tags = ["no_oss"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", @@ -21,6 +22,7 @@ tf_cc_test( tf_cc_test( name = "quantize_weights_test", srcs = ["quantize_weights_test.cc"], + tags = ["no_oss"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", @@ -33,6 +35,7 @@ tf_cc_test( tf_cc_test( name = "resolve_constant_concatenation_test", srcs = ["resolve_constant_concatenation_test.cc"], + tags = ["no_oss"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc index 2c7046c8c77c94a89fc05a26d7d72b3661380475..69bad2fa89cb89cd74e3a4bca98da906a322a670 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc @@ -64,7 +64,14 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) { const string& tmp_array_name = AvailableArrayName(*model, op->outputs[0] + "_unfused"); CHECK(!model->HasArray(tmp_array_name)); - model->GetOrCreateArray(tmp_array_name); + + const auto& output_array = model->GetArray(op->outputs[0]); + auto& tmp_array = model->GetOrCreateArray(tmp_array_name); + if (output_array.quantization_params) { + tmp_array.GetOrCreateQuantizationParams() = + output_array.GetQuantizationParams(); + } + ac_op->inputs = {tmp_array_name}; op->outputs = {tmp_array_name}; return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc index cbea39bcc09ea6787c055d5aaca7f291c2b47a7f..dd9e26e68bd7e6d5cb751fdbf705b861c3f2f188 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc @@ -187,6 +187,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted/perm")); gather_params_permute_op->outputs.push_back( AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted")); + gather_params_permute_op->axis = {0}; op_it = model->operators.emplace(op_it, gather_params_permute_op) + 1; model->GetOrCreateArray(gather_params_permute_op->outputs[0]); const auto& partition_array = model->GetArray(gather_ops[0]->inputs[0]); @@ -212,6 +213,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { mod_op->inputs[0]}; merged_gather_op->outputs = {stitch_op->outputs[0]}; merged_gather_op->input_rank = partition_array.shape().dimensions_count(); + merged_gather_op->axis = {0}; model->operators.emplace(op_it, merged_gather_op); AddMessageF( diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc index da81ea2ff3b4ab0bee0550874a9c4ea1044a3579..5f0cece67a49de6d50fd08896d14d3f27df46b44 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc @@ -76,7 +76,7 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op), batch_count); auto tail_it = batch_op_it; - std::vector stack_inputs; + std::vector pack_inputs; for (int batch = 0; batch < batch_count; ++batch) { std::string batch_name = std::string(batch_op->outputs[0]) + "_b" + std::to_string(batch); @@ -146,15 +146,15 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { tail_it = model->operators.emplace(tail_it, matmul_op) + 1; // Add to stack. - stack_inputs.push_back(matmul_op->outputs[0]); + pack_inputs.push_back(matmul_op->outputs[0]); } - // The stack that will join all the individual matmul results together. - auto* stack_op = new StackOperator; - stack_op->inputs = stack_inputs; - stack_op->outputs = {batch_op->outputs[0]}; - stack_op->axis = 0; - model->operators.emplace(tail_it, stack_op); + // The pack that will join all the individual matmul results together. + auto* pack_op = new PackOperator; + pack_op->inputs = pack_inputs; + pack_op->outputs = {batch_op->outputs[0]}; + pack_op->axis = 0; + model->operators.emplace(tail_it, pack_op); // Remove the old batch matmul now that we've unrolled. batch_op_it = model->operators.begin(); diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index cd4f034dfea57b6d379b67a90ba4fa3fe3d615d5..b7fffbce2223a71ac1e16ec1ce18ba9f610cc2ac 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -215,7 +215,7 @@ tensorflow::Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_FLOAT); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -253,7 +253,7 @@ tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_QUINT8); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -263,7 +263,11 @@ tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, output_array->GetMutableBuffer().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); CHECK_GE(output_int_data.size(), input_flat_size); - if (input_tensor.int_val_size()) { + if (input_tensor.int_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_int_data[i] = input_tensor.int_val(0); + } + } else if (input_tensor.int_val_size() == input_flat_size) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); } @@ -286,7 +290,7 @@ tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT32); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -296,7 +300,11 @@ tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, output_array->GetMutableBuffer().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); CHECK_GE(output_int_data.size(), input_flat_size); - if (input_tensor.int_val_size()) { + if (input_tensor.int_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_int_data[i] = input_tensor.int_val(0); + } + } else if (input_tensor.int_val_size() == input_flat_size) { for (int i = 0; i < input_tensor.int_val_size(); i++) { output_int_data[i] = input_tensor.int_val(i); } @@ -318,7 +326,7 @@ tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT64); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -328,8 +336,12 @@ tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, output_array->GetMutableBuffer().data; output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0); CHECK_GE(output_int_data.size(), input_flat_size); - if (input_tensor.int64_val_size()) { - for (int i = 0; i < input_tensor.int64_val_size(); i++) { + if (input_tensor.int64_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_int_data[i] = input_tensor.int64_val(0); + } + } else if (input_tensor.int64_val_size() == input_flat_size) { + for (int i = 0; i < input_tensor.float_val_size(); i++) { output_int_data[i] = input_tensor.int64_val(i); } } else if (input_tensor.tensor_content().size() == @@ -351,7 +363,7 @@ tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_BOOL); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -362,7 +374,11 @@ tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()), false); CHECK_GE(output_bool_data.size(), input_flat_size); - if (input_tensor.bool_val_size()) { + if (input_tensor.bool_val_size() == 1) { + for (int i = 0; i < input_flat_size; i++) { + output_bool_data[i] = input_tensor.bool_val(0); + } + } else if (input_tensor.bool_val_size() == input_flat_size) { for (int i = 0; i < input_tensor.bool_val_size(); i++) { output_bool_data[i] = input_tensor.bool_val(i); } @@ -393,7 +409,7 @@ tensorflow::Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_STRING); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -426,18 +442,19 @@ int GetInputsCount(const NodeDef& node, return i; } } - return node.input_size(); - } else { - return node.input_size(); } + return node.input_size(); } -void CheckInputsCount(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - int expected_input_count) { - QCHECK_EQ(GetInputsCount(node, tf_import_flags), expected_input_count) - << node.op() << " node expects " << expected_input_count - << " input(s) other than control dependencies: " << node.DebugString(); +tensorflow::Status CheckInputsCount( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + int expected_input_count) { + if (GetInputsCount(node, tf_import_flags) != expected_input_count) { + return tensorflow::errors::FailedPrecondition( + node.op(), " node expects ", expected_input_count, + " input(s) other than control dependencies: ", node.DebugString()); + } + return tensorflow::Status::OK(); } template @@ -504,7 +521,7 @@ tensorflow::Status ConvertConvOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "Conv2D"); - CheckInputsCount(node, tf_import_flags, 2); + TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2)); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -574,11 +591,11 @@ tensorflow::Status ConvertConvOperator( return tensorflow::Status::OK(); } -void ConvertDepthwiseConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertDepthwiseConvOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "DepthwiseConv2dNative"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -625,13 +642,14 @@ void ConvertDepthwiseConvOperator(const NodeDef& node, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } model->operators.emplace_back(conv); + return tensorflow::Status::OK(); } -void ConvertDepthToSpaceOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertDepthToSpaceOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "DepthToSpace"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); auto* op = new DepthToSpaceOperator; @@ -640,13 +658,14 @@ void ConvertDepthToSpaceOperator(const NodeDef& node, op->block_size = GetIntAttr(node, "block_size"); QCHECK_GE(op->block_size, 2); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSpaceToDepthOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSpaceToDepthOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "SpaceToDepth"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); tensorflow::DataType dtype = GetDataTypeAttr(node, "T"); if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 && @@ -662,13 +681,14 @@ void ConvertSpaceToDepthOperator(const NodeDef& node, op->block_size = GetIntAttr(node, "block_size"); QCHECK_GE(op->block_size, 2); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertBiasAddOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertBiasAddOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "BiasAdd"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto& input_name = node.input(0); const auto& bias_name = node.input(1); @@ -678,13 +698,14 @@ void ConvertBiasAddOperator(const NodeDef& node, biasadd->inputs.push_back(bias_name); biasadd->outputs.push_back(node.name()); model->operators.emplace_back(biasadd); + return tensorflow::Status::OK(); } -void ConvertRandomUniform(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertRandomUniform( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "RandomUniform"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32); auto op = absl::make_unique(); @@ -695,11 +716,12 @@ void ConvertRandomUniform(const NodeDef& node, op->seed2 = GetIntAttr(node, "seed2"); CHECK(model != nullptr); model->operators.emplace_back(std::move(op)); + return tensorflow::Status::OK(); } -void ConvertIdentityOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertIdentityOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" || node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient"); auto* op = new TensorFlowIdentityOperator; @@ -716,13 +738,14 @@ void ConvertIdentityOperator(const NodeDef& node, op->inputs.push_back(input_name); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFakeQuantWithMinMaxArgs( +tensorflow::Status ConvertFakeQuantWithMinMaxArgs( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); auto* op = new FakeQuantOperator; op->inputs.push_back(node.input(0)); op->minmax.reset(new MinMax); @@ -732,10 +755,14 @@ void ConvertFakeQuantWithMinMaxArgs( op->outputs.push_back(node.name()); // tf.fake_quant_with_min_max_args num_bits defaults to 8. op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; + if (HasAttr(node, "narrow_range")) { + op->narrow_range = GetBoolAttr(node, "narrow_range"); + } model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFakeQuantWithMinMaxVars( +tensorflow::Status ConvertFakeQuantWithMinMaxVars( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars"); @@ -750,15 +777,18 @@ void ConvertFakeQuantWithMinMaxVars( } op->outputs.push_back(node.name()); op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; + if (HasAttr(node, "narrow_range")) { + op->narrow_range = GetBoolAttr(node, "narrow_range"); + } model->operators.emplace_back(op); + return tensorflow::Status::OK(); } - -void ConvertSqueezeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSqueezeOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Squeeze"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); auto* op = new SqueezeOperator; op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); @@ -772,28 +802,14 @@ void ConvertSqueezeOperator(const NodeDef& node, } model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sum"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowSumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } -} - -void ConvertSplitOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSplitOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Split"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSplitOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -804,13 +820,14 @@ void ConvertSplitOperator(const NodeDef& node, } op->num_split = num_split; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSwitchOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSwitchOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Switch"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSwitchOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -818,13 +835,14 @@ void ConvertSwitchOperator(const NodeDef& node, // Switch operators have two outputs: "name" and "name:1". op->outputs.push_back(node.name() + ":1"); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSoftmaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSoftmaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Softmax"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); auto* softmax = new SoftmaxOperator; softmax->inputs.push_back(input_name); @@ -833,13 +851,14 @@ void ConvertSoftmaxOperator(const NodeDef& node, CHECK(!node.attr().count("beta")); // Stab in the dark, just in case. softmax->beta = 1.f; model->operators.emplace_back(softmax); + return tensorflow::Status::OK(); } -void ConvertLRNOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertLRNOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "LRN"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); auto* lrn = new LocalResponseNormalizationOperator; lrn->inputs.push_back(input_name); @@ -849,13 +868,14 @@ void ConvertLRNOperator(const NodeDef& node, lrn->alpha = GetFloatAttr(node, "alpha"); lrn->beta = GetFloatAttr(node, "beta"); model->operators.emplace_back(lrn); + return tensorflow::Status::OK(); } -void ConvertMaxPoolOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertMaxPoolOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "MaxPool"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -891,13 +911,14 @@ void ConvertMaxPoolOperator(const NodeDef& node, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } model->operators.emplace_back(maxpool); + return tensorflow::Status::OK(); } -void ConvertAvgPoolOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertAvgPoolOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "AvgPool"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -929,13 +950,13 @@ void ConvertAvgPoolOperator(const NodeDef& node, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } model->operators.emplace_back(avgpool); + return tensorflow::Status::OK(); } - -void ConvertBatchMatMulOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, 2); +tensorflow::Status ConvertBatchMatMulOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions CHECK(!HasAttr(node, "adj_a") || (GetBoolAttr(node, "adj_a") == false)); @@ -945,33 +966,36 @@ void ConvertBatchMatMulOperator(const NodeDef& node, batch_matmul->inputs = {node.input(0), node.input(1)}; batch_matmul->outputs = {node.name()}; model->operators.emplace_back(batch_matmul); + return tensorflow::Status::OK(); } -void ConvertMatMulOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, 2); +tensorflow::Status ConvertMatMulOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - // Transpose flags should be easy to support, but we don't have a - // GraphDef with them to test on at the moment. - CHECK_EQ(HasAttr(node, "transpose_a") && GetBoolAttr(node, "transpose_a"), - false); - CHECK_EQ(HasAttr(node, "transpose_b") && GetBoolAttr(node, "transpose_b"), - false); CHECK(!HasAttr(node, "adjoint_a") || (GetBoolAttr(node, "adjoint_a") == false)); CHECK(!HasAttr(node, "adjoint_b") || (GetBoolAttr(node, "adjoint_b") == false)); auto* matmul = new TensorFlowMatMulOperator; + if (HasAttr(node, "transpose_a")) { + matmul->transpose_a = GetBoolAttr(node, "transpose_a"); + } + if (HasAttr(node, "transpose_b")) { + matmul->transpose_b = GetBoolAttr(node, "transpose_b"); + } + matmul->inputs = {node.input(0), node.input(1)}; matmul->outputs = {node.name()}; model->operators.emplace_back(matmul); + return tensorflow::Status::OK(); } -void ConvertConcatOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConcatOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { Operator* op = nullptr; if (node.op() == "Concat") { op = new TensorFlowConcatOperator; @@ -991,13 +1015,14 @@ void ConvertConcatOperator(const NodeDef& node, } op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } // This method supports simple operators without additional attributes. template -void ConvertSimpleOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSimpleOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { auto* op = new Op; const int num_inputs = GetInputsCount(node, tf_import_flags); for (int i = 0; i < num_inputs; ++i) { @@ -1005,51 +1030,28 @@ void ConvertSimpleOperator(const NodeDef& node, } op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } // This method supports simple operators without additional attributes. template -void ConvertSimpleOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, NumInputs); - ConvertSimpleOperator(node, tf_import_flags, model); -} - -void ConvertMaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Max"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMaxOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } -} - -void ConvertMinOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Min"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMinOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } +tensorflow::Status ConvertSimpleOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs)); + return ConvertSimpleOperator(node, tf_import_flags, model); } +tensorflow::Status ConvertUnsupportedOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + // Names of special attributes in TF graph that are used by Toco. + static constexpr char kAttrOutputQuantized[] = "_output_quantized"; + static constexpr char kAttrOutputTypes[] = "_output_types"; + static constexpr char kAttrOutputShapes[] = "_output_shapes"; + static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] = + "_support_output_type_float_in_quantized_op"; -void ConvertUnsupportedOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { LOG(INFO) << "Converting unsupported operation: " << node.op(); auto* op = new TensorFlowUnsupportedOperator; const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -1060,11 +1062,17 @@ void ConvertUnsupportedOperator(const NodeDef& node, op->tensorflow_op = node.op(); node.SerializeToString(&op->tensorflow_node_def); model->operators.emplace_back(op); - if (HasAttr(node, "_output_quantized")) { - op->quantized = GetBoolAttr(node, "_output_quantized"); + // Parse if the op supports quantization + if (HasAttr(node, kAttrOutputQuantized)) { + op->quantized = GetBoolAttr(node, kAttrOutputQuantized); + } + // Parse if the quantized op allows output arrays of type float + if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) { + op->support_output_type_float_in_quantized_op = + GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp); } - if (HasAttr(node, "_output_types")) { - const auto& output_types = GetListAttr(node, "_output_types"); + if (HasAttr(node, kAttrOutputTypes)) { + const auto& output_types = GetListAttr(node, kAttrOutputTypes); for (int i = 0; i < output_types.type_size(); ++i) { op->output_data_types.push_back(ConvertDataType(output_types.type(i))); } @@ -1072,15 +1080,29 @@ void ConvertUnsupportedOperator(const NodeDef& node, const auto& output_type = GetDataTypeAttr(node, "Tout"); op->output_data_types.push_back(ConvertDataType(output_type)); } + if (HasAttr(node, kAttrOutputShapes)) { + const auto& output_shapes = GetListAttr(node, kAttrOutputShapes); + Shape output_shape; + for (int i = 0; i < output_shapes.shape_size(); ++i) { + const auto status = + ImportShape(output_shapes.shape(i).dim(), /*input_flat_size=*/nullptr, + &output_shape); + if (!status.ok()) { + return status; + } + op->output_shapes.push_back(output_shape); + } + } + return tensorflow::Status::OK(); } -void ConvertStridedSliceOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertStridedSliceOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "StridedSlice"); // TODO(soroosh): The 4th input (strides) should be e optional, to be // consistent with TF. - CheckInputsCount(node, tf_import_flags, 4); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); auto* op = new StridedSliceOperator; for (const auto& input : node.input()) { @@ -1100,14 +1122,15 @@ void ConvertStridedSliceOperator(const NodeDef& node, : 0; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertPlaceholderOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertPlaceholderOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput"); if (node.op() == "Placeholder") { - CheckInputsCount(node, tf_import_flags, 0); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0)); } auto& array = model->GetOrCreateArray(node.name()); if (node.attr().count("dtype")) { @@ -1132,17 +1155,20 @@ void ConvertPlaceholderOperator(const NodeDef& node, } } } + return tensorflow::Status::OK(); } -void ConvertNoOpOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) {} +tensorflow::Status ConvertNoOpOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + return tensorflow::Status::OK(); +} -void ConvertCastOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertCastOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Cast"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT"); const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT"); auto* op = new CastOperator; @@ -1151,43 +1177,57 @@ void ConvertCastOperator(const NodeDef& node, op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFloorOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertFloorOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Floor"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto data_type = GetDataTypeAttr(node, "T"); CHECK(data_type == DT_FLOAT); auto* op = new FloorOperator; op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertGatherOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertGatherOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK(node.op() == "Gather" || node.op() == "GatherV2"); - if (node.op() == "Gather") CheckInputsCount(node, tf_import_flags, 2); - if (node.op() == "GatherV2") CheckInputsCount(node, tf_import_flags, 3); + if (node.op() == "Gather") + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + if (node.op() == "GatherV2") + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); const auto indices_data_type = GetDataTypeAttr(node, "Tindices"); CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64); auto* op = new GatherOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); - // TODO(ahentz): we currently ignore the third tensor in GatherV2 but we - // should read it an pass it on to the TF Lite Interpreter. + if (node.input_size() >= 3) { + // GatherV2 form where we are provided an axis. It may be either a constant + // or runtime defined value, so we just wire up the array and let + // ResolveGatherAttributes take care of it later on. + const auto axis_data_type = GetDataTypeAttr(node, "Taxis"); + CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64); + op->inputs.push_back(node.input(2)); + } else { + // Gather form that assumes axis=0. + op->axis = {0}; + } op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertArgMaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "ArgMax"); - CheckInputsCount(node, tf_import_flags, 2); +template +tensorflow::Status ConvertArgMinMaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto axis_data_type = HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; const auto output_type = HasAttr(node, "output_type") @@ -1195,19 +1235,34 @@ void ConvertArgMaxOperator(const NodeDef& node, : DT_INT64; CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32); CHECK(output_type == DT_INT64 || output_type == DT_INT32); - auto* op = new ArgMaxOperator; + auto* op = new Op; op->output_data_type = ConvertDataType(output_type); op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertArgMaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "ArgMax"); + return ConvertArgMinMaxOperator(node, tf_import_flags, model); +} + +tensorflow::Status ConvertArgMinOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "ArgMin"); + return ConvertArgMinMaxOperator(node, tf_import_flags, model); } -void ConvertResizeBilinearOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertResizeBilinearOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "ResizeBilinear"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new ResizeBilinearOperator; op->align_corners = false; @@ -1219,13 +1274,14 @@ void ConvertResizeBilinearOperator(const NodeDef& node, op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertBatchNormWithGlobalNormalizationOperator( +tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization"); - CheckInputsCount(node, tf_import_flags, 5); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5)); // TODO(ahentz): to really match tensorflow we need to add variance_epsilon // to the input, before feeding it into TensorFlowRsqrtOperator. @@ -1268,13 +1324,14 @@ void ConvertBatchNormWithGlobalNormalizationOperator( op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFusedBatchNormOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertFusedBatchNormOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "FusedBatchNorm"); - CheckInputsCount(node, tf_import_flags, 5); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5)); // Declare shortcuts for the inputs. const string& gamma_input = node.input(1); @@ -1320,13 +1377,14 @@ void ConvertFusedBatchNormOperator(const NodeDef& node, op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSpaceToBatchNDOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSpaceToBatchNDOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "SpaceToBatchND"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32); auto* op = new SpaceToBatchNDOperator; @@ -1335,13 +1393,14 @@ void ConvertSpaceToBatchNDOperator(const NodeDef& node, op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertBatchToSpaceNDOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertBatchToSpaceNDOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "BatchToSpaceND"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32); auto* op = new BatchToSpaceNDOperator; @@ -1350,14 +1409,15 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node, op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertMeanOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Mean"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new MeanOperator; +template +tensorflow::Status ConvertReduceOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + auto* op = new T; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); @@ -1367,11 +1427,12 @@ void ConvertMeanOperator(const NodeDef& node, } else if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } + return tensorflow::Status::OK(); } -void ConvertSvdfOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSvdfOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Svdf"); const int input_size = GetInputsCount(node, tf_import_flags); QCHECK(input_size == 3 || input_size == 4) @@ -1394,14 +1455,15 @@ void ConvertSvdfOperator(const NodeDef& node, } op->rank = node.attr().at("Rank").i(); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } // This is just bare bones support to get the shapes to propagate. -void ConvertTransposeConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertTransposeConvOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Conv2DBackpropInput"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new TransposeConvOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1465,14 +1527,14 @@ void ConvertTransposeConvOperator(const NodeDef& node, "Conv2DBackpropInput nodes."; } model->operators.emplace_back(op); + return tensorflow::Status::OK(); } - -void ConvertRangeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertRangeOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Range"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new RangeOperator; if (HasAttr(node, "Tidx")) { const auto dtype = toco::GetDataTypeAttr(node, "Tidx"); @@ -1485,13 +1547,18 @@ void ConvertRangeOperator(const NodeDef& node, op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertStackOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK((node.op() == "Stack") || (node.op() == "Pack")); - auto* op = new StackOperator; +// Note that it's easy to confuse/conflate "Stack" and "Pack" operators, but +// they aren't the same thing. tf.stack results in a "Pack" operator. "Stack" +// operators also exist, but involve manipulating the TF runtime stack, and are +// not directly related to tf.stack() usage. +tensorflow::Status ConvertPackOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Pack"); + auto op = absl::make_unique(); const int num_inputs = GetInputsCount(node, tf_import_flags); QCHECK_GE(num_inputs, 1) << node.op() @@ -1501,13 +1568,14 @@ void ConvertStackOperator(const NodeDef& node, for (int i = 0; i < num_inputs; ++i) { op->inputs.push_back(node.input(i)); } - // Both "Stack" and "Pack" have the "axis" attribute. + op->values_count = HasAttr(node, "N") ? GetIntAttr(node, "N") : num_inputs; op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0; + op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T")); op->outputs.push_back(node.name()); - model->operators.emplace_back(op); + model->operators.emplace_back(std::move(op)); + return tensorflow::Status::OK(); } - // Some TensorFlow ops only occur in graph cycles, representing // control flow. We do not currently support control flow, so we wouldn't // be able to fully support such graphs, including performing inference, @@ -1518,7 +1586,7 @@ void ConvertStackOperator(const NodeDef& node, // such ops as RNN back-edges, which is technically incorrect (does not // allow representing the op's semantics) but good enough to get a // graph visualization. -void ConvertOperatorSpecialCasedAsRNNBackEdge( +tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { // At the moment, the only type of operator special-cased in this way is @@ -1531,6 +1599,41 @@ void ConvertOperatorSpecialCasedAsRNNBackEdge( rnn_state->set_discardable(true); rnn_state->set_state_array(node.name()); rnn_state->set_back_edge_source_array(node.input(0)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertShapeOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Shape"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); + const auto out_type = + HasAttr(node, "out_type") ? GetDataTypeAttr(node, "out_type") : DT_INT32; + CHECK(out_type == DT_INT64 || out_type == DT_INT32); + auto op = absl::make_unique(); + op->output_data_type = ConvertDataType(out_type); + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.push_back(std::move(op)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertAnyOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Any"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + const auto idx_type = + HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; + CHECK(idx_type == DT_INT32); + auto op = absl::make_unique(); + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + op->keep_dims = + HasAttr(node, "keep_dims") ? GetBoolAttr(node, "keep_dims") : false; + model->operators.push_back(std::move(op)); + return tensorflow::Status::OK(); } void StripCaretFromArrayNames(Model* model) { @@ -1673,9 +1776,9 @@ bool InlineAllFunctions(GraphDef* graphdef) { return graph_modified; } -void ConvertTopKV2Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertTopKV2Operator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK((node.op() == "TopK") || (node.op() == "TopKV2")); auto op = absl::make_unique(); op->inputs.push_back(node.input(0)); @@ -1685,22 +1788,23 @@ void ConvertTopKV2Operator(const NodeDef& node, model, node.name() + "k", {static_cast(GetIntAttr(node, "k"))}); op->inputs.push_back(k_array); } else { - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); op->inputs.push_back(node.input(1)); } // The op has two outputs. op->outputs.push_back(node.name()); op->outputs.push_back(node.name() + ":1"); model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); } -void ConvertDynamicPartitionOperator( +tensorflow::Status ConvertDynamicPartitionOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { auto op = absl::make_unique(); CHECK(HasAttr(node, "num_partitions")); op->num_partitions = GetIntAttr(node, "num_partitions"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); CHECK_GT(op->num_partitions, 1); @@ -1709,11 +1813,12 @@ void ConvertDynamicPartitionOperator( op->outputs.push_back(node.name() + ":" + std::to_string(i)); } model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); } -void ConvertDynamicStitchOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertDynamicStitchOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { // The parallel and non-parallel variants are the same besides whether they // have a parallel loop; there are no behavioral differences. CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch"); @@ -1721,19 +1826,20 @@ void ConvertDynamicStitchOperator(const NodeDef& node, CHECK(HasAttr(node, "N")); op->num_partitions = GetIntAttr(node, "N"); // Expect all ID partitions + all value partitions. - CheckInputsCount(node, tf_import_flags, op->num_partitions * 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, op->num_partitions * 2)); for (int i = 0; i < op->num_partitions * 2; ++i) { op->inputs.push_back(node.input(i)); } op->outputs.push_back(node.name()); model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); } -void ConvertSparseToDenseOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSparseToDenseOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "SparseToDense"); - CheckInputsCount(node, tf_import_flags, 4); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); auto* op = new SparseToDenseOperator; for (const string& input : node.input()) { @@ -1745,217 +1851,188 @@ void ConvertSparseToDenseOperator(const NodeDef& node, ? GetBoolAttr(node, "validate_indices") : true; model->operators.emplace_back(op); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertOneHotOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "OneHot"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); + + const auto dtype = GetDataTypeAttr(node, "T"); + // TODO(b/111744875): Support DT_UINT8 and quantization. + CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT || + dtype == DT_BOOL); + + auto op = absl::make_unique(); + op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1; + for (const string& input : node.input()) { + op->inputs.push_back(input); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertCTCBeamSearchDecoderOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "CTCBeamSearchDecoder"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + + auto* op = new CTCBeamSearchDecoderOperator; + for (const string& input : node.input()) { + op->inputs.push_back(input); + } + + op->beam_width = + HasAttr(node, "beam_width") ? GetIntAttr(node, "beam_width") : 1; + op->top_paths = + HasAttr(node, "top_paths") ? GetIntAttr(node, "top_paths") : 1; + op->merge_repeated = HasAttr(node, "merge_repeated") + ? GetBoolAttr(node, "merge_repeated") + : true; + + // There are top_paths + 1 outputs. + op->outputs.push_back(node.name()); // Implicit :0. + for (int i = 0; i < op->top_paths; ++i) { + op->outputs.push_back(node.name() + ":" + std::to_string(i + 1)); + } + model->operators.emplace_back(op); + return tensorflow::Status::OK(); } } // namespace namespace internal { + +using ConverterType = tensorflow::Status (*)( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model); +using ConverterMapType = std::unordered_map; + +ConverterMapType GetTensorFlowNodeConverterMap() { + return std::unordered_map({ + {"Add", ConvertSimpleOperator}, + {"AddN", ConvertSimpleOperator}, + {"All", ConvertSimpleOperator}, + {"Any", ConvertAnyOperator}, + {"ArgMax", ConvertArgMaxOperator}, + {"ArgMin", ConvertArgMinOperator}, + {"Assert", ConvertSimpleOperator}, + {"AvgPool", ConvertAvgPoolOperator}, + {"BatchMatMul", ConvertBatchMatMulOperator}, + {"BatchNormWithGlobalNormalization", + ConvertBatchNormWithGlobalNormalizationOperator}, + {"BatchToSpaceND", ConvertBatchToSpaceNDOperator}, + {"BiasAdd", ConvertBiasAddOperator}, + {"Cast", ConvertCastOperator}, + {"CheckNumerics", ConvertIdentityOperator}, + {"Concat", ConvertConcatOperator}, + {"ConcatV2", ConvertConcatOperator}, + {"Const", ConvertConstOperator}, + {"Conv2D", ConvertConvOperator}, + {"Conv2DBackpropInput", ConvertTransposeConvOperator}, + {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator}, + {"DepthToSpace", ConvertDepthToSpaceOperator}, + {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator}, + {"Div", ConvertSimpleOperator}, + {"DynamicPartition", ConvertDynamicPartitionOperator}, + {"DynamicStitch", ConvertDynamicStitchOperator}, + {"Equal", ConvertSimpleOperator}, + {"Exp", ConvertSimpleOperator}, + {"ExpandDims", ConvertSimpleOperator}, + {"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs}, + {"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars}, + {"Fill", ConvertSimpleOperator}, + {"Floor", ConvertFloorOperator}, + {"FloorDiv", ConvertSimpleOperator}, + {"FloorMod", ConvertSimpleOperator}, + {"FusedBatchNorm", ConvertFusedBatchNormOperator}, + {"Gather", ConvertGatherOperator}, + {"GatherV2", ConvertGatherOperator}, + {"Greater", ConvertSimpleOperator}, + {"GreaterEqual", + ConvertSimpleOperator}, + {"Identity", ConvertIdentityOperator}, + {"LRN", ConvertLRNOperator}, + {"LegacyFedInput", ConvertPlaceholderOperator}, + {"Less", ConvertSimpleOperator}, + {"LessEqual", ConvertSimpleOperator}, + {"Log", ConvertSimpleOperator}, + {"LogicalAnd", ConvertSimpleOperator}, + {"LogicalOr", ConvertSimpleOperator}, + {"LogicalNot", ConvertSimpleOperator}, + {"LogSoftmax", ConvertSimpleOperator}, + {"MatMul", ConvertMatMulOperator}, + {"Max", ConvertReduceOperator}, + {"MaxPool", ConvertMaxPoolOperator}, + {"Maximum", ConvertSimpleOperator}, + {"Mean", ConvertReduceOperator}, + {"Merge", ConvertSimpleOperator}, + {"Min", ConvertReduceOperator}, + {"Minimum", ConvertSimpleOperator}, + {"Mul", ConvertSimpleOperator}, + {"Neg", ConvertSimpleOperator}, + {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge}, + {"NoOp", ConvertNoOpOperator}, + {"NotEqual", ConvertSimpleOperator}, + {"OneHot", ConvertOneHotOperator}, + {"Pack", ConvertPackOperator}, + {"Pad", ConvertSimpleOperator}, + {"PadV2", ConvertSimpleOperator}, + {"ParallelDynamicStitch", ConvertDynamicStitchOperator}, + {"Placeholder", ConvertPlaceholderOperator}, + {"PlaceholderWithDefault", ConvertIdentityOperator}, + {"Pow", ConvertSimpleOperator}, + {"Prod", ConvertReduceOperator}, + {"RandomUniform", ConvertRandomUniform}, + {"Range", ConvertRangeOperator}, + {"Rank", ConvertSimpleOperator}, + {"RealDiv", ConvertSimpleOperator}, + {"Relu", ConvertSimpleOperator}, + {"Relu6", ConvertSimpleOperator}, + {"Reshape", ConvertSimpleOperator}, + {"ResizeBilinear", ConvertResizeBilinearOperator}, + {"Rsqrt", ConvertSimpleOperator}, + {"Select", ConvertSimpleOperator}, + {"Shape", ConvertShapeOperator}, + {"Sigmoid", ConvertSimpleOperator}, + {"Sin", ConvertSimpleOperator}, + {"Slice", ConvertSimpleOperator}, + {"Softmax", ConvertSoftmaxOperator}, + {"SpaceToBatchND", ConvertSpaceToBatchNDOperator}, + {"SpaceToDepth", ConvertSpaceToDepthOperator}, + {"SparseToDense", ConvertSparseToDenseOperator}, + {"Split", ConvertSplitOperator}, + {"Sqrt", ConvertSimpleOperator}, + {"Square", ConvertSimpleOperator}, + {"Squeeze", ConvertSqueezeOperator}, + {"StopGradient", ConvertIdentityOperator}, + {"StridedSlice", ConvertStridedSliceOperator}, + {"Sub", ConvertSimpleOperator}, + {"Sum", ConvertReduceOperator}, + {"Svdf", ConvertSvdfOperator}, + {"Switch", ConvertSwitchOperator}, + {"Tanh", ConvertSimpleOperator}, + {"Tile", ConvertSimpleOperator}, + {"TopK", ConvertTopKV2Operator}, + {"TopKV2", ConvertTopKV2Operator}, + {"Transpose", ConvertSimpleOperator}, + }); +} + tensorflow::Status ImportTensorFlowNode( const tensorflow::NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, Model* model) { - // TODO(ahentz): Historically these functions all CHECK-fail on error. We've - // been slowly converting them to return Status. - if (node.op() == "Const") { - return ConvertConstOperator(node, tf_import_flags, model); - } else if (node.op() == "Conv2D") { - return ConvertConvOperator(node, tf_import_flags, model); - } else if (node.op() == "Conv2DBackpropInput") { - ConvertTransposeConvOperator(node, tf_import_flags, model); - } else if (node.op() == "DepthwiseConv2dNative") { - ConvertDepthwiseConvOperator(node, tf_import_flags, model); - } else if (node.op() == "DepthToSpace") { - ConvertDepthToSpaceOperator(node, tf_import_flags, model); - } else if (node.op() == "SpaceToDepth") { - ConvertSpaceToDepthOperator(node, tf_import_flags, model); - } else if (node.op() == "BiasAdd") { - ConvertBiasAddOperator(node, tf_import_flags, model); - } else if (node.op() == "Relu") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Relu6") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Sigmoid") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Tanh") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "MaxPool") { - ConvertMaxPoolOperator(node, tf_import_flags, model); - } else if (node.op() == "AvgPool") { - ConvertAvgPoolOperator(node, tf_import_flags, model); - } else if (node.op() == "Reshape") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "BatchMatMul") { - ConvertBatchMatMulOperator(node, tf_import_flags, model); - } else if (node.op() == "MatMul") { - ConvertMatMulOperator(node, tf_import_flags, model); - } else if (node.op() == "Div" || node.op() == "RealDiv") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Identity" || node.op() == "CheckNumerics" || - node.op() == "StopGradient") { - ConvertIdentityOperator(node, tf_import_flags, model); - } else if (node.op() == "FakeQuantWithMinMaxVars") { - ConvertFakeQuantWithMinMaxVars(node, tf_import_flags, model); - } else if (node.op() == "FakeQuantWithMinMaxArgs") { - ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model); - } else if (node.op() == "Neg") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Rsqrt") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Squeeze") { - ConvertSqueezeOperator(node, tf_import_flags, model); - } else if (node.op() == "Sqrt") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Square") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Add") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "AddN") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Mul") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Sub") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Sum") { - ConvertSumOperator(node, tf_import_flags, model); - } else if (node.op() == "Tile") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Concat" || node.op() == "ConcatV2") { - ConvertConcatOperator(node, tf_import_flags, model); - } else if (node.op() == "LRN") { - ConvertLRNOperator(node, tf_import_flags, model); - } else if (node.op() == "Softmax") { - ConvertSoftmaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Log") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "LogSoftmax") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "All") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Assert") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Less") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "LessEqual") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Greater") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "GreaterEqual") { - ConvertSimpleOperator( - node, tf_import_flags, model); - } else if (node.op() == "Max") { - ConvertMaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Min") { - ConvertMinOperator(node, tf_import_flags, model); - } else if (node.op() == "Maximum") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Minimum") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Merge") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Pad") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "PadV2") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "StridedSlice") { - ConvertStridedSliceOperator(node, tf_import_flags, model); - } else if (node.op() == "Shape") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Slice") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Split") { - ConvertSplitOperator(node, tf_import_flags, model); - } else if (node.op() == "Switch") { - ConvertSwitchOperator(node, tf_import_flags, model); - } else if (node.op() == "Placeholder") { - ConvertPlaceholderOperator(node, tf_import_flags, model); - } else if (node.op() == "PlaceholderWithDefault") { - ConvertIdentityOperator(node, tf_import_flags, model); - } else if (node.op() == "LegacyFedInput") { - ConvertPlaceholderOperator(node, tf_import_flags, model); - } else if (node.op() == "NoOp") { - ConvertNoOpOperator(node, tf_import_flags, model); - } else if (node.op() == "Cast") { - ConvertCastOperator(node, tf_import_flags, model); - } else if (node.op() == "Floor") { - ConvertFloorOperator(node, tf_import_flags, model); - } else if (node.op() == "Gather" || node.op() == "GatherV2") { - ConvertGatherOperator(node, tf_import_flags, model); - } else if (node.op() == "ResizeBilinear") { - ConvertResizeBilinearOperator(node, tf_import_flags, model); - } else if (node.op() == "BatchNormWithGlobalNormalization") { - ConvertBatchNormWithGlobalNormalizationOperator(node, tf_import_flags, - model); - } else if (node.op() == "FusedBatchNorm") { - ConvertFusedBatchNormOperator(node, tf_import_flags, model); - } else if (node.op() == "SpaceToBatchND") { - ConvertSpaceToBatchNDOperator(node, tf_import_flags, model); - } else if (node.op() == "BatchToSpaceND") { - ConvertBatchToSpaceNDOperator(node, tf_import_flags, model); - } else if (node.op() == "Mean") { - ConvertMeanOperator(node, tf_import_flags, model); - } else if (node.op() == "Svdf") { - ConvertSvdfOperator(node, tf_import_flags, model); - } else if (node.op() == "NextIteration") { - ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model); - } else if (node.op() == "ExpandDims") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Fill") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "FloorDiv") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "FloorMod") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Range") { - ConvertRangeOperator(node, tf_import_flags, model); - } else if (node.op() == "Rank") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Stack" || node.op() == "Pack") { - ConvertStackOperator(node, tf_import_flags, model); - } else if (node.op() == "Transpose") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "ArgMax") { - ConvertArgMaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Exp") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "TopK" || node.op() == "TopKV2") { - ConvertTopKV2Operator(node, tf_import_flags, model); - } else if (node.op() == "DynamicPartition") { - ConvertDynamicPartitionOperator(node, tf_import_flags, model); - } else if (node.op() == "DynamicStitch" || - node.op() == "ParallelDynamicStitch") { - ConvertDynamicStitchOperator(node, tf_import_flags, model); - } else if (node.op() == "RandomUniform") { - ConvertRandomUniform(node, tf_import_flags, model); - } else if (node.op() == "Sin") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Log") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Select") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "SparseToDense") { - ConvertSparseToDenseOperator(node, tf_import_flags, model); - } else if (node.op() == "Equal") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "NotEqual") { - ConvertSimpleOperator(node, tf_import_flags, - model); + const TensorFlowImportFlags& tf_import_flags, Model* model, + const ConverterMapType& converter_map) { + auto converter = converter_map.find(node.op()); + if (converter == converter_map.end()) { + return ConvertUnsupportedOperator(node, tf_import_flags, model); } else { - ConvertUnsupportedOperator(node, tf_import_flags, model); + return converter->second(node, tf_import_flags, model); } - return tensorflow::Status::OK(); } } // namespace internal @@ -1981,10 +2058,13 @@ std::unique_ptr ImportTensorFlowGraphDef( } Model* model = new Model; + const internal::ConverterMapType& converter_map = + internal::GetTensorFlowNodeConverterMap(); for (auto node : inlined_graph.node()) { StripZeroOutputIndexFromInputs(&node); - auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model); + auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model, + converter_map); CHECK(status.ok()) << status.error_message(); } diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index d18c329a43411236f8fd5446998c168803b9373a..90e6f698efee6a6a32da18a658e72c3e8b6550c0 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -36,8 +36,14 @@ using tensorflow::NodeDef; using tensorflow::Status; namespace internal { +using ConverterType = tensorflow::Status (*)( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model); +using ConverterMapType = std::unordered_map; + +ConverterMapType GetTensorFlowNodeConverterMap(); Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&, - Model*); + Model*, const ConverterMapType&); } // namespace internal namespace { @@ -105,8 +111,9 @@ class ShapeImportTest : public ::testing::TestWithParam { Status ImportNode(const NodeDef& node) { Model model; - return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), - &model); + const auto converter = internal::GetTensorFlowNodeConverterMap(); + return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model, + converter); } }; diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 2f43adb07b1c9dc9645942ce6ec868595704baa5..412e14c4ada3280dafcd2fcfa59e2908dd785f9f 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ #define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ +#include #include #include #include @@ -22,6 +23,7 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" #include "tensorflow/contrib/lite/toco/toco_port.h" @@ -32,7 +34,7 @@ namespace toco { using tflite::QuantizationParams; -enum class OperatorType { +enum class OperatorType : uint8 { kNone, // General-purpose neural network operators. kAdd, @@ -62,6 +64,7 @@ enum class OperatorType { kMaxPool, kFakeQuant, kMul, + kOneHot, kRandomUniform, kRange, kRank, @@ -80,10 +83,11 @@ enum class OperatorType { kResizeBilinear, kSin, kSpaceToBatchND, - kStack, + kPack, kBatchToSpaceND, kPad, kPadV2, + kReduceProd, // Reduction product kStridedSlice, kSlice, kSqueeze, @@ -96,38 +100,38 @@ enum class OperatorType { // Special operators used for importing TensorFlow nodes. // The general intent is to have some graph transformation either // drop them or rewrite them as general-purpose operators. - kTensorFlowAll, - kTensorFlowAssert, - kTensorFlowConcat, - kTensorFlowConcatV2, - kTensorFlowGreater, - kTensorFlowGreaterEqual, - kTensorFlowIdentity, - kTensorFlowLess, - kTensorFlowLessEqual, - kTensorFlowMax, - kTensorFlowMaximum, - kTensorFlowMin, - kTensorFlowMinimum, - kTensorFlowMatMul, - kTensorFlowMerge, + kAll, + kAssert, + kConcat, + kConcatV2, + kGreater, + kGreaterEqual, + kIdentity, + kLess, + kLessEqual, + kReduceMax, // Reduction Max + kMaximum, // Element-wise Maximum + kReduceMin, // Reduction Min + kMinimum, // Element-wise Minimum + kMatMul, + kMerge, kNeg, - kTensorFlowReshape, - kTensorFlowRsqrt, - kTensorFlowShape, - kTensorFlowSplit, - kTensorFlowSqrt, - kTensorFlowSquare, - kTensorFlowSum, - kTensorFlowSwitch, - kTensorFlowTile, + kReshape, + kRsqrt, + kShape, + kSplit, + kSqrt, + kSquare, + kSum, + kSwitch, + kTile, kTranspose, kTopK_V2, kDynamicPartition, kDynamicStitch, // An unsupported TF operation. It's only needed to be able to represent TF // graph internally and is expected to be dropped by graph transformations. - kTensorFlowUnsupported, + kUnsupported, // Finally, TensorFlow uses different conventions for axes ordering, // see AxesOrder, and this cannot always be resolved at the time of importing // nodes, as TensorFlow parameters may be constant-expression subgraphs @@ -136,8 +140,15 @@ enum class OperatorType { kReorderAxes, kSelect, kSparseToDense, - kTensorFlowEqual, - kTensorFlowNotEqual, + kEqual, + kNotEqual, + kPow, + kArgMin, + kAny, + kLogicalAnd, + kLogicalNot, + kLogicalOr, + kCTCBeamSearchDecoder, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -160,21 +171,22 @@ enum class AxesOrder { // The type of the scalars in an array. // Note that the type does not by itself tell whether the values in the array -// are real (are literally interpreted as real numbers) or quantized (only -// acquire a meaning as real numbers in conjunction with QuantizationParams). +// are non-quantized (can be accessed directly) or quantized (must be +// interpreted in conjunction with QuantizationParams). // // In practice though: -// float values are always real +// float values are never quantized // uint8 values are always quantized -// int32 values are either real or quantized (depending on whether +// int32 values are sometimes quantized (depending on whether // QuantizationParams are present). -// other types are unused at the moment. +// complex values are never quantized +// other types are never quantized at the moment. // // kNone means that we don't know the data type yet, or that we don't care // because we'll be dropping the array anyway (e.g. some exotic array types // may be involved only in debug-only subgraphs that we may not be interested // in actually supporting). -enum class ArrayDataType { +enum class ArrayDataType : uint8 { kNone, // 0 kBool, kFloat, @@ -186,7 +198,8 @@ enum class ArrayDataType { kUint32, kInt64, kUint64, // 10 - kString + kString, + kComplex64, }; // Compile-time logic to map ArrayDataType to the corresponding C++ scalar type @@ -240,6 +253,10 @@ template <> struct DataTypeImpl { typedef string Type; }; +template <> +struct DataTypeImpl { + typedef std::complex Type; +}; template using DataType = typename DataTypeImpl::Type; @@ -278,6 +295,46 @@ struct Buffer : GenericBuffer { std::vector> data; }; +class Shape { + public: + // For Shape, we stick to half-way encapsulation for now: + // we hide the raw dims_ member, but expose it raw by accessors + // because from some brainstorming, it's not at all easy to + // anticipate which flavor of more hermetic encapsulation would + // actually buy us future-proof-ness without being needlessly + // cumbersome. + Shape() {} + Shape(std::initializer_list dim_list) : dims_(dim_list) {} + + void ReplaceDims(std::initializer_list dim_list) { + dims_ = std::vector(dim_list); + } + + const std::vector& dims() const { return dims_; } + std::vector* mutable_dims() { return &dims_; } + const int dimensions_count() const { return dims_.size(); } + + // We still have that one convenience accessor to avoid + // the awkward double bracket issue: shape.dims()[i]. + int dims(int i) const { + // Always check for out-of-bounds accesses, even in optimized builds where + // standard assertions are disabled. Out-of-bounds access here is a common + // occurrence. + CHECK_GE(i, 0); + CHECK_GT(dims_.size(), i); + return dims_[i]; + } + + bool operator==(const Shape& comp) const { + return (this->dims_ == comp.dims()); + } + + bool operator!=(const Shape& comp) const { return !((*this) == comp); } + + private: + std::vector dims_; +}; + // Base class for all operator classes. struct Operator { // Non-default-constructible: only OperatorType-specific subclass @@ -382,6 +439,28 @@ struct ConvOperator : Operator { int dilation_height_factor = 1; }; +// CTCBeamSearchDecoder operator: +// +// Inputs: +// inputs[0]: required: the logits. +// inputs[1]: required: sequence length. +// inputs[2]: optional: beam width. +// inputs[3]: optional: top paths. +// inputs[4]: optional: merge repeated. +// +// Outputs: +// outputs[0]: deocoded. +// outputs[1]: log probability. +// +// TensorFlow equivalent: CTCBeamSearchDecoder +struct CTCBeamSearchDecoderOperator : Operator { + CTCBeamSearchDecoderOperator() + : Operator(OperatorType::kCTCBeamSearchDecoder) {} + int beam_width; + int top_paths; + bool merge_repeated = true; +}; + // Depthwise-separable convolution operator. // // Inputs: @@ -433,7 +512,8 @@ struct SpaceToDepthOperator : Operator { // input activations as a matrix, followed by a MatMul node. struct FullyConnectedOperator : Operator { FullyConnectedOperator() : Operator(OperatorType::kFullyConnected) {} - bool experimental_shuffled_weights = false; + FullyConnectedWeightsFormat weights_format = + FullyConnectedWeightsFormat::kDefault; }; // Dequantization operator, converting a quantized array of integers with @@ -781,6 +861,7 @@ struct FakeQuantOperator : Operator { FakeQuantOperator() : Operator(OperatorType::kFakeQuant) {} std::unique_ptr minmax; int num_bits = 8; + bool narrow_range = false; }; // Element-wise division operator. @@ -801,7 +882,7 @@ struct DivOperator : Operator { // // TensorFlow equivalent: Identity struct TensorFlowIdentityOperator : Operator { - TensorFlowIdentityOperator() : Operator(OperatorType::kTensorFlowIdentity) {} + TensorFlowIdentityOperator() : Operator(OperatorType::kIdentity) {} }; // Batch matrix multiplication operator. This comes from the (deprecated) @@ -827,7 +908,9 @@ struct BatchMatMulOperator : Operator { // // TensorFlow equivalent: MatMul struct TensorFlowMatMulOperator : Operator { - TensorFlowMatMulOperator() : Operator(OperatorType::kTensorFlowMatMul) {} + TensorFlowMatMulOperator() : Operator(OperatorType::kMatMul) {} + bool transpose_a = false; + bool transpose_b = false; }; // Padding operator. Pads a tensor with zeros. @@ -961,7 +1044,7 @@ struct StridedSliceOperator : Operator { // TensorFlow equivalent: Reshape --- except that we only support a special case // here, where the output shape is a matrix (2D) shape. struct TensorFlowReshapeOperator : Operator { - TensorFlowReshapeOperator() : Operator(OperatorType::kTensorFlowReshape) {} + TensorFlowReshapeOperator() : Operator(OperatorType::kReshape) {} std::vector shape; }; @@ -1131,7 +1214,7 @@ struct SelectOperator : Operator { // // TensorFlow equivalent: Rsqrt struct TensorFlowRsqrtOperator : Operator { - TensorFlowRsqrtOperator() : Operator(OperatorType::kTensorFlowRsqrt) {} + TensorFlowRsqrtOperator() : Operator(OperatorType::kRsqrt) {} }; // Stacks a list of rank-R tensors into one rank-(R+1) tensor. @@ -1143,10 +1226,12 @@ struct TensorFlowRsqrtOperator : Operator { // Inputs: this operator accepts any number >= 1 of inputs. // inputs[i]: the i-th array to merge. // -// TensorFlow equivalent: Stack or Pack -struct StackOperator : Operator { - StackOperator() : Operator(OperatorType::kStack) {} +// TensorFlow equivalent: Pack +struct PackOperator : Operator { + PackOperator() : Operator(OperatorType::kPack) {} + int values_count; int axis = 0; + ArrayDataType dtype = ArrayDataType::kNone; }; // Shape operator. Extracts the shape of the tensor. @@ -1157,10 +1242,10 @@ struct StackOperator : Operator { // This operation outputs a 1-D integer tensor representing the shape of // the input. // -// TensorFlow equivalent: Shape. We currently assume that the output is int32 -// and not int64. The output type could be stored herein. +// TensorFlow equivalent: Shape. struct TensorFlowShapeOperator : Operator { - TensorFlowShapeOperator() : Operator(OperatorType::kTensorFlowShape) {} + TensorFlowShapeOperator() : Operator(OperatorType::kShape) {} + ArrayDataType output_data_type = ArrayDataType::kInt32; }; // Element-wise square-root (x^0.5) operator. @@ -1170,7 +1255,7 @@ struct TensorFlowShapeOperator : Operator { // // TensorFlow equivalent: Sqrt struct TensorFlowSqrtOperator : Operator { - TensorFlowSqrtOperator() : Operator(OperatorType::kTensorFlowSqrt) {} + TensorFlowSqrtOperator() : Operator(OperatorType::kSqrt) {} }; // Element-wise square (x*x) operator. @@ -1180,7 +1265,7 @@ struct TensorFlowSqrtOperator : Operator { // // TensorFlow equivalent: Square struct TensorFlowSquareOperator : Operator { - TensorFlowSquareOperator() : Operator(OperatorType::kTensorFlowSquare) {} + TensorFlowSquareOperator() : Operator(OperatorType::kSquare) {} }; // Transposes a tensor. @@ -1208,24 +1293,37 @@ struct SubOperator : Operator { SubOperator() : Operator(OperatorType::kSub) {} }; -// Global sum reduction: computes the sum of all of entries in the input array. -// Thus the output is "0-dimensional": it consists of a single scalar value. +// Sum reduction: computes the sum of all of entries across the axes. // // Inputs: // inputs[0]: required: the input array // -// TensorFlow equivalent: Sum --- except that we only support the special case -// of global reduction across all dimensions. +// TensorFlow equivalent: Sum struct TensorFlowSumOperator : Operator { - TensorFlowSumOperator() : Operator(OperatorType::kTensorFlowSum) {} + TensorFlowSumOperator() : Operator(OperatorType::kSum) {} + std::vector axis; + bool keep_dims = false; +}; + +// Prod reduction: computes the product of all of entries across the axes. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Prod +struct TensorFlowProdOperator : Operator { + TensorFlowProdOperator() : Operator(OperatorType::kReduceProd) {} + std::vector axis; bool keep_dims = false; }; // TensorFlow Tile equivalent. Refer to TensorFlow documentation for details. -// Not fully supported, just a placeholder to handle TensorFlow graphs and -// support graph transformations to other operator types by matching sub-graphs. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: int array with length of rank(input[0]) struct TensorFlowTileOperator : Operator { - TensorFlowTileOperator() : Operator(OperatorType::kTensorFlowTile) {} + TensorFlowTileOperator() : Operator(OperatorType::kTile) {} }; // TensorFlow Slice equivalent. Refer to TensorFlow documentation for details. @@ -1240,7 +1338,7 @@ struct SliceOperator : Operator { // Not fully supported, just a placeholder to handle TensorFlow graphs and // support graph transformations to other operator types by matching sub-graphs. struct TensorFlowSplitOperator : Operator { - TensorFlowSplitOperator() : Operator(OperatorType::kTensorFlowSplit) {} + TensorFlowSplitOperator() : Operator(OperatorType::kSplit) {} int num_split = 0; }; @@ -1251,7 +1349,7 @@ struct TensorFlowSplitOperator : Operator { // dimension then we can change this op into a DepthConcatenation op. // Otherwise, we hope for some other graph transformation to drop this node. struct TensorFlowConcatOperator : Operator { - TensorFlowConcatOperator() : Operator(OperatorType::kTensorFlowConcat) {} + TensorFlowConcatOperator() : Operator(OperatorType::kConcat) {} }; // TensorFlow ConcatV2 equivalent. Refer to TensorFlow documentation for @@ -1262,7 +1360,7 @@ struct TensorFlowConcatOperator : Operator { // dimension then we can change this op into a DepthConcatenation op. // Otherwise, we hope for some other graph transformation to drop this node. struct TensorFlowConcatV2Operator : Operator { - TensorFlowConcatV2Operator() : Operator(OperatorType::kTensorFlowConcatV2) {} + TensorFlowConcatV2Operator() : Operator(OperatorType::kConcatV2) {} }; // TensorFlow Merge equivalent. Refer to TensorFlow documentation for details. @@ -1278,7 +1376,7 @@ struct TensorFlowConcatV2Operator : Operator { // control flow that can be resolved at tooling time (independently of input // activations). struct TensorFlowMergeOperator : Operator { - TensorFlowMergeOperator() : Operator(OperatorType::kTensorFlowMerge) {} + TensorFlowMergeOperator() : Operator(OperatorType::kMerge) {} }; // TensorFlow Switch equivalent. Refer to TensorFlow documentation for details. @@ -1301,7 +1399,7 @@ struct TensorFlowMergeOperator : Operator { // control flow that can be resolved at tooling time (independently of input // activations). struct TensorFlowSwitchOperator : Operator { - TensorFlowSwitchOperator() : Operator(OperatorType::kTensorFlowSwitch) {} + TensorFlowSwitchOperator() : Operator(OperatorType::kSwitch) {} }; // TensorFlow All equivalent. Refer to TensorFlow documentation for details. @@ -1310,7 +1408,7 @@ struct TensorFlowSwitchOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowAllOperator : Operator { - TensorFlowAllOperator() : Operator(OperatorType::kTensorFlowAll) {} + TensorFlowAllOperator() : Operator(OperatorType::kAll) {} }; // TensorFlow Assert equivalent. Refer to TensorFlow documentation for details. @@ -1318,7 +1416,7 @@ struct TensorFlowAllOperator : Operator { // support graph transformations to other operator types by matching sub-graphs. // Typically, we just drop Assert nodes. struct TensorFlowAssertOperator : Operator { - TensorFlowAssertOperator() : Operator(OperatorType::kTensorFlowAssert) {} + TensorFlowAssertOperator() : Operator(OperatorType::kAssert) {} }; // TensorFlow Less equivalent. Refer to TensorFlow documentation for details. @@ -1327,7 +1425,7 @@ struct TensorFlowAssertOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowLessOperator : Operator { - TensorFlowLessOperator() : Operator(OperatorType::kTensorFlowLess) {} + TensorFlowLessOperator() : Operator(OperatorType::kLess) {} }; // TensorFlow LessEqual equivalent. Refer to TensorFlow documentation for @@ -1337,8 +1435,7 @@ struct TensorFlowLessOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowLessEqualOperator : Operator { - TensorFlowLessEqualOperator() - : Operator(OperatorType::kTensorFlowLessEqual) {} + TensorFlowLessEqualOperator() : Operator(OperatorType::kLessEqual) {} }; // TensorFlow Less equivalent. Refer to TensorFlow documentation for details. @@ -1347,7 +1444,7 @@ struct TensorFlowLessEqualOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowGreaterOperator : Operator { - TensorFlowGreaterOperator() : Operator(OperatorType::kTensorFlowGreater) {} + TensorFlowGreaterOperator() : Operator(OperatorType::kGreater) {} }; // TensorFlow GreaterEqual equivalent. Refer to TensorFlow documentation for @@ -1357,8 +1454,7 @@ struct TensorFlowGreaterOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowGreaterEqualOperator : Operator { - TensorFlowGreaterEqualOperator() - : Operator(OperatorType::kTensorFlowGreaterEqual) {} + TensorFlowGreaterEqualOperator() : Operator(OperatorType::kGreaterEqual) {} }; // TensorFlow Equal equivalent. Refer to TensorFlow documentation for @@ -1368,38 +1464,36 @@ struct TensorFlowGreaterEqualOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowEqualOperator : Operator { - TensorFlowEqualOperator() : Operator(OperatorType::kTensorFlowEqual) {} + TensorFlowEqualOperator() : Operator(OperatorType::kEqual) {} }; // TensorFlow Not Equal equivalent. Refer to TensorFlow documentation for // details. struct TensorFlowNotEqualOperator : Operator { - TensorFlowNotEqualOperator() : Operator(OperatorType::kTensorFlowNotEqual) {} + TensorFlowNotEqualOperator() : Operator(OperatorType::kNotEqual) {} }; -// Global max reduction: computes the max of all of entries in the input array. -// Thus the output is "0-dimensional": it consists of a single scalar value. +// Max reduction: computes the max of all of entries across the axes. // // Inputs: // inputs[0]: required: the input array // -// TensorFlow equivalent: Max --- except that we only support the special case -// of global reduction across all dimensions. +// TensorFlow equivalent: Max struct TensorFlowMaxOperator : Operator { - TensorFlowMaxOperator() : Operator(OperatorType::kTensorFlowMax) {} + TensorFlowMaxOperator() : Operator(OperatorType::kReduceMax) {} + std::vector axis; bool keep_dims = false; }; -// Global min reduction: computes the min of all of entries in the input array. -// Thus the output is "0-dimensional": it consists of a single scalar value. +// Min reduction: computes the min of all of entries across the axes. // // Inputs: // inputs[0]: required: the input array // -// TensorFlow equivalent: Min --- except that we only support the special case -// of global reduction across all dimensions. +// TensorFlow equivalent: Min struct TensorFlowMinOperator : Operator { - TensorFlowMinOperator() : Operator(OperatorType::kTensorFlowMin) {} + TensorFlowMinOperator() : Operator(OperatorType::kReduceMin) {} + std::vector axis; bool keep_dims = false; }; @@ -1412,7 +1506,7 @@ struct TensorFlowMinOperator : Operator { // // TensorFlow equivalent: Maximum struct TensorFlowMaximumOperator : Operator { - TensorFlowMaximumOperator() : Operator(OperatorType::kTensorFlowMaximum) {} + TensorFlowMaximumOperator() : Operator(OperatorType::kMaximum) {} }; // Element-wise minimum operator. Currently it only supports scalar as @@ -1424,14 +1518,13 @@ struct TensorFlowMaximumOperator : Operator { // // TensorFlow equivalent: Minimum struct TensorFlowMinimumOperator : Operator { - TensorFlowMinimumOperator() : Operator(OperatorType::kTensorFlowMinimum) {} + TensorFlowMinimumOperator() : Operator(OperatorType::kMinimum) {} }; // General TF operation, unsupported by tf.mini. Expected to be dropped by // graph transformations. struct TensorFlowUnsupportedOperator : Operator { - TensorFlowUnsupportedOperator() - : Operator(OperatorType::kTensorFlowUnsupported) {} + TensorFlowUnsupportedOperator() : Operator(OperatorType::kUnsupported) {} // The original TF operation type. Used for diagnostic purposes. string tensorflow_op; @@ -1439,8 +1532,13 @@ struct TensorFlowUnsupportedOperator : Operator { string tensorflow_node_def; // A boolean indicating if the unsupported op should be treated as quantized. bool quantized = false; + // A boolean indicating if the unsupported op output should allow float values + // in quantized mode. + bool support_output_type_float_in_quantized_op = false; // Output data types std::vector output_data_types; + // Output shapes. + std::vector output_shapes; }; // Softmax activation function. @@ -1501,11 +1599,15 @@ struct FloorOperator : Operator { // Inputs: // inputs[0]: required: the params array // inputs[1]: required: the indices to gather +// inputs[2]: optional: axis // // TensorFlow equivalent: Gather struct GatherOperator : Operator { GatherOperator() : Operator(OperatorType::kGather) {} - int axis = 0; + // Axis is populated explicitly or implicitly from the axis input by + // ResolveGatherAttributes. An empty axis indicates that the axis has not yet + // be resolved. + absl::optional axis; int input_rank = 0; }; @@ -1520,6 +1622,17 @@ struct ArgMaxOperator : Operator { ArrayDataType output_data_type = ArrayDataType::kInt64; }; +// ArgMin operator. It returns the index of the minimum value along axis. +// +// Inputs: +// inputs[0]: required: the input tensor +// +// TensorFlow equivalent: ArgMin +struct ArgMinOperator : Operator { + ArgMinOperator() : Operator(OperatorType::kArgMin) {} + ArrayDataType output_data_type = ArrayDataType::kInt64; +}; + // ResizeBilinear operator. It resizes input images with bilinear interpolation. // It does not support align_corners at the moment. // @@ -1639,6 +1752,82 @@ struct SparseToDenseOperator : Operator { bool validate_indices; }; +// Pow operator: +// +// Inputs: +// Inputs[0]: required: A tensor. +// Inputs[1]: required: A tensor. +// +// TensorFlow equivalent: Pow. +struct PowOperator : Operator { + PowOperator() : Operator(OperatorType::kPow) {} +}; + +// Any operator: +// +// Inputs: +// Inputs[0]: required: A boolean input tensor. +// Inputs[1]: required: reduction_indices. +// +// TensorFlow equivalent: tf.reduce_any. +struct AnyOperator : Operator { + AnyOperator() : Operator(OperatorType::kAny) {} + bool keep_dims = false; +}; + +// LogicalAnd operator: +// +// Inputs: +// Inputs[0]: required: A boolean tensor. +// Inputs[1]: required: A boolean tensor. +// +// TensorFlow equivalent: tf.logical_and. +struct LogicalAndOperator : Operator { + LogicalAndOperator() : Operator(OperatorType::kLogicalAnd) {} +}; + +// LogicalNot operator: +// +// Inputs: +// Inputs[0]: required: A boolean tensor. +// +// TensorFlow equivalent: tf.logical_not. +struct LogicalNotOperator : Operator { + LogicalNotOperator() : Operator(OperatorType::kLogicalNot) {} +}; + +// OneHot operator: +// +// Inputs: +// Inputs[0]: required: indices. +// Inputs[1]: required: depth. +// Inputs[2]: required: on_value. +// Inputs[3]: required: off_value. +// +// TensorFlow equivalent: OneHot. +struct OneHotOperator : Operator { + enum Inputs { + INDICES_INPUT = 0, + DEPTH_INPUT = 1, + ON_VALUE_INPUT = 2, + OFF_VALUE_INPUT = 3, + }; + + OneHotOperator() : Operator(OperatorType::kOneHot) {} + int axis = -1; +}; + +// LogicalOr operator: +// +// Inputs: +// Inputs[0]: required: A Bool tensor. +// Inputs[1]: required: A Bool tensor. +// +// TensorFlow equivalent: LogicalOr. +struct LogicalOrOperator : Operator { + LogicalOrOperator() : Operator(OperatorType::kLogicalOr) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are @@ -1652,46 +1841,6 @@ inline bool operator<(const Alloc& a, const Alloc& b) { return a.start < b.start; } -class Shape { - public: - // For Shape, we stick to half-way encapsulation for now: - // we hide the raw dims_ member, but expose it raw by accessors - // because from some brainstorming, it's not at all easy to - // anticipate which flavor of more hermetic encapsulation would - // actually buy us future-proof-ness without being needlessly - // cumbersome. - Shape() {} - Shape(std::initializer_list dim_list) : dims_(dim_list) {} - - void ReplaceDims(std::initializer_list dim_list) { - dims_ = std::vector(dim_list); - } - - const std::vector& dims() const { return dims_; } - std::vector* mutable_dims() { return &dims_; } - const int dimensions_count() const { return dims_.size(); } - - // We still have that one convenience accessor to avoid - // the awkward double bracket issue: shape.dims()[i]. - int dims(int i) const { - // Always check for out-of-bounds accesses, even in optimized builds where - // standard assertions are disabled. Out-of-bounds access here is a common - // occurrence. - CHECK_GE(i, 0); - CHECK_GT(dims_.size(), i); - return dims_[i]; - } - - bool operator==(const Shape& comp) const { - return (this->dims_ == comp.dims()); - } - - bool operator!=(const Shape& comp) const { return !((*this) == comp); } - - private: - std::vector dims_; -}; - // Array represents an array (either a constant parameter array or an // activations array) in a Model. struct Array { @@ -1823,6 +1972,40 @@ struct Array { // If this is non-null, then these quantization parameters are to be used // to assign a meaning as real numbers to the elements of this array. std::unique_ptr quantization_params; + // narrow_range is a detail of how toco handles FakeQuant operators with + // narrow_range, see + // https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars + // + // For more context about what that is useful for, see the big comment in + // graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc + // + // The narrow_range flag applies only to quantized arrays, and changes + // their quantization in the following way when it is set to 'true': + // 1. The computation of {zero_point, scale} from {min, max} needs to be + // amended so that the real min value will get quantized to + // (min_quantized_value + 1) instead of just (min_quantized_value). + // E.g. for uint8 quantization, the real min value should get quantized to + // the uint8 value 1, not 0. + // 2. Quantized values should get clamped to the interval + // [min_quantized_value + 1, max_value]. Equivalently, the + // min_quantized_value should get nudged to (min_quantized_value + 1). + // The reason why 1. does not imply 2. is that real values may not belong to + // the stated [min, max] interval. Concretely, weights recorded at the last + // learning step may not fall in the [min, max] interval recorded over + // previous learning steps, as the values evolve across learning steps. + // + // Rationale why this is directly a field on Array: + // - This can't be just a field on FakeQuantOperator, because + // FakeQuantOperators are gone (DropFakeQuant) before we get to using that + // information (Quantize). We need a place to store that bit in the interim. + // - This can't be in QuantizationParams because we need to record this + // ahead of quantization, and QuantizationParams are only created during + // quantization. + // - This could be in MinMax, but that would be an abuse of what MinMax is + // about, and would break existing code that assumes that a MinMax is just + // a min and a max. Unlike MinMax which is agnostic as to the quantized + // data type, narrow_range refers to values in the quantized data type. + bool narrow_range = false; private: std::unique_ptr array_shape; @@ -1888,7 +2071,7 @@ class Model { std::size_t transient_data_size = 0; // For code-generation only: required alignment of the transient_data buffer std::size_t transient_data_alignment = 0; - // Arithmatic operations performed in the model. + // Arithmetic operations performed in the model. int64 ops_count = 0; private: diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 4c9f1aa4b0274b5123bb3baa9b9fca1463bda4c3..d34da63e43eee3b48e575c33ddb6c89f7701865c 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -74,10 +74,10 @@ bool ParseModelFlagsFromCommandLineFlags( "height, input array width, input array depth."), Flag("batch_size", parsed_flags.batch_size.bind(), parsed_flags.batch_size.default_value(), - "Batch size for the model. Replaces the first dimension of an " - "input size array if undefined. Use only with SavedModels when " - "--input_shapes flag is not specified. Always use --input_shapes " - "flag with frozen graphs."), + "Deprecated. Batch size for the model. Replaces the first dimension " + "of an input size array if undefined. Use only with SavedModels " + "when --input_shapes flag is not specified. Always use " + "--input_shapes flag with frozen graphs."), Flag("input_data_type", parsed_flags.input_data_type.bind(), parsed_flags.input_data_type.default_value(), "Deprecated: use --input_data_types instead. Input array type, if " @@ -322,6 +322,10 @@ void ReadModelFlagsFromCommandLineFlags( for (int i = 0; i < input_shapes.size(); ++i) { auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape(); shape->clear_dims(); + // Treat an empty input shape as a scalar. + if (input_shapes[i].empty()) { + continue; + } for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) { int size; CHECK(absl::SimpleAtoi(dim_str, &size)) diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD index 93fe756a55d378fa205ff88be5e18aff586e5dca..33c5b164622cee94d7ba16e7b1a3006dbacb9ca9 100644 --- a/tensorflow/contrib/lite/toco/python/BUILD +++ b/tensorflow/contrib/lite/toco/python/BUILD @@ -53,5 +53,8 @@ tf_py_test( data = [ ":toco_from_protos", ], - tags = ["no_pip"], + tags = [ + "no_oss", + "no_pip", + ], ) diff --git a/tensorflow/contrib/lite/toco/runtime/types.h b/tensorflow/contrib/lite/toco/runtime/types.h index f5de5a5781a5304634642680e6a3cef60e7b844b..207f2c1706ef4cc12572e381c38f61a504ece232 100644 --- a/tensorflow/contrib/lite/toco/runtime/types.h +++ b/tensorflow/contrib/lite/toco/runtime/types.h @@ -24,6 +24,7 @@ namespace toco { // TODO(ahentz): These are just stopgaps for now, untils we move all // the code over to tflite. using tflite::Dims; +using tflite::FullyConnectedWeightsFormat; using tflite::FusedActivationFunctionType; using tflite::RequiredBufferSizeForDims; diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD index 336e94de1ed3238d64f521cf1347acc8f0737de7..ea1fc2827ead7e7442bbf7f569e3ea88c3b0de57 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD @@ -60,6 +60,7 @@ cc_library( tf_cc_test( name = "resolve_svdf_test", srcs = ["resolve_svdf_test.cc"], + tags = ["no_oss"], deps = [ ":cluster", ":cluster_utils", diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD index e1025c66642d2860c5916bf7625f1c0403c9901c..709c53606b1081111fb2e2f8971ba71e5d38b629 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -24,8 +24,10 @@ cc_library( deps = [ ":types", "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:ptr_util", "@com_google_absl//absl/memory", "@flatbuffers", ], @@ -36,6 +38,7 @@ tf_cc_test( srcs = [ "operator_test.cc", ], + tags = ["no_oss"], deps = [ ":operator", "//tensorflow/contrib/lite/toco:tooling_util", @@ -65,6 +68,7 @@ tf_cc_test( srcs = [ "types_test.cc", ], + tags = ["no_oss"], deps = [ ":types", "@com_google_googletest//:gtest_main", @@ -97,6 +101,7 @@ tf_cc_test( srcs = [ "export_test.cc", ], + tags = ["no_oss"], deps = [ ":export", "//tensorflow/contrib/lite/schema:schema_fbs", @@ -130,6 +135,7 @@ tf_cc_test( srcs = [ "import_test.cc", ], + tags = ["no_oss"], deps = [ ":import", "//tensorflow/contrib/lite:schema_fbs_version", diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index a2d753657b0bf6c88f5c94a20a1240fb7c13a37c..5ad307af14a0613188482ae17aed491dea06f984 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -49,7 +49,7 @@ details::OperatorKey GetOperatorKey( const ::toco::Operator& op, const std::map>& ops_by_type) { string custom_code; - if (op.type == OperatorType::kTensorFlowUnsupported) { + if (op.type == OperatorType::kUnsupported) { const TensorFlowUnsupportedOperator& unsupported_op = static_cast(op); custom_code = unsupported_op.tensorflow_op; @@ -99,7 +99,8 @@ void LoadOperatorsMap( Offset>> ExportTensors( const Model& model, const details::TensorsMap& tensors_map, - FlatBufferBuilder* builder, std::vector* buffers_to_write) { + FlatBufferBuilder* builder, std::vector* buffers_to_write, + const std::set& variable_tensor_indices) { // In the end we will need to produce a vector sorted by the indices of the // tensors in the tensors_map. std::map> ordered_tensors; @@ -139,9 +140,11 @@ Offset>> ExportTensors( scale, zero_point); int index = tensors_map.at(tensor_name); + bool is_variable = + variable_tensor_indices.find(index) != variable_tensor_indices.end(); ordered_tensors[index] = CreateTensor(*builder, builder->CreateVector(shape), type, buffer_index, - builder->CreateString(tensor_name), q_param); + builder->CreateString(tensor_name), q_param, is_variable); } std::vector> tensor_vector; @@ -208,7 +211,7 @@ Offset>> ExportOperatorCodes( ordered_opcodes[op_index] = CreateOperatorCode(*builder, builtin_ops[name], 0, op_version); } else { - // This could be a kTensorFlowUnsupported, in which case we should be + // This could be a kUnsupported, in which case we should be // able to retrieve the original Tensorflow name from the OperatorKey, or // this could be a proper TOCO operator that is completely unknown to TF // Lite. @@ -239,7 +242,10 @@ Offset>> ExportOperators( const Model& model, const std::map>& ops_by_type, const details::OperatorsMap& operators_map, - const details::TensorsMap& tensors_map, FlatBufferBuilder* builder) { + const details::TensorsMap& tensors_map, FlatBufferBuilder* builder, + std::set* variable_tensor_indices) { + variable_tensor_indices->clear(); + // The operators are in execution order, so we just follow tf.mini order. std::vector> op_vector; for (const auto& op : model.operators) { @@ -256,18 +262,36 @@ Offset>> ExportOperators( int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type)); - // This is a custom op unless we can find it in ops_by_type, and even then - // it could be a custom op (such as kTensorFlowUnsupported). + auto tflite_op_it = ops_by_type.find(op->type); + BaseOperator* tflite_op = tflite_op_it == ops_by_type.end() + ? nullptr + : tflite_op_it->second.get(); + // This is a custom op unless we can find it in ops_by_type, and even then + // it could be a custom op (such as kUnsupported). auto options = Options::Custom(0); - if (ops_by_type.count(op->type) != 0) { - options = ops_by_type.at(op->type)->Serialize(*op, builder); + + std::vector mutating_input_variables; + if (tflite_op) { + options = tflite_op->Serialize(*op, builder); + mutating_input_variables = tflite_op->GetMutatingInputVariables(*op); + + if (!mutating_input_variables.empty()) { + for (int i = 0; i < op->inputs.size(); ++i) { + if (!mutating_input_variables[i]) { + continue; + } + int32_t variable_tensor_index = tensors_map.at(op->inputs[i]); + variable_tensor_indices->insert(variable_tensor_index); + } + } } // The only supported CustomOptionFormat is FLEXBUFFERS now. op_vector.push_back(CreateOperator( *builder, op_index, builder->CreateVector(inputs), builder->CreateVector(outputs), options.type, options.builtin, - options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS)); + options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS, + builder->CreateVector(mutating_input_variables))); } return builder->CreateVector(op_vector); @@ -308,24 +332,17 @@ void Export( Array empty_array; buffers_to_write.push_back(&empty_array); - auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write); - auto inputs = ExportInputTensors(model, tensors_map, &builder); - auto outputs = ExportOutputTensors(model, tensors_map, &builder); - std::set error_summary; auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, &builder, &error_summary); - const string fake_quant_operation_name = "FAKE_QUANT"; - - if (error_summary.count(fake_quant_operation_name) != 0) { - LOG(ERROR) - << fake_quant_operation_name - << " operation was not converted. If running quantized make sure you " - "are passing --inference_type=QUANTIZED_UINT8 and values for " - "--std_values and --mean_values."; - // Remove the fake quant operation from the errors, since it shouldn't - // be provided a custom implementation. - error_summary.erase(fake_quant_operation_name); + + for (const auto& op : model.operators) { + if (op->type == OperatorType::kFakeQuant) { + LOG(WARNING) << "FAKE_QUANT operation " << LogName(*op) + << " was not converted. If running quantized make sure you " + "are passing --inference_type=QUANTIZED_UINT8 and values " + "for --std_values and --mean_values."; + } } if (!allow_custom_ops && !error_summary.empty()) { // Remove ExpandDims and ReorderAxes from unimplemented list unless they @@ -353,11 +370,18 @@ void Export( << absl::StrJoin(error_summary_final, ", ") << "."; } - auto ops = - ExportOperators(model, ops_by_type, operators_map, tensors_map, &builder); + std::set variable_tensor_indices; + auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map, + &builder, &variable_tensor_indices); + + auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write, + variable_tensor_indices); + auto inputs = ExportInputTensors(model, tensors_map, &builder); + auto outputs = ExportOutputTensors(model, tensors_map, &builder); // TODO(aselle): add support to toco for multiple subgraphs. - auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops); + auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops, + /* name */ 0); std::vector> subgraphs = {subgraph}; auto buffers = ExportBuffers(model, buffers_to_write, &builder); diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 098d2163e6c2fe26f3cb9cdf9959df62a1a4baf0..58ea5c725c378827aac79f2a5a2cdca59ccc0162 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -45,7 +45,7 @@ namespace details { using TensorsMap = std::unordered_map; // A key to identify an operator. -// Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to +// Only when `type` is `kUnsupported`, `custom_code` is filled to // identify which operation is used. struct OperatorKey { OperatorKey(OperatorType type, const std::string& custom_code, int version) diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index 409e7d72a57076ec2832c5d12b52829477624f74..a95937ba0f4f66fedfab6c1528c8dc4e417297d0 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -73,8 +73,8 @@ TEST_F(ExportTest, LoadOperatorsMap) { EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]); EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]); EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]); - EXPECT_EQ(3, operators[details::OperatorKey( - OperatorType::kTensorFlowUnsupported, "MyCrazyOp", 1)]); + EXPECT_EQ(3, operators[details::OperatorKey(OperatorType::kUnsupported, + "MyCrazyOp", 1)]); } TEST_F(ExportTest, Export) { @@ -262,7 +262,7 @@ TEST_F(VersionedOpExportTest, Export) { EXPECT_EQ(1, (*operators)[1]->opcode_index()); } -// TODO(ahentz): tests for tensors, inputs, outpus, opcodes and operators. +// TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators. } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc index cb44a5e6d7356a1cf5597bbe48565c5b1e1949a6..1dd4915b31413e5afb04b45ee7c4893a2eded66d 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.cc +++ b/tensorflow/contrib/lite/toco/tflite/import.cc @@ -124,7 +124,7 @@ void ImportOperators( new_op = ops_by_name.at(effective_opname) ->Deserialize(input_op->builtin_options(), input_op->custom_options()); - if (new_op->type == OperatorType::kTensorFlowUnsupported) { + if (new_op->type == OperatorType::kUnsupported) { auto* unsupported_op = static_cast(new_op.get()); unsupported_op->tensorflow_op = opname; @@ -221,6 +221,8 @@ std::unique_ptr Import(const ModelFlags& model_flags, model.get()); ImportIOTensors(*input_model, tensors_table, model.get()); + UndoWeightsShuffling(model.get()); + return model; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 7490ab960b9b0c62bef4c343927664ac6ae4eb9d..75808f2b690fb6699f86d61a3078ef458db6d295 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -14,13 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/toco/tflite/operator.h" +// TODO(ycling): Consider refactoring to extract the LSTM definition out of +// graph_transformation module. +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" #include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" #include "tensorflow/contrib/lite/toco/tflite/custom_operator.h" #include "tensorflow/contrib/lite/toco/tflite/simple_operator.h" #include "tensorflow/contrib/lite/toco/tflite/types.h" - #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/util/ptr_util.h" namespace toco { @@ -279,25 +282,31 @@ class DepthToSpace : public CustomOperator { int GetVersion(const Operator& op) const override { return 1; } }; -class FakeQuant : public CustomOperator { +class FakeQuant + : public BuiltinOperator { public: - using CustomOperator::CustomOperator; - void WriteOptions(const TocoOperator& op, - flexbuffers::Builder* fbb) const override { - fbb->Float("min", op.minmax->min); - fbb->Float("max", op.minmax->max); - fbb->Int("num_bits", op.num_bits); + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateFakeQuantOptions( + *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range); } - void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { auto* minmax = new MinMax; - minmax->min = m["min"].AsFloat(); - minmax->max = m["max"].AsFloat(); + minmax->min = options.min(); + minmax->max = options.max(); op->minmax.reset(minmax); - const auto& num_bits = m["num_bits"]; - op->num_bits = num_bits.IsInt() ? num_bits.AsInt32() : 8; + op->num_bits = options.num_bits(); + op->narrow_range = options.narrow_range(); } - int GetVersion(const Operator& op) const override { return 1; } + int GetVersion(const Operator& op) const override { + const auto& fq_op = static_cast(op); + return fq_op.narrow_range ? 2 : 1; + } }; class FullyConnected @@ -311,16 +320,47 @@ class FullyConnected flatbuffers::FlatBufferBuilder* builder) const override { auto activation_function = ActivationFunction::Serialize(op.fused_activation_function); - return ::tflite::CreateFullyConnectedOptions(*builder, activation_function); + ::tflite::FullyConnectedOptionsWeightsFormat tflite_weights_format; + switch (op.weights_format) { + case FullyConnectedWeightsFormat::kDefault: + tflite_weights_format = + ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; + break; + case FullyConnectedWeightsFormat::kShuffled4x16Int8: + tflite_weights_format = + ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8; + break; + default: + LOG(ERROR) << "Unhandled FC weights format"; + tflite_weights_format = + ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; + } + return ::tflite::CreateFullyConnectedOptions(*builder, activation_function, + tflite_weights_format); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { op->fused_activation_function = ActivationFunction::Deserialize(options.fused_activation_function()); + switch (options.weights_format()) { + case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT: + op->weights_format = FullyConnectedWeightsFormat::kDefault; + break; + case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8; + break; + default: + LOG(ERROR) << "Unhandled FC weights format"; + op->weights_format = FullyConnectedWeightsFormat::kDefault; + } } - int GetVersion(const Operator& op) const override { return 1; } + int GetVersion(const Operator& op) const override { + const auto& fc_op = static_cast(op); + return fc_op.weights_format == FullyConnectedWeightsFormat::kDefault ? 1 + : 2; + } }; class Gather : public BuiltinOperator WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - return ::tflite::CreateGatherOptions(*builder, op.axis); + int axis = op.axis ? op.axis.value() : 0; + return ::tflite::CreateGatherOptions(*builder, axis); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { - op->axis = options.axis(); + op->axis = {options.axis()}; } int GetVersion(const Operator& op) const override { return 1; } @@ -668,16 +709,93 @@ class Lstm : public BuiltinOperator GetMutatingInputVariables( + const Operator& op) const override { + const auto& lstm_op = static_cast(op); + + std::vector mutating_input_variables(op.inputs.size(), false); + switch (lstm_op.kernel_type) { + case LstmCellOperator::KERNEL_FULL: { + mutating_input_variables[kInputActivationStateTensor] = true; + mutating_input_variables[kInputCellStateTensor] = true; + break; + } + case LstmCellOperator::KERNEL_BASIC: { + mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true; + mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true; + break; + } + } + return mutating_input_variables; + } +}; + +class Mean : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class Sum + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class ReduceMax + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); + } + + int GetVersion(const Operator& op) const override { return 1; } }; -class Mean : public BuiltinOperator { +class ReduceProd + : public BuiltinOperator { public: using BuiltinOperator::BuiltinOperator; flatbuffers::Offset WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - return ::tflite::CreateMeanOptions(*builder, op.keep_dims); + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); } void ReadOptions(const TfLiteOptions& options, @@ -812,6 +930,25 @@ class ArgMax : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateArgMinOptions( + *builder, DataType::Serialize(op.output_data_type)); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->output_data_type = DataType::Deserialize(options.output_type()); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TransposeConv : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->values_count = options.values_count(); + op->axis = options.axis(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class Shape + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateShapeOptions( + *builder, DataType::Serialize(op.output_data_type)); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->output_data_type = DataType::Deserialize(options.out_type()); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class OneHot : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateOneHotOptions(*builder, op.axis); + } + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->axis = options.axis(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class CTCBeamSearchDecoder + : public CustomOperator { + public: + using CustomOperator::CustomOperator; + + void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const override { + fbb->Int("beam_width", op.beam_width); + fbb->Int("top_paths", op.top_paths); + fbb->Bool("merge_repeated", op.merge_repeated); + } + + void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + op->beam_width = m["beam_width"].AsInt32(); + op->top_paths = m["top_paths"].AsInt32(); + op->merge_repeated = m["merge_repeated"].AsBool(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -936,6 +1151,20 @@ class TensorFlowUnsupported : public BaseOperator { fbb->Bool(key, attr.b()); has_valid_attr = true; break; + case tensorflow::AttrValue::kList: + if (attr.list().i_size() > 0) { + auto start = fbb->StartVector(key); + for (const int64_t v : attr.list().i()) { + fbb->Add(v); + } + fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); + has_valid_attr = true; + } else { + LOG(WARNING) + << "Ignoring unsupported type in list attribute with key '" + << key << "'"; + } + break; default: LOG(WARNING) << "Ignoring unsupported attribute type with key '" << key << "'"; @@ -971,7 +1200,21 @@ class TensorFlowUnsupported : public BaseOperator { break; case flexbuffers::TYPE_BOOL: (*attr)[key].set_b(value.AsBool()); + if (string(key) == "_output_quantized") { + op->quantized = value.AsBool(); + } + if (string(key) == "_support_output_type_float_in_quantized_op") { + op->support_output_type_float_in_quantized_op = value.AsBool(); + } break; + case flexbuffers::TYPE_VECTOR_INT: { + auto* list = (*attr)[key].mutable_list(); + const auto& vector = value.AsTypedVector(); + for (size_t i = 0; i < vector.size(); i++) { + list->add_i(vector[i].AsInt64()); + } + break; + } default: LOG(WARNING) << "Ignoring unsupported attribute type with key '" << key << "'"; @@ -992,138 +1235,176 @@ namespace { // Build a vector containing all the known operators. std::vector> BuildOperatorList() { std::vector> ops; - + using tensorflow::MakeUnique; // Builtin Operators. - ops.emplace_back(new Add(::tflite::BuiltinOperator_ADD, OperatorType::kAdd)); - ops.emplace_back(new Div(::tflite::BuiltinOperator_DIV, OperatorType::kDiv)); - ops.emplace_back(new Sub(::tflite::BuiltinOperator_SUB, OperatorType::kSub)); - ops.emplace_back(new AveragePool(::tflite::BuiltinOperator_AVERAGE_POOL_2D, - OperatorType::kAveragePool)); - ops.emplace_back( - new SpaceToBatchND(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND, - OperatorType::kSpaceToBatchND)); - ops.emplace_back( - new BatchToSpaceND(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND, - OperatorType::kBatchToSpaceND)); - ops.emplace_back(new Concatenation(::tflite::BuiltinOperator_CONCATENATION, - OperatorType::kConcatenation)); - ops.emplace_back( - new Convolution(::tflite::BuiltinOperator_CONV_2D, OperatorType::kConv)); - ops.emplace_back( - new DepthwiseConvolution(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, - OperatorType::kDepthwiseConv)); - ops.emplace_back(new FullyConnected(::tflite::BuiltinOperator_FULLY_CONNECTED, - OperatorType::kFullyConnected)); - ops.emplace_back( - new Gather(::tflite::BuiltinOperator_GATHER, OperatorType::kGather)); - ops.emplace_back( - new L2Normalization(::tflite::BuiltinOperator_L2_NORMALIZATION, - OperatorType::kL2Normalization)); - ops.emplace_back( - new L2Pool(::tflite::BuiltinOperator_L2_POOL_2D, OperatorType::kL2Pool)); - ops.emplace_back(new LocalResponseNormalization( + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_ADD, OperatorType::kAdd)); + ops.push_back( + MakeUnique

(::tflite::BuiltinOperator_DIV, OperatorType::kDiv)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_SUB, OperatorType::kSub)); + ops.push_back(MakeUnique( + ::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND, + OperatorType::kSpaceToBatchND)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND, + OperatorType::kBatchToSpaceND)); + ops.push_back(MakeUnique( + ::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_CONV_2D, + OperatorType::kConv)); + ops.push_back(MakeUnique( + ::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, + OperatorType::kDepthwiseConv)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_FULLY_CONNECTED, + OperatorType::kFullyConnected)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_GATHER, + OperatorType::kGather)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_L2_NORMALIZATION, + OperatorType::kL2Normalization)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_L2_POOL_2D, + OperatorType::kL2Pool)); + ops.push_back(MakeUnique( ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, OperatorType::kLocalResponseNormalization)); - ops.emplace_back(new MaxPool(::tflite::BuiltinOperator_MAX_POOL_2D, - OperatorType::kMaxPool)); - ops.emplace_back(new Mul(::tflite::BuiltinOperator_MUL, OperatorType::kMul)); - ops.emplace_back(new Pad(::tflite::BuiltinOperator_PAD, OperatorType::kPad)); - ops.emplace_back( - new PadV2(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2)); - ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE, - OperatorType::kTensorFlowReshape)); - ops.emplace_back( - new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax)); - ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH, - OperatorType::kSpaceToDepth)); - ops.emplace_back( - new Svdf(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf)); - ops.emplace_back(new Transpose(::tflite::BuiltinOperator_TRANSPOSE, - OperatorType::kTranspose)); - ops.emplace_back( - new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); - ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR, - OperatorType::kResizeBilinear)); - ops.emplace_back( - new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze)); - ops.emplace_back(new Split(::tflite::BuiltinOperator_SPLIT, - OperatorType::kTensorFlowSplit)); - ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE, - OperatorType::kStridedSlice)); - ops.emplace_back( - new TopK_V2(::tflite::BuiltinOperator_TOPK_V2, OperatorType::kTopK_V2)); - ops.emplace_back( - new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell)); - ops.emplace_back( - new Cast(::tflite::BuiltinOperator_CAST, OperatorType::kCast)); - ops.emplace_back( - new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax)); - ops.emplace_back( - new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTensorFlowTile)); - ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS, - OperatorType::kExpandDims)); - ops.emplace_back(new TransposeConv(::tflite::BuiltinOperator_TRANSPOSE_CONV, - OperatorType::kTransposeConv)); - ops.emplace_back(new SparseToDense(::tflite::BuiltinOperator_SPARSE_TO_DENSE, - OperatorType::kSparseToDense)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_MAX_POOL_2D, + OperatorType::kMaxPool)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_MUL, OperatorType::kMul)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_PAD, OperatorType::kPad)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_RESHAPE, + OperatorType::kReshape)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_SOFTMAX, + OperatorType::kSoftmax)); + ops.push_back(MakeUnique( + ::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_TRANSPOSE, + OperatorType::kTranspose)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_SUM, OperatorType::kSum)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_REDUCE_PROD, + OperatorType::kReduceProd)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_REDUCE_MAX, + OperatorType::kReduceMax)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_RESIZE_BILINEAR, + OperatorType::kResizeBilinear)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_SQUEEZE, + OperatorType::kSqueeze)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit)); + ops.push_back(MakeUnique( + ::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_TOPK_V2, + OperatorType::kTopK_V2)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_LSTM, + OperatorType::kLstmCell)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_CAST, OperatorType::kCast)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_ARG_MAX, + OperatorType::kArgMax)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_ARG_MIN, + OperatorType::kArgMin)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_TILE, OperatorType::kTile)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_EXPAND_DIMS, + OperatorType::kExpandDims)); + ops.push_back(MakeUnique( + ::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv)); + ops.push_back(MakeUnique( + ::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_FAKE_QUANT, + OperatorType::kFakeQuant)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_ONE_HOT, + OperatorType::kOneHot)); // Custom Operators. - ops.emplace_back( - new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); - ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant)); - ops.emplace_back(new TensorFlowUnsupported( - "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported)); + ops.push_back( + MakeUnique("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); + ops.push_back(MakeUnique( + "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder)); + ops.push_back(MakeUnique("TENSORFLOW_UNSUPPORTED", + OperatorType::kUnsupported)); // There operators are supported by Toco, but not by TF Lite, and has no // attributes. - ops.emplace_back( - new SimpleOperator("ADDN", OperatorType::kAddN)); - ops.emplace_back(new SimpleOperator( - "RSQRT", OperatorType::kTensorFlowRsqrt)); + ops.push_back( + MakeUnique>("ADDN", OperatorType::kAddN)); // Simple Operators. - ops.emplace_back(new SimpleOperator( + ops.push_back(MakeUnique>( "DEQUANTIZE", OperatorType::kDequantize)); - ops.emplace_back( - new SimpleOperator("FLOOR", OperatorType::kFloor)); - ops.emplace_back( - new SimpleOperator("RELU", OperatorType::kRelu)); - ops.emplace_back( - new SimpleOperator("RELU_N1_TO_1", OperatorType::kRelu1)); - ops.emplace_back( - new SimpleOperator("RELU6", OperatorType::kRelu6)); - ops.emplace_back( - new SimpleOperator("PRELU", OperatorType::kPRelu)); - ops.emplace_back(new SimpleOperator( + ops.push_back( + MakeUnique>("FLOOR", OperatorType::kFloor)); + ops.push_back( + MakeUnique>("RELU", OperatorType::kRelu)); + ops.push_back(MakeUnique>( + "RELU_N1_TO_1", OperatorType::kRelu1)); + ops.push_back( + MakeUnique>("RELU6", OperatorType::kRelu6)); + ops.push_back( + MakeUnique>("PRELU", OperatorType::kPRelu)); + ops.push_back(MakeUnique>( "LOGISTIC", OperatorType::kLogistic)); - ops.emplace_back( - new SimpleOperator("TANH", OperatorType::kTanh)); - ops.emplace_back(new SimpleOperator("EXP", OperatorType::kExp)); - ops.emplace_back(new SimpleOperator( + ops.push_back( + MakeUnique>("TANH", OperatorType::kTanh)); + ops.push_back( + MakeUnique>("EXP", OperatorType::kExp)); + ops.push_back(MakeUnique>( "LOG_SOFTMAX", OperatorType::kLogSoftmax)); - ops.emplace_back(new SimpleOperator( - "MAXIMUM", OperatorType::kTensorFlowMaximum)); - ops.emplace_back(new SimpleOperator( - "MINIMUM", OperatorType::kTensorFlowMinimum)); - ops.emplace_back(new SimpleOperator( - "GREATER", OperatorType::kTensorFlowGreater)); - ops.emplace_back(new SimpleOperator( - "GREATER_EQUAL", OperatorType::kTensorFlowGreaterEqual)); - ops.emplace_back(new SimpleOperator( - "LESS", OperatorType::kTensorFlowLess)); - ops.emplace_back(new SimpleOperator( - "LESS_EQUAL", OperatorType::kTensorFlowLessEqual)); - ops.emplace_back(new SimpleOperator( - "EQUAL", OperatorType::kTensorFlowEqual)); - ops.emplace_back(new SimpleOperator( - "NOT_EQUAL", OperatorType::kTensorFlowNotEqual)); - ops.emplace_back(new SimpleOperator("NEG", OperatorType::kNeg)); - ops.emplace_back( - new SimpleOperator("SELECT", OperatorType::kSelect)); - ops.emplace_back( - new SimpleOperator("SLICE", OperatorType::kSlice)); + ops.push_back(MakeUnique>( + "MAXIMUM", OperatorType::kMaximum)); // Element-wise Maximum + ops.push_back(MakeUnique>( + "MINIMUM", OperatorType::kMinimum)); // Element-wise Minimum + ops.push_back(MakeUnique>( + "GREATER", OperatorType::kGreater)); + ops.push_back(MakeUnique>( + "GREATER_EQUAL", OperatorType::kGreaterEqual)); + ops.push_back(MakeUnique>( + "LESS", OperatorType::kLess)); + ops.push_back(MakeUnique>( + "LESS_EQUAL", OperatorType::kLessEqual)); + ops.push_back(MakeUnique>( + "EQUAL", OperatorType::kEqual)); + ops.push_back(MakeUnique>( + "NOT_EQUAL", OperatorType::kNotEqual)); + ops.push_back( + MakeUnique>("NEG", OperatorType::kNeg)); + ops.push_back(MakeUnique>( + "SELECT", OperatorType::kSelect)); + ops.push_back( + MakeUnique>("SLICE", OperatorType::kSlice)); + ops.push_back( + MakeUnique>("POW", OperatorType::kPow)); + ops.push_back(MakeUnique>( + "LOGICAL_OR", OperatorType::kLogicalOr)); + ops.emplace_back(new SimpleOperator( + "LOGICAL_AND", OperatorType::kLogicalAnd)); + ops.emplace_back(new SimpleOperator( + "LOGICAL_NOT", OperatorType::kLogicalNot)); // Element-wise operator - ops.emplace_back(new SimpleOperator("SIN", OperatorType::kSin)); - ops.emplace_back(new SimpleOperator("LOG", OperatorType::kLog)); + ops.push_back( + MakeUnique>("SIN", OperatorType::kSin)); + ops.push_back( + MakeUnique>("LOG", OperatorType::kLog)); + ops.push_back(MakeUnique>( + "SQRT", OperatorType::kSqrt)); + ops.push_back(MakeUnique>( + "RSQRT", OperatorType::kRsqrt)); return ops; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h index 5e9c20e40dd6274e0839379883b6dbe53064a0fc..d9ea23edf2b08146773ca58762623397e0f6257c 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.h +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -87,6 +87,17 @@ class BaseOperator { // overridden. (See example in `operator_test.cc`) virtual int GetVersion(const Operator& op) const = 0; + // Given a Toco `Operator`, return a list of booleans indicating the op + // mutates which input variables. + // * If the op mutates any input variables, it should return a list of bool + // with the same length as inputs. + // * Otherwise, it will return an empty list. + virtual std::vector GetMutatingInputVariables( + const Operator& op) const { + // Most ops don't have variable tensors. This function can be overridden. + return std::vector(); + } + private: string name_; OperatorType type_; diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 03bb20b3208196e964d950c0f0954d1fc0ba9e86..fc854461b4e816e12e12590479501b6542258fef 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -112,20 +112,27 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax); CheckSimpleOperator( - "MAXIMUM", OperatorType::kTensorFlowMaximum); + "MAXIMUM", OperatorType::kMaximum); // Element-wise Maximum CheckSimpleOperator( - "MINIMUM", OperatorType::kTensorFlowMinimum); - CheckSimpleOperator("LESS", - OperatorType::kTensorFlowLess); + "MINIMUM", OperatorType::kMinimum); // Element-wise Minimum + CheckSimpleOperator("LESS", OperatorType::kLess); CheckSimpleOperator("NEG", OperatorType::kNeg); CheckSimpleOperator("SELECT", OperatorType::kSelect); CheckSimpleOperator("SLICE", OperatorType::kSlice); CheckSimpleOperator("SIN", OperatorType::kSin); - CheckSimpleOperator("EQUAL", - OperatorType::kTensorFlowEqual); - CheckSimpleOperator( - "NOT_EQUAL", OperatorType::kTensorFlowNotEqual); + CheckSimpleOperator("EQUAL", OperatorType::kEqual); + CheckSimpleOperator("NOT_EQUAL", + OperatorType::kNotEqual); CheckSimpleOperator("LOG", OperatorType::kLog); + CheckSimpleOperator("SQRT", OperatorType::kSqrt); + CheckSimpleOperator("RSQRT", OperatorType::kRsqrt); + CheckSimpleOperator("POW", OperatorType::kPow); + CheckSimpleOperator("LOGICAL_OR", + OperatorType::kLogicalOr); + CheckSimpleOperator("LOGICAL_AND", + OperatorType::kLogicalAnd); + CheckSimpleOperator("LOGICAL_NOT", + OperatorType::kLogicalNot); } TEST_F(OperatorTest, BuiltinAdd) { @@ -254,7 +261,7 @@ TEST_F(OperatorTest, BuiltinReshape) { TensorFlowReshapeOperator op; op.shape = {1, 2, 4, 5, 8}; auto output_toco_op = SerializeAndDeserialize( - GetOperator("RESHAPE", OperatorType::kTensorFlowReshape), op); + GetOperator("RESHAPE", OperatorType::kReshape), op); EXPECT_EQ(op.shape, output_toco_op->shape); } @@ -277,8 +284,8 @@ TEST_F(OperatorTest, BuiltinSpaceToDepth) { TEST_F(OperatorTest, CustomSplit) { TensorFlowSplitOperator op; op.num_split = 123; - auto output_toco_op = SerializeAndDeserialize( - GetOperator("SPLIT", OperatorType::kTensorFlowSplit), op); + auto output_toco_op = + SerializeAndDeserialize(GetOperator("SPLIT", OperatorType::kSplit), op); EXPECT_EQ(op.num_split, output_toco_op->num_split); } @@ -415,6 +422,13 @@ TEST_F(OperatorTest, BuiltinArgMax) { EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); } +TEST_F(OperatorTest, BuiltinArgMin) { + ArgMinOperator op; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("ARG_MIN", OperatorType::kArgMin), op); + EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); +} + TEST_F(OperatorTest, BuiltinTransposeConv) { TransposeConvOperator op; op.stride_width = 123; @@ -427,6 +441,14 @@ TEST_F(OperatorTest, BuiltinTransposeConv) { EXPECT_EQ(op.padding.type, output_toco_op->padding.type); } +TEST_F(OperatorTest, BuiltinShape) { + TensorFlowShapeOperator op; + op.output_data_type = ArrayDataType::kInt64; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("SHAPE", OperatorType::kShape), op); + EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); +} + TEST_F(OperatorTest, BuiltinSparseToDense) { SparseToDenseOperator op; op.validate_indices = false; @@ -436,6 +458,38 @@ TEST_F(OperatorTest, BuiltinSparseToDense) { EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices); } +TEST_F(OperatorTest, BuiltinPack) { + PackOperator op; + op.values_count = 3; + op.axis = 1; + std::unique_ptr output_toco_op = + SerializeAndDeserialize(GetOperator("PACK", OperatorType::kPack), op); + EXPECT_EQ(op.values_count, output_toco_op->values_count); + EXPECT_EQ(op.axis, output_toco_op->axis); +} + +TEST_F(OperatorTest, BuiltinOneHot) { + OneHotOperator op; + op.axis = 2; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("ONE_HOT", OperatorType::kOneHot), op); + EXPECT_EQ(op.axis, output_toco_op->axis); +} + +TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) { + CTCBeamSearchDecoderOperator op; + op.beam_width = 3; + op.top_paths = 2; + op.merge_repeated = false; + std::unique_ptr output_toco_op = + SerializeAndDeserialize(GetOperator("CTC_BEAM_SEARCH_DECODER", + OperatorType::kCTCBeamSearchDecoder), + op); + EXPECT_EQ(op.beam_width, output_toco_op->beam_width); + EXPECT_EQ(op.top_paths, output_toco_op->top_paths); + EXPECT_EQ(op.merge_repeated, output_toco_op->merge_repeated); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; @@ -446,12 +500,17 @@ TEST_F(OperatorTest, TensorFlowUnsupported) { (*attr)["str_attr"].set_s("Hello World"); (*attr)["int_attr"].set_i(17); (*attr)["bool_attr"].set_b(true); + { + auto* list = (*attr)["list_int_attr"].mutable_list(); + list->add_i(1); + list->add_i(20); + list->add_i(1LL << 40); + list->add_i(-(1LL << 40)); + } node_def.SerializeToString(&op.tensorflow_node_def); - auto output_toco_op = - SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", - OperatorType::kTensorFlowUnsupported), - op); + auto output_toco_op = SerializeAndDeserialize( + GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op); ::tensorflow::NodeDef output_node_def; output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); @@ -460,15 +519,22 @@ TEST_F(OperatorTest, TensorFlowUnsupported) { EXPECT_EQ("Hello World", output_attr.at("str_attr").s()); EXPECT_EQ(17, output_attr.at("int_attr").i()); EXPECT_EQ(true, output_attr.at("bool_attr").b()); + + { + const auto& list = output_attr.at("list_int_attr").list(); + ASSERT_EQ(4, list.i_size()); + EXPECT_EQ(1, list.i(0)); + EXPECT_EQ(20, list.i(1)); + EXPECT_EQ(1LL << 40, list.i(2)); + EXPECT_EQ(-(1LL << 40), list.i(3)); + } } TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; - auto output_toco_op = - SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", - OperatorType::kTensorFlowUnsupported), - op); + auto output_toco_op = SerializeAndDeserialize( + GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op); ::tensorflow::NodeDef output_node_def; output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc index 4867c3a62e68406428644cd05bddf212008c2656..754f0b4b8c661355c99d9e5a86f2d7844414a303 100644 --- a/tensorflow/contrib/lite/toco/tflite/types.cc +++ b/tensorflow/contrib/lite/toco/tflite/types.cc @@ -88,6 +88,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { switch (array_data_type) { case ArrayDataType::kFloat: return ::tflite::TensorType_FLOAT32; + case ArrayDataType::kInt16: + return ::tflite::TensorType_INT16; case ArrayDataType::kInt32: return ::tflite::TensorType_INT32; case ArrayDataType::kInt64: @@ -98,6 +100,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { return ::tflite::TensorType_STRING; case ArrayDataType::kBool: return ::tflite::TensorType_BOOL; + case ArrayDataType::kComplex64: + return ::tflite::TensorType_COMPLEX64; default: // FLOAT32 is filled for unknown data types. // TODO(ycling): Implement type inference in TF Lite interpreter. @@ -109,6 +113,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) { switch (::tflite::TensorType(tensor_type)) { case ::tflite::TensorType_FLOAT32: return ArrayDataType::kFloat; + case ::tflite::TensorType_INT16: + return ArrayDataType::kInt16; case ::tflite::TensorType_INT32: return ArrayDataType::kInt32; case ::tflite::TensorType_INT64: @@ -119,6 +125,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) { return ArrayDataType::kUint8; case ::tflite::TensorType_BOOL: return ArrayDataType::kBool; + case ::tflite::TensorType_COMPLEX64: + return ArrayDataType::kComplex64; default: LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'."; } @@ -131,6 +139,8 @@ flatbuffers::Offset> DataBuffer::Serialize( switch (array.data_type) { case ArrayDataType::kFloat: return CopyBuffer(array, builder); + case ArrayDataType::kInt16: + return CopyBuffer(array, builder); case ArrayDataType::kInt32: return CopyBuffer(array, builder); case ArrayDataType::kInt64: @@ -141,6 +151,8 @@ flatbuffers::Offset> DataBuffer::Serialize( return CopyBuffer(array, builder); case ArrayDataType::kBool: return CopyBoolToBuffer(array, builder); + case ArrayDataType::kComplex64: + return CopyBuffer(array, builder); default: LOG(FATAL) << "Unhandled array data type."; } @@ -154,6 +166,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, switch (tensor.type()) { case ::tflite::TensorType_FLOAT32: return CopyBuffer(buffer, array); + case ::tflite::TensorType_INT16: + return CopyBuffer(buffer, array); case ::tflite::TensorType_INT32: return CopyBuffer(buffer, array); case ::tflite::TensorType_INT64: @@ -164,6 +178,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, return CopyBuffer(buffer, array); case ::tflite::TensorType_BOOL: return CopyBuffer(buffer, array); + case ::tflite::TensorType_COMPLEX64: + return CopyBuffer(buffer, array); default: LOG(FATAL) << "Unhandled tensor type."; } diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc index 564f303b9bb41a777633ecabd666aa93ec3faefe..8e9f30ba3a6e6b98fa9c4237567b0797a5a797aa 100644 --- a/tensorflow/contrib/lite/toco/tflite/types_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/toco/tflite/types.h" +#include + #include #include @@ -71,7 +73,8 @@ TEST(DataType, SupportedTypes) { {ArrayDataType::kInt32, ::tflite::TensorType_INT32}, {ArrayDataType::kInt64, ::tflite::TensorType_INT64}, {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}, - {ArrayDataType::kBool, ::tflite::TensorType_BOOL}}; + {ArrayDataType::kBool, ::tflite::TensorType_BOOL}, + {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64}}; for (auto x : testdata) { EXPECT_EQ(x.second, DataType::Serialize(x.first)); EXPECT_EQ(x.first, DataType::Deserialize(x.second)); @@ -151,6 +154,12 @@ TEST(DataBuffer, Int32) { ::testing::ElementsAre(1, 1 << 30)); } +TEST(DataBuffer, Int16) { + Array recovered = ToFlatBufferAndBack({1, 1 << 14}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(1, 1 << 14)); +} + TEST(DataBuffer, String) { Array recovered = ToFlatBufferAndBack( {"AA", "BBB", "Best. String. Ever."}); @@ -165,6 +174,14 @@ TEST(DataBuffer, Bool) { ::testing::ElementsAre(true, false, true)); } +TEST(DataBuffer, Complex64) { + Array recovered = ToFlatBufferAndBack( + {std::complex(1.0f, 2.0f), std::complex(3.0f, 4.0f)}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(std::complex(1.0f, 2.0f), + std::complex(3.0f, 4.0f))); +} + TEST(Padding, All) { EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame)); EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME)); diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/contrib/lite/toco/toco.cc index 8041aa9e7fbfdaf44134395fee4b2bb01633893a..0b460bd178a49cafefd3438b7ae1c38a07b2ab7c 100644 --- a/tensorflow/contrib/lite/toco/toco.cc +++ b/tensorflow/contrib/lite/toco/toco.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_port.h" -#include "tensorflow/contrib/lite/toco/toco_saved_model.h" #include "tensorflow/contrib/lite/toco/toco_tooling.h" #include "tensorflow/contrib/lite/toco/toco_types.h" #include "tensorflow/core/platform/logging.h" @@ -49,17 +48,6 @@ void CheckFrozenModelPermissions(const Arg& input_file) { << input_file.value() << ".\n"; } -// Checks the permissions of the SavedModel directory. -void CheckSavedModelPermissions(const Arg& savedmodel_directory) { - QCHECK(savedmodel_directory.specified()) - << "Missing required flag --savedmodel_directory.\n"; - QCHECK( - port::file::Exists(savedmodel_directory.value(), port::file::Defaults()) - .ok()) - << "Specified savedmodel_directory does not exist: " - << savedmodel_directory.value() << ".\n"; -} - // Reads the contents of the GraphDef from either the frozen graph file or the // SavedModel directory. If it reads the SavedModel directory, it updates the // ModelFlags and TocoFlags accordingly. @@ -69,24 +57,16 @@ void ReadInputData(const ParsedTocoFlags& parsed_toco_flags, string* graph_def_contents) { port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n"); - bool has_input_file = parsed_toco_flags.input_file.specified(); - bool has_savedmodel_dir = parsed_toco_flags.savedmodel_directory.specified(); - - // Ensure either input_file or savedmodel_directory flag has been set. - QCHECK_NE(has_input_file, has_savedmodel_dir) - << "Specify either input_file or savedmodel_directory flag.\n"; + // Ensure savedmodel_directory is not set. + QCHECK(!parsed_toco_flags.savedmodel_directory.specified()) + << "Use `tensorflow/contrib/lite/python/tflite_convert` script with " + << "SavedModel directories.\n"; // Checks the input file permissions and reads the contents. - if (has_input_file) { - CheckFrozenModelPermissions(parsed_toco_flags.input_file); - CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(), - graph_def_contents, port::file::Defaults()) - .ok()); - } else { - CheckSavedModelPermissions(parsed_toco_flags.savedmodel_directory); - GetSavedModelContents(parsed_toco_flags, parsed_model_flags, toco_flags, - model_flags, graph_def_contents); - } + CheckFrozenModelPermissions(parsed_toco_flags.input_file); + CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(), + graph_def_contents, port::file::Defaults()) + .ok()); } void ToolMain(const ParsedTocoFlags& parsed_toco_flags, diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index 87a1e429b928bf59cb14597980602953732a7659..c6d0a03452f7477841d7e68665baf32dff45f41c 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -41,7 +41,7 @@ bool ParseTocoFlagsFromCommandLineFlags( "extension."), Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(), parsed_flags.savedmodel_directory.default_value(), - "Full path to the directory containing the SavedModel."), + "Deprecated. Full path to the directory containing the SavedModel."), Flag("output_file", parsed_flags.output_file.bind(), parsed_flags.output_file.default_value(), "Output file. " @@ -55,9 +55,9 @@ bool ParseTocoFlagsFromCommandLineFlags( "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."), Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(), parsed_flags.savedmodel_tagset.default_value(), - "Comma-separated set of tags identifying the MetaGraphDef within " - "the SavedModel to analyze. All tags in the tag set must be " - "specified."), + "Deprecated. Comma-separated set of tags identifying the " + "MetaGraphDef within the SavedModel to analyze. All tags in the tag " + "set must be specified."), Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(), parsed_flags.default_ranges_min.default_value(), "If defined, will be used as the default value for the min bound " diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index ad4e94ded9f9730842a257e065d9aec2b1cbfac8..b4a9870d5834d1d5689d15ebc131ac0ead3e9850 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -37,7 +37,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 21. +// Next ID to use: 26. message TocoFlags { // Input file format optional FileFormat input_format = 1; diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc index de76fd4032d24eff8a6c2fd0c16a911b9c00186b..14168fa33f77a75706a52f00ddfa6b1120d90626 100644 --- a/tensorflow/contrib/lite/toco/toco_port.cc +++ b/tensorflow/contrib/lite/toco/toco_port.cc @@ -38,7 +38,8 @@ void CopyToBuffer(const Cord& src, char* dest) { src.CopyToArray(dest); } } // namespace port } // namespace toco -#if defined(PLATFORM_GOOGLE) && !defined(__APPLE__) && !defined(__ANDROID__) +#if defined(PLATFORM_GOOGLE) && !defined(__APPLE__) && \ + !defined(__ANDROID__) && !defined(_WIN32) // Wrap Google file operations. @@ -115,9 +116,12 @@ string JoinPath(const string& a, const string& b) { } // namespace port } // namespace toco -#else // (__APPLE__ || __ANDROID__) +#else // !PLATFORM_GOOGLE || __APPLE__ || __ANDROID__ || _WIN32 #include +#if defined(_WIN32) +#include // for _close, _open, _read +#endif #include #include #include @@ -130,6 +134,19 @@ string JoinPath(const string& a, const string& b) { namespace toco { namespace port { +#if defined(_WIN32) +#define close _close +#define open _open +#define read _read +#define O_RDONLY _O_RDONLY +#define O_CREAT _O_CREAT +#define O_WRONLY _O_WRONLY +// Windows does not support the same set of file permissions as other platforms. +constexpr int kFileCreateMode = _S_IREAD | _S_IWRITE; +#else +constexpr int kFileCreateMode = 0664; +#endif // _WIN32 + static bool port_initialized = false; void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags) { @@ -209,7 +226,7 @@ tensorflow::Status GetContents(const string& path, string* output, tensorflow::Status SetContents(const string& filename, const string& contents, const file::Options& options) { - int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664); + int fd = open(filename.c_str(), O_WRONLY | O_CREAT, kFileCreateMode); if (fd == -1) { return tensorflow::errors::Internal("can't open() for write"); } @@ -243,4 +260,4 @@ string JoinPath(const string& base, const string& filename) { } // namespace port } // namespace toco -#endif // (__APPLE || __ANDROID__) +#endif // !PLATFORM_GOOGLE || __APPLE || __ANDROID__ || _WIN32 diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.cc b/tensorflow/contrib/lite/toco/toco_saved_model.cc deleted file mode 100644 index 26f55a66c729894a990258080e397bb42ea98a13..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/toco_saved_model.cc +++ /dev/null @@ -1,189 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "absl/strings/numbers.h" -#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" -#include "tensorflow/contrib/lite/toco/toco_saved_model.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" - -namespace toco { -namespace { - -// Loads a SavedModel from the directory specified in parsed_toco_flags. -// Returns a SavedModelBundle with the requested MetaGraphDef. -const tensorflow::SavedModelBundle* LoadSavedModel( - const ParsedTocoFlags& parsed_toco_flags) { - const string model_path = parsed_toco_flags.savedmodel_directory.value(); - QCHECK(tensorflow::MaybeSavedModelDirectory(model_path)) - << "Model is not saved in the supported SavedModel format.\n"; - - // Gets the tags identifying the MetaGraphDef from the command line arguments. - string tags_str; - if (parsed_toco_flags.savedmodel_tagset.specified()) { - tags_str = parsed_toco_flags.savedmodel_tagset.value(); - } else { - tags_str = parsed_toco_flags.savedmodel_tagset.default_value(); - } - auto tags = absl::StrSplit(tags_str, ','); - - // Loads MetaGraphDef. - auto* bundle = new tensorflow::SavedModelBundle; - TF_CHECK_OK(tensorflow::LoadSavedModel(tensorflow::SessionOptions(), - tensorflow::RunOptions(), model_path, - tags, bundle)) - << "Failed to load exported model from " << model_path - << ". Ensure the model contains the required tags '" << tags_str - << "'.\n"; - return bundle; -} - -// Returns the array name without the postfix. -// -// e.g. reduces "input:0" to "input". -string GetArrayName(const string& name) { - const std::vector& names = absl::StrSplit(name, ':'); - return names[0]; -} - -// Returns the list of array names without the postfix sorted alphabetically. -std::set GetSortedNames(const std::unordered_set& names) { - std::vector final_names; - final_names.reserve(names.size()); - for (const auto& name : names) { - final_names.push_back(GetArrayName(name)); - } - return std::set(final_names.begin(), final_names.end()); -} - -// Gets the final shape after replacing the first dimension with batch size, if -// it is undefined (containing the value -1). Returns whether the shape is -// valid. -bool ReplaceShapeBatchSize(const tensorflow::TensorShapeProto& shape, - int batch_size, - tensorflow::TensorShapeProto* final_shape) { - for (int idx = 0; idx < shape.dim().size(); ++idx) { - int64 final_dim = shape.dim()[idx].size(); - if (final_dim == -1) { - if (idx > 0) return false; - final_dim = batch_size; - } - final_shape->add_dim()->set_size(final_dim); - } - return true; -} - -// Updates the input arrays in ModelFlags to contain the shape of the array. -void ProcessInputShapes(const tensorflow::GraphDef& graph_def, int batch_size, - ModelFlags* model_flags) { - // Build map of input array names to input arrays. - std::unordered_map input_data_map; - for (auto& input : *model_flags->mutable_input_arrays()) { - input_data_map[input.name()] = &input; - } - - // Adds shapes to the input arrays if the shape is valid. - for (const tensorflow::NodeDef& node_def : graph_def.node()) { - if (input_data_map.find(node_def.name()) != input_data_map.end()) { - const auto shape_it = node_def.attr().find("shape"); - if (shape_it != node_def.attr().end()) { - tensorflow::TensorShapeProto final_shape; - bool is_valid = ReplaceShapeBatchSize(shape_it->second.shape(), - batch_size, &final_shape); - - if (is_valid) { - auto* shape = input_data_map.at(node_def.name())->mutable_shape(); - QCHECK_EQ(shape->dims_size(), 0) - << "The shape for the input '" << node_def.name() - << "' was previously defined. For clarity please define inputs " - << "via --input_arrays and input_shapes flags.\n"; - for (const auto& dim : final_shape.dim()) { - shape->add_dims(dim.size()); - } - } - } - } - } - - // Checks all input arrays have a shape. - for (auto const& input : model_flags->input_arrays()) { - QCHECK(input.shape().dims_size() > 0) - << "A valid input shape was not found for input '" << input.name() - << "'. Please define via --input_arrays and --input_shapes flags.\n"; - } -} - -} // namespace - -void ParseMetaData(const tensorflow::GraphDef& graph_def, - const std::unordered_set& inputs, - const std::unordered_set& outputs, - const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - TocoFlags* toco_flags, ModelFlags* model_flags) { - if (!parsed_model_flags.input_arrays.specified()) { - const std::set sorted_inputs = GetSortedNames(inputs); - for (const auto& input_name : sorted_inputs) { - model_flags->add_input_arrays()->set_name(input_name); - } - } - - if (!parsed_model_flags.output_arrays.specified()) { - const std::set sorted_outputs = GetSortedNames(outputs); - for (const auto& output_name : sorted_outputs) { - model_flags->add_output_arrays(GetArrayName(output_name)); - } - } - - if (!parsed_model_flags.input_shapes.specified()) { - int batch_size = parsed_model_flags.batch_size.value(); - ProcessInputShapes(graph_def, batch_size, model_flags); - } - - if (!parsed_toco_flags.inference_type.specified()) { - toco_flags->set_inference_type(IODataType::FLOAT); - } -} - -// TODO(nupurgarg): Add top level tests. -void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - TocoFlags* toco_flags, ModelFlags* model_flags, - string* graph_def_contents) { - // Loads the MetaGraphDef within a SavedModelBundle. - auto bundle = LoadSavedModel(parsed_toco_flags); - - // Converts the MetaGraphDef to frozen GraphDef. - tensorflow::GraphDef frozen_graph_def; - std::unordered_set inputs; - std::unordered_set outputs; - TF_CHECK_OK(tensorflow::FreezeSavedModel(*bundle, &frozen_graph_def, &inputs, - &outputs)); - - // Reads the frozen GraphDef into a string. - QCHECK(frozen_graph_def.SerializeToString(graph_def_contents)) - << "Unable to generate serialized GraphDef.\n"; - - // Process inputs and outputs and metadata within GraphDef. - const tensorflow::GraphDef graph_def = bundle->meta_graph_def.graph_def(); - ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags, - parsed_model_flags, toco_flags, model_flags); -} - -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.h b/tensorflow/contrib/lite/toco/toco_saved_model.h deleted file mode 100644 index 7a0fabd82d90131a3b2d28c757c08dcb0f9e3988..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/toco_saved_model.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ -#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ - -#include -#include - -#include "tensorflow/cc/tools/freeze_saved_model.h" -#include "tensorflow/contrib/lite/toco/args.h" -#include "tensorflow/contrib/lite/toco/model_flags.pb.h" -#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" -#include "tensorflow/contrib/lite/toco/types.pb.h" - -namespace toco { - -// Parses metadata into `toco_flags` and `model_flags`. -// -// Stores `inputs` as input_arrays and `outputs` as output_arrays in -// `model_flags`. Infers input_shapes from the GraphDef and stores it in -// `model_flags` as part of the input_arrays. Assumes inference_type is FLOAT -// and stores it in `toco_flags`. -void ParseMetaData(const tensorflow::GraphDef& graph_def, - const std::unordered_set& inputs, - const std::unordered_set& outputs, - const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - TocoFlags* toco_flags, ModelFlags* model_flags); - -// Generates a frozen graph from the SavedModel in the directory specified in -// `toco_flags`. Reads frozen graph contents into `graph_def_contents`. Parses -// metadata relating to the GraphDef into `toco_flags` and `model_flags`. -void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags, - TocoFlags* toco_flags, ModelFlags* model_flags, - string* graph_def_contents); - -} // namespace toco - -#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_ diff --git a/tensorflow/contrib/lite/toco/toco_saved_model_test.cc b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc deleted file mode 100644 index 5e122afe65dc29abc85f142f4019aae5058ace51..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/toco_saved_model_test.cc +++ /dev/null @@ -1,274 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/lite/toco/toco_saved_model.h" -#include "absl/strings/str_join.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" -#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -#include -#include - -namespace toco { -namespace { - -using tensorflow::ops::Add; -using tensorflow::ops::Const; -using tensorflow::ops::FakeQuantWithMinMaxArgs; -using tensorflow::ops::Placeholder; - -class TocoSavedModelTest : public ::testing::Test { - protected: - // Calls functions to process cmdline arguments and calls ParseMetaData. - // ParseMetaData parses input_arrays, output_arrays, and gets metadata from - // SavedModel it is not defined in the cmdline arguments. - void ProcessGraphDefMetadata(const std::unordered_set& inputs, - const std::unordered_set& outputs, - const tensorflow::GraphDef& graph_def) { - ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags_, &toco_flags_); - ReadModelFlagsFromCommandLineFlags(parsed_model_flags_, &model_flags_); - ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags_, - parsed_model_flags_, &toco_flags_, &model_flags_); - } - - // Gets the GraphDef from the SavedModelBundle and processes metadata. - void ProcessSavedModelMetadata(const std::unordered_set& inputs, - const std::unordered_set& outputs) { - const tensorflow::GraphDef graph_def = bundle_.meta_graph_def.graph_def(); - ProcessGraphDefMetadata(inputs, outputs, graph_def); - } - - // Returns a GraphDef representing a simple float model with a single input. - tensorflow::GraphDef GetFloatGraphDef(const std::vector& shape) { - tensorflow::GraphDef graph_def; - tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); - - tensorflow::Output input = - Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT, - Placeholder::Shape(tensorflow::PartialTensorShape(shape))); - tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {}); - tensorflow::Output add = Add(scope.WithOpName("add"), input, zero); - - TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); - return graph_def; - } - - // Returns a GraphDef representing a simple float model with two inputs. - tensorflow::GraphDef GetComplexFloatGraphDef() { - tensorflow::GraphDef graph_def; - tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); - - tensorflow::Output inputA = - Placeholder(scope.WithOpName("inputA"), tensorflow::DT_FLOAT, - Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); - tensorflow::Output inputB = - Placeholder(scope.WithOpName("inputB"), tensorflow::DT_FLOAT, - Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); - tensorflow::Output add = Add(scope.WithOpName("add"), inputB, inputA); - - TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); - return graph_def; - } - - // Returns a GraphDef representing a simple quantized model. - tensorflow::GraphDef GetQuantizedGraphDef() { - tensorflow::GraphDef graph_def; - tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); - - tensorflow::Output input = - Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT, - Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1}))); - tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {}); - tensorflow::Output fake_quant = - FakeQuantWithMinMaxArgs(scope.WithOpName("quant"), zero); - tensorflow::Output add = Add(scope.WithOpName("add"), input, fake_quant); - - TF_EXPECT_OK(scope.ToGraphDef(&graph_def)); - return graph_def; - } - - // Gets the values in the input_arrays flag. - std::vector GetInputArrays() { - std::vector actual; - for (const auto& input : model_flags_.input_arrays()) { - actual.push_back(input.name()); - } - return actual; - } - - // Gets the values in the output_arrays flag. - std::vector GetOutputArrays() { - std::vector actual(model_flags_.output_arrays().begin(), - model_flags_.output_arrays().end()); - return actual; - } - - // Gets the shape of the given input array. - string GetInputShape(const string& input_array) { - for (const auto& input : model_flags_.input_arrays()) { - if (input.name() == input_array) { - std::vector dims; - for (int idx = 0; idx < input.shape().dims_size(); ++idx) { - dims.push_back(std::to_string(input.shape().dims(idx))); - } - return absl::StrJoin(dims, ","); - } - } - return ""; - } - - tensorflow::SavedModelBundle bundle_; - ParsedTocoFlags parsed_toco_flags_; - ParsedModelFlags parsed_model_flags_; - TocoFlags toco_flags_; - ModelFlags model_flags_; -}; - -// Tests if input_arrays, output_arrays, inference_type, and output_arrays are -// added to ModelFlags if they are not specified in cmdline arguments. -// Tests if the default batch size replaces a -1 in the first dimension. -TEST_F(TocoSavedModelTest, NoCmdLine) { - tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1}); - - ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"input"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("input"), "1,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if the order of input_arrays and output_arrays is deterministic when -// they are taken from the SavedModel. -TEST_F(TocoSavedModelTest, NoCmdLineMultipleArrays) { - tensorflow::GraphDef graph_def = GetComplexFloatGraphDef(); - - // Note: The model does not have two outputs. However, the function does not - // need an accurate output_array list. This is only meant to test order. - ProcessGraphDefMetadata({"inputB", "inputA"}, {"add", "invalid"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"inputA", "inputB"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add", "invalid"})); - EXPECT_EQ(GetInputShape("inputA"), "1,3,3,1"); - EXPECT_EQ(GetInputShape("inputB"), "1,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if input_shapes is inferred when input_arrays is passed in via cmdline -// arguments. -TEST_F(TocoSavedModelTest, InputNameWithoutInputShape) { - parsed_model_flags_.input_arrays.bind()("input"); - tensorflow::GraphDef graph_def = GetFloatGraphDef({2, 3, 3, 1}); - - ProcessGraphDefMetadata({"not_used_input"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"input"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("input"), "2,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Ensures a failure occurs when input_shapes is defined without input_arrays. -TEST_F(TocoSavedModelTest, InputShapeWithoutInputName) { - parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); - tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1}); - - EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def), - "failed: input_shapes.size\\(\\) == " - "model_flags->input_arrays_size\\(\\)"); -} - -// Tests if the cmdline values of input_arrays, input_shapes are used when -// specified with an empty GraphDef. -TEST_F(TocoSavedModelTest, InputArraysCmdLine) { - parsed_model_flags_.input_arrays.bind()("inputA,inputB"); - parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); - - ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"}); - EXPECT_EQ(GetInputArrays(), std::vector({"inputA", "inputB"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"output0", "output1"})); - EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); - EXPECT_EQ(GetInputShape("inputB"), "9,12"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if the cmdline values of input_arrays, input_shapes are used when -// specified even if values exist within the GraphDef. -TEST_F(TocoSavedModelTest, InputArraysCmdLineWithGraphDef) { - parsed_model_flags_.input_arrays.bind()("inputA"); - parsed_model_flags_.input_shapes.bind()("1,224,224,1"); - tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1}); - - ProcessGraphDefMetadata({"inputA"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"inputA"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if the cmdline values of input_arrays, input_shapes, inference_type, -// and output_arrays are used when specified with an empty GraphDef. -TEST_F(TocoSavedModelTest, AllParamsCmdLine) { - parsed_model_flags_.input_arrays.bind()("inputA,inputB"); - parsed_model_flags_.output_arrays.bind()("outputA,outputB"); - parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12"); - parsed_toco_flags_.inference_type.bind()("FLOAT"); - - ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"}); - EXPECT_EQ(GetInputArrays(), std::vector({"inputA", "inputB"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"outputA", "outputB"})); - EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1"); - EXPECT_EQ(GetInputShape("inputB"), "9,12"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Tests if a quantized graph gives the correct values assuming type is passed -// in via command line. -TEST_F(TocoSavedModelTest, QuantizedNoCmdLine) { - parsed_toco_flags_.inference_type.bind()("QUANTIZED_UINT8"); - tensorflow::GraphDef graph_def = GetQuantizedGraphDef(); - - ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"input"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("input"), "1,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::QUANTIZED_UINT8); -} - -// Tests if the provided batch size replaces a -1 in the first dimension of -// input shape. -TEST_F(TocoSavedModelTest, MissingShapeParameterValid) { - parsed_model_flags_.batch_size.bind()(3); - tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1}); - - ProcessGraphDefMetadata({"input"}, {"add"}, graph_def); - EXPECT_EQ(GetInputArrays(), std::vector({"input"})); - EXPECT_EQ(GetOutputArrays(), std::vector({"add"})); - EXPECT_EQ(GetInputShape("input"), "3,3,3,1"); - EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT); -} - -// Ensures a failure occurs if there is a -1 in a dimension aside from the first -// position of input shape. -TEST_F(TocoSavedModelTest, MissingShapeParameterInvalid) { - parsed_model_flags_.batch_size.bind()(3); - tensorflow::GraphDef graph_def = GetFloatGraphDef({1, -1, 3, 1}); - - EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def), - "A valid input shape was not found for input 'input'."); -} - -} // namespace -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 1fe76f8163cdf23b27f8baaf2d9c6d99b1aa3747..34130a02b03d0104df6f2a16ebccc50202f34f46 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -34,11 +34,11 @@ limitations under the License. namespace toco { namespace { -// CHECK-fails if the model contains a kTensorFlowUnsupported operation. +// CHECK-fails if the model contains a kUnsupported operation. void CheckUnsupportedOperations(const Model& model) { std::set unsupported_ops; for (auto& op : model.operators) { - if (op->type == OperatorType::kTensorFlowUnsupported) { + if (op->type == OperatorType::kUnsupported) { unsupported_ops.insert( static_cast(op.get()) ->tensorflow_op); @@ -55,7 +55,8 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ConvertExpandDimsToReshape); transformations->Add(new ConvertSqueezeToReshape); transformations->Add(new ConvertTrivialAddNToAdd); - transformations->Add(new ConvertTrivialStackToReshape); + transformations->Add(new ConvertTrivialPackToReshape); + transformations->Add(new ConvertTrivialTileToConcat); transformations->Add(new ConvertTrivialTransposeToReshape); transformations->Add(new ConvertReorderAxes); transformations->Add(new ResolveReshapeAttributes); @@ -76,25 +77,28 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveTensorFlowMatMul); transformations->Add(new FuseBinaryIntoPrecedingAffine); transformations->Add(new FuseBinaryIntoFollowingAffine); + transformations->Add(new FuseBroadcastIntoFollowingBinary); transformations->Add(new MergeReshapeIntoPrecedingTranspose); + transformations->Add(new MoveBinaryOperatorBeforeReshape); transformations->Add(new ReorderElementwiseUnary); transformations->Add(new ReorderReshapeTranspose); transformations->Add(new ResolveBatchNormalization); transformations->Add(new ResolveConstantBinaryOperator); transformations->Add(new ResolveConstantFill); transformations->Add(new ResolveConstantGather); + transformations->Add(new ResolveConstantPack); transformations->Add(new ResolveConstantRandomUniform); transformations->Add(new ResolveConstantRange); transformations->Add(new ResolveConstantReshape); + transformations->Add(new ResolveConstantSelect); transformations->Add(new ResolveConstantSlice); - transformations->Add(new ResolveConstantStack); transformations->Add(new ResolveConstantStridedSlice); + transformations->Add(new ResolveConstantTile); transformations->Add(new ResolveConstantTranspose); transformations->Add(new ResolveConstantUnaryOperator); transformations->Add(new ResolveTensorFlowMerge); transformations->Add(new ResolveSqueezeAttributes); transformations->Add(new ResolveTensorFlowSwitch); - transformations->Add(new ResolveTensorFlowTile); transformations->Add(new ResolveTensorFlowConcat); transformations->Add(new ResolveMultiplyByZero); transformations->Add(new IdentifyDilatedConv); @@ -103,17 +107,19 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new IdentifyRelu1); transformations->Add(new IdentifyPRelu); transformations->Add(new RemoveTrivialBinaryOperator); - transformations->Add(new ReadFakeQuantMinMax); + transformations->Add(new ResolveFakeQuantArgsFromVars); + transformations->Add(new ReadArrayMinmaxAndNarrowRangeFromFakeQuant); transformations->Add(new ResolveSpaceToBatchNDAttributes); transformations->Add(new ResolveBatchToSpaceNDAttributes); transformations->Add(new ResolvePadAttributes); transformations->Add(new ResolvePadV2Attributes); transformations->Add(new ResolveStridedSliceAttributes); transformations->Add(new ResolveSliceAttributes); - transformations->Add(new ResolveMeanAttributes); + transformations->Add(new ResolveReduceAttributes); transformations->Add(new ResolveConstantShapeOrRank); transformations->Add(new MakeInitialDequantizeOperator); transformations->Add(new UnpartitionEmbeddingLookup); + transformations->Add(new ResolveGatherAttributes); } bool SupportsQuantization(FileFormat format) { @@ -133,6 +139,8 @@ bool SupportsPreallocatedWorkspace(FileFormat format) { return (format == TFLITE); } +bool SupportsShuffledFCWeights(FileFormat format) { return format == TFLITE; } + bool IsRealValued(toco::ArrayDataType type) { // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used // for quantized real-number values, and no other integer type is ever used @@ -269,13 +277,16 @@ void Transform(const TocoFlags& toco_flags, Model* model) { transformations.Add(new toco::MergeLstmCellInputs); } } - if (toco_flags.quantize_weights()) { - transformations.Add(new QuantizeWeights); - } transformations.Add(new ResolveConstantConcatenation); RunGraphTransformations(model, "general graph transformations", transformations); + if (toco_flags.quantize_weights()) { + // Run the quantize weights transformation after batchnorms have been + // folded into the weights. + RunGraphTransformations(model, "quantize weights transformation", + {new QuantizeWeights}); + } if (quantize_output) { if (toco_flags.propagate_fake_quant_num_bits()) { RunGraphTransformations(model, @@ -300,8 +311,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) { // HardcodeMinMax to move changes through the graph as we make changes. auto propagate_default_min_max = absl::make_unique(); - if (toco_flags.has_default_ranges_min() && - toco_flags.has_default_ranges_max()) { + bool has_default_ranges_flag = (toco_flags.has_default_ranges_min() && + toco_flags.has_default_ranges_max()); + if (has_default_ranges_flag) { propagate_default_min_max->DefineTypeRange( ArrayDataType::kUint8, toco_flags.default_ranges_min(), toco_flags.default_ranges_max()); @@ -326,6 +338,8 @@ void Transform(const TocoFlags& toco_flags, Model* model) { new EnsureUint8WeightsSafeForFastInt8Kernels; ensure_safe_for_int8_kernels->set_allow_nudging_weights( toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel()); + ensure_safe_for_int8_kernels->set_has_default_ranges_flag( + has_default_ranges_flag); RunGraphTransformations(model, "quantization graph transformations", { new RemoveTrivialQuantizedActivationFunc, @@ -334,6 +348,10 @@ void Transform(const TocoFlags& toco_flags, Model* model) { new RemoveFinalDequantizeOp, ensure_safe_for_int8_kernels, }); + if (SupportsShuffledFCWeights(output_format)) { + RunGraphTransformations(model, "shuffling of FC weights", + {new ShuffleFCWeights}); + } } else { GraphTransformationsSet dequantization_transformations{new Dequantize}; // Dequantize creates FakeQuant nodes. We may want to discard diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 92bab5246cb85052b5e0216f1cb8a04736ae7a79..2ad27198119b4a8150a7381c047a4edb51aebfe6 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -338,46 +338,47 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Div) HANDLE_OPERATORTYPENAME_CASE(Tanh) HANDLE_OPERATORTYPENAME_CASE(Sin) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert) + HANDLE_OPERATORTYPENAME_CASE(All) + HANDLE_OPERATORTYPENAME_CASE(Assert) HANDLE_OPERATORTYPENAME_CASE(ExpandDims) HANDLE_OPERATORTYPENAME_CASE(Fill) HANDLE_OPERATORTYPENAME_CASE(FloorMod) HANDLE_OPERATORTYPENAME_CASE(FloorDiv) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreater) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreaterEqual) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowIdentity) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowLess) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowLessEqual) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMatMul) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMax) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMaximum) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMerge) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMin) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum) + HANDLE_OPERATORTYPENAME_CASE(Greater) + HANDLE_OPERATORTYPENAME_CASE(GreaterEqual) + HANDLE_OPERATORTYPENAME_CASE(Identity) + HANDLE_OPERATORTYPENAME_CASE(Less) + HANDLE_OPERATORTYPENAME_CASE(LessEqual) + HANDLE_OPERATORTYPENAME_CASE(MatMul) + HANDLE_OPERATORTYPENAME_CASE(ReduceMax) // Reduction Max + HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum + HANDLE_OPERATORTYPENAME_CASE(Merge) + HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min + HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum HANDLE_OPERATORTYPENAME_CASE(Neg) + HANDLE_OPERATORTYPENAME_CASE(OneHot) + HANDLE_OPERATORTYPENAME_CASE(Pack) HANDLE_OPERATORTYPENAME_CASE(Pad) HANDLE_OPERATORTYPENAME_CASE(PadV2) HANDLE_OPERATORTYPENAME_CASE(StridedSlice) - HANDLE_OPERATORTYPENAME_CASE(Stack) HANDLE_OPERATORTYPENAME_CASE(Range) HANDLE_OPERATORTYPENAME_CASE(Rank) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowReshape) + HANDLE_OPERATORTYPENAME_CASE(Reshape) HANDLE_OPERATORTYPENAME_CASE(Squeeze) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowRsqrt) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowShape) + HANDLE_OPERATORTYPENAME_CASE(Rsqrt) + HANDLE_OPERATORTYPENAME_CASE(Shape) HANDLE_OPERATORTYPENAME_CASE(Slice) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSplit) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSqrt) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSquare) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSwitch) + HANDLE_OPERATORTYPENAME_CASE(Split) + HANDLE_OPERATORTYPENAME_CASE(Sqrt) + HANDLE_OPERATORTYPENAME_CASE(Square) + HANDLE_OPERATORTYPENAME_CASE(Switch) HANDLE_OPERATORTYPENAME_CASE(Sub) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSum) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowTile) + HANDLE_OPERATORTYPENAME_CASE(Sum) + HANDLE_OPERATORTYPENAME_CASE(Tile) HANDLE_OPERATORTYPENAME_CASE(Transpose) HANDLE_OPERATORTYPENAME_CASE(TransposeConv) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcat) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcatV2) + HANDLE_OPERATORTYPENAME_CASE(Concat) + HANDLE_OPERATORTYPENAME_CASE(ConcatV2) HANDLE_OPERATORTYPENAME_CASE(Cast) HANDLE_OPERATORTYPENAME_CASE(Floor) HANDLE_OPERATORTYPENAME_CASE(Gather) @@ -385,17 +386,25 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND) HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND) HANDLE_OPERATORTYPENAME_CASE(Mean) + HANDLE_OPERATORTYPENAME_CASE(ReduceProd) HANDLE_OPERATORTYPENAME_CASE(Svdf) HANDLE_OPERATORTYPENAME_CASE(ArgMax) + HANDLE_OPERATORTYPENAME_CASE(ArgMin) HANDLE_OPERATORTYPENAME_CASE(TopK_V2) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported) + HANDLE_OPERATORTYPENAME_CASE(Unsupported) HANDLE_OPERATORTYPENAME_CASE(Exp) HANDLE_OPERATORTYPENAME_CASE(DynamicPartition) HANDLE_OPERATORTYPENAME_CASE(DynamicStitch) HANDLE_OPERATORTYPENAME_CASE(Select) HANDLE_OPERATORTYPENAME_CASE(SparseToDense) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowEqual) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowNotEqual) + HANDLE_OPERATORTYPENAME_CASE(Equal) + HANDLE_OPERATORTYPENAME_CASE(NotEqual) + HANDLE_OPERATORTYPENAME_CASE(Pow) + HANDLE_OPERATORTYPENAME_CASE(Any) + HANDLE_OPERATORTYPENAME_CASE(LogicalAnd) + HANDLE_OPERATORTYPENAME_CASE(LogicalNot) + HANDLE_OPERATORTYPENAME_CASE(LogicalOr) + HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -403,7 +412,7 @@ const char* OperatorTypeName(OperatorType type) { } string HelpfulOperatorTypeName(const Operator& op) { - if (op.type == OperatorType::kTensorFlowUnsupported) { + if (op.type == OperatorType::kUnsupported) { return toco::port::StringF( "(Unsupported TensorFlow op: %s)", static_cast(op).tensorflow_op); @@ -413,16 +422,20 @@ string HelpfulOperatorTypeName(const Operator& op) { bool OperatorSupportsFusedActivation(OperatorType type) { switch (type) { - case OperatorType::kConcatenation: - case OperatorType::kFakeQuant: - case OperatorType::kGather: - case OperatorType::kSlice: - case OperatorType::kSqueeze: - case OperatorType::kTensorFlowReshape: - case OperatorType::kTensorFlowSplit: - return false; - default: + case OperatorType::kAdd: + case OperatorType::kAveragePool: + case OperatorType::kBatchNormalization: + case OperatorType::kConv: + case OperatorType::kDepthwiseConv: + case OperatorType::kDiv: + case OperatorType::kFullyConnected: + case OperatorType::kL2Pool: + case OperatorType::kMaxPool: + case OperatorType::kMul: + case OperatorType::kSub: return true; + default: + return false; } } @@ -442,8 +455,12 @@ void LogSummary(int log_level, const Model& model) { } void LogArray(int log_level, const Model& model, const string& name) { - const auto& array = model.GetArray(name); VLOG(log_level) << "Array: " << name; + if (!model.HasArray(name)) { + VLOG(log_level) << " DOES NOT EXIST"; + return; + } + const auto& array = model.GetArray(name); VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type); VLOG(log_level) << " Final type: " << ArrayDataTypeName(array.final_data_type); @@ -585,14 +602,33 @@ void UnextendShape(Shape* shape, int new_shape_size) { shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction); } -bool IsValid(const Shape& shape) { +// In general, zero-sized dimensions are disallowed, but there are exceptions, +// e.g., if the tensor data itself represents a scalar (rank 0) shape, its +// shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more +// strict, and is appropriate for ops and comparisons where an empty shape +// doesn't make sense. +template +void CheckValidShapeDimensions(const Dims& dims) { + if (dims.size() == 1 && dims[0] == 0) { + return; + } + for (const auto& dim : dims) { + CHECK_GE(dim, 1); + } +} + +void CheckValidShape(const Shape& shape) { + CheckValidShapeDimensions(shape.dims()); +} + +bool IsNonEmpty(const Shape& shape) { for (int i = 0; i < shape.dimensions_count(); ++i) { if (shape.dims(i) < 1) return false; } return true; } -void CheckShapeDimensions(const Shape& shape) { +void CheckNonEmptyShapeDimensions(const Shape& shape) { for (int i = 0; i < shape.dimensions_count(); ++i) { CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i << ". shape = " << ShapeToString(shape); @@ -600,8 +636,8 @@ void CheckShapeDimensions(const Shape& shape) { } bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) { - CheckShapeDimensions(shape0); - CheckShapeDimensions(shape1); + CheckNonEmptyShapeDimensions(shape0); + CheckNonEmptyShapeDimensions(shape1); const Shape* longer = &shape0; const Shape* shorter = &shape1; @@ -628,8 +664,8 @@ bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) { } bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) { - CheckShapeDimensions(shape0); - CheckShapeDimensions(shape1); + CheckNonEmptyShapeDimensions(shape0); + CheckNonEmptyShapeDimensions(shape1); const Shape* longer = &shape0; const Shape* shorter = &shape1; @@ -666,9 +702,9 @@ bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) { } int RequiredBufferSizeForShape(const Shape& shape) { + CheckValidShape(shape); int max_offset = 1; for (const auto& dim : shape.dims()) { - CHECK_GE(dim, 1); max_offset *= dim; } return max_offset; @@ -929,9 +965,7 @@ void CheckEachArray(const Model& model) { // shape. CHECK(array->has_shape()); // Constant buffer should has a valid shape. - for (int d : array->shape().dims()) { - CHECK_GE(d, 1); - } + CheckValidShape(array->shape()); // The shape flat-size should agree with the buffer length. CHECK_EQ(array->buffer->Length(), RequiredBufferSizeForShape(array->shape())); @@ -1256,8 +1290,13 @@ void InsertCopyOperator(Model* model, const string& source_array_name, auto* copy_op = new TensorFlowReshapeOperator; copy_op->inputs = { source_array_name, - CreateInt32Array(model, target_array_name + "_copy_shape", shape)}; + CreateInt32Array( + model, AvailableArrayName(*model, target_array_name + "_copy_shape"), + shape)}; copy_op->outputs = {target_array_name}; + if (target_array.has_shape()) { + copy_op->shape = target_array.shape().dims(); + } model->operators.emplace_back(copy_op); } @@ -1518,8 +1557,8 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { if (!input_array.has_shape()) { if (input_array_proto.has_shape()) { auto& input_array_dims = *input_array.mutable_shape()->mutable_dims(); + CheckValidShapeDimensions(input_array_proto.shape().dims()); for (auto dim : input_array_proto.shape().dims()) { - CHECK_GE(dim, 1); input_array_dims.push_back(dim); } } @@ -1562,11 +1601,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { model); } - for (const auto& input_array : model->flags.input_arrays()) { - if (input_array.has_shape()) { - CHECK(input_array.shape().dims_size()); - } - } model->flags.set_change_concat_input_ranges( model_flags.change_concat_input_ranges()); model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays()); @@ -1599,11 +1633,12 @@ void CheckIsReadyForQuantization(const Model& model) { << "Array " << input << ", which is an input to the " << HelpfulOperatorTypeName(*op) << " operator producing the output " << "array " << op->outputs[0] << ", is lacking min/max data, " - << "which is necessary for quantization. Either target a " - << "non-quantized output format, or change the input graph to " - << "contain min/max information, or pass --default_ranges_min= and " - << "--default_ranges_max= if you do not care about the accuracy of " - << "results."; + << "which is necessary for quantization. If accuracy matters, either " + << "target a non-quantized output format, or run quantized training " + << "with your model from a floating point checkpoint to change the " + << "input graph to contain min/max information. If you don't care " + << "about accuracy, you can pass --default_ranges_min= and " + << "--default_ranges_max= for easy experimentation."; } } } @@ -2196,4 +2231,51 @@ void UseArraysExtraInfo(Model* model, bool quantize_output) { } } +void UndoWeightsShuffling(Model* model) { + for (const auto& op : model->operators) { + if (op->type != toco::OperatorType::kFullyConnected) { + continue; + } + const auto& fc_op = static_cast(*op); + if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) { + continue; + } + const string& weights_name = fc_op.inputs[1]; + QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1); + auto& weights_array = model->GetArray(weights_name); + QCHECK(weights_array.data_type == ArrayDataType::kUint8); + auto& weights_data = + weights_array.GetMutableBuffer().data; + const auto& weights_shape = weights_array.shape(); + QCHECK_EQ(weights_shape.dimensions_count(), 2); + const int rows = weights_shape.dims(0); + const int cols = weights_shape.dims(1); + QCHECK_EQ(rows % 4, 0); + QCHECK_EQ(cols % 16, 0); + CHECK_EQ(rows * cols, weights_data.size()); + // Compute the de-shuffled weights + std::vector deshuffled_data(weights_data.size()); + uint8* shuffled_data_ptr = weights_data.data(); + for (int r = 0; r < rows; r += 4) { + for (int c = 0; c < cols; c += 16) { + for (int i = 0; i < 4; i++) { + uint8* deshuffled_data_ptr = + deshuffled_data.data() + (r + i) * cols + c; + for (int j = 0; j < 16; j++) { + uint8 shuffled_val = *shuffled_data_ptr++; + // Deshuffling isn't only about deshuffling the storage layout, + // it's also about undoing the flipping of the sign bit, which is + // performed on the shuffled weights. + uint8 deshuffled_val = shuffled_val ^ 0x80; + *deshuffled_data_ptr++ = deshuffled_val; + } + } + } + } + CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols); + // Switch this FC op to using the deshuffled weights. + weights_data = std::move(deshuffled_data); + } +} + } // namespace toco diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 7681ce9d39ec56f9447896682b52bd4efb1d0e54..b99e6111fe92be178b5ff8b83477f1ce10c20926 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -101,6 +101,8 @@ std::vector>::iterator FindOp(Model& model, const char* OperatorTypeName(OperatorType type); string HelpfulOperatorTypeName(const Operator& op); +// Whether the operator can be fused with an activation function. Note that this +// will return false by default for new operators; fusing support is opt-in. bool OperatorSupportsFusedActivation(OperatorType type); void DumpGraphvizVideoFrame(const Model& model); @@ -113,10 +115,9 @@ void ExtendShape(Shape* shape, int new_shape_size); // TODO(b/36075966): Clean up when dims superseded by array shape. void UnextendShape(Shape* shape, int new_shape_size); -// Checks that all dimensions of 'shape' are at least 1. -bool IsValid(const Shape& shape); -// Same as above, but reports error using CHECK. -void CheckShapeDimensions(const Shape& shape); +// Checks that all dimensions of 'shape' are at least 1. Note that scalars, +// lacking dimensions, satisfy this condition and are considered non-empty. +bool IsNonEmpty(const Shape& shape); // Given two shapes with potentially different dimensionality and dimension // arrays d0 and d1. Without loss of generality, assume that shape0 may have @@ -342,6 +343,11 @@ tensorflow::Status NumElements(const std::vector& shape, U* num_elements) { return tensorflow::Status::OK(); } +// A model file may have shuffled FC weights. +// When that happens, we want to de-shuffle them immediately on import, +// so that the rest of toco doesn't need to know about shuffled weights. +void UndoWeightsShuffling(Model* model); + } // namespace toco #endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc index a683867374c8b8dcb274478adf6b5fa0691d1c5a..eb495646a2df0d0295eab54fcc5a5bf156a59d39 100644 --- a/tensorflow/contrib/lite/toco/tooling_util_test.cc +++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc @@ -39,6 +39,8 @@ std::vector CreateShapePairs() { {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kBroadcast}, {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcast}, {Shape({8, 1, 6, 1}), Shape({7, 1, 5}), Agreement::kBroadcast}, + {Shape({}), Shape({3}), Agreement::kBroadcast}, + {Shape({}), Shape({3, 1}), Agreement::kBroadcast}, // These extend (and therefore broadcast). {Shape({3}), Shape({3}), Agreement::kExtend}, @@ -54,6 +56,7 @@ std::vector CreateShapePairs() { {Shape({15, 3, 5}), Shape({15, 1, 5}), Agreement::kBroadcastNotExtend}, {Shape({15, 3, 5}), Shape({3, 5}), Agreement::kBroadcastNotExtend}, {Shape({15, 3, 5}), Shape({3, 1}), Agreement::kBroadcastNotExtend}, + {Shape({3, 1}), Shape({}), Agreement::kBroadcastNotExtend}, // These do not broadcast (and therefore also do not extend). {Shape({3}), Shape({4}), Agreement::kNeither}, @@ -175,4 +178,24 @@ TEST(NumElementsTest, UnsignedInt64) { EXPECT_EQ(status.error_message(), kLargeTensorMessage); } +TEST(NumElementsTest, Scalar) { + tensorflow::Status status = tensorflow::Status::OK(); + + int32_t count; + status = NumElements(std::vector{}, &count); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(count, 1); + + uint64_t countu64; + status = NumElements(std::vector{}, &countu64); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(countu64, 1ULL); +} + +TEST(FusedActivationTest, DefaultsToUnfused) { + EXPECT_TRUE(OperatorSupportsFusedActivation(OperatorType::kAdd)); + EXPECT_FALSE(OperatorSupportsFusedActivation(OperatorType::kNone)); + EXPECT_FALSE(OperatorSupportsFusedActivation(static_cast(255))); +} + } // namespace toco diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD index 5913847329eeae7373d0d21834dd37327e4068c4..0b268264031f4f1e86b2956a75bde173a945ddf4 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/contrib/lite/tools/BUILD @@ -14,6 +14,7 @@ py_binary( srcs = ["visualize.py"], data = [ "//tensorflow/contrib/lite/schema:schema.fbs", + "//tensorflow/python:platform", "@flatbuffers//:flatc", ], srcs_version = "PY2AND3", @@ -52,7 +53,9 @@ cc_test( "//tensorflow/contrib/lite:testdata/test_model_broken.bin", ], tags = [ + "no_oss", "tflite_not_portable_android", + "tflite_not_portable_ios", ], deps = [ ":gen_op_registration", @@ -77,6 +80,7 @@ cc_test( size = "small", srcs = ["verifier_test.cc"], tags = [ + "no_oss", "tflite_not_portable", ], deps = [ diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD index 8857062c000201e1077469fc36e3bf2760924a30..2cb07eb6ec9405a5fefec9cc49f3b1aaff663e4b 100644 --- a/tensorflow/contrib/lite/tools/benchmark/BUILD +++ b/tensorflow/contrib/lite/tools/benchmark/BUILD @@ -10,11 +10,16 @@ load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") common_copts = ["-Wall"] + tflite_copts() +cc_library( + name = "logging", + hdrs = ["logging.h"], + copts = common_copts, +) + cc_binary( name = "benchmark_model", srcs = [ "benchmark_main.cc", - "logging.h", ], copts = common_copts, linkopts = tflite_linkopts() + select({ @@ -26,6 +31,26 @@ cc_binary( }), deps = [ ":benchmark_tflite_model_lib", + ":logging", + ], +) + +cc_test( + name = "benchmark_test", + srcs = ["benchmark_test.cc"], + args = [ + "--graph=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)", + ], + data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], + tags = [ + "tflite_not_portable_android", + "tflite_not_portable_ios", + ], + deps = [ + ":benchmark_tflite_model_lib", + ":command_line_flags", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", ], ) @@ -58,6 +83,7 @@ cc_library( copts = common_copts, deps = [ ":benchmark_model_lib", + ":logging", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:builtin_ops", @@ -66,16 +92,27 @@ cc_library( ], ) +cc_library( + name = "benchmark_params", + srcs = [ + "benchmark_params.cc", + ], + hdrs = ["benchmark_params.h"], + copts = common_copts, + deps = [":logging"], +) + cc_library( name = "benchmark_model_lib", srcs = [ "benchmark_model.cc", - "logging.h", ], hdrs = ["benchmark_model.h"], copts = common_copts, deps = [ + ":benchmark_params", ":command_line_flags", + ":logging", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:builtin_ops", diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md index c10826afff6d5569545d4b7df73c88d24d9dcd1a..f1e257ad104885a23cd7f17b9c21317c0881ccc0 100644 --- a/tensorflow/contrib/lite/tools/benchmark/README.md +++ b/tensorflow/contrib/lite/tools/benchmark/README.md @@ -3,7 +3,38 @@ ## Description A simple C++ binary to benchmark a TFLite model and its individual operators, -both on desktop machines and on Android. +both on desktop machines and on Android. The binary takes a TFLite model, +generates random inputs and then repeatedly runs the model for specified number +of runs. Aggregrate latency statistics are reported after running the benchmark. + +The instructions below are for running the binary on Desktop and Android, +for iOS please use the +[iOS benchmark app] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios). + +## Parameters + +The binary takes the following required parameters: + +* `graph`: `string` \ + The path to the TFLite model file. +* `input_layer`: `string` \ + The name of the input layer, this is typically the first layer of the model. +* `input_layer_shape`: `string` \ + The shape of the input layer. This is a comma separated string of the shape + of tensor of input layer. + +and the following optional parameters: + +* `num_threads`: `int` (default=1) \ + The number of threads to use for running TFLite interpreter. +* `warmup_runs`: `int` (default=1) \ + The number of warmup runs to do before starting the benchmark. +* `run_delay`: `float` (default=-1.0) \ + The delay in seconds between subsequent benchmark runs. Non-positive values + mean use no delay. +* `use_nnapi`: `bool` (default=false) \ + Whether to use [Android NNAPI] (https://developer.android.com/ndk/guides/neuralnetworks/). + This API is available on recent Android devices. ## To build/install/run @@ -44,7 +75,7 @@ adb push mobilenet_quant_v1_224.tflite /data/local/tmp ``` adb shell /data/local/tmp/benchmark_model \ --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \ - --input_layer="Placeholder" \ + --input_layer="input" \ --input_layer_shape="1,224,224,3" \ --num_threads=4 ``` @@ -70,6 +101,30 @@ bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \ The MobileNet graph used as an example here may be downloaded from https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip + +## Reducing variance between runs on Android. + +Most modern Android phones use [ARM big.LITTLE](https://en.wikipedia.org/wiki/ARM_big.LITTLE) +architecture where some cores are more power hungry but faster than other cores. +When running benchmarks on these phones there can be significant variance +between different runs of the benchmark. One way to reduce variance between runs +is to set the [CPU affinity](https://en.wikipedia.org/wiki/Processor_affinity) +before running the benchmark. On Android this can be done using the `taskset` +command. +E.g. for running the benchmark on big cores on Pixel 2 with a single thread one +can use the following command: + +``` +adb shell taskset f0 /data/local/tmp/benchmark_model \ + --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \ + --input_layer="input" \ + --input_layer_shape="1,224,224,3" \ + --num_threads=1 +``` + +where `f0` is the affinity mask for big cores on Pixel 2. +Note: The affinity mask varies with the device. + ## Profiling model operators The benchmark model binary also allows you to profile operators and give execution times of each operator. To do this, compile the binary with a compiler flag that enables profiling to be compiled in. Pass **--copt=-DTFLITE_PROFILING_ENABLED** diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc index a8a9a6112c1ec050be8d0bcfe9dc5f00df40d3ff..f86c0445b0525cd053c733b18bb7f1205d310d43 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc @@ -48,6 +48,19 @@ namespace tflite { namespace benchmark { using tensorflow::Stat; +BenchmarkParams BenchmarkModel::DefaultParams() { + BenchmarkParams params; + params.AddParam("num_runs", BenchmarkParam::Create(50)); + params.AddParam("run_delay", BenchmarkParam::Create(-1.0f)); + params.AddParam("num_threads", BenchmarkParam::Create(1)); + params.AddParam("benchmark_name", BenchmarkParam::Create("")); + params.AddParam("output_prefix", BenchmarkParam::Create("")); + params.AddParam("warmup_runs", BenchmarkParam::Create(1)); + return params; +} + +BenchmarkModel::BenchmarkModel() : params_(DefaultParams()) {} + void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) { auto inference_us = results.inference_time_us(); auto init_us = results.startup_latency_us(); @@ -60,30 +73,38 @@ void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) { std::vector BenchmarkModel::GetFlags() { return { - Flag("num_runs", ¶ms_.num_runs, "number of runs"), - Flag("run_delay", ¶ms_.run_delay, "delay between runs in seconds"), - Flag("num_threads", ¶ms_.num_threads, "number of threads"), - Flag("benchmark_name", ¶ms_.benchmark_name, "benchmark name"), - Flag("output_prefix", ¶ms_.output_prefix, "benchmark output prefix"), - Flag("warmup_runs", ¶ms_.warmup_runs, - "how many runs to initialize model"), + CreateFlag("num_runs", ¶ms_, "number of runs"), + CreateFlag("run_delay", ¶ms_, "delay between runs in seconds"), + CreateFlag("num_threads", ¶ms_, "number of threads"), + CreateFlag("benchmark_name", ¶ms_, "benchmark name"), + CreateFlag("output_prefix", ¶ms_, + "benchmark output prefix"), + CreateFlag("warmup_runs", ¶ms_, + "how many runs to initialize model"), }; } -void BenchmarkModel::LogFlags() { - TFLITE_LOG(INFO) << "Num runs: [" << params_.num_runs << "]"; - TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" << params_.run_delay +void BenchmarkModel::LogParams() { + TFLITE_LOG(INFO) << "Num runs: [" << params_.Get("num_runs") << "]"; + TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" + << params_.Get("run_delay") << "]"; + TFLITE_LOG(INFO) << "Num threads: [" << params_.Get("num_threads") + << "]"; + TFLITE_LOG(INFO) << "Benchmark name: [" + << params_.Get("benchmark_name") << "]"; + TFLITE_LOG(INFO) << "Output prefix: [" + << params_.Get("output_prefix") << "]"; + TFLITE_LOG(INFO) << "Warmup runs: [" << params_.Get("warmup_runs") << "]"; - TFLITE_LOG(INFO) << "Num threads: [" << params_.num_threads << "]"; - TFLITE_LOG(INFO) << "Benchmark name: [" << params_.benchmark_name << "]"; - TFLITE_LOG(INFO) << "Output prefix: [" << params_.output_prefix << "]"; - TFLITE_LOG(INFO) << "Warmup runs: [" << params_.warmup_runs << "]"; } +void BenchmarkModel::PrepareInputsAndOutputs() {} + Stat BenchmarkModel::Run(int num_times, RunType run_type) { Stat run_stats; TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations "; for (int run = 0; run < num_times; run++) { + PrepareInputsAndOutputs(); listeners_.OnSingleRunStart(run_type); int64_t start_us = profiling::time::NowMicros(); RunImpl(); @@ -91,7 +112,7 @@ Stat BenchmarkModel::Run(int num_times, RunType run_type) { listeners_.OnSingleRunEnd(); run_stats.UpdateStat(end_us - start_us); - SleepForSeconds(params_.run_delay); + SleepForSeconds(params_.Get("run_delay")); } std::stringstream stream; @@ -101,12 +122,18 @@ Stat BenchmarkModel::Run(int num_times, RunType run_type) { return run_stats; } +bool BenchmarkModel::ValidateParams() { return true; } + void BenchmarkModel::Run(int argc, char **argv) { if (!ParseFlags(argc, argv)) { return; } + Run(); +} - LogFlags(); +void BenchmarkModel::Run() { + ValidateParams(); + LogParams(); listeners_.OnBenchmarkStart(params_); int64_t initialization_start_us = profiling::time::NowMicros(); @@ -117,8 +144,10 @@ void BenchmarkModel::Run(int argc, char **argv) { << "ms"; uint64_t input_bytes = ComputeInputBytes(); - Stat warmup_time_us = Run(params_.warmup_runs, WARMUP); - Stat inference_time_us = Run(params_.num_runs, REGULAR); + Stat warmup_time_us = + Run(params_.Get("warmup_runs"), WARMUP); + Stat inference_time_us = + Run(params_.Get("num_runs"), REGULAR); listeners_.OnBenchmarkEnd( {startup_latency_us, input_bytes, warmup_time_us, inference_time_us}); } @@ -132,7 +161,7 @@ bool BenchmarkModel::ParseFlags(int argc, char **argv) { TFLITE_LOG(ERROR) << usage; return false; } - return ValidateFlags(); + return true; } } // namespace benchmark diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h index d48f693693c2cee0cd2e2a6f2b4c590998feffb3..677a1ee68c247fb016c7ede4e1a614bacb7a0a15 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h" #include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" #include "tensorflow/core/util/stats_calculator.h" @@ -63,17 +64,6 @@ class BenchmarkResults { tensorflow::Stat inference_time_us_; }; -struct BenchmarkParams { - BenchmarkParams() - : num_runs(50), warmup_runs(1), run_delay(-1.0), num_threads(1) {} - int num_runs; - int warmup_runs; - float run_delay; - int num_threads; - std::string benchmark_name; - std::string output_prefix; -}; - class BenchmarkListener { public: virtual void OnBenchmarkStart(const BenchmarkParams& params) {} @@ -130,26 +120,38 @@ class BenchmarkLoggingListener : public BenchmarkListener { void OnBenchmarkEnd(const BenchmarkResults& results) override; }; +template +Flag CreateFlag(const char* name, BenchmarkParams* params, + const std::string& usage) { + return Flag(name, [params, name](const T& val) { params->Set(name, val); }, + params->Get(name), usage); +} + // Benchmarks a model. // // Subclasses need to implement initialization and running of the model. // The results can be collected by adding BenchmarkListener(s). class BenchmarkModel { public: + static BenchmarkParams DefaultParams(); + BenchmarkModel(); + BenchmarkModel(BenchmarkParams params) : params_(std::move(params)) {} virtual ~BenchmarkModel() {} - bool ParseFlags(int argc, char** argv); virtual void Init() = 0; void Run(int argc, char** argv); + virtual void Run(); void AddListener(BenchmarkListener* listener) { listeners_.AddListener(listener); } protected: - virtual void LogFlags(); - virtual bool ValidateFlags() { return true; } + virtual void LogParams(); + virtual bool ValidateParams(); + bool ParseFlags(int argc, char** argv); virtual std::vector GetFlags(); virtual uint64_t ComputeInputBytes() = 0; virtual tensorflow::Stat Run(int num_times, RunType run_type); + virtual void PrepareInputsAndOutputs(); virtual void RunImpl() = 0; BenchmarkParams params_; BenchmarkListeners listeners_; diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc new file mode 100644 index 0000000000000000000000000000000000000000..1dcf580a9d4995e6cb3706d3562bc8a2f4670082 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h" + +#include +#include +#include + +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace tflite { +namespace benchmark { + +void BenchmarkParam::AssertHasSameType(BenchmarkParam::ParamType a, + BenchmarkParam::ParamType b) { + TFLITE_BENCHMARK_CHECK(a == b) << "Type mismatch while accessing parameter."; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_INT32; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_BOOL; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_FLOAT; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_STRING; +} + +void BenchmarkParams::AssertParamExists(const std::string& name) const { + TFLITE_BENCHMARK_CHECK(HasParam(name)) << name << " was not found."; +} + +} // namespace benchmark +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h new file mode 100644 index 0000000000000000000000000000000000000000..c98f47bb0d89864dff54d7cdebe764e56e4cfda2 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace tflite { +namespace benchmark { + +template +class TypedBenchmarkParam; + +class BenchmarkParam { + protected: + enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING }; + template + static ParamType GetValueType(); + + public: + template + static std::unique_ptr Create(const T& default_value) { + return std::unique_ptr( + new TypedBenchmarkParam(default_value)); + } + + template + TypedBenchmarkParam* AsTyped() { + AssertHasSameType(GetValueType(), type_); + return static_cast*>(this); + } + virtual ~BenchmarkParam() {} + BenchmarkParam(ParamType type) : type_(type) {} + + private: + static void AssertHasSameType(ParamType a, ParamType b); + + const ParamType type_; +}; + +template +class TypedBenchmarkParam : public BenchmarkParam { + public: + TypedBenchmarkParam(const T& value) + : BenchmarkParam(GetValueType()), value_(value) {} + void Set(const T& value) { value_ = value; } + + T Get() { return value_; } + + private: + T value_; +}; + +class BenchmarkParams { + public: + void AddParam(const std::string& name, + std::unique_ptr value) { + params_[name] = std::move(value); + } + + bool HasParam(const std::string& name) const { + return params_.find(name) != params_.end(); + } + + template + void Set(const std::string& name, const T& value) { + AssertParamExists(name); + params_.at(name)->AsTyped()->Set(value); + } + + template + T Get(const std::string& name) const { + AssertParamExists(name); + return params_.at(name)->AsTyped()->Get(); + } + + private: + void AssertParamExists(const std::string& name) const; + std::unordered_map> params_; +}; + +} // namespace benchmark +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_test.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b697bb394db9b967dfaaff649517dcc23e85ccb0 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_test.cc @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" + +namespace { +const std::string* g_model_path = nullptr; +} + +namespace tflite { +namespace benchmark { +namespace { + +BenchmarkParams CreateParams() { + BenchmarkParams params; + params.AddParam("num_runs", BenchmarkParam::Create(2)); + params.AddParam("run_delay", BenchmarkParam::Create(-1.0f)); + params.AddParam("num_threads", BenchmarkParam::Create(1)); + params.AddParam("benchmark_name", BenchmarkParam::Create("")); + params.AddParam("output_prefix", BenchmarkParam::Create("")); + params.AddParam("warmup_runs", BenchmarkParam::Create(1)); + params.AddParam("graph", BenchmarkParam::Create(*g_model_path)); + params.AddParam("input_layer", BenchmarkParam::Create("")); + params.AddParam("input_layer_shape", BenchmarkParam::Create("")); + params.AddParam("use_nnapi", BenchmarkParam::Create(false)); + return params; +} + +TEST(BenchmarkTest, DoesntCrash) { + ASSERT_THAT(g_model_path, testing::NotNull()); + + BenchmarkTfLiteModel benchmark(CreateParams()); + benchmark.Run(); +} + +} // namespace +} // namespace benchmark +} // namespace tflite + +int main(int argc, char** argv) { + std::string model_path; + std::vector flags = { + tflite::Flag::CreateFlag("graph", &model_path, "Path to model file.")}; + g_model_path = &model_path; + const bool parse_result = + tflite::Flags::Parse(&argc, const_cast(argv), flags); + if (!parse_result) { + std::cerr << tflite::Flags::Usage(argv[0], flags); + return 1; + } + + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc index 5f803cec197858953180d379c763ed7ebd34ee1d..7f97f5d0cd6c412653f6d510406daf86b7baa3f7 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc @@ -162,35 +162,60 @@ bool PopulateInputLayerInfo( return true; } +BenchmarkParams GetDefaultParams() { + BenchmarkParams default_params = BenchmarkModel::DefaultParams(); + default_params.AddParam("graph", BenchmarkParam::Create("")); + default_params.AddParam("input_layer", + BenchmarkParam::Create("")); + default_params.AddParam("input_layer_shape", + BenchmarkParam::Create("")); + default_params.AddParam("use_nnapi", BenchmarkParam::Create(false)); + return default_params; +} + } // namespace +BenchmarkTfLiteModel::BenchmarkTfLiteModel() + : BenchmarkModel(GetDefaultParams()) { + AddListener(&profiling_listener_); +} + +BenchmarkTfLiteModel::BenchmarkTfLiteModel(BenchmarkParams params) + : BenchmarkModel(std::move(params)) { + AddListener(&profiling_listener_); +} + std::vector BenchmarkTfLiteModel::GetFlags() { std::vector flags = BenchmarkTfLiteModel::BenchmarkModel::GetFlags(); std::vector specific_flags = { - Flag("graph", &graph, "graph file name"), - Flag("input_layer", &input_layer_string, "input layer names"), - Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), - Flag("use_nnapi", &use_nnapi, "use nnapi api")}; + CreateFlag("graph", ¶ms_, "graph file name"), + CreateFlag("input_layer", ¶ms_, "input layer names"), + CreateFlag("input_layer_shape", ¶ms_, + "input layer shape"), + CreateFlag("use_nnapi", ¶ms_, "use nnapi api")}; flags.insert(flags.end(), specific_flags.begin(), specific_flags.end()); return flags; } -void BenchmarkTfLiteModel::LogFlags() { - BenchmarkModel::LogFlags(); - TFLITE_LOG(INFO) << "Graph: [" << graph << "]"; - TFLITE_LOG(INFO) << "Input layers: [" << input_layer_string << "]"; - TFLITE_LOG(INFO) << "Input shapes: [" << input_layer_shape_string << "]"; - TFLITE_LOG(INFO) << "Use nnapi : [" << use_nnapi << "]"; +void BenchmarkTfLiteModel::LogParams() { + BenchmarkModel::LogParams(); + TFLITE_LOG(INFO) << "Graph: [" << params_.Get("graph") << "]"; + TFLITE_LOG(INFO) << "Input layers: [" + << params_.Get("input_layer") << "]"; + TFLITE_LOG(INFO) << "Input shapes: [" + << params_.Get("input_layer_shape") << "]"; + TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get("use_nnapi") << "]"; } -bool BenchmarkTfLiteModel::ValidateFlags() { - if (graph.empty()) { +bool BenchmarkTfLiteModel::ValidateParams() { + if (params_.Get("graph").empty()) { TFLITE_LOG(ERROR) << "Please specify the name of your TF Lite input file with --graph"; return false; } - return PopulateInputLayerInfo(input_layer_string, input_layer_shape_string, + return PopulateInputLayerInfo(params_.Get("input_layer"), + params_.Get("input_layer_shape"), &inputs); } @@ -205,6 +230,7 @@ uint64_t BenchmarkTfLiteModel::ComputeInputBytes() { } void BenchmarkTfLiteModel::Init() { + std::string graph = params_.Get("graph"); model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); if (!model) { TFLITE_LOG(FATAL) << "Failed to mmap model " << graph; @@ -226,10 +252,14 @@ void BenchmarkTfLiteModel::Init() { } profiling_listener_.SetInterpreter(interpreter.get()); - if (params_.num_threads != -1) { - interpreter->SetNumThreads(params_.num_threads); + const int32_t num_threads = params_.Get("num_threads"); + + if (num_threads != -1) { + interpreter->SetNumThreads(num_threads); } + bool use_nnapi = params_.Get("use_nnapi"); + interpreter->UseNNAPI(use_nnapi); auto interpreter_inputs = interpreter->inputs(); diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h index ffb93da964b2da0328616e749abd9c5a84189468..9931dcbafe06cb9f8673462858244f6f2793b29d 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h @@ -50,13 +50,12 @@ class ProfilingListener : public BenchmarkListener { // Benchmarks a TFLite model by running tflite interpreter. class BenchmarkTfLiteModel : public BenchmarkModel { public: - BenchmarkTfLiteModel() : use_nnapi(false) { - AddListener(&profiling_listener_); - } + BenchmarkTfLiteModel(); + BenchmarkTfLiteModel(BenchmarkParams params); std::vector GetFlags() override; - void LogFlags() override; - bool ValidateFlags() override; + void LogParams() override; + bool ValidateParams() override; uint64_t ComputeInputBytes() override; void Init() override; void RunImpl() override; @@ -70,13 +69,7 @@ class BenchmarkTfLiteModel : public BenchmarkModel { private: std::unique_ptr model; std::unique_ptr interpreter; - std::string graph; - std::string input_layer_string; - std::string input_layer_type_string; - std::string input_layer_shape_string; - std::string input_layer_values_string; std::vector inputs; - bool use_nnapi; ProfilingListener profiling_listener_; }; diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc index 8195fc44beb288eec3c020791b47eefa01536fb7..ff818b9dcb5ee0b58b95c3dceae74083dbd4f0da 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include #include +#include #include namespace tflite { @@ -44,76 +45,79 @@ bool ParseFlag(const std::string& arg, const std::string& flag, } template -bool ParseFlag(const std::string& flag_value, T* value) { +bool ParseFlag(const std::string& flag_value, + const std::function& hook) { std::istringstream stream(flag_value); T read_value; stream >> read_value; if (!stream.eof() && !stream.good()) { return false; } - *value = read_value; + hook(read_value); return true; } -bool ParseBoolFlag(const std::string& flag_value, bool* value) { +bool ParseBoolFlag(const std::string& flag_value, + const std::function& hook) { if (flag_value != "true" && flag_value != "false") { return false; } - *value = (flag_value == "true"); + hook(flag_value == "true"); return true; } - -bool ParseStringFlag(const std::string& flag_value, std::string* value) { - *value = flag_value; - return true; -} - } // namespace -Flag::Flag(const char* name, int32_t* dst, const std::string& usage_text) +Flag::Flag(const char* name, const std::function& hook, + int32_t default_value, const std::string& usage_text) : name_(name), type_(TYPE_INT32), - value_hook_([dst](const std::string& flag_value) { - return ParseFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + return ParseFlag(flag_value, hook); }), - default_for_display_(ToString(*dst)), + default_for_display_(ToString(default_value)), usage_text_(usage_text) {} -Flag::Flag(const char* name, int64_t* dst, const std::string& usage_text) +Flag::Flag(const char* name, const std::function& hook, + int64_t default_value, const std::string& usage_text) : name_(name), type_(TYPE_INT64), - value_hook_([dst](const std::string& flag_value) { - return ParseFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + return ParseFlag(flag_value, hook); }), - default_for_display_(ToString(*dst)), + default_for_display_(ToString(default_value)), usage_text_(usage_text) {} -Flag::Flag(const char* name, float* dst, const std::string& usage_text) +Flag::Flag(const char* name, const std::function& hook, + float default_value, const std::string& usage_text) : name_(name), type_(TYPE_FLOAT), - value_hook_([dst](const std::string& flag_value) { - return ParseFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + return ParseFlag(flag_value, hook); }), - default_for_display_(ToString(*dst)), + default_for_display_(ToString(default_value)), usage_text_(usage_text) {} -Flag::Flag(const char* name, bool* dst, const std::string& usage_text) +Flag::Flag(const char* name, const std::function& hook, + bool default_value, const std::string& usage_text) : name_(name), type_(TYPE_BOOL), - value_hook_([dst](const std::string& flag_value) { - return ParseBoolFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + return ParseBoolFlag(flag_value, hook); }), - default_for_display_((*dst) ? "true" : "false"), + default_for_display_(default_value ? "true" : "false"), usage_text_(usage_text) {} -Flag::Flag(const char* name, std::string* dst, const std::string& usage_text) +Flag::Flag(const char* name, + const std::function& hook, + const std::string& default_value, const std::string& usage_text) : name_(name), type_(TYPE_STRING), - value_hook_([dst](const std::string& flag_value) { - return ParseStringFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + hook(flag_value); + return true; }), - default_for_display_(*dst), + default_for_display_(default_value), usage_text_(usage_text) {} bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const { diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h index 36f9e64767315a317338bc4d2db2ec2d43bee875..2e514ae3ead3b602b8217998ec09177b1e6a2376 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h @@ -33,10 +33,11 @@ namespace tflite { // int some_int = 10; // bool some_switch = false; // std::string some_name = "something"; +// // std::vector flag_list = { -// Flag("some_int", &some_int, "an integer that affects X"), -// Flag("some_switch", &some_switch, "a bool that affects Y"), -// Flag("some_name", &some_name, "a std::string that affects Z") +// Flag::CreateFlag("some_int", &some_int, "an integer that affects X"), +// Flag::CreateFlag("some_switch", &some_switch, "a bool that affects Y"), +// Flag::CreateFlag("some_name", &some_name, "a string that affects Z") // }; // // Get usage message before ParseFlags() to capture default values. // std::string usage = Flag::Usage(argv[0], flag_list); @@ -63,11 +64,21 @@ namespace tflite { // text, and a pointer to the corresponding variable. class Flag { public: - Flag(const char* name, int32_t* dst, const std::string& usage_text); - Flag(const char* name, int64_t* dst, const std::string& usage_text); - Flag(const char* name, bool* dst, const std::string& usage_text); - Flag(const char* name, std::string* dst, const std::string& usage_text); - Flag(const char* name, float* dst, const std::string& usage_text); + template + static Flag CreateFlag(const char* name, T* val, const char* usage) { + return Flag(name, [val](const T& v) { *val = v; }, *val, usage); + } + + Flag(const char* name, const std::function& hook, + int32_t default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + int64_t default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + float default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + bool default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + const std::string& default_value, const std::string& usage_text); private: friend class Flags; diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc index 620d61b027d30044ba9d449a8e308375f72ad76f..03da8051099899241fa5241374d754adb1aa93c6 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc @@ -34,15 +34,15 @@ TEST(CommandLineFlagsTest, BasicUsage) { "--some_name=somethingelse", "--some_float=42.0"}; int argc = 6; - bool parsed_ok = - Flags::Parse(&argc, reinterpret_cast(argv_strings), - { - Flag("some_int32", &some_int32, "some int32"), - Flag("some_int64", &some_int64, "some int64"), - Flag("some_switch", &some_switch, "some switch"), - Flag("some_name", &some_name, "some name"), - Flag("some_float", &some_float, "some float"), - }); + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + { + Flag::CreateFlag("some_int32", &some_int32, "some int32"), + Flag::CreateFlag("some_int64", &some_int64, "some int64"), + Flag::CreateFlag("some_switch", &some_switch, "some switch"), + Flag::CreateFlag("some_name", &some_name, "some name"), + Flag::CreateFlag("some_float", &some_float, "some float"), + }); EXPECT_EQ(true, parsed_ok); EXPECT_EQ(20, some_int32); @@ -57,9 +57,9 @@ TEST(CommandLineFlagsTest, EmptyStringFlag) { int argc = 2; std::string some_string = "invalid"; const char* argv_strings[] = {"program_name", "--some_string="}; - bool parsed_ok = - Flags::Parse(&argc, reinterpret_cast(argv_strings), - {Flag("some_string", &some_string, "some string")}); + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_string", &some_string, "some string")}); EXPECT_EQ(true, parsed_ok); EXPECT_EQ(some_string, ""); @@ -72,7 +72,7 @@ TEST(CommandLineFlagsTest, BadIntValue) { const char* argv_strings[] = {"program_name", "--some_int=notanumber"}; bool parsed_ok = Flags::Parse(&argc, reinterpret_cast(argv_strings), - {Flag("some_int", &some_int, "some int")}); + {Flag::CreateFlag("some_int", &some_int, "some int")}); EXPECT_EQ(false, parsed_ok); EXPECT_EQ(10, some_int); @@ -83,9 +83,9 @@ TEST(CommandLineFlagsTest, BadBoolValue) { bool some_switch = false; int argc = 2; const char* argv_strings[] = {"program_name", "--some_switch=notabool"}; - bool parsed_ok = - Flags::Parse(&argc, reinterpret_cast(argv_strings), - {Flag("some_switch", &some_switch, "some switch")}); + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_switch", &some_switch, "some switch")}); EXPECT_EQ(false, parsed_ok); EXPECT_EQ(false, some_switch); @@ -98,7 +98,7 @@ TEST(CommandLineFlagsTest, BadFloatValue) { const char* argv_strings[] = {"program_name", "--some_float=notanumber"}; bool parsed_ok = Flags::Parse(&argc, reinterpret_cast(argv_strings), - {Flag("some_float", &some_float, "some float")}); + {Flag::CreateFlag("some_float", &some_float, "some float")}); EXPECT_EQ(false, parsed_ok); EXPECT_NEAR(-23.23f, some_float, 1e-5f); @@ -136,10 +136,11 @@ TEST(CommandLineFlagsTest, UsageString) { // match against, and we don't want a flakey test. const std::string tool_name = "some_tool_name"; std::string usage = Flags::Usage( - tool_name + " ", {Flag("some_int", &some_int, "some int"), - Flag("some_int64", &some_int64, "some int64"), - Flag("some_switch", &some_switch, "some switch"), - Flag("some_name", &some_name, "some name")}); + tool_name + " ", + {Flag::CreateFlag("some_int", &some_int, "some int"), + Flag::CreateFlag("some_int64", &some_int64, "some int64"), + Flag::CreateFlag("some_switch", &some_switch, "some switch"), + Flag::CreateFlag("some_name", &some_name, "some name")}); // Match the usage message, being sloppy about whitespace. const char* expected_usage = " usage: some_tool_name \n" diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/README.md b/tensorflow/contrib/lite/tools/benchmark/ios/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c8d3307e29efaebdc5c309dc7e4262b54d64943f --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/README.md @@ -0,0 +1,43 @@ +# TFLite iOS benchmark app. + +## Description + +An iOS app to benchmark TFLite models. + +The app reads benchmark parameters from a JSON file named `benchmark_params.json` +in its `benchmark_data` directory. Any downloaded models for benchmarking should +also be placed in `benchmark_data` directory. + +The JSON file specifies the name of the model file and other benchmarking +parameters like inputs to the model, type of inputs, number of iterations, +number of threads. The default values in the JSON file are for the +Mobilenet_1.0_224 model +([paper](https://arxiv.org/pdf/1704.04861.pdf), +[tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz)) + +## To build/install/run + +- Follow instructions at [iOS build for TFLite] +(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md) +to build TFLite. + +Running + +```bash +tensorflow/contrib/lite/build_ios_universal_lib.sh +``` +will also build `tensorflow/contrib/lite/gen/lib/benchmark-lib.a` . + +- Now copy the downloaded model file to `benchmark_data` directory. + +- Modify `benchmark_params.json` change the `input_layer`, `input_layer_shape` +and other benchmark parameters. + +- Change `Build Phases -> Copy Bundle Resources` and add the model file to the +resources that need to be copied. + +- Ensure that `Build Phases -> Link Binary With Library` contains the +`Accelerate framework` and `tensorflow/contrib/lite/gen/lib/benchmark-lib.a`. + +- Now try running the app. The app has a single button that runs the benchmark + on the model and displays results in a text view below. diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..b908f733d49b56a6b41ebea4185f1fe8c11edc60 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj @@ -0,0 +1,381 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 50; + objects = { + +/* Begin PBXBuildFile section */ + 6FE7579A20D59CE500F01636 /* benchmark_params.json in Resources */ = {isa = PBXBuildFile; fileRef = 6FE7579920D59CE500F01636 /* benchmark_params.json */; }; + 6FE7579D20D5A5E000F01636 /* benchmark-lib.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */; }; + 6FE7579F20D5A6A700F01636 /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 6FE7579E20D5A6A700F01636 /* Accelerate.framework */; }; + 6FE757A120D5AB8100F01636 /* mobilenet_v1_1.0_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */; }; + 6FE93FFD20D592D8008C9FE4 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */; }; + 6FE9400020D592D8008C9FE4 /* BenchmarkViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */; }; + 6FE9400320D592D8008C9FE4 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 6FE9400120D592D8008C9FE4 /* Main.storyboard */; }; + 6FE9400520D592DA008C9FE4 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 6FE9400420D592DA008C9FE4 /* Assets.xcassets */; }; + 6FE9400B20D592DA008C9FE4 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 6FE9400A20D592DA008C9FE4 /* main.m */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 6FE7579920D59CE500F01636 /* benchmark_params.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = benchmark_params.json; sourceTree = ""; }; + 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "benchmark-lib.a"; path = "$SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib/benchmark-lib.a"; sourceTree = ""; }; + 6FE7579E20D5A6A700F01636 /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; }; + 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_v1_1.0_224.tflite; sourceTree = ""; }; + 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = TFLiteBenchmark.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 6FE93FFB20D592D8008C9FE4 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; + 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = ""; }; + 6FE93FFE20D592D8008C9FE4 /* BenchmarkViewController.h */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.h; path = BenchmarkViewController.h; sourceTree = ""; }; + 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = BenchmarkViewController.mm; sourceTree = ""; }; + 6FE9400220D592D8008C9FE4 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = ""; }; + 6FE9400420D592DA008C9FE4 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + 6FE9400920D592DA008C9FE4 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + 6FE9400A20D592DA008C9FE4 /* main.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = main.m; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 6FE93FF520D592D8008C9FE4 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 6FE7579F20D5A6A700F01636 /* Accelerate.framework in Frameworks */, + 6FE7579D20D5A5E000F01636 /* benchmark-lib.a in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 6FE7579820D59C8B00F01636 /* benchmark_data */ = { + isa = PBXGroup; + children = ( + 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */, + 6FE7579920D59CE500F01636 /* benchmark_params.json */, + ); + path = benchmark_data; + sourceTree = ""; + }; + 6FE7579B20D5A5E000F01636 /* Frameworks */ = { + isa = PBXGroup; + children = ( + 6FE7579E20D5A6A700F01636 /* Accelerate.framework */, + 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */, + ); + name = Frameworks; + sourceTree = ""; + }; + 6FE93FEF20D592D8008C9FE4 = { + isa = PBXGroup; + children = ( + 6FE93FFA20D592D8008C9FE4 /* TFLiteBenchmark */, + 6FE93FF920D592D8008C9FE4 /* Products */, + 6FE7579B20D5A5E000F01636 /* Frameworks */, + ); + sourceTree = ""; + }; + 6FE93FF920D592D8008C9FE4 /* Products */ = { + isa = PBXGroup; + children = ( + 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */, + ); + name = Products; + sourceTree = ""; + }; + 6FE93FFA20D592D8008C9FE4 /* TFLiteBenchmark */ = { + isa = PBXGroup; + children = ( + 6FE7579820D59C8B00F01636 /* benchmark_data */, + 6FE93FFB20D592D8008C9FE4 /* AppDelegate.h */, + 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */, + 6FE93FFE20D592D8008C9FE4 /* BenchmarkViewController.h */, + 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */, + 6FE9400120D592D8008C9FE4 /* Main.storyboard */, + 6FE9400420D592DA008C9FE4 /* Assets.xcassets */, + 6FE9400920D592DA008C9FE4 /* Info.plist */, + 6FE9400A20D592DA008C9FE4 /* main.m */, + ); + path = TFLiteBenchmark; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 6FE93FF720D592D8008C9FE4 /* TFLiteBenchmark */ = { + isa = PBXNativeTarget; + buildConfigurationList = 6FE9400E20D592DA008C9FE4 /* Build configuration list for PBXNativeTarget "TFLiteBenchmark" */; + buildPhases = ( + 6FE93FF420D592D8008C9FE4 /* Sources */, + 6FE93FF520D592D8008C9FE4 /* Frameworks */, + 6FE93FF620D592D8008C9FE4 /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = TFLiteBenchmark; + productName = TFLiteBenchmark; + productReference = 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 6FE93FF020D592D8008C9FE4 /* Project object */ = { + isa = PBXProject; + attributes = { + LastUpgradeCheck = 1000; + ORGANIZATIONNAME = Example; + TargetAttributes = { + 6FE93FF720D592D8008C9FE4 = { + CreatedOnToolsVersion = 10.0; + }; + }; + }; + buildConfigurationList = 6FE93FF320D592D8008C9FE4 /* Build configuration list for PBXProject "TFLiteBenchmark" */; + compatibilityVersion = "Xcode 9.3"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 6FE93FEF20D592D8008C9FE4; + productRefGroup = 6FE93FF920D592D8008C9FE4 /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 6FE93FF720D592D8008C9FE4 /* TFLiteBenchmark */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 6FE93FF620D592D8008C9FE4 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 6FE757A120D5AB8100F01636 /* mobilenet_v1_1.0_224.tflite in Resources */, + 6FE9400520D592DA008C9FE4 /* Assets.xcassets in Resources */, + 6FE9400320D592D8008C9FE4 /* Main.storyboard in Resources */, + 6FE7579A20D59CE500F01636 /* benchmark_params.json in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 6FE93FF420D592D8008C9FE4 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 6FE9400020D592D8008C9FE4 /* BenchmarkViewController.mm in Sources */, + 6FE9400B20D592DA008C9FE4 /* main.m in Sources */, + 6FE93FFD20D592D8008C9FE4 /* AppDelegate.m in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin PBXVariantGroup section */ + 6FE9400120D592D8008C9FE4 /* Main.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 6FE9400220D592D8008C9FE4 /* Base */, + ); + name = Main.storyboard; + sourceTree = ""; + }; +/* End PBXVariantGroup section */ + +/* Begin XCBuildConfiguration section */ + 6FE9400C20D592DA008C9FE4 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 11.0; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + ONLY_ACTIVE_ARCH = YES; + OTHER_CFLAGS = ""; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + SDKROOT = iphoneos; + }; + name = Debug; + }; + 6FE9400D20D592DA008C9FE4 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 11.0; + MTL_ENABLE_DEBUG_INFO = NO; + OTHER_CFLAGS = ""; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + SDKROOT = iphoneos; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 6FE9400F20D592DA008C9FE4 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_STYLE = Automatic; + "HEADER_SEARCH_PATHS[arch=*]" = ( + $SRCROOT/../../../../../../../, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/eigen, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/gemmlowp, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/neon_2_sse, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/farmhash/src, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/flatbuffers/include, + ); + INFOPLIST_FILE = TFLiteBenchmark/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib; + PRODUCT_BUNDLE_IDENTIFIER = example.TFLiteBenchmark; + PRODUCT_NAME = "$(TARGET_NAME)"; + TARGETED_DEVICE_FAMILY = "1,2"; + "USER_HEADER_SEARCH_PATHS[arch=*]" = ""; + }; + name = Debug; + }; + 6FE9401020D592DA008C9FE4 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_STYLE = Automatic; + "HEADER_SEARCH_PATHS[arch=*]" = ( + $SRCROOT/../../../../../../../, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/eigen, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/gemmlowp, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/neon_2_sse, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/farmhash/src, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/flatbuffers/include, + ); + INFOPLIST_FILE = TFLiteBenchmark/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib; + PRODUCT_BUNDLE_IDENTIFIER = example.TFLiteBenchmark; + PRODUCT_NAME = "$(TARGET_NAME)"; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 6FE93FF320D592D8008C9FE4 /* Build configuration list for PBXProject "TFLiteBenchmark" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 6FE9400C20D592DA008C9FE4 /* Debug */, + 6FE9400D20D592DA008C9FE4 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 6FE9400E20D592DA008C9FE4 /* Build configuration list for PBXNativeTarget "TFLiteBenchmark" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 6FE9400F20D592DA008C9FE4 /* Debug */, + 6FE9401020D592DA008C9FE4 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 6FE93FF020D592D8008C9FE4 /* Project object */; +} diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h new file mode 100644 index 0000000000000000000000000000000000000000..a55c03e00b5065e3b149c65f820f11d13c064d87 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h @@ -0,0 +1,22 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m new file mode 100644 index 0000000000000000000000000000000000000000..b1165940e9a29ac693d473a1c852b7b0681392fc --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m @@ -0,0 +1,27 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +@interface AppDelegate () + +@end + +@implementation AppDelegate +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + return YES; +} +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..d8db8d65fd79fd541b2b7eba75c7378af3448f9c --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,98 @@ +{ + "images" : [ + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "3x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "83.5x83.5", + "scale" : "2x" + }, + { + "idiom" : "ios-marketing", + "size" : "1024x1024", + "scale" : "1x" + } + ], + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..da4a164c918651cdd1e11dca5cc62c333f097601 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..bfa36129419f8bd7ad73581cb9f07b8c6eec3fcf --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..adcfe1ef4e708ea6f87c77f4a740b58e5027d3e5 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h new file mode 100644 index 0000000000000000000000000000000000000000..ec6dea0546060881682c44ad451f4812a2f3d7ea --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h @@ -0,0 +1,21 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface BenchmarkViewController : UIViewController +@property(weak, nonatomic) IBOutlet UITextView *resultsView; + +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm new file mode 100644 index 0000000000000000000000000000000000000000..356d5b0e17abc715de9b8f7a20ec7459f3468da1 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm @@ -0,0 +1,125 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "BenchmarkViewController.h" +#import +#import +#import +#import +#import "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" +#import "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace { +NSString* FilePathForResourceName(NSString* filename) { + NSString* name = [filename stringByDeletingPathExtension]; + NSString* extension = [filename pathExtension]; + NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; + if (file_path == NULL) { + TFLITE_LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] + << "' in bundle."; + } + return file_path; +} + +NSDictionary* ParseJson() { + NSString* params_json_path = FilePathForResourceName(@"benchmark_params.json"); + NSData* data = [NSData dataWithContentsOfFile:params_json_path]; + return [NSJSONSerialization JSONObjectWithData:data options:kNilOptions error:nil]; +} + +std::string FormatCommandLineParam(NSString* key, NSString* value) { + std::ostringstream stream; + stream << "--" << [key UTF8String] << "=" << [value UTF8String]; + return stream.str(); +} + +// Reads the |benchmark_params.json| to read command line parameters and returns them as a vector of +// strings. +void ReadCommandLineParameters(std::vector* params) { + NSDictionary* param_dict = ParseJson(); + for (NSString* key in param_dict) { + NSString* value = param_dict[key]; + if ([key isEqualToString:@"graph"]) { + value = FilePathForResourceName(value); + } + params->push_back(FormatCommandLineParam(key, value)); + } +} +std::vector StringVecToCharPtrVec(const std::vector& str_vec) { + std::vector charptr_vec; + std::transform(str_vec.begin(), str_vec.end(), std::back_inserter(charptr_vec), + [](const std::string& s) -> char* { return const_cast(s.c_str()); }); + return charptr_vec; +} + +class ResultsListener : public tflite::benchmark::BenchmarkListener { + public: + void OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults& results) override; + std::string Results() { return results_; } + + private: + std::string results_; +}; + +void OutputMicrosecondsStatToStream(const tensorflow::Stat& time_us, + const std::string& prefix, std::ostringstream* stream) { + *stream << prefix << "Num runs: " << time_us.count() << "\n"; + + *stream << prefix << "Average: " << time_us.avg() / 1e3 << " ms\n"; + *stream << prefix << "Min: " << time_us.min() / 1e3 << " ms \n"; + *stream << prefix << "Max: " << time_us.max() / 1e3 << " ms \n"; + *stream << prefix << "Std deviation: " << time_us.std_deviation() / 1e3 << " ms\n"; +} + +void ResultsListener::OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults& results) { + std::ostringstream stream; + const std::string prefix = " - "; + stream << "Startup latency: "; + stream << results.startup_latency_us() / 1e3 << " ms\n"; + stream << "\nInference:\n"; + OutputMicrosecondsStatToStream(results.inference_time_us(), prefix, &stream); + stream << "\nWarmup:\n"; + OutputMicrosecondsStatToStream(results.warmup_time_us(), prefix, &stream); + + results_ = stream.str(); +} + +std::string RunBenchmark() { + ResultsListener listener; + tflite::benchmark::BenchmarkTfLiteModel benchmark; + benchmark.AddListener(&listener); + // TODO(shashishekhar): Passing arguments like this is brittle, refactor the BenchmarkParams + // so that it contains arguments for BenchmarkTfLiteModel and set parameters using BenchmarkParams + std::vector command_line_params; + // Benchmark model expects first arg to be program name. + // push a string for name of program. + command_line_params.push_back("benchmark_tflite_model"); + ReadCommandLineParameters(&command_line_params); + std::vector argv = StringVecToCharPtrVec(command_line_params); + int argc = static_cast(argv.size()); + benchmark.Run(argc, argv.data()); + return listener.Results(); +} +} // namespace + +@interface BenchmarkViewController () +@end + +@implementation BenchmarkViewController +- (IBAction)onBenchmarkModel:(UIButton*)sender { + std::string results = RunBenchmark(); + [_resultsView setText:[NSString stringWithUTF8String:results.c_str()]]; +} +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..96051cf08ff54b51f458eca6f0126dd99dfc51dc --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist @@ -0,0 +1,43 @@ + + + + + UILaunchStoryboardName + Main + CFBundleDevelopmentRegion + $(DEVELOPMENT_LANGUAGE) + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + UIInterfaceOrientationPortraitUpsideDown + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + + diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json new file mode 100644 index 0000000000000000000000000000000000000000..d344a7a5efaef53500bc0f88d29ca7aecf59290a --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json @@ -0,0 +1,10 @@ +{ + "benchmark_name" : "mobile_net_benchmark", + "num_threads" : "4", + "num_runs" : "20", + "warmup_runs" : "1", + "graph" : "mobilenet_v1_1.0_224.tflite", + "input_layer" : "input", + "input_layer_shape" : "1,224,224,3", + "run_delay" : "-1" +} diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m new file mode 100644 index 0000000000000000000000000000000000000000..1e70b9cd1d82f320ec048642520dbc54dc0f7934 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m @@ -0,0 +1,23 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "AppDelegate.h" + +int main(int argc, char* argv[]) { + @autoreleasepool { + return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class])); + } +} diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/tools/make/Makefile similarity index 51% rename from tensorflow/contrib/lite/Makefile rename to tensorflow/contrib/lite/tools/make/Makefile index cc8a8035d1dadeec98886ba1dae4cdf403f26de4..e30cc1d70e1370f6243d9dcd39eeaa8f20cc4b1a 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/tools/make/Makefile @@ -6,61 +6,74 @@ endif # Try to figure out the host system HOST_OS := ifeq ($(OS),Windows_NT) - HOST_OS = WINDOWS + HOST_OS = windows else UNAME_S := $(shell uname -s) ifeq ($(UNAME_S),Linux) - HOST_OS := LINUX + HOST_OS := linux endif ifeq ($(UNAME_S),Darwin) - HOST_OS := OSX + HOST_OS := osx endif endif -ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) +HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) -# Where compiled objects are stored. -OBJDIR := $(MAKEFILE_DIR)/gen/obj/ -BINDIR := $(MAKEFILE_DIR)/gen/bin/ -LIBDIR := $(MAKEFILE_DIR)/gen/lib/ -GENDIR := $(MAKEFILE_DIR)/gen/obj/ - -# Settings for the host compiler. -CXX := $(CC_PREFIX)gcc -CXXFLAGS := --std=c++11 -O3 -DNDEBUG -CC := $(CC_PREFIX)gcc -CCFLAGS := -O3 -DNDEBUG -LDOPTS := -LDOPTS += -L/usr/local/lib +# Override these on the make command line to target a specific architecture. For example: +# make -f tensorflow/contrib/lite/Makefile TARGET=rpi TARGET_ARCH=armv7l +TARGET := $(HOST_OS) +TARGET_ARCH := $(HOST_ARCH) + +# These are the default libraries needed, but they can be added to or +# overridden by the platform-specific settings in target makefiles. +LIBS := \ +-lstdc++ \ +-lpthread \ +-lm \ +-lz + +# There are no rules for compiling objects for the host system (since we don't +# generate things like the protobuf compiler that require that), so all of +# these settings are for the target compiler. +CXXFLAGS := -O3 -DNDEBUG +CCFLAGS := ${CXXFLAGS} +CXXFLAGS += --std=c++11 +CFLAGS := +LDOPTS := -L/usr/local/lib ARFLAGS := -r +TARGET_TOOLCHAIN_PREFIX := +CC_PREFIX := + +# These target-specific makefiles should modify or replace options like +# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic +# based on platforms or architectures should happen within these files, to +# keep this main makefile focused on the sources and dependencies. +include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc) + +# Where compiled objects are stored. +GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/ +OBJDIR := $(GENDIR)obj/ +BINDIR := $(GENDIR)bin/ +LIBDIR := $(GENDIR)lib/ INCLUDES := \ -I. \ --I$(MAKEFILE_DIR)/../../../ \ +-I$(MAKEFILE_DIR)/../../../../../ \ +-I$(MAKEFILE_DIR)/../../../../../../ \ -I$(MAKEFILE_DIR)/downloads/ \ -I$(MAKEFILE_DIR)/downloads/eigen \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ -I$(MAKEFILE_DIR)/downloads/neon_2_sse \ -I$(MAKEFILE_DIR)/downloads/farmhash/src \ -I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ --I$(GENDIR) +-I$(OBJDIR) # This is at the end so any globally-installed frameworks like protobuf don't # override local versions in the source tree. INCLUDES += -I/usr/local/include -LIBS := \ --lstdc++ \ --lpthread \ --lm \ --lz - -# If we're on Linux, also link in the dl library. -ifeq ($(HOST_OS),LINUX) - LIBS += -ldl -endif - -include $(MAKEFILE_DIR)/ios_makefile.inc -include $(MAKEFILE_DIR)/rpi_makefile.inc +CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++ +CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc +AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar # This library is the main target for this makefile. It will contain a minimal # runtime that can be linked in to other programs. @@ -70,6 +83,12 @@ LIB_PATH := $(LIBDIR)$(LIB_NAME) # A small example program that shows how to link against the library. MINIMAL_PATH := $(BINDIR)minimal +# Benchmark static library and binary +BENCHMARK_LIB_NAME := benchmark-lib.a +BENCHMARK_BINARY_NAME := benchmark_model +BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME) +BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME) + MINIMAL_SRCS := \ tensorflow/contrib/lite/examples/minimal/minimal.cc MINIMAL_OBJS := $(addprefix $(OBJDIR), \ @@ -78,19 +97,29 @@ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS)))) # What sources we want to compile, must be kept in sync with the main Bazel # build files. +PROFILER_SRCS := \ + tensorflow/contrib/lite/profiling/time.cc +PROFILE_SUMMARIZER_SRCS := \ + tensorflow/contrib/lite/profiling/profile_summarizer.cc \ + tensorflow/core/util/stats_calculator.cc + CORE_CC_ALL_SRCS := \ $(wildcard tensorflow/contrib/lite/*.cc) \ +$(wildcard tensorflow/contrib/lite/*.c) +ifneq ($(BUILD_TYPE),micro) +CORE_CC_ALL_SRCS += \ $(wildcard tensorflow/contrib/lite/kernels/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \ -$(wildcard tensorflow/contrib/lite/*.c) \ +$(PROFILER_SRCS) \ $(wildcard tensorflow/contrib/lite/kernels/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \ -$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) \ -$(wildcard tensorflow/contrib/lite/downloads/fft2d/fftsg.c) +$(wildcard tensorflow/contrib/lite/tools/make/downloads/farmhash/src/farmhash.cc) \ +$(wildcard tensorflow/contrib/lite/tools/make/downloads/fft2d/fftsg.c) +endif # Remove any duplicates. CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) CORE_CC_EXCLUDE_SRCS := \ @@ -100,6 +129,11 @@ $(wildcard tensorflow/contrib/lite/*/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \ $(MINIMAL_SRCS) +ifeq ($(BUILD_TYPE),micro) +CORE_CC_EXCLUDE_SRCS += \ +tensorflow/contrib/lite/mmap_allocation.cc \ +tensorflow/contrib/lite/nnapi_delegate.cc +endif # Filter out all the excluded files. TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) # File names of the intermediate files target compilation generates. @@ -107,21 +141,40 @@ TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS)))) LIB_OBJS := $(TF_LITE_CC_OBJS) +# Benchmark sources +BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark +BENCHMARK_ALL_SRCS := $(TFLITE_CC_SRCS) \ + $(wildcard $(BENCHMARK_SRCS_DIR)/*.cc) \ + $(PROFILE_SUMMARIZER_SRCS) + +BENCHMARK_SRCS := $(filter-out \ + $(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc), \ + $(BENCHMARK_ALL_SRCS)) + +BENCHMARK_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS)))) + # For normal manually-created TensorFlow C++ source files. $(OBJDIR)%.o: %.cc @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ - # For normal manually-created TensorFlow C++ source files. $(OBJDIR)%.o: %.c @mkdir -p $(dir $@) $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ # The target that's compiled if there's no command-line arguments. -all: $(LIB_PATH) $(MINIMAL_PATH) +all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY) + +# The target that's compiled for micro-controllers +micro: $(LIB_PATH) + +# Hack for generating schema file bypassing flatbuffer parsing +tensorflow/contrib/lite/schema/schema_generated.h: + @cp -u tensorflow/contrib/lite/schema/schema_generated.h.OPENSOURCE tensorflow/contrib/lite/schema/schema_generated.h # Gathers together all the objects we've compiled into a single '.a' archive. -$(LIB_PATH): $(LIB_OBJS) +$(LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(LIB_OBJS) @mkdir -p $(dir $@) $(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS) @@ -131,6 +184,21 @@ $(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH) -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \ $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS) + +$(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_OBJS) + @mkdir -p $(dir $@) + $(AR) $(ARFLAGS) $(BENCHMARK_LIB) $(LIB_OBJS) $(BENCHMARK_OBJS) + +benchmark_lib: $(BENCHMARK_LIB) +$(info $(BENCHMARK_BINARY)) +$(BENCHMARK_BINARY) : $(BENCHMARK_LIB) + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) \ + -o $(BENCHMARK_BINARY) \ + $(LIBFLAGS) $(BENCHMARK_LIB) $(LDFLAGS) $(LIBS) + +benchmark: $(BENCHMARK_BINARY) + # Gets rid of all generated files. clean: rm -rf $(MAKEFILE_DIR)/gen diff --git a/tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh b/tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh new file mode 100755 index 0000000000000000000000000000000000000000..fe056945a652b04d078947f58bfe6ab60aa1f387 --- /dev/null +++ b/tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh @@ -0,0 +1,40 @@ +#!/bin/bash -x +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/../../../../.." + +# Build library for supported architectures and packs them in a fat binary. +make_library() { + for arch in x86_64 armv7 armv7s arm64 + do + make -f tensorflow/contrib/lite/tools/make/Makefile TARGET=ios TARGET_ARCH=${arch} \ + -j 8 + done + mkdir -p tensorflow/contrib/lite/tools/make/gen/lib + lipo \ + tensorflow/contrib/lite/tools/make/gen/ios_x86_64/lib/${1} \ + tensorflow/contrib/lite/tools/make/gen/ios_armv7/lib/${1} \ + tensorflow/contrib/lite/tools/make/gen/ios_armv7s/lib/${1} \ + tensorflow/contrib/lite/tools/make/gen/ios_arm64/lib/${1} \ + -create \ + -output tensorflow/contrib/lite/tools/make/gen/lib/${1} +} + +make_library libtensorflow-lite.a +make_library benchmark-lib.a diff --git a/tensorflow/contrib/lite/build_rpi_lib.sh b/tensorflow/contrib/lite/tools/make/build_rpi_lib.sh similarity index 90% rename from tensorflow/contrib/lite/build_rpi_lib.sh rename to tensorflow/contrib/lite/tools/make/build_rpi_lib.sh index 3824b16412ed26a6cab79df3242da6017c3322b0..24ecd4356df12c25dbdbf81684b7de128e8d11f4 100755 --- a/tensorflow/contrib/lite/build_rpi_lib.sh +++ b/tensorflow/contrib/lite/tools/make/build_rpi_lib.sh @@ -17,6 +17,6 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -cd "$SCRIPT_DIR/../../.." +cd "$SCRIPT_DIR/../../../../.." -CC_PREFIX=arm-linux-gnueabihf- make -j 3 -f tensorflow/contrib/lite/Makefile TARGET=RPI TARGET_ARCH=armv7 +CC_PREFIX=arm-linux-gnueabihf- make -j 3 -f tensorflow/contrib/lite/tools/make/Makefile TARGET=rpi TARGET_ARCH=armv7l diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/tools/make/download_dependencies.sh similarity index 96% rename from tensorflow/contrib/lite/download_dependencies.sh rename to tensorflow/contrib/lite/tools/make/download_dependencies.sh index 840015a7fad173dbd2ea353786871dd4e89abb98..29afa45133775224cef5c2bdd59cc513b0a47914 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/tools/make/download_dependencies.sh @@ -17,9 +17,9 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -cd "$SCRIPT_DIR/../../.." +cd "$SCRIPT_DIR/../../../../.." -DOWNLOADS_DIR=tensorflow/contrib/lite/downloads +DOWNLOADS_DIR=tensorflow/contrib/lite/tools/make/downloads BZL_FILE_PATH=tensorflow/workspace.bzl # Ensure it is being run from repo root @@ -35,7 +35,7 @@ GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.g ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" -FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/master.zip" +FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/v1.8.0.zip" FFT2D_URL="https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, diff --git a/tensorflow/contrib/lite/ios_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/ios_makefile.inc similarity index 67% rename from tensorflow/contrib/lite/ios_makefile.inc rename to tensorflow/contrib/lite/tools/make/targets/ios_makefile.inc index 079320586ffd01fc77818a81e0c5962f1d28c1f1..7f36b8ecef4715a4b89e74bd9ef17d28bbf72ae2 100644 --- a/tensorflow/contrib/lite/ios_makefile.inc +++ b/tensorflow/contrib/lite/tools/make/targets/ios_makefile.inc @@ -1,11 +1,11 @@ # Settings for iOS. -ifeq ($(TARGET), IOS) - BUILD_FOR_IOS_SIMULATOR := false - ifeq ($(IOS_ARCH), x86_64) - BUILD_FOR_IOS_SIMULATOR := true +ifeq ($(TARGET), ios) + BUILD_FOR_IOS_SIMULATOR := false + ifeq ($(TARGET_ARCH), x86_64) + BUILD_FOR_IOS_SIMULATOR := true endif - ifeq ($(IOS_ARCH), i386) - BUILD_FOR_IOS_SIMULATOR := true + ifeq ($(TARGET_ARCH), i386) + BUILD_FOR_IOS_SIMULATOR := true endif ifeq ($(BUILD_FOR_IOS_SIMULATOR), true) IPHONEOS_PLATFORM := $(shell xcrun --sdk iphonesimulator \ @@ -18,8 +18,8 @@ ifeq ($(TARGET), IOS) endif IOS_SDK_VERSION := $(shell xcrun --sdk iphoneos --show-sdk-version) MIN_SDK_VERSION := 9.0 - # Override IOS_ARCH with armv7, armv7s, arm64, i386, or x86_64. - IOS_ARCH := x86_64 + # Override TARGET_ARCH with armv7, armv7s, arm64, i386, or x86_64. + TARGET_ARCH := x86_64 CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ -DTFLITE_USE_APPLE_ACCELERATE_FOR_CONV \ @@ -29,21 +29,17 @@ ifeq ($(TARGET), IOS) -fno-exceptions \ -isysroot \ ${IPHONEOS_SYSROOT} \ - -arch $(IOS_ARCH) \ + -arch $(TARGET_ARCH) \ -O3 CCFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ -fembed-bitcode \ -mno-thumb \ -isysroot \ ${IPHONEOS_SYSROOT} \ - -arch $(IOS_ARCH) \ + -arch $(TARGET_ARCH) \ -O3 LDFLAGS := -fembed-bitcode \ -miphoneos-version-min=${MIN_SDK_VERSION} \ -framework Accelerate \ - -arch $(IOS_ARCH) - OBJDIR := $(OBJDIR)ios_$(IOS_ARCH)/ - LIBDIR := $(LIBDIR)ios_$(IOS_ARCH)/ - BINDIR := $(BINDIR)ios_$(IOS_ARCH)/ - DEPDIR := $(DEPDIR)ios_$(IOS_ARCH)/ + -arch $(TARGET_ARCH) endif diff --git a/tensorflow/contrib/lite/tools/make/targets/linux_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/linux_makefile.inc new file mode 100644 index 0000000000000000000000000000000000000000..86499da99e25c4d025707bc71ebf47d821b3a924 --- /dev/null +++ b/tensorflow/contrib/lite/tools/make/targets/linux_makefile.inc @@ -0,0 +1,10 @@ +# Settings for Linux. +ifeq ($(TARGET), linux) + CXXFLAGS += \ + -fPIC \ + -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -pthread + # TODO(petewarden): In the future we may want to add architecture-specific + # flags like -msse4.2 + LIBS += -ldl +endif diff --git a/tensorflow/contrib/lite/tools/make/targets/riscv_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/riscv_makefile.inc new file mode 100644 index 0000000000000000000000000000000000000000..1a82afec33e092090ebb90c1fe18c5adf881f959 --- /dev/null +++ b/tensorflow/contrib/lite/tools/make/targets/riscv_makefile.inc @@ -0,0 +1,10 @@ +# Settings for RiscV platforms. +ifeq ($(TARGET), riscv) + TARGET_ARCH := riscv + TARGET_TOOLCHAIN_PREFIX := riscv32-unknown-elf- + + #CXXFLAGS += -march=gap8 + CXXFLAGS += -DTFLITE_MCU + LIBS += -ldl + BUILD_TYPE := micro +endif diff --git a/tensorflow/contrib/lite/tools/make/targets/rpi_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/rpi_makefile.inc new file mode 100644 index 0000000000000000000000000000000000000000..1ad0c502372e32a5f5d01cde6c8d775189406777 --- /dev/null +++ b/tensorflow/contrib/lite/tools/make/targets/rpi_makefile.inc @@ -0,0 +1,60 @@ +# Settings for Raspberry Pi. +ifeq ($(TARGET),rpi) + # Default to the architecture used on the Pi Two/Three (ArmV7), but override this + # with TARGET_ARCH=armv6 to build for the Pi Zero or One. + TARGET_ARCH := armv7l + TARGET_TOOLCHAIN_PREFIX := arm-linux-gnueabihf- + + ifeq ($(TARGET_ARCH), armv7l) + CXXFLAGS += \ + -march=armv7-a \ + -mfpu=neon-vfpv4 \ + -funsafe-math-optimizations \ + -ftree-vectorize \ + -fPIC + + CCFLAGS += \ + -march=armv7-a \ + -mfpu=neon-vfpv4 \ + -funsafe-math-optimizations \ + -ftree-vectorize \ + -fPIC + + LDFLAGS := \ + -Wl,--no-export-dynamic \ + -Wl,--exclude-libs,ALL \ + -Wl,--gc-sections \ + -Wl,--as-needed + endif + + # TODO(petewarden) In the future, we'll want to use OpenBLAS as a faster + # alternative to Eigen on non-NEON ARM hardware like armv6. + ifeq ($(TARGET_ARCH), armv6) + CXXFLAGS += \ + -march=armv6 \ + -mfpu=vfp \ + -funsafe-math-optimizations \ + -ftree-vectorize \ + -fPIC + + CCFLAGS += \ + -march=armv6 \ + -mfpu=vfp \ + -funsafe-math-optimizations \ + -ftree-vectorize \ + -fPIC + + LDFLAGS := \ + -Wl,--no-export-dynamic \ + -Wl,--exclude-libs,ALL \ + -Wl,--gc-sections \ + -Wl,--as-needed + endif + + LIBS := \ + -lstdc++ \ + -lpthread \ + -lm \ + -ldl + +endif diff --git a/tensorflow/contrib/lite/tools/make/targets/stm32f1_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/stm32f1_makefile.inc new file mode 100644 index 0000000000000000000000000000000000000000..7418e4d196ed1384bc16baa2c0289173060f74ac --- /dev/null +++ b/tensorflow/contrib/lite/tools/make/targets/stm32f1_makefile.inc @@ -0,0 +1,21 @@ +# Settings for STM32F1 platforms. +ifeq ($(TARGET), stm32f1) + TARGET_ARCH := armm1 + TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- + + CXXFLAGS += \ + -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -mcpu=cortex-m1 \ + -mthumb \ + -DTFLITE_MCU \ + -fno-rtti \ + -fmessage-length=0 \ + -fno-exceptions \ + -fno-builtin \ + -ffunction-sections \ + -fdata-sections \ + -funsigned-char \ + -MMD + LIBS += -ldl + BUILD_TYPE := micro +endif diff --git a/tensorflow/contrib/lite/tools/make/targets/stm32f7_makefile.inc b/tensorflow/contrib/lite/tools/make/targets/stm32f7_makefile.inc new file mode 100644 index 0000000000000000000000000000000000000000..48af71e5b4ba34897bd20d42b6a01ae1198a83ef --- /dev/null +++ b/tensorflow/contrib/lite/tools/make/targets/stm32f7_makefile.inc @@ -0,0 +1,41 @@ +# Settings for STM32F7 platforms. +ifeq ($(TARGET), stm32f7) + TARGET_ARCH := armf7 + TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- + + CXXFLAGS += \ + -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -DTFLITE_MCU \ + -fno-rtti \ + -fmessage-length=0 \ + -fno-exceptions \ + -fno-builtin \ + -ffunction-sections \ + -fdata-sections \ + -funsigned-char \ + -MMD \ + -mcpu=cortex-m7 \ + -mthumb \ + -mfpu=fpv5-sp-d16 \ + -mfloat-abi=softfp \ + -std=gnu++11 \ + -fno-rtti \ + -Wvla \ + -c \ + -Wall \ + -Wextra \ + -Wno-unused-parameter \ + -Wno-missing-field-initializers \ + -fmessage-length=0 \ + -fno-exceptions \ + -fno-builtin \ + -ffunction-sections \ + -fdata-sections \ + -funsigned-char \ + -MMD \ + -fno-delete-null-pointer-checks \ + -fomit-frame-pointer \ + -Os + LIBS += -ldl + BUILD_TYPE := micro +endif diff --git a/tensorflow/contrib/lite/tools/visualize.py b/tensorflow/contrib/lite/tools/visualize.py index f571dd59da0a3f4aff264b48fba3e41f75b50404..597dede63b0c089da21f4b0ede065189d8bbe1d8 100644 --- a/tensorflow/contrib/lite/tools/visualize.py +++ b/tensorflow/contrib/lite/tools/visualize.py @@ -28,11 +28,24 @@ import json import os import sys +from tensorflow.python.platform import resource_loader + # Schema to use for flatbuffers _SCHEMA = "third_party/tensorflow/contrib/lite/schema/schema.fbs" -# Where the binary will be once built in for the flatc converter -_BINARY = "third_party/flatbuffers/flatc" +# TODO(angerson): fix later when rules are simplified.. +_SCHEMA = resource_loader.get_path_to_datafile("../schema/schema.fbs") +_BINARY = resource_loader.get_path_to_datafile("../../../../flatbuffers/flatc") +# Account for different package positioning internal vs. external. +if not os.path.exists(_BINARY): + _BINARY = resource_loader.get_path_to_datafile( + "../../../../../flatbuffers/flatc") + +if not os.path.exists(_SCHEMA): + raise RuntimeError("Sorry, schema file cannot be found at %r" % _SCHEMA) +if not os.path.exists(_BINARY): + raise RuntimeError("Sorry, flatc is not available at %r" % _BINARY) + # A CSS description for making the visualizer _CSS = """ @@ -321,7 +334,7 @@ def CreateHtmlFile(tflite_input, html_output): for key, mapping in toplevel_stuff: if not mapping: mapping = lambda x: x - html += "%s%s\n" % (key, mapping(data[key])) + html += "%s%s\n" % (key, mapping(data.get(key))) html += "\n" # Spec on what keys to display diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h index 89d9b4f5cffa99e708f391fd8fe19208009b5e79..3c4801183bad834e5789c97a56416cdf4668f897 100644 --- a/tensorflow/contrib/lite/util.h +++ b/tensorflow/contrib/lite/util.h @@ -26,12 +26,17 @@ limitations under the License. namespace tflite { -// Converts a `std::vector` to a `TfLiteIntArray`. +// Converts a `std::vector` to a `TfLiteIntArray`. The caller takes ownership +// of the returned pointer. TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector& input); +// Converts an array (of the given size) to a `TfLiteIntArray`. The caller +// takes ownership of the returned pointer, and must make sure 'dims' has at +// least 'rank' elemnts. TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims); // Checks whether a `TfLiteIntArray` and an int array have matching elements. +// The caller must guarantee that 'b' has at least 'b_size' elements. bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size, const int* b); diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 4942d941765951ed2ee5555138e91a202b96bf7c..8c0bfefb30319456e378a85c717c28910811159b 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_lookup_ops from tensorflow.python.ops import lookup_ops # pylint: disable=unused-import @@ -395,17 +394,12 @@ class MutableHashTable(LookupInterface): Raises: TypeError: when `keys` do not match the table data types. """ - if keys.dtype.base_dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - with ops.name_scope(name, "%s_lookup_table_find" % self._name, (self._table_ref, keys, self._default_value)) as name: + keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") with ops.colocate_with(self._table_ref): values = gen_lookup_ops.lookup_table_find_v2( self._table_ref, keys, self._default_value, name=name) - - values.set_shape(keys.get_shape().concatenate(self._value_shape)) return values def insert(self, keys, values, name=None): @@ -451,9 +445,6 @@ class MutableHashTable(LookupInterface): with ops.colocate_with(self._table_ref): exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( self._table_ref, self._key_dtype, self._value_dtype, name=name) - - exported_values.set_shape(exported_keys.get_shape().concatenate( - self._value_shape)) return exported_keys, exported_values class _Saveable(BaseSaverBuilder.SaveableObject): @@ -537,14 +528,15 @@ class MutableDenseHashTable(LookupInterface): ValueError: If checkpoint is True and no name was specified. """ self._default_value = ops.convert_to_tensor( - default_value, dtype=value_dtype) + default_value, dtype=value_dtype, name="default_value") self._value_shape = self._default_value.get_shape() # The table must be shared if checkpointing is requested for multi-worker # training to work correctly. Use the node name if no shared_name has been # explicitly specified. use_node_name_sharing = checkpoint and shared_name is None - empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype) + empty_key = ops.convert_to_tensor( + empty_key, dtype=key_dtype, name="empty_key") self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2( empty_key=empty_key, shared_name=shared_name, @@ -591,20 +583,13 @@ class MutableDenseHashTable(LookupInterface): Raises: TypeError: when `keys` do not match the table data types. """ - if keys.dtype.base_dtype != self._key_dtype: - raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % - (self._key_dtype, keys.dtype)) - with ops.name_scope(name, "%s_lookup_table_find" % self._name, [self._table_ref, keys]) as name: + keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") with ops.colocate_with(self._table_ref): values = gen_lookup_ops.lookup_table_find_v2( self._table_ref, keys, self._default_value, name=name) - if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0: - values.set_shape( - tensor_shape.TensorShape([keys.get_shape().dims[0]]).concatenate( - self._value_shape)) return values def insert(self, keys, values, name=None): @@ -624,11 +609,11 @@ class MutableDenseHashTable(LookupInterface): TypeError: when `keys` or `values` doesn't match the table data types. """ - # pylint: disable=protected-access - lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype) - # pylint: enable=protected-access with ops.name_scope(name, "%s_lookup_table_insert" % self._name, [self._table_ref, keys, values]) as name: + keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") + values = ops.convert_to_tensor( + values, dtype=self._value_dtype, name="values") with ops.colocate_with(self._table_ref): op = gen_lookup_ops.lookup_table_insert_v2( self._table_ref, keys, values, name=name) @@ -650,8 +635,6 @@ class MutableDenseHashTable(LookupInterface): exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2( self._table_ref, self._key_dtype, self._value_dtype, name=name) - exported_values.set_shape(exported_keys.get_shape().concatenate( - self._value_shape)) return exported_keys, exported_values class _Saveable(BaseSaverBuilder.SaveableObject): diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5a080cceabb55c307dcd1a457a9e30d24e0bd172..6fb5244fc6230e1c6f6da7708fe30c20a163494c 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -280,6 +280,21 @@ class HashTableOpTest(test.TestCase): table.init.run() self.assertAllEqual(3, table.size().eval()) + def testHashTableInt32String(self): + with self.test_session(): + default_val = "n/a" + keys = constant_op.constant([0, 1, 2], dtypes.int32) + values = constant_op.constant(["brain", "salad", "surgery"]) + table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) + table.init.run() + + input_tensor = constant_op.constant([0, 1, -1]) + output = table.lookup(input_tensor) + + result = output.eval() + self.assertAllEqual([b"brain", b"salad", b"n/a"], result) + class MutableHashTableOpTest(test.TestCase): @@ -419,8 +434,10 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result) exported_keys, exported_values = table.export() - self.assertAllEqual([None], exported_keys.get_shape().as_list()) - self.assertAllEqual([None, 2], exported_values.get_shape().as_list()) + self.assertAllEqual([None], exported_keys.get_shape().as_list(), + msg="Saw shape %s" % exported_keys.shape) + self.assertAllEqual([None, 2], exported_values.get_shape().as_list(), + msg="Saw shape %s" % exported_values.shape) # exported data is in the order of the internal map, i.e. undefined sorted_keys = np.sort(exported_keys.eval()) sorted_values = np.sort(exported_values.eval()) @@ -654,7 +671,7 @@ class MutableHashTableOpTest(test.TestCase): # lookup with keys of the wrong type input_string = constant_op.constant([1, 2, 3], dtypes.int64) - with self.assertRaises(TypeError): + with self.assertRaises(ValueError): table.lookup(input_string).eval() # default value of the wrong type @@ -838,7 +855,8 @@ class MutableDenseHashTableOpTest(test.TestCase): input_string = constant_op.constant([11, 12, 15], dtypes.int64) output = table.lookup(input_string) - self.assertAllEqual([3, 4], output.get_shape()) + self.assertAllEqual( + [3, 4], output.shape, msg="Saw shape: %s" % output.shape) result = output.eval() self.assertAllEqual([[0, 1, 2, 3], [3, 4, 5, 6], [-1, -2, -3, -4]], @@ -1397,7 +1415,7 @@ class KeyValueTensorInitializerTest(test.TestCase): class IndexTableFromTensor(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_index_table_from_tensor_with_tensor_init(self): table = lookup.index_table_from_tensor( mapping=("brain", "salad", "surgery"), num_oov_buckets=1) @@ -1670,7 +1688,7 @@ class InitializeTableFromFileOpTest(test.TestCase): f.write("\n".join(values) + "\n") return vocabulary_file - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitializeStringTable(self): vocabulary_file = self._createVocabFile("one_column_1.txt") default_value = -1 diff --git a/tensorflow/contrib/makefile/build_all_android.sh b/tensorflow/contrib/makefile/build_all_android.sh index fc88f59e0948e1d3ed7cce9b809bf30ba280af12..fb9e77ae1bcfc3404f1fdf90ab2697a4e79a9836 100755 --- a/tensorflow/contrib/makefile/build_all_android.sh +++ b/tensorflow/contrib/makefile/build_all_android.sh @@ -30,6 +30,14 @@ arm64-v8a armeabi armeabi-v7a mips mips64 x86 x86_64 tegra)" exit 1 } +echo "********************************************************************" +echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference." +echo "You are currently using an older version. Please switch over to TensorFlow Lite." +echo "" +echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite" +echo "********************************************************************" +echo "" + if [[ -z "${NDK_ROOT}" ]]; then echo "NDK_ROOT should be set as an environment variable" 1>&2 exit 1 diff --git a/tensorflow/contrib/makefile/build_all_ios.sh b/tensorflow/contrib/makefile/build_all_ios.sh index 0a458a27b3ac9b1a24b0f42de2f0166d515e8cd9..1d4677ef4bd1e8811998d1464e63902544153a49 100755 --- a/tensorflow/contrib/makefile/build_all_ios.sh +++ b/tensorflow/contrib/makefile/build_all_ios.sh @@ -31,6 +31,14 @@ usage() { exit 1 } +echo "********************************************************************" +echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference." +echo "You are currently using an older version. Please switch over to TensorFlow Lite." +echo "" +echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite" +echo "********************************************************************" +echo "" + DEFAULT_ARCH="i386 x86_64 armv7 armv7s arm64" while getopts "a:g:T" opt_name; do case "$opt_name" in diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 48953e2e3843ff92744514d28bd725cc0d72f3a8..dc9b17a62783817ec9a2998c4d5548c0f05e073b 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -30,8 +30,14 @@ EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +# Note: The Protobuf source in `tensorflow/workspace.bzl` in TensorFlow +# 1.10 branch does not work. `make distclean` fails and blocks the build +# process. For now we're hardcoding to the version which is used by +# TensorFlow 1.9. +PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz" +# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' once +# the archive has been propagated in mirror.bazel.build. +RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" DOUBLE_CONVERSION_URL="$(grep -o "https.*google/double-conversion.*\.zip" "${BZL_FILE_PATH}" | head -n1)" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index 76428bc1d4e682e000998a6e28fc290e218c2341..7d26429f9c3b26bcd8819e92cbc15daed60ea9f4 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -35,6 +35,7 @@ tensorflow/core/lib/random/random.cc tensorflow/core/lib/random/distribution_sampler.cc tensorflow/core/lib/io/zlib_outputbuffer.cc tensorflow/core/lib/io/zlib_inputstream.cc +tensorflow/core/lib/io/zlib_compression_options.cc tensorflow/core/lib/io/two_level_iterator.cc tensorflow/core/lib/io/table_builder.cc tensorflow/core/lib/io/table.cc diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 89db9ee2794ddf0a99951dca327e74c5d9694d23..ecf2e120df98d82cca068e186f95e91e71ebc66d 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -92,6 +92,7 @@ tensorflow/core/kernels/reduction_ops_common.cc tensorflow/core/kernels/reduction_ops_any.cc tensorflow/core/kernels/reduction_ops_all.cc tensorflow/core/kernels/roll_op.cc +tensorflow/core/kernels/queue_op.cc tensorflow/core/kernels/queue_ops.cc tensorflow/core/kernels/queue_base.cc tensorflow/core/kernels/pooling_ops_common.cc @@ -228,6 +229,8 @@ tensorflow/core/kernels/cast_op_impl_int32.cc tensorflow/core/kernels/cast_op_impl_int64.cc tensorflow/core/kernels/cast_op_impl_int8.cc tensorflow/core/kernels/cast_op_impl_uint16.cc +tensorflow/core/kernels/cast_op_impl_uint32.cc +tensorflow/core/kernels/cast_op_impl_uint64.cc tensorflow/core/kernels/cast_op_impl_uint8.cc tensorflow/core/kernels/boosted_trees/prediction_ops.cc tensorflow/core/kernels/boosted_trees/resource_ops.cc diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index 4f2c82ca23011667662c74507fcbd99bcde4c7c0..21cd34f73ffbbf615a81c18b9d365bffa61397f4 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -31,6 +31,7 @@ py_library( "//tensorflow/python:check_ops", "//tensorflow/python:confusion_matrix", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:distribute", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:histogram_ops", "//tensorflow/python:init_ops", @@ -77,7 +78,31 @@ py_test( py_test( name = "metric_ops_test", srcs = ["python/ops/metric_ops_test.py"], - shard_count = 16, + shard_count = 30, + srcs_version = "PY2AND3", + tags = ["noasan"], # times out b/63678675 + deps = [ + ":metrics_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "metric_ops_large_test", + size = "large", + srcs = ["python/ops/metric_ops_large_test.py"], srcs_version = "PY2AND3", tags = ["noasan"], # times out b/63678675 deps = [ diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index 5effea3596bb83a08e0a8627e411684262aef5f7..88798d61b71388de63e492ba69284a72303d32ab 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -63,6 +63,7 @@ See the @{$python/contrib.metrics} guide. @@aggregate_metrics @@aggregate_metric_map @@confusion_matrix +@@f1_score @@set_difference @@set_intersection @@set_size diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py index 26aba1cc51446e589856013d69526007fbe9d921..7053907da05b487df73481e3ced269bb69b8deae 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification.py +++ b/tensorflow/contrib/metrics/python/metrics/classification.py @@ -22,6 +22,9 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics_impl +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import distribution_strategy_context # TODO(nsilberman): move into metrics/python/ops/ @@ -62,3 +65,121 @@ def accuracy(predictions, labels, weights=None, name=None): return math_ops.div(math_ops.reduce_sum(is_correct), math_ops.reduce_sum(num_values)) return math_ops.reduce_mean(is_correct) + + +def f1_score(labels, predictions, weights=None, num_thresholds=200, + metrics_collections=None, updates_collections=None, name=None): + """Computes the approximately best F1-score across different thresholds. + + The f1_score function applies a range of thresholds to the predictions to + convert them from [0, 1] to bool. Precision and recall are computed by + comparing them to the labels. The F1-Score is then defined as + 2 * precision * recall / (precision + recall). The best one across the + thresholds is returned. + + Disclaimer: In practice it may be desirable to choose the best threshold on + the validation set and evaluate the F1 score with this threshold on a + separate test set. Or it may be desirable to use a fixed threshold (e.g. 0.5). + + This function internally creates four local variables, `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` that are used to + compute the pairs of recall and precision values for a linearly spaced set of + thresholds from which the best f1-score is derived. + + This value is ultimately returned as `f1-score`, an idempotent operation that + computes the F1-score (computed using the aforementioned variables). The + `num_thresholds` variable controls the degree of discretization with larger + numbers of thresholds more closely approximating the true best F1-score. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the F1-score. + + Example usage with a custom estimator: + def model_fn(features, labels, mode): + predictions = make_predictions(features) + loss = make_loss(predictions, labels) + train_op = tf.contrib.training.create_train_op( + total_loss=loss, + optimizer='Adam') + eval_metric_ops = {'f1': f1_score(labels, predictions)} + return tf.estimator.EstimatorSpec( + mode=mode, + predictions=predictions, + loss=loss, + train_op=train_op, + eval_metric_ops=eval_metric_ops, + export_outputs=export_outputs) + estimator = tf.estimator.Estimator(model_fn=model_fn) + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: A `Tensor` whose shape matches `predictions`. Will be cast to + `bool`. + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + num_thresholds: The number of thresholds to use when discretizing the roc + curve. + metrics_collections: An optional list of collections that `f1_score` should + be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + f1_score: A scalar `Tensor` representing the current best f1-score across + different thresholds. + update_op: An operation that increments the `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` variables + appropriately and whose value matches the `f1_score`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope( + name, 'f1', (labels, predictions, weights)): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=predictions, labels=labels, weights=weights) + # To account for floating point imprecisions / avoid division by zero. + epsilon = 1e-7 + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + thresholds = [0.0 - epsilon] + thresholds + [1.0 + epsilon] + + # Confusion matrix. + values, update_ops = metrics_impl._confusion_matrix_at_thresholds( # pylint: disable=protected-access + labels, predictions, thresholds, weights, includes=('tp', 'fp', 'fn')) + + # Compute precision and recall at various thresholds. + def compute_best_f1_score(tp, fp, fn, name): + precision_at_t = math_ops.div(tp, epsilon + tp + fp, + name='precision_' + name) + recall_at_t = math_ops.div(tp, epsilon + tp + fn, name='recall_' + name) + # Compute F1 score. + f1_at_thresholds = ( + 2.0 * precision_at_t * recall_at_t / + (precision_at_t + recall_at_t + epsilon)) + return math_ops.reduce_max(f1_at_thresholds) + + def f1_across_towers(_, values): + best_f1 = compute_best_f1_score(tp=values['tp'], fp=values['fp'], + fn=values['fn'], name='value') + if metrics_collections: + ops.add_to_collections(metrics_collections, best_f1) + return best_f1 + + best_f1 = distribution_strategy_context.get_tower_context().merge_call( + f1_across_towers, values) + + update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'], + fn=update_ops['fn'], name='update') + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return best_f1, update_op diff --git a/tensorflow/contrib/metrics/python/metrics/classification_test.py b/tensorflow/contrib/metrics/python/metrics/classification_test.py index fa0f12d029620ad6427f715f035ff69f15c133e7..3d0b81c1bed02dae013141367fb052e16d31fe08 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification_test.py +++ b/tensorflow/contrib/metrics/python/metrics/classification_test.py @@ -18,9 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.metrics.python.metrics import classification +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -108,5 +115,200 @@ class ClassificationTest(test.TestCase): self.assertEqual(result, 0.5) +class F1ScoreTest(test.TestCase): + + def setUp(self): + super(F1ScoreTest, self).setUp() + np.random.seed(1) + + def testVars(self): + classification.f1_score( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_thresholds=3) + expected = {'f1/true_positives:0', 'f1/false_positives:0', + 'f1/false_negatives:0'} + self.assertEquals( + expected, set(v.name for v in variables.local_variables())) + self.assertEquals( + set(expected), set(v.name for v in variables.local_variables())) + self.assertEquals( + set(expected), + set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + f1, _ = classification.f1_score( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_thresholds=3, + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [f1]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, f1_op = classification.f1_score( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_thresholds=3, + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [f1_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes.float32, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes.int64, seed=2) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run([f1_op]) + + # Then verify idempotency. + initial_f1 = f1.eval() + for _ in range(10): + self.assertAllClose(initial_f1, f1.eval()) + + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes.float32) + labels = constant_op.constant(inputs) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertEqual(1, f1.eval()) + + def testSomeCorrect(self): + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes.float32) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=1) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1 + # score of 2 * 0.5 * 1 / (1 + 0.5). + self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval()) + + def testAllIncorrect(self): + inputs = np.random.randint(0, 2, size=(10000, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes.float32) + labels = constant_op.constant(1 - inputs, dtype=dtypes.float32) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1 + # score of 2 * 0.5 * 1 / (1 + 0.5). + self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval(), places=2) + + def testWeights1d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0], [1]], shape=(2, 1), dtype=dtypes.float32) + f1, f1_op = classification.f1_score(predictions, labels, weights, + num_thresholds=3) + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertAlmostEqual(1.0, f1.eval(), places=5) + + def testWeights2d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes.float32) + f1, f1_op = classification.f1_score(predictions, labels, weights, + num_thresholds=3) + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertAlmostEqual(1.0, f1.eval(), places=5) + + def testZeroLabelsPredictions(self): + with self.test_session() as sess: + predictions = array_ops.zeros([4], dtype=dtypes.float32) + labels = array_ops.zeros([4]) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertAlmostEqual(0.0, f1.eval(), places=5) + + def testWithMultipleUpdates(self): + num_samples = 1000 + batch_size = 10 + num_batches = int(num_samples / batch_size) + + # Create the labels and data. + labels = np.random.randint(0, 2, size=(num_samples, 1)) + noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1)) + predictions = 0.4 + 0.2 * labels + noise + predictions[predictions > 1] = 1 + predictions[predictions < 0] = 0 + thresholds = [-0.01, 0.5, 1.01] + + expected_max_f1 = -1.0 + for threshold in thresholds: + tp = 0 + fp = 0 + fn = 0 + tn = 0 + for i in range(num_samples): + if predictions[i] >= threshold: + if labels[i] == 1: + tp += 1 + else: + fp += 1 + else: + if labels[i] == 1: + fn += 1 + else: + tn += 1 + epsilon = 1e-7 + expected_prec = tp / (epsilon + tp + fp) + expected_rec = tp / (epsilon + tp + fn) + expected_f1 = (2 * expected_prec * expected_rec / + (epsilon + expected_prec + expected_rec)) + if expected_f1 > expected_max_f1: + expected_max_f1 = expected_f1 + + labels = labels.astype(np.float32) + predictions = predictions.astype(np.float32) + tf_predictions, tf_labels = (dataset_ops.Dataset + .from_tensor_slices((predictions, labels)) + .repeat() + .batch(batch_size) + .make_one_shot_iterator() + .get_next()) + f1, f1_op = classification.f1_score(tf_labels, tf_predictions, + num_thresholds=3) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in range(num_batches): + sess.run([f1_op]) + # Since this is only approximate, we can't expect a 6 digits match. + # Although with higher number of samples/thresholds we should see the + # accuracy improving + self.assertAlmostEqual(expected_max_f1, f1.eval(), 2) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index a6be2084aae6bb05f958929b45977ed21b570603..a328670526089988c181a8e1146c911309640009 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1064,7 +1064,7 @@ def streaming_auc(predictions, name=name) -def _compute_dynamic_auc(labels, predictions, curve='ROC'): +def _compute_dynamic_auc(labels, predictions, curve='ROC', weights=None): """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. Computes the area under the ROC or PR curve using each prediction as a @@ -1077,13 +1077,22 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): predictions: A 1-D `Tensor` of predictions whose values are `float64`. curve: The name of the curve to be computed, 'ROC' for the Receiving Operating Characteristic or 'PR' for the Precision-Recall curve. + weights: A 1-D `Tensor` of weights whose values are `float64`. Returns: A scalar `Tensor` containing the area-under-curve value for the input. """ - # Count the total number of positive and negative labels in the input. + # Compute the total weight and the total positive weight. size = array_ops.size(predictions) - total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32) + if weights is None: + weights = array_ops.ones_like(labels, dtype=dtypes.float64) + labels, predictions, weights = metrics_impl._remove_squeezable_dimensions( + labels, predictions, weights) + total_weight = math_ops.reduce_sum(weights) + total_positive = math_ops.reduce_sum( + array_ops.where( + math_ops.greater(labels, 0), weights, + array_ops.zeros_like(labels, dtype=dtypes.float64))) def continue_computing_dynamic_auc(): """Continues dynamic auc computation, entered if labels are not all equal. @@ -1091,9 +1100,11 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): Returns: A scalar `Tensor` containing the area-under-curve value. """ - # Sort the predictions descending, and the corresponding labels as well. + # Sort the predictions descending, keeping the same order for the + # corresponding labels and weights. ordered_predictions, indices = nn.top_k(predictions, k=size) ordered_labels = array_ops.gather(labels, indices) + ordered_weights = array_ops.gather(weights, indices) # Get the counts of the unique ordered predictions. _, _, counts = array_ops.unique_with_counts(ordered_predictions) @@ -1103,23 +1114,39 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32) # Count the positives to the left of the split indices. - positives = math_ops.cast( - array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]), - dtypes.int32) - true_positives = array_ops.gather(positives, splits) + true_positives = array_ops.gather( + array_ops.pad( + math_ops.cumsum( + array_ops.where( + math_ops.greater(ordered_labels, 0), ordered_weights, + array_ops.zeros_like(ordered_labels, + dtype=dtypes.float64))), + paddings=[[1, 0]]), splits) if curve == 'ROC': - # Count the negatives to the left of every split point and the total - # number of negatives for computing the FPR. - false_positives = math_ops.subtract(splits, true_positives) - total_negative = size - total_positive + # Compute the weight of the negatives to the left of every split point and + # the total weight of the negatives number of negatives for computing the + # FPR. + false_positives = array_ops.gather( + array_ops.pad( + math_ops.cumsum( + array_ops.where( + math_ops.less(ordered_labels, 1), ordered_weights, + array_ops.zeros_like( + ordered_labels, dtype=dtypes.float64))), + paddings=[[1, 0]]), splits) + total_negative = total_weight - total_positive x_axis_values = math_ops.truediv(false_positives, total_negative) y_axis_values = math_ops.truediv(true_positives, total_positive) elif curve == 'PR': x_axis_values = math_ops.truediv(true_positives, total_positive) # For conformance, set precision to 1 when the number of positive # classifications is 0. + positives = array_ops.gather( + array_ops.pad(math_ops.cumsum(ordered_weights), paddings=[[1, 0]]), + splits) y_axis_values = array_ops.where( - math_ops.greater(splits, 0), math_ops.truediv(true_positives, splits), + math_ops.greater(splits, 0), + math_ops.truediv(true_positives, positives), array_ops.ones_like(true_positives, dtype=dtypes.float64)) # Calculate trapezoid areas. @@ -1133,7 +1160,7 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): return control_flow_ops.cond( math_ops.logical_or( math_ops.equal(total_positive, 0), math_ops.equal( - total_positive, size)), + total_positive, total_weight)), true_fn=lambda: array_ops.constant(0, dtypes.float64), false_fn=continue_computing_dynamic_auc) @@ -1143,7 +1170,8 @@ def streaming_dynamic_auc(labels, curve='ROC', metrics_collections=(), updates_collections=(), - name=None): + name=None, + weights=None): """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. USAGE NOTE: this approach requires storing all of the predictions and labels @@ -1168,6 +1196,8 @@ def streaming_dynamic_auc(labels, should be added to. name: An optional name for the variable_scope that contains the metric variables. + weights: A 'Tensor' of non-negative weights whose values are castable to + `float64`. Will be flattened into a 1-D `Tensor`. Returns: auc: A scalar `Tensor` containing the current area-under-curve value. @@ -1195,14 +1225,24 @@ def streaming_dynamic_auc(labels, check_ops.assert_less_equal( labels, array_ops.ones_like(labels, dtypes.int64), - message='labels must be 0 or 1, at least one is >1') + message='labels must be 0 or 1, at least one is >1'), ]): preds_accum, update_preds = streaming_concat( predictions, name='concat_preds') labels_accum, update_labels = streaming_concat( labels, name='concat_labels') - update_op = control_flow_ops.group(update_labels, update_preds) - auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve) + if weights is not None: + weights = array_ops.reshape( + math_ops.cast(weights, dtypes.float64), [-1]) + weights_accum, update_weights = streaming_concat( + weights, name='concat_weights') + update_op = control_flow_ops.group(update_labels, update_preds, + update_weights) + else: + weights_accum = None + update_op = control_flow_ops.group(update_labels, update_preds) + auc = _compute_dynamic_auc( + labels_accum, preds_accum, curve=curve, weights=weights_accum) if updates_collections: ops.add_to_collections(updates_collections, update_op) if metrics_collections: @@ -3675,6 +3715,7 @@ def count(values, name=None): """Computes the number of examples, or sum of `weights`. + This metric keeps track of the denominator in `tf.metrics.mean`. When evaluating some metric (e.g. mean) on one or more subsets of the data, this auxiliary metric is useful for keeping track of how many examples there are in each subset. @@ -3701,15 +3742,21 @@ def count(values, ValueError: If `weights` is not `None` and its shape doesn't match `values`, or if either `metrics_collections` or `updates_collections` are not a list or tuple. + RuntimeError: If eager execution is enabled. """ + if context.executing_eagerly(): + raise RuntimeError('tf.contrib.metrics.count is not supported when eager ' + 'execution is enabled.') with variable_scope.variable_scope(name, 'count', (values, weights)): + count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') if weights is None: num_values = math_ops.to_float(array_ops.size(values)) else: - _, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + values = math_ops.to_float(values) + values, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=values, labels=None, weights=weights) @@ -3718,15 +3765,14 @@ def count(values, num_values = math_ops.reduce_sum(weights) with ops.control_dependencies([values]): - update_op = state_ops.assign_add(count_, num_values) + update_count_op = state_ops.assign_add(count_, num_values) - if metrics_collections: - ops.add_to_collections(metrics_collections, count_) + count_ = metrics_impl._aggregate_variable(count_, metrics_collections) # pylint: disable=protected-access if updates_collections: - ops.add_to_collections(updates_collections, update_op) + ops.add_to_collections(updates_collections, update_count_op) - return count_, update_op + return count_, update_count_op def cohen_kappa(labels, diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7acfc383eb9a659a600752cf57b4978daa8a07bc --- /dev/null +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Large tests for metric_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.metrics.python.ops import metric_ops +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testLargeCase(self): + shape = [32, 512, 256, 1] + predictions = random_ops.random_uniform( + shape, 0.0, 1.0, dtype=dtypes_lib.float32) + labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5) + + result, update_op = metric_ops.precision_recall_at_equal_thresholds( + labels=labels, predictions=predictions, num_thresholds=201) + # Run many updates, enough to cause highly inaccurate values if the + # code used float32 for accumulation. + num_updates = 71 + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_updates): + sess.run(update_op) + + prdata = sess.run(result) + + # Since we use random values, we won't know the tp/fp/tn/fn values, but + # tp and fp at threshold 0 should be the total number of positive and + # negative labels, hence their sum should be total number of pixels. + expected_value = 1.0 * np.product(shape) * num_updates + got_value = prdata.tp[0] + prdata.fp[0] + # They should be at least within 1. + self.assertNear(got_value, expected_value, 1.0) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index b13f08a37d9e856d56903324fc6e7cf1457bb191..401fedcbed8fef12308d563d108725a418dfef17 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2127,6 +2127,44 @@ class StreamingDynamicAUCTest(test.TestCase): sess.run(update_op) self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5) + def testWithWeights(self): + batch_size = 10 + num_batches = 100 + labels = np.array([]) + predictions = np.array([]) + weights = np.array([]) + tf_labels = variables.Variable( + array_ops.ones(batch_size, dtypes_lib.int32), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.int32) + tf_predictions = variables.Variable( + array_ops.ones(batch_size), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.float32) + tf_weights = variables.Variable( + array_ops.ones(batch_size), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.float32) + auc, update_op = metrics.streaming_dynamic_auc(tf_labels, + tf_predictions, + weights=tf_weights) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_batches): + new_labels = np.random.randint(0, 2, size=batch_size) + noise = np.random.uniform(-0.2, 0.2, size=batch_size) + new_predictions = 0.4 + 0.2 * new_labels + noise + new_weights = np.random.uniform(0.0, 3.0, size=batch_size) + labels = np.concatenate([labels, new_labels]) + predictions = np.concatenate([predictions, new_predictions]) + weights = np.concatenate([weights, new_weights]) + sess.run([tf_labels.assign(new_labels), + tf_predictions.assign(new_predictions), + tf_weights.assign(new_weights)]) + sess.run(update_op) + expected_auc = _np_auc(predictions, labels, weights) + self.assertAlmostEqual(expected_auc, auc.eval()) + class AucWithConfidenceIntervalsTest(test.TestCase): @@ -2391,34 +2429,6 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): for _ in range(3): self._testResultsEqual(initial_result, result) - def testLargeCase(self): - self.skipTest("Test consistently timing out") - shape = [32, 512, 256, 1] - predictions = random_ops.random_uniform( - shape, 0.0, 1.0, dtype=dtypes_lib.float32) - labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5) - - result, update_op = metric_ops.precision_recall_at_equal_thresholds( - labels=labels, predictions=predictions, num_thresholds=201) - # Run many updates, enough to cause highly inaccurate values if the - # code used float32 for accumulation. - num_updates = 71 - - with self.test_session() as sess: - sess.run(variables.local_variables_initializer()) - for _ in xrange(num_updates): - sess.run(update_op) - - prdata = sess.run(result) - - # Since we use random values, we won't know the tp/fp/tn/fn values, but - # tp and fp at threshold 0 should be the total number of positive and - # negative labels, hence their sum should be total number of pixels. - expected_value = 1.0 * np.product(shape) * num_updates - got_value = prdata.tp[0] + prdata.fp[0] - # They should be at least within 1. - self.assertNear(got_value, expected_value, 1.0) - def _testCase(self, predictions, labels, @@ -4727,199 +4737,204 @@ class StreamingSparseRecallTest(test.TestCase): self._test_sparse_recall_at_top_k( labels, top_k_predictions, expected=1.0 / 2) - def test_one_label_at_k1_weighted(self): + def _test_one_label_at_k1_weighted(self, labels): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] top_k_predictions = [[3], [3]] - sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], - [0, 0, 1, 0]]) - dense_labels = np.array([[3], [2]], dtype=np.int64) - for labels in (sparse_labels, dense_labels): - # Class 3: 1 label, 2 predictions, 1 correct. - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(1.0,)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(1.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(2.0,)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(2.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=NAN, - class_id=3, - weights=(0.0, 0.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=NAN, - class_id=3, - weights=(0.0, 0.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=NAN, - class_id=3, - weights=(0.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=NAN, - class_id=3, - weights=(0.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 0.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 0.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=2.0 / 2, - class_id=3, - weights=(2.0, 3.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=2.0 / 2, - class_id=3, - weights=(2.0, 3.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=3.0 / 3, - class_id=3, - weights=(3.0, 2.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=3.0 / 3, - class_id=3, - weights=(3.0, 2.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=0.3 / 0.3, - class_id=3, - weights=(0.3, 0.6)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=0.3 / 0.3, - class_id=3, - weights=(0.3, 0.6)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=0.6 / 0.6, - class_id=3, - weights=(0.6, 0.3)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=0.6 / 0.6, - class_id=3, - weights=(0.6, 0.3)) + # Class 3: 1 label, 2 predictions, 1 correct. + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(2.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=NAN, + class_id=3, + weights=(0.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 0.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=NAN, + class_id=3, + weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 0.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=2.0 / 2, + class_id=3, + weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=2.0 / 2, + class_id=3, + weights=(2.0, 3.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=3.0 / 3, + class_id=3, + weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=3.0 / 3, + class_id=3, + weights=(3.0, 2.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=0.3 / 0.3, + class_id=3, + weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.3 / 0.3, + class_id=3, + weights=(0.3, 0.6)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=0.6 / 0.6, + class_id=3, + weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.6 / 0.6, + class_id=3, + weights=(0.6, 0.3)) - # All classes: 2 labels, 2 predictions, 1 correct. - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=NAN, weights=(0.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=NAN, weights=(0.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,)) + # All classes: 2 labels, 2 predictions, 1 correct. + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=NAN, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, weights=(0.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3)) + + def test_one_label_at_k1_weighted_sparse_labels(self): + sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], + [0, 0, 1, 0]]) + self._test_one_label_at_k1_weighted(sparse_labels) + + def test_one_label_at_k1_weighted_dense_labels(self): + dense_labels = np.array([[3], [2]], dtype=np.int64) + self._test_one_label_at_k1_weighted(dense_labels) def test_three_labels_at_k5_nan(self): predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], @@ -6839,6 +6854,11 @@ class CountTest(test.TestCase): array_ops.ones([4, 3]), updates_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + def testReturnType(self): + c, op = metrics.count(array_ops.ones([4, 3])) + self.assertTrue(isinstance(c, ops.Tensor)) + self.assertTrue(isinstance(op, ops.Operation) or isinstance(op, ops.Tensor)) + def testBasic(self): with self.test_session() as sess: values_queue = data_flow_ops.FIFOQueue( diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py index be7377b1519f3bdab8755411af3de7aa0c2dc9eb..eba505881fb648cf4993e2b8ce7d935dca0f4830 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager.py @@ -41,12 +41,12 @@ class LossScaleManager(object): applied on variables. This class is used together with - @{tf.contrib.mixed_precision.LossScaleOptimizer} for mixed precision training + `tf.contrib.mixed_precision.LossScaleOptimizer` for mixed precision training (float32 variables and float16 ops) on Nvidia GPUs in order to achieve the same model quality as single precision training, with the benefits of potential higher throughput. - See @{tf.contrib.mixed_precision.LossScaleOptimizer} for more details. + See `tf.contrib.mixed_precision.LossScaleOptimizer` for more details. """ @abc.abstractmethod diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py index 480f5f6eaf493c5c87c27cc9f8e510ea9c085a72..1b0383d24c0c472b4875d15c3650e37dfd2439e1 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py @@ -34,7 +34,7 @@ def _GetExampleIter(inputs): class FixedLossScaleManagerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_basic(self): itr = _GetExampleIter([True] * 10 + [False] * 10) @@ -84,13 +84,13 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): actual_outputs.append(self.evaluate(lsm.get_loss_scale())) self.assertEqual(actual_outputs, expected_outputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increase_every_n_steps(self): inputs = [True] * 6 expected_outputs = [1, 2, 2, 4, 4, 8] self._test_helper(inputs, expected_outputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_keep_increasing_until_capped(self): init_loss_scale = np.finfo(np.float32).max / 4 + 10 max_float = np.finfo(np.float32).max @@ -104,7 +104,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_decrease_every_n_steps(self): inputs = [False] * 6 init_loss_scale = 1024 @@ -112,7 +112,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_keep_decreasing_until_one(self): inputs = [False] * 10 init_loss_scale = 16 @@ -120,19 +120,19 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_incr_bad_step_clear_good_step(self): inputs = [True, True, True, False, True] expected_outputs = [1, 2, 2, 2, 2] self._test_helper(inputs, expected_outputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_incr_good_step_does_not_clear_bad_step(self): inputs = [True, True, True, False, True, False] expected_outputs = [1, 2, 2, 2, 2, 1] self._test_helper(inputs, expected_outputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_trigger_loss_scale_update_each_step(self): """Test when incr_every_n_step and decr_every_n_nan_or_inf is 1.""" init_loss_scale = 1 @@ -145,7 +145,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale, incr_every_n_step, decr_every_n_nan_or_inf) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_alternating_good_and_bad_gradients_trigger_each_step(self): init_loss_scale = 1 incr_every_n_step = 1 @@ -156,7 +156,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale, incr_every_n_step, decr_every_n_nan_or_inf) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_alternating_good_and_bad_gradients_trigger_incr_every_2steps(self): init_loss_scale = 32 incr_every_n_step = 2 @@ -167,7 +167,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase): self._test_helper(inputs, expected_outputs, init_loss_scale, incr_every_n_step, decr_every_n_nan_or_inf) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_random_mix_good_and_bad_gradients(self): init_loss_scale = 4 inputs = [ diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py index ef34f7bf7bf3eba047b50ce8abf883b0ed741a63..fcce52a07a88547af437382c3ec060b23c9d334e 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py @@ -77,7 +77,7 @@ class LossScaleOptimizer(optimizer.Optimizer): If gradients clipping is applied, one can call `optimizer.compute_gradients()` and `optimizer.apply_gradients()` - seperately. + separately. Notice the following way of using LossScaleOptimizer is not intended. Always use `loss_scale_optimizer.compute_gradients()` to compute gradients instead of @@ -103,7 +103,7 @@ class LossScaleOptimizer(optimizer.Optimizer): Args: opt: The actual optimizer that will be used to compute and apply the - gradients. Must be an implementation of the @{tf.train.Optimizer} + gradients. Must be an implementation of the `tf.train.Optimizer` interface. loss_scale_manager: A LossScaleManager object. """ @@ -117,7 +117,7 @@ class LossScaleOptimizer(optimizer.Optimizer): aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None): - """Compute gradients. See base class @{tf.train.Optimizer}.""" + """Compute gradients. See base class `tf.train.Optimizer`.""" loss_scale = self._loss_scale_manager.get_loss_scale() if context.executing_eagerly(): @@ -141,7 +141,7 @@ class LossScaleOptimizer(optimizer.Optimizer): return self._down_scale(grads_and_vars, loss_scale) def apply_gradients(self, grads_and_vars, global_step=None, name=None): - """Apply gradients. See base class @{tf.train.Optimizer}.""" + """Apply gradients. See base class `tf.train.Optimizer`.""" grads = [g for (g, _) in grads_and_vars] is_finite_grad = [] diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py index dded61ccd58eb79b338d7264e8a057c9456c8695..9009df0eefec13146090ba5fc2096e71ba6eb89d 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py @@ -54,7 +54,7 @@ class LossScaleOptimizerTest(test.TestCase): opt = loss_scale_opt_fn(opt) return x, loss, opt - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_float16_underflow_without_loss_scale(self): lr = 1 init_val = 1. @@ -73,7 +73,7 @@ class LossScaleOptimizerTest(test.TestCase): rtol=0, atol=min(symbolic_update, 1e-6)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_float16_with_loss_scale(self): lr = 1. init_val = 1. @@ -95,7 +95,7 @@ class LossScaleOptimizerTest(test.TestCase): rtol=0, atol=min(expected_update, 1e-6)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_compute_gradients_with_loss_scale(self): lr = 1 init_val = 1. @@ -115,7 +115,7 @@ class LossScaleOptimizerTest(test.TestCase): # Gradients aren't applied. self.assertAllClose(init_val, self.evaluate(x), rtol=0, atol=1e-6) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_compute_gradients_without_loss_scale(self): lr = 1 init_val = 1. @@ -127,7 +127,7 @@ class LossScaleOptimizerTest(test.TestCase): g_v = self.evaluate(grads_and_vars[0][0]) self.assertAllClose(g_v, 0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_apply_gradients(self): x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32) @@ -155,7 +155,7 @@ class LossScaleOptimizerTest(test.TestCase): actual_output.append(self.evaluate(x)) self.assertAllClose(expected_output, actual_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_apply_gradients_loss_scale_is_updated(self): class SimpleLossScaleManager(lsm_lib.LossScaleManager): diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD index 54bd39afacbec07f054f61b72eda0a3654858aa7..16ddc38f5a5ba88485e18b136b2b1081b0e2ff0f 100644 --- a/tensorflow/contrib/model_pruning/BUILD +++ b/tensorflow/contrib/model_pruning/BUILD @@ -95,6 +95,22 @@ py_library( ], ) +py_library( + name = "strip_pruning_vars_lib", + srcs = ["python/strip_pruning_vars_lib.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":pruning", + "//tensorflow/python:client", + "//tensorflow/python:framework", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_test( name = "pruning_utils_test", size = "small", @@ -129,6 +145,31 @@ py_test( ], ) +py_test( + name = "strip_pruning_vars_test", + size = "small", + srcs = ["python/strip_pruning_vars_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":layers", + ":pruning", + ":rnn_cells", + ":strip_pruning_vars_lib", + "//tensorflow/python:client_testlib", + ], +) + +py_binary( + name = "strip_pruning_vars", + srcs = ["python/strip_pruning_vars.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":strip_pruning_vars_lib", + "//tensorflow/python:platform", + ], +) + py_library( name = "init_py", srcs = ["__init__.py"], @@ -145,5 +186,6 @@ py_library( ":learning", ":pruning", ":rnn_cells", + ":strip_pruning_vars_lib", ], ) diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 86f4fd6adf60d8fa54c13989bf4087e28f1e006f..a5267fd90482287a65a4c38ae257a0af349523e8 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -4,7 +4,15 @@ This document describes the API that facilitates magnitude-based pruning of neural network's weight tensors. The API helps inject necessary tensorflow op into the training graph so the model can be pruned while it is being trained. -### Model creation +## Table of contents +1. [Model creation](#model-creation) +2. [Hyperparameters for pruning](#hyperparameters) + - [Block sparsity](#block-sparsity) +3. [Adding pruning ops to the training graph](#adding-pruning-ops) +4. [Removing pruning ops from trained model](#remove) +5. [Example](#example) + +### Model creation The first step involves adding mask and threshold variables to the layers that need to undergo pruning. The variable mask is the same shape as the layer's @@ -33,7 +41,7 @@ auxiliary variables built-in (see * [rnn_cells.MaskedLSTMCell](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py?l=154) -### Adding pruning ops to the training graph +### Pruning-related hyperparameters The pruning library allows for specification of the following hyper parameters: @@ -42,7 +50,7 @@ The pruning library allows for specification of the following hyper parameters: | name | string | model_pruning | Name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope | | begin_pruning_step | integer | 0 | The global step at which to begin pruning | | end_pruning_step | integer | -1 | The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops | -| do_not_prune | list of strings | [""] | list of layers names that are not pruned | +| weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. | | threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | | nbins | integer | 256 | Number of bins to use for histogram computation | @@ -64,12 +72,18 @@ is divided into $$n$$ intervals of size equal to the pruning_frequency ($$\Delta t$$). $$s_f$$ is the target_sparsity, $$s_i$$ is the initial_sparsity, $$t_0$$ is the sparsity_function_begin_step. In this equation, the sparsity_function_exponent is set to 3. -### Adding pruning ops to the training graph -The final step involves adding ops to the training graph that monitors the -distribution of the layer's weight magnitudes and determines the layer threshold -such masking all the weights below this threshold achieves the sparsity level -desired for the current training step. This can be achieved as follows: +#### Block Sparsity + +For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is only supported for weight tensors which can be squeezed to rank 2. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter). +The convolution layer tensors are always pruned used block dimensions of [1,1]. + +### Adding pruning ops to the training graph + +The final step involves adding ops to the training graph that monitor the +distribution of the layer's weight magnitudes and determine the layer threshold, +such that masking all the weights below this threshold achieves the sparsity +level desired for the current training step. This can be achieved as follows: ```python tf.app.flags.DEFINE_string( @@ -79,7 +93,7 @@ tf.app.flags.DEFINE_string( with tf.graph.as_default(): # Create global step variable - global_step = tf.train.get_global_step() + global_step = tf.train.get_or_create_global_step() # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) @@ -103,8 +117,21 @@ with tf.graph.as_default(): mon_sess.run(mask_update_op) ``` +Ensure that `global_step` is being [incremented](https://www.tensorflow.org/api_docs/python/tf/train/Optimizer#minimize), otherwise pruning will not work! + +### Removing pruning ops from the trained graph +Once the model is trained, it is necessary to remove the auxiliary variables (mask, threshold) and pruning ops added to the graph in the steps above. This can be accomplished using the `strip_pruning_vars` utility. + +This utility generates a binary GraphDef in which the variables have been converted to constants. In particular, the threshold variables are removed from the graph and the mask variable is fused with the corresponding weight tensor to produce a `masked_weight` tensor. This tensor is sparse, has the same size as the weight tensor, and the sparsity is as set by the `target_sparsity` or the `weight_sparsity_map` hyperparameters above. + +```shell +$ bazel build -c opt contrib/model_pruning:strip_pruning_vars +$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_dir=/path/to/checkpoints/ --output_node_names=graph_node1,graph_node2 --output_dir=/tmp --filename=pruning_stripped.pb +``` + +For now, it is assumed that the underlying hardware platform will provide mechanisms for compressing the sparse tensors and/or accelerating the sparse tensor computations. -## Example: Pruning and training deep CNNs on the cifar10 dataset +## Example: Pruning and training deep CNNs on the cifar10 dataset Please see https://www.tensorflow.org/tutorials/deep_cnn for details on neural network architecture, setting up inputs etc. The additional changes needed to @@ -120,7 +147,7 @@ incorporate pruning are captured in the following: To train the pruned version of cifar10: -```bash +```shell $ examples_dir=contrib/model_pruning/examples $ bazel build -c opt $examples_dir/cifar10:cifar10_{train,eval} $ bazel-bin/$examples_dir/cifar10/cifar10_train --pruning_hparams=name=cifar10_pruning,begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000 @@ -132,10 +159,14 @@ Eval: $ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once ``` -### Block Sparsity +Removing pruning nodes from the trained graph: -For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is only supported for weight tensors which can be squeezed to rank 2. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter). -The convolution layer tensors are always pruned used block dimensions of [1,1]. +```shell +$ bazel build -c opt contrib/model_pruning:strip_pruning_vars +$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_path=/tmp/cifar10_train --output_node_names=softmax_linear/softmax_linear_2 --filename=cifar_pruned.pb +``` + +The generated GraphDef (cifar_pruned.pb) may be visualized using the [`import_pb_to_tensorboard`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/tools/import_pb_to_tensorboard.py) utility ## References diff --git a/tensorflow/contrib/model_pruning/__init__.py b/tensorflow/contrib/model_pruning/__init__.py index d32bedbcd6b63bc8e473a9e9d1c8e0753877e6f8..6eca54aaee186f5873a84ef2cb3ff3c7cfb42cd4 100644 --- a/tensorflow/contrib/model_pruning/__init__.py +++ b/tensorflow/contrib/model_pruning/__init__.py @@ -33,6 +33,9 @@ from tensorflow.contrib.model_pruning.python.pruning import get_thresholds from tensorflow.contrib.model_pruning.python.pruning import get_weight_sparsity from tensorflow.contrib.model_pruning.python.pruning import get_weights from tensorflow.contrib.model_pruning.python.pruning import Pruning +from tensorflow.contrib.model_pruning.python.strip_pruning_vars_lib import graph_def_from_checkpoint +from tensorflow.contrib.model_pruning.python.strip_pruning_vars_lib import strip_pruning_vars_fn + # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented @@ -41,7 +44,8 @@ _allowed_symbols = [ 'masked_convolution', 'masked_conv2d', 'masked_fully_connected', 'MaskedBasicLSTMCell', 'MaskedLSTMCell', 'train', 'apply_mask', 'get_masked_weights', 'get_masks', 'get_pruning_hparams', 'get_thresholds', - 'get_weights', 'get_weight_sparsity', 'Pruning' + 'get_weights', 'get_weight_sparsity', 'Pruning', 'strip_pruning_vars_fn', + 'graph_def_from_checkpoint' ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/model_pruning/python/layers/layers.py b/tensorflow/contrib/model_pruning/python/layers/layers.py index 466daf204a1ae86a7f37107342046305ea7249fc..d453e350f05c8e66df13c3861959980d69a564e8 100644 --- a/tensorflow/contrib/model_pruning/python/layers/layers.py +++ b/tensorflow/contrib/model_pruning/python/layers/layers.py @@ -139,7 +139,7 @@ def masked_convolution(inputs, with "NC". num_outputs: Integer, the number of output filters. kernel_size: A sequence of N positive integers specifying the spatial - dimensions of of the filters. Can be a single integer to specify the same + dimensions of the filters. Can be a single integer to specify the same value for all spatial dimensions. stride: A sequence of N positive integers specifying the stride at which to compute output. Can be a single integer to specify the same value for all diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py index a5b050d25d00b298a20f7ce6abdda7c1d00db899..5f6c6aea74f2965ccfe552a58cde290b5506ef12 100644 --- a/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py +++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py @@ -48,7 +48,7 @@ class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell): It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. - For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell} + For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell` that follows. """ diff --git a/tensorflow/contrib/model_pruning/python/learning.py b/tensorflow/contrib/model_pruning/python/learning.py index 2b79c23cefe961b1c4056d41b5fcc0a0521efec6..26695237c27cc4fbe4e9fbaa2666d55836ed39b8 100644 --- a/tensorflow/contrib/model_pruning/python/learning.py +++ b/tensorflow/contrib/model_pruning/python/learning.py @@ -33,11 +33,14 @@ to support training of pruned models # Create the train_op train_op = slim.learning.create_train_op(total_loss, optimizer) - # Set up sparsity - sparsity = pruning.setup_gradual_sparsity(self.global_step) + # Parse pruning hyperparameters + pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) - # Create mask update op - mask_update_op = pruning.add_mask_update_ip(sparsity) + # Create a pruning object using the pruning_hparams + p = pruning.Pruning(pruning_hparams) + + # Add mask update ops to the graph + mask_update_op = p.conditional_mask_update_op() # Run training. learning.train(train_op, diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index 4b7af18b3316950afdb90c344ce777848c63e4c1..cd58526ed3620d4bd880cf36d806afac70c4bff7 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -152,8 +152,11 @@ def get_pruning_hparams(): end_pruning_step: integer the global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops - do_not_prune: list of strings - list of layers that are not pruned + weight_sparsity_map: list of strings + comma separed list of weight variable name:target sparsity pairs. + For layers/weights not in this list, sparsity as specified by the + target_sparsity hyperparameter is used. + Eg. [conv1:0.9,conv2/kernel:0.8] threshold_decay: float the decay factor to use for exponential decay of the thresholds pruning_frequency: integer @@ -200,7 +203,7 @@ def get_pruning_hparams(): name='model_pruning', begin_pruning_step=0, end_pruning_step=-1, - do_not_prune=[''], + weight_sparsity_map=[''], threshold_decay=0.9, pruning_frequency=10, nbins=256, @@ -234,6 +237,9 @@ class Pruning(object): # Pruning specification self._spec = spec if spec else get_pruning_hparams() + # Sanity check for pruning hparams + self._validate_spec() + # A tensorflow variable that tracks the sparsity function. # If not provided as input, the graph must already contain the global_step # variable before calling this constructor. @@ -256,6 +262,37 @@ class Pruning(object): # Block pooling function self._block_pooling_function = self._spec.block_pooling_function + # Mapping of weight names and target sparsity + self._weight_sparsity_map = self._get_weight_sparsity_map() + + def _validate_spec(self): + spec = self._spec + if spec.begin_pruning_step < 0: + raise ValueError('Illegal value for begin_pruning_step') + + if spec.begin_pruning_step >= spec.end_pruning_step: + if spec.end_pruning_step != -1: + raise ValueError( + 'Pruning must begin before it can end. begin_step=%d, end_step=%d.' + 'Set end_pruning_step to -1 if pruning is required till training' + 'stops' % (spec.begin_pruning_step, spec.end_pruning_step)) + + if spec.sparsity_function_begin_step < 0: + raise ValueError('Illegal value for sparsity_function_begin_step') + + if spec.sparsity_function_begin_step >= spec.sparsity_function_end_step: + raise ValueError( + 'Sparsity function requires begin_step < end_step') + + if not 0.0 <= spec.threshold_decay < 1.0: + raise ValueError('threshold_decay must be in range [0,1)') + + if not 0.0 <= spec.initial_sparsity < 1.0: + raise ValueError('initial_sparsity must be in range [0,1)') + + if not 0.0 <= spec.target_sparsity < 1.0: + raise ValueError('target_sparsity must be in range [0,1)') + def _setup_global_step(self, global_step): graph_global_step = global_step if graph_global_step is None: @@ -270,11 +307,6 @@ class Pruning(object): target_sparsity = self._spec.target_sparsity exponent = self._spec.sparsity_function_exponent - if begin_step >= end_step: - raise ValueError( - 'Pruning must begin before it can end. begin_step=%d, end_step=%d' % - (begin_step, end_step)) - with ops.name_scope(self._spec.name): p = math_ops.minimum( 1.0, @@ -306,15 +338,36 @@ class Pruning(object): 'last_mask_update_step', dtype=dtypes.int32) return last_update_step - def _exists_in_do_not_prune_list(self, tensor_name): - do_not_prune_list = self._spec.do_not_prune - if not do_not_prune_list[0]: - return False - for layer_name in do_not_prune_list: - if tensor_name.find(layer_name) != -1: - return True - - return False + def _get_weight_sparsity_map(self): + """Return the map of weight_name:sparsity parsed from the hparams.""" + weight_sparsity_map = {} + val_list = self._spec.weight_sparsity_map + filtered_val_list = [l for l in val_list if l] + for val in filtered_val_list: + weight_name, sparsity = val.split(':') + if float(sparsity) >= 1.0: + raise ValueError('Weight sparsity can not exceed 1.0') + weight_sparsity_map[weight_name] = float(sparsity) + + return weight_sparsity_map + + def _get_sparsity(self, weight_name): + """Return target sparsity for the given layer/weight name.""" + target_sparsity = [ + sparsity for name, sparsity in self._weight_sparsity_map.items() + if weight_name.find(name) != -1 + ] + if not target_sparsity: + return self._sparsity + + if len(target_sparsity) > 1: + raise ValueError( + 'Multiple matches in weight_sparsity_map for weight %s' % weight_name) + # TODO(suyoggupta): This will work when initial_sparsity = 0. Generalize + # to handle other cases as well. + return math_ops.mul( + self._sparsity, + math_ops.div(target_sparsity[0], self._spec.target_sparsity)) def _update_mask(self, weights, threshold): """Updates the mask for a given weight tensor. @@ -342,6 +395,8 @@ class Pruning(object): if self._sparsity is None: raise ValueError('Sparsity variable undefined') + sparsity = self._get_sparsity(weights.op.name) + with ops.name_scope(weights.op.name + '_pruning_ops'): abs_weights = math_ops.abs(weights) max_value = math_ops.reduce_max(abs_weights) @@ -354,7 +409,7 @@ class Pruning(object): math_ops.div( math_ops.reduce_sum( math_ops.cast( - math_ops.less(norm_cdf, self._sparsity), dtypes.float32)), + math_ops.less(norm_cdf, sparsity), dtypes.float32)), float(self._spec.nbins)), max_value) smoothed_threshold = math_ops.add_n([ @@ -453,10 +508,6 @@ class Pruning(object): if is_partitioned: weight = weight.as_tensor() - if self._spec.do_not_prune: - if self._exists_in_do_not_prune_list(mask.name): - continue - new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold) self._assign_ops.append( pruning_utils.variable_assign(threshold, new_threshold)) @@ -507,22 +558,15 @@ class Pruning(object): no_update_op) def add_pruning_summaries(self): - """Adds summaries for this pruning spec. - - Args: none - - Returns: none - """ + """Adds summaries of weight sparsities and thresholds.""" with ops.name_scope(self._spec.name + '_summaries'): summary.scalar('sparsity', self._sparsity) summary.scalar('last_mask_update_step', self._last_update_step) masks = get_masks() thresholds = get_thresholds() - for index, mask in enumerate(masks): - if not self._exists_in_do_not_prune_list(mask.name): - summary.scalar(mask.name + '/sparsity', nn_impl.zero_fraction(mask)) - summary.scalar(thresholds[index].op.name + '/threshold', - thresholds[index]) + for mask, threshold in zip(masks, thresholds): + summary.scalar(mask.op.name + '/sparsity', nn_impl.zero_fraction(mask)) + summary.scalar(threshold.op.name + '/threshold', threshold) def print_hparams(self): logging.info(self._spec.to_json()) diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py index f80b7c52c000f13b5ce98dd442ff21abfac37761..33c4ad58bd7f57422935fc839ddfc64d5e1f00f5 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -35,8 +35,8 @@ from tensorflow.python.training import training_util class PruningHParamsTest(test.TestCase): PARAM_LIST = [ "name=test", "threshold_decay=0.9", "pruning_frequency=10", - "do_not_prune=[conv1,conv2]", "sparsity_function_end_step=100", - "target_sparsity=0.9" + "sparsity_function_end_step=100", "target_sparsity=0.9", + "weight_sparsity_map=[conv1:0.8,conv2/kernel:0.8]" ] TEST_HPARAMS = ",".join(PARAM_LIST) @@ -55,9 +55,10 @@ class PruningHParamsTest(test.TestCase): self.assertEqual(p._spec.name, "test") self.assertAlmostEqual(p._spec.threshold_decay, 0.9) self.assertEqual(p._spec.pruning_frequency, 10) - self.assertAllEqual(p._spec.do_not_prune, ["conv1", "conv2"]) self.assertEqual(p._spec.sparsity_function_end_step, 100) self.assertAlmostEqual(p._spec.target_sparsity, 0.9) + self.assertEqual(p._weight_sparsity_map["conv1"], 0.8) + self.assertEqual(p._weight_sparsity_map["conv2/kernel"], 0.8) def testInitWithExternalSparsity(self): with self.test_session(): @@ -211,6 +212,37 @@ class PruningTest(test.TestCase): expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40] self.assertAllEqual(expected_non_zero_count, non_zero_count) + def testWeightSpecificSparsity(self): + param_list = [ + "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", + "target_sparsity=0.5", "weight_sparsity_map=[layer2/weights:0.75]", + "threshold_decay=0.0" + ] + test_spec = ",".join(param_list) + pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) + + with variable_scope.variable_scope("layer1"): + w1 = variables.Variable( + math_ops.linspace(1.0, 100.0, 100), name="weights") + _ = pruning.apply_mask(w1) + with variable_scope.variable_scope("layer2"): + w2 = variables.Variable( + math_ops.linspace(1.0, 100.0, 100), name="weights") + _ = pruning.apply_mask(w2) + + p = pruning.Pruning(pruning_hparams) + mask_update_op = p.conditional_mask_update_op() + increment_global_step = state_ops.assign_add(self.global_step, 1) + + with self.test_session() as session: + variables.global_variables_initializer().run() + for _ in range(110): + session.run(mask_update_op) + session.run(increment_global_step) + + self.assertAllEqual( + session.run(pruning.get_weight_sparsity()), [0.5, 0.75]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/model_pruning/python/strip_pruning_vars.py b/tensorflow/contrib/model_pruning/python/strip_pruning_vars.py new file mode 100644 index 0000000000000000000000000000000000000000..3385103807f6dbdab2d27882c670a3ccf6a26e9d --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/strip_pruning_vars.py @@ -0,0 +1,103 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Removes the auxiliary variables and ops added by the pruning library. + +Usage: + +bazel build tensorflow/contrib/model_pruning:strip_pruning_vars && \ +bazel-bin/tensorflow/contrib/model_pruning/strip_pruning_vars \ +--checkpoint_dir=/tmp/model_ckpts \ +--output_node_names=softmax \ +--output_dir=/tmp \ +--filename=pruning_stripped.pb +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys + +from tensorflow.contrib.model_pruning.python import strip_pruning_vars_lib +from tensorflow.python.framework import graph_io +from tensorflow.python.platform import app +from tensorflow.python.platform import tf_logging as logging + +FLAGS = None + + +def strip_pruning_vars(checkpoint_dir, output_node_names, output_dir, filename): + """Remove pruning-related auxiliary variables and ops from the graph. + + Accepts training checkpoints and produces a GraphDef in which the pruning vars + and ops have been removed. + + Args: + checkpoint_dir: Path to the checkpoints. + output_node_names: The name of the output nodes, comma separated. + output_dir: Directory where to write the graph. + filename: Output GraphDef file name. + + Returns: + None + + Raises: + ValueError: if output_nodes_names are not provided. + """ + if not output_node_names: + raise ValueError( + 'Need to specify atleast 1 output node through output_node_names flag') + output_node_names = output_node_names.replace(' ', '').split(',') + + initial_graph_def = strip_pruning_vars_lib.graph_def_from_checkpoint( + checkpoint_dir, output_node_names) + + final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn( + initial_graph_def, output_node_names) + graph_io.write_graph(final_graph_def, output_dir, filename, as_text=False) + logging.info('\nFinal graph written to %s', os.path.join( + output_dir, filename)) + + +def main(unused_args): + return strip_pruning_vars(FLAGS.checkpoint_dir, FLAGS.output_node_names, + FLAGS.output_dir, FLAGS.filename) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--checkpoint_dir', type=str, default='', help='Path to the checkpoints.') + parser.add_argument( + '--output_node_names', + type=str, + default='', + help='The name of the output nodes, comma separated.') + parser.add_argument( + '--output_dir', + type=str, + default='/tmp', + help='Directory where to write the graph.') + parser.add_argument( + '--filename', + type=str, + default='pruning_stripped.pb', + help='Output \'GraphDef\' file name.') + + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/model_pruning/python/strip_pruning_vars_lib.py b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4b10863f7c46235059f948fbbfcfcf83d3e15b --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_lib.py @@ -0,0 +1,142 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to remove pruning-related ops and variables from a GraphDef. +""" + +# pylint: disable=missing-docstring +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import saver as saver_lib + + +def _node_name(tensor_name): + """Remove the trailing ':0' from the variable name.""" + if ':' not in tensor_name: + return tensor_name + + return tensor_name.split(':')[0] + + +def _tensor_name(node_name): + """Appends the :0 in the op name to get the canonical tensor name.""" + if ':' in node_name: + return node_name + + return node_name + ':0' + + +def _get_masked_weights(input_graph_def): + """Extracts masked_weights from the graph as a dict of {var_name:ndarray}.""" + input_graph = ops.Graph() + with input_graph.as_default(): + importer.import_graph_def(input_graph_def, name='') + + with session.Session(graph=input_graph) as sess: + masked_weights_dict = {} + for node in input_graph_def.node: + if 'masked_weight' in node.name: + masked_weight_val = sess.run( + sess.graph.get_tensor_by_name(_tensor_name(node.name))) + logging.info( + '%s has %d values, %1.2f%% zeros \n', node.name, + np.size(masked_weight_val), + 100 - float(100 * np.count_nonzero(masked_weight_val)) / + np.size(masked_weight_val)) + masked_weights_dict.update({node.name: masked_weight_val}) + return masked_weights_dict + + +def strip_pruning_vars_fn(input_graph_def, output_node_names): + """Removes mask variable from the graph. + + Replaces the masked_weight tensor with element-wise multiplication of mask + and the corresponding weight variable. + + Args: + input_graph_def: A GraphDef in which the variables have been converted to + constants. This is typically the output of + tf.graph_util.convert_variables_to_constant() + output_node_names: List of name strings for the result nodes of the graph + + Returns: + A GraphDef in which pruning-related variables have been removed + """ + masked_weights_dict = _get_masked_weights(input_graph_def) + pruned_graph_def = graph_pb2.GraphDef() + + # Replace masked_weight with a const op containing the + # result of tf.multiply(mask,weight) + for node in input_graph_def.node: + output_node = node_def_pb2.NodeDef() + if 'masked_weight' in node.name: + output_node.op = 'Const' + output_node.name = node.name + dtype = node.attr['T'] + data = masked_weights_dict[node.name] + output_node.attr['dtype'].CopyFrom(dtype) + output_node.attr['value'].CopyFrom( + attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(data))) + + else: + output_node.CopyFrom(node) + pruned_graph_def.node.extend([output_node]) + + # Remove stranded nodes: mask and weights + return graph_util.extract_sub_graph(pruned_graph_def, output_node_names) + + +def graph_def_from_checkpoint(checkpoint_dir, output_node_names): + """Converts checkpoint data to GraphDef. + + Reads the latest checkpoint data and produces a GraphDef in which the + variables have been converted to constants. + + Args: + checkpoint_dir: Path to the checkpoints. + output_node_names: List of name strings for the result nodes of the graph. + + Returns: + A GraphDef from the latest checkpoint + + Raises: + ValueError: if no checkpoint is found + """ + checkpoint_path = saver_lib.latest_checkpoint(checkpoint_dir) + if checkpoint_path is None: + raise ValueError('Could not find a checkpoint at: {0}.' + .format(checkpoint_dir)) + + saver_for_restore = saver_lib.import_meta_graph( + checkpoint_path + '.meta', clear_devices=True) + with session.Session() as sess: + saver_for_restore.restore(sess, checkpoint_path) + graph_def = ops.get_default_graph().as_graph_def() + output_graph_def = graph_util.convert_variables_to_constants( + sess, graph_def, output_node_names) + + return output_graph_def diff --git a/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py new file mode 100644 index 0000000000000000000000000000000000000000..255daa036099c0d3ef2dbc5eb37fdb0c31c71383 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py @@ -0,0 +1,232 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for strip_pruning_vars.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +from tensorflow.contrib.model_pruning.python import pruning +from tensorflow.contrib.model_pruning.python import strip_pruning_vars_lib +from tensorflow.contrib.model_pruning.python.layers import layers +from tensorflow.contrib.model_pruning.python.layers import rnn_cells +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell as tf_rnn_cells +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import training_util + + +def _get_number_pruning_vars(graph_def): + number_vars = 0 + for node in graph_def.node: + if re.match(r"^.*(mask$)|(threshold$)", node.name): + number_vars += 1 + return number_vars + + +def _get_node_names(tensor_names): + return [ + strip_pruning_vars_lib._node_name(tensor_name) + for tensor_name in tensor_names + ] + + +class StripPruningVarsTest(test.TestCase): + + def setUp(self): + param_list = [ + "pruning_frequency=1", "begin_pruning_step=1", "end_pruning_step=10", + "nbins=2048", "threshold_decay=0.0" + ] + self.initial_graph = ops.Graph() + self.initial_graph_def = None + self.final_graph = ops.Graph() + self.final_graph_def = None + self.pruning_spec = ",".join(param_list) + with self.initial_graph.as_default(): + self.sparsity = variables.Variable(0.5, name="sparsity") + self.global_step = training_util.get_or_create_global_step() + self.increment_global_step = state_ops.assign_add(self.global_step, 1) + self.mask_update_op = None + + def _build_convolutional_model(self, number_of_layers): + # Create a graph with several conv2d layers + kernel_size = 3 + base_depth = 4 + depth_step = 7 + height, width = 7, 9 + with variable_scope.variable_scope("conv_model"): + input_tensor = array_ops.ones((8, height, width, base_depth)) + top_layer = input_tensor + for ix in range(number_of_layers): + top_layer = layers.masked_conv2d( + top_layer, + base_depth + (ix + 1) * depth_step, + kernel_size, + scope="Conv_" + str(ix)) + + return top_layer + + def _build_fully_connected_model(self, number_of_layers): + base_depth = 4 + depth_step = 7 + + input_tensor = array_ops.ones((8, base_depth)) + + top_layer = input_tensor + + with variable_scope.variable_scope("fc_model"): + for ix in range(number_of_layers): + top_layer = layers.masked_fully_connected( + top_layer, base_depth + (ix + 1) * depth_step) + + return top_layer + + def _build_lstm_model(self, number_of_layers): + batch_size = 8 + dim = 10 + inputs = variables.Variable(random_ops.random_normal([batch_size, dim])) + + def lstm_cell(): + return rnn_cells.MaskedBasicLSTMCell( + dim, forget_bias=0.0, state_is_tuple=True, reuse=False) + + cell = tf_rnn_cells.MultiRNNCell( + [lstm_cell() for _ in range(number_of_layers)], state_is_tuple=True) + + outputs = rnn.static_rnn( + cell, [inputs], + initial_state=cell.zero_state(batch_size, dtypes.float32)) + + return outputs + + def _prune_model(self, session): + pruning_hparams = pruning.get_pruning_hparams().parse(self.pruning_spec) + p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity) + self.mask_update_op = p.conditional_mask_update_op() + + variables.global_variables_initializer().run() + for _ in range(20): + session.run(self.mask_update_op) + session.run(self.increment_global_step) + + def _get_outputs(self, session, input_graph, tensors_list, graph_prefix=None): + outputs = [] + + for output_tensor in tensors_list: + if graph_prefix: + output_tensor = graph_prefix + "/" + output_tensor + outputs.append( + session.run(session.graph.get_tensor_by_name(output_tensor))) + + return outputs + + def _get_initial_outputs(self, output_tensor_names_list): + with self.test_session(graph=self.initial_graph) as sess1: + self._prune_model(sess1) + reference_outputs = self._get_outputs(sess1, self.initial_graph, + output_tensor_names_list) + + self.initial_graph_def = graph_util.convert_variables_to_constants( + sess1, sess1.graph.as_graph_def(), + _get_node_names(output_tensor_names_list)) + return reference_outputs + + def _get_final_outputs(self, output_tensor_names_list): + self.final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn( + self.initial_graph_def, _get_node_names(output_tensor_names_list)) + _ = importer.import_graph_def(self.final_graph_def, name="final") + + with self.test_session(self.final_graph) as sess2: + final_outputs = self._get_outputs( + sess2, + self.final_graph, + output_tensor_names_list, + graph_prefix="final") + return final_outputs + + def _check_removal_of_pruning_vars(self, number_masked_layers): + self.assertEqual( + _get_number_pruning_vars(self.initial_graph_def), number_masked_layers) + self.assertEqual(_get_number_pruning_vars(self.final_graph_def), 0) + + def _check_output_equivalence(self, initial_outputs, final_outputs): + for initial_output, final_output in zip(initial_outputs, final_outputs): + self.assertAllEqual(initial_output, final_output) + + def testConvolutionalModel(self): + with self.initial_graph.as_default(): + number_masked_conv_layers = 5 + top_layer = self._build_convolutional_model(number_masked_conv_layers) + output_tensor_names = [top_layer.name] + initial_outputs = self._get_initial_outputs(output_tensor_names) + + # Remove pruning-related nodes. + with self.final_graph.as_default(): + final_outputs = self._get_final_outputs(output_tensor_names) + + # Check that the final graph has no pruning-related vars + self._check_removal_of_pruning_vars(number_masked_conv_layers) + + # Check that outputs remain the same after removal of pruning-related nodes + self._check_output_equivalence(initial_outputs, final_outputs) + + def testFullyConnectedModel(self): + with self.initial_graph.as_default(): + number_masked_fc_layers = 3 + top_layer = self._build_fully_connected_model(number_masked_fc_layers) + output_tensor_names = [top_layer.name] + initial_outputs = self._get_initial_outputs(output_tensor_names) + + # Remove pruning-related nodes. + with self.final_graph.as_default(): + final_outputs = self._get_final_outputs(output_tensor_names) + + # Check that the final graph has no pruning-related vars + self._check_removal_of_pruning_vars(number_masked_fc_layers) + + # Check that outputs remain the same after removal of pruning-related nodes + self._check_output_equivalence(initial_outputs, final_outputs) + + def testLSTMModel(self): + with self.initial_graph.as_default(): + number_masked_lstm_layers = 2 + outputs = self._build_lstm_model(number_masked_lstm_layers) + output_tensor_names = [outputs[0][0].name] + initial_outputs = self._get_initial_outputs(output_tensor_names) + + # Remove pruning-related nodes. + with self.final_graph.as_default(): + final_outputs = self._get_final_outputs(output_tensor_names) + + # Check that the final graph has no pruning-related vars + self._check_removal_of_pruning_vars(number_masked_lstm_layers) + + # Check that outputs remain the same after removal of pruning-related nodes + self._check_output_equivalence(initial_outputs, final_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD index a7be92a35e0d62a61f7923ac61bb2c1267d039c6..ecac06354d2ce796f2a6021cdf2370d7c30ccab7 100644 --- a/tensorflow/contrib/mpi_collectives/BUILD +++ b/tensorflow/contrib/mpi_collectives/BUILD @@ -52,6 +52,7 @@ tf_custom_op_library( deps = [ ":mpi_defines", ":mpi_message_proto_cc", + "//tensorflow/stream_executor:stream_executor_headers_lib", "//third_party/mpi", ], ) diff --git a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc index ed22ee667f1d73b3f86f77e09bad9bfec7e46391..e4b0c2c6541836243347d2950686c60ef06d2bfc 100644 --- a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc +++ b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc @@ -73,7 +73,7 @@ limitations under the License. */ template -using StatusOr = se::port::StatusOr; +using StatusOr = stream_executor::port::StatusOr; using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.py b/tensorflow/contrib/mpi_collectives/mpi_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..bd7096d9cee2d32bde5227a95038ae65cd8a6e18 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/mpi_ops.py @@ -0,0 +1,163 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Inter-process communication using MPI.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.python.framework import errors +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging + + +def _load_library(name, op_list=None): + """Loads a .so file containing the specified operators. + + Args: + name: The name of the .so file to load. + op_list: A list of names of operators that the library should have. If None + then the .so file's contents will not be verified. + + Raises: + NameError if one of the required ops is missing. + """ + try: + filename = resource_loader.get_path_to_datafile(name) + library = load_library.load_op_library(filename) + for expected_op in (op_list or []): + for lib_op in library.OP_LIST.op: + if lib_op.name == expected_op: + break + else: + raise NameError('Could not find operator %s in dynamic library %s' % + (expected_op, name)) + return library + except errors.NotFoundError: + logging.warning('%s file could not be loaded.', name) + + +MPI_LIB = _load_library( + 'mpi_collectives.so', + ['MPISize', 'MPIRank', 'MPILocalRank', 'MPIAllgather', 'MPIAllreduce']) + + +def size(name=None): + """An op which returns the number of MPI processes. + + This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the + size of the global communicator. + + Returns: + An integer scalar containing the number of MPI processes. + """ + return MPI_LIB.mpi_size(name=name) + + +ops.NotDifferentiable('MPISize') + + +def rank(name=None): + """An op which returns the MPI rank of the calling process. + + This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the + rank of the current process in the global communicator. + + Returns: + An integer scalar with the MPI rank of the calling process. + """ + return MPI_LIB.mpi_rank(name=name) + + +ops.NotDifferentiable('MPIRank') + + +def init(name=None): + """An op which initializes MPI on the device on which it is run. + + All future MPI ops must be run on the same device that the `init` op was run + on. + """ + return MPI_LIB.mpi_init(name=name) + + +ops.NotDifferentiable('MPIInit') + + +def local_rank(name=None): + """An op which returns the local MPI rank of the calling process, within the + node that it is running on. For example, if there are seven processes running + on a node, their local ranks will be zero through six, inclusive. + + This is equivalent to running `MPI_Comm_rank(...)` on a new communicator + which only includes processes on the same node. + + Returns: + An integer scalar with the local MPI rank of the calling process. + """ + return MPI_LIB.mpi_local_rank(name=name) + + +ops.NotDifferentiable('MPILocalRank') + + +def _allreduce(tensor, name=None): + """An op which sums an input tensor over all the MPI processes. + + The reduction operation is keyed by the name of the op. The tensor type and + shape must be the same on all MPI processes for a given name. The reduction + will not start until all processes are ready to send and receive the tensor. + + Returns: + A tensor of the same shape and type as `tensor`, summed across all + processes. + """ + return MPI_LIB.mpi_allreduce(tensor, name=name) + + +ops.NotDifferentiable('MPIAllreduce') + + +def allgather(tensor, name=None): + """An op which concatenates the input tensor with the same input tensor on + all other MPI processes. + + The concatenation is done on the first dimension, so the input tensors on the + different processes must have the same rank and shape, except for the first + dimension, which is allowed to be different. + + Returns: + A tensor of the same type as `tensor`, concatenated on dimension zero + across all processes. The shape is identical to the input shape, except for + the first dimension, which may be greater and is the sum of all first + dimensions of the tensors in different MPI processes. + """ + # Specify that first allgather is to collect the tensor gather sizes, + # indicated by passing in a scalar (0-D tensor) of value 0 + sizes_flag = tf.constant(0, dtype=tf.int64, name='size_flag_const') + my_size = tf.slice( + tf.shape(tensor, out_type=tf.int64), [0], [1], name='size_slice') + if name is None: + name = 'allgather' + sizing_name = '{}_sizing'.format(name) + sizes = MPI_LIB.mpi_allgather(my_size, sizes_flag, name=sizing_name) + return MPI_LIB.mpi_allgather(tensor, sizes, name=name) + + +ops.NotDifferentiable('MPIAllgather') diff --git a/tensorflow/contrib/mpi_collectives/ring.cc b/tensorflow/contrib/mpi_collectives/ring.cc new file mode 100644 index 0000000000000000000000000000000000000000..d93233eb210b80df10fd9c2c7975ce77112d18a2 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/ring.cc @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef TENSORFLOW_USE_MPI + +#define EIGEN_USE_THREADS + +#include "tensorflow/contrib/mpi_collectives/ring.h" + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; + +extern template MPI_Datatype MPIType(); +extern template MPI_Datatype MPIType(); +extern template MPI_Datatype MPIType(); +extern template DataType TensorFlowDataType(); +extern template DataType TensorFlowDataType(); +extern template DataType TensorFlowDataType(); + +// Generate all necessary specializations for RingAllreduce. +template Status RingAllreduce(OpKernelContext*, const Tensor*, + Tensor*, Tensor*); +template Status RingAllreduce(OpKernelContext*, + const Tensor*, Tensor*, + Tensor*); +template Status RingAllreduce(OpKernelContext*, const Tensor*, + Tensor*, Tensor*); + +// Generate all necessary specializations for RingAllgather. +template Status RingAllgather(OpKernelContext*, const Tensor*, + const std::vector&, + Tensor*); +template Status RingAllgather(OpKernelContext*, + const Tensor*, + const std::vector&, + Tensor*); +template Status RingAllgather(OpKernelContext*, const Tensor*, + const std::vector&, + Tensor*); + +// Copy data on a CPU using a straight-forward memcpy. +template <> +void CopyTensorData(void* dst, void* src, size_t size) { + std::memcpy(dst, src, size); +}; + +// Accumulate values on a CPU. +#define GENERATE_ACCUMULATE(type) \ + template <> \ + void AccumulateTensorData(type * dst, type * src, \ + size_t size) { \ + for (unsigned int i = 0; i < size; i++) { \ + dst[i] += src[i]; \ + } \ + }; +GENERATE_ACCUMULATE(int); +GENERATE_ACCUMULATE(long long); +GENERATE_ACCUMULATE(float); +#undef GENERATE_ACCUMULATE + +} // namespace mpi +} // namespace contrib +} // namespace tensorflow + +#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/ring.cu.cc b/tensorflow/contrib/mpi_collectives/ring.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..2f3eef366a9a3c10e59cd5298fc1626e1094dff8 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/ring.cu.cc @@ -0,0 +1,117 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef TENSORFLOW_USE_MPI + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/mpi_collectives/ring.h" + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; + +template <> +MPI_Datatype MPIType() { + return MPI_FLOAT; +}; +template <> +MPI_Datatype MPIType() { + return MPI_INT; +}; +template <> +MPI_Datatype MPIType() { + return MPI_LONG_LONG; +}; + +template <> +DataType TensorFlowDataType() { + return DT_FLOAT; +}; +template <> +DataType TensorFlowDataType() { + return DT_INT32; +}; +template <> +DataType TensorFlowDataType() { + return DT_INT64; +}; + +// Generate all necessary specializations for RingAllreduce. +template Status RingAllreduce(OpKernelContext*, const Tensor*, + Tensor*, Tensor*); +template Status RingAllreduce(OpKernelContext*, + const Tensor*, Tensor*, + Tensor*); +template Status RingAllreduce(OpKernelContext*, const Tensor*, + Tensor*, Tensor*); + +// Generate all necessary specializations for RingAllgather. +template Status RingAllgather(OpKernelContext*, const Tensor*, + const std::vector&, + Tensor*); +template Status RingAllgather(OpKernelContext*, + const Tensor*, + const std::vector&, + Tensor*); +template Status RingAllgather(OpKernelContext*, const Tensor*, + const std::vector&, + Tensor*); + +// Synchronously copy data on the GPU, using a different stream than the default +// and than TensorFlow to avoid synchronizing on operations unrelated to the +// allreduce. +template <> +void CopyTensorData(void* dst, void* src, size_t size) { + auto stream = CudaStreamForMPI(); + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream); + cudaStreamSynchronize(stream); +}; + +// Elementwise accumulation kernel for GPU. +template +__global__ void elemwise_accum(T* out, const T* in, const size_t N) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + out[i] += in[i]; + } +} + +// Synchronously accumulate tensors on the GPU, using a different stream than +// the default and than TensorFlow to avoid synchronizing on operations +// unrelated to the allreduce. +#define GENERATE_ACCUMULATE(type) \ + template <> \ + void AccumulateTensorData(type * dst, type * src, \ + size_t size) { \ + auto stream = CudaStreamForMPI(); \ + elemwise_accum<<<32, 256, 0, stream>>>(dst, src, size); \ + cudaStreamSynchronize(stream); \ + }; +GENERATE_ACCUMULATE(int); +GENERATE_ACCUMULATE(long long); +GENERATE_ACCUMULATE(float); +#undef GENERATE_ACCUMULATE + +} // namespace mpi +} // namespace contrib +} // namespace tensorflow +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/ring.h b/tensorflow/contrib/mpi_collectives/ring.h new file mode 100644 index 0000000000000000000000000000000000000000..cae57ce60eb09509af69f8ccab9eacedea361548 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/ring.h @@ -0,0 +1,327 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_MPI_H_ +#define TENSORFLOW_CONTRIB_MPI_H_ + +#ifdef TENSORFLOW_USE_MPI + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +#if GOOGLE_CUDA +#include "cuda_runtime.h" +#endif + +// Needed to avoid header issues with C++-supporting MPI implementations +#define OMPI_SKIP_MPICXX +#include "third_party/mpi/mpi.h" + +#define TAG_TENSOR 12 + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +// Convert from templated types to values we can pass to MPI. +template +MPI_Datatype MPIType(); + +// Convert from templated types to TensorFlow data types. +template +DataType TensorFlowDataType(); + +#define MPI_REQUIRES_OK(MPI_STATUS) \ + if ((MPI_STATUS) != MPI_SUCCESS) { \ + return errors::Unknown("MPI operation failed unexpectedly."); \ + } + +// Copy data from one tensor to another tensor. +// This uses a custom CUDA stream on GPU, which is necessary to overlay the +// backpropagation computations with the allreduce. +template +void CopyTensorData(void* destination, void* source, size_t size); + +// Add a tensor into another tensor, accumulating in place. +// This uses a custom CUDA stream on GPU, which is necessary to overlay the +// backpropagation computations with the allreduce. +template +void AccumulateTensorData(T* destination, T* source, size_t size); + +// We need to get the right stream for doing CUDA memory transfers and +// operations, which is possibly different from the standard TensorFlow stream. +#if GOOGLE_CUDA +cudaStream_t CudaStreamForMPI(); +#endif + +/* Perform a ring allreduce on the data. Allocate the necessary output tensor + * and store it in the output parameter. + * + * Assumes that all MPI processes are doing an allreduce of the same tensor, + * with the same dimensions. + * + * A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the + * allreduce, the nodes involved are arranged in a ring: + * + * .--0--. + * / \ + * 3 1 + * \ / + * *--2--* + * + * Each node always sends to the next clockwise node in the ring, and receives + * from the previous one. + * + * The allreduce is done in two parts: a scatter-reduce and an allgather. In + * the scatter reduce, a reduction is done, so that each node ends up with a + * chunk of the final output tensor which has contributions from all other + * nodes. In the allgather, those chunks are distributed among all the nodes, + * so that all nodes have the entire output tensor. + * + * Both of these operations are done by dividing the input tensor into N + * evenly sized chunks (where N is the number of nodes in the ring). + * + * The scatter-reduce is done in N-1 steps. In the ith step, node j will send + * the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to + * its existing data for that chunk. For example, in the first iteration with + * the ring depicted above, you will have the following transfers: + * + * Segment 0: Node 0 --> Node 1 + * Segment 1: Node 1 --> Node 2 + * Segment 2: Node 2 --> Node 3 + * Segment 3: Node 3 --> Node 0 + * + * In the second iteration, you'll have the following transfers: + * + * Segment 0: Node 1 --> Node 2 + * Segment 1: Node 2 --> Node 3 + * Segment 2: Node 3 --> Node 0 + * Segment 3: Node 0 --> Node 1 + * + * After this iteration, Node 2 has 3 of the four contributions to Segment 0. + * The last iteration has the following transfers: + * + * Segment 0: Node 2 --> Node 3 + * Segment 1: Node 3 --> Node 0 + * Segment 2: Node 0 --> Node 1 + * Segment 3: Node 1 --> Node 2 + * + * After this iteration, Node 3 has the fully accumulated Segment 0; Node 0 + * has the fully accumulated Segment 1; and so on. The scatter-reduce is + * complete. + * + * Next, the allgather distributes these fully accumululated chunks across all + * nodes. Communication proceeds in the same ring, once again in N-1 steps. At + * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i). + * For example, at the first iteration, the following transfers will occur: + * + * Segment 0: Node 3 --> Node 0 + * Segment 1: Node 0 --> Node 1 + * Segment 2: Node 1 --> Node 2 + * Segment 3: Node 2 --> Node 3 + * + * After the first iteration, Node 0 will have a fully accumulated Segment 0 + * (from Node 3) and Segment 1. In the next iteration, Node 0 will send its + * just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3. + * After this has continued for N - 1 iterations, all nodes will have a the + * fully accumulated tensor. + * + * Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the + * allgather. Each send will contain K / N bytes, if there are K bytes in the + * original tensor on every node. Thus, each node sends and receives 2K(N - 1)/N + * bytes of data, and the performance of the allreduce (assuming no latency in + * connections) is constrained by the slowest interconnect between the nodes. + * + */ +template +Status RingAllreduce(OpKernelContext* context, const Tensor* input, + Tensor* temp, Tensor* output) { + // Acquire MPI size and rank + int n, r; + MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); + MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); + + T* buffer = (T*)output->tensor_data().data(); + + CopyTensorData((void*)buffer, (void*)input->tensor_data().data(), + output->tensor_data().size()); + + // Calculate segment sizes and segment ends + const size_t elements_to_reduce = input->NumElements(); + const size_t segment_size = elements_to_reduce / n; + std::vector segment_sizes(n, segment_size); + + const size_t residual = elements_to_reduce % n; + for (size_t i = 0; i < residual; ++i) { + segment_sizes[i]++; + } + + std::vector segment_starts(n); + segment_starts[0] = 0; + for (size_t i = 1; i < segment_starts.size(); ++i) { + segment_starts[i] = segment_starts[i - 1] + segment_sizes[i - 1]; + } + + assert(segment_starts[n - 1] + segment_sizes[n - 1] == elements_to_reduce); + + T* segment_recv = (T*)temp->tensor_data().data(); + + // Receive from your left neighbor with wrap-around + const size_t recv_from = ((r - 1) + n) % n; + + // Send to your right neighbor with wrap-around + const size_t send_to = (r + 1) % n; + + MPI_Status recv_status; + MPI_Request recv_req; + + // Now start ring. At every step, for every rank, we iterate through + // segments with wraparound and send and recv from our neighbors and reduce + // locally. At the i'th iteration, rank r, sends segment (r-i) and receives + // segment (r-i-1). + for (int i = 0; i < n - 1; i++) { + const size_t send_seg_id = ((r - i) + n) % n; + const size_t recv_seg_id = ((r - i - 1) + n) % n; + + T* segment_send = &(buffer[segment_starts[send_seg_id]]); + + MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[recv_seg_id], + MPIType(), recv_from, TAG_TENSOR, + MPI_COMM_WORLD, &recv_req)); + + MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[send_seg_id], + MPIType(), send_to, TAG_TENSOR, + MPI_COMM_WORLD)); + + T* segment_update = &(buffer[segment_starts[recv_seg_id]]); + + // Wait for recv to complete before reduction + MPI_REQUIRES_OK(MPI_Wait(&recv_req, &recv_status)); + + const size_t recv_seg_size = segment_sizes[recv_seg_id]; + AccumulateTensorData(segment_update, segment_recv, + recv_seg_size); + } + + // Now start pipelined ring allgather. At every step, for every rank, we + // iterate through segments with wraparound and send and recv from our + // neighbors. At the i'th iteration, rank r, sends segment (r-i+1) and + // receives segment (r-i). + for (size_t i = 0; i < n - 1; ++i) { + const size_t send_seg_id = ((r - i + 1) + n) % n; + const size_t recv_seg_id = ((r - i) + n) % n; + + // Segment to send - at every iteration we send segment (r-i+1) + T* segment_send = &(buffer[segment_starts[send_seg_id]]); + + // Segment to recv - at every iteration we receive segment (r-i) + T* segment_recv = &(buffer[segment_starts[recv_seg_id]]); + + MPI_REQUIRES_OK(MPI_Sendrecv( + segment_send, segment_sizes[send_seg_id], MPIType(), send_to, + TAG_TENSOR, segment_recv, segment_sizes[recv_seg_id], MPIType(), + recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); + } + + return Status::OK(); +} + +// Perform a ring allgather on a Tensor. Other ranks may allgather with a +// tensor which differs in the first dimension only; all other dimensions must +// be the same. +// +// For more information on the ring allgather, read the documentation for the +// ring allreduce, which includes a ring allgather. +template +Status RingAllgather(OpKernelContext* context, const Tensor* input, + const std::vector& sizes, Tensor* output) { + // Acquire MPI size and rank + int n, r; + MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); + MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); + + assert(sizes.size() == n); + assert(input->dim_size(0) == sizes[r]); + + // Compute number of elements in every "row". We can't compute number of + // elements in every chunks, because those chunks are variable length. + size_t elements_per_row = 1; + for (int i = 1; i < input->shape().dims(); i++) { + elements_per_row *= input->dim_size(i); + } + + // Copy data from input tensor to correct place in output tensor. + std::vector segment_starts(n); + segment_starts[0] = 0; + for (int i = 1; i < n; i++) { + segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1]; + } + size_t offset = segment_starts[r]; + + // Copy data to the right offset for this rank. + T* buffer = (T*)output->tensor_data().data(); + CopyTensorData((void*)(buffer + offset), + (void*)input->tensor_data().data(), + elements_per_row * sizes[r] * sizeof(T)); + + // Receive from your left neighbor with wrap-around + const size_t recv_from = ((r - 1) + n) % n; + + // Send to your right neighbor with wrap-around + const size_t send_to = (r + 1) % n; + + // Perform a ring allgather. At every step, for every rank, we iterate + // through segments with wraparound and send and recv from our neighbors. + // At the i'th iteration, rank r, sends segment (r-i) and receives segment + // (r-1-i). + MPI_Status recv_status; + for (size_t i = 0; i < n - 1; ++i) { + const size_t send_seg_id = ((r - i) + n) % n; + const size_t recv_seg_id = ((r - i - 1) + n) % n; + + // Segment to send - at every iteration we send segment (r-i) + size_t offset_send = segment_starts[send_seg_id]; + size_t rows_send = sizes[send_seg_id]; + T* segment_send = &(buffer[offset_send]); + + // Segment to recv - at every iteration we receive segment (r-1-i) + size_t offset_recv = segment_starts[recv_seg_id]; + size_t rows_recv = sizes[recv_seg_id]; + T* segment_recv = &(buffer[offset_recv]); + + MPI_REQUIRES_OK(MPI_Sendrecv( + segment_send, elements_per_row * rows_send, MPIType(), send_to, + TAG_TENSOR, segment_recv, elements_per_row * rows_recv, MPIType(), + recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); + } + + return Status::OK(); +} + +} // namespace mpi +} // namespace contrib +} // namespace tensorflow + +#endif // TENSORFLOW_USE_MPI + +#undef TENSORFLOW_CONTRIB_MPI_H_ +#endif // TENSORFLOW_CONTRIB_MPI_H_ diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index 334e70318dd88185cecd93ebeb2587861b7999b9..62996d1fd83f46145e9a1b773b1be57e27903127 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -19,17 +19,18 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "if_not_windows_cuda") tf_custom_op_library( name = "python/ops/_nccl_ops.so", srcs = [ "ops/nccl_ops.cc", ], - gpu_srcs = [ + gpu_srcs = if_not_windows_cuda([ "kernels/nccl_manager.cc", "kernels/nccl_manager.h", "kernels/nccl_ops.cc", - ], + ]), deps = if_cuda([ "@local_config_nccl//:nccl", "//tensorflow/core:gpu_headers_lib", @@ -97,18 +98,19 @@ tf_gen_op_wrapper_py( deps = [":nccl_ops_op_lib"], ) +# Test only nccl ops lib without dso to test behavior when NCCL lib is not +# installed. See nccl_dependency_test for more details. +# +# Users should use the public nccl_py lib that also adds the dso. tf_custom_op_py_library( - name = "nccl_py", + name = "nccl_ops_lib_without_dso", srcs = [ "__init__.py", "python/ops/nccl_ops.py", ], - dso = [":python/ops/_nccl_ops.so"], kernels = if_cuda([":nccl_kernels"]) + [ ":nccl_ops_op_lib", ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], deps = [ ":nccl_ops", "//tensorflow/contrib/util:util_py", @@ -120,6 +122,15 @@ tf_custom_op_py_library( ], ) +tf_custom_op_py_library( + name = "nccl_py", + dso = [":python/ops/_nccl_ops.so"], + visibility = ["//visibility:public"], + deps = [ + ":nccl_ops_lib_without_dso", + ], +) + cuda_py_test( name = "nccl_ops_test", size = "small", @@ -141,3 +152,25 @@ cuda_py_test( "notap", ], ) + +cuda_py_test( + name = "nccl_dependency_test", + size = "small", + srcs = ["python/ops/nccl_dependency_test.py"], + additional_deps = [ + ":nccl_ops_lib_without_dso", + "//tensorflow/python:constant_op", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + ], + # Disable this test internally as static linking is used internally and only + # run for OSS to verify that NCCL is an optional dynamic dependency. + tags = [ + "manual", + "noguitar", + "notap", + ], +) diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc index b1cb89391ceaa70813be47cc1bba0c16f4f70e77..99fecf96517935bf3bde3636df83b4a9a4e1c779 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc @@ -445,7 +445,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { se::Stream* comm_stream = nccl_stream->stream.get(); ScopedActivateExecutorContext scoped_context(nccl_stream->executor); const cudaStream_t* cu_stream = reinterpret_cast( - comm_stream->implementation()->CudaStreamMemberHack()); + comm_stream->implementation()->GpuStreamMemberHack()); while (true) { // Find collective to run. diff --git a/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py b/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c766080dbee7c9a6f4383ef6fa8cade7bba158af --- /dev/null +++ b/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py @@ -0,0 +1,59 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dependency test for nccl to test behavior when NCCL is not installed.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import nccl +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect + + +class NcclDependencyTest(test.TestCase): + """Verifies that importing nccl ops lib does not fail even if NCCL is not + installed but nccl ops throws an exception on use if NCCL is not installed. + """ + + def test_nccl_ops(self): + """Tests behavior of nccl ops when NCCL is not installed.""" + + public_methods = [ + m[0] + for m in tf_inspect.getmembers(nccl, tf_inspect.isfunction) + if not m[0].startswith('_') + ] + for method_name in public_methods: + with ops.device('/device:CPU:0'): + tensor = constant_op.constant(1) + + if method_name == 'broadcast': + arg = tensor + else: + arg = [tensor] + + nccl_op = getattr(nccl, method_name) + with ops.device('/device:CPU:0'): + with self.assertRaisesRegexp(errors_impl.NotFoundError, + r'cannot open shared object file'): + nccl_op(arg) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py index 794372a1f4b0dcc41bcf0da611f5bc2ec9301973..fa597cf3efcf915311047f3a483772c45cc314fd 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py @@ -26,8 +26,10 @@ from tensorflow.python.framework import device from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader -_nccl_ops_so = loader.load_op_library( - resource_loader.get_path_to_datafile('_nccl_ops.so')) + +_nccl_ops_so = None +_module_lock = threading.Lock() +_shared_name_counter = 0 def all_sum(tensors): @@ -61,12 +63,12 @@ def _all_sum_grad(op, grad): Raises: LookupError: If `reduction` is not `sum`. """ - if op.get_attr('reduction') != 'sum': + if op.get_attr('reduction') != b'sum': raise LookupError('No gradient defined for NcclAllReduce except sum.') _check_device(grad, expected=op.device) num_devices = op.get_attr('num_devices') - shared_name = op.get_attr('shared_name') + '_grad' + shared_name = op.get_attr('shared_name') + b'_grad' with ops.device(op.device): return gen_nccl_ops.nccl_all_reduce( @@ -160,7 +162,7 @@ def _reduce_sum_grad(op, grad): Raises: LookupError: If the reduction attribute of op is not `sum`. """ - if op.get_attr('reduction') != 'sum': + if op.get_attr('reduction') != b'sum': raise LookupError('No gradient defined for NcclReduce except sum.') _check_device(grad, expected=op.device) @@ -180,7 +182,7 @@ def broadcast(tensor): A tensor with the value of `src_tensor`, which can be used as input to ops on other GPU devices. """ - _check_graph_mode() + _validate_and_load_nccl_so() _check_device(tensor) with ops.device(tensor.device): @@ -212,7 +214,7 @@ def _apply_all_reduce(reduction, tensors): """Helper function for all_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to all reduce operations') - _check_graph_mode() + _validate_and_load_nccl_so() shared_name = _get_shared_name() res = [] @@ -234,7 +236,7 @@ def _apply_reduce(reduction, tensors): """Helper function for reduce_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to reduce operations') - _check_graph_mode() + _validate_and_load_nccl_so() for t in tensors: _check_device(t) @@ -246,14 +248,10 @@ def _apply_reduce(reduction, tensors): return result -_lock = threading.Lock() -_shared_name_counter = 0 - - def _get_shared_name(): global _shared_name_counter - with _lock: + with _module_lock: val = _shared_name_counter _shared_name_counter += 1 return 'c%s' % val @@ -266,6 +264,25 @@ def _check_device(tensor, expected=None): raise ValueError('Expected device %s, got %s' % (expected, tensor.device)) -def _check_graph_mode(): +def _maybe_load_nccl_ops_so(): + """Loads nccl ops so if it hasn't been loaded already.""" + + with _module_lock: + global _nccl_ops_so + if not _nccl_ops_so: + _nccl_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile('_nccl_ops.so')) + + +def _validate_and_load_nccl_so(): + """Validates calling context and loads nccl ops so file. + + Raises: + ValueError: Ops are not supported. + errors_impl.NotFoundError: nccl library is not installed. + """ + if context.executing_eagerly(): raise ValueError('Nccl ops are not supported in eager mode') + + _maybe_load_nccl_ops_so() diff --git a/tensorflow/contrib/nn/python/ops/alpha_dropout.py b/tensorflow/contrib/nn/python/ops/alpha_dropout.py index 2f92d05ba81f30a91f68f3c3ec51b6695d3d0371..98f4264fe0813d421f559594efae73608e53ca62 100644 --- a/tensorflow/contrib/nn/python/ops/alpha_dropout.py +++ b/tensorflow/contrib/nn/python/ops/alpha_dropout.py @@ -43,7 +43,7 @@ def alpha_dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylin noise_shape: A 1-D `Tensor` of type `int32`, representing the shape for randomly generated keep/drop flags. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. name: A name for this operation (optional). Returns: diff --git a/tensorflow/contrib/nn/python/ops/sampling_ops.py b/tensorflow/contrib/nn/python/ops/sampling_ops.py index e65925610c5f5125c2d2e92edc1cf708c54255d4..de71b0845e292b3ee03848afc6cc05c15286d9e8 100644 --- a/tensorflow/contrib/nn/python/ops/sampling_ops.py +++ b/tensorflow/contrib/nn/python/ops/sampling_ops.py @@ -123,15 +123,15 @@ def rank_sampled_softmax_loss(weights, """Computes softmax loss using rank-based adaptive resampling. This has been shown to improve rank loss after training compared to - @{tf.nn.sampled_softmax_loss}. For a description of the algorithm and some + `tf.nn.sampled_softmax_loss`. For a description of the algorithm and some experimental results, please see: [TAPAS: Two-pass Approximate Adaptive Sampling for Softmax](https://arxiv.org/abs/1707.03073). Sampling follows two phases: * In the first phase, `num_sampled` classes are selected using - @{tf.nn.learned_unigram_candidate_sampler} or supplied `sampled_values`. + `tf.nn.learned_unigram_candidate_sampler` or supplied `sampled_values`. The logits are calculated on those sampled classes. This phases is - similar to @{tf.nn.sampled_softmax_loss}. + similar to `tf.nn.sampled_softmax_loss`. * In the second phase, the `num_resampled` classes with highest predicted probability are kept. Probabilities are `LogSumExp(logits / resampling_temperature)`, where the sum is over @@ -142,7 +142,7 @@ def rank_sampled_softmax_loss(weights, picks more candidates close to the predicted classes. A common strategy is to decrease the temperature as training proceeds. - See @{tf.nn.sampled_softmax_loss} for more documentation on sampling and + See `tf.nn.sampled_softmax_loss` for more documentation on sampling and for typical default values for some of the parameters. This operation is for training only. It is generally an underestimate of @@ -197,7 +197,7 @@ def rank_sampled_softmax_loss(weights, where a sampled class equals one of the target classes. partition_strategy: A string specifying the partitioning strategy, relevant if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. - See @{tf.nn.embedding_lookup} for more details. + See `tf.nn.embedding_lookup` for more details. name: A name for the operation (optional). Returns: diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 114b344d38413208755a47f36f45badc1a5ecaa9..778b710d78a2095b8a1315018641c67419c26b98 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -19,6 +19,7 @@ py_library( "python/training/drop_stale_gradient_optimizer.py", "python/training/elastic_average_optimizer.py", "python/training/external_optimizer.py", + "python/training/ggt.py", "python/training/lazy_adam_optimizer.py", "python/training/model_average_optimizer.py", "python/training/moving_average_optimizer.py", @@ -26,18 +27,22 @@ py_library( "python/training/nadam_optimizer.py", "python/training/powersign.py", "python/training/reg_adagrad_optimizer.py", + "python/training/shampoo.py", "python/training/sign_decay.py", "python/training/variable_clipping_optimizer.py", "python/training/weight_decay_optimizers.py", ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/optimizer_v2:optimizer_v2_py", "//tensorflow/python:array_ops", "//tensorflow/python:clip_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:state_ops", @@ -322,3 +327,41 @@ py_test( "//third_party/py/numpy", ], ) + +py_test( + name = "ggt_test", + srcs = ["python/training/ggt_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "shampoo_test", + size = "large", + srcs = ["python/training/shampoo_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 5df5d35f8e4f8fcc2c5aa09bd8f3254e16e3a74f..9471fb018162ee377e9c614d6e4d745b4282165a 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -22,16 +22,18 @@ from __future__ import print_function from tensorflow.contrib.opt.python.training.adamax import * from tensorflow.contrib.opt.python.training.addsign import * from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import * +from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * from tensorflow.contrib.opt.python.training.external_optimizer import * +from tensorflow.contrib.opt.python.training.ggt import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.nadam_optimizer import * +from tensorflow.contrib.opt.python.training.shampoo import * from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * -from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * -from tensorflow.contrib.opt.python.training.model_average_optimizer import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented @@ -58,7 +60,9 @@ _allowed_symbols = [ 'ElasticAverageOptimizer', 'ElasticAverageCustomGetter', 'ModelAverageOptimizer', - 'ModelAverageCustomGetter' + 'ModelAverageCustomGetter', + 'GGTOptimizer', + 'ShampooOptimizer', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/opt/python/training/addsign_test.py b/tensorflow/contrib/opt/python/training/addsign_test.py index 08d45ed73f3ae4b580d7078272e79fef22ef67c5..628a735e721d2f0c594dd59b5193499dfd7da02e 100644 --- a/tensorflow/contrib/opt/python/training/addsign_test.py +++ b/tensorflow/contrib/opt/python/training/addsign_test.py @@ -214,7 +214,7 @@ class AddSignTest(test.TestCase): # Run 7 steps of AddSign # first 4 steps with positive gradient # last 3 steps with negative gradient (sign(gm) should be -1) - for t in range(1, 4): + for t in range(1, 8): if t < 5: update.run() else: @@ -222,7 +222,7 @@ class AddSignTest(test.TestCase): var0_np, m0 = addsign_update_numpy( var0_np, - grads0_np, + grads0_np if t < 5 else -grads0_np, m0, learning_rate, alpha=alpha, @@ -232,7 +232,7 @@ class AddSignTest(test.TestCase): ) var1_np, m1 = addsign_update_numpy( var1_np, - grads1_np, + grads1_np if t < 5 else -grads1_np, m1, learning_rate, alpha=alpha, diff --git a/tensorflow/contrib/opt/python/training/ggt.py b/tensorflow/contrib/opt/python/training/ggt.py new file mode 100644 index 0000000000000000000000000000000000000000..cae952d8f50acbc3a176697fb3989db6c9ac3e9b --- /dev/null +++ b/tensorflow/contrib/opt/python/training/ggt.py @@ -0,0 +1,312 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GGT for Tensorflow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import numpy as np +from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops + + +class GGTOptimizer(optimizer_v2.OptimizerV2): + """Optimizer that implements the GGT algorithm. + + GGT has an advantage over sgd and adam on large models with poor conditioning, + for example language models and CNNs, + see [[ABCHSZZ 2018]](https://arxiv.org/pdf/1806.02958.pdf). + """ + + def __init__(self, + learning_rate=0.001, + beta1=0.9, + use_locking=False, + name="GGT", + window=10, + eps=1e-4, + svd_eps=1e-6, + sigma_eps=1e-2): + """Construct a new GGT optimizer. + + Initialization: + + ``` + t <- 0 (Initialize timestep) + grad_buffer <- 0 (Initialize buffer for keeping past gradients) + flat_grad <- 0 (Initialize flattened gradient that contains gradients of all + variables) + m_0 <- 0 (Initialize 1st moment vector) + ``` + + Suppose all variables and their gradients are concatenated into vectors + `flat_vars` and `flat_grad`. The update rule for `flat_vars` + uses an optimization described at the beginning of section 2 of the paper: + + ``` + t <- t + 1 + + m_t <- beta1 * m_{t-1} + (1 - beta1) * flat_grad + grad_buffer[(t-1) % window, :] <- m_t + + M <- grad_buffer^T / sqrt(min(t, window)) + U, sigma, _ <- SVD(M^TM + I * svd_eps) + + sigma_sqrt_inv <- (sqrt(sigma) + sigma_eps)^(-3) + sigma_sqrt_min <- min(sqrt(sigma)) + + if sigma_sqrt_min > eps: + new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t + + (m_t - M U diag(1/sigma) U^T M^T m_t) / sigma_sqrt_min + else: + new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t + + flat_vars <- flat_vars - learning_rate * new_step + ``` + + GGT provides the power of full-matrix adaptive regularization at a cost not + much larger than SGD. As a result it is suited for large models where the + gradient covariance matrix has a poor condition number that slows down first + order methods. + GGT uses the preconditioner from full-matrix AdaGrad, with gradient history + attenuated exponentially as in Adam, and truncated to a window parameter. + It has provable guarantees even for non-convex optimization that is never + significantly worse than SGD and in some cases better. + + Args: + learning_rate: A float hyperparameter. The learning rate. + beta1: A float hyperparameter. The exponential decay rate for the 1st + moment estimates. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "GGT". + window: An integer hyperparameter. The number of first moments to keep in + computing the adaptive preconditioner. + eps: A float hyperparameter. Used to truncate small eigenvalues of the + gradient covariance matrix. + svd_eps: A float hyperparameter. Used to stabilize SVD. + sigma_eps: A float hyperparameter. Used to regularize matrix inversion. + """ + super(GGTOptimizer, self).__init__(use_locking, name) + self._set_hyper("lr", learning_rate) + self._set_hyper("beta1", beta1) + self._set_hyper("window", window) + self._set_hyper("eps", eps) + self._set_hyper("svd_eps", svd_eps) + self._set_hyper("sigma_eps", sigma_eps) + + self.index_dict = {} + self.shape_dict = {} + + def _create_vars(self, var_list, state): + # Construct ordered dictionary for variable dimensions, sorted by name. + shape_dict = {} + for v in var_list: + shape_dict[v.name] = np.prod(v.get_shape()).value + self.shape_dict = collections.OrderedDict( + sorted(shape_dict.items(), key=lambda t: t[0])) + + # Assign each variable its location in flat_grad. The locations are based on + # the order of sorted names. + idx = 0 + for v_name, v_dim in self.shape_dict.items(): + self.index_dict[v_name] = idx + idx += v_dim + + state.create_non_slot( + initial_value=math_ops.cast(0., dtype=var_list[0].dtype.base_dtype), + name="global_step") + + # Buffer for keeping past gradients. + window = state.get_hyper("window") + grad_buffer_init = array_ops.zeros( + [window, idx], dtype=var_list[0].dtype.base_dtype) + state.create_non_slot(initial_value=grad_buffer_init, name="grad_buffer") + + state.create_non_slot( + initial_value=array_ops.zeros( + (idx,), dtype=var_list[0].dtype.base_dtype), + name="moment1") + + # Flattened gradient that contains gradients for all variables in the model. + state.create_non_slot( + initial_value=array_ops.zeros( + (idx,), dtype=var_list[0].dtype.base_dtype), + name="flat_grad") + + def _get_global_step(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("global_step") + + def _get_moment1(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("moment1") + + def _get_grad_buffer(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("grad_buffer") + + def _get_flat_grad(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("flat_grad") + + def _apply_sparse(self, grad, var): + raise NotImplementedError("Sparse gradient updates are not supported.") + + def _prepare(self, state): + self._variables = [] + + def _apply_dense(self, grad, var, state): + self._variables.append(var) + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + + # Update flat_gradient at the index associated with the variable. + flat_grad = self._get_flat_grad(state) + new_flat_grad = array_ops.reshape(grad, [-1]) + flat_grad_updated = state_ops.scatter_update( + flat_grad, math_ops.range(start_index, end_index), new_flat_grad) + + return flat_grad_updated + + def _resource_apply_dense(self, grad, var, state): + self._variables.append(var) + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + + # Update flat_gradient at the index associated with the variable. + flat_grad = self._get_flat_grad(state) + new_flat_grad = array_ops.reshape(grad, [-1]) + flat_grad_updated = state_ops.scatter_update( + flat_grad, math_ops.range(start_index, end_index), new_flat_grad) + + return flat_grad_updated + + def _finish(self, state): + var_dtype = self._variables[0].dtype.base_dtype + # Update global step. + global_step = self._get_global_step(state) + update_global_step = state_ops.assign_add(global_step, 1.) + + # Update the first moment estimate. + beta1 = state.get_hyper("beta1", dtype=var_dtype) + moment1 = self._get_moment1(state) + flat_grad = self._get_flat_grad(state) + # moment1_t := beta1 * moment1_{t-1} + (1 - beta1) * flat_grad_t + update_moment1 = moment1.assign(beta1 * moment1 + (1. - beta1) * flat_grad) + + # Update the gradient buffer. + window = state.get_hyper("window") + grad_buffer = self._get_grad_buffer(state) + next_grad_index = math_ops.floormod( + math_ops.to_int32(update_global_step - 1.), window) + # grad_buffer[(t-1) % window] := moment1_t + update_grad_buffer = state_ops.scatter_update(grad_buffer, next_grad_index, + update_moment1) + + # Compute the update step. + eps = state.get_hyper("eps", dtype=var_dtype) + svd_eps = state.get_hyper("svd_eps", dtype=var_dtype) + sigma_eps = state.get_hyper("sigma_eps", dtype=var_dtype) + lr = state.get_hyper("lr", dtype=var_dtype) + denom = math_ops.sqrt( + math_ops.minimum( + ops.convert_to_tensor(update_global_step), + ops.convert_to_tensor(math_ops.cast(window, dtype=var_dtype)))) + moment1_2d = array_ops.expand_dims(update_moment1, -1) + + # m = grad_buffer^T / sqrt(min(t, window)) + # m has shape [model dimension, window], where model dimension is the sum + # of the dimensions of the flattened variables. + m = array_ops.transpose(math_ops.divide(update_grad_buffer, denom)) + + # sigma, u, _ = SVD(m^Tm + I * svd_eps) + mm = math_ops.matmul(m, m, transpose_a=True) + damping = math_ops.cast(linalg_ops.eye(window), dtype=var_dtype) * svd_eps + sigma, u, _ = linalg_ops.svd(mm + damping) + sigma_sqrt = math_ops.sqrt(sigma) + sigma_sqrt_min = math_ops.reduce_min(sigma_sqrt) + + # sigma_sqrt_inv = 1 / (\sqrt{sigma} + sigma_eps) ^ 3 + # We add sigma_eps to alleviate numerical instability. + # Note that (m^Tm)^(-3/2) = u diag(sigma_sqrt_inv) u^T. + sigma_sqrt_inv = math_ops.divide( + math_ops.cast(1.0, dtype=var_dtype), + math_ops.pow(sigma_sqrt + sigma_eps, 3)) + + # In full matrix AdaGrad, the update step computes (mm^T)^(-1/2)g, where the + # inversion of a model dimension by model dimension matrix is needed. To + # speed up this computation we calculate the following instead: + # m(m^Tm)^(-3/2)m^T moment1 = m u diag(sigma_sqrt_inv) u^T m^T moment1. + new_step = array_ops.expand_dims( + array_ops.zeros(flat_grad.get_shape(), dtype=var_dtype), -1) + head = math_ops.matmul( + m, + math_ops.matmul( + u, + math_ops.matmul( + array_ops.diag(sigma_sqrt_inv), + math_ops.matmul( + u, + math_ops.matmul(m, moment1_2d, transpose_a=True), + transpose_a=True)))) + + # When inverting (mm^t)^(1/2), we also add epsilon * I regularization for + # degenerate cases. We expand ((mm^t)^(1/2) + epsilon * I)^(-1) using + # Woodbury's identity. + # For full derivation please see paper at + # https://arxiv.org/pdf/1806.02958.pdf + tail = moment1_2d - math_ops.matmul( + m, + math_ops.matmul( + u, + math_ops.matmul( + array_ops.diag( + math_ops.divide(math_ops.cast(1.0, dtype=var_dtype), + sigma)), + math_ops.matmul( + u, + math_ops.matmul(m, moment1_2d, transpose_a=True), + transpose_a=True)))) + scaled_tail = math_ops.divide(tail, sigma_sqrt_min) + + update_new_step = control_flow_ops.cond( + sigma_sqrt_min > eps, lambda: math_ops.add(head, scaled_tail), + lambda: math_ops.add(new_step, head)) + + # Update each variable. + update_step = [] + for var in self._variables: + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + var_update_correct_shape = array_ops.reshape( + update_new_step[start_index:end_index], var.get_shape()) + var_updated = state_ops.assign_sub(var, lr * var_update_correct_shape) + update_step.append(var_updated) + + return control_flow_ops.group(update_step) diff --git a/tensorflow/contrib/opt/python/training/ggt_test.py b/tensorflow/contrib/opt/python/training/ggt_test.py new file mode 100644 index 0000000000000000000000000000000000000000..42162960b049cd90c663989fb4fc9d7f179a84ff --- /dev/null +++ b/tensorflow/contrib/opt/python/training/ggt_test.py @@ -0,0 +1,183 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for GGTOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.opt.python.training.ggt import GGTOptimizer +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def ggt_update_numpy(param, + g_t, + lr, + grad_buffer, + m, + window, + t, + beta1=0.9, + eps=1e-4, + svd_eps=1e-6, + sigma_eps=1e-2): + """Tests the correctness of one step of GGT.""" + m_t = m * beta1 + (1 - beta1) * g_t + grad_buffer[((t - 1) % window), :] = m_t + m_matrix = np.transpose(grad_buffer / np.sqrt(np.minimum(t, window))) + mm = np.dot(np.transpose(m_matrix), m_matrix) + damping = np.eye(window) * svd_eps + u, sigma, _ = np.linalg.svd(mm + damping) + + sigma_sqrt_inv = np.power(np.sqrt(sigma) + sigma_eps, -3) + new_step = np.linalg.multi_dot([ + m_matrix, u, + np.diag(sigma_sqrt_inv), + np.transpose(u), + np.transpose(m_matrix), m_t + ]) + + sigma_sqrt_min = np.sqrt(sigma).min() + + if sigma_sqrt_min > eps: + new_step += (m_t - np.linalg.multi_dot([ + m_matrix, u, + np.diag(1.0 / sigma), + np.transpose(u), + np.transpose(m_matrix), m_t + ])) * (1.0 / sigma_sqrt_min) + + param_t = param - lr * new_step + return param_t, m_t, grad_buffer + + +class GGTOptimizerTest(test.TestCase): + + def doTestBasic(self, use_resource=False): + # SVD does not support float16 + for i, dtype in enumerate([dtypes.float32, dtypes.float64]): + with self.test_session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0 = 0.0 + window = 3 + grad_buffer = np.zeros((window, 4), dtype=dtype.as_numpy_dtype) + lr = 0.001 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np, name="var0") + var1 = variables.Variable(var1_np, name="var1") + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = GGTOptimizer(learning_rate=lr, window=window) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + opt_variables = opt.variables() + + m_t = opt._get_moment1() + grad_buffer_t = opt._get_grad_buffer() + g_t = opt._get_flat_grad() + self.assertTrue(m_t is not None) + self.assertTrue(grad_buffer_t is not None) + self.assertTrue(g_t is not None) + self.assertIn(m_t, opt_variables) + self.assertIn(grad_buffer_t, opt_variables) + self.assertIn(g_t, opt_variables) + + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + m_t = opt._get_moment1() + grad_buffer_t = opt._get_grad_buffer() + g_t = opt._get_flat_grad() + + # Run 3 steps of GGT + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + if t == 1: + self.assertAllCloseAccordingToType( + np.array([0.01, 0.01, 0.001, 0.001]), self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, 0.001], [0., 0., 0., 0.], + [0., 0., 0., 0.]]), self.evaluate(grad_buffer_t)) + elif t == 2: + self.assertAllCloseAccordingToType( + np.array([0.019, 0.019, 0.0019, 0.0019]), self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, 0.001], + [0.019, 0.019, 0.0019, 0.0019], [0., 0., 0., 0.]]), + self.evaluate(grad_buffer_t)) + else: + self.assertAllCloseAccordingToType( + np.array([0.0271, 0.0271, 0.00271, 0.00271]), + self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, + 0.001], [0.019, 0.019, 0.0019, 0.0019], + [0.0271, 0.0271, 0.00271, 0.00271]]), + self.evaluate(grad_buffer_t)) + + self.assertAllCloseAccordingToType([0.1, 0.1, 0.01, 0.01], + self.evaluate(g_t)) + + var_np = np.append(var0_np, var1_np) + grads_np = np.append(grads0_np, grads1_np) + var_np, m0, grad_buffer = ggt_update_numpy(var_np, grads_np, lr, + grad_buffer, m0, window, t) + + var0_np = var_np[:2] + var1_np = var_np[2:] + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testBasic(self): + with self.test_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/powersign_test.py b/tensorflow/contrib/opt/python/training/powersign_test.py index 5214082dd66f00eadadad71d50f7e00b178b8c10..0bcf5d230a8b7b5b778d233a79922dc34449f8dd 100644 --- a/tensorflow/contrib/opt/python/training/powersign_test.py +++ b/tensorflow/contrib/opt/python/training/powersign_test.py @@ -216,7 +216,7 @@ class PowerSignTest(test.TestCase): self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([3.0, 4.0], var1.eval()) - # Run 3 steps of powersign + # Run 7 steps of powersign # first 4 steps with positive gradient # last 3 steps with negative gradient (sign(gm) should be -1) for t in range(1, 8): diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py new file mode 100644 index 0000000000000000000000000000000000000000..294627f42a839f399f747bcd6ba50968aadb35a1 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/shampoo.py @@ -0,0 +1,474 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""The Shampoo Optimizer. + +Variant of Adagrad using one preconditioner matrix per variable dimension. +For details, see https://arxiv.org/abs/1802.09568 +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.platform import tf_logging +from tensorflow.python.training import optimizer + + +def GetParam(var, timestep): + if callable(var): + return var(timestep) + else: + return var + + +class ShampooOptimizer(optimizer.Optimizer): + """The Shampoo Optimizer + + Variant of Adagrad using one preconditioner matrix per variable dimension. + For details, see https://arxiv.org/abs/1802.09568 + + gbar is time-weighted accumulated gradient: + gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] + + mat_gbar is time-weighted accumulated gradient square: + mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] + + mat_gbar_weight[t] * gg_j[t] + where if g[t] = g_abcd then gg_a[t] = g_abcd g_a'bcd (Einstein notation) + + Update rule: + w[t+1] = w[t] - learning_rate[t] * Prod_j mat_gbar_j[t]^(-alpha/n) gbar[t] + Again, mat_gbar_j[t]^(-alpha) gbar[t] is a tensor contraction along the + j'th dimension of gbar[t] with the first dimension of + mat_gbar_j[t]^(-alpha/n), where alpha is a hyperparameter, + and n = rank of the variable. + Prod_j represents doing this contraction for all j in 0..n-1. + + Typically learning_rate is constant, but could be time dependent by passing + a lambda function that depends on step. + """ + + def __init__(self, + global_step=0, + max_matrix_size=768, + gbar_decay=0.0, + gbar_weight=1.0, + mat_gbar_decay=1.0, + mat_gbar_weight=1.0, + learning_rate=1.0, + svd_interval=1, + precond_update_interval=1, + epsilon=0.1, + alpha=0.5, + use_iterative_root=False, + use_locking=False, + name="Shampoo"): + """Default values of the various hyper-parameters. + + gbar_decay, gbar_weight etc. can be a float or a time varying parameter. + For time-varying parameters use e.g. "lambda T: T / (T + 1.0)" + where the expression in the lambda is a tensorflow expression + + Args: + global_step: tensorflow variable indicating the step. + max_matrix_size: We do not perform SVD for matrices larger than this. + gbar_decay: + gbar_weight: Used to update gbar: + gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] + mat_gbar_decay: + mat_gbar_weight: Used to update mat_gbar: + mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] + + mat_gbar_weight[t] * gg_j[t] + learning_rate: Similar to SGD + svd_interval: We should do SVD after this many steps. Default = 1, i.e. + every step. Usually 20 leads to no loss of accuracy, and + 50 or 100 is also OK. May also want more often early, + and less often later - set in caller as for example: + "svd_interval = lambda(T): tf.cond( + T < 2000, lambda: 20.0, lambda: 1000.0)" + precond_update_interval: We should update the preconditioners after + this many steps. Default = 1. Usually less than + svd_interval. + epsilon: epsilon * I_n is added to each mat_gbar_j for stability + alpha: total power of the preconditioners. + use_iterative_root: should the optimizer use SVD (faster) or the + iterative root method (for TPU) for finding the + roots of PSD matrices. + use_locking: + name: name of optimizer. + """ + + super(ShampooOptimizer, self).__init__(use_locking, name) + + self._global_step = math_ops.to_float(global_step) + self._max_matrix_size = max_matrix_size + self._gbar_decay = gbar_decay + self._gbar_weight = gbar_weight + self._mat_gbar_decay = mat_gbar_decay + self._mat_gbar_weight = mat_gbar_weight + self._learning_rate = learning_rate + self._svd_interval = svd_interval + self._precond_update_interval = precond_update_interval + self._epsilon = epsilon + self._alpha = alpha + self._use_iterative_root = use_iterative_root + self._name = name + + def _create_slots(self, var_list): + for v in var_list: + with ops.colocate_with(v): + _ = self._zeros_slot(v, "gbar", self._name) + shape = np.array(v.get_shape()) + for i, d in enumerate(shape): + d_tensor = ops.convert_to_tensor(d) + if d <= self._max_matrix_size: + mat_g_init = array_ops.zeros_like(linalg_ops.eye(d_tensor)) + if self._svd_interval > 1: + _ = self._get_or_make_slot(v, linalg_ops.eye(d_tensor), + "H_" + str(i), self._name) + else: + mat_g_init = array_ops.zeros([d_tensor]) + + _ = self._get_or_make_slot(v, mat_g_init, "Gbar_" + str(i), + self._name) + + def _resource_apply_dense(self, grad, var): + return self._apply_dense(grad, var) + + def _apply_dense(self, grad, var): + return self._apply_gradient(grad, var) + + def _resource_apply_sparse(self, grad_values, var, grad_indices): + return self._apply_sparse_shared(grad_values, grad_indices, var) + + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared(grad.values, grad.indices, var) + + def _apply_sparse_shared(self, grad_values, grad_indices, var): + if var.get_shape()[0] <= self._max_matrix_size or self._gbar_decay != 0.0: + # The dimension is small enough, we can make the variable dense and + # do a dense update + dense_grad = array_ops.scatter_nd( + array_ops.expand_dims(grad_indices, axis=1), grad_values, + array_ops.shape(var, out_type=grad_indices.dtype)) + return self._apply_gradient(dense_grad, var) + return self._apply_gradient(grad_values, var, grad_indices) + + def _weighted_average(self, var, weight, weight_t, rest): + """Computes exponential weighted average: var = weight_t * var + rest. + + Important to ensure that var does not occur in rest, otherwise + we can get race conditions in a distributed setting. + + Args: + var: variable to be updated + weight: parameter to be checked. If it is a constant, we can optimize. + weight_t: current value of parameter, used for weighting + rest: the remaining tensor to be added + + Returns: + updated variable. + """ + if weight == 0.0: + return rest # no need to update var, we will never use it. + if weight == 1.0: # common case + return state_ops.assign_add(var, rest) + # The op below can cause race conditions in a distributed setting, + # since computing weight_t * var + rest can take some time, during + # which var may be set by another worker. To prevent this, it should + # be implemented as a C++ op. + return var.assign_add((weight_t - 1) * var + rest) + + def _update_mat_g(self, mat_g, grad, axes, mat_gbar_decay, + mat_gbar_weight, i): + """Updates the cumulative outer products of the gradients. + + Args: + mat_g: the matrix to be updated + grad: the gradient of the variable + axes: a list of k-1 integers 0 to k-1, except i + mat_gbar_decay: constant for weighted average: + mat_g = mat_g * decay + grad * weight + mat_gbar_weight: constant for weighted average + i: index of dimension to be updated. + + Returns: + updated mat_g = mat_g * mat_gbar_decay + grad_outer * mat_gbar_weight + + In Einstein notation if i = 0: grad_outer_aa'= g_abcd g_a'bcd + thus grad_outer is a matrix d_i x d_i, where d_i is the size of the + i'th dimension of g. + Alternate view: If mat_i(grad) is the flattening of grad to a + d_i x (d_1d_2...d_{i-1}d_{i+1}...d_k) matrix, then + grad_outer = mat_i(grad) mat_i(grad).transpose + """ + grad_outer = math_ops.tensordot(grad, grad, axes=(axes, axes), + name="grad_outer_" + str(i)) + return self._weighted_average(mat_g, self._mat_gbar_decay, mat_gbar_decay, + mat_gbar_weight * grad_outer) + + def _compute_power_svd(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name): + """Computes mat_h = mat_g^alpha using svd. mat_g is a symmetric PSD matrix. + + Args: + var: the variable we are updating. + mat_g: the symmetric PSD matrix whose power it to be computed + mat_g_size: size of mat_g + alpha: a real number + mat_h_slot_name: name of slot to store the power, if needed. + + Returns: + mat_h = mat_g^alpha + + Stores mat_h in the appropriate slot, if it exists. + Note that mat_g is PSD. So we could use linalg_ops.self_adjoint_eig. + """ + if mat_g_size == 1: + mat_h = math_ops.pow(mat_g + self._epsilon, alpha) + else: + damping = self._epsilon * linalg_ops.eye(math_ops.to_int32(mat_g_size)) + diag_d, mat_u, mat_v = linalg_ops.svd(mat_g + damping, full_matrices=True) + mat_h = math_ops.matmul( + mat_v * math_ops.pow(math_ops.maximum(diag_d, self._epsilon), alpha), + array_ops.transpose(mat_u)) + if mat_h_slot_name is not None: + return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) + return mat_h + + def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name, + iter_count=100, epsilon=1e-6): + """Computes mat_g^alpha, where alpha = -1/p, p a positive integer. + + We use an iterative Schur-Newton method from equation 3.2 on page 9 of: + + A Schur-Newton Method for the Matrix p-th Root and its Inverse + by Chun-Hua Guo and Nicholas J. Higham + SIAM Journal on Matrix Analysis and Applications, + 2006, Vol. 28, No. 3 : pp. 788-804 + https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf + + Args: + var: the variable we are updating. + mat_g: the symmetric PSD matrix whose power it to be computed + mat_g_size: size of mat_g. + alpha: exponent, must be -1/p for p a positive integer. + mat_h_slot_name: name of slot to store the power, if needed. + iter_count: Maximum number of iterations. + epsilon: accuracy indicator, useful for early termination. + + Returns: + mat_g^alpha + """ + + identity = linalg_ops.eye(math_ops.to_int32(mat_g_size)) + + def MatPower(mat_m, p): + """Computes mat_m^p, for p a positive integer. + + Power p is known at graph compile time, so no need for loop and cond. + Args: + mat_m: a square matrix + p: a positive integer + + Returns: + mat_m^p + """ + assert p == int(p) and p > 0 + power = None + while p > 0: + if p % 2 == 1: + power = math_ops.matmul(mat_m, power) if power is not None else mat_m + p //= 2 + mat_m = math_ops.matmul(mat_m, mat_m) + return power + + def IterCondition(i, mat_m, _): + return math_ops.logical_and( + i < iter_count, + math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon) + + def IterBody(i, mat_m, mat_x): + mat_m_i = (1 - alpha) * identity + alpha * mat_m + return (i + 1, math_ops.matmul(MatPower(mat_m_i, -1.0/alpha), mat_m), + math_ops.matmul(mat_x, mat_m_i)) + + if mat_g_size == 1: + mat_h = math_ops.pow(mat_g + self._epsilon, alpha) + else: + damped_mat_g = mat_g + self._epsilon * identity + z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g)) + # The best value for z is + # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) / + # (c_max^{1-alpha} - c_min^{1-alpha}) + # where c_max and c_min are the largest and smallest singular values of + # damped_mat_g. + # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha) + # Can replace above line by the one below, but it is less accurate, + # hence needs more iterations to converge. + # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g) + # If we want the method to always converge, use z = 1 / norm(damped_mat_g) + # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many + # extra iterations. + _, _, mat_h = control_flow_ops.while_loop( + IterCondition, IterBody, + [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)]) + if mat_h_slot_name is not None: + return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) + return mat_h + + def _compute_power(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name=None): + """Just a switch between the iterative power vs svd.""" + with ops.name_scope("matrix_iterative_power"): + if self._use_iterative_root: + return self._compute_power_iter(var, mat_g, mat_g_size, alpha, + mat_h_slot_name) + else: + return self._compute_power_svd(var, mat_g, mat_g_size, alpha, + mat_h_slot_name) + + def _apply_gradient(self, grad, var, indices=None): + """The main function to update a variable. + + Args: + grad: A Tensor containing gradient to apply. + var: A Tensor containing the variable to update. + indices: An array of integers, for sparse update. + + Returns: + Updated variable var = var - learning_rate * preconditioner * grad + + If the gradient is dense, var and grad have the same shape. + If the update is sparse, then the first dimension of the gradient and var + may differ, others are all the same. In this case the indices array + provides the set of indices of the variable which are to be updated with + each row of the gradient. + """ + global_step = self._global_step + 1 + + # Update accumulated weighted average of gradients + gbar = self.get_slot(var, "gbar") + gbar_decay_t = GetParam(self._gbar_decay, global_step) + gbar_weight_t = GetParam(self._gbar_weight, global_step) + if indices is not None: + # Note - the sparse update is not easily implemented, since the + # algorithm needs all indices of gbar to be updated + # if mat_gbar_decay != 1 or mat_gbar_decay != 0. + # One way to make mat_gbar_decay = 1 is by rescaling. + # If we want the update: + # G_{t+1} = a_{t+1} G_t + b_{t+1} w_t + # define: + # r_{t+1} = a_{t+1} * r_t + # h_t = G_t / r_t + # Then: + # h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t + # So we get the mat_gbar_decay = 1 as desired. + # We can implement this in a future version as needed. + # However we still need gbar_decay = 0, otherwise all indices + # of the variable will need to be updated. + if self._gbar_decay != 0.0: + tf_logging.warning("Not applying momentum for variable: %s" % var.name) + gbar_updated = grad + else: + gbar_updated = self._weighted_average(gbar, self._gbar_decay, + gbar_decay_t, + gbar_weight_t * grad) + + # Update the preconditioners and compute the preconditioned gradient + shape = var.get_shape() + mat_g_list = [] + for i in range(len(shape)): + mat_g_list.append(self.get_slot(var, "Gbar_" + str(i))) + mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step) + mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step) + + preconditioned_grad = gbar_updated + v_rank = len(mat_g_list) + neg_alpha = - GetParam(self._alpha, global_step) / v_rank + svd_interval = GetParam(self._svd_interval, global_step) + precond_update_interval = GetParam(self._precond_update_interval, + global_step) + for i, mat_g in enumerate(mat_g_list): + # axes is the list of indices to reduce - everything but the current i. + axes = list(range(i)) + list(range(i+1, v_rank)) + if shape[i] <= self._max_matrix_size: + # If the tensor size is sufficiently small perform full Shampoo update + # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this + # is not strictly correct. However we will use it for now, and + # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg) + + # pylint: disable=g-long-lambda,cell-var-from-loop + mat_g_updated = control_flow_ops.cond( + math_ops.mod(global_step, precond_update_interval) < 1, + lambda: self._update_mat_g( + mat_g, grad, axes, mat_gbar_decay_t, + mat_gbar_weight_t * precond_update_interval, i), + lambda: mat_g) + + if self._svd_interval == 1: + mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha) + else: + mat_h = control_flow_ops.cond( + math_ops.mod(global_step, svd_interval) < 1, + lambda: self._compute_power(var, mat_g_updated, shape[i], + neg_alpha, "H_" + str(i)), + lambda: self.get_slot(var, "H_" + str(i))) + + # mat_h is a square matrix of size d_i x d_i + # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor + # After contraction with a d_i x d_i tensor + # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor + # (the first dimension is contracted out, and the second dimension of + # mat_h is appended). After going through all the indices, it becomes + # a d_0 x ... x d_n tensor again. + preconditioned_grad = math_ops.tensordot(preconditioned_grad, mat_h, + axes=([0], [0]), + name="precond_" + str(i)) + else: + # Tensor size is too large -- perform diagonal Shampoo update + grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) + if i == 0 and indices is not None: + assert self._mat_gbar_decay == 1.0 + mat_g_updated = state_ops.scatter_add(mat_g, indices, + mat_gbar_weight_t * grad_outer) + mat_h = math_ops.pow( + array_ops.gather(mat_g_updated, indices) + self._epsilon, + neg_alpha) + else: + mat_g_updated = self._weighted_average(mat_g, + self._mat_gbar_decay, + mat_gbar_decay_t, + mat_gbar_weight_t * grad_outer) + mat_h = math_ops.pow(mat_g_updated + self._epsilon, neg_alpha) + + # Need to do the transpose to ensure that the tensor becomes + # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above. + preconditioned_grad = array_ops.transpose( + preconditioned_grad, perm=list(range(1, v_rank)) + [0]) * mat_h + + # Update the variable based on the Shampoo update + learning_rate_t = GetParam(self._learning_rate, global_step) + if indices is not None: + var_updated = state_ops.scatter_add( + var, indices, -learning_rate_t * preconditioned_grad) + else: + var_updated = state_ops.assign_sub(var, + learning_rate_t * preconditioned_grad) + return var_updated diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0a202ae293664d85ece884a505096455cde73c --- /dev/null +++ b/tensorflow/contrib/opt/python/training/shampoo_test.py @@ -0,0 +1,734 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functional tests for AdaMoo optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.opt.python.training import shampoo +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + +TOLERANCE = 1e-3 + + +def np_power(mat_g, alpha): + """Computes mat_g^alpha for a square symmetric matrix mat_g.""" + + mat_u, diag_d, mat_v = np.linalg.svd(mat_g) + diag_d = np.power(diag_d, alpha) + return np.dot(np.dot(mat_u, np.diag(diag_d)), mat_v) + + +class ShampooTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testBasicVector(self, use_resource_var): + """Similar to the full Adagrad update.""" + + size = 20 + init_var_np = np.zeros(size) + grad_np = np.random.rand(size) + grad_np_2 = np.random.rand(size) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_g^{-0.5} * grad + # lr = 1 + mat_g = np.outer(grad_np, grad_np) + mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5) + new_val_np = init_var_np - np.dot(mat_h, grad_np) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g += np.outer(grad_np_2, grad_np_2) + mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5) + new_val_np -= np.dot(mat_h, grad_np_2) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testBasicMatrix(self, use_resource_var): + """Check update when gradient is a matrix.""" + size = [10, 5] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1]) + grad_np_2 = np.random.rand(size[0], size[1]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_g1^{-0.25} * grad * mat_g2^{-0.25} + # lr = 1 + mat_g1 = np.dot(grad_np, grad_np.transpose()) + mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np = init_var_np - np.dot(np.dot(mat_left, grad_np), mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.dot(grad_np_2, grad_np_2.transpose()) + mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np -= np.dot(np.dot(mat_left, grad_np_2), mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def _testBasicTensor(self, use_iterative_root, use_resource_var): + """Check update when gradient is a tensor. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1], size[2]) + grad_np_2 = np.random.rand(size[0], size[1], size[2]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np_2, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testBasicTensor(self, use_iterative_root, use_resource_var): + self._testBasicTensor(use_iterative_root, use_resource_var) + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testLargeVector(self, use_resource_var): + """This is just the diagonal Adagrad update.""" + + size = 2000 + init_var_np = np.zeros(size) + grad_np = np.random.rand(size) + grad_np_2 = np.random.rand(size) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * gg^{-0.5} * grad + # lr = 1 + mat_g = grad_np * grad_np + 0.1 + new_val_np = init_var_np - np.power(mat_g, -0.5) * grad_np + + self.assertAllCloseAccordingToType(new_val_np, new_val) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g += grad_np_2 * grad_np_2 + new_val_np -= np.power(mat_g, -0.5) * grad_np_2 + + self.assertAllCloseAccordingToType(new_val_np, new_val) + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testLargeMatrix(self, use_resource_var): + """Gradient is a matrix, one of whose dimensions is large. + + We do diagonal updates for large dimensions. + + Args: + use_resource_var: use resource var as variables. + """ + + size = [2000, 3] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1]) + grad_np_2 = np.random.rand(size[0], size[1]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_left * grad * mat_right + # where the mat_left * grad is just element-wise product, + # with broadcasting + # lr = 1 + + mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True) + mat_left = np.power(mat_g1 + 0.1, -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True) + mat_left = np.power(mat_g1 + 0.1, -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np -= np.dot(grad_np_2 * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters(('Var', False)) + def testSparseUpdateLarge(self, use_resource_var): + """Check update when gradient is of type IndexSlices. + + We do diagonal updates for the first dimension, unless it is very small. + + Args: + use_resource_var: use resource var as variables. + """ + size = [2000, 3] + sample_size_1 = 100 + init_var_np = np.zeros(size) + grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size_1, + replace=False)) + grad_np = np.random.rand(sample_size_1, size[1]) + + sample_size_2 = 7 + grad_indices_2 = np.sort(np.random.choice(np.arange(size[0]), sample_size_2, + replace=False)) + grad_np_2 = np.random.rand(sample_size_2, size[1]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = ops.IndexedSlices( + constant_op.constant(grad_np, dtype=dtypes.float32), + constant_op.constant(grad_indices), + constant_op.constant(size)) + grad_2 = ops.IndexedSlices( + constant_op.constant(grad_np_2, dtype=dtypes.float32), + constant_op.constant(grad_indices_2), + constant_op.constant(size)) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_left * grad * mat_right + # where the mat_left * grad is just element-wise product, + # with broadcasting + # lr = 1 + # In this case the update lr * mat_left * grad * mat_right is + # of size 10 x 2. + # So the correct indices of var need to be updated. + + mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True) + mat_g1_acc = np.zeros((size[0], 1)) + mat_g1_acc[grad_indices] += mat_g1 + mat_left = np.power(mat_g1 + 0.1, -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np = init_var_np + new_val_np[grad_indices, :] -= np.dot(grad_np * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 = np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True) + mat_g1_acc[grad_indices_2] += mat_g1 + mat_left = np.power(mat_g1_acc[grad_indices_2] + 0.1, -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np[grad_indices_2, :] -= np.dot(grad_np_2 * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def _testSparseUpdateSmall(self, use_iterative_root, use_resource_var): + """Gradient is of type IndexSlices, but the first dimension is small. + + We create dense gradient and do the full update with SVD etc. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + + size = [100, 3, 5] + sample_size = 10 + init_var_np = np.zeros(size) + grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size, + replace=False)) + grad_np = np.random.rand(sample_size, size[1], size[2]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = ops.IndexedSlices( + constant_op.constant(grad_np, dtype=dtypes.float32), + constant_op.constant(grad_indices), + constant_op.constant(size)) + + opt = shampoo.ShampooOptimizer(global_step, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.125} grad + # lr = 1 + grad_dense = np.zeros_like(init_var_np) + grad_dense[grad_indices] = grad_np + + mat_g1 = np.tensordot(grad_dense, grad_dense, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 = np.tensordot(grad_dense, grad_dense, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 = np.tensordot(grad_dense, grad_dense, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_dense, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testSparseUpdateSmall(self, use_iterative_root, use_resource_var): + self._testSparseUpdateSmall(use_iterative_root, use_resource_var) + + def _testBasicTensorWithMomentum(self, use_iterative_root, use_resource_var): + """Check update with momentum when gradient is a tensor. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1], size[2]) + grad_np_2 = np.random.rand(size[0], size[1], size[2]) + gbar_decay = 0.9 + gbar_weight = 0.1 + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step, gbar_decay=gbar_decay, + gbar_weight=gbar_weight, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + gbar_np = gbar_weight * grad_np + precond_grad = np.tensordot(gbar_np, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + gbar_np_2 = gbar_decay * gbar_np + gbar_weight * grad_np_2 + precond_grad = np.tensordot(gbar_np_2, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testBasicTensorWithMomentum(self, use_iterative_root, use_resource_var): + self._testBasicTensorWithMomentum(use_iterative_root, use_resource_var) + + def _testDelayedSVD(self, use_iterative_root, use_resource_var): + """Performing the SVD every nth step. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size).astype(np.float32) + iterations = 20 + svd_interval = 5 + grad_np = np.random.rand( + iterations, size[0], size[1], size[2]).astype(np.float32) + mat_g1_a = np.eye(size[0]) + mat_g1 = np.zeros_like(mat_g1_a) + mat_g2_a = np.eye(size[1]) + mat_g2 = np.zeros_like(mat_g2_a) + mat_g3_a = np.eye(size[2]) + mat_g3 = np.zeros_like(mat_g3_a) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = array_ops.placeholder(dtypes.float32, shape=size) + + opt = shampoo.ShampooOptimizer(global_step, svd_interval=svd_interval, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + new_val_np = init_var_np + + # Run n steps of Shampoo + for i in range(iterations): + _ = sess.run(update, feed_dict={grad: grad_np[i]}) + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 += np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) + mat_g2 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) + mat_g3 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) + if (i + 1) % svd_interval == 0: + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testDelayedSVD(self, use_iterative_root, use_resource_var): + self._testDelayedSVD(use_iterative_root, use_resource_var) + + def _testDelayedPrecondUpdate(self, use_iterative_root, use_resource_var): + """Update the squared sum every nth step, drop the other steps. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size).astype(np.float32) + iterations = 100 + grad_np = np.random.rand( + iterations, size[0], size[1], size[2]).astype(np.float32) + svd_interval = 20 + precond_update_interval = 5 + mat_g1_a = np.eye(size[0]) + mat_g1 = np.zeros_like(mat_g1_a) + mat_g2_a = np.eye(size[1]) + mat_g2 = np.zeros_like(mat_g2_a) + mat_g3_a = np.eye(size[2]) + mat_g3 = np.zeros_like(mat_g3_a) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = array_ops.placeholder(dtypes.float32, shape=size) + + opt = shampoo.ShampooOptimizer( + global_step, svd_interval=svd_interval, + precond_update_interval=precond_update_interval, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + new_val_np = init_var_np + + # Run n steps of Shampoo + for i in range(iterations): + _ = sess.run(update, feed_dict={grad: grad_np[i]}) + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + if (i + 1) % precond_update_interval == 0: + mat_g1 += (np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) + * precond_update_interval) + mat_g2 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) + * precond_update_interval) + mat_g3 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) + * precond_update_interval) + + if (i + 1) % svd_interval == 0: + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testDelayedPrecondUpdate(self, use_iterative_root, use_resource_var): + self._testDelayedPrecondUpdate(use_iterative_root, use_resource_var) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py index 8aa40aeb45d4ec15140bdfc5ebd824e8aa08d8d9..b9cf40eb7b2d11c98b93c51213145ca4e2670318 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -19,13 +19,13 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops -from tensorflow.python.training import optimizer from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops from tensorflow.python.training import adam from tensorflow.python.training import momentum as momentum_opt +from tensorflow.python.training import optimizer from tensorflow.python.util.tf_export import tf_export -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import resource_variable_ops class DecoupledWeightDecayExtension(object): @@ -65,7 +65,7 @@ class DecoupledWeightDecayExtension(object): Args: weight_decay: A `Tensor` or a floating point value, the factor by which a variable is decayed in the update step. - decay_var_list: Optional list or tuple or set of `Variable` objects to + **kwargs: Optional list or tuple or set of `Variable` objects to decay. """ self._decay_var_list = None # is set in minimize or apply_gradients @@ -85,6 +85,28 @@ class DecoupledWeightDecayExtension(object): If decay_var_list is None, all variables in var_list are decayed. For more information see the documentation of Optimizer.minimize. + + Args: + loss: A `Tensor` containing the value to minimize. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + var_list: Optional list or tuple of `Variable` objects to update to + minimize `loss`. Defaults to the list of variables collected in + the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. + gate_gradients: How to gate the computation of gradients. Can be + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: If True, try colocating gradients with + the corresponding op. + name: Optional name for the returned operation. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + decay_var_list: Optional list of decay variables. + + Returns: + An Operation that updates the variables in `var_list`. If `global_step` + was not `None`, that operation also increments `global_step`. + """ self._decay_var_list = set(decay_var_list) if decay_var_list else False return super(DecoupledWeightDecayExtension, self).minimize( @@ -103,6 +125,19 @@ class DecoupledWeightDecayExtension(object): are decayed. For more information see the documentation of Optimizer.apply_gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()`. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + name: Optional name for the returned operation. Default to the + name passed to the `Optimizer` constructor. + decay_var_list: Optional list of decay variables. + + Returns: + An `Operation` that applies the specified gradients. If `global_step` + was not None, that operation also increments `global_step`. """ self._decay_var_list = set(decay_var_list) if decay_var_list else False return super(DecoupledWeightDecayExtension, self).apply_gradients( @@ -197,6 +232,7 @@ def extend_with_decoupled_weight_decay(base_optimizer): A new optimizer class that inherits from DecoupledWeightDecayExtension and base_optimizer. """ + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, base_optimizer): """Base_optimizer with decoupled weight decay. diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py index 74d1cdbbdac8724518937d141a976abf9fec6ce3..76d8a5697acb79e7748175c4a81dfdd85807dd49 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.opt.python.training import weight_decay_optimizers from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,7 +30,6 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import adam -from tensorflow.contrib.opt.python.training import weight_decay_optimizers WEIGHT_DECAY = 0.01 @@ -91,7 +91,6 @@ class WeightDecayOptimizerTest(test.TestCase): opt = optimizer() update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - if not context.executing_eagerly(): with ops.Graph().as_default(): # Shouldn't return non-slot variables from other graphs. @@ -171,9 +170,9 @@ class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): @staticmethod def get_optimizer(): - AdamW = weight_decay_optimizers.extend_with_decoupled_weight_decay( + adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay( adam.AdamOptimizer) - return AdamW(WEIGHT_DECAY) + return adamw(WEIGHT_DECAY) def testBasic(self): self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", @@ -185,6 +184,5 @@ class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): use_resource=True) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD index 5225ecc14fef3cec9506eceb776805b74a87714e..3ba3ee29ec79687df522eb330665a2ce80061682 100644 --- a/tensorflow/contrib/optimizer_v2/BUILD +++ b/tensorflow/contrib/optimizer_v2/BUILD @@ -193,6 +193,7 @@ cuda_py_test( srcs = ["rmsprop_test.py"], additional_deps = [ ":training", + "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:embedding_ops", "//tensorflow/python:framework", diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py index d538ad0fb02699ed8514f512208914f629a47436..631d4f44dfb646541244bfe1d15136dd29f02703 100644 --- a/tensorflow/contrib/optimizer_v2/adam.py +++ b/tensorflow/contrib/optimizer_v2/adam.py @@ -103,9 +103,9 @@ class AdamOptimizer(optimizer_v2.OptimizerV2): def _create_vars(self, var_list, state): # Non-slot variables end up on the same device(s). - state.create_non_slot(initial_value=state.get_hyper("beta1"), + state.create_non_slot(initial_value=lambda: state.get_hyper("beta1"), name="beta1_power") - state.create_non_slot(initial_value=state.get_hyper("beta2"), + state.create_non_slot(initial_value=lambda: state.get_hyper("beta2"), name="beta2_power") # Create slots for the first and second moments. diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 64b95786b5c7a71ee514201d8eb60c26975938b5..28a531dfecf275c48fea54310b93b5266a79899a 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -41,17 +41,18 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util -class NonLayerCheckpointable(checkpointable.Checkpointable): +class NonLayerCheckpointable(tracking.Checkpointable): def __init__(self): super(NonLayerCheckpointable, self).__init__() - self.a_variable = checkpointable_utils.add_variable( + self.a_variable = util.add_variable( self, name="a_variable", shape=[]) @@ -88,29 +89,6 @@ class _MirroringSaveable( self._mirrored_variable.assign(tensor)) -class _OwnsMirroredVariables(checkpointable.CheckpointableBase): - """A Checkpointable object which returns a more complex SaveableObject.""" - - def __init__(self): - self.non_dep_variable = variable_scope.get_variable( - name="non_dep_variable", initializer=6., use_resource=True) - self.mirrored = variable_scope.get_variable( - name="mirrored", initializer=15., use_resource=True) - - def _gather_saveables_for_checkpoint(self): - def _saveable_factory(name=self.non_dep_variable.name): - return _MirroringSaveable( - primary_variable=self.non_dep_variable, - mirrored_variable=self.mirrored, - name=name) - return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} - - # The Saver sorts by name before parsing, so we need a name property. - @property - def name(self): - return self.non_dep_variable.name - - class CheckpointingTests(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) @@ -122,7 +100,7 @@ class CheckpointingTests(test.TestCase): other_model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = checkpointable_utils.Checkpoint( + root_checkpointable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) if context.executing_eagerly(): optimizer.minimize( @@ -137,11 +115,11 @@ class CheckpointingTests(test.TestCase): optimizer.minimize( other_model(input_value), global_step=optimizer_step) - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( root_checkpointable)) self.evaluate(train_op) named_variables, serialized_graph, _ = ( - checkpointable_utils._serialize_object_graph( + util._serialize_object_graph( root_checkpointable, saveables_cache=None)) expected_checkpoint_names = ( # Created in the root node, so no prefix. @@ -226,11 +204,11 @@ class CheckpointingTests(test.TestCase): optimizer_node.slot_variables[0] .slot_variable_node_id].attributes[0].checkpoint_key) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestore(self): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root_checkpointable = checkpointable_utils.Checkpoint( + root_checkpointable = util.Checkpoint( optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) if context.executing_eagerly(): @@ -240,7 +218,7 @@ class CheckpointingTests(test.TestCase): train_op = optimizer.minimize(model(input_value)) # TODO(allenl): Make initialization more pleasant when graph building. root_checkpointable.save_counter # pylint: disable=pointless-statement - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( root_checkpointable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") @@ -266,7 +244,7 @@ class CheckpointingTests(test.TestCase): # Preserve beta1_power and beta2_power when appying gradients so we can # test that they've been restored correctly. beta1=1.0, beta2=1.0) - on_create_root = checkpointable_utils.Checkpoint( + on_create_root = util.Checkpoint( optimizer=on_create_optimizer, model=on_create_model) # Deferred restoration status = on_create_root.restore(save_path=save_path) @@ -298,10 +276,11 @@ class CheckpointingTests(test.TestCase): for training_continuation in range(3): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) - root.restore(core_saver.latest_checkpoint(checkpoint_directory)) + root.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) for _ in range(num_training_steps): # TODO(allenl): Use a Dataset and serialize/checkpoint it. input_value = constant_op.constant([[3.]]) @@ -322,14 +301,15 @@ class CheckpointingTests(test.TestCase): with ops.Graph().as_default(): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) input_value = constant_op.constant([[3.]]) train_op = optimizer.minimize( model(input_value), global_step=root.global_step) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) with self.test_session(graph=ops.get_default_graph()) as session: status = root.restore(save_path=checkpoint_path) status.initialize_or_restore(session=session) @@ -347,7 +327,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual(training_continuation + 1, session.run(root.save_counter)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAgnosticUsage(self): """Graph/eager agnostic usage.""" # Does create garbage when executing eagerly due to ops.Graph() creation. @@ -359,10 +339,11 @@ class CheckpointingTests(test.TestCase): graph=ops.get_default_graph()), test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( @@ -381,7 +362,7 @@ class CheckpointingTests(test.TestCase): self.evaluate(root.save_counter)) # pylint: disable=cell-var-from-loop - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWithDefun(self): num_training_steps = 2 checkpoint_directory = self.get_temp_dir() @@ -392,10 +373,11 @@ class CheckpointingTests(test.TestCase): model = MyModel() # Don't actually train so we can test variable values optimizer = adam.AdamOptimizer(0.) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) def train_fn(): @function.defun @@ -442,7 +424,7 @@ class CheckpointingTests(test.TestCase): optimizer = adam.AdamOptimizer(learning_rate=0.05) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - checkpoint = checkpointable_utils.Checkpoint( + checkpoint = util.Checkpoint( model=model, optimizer=optimizer) for _ in range(2): checkpoint.save(checkpoint_prefix) @@ -453,12 +435,12 @@ class CheckpointingTests(test.TestCase): optimizer.apply_gradients( [(g, v) for g, v in zip(grad, model.vars)]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() - root = checkpointable.Checkpointable() - root.var = checkpointable_utils.add_variable( + root = tracking.Checkpointable() + root.var = util.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) if context.executing_eagerly(): @@ -468,28 +450,28 @@ class CheckpointingTests(test.TestCase): # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. - self.evaluate(checkpointable_utils.gather_initializers( - checkpointable_utils.Checkpoint(root=root, optimizer=optimizer))) + self.evaluate(util.gather_initializers( + util.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) - no_slots_path = checkpointable_utils.CheckpointableSaver(root).save( + no_slots_path = util.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) - slots_path = checkpointable_utils.CheckpointableSaver(root).save( + slots_path = util.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) - new_root = checkpointable.Checkpointable() + new_root = tracking.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). - slot_status = checkpointable_utils.CheckpointableSaver( + slot_status = util.CheckpointableSaver( new_root).restore(slots_path) - no_slot_status = checkpointable_utils.CheckpointableSaver( + no_slot_status = util.CheckpointableSaver( new_root).restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() - new_root.var = checkpointable_utils.add_variable( + new_root.var = util.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() @@ -525,12 +507,12 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) - self.evaluate(checkpointable_utils.gather_initializers(obj)) - saver = checkpointable_utils.CheckpointableSaver(obj) + self.evaluate(util.gather_initializers(obj)) + saver = util.CheckpointableSaver(obj) saver.save(checkpoint_prefix) before_ops = graph.get_operations() saver.save(checkpoint_prefix) @@ -543,12 +525,12 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) - self.evaluate(checkpointable_utils.gather_initializers(obj)) - saver = checkpointable_utils.CheckpointableSaver(obj) + self.evaluate(util.gather_initializers(obj)) + saver = util.CheckpointableSaver(obj) save_path = saver.save(checkpoint_prefix) saver.restore(save_path) before_ops = graph.get_operations() @@ -565,10 +547,10 @@ class CheckpointingTests(test.TestCase): first_session = session_lib.Session(graph=first_graph) with first_graph.as_default(), first_session.as_default(): first_variable = resource_variable_ops.ResourceVariable([1.]) - first_root_checkpointable = checkpointable_utils.Checkpoint( + first_root_checkpointable = util.Checkpoint( optimizer=optimizer, variable=first_variable) train_op = optimizer.minimize(first_variable.read_value) - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( first_root_checkpointable)) self.evaluate(train_op) self.evaluate(first_variable.assign([1.])) @@ -581,7 +563,7 @@ class CheckpointingTests(test.TestCase): second_graph = ops.Graph() with second_graph.as_default(), session_lib.Session(graph=second_graph): second_variable = resource_variable_ops.ResourceVariable([1.]) - second_root_checkpointable = checkpointable_utils.Checkpoint( + second_root_checkpointable = util.Checkpoint( optimizer=optimizer, variable=second_variable) train_op = optimizer.minimize(second_variable.read_value) second_root_checkpointable.restore(None).initialize_or_restore() @@ -616,7 +598,7 @@ class CheckpointingTests(test.TestCase): class TemplateTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_checkpointable_save_restore(self): def _templated(): @@ -631,7 +613,7 @@ class TemplateTests(test.TestCase): save_template = template.make_template("s1", _templated) v1_save, _, v2_save = save_template() optimizer = adam.AdamOptimizer(0.0) - save_root = checkpointable_utils.Checkpoint( + save_root = util.Checkpoint( my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value) self.evaluate([v.initializer for v in optimizer.variables()]) @@ -643,7 +625,7 @@ class TemplateTests(test.TestCase): load_template = template.make_template("s2", _templated) load_optimizer = adam.AdamOptimizer(0.0) - load_root = checkpointable_utils.Checkpoint( + load_root = util.Checkpoint( my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2 = load_template() @@ -664,12 +646,12 @@ class CheckpointCompatibilityTests(test.TestCase): model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = checkpointable_utils.Checkpoint( + root_checkpointable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) train_op = optimizer.minimize( functools.partial(model, input_value), global_step=optimizer_step) - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( root_checkpointable)) self.evaluate(train_op) # A regular variable, a slot variable, and a non-slot Optimizer variable @@ -712,7 +694,7 @@ class CheckpointCompatibilityTests(test.TestCase): sess=session, save_path=checkpoint_prefix, global_step=root.optimizer_step) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLoadFromNameBasedSaver(self): """Save a name-based checkpoint, load it using the object-based API.""" with test_util.device(use_gpu=True): @@ -721,7 +703,7 @@ class CheckpointCompatibilityTests(test.TestCase): self._set_sentinels(root) with self.assertRaises(AssertionError): self._check_sentinels(root) - object_saver = checkpointable_utils.CheckpointableSaver(root) + object_saver = util.CheckpointableSaver(root) self._set_sentinels(root) status = object_saver.restore(save_path) if context.executing_eagerly(): diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index f537318b32986c941b6c41eb363929e906027dd7..f6ecaba834600f7477453fb63842941c6a6e1a04 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer as optimizer_v1 from tensorflow.python.training import slot_creator from tensorflow.python.training.checkpointable import base as checkpointable @@ -162,12 +163,12 @@ def _get_processor(v): def _var_key_v2(var): """Key for representing a primary variable, for looking up slots.""" # pylint: disable=protected-access - if hasattr(var, "_mirrored_container"): - mirrored_container = var._mirrored_container() - assert mirrored_container is not None + if hasattr(var, "_distributed_container"): + distributed_container = var._distributed_container() + assert distributed_container is not None if context.executing_eagerly(): - return mirrored_container._unique_id - return mirrored_container._shared_name + return distributed_container._unique_id + return distributed_container._shared_name if context.executing_eagerly(): return var._unique_id return var.op.name @@ -211,8 +212,9 @@ class _OptimizerV2State(object): # This dict starts with a single item with key "None" with the hyper # parameter value converted to a Tensor. Other items have dtype keys # with that Tensor cast to that dtype. - self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)} - for name, (dynamic, value) in hyper.items() if not dynamic} + with ops.init_scope(): + self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)} + for name, (dynamic, value) in hyper.items() if not dynamic} self._slots = {} self._non_slot_dict = {} # Extra state to help Optimizers implement Checkpointable. Holds information @@ -619,7 +621,7 @@ class OptimizerV2(optimizer_v1.Optimizer): # Map from graph_key to state for that graph. We use the graph_key # since it works in both eager and graph mode, and gives the outer # graph inside functions. - tower_context = distribute_lib.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() if tower_context is None: # In a cross-tower context for a DistributionStrategy, which means # only one Optimizer will be created, not one per tower. @@ -765,9 +767,11 @@ class OptimizerV2(optimizer_v1.Optimizer): # *after* loss() is evaluated, so we know what loss reduction it uses. if scale_loss_by_num_towers is None: scale_loss_by_num_towers = ( - distribute_lib.get_loss_reduction() == "mean") + distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: - num_towers = distribute_lib.get_distribution_strategy().num_towers + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss_value *= 1. / num_towers @@ -783,9 +787,11 @@ class OptimizerV2(optimizer_v1.Optimizer): # Scale loss for number of towers (non-callable-loss case). if scale_loss_by_num_towers is None: scale_loss_by_num_towers = ( - distribute_lib.get_loss_reduction() == "mean") + distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: - num_towers = distribute_lib.get_distribution_strategy().num_towers + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss *= 1. / num_towers @@ -859,7 +865,7 @@ class OptimizerV2(optimizer_v1.Optimizer): if not filtered: raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, v in grads_and_vars],)) - return distribute_lib.get_tower_context().merge_call( + return distribution_strategy_context.get_tower_context().merge_call( self._distributed_apply, filtered, global_step=global_step, name=name) def _get_or_create_state(self, var_list=None): @@ -895,7 +901,8 @@ class OptimizerV2(optimizer_v1.Optimizer): def _distributed_apply(self, distribution, grads_and_vars, global_step, name): """`apply_gradients` for use with a `DistributionStrategy`.""" - reduced_grads = distribution.batch_reduce("sum", grads_and_vars) + reduced_grads = distribution.batch_reduce( + variable_scope.VariableAggregation.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py index 8599af32f6f4cc5529cd812e83c02ef3812cb71e..a44bfd1bfd97e678fbf4c402ef5b1298dc518f75 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py @@ -35,15 +35,11 @@ from tensorflow.python.platform import test class OptimizerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBasic(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a_%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b_%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) def loss(): return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop # Note that for eager execution, minimize expects a function instead of a @@ -113,7 +109,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)], var1.eval()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoVariables(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: # pylint: disable=cell-var-from-loop @@ -128,15 +124,11 @@ class OptimizerTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'No.*variables'): sgd_op.minimize(loss) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) # pylint: disable=cell-var-from-loop def loss(): return 5 * var0 @@ -146,15 +138,11 @@ class OptimizerTest(test.TestCase): # var1 has no gradient sgd_op.minimize(loss, var_list=[var1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_Minimize(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a_%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b_%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) def loss(): return constant_op.constant(5.0) sgd_op = gradient_descent.GradientDescentOptimizer(3.0) @@ -162,29 +150,21 @@ class OptimizerTest(test.TestCase): 'No gradients provided for any variable'): sgd_op.minimize(loss, var_list=[var0, var1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_ApplyGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a_%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b_%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) sgd_op = gradient_descent.GradientDescentOptimizer(3.0) with self.assertRaisesRegexp(ValueError, 'No gradients provided for any variable'): sgd_op.apply_gradients([(None, var0), (None, var1)]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientsAsVariables(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) def loss(): return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop sgd_op = gradient_descent.GradientDescentOptimizer(3.0) @@ -216,7 +196,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([-14., -13.], self.evaluate(var0)) self.assertAllClose([-6., -5.], self.evaluate(var1)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testComputeGradientsWithTensors(self): x = ops.convert_to_tensor(1.0) def f(): diff --git a/tensorflow/contrib/optimizer_v2/rmsprop_test.py b/tensorflow/contrib/optimizer_v2/rmsprop_test.py index ed68f6afbf8bf9678649c1ce6fc59c3b91026dc0..dc23ef241a43900ed40f029f1b857820459e43d0 100644 --- a/tensorflow/contrib/optimizer_v2/rmsprop_test.py +++ b/tensorflow/contrib/optimizer_v2/rmsprop_test.py @@ -19,15 +19,16 @@ from __future__ import division from __future__ import print_function import copy -import itertools import math +from absl.testing import parameterized import numpy as np from tensorflow.contrib.optimizer_v2 import rmsprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -48,13 +49,8 @@ _TEST_PARAM_VALUES = [ [0.5, 0.95, 0.9, 1e-5, True, False], ] -_TESTPARAMS = [ - [data_type] + values - for data_type, values in itertools.product(_DATA_TYPES, _TEST_PARAM_VALUES) -] - -class RMSPropOptimizerTest(test.TestCase): +class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, decay, momentum, epsilon, centered): @@ -87,362 +83,366 @@ class RMSPropOptimizerTest(test.TestCase): var_t[gindex] = var[gindex] - mom_t[gindex] return var_t, mg_t, rms_t, mom_t - def testDense(self): - # TODO(yori): Use ParameterizedTest when available - for (dtype, learning_rate, decay, momentum, - epsilon, centered, use_resource) in _TESTPARAMS: - with self.test_session(use_gpu=True): - # Initialize variables for numpy implementation. - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype) - - if use_resource: - var0 = resource_variable_ops.ResourceVariable(var0_np) - var1 = resource_variable_ops.ResourceVariable(var1_np) - else: - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) - opt = rmsprop.RMSPropOptimizer( - learning_rate=learning_rate, - decay=decay, - momentum=momentum, - epsilon=epsilon, - centered=centered) - - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - mg0 = opt.get_slot(var0, "mg") - self.assertEqual(mg0 is not None, centered) - mg1 = opt.get_slot(var1, "mg") - self.assertEqual(mg1 is not None, centered) - rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) - rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) - mom0 = opt.get_slot(var0, "momentum") - self.assertTrue(mom0 is not None) - mom1 = opt.get_slot(var1, "momentum") - self.assertTrue(mom1 is not None) - - mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - - # Run 4 steps of RMSProp - for _ in range(1, 5): - update.run() - - var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( - var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, - decay, momentum, epsilon, centered) - var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy( - var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate, - decay, momentum, epsilon, centered) - - # Validate updated params - if centered: - self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) - self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) - self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) - self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) - self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) - self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) - - def testMinimizeSparseResourceVariable(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) - x = constant_op.constant([[4.0], [5.0]], dtype=dtype) - pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) - loss = pred * pred - sgd_op = rmsprop.RMSPropOptimizer( - learning_rate=1.0, - decay=0.0, - momentum=0.0, - epsilon=0.0, - centered=False).minimize(loss) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType( - [[0., 1.]], var0.eval(), atol=0.01) - - def testMinimizeSparseResourceVariableCentered(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) - x = constant_op.constant([[4.0], [5.0]], dtype=dtype) - pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) - loss = pred * pred - sgd_op = rmsprop.RMSPropOptimizer( - learning_rate=1.0, - decay=0.0, - momentum=0.0, - epsilon=1.0, - centered=True).minimize(loss) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType( - [[-111, -138]], var0.eval(), atol=0.01) - - def testSparse(self): - # TODO(yori): Use ParameterizedTest when available - for (dtype, learning_rate, decay, - momentum, epsilon, centered, _) in _TESTPARAMS: - with self.test_session(use_gpu=True): - # Initialize variables for numpy implementation. - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype) - + @parameterized.named_parameters( + *test_util.generate_combinations_with_testcase_name( + dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES)) + def testDense(self, dtype, param_value): + (learning_rate, decay, momentum, epsilon, centered, use_resource) = tuple( + param_value) + with self.test_session(use_gpu=True): + # Initialize variables for numpy implementation. + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: var0 = variables.Variable(var0_np) var1 = variables.Variable(var1_np) - grads0_np_indices = np.array([0], dtype=np.int32) - grads0 = ops.IndexedSlices( - constant_op.constant(grads0_np), - constant_op.constant(grads0_np_indices), constant_op.constant([1])) - grads1_np_indices = np.array([1], dtype=np.int32) - grads1 = ops.IndexedSlices( - constant_op.constant(grads1_np), - constant_op.constant(grads1_np_indices), constant_op.constant([1])) - opt = rmsprop.RMSPropOptimizer( - learning_rate=learning_rate, - decay=decay, - momentum=momentum, - epsilon=epsilon, - centered=centered) - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - mg0 = opt.get_slot(var0, "mg") - self.assertEqual(mg0 is not None, centered) - mg1 = opt.get_slot(var1, "mg") - self.assertEqual(mg1 is not None, centered) - rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) - rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) - mom0 = opt.get_slot(var0, "momentum") - self.assertTrue(mom0 is not None) - mom1 = opt.get_slot(var1, "momentum") - self.assertTrue(mom1 is not None) - - mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - - # Run 4 steps of RMSProp - for _ in range(1, 5): - update.run() - - var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy( - var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np, - learning_rate, decay, momentum, epsilon, centered) - var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy( - var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np, - learning_rate, decay, momentum, epsilon, centered) - - # Validate updated params - if centered: - self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) - self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) - self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) - self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) - self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) - self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) - - def testWithoutMomentum(self): - for dtype in [dtypes.half, dtypes.float32]: - with self.test_session(use_gpu=True): - var0 = variables.Variable([1.0, 2.0], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - opt = rmsprop.RMSPropOptimizer( - learning_rate=2.0, decay=0.9, momentum=0.0, epsilon=1.0) - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) - rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) - mom0 = opt.get_slot(var0, "momentum") - self.assertTrue(mom0 is not None) - mom1 = opt.get_slot(var1, "momentum") - self.assertTrue(mom1 is not None) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - # Step 1: the rms accumulators where 1. So we should see a normal - # update: v -= grad * learning_rate - update.run() - # Check the root mean square accumulators. - self.assertAllCloseAccordingToType( - np.array([0.901, 0.901]), rms0.eval()) - self.assertAllCloseAccordingToType( - np.array([0.90001, 0.90001]), rms1.eval()) - # Check the parameters. - self.assertAllCloseAccordingToType( - np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - ]), var0.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - ]), var1.eval()) - # Step 2: the root mean square accumulators contain the previous update. - update.run() - # Check the rms accumulators. - self.assertAllCloseAccordingToType( - np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) - self.assertAllCloseAccordingToType( - np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) - # Check the parameters. - self.assertAllCloseAccordingToType( - np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)) - ]), var0.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)) - ]), var1.eval()) - - def testWithMomentum(self): - for dtype in [dtypes.half, dtypes.float32]: - with self.test_session(use_gpu=True): - var0 = variables.Variable([1.0, 2.0], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - - opt = rmsprop.RMSPropOptimizer( - learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1e-5) - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) - rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) - mom0 = opt.get_slot(var0, "momentum") - self.assertTrue(mom0 is not None) - mom1 = opt.get_slot(var1, "momentum") - self.assertTrue(mom1 is not None) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - # Step 1: rms = 1, mom = 0. So we should see a normal - # update: v -= grad * learning_rate + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = rmsprop.RMSPropOptimizer( + learning_rate=learning_rate, + decay=decay, + momentum=momentum, + epsilon=epsilon, + centered=centered) + + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + mg0 = opt.get_slot(var0, "mg") + self.assertEqual(mg0 is not None, centered) + mg1 = opt.get_slot(var1, "mg") + self.assertEqual(mg1 is not None, centered) + rms0 = opt.get_slot(var0, "rms") + self.assertIsNotNone(rms0) + rms1 = opt.get_slot(var1, "rms") + self.assertIsNotNone(rms1) + mom0 = opt.get_slot(var0, "momentum") + self.assertIsNotNone(mom0) + mom1 = opt.get_slot(var1, "momentum") + self.assertIsNotNone(mom1) + + mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 4 steps of RMSProp + for _ in range(4): update.run() - # Check the root mean square accumulators. - self.assertAllCloseAccordingToType( - np.array([0.901, 0.901]), rms0.eval()) - self.assertAllCloseAccordingToType( - np.array([0.90001, 0.90001]), rms1.eval()) - # Check the momentum accumulators - self.assertAllCloseAccordingToType( - np.array([(0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))]), mom0.eval()) - self.assertAllCloseAccordingToType( - np.array([(0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))]), mom1.eval()) - - # Check that the parameters. - self.assertAllCloseAccordingToType( - np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - ]), var0.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - ]), var1.eval()) - - # Step 2: the root mean square accumulators contain the previous update. + + var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( + var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, + decay, momentum, epsilon, centered) + var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy( + var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate, + decay, momentum, epsilon, centered) + + # Validate updated params + if centered: + self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) + self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) + self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) + self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) + self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) + self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @parameterized.parameters([dtypes.float32, dtypes.float64]) + def testMinimizeSparseResourceVariable(self, dtype): + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) + loss = pred * pred + sgd_op = rmsprop.RMSPropOptimizer( + learning_rate=1.0, + decay=0.0, + momentum=0.0, + epsilon=0.0, + centered=False).minimize(loss) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType( + [[0., 1.]], var0.eval(), atol=0.01) + + @parameterized.parameters([dtypes.float32, dtypes.float64]) + def testMinimizeSparseResourceVariableCentered(self, dtype): + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) + loss = pred * pred + sgd_op = rmsprop.RMSPropOptimizer( + learning_rate=1.0, + decay=0.0, + momentum=0.0, + epsilon=1.0, + centered=True).minimize(loss) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType( + [[-111, -138]], var0.eval(), atol=0.01) + + @parameterized.named_parameters( + *test_util.generate_combinations_with_testcase_name( + dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES)) + def testSparse(self, dtype, param_value): + (learning_rate, decay, momentum, epsilon, centered, _) = tuple( + param_value) + with self.test_session(use_gpu=True): + # Initialize variables for numpy implementation. + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([1])) + grads1_np_indices = np.array([1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([1])) + opt = rmsprop.RMSPropOptimizer( + learning_rate=learning_rate, + decay=decay, + momentum=momentum, + epsilon=epsilon, + centered=centered) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + mg0 = opt.get_slot(var0, "mg") + self.assertEqual(mg0 is not None, centered) + mg1 = opt.get_slot(var1, "mg") + self.assertEqual(mg1 is not None, centered) + rms0 = opt.get_slot(var0, "rms") + self.assertIsNotNone(rms0) + rms1 = opt.get_slot(var1, "rms") + self.assertIsNotNone(rms1) + mom0 = opt.get_slot(var0, "momentum") + self.assertIsNotNone(mom0) + mom1 = opt.get_slot(var1, "momentum") + self.assertIsNotNone(mom1) + + mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 4 steps of RMSProp + for _ in range(4): update.run() - # Check the rms accumulators. - self.assertAllCloseAccordingToType( - np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) - self.assertAllCloseAccordingToType( - np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)), - 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)) - ]), mom0.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)), - 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)) - ]), mom1.eval()) - - # Check the parameters. - self.assertAllCloseAccordingToType( - np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - - (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - - (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))) - ]), var0.eval()) - - self.assertAllCloseAccordingToType( - np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - - (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - - (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))) - ]), var1.eval()) + + var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy( + var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np, + learning_rate, decay, momentum, epsilon, centered) + var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy( + var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np, + learning_rate, decay, momentum, epsilon, centered) + + # Validate updated params + if centered: + self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) + self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) + self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) + self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) + self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) + self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @parameterized.parameters(_DATA_TYPES) + def testWithoutMomentum(self, dtype): + with self.test_session(use_gpu=True): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + opt = rmsprop.RMSPropOptimizer( + learning_rate=2.0, decay=0.9, momentum=0.0, epsilon=1.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + rms0 = opt.get_slot(var0, "rms") + self.assertIsNotNone(rms0) + rms1 = opt.get_slot(var1, "rms") + self.assertIsNotNone(rms1) + mom0 = opt.get_slot(var0, "momentum") + self.assertIsNotNone(mom0) + mom1 = opt.get_slot(var1, "momentum") + self.assertIsNotNone(mom1) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: the rms accumulators where 1. So we should see a normal + # update: v -= grad * learning_rate + update.run() + # Check the root mean square accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901, 0.901]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001, 0.90001]), rms1.eval()) + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) + ]), var1.eval()) + # Step 2: the root mean square accumulators contain the previous update. + update.run() + # Check the rms accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)) + ]), var1.eval()) + + @parameterized.parameters(_DATA_TYPES) + def testWithMomentum(self, dtype): + with self.test_session(use_gpu=True): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + + opt = rmsprop.RMSPropOptimizer( + learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1e-5) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + rms0 = opt.get_slot(var0, "rms") + self.assertIsNotNone(rms0) + rms1 = opt.get_slot(var1, "rms") + self.assertIsNotNone(rms1) + mom0 = opt.get_slot(var0, "momentum") + self.assertIsNotNone(mom0) + mom1 = opt.get_slot(var1, "momentum") + self.assertIsNotNone(mom1) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: rms = 1, mom = 0. So we should see a normal + # update: v -= grad * learning_rate + update.run() + # Check the root mean square accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901, 0.901]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001, 0.90001]), rms1.eval()) + # Check the momentum accumulators + self.assertAllCloseAccordingToType( + np.array([(0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), + (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))]), mom0.eval()) + self.assertAllCloseAccordingToType( + np.array([(0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), + (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))]), mom1.eval()) + + # Check that the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + ]), var1.eval()) + + # Step 2: the root mean square accumulators contain the previous update. + update.run() + # Check the rms accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)), + 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)) + ]), mom0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)), + 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)) + ]), mom1.eval()) + + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - + (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - + (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))) + ]), var0.eval()) + + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - + (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - + (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))) + ]), var1.eval()) if __name__ == "__main__": diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD index 976b312e8345a801ad07f622b6117b88af2cf603..f2171efc959362c1e4392fefbd5842f0883571d7 100644 --- a/tensorflow/contrib/periodic_resample/BUILD +++ b/tensorflow/contrib/periodic_resample/BUILD @@ -97,6 +97,8 @@ tf_cc_test( ], deps = [ ":all_ops", + "//tensorflow/core:framework", + "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", ], diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc index 55edf76fcd3eed461e1465b569e1c2e9e2facbc0..43b7c1799ffb2e27f9d15bc6011d49334867b6ec 100644 --- a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc +++ b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index 36e21af618f5af744ce793509813eaf36e1b8479..72ea777ca7036bad91b15d8d2163fdee842b1e32 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -60,7 +60,7 @@ py_library( ":base_predictor", "//tensorflow/python:framework_ops", "//tensorflow/python:training", - "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/saved_model:signature_constants", ], ) @@ -90,9 +90,7 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python/estimator", - "//tensorflow/python/estimator:export", - "//tensorflow/python/estimator:export_output", - "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/saved_model:signature_constants", ], ) diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py index af3b2ad1b531b835f484a155efcc57bbe634f2df..c2166594e598857065a7fd109ec599a3b36e2d2c 100644 --- a/tensorflow/contrib/predictor/contrib_estimator_predictor.py +++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py @@ -22,8 +22,8 @@ from __future__ import print_function from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils from tensorflow.contrib.predictor import predictor from tensorflow.python.framework import ops +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import monitored_session -from tensorflow.python.training import saver class ContribEstimatorPredictor(predictor.Predictor): @@ -57,7 +57,8 @@ class ContribEstimatorPredictor(predictor.Predictor): # pylint: disable=protected-access model_fn_ops = estimator._get_predict_ops(input_fn_ops.features) # pylint: enable=protected-access - checkpoint_path = saver.latest_checkpoint(estimator.model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint( + estimator.model_dir) self._session = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( config=config, diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py index f275bc15adfa0a51a48964dff8edddbd45500e45..7886744b3ce7fc438bc73cb81bccfd0ddeea873e 100644 --- a/tensorflow/contrib/predictor/predictor_factories.py +++ b/tensorflow/contrib/predictor/predictor_factories.py @@ -108,6 +108,8 @@ def from_estimator(estimator, def from_saved_model(export_dir, signature_def_key=None, signature_def=None, + input_names=None, + output_names=None, tags=None, graph=None, config=None): @@ -121,6 +123,12 @@ def from_saved_model(export_dir, signature_def: A `SignatureDef` proto specifying the inputs and outputs for prediction. Only one of `signature_def_key` and `signature_def` should be specified. + input_names: A dictionary mapping strings to `Tensor`s in the `SavedModel` + that represent the input. The keys can be any string of the user's + choosing. + output_names: A dictionary mapping strings to `Tensor`s in the + `SavedModel` that represent the output. The keys can be any string of + the user's choosing. tags: Optional. Tags that will be used to retrieve the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`. graph: Optional. The Tensorflow `graph` in which prediction should be @@ -138,6 +146,8 @@ def from_saved_model(export_dir, export_dir, signature_def_key=signature_def_key, signature_def=signature_def, + input_names=input_names, + output_names=output_names, tags=tags, graph=graph, config=config) diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD index 3e9b1a0b8d8ec7c3c5fe5d1f2cf896dbb6c3de72..b27142cf4a6413eccb8489ea3eb775060ffd787b 100644 --- a/tensorflow/contrib/proto/BUILD +++ b/tensorflow/contrib/proto/BUILD @@ -16,17 +16,3 @@ py_library( "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", ], ) - -py_library( - name = "proto_pip", - data = [ - "//tensorflow/contrib/proto/python/kernel_tests:test_messages", - ] + if_static( - [], - otherwise = ["//tensorflow/contrib/proto/python/kernel_tests:libtestexample.so"], - ), - deps = [ - ":proto", - "//tensorflow/contrib/proto/python/kernel_tests:py_test_deps", - ], -) diff --git a/tensorflow/contrib/proto/python/kernel_tests/BUILD b/tensorflow/contrib/proto/python/kernel_tests/BUILD index a380a131f86abc8dd921a123afdb964bf6c2466c..125c1cee292092e55bc17294a29f175c8cc3999c 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/BUILD +++ b/tensorflow/contrib/proto/python/kernel_tests/BUILD @@ -4,47 +4,41 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -# Much of the work in this BUILD file actually happens in the corresponding -# build_defs.bzl, which creates an individual testcase for each example .pbtxt -# file in this directory. -# -load(":build_defs.bzl", "decode_proto_test_suite") -load(":build_defs.bzl", "encode_proto_test_suite") - -# This expands to a tf_py_test for each test file. -# It defines the test_suite :decode_proto_op_tests. -decode_proto_test_suite( - name = "decode_proto_tests", - examples = glob(["*.pbtxt"]), -) - -# This expands to a tf_py_test for each test file. -# It defines the test_suite :encode_proto_op_tests. -encode_proto_test_suite( - name = "encode_proto_tests", - examples = glob(["*.pbtxt"]), -) - -# Below here are tests that are not tied to an example text proto. -filegroup( - name = "test_messages", - srcs = glob(["*.pbtxt"]), -) - load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") tf_py_test( - name = "decode_proto_fail_test", + name = "decode_proto_op_test", size = "small", - srcs = ["decode_proto_fail_test.py"], + srcs = ["decode_proto_op_test.py"], additional_deps = [ + ":decode_proto_op_test_base", + ":py_test_deps", + "//tensorflow/contrib/proto:proto", + "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + ], + data = if_static( + [], + otherwise = [":libtestexample.so"], + ), + tags = [ + "no_pip", # TODO(b/78026780) + "no_windows", # TODO(b/78028010) + ], +) + +tf_py_test( + name = "encode_proto_op_test", + size = "small", + srcs = ["encode_proto_op_test.py"], + additional_deps = [ + ":encode_proto_op_test_base", ":py_test_deps", - "//third_party/py/numpy", "//tensorflow/contrib/proto:proto", "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", ], data = if_static( [], @@ -57,19 +51,41 @@ tf_py_test( ) py_library( - name = "test_case", - srcs = ["test_case.py"], - deps = ["//tensorflow/python:client_testlib"], + name = "proto_op_test_base", + testonly = 1, + srcs = ["proto_op_test_base.py"], + deps = [ + ":test_example_proto_py", + "//tensorflow/python:client_testlib", + ], +) + +py_library( + name = "decode_proto_op_test_base", + testonly = 1, + srcs = ["decode_proto_op_test_base.py"], + deps = [ + ":proto_op_test_base", + ":test_example_proto_py", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], ) py_library( - name = "py_test_deps", + name = "encode_proto_op_test_base", + testonly = 1, + srcs = ["encode_proto_op_test_base.py"], deps = [ - ":test_case", + ":proto_op_test_base", ":test_example_proto_py", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) +py_library(name = "py_test_deps") + tf_proto_library( name = "test_example_proto", srcs = ["test_example.proto"], @@ -84,3 +100,30 @@ tf_cc_shared_object( ":test_example_proto_cc", ], ) + +py_library( + name = "descriptor_source_test_base", + testonly = 1, + srcs = ["descriptor_source_test_base.py"], + deps = [ + ":proto_op_test_base", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + "@protobuf_archive//:protobuf_python", + ], +) + +tf_py_test( + name = "descriptor_source_test", + size = "small", + srcs = ["descriptor_source_test.py"], + additional_deps = [ + ":descriptor_source_test_base", + "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", + "//tensorflow/python:client_testlib", + ], + tags = [ + "no_pip", + ], +) diff --git a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl deleted file mode 100644 index f425601691e21b36914f340d53ccadf9b4e3641f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl +++ /dev/null @@ -1,89 +0,0 @@ -"""BUILD rules for generating file-driven proto test cases. - -The decode_proto_test_suite() and encode_proto_test_suite() rules take a list -of text protos and generates a tf_py_test() for each one. -""" - -load("//tensorflow:tensorflow.bzl", "tf_py_test") -load("//tensorflow:tensorflow.bzl", "register_extension_info") -load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") - -def _test_name(test, path): - return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0]) - -def decode_proto_test_suite(name, examples): - """Build the decode_proto py_test for each test filename.""" - for test_filename in examples: - tf_py_test( - name = _test_name("decode_proto", test_filename), - srcs = ["decode_proto_op_test.py"], - size = "small", - data = [test_filename] + if_static( - [], - otherwise = [":libtestexample.so"], - ), - main = "decode_proto_op_test.py", - args = [ - "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename), - ], - additional_deps = [ - ":py_test_deps", - "//third_party/py/numpy", - "//tensorflow/contrib/proto:proto", - "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", - ], - tags = [ - "no_pip", # TODO(b/78026780) - "no_windows", # TODO(b/78028010) - ], - ) - native.test_suite( - name = name, - tests = [":" + _test_name("decode_proto", test_filename) - for test_filename in examples], - ) - -def encode_proto_test_suite(name, examples): - """Build the encode_proto py_test for each test filename.""" - for test_filename in examples: - tf_py_test( - name = _test_name("encode_proto", test_filename), - srcs = ["encode_proto_op_test.py"], - size = "small", - data = [test_filename] + if_static( - [], - otherwise = [":libtestexample.so"], - ), - main = "encode_proto_op_test.py", - args = [ - "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename), - ], - additional_deps = [ - ":py_test_deps", - "//third_party/py/numpy", - "//tensorflow/contrib/proto:proto", - "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", - "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", - ], - tags = [ - "no_pip", # TODO(b/78026780) - "no_windows", # TODO(b/78028010) - ], - ) - native.test_suite( - name = name, - tests = [":" + _test_name("encode_proto", test_filename) - for test_filename in examples], - ) - -register_extension_info( - extension_name = "decode_proto_test_suite", - label_regex_map = { - "deps": "deps:decode_example_.*", - }) - -register_extension_info( - extension_name = "encode_proto_test_suite", - label_regex_map = { - "deps": "deps:encode_example_.*", - }) diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py deleted file mode 100644 index 5298342ee79b08a50b13ce8715e891a332efb3bc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py +++ /dev/null @@ -1,68 +0,0 @@ -# ============================================================================= -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -# Python3 preparedness imports. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.proto.python.kernel_tests import test_case -from tensorflow.contrib.proto.python.ops import decode_proto_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.platform import test - - -class DecodeProtoFailTest(test_case.ProtoOpTestCase): - """Test failure cases for DecodeToProto.""" - - def _TestCorruptProtobuf(self, sanitize): - """Test failure cases for DecodeToProto.""" - - # The goal here is to check the error reporting. - # Testing against a variety of corrupt protobufs is - # done by fuzzing. - corrupt_proto = 'This is not a binary protobuf' - - # Numpy silently truncates the strings if you don't specify dtype=object. - batch = np.array(corrupt_proto, dtype=object) - msg_type = 'tensorflow.contrib.proto.TestCase' - field_names = ['sizes'] - field_types = [dtypes.int32] - - with self.test_session() as sess: - ctensor, vtensor = decode_proto_op.decode_proto( - batch, - message_type=msg_type, - field_names=field_names, - output_types=field_types, - sanitize=sanitize) - with self.assertRaisesRegexp(errors.DataLossError, - 'Unable to parse binary protobuf' - '|Failed to consume entire buffer'): - _ = sess.run([ctensor] + vtensor) - - def testCorrupt(self): - self._TestCorruptProtobuf(sanitize=False) - - def testSanitizerCorrupt(self): - self._TestCorruptProtobuf(sanitize=True) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py index d1c13c82bc264bc8bcc721eb68ee3916f32ef7a8..934035ec4c97e04846f493817d4b4ed65db94f14 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py @@ -13,287 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""Table-driven test for decode_proto op. +"""Tests for decode_proto op.""" -This test is run once with each of the *.TestCase.pbtxt files -in the test directory. -""" # Python3 preparedness imports. from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from google.protobuf import text_format - -from tensorflow.contrib.proto.python.kernel_tests import test_case -from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.contrib.proto.python.kernel_tests import decode_proto_op_test_base as test_base from tensorflow.contrib.proto.python.ops import decode_proto_op -from tensorflow.python.framework import dtypes -from tensorflow.python.platform import flags from tensorflow.python.platform import test -FLAGS = flags.FLAGS - -flags.DEFINE_string('message_text_file', None, - 'A file containing a text serialized TestCase protobuf.') - - -class DecodeProtoOpTest(test_case.ProtoOpTestCase): - - def _compareValues(self, fd, vs, evs): - """Compare lists/arrays of field values.""" - - if len(vs) != len(evs): - self.fail('Field %s decoded %d outputs, expected %d' % - (fd.name, len(vs), len(evs))) - for i, ev in enumerate(evs): - # Special case fuzzy match for float32. TensorFlow seems to mess with - # MAX_FLT slightly and the test doesn't work otherwise. - # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through. - if fd.cpp_type == fd.CPPTYPE_FLOAT: - # Numpy isclose() is better than assertIsClose() which uses an absolute - # value comparison. - self.assertTrue( - np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i])) - elif fd.cpp_type == fd.CPPTYPE_STRING: - # In Python3 string tensor values will be represented as bytes, so we - # reencode the proto values to match that. - self.assertEqual(vs[i], ev.encode('ascii')) - else: - # Doubles and other types pass through unscathed. - self.assertEqual(vs[i], ev) - - def _compareRepeatedPrimitiveValue(self, batch_shape, sizes, fields, - field_dict): - """Compare protos of type RepeatedPrimitiveValue. - - Args: - batch_shape: the shape of the input tensor of serialized messages. - sizes: int matrix of repeat counts returned by decode_proto - fields: list of test_example_pb2.FieldSpec (types and expected values) - field_dict: map from field names to decoded numpy tensors of values - """ - - # Check that expected values match. - for field in fields: - values = field_dict[field.name] - self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) - - fd = field.expected.DESCRIPTOR.fields_by_name[field.name] - - # Values has the same shape as the input plus an extra - # dimension for repeats. - self.assertEqual(list(values.shape)[:-1], batch_shape) - - # Nested messages are represented as TF strings, requiring - # some special handling. - if field.name == 'message_value': - vs = [] - for buf in values.flat: - msg = test_example_pb2.PrimitiveValue() - msg.ParseFromString(buf) - vs.append(msg) - evs = getattr(field.expected, field.name) - if len(vs) != len(evs): - self.fail('Field %s decoded %d outputs, expected %d' % - (fd.name, len(vs), len(evs))) - for v, ev in zip(vs, evs): - self.assertEqual(v, ev) - continue - - # This can be a little confusing. For testing we are using - # RepeatedPrimitiveValue in two ways: it's the proto that we - # decode for testing, and it's used in the expected value as a - # union type. The two cases are slightly different: this is the - # second case. - # We may be fetching the uint64_value from the test proto, but - # in the expected proto we store it in the int64_value field - # because TensorFlow doesn't support unsigned int64. - tf_type_to_primitive_value_field = { - dtypes.float32: - 'float_value', - dtypes.float64: - 'double_value', - dtypes.int32: - 'int32_value', - dtypes.uint8: - 'uint8_value', - dtypes.int8: - 'int8_value', - dtypes.string: - 'string_value', - dtypes.int64: - 'int64_value', - dtypes.bool: - 'bool_value', - # Unhandled TensorFlow types: - # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32 - # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16 - } - tf_field_name = tf_type_to_primitive_value_field.get(field.dtype) - if tf_field_name is None: - self.fail('Unhandled tensorflow type %d' % field.dtype) - - self._compareValues(fd, values.flat, - getattr(field.expected, tf_field_name)) - - def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, - message_type, message_format, sanitize, - force_disordered=False): - """Run decode tests on a batch of messages. - - Args: - fields: list of test_example_pb2.FieldSpec (types and expected values) - case_sizes: expected sizes array - batch_shape: the shape of the input tensor of serialized messages - batch: list of serialized messages - message_type: descriptor name for messages - message_format: format of messages, 'text' or 'binary' - sanitize: whether to sanitize binary protobuf inputs - force_disordered: whether to force fields encoded out of order. - """ - - if force_disordered: - # Exercise code path that handles out-of-order fields by prepending extra - # fields with tag numbers higher than any real field. Note that this won't - # work with sanitization because that forces reserialization using a - # trusted decoder and encoder. - assert not sanitize - extra_fields = test_example_pb2.ExtraFields() - extra_fields.string_value = 'IGNORE ME' - extra_fields.bool_value = False - extra_msg = extra_fields.SerializeToString() - batch = [extra_msg + msg for msg in batch] - - # Numpy silently truncates the strings if you don't specify dtype=object. - batch = np.array(batch, dtype=object) - batch = np.reshape(batch, batch_shape) - - field_names = [f.name for f in fields] - output_types = [f.dtype for f in fields] - - with self.test_session() as sess: - sizes, vtensor = decode_proto_op.decode_proto( - batch, - message_type=message_type, - field_names=field_names, - output_types=output_types, - message_format=message_format, - sanitize=sanitize) - - vlist = sess.run([sizes] + vtensor) - sizes = vlist[0] - # Values is a list of tensors, one for each field. - value_tensors = vlist[1:] - - # Check that the repeat sizes are correct. - self.assertTrue( - np.all(np.array(sizes.shape) == batch_shape + [len(field_names)])) - - # Check that the decoded sizes match the expected sizes. - self.assertEqual(len(sizes.flat), len(case_sizes)) - self.assertTrue( - np.all(sizes.flat == np.array( - case_sizes, dtype=np.int32))) - - field_dict = dict(zip(field_names, value_tensors)) - - self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields, - field_dict) - - def testBinary(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - batch = [primitive.SerializeToString() for primitive in case.primitive] - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'binary', - sanitize=False) - - def testBinaryDisordered(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - batch = [primitive.SerializeToString() for primitive in case.primitive] - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'binary', - sanitize=False, - force_disordered=True) - - def testPacked(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - # Now try with the packed serialization. - # We test the packed representations by loading the same test cases - # using PackedPrimitiveValue instead of RepeatedPrimitiveValue. - # To do this we rely on the text format being the same for packed and - # unpacked fields, and reparse the test message using the packed version - # of the proto. - packed_batch = [ - # Note: float_format='.17g' is necessary to ensure preservation of - # doubles and floats in text format. - text_format.Parse( - text_format.MessageToString( - primitive, float_format='.17g'), - test_example_pb2.PackedPrimitiveValue()).SerializeToString() - for primitive in case.primitive - ] - - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - packed_batch, - 'tensorflow.contrib.proto.PackedPrimitiveValue', - 'binary', - sanitize=False) - - def testText(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - # Note: float_format='.17g' is necessary to ensure preservation of - # doubles and floats in text format. - text_batch = [ - text_format.MessageToString( - primitive, float_format='.17g') for primitive in case.primitive - ] - - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - text_batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'text', - sanitize=False) - def testSanitizerGood(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) +class DecodeProtoOpTest(test_base.DecodeProtoOpTestBase): - batch = [primitive.SerializeToString() for primitive in case.primitive] - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'binary', - sanitize=True) + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + super(DecodeProtoOpTest, self).__init__(decode_proto_op, methodName) if __name__ == '__main__': diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..e3570e38a3aac738b01b28eb4bfdf57e6abbc595 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py @@ -0,0 +1,303 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for decode_proto op.""" + +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + + +from google.protobuf import text_format + +from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors + + +class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): + """Base class for testing proto decoding ops.""" + + def __init__(self, decode_module, methodName='runTest'): # pylint: disable=invalid-name + """DecodeProtoOpTestBase initializer. + + Args: + decode_module: a module containing the `decode_proto_op` method + methodName: the name of the test method (same as for test.TestCase) + """ + + super(DecodeProtoOpTestBase, self).__init__(methodName) + self._decode_module = decode_module + + def _compareValues(self, fd, vs, evs): + """Compare lists/arrays of field values.""" + + if len(vs) != len(evs): + self.fail('Field %s decoded %d outputs, expected %d' % + (fd.name, len(vs), len(evs))) + for i, ev in enumerate(evs): + # Special case fuzzy match for float32. TensorFlow seems to mess with + # MAX_FLT slightly and the test doesn't work otherwise. + # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through. + if fd.cpp_type == fd.CPPTYPE_FLOAT: + # Numpy isclose() is better than assertIsClose() which uses an absolute + # value comparison. + self.assertTrue( + np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i])) + elif fd.cpp_type == fd.CPPTYPE_STRING: + # In Python3 string tensor values will be represented as bytes, so we + # reencode the proto values to match that. + self.assertEqual(vs[i], ev.encode('ascii')) + else: + # Doubles and other types pass through unscathed. + self.assertEqual(vs[i], ev) + + def _compareProtos(self, batch_shape, sizes, fields, field_dict): + """Compare protos of type TestValue. + + Args: + batch_shape: the shape of the input tensor of serialized messages. + sizes: int matrix of repeat counts returned by decode_proto + fields: list of test_example_pb2.FieldSpec (types and expected values) + field_dict: map from field names to decoded numpy tensors of values + """ + + # Check that expected values match. + for field in fields: + values = field_dict[field.name] + self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) + + fd = field.value.DESCRIPTOR.fields_by_name[field.name] + + # Values has the same shape as the input plus an extra + # dimension for repeats. + self.assertEqual(list(values.shape)[:-1], batch_shape) + + # Nested messages are represented as TF strings, requiring + # some special handling. + if field.name == 'message_value': + vs = [] + for buf in values.flat: + msg = test_example_pb2.PrimitiveValue() + msg.ParseFromString(buf) + vs.append(msg) + evs = getattr(field.value, field.name) + if len(vs) != len(evs): + self.fail('Field %s decoded %d outputs, expected %d' % + (fd.name, len(vs), len(evs))) + for v, ev in zip(vs, evs): + self.assertEqual(v, ev) + continue + + tf_type_to_primitive_value_field = { + dtypes.bool: + 'bool_value', + dtypes.float32: + 'float_value', + dtypes.float64: + 'double_value', + dtypes.int8: + 'int8_value', + dtypes.int32: + 'int32_value', + dtypes.int64: + 'int64_value', + dtypes.string: + 'string_value', + dtypes.uint8: + 'uint8_value', + dtypes.uint32: + 'uint32_value', + dtypes.uint64: + 'uint64_value', + } + tf_field_name = tf_type_to_primitive_value_field.get(field.dtype) + if tf_field_name is None: + self.fail('Unhandled tensorflow type %d' % field.dtype) + + self._compareValues(fd, values.flat, + getattr(field.value, tf_field_name)) + + def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, + message_type, message_format, sanitize, + force_disordered=False): + """Run decode tests on a batch of messages. + + Args: + fields: list of test_example_pb2.FieldSpec (types and expected values) + case_sizes: expected sizes array + batch_shape: the shape of the input tensor of serialized messages + batch: list of serialized messages + message_type: descriptor name for messages + message_format: format of messages, 'text' or 'binary' + sanitize: whether to sanitize binary protobuf inputs + force_disordered: whether to force fields encoded out of order. + """ + + if force_disordered: + # Exercise code path that handles out-of-order fields by prepending extra + # fields with tag numbers higher than any real field. Note that this won't + # work with sanitization because that forces reserialization using a + # trusted decoder and encoder. + assert not sanitize + extra_fields = test_example_pb2.ExtraFields() + extra_fields.string_value = 'IGNORE ME' + extra_fields.bool_value = False + extra_msg = extra_fields.SerializeToString() + batch = [extra_msg + msg for msg in batch] + + # Numpy silently truncates the strings if you don't specify dtype=object. + batch = np.array(batch, dtype=object) + batch = np.reshape(batch, batch_shape) + + field_names = [f.name for f in fields] + output_types = [f.dtype for f in fields] + + with self.test_session() as sess: + sizes, vtensor = self._decode_module.decode_proto( + batch, + message_type=message_type, + field_names=field_names, + output_types=output_types, + message_format=message_format, + sanitize=sanitize) + + vlist = sess.run([sizes] + vtensor) + sizes = vlist[0] + # Values is a list of tensors, one for each field. + value_tensors = vlist[1:] + + # Check that the repeat sizes are correct. + self.assertTrue( + np.all(np.array(sizes.shape) == batch_shape + [len(field_names)])) + + # Check that the decoded sizes match the expected sizes. + self.assertEqual(len(sizes.flat), len(case_sizes)) + self.assertTrue( + np.all(sizes.flat == np.array( + case_sizes, dtype=np.int32))) + + field_dict = dict(zip(field_names, value_tensors)) + + self._compareProtos(batch_shape, sizes, fields, field_dict) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testBinary(self, case): + batch = [value.SerializeToString() for value in case.values] + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + batch, + 'tensorflow.contrib.proto.TestValue', + 'binary', + sanitize=False) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testBinaryDisordered(self, case): + batch = [value.SerializeToString() for value in case.values] + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + batch, + 'tensorflow.contrib.proto.TestValue', + 'binary', + sanitize=False, + force_disordered=True) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testPacked(self, case): + # Now try with the packed serialization. + # + # We test the packed representations by loading the same test case using + # PackedTestValue instead of TestValue. To do this we rely on the text + # format being the same for packed and unpacked fields, and reparse the + # test message using the packed version of the proto. + packed_batch = [ + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_format.Parse( + text_format.MessageToString( + value, float_format='.17g'), + test_example_pb2.PackedTestValue()).SerializeToString() + for value in case.values + ] + + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + packed_batch, + 'tensorflow.contrib.proto.PackedTestValue', + 'binary', + sanitize=False) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testText(self, case): + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_batch = [ + text_format.MessageToString( + value, float_format='.17g') for value in case.values + ] + + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + text_batch, + 'tensorflow.contrib.proto.TestValue', + 'text', + sanitize=False) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testSanitizerGood(self, case): + batch = [value.SerializeToString() for value in case.values] + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + batch, + 'tensorflow.contrib.proto.TestValue', + 'binary', + sanitize=True) + + @parameterized.parameters((False), (True)) + def testCorruptProtobuf(self, sanitize): + corrupt_proto = 'This is not a binary protobuf' + + # Numpy silently truncates the strings if you don't specify dtype=object. + batch = np.array(corrupt_proto, dtype=object) + msg_type = 'tensorflow.contrib.proto.TestCase' + field_names = ['sizes'] + field_types = [dtypes.int32] + + with self.test_session() as sess: + ctensor, vtensor = self._decode_module.decode_proto( + batch, + message_type=msg_type, + field_names=field_names, + output_types=field_types, + sanitize=sanitize) + with self.assertRaisesRegexp(errors.DataLossError, + 'Unable to parse binary protobuf' + '|Failed to consume entire buffer'): + _ = sess.run([ctensor] + vtensor) diff --git a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt deleted file mode 100644 index 4e316819077c7dbb28beefd4dc260568f26da680..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt +++ /dev/null @@ -1,94 +0,0 @@ -primitive { - # No fields specified, so we get all defaults -} -shape: 1 -sizes: 0 -field { - name: "double_default" - dtype: DT_DOUBLE - expected { double_value: 1.0 } -} -sizes: 0 -field { - name: "float_default" - dtype: DT_DOUBLE # Try casting the float field to double. - expected { double_value: 2.0 } -} -sizes: 0 -field { - name: "int64_default" - dtype: DT_INT64 - expected { int64_value: 3 } -} -sizes: 0 -field { - name: "uint64_default" - dtype: DT_INT64 - expected { int64_value: 4 } -} -sizes: 0 -field { - name: "int32_default" - dtype: DT_INT32 - expected { int32_value: 5 } -} -sizes: 0 -field { - name: "fixed64_default" - dtype: DT_INT64 - expected { int64_value: 6 } -} -sizes: 0 -field { - name: "fixed32_default" - dtype: DT_INT32 - expected { int32_value: 7 } -} -sizes: 0 -field { - name: "bool_default" - dtype: DT_BOOL - expected { bool_value: true } -} -sizes: 0 -field { - name: "string_default" - dtype: DT_STRING - expected { string_value: "a" } -} -sizes: 0 -field { - name: "bytes_default" - dtype: DT_STRING - expected { string_value: "a longer default string" } -} -sizes: 0 -field { - name: "uint32_default" - dtype: DT_INT32 - expected { int32_value: -1 } -} -sizes: 0 -field { - name: "sfixed32_default" - dtype: DT_INT32 - expected { int32_value: 10 } -} -sizes: 0 -field { - name: "sfixed64_default" - dtype: DT_INT64 - expected { int64_value: 11 } -} -sizes: 0 -field { - name: "sint32_default" - dtype: DT_INT32 - expected { int32_value: 12 } -} -sizes: 0 -field { - name: "sint64_default" - dtype: DT_INT64 - expected { int64_value: 13 } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_case.py b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test.py similarity index 65% rename from tensorflow/contrib/proto/python/kernel_tests/test_case.py rename to tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test.py index b95202c5df654cfc02339477b242b2c58575a4d5..32ca318f733ce11221539838dfdbcf710dca51a1 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/test_case.py +++ b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test.py @@ -13,23 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""Test case base for testing proto operations.""" - +"""Tests for proto ops reading descriptors from other sources.""" # Python3 preparedness imports. from __future__ import absolute_import from __future__ import division from __future__ import print_function -import ctypes as ct -import os - +from tensorflow.contrib.proto.python.kernel_tests import descriptor_source_test_base as test_base +from tensorflow.contrib.proto.python.ops import decode_proto_op +from tensorflow.contrib.proto.python.ops import encode_proto_op from tensorflow.python.platform import test -class ProtoOpTestCase(test.TestCase): +class DescriptorSourceTest(test_base.DescriptorSourceTestBase): def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - super(ProtoOpTestCase, self).__init__(methodName) - lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so') - if os.path.isfile(lib): - ct.cdll.LoadLibrary(lib) + super(DescriptorSourceTest, self).__init__(decode_proto_op, encode_proto_op, + methodName) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..9a1c04af324620fc893583ebb17cd99ea3ba166d --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py @@ -0,0 +1,176 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for proto ops reading descriptors from other sources.""" +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np + +from google.protobuf.descriptor_pb2 import FieldDescriptorProto +from google.protobuf.descriptor_pb2 import FileDescriptorSet +from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test + + +class DescriptorSourceTestBase(test.TestCase): + """Base class for testing descriptor sources.""" + + def __init__(self, decode_module, encode_module, methodName='runTest'): # pylint: disable=invalid-name + """DescriptorSourceTestBase initializer. + + Args: + decode_module: a module containing the `decode_proto_op` method + encode_module: a module containing the `encode_proto_op` method + methodName: the name of the test method (same as for test.TestCase) + """ + + super(DescriptorSourceTestBase, self).__init__(methodName) + self._decode_module = decode_module + self._encode_module = encode_module + + # NOTE: We generate the descriptor programmatically instead of via a compiler + # because of differences between different versions of the compiler. + # + # The generated descriptor should capture the subset of `test_example.proto` + # used in `test_base.simple_test_case()`. + def _createDescriptorFile(self): + set_proto = FileDescriptorSet() + + file_proto = set_proto.file.add( + name='types.proto', + package='tensorflow', + syntax='proto3') + enum_proto = file_proto.enum_type.add(name='DataType') + enum_proto.value.add(name='DT_DOUBLE', number=0) + enum_proto.value.add(name='DT_BOOL', number=1) + + file_proto = set_proto.file.add( + name='test_example.proto', + package='tensorflow.contrib.proto', + dependency=['types.proto']) + message_proto = file_proto.message_type.add(name='TestCase') + message_proto.field.add( + name='values', + number=1, + type=FieldDescriptorProto.TYPE_MESSAGE, + type_name='.tensorflow.contrib.proto.TestValue', + label=FieldDescriptorProto.LABEL_REPEATED) + message_proto.field.add( + name='shapes', + number=2, + type=FieldDescriptorProto.TYPE_INT32, + label=FieldDescriptorProto.LABEL_REPEATED) + message_proto.field.add( + name='sizes', + number=3, + type=FieldDescriptorProto.TYPE_INT32, + label=FieldDescriptorProto.LABEL_REPEATED) + message_proto.field.add( + name='fields', + number=4, + type=FieldDescriptorProto.TYPE_MESSAGE, + type_name='.tensorflow.contrib.proto.FieldSpec', + label=FieldDescriptorProto.LABEL_REPEATED) + + message_proto = file_proto.message_type.add( + name='TestValue') + message_proto.field.add( + name='double_value', + number=1, + type=FieldDescriptorProto.TYPE_DOUBLE, + label=FieldDescriptorProto.LABEL_REPEATED) + message_proto.field.add( + name='bool_value', + number=2, + type=FieldDescriptorProto.TYPE_BOOL, + label=FieldDescriptorProto.LABEL_REPEATED) + + message_proto = file_proto.message_type.add( + name='FieldSpec') + message_proto.field.add( + name='name', + number=1, + type=FieldDescriptorProto.TYPE_STRING, + label=FieldDescriptorProto.LABEL_OPTIONAL) + message_proto.field.add( + name='dtype', + number=2, + type=FieldDescriptorProto.TYPE_ENUM, + type_name='.tensorflow.DataType', + label=FieldDescriptorProto.LABEL_OPTIONAL) + message_proto.field.add( + name='value', + number=3, + type=FieldDescriptorProto.TYPE_MESSAGE, + type_name='.tensorflow.contrib.proto.TestValue', + label=FieldDescriptorProto.LABEL_OPTIONAL) + + fn = os.path.join(self.get_temp_dir(), 'descriptor.pb') + with open(fn, 'wb') as f: + f.write(set_proto.SerializeToString()) + return fn + + def _testRoundtrip(self, descriptor_source): + # Numpy silently truncates the strings if you don't specify dtype=object. + in_bufs = np.array( + [test_base.ProtoOpTestBase.simple_test_case().SerializeToString()], + dtype=object) + message_type = 'tensorflow.contrib.proto.TestCase' + field_names = ['values', 'shapes', 'sizes', 'fields'] + tensor_types = [dtypes.string, dtypes.int32, dtypes.int32, dtypes.string] + + with self.test_session() as sess: + sizes, field_tensors = self._decode_module.decode_proto( + in_bufs, + message_type=message_type, + field_names=field_names, + output_types=tensor_types, + descriptor_source=descriptor_source) + + out_tensors = self._encode_module.encode_proto( + sizes, + field_tensors, + message_type=message_type, + field_names=field_names, + descriptor_source=descriptor_source) + + out_bufs, = sess.run([out_tensors]) + + # Check that the re-encoded tensor has the same shape. + self.assertEqual(in_bufs.shape, out_bufs.shape) + + # Compare the input and output. + for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): + # Check that the input and output serialized messages are identical. + # If we fail here, there is a difference in the serialized + # representation but the new serialization still parses. This could + # be harmless (a change in map ordering?) or it could be bad (e.g. + # loss of packing in the encoding). + self.assertEqual(in_buf, out_buf) + + def testWithFileDescriptorSet(self): + # First try parsing with a local proto db, which should fail. + with self.assertRaisesOpError('No descriptor found for message type'): + self._testRoundtrip('local://') + + # Now try parsing with a FileDescriptorSet which contains the test proto. + descriptor_file = self._createDescriptorFile() + self._testRoundtrip(descriptor_file) diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py index 30e58e6336dc66830418c7cd2b3111a851d691b6..fc5cd25d43be1df2480630396c39f7a83e0eb57a 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py @@ -13,167 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""Table-driven test for encode_proto op. +"""Tests for encode_proto op.""" -This test is run once with each of the *.TestCase.pbtxt files -in the test directory. - -It tests that encode_proto is a lossless inverse of decode_proto -(for the specified fields). -""" # Python3 readiness boilerplate from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from google.protobuf import text_format - -from tensorflow.contrib.proto.python.kernel_tests import test_case -from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.contrib.proto.python.kernel_tests import encode_proto_op_test_base as test_base from tensorflow.contrib.proto.python.ops import decode_proto_op from tensorflow.contrib.proto.python.ops import encode_proto_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import flags from tensorflow.python.platform import test -FLAGS = flags.FLAGS - -flags.DEFINE_string('message_text_file', None, - 'A file containing a text serialized TestCase protobuf.') - - -class EncodeProtoOpTest(test_case.ProtoOpTestCase): - - def testBadInputs(self): - # Invalid field name - with self.test_session(): - with self.assertRaisesOpError('Unknown field: non_existent_field'): - encode_proto_op.encode_proto( - sizes=[[1]], - values=[np.array([[0.0]], dtype=np.int32)], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['non_existent_field']).eval() - - # Incorrect types. - with self.test_session(): - with self.assertRaisesOpError( - 'Incompatible type for field double_value.'): - encode_proto_op.encode_proto( - sizes=[[1]], - values=[np.array([[0.0]], dtype=np.int32)], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['double_value']).eval() - - # Incorrect shapes of sizes. - with self.test_session(): - with self.assertRaisesOpError( - r'sizes should be batch_size \+ \[len\(field_names\)\]'): - sizes = array_ops.placeholder(dtypes.int32) - values = array_ops.placeholder(dtypes.float64) - encode_proto_op.encode_proto( - sizes=sizes, - values=[values], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['double_value']).eval(feed_dict={ - sizes: [[[0, 0]]], - values: [[0.0]] - }) - - # Inconsistent shapes of values. - with self.test_session(): - with self.assertRaisesOpError( - 'Values must match up to the last dimension'): - sizes = array_ops.placeholder(dtypes.int32) - values1 = array_ops.placeholder(dtypes.float64) - values2 = array_ops.placeholder(dtypes.int32) - (encode_proto_op.encode_proto( - sizes=[[1, 1]], - values=[values1, values2], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['double_value', 'int32_value']).eval(feed_dict={ - values1: [[0.0]], - values2: [[0], [0]] - })) - - def _testRoundtrip(self, in_bufs, message_type, fields): - - field_names = [f.name for f in fields] - out_types = [f.dtype for f in fields] - - with self.test_session() as sess: - sizes, field_tensors = decode_proto_op.decode_proto( - in_bufs, - message_type=message_type, - field_names=field_names, - output_types=out_types) - - out_tensors = encode_proto_op.encode_proto( - sizes, - field_tensors, - message_type=message_type, - field_names=field_names) - - out_bufs, = sess.run([out_tensors]) - - # Check that the re-encoded tensor has the same shape. - self.assertEqual(in_bufs.shape, out_bufs.shape) - - # Compare the input and output. - for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): - in_obj = test_example_pb2.RepeatedPrimitiveValue() - in_obj.ParseFromString(in_buf) - - out_obj = test_example_pb2.RepeatedPrimitiveValue() - out_obj.ParseFromString(out_buf) - - # Check that the deserialized objects are identical. - self.assertEqual(in_obj, out_obj) - - # Check that the input and output serialized messages are identical. - # If we fail here, there is a difference in the serialized - # representation but the new serialization still parses. This could - # be harmless (a change in map ordering?) or it could be bad (e.g. - # loss of packing in the encoding). - self.assertEqual(in_buf, out_buf) - - def testRoundtrip(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - in_bufs = [primitive.SerializeToString() for primitive in case.primitive] - - # np.array silently truncates strings if you don't specify dtype=object. - in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape)) - return self._testRoundtrip( - in_bufs, 'tensorflow.contrib.proto.RepeatedPrimitiveValue', case.field) - - def testRoundtripPacked(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - # Now try with the packed serialization. - # We test the packed representations by loading the same test cases - # using PackedPrimitiveValue instead of RepeatedPrimitiveValue. - # To do this we rely on the text format being the same for packed and - # unpacked fields, and reparse the test message using the packed version - # of the proto. - in_bufs = [ - # Note: float_format='.17g' is necessary to ensure preservation of - # doubles and floats in text format. - text_format.Parse( - text_format.MessageToString( - primitive, float_format='.17g'), - test_example_pb2.PackedPrimitiveValue()).SerializeToString() - for primitive in case.primitive - ] +class EncodeProtoOpTest(test_base.EncodeProtoOpTestBase): - # np.array silently truncates strings if you don't specify dtype=object. - in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape)) - return self._testRoundtrip( - in_bufs, 'tensorflow.contrib.proto.PackedPrimitiveValue', case.field) + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + super(EncodeProtoOpTest, self).__init__(decode_proto_op, encode_proto_op, + methodName) if __name__ == '__main__': diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..07dfb924d3ede5bdb9b848c5eb0d3382ec053121 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py @@ -0,0 +1,177 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Table-driven test for encode_proto op. + +This test is run once with each of the *.TestCase.pbtxt files +in the test directory. + +It tests that encode_proto is a lossless inverse of decode_proto +(for the specified fields). +""" +# Python3 readiness boilerplate +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from google.protobuf import text_format + +from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops + + +class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): + """Base class for testing proto encoding ops.""" + + def __init__(self, decode_module, encode_module, methodName='runTest'): # pylint: disable=invalid-name + """EncodeProtoOpTestBase initializer. + + Args: + decode_module: a module containing the `decode_proto_op` method + encode_module: a module containing the `encode_proto_op` method + methodName: the name of the test method (same as for test.TestCase) + """ + + super(EncodeProtoOpTestBase, self).__init__(methodName) + self._decode_module = decode_module + self._encode_module = encode_module + + def testBadInputs(self): + # Invalid field name + with self.test_session(): + with self.assertRaisesOpError('Unknown field: non_existent_field'): + self._encode_module.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['non_existent_field']).eval() + + # Incorrect types. + with self.test_session(): + with self.assertRaisesOpError( + 'Incompatible type for field double_value.'): + self._encode_module.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value']).eval() + + # Incorrect shapes of sizes. + with self.test_session(): + with self.assertRaisesOpError( + r'sizes should be batch_size \+ \[len\(field_names\)\]'): + sizes = array_ops.placeholder(dtypes.int32) + values = array_ops.placeholder(dtypes.float64) + self._encode_module.encode_proto( + sizes=sizes, + values=[values], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value']).eval(feed_dict={ + sizes: [[[0, 0]]], + values: [[0.0]] + }) + + # Inconsistent shapes of values. + with self.test_session(): + with self.assertRaisesOpError( + 'Values must match up to the last dimension'): + sizes = array_ops.placeholder(dtypes.int32) + values1 = array_ops.placeholder(dtypes.float64) + values2 = array_ops.placeholder(dtypes.int32) + (self._encode_module.encode_proto( + sizes=[[1, 1]], + values=[values1, values2], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value', 'int32_value']).eval(feed_dict={ + values1: [[0.0]], + values2: [[0], [0]] + })) + + def _testRoundtrip(self, in_bufs, message_type, fields): + + field_names = [f.name for f in fields] + out_types = [f.dtype for f in fields] + + with self.test_session() as sess: + sizes, field_tensors = self._decode_module.decode_proto( + in_bufs, + message_type=message_type, + field_names=field_names, + output_types=out_types) + + out_tensors = self._encode_module.encode_proto( + sizes, + field_tensors, + message_type=message_type, + field_names=field_names) + + out_bufs, = sess.run([out_tensors]) + + # Check that the re-encoded tensor has the same shape. + self.assertEqual(in_bufs.shape, out_bufs.shape) + + # Compare the input and output. + for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): + in_obj = test_example_pb2.TestValue() + in_obj.ParseFromString(in_buf) + + out_obj = test_example_pb2.TestValue() + out_obj.ParseFromString(out_buf) + + # Check that the deserialized objects are identical. + self.assertEqual(in_obj, out_obj) + + # Check that the input and output serialized messages are identical. + # If we fail here, there is a difference in the serialized + # representation but the new serialization still parses. This could + # be harmless (a change in map ordering?) or it could be bad (e.g. + # loss of packing in the encoding). + self.assertEqual(in_buf, out_buf) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testRoundtrip(self, case): + in_bufs = [value.SerializeToString() for value in case.values] + + # np.array silently truncates strings if you don't specify dtype=object. + in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shapes)) + return self._testRoundtrip( + in_bufs, 'tensorflow.contrib.proto.TestValue', case.fields) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testRoundtripPacked(self, case): + # Now try with the packed serialization. + # We test the packed representations by loading the same test cases using + # PackedTestValue instead of TestValue. To do this we rely on the text + # format being the same for packed and unpacked fields, and reparse the test + # message using the packed version of the proto. + in_bufs = [ + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_format.Parse( + text_format.MessageToString( + value, float_format='.17g'), + test_example_pb2.PackedTestValue()).SerializeToString() + for value in case.values + ] + + # np.array silently truncates strings if you don't specify dtype=object. + in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shapes)) + return self._testRoundtrip( + in_bufs, 'tensorflow.contrib.proto.PackedTestValue', case.fields) diff --git a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt deleted file mode 100644 index b170f89c0f00dd9dffd5785197bb3bfd1ca2cfee..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt +++ /dev/null @@ -1,161 +0,0 @@ -primitive { - double_value: -1.7976931348623158e+308 - double_value: 2.2250738585072014e-308 - double_value: 1.7976931348623158e+308 - float_value: -3.402823466e+38 - float_value: 1.175494351e-38 - float_value: 3.402823466e+38 - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - uint64_value: 0 - uint64_value: 18446744073709551615 - int32_value: -2147483648 - int32_value: 2147483647 - fixed64_value: 0 - fixed64_value: 18446744073709551615 - fixed32_value: 0 - fixed32_value: 4294967295 - bool_value: false - bool_value: true - string_value: "" - string_value: "I refer to the infinite." - uint32_value: 0 - uint32_value: 4294967295 - sfixed32_value: -2147483648 - sfixed32_value: 2147483647 - sfixed64_value: -9223372036854775808 - sfixed64_value: 9223372036854775807 - sint32_value: -2147483648 - sint32_value: 2147483647 - sint64_value: -9223372036854775808 - sint64_value: 9223372036854775807 -} -shape: 1 -sizes: 3 -sizes: 3 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: -1.7976931348623158e+308 - double_value: 2.2250738585072014e-308 - double_value: 1.7976931348623158e+308 - } -} -field { - name: "float_value" - dtype: DT_FLOAT - expected { - float_value: -3.402823466e+38 - float_value: 1.175494351e-38 - float_value: 3.402823466e+38 - } -} -field { - name: "int64_value" - dtype: DT_INT64 - expected { - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - } -} -field { - name: "uint64_value" - dtype: DT_INT64 - expected { - int64_value: 0 - int64_value: -1 - } -} -field { - name: "int32_value" - dtype: DT_INT32 - expected { - int32_value: -2147483648 - int32_value: 2147483647 - } -} -field { - name: "fixed64_value" - dtype: DT_INT64 - expected { - int64_value: 0 - int64_value: -1 # unsigned is 18446744073709551615 - } -} -field { - name: "fixed32_value" - dtype: DT_INT32 - expected { - int32_value: 0 - int32_value: -1 # unsigned is 4294967295 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: false - bool_value: true - } -} -field { - name: "string_value" - dtype: DT_STRING - expected { - string_value: "" - string_value: "I refer to the infinite." - } -} -field { - name: "uint32_value" - dtype: DT_INT32 - expected { - int32_value: 0 - int32_value: -1 # unsigned is 4294967295 - } -} -field { - name: "sfixed32_value" - dtype: DT_INT32 - expected { - int32_value: -2147483648 - int32_value: 2147483647 - } -} -field { - name: "sfixed64_value" - dtype: DT_INT64 - expected { - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - } -} -field { - name: "sint32_value" - dtype: DT_INT32 - expected { - int32_value: -2147483648 - int32_value: 2147483647 - } -} -field { - name: "sint64_value" - dtype: DT_INT64 - expected { - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt deleted file mode 100644 index c664e52851b5bb3c439544537ce6402fc7cf3362..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt +++ /dev/null @@ -1,16 +0,0 @@ -primitive { - message_value { - double_value: 23.5 - } -} -shape: 1 -sizes: 1 -field { - name: "message_value" - dtype: DT_STRING - expected { - message_value { - double_value: 23.5 - } - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt deleted file mode 100644 index 125651d7eaa1901e4804712bb807322b02ed5bc6..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt +++ /dev/null @@ -1,20 +0,0 @@ -primitive { - bool_value: true -} -shape: 1 -sizes: 1 -sizes: 0 -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - } -} -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 0.0 - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt deleted file mode 100644 index bc07efc8f3038c6c540855c97b2254575e517ef3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt +++ /dev/null @@ -1,29 +0,0 @@ -primitive { - fixed32_value: 4294967295 - uint32_value: 4294967295 -} -shape: 1 -sizes: 1 -field { - name: "fixed32_value" - dtype: DT_INT64 - expected { - int64_value: 4294967295 - } -} -sizes: 1 -field { - name: "uint32_value" - dtype: DT_INT64 - expected { - int64_value: 4294967295 - } -} -sizes: 0 -field { - name: "uint32_default" - dtype: DT_INT64 - expected { - int64_value: 4294967295 # Comes from an explicitly-specified default - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2950c7dfdc59a11ba7d2c07d8406bd4af26b5bd9 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py @@ -0,0 +1,419 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Test case base for testing proto operations.""" + +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ctypes as ct +import os + +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.platform import test + + +class ProtoOpTestBase(test.TestCase): + """Base class for testing proto decoding and encoding ops.""" + + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super(ProtoOpTestBase, self).__init__(methodName) + lib = os.path.join(os.path.dirname(__file__), "libtestexample.so") + if os.path.isfile(lib): + ct.cdll.LoadLibrary(lib) + + @staticmethod + def named_parameters(): + return ( + ("defaults", ProtoOpTestBase.defaults_test_case()), + ("minmax", ProtoOpTestBase.minmax_test_case()), + ("nested", ProtoOpTestBase.nested_test_case()), + ("optional", ProtoOpTestBase.optional_test_case()), + ("promote", ProtoOpTestBase.promote_test_case()), + ("ragged", ProtoOpTestBase.ragged_test_case()), + ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), + ("simple", ProtoOpTestBase.simple_test_case()), + ) + + @staticmethod + def defaults_test_case(): + test_case = test_example_pb2.TestCase() + test_case.values.add() # No fields specified, so we get all defaults. + test_case.shapes.append(1) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "double_value_with_default" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(1.0) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "float_value_with_default" + field.dtype = types_pb2.DT_FLOAT + field.value.float_value.append(2.0) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "int64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(3) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sfixed64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(11) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sint64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(13) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "uint64_value_with_default" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(4) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "fixed64_value_with_default" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(6) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "int32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(5) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sfixed32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(10) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sint32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(12) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "uint32_value_with_default" + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(9) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "fixed32_value_with_default" + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(7) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "bool_value_with_default" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "string_value_with_default" + field.dtype = types_pb2.DT_STRING + field.value.string_value.append("a") + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "bytes_value_with_default" + field.dtype = types_pb2.DT_STRING + field.value.string_value.append("a longer default string") + return test_case + + @staticmethod + def minmax_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(-1.7976931348623158e+308) + value.double_value.append(2.2250738585072014e-308) + value.double_value.append(1.7976931348623158e+308) + value.float_value.append(-3.402823466e+38) + value.float_value.append(1.175494351e-38) + value.float_value.append(3.402823466e+38) + value.int64_value.append(-9223372036854775808) + value.int64_value.append(9223372036854775807) + value.sfixed64_value.append(-9223372036854775808) + value.sfixed64_value.append(9223372036854775807) + value.sint64_value.append(-9223372036854775808) + value.sint64_value.append(9223372036854775807) + value.uint64_value.append(0) + value.uint64_value.append(18446744073709551615) + value.fixed64_value.append(0) + value.fixed64_value.append(18446744073709551615) + value.int32_value.append(-2147483648) + value.int32_value.append(2147483647) + value.sfixed32_value.append(-2147483648) + value.sfixed32_value.append(2147483647) + value.sint32_value.append(-2147483648) + value.sint32_value.append(2147483647) + value.uint32_value.append(0) + value.uint32_value.append(4294967295) + value.fixed32_value.append(0) + value.fixed32_value.append(4294967295) + value.bool_value.append(False) + value.bool_value.append(True) + value.string_value.append("") + value.string_value.append("I refer to the infinite.") + test_case.shapes.append(1) + test_case.sizes.append(3) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(-1.7976931348623158e+308) + field.value.double_value.append(2.2250738585072014e-308) + field.value.double_value.append(1.7976931348623158e+308) + test_case.sizes.append(3) + field = test_case.fields.add() + field.name = "float_value" + field.dtype = types_pb2.DT_FLOAT + field.value.float_value.append(-3.402823466e+38) + field.value.float_value.append(1.175494351e-38) + field.value.float_value.append(3.402823466e+38) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "int64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(-9223372036854775808) + field.value.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sfixed64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(-9223372036854775808) + field.value.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sint64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(-9223372036854775808) + field.value.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "uint64_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(0) + field.value.uint64_value.append(18446744073709551615) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "fixed64_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(0) + field.value.uint64_value.append(18446744073709551615) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "int32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(-2147483648) + field.value.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sfixed32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(-2147483648) + field.value.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sint32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(-2147483648) + field.value.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "uint32_value" + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(0) + field.value.uint32_value.append(4294967295) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "fixed32_value" + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(0) + field.value.uint32_value.append(4294967295) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(False) + field.value.bool_value.append(True) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "string_value" + field.dtype = types_pb2.DT_STRING + field.value.string_value.append("") + field.value.string_value.append("I refer to the infinite.") + return test_case + + @staticmethod + def nested_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + message_value = value.message_value.add() + message_value.double_value = 23.5 + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "message_value" + field.dtype = types_pb2.DT_STRING + message_value = field.value.message_value.add() + message_value.double_value = 23.5 + return test_case + + @staticmethod + def optional_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.bool_value.append(True) + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(0.0) + return test_case + + @staticmethod + def promote_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.sint32_value.append(2147483647) + value.sfixed32_value.append(2147483647) + value.int32_value.append(2147483647) + value.fixed32_value.append(4294967295) + value.uint32_value.append(4294967295) + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "sint32_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(2147483647) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "sfixed32_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(2147483647) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "int32_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(2147483647) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "fixed32_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(4294967295) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "uint32_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(4294967295) + return test_case + + @staticmethod + def ragged_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(23.5) + value.double_value.append(123.0) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(3.1) + value.bool_value.append(False) + test_case.shapes.append(2) + test_case.sizes.append(2) + test_case.sizes.append(1) + test_case.sizes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(23.5) + field.value.double_value.append(123.0) + field.value.double_value.append(3.1) + field.value.double_value.append(0.0) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + field.value.bool_value.append(False) + return test_case + + @staticmethod + def shaped_batch_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(23.5) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(44.0) + value.bool_value.append(False) + value = test_case.values.add() + value.double_value.append(3.14159) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(1.414) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(-32.2) + value.bool_value.append(False) + value = test_case.values.add() + value.double_value.append(0.0001) + value.bool_value.append(True) + test_case.shapes.append(3) + test_case.shapes.append(2) + for _ in range(12): + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(23.5) + field.value.double_value.append(44.0) + field.value.double_value.append(3.14159) + field.value.double_value.append(1.414) + field.value.double_value.append(-32.2) + field.value.double_value.append(0.0001) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + field.value.bool_value.append(False) + field.value.bool_value.append(True) + field.value.bool_value.append(True) + field.value.bool_value.append(False) + field.value.bool_value.append(True) + return test_case + + @staticmethod + def simple_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(23.5) + value.bool_value.append(True) + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(23.5) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + return test_case diff --git a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt deleted file mode 100644 index 61c7ac53f72b0764a0d57241cbdcdd93fcbd9279..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt +++ /dev/null @@ -1,32 +0,0 @@ -primitive { - double_value: 23.5 - double_value: 123.0 - bool_value: true -} -primitive { - double_value: 3.1 - bool_value: false -} -shape: 2 -sizes: 2 -sizes: 1 -sizes: 1 -sizes: 1 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 23.5 - double_value: 123.0 - double_value: 3.1 - double_value: 0.0 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - bool_value: false - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt deleted file mode 100644 index f4828076d52dc5d03a887c4a445dbcf52414c361..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt +++ /dev/null @@ -1,62 +0,0 @@ -primitive { - double_value: 23.5 - bool_value: true -} -primitive { - double_value: 44.0 - bool_value: false -} -primitive { - double_value: 3.14159 - bool_value: true -} -primitive { - double_value: 1.414 - bool_value: true -} -primitive { - double_value: -32.2 - bool_value: false -} -primitive { - double_value: 0.0001 - bool_value: true -} -shape: 3 -shape: 2 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 23.5 - double_value: 44.0 - double_value: 3.14159 - double_value: 1.414 - double_value: -32.2 - double_value: 0.0001 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - bool_value: false - bool_value: true - bool_value: true - bool_value: false - bool_value: true - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt deleted file mode 100644 index dc20ac147b0e772f05b4fc614f9f56513aceb1d5..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt +++ /dev/null @@ -1,21 +0,0 @@ -primitive { - double_value: 23.5 - bool_value: true -} -shape: 1 -sizes: 1 -sizes: 1 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 23.5 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto index a2c88e372bf7c6b7f14c5bb55776b66c4c06bcd4..674d881220a1113631def47c5111e3ef401b99f3 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto +++ b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto @@ -1,6 +1,4 @@ // Test description and protos to work with it. -// -// Many of the protos in this file are for unit tests that haven't been written yet. syntax = "proto2"; @@ -8,54 +6,27 @@ import "tensorflow/core/framework/types.proto"; package tensorflow.contrib.proto; -// A TestCase holds a proto and a bunch of assertions -// about how it should decode. +// A TestCase holds a proto and assertions about how it should decode. message TestCase { - // A batch of primitives to be serialized and decoded. - repeated RepeatedPrimitiveValue primitive = 1; - // The shape of the batch. - repeated int32 shape = 2; + // Batches of primitive values. + repeated TestValue values = 1; + // The batch shapes. + repeated int32 shapes = 2; // Expected sizes for each field. repeated int32 sizes = 3; // Expected values for each field. - repeated FieldSpec field = 4; + repeated FieldSpec fields = 4; }; // FieldSpec describes the expected output for a single field. message FieldSpec { optional string name = 1; optional tensorflow.DataType dtype = 2; - optional RepeatedPrimitiveValue expected = 3; + optional TestValue value = 3; }; +// NOTE: This definition must be kept in sync with PackedTestValue. message TestValue { - optional PrimitiveValue primitive_value = 1; - optional EnumValue enum_value = 2; - optional MessageValue message_value = 3; - optional RepeatedMessageValue repeated_message_value = 4; - optional RepeatedPrimitiveValue repeated_primitive_value = 6; -} - -message PrimitiveValue { - optional double double_value = 1; - optional float float_value = 2; - optional int64 int64_value = 3; - optional uint64 uint64_value = 4; - optional int32 int32_value = 5; - optional fixed64 fixed64_value = 6; - optional fixed32 fixed32_value = 7; - optional bool bool_value = 8; - optional string string_value = 9; - optional bytes bytes_value = 12; - optional uint32 uint32_value = 13; - optional sfixed32 sfixed32_value = 15; - optional sfixed64 sfixed64_value = 16; - optional sint32 sint32_value = 17; - optional sint64 sint64_value = 18; -} - -// NOTE: This definition must be kept in sync with PackedPrimitiveValue. -message RepeatedPrimitiveValue { repeated double double_value = 1; repeated float float_value = 2; repeated int64 int64_value = 3; @@ -74,30 +45,31 @@ message RepeatedPrimitiveValue { repeated PrimitiveValue message_value = 19; // Optional fields with explicitly-specified defaults. - optional double double_default = 20 [default = 1.0]; - optional float float_default = 21 [default = 2.0]; - optional int64 int64_default = 22 [default = 3]; - optional uint64 uint64_default = 23 [default = 4]; - optional int32 int32_default = 24 [default = 5]; - optional fixed64 fixed64_default = 25 [default = 6]; - optional fixed32 fixed32_default = 26 [default = 7]; - optional bool bool_default = 27 [default = true]; - optional string string_default = 28 [default = "a"]; - optional bytes bytes_default = 29 [default = "a longer default string"]; - optional uint32 uint32_default = 30 [default = 4294967295]; - optional sfixed32 sfixed32_default = 31 [default = 10]; - optional sfixed64 sfixed64_default = 32 [default = 11]; - optional sint32 sint32_default = 33 [default = 12]; - optional sint64 sint64_default = 34 [default = 13]; + optional double double_value_with_default = 20 [default = 1.0]; + optional float float_value_with_default = 21 [default = 2.0]; + optional int64 int64_value_with_default = 22 [default = 3]; + optional uint64 uint64_value_with_default = 23 [default = 4]; + optional int32 int32_value_with_default = 24 [default = 5]; + optional fixed64 fixed64_value_with_default = 25 [default = 6]; + optional fixed32 fixed32_value_with_default = 26 [default = 7]; + optional bool bool_value_with_default = 27 [default = true]; + optional string string_value_with_default = 28 [default = "a"]; + optional bytes bytes_value_with_default = 29 + [default = "a longer default string"]; + optional uint32 uint32_value_with_default = 30 [default = 9]; + optional sfixed32 sfixed32_value_with_default = 31 [default = 10]; + optional sfixed64 sfixed64_value_with_default = 32 [default = 11]; + optional sint32 sint32_value_with_default = 33 [default = 12]; + optional sint64 sint64_value_with_default = 34 [default = 13]; } -// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue -// in the text format, but the binary serializion is different. -// We test the packed representations by loading the same test cases -// using this definition instead of RepeatedPrimitiveValue. -// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue -// in every way except the packed=true declaration. -message PackedPrimitiveValue { +// A PackedTestValue looks exactly the same as a TestValue in the text format, +// but the binary serializion is different. We test the packed representations +// by loading the same test cases using this definition instead of TestValue. +// +// NOTE: This definition must be kept in sync with TestValue in every way except +// the packed=true declaration. +message PackedTestValue { repeated double double_value = 1 [packed = true]; repeated float float_value = 2 [packed = true]; repeated int64 int64_value = 3 [packed = true]; @@ -115,23 +87,53 @@ message PackedPrimitiveValue { repeated sint64 sint64_value = 18 [packed = true]; repeated PrimitiveValue message_value = 19; - optional double double_default = 20 [default = 1.0]; - optional float float_default = 21 [default = 2.0]; - optional int64 int64_default = 22 [default = 3]; - optional uint64 uint64_default = 23 [default = 4]; - optional int32 int32_default = 24 [default = 5]; - optional fixed64 fixed64_default = 25 [default = 6]; - optional fixed32 fixed32_default = 26 [default = 7]; - optional bool bool_default = 27 [default = true]; - optional string string_default = 28 [default = "a"]; - optional bytes bytes_default = 29 [default = "a longer default string"]; - optional uint32 uint32_default = 30 [default = 4294967295]; - optional sfixed32 sfixed32_default = 31 [default = 10]; - optional sfixed64 sfixed64_default = 32 [default = 11]; - optional sint32 sint32_default = 33 [default = 12]; - optional sint64 sint64_default = 34 [default = 13]; + optional double double_value_with_default = 20 [default = 1.0]; + optional float float_value_with_default = 21 [default = 2.0]; + optional int64 int64_value_with_default = 22 [default = 3]; + optional uint64 uint64_value_with_default = 23 [default = 4]; + optional int32 int32_value_with_default = 24 [default = 5]; + optional fixed64 fixed64_value_with_default = 25 [default = 6]; + optional fixed32 fixed32_value_with_default = 26 [default = 7]; + optional bool bool_value_with_default = 27 [default = true]; + optional string string_value_with_default = 28 [default = "a"]; + optional bytes bytes_value_with_default = 29 + [default = "a longer default string"]; + optional uint32 uint32_value_with_default = 30 [default = 9]; + optional sfixed32 sfixed32_value_with_default = 31 [default = 10]; + optional sfixed64 sfixed64_value_with_default = 32 [default = 11]; + optional sint32 sint32_value_with_default = 33 [default = 12]; + optional sint64 sint64_value_with_default = 34 [default = 13]; } +message PrimitiveValue { + optional double double_value = 1; + optional float float_value = 2; + optional int64 int64_value = 3; + optional uint64 uint64_value = 4; + optional int32 int32_value = 5; + optional fixed64 fixed64_value = 6; + optional fixed32 fixed32_value = 7; + optional bool bool_value = 8; + optional string string_value = 9; + optional bytes bytes_value = 12; + optional uint32 uint32_value = 13; + optional sfixed32 sfixed32_value = 15; + optional sfixed64 sfixed64_value = 16; + optional sint32 sint32_value = 17; + optional sint64 sint64_value = 18; +} + +// Message containing fields with field numbers higher than any field above. +// An instance of this message is prepended to each binary message in the test +// to exercise the code path that handles fields encoded out of order of field +// number. +message ExtraFields { + optional string string_value = 1776; + optional bool bool_value = 1777; +} + +// The messages below are for yet-to-be created tests. + message EnumValue { enum Color { RED = 0; @@ -171,12 +173,3 @@ message RepeatedMessageValue { repeated NestedMessageValue message_values = 11; } - -// Message containing fields with field numbers higher than any field above. An -// instance of this message is prepended to each binary message in the test to -// exercise the code path that handles fields encoded out of order of field -// number. -message ExtraFields { - optional string string_value = 1776; - optional bool bool_value = 1777; -} diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index 55479bf5f74299bf09f131a6127f9f11d6192d90..d9f179bee48de587976872dabb470cfd5c69114c 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -120,8 +120,10 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): scaled_weight_tensor = math_ops.multiply( weights, multiplier_tensor, name='mul_fold') + new_layer_tensor = _CloneWithNewOperands( - match.layer_op, match.input_tensor, scaled_weight_tensor) + match.layer_op, match.input_tensor, scaled_weight_tensor, + match.batch_to_space_op) if correction_recip is not None: new_layer_tensor = math_ops.multiply( @@ -149,6 +151,8 @@ def _FindFusedBatchNorms(graph): _FusedBatchNormMatches. """ input_pattern = graph_matcher.OpTypePattern('*') + # In practice, the weight pattern can match a Variable or a SpaceToBatchND + # operation that follows a variable for atrous convolutions. weight_pattern = graph_matcher.OpTypePattern('*') gamma_pattern = graph_matcher.OpTypePattern('*') beta_pattern = graph_matcher.OpTypePattern('*') @@ -160,16 +164,27 @@ def _FindFusedBatchNorms(graph): layer_pattern = graph_matcher.OpTypePattern( 'Conv2D|DepthwiseConv2dNative|MatMul', inputs=[input_pattern, weight_pattern]) + batch_to_space_pattern = graph_matcher.OpTypePattern( + 'BatchToSpaceND', + inputs=[ + layer_pattern, + graph_matcher.OpTypePattern('*'), + graph_matcher.OpTypePattern('*') + ]) + layer_output_pattern = graph_matcher.OneofPattern( + [layer_pattern, batch_to_space_pattern]) # MatMul has a Reshape between it and FusedBatchNorm. matmul_reshape_pattern = graph_matcher.OpTypePattern( - 'Reshape', inputs=[layer_pattern, - graph_matcher.OpTypePattern('*')]) + 'Reshape', + inputs=[layer_output_pattern, + graph_matcher.OpTypePattern('*')]) batch_norm_pattern = graph_matcher.OpTypePattern( 'FusedBatchNorm', inputs=[ - graph_matcher.OneofPattern([matmul_reshape_pattern, layer_pattern]), - gamma_pattern, beta_pattern, mean_pattern, variance_pattern + graph_matcher.OneofPattern( + [matmul_reshape_pattern, layer_output_pattern]), gamma_pattern, + beta_pattern, mean_pattern, variance_pattern ]) matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern( 'Reshape', inputs=[batch_norm_pattern, @@ -192,6 +207,7 @@ def _FindFusedBatchNorms(graph): moving_variance_tensor = None bn_decay_mean_tensor = None bn_decay_var_tensor = None + batch_to_space_op = None layer_op = match_result.get_op(layer_pattern) layer_tensor = match_result.get_tensor(layer_pattern) bn_op = match_result.get_op(batch_norm_pattern) @@ -213,6 +229,7 @@ def _FindFusedBatchNorms(graph): if not output_tensor.consumers(): continue + batch_to_space_op = match_result.get_op(batch_to_space_pattern) input_tensor = match_result.get_tensor(input_pattern) weight_tensor = match_result.get_tensor(weight_pattern) gamma_tensor = match_result.get_tensor(gamma_pattern) @@ -276,7 +293,8 @@ def _FindFusedBatchNorms(graph): moving_variance_tensor=moving_variance_tensor, bn_decay_mean_tensor=bn_decay_mean_tensor, bn_decay_var_tensor=bn_decay_var_tensor, - batch_epsilon=batch_epsilon) + batch_epsilon=batch_epsilon, + batch_to_space_op=batch_to_space_op) def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, @@ -351,20 +369,20 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, lambda: bn_decay_zero, lambda: match.bn_decay_mean_tensor, name='freeze_moving_mean') + graph_editor.reroute_ts( [bn_decay_mean_out], [match.bn_decay_mean_tensor], can_modify=bn_decay_mean_consumers) - if fused_batch_norm is False: - bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) - bn_decay_var_out = utils.smart_cond( - use_mv_avg, - lambda: bn_decay_zero, - lambda: match.bn_decay_var_tensor, - name='freeze_moving_var') - graph_editor.reroute_ts( - [bn_decay_var_out], [match.bn_decay_var_tensor], - can_modify=bn_decay_var_consumers) + bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) + bn_decay_var_out = utils.smart_cond( + use_mv_avg, + lambda: bn_decay_zero, + lambda: match.bn_decay_var_tensor, + name='freeze_moving_var') + graph_editor.reroute_ts( + [bn_decay_var_out], [match.bn_decay_var_tensor], + can_modify=bn_decay_var_consumers) correction_recip = utils.smart_cond( use_mv_avg, @@ -380,7 +398,8 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, return correction_scale, correction_recip, correction_offset -def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): +def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor, + batch_to_space_op): """Clones layer_op with input_tensor and weight_tensor as new inputs.""" new_layer_name = layer_op.name.split('/')[-1] + '_Fold' if layer_op.type == 'Conv2D': @@ -400,12 +419,25 @@ def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): transpose_b=layer_op.get_attr('transpose_b'), name=new_layer_name) elif layer_op.type == 'DepthwiseConv2dNative': - return nn.depthwise_conv2d( + conv = nn.depthwise_conv2d( input_tensor, weight_tensor, + rate=layer_op.get_attr('dilations'), strides=layer_op.get_attr('strides'), padding=layer_op.get_attr('padding'), name=new_layer_name) + # Copy the batch to space operation if we have a atrous convolution. + if batch_to_space_op: + batch_to_space_op = layer_op.outputs[0].consumers()[0] + # TODO(suharshs): It's hard to make this name match with the unfused name. + # Restructure this code to not rely on scope at all. + new_batch_to_space_name = batch_to_space_op.name.split('/')[-1] + '_Fold' + conv = array_ops.batch_to_space_nd( + conv, + batch_to_space_op.inputs[1], + batch_to_space_op.inputs[2], + name=new_batch_to_space_name) + return conv else: raise ValueError('Cannot handle operation of type: %s' % layer_op.type) @@ -617,7 +649,8 @@ def _GetBatchNormParams(graph, context, has_scaling): moving_variance_tensor=moving_variance_tensor, bn_decay_mean_tensor=bn_decay_mean_tensor, bn_decay_var_tensor=bn_decay_var_tensor, - batch_epsilon=batch_epsilon) + batch_epsilon=batch_epsilon, + batch_to_space_op=None) def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, @@ -651,6 +684,11 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, '/BatchNorm/batchnorm_1/' + mul_scale_name) op_below = mul_scale.inputs[0].op + # Skip over the BatchToSpace operation in the case of atrous convolutions. + batch_to_space_op = None + if op_below.type == 'BatchToSpaceND': + batch_to_space_op = op_below + op_below = op_below.inputs[0].op weights = op_below.inputs[1] match = _GetBatchNormParams( graph=graph, context=context, has_scaling=has_scaling) @@ -691,7 +729,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, context + '/correction_mult') mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)]) else: - raise ValueError('Cannot handle operation of type: %s' % op_below.op) + raise ValueError('Cannot handle operation of type: %s' % op_below.type) _AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0]) conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold', @@ -701,6 +739,13 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay, context + '/BatchNorm/batchnorm_1/add_1') corrected_output = conv_or_fc_folded.outputs[0] + # Copy the batch to space operation if we have a atrous convolution. + if batch_to_space_op: + corrected_output = array_ops.batch_to_space_nd( + corrected_output, + batch_to_space_op.inputs[1], + batch_to_space_op.inputs[2], + name=batch_to_space_op.name + '_Fold') if correction_offset is not None: with ops.device(conv_or_fc_folded.device): corrected_output = math_ops.multiply(correction_recip, corrected_output, @@ -898,7 +943,8 @@ class _BatchNormMatch(object): def __init__(self, layer_op, bn_op, output_tensor, input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, variance_tensor, moving_mean_tensor, moving_variance_tensor, - bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon): + bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon, + batch_to_space_op): self._layer_op = layer_op self._bn_op = bn_op self._output_tensor = output_tensor @@ -913,6 +959,7 @@ class _BatchNormMatch(object): self._bn_decay_mean_tensor = bn_decay_mean_tensor self._bn_decay_var_tensor = bn_decay_var_tensor self._batch_epsilon = batch_epsilon + self._batch_to_space_op = batch_to_space_op @property def layer_op(self): @@ -969,3 +1016,7 @@ class _BatchNormMatch(object): @property def bn_decay_var_tensor(self): return self._bn_decay_var_tensor + + @property + def batch_to_space_op(self): + return self._batch_to_space_op diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index bfa9d3bf705e327091098a8e416b7902f852605a..3f8063cc022726cb745d42aba3c834c71e876e70 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -128,6 +128,9 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) + for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -216,6 +219,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = [scope + '/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -284,6 +289,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -351,6 +358,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -431,6 +440,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -438,6 +449,92 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): def testFoldDepthwiseConv2d(self): self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) + def _TestFoldAtrousConv2d(self, relu, relu_op_name, with_bypass, has_scaling, + fused_batch_norm, freeze_batch_norm_delay): + """Tests folding: inputs -> AtrousConv2d with batch norm -> Relu*. + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. + freeze_batch_norm_delay: None or the number of steps after which training + switches to using frozen mean and variance + """ + g = ops.Graph() + with g.as_default(): + batch_size, height, width = 5, 128, 128 + inputs = array_ops.zeros((batch_size, height, width, 3)) + dilation_rate = 2 + activation_fn = None if with_bypass else relu + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d( + inputs, + None, [3, 3], + rate=dilation_rate, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms( + g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + if fused_batch_norm: + scale_reshape_op_name = scope + '/BatchNorm_Fold/scale_reshape' + else: + scale_reshape_op_name = scope + '/scale_reshape' + self._AssertInputOpsAre(folded_mul, + [scope + '/correction_mult', scale_reshape_op_name]) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold']) + + scale_reshape = g.get_operation_by_name(scale_reshape_op_name) + self.assertEqual(scale_reshape.type, 'Reshape') + self._AssertInputOpsAre(scale_reshape, [ + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm), + scale_reshape_op_name + '/shape' + ]) + self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold']) + + folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold') + self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative') + self._AssertInputOpsAre( + folded_conv, [scope + '/mul_fold', scope + '/depthwise/SpaceToBatchND']) + if fused_batch_norm: + self._AssertOutputGoesToOps(folded_conv, g, + [scope + '/BatchToSpaceND_Fold']) + else: + self._AssertOutputGoesToOps(folded_conv, g, + [scope + '/depthwise/BatchToSpaceND_Fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, [ + scope + '/correction_add', + self._BathNormBiasName(scope, fused_batch_norm) + ]) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) + + for op in g.get_operations(): + self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) + + def testFoldAtrousConv2d(self): + self._RunTestOverParameters(self._TestFoldAtrousConv2d) + def _TestCompareFoldAndUnfolded(self, relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm, freeze_batch_norm_delay): @@ -560,6 +657,22 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): out_op = graph.get_operation_by_name(out_op_name) self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) + def _AssertMovingAveragesAreFrozen(self, graph, scope): + """Asserts to check if moving mean and variance are frozen. + + Args: + graph: Graph where the operations are located. + scope: Scope of batch norm op + """ + moving_average_mult = graph.get_operation_by_name( + scope + '/BatchNorm/AssignMovingAvg/mul') + self.assertTrue( + moving_average_mult.inputs[1].name.find('freeze_moving_mean/Merge') > 0) + moving_var_mult = graph.get_operation_by_name( + scope + '/BatchNorm/AssignMovingAvg_1/mul') + self.assertTrue( + moving_var_mult.inputs[1].name.find('freeze_moving_var/Merge') > 0) + def _CopyGraph(self, graph): """Return a copy of graph.""" meta_graph = saver_lib.export_meta_graph( diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py index c2a8def48012c808da18587c8ff462fa33a363c0..a45840009b758881c14fb64b2d39af6cd4ec4bc4 100644 --- a/tensorflow/contrib/quantize/python/quant_ops_test.py +++ b/tensorflow/contrib/quantize/python/quant_ops_test.py @@ -75,7 +75,7 @@ class QuantOpsTest(googletest.TestCase): self.assertGreater(max_value, 0.0) self.assertLess(max_value, 1.0) - def testVariablesNotParitioned_LastValue(self): + def testVariablesNotPartitioned_LastValue(self): # Variables added should not use a default partiioner since they are # scalar. There would be a tensorflow error thrown if the partitioner was # respected by the rewrite. @@ -90,7 +90,7 @@ class QuantOpsTest(googletest.TestCase): is_training=True, vars_collection=_MIN_MAX_VARS) - def testVariablesNotParitioned_MovingAvg(self): + def testVariablesNotPartitioned_MovingAvg(self): # Variables added should not use a default partiioner since they are # scalar. There would be a tensorflow error thrown if the partitioner was # respected by the rewrite. diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index cbba72643f7f166c473b6181edc292f695c4cbc2..cb66fd1f76bcdb0a8f77fc7c476511576368ab4e 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -194,9 +194,11 @@ def _FindLayersToQuantize(graph): / conv|fc | + [batch_to_space_nd] + | [post_conv_correction] | - biasadd|folded_bias + [biasadd|folded_bias] | [bypass] | @@ -247,9 +249,31 @@ def _FindLayersToQuantize(graph): ], ordered_inputs=False) + # For atrous convolutions a BatchToSpaceND will occur after the depthwise + # convolution. + batch_to_space_pattern = graph_matcher.OpTypePattern( + 'BatchToSpaceND', + inputs=[ + layer_pattern, + graph_matcher.OpTypePattern('*'), + graph_matcher.OpTypePattern('*') + ]) + + layer_output_pattern = graph_matcher.OneofPattern( + [batch_to_space_pattern, layer_pattern]) + + # For separable convolutions, we are looking for a conv, followed by a conv + # with no activations between the two. + sep_conv_pattern = graph_matcher.OpTypePattern( + '|'.join(_QUANTIZABLE_TYPES), + inputs=[ + graph_matcher.OneofPattern([layer_output_pattern]), + graph_matcher.OpTypePattern('*') + ], + ordered_inputs=False) folded_bias_mul_pattern = graph_matcher.OpTypePattern( 'Mul', - inputs=[graph_matcher.OpTypePattern('*'), layer_pattern], + inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern], ordered_inputs=False) post_layer_op_correction_pattern = graph_matcher.OpTypePattern( 'Add', @@ -264,29 +288,39 @@ def _FindLayersToQuantize(graph): ], ordered_inputs=False) + # batch_norms with forced updates have an Identity operation at the end. + # TODO(suharshs): Find a way to easily skip extra Identity operations. The + # current issue is that doing so can often match patterns across many layers + # incorrectly. + batch_norm_identity = graph_matcher.OpTypePattern( + 'Identity', inputs=[folded_bias_add_pattern]) + bias_add_pattern = graph_matcher.OpTypePattern( - 'Add|BiasAdd', inputs=[layer_pattern, '*'], ordered_inputs=False) + 'Add|BiasAdd', inputs=[layer_output_pattern, '*'], ordered_inputs=False) # The bias can come from the bias add or the folded bias add. bypass_pattern = graph_matcher.OpTypePattern( 'Add', inputs=[ graph_matcher.OneofPattern( - [bias_add_pattern, folded_bias_add_pattern]), '*' + [bias_add_pattern, folded_bias_add_pattern, batch_norm_identity]), + '*' ], ordered_inputs=False) # The input to the activation can come from bias add, fold bias add, the # bypasses. # TODO(suharshs): We should ideally skip Identity operations instead of - # treating them as an activation. + # treating them as activations. activation_pattern = graph_matcher.OpTypePattern( '|'.join(_ACTIVATION_TYPES) + '|Identity', inputs=[ graph_matcher.OneofPattern([ bias_add_pattern, folded_bias_add_pattern, + batch_norm_identity, bypass_pattern, + layer_pattern, ]) ]) @@ -370,15 +404,18 @@ def _FindLayersToQuantize(graph): layer_matches.append( _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None)) - return layer_matches - + # Look for separable convolutions here + sep_conv_matcher = graph_matcher.GraphMatcher(sep_conv_pattern) + for match_result in sep_conv_matcher.match_graph(graph): + layer_op = match_result.get_op(layer_pattern) + weight_tensor = match_result.get_tensor(weight_identity_pattern) + activation_op = match_result.get_op(layer_pattern) + if layer_op not in matched_layer_set: + matched_layer_set.add(layer_op) + layer_matches.append( + _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None)) -def _HasPostActivationBypass(activation_op): - for activation_tensor in activation_op.outputs: - for output_op in activation_tensor.consumers(): - if output_op.type == 'Add': - return True - return False + return layer_matches class _LayerMatch(object): diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py index 11d052d7f491dc029d1bda9b47364d6e9c880a67..2944f964c7078814111c96890f18abe1607b68fc 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph.py +++ b/tensorflow/contrib/quantize/python/quantize_graph.py @@ -191,6 +191,7 @@ def experimental_create_training_graph(input_graph=None, def experimental_create_eval_graph(input_graph=None, weight_bits=8, activation_bits=8, + quant_delay=None, scope=None): """Rewrites an eval input_graph in place for simulated quantization. @@ -209,6 +210,8 @@ def experimental_create_eval_graph(input_graph=None, default graph. weight_bits: Number of bits to use for quantizing weights. activation_bits: Number of bits to use for quantizing activations. + quant_delay: Number of steps after which weights and activations are + quantized during eval. scope: The scope to be transformed. If it's not None, only the ops which are in this scope will be transformed. @@ -221,4 +224,5 @@ def experimental_create_eval_graph(input_graph=None, is_training=False, weight_bits=weight_bits, activation_bits=activation_bits, + quant_delay=quant_delay, scope=scope) diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index db745aa56212af6a9c20e06ee9e4e5d6e27cf3c3..31a2955ddb3b32f2b07c6125c8f83ffba335cc5f 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -276,6 +276,52 @@ class QuantizeTest(test_util.TensorFlowTestCase): graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass, delay, use_resource) + def testQuantize_AtrousConvWithoutBatchNorm(self): + self._RunWithoutBatchNormTestOverParameters( + self._TestQuantize_AtrousConvWithoutBatchNorm) + + def _TestQuantize_AtrousConvWithoutBatchNorm( + self, activation, activation_op_name, with_bypass, delay, use_resource): + """Tests quantization: inputs -> atrous conv no batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + use_resource: Bool, when true uses resource variables. + """ + graph = ops.Graph() + with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + dilation_rate = 2 + activation_fn = None if with_bypass else activation + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d( + inputs, + None, [3, 3], + rate=dilation_rate, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + quantize.Quantize(graph, True, quant_delay=delay) + + self._AssertCorrectQuantizedGraphWithoutBatchNorm( + graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass, + delay, use_resource) + def _RunBatchNormTestOverParameters(self, test_fn): # TODO(suharshs): Use parameterized test once OSS TF supports it. parameters_list = [ @@ -543,6 +589,61 @@ class QuantizeTest(test_util.TensorFlowTestCase): graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass, delay, use_resource) + def testQuantize_AtrousConvWithBatchNorm(self): + self._RunBatchNormTestOverParameters( + self._TestQuantize_AtrousConvWithBatchNorm) + + def _TestQuantize_AtrousConvWithBatchNorm( + self, activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_resource): + """Tests quantization: inputs -> atrous conv with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. + use_resource: Bool, when true uses resource variables. + """ + graph = ops.Graph() + with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + dilation_rate = 2 + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d( + inputs, + None, [3, 3], + rate=dilation_rate, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + + # Manually add a bypass (optional) and an activation. + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + + node = activation(node, name='test/' + activation_op_name) + + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + fold_batch_norms.FoldBatchNorms(graph, is_training=True) + quantize.Quantize(graph, True, quant_delay=delay) + + self._AssertCorrectQuantizedGraphWithBatchNorm( + graph, scope, 'DepthwiseConv2dNative', activation_op_name, + with_bypass, delay, use_resource) + def _AssertIdempotent(self, graph): # Ensure that calling the rewrite again doesn't change the graph. graph_def_before = str(graph.as_graph_def()) @@ -553,8 +654,80 @@ class QuantizeTest(test_util.TensorFlowTestCase): graph_def_after = str(graph.as_graph_def()) self.assertEqual(graph_def_before, graph_def_after) - def _BatchNormParams(self, fused=False): - return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused} + def testBatchNormForcedUpdates(self): + parameter_list = [ + # (activation, activation_op_name, fused_batch_norm) + (nn_ops.relu6, 'Relu6', False), + (nn_ops.relu, 'Relu', False), + (array_ops.identity, 'Identity', False), + (nn_ops.relu6, 'Relu6', True), + (nn_ops.relu, 'Relu', True), + (array_ops.identity, 'Identity', True), + ] + for params in parameter_list: + self._TestBatchNormForcedUpdates(params[0], params[1], params[2], False) + self._TestBatchNormForcedUpdates(params[0], params[1], params[2], True) + + def _TestBatchNormForcedUpdates(self, activation, activation_op_name, + fused_batch_norm, use_resource): + """post_activation bypass quantization should happen with forced updates.""" + graph = ops.Graph() + with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32)) + # Setting updates_collections to None forces updates adding an extra + # identity operation following batch norms. + bn_params = self._BatchNormParams( + fused=fused_batch_norm, force_updates=True) + conv = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation, + normalizer_fn=batch_norm, + normalizer_params=bn_params, + scope='test/test') + bypass_tensor = math_ops.add(conv, input2, name='test/add') + # The output of the post_activation bypass will be another layer. + _ = conv2d( + bypass_tensor, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + normalizer_fn=batch_norm, + normalizer_params=bn_params, + activation_fn=activation, + scope='test/unused') + + fold_batch_norms.FoldBatchNorms(graph, is_training=True) + quantize.Quantize(graph, is_training=True) + + # Ensure that the bypass node is preceded by and followed by a + # FakeQuantWithMinMaxVar operation, since the output of the Add isn't an + # activation. + self.assertTrue('FakeQuantWithMinMaxVars' in + [c.type for c in bypass_tensor.consumers()]) + self.assertTrue('FakeQuantWithMinMaxVars' in + [i.op.type for i in bypass_tensor.op.inputs]) + + with open('/tmp/bn_quant_test.pbtxt', 'w') as f: + f.write(str(graph.as_graph_def())) + + def _BatchNormParams(self, fused=False, force_updates=False): + params = { + 'center': True, + 'scale': True, + 'decay': 1.0 - 0.003, + 'fused': fused + } + if force_updates: + params['updates_collections'] = None + return params def _WeightInit(self, stddev): """Returns truncated normal variable initializer. diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 92ca4a1b0c3126ebccf2b525f01f4d6455c4d527..06ebcdfee1617af0c13cd6ed09a2ec5190c5a718 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -122,12 +122,67 @@ class QuantizeTest(test_util.TensorFlowTestCase): array_ops.identity(node, name='control_dependency') quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + # Check if output of bias add is quantized + quantization_node_name = 'FakeQuantWithMinMaxVars' + conv_quant = graph.get_operation_by_name('test/test/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + + for op in graph.get_operations(): + if op.type == quantization_node_name: + quant_op = graph.get_operation_by_name(op.name) + # Scan through all FakeQuant operations, ensuring that the activation + # identity op isn't in the consumers of the operation. + consumers = [] + for output in quant_op.outputs: + consumers.extend(output.consumers()) + + self.assertNotIn('test/relu6', [c.name for c in consumers]) + + def testInsertQuantOpInSeparableConv2d(self): + self._RunTestOverParameters(self._TestInsertQuantOpInSeparableConv2d) + + def _TestInsertQuantOpInSeparableConv2d(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + input2 = array_ops.zeros((batch_size, height / 2, width / 2, depth)) + conv = separable_conv2d( + input1, + 3, [5, 5], + stride=2, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test/test') + node = math_ops.add(conv, input2, name='test/add') + node = nn_ops.relu6(node, name='test/relu6') + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + # Check if output of bias add is quantized quantization_node_name = 'FakeQuantWithMinMaxVars' conv_quant = graph.get_operation_by_name('test/test/conv_quant/' + quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) + # Check if weights for both convs inside seperable conv are quantized + pointwise_weight_quant = graph.get_operation_by_name( + 'test/test/weights_quant/' + quantization_node_name) + self.assertEqual(pointwise_weight_quant.type, quantization_node_name) + depthwise_weight_quant = graph.get_operation_by_name( + 'test/test/separable_conv2d/weights_quant/' + quantization_node_name) + self.assertEqual(depthwise_weight_quant.type, quantization_node_name) + + # Check if activations after first depthwise conv are quantized. + depthwise_act_quant = graph.get_operation_by_name( + 'test/test/separable_conv2d/act_quant/' + quantization_node_name) + self.assertEqual(depthwise_act_quant.type, quantization_node_name) + for op in graph.get_operations(): if op.type == quantization_node_name: quant_op = graph.get_operation_by_name(op.name) @@ -139,6 +194,33 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertNotIn('test/relu6', [c.name for c in consumers]) + def testLayerActivationQuantized(self): + self._RunTestOverParameters(self._TestLayerActivationQuantized) + + def _TestLayerActivationQuantized(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + _ = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=nn_ops.relu6, + biases_initializer=None, + scope='test') + # Ensure that both weights and output of activations are quantized + # when we have a conv->relu6 with no bias add + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + activation_op = graph.get_operation_by_name('test/Relu6') + conv_op = graph.get_operation_by_name('test/Conv2D') + self.assertTrue('test/weights_quant/FakeQuantWithMinMaxVars:0' in + [tensor_in.name for tensor_in in conv_op.inputs]) + self.assertTrue('FakeQuantWithMinMaxVars' in + [op.type for op in activation_op.outputs[0].consumers()]) + def testFinalLayerQuantized(self): self._RunTestOverParameters(self._TestFinalLayerQuantized) diff --git a/tensorflow/contrib/recurrent/BUILD b/tensorflow/contrib/recurrent/BUILD index b3cb04ce26d96333f516f1298c8d5c331964f05b..f9827f766da022b184b3348fc24b1570bac8678f 100644 --- a/tensorflow/contrib/recurrent/BUILD +++ b/tensorflow/contrib/recurrent/BUILD @@ -102,5 +102,8 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], - tags = ["nopip"], + tags = [ + "nopip", + "optonly", + ], ) diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py index 0f19ac7dbe0cee2eb6c780ec5ea6266bc847abd7..f23194a6f2e64e0619049bac51891d6d6099831f 100644 --- a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py +++ b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py @@ -61,10 +61,17 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): func, args = self._CELLDEFS[celldef_name] return func(*args) - def _CreateInputs(self): - inputs = np.random.random([FunctionalRnnTest._BATCH_SIZE, - FunctionalRnnTest._TOTAL_TIME, - FunctionalRnnTest._INPUT_SIZE]) + def _CreateInputs(self, time_major=False): + if time_major: + inputs = np.random.random([ + FunctionalRnnTest._TOTAL_TIME, FunctionalRnnTest._BATCH_SIZE, + FunctionalRnnTest._INPUT_SIZE + ]) + else: + inputs = np.random.random([ + FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._TOTAL_TIME, + FunctionalRnnTest._INPUT_SIZE + ]) # Always leave one time slot empty, to check max_length behavior. sequence_length = np.random.randint( 0, high=FunctionalRnnTest._TOTAL_TIME - 1, @@ -72,15 +79,51 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): dtype=np.int) return (inputs, sequence_length) - def _CreateRnnGraph(self, create_rnn_computation_func, cell, tf_inputs, - tf_sequence_length, initial_state=None, - time_major=None, scope=None): - tf_result = create_rnn_computation_func(cell=cell, inputs=tf_inputs, - sequence_length=tf_sequence_length, - initial_state=initial_state, - dtype=dtypes.float32, - time_major=time_major, - scope=scope) + def _CreateSymmetricInputs(self): + # total time = batch size + inputs = np.zeros( + (FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._BATCH_SIZE, + FunctionalRnnTest._INPUT_SIZE)) + for i in range(FunctionalRnnTest._BATCH_SIZE): + for j in range(i, FunctionalRnnTest._BATCH_SIZE): + inputs[i][j] = np.random.random([FunctionalRnnTest._INPUT_SIZE]) + inputs[j][i] = inputs[i][j] + + # Always leave one time slot empty, to check max_length behavior. + sequence_length = np.random.randint( + 0, + high=FunctionalRnnTest._BATCH_SIZE - 1, + size=FunctionalRnnTest._BATCH_SIZE, + dtype=np.int) + return (inputs, sequence_length) + + def _CreateRnnGraph(self, + create_rnn_computation_func, + cell, + tf_inputs, + tf_sequence_length, + is_bidirectional, + initial_state=None, + time_major=None, + scope=None): + if is_bidirectional: + tf_result = create_rnn_computation_func( + cell_fw=cell, + cell_bw=cell, + inputs=tf_inputs, + sequence_length=tf_sequence_length, + dtype=dtypes.float32, + time_major=time_major, + scope=scope) + else: + tf_result = create_rnn_computation_func( + cell=cell, + inputs=tf_inputs, + sequence_length=tf_sequence_length, + initial_state=initial_state, + dtype=dtypes.float32, + time_major=time_major, + scope=scope) grad = gradients_impl.gradients(tf_result, variables.trainable_variables()) return {'inference': tf_result, 'grad': grad} @@ -102,15 +145,26 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): variable_cache[n] = v def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache, - is_dynamic): + is_dynamic, time_major=None, is_bidirectional=False): with ops.Graph().as_default() as graph: tf_inputs = array_ops.placeholder( dtypes.float32, shape=numpy_inputs.shape) tf_slen = array_ops.placeholder(dtypes.int32) feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen} cell = self._CreateCell(cell_name) - fn = rnn_lib.dynamic_rnn if is_dynamic else functional_rnn.functional_rnn - fetches = self._CreateRnnGraph(fn, cell, tf_inputs, tf_slen) + if is_dynamic: + if is_bidirectional: + fn = rnn_lib.bidirectional_dynamic_rnn + else: + fn = rnn_lib.dynamic_rnn + else: + if is_bidirectional: + fn = functional_rnn.bidirectional_functional_rnn + else: + fn = functional_rnn.functional_rnn + + fetches = self._CreateRnnGraph( + fn, cell, tf_inputs, tf_slen, is_bidirectional, time_major=time_major) with self.test_session(graph=graph) as sess: sess.run(variables.global_variables_initializer()) # Note that cell.trainable_variables it not always set. @@ -158,6 +212,78 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + def testLstmWithTimeMajorInputs(self): + """Checks an LSTM against the reference implementation, with time_major.""" + time_major = True + np_inputs, np_slen = self._CreateInputs(time_major=True) + var_cache = {} + args = [np_inputs, np_slen, 'lstm', var_cache] + _, func_rnn = self._RunRnn(*(args + [False]), time_major=time_major) + _, dyn_rnn = self._RunRnn(*(args + [True]), time_major=time_major) + self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) + self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + + def testBidirectionalLstmWithTimeMajorInputs(self): + """Checks a bi-directional LSTM with time-major inputs.""" + time_major = True + np_inputs, np_slen = self._CreateInputs(time_major) + var_cache = {} + args = [np_inputs, np_slen, 'lstm', var_cache] + _, func_rnn = self._RunRnn( + *(args + [False]), time_major=time_major, is_bidirectional=True) + _, dyn_rnn = self._RunRnn( + *(args + [True]), time_major=time_major, is_bidirectional=True) + self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) + # TODO(b/112170761): comment out this line after the bug is fixed. + # self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + + def testBidirectionalLstm(self): + """Checks time-major and batch-major rnn produce consistent results.""" + time_major_inputs, np_slen = self._CreateInputs(True) + batch_major_inputs = np.transpose(time_major_inputs, [1, 0, 2]) + var_cache = {} + args = [np_slen, 'lstm', var_cache, False] + _, time_major_rnn = self._RunRnn( + *([time_major_inputs] + args), time_major=True, is_bidirectional=True) + _, batch_major_rnn = self._RunRnn( + *([batch_major_inputs]+ args), time_major=False, is_bidirectional=True) + # Convert the batch-major outputs to be time-major before the comparasion. + outputs, state = batch_major_rnn['inference'] + outputs = [np.transpose(x, [1, 0, 2]) for x in outputs] + batch_major_rnn['inference'] = [outputs, state] + self.assertAllClose(time_major_rnn['inference'], + batch_major_rnn['inference']) + self.assertAllClose(time_major_rnn['grad'], batch_major_rnn['grad']) + + def testBidirectionalLstmWithSymmetricInputs(self): + """Checks a bi-directional LSTM with symmetric inputs. + + time-major and batch-major rnn produce the same result with symmetric + inputs. + """ + np_inputs, np_slen = self._CreateSymmetricInputs() + var_cache = {} + args = [np_inputs, np_slen, 'lstm', var_cache] + _, time_major_func_rnn = self._RunRnn( + *(args + [False]), time_major=True, is_bidirectional=True) + _, batch_major_func_rnn = self._RunRnn( + *(args + [False]), time_major=False, is_bidirectional=True) + _, time_major_dyn_rnn = self._RunRnn( + *(args + [True]), time_major=True, is_bidirectional=True) + _, batch_major_dyn_rnn = self._RunRnn( + *(args + [True]), time_major=False, is_bidirectional=True) + self.assertAllClose(time_major_func_rnn['inference'], + batch_major_func_rnn['inference']) + self.assertAllClose(time_major_func_rnn['grad'], + batch_major_func_rnn['grad']) + self.assertAllClose(time_major_dyn_rnn['inference'], + batch_major_dyn_rnn['inference']) + self.assertAllClose(time_major_dyn_rnn['grad'], batch_major_dyn_rnn['grad']) + self.assertAllClose(time_major_func_rnn['inference'], + batch_major_dyn_rnn['inference']) + self.assertAllClose(time_major_func_rnn['grad'], + batch_major_dyn_rnn['grad']) + if __name__ == '__main__': test_lib.main() diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py index a085474c1bf6117ba5663139c78d8f08f71392d3..67a8f59c3c03d01a5957a9eff8bd026e70770a45 100644 --- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py +++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py @@ -206,7 +206,7 @@ def _PickFinalStateFromHistory(acc_state, sequence_length): lengths = array_ops.tile(array_ops.reshape(sequence_length, [-1, 1]), [1, max_time]) last_idx = math_ops.cast(math_ops.equal(output_time, lengths - 1), - dtype=dtypes.float32) + dtype=state_var.dtype) last_idx = array_ops.transpose(last_idx) last_idx_for_bcast = array_ops.expand_dims(last_idx, -1) sliced = math_ops.multiply(last_idx_for_bcast, state_var) @@ -284,8 +284,13 @@ def functional_rnn(cell, inputs, sequence_length=None, inputs=inputs, cell_fn=func_cell.cell_step, use_tpu=use_tpu) - return _PostProcessOutput(extended_acc_state, extended_final_state, - func_cell, inputs_flat[0].shape[0], sequence_length) + tf_output, tf_state = _PostProcessOutput( + extended_acc_state, extended_final_state, func_cell, + inputs_flat[0].shape[0], sequence_length) + + if time_major: + tf_output = array_ops.transpose(tf_output, [1, 0, 2]) + return tf_output, tf_state def bidirectional_functional_rnn( diff --git a/tensorflow/contrib/recurrent/python/ops/recurrent.py b/tensorflow/contrib/recurrent/python/ops/recurrent.py index fa16b82ab62f27d034c3ca7584e7e1ca14be6f9b..4f289e0c85e2260a44a8ea2f3f1d6cacbc839f66 100644 --- a/tensorflow/contrib/recurrent/python/ops/recurrent.py +++ b/tensorflow/contrib/recurrent/python/ops/recurrent.py @@ -79,7 +79,7 @@ def _Index(struct, index): """ index = ops.convert_to_tensor(index) index.get_shape().assert_has_rank(0) - return nest.map_structure(lambda x: x[index], struct) + return nest.map_structure(lambda x: array_ops.gather(x, index), struct) def _Update(struct_acc, struct_x, t): diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 4eb5c920b3517a8968ff730003e786ae2a9c9e26..1a1be5dc70069288a76ffd5cbf6ee3ab9ce3ae4d 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -118,7 +118,6 @@ cuda_py_tests( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", "//tensorflow/python:rnn", "//tensorflow/python:rnn_cell", "//tensorflow/python:variable_scope", @@ -176,7 +175,7 @@ cuda_py_tests( tf_py_test( name = "fused_rnn_cell_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/fused_rnn_cell_test.py"], additional_deps = [ ":rnn_py", @@ -193,10 +192,6 @@ tf_py_test( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], - tags = [ - "manual", - "notap", - ], ) cuda_py_tests( diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index 67f31785b57fddef67733c18c3b744322532c28c..cb437f2a2f252fcb0763587b07fed19be5887282 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -58,6 +58,10 @@ See @{$python/contrib.rnn} guide. @@Conv3DLSTMCell @@HighwayWrapper @@GLSTMCell +@@SRUCell +@@IndRNNCell +@@IndyGRUCell +@@IndyLSTMCell @@AttentionCellWrapper diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index b8840a8f2420f1bc6c75f0a02e5465c595378dec..85f0f8ced91e15cd0f9b3bc51f3a9e3aee12c978 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools import os import numpy as np @@ -35,7 +34,6 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope @@ -117,6 +115,27 @@ class RNNCellTest(test.TestCase): }) self.assertEqual(res[0].shape, (1, 2)) + def testIndRNNCell(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + cell = contrib_rnn_cell.IndRNNCell(2) + g, _ = cell(x, m) + self.assertEqual([ + "root/ind_rnn_cell/%s_w:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/ind_rnn_cell/%s_u:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/ind_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME + ], [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + self.assertEqual(res[0].shape, (1, 2)) + def testGRUCell(self): with self.test_session() as sess: with variable_scope.variable_scope( @@ -145,6 +164,34 @@ class RNNCellTest(test.TestCase): # Smoke test self.assertAllClose(res[0], [[0.156736, 0.156736]]) + def testIndyGRUCell(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.185265, 0.17704]]) + with variable_scope.variable_scope( + "other", initializer=init_ops.constant_initializer(0.5)): + # Test IndyGRUCell with input_size != num_units. + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.155127, 0.157328]]) + def testSRUCell(self): with self.test_session() as sess: with variable_scope.variable_scope( @@ -345,6 +392,72 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1], expected_mem0) self.assertAllClose(res[2], expected_mem1) + def testIndyLSTMCell(self): + for dtype in [dtypes.float16, dtypes.float32]: + np_dtype = dtype.as_numpy_dtype + with self.test_session(graph=ops.Graph()) as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2], dtype=dtype) + state_0 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + state_1 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + cell = rnn_cell_impl.MultiRNNCell( + [contrib_rnn_cell.IndyLSTMCell(2) for _ in range(2)]) + self.assertEqual(cell.dtype, None) + self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) + self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) + cell.get_config() # Should not throw an error + g, (out_state_0, out_state_1) = cell(x, (state_0, state_1)) + # Layer infers the input type. + self.assertEqual(cell.dtype, dtype.name) + expected_variable_names = [ + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_w:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_u:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_w:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_u:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME + ] + self.assertEqual(expected_variable_names, + [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run( + [g, out_state_0, out_state_1], { + x.name: np.array([[1., 1.]]), + state_0[0].name: 0.1 * np.ones([1, 2]), + state_0[1].name: 0.1 * np.ones([1, 2]), + state_1[0].name: 0.1 * np.ones([1, 2]), + state_1[1].name: 0.1 * np.ones([1, 2]), + }) + self.assertEqual(len(res), 3) + variables = variables_lib.global_variables() + self.assertEqual(expected_variable_names, [v.name for v in variables]) + # Only check the range of outputs as this is just a smoke test. + self.assertAllInRange(res[0], -1.0, 1.0) + self.assertAllInRange(res[1], -1.0, 1.0) + self.assertAllInRange(res[2], -1.0, 1.0) + with variable_scope.variable_scope( + "other", initializer=init_ops.constant_initializer(0.5)): + # Test IndyLSTMCell with input_size != num_units. + x = array_ops.zeros([1, 3], dtype=dtype) + state = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + g, out_state = contrib_rnn_cell.IndyLSTMCell(2)(x, state) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run( + [g, out_state], { + x.name: np.array([[1., 1., 1.]], dtype=np_dtype), + state[0].name: 0.1 * np.ones([1, 2], dtype=np_dtype), + state[1].name: 0.1 * np.ones([1, 2], dtype=np_dtype), + }) + self.assertEqual(len(res), 2) + def testLSTMCell(self): with self.test_session() as sess: num_units = 8 @@ -443,7 +556,7 @@ class RNNCellTest(test.TestCase): self.assertTrue( float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWrapperCheckpointing(self): for wrapper_type in [ rnn_cell_impl.DropoutWrapper, @@ -935,50 +1048,6 @@ class DropoutWrapperTest(test.TestCase): self.assertAllClose(res0[1].h, res1[1].h) -class SlimRNNCellTest(test.TestCase): - - def testBasicRNNCell(self): - with self.test_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - my_cell = functools.partial(basic_rnn_cell, num_units=2) - # pylint: disable=protected-access - g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m) - # pylint: enable=protected-access - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[0].shape, (1, 2)) - - def testBasicRNNCellMatch(self): - batch_size = 32 - input_size = 100 - num_units = 10 - with self.test_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - inputs = random_ops.random_uniform((batch_size, input_size)) - _, initial_state = basic_rnn_cell(inputs, None, num_units) - rnn_cell = rnn_cell_impl.BasicRNNCell(num_units) - outputs, state = rnn_cell(inputs, initial_state) - variable_scope.get_variable_scope().reuse_variables() - my_cell = functools.partial(basic_rnn_cell, num_units=num_units) - # pylint: disable=protected-access - slim_cell = rnn_cell_impl._SlimRNNCell(my_cell) - # pylint: enable=protected-access - slim_outputs, slim_state = slim_cell(inputs, initial_state) - self.assertEqual(slim_outputs.get_shape(), outputs.get_shape()) - self.assertEqual(slim_state.get_shape(), state.get_shape()) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([slim_outputs, slim_state, outputs, state]) - self.assertAllClose(res[0], res[2]) - self.assertAllClose(res[1], res[3]) - - def basic_rnn_cell(inputs, state, num_units, scope=None): if state is None: if inputs is not None: diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index be99a5d67a3e49b1d522406601d050392f75e963..f7088c63313f91b709e1ed0c583841d1f3be3c28 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -921,7 +921,7 @@ class LSTMTest(test.TestCase): # Smoke test, this should not raise an error rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDynamicRNNWithTupleStates(self): num_units = 3 input_size = 5 @@ -997,7 +997,7 @@ class LSTMTest(test.TestCase): self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic) self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDynamicRNNWithNestedTupleStates(self): num_units = 3 input_size = 5 @@ -1285,9 +1285,9 @@ class LSTMTest(test.TestCase): "Comparing individual variable gradients iteration %d" % i) self.assertAllEqual(a, b) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDynamicEquivalentToStaticRNN(self): - self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) + self._testDynamicEquivalentToStaticRNN(use_sequence_length=True) self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) diff --git a/tensorflow/contrib/rnn/python/ops/rnn.py b/tensorflow/contrib/rnn/python/ops/rnn.py index 2f0caadda336b878e58e973e1c995cbec65d5732..0266b72dcb15e4aba01a9a31b4be75c5b84d44da 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn.py +++ b/tensorflow/contrib/rnn/python/ops/rnn.py @@ -175,7 +175,7 @@ def stack_bidirectional_dynamic_rnn(cells_fw, Returns: A tuple (outputs, output_state_fw, output_state_bw) where: outputs: Output `Tensor` shaped: - `batch_size, max_time, layers_output]`. Where layers_output + `[batch_size, max_time, layers_output]`. Where layers_output are depth-concatenated forward and backward outputs. output_states_fw is the final states, one tensor per layer, of the forward rnn. diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index b12e2cd5eddc3f8abdba62781692673a40e41d9b..f74c95f96299cf132a9a1d8ab8b238a532e2695b 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -23,6 +23,7 @@ import math from tensorflow.contrib.compiler import jit from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.rnn.python.ops import core_rnn_cell +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops @@ -30,6 +31,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl # pylint: disable=unused-import @@ -3050,3 +3052,343 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell): new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) return new_h, new_state + + +class IndRNNCell(rnn_cell_impl.LayerRNNCell): + """Independently Recurrent Neural Network (IndRNN) cell + (cf. https://arxiv.org/abs/1803.04831). + + Args: + num_units: int, The number of units in the RNN cell. + activation: Nonlinearity to use. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such + cases. + dtype: Default dtype of the layer (default of `None` means use the type + of the first input). Required when `build` is called before `call`. + """ + + def __init__(self, + num_units, + activation=None, + reuse=None, + name=None, + dtype=None): + super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + + # Inputs must be 2-dimensional. + self.input_spec = base_layer.InputSpec(ndim=2) + + self._num_units = num_units + self._activation = activation or math_ops.tanh + + @property + def state_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + def build(self, inputs_shape): + if inputs_shape[1].value is None: + raise ValueError( + "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) + + input_depth = inputs_shape[1].value + # pylint: disable=protected-access + self._kernel_w = self.add_variable( + "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[input_depth, self._num_units]) + self._kernel_u = self.add_variable( + "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[1, self._num_units], + initializer=init_ops.random_uniform_initializer( + minval=-1, maxval=1, dtype=self.dtype)) + self._bias = self.add_variable( + rnn_cell_impl._BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=init_ops.zeros_initializer(dtype=self.dtype)) + # pylint: enable=protected-access + + self.built = True + + def call(self, inputs, state): + """IndRNN: output = new_state = act(W * input + u * state + B).""" + + gate_inputs = math_ops.matmul(inputs, self._kernel_w) + ( + state * self._kernel_u) + gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) + output = self._activation(gate_inputs) + return output, output + + +class IndyGRUCell(rnn_cell_impl.LayerRNNCell): + r"""Independently Gated Recurrent Unit cell. + + Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to GRUCell, + yet with the \(U_r\), \(U_z\), and \(U\) matrices in equations 5, 6, and + 8 of http://arxiv.org/abs/1406.1078 respectively replaced by diagonal + matrices, i.e. a Hadamard product with a single vector: + + $$r_j = \sigma\left([\mathbf W_r\mathbf x]_j + + [\mathbf u_r\circ \mathbf h_{(t-1)}]_j\right)$$ + $$z_j = \sigma\left([\mathbf W_z\mathbf x]_j + + [\mathbf u_z\circ \mathbf h_{(t-1)}]_j\right)$$ + $$\tilde{h}^{(t)}_j = \phi\left([\mathbf W \mathbf x]_j + + [\mathbf u \circ \mathbf r \circ \mathbf h_{(t-1)}]_j\right)$$ + + where \(\circ\) denotes the Hadamard operator. This means that each IndyGRU + node sees only its own state, as opposed to seeing all states in the same + layer. + + TODO(gonnet): Write a paper describing this and add a reference here. + + Args: + num_units: int, The number of units in the GRU cell. + activation: Nonlinearity to use. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + kernel_initializer: (optional) The initializer to use for the weight + matrices applied to the input. + bias_initializer: (optional) The initializer to use for the bias. + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such + cases. + dtype: Default dtype of the layer (default of `None` means use the type + of the first input). Required when `build` is called before `call`. + """ + + def __init__(self, + num_units, + activation=None, + reuse=None, + kernel_initializer=None, + bias_initializer=None, + name=None, + dtype=None): + super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + + # Inputs must be 2-dimensional. + self.input_spec = base_layer.InputSpec(ndim=2) + + self._num_units = num_units + self._activation = activation or math_ops.tanh + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + + @property + def state_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + def build(self, inputs_shape): + if inputs_shape[1].value is None: + raise ValueError( + "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) + + input_depth = inputs_shape[1].value + # pylint: disable=protected-access + self._gate_kernel_w = self.add_variable( + "gates/%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[input_depth, 2 * self._num_units], + initializer=self._kernel_initializer) + self._gate_kernel_u = self.add_variable( + "gates/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[1, 2 * self._num_units], + initializer=init_ops.random_uniform_initializer( + minval=-1, maxval=1, dtype=self.dtype)) + self._gate_bias = self.add_variable( + "gates/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME, + shape=[2 * self._num_units], + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.constant_initializer(1.0, dtype=self.dtype))) + self._candidate_kernel_w = self.add_variable( + "candidate/%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[input_depth, self._num_units], + initializer=self._kernel_initializer) + self._candidate_kernel_u = self.add_variable( + "candidate/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[1, self._num_units], + initializer=init_ops.random_uniform_initializer( + minval=-1, maxval=1, dtype=self.dtype)) + self._candidate_bias = self.add_variable( + "candidate/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.zeros_initializer(dtype=self.dtype))) + # pylint: enable=protected-access + + self.built = True + + def call(self, inputs, state): + """Gated recurrent unit (GRU) with nunits cells.""" + + gate_inputs = math_ops.matmul(inputs, self._gate_kernel_w) + ( + gen_array_ops.tile(state, [1, 2]) * self._gate_kernel_u) + gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) + + value = math_ops.sigmoid(gate_inputs) + r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) + + r_state = r * state + + candidate = math_ops.matmul(inputs, self._candidate_kernel_w) + ( + r_state * self._candidate_kernel_u) + candidate = nn_ops.bias_add(candidate, self._candidate_bias) + + c = self._activation(candidate) + new_h = u * state + (1 - u) * c + return new_h, new_h + + +class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): + r"""Basic IndyLSTM recurrent network cell. + + Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to + BasicLSTMCell, yet with the \(U_f\), \(U_i\), \(U_o\) and \(U_c\) + matrices in + https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate + replaced by diagonal matrices, i.e. a Hadamard product with a single vector: + + $$f_t = \sigma_g\left(W_f x_t + u_f \circ h_{t-1} + b_f\right)$$ + $$i_t = \sigma_g\left(W_i x_t + u_i \circ h_{t-1} + b_i\right)$$ + $$o_t = \sigma_g\left(W_o x_t + u_o \circ h_{t-1} + b_o\right)$$ + $$c_t = f_t \circ c_{t-1} + + i_t \circ \sigma_c\left(W_c x_t + u_c \circ h_{t-1} + b_c\right)$$ + + where \(\circ\) denotes the Hadamard operator. This means that each IndyLSTM + node sees only its own state \(h\) and \(c\), as opposed to seeing all + states in the same layer. + + We add forget_bias (default: 1) to the biases of the forget gate in order to + reduce the scale of forgetting in the beginning of the training. + + It does not allow cell clipping, a projection layer, and does not + use peep-hole connections: it is the basic baseline. + + For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell` + that follows. + + TODO(gonnet): Write a paper describing this and add a reference here. + """ + + def __init__(self, + num_units, + forget_bias=1.0, + activation=None, + reuse=None, + kernel_initializer=None, + bias_initializer=None, + name=None, + dtype=None): + """Initialize the IndyLSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell. + forget_bias: float, The bias added to forget gates (see above). + Must set to `0.0` manually when restoring from CudnnLSTM-trained + checkpoints. + activation: Activation function of the inner states. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + kernel_initializer: (optional) The initializer to use for the weight + matrix applied to the inputs. + bias_initializer: (optional) The initializer to use for the bias. + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such + cases. + dtype: Default dtype of the layer (default of `None` means use the type + of the first input). Required when `build` is called before `call`. + """ + super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + + # Inputs must be 2-dimensional. + self.input_spec = base_layer.InputSpec(ndim=2) + + self._num_units = num_units + self._forget_bias = forget_bias + self._activation = activation or math_ops.tanh + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + + @property + def state_size(self): + return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units) + + @property + def output_size(self): + return self._num_units + + def build(self, inputs_shape): + if inputs_shape[1].value is None: + raise ValueError( + "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) + + input_depth = inputs_shape[1].value + # pylint: disable=protected-access + self._kernel_w = self.add_variable( + "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[input_depth, 4 * self._num_units], + initializer=self._kernel_initializer) + self._kernel_u = self.add_variable( + "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[1, 4 * self._num_units], + initializer=init_ops.random_uniform_initializer( + minval=-1, maxval=1, dtype=self.dtype)) + self._bias = self.add_variable( + rnn_cell_impl._BIAS_VARIABLE_NAME, + shape=[4 * self._num_units], + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.zeros_initializer(dtype=self.dtype))) + # pylint: enable=protected-access + + self.built = True + + def call(self, inputs, state): + """Independent Long short-term memory cell (IndyLSTM). + + Args: + inputs: `2-D` tensor with shape `[batch_size, input_size]`. + state: An `LSTMStateTuple` of state tensors, each shaped + `[batch_size, num_units]`. + + Returns: + A pair containing the new hidden state, and the new state (a + `LSTMStateTuple`). + """ + sigmoid = math_ops.sigmoid + one = constant_op.constant(1, dtype=dtypes.int32) + c, h = state + + gate_inputs = math_ops.matmul(inputs, self._kernel_w) + gate_inputs += gen_array_ops.tile(h, [1, 4]) * self._kernel_u + gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = array_ops.split( + value=gate_inputs, num_or_size_splits=4, axis=one) + + forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) + # Note that using `add` and `multiply` instead of `+` and `*` gives a + # performance improvement. So using those at the cost of readability. + add = math_ops.add + multiply = math_ops.multiply + new_c = add( + multiply(c, sigmoid(add(f, forget_bias_tensor))), + multiply(sigmoid(i), self._activation(j))) + new_h = multiply(self._activation(new_c), sigmoid(o)) + + new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) + return new_h, new_state diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD index 2311c15a68c46090cec0f97bd950296506b0817e..cb0b89ae55b96361428c7845d4d6aab72543feb7 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD +++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD @@ -1,5 +1,3 @@ -# TODO(b/76425722): Port everything in here to OS (currently excluded). - package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 @@ -17,7 +15,6 @@ tf_proto_library( srcs = ["test_example.proto"], has_services = 1, cc_api_version = 2, - protodeps = ["//tensorflow/core:protos_all"], ) py_library( diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py index 27273d16b1c09eba60e124e632b353b09ea2d063..1c23c28860dac6203ea4ec8e808f63d3e9e467e2 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py @@ -51,23 +51,23 @@ class RpcOpTestBase(object): def testScalarHostPortRpc(self): with self.test_session() as sess: request_tensors = ( - test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString()) + test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString()) response_tensors = self.rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=self._address, request=request_tensors) self.assertEqual(response_tensors.shape, ()) response_values = sess.run(response_tensors) response_message = test_example_pb2.TestCase() self.assertTrue(response_message.ParseFromString(response_values)) - self.assertAllEqual([2, 3, 4], response_message.shape) + self.assertAllEqual([2, 3, 4], response_message.values) def testScalarHostPortTryRpc(self): with self.test_session() as sess: request_tensors = ( - test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString()) + test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString()) response_tensors, status_code, status_message = self.try_rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=self._address, request=request_tensors) self.assertEqual(status_code.shape, ()) @@ -77,7 +77,7 @@ class RpcOpTestBase(object): sess.run((response_tensors, status_code, status_message))) response_message = test_example_pb2.TestCase() self.assertTrue(response_message.ParseFromString(response_values)) - self.assertAllEqual([2, 3, 4], response_message.shape) + self.assertAllEqual([2, 3, 4], response_message.values) # For the base Rpc op, don't expect to get error status back. self.assertEqual(errors.OK, status_code_values) self.assertEqual(b'', status_message_values) @@ -86,7 +86,7 @@ class RpcOpTestBase(object): with self.test_session() as sess: request_tensors = [] response_tensors = self.rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=self._address, request=request_tensors) self.assertAllEqual(response_tensors.shape, [0]) @@ -95,7 +95,7 @@ class RpcOpTestBase(object): def testInvalidMethod(self): for method in [ - '/InvalidService.IncrementTestShapes', + '/InvalidService.Increment', self.get_method_name('InvalidMethodName') ]: with self.test_session() as sess: @@ -115,12 +115,12 @@ class RpcOpTestBase(object): with self.assertRaises(errors.UnavailableError): sess.run( self.rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=address, request='')) _, status_code_value, status_message_value = sess.run( self.try_rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=address, request='')) self.assertEqual(errors.UNAVAILABLE, status_code_value) @@ -182,10 +182,10 @@ class RpcOpTestBase(object): with self.test_session() as sess: request_tensors = [ test_example_pb2.TestCase( - shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20) + values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) ] response_tensors = self.rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=self._address, request=request_tensors) self.assertEqual(response_tensors.shape, (20,)) @@ -194,17 +194,17 @@ class RpcOpTestBase(object): for i in range(20): response_message = test_example_pb2.TestCase() self.assertTrue(response_message.ParseFromString(response_values[i])) - self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape) + self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values) def testVecHostPortManyParallelRpcs(self): with self.test_session() as sess: request_tensors = [ test_example_pb2.TestCase( - shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20) + values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) ] many_response_tensors = [ self.rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=self._address, request=request_tensors) for _ in range(10) ] @@ -216,25 +216,25 @@ class RpcOpTestBase(object): for i in range(20): response_message = test_example_pb2.TestCase() self.assertTrue(response_message.ParseFromString(response_values[i])) - self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape) + self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values) def testVecHostPortRpcUsingEncodeAndDecodeProto(self): with self.test_session() as sess: request_tensors = encode_proto_op.encode_proto( message_type='tensorflow.contrib.rpc.TestCase', - field_names=['shape'], + field_names=['values'], sizes=[[3]] * 20, values=[ [[i, i + 1, i + 2] for i in range(20)], ]) response_tensor_strings = self.rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=self._address, request=request_tensors) _, (response_shape,) = decode_proto_op.decode_proto( bytes=response_tensor_strings, message_type='tensorflow.contrib.rpc.TestCase', - field_names=['shape'], + field_names=['values'], output_types=[dtypes.int32]) response_shape_values = sess.run(response_shape) self.assertAllEqual([[i + 1, i + 2, i + 3] @@ -285,9 +285,9 @@ class RpcOpTestBase(object): addresses = flatten([[ self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' ] for _ in range(10)]) - request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString() + request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString() response_tensors, status_code, _ = self.try_rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=addresses, request=request) response_tensors_values, status_code_values = sess.run((response_tensors, @@ -303,9 +303,9 @@ class RpcOpTestBase(object): flatten = lambda x: list(itertools.chain.from_iterable(x)) with self.test_session() as sess: methods = flatten( - [[self.get_method_name('IncrementTestShapes'), 'InvalidMethodName'] + [[self.get_method_name('Increment'), 'InvalidMethodName'] for _ in range(10)]) - request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString() + request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString() response_tensors, status_code, _ = self.try_rpc( method=methods, address=self._address, request=request) response_tensors_values, status_code_values = sess.run((response_tensors, @@ -325,10 +325,10 @@ class RpcOpTestBase(object): ] for _ in range(10)]) requests = [ test_example_pb2.TestCase( - shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20) + values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) ] response_tensors, status_code, _ = self.try_rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=addresses, request=requests) response_tensors_values, status_code_values = sess.run((response_tensors, @@ -343,4 +343,4 @@ class RpcOpTestBase(object): response_message = test_example_pb2.TestCase() self.assertTrue( response_message.ParseFromString(response_tensors_values[i])) - self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape) + self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values) diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py index 7cbd636cb16e3befc9ae27cb231696634e859a22..265254aa51c64ff5a76ad3a9f7e081c56dd639e7 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py @@ -30,8 +30,8 @@ from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer): """Test servicer for RpcOp tests.""" - def IncrementTestShapes(self, request, context): - """Increment the entries in the shape attribute of request. + def Increment(self, request, context): + """Increment the entries in the `values` attribute of request. Args: request: input TestCase. @@ -40,8 +40,8 @@ class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer): Returns: output TestCase. """ - for i in range(len(request.shape)): - request.shape[i] += 1 + for i in range(len(request.values)): + request.values[i] += 1 return request def AlwaysFailWithInvalidArgument(self, request, context): diff --git a/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto b/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto index 96f4550f62bc17e713abe1f3843ec0964f57b046..8141466349afcebcd104153a9f28c8f382458098 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto +++ b/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto @@ -1,29 +1,17 @@ // Test description and protos to work with it. -// -// Many of the protos in this file are for unit tests that haven't been written yet. syntax = "proto2"; -import "tensorflow/core/framework/types.proto"; - package tensorflow.contrib.rpc; -// A TestCase holds a proto and a bunch of assertions -// about how it should decode. +// A TestCase holds a sequence of values. message TestCase { - // A batch of primitives to be serialized and decoded. - repeated RepeatedPrimitiveValue primitive = 1; - // The shape of the batch. - repeated int32 shape = 2; - // Expected sizes for each field. - repeated int32 sizes = 3; - // Expected values for each field. - repeated FieldSpec field = 4; + repeated int32 values = 1; }; service TestCaseService { - // Copy input, and increment each entry in 'shape' by 1. - rpc IncrementTestShapes(TestCase) returns (TestCase) { + // Copy input, and increment each entry in 'values' by 1. + rpc Increment(TestCase) returns (TestCase) { } // Sleep forever. @@ -42,130 +30,3 @@ service TestCaseService { rpc SometimesFailWithInvalidArgument(TestCase) returns (TestCase) { } }; - -// FieldSpec describes the expected output for a single field. -message FieldSpec { - optional string name = 1; - optional tensorflow.DataType dtype = 2; - optional RepeatedPrimitiveValue expected = 3; -}; - -message TestValue { - optional PrimitiveValue primitive_value = 1; - optional EnumValue enum_value = 2; - optional MessageValue message_value = 3; - optional RepeatedMessageValue repeated_message_value = 4; - optional RepeatedPrimitiveValue repeated_primitive_value = 6; -} - -message PrimitiveValue { - optional double double_value = 1; - optional float float_value = 2; - optional int64 int64_value = 3; - optional uint64 uint64_value = 4; - optional int32 int32_value = 5; - optional fixed64 fixed64_value = 6; - optional fixed32 fixed32_value = 7; - optional bool bool_value = 8; - optional string string_value = 9; - optional bytes bytes_value = 12; - optional uint32 uint32_value = 13; - optional sfixed32 sfixed32_value = 15; - optional sfixed64 sfixed64_value = 16; - optional sint32 sint32_value = 17; - optional sint64 sint64_value = 18; -} - -// NOTE: This definition must be kept in sync with PackedPrimitiveValue. -message RepeatedPrimitiveValue { - repeated double double_value = 1; - repeated float float_value = 2; - repeated int64 int64_value = 3; - repeated uint64 uint64_value = 4; - repeated int32 int32_value = 5; - repeated fixed64 fixed64_value = 6; - repeated fixed32 fixed32_value = 7; - repeated bool bool_value = 8; - repeated string string_value = 9; - repeated bytes bytes_value = 12; - repeated uint32 uint32_value = 13; - repeated sfixed32 sfixed32_value = 15; - repeated sfixed64 sfixed64_value = 16; - repeated sint32 sint32_value = 17; - repeated sint64 sint64_value = 18; - repeated PrimitiveValue message_value = 19; -} - -// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue -// in the text format, but the binary serializion is different. -// We test the packed representations by loading the same test cases -// using this definition instead of RepeatedPrimitiveValue. -// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue -// in every way except the packed=true declaration. -message PackedPrimitiveValue { - repeated double double_value = 1 [packed = true]; - repeated float float_value = 2 [packed = true]; - repeated int64 int64_value = 3 [packed = true]; - repeated uint64 uint64_value = 4 [packed = true]; - repeated int32 int32_value = 5 [packed = true]; - repeated fixed64 fixed64_value = 6 [packed = true]; - repeated fixed32 fixed32_value = 7 [packed = true]; - repeated bool bool_value = 8 [packed = true]; - repeated string string_value = 9; - repeated bytes bytes_value = 12; - repeated uint32 uint32_value = 13 [packed = true]; - repeated sfixed32 sfixed32_value = 15 [packed = true]; - repeated sfixed64 sfixed64_value = 16 [packed = true]; - repeated sint32 sint32_value = 17 [packed = true]; - repeated sint64 sint64_value = 18 [packed = true]; - repeated PrimitiveValue message_value = 19; -} - -message EnumValue { - enum Color { - RED = 0; - ORANGE = 1; - YELLOW = 2; - GREEN = 3; - BLUE = 4; - INDIGO = 5; - VIOLET = 6; - }; - optional Color enum_value = 14; - repeated Color repeated_enum_value = 15; -} - - -message InnerMessageValue { - optional float float_value = 2; - repeated bytes bytes_values = 8; -} - -message MiddleMessageValue { - repeated int32 int32_values = 5; - optional InnerMessageValue message_value = 11; - optional uint32 uint32_value = 13; -} - -message MessageValue { - optional double double_value = 1; - optional MiddleMessageValue message_value = 11; -} - -message RepeatedMessageValue { - message NestedMessageValue { - optional float float_value = 2; - repeated bytes bytes_values = 8; - } - - repeated NestedMessageValue message_values = 11; -} - -// Message containing fields with field numbers higher than any field above. An -// instance of this message is prepended to each binary message in the test to -// exercise the code path that handles fields encoded out of order of field -// number. -message ExtraFields { - optional string string_value = 1776; - optional bool bool_value = 1777; -} diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 26fd4e2023806765ea4088f4c13a780ca7338bff..e7eb4ac5635037e30ecf88a037343c8967986447 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -93,3 +93,31 @@ py_test( "//tensorflow/python/saved_model:utils", ], ) + +py_library( + name = "keras_saved_model", + srcs = ["python/saved_model/keras_saved_model.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:lib", + "//tensorflow/python:util", + "//tensorflow/python/keras:engine", + "//tensorflow/python/saved_model:constants", + ], +) + +py_test( + name = "keras_saved_model_test", + size = "small", + srcs = ["python/saved_model/keras_saved_model_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":saved_model_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:training", + "//tensorflow/python/keras", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/saved_model/__init__.py b/tensorflow/contrib/saved_model/__init__.py index b4f27a055dad7a5b95112d561cc878609a558f8d..95e1a8967b2223fd3feb112af3cbe0c5991d2d03 100644 --- a/tensorflow/contrib/saved_model/__init__.py +++ b/tensorflow/contrib/saved_model/__init__.py @@ -24,11 +24,12 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long +from tensorflow.contrib.saved_model.python.saved_model.keras_saved_model import * from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import * # pylint: enable=unused-import,widcard-import,line-too-long from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ["get_signature_def_by_key"] +_allowed_symbols = ["get_signature_def_by_key", "load_model", "save_model"] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/saved_model/python/saved_model/__init__.py b/tensorflow/contrib/saved_model/python/saved_model/__init__.py index 7b91622b6127413ce122c4166a18255b65365d32..e3b76bb6f34846f02ccdf623d48ddd9c5909fdce 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/__init__.py +++ b/tensorflow/contrib/saved_model/python/saved_model/__init__.py @@ -24,5 +24,6 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import +from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e2a969f053d3f1ded8aecd6411a62a198df48bb0 --- /dev/null +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -0,0 +1,108 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Utility functions to save/load keras Model to/from SavedModel.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.keras.models import model_from_json +from tensorflow.python.lib.io import file_io +from tensorflow.python.saved_model import constants +from tensorflow.python.util import compat + + +def save_model(model, saved_model_path): + """Save a `tf.keras.Model` into Tensorflow SavedModel format. + + `save_model` generates such files/folders under the `saved_model_path` folder: + 1) an asset folder containing the json string of the model's + configuration(topology). + 2) a checkpoint containing the model weights. + + Note that subclassed models can not be saved via this function, unless you + provide an implementation for get_config() and from_config(). + Also note that `tf.keras.optimizers.Optimizer` instances can not currently be + saved to checkpoints. Use optimizers from `tf.train`. + + Args: + model: A `tf.keras.Model` to be saved. + saved_model_path: a string specifying the path to the SavedModel directory. + + Raises: + NotImplementedError: If the passed in model is a subclassed model. + """ + if not model._is_graph_network: + raise NotImplementedError + + # save model configuration as a json string under assets folder. + model_json = model.to_json() + assets_destination_dir = os.path.join( + compat.as_bytes(saved_model_path), + compat.as_bytes(constants.ASSETS_DIRECTORY)) + + if not file_io.file_exists(assets_destination_dir): + file_io.recursive_create_dir(assets_destination_dir) + + model_json_filepath = os.path.join( + compat.as_bytes(assets_destination_dir), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) + file_io.write_string_to_file(model_json_filepath, model_json) + + # save model weights in checkpoint format. + checkpoint_destination_dir = os.path.join( + compat.as_bytes(saved_model_path), + compat.as_bytes(constants.VARIABLES_DIRECTORY)) + + if not file_io.file_exists(checkpoint_destination_dir): + file_io.recursive_create_dir(checkpoint_destination_dir) + + checkpoint_prefix = os.path.join( + compat.as_text(checkpoint_destination_dir), + compat.as_text(constants.VARIABLES_FILENAME)) + model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) + + +def load_model(saved_model_path): + """Load a keras.Model from SavedModel. + + load_model reinstantiates model state by: + 1) loading model topology from json (this will eventually come + from metagraph). + 2) loading model weights from checkpoint. + + Args: + saved_model_path: a string specifying the path to an existing SavedModel. + + Returns: + a keras.Model instance. + """ + # restore model topology from json string + model_json_filepath = os.path.join( + compat.as_bytes(saved_model_path), + compat.as_bytes(constants.ASSETS_DIRECTORY), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) + model_json = file_io.read_file_to_string(model_json_filepath) + model = model_from_json(model_json) + + # restore model weights + checkpoint_prefix = os.path.join( + compat.as_text(saved_model_path), + compat.as_text(constants.VARIABLES_DIRECTORY), + compat.as_text(constants.VARIABLES_FILENAME)) + model.load_weights(checkpoint_prefix) + return model diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..107ae1b07b777570e4124337595ceecd6e33cd0b --- /dev/null +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Tests for saving/loading function for keras Model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import numpy as np + +from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model +from tensorflow.python import keras +from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training +from tensorflow.python.platform import test +from tensorflow.python.training import training as training_module + + +class TestModelSavingandLoading(test.TestCase): + + def test_saving_sequential_model(self): + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + + ref_y = model.predict(x) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + + loaded_model = keras_saved_model.load_model(temp_saved_model) + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + @test_util.run_in_graph_and_eager_modes + def test_saving_sequential_model_without_compile(self): + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + + x = np.random.random((1, 3)) + ref_y = model.predict(x) + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + def test_saving_functional_model(self): + with self.test_session(): + inputs = keras.layers.Input(shape=(3,)) + x = keras.layers.Dense(2)(inputs) + output = keras.layers.Dense(3)(x) + + model = keras.models.Model(inputs, output) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + ref_y = model.predict(x) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + @test_util.run_in_graph_and_eager_modes + def test_saving_functional_model_without_compile(self): + with self.test_session(): + inputs = keras.layers.Input(shape=(3,)) + x = keras.layers.Dense(2)(inputs) + output = keras.layers.Dense(3)(x) + + model = keras.models.Model(inputs, output) + + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + + ref_y = model.predict(x) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + @test_util.run_in_graph_and_eager_modes + def test_saving_with_tf_optimizer(self): + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.Dense(3)) + model.compile( + loss='mse', + optimizer=training_module.RMSPropOptimizer(0.1), + metrics=['acc']) + + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + ref_y = model.predict(x) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + loaded_model.compile( + loss='mse', + optimizer=training_module.RMSPropOptimizer(0.1), + metrics=['acc']) + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + # test that new updates are the same with both models + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + + ref_loss = model.train_on_batch(x, y) + loss = loaded_model.train_on_batch(x, y) + self.assertAllClose(ref_loss, loss, atol=1e-05) + + ref_y = model.predict(x) + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + # test saving/loading again + keras_saved_model.save_model(loaded_model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + def test_saving_subclassed_model_raise_error(self): + # For now, saving subclassed model should raise an error. It should be + # avoided later with loading from SavedModel.pb. + + class SubclassedModel(training.Model): + + def __init__(self): + super(SubclassedModel, self).__init__() + self.layer1 = keras.layers.Dense(3) + self.layer2 = keras.layers.Dense(1) + + def call(self, inp): + return self.layer2(self.layer1(inp)) + + model = SubclassedModel() + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + temp_saved_model = os.path.join(temp_dir, 'saved_model') + with self.assertRaises(NotImplementedError): + keras_saved_model.save_model(model, temp_saved_model) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index 1a1591d798f6f904e23987d9d7a60193c124c20e..18b56cd21942e28cb0dc3210df0bb04d55c1e16f 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -177,7 +177,7 @@ cuda_py_test( cuda_py_test( name = "beam_search_decoder_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/beam_search_decoder_test.py"], additional_deps = [ ":seq2seq_py", diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 178328619f087789df040489cd150ba018cc8d14..4073b390fc72cf0f84edd0d2ab56df5ffeb3e2e5 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -132,6 +132,48 @@ class TestGatherTree(test.TestCase): def test_gather_tree_from_array_2d(self): self._test_gather_tree_from_array(depth_ndims=2) + def test_gather_tree_from_array_complex_trajectory(self): + # Max. time = 7, batch = 1, beam = 5. + array = np.expand_dims(np.array( + [[[25, 12, 114, 89, 97]], + [[9, 91, 64, 11, 162]], + [[34, 34, 34, 34, 34]], + [[2, 4, 2, 2, 4]], + [[2, 3, 6, 2, 2]], + [[2, 2, 2, 3, 2]], + [[2, 2, 2, 2, 2]]]), -1) + parent_ids = np.array( + [[[0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0]], + [[0, 1, 2, 3, 4]], + [[0, 0, 1, 2, 1]], + [[0, 1, 1, 2, 3]], + [[0, 1, 3, 1, 2]], + [[0, 1, 2, 3, 4]]]) + expected_array = np.expand_dims(np.array( + [[[25, 25, 25, 25, 25]], + [[9, 9, 91, 9, 9]], + [[34, 34, 34, 34, 34]], + [[2, 4, 2, 4, 4]], + [[2, 3, 6, 3, 6]], + [[2, 2, 2, 3, 2]], + [[2, 2, 2, 2, 2]]]), -1) + sequence_length = [[4, 6, 4, 7, 6]] + + array = ops.convert_to_tensor( + array, dtype=dtypes.float32) + parent_ids = ops.convert_to_tensor( + parent_ids, dtype=dtypes.int32) + expected_array = ops.convert_to_tensor( + expected_array, dtype=dtypes.float32) + + sorted_array = beam_search_decoder.gather_tree_from_array( + array, parent_ids, sequence_length) + + with self.test_session() as sess: + sorted_array, expected_array = sess.run([sorted_array, expected_array]) + self.assertAllEqual(expected_array, sorted_array) + class TestArrayShapeChecks(test.TestCase): diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 1c9d179e3c55ad07fcf709f66028c91c20e8eea0..0ba32cd3bf8a374f5f55bdc6b2325b03443cd545 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -382,8 +382,8 @@ class LuongAttention(_BaseAttentionMechanism): for values past the respective sequence lengths. scale: Python boolean. Whether to scale the energy term. probability_fn: (optional) A `callable`. Converts the score to - probabilities. The default is @{tf.nn.softmax}. Other options include - @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}. + probabilities. The default is `tf.nn.softmax`. Other options include + `tf.contrib.seq2seq.hardmax` and `tf.contrib.sparsemax.sparsemax`. Its signature should be: `probabilities = probability_fn(score)`. score_mask_value: (optional) The mask value for score before passing into `probability_fn`. The default is -inf. Only used if @@ -529,8 +529,8 @@ class BahdanauAttention(_BaseAttentionMechanism): for values past the respective sequence lengths. normalize: Python boolean. Whether to normalize the energy term. probability_fn: (optional) A `callable`. Converts the score to - probabilities. The default is @{tf.nn.softmax}. Other options include - @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}. + probabilities. The default is `tf.nn.softmax`. Other options include + `tf.contrib.seq2seq.hardmax` and `tf.contrib.sparsemax.sparsemax`. Its signature should be: `probabilities = probability_fn(score)`. score_mask_value: (optional): The mask value for score before passing into `probability_fn`. The default is -inf. Only used if @@ -1091,7 +1091,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): `AttentionWrapper`, then you must ensure that: - The encoder output has been tiled to `beam_width` via - @{tf.contrib.seq2seq.tile_batch} (NOT `tf.tile`). + `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). - The `batch_size` argument passed to the `zero_state` method of this wrapper is equal to `true_batch_size * beam_width`. - The initial state created with `zero_state` above contains a diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 184144f64a56358206014a0f75473b4a9b16617a..74741a7bd6306181c248af50e9784f45dfc41c55 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -145,24 +145,20 @@ def gather_tree_from_array(t, parent_ids, sequence_length): array_ops.expand_dims(math_ops.range(beam_width), 0), 0) beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1]) - mask = array_ops.sequence_mask( - sequence_length, maxlen=max_time, dtype=dtypes.int32) - mask = array_ops.transpose(mask, perm=[2, 0, 1]) - - # Use beam_width + 1 to mark the end of beam. - masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1) - max_sequence_lengths = math_ops.to_int32( math_ops.reduce_max(sequence_length, axis=1)) sorted_beam_ids = beam_search_ops.gather_tree( - step_ids=masked_beam_ids, + step_ids=beam_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=beam_width + 1) # For out of range steps, simply copy the same beam. + in_bound_steps = array_ops.transpose( + array_ops.sequence_mask(sequence_length, maxlen=max_time), + perm=[2, 0, 1]) sorted_beam_ids = array_ops.where( - math_ops.cast(mask, dtypes.bool), x=sorted_beam_ids, y=beam_ids) + in_bound_steps, x=sorted_beam_ids, y=beam_ids) # Generate indices for gather_nd. time_ind = array_ops.tile(array_ops.reshape( @@ -238,7 +234,7 @@ class BeamSearchDecoder(decoder.Decoder): `AttentionWrapper`, then you must ensure that: - The encoder output has been tiled to `beam_width` via - @{tf.contrib.seq2seq.tile_batch} (NOT `tf.tile`). + `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). - The `batch_size` argument passed to the `zero_state` method of this wrapper is equal to `true_batch_size * beam_width`. - The initial state created with `zero_state` above contains a @@ -250,7 +246,7 @@ class BeamSearchDecoder(decoder.Decoder): ``` tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=beam_width) - tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch( + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( encoder_final_state, multiplier=beam_width) tiled_sequence_length = tf.contrib.seq2seq.tile_batch( sequence_length, multiplier=beam_width) diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index e69725ff8ab1ba4de880c914a6f5fdad5e54566d..f58268eff525a4b592c79acb32207e1a3f62bdc7 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -21,6 +21,7 @@ from __future__ import print_function import abc import six +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -182,19 +183,20 @@ def dynamic_decode(decoder, raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) - def _is_xla_tensor(tensor): - try: - op = tensor.op - except AttributeError: - return False - if control_flow_util.IsInXLAContext(op): - return True - return False - with variable_scope.variable_scope(scope, "decoder") as varscope: - # Properly cache variable values inside the while_loop - if varscope.caching_device is None: - varscope.set_caching_device(lambda op: op.device) + # Determine context types. + ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None + in_while_loop = ( + control_flow_util.GetContainingWhileContext(ctxt) is not None) + # Properly cache variable values inside the while_loop. + # Don't set a caching device when running in a loop, since it is possible + # that train steps could be wrapped in a tf.while_loop. In that scenario + # caching prevents forward computations in loop iterations from re-reading + # the updated weights. + if not context.executing_eagerly() and not in_while_loop: + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) if maximum_iterations is not None: maximum_iterations = ops.convert_to_tensor( @@ -208,9 +210,6 @@ def dynamic_decode(decoder, decoder.output_dtype, decoder.batch_size) - is_xla = False - if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]): - is_xla = True if is_xla and maximum_iterations is None: raise ValueError("maximum_iterations is required for XLA compilation.") if maximum_iterations is not None: diff --git a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py index 345eb6cfaa67fd4cda6e7e3f01a1243bbf3c9fa1..f4348e80eac54933d67cdf7bd281d6a9c6c10381 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py @@ -53,7 +53,8 @@ def spectrogram_to_mel_matrix(num_mel_bins=20, num_spectrogram_bins=129, audio_sample_rate=8000, lower_edge_hertz=125.0, - upper_edge_hertz=3800.0): + upper_edge_hertz=3800.0, + unused_dtype=None): """Return a matrix that can post-multiply spectrogram rows to make mel. Copied from @@ -132,9 +133,9 @@ class LinearToMelTest(test.TestCase): # lower_edge_hertz, upper_edge_hertz) to test. configs = [ # Defaults. - (20, 129, 8000.0, 125.0, 3800.0), + (20, 129, 8000.0, 125.0, 3800.0, dtypes.float64), # Settings used by Tacotron (https://arxiv.org/abs/1703.10135). - (80, 1025, 24000.0, 80.0, 12000.0) + (80, 1025, 24000.0, 80.0, 12000.0, dtypes.float64) ] with self.test_session(use_gpu=True): for config in configs: @@ -143,7 +144,8 @@ class LinearToMelTest(test.TestCase): self.assertAllClose(mel_matrix_np, mel_matrix.eval(), atol=3e-6) def test_dtypes(self): - for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): + # LinSpace is not supported for tf.float16. + for dtype in (dtypes.bfloat16, dtypes.float32, dtypes.float64): self.assertEqual(dtype, mel_ops.linear_to_mel_weight_matrix(dtype=dtype).dtype) @@ -167,7 +169,8 @@ class LinearToMelTest(test.TestCase): def test_constant_folding(self): """Mel functions should be constant foldable.""" - for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): + # TODO(rjryan): tf.bloat16 cannot be constant folded by Grappler. + for dtype in (dtypes.float32, dtypes.float64): g = ops.Graph() with g.as_default(): mel_matrix = mel_ops.linear_to_mel_weight_matrix(dtype=dtype) diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py index 03d6da7765ba5249a9fb22f56a469cf07c310479..f10d78259a3be3a3a6f7f78c196ab107f18a53aa 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py @@ -147,7 +147,7 @@ class SpectralOpsTest(test.TestCase): inverse_stft = spectral_ops.inverse_stft(stft, frame_length=8, fft_length=16, frame_step=8) expected_length = (stft.shape[0] - 1) * 8 + 8 - self.assertAllEqual([None], inverse_stft.shape.as_list()) + self.assertAllEqual([256], inverse_stft.shape.as_list()) self.assertAllEqual([expected_length], inverse_stft.eval().shape) def test_stft_and_inverse_stft(self): diff --git a/tensorflow/contrib/signal/python/kernel_tests/test_util.py b/tensorflow/contrib/signal/python/kernel_tests/test_util.py index 9a3603b6a97ef7c3a4b940b83281ebceda93c9db..b4422a49887378187a2be46275d4dabf1fbd40a1 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/test_util.py +++ b/tensorflow/contrib/signal/python/kernel_tests/test_util.py @@ -27,18 +27,19 @@ def grappler_optimize(graph, fetches=None, rewriter_config=None): """Tries to optimize the provided graph using grappler. Args: - graph: A @{tf.Graph} instance containing the graph to optimize. + graph: A `tf.Graph` instance containing the graph to optimize. fetches: An optional list of `Tensor`s to fetch (i.e. not optimize away). Grappler uses the 'train_op' collection to look for fetches, so if not provided this collection should be non-empty. - rewriter_config: An optional @{tf.RewriterConfig} to use when rewriting the + rewriter_config: An optional `tf.RewriterConfig` to use when rewriting the graph. Returns: - A @{tf.GraphDef} containing the rewritten graph. + A `tf.GraphDef` containing the rewritten graph. """ if rewriter_config is None: rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.min_graph_nodes = -1 if fetches is not None: for fetch in fetches: graph.add_to_collection('train_op', fetch) diff --git a/tensorflow/contrib/signal/python/ops/mel_ops.py b/tensorflow/contrib/signal/python/ops/mel_ops.py index 1e84006116daa3f28c760037cb9eeafd53eaafb8..ecc2fedb9f82151511bab3f3c0496bc4e290903f 100644 --- a/tensorflow/contrib/signal/python/ops/mel_ops.py +++ b/tensorflow/contrib/signal/python/ops/mel_ops.py @@ -108,7 +108,7 @@ def linear_to_mel_weight_matrix(num_mel_bins=20, # `M` has shape [frames, num_mel_bins] M = tf.matmul(S, A) - The matrix can be used with @{tf.tensordot} to convert an arbitrary rank + The matrix can be used with `tf.tensordot` to convert an arbitrary rank `Tensor` of linear-scale spectral bins into the mel scale. # S has shape [..., num_spectrogram_bins]. @@ -151,22 +151,21 @@ def linear_to_mel_weight_matrix(num_mel_bins=20, _validate_arguments(num_mel_bins, sample_rate, lower_edge_hertz, upper_edge_hertz, dtype) - # To preserve accuracy, we compute the matrix at float64 precision and then - # cast to `dtype` at the end. This function can be constant folded by graph - # optimization since there are no Tensor inputs. + # This function can be constant folded by graph optimization since there are + # no Tensor inputs. sample_rate = ops.convert_to_tensor( - sample_rate, dtypes.float64, name='sample_rate') + sample_rate, dtype, name='sample_rate') lower_edge_hertz = ops.convert_to_tensor( - lower_edge_hertz, dtypes.float64, name='lower_edge_hertz') + lower_edge_hertz, dtype, name='lower_edge_hertz') upper_edge_hertz = ops.convert_to_tensor( - upper_edge_hertz, dtypes.float64, name='upper_edge_hertz') - zero_float64 = ops.convert_to_tensor(0.0, dtypes.float64) + upper_edge_hertz, dtype, name='upper_edge_hertz') + zero = ops.convert_to_tensor(0.0, dtype) # HTK excludes the spectrogram DC bin. bands_to_zero = 1 nyquist_hertz = sample_rate / 2.0 linear_frequencies = math_ops.linspace( - zero_float64, nyquist_hertz, num_spectrogram_bins)[bands_to_zero:] + zero, nyquist_hertz, num_spectrogram_bins)[bands_to_zero:] spectrogram_bins_mel = array_ops.expand_dims( _hertz_to_mel(linear_frequencies), 1) @@ -193,11 +192,8 @@ def linear_to_mel_weight_matrix(num_mel_bins=20, # Intersect the line segments with each other and zero. mel_weights_matrix = math_ops.maximum( - zero_float64, math_ops.minimum(lower_slopes, upper_slopes)) + zero, math_ops.minimum(lower_slopes, upper_slopes)) # Re-add the zeroed lower bins we sliced out above. - mel_weights_matrix = array_ops.pad( - mel_weights_matrix, [[bands_to_zero, 0], [0, 0]]) - - # Cast to the desired type. - return math_ops.cast(mel_weights_matrix, dtype, name=name) + return array_ops.pad( + mel_weights_matrix, [[bands_to_zero, 0], [0, 0]], name=name) diff --git a/tensorflow/contrib/signal/python/ops/reconstruction_ops.py b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py index 653c030a04c2bbc7e3ee49b9c85a781fb49de8d0..4db8dc2ca090534f2cda66bd55c30dfa389b860a 100644 --- a/tensorflow/contrib/signal/python/ops/reconstruction_ops.py +++ b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py @@ -90,22 +90,28 @@ def overlap_and_add(signal, frame_step, name=None): raise ValueError("frame_step must be an integer. Got %s" % frame_step.dtype) - # If frame_length and frame_step are known at graph construction time, check - # frame_step is less than or equal to frame_length. - frame_step_static = tensor_util.constant_value(frame_step) - if (frame_step_static is not None and signal.shape.ndims is not None and - signal.shape[-1].value is not None and - frame_step_static > signal.shape[-1].value): - raise ValueError( - "frame_step (%d) must be less than or equal to frame_length (%d)" % ( - frame_step_static, signal.shape[-1].value)) - signal_shape = array_ops.shape(signal) # All dimensions that are not part of the overlap-and-add. Can be empty for # rank 2 inputs. outer_dimensions = signal_shape[:-2] + # If frame_length and frame_step are known at graph construction time, check + # frame_step is less than or equal to frame_length. + frame_step_static = tensor_util.constant_value(frame_step) + if (frame_step_static is not None and signal.shape.ndims is not None and + signal.shape[-1].value is not None): + if frame_step_static > signal.shape[-1].value: + raise ValueError( + "frame_step (%d) must be less than or equal to " + "frame_length (%d)" % ( + frame_step_static, signal.shape[-1].value)) + # If frame_length is equal to frame_step, there's no overlap so just + # reshape the tensor. + if frame_step_static == signal.shape[-1].value: + return array_ops.reshape(signal, array_ops.concat( + [outer_dimensions, [-1]], 0)) + signal_rank = array_ops.rank(signal) frames = signal_shape[-2] frame_length = signal_shape[-1] diff --git a/tensorflow/contrib/slim/python/slim/evaluation.py b/tensorflow/contrib/slim/python/slim/evaluation.py index 5cfd5ee82e2a0fce33311a8783d2d4ceb031544d..0feb3925eb8ec4eca7c7fd527510f45ceb83091b 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation.py +++ b/tensorflow/contrib/slim/python/slim/evaluation.py @@ -22,7 +22,8 @@ modules using a variety of metrics and summarizing the results. ********************** In the simplest use case, we use a model to create the predictions, then specify -the metrics and finally call the `evaluation` method: +the metrics and choose one model checkpoint, finally call the`evaluation_once` +method: # Create model and obtain the predictions: images, labels = LoadData(...) @@ -34,20 +35,24 @@ the metrics and finally call the `evaluation` method: "mse": slim.metrics.mean_squared_error(predictions, labels), }) + checkpoint_path = '/tmp/my_model_dir/my_checkpoint' + log_dir = '/tmp/my_model_eval/' + initial_op = tf.group( tf.global_variables_initializer(), tf.local_variables_initializer()) - with tf.Session() as sess: - metric_values = slim.evaluation( - sess, - num_evals=1, - initial_op=initial_op, - eval_op=names_to_updates.values(), - final_op=name_to_values.values()) + metric_values = slim.evaluate_once( + master='', + checkpoint_path=checkpoint_path, + log_dir=log_dir, + num_evals=1, + initial_op=initial_op, + eval_op=names_to_updates.values(), + final_op=name_to_values.values()) - for metric, value in zip(names_to_values.keys(), metric_values): - logging.info('Metric %s has value: %f', metric, value) + for metric, value in zip(names_to_values.keys(), metric_values): + logging.info('Metric %s has value: %f', metric, value) ************************************************ * Evaluating a Checkpointed Model with Metrics * diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index 3d0308aaf3da3b5b16fd22a2905db36917e8c97b..2c97834523424d0fab56330b4d9355a75427e0ef 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -33,7 +33,6 @@ from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.wrappers import hooks from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics @@ -242,7 +241,7 @@ class SingleEvaluationTest(test.TestCase): checkpoint_path = os.path.join(self.get_temp_dir(), 'this_file_doesnt_exist') log_dir = os.path.join(self.get_temp_dir(), 'error_raised') - with self.assertRaises(errors.NotFoundError): + with self.assertRaises(ValueError): evaluation.evaluate_once('', checkpoint_path, log_dir) def _prepareCheckpoint(self, checkpoint_path): diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD index 0b8fc0cdc66ae41807cce92776ada263675b1f94..412a2c81a140fbd44d3d01efcc90b1fc419068f1 100644 --- a/tensorflow/contrib/stat_summarizer/BUILD +++ b/tensorflow/contrib/stat_summarizer/BUILD @@ -31,8 +31,5 @@ tf_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:variables", ], - tags = [ - "no_windows", - "notap", # TODO(b/80546574): test is flaky - ], + tags = ["notap"], # TODO(b/80546574): test is flaky ) diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py index d22b80ac88a9ced541a952fcbb58c50366464075..42898e797cc351e3de290cc65fc825f1406c739d 100644 --- a/tensorflow/contrib/summary/summary.py +++ b/tensorflow/contrib/summary/summary.py @@ -17,7 +17,7 @@ The operations in this package are safe to use with eager execution turned on or off. It has a more flexible API that allows summaries to be written directly from ops to places other than event log files, rather than propagating protos -from @{tf.summary.merge_all} to @{tf.summary.FileWriter}. +from `tf.summary.merge_all` to `tf.summary.FileWriter`. To use with eager execution enabled, write your code as follows: diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index f1ef218e74bbd225071324a8269fdfeb5de0e038..4d1807130c57039976dfa57c27bb0d4807e75212 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -20,6 +20,8 @@ import os import tempfile import time +import sqlite3 + import numpy as np import six @@ -81,6 +83,19 @@ class EagerFileTest(test_util.TensorFlowTestCase): # test here that we're calling them correctly. self.assertTrue(gfile.Exists(logdir)) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testEagerMemory(self): + training_util.get_or_create_global_step() + logdir = self.get_temp_dir() + with summary_ops.create_file_writer( + logdir, max_queue=0, + name='t0').as_default(), summary_ops.always_record_summaries(): + summary_ops.generic('tensor', 1, '') + summary_ops.scalar('scalar', 2.0) + summary_ops.histogram('histogram', [1.0]) + summary_ops.image('image', [[[[1.0]]]]) + summary_ops.audio('audio', [[1.0]], 1.0, 1) + def testDefunSummarys(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() @@ -262,6 +277,22 @@ class EagerFileTest(test_util.TensorFlowTestCase): class EagerDbTest(summary_test_util.SummaryDbTest): + def testDbURIOpen(self): + tmpdb_path = os.path.join(self.get_temp_dir(), 'tmpDbURITest.sqlite') + tmpdb_uri = six.moves.urllib_parse.urljoin("file:", tmpdb_path) + tmpdb_writer = summary_ops.create_db_writer( + tmpdb_uri, + "experimentA", + "run1", + "user1") + with summary_ops.always_record_summaries(): + with tmpdb_writer.as_default(): + summary_ops.scalar('t1', 2.0) + tmpdb = sqlite3.connect(tmpdb_path) + num = get_one(tmpdb, 'SELECT count(*) FROM Tags WHERE tag_name = "t1"') + self.assertEqual(num, 1) + tmpdb.close() + def testIntegerSummaries(self): step = training_util.create_global_step() writer = self.create_db_writer() diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 136856c0156c41046f9af61cdd6e3d5f8213309e..164f3e58e6c0b2486d270c457500c8dca0c7e7eb 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -223,7 +223,6 @@ tf_kernel_library( ":model_ops_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", ], alwayslink = 1, ) @@ -319,7 +318,6 @@ tf_kernel_library( ":stats_ops_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", ], alwayslink = 1, ) diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index 35e8c92aba325d9115c7ee566363a1625e6e76fc..8fa0b3ada9488d33f444858d309f8bd7885af75c 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -22,10 +22,12 @@ from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib - from tensorflow.contrib.tensor_forest.client import eval_metrics from tensorflow.contrib.tensor_forest.python import tensor_forest - +from tensorflow.python.estimator import estimator as core_estimator +from tensorflow.python.estimator.canned import head as core_head_lib +from tensorflow.python.estimator.export.export_output import PredictOutput +from tensorflow.python.feature_column import feature_column as fc_core from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops @@ -34,12 +36,12 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util - KEYS_NAME = 'keys' LOSS_NAME = 'rf_training_loss' TREE_PATHS_PREDICTION_KEY = 'tree_paths' @@ -48,6 +50,11 @@ ALL_SERVING_KEY = 'tensorforest_all' EPSILON = 0.000001 +class ModelBuilderOutputType(object): + MODEL_FN_OPS = 0 + ESTIMATOR_SPEC = 1 + + class TensorForestRunOpAtEndHook(session_run_hook.SessionRunHook): def __init__(self, op_dict): @@ -106,20 +113,34 @@ class TensorForestLossHook(session_run_hook.SessionRunHook): run_context.request_stop() -def get_default_head(params, weights_name, name=None): - if params.regression: - return head_lib.regression_head( - weight_column_name=weights_name, - label_dimension=params.num_outputs, - enable_centered_bias=False, - head_name=name) +def _get_default_head(params, weights_name, output_type, name=None): + """Creates a default head based on a type of a problem.""" + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + if params.regression: + return head_lib.regression_head( + weight_column_name=weights_name, + label_dimension=params.num_outputs, + enable_centered_bias=False, + head_name=name) + else: + return head_lib.multi_class_head( + params.num_classes, + weight_column_name=weights_name, + enable_centered_bias=False, + head_name=name) else: - return head_lib.multi_class_head( - params.num_classes, - weight_column_name=weights_name, - enable_centered_bias=False, - head_name=name) - + if params.regression: + return core_head_lib._regression_head( # pylint:disable=protected-access + weight_column=weights_name, + label_dimension=params.num_outputs, + name=name, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + else: + return core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access + n_classes=params.num_classes, + weight_column=weights_name, + name=name, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) def get_model_fn(params, graph_builder_class, @@ -135,19 +156,27 @@ def get_model_fn(params, report_feature_importances=False, local_eval=False, head_scope=None, - include_all_in_serving=False): + include_all_in_serving=False, + output_type=ModelBuilderOutputType.MODEL_FN_OPS): """Return a model function given a way to construct a graph builder.""" if model_head is None: - model_head = get_default_head(params, weights_name) + model_head = _get_default_head(params, weights_name, output_type) def _model_fn(features, labels, mode): """Function that returns predictions, training loss, and training op.""" + if (isinstance(features, ops.Tensor) or isinstance(features, sparse_tensor.SparseTensor)): features = {'features': features} if feature_columns: features = features.copy() - features.update(layers.transform_features(features, feature_columns)) + + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + features.update(layers.transform_features(features, feature_columns)) + else: + for fc in feature_columns: + tensor = fc_core._transform_features(features, [fc])[fc] # pylint: disable=protected-access + features[fc.name] = tensor weights = None if weights_name and weights_name in features: @@ -201,52 +230,95 @@ def get_model_fn(params, def _train_fn(unused_loss): return training_graph - model_ops = model_head.create_model_fn_ops( - features=features, - labels=labels, - mode=mode, - train_op_fn=_train_fn, - logits=logits, - scope=head_scope) # Ops are run in lexigraphical order of their keys. Run the resource # clean-up op last. all_handles = graph_builder.get_all_resource_handles() ops_at_end = { - '9: clean up resources': control_flow_ops.group( - *[resource_variable_ops.destroy_resource_op(handle) - for handle in all_handles])} + '9: clean up resources': + control_flow_ops.group(*[ + resource_variable_ops.destroy_resource_op(handle) + for handle in all_handles + ]) + } if report_feature_importances: ops_at_end['1: feature_importances'] = ( graph_builder.feature_importances()) - training_hooks.append(TensorForestRunOpAtEndHook(ops_at_end)) - - if early_stopping_rounds: - training_hooks.append( - TensorForestLossHook( - early_stopping_rounds, - early_stopping_loss_threshold=early_stopping_loss_threshold, - loss_op=model_ops.loss)) - - model_ops.training_hooks.extend(training_hooks) - - if keys is not None: - model_ops.predictions[keys_name] = keys - - if params.inference_tree_paths: - model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths - - model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance - if include_all_in_serving: - # In order to serve the variance we need to add the prediction dict - # to output_alternatives dict. - if not model_ops.output_alternatives: - model_ops.output_alternatives = {} - model_ops.output_alternatives[ALL_SERVING_KEY] = ( - constants.ProblemType.UNSPECIFIED, model_ops.predictions) - return model_ops + training_hooks = [TensorForestRunOpAtEndHook(ops_at_end)] + + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + model_ops = model_head.create_model_fn_ops( + features=features, + labels=labels, + mode=mode, + train_op_fn=_train_fn, + logits=logits, + scope=head_scope) + + if early_stopping_rounds: + training_hooks.append( + TensorForestLossHook( + early_stopping_rounds, + early_stopping_loss_threshold=early_stopping_loss_threshold, + loss_op=model_ops.loss)) + + model_ops.training_hooks.extend(training_hooks) + + if keys is not None: + model_ops.predictions[keys_name] = keys + + if params.inference_tree_paths: + model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths + + model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance + + if include_all_in_serving: + # In order to serve the variance we need to add the prediction dict + # to output_alternatives dict. + if not model_ops.output_alternatives: + model_ops.output_alternatives = {} + model_ops.output_alternatives[ALL_SERVING_KEY] = ( + constants.ProblemType.UNSPECIFIED, model_ops.predictions) + + return model_ops + + else: + # Estimator spec + estimator_spec = model_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_fn, + logits=logits) + + if early_stopping_rounds: + training_hooks.append( + TensorForestLossHook( + early_stopping_rounds, + early_stopping_loss_threshold=early_stopping_loss_threshold, + loss_op=estimator_spec.loss)) + + estimator_spec = estimator_spec._replace( + training_hooks=training_hooks + list(estimator_spec.training_hooks)) + if keys is not None: + estimator_spec.predictions[keys_name] = keys + if params.inference_tree_paths: + estimator_spec.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths + estimator_spec.predictions[VARIANCE_PREDICTION_KEY] = regression_variance + + if include_all_in_serving: + outputs = estimator_spec.export_outputs + if not outputs: + outputs = {} + outputs = {ALL_SERVING_KEY: PredictOutput(estimator_spec.predictions)} + print(estimator_spec.export_outputs) + # In order to serve the variance we need to add the prediction dict + # to output_alternatives dict. + estimator_spec = estimator_spec._replace(export_outputs=outputs) + + return estimator_spec return _model_fn @@ -493,8 +565,11 @@ class MultiForestMultiHeadEstimator(estimator.Estimator): params, graph_builder_class, device_assigner, - model_head=get_default_head( - params, weight_column, name='head{0}'.format(i)), + model_head=_get_default_head( + params, + weight_column, + name='head{0}'.format(i), + output_type=ModelBuilderOutputType.MODEL_FN_OPS), weights_name=weight_column, keys_name=keys_column, early_stopping_rounds=early_stopping_rounds, @@ -509,3 +584,142 @@ class MultiForestMultiHeadEstimator(estimator.Estimator): model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) + + +class CoreTensorForestEstimator(core_estimator.Estimator): + """A CORE estimator that can train and evaluate a random forest. + + Example: + + ```python + params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams( + num_classes=2, num_features=40, num_trees=10, max_nodes=1000) + + # Estimator using the default graph builder. + estimator = CoreTensorForestEstimator(params, model_dir=model_dir) + + # Or estimator using TrainingLossForest as the graph builder. + estimator = CoreTensorForestEstimator( + params, graph_builder_class=tensor_forest.TrainingLossForest, + model_dir=model_dir) + + # Input builders + def input_fn_train: # returns x, y + ... + def input_fn_eval: # returns x, y + ... + estimator.train(input_fn=input_fn_train) + estimator.evaluate(input_fn=input_fn_eval) + + # Predict returns an iterable of dicts. + results = list(estimator.predict(x=x)) + prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME] + prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME] + ``` + """ + + def __init__(self, + params, + device_assigner=None, + model_dir=None, + feature_columns=None, + graph_builder_class=tensor_forest.RandomForestGraphs, + config=None, + weight_column=None, + keys_column=None, + feature_engineering_fn=None, + early_stopping_rounds=100, + early_stopping_loss_threshold=0.001, + num_trainers=1, + trainer_id=0, + report_feature_importances=False, + local_eval=False, + version=None, + head=None, + include_all_in_serving=False): + """Initializes a TensorForestEstimator instance. + + Args: + params: ForestHParams object that holds random forest hyperparameters. + These parameters will be passed into `model_fn`. + device_assigner: An `object` instance that controls how trees get + assigned to devices. If `None`, will use + `tensor_forest.RandomForestDeviceAssigner`. + model_dir: Directory to save model parameters, graph, etc. To continue + training a previously saved model, load checkpoints saved to this + directory into an estimator. + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `_FeatureColumn`. + graph_builder_class: An `object` instance that defines how TF graphs for + random forest training and inference are built. By default will use + `tensor_forest.RandomForestGraphs`. Can be overridden by version + kwarg. + config: `RunConfig` object to configure the runtime settings. + weight_column: A string defining feature column name representing + weights. Will be multiplied by the loss of the example. Used to + downweight or boost examples during training. + keys_column: A string naming one of the features to strip out and + pass through into the inference/eval results dict. Useful for + associating specific examples with their prediction. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + early_stopping_rounds: Allows training to terminate early if the forest is + no longer growing. 100 by default. Set to a Falsy value to disable + the default training hook. + early_stopping_loss_threshold: Percentage (as fraction) that loss must + improve by within early_stopping_rounds steps, otherwise training will + terminate. + num_trainers: Number of training jobs, which will partition trees + among them. + trainer_id: Which trainer this instance is. + report_feature_importances: If True, print out feature importances + during evaluation. + local_eval: If True, don't use a device assigner for eval. This is to + support some common setups where eval is done on a single machine, even + though training might be distributed. + version: Unused. + head: A heads_lib.Head object that calculates losses and such. If None, + one will be automatically created based on params. + include_all_in_serving: if True, allow preparation of the complete + prediction dict including the variance to be exported for serving with + the Servo lib; and it also requires calling export_savedmodel with + default_output_alternative_key=ALL_SERVING_KEY, i.e. + estimator.export_savedmodel(export_dir_base=your_export_dir, + serving_input_fn=your_export_input_fn, + default_output_alternative_key=ALL_SERVING_KEY) + if False, resort to default behavior, i.e. export scores and + probabilities but no variances. In this case + default_output_alternative_key should be None while calling + export_savedmodel(). + Note, that due to backward compatibility we cannot always set + include_all_in_serving to True because in this case calling + export_saved_model() without + default_output_alternative_key=ALL_SERVING_KEY (legacy behavior) the + saved_model_export_utils.get_output_alternatives() would raise + ValueError. + + Returns: + A `TensorForestEstimator` instance. + """ + + super(CoreTensorForestEstimator, self).__init__( + model_fn=get_model_fn( + params.fill(), + graph_builder_class, + device_assigner, + feature_columns=feature_columns, + model_head=head, + weights_name=weight_column, + keys_name=keys_column, + early_stopping_rounds=early_stopping_rounds, + early_stopping_loss_threshold=early_stopping_loss_threshold, + num_trainers=num_trainers, + trainer_id=trainer_id, + report_feature_importances=report_feature_importances, + local_eval=local_eval, + include_all_in_serving=include_all_in_serving, + output_type=ModelBuilderOutputType.ESTIMATOR_SPEC), + model_dir=model_dir, + config=config) diff --git a/tensorflow/contrib/tensor_forest/client/random_forest_test.py b/tensorflow/contrib/tensor_forest/client/random_forest_test.py index ac42364d25796aa34ef0831a00c768656cc64adb..e951592f85396b9e6c85f008d910e1e36908f15c 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest_test.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest_test.py @@ -23,7 +23,39 @@ import numpy as np from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.tensor_forest.client import random_forest from tensorflow.contrib.tensor_forest.python import tensor_forest +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column_lib as core_feature_column +from tensorflow.python.framework import ops +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_utils + + +def _get_classification_input_fns(): + iris = base.load_iris() + data = iris.data.astype(np.float32) + labels = iris.target.astype(np.int32) + + train_input_fn = numpy_io.numpy_input_fn( + x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False) + + predict_input_fn = numpy_io.numpy_input_fn( + x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False) + return train_input_fn, predict_input_fn + + +def _get_regression_input_fns(): + boston = base.load_boston() + data = boston.data.astype(np.float32) + labels = boston.target.astype(np.int32) + + train_input_fn = numpy_io.numpy_input_fn( + x=data, y=labels, batch_size=506, num_epochs=None, shuffle=False) + + predict_input_fn = numpy_io.numpy_input_fn( + x=data[:1,], y=None, batch_size=1, num_epochs=1, shuffle=False) + return train_input_fn, predict_input_fn class TensorForestTrainerTests(test.TestCase): @@ -39,32 +71,285 @@ class TensorForestTrainerTests(test.TestCase): inference_tree_paths=True) classifier = random_forest.TensorForestEstimator(hparams.fill()) + input_fn, predict_input_fn = _get_classification_input_fns() + classifier.fit(input_fn=input_fn, steps=100) + res = classifier.evaluate(input_fn=input_fn, steps=10) + + self.assertEqual(1.0, res['accuracy']) + self.assertAllClose(0.55144483, res['loss']) + + predictions = list(classifier.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0.576117, 0.211942, 0.211942]], + [pred['probabilities'] for pred in predictions]) + + def testRegression(self): + """Tests regression using matrix data as input.""" + + hparams = tensor_forest.ForestHParams( + num_trees=5, + max_nodes=1000, + num_classes=1, + num_features=13, + regression=True, + split_after_samples=20) + + regressor = random_forest.TensorForestEstimator(hparams.fill()) + + input_fn, predict_input_fn = _get_regression_input_fns() + + regressor.fit(input_fn=input_fn, steps=100) + res = regressor.evaluate(input_fn=input_fn, steps=10) + self.assertGreaterEqual(0.1, res['loss']) + + predictions = list(regressor.predict(input_fn=predict_input_fn)) + self.assertAllClose([24.], [pred['scores'] for pred in predictions], atol=1) + + def testAdditionalOutputs(self): + """Tests multi-class classification using matrix data as input.""" + hparams = tensor_forest.ForestHParams( + num_trees=1, + max_nodes=100, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + classifier = random_forest.TensorForestEstimator( + hparams.fill(), keys_column='keys', include_all_in_serving=True) + iris = base.load_iris() data = iris.data.astype(np.float32) labels = iris.target.astype(np.int32) - classifier.fit(x=data, y=labels, steps=100, batch_size=50) - classifier.evaluate(x=data, y=labels, steps=10) + input_fn = numpy_io.numpy_input_fn( + x={ + 'x': data, + 'keys': np.arange(len(iris.data)).reshape(150, 1) + }, + y=labels, + batch_size=10, + num_epochs=1, + shuffle=False) - def testRegression(self): + classifier.fit(input_fn=input_fn, steps=100) + predictions = list(classifier.predict(input_fn=input_fn)) + # Check that there is a key column, tree paths and var. + for pred in predictions: + self.assertTrue('keys' in pred) + self.assertTrue('tree_paths' in pred) + self.assertTrue('prediction_variance' in pred) + + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + + def testEarlyStopping(self): """Tests multi-class classification using matrix data as input.""" + hparams = tensor_forest.ForestHParams( + num_trees=100, + max_nodes=10000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + classifier = random_forest.TensorForestEstimator( + hparams.fill(), + # Set a crazy threshold - 30% loss change. + early_stopping_loss_threshold=0.3, + early_stopping_rounds=2) + + input_fn, _ = _get_classification_input_fns() + classifier.fit(input_fn=input_fn, steps=100) + + # We stopped early. + self._assert_checkpoint(classifier.model_dir, global_step=5) + + +class CoreTensorForestTests(test.TestCase): + + def testTrainEvaluateInferDoesNotThrowErrorForClassifier(self): + head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) hparams = tensor_forest.ForestHParams( num_trees=3, max_nodes=1000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + + est = random_forest.CoreTensorForestEstimator(hparams.fill(), head=head_fn) + + input_fn, predict_input_fn = _get_classification_input_fns() + + est.train(input_fn=input_fn, steps=100) + res = est.evaluate(input_fn=input_fn, steps=1) + + self.assertEqual(1.0, res['accuracy']) + self.assertAllClose(0.55144483, res['loss']) + + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0.576117, 0.211942, 0.211942]], + [pred['probabilities'] for pred in predictions]) + + def testRegression(self): + """Tests regression using matrix data as input.""" + head_fn = head_lib._regression_head( + label_dimension=1, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + hparams = tensor_forest.ForestHParams( + num_trees=5, + max_nodes=1000, num_classes=1, num_features=13, regression=True, split_after_samples=20) - regressor = random_forest.TensorForestEstimator(hparams.fill()) + regressor = random_forest.CoreTensorForestEstimator( + hparams.fill(), head=head_fn) + + input_fn, predict_input_fn = _get_regression_input_fns() + + regressor.train(input_fn=input_fn, steps=100) + res = regressor.evaluate(input_fn=input_fn, steps=10) + self.assertGreaterEqual(0.1, res['loss']) + + predictions = list(regressor.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[24.]], [pred['predictions'] for pred in predictions], atol=1) + + def testWithFeatureColumns(self): + head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + hparams = tensor_forest.ForestHParams( + num_trees=3, + max_nodes=1000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + + est = random_forest.CoreTensorForestEstimator( + hparams.fill(), + head=head_fn, + feature_columns=[core_feature_column.numeric_column('x')]) + + iris = base.load_iris() + data = {'x': iris.data.astype(np.float32)} + labels = iris.target.astype(np.int32) + + input_fn = numpy_io.numpy_input_fn( + x=data, y=labels, batch_size=150, num_epochs=None, shuffle=False) + + est.train(input_fn=input_fn, steps=100) + res = est.evaluate(input_fn=input_fn, steps=1) + + self.assertEqual(1.0, res['accuracy']) + self.assertAllClose(0.55144483, res['loss']) + + def testAutofillsClassificationHead(self): + hparams = tensor_forest.ForestHParams( + num_trees=3, + max_nodes=1000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + + est = random_forest.CoreTensorForestEstimator(hparams.fill()) + + input_fn, _ = _get_classification_input_fns() + + est.train(input_fn=input_fn, steps=100) + res = est.evaluate(input_fn=input_fn, steps=1) + + self.assertEqual(1.0, res['accuracy']) + self.assertAllClose(0.55144483, res['loss']) + + def testAutofillsRegressionHead(self): + hparams = tensor_forest.ForestHParams( + num_trees=5, + max_nodes=1000, + num_classes=1, + num_features=13, + regression=True, + split_after_samples=20) + + regressor = random_forest.CoreTensorForestEstimator(hparams.fill()) + + input_fn, predict_input_fn = _get_regression_input_fns() + + regressor.train(input_fn=input_fn, steps=100) + res = regressor.evaluate(input_fn=input_fn, steps=10) + self.assertGreaterEqual(0.1, res['loss']) + + predictions = list(regressor.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[24.]], [pred['predictions'] for pred in predictions], atol=1) + + def testAdditionalOutputs(self): + """Tests multi-class classification using matrix data as input.""" + hparams = tensor_forest.ForestHParams( + num_trees=1, + max_nodes=100, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) + classifier = random_forest.CoreTensorForestEstimator( + hparams.fill(), keys_column='keys', include_all_in_serving=True) + + iris = base.load_iris() + data = iris.data.astype(np.float32) + labels = iris.target.astype(np.int32) + + input_fn = numpy_io.numpy_input_fn( + x={ + 'x': data, + 'keys': np.arange(len(iris.data)).reshape(150, 1) + }, + y=labels, + batch_size=10, + num_epochs=1, + shuffle=False) + + classifier.train(input_fn=input_fn, steps=100) + predictions = list(classifier.predict(input_fn=input_fn)) + # Check that there is a key column, tree paths and var. + for pred in predictions: + self.assertTrue('keys' in pred) + self.assertTrue('tree_paths' in pred) + self.assertTrue('prediction_variance' in pred) + + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + + def testEarlyStopping(self): + head_fn = head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=3, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + hparams = tensor_forest.ForestHParams( + num_trees=3, + max_nodes=1000, + num_classes=3, + num_features=4, + split_after_samples=20, + inference_tree_paths=True) - boston = base.load_boston() - data = boston.data.astype(np.float32) - labels = boston.target.astype(np.int32) + est = random_forest.CoreTensorForestEstimator( + hparams.fill(), + head=head_fn, + # Set a crazy threshold - 30% loss change. + early_stopping_loss_threshold=0.3, + early_stopping_rounds=2) - regressor.fit(x=data, y=labels, steps=100, batch_size=50) - regressor.evaluate(x=data, y=labels, steps=10) + input_fn, _ = _get_classification_input_fns() + est.train(input_fn=input_fn, steps=100) + # We stopped early. + self._assert_checkpoint(est.model_dir, global_step=5) if __name__ == "__main__": diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc index 7e25579070eef13682dedfcd3c9e435333f65687..6cb2c881e2428dfcac3187bf7364582e857b9879 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc @@ -51,7 +51,8 @@ std::unique_ptr CreateBinaryDecisionNodeEvaluator( InequalityDecisionNodeEvaluator::InequalityDecisionNodeEvaluator( const decision_trees::InequalityTest& test, int32 left, int32 right) : BinaryDecisionNodeEvaluator(left, right) { - safe_strto32(test.feature_id().id().value(), &feature_num_); + CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_)) + << "Invalid feature ID: [" << test.feature_id().id().value() << "]"; threshold_ = test.threshold().float_value(); include_equals_ = test.type() == decision_trees::InequalityTest::LESS_OR_EQUAL; @@ -72,7 +73,9 @@ ObliqueInequalityDecisionNodeEvaluator::ObliqueInequalityDecisionNodeEvaluator( : BinaryDecisionNodeEvaluator(left, right) { for (int i = 0; i < test.oblique().features_size(); ++i) { int32 val; - safe_strto32(test.oblique().features(i).id().value(), &val); + CHECK(safe_strto32(test.oblique().features(i).id().value(), &val)) + << "Invalid feature ID: [" << test.oblique().features(i).id().value() + << "]"; feature_num_.push_back(val); feature_weights_.push_back(test.oblique().weights(i)); } @@ -97,7 +100,8 @@ int32 ObliqueInequalityDecisionNodeEvaluator::Decide( MatchingValuesDecisionNodeEvaluator::MatchingValuesDecisionNodeEvaluator( const decision_trees::MatchingValuesTest& test, int32 left, int32 right) : BinaryDecisionNodeEvaluator(left, right) { - safe_strto32(test.feature_id().id().value(), &feature_num_); + CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_)) + << "Invalid feature ID: [" << test.feature_id().id().value() << "]"; for (const auto& val : test.value()) { values_.push_back(val.float_value()); } diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD index 3f6b4cdc9ad10f5089f28af35a8be408918c7f90..6507546ee9f81108add181a9c83064c9860005e2 100644 --- a/tensorflow/contrib/tensorboard/db/BUILD +++ b/tensorflow/contrib/tensorboard/db/BUILD @@ -106,6 +106,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:png_internal", "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index a5d8b061b6b26f9d05be40a1162481ae219b0e9c..fc0d22d112efcccd1a3be6388d36478cf2076ff5 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -3,7 +3,7 @@ # and provide TensorRT operators and converter package. # APIs are meant to change over time. -package(default_visibility = ["//tensorflow:__subpackages__"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 @@ -11,7 +11,6 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "py_test", "tf_cc_test", "tf_copts", "tf_cuda_library", @@ -20,6 +19,7 @@ load( "tf_gen_op_libs", "tf_gen_op_wrapper_py", ) +load("//tensorflow:tensorflow.bzl", "cuda_py_tests") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") @@ -33,11 +33,13 @@ tf_cuda_cc_test( size = "small", srcs = ["tensorrt_test.cc"], tags = [ - "manual", - "notap", + "no_windows", + "nomac", ], deps = [ + "//tensorflow/core:gpu_init", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", "//tensorflow/core:test", "//tensorflow/core:test_main", ] + if_tensorrt([ @@ -49,7 +51,6 @@ tf_cuda_cc_test( tf_custom_op_library( name = "python/ops/_trt_engine_op.so", srcs = [ - "ops/trt_calib_op.cc", "ops/trt_engine_op.cc", ], deps = [ @@ -76,33 +77,35 @@ tf_cuda_library( cc_library( name = "trt_engine_op_kernel", srcs = [ - "kernels/trt_calib_op.cc", "kernels/trt_engine_op.cc", ], hdrs = [ - "kernels/trt_calib_op.h", "kernels/trt_engine_op.h", ], copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":test_utils", + ":trt_allocator", + ":trt_conversion", ":trt_logging", ":trt_plugins", ":trt_resources", + ":utils", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/core/grappler/costs:graph_properties", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]) + tf_custom_op_library_additional_deps(), - # TODO(laigd) + # TODO(laigd): fix this by merging header file in cc file. alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs ) tf_gen_op_libs( op_lib_names = [ "trt_engine_op", - "trt_calib_op", ], ) @@ -120,9 +123,7 @@ tf_cuda_library( tf_gen_op_wrapper_py( name = "trt_engine_op", - gen_locally = True, deps = [ - ":trt_calib_op_op_lib", ":trt_engine_op_op_lib", ":trt_logging", ":trt_shape_function", @@ -140,7 +141,6 @@ tf_custom_op_py_library( kernels = [ ":trt_engine_op_kernel", ":trt_engine_op_op_lib", - ":trt_calib_op_op_lib", ":trt_shape_function", ], srcs_version = "PY2AND3", @@ -159,6 +159,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":tf_trt_integration_test_base", ":trt_convert_py", ":trt_ops_py", "//tensorflow/python:errors", @@ -184,14 +185,19 @@ py_library( ], ) +# TODO(aaroey): this wrapper has been causing troubles of double linking, so +# either get rid of it, or split to make it contain minimum dependencies. tf_py_wrap_cc( name = "wrap_conversion", srcs = ["trt_conversion.i"], copts = tf_copts(), + swig_includes = [ + "//tensorflow/python:platform/base.i", + ], deps = [ + ":test_utils", ":trt_conversion", ":trt_engine_op_kernel", - "//tensorflow/core:framework_lite", "//third_party/python_runtime:headers", ], ) @@ -199,18 +205,31 @@ tf_py_wrap_cc( tf_cuda_library( name = "trt_resources", srcs = [ - "resources/trt_allocator.cc", "resources/trt_int8_calibrator.cc", "resources/trt_resource_manager.cc", ], hdrs = [ - "resources/trt_allocator.h", "resources/trt_int8_calibrator.h", "resources/trt_resource_manager.h", "resources/trt_resources.h", ], deps = [ + ":trt_allocator", ":trt_logging", + ":utils", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_cuda_library( + name = "trt_allocator", + srcs = ["resources/trt_allocator.cc"], + hdrs = ["resources/trt_allocator.h"], + deps = [ "//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_lite", "//tensorflow/core:lib_proto_parsing", @@ -219,6 +238,21 @@ tf_cuda_library( ]), ) +tf_cc_test( + name = "trt_allocator_test", + size = "small", + srcs = ["resources/trt_allocator_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + ":trt_allocator", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + # Library for the node-level conversion portion of TensorRT operation creation tf_cuda_library( name = "trt_conversion", @@ -234,16 +268,17 @@ tf_cuda_library( ], deps = [ ":segment", + ":test_utils", + ":trt_allocator", ":trt_plugins", ":trt_logging", ":trt_resources", + ":utils", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", - "//tensorflow/core:framework", - "//tensorflow/core:gpu_runtime", "//tensorflow/core:framework_lite", "//tensorflow/core:graph", "//tensorflow/core:lib", @@ -278,13 +313,21 @@ tf_cc_test( name = "segment_test", size = "small", srcs = ["segment/segment_test.cc"], + tags = [ + "no_windows", + "nomac", + ], deps = [ ":segment", - "//tensorflow/c:c_api", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:core_cpu", "//tensorflow/core:lib", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) @@ -314,8 +357,9 @@ tf_cuda_cc_test( size = "small", srcs = ["plugin/trt_plugin_factory_test.cc"], tags = [ - "manual", - "notap", + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", ], deps = [ ":trt_plugins", @@ -328,18 +372,61 @@ tf_cuda_cc_test( ]), ) -py_test( +py_library( + name = "tf_trt_integration_test_base", + srcs = ["test/tf_trt_integration_test_base.py"], + deps = [ + ":trt_convert_py", + ":trt_ops_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], +) + +cuda_py_tests( name = "tf_trt_integration_test", - srcs = ["test/tf_trt_integration_test.py"], - main = "test/tf_trt_integration_test.py", - srcs_version = "PY2AND3", - tags = [ - "manual", - "notap", + srcs = [ + "test/base_test.py", + # "test/batch_matmul_test.py", + # "test/biasadd_matmul_test.py", + # "test/binary_tensor_weight_broadcast_test.py", # Blocked by trt4 installation + # "test/concatenation_test.py", # Blocked by trt4 installation + "test/const_broadcast_test.py", + "test/multi_connection_neighbor_engine_test.py", + "test/neighboring_engine_test.py", + # "test/unary_test.py", # Blocked by trt4 installation + # "test/vgg_block_nchw_test.py", + # "test/vgg_block_test.py", + "test/memory_alignment_test.py", ], - deps = [ - ":init_py", + additional_deps = [ + ":tf_trt_integration_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", ], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], +) + +cc_library( + name = "utils", + srcs = ["convert/utils.cc"], + hdrs = ["convert/utils.h"], + copts = tf_copts(), + deps = [ + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "test_utils", + srcs = ["test/utils.cc"], + hdrs = ["test/utils.h"], + deps = [ + "//tensorflow/core:lib", + "@com_googlesource_code_re2//:re2", + ], ) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index da4dd5a14cd74591fc9df63cd5868044e4e369ec..21ec8b0b30c595a1fad01b69bce9b16393742704 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -14,20 +14,26 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include #include #include #include #include +#include #include #include #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/contrib/tensorrt/segment/segment.h" -#include "tensorflow/core/common_runtime/gpu/gpu_id.h" -#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -39,17 +45,39 @@ limitations under the License. #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT +#include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT +#include "tensorflow/core/util/device_name_utils.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" #include "tensorrt/include/NvInfer.h" - namespace tensorflow { namespace tensorrt { namespace convert { +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +// Returns compiled TRT version information {Maj, Min, Patch} +std::vector GetLinkedTensorRTVersion() { + return {NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH}; +} + +// Returns loaded TRT library version {Maj, Min, Patch} +std::vector GetLoadedTensorRTVersion() { + int ver = getInferLibVersion(); + int ver_major = ver / 1000; + ver = ver - ver_major * 1000; + int ver_minor = ver / 100; + int ver_patch = ver - ver_minor * 100; + return {ver_major, ver_minor, ver_patch}; +} + namespace { bool IsTensorRTCandidate(const tensorflow::Node* node) { @@ -57,447 +85,857 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(jie): Segmentation shouldn't associated with op name. // Split it into a registration for each kernel. static const std::set candidate_ops = { - "Identity", - "Snapshot", - "Const", - "Conv2D", - "MaxPool", - "BiasAdd", - "Relu", - "Add", - "Mul", - "Sub", - "Rsqrt", - "Pad", - "Mean", - "AvgPool", - "ConcatV2", - "DepthwiseConv2dNative", - "FusedBatchNorm", - "FusedBatchNormV2", - // TODO(ben,jie): ... + "Identity", + "Snapshot", + "Const", + "Conv2D", + "MaxPool", + "BiasAdd", + "Relu", + "Add", + "Mul", + "Sub", + "Rsqrt", + "Pad", + "Mean", + "AvgPool", + "ConcatV2", + "DepthwiseConv2dNative", + "FusedBatchNorm", + "FusedBatchNormV2", + "Div", + "RealDiv", + "Rsqrt", + "Reciprocal", + "Exp", + "Log", + "Sqrt", + "Abs", + "Neg", +#if NV_TENSORRT_MAJOR > 3 + "MatMul", + "BatchMatMul", + "Softmax", + "Minimum", + "Maximum", + "TopKV2", + "Sum", + "Prod", + "Max", + "Min", +#endif + // TODO(ben,jie): ... }; - // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) + // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) return (candidate_ops.count(node->type_string()) || PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); } -void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, - const std::set& subgraph_node_ids, - tensorflow::EdgeSet* incoming_edges) { - for (int node_id : subgraph_node_ids) { - const tensorflow::Node* node = graph.FindNodeId(node_id); - for (const tensorflow::Edge* edge : node->in_edges()) { - if (!subgraph_node_ids.count(edge->src()->id()) && - !edge->src()->IsSource() && !edge->IsControlEdge()) { - incoming_edges->insert(edge); - VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() - << " Y, "; - } else { - VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() - << " N, "; - } +tensorflow::Status BuildNodeMap( + const tensorflow::Graph& graph, + std::unordered_map* node_map) { + for (auto* node : graph.op_nodes()) { + if (!node_map->insert({node->name(), node}).second) { + return tensorflow::errors::AlreadyExists( + "Node name is not unique in graph: " + node->name()); } } + return tensorflow::Status::OK(); } -void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph, - const std::set& subgraph_node_ids, - tensorflow::EdgeSet* outgoing_edges) { - for (int node_id : subgraph_node_ids) { - const tensorflow::Node* node = graph.FindNodeId(node_id); - for (const tensorflow::Edge* edge : node->out_edges()) { - if (!subgraph_node_ids.count(edge->dst()->id()) && - !edge->dst()->IsSink() && !edge->IsControlEdge()) { - VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() - << " Y, "; - outgoing_edges->insert(edge); +} // namespace + +// Function to get calibration from ResourceMgr and put them into nodedef. +tensorflow::Status ConvertCalibGraphToInferGraph( + const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph, + bool is_dyn_op) { + VLOG(0) << "Starting Calib Conversion"; + infer_graph->CopyFrom(graph_def); + auto trt_rm = TRTResourceManager::instance(); + auto calib_rm = trt_rm->getManager("TRTCalibration"); + int num_nodes = infer_graph->node_size(); + if (!is_dyn_op) { + LOG(WARNING) << "Construction of static int8 engine is not implemented " + "yet!. Dynamic engine will be constructed"; + } + for (int i = 0; i < num_nodes; ++i) { + auto n = infer_graph->mutable_node(i); + if (n->op() == "TRTEngineOp") { + VLOG(1) << "Processing " << n->name(); + const string& container_name = n->attr().at("segment_funcdef_name").s(); + TRTCalibrationResource* cres = nullptr; + auto status = calib_rm->Lookup(container_name, "Calibrator", &cres); + if (!status.ok()) { + LOG(ERROR) << "Could not get Calibration information. Did you run with " + "calibration data?"; + return tensorflow::errors::FailedPrecondition( + "Need to run graph with calibration data first!"); + } + if (cres->calibrator_) { + cres->calibrator_->waitAndSetDone(); + cres->thr_->join(); + const auto& calibration_table = + cres->calibrator_->getCalibrationTableAsString(); + if (!calibration_table.size()) { + LOG(ERROR) << "Calibration table is empty"; + return tensorflow::errors::Unknown( + "Calibration table is missing. This shouldn't have happened!"); + } + n->mutable_attr()->at("calibration_data").set_s(calibration_table); } else { - VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() - << " N, "; + LOG(ERROR) << "Can't get TRTCalibrator from resource manager!"; + return tensorflow::errors::Unknown( + "Can't get TRTCalibrator from resource manager!"); } + cres->Unref(); + TF_RETURN_IF_ERROR(calib_rm->Cleanup(container_name)); } } + return tensorflow::Status::OK(); } -std::pair ParseTensorName(const string& name, - int default_idx = 0) { - string name_no_idx = name; - int idx = default_idx; - const size_t sep = name_no_idx.find_last_of(':'); - if (sep != string::npos) { - name_no_idx = name_no_idx.substr(0, sep); - idx = std::stoi(name.substr(sep + 1)); +tensorflow::Status ConvertGraphDefToTensorRT( + const tensorflow::GraphDef& graph_def, + const std::vector& output_names, size_t max_batch_size, + size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, + int precision_mode, int minimum_segment_size, bool is_dyn_op, + int max_cached_engines, std::vector cached_engine_batches) { + // Create GrapplerItem. + tensorflow::grappler::GrapplerItem item; + item.fetch = output_names; + item.graph = graph_def; + + // TODO(aaroey): we should have used single machine cluster like the + // following, but the problem is then wrap_conversion will depend on + // direct_session and cause double linking problems. To fix this we need to + // fix or get rid of the swig dependency. Here we use VirtualCluster + // as a work around, and we need to create a session to initialize the + // underlying device before calling this method. +#if 0 + // Create single machine cluster. Note that this will create a session and + // initialize the gpu devices. + const int num_cpu_cores = + tensorflow::grappler::GetNumAvailableLogicalCPUCores(); + const int num_gpus = tensorflow::grappler::GetNumAvailableGPUs(); + VLOG(2) << "cpu_cores: " << num_cpu_cores; + VLOG(2) << "gpus: " << num_gpus; + const int timeout_s = 60 * 10; + std::unique_ptr cluster( + new tensorflow::grappler::SingleMachine( + timeout_s, num_cpu_cores, num_gpus)); + // These settings are the defaults in tensorflow/python/grappler/cluster.py. + cluster->DisableDetailedStats(true); + cluster->AllowSoftPlacement(true); + cluster->SetNumWarmupSteps(10); + TF_RETURN_IF_ERROR(cluster->Provision()); +#else + // Create virtual cluster. Grappler requires a virtual cluster with a proper + // GPU device in order to calculate flops>0 or fails with FATAL in dbg mode. + // We add numbers from a Pascal card here to have flops>0. + tensorflow::DeviceProperties device_properties; + device_properties.set_type("GPU"); + device_properties.mutable_environment()->insert({"architecture", "6"}); + device_properties.set_num_cores(3584); + device_properties.set_frequency(1531); + std::unique_ptr cluster( + new tensorflow::grappler::VirtualCluster( + {{"/GPU:0", device_properties}})); +#endif + + // Create RewriterConfig. + tensorflow::RewriterConfig rw_cfg; + // TODO(aaroey): use only const folding and layout for the time being since + // new optimizers break the graph for trt. + rw_cfg.add_optimizers("constfold"); + rw_cfg.add_optimizers("layout"); + auto optimizer = rw_cfg.add_custom_optimizers(); + optimizer->set_name("TensorRTOptimizer"); + auto& parameters = *(optimizer->mutable_parameter_map()); + parameters["minimum_segment_size"].set_i(minimum_segment_size); + parameters["max_batch_size"].set_i(max_batch_size); + parameters["is_dynamic_op"].set_b(is_dyn_op); + parameters["max_workspace_size_bytes"].set_i(max_workspace_size_bytes); + TF_RETURN_IF_ERROR(GetPrecisionModeName( + precision_mode, parameters["precision_mode"].mutable_s())); + parameters["maximum_cached_engines"].set_i(max_cached_engines); + if (!cached_engine_batches.empty()) { + auto list = parameters["cached_engine_batches"].mutable_list(); + for (const int batch : cached_engine_batches) { + list->add_i(batch); + } } - return std::make_pair(name_no_idx, idx); -} -std::unordered_map> BuildTensorNameMap( - const std::vector& tensor_names) { - std::unordered_map> result; - for (const string& tensor_name : tensor_names) { - string node_name; - int index; - std::tie(node_name, index) = ParseTensorName(tensor_name); - result[node_name].push_back(index); + // Run optimizer. + tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg); + TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, new_graph_def)); + + if (VLOG_IS_ON(5)) { + std::fstream f; + f.open("TRTConversionInput.pb", + std::fstream::out | std::fstream::binary | std::fstream::trunc); + f << new_graph_def->SerializeAsString(); + f.close(); } - return result; + return Status::OK(); } -// TODO(sami): convert references to pointers -struct ConvertGraphParams { - ConvertGraphParams( - tensorflow::Graph& inp_graph, - const std::vector& output_node_names, - const std::set& subgraph_node_id_numbers, - size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes, - const tensorflow::grappler::GraphProperties& current_graph_properties, - std::unordered_map>* output_edges, - int engine_precision_mode, const string& device_name, - std::shared_ptr allocator, int cuda_gpu_id) - : graph(inp_graph), - output_names(output_node_names), - subgraph_node_ids(subgraph_node_id_numbers), - max_batch_size(max_supported_batch_size), - max_workspace_size_bytes(max_consumed_workspace_size_bytes), - graph_properties(current_graph_properties), - output_edge_map(output_edges), - precision_mode(engine_precision_mode), - device_name_(device_name), - allocator_(allocator), - cuda_gpu_id_(cuda_gpu_id) {} - tensorflow::Graph& graph; - const std::vector& output_names; - const std::set& subgraph_node_ids; - size_t max_batch_size; - size_t max_workspace_size_bytes; - const tensorflow::grappler::GraphProperties& graph_properties; - std::unordered_map>* output_edge_map; - int precision_mode; - string device_name_; - std::shared_ptr allocator_; - int cuda_gpu_id_; - std::vector> subgraph_inputs; - std::vector> subgraph_outputs; - tensorflow::EdgeSet subgraph_incoming_edges; - tensorflow::EdgeSet subgraph_outgoing_edges; -}; - -static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) { - GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids, - &p->subgraph_incoming_edges); - - std::set> unique_tensors; - // Add only unique input source nodes. If output of an outside node is shared - // between multiple nodes inside the engine, only one edge should be created - for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) { - unique_tensors.insert({edge->src()->id(), edge->src_output()}); - } - p->subgraph_inputs.insert(p->subgraph_inputs.begin(), unique_tensors.begin(), - unique_tensors.end()); - GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids, - &p->subgraph_outgoing_edges); - unique_tensors.clear(); - // Similar to above, if multiple ouside nodes are sharing the output of an - // internal node only one output port should be created and shared between - // outputs - for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) { - unique_tensors.insert({edge->src()->id(), edge->src_output()}); +// Function to get subsegment information structure. +tensorflow::Status GetEngineInfo( + const tensorflow::Graph* g, + const tensorflow::grappler::GraphProperties& graph_properties, + const std::set& segment_nodes, + const std::unordered_map& node_map, + const std::vector& reverse_topo_order, + EngineInfo* info) { + std::vector subgraph_node_ids; // Topologically sorted node ids. + std::set subgraph_node_names = segment_nodes; + std::set added_const_node_ids; // Used to prevent double insertion. + std::set segment_devices; + + // Map from src_node_name+port to the unique port numbers of the TRT op, where + // the src_node_name is the name of the source node of the input/output + // edge, thus there must not be any duplicates since source nodes of + // input/output edges must be in different split of the graph. + // TODO(aaroey): consider using node id and port instead. + // TODO(aaroey): using topo order instead of reverting reverse topo order. + std::unordered_map input_to_engine_port, output_to_engine_port; + for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend(); + ++it) { + const auto& node_name = (*it)->name(); + if (segment_nodes.count(node_name) == 0) continue; + auto node = *it; + auto node_device = node->requested_device(); + if (!node_device.empty()) { + segment_devices.insert(node_device); + } else { + if (node->has_assigned_device_name()) { + segment_devices.insert(node->assigned_device_name()); + } else { + VLOG(2) << "Node " << node->name() + << " neither have requested device nor assigned device"; + } + } + const int node_id = node->id(); + subgraph_node_ids.push_back(node_id); + // Create input connections. + for (const auto edge : node->in_edges()) { + auto input_node = edge->src(); + if (input_node->IsSource() || segment_nodes.count(input_node->name())) { + continue; + } + if (edge->IsControlEdge()) { + // Control input. + info->connections.emplace_back(input_node->name(), input_node->id(), + node_name, node_id, + /*input_edge=*/true); + } else if (input_node->type_string() == "Const") { + // Add constant data input nodes into the segment graphdef (thus also in + // the engine). We don't care if it has other output edges going into + // other engines or TF nodes. Since we add it only to the segment + // graphdef, not the segment itself, it won't be removed from the graph. + // If it doesn't have any edges, TF will prune it out. + // + // Note that the segmenter already ensure that the constant data input + // is valid and suppported by the engine. + if (!added_const_node_ids.insert(input_node->id()).second) { + // Already added before. + continue; + } + VLOG(1) << "Adding const node " << input_node->name(); + QCHECK(subgraph_node_names.insert(input_node->name()).second); + // Since we already add (duplicate) the const input node to the segment + // graphdef, it's now not a data dependency any more, but to make the + // dependency correct we still add a control dependency. + info->connections.emplace_back(input_node->name(), input_node->id(), + node_name, node_id, + /*input_edge=*/true); + } else { + // Non-const data input. + int port = Graph::kControlSlot - 1; + // Use the source non-segment node name/port as key. + const string s = StrCat(input_node->name(), ":", edge->src_output()); + VLOG(1) << "Input edge = " << s; + if (input_to_engine_port.count(s)) { + port = input_to_engine_port.at(s); + } else { + port = input_to_engine_port.size(); + input_to_engine_port.insert({s, port}); + } + info->connections.emplace_back( + input_node->name(), input_node->id(), edge->src_output(), node_name, + node_id, edge->dst_input(), /*input_edge=*/true, port); + } + } + // Create output connections. + for (const auto edge : node->out_edges()) { + auto output_node = edge->dst(); + if (output_node->IsSink() || segment_nodes.count(output_node->name())) { + continue; + } + if (edge->IsControlEdge()) { + // Control output. + info->connections.emplace_back(output_node->name(), output_node->id(), + node_name, node_id, + /*input_edge=*/false); + } else { + // Data output. + int port = Graph::kControlSlot - 1; + // Use the source segment node name/port as key. + const string s = StrCat(node_name, ":", edge->src_output()); + VLOG(1) << "Output edge = " << s; + if (output_to_engine_port.count(s)) { + port = output_to_engine_port.at(s); + } else { + port = output_to_engine_port.size(); + output_to_engine_port.insert({s, port}); + } + info->connections.emplace_back( + output_node->name(), output_node->id(), edge->dst_input(), + node_name, node_id, edge->src_output(), /*input_edge=*/false, port); + } + } + } // For each segment node in topological order. + + // Construct the const nodes first. + subgraph_node_ids.insert(subgraph_node_ids.begin(), + added_const_node_ids.begin(), + added_const_node_ids.end()); + TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( + g, graph_properties, subgraph_node_names, subgraph_node_ids, + &info->connections, &info->segment_graph_def, &info->engine_name)); + // TODO(sami): This should not happen once segmenter is updated. + if (segment_devices.size() == 1) { + info->device = *segment_devices.begin(); + } else if (segment_devices.size() > 1) { + LOG(WARNING) << "Detected multiple(" << segment_devices.size() + << ") devices for the segment. Picking first one to continue " + << "but this shouldn't have happened"; + info->device = *segment_devices.begin(); + } else { + LOG(ERROR) << "Can't find a device placement for the op!"; } - p->subgraph_outputs.reserve(unique_tensors.size()); - p->subgraph_outputs.insert(p->subgraph_outputs.begin(), - unique_tensors.begin(), unique_tensors.end()); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status GetCalibNode(ConvertGraphParams* params) { - TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params)); - tensorflow::NodeDef trt_node_def; - SubGraphParams s(params->graph, params->subgraph_node_ids, - params->subgraph_inputs, params->subgraph_outputs, - params->max_batch_size, params->max_workspace_size_bytes, - params->graph_properties, params->output_edge_map, - &trt_node_def, params->precision_mode, params->device_name_, - params->allocator_, params->cuda_gpu_id_); - TF_RETURN_IF_ERROR(InjectCalibrationNode(s)); - tensorflow::Status status; - tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status); - - TF_RETURN_IF_ERROR(status); - - for (auto in_edge : - params->subgraph_incoming_edges) { // loop over incoming edges and - // attach them to calib node - auto src_output = in_edge->src_output(); - auto dst_node = in_edge->dst(); - auto dst_input = in_edge->dst_input(); - VLOG(1) << " update edge " << trt_node->name() << ":" << src_output - << " -> " << dst_node->name() << ":" << dst_input; - TF_RETURN_IF_ERROR( - params->graph.UpdateEdge(trt_node, src_output, dst_node, dst_input)); +// Helper function to update edge connection from the removed node to the +// engine node. If an outside node is gone, it must have been absorbed into +// an engine node. Find the engine node. +void UpdateToEngineNode(const std::vector& infos, + const size_t my_engine_id, + const std::vector& engine_nodes, + const bool is_input_edge, const string& node_name, + tensorflow::Node** node, int* port) { + for (size_t t = 0; t < infos.size(); ++t) { + if (t == my_engine_id) { + continue; + } + const auto& info = infos.at(t); + for (const auto& eng_conn : info.connections) { + // If the connection being updated is an input connection, the source of + // the connection must be an output connection of another engine. And vise + // versa. + if (is_input_edge == eng_conn.is_input_edge) continue; + if (eng_conn.inside_node_name == node_name && + eng_conn.inside_port == *port) { + *node = CHECK_NOTNULL(engine_nodes[t]); + QCHECK_EQ(info.engine_name, (**node).name()) + << "Engine name mismatch: " << info.engine_name << " vs " + << (**node).name(); + *port = eng_conn.port_number; + return; + } + } } - return tensorflow::Status::OK(); + LOG(FATAL) << "Node " << (**node).name() << " not found in any engine."; } -tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { - TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params)); - tensorflow::NodeDef trt_node_def; - - SubGraphParams s(params->graph, params->subgraph_node_ids, - params->subgraph_inputs, params->subgraph_outputs, - params->max_batch_size, params->max_workspace_size_bytes, - params->graph_properties, params->output_edge_map, - &trt_node_def, params->precision_mode, params->device_name_, - params->allocator_, params->cuda_gpu_id_); - TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(s)); - tensorflow::Status status; - tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status); - - // AddNode does not wire edges. - // Re-map incoming edges to use the new TRT node instead of the orig subgraph - std::map, int> subgraph_edge_to_input_map; - for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) { - subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i}); +// Function to insert a TRT engine node into the graph. +// Create engine nodes in the following way: +// 1. Each invocation of CreateTRTNode creates an engine node for infos[pos] +// 2. When an engine node is created, add it into the graph with necessary +// re-wiring. +// 2.1. If the outside connected node is existing, connect the engine +// node to it. +// 2.2. If the outside connected node is gone, it must have been absorted +// into another engine node (which was processed before the processing +// one). Connect to the pre-existing engine node instead. +// 3. In this way, we ensure the graph is topologically sort-able after each +// invocation of CreateTRTNode(). +tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, + int max_batch_size, tensorflow::Graph* graph, + nvinfer1::IGpuAllocator* alloc, + std::vector* engine_nodes) { + const auto& info = infos.at(pos); + TRT_RETURN_IF_TEST_VALUE(StrCat(info.engine_name, ":CreateTRTNode"), "fail"); + std::vector output_shape_protos; + std::vector input_shape_protos; + std::vector input_shapes; + std::vector inputs; + std::vector input_nodes; + std::vector control_input_nodes; + std::unordered_set control_input_names; + std::vector out_types; + + VLOG(1) << "Processing " << info.engine_name; + // Collect needed info for creating the engine node in the graph + for (const auto& conn : info.connections) { + // Control edges + if (conn.is_control_edge()) { + // Skip control outputs for now. control output info are not needed for + // node creation and will be processed later. + if (!conn.is_input_edge) continue; + + // Rewrire control input if it's not found in original graph. + tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); + int port = tensorflow::Graph::kControlSlot; + if (!input_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, + conn.outside_node_name, &input_node, &port); + QCHECK_EQ(Graph::kControlSlot, port); + } + if (!control_input_names.insert(input_node->name()).second) { + continue; + } + control_input_nodes.push_back(input_node); + VLOG(1) << "Engine Control Input " << input_node->name() << " -> " + << info.engine_name; + } else { + // Data edges + if (!conn.is_input_edge) { + // Set the shapes and data types of output edge. + tensorflow::TensorShapeProto out_shape; + // shape of the output node inside segment + conn.inside_shape.AsProto(&out_shape); + if (output_shape_protos.size() <= conn.port_number) { + output_shape_protos.resize(conn.port_number + 1); + out_types.resize(conn.port_number + 1); + } + output_shape_protos.at(conn.port_number) = out_shape; + out_types.at(conn.port_number) = conn.connection_type; + } else { + // Set the shapes and data types of input edge. + tensorflow::TensorShapeProto in_shape; + conn.outside_shape.AsProto(&in_shape); + if (input_shape_protos.size() <= conn.port_number) { + input_shape_protos.resize(conn.port_number + 1); + input_shapes.resize(conn.port_number + 1); + } + input_shape_protos.at(conn.port_number) = in_shape; + input_shapes.at(conn.port_number) = conn.outside_shape; + + // Rewrire data input if it's not found in original graph. + tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); + int port = conn.outside_port; + if (!input_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, + conn.outside_node_name, &input_node, &port); + } + if (std::find_if( + std::begin(inputs), std::end(inputs), + [input_node, &port](const NodeDefBuilder::NodeOut& inp) { + return inp.node == input_node->name() && inp.index == port; + }) == std::end(inputs)) { + inputs.emplace_back(input_node->name(), port, conn.connection_type); + input_nodes.push_back(CHECK_NOTNULL(input_node)); + VLOG(1) << "Engine Input " << input_node->name() << ":" << port + << " -> " << info.engine_name << ":" << inputs.size() - 1; + } + } + } } - std::set> unique_tensors; - for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) { - std::pair old_src = {edge->src()->id(), edge->src_output()}; - if (unique_tensors.count(old_src)) continue; - unique_tensors.insert(old_src); - int new_src_output = subgraph_edge_to_input_map.at(old_src); - params->graph.AddEdge(edge->src(), edge->src_output(), trt_node, - new_src_output); - VLOG(1) << "Wire " << edge->src()->name() << ":" << edge->src_output() - << " -> " << trt_node->name() << ":" << new_src_output; - params->graph.RemoveEdge(edge); + string segment_string; + if (info.engine_type == EngineInfo::EngineType::TRTStatic || + info.precision_mode == INT8MODE) { + // Create static engine for fp32/fp16 mode, and test validity of the engine + // for int8 mode. We don't want engine to fail at the calibration time. + // So we are constructing a FP32 engine here to check its validity, and if + // it is a valid engine then we put the serialized graphdef to the op. + // Otherwise we skip node creation for this engine. + Logger trt_logger; + TrtUniquePtrType engine; + // TODO(sami): What happens if 1st dim is not batch? + TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( + info.segment_graph_def, + info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode, + max_batch_size, info.max_workspace_size_bytes, input_shapes, + &trt_logger, alloc, /*calibrator=*/nullptr, &engine, + /*convert_successfully=*/nullptr)); + TrtUniquePtrType engine_data(engine->serialize()); + segment_string = + string((const char*)engine_data->data(), engine_data->size()); + if (info.precision_mode == INT8MODE) { + // See above comment about why not putting this inside the 'else' branch. + segment_string = info.segment_graph_def.SerializeAsString(); + } + } else { + segment_string = info.segment_graph_def.SerializeAsString(); + } + + // TODO(aaroey): use enum instead, and add a helper method to do the + // conversion. + string prec_string; + TF_RETURN_IF_ERROR(GetPrecisionModeName(info.precision_mode, &prec_string)); + if (info.precision_mode == INT8MODE && + !TRTResourceManager::instance()->getManager("TRTCalibration")) { + LOG(ERROR) << "Failed to construct calibration storage"; } - if (VLOG_IS_ON(2)) { - VLOG(2) << "new edge count: " << trt_node->in_edges().size(); - for (const tensorflow::Edge* edge : trt_node->in_edges()) { - VLOG(2) << edge->src()->name() << " port: " << edge->src_output(); + tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp"); + if (!info.device.empty()) node_builder.Device(info.device); + if (VLOG_IS_ON(1)) { + string ins = StrCat(info.engine_name, " inputs= "); + for (const auto& ii : inputs) { + StrAppend(&ins, ii.node, ":", ii.index, " "); } + VLOG(1) << ins; + } + node_builder.Input(inputs); + for (const string& c : control_input_names) { + node_builder.ControlInput(c); } - TF_RETURN_IF_ERROR(status); - // Re-map outgoing edges to use the new TRT node instead of the orig subgraph - std::map, int> subgraph_edge_to_output_map; - for (size_t i = 0; i < params->subgraph_outputs.size(); ++i) { - subgraph_edge_to_output_map.insert({params->subgraph_outputs.at(i), i}); + if (info.engine_type == EngineInfo::EngineType::TRTStatic && + info.cached_engine_batches.size()) { + LOG(WARNING) << "Cached engine batches are ignored for static engines"; } - TF_RETURN_IF_ERROR(status); - for (const tensorflow::Edge* edge : params->subgraph_outgoing_edges) { - std::pair old_src = {edge->src()->id(), edge->src_output()}; - int new_src_output = subgraph_edge_to_output_map.at(old_src); - TF_RETURN_IF_ERROR(params->graph.UpdateEdge( - trt_node, new_src_output, edge->dst(), edge->dst_input())); - VLOG(1) << "Wire " << trt_node->name() << ":" << new_src_output << " -> " - << edge->dst()->name() << ":" << edge->dst_input(); + tensorflow::NodeDef trt_node; + tensorflow::Status status = + node_builder.Attr("input_shapes", input_shape_protos) + .Attr("output_shapes", output_shape_protos) + .Attr("static_engine", + info.engine_type == EngineInfo::EngineType::TRTStatic) + .Attr("segment_funcdef_name", + StrCat(info.engine_name, "_native_segment")) + .Attr("serialized_segment", segment_string) + .Attr("calibration_data", "") + .Attr("max_cached_engines_count", info.maximum_cached_engines) + .Attr("cached_engine_batches", {max_batch_size}) + .Attr("workspace_size_bytes", info.max_workspace_size_bytes) + .Attr("precision_mode", prec_string) + .Attr("OutT", out_types) + .Finalize(&trt_node); + if (!status.ok()) { + LOG(ERROR) << "Node construction failed with" << status; + return status; } - // Remove the original subgraph - for (int node_id : params->subgraph_node_ids) { - tensorflow::Node* node = params->graph.FindNodeId(node_id); - // Don't remove the input placeholders - if (node->type_string() == "Placeholder") { + VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph"; + + // Up until this point, graph is not modified. If we return !status.ok() from + // here, this segment will be skipped + // TODO(aaroey): let it return proper error status for the following logic + // instead of checking fail. + tensorflow::Node* engine_node = graph->AddNode(trt_node, &status); + (*engine_nodes)[pos] = engine_node; + if (!status.ok()) { + LOG(ERROR) << "Adding node failed " << status; + return status; + } + // Add control input and input edges to the engine node. + for (const auto in : control_input_nodes) { + VLOG(1) << "Connecting control edge from " << in->name() << " to " + << engine_node->name(); + graph->AddControlEdge(in, engine_node); + } + VLOG(1) << "input_nodes size = " << input_nodes.size(); + for (int i = 0; i < input_nodes.size(); ++i) { + Node* n = CHECK_NOTNULL(input_nodes[i]); + const auto& in = inputs[i]; + VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index + << " to " << engine_node->name() << ":" << i; + graph->AddEdge(n, in.index, engine_node, i); + } + + // Updates the inputs of output edges destination nodes, and point them to the + // engine node. + for (auto& conn : info.connections) { + if (conn.is_input_edge) { continue; } - params->graph.RemoveNode(node); + tensorflow::Node* output_node = graph->FindNodeId(conn.outside_id); + int port = conn.outside_port; + if (!output_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false, + conn.outside_node_name, &output_node, &port); + } + VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number + << " to " << output_node->name() << ":" << port; + if (conn.is_control_edge()) { + QCHECK_EQ(Graph::kControlSlot, port); + graph->AddControlEdge(engine_node, output_node); + } else { + auto new_edge = + graph->AddEdge(engine_node, conn.port_number, output_node, port); + QCHECK(new_edge) << "Adding a new edge failed " << engine_node->name() + << ":" << conn.port_number << " -> " + << output_node->name() << ":" << conn.outside_port; + } } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status BuildNodeMap( - const tensorflow::Graph& graph, - std::unordered_map* node_map) { - for (auto* node : graph.op_nodes()) { - if (!node_map->insert({node->name(), node}).second) { - return tensorflow::errors::AlreadyExists( - "Node name is not unique in graph: " + node->name()); +// Function to construct a funcdef from the segment and add it to the graph. +tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( + tensorflow::Graph* graph, const tensorflow::GraphDef& segment, + const string& name) { + tensorflow::Graph sgraph(graph->flib_def()); + tensorflow::GraphConstructorOptions gcopts; + TF_RETURN_IF_ERROR( + tensorflow::ConvertGraphDefToGraph(gcopts, segment, &sgraph)); + std::map io_nodes; + int num_inputs = 0; + for (auto n : sgraph.op_nodes()) { + if (tensorflow::str_util::StartsWith(n->name(), kInputPHName)) { + num_inputs++; + io_nodes.insert({n->name(), n}); + } else if (tensorflow::str_util::StartsWith(n->name(), kOutputPHName)) { + io_nodes.insert({n->name(), n}); } } - return tensorflow::Status::OK(); -} -} // namespace -tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph) { - VLOG(0) << "Starting Calib Conversion"; - tensorflow::Graph graph(tensorflow::OpRegistry::Global()); - TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( - tensorflow::GraphConstructorOptions(), graph_def, &graph)); - // get calib nodes - std::vector calib_nodes; - std::vector topo_order; - tensorflow::GetPostOrder(graph, &topo_order); - for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { - auto node = *rit; - if (node->type_string() == "TRTCalibOp") { - VLOG(1) << "Found Calib Node " << node->name(); - calib_nodes.push_back(node); + for (int i = 0; i < num_inputs; ++i) { + auto name = StrCat(kInputPHName, i); + auto node = io_nodes[name]; + tensorflow::NodeDef nd; + tensorflow::NodeDefBuilder node_builder( + StrCat(name, "_Arg"), tensorflow::FunctionLibraryDefinition::kArgOp); + VLOG(1) << "Adding " << StrCat(name, "_Arg"); + TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0)) + .Attr("index", i) + .Finalize(&nd)); + tensorflow::Status s; + auto node_arg = sgraph.AddNode(nd, &s); + if (!s.ok()) { + LOG(ERROR) << "Couldn't add _Arg node for " << name; + } + for (auto edge : node->out_edges()) { + sgraph.AddEdge(node_arg, 0, edge->dst(), edge->dst_input()); + VLOG(1) << "Updating funcdef input " << node_arg->name() << ":" << 0 + << " - > " << edge->dst()->name() << ":" << edge->dst_input(); + if (!s.ok()) { + LOG(ERROR) << "Failed to update edge from " << node_arg->name() + << " to " << edge->dst()->name() << ":" << edge->dst_input(); + } + } + sgraph.RemoveNode(node); + } + + for (int i = 0; i < io_nodes.size() - num_inputs; ++i) { + auto name = StrCat(kOutputPHName, i); + auto node = io_nodes[name]; + tensorflow::NodeDef nd; + tensorflow::NodeDefBuilder node_builder( + StrCat(name, "_Ret"), tensorflow::FunctionLibraryDefinition::kRetOp); + auto edge = *(node->in_edges().begin()); + tensorflow::NodeDefBuilder::NodeOut nout( + edge->src()->name(), edge->src_output(), + edge->src()->output_type(edge->src_output())); + VLOG(1) << " input " << nout.node << ":" << nout.index + << " dtype=" << tensorflow::DataTypeString(nout.data_type); + // nvcc complains that Input() is + // ambiguous, so do not use Input({nout}). + node_builder.Input(nout); + TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0)) + .Attr("index", i) + .Finalize(&nd)); + if (VLOG_IS_ON(3)) { + VLOG(3) << nd.DebugString(); } + tensorflow::Status s; + auto node_ret = sgraph.AddNode(nd, &s); + if (!s.ok()) { + LOG(ERROR) << "Couldn't add _Ret node for " << name; + } + VLOG(1) << "Update edge from " << edge->src()->name() << ":" + << edge->src_output() << " - > " << node_ret->name() << ":" << 0; + sgraph.AddEdge(edge->src(), edge->src_output(), node_ret, 0); + s = sgraph.UpdateEdge(edge->src(), edge->src_output(), node_ret, 0); + if (!s.ok()) { + LOG(ERROR) << "Failed to update edge from " << edge->src()->name() << ":" + << edge->src_output() << " - > " << node_ret->name() << ":" + << 0; + } + sgraph.RemoveNode(node); } - VLOG(0) << "Num Calib nodes in graph= " << calib_nodes.size(); - if (calib_nodes.size() == 0) - return tensorflow::errors::FailedPrecondition( - "Graph doesn't contain any calibration nodes!." - " Please generate calibration graph and run calibration first"); - for (auto n : calib_nodes) { - TF_RETURN_IF_ERROR( - tensorrt::convert::ConvertCalibrationNodeToEngineNode(graph, n)); + tensorflow::FunctionDefLibrary fdeflib; + auto native_segment = fdeflib.add_function(); + TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef( + sgraph, StrCat(name, "_native_segment"), native_segment)); + if (VLOG_IS_ON(7)) { + VLOG(7) << name << " Function_Def "; + VLOG(7) << native_segment->DebugString(); } - graph.ToGraphDef(infer_graph); + VLOG(1) << "Adding funcdef to graphlib"; + TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib)); return tensorflow::Status::OK(); } -tensorflow::Status ConvertGraphDefToTensorRT( - const tensorflow::GraphDef& graph_def, - const std::vector& output_names, size_t max_batch_size, - size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode = FP32MODE, int minimum_segment_size = 3) { - // optimization pass - tensorflow::grappler::GrapplerItem item; - item.fetch = output_names; - item.graph = graph_def; - - tensorflow::DeviceProperties device_properties; - device_properties.set_type("GPU"); - device_properties.mutable_environment()->insert({"architecture", "6"}); - tensorflow::grappler::Cluster* cluster = - new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}}); - - // single machine - int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores(); - int num_gpus = tensorflow::grappler::GetNumAvailableGPUs(); - VLOG(2) << "cpu_cores: " << num_cpu_cores; - VLOG(2) << "gpus: " << num_gpus; - tensorflow::RewriterConfig rw_cfg; - tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg); - tensorflow::GraphDef gdef; - TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster, item, &gdef)); - item.graph = gdef; - - // AJ refactoring shape inference through grappler/GraphProperties. - tensorflow::grappler::GraphProperties static_graph_properties(item); - TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); - // Build full graph - - return ConvertAfterShapes(gdef, output_names, max_batch_size, - max_workspace_size_bytes, new_graph_def, - precision_mode, minimum_segment_size, - static_graph_properties, nullptr); +std::pair GetDeviceAndAllocator( + const ConversionParams& params, const EngineInfo& engine) { + int cuda_device_id = -1; + tensorflow::Allocator* dev_allocator = nullptr; + if (params.cluster) { + std::vector devices; + if (!engine.device.empty() && params.cluster->GetDeviceSet()) { + DeviceNameUtils::ParsedName parsed_name; + if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) && + parsed_name.has_id) { + params.cluster->GetDeviceSet()->FindMatchingDevices(parsed_name, + &devices); + } + } + if (!devices.empty()) { + if (devices.size() > 1) { + string msg = "Found multiple matching devices using name '"; + StrAppend(&msg, engine.device, "': "); + for (auto d : devices) StrAppend(&msg, d->name(), ", "); + StrAppend(&msg, ". Will get the allocator from first one."); + LOG(WARNING) << msg; + } + tensorflow::AllocatorAttributes alloc_attr; + cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id; + dev_allocator = devices[0]->GetAllocator(alloc_attr); + VLOG(1) << "Using allocator " << dev_allocator->Name() + << " and cuda_device_id " << cuda_device_id; + } else { + LOG(WARNING) << "Cluster is set but device '" << engine.device + << "' is not found in the cluster"; + } + } + return std::make_pair(cuda_device_id, dev_allocator); } -tensorflow::Status ConvertAfterShapes( - const tensorflow::GraphDef& gdef, const std::vector& output_names, - size_t max_batch_size, size_t max_workspace_size_bytes, - tensorflow::GraphDef* new_graph_def, int precision_mode, - int minimum_segment_size, - const tensorflow::grappler::GraphProperties& graph_properties, - const tensorflow::grappler::Cluster* cluster) { - // Segment the graph into subgraphs that can be converted to TensorRT - tensorflow::tensorrt::segment::SegmentOptions segment_options; +// Entry function from optimization pass. +// TODO(aaeory): parameter should use pointer type. +tensorflow::Status ConvertAfterShapes(ConversionParams& params) { + // Convert graphdef to graph. tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), - gdef.library()); + params.input_graph_def->library()); tensorflow::Graph graph(flib); TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( - tensorflow::GraphConstructorOptions(), gdef, &graph)); + tensorflow::GraphConstructorOptions(), *params.input_graph_def, &graph)); + // Segment the graph into subgraphs that can be converted to TensorRT + tensorflow::tensorrt::segment::SegmentOptions segment_options; // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT) - for (auto node : output_names) { + for (auto node : *(params.output_names)) { segment_options.exclude_node_list.insert(node); } - - // TODO(sami): this should be passed as a knob!!!! - segment_options.minimum_segment_size = minimum_segment_size; - tensorflow::tensorrt::segment::SegmentNodesVector segments; + segment_options.minimum_segment_size = params.minimum_segment_size; + tensorflow::tensorrt::segment::SegmentNodesVector initial_segments; TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( - &graph, IsTensorRTCandidate, segment_options, &segments)); - if (segments.size() > 1) { - VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << segments.size(); + &graph, IsTensorRTCandidate, InputEdgeValidator(*params.graph_properties), + OutputEdgeValidator(), segment_options, &initial_segments)); + if (initial_segments.size() > 1) { + VLOG(0) << "MULTIPLE tensorrt candidate conversion: " + << initial_segments.size(); } + + // Get the EngineInfo for each segment. std::unordered_map node_map; TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); - std::unordered_map> output_edge_map; - int count = 0; float total_num_nodes_in_segments = 0.; - for (auto s : segments) { - total_num_nodes_in_segments += s.first.size(); - } - // We create the map here since cluster may not be available in all cases. - std::map name_to_device_map; - if (cluster) { - // TODO(aaroey): consider using DeviceSet::FindDeviceByName(), as in a - // distributed environment, devices from different workers can have same - // short name. - for (const auto dm : cluster->GetDeviceSet()->devices()) { - name_to_device_map[dm->name()] = dm; + std::vector engine_segments; + engine_segments.reserve(initial_segments.size()); + std::vector reverse_topo_order; + tensorflow::GetPostOrder(graph, &reverse_topo_order); + size_t total_engine_bytes_size = 0; + std::vector engine_bytes_size; + tensorflow::tensorrt::segment::SegmentNodesVector converted_segments; + converted_segments.reserve(initial_segments.size()); + for (size_t t = 0; t < initial_segments.size(); t++) { + auto& curr_segment = initial_segments.at(t); + EngineInfo curr_engine; + Status status = + GetEngineInfo(&graph, *params.graph_properties, curr_segment.first, + node_map, reverse_topo_order, &curr_engine); + if (!status.ok()) { + LOG(WARNING) << "Failed to get engine info for segment " << t << ": " + << status; + continue; } - } - for (const auto& segment_nodes_and_device : segments) { - const std::set& subgraph_node_names = - segment_nodes_and_device.first; - std::set subgraph_node_ids; - size_t max_mem_per_engine = - max_workspace_size_bytes * - ((float)subgraph_node_names.size() / total_num_nodes_in_segments); - std::stringstream oss; - for (const string& node_name : subgraph_node_names) { - oss << " " << node_name; - subgraph_node_ids.insert(node_map.at(node_name)->id()); + curr_engine.precision_mode = params.precision_mode; + curr_engine.engine_type = + (params.is_dyn_op || params.precision_mode == INT8MODE + ? EngineInfo::EngineType::TRTDynamic + : EngineInfo::EngineType::TRTStatic); + curr_engine.cached_engine_batches = params.cached_engine_batches; + curr_engine.maximum_cached_engines = params.max_cached_engines; + StrAppend(&curr_engine.engine_name, "my_trt_op_", t); + status = RegisterSegmentFunctionToFunctionLibrary( + &graph, curr_engine.segment_graph_def, curr_engine.engine_name); + if (!status.ok()) { + LOG(WARNING) << "Failed to register segment graphdef as a function " << t + << ": " << status; + continue; } - VLOG(1) << "Subgraph nodes at device " << segment_nodes_and_device.second - << " : " << oss.str(); - auto target_device = - name_to_device_map.find(segment_nodes_and_device.second); - std::shared_ptr allocator(0); + engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong()); + total_engine_bytes_size += engine_bytes_size.back(); + total_num_nodes_in_segments += curr_segment.first.size(); + engine_segments.push_back(std::move(curr_engine)); + converted_segments.push_back(std::move(curr_segment)); + + if (VLOG_IS_ON(8)) { + string fname = curr_engine.engine_name; + StrAppend(&fname, ".pb"); + std::fstream f; + f.open(fname.c_str(), std::fstream::out | std::fstream::binary); + f << engine_segments.at(t).segment_graph_def.SerializeAsString(); + f.close(); + } + } + + // Create a TRT node for each segment using its EngineInfo. + int old_cuda_device = 0; + auto err = cudaGetDevice(&old_cuda_device); + if (err != cudaSuccess) { + LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err); + } + VLOG(1) << "Current cuda device is " << old_cuda_device; + std::vector engine_nodes; + engine_nodes.resize(engine_segments.size()); + for (int i = 0; i < engine_segments.size(); ++i) { + auto& engine = engine_segments.at(i); + // Partition the workspace size by the average of node ratio and segment + // graphdef size + engine.max_workspace_size_bytes = + params.max_workspace_size_bytes * + (engine_bytes_size.at(i) / total_engine_bytes_size + + converted_segments.at(i).first.size() / total_num_nodes_in_segments) / + 2.0; + // The allocator is used to build the engine. The build and the built engine + // will be destroyed after we get the serialized engine string, so it's fine + // to use unique_ptr here. + std::unique_ptr alloc; + auto device_alloc = GetDeviceAndAllocator(params, engine); int cuda_device_id = 0; - if (target_device != name_to_device_map.end()) { - tensorflow::TfGpuId tf_gpu_id(target_device->second->parsed_name().id); - CudaGpuId cuda_gpu_id; - Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); - if (!s.ok()) { - LOG(ERROR) - << "Cuda device identification failed, using device 0. Error= " - << s; - } else { - cuda_device_id = cuda_gpu_id.value(); - } - tensorflow::GPUOptions gpuoptions; - // we need to us PM here since in python path there is no way to get to - // allocators - auto pm = tensorflow::ProcessState::singleton(); - // this should be instantiated by now - auto dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1); - VLOG(1) << "Got an allocator for device tf_device=" << tf_gpu_id.value() - << " cuda device= " << cuda_device_id << " at " << dev_allocator; - allocator = std::make_shared(dev_allocator); - } else { // device unknown or not available - allocator = std::make_shared(); + if (device_alloc.first >= 0) { + cuda_device_id = device_alloc.first; + alloc.reset(new TRTDeviceAllocator(device_alloc.second)); + } else { + // Setting allocator as nullptr should get revert to the cudamalloc + LOG(WARNING) << "Can't identify the cuda device. Running on device 0 "; } - ConvertGraphParams p(graph, output_names, subgraph_node_ids, max_batch_size, - max_mem_per_engine, graph_properties, &output_edge_map, - precision_mode, segment_nodes_and_device.second, - allocator, cuda_device_id); - if (precision_mode == INT8MODE) { - tensorflow::Status status = GetCalibNode(&p); - if (status != tensorflow::Status::OK()) { - LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count - << " due to: \"" << status.ToString() - << "\" SKIPPING......( " << subgraph_node_names.size() - << " nodes)"; + cudaSetDevice(cuda_device_id); + auto status = CreateTRTNode(engine_segments, i, params.max_batch_size, + &graph, alloc.get(), &engine_nodes); + // If status is ok, we successfully added the node to the graph and can + // remove segment ops. Otherwise graph is not modified. + const string msg = StrCat("Engine ", engine.engine_name, + " creation for segment ", i, ", composed of ", + converted_segments.at(i).first.size(), " nodes"); + if (status.ok()) { + LOG(INFO) << msg << " succeeded."; + for (auto node_name : converted_segments.at(i).first) { + graph.RemoveNode(node_map.at(node_name)); } } else { - tensorflow::Status status = ConvertSubGraphToTensorRT(&p); - if (status != tensorflow::Status::OK()) { - LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count - << " due to: \"" << status.ToString() - << "\" SKIPPING......( " << subgraph_node_names.size() - << " nodes)"; - } + // Graph is not modified. + LOG(WARNING) << msg << " failed: " << status << ". Skipping..."; } - count++; } - graph.ToGraphDef(new_graph_def); + cudaSetDevice(old_cuda_device); + graph.ToGraphDef(params.output_graph_def); + VLOG(1) << "Returning from conversion"; return tensorflow::Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h index 65a67d7e73e32f904bd636a4f4aaefe32b0c092d..9d986e489043c0a0e16e379166aa2e8f7ac0b11f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -30,29 +30,60 @@ namespace tensorflow { namespace tensorrt { namespace convert { -// This method converts an already generated calibration graph which was used in -// calibration runs to an inference graph +struct ConversionParams { + ConversionParams() + : input_graph_def(nullptr), + max_batch_size(1), + max_workspace_size_bytes(1 << 30), + output_graph_def(nullptr), + precision_mode(1), + minimum_segment_size(3), + graph_properties(nullptr), + cluster(nullptr), + is_dyn_op(false), + fixed_input_size(true), + max_cached_engines(1) {} + const tensorflow::GraphDef* input_graph_def; + const std::vector* output_names; + size_t max_batch_size; + size_t max_workspace_size_bytes; + tensorflow::GraphDef* output_graph_def; + int precision_mode; + int minimum_segment_size; + const tensorflow::grappler::GraphProperties* graph_properties; + const tensorflow::grappler::Cluster* cluster; + bool is_dyn_op; // Whether to create engine on conversion or execution time + bool fixed_input_size; // Assume non-batch ranks of input tensors are fixed + int max_cached_engines; // maximum number of cached engines + std::vector cached_engine_batches; // list of cached engines +}; + +// This method extracts calibration information from the resource managers +// and puts them in to engine nodedefs. tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def); + const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def, + bool is_dyn_op); -// max_batch_size: maximum batch size which can be used for inference for -// optimization targets inference run with max batch size. -// max_workspace_size_bytes: The upper bound of memory allowance for -// engine building. +// - max_batch_size: maximum batch size which can be used for inference for +// optimization targets inference run with max batch size. +// - max_workspace_size_bytes: The upper bound of memory allowance for engine +// building. tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode, int minimum_segment_size); + int precision_mode = 1, int minimum_segment_size = 3, + bool is_dyn_op = false, int max_cached_engines = 1, + std::vector cached_engine_batches = {}); // Method to call from optimization pass -tensorflow::Status ConvertAfterShapes( - const tensorflow::GraphDef& graph, const std::vector& output_names, - size_t max_batch_size, size_t max_workspace_size_bytes, - tensorflow::GraphDef* new_graph_def, int precision_mode, - int minimum_segment_size, - const tensorflow::grappler::GraphProperties& graph_properties, - const tensorflow::grappler::Cluster* cluster); +tensorflow::Status ConvertAfterShapes(ConversionParams& params); + +// Return compile time TensorRT library version information. +std::vector GetLinkedTensorRTVersion(); + +// Return runtime time TensorRT library version information. +std::vector GetLoadedTensorRTVersion(); } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 4e4d295538edadd26a347a38ec141737f097f26f..35fa590254137d62fea868882d5c225848829ca1 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -14,18 +14,21 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include +#include #include #include #include #include #include +#include #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT @@ -37,6 +40,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -47,15 +51,37 @@ limitations under the License. #if GOOGLE_TENSORRT #include "tensorrt/include/NvInfer.h" -// Check if the types are equal. Cast to int first so that failure log message -// would work! -#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) +// Check if the types are equal. Cast to int first so that failure log message +// would work! +#define TFTRT_CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) + +#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \ + do { \ + return tensorflow::errors::Internal( \ + "TFTRT::", __FUNCTION__, "failed to add TRT layer, at: ", node); \ + } while (0) + +#define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \ + do { \ + if (status == false) { \ + TFTRT_INTERNAL_ERROR_AT_NODE(node); \ + } \ + } while (0) + +#define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \ + do { \ + if (ptr == nullptr) { \ + TFTRT_INTERNAL_ERROR_AT_NODE(node); \ + } \ + } while (0) namespace tensorflow { namespace tensorrt { namespace convert { +using ::tensorflow::str_util::Split; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; + namespace { inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, @@ -70,13 +96,163 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, case tensorflow::DataType::DT_HALF: *trt_dtype = nvinfer1::DataType::kHALF; break; +#if NV_TENSORRT_MAJOR > 3 + case tensorflow::DataType::DT_INT32: + *trt_dtype = nvinfer1::DataType::kINT32; + break; +#endif default: return tensorflow::errors::InvalidArgument( - "Unsupported data type " + tensorflow::DataTypeString(tf_dtype)); + "Unsupported data type ", tensorflow::DataTypeString(tf_dtype)); } return tensorflow::Status::OK(); } +void GetInputProperties(const grappler::GraphProperties& graph_properties, + const Node* outside_node, const int out_port, + PartialTensorShape* shape, + tensorflow::DataType* dtype) { + if (graph_properties.HasOutputProperties(outside_node->name())) { + auto output_params = + graph_properties.GetOutputProperties(outside_node->name()); + auto out_shape = output_params.at(out_port); + *dtype = out_shape.dtype(); + *shape = out_shape.shape(); + } else { + VLOG(0) << "Unknown output shape" << outside_node->name(); + *dtype = outside_node->output_type(out_port); + } +} + +void GetOutputProperties(const grappler::GraphProperties& graph_properties, + const Node* outside_node, const int in_port, + PartialTensorShape* shape, + tensorflow::DataType* dtype) { + if (graph_properties.HasInputProperties(outside_node->name())) { + auto input_params = + graph_properties.GetInputProperties(outside_node->name()); + auto in_shape = input_params.at(in_port); + *dtype = in_shape.dtype(); + *shape = in_shape.shape(); + } else { + *dtype = outside_node->input_type(in_port); + } +} + +tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape, + const tensorflow::DataType dtype, + nvinfer1::DataType* trt_dtype) { + // TODO(aaroey): some of these checks also apply to IsTensorRTCandidate(), so + // put them there instead. + TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype)); + if (shape.dims() < 0) { + return tensorflow::errors::InvalidArgument("Input tensor rank is unknown."); + } + if (shape.dims() > 9) { + return tensorflow::errors::OutOfRange( + "Input tensor rank is greater than 8."); + } + for (int d = 1; d < shape.dims(); ++d) { + if (shape.dim_size(d) < 0) { + return tensorflow::errors::InvalidArgument( + "Input tensor has a unknown non-batch dimemension at dim ", d); + } + } + return Status::OK(); +} + +// Return whether or not the broadcast is feasible; +bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l, + const bool operand_l_is_tensor, + const nvinfer1::Dims& operand_r, + const bool operand_r_is_tensor, + nvinfer1::Dims* operand_l_new_shape, + nvinfer1::Dims* operand_r_new_shape) { + // *************************************************************************** + // TensorRT Elementwise op supports broadcast but requires both tensor to be + // of Identical rank + // + // We consider case of: + // 1. operand_l to be a Tensor & operand_r to be a Const; + // 2. operand_l to be a Tensor & operand_r to be a Tensor; + // note: const op const (constant folding) should fallback to TensorFlow + // + // broadcast scheme: + // T: 1 3 5 (tensor would not have batch dimension) + // W: 1 1 3 1 (weight would have all explicit dimensions) + // i. fill in explicit dimensions + // -> T: -1 1 3 5 (we put a -1 for batch dimension) + // -> W: 1 1 3 1 + // ii. compare broadcast feasibility + // + // We cannot support the following since TensorRT does not allow manipulation + // on batch dimension, we cannot generate output with proper shape + // T: 3 5 1 + // W: 1 1 1 1 3 5 1 + // -> T: 1 1 1 -1 3 5 1 + // -> W: 1 1 1 1 3 5 1 + // *************************************************************************** + const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; + const size_t element_size = sizeof(operand_l.d[0]); + + // fill in dimensions + int l_s[max_nb_dims]; + std::fill(l_s, l_s + max_nb_dims, 1); + int l_d = operand_l_is_tensor ? operand_l.nbDims + 1 : operand_l.nbDims; + int r_s[max_nb_dims]; + std::fill(r_s, r_s + max_nb_dims, 1); + int r_d = operand_r_is_tensor ? operand_r.nbDims + 1 : operand_r.nbDims; + + int max_d = std::max(l_d, r_d); + std::memcpy(l_s + max_d - operand_l.nbDims, operand_l.d, + operand_l.nbDims * element_size); + std::memcpy(r_s + max_d - operand_r.nbDims, operand_r.d, + operand_r.nbDims * element_size); + + // set -1 for batch dimension, since batch size is not supposed to be + // broadcasted + if (operand_l_is_tensor) { + if (max_d != l_d) { // if broadcast beyond batch dimension, fail + return false; + } + l_s[0] = -1; + } + if (operand_r_is_tensor) { + if (max_d != r_d) { // if broadcast beyond batch dimension, fail + return false; + } + r_s[0] = -1; + } + + // compare broadcast feasibility + for (int i = max_d - 1; i >= 0; i--) { + if ((l_s[i] != r_s[i]) && (l_s[i] != 1) && (r_s[i] != 1)) { + return false; + } + } + + // output new TensorRT Dimension (stripping the batch dimension) + operand_l_new_shape->nbDims = max_d - 1; + std::memcpy(operand_l_new_shape->d, l_s + 1, (max_d - 1) * element_size); + operand_r_new_shape->nbDims = max_d - 1; + std::memcpy(operand_r_new_shape->d, r_s + 1, (max_d - 1) * element_size); + + return true; +} + +inline bool DimsEqual(const nvinfer1::Dims& dim_l, + const nvinfer1::Dims& dim_r) { + if (dim_l.nbDims != dim_r.nbDims) { + return false; + } + for (int i = 0; i < dim_l.nbDims; i++) { + if (dim_l.d[i] != dim_r.d[i]) { + return false; + } + } + return true; +} + inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) { nvinfer1::Dims dims; dims.nbDims = tensor.dims(); @@ -86,7 +262,7 @@ inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) { return dims; } -inline int64_t GetShapeSize(nvinfer1::Dims shape) { +inline int64_t GetShapeSize(const nvinfer1::Dims& shape) { // Returns total number of elements in shape int64_t count = 1; for (int d = 0; d < shape.nbDims; ++d) { @@ -99,7 +275,7 @@ static std::vector> CreateSamePadding( const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel, const std::vector& input_dims) { std::vector> padding(input_dims.size()); - CHECK_EQ((size_t)stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+? + CHECK_EQ(stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+? for (size_t i = 0; i < input_dims.size(); ++i) { // Formula to calculate the padding @@ -121,16 +297,15 @@ static std::vector> CreateSamePadding( string GetCommonNameScope(const string& op_name_a, const string& op_name_b) { size_t last_scope_separator = 0; - for (size_t i = 0; i < std::min(op_name_a.size(), op_name_b.size()); ++i) { - if (op_name_a[i] != op_name_b[i]) { - break; - } else if (op_name_a[i] == '/') { - last_scope_separator = i + 1; - } + const size_t min_size = std::min(op_name_a.size(), op_name_b.size()); + for (size_t i = 0; i < min_size; ++i) { + if (op_name_a[i] != op_name_b[i]) break; + if (op_name_a[i] == '/') last_scope_separator = i + 1; } return op_name_a.substr(0, last_scope_separator); } +// Class to convert TF weight to TRT weight. class TRT_ShapedWeights { public: TRT_ShapedWeights(tensorflow::DataType type, const void* values, @@ -142,12 +317,14 @@ class TRT_ShapedWeights { explicit TRT_ShapedWeights(tensorflow::DataType type) : shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {} + // TODO(aaroey): use rvalue reference. TRT_ShapedWeights(const TRT_ShapedWeights& rhs) : shape_(rhs.shape_), type_(rhs.type_), values_(rhs.values_), empty_weight_flag_(rhs.empty_weight_flag_) {} + // TODO(aaroey): use GetShapeSize() instead. int64_t count() const { int64_t c = 1; for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i]; @@ -165,6 +342,7 @@ class TRT_ShapedWeights { const void* GetValues() const { return values_; } + // TODO(aaroey): get rid of this method. void SetValues(const void* values) { values_ = values; } size_t size_bytes() const { @@ -175,10 +353,12 @@ class TRT_ShapedWeights { // Default converter operator nvinfer1::Weights() const { return GetWeightsForTRT(); } + // TODO(aaroey): make these private. nvinfer1::Dims shape_; tensorflow::DataType type_; private: + // TODO(aaroey): this should not be const as it's always from TRTWeightStore. const void* values_; bool empty_weight_flag_; }; @@ -189,6 +369,7 @@ class TRT_TensorOrWeights { : tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {} explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights) : tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {} + // TODO(aaroey): use rvalue reference. TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs) : tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {} ~TRT_TensorOrWeights() {} @@ -197,19 +378,19 @@ class TRT_TensorOrWeights { bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; } nvinfer1::ITensor* tensor() { - CHECK_EQ(is_tensor(), true); + CHECK(is_tensor()); return tensor_; } const nvinfer1::ITensor* tensor() const { - CHECK_EQ(is_tensor(), true); + CHECK(is_tensor()); return tensor_; } TRT_ShapedWeights& weights() { - CHECK_EQ(is_weights(), true); + CHECK(is_weights()); return weights_; } const TRT_ShapedWeights& weights() const { - CHECK_EQ(is_weights(), true); + CHECK(is_weights()); return weights_; } nvinfer1::Dims shape() const { @@ -233,21 +414,25 @@ class TFAttrs { attrs_.insert({attr.first, &attr.second}); } } - bool count(string key) const { return attrs_.count(key); } - tensorflow::AttrValue const* at(string key) const { + + bool count(const string& key) const { return attrs_.count(key); } + + tensorflow::AttrValue const* at(const string& key) const { if (!attrs_.count(key)) { LOG(FATAL) << "Attribute not found: " << key; } return attrs_.at(key); } + template T get(const string& key) const; + template T get(const string& key, const T& default_value) const { return attrs_.count(key) ? this->get(key) : default_value; } - std::vector GetAllAttrKey() { + std::vector GetAllAttrKeys() const { std::vector attr_list; for (const auto& attr_item : attrs_) { attr_list.emplace_back(attr_item.first); @@ -282,15 +467,6 @@ std::vector TFAttrs::get>(const string& key) const { auto attr = this->at(key)->list().s(); return std::vector(attr.begin(), attr.end()); } -template <> -nvinfer1::Dims TFAttrs::get(const string& key) const { - auto values = this->get>(key); - nvinfer1::Dims dims; - dims.nbDims = values.size(); - std::copy(values.begin(), values.end(), dims.d); - // Note: No dimension type information is included - return dims; -} template <> nvinfer1::DataType TFAttrs::get(const string& key) const { @@ -316,10 +492,11 @@ bool TFAttrs::get(const string& key) const { } // TODO(jie): reorder4 & reorder2 should be merged? +// TODO(aaroey): fix the order of parameters. template -void Reorder4(nvinfer1::DimsNCHW shape, const T* idata, - nvinfer1::DimsNCHW istrides, T* odata, - nvinfer1::DimsNCHW ostrides) { +void Reorder4(const nvinfer1::DimsNCHW& shape, const T* idata, + const nvinfer1::DimsNCHW& istrides, T* odata, + const nvinfer1::DimsNCHW& ostrides) { for (int n = 0; n < shape.n(); ++n) { for (int c = 0; c < shape.c(); ++c) { for (int h = 0; h < shape.h(); ++h) { @@ -334,12 +511,13 @@ void Reorder4(nvinfer1::DimsNCHW shape, const T* idata, } template -void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides, - T* odata, nvinfer1::DimsHW ostrides) { +void Reorder2(const nvinfer1::DimsHW& shape, const T* idata, + const nvinfer1::DimsHW& istrides, T* odata, + const nvinfer1::DimsHW& ostrides) { for (int h = 0; h < shape.h(); ++h) { for (int w = 0; w < shape.w(); ++w) { odata[h * ostrides.h() + w * ostrides.w()] = - idata[h * ostrides.h() + w * ostrides.w()]; + idata[h * istrides.h() + w * istrides.w()]; } } } @@ -347,16 +525,17 @@ void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides, // TODO(jie): fallback to tensorflow!! void ReorderCKtoKC(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights) { - int c = iweights.shape_.d[0]; - int k = iweights.shape_.d[1]; + const int c = iweights.shape_.d[0]; + const int k = iweights.shape_.d[1]; oweights->shape_.d[0] = k; oweights->shape_.d[1] = c; - nvinfer1::DimsHW istrides = {1, k}; - nvinfer1::DimsHW ostrides = {c, 1}; + const nvinfer1::DimsHW istrides = {1, k}; + const nvinfer1::DimsHW ostrides = {c, 1}; switch (iweights.type_) { case tensorflow::DataType::DT_FLOAT: { Reorder2({k, c}, static_cast(iweights.GetValues()), istrides, + // TODO(aaroey): get rid of all the const_cast like this. static_cast(const_cast(oweights->GetValues())), ostrides); break; @@ -379,21 +558,24 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights, int num_groups) { CHECK_EQ(iweights.type_, oweights->type_); CHECK_EQ(iweights.size_bytes(), oweights->size_bytes()); - int r = iweights.shape_.d[0]; - int s = iweights.shape_.d[1]; - // TRT requires GKcRS, while TF depthwise has RSCK - // where c=1, C=G + // K indexes over output channels, C over input channels, and R and S over the + // height and width of the convolution + const int r = iweights.shape_.d[0]; + const int s = iweights.shape_.d[1]; + // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G VLOG(2) << "num_groups: " << num_groups; - int c = iweights.shape_.d[2] / num_groups; + const int c = iweights.shape_.d[2] / num_groups; VLOG(2) << "c" << iweights.shape_.d[2] << " then " << c; - int k = iweights.shape_.d[3] * num_groups; + const int k = iweights.shape_.d[3] * num_groups; VLOG(2) << "k" << iweights.shape_.d[3] << " then " << k; + VLOG(2) << "r" << iweights.shape_.d[0] << " then " << r; + VLOG(2) << "s" << iweights.shape_.d[1] << " then " << s; oweights->shape_.d[0] = k / num_groups; oweights->shape_.d[1] = c * num_groups; oweights->shape_.d[2] = r; oweights->shape_.d[3] = s; - nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; - nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; + const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; + const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; switch (iweights.type_) { case tensorflow::DataType::DT_FLOAT: { Reorder4({k, c, r, s}, static_cast(iweights.GetValues()), @@ -417,20 +599,6 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, } } -struct InferDeleter { - template - void operator()(T* obj) const { - if (obj) { - obj->destroy(); - } - } -}; - -template -inline std::shared_ptr infer_object(T* obj) { - return std::shared_ptr(obj, InferDeleter()); -} - class Converter; using OpConverter = @@ -439,19 +607,22 @@ using OpConverter = std::vector*)>; class Converter { + // TODO(aaroey): fix the order of members. std::unordered_map trt_tensors_; std::unordered_map op_registry_; OpConverter plugin_converter_; nvinfer1::INetworkDefinition* trt_network_; std::list> temp_bufs_; - tensorflow::tensorrt::TRTWeightStore* weight_store_; + // TODO(aaroey): inline the definition of TRTWeightStore here, and add APIs to + // operate the stored weights instead of operating it directly. + TRTWeightStore* weight_store_; bool fp16_; void register_op_converters(); tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def, std::vector* inputs) { for (auto const& input_name : node_def.input()) { /************************************************************************* - * TODO(jie) handle case 1) here + * TODO(jie): handle case 1) here. * Normalizes the inputs and extracts associated metadata: * 1) Inputs can contain a colon followed by a suffix of characters. * That suffix may be a single number (e.g. inputName:1) or several @@ -465,6 +636,7 @@ class Converter { if (input_name[0] == '^') continue; string name = input_name; auto first = name.find_first_of(':'); + // TODO(aaroey): why removing the colon but not the zero? A bug? if (first != string::npos && first + 2 == name.size() && name[first + 1] == '0') name.erase(first); @@ -473,12 +645,13 @@ class Converter { if (trt_tensors_.count(name)) { inputs->push_back(trt_tensors_.at(name)); } else { - string str("Node "); - StrAppend(&str, node_def.name(), " should have an input named '", name, + // TODO(aaroey): this should not happen, make it a CHECK. + // TODO(aaroey): use StrCat for pattern like this. + string msg("Node "); + StrAppend(&msg, node_def.name(), " should have an input named '", name, "' but it is not available"); - LOG(WARNING) << "input: " << name << " not available for node at " - << node_def.name(); - return tensorflow::errors::InvalidArgument(str); + LOG(ERROR) << msg; + return tensorflow::errors::InvalidArgument(msg); } } return tensorflow::Status::OK(); @@ -486,11 +659,11 @@ class Converter { public: explicit Converter(nvinfer1::INetworkDefinition* trt_network, - tensorflow::tensorrt::TRTWeightStore* ws, bool fp16) + TRTWeightStore* ws, bool fp16) : trt_network_(trt_network), weight_store_(ws), fp16_(fp16) { this->register_op_converters(); } - tensorflow::tensorrt::TRTWeightStore* weight_store() { return weight_store_; } + TRTWeightStore* weight_store() { return weight_store_; } TRT_ShapedWeights get_temp_weights(tensorflow::DataType type, nvinfer1::Dims shape) { TRT_ShapedWeights weights(type, nullptr, shape); @@ -499,6 +672,7 @@ class Converter { weights.SetValues(weight_store_->store_.back().data()); return weights; } + // TODO(aaroey): fix all the namings. bool isFP16() { return fp16_; } TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) { return this->get_temp_weights(weights.type_, weights.shape_); @@ -507,9 +681,10 @@ class Converter { tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) { std::vector inputs; TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs)); - string op = node_def.op(); + const string& op = node_def.op(); std::vector outputs; if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) { + // TODO(aaroey): plugin_converter_ is not set, fix it. TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs)); } else { if (!op_registry_.count(op)) { @@ -520,7 +695,7 @@ class Converter { TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); } for (size_t i = 0; i < outputs.size(); ++i) { - TRT_TensorOrWeights output = outputs.at(i); + TRT_TensorOrWeights& output = outputs[i]; // TODO(jie): tf protobuf seems to be omitting the :0 suffix string output_name = node_def.name(); if (i != 0) output_name = StrCat(output_name, ":", i); @@ -538,26 +713,29 @@ class Converter { nvinfer1::INetworkDefinition* network() { return trt_network_; } - TRT_TensorOrWeights get_tensor(string name) { + TRT_TensorOrWeights get_tensor(const string& name) { if (!trt_tensors_.count(name)) { return TRT_TensorOrWeights(nullptr); } return trt_tensors_.at(name); } - bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) { + bool insert_input_tensor(const string& name, nvinfer1::ITensor* tensor) { return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second; } nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor, - std::vector order) { - auto dims = input_tensor->getDimensions(); + const std::vector& order) { + const auto dims = input_tensor->getDimensions(); // TODO(jie): change the return to status and properly exit if (order.size() - 1 != size_t(dims.nbDims)) LOG(ERROR) << "Dimension does not match, fail gracefully"; nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor); + if (layer == nullptr) { + return nullptr; + } nvinfer1::Permutation permutation; for (int32_t i = 0; i < dims.nbDims; ++i) { permutation.order[i] = order[i + 1] - 1; @@ -588,13 +766,14 @@ TRT_ShapedWeights ConvertFP32ToFP16(Converter& ctx, } return weights; } + // **************************************************************************** // Constant folding functions // TODO(jie): once optimizer kicks in, we should have done constant folding // there. -//*****************************************************************************/ +// ***************************************************************************** struct LambdaFactory { - enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB }; + enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB, RECIP }; OP_CATEGORY op; template @@ -606,6 +785,8 @@ struct LambdaFactory { } case OP_CATEGORY::NEG: return [](T t) -> T { return -t; }; + case OP_CATEGORY::RECIP: + return [](T t) -> T { return 1.0 / t; }; default: VLOG(2) << "Not supported op for unary: " << static_cast(op); return nullptr; @@ -639,7 +820,6 @@ struct LambdaFactory { VLOG(2) << "LAMBDA VAL : " << val; return l + val; }; - // Return [val](T l)-> T {return l+val;}; case OP_CATEGORY::SUB: return [val](T l) -> T { VLOG(2) << "LAMBDA VAL : " << val; @@ -699,11 +879,13 @@ std::function LambdaFactory::unary() { } case OP_CATEGORY::NEG: return [](Eigen::half t) -> Eigen::half { return -t; }; + // TODO(aaroey): can we support RECIP? default: VLOG(2) << "Not supported op for unary: " << static_cast(op); return nullptr; } } + tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights, LambdaFactory unary_op) { @@ -749,6 +931,7 @@ tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l, if (iweights_l.count() != iweights_r.count()) { // We only supports broadcast of RankZero if (iweights_l.count() == 1) { + // TODO(aaroey): Remove loggings like this. VLOG(2) << "I bet it is not working!" << (*inp_l); std::transform(inp_r, inp_r + iweights_r.count(), oup, binary_op.broadcast_l(*inp_l)); @@ -801,117 +984,21 @@ tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l, return tensorflow::Status::OK(); } -tensorflow::Status ConstantFoldUnary( - Converter& ctx, const tensorflow::NodeDef& node_def, - const std::vector& inputs, - std::vector* outputs) { - TRT_ShapedWeights weights_input = inputs.at(0).weights(); - - // Allocate output weights - TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input); - - // FIXME assume type matches input weights - // Get trt type & shape - // Maybe this part has to be moved into the block of rsqrt later - // Check type consistency - CHECK_EQ(weights_input.type_, - TFAttrs(node_def).get("T")); - - LambdaFactory unary_op; - if (node_def.op() == "Rsqrt") { - // Compute rsqrt - unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT; - auto ret = UnaryCompute(weights_input, &weights_output, unary_op); - // Pass the output - if (ret == tensorflow::Status::OK()) { - outputs->push_back(TRT_TensorOrWeights(weights_output)); - } - return ret; - } else { - return tensorflow::errors::Unimplemented("Binary op not supported: " + - node_def.op()); - } -} - -// TODO(jie,ben) broadcast is needed yet not implemented -// Let's get the simple stuff working first. Maybe we should fall back to TF -// approach for constant folding -tensorflow::Status ConstantFoldBinary( - Converter& ctx, const tensorflow::NodeDef& node_def, - const std::vector& inputs, - std::vector* outputs) { - TRT_ShapedWeights weights_input_l = inputs.at(0).weights(); - TRT_ShapedWeights weights_input_r = inputs.at(1).weights(); - - // Check type consistency - CHECK_EQ(weights_input_l.type_, weights_input_r.type_); - - if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims) - return tensorflow::errors::Unimplemented( - "Binary op implicit broadcast not supported: " + node_def.op()); - - // TODO(jie): constant fold should really fall back to TF. - int num_dims = weights_input_l.shape_.nbDims; - nvinfer1::Dims output_shape; - output_shape.nbDims = num_dims; - VLOG(2) << "nb_dims: " << num_dims - << ", the other: " << weights_input_r.shape_.nbDims; - for (int i = 0; i < num_dims; i++) { - if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) { - output_shape.d[i] = weights_input_l.shape_.d[i]; - } else if (weights_input_l.shape_.d[i] == 1 || - weights_input_r.shape_.d[i] == 1) { - output_shape.d[i] = - std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]); - } else { - return tensorflow::errors::Unimplemented( - "Binary op with incompatible shape at, " + node_def.op()); - } - VLOG(2) << "left: " << weights_input_l.shape_.d[i] - << "right: " << weights_input_r.shape_.d[i] - << "output: " << output_shape.d[i]; - } - - // FIXME assume type matches input weights - // Get trt type & shape - TFAttrs attrs(node_def); - // Maybe this part has to be moved into the block of rsqrt later - tensorflow::DataType dtype = attrs.get("T"); - - // Allocate output weights - TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape); - - LambdaFactory binary_op; - if (node_def.op() == "Sub") { - binary_op.op = LambdaFactory::OP_CATEGORY::SUB; - } else if (node_def.op() == "Mul") { - binary_op.op = LambdaFactory::OP_CATEGORY::MUL; - } else if (node_def.op() == "Add") { - binary_op.op = LambdaFactory::OP_CATEGORY::ADD; - } else { - return tensorflow::errors::Unimplemented("Binary op not supported: " + - node_def.op()); - } - auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output, - binary_op); - - // Pass the output - if (ret == tensorflow::Status::OK()) { - outputs->push_back(TRT_TensorOrWeights(weights_output)); - } - - return ret; -} - // TODO(jie): broadcast is needed yet not implemented. // Only implemented channel wise for the time being tensorflow::Status BinaryTensorOpWeight( Converter& ctx, const tensorflow::NodeDef& node_def, const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights, - std::vector* outputs) { - // FIXME assume type matches input weights - // Get trt type & shape - // Maybe this part has to be moved into the block of rsqrt later + bool swapped_inputs, std::vector* outputs) { + // tensor is the left operand while weights is the right operand; + // when swapped_inputs set to true, those two are swapped. + // TODO(aaroey): use a set. + if (node_def.op() != "Sub" && node_def.op() != "Add" && + node_def.op() != "Mul" && node_def.op() != "Div" && + node_def.op() != "RealDiv") { + return tensorflow::errors::Unimplemented( + "op not supported: " + node_def.op() + ", at: " + node_def.name()); + } // Check type consistency nvinfer1::DataType ttype; @@ -921,6 +1008,12 @@ tensorflow::Status BinaryTensorOpWeight( auto dims_w = weights.shape_; auto dims_t = tensor->getDimensions(); + // TODO(jie): addScale checks for input tensor dimension + if (dims_t.nbDims != 3) { + return tensorflow::errors::InvalidArgument( + "addScale requires tensor with rank 3, " + node_def.name()); + } + // default to element-wise auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; @@ -991,6 +1084,7 @@ tensorflow::Status BinaryTensorOpWeight( permutation[dims_t.nbDims] = 1; tensor = ctx.TransposeTensor(const_cast(tensor), permutation); + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); } else { return tensorflow::errors::InvalidArgument( "Transpose cannot be applied, " + node_def.name()); @@ -1008,11 +1102,35 @@ tensorflow::Status BinaryTensorOpWeight( // Maybe I should do a switch if (node_def.op() == "Sub") { - TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights); - LambdaFactory unary_op; - unary_op.op = LambdaFactory::OP_CATEGORY::NEG; - TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op)); - shift_weights = neg_weights; + if (swapped_inputs) { + shift_weights = weights; + nvinfer1::IUnaryLayer* layer = + ctx.network()->addUnary(*const_cast(tensor), + nvinfer1::UnaryOperation::kNEG); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + tensor = layer->getOutput(0); + } else { + TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights); + LambdaFactory unary_op; + unary_op.op = LambdaFactory::OP_CATEGORY::NEG; + TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op)); + shift_weights = neg_weights; + } + } else if (node_def.op() == "Div" || node_def.op() == "RealDiv") { + if (swapped_inputs) { + scale_weights = weights; + nvinfer1::IUnaryLayer* layer = + ctx.network()->addUnary(*const_cast(tensor), + nvinfer1::UnaryOperation::kRECIP); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + tensor = layer->getOutput(0); + } else { + TRT_ShapedWeights recip_weights = ctx.get_temp_weights_like(weights); + LambdaFactory unary_op; + unary_op.op = LambdaFactory::OP_CATEGORY::RECIP; + TF_RETURN_IF_ERROR(UnaryCompute(weights, &recip_weights, unary_op)); + scale_weights = recip_weights; + } } else if (node_def.op() == "Mul") { scale_weights = weights; } else if (node_def.op() == "Add") { @@ -1025,11 +1143,13 @@ tensorflow::Status BinaryTensorOpWeight( nvinfer1::IScaleLayer* layer = ctx.network()->addScale( *const_cast(tensor), scale_mode, shift_weights, scale_weights, power_weights); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); // transpose back dimension if (permutation_flag) { output_tensor = ctx.TransposeTensor(output_tensor, permutation); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } // Pass the output @@ -1053,20 +1173,31 @@ tensorflow::Status ConvertConv2DHelper( if (data_format == "NHWC") { tensor = ctx.TransposeTensor(const_cast(tensor), {0, 3, 1, 2}); + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); h_index = 1; w_index = 2; // TODO(jie): transpose it } // tensor after transpose (NCHW) - auto tensor_dim = tensor->getDimensions(); + const auto tensor_dim = tensor->getDimensions(); int num_groups = group; - if (num_groups == 0) // depthwise convolution - num_groups = tensor_dim.d[0]; + if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution VLOG(2) << "groups count: " << num_groups; TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); + + VLOG(2) << "weight shape: " << weights_rsck.shape_.nbDims; + for (int i = 0; i < weights_rsck.shape_.nbDims; i++) { + VLOG(2) << weights_rsck.shape_.d[i]; + } + + if (weights_rsck.shape_.nbDims != 4) { + return tensorflow::errors::Internal( + "Conv2D expects kernel of dimension 4, at: " + node_def.name()); + } + if (ctx.isFP16()) { weights_rsck = ConvertFP32ToFP16(ctx, inputs.at(1).weights()); } @@ -1074,18 +1205,22 @@ tensorflow::Status ConvertConv2DHelper( TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck); ReorderRSCKToKCRS(weights_rsck, &weights, num_groups); TRT_ShapedWeights biases(weights.type_); - int noutput = weights.shape_.d[0] * num_groups; + const int noutput = weights.shape_.d[0] * num_groups; nvinfer1::DimsHW kernel_size; kernel_size.h() = weights.shape_.d[2]; kernel_size.w() = weights.shape_.d[3]; + VLOG(2) << "RSCK: "; + for (int i = 0; i < 4; i++) { + VLOG(2) << " " << weights.shape_.d[i]; + } VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w(); // TODO(jie): stride. (NHWC/NCHW) - auto tf_stride = attrs.get>("strides"); + const auto tf_stride = attrs.get>("strides"); VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index; VLOG(2) << "stride!!!: " << tf_stride[0] << tf_stride[1] << tf_stride[2] << tf_stride[3]; - nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); std::vector> padding; // TODO(jie): padding. @@ -1113,6 +1248,7 @@ tensorflow::Status ConvertConv2DHelper( *const_cast(tensor), nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); + TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); auto dim_after = tensor->getDimensions(); @@ -1123,6 +1259,7 @@ tensorflow::Status ConvertConv2DHelper( nvinfer1::IConvolutionLayer* layer = ctx.network()->addConvolution(*const_cast(tensor), noutput, kernel_size, weights, biases); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); layer->setStride(stride); layer->setPadding({padding[0].first, padding[1].first}); @@ -1137,6 +1274,7 @@ tensorflow::Status ConvertConv2DHelper( if (data_format == "NHWC") { // TODO(jie): transpose it back! output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } else { VLOG(2) << "NCHW !!!!"; } @@ -1158,35 +1296,91 @@ tensorflow::Status ConvertConv2DHelper( node_def.name()); } +// Helper function converts input into tensor with shape specified by dims. +bool PrepareTensorForShape(Converter& ctx, const TRT_TensorOrWeights& input, + const nvinfer1::Dims& dims, + const nvinfer1::ITensor** tensor) { + if (input.is_tensor()) { + if (DimsEqual(input.shape(), dims)) { + *tensor = input.tensor(); + } else { + nvinfer1::IShuffleLayer* layer = ctx.network()->addShuffle( + *const_cast(input.tensor())); + if (layer != nullptr) { + layer->setReshapeDimensions(dims); + *tensor = layer->getOutput(0); + } else { + return false; + } + } + } else { +#if NV_TENSORRT_MAJOR > 3 + nvinfer1::IConstantLayer* layer = + ctx.network()->addConstant(dims, input.weights()); + if (layer != nullptr) { + *tensor = layer->getOutput(0); + } else { + return false; + } +#else + return false; +#endif + } + return true; +} + tensorflow::Status BinaryTensorOpTensor( Converter& ctx, const tensorflow::NodeDef& node_def, - const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r, + const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r, std::vector* outputs) { static const std::unordered_map ops{ {"Add", nvinfer1::ElementWiseOperation::kSUM}, {"Mul", nvinfer1::ElementWiseOperation::kPROD}, {"Sub", nvinfer1::ElementWiseOperation::kSUB}, {"Div", nvinfer1::ElementWiseOperation::kDIV}, + {"RealDiv", nvinfer1::ElementWiseOperation::kDIV}, + {"Minimum", nvinfer1::ElementWiseOperation::kMIN}, + {"Maximum", nvinfer1::ElementWiseOperation::kMAX}, }; - // FIXME assume type matches input weights + const nvinfer1::ITensor* tensor_l; + const nvinfer1::ITensor* tensor_r; + + nvinfer1::Dims dim_l; + nvinfer1::Dims dim_r; + + if (!TensorRTGetBroadcastShape(operand_l.shape(), operand_l.is_tensor(), + operand_r.shape(), operand_r.is_tensor(), + &dim_l, &dim_r)) { + return tensorflow::errors::InvalidArgument( + "Binary op broadcast scheme not supported by TensorRT op: " + + node_def.op() + ", at: " + node_def.name()); + } + + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, operand_l, dim_l, &tensor_l), node_def.name()); + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, operand_r, dim_r, &tensor_r), node_def.name()); + // get trt type & shape TFAttrs attrs(node_def); // maybe this part has to be moved into the block of rsqrt later nvinfer1::DataType dtype = attrs.get("T"); // check type consistency - CHECK_EQ_TYPE(tensor_l->getType(), dtype); - CHECK_EQ_TYPE(tensor_r->getType(), dtype); + TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype); + TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype); auto op_pair = ops.find(node_def.op()); - if (op_pair == ops.end()) + if (op_pair == ops.end()) { return tensorflow::errors::Unimplemented( - "binary op: " + node_def.op() + - " not supported at: " + node_def.name()); + "binary op: ", node_def.op(), " not supported at: ", node_def.name()); + } nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( + // TODO(aaroey): will tensor_l/tensor_r get modified? *const_cast(tensor_l), *const_cast(tensor_r), op_pair->second); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); @@ -1213,7 +1407,7 @@ tensorflow::Status ConvertPlugin(Converter& ctx, // passing attributes // TODO(jie): support more general attribute TFAttrs attrs(node_def); - auto attr_key_vector = attrs.GetAllAttrKey(); + auto attr_key_vector = attrs.GetAllAttrKeys(); for (auto attr_key : attr_key_vector) { // TODO(jie): support only list of float for toy example here. auto data = attrs.get>(attr_key); @@ -1234,29 +1428,6 @@ tensorflow::Status ConvertPlugin(Converter& ctx, return tensorflow::Status::OK(); } -tensorflow::Status ConvertPlaceholder( - Converter& ctx, const tensorflow::NodeDef& node_def, - const std::vector& inputs, - std::vector* outputs) { - VLOG(2) << "Placeholder should have been replace already"; - return tensorflow::errors::Unimplemented("cannot convert Placeholder op"); - // OK this make sense since we are supposed to replace it with input - TFAttrs attrs(node_def); - nvinfer1::DataType dtype = attrs.get("dtype"); - nvinfer1::Dims dims = attrs.get("shape"); - - dims.nbDims--; - for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1]; - - nvinfer1::ITensor* output = - ctx.network()->addInput(node_def.name().c_str(), dtype, dims); - if (!output) { - return tensorflow::errors::InvalidArgument("Failed to create Input layer"); - } - outputs->push_back(TRT_TensorOrWeights(output)); - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertConv2D(Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, @@ -1282,65 +1453,64 @@ tensorflow::Status ConvertPool(Converter& ctx, int h_index = 2; int w_index = 3; - auto data_format = attrs.get("data_format"); + const auto data_format = attrs.get("data_format"); if (data_format == "NHWC") { h_index = 1; w_index = 2; tensor = ctx.TransposeTensor(const_cast(tensor), {0, 3, 1, 2}); - } else { - VLOG(2) << "NCHW !!!!"; + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); } + nvinfer1::PoolingType type; - // TODO(jie): support other pooling type - if (node_def.op() == "MaxPool") + if (node_def.op() == "MaxPool") { type = nvinfer1::PoolingType::kMAX; - else if (node_def.op() == "AvgPool") + } else if (node_def.op() == "AvgPool") { type = nvinfer1::PoolingType::kAVERAGE; - else - return tensorflow::errors::Unimplemented("Only supports Max pool"); + } else { + return tensorflow::errors::Unimplemented("Unsupported pool type: ", + node_def.op()); + } - // TODO(jie): NCHW - auto tf_stride = attrs.get>("strides"); - nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + const auto tf_stride = attrs.get>("strides"); + const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); - auto tf_kernel = attrs.get>("ksize"); - nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]); + const auto tf_kernel = attrs.get>("ksize"); + const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]); auto tensor_dim = tensor->getDimensions(); std::vector> padding; - // TODO(jie): padding. - if (attrs.get("padding") == "SAME") { + const string padding_type = attrs.get("padding"); + if (padding_type == "SAME") { // This is NCHW tensor with no batch dimension. // 1 -> h // 2 -> w padding = CreateSamePadding( stride, ksize, {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); - } else if (attrs.get("padding") == "VALID") { - // No padding for valid padding here - VLOG(2) << "No padding added for VALID padding in pool" << node_def.name(); + } else if (padding_type == "VALID") { padding = {{0, 0}, {0, 0}}; } else { - return tensorflow::errors::Unimplemented( - "Current MaxPool cannot support padding other than SAME"); + return tensorflow::errors::Unimplemented("Unsupported padding type: ", + padding_type); } if (padding[0].first != padding[0].second || padding[1].first != padding[1].second) { - // TODO(jie): handle asymmetric padding VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second << padding[1].first << padding[1].second; auto pad_layer = ctx.network()->addPadding( *const_cast(tensor), nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); + TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); } nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling( *const_cast(tensor), type, ksize); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); layer->setStride(stride); layer->setPadding({padding[0].first, padding[1].first}); @@ -1348,10 +1518,8 @@ tensorflow::Status ConvertPool(Converter& ctx, nvinfer1::ITensor* output_tensor = layer->getOutput(0); if (data_format == "NHWC") { - // TODO(jie): transpose it back! output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); - } else { - VLOG(2) << "NCHW !!!!"; + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); @@ -1364,6 +1532,7 @@ tensorflow::Status ConvertActivation( const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); nvinfer1::IActivationLayer* layer = ctx.network()->addActivation( *const_cast(tensor), nvinfer1::ActivationType::kRELU); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); @@ -1374,40 +1543,61 @@ tensorflow::Status ConvertScale(Converter& ctx, const std::vector& inputs, std::vector* outputs) { if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) + !inputs.at(1).is_weights()) { return tensorflow::errors::Unimplemented( - "Only supports tensor op weight for now, at " + node_def.name()); - // Implement tensor binaryOp weight [channel wise] for now; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + "ConvertScale only supports tensorweight: ", node_def.name()); + } + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); TRT_ShapedWeights weights = inputs.at(1).weights(); if (ctx.isFP16()) { weights = ConvertFP32ToFP16(ctx, inputs.at(1).weights()); } TRT_ShapedWeights empty_weights(weights.type_); - TFAttrs attrs(node_def); - // Transpose NHWC - auto data_format = attrs.get("data_format"); + const auto data_format = attrs.get("data_format"); + int channel_index; + const auto dims = tensor->getDimensions(); if (data_format == "NHWC") { - tensor = ctx.TransposeTensor(const_cast(tensor), - {0, 3, 1, 2}); - // TODO(jie): transpose it + // 1). NHWC is really N+C + channel_index = dims.nbDims - 1; // batch dimension is implicit here! } else { - VLOG(2) << "NCHW !!!!"; + // 2). NCHW is really N+CHW + channel_index = dims.nbDims - 3; // batch dimension is implicit here! } - auto dims = tensor->getDimensions(); - VLOG(2) << "tensor dimensions: " << dims.nbDims; - for (int i = 0; i < dims.nbDims; i++) { - VLOG(2) << "i: " << dims.d[i]; + nvinfer1::Permutation permutation; + for (int32_t i = 0; i < dims.nbDims; ++i) { + permutation.order[i] = i; } - dims = weights.shape_; - VLOG(2) << "tensor dimensions: " << dims.nbDims; - for (int i = 0; i < dims.nbDims; i++) { - VLOG(2) << "i: " << dims.d[i]; + + if (channel_index >= 0) { + permutation.order[0] = channel_index; + permutation.order[channel_index] = 0; + } else { + return tensorflow::errors::Unimplemented( + "TFTRT::BiasAdd cannot apply on batch dimension, at ", node_def.name()); + } + + // TensorRT addScale requires input to be of rank 3, we need to apply + // transpose as well as reshape + if (channel_index != 0 || dims.nbDims != 3) { + nvinfer1::IShuffleLayer* shuffle_layer = + ctx.network()->addShuffle(*const_cast(tensor)); + TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); + nvinfer1::Dims reshape_dims; + reshape_dims.nbDims = 3; + reshape_dims.d[0] = 0; // 0 copy from the input + reshape_dims.d[1] = dims.nbDims >= 2 ? 0 : 1; // 0 copy from the input + reshape_dims.d[2] = dims.nbDims >= 3 ? -1 : 1; // -1 infer from the rest + if (channel_index != 0) { + // maybe we do not need this check. concerned about TRT optimization + shuffle_layer->setFirstTranspose(permutation); + } + shuffle_layer->setReshapeDimensions(reshape_dims); + tensor = shuffle_layer->getOutput(0); } nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL; @@ -1418,14 +1608,26 @@ tensorflow::Status ConvertScale(Converter& ctx, nvinfer1::IScaleLayer* layer = ctx.network()->addScale(*const_cast(tensor), mode, weights, empty_weights, empty_weights); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); - if (data_format == "NHWC") { - // TODO(jie): transpose it back! - output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); - } else { - VLOG(2) << "NCHW !!!!"; + + // restore transpose & reshape + if (channel_index != 0 || dims.nbDims != 3) { + nvinfer1::IShuffleLayer* shuffle_layer = ctx.network()->addShuffle( + *const_cast(output_tensor)); + TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); + nvinfer1::Dims reshape_dims = dims; + int tmp = reshape_dims.d[channel_index]; + reshape_dims.d[channel_index] = reshape_dims.d[0]; + reshape_dims.d[0] = tmp; + shuffle_layer->setReshapeDimensions(reshape_dims); + if (channel_index != 0) { + shuffle_layer->setSecondTranspose(permutation); + } + output_tensor = shuffle_layer->getOutput(0); } + outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } @@ -1442,11 +1644,13 @@ tensorflow::Status ConvertConst(Converter& ctx, // Create shaped weights as output tensorflow::Tensor tensor; - if (!tensor.FromProto(weights_tensor)) - return tensorflow::errors::Internal("Cannot parse weight tensor proto: " + + if (!tensor.FromProto(weights_tensor)) { + return tensorflow::errors::Internal("Cannot parse weight tensor proto: ", node_def.name()); + } TRT_ShapedWeights weights(dtype); + // TODO(aaroey): we should choose the array using dtype and shape. if (!weights_tensor.float_val().empty()) { VLOG(2) << "SCALAR!!!" << node_def.name(); nvinfer1::Dims scalar_shape; @@ -1454,22 +1658,16 @@ tensorflow::Status ConvertConst(Converter& ctx, VLOG(2) << "dimensions: " << tensor.dims(); VLOG(2) << "size: " << weights_tensor.float_val_size(); scalar_shape = GetTensorShape(tensor); + VLOG(2) << "details: "; for (int i = 0; i < scalar_shape.nbDims; i++) VLOG(2) << scalar_shape.d[i]; - if (GetShapeSize(scalar_shape) != weights_tensor.float_val_size()) { - if (weights_tensor.float_val_size() == 1 || - scalar_shape.d[0] == weights_tensor.float_val_size()) { - scalar_shape.nbDims = 1; - // no dimension provided. flatten it - scalar_shape.d[0] = weights_tensor.float_val_size(); - scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; - } else { - LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and" - << " kUNIFORM, at: " << node_def.name(); - string err_str("Broadcast method is not supported for '"); - StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); - return tensorflow::errors::InvalidArgument(err_str); - } + if (GetShapeSize(scalar_shape) != weights_tensor.float_val_size() && + weights_tensor.float_val_size() != 1) { + LOG(ERROR) << "Broadcast on weights only supports kCHANNEL and" + << " kUNIFORM, at: " << node_def.name(); + string err_str("Broadcast method is not supported for '"); + StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); + return tensorflow::errors::InvalidArgument(err_str); } } else { VLOG(2) << "Dimensions: " << tensor.dims(); @@ -1479,39 +1677,42 @@ tensorflow::Status ConvertConst(Converter& ctx, scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) { scalar_shape.d[i] = 0; - scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL; } } + // TODO(aaroey): use GetShapeSize(). size_t len_data = tensorflow::DataTypeSize(dtype); for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i]; ctx.weight_store()->store_.push_back(std::vector(len_data)); void* dst = static_cast(&(ctx.weight_store()->store_.back()[0])); - std::vector tensor_data( - weights_tensor.float_val().begin(), - weights_tensor.float_val() - .end()); // make a local copy first to flatten - memcpy(dst, tensor_data.data(), len_data); // store into weight store + if (weights_tensor.float_val_size() == 1) { + std::fill_n((float*)dst, GetShapeSize(scalar_shape), + *weights_tensor.float_val().begin()); + } else { + // TODO(aaroey): get rid of this copy as RepeatedField is always + // contiguous make a local copy first to flatten doesn't have to be + // contiguous + std::vector tensor_data(weights_tensor.float_val().begin(), + weights_tensor.float_val().end()); + memcpy(dst, tensor_data.data(), len_data); // store into weight store + } + VLOG(2) << "create shape details: "; + for (int i = 0; i < scalar_shape.nbDims; i++) VLOG(2) << scalar_shape.d[i]; weights = TRT_ShapedWeights(dtype, dst, scalar_shape); } else if (!weights_tensor.int_val().empty()) { + // TODO(aaroey): this is very similar to the above code for float, merge + // them. VLOG(2) << "int!!!" << node_def.name(); nvinfer1::Dims scalar_shape; if (tensor.dims() > 0) { VLOG(2) << "dimensions: " << tensor.dims(); scalar_shape = GetTensorShape(tensor); - if (GetShapeSize(scalar_shape) != weights_tensor.int_val_size()) { - if (weights_tensor.int_val_size() == 1 || - scalar_shape.d[0] == weights_tensor.int_val_size()) { - scalar_shape.nbDims = 1; - // no dimension provided. flatten it - scalar_shape.d[0] = weights_tensor.int_val_size(); - scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; - } else { - LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and" - << " kUNIFORM, at: " << node_def.name(); - string err_str("Broadcast method is not supported for '"); - StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); - return tensorflow::errors::InvalidArgument(err_str); - } + if (GetShapeSize(scalar_shape) != weights_tensor.int_val_size() && + weights_tensor.int_val_size() != 1) { + LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and" + << " kUNIFORM, at: " << node_def.name(); + string err_str("Broadcast method is not supported for '"); + StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); + return tensorflow::errors::InvalidArgument(err_str); } } else { VLOG(2) << "dimensions: " << tensor.dims(); @@ -1524,23 +1725,30 @@ tensorflow::Status ConvertConst(Converter& ctx, scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL; } } - // we should not have converted //if (ctx.isFP16()) { + // we should not have converted size_t len_data = tensorflow::DataTypeSize(dtype); for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i]; size_t len_tensor = weights_tensor.int_val_size() * sizeof(int32); len_data = std::max(len_data, len_tensor); ctx.weight_store()->store_.push_back(std::vector(len_data)); void* dst = static_cast(&(ctx.weight_store()->store_.back()[0])); - std::vector tensor_data( - weights_tensor.int_val().begin(), - weights_tensor.int_val().end()); // make a local copy first to flatten - // doesn't have to be contigous - memcpy(dst, tensor_data.data(), len_tensor); // store into weight store + if (weights_tensor.int_val_size() == 1) { + std::fill_n((int*)dst, GetShapeSize(scalar_shape), + *weights_tensor.int_val().begin()); + } else { + // TODO(aaroey): get rid of this copy as RepeatedField is always + // contiguous make a local copy first to flatten doesn't have to be + // contiguous + std::vector tensor_data(weights_tensor.int_val().begin(), + weights_tensor.int_val().end()); + memcpy(dst, tensor_data.data(), len_tensor); // store into weight store + } weights = TRT_ShapedWeights(dtype, dst, scalar_shape); } else if (!weights_tensor.tensor_content().empty()) { - // obsolete method. - // After optimization path, we do not see weights in this format. - // fp16 conversion technically should be needed here. + // obsolete method. + // After optimization path, we do not see weights in this format. + // TODO(aaroey): why? + // fp16 conversion technically should be needed here. VLOG(2) << "TENSOR!!!" << node_def.name(); const auto& content = weights_tensor.tensor_content(); @@ -1554,8 +1762,8 @@ tensorflow::Status ConvertConst(Converter& ctx, content, static_cast(const_cast(weights.GetValues()))); } } else { - return tensorflow::errors::Unimplemented( - "Not supported constant type, at " + node_def.name()); + return tensorflow::errors::Unimplemented("Not supported constant type, at ", + node_def.name()); } // Pass the output outputs->push_back(TRT_TensorOrWeights(weights)); @@ -1574,96 +1782,144 @@ tensorflow::Status ConvertBinary(Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, std::vector* outputs) { - if (inputs.size() != 2) + if (inputs.size() != 2) { return tensorflow::errors::FailedPrecondition( - "Binary ops require two tensor input, at " + node_def.name()); - - if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) - return ConstantFoldBinary(ctx, node_def, inputs, outputs); - - if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) - return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(), - inputs.at(1).weights(), outputs); + "Binary ops require two tensor input, at ", node_def.name()); + } - if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) - return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(), - inputs.at(0).weights(), outputs); + // Constant folding should have been done by TensorFlow - if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) - return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(), - inputs.at(1).tensor(), outputs); + if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) { + return tensorflow::errors::Unimplemented( + "Constant folding is falled back to TensorFlow, binary op received " + "both input as constant at: ", + node_def.name()); + } - return tensorflow::errors::Unknown("Binary op input error, at " + - node_def.name()); + // Try to convert into Scale layer first (for better performance) + // Since scale layer supports restricted broadcast policy and op types, we + // allow failure and try to handle it through Elementwise op + // (BinaryTensorOpTensor) + Status status = tensorflow::Status::OK(); + if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) { + status = BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(), + inputs.at(1).weights(), false, outputs); + } else if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) { + status = BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(), + inputs.at(0).weights(), true, outputs); +#if NV_TENSORRT_MAJOR == 3 + } else { +#else + } + if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) { +#endif + status = BinaryTensorOpTensor(ctx, node_def, inputs.at(0), inputs.at(1), + outputs); + } + return status; } tensorflow::Status ConvertUnary(Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, std::vector* outputs) { - if (inputs.size() != 1) + static const std::unordered_map ops{ + {"Neg", nvinfer1::UnaryOperation::kNEG}, + {"Exp", nvinfer1::UnaryOperation::kEXP}, + {"Log", nvinfer1::UnaryOperation::kLOG}, + {"Sqrt", nvinfer1::UnaryOperation::kSQRT}, + {"Abs", nvinfer1::UnaryOperation::kABS}, + {"Reciprocal", nvinfer1::UnaryOperation::kRECIP}, + }; + + if (inputs.size() != 1) { return tensorflow::errors::FailedPrecondition( - "Unary ops require single tensor input, at " + node_def.name()); + "Unary ops require single tensor input, at ", node_def.name()); + } - if (inputs.at(0).is_weights()) - return ConstantFoldUnary(ctx, node_def, inputs, outputs); - else if (inputs.at(0).is_tensor()) +#if NV_TENSORRT_MAJOR == 3 + if (inputs.at(0).is_weights()) { return tensorflow::errors::Unimplemented( - "Unary op for tensor not supported, at " + node_def.name()); + "Constant folding for unary op is not supported", node_def.name()); + } +#endif + + // TODO(jie): check type + const nvinfer1::ITensor* tensor; + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, inputs.at(0), inputs.at(0).shape(), &tensor), + node_def.name()); + + nvinfer1::IUnaryLayer* layer; + if (node_def.op() == "Rsqrt") { + layer = ctx.network()->addUnary(*const_cast(tensor), + nvinfer1::UnaryOperation::kSQRT); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + tensor = layer->getOutput(0); + layer = ctx.network()->addUnary(*const_cast(tensor), + nvinfer1::UnaryOperation::kRECIP); + } else if (ops.count(node_def.op()) != 0) { + layer = ctx.network()->addUnary(*const_cast(tensor), + ops.at(node_def.op())); + } else { + return tensorflow::errors::InvalidArgument( + "Binary op: ", node_def.op(), " not supported, at ", node_def.name()); + } - return tensorflow::errors::Unknown("Binary op input error, at " + - node_def.name()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); } -tensorflow::Status ConvertReduce(Converter& ctx, - const tensorflow::NodeDef& node_def, - const std::vector& inputs, - std::vector* outputs) { +#if NV_TENSORRT_MAJOR == 3 +tensorflow::Status ConvertReducePool( + Converter& ctx, const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) + !inputs.at(1).is_weights()) { return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at" + node_def.name()); + "Input expects tensor and weights, at", node_def.name()); + } // Implement tensor binaryOp weight [channel wise] for now; const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - auto dims = tensor->getDimensions(); + const auto dims = tensor->getDimensions(); // Restore implicit batch dimension - int nb_dims = dims.nbDims + 1; + const int nb_dims = dims.nbDims + 1; TRT_ShapedWeights index_list = inputs.at(1).weights(); - TFAttrs attrs(node_def); - // TODO(jie): handle data type. - // Index type here is done through TF type, so I can leverage their - // EnumToDataType for my cast auto index_type = attrs.get("Tidx"); // Only expect to handle INT32 as attributes for now - if (index_type != tensorflow::DataType::DT_INT32) + if (index_type != tensorflow::DataType::DT_INT32) { return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32"); - auto index_list_data = + } + const auto index_list_data = static_cast(const_cast(index_list.GetValues())); - // Hack warning: have to fall back to pool layer since reduce is not in public - // TRT yet. - if (nb_dims != 4) + if (nb_dims != 4) { return tensorflow::errors::InvalidArgument( - "TRT only support reduce on 4 dimensional tensors, at" + + "TRT only support reduce on 4 dimensional tensors, at", node_def.name()); - if (index_list.count() > 2) + } + if (index_list.count() > 2) { return tensorflow::errors::InvalidArgument( - "TRT cannot support reduce on more than 2 dimensions, at" + + "TRT cannot support reduce on more than 2 dimensions, at", node_def.name()); + } std::set idx_set; // We cannot operate on Channel. permutation flag used to transpose tensor int permuted_index = -1; for (int i = 0; i < index_list.count(); i++) { - if (index_list_data[i] == 0) - return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" + + if (index_list_data[i] == 0) { + return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at", node_def.name()); + } if (index_list_data[i] == 1) permuted_index = 1; - idx_set.emplace(index_list_data[i]); } @@ -1684,6 +1940,7 @@ tensorflow::Status ConvertReduce(Converter& ctx, // Apply permutation before extracting dimension for pool_kernel tensor = ctx.TransposeTensor(const_cast(tensor), permutation_order); + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); } // Apply permutation before extracting dimension for pool_kernel @@ -1696,34 +1953,104 @@ tensorflow::Status ConvertReduce(Converter& ctx, nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(*const_cast(tensor), nvinfer1::PoolingType::kAVERAGE, pool_kernel); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); output_tensor = layer->getOutput(0); } else { - return tensorflow::errors::Unimplemented( - "Op not supported " + node_def.op() + " , at " + node_def.name()); + return tensorflow::errors::Unimplemented("Op not supported ", node_def.op(), + " , at ", node_def.name()); } if (permuted_index != -1) { // Apply permutation before extracting dimension for pool_kernel output_tensor = ctx.TransposeTensor( const_cast(output_tensor), permutation_order); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } +#elif NV_TENSORRT_MAJOR > 3 +tensorflow::Status ConvertReduce(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + if (inputs.size() != 2 || !inputs.at(0).is_tensor() || + !inputs.at(1).is_weights()) { + return tensorflow::errors::InvalidArgument( + "Input expects tensor and weights, at", node_def.name()); + } + + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + TRT_ShapedWeights index_list = inputs.at(1).weights(); + + TFAttrs attrs(node_def); + auto index_type = attrs.get("Tidx"); + + // Only expect to handle INT32 as attributes for now + if (index_type != tensorflow::DataType::DT_INT32) { + return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32"); + } + + const auto keep_dims = attrs.get("keep_dims"); + auto index_list_data = + static_cast(const_cast(index_list.GetValues())); + + int axes = 0; + if (index_list.count() == 0) { + return tensorflow::errors::InvalidArgument( + "TRT cannot support reduce on all (batch) dimensions, at", + node_def.name()); + } else { + for (int i = 0; i < index_list.count(); i++) { + if (index_list_data[i] == 0) { + return tensorflow::errors::InvalidArgument( + "TRT cannot reduce at batch dimension, at", node_def.name()); + } + axes |= (1 << (index_list_data[i] - 1)); + } + } + + nvinfer1::ReduceOperation reduce_operation; + if (node_def.op() == "Sum") { + reduce_operation = nvinfer1::ReduceOperation::kSUM; + } else if (node_def.op() == "Prod") { + reduce_operation = nvinfer1::ReduceOperation::kPROD; + } else if (node_def.op() == "Max") { + reduce_operation = nvinfer1::ReduceOperation::kMAX; + } else if (node_def.op() == "Min") { + reduce_operation = nvinfer1::ReduceOperation::kMIN; + } else if (node_def.op() == "Mean") { + reduce_operation = nvinfer1::ReduceOperation::kAVG; + } else { + return tensorflow::errors::Unimplemented("Op not supported ", node_def.op(), + " , at ", node_def.name()); + } + + nvinfer1::ILayer* layer = + ctx.network()->addReduce(*const_cast(tensor), + reduce_operation, axes, keep_dims); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + + outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return tensorflow::Status::OK(); +} +#endif tensorflow::Status ConvertPad(Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, std::vector* outputs) { + // TODO(aaroey): make a routine for this check and reuse it. if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) + !inputs.at(1).is_weights()) { return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at" + node_def.name()); + "Input expects tensor and weights, at", node_def.name()); + } // Implement tensor binaryOp weight [channel wise] for now; const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - auto dims = tensor->getDimensions(); + const auto dims = tensor->getDimensions(); // Restore implicit batch dimension - int nb_dims = dims.nbDims + 1; + const int nb_dims = dims.nbDims + 1; TRT_ShapedWeights pads = inputs.at(1).weights(); @@ -1733,21 +2060,24 @@ tensorflow::Status ConvertPad(Converter& ctx, auto padding_type = attrs.get("Tpaddings"); // TODO(jie): handle data type conversion for TRT? - if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) + if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) { return tensorflow::errors::InvalidArgument( - "Pad only supports explicit padding on 4 dimensional tensor, at " + + "Pad only supports explicit padding on 4 dimensional tensor, at ", node_def.name()); + } // Only expect to handle INT32 as attributes for now - if (padding_type != tensorflow::DataType::DT_INT32) + if (padding_type != tensorflow::DataType::DT_INT32) { return tensorflow::errors::Unimplemented( "Tpaddings supports only DT_INT32"); + } auto pad_data = static_cast(const_cast(pads.GetValues())); std::vector pad_index; for (int i = 0; i < nb_dims; i++) { - if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) + if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) { pad_index.push_back(i); + } } // No padding at all, we should exit @@ -1757,20 +2087,23 @@ tensorflow::Status ConvertPad(Converter& ctx, } // Only supports padding on less than 2 axis GIE-2579 - if (pad_index.size() > 2) + if (pad_index.size() > 2) { return tensorflow::errors::InvalidArgument( "Padding layer does not support padding on > 2"); + } // Padding on batch dimension is not supported - if (pad_index[0] == 0) + if (pad_index[0] == 0) { return tensorflow::errors::InvalidArgument( "Padding layer does not support padding on batch dimension"); + } // Not doing the legit thing here. ignoring padding on dim 1 and 3; // TODO(jie): implement pad as uff parser - if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) + if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) { return tensorflow::errors::Unimplemented( "Padding layer does not support padding on dimension 1 and 3 yet"); + } bool legit_pad = true; nvinfer1::DimsHW pre_padding(0, 0); @@ -1781,6 +2114,7 @@ tensorflow::Status ConvertPad(Converter& ctx, legit_pad = false; tensor = ctx.TransposeTensor(const_cast(tensor), {0, 3, 2, 1}); + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); permuted_pad_index[0] = 3; } @@ -1797,11 +2131,14 @@ tensorflow::Status ConvertPad(Converter& ctx, nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding( *const_cast(tensor), pre_padding, post_padding); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); - if (!legit_pad) + if (!legit_pad) { output_tensor = ctx.TransposeTensor( const_cast(output_tensor), {0, 3, 2, 1}); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); + } outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); @@ -1814,9 +2151,10 @@ tensorflow::Status ConvertConcat(Converter& ctx, // not including the last input (axis) here int input_size = static_cast(inputs.size()) - 1; - if (!inputs.at(0).is_tensor()) + if (!inputs.at(0).is_tensor()) { return tensorflow::errors::InvalidArgument( - "Concat in TRT support only Tensor input, at " + node_def.name()); + "Concat in TRT support only Tensor input, at ", node_def.name()); + } // We are retrieving the axis TRT_ShapedWeights axis = inputs.at(input_size).weights(); @@ -1827,8 +2165,8 @@ tensorflow::Status ConvertConcat(Converter& ctx, // TODO(jie): handle data type // Only expect to handle INT32 as index attributes for now if (index_type != tensorflow::DataType::DT_INT32) - return tensorflow::errors::Unimplemented( - "Tidx supports only DT_INT32, at " + node_def.name()); + return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32, at ", + node_def.name()); int index = *(static_cast(const_cast(axis.GetValues()))); @@ -1836,23 +2174,29 @@ tensorflow::Status ConvertConcat(Converter& ctx, auto dim = inputs.at(0).tensor()->getDimensions(); // dimension check - if (index > dim.nbDims + 1) + if (index > dim.nbDims + 1) { return tensorflow::errors::InvalidArgument( - "Concatenate on axis out of dimension range, at " + node_def.name()); - - if (index == 0) + "Concatenate on axis out of dimension range, at ", node_def.name()); + } + if (index == 0) { return tensorflow::errors::InvalidArgument( - "Concatenate on batch dimension not supported, at " + node_def.name()); + "Concatenate on batch dimension not supported, at ", node_def.name()); + } + if (index < 0) { + index = dim.nbDims + index + 1; + } +#if NV_TENSORRT_MAJOR == 3 // incase we need permutation; std::vector permutation_order(dim.nbDims + 1); for (int i = 0; i < dim.nbDims + 1; i++) permutation_order[i] = i; if (index != 1) { - permutation_order[1] = index - 1; - permutation_order[index - 1] = 1; + permutation_order[1] = index; + permutation_order[index] = 1; } +#endif std::vector inputs_vec; // Shap chack (all input tensor should have same shape) @@ -1860,24 +2204,28 @@ tensorflow::Status ConvertConcat(Converter& ctx, for (int i = 0; i < input_size; i++) { auto tensor_i = inputs.at(i).tensor(); auto dim_i = tensor_i->getDimensions(); - if (dim_i.nbDims != dim.nbDims) + if (dim_i.nbDims != dim.nbDims) { return tensorflow::errors::InvalidArgument( - "Concatenate receives inputs with inconsistent dimensions, at " + + "Concatenate receives inputs with inconsistent dimensions, at ", node_def.name()); - + } for (int j = 0; j < dim.nbDims; j++) { // check dimension consistency on non-concatenate axis - if (j != index - 1 && dim_i.d[j] != dim.d[j]) + if (j != index - 1 && dim_i.d[j] != dim.d[j]) { return tensorflow::errors::InvalidArgument( - "Concatenate receives inputs with inconsistent shape, at" + + "Concatenate receives inputs with inconsistent shape, at", node_def.name()); + } } - // TRT does concatenation only on channel! - if (index != 1) +#if NV_TENSORRT_MAJOR == 3 + // TRT3 does concatenation only on channel! + if (index != 1) { tensor_i = ctx.TransposeTensor(const_cast(tensor_i), permutation_order); - + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor_i, node_def.name()); + } +#endif inputs_vec.push_back(tensor_i); } @@ -1885,11 +2233,18 @@ tensorflow::Status ConvertConcat(Converter& ctx, nvinfer1::IConcatenationLayer* layer = ctx.network()->addConcatenation( const_cast(inputs_vec.data()), inputs_vec.size()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); +#if NV_TENSORRT_MAJOR > 3 + layer->setAxis(index - 1); +#endif nvinfer1::ITensor* output_tensor = layer->getOutput(0); +#if NV_TENSORRT_MAJOR == 3 if (index != 1) { output_tensor = ctx.TransposeTensor(output_tensor, permutation_order); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } +#endif outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } @@ -2008,112 +2363,243 @@ tensorflow::Status ConvertFusedBatchNorm( combined_offset_weights.GetWeightsForTRT(), combined_scale_weights.GetWeightsForTRT(), dummy_power_weights.GetWeightsForTRT()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } -tensorflow::Status ConvertMatMul(Converter& ctx, - const tensorflow::NodeDef& node_def, - const std::vector& inputs, - std::vector* outputs) { - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - - // TODO(jie): transpose! - TFAttrs attrs(node_def); +#if NV_TENSORRT_MAJOR > 3 +tensorflow::Status ConvertMatMulHelper( + Converter& ctx, TRT_TensorOrWeights tensor_input, + TRT_ShapedWeights weights_raw, bool transpose_weight, string node_name, + std::vector* outputs) { + nvinfer1::ITensor* output_tensor; + if (!tensor_input.is_tensor()) { + return tensorflow::errors::InvalidArgument("Input 0 expects tensor"); + } + const nvinfer1::ITensor* tensor = tensor_input.tensor(); - TRT_ShapedWeights weights_ck = inputs.at(1).weights(); - TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_ck); - ReorderCKtoKC(weights_ck, &weights); + TRT_ShapedWeights weights(weights_raw.type_); + if (transpose_weight) { + weights = weights_raw; + } else { + TRT_ShapedWeights weights_ck = weights_raw; + weights = ctx.get_temp_weights_like(weights_ck); + ReorderCKtoKC(weights_raw, &weights); + } TRT_ShapedWeights biases(weights.type_); int noutput = weights.shape_.d[0]; + auto input_dim = tensor->getDimensions(); + while (input_dim.nbDims != 3) { + input_dim.d[input_dim.nbDims++] = 1; + } + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, tensor_input, input_dim, &tensor), node_name); + nvinfer1::IFullyConnectedLayer* layer = ctx.network()->addFullyConnected( *const_cast(tensor), noutput, weights, biases); - - nvinfer1::ITensor* output_tensor = layer->getOutput(0); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_name); + output_tensor = layer->getOutput(0); + + const nvinfer1::ITensor* temp_tensor; + auto output_dim = output_tensor->getDimensions(); + output_dim.nbDims = 1; + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, TRT_TensorOrWeights(output_tensor), output_dim, + &temp_tensor), + node_name); + output_tensor = const_cast(temp_tensor); outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } -tensorflow::Status ConvertReshape( +// inputs are both two dimensional (tensorflow::ops::MatMul) +tensorflow::Status ConvertMatMul(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + if (!inputs.at(0).is_tensor()) { + return tensorflow::errors::InvalidArgument("Input 0 expects tensor, at" + + node_def.name()); + } + + TFAttrs attrs(node_def); + // TODO(jie): INT32 should be converted? + tensorflow::DataType tf_dtype = attrs.get("T"); + if (tf_dtype != tensorflow::DataType::DT_FLOAT && + tf_dtype != tensorflow::DataType::DT_HALF) { + return tensorflow::errors::Unimplemented( + "data type is not supported, for node " + node_def.name() + " got " + + tensorflow::DataTypeString(tf_dtype)); + } + bool transpose_a = attrs.get("transpose_a"); + bool transpose_b = attrs.get("transpose_b"); + + // FullyConnected: + if (transpose_a) { + return tensorflow::errors::Internal( + "Transpose_a is not supported for TensorRT FullyConnected (op: " + + node_def.op() + "), at: " + node_def.name()); + } + if (inputs.at(1).is_tensor()) { + return tensorflow::errors::Internal( + "Operand 1 must be constant for TensorRT FullyConnected (op: " + + node_def.op() + "), at: " + node_def.name()); + } + return ConvertMatMulHelper(ctx, inputs.at(0), inputs.at(1).weights(), + transpose_b, node_def.name(), outputs); +} + +tensorflow::Status ConvertBatchMatMul( Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, std::vector* outputs) { - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) - return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at" + node_def.name()); + TFAttrs attrs(node_def); - // implement tensor binaryOp weight [channel wise] for now; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - auto dims = tensor->getDimensions(); - // restore implicit batch dimension + // TODO(jie): INT32 should be converted? + tensorflow::DataType tf_dtype = attrs.get("T"); + if (tf_dtype != tensorflow::DataType::DT_FLOAT && + tf_dtype != tensorflow::DataType::DT_HALF) { + return tensorflow::errors::Unimplemented( + "data type is not supported, for node " + node_def.name() + " got " + + tensorflow::DataTypeString(tf_dtype)); + } - TRT_ShapedWeights shape = inputs.at(1).weights(); + bool transpose_a = attrs.get("adj_x"); + bool transpose_b = attrs.get("adj_y"); - TFAttrs attrs(node_def); + auto dims = inputs.at(0).shape(); + if (dims.nbDims == 1) { // NC * CK is only supported through fully connected + if (transpose_a == false && inputs.at(0).is_tensor() && + inputs.at(1).is_weights()) { + return ConvertMatMulHelper(ctx, inputs.at(0), inputs.at(1).weights(), + transpose_b, node_def.name(), outputs); + } else { + return tensorflow::errors::InvalidArgument( + "Invalid configuration for MatMul, at: " + node_def.name()); + } + } - auto padding_type = attrs.get("Tshape"); + const nvinfer1::ITensor* tensor_l; + const nvinfer1::ITensor* tensor_r; + auto dims_l = inputs.at(0).shape(); + auto dims_r = inputs.at(1).shape(); + if (inputs.at(0).is_weights()) { + if (inputs.at(0).shape().d[0] != 1) { + return tensorflow::errors::InvalidArgument( + "Input 0 as weight assumes broadcast across batch for MatMul, at: " + + node_def.name()); + } else { + for (int i = 0; i < dims_l.nbDims - 1; i++) { + dims_l.d[i] = dims_l.d[i + 1]; + } + dims_l.nbDims--; + } + } + if (inputs.at(1).is_weights()) { + if (inputs.at(1).shape().d[0] != 1) { + return tensorflow::errors::InvalidArgument( + "Input 1 as weight assumes broadcast across batch for MatMul, at: " + + node_def.name()); + } else { + for (int i = 0; i < dims_r.nbDims - 1; i++) { + dims_r.d[i] = dims_r.d[i + 1]; + } + dims_r.nbDims--; + } + } - if (shape.shape_.nbDims != 1) - return tensorflow::errors::InvalidArgument( - "reshape new shape is not 1 dimensional, at " + node_def.name()); + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, inputs.at(0), dims_l, &tensor_l), + node_def.name()); + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, inputs.at(1), dims_r, &tensor_r), + node_def.name()); - // Only expect to handle INT32 as attributes for now - if (padding_type != tensorflow::DataType::DT_INT32) - return tensorflow::errors::Unimplemented( - "reshape new shape supports only DT_INT32, at " + node_def.name()); + nvinfer1::IMatrixMultiplyLayer* layer = ctx.network()->addMatrixMultiply( + *const_cast(tensor_l), transpose_a, + *const_cast(tensor_r), transpose_b); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} +#endif - auto shape_data = static_cast(const_cast(shape.GetValues())); +#if NV_TENSORRT_MAJOR > 3 +tensorflow::Status ConvertSoftmax( + Converter& ctx, const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - if (shape_data[0] != -1) + int nbDims = tensor->getDimensions().nbDims; + if (nbDims == 0) { return tensorflow::errors::InvalidArgument( - "reshape new shape first dimension is not -1, at " + node_def.name()); + "TensorRT Softmax cannot apply on batch dimension, at" + + node_def.name()); + } + nvinfer1::ISoftMaxLayer* layer = + ctx.network()->addSoftMax(*const_cast(tensor)); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + // Tensorflow SoftMax assumes applying softmax on the last dimension. + layer->setAxes(1 << (nbDims - 1)); - auto shape_num_dims = shape.shape_.d[0]; - VLOG(2) << "shape dimensions: " << shape_num_dims; - int volume_w = 1; - for (int i = 1; i < shape.shape_.d[0]; i++) volume_w *= shape_data[i]; + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} +#endif - int volume_t = 1; - for (int i = 0; i < dims.nbDims; i++) volume_t *= dims.d[i]; +#if NV_TENSORRT_MAJOR > 3 +tensorflow::Status ConvertTopK(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - VLOG(2) << "volume: " << volume_t << " volume weights: " << volume_w; - if (volume_w != volume_t) + int nbDims = tensor->getDimensions().nbDims; + if (nbDims == 0) { return tensorflow::errors::InvalidArgument( - "volume does not agree between tensor and new shape, at " + - node_def.name()); + "TensorRT TopK cannot apply on batch dimension, at" + node_def.name()); + } - nvinfer1::IShuffleLayer* layer = - ctx.network()->addShuffle(*const_cast(tensor)); + TRT_ShapedWeights k_w = inputs.at(1).weights(); + int k = *(static_cast(const_cast(k_w.GetValues()))); - nvinfer1::Dims reshape_dims; - VLOG(2) << "new dimension: " << shape_num_dims - 1; - reshape_dims.nbDims = shape_num_dims - 1; - for (int32_t i = 0; i < reshape_dims.nbDims; ++i) { - reshape_dims.d[i] = shape_data[i + 1]; + nvinfer1::TopKOperation op; + uint32_t reducedAxes = 0; + if (node_def.op() == "TopKV2") { + op = nvinfer1::TopKOperation::kMAX; + reducedAxes |= 1 << (nbDims - 1); + } else { + return tensorflow::errors::Unimplemented( + "Operation: " + node_def.op() + + " not implemented, at: " + node_def.name()); } - layer->setReshapeDimensions(reshape_dims); - VLOG(2) << "new dimension: " << shape_num_dims - 1; - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - auto dims_output = output_tensor->getDimensions(); - VLOG(2) << "output tensor dimension:" << dims_output.nbDims; - outputs->push_back(TRT_TensorOrWeights(output_tensor)); + nvinfer1::ITopKLayer* layer = ctx.network()->addTopK( + *const_cast(tensor), op, k, reducedAxes); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + + nvinfer1::ITensor* output_value_tensor = layer->getOutput(0); + nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1); + outputs->push_back(TRT_TensorOrWeights(output_value_tensor)); + outputs->push_back(TRT_TensorOrWeights(output_indices_tensor)); return tensorflow::Status::OK(); } +#endif void Converter::register_op_converters() { // vgg_16 slim implementation - op_registry_["Placeholder"] = ConvertPlaceholder; op_registry_["Conv2D"] = ConvertConv2D; op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; op_registry_["Relu"] = ConvertActivation; op_registry_["MaxPool"] = ConvertPool; op_registry_["AvgPool"] = ConvertPool; - // This could be really handled as ConvertBinary op_registry_["BiasAdd"] = ConvertScale; op_registry_["Const"] = ConvertConst; // TODO(ben,jie): this is a temp hack. @@ -2124,576 +2610,350 @@ void Converter::register_op_converters() { op_registry_["Add"] = ConvertBinary; op_registry_["Mul"] = ConvertBinary; op_registry_["Sub"] = ConvertBinary; - op_registry_["Rsqrt"] = ConvertUnary; - op_registry_["Mean"] = ConvertReduce; op_registry_["Pad"] = ConvertPad; - // TODO(ben,jie): Add more ops op_registry_["ConcatV2"] = ConvertConcat; - op_registry_["MatMul"] = ConvertMatMul; - op_registry_["Reshape"] = ConvertReshape; op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; + op_registry_["Div"] = ConvertBinary; + op_registry_["RealDiv"] = ConvertBinary; + + op_registry_["Rsqrt"] = ConvertUnary; + op_registry_["Reciprocal"] = ConvertUnary; + op_registry_["Exp"] = ConvertUnary; + op_registry_["Log"] = ConvertUnary; + op_registry_["Sqrt"] = ConvertUnary; + op_registry_["Abs"] = ConvertUnary; + op_registry_["Neg"] = ConvertUnary; +#if NV_TENSORRT_MAJOR == 3 + op_registry_["Mean"] = ConvertReducePool; +#endif +#if NV_TENSORRT_MAJOR > 3 + op_registry_["Sum"] = ConvertReduce; + op_registry_["Prod"] = ConvertReduce; + op_registry_["Max"] = ConvertReduce; + op_registry_["Min"] = ConvertReduce; + op_registry_["Mean"] = ConvertReduce; + op_registry_["Maximum"] = ConvertBinary; + op_registry_["Minimum"] = ConvertBinary; + op_registry_["Softmax"] = ConvertSoftmax; + op_registry_["MatMul"] = ConvertMatMul; + op_registry_["BatchMatMul"] = ConvertBatchMatMul; + op_registry_["TopKV2"] = ConvertTopK; +#endif + plugin_converter_ = ConvertPlugin; } } // namespace -tensorflow::Status ConvertCalibrationNodeToEngineNode( - tensorflow::Graph& graph, tensorflow::Node* c_node) { - const auto ndef = c_node->def(); - - TFAttrs attrs(ndef); - std::vector segment_nodes( - attrs.get>("segment_nodes")); - std::vector output_nodes( - attrs.get>("segment_output_names")); - std::vector input_names( - attrs.get>("input_names")); - string res_name = attrs.get("resource_name"); - VLOG(1) << "Node name " << c_node->name() << " res_name " << res_name; - string engine_name = "my_trt_op"; - { - const auto node_id = tensorflow::str_util::Split(res_name, "_"); - engine_name += node_id.back(); - } - std::map node_maps; - - for (auto n : graph.op_nodes()) { - node_maps.insert({n->name(), n}); - } - std::set subgraph_ids; - for (const auto internal_node : segment_nodes) { - subgraph_ids.insert(node_maps.at(internal_node)->id()); - } - if (VLOG_IS_ON(2)) { - string node_names = StrCat(c_node->name(), " segment nodes= "); - - for (const auto& node_name : segment_nodes) { - StrAppend(&node_names, node_name, ", "); - } - VLOG(2) << node_names; - } - - VLOG(1) << "Output Nodes:"; - std::vector out_types; - std::vector out_edges; - - for (auto& i : output_nodes) { - auto node_port = tensorflow::str_util::Split(i, ":"); - VLOG(1) << " " << i << " in graph " << node_maps.count(i); - auto out_node_name = node_port.at(0); - if (node_port.size() > 1) { - VLOG(1) << "Multi port output" << node_port.at(0) << " " - << node_port.at(1) << " size=" << node_port.size(); - } - auto node_it = node_maps.find(out_node_name); - if (node_it != node_maps.end()) { - tensorflow::Node* out_node = node_it->second; - int port = 0; - if (node_port.size() == 2) { - port = std::strtoul(node_port.at(1).c_str(), nullptr, 10); - out_types.push_back(out_node->output_type(port)); - } else { - out_types.push_back(out_node->output_type(0)); - } - for (auto out_edge : out_node->out_edges()) { - if (subgraph_ids.count(out_edge->dst()->id())) - continue; // skip internal edges; - if (out_edge->src_output() == port) { - out_edges.push_back(out_edge); - VLOG(1) << "OUTPUT EDGE " << out_edge->src()->name() << ":" - << out_edge->src_output() << " -> " << out_edge->dst()->name() - << ":" << out_edge->dst_input(); - } - } - } else { - LOG(WARNING) << " couldn't find output node " << out_node_name; - } - } - if (VLOG_IS_ON(1)) { - VLOG(1) << c_node->name() << " Input Nodes:"; - for (auto& i : input_names) { - VLOG(1) << " Input " << i << " in graph " << node_maps.count(i); - } - } - auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); - auto resmgr = trt_rm->getManager("TRTCalibOps"); - tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; - auto status = resmgr->Lookup(res_name, res_name, &calib_res); - if (!status.ok() || !calib_res->calibrator_) { - return tensorflow::errors::FailedPrecondition( - "You must run calibration" - " and inference conversion in the same process"); +tensorflow::Status ConvertGraphDefToEngine( + const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, + size_t max_workspace_size_bytes, + const std::vector& input_shapes, + Logger* logger, nvinfer1::IGpuAllocator* allocator, + TRTInt8Calibrator* calibrator, + TrtUniquePtrType* engine, + bool* convert_successfully) { + engine->reset(); + if (convert_successfully) *convert_successfully = false; + + // Create the builder. + TrtUniquePtrType builder( + nvinfer1::createInferBuilder(*logger)); + builder->setMaxBatchSize(max_batch_size); + // TODO(aaroey): use the allocator to allocate the TRT workspace. + builder->setMaxWorkspaceSize(max_workspace_size_bytes); +#if NV_TENSORRT_MAJOR > 3 + builder->setGpuAllocator(allocator); +#endif + if (precision_mode == FP16MODE) { + builder->setHalf2Mode(true); + } else if (precision_mode == INT8MODE) { + builder->setInt8Mode(true); + builder->setInt8Calibrator(calibrator); } - calib_res->calibrator_->setDone(); - calib_res->thr_->join(); - delete calib_res->thr_; - if (!calib_res->engine_) { - LOG(ERROR) << "Calibration failed!, engine does not exist. Did you run " - "calibration graph?"; - return tensorflow::errors::FailedPrecondition( - "Calibration graph needs to be executed on" - " calibration data before convertsion to inference graph"); - } - auto weight_rmgr = trt_rm->getManager("WeightStore"); - TF_CHECK_OK(weight_rmgr->Delete( - res_name, res_name)); - auto engine_plan = calib_res->engine_->serialize(); - calib_res->engine_->destroy(); - calib_res->network_->destroy(); - calib_res->builder_->destroy(); - calib_res->thr_ = nullptr; - calib_res->engine_ = nullptr; - calib_res->builder_ = nullptr; - tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); - std::vector income_edges; - income_edges.resize(c_node->num_inputs()); - for (const auto in_edge : c_node->in_edges()) { - auto src = in_edge->src(); - int dest_port = in_edge->dst_input(); - VLOG(1) << "Incoming connection " << src->name() << ":" - << in_edge->src_output() << " -> " << c_node->name() << ":" - << dest_port; - income_edges.at(dest_port) = {src->name(), in_edge->src_output(), - c_node->input_type(dest_port)}; - } - tensorflow::gtl::ArraySlice input_list( - income_edges); - if (VLOG_IS_ON(2)) { - for (const auto& inp : input_list) { - VLOG(2) << " Input from inputlist " << inp.node << ":" << inp.index << " " - << tensorflow::DataTypeString(inp.data_type); - } - } - op_builder.Input(input_list); - tensorflow::NodeDef engine_node; - const char* engine_plan_data = static_cast(engine_plan->data()); - string engine_plan_string(engine_plan_data, - engine_plan_data + engine_plan->size()); - status = op_builder.Attr("serialized_engine", engine_plan_string) - .Attr("input_nodes", input_names) - .Attr("output_nodes", output_nodes) - .Attr("OutT", out_types) - .Finalize(&engine_node); - if (!status.ok()) { - LOG(ERROR) << "Engine Node creation failed"; - return status; - } - auto trt_engine_node = graph.AddNode(engine_node, &status); - TF_RETURN_IF_ERROR(status); - std::map port_map; - for (size_t t = 0; t < output_nodes.size(); t++) { - port_map.insert({output_nodes.at(t), t}); - } - for (auto& i : out_edges) { - string s(i->src()->name()); - if (i->src_output()) StrAppend(&s, ":", i->src_output()); - int out_port = port_map.at(s); - VLOG(1) << "Connecting " << trt_engine_node->name() << ":" << out_port - << " -> " << i->dst()->name() << ":" << i->dst_input(); - TF_RETURN_IF_ERROR( - graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input())); - } - for (const auto ed : trt_engine_node->in_edges()) { - VLOG(1) << "In Edge " << ed->src()->name() << ":" << ed->src_output() - << " -> " << ed->dst()->name() << ":" << ed->dst_input(); - } - for (const auto ed : trt_engine_node->out_edges()) { - VLOG(1) << "Out Edge " << ed->src()->name() << ":" << ed->src_output() - << " -> " << ed->dst()->name() << ":" << ed->dst_input(); - } - VLOG(1) << "Segment nodes:"; - for (auto& i : segment_nodes) { - VLOG(1) << " " << i << " in graph " << node_maps.count(i); - auto it = node_maps.find(i); - if (it != node_maps.end()) { - graph.RemoveNode(it->second); - } - } - graph.RemoveNode(c_node); - return tensorflow::Status::OK(); -} - -tensorflow::Status ReverseTopologicalSort( - const tensorrt::convert::SubGraphParams& s, - std::list* order) { - std::vector order_vec; - tensorflow::GetPostOrder(s.graph, &order_vec); - // Select just the subgraph - for (tensorflow::Node* node : order_vec) { - if (s.subgraph_node_ids.count(node->id())) { - // We want topological order to contstruct the - // network layer by layer - order->push_front(node); - } + // Create the network. + auto trt_network = + TrtUniquePtrType(builder->createNetwork()); + if (!trt_network) { + return tensorflow::errors::Internal( + "Failed to create TensorRT network object"); } - return tensorflow::Status::OK(); -} + auto ws = std::unique_ptr(new TRTWeightStore()); -tensorflow::Status SetInputList( - const tensorrt::convert::SubGraphParams& s, - tensorflow::NodeDefBuilder* op_builder, - const std::vector* input_names, - std::vector* input_dtypes) { - std::vector income_edges; - VLOG(2) << "input edge size: " << input_names->size(); - for (size_t i = 0; i < input_names->size(); ++i) { - VLOG(2) << "input edges: " << i << " " << input_names->at(i); - int output_idx = s.input_inds.at(i).second; - // we wired up the input here already, it is redundant to do it again in - // ConvertSubGraphToTensorRT(convert_graph.cc) - auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut( - input_names->at(i), output_idx, input_dtypes->at(i)); - income_edges.push_back(incoming_edge); - } - tensorflow::gtl::ArraySlice input_list( - income_edges); - op_builder->Input(input_list); - return tensorflow::Status::OK(); -} + // Build the network + VLOG(1) << "Starting engine conversion "; + Converter converter(trt_network.get(), ws.get(), precision_mode == FP16MODE); + std::vector> output_tensors; + // Graph nodes are already topologically sorted during construction + for (const auto& node_def : gdef.node()) { + string node_name = node_def.name(); + VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op(); + if (tensorflow::str_util::StartsWith(node_name, kInputPHName) && + (node_def.op() == "Placeholder")) { + nvinfer1::DimsCHW input_dim_pseudo_chw; + for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0; + int32 slot_number = -1; + if (!tensorflow::strings::safe_strto32( + node_name.c_str() + strlen(kInputPHName), &slot_number)) { + return tensorflow::errors::InvalidArgument( + "Failed to parse slot number from ", node_name); + } + nvinfer1::DataType dtype; + auto shape = input_shapes.at(slot_number); + auto status = ValidateInputProperties( + shape, node_def.attr().at("dtype").type(), &dtype); + if (!status.ok()) { + const string error_message = + StrCat("Validation failed for ", node_name, " and input slot ", + slot_number, ": ", status.error_message()); + LOG(WARNING) << error_message; + return Status(status.code(), error_message); + } + if (VLOG_IS_ON(1)) { + string dim_str("dims="); + StrAppend(&dim_str, "[ ", shape.dim_size(0)); + for (int i = 1; i < shape.dims(); i++) { + StrAppend(&dim_str, ", ", shape.dim_size(i)); + } + StrAppend(&dim_str, " ]"); + VLOG(1) << dim_str; + } + for (int i = 1; i < shape.dims(); i++) { + input_dim_pseudo_chw.d[i - 1] = shape.dim_size(i); + } -string SubgraphNameScopeGenerator(const std::list* order) { - string subgraph_name_scope; - if (!order->empty()) { - subgraph_name_scope = order->front()->name(); - } - for (const tensorflow::Node* node : *order) { - subgraph_name_scope = GetCommonNameScope(subgraph_name_scope, node->name()); + input_dim_pseudo_chw.nbDims = shape.dims() - 1; + nvinfer1::ITensor* input_tensor = converter.network()->addInput( + node_name.c_str(), dtype, input_dim_pseudo_chw); + if (!input_tensor) { + return tensorflow::errors::InvalidArgument( + "Failed to create Input layer tensor ", node_name, + " rank=", shape.dims() - 1); + } + VLOG(1) << "Input tensor name :" << node_name; + if (!converter.insert_input_tensor(node_name, input_tensor)) { + return tensorflow::errors::AlreadyExists( + "Output tensor already exists for op: " + node_name); + } + } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && + (node_def.op() == "Identity")) { + int32 slot_number = -1; + if (!tensorflow::strings::safe_strto32( + node_name.c_str() + strlen(kOutputPHName), &slot_number)) { + return tensorflow::errors::InvalidArgument( + "Failed to parse slot number from ", node_name); + } + if (output_tensors.size() <= slot_number) { + output_tensors.resize(slot_number + 1); + } + output_tensors.at(slot_number) = {node_def.input(0), node_name}; + } else { + VLOG(2) << "Converting node: " << node_def.name() << " , " + << node_def.op(); + TF_RETURN_IF_ERROR(converter.convert_node(node_def)); + } } - // TODO(sami,ben,jie): proper naming! - return subgraph_name_scope; -} - -tensorflow::Status ConvertSubgraph( - Converter& converter, tensorrt::convert::SubGraphParams& s, - std::list* order, std::vector* input_names, - std::vector* input_dtypes, - std::vector* output_names, - std::vector* output_dtypes, - const string& engine_name) { - std::set added_tensors; - for (const std::pair& input : s.input_inds) { - VLOG(2) << "parsing input. Node id= " << input.first; - int node_id = input.first; - int output_idx = input.second; - tensorflow::Node* node = s.graph.FindNodeId(node_id); - auto node_name = node->name(); - // input_names should use the node name in the graph - // here it should be the input tensor name -> matching the binding - // insert original node name without port - auto tensor_name = node_name; - if (output_idx != 0) { - tensor_name = StrCat(tensor_name, ":", output_idx); - } - - VLOG(2) << "input name: " << node_name << " tensor_name: " << tensor_name - << " idx: " << output_idx; - - auto shape_inference_node_name = node_name; - auto shape_inference_output_idx = output_idx; - // rewire the shape inference to original node in the graph - if (s.output_edge_map->count(tensor_name)) { - shape_inference_node_name = s.output_edge_map->at(tensor_name).second; - shape_inference_output_idx = s.output_edge_map->at(tensor_name).first; - } - if (shape_inference_output_idx < 0) continue; - VLOG(2) << "shapeinference name: " << shape_inference_node_name - << " idx: " << shape_inference_output_idx; - - if (!s.graph_properties.HasOutputProperties(shape_inference_node_name)) - return tensorflow::errors::Internal("failed to find input node: " + - shape_inference_node_name); - - auto op_info_vec = - s.graph_properties.GetOutputProperties(shape_inference_node_name); - if (static_cast(op_info_vec.size()) <= shape_inference_output_idx) - return tensorflow::errors::Internal( - "accessing output index of: ", shape_inference_output_idx, - ", at node: ", shape_inference_node_name, - " with output entry from shape_map: ", op_info_vec.size()); - - auto op_info = op_info_vec.at(shape_inference_output_idx); - tensorflow::DataType tf_dtype = op_info.dtype(); - - nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); - auto type_status = ConvertDType(tf_dtype, &dtype); - if (type_status != tensorflow::Status::OK()) { - LOG(WARNING) << "Type conversion failed for " << node_name; - return type_status; - } - - VLOG(2) << "Accessing output index of: " << output_idx - << ", at node: " << node_name - << " with output entry from shape_map: " << op_info_vec.size(); - // TODO(ben,jie): update TRT input format/dimension - nvinfer1::DimsCHW input_dim_pseudo_chw; - for (int i = 0; i < 3; i++) input_dim_pseudo_chw.d[i] = 1; - - // TODO(jie): TRT 3.x only support 4 dimensional input tensor. - // update the code once TRT 4.0 comes out. - if (op_info.shape().dim_size() != 4) { - string err_str = "Require 4 dimensional input."; - StrAppend(&err_str, " Got ", op_info.shape().dim_size(), " ", - shape_inference_node_name); - return tensorflow::errors::Unimplemented(err_str); - } - - for (int i = 1; i < op_info.shape().dim_size(); i++) { - VLOG(2) << "dimension: " << i - << " , size: " << op_info.shape().dim(i).size(); - input_dim_pseudo_chw.d[i - 1] = op_info.shape().dim(i).size(); - } - - // TODO(ben,jie): proper way to restore input tensor name? - auto input_tensor_name = node_name; - if (output_idx != 0) { - input_tensor_name = StrCat(node_name, ":", output_idx); - } - if (added_tensors.count(input_tensor_name)) continue; - added_tensors.insert(input_tensor_name); - input_names->push_back(input_tensor_name); - input_dtypes->push_back(tf_dtype); - nvinfer1::ITensor* input_tensor = converter.network()->addInput( - input_tensor_name.c_str(), dtype, input_dim_pseudo_chw); - - if (!input_tensor) - return tensorflow::errors::InvalidArgument( - "Failed to create Input layer"); - VLOG(2) << "Input tensor name :" << input_tensor_name; - - if (!converter.insert_input_tensor(input_tensor_name, input_tensor)) - return tensorflow::errors::AlreadyExists( - "Output tensor already exists for op: " + input_tensor_name); - } - - for (const tensorflow::Node* node : *order) { - const tensorflow::NodeDef& node_def = node->def(); - VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op(); - TF_RETURN_IF_ERROR(converter.convert_node(node_def)); - } - - VLOG(2) << "Finished conversion"; - - // Gather output metadata - int trt_engine_op_output_idx = 0; - added_tensors.clear(); - for (const std::pair& output : s.output_inds) { - int node_id = output.first; - int output_idx = output.second; - tensorflow::Node* node = s.graph.FindNodeId(node_id); - string op_name = node->name(); - string tensor_name = op_name; - - s.output_edge_map->insert( - {trt_engine_op_output_idx == 0 - ? engine_name - : StrCat(engine_name, ":", trt_engine_op_output_idx), - {output_idx, tensor_name}}); - trt_engine_op_output_idx++; - if (output_idx != 0) - tensorflow::strings::StrAppend(&tensor_name, ":", output_idx); - VLOG(2) << "Output tensor name: " << tensor_name; - if (added_tensors.count(tensor_name)) continue; - added_tensors.insert(tensor_name); - output_names->push_back(tensor_name); - auto tensor_or_weights = converter.get_tensor(tensor_name); + for (const auto& output : output_tensors) { + auto tensor_or_weights = converter.get_tensor(output.first); if (!tensor_or_weights.is_tensor()) { - return tensorflow::errors::InvalidArgument("Output node '" + tensor_name + - "' is weights not tensor"); + return tensorflow::errors::InvalidArgument( + "Output node '" + output.first + "' is weights not tensor"); } nvinfer1::ITensor* tensor = tensor_or_weights.tensor(); + tensor->setName(output.second.c_str()); if (!tensor) { return tensorflow::errors::NotFound("Output tensor not found: " + - tensor_name); + output.first); } + VLOG(1) << "Marking output tensor " << output.first << ", as output tensor " + << output.second; + converter.network()->markOutput(*tensor); - tensorflow::DataType tf_dtype = node->output_type(output_idx); - output_dtypes->push_back(tf_dtype); - nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT; - TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); - tensor->setType(trt_dtype); } + if (convert_successfully) *convert_successfully = true; + // Build the engine. + VLOG(1) << "Starting engine creation"; + engine->reset(builder->buildCudaEngine(*converter.network())); + if (engine->get() == nullptr) { + return tensorflow::errors::Internal("Failed to build TensorRT engine"); + } + VLOG(1) << "Finished conversion"; return tensorflow::Status::OK(); } -tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) { - // Visit nodes in reverse topological order and construct the TRT network. - // Toposort - std::list order; - TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order)); - - static int static_id = 0; - string subgraph_name_scope = SubgraphNameScopeGenerator(&order); - // TODO(sami,ben,jie): proper naming! - string calib_op_name = - StrCat(subgraph_name_scope, "my_trt_calib_op_", static_id); - string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id); - static_id++; - - auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance(); - auto op_rmgr = trt_rmgr->getManager("TRTCalibOps"); - auto op_res = new tensorflow::tensorrt::TRTCalibrationResource(); - TF_CHECK_OK(op_rmgr->Create(calib_op_name, calib_op_name, op_res)); - op_res->logger_ = new tensorflow::tensorrt::Logger(); - cudaSetDevice(s.cuda_gpu_id_); - op_res->builder_ = nvinfer1::createInferBuilder(*(op_res->logger_)); - op_res->allocator_ = s.allocator_; -#if NV_TENSORRT_MAJOR > 3 - op_res->builder_->setGpuAllocator(s.allocator_.get()); -#endif - if (!op_res->builder_) { - return tensorflow::errors::Internal( - "failed to create TensorRT builder object"); - } - - op_res->network_ = op_res->builder_->createNetwork(); - if (!op_res->network_) { - return tensorflow::errors::Internal( - "failed to create TensorRT network object"); +tensorflow::Status ConvertSegmentToGraphDef( + const tensorflow::Graph* graph, + const tensorflow::grappler::GraphProperties& graph_properties, + const std::set& subgraph_node_names, + const std::vector& subgraph_node_ids, // In topological order + std::vector* connections, + tensorflow::GraphDef* segment_def, string* common_scope) { + std::set marker_nodes; + // Update connection shapes/data types and add corresponding input/output + // nodes in the segment graphdef. + for (size_t i = 0; i < connections->size(); ++i) { + auto& connection = connections->at(i); + if (connection.is_control_edge()) continue; + auto outside_node = graph->FindNodeId(connection.outside_id); + if (!outside_node) { + // This should never happen, unless the original graph is problematic. + return tensorflow::errors::NotFound( + "Cannot find node with id ", connection.outside_id, " in the graph."); + } + // Updates the shape and data types of input/output connections. + tensorflow::DataType dtype; + tensorflow::PartialTensorShape partial_shape; + if (connection.is_input_edge) { + GetInputProperties(graph_properties, + graph->FindNodeId(connection.outside_id), + connection.outside_port, &partial_shape, &dtype); + connection.outside_shape = partial_shape; + } else { + GetOutputProperties(graph_properties, + graph->FindNodeId(connection.outside_id), + connection.outside_port, &partial_shape, &dtype); + connection.inside_shape = partial_shape; + } + connection.connection_type = dtype; + + // Add dummy input/output nodes to the segment graphdef. + if (connection.is_input_edge) { + const string node_name = StrCat(kInputPHName, connection.port_number); + if (marker_nodes.count(node_name)) { + VLOG(1) << "Reusing input " << node_name << " for the edge " + << connection.outside_node_name << ":" + << connection.outside_port << " -> " + << connection.inside_node_name << ":" << connection.inside_port; + continue; + } + marker_nodes.insert(node_name); + auto seg_node = segment_def->add_node(); + tensorflow::NodeDefBuilder builder(node_name, "Placeholder"); + auto status = builder.Attr("shape", partial_shape) + .Attr("dtype", dtype) + .Finalize(seg_node); + VLOG(1) << "Constructing input " << node_name << " for the edge " + << connection.outside_node_name << ":" << connection.outside_port + << " -> " << connection.inside_node_name << ":" + << connection.inside_port; + } else { + const string node_name = StrCat(kOutputPHName, connection.port_number); + if (marker_nodes.count(node_name)) { + VLOG(1) << "Reusing output " << node_name << " for the edge " + << connection.inside_node_name << ":" << connection.inside_port + << " -> " << connection.outside_node_name << ":" + << connection.outside_port; + continue; + } + marker_nodes.insert(node_name); + auto seg_node = segment_def->add_node(); + tensorflow::NodeDefBuilder builder(node_name, "Identity"); + auto status = builder.Input(connection.inside_node_name, 0, dtype) + .Finalize(seg_node); + VLOG(1) << "Constructing output " << node_name << " for the edge " + << connection.inside_node_name << ":" << connection.inside_port + << " -> " << connection.outside_node_name << ":" + << connection.outside_port; + } + } // for each connection. + + std::unordered_map old_to_new_id_map; + // Copy internal nodes to new graphdef + string local_scope = graph->FindNodeId(*subgraph_node_ids.begin())->name(); + for (const auto node_id : subgraph_node_ids) { + const auto node = graph->FindNodeId(node_id); + local_scope = GetCommonNameScope(local_scope, node->name()); + old_to_new_id_map[node_id] = segment_def->node_size(); + auto snode = segment_def->add_node(); + snode->CopyFrom(node->def()); + VLOG(2) << "Copying " << snode->name() << " to subgraph"; + } + // Update the inputs of the new input nodes to point to placeholder nodes. + for (int i = 0; i < connections->size(); ++i) { + auto& connection = connections->at(i); + if (connection.is_control_edge() || !connection.is_input_edge) continue; + auto snode = + segment_def->mutable_node(old_to_new_id_map[connection.inside_id]); + const string placeholder_name = + StrCat(kInputPHName, connection.port_number); + VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port + << " from " << snode->input(connection.inside_port) << " to " + << placeholder_name; + snode->set_input(connection.inside_port, placeholder_name); + } + // Remove control inputs that are not inside the segment. + for (int i = 0; i < segment_def->node_size(); ++i) { + auto snode = segment_def->mutable_node(i); + const int input_size = snode->input_size(); + int input_idx = 0; + int actual_input_idx = 0; + while (input_idx < input_size) { + TensorId input = ParseTensorName(snode->input(input_idx)); + if (!subgraph_node_names.count( + string(input.first.data(), input.first.size())) && + !str_util::StartsWith(input.first, kInputPHName)) { + if (input.second == Graph::kControlSlot) { + VLOG(1) << "... removing control inputs " << input.first + << " from subgraph."; + ++input_idx; + continue; + } else { + return tensorflow::errors::InvalidArgument( + "Found non control input outside the segment that is not an " + "engine connection to ", + snode->name(), ": ", input.first); + } + } + if (actual_input_idx != input_idx) { + snode->set_input(actual_input_idx, snode->input(input_idx)); + } + ++input_idx; + ++actual_input_idx; + } + for (int remove = input_size - actual_input_idx; remove > 0; --remove) { + snode->mutable_input()->RemoveLast(); + } } - - // Build the network - auto weight_rmgr = trt_rmgr->getManager("WeightStore"); - auto ws = new tensorflow::tensorrt::TRTWeightStore(); - TF_CHECK_OK(weight_rmgr->Create(calib_op_name, calib_op_name, ws)); - Converter converter(op_res->network_, ws, s.precision_mode == FP16MODE); - - std::vector input_names; - std::vector input_dtypes; - std::vector output_names; - std::vector output_dtypes; - TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names, - &input_dtypes, &output_names, - &output_dtypes, engine_name)); - - VLOG(2) << "Finished processing outputs"; - - // Build the engine - op_res->builder_->setMaxBatchSize(s.max_batch_size); - op_res->builder_->setMaxWorkspaceSize(s.max_workspace_size_bytes); - VLOG(0) << "Max batch size= " << s.max_batch_size - << " max workspace size= " << s.max_workspace_size_bytes; - - // Build the TRT op - // TODO(sami,ben,jie): proper naming! - tensorflow::NodeDefBuilder op_builder(calib_op_name, "TRTCalibOp"); - TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes)); - - std::vector segment_names; - segment_names.reserve(s.subgraph_node_ids.size()); - for (int i : s.subgraph_node_ids) { - auto node = s.graph.FindNodeId(i); - segment_names.push_back(node->name()); - } - LOG(INFO) << "finished op preparation"; - - auto status = op_builder.Attr("segment_nodes", segment_names) - .Attr("input_names", input_names) - .Attr("segment_output_names", output_names) - .Attr("resource_name", calib_op_name) - .Finalize(s.trt_node); - - LOG(INFO) << status.ToString(); - LOG(INFO) << "finished op building"; - + *common_scope = local_scope; + VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph"; return tensorflow::Status::OK(); } -tensorflow::Status ConvertSubGraphToTensorRTNodeDef( - tensorrt::convert::SubGraphParams& s) { - // Visit nodes in reverse topological order and construct the TRT network. - std::list order; - TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order)); - - static int static_id = 0; - string subgraph_name_scope = SubgraphNameScopeGenerator(&order); - string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id++); - - tensorflow::tensorrt::Logger trt_logger; - cudaSetDevice(s.cuda_gpu_id_); - auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger)); - if (!trt_builder) { - return tensorflow::errors::Internal( - "Failed to create TensorRT builder object"); +bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const { + if (in_edge->IsControlEdge()) return true; + PartialTensorShape shape; + tensorflow::DataType dtype; + GetInputProperties(graph_properties_, in_edge->src(), in_edge->src_output(), + &shape, &dtype); + nvinfer1::DataType trt_dtype; + Status status = ValidateInputProperties(shape, dtype, &trt_dtype); + if (!status.ok()) { + VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() + << ": " << status; + return false; } -#if NV_TENSORRT_MAJOR > 3 - trt_builder->setGpuAllocator(s.allocator_.get()); -#endif - auto trt_network = infer_object(trt_builder->createNetwork()); - if (!trt_network) { - return tensorflow::errors::Internal( - "Failed to create TensorRT network object"); + if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") { + VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() + << " which has an input at port " << in_edge->dst_input() + << " with #dim<3 and is not a const: " << shape; + return false; } + return true; +} - auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance(); - auto weight_rmgr = trt_rmgr->getManager("WeightStore"); - auto ws = new tensorflow::tensorrt::TRTWeightStore(); - TF_CHECK_OK(weight_rmgr->Create(engine_name, engine_name, ws)); - - // Build the network - Converter converter(trt_network.get(), ws, s.precision_mode == FP16MODE); - - std::vector input_names; - std::vector input_dtypes; - std::vector output_names; - std::vector output_dtypes; - TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names, - &input_dtypes, &output_names, - &output_dtypes, engine_name)); - - VLOG(2) << "Finished output"; - - // Build the engine - trt_builder->setMaxBatchSize(s.max_batch_size); - trt_builder->setMaxWorkspaceSize(s.max_workspace_size_bytes); - VLOG(0) << "Max batch size= " << s.max_batch_size - << " max workspace size= " << s.max_workspace_size_bytes; - if (s.precision_mode == FP16MODE) { - trt_builder->setHalf2Mode(true); - VLOG(0) << "Using FP16 precision mode"; - } - LOG(INFO) << "starting build engine"; - string engine_plan_string; - { - auto trt_engine = - infer_object(trt_builder->buildCudaEngine(*converter.network())); - VLOG(0) << "Built network"; - if (trt_engine.get() == nullptr) { - return tensorflow::errors::Internal("Engine building failure"); - } - auto engine_plan = infer_object(trt_engine->serialize()); - VLOG(0) << "Serialized engine"; - const char* engine_plan_data = - static_cast(engine_plan->data()); - engine_plan_string = - string(engine_plan_data, engine_plan_data + engine_plan->size()); - } - TF_RETURN_IF_ERROR(weight_rmgr->Delete( - engine_name, engine_name)); - LOG(INFO) << "finished engine " << engine_name << " containing " - << s.subgraph_node_ids.size() << " nodes"; - - // Build the TRT op - tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); - TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes)); - - VLOG(0) << "Finished op preparation"; - - auto status = op_builder.Attr("serialized_engine", engine_plan_string) - .Attr("input_nodes", input_names) - .Attr("output_nodes", output_names) - .Attr("OutT", output_dtypes) - .Device(s.device_name_) - .Finalize(s.trt_node); - - VLOG(0) << status.ToString() << " finished op building for " << engine_name - << " on device " << s.device_name_; - - return tensorflow::Status::OK(); +bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const { + if (out_edge->IsControlEdge()) return true; + if (out_edge->src()->type_string() == "Const") { + VLOG(1) << "--> Need to remove output node " << out_edge->src()->name() + << " which is a Const."; + return false; + } + return true; } } // namespace convert diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 3f6592cd25ff013cadc0621ba64f0553983dd10b..a60253740fe0b27dcd9c20618d6d05aa7001a1a1 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -22,69 +22,153 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" + #if GOOGLE_CUDA #if GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { +static const char* kInputPHName = "TensorRTInputPH_"; +static const char* kOutputPHName = "TensorRTOutputPH_"; namespace convert { -const int FP32MODE = 0; -const int FP16MODE = 1; -const int INT8MODE = 2; - -struct SubGraphParams { - SubGraphParams( - tensorflow::Graph& inp_graph, - const std::set& subgraph_node_id_numbers, - const std::vector>& input_indices, - const std::vector>& output_indices, - size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes, - const tensorflow::grappler::GraphProperties& current_graph_properties, - std::unordered_map>* output_edges, - tensorflow::NodeDef* constructed_trt_node, - int engine_precision_mode = FP32MODE, const string& device_name = "", - std::shared_ptr allocator = nullptr, - int cuda_gpu_id = 0) - : graph(inp_graph), - subgraph_node_ids(subgraph_node_id_numbers), - input_inds(input_indices), - output_inds(output_indices), - max_batch_size(max_supported_batch_size), - max_workspace_size_bytes(max_consumed_workspace_size_bytes), - graph_properties(current_graph_properties), - output_edge_map(output_edges), - trt_node(constructed_trt_node), - precision_mode(engine_precision_mode), - device_name_(device_name), - allocator_(allocator), - cuda_gpu_id_(cuda_gpu_id) {} - - tensorflow::Graph& graph; - const std::set& subgraph_node_ids; - const std::vector>& input_inds; // {node_id, output_idx} - const std::vector>& output_inds; // {node_id, output_idx} - size_t max_batch_size; - size_t max_workspace_size_bytes; - const tensorflow::grappler::GraphProperties& graph_properties; - std::unordered_map>* output_edge_map; - tensorflow::NodeDef* trt_node; - const int precision_mode; - const string device_name_; - std::shared_ptr allocator_; - const int cuda_gpu_id_; +struct EngineConnection { + // Constructs a non-control edge. + EngineConnection(const string& outside, int out_id, int out_port, + const string& inside, int in_id, int in_port, + bool input_edge, int port) + : outside_node_name(outside), + outside_id(out_id), + outside_port(out_port), + inside_node_name(inside), + inside_id(in_id), + inside_port(in_port), + is_input_edge(input_edge), + port_number(port) {} + + // Constructs a control edge. + EngineConnection(const string& outside, int out_id, const string& inside, + int in_id, bool input_edge) + : outside_node_name(outside), + outside_id(out_id), + outside_port(Graph::kControlSlot), + inside_node_name(inside), + inside_id(in_id), + inside_port(Graph::kControlSlot), + is_input_edge(input_edge), + port_number(Graph::kControlSlot) {} + + bool is_control_edge() const { return port_number == Graph::kControlSlot; } + + const string outside_node_name; + const int outside_id; + const int outside_port; + tensorflow::PartialTensorShape outside_shape; // Only set for input edge. + + const string inside_node_name; + const int inside_id; + const int inside_port; + tensorflow::PartialTensorShape inside_shape; // Only set for output edge. + + tensorflow::DataType connection_type; + const bool is_input_edge; + + // The port number of the TRT node connected with this edge. + const int port_number; +}; + +struct EngineInfo { + EngineInfo() + : engine_type(EngineType::TRTStatic), + max_workspace_size_bytes(0), + precision_mode(FP32MODE) {} + + string engine_name; + string device; + tensorflow::GraphDef segment_graph_def; + + // Non-control input connections inside this vector are sorted in a way such + // that, the segment nodes connecting to them are topological sorted. + // In addition, for non-control connections, there must be no duplicates. + std::vector connections; + + enum class EngineType { TRTStatic = 0, TRTDynamic = 1 }; + EngineType engine_type; + int64 max_workspace_size_bytes; + int maximum_cached_engines; + std::vector cached_engine_batches; + int precision_mode; +}; + +// Constructs a graphdef from the segment in the given graph. Adds placeholder +// nodes for input edges (InputPH_*) and identity nodes for output edges +// (OutputPH_*). This function needs to be called before TensorRT nodes +// inserted in order to correctly get sizes from the original graph. +// +// - subgraph_node_names: the node names of the subgraph. +// - subgraph_node_ids: the node ids of the subgraph, must be sorted in +// topological order. +// - segment_def: the output GraphDef, whose non-input/output nodedefs will be +// sorted in topological order. +// +// TODO(aaroey): add tests to validate these properties. +tensorflow::Status ConvertSegmentToGraphDef( + const tensorflow::Graph* graph, + const tensorflow::grappler::GraphProperties& graph_properties, + const std::set& subgraph_node_names, + const std::vector& subgraph_node_ids, + std::vector* connections, + tensorflow::GraphDef* segment_def, string* common_scope); + +// Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff +// 'builder' successfully build the engine. If the result is not ok, 'engine' +// will be set to nullptr +// Once returned, 'builder' is not needed any more and can be safely detroyed. +// +// - convert_successfully: indicates whether the converson to TensorRT network +// is successful. This is different than successfully building the engine: +// building can still fail afterwards. +tensorflow::Status ConvertGraphDefToEngine( + const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, + size_t max_workspace_size_bytes, + const std::vector& input_shapes, + Logger* logger, nvinfer1::IGpuAllocator* allocator, + TRTInt8Calibrator* calibrator, + TrtUniquePtrType* engine, + bool* convert_successfully); + +// Helper class for the segmenter to determine whether an input edge to the TRT +// segment is valid. +class InputEdgeValidator { + public: + InputEdgeValidator(const grappler::GraphProperties& graph_properties) + : graph_properties_(graph_properties) {} + + // Return true if the specified edge is eligible to be an input edge of the + // TRT segment. + bool operator()(const tensorflow::Edge* in_edge) const; + + private: + const grappler::GraphProperties& graph_properties_; +}; + +// Helper class for the segmenter to determine whether an output edge from the +// TRT segment is valid. +class OutputEdgeValidator { + public: + // Return true if the specified edge is eligible to be an output edge of the + // TRT segment. + bool operator()(const tensorflow::Edge* out_edge) const; }; -// TODO(sami): Replace references with const reference or pointers -tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams& params); -tensorflow::Status InjectCalibrationNode(SubGraphParams& params); -tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph, - tensorflow::Node* c_node); } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index 8f634b1f74717310a69a6bab5d5224c9bdbf10cc..f33f2cc4d68f5ac10eafeb744f8162bfca0abfab 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -17,9 +17,11 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stacktrace.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -45,8 +47,24 @@ tensorflow::Status TRTOptimizationPass::Init( if (params.count("max_batch_size")) { maximum_batch_size_ = params.at("max_batch_size").i(); } - if (params.count("max_workspace_size_bytes")) + is_dynamic_op_ = false; + if (params.count("is_dynamic_op")) { + is_dynamic_op_ = params.at("is_dynamic_op").b(); + } + if (params.count("cached_engine_batches")) { + auto batch_vec = params.at("cached_engine_batches").list(); + batches_.reserve(batch_vec.i_size()); + for (const auto i : batch_vec.i()) { + batches_.push_back(i); + } + } + max_cached_batches_ = 1; + if (params.count("maximum_cached_engines")) { + max_cached_batches_ = params.at("maximum_cached_engines").i(); + } + if (params.count("max_workspace_size_bytes")) { maximum_workspace_size_ = params.at("max_workspace_size_bytes").i(); + } if (params.count("precision_mode")) { string pm = Uppercase(params.at("precision_mode").s()); if (pm == "FP32") { @@ -172,7 +190,19 @@ tensorflow::Status TRTOptimizationPass::Optimize( tensorflow::grappler::Cluster* cluster, const tensorflow::grappler::GrapplerItem& item, GraphDef* optimized_graph) { VLOG(1) << "Called TRTOptimization Pass " << name_; + // This is a hack to workaround optimizer issue. MetaOptimizer calls + // optimization passes on function objects as well, we should not modify + // generated funcdefs! This is fragile but we don't have any other option + // until framework fixes it. + if (item.id != "tf_graph") { + LOG(WARNING) << name_ + << " is probably called on funcdef! This optimizer must *NOT* " + "be called on function objects."; + *optimized_graph = item.graph; + return tensorflow::Status::OK(); + } if (VLOG_IS_ON(1)) { + VLOG(2) << CurrentStackTrace(); PrintDebugInfo(cluster, item); } int max_dim = -1; @@ -204,11 +234,39 @@ tensorflow::Status TRTOptimizationPass::Optimize( } tensorflow::grappler::GraphProperties static_graph_properties(item); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); - auto status = tensorflow::tensorrt::convert::ConvertAfterShapes( - item.graph, item.fetch, maximum_batch_size_, maximum_workspace_size_, - optimized_graph, precision_mode_, minimum_segment_size_, - static_graph_properties, cluster); + tensorflow::tensorrt::convert::ConversionParams cp; + + std::vector nodes_to_preserve; + for (const auto& n : item.NodesToPreserve()) { + auto tokens = str_util::Split(n, ":"); + string s = tokens.at(0); + for (int i = 1; i < tokens.size() - 1; ++i) { + StrAppend(&s, ":", tokens.at(i)); + } + int dumm_port = -1; + // If the last token is not an integer, it must be part of the name. + // Otherwise it is port number. + if (tokens.size() > 1 && + !strings::safe_strto32(tokens.back(), &dumm_port)) { + StrAppend(&s, ":", tokens.back()); + } + nodes_to_preserve.push_back(s); + } + cp.input_graph_def = &item.graph; + cp.output_names = &nodes_to_preserve; + cp.max_batch_size = maximum_batch_size_; + cp.max_workspace_size_bytes = maximum_workspace_size_; + cp.output_graph_def = optimized_graph; + cp.precision_mode = precision_mode_; + cp.minimum_segment_size = minimum_segment_size_; + cp.graph_properties = &static_graph_properties; + cp.cluster = cluster; + cp.is_dyn_op = is_dynamic_op_; + cp.cached_engine_batches = batches_; + cp.max_cached_engines = max_cached_batches_; + auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp); VLOG(2) << optimized_graph->DebugString(); + VLOG(1) << "Returning from " << name_; return status; } diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h index d8ecead23efaa5c3bab95b8ba481e2307b0af772..463ed3883e4808408104c618a289989472c497ea 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h @@ -61,6 +61,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { int minimum_segment_size_; int precision_mode_; int maximum_batch_size_; + bool is_dynamic_op_; + std::vector batches_; + int max_cached_batches_; int64_t maximum_workspace_size_; }; diff --git a/tensorflow/contrib/tensorrt/convert/utils.cc b/tensorflow/contrib/tensorrt/convert/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7a1febb8c076891596741fe30721e7acca15a73 --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/utils.cc @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/convert/utils.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tensorrt { + +bool IsGoogleTensorRTEnabled() { + // TODO(laigd): consider also checking if tensorrt shared libraries are + // accessible. We can then direct users to this function to make sure they can + // safely write code that uses tensorrt conditionally. E.g. if it does not + // check for for tensorrt, and user mistakenly uses tensorrt, they will just + // crash and burn. +#if GOOGLE_CUDA && GOOGLE_TENSORRT + return true; +#else + return false; +#endif +} + +Status GetPrecisionModeName(const int precision_mode, string* name) { + switch (precision_mode) { + case FP32MODE: + *name = "FP32"; + break; + case FP16MODE: + *name = "FP16"; + break; + case INT8MODE: + *name = "INT8"; + break; + default: + return tensorflow::errors::OutOfRange("Unknown precision mode"); + } + return Status::OK(); +} + +Status GetPrecisionMode(const string& name, int* precision_mode) { + if (name == "FP32") { + *precision_mode = FP32MODE; + } else if (name == "FP16") { + *precision_mode = FP16MODE; + } else if (name == "INT8") { + *precision_mode = INT8MODE; + } else { + return tensorflow::errors::InvalidArgument("Invalid precision mode name: ", + name); + } + return Status::OK(); +} + +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/utils.h b/tensorflow/contrib/tensorrt/convert/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..0592f31462af2b20f3a13fe5119e89c2ba42dd8a --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/utils.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ + +#include + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tensorrt { + +template +struct TrtDestroyer { + void operator()(T* t) { + if (t) t->destroy(); + } +}; + +template +using TrtUniquePtrType = std::unique_ptr>; + +bool IsGoogleTensorRTEnabled(); + +// TODO(aaroey): use an enum instead. +const int FP32MODE = 0; +const int FP16MODE = 1; +const int INT8MODE = 2; + +Status GetPrecisionModeName(const int precision_mode, string* name); + +Status GetPrecisionMode(const string& name, int* precision_mode); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index a89cf3ab8bfaecc74fc5890ccb7e7a7147278182..69058c5826822c519a69d50860c06b8ab3ec6578 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -112,7 +112,9 @@ cuda_py_test( ], tags = [ "manual", + "no_windows", "noguitar", + "nomac", "notap", ], ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index 988b35f74f3989481f59c52c6320623a26704327..2de79737501a11d9760f9a7d3953cf132e512145 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -65,7 +65,7 @@ class IncPluginTRT : public OpKernel { reinterpret_cast(context->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); IncrementKernel(input_tensor.flat().data(), inc_, output_tensor->flat().data(), input_shape.num_elements(), *stream); diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc deleted file mode 100644 index aea44fd8a2fcc4c359a6cb0c98ae34711708326e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/tensorrt/kernels/trt_calib_op.h" -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/platform/stream_executor.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT -#include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/include/NvInfer.h" - -namespace tensorflow { -namespace tensorrt { - -TRTCalibOp::TRTCalibOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("segment_nodes", &segment_nodes_)); - OP_REQUIRES_OK(context, context->GetAttr("input_names", &input_names_)); - OP_REQUIRES_OK(context, context->GetAttr("resource_name", &resource_name_)); -}; - -#define TYPECASE(dt, X, Y) \ - case dt: { \ - return (void*)X->flat::Type>().data(); \ - } - -void* GetTensorAddress(const Tensor* tensor_ptr) { - auto tensor_type = tensor_ptr->dtype(); - switch (tensor_type) { - TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); - TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); - TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); - default: { - LOG(FATAL) << "Unsupported Data type " - << tensorflow::DataTypeString(tensor_type); - return nullptr; - } - } -} - -void TRTCalibOp::Compute(tensorflow::OpKernelContext* ctx) { - // TODO(aaroey): make sure ctx->resource_mgr() is used in future PR. - auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); - auto res_mgr = trt_rm->getManager("TRTCalibOps"); - tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; - auto status = res_mgr->Lookup(resource_name_, resource_name_, &calib_res); - - if (!status.ok()) { - ctx->SetStatus(status); - return; - } - int num_inputs = ctx->num_inputs(); - // first run instantiate calibrator - if (calib_res->calibrator_ == nullptr) { - dev_tensors_.resize(num_inputs); - int batch_size = ctx->input(0).dim_size(0); - VLOG(1) << " Constructing calibrator"; - for (int i = 0; i < num_inputs; i++) { - // allocate workspace on device for inputs - const tensorflow::Tensor& t = ctx->input(i); - OP_REQUIRES_OK(ctx, - ctx->allocate_persistent(t.dtype(), t.shape(), - &dev_tensors_.at(i), nullptr)); - const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); - CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); - void* device_address = GetTensorAddress(device_tensor); - device_buffers_.emplace(input_names_.at(i), - std::pair( - device_address, device_tensor->TotalBytes())); - } - - calib_res->calibrator_ = - new TRTInt8Calibrator(device_buffers_, batch_size, resource_name_); - string label(resource_name_); - calib_res->thr_ = new std::thread([calib_res, label]() { - VLOG(1) << "Starting calibration thread, Calibration Resource @ " - << calib_res; - calib_res->builder_->setInt8Calibrator(calib_res->calibrator_); - calib_res->builder_->setInt8Mode(true); - calib_res->engine_ = calib_res->builder_->buildCudaEngine( - *calib_res->network_); // will loop until we terminate calibrator - VLOG(1) << "Calibration loop terminated " << label; - }); - VLOG(1) << "initialized calibrator resource"; - } // calibrator initialized - - // Pass input data to calibrator - std::unordered_map input_data; - for (int i = 0; i < num_inputs; i++) { - const Tensor& t = ctx->input(i); - void* data_address = GetTensorAddress(&t); - const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); - CHECK_EQ(t.TotalBytes(), - device_tensor->TotalBytes()); // use the tensor so FW keeps it - input_data.emplace(input_names_.at(i), data_address); - ctx->set_output(i, t); - } - VLOG(2) << "Filled map for sending"; - // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files - const cudaStream_t* stream = CHECK_NOTNULL( - reinterpret_cast(ctx->op_device_context() - ->stream() - ->implementation() - ->CudaStreamMemberHack())); - calib_res->calibrator_->setBatch(input_data, *stream); - VLOG(2) << "Passed calibration data"; - // TODO(aaroey): make sure we wait for the completion of calibration on the - // last batch in future PR. -}; - -#undef TYPECASE - -REGISTER_KERNEL_BUILDER(Name("TRTCalibOp").Device(DEVICE_GPU), TRTCalibOp); - -} // namespace tensorrt -} // namespace tensorflow -#endif -#endif diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h deleted file mode 100644 index 23df9db32f077a080eaff7479fcbe90d6a504c42..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H -#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H - -#include -#include -#include -#include -#include -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/platform/types.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT -namespace tensorflow { -namespace tensorrt { -// TODO(sami): Convert this to async kernel! -class TRTCalibOp : public OpKernel { - public: - explicit TRTCalibOp(OpKernelConstruction* context); - - void Compute(OpKernelContext* context) override; - - private: - string resource_name_; - std::vector segment_nodes_; - std::vector input_names_; - std::vector shapes_; - std::unordered_map> device_buffers_; - std::vector dev_tensors_; -}; -} // namespace tensorrt -} // namespace tensorflow -#endif -#endif -#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 9ac8047944874181de228a6cc58e2dafe46abe50..2b42d81f475189f74a934c3aeed7d7fc34d4eb53 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -14,8 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" +#include + +#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" @@ -25,144 +36,574 @@ limitations under the License. #include "cuda/include/cuda_runtime_api.h" namespace tensorflow { -static ::tensorflow::tensorrt::Logger logger; -using IRuntime = nvinfer1::IRuntime; -using Dims = nvinfer1::Dims; - namespace tensorrt { +static Logger logger; +using ::nvinfer1::IRuntime; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +// A helper class to call done() when destructed for asynchronous execution. +// Helps simultaneous execution of native and TRT engines. +class AsyncHelper : public tensorflow::core::RefCounted { + public: + AsyncHelper(AsyncOpKernel::DoneCallback done) { done_ = done; } + ~AsyncHelper() override { done_(); } + + private: + AsyncOpKernel::DoneCallback done_; +}; + +#define TYPECASE(dt, X, Y) \ + case dt: { \ + return (void*)X->flat::Type>().data(); \ + } + +void* GetTensorAddress(const Tensor* tensor_ptr) { + auto tensor_type = tensor_ptr->dtype(); + switch (tensor_type) { + TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); + default: { + LOG(ERROR) << "Unsupported Data type " + << tensorflow::DataTypeString(tensor_type); + return nullptr; + } + } +} -TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { +tensorflow::Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) { + VLOG(1) << "Constructing function handle"; + auto lib = ctx->function_library(); + if (lib == nullptr) { + return tensorflow::errors::Internal("Context function library is null"); + } + auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_); + if (fdef == nullptr) { + return tensorflow::errors::Internal("Native FunctionDef ", funcdef_name_, + " can't be found in function library"); + } + tensorflow::FunctionLibraryRuntime::InstantiateOptions inst_ops; + inst_ops.overlay_lib = nullptr; + inst_ops.state_handle = ""; + inst_ops.target = ctx->device()->name(); + native_func_ = 0; + auto status = lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()), + inst_ops, &native_func_); + if (!status.ok()) { + LOG(ERROR) << " Instantiating native function " << funcdef_name_ + << " failed!"; + } + return status; +} + +TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) + : AsyncOpKernel(context) { // read serialized_engine OP_REQUIRES_OK(context, - context->GetAttr("serialized_engine", &serialized_engine_)); + context->GetAttr("serialized_segment", &serialized_segment_)); + OP_REQUIRES_OK(context, + context->GetAttr("workspace_size_bytes", &workspace_size_)); + OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_)); + if (!static_engine_) { + if (!segment_graph_.ParseFromString(serialized_segment_)) { + LOG(ERROR) << "Parsing segment graph failed!"; + context->SetStatus(tensorflow::errors::InvalidArgument( + "Failed to parse segment graphdef!")); + return; + } + serialized_segment_.resize(0); + } + VLOG(1) << "Constructing " << name(); + string precision_string; + OP_REQUIRES_OK(context, + context->GetAttr("precision_mode", &precision_string)); + string calibration_data; + OP_REQUIRES_OK(context, + context->GetAttr("calibration_data", &calibration_data)); + OP_REQUIRES_OK(context, + context->GetAttr("segment_funcdef_name", &funcdef_name_)); + OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_)); + calibration_mode_ = + (precision_mode_ == INT8MODE && calibration_data.size() == 0); + if (calibration_data.size()) { + calibrator_.reset(new TRTInt8Calibrator(calibration_data)); + calibration_data.resize(0); + } + native_func_ = tensorflow::kInvalidHandle; + OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count", + &max_cached_engines_)); + OP_REQUIRES_OK(context, + context->GetAttr("fixed_input_size", &fixed_input_size_)); + OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches", + &cached_engine_batches_)); + std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end()); + if (VLOG_IS_ON(1)) { + string s("Engine Batches= "); + for (auto i : cached_engine_batches_) { + StrAppend(&s, i, " "); + } + VLOG(1) << s; + } +} - // register input output node name in trt_sub_graph - OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_)); - OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_)); +void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, + AsyncHelper* helper) { + if (!calibration_mode_) { + VLOG(1) << "Executing native engine"; + } + std::vector inputs; + std::vector* outputs = new std::vector(); + if (native_func_ == tensorflow::kInvalidHandle) { + auto status = ConstructFunctionHandle(ctx); + if (!status.ok()) { + LOG(ERROR) << "Couldn't construct function handle " << funcdef_name_; + ctx->SetStatus(status); + return; + } + } + auto lib = ctx->function_library(); + tensorflow::FunctionLibraryRuntime::Options opts; + opts.step_id = ctx->step_id(); + opts.rendezvous = ctx->rendezvous(); + opts.cancellation_manager = ctx->cancellation_manager(); + opts.runner = ctx->runner(); + for (int i = 0; i < ctx->num_inputs(); i++) { + inputs.push_back(ctx->input(i)); + } + helper->Ref(); // Increment count for calculating native graph + VLOG(1) << "Executing native segment " << name(); + lib->Run(opts, native_func_, inputs, outputs, + [this, ctx, outputs, helper](const tensorflow::Status& s) { + tensorflow::core::ScopedUnref sc(helper); + VLOG(1) << "Native Segment completed"; + if (!s.ok()) { + ctx->SetStatus(s); + return; + } + for (size_t t = 0; t < outputs->size(); ++t) { + ctx->set_output(t, outputs->at(t)); + } + test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"), + "done"); + delete outputs; + }); } -void TRTEngineOp::Compute(OpKernelContext* context) { - // TODO(samikama) runtime should be taken from a resourcemanager as well. - // Only engine should be in the op and context and runtime should be taken - // from resourcemanager +void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, + AsyncHelper* helper) { + helper->Ref(); + tensorflow::core::ScopedUnref sc(helper); + // TODO(aaroey): remove the ResourceMgr singleton. + auto trt_rm = TRTResourceManager::instance(); + auto res_mgr = trt_rm->getManager("TRTCalibration"); + TRTCalibrationResource* calib_res = nullptr; + auto status = res_mgr->LookupOrCreate( + funcdef_name_, "Calibrator", &calib_res, + {[ctx, this](TRTCalibrationResource** cr) -> tensorflow::Status { + return this->AllocateCalibrationResources(ctx, cr); + }}); + if (!status.ok()) { + ctx->SetStatus(status); + return; + } + int num_inputs = ctx->num_inputs(); + // Pass input data to calibrator + std::unordered_map input_data; + for (int i = 0; i < num_inputs; i++) { + const Tensor& t = ctx->input(i); + void* data_address = GetTensorAddress(&t); + if (data_address == nullptr) { + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "Unsupported data type encountered in input ", i)); + return; + } + // Check the allocated buffer is sufficient for input + const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); + input_data.emplace(StrCat(kInputPHName, i), data_address); + } + VLOG(2) << "Filled map for sending"; + // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files + const cudaStream_t* stream = CHECK_NOTNULL( + reinterpret_cast(ctx->op_device_context() + ->stream() + ->implementation() + ->GpuStreamMemberHack())); + calib_res->calibrator_->setBatch(input_data, *stream); + test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done"); + VLOG(2) << "Passed calibration data"; + ExecuteNativeSegment(ctx, helper); +} - if (!trt_execution_context_ptr_) { - IRuntime* infer = nvinfer1::createInferRuntime(logger); -#if NV_TENSORRT_MAJOR > 3 - auto device = context->device(); - auto dev_allocator = - device->GetAllocator(tensorflow::AllocatorAttributes()); - if (!dev_allocator) { - LOG(FATAL) << "Can't find device allocator for gpu device " - << device->name(); - } - allocator_ = std::make_shared(dev_allocator); - infer->setGpuAllocator(allocator_.get()); -#endif - trt_engine_ptr_.reset(infer->deserializeCudaEngine( - serialized_engine_.c_str(), serialized_engine_.size(), - PluginFactoryTensorRT::GetInstance())); - trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); - // Runtime is safe to delete after engine creation - infer->destroy(); - serialized_engine_.clear(); +int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) { + int num_batch = ctx->input(0).shape().dim_size(0); + int smallest_engine = 0; + for (const auto i : cached_engine_batches_) { + if (i >= num_batch) { + smallest_engine = i; + break; + } } - int num_binding = context->num_inputs() + context->num_outputs(); - std::vector buffers(num_binding); + // TODO(sami): Need an LRU here + if (smallest_engine == 0) { + if (max_cached_engines_ > cached_engine_batches_.size()) { + smallest_engine = num_batch; + cached_engine_batches_.push_back(num_batch); + VLOG(1) << "Running with batch size " << num_batch; + } else { + string msg = + StrCat("Engine buffer is full. buffer limit=", max_cached_engines_, + ", current entries="); + for (auto i : cached_engine_batches_) StrAppend(&msg, i, ","); + StrAppend(&msg, " requested batch=", num_batch); + LOG(WARNING) << msg; + return -1; + } + } + return smallest_engine; +} + +void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, + AsyncOpKernel::DoneCallback done) { + auto helper = new AsyncHelper(done); + tensorflow::core::ScopedUnref sc(helper); + if (calibration_mode_) { + ExecuteCalibration(ctx, helper); + return; + } + const int smallest_engine = GetEngineBatch(ctx); + if (smallest_engine < 0) { + LOG(WARNING) << "Failed to get engine batch, running native segment for " + << name(); + ExecuteNativeSegment(ctx, helper); + return; + } + + const int num_batch = ctx->input(0).shape().dim_size(0); + auto& engine_ctx_pair = GetEngine(smallest_engine, ctx); + auto& trt_engine_ptr = engine_ctx_pair.first; + if (!trt_engine_ptr) { + LOG(WARNING) << "Engine retrieval for batch size " << num_batch + << " failed. Running native segment for " << name(); + ExecuteNativeSegment(ctx, helper); + return; + } + const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(), + engine_ctx_pair.second.get()); + if (retry) { + LOG(WARNING) << "Failed to execute engine, " + << "retrying with native segment for " << name(); + ExecuteNativeSegment(ctx, helper); + return; + } +} - size_t binding_index; - int num_batch = 0; - for (int i = 0; i < context->num_inputs(); i++) { - // Grab the input tensor - binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str()); +bool TRTEngineOp::ExecuteTrtEngine( + OpKernelContext* ctx, const int num_batch, + nvinfer1::ICudaEngine* trt_engine_ptr, + nvinfer1::IExecutionContext* trt_execution_context_ptr) { + const bool kRetry = true; + const int num_binding = ctx->num_inputs() + ctx->num_outputs(); + std::vector buffers(num_binding); + for (int i = 0; i < ctx->num_inputs(); i++) { + const string input_name = StrCat(kInputPHName, i); + const size_t binding_index = + trt_engine_ptr->getBindingIndex(input_name.c_str()); + if (binding_index == -1) { + LOG(ERROR) << "Input node not found, at " << input_name; + return kRetry; + } - const Tensor& input_tensor = context->input(i); + const Tensor& input_tensor = ctx->input(i); const TensorShape& input_shape = input_tensor.shape(); - if (i == 0) { - num_batch = input_shape.dim_size(0); - if (num_batch > trt_engine_ptr_->getMaxBatchSize()) { - LOG(FATAL) << "input tensor batch larger than max_batch_size: " - << trt_engine_ptr_->getMaxBatchSize(); - } - } else if (num_batch != input_shape.dim_size(0)) { - LOG(FATAL) << "input data inconsistent batch size"; - break; + if (num_batch != input_shape.dim_size(0)) { + LOG(ERROR) << "Input data has inconsistent batch size: " << num_batch + << " vs " << input_shape.dim_size(0); + return kRetry; } - auto dtype = trt_engine_ptr_->getBindingDataType(binding_index); + auto dtype = trt_engine_ptr->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = (void*)(input_tensor.flat().data()); break; case nvinfer1::DataType::kHALF: - LOG(FATAL) << "half size is not supported yet!"; - break; + LOG(ERROR) << "FP16 inputs are not supported yet!"; + return kRetry; case nvinfer1::DataType::kINT8: - LOG(FATAL) << "int8 is not supported yet!"; + LOG(ERROR) << "INT8 inputs are not supported yet!"; + return kRetry; +#if NV_TENSORRT_MAJOR > 3 + case nvinfer1::DataType::kINT32: + buffers[binding_index] = (void*)(input_tensor.flat().data()); break; +#endif default: - LOG(FATAL) << "Unknown data type: " << int(dtype); - break; + LOG(ERROR) << "Unknown TRT data type: " << int(dtype); + return kRetry; } } - for (int i = 0; i < static_cast(output_nodes_.size()); i++) { - // This is bad that we have to reallocate output buffer every run. + for (int i = 0; i < ctx->num_outputs(); i++) { // Create an output tensor - binding_index = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str()); + const string output_name = StrCat(kOutputPHName, i); + const size_t binding_index = + trt_engine_ptr->getBindingIndex(output_name.c_str()); Tensor* output_tensor = nullptr; TensorShape output_shape; if (binding_index != -1) { - auto dims = trt_engine_ptr_->getBindingDimensions(binding_index); + auto dims = trt_engine_ptr->getBindingDimensions(binding_index); std::vector trt_shape(dims.nbDims + 1); trt_shape[0] = num_batch; for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j]; - OP_REQUIRES_OK(context, - TensorShapeUtils::MakeShape( - trt_shape.data(), trt_shape.size(), &output_shape)); + auto status = TensorShapeUtils::MakeShape( + trt_shape.data(), trt_shape.size(), &output_shape); + if (!status.ok()) { + LOG(ERROR) << "Failed to get output shape: " << status; + return kRetry; + } } else { - LOG(FATAL) << "output node not found, at " << output_nodes_[i]; - break; + LOG(ERROR) << "Output node not found, at " << output_name; + return kRetry; } - - OP_REQUIRES_OK(context, - context->allocate_output(i, output_shape, &output_tensor)); - auto dtype = trt_engine_ptr_->getBindingDataType(binding_index); + auto status = ctx->allocate_output(i, output_shape, &output_tensor); + if (!status.ok()) { + LOG(ERROR) << "Allocating output failed with " << status; + ctx->SetStatus(status); + // Do not retry since we cannot allocate the same output twice. + // TODO(aaroey): ideally we should retry, fix this. + return !kRetry; + } + auto dtype = trt_engine_ptr->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = reinterpret_cast(output_tensor->flat().data()); break; case nvinfer1::DataType::kHALF: - LOG(FATAL) << "half size is not supported yet!"; - break; + LOG(WARNING) << "half size is not supported yet!"; + return kRetry; case nvinfer1::DataType::kINT8: - LOG(FATAL) << "int8 is not supported yet!"; + LOG(WARNING) << "int8 is not supported yet!"; + return kRetry; +#if NV_TENSORRT_MAJOR > 3 + case nvinfer1::DataType::kINT32: + buffers[binding_index] = + reinterpret_cast(output_tensor->flat().data()); break; +#endif default: - LOG(FATAL) << "Unknown data type: " << int(dtype); - break; + LOG(WARNING) << "Unknown TRT data type: " << static_cast(dtype); + return kRetry; } } - // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files + // Copied from cuda_kernel_helper since it seems only valid in *.cu.cc files const cudaStream_t* stream = CHECK_NOTNULL( - reinterpret_cast(context->op_device_context() + reinterpret_cast(ctx->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); // TODO(jie): trt enqueue does not return error - auto ret = trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], - *stream, nullptr); - VLOG(2) << "enqueue returns: " << ret; - // sync should be done by TF. + auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream, + nullptr); + if (!ret) { + LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name(); + return kRetry; + } + test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done"); + // Synchronization will be done by TF. + return !kRetry; } + TRTEngineOp::~TRTEngineOp() { - // Order matters! - trt_execution_context_ptr_.reset(); - trt_engine_ptr_.reset(); + // We need to manually destroy the engine and execution context before + // the allocator is destructed. + for (auto& eng : engine_map_) { + eng.second.first.reset(); + eng.second.second.reset(); + } allocator_.reset(); } + +nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) { + if (allocator_) return allocator_.get(); + auto device = ctx->device(); + auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes()); + if (!alloc) { + LOG(ERROR) << "Can't find device allocator for gpu device " + << device->name(); + return nullptr; + } + allocator_.reset(new TRTDeviceAllocator(alloc)); + return allocator_.get(); +} + +TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, + OpKernelContext* ctx) { + static EngineCtxPair null_pair = { + TrtUniquePtrType(nullptr), + TrtUniquePtrType(nullptr)}; + // TODO(sami): This method needs to be re-written to use resource manager and + // with LRU mechanism option. + tensorflow::mutex_lock lock(engine_mutex_); + + if (static_engine_) { + if (engine_map_.size()) { + if (engine_map_.begin()->first >= batch_size) { + return engine_map_.begin()->second; + } + return null_pair; + } + TrtUniquePtrType infer(nvinfer1::createInferRuntime(logger)); +#if NV_TENSORRT_MAJOR > 3 + auto allocator = GetAllocator(ctx); + if (allocator == nullptr) { + return null_pair; + } + infer->setGpuAllocator(allocator); +#endif + TrtUniquePtrType static_engine( + infer->deserializeCudaEngine(serialized_segment_.c_str(), + serialized_segment_.size(), + PluginFactoryTensorRT::GetInstance())); + auto raw_static_engine = static_engine.get(); + const auto max_batch_size = raw_static_engine->getMaxBatchSize(); + engine_map_[max_batch_size] = { + std::move(static_engine), + TrtUniquePtrType( + raw_static_engine->createExecutionContext())}; + // Runtime is safe to delete after engine creation + serialized_segment_.clear(); + if (max_batch_size < batch_size) { + return null_pair; + } + return engine_map_.at(max_batch_size); + } // static_engine_ + + // Handle the dynamic engine case. + auto engine_it = engine_map_.find(batch_size); + if (engine_it == engine_map_.end() && + engine_map_.size() < (size_t)max_cached_engines_) { + nvinfer1::IGpuAllocator* allocator = nullptr; +#if NV_TENSORRT_MAJOR > 3 + allocator = GetAllocator(ctx); + if (allocator == nullptr) { + return null_pair; + } +#endif + std::vector shapes; + for (int i = 0; i < ctx->num_inputs(); ++i) { + shapes.emplace_back(ctx->input(i).shape()); + } + TrtUniquePtrType engine; + bool convert_successfully = false; + VLOG(0) << name() << " Constructing a new engine with batch size " + << batch_size; + // Up to this point, calibrator_ can never be empty, since otherwise it + // means calibration_mode_ is true and this path won't get executed. + auto status = convert::ConvertGraphDefToEngine( + segment_graph_, precision_mode_, batch_size, workspace_size_, shapes, + &logger, allocator, calibrator_.get(), &engine, &convert_successfully); + if (!status.ok()) { + if (convert_successfully) { + // This means it fail to build the engine even when the network is built + // successfully, probably due to internal issues. In this case we don't + // retry in the future. + engine_map_[batch_size] = {nullptr, nullptr}; + } + LOG(WARNING) << "Engine creation for batch size " << batch_size + << " failed " << status; + return null_pair; + } + VLOG(1) << "Conversion is done"; + TrtUniquePtrType exec_context( + engine->createExecutionContext()); + engine_map_[batch_size] = {std::move(engine), std::move(exec_context)}; + } + return engine_map_.at(batch_size); +} + +tensorflow::Status TRTEngineOp::AllocateCalibrationResources( + OpKernelContext* ctx, TRTCalibrationResource** cr) { + auto cres = new TRTCalibrationResource(); + *cr = cres; + // Get the allocator. + auto alloc = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes()); + if (!alloc) { + LOG(WARNING) << "Can't get device allocator will not be able to " + "allocate memory from TensorFlow memory pool"; + cres->allocator_.reset(new TRTCudaAllocator); + } else { + cres->allocator_.reset(new TRTDeviceAllocator(alloc)); + } + // Get the input shapes. + const int batch_size = ctx->input(0).dim_size(0); + const int num_inputs = ctx->num_inputs(); + std::vector shapes; + dev_tensors_.resize(num_inputs); + VLOG(1) << " Constructing calibrator"; + for (int i = 0; i < num_inputs; i++) { + // allocate workspace on device for inputs + const tensorflow::Tensor& t = ctx->input(i); + shapes.emplace_back(t.shape()); + Tensor* device_tensor; + TF_RETURN_IF_ERROR(ctx->allocate_persistent( + t.dtype(), t.shape(), &dev_tensors_.at(i), &device_tensor)); + CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); + void* device_address = GetTensorAddress(device_tensor); + if (device_address == nullptr) { + return tensorflow::errors::InvalidArgument( + "Unsupported data type encountered in input ", i); + } + device_buffers_.emplace( + StrCat(kInputPHName, i), + std::pair(device_address, device_tensor->TotalBytes())); + } + cres->calibrator_.reset( + new TRTInt8Calibrator(device_buffers_, batch_size, name())); + const string label(name()); + auto segment_graph = &segment_graph_; + const int cuda_gpu_id = ctx->device()->tensorflow_gpu_device_info()->gpu_id; + if (cuda_gpu_id < 0) { + LOG(ERROR) << "Can't get gpu_device_info from context->device()"; + return tensorflow::errors::InvalidArgument( + "Context->device doesn't contain device info!"); + } + const int64 workspace_size_bytes = workspace_size_; + cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes, + cuda_gpu_id, workspace_size_bytes]() { + VLOG(0) << "Starting calibration thread on device " << cuda_gpu_id + << ", Calibration Resource @ " << cres; + auto err = cudaSetDevice(cuda_gpu_id); + if (err != cudaSuccess) { + // TODO(aaroey): should return error here. + LOG(ERROR) << "Couldn't set cuda device to " << cuda_gpu_id + << " in calibration thread"; + } + // ConvertGraphDefToEngine() will try to build the engine. This thread + // will loop inside buildCudaEngine() consuming the calibration data + // that is set by the TF op, and drive the builder until calibrator returns + // false. Engine is discarded after calibration table is generated + // + // TODO(aaroey): maybe setting the max batch size using the python + // calibration wrapper class. + auto s = convert::ConvertGraphDefToEngine( + *segment_graph, INT8MODE, cres->calibrator_->getBatchSize(), + workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(), + cres->calibrator_.get(), &cres->engine_, + /*convert_successfully=*/nullptr); + if (!s.ok()) { + LOG(ERROR) << "Calibration failed: " << s; + cres->calibrator_->setDone(); // Ignore further pushes + } + VLOG(1) << "Calibration loop terminated " << label; + })); + VLOG(1) << "initialized calibrator resource"; + return tensorflow::Status::OK(); +} + REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp); } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h index e613a71422852e60565ba7554516d7eace6b9cc7..8fe06758914261035c90a6fda3f114a63a8ac93a 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -19,9 +19,14 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/mutex.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -30,32 +35,101 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -class Logger; - +struct TRTInt8Calibrator; +class TRTCalibrationResource; +class AsyncHelper; // TODO(Sami): Remove this file? -class TRTEngineOp : public OpKernel { + +// This OP can construct TRTEngine on the fly and if construction of engine +// fails, executes equivalent subgraph as a TensorFlow function. +class TRTEngineOp : public AsyncOpKernel { public: explicit TRTEngineOp(OpKernelConstruction* context); - void Compute(OpKernelContext* context) override; + void ComputeAsync(OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; ~TRTEngineOp(); private: - template - struct Destroyer { - void operator()(T* d) { d->destroy(); } - }; - - template - using destroyed_ptr = std::unique_ptr>; - destroyed_ptr trt_engine_ptr_; + // Execute calibration + void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper); + + // Construct a function handle for executing native funcdef graph + Status ConstructFunctionHandle(OpKernelContext* ctx); + + // Execute replaced native segment as function Op. + void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); + + // Execute the tensorrt engine. Returns whether we need to retry by running + // the native segment. + bool ExecuteTrtEngine(OpKernelContext* ctx, const int num_batch, + nvinfer1::ICudaEngine* trt_engine_ptr, + nvinfer1::IExecutionContext* trt_execution_context_ptr); + + // Allocate necessary resources for calibration + Status AllocateCalibrationResources(OpKernelContext* ctx, + TRTCalibrationResource** cr); + // TODO(samikama): context should go to a resource manager! - destroyed_ptr trt_execution_context_ptr_; + typedef std::pair, + TrtUniquePtrType> + EngineCtxPair; + EngineCtxPair& GetEngine(int batch_size, OpKernelContext* ctx); + + // Return engine batch closest to input batch. + int GetEngineBatch(OpKernelContext* ctx); + + nvinfer1::IGpuAllocator* GetAllocator(OpKernelContext* ctx); + // map to keep engines and their execution context for given batch size. + std::unordered_map engine_map_; std::vector input_nodes_; std::vector output_nodes_; - std::shared_ptr allocator_; - string serialized_engine_; + + // keep device allocator for TRT. + std::unique_ptr allocator_; + + // serialized protobuf segment or trt engine depending on static_engine_ flag. + string serialized_segment_; + + // Name of the function for TF native execution of the segment. + string funcdef_name_; + + // GraphDef representation of the segment. + GraphDef segment_graph_; + + // Lookup table for temporary staging areas of input tensors for calibration. + std::unordered_map> device_buffers_; + + // Temporary staging areas for calibration inputs. + std::vector dev_tensors_; + + // Engine Precision mode. + int precision_mode_; + + // Whether engine is constructed during the conversion or needs to be + // constructed from protobuf segment. + bool static_engine_; + + // Whether to calibrate INT8 engine. + bool calibration_mode_; + + // Whether non-batch ranks of the inputs are assumed to be fixed or not for + // engine construction. + bool fixed_input_size_; + + // Batches of the cached engines + std::vector cached_engine_batches_; + + // Maximum number of cached engines + int max_cached_engines_; + + int64 workspace_size_; + mutex engine_mutex_; + FunctionLibraryRuntime::Handle native_func_; + + // The finalized calibrator for inference. + std::unique_ptr calibrator_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc index 079d73f7bec3f9a9740e455b31a259cec287f849..e0c7b6272379a20e3dacb6cd7c3b39de735d844d 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc @@ -28,14 +28,28 @@ extern Status TRTEngineOpShapeInference(InferenceContext* c); } REGISTER_OP("TRTEngineOp") - .Attr("serialized_engine: string") - .Attr("input_nodes: list(string)") - .Attr("output_nodes: list(string)") - .Attr("InT: list({float32})") - .Attr("OutT: list({float32})") + .Attr("serialized_segment: string") + .Attr("input_shapes: list(shape)") + .Attr("output_shapes: list(shape)") + .Attr("segment_funcdef_name: string") + .Attr("InT: list({int8,float16,float32})") + .Attr("OutT: list({int8,float16,float32})") + .Attr("static_engine: bool = true") + .Attr("fixed_input_size: bool = true") + .Attr("cached_engine_batches: list(int) = []") + .Attr("max_cached_engines_count: int = 1") + .Attr("workspace_size_bytes: int") + .Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}") + .Attr("calibration_data: string = ''") .Input("in_tensor: InT") - .Output("out_tensor: OutT") - .SetShapeFn(shape_inference::TRTEngineOpShapeInference); + .Output("out_tensor: OutT"); +// TODO(jie): TF requires concrete output shape for concrete input shapes. +// This is tricky for batch dimension, since we cannot ensure which input +// would carry the correct batch dimension (for the current stage of the +// implementation, we do require all input tensor to carry the same batch +// size, but this could change in the future). Hence we disable shape +// inference function as a workaround. +// .SetShapeFn(shape_inference::TRTEngineOpShapeInference); } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 2bc591484dcaf5b35c39f3d0523dd89dcd152e6a..cccc91226265ed139fb8db0b71c40b868f729562 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -65,9 +65,6 @@ bool PluginFactoryTensorRT::RegisterPlugin( void PluginFactoryTensorRT::DestroyPlugins() { tensorflow::mutex_lock lock(instance_m_); - for (auto& owned_plugin_ptr : owned_plugins_) { - owned_plugin_ptr.release(); - } owned_plugins_.clear(); } diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py index 0b2321b5fc7bcbd53c01d1c97cafcfcb229a83ef..7cdfe2b1a612be2eec473d806d0eb44b611ca68a 100644 --- a/tensorflow/contrib/tensorrt/python/__init__.py +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -20,6 +20,11 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph +from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph +from tensorflow.contrib.tensorrt.python.trt_convert import enable_test_value +from tensorflow.contrib.tensorrt.python.trt_convert import get_test_value +from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled # pylint: enable=unused-import,line-too-long diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 338475d90ea55ab2c1bb8df77f27a71a4a36a5dd..4116f2fe30aa5c0c9ea139100291abe3b13da94b 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -20,27 +20,35 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long import six as _six +from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert -from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert +from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values +from tensorflow.contrib.tensorrt.wrap_conversion import enable_test_value +from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version +from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version +from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value +from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled from tensorflow.core.framework import graph_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python.framework import errors from tensorflow.python.framework import errors_impl as _impl -from tensorflow.python.framework import meta_graph +from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.grappler import tf_optimizer -from tensorflow.python.util import compat +from tensorflow.python.platform import tf_logging +from tensorflow.python.training import saver # pylint: enable=unused-import,line-too-long -# TODO(skama): get outputs from session when implemented as c++ -# optimization pass def create_inference_graph(input_graph_def, outputs, max_batch_size=1, max_workspace_size_bytes=2 << 20, precision_mode="FP32", - minimum_segment_size=3): + minimum_segment_size=3, + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=None): """Python wrapper for the TRT transformation. Args: @@ -51,6 +59,10 @@ def create_inference_graph(input_graph_def, precision_mode: one of 'FP32', 'FP16' and 'INT8' minimum_segment_size: the minimum number of nodes required for a subgraph to be replaced by TRTEngineOp. + is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT + network and engine at run time. + maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. + cached_engine_batches: batch sizes used to pre-create cached engines. Returns: New GraphDef with TRTEngineOps placed in graph replacing subgraphs. @@ -65,6 +77,29 @@ def create_inference_graph(input_graph_def, "It should be one of {}").format( precision_mode, "{'FP32', 'FP16', 'INT8'}")) mode = supported_precision_modes[precision_mode.upper()] + compiled_version = get_linked_tensorrt_version() + loaded_version = get_loaded_tensorrt_version() + version_mismatch = False + if loaded_version[0] < compiled_version[0]: + tf_logging.error( + "TensorRT version mismatch. Tensorflow was compiled against " + + "TensorRT %s but library loaded from environment is TensorRT %s" % + (".".join([str(x) for x in compiled_version]), + ".".join([str(x) for x in loaded_version])) + + ". Please make sure that correct version of TensorRT " + + "is available in the system and added to ldconfig or LD_LIBRARY_PATH") + raise RuntimeError("Incompatible TensorRT library version") + for i in zip(loaded_version, compiled_version): + if i[0] != i[1]: + tf_logging.warn("TensorRT mismatch. Compiled against version " + + "%s, but loaded %s. Things may not work" % + (".".join([str(x) for x in compiled_version]), + ".".join([str(x) for x in loaded_version]))) + version_mismatch = True + break + if not version_mismatch: + tf_logging.info("Running against TensorRT version %s" % ".".join( + [str(x) for x in loaded_version])) def py2bytes(inp): return inp @@ -85,46 +120,50 @@ def create_inference_graph(input_graph_def, to_bytes = py3bytes to_string = py3string - out_names = [] - for i in outputs: - if isinstance(i, ops.Tensor): - out_names.append(to_bytes(i.name)) - else: - out_names.append(to_bytes(i)) - - input_graph_def_str = input_graph_def.SerializeToString() - - # TODO(sami): Fix this when we can return status from C++ library - # There is a problem with the TF internal library setup that doesn't - # allow us to return a status object from C++. Thus we return a - # pair or strings where first one is encoded status and the second - # one is the transformed graphs protobuf string. - out = trt_convert(input_graph_def_str, out_names, max_batch_size, - max_workspace_size_bytes, mode, minimum_segment_size) - status = to_string(out[0]) - output_graph_def_string = out[1] - del input_graph_def_str # Save some memory - if len(status) < 2: - raise _impl.UnknownError(None, None, status) - if status[:2] != "OK": - msg = status.split(";") - if len(msg) == 1: - raise RuntimeError("Status message is malformed {}".format(status)) - # pylint: disable=protected-access - raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), - int(msg[0])) - # pylint: enable=protected-access - output_graph_def = graph_pb2.GraphDef() - output_graph_def.ParseFromString(output_graph_def_string) - del output_graph_def_string # Save some memory - return output_graph_def - - -def calib_graph_to_infer_graph(calibration_graph_def): + # Create MetaGraphDef + graph = ops.Graph() + with graph.as_default(): + importer.import_graph_def(input_graph_def, name="") + meta_graph = saver.export_meta_graph( + graph_def=graph.as_graph_def(), graph=graph) + if outputs: + output_collection = meta_graph_pb2.CollectionDef() + output_list = output_collection.node_list.value + for i in outputs: + if isinstance(i, ops.Tensor): + output_list.append(to_bytes(i.name)) + else: + output_list.append(to_bytes(i)) + meta_graph.collection_def["train_op"].CopyFrom(output_collection) + + # Create RewriterConfig. + rewriter_cfg = rewriter_config_pb2.RewriterConfig() + rewriter_cfg.optimizers.extend(["constfold", "layout"]) + optimizer = rewriter_cfg.custom_optimizers.add() + optimizer.name = "TensorRTOptimizer" + optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size + optimizer.parameter_map["max_batch_size"].i = max_batch_size + optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op + optimizer.parameter_map[ + "max_workspace_size_bytes"].i = max_workspace_size_bytes + optimizer.parameter_map["precision_mode"].s = to_bytes(precision_mode) + optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines + if cached_engine_batches: + if not isinstance(cached_engine_batches, list): + raise TypeError("cached_engine_batches should be a list.") + optimizer.parameter_map["cached_engine_batches"].list.i.extend( + cached_engine_batches) + + return tf_optimizer.OptimizeGraph( + rewriter_cfg, meta_graph, graph_id=b"tf_graph") + + +def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): """Convert an existing calibration graph to inference graph. Args: calibration_graph_def: the calibration GraphDef object with calibration data + is_dynamic_op: whether to create dynamic static engines from calibration Returns: New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. Raises: @@ -141,9 +180,16 @@ def calib_graph_to_infer_graph(calibration_graph_def): to_string = py2string else: to_string = py3string - + is_calib_graph = False + for n in calibration_graph_def.node: + if n.op == "TRTEngineOp": + is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s + if not is_calib_graph: + tf_logging.error( + "Not a calib graph. Doesn't seem to contain any calibration nodes.") + return None graph_str = calibration_graph_def.SerializeToString() - out = calib_convert(graph_str) + out = calib_convert(graph_str, is_dynamic_op) status = to_string(out[0]) output_graph_def_string = out[1] del graph_str # Save some memory diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc index 0f0508331c13055096714352e83fc360f0ef39b4..d8f97bfbbc7adb10a5dda6fbc2f7a660f6cd7742 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc @@ -19,12 +19,42 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +namespace tensorflow { +namespace tensorrt { + +// std::align is not supported, so this method mimic its behavior. +void* Align(size_t alignment, size_t size, void*& ptr, size_t& space) { + QCHECK_GT(alignment, 0) << "alignment must be greater than 0."; + QCHECK_EQ(0, alignment & (alignment - 1)) << "Alignment must be power of 2."; + QCHECK_GT(size, 0) << "size must be greater than 0."; + QCHECK(ptr) << "ptr must not be nullptr."; + QCHECK_GT(space, 0) << "space must be greater than 0."; + const uintptr_t ptr_val = reinterpret_cast(ptr); + QCHECK_GE(ptr_val + space, ptr_val) << "Provided space overflows."; + if (size > space) return nullptr; + const uintptr_t aligned_ptr_val = ((ptr_val + alignment - 1) & -alignment); + if (aligned_ptr_val > ptr_val + space - size) return nullptr; + ptr = reinterpret_cast(aligned_ptr_val); + const uintptr_t diff = aligned_ptr_val - ptr_val; + space -= diff; + return ptr; +} + +} // namespace tensorrt +} // namespace tensorflow + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT #if NV_TENSORRT_MAJOR > 2 -#include "cuda/include/cuda_runtime_api.h" namespace tensorflow { namespace tensorrt { + void* TRTCudaAllocator::allocate(uint64_t size, uint64_t alignment, uint32_t flags) { assert((alignment & (alignment - 1)) == 0); // zero or a power of 2. @@ -37,10 +67,23 @@ void TRTCudaAllocator::free(void* memory) { cudaFree(memory); } void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment, uint32_t flags) { + // WAR for allocator alignment requirement. Certain cuda API calls require GPU + // memory with alignemtn to cudaDeviceProp::textureAlignment. + // See issue #20856 + alignment = 512; assert((alignment & (alignment - 1)) == 0); // zero or a power of 2. - void* mem = allocator_->AllocateRaw(alignment, size); - VLOG(2) << "Allocated " << size << " bytes with alignment " << alignment - << " @ " << mem; + size_t total_size = size + alignment; + void* mem = allocator_->AllocateRaw(alignment, total_size); + if (!mem) return nullptr; + + void* alloc_mem = mem; + QCHECK(Align(alignment, size, mem, total_size)); + if (mem != alloc_mem) { + QCHECK(mem_map_.insert({mem, alloc_mem}).second); + } + VLOG(2) << "Allocated " << total_size << " bytes memory @" << alloc_mem + << "; aligned to " << size << " bytes @" << mem << " with alignment " + << alignment; return mem; } @@ -50,13 +93,21 @@ TRTDeviceAllocator::TRTDeviceAllocator(tensorflow::Allocator* allocator) } void TRTDeviceAllocator::free(void* memory) { - VLOG(2) << "Deallocating " << memory; - allocator_->DeallocateRaw(memory); + VLOG(2) << "Deallocating @ " << memory; + // allocated memory adjusted for alignment, restore the original pointer + if (memory) { + auto alloc_mem = mem_map_.find(memory); + if (alloc_mem != mem_map_.end()) { + memory = alloc_mem->second; + mem_map_.erase(alloc_mem->first); + } + allocator_->DeallocateRaw(memory); + } } } // namespace tensorrt } // namespace tensorflow #endif -#endif -#endif +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h index a0c2540a7698bc46a65dbd967412351bac2a4dd2..6f944920835b475fc7d12167dbcefa0111b6fb19 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h @@ -16,14 +16,25 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ #define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ +#include -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/core/framework/allocator.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT #include "tensorrt/include/NvInfer.h" +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +namespace tensorflow { +namespace tensorrt { +// std::align is not supported, so this function mimic its behavior. +void* Align(size_t alignment, size_t size, void*& ptr, size_t& space); +} // namespace tensorrt +} // namespace tensorflow +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT #if NV_TENSORRT_MAJOR == 3 // Define interface here temporarily until TRT 4.0 is released namespace nvinfer1 { @@ -38,7 +49,14 @@ class IGpuAllocator { namespace tensorflow { namespace tensorrt { -class TRTCudaAllocator : public nvinfer1::IGpuAllocator { +class TRTBaseAllocator : public nvinfer1::IGpuAllocator { + // Base allocator class so we can have a virtual destructor; + public: + // python wrapper seems to be not happy with an pure virtual destructor; + virtual ~TRTBaseAllocator() = default; +}; + +class TRTCudaAllocator : public TRTBaseAllocator { // Allocator implementation that is using cuda allocator instead of device // allocator in case we can't get device allocator from TF. public: @@ -48,16 +66,24 @@ class TRTCudaAllocator : public nvinfer1::IGpuAllocator { void free(void* memory) override; }; -class TRTDeviceAllocator : public nvinfer1::IGpuAllocator { +class TRTDeviceAllocator : public TRTBaseAllocator { // Allocator implementation wrapping TF device allocators. public: TRTDeviceAllocator(tensorflow::Allocator* allocator); - virtual ~TRTDeviceAllocator() {} + + // TODO(aaroey): base class doesn't have a virtual destructor, work with + // Nvidia to fix it. + virtual ~TRTDeviceAllocator() { + VLOG(1) << "Destroying allocator attached to " << allocator_->Name(); + } void* allocate(uint64_t size, uint64_t alignment, uint32_t flags) override; void free(void* memory) override; private: tensorflow::Allocator* allocator_; + + // supporting alignment from allocation request requires a map to free; + std::unordered_map mem_map_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f515ed03f245f11ad461bac07970c5001a56aaad --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace tensorrt { + +bool RunTest(const size_t alignment, const size_t size, + const intptr_t orig_ptr_val, const size_t orig_space) { + void* const orig_ptr = reinterpret_cast(orig_ptr_val); + void* ptr = orig_ptr; + size_t space = orig_space; + void* result = Align(alignment, size, ptr, space); + if (result == nullptr) { + EXPECT_EQ(orig_ptr, ptr); + EXPECT_EQ(orig_space, space); + return false; + } else { + EXPECT_EQ(result, ptr); + const intptr_t ptr_val = reinterpret_cast(ptr); + EXPECT_EQ(0, ptr_val % alignment); + EXPECT_GE(ptr_val, orig_ptr_val); + EXPECT_GE(space, size); + EXPECT_LE(space, orig_space); + EXPECT_EQ(ptr_val + space, orig_ptr_val + orig_space); + return true; + } +} + +TEST(TRTAllocatorTest, Align) { + for (const size_t space : + {1, 2, 3, 4, 7, 8, 9, 10, 16, 32, 511, 512, 513, 700, 12345}) { + for (size_t alignment = 1; alignment <= space * 4; alignment *= 2) { + for (const intptr_t ptr_val : + {1ul, alignment == 1 ? 1ul : alignment - 1, alignment, alignment + 1, + alignment + (alignment / 2)}) { + if (ptr_val % alignment == 0) { + for (const size_t size : + {1ul, space == 1 ? 1ul : space - 1, space, space + 1}) { + EXPECT_EQ(space >= size, RunTest(alignment, size, ptr_val, space)); + } + } else { + EXPECT_FALSE(RunTest(alignment, space, ptr_val, space)); + const size_t diff = alignment - ptr_val % alignment; + if (space > diff) { + EXPECT_TRUE( + RunTest(alignment, space - diff, ptr_val + diff, space - diff)); + for (const size_t size : + {1ul, space - diff > 1 ? space - diff - 1 : 1ul, space - diff, + space - diff + 1, space - 1}) { + EXPECT_EQ(space - diff >= size, + RunTest(alignment, size, ptr_val, space)); + } + } else { + EXPECT_FALSE(RunTest(alignment, 1, ptr_val, space)); + } + } + } + } + } +} + +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc index dc7c93f869f5ef7c8eaa2a87eed26cfe69597fdb..dab1dd9343be7d5b033a3e04bf0b49fbbf37e9e5 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" #include -#include #include #include "tensorflow/core/platform/logging.h" @@ -37,20 +36,29 @@ TRTInt8Calibrator::TRTInt8Calibrator( : batch_size_(batch_size), done_(false), dev_buffers_(dev_buffers), - calib_running_(false), + // Make sure setBatch() waits until getBatch() is called (the first time). + calib_running_(true), batch_is_set_(false), engine_name_(engine_name) {} +TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data) + : batch_size_(0), + done_(true), + calib_running_(false), + batch_is_set_(false), + calibration_table_(calib_data) {} + bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, const cudaStream_t stream) { tensorflow::mutex_lock lock(cond_mtx_); - while ((calib_running_ || batch_is_set_) && - !done_) { // wait while calibration is running - cond_.wait(lock); - } + + // Wait while the queue is full or calibration is running. + while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock); if (done_) return false; CHECK(!calib_running_ && !batch_is_set_); VLOG(1) << "Set Batch Waiting finished"; + + // Sets the batch. for (const auto it : data) { auto devptr = dev_buffers_.find(it.first); if (devptr == dev_buffers_.end()) { @@ -59,8 +67,6 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, } const auto& d = devptr->second; - // TODO(aaroey): we should not use sync copy on default stream. Make sure - // stream->ThenMemcpy() is used in future PRs. // TODO(sami,aaroey): Need to figure out a way to ensure synchronization // between stream, perhaps using a tensor? auto status = cudaMemcpyAsync(d.first, it.second, d.second, @@ -72,8 +78,8 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, } // TODO(Sami, aaorey): Find an alternative way! - cudaStreamSynchronize( - stream); // we have to wait for the stream before returning! + // we have to wait for the stream before returning! + cudaStreamSynchronize(stream); batch_is_set_ = true; cond_.notify_all(); return true; @@ -82,23 +88,21 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, int num_bindings) { tensorflow::mutex_lock lock(cond_mtx_); + // Notify finish of last round of calibration. calib_running_ = false; cond_.notify_all(); - while ((!batch_is_set_ && !done_)) { // wait until new batch arrives - cond_.wait(lock); - } - if (done_) { - return false; - } + // Wait until new batch arrives + while ((!batch_is_set_ && !done_)) cond_.wait(lock); + if (done_) return false; + // Gets the batch for (int i = 0; i < num_bindings; i++) { auto it = dev_buffers_.find(names[i]); if (it == dev_buffers_.end()) { LOG(FATAL) << "Calibration engine asked for unknown tensor name '" << names[i] << "' at position " << i; } - bindings[i] = it->second.first; } batch_is_set_ = false; @@ -106,8 +110,21 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, return true; } +void TRTInt8Calibrator::waitAndSetDone() { + tensorflow::mutex_lock lock(cond_mtx_); + // Wait while the queue is full or calibration is running, so we don't miss + // the last batch. + while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock); + if (!done_) { + done_ = true; + cond_.notify_all(); + } +} + const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { - return nullptr; + if (calibration_table_.empty()) return nullptr; + length = calibration_table_.size(); + return calibration_table_.data(); } void TRTInt8Calibrator::setDone() { @@ -117,7 +134,11 @@ void TRTInt8Calibrator::setDone() { } void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, - std::size_t length) {} + std::size_t length) { + calibration_table_ = string((const char*)ptr, length); + VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr + << " length=" << length; +} TRTInt8Calibrator::~TRTInt8Calibrator() { VLOG(1) << "Destroying calibrator for " << engine_name_; } diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h index d77aa2c5ab184756adaee38f88180b3c128ebe03..65466c9741989fda5f82fc27d813d026f35fe386 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h @@ -36,32 +36,59 @@ namespace tensorrt { struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { public: + // Construct a calibrator for future calibration. TRTInt8Calibrator( const std::unordered_map>& dev_buffers, int batch_size, string engine_name); + + // Construct a finalized calibrator where we don't need to run calibration any + // more, as the calibration data is provided. + TRTInt8Calibrator(const string& calibration_data); + + ~TRTInt8Calibrator(); + int getBatchSize() const override; + bool getBatch(void* bindings[], const char* names[], int num_bindings) override; + bool setBatch(const std::unordered_map& data, const cudaStream_t stream); + + // Wait until the last batch is consumed by the calibrator and set done. + void waitAndSetDone(); + + // Notify that calibration is done and future batches provided by setBatch() + // will be ignored. void setDone(); + + // If not null, calibration is skipped. const void* readCalibrationCache(std::size_t& length) override; + void writeCalibrationCache(const void* ptr, std::size_t length) override; - ~TRTInt8Calibrator(); + + const string& getCalibrationTableAsString() { return calibration_table_; } private: const int batch_size_; - tensorflow::mutex cond_mtx_; // mutex for condition_variable - tensorflow::condition_variable cond_; // condition variable to implement - // producer-consumer queue for - // calibration + + // mutex for condition_variable + tensorflow::mutex cond_mtx_; + + // condition variable to implement producer-consumer queue for calibration + tensorflow::condition_variable cond_; + + // Is calibration finished? bool done_; - const std::unordered_map> - dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with - // buffer names + + // Map to keep tensorrt input buffers and sizes keyed with buffer names + const std::unordered_map> dev_buffers_; + bool calib_running_; bool batch_is_set_; + string engine_name_; + string calibration_table_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index e3469124acd4b9f6f4dd81b9998aa60bfe469b35..d7d56cb95e033ea55bd3aa385a707e7a7cfc557b 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" #include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" @@ -34,50 +35,48 @@ limitations under the License. namespace tensorflow { namespace tensorrt { + class TRTCalibrationResource : public tensorflow::ResourceBase { public: - TRTCalibrationResource() - : calibrator_(nullptr), - builder_(nullptr), - network_(nullptr), - engine_(nullptr), - logger_(nullptr), - thr_(nullptr) {} - ~TRTCalibrationResource() { VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); + builder_.reset(); + engine_.reset(); + // We need to manually destroy the builder and engine before the allocator + // is destroyed. + allocator_.reset(); } string DebugString() override { std::stringstream oss; - oss << " Calibrator = " << std::hex << calibrator_ << std::dec << std::endl - << " Builder = " << std::hex << builder_ << std::dec << std::endl - << " Network = " << std::hex << network_ << std::dec << std::endl - << " Engine = " << std::hex << engine_ << std::dec << std::endl - << " Logger = " << std::hex << logger_ << std::dec << std::endl - << " Allocator = " << std::hex << allocator_.get() << std::dec - << std::endl - << " Thread = " << std::hex << thr_ << std::dec << std::endl; + using std::dec; + using std::endl; + using std::hex; + oss << " Calibrator = " << hex << calibrator_.get() << dec << endl + << " Builder = " << hex << builder_.get() << dec << endl + << " Engine = " << hex << engine_.get() << dec << endl + << " Logger = " << hex << &logger_ << dec << endl + << " Allocator = " << hex << allocator_.get() << dec << endl + << " Thread = " << hex << thr_.get() << dec << endl; return oss.str(); } - TRTInt8Calibrator* calibrator_; - nvinfer1::IBuilder* builder_; - nvinfer1::INetworkDefinition* network_; - nvinfer1::ICudaEngine* engine_; - std::shared_ptr allocator_; - tensorflow::tensorrt::Logger* logger_; + std::unique_ptr calibrator_; + TrtUniquePtrType builder_; + TrtUniquePtrType engine_; + std::unique_ptr allocator_; + tensorflow::tensorrt::Logger logger_; // TODO(sami): Use threadpool threads! - std::thread* thr_; + std::unique_ptr thr_; }; -class TRTWeightStore : public tensorflow::ResourceBase { +class TRTWeightStore { public: TRTWeightStore() {} virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); } - string DebugString() override { + string DebugString() { std::stringstream oss; size_t len_bytes = 0; for (const auto& v : store_) { diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index cc42913ecadc3e15fbb4a4a322f125579f075da2..b43f1b190f5f8cfe98959dd9f2838e4d45759e5c 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/segment/segment.h" +#include #include #include #include @@ -32,6 +33,7 @@ namespace tensorflow { namespace tensorrt { namespace segment { using ::tensorflow::strings::StrAppend; + // A simple graph representation to mirror tensorflow::Graph. This structure // helps saving memory since segmenter modifies the graph in place, preventing // the need to create a copy of the graph. It is composed of edges and nodes. @@ -215,7 +217,7 @@ namespace { bool CheckCycles(const std::unique_ptr& g, const SimpleNode* src, const std::vector& start) { - // copied from TF ReverseDFS. + // Copied from TF ReverseDFS, which only works for tensorflow::Graph. struct Work { SimpleNode* node; bool leave; // Are we entering or leaving n? @@ -269,6 +271,24 @@ bool CanContractEdge(const SimpleEdge* edge, // 1. Get all nodes incoming to 'dst', excluding 'src' // 2. Reverse DFS from those nodes // 3. If reverse DFS reaches 'src' then we have a cycle + // + // TODO(aaroey): there are several problems with the current approach: + // 1. src->dst->src, this is not detected but it should be; + // 2. src->dst->...(any node sequence that doesn't contain src)...->dst, this + // is detected but it should not be. + // + // Note that it's fine that dst connects back to src indirectly (i.e. through + // a path with length > 1 that consists of intermedia nodes other than src). + // While loops is one example. + // + // The goal is to make sure that the trt subgraph: + // 1. has no loops (i.e. is a DAG), and + // 2. if there is a path in the subgraph from X to Y (X and Y are both nodes + // in the subgraph), then all paths from X to Y are in the subgraph. + // + // To achieve this goal, the correct way seems to be: + // 1. remove any direct edge from src->dst; + // 2. detect if src can reach dst, if so they cannot be merged. std::vector dfs_start_nodes; for (SimpleNode* node : dst->in_nodes()) { if (node != src) { @@ -276,8 +296,8 @@ bool CanContractEdge(const SimpleEdge* edge, } } - bool is_cycle = CheckCycles(graph, src, dfs_start_nodes); - return !is_cycle; + const bool has_cycle = CheckCycles(graph, src, dfs_start_nodes); + return !has_cycle; } } // namespace @@ -342,22 +362,20 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, } tensorflow::Status SegmentGraph( - const tensorflow::GraphDef& gdef, - const std::function& candidate_fn, - const SegmentOptions& options, SegmentNodesVector* segments) { - // Create a Graph representation of the GraphDef. - tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), - gdef.library()); - tensorflow::Graph graph(flib); - TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( - tensorflow::GraphConstructorOptions(), gdef, &graph)); - return SegmentGraph(&graph, candidate_fn, options, segments); -} - -tensorflow::Status SegmentGraph( - tensorflow::Graph* tf_graph, + const tensorflow::Graph* tf_graph, const std::function& candidate_fn, + const std::function& input_candidate_fn, + const std::function& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments) { + // Steps: + // 1. run the segmentation algorithm to find all the segments, which uses + // candidate_fn to determine the candidates segment nodes; + // 2. for each segments, remove the nodes that are inputs/outputs of the + // segment but are not eligible, using input/output_candidate_fn to + // determine the eligibilities; + // 3. convert the segment into expected return format and return the result. + + // --------------------------------- Step 1 --------------------------------- auto graph = std::unique_ptr(new SimpleGraph(tf_graph)); // Use a union-find to collect the nodes that belong to the same // segment. A node value of nullptr indicates that the node is not a candidate @@ -372,14 +390,19 @@ tensorflow::Status SegmentGraph( node_segments.emplace_back(node); } - // The segmentation algorithm below visits nodes in reverse - // topological order and attempts to merge nodes along output - // edges. That means that subgraphs grow from the output-side of the - // network towards the inputs. In general this is not guaranteed to - // produce a globally optimal segmentation. In the future if we have - // a measure of how beneficial it is to include a given node in a - // TRT subgraph then we can revisit this algorithm to take advantage - // of that information. + // The segmentation algorithm below visits nodes in reverse topological order + // and attempts to merge nodes along output edges. That means that subgraphs + // grow from the output-side of the network towards the inputs. + // + // In general this is not guaranteed to produce a globally optimal + // segmentation. For exaample, consider graph with node {A, B, C, D} and edges + // {A->B, A->C, B->D, C->D), where A, B, D are trt compatible but C is not, so + // in theory we can choose to contract either A, B or B, D but not both, but + // here it always choose to contract B, D. + // + // In the future if we have a measure of how beneficial it is to include a + // given node in a TRT subgraph then we can revisit this algorithm to take + // advantage of that information. std::vector tforder; tensorflow::GetPostOrder(*tf_graph, &tforder); // use postorder implementation from tensorflow and construct mirror in @@ -391,44 +414,39 @@ tensorflow::Status SegmentGraph( } for (const SimpleNode* node : order) { // All output nodes of 'node' have been visited... - VLOG(2) << "Trying node " << node->name() << " id=" << node->id(); - + VLOG(3) << "Trying node " << node->name() << " id=" << node->id(); // 'node' must be a TRT candidate... if (node_segments[node->id()].Value() == nullptr) { - VLOG(2) << "... not a TRT candidate"; + VLOG(3) << "... not a TRT candidate"; continue; } - // Contract output edges to combine 'node' with output // nodes. Iterate since combining two nodes may unblock other // combining. while (true) { std::set contract_edges; for (const SimpleEdge* out_edge : node->out_edges()) { - VLOG(2) << "... out node " << out_edge->dst()->name() << " ( " + VLOG(3) << "... out node " << out_edge->dst()->name() << " ( " << out_edge->dst()->id() << " <- " << node->id() << " )"; if (out_edge->IsControlEdge()) { - VLOG(2) << "... ... Control Edge, Skipping"; + VLOG(3) << "... ... Control Edge, Skipping"; continue; } // Out node must be TRT candidate... if (node_segments[out_edge->dst()->id()].Value() == nullptr) { - VLOG(2) << "... ... not a TRT candidate"; + VLOG(3) << "... ... not a TRT candidate"; continue; } - if (CanContractEdge(out_edge, graph)) { - VLOG(2) << "... ... can contract"; + VLOG(3) << "... ... can contract"; contract_edges.insert(out_edge); } else { - VLOG(2) << "... ... cannot contract, would form cycle"; + VLOG(3) << "... ... cannot contract, would form cycle"; } } - if (contract_edges.empty()) { break; } - // Contract edges and collect the adjacent nodes into the same // segment/subgraph. while (!contract_edges.empty()) { @@ -436,7 +454,7 @@ tensorflow::Status SegmentGraph( const SimpleNode* src = contract_edge->src(); const SimpleNode* dst = contract_edge->dst(); - VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " (" + VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " (" << src->id() << " <- " << dst->id(); node_segments[src->id()].Merge(&node_segments[dst->id()]); @@ -457,11 +475,22 @@ tensorflow::Status SegmentGraph( // Collect the segments/subgraphs. Each subgraph is represented by a // set of the names of the nodes in that subgraph. - std::unordered_map> sg_map; + + // A map from the segment identifier (currently the name of the root node of + // the segment tree) to the segment nodes set. + std::map> sg_map; + + // A map from the segment identifier (currently the name of the root node of + // the segment tree) to the device names that the nodes in the segment are + // assigned to. + // + // TODO(aaroey): nodes assigned to different devices should not be merged, + // fix this. std::unordered_map> device_maps; + for (auto& u : node_segments) { if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) { - sg_map[u.ParentValue()->name()].insert(u.Value()->name()); + sg_map[u.ParentValue()->name()].insert(u.Value()->tf_node()); auto tf_node = u.Value()->tf_node(); // has_assigned_device_name() is expected to return true // when called from optimization pass. However, since graph @@ -482,25 +511,113 @@ tensorflow::Status SegmentGraph( } } + // --------------------------------- Step 2 --------------------------------- + // Remove ineligible input/output nodes. + for (auto& itr : sg_map) { + std::set& segment_nodes = itr.second; + VLOG(1) << "Segment original size: " << segment_nodes.size(); + while (true) { + std::deque in_nodes_que, out_nodes_que; + // Find an input node that is not eligible and add it to the queue. + // Nodes that has no incoming edges should not be treated as "input", + // as there are really no inputs to them. Similar for output nodes. + for (auto node : segment_nodes) { + bool added = false; + for (const tensorflow::Edge* edge : node->in_edges()) { + if (!edge->IsControlEdge() && !edge->src()->IsSource() && + !segment_nodes.count(edge->src())) { // 'node' is an input node. + if (!input_candidate_fn(edge)) { + in_nodes_que.push_back(node); + added = true; + break; + } + } + } + if (added) continue; // Only adding the node once to either queue. + for (const tensorflow::Edge* edge : node->out_edges()) { + if (!edge->dst()->IsSink() && !edge->IsControlEdge() && + !segment_nodes.count(edge->dst())) { // 'node' is an output node. + if (!output_candidate_fn(edge)) { + out_nodes_que.push_back(node); + break; + } + } + } + } + if (in_nodes_que.empty() && out_nodes_que.empty()) { + // No more ineligible input/output nodes. + break; + } + // Now for each ineligible node, remove all of its inputs or outputs from + // the subgraph. + // + // It can be proven that, if the original subgraph: + // 1. is a DAG, and + // 2. all paths between two nodes in the subgraph are all inside the + // subgraph + // then after doing this operation the resulting subgraph will keep the + // same properties 1 and 2. + // + // For simplicity we use heuristics: for input and const output nodes + // remove all their inputs, and for non-const output nodes remove all + // their outputs. In this way, for common cases the number of removed + // nodes should be minimum. + auto remove_nodes = [&segment_nodes]( + bool is_input_nodes, + std::deque* que) { + // Run a BFS on the queue to find all the input/output nodes. + std::set visited; + std::set logged(que->begin(), que->end()); + while (!que->empty()) { + auto node = que->front(); + que->pop_front(); + if (!visited.insert(node).second) continue; + segment_nodes.erase(node); + for (auto in : (is_input_nodes || node->type_string() == "Const") + ? node->in_nodes() + : node->out_nodes()) { + if (segment_nodes.count(in)) { + que->push_back(in); + if (VLOG_IS_ON(2)) { + if (!logged.count(in)) { + VLOG(2) << "----> Need to remove node " << in->name() + << " because one of its " + << (is_input_nodes ? "output" : "input") + << " nodes in the graph was removed: " + << node->name(); + logged.insert(in); + } + } + } + } + } + }; + remove_nodes(true, &in_nodes_que); + remove_nodes(false, &out_nodes_que); + } + VLOG(1) << "Segment new size: " << segment_nodes.size(); + } + + // --------------------------------- Step 3 --------------------------------- // Convert the segments into the expected return format for (const auto& itr : sg_map) { - const auto& segment_node_names = itr.second; + const std::set& segment_nodes = itr.second; if (VLOG_IS_ON(1)) { - string s; - for (const auto& name : segment_node_names) { - s += " " + name; - } - VLOG(1) << "Segment " << segments->size() << ":" << s; + string s = "parent=" + itr.first + ":"; + for (auto node : segment_nodes) s += " " + node->name(); + VLOG(1) << "Segment " << segments->size() << ": " << s; } // Don't use small segments. - if (static_cast(segment_node_names.size()) < - options.minimum_segment_size) { + if (static_cast(segment_nodes.size()) < options.minimum_segment_size) { VLOG(1) << "Segment " << segments->size() << " has only " - << segment_node_names.size() << " nodes, dropping"; + << segment_nodes.size() << " nodes, dropping"; continue; } + // TODO(sami): Make segmenter placement aware once trtscopes are in place + std::set segment_node_names; + for (auto node : itr.second) segment_node_names.insert(node->name()); const auto& dev_itr = device_maps.find(itr.first); if (dev_itr == device_maps.end() || dev_itr->second.empty()) { VLOG(1) << "No device assigned to segment " << segments->size(); diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index 1568dd915344e6ba982b5a5550cc5386e047ff9f..8c44eb782aa37052680d0e06023f29dc65e327c6 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -29,8 +29,9 @@ namespace tensorflow { namespace tensorrt { namespace segment { -// vector of segments, each entry contains a device name and a set of nodes in -// segment +// Vector of segments, each entry contains a set of node names and a device name +// in the segment. +// TODO(aaroey): use node pointer instead of node name. using SegmentNodesVector = std::vector, string>>; struct SegmentOptions { @@ -39,20 +40,6 @@ struct SegmentOptions { std::set exclude_node_list; }; -// Get the subgraphs of a graph that can be handled by TensorRT. -// -// @param gdef The GraphDef describing the network -// @param candidate_fn A function that returns true for a NodeDef if -// that node can be handled by TensorRT. -// @param segments Returns the TensorRT segments/subgraphs. Each entry -// in the vector describes a subgraph by giving a set of the names of -// all the NodeDefs in that subgraph. -// @return the status. -tensorflow::Status SegmentGraph( - const tensorflow::GraphDef& gdef, - const std::function& candidate_fn, - const SegmentOptions& options, SegmentNodesVector* segments); - // Get the subgraphs of a graph that can be handled by TensorRT. // // @param graph tensorflow::Graph of the network @@ -63,8 +50,10 @@ tensorflow::Status SegmentGraph( // all the NodeDefs in that subgraph. // @return the status. tensorflow::Status SegmentGraph( - tensorflow::Graph* tf_graph, + const tensorflow::Graph* tf_graph, const std::function& candidate_fn, + const std::function& input_candidate_fn, + const std::function& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments); } // namespace segment diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index f5b2d258d70d5577a9d68f2d9f6d6e678ede97ce..5937fa8259a39339e92b150862d195ee1f23f70a 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -14,350 +14,245 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/segment/segment.h" -#include "tensorflow/c/c_api.h" -#include "tensorflow/core/framework/graph.pb.h" + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" namespace tensorflow { namespace tensorrt { namespace segment { namespace test { +namespace ops = ::tensorflow::ops; class SegmentTest : public ::testing::Test { - public: - bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def); - - TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name); - TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, - TF_Status* s, const char* name); - - std::function MakeCandidateFn( - const std::set& node_names); - protected: - void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name, - TF_Operation** op); - void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, - TF_Status* s, const char* name, TF_Operation** op, bool check); - - SegmentOptions default_options_; -}; - -bool SegmentTest::GetGraphDef(TF_Graph* graph, - tensorflow::GraphDef* graph_def) { - TF_Status* s = TF_NewStatus(); - TF_Buffer* buffer = TF_NewBuffer(); - TF_GraphToGraphDef(graph, buffer, s); - bool ret = TF_GetCode(s) == TF_OK; - EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length); - TF_DeleteBuffer(buffer); - TF_DeleteStatus(s); - return ret; -} + std::function MakeCandidateFn( + const std::set& node_names) { + return [node_names](const tensorflow::Node* node) -> bool { + return node_names.find(node->name()) != node_names.end(); + }; + } -std::function SegmentTest::MakeCandidateFn( - const std::set& node_names) { - return [node_names](const tensorflow::Node* node) -> bool { - return node_names.find(node->name()) != node_names.end(); - }; -} + std::function MakeInputEdgeCandidateFn( + const std::set& node_names) { + return [node_names](const tensorflow::Edge* in_edge) -> bool { + return node_names.find(in_edge->dst()->name()) != node_names.end(); + }; + } -void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s, - const char* name, TF_Operation** op) { - TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); - TF_SetAttrType(desc, "dtype", TF_INT32); - *op = TF_FinishOperation(desc, s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - ASSERT_NE(*op, nullptr); -} + std::function MakeOutputEdgeCandidateFn( + const std::set& node_names) { + return [node_names](const tensorflow::Edge* out_edge) -> bool { + return node_names.find(out_edge->src()->name()) != node_names.end(); + }; + } -TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s, - const char* name) { - TF_Operation* op; - PlaceholderHelper(graph, s, name, &op); - return op; -} + void RunTest(const tensorflow::Graph* graph, + const std::set& candidates, + const std::set& input_candidates, + const std::set& output_candidates, + const std::vector>& expected_segments) { + SegmentNodesVector segments; + TF_EXPECT_OK(SegmentGraph(graph, MakeCandidateFn(candidates), + MakeInputEdgeCandidateFn(input_candidates), + MakeOutputEdgeCandidateFn(output_candidates), + default_options_, &segments)); + ValidateSegment(segments, expected_segments); + } -void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, - TF_Status* s, const char* name, TF_Operation** op, - bool check) { - TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); - TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; - TF_AddInputList(desc, add_inputs, 2); - *op = TF_FinishOperation(desc, s); - if (check) { - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - ASSERT_NE(*op, nullptr); + void ValidateSegment(const SegmentNodesVector& segments, + const std::vector>& expected_segments) { + EXPECT_EQ(expected_segments.size(), segments.size()); + for (int i = 0; i < segments.size(); ++i) { + const auto& segment_node_names = segments[i].first; + const auto& expected = expected_segments[i]; + for (const auto& name : expected) { + EXPECT_TRUE(segment_node_names.count(name)) + << "Segment " << i << " is missing expected node: " << name; + } + if (segment_node_names.size() == expected.size()) continue; + for (const auto& name : segment_node_names) { + EXPECT_TRUE(expected.count(name)) + << "Unexpected node found in segment " << i << ": " << name; + } + } } -} -TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r, - TF_Graph* graph, TF_Status* s, - const char* name) { - TF_Operation* op; - AddHelper(l, r, graph, s, name, &op, true); - return op; + SegmentOptions default_options_; +}; + +std::set operator-(const std::set& lhs, const string& rhs) { + std::set result = lhs; + CHECK(result.erase(rhs)); + return result; } TEST_F(SegmentTest, Empty) { - TF_Graph* graph = TF_NewGraph(); - - GraphDef graph_def; - ASSERT_TRUE(GetGraphDef(graph, &graph_def)); - - SegmentNodesVector segments; - ASSERT_EQ( - SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments), - tensorflow::Status::OK()); - + Scope s = Scope::NewRootScope(); + tensorflow::Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); // Expect no segments/subgraphs. - EXPECT_TRUE(segments.empty()); - TF_DeleteGraph(graph); + RunTest(&g, {}, {}, {}, {}); } TEST_F(SegmentTest, Simple) { - TF_Status* s = TF_NewStatus(); - TF_Graph* graph = TF_NewGraph(); - // feed - // // || + // // \\ // add0 add1 - // | | / + // | \ / // | add2 - // | / || + // | / \\ // add3 add4 - // | / + // \ / // - // - TF_Operation* feed = Placeholder(graph, s, "feed"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); - - TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); - TF_Operation* add4 = Add(add2, add2, graph, s, "add4"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); - - GraphDef graph_def; - ASSERT_TRUE(GetGraphDef(graph, &graph_def)); - - SegmentNodesVector segments; - ASSERT_EQ( - SegmentGraph(graph_def, - MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}), - default_options_, &segments), - tensorflow::Status::OK()); - - // Expect all Add operations to be collapsed into a single segment - ASSERT_EQ(segments.size(), 1); - std::vector expected{"add0", "add1", "add2", "add3", "add4"}; - for (const auto& ex : expected) { - EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end()) - << "Missing expected node " << ex; - } - TF_DeleteGraph(graph); - TF_DeleteStatus(s); + Scope s = Scope::NewRootScope(); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT); + auto add0 = ops::Add(s.WithOpName("add0"), feed, feed); + auto add1 = ops::Add(s.WithOpName("add1"), feed, feed); + auto add2 = ops::Add(s.WithOpName("add2"), add0, add1); + auto add3 = ops::Add(s.WithOpName("add3"), add0, add2); + auto add4 = ops::Add(s.WithOpName("add4"), add2, add2); + tensorflow::Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + + // All Add operations are candidates, and we expect all of them to be + // collapsed into a single segment + const std::set all_adds = {"add0", "add1", "add2", "add3", "add4"}; + RunTest(&g, all_adds, all_adds, all_adds, {all_adds}); + + // Make add1 not a candidate, and we expect all other Add operations to be + // collapsed into a single segment + auto without_add1 = all_adds - "add1"; + RunTest(&g, without_add1, without_add1, without_add1, {without_add1}); + + // Make add1 not a candidate and add2 not an input candidate, and we expect + // add0 and add2 are removed from the segment. + auto without_add2 = all_adds - "add2"; + RunTest(&g, without_add1, without_add2, without_add1, {{"add3", "add4"}}); + + // Making add2 not an input candidate itself won't affect anything. + RunTest(&g, all_adds, without_add2, all_adds, {all_adds}); + + // Making add1 not an input candidate. + RunTest(&g, all_adds, without_add1, all_adds, {without_add1}); + + // Making add3 not an output candidate doesn't affect anything, since it's + // output is sink. + auto without_add3 = all_adds - "add3"; + RunTest(&g, all_adds, all_adds, without_add3, {all_adds}); } TEST_F(SegmentTest, AvoidCycle) { - TF_Status* s = TF_NewStatus(); - TF_Graph* graph = TF_NewGraph(); - - // add2 is not a TRT candidate so add0/add3 cannot be formed as a - // subgraph - // // feed - // // || + // // \\ // add0 add1 - // | | / + // | \ / // | add2 - // | / || + // | / \\ // add3 add4 - // | / + // \ / // - // - TF_Operation* feed = Placeholder(graph, s, "feed"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); - - TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); - TF_Operation* add4 = Add(add2, add2, graph, s, "add4"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); - - GraphDef graph_def; - ASSERT_TRUE(GetGraphDef(graph, &graph_def)); - - SegmentNodesVector segments; - ASSERT_EQ( - SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}), - default_options_, &segments), - tensorflow::Status::OK()); - - // Expect no subgraphs - EXPECT_EQ(segments.size(), 0); - TF_DeleteGraph(graph); - TF_DeleteStatus(s); + Scope s = Scope::NewRootScope(); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT); + auto add0 = ops::Add(s.WithOpName("add0"), feed, feed); + auto add1 = ops::Add(s.WithOpName("add1"), feed, feed); + auto add2 = ops::Add(s.WithOpName("add2"), add0, add1); + auto add3 = ops::Add(s.WithOpName("add3"), add0, add2); + auto add4 = ops::Add(s.WithOpName("add4"), add2, add2); + tensorflow::Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + + // add2 is not a TRT candidate so there should be no segments generated. + const std::set without_add2 = {"add0", "add1", "add3", "add4"}; + RunTest(&g, without_add2, without_add2, without_add2, {}); } TEST_F(SegmentTest, Multiple) { - TF_Status* s = TF_NewStatus(); - TF_Graph* graph = TF_NewGraph(); - - // add5 is not a TRT candidate so two subgraphs should be formed - // - // feed - // // || || - // add0 add1 add7 - // | | / / || - // | add2-----add5 add8 - // | / | | | | - // add3 add4 add6 - // | | / - // - // - TF_Operation* feed = Placeholder(graph, s, "feed"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); - - TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add7 = Add(feed, feed, graph, s, "add7"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add5 = Add(add2, add7, graph, s, "add5"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add8 = Add(add7, add7, graph, s, "add8"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); - TF_Operation* add4 = Add(add2, add5, graph, s, "add4"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); - TF_Operation* add6 = Add(add5, add8, graph, s, "add6"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add6"), string(TF_OperationName(add6))); - - GraphDef graph_def; - ASSERT_TRUE(GetGraphDef(graph, &graph_def)); - - SegmentNodesVector segments; - ASSERT_EQ(SegmentGraph(graph_def, - MakeCandidateFn({"add0", "add1", "add2", "add3", - "add4", "add6", "add7", "add8"}), - default_options_, &segments), - tensorflow::Status::OK()); - - // Expect two subgraphs - EXPECT_EQ(segments.size(), 2); - - std::vector expected0{"add6", "add8"}; - for (const auto& ex : expected0) { - EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end()) - << "Missing expected node " << ex; - } - - std::vector expected1{"add0", "add1", "add2", "add3"}; - for (const auto& ex : expected1) { - EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end()) - << "Missing expected node " << ex; - } - TF_DeleteGraph(graph); - TF_DeleteStatus(s); + // feed + // // || \\ + // add0 add1 add7 + // | \ / / \\ + // | add2 / \\ + // | || \ | || + // | || add5 add8 + // | / \ / \ / + // add3 add4 add6 + // \ | / + // + Scope s = Scope::NewRootScope(); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT); + auto add0 = ops::Add(s.WithOpName("add0"), feed, feed); + auto add1 = ops::Add(s.WithOpName("add1"), feed, feed); + auto add7 = ops::Add(s.WithOpName("add7"), feed, feed); + auto add2 = ops::Add(s.WithOpName("add2"), add0, add1); + auto add5 = ops::Add(s.WithOpName("add5"), add2, add7); + auto add8 = ops::Add(s.WithOpName("add8"), add7, add7); + auto add3 = ops::Add(s.WithOpName("add3"), add0, add2); + auto add4 = ops::Add(s.WithOpName("add4"), add2, add5); + auto add6 = ops::Add(s.WithOpName("add6"), add5, add8); + tensorflow::Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + + const std::set all_adds = {"add0", "add1", "add2", "add3", "add4", + "add5", "add6", "add7", "add8"}; + // Make add5 not a TRT candidate, and we expect two segments. + auto without_add5 = all_adds - "add5"; + RunTest(&g, without_add5, without_add5, without_add5, + {{"add0", "add1", "add2", "add3"}, {"add6", "add8"}}); + + // Make add8 not a candidate and add6 not an input candidate, then all direct + // and indirect inputs of add6 will be removed from the segment. + auto without_add8 = all_adds - "add8"; + auto without_add6 = all_adds - "add6"; + RunTest(&g, without_add8, without_add6, all_adds, {{"add3", "add4"}}); + + // Make add3 not a candidate and add0 not an output candidate, then all + // direct and indirect outputs of add0 will be removed from the segment. + auto without_add3 = all_adds - "add3"; + auto without_add0 = all_adds - "add0"; + RunTest(&g, without_add3, all_adds, without_add0, {{"add1", "add7", "add8"}}); } TEST_F(SegmentTest, BigIfElse) { - TF_Status* s = TF_NewStatus(); - TF_Graph* graph = TF_NewGraph(); - - // add2 is not a TRT candidate - // // feed // || // add0 - // // || + // // \\ // add1 add4 // || || // add2 add5 // || || // add3 add6 - // || // + // \\ // // add7 // || // - // - TF_Operation* feed = Placeholder(graph, s, "feed"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); - - TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add1 = Add(add0, add0, graph, s, "add1"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add2 = Add(add1, add1, graph, s, "add2"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add3 = Add(add2, add2, graph, s, "add3"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add4 = Add(add0, add0, graph, s, "add4"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add5 = Add(add4, add4, graph, s, "add5"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add6 = Add(add5, add5, graph, s, "add6"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add7 = Add(add3, add6, graph, s, "add7"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add7"), string(TF_OperationName(add7))); - - GraphDef graph_def; - ASSERT_TRUE(GetGraphDef(graph, &graph_def)); - - SegmentNodesVector segments; - ASSERT_EQ(SegmentGraph(graph_def, - MakeCandidateFn({"add0", "add1", "add3", "add4", - "add5", "add6", "add7"}), - default_options_, &segments), - tensorflow::Status::OK()); - - // Expect 2 subgraphs - EXPECT_EQ(segments.size(), 2); - - std::vector expected0{"add3", "add4", "add5", "add6", "add7"}; - for (const auto& ex : expected0) { - EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end()) - << "Missing expected node " << ex; - } - - std::vector expected1{"add0", "add1"}; - for (const auto& ex : expected1) { - EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end()) - << "Missing expected node " << ex; - } - TF_DeleteGraph(graph); - TF_DeleteStatus(s); + Scope s = Scope::NewRootScope(); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT); + auto add0 = ops::Add(s.WithOpName("add0"), feed, feed); + auto add1 = ops::Add(s.WithOpName("add1"), add0, add0); + auto add2 = ops::Add(s.WithOpName("add2"), add1, add1); + auto add3 = ops::Add(s.WithOpName("add3"), add2, add2); + auto add4 = ops::Add(s.WithOpName("add4"), add0, add0); + auto add5 = ops::Add(s.WithOpName("add5"), add4, add4); + auto add6 = ops::Add(s.WithOpName("add6"), add5, add5); + auto add7 = ops::Add(s.WithOpName("add7"), add3, add6); + tensorflow::Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + + // Make add2 not a TRT candidate, and we expect 2 segments. + const std::set all_adds = {"add0", "add1", "add2", "add3", + "add4", "add5", "add6", "add7"}; + RunTest(&g, all_adds - "add2", all_adds, all_adds, + {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}}); } } // namespace test diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index f36495f6b69ecb2f2a8d730b9ae4919fea3c04b8..f30dba59ad55317d7ad7730e4dc66c9aba4e6a6b 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -28,59 +28,47 @@ limitations under the License. namespace tensorflow { namespace shape_inference { -tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { - tensorflow::tensorrt::Logger logger; - string serialized_engine; - TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine)); - nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); - nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( - serialized_engine.c_str(), serialized_engine.size(), - tensorrt::PluginFactoryTensorRT::GetInstance()); - - int num_batch = -1; - std::vector<::tensorflow::DataType> input_type; - TF_RETURN_IF_ERROR(context->GetAttr("InT", &input_type)); - for (size_t i = 0; i < context->num_inputs(); i++) { - // Check if input shape is legit - auto input_shape = context->input(i); - for (int j = 0; j < context->Rank(input_shape); j++) { - auto dim_handler = context->Dim(input_shape, j); - if (j == 0) { - if (i == 0) { - num_batch = context->Value(dim_handler); - } else if (num_batch != context->Value(dim_handler)) { - // TODO(jie): TensorRT engine requires consistent batch between inputs - // tensors. Segmenter should be aware of this. - LOG(FATAL) << "TensorRT engine requires consistent batch size"; - } - } - } +tensorflow::Status TRTEngineOpShapeInference(InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->UnknownShape()); } - // Arrange input here - std::vector input_nodes; - TF_RETURN_IF_ERROR(context->GetAttr("input_nodes", &input_nodes)); - - // Arrange output here - std::vector output_nodes; - TF_RETURN_IF_ERROR(context->GetAttr("output_nodes", &output_nodes)); - for (size_t i = 0; i < output_nodes.size(); i++) { - int binding_index = trt_engine->getBindingIndex(output_nodes[i].c_str()); - ShapeHandle output_shape; - std::vector dim_vec; - dim_vec.emplace_back(context->MakeDim(num_batch)); - if (binding_index != -1) { - auto dims = trt_engine->getBindingDimensions(binding_index); - for (int j = 0; j < dims.nbDims; j++) { - dim_vec.emplace_back(context->MakeDim(dims.d[j])); - } - } else { - LOG(FATAL) << "TensorRT engine cannot find binding: " << output_nodes[i]; + // Check the sanity of the input shapes. + std::vector input_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("input_shapes", &input_shapes)); + if (input_shapes.size() != c->num_inputs()) { + return tensorflow::errors::InvalidArgument( + "The actual number of inputs doesn't match the number of input " + "shapes set in the attr: ", + c->num_inputs(), " vs ", input_shapes.size()); + } + bool input_match = true; + for (int i = 0; i < c->num_inputs(); ++i) { + ShapeHandle handle; + TF_RETURN_IF_ERROR( + c->MakeShapeFromTensorShape(input_shapes.at(i), &handle)); + ShapeHandle merged; + if (!c->Merge(c->input(i), handle, &merged).ok()) { + // Input shape doesn't match what was set in attr, fine. + input_match = false; } - output_shape = context->MakeShape(dim_vec); - context->set_output(i, output_shape); } + // Check the sanity of the output shapes. + std::vector output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return tensorflow::errors::InvalidArgument( + "The actual number of outputs doesn't match the number of output " + "shapes set in the attr: ", + c->num_outputs(), " vs ", output_shapes.size()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + ShapeHandle handle; + TF_RETURN_IF_ERROR( + c->MakeShapeFromTensorShape(output_shapes.at(i), &handle)); + if (input_match) c->set_output(i, handle); + } return Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/contrib/tensorrt/tensorrt_test.cc index 3712a9a6fe349d949ef2666652b9d750538d5535..769982c6456f76663e50fe3ec59651127e3720ac 100644 --- a/tensorflow/contrib/tensorrt/tensorrt_test.cc +++ b/tensorflow/contrib/tensorrt/tensorrt_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/test.h" #if GOOGLE_CUDA @@ -130,6 +132,13 @@ void Execute(nvinfer1::IExecutionContext* context, const float* input, } TEST(TensorrtTest, BasicFunctions) { + // Handle the case where the test is run on machine with no gpu available. + if (CHECK_NOTNULL(GPUMachineManager())->VisibleDeviceCount() <= 0) { + LOG(WARNING) << "No gpu device available, probably not being run on a gpu " + "machine. Skipping..."; + return; + } + // Create the network model. nvinfer1::IHostMemory* model = CreateNetwork(); // Use the model to create an engine and then an execution context. diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea5a6373525a8045d13f70aa9e12d66d4c08f0a --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/base_test.py @@ -0,0 +1,346 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Basic tests for TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.python import trt_convert +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing single segment.""" + # TODO(aaroey): test graph with different dtypes. + dtype = dtypes.float32 + input_name = "input" + input_dims = [100, 24, 24, 2] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + with g.device("/GPU:0"): + conv_filter = constant_op.constant( + [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], + name="weights", + dtype=dtype) + conv = nn.conv2d( + input=inp, + filter=conv_filter, + strides=[1, 2, 2, 1], + padding="SAME", + name="conv") + bias = constant_op.constant( + [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype) + added = nn.bias_add(conv, bias, name="bias_add") + relu = nn.relu(added, "relu") + identity = array_ops.identity(relu, "identity") + pool = nn_ops.max_pool( + identity, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + array_ops.squeeze(pool, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which + # breaks the connection check, fix it. + # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add", + # "relu", "identity", "max_pool"] + expected_engines=["my_trt_op_0"], + expected_output_dims=(100, 6, 6, 6), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + # TODO(aaroey): test graph with different dtypes. + dtype = dtypes.float32 + input_name = "input" + input_dims = [100, 24, 24, 2] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + with g.device("/GPU:0"): + conv_filter = constant_op.constant( + [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], + name="weights", + dtype=dtype) + conv = nn.conv2d( + input=inp, + filter=conv_filter, + strides=[1, 2, 2, 1], + padding="SAME", + name="conv") + c1 = constant_op.constant( + np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c1") + p = math_ops.mul(conv, c1, name="mul") + c2 = constant_op.constant( + np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c2") + q = math_ops.div(conv, c2, name="div") + + edge = self.trt_incompatible_op(q, name="incompatible") + edge = math_ops.div(edge, edge, name="div1") + r = math_ops.add(edge, edge, name="add") + + p = math_ops.sub(p, edge, name="sub") + q = math_ops.mul(q, edge, name="mul1") + s = math_ops.add(p, q, name="add1") + s = math_ops.sub(s, r, name="sub1") + array_ops.squeeze(s, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which + # breaks the connection check, fix it. + # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1", + # "add", "sub1"]; + # - my_trt_op_1 should have ["weights","conv", "div"] + expected_engines=["my_trt_op_0", "my_trt_op_1"], + expected_output_dims=(100, 12, 12, 6), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): + + def setUp(self): + """Setup method.""" + super(PartiallyConvertedTestA, self).setUp() + # Let it fail to build the second engine. + trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail") + + def GetParams(self): + """Create a graph containing two segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + for i in range(2): + c = constant_op.constant(1.0, name="c%d" % i) + n = math_ops.add(n, c, name="add%d" % i) + n = math_ops.mul(n, n, name="mul%d" % i) + edge = self.trt_incompatible_op(n, name="incompatible") + with g.control_dependencies([edge]): + c = constant_op.constant(1.0, name="c2") + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul2") + c = constant_op.constant(1.0, name="c3") + n = math_ops.add(n, c, name="add3") + n = math_ops.mul(n, n, name="mul3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + # Only the first engine is built. + "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class PartiallyConvertedTestB(PartiallyConvertedTestA): + + def setUp(self): + """Setup method.""" + super(PartiallyConvertedTestB, self).setUp() + # Let it fail to build the first engine. + trt_convert.clear_test_values("") + trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail") + + def GetParams(self): + """Create a graph containing two segment.""" + return super(PartiallyConvertedTestB, self).GetParams()._replace( + expected_engines={ + # Only the second engine is built. + "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"] + }) + + +class ConstInputTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + c = constant_op.constant(1.0, name="c") + # Adds control dependency from the constant op to a trt incompatible op, + # and adds control dependency from the trt incompatible op to all other + # ops, to make sure the constant op cannot be contracted with any trt + # segment that depends on it. + with g.control_dependencies([c]): + d = self.trt_incompatible_op(n, name="incompatible") + with g.control_dependencies([d]): + n = math_ops.add(n, c, name="add") + n = math_ops.mul(n, n, name="mul") + n = math_ops.add(n, n, name="add1") + n = self.trt_incompatible_op(n, name="incompatible1") + with g.control_dependencies([d]): + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul1") + n = math_ops.add(n, n, name="add3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["add", "add1", "mul"], + "my_trt_op_1": ["add2", "add3", "mul1"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing single segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + c = constant_op.constant(1.0, name="c") + n = math_ops.add(n, c, name="add") + n = math_ops.mul(n, n, name="mul") + n = math_ops.add(n, n, name="add1") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={"my_trt_op_0": ["c", "add", "add1", "mul"]}, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + c = constant_op.constant(1.0, name="c") + n = math_ops.add(n, c, name="add") + n = math_ops.mul(n, n, name="mul") + n = math_ops.add(n, n, name="add1") + n = self.trt_incompatible_op(n, name="incompatible1") + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul1") + n = math_ops.add(n, n, name="add3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["add2", "add3", "mul1"], + # Why segment ["add", "add1", "mul"] was assigned segment id 1 + # instead of 0: the parent node of this segment is actually const + # node 'c', but it's removed later since it's const output of the + # segment which is not allowed. + "my_trt_op_1": ["add", "add1", "mul"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + c1 = constant_op.constant(1.0, name="c1") + c2 = constant_op.constant(1.0, name="c2") + d1 = constant_op.constant(1.0, name="d1") + d2 = self.trt_incompatible_op(inp, name="d2") + with g.control_dependencies([d1, d2]): + add = math_ops.add(inp, c1, name="add") + with g.control_dependencies([d1, d2]): + mul = math_ops.mul(add, add, name="mul") + with g.control_dependencies([d1, d2]): + add1 = math_ops.add(mul, mul, name="add1") + edge = self.trt_incompatible_op(add1, name="incompatible") + with g.control_dependencies([d1, d2, add, mul]): + add2 = math_ops.add(edge, c2, name="add2") + with g.control_dependencies([d1, d2, add1, mul]): + mul1 = math_ops.mul(add2, add2, name="mul1") + with g.control_dependencies([d1, d2, add, add1]): + add3 = math_ops.add(mul1, mul1, name="add3") + array_ops.squeeze(add3, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["c1", "add", "add1", "mul"], + "my_trt_op_1": ["c2", "add2", "add3", "mul1"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2e1107e30383926f6428c6551682caf66cd97498 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py @@ -0,0 +1,76 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of BatchMatMul in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [12, 5, 8, 12] + w1_name = "matmul_w1" + w1_dims = [12, 5, 12, 7] + w2_name = "matmul_w2" + w2_dims = [12, 12, 7] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + w1 = array_ops.placeholder(dtype=dtype, shape=w1_dims, name=w1_name) + w2 = array_ops.placeholder(dtype=dtype, shape=w2_dims, name=w2_name) + with g.device("/GPU:0"): + b = constant_op.constant(np.random.randn(12, 5, 12, 7), dtype=dtype) + c = constant_op.constant(np.random.randn(5, 1, 1), dtype=dtype) + d = constant_op.constant(np.random.randn(5, 1, 1), dtype=dtype) + x1 = math_ops.matmul(inp, b) + x1 = x1 + c + x2 = math_ops.matmul(inp, w1) + x2 = x2 * d + e = gen_array_ops.reshape(inp, [12, 40, 12]) + x3 = math_ops.matmul(e, w2) + f = constant_op.constant(np.random.randn(40, 1), dtype=dtype) + x3 = x3 + f + x3 = gen_array_ops.reshape(x3, [12, 5, 8, 7]) + out = x1 + x2 + x3 + array_ops.squeeze(out, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name, w1_name, w2_name], + input_dims=[input_dims, w1_dims, w2_dims], + expected_engines=["my_trt_op_0"], + expected_output_dims=(12, 5, 8, 7), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8be32f59b48e64412466370950298feafc03b35c --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -0,0 +1,115 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of BiasAdd MatMul in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [48, 12] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + + b = constant_op.constant(np.random.randn(12, 4), dtype=dtype) + x1 = math_ops.matmul(x, b) + b = constant_op.constant(np.random.randn(1, 4), dtype=dtype) + x1 = x1 + b + + b = constant_op.constant(np.random.randn(48, 4), dtype=dtype) + x2 = math_ops.matmul(x, b, transpose_a=True) + x2 = gen_array_ops.reshape(x2, [48, 1]) + + b = constant_op.constant(np.random.randn(4, 12), dtype=dtype) + x3 = math_ops.matmul(x, b, transpose_b=True) + + b = constant_op.constant(np.random.randn(16, 48), dtype=dtype) + x4 = math_ops.matmul(x, b, transpose_b=True, transpose_a=True) + x4 = gen_array_ops.reshape(x4, [48, 4]) + + x5 = gen_array_ops.reshape(x, [4, 144]) + b = constant_op.constant(np.random.randn(144, 48), dtype=dtype) + x5 = math_ops.matmul(x5, b) + b = constant_op.constant(np.random.randn(48), dtype=dtype) + x5 = nn.bias_add(x5, b) + x5 = gen_array_ops.reshape(x5, [48, 4]) + + x6 = gen_array_ops.reshape(x, [4, 12, 12]) + b = constant_op.constant(np.random.randn(12), dtype=dtype) + x6 = nn.bias_add(x6, b, data_format="NHWC") + x6 = gen_array_ops.reshape(x6, [48, -1]) + + x7 = gen_array_ops.reshape(x, [4, 12, 3, 4]) + b = constant_op.constant(np.random.randn(4), dtype=dtype) + x7 = nn.bias_add(x7, b, data_format="NHWC") + x7 = gen_array_ops.reshape(x7, [48, -1]) + + x8 = gen_array_ops.reshape(x, [4, 12, 3, 2, 2]) + b = constant_op.constant(np.random.randn(2), dtype=dtype) + x8 = nn.bias_add(x8, b, data_format="NHWC") + x8 = gen_array_ops.reshape(x8, [48, -1]) + + x9 = gen_array_ops.reshape(x, [4, 12, 3, 2, 2]) + b = constant_op.constant(np.random.randn(3), dtype=dtype) + x9 = nn.bias_add(x9, b, data_format="NCHW") + x9 = gen_array_ops.reshape(x9, [48, -1]) + + x10 = gen_array_ops.reshape(x, [4, 12, 3, 4]) + b = constant_op.constant(np.random.randn(12), dtype=dtype) + x10 = nn.bias_add(x10, b, data_format="NCHW") + x10 = gen_array_ops.reshape(x10, [48, -1]) + + x11 = gen_array_ops.reshape(x, [4, 12, 12]) + b = constant_op.constant(np.random.randn(4), dtype=dtype) + x11 = nn.bias_add(x11, b, data_format="NCHW") + x11 = gen_array_ops.reshape(x11, [48, -1]) + + out = array_ops.concat( + [x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11], axis=-1) + out = array_ops.squeeze(out, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=[ + "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", + "my_trt_op_4", "my_trt_op_5", "my_trt_op_6" + ], + expected_output_dims=(48, 89), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9316b14da07d5f7e47953504680e14d5d20c17a4 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py @@ -0,0 +1,136 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Tests for scale & elementwise layers in TF-TRT.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [10, 24, 24, 20] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + # scale + a = constant_op.constant(np.random.randn(1), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # scale + a = constant_op.constant(np.random.randn(1), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # scale + a = constant_op.constant(np.random.randn(24, 1, 1), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # scale + a = constant_op.constant(np.random.randn(24, 1, 1), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # scale + a = constant_op.constant(np.random.randn(24, 24, 20), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # scale + a = constant_op.constant(np.random.randn(24, 24, 20), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(20), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(20), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(1, 24, 1, 1), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(1, 24, 1, 1), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(1, 24, 24, 1), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(1, 24, 24, 1), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(1, 24, 24, 20), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(1, 24, 24, 20), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(24, 20), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(24, 20), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + gen_array_ops.reshape(x, [5, -1], name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=[ + "my_trt_op_0", + "my_trt_op_1", + "my_trt_op_2", + "my_trt_op_3", + "my_trt_op_4", + "my_trt_op_5", + "my_trt_op_6", + "my_trt_op_7", + "my_trt_op_8", + "my_trt_op_9", + "my_trt_op_10", + "my_trt_op_11", + "my_trt_op_12", + "my_trt_op_13", + "my_trt_op_14", + "my_trt_op_15", + ], + expected_output_dims=(5, 23040), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1874b9dd45390407d3d36798cae620848df50c8d --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py @@ -0,0 +1,83 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.platform import test + + +class ConcatenationTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing Concatenation in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [2, 3, 3, 1] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + # scale + a = constant_op.constant(np.random.randn(3, 1, 1), dtype=dtype) + r1 = x / a + a = constant_op.constant(np.random.randn(3, 1, 1), dtype=dtype) + r2 = a / x + a = constant_op.constant(np.random.randn(1, 3, 1), dtype=dtype) + r3 = a + x + a = constant_op.constant(np.random.randn(1, 3, 1), dtype=dtype) + r4 = x * a + a = constant_op.constant(np.random.randn(3, 1, 1), dtype=dtype) + r5 = x - a + a = constant_op.constant(np.random.randn(3, 1, 1), dtype=dtype) + r6 = a - x + a = constant_op.constant(np.random.randn(3, 1), dtype=dtype) + r7 = x - a + a = constant_op.constant(np.random.randn(3, 1), dtype=dtype) + r8 = a - x + a = constant_op.constant(np.random.randn(3, 1, 1), dtype=dtype) + r9 = gen_math_ops.maximum(x, a) + a = constant_op.constant(np.random.randn(3, 1), dtype=dtype) + r10 = gen_math_ops.minimum(a, x) + a = constant_op.constant(np.random.randn(3), dtype=dtype) + r11 = x * a + a = constant_op.constant(np.random.randn(1), dtype=dtype) + r12 = a * x + concat1 = array_ops.concat([r1, r2, r3, r4, r5, r6], axis=-1) + concat2 = array_ops.concat([r7, r8, r9, r10, r11, r12], axis=3) + x = array_ops.concat([concat1, concat2], axis=-1) + gen_array_ops.reshape(x, [2, -1], name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=["my_trt_op_0"], + expected_output_dims=(2, 126), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8c59000b70e04cedc84308249865cfcb23ce80a3 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py @@ -0,0 +1,68 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Test for Constant broadcasting in TF-TRT.""" + dtype = dtypes.float32 + input_name = 'input' + input_dims = [5, 12, 12, 2] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + filt1 = constant_op.constant( + 0.3, shape=(3, 3, 2, 1), dtype=dtype, name='filt1') + y1 = nn.conv2d(x, filt1, strides=[1, 1, 1, 1], padding='SAME', name='y1') + z1 = nn.relu(y1, name='z1') + filt2 = constant_op.constant( + np.random.randn(9), shape=(3, 3, 1, 1), dtype=dtype, name='filt2') + y2 = nn.conv2d(z1, filt2, strides=[1, 1, 1, 1], padding='SAME', name='y2') + z2 = nn.relu(y2, name='z') + filt3 = constant_op.constant( + np.random.randn(3, 3, 1, 1), + shape=(3, 3, 1, 1), + dtype=dtype, + name='filt3') + y3 = nn.conv2d(z2, filt3, strides=[1, 1, 1, 1], padding='SAME', name='y3') + nn.relu(y3, name='output') + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=['my_trt_op_0'], + expected_output_dims=(5, 12, 12, 1), + allclose_atol=1.e-02, + allclose_rtol=1.e-02) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66eb6be757d3f4dcc390435486f7ed4f6517f875 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py @@ -0,0 +1,72 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of BatchMatMul in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [2, 15, 15, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + with g.device("/GPU:0"): + e1 = constant_op.constant( + np.random.randn(1, 1, 3, 5), name="kernel_1", dtype=dtype) + e2 = constant_op.constant( + np.random.randn(1, 1, 5, 10), name="kernel_2", dtype=dtype) + conv = nn.conv2d( + input=inp, + filter=e1, + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + out = nn.conv2d( + input=conv, + filter=e2, + strides=[1, 1, 1, 1], + padding="VALID", + name="conv_2") + array_ops.squeeze(out, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=["my_trt_op_0"], + expected_output_dims=(2, 15, 15, 10), + allclose_atol=1.e-02, + allclose_rtol=1.e-02) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fd55b8cd99171fe34424e48a417eb8981b051c17 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py @@ -0,0 +1,87 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Test for multi connection neighboring nodes wiring tests in TF-TRT.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [2, 3, 7, 5] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + e = constant_op.constant( + np.random.normal(.05, .005, [3, 2, 3, 4]), + name="weights", + dtype=dtype) + conv = nn.conv2d( + input=x, + filter=e, + data_format="NCHW", + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + b = constant_op.constant( + np.random.normal(2.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) + t = conv + b + + b = constant_op.constant( + np.random.normal(5.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) + q = conv - b + edge = math_ops.sigmoid(q) + + b = constant_op.constant( + np.random.normal(5.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) + d = b + conv + edge3 = math_ops.sigmoid(d) + + edge1 = gen_math_ops.tan(conv) + t = t - edge1 + q = q + edge + t = t + q + t = t + d + t = t - edge3 + array_ops.squeeze(t, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=["my_trt_op_0", "my_trt_op_1"], + expected_output_dims=(2, 4, 5, 4), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py new file mode 100644 index 0000000000000000000000000000000000000000..51c905a50b29c017719d66f9049e9b1bc3a9ec97 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py @@ -0,0 +1,72 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Neighboring node wiring tests in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [2, 3, 7, 5] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + e = constant_op.constant( + np.random.normal(.3, 0.05, [3, 2, 3, 4]), name="weights", dtype=dtype) + conv = nn.conv2d( + input=x, + filter=e, + data_format="NCHW", + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + b = constant_op.constant( + np.random.normal(1.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) + t = math_ops.mul(conv, b, name="mul") + e = self.trt_incompatible_op(conv, name="incompatible") + t = math_ops.sub(t, e, name="sub") + array_ops.squeeze(t, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["bias", "mul", "sub"], + "my_trt_op_1": ["weights", "conv"] + }, + expected_output_dims=(2, 4, 5, 4), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py index 175ccd800686255092e241aa59568df407d6eebc..090aa8bdb0487973e186631af3b4edac48096a5f 100644 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -20,6 +20,7 @@ from __future__ import print_function import argparse import numpy as np +import six as _six # normally we should do import tensorflow as tf and then # tf.placeholder, tf.constant, tf.nn.conv2d etc but @@ -35,10 +36,75 @@ from tensorflow.python.framework import dtypes as dtypes from tensorflow.python.framework import importer as importer from tensorflow.python.framework import ops as ops from tensorflow.python.ops import array_ops as aops +from tensorflow.python.ops import math_ops as mops from tensorflow.python.ops import nn as nn from tensorflow.python.ops import nn_ops as nn_ops +def py2bytes(inp): + return inp + + +def py3bytes(inp): + return inp.encode("utf-8", errors="surrogateescape") + + +def py2string(inp): + return inp + + +def py3string(inp): + return inp.decode("utf-8") + + +if _six.PY2: + to_bytes = py2bytes + to_string = py2string +else: + to_bytes = py3bytes + to_string = py3string + + +def get_multi_engine_graph_def(mode="FP32"): + """Create a simple graph and return its graph_def.""" + dtype = dtypes.float32 + if mode.upper() == "FP16": + dtype = dtypes.float16 + else: + pass + + g = ops.Graph() + with g.as_default(): + x = aops.placeholder(shape=[None, 3, 7, 5], name="input", dtype=dtype) + with g.name_scope("Global_scope"): + with g.name_scope("first_scope"): + e = cop.constant( + np.random.randn(3, 2, 3, 4), name="weights", dtype=dtype) + conv = nn.conv2d( + input=x, + filter=e, + data_format="NCHW", + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias1", dtype=dtype) + t = conv * b + + b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias2", dtype=dtype) + q = conv / b + edge = mops.sin(q) + edge1 = mops.cos(conv) + with g.name_scope("test_scope"): + de = edge + edge1 + t -= edge1 + q *= edge + t += q + t -= de + k = aops.squeeze(t, name="output") + print(k.dtype) + return g.as_graph_def() + + def get_simple_graph_def(): """Create a simple graph and return its graph_def.""" g = ops.Graph() @@ -65,7 +131,9 @@ def get_simple_graph_def(): def execute_graph(gdef, dumm_inp): """Run given graphdef once.""" print("executing") - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = None + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) sessconfig = cpb2.ConfigProto(gpu_options=gpu_options) ops.reset_default_graph() g = ops.Graph() @@ -83,7 +151,9 @@ def execute_graph(gdef, dumm_inp): # for calibration. For this test script it is random data. def execute_calibration(gdef, dumm_inp): """Run given calibration graph multiple times.""" - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = None + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) ops.reset_default_graph() g = ops.Graph() with g.as_default(): @@ -100,12 +170,17 @@ def execute_calibration(gdef, dumm_inp): return val -def user(run_graph=execute_graph, run_calibration=execute_calibration): +def user(multi_engine, + run_graph=execute_graph, + run_calibration=execute_calibration): """Example function that converts a graph to TFTRT graph.""" - - inp_dims = (100, 24, 24, 2) + if multi_engine: + inp_dims = (2, 3, 7, 5) + orig_graph = get_multi_engine_graph_def() + else: + inp_dims = (100, 24, 24, 2) + orig_graph = get_simple_graph_def() # use a frozen graph for inference dummy_input = np.random.random_sample(inp_dims) - orig_graph = get_simple_graph_def() # use a frozen graph for inference # Get optimized graph trt_graph = trt.create_inference_graph( input_graph_def=orig_graph, @@ -113,8 +188,10 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration): max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) + minimum_segment_size=2, # minimum number of nodes in an engine + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]) o1 = run_graph(orig_graph, dummy_input) o2 = run_graph(trt_graph, dummy_input) o3 = run_graph(trt_graph, dummy_input) @@ -126,40 +203,51 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration): max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) + minimum_segment_size=2, # minimum number of nodes in an engine + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]) int8_calib_gdef = trt.create_inference_graph( input_graph_def=orig_graph, outputs=["output"], max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) + minimum_segment_size=2, # minimum number of nodes in an engine + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]) o4 = run_graph(fp16_graph, dummy_input) _ = run_calibration(int8_calib_gdef, dummy_input) int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) o5 = run_graph(int8_graph, dummy_input) - assert np.allclose(o1, o4) - assert np.allclose(o1, o5) + print("Is FP32 == FP16? %s (False is possible)" % np.allclose(o1, o4)) + print("Is FP32 == INT8? %s (False is possible)" % np.allclose(o1, o5)) print("Pass") -def auto(): +def auto(multi_engine): """Run the conversion as an optimization pass.""" - inp_dims = (100, 24, 24, 2) + if multi_engine: + inp_dims = (2, 3, 7, 5) + orig_graph = get_multi_engine_graph_def() + else: + inp_dims = (100, 24, 24, 2) + orig_graph = get_simple_graph_def() # use a frozen graph for inference dummy_input = np.random.random_sample(inp_dims) - orig_graph = get_simple_graph_def() opt_config = rwpb2.RewriterConfig() + opt_config.meta_optimizer_iterations = opt_config.ONE opt_config.optimizers.extend(["constfold", "layout"]) custom_op = opt_config.custom_optimizers.add() custom_op.name = "TensorRTOptimizer" custom_op.parameter_map["minimum_segment_size"].i = 3 - custom_op.parameter_map["precision_mode"].s = "FP32" + custom_op.parameter_map["precision_mode"].s = to_bytes("FP32") custom_op.parameter_map["max_batch_size"].i = inp_dims[0] custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 print(custom_op) - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = None + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) graph_options = cpb2.GraphOptions(rewrite_options=opt_config) sessconfig = cpb2.ConfigProto( gpu_options=gpu_options, graph_options=graph_options) @@ -168,7 +256,7 @@ def auto(): ops.reset_default_graph() with g.as_default(): inp, out = importer.import_graph_def( - graph_def=orig_graph, return_elements=["input", "output"]) + graph_def=orig_graph, return_elements=["input", "output"], name="") inp = inp.outputs[0] out = out.outputs[0] with csess.Session(config=sessconfig, graph=g) as sess: @@ -186,8 +274,14 @@ if "__main__" in __name__: action="store_true", help="Do TRT conversion automatically", default=False) + P.add_argument( + "--multi-engine", + "-m", + action="store_true", + help="Use a graph that will result in 2 engines", + default=False) flags, unparsed = P.parse_known_args() if flags.automatic: - auto() + auto(flags.multi_engine) else: - user() + user(flags.multi_engine) diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py deleted file mode 100644 index 0403b652d72877196c3537a3181529aeeb997395..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Script to test TF-TensorRT integration.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import warnings -import numpy as np - -from tensorflow.contrib import tensorrt as trt -from tensorflow.core.protobuf import config_pb2 as cpb2 -from tensorflow.python.framework import constant_op as cop -from tensorflow.python.framework import dtypes as dtypes -from tensorflow.python.framework import importer as importer -from tensorflow.python.framework import ops as ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops as aops -from tensorflow.python.ops import nn as nn -from tensorflow.python.ops import nn_ops as nn_ops -from tensorflow.python.platform import googletest - - -class IntegrationTest(test_util.TensorFlowTestCase): - """Class to test Tensorflow-TensorRT integration.""" - - def setUp(self): - """Setup method.""" - super(IntegrationTest, self).setUp() - warnings.simplefilter("always") - inp_dims = (100, 24, 24, 2) - self._input = np.random.random_sample(inp_dims) - self._original_graph = self.get_simple_graph_def() - self._gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - self._config = cpb2.ConfigProto(gpu_options=self._gpu_options) - self._reference = self.run_graph(self._original_graph, self._input) - - def get_simple_graph_def(self): - """Create a simple graph and return its graph_def.""" - g = ops.Graph() - with g.as_default(): - a = aops.placeholder( - dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") - e = cop.constant( - [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], - name="weights", - dtype=dtypes.float32) - conv = nn.conv2d( - input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") - b = cop.constant( - [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32) - t = nn.bias_add(conv, b, name="biasAdd") - relu = nn.relu(t, "relu") - idty = aops.identity(relu, "ID") - v = nn_ops.max_pool( - idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - aops.squeeze(v, name="output") - return g.as_graph_def() - - def run_graph(self, gdef, dumm_inp): - """Run given graphdef once.""" - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - with self.test_session( - graph=g, config=self._config, use_gpu=True, force_gpu=True) as sess: - val = sess.run(out, {inp: dumm_inp}) - return val - - # Use real data that is representative of the inference dataset - # for calibration. For this test script it is random data. - def run_calibration(self, gdef, dumm_inp): - """Run given calibration graph multiple times.""" - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - # run over real calibration data here, we are mimicking a calibration - # set of 30 different batches. Use as much calibration data as you want - with self.test_session( - graph=g, config=self._config, use_gpu=True, force_gpu=True) as sess: - for _ in range(30): - val = sess.run(out, {inp: dumm_inp}) - return val - - def get_trt_graph(self, mode): - """Return trt converted graph.""" - if mode in ["FP32", "FP16", "INT8"]: - return trt.create_inference_graph( - input_graph_def=self._original_graph, - outputs=["output"], - max_batch_size=self._input.shape[0], - max_workspace_size_bytes=1 << 25, - precision_mode=mode, # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) - return None - - def testFP32(self): - """Test FP32 conversion. Results should be identical to native case.""" - trt_graph = self.get_trt_graph("FP32") - result = self.run_graph(trt_graph, self._input) - self.assertAllEqual(self._reference, result) - result1 = self.run_graph(trt_graph, self._input) - self.assertAllEqual(result1, result) - - def testFP16(self): - """Test FP16 conversion. Results may be different from native case.""" - trt_graph = self.get_trt_graph("FP16") - result = self.run_graph(trt_graph, self._input) - self.assertAllClose(self._reference, result, rtol=1.e-03) - result1 = self.run_graph(trt_graph, self._input) - self.assertAllEqual(result1, result) - - def testINT8(self): - """Test INT8 conversion. Results may be different from native case.""" - calib_graph = self.get_trt_graph("INT8") - result = self.run_calibration(calib_graph, self._input) - self.assertAllEqual(self._reference, result) - int8_graph = trt.calib_graph_to_infer_graph(calib_graph) - result = self.run_graph(int8_graph, self._input) - self.assertAllClose(self._reference, result, rtol=1.e-03) - result1 = self.run_graph(int8_graph, self._input) - self.assertAllEqual(result1, result) - - -if __name__ == "__main__": - googletest.main() diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..6f85ada4649563d099c6054e8e17da27954071f7 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -0,0 +1,461 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import itertools +import os +import warnings +import numpy as np +import six + +from tensorflow.contrib.tensorrt.python import trt_convert +# pylint: disable=unused-import +from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +# pylint: enable=unused-import +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import graph_io +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging as logging + +TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [ + "gdef", "input_names", "input_dims", "expected_engines", + "expected_output_dims", "allclose_atol", "allclose_rtol" +]) + +RunParams = namedtuple( + "RunParams", + ["use_optimizer", "precision_mode", "dynamic_engine", "test_name"]) + +PRECISION_MODES = ["FP32", "FP16", "INT8"] + + +def _IsQuantizationMode(mode): + return mode == "INT8" + + +class GraphState(object): + ORIGINAL = 0 + CALIBRATE = 1 + INFERENCE = 2 + + +class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): + """Class to test Tensorflow-TensorRT integration.""" + + @property + def output_name(self): + return "output" + + @property + def trt_incompatible_op(self): + return math_ops.sin + + @property + def precision_modes(self): + return ["FP32", "FP16", "INT8"] + + # str is bytes in py2, but unicode in py3. + def _ToUnicode(self, s): + if six.PY2: + if isinstance(s, unicode): + return s + return s.decode("utf-8") + else: + if isinstance(s, str): + return s + return s.decode("utf-8") + + def _ToBytes(self, s): + if six.PY2: + if isinstance(s, unicode): + return s.encode("utf-8") + return s + else: + if isinstance(s, str): + return s.encode("utf-8") + return s + + def _ToString(self, s): + if six.PY2: + if isinstance(s, unicode): + return s.encode("utf-8") + return s + else: + if isinstance(s, str): + return s + return s.decode("utf-8") + + @classmethod + def setUpClass(cls): + """Setup method for the module.""" + super(TfTrtIntegrationTestBase, cls).setUpClass() + trt_convert.enable_test_value() + + def setUp(self): + """Setup method.""" + super(TfTrtIntegrationTestBase, self).setUp() + warnings.simplefilter("always") + trt_convert.clear_test_values("") + + def GetParams(self): + """Return a TfTrtIntegrationTestParams for test, implemented by subclass.""" + raise NotImplementedError() + + def _PrepareRun(self, params, graph_state): + """Set up necessary testing environment before calling sess.run().""" + # Clear test values added by TRTEngineOp. + trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine") + trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration") + trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment") + + def _VerifyRun(self, params, graph_state): + """Verify the state after sess.run().""" + for engine_name in params.expected_engines: + if graph_state == GraphState.ORIGINAL: + self._ExpectCalibration(engine_name, "") + self._ExpectNativeSegment(engine_name, "") + self._ExpectTrtEngine(engine_name, "") + elif graph_state == GraphState.CALIBRATE: + self._ExpectCalibration(engine_name, "done") + self._ExpectNativeSegment(engine_name, "done") + self._ExpectTrtEngine(engine_name, "") + elif graph_state == GraphState.INFERENCE: + self._ExpectCalibration(engine_name, "") + self._ExpectNativeSegment(engine_name, "") + self._ExpectTrtEngine(engine_name, "done") + + def _GetConfigProto(self, params, run_params, graph_state): + """Get config proto based on specific settings.""" + if graph_state != GraphState.ORIGINAL and run_params.use_optimizer: + rewriter_cfg = rewriter_config_pb2.RewriterConfig() + rewriter_cfg.optimizers.extend(["constfold", "layout"]) + custom_op = rewriter_cfg.custom_optimizers.add() + custom_op.name = "TensorRTOptimizer" + custom_op.parameter_map["minimum_segment_size"].i = 2 + custom_op.parameter_map["max_batch_size"].i = max( + [dims[0] for dims in params.input_dims]) + custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine + custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 + custom_op.parameter_map["precision_mode"].s = self._ToBytes( + run_params.precision_mode) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) + else: + graph_options = config_pb2.GraphOptions() + + gpu_options = config_pb2.GPUOptions() + gpu_options.allow_growth = True + if trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options.per_process_gpu_memory_fraction = 0.50 + + config = config_pb2.ConfigProto( + gpu_options=gpu_options, graph_options=graph_options) + return config + + def _ExpectTestValue(self, engine_name, method, expected_value): + label = "%s:%s" % (engine_name, method) + actual_value = trt_convert.get_test_value(label) + self.assertEqual( + expected_value, + actual_value, + msg="Unexpected test value with label %s. Actual: %s; expected: %s" % + (label, actual_value, expected_value)) + + def _ExpectCalibration(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteCalibration", value) + + def _ExpectTrtEngine(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteTrtEngine", value) + + def _ExpectNativeSegment(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteNativeSegment", value) + + def _RunGraph(self, params, gdef, input_data, config, graph_state, + num_runs=2): + """Run given graphdef multiple times.""" + assert len(params.input_names) == len(input_data) + g = ops.Graph() + with g.as_default(): + io_ops = importer.import_graph_def( + graph_def=gdef, + return_elements=params.input_names + [self.output_name], + name="") + inp = [i.outputs[0] for i in io_ops[:-1]] + assert len(inp) == len(input_data) + out = io_ops[-1].outputs[0] + with self.test_session( + graph=g, config=config, use_gpu=True, force_gpu=True) as sess: + val = None + # Defaults to 2 runs to verify result across multiple runs is same. + for _ in range(num_runs): + self._PrepareRun(params, graph_state) + new_val = sess.run(out, + {inp[i]: input_data[i] for i in range(len(inp))}) + self.assertEqual(params.expected_output_dims, new_val.shape) + if val is not None: + self.assertAllEqual(val, new_val) + val = new_val + self._VerifyRun(params, graph_state) + return val + + # Use real data that is representative of the inference dataset + # for calibration. For this test script it is random data. + def _RunCalibration(self, params, gdef, input_data, config): + """Run calibration on given graph.""" + return self._RunGraph( + params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5) + + def _GetTrtGraphDef(self, params, run_params, gdef): + """Return trt converted graphdef.""" + return trt_convert.create_inference_graph( + input_graph_def=gdef, + outputs=[self.output_name], + max_batch_size=max([dims[0] for dims in params.input_dims]), + max_workspace_size_bytes=1 << 25, + precision_mode=run_params.precision_mode, + minimum_segment_size=2, + is_dynamic_op=run_params.dynamic_engine) + + def _WriteGraph(self, params, run_params, gdef, graph_state): + if graph_state == GraphState.ORIGINAL: + label = "Original" + elif graph_state == GraphState.CALIBRATE: + label = "CalibEngine" + elif graph_state == GraphState.INFERENCE: + label = "InferEngine" + graph_name = ( + self.__class__.__name__ + "_" + run_params.test_name + "_" + label + + ".pbtxt") + temp_dir = os.getenv("TRT_TEST_TMPDIR", self.get_temp_dir()) + logging.info("Writing graph to %s/%s", temp_dir, graph_name) + graph_io.write_graph(gdef, temp_dir, graph_name) + + def _VerifyConnections(self, params, converted_gdef): + old_to_new_node_map = { + self._ToString(node.name): self._ToString(node.name) + for node in params.gdef.node + } + for engine_name, node_names in params.expected_engines.items(): + for node_name in node_names: + old_to_new_node_map[node_name] = engine_name + name_to_node_map = { + self._ToString(node.name): node for node in params.gdef.node + } + + def _InputName(inp): + inp = self._ToString(inp) + prefix = "" + if inp[0] == "^": + prefix = "^" + inp = inp[1:] + parts = inp.split(":") + if len(parts) > 1 and parts[-1].isdigit(): + inp = inp[:-len(parts[-1]) - 1] + return (prefix, inp) + + expected_input_map = {} + for node in params.gdef.node: + name_str = self._ToString(node.name) + target_node_name = old_to_new_node_map[name_str] + is_engine_op = (target_node_name != name_str) + if target_node_name not in expected_input_map: + expected_input_map[target_node_name] = set() + input_set = expected_input_map[target_node_name] + for inp in node.input: + (prefix, inp_name) = _InputName(inp) + # Add the input only if it's outside the segment (note that it could be + # in a different engine). + if (not is_engine_op or + old_to_new_node_map[inp_name] != target_node_name): + if is_engine_op and name_to_node_map[inp_name].op == "Const": + # Const data input nodes to the segment has been copied to the + # segment graphdef and the engine, and the dependency has been + # converted to control dependendy. + input_set.add("^" + old_to_new_node_map[inp_name]) + else: + input_set.add(prefix + old_to_new_node_map[inp_name]) + + actual_input_map = {} + for node in converted_gdef.node: + name_str = self._ToString(node.name) + actual_input_map[name_str] = set() + input_set = actual_input_map[name_str] + for inp in node.input: + (prefix, node_name) = _InputName(inp) + input_set.add(prefix + node_name) + + self.assertEqual( + expected_input_map, + actual_input_map, + msg="expected:\n%s\nvs actual:\n%s" % (sorted( + expected_input_map.items()), sorted(actual_input_map.items()))) + + def _VerifyGraphDef(self, params, run_params, gdef, graph_state): + self._WriteGraph(params, run_params, gdef, graph_state) + + num_engines = 0 + for node in gdef.node: + if node.op == "TRTEngineOp": + num_engines += 1 + self.assertTrue(node.name in params.expected_engines) + self.assertTrue(len(node.attr["serialized_segment"].s)) + self.assertTrue(len(node.attr["segment_funcdef_name"].s)) + self.assertEqual( + self._ToBytes(run_params.precision_mode), + node.attr["precision_mode"].s) + + is_dynamic_engine = not node.attr["static_engine"].b + self.assertEqual(run_params.dynamic_engine, is_dynamic_engine) + + has_calibration_data = len(node.attr["calibration_data"].s) + if (_IsQuantizationMode(run_params.precision_mode) and + graph_state == GraphState.INFERENCE): + self.assertTrue(has_calibration_data) + else: + self.assertFalse(has_calibration_data) + if graph_state == GraphState.ORIGINAL: + self.assertEqual(0, num_engines) + else: + self.assertEqual(num_engines, len(params.expected_engines)) + if isinstance(params.expected_engines, dict): + self._VerifyConnections(params, gdef) + # TODO(aaroey): consider verifying the corresponding TF function. + + def RunTest(self, params, run_params): + assert run_params.precision_mode in PRECISION_MODES + input_data = [np.random.random_sample(dims) for dims in params.input_dims] + input_gdef = params.gdef + self._VerifyGraphDef(params, run_params, input_gdef, GraphState.ORIGINAL) + + # Get reference result without running trt. + config_no_trt = self._GetConfigProto(params, run_params, + GraphState.ORIGINAL) + logging.info("Running original graph w/o trt, config:\n%s", + str(config_no_trt)) + ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt, + GraphState.ORIGINAL) + + # Run calibration if necessary. + if _IsQuantizationMode(run_params.precision_mode): + + calib_config = self._GetConfigProto(params, run_params, + GraphState.CALIBRATE) + logging.info("Running calibration graph, config:\n%s", str(calib_config)) + if run_params.use_optimizer: + result = self._RunCalibration(params, input_gdef, input_data, + calib_config) + else: + calib_gdef = self._GetTrtGraphDef(params, run_params, input_gdef) + self._VerifyGraphDef(params, run_params, calib_gdef, + GraphState.CALIBRATE) + result = self._RunCalibration(params, calib_gdef, input_data, + calib_config) + infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef) + self._VerifyGraphDef(params, run_params, infer_gdef, GraphState.INFERENCE) + + self.assertAllClose( + ref_result, + result, + atol=params.allclose_atol, + rtol=params.allclose_rtol) + else: + infer_gdef = input_gdef + + # Run inference. + infer_config = self._GetConfigProto(params, run_params, + GraphState.INFERENCE) + logging.info("Running final inference graph, config:\n%s", + str(infer_config)) + if run_params.use_optimizer: + result = self._RunGraph(params, infer_gdef, input_data, infer_config, + GraphState.INFERENCE) + else: + trt_infer_gdef = self._GetTrtGraphDef(params, run_params, infer_gdef) + self._VerifyGraphDef(params, run_params, trt_infer_gdef, + GraphState.INFERENCE) + result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config, + GraphState.INFERENCE) + + self.assertAllClose( + ref_result, + result, + atol=params.allclose_atol, + rtol=params.allclose_rtol) + + def testIdempotence(self): + # Test that applying tensorrt optimizer or offline conversion tools multiple + # times to the same graph will result in same graph. + # + # TODO(aaroey): currently the conversion is not deterministic, this is + # mainly because during tensorflow::ConvertGraphDefToGraph(), the graph uses + # EdgeSet which use a map keyed by Edge*, so the order of input/output edges + # of a node is nondeterministic, thus the order for segmenter to contract + # edges is nondeterministic. Need to evaluate whether we should fix this. + pass + + +def _AddTests(test_class): + """Adds test methods to TfTrtIntegrationTestBase.""" + + def _GetTest(run_params): + """Gets a single test method based on the parameters.""" + + def _Test(self): + params = self.GetParams() + logging.info( + "Running test %s with parameters: use_optimizer=%s, " + "precision_mode=%s, dynamic_engine=%s", + "testTfTrt_" + run_params.test_name, run_params.use_optimizer, + run_params.precision_mode, run_params.dynamic_engine) + self.RunTest(params, run_params) + + return _Test + + use_optimizer_options = [False, True] + dynamic_engine_options = [False, True] + for (use_optimizer, precision_mode, dynamic_engine) in itertools.product( + use_optimizer_options, PRECISION_MODES, dynamic_engine_options): + if _IsQuantizationMode(precision_mode): + if use_optimizer: + # TODO(aaroey): if use_optimizer is True we need to get the inference + # graphdef using custom python wrapper class, which is not currently + # supported yet. + continue + if not dynamic_engine: + # TODO(aaroey): construction of static calibration engine is not + # supported yet. + continue + + conversion = "OptimizerConversion" if use_optimizer else "ToolConversion" + engine_type = ("DynamicEngine" if dynamic_engine else "StaticEngine") + test_name = "%s_%s_%s" % (conversion, precision_mode, engine_type) + run_params = RunParams( + use_optimizer=use_optimizer, + precision_mode=precision_mode, + dynamic_engine=dynamic_engine, + test_name=test_name) + setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params)) + + +if trt_convert.is_tensorrt_enabled(): + _AddTests(TfTrtIntegrationTestBase) diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py new file mode 100644 index 0000000000000000000000000000000000000000..500057a36d60efa3b7f96f22e27973444ecc277c --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/unary_test.py @@ -0,0 +1,113 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class UnaryTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Test for unary operations in TF-TRT.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [12, 5, 8, 1, 1, 12] + input2_name = "input_2" + input2_dims = [12, 5, 8, 1, 12, 1, 1] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + q = math_ops.abs(x) + q = q + 1.0 + q = gen_math_ops.exp(q) + q = gen_math_ops.log(q) + q = array_ops.squeeze(q, axis=-2) + q = math_ops.abs(q) + q = q + 2.2 + q = gen_math_ops.sqrt(q) + q = gen_math_ops.rsqrt(q) + q = math_ops.negative(q) + q = array_ops.squeeze(q, axis=3) + q = math_ops.abs(q) + q = q + 3.0 + a = gen_math_ops.reciprocal(q) + + x = constant_op.constant(np.random.randn(5, 8, 12), dtype=dtype) + q = math_ops.abs(x) + q = q + 2.0 + q = gen_math_ops.exp(q) + q = gen_math_ops.log(q) + q = math_ops.abs(q) + q = q + 2.1 + q = gen_math_ops.sqrt(q) + q = gen_math_ops.rsqrt(q) + q = math_ops.negative(q) + q = math_ops.abs(q) + q = q + 4.0 + b = gen_math_ops.reciprocal(q) + + # TODO(jie): this one will break, broadcasting on batch. + x = array_ops.placeholder( + dtype=dtype, shape=input2_dims, name=input2_name) + q = math_ops.abs(x) + q = q + 5.0 + q = gen_math_ops.exp(q) + q = array_ops.squeeze(q, axis=[-1, -2, 3]) + q = gen_math_ops.log(q) + q = math_ops.abs(q) + q = q + 5.1 + q = gen_array_ops.reshape(q, [12, 5, 1, 1, 8, 1, 12]) + q = array_ops.squeeze(q, axis=[5, 2, 3]) + q = gen_math_ops.sqrt(q) + q = math_ops.abs(q) + q = q + 5.2 + q = gen_math_ops.rsqrt(q) + q = math_ops.negative(q) + q = math_ops.abs(q) + q = q + 5.3 + c = gen_math_ops.reciprocal(q) + + q = a * b + q = q / c + array_ops.squeeze(q, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name, input2_name], + input_dims=[input_dims, input2_dims], + expected_engines=[ + "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", + "my_trt_op_4" + ], + expected_output_dims=(12, 5, 8, 12), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/utils.cc b/tensorflow/contrib/tensorrt/test/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..276308b3a0a6ce864969afb0179c6a3f00d6b70b --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/utils.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/test/utils.h" + +#include +#include + +#include "re2/re2.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +// TODO(aaroey): make this class thread-safe. +class TestValueManager { + public: + static TestValueManager* singleton() { + static TestValueManager* manager = new TestValueManager(); + return manager; + } + + void Enable() { + VLOG(1) << "Enabling test value"; + enabled_ = true; + } + + void Add(const string& label, const string& value) { + if (TF_PREDICT_FALSE(enabled_)) { + QCHECK_NE("", value); + VLOG(1) << "Adding test value: " << label << " -> " << value; + values_.insert({label, value}); + } + } + + string Get(const string& label) { + if (TF_PREDICT_FALSE(enabled_)) { + VLOG(1) << "Getting test value by " << label; + auto itr = values_.find(label); + if (itr == values_.end()) return ""; + return itr->second; + } + return ""; + } + + void Clear(const string& pattern) { + if (TF_PREDICT_FALSE(enabled_)) { + VLOG(1) << "Clearing test values"; + if (pattern.empty()) { + values_.clear(); + return; + } + std::vector keys_to_clear; + for (const auto& kv : values_) { + if (RE2::FullMatch(kv.first, pattern)) { + keys_to_clear.push_back(kv.first); + } + } + for (const string& key : keys_to_clear) { + values_.erase(key); + } + } + } + + private: + TestValueManager() : enabled_(false) {} + + bool enabled_; + std::unordered_map values_; +}; + +void EnableTestValue() { TestValueManager::singleton()->Enable(); } + +void ClearTestValues(const string& pattern) { + TestValueManager::singleton()->Clear(pattern); +} + +void AddTestValue(const string& label, const string& value) { + TestValueManager::singleton()->Add(label, value); +} + +string GetTestValue(const string& label) { + return TestValueManager::singleton()->Get(label); +} + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/test/utils.h b/tensorflow/contrib/tensorrt/test/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..4bb4120206cfaae70107e55d1818e3af2f02717a --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/utils.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +// Helper methods to inject values used by testing tools. +void EnableTestValue(); +void ClearTestValues(const string& pattern); +void AddTestValue(const string& label, const string& value); +string GetTestValue(const string& label); + +#define TRT_RETURN_IF_TEST_VALUE(label, value_to_return) \ + do { \ + if (::tensorflow::tensorrt::test::GetTestValue(label) == \ + value_to_return) { \ + return errors::Internal("Injected manually"); \ + } \ + } while (0) + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4d224db4d88c91c9b06d278b404879d989a834 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py @@ -0,0 +1,82 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Single vgg layer in NCHW unit tests in TF-TRT.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [5, 2, 8, 8] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + x, _, _ = nn_impl.fused_batch_norm( + x, + np.random.randn(2).astype(np.float32), + np.random.randn(2).astype(np.float32), + mean=np.random.randn(2).astype(np.float32), + variance=np.random.randn(2).astype(np.float32), + data_format="NCHW", + is_training=False) + e = constant_op.constant( + np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype) + conv = nn.conv2d( + input=x, + filter=e, + data_format="NCHW", + strides=[1, 1, 2, 2], + padding="SAME", + name="conv") + b = constant_op.constant(np.random.randn(6), name="bias", dtype=dtype) + t = nn.bias_add(conv, b, data_format="NCHW", name="biasAdd") + relu = nn.relu(t, "relu") + idty = array_ops.identity(relu, "ID") + v = nn_ops.max_pool( + idty, [1, 1, 2, 2], [1, 1, 2, 2], + "VALID", + data_format="NCHW", + name="max_pool") + array_ops.squeeze(v, name="output") + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=["my_trt_op_0"], + expected_output_dims=(5, 6, 2, 2), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py new file mode 100644 index 0000000000000000000000000000000000000000..56bdf848eadbdde3d5896e415ecd9754ed387eeb --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py @@ -0,0 +1,73 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Single vgg layer test in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [5, 8, 8, 2] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + x, _, _ = nn_impl.fused_batch_norm( + x, + np.random.randn(2).astype(np.float32), + np.random.randn(2).astype(np.float32), + mean=np.random.randn(2).astype(np.float32), + variance=np.random.randn(2).astype(np.float32), + is_training=False) + e = constant_op.constant( + np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype) + conv = nn.conv2d( + input=x, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") + b = constant_op.constant(np.random.randn(6), name="bias", dtype=dtype) + t = nn.bias_add(conv, b, name="biasAdd") + relu = nn.relu(t, "relu") + idty = array_ops.identity(relu, "ID") + v = nn_ops.max_pool( + idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + array_ops.squeeze(v, name="output") + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=["my_trt_op_0"], + expected_output_dims=(5, 2, 2, 6), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index 46480e99a113afb34702b0ecd71468d4bdc83f98..6ea15fb8eff13663625420288a37ba002d57fa47 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -48,34 +48,78 @@ PyObject* pair_helper(std::pair* in) { } return tuple; } + +struct version_struct{ + int vmajor; + int vminor; + int vpatch; +}; + +PyObject* version_helper(version_struct* in) { + PyObject *tuple(nullptr); + tuple = Py_BuildValue("(iii)", in->vmajor, in->vminor, in->vpatch); + if (!tuple) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, + "Tuple creation from version structure failed!"); + } + return NULL; + } + return tuple; +} +/* Define converters for vector */ +template<> +bool _PyObjAs(PyObject *pyobj, int* dest) { + *dest = PyLong_AsLong(pyobj); + return true; +} + +template<> +PyObject *_PyObjFrom(const int& src) { + return PyLong_FromLong(src); +} + %} + +_LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); + %typemap(out) std::pair { PyObject *tuple = pair_helper(&$1); if (!tuple) SWIG_fail; $result = tuple; } + +%typemap(out) version_struct { + PyObject *tuple = version_helper(&$1); + if (!tuple) SWIG_fail; + $result = tuple; +} + %{ #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/stat_summarizer.h" #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" %} %ignoreall %unignore tensorflow; -%unignore trt_convert; %unignore calib_convert; +%unignore get_linked_tensorrt_version; +%unignore get_loaded_tensorrt_version; +%unignore is_tensorrt_enabled; +%unignore enable_test_value; +%unignore clear_test_values; +%unignore add_test_value; +%unignore get_test_value; %{ -std::pair trt_convert( - string graph_def_string, // The serialized GraphDef string. - std::vector output_names, - size_t max_batch_size, - size_t max_workspace_size_bytes, - int precision_mode, - int minimum_segment_size - // Unfortunately we can't use TF_Status here since it +std::pair calib_convert( + string graph_def_string, bool is_dyn_op + // unfortunately we can't use TF_Status here since it // is in c/c_api and brings in a lot of other libraries // which in turn declare ops. These ops are included // statically in our library and cause an abort when @@ -93,20 +137,11 @@ std::pair trt_convert( out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; return std::pair{out_status, ""}; } - - if(precision_mode < 0 || precision_mode > 2){ - out_status = "InvalidArgument;Invalid precision_mode"; - return std::pair{out_status, ""}; - } - if (!output_names.size()) { - out_status = "InvalidArgument;Size of the output_names vector is 0"; - return std::pair{out_status, ""}; - } - tensorflow::GraphDef outGraph; + graph_def_string.resize(0); + tensorflow::GraphDef out_graph; tensorflow::Status conversion_status = - tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT( - graph_def, output_names, max_batch_size, max_workspace_size_bytes, - &outGraph, precision_mode, minimum_segment_size); + tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph( + graph_def, &out_graph, is_dyn_op); if (!conversion_status.ok()) { auto retCode = (int)conversion_status.code(); char buff[2000]; @@ -116,7 +151,7 @@ std::pair trt_convert( return std::pair{out_status, ""}; } string result; - if (!outGraph.SerializeToString(&result)) { + if (!out_graph.SerializeToString(&result)) { out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; return std::pair{out_status, ""}; } @@ -128,59 +163,72 @@ std::pair trt_convert( #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } -std::pair calib_convert(string graph_def_string // const tensorflow::GraphDef& - // unfortunately we can't use TF_Status here since it - // is in c/c_api and brings in a lot of other libraries - // which in turn declare ops. These ops are included - // statically in our library and cause an abort when - // module is loaded due to double registration - // until Tensorflow properly exposes these headers - // we have to work around this by returning a string - // and converting it to exception on python side. - //,TF_Status* out_status) { -) { +version_struct get_linked_tensorrt_version() { + // Return the version at the link time. + version_struct s; #if GOOGLE_CUDA && GOOGLE_TENSORRT - string out_status; + const auto &lv = tensorflow::tensorrt::convert::GetLinkedTensorRTVersion(); + s.vmajor = lv[0]; + s.vminor = lv[1]; + s.vpatch = lv[2]; +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + return s; +} - tensorflow::GraphDef graph_def; - if (!graph_def.ParseFromString(graph_def_string)) { - out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; - return std::pair{out_status, ""}; - } +version_struct get_loaded_tensorrt_version() { + // Return the version from the loaded library. + version_struct s; +#if GOOGLE_CUDA && GOOGLE_TENSORRT + const auto &lv = tensorflow::tensorrt::convert::GetLoadedTensorRTVersion(); + s.vmajor = lv[0]; + s.vminor = lv[1]; + s.vpatch = lv[2]; +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + return s; +} - tensorflow::GraphDef outGraph; - tensorflow::Status conversion_status = - tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(graph_def, - &outGraph); - if (!conversion_status.ok()) { - auto retCode = (int)conversion_status.code(); - char buff[2000]; - snprintf(buff, 2000, "%d;%s", retCode, - conversion_status.error_message().c_str()); - out_status = buff; - return std::pair{out_status, ""}; - } - string result; - if (!outGraph.SerializeToString(&result)) { - out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; - return std::pair{out_status, ""}; - } - out_status = "OK;All good!"; - return std::pair{out_status, result}; +bool is_tensorrt_enabled() { + return tensorflow::tensorrt::IsGoogleTensorRTEnabled(); +} + +void enable_test_value() { + tensorflow::tensorrt::test::EnableTestValue(); +} + +#if PY_MAJOR_VERSION < 3 +#define TRT_PY_TO_CPP_STRING PyString_AsString +#define TRT_CPP_TO_PY_STRING PyString_FromString #else - // Returns FAILED_PRECONDITION. - return std::pair{"9;TensorRT is not enabled!", ""}; -#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#define TRT_PY_TO_CPP_STRING PyUnicode_AsUTF8 +#define TRT_CPP_TO_PY_STRING PyUnicode_FromString +#endif + +void clear_test_values(PyObject* pattern) { + tensorflow::tensorrt::test::ClearTestValues( + string(TRT_PY_TO_CPP_STRING(pattern))); } -%} -std::pair calib_convert(string graph_def_string); +void add_test_value(PyObject* label, PyObject* value) { + tensorflow::tensorrt::test::AddTestValue( + string(TRT_PY_TO_CPP_STRING(label)), string(TRT_PY_TO_CPP_STRING(value))); +} + +PyObject* get_test_value(PyObject* label) { + string value = tensorflow::tensorrt::test::GetTestValue( + string(TRT_PY_TO_CPP_STRING(label))); + return TRT_CPP_TO_PY_STRING(value.c_str()); +} -std::pair trt_convert(string graph_def_string, - std::vector output_names, - size_t max_batch_size, - size_t max_workspace_size_bytes, - int precision_mode, int minimum_segment_size); +%} +std::pair calib_convert( + string graph_def_string, bool is_dyn_op); +version_struct get_linked_tensorrt_version(); +version_struct get_loaded_tensorrt_version(); +bool is_tensorrt_enabled(); +void enable_test_value(); +void clear_test_values(PyObject* pattern); +void add_test_value(PyObject* label, PyObject* value); +PyObject* get_test_value(PyObject* label); %unignoreall diff --git a/tensorflow/contrib/timeseries/__init__.py b/tensorflow/contrib/timeseries/__init__.py index 11db56b1b7a48b401efeece91283eb7084747c14..654a4db098757a969c2d298f7ed490083e63b9da 100644 --- a/tensorflow/contrib/timeseries/__init__.py +++ b/tensorflow/contrib/timeseries/__init__.py @@ -27,6 +27,9 @@ @@TrainEvalFeatures @@FilteringResults + +@@TimeSeriesRegressor +@@OneShotPredictionHead """ from __future__ import absolute_import diff --git a/tensorflow/contrib/timeseries/examples/multivariate.py b/tensorflow/contrib/timeseries/examples/multivariate.py index ed799542fd50cd150f13533c5f33bd67ed09fff6..e81cb18ad7b928a6fd2a748ea6b258c49cf722ae 100644 --- a/tensorflow/contrib/timeseries/examples/multivariate.py +++ b/tensorflow/contrib/timeseries/examples/multivariate.py @@ -80,8 +80,8 @@ def multivariate_train_and_sample( session=session, steps=1)) next_sample = numpy.random.multivariate_normal( # Squeeze out the batch and series length dimensions (both 1). - mean=numpy.squeeze(current_prediction["mean"], axis=[0, 1]), - cov=numpy.squeeze(current_prediction["covariance"], axis=[0, 1])) + mean=numpy.squeeze(current_prediction["mean"], axis=(0, 1)), + cov=numpy.squeeze(current_prediction["covariance"], axis=(0, 1))) # Update model state so that future predictions are conditional on the # value we just sampled. filtering_features = { diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index e4963596d38dbe8aea98fddbc67dbbf761c215c8..c230919168b937b26c68e141e15f0762ad70f3e6 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -94,7 +94,6 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/estimator:export", "//tensorflow/python/feature_column", ], ) @@ -149,17 +148,16 @@ py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/estimator:export", - "//tensorflow/python/estimator:head", - "//tensorflow/python/estimator:metric_keys", ], ) py_test( name = "head_test", + size = "large", srcs = [ "head_test.py", ], + shard_count = 4, srcs_version = "PY2AND3", tags = ["no_pip_gpu"], # b/63391119 deps = [ @@ -184,6 +182,7 @@ py_test( "//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:tag_constants", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/timeseries/python/timeseries/__init__.py b/tensorflow/contrib/timeseries/python/timeseries/__init__.py index c683dad71de8f8502f08a4e823faa79d60d5604d..8462138339cda8557d9c9ee6e79d4c7a67ad1aa7 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/__init__.py +++ b/tensorflow/contrib/timeseries/python/timeseries/__init__.py @@ -24,5 +24,6 @@ from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils from tensorflow.contrib.timeseries.python.timeseries.ar_model import * from tensorflow.contrib.timeseries.python.timeseries.estimators import * from tensorflow.contrib.timeseries.python.timeseries.feature_keys import * +from tensorflow.contrib.timeseries.python.timeseries.head import * from tensorflow.contrib.timeseries.python.timeseries.input_pipeline import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py index 63f5d3568bc208e1ce0ae69abb3a93132163c860..5eb4deefb9494566bc31b2b8a72aab4f04f2980e 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py @@ -195,7 +195,7 @@ class ARModelTest(test.TestCase): self.train_helper(input_window_size=10, loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS, train_steps=300, - max_loss=1.5, + max_loss=2.5, anomaly_distribution=None) def test_autoregression_normal_multiple_periods(self): diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index 4ec8d26116159fee3ac00581010d1603ac9e19f3..0ddc4b4144da25206735b0480aa0886374ed43a8 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.training import training as train from tensorflow.python.util import nest @@ -79,12 +80,137 @@ class TimeSeriesRegressor(estimator_lib.Estimator): model_dir=model_dir, config=config) - # TODO(allenl): A parsing input receiver function, which takes a serialized - # tf.Example containing all features (times, values, any exogenous features) - # and serialized model state (possibly also as a tf.Example). - def build_raw_serving_input_receiver_fn(self, - default_batch_size=None, - default_series_length=None): + def _model_start_state_placeholders( + self, batch_size_tensor, static_batch_size=None): + """Creates placeholders with zeroed start state for the current model.""" + gathered_state = {} + # Models may not know the shape of their state without creating some + # variables/ops. Avoid polluting the default graph by making a new one. We + # use only static metadata from the returned Tensors. + with ops.Graph().as_default(): + self._model.initialize_graph() + # Evaluate the initial state as same-dtype "zero" values. These zero + # constants aren't used, but are necessary for feeding to + # placeholder_with_default for the "cold start" case where state is not + # fed to the model. + def _zeros_like_constant(tensor): + return tensor_util.constant_value(array_ops.zeros_like(tensor)) + start_state = nest.map_structure( + _zeros_like_constant, self._model.get_start_state()) + for prefixed_state_name, state in ts_head_lib.state_to_dictionary( + start_state).items(): + state_shape_with_batch = tensor_shape.TensorShape( + (static_batch_size,)).concatenate(state.shape) + default_state_broadcast = array_ops.tile( + state[None, ...], + multiples=array_ops.concat( + [batch_size_tensor[None], + array_ops.ones(len(state.shape), dtype=dtypes.int32)], + axis=0)) + gathered_state[prefixed_state_name] = array_ops.placeholder_with_default( + input=default_state_broadcast, + name=prefixed_state_name, + shape=state_shape_with_batch) + return gathered_state + + def build_one_shot_parsing_serving_input_receiver_fn( + self, filtering_length, prediction_length, default_batch_size=None, + values_input_dtype=None, truncate_values=False): + """Build an input_receiver_fn for export_savedmodel accepting tf.Examples. + + Only compatible with `OneShotPredictionHead` (see `head`). + + Args: + filtering_length: The number of time steps used as input to the model, for + which values are provided. If more than `filtering_length` values are + provided (via `truncate_values`), only the first `filtering_length` + values are used. + prediction_length: The number of time steps requested as predictions from + the model. Times and all exogenous features must be provided for these + steps. + default_batch_size: If specified, must be a scalar integer. Sets the batch + size in the static shape information of all feature Tensors, which means + only this batch size will be accepted by the exported model. If None + (default), static shape information for batch sizes is omitted. + values_input_dtype: An optional dtype specification for values in the + tf.Example protos (either float32 or int64, since these are the numeric + types supported by tf.Example). After parsing, values are cast to the + model's dtype (float32 or float64). + truncate_values: If True, expects `filtering_length + prediction_length` + values to be provided, but only uses the first `filtering_length`. If + False (default), exactly `filtering_length` values must be provided. + + Returns: + An input_receiver_fn which may be passed to the Estimator's + export_savedmodel. + + Expects features contained in a vector of serialized tf.Examples with + shape [batch size] (dtype `tf.string`), each tf.Example containing + features with the following shapes: + times: [filtering_length + prediction_length] integer + values: [filtering_length, num features] floating point. If + `truncate_values` is True, expects `filtering_length + + prediction_length` values but only uses the first `filtering_length`. + all exogenous features: [filtering_length + prediction_length, ...] + (various dtypes) + """ + if values_input_dtype is None: + values_input_dtype = dtypes.float32 + if truncate_values: + values_proto_length = filtering_length + prediction_length + else: + values_proto_length = filtering_length + + def _serving_input_receiver_fn(): + """A receiver function to be passed to export_savedmodel.""" + times_column = feature_column.numeric_column( + key=feature_keys.TrainEvalFeatures.TIMES, dtype=dtypes.int64) + values_column = feature_column.numeric_column( + key=feature_keys.TrainEvalFeatures.VALUES, dtype=values_input_dtype, + shape=(self._model.num_features,)) + parsed_features_no_sequence = ( + feature_column.make_parse_example_spec( + list(self._model.exogenous_feature_columns) + + [times_column, values_column])) + parsed_features = {} + for key, feature_spec in parsed_features_no_sequence.items(): + if isinstance(feature_spec, parsing_ops.FixedLenFeature): + if key == feature_keys.TrainEvalFeatures.VALUES: + parsed_features[key] = feature_spec._replace( + shape=((values_proto_length,) + + feature_spec.shape)) + else: + parsed_features[key] = feature_spec._replace( + shape=((filtering_length + prediction_length,) + + feature_spec.shape)) + elif feature_spec.dtype == dtypes.string: + parsed_features[key] = parsing_ops.FixedLenFeature( + shape=(filtering_length + prediction_length,), + dtype=dtypes.string) + else: # VarLenFeature + raise ValueError("VarLenFeatures not supported, got %s for key %s" + % (feature_spec, key)) + tfexamples = array_ops.placeholder( + shape=[default_batch_size], dtype=dtypes.string, name="input") + features = parsing_ops.parse_example( + serialized=tfexamples, + features=parsed_features) + features[feature_keys.TrainEvalFeatures.TIMES] = array_ops.squeeze( + features[feature_keys.TrainEvalFeatures.TIMES], axis=-1) + features[feature_keys.TrainEvalFeatures.VALUES] = math_ops.cast( + features[feature_keys.TrainEvalFeatures.VALUES], + dtype=self._model.dtype)[:, :filtering_length] + features.update( + self._model_start_state_placeholders( + batch_size_tensor=array_ops.shape( + features[feature_keys.TrainEvalFeatures.TIMES])[0], + static_batch_size=default_batch_size)) + return export_lib.ServingInputReceiver( + features, {"examples": tfexamples}) + return _serving_input_receiver_fn + + def build_raw_serving_input_receiver_fn( + self, default_batch_size=None, default_series_length=None): """Build an input_receiver_fn for export_savedmodel which accepts arrays. Automatically creates placeholders for exogenous `FeatureColumn`s passed to @@ -149,34 +275,10 @@ class TimeSeriesRegressor(estimator_lib.Estimator): + batch_only_feature_shape[1:]) placeholders[feature_key] = array_ops.placeholder( dtype=value_dtype, name=feature_key, shape=feature_shape) - # Models may not know the shape of their state without creating some - # variables/ops. Avoid polluting the default graph by making a new one. We - # use only static metadata from the returned Tensors. - with ops.Graph().as_default(): - self._model.initialize_graph() - # Evaluate the initial state as same-dtype "zero" values. These zero - # constants aren't used, but are necessary for feeding to - # placeholder_with_default for the "cold start" case where state is not - # fed to the model. - def _zeros_like_constant(tensor): - return tensor_util.constant_value(array_ops.zeros_like(tensor)) - start_state = nest.map_structure( - _zeros_like_constant, self._model.get_start_state()) batch_size_tensor = array_ops.shape(time_placeholder)[0] - for prefixed_state_name, state in ts_head_lib.state_to_dictionary( - start_state).items(): - state_shape_with_batch = tensor_shape.TensorShape( - (default_batch_size,)).concatenate(state.shape) - default_state_broadcast = array_ops.tile( - state[None, ...], - multiples=array_ops.concat( - [batch_size_tensor[None], - array_ops.ones(len(state.shape), dtype=dtypes.int32)], - axis=0)) - placeholders[prefixed_state_name] = array_ops.placeholder_with_default( - input=default_state_broadcast, - name=prefixed_state_name, - shape=state_shape_with_batch) + placeholders.update( + self._model_start_state_placeholders( + batch_size_tensor, static_batch_size=default_batch_size)) return export_lib.ServingInputReceiver(placeholders, placeholders) return _serving_input_receiver_fn @@ -288,7 +390,7 @@ class StateSpaceRegressor(TimeSeriesRegressor): """An Estimator for general state space models.""" def __init__(self, model, state_manager=None, optimizer=None, model_dir=None, - config=None): + config=None, head_type=ts_head_lib.TimeSeriesRegressionHead): """See TimeSeriesRegressor. Uses the ChainingStateManager by default.""" if not isinstance(model, state_space_model.StateSpaceModel): raise ValueError( @@ -301,7 +403,8 @@ class StateSpaceRegressor(TimeSeriesRegressor): state_manager=state_manager, optimizer=optimizer, model_dir=model_dir, - config=config) + config=config, + head_type=head_type) class StructuralEnsembleRegressor(StateSpaceRegressor): @@ -344,7 +447,8 @@ class StructuralEnsembleRegressor(StateSpaceRegressor): anomaly_prior_probability=None, optimizer=None, model_dir=None, - config=None): + config=None, + head_type=ts_head_lib.TimeSeriesRegressionHead): """Initialize the Estimator. Args: @@ -401,6 +505,8 @@ class StructuralEnsembleRegressor(StateSpaceRegressor): from tf.train.Optimizer. Defaults to Adam with step size 0.02. model_dir: See `Estimator`. config: See `Estimator`. + head_type: The kind of head to use for the model (inheriting from + `TimeSeriesRegressionHead`). """ if anomaly_prior_probability is not None: filtering_postprocessor = StateInterpolatingAnomalyDetector( @@ -424,4 +530,5 @@ class StructuralEnsembleRegressor(StateSpaceRegressor): model=model, optimizer=optimizer, model_dir=model_dir, - config=config) + config=config, + head_type=head_type) diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index a28a5872b850b51630240bdeb3ff22f372613523..32194e400e6ada594ef2a067bf612826a6e4acd3 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -19,24 +19,22 @@ from __future__ import print_function import re -from tensorflow.python.training import training_util -from tensorflow.contrib.layers.python.layers import optimizers - from tensorflow.contrib.timeseries.python.timeseries import feature_keys - from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.export import export_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.util import nest from tensorflow.python.summary import summary +from tensorflow.python.training import training_util +from tensorflow.python.util import nest class _NoStatePredictOutput(export_lib.PredictOutput): @@ -102,12 +100,9 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce use_resource=True): model_outputs = self.create_loss(features, mode) - train_op = optimizers.optimize_loss( + train_op = self.optimizer.minimize( model_outputs.loss, - global_step=training_util.get_global_step(), - optimizer=self.optimizer, - # Learning rate is set in the Optimizer object - learning_rate=None) + global_step=training_util.get_global_step()) return estimator_lib.EstimatorSpec( loss=model_outputs.loss, mode=mode, @@ -132,7 +127,8 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce loss=model_outputs.loss, mode=mode, eval_metric_ops=metrics, - predictions={}) + # needed for custom metrics. + predictions=model_outputs.predictions) def _predict_ops(self, features): """Add ops for prediction to the graph.""" @@ -185,7 +181,7 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce return math_ops.cast(value, self.model.dtype) if name == feature_keys.PredictionFeatures.STATE_TUPLE: return value # Correct dtypes are model-dependent - return ops.convert_to_tensor(value) + return sparse_tensor.convert_to_tensor_or_sparse_tensor(value) def _gather_state(self, features): """Returns `features` with state packed, indicates if packing was done.""" @@ -207,15 +203,38 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce flat_sequence=[tensor for _, _, tensor in numbered_state]) return features, True + def _check_predict_features(self, features): + """Raises errors if features are not suitable for prediction.""" + if feature_keys.PredictionFeatures.TIMES not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.TIMES)) + if feature_keys.PredictionFeatures.STATE_TUPLE not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.STATE_TUPLE)) + times_feature = features[feature_keys.PredictionFeatures.TIMES] + if not times_feature.get_shape().is_compatible_with([None, None]): + raise ValueError( + ("Expected shape (batch dimension, window size) for feature '{}' " + "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, + times_feature.get_shape())) + _check_feature_shapes_compatible_with( + features=features, + compatible_with_name=feature_keys.PredictionFeatures.TIMES, + compatible_with_value=times_feature, + ignore=set([ + # Model-dependent shapes + feature_keys.PredictionFeatures.STATE_TUPLE + ])) + def create_estimator_spec(self, features, mode, labels=None): """Performs basic error checking and returns an EstimatorSpec.""" with ops.name_scope(self._name, "head"): - if labels: + if labels is not None and labels != {}: # for better error messages. raise ValueError( - "The model received a `labels` dictionary, which is " - "not supported. Pass '{}' and '{}' as " - "features.".format(feature_keys.TrainEvalFeatures.TIMES, - feature_keys.TrainEvalFeatures.VALUES)) + "The model received a `labels`, which is not supported. " + "Pass '{}' and '{}' as features.".format( + feature_keys.TrainEvalFeatures.TIMES, + feature_keys.TrainEvalFeatures.VALUES)) del labels features = { name: self._convert_feature_to_tensor(name=name, value=value) @@ -235,7 +254,7 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce mode == estimator_lib.ModeKeys.EVAL): _check_train_eval_features(features, self.model) elif mode == estimator_lib.ModeKeys.PREDICT: - _check_predict_features(features) + self._check_predict_features(features) else: raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode)) @@ -272,6 +291,44 @@ class OneShotPredictionHead(TimeSeriesRegressionHead): each time predictions are requested when using this head. """ + def _check_predict_features(self, features): + """Raises errors if features are not suitable for one-shot prediction.""" + if feature_keys.PredictionFeatures.TIMES not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.TIMES)) + if feature_keys.TrainEvalFeatures.VALUES not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.TrainEvalFeatures.VALUES)) + if feature_keys.PredictionFeatures.STATE_TUPLE not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.STATE_TUPLE)) + times_feature = features[feature_keys.PredictionFeatures.TIMES] + if not times_feature.get_shape().is_compatible_with([None, None]): + raise ValueError( + ("Expected shape (batch dimension, window size) for feature '{}' " + "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, + times_feature.get_shape())) + _check_feature_shapes_compatible_with( + features=features, + compatible_with_name=feature_keys.PredictionFeatures.TIMES, + compatible_with_value=times_feature, + ignore=set([ + # Model-dependent shapes + feature_keys.PredictionFeatures.STATE_TUPLE, + # One shot prediction head relies on values being shorter than + # times. Even though we're predicting eventually, we need values for + # the filtering phase. + feature_keys.TrainEvalFeatures.VALUES, + ])) + + def _evaluate_ops(self, features): + """Add ops for evaluation (aka filtering) to the graph.""" + spec = super(OneShotPredictionHead, self)._evaluate_ops(features) + # No state is fed to OneShotPredictionHead, so we don't return it; it being + # a tuple can cause issues for downstream infrastructure. + del spec.eval_metric_ops[feature_keys.State.STATE_TUPLE] + return spec + def _serving_ops(self, features): """Add ops for serving to the graph.""" with variable_scope.variable_scope("model", use_resource=True): @@ -338,29 +395,6 @@ def _check_feature_shapes_compatible_with(features, times_shape=compatible_with_value.get_shape())) -def _check_predict_features(features): - """Raises errors if features are not suitable for prediction.""" - if feature_keys.PredictionFeatures.TIMES not in features: - raise ValueError("Expected a '{}' feature for prediction.".format( - feature_keys.PredictionFeatures.TIMES)) - if feature_keys.PredictionFeatures.STATE_TUPLE not in features: - raise ValueError("Expected a '{}' feature for prediction.".format( - feature_keys.PredictionFeatures.STATE_TUPLE)) - times_feature = features[feature_keys.PredictionFeatures.TIMES] - if not times_feature.get_shape().is_compatible_with([None, None]): - raise ValueError( - ("Expected shape (batch dimension, window size) for feature '{}' " - "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, - times_feature.get_shape())) - _check_feature_shapes_compatible_with( - features=features, - compatible_with_name=feature_keys.PredictionFeatures.TIMES, - compatible_with_value=times_feature, - ignore=set([ - feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes - ])) - - def _check_train_eval_features(features, model): """Raise errors if features are not suitable for training/evaluation.""" if feature_keys.TrainEvalFeatures.TIMES not in features: diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index c606db76a668235ab6a837159b9dec072b5fd801..bda3b53aca0d0156e542e2bedcadf5caa6b3d2cf 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -18,16 +18,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools +import os + +from absl.testing import parameterized import numpy import six +from tensorflow.contrib.estimator.python.estimator import extenders from tensorflow.contrib.timeseries.examples import lstm as lstm_example +from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators from tensorflow.contrib.timeseries.python.timeseries import feature_keys from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib from tensorflow.contrib.timeseries.python.timeseries import input_pipeline from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import state_management +from tensorflow.core.example import example_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.estimator import estimator_lib @@ -35,6 +42,7 @@ from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import variables @@ -53,9 +61,12 @@ class HeadTest(test.TestCase): model_fn = _stub_model_fn() for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL, estimator_lib.ModeKeys.PREDICT]: - with self.assertRaisesRegexp(ValueError, "labels"): + with self.assertRaisesRegexp(ValueError, "received a `labels`"): model_fn(features={}, labels={"a": "b"}, mode=mode) + with self.assertRaisesRegexp(ValueError, "received a `labels`"): + model_fn(features={}, labels=array_ops.zeros([]), mode=mode) + def test_unknown_mode(self): model_fn = _stub_model_fn() with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"): @@ -128,6 +139,44 @@ class EvaluationMetricsTests(test.TestCase): coordinator.request_stop() coordinator.join() + def test_custom_metrics(self): + """Tests that the custom metrics can be applied to the estimator.""" + model_dir = self.get_temp_dir() + estimator = ts_estimators.TimeSeriesRegressor( + model=lstm_example._LSTMModel(num_features=1, num_units=4), + optimizer=adam.AdamOptimizer(0.001), + config=estimator_lib.RunConfig(tf_random_seed=4), + model_dir=model_dir) + + def input_fn(): + return { + feature_keys.TrainEvalFeatures.TIMES: [[1, 2, 3], [7, 8, 9]], + feature_keys.TrainEvalFeatures.VALUES: + numpy.array([[[0.], [1.], [0.]], [[2.], [3.], [2.]]]) + } + + def metrics_fn(predictions, features): + # checking that the inputs are properly passed. + predict = predictions["mean"] + target = features[feature_keys.TrainEvalFeatures.VALUES][:, -1, 0] + return { + "plain_boring_metric386": + (math_ops.reduce_mean(math_ops.abs(predict - target)), + control_flow_ops.no_op()), + "fun_metric101": (math_ops.reduce_sum(predict + target), + control_flow_ops.no_op()), + } + + # Evaluation without training is enough for testing custom metrics. + estimator = extenders.add_metrics(estimator, metrics_fn) + evaluation = estimator.evaluate(input_fn, steps=1) + self.assertIn("plain_boring_metric386", evaluation) + self.assertIn("fun_metric101", evaluation) + # The values are deterministic because of fixed tf_random_seed. + # However if they become flaky, remove such exacts comparisons. + self.assertAllClose(evaluation["plain_boring_metric386"], 1.130380) + self.assertAllClose(evaluation["fun_metric101"], 10.435442) + class _StubModel(object): num_features = 3 @@ -274,10 +323,56 @@ class PredictFeatureCheckingTests(test.TestCase): mode=estimator_lib.ModeKeys.PREDICT) -class OneShotTests(test.TestCase): - - def test_one_shot_prediction_head_export(self): - model_dir = self.get_temp_dir() +def _custom_time_series_regressor( + model_dir, head_type, exogenous_feature_columns): + return ts_estimators.TimeSeriesRegressor( + model=lstm_example._LSTMModel( + num_features=5, num_units=128, + exogenous_feature_columns=exogenous_feature_columns), + optimizer=adam.AdamOptimizer(0.001), + config=estimator_lib.RunConfig(tf_random_seed=4), + state_manager=state_management.ChainingStateManager(), + head_type=head_type, + model_dir=model_dir) + + +def _structural_ensemble_regressor( + model_dir, head_type, exogenous_feature_columns): + return ts_estimators.StructuralEnsembleRegressor( + periodicities=None, + num_features=5, + exogenous_feature_columns=exogenous_feature_columns, + head_type=head_type, + model_dir=model_dir) + + +def _ar_lstm_regressor( + model_dir, head_type, exogenous_feature_columns): + return ts_estimators.TimeSeriesRegressor( + model=ar_model.ARModel( + periodicities=10, input_window_size=10, output_window_size=6, + num_features=5, + exogenous_feature_columns=exogenous_feature_columns, + prediction_model_factory=functools.partial( + ar_model.LSTMPredictionModel, + num_units=10)), + head_type=head_type, + model_dir=model_dir) + + +class OneShotTests(parameterized.TestCase): + + @parameterized.named_parameters( + {"testcase_name": "ar_lstm_regressor", + "estimator_factory": _ar_lstm_regressor}, + {"testcase_name": "custom_time_series_regressor", + "estimator_factory": _custom_time_series_regressor}, + {"testcase_name": "structural_ensemble_regressor", + "estimator_factory": _structural_ensemble_regressor}) + def test_one_shot_prediction_head_export(self, estimator_factory): + def _new_temp_dir(): + return os.path.join(test.get_temp_dir(), str(ops.uid())) + model_dir = _new_temp_dir() categorical_column = feature_column.categorical_column_with_hash_bucket( key="categorical_exogenous_feature", hash_bucket_size=16) exogenous_feature_columns = [ @@ -285,15 +380,10 @@ class OneShotTests(test.TestCase): "2d_exogenous_feature", shape=(2,)), feature_column.embedding_column( categorical_column=categorical_column, dimension=10)] - estimator = ts_estimators.TimeSeriesRegressor( - model=lstm_example._LSTMModel( - num_features=5, num_units=128, - exogenous_feature_columns=exogenous_feature_columns), - optimizer=adam.AdamOptimizer(0.001), - config=estimator_lib.RunConfig(tf_random_seed=4), - state_manager=state_management.ChainingStateManager(), - head_type=ts_head_lib.OneShotPredictionHead, - model_dir=model_dir) + estimator = estimator_factory( + model_dir=model_dir, + exogenous_feature_columns=exogenous_feature_columns, + head_type=ts_head_lib.OneShotPredictionHead) train_features = { feature_keys.TrainEvalFeatures.TIMES: numpy.arange( 20, dtype=numpy.int64), @@ -307,8 +397,10 @@ class OneShotTests(test.TestCase): input_pipeline.NumpyReader(train_features), shuffle_seed=2, num_threads=1, batch_size=16, window_size=16) estimator.train(input_fn=train_input_fn, steps=5) + result = estimator.evaluate(input_fn=train_input_fn, steps=1) + self.assertNotIn(feature_keys.State.STATE_TUPLE, result) input_receiver_fn = estimator.build_raw_serving_input_receiver_fn() - export_location = estimator.export_savedmodel(self.get_temp_dir(), + export_location = estimator.export_savedmodel(_new_temp_dir(), input_receiver_fn) graph = ops.Graph() with graph.as_default(): @@ -342,7 +434,42 @@ class OneShotTests(test.TestCase): for output_key, output_value in predict_signature.outputs.items()} output = session.run(fetches, feed_dict=feeds) - self.assertAllEqual((2, 15, 5), output["mean"].shape) + self.assertEqual((2, 15, 5), output["mean"].shape) + # Build a parsing input function, then make a tf.Example for it to parse. + export_location = estimator.export_savedmodel( + _new_temp_dir(), + estimator.build_one_shot_parsing_serving_input_receiver_fn( + filtering_length=20, prediction_length=15)) + graph = ops.Graph() + with graph.as_default(): + with session_lib.Session() as session: + example = example_pb2.Example() + times = example.features.feature[feature_keys.TrainEvalFeatures.TIMES] + values = example.features.feature[feature_keys.TrainEvalFeatures.VALUES] + times.int64_list.value.extend(range(35)) + for i in range(20): + values.float_list.value.extend( + [float(i) * 2. + feature_number + for feature_number in range(5)]) + real_feature = example.features.feature["2d_exogenous_feature"] + categortical_feature = example.features.feature[ + "categorical_exogenous_feature"] + for i in range(35): + real_feature.float_list.value.extend([1, 1]) + categortical_feature.bytes_list.value.append(b"strkey") + # Serialize the tf.Example for feeding to the Session + examples = [example.SerializeToString()] * 2 + signatures = loader.load( + session, [tag_constants.SERVING], export_location) + predict_signature = signatures.signature_def[ + feature_keys.SavedModelLabels.PREDICT] + ((_, input_value),) = predict_signature.inputs.items() + feeds = {graph.as_graph_element(input_value.name): examples} + fetches = {output_key: graph.as_graph_element(output_value.name) + for output_key, output_value + in predict_signature.outputs.items()} + output = session.run(fetches, feed_dict=feeds) + self.assertEqual((2, 15, 5), output["mean"].shape) if __name__ == "__main__": diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index f84ff1bfe9b014733205a8e51b43f79c63b227cb..56e451e2e37b48496902ad5bb7468cb48111f65b 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -15,8 +15,8 @@ package( default_visibility = [ "//cloud/vmm/testing/tests/tpu:__subpackages__", "//learning/brain:__subpackages__", + "//learning/deepmind:__subpackages__", "//tensorflow:__subpackages__", - "//third_party/cloud_tpu:__subpackages__", ], ) @@ -37,16 +37,17 @@ cc_library( py_library( name = "tpu_estimator", srcs = [ + "python/tpu/error_handling.py", "python/tpu/tpu_config.py", "python/tpu/tpu_context.py", "python/tpu/tpu_estimator.py", - "python/tpu/tpu_system_metadata.py", "python/tpu/util.py", ], srcs_version = "PY2AND3", deps = [ ":tpu_lib", - ":tpu_py", + "//tensorflow/compiler/xla/experimental/xla_sharding", + "//tensorflow/compiler/xla/python_api:xla_shape", "//tensorflow/contrib/training:training_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -61,10 +62,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/estimator:run_config", - "//tensorflow/python/estimator:util", + "//tensorflow/python/estimator:estimator_py", "@six_archive//:six", ], ) @@ -133,7 +131,7 @@ py_library( tf_custom_op_py_library( name = "tpu_py", - srcs = glob(["python/ops/*.py"]) + ["__init__.py"], + srcs = glob(["python/ops/*.py"]), dso = [":python/ops/_tpu_ops.so"], kernels = [ ":all_ops", @@ -152,21 +150,62 @@ tf_custom_op_py_library( py_library( name = "tpu", - srcs = ["python/tpu/__init__.py"], + srcs = [ + "__init__.py", + "python/tpu/__init__.py", + ], srcs_version = "PY2AND3", deps = [ + ":keras_support", # split out to avoid cycle with tpu_strategy ":tpu_estimator", ":tpu_lib", ], ) +py_library( + name = "keras_support", + srcs = [ + "python/tpu/keras_support.py", + ], + srcs_version = "PY2AND3", + visibility = [ + "//cloud/vmm/testing/tests/tpu:__subpackages__", + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + "//third_party/cloud_tpu/models/keras:__subpackages__", + ], + deps = [ + ":tpu_lib", + "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py", + "//tensorflow/contrib/distribute", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:random_ops", + "//tensorflow/python:session", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/keras:backend", + "//tensorflow/python/keras:engine", + "//tensorflow/python/keras:layers", + "//third_party/py/numpy", + ], +) + py_library( name = "tpu_lib", srcs = [ "python/tpu/__init__.py", "python/tpu/bfloat16.py", "python/tpu/device_assignment.py", - "python/tpu/keras_support.py", "python/tpu/session_support.py", "python/tpu/topology.py", "python/tpu/tpu.py", @@ -174,6 +213,7 @@ py_library( "python/tpu/tpu_function.py", "python/tpu/tpu_optimizer.py", "python/tpu/tpu_sharding.py", + "python/tpu/tpu_system_metadata.py", "python/tpu/training_loop.py", ], srcs_version = "PY2AND3", @@ -181,6 +221,7 @@ py_library( ":datasets", ":profiler", ":tpu_py", + "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py", "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", "//tensorflow/contrib/tpu/proto:topology_proto_py", "//tensorflow/core:protos_all_py", @@ -224,7 +265,6 @@ tf_py_test( ":datasets", ], grpc_enabled = True, - tags = ["no_windows"], ) tf_py_test( @@ -306,3 +346,13 @@ tf_py_test( "//tensorflow/python:framework_test_lib", ], ) + +tf_py_test( + name = "topology_test", + size = "small", + srcs = ["python/tpu/topology_test.py"], + additional_deps = [ + ":tpu", + "//tensorflow/python:framework_test_lib", + ], +) diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py index dc9066855990f372c28dc481959117daa4c2da97..537d94b7979af3e4bd3fb7392c8dcc5a210e98af 100644 --- a/tensorflow/contrib/tpu/__init__.py +++ b/tensorflow/contrib/tpu/__init__.py @@ -18,6 +18,10 @@ @@cross_replica_sum @@infeed_dequeue @@infeed_dequeue_tuple +@@infeed_enqueue +@@infeed_enqueue_tuple +@@outfeed_dequeue +@@outfeed_dequeue_tuple @@outfeed_enqueue @@outfeed_enqueue_tuple @@ -42,9 +46,14 @@ @@TPUEstimator @@TPUEstimatorSpec +@@export_estimator_savedmodel @@RunConfig @@InputPipelineConfig @@TPUConfig +@@bfloat16_scope + +@@TPUDistributionStrategy +@@keras_to_tpu_model """ from __future__ import absolute_import @@ -56,11 +65,13 @@ from tensorflow.contrib.tpu.python import profiler from tensorflow.contrib.tpu.python.ops.tpu_ops import * from tensorflow.contrib.tpu.python.tpu.bfloat16 import * from tensorflow.contrib.tpu.python.tpu.device_assignment import * +from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model +from tensorflow.contrib.tpu.python.tpu.keras_support import TPUDistributionStrategy from tensorflow.contrib.tpu.python.tpu.topology import * from tensorflow.contrib.tpu.python.tpu.tpu import * from tensorflow.contrib.tpu.python.tpu.tpu_config import * from tensorflow.contrib.tpu.python.tpu.tpu_estimator import * -from tensorflow.contrib.tpu.python.tpu.tpu_feed import * +from tensorflow.contrib.tpu.python.tpu.tpu_feed import InfeedQueue from tensorflow.contrib.tpu.python.tpu.tpu_optimizer import * from tensorflow.contrib.tpu.python.tpu.training_loop import * # pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc index d389050e67f9a9e48b91583e5088058ec4e2832f..06553929dc44ca1f75ce64532a4dcdf1c8aae3eb 100644 --- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc +++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc @@ -23,15 +23,23 @@ REGISTER_OP("CrossReplicaSum") .Input("input: T") .Output("output: T") .Attr("T: {bfloat16, float}") + .Attr("group_assignment: list(int) = []") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( An Op to sum inputs across replicated TPU instances. Each -instance supplies its own input, and the output of each is the sum of -all the inputs. +instance supplies its own input. If group_assignment is empty, the output of +each is the sum of all the inputs, otherwise the output of each is the sum of +the inputs belonging to the same group. + +For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing +group_assignment=`[0,1,0,1]` sets `A, C` as group 0, and `B, D` as group 1. +Thus we get the outputs: `[A+C, B+D, A+C, B+D]`. input: The local input to the sum. output: The sum of all the distributed inputs. T: The type of elements to be summed. +group_assignment: The list of group ids. `group_assignment[i]` represents the + group id of replica i. )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index ab2a7a0d4bec48d6b3b459bb3144e8ddae614ca0..15a2bb17a93212afe9ce5604a28d9dba5825f7d4 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -44,6 +44,27 @@ REGISTER_OP("TPUReplicatedInput") " with other shapes."); } c->set_output(0, cur); + + // If this is a resource, unify the resource shapes. + DataType dtype; + TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype)); + if (dtype == DT_RESOURCE) { + const std::vector* shapes_and_types = + nullptr; + for (int i = c->num_inputs() - 1; i >= 0; --i) { + if (shapes_and_types) { + // The return value of MergeInputHandleShapesAndTypes indicates + // the shape was refined, not that there was an error. + // TODO(phawkins): there seems to be no way to discover errors. + (void)c->MergeInputHandleShapesAndTypes(i, *shapes_and_types); + } else { + shapes_and_types = c->input_handle_shapes_and_types(i); + } + } + if (shapes_and_types) { + c->set_output_handle_shapes_and_types(0, *shapes_and_types); + } + } return Status::OK(); }) .Doc( diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index 3b2d7adfff6b8de3145a73756a8b5306445034c5..38d1c3049ef7185f2f9f448361029d066678cdae 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -49,11 +49,11 @@ tf_cc_binary( ":tpu_profiler_analysis_proto_cc", ":tpu_profiler_proto_cc", ":version", + "//tensorflow:grpc++", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/platform/cloud:gcs_file_system", - "@grpc//:grpc++", ], ) diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index f80f5652af79d410946971573ae160fdd0b85f6d..8e6e9aa0cded630f39bfd699def37e06a8b920e8 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -84,8 +84,6 @@ ProfileRequest PopulateProfileRequest(int duration_ms, request.add_tools("memory_viewer"); request.add_tools("overview_page"); *request.mutable_opts() = opts; - std::cout << "Limiting the number of trace events to " << kMaxEvents - << std::endl; return request; } @@ -99,7 +97,6 @@ bool Profile(const string& service_addr, const string& logdir, int duration_ms, ::grpc::ClientContext context; ::grpc::ChannelArguments channel_args; - // TODO(ioeric): use `SetMaxReceiveMessageSize` instead once it's available. // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their // `ValidateHostPortPair` checks for empty host string case. channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, @@ -166,6 +163,85 @@ bool NewSession(const string& service_addr, return new_session_response.empty_trace(); } +// Starts tracing on a single or multiple TPU hosts and saves the result in the +// given logdir. If no trace was collected, retries tracing for +// num_tracing_attempts. +void StartTracing(const tensorflow::string& service_addr, + const tensorflow::string& logdir, + const tensorflow::string& workers_list, + bool include_dataset_ops, int duration_ms, + int num_tracing_attempts) { + // Use the current timestamp as the run name. + tensorflow::string session_id = GetCurrentTimeStampAsString(); + constexpr char kProfilePluginDirectory[] = "plugins/profile/"; + tensorflow::string repository_root = + io::JoinPath(logdir, kProfilePluginDirectory); + std::vector hostnames = + tensorflow::str_util::Split(workers_list, ","); + + bool empty_trace = false; + int remaining_attempts = num_tracing_attempts; + tensorflow::ProfileOptions opts; + opts.set_include_dataset_ops(include_dataset_ops); + while (true) { + std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. " + << "Remaining attempt(s): " << remaining_attempts-- << std::endl; + if (hostnames.empty()) { + empty_trace = tensorflow::tpu::Profile(service_addr, logdir, duration_ms, + repository_root, session_id, opts); + } else { + tensorflow::string tpu_master = service_addr; + empty_trace = + tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms, + repository_root, session_id, opts); + } + if (remaining_attempts <= 0 || !empty_trace) break; + std::cout << "No trace event is collected. Automatically retrying." + << std::endl + << std::endl; + } + + if (empty_trace) { + std::cout << "No trace event is collected after " << num_tracing_attempts + << " attempt(s). " + << "Perhaps, you want to try again (with more attempts?)." + << std::endl + << "Tip: increase number of attempts with --num_tracing_attempts." + << std::endl; + } +} + +MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level) { + MonitorRequest request; + request.set_duration_ms(duration_ms); + request.set_monitoring_level(monitoring_level); + return request; +} + +// Repeatedly collects profiles and shows user-friendly metrics for +// 'num_queries' time(s). +void StartMonitoring(const tensorflow::string& service_addr, int duration_ms, + int monitoring_level, int num_queries) { + for (int query = 0; query < num_queries; ++query) { + MonitorRequest request = + PopulateMonitorRequest(duration_ms, monitoring_level); + + ::grpc::ClientContext context; + ::grpc::ChannelArguments channel_args; + channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, + std::numeric_limits::max()); + std::unique_ptr stub = + TPUProfiler::NewStub(::grpc::CreateCustomChannel( + "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), + channel_args)); + MonitorResponse response; + TF_QCHECK_OK(FromGrpcStatus(stub->Monitor(&context, request, &response))); + + std::cout << "Xprof Monitoring Results (Sample " << query + 1 << "):\n\n" + << response.data() << std::flush; + } +} + } // namespace } // namespace tpu } // namespace tensorflow @@ -174,9 +250,11 @@ int main(int argc, char** argv) { tensorflow::string FLAGS_service_addr; tensorflow::string FLAGS_logdir; tensorflow::string FLAGS_workers_list; - int FLAGS_duration_ms = 2000; + int FLAGS_duration_ms = 0; int FLAGS_num_tracing_attempts = 3; bool FLAGS_include_dataset_ops = true; + int FLAGS_monitoring_level = 0; + int FLAGS_num_queries = 100; std::vector flag_list = { tensorflow::Flag("service_addr", &FLAGS_service_addr, "Address of TPU profiler service e.g. localhost:8466"), @@ -186,21 +264,38 @@ int main(int argc, char** argv) { tensorflow::Flag("logdir", &FLAGS_logdir, "Path of TensorBoard log directory e.g. /tmp/tb_log, " "gs://tb_bucket"), - tensorflow::Flag("duration_ms", &FLAGS_duration_ms, - "Duration of tracing in ms. Default is 2000ms."), + tensorflow::Flag( + "duration_ms", &FLAGS_duration_ms, + "Duration of tracing or monitoring in ms. Default is 2000ms for " + "tracing and 1000ms for monitoring."), tensorflow::Flag("num_tracing_attempts", &FLAGS_num_tracing_attempts, "Automatically retry N times when no trace event " "is collected. Default is 3."), tensorflow::Flag("include_dataset_ops", &FLAGS_include_dataset_ops, "Set to false to profile longer TPU device traces."), - }; + tensorflow::Flag("monitoring_level", &FLAGS_monitoring_level, + "Choose a monitoring level between 1 and 2 to monitor " + "your TPU job continuously. Level 2 is more verbose " + "than level 1 and shows more metrics."), + tensorflow::Flag("num_queries", &FLAGS_num_queries, + "This script will run monitoring for num_queries before " + "it stops.")}; std::cout << "Welcome to the Cloud TPU Profiler v" << TPU_PROFILER_VERSION << std::endl; tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_ok || FLAGS_service_addr.empty() || FLAGS_logdir.empty()) { + if (!parse_ok || FLAGS_service_addr.empty() || + (FLAGS_logdir.empty() && FLAGS_monitoring_level == 0)) { + // Fail if flags are not parsed correctly or service_addr not provided. + // Also, fail if neither logdir is provided (required for tracing) nor + // monitoring level is provided (required for monitoring). + std::cout << usage.c_str() << std::endl; + return 2; + } + if (FLAGS_monitoring_level < 0 || FLAGS_monitoring_level > 2) { + // Invalid monitoring level. std::cout << usage.c_str() << std::endl; return 2; } @@ -213,52 +308,27 @@ int main(int argc, char** argv) { } tensorflow::port::InitMain(argv[0], &argc, &argv); - // Sets the minimum duration_ms and tracing attempts to one. - int duration_ms = std::max(FLAGS_duration_ms, 1); - int remaining_attempts = std::max(FLAGS_num_tracing_attempts, 1); - tensorflow::ProfileOptions opts; - opts.set_include_dataset_ops(FLAGS_include_dataset_ops); - tensorflow::ProfileResponse response; - - // Use the current timestamp as the run name. - tensorflow::string session_id = - tensorflow::tpu::GetCurrentTimeStampAsString(); - constexpr char kProfilePluginDirectory[] = "plugins/profile/"; - tensorflow::string repository_root = - ::tensorflow::io::JoinPath(FLAGS_logdir, kProfilePluginDirectory); - std::vector hostnames = - tensorflow::str_util::Split(FLAGS_workers_list, ","); - - bool empty_trace = false; - while (true) { - std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. " - << "Remaining attempt(s): " << remaining_attempts-- << std::endl; - if (hostnames.empty()) { - empty_trace = tensorflow::tpu::Profile(FLAGS_service_addr, FLAGS_logdir, - duration_ms, repository_root, - session_id, opts); - } else { - tensorflow::string tpu_master = FLAGS_service_addr; - empty_trace = - tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms, - repository_root, session_id, opts); - } - if (remaining_attempts <= 0 || !empty_trace) break; - std::cout << "No trace event is collected. Automatically retrying." - << std::endl - << std::endl; + // Sets the minimum duration_ms, tracing attempts and num queries. + int duration_ms = std::max(FLAGS_duration_ms, 0); + if (duration_ms == 0) { + // If profiling duration was not set by user or set to a negative value, we + // set it to default values of 2000ms for tracing and 1000ms for monitoring. + duration_ms = FLAGS_monitoring_level == 0 ? 2000 : 1000; } + int num_tracing_attempts = std::max(FLAGS_num_tracing_attempts, 1); + int num_queries = std::max(FLAGS_num_queries, 1); - if (empty_trace) { - std::cout << "No trace event is collected after " - << FLAGS_num_tracing_attempts << " attempt(s). " - << "Perhaps, you want to try again (with more attempts?)." - << std::endl - << "Tip: increase number of attempts with --num_tracing_attempts." + if (FLAGS_monitoring_level != 0) { + std::cout << "Since monitoring level is provided, profile " + << FLAGS_service_addr << " for " << duration_ms + << "ms and show metrics for " << num_queries << " time(s)." << std::endl; - // Don't dump profile data if no trace is collected. - return 0; + tensorflow::tpu::StartMonitoring(FLAGS_service_addr, duration_ms, + FLAGS_monitoring_level, num_queries); + } else { + tensorflow::tpu::StartTracing(FLAGS_service_addr, FLAGS_logdir, + FLAGS_workers_list, FLAGS_include_dataset_ops, + duration_ms, num_tracing_attempts); } - return 0; } diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 508c7a842fb82ec080082d7e7f02f8d2f2a79447..438f4428483a86b75ca1feb31d9c43f860fcc287 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -17,12 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from absl import flags - import os import subprocess import sys - +from absl import flags +from distutils.version import LooseVersion import tensorflow as tf # Cloud TPU Cluster Resolvers @@ -35,70 +34,87 @@ flags.DEFINE_string( None, help='GCE zone where the Cloud TPU is located in. If not specified, we ' 'will attempt to automatically detect the GCE project from metadata.') -flags.DEFINE_string('tpu_name', None, - 'Name of the Cloud TPU for Cluster Resolvers. You must ' - 'specify either this flag or --service_addr.') +flags.DEFINE_string( + 'tpu', None, 'Name of the Cloud TPU for Cluster Resolvers. You must ' + 'specify either this flag or --service_addr.') # Tool specific parameters flags.DEFINE_string( 'service_addr', None, 'Address of TPU profiler service e.g. ' - 'localhost:8466, you must specify either this flag or --tpu_name.') + 'localhost:8466, you must specify either this flag or --tpu.') flags.DEFINE_string( 'workers_list', None, 'The list of worker TPUs that we are about to profile' - ' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu_name or ' + ' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu or ' '--service_addr to profile a subset of tpu nodes. You can also use only' - '--tpu_name and leave this flag unspecified to profile all the tpus.') -flags.DEFINE_string('logdir', None, - 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' - 'gs://tb_bucket') -flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.') -flags.DEFINE_integer('num_tracing_attempts', 3, - 'Automatically retry N times when no trace ' - 'event is collected.') + '--tpu and leave this flag unspecified to profile all the tpus.') +flags.DEFINE_string( + 'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' + 'gs://tb_bucket') +flags.DEFINE_integer('duration_ms', 0, + 'Duration of tracing or monitoring in ms.') +flags.DEFINE_integer( + 'num_tracing_attempts', 3, 'Automatically retry N times when no trace ' + 'event is collected.') flags.DEFINE_boolean('include_dataset_ops', True, 'Set to false to profile longer TPU ' 'device traces.') +# Monitoring parameters +flags.DEFINE_integer( + 'monitoring_level', 0, 'Choose a monitoring level between ' + '1 and 2 to monitor your TPU job continuously.') +flags.DEFINE_integer( + 'num_queries', 100, + 'This script will run monitoring for num_queries before it stops.') + FLAGS = flags.FLAGS EXECUTABLE = 'data/capture_tpu_profile' JOB_NAME = 'worker' + def get_workers_list(cluster_resolver): cluster_spec = cluster_resolver.cluster_spec() task_indices = cluster_spec.task_indices(JOB_NAME) - workers_list = [cluster_spec.task_address(JOB_NAME, i).split(':')[0] - for i in task_indices] + workers_list = [ + cluster_spec.task_address(JOB_NAME, i).split(':')[0] for i in task_indices + ] return ','.join(workers_list) + def run_main(): tf.app.run(main) + def main(unused_argv=None): tf.logging.set_verbosity(tf.logging.INFO) + tf_version = tf.__version__ + print('TensorFlow version %s detected' % tf_version) - if FLAGS.service_addr is None and FLAGS.tpu_name is None: - sys.exit('You must specify either --service_addr or --tpu_name.') + if FLAGS.service_addr is None and FLAGS.tpu is None: + sys.exit('You must specify either --service_addr or --tpu.') tpu_cluster_resolver = None if FLAGS.service_addr is not None: - if FLAGS.tpu_name is not None: - tf.logging.warn('Both --service_addr and --tpu_name are set. Ignoring ' - '--tpu_name and using --service_addr.') + if FLAGS.tpu is not None: + tf.logging.warn('Both --service_addr and --tpu are set. Ignoring ' + '--tpu and using --service_addr.') service_addr = FLAGS.service_addr else: tpu_cluster_resolver = ( tf.contrib.cluster_resolver.TPUClusterResolver( - [FLAGS.tpu_name], - zone=FLAGS.tpu_zone, - project=FLAGS.gcp_project)) + [FLAGS.tpu], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)) service_addr = tpu_cluster_resolver.get_master() service_addr = service_addr.replace('grpc://', '').replace(':8470', ':8466') - workers_list = "" - if FLAGS.workers_list is not None: - workers_list = FLAGS.workers_list - elif tpu_cluster_resolver is not None: - workers_list = get_workers_list(tpu_cluster_resolver) + workers_list = '' + if LooseVersion(tf_version) < LooseVersion('1.9'): + tf.logging.warn('Attempt to profile with legacy support under TensorFlow ' + 'version %s' % tf_version) + else: + if FLAGS.workers_list is not None: + workers_list = FLAGS.workers_list + elif tpu_cluster_resolver is not None: + workers_list = get_workers_list(tpu_cluster_resolver) if not FLAGS.logdir: sys.exit('logdir must be provided.') @@ -111,6 +127,8 @@ def main(unused_argv=None): cmd.append('--duration_ms=' + str(FLAGS.duration_ms)) cmd.append('--num_tracing_attempts=' + str(FLAGS.num_tracing_attempts)) cmd.append('--include_dataset_ops=' + str(FLAGS.include_dataset_ops).lower()) + cmd.append('--monitoring_level=' + str(FLAGS.monitoring_level)) + cmd.append('--num_queries=' + str(FLAGS.num_queries)) subprocess.call(cmd) diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index ebd478fd02295108b9d2454963eb06165828b523..d4ccb0f24679af830365037819d51529874f4fcc 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -20,7 +20,7 @@ from __future__ import print_function from setuptools import setup -_VERSION = '1.6.0' +_VERSION = '1.10.0' CONSOLE_SCRIPTS = [ 'capture_tpu_profile=cloud_tpu_profiler.main:run_main', @@ -46,7 +46,7 @@ setup( # 3 - Alpha # 4 - Beta # 5 - Production/Stable - 'Development Status :: 4 - Beta', + 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto index f0fca63db0bca80cdaa27e491b2a03ae2246c007..da4a95e0450a9d0c20593ca60b69f3ad467d455d 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto @@ -11,6 +11,9 @@ service TPUProfiler { // Starts a profiling session, blocks until it completes, and returns data. rpc Profile(ProfileRequest) returns (ProfileResponse) { } + // Collects profiling data and returns user-friendly metrics. + rpc Monitor(MonitorRequest) returns (MonitorResponse) { + } } message ProfileOptions { @@ -104,3 +107,26 @@ message ProfileResponse { // next-field: 8 } + +message MonitorRequest { + // Duration for which to profile between each update. + uint64 duration_ms = 1; + + // Indicates the level at which we want to monitor. Currently, two levels are + // supported: + // Level 1: An ultra lightweight mode that captures only some utilization + // metrics. + // Level 2: More verbose than level 1. Collects utilization metrics, device + // information, step time information, etc. Do not use this option if the TPU + // host is being very heavily used. + int32 monitoring_level = 2; + + // next-field: 3 +} + +message MonitorResponse { + // Properly formatted string data that can be directly returned back to user. + string data = 1; + + // next-field: 2 +} diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h index 618479e1a6ccf26a4103ea1f182b662d7d9998da..aee094177bf8a36c98463055aafc777a7ed40f44 100644 --- a/tensorflow/contrib/tpu/profiler/version.h +++ b/tensorflow/contrib/tpu/profiler/version.h @@ -16,6 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ #define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ -#define TPU_PROFILER_VERSION "1.6.0" +#define TPU_PROFILER_VERSION "1.10.0" #endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD index 7ecb36852c53bb74d70ed0f8c70ca1ce860a037a..598b73b438cb239187a911b2d1425b434c889d8d 100644 --- a/tensorflow/contrib/tpu/proto/BUILD +++ b/tensorflow/contrib/tpu/proto/BUILD @@ -2,7 +2,12 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_additional_all_protos", + "tf_proto_library", + "tf_proto_library_py", +) tf_proto_library( name = "tpu_embedding_config_proto", @@ -10,6 +15,16 @@ tf_proto_library( "tpu_embedding_config.proto", ], cc_api_version = 2, + protodeps = [":optimization_parameters_proto"], + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "optimization_parameters_proto", + srcs = [ + "optimization_parameters.proto", + ], + cc_api_version = 2, visibility = ["//visibility:public"], ) @@ -22,12 +37,14 @@ tf_proto_library( visibility = ["//visibility:public"], ) -tf_proto_library( +tf_proto_library_py( name = "compilation_result_proto", srcs = [ "compilation_result.proto", ], - cc_api_version = 2, - protodeps = ["//tensorflow/core:protos_all"], + protodeps = tf_additional_all_protos() + [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/tpu/proto/compilation_result.proto b/tensorflow/contrib/tpu/proto/compilation_result.proto index cf52897de3d0fefa55e68a6b889ae9af7b45864a..88585a5bd10fc28aa34bb0de72de970e21b2adb2 100644 --- a/tensorflow/contrib/tpu/proto/compilation_result.proto +++ b/tensorflow/contrib/tpu/proto/compilation_result.proto @@ -3,6 +3,7 @@ syntax = "proto3"; option cc_enable_arenas = true; package tensorflow.tpu; +import "tensorflow/compiler/xla/service/hlo.proto"; import "tensorflow/core/lib/core/error_codes.proto"; // Describes the result of a TPU compilation. @@ -10,4 +11,7 @@ message CompilationResultProto { // The error message, if any, returned during compilation. error.Code status_code = 1; string status_error_message = 2; + + // HLO proto. + repeated xla.HloProto hlo_protos = 3; } diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto new file mode 100644 index 0000000000000000000000000000000000000000..2cc17d6d928370afbb0e3b1e89252f7a687c27d3 --- /dev/null +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -0,0 +1,164 @@ +syntax = "proto3"; + +package tensorflow.tpu; + +import "google/protobuf/wrappers.proto"; + +message ClippingLimits { + google.protobuf.FloatValue lower = 1; // -inf if not set + google.protobuf.FloatValue upper = 2; // +inf if not set +} + +// Get the learning rate from a source that can change +// dynamically. +message DynamicLearningRate { +} + +// Source of learning rate to use. +message LearningRate { + oneof learning_rate { + float constant = 1; + DynamicLearningRate dynamic = 2; + } +} + +message AdagradParameters { + float initial_accumulator = 1; +} + +message StochasticGradientDescentParameters { +} + +message FtrlParameters { + float l1 = 1; + float l2 = 2; + float lr_power = 3; + float initial_accum = 4; + float initial_linear = 5; +} + +// The Adam optimizer does not implement hyper-parameter update; use the dynamic +// learning rate feature instead, setting the learning rate to: +// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) +// Here, t is the current timestep. +// https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L54 +message AdamParameters { + float beta1 = 3; + float beta2 = 4; + float epsilon = 5; + float initial_m = 6; + float initial_v = 7; +} + +message MomentumParameters { + float momentum = 1; + bool use_nesterov = 2; + float initial_accum = 3; +} + +message RmsPropParameters { + float rho = 1; + float momentum = 2; + float epsilon = 3; + float initial_ms = 4; + float initial_mom = 5; +} + +message CenteredRmsPropParameters { + float rho = 1; + float momentum = 2; + float epsilon = 3; + float initial_ms = 4; + float initial_mom = 5; + float initial_mg = 6; +} + +message MdlAdagradLightParameters { + float l2 = 1; + float lr_power = 2; + float min_servable_mdl_benefit = 3; + float mdl_mix_in_margin = 4; + float mdl_benefit_rampup_coeff = 5; + float mdl_min_weight = 6; + float benefit_revisit_scale = 7; + float max_event_benefit = 8; + float max_total_benefit = 9; + float mdl_hard_limit = 10; + bool hard_limit_min_benefit = 11; + bool mdl_regularize = 12; + float initial_accumulator = 13; + float initial_weight = 14; + float initial_benefit = 15; +} + +message AdadeltaParameters { + float rho = 1; + float epsilon = 2; + float initial_accumulator = 3; + float initial_update = 4; +} + +message ProximalAdagradParameters { + float l1 = 1; + float l2 = 2; + float initial_accumulator = 3; +} + +message OptimizationParameters { + // Learning rate used for updating the embedding layer parameters. + LearningRate learning_rate = 13; + reserved 1; // Old learning rate tag. + + // Limits to which to clip the weight values after the backward pass; not + // present means no limits are applied. + ClippingLimits clipping_limits = 2; + + // Limits to which to clip the backward pass gradient before using it for + // updates; not present means no limits are applied. + ClippingLimits gradient_clipping_limits = 7; + + // Whether to use gradient accumulation (do two passes over the input + // gradients: one to accumulate them into a temporary array and another to + // apply them using the actual optimization algorithm). + bool use_gradient_accumulation = 15; + + // Optimization algorithm parameters; which field is selected determines which + // algorithm to use. + oneof parameters { + AdagradParameters adagrad = 3; + StochasticGradientDescentParameters stochastic_gradient_descent = 4; + FtrlParameters ftrl = 5; + AdamParameters adam = 6; + MomentumParameters momentum = 8; + RmsPropParameters rms_prop = 9; + CenteredRmsPropParameters centered_rms_prop = 10; + MdlAdagradLightParameters mdl_adagrad_light = 11; + AdadeltaParameters adadelta = 12; + ProximalAdagradParameters proximal_adagrad = 14; + } +} + +// Specification of an optimization algorithm's state variables (both the main +// value vector and any extra accumulators, etc.). +message StateVariableSpecification { + // Parameter name for the state variable. + string name = 1; + + // A normal state variable that should be saved and restored in checkpoints + // and used as an input or output to non-debug TensorFlow ops. + message UserDefined { + } + + // A state variable that should be filled with a constant and normally hidden + // from users (used for intermediate gradients being accumulated, for + // example). + message FillWithConstant { + double initial_value = 1; + } + + // Usage type of this state variable. + oneof usage { + UserDefined user_defined = 2; + FillWithConstant fill_with_constant = 3; + } +} diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto index b0ec968d3a401f1b80ed1bf6fd7a83a69c068fe2..3476cc89534efb7fe05640935d1387d02737f240 100644 --- a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto +++ b/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package tensorflow.tpu; +import "tensorflow/contrib/tpu/proto/optimization_parameters.proto"; + // The TPUEmbeddingConfiguration contains specification of TPU Embedding lookups // and gradient updates separate from the TF Graph. message TPUEmbeddingConfiguration { @@ -30,15 +32,6 @@ message TPUEmbeddingConfiguration { // The number of training examples per TensorNode. int32 batch_size = 4; - message GradientDescentOptimizer { - float learning_rate = 1; - } - - message AdagradOptimizer { - float learning_rate = 1; - float initial_accumulator = 2; - } - // Each Embedding message TPUEmbeddingTable { // Name of the embedding table. This will be used to name Variables in the @@ -66,10 +59,7 @@ message TPUEmbeddingConfiguration { // separately to the convolutional or recurrent network. int32 num_features = 5; - oneof optimizer { - GradientDescentOptimizer gradient_descent = 6; - AdagradOptimizer adagrad = 7; - } + OptimizationParameters optimization_parameters = 6; } repeated TPUEmbeddingTable table_config = 5; diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 14c63a79763300dcfe8d6c8e09b90f8e9c772358..bf442d9116d2ceca499ffc66258c64b5b94dd881 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -38,9 +38,8 @@ if platform.system() != "Windows": @ops.RegisterGradient("CrossReplicaSum") def _cross_replica_sum_grad(op, grad): - del op # Unused # The gradient of a cross replica sum is also a cross-replica sum. - return gen_tpu_ops.cross_replica_sum(grad) + return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment")) # This extra type checking exists to give a more helpful error message in # the common case that uint8 and int64 values are infed. Remove when both diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py index 726b2d248e3086e1882004827076ed3e563d960d..471b1fa46c679dcab70e9bc12d61ada84cba79bb 100644 --- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py +++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py @@ -175,6 +175,8 @@ class DeviceAssignment(object): """Returns the physical topology coordinates of a logical core.""" if logical_core is None: logical_core = np.array([0, 0, 0], np.int32) + else: + logical_core = np.asarray(logical_core) if any(logical_core < 0) or any(logical_core >= self.computation_shape): raise ValueError("Invalid core {}; computation shape is {}".format( diff --git a/tensorflow/contrib/tpu/python/tpu/error_handling.py b/tensorflow/contrib/tpu/python/tpu/error_handling.py new file mode 100644 index 0000000000000000000000000000000000000000..52e1ea42370d653d1de7c12eee4b456ec7ce921c --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/error_handling.py @@ -0,0 +1,132 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""ErrorRendezvous handler for collecting errors from multiple threads.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import sys +import threading +import time + +import six + +from tensorflow.python.framework import errors +from tensorflow.python.platform import tf_logging as logging + +_UNINTERESTING_ERRORS = (errors.CancelledError,) + + +class ErrorRendezvous(object): + """Resolve errors from multiple threads during TPU execution. + + TPU errors can occur on the infeed or outfeed threads as well as the main + training thread. + + Depending on which thread "wins" and receives the session error first, we may + end up showing users a confusing and non-actionable error message (session + cancelled) instead of a root cause (e.g. a bad filename). + + The rendezvous object provides a location to capture these errors until all + threads terminate. At that point we can choose the most informative error + to report. + """ + + def __init__(self, num_sources): + # string -> (message, traceback) + self._errors = {} + self._num_sources = num_sources + self._session_cancel_timer = None + + def record_error(self, source, exc_info, session=None): + """Report an exception from the given source. + + If a session is passed, a timer will be registered to close it after a few + seconds. This is necessary to ensure the main training loop does not hang + if an infeed/oufeed error occurs. We sleep a few seconds to allow a more + interesting error from another thread to propagate. + + Args: + source: string, source of the error + exc_info: Output from `sys.exc_info` (type, value, traceback) + session: Session to close after delay. + """ + _, value, _ = exc_info + self._errors[source] = exc_info + logging.info('Error recorded from %s: %s', source, value) + + if session is not None and self._session_cancel_timer is None: + + def _cancel_session(): + time.sleep(5) + try: + session.close() + except: # pylint: disable=bare-except + pass + + self._session_cancel_timer = threading.Thread(target=_cancel_session,) + self._session_cancel_timer.daemon = True + self._session_cancel_timer.start() + + def record_done(self, source): + """Mark execution source `source` as done. + + If an error was originally reported from `source` it is left intact. + + Args: + source: `str`, source being recorded + """ + logging.info('%s marked as finished', source) + if source not in self._errors: + self._errors[source] = None + + @contextlib.contextmanager + def catch_errors(self, source, session=None): + """Context manager to report any errors within a block.""" + try: + yield + except Exception: # pylint: disable=broad-except + self.record_error(source, sys.exc_info(), session) + + def raise_errors(self, timeout_sec=0): + """Wait for up to `timeout` seconds for all error sources to finish. + + Preferentially raise "interesting" errors (errors not in the + _UNINTERESTING_ERRORS) set. + + Args: + timeout_sec: Seconds to wait for other error sources. + """ + for _ in range(timeout_sec): + if len(self._errors) == self._num_sources: + break + time.sleep(1) + + kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None] + + # First check for any interesting errors, then fall back on the session + # cancelled errors etc. + for k, (typ, value, traceback) in kept_errors: + if isinstance(value, _UNINTERESTING_ERRORS): + continue + else: + logging.warn('Reraising captured error') + six.reraise(typ, value, traceback) + + for k, (typ, value, traceback) in kept_errors: + logging.warn('Reraising captured error') + six.reraise(typ, value, traceback) diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index f1a11fa6548b87d6222a97c72b8db5442c8ef774..ff893a722f4e77c743edd3b8db77aa90be1e498d 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -19,15 +19,16 @@ To use, wrap your model with the `keras_support.tpu_model` function. Example usage: ``` -# Must activate before building TPU models -keras_support.setup_tpu_session(master_address) - image = tf.keras.layers.Input(shape=(28, 28, 3), name='image') c1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3))( image) flattened = tf.keras.layers.Flatten()(c1) logits = tf.keras.layers.Dense(10, activation='softmax')(flattened) model = tf.keras.Model(inputs=[image], outputs=[logits]) -model = keras_support.tpu_model(model) + +strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8) +model = keras_support.tpu_model(model, + strategy=strategy, + tpu_name_or_address=tpu_name) # Only TF optimizers are currently supported. model.compile(optimizer=tf.train.AdamOptimizer(), ...) @@ -35,9 +36,6 @@ model.compile(optimizer=tf.train.AdamOptimizer(), ...) # `images` and `labels` should be Numpy arrays. Support for tensor input # (e.g. datasets) is planned. model.fit(images, labels) - -# Invoke before shutting down -keras_support.shutdown_tpu_session() ``` """ @@ -47,30 +45,49 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections +import contextlib import re +import sys import time +import numpy as np + +from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.contrib.tpu.python.tpu import tpu_optimizer from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as tf_session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K -from tensorflow.python.keras import layers from tensorflow.python.keras import models from tensorflow.python.keras import optimizers as keras_optimizers +from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.layers import embeddings from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging +# Work-around dependency cycle between DistributionStrategy and TPU lib. +def TPUDistributionStrategy(*args, **kw): # pylint: disable=invalid-name + from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top + return tpu_strategy.TPUStrategy(*args, **kw) + + class TPUEmbedding(embeddings.Embedding): """TPU compatible embedding layer. @@ -92,11 +109,49 @@ class TPUEmbedding(embeddings.Embedding): return math_ops.tensordot(inputs, self.embeddings, 1) +class KerasCrossShardOptimizer(keras_optimizers.Optimizer): + """An optimizer that averages gradients across TPU shards.""" + + def __init__(self, opt, name='KerasCrossShardOptimizer'): + """Construct a new cross-shard optimizer. + + Args: + opt: An existing `Optimizer` to encapsulate. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "KerasCrossShardOptimizer". + + Raises: + ValueError: If reduction is not a valid cross-shard reduction. + """ + super(KerasCrossShardOptimizer, self).__init__() + self._name = name + self._opt = opt + + def get_updates(self, loss, params): + logging.info('Get updates: %s', loss) + self._opt.get_gradients = self.get_gradients + return self._opt.get_updates(loss, params) + + def get_gradients(self, loss, params): + num_shards = tpu_function.get_tpu_context().number_of_shards + grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params) + return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads] + + def set_weights(self, weights): + self._opt.set_weights() + + def get_weights(self): + return self._opt.get_weights() + + @property + def lr(self): + return self._opt.lr + + class TPUModelOp( - collections.namedtuple( - 'TPUModelOp', - ['compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', - 'outfeed_op'])): + collections.namedtuple('TPUModelOp', [ + 'compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', 'outfeed_op' + ])): pass @@ -105,13 +160,444 @@ def _valid_name(tensor_name): return re.sub('[^a-zA-Z0-9_-]+', '', tensor_name) -def _replicated_optimizer(opt, num_replicas): +def _replicated_optimizer(opt): """Wrap the optimizer `opt` with CrossShardOptimizer if applicable.""" - if num_replicas == 1: + if tpu_function.get_tpu_context().number_of_shards == 1: return opt - return keras_optimizers.TFOptimizer( - optimizer=tpu_optimizer.CrossShardOptimizer(opt.optimizer) - ) + + if isinstance(opt, keras_optimizers.TFOptimizer): + return tpu_optimizer.CrossShardOptimizer(opt.optimizer) + else: + return KerasCrossShardOptimizer(opt) + + +class TPURewriteContext(object): + """Prepare the environment for a Keras model during `tpu.rewrite`. + + This overrides the default placeholder behaviour to instead refer to a preset + input mapping. Placeholders are unsupported in TPU compiled code, and must + be replaced with explicit inputs or values from the infeed queue. + + Instead of explicitly threading inputs all the way through the Keras codebase, + we override the behavior of the placeholder while compiling and inject the + Tensors from the infeed in place of the placeholder. + + Similarly, as we compile a new sub-graph for each unique shape and execution + mode, we need to override the behavior of an embedded `name_scope` call in + the base Keras layer code. This allows us to re-use the same weights across + many compiles and share a single session/graph. + """ + + def __init__(self, input_map): + self._input_map = input_map + self._default_placeholder = None + self._default_name_scope = None + + def __enter__(self): + + def _placeholder(dtype, shape=None, name=None): # pylint: disable=unused-argument + logging.info('Remapping placeholder for %s', name) + if name in self._input_map: + return self._input_map[name] + else: + logging.info('Default: %s', name) + return self._default_placeholder(dtype, shape, name) + + def _name_scope(name, default_name=None, values=None): + caller_frame = sys._getframe().f_back + caller_obj = caller_frame.f_locals.get('self') + if (caller_obj is not None and + isinstance(caller_obj, base_layer.Layer) and name is not None): + return variable_scope.variable_scope( + name, default_name, values, reuse=variable_scope.AUTO_REUSE) + + return self._default_name_scope(name, default_name, values) + + self._default_placeholder = array_ops.placeholder + self._default_name_scope = ops.name_scope + self._default_make_variable = base_layer.make_variable + self._default_random_normal = random_ops.random_normal + self._default_qr = gen_linalg_ops.qr + + array_ops.placeholder = _placeholder + + # Replace random_ops.random_normal with a dummy function because + # `random_normal` isn't yet implemented on the TPU. Because these + # initialized values are overwritten by the CPU values, this is okay. + def random_normal(shape, + mean=0.0, + stddev=1.0, + dtype=dtypes.float32, + seed=None, + name=None): + del mean + del stddev + del seed + return array_ops.zeros(shape, dtype=dtype, name=name) + + random_ops.random_normal = random_normal + + # Replace gen_linalg_ops.qr because QR decomposition is not yet implemented. + # TODO(saeta): Remove qr override once we confirm the qr implementation is + # ok. + # pylint: disable=redefined-builtin + def qr(input, full_matrices=False, name=None): + """Dummy implementation of qr decomposition.""" + del full_matrices # TODO(saeta): Properly handle the full matrix case. + input_shape = input.shape + if len(input_shape) < 2: + raise ValueError('Invalid shape passed to qr: %s' % input_shape) + p = min(input_shape[-1], input_shape[-2]) + if len(input_shape) == 2: + q = array_ops.zeros((p, p), name=name) + r = array_ops.zeros(input_shape, name=name) + return (r, q) + elif len(input_shape) == 3: + n = input_shape[0] + q = array_ops.zeros((n, p, p), name=name) + r = array_ops.zeros(input_shape, name=name) + return (r, q) + else: + raise ValueError('Invalid shape passed to qr: %s' % input_shape) + gen_linalg_ops.qr = qr + + ops.name_scope = _name_scope + base_layer.make_variable = variable_scope.get_variable + logging.info('Overriding default placeholder.') + return + + def __exit__(self, exc_type, exc_val, exc_tb): + array_ops.placeholder = self._default_placeholder + ops.name_scope = self._default_name_scope + base_layer.make_variable = self._default_make_variable + random_ops.random_normal = self._default_random_normal + gen_linalg_ops.qr = self._default_qr + + +class SizedInfeed(collections.namedtuple('SizedInfeed', + ['sharded_infeed_tensors', + 'infeed_ops'])): + """Represents an instantiation of the infeed ops for a concrete input shape. + + sharded_infeed_tensors: A data structure of Tensors used to represent the + placeholder tensors that must be fed when using feed_dicts. + + infeed_ops: the set of ops that will be run to drive infeed for a single step. + """ + pass + + +class TPUInfeedInstance(object): + """TPUInfeedInstance represents the logic to manage feeding in a single step. + + See the comments on the `TPUInfeedManager` for a description for how infeed + is managed. + """ + + @abc.abstractmethod + def make_input_specs(self, input_tensors): + """Constructs the infeed_specs for the given Infeed instance. + + Args: + input_tensors: The inputs to the model. + + Returns: + A list of + """ + pass + + def make_feed_dict(self, tpu_model_op): + """Constructs a feed_dict for this instance, given the tpu_model_op. + + Args: + tpu_model_op: A `TPUModelOp` representing the TPU Model for this + instance's input spec. + + Returns: + A dictionary to use as the feed_dict of a `session.run` call. + """ + pass + + +class TPUInfeedManager(object): + """TPUInfeedManager manages the data infeeding of data to a TPU computation. + + Because there are multiple data sources (e.g. in-memory NumPy arrays, + `tf.data.Dataset`s), we abstract the different logic behind a single + interface: the `TPUInfeedManager`. + + (1) A `TPUFunction` is called with a set of inputs. Based on the inputs, + `TPUFunction` retrieves the corresponding `TPUInfeedManager` (or constructs a + new one if required). + + (2) The `TPUFunction` calls `make_infeed_instance` on the `TPUInfeedManager` + which returns a `TPUInfeedInstance`. + + (3) The `TPUFunction` checks in the shape cache for a pre-compiled instance of + the model based on the returned `input_specs` from `TPUInfeedInstance`. + + (4) [Optional.] If the model has not already been instantiated for the given + input spec, the `TPUFunction` compiles the model for the input spec (using the + `TPUInfeedManager`). + + (5) The `TPUInfeedInstance` constructs the session.run's feed_dict given the + compiled model instance corresponding to its shape. + """ + + @abc.abstractmethod + def make_infeed_instance(self, inputs): + """Given a single step's input, construct a `TPUInfeedInstance`. + + Args: + inputs: The inputs to a given step. + + Returns: + A subclass of `TPUInfeedInstance`. + """ + pass + + @abc.abstractmethod + def build_infeed_from_input_specs(self, input_specs, execution_mode): + """For a given input specification (size, type), construct the infeed ops. + + This is called only once for a given input specification and builds the + graph ops. It does not have a pointer to the actual infeed data. + + Args: + input_specs: TODO(saeta): Document me! + execution_mode: TODO(saeta): Document me! + + Returns: + A `SizedInfeed` instance. + """ + pass + + +class TPUNumpyInfeedManager(TPUInfeedManager): + """TPU Infeed manager for Numpy inputs.""" + + class NumpyInfeedInstance(TPUInfeedInstance): + """Infeed instance for Numpy inputs.""" + + def __init__(self, sharded_inputs): + self._sharded_inputs = sharded_inputs + + def make_input_specs(self, input_tensors): + # Compute an input specification (used to generate infeed enqueue and + # dequeue operations). We use the shape from our input array and the + # dtype from our model. A user may pass in a float64 for a float32 + # input: for model compatibility we still must generate a float32 infeed. + input_specs = [] + # We use the shape and dtype from the first shard to compute the input + # metadata (`input_specs`); all replicas have the same type and shape. + for tensor, ary in zip(input_tensors, self._sharded_inputs[0]): + input_specs.append( + tensor_spec.TensorSpec(ary.shape, tensor.dtype, + _valid_name(tensor.name))) + + return input_specs + + def make_feed_dict(self, tpu_model_op): + infeed_dict = {} + for infeed_tensors, inputs in zip(tpu_model_op.infeed_tensors, + self._sharded_inputs): + for tensor, value in zip(infeed_tensors, inputs): + infeed_dict[tensor] = value + return infeed_dict + + def __init__(self, distribution_strategy): + self._strategy = distribution_strategy + + def _split_tensors(self, inputs): + """Split input data across shards. + + Each input is sliced along the batch axis. + + Args: + inputs: List of Numpy arrays to run on the TPU. + + Returns: + List of lists containing the input to feed to each TPU shard. + """ + if self._strategy.num_towers == 1: + return [inputs] + + batch_size = inputs[0].shape[0] + assert batch_size % self._strategy.num_towers == 0, ( + 'batch_size must be divisible by strategy.num_towers (%s vs %s)' % + (batch_size, self._strategy.num_towers)) + shard_size = batch_size // self._strategy.num_towers + input_list = [] + for index in range(self._strategy.num_towers): + shard_inputs = [ + x[index * shard_size:(index + 1) * shard_size] for x in inputs + ] + input_list.append(shard_inputs) + return input_list + + def make_infeed_instance(self, inputs): + sharded_inputs = self._split_tensors(inputs) + return self.NumpyInfeedInstance(sharded_inputs) + + def build_infeed_from_input_specs(self, input_specs, execution_mode): + infeed_op = [] + shard_infeed_tensors = [] + + for shard_id in range(self._strategy.num_towers): + with ops.device('/device:CPU:0'): + infeed_tensors = [] + with ops.device('/device:TPU:%d' % shard_id): + for spec in input_specs: + # Construct placeholders for each of the inputs. + infeed_tensors.append( + array_ops.placeholder( + dtype=spec.dtype, + shape=spec.shape, + name='infeed-enqueue-%s-%d' % (spec.name, shard_id))) + shard_infeed_tensors.append(infeed_tensors) + + infeed_op.append( + tpu_ops.infeed_enqueue_tuple( + infeed_tensors, [spec.shape for spec in input_specs], + name='infeed-enqueue-%s-%d' % (execution_mode, shard_id), + device_ordinal=shard_id)) + return SizedInfeed(infeed_ops=infeed_op, + sharded_infeed_tensors=shard_infeed_tensors) + + +class TPUDatasetInfeedManager(TPUInfeedManager): + """Manages infeed for a `tf.data.Dataset` into a TPU computation. + """ + + class DatasetInfeedInstance(TPUInfeedInstance): + """An instance of the TPU infeed.""" + + def __init__(self, input_specs): + self._input_specs = input_specs + + def make_input_specs(self, input_tensors): + # TODO(saeta): Do error checking here! + return self._input_specs + + def make_feed_dict(self, tpu_model_op): + # TODO(saeta): Verify tpu_model_op is as expected! + return {} + + def __init__(self, dataset, distribution_strategy, tpu_session): + """Constructs a TPUDatasetInfeedManager. + + Must be called within a `KerasTPUModel.tpu_session` context! + + Args: + dataset: A `tf.data.Dataset` to infeed. + distribution_strategy: The `TPUDistributionStrategy` used to configure the + Keras TPU model. + tpu_session: The `tf.Session` object used for running the TPU model. + """ + self._verify_dataset_shape(dataset) + self._dataset = dataset + self._strategy = distribution_strategy + dummy_x_shape = dataset.output_shapes[0].as_list() + dummy_x_shape[0] *= distribution_strategy.num_towers + dummy_y_shape = dataset.output_shapes[1].as_list() + dummy_y_shape[0] *= distribution_strategy.num_towers + self._iterator = dataset.make_initializable_iterator() + tpu_session.run(self._iterator.initializer) + + self._get_next_ops = [] + ctrl_deps = [] + for i in range(distribution_strategy.num_towers): + with ops.control_dependencies(ctrl_deps): # Ensure deterministic + # TODO(saeta): Ensure correct placement! + get_next_op = self._iterator.get_next() + self._get_next_ops.append(get_next_op) + ctrl_deps.extend(get_next_op) + + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + self._dummy_x = np.zeros(dummy_x_shape, + dtype=dataset.output_types[0].as_numpy_dtype) + self._dummy_y = np.zeros(dummy_y_shape, + dtype=dataset.output_types[1].as_numpy_dtype) + + input_specs = [] + if isinstance(self._iterator.output_shapes, tuple): + assert isinstance(self._iterator.output_types, tuple) + assert len(self._iterator.output_shapes) == len( + self._iterator.output_types) + for i in range(len(self._iterator.output_shapes)): + spec = tensor_spec.TensorSpec(self._iterator.output_shapes[i], + self._iterator.output_types[i]) + input_specs.append(spec) + elif isinstance(self._iterator.output_shapes, tensor_shape.TensorShape): + spec = tensor_spec.TensorSpec(self._iterator.output_shapes, + self._iterator.output_types) + input_specs.append(spec) + + self._infeed_instance = self.DatasetInfeedInstance(input_specs) + + def _verify_dataset_shape(self, dataset): + """Verifies a dataset is of an appropriate shape for TPUs.""" + if not isinstance(dataset, dataset_ops.Dataset): + raise ValueError('The function passed as the `x` parameter did not ' + 'return a `tf.data.Dataset`.') + if not isinstance(dataset.output_classes, tuple): + raise ValueError('The dataset must return a tuple of tf.Tensors, ' + 'instead it returns: %s' % dataset.output_classes) + if len(dataset.output_classes) != 2: + raise ValueError( + 'The dataset must return a 2-element tuple, got ' + '%s output classes instead.' % (dataset.output_classes,)) + for i, cls in enumerate(dataset.output_classes): + if cls != ops.Tensor: + raise ValueError('The dataset returned a non-Tensor type (%s) at ' + 'index %d.' % (cls, i)) + for i, shape in enumerate(dataset.output_shapes): + if not shape: + raise ValueError('The dataset returns a scalar tensor in ' + 'tuple index %d. Did you forget to batch? ' + '(Output shapes: %s).' % (i, + dataset.output_shapes)) + for j, dim in enumerate(shape): + if dim.value is None: + if j == 0: + hint = (' Hint: did you use `ds.batch(BATCH_SIZE, ' + 'drop_remainder=True)`?') + else: + hint = '' + raise ValueError( + 'The Keras-TPU integration for `tf.data` ' + 'currently requires static shapes. The provided ' + 'dataset only has a partially defined shape. ' + '(Dimension %d of output tensor %d is not statically known ' + 'for output shapes: %s.%s)' % (i, j, dataset.output_shapes, hint)) + + @property + def dummy_x(self): + return self._dummy_x + + @property + def dummy_y(self): + return self._dummy_y + + def make_infeed_instance(self, inputs): + # TODO(saeta): Verify inputs is as expected. + return self._infeed_instance + + def build_infeed_from_input_specs(self, input_specs, execution_mode): + shard_infeed_tensors = self._get_next_ops + assert len(shard_infeed_tensors) == self._strategy.num_towers + infeed_ops = [] + for shard_id in range(self._strategy.num_towers): + with ops.device('/device:CPU:0'): + infeed_ops.append( + tpu_ops.infeed_enqueue_tuple( + shard_infeed_tensors[shard_id], + [spec.shape for spec in input_specs], + name='infeed-enqueue-%s-%d' % (execution_mode, shard_id), + device_ordinal=shard_id)) + return SizedInfeed(infeed_ops=infeed_ops, + sharded_infeed_tensors=shard_infeed_tensors) class TPUFunction(object): @@ -126,19 +612,24 @@ class TPUFunction(object): instead of being injected as `feed_dict` items or fetches. """ - def __init__(self, model, execution_mode, num_replicas=1): + def __init__(self, model, execution_mode, strategy): self.model = model self.execution_mode = execution_mode + self._strategy = strategy self._compilation_cache = {} - self.num_replicas = num_replicas + self._cloned_model = None + + # Copy optimizer configuration. This is done prior to `_specialize_model` + # as the configuration may require evaluating variables in the CPU session. + self._optimizer_config = None + if not isinstance(self.model.optimizer, keras_optimizers.TFOptimizer): + self._optimizer_config = self.model.optimizer.get_config() - def _specialize_model(self, input_specs): + def _specialize_model(self, input_specs, infeed_manager): """Specialize `self.model` (a Keras model) for the given input shapes.""" # Re-create our input and output layers inside our subgraph. They will be # attached to the true computation when we clone our model in `tpu_fn`. - K.set_learning_phase( - self.execution_mode == model_fn_lib.ModeKeys.TRAIN - ) + K.set_learning_phase(self.execution_mode == model_fn_lib.ModeKeys.TRAIN) # functools.partial and callable objects are not supported by tpu.rewrite def _model_fn(): @@ -160,27 +651,38 @@ class TPUFunction(object): name='infeed-%s' % self.execution_mode) assert len(infeed_tensors) == len(infeed_layers), ( - 'Infeed inputs did not match model: %s vs %s', (infeed_layers, - infeed_tensors)) + 'Infeed inputs did not match model: %s vs %s' % (infeed_layers, + infeed_tensors)) tpu_targets = [] - tpu_inputs = [] + tpu_input_map = {} # Sort infeed outputs into inputs and labels for calling our Keras model. for tensor, layer in zip(infeed_tensors, infeed_layers): if layer in self.model._input_layers: - tpu_inputs.append(layers.Input(name=layer.name, tensor=tensor)) + tpu_input_map[layer.name] = tensor if layer in self.model._output_layers: tpu_targets.append(tensor) - # Call our model with our infeed inputs (re-using the weights). - model_outputs = self.model(tpu_inputs) - child_model = models.Model(inputs=tpu_inputs, outputs=model_outputs) + # Clone our CPU model, running within the TPU device context. + with TPURewriteContext(tpu_input_map): + # TODO(power): Replicate variables. + with ops.device('/device:TPU:0'): + self._cloned_model = models.clone_model(self.model) + + # Create a copy of the optimizer for this graph. + if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer): + cloned_optimizer = keras_optimizers.TFOptimizer( + self.model.optimizer.optimizer) + else: + logging.info('Cloning %s %s', self.model.optimizer.__class__.__name__, + self._optimizer_config) + cloned_optimizer = self.model.optimizer.__class__.from_config( + self._optimizer_config) if is_training or is_test: - child_model.compile( - optimizer=_replicated_optimizer(self.model.optimizer, - self.num_replicas), + self._cloned_model.compile( + optimizer=_replicated_optimizer(cloned_optimizer), loss=self.model.loss, loss_weights=self.model.loss_weights, metrics=self.model.metrics, @@ -190,37 +692,37 @@ class TPUFunction(object): # Compute our outfeed depending on the execution mode if is_training: - child_model._make_train_function() + self._cloned_model._make_train_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in child_model.train_function.outputs + for tensor in self._cloned_model.train_function.outputs ] return [ - child_model.train_function.updates_op, + self._cloned_model.train_function.updates_op, tpu_ops.outfeed_enqueue_tuple( - child_model.train_function.outputs, + self._cloned_model.train_function.outputs, name='outfeed-enqueue-train') ] elif is_test: - child_model._make_test_function() + self._cloned_model._make_test_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in child_model.test_function.outputs + for tensor in self._cloned_model.test_function.outputs ] return [ tpu_ops.outfeed_enqueue_tuple( - child_model.test_function.outputs, + self._cloned_model.test_function.outputs, name='outfeed-enqueue-test') ] elif is_predict: - child_model._make_predict_function() + self._cloned_model._make_predict_function() self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) - for tensor in child_model.predict_function.outputs + for tensor in self._cloned_model.predict_function.outputs ] return [ tpu_ops.outfeed_enqueue_tuple( - child_model.predict_function.outputs, + self._cloned_model.predict_function.outputs, name='outfeed-enqueue-predict', ) ] @@ -235,84 +737,57 @@ class TPUFunction(object): # `execute op` replicates `_model_fn` `num_replicas` times, with each shard # running on a different logical core. compile_op, execute_op = tpu.split_compile_and_replicate( - _model_fn, inputs=[[]] * self.num_replicas) + _model_fn, inputs=[[]] * self._strategy.num_towers) # Generate CPU side operations to enqueue features/labels and dequeue # outputs from the model call. - infeed_op = [] + sized_infeed = infeed_manager.build_infeed_from_input_specs( + input_specs, self.execution_mode) + # Build output ops. outfeed_op = [] - shard_infeed_tensors = [] - - for shard_id in range(self.num_replicas): - with ops.device('/device:TPU:%d' % shard_id): - infeed_tensors = [] - for spec in input_specs: - infeed_tensors.append( - array_ops.placeholder( - dtype=spec.dtype, - shape=spec.shape, - name='infeed-enqueue-%s-%d' % (spec.name, shard_id))) - shard_infeed_tensors.append(infeed_tensors) - - infeed_op.append(tpu_ops.infeed_enqueue_tuple( - infeed_tensors, [spec.shape for spec in input_specs], - name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id))) - - outfeed_op.extend(tpu_ops.outfeed_dequeue_tuple( - dtypes=[spec.dtype for spec in self._outfeed_spec], - shapes=[spec.shape for spec in self._outfeed_spec], - name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id))) + for shard_id in range(self._strategy.num_towers): + with ops.device('/device:CPU:0'): + outfeed_op.extend( + tpu_ops.outfeed_dequeue_tuple( + dtypes=[spec.dtype for spec in self._outfeed_spec], + shapes=[spec.shape for spec in self._outfeed_spec], + name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id), + device_ordinal=shard_id)) return TPUModelOp( - compile_op, execute_op, infeed_tensors=shard_infeed_tensors, - infeed_op=infeed_op, outfeed_op=outfeed_op) + compile_op, + execute_op, + infeed_tensors=sized_infeed.sharded_infeed_tensors, + infeed_op=sized_infeed.infeed_ops, + outfeed_op=outfeed_op) def _test_model_compiles(self, tpu_model_ops): """Verifies that the given TPUModelOp can be compiled via XLA.""" - session = K.get_session() - logging.info('Started compiling') start_time = time.clock() - result = session.run(tpu_model_ops.compile_op) + result = K.get_session().run(tpu_model_ops.compile_op) proto = tpu_compilation_result.CompilationResultProto() proto.ParseFromString(result) if proto.status_error_message: - raise RuntimeError( - 'Compilation failed: {}'.format(proto.status_error_message)) + raise RuntimeError('Compilation failed: {}'.format( + proto.status_error_message)) end_time = time.clock() logging.info('Finished compiling. Time elapsed: %s secs', end_time - start_time) - def _split_tensors(self, inputs): - """Split input data across shards. - - Each input is sliced along the batch axis. - - Args: - inputs: List of Numpy arrays to run on the TPU. - - Returns: - List of lists containing the input to feed to each TPU shard. - """ - if self.num_replicas == 1: - return [inputs] - - batch_size = inputs[0].shape[0] - assert batch_size % self.num_replicas == 0, ( - 'batch_size must be divisible by num_replicas') - shard_size = batch_size // self.num_replicas - input_list = [] - for index in range(self.num_replicas): - shard_inputs = [x[index * shard_size:(index + 1) * shard_size] - for x in inputs] - input_list.append(shard_inputs) - return input_list - def __call__(self, inputs): assert isinstance(inputs, list) + infeed_manager = None + for x, mgr in self.model._numpy_to_infeed_manager_list: + if inputs[0] is x: + infeed_manager = mgr + break + if infeed_manager is None: + infeed_manager = TPUNumpyInfeedManager(self.model._strategy) + # Strip sample weight from inputs if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or self.execution_mode == model_fn_lib.ModeKeys.EVAL): @@ -321,21 +796,9 @@ class TPUFunction(object): else: input_tensors = self.model._feed_inputs - shard_inputs = self._split_tensors(inputs) + infeed_instance = infeed_manager.make_infeed_instance(inputs) del inputs # To avoid accident usage. - - # Compute an input specification (used to generate infeed enqueue and - # dequeue operations). We use the shape from our input array and the - # dtype from our model. A user may pass in a float64 for a float32 - # input: for model compatibility we still must generate a float32 infeed. - input_specs = [] - - # We use the shape and dtype from the first shard to compute the input - # metadata (`input_specs`); all replicas have the same type and shape. - for tensor, ary in zip(input_tensors, shard_inputs[0]): - input_specs.append( - tensor_spec.TensorSpec(ary.shape, tensor.dtype, - _valid_name(tensor.name))) + input_specs = infeed_instance.make_input_specs(input_tensors) # XLA requires every operation in the graph has a fixed shape. To # handle varying batch sizes we recompile a new sub-graph for each @@ -343,72 +806,103 @@ class TPUFunction(object): shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs]) if shape_key not in self._compilation_cache: - logging.info('New input shapes; (re-)compiling: mode=%s, %s', - self.execution_mode, input_specs) - new_tpu_model_ops = self._specialize_model(input_specs) - self._compilation_cache[shape_key] = new_tpu_model_ops - self._test_model_compiles(new_tpu_model_ops) - + with self.model.tpu_session(): + logging.info('New input shapes; (re-)compiling: mode=%s, %s', + self.execution_mode, input_specs) + new_tpu_model_ops = self._specialize_model(input_specs, + infeed_manager) + self._compilation_cache[shape_key] = new_tpu_model_ops + self._test_model_compiles(new_tpu_model_ops) + + # Initialize our TPU weights on the first compile. + self.model._initialize_weights(self._cloned_model) tpu_model_ops = self._compilation_cache[shape_key] - infeed_dict = {} - for infeed_tensors, inputs in zip(tpu_model_ops.infeed_tensors, - shard_inputs): - for tensor, value in zip(infeed_tensors, inputs): - infeed_dict[tensor] = value + infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops) - session = K.get_session() - _, _, outfeed_outputs = session.run([ - tpu_model_ops.infeed_op, tpu_model_ops.execute_op, - tpu_model_ops.outfeed_op - ], infeed_dict) + with self.model.tpu_session() as session: + _, _, outfeed_outputs = session.run([ + tpu_model_ops.infeed_op, tpu_model_ops.execute_op, + tpu_model_ops.outfeed_op + ], infeed_dict) # TODO(xiejw): Decide how to reduce outputs, or just discard all but first. - return outfeed_outputs[:len(outfeed_outputs) // self.num_replicas] + if self.execution_mode == model_fn_lib.ModeKeys.PREDICT: + outputs = [[]] * len(self._outfeed_spec) + outputs_per_replica = len(self._outfeed_spec) + for i in range(self._strategy.num_towers): + output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) * + outputs_per_replica] + for j in range(outputs_per_replica): + outputs[j].append(output_group[j]) -@experimental -def setup_tpu_session(master): - """Initializes and returns a Keras/TF session connected the TPU `master`.""" - session = tf_session.Session( - target=master, config=config_pb2.ConfigProto(isolate_session_state=True)) - K.set_session(session) - K.get_session().run(tpu.initialize_system()) - return session - - -@experimental -def shutdown_tpu_session(session=None): - """Shutdown the TPU attached to session. - - This should be called to cleanly shut down the TPU system before the client - exits. - - Args: - session: Session to shutdown, or None to use the default session. - - Returns: - - """ - if session is None: - session = K.get_session() - - session.run(tpu.shutdown_system()) + return [np.concatenate(group) for group in outputs] + else: + return outfeed_outputs[:len(outfeed_outputs) // self._strategy.num_towers] class KerasTPUModel(models.Model): """TPU compatible Keras model wrapper.""" - def __init__(self, inputs, outputs, name, replicas=1): + def __init__(self, cpu_model, tpu_name_or_address, strategy): super(models.Model, self).__init__( # pylint: disable=bad-super-call - inputs=inputs, - outputs=outputs, - name=name, + inputs=cpu_model.inputs, + outputs=cpu_model.outputs, + name=cpu_model.name, ) + + # Create a mapping from numpy arrays to infeed managers. + # Note: uses a list of tuples instead of a map because numpy arrays are + # not hashable. + self._numpy_to_infeed_manager_list = [] + self.predict_function = None self.test_function = None self.train_function = None - self.replicas = replicas + self._strategy = strategy + + self._tpu_name_or_address = tpu_name_or_address + self._cpu_model = cpu_model + self._tpu_model = None + self._tpu_weights_initialized = False + self._graph = ops.Graph() + + self._cluster_resolver = tpu_cluster_resolver.TPUClusterResolver( + tpu_name_or_address) + master = self._cluster_resolver.master() + cluster_spec = self._cluster_resolver.cluster_spec() + self._session = tf_session.Session( + graph=self._graph, + target=master, + config=config_pb2.ConfigProto(isolate_session_state=True)) + + # TODO(saeta): Confirm the lines below work in ClusterSpec propagation env. + if cluster_spec: + self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + + with self._graph.as_default(): + self._session.run(tpu.initialize_system()) + + # If the input CPU model has already been compiled, compile our TPU model + # immediately. + if self._cpu_model.optimizer: + self.compile( + self._cpu_model.optimizer, + self._cpu_model.loss, + self._cpu_model.metrics, + self._cpu_model.loss_weights, + self._cpu_model.sample_weight_mode, + self._cpu_model.weighted_metrics, + self._cpu_model.target_tensors, + ) + + def get_config(self): + return { + 'cpu_model': self._cpu_model, + 'tpu_name_or_address': self._tpu_name_or_address, + 'strategy': self._strategy, + } def compile(self, optimizer, @@ -430,44 +924,183 @@ class KerasTPUModel(models.Model): sample_weight_mode, weighted_metrics, target_tensors, **kwargs) - # Keras optimizers are not compatible with TPU rewrite - if not isinstance(self.optimizer, keras_optimizers.TFOptimizer): + if not self._cpu_model.optimizer: + self._cpu_model.compile(optimizer, loss, metrics, loss_weights, + sample_weight_mode, weighted_metrics, + target_tensors, **kwargs) + + def fit(self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose=1, + callbacks=None, + validation_split=0., + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + **kwargs): + assert not self._numpy_to_infeed_manager_list # Ensure empty. + + infeed_managers = [] # Managers to clean up at the end of the fit call. + if isinstance(x, dataset_ops.Dataset): + # TODO(b/111413240): Support taking a tf.data.Dataset directly. + raise ValueError( + 'Taking a Dataset directly is not yet supported. Please ' + 'wrap your dataset construction code in a function and ' + 'pass that to fit instead. For examples, see: ' + 'https://github.com/tensorflow/tpu/tree/master/models/experimental' + '/keras') + if callable(x): + with self.tpu_session() as sess: + dataset = x() + if steps_per_epoch is None: + raise ValueError('When using tf.data as input to a model, you ' + 'should specify the steps_per_epoch argument.') + if y is not None: + raise ValueError('When using tf.data as input to a model, y must be ' + 'None') + infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess) + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + x = infeed_manager.dummy_x + y = infeed_manager.dummy_y + infeed_managers.append((x, infeed_manager)) + + if isinstance(validation_data, dataset_ops.Dataset): + # TODO(b/111413240): Support taking a tf.data.Dataset directly. raise ValueError( - 'Optimizer must be a TFOptimizer, got: %s' % self.optimizer) + 'Taking a Dataset directly is not yet supported. Please ' + 'wrap your dataset construction code in a function and ' + 'pass that to fit instead. For examples, see: ' + 'https://github.com/tensorflow/tpu/tree/master/models/experimental' + '/keras') + if callable(validation_data): + with self.tpu_session() as sess: + dataset = validation_data() + if validation_steps is None: + raise ValueError('When using tf.data as validation for a model, you ' + 'should specify the validation_steps argument.') + infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess) + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + val_x = infeed_manager.dummy_x + val_y = infeed_manager.dummy_y + infeed_managers.append((val_x, infeed_manager)) + validation_data = (val_x, val_y) + + self._numpy_to_infeed_manager_list = infeed_managers + try: + return super(KerasTPUModel, self).fit( + x, + y, + batch_size, + epochs, + verbose, + callbacks, + validation_split, + validation_data, + shuffle, + class_weight, + sample_weight, + initial_epoch, + steps_per_epoch, + validation_steps, + **kwargs) + finally: + self._numpy_to_infeed_manager_list = [] def _make_train_function(self): if not self.train_function: - self.train_function = TPUFunction(self, model_fn_lib.ModeKeys.TRAIN, - num_replicas=self.replicas) + self.train_function = TPUFunction( + self, model_fn_lib.ModeKeys.TRAIN, strategy=self._strategy) return self.train_function def _make_test_function(self): if not self.test_function: - self.test_function = TPUFunction(self, model_fn_lib.ModeKeys.EVAL) + self.test_function = TPUFunction( + self, model_fn_lib.ModeKeys.EVAL, strategy=self._strategy) return self.test_function def _make_predict_function(self): if not self.predict_function: - self.predict_function = TPUFunction(self, model_fn_lib.ModeKeys.PREDICT) + self.predict_function = TPUFunction( + self, model_fn_lib.ModeKeys.PREDICT, strategy=self._strategy) return self.predict_function - def cpu_model(self): - cpu_model = models.Model( - inputs=self.inputs, - outputs=self.outputs, - name=self.name, - ) + def _initialize_weights(self, cloned_model): + """Initialize TPU weights. - if self.optimizer: - cpu_model.compile( - optimizer=self.optimizer, - loss=self.loss, - metrics=self.metrics, - loss_weights=self.loss_weights, - ) + This is called on the first compile of the TPU model (first call to + fit/predict/evaluate). - return cpu_model + Args: + cloned_model: `keras.Model`, TPU model to initialize. + """ + if self._tpu_weights_initialized: + return + + self._tpu_model = cloned_model + self._tpu_weights_initialized = True + + weights = self._cpu_model.get_weights() + with self.tpu_session(): + logging.info('Setting weights on TPU model.') + cloned_model.set_weights(weights) + + def sync_to_cpu(self): + """Copy weights from the CPU, returning a synchronized CPU model.""" + if self._tpu_weights_initialized: + with self.tpu_session(): + logging.info('Copying TPU weights to the CPU') + tpu_weights = self._tpu_model.get_weights() + + self._cpu_model.set_weights(tpu_weights) + + return self._cpu_model + + def get_weights(self): + return self.sync_to_cpu().get_weights() + + def save_weights(self, *args, **kw): + return self.sync_to_cpu().save_weights(*args, **kw) + + def save(self, *args, **kw): + return self.sync_to_cpu().save(*args, **kw) + + def set_weights(self, weights): + # We may not have a TPU model available if we haven't run fit/predict, so + # we can't directly set the TPU weights here. + # Instead, reset CPU model weights and force TPU re-initialization at the + # next call. + self._cpu_model.set_weights(weights) + self._tpu_weights_initialized = False + + @contextlib.contextmanager + def tpu_session(self): + """Yields a TPU session and sets it as the default Keras session.""" + with self._graph.as_default(): + default_session = K.get_session() + # N.B. We have to call `K.set_session()` AND set our session as the + # TF default. `K.get_session()` surprisingly does not return the value + # supplied by K.set_session otherwise. + K.set_session(self._session) + with self._session.as_default(): + yield self._session + K.set_session(default_session) + + def shutdown(self): + # TODO(b/111364423): Actually shut down the system. + logging.info('Skipping shutting down TPU system.') + # with self.tpu_session() as session: + # session.run(tpu.shutdown_system()) + self._session.close() def _validate_shapes(model): @@ -504,26 +1137,8 @@ Output shape: %(output_shape)s @experimental -def tpu_model(model, replicas=None): - """Runs a model on TPU(s). - - Usage: - ``` - a = Input(shape=(32,)) - b = Dense(32)(a) - model = Model(inputs=a, outputs=b) - - model = keras_support.tpu_model(model) - model.compile( - optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0), - ...) - ``` - - If `replicas` is set, replicates the model computation on all TPU cores. The - model computation is replicated `num_replicas` times; each shard will run on a - different TPU core. - - Limitation: Currently, replication is only supported for training. +def tpu_model(model, tpu_name_or_address=None, strategy=None): + """Copy `model` along with weights to the TPU. Returns a TPU model. Usage: ``` @@ -531,26 +1146,39 @@ def tpu_model(model, replicas=None): b = Dense(32)(a) model = Model(inputs=a, outputs=b) - model = keras_support.tpu_model(model, replicas=2) + # If `num_cores_per_host` is greater than one, batch parallelism will be used + # to run on multiple TPU cores. + strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8) + model = keras_support.tpu_model(model, strategy) model.compile( optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0), ...) + model.shutdown() ``` Args: model: A `KerasTPUModel`. - replicas: (Optional) Int, number of TPU cores which to create model - replicas. If `None`, the model runs on single core only, i.e., no - replication. + tpu_name_or_address: A string that is either the name of the Cloud TPU, + the grpc address of the Cloud TPU, or (Googlers only) the BNS name of the + Cloud TPU. If tpu_name_or_address is None, the TPUClusterResolver will + examine the environment to determine a potential Cloud TPU to use. + strategy: `TPUDistributionStrategy`. The strategy to use for replicating + model across multiple TPU cores. Returns: A new `KerasTPUModel` instance. """ + # Force initialization of the CPU model. + model.get_weights() + model.reset_states() + _validate_shapes(model) # TODO(xiejw): Validate TPU model. TPUModel only? # TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset? # TODO(xiejw): Adds reduction option. - replicas = 1 if replicas is None else replicas + if strategy is None: + strategy = TPUDistributionStrategy(num_cores_per_host=1) return KerasTPUModel( - inputs=model.inputs, outputs=model.outputs, name=model.name, - replicas=replicas) + cpu_model=model, + tpu_name_or_address=tpu_name_or_address, + strategy=strategy) diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py index cda9a63f204ed686b527c95dd5b4fd7786ac60cf..1fb26e701a392d5ef3bc40d5772d4541fa38f773 100644 --- a/tensorflow/contrib/tpu/python/tpu/topology.py +++ b/tensorflow/contrib/tpu/python/tpu/topology.py @@ -55,8 +55,9 @@ class Topology(object): rank 3 numpy int32 array that describes a valid coordinate mapping. """ + self._serialized = serialized + if serialized: - self._serialized = serialized self._parse_topology(serialized) else: self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32) @@ -131,7 +132,7 @@ class Topology(object): proto.mesh_shape[:] = list(self._mesh_shape) proto.num_tasks = self._device_coordinates.shape[0] proto.num_tpu_devices_per_task = self._device_coordinates.shape[1] - proto.device_coordinates = list(self._device_coordinates.flatten()) + proto.device_coordinates.extend(list(self._device_coordinates.flatten())) self._serialized = proto.SerializeToString() return self._serialized diff --git a/tensorflow/contrib/tpu/python/tpu/topology_test.py b/tensorflow/contrib/tpu/python/tpu/topology_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e67fdb263aa48a37f65c3623365ebcf8f98bebd4 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/topology_test.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Tests for topology.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.tpu import topology + +from tensorflow.python.platform import test + + +class TopologyTest(test.TestCase): + + def testSerialization(self): + """Test if the class is able to generate serialzied string.""" + original_topology = topology.Topology( + mesh_shape=[1, 1, 2], + device_coordinates=[[[0, 0, 0], [0, 0, 1]]], + ) + serialized_str = original_topology.serialized() + new_topology = topology.Topology(serialized=serialized_str) + + # Make sure the topology recovered from serialized str is same as the + # original topology. + self.assertAllEqual( + original_topology.mesh_shape, new_topology.mesh_shape) + self.assertAllEqual( + original_topology.device_coordinates, new_topology.device_coordinates) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index cd0fd6ae8a2b35efa85bb4583ed3846a1a33395f..7fa06d6d560a4b6ffa6d9a3fd0fa208b4c60ee7f 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -151,6 +151,41 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._name = name self._unsupported_ops = [] self._pivot = pivot + self._replicated_vars = {} + + def get_replicated_var_handle(self, var): + """Returns a variable handle for replicated TPU variable 'var'. + + This is an method used by an experimental replicated variable + implementation and is not intended as a public API. + + Args: + var: The replicated TPU variable. + + Returns: + The handle of the TPU replicated input node. + """ + handle = self._replicated_vars.get(var) + if handle is not None: + return handle + + # Builds a TPUReplicatedInput node for the variable, if one does not already + # exist. The TPUReplicatedInput node must belong to the enclosing + # control-flow scope of the TPUReplicateContext. + # TODO(phawkins): consider changing the contract of the TPU encapsulation + # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope + # instead. + + # pylint: disable=protected-access + graph = ops.get_default_graph() + saved_context = graph._get_control_flow_context() + graph._set_control_flow_context(self.outer_context) + handle = tpu_ops.tpu_replicated_input( + [v.handle for v in var._vars], name=var.name + "/handle") + graph._set_control_flow_context(saved_context) + # pylint: enable=protected-access + self._replicated_vars[var] = handle + return handle def report_unsupported_operations(self): if self._unsupported_ops: @@ -227,19 +262,26 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): class FakeOp(object): """A helper class to determine the current device. - Supports only the device set/get methods needed to run the + Supports only the type and device set/get methods needed to run the graph's _apply_device_function method. """ def __init__(self): self._device = "" + @property + def type(self): + return "FakeOp" + @property def device(self): return self._device def _set_device(self, device): - self._device = device.to_string() + if isinstance(device, pydev.DeviceSpec): + self._device = device.to_string() + else: + self._device = device if self._outside_compilation_cluster: raise NotImplementedError("Cannot nest outside_compilation clusters") @@ -272,7 +314,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): # Capture the device function stack at the time of first entry # since that is the stack that will be used outside_compilation. graph = ops.get_default_graph() - self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access + # pylint: disable=protected-access + self._outer_device_function_stack = graph._device_function_stack.copy() + # pylint: enable=protected-access super(TPUReplicateContext, self).Enter() def HostComputeCore(self): @@ -591,17 +635,14 @@ def split_compile_and_replicate(computation, with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): - # The EncapsulateTPUComputations rewrite needs to identify the - # replicated arguments inside each computation. Adds identity operators - # tagged with an attribute _tpu_replicated_input to identify the - # replicated inputs. - # pylint: disable=protected-access - with graph._attr_scope({"_tpu_replicated_input": - attr_value_pb2.AttrValue(b=True)}): - computation_inputs = [ - array_ops.identity(x, name="replicated_input_{}".format(i)) - for i, x in enumerate(computation_inputs)] - # pylint: enable=protected-access + # Add identity ops so even unused inputs are "consumed" by the + # computation. This is to avoid orphaned TPUReplicatedInput nodes. + # TODO(phawkins): consider instead pruning unused TPUReplicatedInput + # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. + computation_inputs = [ + array_ops.identity(x, name="replicated_input_{}".format(i)) + for i, x in enumerate(computation_inputs) + ] # If there is an infeed queue, adds the dequeued values to the # computation's inputs. @@ -929,8 +970,15 @@ def rewrite(computation, Args: computation: A Python function that builds a computation to apply to the input. If the function takes n inputs, 'inputs' should be - a list of n tensors. If the function returns m outputs, rewrite - will return a list of m tensors. + a list of n tensors. + + `computation` may return a list of operations and tensors. Tensors must + come before operations in the returned list. The return value of + `rewrite` is a list of tensors corresponding to the tensors from the + from `computation`. + + All `Operation`s returned from `computation` will be executed when + evaluating any of the returned output tensors. inputs: A list of input tensors or `None` (equivalent to an empty list). infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to `computation`. @@ -967,6 +1015,19 @@ _BLACKLISTED_INFERENCE_OPS = set([ ]) +def under_tpu_inference_context(): + """Check if it is currently under `tpu.rewrite_for_inference()`.""" + graph = ops.get_default_graph() + + context = graph._get_control_flow_context() # pylint: disable=protected-access + while context: + if isinstance(context, _TPUInferenceContext): + return True + context = context.outer_context + + return False + + class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): """A `ControlFlowContext` for nodes inside a TPU inference computation. diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 6d7331e3c79ade9c12c15de79f550cf3973c4e6c..8d05e081a7c6e0327fedae6dc2c3ba45df40d029 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -23,8 +23,6 @@ import collections import json import os -import numpy as np - from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import run_config as run_config_lib @@ -43,17 +41,18 @@ class InputPipelineConfig(object): PER_SHARD_V1 = 1 PER_HOST_V1 = 2 PER_HOST_V2 = 3 + BROADCAST = 4 -# TODO(b/72511246) Provide a simplified api to configure model parallelism. class TPUConfig( collections.namedtuple('TPUConfig', [ 'iterations_per_loop', 'num_shards', - 'computation_shape', + 'num_cores_per_replica', 'per_host_input_for_training', 'tpu_job_name', 'initial_infeed_sleep_secs', + 'input_partition_dims', ])): r"""TPU related configuration required by `TPUEstimator`. @@ -67,22 +66,22 @@ class TPUConfig( case, this number equals the total number of TPU cores. For model-parallelism, the total number of TPU cores equals product(computation_shape) * num_shards. - computation_shape: Defaults to `None`, which disables model parallelism. A - list of size 3 which describes the shape of a model replica's block of - cores. This is required by model-parallelism which enables partitioning - the model to multiple cores. For example, [2, 2, 1] means the model is - partitioned across 4 cores which span two cores in both x and y - coordinates. Please refer to @{tf.contrib.tpu.Topology} for the - geometry of a TPU mesh. + num_cores_per_replica: Defaults to `None`, which disables model parallelism. + An integer which describes the number of TPU cores per model replica. This + is required by model-parallelism which enables partitioning + the model to multiple cores. Currently num_cores_per_replica must be + 1, 2, 4, or 8. per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`, - `input_fn` is invoked per-host rather than per-core. With per-host input - pipeline configuration, `input_fn` is invoked once on each host. With the - per-core input pipeline configuration, it is invoked once for each core. + `input_fn` is invoked once on each host. With the per-core input pipeline + configuration, it is invoked once for each core. With a global batch size `train_batch_size` in `TPUEstimator` constructor, the batch size for each shard is `train_batch_size` // #hosts in the `True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is - `train_batch_size` // #cores. With the per-core input pipeline - configuration, the shard batch size is also `train_batch_size` // #cores. + `train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only + invoked once on host 0 and the tensors are broadcasted to all other + replicas. The batch size equals to train_batch_size`. With the per-core + input pipeline configuration, the shard batch size is also + `train_batch_size` // #cores. Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN. tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred within TPUEstimator, however when using ClusterSpec propagation in more @@ -91,6 +90,17 @@ class TPUConfig( initial_infeed_sleep_secs: The number of seconds the infeed thread should wait before enqueueing the first batch. This helps avoid timeouts for models that require a long compilation time. + input_partition_dims: A nested list to describe the partition dims + for all the tensors from input_fn(). The structure of + input_partition_dims must match the structure of `features` and + `labels` from input_fn(). The total number of partitions must match + `num_cores_per_replica`. For example, if input_fn() returns two tensors: + images with shape [N, H, W, C] and labels [N]. + input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4 + pieces and feed into 4 TPU cores. labels tensor are directly broadcasted + to all the TPU cores since the partition dims is `None`. + Current limitations: This feature is only supported with the PER_HOST_V2 + input mode. Raises: ValueError: If `computation_shape` or `computation_shape` are invalid. @@ -99,10 +109,11 @@ class TPUConfig( def __new__(cls, iterations_per_loop=2, num_shards=None, - computation_shape=None, + num_cores_per_replica=None, per_host_input_for_training=True, tpu_job_name=None, - initial_infeed_sleep_secs=None): + initial_infeed_sleep_secs=None, + input_partition_dims=None): # Check iterations_per_loop. util_lib.check_positive_integer(iterations_per_loop, @@ -112,19 +123,26 @@ class TPUConfig( if num_shards is not None: util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') - # Check computation_shape - if computation_shape is not None and len(computation_shape) != 3: - raise ValueError( - 'computation_shape must be a list with length 3 or None; got {}'. - format(str(computation_shape))) + if input_partition_dims is not None: + if len(input_partition_dims) != 1 and len(input_partition_dims) != 2: + raise ValueError( + 'input_partition_dims must be a list/tuple with one or two' + ' elements.') + + if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2: + raise ValueError( + 'input_partition_dims is only supported in PER_HOST_V2 mode.') + + if num_cores_per_replica is None: + raise ValueError( + 'input_partition_dims requires setting num_cores_per_replica.') - if computation_shape is not None: - computation_shape_array = np.asarray(computation_shape, dtype=np.int32) - # This prevents any computation being replicated across multiple hosts, so - # that each host feeds the same number of computations. - if any(computation_shape_array < 1) or any(computation_shape_array > 2): - raise ValueError('computation_shape elements can only be 1 or 2; got ' - 'computation_shape={}'.format(computation_shape)) + # Parse computation_shape + if num_cores_per_replica is not None: + if num_cores_per_replica not in [1, 2, 4, 8]: + raise ValueError( + 'num_cores_per_replica must be 1, 2, 4, or 8; got {}'.format( + str(num_cores_per_replica))) # per_host_input_for_training may be True, False, or integer in [1..3]. # Map legacy values (True, False) to numeric values. @@ -144,10 +162,11 @@ class TPUConfig( cls, iterations_per_loop=iterations_per_loop, num_shards=num_shards, - computation_shape=computation_shape, + num_cores_per_replica=num_cores_per_replica, per_host_input_for_training=per_host_input_for_training, tpu_job_name=tpu_job_name, - initial_infeed_sleep_secs=initial_infeed_sleep_secs) + initial_infeed_sleep_secs=initial_infeed_sleep_secs, + input_partition_dims=input_partition_dims) class RunConfig(run_config_lib.RunConfig): @@ -214,6 +233,12 @@ class RunConfig(run_config_lib.RunConfig): self._session_config.cluster_def.CopyFrom( self._cluster_spec.as_cluster_def()) + def _maybe_overwrite_session_config_for_distributed_training(self): + # Overrides the parent class session_config overwrite for between-graph. TPU + # runs with in-graph, which should not have device filter. Doing nothing + # ("pass") basically disables it. + pass + @property def evaluation_master(self): return self._evaluation_master diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py index 37ef3dbe1e66efe18b13ab9153ee346c08b9774a..2326fe97a807e6708a9cdc24fea889b998025a45 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import json from tensorflow.contrib.tpu.python.tpu import tpu_config as tpu_config_lib +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.platform import test @@ -33,6 +34,46 @@ def _set_tf_config_env_variable(tf_config): class TPURunConfigTest(test.TestCase): + def test_no_session_config_set_in_local_case(self): + run_config = tpu_config_lib.RunConfig() + self.assertIsNone(run_config.session_config) + + def test_no_session_config_overwrite_in_local_case(self): + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + run_config = tpu_config_lib.RunConfig(session_config=session_config) + self.assertEqual(session_config, run_config.session_config) + + def test_no_session_config_set_with_cluster_spec(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host3:3'], + run_config_lib.TaskType.WORKER: ['host3:4'] + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with _set_tf_config_env_variable(tf_config): + run_config = tpu_config_lib.RunConfig() + self.assertIsNone(run_config.session_config) + + def test_no_session_config_overwrite_with_cluster_spec(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host3:3'], + run_config_lib.TaskType.WORKER: ['host3:4'] + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with _set_tf_config_env_variable(tf_config): + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + run_config = tpu_config_lib.RunConfig(session_config=session_config) + self.assertEqual(session_config, run_config.session_config) + def test_fail_with_invalid_num_shards(self): with self.assertRaisesRegexp(ValueError, 'must be positive'): tpu_config_lib.RunConfig( @@ -43,15 +84,11 @@ class TPURunConfigTest(test.TestCase): tpu_config_lib.RunConfig( tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0)) - def test_fail_with_invalid_computation_shape(self): - with self.assertRaisesRegexp(ValueError, - 'computation_shape must be a list with length' - ' 3 or None'): - tpu_config_lib.TPUConfig(computation_shape=[2, 1]) - - with self.assertRaisesRegexp(ValueError, - 'computation_shape elements can only be'): - tpu_config_lib.TPUConfig(computation_shape=[1, 3, 1]) + def test_fail_with_invalid_num_cores_per_replica(self): + with self.assertRaisesRegexp( + ValueError, 'num_cores_per_replica must be 1, 2, 4, or 8;' + ' got 7'): + tpu_config_lib.TPUConfig(num_cores_per_replica=7) class TPURunConfigMasterTest(test.TestCase): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 5b9aeaa8797b92b4cc596744812f440607054dce..806ae1c4c9918be0bf0af8579c12386c0a18aff0 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -21,8 +21,6 @@ from __future__ import print_function from contextlib import contextmanager import copy -import numpy as np - from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib @@ -33,15 +31,26 @@ from tensorflow.python.platform import tf_logging as logging _DEFAULT_JOB_NAME = 'tpu_worker' _DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' _LOCAL_MASTERS = ('', 'local') +_NUM_CORES_TO_COMPUTATION_SHAPE = { + 1: [1, 1, 1], + 2: [1, 1, 2], + 4: [1, 2, 2], + 8: [2, 2, 2] +} class TPUContext(object): """The context of current input_fn invocation.""" - def __init__(self, internal_ctx, input_device=None, invocation_index=None): + def __init__(self, + internal_ctx, + input_device=None, + invocation_index=None, + call_from_input_fn=True): self._internal_ctx = internal_ctx self._input_device = input_device self._invocation_index = invocation_index + self._call_from_input_fn = call_from_input_fn def current_input_fn_deployment(self): """The configuration of the current input_fn invocation. @@ -69,11 +78,21 @@ class TPUContext(object): total invocation count is equal to the number of hosts in the system and num replicas consumed by current invocation is equal to number of cores per host. + + Raises: + RuntimeError: If this method must not be called from input_fn. """ + if not self._call_from_input_fn: + raise RuntimeError('This TPUContext instance must not be called from' + ' model_fn.') + if self._internal_ctx.is_input_sharded_per_core(): total_invocation_count = (self._internal_ctx.num_hosts * self._internal_ctx.num_of_replicas_per_host) replicas_consumed = 1 + elif self._internal_ctx.is_input_broadcast_with_iterators(): + total_invocation_count = 1 + replicas_consumed = self._internal_ctx.num_replicas else: total_invocation_count = self._internal_ctx.num_hosts replicas_consumed = self._internal_ctx.num_of_replicas_per_host @@ -92,6 +111,27 @@ class TPUContext(object): """ return self._internal_ctx.num_replicas + @property + def num_hosts(self): + """The number of hosts for the TPU system.""" + return self._internal_ctx.num_hosts + + @property + def num_of_replicas_per_host(self): + """The number of replicas for each host.""" + if self._internal_ctx.model_parallelism_enabled: + raise ValueError( + 'num_of_replicas_per_host is not supported for model_parallelism') + return self._internal_ctx.num_of_replicas_per_host + + @property + def device_assignment(self): + """Returns device_assignment object.""" + if self._call_from_input_fn: + raise RuntimeError('This TPUContext instance must not be called from' + ' input_fn.') + return self._internal_ctx.device_assignment + def device_for_replica(self, replica_id): """Returns the tuple of (CPU device and device ordinal) for replica. @@ -106,24 +146,7 @@ class TPUContext(object): # Note that: For the non-model parallelism, the mapping could be # a random permutation. The order should not matter in most cases # as far as model is replicated to all cores in the system. - - # If the precise replica_id to device mapping is required, please - # set the computation_shape as [1,1,1] in TPUConfig to enable - # the model parallelism. - if self._internal_ctx.model_parallelism_enabled: - return RuntimeError( - 'device_for_replica is not yet implemented for model parallelism. ' - 'b/79689078.') - - master = self._internal_ctx.master_job - job_device = '' if master is None else ('/job:%s' % master) - - num_of_replicas_per_host = self._internal_ctx.num_of_replicas_per_host - host_id = replica_id / num_of_replicas_per_host - ordinal_id = replica_id % num_of_replicas_per_host - - host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id) - return (host_device, ordinal_id) + return self._internal_ctx.device_for_replica(replica_id) class _InternalTPUContext(object): @@ -162,9 +185,14 @@ class _InternalTPUContext(object): self._eval_on_tpu = eval_on_tpu self._model_parallelism_enabled = ( - use_tpu and config.tpu_config.computation_shape) + use_tpu and config.tpu_config.num_cores_per_replica) self._mode = None - + num_cores_per_replica = config.tpu_config.num_cores_per_replica + if num_cores_per_replica: + self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[ + num_cores_per_replica] + else: + self._computation_shape = None self._lazy_tpu_system_metadata_dict = {} # key by master address self._lazy_device_assignment_dict = {} # key by master address self._lazy_validation_dict = {} # key by ModeKeys @@ -204,11 +232,16 @@ class _InternalTPUContext(object): if tpu_system_metadata is not None: return tpu_system_metadata + cluster_def = None + if (self._config.session_config and + self._config.session_config.cluster_def.job): + cluster_def = self._config.session_config.cluster_def + # pylint: disable=protected-access tpu_system_metadata = ( tpu_system_metadata_lib._query_tpu_system_metadata( master, - run_config=self._config, + cluster_def=cluster_def, query_topology=self.model_parallelism_enabled)) self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata @@ -225,11 +258,12 @@ class _InternalTPUContext(object): device_assignment = tpu_device_assignment.device_assignment( tpu_system_metadata.topology, - computation_shape=self._config.tpu_config.computation_shape, + computation_shape=self._computation_shape, num_replicas=self.num_replicas) - logging.info('computation_shape: %s', - str(self._config.tpu_config.computation_shape)) + logging.info('num_cores_per_replica: %s', + str(self._config.tpu_config.num_cores_per_replica)) + logging.info('computation_shape: %s', str(self._computation_shape)) logging.info('num_replicas: %d', self.num_replicas) logging.info('device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) @@ -243,6 +277,10 @@ class _InternalTPUContext(object): def model_parallelism_enabled(self): return self._model_parallelism_enabled + @property + def input_partition_dims(self): + return self._config.tpu_config.input_partition_dims + @property def device_assignment(self): return (self._get_device_assignment() @@ -270,23 +308,20 @@ class _InternalTPUContext(object): num_cores_in_system = self.num_cores if self.model_parallelism_enabled: - computation_shape_array = np.asarray( - self._config.tpu_config.computation_shape, dtype=np.int32) - num_cores_per_replica = np.prod(computation_shape_array) + num_cores_per_replica = self._config.tpu_config.num_cores_per_replica if num_cores_per_replica > num_cores_in_system: raise ValueError( 'The num of cores required by the model parallelism, specified by ' - 'TPUConfig.computation_shape, is larger than the total num of ' - 'TPU cores in the system. computation_shape: {}, num cores ' - 'in the system: {}'.format( - self._config.tpu_config.computation_shape, - num_cores_in_system)) + 'TPUConfig.num_cores_per_replica, is larger than the total num of ' + 'TPU cores in the system. num_cores_per_replica: {}, num cores ' + 'in the system: {}'.format(num_cores_per_replica, + num_cores_in_system)) if num_cores_in_system % num_cores_per_replica != 0: raise RuntimeError( 'The num of cores in the system ({}) is not divisible by the num ' 'of cores ({}) required by the model parallelism, specified by ' - 'TPUConfig.computation_shape. This should never happen!'.format( + 'TPUConfig.num_cores_per_replica. This should never happen!'.format( num_cores_in_system, num_cores_per_replica)) return num_cores_in_system // num_cores_per_replica @@ -314,6 +349,11 @@ class _InternalTPUContext(object): return (self._config.tpu_config.per_host_input_for_training is tpu_config.InputPipelineConfig.PER_HOST_V2) + def is_input_broadcast_with_iterators(self): + """Return true if input_fn should be run in the full_replicae config.""" + return (self._config.tpu_config.per_host_input_for_training is + tpu_config.InputPipelineConfig.BROADCAST) + def is_running_on_cpu(self, is_export_mode=False): """Determines whether the input_fn and model_fn should be invoked on CPU. @@ -378,15 +418,13 @@ class _InternalTPUContext(object): """Returns the shard batch size for `input_fn`.""" global_batch_size = self.global_batch_size - if self.is_running_on_cpu(): + if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()): return global_batch_size # On TPU if self.is_input_sharded_per_core() or ( self.is_input_per_host_with_iterators()): - # We prohibit per core input sharding for the model parallelism case, - # therefore it is safe to use num_cores here. - return global_batch_size // self.num_cores + return global_batch_size // self.num_replicas else: return global_batch_size // self.num_hosts @@ -395,7 +433,7 @@ class _InternalTPUContext(object): """Returns the shard batch size for `model_fn`.""" global_batch_size = self.global_batch_size - if self.is_running_on_cpu(): + if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()): return global_batch_size # On TPU. always sharded per shard. @@ -452,17 +490,23 @@ class _InternalTPUContext(object): master = self.master_job - def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name + def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name + """Return the host device given replica_id or host_id.""" assert _sentinal is None - if core_id is not None and host_id is not None: + if replica_id is not None and host_id is not None: raise RuntimeError( - 'core_id and host_id can have only one non-None value.') + 'replica_id and host_id can have only one non-None value.') if master is None: return '/replica:0/task:0/device:CPU:0' else: - if core_id is not None: - host_id = core_id / self.num_of_cores_per_host + if replica_id is not None: + if self.model_parallelism_enabled: + return self.device_assignment.host_device( + replica=replica_id, job=master) + else: + host_id = replica_id / self.num_of_cores_per_host + return '/job:%s/task:%d/device:CPU:0' % (master, host_id) return _placement_function @@ -484,25 +528,27 @@ class _InternalTPUContext(object): return _placement_function - @property - def tpu_ordinal_function(self): + def tpu_ordinal_function(self, host_id): """Returns the TPU ordinal fn.""" - def _tpu_ordinal_function(index): + def _tpu_ordinal_function(shard_index_in_host): """Return the TPU ordinal associated with a shard. Required because the enqueue ops are placed on CPU. Args: - index: the shard index + shard_index_in_host: the shard index Returns: The ordinal of the TPU device the shard's infeed should be placed on. """ if self.model_parallelism_enabled: - return self.device_assignment.tpu_ordinal(replica=index) + # We put both enqueue/dequeue ops at tpu.core(0) in each replica. + replica = self.device_assignment.lookup_replicas( + host_id, (0, 0, 0))[shard_index_in_host] + return self.device_assignment.tpu_ordinal(replica=replica) else: - return index % self.num_of_cores_per_host + return shard_index_in_host % self.num_of_cores_per_host return _tpu_ordinal_function @@ -533,7 +579,7 @@ class _InternalTPUContext(object): 'be ({}), got ({}). For non-model-parallelism, num_replicas should ' 'be the total num of TPU cores in the system. For ' 'model-parallelism, the total number of TPU cores should be ' - 'product(computation_shape) * num_replicas. Please set it ' + 'num_cores_per_replica * num_replicas. Please set it ' 'accordingly or leave it as `None`'.format( self._get_master_address(), num_replicas, user_provided_num_replicas)) @@ -541,7 +587,8 @@ class _InternalTPUContext(object): raise ValueError(message) if mode == model_fn_lib.ModeKeys.TRAIN: - if self._train_batch_size % num_replicas != 0: + if (self._train_batch_size % num_replicas != 0 and + not self.is_input_broadcast_with_iterators()): raise ValueError( 'train batch size {} must be divisible by number of replicas {}' .format(self._train_batch_size, num_replicas)) @@ -551,11 +598,12 @@ class _InternalTPUContext(object): raise ValueError( 'eval_batch_size in TPUEstimator constructor cannot be `None`' 'if .evaluate is running on TPU.') - if self._eval_batch_size % num_replicas != 0: + if (self._eval_batch_size % num_replicas != 0 and + not self.is_input_broadcast_with_iterators()): raise ValueError( 'eval batch size {} must be divisible by number of replicas {}' .format(self._eval_batch_size, num_replicas)) - if num_hosts > 1: + if num_hosts > 1 and not self.is_input_broadcast_with_iterators(): raise ValueError( 'TPUEstimator.evaluate should be running on single TPU worker. ' 'got {}.'.format(num_hosts)) @@ -565,11 +613,12 @@ class _InternalTPUContext(object): raise ValueError( 'predict_batch_size in TPUEstimator constructor should not be ' '`None` if .predict is running on TPU.') - if self._predict_batch_size % num_replicas != 0: + if (self._predict_batch_size % num_replicas != 0 and + not self.is_input_broadcast_with_iterators()): raise ValueError( 'predict batch size {} must be divisible by number of replicas {}' .format(self._predict_batch_size, num_replicas)) - if num_hosts > 1: + if num_hosts > 1 and not self.is_input_broadcast_with_iterators(): raise ValueError( 'TPUEstimator.predict should be running on single TPU worker. ' 'got {}.'.format(num_hosts)) @@ -577,6 +626,33 @@ class _InternalTPUContext(object): # Record the state "validated" into lazy dictionary. self._lazy_validation_dict[mode] = True + def device_for_replica(self, replica_id): + """Returns the tuple of (CPU device and device ordinal) for replica. + + This should be used for full replicate for non-model-parallelism. + + Args: + replica_id: Int, the replica index. + + Returns: + A tuple of device spec for CPU device and int device ordinal. + """ + master = self.master_job + + if self.model_parallelism_enabled: + return (self.device_assignment.host_device( + replica=replica_id, job=master), + self.device_assignment.tpu_ordinal(replica=replica_id)) + + job_device = '' if master is None else ('/job:%s' % master) + + num_of_replicas_per_host = self.num_of_replicas_per_host + host_id = replica_id / num_of_replicas_per_host + ordinal_id = replica_id % num_of_replicas_per_host + + host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id) + return (host_device, ordinal_id) + class _OneCoreTPUContext(_InternalTPUContext): """Special _InternalTPUContext for one core usage.""" @@ -612,7 +688,7 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size, """Returns an instance of `_InternalTPUContext`.""" if (config.tpu_config.num_shards == 1 and - config.tpu_config.computation_shape is None): + config.tpu_config.num_cores_per_replica is None): logging.warning( 'Setting TPUConfig.num_shards==1 is an unsupported behavior. ' 'Please fix as soon as possible (leaving num_shards as None.') diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index e94bd78833f6cbe9adb1b6ca3f29a88bd8a53f64..029492b489ea2b790660d7a02dfd189451acf26c 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -22,9 +22,9 @@ import collections import copy import os import signal +import sys import threading import time -import traceback import numpy as np import six @@ -32,6 +32,7 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import error_handling from tensorflow.contrib.tpu.python.tpu import session_support from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_config @@ -81,12 +82,17 @@ _TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CTX_KEY = 'context' +_USE_TPU_KEY = 'use_tpu' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _ONE_GIGABYTE = 1024 * 1024 * 1024 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' _TPU_TRAIN_OP = '_tpu_train_op' _REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' +# Ideally _USE_TPU_KEY should be reserved as well. However there are already +# models that make use of this key, thus it can not be reserved now to prevent +# breakage. In the long run, we would like to mitigate this by migrating models +# off of using _USE_TPU_KEY. _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] @@ -211,14 +217,14 @@ class _SIGNAL(object): class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. - See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and - 'export_outputs`. + See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and + `export_outputs`. For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where `metric_fn` runs on CPU to generate metrics and `tensors` represents the `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`. To be precise, TPU evaluation expects a slightly different signature from the - @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a + `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`. The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The `tensors` usually specify the model logits, which are transferred back from @@ -226,7 +232,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote size is the first dimension. Once all tensors are available at CPU host from all shards, they are concatenated (on CPU) and passed as positional arguments to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is - dict. `metric_fn` takes the `tensors` and returns a dict from metric string + a dict. `metric_fn` takes the `tensors` and returns a dict from metric string name to the result of calling a metric function, namely a `(metric_tensor, update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the `eval_metrics`. @@ -241,7 +247,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote sending tensors from TPU to CPU. To reduce the overhead, try reducing the size of the tensors. The `tensors` are concatenated along their major (batch) dimension, and so must be >= rank 1. The `host_call` is useful for writing - summaries with @{tf.contrib.summary.create_file_writer}. + summaries with `tf.contrib.summary.create_file_writer`. """ def __new__(cls, @@ -252,7 +258,10 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote eval_metrics=None, export_outputs=None, scaffold_fn=None, - host_call=None): + host_call=None, + training_hooks=None, + evaluation_hooks=None, + prediction_hooks=None): """Creates a validated `TPUEstimatorSpec` instance.""" host_calls = {} if eval_metrics is not None: @@ -260,6 +269,17 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote if host_call is not None: host_calls['host_call'] = host_call _OutfeedHostCall.validate(host_calls) + + training_hooks = list(training_hooks or []) + evaluation_hooks = list(evaluation_hooks or []) + prediction_hooks = list(prediction_hooks or []) + + for hook in training_hooks + evaluation_hooks + prediction_hooks: + if not isinstance(hook, session_run_hook.SessionRunHook): + raise TypeError( + 'All hooks must be SessionRunHook instances, given: {}'.format( + hook)) + return super(TPUEstimatorSpec, cls).__new__( cls, mode=mode, @@ -269,7 +289,10 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote eval_metrics=eval_metrics, export_outputs=export_outputs, scaffold_fn=scaffold_fn, - host_call=host_call) + host_call=host_call, + training_hooks=training_hooks, + evaluation_hooks=evaluation_hooks, + prediction_hooks=prediction_hooks) def as_estimator_spec(self): """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" @@ -285,6 +308,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote hooks = None if self.host_call is not None: hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] + hooks = list(hooks or []) scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( mode=self.mode, @@ -294,9 +318,9 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote eval_metric_ops=eval_metric_ops, export_outputs=self.export_outputs, scaffold=scaffold, - training_hooks=hooks, - evaluation_hooks=hooks, - prediction_hooks=hooks) + training_hooks=self.training_hooks + hooks, + evaluation_hooks=self.evaluation_hooks + hooks, + prediction_hooks=self.prediction_hooks + hooks) class _OpQueueContext(object): @@ -360,17 +384,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): ctx, enqueue_ops, dequeue_ops, - run_infeed_loop_on_coordinator=True): + run_infeed_loop_on_coordinator=True, + rendezvous=None): self._master_job = ctx.master_job self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops + self._rendezvous = rendezvous self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator self._initial_infeed_sleep_secs = ( ctx.config.tpu_config.initial_infeed_sleep_secs) - self._session_cancel_timer = None - self._feed_error = None self._finished = False @@ -387,62 +411,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): for op in summary_writer_init_ops: self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) - def _log_error(self, session, error): - """Log an infeed or outfeed error. - - This logs a short error message immediately, and schedules a timer to - emit the full stack trace and error message after a short period of time. - If the main session has terminated by the time the timer triggers, we - assume the real source of the error was from the main session and avoid - emitting a stack trace for the infeed. - - Args: - session: `tf.Session`, session to be terminated error: exception that - triggered logging. - error: the Exception to log. - """ - logging.warning( - '\n\n' - 'Error occurred during infeed/outfeed. This may be due to a compile ' - 'error in the main session. Waiting for a short time for the main ' - 'session to come back.\n\n%s', error) - - self._feed_error = traceback.format_exc() - - # If we've already encountered a feed error, don't schedule another - # cancellation op. - if self._session_cancel_timer: - return - - def _cancel_session(): - """Close the session to avoid the main thread from hanging. - - If input pipeline triggers any error, the infeed thread dies but the main - thread for TPU computation waits for the infeed enqueue forever. Close the - Session to cancel the main thread Session.run execution. - - We sleep for a few seconds before closing to give some time for the TPU - compilation error, if any, propagating, from TPU to CPU host. Compilation - errors should be reported by the main thread so that the program can be - interrupted and users can take action. Due to a race condition, the - infeed thread might see an error first. Closing the session here - immediately would result in a session cancellation exception in the main - thread, instead of the expected compile error. User code that depends on - having the proper exception type will therefore be confused. - """ - time.sleep(5) - - # If the main session is still running, the infeed/outfeed errors are - # legitimate, and should be logged. - if not self._finished and self._feed_error: - logging.error('Feed error: %s', self._feed_error) - logging.error('Closing session. A RuntimeError should follow.') - session.close() - - self._session_cancel_timer = threading.Thread(target=_cancel_session) - self._session_cancel_timer.daemon = True - self._session_cancel_timer.start() - def _run_infeed(self, queue_ctx, session): logging.info('Starting infeed thread controller.') if self._initial_infeed_sleep_secs: @@ -451,7 +419,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): time.sleep(self._initial_infeed_sleep_secs) logging.info('%s thread starting after sleep', self._name) - try: + with self._rendezvous.catch_errors(source='infeed', session=session): if self._run_infeed_loop_on_coordinator: for count, steps in enumerate(queue_ctx.read_iteration_counts()): for i in xrange(steps): @@ -461,19 +429,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): for _ in queue_ctx.read_iteration_counts(): session.run(self._enqueue_ops) logging.info('Infeed thread finished, shutting down.') - except Exception as e: # pylint: disable=broad-except - self._log_error(session, e) def _run_outfeed(self, queue_ctx, session): logging.info('Starting outfeed thread controller.') - try: + with self._rendezvous.catch_errors(source='outfeed', session=session): for count, steps in enumerate(queue_ctx.read_iteration_counts()): for i in xrange(steps): logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) session.run(self._dequeue_ops) logging.info('Outfeed thread finished, shutting down.') - except Exception as e: # pylint: disable=broad-except - self._log_error(session, e) def _create_infeed_controller(self, name, target, args): return _OpQueueContext(name=name, target=target, args=args) @@ -492,11 +456,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def before_run(self, run_context): self._feed_error = None - # Wait for the cancellation timer to complete before continuing. - if self._session_cancel_timer: - self._session_cancel_timer.join() - self._session_cancel_timer = None - iterations = run_context.session.run(self._iterations_per_loop_var) logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) @@ -507,16 +466,14 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): self._outfeed_controller.send_next_batch_signal(iterations) def end(self, session): - if self._session_cancel_timer: - logging.warning('Feed error occurred; waiting for message.') - self._session_cancel_timer.join() - self._finished = True logging.info('Stop infeed thread controller') self._infeed_controller.join() + self._rendezvous.record_done('infeed') logging.info('Stop output thread controller') self._outfeed_controller.join() + self._rendezvous.record_done('outfeed') logging.info('Shutdown TPU system.') session.run(self._finalize_ops) @@ -524,9 +481,10 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): - def __init__(self, ctx, enqueue_ops, dequeue_ops): + def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None): super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( - ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False) + ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False, + rendezvous=rendezvous) def _create_infeed_controller(self, name, target, args): return _OpSignalOnceQueueContext(name=name, target=target, args=args) @@ -664,6 +622,7 @@ def generate_per_core_enqueue_ops_fn_for_host( ctx, input_fn, inputs_structure_recorder, host_device, host_id): """Generates infeed enqueue ops for per-core input_fn on a single host.""" captured_infeed_queue = _CapturedObject() + tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): """A fn returns enqueue_ops.""" @@ -695,11 +654,9 @@ def generate_per_core_enqueue_ops_fn_for_host( infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(per_host_sharded_inputs[0])) captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_configuration_from_sharded_input_tensors( - per_host_sharded_inputs) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function) + per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) return per_host_enqueue_ops return enqueue_ops_fn, captured_infeed_queue @@ -734,19 +691,7 @@ def generate_per_host_enqueue_ops_fn_for_host( if is_dataset: hooks.append(inputs.dataset_initializer_hook()) - # TODO(ylc): Refactoring the code to merge the tpu ordinal logic here and the - # _InternalTPUContext.tpu_ordinal_function. We should either introduce another - # abstraction or a different helper method. - def _tpu_ordinal_function_impl(shard_index_in_host): - # We put both enqueue/dequeue op at tpu.core(0) in each replica. - replica = ctx.device_assignment.lookup_replicas( - host_id, (0, 0, 0))[shard_index_in_host] - return ctx.device_assignment.tpu_ordinal(replica=replica) - - if ctx.model_parallelism_enabled: - tpu_ordinal_function = _tpu_ordinal_function_impl - else: - tpu_ordinal_function = None + tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): """A Fn returning the TPU infeed enqueue ops. @@ -782,7 +727,7 @@ def generate_per_host_enqueue_ops_fn_for_host( infeed_queue.split_inputs_and_generate_enqueue_ops( unsharded_tensor_list, placement_function=lambda x: device, - tpu_ordinal_function=tpu_ordinal_function)) + tpu_ordinal_function=tpu_ordinal_function_impl)) if signals is None: return per_host_enqueue_ops else: @@ -816,6 +761,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( raise TypeError('Most PREDICT not yet supported in PER_HOST_V2 mode.') hooks.append(inputs.dataset_initializer_hook()) + tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): """Generates the per_host enqueue ops.""" @@ -835,19 +781,105 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( flattened_inputs = ( inputs_structure_recorder.flatten_features_and_labels( features, labels)) - control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) + if inputs_structure_recorder.flattened_input_dims: + # pylint: disable=protected-access + infeed_queue = tpu_feed._PartitionedInfeedQueue( + number_of_tuple_elements=len(per_host_sharded_inputs[0]), + host_id=host_id, + input_partition_dims=inputs_structure_recorder.flattened_input_dims, + device_assignment=ctx.device_assignment) + per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( + per_host_sharded_inputs) + else: + infeed_queue = tpu_feed.InfeedQueue( + number_of_tuple_elements=len(per_host_sharded_inputs[0])) + per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( + per_host_sharded_inputs, + tpu_ordinal_function=tpu_ordinal_function_impl) + captured_infeed_queue.capture(infeed_queue) + + return per_host_enqueue_ops + + return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset + + +def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, + num_hosts): + """Generates infeed enqueue ops for one input_fn on all the hosts.""" + captured_infeed_queue = _CapturedObject() + hooks = [] + device_0 = ctx.tpu_host_placement_function(host_id=0) + with ops.device(device_0): + user_context = tpu_context.TPUContext( + internal_ctx=ctx, input_device=device_0, invocation_index=0) + inputs = _Inputs.from_input_fn(input_fn(user_context)) + + is_dataset = inputs.is_dataset + if ctx.mode == model_fn_lib.ModeKeys.PREDICT: + if not is_dataset: + raise TypeError( + 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' + '`features` and `labels`.') + + inputs = _InputsWithStoppingSignals( + dataset=inputs.dataset, + batch_size=ctx.batch_size_for_input_fn, + add_padding=True) + + if is_dataset: + hooks.append(inputs.dataset_initializer_hook()) + num_replicas_per_host = ctx.num_of_replicas_per_host + + def tpu_ordinal_function_impl(replica_id): + if ctx.device_assignment: + return ctx.device_assignment.tpu_ordinal(replica=replica_id) + else: + return replica_id % num_replicas_per_host + + def device_function_impl(replica_id): + return ctx.tpu_host_placement_function(replica_id=replica_id) + + def enqueue_ops_fn(): + """Generates enqueue ops for all the hosts.""" + broadcasted_inputs = [] + flattened_inputs = None # Cache result from input_fn. + signals = None + for host_id in xrange(num_hosts): + with ops.device(ctx.tpu_host_placement_function(host_id=host_id)): + for _ in xrange(ctx.num_of_replicas_per_host): + # Note: input_fn is only called once at host 0 for the first replica. + # The features and labels returned from that invocation are + # broadcasted to other replicas(including the replicas on other + # hosts). + if flattened_inputs is None: + features, labels = inputs.features_and_labels() # Calls get_next() + signals = inputs.signals() + + inputs_structure_recorder.validate_and_record_structure( + features, labels, signals) + flattened_inputs = ( + inputs_structure_recorder.flatten_features_and_labels( + features, labels, signals)) + broadcasted_inputs.append(flattened_inputs) + infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0])) + number_of_tuple_elements=len(broadcasted_inputs[0])) captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_configuration_from_sharded_input_tensors( - per_host_sharded_inputs) + enqueue_ops = infeed_queue.generate_enqueue_ops( + broadcasted_inputs, + tpu_ordinal_function=tpu_ordinal_function_impl, + placement_function=device_function_impl) - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function) - return per_host_enqueue_ops + if signals is None: + return enqueue_ops + else: + return { + 'ops': enqueue_ops, + 'signals': signals, + } return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset @@ -885,21 +917,68 @@ class _InputPipeline(object): class InputsStructureRecorder(object): """The recorder to record inputs structure.""" - def __init__(self): + def __init__(self, input_partition_dims=None): # Holds the structure of inputs self._feature_names = [] self._label_names = [] self._has_labels = False self._signals_helper = None + self._flattened_input_dims = None + + if input_partition_dims: + # This should have been validated in TPUConfig. + assert len(input_partition_dims) <= 2, 'must have 1 or 2 elements.' + if len(input_partition_dims) == 2: + self._feature_dims, self._label_dims = input_partition_dims + else: + self._feature_dims = input_partition_dims[0] + self._label_dims = None + + assert self._feature_dims is not None, ('input_partition_dims[0] must ' + 'not be None') + else: + self._feature_dims = None + self._label_dims = None # Internal state. self._initialized = False + @property + def flattened_input_dims(self): + assert self._initialized, 'InputsStructureRecorder is not initialized.' + return self._flattened_input_dims + def has_labels(self): return self._has_labels + def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims, + label_dims_names, label_names, has_labels): + """Flatten input dims with the same order as flattened input tensors.""" + flattened_input_dims = [] + if feature_dims_names: + # We need a fixed ordering for matching the tensors in features. + flattened_input_dims.extend( + [feature_dims[name] for name in feature_dims_names]) + else: + flattened_input_dims.append(feature_dims) + + if label_dims_names: + # We need a fixed ordering for matching the tensors in labels. + flattened_input_dims.extend( + [label_dims[name] for name in label_dims_names]) + else: + if label_names: + num_tensors_in_label = len(label_names) + else: + num_tensors_in_label = int(has_labels) + # Setting `None` in input_partition_dims[1] will apply `None` to + # all the tensors in labels, regardless of internal structure. + flattened_input_dims.extend([label_dims] * num_tensors_in_label) + + return flattened_input_dims + def validate_and_record_structure(self, features, labels, signals=None): - """Validates and records the structure of features` and `labels`.""" + """Validates and records the structure of `features` and `labels`.""" def _extract_key_names(tensor_or_dict): if tensor_or_dict is None: @@ -927,6 +1006,24 @@ class _InputPipeline(object): self._feature_names = feature_names self._label_names = label_names self._has_labels = has_labels + if self._feature_dims is not None: + feature_dims_names = _extract_key_names(self._feature_dims) + if feature_dims_names != feature_names: + raise ValueError( + 'TPUConfig.input_partition_dims[0] mismatched feature' + ' keys. Expected {}, got {}'.format(feature_names, + feature_dims_names)) + + label_dims_names = _extract_key_names(self._label_dims) + if self._label_dims is not None and label_dims_names != label_names: + raise ValueError( + 'TPUConfig.input_partition_dims[1] mismatched label' + ' keys. Expected {}, got {}'.format(label_names, + label_dims_names)) + + self._flattened_input_dims = self._flatten_input_dims( + self._feature_dims, feature_dims_names, self._label_dims, + label_dims_names, label_names, has_labels) def flatten_features_and_labels(self, features, labels, signals=None): """Flattens the `features` and `labels` to a single tensor list.""" @@ -1021,7 +1118,8 @@ class _InputPipeline(object): Raises: ValueError: If both `sharded_features` and `num_cores` are `None`. """ - self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder() + self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder( + ctx.input_partition_dims) self._sharded_per_core = ctx.is_input_sharded_per_core() self._input_fn = input_fn @@ -1084,6 +1182,24 @@ class _InputPipeline(object): # Infeed_queue_getter must be called after enqueue_ops_fn is called. infeed_queues.append(captured_infeed_queue.get()) + elif self._ctx.is_input_broadcast_with_iterators(): + # Only calls input_fn in host 0. + host_device = tpu_host_placement_fn(host_id=0) + enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( + generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn, + self._inputs_structure_recorder, + num_hosts)) + all_hooks.extend(hooks) + if is_dataset: + run_infeed_loop_on_coordinator = False + wrap_fn = ( + _wrap_computation_in_while_loop + if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else + _wrap_computation_in_while_loop_with_stopping_signals) + enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) + else: + enqueue_ops.append(enqueue_ops_fn()) + infeed_queues.append(captured_infeed_queue.get()) else: for host_id in range(num_hosts): host_device = tpu_host_placement_fn(host_id=host_id) @@ -1146,7 +1262,7 @@ class _InputPipeline(object): err_msg = ('Input pipeline contains one or more QueueRunners. ' 'It could be slow and not scalable. Please consider ' 'converting your input pipeline to use `tf.data` instead (see ' - 'https://www.tensorflow.org/programmers_guide/datasets for ' + 'https://www.tensorflow.org/guide/datasets for ' 'instructions.') if _WRAP_INPUT_FN_INTO_WHILE_LOOP: raise RuntimeError(err_msg) @@ -1198,6 +1314,7 @@ class _ModelFnWrapper(object): host_call = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() + captured_training_hooks = _CapturedObject() def train_step(loss): """Training step function for use inside a while loop.""" @@ -1214,6 +1331,8 @@ class _ModelFnWrapper(object): else: captured_scaffold_fn.capture(None) + captured_training_hooks.capture(estimator_spec.training_hooks) + # We must run train_op to update the variables prior to running the # outfeed. with ops.control_dependencies([train_op]): @@ -1225,7 +1344,8 @@ class _ModelFnWrapper(object): with ops.control_dependencies(host_call_outfeed_ops): return array_ops.identity(loss) - return train_step, host_call, captured_scaffold_fn + return (train_step, host_call, captured_scaffold_fn, + captured_training_hooks) def convert_to_single_tpu_eval_step(self, dequeue_fn): """Converts user provided model_fn` as a single eval step on TPU. @@ -1255,6 +1375,7 @@ class _ModelFnWrapper(object): """ host_calls = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() + captured_eval_hooks = _CapturedObject() def eval_step(total_loss): """Evaluation step function for use inside a while loop.""" @@ -1269,8 +1390,11 @@ class _ModelFnWrapper(object): loss = tpu_estimator_spec.loss captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) + captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks) + to_record = {} - to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics + if tpu_estimator_spec.eval_metrics: + to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics if tpu_estimator_spec.host_call is not None: # We assume that evaluate won't update global step, so we don't wrap # this host_call. @@ -1280,7 +1404,7 @@ class _ModelFnWrapper(object): with ops.control_dependencies(host_calls.create_enqueue_op()): return math_ops.add(total_loss, loss) - return eval_step, host_calls, captured_scaffold_fn + return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks def convert_to_single_tpu_predict_step(self, dequeue_fn): """Converts user provided model_fn` as a single predict step on TPU. @@ -1295,6 +1419,7 @@ class _ModelFnWrapper(object): """ host_calls = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() + captured_predict_hooks = _CapturedObject() def predict_step(unused_scalar_stopping_signal): """Evaluation step function for use inside a while loop.""" @@ -1315,6 +1440,7 @@ class _ModelFnWrapper(object): self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions) captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) + captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks) to_record = {} identity_fn = lambda **kwargs: kwargs to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] @@ -1326,7 +1452,8 @@ class _ModelFnWrapper(object): with ops.control_dependencies(host_calls.create_enqueue_op()): return _StopSignals.as_scalar_stopping_signal(stopping_signals) - return predict_step, host_calls, captured_scaffold_fn + return (predict_step, host_calls, captured_scaffold_fn, + captured_predict_hooks) def _verify_tpu_spec_predictions(self, predictions): """Validates TPUEstimatorSpec.predictions dict.""" @@ -1424,8 +1551,16 @@ class _ModelFnWrapper(object): if batch_size_for_model_fn is not None: _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) + running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode) + _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu) + + if not running_on_cpu: + user_context = tpu_context.TPUContext( + internal_ctx=self._ctx, call_from_input_fn=False) + _add_item_to_params(params, _CTX_KEY, user_context) + estimator_spec = self._model_fn(features=features, **kwargs) - if (self._ctx.is_running_on_cpu(is_export_mode) and + if (running_on_cpu and isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access # The estimator_spec will be passed to `Estimator` directly, which expects # type `EstimatorSpec`. @@ -1440,11 +1575,9 @@ class _ModelFnWrapper(object): err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.' if estimator_spec.training_chief_hooks: - raise ValueError(err_msg.format('training_chief_hooks')) - if estimator_spec.training_hooks: - raise ValueError(err_msg.format('training_hooks')) - if estimator_spec.evaluation_hooks: - raise ValueError(err_msg.format('evaluation_hooks')) + raise ValueError( + err_msg.format('training_chief_hooks') + 'If you want' + + ' to pass training hooks, please pass via training_hooks.') if estimator_spec.scaffold: logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. ' @@ -1565,7 +1698,7 @@ class _OutfeedHostCall(object): RuntimeError: If outfeed tensor is scalar. """ if not self._names: - return [] + return {} ret = {} # For each i, dequeue_ops[i] is a list containing the tensors from all @@ -1584,11 +1717,13 @@ class _OutfeedHostCall(object): # Outfeed ops execute on each replica's first logical core. Note: we must # constraint it such that we have at most one outfeed dequeue and enqueue # per replica. - tpu_device_placement_fn = self._ctx.tpu_device_placement_function for i in xrange(self._ctx.num_replicas): - with ops.device(tpu_device_placement_fn(i)): + host_device, ordinal_id = self._ctx.device_for_replica(i) + with ops.device(host_device): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( - dtypes=tensor_dtypes, shapes=tensor_shapes) + dtypes=tensor_dtypes, + shapes=tensor_shapes, + device_ordinal=ordinal_id) for j, item in enumerate(outfeed_tensors): dequeue_ops[j].append(item) @@ -1603,7 +1738,7 @@ class _OutfeedHostCall(object): # place all ops on tpu host if possible. # # TODO(jhseu): Evaluate whether this is right for summaries. - with ops.device(self._ctx.tpu_host_placement_function(core_id=0)): + with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)): for name in self._names: dequeue_ops = dequeue_ops_by_name[name] for i, item in enumerate(dequeue_ops): @@ -1712,6 +1847,9 @@ class InstallSignalHandlerHook(session_run_hook.SessionRunHook): class TPUEstimator(estimator_lib.Estimator): """Estimator with TPU support. + TPUEstimator also supports training on CPU and GPU. You don't need to define + a separate `tf.estimator.Estimator`. + TPUEstimator handles many of the details of running on TPU devices, such as replicating inputs and models for each core, and returning to host periodically to run hooks. @@ -1749,7 +1887,8 @@ class TPUEstimator(estimator_lib.Estimator): Current limitations: -------------------- - 1. TPU evaluation only works on a single host (one TPU worker). + 1. TPU evaluation only works on a single host (one TPU worker) except + BROADCAST mode. 2. `input_fn` for evaluation should **NOT** raise an end-of-input exception (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all @@ -1920,10 +2059,9 @@ class TPUEstimator(estimator_lib.Estimator): """Constructs an `TPUEstimator` instance. Args: - model_fn: Model function as required by `Estimator`. For training, the - returned `EstimatorSpec` cannot have hooks as it is not supported in - `TPUEstimator`. Instead, the user can pass the training hooks as - an argument to `TPUEstimator.train()`. + model_fn: Model function as required by `Estimator` which returns + EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks', + and `prediction_hooks` must not capure any TPU Tensor inside the model_fn. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If `None`, the model_dir in @@ -1988,7 +2126,7 @@ class TPUEstimator(estimator_lib.Estimator): if (config.tpu_config.per_host_input_for_training is tpu_config.InputPipelineConfig.PER_SHARD_V1 and - config.tpu_config.computation_shape): + config.tpu_config.num_cores_per_replica): raise ValueError( 'Model parallelism only supports per host input for training. ' 'Please adjust TPURunconfig.per_host_input_for_training.') @@ -2035,6 +2173,7 @@ class TPUEstimator(estimator_lib.Estimator): self._export_to_tpu = export_to_tpu self._is_input_fn_invoked = None + self._rendezvous = {} def _add_meta_graph_for_mode(self, builder, @@ -2043,24 +2182,29 @@ class TPUEstimator(estimator_lib.Estimator): strip_default_attrs, save_variables=True, mode=model_fn_lib.ModeKeys.PREDICT, - export_tags=None): + export_tags=None, + check_variables=True): if mode != model_fn_lib.ModeKeys.PREDICT: raise NotImplementedError( 'TPUEstimator only handles mode PREDICT for export_savedmodel(); ' 'got {}.'.format(mode)) - super(TPUEstimator, self)._add_meta_graph_for_mode(builder, - input_receiver_fn_map, - checkpoint_path, - strip_default_attrs, - save_variables, - mode=mode) + (super(TPUEstimator, self). + _add_meta_graph_for_mode(builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables, + mode=mode, + export_tags=export_tags, + check_variables=check_variables)) if self._export_to_tpu: input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE: input_receiver_fn_map[mode]} export_tags = [tag_constants.SERVING, tag_constants.TPU] mode = _REWRITE_FOR_INFERENCE_MODE + # See b/110052256 for why `check_variables` is `False`. (super(TPUEstimator, self). _add_meta_graph_for_mode(builder, input_receiver_fn_map, @@ -2068,7 +2212,8 @@ class TPUEstimator(estimator_lib.Estimator): strip_default_attrs, save_variables=False, mode=mode, - export_tags=export_tags)) + export_tags=export_tags, + check_variables=False)) def _call_model_fn(self, features, labels, mode, config): if mode == _REWRITE_FOR_INFERENCE_MODE: @@ -2272,6 +2417,65 @@ class TPUEstimator(estimator_lib.Estimator): """ pass + def train(self, + input_fn, + hooks=None, + steps=None, + max_steps=None, + saving_listeners=None): + rendezvous = error_handling.ErrorRendezvous(num_sources=3) + self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous + try: + return super(TPUEstimator, self).train( + input_fn=input_fn, hooks=hooks, steps=steps, max_steps=max_steps, + saving_listeners=saving_listeners + ) + except Exception: # pylint: disable=broad-except + rendezvous.record_error('training_loop', sys.exc_info()) + finally: + rendezvous.record_done('training_loop') + rendezvous.raise_errors() + + def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, + name=None): + rendezvous = error_handling.ErrorRendezvous(num_sources=3) + self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous + try: + return super(TPUEstimator, self).evaluate( + input_fn, steps=steps, hooks=hooks, checkpoint_path=checkpoint_path, + name=name + ) + except Exception: # pylint: disable=broad-except + rendezvous.record_error('evaluation_loop', sys.exc_info()) + finally: + rendezvous.record_done('evaluation_loop') + rendezvous.raise_errors() + + def predict(self, + input_fn, + predict_keys=None, + hooks=None, + checkpoint_path=None, + yield_single_examples=True): + rendezvous = error_handling.ErrorRendezvous(num_sources=3) + self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous + try: + for result in super(TPUEstimator, self).predict( + input_fn=input_fn, + predict_keys=predict_keys, + hooks=hooks, + checkpoint_path=checkpoint_path, + yield_single_examples=yield_single_examples): + yield result + except Exception: # pylint: disable=broad-except + rendezvous.record_error('prediction_loop', sys.exc_info()) + finally: + rendezvous.record_done('prediction_loop') + rendezvous.raise_errors() + + rendezvous.record_done('prediction_loop') + rendezvous.raise_errors() + def _augment_model_fn(self, model_fn, batch_axis): """Returns a new model_fn, which wraps the TPU support.""" @@ -2294,10 +2498,20 @@ class TPUEstimator(estimator_lib.Estimator): # Clear the bit. self._is_input_fn_invoked = None + # examples_hook is added to training_hooks for both CPU and TPU + # execution. + examples_hook = ExamplesPerSecondHook( + ctx.global_batch_size, + output_dir=self.model_dir, + every_n_steps=self._log_every_n_steps) + if ctx.is_running_on_cpu(is_export_mode=is_export_mode): logging.info('Running %s on CPU', mode) - return model_fn_wrapper.call_without_tpu( + estimator_spec = model_fn_wrapper.call_without_tpu( features, labels, is_export_mode=is_export_mode) + estimator_spec = estimator_spec._replace( + training_hooks=estimator_spec.training_hooks + (examples_hook,)) + return estimator_spec assert labels is None, '`labels` passed to `model_fn` must be `None`.' # TPUEstimator._call_input_fn passes `input_fn` as features to here. @@ -2316,7 +2530,7 @@ class TPUEstimator(estimator_lib.Estimator): graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op) if mode == model_fn_lib.ModeKeys.TRAIN: - loss, host_call, scaffold = ( + loss, host_call, scaffold, training_hooks = ( _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) host_ops = host_call.create_tpu_hostcall() if host_ops is None: @@ -2356,7 +2570,9 @@ class TPUEstimator(estimator_lib.Estimator): enqueue_ops, host_ops, run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator)), + run_infeed_loop_on_coordinator), + rendezvous=self._rendezvous[mode], + ), InstallSignalHandlerHook(), training.LoggingTensorHook( { @@ -2365,14 +2581,13 @@ class TPUEstimator(estimator_lib.Estimator): }, every_n_iter=logging_hook_frequency) ]) - examples_hook = ExamplesPerSecondHook( - ctx.global_batch_size, - output_dir=self.model_dir, - every_n_steps=self._log_every_n_steps) examples_hook._set_steps_per_run( # pylint: disable=protected-access self._config.tpu_config.iterations_per_loop) hooks.append(examples_hook) + if training_hooks: + hooks.extend(training_hooks) + chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): @@ -2384,6 +2599,7 @@ class TPUEstimator(estimator_lib.Estimator): checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access self._config.tpu_config.iterations_per_loop) chief_hooks.append(checkpoint_hook) + summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) with ops.control_dependencies([loss]): update_ops = _sync_variables_ops() @@ -2403,7 +2619,7 @@ class TPUEstimator(estimator_lib.Estimator): scaffold=scaffold) if mode == model_fn_lib.ModeKeys.EVAL: - total_loss, host_calls, scaffold = _eval_on_tpu_system( + total_loss, host_calls, scaffold, eval_hooks = _eval_on_tpu_system( ctx, model_fn_wrapper, dequeue_fn) iterations_per_loop_var = _create_or_get_iterations_per_loop() mean_loss = math_ops.div(total_loss, @@ -2428,7 +2644,8 @@ class TPUEstimator(estimator_lib.Estimator): host_call_ret = host_calls.create_tpu_hostcall() eval_metric_ops = {} eval_update_ops = [] - for k, v in host_call_ret['eval_metrics'].items(): + + for k, v in host_call_ret.get('eval_metrics', {}).items(): eval_metric_ops[k] = (v[0], dummy_update_op) eval_update_ops.append(v[1]) @@ -2442,9 +2659,13 @@ class TPUEstimator(estimator_lib.Estimator): enqueue_ops, eval_update_ops + host_ops, run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator)), + run_infeed_loop_on_coordinator), + rendezvous=self._rendezvous[mode]), ] + input_hooks + if eval_hooks: + hooks.extend(eval_hooks) + return model_fn_lib.EstimatorSpec( mode, loss=mean_loss, @@ -2455,8 +2676,9 @@ class TPUEstimator(estimator_lib.Estimator): # Predict assert mode == model_fn_lib.ModeKeys.PREDICT - dummy_predict_op, host_calls, scaffold = _predict_on_tpu_system( - ctx, model_fn_wrapper, dequeue_fn) + (dummy_predict_op, host_calls, + scaffold, prediction_hooks) = _predict_on_tpu_system( + ctx, model_fn_wrapper, dequeue_fn) with ops.control_dependencies([dummy_predict_op]): internal_ops_to_run = _sync_variables_ops() with ops.control_dependencies(internal_ops_to_run): @@ -2508,10 +2730,13 @@ class TPUEstimator(estimator_lib.Estimator): hooks = [ _StoppingPredictHook(scalar_stopping_signal), - TPUInfeedOutfeedSessionHookForPrediction(ctx, enqueue_ops, - host_ops), + TPUInfeedOutfeedSessionHookForPrediction( + ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]), ] + input_hooks + if prediction_hooks: + hooks.extend(prediction_hooks) + return model_fn_lib.EstimatorSpec( mode, prediction_hooks=hooks, @@ -2595,8 +2820,8 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" iterations_per_loop_var = _create_or_get_iterations_per_loop() - single_tpu_eval_step, host_calls, captured_scaffold_fn = ( - model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)) + (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks + ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn) def multi_tpu_eval_steps_on_single_shard(): return training_loop.repeat( @@ -2611,15 +2836,16 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): device_assignment=ctx.device_assignment) scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_calls, scaffold + return loss, host_calls, scaffold, captured_eval_hooks.get() def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" iterations_per_loop_var = _create_or_get_iterations_per_loop() - single_tpu_train_step, host_call, captured_scaffold_fn = ( - model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) + (single_tpu_train_step, host_call, captured_scaffold_fn, + captured_training_hooks) = ( + model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) def multi_tpu_train_steps_on_single_shard(): return training_loop.repeat( @@ -2634,15 +2860,16 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): device_assignment=ctx.device_assignment) scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_call, scaffold + return loss, host_call, scaffold, captured_training_hooks.get() def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" num_cores = ctx.num_cores - single_tpu_predict_step, host_calls, captured_scaffold_fn = ( - model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)) + (single_tpu_predict_step, host_calls, captured_scaffold_fn, + captured_predict_hooks + ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn) def multi_tpu_predict_steps_on_single_shard(): @@ -2659,10 +2886,11 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): multi_tpu_predict_steps_on_single_shard, inputs=[], num_shards=num_cores, - outputs_from_all_shards=False) + outputs_from_all_shards=False, + device_assignment=ctx.device_assignment) scaffold = _get_scaffold(captured_scaffold_fn) - return dummy_predict_op, host_calls, scaffold + return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get() def _wrap_computation_in_while_loop(device, op_fn): @@ -3115,7 +3343,7 @@ class _SignalsHelper(object): def __init__(self, signals): self._signal_keys = [] - for key in sorted(signals.iterkeys()): + for key in sorted(iter(signals.keys())): self._signal_keys.append(key) @property @@ -3127,7 +3355,7 @@ class _SignalsHelper(object): @staticmethod def as_tensor_list(signals): - return [signals[key] for key in sorted(signals.iterkeys())] + return [signals[key] for key in sorted(iter(signals.keys()))] def _verify_cross_hosts_transfer_size(tensor_dict, message): @@ -3153,9 +3381,53 @@ def _add_item_to_params(params, key, value): if isinstance(params, hparam.HParams): # For HParams, we need to use special API. if key in params: - params.key = value + params.set_hparam(key, value) else: params.add_hparam(key, value) else: # Now params is Python dict. params[key] = value + + +def export_estimator_savedmodel(estimator, + export_dir_base, + serving_input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + """Export `Estimator` trained model for TPU inference. + + Args: + estimator: `Estimator` with which model has been trained. + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + serving_input_receiver_fn: A function that takes no argument and + returns a `ServingInputReceiver` or `TensorServingInputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. + + Returns: + The string path to the exported directory. + """ + # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use + # `estimator.config`. + config = tpu_config.RunConfig(model_dir=estimator.model_dir) + est = TPUEstimator( + estimator._model_fn, # pylint: disable=protected-access + config=config, + params=estimator.params, + use_tpu=True, + train_batch_size=2048, # Does not matter. + eval_batch_size=2048, # Does not matter. + ) + return est.export_savedmodel(export_dir_base, serving_input_receiver_fn, + assets_extra, + as_text, + checkpoint_path, + strip_default_attrs) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py index 604e6600c81a4136a1f10e79a725a887a96f4d86..d9c77a3ea1bbc456f058f36d78eec1f0843ddc79 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -20,8 +20,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools + +import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding +from tensorflow.compiler.xla.python_api import xla_shape from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_sharding @@ -30,6 +35,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.util import nest class InfeedQueue(object): @@ -461,7 +467,10 @@ class InfeedQueue(object): name=full_name, device_ordinal=tpu_ordinal) - def generate_enqueue_ops(self, sharded_inputs, tpu_ordinal_function=None): + def generate_enqueue_ops(self, + sharded_inputs, + tpu_ordinal_function=None, + placement_function=None): """Generates the host-side Ops to enqueue the shards of a tuple. sharded_inputs is a list, one for each shard, of lists of @@ -483,6 +492,9 @@ class InfeedQueue(object): shard index as input and returns the ordinal of the TPU device the shard's infeed should be placed on. tpu_ordinal_function must be set if the inputs are placed on CPU devices. + placement_function: if not None, a function that takes the shard index as + input and returns the host device where the enqueue op should be placed + on. Returns: A list of host-side Ops, one for each shard, that when executed together @@ -508,8 +520,12 @@ class InfeedQueue(object): tpu_ordinal_function = lambda index: -1 name_prefix = "%s/enqueue" % self._name return [ - self._generate_enqueue_op(shard, name_prefix, index, - tpu_ordinal=tpu_ordinal_function(index)) + self._generate_enqueue_op( + shard, + name_prefix, + index, + tpu_ordinal=tpu_ordinal_function(index), + device=placement_function(index) if placement_function else None) for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) ] @@ -630,3 +646,264 @@ class InfeedQueue(object): tpu_ordinal=tpu_ordinal_function(index)) for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) ] + + +class _PartitionedInfeedQueue(InfeedQueue): + """A helper object to build a device infeed queue with input partition. + + Args: + number_of_tuple_elements: the number of Tensors fed atomically through the + queue, must be present unless it can be inferred from other arguments. + device_assignment: A TPU `DeviceAssignment` which is used to place all the + partitions to different TPU infeed queues. + host_id: The id of the host machine. + input_partition_dims: A nested list/tuple of integers. Each inner + list/tuple describes how to partition the corresponding input tensor. + tuple_types: If not None, a list of types of the elements of the queue. + tuple_shapes: If not None, a list of shapes of the elements of the queue. + name: The name of the queue. + """ + + def __init__(self, + number_of_tuple_elements, + device_assignment, + host_id, + input_partition_dims=None, + tuple_types=None, + tuple_shapes=None, + name=None): + super(_PartitionedInfeedQueue, self).__init__( + number_of_tuple_elements=number_of_tuple_elements, + tuple_types=tuple_types, + tuple_shapes=None, + shard_dimensions=None, + name="PartitionedInfeedQueue" if name is None else name) + self._input_partition_dims = input_partition_dims + self._host_id = host_id + self._device_assignment = device_assignment + + def generate_dequeue_op(self, tpu_device=0): + """Generate TPU dequeue ops. + + Args: + tpu_device: The TPU device ordinal where the infeed instruction should be + placed. + + Returns: + A list of Outputs corresponding to a partition of infeed dequeued + into XLA, suitable for use within a replicated block. + + Raises: + ValueError: if the types or shapes of the tuple elements have not been + set; or if a dequeue op has already been generated. + """ + self.freeze() + if self._generated_dequeue_op: + raise ValueError("Can't generate two dequeue Ops from the same queue") + self._generated_dequeue_op = True + full_name = "%s/dequeue" % self._name + sharded_shapes = [ + policy.get_sharded_shape(shape) + for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) + ] + with ops.device(tpu.core(tpu_device)): + values = tpu_ops.infeed_dequeue_tuple( + dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) + return self._tag_sharding_attribute_for_dequeued_tensors( + values, self._input_partition_dims) + + def generate_enqueue_ops(self, per_host_sharded_inputs): + """Generates the host-side Ops to enqueue the partitioned inputs. + + per_host_sharded_inputs is a list, one for each replica, of lists of + Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed + replica i. + sharded_inputs[i][j] is partitioned by self._input_partition_dims[j]. + + For example, if sharded_inputs[i][j] is a 2-D Tensor: + [[A, B, C, D], + [E ,F, G, H]] + self._input_partition_dims[j] is [2, 4]. + + sharded_inputs[i][j] will be partitioned and flattened into: + [A, B, C, D, E, F, G, H] and fed into the logical core ids: + [0, 1, 2, 3, 4, 5, 6, 7] respectively. + + Args: + per_host_sharded_inputs: a list of lists of Tensors. The length of the + outer list determines the number of shards. Each inner list indicates + the types and shapes of the tuples in the corresponding shard. + + Returns: + A list of host-side Ops, one for each shard, that when executed together + will enqueue a full-size element of infeed. + + Raises: + ValueError: if the queue configuration has previously been frozen and the + shapes of the elements of sharded_inputs are not compatible with the + frozen configuration; or if the shapes of the elements of sharded_inputs + don't form a consistent unsharded tuple; or if the elements of a tuple + have different device constraints; or if the partition dims are invalid. + TypeError: if the queue configuration has previously been frozen and the + types of the elements of sharded_inputs are not compatible with the + frozen configuration; or if the types of the elements of sharded_inputs + don't form a consistent unsharded tuple. + """ + self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs) + number_of_replicas_per_host = len(per_host_sharded_inputs) + number_of_tuple_elements = len(per_host_sharded_inputs[0]) + + assert len(self._input_partition_dims) == number_of_tuple_elements + per_host_enqueue_ops = [] + + for replica_index in range(number_of_replicas_per_host): + flattened_inputs = per_host_sharded_inputs[replica_index] + inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs, + self._input_partition_dims) + inputs_parted_iters = [ + iter(self._partition_or_replicate_on_host(x, dims)) for x, dims in + zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat) + ] + + for core_index in xrange(self._device_assignment.num_cores_per_replica): + # Places different partitions to different logic cores. + logical_core = self._get_logical_core(core_index) + replica_id = self._device_assignment.lookup_replicas( + self._host_id, logical_core)[replica_index] + ordinal = self._device_assignment.tpu_ordinal( + replica=replica_id, logical_core=logical_core) + infeed_inputs = [] + for it in inputs_parted_iters: + input_for_device = next(it, None) + if input_for_device is not None: + infeed_inputs.append(input_for_device) + + if infeed_inputs: + per_host_enqueue_ops.append( + tpu_ops.infeed_enqueue_tuple( + inputs=infeed_inputs, + shapes=[x.shape for x in infeed_inputs], + name="enqueue/replica_{0}/input_{1}".format( + replica_index, core_index), + device_ordinal=ordinal)) + return per_host_enqueue_ops + + def _check_input_partition_dims(self, tensor, dims): + """Checks that input partition dims are valid for the `Tensor`. + + Args: + tensor: Input tensor for partitioning. + dims: A list of integer describes how to partition the input tensor. + + Raises: + ValueError: If the tensor can't be partitioned by dims or the + num_cores_per_replica doesn't match the number of + partitions(dims.prod()). + """ + if dims is None: + return + + dims = np.array(dims) + + if (dims < 1).any(): + raise ValueError("All input partition dims must be >= 1.") + + # No partitioning, so don't perform further checks. + if dims.prod() == 1: + return + + if dims.prod() != self._device_assignment.num_cores_per_replica: + raise ValueError( + "The product of each input parition dim should equal to " + "num_cores_per_replica. (dim = {}, num_cores_per_replica " + "= {})".format(dims, self._device_assignment.num_cores_per_replica)) + if dims.shape[0] != tensor.shape.ndims: + raise ValueError( + "Input partition dims must have the same number of dimensions " + "as the `Tensor` to be partitioned. (tensor shape = {}, input " + "partition dims = {}).".format(tensor.shape.as_list(), dims)) + + tensor.shape.assert_is_fully_defined() + if (np.array(tensor.shape.as_list()) % dims != 0).any(): + raise ValueError( + "All input partition dims must divide exactly into the `Tensor` " + "shape (tensor shape = {}, input partition dims = {}).".format( + tensor.shape.as_list(), dims)) + + def _partition_or_replicate_on_host(self, tensor, dims): + """Partitions or replicates the input tensor. + + The ops inside this function are placed on the host side. + + Args: + tensor: The input tensor which will be partioned or replicated. + dims: A list of integer describes how to partition the input tensor. + Returns: + An iterator of `Tensor`s or a list of partioned tensors. + """ + self._check_input_partition_dims(tensor, dims) + if dims is None: + return itertools.repeat(tensor) + else: + output = [tensor] + for axis, dim in enumerate(dims): + if dim > 1: + output = [array_ops.split(x, dim, axis=axis) for x in output] + output = nest.flatten(output) + return output + + def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): + """Tags appropriate XLA sharding attribute to the dequeued tensor. + + Args: + tensor: The dequeued tensor on TPU. + dims: A list of integer describes how the tensor is partitioned. + + Returns: + The same tensor with the xla_sharding attribute. + """ + if dims is None: + return xla_sharding.replicate(tensor) + elif np.prod(dims) == 1: + return xla_sharding.assign_device(tensor, 0) + else: + tile_shape = np.array(tensor.shape.as_list()) // dims + tile_assignment = np.arange(np.prod(dims)).reshape(dims) + return xla_sharding.tile( + tensor=tensor, + tile_shape=xla_shape.CreateShapeFromDtypeAndTuple( + dtype=np.dtype(tensor.dtype.as_numpy_dtype), + shape_tuple=tile_shape), + tile_assignment=tile_assignment) + + def _tag_sharding_attribute_for_dequeued_tensors(self, dequeues, dims): + """Tags appropriate XLA sharding attribute to the dequeued tensors. + + Args: + dequeues: A list of dequeued tensors on TPU. + dims: A list of integer describes how the tensor is partitioned. + + Returns: + The same dequeues with appropriate xla_sharding attribute. + """ + nest.assert_shallow_structure(dequeues, dims) + return nest.map_structure_up_to( + dequeues, self._tag_sharding_attribute_for_dequeued_tensor, dequeues, + dims) + + def _get_logical_core(self, core_index): + """Maps the core index to the 3D coordinate within replica. + + The lowest dimension number in computation_shape is the slowest varying + dimension (most major). + + Args: + core_index: An integer represents the core index within replcia. + + Returns: + A tuple with three integers which represents the 3D coordinate. + """ + computation_shape = self._device_assignment.computation_shape + return (core_index // (computation_shape[1] * computation_shape[2]), + core_index % (computation_shape[1] * computation_shape[2]) // + computation_shape[2], core_index % computation_shape[2]) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py index e76cf83e4ddcd86ab3971bcecefe2e2dc979bf63..53d33f40777a1c6d93f19c30b2ef5902d63ad2fd 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py @@ -19,8 +19,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function +from tensorflow.python.framework import ops from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import optimizer @@ -32,7 +35,8 @@ class CrossShardOptimizer(optimizer.Optimizer): def __init__(self, opt, reduction=losses.Reduction.MEAN, - name="CrossShardOptimizer"): + name="CrossShardOptimizer", + group_assignment=None): """Construct a new cross-shard optimizer. Args: @@ -40,6 +44,8 @@ class CrossShardOptimizer(optimizer.Optimizer): reduction: The reduction to apply to the shard losses. name: Optional name prefix for the operations created when applying gradients. Defaults to "CrossShardOptimizer". + group_assignment: Optional list of group ids for applying the optimizer + to subgroups. Raises: ValueError: If reduction is not a valid cross-shard reduction. @@ -50,6 +56,35 @@ class CrossShardOptimizer(optimizer.Optimizer): super(CrossShardOptimizer, self).__init__(False, name) self._opt = opt self._reduction = reduction + self._group_assignment = group_assignment + + def _verify_and_get_subgroup_size(self, group_assignment, num_shards): + """Verify group_assignment and get the subgroup size". + + Args: + group_assignment: list of group ids for applying the optimizer + to subgroups. + num_shards: The number of TPU shards. + + Returns: + The size of one subgroup in group_assignment. + + Raises: + ValueError: If group_assignment is invalid. + """ + if not group_assignment: + return None + if len(group_assignment) != num_shards: + raise ValueError("The size of group_assignment does not equal to " + "num_shard({0}). Got group_assignment={1}".format( + num_shards, self._group_assignment)) + subgroup_size_list = dict(collections.Counter(group_assignment)).values() + if all(subgroup_size_list[0] == size for size in subgroup_size_list): + return subgroup_size_list[0] + else: + raise ValueError("The size of each subgroup in group_assignment must " + "be equal. Got group_assignment={}".format( + self._group_assignment)) def compute_gradients(self, loss, var_list=None, **kwargs): """Compute gradients of "loss" for the variables in "var_list". @@ -71,7 +106,8 @@ class CrossShardOptimizer(optimizer.Optimizer): A list of (gradient, variable) pairs. Raises: - ValueError: If not within a tpu_shard_context. + ValueError: If not within a tpu_shard_context or group_assignment is + invalid. """ num_shards = tpu_function.get_tpu_context().number_of_shards if num_shards is None: @@ -79,9 +115,17 @@ class CrossShardOptimizer(optimizer.Optimizer): "CrossShardOptimizer should be used within a tpu_shard_context, but " "got unset number_of_shards. Assuming 1.") num_shards = 1 + + subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment, + num_shards) + if num_shards > 1 and self._reduction == losses.Reduction.MEAN: - scale = 1.0 / num_shards + if self._group_assignment: + scale = 1.0 / subgroup_size + else: + scale = 1.0 / num_shards loss *= scale + return self._opt.compute_gradients(loss, var_list=var_list, **kwargs) def apply_gradients(self, grads_and_vars, global_step=None, name=None): @@ -110,7 +154,9 @@ class CrossShardOptimizer(optimizer.Optimizer): if grad is None: summed_grads_and_vars.append((grad, var)) else: - summed_grads_and_vars.append((tpu_ops.cross_replica_sum(grad), var)) + with ops.colocate_with(grad): + summed_grads_and_vars.append((tpu_ops.cross_replica_sum( + grad, self._group_assignment), var)) return self._opt.apply_gradients(summed_grads_and_vars, global_step, name) def get_slot(self, *args, **kwargs): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index 894f21d0635ca47d3da1c0d2c3f5c37bac690920..ec682e5829c4df536a043334b74200f0b6259df3 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -45,7 +45,7 @@ _TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [ ]) -def _query_tpu_system_metadata(master_address, run_config, +def _query_tpu_system_metadata(master_address, cluster_def=None, query_topology=False): """Automatically detects the TPU system metadata in the system.""" tpu_core_count = 0 @@ -61,7 +61,8 @@ def _query_tpu_system_metadata(master_address, run_config, with session_lib.Session( master_address, config=get_session_config_with_timeout( - _PINGING_MASTER_TIMEOUT_IN_MS, run_config)) as sess: + _PINGING_MASTER_TIMEOUT_IN_MS, + cluster_def)) as sess: devices = sess.list_devices() for device in devices: match = _TPU_DEVICE_REG.match(device.name) @@ -105,7 +106,7 @@ def _query_tpu_system_metadata(master_address, run_config, 'TPU worker has some problems. Available devices: {}'.format( master_address, devices)) - topology = _obtain_topology(master_address, run_config) + topology = _obtain_topology(master_address, cluster_def) metadata = _TPUSystemMetadata( num_cores=tpu_core_count, @@ -127,14 +128,15 @@ def _query_tpu_system_metadata(master_address, run_config, return metadata -def _obtain_topology(master_address, run_config): +def _obtain_topology(master_address, cluster_def): + """Obtains TPU fabric topology.""" try: logging.info('Initializing TPU system (master: %s) to fetch topology ' 'for model parallelism. This might take a while.', master_address) with ops.Graph().as_default(): session_config = get_session_config_with_timeout( - _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, run_config) + _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def) with session_lib.Session( master_address, config=session_config) as sess: topology = sess.run(tpu.initialize_system()) @@ -146,11 +148,8 @@ def _obtain_topology(master_address, run_config): master_address)) -def get_session_config_with_timeout(timeout_in_secs, run_config): - cluster_def = None - if run_config.session_config and run_config.session_config.cluster_def.job: - cluster_def = run_config.session_config.cluster_def - +def get_session_config_with_timeout(timeout_in_secs, cluster_def): + """Returns a session given a timeout and a cluster configuration.""" config = config_pb2.ConfigProto( operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def) return config diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 5de55b5f7f2a41ac6edd27e5a102e565f33df12c..ddf8365d6130dcb4c8234ac60c91955d007e2410 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -61,7 +61,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/data", - "//tensorflow/python/estimator:inputs_queues", + "//tensorflow/python/estimator:estimator_py", "//third_party/py/numpy", "@six_archive//:six", ], @@ -133,7 +133,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:session", "//tensorflow/python:training", - "//tensorflow/python/estimator:inputs_queues", + "//tensorflow/python/estimator:estimator_py", "//third_party/py/numpy", ], ) @@ -295,7 +295,7 @@ py_test( tags = ["notsan"], deps = [ ":training_py", - "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test", + "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py index f7fd66d33fc0c329db7daaf87373385156d84217..01bac891da7ddf8523e6cc8c99decf4a61aa2741 100644 --- a/tensorflow/contrib/training/python/training/evaluation.py +++ b/tensorflow/contrib/training/python/training/evaluation.py @@ -142,9 +142,9 @@ from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import evaluation from tensorflow.python.training import monitored_session -from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util @@ -189,7 +189,7 @@ def wait_for_new_checkpoint(checkpoint_dir, logging.info('Waiting for new checkpoint at %s', checkpoint_dir) stop_time = time.time() + timeout if timeout is not None else None while True: - checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir) if checkpoint_path is None or checkpoint_path == last_checkpoint: if stop_time is not None and time.time() + seconds_to_sleep > stop_time: return None diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py index 39d75a080604e3a7ae93391652d4c03be9857218..53e4f23a7cd940c026e462dc7fb55cf9f175bf02 100644 --- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py +++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py @@ -988,14 +988,14 @@ class SequenceQueueingStateSaver(object): assert isinstance(sequences, dict) assert isinstance(context, dict) assert isinstance(states, dict) - self._name_to_index = dict( - (name, ix) + self._name_to_index = { + name: ix for (ix, name) in enumerate([ "__length", "__total_length", "__next_key", "__sequence", "__sequence_count" ] + ["__sequence__%s" % k for k in sequences.keys()] + [ "__context__%s" % k for k in context.keys() - ] + ["__state__%s" % k for k in states.keys()])) + ] + ["__state__%s" % k for k in states.keys()])} self._index_to_name = [ name for (name, _) in sorted( diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py new file mode 100644 index 0000000000000000000000000000000000000000..ed0f398e30a7f3c0b1b9378f8fc5d5bfbea1536a --- /dev/null +++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py @@ -0,0 +1,187 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""SGDR learning rate decay function.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops, control_flow_ops + + +def sgdr_decay(learning_rate, global_step, initial_period_steps, + t_mul=2.0, m_mul=1.0, name=None): + """Implements Stochastic Gradient Descent with Warm Restarts (SGDR). + + As described in "SGDR: Stochastic Gradient Descent + with Warm Restarts" by Ilya Loshchilov & Frank Hutter, Proceedings of + ICLR'2017, available at https://arxiv.org/pdf/1608.03983.pdf + + The learning rate decreases according to cosine annealing: + + ```python + learning_rate * 0.5 * (1 + cos(x_val * pi)) # for x_val defined in [0, 1] + ``` + + Thus, at the beginning (when the restart index i = 0), + the learning rate decreases for `initial_period_steps` steps from the initial + learning rate `learning_rate` (when `x_val=0`, we get `cos(0)=1`) to + 0 (when `x_val=1`, we get `cos(pi)=-1`). + + The decrease within the i-th period takes `t_i` steps, + where `t_0` = `initial_period_steps` is the user-defined number of batch + iterations (not epochs as in the paper) to be performed before the first + restart is launched. + + Then, we perform the first restart (i=1) by setting the learning rate to + `learning_rate*(m_mul^i)`, where `m_mul in [0,1]` (set to 1 by default). + The i-th restart runs for `t_i=t_0*(t_mul^i)` steps, i.e., every new + restart runs `t_mul` times longer than the previous one. + + Importantly, when one has no access to a validation set, SGDR suggests + to report the best expected / recommended solution in the following way: + When we are within our initial run (i=0), every new solution represents + SGDR's recommended solution. Instead, when i>0, the recommended solution is + the one obtained at the end of each restart. + + Note that the minimum learning rate is set to 0 for simplicity, + you can adjust the code to deal with any positive minimum learning rate + as defined in the paper. + + `initial_period_steps` is the duration of the first period measured in terms + of number of minibatch updates. If one wants to use epochs, one should compute + the number of updates required for an epoch. + + For example, assume the following parameters and intention: + Minibatch size: 100 + Training dataset size: 10000 + If the user wants the first decay period to span across 5 epochs, then + `initial_period_steps` = 5 * 10000/100 = 500 + + Train for 10000 batch iterations with the initial learning rate set to + 0.1, then restart to run 2 times longer, i.e, for 20000 batch iterations + and with the initial learning rate 0.05, then restart again and again, + doubling the runtime of each new period and with two times smaller + initial learning rate. + + To accomplish the above, one would write: + + ```python + ... + global_step = tf.Variable(0, trainable=False) + starter_learning_rate = 0.1 + learning_rate = sgdr_decay(starter_learning_rate, global_step, + initial_period_steps=10000, t_mul=2, m_mul=0.5) + # Passing global_step to minimize() will increment it at each step. + learning_step = ( + tf.train.GradientDescentOptimizer(learning_rate) + .minimize(...my loss..., global_step=global_step) + ) + + # Step | 0 | 1000 | 5000 | 9000 | 9999 | 10000 | 11000 | + # LR | 0.1 | 0.097 | 0.05 | 0.002 | 0.00 | 0.05 | 0.0496 | + + # Step | 20000 | 29000 | 29999 | 30000 | + # LR | 0.025 | 0.0003 | 0.00 | 0.025 | + ``` + + Args: + learning_rate: A scalar `float32` or `float64` `Tensor` or a + Python number. The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. + Global step to use for the decay computation. Must not be negative. + initial_period_steps: Duration of the first period measured as the number + of minibatch updates, if one wants to use epochs, one should compute + the number of updates required for an epoch. + t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. + Must be positive. + Used to derive the number of iterations in the i-th period: + `initial_period_steps * (t_mul^i)`. Defaults to 2.0. + m_mul: A scalar `float32` or `float64` `Tensor` or a Python number. + Must be positive. + Used to derive the initial learning rate of the i-th period: + `learning_rate * (m_mul^i)`. Defaults to 1.0 + + Returns: + A scalar `Tensor` of the same type as `learning_rate`. + The learning rate for a provided global_step. + Raises: + ValueError: if `global_step` is not supplied. + """ + + if global_step is None: + raise ValueError("global_step is required for sgdr_decay.") + with ops.name_scope(name, "SGDRDecay", + [learning_rate, global_step, + initial_period_steps, t_mul, m_mul]) as name: + learning_rate = ops.convert_to_tensor(learning_rate, + name="initial_learning_rate") + dtype = learning_rate.dtype + global_step = math_ops.cast(global_step, dtype) + t_0 = math_ops.cast(initial_period_steps, dtype) + t_mul = math_ops.cast(t_mul, dtype) + m_mul = math_ops.cast(m_mul, dtype) + + c_one = math_ops.cast(constant_op.constant(1.0), dtype) + c_half = math_ops.cast(constant_op.constant(0.5), dtype) + c_pi = math_ops.cast(constant_op.constant(math.pi), dtype) + + # Find normalized value of the current step + x_val = math_ops.div(global_step, t_0) + + def compute_step(x_val, geometric=False): + if geometric: + # Consider geometric series where t_mul != 1 + # 1 + t_mul + t_mul^2 ... = (1 - t_mul^i_restart) / (1 - t_mul) + + # First find how many restarts were performed for a given x_val + # Find maximal integer i_restart value for which this equation holds + # x_val >= (1 - t_mul^i_restart) / (1 - t_mul) + # x_val * (1 - t_mul) <= (1 - t_mul^i_restart) + # t_mul^i_restart <= (1 - x_val * (1 - t_mul)) + + # tensorflow allows only log with base e + # i_restart <= log(1 - x_val * (1 - t_mul) / log(t_mul) + # Find how many restarts were performed + + i_restart = math_ops.floor( + math_ops.log(c_one - x_val * (c_one - t_mul)) / math_ops.log(t_mul)) + # Compute the sum of all restarts before the current one + sum_r = (c_one - t_mul ** i_restart) / (c_one - t_mul) + # Compute our position within the current restart + x_val = (x_val - sum_r) / t_mul ** i_restart + + else: + # Find how many restarts were performed + i_restart = math_ops.floor(x_val) + # Compute our position within the current restart + x_val = x_val - i_restart + return i_restart, x_val + + i_restart, x_val = control_flow_ops.cond( + math_ops.equal(t_mul, c_one), + lambda: compute_step(x_val, geometric=False), + lambda: compute_step(x_val, geometric=True)) + + # If m_mul < 1, then the initial learning rate of every new restart will be + # smaller, i.e., by a factor of m_mul ** i_restart at i_restart-th restart + m_fac = learning_rate * (m_mul ** i_restart) + + return math_ops.multiply(c_half * m_fac, + (math_ops.cos(x_val * c_pi) + c_one), name=name) diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4a46e9a49ef203384e36698f81d6cbe3a3881ef8 --- /dev/null +++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py @@ -0,0 +1,145 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functional test for sgdr learning rate decay.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from sgdr_learning_rate_decay import sgdr_decay +from tensorflow.python.platform import googletest +from tensorflow.python.framework import test_util +from tensorflow.python.framework import dtypes +from tensorflow import placeholder + + +class SGDRDecayTest(test_util.TensorFlowTestCase): + """Unit tests for SGDR learning rate decay.""" + + def get_original_values(self, lr, t_e, mult_factor, iter_per_epoch, epochs): + """Get an array with learning rate values from the consecutive steps using + the original implementation + (https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py).""" + t0 = math.pi / 2.0 + tt = 0 + te_next = t_e + + lr_values = [] + sh_lr = lr + for epoch in range(epochs): + for _ in range(iter_per_epoch): + # In the original approach training function is executed here + lr_values.append(sh_lr) + dt = 2.0 * math.pi / float(2.0 * t_e) + tt = tt + float(dt) / iter_per_epoch + if tt >= math.pi: + tt = tt - math.pi + cur_t = t0 + tt + new_lr = lr * (1.0 + math.sin(cur_t)) / 2.0 # lr_min = 0, lr_max = lr + sh_lr = new_lr + if (epoch + 1) == te_next: # time to restart + sh_lr = lr + tt = 0 # by setting to 0 we set lr to lr_max, see above + t_e = t_e * mult_factor # change the period of restarts + te_next = te_next + t_e # note the next restart's epoch + + return lr_values + + def get_sgdr_values(self, lr, initial_period_steps, t_mul, iters): + """Get an array with learning rate values from the consecutive steps + using current tensorflow implementation.""" + with self.test_session(): + step = placeholder(dtypes.int32) + + decay = sgdr_decay(lr, step, initial_period_steps, t_mul) + lr_values = [] + for i in range(iters): + lr_values.append(decay.eval(feed_dict={step: i})) + + return lr_values + + def testCompareToOriginal(self): + """Compare values generated by tensorflow implementation to the values + generated by the original implementation + (https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py).""" + with self.test_session(): + lr = 10.0 + init_steps = 2 + t_mul = 3 + iters = 10 + epochs = 50 + + org_lr = self.get_original_values(lr, init_steps, t_mul, iters, epochs) + sgdr_lr = self.get_sgdr_values(lr, init_steps*iters, t_mul, iters*epochs) + + for org, sgdr in zip(org_lr, sgdr_lr): + self.assertAllClose(org, sgdr) + + def testMDecay(self): + """Test m_mul argument. Check values for learning rate at the beginning + of the first, second, third and fourth period. """ + with self.test_session(): + step = placeholder(dtypes.int32) + + lr = 0.1 + t_e = 10 + t_mul = 3 + m_mul = 0.9 + + decay = sgdr_decay(lr, step, t_e, t_mul, m_mul) + + test_step = 0 + self.assertAllClose(decay.eval(feed_dict={step: test_step}), + lr) + + test_step = t_e + self.assertAllClose(decay.eval(feed_dict={step: test_step}), + lr * m_mul) + + test_step = t_e + t_e*t_mul + self.assertAllClose(decay.eval(feed_dict={step: test_step}), + lr * m_mul**2) + + test_step = t_e + t_e*t_mul + t_e * (t_mul**2) + self.assertAllClose(decay.eval(feed_dict={step: test_step}), + lr * (m_mul**3)) + + def testCos(self): + """Check learning rate values at the beginning, in the middle + and at the end of the period.""" + with self.test_session(): + step = placeholder(dtypes.int32) + lr = 0.2 + t_e = 1000 + t_mul = 1 + + decay = sgdr_decay(lr, step, t_e, t_mul) + + test_step = 0 + self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr) + + test_step = t_e//2 + self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2) + + test_step = t_e + self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr) + + test_step = t_e*3//2 + self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2) + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py index a2444934bc21d58ed57d15494b3548a31ce3a2df..f46d03209ce7b111415b61181906c496f8181e71 100644 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py @@ -156,7 +156,7 @@ def prepend_from_queue_and_padded_batch_dataset(batch_size, Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py index 0338f409a203c232e63e99534a8f6d6a43fa661e..df0a186f4f6963d7e874bb4ab74a8db7e10a52ee 100644 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py @@ -19,7 +19,7 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py index f72e0a3f831f9e9c61a2e9d77828ffb12d8428b1..c272a2ac144068cfb7355c2647eebf5bd0ce9d50 100644 --- a/tensorflow/contrib/training/python/training/training.py +++ b/tensorflow/contrib/training/python/training/training.py @@ -484,7 +484,8 @@ def train(train_op, save_checkpoint_secs=600, save_summaries_steps=100, config=None, - max_wait_secs=7200): + max_wait_secs=7200, + run_metadata=None): """Runs the training loop. Args: @@ -511,6 +512,7 @@ def train(train_op, become available. This should be kept relatively short to help detect incorrect code, but sometimes may need to be increased if the chief takes a while to start up. + run_metadata: A [`RunMetadata`] protocol buffer. Returns: the value of the loss function after training. @@ -541,5 +543,5 @@ def train(train_op, max_wait_secs=max_wait_secs) as session: loss = None while not session.should_stop(): - loss = session.run(train_op) + loss = session.run(train_op, run_metadata=run_metadata) return loss diff --git a/tensorflow/contrib/training/python/training/training_test.py b/tensorflow/contrib/training/python/training/training_test.py index 4877c010fad2c567d26b9674d2904274c0895f55..94cf7788b2bd3bc3fe87eefd599ce88de03042af 100644 --- a/tensorflow/contrib/training/python/training/training_test.py +++ b/tensorflow/contrib/training/python/training/training_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import gradient_descent from tensorflow.python.training import monitored_session from tensorflow.python.training import saver as saver_lib @@ -421,7 +422,7 @@ class TrainTest(test.TestCase): train_op = self.create_train_op() model_variables = variables_lib2.global_variables() - model_path = saver_lib.latest_checkpoint(logdir1) + model_path = checkpoint_management.latest_checkpoint(logdir1) assign_fn = variables_lib.assign_from_checkpoint_fn( model_path, model_variables) diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD index 1b45584dcb84fe62de6cc14017e7cae575f99b2f..19cb8983b6836266ebfac70c54657a96324e8435 100644 --- a/tensorflow/contrib/verbs/BUILD +++ b/tensorflow/contrib/verbs/BUILD @@ -53,12 +53,12 @@ cc_library( ":grpc_verbs_service_impl", ":rdma_mgr", ":verbs_service_proto_cc", + "//tensorflow:grpc++", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:session_mgr", "//tensorflow/core/distributed_runtime/rpc:async_service_interface", "//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "@grpc//:grpc++", ], alwayslink = 1, ) @@ -69,7 +69,7 @@ cc_library( hdrs = ["grpc_verbs_service_impl.h"], deps = [ ":verbs_service_proto_cc", - "@grpc//:grpc++", + "//tensorflow:grpc++", ], ) diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 86350a08e57e5050f18d019fe80d70f6381c1f7d..f7c979e86320d59ad033e2b8d7fcdff89ce0d133 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -24,8 +24,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/process_util.h" #if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" #endif #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" @@ -1084,7 +1084,7 @@ void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed, // The tensor must be copied from GPU to CPU, because either: // 1. The tensor is located on a non GDR compatible GPU. // 2. The tensor's meta-data has changed. - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); copy = Tensor(alloc, in.dtype(), in.shape()); CountCopies(rm_.name_, (void*)DMAHelper::base(&in), (void*)DMAHelper::base(©), in.TotalBytes(), true); @@ -1541,7 +1541,7 @@ bool RdmaTensorRequest::AllocateTensors() { if (mr_ == nullptr) { // Can't RDMA directly to result. Use a proxy. proxy_tensor_ = - new Tensor(ProcessState::singleton()->GetCUDAHostAllocator(0), + new Tensor(GPUProcessState::singleton()->GetCUDAHostAllocator(0), result_tensor_->dtype(), result_tensor_->shape()); rdma_addr_ = DMAHelper::base(proxy_tensor_); mr_ = diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc index 369bd986df5313955bc22d6e5c6d38815908ada3..3cb5e61facf860f2740935f66bf548096296280f 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_mgr.cc @@ -21,8 +21,10 @@ limitations under the License. #include "tensorflow/contrib/verbs/grpc_verbs_client.h" #include "tensorflow/contrib/verbs/verbs_service.pb.h" #include "tensorflow/core/common_runtime/bfc_allocator.h" +#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/core/common_runtime/pool_allocator.h" +#include "tensorflow/core/common_runtime/process_state.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/framework/allocator_registry.h" @@ -254,37 +256,34 @@ void MRDeleter(ibv_mr* mr) { } } -// TODO(byronyi): remove this class duplicated from the one in -// common/runtime/gpu/pool_allocator.h when it is available in common_runtime -class BasicCPUAllocator : public SubAllocator { - public: - ~BasicCPUAllocator() override {} - - void* Alloc(size_t alignment, size_t num_bytes) override { - return port::AlignedMalloc(num_bytes, alignment); - } - void Free(void* ptr, size_t) override { port::AlignedFree(ptr); } -}; - // TODO(byronyi): remove this class and its registration when the default -// cpu_allocator() returns visitable allocator +// cpu_allocator() returns visitable allocator, or cpu_allocator() is no +// longer in use. class BFCRdmaAllocator : public BFCAllocator { public: BFCRdmaAllocator() - : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") { + : BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36, + true, "cpu_rdma_bfc") {} +}; +class BFCRdmaAllocatorFactory : public AllocatorFactory { + public: + Allocator* CreateAllocator() { return new BFCRdmaAllocator; } + + SubAllocator* CreateSubAllocator(int numa_node) { + return new BasicCPUAllocator(numa_node); } }; -REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator); +REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocatorFactory); void RdmaMgr::InitAllocators() { RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_; Allocator* allocators[] = { #if GOOGLE_CUDA - ProcessState::singleton()->GetCUDAHostAllocator(0), - ProcessState::singleton()->GetCPUAllocator(0), + GPUProcessState::singleton()->GetCUDAHostAllocator(0), #endif // GOOGLE_CUDA + ProcessState::singleton()->GetCPUAllocator(0), cpu_allocator(), }; @@ -323,7 +322,8 @@ void RdmaMgr::InitAllocators() { std::bind(&RdmaMemoryMgr::InsertMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2, std::string(buf)); - ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor); + GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id, + cuda_alloc_visitor); LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; } #endif // GOOGLE_CUDA diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b6b48a077cdafe12aeb1e4e0988493692c82eace..aef652e18582ee743eae0d39bc30612ca6bd29b0 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -89,6 +89,7 @@ load( "tf_generate_proto_text_sources", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_android", + "tf_features_nomodules_if_android", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl") load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") @@ -148,8 +149,8 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library") load( "//third_party/mkl:build_defs.bzl", "if_mkl", + "mkl_deps", ) -load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library") exports_files(["ops/ops.pbtxt"]) @@ -233,7 +234,6 @@ tf_proto_library( srcs = [], cc_api_version = 2, default_header = True, - j2objc_api_version = 1, java_api_version = 2, js_api_version = 2, protodeps = [ @@ -334,6 +334,7 @@ filegroup( "platform/init_main.h", "platform/mem.h", "platform/mutex.h", + "platform/numa.h", "platform/thread_annotations.h", ], visibility = ["//visibility:private"], @@ -662,6 +663,7 @@ cc_library( "lib/random/random_distributions.h", "lib/random/simple_philox.h", "lib/strings/numbers.h", + "lib/strings/proto_serialization.h", "lib/strings/str_util.h", "lib/strings/strcat.h", "lib/strings/stringprintf.h", @@ -734,7 +736,10 @@ cc_library( "util/reporter.h", ], copts = tf_copts(), - linkopts = ["-lm"], + linkopts = select({ + "//tensorflow:windows": [], + "//conditions:default": ["-lm"], + }), visibility = ["//visibility:public"], deps = [ ":lib", @@ -792,6 +797,7 @@ tf_cuda_library( "framework/graph_def_util.h", "framework/graph_to_functiondef.h", "framework/kernel_def_builder.h", + "framework/kernel_def_util.h", "framework/log_memory.h", "framework/lookup_interface.h", "framework/memory_types.h", @@ -845,6 +851,7 @@ tf_cuda_library( "util/sparse/sparse_tensor.h", "util/stat_summarizer.h", "util/stat_summarizer_options.h", + "util/status_util.h", "util/stream_executor_util.h", "util/strided_slice_op.h", "util/tensor_format.h", @@ -857,7 +864,6 @@ tf_cuda_library( "util/work_sharder.h", ] + select({ "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "util/memmapped_file_system.h", "util/memmapped_file_system_writer.h", @@ -881,6 +887,16 @@ cc_library( copts = tf_copts(), ) +tf_cc_test( + name = "stats_calculator_test", + srcs = ["util/stats_calculator_test.cc"], + deps = [ + ":stats_calculator_portable", + ":test", + ":test_main", + ], +) + cc_library( name = "overflow", hdrs = ["util/overflow.h"], @@ -901,6 +917,15 @@ cc_library( hdrs = ["util/ptr_util.h"], ) +cc_library( + name = "status_util", + hdrs = ["util/status_util.h"], + deps = [ + ":graph", + ":lib", + ], +) + cc_library( name = "reader_base", srcs = ["framework/reader_base.cc"], @@ -998,6 +1023,7 @@ tf_gen_op_libs( "nn_ops", "no_op", "parsing_ops", + "random_grad", "random_ops", "remote_fused_graph_ops", "resource_variable_ops", @@ -1196,6 +1222,7 @@ tf_cuda_library( hdrs = [ "common_runtime/device.h", "common_runtime/device_factory.h", + "common_runtime/function.h", "common_runtime/optimization_registry.h", "common_runtime/shape_refiner.h", "graph/algorithm.h", @@ -1250,6 +1277,7 @@ cc_library( "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:functional_ops", + "//tensorflow/core/kernels:grappler", "//tensorflow/core/kernels:histogram_op", "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", @@ -1631,6 +1659,7 @@ cc_library( copts = tf_copts(android_optimization_level_override = None) + [ "-DSUPPORT_SELECTIVE_REGISTRATION", ], + linkopts = if_android(["-lz"]), tags = [ "manual", "notap", @@ -1654,6 +1683,7 @@ cc_library( copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [ "-DSUPPORT_SELECTIVE_REGISTRATION", ], + linkopts = if_android(["-lz"]), tags = [ "manual", "notap", @@ -1939,8 +1969,10 @@ LIB_INTERNAL_PRIVATE_HEADERS = ["framework/resource_handle.h"] + glob( "**/*test*", "lib/gif/**/*", "lib/jpeg/**/*", + "lib/png/**/*", "platform/gif.h", "platform/jpeg.h", + "platform/png.h", "platform/**/cuda.h", "platform/**/stream_executor.h", ], @@ -2007,7 +2039,7 @@ cc_library( linkopts = select({ "//tensorflow:freebsd": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], + "//tensorflow:android": [], "//conditions:default": [ "-ldl", "-lpthread", @@ -2035,6 +2067,7 @@ cc_library( "lib/hash/crc32c_accelerate.cc", "lib/gif/**/*", "lib/jpeg/**/*", + "lib/png/**/*", "platform/**/env_time.cc", "platform/**/cuda_libdevice_path.cc", "platform/**/device_tracer.cc", @@ -2095,7 +2128,6 @@ cc_library( linkopts = select({ "//tensorflow:freebsd": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": ["-ldl"], }), deps = [ @@ -2120,7 +2152,6 @@ cc_library( linkopts = select({ "//tensorflow:freebsd": [], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": ["-ldl"], }), deps = [ @@ -2130,6 +2161,38 @@ cc_library( ], ) +cc_library( + name = "png_internal", + srcs = ["lib/png/png_io.cc"], + hdrs = [ + "lib/bfloat16/bfloat16.h", + "lib/core/casts.h", + "lib/core/stringpiece.h", + "lib/png/png_io.h", + "platform/byte_order.h", + "platform/cpu_info.h", + "platform/default/integral_types.h", + "platform/default/logging.h", + "platform/logging.h", + "platform/macros.h", + "platform/platform.h", + "platform/png.h", + "platform/types.h", + ], + copts = tf_copts(), + linkopts = select({ + "//tensorflow:freebsd": [], + "//tensorflow:windows": [], + "//conditions:default": ["-ldl"], + }), + deps = [ + ":lib", + ":lib_internal", + "//tensorflow/core/platform/default/build_config:png", + "@zlib_archive//:zlib", + ], +) + cc_library( name = "tflite_portable_logging", srcs = [], @@ -2175,6 +2238,7 @@ cc_library( linkopts = ["-ldl"], deps = [ "//tensorflow/core/platform/default/build_config:jpeg", + "//tensorflow/core/platform/default/build_config:logging", ], ) @@ -2203,6 +2267,7 @@ cc_library( linkopts = ["-ldl"], deps = [ "//tensorflow/core/platform/default/build_config:gif", + "//tensorflow/core/platform/default/build_config:logging", ], ) @@ -2229,6 +2294,7 @@ cc_library( copts = tf_copts(), linkopts = ["-ldl"], deps = [ + "//tensorflow/core/platform/default/build_config:logging", "@png_archive//:png", ], ) @@ -2238,7 +2304,6 @@ tf_proto_library( srcs = ERROR_CODES_PROTO_SRCS, cc_api_version = 2, default_header = True, - j2objc_api_version = 1, java_api_version = 2, js_api_version = 2, provide_cc_alias = True, @@ -2260,7 +2325,6 @@ tf_proto_library( srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS, cc_api_version = 2, default_header = True, - j2objc_api_version = 1, java_api_version = 2, js_api_version = 2, protodeps = [ @@ -2339,6 +2403,7 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [ "framework/op_segment.h", "framework/rendezvous.h", # only needed for tests + "framework/resource_var.h", "framework/tensor_reference.h", "framework/tracking_allocator.h", # only needed for tests "framework/unique_tensor_references.h", @@ -2416,11 +2481,11 @@ tf_cuda_library( "framework/resource_handle.cc", "util/memmapped_file_system.*", "util/memmapped_file_system_writer.*", + "util/stats_calculator.*", "util/version_info.cc", ], ) + select({ "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "util/memmapped_file_system.cc", "util/memmapped_file_system_writer.cc", @@ -2429,19 +2494,20 @@ tf_cuda_library( hdrs = FRAMEWORK_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), linkopts = select({ - "//tensorflow:freebsd": [], + "//tensorflow:freebsd": ["-lm"], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], - "//conditions:default": ["-ldl"], - }) + [ - "-lm", - ], + "//conditions:default": [ + "-ldl", + "-lm", + ], + }), deps = [ ":lib", ":lib_internal", ":protos_all_proto_text", ":error_codes_proto_text", ":protos_all_cc", + ":stats_calculator_portable", ":version_lib", "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/kernels:bounds_check", @@ -2449,12 +2515,7 @@ tf_cuda_library( ] + if_static( extra_deps = ["@protobuf_archive//:protobuf"], otherwise = ["@protobuf_archive//:protobuf_headers"], - ) + if_mkl( - [ - "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], - ), + ) + mkl_deps(), alwayslink = 1, ) @@ -2657,6 +2718,8 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/step_stats_collector.h", "common_runtime/threadpool_device.h", "common_runtime/visitable_allocator.h", + "common_runtime/process_state.h", + "common_runtime/pool_allocator.h", "graph/gradients.h", "graph/quantize_training.h", ] + if_mkl(["graph/mkl_graph_util.h"]) @@ -2695,7 +2758,9 @@ tf_cuda_library( "common_runtime/optimization_registry.cc", "common_runtime/parallel_concat_optimizer.cc", "common_runtime/placer.cc", + "common_runtime/pool_allocator.cc", "common_runtime/process_function_library_runtime.cc", + "common_runtime/process_state.cc", "common_runtime/process_util.cc", "common_runtime/renamed_device.cc", "common_runtime/rendezvous_mgr.cc", @@ -2731,12 +2796,7 @@ tf_cuda_library( ":protos_all_cc", "//third_party/eigen3", "//tensorflow/core/grappler:grappler_item", - ] + if_mkl( - [ - "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], - ), + ] + mkl_deps(), alwayslink = 1, ) @@ -2776,12 +2836,7 @@ tf_cuda_library( "//tensorflow/core/grappler/optimizers:meta_optimizer", "//third_party/eigen3", "//tensorflow/core/kernels:required", - ] + if_mkl( - [ - "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], - ) + tf_additional_core_deps() + if_static([":core_cpu_impl"]), + ] + mkl_deps() + tf_additional_core_deps() + if_static([":core_cpu_impl"]), alwayslink = 1, ) @@ -2856,6 +2911,14 @@ tf_cuda_library( ] + tf_additional_device_tracer_deps(), ) +cc_library( + name = "session_ref", + srcs = ["common_runtime/session_ref.cc"], + hdrs = ["common_runtime/session_ref.h"], + copts = tf_copts(), + deps = [":core_cpu_base"], +) + cc_library( name = "gpu_id", hdrs = [ @@ -2882,6 +2945,7 @@ cc_library( ) GPU_RUNTIME_HEADERS = [ + "common_runtime/gpu/cuda_host_allocator.h", "common_runtime/gpu/gpu_bfc_allocator.h", "common_runtime/gpu/gpu_cudamalloc_allocator.h", "common_runtime/gpu/gpu_debug_allocator.h", @@ -2891,10 +2955,9 @@ GPU_RUNTIME_HEADERS = [ "common_runtime/gpu/gpu_id_utils.h", "common_runtime/gpu/gpu_init.h", "common_runtime/gpu/gpu_managed_allocator.h", + "common_runtime/gpu/gpu_process_state.h", "common_runtime/gpu/gpu_stream_util.h", "common_runtime/gpu/gpu_util.h", - "common_runtime/gpu/pool_allocator.h", - "common_runtime/gpu/process_state.h", "common_runtime/gpu_device_context.h", ] @@ -2907,11 +2970,10 @@ tf_cuda_library( "common_runtime/gpu/gpu_device.cc", "common_runtime/gpu/gpu_device_factory.cc", "common_runtime/gpu/gpu_managed_allocator.cc", + "common_runtime/gpu/gpu_process_state.cc", "common_runtime/gpu/gpu_stream_util.cc", "common_runtime/gpu/gpu_util.cc", "common_runtime/gpu/gpu_util_platform_specific.cc", - "common_runtime/gpu/pool_allocator.cc", - "common_runtime/gpu/process_state.cc", ], hdrs = GPU_RUNTIME_HEADERS, copts = tf_copts(), @@ -3067,7 +3129,10 @@ cc_library( testonly = 1, srcs = ["platform/test_main.cc"], copts = tf_copts(), - linkopts = ["-lm"], + linkopts = select({ + "//tensorflow:windows": [], + "//conditions:default": ["-lm"], + }), visibility = ["//tensorflow:internal"], deps = [ ":lib", @@ -3158,6 +3223,7 @@ tf_cc_tests( "platform/fingerprint_test.cc", "platform/integral_types_test.cc", "platform/logging_test.cc", + "platform/mutex_test.cc", "platform/net_test.cc", "platform/port_test.cc", "platform/profile_utils/cpu_utils_test.cc", @@ -3172,6 +3238,7 @@ tf_cc_tests( ":test", ":test_main", "//third_party/eigen3", + "@zlib_archive//:zlib", ], ) @@ -3221,6 +3288,28 @@ tf_cc_test( ], ) +tf_cc_test( + name = "platform_numa_test", + size = "small", + srcs = ["platform/numa_test.cc"], + tags = [ + # This test will not pass unless it has access to all NUMA nodes + # on the executing machine. + "manual", + "notap", + ], + deps = [ + ":framework", + ":lib", + ":lib_internal", + ":lib_test_internal", + ":protos_all_cc", + ":test", + ":test_main", + "//third_party/eigen3", + ], +) + tf_cc_test( name = "platform_setround_test", size = "small", @@ -3369,10 +3458,12 @@ tf_cc_tests( "framework/bfloat16_test.cc", "framework/cancellation_test.cc", "framework/common_shape_fns_test.cc", + "framework/device_base_test.cc", "framework/function_test.cc", "framework/graph_def_util_test.cc", "framework/graph_to_functiondef_test.cc", "framework/kernel_def_builder_test.cc", + "framework/kernel_def_util_test.cc", "framework/memory_types_test.cc", "framework/node_def_builder_test.cc", "framework/node_def_util_test.cc", @@ -3390,6 +3481,7 @@ tf_cc_tests( "framework/tensor_shape_test.cc", "framework/tensor_slice_test.cc", "framework/tensor_test.cc", + "framework/tensor_testutil_test.cc", "framework/tensor_util_test.cc", "framework/tracking_allocator_test.cc", "framework/types_test.cc", @@ -3397,6 +3489,7 @@ tf_cc_tests( "framework/variant_op_registry_test.cc", "framework/variant_test.cc", "graph/algorithm_test.cc", + "graph/control_flow_test.cc", "graph/edgeset_test.cc", "graph/graph_def_builder_test.cc", "graph/graph_partition_test.cc", @@ -3421,6 +3514,7 @@ tf_cc_tests( "util/semver_test.cc", "util/sparse/sparse_tensor_test.cc", "util/stat_summarizer_test.cc", + "util/status_util_test.cc", "util/tensor_format_test.cc", "util/tensor_slice_reader_test.cc", "util/tensor_slice_set_test.cc", @@ -3445,6 +3539,7 @@ tf_cc_tests( ":ops", ":protos_all_cc", ":protos_test_cc", + ":status_util", ":test", ":test_main", ":testlib", @@ -3580,6 +3675,7 @@ tf_cc_test_mkl( deps = [ ":core", ":core_cpu", + ":core_cpu_internal", ":framework", ":framework_internal", ":test", @@ -3653,7 +3749,6 @@ tf_cc_tests_gpu( "common_runtime/gpu/gpu_bfc_allocator_test.cc", "common_runtime/gpu/gpu_device_test.cc", "common_runtime/gpu/gpu_id_manager_test.cc", - "common_runtime/gpu/gpu_event_mgr_test.cc", "common_runtime/gpu/pool_allocator_test.cc", ], linkstatic = tf_kernel_tests_linkstatic(), @@ -3677,6 +3772,23 @@ tf_cc_tests_gpu( ], ) +tf_cc_test_gpu( + name = "gpu_event_mgr_test", + srcs = ["common_runtime/gpu/gpu_event_mgr_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + ], +) + tf_cuda_cc_test( name = "gpu_device_unified_memory_test", size = "small", @@ -3731,11 +3843,7 @@ tf_cuda_only_cc_test( ":test", ":test_main", "//third_party/eigen3", - ] + if_mkl( - [ - "//third_party/mkl:intel_binary_blob", - ], - ), + ] + mkl_deps(), ) tf_cc_test_gpu( @@ -3903,13 +4011,13 @@ tf_cc_test( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "common_runtime_direct_session_test", size = "small", srcs = ["common_runtime/direct_session_test.cc"], + args = [] + if_cuda(["--heap_check=local"]), # The GPU tracer leaks memory linkstatic = tf_kernel_tests_linkstatic(), deps = [ - ":core", ":core_cpu", ":core_cpu_internal", ":direct_session_internal", @@ -3922,6 +4030,7 @@ tf_cc_test( ":test", ":test_main", ":testlib", + "//third_party/eigen3", "//tensorflow/cc:cc_ops", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:cwise_op", @@ -3935,8 +4044,7 @@ tf_cc_test( "//tensorflow/core/kernels:queue_ops", "//tensorflow/core/kernels:session_ops", "//tensorflow/core/kernels:variable_ops", - "//third_party/eigen3", - ], + ] + if_cuda([":cuda"]), ) # This is identical to :common_runtime_direct_session_test with the addition of @@ -4456,6 +4564,8 @@ filegroup( # PNG data "lib/png/testdata/lena_gray.png", "lib/png/testdata/lena_rgba.png", + "lib/png/testdata/lena_palette.png", + "lib/png/testdata/lena_palette_trns.png", # JPEG data "lib/jpeg/testdata/jpeg_merge_test1.jpg", "lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg", diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 19d643880966f7607405539a5ad43d8e03dc13fb..06b797e32edc046bab498f8d775040d57ef62ce9 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -4,6 +4,7 @@ # The following targets can be used to access ApiDefs: # :base_api_def # :python_api_def +# :java_api_def package( default_visibility = ["//visibility:private"], @@ -29,6 +30,12 @@ filegroup( visibility = ["//tensorflow:internal"], ) +filegroup( + name = "java_api_def", + srcs = glob(["java_api/*"]), + visibility = ["//tensorflow:internal"], +) + cc_library( name = "excluded_ops_lib", srcs = ["excluded_ops.cc"], diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 477a0b670e49f8aa4ee8c250d4957886eb865ed5..51812caeb2979270c913adee4fba2ce02f9c4d0e 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -59,8 +59,8 @@ void GetGoldenApiDefs(Env* env, const string& api_files_dir, file_contents = PBTxtFromMultiline(file_contents); ApiDefs api_defs; - CHECK(tensorflow::protobuf::TextFormat::ParseFromString(file_contents, - &api_defs)) + QCHECK(tensorflow::protobuf::TextFormat::ParseFromString(file_contents, + &api_defs)) << "Failed to load " << file_path; CHECK_EQ(api_defs.op_size(), 1); (*name_to_api_def)[api_defs.op(0).graph_op_name()] = api_defs.op(0); @@ -149,6 +149,33 @@ void TestAllApiDefAttributeNamesAreValid( } } } + +void TestDeprecatedAttributesSetCorrectly( + const std::unordered_map& api_defs_map) { + for (const auto& name_and_api_def : api_defs_map) { + int num_deprecated_endpoints = 0; + const auto& api_def = name_and_api_def.second; + for (const auto& endpoint : api_def.endpoint()) { + if (endpoint.deprecated()) { + ++num_deprecated_endpoints; + } + } + + const auto& name = name_and_api_def.first; + ASSERT_TRUE(api_def.deprecation_message().empty() || + num_deprecated_endpoints == 0) + << "Endpoints are set to 'deprecated' for deprecated op " << name + << ". If an op is deprecated (i.e. deprecation_message is set), " + << "all the endpoints are deprecated implicitly and 'deprecated' " + << "field should not be set."; + if (num_deprecated_endpoints > 0) { + ASSERT_NE(num_deprecated_endpoints, api_def.endpoint_size()) + << "All " << name << " endpoints are deprecated. Please, set " + << "deprecation_message in api_def_" << name << ".pbtxt instead. " + << "to indicate that the op is deprecated."; + } + } +} } // namespace class BaseApiTest : public ::testing::Test { @@ -171,7 +198,7 @@ TEST_F(BaseApiTest, AllOpsAreInApiDef) { if (excluded_ops->find(op.name()) != excluded_ops->end()) { continue; } - ASSERT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end()) + EXPECT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end()) << op.name() << " op does not have api_def_*.pbtxt file. " << "Please add api_def_" << op.name() << ".pbtxt file " << "under tensorflow/core/api_def/base_api/ directory."; @@ -236,6 +263,11 @@ TEST_F(BaseApiTest, AllApiDefAttributeNamesAreValid) { TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_); } +// Checks that deprecation is set correctly. +TEST_F(BaseApiTest, DeprecationSetCorrectly) { + TestDeprecatedAttributesSetCorrectly(api_defs_map_); +} + class PythonApiTest : public ::testing::Test { protected: PythonApiTest() { @@ -272,4 +304,9 @@ TEST_F(PythonApiTest, AllApiDefAttributeNamesAreValid) { TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_); } +// Checks that deprecation is set correctly. +TEST_F(PythonApiTest, DeprecationSetCorrectly) { + TestDeprecatedAttributesSetCorrectly(api_defs_map_); +} + } // namespace tensorflow diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCenterBias.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCenterBias.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b58b974eb4e43b49d6630449de1a0a6c37a15859 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCenterBias.pbtxt @@ -0,0 +1,41 @@ +op { + graph_op_name: "BoostedTreesCenterBias" + visibility: HIDDEN + in_arg { + name: "tree_ensemble_handle" + description: <
The helpers in `tf.train` that create these queues and enqueuing operations add -a @{tf.train.QueueRunner} to the +a `tf.train.QueueRunner` to the graph using the -@{tf.train.add_queue_runner} +`tf.train.add_queue_runner` function. Each `QueueRunner` is responsible for one stage, and holds the list of enqueue operations that need to be run in threads. Once the graph is constructed, the -@{tf.train.start_queue_runners} +`tf.train.start_queue_runners` function asks each QueueRunner in the graph to start its threads running the enqueuing operations. If all goes well, you can now run your training steps and the queues will be filled by the background threads. If you have set an epoch limit, at some point an attempt to dequeue examples will get an -@{tf.errors.OutOfRangeError}. This +`tf.errors.OutOfRangeError`. This is the TensorFlow equivalent of "end of file" (EOF) -- this means the epoch limit has been reached and no more examples are available. The last ingredient is the -@{tf.train.Coordinator}. This is responsible +`tf.train.Coordinator`. This is responsible for letting all the threads know if anything has signaled a shut down. Most commonly this would be because an exception was raised, for example one of the threads got an error when running some operation (or an ordinary Python @@ -396,21 +396,21 @@ associated with a single QueueRunner. If this isn't the last thread in the QueueRunner, the `OutOfRange` error just causes the one thread to exit. This allows the other threads, which are still finishing up their last file, to proceed until they finish as well. (Assuming you are using a -@{tf.train.Coordinator}, +`tf.train.Coordinator`, other types of errors will cause all the threads to stop.) Once all the reader threads hit the `OutOfRange` error, only then does the next queue, the example queue, gets closed. Again, the example queue will have some elements queued, so training will continue until those are exhausted. If the example queue is a -@{tf.RandomShuffleQueue}, say +`tf.RandomShuffleQueue`, say because you are using `shuffle_batch` or `shuffle_batch_join`, it normally will avoid ever having fewer than its `min_after_dequeue` attr elements buffered. However, once the queue is closed that restriction will be lifted and the queue will eventually empty. At that point the actual training threads, when they try and dequeue from example queue, will start getting `OutOfRange` errors and exiting. Once all the training threads are done, -@{tf.train.Coordinator.join} +`tf.train.Coordinator.join` will return and you can exit cleanly. ### Filtering records or producing multiple examples per record @@ -426,7 +426,7 @@ when calling one of the batching functions (such as `shuffle_batch` or SparseTensors don't play well with queues. If you use SparseTensors you have to decode the string records using -@{tf.parse_example} **after** +`tf.parse_example` **after** batching (instead of using `tf.parse_single_example` before batching). ## Preloaded data @@ -475,11 +475,11 @@ update it when training. Setting `collections=[]` keeps the variable out of the `GraphKeys.GLOBAL_VARIABLES` collection used for saving and restoring checkpoints. Either way, -@{tf.train.slice_input_producer} +`tf.train.slice_input_producer` can be used to produce a slice at a time. This shuffles the examples across an entire epoch, so further shuffling when batching is undesirable. So instead of using the `shuffle_batch` functions, we use the plain -@{tf.train.batch} function. To use +`tf.train.batch` function. To use multiple preprocessing threads, set the `num_threads` parameter to a number bigger than 1. @@ -500,7 +500,7 @@ sessions, maybe in separate processes: * The evaluation process restores the checkpoint files into an inference model that reads validation input data. -This is what is done @{tf.estimator$estimators} and manually in +This is what is done `tf.estimator` and manually in @{$deep_cnn#save-and-restore-checkpoints$the example CIFAR-10 model}. This has a couple of benefits: @@ -511,12 +511,12 @@ You can have the train and eval in the same graph in the same process, and share their trained variables or layers. See @{$variables$the shared variables tutorial}. To support the single-graph approach -@{$programmers_guide/datasets$`tf.data`} also supplies -@{$programmers_guide/datasets#creating_an_iterator$advanced iterator types} that +@{$guide/datasets$`tf.data`} also supplies +@{$guide/datasets#creating_an_iterator$advanced iterator types} that that allow the user to change the input pipeline without rebuilding the graph or session. Note: Regardless of the implementation, many -operations (like @{tf.layers.batch_normalization}, and @{tf.layers.dropout}) +operations (like `tf.layers.batch_normalization`, and `tf.layers.dropout`) need to know if they are in training or evaluation mode, and you must be careful to set this appropriately if you change the data source. diff --git a/tensorflow/docs_src/api_guides/python/regression_examples.md b/tensorflow/docs_src/api_guides/python/regression_examples.md index 7de2be05521d9293e33664cdbbd7bf16b9ad7c52..f8abbf0f9741e379cd628f0ee3cf18fdb8152a0b 100644 --- a/tensorflow/docs_src/api_guides/python/regression_examples.md +++ b/tensorflow/docs_src/api_guides/python/regression_examples.md @@ -8,25 +8,25 @@ to implement regression in Estimators: linear_regression.py - Use the @{tf.estimator.LinearRegressor} Estimator to train a + Use the `tf.estimator.LinearRegressor` Estimator to train a regression model on numeric data. linear_regression_categorical.py - Use the @{tf.estimator.LinearRegressor} Estimator to train a + Use the `tf.estimator.LinearRegressor` Estimator to train a regression model on categorical data. dnn_regression.py - Use the @{tf.estimator.DNNRegressor} Estimator to train a + Use the `tf.estimator.DNNRegressor` Estimator to train a regression model on discrete data with a deep neural network. custom_regression.py - Use @{tf.estimator.Estimator} to train a customized dnn + Use `tf.estimator.Estimator` to train a customized dnn regression model. @@ -219,7 +219,7 @@ The `custom_regression.py` example also trains a model that predicts the price of a car based on mixed real-valued and categorical input features, described by feature_columns. Unlike `linear_regression_categorical.py`, and `dnn_regression.py` this example does not use a pre-made estimator, but defines -a custom model using the base @{tf.estimator.Estimator$`Estimator`} class. The +a custom model using the base `tf.estimator.Estimator` class. The custom model is quite similar to the model defined by `dnn_regression.py`. The custom model is defined by the `model_fn` argument to the constructor. The @@ -227,6 +227,6 @@ customization is made more reusable through `params` dictionary, which is later passed through to the `model_fn` when the `model_fn` is called. The `model_fn` returns an -@{tf.estimator.EstimatorSpec$`EstimatorSpec`} which is a simple structure +`tf.estimator.EstimatorSpec` which is a simple structure indicating to the `Estimator` which operations should be run to accomplish various tasks. diff --git a/tensorflow/docs_src/api_guides/python/session_ops.md b/tensorflow/docs_src/api_guides/python/session_ops.md index 5176e3549c38e07d789401c5e684c16449d84a8a..5f41bcf209b13b4f3a4a14322cf20e82cc3d27d8 100644 --- a/tensorflow/docs_src/api_guides/python/session_ops.md +++ b/tensorflow/docs_src/api_guides/python/session_ops.md @@ -1,7 +1,7 @@ # Tensor Handle Operations Note: Functions taking `Tensor` arguments can also take anything accepted by -@{tf.convert_to_tensor}. +`tf.convert_to_tensor`. [TOC] @@ -10,6 +10,6 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by TensorFlow provides several operators that allows the user to keep tensors "in-place" across run calls. -* @{tf.get_session_handle} -* @{tf.get_session_tensor} -* @{tf.delete_session_tensor} +* `tf.get_session_handle` +* `tf.get_session_tensor` +* `tf.delete_session_tensor` diff --git a/tensorflow/docs_src/api_guides/python/sparse_ops.md b/tensorflow/docs_src/api_guides/python/sparse_ops.md index 19d5faba05a6ac79229b721ab6e45e4e36fd9f7a..b360055ed0ed0cde59a68c89f0a0f4ae1d5758ab 100644 --- a/tensorflow/docs_src/api_guides/python/sparse_ops.md +++ b/tensorflow/docs_src/api_guides/python/sparse_ops.md @@ -1,7 +1,7 @@ # Sparse Tensors Note: Functions taking `Tensor` arguments can also take anything accepted by -@{tf.convert_to_tensor}. +`tf.convert_to_tensor`. [TOC] @@ -12,34 +12,34 @@ in multiple dimensions. Contrast this representation with `IndexedSlices`, which is efficient for representing tensors that are sparse in their first dimension, and dense along all other dimensions. -* @{tf.SparseTensor} -* @{tf.SparseTensorValue} +* `tf.SparseTensor` +* `tf.SparseTensorValue` ## Conversion -* @{tf.sparse_to_dense} -* @{tf.sparse_tensor_to_dense} -* @{tf.sparse_to_indicator} -* @{tf.sparse_merge} +* `tf.sparse_to_dense` +* `tf.sparse_tensor_to_dense` +* `tf.sparse_to_indicator` +* `tf.sparse_merge` ## Manipulation -* @{tf.sparse_concat} -* @{tf.sparse_reorder} -* @{tf.sparse_reshape} -* @{tf.sparse_split} -* @{tf.sparse_retain} -* @{tf.sparse_reset_shape} -* @{tf.sparse_fill_empty_rows} -* @{tf.sparse_transpose} +* `tf.sparse_concat` +* `tf.sparse_reorder` +* `tf.sparse_reshape` +* `tf.sparse_split` +* `tf.sparse_retain` +* `tf.sparse_reset_shape` +* `tf.sparse_fill_empty_rows` +* `tf.sparse_transpose` ## Reduction -* @{tf.sparse_reduce_sum} -* @{tf.sparse_reduce_sum_sparse} +* `tf.sparse_reduce_sum` +* `tf.sparse_reduce_sum_sparse` ## Math Operations -* @{tf.sparse_add} -* @{tf.sparse_softmax} -* @{tf.sparse_tensor_dense_matmul} -* @{tf.sparse_maximum} -* @{tf.sparse_minimum} +* `tf.sparse_add` +* `tf.sparse_softmax` +* `tf.sparse_tensor_dense_matmul` +* `tf.sparse_maximum` +* `tf.sparse_minimum` diff --git a/tensorflow/docs_src/api_guides/python/spectral_ops.md b/tensorflow/docs_src/api_guides/python/spectral_ops.md index 022c471ef10ff9b37311dcc0c114507376f6b6c4..f6d109a3a080b467eb8606f36671b449fb6e5c4d 100644 --- a/tensorflow/docs_src/api_guides/python/spectral_ops.md +++ b/tensorflow/docs_src/api_guides/python/spectral_ops.md @@ -2,24 +2,25 @@ [TOC] -The @{tf.spectral} module supports several spectral decomposition operations +The `tf.spectral` module supports several spectral decomposition operations that you can use to transform Tensors of real and complex signals. ## Discrete Fourier Transforms -* @{tf.spectral.fft} -* @{tf.spectral.ifft} -* @{tf.spectral.fft2d} -* @{tf.spectral.ifft2d} -* @{tf.spectral.fft3d} -* @{tf.spectral.ifft3d} -* @{tf.spectral.rfft} -* @{tf.spectral.irfft} -* @{tf.spectral.rfft2d} -* @{tf.spectral.irfft2d} -* @{tf.spectral.rfft3d} -* @{tf.spectral.irfft3d} +* `tf.spectral.fft` +* `tf.spectral.ifft` +* `tf.spectral.fft2d` +* `tf.spectral.ifft2d` +* `tf.spectral.fft3d` +* `tf.spectral.ifft3d` +* `tf.spectral.rfft` +* `tf.spectral.irfft` +* `tf.spectral.rfft2d` +* `tf.spectral.irfft2d` +* `tf.spectral.rfft3d` +* `tf.spectral.irfft3d` ## Discrete Cosine Transforms -* @{tf.spectral.dct} +* `tf.spectral.dct` +* `tf.spectral.idct` diff --git a/tensorflow/docs_src/api_guides/python/state_ops.md b/tensorflow/docs_src/api_guides/python/state_ops.md index ec2d8773860f0595cabe91d591a5fdc025e99b83..fc55ea14813ef0a20b0a30fbb35888777c5f152f 100644 --- a/tensorflow/docs_src/api_guides/python/state_ops.md +++ b/tensorflow/docs_src/api_guides/python/state_ops.md @@ -1,68 +1,68 @@ # Variables Note: Functions taking `Tensor` arguments can also take anything accepted by -@{tf.convert_to_tensor}. +`tf.convert_to_tensor`. [TOC] ## Variables -* @{tf.Variable} +* `tf.Variable` ## Variable helper functions TensorFlow provides a set of functions to help manage the set of variables collected in the graph. -* @{tf.global_variables} -* @{tf.local_variables} -* @{tf.model_variables} -* @{tf.trainable_variables} -* @{tf.moving_average_variables} -* @{tf.global_variables_initializer} -* @{tf.local_variables_initializer} -* @{tf.variables_initializer} -* @{tf.is_variable_initialized} -* @{tf.report_uninitialized_variables} -* @{tf.assert_variables_initialized} -* @{tf.assign} -* @{tf.assign_add} -* @{tf.assign_sub} +* `tf.global_variables` +* `tf.local_variables` +* `tf.model_variables` +* `tf.trainable_variables` +* `tf.moving_average_variables` +* `tf.global_variables_initializer` +* `tf.local_variables_initializer` +* `tf.variables_initializer` +* `tf.is_variable_initialized` +* `tf.report_uninitialized_variables` +* `tf.assert_variables_initialized` +* `tf.assign` +* `tf.assign_add` +* `tf.assign_sub` ## Saving and Restoring Variables -* @{tf.train.Saver} -* @{tf.train.latest_checkpoint} -* @{tf.train.get_checkpoint_state} -* @{tf.train.update_checkpoint_state} +* `tf.train.Saver` +* `tf.train.latest_checkpoint` +* `tf.train.get_checkpoint_state` +* `tf.train.update_checkpoint_state` ## Sharing Variables TensorFlow provides several classes and operations that you can use to create variables contingent on certain conditions. -* @{tf.get_variable} -* @{tf.get_local_variable} -* @{tf.VariableScope} -* @{tf.variable_scope} -* @{tf.variable_op_scope} -* @{tf.get_variable_scope} -* @{tf.make_template} -* @{tf.no_regularizer} -* @{tf.constant_initializer} -* @{tf.random_normal_initializer} -* @{tf.truncated_normal_initializer} -* @{tf.random_uniform_initializer} -* @{tf.uniform_unit_scaling_initializer} -* @{tf.zeros_initializer} -* @{tf.ones_initializer} -* @{tf.orthogonal_initializer} +* `tf.get_variable` +* `tf.get_local_variable` +* `tf.VariableScope` +* `tf.variable_scope` +* `tf.variable_op_scope` +* `tf.get_variable_scope` +* `tf.make_template` +* `tf.no_regularizer` +* `tf.constant_initializer` +* `tf.random_normal_initializer` +* `tf.truncated_normal_initializer` +* `tf.random_uniform_initializer` +* `tf.uniform_unit_scaling_initializer` +* `tf.zeros_initializer` +* `tf.ones_initializer` +* `tf.orthogonal_initializer` ## Variable Partitioners for Sharding -* @{tf.fixed_size_partitioner} -* @{tf.variable_axis_size_partitioner} -* @{tf.min_max_variable_partitioner} +* `tf.fixed_size_partitioner` +* `tf.variable_axis_size_partitioner` +* `tf.min_max_variable_partitioner` ## Sparse Variable Updates @@ -73,38 +73,38 @@ only a small subset of embedding vectors change in any given step. Since a sparse update of a large tensor may be generated automatically during gradient computation (as in the gradient of -@{tf.gather}), -an @{tf.IndexedSlices} class is provided that encapsulates a set +`tf.gather`), +an `tf.IndexedSlices` class is provided that encapsulates a set of sparse indices and values. `IndexedSlices` objects are detected and handled automatically by the optimizers in most cases. -* @{tf.scatter_update} -* @{tf.scatter_add} -* @{tf.scatter_sub} -* @{tf.scatter_mul} -* @{tf.scatter_div} -* @{tf.scatter_min} -* @{tf.scatter_max} -* @{tf.scatter_nd_update} -* @{tf.scatter_nd_add} -* @{tf.scatter_nd_sub} -* @{tf.sparse_mask} -* @{tf.IndexedSlices} +* `tf.scatter_update` +* `tf.scatter_add` +* `tf.scatter_sub` +* `tf.scatter_mul` +* `tf.scatter_div` +* `tf.scatter_min` +* `tf.scatter_max` +* `tf.scatter_nd_update` +* `tf.scatter_nd_add` +* `tf.scatter_nd_sub` +* `tf.sparse_mask` +* `tf.IndexedSlices` ### Read-only Lookup Tables -* @{tf.initialize_all_tables} -* @{tf.tables_initializer} +* `tf.initialize_all_tables` +* `tf.tables_initializer` ## Exporting and Importing Meta Graphs -* @{tf.train.export_meta_graph} -* @{tf.train.import_meta_graph} +* `tf.train.export_meta_graph` +* `tf.train.import_meta_graph` # Deprecated functions (removed after 2017-03-02). Please don't use them. -* @{tf.all_variables} -* @{tf.initialize_all_variables} -* @{tf.initialize_local_variables} -* @{tf.initialize_variables} +* `tf.all_variables` +* `tf.initialize_all_variables` +* `tf.initialize_local_variables` +* `tf.initialize_variables` diff --git a/tensorflow/docs_src/api_guides/python/string_ops.md b/tensorflow/docs_src/api_guides/python/string_ops.md index e9be4f156a9b40fac41dfee16e3265464e940d7e..24a3aad642d16eaef25f427ae0223b884ef884d7 100644 --- a/tensorflow/docs_src/api_guides/python/string_ops.md +++ b/tensorflow/docs_src/api_guides/python/string_ops.md @@ -1,7 +1,7 @@ # Strings Note: Functions taking `Tensor` arguments can also take anything accepted by -@{tf.convert_to_tensor}. +`tf.convert_to_tensor`. [TOC] @@ -10,30 +10,30 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by String hashing ops take a string input tensor and map each element to an integer. -* @{tf.string_to_hash_bucket_fast} -* @{tf.string_to_hash_bucket_strong} -* @{tf.string_to_hash_bucket} +* `tf.string_to_hash_bucket_fast` +* `tf.string_to_hash_bucket_strong` +* `tf.string_to_hash_bucket` ## Joining String joining ops concatenate elements of input string tensors to produce a new string tensor. -* @{tf.reduce_join} -* @{tf.string_join} +* `tf.reduce_join` +* `tf.string_join` ## Splitting -* @{tf.string_split} -* @{tf.substr} +* `tf.string_split` +* `tf.substr` ## Conversion -* @{tf.as_string} -* @{tf.string_to_number} +* `tf.as_string` +* `tf.string_to_number` -* @{tf.decode_raw} -* @{tf.decode_csv} +* `tf.decode_raw` +* `tf.decode_csv` -* @{tf.encode_base64} -* @{tf.decode_base64} +* `tf.encode_base64` +* `tf.decode_base64` diff --git a/tensorflow/docs_src/api_guides/python/summary.md b/tensorflow/docs_src/api_guides/python/summary.md index eda119ab24edf2caeb6d2de01abc541b590289f4..e290703b7d844504291bd3f6fc9819f7e6782d45 100644 --- a/tensorflow/docs_src/api_guides/python/summary.md +++ b/tensorflow/docs_src/api_guides/python/summary.md @@ -7,17 +7,17 @@ then accessible in tools such as @{$summaries_and_tensorboard$TensorBoard}. ## Generation of Summaries ### Class for writing Summaries -* @{tf.summary.FileWriter} -* @{tf.summary.FileWriterCache} +* `tf.summary.FileWriter` +* `tf.summary.FileWriterCache` ### Summary Ops -* @{tf.summary.tensor_summary} -* @{tf.summary.scalar} -* @{tf.summary.histogram} -* @{tf.summary.audio} -* @{tf.summary.image} -* @{tf.summary.merge} -* @{tf.summary.merge_all} +* `tf.summary.tensor_summary` +* `tf.summary.scalar` +* `tf.summary.histogram` +* `tf.summary.audio` +* `tf.summary.image` +* `tf.summary.merge` +* `tf.summary.merge_all` ## Utilities -* @{tf.summary.get_summary_description} +* `tf.summary.get_summary_description` diff --git a/tensorflow/docs_src/api_guides/python/test.md b/tensorflow/docs_src/api_guides/python/test.md index 5dc88124e7e1c26237c5c150b624486ab0df1283..b6e0a332b9d2e906af96d36d4ef856199e485a05 100644 --- a/tensorflow/docs_src/api_guides/python/test.md +++ b/tensorflow/docs_src/api_guides/python/test.md @@ -23,25 +23,25 @@ which adds methods relevant to TensorFlow tests. Here is an example: ``` `tf.test.TestCase` inherits from `unittest.TestCase` but adds a few additional -methods. See @{tf.test.TestCase} for details. +methods. See `tf.test.TestCase` for details. -* @{tf.test.main} -* @{tf.test.TestCase} -* @{tf.test.test_src_dir_path} +* `tf.test.main` +* `tf.test.TestCase` +* `tf.test.test_src_dir_path` ## Utilities Note: `tf.test.mock` is an alias to the python `mock` or `unittest.mock` depending on the python version. -* @{tf.test.assert_equal_graph_def} -* @{tf.test.get_temp_dir} -* @{tf.test.is_built_with_cuda} -* @{tf.test.is_gpu_available} -* @{tf.test.gpu_device_name} +* `tf.test.assert_equal_graph_def` +* `tf.test.get_temp_dir` +* `tf.test.is_built_with_cuda` +* `tf.test.is_gpu_available` +* `tf.test.gpu_device_name` ## Gradient checking -@{tf.test.compute_gradient} and @{tf.test.compute_gradient_error} perform +`tf.test.compute_gradient` and `tf.test.compute_gradient_error` perform numerical differentiation of graphs for comparison against registered analytic gradients. diff --git a/tensorflow/docs_src/api_guides/python/tfdbg.md b/tensorflow/docs_src/api_guides/python/tfdbg.md index 2212a2da0e8c4f339120453c15d5b61b4574f8ee..9778cdc0b0a6bdf4acecce95e19deb99490d669e 100644 --- a/tensorflow/docs_src/api_guides/python/tfdbg.md +++ b/tensorflow/docs_src/api_guides/python/tfdbg.md @@ -8,9 +8,9 @@ Public Python API of TensorFlow Debugger (tfdbg). These functions help you modify `RunOptions` to specify which `Tensor`s are to be watched when the TensorFlow graph is executed at runtime. -* @{tfdbg.add_debug_tensor_watch} -* @{tfdbg.watch_graph} -* @{tfdbg.watch_graph_with_blacklists} +* `tfdbg.add_debug_tensor_watch` +* `tfdbg.watch_graph` +* `tfdbg.watch_graph_with_blacklists` ## Classes for debug-dump data and directories @@ -18,13 +18,13 @@ be watched when the TensorFlow graph is executed at runtime. These classes allow you to load and inspect tensor values dumped from TensorFlow graphs during runtime. -* @{tfdbg.DebugTensorDatum} -* @{tfdbg.DebugDumpDir} +* `tfdbg.DebugTensorDatum` +* `tfdbg.DebugDumpDir` ## Functions for loading debug-dump data -* @{tfdbg.load_tensor_from_event_file} +* `tfdbg.load_tensor_from_event_file` ## Tensor-value predicates @@ -32,7 +32,7 @@ TensorFlow graphs during runtime. Built-in tensor-filter predicates to support conditional breakpoint between runs. See `DebugDumpDir.find()` for more details. -* @{tfdbg.has_inf_or_nan} +* `tfdbg.has_inf_or_nan` ## Session wrapper class and `SessionRunHook` implementations @@ -44,7 +44,7 @@ These classes allow you to * generate `SessionRunHook` objects to debug `tf.contrib.learn` models (see `DumpingDebugHook` and `LocalCLIDebugHook`). -* @{tfdbg.DumpingDebugHook} -* @{tfdbg.DumpingDebugWrapperSession} -* @{tfdbg.LocalCLIDebugHook} -* @{tfdbg.LocalCLIDebugWrapperSession} +* `tfdbg.DumpingDebugHook` +* `tfdbg.DumpingDebugWrapperSession` +* `tfdbg.LocalCLIDebugHook` +* `tfdbg.LocalCLIDebugWrapperSession` diff --git a/tensorflow/docs_src/api_guides/python/threading_and_queues.md b/tensorflow/docs_src/api_guides/python/threading_and_queues.md index 8ad4c4c07512d04d1df43062954f2e64b1d8e177..48f0778b732919c4d70154f0200d0f065139bac3 100644 --- a/tensorflow/docs_src/api_guides/python/threading_and_queues.md +++ b/tensorflow/docs_src/api_guides/python/threading_and_queues.md @@ -25,7 +25,7 @@ longer holds, the queue will unblock the step and allow execution to proceed. TensorFlow implements several classes of queue. The principal difference between these classes is the order that items are removed from the queue. To get a feel for queues, let's consider a simple example. We will create a "first in, first -out" queue (@{tf.FIFOQueue}) and fill it with zeros. Then we'll construct a +out" queue (`tf.FIFOQueue`) and fill it with zeros. Then we'll construct a graph that takes an item off the queue, adds one to that item, and puts it back on the end of the queue. Slowly, the numbers on the queue increase. @@ -47,8 +47,8 @@ Now that you have a bit of a feel for queues, let's dive into the details... ## Queue usage overview -Queues, such as @{tf.FIFOQueue} -and @{tf.RandomShuffleQueue}, +Queues, such as `tf.FIFOQueue` +and `tf.RandomShuffleQueue`, are important TensorFlow objects that aid in computing tensors asynchronously in a graph. @@ -59,11 +59,11 @@ prepare inputs for training a model as follows: * A training thread executes a training op that dequeues mini-batches from the queue -We recommend using the @{tf.data.Dataset.shuffle$`shuffle`} -and @{tf.data.Dataset.batch$`batch`} methods of a -@{tf.data.Dataset$`Dataset`} to accomplish this. However, if you'd prefer +We recommend using the `tf.data.Dataset.shuffle` +and `tf.data.Dataset.batch` methods of a +`tf.data.Dataset` to accomplish this. However, if you'd prefer to use a queue-based version instead, you can find a full implementation in the -@{tf.train.shuffle_batch} function. +`tf.train.shuffle_batch` function. For demonstration purposes a simplified implementation is given below. @@ -93,8 +93,8 @@ def simple_shuffle_batch(source, capacity, batch_size=10): return queue.dequeue_many(batch_size) ``` -Once started by @{tf.train.start_queue_runners}, or indirectly through -@{tf.train.MonitoredSession}, the `QueueRunner` will launch the +Once started by `tf.train.start_queue_runners`, or indirectly through +`tf.train.MonitoredSession`, the `QueueRunner` will launch the threads in the background to fill the queue. Meanwhile the main thread will execute the `dequeue_many` op to pull data from it. Note how these ops do not depend on each other, except indirectly through the internal state of the queue. @@ -126,7 +126,7 @@ with tf.train.MonitoredSession() as sess: ``` For most use cases, the automatic thread startup and management provided -by @{tf.train.MonitoredSession} is sufficient. In the rare case that it is not, +by `tf.train.MonitoredSession` is sufficient. In the rare case that it is not, TensorFlow provides tools for manually managing your threads and queues. ## Manual Thread Management @@ -139,8 +139,8 @@ threads must be able to stop together, exceptions must be caught and reported, and queues must be properly closed when stopping. TensorFlow provides two classes to help: -@{tf.train.Coordinator} and -@{tf.train.QueueRunner}. These two classes +`tf.train.Coordinator` and +`tf.train.QueueRunner`. These two classes are designed to be used together. The `Coordinator` class helps multiple threads stop together and report exceptions to a program that waits for them to stop. The `QueueRunner` class is used to create a number of threads cooperating to @@ -148,14 +148,14 @@ enqueue tensors in the same queue. ### Coordinator -The @{tf.train.Coordinator} class manages background threads in a TensorFlow +The `tf.train.Coordinator` class manages background threads in a TensorFlow program and helps multiple threads stop together. Its key methods are: -* @{tf.train.Coordinator.should_stop}: returns `True` if the threads should stop. -* @{tf.train.Coordinator.request_stop}: requests that threads should stop. -* @{tf.train.Coordinator.join}: waits until the specified threads have stopped. +* `tf.train.Coordinator.should_stop`: returns `True` if the threads should stop. +* `tf.train.Coordinator.request_stop`: requests that threads should stop. +* `tf.train.Coordinator.join`: waits until the specified threads have stopped. You first create a `Coordinator` object, and then create a number of threads that use the coordinator. The threads typically run loops that stop when @@ -191,11 +191,11 @@ coord.join(threads) Obviously, the coordinator can manage threads doing very different things. They don't have to be all the same as in the example above. The coordinator -also has support to capture and report exceptions. See the @{tf.train.Coordinator} documentation for more details. +also has support to capture and report exceptions. See the `tf.train.Coordinator` documentation for more details. ### QueueRunner -The @{tf.train.QueueRunner} class creates a number of threads that repeatedly +The `tf.train.QueueRunner` class creates a number of threads that repeatedly run an enqueue op. These threads can use a coordinator to stop together. In addition, a queue runner will run a *closer operation* that closes the queue if an exception is reported to the coordinator. diff --git a/tensorflow/docs_src/api_guides/python/train.md b/tensorflow/docs_src/api_guides/python/train.md index cbc50529469b32afbb9c0646a0cfd27627563f87..a118123665e42cdee28819a86e5b24a2a106f5df 100644 --- a/tensorflow/docs_src/api_guides/python/train.md +++ b/tensorflow/docs_src/api_guides/python/train.md @@ -1,7 +1,7 @@ # Training [TOC] -@{tf.train} provides a set of classes and functions that help train models. +`tf.train` provides a set of classes and functions that help train models. ## Optimizers @@ -12,19 +12,19 @@ optimization algorithms such as GradientDescent and Adagrad. You never instantiate the Optimizer class itself, but instead instantiate one of the subclasses. -* @{tf.train.Optimizer} -* @{tf.train.GradientDescentOptimizer} -* @{tf.train.AdadeltaOptimizer} -* @{tf.train.AdagradOptimizer} -* @{tf.train.AdagradDAOptimizer} -* @{tf.train.MomentumOptimizer} -* @{tf.train.AdamOptimizer} -* @{tf.train.FtrlOptimizer} -* @{tf.train.ProximalGradientDescentOptimizer} -* @{tf.train.ProximalAdagradOptimizer} -* @{tf.train.RMSPropOptimizer} +* `tf.train.Optimizer` +* `tf.train.GradientDescentOptimizer` +* `tf.train.AdadeltaOptimizer` +* `tf.train.AdagradOptimizer` +* `tf.train.AdagradDAOptimizer` +* `tf.train.MomentumOptimizer` +* `tf.train.AdamOptimizer` +* `tf.train.FtrlOptimizer` +* `tf.train.ProximalGradientDescentOptimizer` +* `tf.train.ProximalAdagradOptimizer` +* `tf.train.RMSPropOptimizer` -See @{tf.contrib.opt} for more optimizers. +See `tf.contrib.opt` for more optimizers. ## Gradient Computation @@ -34,10 +34,10 @@ optimizer classes automatically compute derivatives on your graph, but creators of new Optimizers or expert users can call the lower-level functions below. -* @{tf.gradients} -* @{tf.AggregationMethod} -* @{tf.stop_gradient} -* @{tf.hessians} +* `tf.gradients` +* `tf.AggregationMethod` +* `tf.stop_gradient` +* `tf.hessians` ## Gradient Clipping @@ -47,22 +47,22 @@ functions to your graph. You can use these functions to perform general data clipping, but they're particularly useful for handling exploding or vanishing gradients. -* @{tf.clip_by_value} -* @{tf.clip_by_norm} -* @{tf.clip_by_average_norm} -* @{tf.clip_by_global_norm} -* @{tf.global_norm} +* `tf.clip_by_value` +* `tf.clip_by_norm` +* `tf.clip_by_average_norm` +* `tf.clip_by_global_norm` +* `tf.global_norm` ## Decaying the learning rate -* @{tf.train.exponential_decay} -* @{tf.train.inverse_time_decay} -* @{tf.train.natural_exp_decay} -* @{tf.train.piecewise_constant} -* @{tf.train.polynomial_decay} -* @{tf.train.cosine_decay} -* @{tf.train.linear_cosine_decay} -* @{tf.train.noisy_linear_cosine_decay} +* `tf.train.exponential_decay` +* `tf.train.inverse_time_decay` +* `tf.train.natural_exp_decay` +* `tf.train.piecewise_constant` +* `tf.train.polynomial_decay` +* `tf.train.cosine_decay` +* `tf.train.linear_cosine_decay` +* `tf.train.noisy_linear_cosine_decay` ## Moving Averages @@ -70,7 +70,7 @@ Some training algorithms, such as GradientDescent and Momentum often benefit from maintaining a moving average of variables during optimization. Using the moving averages for evaluations often improve results significantly. -* @{tf.train.ExponentialMovingAverage} +* `tf.train.ExponentialMovingAverage` ## Coordinator and QueueRunner @@ -79,61 +79,61 @@ for how to use threads and queues. For documentation on the Queue API, see @{$python/io_ops#queues$Queues}. -* @{tf.train.Coordinator} -* @{tf.train.QueueRunner} -* @{tf.train.LooperThread} -* @{tf.train.add_queue_runner} -* @{tf.train.start_queue_runners} +* `tf.train.Coordinator` +* `tf.train.QueueRunner` +* `tf.train.LooperThread` +* `tf.train.add_queue_runner` +* `tf.train.start_queue_runners` ## Distributed execution See @{$distributed$Distributed TensorFlow} for more information about how to configure a distributed TensorFlow program. -* @{tf.train.Server} -* @{tf.train.Supervisor} -* @{tf.train.SessionManager} -* @{tf.train.ClusterSpec} -* @{tf.train.replica_device_setter} -* @{tf.train.MonitoredTrainingSession} -* @{tf.train.MonitoredSession} -* @{tf.train.SingularMonitoredSession} -* @{tf.train.Scaffold} -* @{tf.train.SessionCreator} -* @{tf.train.ChiefSessionCreator} -* @{tf.train.WorkerSessionCreator} +* `tf.train.Server` +* `tf.train.Supervisor` +* `tf.train.SessionManager` +* `tf.train.ClusterSpec` +* `tf.train.replica_device_setter` +* `tf.train.MonitoredTrainingSession` +* `tf.train.MonitoredSession` +* `tf.train.SingularMonitoredSession` +* `tf.train.Scaffold` +* `tf.train.SessionCreator` +* `tf.train.ChiefSessionCreator` +* `tf.train.WorkerSessionCreator` ## Reading Summaries from Event Files See @{$summaries_and_tensorboard$Summaries and TensorBoard} for an overview of summaries, event files, and visualization in TensorBoard. -* @{tf.train.summary_iterator} +* `tf.train.summary_iterator` ## Training Hooks Hooks are tools that run in the process of training/evaluation of the model. -* @{tf.train.SessionRunHook} -* @{tf.train.SessionRunArgs} -* @{tf.train.SessionRunContext} -* @{tf.train.SessionRunValues} -* @{tf.train.LoggingTensorHook} -* @{tf.train.StopAtStepHook} -* @{tf.train.CheckpointSaverHook} -* @{tf.train.NewCheckpointReader} -* @{tf.train.StepCounterHook} -* @{tf.train.NanLossDuringTrainingError} -* @{tf.train.NanTensorHook} -* @{tf.train.SummarySaverHook} -* @{tf.train.GlobalStepWaiterHook} -* @{tf.train.FinalOpsHook} -* @{tf.train.FeedFnHook} +* `tf.train.SessionRunHook` +* `tf.train.SessionRunArgs` +* `tf.train.SessionRunContext` +* `tf.train.SessionRunValues` +* `tf.train.LoggingTensorHook` +* `tf.train.StopAtStepHook` +* `tf.train.CheckpointSaverHook` +* `tf.train.NewCheckpointReader` +* `tf.train.StepCounterHook` +* `tf.train.NanLossDuringTrainingError` +* `tf.train.NanTensorHook` +* `tf.train.SummarySaverHook` +* `tf.train.GlobalStepWaiterHook` +* `tf.train.FinalOpsHook` +* `tf.train.FeedFnHook` ## Training Utilities -* @{tf.train.global_step} -* @{tf.train.basic_train_loop} -* @{tf.train.get_global_step} -* @{tf.train.assert_global_step} -* @{tf.train.write_graph} +* `tf.train.global_step` +* `tf.train.basic_train_loop` +* `tf.train.get_global_step` +* `tf.train.assert_global_step` +* `tf.train.write_graph` diff --git a/tensorflow/docs_src/community/index.md b/tensorflow/docs_src/community/index.md index eec2e51a8706b73abcedb8329df3ad03e3b349c3..0aa8e7612a6dfd96dc3f59403e2691df00418cb5 100644 --- a/tensorflow/docs_src/community/index.md +++ b/tensorflow/docs_src/community/index.md @@ -54,7 +54,7 @@ with content from the TensorFlow team and the best articles from the community. ### YouTube -Our [YouTube Channel](http://youtube.com/tensorflow/) focuses on machine learing +Our [YouTube Channel](http://youtube.com/tensorflow/) focuses on machine learning and AI with TensorFlow. On it we have a number of new shows, including: - TensorFlow Meets: meet with community contributors to learn and share what they're doing diff --git a/tensorflow/docs_src/community/lists.md b/tensorflow/docs_src/community/lists.md index 7450ab36c436538dd584541fb0dafb5a2c6067b3..bc2f573c29ca445cc1770a3a2c520a7b60e52855 100644 --- a/tensorflow/docs_src/community/lists.md +++ b/tensorflow/docs_src/community/lists.md @@ -32,6 +32,8 @@ These projects inside the TensorFlow GitHub organization have lists dedicated to and peer support for TensorFlow.js. * [tflite](https://groups.google.com/a/tensorflow.org/d/forum/tflite) - Discussion and peer support for TensorFlow Lite. +* [tfprobability](https://groups.google.com/a/tensorflow.org/d/forum/tfprobability) - Discussion and + peer support for TensorFlow Probability. * [tpu-users](https://groups.google.com/a/tensorflow.org/d/forum/tpu-users) - Community discussion and support for TPU users. diff --git a/tensorflow/docs_src/community/style_guide.md b/tensorflow/docs_src/community/style_guide.md index c9268790a71fad9328f60f6a889c19c32117497e..daf0d2fdc042509972f7ab7446bb5876bb218657 100644 --- a/tensorflow/docs_src/community/style_guide.md +++ b/tensorflow/docs_src/community/style_guide.md @@ -47,27 +47,7 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) ``` -* At the end of every BUILD file, should contain: -``` -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) -``` - -* When adding new BUILD file, add this line to `tensorflow/BUILD` file into `all_opensource_files` target. - -``` -"//tensorflow/:all_files", -``` * For all Python BUILD targets (libraries and tests) add next line: @@ -80,6 +60,9 @@ srcs_version = "PY2AND3", * Operations that deal with batches may assume that the first dimension of a Tensor is the batch dimension. +* In most models the *last dimension* is the number of channels. + +* Dimensions excluding the first and last usually make up the "space" dimensions: Sequence-length or Image-size. ## Python operations @@ -148,37 +131,6 @@ Usage: ## Layers -A *Layer* is a Python operation that combines variable creation and/or one or many -other graph operations. Follow the same requirements as for regular Python -operation. - -* If a layer creates one or more variables, the layer function - should take next arguments also following order: - - `initializers`: Optionally allow to specify initializers for the variables. - - `regularizers`: Optionally allow to specify regularizers for the variables. - - `trainable`: which control if their variables are trainable or not. - - `scope`: `VariableScope` object that variable will be put under. - - `reuse`: `bool` indicator if the variable should be reused if - it's present in the scope. - -* Layers that behave differently during training should take: - - `is_training`: `bool` indicator to conditionally choose different - computation paths (e.g. using `tf.cond`) during execution. - -Example: - - def conv2d(inputs, - num_filters_out, - kernel_size, - stride=1, - padding='SAME', - activation_fn=tf.nn.relu, - normalization_fn=add_bias, - normalization_params=None, - initializers=None, - regularizers=None, - trainable=True, - scope=None, - reuse=None): - ... see implementation at tensorflow/contrib/layers/python/layers/layers.py ... +Use `tf.keras.layers`, not `tf.layers`. +See `tf.keras.layers` and [the Keras guide](../guide/keras.md#custom_layers) for details on how to sub-class layers. diff --git a/tensorflow/docs_src/deploy/distributed.md b/tensorflow/docs_src/deploy/distributed.md index d7ed6b1debdf256a800aed7304152acf5972bf72..6a760f53c878a38d69e3edb8706b20b67aabf5dd 100644 --- a/tensorflow/docs_src/deploy/distributed.md +++ b/tensorflow/docs_src/deploy/distributed.md @@ -2,7 +2,7 @@ This document shows how to create a cluster of TensorFlow servers, and how to distribute a computation graph across that cluster. We assume that you are -familiar with the @{$programmers_guide/low_level_intro$basic concepts} of +familiar with the @{$guide/low_level_intro$basic concepts} of writing low level TensorFlow programs. ## Hello distributed TensorFlow! @@ -21,7 +21,7 @@ $ python ``` The -@{tf.train.Server.create_local_server} +`tf.train.Server.create_local_server` method creates a single-process cluster, with an in-process server. ## Create a cluster @@ -55,7 +55,7 @@ the following: The cluster specification dictionary maps job names to lists of network addresses. Pass this dictionary to -the @{tf.train.ClusterSpec} +the `tf.train.ClusterSpec` constructor. For example: @@ -84,10 +84,10 @@ tf.train.ClusterSpec({ ### Create a `tf.train.Server` instance in each task -A @{tf.train.Server} object contains a +A `tf.train.Server` object contains a set of local devices, a set of connections to other tasks in its `tf.train.ClusterSpec`, and a -@{tf.Session} that can use these +`tf.Session` that can use these to perform a distributed computation. Each server is a member of a specific named job and has a task index within that job. A server can communicate with any other server in the cluster. @@ -117,7 +117,7 @@ which you'd like to see support, please raise a ## Specifying distributed devices in your model To place operations on a particular process, you can use the same -@{tf.device} +`tf.device` function that is used to specify whether ops run on the CPU or GPU. For example: ```python @@ -165,7 +165,7 @@ simplify the work of specifying a replicated model. Possible approaches include: for each `/job:worker` task, typically in the same process as the worker task. Each client builds a similar graph containing the parameters (pinned to `/job:ps` as before using - @{tf.train.replica_device_setter} + `tf.train.replica_device_setter` to map them deterministically to the same tasks); and a single copy of the compute-intensive part of the model, pinned to the local task in `/job:worker`. @@ -180,7 +180,7 @@ simplify the work of specifying a replicated model. Possible approaches include: gradient averaging as in the [CIFAR-10 multi-GPU trainer](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py)), and between-graph replication (e.g. using the - @{tf.train.SyncReplicasOptimizer}). + `tf.train.SyncReplicasOptimizer`). ### Putting it all together: example trainer program @@ -314,11 +314,11 @@ serve multiple clients. **Cluster** -A TensorFlow cluster comprises a one or more "jobs", each divided into lists of +A TensorFlow cluster comprises one or more "jobs", each divided into lists of one or more "tasks". A cluster is typically dedicated to a particular high-level objective, such as training a neural network, using many machines in parallel. A cluster is defined by -a @{tf.train.ClusterSpec} object. +a `tf.train.ClusterSpec` object. **Job** @@ -344,7 +344,7 @@ to a single process. A task belongs to a particular "job" and is identified by its index within that job's list of tasks. **TensorFlow server** A process running -a @{tf.train.Server} instance, which is +a `tf.train.Server` instance, which is a member of a cluster, and exports a "master service" and "worker service". **Worker service** diff --git a/tensorflow/docs_src/deploy/s3.md b/tensorflow/docs_src/deploy/s3.md index 9ef9674338a905388abac819693aa79226e90a7d..079c796aa7766377c46f47087268e47b41356a12 100644 --- a/tensorflow/docs_src/deploy/s3.md +++ b/tensorflow/docs_src/deploy/s3.md @@ -40,7 +40,7 @@ AWS_SECRET_ACCESS_KEY=XXXXX AWS_REGION=us-east-1 # Region for the S3 bucket, this is not always needed. Default is us-east-1. S3_ENDPOINT=s3.us-east-1.amazonaws.com # The S3 API Endpoint to connect to. This is specified in a HOST:PORT format. S3_USE_HTTPS=1 # Whether or not to use HTTPS. Disable with 0. -S3_VERIFY_SSL=1 # If HTTPS is used, conterols if SSL should be enabled. Disable with 0. +S3_VERIFY_SSL=1 # If HTTPS is used, controls if SSL should be enabled. Disable with 0. ``` ## Usage @@ -90,4 +90,4 @@ S3 was invented by Amazon, but the S3 API has spread in popularity and has sever * [Amazon S3](https://aws.amazon.com/s3/) * [Google Storage](https://cloud.google.com/storage/docs/interoperability) -* [Minio](https://www.minio.io/kubernetes.html)(Standalone mode only) +* [Minio](https://www.minio.io/kubernetes.html) diff --git a/tensorflow/docs_src/extend/adding_an_op.md b/tensorflow/docs_src/extend/adding_an_op.md index 1b028be4ea16af89b8aac8a8a73e9ceca9e842c5..fbf5c0b90d57bfcd23ea8a09611d43b395c36c09 100644 --- a/tensorflow/docs_src/extend/adding_an_op.md +++ b/tensorflow/docs_src/extend/adding_an_op.md @@ -46,7 +46,7 @@ To incorporate your custom op you'll need to: 4. Write a function to compute gradients for the op (optional). 5. Test the op. We usually do this in Python for convenience, but you can also test the op in C++. If you define gradients, you can verify them with the - Python @{tf.test.compute_gradient_error$gradient checker}. + Python `tf.test.compute_gradient_error`. See [`relu_op_test.py`](https://www.tensorflow.org/code/tensorflow/python/kernel_tests/relu_op_test.py) as an example that tests the forward functions of Relu-like operators and @@ -388,7 +388,7 @@ $ bazel build --config opt //tensorflow/core/user_ops:zero_out.so ## Use the op in Python TensorFlow Python API provides the -@{tf.load_op_library} function to +`tf.load_op_library` function to load the dynamic library and register the op with the TensorFlow framework. `load_op_library` returns a Python module that contains the Python wrappers for the op and the kernel. Thus, once you have built the op, you can @@ -538,7 +538,7 @@ REGISTER_OP("ZeroOut") ``` (Note that the set of [attribute types](#attr_types) is different from the -@{tf.DType$tensor types} used for inputs and outputs.) +`tf.DType` used for inputs and outputs.) Your kernel can then access this attr in its constructor via the `context` parameter: @@ -615,7 +615,7 @@ define an attr with constraints, you can use the following ``s: * `{, }`: The value is of type `type`, and must be one of `` or ``, where `` and `` are supported - @{tf.DType$tensor types}. You don't specify + `tf.DType`. You don't specify that the type of the attr is `type`. This is implied when you have a list of types in `{...}`. For example, in this case the attr `t` is a type that must be an `int32`, a `float`, or a `bool`: @@ -649,7 +649,7 @@ define an attr with constraints, you can use the following ``s: ``` Lists can be combined with other lists and single types. The following - op allows attr `t` to be any of the numberic types, or the bool type: + op allows attr `t` to be any of the numeric types, or the bool type: ```c++ REGISTER_OP("NumberOrBooleanType") @@ -714,7 +714,7 @@ REGISTER_OP("AttrDefaultExampleForAllTypes") ``` Note in particular that the values of type `type` -use @{tf.DType$the `DT_*` names for the types}. +use `tf.DType`. #### Polymorphism @@ -1056,7 +1056,7 @@ expressions: `string`). This specifies a single tensor of the given type. See - @{tf.DType$the list of supported Tensor types}. + `tf.DType`. ```c++ REGISTER_OP("BuiltInTypesExample") @@ -1098,8 +1098,7 @@ expressions: * For a sequence of tensors with the same type: ` * `, where `` is the name of an [Attr](#attrs) with type `int`. The `` can - either be - @{tf.DType$a specific type like `int32` or `float`}, + either be a `tf.DType`, or the name of an attr with type `type`. As an example of the first, this op accepts a list of `int32` tensors: @@ -1202,7 +1201,7 @@ There are several examples of kernels with GPU support in Notice some kernels have a CPU version in a `.cc` file, a GPU version in a file ending in `_gpu.cu.cc`, and some code shared in common in a `.h` file. -For example, the @{tf.pad} has +For example, the `tf.pad` has everything but the GPU kernel in [`tensorflow/core/kernels/pad_op.cc`][pad_op]. The GPU kernel is in [`tensorflow/core/kernels/pad_op_gpu.cu.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/pad_op_gpu.cu.cc), @@ -1307,16 +1306,16 @@ def _zero_out_grad(op, grad): ``` Details about registering gradient functions with -@{tf.RegisterGradient}: +`tf.RegisterGradient`: * For an op with one output, the gradient function will take an - @{tf.Operation} `op` and a - @{tf.Tensor} `grad` and build new ops + `tf.Operation` `op` and a + `tf.Tensor` `grad` and build new ops out of the tensors [`op.inputs[i]`](../../api_docs/python/framework.md#Operation.inputs), [`op.outputs[i]`](../../api_docs/python/framework.md#Operation.outputs), and `grad`. Information about any attrs can be found via - @{tf.Operation.get_attr}. + `tf.Operation.get_attr`. * If the op has multiple outputs, the gradient function will take `op` and `grads`, where `grads` is a list of gradients with respect to each output. diff --git a/tensorflow/docs_src/extend/architecture.md b/tensorflow/docs_src/extend/architecture.md index c8f522a03ab0c15083abad927b5ca1dac2851740..83d70c9468e940b4b347d0d5652327c226ecffe4 100644 --- a/tensorflow/docs_src/extend/architecture.md +++ b/tensorflow/docs_src/extend/architecture.md @@ -7,9 +7,8 @@ learning models and system-level optimizations. This document describes the system architecture that makes this combination of scale and flexibility possible. It assumes that you have basic familiarity with TensorFlow programming concepts such as the computation graph, operations, -and sessions. See @{$programmers_guide/low_level_intro$this document} -for an introduction to these topics. Some familiarity -with @{$distributed$distributed TensorFlow} +and sessions. See @{$guide/low_level_intro$this document} for an introduction to +these topics. Some familiarity with @{$distributed$distributed TensorFlow} will also be helpful. This document is for developers who want to extend TensorFlow in some way not @@ -82,7 +81,7 @@ implementation from all client languages. Most of the training libraries are still Python-only, but C++ does have support for efficient inference. The client creates a session, which sends the graph definition to the -distributed master as a @{tf.GraphDef} +distributed master as a `tf.GraphDef` protocol buffer. When the client evaluates a node or nodes in the graph, the evaluation triggers a call to the distributed master to initiate computation. @@ -97,7 +96,7 @@ feature vector (x), adds a bias term (b) and saves the result in a variable ### Code -* @{tf.Session} +* `tf.Session` ## Distributed master diff --git a/tensorflow/docs_src/extend/index.md b/tensorflow/docs_src/extend/index.md index 1ab0340ad983de891ef5e18a729c1e4fb3c4e0d9..0e4bfd1dc46a2f669902dca30dfab512356705f3 100644 --- a/tensorflow/docs_src/extend/index.md +++ b/tensorflow/docs_src/extend/index.md @@ -17,7 +17,8 @@ TensorFlow: Python is currently the only language supported by TensorFlow's API stability promises. However, TensorFlow also provides functionality in C++, Go, Java and -[JavaScript](https://js.tensorflow.org), +[JavaScript](https://js.tensorflow.org) (including +[Node.js](https://github.com/tensorflow/tfjs-node)), plus community support for [Haskell](https://github.com/tensorflow/haskell) and [Rust](https://github.com/tensorflow/rust). If you'd like to create or develop TensorFlow features in a language other than these languages, read the diff --git a/tensorflow/docs_src/extend/new_data_formats.md b/tensorflow/docs_src/extend/new_data_formats.md index d1d1f69766987c2ff02d4ba447a02c0bd91f3a1b..47a8344b70adade03612532d6fab340b2576bed7 100644 --- a/tensorflow/docs_src/extend/new_data_formats.md +++ b/tensorflow/docs_src/extend/new_data_formats.md @@ -15,25 +15,24 @@ We divide the task of supporting a file format into two pieces: * Record formats: We use decoder or parsing ops to turn a string record into tensors usable by TensorFlow. -For example, to read a -[CSV file](https://en.wikipedia.org/wiki/Comma-separated_values), we use -@{tf.data.TextLineDataset$a dataset for reading text files line-by-line} -and then @{tf.data.Dataset.map$map} an -@{tf.decode_csv$op} that parses CSV data from each line of text in the dataset. +For example, to re-implement `tf.contrib.data.make_csv_dataset` function, we +could use `tf.data.TextLineDataset` to extract the records, and then +use `tf.data.Dataset.map` and `tf.decode_csv` to parses the CSV records from +each line of text in the dataset. [TOC] ## Writing a `Dataset` for a file format -A @{tf.data.Dataset} represents a sequence of *elements*, which can be the +A `tf.data.Dataset` represents a sequence of *elements*, which can be the individual records in a file. There are several examples of "reader" datasets that are already built into TensorFlow: -* @{tf.data.TFRecordDataset} +* `tf.data.TFRecordDataset` ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc)) -* @{tf.data.FixedLengthRecordDataset} +* `tf.data.FixedLengthRecordDataset` ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc)) -* @{tf.data.TextLineDataset} +* `tf.data.TextLineDataset` ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc)) Each of these implementations comprises three related classes: @@ -64,7 +63,7 @@ need to: that implement the reading logic. 2. In C++, register a new reader op and kernel with the name `"MyReaderDataset"`. -3. In Python, define a subclass of @{tf.data.Dataset} called `MyReaderDataset`. +3. In Python, define a subclass of `tf.data.Dataset` called `MyReaderDataset`. You can put all the C++ code in a single file, such as `my_reader_dataset_op.cc`. It will help if you are @@ -77,18 +76,24 @@ can be used as a starting point for your implementation: #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" -namespace tensorflow { +namespace myproject { namespace { -class MyReaderDatasetOp : public DatasetOpKernel { +using ::tensorflow::DT_STRING; +using ::tensorflow::PartialTensorShape; +using ::tensorflow::Status; + +class MyReaderDatasetOp : public tensorflow::DatasetOpKernel { public: - MyReaderDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { + MyReaderDatasetOp(tensorflow::OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) { // Parse and validate any attrs that define the dataset using // `ctx->GetAttr()`, and store them in member variables. } - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + void MakeDataset(tensorflow::OpKernelContext* ctx, + tensorflow::DatasetBase** output) override { // Parse and validate any input tensors 0that define the dataset using // `ctx->input()` or the utility function // `ParseScalarArgument(ctx, &arg)`. @@ -99,14 +104,14 @@ class MyReaderDatasetOp : public DatasetOpKernel { } private: - class Dataset : public GraphDatasetBase { + class Dataset : public tensorflow::GraphDatasetBase { public: - Dataset(OpKernelContext* ctx) : GraphDatasetBase(ctx) {} + Dataset(tensorflow::OpKernelContext* ctx) : GraphDatasetBase(ctx) {} - std::unique_ptr MakeIteratorInternal( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr( - new Iterator({this, strings::StrCat(prefix, "::MyReader")})); + return std::unique_ptr(new Iterator( + {this, tensorflow::strings::StrCat(prefix, "::MyReader")})); } // Record structure: Each record is represented by a scalar string tensor. @@ -114,8 +119,8 @@ class MyReaderDatasetOp : public DatasetOpKernel { // Dataset elements can have a fixed number of components of different // types and shapes; replace the following two methods to customize this // aspect of the dataset. - const DataTypeVector& output_dtypes() const override { - static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + const tensorflow::DataTypeVector& output_dtypes() const override { + static auto* const dtypes = new tensorflow::DataTypeVector({DT_STRING}); return *dtypes; } const std::vector& output_shapes() const override { @@ -132,16 +137,16 @@ class MyReaderDatasetOp : public DatasetOpKernel { // Implement this method if you want to be able to save and restore // instances of this dataset (and any iterators over it). Status AsGraphDefInternal(DatasetGraphDefBuilder* b, - Node** output) const override { + tensorflow::Node** output) const override { // Construct nodes to represent any of the input tensors from this // object's member variables using `b->AddScalar()` and `b->AddVector()`. - std::vector input_tensors; + std::vector input_tensors; TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output)); return Status::OK(); } private: - class Iterator : public DatasetIterator { + class Iterator : public tensorflow::DatasetIterator { public: explicit Iterator(const Params& params) : DatasetIterator(params), i_(0) {} @@ -158,15 +163,15 @@ class MyReaderDatasetOp : public DatasetOpKernel { // return `Status::OK()`. // 3. If an error occurs, return an error status using one of the helper // functions from "tensorflow/core/lib/core/errors.h". - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, + Status GetNextInternal(tensorflow::IteratorContext* ctx, + std::vector* out_tensors, bool* end_of_sequence) override { // NOTE: `GetNextInternal()` may be called concurrently, so it is // recommended that you protect the iterator state with a mutex. - mutex_lock l(mu_); + tensorflow::mutex_lock l(mu_); if (i_ < 10) { // Create a scalar string tensor and add it to the output. - Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); + tensorflow::Tensor record_tensor(ctx->allocator({}), DT_STRING, {}); record_tensor.scalar()() = "MyReader!"; out_tensors->emplace_back(std::move(record_tensor)); ++i_; @@ -183,20 +188,20 @@ class MyReaderDatasetOp : public DatasetOpKernel { // // Implement these two methods if you want to be able to save and restore // instances of this iterator. - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); + Status SaveInternal(tensorflow::IteratorStateWriter* writer) override { + tensorflow::mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); return Status::OK(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); + Status RestoreInternal(tensorflow::IteratorContext* ctx, + tensorflow::IteratorStateReader* reader) override { + tensorflow::mutex_lock l(mu_); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); return Status::OK(); } private: - mutex mu_; + tensorflow::mutex mu_; int64 i_ GUARDED_BY(mu_); }; }; @@ -211,20 +216,20 @@ class MyReaderDatasetOp : public DatasetOpKernel { REGISTER_OP("MyReaderDataset") .Output("handle: variant") .SetIsStateful() - .SetShapeFn(shape_inference::ScalarShape); + .SetShapeFn(tensorflow::shape_inference::ScalarShape); // Register the kernel implementation for MyReaderDataset. -REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(DEVICE_CPU), +REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(tensorflow::DEVICE_CPU), MyReaderDatasetOp); } // namespace -} // namespace tensorflow +} // namespace myproject ``` The last step is to build the C++ code and add a Python wrapper. The easiest way to do this is by @{$adding_an_op#build_the_op_library$compiling a dynamic library} (e.g. called `"my_reader_dataset_op.so"`), and adding a Python class -that subclasses @{tf.data.Dataset} to wrap it. An example Python program is +that subclasses `tf.data.Dataset` to wrap it. An example Python program is given here: ```python @@ -287,14 +292,14 @@ track down where the bad data came from. Examples of Ops useful for decoding records: -* @{tf.parse_single_example} (and @{tf.parse_example}) -* @{tf.decode_csv} -* @{tf.decode_raw} +* `tf.parse_single_example` (and `tf.parse_example`) +* `tf.decode_csv` +* `tf.decode_raw` Note that it can be useful to use multiple Ops to decode a particular record format. For example, you may have an image saved as a string in [a `tf.train.Example` protocol buffer](https://www.tensorflow.org/code/tensorflow/core/example/example.proto). Depending on the format of that image, you might take the corresponding output -from a @{tf.parse_single_example} op and call @{tf.image.decode_jpeg}, -@{tf.image.decode_png}, or @{tf.decode_raw}. It is common to take the output -of `tf.decode_raw` and use @{tf.slice} and @{tf.reshape} to extract pieces. +from a `tf.parse_single_example` op and call `tf.image.decode_jpeg`, +`tf.image.decode_png`, or `tf.decode_raw`. It is common to take the output +of `tf.decode_raw` and use `tf.slice` and `tf.reshape` to extract pieces. diff --git a/tensorflow/docs_src/get_started/eager.md b/tensorflow/docs_src/get_started/eager.md deleted file mode 100644 index bbb25e20c62f6a2eec78668250a0e748494797c5..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/get_started/eager.md +++ /dev/null @@ -1,3 +0,0 @@ -# Get Started with Eager Execution - -[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/r1.9.0/samples/core/get_started/eager.ipynb) diff --git a/tensorflow/docs_src/get_started/index.md b/tensorflow/docs_src/get_started/index.md deleted file mode 100644 index 232d2f154703dc10320f9ee074c67d6e1a8ee850..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/get_started/index.md +++ /dev/null @@ -1,29 +0,0 @@ -# Get Started - -If you are new to machine learning, we recommend taking the following online -course prior to diving into TensorFlow documentation: - - * [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/), - which introduces machine learning concepts and encourages experimentation - with existing TensorFlow code. - -TensorFlow is a tool for machine learning. While it contains a wide range of -functionality, TensorFlow is mainly designed for deep neural network models. - -The easiest way to get started with TensorFlow is by using Eager Execution. - - * @{$get_started/eager}, is for anyone new to machine learning or TensorFlow. - -TensorFlow provides many APIs. The remainder of this section focuses on the -Estimator API which provide scalable, high-performance models. See the -@{$estimators} guide. - -For more advanced users: - - * The @{$low_level_intro$Low Level Introduction} demonstrates how to use - TensorFlow outside of the Estimator framework, for debugging and - experimentation. - * The @{$programmers_guide$Programmer's Guide} details major - TensorFlow components. - * The @{$tutorials$Tutorials} provide walkthroughs of a variety of - TensorFlow models. diff --git a/tensorflow/docs_src/get_started/leftnav_files b/tensorflow/docs_src/get_started/leftnav_files deleted file mode 100644 index e6cc8d565810683947e9cf9692e7cccb43916e66..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/get_started/leftnav_files +++ /dev/null @@ -1,4 +0,0 @@ -index.md - -eager.md -datasets_quickstart.md diff --git a/tensorflow/docs_src/guide/autograph.md b/tensorflow/docs_src/guide/autograph.md new file mode 100644 index 0000000000000000000000000000000000000000..823e1c6d6bfff8e575eb7479d9c3592cdb2c01cf --- /dev/null +++ b/tensorflow/docs_src/guide/autograph.md @@ -0,0 +1,3 @@ +# AutoGraph: Easy control flow for graphs + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/guide/autograph.ipynb) diff --git a/tensorflow/docs_src/programmers_guide/checkpoints.md b/tensorflow/docs_src/guide/checkpoints.md similarity index 95% rename from tensorflow/docs_src/programmers_guide/checkpoints.md rename to tensorflow/docs_src/guide/checkpoints.md index 8dfd91e3c8368f4a649c5b5fa3947e97441ef390..e1add298527f27c063cc7622c26e5f3cc28e863d 100644 --- a/tensorflow/docs_src/programmers_guide/checkpoints.md +++ b/tensorflow/docs_src/guide/checkpoints.md @@ -8,9 +8,8 @@ Estimators. TensorFlow provides two model formats: * SavedModel, which is a format independent of the code that created the model. -This document focuses on checkpoints. For details on SavedModel, see the -@{$saved_model$Saving and Restoring} chapter of the -*TensorFlow Programmer's Guide*. +This document focuses on checkpoints. For details on `SavedModel`, see the +@{$saved_model$Saving and Restoring} guide. ## Sample code @@ -130,7 +129,7 @@ in the `model_dir` according to the following schedule: You may alter the default schedule by taking the following steps: -1. Create a @{tf.estimator.RunConfig$`RunConfig`} object that defines the +1. Create a `tf.estimator.RunConfig` object that defines the desired schedule. 2. When instantiating the Estimator, pass that `RunConfig` object to the Estimator's `config` argument. @@ -232,8 +231,7 @@ This separation will keep your checkpoints recoverable. Checkpoints provide an easy automatic mechanism for saving and restoring models created by Estimators. -See the @{$saved_model$Saving and Restoring} -chapter of the *TensorFlow Programmer's Guide* for details on: +See the @{$saved_model$Saving and Restoring} guide for details about: * Saving and restoring models using low-level TensorFlow APIs. * Exporting and importing models in the SavedModel format, which is a diff --git a/tensorflow/docs_src/programmers_guide/custom_estimators.md b/tensorflow/docs_src/guide/custom_estimators.md similarity index 91% rename from tensorflow/docs_src/programmers_guide/custom_estimators.md rename to tensorflow/docs_src/guide/custom_estimators.md index fb20b35c128b5bdafbb88ccb19df05f6a73c9977..199a0e93de1b8bf1b00b3539e975278481781cb1 100644 --- a/tensorflow/docs_src/programmers_guide/custom_estimators.md +++ b/tensorflow/docs_src/guide/custom_estimators.md @@ -2,9 +2,9 @@ # Creating Custom Estimators This document introduces custom Estimators. In particular, this document -demonstrates how to create a custom @{tf.estimator.Estimator$Estimator} that +demonstrates how to create a custom `tf.estimator.Estimator` that mimics the behavior of the pre-made Estimator -@{tf.estimator.DNNClassifier$`DNNClassifier`} in solving the Iris problem. See +`tf.estimator.DNNClassifier` in solving the Iris problem. See the @{$premade_estimators$Pre-Made Estimators chapter} for details on the Iris problem. @@ -34,7 +34,7 @@ with ## Pre-made vs. custom As the following figure shows, pre-made Estimators are subclasses of the -@{tf.estimator.Estimator} base class, while custom Estimators are an instance +`tf.estimator.Estimator` base class, while custom Estimators are an instance of tf.estimator.Estimator:
@@ -144,12 +144,12 @@ The caller may pass `params` to an Estimator's constructor. Any `params` passed to the constructor are in turn passed on to the `model_fn`. In [`custom_estimator.py`](https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py) the following lines create the estimator and set the params to configure the -model. This configuration step is similar to how we configured the @{tf.estimator.DNNClassifier} in +model. This configuration step is similar to how we configured the `tf.estimator.DNNClassifier` in @{$premade_estimators}. ```python classifier = tf.estimator.Estimator( - model_fn=my_model, + model_fn=my_model_fn, params={ 'feature_columns': my_feature_columns, # Two hidden layers of 10 nodes each. @@ -178,7 +178,7 @@ The basic deep neural network model must define the following three sections: ### Define the input layer -The first line of the `model_fn` calls @{tf.feature_column.input_layer} to +The first line of the `model_fn` calls `tf.feature_column.input_layer` to convert the feature dictionary and `feature_columns` into input for your model, as follows: @@ -202,7 +202,7 @@ creating the model's input layer. If you are creating a deep neural network, you must define one or more hidden layers. The Layers API provides a rich set of functions to define all types of hidden layers, including convolutional, pooling, and dropout layers. For Iris, -we're simply going to call @{tf.layers.dense} to create hidden layers, with +we're simply going to call `tf.layers.dense` to create hidden layers, with dimensions defined by `params['hidden_layers']`. In a `dense` layer each node is connected to every node in the preceding layer. Here's the relevant code: @@ -231,14 +231,14 @@ simplicity, the figure does not show all the units in each layer. src="../images/custom_estimators/add_hidden_layer.png">
-Note that @{tf.layers.dense} provides many additional capabilities, including +Note that `tf.layers.dense` provides many additional capabilities, including the ability to set a multitude of regularization parameters. For the sake of simplicity, though, we're going to simply accept the default values of the other parameters. ### Output Layer -We'll define the output layer by calling @{tf.layers.dense} yet again, this +We'll define the output layer by calling `tf.layers.dense` yet again, this time without an activation function: ```python @@ -265,7 +265,7 @@ score, or "logit", calculated for the associated class of Iris: Setosa, Versicolor, or Virginica, respectively. Later on, these logits will be transformed into probabilities by the -@{tf.nn.softmax} function. +`tf.nn.softmax` function. ## Implement training, evaluation, and prediction {#modes} @@ -290,9 +290,9 @@ function with the mode parameter set as follows: | Estimator method | Estimator Mode | |:---------------------------------|:------------------| -|@{tf.estimator.Estimator.train$`train()`} |@{tf.estimator.ModeKeys.TRAIN$`ModeKeys.TRAIN`} | -|@{tf.estimator.Estimator.evaluate$`evaluate()`} |@{tf.estimator.ModeKeys.EVAL$`ModeKeys.EVAL`} | -|@{tf.estimator.Estimator.predict$`predict()`}|@{tf.estimator.ModeKeys.PREDICT$`ModeKeys.PREDICT`} | +|`tf.estimator.Estimator.train` |`tf.estimator.ModeKeys.TRAIN` | +|`tf.estimator.Estimator.evaluate` |`tf.estimator.ModeKeys.EVAL` | +|`tf.estimator.Estimator.predict`|`tf.estimator.ModeKeys.PREDICT` | For example, suppose you instantiate a custom Estimator to generate an object named `classifier`. Then, you make the following call: @@ -350,8 +350,8 @@ The `predictions` holds the following three key/value pairs: * `logit` holds the raw logit values (in this example, -1.3, 2.6, and -0.9) We return that dictionary to the caller via the `predictions` parameter of the -@{tf.estimator.EstimatorSpec}. The Estimator's -@{tf.estimator.Estimator.predict$`predict`} method will yield these +`tf.estimator.EstimatorSpec`. The Estimator's +`tf.estimator.Estimator.predict` method will yield these dictionaries. ### Calculate the loss @@ -361,11 +361,11 @@ model's loss. This is the [objective](https://developers.google.com/machine-learning/glossary/#objective) that will be optimized. -We can calculate the loss by calling @{tf.losses.sparse_softmax_cross_entropy}. -The value returned by this function will be lowest, approximately 0, -probability of the correct class (at index `label`) is near 1.0. The loss value -returned is progressively larger as the probability of the correct class -decreases. +We can calculate the loss by calling `tf.losses.sparse_softmax_cross_entropy`. +The value returned by this function will be approximately 0 at lowest, +when the probability of the correct class (at index `label`) is near 1.0. +The loss value returned is progressively larger as the probability of the +correct class decreases. This function returns the average over the whole batch. @@ -382,12 +382,12 @@ When the Estimator's `evaluate` method is called, the `model_fn` receives or more metrics. Although returning metrics is optional, most custom Estimators do return at -least one metric. TensorFlow provides a Metrics module @{tf.metrics} to +least one metric. TensorFlow provides a Metrics module `tf.metrics` to calculate common metrics. For brevity's sake, we'll only return accuracy. The -@{tf.metrics.accuracy} function compares our predictions against the +`tf.metrics.accuracy` function compares our predictions against the true values, that is, against the labels provided by the input function. The -@{tf.metrics.accuracy} function requires the labels and predictions to have the -same shape. Here's the call to @{tf.metrics.accuracy}: +`tf.metrics.accuracy` function requires the labels and predictions to have the +same shape. Here's the call to `tf.metrics.accuracy`: ``` python # Compute evaluation metrics. @@ -396,7 +396,7 @@ accuracy = tf.metrics.accuracy(labels=labels, name='acc_op') ``` -The @{tf.estimator.EstimatorSpec$`EstimatorSpec`} returned for evaluation +The `tf.estimator.EstimatorSpec` returned for evaluation typically contains the following information: * `loss`, which is the model's loss @@ -416,7 +416,7 @@ if mode == tf.estimator.ModeKeys.EVAL: mode, loss=loss, eval_metric_ops=metrics) ``` -The @{tf.summary.scalar} will make accuracy available to TensorBoard +The `tf.summary.scalar` will make accuracy available to TensorBoard in both `TRAIN` and `EVAL` modes. (More on this later). ### Train @@ -426,7 +426,7 @@ with `mode = ModeKeys.TRAIN`. In this case, the model function must return an `EstimatorSpec` that contains the loss and a training operation. Building the training operation will require an optimizer. We will use -@{tf.train.AdagradOptimizer} because we're mimicking the `DNNClassifier`, which +`tf.train.AdagradOptimizer` because we're mimicking the `DNNClassifier`, which also uses `Adagrad` by default. The `tf.train` package provides many other optimizers—feel free to experiment with them. @@ -437,14 +437,14 @@ optimizer = tf.train.AdagradOptimizer(learning_rate=0.1) ``` Next, we build the training operation using the optimizer's -@{tf.train.Optimizer.minimize$`minimize`} method on the loss we calculated +`tf.train.Optimizer.minimize` method on the loss we calculated earlier. The `minimize` method also takes a `global_step` parameter. TensorFlow uses this parameter to count the number of training steps that have been processed (to know when to end a training run). Furthermore, the `global_step` is essential for TensorBoard graphs to work correctly. Simply call -@{tf.train.get_global_step} and pass the result to the `global_step` +`tf.train.get_global_step` and pass the result to the `global_step` argument of `minimize`. Here's the code to train the model: @@ -453,7 +453,7 @@ Here's the code to train the model: train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) ``` -The @{tf.estimator.EstimatorSpec$`EstimatorSpec`} returned for training +The `tf.estimator.EstimatorSpec` returned for training must have the following fields set: * `loss`, which contains the value of the loss function. @@ -474,7 +474,7 @@ Instantiate the custom Estimator through the Estimator base class as follows: ```python # Build 2 hidden layer DNN with 10, 10 units respectively. classifier = tf.estimator.Estimator( - model_fn=my_model, + model_fn=my_model_fn, params={ 'feature_columns': my_feature_columns, # Two hidden layers of 10 nodes each. diff --git a/tensorflow/docs_src/programmers_guide/datasets.md b/tensorflow/docs_src/guide/datasets.md similarity index 97% rename from tensorflow/docs_src/programmers_guide/datasets.md rename to tensorflow/docs_src/guide/datasets.md index 8b69860a68461e849a445f5c01c2e9b71d614a46..bb18e8b79cef8cd9958fa77ac20819d1dc7675e1 100644 --- a/tensorflow/docs_src/programmers_guide/datasets.md +++ b/tensorflow/docs_src/guide/datasets.md @@ -1,6 +1,6 @@ # Importing Data -The @{tf.data} API enables you to build complex input pipelines from +The `tf.data` API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch @@ -51,7 +51,7 @@ Once you have a `Dataset` object, you can *transform* it into a new `Dataset` by chaining method calls on the `tf.data.Dataset` object. For example, you can apply per-element transformations such as `Dataset.map()` (to apply a function to each element), and multi-element transformations such as -`Dataset.batch()`. See the documentation for @{tf.data.Dataset} +`Dataset.batch()`. See the documentation for `tf.data.Dataset` for a complete list of transformations. The most common way to consume values from a `Dataset` is to make an @@ -211,13 +211,13 @@ for _ in range(20): sess.run(next_element) ``` -A **feedable** iterator can be used together with @{tf.placeholder} to select -what `Iterator` to use in each call to @{tf.Session.run}, via the familiar +A **feedable** iterator can be used together with `tf.placeholder` to select +what `Iterator` to use in each call to `tf.Session.run`, via the familiar `feed_dict` mechanism. It offers the same functionality as a reinitializable iterator, but it does not require you to initialize the iterator from the start of a dataset when you switch between iterators. For example, using the same training and validation example from above, you can use -@{tf.data.Iterator.from_string_handle} to define a feedable iterator +`tf.data.Iterator.from_string_handle` to define a feedable iterator that allows you to switch between the two datasets: ```python @@ -329,12 +329,12 @@ of an iterator will include all components in a single expression. ### Saving iterator state -The @{tf.contrib.data.make_saveable_from_iterator} function creates a +The `tf.contrib.data.make_saveable_from_iterator` function creates a `SaveableObject` from an iterator, which can be used to save and restore the current state of the iterator (and, effectively, the whole input -pipeline). A saveable object thus created can be added to @{tf.train.Saver} +pipeline). A saveable object thus created can be added to `tf.train.Saver` variables list or the `tf.GraphKeys.SAVEABLE_OBJECTS` collection for saving and -restoring in the same manner as a @{tf.Variable}. Refer to +restoring in the same manner as a `tf.Variable`. Refer to @{$saved_model$Saving and Restoring} for details on how to save and restore variables. @@ -488,7 +488,7 @@ dataset = dataset.flat_map( ### Consuming CSV data The CSV file format is a popular format for storing tabular data in plain text. -The @{tf.contrib.data.CsvDataset} class provides a way to extract records from +The `tf.contrib.data.CsvDataset` class provides a way to extract records from one or more CSV files that comply with [RFC 4180](https://tools.ietf.org/html/rfc4180). Given one or more filenames and a list of defaults, a `CsvDataset` will produce a tuple of elements whose types correspond to the types of the defaults @@ -757,9 +757,9 @@ dataset = dataset.repeat() ### Using high-level APIs -The @{tf.train.MonitoredTrainingSession} API simplifies many aspects of running +The `tf.train.MonitoredTrainingSession` API simplifies many aspects of running TensorFlow in a distributed setting. `MonitoredTrainingSession` uses the -@{tf.errors.OutOfRangeError} to signal that training has completed, so to use it +`tf.errors.OutOfRangeError` to signal that training has completed, so to use it with the `tf.data` API, we recommend using `Dataset.make_one_shot_iterator()`. For example: @@ -782,7 +782,7 @@ with tf.train.MonitoredTrainingSession(...) as sess: sess.run(training_op) ``` -To use a `Dataset` in the `input_fn` of a @{tf.estimator.Estimator}, we also +To use a `Dataset` in the `input_fn` of a `tf.estimator.Estimator`, we also recommend using `Dataset.make_one_shot_iterator()`. For example: ```python diff --git a/tensorflow/docs_src/get_started/datasets_quickstart.md b/tensorflow/docs_src/guide/datasets_for_estimators.md similarity index 88% rename from tensorflow/docs_src/get_started/datasets_quickstart.md rename to tensorflow/docs_src/guide/datasets_for_estimators.md index 020e40dd3b8f046f0144e3806468f58833f7b607..969ea579f7e85fc296f928c6ab71ea94d47d0fb5 100644 --- a/tensorflow/docs_src/get_started/datasets_quickstart.md +++ b/tensorflow/docs_src/guide/datasets_for_estimators.md @@ -1,6 +1,6 @@ -# Datasets Quick Start +# Datasets for Estimators -The @{tf.data} module contains a collection of classes that allows you to +The `tf.data` module contains a collection of classes that allows you to easily load data, manipulate it, and pipe it into your model. This document introduces the API by walking through two simple examples: @@ -73,12 +73,12 @@ Let's walk through the `train_input_fn()`. ### Slices -The function starts by using the @{tf.data.Dataset.from_tensor_slices} function -to create a @{tf.data.Dataset} representing slices of the array. The array is +The function starts by using the `tf.data.Dataset.from_tensor_slices` function +to create a `tf.data.Dataset` representing slices of the array. The array is sliced across the first dimension. For example, an array containing the -@{$tutorials/layers$mnist training data} has a shape of `(60000, 28, 28)`. -Passing this to `from_tensor_slices` returns a `Dataset` object containing -60000 slices, each one a 28x28 image. +MNIST training data has a shape of `(60000, 28, 28)`. Passing this to +`from_tensor_slices` returns a `Dataset` object containing 60000 slices, each one +a 28x28 image. The code that returns this `Dataset` is as follows: @@ -91,8 +91,8 @@ print(mnist_ds) ``` This will print the following line, showing the -@{$programmers_guide/tensors#shapes$shapes} and -@{$programmers_guide/tensors#data_types$types} of the items in +@{$guide/tensors#shapes$shapes} and +@{$guide/tensors#data_types$types} of the items in the dataset. Note that a `Dataset` does not know how many items it contains. ``` None @@ -128,7 +128,7 @@ print(dataset) Here we see that when a `Dataset` contains structured elements, the `shapes` and `types` of the `Dataset` take on the same structure. This dataset contains -dictionaries of @{$programmers_guide/tensors#rank$scalars}, all of type +dictionaries of @{$guide/tensors#rank$scalars}, all of type `tf.float64`. The first line of the iris `train_input_fn` uses the same functionality, but @@ -170,15 +170,15 @@ function takes advantage of several of these methods: dataset = dataset.shuffle(1000).repeat().batch(batch_size) ``` -The @{tf.data.Dataset.shuffle$`shuffle`} method uses a fixed-size buffer to +The `tf.data.Dataset.shuffle` method uses a fixed-size buffer to shuffle the items as they pass through. In this case the `buffer_size` is greater than the number of examples in the `Dataset`, ensuring that the data is completely shuffled (The Iris data set only contains 150 examples). -The @{tf.data.Dataset.repeat$`repeat`} method restarts the `Dataset` when +The `tf.data.Dataset.repeat` method restarts the `Dataset` when it reaches the end. To limit the number of epochs, set the `count` argument. -The @{tf.data.Dataset.batch$`batch`} method collects a number of examples and +The `tf.data.Dataset.batch` method collects a number of examples and stacks them, to create batches. This adds a dimension to their shape. The new dimension is added as the first dimension. The following code uses the `batch` method on the MNIST `Dataset`, from earlier. This results in a @@ -234,7 +234,7 @@ The `labels` can/should be omitted when using the `predict` method. ## Reading a CSV File The most common real-world use case for the `Dataset` class is to stream data -from files on disk. The @{tf.data} module includes a variety of +from files on disk. The `tf.data` module includes a variety of file readers. Let's see how parsing the Iris dataset from the csv file looks using a `Dataset`. @@ -255,9 +255,9 @@ from the local files. ### Build the `Dataset` -We start by building a @{tf.data.TextLineDataset$`TextLineDataset`} object to +We start by building a `tf.data.TextLineDataset` object to read the file one line at a time. Then, we call the -@{tf.data.Dataset.skip$`skip`} method to skip over the first line of the file, which contains a header, not an example: +`tf.data.Dataset.skip` method to skip over the first line of the file, which contains a header, not an example: ``` python ds = tf.data.TextLineDataset(train_path).skip(1) @@ -268,11 +268,11 @@ ds = tf.data.TextLineDataset(train_path).skip(1) We will start by building a function to parse a single line. The following `iris_data.parse_line` function accomplishes this task using the -@{tf.decode_csv} function, and some simple python code: +`tf.decode_csv` function, and some simple python code: We must parse each of the lines in the dataset in order to generate the necessary `(features, label)` pairs. The following `_parse_line` function -calls @{tf.decode_csv} to parse a single line into its features +calls `tf.decode_csv` to parse a single line into its features and the label. Since Estimators require that features be represented as a dictionary, we rely on Python's built-in `dict` and `zip` functions to build that dictionary. The feature names are the keys of that dictionary. @@ -301,7 +301,7 @@ def _parse_line(line): ### Parse the lines Datasets have many methods for manipulating the data while it is being piped -to a model. The most heavily-used method is @{tf.data.Dataset.map$`map`}, which +to a model. The most heavily-used method is `tf.data.Dataset.map`, which applies a transformation to each element of the `Dataset`. The `map` method takes a `map_func` argument that describes how each item in the @@ -311,7 +311,7 @@ The `map` method takes a `map_func` argument that describes how each item in the
-The @{tf.data.Dataset.map$`map`} method applies the `map_func` to +The `tf.data.Dataset.map` method applies the `map_func` to transform each item in the Dataset.
@@ -382,6 +382,6 @@ Estimator. Consider the following documents next: * The @{$low_level_intro#datasets$Low Level Introduction}, which demonstrates how to experiment directly with `tf.data.Datasets` using TensorFlow's low level APIs. -* @{$programmers_guide/datasets} which goes into great detail about additional +* @{$guide/datasets} which goes into great detail about additional functionality of `Datasets`. diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/guide/debugger.md similarity index 97% rename from tensorflow/docs_src/programmers_guide/debugger.md rename to tensorflow/docs_src/guide/debugger.md index fc845c68f48204ab5e9796234a6c630f27704def..4c4a04a88af19ec2d3b1fc0b093a38153666d2de 100644 --- a/tensorflow/docs_src/programmers_guide/debugger.md +++ b/tensorflow/docs_src/guide/debugger.md @@ -17,7 +17,7 @@ how to use the graphical user interface (GUI) of tfdbg, i.e., the Note: The TensorFlow debugger uses a [curses](https://en.wikipedia.org/wiki/Curses_\(programming_library\))-based text user interface. On Mac OS X, the `ncurses` library is required and can be -installed with `brew install homebrew/dupes/ncurses`. On Windows, curses isn't as +installed with `brew install ncurses`. On Windows, curses isn't as well supported, so a [readline](https://en.wikipedia.org/wiki/GNU_Readline)-based interface can be used with tfdbg by installing `pyreadline` with `pip`. If you use Anaconda3, you can install it with a command such as @@ -89,7 +89,7 @@ control the execution and inspect the graph's internal state. the diagnosis of issues. In this example, we have already registered a tensor filter called -@{tfdbg.has_inf_or_nan}, +`tfdbg.has_inf_or_nan`, which simply determines if there are any `nan` or `inf` values in any intermediate tensors (tensors that are neither inputs or outputs of the `Session.run()` call, but are in the path leading from the inputs to the @@ -98,13 +98,11 @@ we ship it with the @{$python/tfdbg#Classes_for_debug_dump_data_and_directories$`debug_data`} module. -Note: You can also write your own custom filters. See -the @{tfdbg.DebugDumpDir.find$API documentation} -of `DebugDumpDir.find()` for additional information. +Note: You can also write your own custom filters. See `tfdbg.DebugDumpDir.find` +for additional information. ## Debugging Model Training with tfdbg - Let's try training the model again, but with the `--debug` flag added this time: ```none @@ -210,6 +208,7 @@ Try the following commands at the `tfdbg>` prompt (referencing the code at | **`config`** | | **Set or show persistent TFDBG UI configuration.** | | | | `set` | Set the value of a config item: {`graph_recursion_depth`, `mouse_mode`}. | `config set graph_recursion_depth 3` | | | `show` | Show current persistent UI configuration. | `config show` | +| **`version`** | | **Print the version of TensorFlow and its key dependencies.** | `version` | | **`help`** | | **Print general help information** | `help` | | | `help ` | Print help for given command. | `help lt` | @@ -428,9 +427,9 @@ described in the preceding sections inapplicable. Fortunately, you can still debug them by using special `hook`s provided by `tfdbg`. `tfdbg` can debug the -@{tf.estimator.Estimator.train$`train()`}, -@{tf.estimator.Estimator.evaluate$`evaluate()`} and -@{tf.estimator.Estimator.predict$`predict()`} +`tf.estimator.Estimator.train`, +`tf.estimator.Estimator.evaluate` and +`tf.estimator.Estimator.predict` methods of tf-learn `Estimator`s. To debug `Estimator.train()`, create a `LocalCLIDebugHook` and supply it in the `hooks` argument. For example: @@ -462,7 +461,6 @@ predict_results = classifier.predict(predict_input_fn, hooks=hooks) ``` [debug_tflearn_iris.py](https://www.tensorflow.org/code/tensorflow/python/debug/examples/debug_tflearn_iris.py), -based on [tf-learn's iris tutorial](https://www.tensorflow.org/versions/r1.8/get_started/tflearn), contains a full example of how to use the tfdbg with `Estimator`s. To run this example, do: @@ -473,7 +471,7 @@ python -m tensorflow.python.debug.examples.debug_tflearn_iris --debug The `LocalCLIDebugHook` also allows you to configure a `watch_fn` that can be used to flexibly specify what `Tensor`s to watch on different `Session.run()` calls, as a function of the `fetches` and `feed_dict` and other states. See -@{tfdbg.DumpingDebugWrapperSession.__init__$this API doc} +`tfdbg.DumpingDebugWrapperSession.__init__` for more details. ## Debugging Keras Models with TFDBG @@ -556,7 +554,7 @@ and the higher-level `Estimator` API. If you interact directly with the `tf.Session` API in `python`, you can configure the `RunOptions` proto that you call your `Session.run()` method -with, by using the method @{tfdbg.watch_graph}. +with, by using the method `tfdbg.watch_graph`. This will cause the intermediate tensors and runtime graphs to be dumped to a shared storage location of your choice when the `Session.run()` call occurs (at the cost of slower performance). For example: @@ -629,7 +627,7 @@ hooks = [tf_debug.DumpingDebugHook("/shared/storage/location/tfdbg_dumps_1")] Then this `hook` can be used in the same way as the `LocalCLIDebugHook` examples described earlier in this document. -As the training, evalution or prediction happens with `Estimator`, +As the training, evaluation or prediction happens with `Estimator`, tfdbg creates directories having the following name pattern: `/shared/storage/location/tfdbg_dumps_1/run__`. Each directory corresponds to a `Session.run()` call that underlies @@ -715,7 +713,7 @@ You might encounter this problem in any of the following situations: * models with many intermediate tensors * very large intermediate tensors -* many @{tf.while_loop} iterations +* many `tf.while_loop` iterations There are three possible workarounds or solutions: @@ -770,17 +768,17 @@ sess.run(b) **A**: The reason why you see no data dumped is because every node in the executed TensorFlow graph is constant-folded by the TensorFlow runtime. - In this exapmle, `a` is a constant tensor; therefore, the fetched + In this example, `a` is a constant tensor; therefore, the fetched tensor `b` is effectively also a constant tensor. TensorFlow's graph optimization folds the graph that contains `a` and `b` into a single node to speed up future runs of the graph, which is why `tfdbg` does not generate any intermediate tensor dumps. However, if `a` were a - @{tf.Variable}, as in the following example: + `tf.Variable`, as in the following example: ``` python import numpy as np -a = tf.Variable(np.ones[10], name="a") +a = tf.Variable(np.ones(10), name="a") b = tf.add(a, a, name="b") sess = tf.Session() sess.run(tf.global_variables_initializer()) diff --git a/tensorflow/docs_src/programmers_guide/eager.md b/tensorflow/docs_src/guide/eager.md similarity index 94% rename from tensorflow/docs_src/programmers_guide/eager.md rename to tensorflow/docs_src/guide/eager.md index b2bc3273b4529233faaaca8ce325f4e12979faa4..e47a8b599cd00220b186ff5117bf2c3cdccf7c22 100644 --- a/tensorflow/docs_src/programmers_guide/eager.md +++ b/tensorflow/docs_src/guide/eager.md @@ -225,7 +225,7 @@ the tape backwards and then discard. A particular `tf.GradientTape` can only compute one gradient; subsequent calls throw a runtime error. ```py -w = tfe.Variable([[1.0]]) +w = tf.Variable([[1.0]]) with tf.GradientTape() as tape: loss = w * w @@ -260,8 +260,8 @@ def grad(weights, biases): train_steps = 200 learning_rate = 0.01 # Start with arbitrary values for W and B on the same batch of data -W = tfe.Variable(5.) -B = tfe.Variable(10.) +W = tf.Variable(5.) +B = tf.Variable(10.) print("Initial loss: {:.3f}".format(loss(W, B))) @@ -316,9 +316,8 @@ for (batch, (images, labels)) in enumerate(dataset): The following example creates a multi-layer model that classifies the standard -[MNIST handwritten digits](https://www.tensorflow.org/tutorials/layers). It -demonstrates the optimizer and layer APIs to build trainable graphs in an eager -execution environment. +MNIST handwritten digits. It demonstrates the optimizer and layer APIs to build +trainable graphs in an eager execution environment. ### Train a model @@ -408,11 +407,11 @@ with tf.device("/gpu:0"): ### Variables and optimizers -`tfe.Variable` objects store mutable `tf.Tensor` values accessed during +`tf.Variable` objects store mutable `tf.Tensor` values accessed during training to make automatic differentiation easier. The parameters of a model can be encapsulated in classes as variables. -Better encapsulate model parameters by using `tfe.Variable` with +Better encapsulate model parameters by using `tf.Variable` with `tf.GradientTape`. For example, the automatic differentiation example above can be rewritten: @@ -420,9 +419,9 @@ can be rewritten: class Model(tf.keras.Model): def __init__(self): super(Model, self).__init__() - self.W = tfe.Variable(5., name='weight') - self.B = tfe.Variable(10., name='bias') - def predict(self, inputs): + self.W = tf.Variable(5., name='weight') + self.B = tf.Variable(10., name='bias') + def call(self, inputs): return inputs * self.W + self.B # A toy dataset of points around 3 * x + 2 @@ -433,7 +432,7 @@ training_outputs = training_inputs * 3 + 2 + noise # The loss function to be optimized def loss(model, inputs, targets): - error = model.predict(inputs) - targets + error = model(inputs) - targets return tf.reduce_mean(tf.square(error)) def grad(model, inputs, targets): @@ -499,19 +498,19 @@ is removed, and is then deleted. ```py with tf.device("gpu:0"): - v = tfe.Variable(tf.random_normal([1000, 1000])) + v = tf.Variable(tf.random_normal([1000, 1000])) v = None # v no longer takes up GPU memory ``` ### Object-based saving -`tfe.Checkpoint` can save and restore `tfe.Variable`s to and from +`tf.train.Checkpoint` can save and restore `tf.Variable`s to and from checkpoints: ```py -x = tfe.Variable(10.) +x = tf.Variable(10.) -checkpoint = tfe.Checkpoint(x=x) # save as "x" +checkpoint = tf.train.Checkpoint(x=x) # save as "x" x.assign(2.) # Assign a new value to the variables and save. save_path = checkpoint.save('./ckpt/') @@ -524,18 +523,18 @@ checkpoint.restore(save_path) print(x) # => 2.0 ``` -To save and load models, `tfe.Checkpoint` stores the internal state of objects, +To save and load models, `tf.train.Checkpoint` stores the internal state of objects, without requiring hidden variables. To record the state of a `model`, -an `optimizer`, and a global step, pass them to a `tfe.Checkpoint`: +an `optimizer`, and a global step, pass them to a `tf.train.Checkpoint`: ```py model = MyModel() optimizer = tf.train.AdamOptimizer(learning_rate=0.001) checkpoint_dir = ‘/path/to/model_dir’ checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") -root = tfe.Checkpoint(optimizer=optimizer, - model=model, - optimizer_step=tf.train.get_or_create_global_step()) +root = tf.train.Checkpoint(optimizer=optimizer, + model=model, + optimizer_step=tf.train.get_or_create_global_step()) root.save(file_prefix=checkpoint_prefix) # or @@ -569,9 +568,8 @@ inserted during model construction. For example, to record summaries once every 100 global steps: ```py +global_step = tf.train.get_or_create_global_step() writer = tf.contrib.summary.create_file_writer(logdir) -global_step=tf.train.get_or_create_global_step() # return global step var - writer.set_as_default() for _ in range(iterations): @@ -613,7 +611,7 @@ def line_search_step(fn, init_x, rate=1.0): `tf.GradientTape` is a powerful interface for computing gradients, but there is another [Autograd](https://github.com/HIPS/autograd)-style API available for automatic differentiation. These functions are useful if writing math code with -only tensors and gradient functions, and without `tfe.Variables`: +only tensors and gradient functions, and without `tf.Variables`: * `tfe.gradients_function` —Returns a function that computes the derivatives of its input function parameter with respect to its arguments. The input @@ -728,7 +726,13 @@ def measure(x, steps): start = time.time() for i in range(steps): x = tf.matmul(x, x) - _ = x.numpy() # Make sure to execute op and not just enqueue it + # tf.matmul can return before completing the matrix multiplication + # (e.g., can return after enqueing the operation on a CUDA stream). + # The x.numpy() call below will ensure that all enqueued operations + # have completed (and will also copy the result to host memory, + # so we're including a little more than just the matmul operation + # time). + _ = x.numpy() end = time.time() return end - start @@ -752,8 +756,8 @@ Output (exact numbers depend on hardware): ``` Time to multiply a (1000, 1000) matrix by itself 200 times: -CPU: 4.614904403686523 secs -GPU: 0.5581181049346924 secs +CPU: 1.46628093719 secs +GPU: 0.0593810081482 secs ``` A `tf.Tensor` object can be copied to a different device to execute its @@ -825,7 +829,7 @@ gives you eager's interactive experimentation and debuggability with the distributed performance benefits of graph execution. Write, debug, and iterate in eager execution, then import the model graph for -production deployment. Use `tfe.Checkpoint` to save and restore model +production deployment. Use `tf.train.Checkpoint` to save and restore model variables, this allows movement between eager and graph execution environments. See the examples in: [tensorflow/contrib/eager/python/examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples). diff --git a/tensorflow/docs_src/programmers_guide/embedding.md b/tensorflow/docs_src/guide/embedding.md similarity index 100% rename from tensorflow/docs_src/programmers_guide/embedding.md rename to tensorflow/docs_src/guide/embedding.md diff --git a/tensorflow/docs_src/programmers_guide/estimators.md b/tensorflow/docs_src/guide/estimators.md similarity index 91% rename from tensorflow/docs_src/programmers_guide/estimators.md rename to tensorflow/docs_src/guide/estimators.md index b13b47184d2b32fffb2390b0318fba8612d7826a..7b54e3de29a9f215f5b9396b25b78fc848d2d7e7 100644 --- a/tensorflow/docs_src/programmers_guide/estimators.md +++ b/tensorflow/docs_src/guide/estimators.md @@ -1,6 +1,6 @@ # Estimators -This document introduces @{tf.estimator$**Estimators**}--a high-level TensorFlow +This document introduces `tf.estimator`--a high-level TensorFlow API that greatly simplifies machine learning programming. Estimators encapsulate the following actions: @@ -11,10 +11,13 @@ the following actions: You may either use the pre-made Estimators we provide or write your own custom Estimators. All Estimators--whether pre-made or custom--are -classes based on the @{tf.estimator.Estimator} class. +classes based on the `tf.estimator.Estimator` class. + +For a quick example try [Estimator tutorials]](../tutorials/estimators/linear). +To see each sub-topic in depth, see the [Estimator guides](premade_estimators). Note: TensorFlow also includes a deprecated `Estimator` class at -@{tf.contrib.learn.Estimator}, which you should not use. +`tf.contrib.learn.Estimator`, which you should not use. ## Advantages of Estimators @@ -29,14 +32,14 @@ Estimators provide the following benefits: * You can develop a state of the art model with high-level intuitive code. In short, it is generally much easier to create models with Estimators than with the low-level TensorFlow APIs. -* Estimators are themselves built on @{tf.layers}, which +* Estimators are themselves built on `tf.keras.layers`, which simplifies customization. * Estimators build the graph for you. * Estimators provide a safe distributed training loop that controls how and when to: * build the graph * initialize variables - * start queues + * load data * handle exceptions * create checkpoint files and recover from failures * save summaries for TensorBoard @@ -52,9 +55,9 @@ Pre-made Estimators enable you to work at a much higher conceptual level than the base TensorFlow APIs. You no longer have to worry about creating the computational graph or sessions since Estimators handle all the "plumbing" for you. That is, pre-made Estimators create and manage -@{tf.Graph$`Graph`} and @{tf.Session$`Session`} objects for you. Furthermore, +`tf.Graph` and `tf.Session` objects for you. Furthermore, pre-made Estimators let you experiment with different model architectures by -making only minimal code changes. @{tf.estimator.DNNClassifier$`DNNClassifier`}, +making only minimal code changes. `tf.estimator.DNNClassifier`, for example, is a pre-made Estimator class that trains classification models based on dense, feed-forward neural networks. @@ -81,9 +84,9 @@ of the following four steps: ... # manipulate dataset, extracting the feature dict and the label return feature_dict, label - (See @{$programmers_guide/datasets} for full details.) + (See @{$guide/datasets} for full details.) -2. **Define the feature columns.** Each @{tf.feature_column} +2. **Define the feature columns.** Each `tf.feature_column` identifies a feature name, its type, and any input pre-processing. For example, the following snippet creates three feature columns that hold integer or floating-point data. The first two @@ -155,7 +158,7 @@ We recommend the following workflow: You can convert existing Keras models to Estimators. Doing so enables your Keras model to access Estimator's strengths, such as distributed training. Call -@{tf.keras.estimator.model_to_estimator} as in the +`tf.keras.estimator.model_to_estimator` as in the following sample: ```python @@ -190,4 +193,4 @@ and similarly, the predicted output names can be obtained from `keras_inception_v3.output_names`. For more details, please refer to the documentation for -@{tf.keras.estimator.model_to_estimator}. +`tf.keras.estimator.model_to_estimator`. diff --git a/tensorflow/docs_src/programmers_guide/faq.md b/tensorflow/docs_src/guide/faq.md similarity index 87% rename from tensorflow/docs_src/programmers_guide/faq.md rename to tensorflow/docs_src/guide/faq.md index b6291a9fface404406829d8d7ce5cc36980661a3..8370097560c01d10cba038be63bd1f152115e7f5 100644 --- a/tensorflow/docs_src/programmers_guide/faq.md +++ b/tensorflow/docs_src/guide/faq.md @@ -28,13 +28,13 @@ See also the #### Why does `c = tf.matmul(a, b)` not execute the matrix multiplication immediately? In the TensorFlow Python API, `a`, `b`, and `c` are -@{tf.Tensor} objects. A `Tensor` object is +`tf.Tensor` objects. A `Tensor` object is a symbolic handle to the result of an operation, but does not actually hold the values of the operation's output. Instead, TensorFlow encourages users to build up complicated expressions (such as entire neural networks and its gradients) as a dataflow graph. You then offload the computation of the entire dataflow graph (or a subgraph of it) to a TensorFlow -@{tf.Session}, which is able to execute the +`tf.Session`, which is able to execute the whole computation much more efficiently than executing the operations one-by-one. @@ -46,7 +46,7 @@ device, and `"/device:GPU:i"` (or `"/gpu:i"`) for the *i*th GPU device. #### How do I place operations on a particular device? To place a group of operations on a device, create them within a -@{tf.device$`with tf.device(name):`} context. See +`tf.device` context. See the how-to documentation on @{$using_gpu$using GPUs with TensorFlow} for details of how TensorFlow assigns operations to devices, and the @@ -63,17 +63,17 @@ See also the Feeding is a mechanism in the TensorFlow Session API that allows you to substitute different values for one or more tensors at run time. The `feed_dict` -argument to @{tf.Session.run} is a -dictionary that maps @{tf.Tensor} objects to +argument to `tf.Session.run` is a +dictionary that maps `tf.Tensor` objects to numpy arrays (and some other types), which will be used as the values of those tensors in the execution of a step. #### What is the difference between `Session.run()` and `Tensor.eval()`? -If `t` is a @{tf.Tensor} object, -@{tf.Tensor.eval} is shorthand for -@{tf.Session.run}, where `sess` is the -current @{tf.get_default_session}. The +If `t` is a `tf.Tensor` object, +`tf.Tensor.eval` is shorthand for +`tf.Session.run`, where `sess` is the +current `tf.get_default_session`. The two following snippets of code are equivalent: ```python @@ -99,11 +99,11 @@ sessions, it may be more straightforward to make explicit calls to #### Do Sessions have a lifetime? What about intermediate tensors? Sessions can own resources, such as -@{tf.Variable}, -@{tf.QueueBase}, and -@{tf.ReaderBase}. These resources can sometimes use +`tf.Variable`, +`tf.QueueBase`, and +`tf.ReaderBase`. These resources can sometimes use a significant amount of memory, and can be released when the session is closed by calling -@{tf.Session.close}. +`tf.Session.close`. The intermediate tensors that are created as part of a call to @{$python/client$`Session.run()`} will be freed at or before the @@ -120,7 +120,7 @@ dimensions: devices, which makes it possible to speed up @{$deep_cnn$CIFAR-10 training using multiple GPUs}. * The Session API allows multiple concurrent steps (i.e. calls to - @{tf.Session.run} in parallel). This + `tf.Session.run` in parallel). This enables the runtime to get higher throughput, if a single step does not use all of the resources in your computer. @@ -151,8 +151,8 @@ than 3.5. #### Why does `Session.run()` hang when using a reader or a queue? -The @{tf.ReaderBase} and -@{tf.QueueBase} classes provide special operations that +The `tf.ReaderBase` and +`tf.QueueBase` classes provide special operations that can *block* until input (or free space in a bounded queue) becomes available. These operations allow you to build sophisticated @{$reading_data$input pipelines}, at the cost of making the @@ -169,9 +169,9 @@ See also the how-to documentation on @{$variables$variables} and #### What is the lifetime of a variable? A variable is created when you first run the -@{tf.Variable.initializer} +`tf.Variable.initializer` operation for that variable in a session. It is destroyed when that -@{tf.Session.close}. +`tf.Session.close`. #### How do variables behave when they are concurrently accessed? @@ -179,32 +179,31 @@ Variables allow concurrent read and write operations. The value read from a variable may change if it is concurrently updated. By default, concurrent assignment operations to a variable are allowed to run with no mutual exclusion. To acquire a lock when assigning to a variable, pass `use_locking=True` to -@{tf.Variable.assign}. +`tf.Variable.assign`. ## Tensor shapes See also the -@{tf.TensorShape}. +`tf.TensorShape`. #### How can I determine the shape of a tensor in Python? In TensorFlow, a tensor has both a static (inferred) shape and a dynamic (true) shape. The static shape can be read using the -@{tf.Tensor.get_shape} +`tf.Tensor.get_shape` method: this shape is inferred from the operations that were used to create the -tensor, and may be -@{tf.TensorShape$partially complete}. If the static -shape is not fully defined, the dynamic shape of a `Tensor` `t` can be -determined by evaluating @{tf.shape$`tf.shape(t)`}. +tensor, and may be partially complete (the static-shape may contain `None`). If +the static shape is not fully defined, the dynamic shape of a `tf.Tensor`, `t` +can be determined using `tf.shape(t)`. #### What is the difference between `x.set_shape()` and `x = tf.reshape(x)`? -The @{tf.Tensor.set_shape} method updates +The `tf.Tensor.set_shape` method updates the static shape of a `Tensor` object, and it is typically used to provide additional shape information when this cannot be inferred directly. It does not change the dynamic shape of the tensor. -The @{tf.reshape} operation creates +The `tf.reshape` operation creates a new tensor with a different dynamic shape. #### How do I build a graph that works with variable batch sizes? @@ -212,9 +211,9 @@ a new tensor with a different dynamic shape. It is often useful to build a graph that works with variable batch sizes so that the same code can be used for (mini-)batch training, and single-instance inference. The resulting graph can be -@{tf.Graph.as_graph_def$saved as a protocol buffer} +`tf.Graph.as_graph_def` and -@{tf.import_graph_def$imported into another program}. +`tf.import_graph_def`. When building a variable-size graph, the most important thing to remember is not to encode the batch size as a Python constant, but instead to use a symbolic @@ -224,7 +223,7 @@ to encode the batch size as a Python constant, but instead to use a symbolic to extract the batch dimension from a `Tensor` called `input`, and store it in a `Tensor` called `batch_size`. -* Use @{tf.reduce_mean} instead +* Use `tf.reduce_mean` instead of `tf.reduce_sum(...) / batch_size`. @@ -259,19 +258,19 @@ See the how-to documentation for There are three main options for dealing with data in a custom format. The easiest option is to write parsing code in Python that transforms the data -into a numpy array. Then, use @{tf.data.Dataset.from_tensor_slices} to +into a numpy array. Then, use `tf.data.Dataset.from_tensor_slices` to create an input pipeline from the in-memory data. If your data doesn't fit in memory, try doing the parsing in the Dataset pipeline. Start with an appropriate file reader, like -@{tf.data.TextLineDataset}. Then convert the dataset by mapping -@{tf.data.Dataset.map$mapping} appropriate operations over it. -Prefer predefined TensorFlow operations such as @{tf.decode_raw}, -@{tf.decode_csv}, @{tf.parse_example}, or @{tf.image.decode_png}. +`tf.data.TextLineDataset`. Then convert the dataset by mapping +`tf.data.Dataset.map` appropriate operations over it. +Prefer predefined TensorFlow operations such as `tf.decode_raw`, +`tf.decode_csv`, `tf.parse_example`, or `tf.image.decode_png`. If your data is not easily parsable with the built-in TensorFlow operations, consider converting it, offline, to a format that is easily parsable, such -as @{tf.python_io.TFRecordWriter$`TFRecord`} format. +as `tf.python_io.TFRecordWriter` format. The most efficient method to customize the parsing behavior is to @{$adding_an_op$add a new op written in C++} that parses your diff --git a/tensorflow/docs_src/programmers_guide/feature_columns.md b/tensorflow/docs_src/guide/feature_columns.md similarity index 93% rename from tensorflow/docs_src/programmers_guide/feature_columns.md rename to tensorflow/docs_src/guide/feature_columns.md index 90f5c53a17f23200f238f6b0d171e1e225330e27..b189c4334ed5a5428de223f92de8d93f4ef052ba 100644 --- a/tensorflow/docs_src/programmers_guide/feature_columns.md +++ b/tensorflow/docs_src/guide/feature_columns.md @@ -6,10 +6,10 @@ enabling you to transform a diverse range of raw data into formats that Estimators can use, allowing easy experimentation. In @{$premade_estimators$Premade Estimators}, we used the premade -Estimator, @{tf.estimator.DNNClassifier$`DNNClassifier`} to train a model to +Estimator, `tf.estimator.DNNClassifier` to train a model to predict different types of Iris flowers from four input features. That example created only numerical feature columns (of type -@{tf.feature_column.numeric_column}). Although numerical feature columns model +`tf.feature_column.numeric_column`). Although numerical feature columns model the lengths of petals and sepals effectively, real world data sets contain all kinds of features, many of which are non-numerical. @@ -59,7 +59,7 @@ Feature columns bridge raw data with the data your model needs. To create feature columns, call functions from the -@{tf.feature_column} module. This document explains nine of the functions in +`tf.feature_column` module. This document explains nine of the functions in that module. As the following figure shows, all nine functions return either a Categorical-Column or a Dense-Column object, except `bucketized_column`, which inherits from both classes: @@ -75,7 +75,7 @@ Let's look at these functions in more detail. ### Numeric column -The Iris classifier calls the @{tf.feature_column.numeric_column} function for +The Iris classifier calls the `tf.feature_column.numeric_column` function for all input features: * `SepalLength` @@ -119,7 +119,7 @@ matrix_feature_column = tf.feature_column.numeric_column(key="MyMatrix", Often, you don't want to feed a number directly into the model, but instead split its value into different categories based on numerical ranges. To do so, -create a @{tf.feature_column.bucketized_column$bucketized column}. For +create a `tf.feature_column.bucketized_column`. For example, consider raw data that represents the year a house was built. Instead of representing that year as a scalar numeric column, we could split the year into the following four buckets: @@ -194,7 +194,7 @@ value. That is: * `1="electronics"` * `2="sport"` -Call @{tf.feature_column.categorical_column_with_identity} to implement a +Call `tf.feature_column.categorical_column_with_identity` to implement a categorical identity column. For example: ``` python @@ -230,8 +230,8 @@ As you can see, categorical vocabulary columns are kind of an enum version of categorical identity columns. TensorFlow provides two different functions to create categorical vocabulary columns: -* @{tf.feature_column.categorical_column_with_vocabulary_list} -* @{tf.feature_column.categorical_column_with_vocabulary_file} +* `tf.feature_column.categorical_column_with_vocabulary_list` +* `tf.feature_column.categorical_column_with_vocabulary_file` `categorical_column_with_vocabulary_list` maps each string to an integer based on an explicit vocabulary list. For example: @@ -281,7 +281,7 @@ categories can be so big that it's not possible to have individual categories for each vocabulary word or integer because that would consume too much memory. For these cases, we can instead turn the question around and ask, "How many categories am I willing to have for my input?" In fact, the -@{tf.feature_column.categorical_column_with_hash_bucket} function enables you +`tf.feature_column.categorical_column_with_hash_bucket` function enables you to specify the number of categories. For this type of feature column the model calculates a hash value of the input, then puts it into one of the `hash_bucket_size` categories using the modulo operator, as in the following @@ -289,7 +289,7 @@ pseudocode: ```python # pseudocode -feature_id = hash(raw_feature) % hash_buckets_size +feature_id = hash(raw_feature) % hash_bucket_size ``` The code to create the `feature_column` might look something like this: @@ -298,7 +298,7 @@ The code to create the `feature_column` might look something like this: hashed_feature_column = tf.feature_column.categorical_column_with_hash_bucket( key = "some_feature", - hash_buckets_size = 100) # The number of categories + hash_bucket_size = 100) # The number of categories ``` At this point, you might rightfully think: "This is crazy!" After all, we are forcing the different input values to a smaller set of categories. This means @@ -349,7 +349,7 @@ equal size. For the solution, we used a combination of the `bucketized_column` we looked at -earlier, with the @{tf.feature_column.crossed_column} function. +earlier, with the `tf.feature_column.crossed_column` function. @@ -440,7 +440,7 @@ Representing data in indicator columns. Here's how you create an indicator column by calling -@{tf.feature_column.indicator_column}: +`tf.feature_column.indicator_column`: ``` python categorical_column = ... # Create any type of categorical column. @@ -521,7 +521,7 @@ number of dimensions is 3: Note that this is just a general guideline; you can set the number of embedding dimensions as you please. -Call @{tf.feature_column.embedding_column} to create an `embedding_column` as +Call `tf.feature_column.embedding_column` to create an `embedding_column` as suggested by the following snippet: ``` python @@ -534,7 +534,7 @@ embedding_column = tf.feature_column.embedding_column( dimension=embedding_dimensions) ``` -@{$programmers_guide/embedding$Embeddings} is a significant topic within machine +@{$guide/embedding$Embeddings} is a significant topic within machine learning. This information was just to get you started using them as feature columns. @@ -543,15 +543,15 @@ columns. As the following list indicates, not all Estimators permit all types of `feature_columns` argument(s): -* @{tf.estimator.LinearClassifier$`LinearClassifier`} and - @{tf.estimator.LinearRegressor$`LinearRegressor`}: Accept all types of +* `tf.estimator.LinearClassifier` and + `tf.estimator.LinearRegressor`: Accept all types of feature column. -* @{tf.estimator.DNNClassifier$`DNNClassifier`} and - @{tf.estimator.DNNRegressor$`DNNRegressor`}: Only accept dense columns. Other +* `tf.estimator.DNNClassifier` and + `tf.estimator.DNNRegressor`: Only accept dense columns. Other column types must be wrapped in either an `indicator_column` or `embedding_column`. -* @{tf.estimator.DNNLinearCombinedClassifier$`DNNLinearCombinedClassifier`} and - @{tf.estimator.DNNLinearCombinedRegressor$`DNNLinearCombinedRegressor`}: +* `tf.estimator.DNNLinearCombinedClassifier` and + `tf.estimator.DNNLinearCombinedRegressor`: * The `linear_feature_columns` argument accepts any feature column type. * The `dnn_feature_columns` argument only accepts dense columns. @@ -561,9 +561,9 @@ For more examples on feature columns, view the following: * The @{$low_level_intro#feature_columns$Low Level Introduction} demonstrates how experiment directly with `feature_columns` using TensorFlow's low level APIs. -* The @{$wide$wide} and @{$wide_and_deep$Wide & Deep} Tutorials solve a - binary classification problem using `feature_columns` on a variety of input - data types. +* The [Estimator wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep) + solves a binary classification problem using `feature_columns` on a variety of + input data types. To learn more about embeddings, see the following: diff --git a/tensorflow/docs_src/programmers_guide/graph_viz.md b/tensorflow/docs_src/guide/graph_viz.md similarity index 99% rename from tensorflow/docs_src/programmers_guide/graph_viz.md rename to tensorflow/docs_src/guide/graph_viz.md index f581ae56dae45238d697196e8ad56c86f7309604..97b0e2d4de8e8658f6cde787bc030fe074e59d49 100644 --- a/tensorflow/docs_src/programmers_guide/graph_viz.md +++ b/tensorflow/docs_src/guide/graph_viz.md @@ -15,7 +15,7 @@ variable names can be scoped and the visualization uses this information to define a hierarchy on the nodes in the graph. By default, only the top of this hierarchy is shown. Here is an example that defines three operations under the `hidden` name scope using -@{tf.name_scope}: +`tf.name_scope`: ```python import tensorflow as tf @@ -248,7 +248,8 @@ The images below show the CIFAR-10 model with tensor shape information: Often it is useful to collect runtime metadata for a run, such as total memory usage, total compute time, and tensor shapes for nodes. The code example below is a snippet from the train and test section of a modification of the -@{$layers$simple MNIST tutorial}, in which we have recorded summaries and +[Estimators MNIST tutorial](../tutorials/estimators/cnn.md), in which we have +recorded summaries and runtime statistics. See the @{$summaries_and_tensorboard#serializing-the-data$Summaries Tutorial} for details on how to record summaries. diff --git a/tensorflow/docs_src/programmers_guide/graphs.md b/tensorflow/docs_src/guide/graphs.md similarity index 71% rename from tensorflow/docs_src/programmers_guide/graphs.md rename to tensorflow/docs_src/guide/graphs.md index f0dd8def17fd6dfed241167a5ebb5be678152c16..2bb44fbb327d14fe2650bfa8adb1740312f136f0 100644 --- a/tensorflow/docs_src/programmers_guide/graphs.md +++ b/tensorflow/docs_src/guide/graphs.md @@ -7,7 +7,7 @@ TensorFlow **session** to run parts of the graph across a set of local and remote devices. This guide will be most useful if you intend to use the low-level programming -model directly. Higher-level APIs such as @{tf.estimator.Estimator} and Keras +model directly. Higher-level APIs such as `tf.estimator.Estimator` and Keras hide the details of graphs and sessions from the end user, but this guide may also be useful if you want to understand how these APIs are implemented. @@ -18,12 +18,12 @@ also be useful if you want to understand how these APIs are implemented. [Dataflow](https://en.wikipedia.org/wiki/Dataflow_programming) is a common programming model for parallel computing. In a dataflow graph, the nodes represent units of computation, and the edges represent the data consumed or -produced by a computation. For example, in a TensorFlow graph, the @{tf.matmul} +produced by a computation. For example, in a TensorFlow graph, the `tf.matmul` operation would correspond to a single node with two incoming edges (the matrices to be multiplied) and one outgoing edge (the result of the multiplication). - + Dataflow has several advantages that TensorFlow leverages when executing your programs: @@ -48,9 +48,9 @@ programs: low-latency inference. -## What is a @{tf.Graph}? +## What is a `tf.Graph`? -A @{tf.Graph} contains two relevant kinds of information: +A `tf.Graph` contains two relevant kinds of information: * **Graph structure.** The nodes and edges of the graph, indicating how individual operations are composed together, but not prescribing how they @@ -59,78 +59,78 @@ A @{tf.Graph} contains two relevant kinds of information: context that source code conveys. * **Graph collections.** TensorFlow provides a general mechanism for storing - collections of metadata in a @{tf.Graph}. The @{tf.add_to_collection} function - enables you to associate a list of objects with a key (where @{tf.GraphKeys} - defines some of the standard keys), and @{tf.get_collection} enables you to + collections of metadata in a `tf.Graph`. The `tf.add_to_collection` function + enables you to associate a list of objects with a key (where `tf.GraphKeys` + defines some of the standard keys), and `tf.get_collection` enables you to look up all objects associated with a key. Many parts of the TensorFlow - library use this facility: for example, when you create a @{tf.Variable}, it + library use this facility: for example, when you create a `tf.Variable`, it is added by default to collections representing "global variables" and - "trainable variables". When you later come to create a @{tf.train.Saver} or - @{tf.train.Optimizer}, the variables in these collections are used as the + "trainable variables". When you later come to create a `tf.train.Saver` or + `tf.train.Optimizer`, the variables in these collections are used as the default arguments. -## Building a @{tf.Graph} +## Building a `tf.Graph` Most TensorFlow programs start with a dataflow graph construction phase. In this -phase, you invoke TensorFlow API functions that construct new @{tf.Operation} -(node) and @{tf.Tensor} (edge) objects and add them to a @{tf.Graph} +phase, you invoke TensorFlow API functions that construct new `tf.Operation` +(node) and `tf.Tensor` (edge) objects and add them to a `tf.Graph` instance. TensorFlow provides a **default graph** that is an implicit argument to all API functions in the same context. For example: -* Calling `tf.constant(42.0)` creates a single @{tf.Operation} that produces the - value `42.0`, adds it to the default graph, and returns a @{tf.Tensor} that +* Calling `tf.constant(42.0)` creates a single `tf.Operation` that produces the + value `42.0`, adds it to the default graph, and returns a `tf.Tensor` that represents the value of the constant. -* Calling `tf.matmul(x, y)` creates a single @{tf.Operation} that multiplies - the values of @{tf.Tensor} objects `x` and `y`, adds it to the default graph, - and returns a @{tf.Tensor} that represents the result of the multiplication. +* Calling `tf.matmul(x, y)` creates a single `tf.Operation` that multiplies + the values of `tf.Tensor` objects `x` and `y`, adds it to the default graph, + and returns a `tf.Tensor` that represents the result of the multiplication. -* Executing `v = tf.Variable(0)` adds to the graph a @{tf.Operation} that will - store a writeable tensor value that persists between @{tf.Session.run} calls. - The @{tf.Variable} object wraps this operation, and can be used [like a +* Executing `v = tf.Variable(0)` adds to the graph a `tf.Operation` that will + store a writeable tensor value that persists between `tf.Session.run` calls. + The `tf.Variable` object wraps this operation, and can be used [like a tensor](#tensor-like_objects), which will read the current value of the - stored value. The @{tf.Variable} object also has methods such as - @{tf.Variable.assign$`assign`} and @{tf.Variable.assign_add$`assign_add`} that - create @{tf.Operation} objects that, when executed, update the stored value. - (See @{$programmers_guide/variables} for more information about variables.) + stored value. The `tf.Variable` object also has methods such as + `tf.Variable.assign` and `tf.Variable.assign_add` that + create `tf.Operation` objects that, when executed, update the stored value. + (See @{$guide/variables} for more information about variables.) -* Calling @{tf.train.Optimizer.minimize} will add operations and tensors to the - default graph that calculates gradients, and return a @{tf.Operation} that, +* Calling `tf.train.Optimizer.minimize` will add operations and tensors to the + default graph that calculates gradients, and return a `tf.Operation` that, when run, will apply those gradients to a set of variables. Most programs rely solely on the default graph. However, see [Dealing with multiple graphs](#programming_with_multiple_graphs) for more -advanced use cases. High-level APIs such as the @{tf.estimator.Estimator} API +advanced use cases. High-level APIs such as the `tf.estimator.Estimator` API manage the default graph on your behalf, and--for example--may create different graphs for training and evaluation. Note: Calling most functions in the TensorFlow API merely adds operations and tensors to the default graph, but **does not** perform the actual -computation. Instead, you compose these functions until you have a @{tf.Tensor} -or @{tf.Operation} that represents the overall computation--such as performing -one step of gradient descent--and then pass that object to a @{tf.Session} to -perform the computation. See the section "Executing a graph in a @{tf.Session}" +computation. Instead, you compose these functions until you have a `tf.Tensor` +or `tf.Operation` that represents the overall computation--such as performing +one step of gradient descent--and then pass that object to a `tf.Session` to +perform the computation. See the section "Executing a graph in a `tf.Session`" for more details. ## Naming operations -A @{tf.Graph} object defines a **namespace** for the @{tf.Operation} objects it +A `tf.Graph` object defines a **namespace** for the `tf.Operation` objects it contains. TensorFlow automatically chooses a unique name for each operation in your graph, but giving operations descriptive names can make your program easier to read and debug. The TensorFlow API provides two ways to override the name of an operation: -* Each API function that creates a new @{tf.Operation} or returns a new - @{tf.Tensor} accepts an optional `name` argument. For example, - `tf.constant(42.0, name="answer")` creates a new @{tf.Operation} named - `"answer"` and returns a @{tf.Tensor} named `"answer:0"`. If the default graph +* Each API function that creates a new `tf.Operation` or returns a new + `tf.Tensor` accepts an optional `name` argument. For example, + `tf.constant(42.0, name="answer")` creates a new `tf.Operation` named + `"answer"` and returns a `tf.Tensor` named `"answer:0"`. If the default graph already contains an operation named `"answer"`, then TensorFlow would append `"_1"`, `"_2"`, and so on to the name, in order to make it unique. -* The @{tf.name_scope} function makes it possible to add a **name scope** prefix +* The `tf.name_scope` function makes it possible to add a **name scope** prefix to all operations created in a particular context. The current name scope - prefix is a `"/"`-delimited list of the names of all active @{tf.name_scope} + prefix is a `"/"`-delimited list of the names of all active `tf.name_scope` context managers. If a name scope has already been used in the current context, TensorFlow appends `"_1"`, `"_2"`, and so on. For example: @@ -160,7 +160,7 @@ The graph visualizer uses name scopes to group operations and reduce the visual complexity of a graph. See [Visualizing your graph](#visualizing-your-graph) for more information. -Note that @{tf.Tensor} objects are implicitly named after the @{tf.Operation} +Note that `tf.Tensor` objects are implicitly named after the `tf.Operation` that produces the tensor as output. A tensor name has the form `":"` where: @@ -171,7 +171,7 @@ where: ## Placing operations on different devices If you want your TensorFlow program to use multiple different devices, the -@{tf.device} function provides a convenient way to request that all operations +`tf.device` function provides a convenient way to request that all operations created in a particular context are placed on the same device (or type of device). @@ -186,7 +186,7 @@ where: * `` is an alpha-numeric string that does not start with a number. * `` is a registered device type (such as `GPU` or `CPU`). * `` is a non-negative integer representing the index of the task - in the job named ``. See @{tf.train.ClusterSpec} for an explanation + in the job named ``. See `tf.train.ClusterSpec` for an explanation of jobs and tasks. * `` is a non-negative integer representing the index of the device, for example, to distinguish between different GPU devices used in the @@ -194,7 +194,7 @@ where: You do not need to specify every part of a device specification. For example, if you are running in a single-machine configuration with a single GPU, you -might use @{tf.device} to pin some operations to the CPU and GPU: +might use `tf.device` to pin some operations to the CPU and GPU: ```python # Operations created outside either context will run on the "best possible" @@ -229,13 +229,13 @@ with tf.device("/job:worker"): layer_2 = tf.matmul(train_batch, weights_2) + biases_2 ``` -@{tf.device} gives you a lot of flexibility to choose placements for individual +`tf.device` gives you a lot of flexibility to choose placements for individual operations or broad regions of a TensorFlow graph. In many cases, there are simple heuristics that work well. For example, the -@{tf.train.replica_device_setter} API can be used with @{tf.device} to place +`tf.train.replica_device_setter` API can be used with `tf.device` to place operations for **data-parallel distributed training**. For example, the -following code fragment shows how @{tf.train.replica_device_setter} applies -different placement policies to @{tf.Variable} objects and other operations: +following code fragment shows how `tf.train.replica_device_setter` applies +different placement policies to `tf.Variable` objects and other operations: ```python with tf.device(tf.train.replica_device_setter(ps_tasks=3)): @@ -253,41 +253,41 @@ with tf.device(tf.train.replica_device_setter(ps_tasks=3)): ## Tensor-like objects -Many TensorFlow operations take one or more @{tf.Tensor} objects as arguments. -For example, @{tf.matmul} takes two @{tf.Tensor} objects, and @{tf.add_n} takes -a list of `n` @{tf.Tensor} objects. For convenience, these functions will accept -a **tensor-like object** in place of a @{tf.Tensor}, and implicitly convert it -to a @{tf.Tensor} using the @{tf.convert_to_tensor} method. Tensor-like objects +Many TensorFlow operations take one or more `tf.Tensor` objects as arguments. +For example, `tf.matmul` takes two `tf.Tensor` objects, and `tf.add_n` takes +a list of `n` `tf.Tensor` objects. For convenience, these functions will accept +a **tensor-like object** in place of a `tf.Tensor`, and implicitly convert it +to a `tf.Tensor` using the `tf.convert_to_tensor` method. Tensor-like objects include elements of the following types: -* @{tf.Tensor} -* @{tf.Variable} +* `tf.Tensor` +* `tf.Variable` * [`numpy.ndarray`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.html) * `list` (and lists of tensor-like objects) * Scalar Python types: `bool`, `float`, `int`, `str` You can register additional tensor-like types using -@{tf.register_tensor_conversion_function}. +`tf.register_tensor_conversion_function`. -Note: By default, TensorFlow will create a new @{tf.Tensor} each time you use +Note: By default, TensorFlow will create a new `tf.Tensor` each time you use the same tensor-like object. If the tensor-like object is large (e.g. a `numpy.ndarray` containing a set of training examples) and you use it multiple times, you may run out of memory. To avoid this, manually call -@{tf.convert_to_tensor} on the tensor-like object once and use the returned -@{tf.Tensor} instead. +`tf.convert_to_tensor` on the tensor-like object once and use the returned +`tf.Tensor` instead. -## Executing a graph in a @{tf.Session} +## Executing a graph in a `tf.Session` -TensorFlow uses the @{tf.Session} class to represent a connection between the +TensorFlow uses the `tf.Session` class to represent a connection between the client program---typically a Python program, although a similar interface is -available in other languages---and the C++ runtime. A @{tf.Session} object +available in other languages---and the C++ runtime. A `tf.Session` object provides access to devices in the local machine, and remote devices using the distributed TensorFlow runtime. It also caches information about your -@{tf.Graph} so that you can efficiently run the same computation multiple times. +`tf.Graph` so that you can efficiently run the same computation multiple times. -### Creating a @{tf.Session} +### Creating a `tf.Session` -If you are using the low-level TensorFlow API, you can create a @{tf.Session} +If you are using the low-level TensorFlow API, you can create a `tf.Session` for the current default graph as follows: ```python @@ -300,50 +300,50 @@ with tf.Session("grpc://example.org:2222"): # ... ``` -Since a @{tf.Session} owns physical resources (such as GPUs and +Since a `tf.Session` owns physical resources (such as GPUs and network connections), it is typically used as a context manager (in a `with` block) that automatically closes the session when you exit the block. It is also possible to create a session without using a `with` block, but you should -explicitly call @{tf.Session.close} when you are finished with it to free the +explicitly call `tf.Session.close` when you are finished with it to free the resources. -Note: Higher-level APIs such as @{tf.train.MonitoredTrainingSession} or -@{tf.estimator.Estimator} will create and manage a @{tf.Session} for you. These +Note: Higher-level APIs such as `tf.train.MonitoredTrainingSession` or +`tf.estimator.Estimator` will create and manage a `tf.Session` for you. These APIs accept optional `target` and `config` arguments (either directly, or as -part of a @{tf.estimator.RunConfig} object), with the same meaning as +part of a `tf.estimator.RunConfig` object), with the same meaning as described below. -@{tf.Session.__init__} accepts three optional arguments: +`tf.Session.__init__` accepts three optional arguments: * **`target`.** If this argument is left empty (the default), the session will only use devices in the local machine. However, you may also specify a `grpc://` URL to specify the address of a TensorFlow server, which gives the session access to all devices on machines that this server controls. See - @{tf.train.Server} for details of how to create a TensorFlow + `tf.train.Server` for details of how to create a TensorFlow server. For example, in the common **between-graph replication** - configuration, the @{tf.Session} connects to a @{tf.train.Server} in the same + configuration, the `tf.Session` connects to a `tf.train.Server` in the same process as the client. The [distributed TensorFlow](../deploy/distributed.md) deployment guide describes other common scenarios. -* **`graph`.** By default, a new @{tf.Session} will be bound to---and only able +* **`graph`.** By default, a new `tf.Session` will be bound to---and only able to run operations in---the current default graph. If you are using multiple graphs in your program (see [Programming with multiple graphs](#programming_with_multiple_graphs) for more details), you can specify - an explicit @{tf.Graph} when you construct the session. + an explicit `tf.Graph` when you construct the session. -* **`config`.** This argument allows you to specify a @{tf.ConfigProto} that +* **`config`.** This argument allows you to specify a `tf.ConfigProto` that controls the behavior of the session. For example, some of the configuration options include: * `allow_soft_placement`. Set this to `True` to enable a "soft" device - placement algorithm, which ignores @{tf.device} annotations that attempt + placement algorithm, which ignores `tf.device` annotations that attempt to place CPU-only operations on a GPU device, and places them on the CPU instead. * `cluster_def`. When using distributed TensorFlow, this option allows you to specify what machines to use in the computation, and provide a mapping between job names, task indices, and network addresses. See - @{tf.train.ClusterSpec.as_cluster_def} for details. + `tf.train.ClusterSpec.as_cluster_def` for details. * `graph_options.optimizer_options`. Provides control over the optimizations that TensorFlow performs on your graph before executing it. @@ -353,21 +353,21 @@ described below. rather than allocating most of the memory at startup. -### Using @{tf.Session.run} to execute operations +### Using `tf.Session.run` to execute operations -The @{tf.Session.run} method is the main mechanism for running a @{tf.Operation} -or evaluating a @{tf.Tensor}. You can pass one or more @{tf.Operation} or -@{tf.Tensor} objects to @{tf.Session.run}, and TensorFlow will execute the +The `tf.Session.run` method is the main mechanism for running a `tf.Operation` +or evaluating a `tf.Tensor`. You can pass one or more `tf.Operation` or +`tf.Tensor` objects to `tf.Session.run`, and TensorFlow will execute the operations that are needed to compute the result. -@{tf.Session.run} requires you to specify a list of **fetches**, which determine -the return values, and may be a @{tf.Operation}, a @{tf.Tensor}, or -a [tensor-like type](#tensor-like_objects) such as @{tf.Variable}. These fetches -determine what **subgraph** of the overall @{tf.Graph} must be executed to +`tf.Session.run` requires you to specify a list of **fetches**, which determine +the return values, and may be a `tf.Operation`, a `tf.Tensor`, or +a [tensor-like type](#tensor-like_objects) such as `tf.Variable`. These fetches +determine what **subgraph** of the overall `tf.Graph` must be executed to produce the result: this is the subgraph that contains all operations named in the fetch list, plus all operations whose outputs are used to compute the value of the fetches. For example, the following code fragment shows how different -arguments to @{tf.Session.run} cause different subgraphs to be executed: +arguments to `tf.Session.run` cause different subgraphs to be executed: ```python x = tf.constant([[37.0, -23.0], [1.0, 4.0]]) @@ -390,8 +390,8 @@ with tf.Session() as sess: y_val, output_val = sess.run([y, output]) ``` -@{tf.Session.run} also optionally takes a dictionary of **feeds**, which is a -mapping from @{tf.Tensor} objects (typically @{tf.placeholder} tensors) to +`tf.Session.run` also optionally takes a dictionary of **feeds**, which is a +mapping from `tf.Tensor` objects (typically `tf.placeholder` tensors) to values (typically Python scalars, lists, or NumPy arrays) that will be substituted for those tensors in the execution. For example: @@ -415,7 +415,7 @@ with tf.Session() as sess: sess.run(y, {x: 37.0}) ``` -@{tf.Session.run} also accepts an optional `options` argument that enables you +`tf.Session.run` also accepts an optional `options` argument that enables you to specify options about the call, and an optional `run_metadata` argument that enables you to collect metadata about the execution. For example, you can use these options together to collect tracing information about the execution: @@ -447,8 +447,8 @@ with tf.Session() as sess: TensorFlow includes tools that can help you to understand the code in a graph. The **graph visualizer** is a component of TensorBoard that renders the structure of your graph visually in a browser. The easiest way to create a -visualization is to pass a @{tf.Graph} when creating the -@{tf.summary.FileWriter}: +visualization is to pass a `tf.Graph` when creating the +`tf.summary.FileWriter`: ```python # Build your graph. @@ -471,7 +471,7 @@ with tf.Session() as sess: writer.close() ``` -Note: If you are using a @{tf.estimator.Estimator}, the graph (and any +Note: If you are using a `tf.estimator.Estimator`, the graph (and any summaries) will be logged automatically to the `model_dir` that you specified when creating the estimator. @@ -486,7 +486,7 @@ subgraph inside. ![](../images/mnist_deep.png) For more information about visualizing your TensorFlow application with -TensorBoard, see the [TensorBoard tutorial](../get_started/summaries_and_tensorboard.md). +TensorBoard, see the [TensorBoard guide](./summaries_and_tensorboard.md). ## Programming with multiple graphs @@ -495,8 +495,8 @@ graph for training your model, and a separate graph for evaluating or performing inference with a trained model. In many cases, the inference graph will be different from the training graph: for example, techniques like dropout and batch normalization use different operations in each case. Furthermore, by -default utilities like @{tf.train.Saver} use the names of @{tf.Variable} objects -(which have names based on an underlying @{tf.Operation}) to identify each +default utilities like `tf.train.Saver` use the names of `tf.Variable` objects +(which have names based on an underlying `tf.Operation`) to identify each variable in a saved checkpoint. When programming this way, you can either use completely separate Python processes to build and execute the graphs, or you can use multiple graphs in the same process. This section describes how to use @@ -507,21 +507,21 @@ to all API functions in the same context. For many applications, a single graph is sufficient. However, TensorFlow also provides methods for manipulating the default graph, which can be useful in more advanced use cases. For example: -* A @{tf.Graph} defines the namespace for @{tf.Operation} objects: each +* A `tf.Graph` defines the namespace for `tf.Operation` objects: each operation in a single graph must have a unique name. TensorFlow will "uniquify" the names of operations by appending `"_1"`, `"_2"`, and so on to their names if the requested name is already taken. Using multiple explicitly created graphs gives you more control over what name is given to each operation. -* The default graph stores information about every @{tf.Operation} and - @{tf.Tensor} that was ever added to it. If your program creates a large number +* The default graph stores information about every `tf.Operation` and + `tf.Tensor` that was ever added to it. If your program creates a large number of unconnected subgraphs, it may be more efficient to use a different - @{tf.Graph} to build each subgraph, so that unrelated state can be garbage + `tf.Graph` to build each subgraph, so that unrelated state can be garbage collected. -You can install a different @{tf.Graph} as the default graph, using the -@{tf.Graph.as_default} context manager: +You can install a different `tf.Graph` as the default graph, using the +`tf.Graph.as_default` context manager: ```python g_1 = tf.Graph() @@ -548,8 +548,8 @@ assert d.graph is g_2 assert sess_2.graph is g_2 ``` -To inspect the current default graph, call @{tf.get_default_graph}, which -returns a @{tf.Graph} object: +To inspect the current default graph, call `tf.get_default_graph`, which +returns a `tf.Graph` object: ```python # Print all of the operations in the default graph. diff --git a/tensorflow/docs_src/programmers_guide/index.md b/tensorflow/docs_src/guide/index.md similarity index 59% rename from tensorflow/docs_src/programmers_guide/index.md rename to tensorflow/docs_src/guide/index.md index 0c2d4afb115c592c1925dde98b3a1a8c2a7ccad1..1c920e7d700c29b2851927beafa5ca4207787a09 100644 --- a/tensorflow/docs_src/programmers_guide/index.md +++ b/tensorflow/docs_src/guide/index.md @@ -1,29 +1,26 @@ -# Programmer's Guide +# TensorFlow Guide The documents in this unit dive into the details of how TensorFlow works. The units are as follows: ## High Level APIs - * @{$programmers_guide/keras}, TensorFlow's high-level API for building and + * @{$guide/keras}, TensorFlow's high-level API for building and training deep learning models. - * @{$programmers_guide/eager}, an API for writing TensorFlow code + * @{$guide/eager}, an API for writing TensorFlow code imperatively, like you would use Numpy. - * @{$programmers_guide/estimators}, a high-level API that provides - fully-packaged models ready for large-scale training and production. - * @{$programmers_guide/datasets}, easy input pipelines to bring your data into + * @{$guide/datasets}, easy input pipelines to bring your data into your TensorFlow program. + * @{$guide/estimators}, a high-level API that provides + fully-packaged models ready for large-scale training and production. ## Estimators -* @{$estimators} provides an introduction. -* @{$premade_estimators}, introduces Estimators for machine learning. -* @{$custom_estimators}, which demonstrates how to build and train models you - design yourself. -* @{$feature_columns}, which shows how an Estimator can handle a variety of input - data types without changes to the model. -* @{$checkpoints}, which explains how to save training progress and resume where - you left off. +* @{$premade_estimators}, the basics of premade Estimators. +* @{$checkpoints}, save training progress and resume where you left off. +* @{$feature_columns}, handle a variety of input data types without changes to the model. +* @{$datasets_for_estimators}, use `tf.data` to input data. +* @{$custom_estimators}, write your own Estimator. ## Accelerators @@ -33,13 +30,13 @@ works. The units are as follows: ## Low Level APIs - * @{$programmers_guide/low_level_intro}, which introduces the + * @{$guide/low_level_intro}, which introduces the basics of how you can use TensorFlow outside of the high Level APIs. - * @{$programmers_guide/tensors}, which explains how to create, + * @{$guide/tensors}, which explains how to create, manipulate, and access Tensors--the fundamental object in TensorFlow. - * @{$programmers_guide/variables}, which details how + * @{$guide/variables}, which details how to represent shared, persistent state in your program. - * @{$programmers_guide/graphs}, which explains: + * @{$guide/graphs}, which explains: * dataflow graphs, which are TensorFlow's representation of computations as dependencies between operations. * sessions, which are TensorFlow's mechanism for running dataflow graphs @@ -49,19 +46,19 @@ works. The units are as follows: such as Estimators or Keras, the high-level API creates and manages graphs and sessions for you, but understanding graphs and sessions can still be helpful. - * @{$programmers_guide/saved_model}, which + * @{$guide/saved_model}, which explains how to save and restore variables and models. ## ML Concepts - * @{$programmers_guide/embedding}, which introduces the concept + * @{$guide/embedding}, which introduces the concept of embeddings, provides a simple example of training an embedding in TensorFlow, and explains how to view embeddings with the TensorBoard Embedding Projector. ## Debugging - * @{$programmers_guide/debugger}, which + * @{$guide/debugger}, which explains how to use the TensorFlow debugger (tfdbg). ## TensorBoard @@ -69,17 +66,17 @@ works. The units are as follows: TensorBoard is a utility to visualize different aspects of machine learning. The following guides explain how to use TensorBoard: - * @{$programmers_guide/summaries_and_tensorboard}, + * @{$guide/summaries_and_tensorboard}, which introduces TensorBoard. - * @{$programmers_guide/graph_viz}, which + * @{$guide/graph_viz}, which explains how to visualize the computational graph. - * @{$programmers_guide/tensorboard_histograms} which demonstrates the how to + * @{$guide/tensorboard_histograms} which demonstrates the how to use TensorBoard's histogram dashboard. ## Misc - * @{$programmers_guide/version_compat}, + * @{$guide/version_compat}, which explains backward compatibility guarantees and non-guarantees. - * @{$programmers_guide/faq}, which contains frequently asked + * @{$guide/faq}, which contains frequently asked questions about TensorFlow. diff --git a/tensorflow/docs_src/programmers_guide/keras.md b/tensorflow/docs_src/guide/keras.md similarity index 94% rename from tensorflow/docs_src/programmers_guide/keras.md rename to tensorflow/docs_src/guide/keras.md index c6aca7ebf4edd085e2b47492fd3a86578620492d..2330fa03c7401c91d588f47d7d62484c09a73c5f 100644 --- a/tensorflow/docs_src/programmers_guide/keras.md +++ b/tensorflow/docs_src/guide/keras.md @@ -19,7 +19,7 @@ fast prototyping, advanced research, and production, with three key advantages: [Keras API specification](https://keras.io){:.external}. This is a high-level API to build and train models that includes first-class support for TensorFlow-specific functionality, such as [eager execution](#eager_execution), -`tf.data` pipelines, and [Estimators](/programmers_guide/estimators). +`tf.data` pipelines, and [Estimators](./estimators.md). `tf.keras` makes TensorFlow easier to use without sacrificing flexibility and performance. @@ -35,8 +35,8 @@ from tensorflow import keras * The `tf.keras` version in the latest TensorFlow release might not be the same as the latest `keras` version from PyPI. Check `tf.keras.__version__`. * When [saving a model's weights](#weights_only), `tf.keras` defaults to the - [checkpoint format](/get_started/checkpoints). Pass `save_format='h5'` to use - HDF5. + [checkpoint format](./checkpoints.md). Pass `save_format='h5'` to + use HDF5. ## Build a simple model @@ -179,7 +179,7 @@ model.fit(data, labels, epochs=10, batch_size=32, ### Input tf.data datasets -Use the [Datasets API](/programmers_guide/datasets) to scale to large datasets +Use the [Datasets API](./datasets.md) to scale to large datasets or multi-device training. Pass a `tf.data.Dataset` instance to the `fit` method: @@ -221,7 +221,7 @@ To *evaluate* the inference-mode loss and metrics for the data provided: ```python model.evaluate(x, y, batch_size=32) -model.evaluate(dataset, steps=30 +model.evaluate(dataset, steps=30) ``` And to *predict* the output of the last layer in inference for the data provided, @@ -285,7 +285,7 @@ your own forward pass. Create layers in the `__init__` method and set them as attributes of the class instance. Define the forward pass in the `call` method. Model subclassing is particularly useful when -[eager execution](/programmers_guide/eager) is enabled since the forward pass +[eager execution](./eager.md) is enabled since the forward pass can be written imperatively. Key Point: Use the right API for the job. While model subclassing offers @@ -410,7 +410,7 @@ during training. You can write your own custom callback, or use the built-in * `tf.keras.callbacks.EarlyStopping`: Interrupt training when validation performance has stopped improving. * `tf.keras.callbacks.TensorBoard`: Monitor the model's behavior using - [TensorBoard](/programmers_guide/summaries_and_tensorboard). + [TensorBoard](./summaries_and_tensorboard.md). To use a `tf.keras.callbacks.Callback`, pass it to the model's `fit` method: @@ -442,8 +442,8 @@ model.load_weights('my_model') ``` By default, this saves the model's weights in the -[TensorFlow checkpoint](/get_started/checkpoints) file format. Weights can also -be saved to the Keras HDF5 format (the default for the multi-backend +[TensorFlow checkpoint](./checkpoints.md) file format. Weights can +also be saved to the Keras HDF5 format (the default for the multi-backend implementation of Keras): ```python @@ -467,13 +467,13 @@ JSON and YAML serialization formats: json_string = model.to_json() # Recreate the model (freshly initialized) -fresh_model = keras.models.from_json(json_string) +fresh_model = keras.models.model_from_json(json_string) # Serializes a model to YAML format yaml_string = model.to_yaml() # Recreate the model -fresh_model = keras.models.from_yaml(yaml_string) +fresh_model = keras.models.model_from_yaml(yaml_string) ``` Caution: Subclassed models are not serializable because their architecture is @@ -509,7 +509,7 @@ model = keras.models.load_model('my_model.h5') ## Eager execution -[Eager execution](/programmers_guide/eager) is an imperative programming +[Eager execution](./eager.md) is an imperative programming environment that evaluates operations immediately. This is not required for Keras, but is supported by `tf.keras` and useful for inspecting your program and debugging. @@ -520,7 +520,7 @@ especially benefits *model subclassing* and building *custom layers*—the APIs that require you to write the forward pass as code (instead of the APIs that create models by assembling existing layers). -See the [eager execution guide](/programmers_guide/eager#build_a_model) for +See the [eager execution guide](./eager.md#build_a_model) for examples of using Keras models with custom training loops and `tf.GradientTape`. @@ -528,14 +528,14 @@ examples of using Keras models with custom training loops and `tf.GradientTape`. ### Estimators -The [Estimators](/programmers_guide/estimators) API is used for training models +The [Estimators](./estimators.md) API is used for training models for distributed environments. This targets industry use cases such as distributed training on large datasets that can export a model for production. A `tf.keras.Model` can be trained with the `tf.estimator` API by converting the model to an `tf.estimator.Estimator` object with `tf.keras.estimator.model_to_estimator`. See -[Creating Estimators from Keras models](/programmers_guide/estimators#creating_estimators_from_keras_models). +[Creating Estimators from Keras models](./estimators.md#creating_estimators_from_keras_models). ```python model = keras.Sequential([layers.Dense(10,activation='softmax'), @@ -548,8 +548,8 @@ model.compile(optimizer=tf.train.RMSPropOptimizer(0.001), estimator = keras.estimator.model_to_estimator(model) ``` -Note: Enable [eager execution](/programmers_guide/eager) for debugging -[Estimator input functions](/programmers_guide/premade_estimators#create_input_functions) +Note: Enable [eager execution](./eager.md) for debugging +[Estimator input functions](./premade_estimators.md#create_input_functions) and inspecting data. ### Multiple GPUs @@ -581,15 +581,6 @@ model.compile(loss='binary_crossentropy', optimizer=optimizer) model.summary() ``` -Convert the Keras model to a `tf.estimator.Estimator` instance: - -```python -keras_estimator = keras.estimator.model_to_estimator( - keras_model=model, - config=config, - model_dir='/tmp/model_dir') -``` - Define an *input pipeline*. The `input_fn` returns a `tf.data.Dataset` object used to distribute the data across multiple devices—with each device processing a slice of the input batch. @@ -615,6 +606,15 @@ strategy = tf.contrib.distribute.MirroredStrategy() config = tf.estimator.RunConfig(train_distribute=strategy) ``` +Convert the Keras model to a `tf.estimator.Estimator` instance: + +```python +keras_estimator = keras.estimator.model_to_estimator( + keras_model=model, + config=config, + model_dir='/tmp/model_dir') +``` + Finally, train the `Estimator` instance by providing the `input_fn` and `steps` arguments: diff --git a/tensorflow/docs_src/programmers_guide/leftnav_files b/tensorflow/docs_src/guide/leftnav_files similarity index 90% rename from tensorflow/docs_src/programmers_guide/leftnav_files rename to tensorflow/docs_src/guide/leftnav_files index 3bcf864e13db0cef40cec74ab872c807c2ec2fb0..8e227e0c8fc5cf7a30ed222706f89db9af482ec0 100644 --- a/tensorflow/docs_src/programmers_guide/leftnav_files +++ b/tensorflow/docs_src/guide/leftnav_files @@ -4,13 +4,14 @@ index.md keras.md eager.md datasets.md +estimators.md: Introduction to Estimators ### Estimators -estimators.md: Introduction to Estimators premade_estimators.md -custom_estimators.md -feature_columns.md checkpoints.md +feature_columns.md +datasets_for_estimators.md +custom_estimators.md ### Accelerators using_gpu.md @@ -22,6 +23,7 @@ tensors.md variables.md graphs.md saved_model.md +autograph.md : Control flow ### ML Concepts embedding.md diff --git a/tensorflow/docs_src/programmers_guide/low_level_intro.md b/tensorflow/docs_src/guide/low_level_intro.md similarity index 91% rename from tensorflow/docs_src/programmers_guide/low_level_intro.md rename to tensorflow/docs_src/guide/low_level_intro.md index 478e2bb70bc7f58156398c9f9fef4e76ba581e1a..dc6cb9ee0dfec37ce56f2c791f99f3f4917cf4f9 100644 --- a/tensorflow/docs_src/programmers_guide/low_level_intro.md +++ b/tensorflow/docs_src/guide/low_level_intro.md @@ -63,17 +63,17 @@ TensorFlow uses numpy arrays to represent tensor **values**. You might think of TensorFlow Core programs as consisting of two discrete sections: -1. Building the computational graph (a @{tf.Graph}). -2. Running the computational graph (using a @{tf.Session}). +1. Building the computational graph (a `tf.Graph`). +2. Running the computational graph (using a `tf.Session`). ### Graph A **computational graph** is a series of TensorFlow operations arranged into a graph. The graph is composed of two types of objects. - * @{tf.Operation$Operations} (or "ops"): The nodes of the graph. + * `tf.Operation` (or "ops"): The nodes of the graph. Operations describe calculations that consume and produce tensors. - * @{tf.Tensor$Tensors}: The edges in the graph. These represent the values + * `tf.Tensor`: The edges in the graph. These represent the values that will flow through the graph. Most TensorFlow functions return `tf.Tensors`. @@ -149,7 +149,7 @@ For more about TensorBoard's graph visualization tools see @{$graph_viz}. ### Session -To evaluate tensors, instantiate a @{tf.Session} object, informally known as a +To evaluate tensors, instantiate a `tf.Session` object, informally known as a **session**. A session encapsulates the state of the TensorFlow runtime, and runs TensorFlow operations. If a `tf.Graph` is like a `.py` file, a `tf.Session` is like the `python` executable. @@ -232,7 +232,7 @@ z = x + y The preceding three lines are a bit like a function in which we define two input parameters (`x` and `y`) and then an operation on them. We can evaluate this graph with multiple inputs by using the `feed_dict` argument of -the @{tf.Session.run$run method} to feed concrete values to the placeholders: +the `tf.Session.run` method to feed concrete values to the placeholders: ```python print(sess.run(z, feed_dict={x: 3, y: 4.5})) @@ -251,15 +251,15 @@ that placeholders throw an error if no value is fed to them. ## Datasets -Placeholders work for simple experiments, but @{tf.data$Datasets} are the +Placeholders work for simple experiments, but `tf.data` are the preferred method of streaming data into a model. To get a runnable `tf.Tensor` from a Dataset you must first convert it to a -@{tf.data.Iterator}, and then call the Iterator's -@{tf.data.Iterator.get_next$`get_next`} method. +`tf.data.Iterator`, and then call the Iterator's +`tf.data.Iterator.get_next` method. The simplest way to create an Iterator is with the -@{tf.data.Dataset.make_one_shot_iterator$`make_one_shot_iterator`} method. +`tf.data.Dataset.make_one_shot_iterator` method. For example, in the following code the `next_item` tensor will return a row from the `my_data` array on each `run` call: @@ -275,7 +275,7 @@ next_item = slices.make_one_shot_iterator().get_next() ``` Reaching the end of the data stream causes `Dataset` to throw an -@{tf.errors.OutOfRangeError$`OutOfRangeError`}. For example, the following code +`tf.errors.OutOfRangeError`. For example, the following code reads the `next_item` until there is no more data to read: ``` python @@ -303,12 +303,12 @@ while True: break ``` -For more details on Datasets and Iterators see: @{$programmers_guide/datasets}. +For more details on Datasets and Iterators see: @{$guide/datasets}. ## Layers A trainable model must modify the values in the graph to get new outputs with -the same input. @{tf.layers$Layers} are the preferred way to add trainable +the same input. `tf.layers` are the preferred way to add trainable parameters to a graph. Layers package together both the variables and the operations that act @@ -321,7 +321,7 @@ The connection weights and biases are managed by the layer object. ### Creating Layers -The following code creates a @{tf.layers.Dense$`Dense`} layer that takes a +The following code creates a `tf.layers.Dense` layer that takes a batch of input vectors, and produces a single output value for each. To apply a layer to an input, call the layer as if it were a function. For example: @@ -375,8 +375,8 @@ will generate a two-element output vector such as the following: ### Layer Function shortcuts -For each layer class (like @{tf.layers.Dense}) TensorFlow also supplies a -shortcut function (like @{tf.layers.dense}). The only difference is that the +For each layer class (like `tf.layers.Dense`) TensorFlow also supplies a +shortcut function (like `tf.layers.dense`). The only difference is that the shortcut function versions create and run the layer in a single call. For example, the following code is equivalent to the earlier version: @@ -390,17 +390,17 @@ sess.run(init) print(sess.run(y, {x: [[1, 2, 3], [4, 5, 6]]})) ``` -While convenient, this approach allows no access to the @{tf.layers.Layer} +While convenient, this approach allows no access to the `tf.layers.Layer` object. This makes introspection and debugging more difficult, and layer reuse impossible. ## Feature columns The easiest way to experiment with feature columns is using the -@{tf.feature_column.input_layer} function. This function only accepts +`tf.feature_column.input_layer` function. This function only accepts @{$feature_columns$dense columns} as inputs, so to view the result of a categorical column you must wrap it in an -@{tf.feature_column.indicator_column}. For example: +`tf.feature_column.indicator_column`. For example: ``` python features = { @@ -422,9 +422,9 @@ inputs = tf.feature_column.input_layer(features, columns) Running the `inputs` tensor will parse the `features` into a batch of vectors. Feature columns can have internal state, like layers, so they often need to be -initialized. Categorical columns use @{tf.contrib.lookup$lookup tables} +initialized. Categorical columns use `tf.contrib.lookup` internally and these require a separate initialization op, -@{tf.tables_initializer}. +`tf.tables_initializer`. ``` python var_init = tf.global_variables_initializer() @@ -501,7 +501,7 @@ To optimize a model, you first need to define the loss. We'll use the mean square error, a standard loss for regression problems. While you could do this manually with lower level math operations, -the @{tf.losses} module provides a set of common loss functions. You can use it +the `tf.losses` module provides a set of common loss functions. You can use it to calculate the mean square error as follows: ``` python @@ -520,10 +520,10 @@ This will produce a loss value, something like: TensorFlow provides [**optimizers**](https://developers.google.com/machine-learning/glossary/#optimizer) implementing standard optimization algorithms. These are implemented as -sub-classes of @{tf.train.Optimizer}. They incrementally change each +sub-classes of `tf.train.Optimizer`. They incrementally change each variable in order to minimize the loss. The simplest optimization algorithm is [**gradient descent**](https://developers.google.com/machine-learning/glossary/#gradient_descent), -implemented by @{tf.train.GradientDescentOptimizer}. It modifies each +implemented by `tf.train.GradientDescentOptimizer`. It modifies each variable according to the magnitude of the derivative of loss with respect to that variable. For example: diff --git a/tensorflow/docs_src/programmers_guide/premade_estimators.md b/tensorflow/docs_src/guide/premade_estimators.md similarity index 95% rename from tensorflow/docs_src/programmers_guide/premade_estimators.md rename to tensorflow/docs_src/guide/premade_estimators.md index f6dd75eacab1c99215ab918a0854b0a33d0d9cca..dc38f0c1d38d8ffed8abb820eadf7f093307d01b 100644 --- a/tensorflow/docs_src/programmers_guide/premade_estimators.md +++ b/tensorflow/docs_src/guide/premade_estimators.md @@ -78,10 +78,10 @@ provides a programming stack consisting of multiple API layers: We strongly recommend writing TensorFlow programs with the following APIs: -* @{$programmers_guide/estimators$Estimators}, which represent a complete model. +* @{$guide/estimators$Estimators}, which represent a complete model. The Estimator API provides methods to train the model, to judge the model's accuracy, and to generate predictions. -* @{$get_started/datasets_quickstart$Datasets}, which build a data input +* @{$guide/datasets_for_estimators}, which build a data input pipeline. The Dataset API has methods to load and manipulate data, and feed it into your model. The Dataset API meshes well with the Estimators API. @@ -173,11 +173,11 @@ example is an Iris Versicolor. An Estimator is TensorFlow's high-level representation of a complete model. It handles the details of initialization, logging, saving and restoring, and many other features so you can concentrate on your model. For more details see -@{$programmers_guide/estimators}. +@{$guide/estimators}. -An Estimator is any class derived from @{tf.estimator.Estimator}. TensorFlow +An Estimator is any class derived from `tf.estimator.Estimator`. TensorFlow provides a collection of -@{tf.estimator$pre-made Estimators} +`tf.estimator` (for example, `LinearRegressor`) to implement common ML algorithms. Beyond those, you may write your own @{$custom_estimators$custom Estimators}. @@ -200,7 +200,7 @@ Let's see how those tasks are implemented for Iris classification. You must create input functions to supply data for training, evaluating, and prediction. -An **input function** is a function that returns a @{tf.data.Dataset} object +An **input function** is a function that returns a `tf.data.Dataset` object which outputs the following two-element tuple: * [`features`](https://developers.google.com/machine-learning/glossary/#feature) - A Python dictionary in which: @@ -271,7 +271,7 @@ A [**feature column**](https://developers.google.com/machine-learning/glossary/# is an object describing how the model should use raw input data from the features dictionary. When you build an Estimator model, you pass it a list of feature columns that describes each of the features you want the model to use. -The @{tf.feature_column} module provides many options for representing data +The `tf.feature_column` module provides many options for representing data to the model. For Iris, the 4 raw features are numeric values, so we'll build a list of @@ -299,10 +299,10 @@ features, we can build the estimator. The Iris problem is a classic classification problem. Fortunately, TensorFlow provides several pre-made classifier Estimators, including: -* @{tf.estimator.DNNClassifier} for deep models that perform multi-class +* `tf.estimator.DNNClassifier` for deep models that perform multi-class classification. -* @{tf.estimator.DNNLinearCombinedClassifier} for wide & deep models. -* @{tf.estimator.LinearClassifier} for classifiers based on linear models. +* `tf.estimator.DNNLinearCombinedClassifier` for wide & deep models. +* `tf.estimator.LinearClassifier` for classifiers based on linear models. For the Iris problem, `tf.estimator.DNNClassifier` seems like the best choice. Here's how we instantiated this Estimator: @@ -424,9 +424,7 @@ Now that you've gotten started writing TensorFlow programs, consider the following material: * @{$checkpoints$Checkpoints} to learn how to save and restore models. -* @{$get_started/datasets_quickstart$Datasets} to learn more about importing - data into your - model. +* @{$guide/datasets_for_estimators} to learn more about importing + data into your model. * @{$custom_estimators$Creating Custom Estimators} to learn how to write your own Estimator, customized for a particular problem. - diff --git a/tensorflow/docs_src/programmers_guide/saved_model.md b/tensorflow/docs_src/guide/saved_model.md similarity index 93% rename from tensorflow/docs_src/programmers_guide/saved_model.md rename to tensorflow/docs_src/guide/saved_model.md index c6ef87c54a3bc37dbfc0553232a8e3d30f8ee2f6..c260da79668807eaefb3811fd475151571cb69bf 100644 --- a/tensorflow/docs_src/programmers_guide/saved_model.md +++ b/tensorflow/docs_src/guide/saved_model.md @@ -1,10 +1,9 @@ # Save and Restore -The @{tf.train.Saver} class provides methods to save and restore models. The -@{tf.saved_model.simple_save} function is an easy way to build a -@{tf.saved_model$saved model} suitable for serving. -[Estimators](@{$programmers_guide/estimators}) automatically save and restore -variables in the `model_dir`. +The `tf.train.Saver` class provides methods to save and restore models. The +`tf.saved_model.simple_save` function is an easy way to build a +`tf.saved_model` suitable for serving. [Estimators](./estimators) +automatically save and restore variables in the `model_dir`. ## Save and restore variables @@ -146,13 +145,13 @@ Notes: * If you only restore a subset of the model variables at the start of a session, you have to run an initialize op for the other variables. See - @{tf.variables_initializer} for more information. + `tf.variables_initializer` for more information. * To inspect the variables in a checkpoint, you can use the [`inspect_checkpoint`](https://www.tensorflow.org/code/tensorflow/python/tools/inspect_checkpoint.py) library, particularly the `print_tensors_in_checkpoint_file` function. -* By default, `Saver` uses the value of the @{tf.Variable.name} property +* By default, `Saver` uses the value of the `tf.Variable.name` property for each variable. However, when you create a `Saver` object, you may optionally choose names for the variables in the checkpoint files. @@ -197,15 +196,15 @@ Use `SavedModel` to save and load your model—variables, the graph, and the graph's metadata. This is a language-neutral, recoverable, hermetic serialization format that enables higher-level systems and tools to produce, consume, and transform TensorFlow models. TensorFlow provides several ways to -interact with `SavedModel`, including the @{tf.saved_model} APIs, -@{tf.estimator.Estimator}, and a command-line interface. +interact with `SavedModel`, including the `tf.saved_model` APIs, +`tf.estimator.Estimator`, and a command-line interface. ## Build and load a SavedModel ### Simple save -The easiest way to create a `SavedModel` is to use the @{tf.saved_model.simple_save} +The easiest way to create a `SavedModel` is to use the `tf.saved_model.simple_save` function: ```python @@ -219,14 +218,14 @@ This configures the `SavedModel` so it can be loaded by [TensorFlow serving](/serving/serving_basic) and supports the [Predict API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto). To access the classify, regress, or multi-inference APIs, use the manual -`SavedModel` builder APIs or an @{tf.estimator.Estimator}. +`SavedModel` builder APIs or an `tf.estimator.Estimator`. ### Manually build a SavedModel -If your use case isn't covered by @{tf.saved_model.simple_save}, use the manual -@{tf.saved_model.builder$builder APIs} to create a `SavedModel`. +If your use case isn't covered by `tf.saved_model.simple_save`, use the manual +`tf.saved_model.builder` to create a `SavedModel`. -The @{tf.saved_model.builder.SavedModelBuilder} class provides functionality to +The `tf.saved_model.builder.SavedModelBuilder` class provides functionality to save multiple `MetaGraphDef`s. A **MetaGraph** is a dataflow graph, plus its associated variables, assets, and signatures. A **`MetaGraphDef`** is the protocol buffer representation of a MetaGraph. A **signature** is @@ -273,16 +272,16 @@ builder.save() Following the guidance below gives you forward compatibility only if the set of Ops has not changed. -The @{tf.saved_model.builder.SavedModelBuilder$`SavedModelBuilder`} class allows +The `tf.saved_model.builder.SavedModelBuilder` class allows users to control whether default-valued attributes must be stripped from the @{$extend/tool_developers#nodes$`NodeDefs`} while adding a meta graph to the SavedModel bundle. Both -@{tf.saved_model.builder.SavedModelBuilder.add_meta_graph_and_variables$`SavedModelBuilder.add_meta_graph_and_variables`} -and @{tf.saved_model.builder.SavedModelBuilder.add_meta_graph$`SavedModelBuilder.add_meta_graph`} +`tf.saved_model.builder.SavedModelBuilder.add_meta_graph_and_variables` +and `tf.saved_model.builder.SavedModelBuilder.add_meta_graph` methods accept a Boolean flag `strip_default_attrs` that controls this behavior. -If `strip_default_attrs` is `False`, the exported @{tf.MetaGraphDef} will have -the default valued attributes in all its @{tf.NodeDef} instances. +If `strip_default_attrs` is `False`, the exported `tf.MetaGraphDef` will have +the default valued attributes in all its `tf.NodeDef` instances. This can break forward compatibility with a sequence of events such as the following: @@ -299,13 +298,13 @@ following: added attributes with defaults don't cause older model consumers to fail loading models regenerated with newer training binaries. -See [compatibility guidance](https://www.tensorflow.org/programmers_guide/version_compat) +See [compatibility guidance](./version_compat.md) for more information. ### Loading a SavedModel in Python The Python version of the SavedModel -@{tf.saved_model.loader$loader} +`tf.saved_model.loader` provides load and restore capability for a SavedModel. The `load` operation requires the following information: @@ -424,20 +423,20 @@ the model. This function has the following purposes: * To add any additional ops needed to convert data from the input format into the feature `Tensor`s expected by the model. -The function returns a @{tf.estimator.export.ServingInputReceiver} object, +The function returns a `tf.estimator.export.ServingInputReceiver` object, which packages the placeholders and the resulting feature `Tensor`s together. A typical pattern is that inference requests arrive in the form of serialized `tf.Example`s, so the `serving_input_receiver_fn()` creates a single string placeholder to receive them. The `serving_input_receiver_fn()` is then also -responsible for parsing the `tf.Example`s by adding a @{tf.parse_example} op to +responsible for parsing the `tf.Example`s by adding a `tf.parse_example` op to the graph. When writing such a `serving_input_receiver_fn()`, you must pass a parsing -specification to @{tf.parse_example} to tell the parser what feature names to +specification to `tf.parse_example` to tell the parser what feature names to expect and how to map them to `Tensor`s. A parsing specification takes the -form of a dict from feature names to @{tf.FixedLenFeature}, @{tf.VarLenFeature}, -and @{tf.SparseFeature}. Note this parsing specification should not include +form of a dict from feature names to `tf.FixedLenFeature`, `tf.VarLenFeature`, +and `tf.SparseFeature`. Note this parsing specification should not include any label or weight columns, since those will not be available at serving time—in contrast to a parsing specification used in the `input_fn()` at training time. @@ -458,7 +457,7 @@ def serving_input_receiver_fn(): return tf.estimator.export.ServingInputReceiver(features, receiver_tensors) ``` -The @{tf.estimator.export.build_parsing_serving_input_receiver_fn} utility +The `tf.estimator.export.build_parsing_serving_input_receiver_fn` utility function provides that input receiver for the common case. > Note: when training a model to be served using the Predict API with a local @@ -469,7 +468,7 @@ Even if you require no parsing or other input processing—that is, if the serving system will feed feature `Tensor`s directly—you must still provide a `serving_input_receiver_fn()` that creates placeholders for the feature `Tensor`s and passes them through. The -@{tf.estimator.export.build_raw_serving_input_receiver_fn} utility provides for +`tf.estimator.export.build_raw_serving_input_receiver_fn` utility provides for this. If these utilities do not meet your needs, you are free to write your own @@ -489,7 +488,7 @@ By contrast, the *output* portion of the signature is determined by the model. ### Specify the outputs of a custom model When writing a custom `model_fn`, you must populate the `export_outputs` element -of the @{tf.estimator.EstimatorSpec} return value. This is a dict of +of the `tf.estimator.EstimatorSpec` return value. This is a dict of `{name: output}` describing the output signatures to be exported and used during serving. @@ -499,9 +498,9 @@ is represented by an entry in this dict. In this case the `name` is a string of your choice that can be used to request a specific head at serving time. Each `output` value must be an `ExportOutput` object such as -@{tf.estimator.export.ClassificationOutput}, -@{tf.estimator.export.RegressionOutput}, or -@{tf.estimator.export.PredictOutput}. +`tf.estimator.export.ClassificationOutput`, +`tf.estimator.export.RegressionOutput`, or +`tf.estimator.export.PredictOutput`. These output types map straightforwardly to the [TensorFlow Serving APIs](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto), @@ -521,7 +520,7 @@ does not specify one. ### Perform the export To export your trained Estimator, call -@{tf.estimator.Estimator.export_savedmodel} with the export base path and +`tf.estimator.Estimator.export_savedmodel` with the export base path and the `serving_input_receiver_fn`. ```py @@ -794,11 +793,12 @@ Here's the syntax: ``` usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET --signature_def SIGNATURE_DEF_KEY [--inputs INPUTS] - [--input_exprs INPUT_EXPRS] [--outdir OUTDIR] + [--input_exprs INPUT_EXPRS] + [--input_examples INPUT_EXAMPLES] [--outdir OUTDIR] [--overwrite] [--tf_debug] ``` -The `run` command provides the following two ways to pass inputs to the model: +The `run` command provides the following three ways to pass inputs to the model: * `--inputs` option enables you to pass numpy ndarray in files. * `--input_exprs` option enables you to pass Python expressions. @@ -847,7 +847,7 @@ dictionary is stored in the pickle file and the value corresponding to the *variable_name* will be used. -#### `--inputs_exprs` +#### `--input_exprs` To pass inputs through Python expressions, specify the `--input_exprs` option. This can be useful for when you don't have data @@ -869,7 +869,7 @@ example: (Note that the `numpy` module is already available to you as `np`.) -#### `--inputs_examples` +#### `--input_examples` To pass `tf.train.Example` as inputs, specify the `--input_examples` option. For each input key, it takes a list of dictionary, where each dictionary is an diff --git a/tensorflow/docs_src/programmers_guide/summaries_and_tensorboard.md b/tensorflow/docs_src/guide/summaries_and_tensorboard.md similarity index 98% rename from tensorflow/docs_src/programmers_guide/summaries_and_tensorboard.md rename to tensorflow/docs_src/guide/summaries_and_tensorboard.md index fadfa03e78349801d69e0045991a8fa9a0a59df9..6177c3393b203620842d88f9a18bfdde2239f369 100644 --- a/tensorflow/docs_src/programmers_guide/summaries_and_tensorboard.md +++ b/tensorflow/docs_src/guide/summaries_and_tensorboard.md @@ -41,7 +41,7 @@ data from, and decide which nodes you would like to annotate with For example, suppose you are training a convolutional neural network for recognizing MNIST digits. You'd like to record how the learning rate varies over time, and how the objective function is changing. Collect these by -attaching @{tf.summary.scalar} ops +attaching `tf.summary.scalar` ops to the nodes that output the learning rate and loss respectively. Then, give each `scalar_summary` a meaningful `tag`, like `'learning rate'` or `'loss function'`. @@ -49,7 +49,7 @@ function'`. Perhaps you'd also like to visualize the distributions of activations coming off a particular layer, or the distribution of gradients or weights. Collect this data by attaching -@{tf.summary.histogram} ops to +`tf.summary.histogram` ops to the gradient outputs and to the variable that holds your weights, respectively. For details on all of the summary operations available, check out the docs on @@ -60,13 +60,13 @@ depends on their output. And the summary nodes that we've just created are peripheral to your graph: none of the ops you are currently running depend on them. So, to generate summaries, we need to run all of these summary nodes. Managing them by hand would be tedious, so use -@{tf.summary.merge_all} +`tf.summary.merge_all` to combine them into a single op that generates all the summary data. Then, you can just run the merged summary op, which will generate a serialized `Summary` protobuf object with all of your summary data at a given step. Finally, to write this summary data to disk, pass the summary protobuf to a -@{tf.summary.FileWriter}. +`tf.summary.FileWriter`. The `FileWriter` takes a logdir in its constructor - this logdir is quite important, it's the directory where all of the events will be written out. diff --git a/tensorflow/docs_src/programmers_guide/tensorboard_histograms.md b/tensorflow/docs_src/guide/tensorboard_histograms.md similarity index 98% rename from tensorflow/docs_src/programmers_guide/tensorboard_histograms.md rename to tensorflow/docs_src/guide/tensorboard_histograms.md index 918deda190a930e504b7d1b213a2af611b2e919e..af8f2cadd1202a5d902eb43dc993710e7bf81aff 100644 --- a/tensorflow/docs_src/programmers_guide/tensorboard_histograms.md +++ b/tensorflow/docs_src/guide/tensorboard_histograms.md @@ -13,8 +13,8 @@ TensorFlow has an op which is perfect for this purpose. As is usually the case with TensorBoard, we will ingest data using a summary op; in this case, ['tf.summary.histogram'](https://www.tensorflow.org/api_docs/python/tf/summary/histogram). -For a primer on how summaries work, please see the general -[TensorBoard tutorial](https://www.tensorflow.org/get_started/summaries_and_tensorboard). +For a primer on how summaries work, please see the +[TensorBoard guide](./summaries_and_tensorboard.md). Here is a code snippet that will generate some histogram summaries containing normally distributed data, where the mean of the distribution increases over diff --git a/tensorflow/docs_src/programmers_guide/tensors.md b/tensorflow/docs_src/guide/tensors.md similarity index 98% rename from tensorflow/docs_src/programmers_guide/tensors.md rename to tensorflow/docs_src/guide/tensors.md index 1248c3cabe23c8d5f200fc1bf46e60851ba532a6..6b5a110a1c3e59b2b9d18c8c43d56c4323bdbf55 100644 --- a/tensorflow/docs_src/programmers_guide/tensors.md +++ b/tensorflow/docs_src/guide/tensors.md @@ -26,7 +26,7 @@ some cases it's only possible to find the shape of a tensor at graph execution time. Some types of tensors are special, and these will be covered in other -units of the Programmer's guide. The main ones are: +units of the TensorFlow guide. The main ones are: * `tf.Variable` * `tf.constant` @@ -176,7 +176,7 @@ Rank | Shape | Dimension number | Example n | [D0, D1, ... Dn-1] | n-D | A tensor with shape [D0, D1, ... Dn-1]. Shapes can be represented via Python lists / tuples of ints, or with the -@{tf.TensorShape}. +`tf.TensorShape`. ### Getting a `tf.Tensor` object's shape @@ -230,7 +230,7 @@ yet_another = tf.reshape(matrixAlt, [13, 2, -1]) # ERROR! ## Data types In addition to dimensionality, Tensors have a data type. Refer to the -`tf.DataType` page in the programmer's guide for a full list of the data types. +`tf.DType` page for a complete list of the data types. It is not possible to have a `tf.Tensor` with more than one data type. It is possible, however, to serialize arbitrary data structures as `string`s and store diff --git a/tensorflow/docs_src/programmers_guide/using_gpu.md b/tensorflow/docs_src/guide/using_gpu.md similarity index 99% rename from tensorflow/docs_src/programmers_guide/using_gpu.md rename to tensorflow/docs_src/guide/using_gpu.md index c429ca4750753278e4736650a08fd0c71e0d9fad..c0218fd12e1f5d69d667e50472fa75ed394e9318 100644 --- a/tensorflow/docs_src/programmers_guide/using_gpu.md +++ b/tensorflow/docs_src/guide/using_gpu.md @@ -143,7 +143,7 @@ If the device you have specified does not exist, you will get ``` InvalidArgumentError: Invalid argument: Cannot assign a device to node 'b': Could not satisfy explicit device specification '/device:GPU:2' - [[Node: b = Const[dtype=DT_FLOAT, value=Tensor, _device="/device:GPU:2"]()]] ``` diff --git a/tensorflow/docs_src/programmers_guide/using_tpu.md b/tensorflow/docs_src/guide/using_tpu.md similarity index 91% rename from tensorflow/docs_src/programmers_guide/using_tpu.md rename to tensorflow/docs_src/guide/using_tpu.md index 44aabf05571bb7f325a5d642f06362e0088607d2..90a663b75ed87e724009897045abac7bb338e911 100644 --- a/tensorflow/docs_src/programmers_guide/using_tpu.md +++ b/tensorflow/docs_src/guide/using_tpu.md @@ -17,9 +17,9 @@ This doc is aimed at users who: ## TPUEstimator -@{tf.estimator.Estimator$Estimators} are TensorFlow's model-level abstraction. +`tf.estimator.Estimator` are TensorFlow's model-level abstraction. Standard `Estimators` can drive models on CPU and GPUs. You must use -@{tf.contrib.tpu.TPUEstimator} to drive a model on TPUs. +`tf.contrib.tpu.TPUEstimator` to drive a model on TPUs. Refer to TensorFlow's Getting Started section for an introduction to the basics of using a @{$premade_estimators$pre-made `Estimator`}, and @@ -44,10 +44,10 @@ my_estimator = tf.estimator.Estimator( model_fn=my_model_fn) ``` -The changes required to use a @{tf.contrib.tpu.TPUEstimator} on your local +The changes required to use a `tf.contrib.tpu.TPUEstimator` on your local machine are relatively minor. The constructor requires two additional arguments. You should set the `use_tpu` argument to `False`, and pass a -@{tf.contrib.tpu.RunConfig} as the `config` argument, as shown below: +`tf.contrib.tpu.RunConfig` as the `config` argument, as shown below: ``` python my_tpu_estimator = tf.contrib.tpu.TPUEstimator( @@ -117,7 +117,7 @@ my_tpu_run_config = tf.contrib.tpu.RunConfig( ) ``` -Then you must pass the @{tf.contrib.tpu.RunConfig} to the constructor: +Then you must pass the `tf.contrib.tpu.RunConfig` to the constructor: ``` python my_tpu_estimator = tf.contrib.tpu.TPUEstimator( @@ -137,7 +137,7 @@ training locally to training on a cloud TPU you would need to: ## Optimizer When training on a cloud TPU you **must** wrap the optimizer in a -@{tf.contrib.tpu.CrossShardOptimizer}, which uses an `allreduce` to aggregate +`tf.contrib.tpu.CrossShardOptimizer`, which uses an `allreduce` to aggregate gradients and broadcast the result to each shard (each TPU core). The `CrossShardOptimizer` is not compatible with local training. So, to have @@ -171,7 +171,7 @@ This section details the changes you must make to the model function During regular usage TensorFlow attempts to determine the shapes of each `tf.Tensor` during graph construction. During execution any unknown shape dimensions are determined dynamically, -see @{$programmers_guide/tensors#shape$Tensor Shapes} for more details. +see @{$guide/tensors#shape$Tensor Shapes} for more details. To run on Cloud TPUs TensorFlow models are compiled using @{$xla$XLA}. XLA uses a similar system for determining shapes at compile time. XLA requires @@ -195,12 +195,12 @@ TPU. Build your evaluation metrics dictionary in a stand-alone `metric_fn`. - + Evaluation metrics are an essential part of training a model. These are fully supported on Cloud TPUs, but with a slightly different syntax. -A standard @{tf.metrics} returns two tensors. The first returns the running +A standard `tf.metrics` returns two tensors. The first returns the running average of the metric value, while the second updates the running average and returns the value for this batch: @@ -242,15 +242,15 @@ An `Estimator`'s `model_fn` must return an `EstimatorSpec`. An `EstimatorSpec` is a simple structure of named fields containing all the `tf.Tensors` of the model that the `Estimator` may need to interact with. -`TPUEstimators` use a @{tf.contrib.tpu.TPUEstimatorSpec}. There are a few -differences between it and a standard @{tf.estimator.EstimatorSpec}: +`TPUEstimators` use a `tf.contrib.tpu.TPUEstimatorSpec`. There are a few +differences between it and a standard `tf.estimator.EstimatorSpec`: * The `eval_metric_ops` must be wrapped into a `metrics_fn`, this field is renamed `eval_metrics` ([see above](#metrics)). -* The @{tf.train.SessionRunHook$hooks} are unsupported, so these fields are +* The `tf.train.SessionRunHook` are unsupported, so these fields are omitted. -* The @{tf.train.Scaffold$`scaffold`}, if used, must also be wrapped in a +* The `tf.train.Scaffold`, if used, must also be wrapped in a function. This field is renamed to `scaffold_fn`. `Scaffold` and `Hooks` are for advanced usage, and can typically be omitted. @@ -304,7 +304,7 @@ In many cases the batch size is the only unknown dimension. A typical input pipeline, using `tf.data`, will usually produce batches of a fixed size. The last batch of a finite `Dataset`, however, is typically smaller, containing just the remaining elements. Since a `Dataset` does not know its own -length or finiteness, the standard @{tf.data.Dataset.batch$`batch`} method +length or finiteness, the standard `tf.data.Dataset.batch` method cannot determine if all batches will have a fixed size batch on its own: ``` @@ -317,7 +317,7 @@ cannot determine if all batches will have a fixed size batch on its own: ``` The most straightforward fix is to -@{tf.data.Dataset.apply$apply} @{tf.contrib.data.batch_and_drop_remainder} +`tf.data.Dataset.apply` `tf.contrib.data.batch_and_drop_remainder` as follows: ``` @@ -346,19 +346,19 @@ TPU, as it is impossible to use the Cloud TPU's unless you can feed it data quickly enough. See @{$datasets_performance} for details on dataset performance. For all but the simplest experimentation (using -@{tf.data.Dataset.from_tensor_slices} or other in-graph data) you will need to +`tf.data.Dataset.from_tensor_slices` or other in-graph data) you will need to store all data files read by the `TPUEstimator`'s `Dataset` in Google Cloud Storage Buckets. For most use-cases, we recommend converting your data into `TFRecord` -format and using a @{tf.data.TFRecordDataset} to read it. This, however, is not +format and using a `tf.data.TFRecordDataset` to read it. This, however, is not a hard requirement and you can use other dataset readers (`FixedLengthRecordDataset` or `TextLineDataset`) if you prefer. Small datasets can be loaded entirely into memory using -@{tf.data.Dataset.cache}. +`tf.data.Dataset.cache`. Regardless of the data format used, it is strongly recommended that you @{$performance_guide#use_large_files$use large files}, on the order of diff --git a/tensorflow/docs_src/programmers_guide/variables.md b/tensorflow/docs_src/guide/variables.md similarity index 98% rename from tensorflow/docs_src/programmers_guide/variables.md rename to tensorflow/docs_src/guide/variables.md index cd8c4b5b9a026f01af4957ade0e132477b0066a5..5d5d73394c6f2529c9af5513e2e8d661a1f8a147 100644 --- a/tensorflow/docs_src/programmers_guide/variables.md +++ b/tensorflow/docs_src/guide/variables.md @@ -119,7 +119,7 @@ It is particularly important for variables to be in the correct device in distributed settings. Accidentally putting variables on workers instead of parameter servers, for example, can severely slow down training or, in the worst case, let each worker blithely forge ahead with its own independent copy of each -variable. For this reason we provide @{tf.train.replica_device_setter}, which +variable. For this reason we provide `tf.train.replica_device_setter`, which can automatically place variables in parameter servers. For example: ``` python @@ -211,7 +211,7 @@ sess.run(assignment) # or assignment.op.run(), or assignment.eval() Most TensorFlow optimizers have specialized ops that efficiently update the values of variables according to some gradient descent-like algorithm. See -@{tf.train.Optimizer} for an explanation of how to use optimizers. +`tf.train.Optimizer` for an explanation of how to use optimizers. Because variables are mutable it's sometimes useful to know what version of a variable's value is being used at any point in time. To force a re-read of the diff --git a/tensorflow/docs_src/programmers_guide/version_compat.md b/tensorflow/docs_src/guide/version_compat.md similarity index 96% rename from tensorflow/docs_src/programmers_guide/version_compat.md rename to tensorflow/docs_src/guide/version_compat.md index 72e427c5f8f0f6581d528f4ead18699736eafd04..29ac066e6f2b94fa456a3af2c851a5e87be765da 100644 --- a/tensorflow/docs_src/programmers_guide/version_compat.md +++ b/tensorflow/docs_src/guide/version_compat.md @@ -66,7 +66,7 @@ patch versions. The public APIs consist of Some API functions are explicitly marked as "experimental" and can change in backward incompatible ways between minor releases. These include: -* **Experimental APIs**: The @{tf.contrib} module and its submodules in Python +* **Experimental APIs**: The `tf.contrib` module and its submodules in Python and any functions in the C API or fields in protocol buffers that are explicitly commented as being experimental. In particular, any field in a protocol buffer which is called "experimental" and all its fields and @@ -79,6 +79,7 @@ backward incompatible ways between minor releases. These include: [`tensorflow/cc`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/cc)). - [Java](../api_docs/java/reference/org/tensorflow/package-summary), - [Go](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go) + - [JavaScript](https://js.tensorflow.org) * **Details of composite ops:** Many public functions in Python expand to several primitive ops in the graph, and these details will be part of any @@ -252,13 +253,13 @@ ops has not changed: 1. If forward compatibility is desired, set `strip_default_attrs` to `True` while exporting the model using either the - @{tf.saved_model.builder.SavedModelBuilder.add_meta_graph_and_variables$`add_meta_graph_and_variables`} - and @{tf.saved_model.builder.SavedModelBuilder.add_meta_graph$`add_meta_graph`} + `tf.saved_model.builder.SavedModelBuilder.add_meta_graph_and_variables` + and `tf.saved_model.builder.SavedModelBuilder.add_meta_graph` methods of the `SavedModelBuilder` class, or - @{tf.estimator.Estimator.export_savedmodel$`Estimator.export_savedmodel`} + `tf.estimator.Estimator.export_savedmodel` 2. This strips off the default valued attributes at the time of producing/exporting the models. This makes sure that the exported - @{tf.MetaGraphDef} does not contain the new op-attribute when the default + `tf.MetaGraphDef` does not contain the new op-attribute when the default value is used. 3. Having this control could allow out-of-date consumers (for example, serving binaries that lag behind training binaries) to continue loading the models @@ -301,8 +302,10 @@ existing producer scripts will not suddenly use the new functionality. #### Change an op's functionality 1. Add a new similar op named `SomethingV2` or similar and go through the - process of adding it and switching existing Python wrappers to use it, which - may take three weeks if forward compatibility is desired. + process of adding it and switching existing Python wrappers to use it. + To ensure forward compatibility use the checks suggested in + [compat.py](https://www.tensorflow.org/code/tensorflow/python/compat/compat.py) + when changing the Python wrappers. 2. Remove the old op (Can only take place with a major version change due to backward compatibility). 3. Increase `min_consumer` to rule out consumers with the old op, add back the diff --git a/tensorflow/docs_src/install/index.md b/tensorflow/docs_src/install/index.md index c2e5a991d459f52be7c08a6da3bcad0c57e58934..55481cc4001429a8257b846a2161f088ce2d9c10 100644 --- a/tensorflow/docs_src/install/index.md +++ b/tensorflow/docs_src/install/index.md @@ -1,36 +1,39 @@ -# Installing TensorFlow +# Install TensorFlow -We've built and tested TensorFlow on the following 64-bit laptop/desktop -operating systems: +Note: Run the [TensorFlow tutorials](../tutorials) in a pre-configured +[Colab notebook environment](https://colab.research.google.com/notebooks/welcome.ipynb){: .external}, +without installation. + +TensorFlow is built and tested on the following 64-bit operating systems: * macOS 10.12.6 (Sierra) or later. * Ubuntu 16.04 or later * Windows 7 or later. * Raspbian 9.0 or later. -Although you might be able to install TensorFlow on other laptop or desktop -systems, we only support (and only fix issues in) the preceding configurations. +While TensorFlow may work on other systems, we only support—and fix issues in—the +systems listed above. The following guides explain how to install a version of TensorFlow that enables you to write applications in Python: - * @{$install_linux$Installing TensorFlow on Ubuntu} - * @{$install_mac$Installing TensorFlow on macOS} - * @{$install_windows$Installing TensorFlow on Windows} - * @{$install_raspbian$Installing TensorFlow on a Raspberry Pi} - * @{$install_sources$Installing TensorFlow from Sources} + * @{$install_linux$Install TensorFlow on Ubuntu} + * @{$install_mac$Install TensorFlow on macOS} + * @{$install_windows$Install TensorFlow on Windows} + * @{$install_raspbian$Install TensorFlow on a Raspberry Pi} + * @{$install_sources$Install TensorFlow from source code} Many aspects of the Python TensorFlow API changed from version 0.n to 1.0. The following guide explains how to migrate older TensorFlow applications to Version 1.0: - * @{$migration$Transitioning to TensorFlow 1.0} + * @{$migration$Transition to TensorFlow 1.0} The following guides explain how to install TensorFlow libraries for use in other programming languages. These APIs are aimed at deploying TensorFlow models in applications and are not as extensive as the Python APIs. - * @{$install_java$Installing TensorFlow for Java} - * @{$install_c$Installing TensorFlow for C} - * @{$install_go$Installing TensorFlow for Go} + * @{$install_java$Install TensorFlow for Java} + * @{$install_c$Install TensorFlow for C} + * @{$install_go$Install TensorFlow for Go} diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md index 29018487450951661f2b6aaef59e2fe59ab052f0..4a63f11fcac03a1b56f900fc29b1950bdba2e4bf 100644 --- a/tensorflow/docs_src/install/install_c.md +++ b/tensorflow/docs_src/install/install_c.md @@ -1,4 +1,4 @@ -# Installing TensorFlow for C +# Install TensorFlow for C TensorFlow provides a C API defined in [`c_api.h`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.h), @@ -38,7 +38,7 @@ enable TensorFlow for C: OS="linux" # Change to "darwin" for macOS TARGET_DIRECTORY="/usr/local" curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.9.0-rc0.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.10.0.tar.gz" | sudo tar -C $TARGET_DIRECTORY -xz The `tar` command extracts the TensorFlow C library into the `lib` diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md index 55bc0f64e799ecb7115cc656c8a08ec1ce2a6108..f0f8436777ea17885b6ccd2b0f75fbb9e900d15f 100644 --- a/tensorflow/docs_src/install/install_go.md +++ b/tensorflow/docs_src/install/install_go.md @@ -1,4 +1,4 @@ -# Installing TensorFlow for Go +# Install TensorFlow for Go TensorFlow provides APIs for use in Go programs. These APIs are particularly well-suited to loading models created in Python and executing them within @@ -6,7 +6,7 @@ a Go application. This guide explains how to install and set up the [TensorFlow Go package](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go). Warning: The TensorFlow Go API is *not* covered by the TensorFlow -[API stability guarantees](https://www.tensorflow.org/programmers_guide/version_semantics). +[API stability guarantees](../guide/version_compat.md). ## Supported Platforms @@ -38,7 +38,7 @@ steps to install this library and enable TensorFlow for Go: TF_TYPE="cpu" # Change to "gpu" for GPU support TARGET_DIRECTORY='/usr/local' curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.9.0-rc0.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.10.0.tar.gz" | sudo tar -C $TARGET_DIRECTORY -xz The `tar` command extracts the TensorFlow C library into the `lib` diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md index 637231da1252097a9a143edea23f3248c7cf2eb6..c131a2ea766625a57af6df60ad425cc46bf7cad2 100644 --- a/tensorflow/docs_src/install/install_java.md +++ b/tensorflow/docs_src/install/install_java.md @@ -1,4 +1,4 @@ -# Installing TensorFlow for Java +# Install TensorFlow for Java TensorFlow provides APIs for use in Java programs. These APIs are particularly well-suited to loading models created in Python and executing them within a @@ -7,7 +7,7 @@ Java application. This guide explains how to install and use it in a Java application. Warning: The TensorFlow Java API is *not* covered by the TensorFlow -[API stability guarantees](https://www.tensorflow.org/programmers_guide/version_semantics). +[API stability guarantees](../guide/version_semantics.md). ## Supported Platforms @@ -36,7 +36,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs: org.tensorflow tensorflow - 1.9.0-rc0 + 1.10.0 ``` @@ -65,7 +65,7 @@ As an example, these steps will create a Maven project that uses TensorFlow: org.tensorflow tensorflow - 1.9.0-rc0 + 1.10.0 @@ -124,12 +124,12 @@ instead: org.tensorflow libtensorflow - 1.9.0-rc0 + 1.10.0 org.tensorflow libtensorflow_jni_gpu - 1.9.0-rc0 + 1.10.0 ``` @@ -148,7 +148,7 @@ refer to the simpler instructions above instead. Take the following steps to install TensorFlow for Java on Linux or macOS: 1. Download - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.9.0-rc0.jar), + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.10.0.jar), which is the TensorFlow Java Archive (JAR). 2. Decide whether you will run TensorFlow for Java on CPU(s) only or with @@ -167,7 +167,7 @@ Take the following steps to install TensorFlow for Java on Linux or macOS: OS=$(uname -s | tr '[:upper:]' '[:lower:]') mkdir -p ./jni curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.9.0-rc0.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.10.0.tar.gz" | tar -xz -C ./jni ### Install on Windows @@ -175,10 +175,10 @@ Take the following steps to install TensorFlow for Java on Linux or macOS: Take the following steps to install TensorFlow for Java on Windows: 1. Download - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.9.0-rc0.jar), + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.10.0.jar), which is the TensorFlow Java Archive (JAR). 2. Download the following Java Native Interface (JNI) file appropriate for - [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.9.0-rc0.zip). + [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.10.0.zip). 3. Extract this .zip file. __Note__: The native library (`tensorflow_jni.dll`) requires `msvcp140.dll` at runtime, which is included in the [Visual C++ 2015 Redistributable](https://www.microsoft.com/en-us/download/details.aspx?id=48145) package. @@ -227,7 +227,7 @@ must be part of your `classpath`. For example, you can include the downloaded `.jar` in your `classpath` by using the `-cp` compilation flag as follows: -
javac -cp libtensorflow-1.9.0-rc0.jar HelloTF.java
+
javac -cp libtensorflow-1.10.0.jar HelloTF.java
### Running @@ -241,11 +241,11 @@ two files are available to the JVM: For example, the following command line executes the `HelloTF` program on Linux and macOS X: -
java -cp libtensorflow-1.9.0-rc0.jar:. -Djava.library.path=./jni HelloTF
+
java -cp libtensorflow-1.10.0.jar:. -Djava.library.path=./jni HelloTF
And the following command line executes the `HelloTF` program on Windows: -
java -cp libtensorflow-1.9.0-rc0.jar;. -Djava.library.path=jni HelloTF
+
java -cp libtensorflow-1.10.0.jar;. -Djava.library.path=jni HelloTF
If the program prints Hello from version, you've successfully installed TensorFlow for Java and are ready to use the API. If the program diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md index c8d706cf3c232d1ea91265bd7ad38d5227c440f0..0febdee99fd267947858cea2b2a3fcbfc59f986d 100644 --- a/tensorflow/docs_src/install/install_linux.md +++ b/tensorflow/docs_src/install/install_linux.md @@ -1,38 +1,38 @@ -# Installing TensorFlow on Ubuntu +# Install TensorFlow on Ubuntu This guide explains how to install TensorFlow on Ubuntu Linux. While these -instructions may work on other Linux variants, they are tested and supported with -the following system requirements: - -* 64-bit desktops or laptops -* Ubuntu 16.04 or higher +instructions may work on other Linux variants, they are tested and supported +with the following system requirements: +* 64-bit desktops or laptops +* Ubuntu 16.04 or higher ## Choose which TensorFlow to install The following TensorFlow variants are available for installation: -* __TensorFlow with CPU support only__. If your system does not have a - NVIDIA® GPU, you must install this version. This version of TensorFlow is - usually easier to install, so even if you have an NVIDIA GPU, we recommend - installing this version first. -* __TensorFlow with GPU support__. TensorFlow programs usually run much faster on - a GPU instead of a CPU. If you run performance-critical applications and your - system has an NVIDIA® GPU that meets the prerequisites, you should install - this version. See [TensorFlow GPU support](#NVIDIARequirements) for details. - +* __TensorFlow with CPU support only__. If your system does not have a + NVIDIA® GPU, you must install this version. This version of TensorFlow + is usually easier to install, so even if you have an NVIDIA GPU, we + recommend installing this version first. +* __TensorFlow with GPU support__. TensorFlow programs usually run much faster + on a GPU instead of a CPU. If you run performance-critical applications and + your system has an NVIDIA® GPU that meets the prerequisites, you should + install this version. See [TensorFlow GPU support](#NVIDIARequirements) for + details. ## How to install TensorFlow There are a few options to install TensorFlow on your machine: -* [Use pip in a virtual environment](#InstallingVirtualenv) *(recommended)* -* [Use pip in your system environment](#InstallingNativePip) -* [Configure a Docker container](#InstallingDocker) -* [Use pip in Anaconda](#InstallingAnaconda) -* [Install TensorFlow from source](/install/install_sources) +* [Use pip in a virtual environment](#InstallingVirtualenv) *(recommended)* +* [Use pip in your system environment](#InstallingNativePip) +* [Configure a Docker container](#InstallingDocker) +* [Use pip in Anaconda](#InstallingAnaconda) +* [Install TensorFlow from source](/install/install_sources) + ### Use `pip` in a virtual environment Key Point: Using a virtual environment is the recommended install method. @@ -41,8 +41,8 @@ The [Virtualenv](https://virtualenv.pypa.io/en/stable/) tool creates virtual Python environments that are isolated from other Python development on the same machine. In this scenario, you install TensorFlow and its dependencies within a virtual environment that is available when *activated*. Virtualenv provides a -reliable way to install and run TensorFlow while avoiding conflicts with the rest -of the system. +reliable way to install and run TensorFlow while avoiding conflicts with the +rest of the system. ##### 1. Install Python, `pip`, and `virtualenv`. @@ -62,10 +62,10 @@ To install these packages on Ubuntu: We *recommend* using `pip` version 8.1 or higher. If using a release before -version 8.1, upgrade `pip`: +version 8.1, upgrade `pip`:
-  sudo pip install -U pip
+  pip install --upgrade pip
 
If not using Ubuntu and [setuptools](https://pypi.org/project/setuptools/) is @@ -102,7 +102,7 @@ When the Virtualenv is activated, the shell prompt displays as `(venv) $`. Within the active virtual environment, upgrade `pip`:
-(venv)$ pip install -U pip
+(venv)$ pip install --upgrade pip
 
You can install other Python packages within the virtual environment without @@ -112,15 +112,15 @@ affecting packages outside the `virtualenv`. Choose one of the available TensorFlow packages for installation: -* `tensorflow` —Current release for CPU -* `tensorflow-gpu` —Current release with GPU support -* `tf-nightly` —Nightly build for CPU -* `tf-nightly-gpu` —Nightly build with GPU support +* `tensorflow` —Current release for CPU +* `tensorflow-gpu` —Current release with GPU support +* `tf-nightly` —Nightly build for CPU +* `tf-nightly-gpu` —Nightly build with GPU support Within an active Virtualenv environment, use `pip` to install the package:
-  pip install -U tensorflow
+  pip install --upgrade tensorflow
 
Use `pip list` to show the packages installed in the virtual environment. @@ -160,14 +160,14 @@ To uninstall TensorFlow, remove the Virtualenv directory you created in step 2: rm -r ~/tensorflow/venv - + ### Use `pip` in your system environment Use `pip` to install the TensorFlow package directly on your system without using a container or virtual environment for isolation. This method is -recommended for system administrators that want a TensorFlow installation that is -available to everyone on a multi-user system. +recommended for system administrators that want a TensorFlow installation that +is available to everyone on a multi-user system. Since a system install is not isolated, it could interfere with other Python-based installations. But if you understand `pip` and your Python @@ -195,10 +195,10 @@ To install these packages on Ubuntu: We *recommend* using `pip` version 8.1 or higher. If using a release before -version 8.1, upgrade `pip`: +version 8.1, upgrade `pip`:
-  sudo pip install -U pip
+  pip install --upgrade pip
 
If not using Ubuntu and [setuptools](https://pypi.org/project/setuptools/) is @@ -212,16 +212,16 @@ installed, use `easy_install` to install `pip`: Choose one of the available TensorFlow packages for installation: -* `tensorflow` —Current release for CPU -* `tensorflow-gpu` —Current release with GPU support -* `tf-nightly` —Nightly build for CPU -* `tf-nightly-gpu` —Nightly build with GPU support +* `tensorflow` —Current release for CPU +* `tensorflow-gpu` —Current release with GPU support +* `tf-nightly` —Nightly build for CPU +* `tf-nightly-gpu` —Nightly build with GPU support And use `pip` to install the package for Python 2 or 3:
-  sudo pip install -U tensorflow   # Python 2.7
-  sudo pip3 install -U tensorflow  # Python 3.n
+  pip install --upgrade --user tensorflow   # Python 2.7
+  pip3 install --upgrade --user tensorflow  # Python 3.n
 
Use `pip list` to show the packages installed on the system. @@ -239,8 +239,8 @@ If the above steps failed, try installing the TensorFlow binary using the remote URL of the `pip` package:
-  sudo pip install --upgrade remote-pkg-URL   # Python 2.7
-  sudo pip3 install --upgrade remote-pkg-URL  # Python 3.n
+  pip install --user --upgrade remote-pkg-URL   # Python 2.7
+  pip3 install --user --upgrade remote-pkg-URL  # Python 3.n
 
The remote-pkg-URL depends on the operating system, Python version, @@ -255,42 +255,41 @@ encounter problems. To uninstall TensorFlow on your system, use one of following commands:
-  sudo pip uninstall tensorflow   # for Python 2.7
-  sudo pip3 uninstall tensorflow  # for Python 3.n
+  pip uninstall tensorflow   # for Python 2.7
+  pip3 uninstall tensorflow  # for Python 3.n
 
+ ### Configure a Docker container -Docker completely isolates the TensorFlow installation -from pre-existing packages on your machine. The Docker container contains -TensorFlow and all its dependencies. Note that the Docker image can be quite -large (hundreds of MBs). You might choose the Docker installation if you are -incorporating TensorFlow into a larger application architecture that already -uses Docker. +Docker completely isolates the TensorFlow installation from pre-existing +packages on your machine. The Docker container contains TensorFlow and all its +dependencies. Note that the Docker image can be quite large (hundreds of MBs). +You might choose the Docker installation if you are incorporating TensorFlow +into a larger application architecture that already uses Docker. Take the following steps to install TensorFlow through Docker: - 1. Install Docker on your machine as described in the - [Docker documentation](http://docs.docker.com/engine/installation/). - 2. Optionally, create a Linux group called docker to allow - launching containers without sudo as described in the - [Docker documentation](https://docs.docker.com/engine/installation/linux/linux-postinstall/). - (If you don't do this step, you'll have to use sudo each time - you invoke Docker.) - 3. To install a version of TensorFlow that supports GPUs, you must first - install [nvidia-docker](https://github.com/NVIDIA/nvidia-docker), which - is stored in github. - 4. Launch a Docker container that contains one of the - [TensorFlow binary images](https://hub.docker.com/r/tensorflow/tensorflow/tags/). +1. Install Docker on your machine as described in the + [Docker documentation](http://docs.docker.com/engine/installation/). +2. Optionally, create a Linux group called docker to allow + launching containers without sudo as described in the + [Docker documentation](https://docs.docker.com/engine/installation/linux/linux-postinstall/). + (If you don't do this step, you'll have to use sudo each time you invoke + Docker.) +3. To install a version of TensorFlow that supports GPUs, you must first + install [nvidia-docker](https://github.com/NVIDIA/nvidia-docker), which is + stored in github. +4. Launch a Docker container that contains one of the + [TensorFlow binary images](https://hub.docker.com/r/tensorflow/tensorflow/tags/). The remainder of this section explains how to launch a Docker container. - #### CPU-only -To launch a Docker container with CPU-only support (that is, without -GPU support), enter a command of the following format: +To launch a Docker container with CPU-only support (that is, without GPU +support), enter a command of the following format:
 $ docker run -it -p hostPort:containerPort TensorFlowCPUImage
@@ -298,29 +297,31 @@ $ docker run -it -p hostPort:containerPort TensorFlowCPUImage
 
 where:
 
-  * -p hostPort:containerPort is optional.
-    If you plan to run TensorFlow programs from the shell, omit this option.
-    If you plan to run TensorFlow programs as Jupyter notebooks, set both
-    hostPort and containerPort
-    to 8888.  If you'd like to run TensorBoard inside the container,
-    add a second `-p` flag, setting both hostPort and containerPort
-    to 6006.
-  * TensorFlowCPUImage is required. It identifies the Docker
+*   -p hostPort:containerPort is optional. If you plan to run
+    TensorFlow programs from the shell, omit this option. If you plan to run
+    TensorFlow programs as Jupyter notebooks, set both hostPort
+    and containerPort to 8888. If you'd like to run
+    TensorBoard inside the container, add a second `-p` flag, setting both
+    hostPort and containerPort to 6006.
+*   TensorFlowCPUImage is required. It identifies the Docker
     container. Specify one of the following values:
-    * tensorflow/tensorflow, which is the TensorFlow CPU binary image.
-    * tensorflow/tensorflow:latest-devel, which is the latest
-      TensorFlow CPU Binary image plus source code.
-    * tensorflow/tensorflow:version, which is the
-      specified version (for example, 1.1.0rc1) of TensorFlow CPU binary image.
-    * tensorflow/tensorflow:version-devel, which is
-      the specified version (for example, 1.1.0rc1) of the TensorFlow GPU
-      binary image plus source code.
+
+    *   tensorflow/tensorflow, which is the TensorFlow CPU binary
+        image.
+    *   tensorflow/tensorflow:latest-devel, which is the latest
+        TensorFlow CPU Binary image plus source code.
+    *   tensorflow/tensorflow:version, which is the specified
+        version (for example, 1.1.0rc1) of TensorFlow CPU binary image.
+    *   tensorflow/tensorflow:version-devel, which is the
+        specified version (for example, 1.1.0rc1) of the TensorFlow GPU binary
+        image plus source code.
 
     TensorFlow images are available at
     [dockerhub](https://hub.docker.com/r/tensorflow/tensorflow/).
 
-For example, the following command launches the latest TensorFlow CPU binary image
-in a Docker container from which you can run TensorFlow programs in a shell:
+For example, the following command launches the latest TensorFlow CPU binary
+image in a Docker container from which you can run TensorFlow programs in a
+shell:
 
 
 $ docker run -it tensorflow/tensorflow bash
@@ -336,10 +337,11 @@ $ docker run -it -p 8888:8888 tensorflow/tensorflow
 
 Docker will download the TensorFlow binary image the first time you launch it.
 
-
 #### GPU support
 
-To launch a Docker container with NVidia GPU support, enter a command of the following format (this [does not require any local CUDA installation](https://github.com/nvidia/nvidia-docker/wiki/CUDA#requirements)):
+To launch a Docker container with NVidia GPU support, enter a command of the
+following format (this
+[does not require any local CUDA installation](https://github.com/nvidia/nvidia-docker/wiki/CUDA#requirements)):
 
 
 $ nvidia-docker run -it -p hostPort:containerPort TensorFlowGPUImage
@@ -347,34 +349,34 @@ $ nvidia-docker run -it -p hostPort:containerPort TensorFlowGPUImage-p hostPort:containerPort is optional. If you plan
-    to run TensorFlow programs from the shell, omit this option. If you plan
-    to run TensorFlow programs as Jupyter notebooks, set both
-    hostPort and containerPort to `8888`.
-  * TensorFlowGPUImage specifies the Docker container. You must
-    specify one of the following values:
-    * tensorflow/tensorflow:latest-gpu, which is the latest
-      TensorFlow GPU binary image.
-    * tensorflow/tensorflow:latest-devel-gpu, which is
-      the latest TensorFlow GPU Binary image plus source code.
-    * tensorflow/tensorflow:version-gpu, which is the
-      specified version (for example, 0.12.1) of the TensorFlow GPU
-      binary image.
-    * tensorflow/tensorflow:version-devel-gpu, which is
-      the specified version (for example, 0.12.1) of the TensorFlow GPU
-      binary image plus source code.
-
-We recommend installing one of the `latest` versions. For example, the
-following command launches the latest TensorFlow GPU binary image in a
-Docker container from which you can run TensorFlow programs in a shell:
+*   -p hostPort:containerPort is optional. If you plan to run
+    TensorFlow programs from the shell, omit this option. If you plan to run
+    TensorFlow programs as Jupyter notebooks, set both hostPort
+    and containerPort to `8888`.
+*   TensorFlowGPUImage specifies the Docker container. You must specify
+    one of the following values:
+    *   tensorflow/tensorflow:latest-gpu, which is the latest
+        TensorFlow GPU binary image.
+    *   tensorflow/tensorflow:latest-devel-gpu, which is the latest
+        TensorFlow GPU Binary image plus source code.
+    *   tensorflow/tensorflow:version-gpu, which is the
+        specified version (for example, 0.12.1) of the TensorFlow GPU binary
+        image.
+    *   tensorflow/tensorflow:version-devel-gpu, which is the
+        specified version (for example, 0.12.1) of the TensorFlow GPU binary
+        image plus source code.
+
+We recommend installing one of the `latest` versions. For example, the following
+command launches the latest TensorFlow GPU binary image in a Docker container
+from which you can run TensorFlow programs in a shell:
 
 
 $ nvidia-docker run -it tensorflow/tensorflow:latest-gpu bash
 
-The following command also launches the latest TensorFlow GPU binary image -in a Docker container. In this Docker container, you can run TensorFlow -programs in a Jupyter notebook: +The following command also launches the latest TensorFlow GPU binary image in a +Docker container. In this Docker container, you can run TensorFlow programs in a +Jupyter notebook:
 $ nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:latest-gpu
@@ -390,14 +392,12 @@ Docker will download the TensorFlow binary image the first time you launch it.
 For more details see the
 [TensorFlow docker readme](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/docker).
 
-
 #### Next Steps
 
-You should now
-[validate your installation](#ValidateYourInstallation).
-
+You should now [validate your installation](#ValidateYourInstallation).
 
 
+
 ### Use `pip` in Anaconda
 
 Anaconda provides the `conda` utility to create a virtual environment. However,
@@ -410,61 +410,59 @@ not tested on new TensorFlow releases.
 
 Take the following steps to install TensorFlow in an Anaconda environment:
 
-  1. Follow the instructions on the
-     [Anaconda download site](https://www.continuum.io/downloads)
-     to download and install Anaconda.
+1.  Follow the instructions on the
+    [Anaconda download site](https://www.continuum.io/downloads) to download and
+    install Anaconda.
 
-  2. Create a conda environment named tensorflow to run a version
-     of Python by invoking the following command:
+2.  Create a conda environment named tensorflow to run a version of
+    Python by invoking the following command:
 
      
$ conda create -n tensorflow pip python=2.7 # or python=3.3, etc.
- 3. Activate the conda environment by issuing the following command: +3. Activate the conda environment by issuing the following command:
$ source activate tensorflow
      (tensorflow)$  # Your prompt should change 
- 4. Issue a command of the following format to install - TensorFlow inside your conda environment: +4. Issue a command of the following format to install TensorFlow inside your + conda environment:
(tensorflow)$ pip install --ignore-installed --upgrade tfBinaryURL
- where tfBinaryURL is the - [URL of the TensorFlow Python package](#the_url_of_the_tensorflow_python_package). - For example, the following command installs the CPU-only version of - TensorFlow for Python 3.4: + where tfBinaryURL is the + [URL of the TensorFlow Python package](#the_url_of_the_tensorflow_python_package). + For example, the following command installs the CPU-only version of + TensorFlow for Python 3.4:
      (tensorflow)$ pip install --ignore-installed --upgrade \
-     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp34-cp34m-linux_x86_64.whl
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0-cp34-cp34m-linux_x86_64.whl
+ ## Validate your installation To validate your TensorFlow installation, do the following: - 1. Ensure that your environment is prepared to run TensorFlow programs. - 2. Run a short TensorFlow program. - +1. Ensure that your environment is prepared to run TensorFlow programs. +2. Run a short TensorFlow program. ### Prepare your environment -If you installed on native pip, Virtualenv, or Anaconda, then -do the following: +If you installed on native pip, Virtualenv, or Anaconda, then do the following: - 1. Start a terminal. - 2. If you installed with Virtualenv or Anaconda, activate your container. - 3. If you installed TensorFlow source code, navigate to any - directory *except* one containing TensorFlow source code. +1. Start a terminal. +2. If you installed with Virtualenv or Anaconda, activate your container. +3. If you installed TensorFlow source code, navigate to any directory *except* + one containing TensorFlow source code. -If you installed through Docker, start a Docker container -from which you can run bash. For example: +If you installed through Docker, start a Docker container from which you can run +bash. For example:
 $ docker run -it tensorflow/tensorflow bash
 
- ### Run a short TensorFlow program Invoke python from your shell as follows: @@ -486,100 +484,71 @@ TensorFlow programs:
Hello, TensorFlow!
-If the system outputs an error message instead of a greeting, see [Common -installation problems](#common_installation_problems). - -If you are new to machine learning, we recommend the following: - -* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course) -* @{$get_started/eager} +If the system outputs an error message instead of a greeting, see +[Common installation problems](#common_installation_problems). -If you are experienced with machine learning but new to TensorFlow, see -@{$get_started/eager}. +To learn more, see the [TensorFlow tutorials](../tutorials/). -## TensorFlow GPU support - -To install TensorFlow with GPU support, configure the following NVIDIA® software -on your system: - -* [CUDA Toolkit 9.0](http://nvidia.com/cuda). For details, see - [NVIDIA's documentation](http://docs.nvidia.com/cuda/cuda-installation-guide-linux/). - Append the relevant CUDA pathnames to the `LD_LIBRARY_PATH` environmental - variable as described in the NVIDIA documentation. -* [cuDNN SDK v7](http://developer.nvidia.com/cudnn). For details, see - [NVIDIA's documentation](http://docs.nvidia.com/deeplearning/sdk/cudnn-install/). - Create the `CUDA_HOME` environment variable as described in the NVIDIA - documentation. -* A GPU card with CUDA Compute Capability 3.0 or higher for building TensorFlow - from source. To use the TensorFlow binaries, version 3.5 or higher is required. - See the [NVIDIA documentation](https://developer.nvidia.com/cuda-gpus) for a - list of supported GPU cards. -* [GPU drivers](http://nvidia.com/drivers) that support your version of the CUDA - Toolkit. -* The `libcupti-dev` library is the NVIDIA CUDA Profile Tools Interface. This - library provides advanced profiling support. To install this library, - use the following command for CUDA Toolkit >= 8.0: - -
-  sudo apt-get install cuda-command-line-tools
-
- -Add this path to the `LD_LIBRARY_PATH` environmental variable: - -
-  export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:+${LD_LIBRARY_PATH}:}/usr/local/cuda/extras/CUPTI/lib64
-
-* *OPTIONAL*: For optimized performance during inference, install - *NVIDIA TensorRT 3.0*. To install the minimal amount of TensorRT - runtime components required to use with the pre-built `tensorflow-gpu` package: - -
-  wget https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1404/x86_64/nvinfer-runtime-trt-repo-ubuntu1404-3.0.4-ga-cuda9.0_1.0-1_amd64.deb
-  sudo dpkg -i nvinfer-runtime-trt-repo-ubuntu1404-3.0.4-ga-cuda9.0_1.0-1_amd64.deb
-  sudo apt-get update
-  sudo apt-get install -y --allow-downgrades libnvinfer-dev libcudnn7-dev=7.0.5.15-1+cuda9.0 libcudnn7=7.0.5.15-1+cuda9.0
-
- -Note: For compatibility with the pre-built `tensorflow-gpu` package, use the -Ubuntu *14.04* package of TensorRT (shown above). Use this even when installing -on an Ubuntu 16.04 system. - -To build the TensorFlow-TensorRT integration module from source instead of using -the pre-built binaries, see the -[module documentation](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/tensorrt#using-tensorrt-in-tensorflow). -For detailed TensorRT installation instructions, see -[NVIDIA's TensorRT documentation](http://docs.nvidia.com/deeplearning/sdk/tensorrt-install-guide/index.html). - -To avoid cuDNN version conflicts during later system upgrades, hold the cuDNN -version at 7.0.5: - -
-  sudo apt-mark hold libcudnn7 libcudnn7-dev
-
- -To allow upgrades, remove the this hold: - -
-  sudo apt-mark unhold libcudnn7 libcudnn7-dev
-
- -If you have an earlier version of the preceding packages, upgrade to the -specified versions. If upgrading is not possible, you can still run TensorFlow -with GPU support by @{$install_sources}. +## TensorFlow GPU support +Note: Due to the number of libraries required, using [Docker](#InstallingDocker) +is recommended over installing directly on the host system. + +The following NVIDIA® hardware must be installed on your system: + +* GPU card with CUDA Compute Capability 3.5 or higher. See + [NVIDIA documentation](https://developer.nvidia.com/cuda-gpus) for a list of + supported GPU cards. + +The following NVIDIA® software must be installed on your system: + +* [GPU drivers](http://nvidia.com/driver). CUDA 9.0 requires 384.x or higher. +* [CUDA Toolkit 9.0](http://nvidia.com/cuda). +* [cuDNN SDK](http://developer.nvidia.com/cudnn) (>= 7.0). Version 7.1 is + recommended. +* [CUPTI](http://docs.nvidia.com/cuda/cupti/) ships with the CUDA Toolkit, but + you also need to append its path to the `LD_LIBRARY_PATH` environment + variable: `export + LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/extras/CUPTI/lib64` +* *OPTIONAL*: [NCCL 2.2](https://developer.nvidia.com/nccl) to use TensorFlow + with multiple GPUs. +* *OPTIONAL*: + [TensorRT](http://docs.nvidia.com/deeplearning/sdk/tensorrt-install-guide/index.html) + which can improve latency and throughput for inference for some models. + +To use a GPU with CUDA Compute Capability 3.0, or different versions of the +preceding NVIDIA libraries see +@{$install_sources$installing TensorFlow from Sources}. If using Ubuntu 16.04 +and possibly other Debian based linux distros, `apt-get` can be used with the +NVIDIA repository to simplify installation. + +```bash +# Adds NVIDIA package repository. +sudo apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/7fa2af80.pub +wget http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_9.1.85-1_amd64.deb +wget http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvidia-machine-learning-repo-ubuntu1604_1.0.0-1_amd64.deb +sudo dpkg -i cuda-repo-ubuntu1604_9.1.85-1_amd64.deb +sudo dpkg -i nvidia-machine-learning-repo-ubuntu1604_1.0.0-1_amd64.deb +sudo apt-get update +# Includes optional NCCL 2.x. +sudo apt-get install cuda9.0 cuda-cublas-9-0 cuda-cufft-9-0 cuda-curand-9-0 \ + cuda-cusolver-9-0 cuda-cusparse-9-0 libcudnn7=7.1.4.18-1+cuda9.0 \ + libnccl2=2.2.13-1+cuda9.0 cuda-command-line-tools-9-0 +# Optionally install TensorRT runtime, must be done after above cuda install. +sudo apt-get update +sudo apt-get install libnvinfer4=4.1.2-1+cuda9.0 +``` ## Common installation problems We are relying on Stack Overflow to document TensorFlow installation problems -and their remedies. The following table contains links to Stack Overflow -answers for some common installation problems. -If you encounter an error message or other -installation problem not listed in the following table, search for it -on Stack Overflow. If Stack Overflow doesn't show the error message, -ask a new question about it on Stack Overflow and specify -the `tensorflow` tag. +and their remedies. The following table contains links to Stack Overflow answers +for some common installation problems. If you encounter an error message or +other installation problem not listed in the following table, search for it on +Stack Overflow. If Stack Overflow doesn't show the error message, ask a new +question about it on Stack Overflow and specify the `tensorflow` tag.
@@ -663,74 +632,67 @@ the `tensorflow` tag.
Link to GitHub or Stack Overflow Error Message
- + ## The URL of the TensorFlow Python package A few installation mechanisms require the URL of the TensorFlow Python package. The value you specify depends on three factors: - * operating system - * Python version - * CPU only vs. GPU support +* operating system +* Python version +* CPU only vs. GPU support This section documents the relevant values for Linux installations. - ### Python 2.7 CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0-cp27-none-linux_x86_64.whl
 
- GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.10.0-cp27-none-linux_x86_64.whl
 
Note that GPU support requires the NVIDIA hardware and software described in [NVIDIA requirements to run TensorFlow with GPU support](#NVIDIARequirements). - ### Python 3.4 CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0-cp34-cp34m-linux_x86_64.whl
 
- GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.10.0-cp34-cp34m-linux_x86_64.whl
 
Note that GPU support requires the NVIDIA hardware and software described in [NVIDIA requirements to run TensorFlow with GPU support](#NVIDIARequirements). - ### Python 3.5 CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl
 
- GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.10.0-cp35-cp35m-linux_x86_64.whl
 
- Note that GPU support requires the NVIDIA hardware and software described in [NVIDIA requirements to run TensorFlow with GPU support](#NVIDIARequirements). @@ -739,16 +701,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl
 
- GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.10.0-cp36-cp36m-linux_x86_64.whl
 
- Note that GPU support requires the NVIDIA hardware and software described in [NVIDIA requirements to run TensorFlow with GPU support](#NVIDIARequirements). diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md index 9d01271c5a0beebf75be9e32682583ddc4a666b1..c4d63cc10716b2f399df15bd462c3551944375b6 100644 --- a/tensorflow/docs_src/install/install_mac.md +++ b/tensorflow/docs_src/install/install_mac.md @@ -1,4 +1,4 @@ -# Installing TensorFlow on macOS +# Install TensorFlow on macOS This guide explains how to install TensorFlow on macOS. Although these instructions might also work on other macOS variants, we have only @@ -119,7 +119,7 @@ Take the following steps to install TensorFlow with Virtualenv: TensorFlow in the active Virtualenv is as follows:
 $ pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0-py3-none-any.whl If you encounter installation problems, see [Common Installation Problems](#common-installation-problems). @@ -242,7 +242,7 @@ take the following steps: issue the following command:
 $ sudo pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl 
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0-py3-none-any.whl If the preceding command fails, see [installation problems](#common-installation-problems). @@ -350,7 +350,7 @@ Take the following steps to install TensorFlow in an Anaconda environment: TensorFlow for Python 2.7:
 (targetDirectory)$ pip install --ignore-installed --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py2-none-any.whl
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0-py2-none-any.whl @@ -403,12 +403,7 @@ writing TensorFlow programs: If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). -If you are new to machine learning, we recommend the -[Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course). - -If you are experienced with machine learning but new to TensorFlow, see -@{$get_started/eager}. - +To learn more, see the [TensorFlow tutorials](../tutorials/). ## Common installation problems @@ -522,7 +517,7 @@ The value you specify depends on your Python version.
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0-py2-none-any.whl
 
@@ -530,5 +525,5 @@ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py2-none-a
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.10.0-py3-none-any.whl
 
diff --git a/tensorflow/docs_src/install/install_raspbian.md b/tensorflow/docs_src/install/install_raspbian.md index 2f425162a1c63f084702727b5280ed266196b955..cf6b6b4f79113fee7fde6e83522af4fe6d9d7f43 100644 --- a/tensorflow/docs_src/install/install_raspbian.md +++ b/tensorflow/docs_src/install/install_raspbian.md @@ -1,4 +1,4 @@ -# Installing TensorFlow on Raspbian +# Install TensorFlow on Raspbian This guide explains how to install TensorFlow on a Raspberry Pi running Raspbian. Although these instructions might also work on other Pi variants, we @@ -60,7 +60,7 @@ If it gives the error "Command not found", then the package has not been installed yet. To install if for the first time, run:
$ sudo apt-get install python3-pip # for Python 3.n
-sudo apt-get install python-pip # for Python 2.7
+$ sudo apt-get install python-pip # for Python 2.7 You can find more help on installing and upgrading pip in [the Raspberry Pi documentation](https://www.raspberrypi.org/documentation/linux/software/python.md). @@ -78,8 +78,8 @@ your system, run the following command: Assuming the prerequisite software is installed on your Pi, install TensorFlow by invoking **one** of the following commands: -
 $ pip3 install tensorflow     # Python 3.n
-     $ pip install tensorflow      # Python 2.7
+
$ pip3 install tensorflow     # Python 3.n
+$ pip install tensorflow      # Python 2.7
This can take some time on certain platforms like the Pi Zero, where some Python packages like scipy that TensorFlow depends on need to be compiled before the @@ -230,11 +230,7 @@ problems, despite the log message. If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). -If you are new to machine learning, we recommend the [Machine Learning Crash -Course](https://developers.google.com/machine-learning/crash-course). - -If you are experienced with machine learning but new to TensorFlow, see -@{$get_started/eager}. +To learn more, see the [TensorFlow tutorials](../tutorials/). ## Common installation problems diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md index dc6c1e36fc237c2a160887e6417e7f373008309e..dfd9fbce4b53dce2a981526b1794d6b359312e40 100644 --- a/tensorflow/docs_src/install/install_sources.md +++ b/tensorflow/docs_src/install/install_sources.md @@ -1,28 +1,27 @@ -# Installing TensorFlow from Sources +# Install TensorFlow from Sources -This guide explains how to build TensorFlow sources into a TensorFlow -binary and how to install that TensorFlow binary. Note that we provide -well-tested, pre-built TensorFlow binaries for Ubuntu, macOS, and Windows -systems. In addition, there are pre-built TensorFlow -[docker images](https://hub.docker.com/r/tensorflow/tensorflow/). -So, don't build a TensorFlow binary yourself unless you are very -comfortable building complex packages from source and dealing with -the inevitable aftermath should things not go exactly as documented. +This guide explains how to build TensorFlow sources into a TensorFlow binary and +how to install that TensorFlow binary. Note that we provide well-tested, +pre-built TensorFlow binaries for Ubuntu, macOS, and Windows systems. In +addition, there are pre-built TensorFlow +[docker images](https://hub.docker.com/r/tensorflow/tensorflow/). So, don't +build a TensorFlow binary yourself unless you are very comfortable building +complex packages from source and dealing with the inevitable aftermath should +things not go exactly as documented. -If the last paragraph didn't scare you off, welcome. This guide explains -how to build TensorFlow on 64-bit desktops and laptops running either of -the following operating systems: +If the last paragraph didn't scare you off, welcome. This guide explains how to +build TensorFlow on 64-bit desktops and laptops running either of the following +operating systems: * Ubuntu * macOS X -Note: Some users have successfully built and installed TensorFlow from -sources on non-supported systems. Please remember that we do not fix -issues stemming from these attempts. +Note: Some users have successfully built and installed TensorFlow from sources +on non-supported systems. Please remember that we do not fix issues stemming +from these attempts. -We **do not support** building TensorFlow on Windows. That said, if you'd -like to try to build TensorFlow on Windows anyway, use either of the -following: +We **do not support** building TensorFlow on Windows. That said, if you'd like +to try to build TensorFlow on Windows anyway, use either of the following: * [Bazel on Windows](https://bazel.build/versions/master/docs/windows.html) * [TensorFlow CMake build](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/cmake) @@ -32,38 +31,33 @@ instructions. Older CPUs may not be able to execute these binaries. ## Determine which TensorFlow to install -You must choose one of the following types of TensorFlow to build and -install: - -* **TensorFlow with CPU support only**. If your system does not have a - NVIDIA® GPU, build and install this version. Note that this version of - TensorFlow is typically easier to build and install, so even if you - have an NVIDIA GPU, we recommend building and installing this version - first. -* **TensorFlow with GPU support**. TensorFlow programs typically run - significantly faster on a GPU than on a CPU. Therefore, if your system - has a NVIDIA GPU and you need to run performance-critical applications, - you should ultimately build and install this version. - Beyond the NVIDIA GPU itself, your system must also fulfill the NVIDIA - software requirements described in one of the following documents: +You must choose one of the following types of TensorFlow to build and install: - * @{$install_linux#NVIDIARequirements$Installing TensorFlow on Ubuntu} - * @{$install_mac#NVIDIARequirements$Installing TensorFlow on macOS} +* **TensorFlow with CPU support only**. If your system does not have a NVIDIA® + GPU, build and install this version. Note that this version of TensorFlow is + typically easier to build and install, so even if you have an NVIDIA GPU, we + recommend building and installing this version first. +* **TensorFlow with GPU support**. TensorFlow programs typically run + significantly faster on a GPU than on a CPU. Therefore, if your system has a + NVIDIA GPU and you need to run performance-critical applications, you should + ultimately build and install this version. Beyond the NVIDIA GPU itself, + your system must also fulfill the NVIDIA software requirements described in + one of the following documents: + * @ {$install_linux#NVIDIARequirements$Installing TensorFlow on Ubuntu} + * @ {$install_mac#NVIDIARequirements$Installing TensorFlow on macOS} ## Clone the TensorFlow repository -Start the process of building TensorFlow by cloning a TensorFlow -repository. +Start the process of building TensorFlow by cloning a TensorFlow repository. To clone **the latest** TensorFlow repository, issue the following command:
$ git clone https://github.com/tensorflow/tensorflow 
-The preceding git clone command creates a subdirectory -named `tensorflow`. After cloning, you may optionally build a -**specific branch** (such as a release branch) by invoking the -following commands: +The preceding git clone command creates a subdirectory named +`tensorflow`. After cloning, you may optionally build a **specific branch** +(such as a release branch) by invoking the following commands:
 $ cd tensorflow
@@ -75,38 +69,34 @@ issue the following command:
 
 
$ git checkout r1.0
-Next, you must prepare your environment for -[Linux](#PrepareLinux) -or +Next, you must prepare your environment for [Linux](#PrepareLinux) or [macOS](#PrepareMac) - -## Prepare environment for Linux -Before building TensorFlow on Linux, install the following build -tools on your system: +## Prepare environment for Linux - * bazel - * TensorFlow Python dependencies - * optionally, NVIDIA packages to support TensorFlow for GPU. +Before building TensorFlow on Linux, install the following build tools on your +system: +* bazel +* TensorFlow Python dependencies +* optionally, NVIDIA packages to support TensorFlow for GPU. ### Install Bazel If bazel is not installed on your system, install it now by following [these directions](https://bazel.build/versions/master/docs/install.html). - ### Install TensorFlow Python dependencies To install TensorFlow, you must install the following packages: - * `numpy`, which is a numerical processing package that TensorFlow requires. - * `dev`, which enables adding extensions to Python. - * `pip`, which enables you to install and manage certain Python packages. - * `wheel`, which enables you to manage Python compressed packages in - the wheel (.whl) format. +* `numpy`, which is a numerical processing package that TensorFlow requires. +* `dev`, which enables adding extensions to Python. +* `pip`, which enables you to install and manage certain Python packages. +* `wheel`, which enables you to manage Python compressed packages in the wheel + (.whl) format. To install these packages for Python 2.7, issue the following command: @@ -120,94 +110,101 @@ To install these packages for Python 3.n, issue the following command: $ sudo apt-get install python3-numpy python3-dev python3-pip python3-wheel
- ### Optional: install TensorFlow for GPU prerequisites If you are building TensorFlow without GPU support, skip this section. -The following NVIDIA hardware must be installed on your system: - - * GPU card with CUDA Compute Capability 3.0 or higher. See - [NVIDIA documentation](https://developer.nvidia.com/cuda-gpus) - for a list of supported GPU cards. - -The following NVIDIA software must be installed on your system: - - * [CUDA Toolkit](http://nvidia.com/cuda) (>= 8.0). We recommend version 9.0. - For details, see - [NVIDIA's documentation](http://docs.nvidia.com/cuda/cuda-installation-guide-linux/). - Ensure that you append the relevant CUDA pathnames to the - `LD_LIBRARY_PATH` environment variable as described in the - NVIDIA documentation. - * [GPU drivers](http://nvidia.com/driver) supporting your version of the CUDA - Toolkit. - * [cuDNN SDK](http://developer.nvidia.com/cudnn) (>= 6.0). We recommend version 7.0. For details, see - [NVIDIA's documentation](http://docs.nvidia.com/deeplearning/sdk/cudnn-install/). - * [CUPTI](http://docs.nvidia.com/cuda/cupti/) ships with the CUDA Toolkit, but - you also need to append its path to the `LD_LIBRARY_PATH` environment - variable: +The following NVIDIA® hardware must be installed on your system: + +* GPU card with CUDA Compute Capability 3.5 or higher. See + [NVIDIA documentation](https://developer.nvidia.com/cuda-gpus) for a list of + supported GPU cards. -
 $ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/extras/CUPTI/lib64 
+The following NVIDIA® software must be installed on your system: + +* [GPU drivers](http://nvidia.com/driver). CUDA 9.0 requires 384.x or higher. +* [CUDA Toolkit](http://nvidia.com/cuda) (>= 8.0). We recommend version 9.0. +* [cuDNN SDK](http://developer.nvidia.com/cudnn) (>= 6.0). We recommend + version 7.1.x. +* [CUPTI](http://docs.nvidia.com/cuda/cupti/) ships with the CUDA Toolkit, but + you also need to append its path to the `LD_LIBRARY_PATH` environment + variable: `export + LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/extras/CUPTI/lib64` +* *OPTIONAL*: [NCCL 2.2](https://developer.nvidia.com/nccl) to use TensorFlow + with multiple GPUs. +* *OPTIONAL*: + [TensorRT](http://docs.nvidia.com/deeplearning/sdk/tensorrt-install-guide/index.html) + which can improve latency and throughput for inference for some models. + +While it is possible to install the NVIDIA libraries via `apt-get` from the +NVIDIA repository, the libraries and headers are installed in locations that +make it difficult to configure and debug build issues. Downloading and +installing the libraries manually or using docker +([latest-devel-gpu](https://hub.docker.com/r/tensorflow/tensorflow/tags/)) is +recommended. ### Next After preparing the environment, you must now [configure the installation](#ConfigureInstallation). - + ## Prepare environment for macOS Before building TensorFlow, you must install the following on your system: - * bazel - * TensorFlow Python dependencies. - * optionally, NVIDIA packages to support TensorFlow for GPU. - +* bazel +* TensorFlow Python dependencies. +* optionally, NVIDIA packages to support TensorFlow for GPU. ### Install bazel If bazel is not installed on your system, install it now by following [these directions](https://bazel.build/versions/master/docs/install.html#mac-os-x). - ### Install python dependencies To build TensorFlow, you must install the following packages: - * six - * numpy, which is a numerical processing package that TensorFlow requires. - * wheel, which enables you to manage Python compressed packages - in the wheel (.whl) format. +* six +* mock +* numpy, which is a numerical processing package that TensorFlow requires. +* wheel, which enables you to manage Python compressed packages in the wheel + (.whl) format. -You may install the python dependencies using pip. If you don't have pip -on your machine, we recommend using homebrew to install Python and pip as +You may install the python dependencies using pip. If you don't have pip on your +machine, we recommend using homebrew to install Python and pip as [documented here](http://docs.python-guide.org/en/latest/starting/install/osx/). If you follow these instructions, you will not need to disable SIP. After installing pip, invoke the following commands: -
 $ sudo pip install six numpy wheel 
+
 $ sudo pip install six numpy wheel mock h5py
+ $ sudo pip install keras_applications==1.0.4 --no-deps
+ $ sudo pip install keras_preprocessing==1.0.2 --no-deps
+
Note: These are just the minimum requirements to _build_ tensorflow. Installing the pip package will download additional packages required to _run_ it. If you plan on executing tasks directly with `bazel` , without the pip installation, -you may need to install additional python packages. For example, you should -`pip install mock enum34` before running TensorFlow's tests with bazel. +you may need to install additional python packages. For example, you should `pip +install mock enum34` before running TensorFlow's tests with bazel. + ## Configure the installation -The root of the source tree contains a bash script named -configure. This script asks you to identify the pathname of all -relevant TensorFlow dependencies and specify other build configuration options -such as compiler flags. You must run this script *prior* to -creating the pip package and installing TensorFlow. +The root of the source tree contains a bash script named configure. +This script asks you to identify the pathname of all relevant TensorFlow +dependencies and specify other build configuration options such as compiler +flags. You must run this script *prior* to creating the pip package and +installing TensorFlow. -If you wish to build TensorFlow with GPU, `configure` will ask -you to specify the version numbers of CUDA and cuDNN. If several -versions of CUDA or cuDNN are installed on your system, explicitly select -the desired version instead of relying on the default. +If you wish to build TensorFlow with GPU, `configure` will ask you to specify +the version numbers of CUDA and cuDNN. If several versions of CUDA or cuDNN are +installed on your system, explicitly select the desired version instead of +relying on the default. One of the questions that `configure` will ask is as follows: @@ -215,73 +212,117 @@ One of the questions that `configure` will ask is as follows: Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native] -This question refers to a later phase in which you'll use bazel to [build the -pip package](#build-the-pip-package) or the [C/Java libraries](#BuildCorJava). -We recommend accepting the default (`-march=native`), which will optimize the -generated code for your local machine's CPU type. However, if you are building -TensorFlow on one CPU type but will run TensorFlow on a different CPU type, then -consider specifying a more specific optimization -flag as described in [the gcc -documentation](https://gcc.gnu.org/onlinedocs/gcc-4.5.3/gcc/i386-and-x86_002d64-Options.html). +This question refers to a later phase in which you'll use bazel to +[build the pip package](#build-the-pip-package) or the +[C/Java libraries](#BuildCorJava). We recommend accepting the default +(`-march=native`), which will optimize the generated code for your local +machine's CPU type. However, if you are building TensorFlow on one CPU type but +will run TensorFlow on a different CPU type, then consider specifying a more +specific optimization flag as described in +[the gcc documentation](https://gcc.gnu.org/onlinedocs/gcc-4.5.3/gcc/i386-and-x86_002d64-Options.html). -Here is an example execution of the `configure` script. Note that your -own input will likely differ from our sample input: +Here is an example execution of the `configure` script. Note that your own input +will likely differ from our sample input:
 $ cd tensorflow  # cd to the top-level directory created
 $ ./configure
+You have bazel 0.15.0 installed.
 Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python2.7
+
+
 Found possible Python library paths:
   /usr/local/lib/python2.7/dist-packages
   /usr/lib/python2.7/dist-packages
 Please input the desired Python library path to use.  Default is [/usr/lib/python2.7/dist-packages]
 
-Using python library path: /usr/local/lib/python2.7/dist-packages
-Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]:
-Do you wish to use jemalloc as the malloc implementation? [Y/n]
-jemalloc enabled
-Do you wish to build TensorFlow with Google Cloud Platform support? [y/N]
-No Google Cloud Platform support will be enabled for TensorFlow
-Do you wish to build TensorFlow with Hadoop File System support? [y/N]
-No Hadoop File System support will be enabled for TensorFlow
-Do you wish to build TensorFlow with the XLA just-in-time compiler (experimental)? [y/N]
-No XLA support will be enabled for TensorFlow
-Do you wish to build TensorFlow with VERBS support? [y/N]
-No VERBS support will be enabled for TensorFlow
-Do you wish to build TensorFlow with OpenCL support? [y/N]
-No OpenCL support will be enabled for TensorFlow
-Do you wish to build TensorFlow with CUDA support? [y/N] Y
-CUDA support will be enabled for TensorFlow
-Do you want to use clang as CUDA compiler? [y/N]
-nvcc will be used as CUDA compiler
+Do you wish to build TensorFlow with jemalloc as malloc support? [Y/n]:
+jemalloc as malloc support will be enabled for TensorFlow.
+
+Do you wish to build TensorFlow with Google Cloud Platform support? [Y/n]:
+Google Cloud Platform support will be enabled for TensorFlow.
+
+Do you wish to build TensorFlow with Hadoop File System support? [Y/n]:
+Hadoop File System support will be enabled for TensorFlow.
+
+Do you wish to build TensorFlow with Amazon AWS Platform support? [Y/n]:
+Amazon AWS Platform support will be enabled for TensorFlow.
+
+Do you wish to build TensorFlow with Apache Kafka Platform support? [Y/n]:
+Apache Kafka Platform support will be enabled for TensorFlow.
+
+Do you wish to build TensorFlow with XLA JIT support? [y/N]:
+No XLA JIT support will be enabled for TensorFlow.
+
+Do you wish to build TensorFlow with GDR support? [y/N]:
+No GDR support will be enabled for TensorFlow.
+
+Do you wish to build TensorFlow with VERBS support? [y/N]:
+No VERBS support will be enabled for TensorFlow.
+
+Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]:
+No OpenCL SYCL support will be enabled for TensorFlow.
+
+Do you wish to build TensorFlow with CUDA support? [y/N]: Y
+CUDA support will be enabled for TensorFlow.
+
 Please specify the CUDA SDK version you want to use. [Leave empty to default to CUDA 9.0]: 9.0
+
+
 Please specify the location where CUDA 9.0 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]:
-Please specify which gcc should be used by nvcc as the host compiler. [Default is /usr/bin/gcc]:
-Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 7.0]: 7
+
+
+Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 7.0]: 7.0
+
+
 Please specify the location where cuDNN 7 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda]:
-Please specify a list of comma-separated CUDA compute capabilities you want to build with.
+
+
+Do you wish to build TensorFlow with TensorRT support? [y/N]:
+No TensorRT support will be enabled for TensorFlow.
+
+Please specify the NCCL version you want to use. If NCLL 2.2 is not installed, then you can use version 1.3 that can be fetched automatically but it may have worse performance with multiple GPUs. [Default is 2.2]: 1.3
+
+
+Please specify a list of comma-separated Cuda compute capabilities you want to build with.
 You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
-Please note that each additional compute capability significantly increases your build time and binary size.
-[Default is: "3.5,5.2"]: 3.0
-Do you wish to build TensorFlow with MPI support? [y/N]
-MPI support will not be enabled for TensorFlow
+Please note that each additional compute capability significantly increases your
+build time and binary size. [Default is: 3.5,7.0] 6.1
+
+
+Do you want to use clang as CUDA compiler? [y/N]:
+nvcc will be used as CUDA compiler.
+
+Please specify which gcc should be used by nvcc as the host compiler. [Default is /usr/bin/gcc]:
+
+
+Do you wish to build TensorFlow with MPI support? [y/N]:
+No MPI support will be enabled for TensorFlow.
+
+Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native]:
+
+
+Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]:
+Not configuring the WORKSPACE for Android builds.
+
+Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See tools/bazel.rc for more details.
+    --config=mkl            # Build with MKL support.
+    --config=monolithic     # Config for mostly static monolithic build.
 Configuration finished
 
-If you told `configure` to build for GPU support, then `configure` -will create a canonical set of symbolic links to the CUDA libraries -on your system. Therefore, every time you change the CUDA library paths, -you must rerun the `configure` script before re-invoking -the bazel build command. +If you told `configure` to build for GPU support, then `configure` will create a +canonical set of symbolic links to the CUDA libraries on your system. Therefore, +every time you change the CUDA library paths, you must rerun the `configure` +script before re-invoking the bazel build command. Note the following: - * Although it is possible to build both CUDA and non-CUDA configs - under the same source tree, we recommend running `bazel clean` when - switching between these two configurations in the same source tree. - * If you don't run the `configure` script *before* running the - `bazel build` command, the `bazel build` command will fail. - +* Although it is possible to build both CUDA and non-CUDA configs under the + same source tree, we recommend running `bazel clean` when switching between + these two configurations in the same source tree. +* If you don't run the `configure` script *before* running the `bazel build` + command, the `bazel build` command will fail. ## Build the pip package @@ -289,49 +330,58 @@ Note: If you're only interested in building the libraries for the TensorFlow C or Java APIs, see [Build the C or Java libraries](#BuildCorJava), you do not need to build the pip package in that case. -To build a pip package for TensorFlow with CPU-only support, -you would typically invoke the following command: +### CPU-only support + +To build a pip package for TensorFlow with CPU-only support: + +
+$ bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package
+
+ +To build a pip package for TensorFlow with CPU-only support for the Intel® +MKL-DNN:
-$ bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package
+$ bazel build --config=mkl --config=opt //tensorflow/tools/pip_package:build_pip_package
 
-To build a pip package for TensorFlow with GPU support, -invoke the following command: +### GPU support -
$ bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package 
+To build a pip package for TensorFlow with GPU support: -**NOTE on gcc 5 or later:** the binary pip packages available on the -TensorFlow website are built with gcc 4, which uses the older ABI. To -make your build compatible with the older ABI, you need to add -`--cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"` to your `bazel build` command. -ABI compatibility allows custom ops built against the TensorFlow pip package -to continue to work against your built package. +
+$ bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
+
-Tip: By default, building TensorFlow from sources consumes -a lot of RAM. If RAM is an issue on your system, you may limit RAM usage -by specifying --local_resources 2048,.5,1.0 while -invoking `bazel`. +**NOTE on gcc 5 or later:** the binary pip packages available on the TensorFlow +website are built with gcc 4, which uses the older ABI. To make your build +compatible with the older ABI, you need to add +`--cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0"` to your `bazel build` command. ABI +compatibility allows custom ops built against the TensorFlow pip package to +continue to work against your built package. -The bazel build command builds a script named -`build_pip_package`. Running this script as follows will build -a `.whl` file within the `/tmp/tensorflow_pkg` directory: +Tip: By default, building TensorFlow from sources consumes a lot of RAM. +If RAM is an issue on your system, you may limit RAM usage by specifying +--local_resources 2048,.5,1.0 while invoking `bazel`. + +The bazel build command builds a script named `build_pip_package`. +Running this script as follows will build a `.whl` file within the +`/tmp/tensorflow_pkg` directory:
 $ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
 
- ## Install the pip package -Invoke `pip install` to install that pip package. -The filename of the `.whl` file depends on your platform. -For example, the following command will install the pip package +Invoke `pip install` to install that pip package. The filename of the `.whl` +file depends on your platform. For example, the following command will install +the pip package -for TensorFlow 1.9.0rc0 on Linux: +for TensorFlow 1.10.0 on Linux:
-$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.9.0rc0-py2-none-any.whl
+$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.10.0-py2-none-any.whl
 
## Validate your installation @@ -362,28 +412,31 @@ TensorFlow programs:
Hello, TensorFlow!
-If you are new to TensorFlow, see @{$get_started/eager}. +To learn more, see the [TensorFlow tutorials](../tutorials/). -If the system outputs an error message instead of a greeting, see [Common -installation problems](#common_installation_problems). +If the system outputs an error message instead of a greeting, see +[Common installation problems](#common_installation_problems). ## Common build and installation problems The build and installation problems you encounter typically depend on the -operating system. See the "Common installation problems" section -of one of the following guides: - - * @{$install_linux#common_installation_problems$Installing TensorFlow on Linux} - * @{$install_mac#common_installation_problems$Installing TensorFlow on Mac OS} - * @{$install_windows#common_installation_problems$Installing TensorFlow on Windows} - -Beyond the errors documented in those two guides, the following table -notes additional errors specific to building TensorFlow. Note that we -are relying on Stack Overflow as the repository for build and installation -problems. If you encounter an error message not listed in the preceding -two guides or in the following table, search for it on Stack Overflow. If -Stack Overflow doesn't show the error message, ask a new question on -Stack Overflow and specify the `tensorflow` tag. +operating system. See the "Common installation problems" section of one of the +following guides: + +* @ + {$install_linux#common_installation_problems$Installing TensorFlow on Linux} +* @ + {$install_mac#common_installation_problems$Installing TensorFlow on Mac OS} +* @ + {$install_windows#common_installation_problems$Installing TensorFlow on Windows} + +Beyond the errors documented in those two guides, the following table notes +additional errors specific to building TensorFlow. Note that we are relying on +Stack Overflow as the repository for build and installation problems. If you +encounter an error message not listed in the preceding two guides or in the +following table, search for it on Stack Overflow. If Stack Overflow doesn't show +the error message, ask a new question on Stack Overflow and specify the +`tensorflow` tag. @@ -430,9 +483,12 @@ Stack Overflow and specify the `tensorflow` tag.
Stack Overflow Link Error Message
## Tested source configurations + **Linux** + + @@ -458,6 +514,7 @@ Stack Overflow and specify the `tensorflow` tag. **Mac**
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.10.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.15.0N/AN/A
tensorflow_gpu-1.10.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.15.079
tensorflow-1.9.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.11.0N/AN/A
tensorflow_gpu-1.9.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.11.079
tensorflow-1.8.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.10.0N/AN/A
+ @@ -475,6 +532,8 @@ Stack Overflow and specify the `tensorflow` tag. **Windows**
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.10.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.15.0N/AN/A
tensorflow-1.9.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.11.0N/AN/A
tensorflow-1.8.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.10.1N/AN/A
tensorflow-1.7.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.10.1N/AN/A
+ + @@ -498,6 +557,7 @@ Stack Overflow and specify the `tensorflow` tag.
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.10.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
tensorflow_gpu-1.10.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
tensorflow-1.9.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
tensorflow_gpu-1.9.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
tensorflow-1.8.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
+ ## Build the C or Java libraries The instructions above are tailored to building the TensorFlow Python packages. @@ -506,10 +566,12 @@ If you're interested in building the libraries for the TensorFlow C API, do the following: 1. Follow the steps up to [Configure the installation](#ConfigureInstallation) -2. Build the C libraries following instructions in the [README](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/lib_package/README.md). +2. Build the C libraries following instructions in the + [README](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/lib_package/README.md). -If you're interested inv building the libraries for the TensorFlow Java API, -do the following: +If you're interested inv building the libraries for the TensorFlow Java API, do +the following: 1. Follow the steps up to [Configure the installation](#ConfigureInstallation) -2. Build the Java library following instructions in the [README](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/lib_package/README.md). +2. Build the Java library following instructions in the + [README](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/lib_package/README.md). diff --git a/tensorflow/docs_src/install/install_sources_windows.md b/tensorflow/docs_src/install/install_sources_windows.md new file mode 100644 index 0000000000000000000000000000000000000000..a1da12231738259969d35e4dffc7612e45aab031 --- /dev/null +++ b/tensorflow/docs_src/install/install_sources_windows.md @@ -0,0 +1,320 @@ +# Install TensorFlow from Sources on Windows + +This guide explains how to build TensorFlow sources into a TensorFlow binary and +how to install that TensorFlow binary on Windows. + +## Determine which TensorFlow to install + +You must choose one of the following types of TensorFlow to build and install: + +* **TensorFlow with CPU support only**. If your system does not have a NVIDIA® + GPU, build and install this version. Note that this version of TensorFlow is + typically easier to build and install, so even if you have an NVIDIA GPU, we + recommend building and installing this version first. +* **TensorFlow with GPU support**. TensorFlow programs typically run + significantly faster on a GPU than on a CPU. Therefore, if your system has a + NVIDIA GPU and you need to run performance-critical applications, you should + ultimately build and install this version. Beyond the NVIDIA GPU itself, + your system must also fulfill the NVIDIA software requirements described in + the following document: + + * [Installing TensorFlow on Windows](install_windows.md#NVIDIARequirements) + +## Prepare environment for Windows + +Before building TensorFlow on Windows, install the following build tools on your +system: + +* [MSYS2](#InstallMSYS2) +* [Visual C++ build tools](#InstallVCBuildTools) +* [Bazel for Windows](#InstallBazel) +* [TensorFlow Python dependencies](#InstallPython) +* [optionally, NVIDIA packages to support TensorFlow for GPU](#InstallCUDA) + + + +### Install MSYS2 + +Bash bin tools are used in TensorFlow Bazel build, you can install them through [MSYS2](https://www.msys2.org/). + +Assume you installed MSYS2 at `C:\msys64`, add `C:\msys64\usr\bin` to your `%PATH%` environment variable. + +To install necessary bash bin tools, issue the following command under `cmd.exe`: + +
+C:\> pacman -S git patch unzip
+
+ + + +### Install Visual C++ Build Tools 2015 + +To build TensorFlow, you need to install Visual C++ build tools 2015. It is a part of Visual Studio 2015. +But you can install it separately by the following way: + + * Open the [official downloand page](https://visualstudio.microsoft.com/vs/older-downloads/). + * Go to Redistributables and Build Tools section. + * Find Microsoft Build Tools 2015 Update 3 and click download. + * Run the installer. + +It's possible to build TensorFlow with newer version of Visual C++ build tools, +but we only test against Visual Studio 2015 Update 3. + + + +### Install Bazel + +If bazel is not installed on your system, install it now by following +[these instructions](https://docs.bazel.build/versions/master/install-windows.html). +It is recommended to use a Bazel version >= `0.15.0`. + +Add the directory where you installed Bazel to your `%PATH%` environment variable. + + + +### Install TensorFlow Python dependencies + +If you don't have Python 3.5 or Python 3.6 installed, install it now: + + * [Python 3.5.x 64-bit from python.org](https://www.python.org/downloads/release/python-352/) + * [Python 3.6.x 64-bit from python.org](https://www.python.org/downloads/release/python-362/) + +To build and install TensorFlow, you must install the following python packages: + +* `six`, which provides simple utilities for wrapping over differences between + Python 2 and Python 3. +* `numpy`, which is a numerical processing package that TensorFlow requires. +* `wheel`, which enables you to manage Python compressed packages in the wheel + (.whl) format. +* `keras_applications`, the applications module of the Keras deep learning library. +* `keras_preprocessing`, the data preprocessing and data augmentation module + of the Keras deep learning library. + +Assume you already have `pip3` in `%PATH%`, issue the following command: + +
+C:\> pip3 install six numpy wheel
+C:\> pip3 install keras_applications==1.0.4 --no-deps
+C:\> pip3 install keras_preprocessing==1.0.2 --no-deps
+
+ + + +### Optional: install TensorFlow for GPU prerequisites + +If you are building TensorFlow without GPU support, skip this section. + +The following NVIDIA® _hardware_ must be installed on your system: + +* GPU card with CUDA Compute Capability 3.5 or higher. See + [NVIDIA documentation](https://developer.nvidia.com/cuda-gpus) for a list of + supported GPU cards. + +The following NVIDIA® _software_ must be installed on your system: + +* [GPU drivers](http://nvidia.com/driver). CUDA 9.0 requires 384.x or higher. +* [CUDA Toolkit](http://nvidia.com/cuda) (>= 8.0). We recommend version 9.0. +* [cuDNN SDK](http://developer.nvidia.com/cudnn) (>= 6.0). We recommend + version 7.1.x. +* [CUPTI](http://docs.nvidia.com/cuda/cupti/) ships with the CUDA Toolkit, but + you also need to append its path to `%PATH%` environment + variable. + +Assume you have CUDA Toolkit installed at `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0` +and cuDNN at `C:\tools\cuda`, issue the following commands. + +
+C:\> SET PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0\bin;%PATH%
+C:\> SET PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.0\extras\CUPTI\libx64;%PATH%
+C:\> SET PATH=C:\tools\cuda\bin;%PATH%
+
+ +## Clone the TensorFlow repository + +Now you need to clone **the latest** TensorFlow repository, +thanks to MSYS2 we already have `git` avaiable, issue the following command: + +
C:\> git clone https://github.com/tensorflow/tensorflow.git 
+ +The preceding git clone command creates a subdirectory named +`tensorflow`. After cloning, you may optionally build a **specific branch** +(such as a release branch) by invoking the following commands: + +
+C:\> cd tensorflow
+C:\> git checkout Branch # where Branch is the desired branch
+
+ +For example, to work with the `r1.11` release instead of the master release, +issue the following command: + +
C:\> git checkout r1.11
+ +Next, you must now configure the installation. + +## Configure the installation + +The root of the source tree contains a python script named configure.py. +This script asks you to identify the pathname of all relevant TensorFlow +dependencies and specify other build configuration options such as compiler +flags. You must run this script *prior* to creating the pip package and +installing TensorFlow. + +If you wish to build TensorFlow with GPU, `configure.py` will ask you to specify +the version numbers of CUDA and cuDNN. If several versions of CUDA or cuDNN are +installed on your system, explicitly select the desired version instead of +relying on the default. + +One of the questions that `configure.py` will ask is as follows: + +
+Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is /arch:AVX]:
+
+ +Here is an example execution of the `configure.py` script. Note that your own input +will likely differ from our sample input: + +
+C:\> cd tensorflow  # cd to the top-level directory created
+C:\tensorflow> python ./configure.py
+Starting local Bazel server and connecting to it...
+................
+You have bazel 0.15.0 installed.
+Please specify the location of python. [Default is C:\python36\python.exe]: 
+
+Found possible Python library paths:
+  C:\python36\lib\site-packages
+Please input the desired Python library path to use.  Default is [C:\python36\lib\site-packages]
+
+Do you wish to build TensorFlow with CUDA support? [y/N]: Y
+CUDA support will be enabled for TensorFlow.
+
+Please specify the CUDA SDK version you want to use. [Leave empty to default to CUDA 9.0]:
+
+Please specify the location where CUDA 9.0 toolkit is installed. Refer to README.md for more details. [Default is C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0]:
+
+Please specify the cuDNN version you want to use. [Leave empty to default to cuDNN 7.0]: 7.0
+
+Please specify the location where cuDNN 7 library is installed. Refer to README.md for more details. [Default is C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0]: C:\tools\cuda
+
+Please specify a list of comma-separated Cuda compute capabilities you want to build with.
+You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
+Please note that each additional compute capability significantly increases your build time and binary size. [Default is: 3.5,7.0]: 3.7
+
+Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is /arch:AVX]: 
+
+Would you like to override eigen strong inline for some C++ compilation to reduce the compilation time? [Y/n]:
+Eigen strong inline overridden.
+
+Configuration finished
+
+ +## Build the pip package + +### CPU-only support + +To build a pip package for TensorFlow with CPU-only support: + +
+C:\tensorflow> bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package
+
+ +### GPU support + +To build a pip package for TensorFlow with GPU support: + +
+C:\tensorflow> bazel build --config=opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
+
+ +**NOTE :** When building with GPU support, you might want to add `--copt=-nvcc_options=disable-warnings` +to suppress nvcc warning messages. + +The `bazel build` command builds a binary named `build_pip_package` +(an executable binary to launch bash and run a bash script to create the pip package). +Running this binary as follows will build a `.whl` file within the `C:/tmp/tensorflow_pkg` directory: + +
+C:\tensorflow> bazel-bin\tensorflow\tools\pip_package\build_pip_package C:/tmp/tensorflow_pkg
+
+ +## Install the pip package + +Invoke `pip3 install` to install that pip package. The filename of the `.whl` +file depends on the TensorFlow version and your platform. For example, the +following command will install the pip package for TensorFlow 1.11.0rc0: + +
+C:\tensorflow> pip3 install C:/tmp/tensorflow_pkg/tensorflow-1.11.0rc0-cp36-cp36m-win_amd64.whl
+
+ +## Validate your installation + +Validate your TensorFlow installation by doing the following: + +Start a terminal. + +Change directory (`cd`) to any directory on your system other than the +`tensorflow` subdirectory from which you invoked the `configure` command. + +Invoke python: + +
$ python
+ +Enter the following short program inside the python interactive shell: + +```python +# Python +import tensorflow as tf +hello = tf.constant('Hello, TensorFlow!') +sess = tf.Session() +print(sess.run(hello)) +``` + +If the system outputs the following, then you are ready to begin writing +TensorFlow programs: + +
Hello, TensorFlow!
+ +To learn more, see the [TensorFlow tutorials](../tutorials/). + +## Build under MSYS shell +The above instruction assumes you are building under the Windows native command line (`cmd.exe`), but you can also +build TensorFlow from MSYS shell. There are a few things to notice: + +* Disable the path conversion heuristic in MSYS. MSYS automatically converts arguments that look + like a Unix path to Windows path when running a program, this will confuse Bazel. + (eg. A Bazel label `//foo/bar:bin` is considered a Unix absolute path, only because it starts with a slash) + + ```sh +$ export MSYS_NO_PATHCONV=1 +$ export MSYS2_ARG_CONV_EXCL="*" +``` + +* Add the directory where you install Bazel in `$PATH`. Assume you have Bazel + installed at `C:\tools\bazel.exe`, issue the following command: + + ```sh +# `:` is used as path separator, so we have to convert the path to Unix style. +$ export PATH="/c/tools:$PATH" +``` + +* Add the directory where you install Python in `$PATH`. Assume you have + Python installed at `C:\Python36\python.exe`, issue the following command: + + ```sh +$ export PATH="/c/Python36:$PATH" +``` + +* If you have Python in `$PATH`, you can run configure script just by + `./configure`, a shell script will help you invoke python. + +* (For GPU build only) Add Cuda and cuDNN bin directories in `$PATH` in the following way: + + ```sh +$ export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/bin:$PATH" +$ export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/extras/CUPTI/libx64:$PATH" +$ export PATH="/c/tools/cuda/bin:$PATH" +``` + +The rest steps should be the same as building under `cmd.exe`. diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md index 6c4f5b85ab2facdb274e9bdd36f6edb9ad79ba4b..0bb0e5aeb9ccdf956c39516297b1f59b9da263de 100644 --- a/tensorflow/docs_src/install/install_windows.md +++ b/tensorflow/docs_src/install/install_windows.md @@ -1,4 +1,4 @@ -# Installing TensorFlow on Windows +# Install TensorFlow on Windows This guide explains how to install TensorFlow on Windows. Although these instructions might also work on other Windows variants, we have only @@ -24,6 +24,8 @@ You must choose one of the following types of TensorFlow to install: and you need to run performance-critical applications, you should ultimately install this version. + + ### Requirements to run TensorFlow with GPU support If you are installing TensorFlow with GPU support using one of the mechanisms @@ -157,12 +159,7 @@ TensorFlow programs: If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). -If you are new to machine learning, we recommend the -[Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course). - -If you are experienced with machine learning but new to TensorFlow, see -@{$get_started/eager}. - +To learn more, see the [TensorFlow tutorials](../tutorials/). ## Common installation problems diff --git a/tensorflow/docs_src/install/leftnav_files b/tensorflow/docs_src/install/leftnav_files index e523e06f67aad508238ee0965f34ebe16c77bf90..59292f71218c5b6eee7b543f0b2a2eaf849a4246 100644 --- a/tensorflow/docs_src/install/leftnav_files +++ b/tensorflow/docs_src/install/leftnav_files @@ -4,7 +4,9 @@ index.md install_linux.md: Ubuntu install_mac.md: MacOS install_windows.md: Windows +install_raspbian.md: Raspbian install_sources.md: From source +install_sources_windows.md: From source on Windows >>> migration.md diff --git a/tensorflow/docs_src/install/migration.md b/tensorflow/docs_src/install/migration.md index d6c31f96bd624f03f0b868a030383851c4e48ef7..19315ace2d76b63da0370cb811729934c801cf11 100644 --- a/tensorflow/docs_src/install/migration.md +++ b/tensorflow/docs_src/install/migration.md @@ -1,5 +1,4 @@ - -# Transitioning to TensorFlow 1.0 +# Transition to TensorFlow 1.0 The APIs in TensorFlow 1.0 have changed in ways that are not all backwards diff --git a/tensorflow/docs_src/javascript/index.md b/tensorflow/docs_src/javascript/index.md deleted file mode 100644 index ad63eeb255d870064567a0de8a28815ce2ae0172..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/javascript/index.md +++ /dev/null @@ -1,5 +0,0 @@ -# JavaScript - -You may develop TensorFlow programs in JavaScript, training and deploying -models right in your browser. For details, see -[js.tensorflow.org](https://js.tensorflow.org). diff --git a/tensorflow/docs_src/javascript/leftnav_files b/tensorflow/docs_src/javascript/leftnav_files deleted file mode 100644 index fc0ab8a5435943f6442969ec5787305b98c7908b..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/javascript/leftnav_files +++ /dev/null @@ -1 +0,0 @@ -index.md diff --git a/tensorflow/docs_src/mobile/README.md b/tensorflow/docs_src/mobile/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ecf42672654ab4a8d2ea8c9bb4752ed65d6c8a9a --- /dev/null +++ b/tensorflow/docs_src/mobile/README.md @@ -0,0 +1,3 @@ +# TF Lite subsite + +This subsite directory lives in [tensorflow/contrib/lite/g3doc](../../contrib/lite/g3doc/). diff --git a/tensorflow/docs_src/mobile/index.md b/tensorflow/docs_src/mobile/index.md deleted file mode 100644 index 419ae7094a180fb166eb5b00cc382773b95b91f4..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/mobile/index.md +++ /dev/null @@ -1,36 +0,0 @@ -# Overview - -TensorFlow was designed to be a good deep learning solution for mobile -platforms. Currently we have two solutions for deploying machine learning -applications on mobile and embedded devices: -@{$mobile/mobile_intro$TensorFlow for Mobile} and @{$mobile/tflite$TensorFlow Lite}. - -## TensorFlow Lite versus TensorFlow Mobile - -Here are a few of the differences between the two: - -- TensorFlow Lite is an evolution of TensorFlow Mobile. In most cases, apps - developed with TensorFlow Lite will have a smaller binary size, fewer - dependencies, and better performance. - -- TensorFlow Lite is in developer preview, so not all use cases are covered yet. - We expect you to use TensorFlow Mobile to cover production cases. - -- TensorFlow Lite supports only a limited set of operators, so not all models - will work on it by default. TensorFlow for Mobile has a fuller set of - supported functionality. - -TensorFlow Lite provides better performance and a small binary size on mobile -platforms as well as the ability to leverage hardware acceleration if available -on their platforms. In addition, it has many fewer dependencies so it can be -built and hosted on simpler, more constrained device scenarios. TensorFlow Lite -also allows targeting accelerators through the [Neural Networks -API](https://developer.android.com/ndk/guides/neuralnetworks/index.html). - -TensorFlow Lite currently has coverage for a limited set of operators. While -TensorFlow for Mobile supports only a constrained set of ops by default, in -principle if you use an arbitrary operator in TensorFlow, it can be customized -to build that kernel. Thus use cases which are not currently supported by -TensorFlow Lite should continue to use TensorFlow for Mobile. As TensorFlow Lite -evolves, it will gain additional operators, and the decision will be easier to -make. diff --git a/tensorflow/docs_src/mobile/leftnav_files b/tensorflow/docs_src/mobile/leftnav_files deleted file mode 100644 index 585470d5f0847716863ba6129bf75c26631fecbd..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/mobile/leftnav_files +++ /dev/null @@ -1,14 +0,0 @@ -index.md -### TensorFlow Lite -tflite/index.md -tflite/devguide.md -tflite/demo_android.md -tflite/demo_ios.md ->>> -### TensorFlow Mobile -mobile_intro.md -android_build.md -ios_build.md -linking_libs.md -prepare_models.md -optimizing.md diff --git a/tensorflow/docs_src/performance/datasets_performance.md b/tensorflow/docs_src/performance/datasets_performance.md index 46b43b7673c561679e89fff0ae738b0e751fcff5..5d9e4ba392558b6a621808961102e5958e2cbe74 100644 --- a/tensorflow/docs_src/performance/datasets_performance.md +++ b/tensorflow/docs_src/performance/datasets_performance.md @@ -38,9 +38,9 @@ the heavy lifting of training your model. In addition, viewing input pipelines as an ETL process provides structure that facilitates the application of performance optimizations. -When using the @{tf.estimator.Estimator} API, the first two phases (Extract and +When using the `tf.estimator.Estimator` API, the first two phases (Extract and Transform) are captured in the `input_fn` passed to -@{tf.estimator.Estimator.train}. In code, this might look like the following +`tf.estimator.Estimator.train`. In code, this might look like the following (naive, sequential) implementation: ``` @@ -99,7 +99,7 @@ With pipelining, idle time diminishes significantly: ![with pipelining](/images/datasets_with_pipelining.png) The `tf.data` API provides a software pipelining mechanism through the -@{tf.data.Dataset.prefetch} transformation, which can be used to decouple the +`tf.data.Dataset.prefetch` transformation, which can be used to decouple the time data is produced from the time it is consumed. In particular, the transformation uses a background thread and an internal buffer to prefetch elements from the input dataset ahead of the time they are requested. Thus, to @@ -130,7 +130,7 @@ The preceding recommendation is simply the most common application. ### Parallelize Data Transformation When preparing a batch, input elements may need to be pre-processed. To this -end, the `tf.data` API offers the @{tf.data.Dataset.map} transformation, which +end, the `tf.data` API offers the `tf.data.Dataset.map` transformation, which applies a user-defined function (for example, `parse_fn` from the running example) to each element of the input dataset. Because input elements are independent of one another, the pre-processing can be parallelized across @@ -164,7 +164,7 @@ dataset = dataset.map(map_func=parse_fn, num_parallel_calls=FLAGS.num_parallel_c Furthermore, if your batch size is in the hundreds or thousands, your pipeline will likely additionally benefit from parallelizing the batch creation. To this -end, the `tf.data` API provides the @{tf.contrib.data.map_and_batch} +end, the `tf.data` API provides the `tf.contrib.data.map_and_batch` transformation, which effectively "fuses" the map and batch transformations. To apply this change to our running example, change: @@ -205,7 +205,7 @@ is stored locally or remotely, but can be worse in the remote case if data is not prefetched effectively. To mitigate the impact of the various data extraction overheads, the `tf.data` -API offers the @{tf.contrib.data.parallel_interleave} transformation. Use this +API offers the `tf.contrib.data.parallel_interleave` transformation. Use this transformation to parallelize the execution of and interleave the contents of other datasets (such as data file readers). The number of datasets to overlap can be specified by the `cycle_length` argument. @@ -232,7 +232,7 @@ dataset = files.apply(tf.contrib.data.parallel_interleave( The throughput of remote storage systems can vary over time due to load or network events. To account for this variance, the `parallel_interleave` transformation can optionally use prefetching. (See -@{tf.contrib.data.parallel_interleave} for details). +`tf.contrib.data.parallel_interleave` for details). By default, the `parallel_interleave` transformation provides a deterministic ordering of elements to aid reproducibility. As an alternative to prefetching @@ -261,7 +261,7 @@ function (that is, have it operate over a batch of inputs at once) and apply the ### Map and Cache -The @{tf.data.Dataset.cache} transformation can cache a dataset, either in +The `tf.data.Dataset.cache` transformation can cache a dataset, either in memory or on local storage. If the user-defined function passed into the `map` transformation is expensive, apply the cache transformation after the map transformation as long as the resulting dataset can still fit into memory or @@ -281,9 +281,9 @@ performance (for example, to enable fusing of the map and batch transformations) ### Repeat and Shuffle -The @{tf.data.Dataset.repeat} transformation repeats the input data a finite (or +The `tf.data.Dataset.repeat` transformation repeats the input data a finite (or infinite) number of times; each repetition of the data is typically referred to -as an _epoch_. The @{tf.data.Dataset.shuffle} transformation randomizes the +as an _epoch_. The `tf.data.Dataset.shuffle` transformation randomizes the order of the dataset's examples. If the `repeat` transformation is applied before the `shuffle` transformation, @@ -296,7 +296,7 @@ internal state of the `shuffle` transformation. In other words, the former (`shuffle` before `repeat`) provides stronger ordering guarantees. When possible, we recommend using the fused -@{tf.contrib.data.shuffle_and_repeat} transformation, which combines the best of +`tf.contrib.data.shuffle_and_repeat` transformation, which combines the best of both worlds (good performance and strong ordering guarantees). Otherwise, we recommend shuffling before repeating. diff --git a/tensorflow/docs_src/performance/performance_guide.md b/tensorflow/docs_src/performance/performance_guide.md index cb0f5ca9242098d06aa0a9898e4a3774fab527b8..df703095688097123d0c46bdbfcf0c0f92457871 100644 --- a/tensorflow/docs_src/performance/performance_guide.md +++ b/tensorflow/docs_src/performance/performance_guide.md @@ -94,7 +94,7 @@ sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) #### Fused decode and crop If inputs are JPEG images that also require cropping, use fused -@{tf.image.decode_and_crop_jpeg} to speed up preprocessing. +`tf.image.decode_and_crop_jpeg` to speed up preprocessing. `tf.image.decode_and_crop_jpeg` only decodes the part of the image within the crop window. This significantly speeds up the process if the crop window is much smaller than the full image. For imagenet data, this @@ -187,14 +187,14 @@ some models makes up a large percentage of the operation time. Using fused batch norm can result in a 12%-30% speedup. There are two commonly used batch norms and both support fusing. The core -@{tf.layers.batch_normalization} added fused starting in TensorFlow 1.3. +`tf.layers.batch_normalization` added fused starting in TensorFlow 1.3. ```python bn = tf.layers.batch_normalization( input_layer, fused=True, data_format='NCHW') ``` -The contrib @{tf.contrib.layers.batch_norm} method has had fused as an option +The contrib `tf.contrib.layers.batch_norm` method has had fused as an option since before TensorFlow 1.0. ```python @@ -205,43 +205,43 @@ bn = tf.contrib.layers.batch_norm(input_layer, fused=True, data_format='NCHW') There are many ways to specify an RNN computation in TensorFlow and they have trade-offs with respect to model flexibility and performance. The -@{tf.nn.rnn_cell.BasicLSTMCell} should be considered a reference implementation +`tf.nn.rnn_cell.BasicLSTMCell` should be considered a reference implementation and used only as a last resort when no other options will work. When using one of the cells, rather than the fully fused RNN layers, you have a -choice of whether to use @{tf.nn.static_rnn} or @{tf.nn.dynamic_rnn}. There +choice of whether to use `tf.nn.static_rnn` or `tf.nn.dynamic_rnn`. There shouldn't generally be a performance difference at runtime, but large unroll -amounts can increase the graph size of the @{tf.nn.static_rnn} and cause long -compile times. An additional advantage of @{tf.nn.dynamic_rnn} is that it can +amounts can increase the graph size of the `tf.nn.static_rnn` and cause long +compile times. An additional advantage of `tf.nn.dynamic_rnn` is that it can optionally swap memory from the GPU to the CPU to enable training of very long sequences. Depending on the model and hardware configuration, this can come at a performance cost. It is also possible to run multiple iterations of -@{tf.nn.dynamic_rnn} and the underlying @{tf.while_loop} construct in parallel, +`tf.nn.dynamic_rnn` and the underlying `tf.while_loop` construct in parallel, although this is rarely useful with RNN models as they are inherently sequential. -On NVIDIA GPUs, the use of @{tf.contrib.cudnn_rnn} should always be preferred +On NVIDIA GPUs, the use of `tf.contrib.cudnn_rnn` should always be preferred unless you want layer normalization, which it doesn't support. It is often at -least an order of magnitude faster than @{tf.contrib.rnn.BasicLSTMCell} and -@{tf.contrib.rnn.LSTMBlockCell} and uses 3-4x less memory than -@{tf.contrib.rnn.BasicLSTMCell}. +least an order of magnitude faster than `tf.contrib.rnn.BasicLSTMCell` and +`tf.contrib.rnn.LSTMBlockCell` and uses 3-4x less memory than +`tf.contrib.rnn.BasicLSTMCell`. If you need to run one step of the RNN at a time, as might be the case in reinforcement learning with a recurrent policy, then you should use the -@{tf.contrib.rnn.LSTMBlockCell} with your own environment interaction loop -inside a @{tf.while_loop} construct. Running one step of the RNN at a time and +`tf.contrib.rnn.LSTMBlockCell` with your own environment interaction loop +inside a `tf.while_loop` construct. Running one step of the RNN at a time and returning to Python is possible, but it will be slower. -On CPUs, mobile devices, and if @{tf.contrib.cudnn_rnn} is not available on +On CPUs, mobile devices, and if `tf.contrib.cudnn_rnn` is not available on your GPU, the fastest and most memory efficient option is -@{tf.contrib.rnn.LSTMBlockFusedCell}. +`tf.contrib.rnn.LSTMBlockFusedCell`. -For all of the less common cell types like @{tf.contrib.rnn.NASCell}, -@{tf.contrib.rnn.PhasedLSTMCell}, @{tf.contrib.rnn.UGRNNCell}, -@{tf.contrib.rnn.GLSTMCell}, @{tf.contrib.rnn.Conv1DLSTMCell}, -@{tf.contrib.rnn.Conv2DLSTMCell}, @{tf.contrib.rnn.LayerNormBasicLSTMCell}, +For all of the less common cell types like `tf.contrib.rnn.NASCell`, +`tf.contrib.rnn.PhasedLSTMCell`, `tf.contrib.rnn.UGRNNCell`, +`tf.contrib.rnn.GLSTMCell`, `tf.contrib.rnn.Conv1DLSTMCell`, +`tf.contrib.rnn.Conv2DLSTMCell`, `tf.contrib.rnn.LayerNormBasicLSTMCell`, etc., one should be aware that they are implemented in the graph like -@{tf.contrib.rnn.BasicLSTMCell} and as such will suffer from the same poor +`tf.contrib.rnn.BasicLSTMCell` and as such will suffer from the same poor performance and high memory usage. One should consider whether or not those trade-offs are worth it before using these cells. For example, while layer normalization can speed up convergence, because cuDNN is 20x faster the fastest @@ -464,7 +464,7 @@ equal to the number of physical cores rather than logical cores. config = tf.ConfigProto() config.intra_op_parallelism_threads = 44 config.inter_op_parallelism_threads = 44 - tf.session(config=config) + tf.Session(config=config) ``` diff --git a/tensorflow/docs_src/performance/performance_models.md b/tensorflow/docs_src/performance/performance_models.md index 359b0e904dba1aea92f30604ff3b8abb81d432b1..66bf684d5b195a0e303aeaa2534c293777b4709c 100644 --- a/tensorflow/docs_src/performance/performance_models.md +++ b/tensorflow/docs_src/performance/performance_models.md @@ -10,8 +10,8 @@ incorporated into high-level APIs. ## Input Pipeline The @{$performance_guide$Performance Guide} explains how to identify possible -input pipeline issues and best practices. We found that using @{tf.FIFOQueue} -and @{tf.train.queue_runner} could not saturate multiple current generation GPUs +input pipeline issues and best practices. We found that using `tf.FIFOQueue` +and `tf.train.queue_runner` could not saturate multiple current generation GPUs when using large inputs and processing with higher samples per second, such as training ImageNet with [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf). This is due to the use of Python threads as its underlying implementation. The @@ -29,7 +29,7 @@ implementation is made up of 3 stages: The dominant part of each stage is executed in parallel with the other stages using `data_flow_ops.StagingArea`. `StagingArea` is a queue-like operator -similar to @{tf.FIFOQueue}. The difference is that `StagingArea` does not +similar to `tf.FIFOQueue`. The difference is that `StagingArea` does not guarantee FIFO ordering, but offers simpler functionality and can be executed on both CPU and GPU in parallel with other stages. Breaking the input pipeline into 3 stages that operate independently in parallel is scalable and takes full @@ -62,10 +62,10 @@ and executed in parallel. The image preprocessing ops include operations such as image decoding, distortion, and resizing. Once the images are through preprocessing, they are concatenated together into 8 -tensors each with a batch-size of 32. Rather than using @{tf.concat} for this +tensors each with a batch-size of 32. Rather than using `tf.concat` for this purpose, which is implemented as a single op that waits for all the inputs to be -ready before concatenating them together, @{tf.parallel_stack} is used. -@{tf.parallel_stack} allocates an uninitialized tensor as an output, and each +ready before concatenating them together, `tf.parallel_stack` is used. +`tf.parallel_stack` allocates an uninitialized tensor as an output, and each input tensor is written to its designated portion of the output tensor as soon as the input is available. @@ -94,7 +94,7 @@ the GPU, all the tensors are already available. With all the stages capable of being driven by different processors, `data_flow_ops.StagingArea` is used between them so they run in parallel. -`StagingArea` is a queue-like operator similar to @{tf.FIFOQueue} that offers +`StagingArea` is a queue-like operator similar to `tf.FIFOQueue` that offers simpler functionalities that can be executed on both CPU and GPU. Before the model starts running all the stages, the input pipeline stages are @@ -153,7 +153,7 @@ weights obtained from training. The default batch-normalization in TensorFlow is implemented as composite operations. This is very general, but often leads to suboptimal performance. An alternative is to use fused batch-normalization which often has much better -performance on GPU. Below is an example of using @{tf.contrib.layers.batch_norm} +performance on GPU. Below is an example of using `tf.contrib.layers.batch_norm` to implement fused batch-normalization. ```python @@ -301,7 +301,7 @@ In order to broadcast variables and aggregate gradients across different GPUs within the same host machine, we can use the default TensorFlow implicit copy mechanism. -However, we can instead use the optional NCCL (@{tf.contrib.nccl}) support. NCCL +However, we can instead use the optional NCCL (`tf.contrib.nccl`) support. NCCL is an NVIDIA® library that can efficiently broadcast and aggregate data across different GPUs. It schedules a cooperating kernel on each GPU that knows how to best utilize the underlying hardware topology; this kernel uses a single SM of diff --git a/tensorflow/docs_src/performance/quantization.md b/tensorflow/docs_src/performance/quantization.md index c97f74139c6ee852bf29724a3ac335d349a73fd3..4499f5715cd58ff846d49f3ed4560ded319c883c 100644 --- a/tensorflow/docs_src/performance/quantization.md +++ b/tensorflow/docs_src/performance/quantization.md @@ -163,7 +163,7 @@ bazel build tensorflow/contrib/lite/toco:toco && \ --std_value=127.5 --mean_value=127.5 ``` -See the documentation for @{tf.contrib.quantize} and +See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](/mobile/tflite/). ## Quantized accuracy diff --git a/tensorflow/docs_src/performance/xla/broadcasting.md b/tensorflow/docs_src/performance/xla/broadcasting.md index eaa709c2f84245341044b93060f932a22fbe54c7..7018ded53f8bc078a43b6af54a9ba13796374458 100644 --- a/tensorflow/docs_src/performance/xla/broadcasting.md +++ b/tensorflow/docs_src/performance/xla/broadcasting.md @@ -99,7 +99,7 @@ dimensions 1 and 2 of the cuboid. This type of broadcast is used in the binary ops in `XlaBuilder`, if the `broadcast_dimensions` argument is given. For example, see -[XlaBuilder::Add](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.cc). +[XlaBuilder::Add](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.cc). In the XLA source code, this type of broadcasting is sometimes called "InDim" broadcasting. diff --git a/tensorflow/docs_src/performance/xla/developing_new_backend.md b/tensorflow/docs_src/performance/xla/developing_new_backend.md index 74ea15bb2bac2014257f0b1719820f7ee313b66b..840f6983c2837771acbd79b221efcb5537ae4d7d 100644 --- a/tensorflow/docs_src/performance/xla/developing_new_backend.md +++ b/tensorflow/docs_src/performance/xla/developing_new_backend.md @@ -44,7 +44,7 @@ It is possible to model a new implementation on the existing [`xla::CPUCompiler`] (https://www.tensorflow.org/code/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc) and [`xla::GPUCompiler`] -(https://www.tensorflow.org/code/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc) +(https://www.tensorflow.org/code/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc) classes, since these already emit LLVM IR. Depending on the nature of the hardware, it is possible that many of the LLVM IR generation aspects will have to be changed, but a lot of code can be shared with the existing backends. diff --git a/tensorflow/docs_src/performance/xla/jit.md b/tensorflow/docs_src/performance/xla/jit.md index 6724d1eaf8f85320b963eddc37947d69dcaa8471..7202ef47f7ae94ca37811f7fab208860410299f0 100644 --- a/tensorflow/docs_src/performance/xla/jit.md +++ b/tensorflow/docs_src/performance/xla/jit.md @@ -19,10 +19,11 @@ on the `XLA_CPU` or `XLA_GPU` TensorFlow devices. Placing operators directly on a TensorFlow XLA device forces the operator to run on that device and is mainly used for testing. -> Note: The XLA CPU backend produces fast single-threaded code (in most cases), -> but does not yet parallelize as well as the TensorFlow CPU backend. The XLA -> GPU backend is competitive with the standard TensorFlow implementation, -> sometimes faster, sometimes slower. +> Note: The XLA CPU backend supports intra-op parallelism (i.e. it can shard a +> single operation across multiple cores) but it does not support inter-op +> parallelism (i.e. it cannot execute independent operations concurrently across +> multiple cores). The XLA GPU backend is competitive with the standard +> TensorFlow implementation, sometimes faster, sometimes slower. ### Turning on JIT compilation @@ -55,8 +56,7 @@ sess = tf.Session(config=config) > Note: Turning on JIT at the session level will not result in operations being > compiled for the CPU. JIT compilation for CPU operations must be done via -> the manual method documented below. This decision was made due to the CPU -> backend being single-threaded. +> the manual method documented below. #### Manual diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index f7e116bf0f85fe94b2167eca8b623207432b38e9..e24a7cda733febd98f0cf7af1c86893d9a8f91dc 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -1,7 +1,7 @@ # Operation Semantics The following describes the semantics of operations defined in the -[`XlaBuilder`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h) +[`XlaBuilder`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) interface. Typically, these operations map one-to-one to operations defined in the RPC interface in [`xla_data.proto`](https://www.tensorflow.org/code/tensorflow/compiler/xla/xla_data.proto). @@ -13,10 +13,83 @@ arbitrary-dimensional array. For convenience, special cases have more specific and familiar names; for example a *vector* is a 1-dimensional array and a *matrix* is a 2-dimensional array. +## AllToAll + +See also +[`XlaBuilder::AllToAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Alltoall is a collective operation that sends data from all cores to all cores. +It has two phases: + +1. the scatter phase. On each core, the operand is split into `split_count` + number of blocks along the `split_dimensions`, and the blocks are scattered + to all cores, e.g., the ith block is send to the ith core. +2. the gather phase. Each core concatenates the received blocks along the + `concat_dimension`. + +The participating cores can be configured by: + +- `replica_groups`: each ReplicaGroup contains a list of replica id. If empty, + all replicas belong to one group in the order of 0 - (n-1). Alltoall will be + applied within subgroups in the specified order. For example, replica + groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica + 1, 2, 3, and in the gather phase, the received blocks will be concatenated + in the order of 1, 2, 3; another Alltoall will be applied within replica 4, + 5, 0, and the concatenation order is 4, 5, 0. + +Prerequisites: + +- The dimension size of the operand on the split_dimension is divisible by + split_count. +- The operand's shape is not tuple. + + `AllToAll(operand, split_dimension, concat_dimension, split_count, +replica_groups)` + + +| Arguments | Type | Semantics | +| ------------------ | --------------------- | ------------------------------- | +| `operand` | `XlaOp` | n dimensional input array | +| `split_dimension` | `int64` | A value in the interval `[0, | +: : : n)` that names the dimension : +: : : along which the operand is : +: : : split : +| `concat_dimension` | `int64` | a value in the interval `[0, | +: : : n)` that names the dimension : +: : : along which the split blocks : +: : : are concatenated : +| `split_count` | `int64` | the number of cores that | +: : : participate this operation. If : +: : : `replica_groups` is empty, this : +: : : should be the number of : +: : : replicas; otherwise, this : +: : : should be equal to the number : +: : : of replicas in each group. : +| `replica_groups` | `ReplicaGroup` vector | each group contains a list of | +: : : replica id. : + +Below shows an example of Alltoall. + +``` +XlaBuilder b("alltoall"); +auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); +AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4); +``` + +
+ +
+ +In this example, there are 4 cores participating the Alltoall. On each core, the +operand is split into 4 parts along dimension 0, so each part has shape +f32[4,4]. The 4 parts are scattered to all cores. Then each core concatenates +the received parts along dimension 1, in the order or core 0-4. So the output on +each core has shape f32[16,4]. + ## BatchNormGrad See also -[`XlaBuilder::BatchNormGrad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h) +[`XlaBuilder::BatchNormGrad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) and [the original batch normalization paper](https://arxiv.org/abs/1502.03167) for a detailed description of the algorithm. @@ -80,7 +153,7 @@ The output type is a tuple of three handles: ## BatchNormInference See also -[`XlaBuilder::BatchNormInference`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h) +[`XlaBuilder::BatchNormInference`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) and [the original batch normalization paper](https://arxiv.org/abs/1502.03167) for a detailed description of the algorithm. @@ -115,7 +188,7 @@ The output is an n-dimensional, normalized array with the same shape as input ## BatchNormTraining See also -[`XlaBuilder::BatchNormTraining`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h) +[`XlaBuilder::BatchNormTraining`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) and [`the original batch normalization paper`](https://arxiv.org/abs/1502.03167) for a detailed description of the algorithm. @@ -167,7 +240,7 @@ spatial dimensions using the formulas above. ## BitcastConvertType See also -[`XlaBuilder::BitcastConvertType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::BitcastConvertType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Similar to a `tf.bitcast` in TensorFlow, performs an element-wise bitcast operation from a data shape to a target shape. The dimensions must match, and @@ -189,7 +262,7 @@ and destination element types must not be tuples. ## Broadcast See also -[`XlaBuilder::Broadcast`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Broadcast`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Adds dimensions to an array by duplicating the data in the array. @@ -217,7 +290,7 @@ For example, if `operand` is a scalar `f32` with value `2.0f`, and ## Call See also -[`XlaBuilder::Call`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Call`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Invokes a computation with the given arguments. @@ -236,7 +309,7 @@ The arity and types of the `args` must match the parameters of the ## Clamp See also -[`XlaBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Clamps an operand to within the range between a minimum and maximum value. @@ -269,8 +342,8 @@ Clamp(min, operand, max) = s32[3]{0, 5, 6}; ## Collapse See also -[`XlaBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h) -and the @{tf.reshape} operation. +[`XlaBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) +and the `tf.reshape` operation. Collapses dimensions of an array into one dimension. @@ -291,7 +364,7 @@ same position in the dimension sequence as those they replace, with the new dimension size equal to the product of original dimension sizes. The lowest dimension number in `dimensions` is the slowest varying dimension (most major) in the loop nest which collapses these dimension, and the highest dimension -number is fastest varying (most minor). See the @{tf.reshape} operator +number is fastest varying (most minor). See the `tf.reshape` operator if more general collapse ordering is needed. For example, let v be an array of 24 elements: @@ -332,7 +405,7 @@ then v12 == f32[8x3] {{10, 11, 12}, ## Concatenate See also -[`XlaBuilder::ConcatInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::ConcatInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Concatenate composes an array from multiple array operands. The array is of the same rank as each of the input array operands (which must be of the same rank as @@ -388,7 +461,7 @@ Diagram: ## Conditional See also -[`XlaBuilder::Conditional`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Conditional`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). `Conditional(pred, true_operand, true_computation, false_operand, false_computation)` @@ -416,7 +489,7 @@ executed depending on the value of `pred`. ## Conv (convolution) See also -[`XlaBuilder::Conv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Conv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). As ConvWithGeneralPadding, but the padding is specified in a short-hand way as either SAME or VALID. SAME padding pads the input (`lhs`) with zeroes so that @@ -426,7 +499,7 @@ account. VALID padding simply means no padding. ## ConvWithGeneralPadding (convolution) See also -[`XlaBuilder::ConvWithGeneralPadding`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::ConvWithGeneralPadding`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Computes a convolution of the kind used in neural networks. Here, a convolution can be thought of as a n-dimensional window moving across a n-dimensional base @@ -490,8 +563,8 @@ array. The holes are filled with a no-op value, which for convolution means zeroes. Dilation of the rhs is also called atrous convolution. For more details, see -@{tf.nn.atrous_conv2d}. Dilation of the lhs is also called transposed -convolution. For more details, see @{tf.nn.conv2d_transpose}. +`tf.nn.atrous_conv2d`. Dilation of the lhs is also called transposed +convolution. For more details, see `tf.nn.conv2d_transpose`. The output shape has these dimensions, in this order: @@ -538,7 +611,7 @@ for (b, oz, oy, ox) { // output coordinates ## ConvertElementType See also -[`XlaBuilder::ConvertElementType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::ConvertElementType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Similar to an element-wise `static_cast` in C++, performs an element-wise conversion operation from a data shape to a target shape. The dimensions must @@ -572,7 +645,7 @@ then b == f32[3]{0.0, 1.0, 2.0} ## CrossReplicaSum See also -[`XlaBuilder::CrossReplicaSum`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::CrossReplicaSum`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Computes a sum across replicas. @@ -607,7 +680,7 @@ than another. ## CustomCall See also -[`XlaBuilder::CustomCall`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::CustomCall`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Call a user-provided function within a computation. @@ -668,7 +741,7 @@ idempotent. ## Dot See also -[`XlaBuilder::Dot`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Dot`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). `Dot(lhs, rhs)` @@ -697,7 +770,7 @@ multiplications or matrix/matrix multiplications. ## DotGeneral See also -[`XlaBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). `DotGeneral(lhs, rhs, dimension_numbers)` @@ -784,15 +857,13 @@ non-contracting/non-batch dimension. ## DynamicSlice See also -[`XlaBuilder::DynamicSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::DynamicSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). DynamicSlice extracts a sub-array from the input array at dynamic `start_indices`. The size of the slice in each dimension is passed in `size_indices`, which specify the end point of exclusive slice intervals in each dimension: [start, start + size). The shape of `start_indices` must be rank == 1, with dimension size equal to the rank of `operand`. -Note: handling of out-of-bounds slice indices (generated by incorrect runtime -calculation of 'start_indices') is currently implementation-defined. `DynamicSlice(operand, start_indices, size_indices)` @@ -812,6 +883,17 @@ calculation of 'start_indices') is currently implementation-defined. : : : dimension to avoid wrapping modulo : : : : dimension size. : +The effective slice indices are computed by applying the following +transformation for each index `i` in `[1, N)` before performing the slice: + +``` +start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i]) +``` + +This ensures that the extracted slice is always in-bounds with respect to the +operand array. If the slice is in-bounds before the transformation is applied, +the transformation has no effect. + 1-dimensional example: ``` @@ -839,7 +921,7 @@ DynamicSlice(b, s, {2, 2}) produces: ## DynamicUpdateSlice See also -[`XlaBuilder::DynamicUpdateSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::DynamicUpdateSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). DynamicUpdateSlice generates a result which is the value of the input array `operand`, with a slice `update` overwritten at `start_indices`. @@ -847,8 +929,6 @@ The shape of `update` determines the shape of the sub-array of the result which is updated. The shape of `start_indices` must be rank == 1, with dimension size equal to the rank of `operand`. -Note: handling of out-of-bounds slice indices (generated by incorrect runtime -calculation of 'start_indices') is currently implementation-defined. `DynamicUpdateSlice(operand, update, start_indices)` @@ -866,6 +946,17 @@ calculation of 'start_indices') is currently implementation-defined. : : : dimension. Value must be greater than or equal : : : : to zero. : +The effective slice indices are computed by applying the following +transformation for each index `i` in `[1, N)` before performing the slice: + +``` +start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i]) +``` + +This ensures that the updated slice is always in-bounds with respect to the +operand array. If the slice is in-bounds before the transformation is applied, +the transformation has no effect. + 1-dimensional example: ``` @@ -902,7 +993,7 @@ DynamicUpdateSlice(b, u, s) produces: ## Element-wise binary arithmetic operations See also -[`XlaBuilder::Add`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Add`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). A set of element-wise binary arithmetic operations is supported. @@ -947,7 +1038,7 @@ shapes of both operands. The semantics are described in detail on the ## Element-wise comparison operations See also -[`XlaBuilder::Eq`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Eq`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). A set of standard element-wise binary comparison operations is supported. Note that standard IEEE 754 floating-point comparison semantics apply when comparing @@ -1033,7 +1124,7 @@ potentially different runtime offset) of an input tensor into an output tensor. ### General Semantics See also -[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). For a more intuitive description, see the "Informal Description" section below. `gather(operand, gather_indices, output_window_dims, elided_window_dims, window_bounds, gather_dims_to_operand_dims)` @@ -1236,7 +1327,7 @@ concatenation of all these rows. ## GetTupleElement See also -[`XlaBuilder::GetTupleElement`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::GetTupleElement`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Indexes into a tuple with a compile-time-constant value. @@ -1252,12 +1343,12 @@ let t: (f32[10], s32) = tuple(v, s); let element_1: s32 = gettupleelement(t, 1); // Inferred shape matches s32. ``` -See also @{tf.tuple}. +See also `tf.tuple`. ## Infeed See also -[`XlaBuilder::Infeed`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Infeed`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). `Infeed(shape)` @@ -1293,27 +1384,38 @@ Infeed of the device. > which case the compiler will provide information about how the Infeed > operations are serialized in the compiled program. +## Iota + + `Iota()` + +Builds a constant literal on device rather than a potentially large host +transfer. Creates a rank 1 tensor of values starting at zero and incrementing +by one. + +Arguments | Type | Semantics +------------------ | --------------- | --------------------------- +`type` | `PrimitiveType` | type U +`size` | `int64` | The number of elements in the tensor. + ## Map See also -[`XlaBuilder::Map`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Map`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). `Map(operands..., computation)` | Arguments | Type | Semantics | | ----------------- | ---------------------- | ------------------------------ | | `operands` | sequence of N `XlaOp`s | N arrays of types T_0..T_{N-1} | -| `computation` | `XlaComputation` | computation of type `T_0, T_1, | +| `computation` | `XlaComputation` | computation of type `T_0, T_1, | : : : ..., T_{N + M -1} -> S` with N : : : : parameters of type T and M of : : : : arbitrary type : | `dimensions` | `int64` array | array of map dimensions | -| `static_operands` | sequence of M `XlaOp`s | M arrays of arbitrary type | Applies a scalar function over the given `operands` arrays, producing an array of the same dimensions where each element is the result of the mapped function -applied to the corresponding elements in the input arrays with `static_operands` -given as additional input to `computation`. +applied to the corresponding elements in the input arrays. The mapped function is an arbitrary computation with the restriction that it has N inputs of scalar type `T` and a single output with type `S`. The output has @@ -1327,7 +1429,7 @@ input arrays to produce the output array. ## Pad See also -[`XlaBuilder::Pad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Pad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). `Pad(operand, padding_value, padding_config)` @@ -1366,7 +1468,7 @@ are all 0. The figure below shows examples of different `edge_padding` and ## Recv See also -[`XlaBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). `Recv(shape, channel_handle)` @@ -1400,21 +1502,31 @@ complete and returns the received data. ## Reduce See also -[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Applies a reduction function to one or more arrays in parallel. -Applies a reduction function to an array. + `Reduce(operands..., init_values..., computation, dimensions)` - `Reduce(operand, init_value, computation, dimensions)` +Arguments | Type | Semantics +------------- | --------------------- | --------------------------------------- +`operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`. +`init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`. +`computation` | `XlaComputation` | computation of type + : : `T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)` +`dimensions` | `int64` array | unordered array of dimensions to reduce -Arguments | Type | Semantics -------------- | ---------------- | --------------------------------------- -`operand` | `XlaOp` | array of type `T` -`init_value` | `XlaOp` | scalar of type `T` -`computation` | `XlaComputation` | computation of type `T, T -> T` -`dimensions` | `int64` array | unordered array of dimensions to reduce +Where: +* N is required to be greater or equal to 1. +* All input arrays must have the same dimensions. +* If `N = 1`, `Collate(T)` is `T`. +* If `N > 1`, `Collate(T_0, ..., T_N)` is a tuple of `N` elements of type `T`. -This operation reduces one or more dimensions of the input array into scalars. -The rank of the returned array is `rank(operand) - len(dimensions)`. +The output of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type +`T_i`, the dimensions of which are described below. + +This operation reduces one or more dimensions of each input array into scalars. +The rank of each returned array is `rank(operand) - len(dimensions)`. `init_value` is the initial value used for every reduction and may be inserted anywhere during computation by the back-end. In most cases, `init_value` is an identity of the reduction function (for example, 0 for addition). The applied @@ -1430,9 +1542,9 @@ enough to being associative for most practical uses. It is possible to conceive of some completely non-associative reductions, however, and these will produce incorrect or unpredictable results in XLA reductions. -As an example, when reducing across the one dimension in a 1D array with values -[10, 11, 12, 13], with reduction function `f` (this is `computation`) then that -could be computed as +As an example, when reducing across one dimension in a single 1D array with +values [10, 11, 12, 13], with reduction function `f` (this is `computation`) +then that could be computed as `f(10, f(11, f(12, f(init_value, 13)))` @@ -1514,10 +1626,38 @@ the 1D array `| 20 28 36 |`. Reducing the 3D array over all its dimensions produces the scalar `84`. +When `N > 1`, reduce function application is slightly more complex, as it is +applied simultaneously to all inputs. For example, consider the following +reduction function, which can be used to compute the max and the argmax of a +a 1-D tensor in parallel: + +``` +f: (Float, Int, Float, Int) -> Float, Int +f(max, argmax, value, index): + if value >= argmax: + return (value, index) + else: + return (max, argmax) +``` + +For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values +`I_V = Float, I_K = Int`, the result `f_(N-1)` of reducing across the only +input dimension is equivalent to the following recursive application: +``` +f_0 = f(I_V, I_K, V_0, K_0) +f_1 = f(f_0.first, f_0.second, V_1, K_1) +... +f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1)) +``` + +Applying this reduction to an array of values, and an array of sequential +indices (i.e. iota), will co-iterate over the arrays, and return a tuple +containing the maximal value and the matching index. + ## ReducePrecision See also -[`XlaBuilder::ReducePrecision`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::ReducePrecision`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Models the effect of converting floating-point values to a lower-precision format (such as IEEE-FP16) and back to the original format. The number of @@ -1548,7 +1688,7 @@ portion of the conversion is then simply a no-op. ## ReduceWindow See also -[`XlaBuilder::ReduceWindow`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::ReduceWindow`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Applies a reduction function to all elements in each window of the input multi-dimensional array, producing an output multi-dimensional array with the @@ -1631,7 +1771,7 @@ context of [`Reduce`](#reduce) for more details. ## Reshape See also -[`XlaBuilder::Reshape`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h) +[`XlaBuilder::Reshape`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) and the [`Collapse`](#collapse) operation. Reshapes the dimensions of an array into a new configuration. @@ -1712,7 +1852,7 @@ Reshape(5, {}, {1,1}) == f32[1x1] {{5}}; ## Rev (reverse) See also -[`XlaBuilder::Rev`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Rev`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). `Rev(operand, dimensions)` @@ -1734,33 +1874,35 @@ the two window dimensions during the gradient computation in neural networks. ## RngNormal See also -[`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Constructs an output of a given shape with random numbers generated following -the $$N(\mu, \sigma)$$ normal distribution. The parameters `mu` and `sigma`, and -output shape have to have elemental type F32. The parameters furthermore have to -be scalar valued. +the $$N(\mu, \sigma)$$ normal distribution. The parameters $$\mu$$ and +$$\sigma$$, and output shape have to have a floating point elemental type. The +parameters furthermore have to be scalar valued. -`RngNormal(mean, sigma, shape)` +`RngNormal(mu, sigma, shape)` | Arguments | Type | Semantics | | --------- | ------- | --------------------------------------------------- | -| `mu` | `XlaOp` | Scalar of type F32 specifying mean of generated | -: : : numbers : -| `sigma` | `XlaOp` | Scalar of type F32 specifying standard deviation of | +| `mu` | `XlaOp` | Scalar of type T specifying mean of generated | +: : : numbers : +| `sigma` | `XlaOp` | Scalar of type T specifying standard deviation of | : : : generated numbers : -| `shape` | `Shape` | Output shape of type F32 | +| `shape` | `Shape` | Output shape of type T | ## RngUniform See also -[`XlaBuilder::RngUniform`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::RngUniform`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Constructs an output of a given shape with random numbers generated following the uniform distribution over the interval $$[a,b)$$. The parameters and output -shape may be either F32, S32 or U32, but the types have to be consistent. -Furthermore, the parameters need to be scalar valued. If $$b <= a$$ the result -is implementation-defined. +element type have to be a boolean type, an integral type or a floating point +types, and the types have to be consistent. The CPU and GPU backends currently +only support F64, F32, F16, BF16, S64, U64, S32 and U32. Furthermore, the +parameters need to be scalar valued. If $$b <= a$$ the result is +implementation-defined. `RngUniform(a, b, shape)` @@ -1772,10 +1914,142 @@ is implementation-defined. : : : limit of interval : | `shape` | `Shape` | Output shape of type T | +## Scatter + +The XLA scatter operation generates a result which is the value of the input +tensor `operand`, with several slices (at indices specified by +`scatter_indices`) updated with the values in `updates` using +`update_computation`. + +See also +[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + + `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)` + +|Arguments | Type | Semantics | +|------------------|------------------------|----------------------------------| +|`operand` | `XlaOp` | Tensor to be scattered into. | +|`scatter_indices` | `XlaOp` | Tensor containing the starting | +: : : indices of the slices that must : +: : : be scattered to. : +|`updates` | `XlaOp` | Tensor containing the values that| +: : : must be used for scattering. : +|`update_computation`| `XlaComputation` | Computation to be used for | +: : : combining the existing values in : +: : : the input tensor and the updates : +: : : during scatter. This computation : +: : : should be of type `T, T -> T`. : +|`index_vector_dim`| `int64` | The dimension in | +: : : `scatter_indices` that contains : +: : : the starting indices. : +|`update_window_dims`| `ArraySlice` | The set of dimensions in | +: : : `updates` shape that are _window : +: : : dimensions_. : +|`inserted_window_dims`| `ArraySlice`| The set of _window dimensions_ | +: : : that must be inserted into : +: : : `updates` shape. : +|`scatter_dims_to_operand_dims`| `ArraySlice` | A dimensions map from | +: : : the scatter indices to the : +: : : operand index space. This array : +: : : is interpreted as mapping `i` to : +: : : `scatter_dims_to_operand_dims[i]`: +: : : . It has to be one-to-one and : +: : : total. : + +If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider +`scatter_indices` to have a trailing `1` dimension. + +We define `update_scatter_dims` of type `ArraySlice` as the set of +dimensions in `updates` shape that are not in `update_window_dims`, in ascending +order. + +The arguments of scatter should follow these constraints: + + - `updates` tensor must be of rank `update_window_dims.size + + scatter_indices.rank - 1`. + + - Bounds of dimension `i` in `updates` must conform to the following: + - If `i` is present in `update_window_dims` (i.e. equal to + `update_window_dims`[`k`] for some `k`), then the bound of dimension + `i` in `updates` must not exceed the corresponding bound of `operand` + after accounting for the `inserted_window_dims` (i.e. + `adjusted_window_bounds`[`k`], where `adjusted_window_bounds` contains + the bounds of `operand` with the bounds at indices + `inserted_window_dims` removed). + - If `i` is present in `update_scatter_dims` (i.e. equal to + `update_scatter_dims`[`k`] for some `k`), then the bound of dimension + `i` in `updates` must be equal to the corresponding bound of + `scatter_indices`, skipping `index_vector_dim` (i.e. + `scatter_indices.shape.dims`[`k`], if `k` < `index_vector_dim` and + `scatter_indices.shape.dims`[`k+1`] otherwise). + + - `update_window_dims` must be in ascending order, not have any repeating + dimension numbers, and be in the range `[0, updates.rank)`. + + - `inserted_window_dims` must be in ascending order, not have any + repeating dimension numbers, and be in the range `[0, operand.rank)`. + + - `scatter_dims_to_operand_dims.size` must be equal to + `scatter_indices`[`index_vector_dim`], and its values must be in the range + `[0, operand.rank)`. + +For a given index `U` in the `updates` tensor, the corresponding index `I` in +the `operand` tensor into which this update has to be applied is computed as +follows: + + 1. Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up + an index vector `S` in the `scatter_indices` tensor such that `S`[`i`] = + `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at + positions `index_vector_dim` into A. + 2. Create an index `S``in` into `operand` using `S` by scattering + `S` using the `scatter_dims_to_operand_dims` map. More formally: + 1. `S``in`[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if + `k` < `scatter_dims_to_operand_dims.size`. + 2. `S``in`[`_`] = `0` otherwise. + 3. Create an index `W``in` into `operand` by scattering the indices + at `update_window_dims` in `U` according to `inserted_window_dims`. + More formally: + 1. `W``in`[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if + `k` < `update_window_dims.size`, where `window_dims_to_operand_dims` + is the monotonic function with domain [`0`, `update_window_dims.size`) + and range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For + example, if `update_window_dims.size` is `4`, `operand.rank` is `6`, + and `inserted_window_dims` is {`0`, `2`} then + `window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, + `3`→`5`}). + 2. `W``in`[`_`] = `0` otherwise. + 4. `I` is `W``in` + `S``in` where + is element-wise + addition. + +In summary, the scatter operation can be defined as follows. + + - Initialize `output` with `operand`, i.e. for all indices `O` in the + `operand` tensor:\ + `output`[`O`] = `operand`[`O`] + - For every index `U` in the `updates` tensor and the corresponding index `O` + in the `operand` tensor:\ + `output`[`O`] = `update_computation`(`output`[`O`], `updates`[`U`]) + +The order in which updates are applied is non-deterministic. So, when multiple +indices in `updates` refer to the same index in `operand`, the corresponding +value in `output` will be non-deterministic. + +Note that the first parameter that is passed into the `update_computation` will +always be the current value from the `output` tensor and the second parameter +will always be the value from the `updates` tensor. This is important +specifically for cases when the `update_computation` is _not commutative_. + +Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e. +the scatter op updates the elements in the input that are extracted by the +corresponding gather op. + +For a detailed informal description and examples, refer to the +"Informal Description" section under `Gather`. + ## Select See also -[`XlaBuilder::Select`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Select`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Constructs an output array from elements of two input arrays, based on the values of a predicate array. @@ -1826,7 +2100,7 @@ the same shape!) then `pred` has to be a scalar of type `PRED`. ## SelectAndScatter See also -[`XlaBuilder::SelectAndScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::SelectAndScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). This operation can be considered as a composite operation that first computes `ReduceWindow` on the `operand` array to select an element from each window, and @@ -1906,7 +2180,7 @@ context of [`Reduce`](#reduce) for more details. ## Send See also -[`XlaBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). `Send(operand, channel_handle)` @@ -1961,7 +2235,7 @@ computations. For example, below schedules lead to deadlocks. ## Slice See also -[`XlaBuilder::Slice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Slice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). Slicing extracts a sub-array from the input array. The sub-array is of the same rank as the input and contains the values inside a bounding box within the input @@ -2010,19 +2284,48 @@ Slice(b, {2, 1}, {4, 3}) produces: ## Sort See also -[`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). -Sorts the elements in the operand. +There are two versions of the Sort instruction: a single-operand and a +two-operand version. `Sort(operand)` -Arguments | Type | Semantics ---------- | ------- | ------------------- -`operand` | `XlaOp` | The operand to sort +Arguments | Type | Semantics +----------- | ------- | -------------------- +`operand` | `XlaOp` | The operand to sort. +`dimension` | `int64` | The dimension along which to sort. + +Sorts the elements in the operand in ascending order along the provided +dimension. For example, for a rank-2 (matrix) operand, a `dimension` value of 0 +will sort each column independently, and a `dimension` value of 1 will sort each +row independently. If the operand's elements have floating point type, and the +operand contains NaN elements, the order of elements in the output is +implementation-defined. + +`Sort(key, value)` + +Sorts both the key and the value operands. The keys are sorted as in the +single-operand version. The values are sorted according to the order of their +corresponding keys. For example, if the inputs are `keys = [3, 1]` and +`values = [42, 50]`, then the output of the sort is the tuple +`{[1, 3], [50, 42]}`. + +The sort is not guaranteed to be stable, that is, if the keys array contains +duplicates, the order of their corresponding values may not be preserved. + +Arguments | Type | Semantics +----------- | ------- | ------------------- +`keys` | `XlaOp` | The sort keys. +`values` | `XlaOp` | The values to sort. +`dimension` | `int64` | The dimension along which to sort. + +The `keys` and `values` must have the same dimensions, but may have different +element types. ## Transpose -See also the @{tf.reshape} operation. +See also the `tf.reshape` operation. `Transpose(operand)` @@ -2041,7 +2344,7 @@ This is the same as Reshape(operand, permutation, ## Tuple See also -[`XlaBuilder::Tuple`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::Tuple`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). A tuple containing a variable number of data handles, each of which has its own shape. @@ -2060,7 +2363,7 @@ Tuples can be deconstructed (accessed) via the [`GetTupleElement`] ## While See also -[`XlaBuilder::While`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). +[`XlaBuilder::While`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). `While(condition, body, init)` @@ -2082,8 +2385,6 @@ restrictions listed below. last execution of the `body`. * The shape of the type `T` is statically determined and must be the same across all iterations. -* `While` nodes are not allowed to be nested. (This restriction may be lifted - in the future on some targets.) The T parameters of the computations are initialized with the `init` value in the first iteration and are automatically updated to the new result from `body` diff --git a/tensorflow/docs_src/performance/xla/tfcompile.md b/tensorflow/docs_src/performance/xla/tfcompile.md index 8521d7eacb4a7fec7d187bdd1c4f452b644dc8b2..e4b803164f23038ef219f3333acdc2dc23fa86ed 100644 --- a/tensorflow/docs_src/performance/xla/tfcompile.md +++ b/tensorflow/docs_src/performance/xla/tfcompile.md @@ -205,10 +205,7 @@ representing the inputs, `results` representing the outputs, and `temps` representing temporary buffers used internally to perform the computation. By default, each instance of the generated class allocates and manages all of these buffers for you. The `AllocMode` constructor argument may be used to change this -behavior. A convenience library is provided in -[`tensorflow/compiler/aot/runtime.h`](https://www.tensorflow.org/code/tensorflow/compiler/aot/runtime.h) -to help with manual buffer allocation; usage of this library is optional. All -buffers should be aligned to 32-byte boundaries. +behavior. All buffers are aligned to 64-byte boundaries. The generated C++ class is just a wrapper around the low-level code generated by XLA. diff --git a/tensorflow/docs_src/tutorials/_index.yaml b/tensorflow/docs_src/tutorials/_index.yaml new file mode 100644 index 0000000000000000000000000000000000000000..953411468978846724b52bae73537e80694a78ee --- /dev/null +++ b/tensorflow/docs_src/tutorials/_index.yaml @@ -0,0 +1,202 @@ +project_path: /_project.yaml +book_path: /_book.yaml +description: +landing_page: + custom_css_path: /site-assets/css/style.css + show_side_navs: True + rows: + - description: > +

Get Started with TensorFlow

+

+ TensorFlow is an open-source machine learning library for research and + production. TensorFlow offers APIs for beginners and experts to develop + for desktop, mobile, web, and cloud. See the sections below to get + started. +

+ items: + - custom_html: > +
+

Learn and use ML

+
+

+ The high-level Keras API provides building blocks to create and + train deep learning models. Start with these beginner-friendly + notebook examples, then read the + TensorFlow Keras guide. +

+
    +
  1. Basic classification
  2. +
  3. Text classification
  4. +
  5. Regression
  6. +
  7. Overfitting and underfitting
  8. +
  9. Save and load
  10. +
+
+ +
+ - classname: tfo-landing-row-item-code-block + code_block: | +
+        import tensorflow as tf
+        mnist = tf.keras.datasets.mnist
+
+        (x_train, y_train),(x_test, y_test) = mnist.load_data()
+        x_train, x_test = x_train / 255.0, x_test / 255.0
+
+        model = tf.keras.models.Sequential([
+          tf.keras.layers.Flatten(),
+          tf.keras.layers.Dense(512, activation=tf.nn.relu),
+          tf.keras.layers.Dropout(0.2),
+          tf.keras.layers.Dense(10, activation=tf.nn.softmax)
+        ])
+        model.compile(optimizer='adam',
+                      loss='sparse_categorical_crossentropy',
+                      metrics=['accuracy'])
+
+        model.fit(x_train, y_train, epochs=5)
+        model.evaluate(x_test, y_test)
+        
+ {% dynamic if request.tld != 'cn' %} + Run in a Notebook + {% dynamic endif %} + + - items: + - custom_html: > +
+

Research and experimentation

+
+

+ Eager execution provides an imperative, define-by-run interface for advanced operations. Write custom layers, forward passes, and training loops with auto‑differentiation. Start with + these notebooks, then read the eager execution guide. +

+
    +
  1. + {% dynamic if request.tld == 'cn' %} + Eager execution basics + {% dynamic else %} + Eager execution basics + {% dynamic endif %} +
  2. +
  3. + {% dynamic if request.tld == 'cn' %} + Automatic differentiation and gradient tape + {% dynamic else %} + Automatic differentiation and gradient tape + {% dynamic endif %} +
  4. +
  5. + {% dynamic if request.tld == 'cn' %} + Custom training: basics + {% dynamic else %} + Custom training: basics + {% dynamic endif %} +
  6. +
  7. + {% dynamic if request.tld == 'cn' %} + Custom layers + {% dynamic else %} + Custom layers + {% dynamic endif %} +
  8. +
  9. Custom training: walkthrough
  10. +
  11. + {% dynamic if request.tld == 'cn' %} + Example: Neural machine translation w/ attention + {% dynamic else %} + Example: Neural machine translation w/ attention + {% dynamic endif %} +
  12. +
+
+ +
+ - custom_html: > +
+

ML at production scale

+
+

+ Estimators can train large models on multiple machines in a + production environment. TensorFlow provides a collection of + pre-made Estimators to implement common ML algorithms. See the + Estimators guide. +

+
    +
  1. Build a linear model with Estimators
  2. +
  3. Wide and deep learning with Estimators
  4. +
  5. Boosted trees
  6. +
  7. How to build a simple text classifier with TF-Hub
  8. +
  9. Build a Convolutional Neural Network using Estimators
  10. +
+
+ +
+ + - description: > +

Google Colab: An easy way to learn and use TensorFlow

+

+ Colaboratory + is a Google research project created to help disseminate machine learning + education and research. It's a Jupyter notebook environment that requires + no setup to use and runs entirely in the cloud. + Read the blog post. +

+ + - description: > +

Build your first ML app

+

Create and deploy TensorFlow models on web and mobile.

+ background: grey + items: + - custom_html: > +
+ +

Web developers

+
+
+ TensorFlow.js is a WebGL accelerated, JavaScript library to train and + deploy ML models in the browser and for Node.js. +
+
+ - custom_html: > +
+ +

Mobile developers

+
+
+ TensorFlow Lite is lightweight solution for mobile and embedded devices. +
+
+ + - description: > +

Videos and updates

+

+ Subscribe to the TensorFlow + YouTube channel + and blog for + the latest videos and updates. +

+ items: + - description: > +

Get started with TensorFlow's High-Level APIs

+ youtube_id: tjsHSIG8I08 + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=tjsHSIG8I08 + - description: > +

Eager execution

+ youtube_id: T8AW0fKP0Hs + background: grey + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=T8AW0fKP0Hs + - description: > +

tf.data: Fast, flexible, and easy-to-use input pipelines

+ youtube_id: uIcqeP7MFH0 + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=uIcqeP7MFH0 diff --git a/tensorflow/docs_src/tutorials/_toc.yaml b/tensorflow/docs_src/tutorials/_toc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e25208a000b7bb196462c2904c3dfba5adead6c --- /dev/null +++ b/tensorflow/docs_src/tutorials/_toc.yaml @@ -0,0 +1,124 @@ +toc: +- title: Get started with TensorFlow + path: /tutorials/ + +- title: Learn and use ML + style: accordion + section: + - title: Overview + path: /tutorials/keras/ + - title: Basic classification + path: /tutorials/keras/basic_classification + - title: Text classification + path: /tutorials/keras/basic_text_classification + - title: Regression + path: /tutorials/keras/basic_regression + - title: Overfitting and underfitting + path: /tutorials/keras/overfit_and_underfit + - title: Save and restore models + path: /tutorials/keras/save_and_restore_models + +- title: Research and experimentation + style: accordion + section: + - title: Overview + path: /tutorials/eager/ + - title: Eager execution + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb + status: external + - title: Automatic differentiation + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb + status: external + - title: "Custom training: basics" + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb + status: external + - title: Custom layers + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb + status: external + - title: "Custom training: walkthrough" + path: /tutorials/eager/custom_training_walkthrough + - title: Text generation + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb + status: external + - title: Translation with attention + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb + status: external + - title: Image captioning + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb + status: external + - title: Neural Style Transfer + path: https://github.com/tensorflow/models/blob/master/research/nst_blogpost/4_Neural_Style_Transfer_with_Eager_Execution.ipynb + status: external + - title: DCGAN + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb + status: external + - title: VAE + path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb + status: external + - title: Pix2Pix + path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb + status: external + - title: Image Segmentation + path: https://github.com/tensorflow/models/blob/master/samples/outreach/blogs/segmentation_blogpost/image_segmentation.ipynb + status: external + +- title: ML at production scale + style: accordion + section: + - title: Linear model with Estimators + path: /tutorials/estimators/linear + - title: Wide and deep learning + path: https://github.com/tensorflow/models/tree/master/official/wide_deep + status: external + - title: Boosted trees + path: https://github.com/tensorflow/models/tree/master/official/boosted_trees + status: external + - title: Text classifier with TF-Hub + path: /hub/tutorials/text_classification_with_tf_hub + - title: Build a CNN using Estimators + path: /tutorials/estimators/cnn + +- title: Images + style: accordion + section: + - title: Image recognition + path: /tutorials/images/image_recognition + - title: Image retraining + path: /hub/tutorials/image_retraining + - title: Advanced CNN + path: /tutorials/images/deep_cnn + +- title: Sequences + style: accordion + section: + - title: Recurrent neural network + path: /tutorials/sequences/recurrent + - title: Drawing classification + path: /tutorials/sequences/recurrent_quickdraw + - title: Simple audio recognition + path: /tutorials/sequences/audio_recognition + - title: Neural machine translation + path: https://github.com/tensorflow/nmt + status: external + +- title: Data representation + style: accordion + section: + - title: Vector representations of words + path: /tutorials/representation/word2vec + - title: Kernel methods + path: /tutorials/representation/kernel_methods + - title: Large-scale linear models + path: /tutorials/representation/linear + +- title: Non-ML + style: accordion + section: + - title: Mandelbrot set + path: /tutorials/non-ml/mandelbrot + - title: Partial differential equations + path: /tutorials/non-ml/pdes + +- break: True +- title: Next steps + path: /tutorials/next_steps diff --git a/tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md b/tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md new file mode 100644 index 0000000000000000000000000000000000000000..b564a27ecfd1b06c6b977302ba463bb763a6fb38 --- /dev/null +++ b/tensorflow/docs_src/tutorials/eager/custom_training_walkthrough.md @@ -0,0 +1,3 @@ +# Custom training: walkthrough + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/eager/custom_training_walkthrough.ipynb) diff --git a/tensorflow/docs_src/tutorials/eager/index.md b/tensorflow/docs_src/tutorials/eager/index.md new file mode 100644 index 0000000000000000000000000000000000000000..a13b39609435256ded88072ce40c929a1494aad0 --- /dev/null +++ b/tensorflow/docs_src/tutorials/eager/index.md @@ -0,0 +1,13 @@ +# Research and experimentation + +Eager execution provides an imperative, define-by-run interface for advanced +operations. Write custom layers, forward passes, and training loops with +auto differentiation. Start with these notebooks, then read the +[eager execution guide](../../guide/eager). + +1. [Eager execution](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb){:.external} +2. [Automatic differentiation and gradient tape](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb){:.external} +3. [Custom training: basics](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb){:.external} +4. [Custom layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb){:.external} +5. [Custom training: walkthrough](/tutorials/eager/custom_training_walkthrough) +6. [Advanced example: Neural machine translation with attention](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb){:.external} diff --git a/tensorflow/docs_src/tutorials/layers.md b/tensorflow/docs_src/tutorials/estimators/cnn.md similarity index 92% rename from tensorflow/docs_src/tutorials/layers.md rename to tensorflow/docs_src/tutorials/estimators/cnn.md index 0f17899dae7ccd8686ac159548dec303401b8ad4..100f501cc215a624212cdfe15555cd6db5da1e20 100644 --- a/tensorflow/docs_src/tutorials/layers.md +++ b/tensorflow/docs_src/tutorials/estimators/cnn.md @@ -1,6 +1,6 @@ -# A Guide to TF Layers: Building a Convolutional Neural Network +# Build a Convolutional Neural Network using Estimators -The TensorFlow @{tf.layers$`layers` module} provides a high-level API that makes +The `tf.layers` module provides a high-level API that makes it easy to construct a neural network. It provides methods that facilitate the creation of dense (fully connected) layers and convolutional layers, adding activation functions, and applying dropout regularization. In this tutorial, @@ -118,8 +118,8 @@ output from one layer-creation method and supply it as input to another. Open `cnn_mnist.py` and add the following `cnn_model_fn` function, which conforms to the interface expected by TensorFlow's Estimator API (more on this later in [Create the Estimator](#create-the-estimator)). `cnn_mnist.py` takes -MNIST feature data, labels, and -@{tf.estimator.ModeKeys$model mode} (`TRAIN`, `EVAL`, `PREDICT`) as arguments; +MNIST feature data, labels, and mode (from +`tf.estimator.ModeKeys`: `TRAIN`, `EVAL`, `PREDICT`) as arguments; configures the CNN; and returns predictions, loss, and a training operation: ```python @@ -277,7 +277,7 @@ a 5x5 convolution over a 28x28 tensor will produce a 24x24 tensor, as there are The `activation` argument specifies the activation function to apply to the output of the convolution. Here, we specify ReLU activation with -@{tf.nn.relu}. +`tf.nn.relu`. Our output tensor produced by `conv2d()` has a shape of [batch_size, 28, 28, 32]: the same height and width @@ -423,7 +423,7 @@ raw values into two different formats that our model function can return: For a given example, our predicted class is the element in the corresponding row of the logits tensor with the highest raw value. We can find the index of this -element using the @{tf.argmax} +element using the `tf.argmax` function: ```python @@ -438,7 +438,7 @@ value along the dimension with index of 1, which corresponds to our predictions 10]). We can derive probabilities from our logits layer by applying softmax activation -using @{tf.nn.softmax}: +using `tf.nn.softmax`: ```python tf.nn.softmax(logits, name="softmax_tensor") @@ -470,51 +470,18 @@ as the loss metric. The following code calculates cross entropy when the model runs in either `TRAIN` or `EVAL` mode: ```python -onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10) -loss = tf.losses.softmax_cross_entropy( - onehot_labels=onehot_labels, logits=logits) +loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) ``` Let's take a closer look at what's happening above. -Our `labels` tensor contains a list of predictions for our examples, e.g. `[1, -9, ...]`. In order to calculate cross-entropy, first we need to convert `labels` -to the corresponding -[one-hot encoding](https://www.quora.com/What-is-one-hot-encoding-and-when-is-it-used-in-data-science): +Our `labels` tensor contains a list of prediction indices for our examples, e.g. `[1, +9, ...]`. `logits` contains the linear outputs of our last layer. -```none -[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - ...] -``` - -We use the @{tf.one_hot} function -to perform this conversion. `tf.one_hot()` has two required arguments: - -* `indices`. The locations in the one-hot tensor that will have "on - values"—i.e., the locations of `1` values in the tensor shown above. -* `depth`. The depth of the one-hot tensor—i.e., the number of target classes. - Here, the depth is `10`. +`tf.losses.sparse_softmax_cross_entropy`, calculates the softmax crossentropy +(aka: categorical crossentropy, negative log-likelihood) from these two inputs +in an efficient, numerically stable way. -The following code creates the one-hot tensor for our labels, `onehot_labels`: - -```python -onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10) -``` - -Because `labels` contains a series of values from 0–9, `indices` is just our -`labels` tensor, with values cast to integers. The `depth` is `10` because we -have 10 possible target classes, one for each digit. - -Next, we compute cross-entropy of `onehot_labels` and the softmax of the -predictions from our logits layer. `tf.losses.softmax_cross_entropy()` takes -`onehot_labels` and `logits` as arguments, performs softmax activation on -`logits`, calculates cross-entropy, and returns our `loss` as a scalar `Tensor`: - -```python -loss = tf.losses.softmax_cross_entropy( - onehot_labels=onehot_labels, logits=logits) -``` ### Configure the Training Op @@ -605,8 +572,8 @@ feel free to change to another directory of your choice). ### Set Up a Logging Hook {#set_up_a_logging_hook} Since CNNs can take a while to train, let's set up some logging so we can track -progress during training. We can use TensorFlow's @{tf.train.SessionRunHook} to create a -@{tf.train.LoggingTensorHook} +progress during training. We can use TensorFlow's `tf.train.SessionRunHook` to create a +`tf.train.LoggingTensorHook` that will log the probability values from the softmax layer of our CNN. Add the following to `main()`: @@ -627,7 +594,7 @@ operation earlier when we generated the probabilities in `cnn_model_fn`. > argument, TensorFlow will assign a default name. A couple easy ways to > discover the names applied to operations are to visualize your graph on > @{$graph_viz$TensorBoard}) or to enable the -> @{$programmers_guide/debugger$TensorFlow Debugger (tfdbg)}. +> @{$guide/debugger$TensorFlow Debugger (tfdbg)}. Next, we create the `LoggingTensorHook`, passing `tensors_to_log` to the `tensors` argument. We set `every_n_iter=50`, which specifies that probabilities diff --git a/tensorflow/docs_src/tutorials/estimators/linear.md b/tensorflow/docs_src/tutorials/estimators/linear.md new file mode 100644 index 0000000000000000000000000000000000000000..067a33ac036ec54826c6e88d0c9dc11b07e95976 --- /dev/null +++ b/tensorflow/docs_src/tutorials/estimators/linear.md @@ -0,0 +1,3 @@ +# Build a linear model with Estimators + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/estimators/linear.ipynb) diff --git a/tensorflow/docs_src/tutorials/image_retraining.md b/tensorflow/docs_src/tutorials/image_retraining.md deleted file mode 100644 index 27784eef9cdb5c6f8b9af44b3fc3f876cda39d13..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/tutorials/image_retraining.md +++ /dev/null @@ -1,4 +0,0 @@ -# How to Retrain Inception's Final Layer for New Categories - -**NOTE: This tutorial has moved to** -https://github.com/tensorflow/hub/tree/master/docs/tutorials/image_retraining.md diff --git a/tensorflow/docs_src/tutorials/deep_cnn.md b/tensorflow/docs_src/tutorials/images/deep_cnn.md similarity index 84% rename from tensorflow/docs_src/tutorials/deep_cnn.md rename to tensorflow/docs_src/tutorials/images/deep_cnn.md index 6a4c9a9b0727208a158b1b57d13ca70290961ec2..42ad484bbfe0b34383648197a8c88c2fa097c342 100644 --- a/tensorflow/docs_src/tutorials/deep_cnn.md +++ b/tensorflow/docs_src/tutorials/images/deep_cnn.md @@ -1,7 +1,4 @@ -# Convolutional Neural Networks - -> **NOTE:** This tutorial is intended for *advanced* users of TensorFlow -and assumes expertise and experience in machine learning. +# Advanced Convolutional Neural Networks ## Overview @@ -34,26 +31,26 @@ new ideas and experimenting with new techniques. The CIFAR-10 tutorial demonstrates several important constructs for designing larger and more sophisticated models in TensorFlow: -* Core mathematical components including @{tf.nn.conv2d$convolution} +* Core mathematical components including `tf.nn.conv2d` ([wiki](https://en.wikipedia.org/wiki/Convolution)), -@{tf.nn.relu$rectified linear activations} +`tf.nn.relu` ([wiki](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))), -@{tf.nn.max_pool$max pooling} +`tf.nn.max_pool` ([wiki](https://en.wikipedia.org/wiki/Convolutional_neural_network#Pooling_layer)) -and @{tf.nn.local_response_normalization$local response normalization} +and `tf.nn.local_response_normalization` (Chapter 3.3 in [AlexNet paper](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)). * @{$summaries_and_tensorboard$Visualization} of network activities during training, including input images, losses and distributions of activations and gradients. * Routines for calculating the -@{tf.train.ExponentialMovingAverage$moving average} +`tf.train.ExponentialMovingAverage` of learned parameters and using these averages during evaluation to boost predictive performance. * Implementation of a -@{tf.train.exponential_decay$learning rate schedule} +`tf.train.exponential_decay` that systematically decrements over time. -* Prefetching @{tf.train.shuffle_batch$queues} +* Prefetching `tf.train.shuffle_batch` for input data to isolate the model from disk latency and expensive image pre-processing. @@ -83,21 +80,21 @@ for details. It consists of 1,068,298 learnable parameters and requires about ## Code Organization The code for this tutorial resides in -[`models/tutorials/image/cifar10/`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/). +[`models/tutorials/image/cifar10/`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/). File | Purpose --- | --- -[`cifar10_input.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10_input.py) | Reads the native CIFAR-10 binary file format. -[`cifar10.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10.py) | Builds the CIFAR-10 model. -[`cifar10_train.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10_train.py) | Trains a CIFAR-10 model on a CPU or GPU. -[`cifar10_multi_gpu_train.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10_multi_gpu_train.py) | Trains a CIFAR-10 model on multiple GPUs. -[`cifar10_eval.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10_eval.py) | Evaluates the predictive performance of a CIFAR-10 model. +[`cifar10_input.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_input.py) | Reads the native CIFAR-10 binary file format. +[`cifar10.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10.py) | Builds the CIFAR-10 model. +[`cifar10_train.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_train.py) | Trains a CIFAR-10 model on a CPU or GPU. +[`cifar10_multi_gpu_train.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py) | Trains a CIFAR-10 model on multiple GPUs. +[`cifar10_eval.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_eval.py) | Evaluates the predictive performance of a CIFAR-10 model. ## CIFAR-10 Model The CIFAR-10 network is largely contained in -[`cifar10.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10.py). +[`cifar10.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10.py). The complete training graph contains roughly 765 operations. We find that we can make the code most reusable by constructing the graph with the following modules: @@ -116,27 +113,27 @@ gradients, variable updates and visualization summaries. The input part of the model is built by the functions `inputs()` and `distorted_inputs()` which read images from the CIFAR-10 binary data files. These files contain fixed byte length records, so we use -@{tf.FixedLengthRecordReader}. +`tf.FixedLengthRecordReader`. See @{$reading_data#reading-from-files$Reading Data} to learn more about how the `Reader` class works. The images are processed as follows: * They are cropped to 24 x 24 pixels, centrally for evaluation or - @{tf.random_crop$randomly} for training. -* They are @{tf.image.per_image_standardization$approximately whitened} + `tf.random_crop` for training. +* They are `tf.image.per_image_standardization` to make the model insensitive to dynamic range. For training, we additionally apply a series of random distortions to artificially increase the data set size: -* @{tf.image.random_flip_left_right$Randomly flip} the image from left to right. -* Randomly distort the @{tf.image.random_brightness$image brightness}. -* Randomly distort the @{tf.image.random_contrast$image contrast}. +* `tf.image.random_flip_left_right` the image from left to right. +* Randomly distort the `tf.image.random_brightness`. +* Randomly distort the `tf.image.random_contrast`. Please see the @{$python/image$Images} page for the list of available distortions. We also attach an -@{tf.summary.image} to the images +`tf.summary.image` to the images so that we may visualize them in @{$summaries_and_tensorboard$TensorBoard}. This is a good practice to verify that inputs are built correctly. @@ -147,7 +144,7 @@ This is a good practice to verify that inputs are built correctly. Reading images from disk and distorting them can use a non-trivial amount of processing time. To prevent these operations from slowing down training, we run them inside 16 separate threads which continuously fill a TensorFlow -@{tf.train.shuffle_batch$queue}. +`tf.train.shuffle_batch`. ### Model Prediction @@ -157,12 +154,12 @@ the model is organized as follows: Layer Name | Description --- | --- -`conv1` | @{tf.nn.conv2d$convolution} and @{tf.nn.relu$rectified linear} activation. -`pool1` | @{tf.nn.max_pool$max pooling}. -`norm1` | @{tf.nn.local_response_normalization$local response normalization}. -`conv2` | @{tf.nn.conv2d$convolution} and @{tf.nn.relu$rectified linear} activation. -`norm2` | @{tf.nn.local_response_normalization$local response normalization}. -`pool2` | @{tf.nn.max_pool$max pooling}. +`conv1` | `tf.nn.conv2d` and `tf.nn.relu` activation. +`pool1` | `tf.nn.max_pool`. +`norm1` | `tf.nn.local_response_normalization`. +`conv2` | `tf.nn.conv2d` and `tf.nn.relu` activation. +`norm2` | `tf.nn.local_response_normalization`. +`pool2` | `tf.nn.max_pool`. `local3` | @{$python/nn$fully connected layer with rectified linear activation}. `local4` | @{$python/nn$fully connected layer with rectified linear activation}. `softmax_linear` | linear transformation to produce logits. @@ -175,7 +172,7 @@ Here is a graph generated from TensorBoard describing the inference operation: > **EXERCISE**: The output of `inference` are un-normalized logits. Try editing the network architecture to return normalized predictions using -@{tf.nn.softmax}. +`tf.nn.softmax`. The `inputs()` and `inference()` functions provide all the components necessary to perform an evaluation of a model. We now shift our focus towards @@ -193,16 +190,16 @@ architecture in the top layer. The usual method for training a network to perform N-way classification is [multinomial logistic regression](https://en.wikipedia.org/wiki/Multinomial_logistic_regression), aka. *softmax regression*. Softmax regression applies a -@{tf.nn.softmax$softmax} nonlinearity to the +`tf.nn.softmax` nonlinearity to the output of the network and calculates the -@{tf.nn.sparse_softmax_cross_entropy_with_logits$cross-entropy} +`tf.nn.sparse_softmax_cross_entropy_with_logits` between the normalized predictions and the label index. For regularization, we also apply the usual -@{tf.nn.l2_loss$weight decay} losses to all learned +`tf.nn.l2_loss` losses to all learned variables. The objective function for the model is the sum of the cross entropy loss and all these weight decay terms, as returned by the `loss()` function. -We visualize it in TensorBoard with a @{tf.summary.scalar}: +We visualize it in TensorBoard with a `tf.summary.scalar`: ![CIFAR-10 Loss](https://www.tensorflow.org/images/cifar_loss.png "CIFAR-10 Total Loss") @@ -210,14 +207,14 @@ We train the model using standard [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) algorithm (see @{$python/train$Training} for other methods) with a learning rate that -@{tf.train.exponential_decay$exponentially decays} +`tf.train.exponential_decay` over time. ![CIFAR-10 Learning Rate Decay](https://www.tensorflow.org/images/cifar_lr_decay.png "CIFAR-10 Learning Rate Decay") The `train()` function adds the operations needed to minimize the objective by calculating the gradient and updating the learned variables (see -@{tf.train.GradientDescentOptimizer} +`tf.train.GradientDescentOptimizer` for details). It returns an operation that executes all the calculations needed to train and update the model for one batch of images. @@ -266,9 +263,9 @@ training step can take so long. Try decreasing the number of images that initially fill up the queue. Search for `min_fraction_of_examples_in_queue` in `cifar10_input.py`. -`cifar10_train.py` periodically @{tf.train.Saver$saves} +`cifar10_train.py` periodically uses a `tf.train.Saver` to save all model parameters in -@{$programmers_guide/saved_model$checkpoint files} +@{$guide/saved_model$checkpoint files} but it does *not* evaluate the model. The checkpoint file will be used by `cifar10_eval.py` to measure the predictive performance (see [Evaluating a Model](#evaluating-a-model) below). @@ -288,7 +285,7 @@ how the model is training. We want more insight into the model during training: @{$summaries_and_tensorboard$TensorBoard} provides this functionality, displaying data exported periodically from `cifar10_train.py` via a -@{tf.summary.FileWriter}. +`tf.summary.FileWriter`. For instance, we can watch how the distribution of activations and degree of sparsity in `local3` features evolve during training: @@ -303,7 +300,7 @@ interesting to track over time. However, the loss exhibits a considerable amount of noise due to the small batch size employed by training. In practice we find it extremely useful to visualize their moving averages in addition to their raw values. See how the scripts use -@{tf.train.ExponentialMovingAverage} +`tf.train.ExponentialMovingAverage` for this purpose. ## Evaluating a Model @@ -339,8 +336,8 @@ exports summaries that may be visualized in TensorBoard. These summaries provide additional insight into the model during evaluation. The training script calculates the -@{tf.train.ExponentialMovingAverage$moving average} -version of all learned variables. The evaluation script substitutes +`tf.train.ExponentialMovingAverage` of all learned variables. +The evaluation script substitutes all learned model parameters with the moving average version. This substitution boosts model performance at evaluation time. @@ -404,17 +401,17 @@ gradients for a single model replica. In the code we term this abstraction a "tower". We must set two attributes for each tower: * A unique name for all operations within a tower. -@{tf.name_scope} provides +`tf.name_scope` provides this unique name by prepending a scope. For instance, all operations in the first tower are prepended with `tower_0`, e.g. `tower_0/conv1/Conv2D`. * A preferred hardware device to run the operation within a tower. -@{tf.device} specifies this. For +`tf.device` specifies this. For instance, all operations in the first tower reside within `device('/device:GPU:0')` scope indicating that they should be run on the first GPU. All variables are pinned to the CPU and accessed via -@{tf.get_variable} +`tf.get_variable` in order to share them in a multi-GPU version. See how-to on @{$variables$Sharing Variables}. @@ -438,9 +435,6 @@ with a batch size of 64 and compare the training speed. ## Next Steps -[Congratulations!](https://www.youtube.com/watch?v=9bZkp7q19f0) You have -completed the CIFAR-10 tutorial. - If you are now interested in developing and training your own image classification system, we recommend forking this tutorial and replacing components to address your image classification problem. diff --git a/tensorflow/docs_src/tutorials/image_recognition.md b/tensorflow/docs_src/tutorials/images/image_recognition.md similarity index 99% rename from tensorflow/docs_src/tutorials/image_recognition.md rename to tensorflow/docs_src/tutorials/images/image_recognition.md index 332bcf54f02e6e3c7d805746011dfab642943cfe..83a8d97cf04ca0442c6b670d144c3dcf5443bfc8 100644 --- a/tensorflow/docs_src/tutorials/image_recognition.md +++ b/tensorflow/docs_src/tutorials/images/image_recognition.md @@ -253,7 +253,7 @@ definition with the `ToGraphDef()` function. TF_RETURN_IF_ERROR(session->Run({}, {output_name}, {}, out_tensors)); return Status::OK(); ``` -Then we create a @{tf.Session} +Then we create a `tf.Session` object, which is the interface to actually running the graph, and run it, specifying which node we want to get the output from, and where to put the output data. @@ -434,7 +434,6 @@ should be able to transfer some of that understanding to solving related problems. One way to perform transfer learning is to remove the final classification layer of the network and extract the [next-to-last layer of the CNN](https://arxiv.org/abs/1310.1531), in this case a 2048 dimensional vector. -There's a guide to doing this @{$image_retraining$in the how-to section}. ## Resources for Learning More @@ -450,7 +449,7 @@ covering them. To find out more about implementing convolutional neural networks, you can jump to the TensorFlow @{$deep_cnn$deep convolutional networks tutorial}, -or start a bit more gently with our @{$layers$MNIST starter tutorial}. +or start a bit more gently with our [Estimator MNIST tutorial](../estimators/cnn.md). Finally, if you want to get up to speed on research in this area, you can read the recent work of all the papers referenced in this tutorial. diff --git a/tensorflow/docs_src/tutorials/index.md b/tensorflow/docs_src/tutorials/index.md deleted file mode 100644 index af01d3eaa12157f82c981de005708509f6652cca..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/tutorials/index.md +++ /dev/null @@ -1,60 +0,0 @@ -# Tutorials - - -This section contains tutorials demonstrating how to do specific tasks -in TensorFlow. If you are new to TensorFlow, we recommend reading the -documents in the "@{$get_started$Get Started}" section before reading -these tutorials. - -## Images - -These tutorials cover different aspects of image recognition: - - * @{$layers$MNIST}, which introduces convolutional neural networks (CNNs) and - demonstrates how to build a CNN in TensorFlow. - * @{$image_recognition}, which introduces the field of image recognition and - uses a pre-trained model (Inception) for recognizing images. - * @{$image_retraining}, which has a wonderfully self-explanatory title. - * @{$deep_cnn}, which demonstrates how to build a small CNN for recognizing - images. This tutorial is aimed at advanced TensorFlow users. - - -## Sequences - -These tutorials focus on machine learning problems dealing with sequence data. - - * @{$recurrent}, which demonstrates how to use a - recurrent neural network to predict the next word in a sentence. - * @{$seq2seq}, which demonstrates how to use a - sequence-to-sequence model to translate text from English to French. - * @{$recurrent_quickdraw} - builds a classification model for drawings, directly from the sequence of - pen strokes. - * @{$audio_recognition}, which shows how to - build a basic speech recognition network. - -## Data representation - -These tutorials demonstrate various data representations that can be used in -TensorFlow. - - * @{$wide}, uses - @{tf.feature_column$feature columns} to feed a variety of data types - to linear model, to solve a classification problem. - * @{$wide_and_deep}, builds on the - above linear model tutorial, adding a deep feed-forward neural network - component and a DNN-compatible data representation. - * @{$word2vec}, which demonstrates how to - create an embedding for words. - * @{$kernel_methods}, - which shows how to improve the quality of a linear model by using explicit - kernel mappings. - -## Non Machine Learning - -Although TensorFlow specializes in machine learning, the core of TensorFlow is -a powerful numeric computation system which you can also use to solve other -kinds of math problems. For example: - - * @{$mandelbrot} - * @{$pdes} diff --git a/tensorflow/docs_src/tutorials/keras/basic_classification.md b/tensorflow/docs_src/tutorials/keras/basic_classification.md new file mode 100644 index 0000000000000000000000000000000000000000..e028af99b936a92cf359a7b4e561f7bcf3c4bffc --- /dev/null +++ b/tensorflow/docs_src/tutorials/keras/basic_classification.md @@ -0,0 +1,3 @@ +# Basic Classification + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/basic_classification.ipynb) diff --git a/tensorflow/docs_src/tutorials/keras/basic_regression.md b/tensorflow/docs_src/tutorials/keras/basic_regression.md new file mode 100644 index 0000000000000000000000000000000000000000..8721b7aca19e3f37b6989bb1b280ac3b4fdffc8e --- /dev/null +++ b/tensorflow/docs_src/tutorials/keras/basic_regression.md @@ -0,0 +1,3 @@ +# Basic Regression + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/basic_regression.ipynb) diff --git a/tensorflow/docs_src/tutorials/keras/basic_text_classification.md b/tensorflow/docs_src/tutorials/keras/basic_text_classification.md new file mode 100644 index 0000000000000000000000000000000000000000..c2a16bdd204c303cd166f283229cb9eaf73540b0 --- /dev/null +++ b/tensorflow/docs_src/tutorials/keras/basic_text_classification.md @@ -0,0 +1,3 @@ +# Basic Text Classification + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/basic_text_classification.ipynb) diff --git a/tensorflow/docs_src/tutorials/keras/index.md b/tensorflow/docs_src/tutorials/keras/index.md new file mode 100644 index 0000000000000000000000000000000000000000..9d42281c8f97fd8930770c0bc30c9bcf1e50fde6 --- /dev/null +++ b/tensorflow/docs_src/tutorials/keras/index.md @@ -0,0 +1,22 @@ +# Learn and use machine learning + +This notebook collection is inspired by the book +*[Deep Learning with Python](https://books.google.com/books?id=Yo3CAQAACAAJ)*. +These tutorials use `tf.keras`, TensorFlow's high-level Python API for building +and training deep learning models. To learn more about using Keras with +TensorFlow, see the [TensorFlow Keras Guide](../../guide/keras). + +Publisher's note: *Deep Learning with Python* introduces the field of deep +learning using the Python language and the powerful Keras library. Written by +Keras creator and Google AI researcher François Chollet, this book builds your +understanding through intuitive explanations and practical examples. + +To learn about machine learning fundamentals and concepts, consider taking the +[Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/). +Additional TensorFlow and machine learning resources are listed in [next steps](../next_steps). + +1. [Basic classification](./basic_classification) +2. [Text classification](./basic_text_classification) +3. [Regression](./basic_regression) +4. [Overfitting and underfitting](./overfit_and_underfit) +5. [Save and restore models](./save_and_restore_models) diff --git a/tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md b/tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md new file mode 100644 index 0000000000000000000000000000000000000000..f07f3addd82235181cc6c4c5d32d44da2c72107f --- /dev/null +++ b/tensorflow/docs_src/tutorials/keras/overfit_and_underfit.md @@ -0,0 +1,3 @@ +# Overfitting and Underfitting + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/overfit_and_underfit.ipynb) diff --git a/tensorflow/docs_src/tutorials/keras/save_and_restore_models.md b/tensorflow/docs_src/tutorials/keras/save_and_restore_models.md new file mode 100644 index 0000000000000000000000000000000000000000..a799b379a004d545b12d7c1d37b78ee3baeee1fc --- /dev/null +++ b/tensorflow/docs_src/tutorials/keras/save_and_restore_models.md @@ -0,0 +1,3 @@ +# Save and restore Models + +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/tutorials/keras/save_and_restore_models.ipynb) diff --git a/tensorflow/docs_src/tutorials/leftnav_files b/tensorflow/docs_src/tutorials/leftnav_files deleted file mode 100644 index 888052428f951fa1a7cbd9c6d35497a056387097..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/tutorials/leftnav_files +++ /dev/null @@ -1,23 +0,0 @@ -index.md - -### Images -layers.md: MNIST -image_recognition.md: Image Recognition -image_retraining.md: Image Retraining -deep_cnn.md - -### Sequences -recurrent.md -seq2seq.md: Neural Machine Translation -recurrent_quickdraw.md: Drawing Classification -audio_recognition.md - -### Data Representation -wide.md: Linear Models -wide_and_deep.md: Wide & Deep Learning -word2vec.md -kernel_methods.md: Kernel Methods - -### Non-ML -mandelbrot.md -pdes.md diff --git a/tensorflow/docs_src/tutorials/next_steps.md b/tensorflow/docs_src/tutorials/next_steps.md new file mode 100644 index 0000000000000000000000000000000000000000..01c9f7204a7ddae16bcbd9eb5702516a39f8ce4c --- /dev/null +++ b/tensorflow/docs_src/tutorials/next_steps.md @@ -0,0 +1,36 @@ +# Next steps + +## Learn more about TensorFlow + +* The [TensorFlow Guide](/guide) includes usage guides for the + high-level APIs, as well as advanced TensorFlow operations. +* [Premade Estimators](/guide/premade_estimators) are designed to + get results out of the box. Use TensorFlow without building your own models. +* [TensorFlow.js](https://js.tensorflow.org/) allows web developers to train and + deploy ML models in the browser and using Node.js. +* [TFLite](/mobile/tflite) allows mobile developers to do inference efficiently + on mobile devices. +* [TensorFlow Serving](/serving) is an open-source project that can put + TensorFlow models in production quickly. +* The [ecosystem](/ecosystem) contains more projects, including + [Magenta](https://magenta.tensorflow.org/), [TFX](/tfx), + [Swift for TensorFlow](https://github.com/tensorflow/swift), and more. + +## Learn more about machine learning + +Recommended resources include: + +* [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/), + a course from Google that introduces machine learning concepts. +* [CS 20: Tensorflow for Deep Learning Research](http://web.stanford.edu/class/cs20si/), + notes from an intro course from Stanford. +* [CS231n: Convolutional Neural Networks for Visual Recognition](http://cs231n.stanford.edu/), + a course that teaches how convolutional networks work. +* [Machine Learning Recipes](https://www.youtube.com/watch?v=cKxRvEZd3Mw&list=PLOU2XLYxmsIIuiBfYad6rFYQU_jL2ryal), + a video series that introduces basic machine learning concepts with few prerequisites. +* [Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python), + a book by Francois Chollet about the Keras API, as well as an excellent hands on intro to Deep Learning. +* [Hands-on Machine Learning with Scikit-Learn and TensorFlow](https://github.com/ageron/handson-ml), + a book by Aurélien Geron's that is a clear getting-started guide to data science and deep learning. +* [Deep Learning](https://www.deeplearningbook.org/), a book by Ian Goodfellow et al. + that provides a technical dive into learning machine learning. diff --git a/tensorflow/docs_src/tutorials/mandelbrot.md b/tensorflow/docs_src/tutorials/non-ml/mandelbrot.md old mode 100755 new mode 100644 similarity index 100% rename from tensorflow/docs_src/tutorials/mandelbrot.md rename to tensorflow/docs_src/tutorials/non-ml/mandelbrot.md diff --git a/tensorflow/docs_src/tutorials/pdes.md b/tensorflow/docs_src/tutorials/non-ml/pdes.md old mode 100755 new mode 100644 similarity index 98% rename from tensorflow/docs_src/tutorials/pdes.md rename to tensorflow/docs_src/tutorials/non-ml/pdes.md index 425e8d7084e7f2505b7a3013b431345b72b38cf0..b5a0fa834a8a0a51421657180f8c7817c0e3d140 --- a/tensorflow/docs_src/tutorials/pdes.md +++ b/tensorflow/docs_src/tutorials/non-ml/pdes.md @@ -135,7 +135,6 @@ for i in range(1000): DisplayArray(U.eval(), rng=[-0.1, 0.1]) ``` -![jpeg](../images/pde_output_2.jpg) +![jpeg](../../images/pde_output_2.jpg) Look! Ripples! - diff --git a/tensorflow/docs_src/tutorials/kernel_methods.md b/tensorflow/docs_src/tutorials/representation/kernel_methods.md similarity index 96% rename from tensorflow/docs_src/tutorials/kernel_methods.md rename to tensorflow/docs_src/tutorials/representation/kernel_methods.md index 205e2a2d2c1d1008e62ca4c2caf9f1b0895dff1a..71e87f4d3e986ad552ccabc33d41266c3e0f871b 100644 --- a/tensorflow/docs_src/tutorials/kernel_methods.md +++ b/tensorflow/docs_src/tutorials/representation/kernel_methods.md @@ -1,9 +1,8 @@ # Improving Linear Models Using Explicit Kernel Methods -Note: This document uses a deprecated version of @{tf.estimator}, -which has a @{tf.contrib.learn.Estimator$different interface}. -It also uses other `contrib` methods whose -@{$version_compat#not_covered$API may not be stable}. +Note: This document uses a deprecated version of `tf.estimator`, +`tf.contrib.learn.Estimator`, which has a different interface. It also uses +other `contrib` methods whose @{$version_compat#not_covered$API may not be stable}. In this tutorial, we demonstrate how combining (explicit) kernel methods with linear models can drastically increase the latters' quality of predictions @@ -27,7 +26,7 @@ TensorFlow will provide support for sparse features at a later release. This tutorial uses [tf.contrib.learn](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn) (TensorFlow's high-level Machine Learning API) Estimators for our ML models. -If you are not familiar with this API, [tf.estimator Quickstart](https://www.tensorflow.org/get_started/estimator) +If you are not familiar with this API, The [Estimator guide](../../guide/estimators.md) is a good place to start. We will use the MNIST dataset. The tutorial consists of the following steps: @@ -90,7 +89,7 @@ eval_input_fn = get_input_fn(data.validation, batch_size=5000) ## Training a simple linear model We can now train a linear model over the MNIST dataset. We will use the -@{tf.contrib.learn.LinearClassifier} estimator with 10 classes representing the +`tf.contrib.learn.LinearClassifier` estimator with 10 classes representing the 10 digits. The input features form a 784-dimensional dense vector which can be specified as follows: @@ -195,7 +194,7 @@ much higher dimensional space than the original one. See for more details. ### Kernel classifier -@{tf.contrib.kernel_methods.KernelLinearClassifier} is a pre-packaged +`tf.contrib.kernel_methods.KernelLinearClassifier` is a pre-packaged `tf.contrib.learn` estimator that combines the power of explicit kernel mappings with linear models. Its constructor is almost identical to that of the LinearClassifier estimator with the additional option to specify a list of diff --git a/tensorflow/docs_src/tutorials/linear.md b/tensorflow/docs_src/tutorials/representation/linear.md similarity index 94% rename from tensorflow/docs_src/tutorials/linear.md rename to tensorflow/docs_src/tutorials/representation/linear.md index 3f247ade266d2675eac4d0f59a4744daa61f27ea..014409c617ea7c836e472cab1aa828fa497bd412 100644 --- a/tensorflow/docs_src/tutorials/linear.md +++ b/tensorflow/docs_src/tutorials/representation/linear.md @@ -1,6 +1,6 @@ # Large-scale Linear Models with TensorFlow -@{tf.estimator$Estimators} provides (among other things) a rich set of tools for +`tf.estimator` provides (among other things) a rich set of tools for working with linear models in TensorFlow. This document provides an overview of those tools. It explains: @@ -11,8 +11,9 @@ those tools. It explains: deep learning to get the advantages of both. Read this overview to decide whether the Estimator's linear model tools might -be useful to you. Then do the @{$wide$Linear Models tutorial} to -give it a try. This overview uses code samples from the tutorial, but the +be useful to you. Then work through the +[Estimator wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep) +to give it a try. This overview uses code samples from the tutorial, but the tutorial walks through the code in greater detail. To understand this overview it will help to have some familiarity @@ -176,7 +177,7 @@ the name of a `FeatureColumn`. Each key's value is a tensor containing the values of that feature for all data instances. See @{$premade_estimators#input_fn} for a more comprehensive look at input functions, and `input_fn` in the -[linear models tutorial code](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py) +[wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep) for an example implementation of an input function. The input function is passed to the `train()` and `evaluate()` calls that @@ -234,4 +235,5 @@ e = tf.estimator.DNNLinearCombinedClassifier( dnn_feature_columns=deep_columns, dnn_hidden_units=[100, 50]) ``` -For more information, see the @{$wide_and_deep$Wide and Deep Learning tutorial}. +For more information, see the +[wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep). diff --git a/tensorflow/docs_src/tutorials/word2vec.md b/tensorflow/docs_src/tutorials/representation/word2vec.md similarity index 96% rename from tensorflow/docs_src/tutorials/word2vec.md rename to tensorflow/docs_src/tutorials/representation/word2vec.md index 3fe7352bd2383177ca200a0265dee41dba430144..7964650e199d0d8f156feb74ee95bc0c33593661 100644 --- a/tensorflow/docs_src/tutorials/word2vec.md +++ b/tensorflow/docs_src/tutorials/representation/word2vec.md @@ -23,7 +23,7 @@ straight in, feel free to look at the minimalistic implementation in This basic example contains the code needed to download some data, train on it a bit and visualize the result. Once you get comfortable with reading and running the basic version, you can graduate to -[models/tutorials/embedding/word2vec.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec.py) +[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py) which is a more serious implementation that showcases some more advanced TensorFlow principles about how to efficiently use threads to move data into a text model, how to checkpoint during training, etc. @@ -317,7 +317,7 @@ optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0).minimize(loss) Training the model is then as simple as using a `feed_dict` to push data into the placeholders and calling -@{tf.Session.run} with this new data +`tf.Session.run` with this new data in a loop. ```python @@ -341,7 +341,7 @@ t-SNE. Et voila! As expected, words that are similar end up clustering nearby each other. For a more heavyweight implementation of word2vec that showcases more of the advanced features of TensorFlow, see the implementation in -[models/tutorials/embedding/word2vec.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec.py). +[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py). ## Evaluating Embeddings: Analogical Reasoning @@ -357,7 +357,7 @@ Download the dataset for this task from To see how we do this evaluation, have a look at the `build_eval_graph()` and `eval()` functions in -[models/tutorials/embedding/word2vec.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec.py). +[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py). The choice of hyperparameters can strongly influence the accuracy on this task. To achieve state-of-the-art performance on this task requires training over a @@ -385,13 +385,13 @@ your model is seriously bottlenecked on input data, you may want to implement a custom data reader for your problem, as described in @{$new_data_formats$New Data Formats}. For the case of Skip-Gram modeling, we've actually already done this for you as an example in -[models/tutorials/embedding/word2vec.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec.py). +[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py). If your model is no longer I/O bound but you want still more performance, you can take things further by writing your own TensorFlow Ops, as described in @{$adding_an_op$Adding a New Op}. Again we've provided an example of this for the Skip-Gram case -[models/tutorials/embedding/word2vec_optimized.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec_optimized.py). +[models/tutorials/embedding/word2vec_optimized.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec_optimized.py). Feel free to benchmark these against each other to measure performance improvements at each stage. diff --git a/tensorflow/docs_src/tutorials/seq2seq.md b/tensorflow/docs_src/tutorials/seq2seq.md deleted file mode 100644 index 8928ba4f7da26ae2e8e9351e2c7c03f0e657f613..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/tutorials/seq2seq.md +++ /dev/null @@ -1,5 +0,0 @@ -# Sequence-to-Sequence Models - -Please check out the -[tensorflow neural machine translation tutorial](https://github.com/tensorflow/nmt) -for building sequence-to-sequence models with the latest Tensorflow API. diff --git a/tensorflow/docs_src/tutorials/audio_recognition.md b/tensorflow/docs_src/tutorials/sequences/audio_recognition.md similarity index 100% rename from tensorflow/docs_src/tutorials/audio_recognition.md rename to tensorflow/docs_src/tutorials/sequences/audio_recognition.md diff --git a/tensorflow/docs_src/tutorials/recurrent.md b/tensorflow/docs_src/tutorials/sequences/recurrent.md similarity index 95% rename from tensorflow/docs_src/tutorials/recurrent.md rename to tensorflow/docs_src/tutorials/sequences/recurrent.md index 14da2c8785276abb34d6959d738f5b39e6c6a2e8..10d60f7966c7026d0d01b1155e435d8be299734f 100644 --- a/tensorflow/docs_src/tutorials/recurrent.md +++ b/tensorflow/docs_src/tutorials/sequences/recurrent.md @@ -2,8 +2,8 @@ ## Introduction -Take a look at [this great article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) -for an introduction to recurrent neural networks and LSTMs in particular. +See [Understanding LSTM Networks](https://colah.github.io/posts/2015-08-Understanding-LSTMs/){:.external} +for an introduction to recurrent neural networks and LSTMs. ## Language Modeling @@ -77,9 +77,7 @@ The basic pseudocode is as follows: words_in_dataset = tf.placeholder(tf.float32, [time_steps, batch_size, num_features]) lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size) # Initial state of the LSTM memory. -hidden_state = tf.zeros([batch_size, lstm.state_size]) -current_state = tf.zeros([batch_size, lstm.state_size]) -state = hidden_state, current_state +state = lstm.zero_state(batch_size, dtype=tf.float32) probabilities = [] loss = 0.0 for current_batch_of_words in words_in_dataset: @@ -112,7 +110,7 @@ words = tf.placeholder(tf.int32, [batch_size, num_steps]) lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size) # Initial state of the LSTM memory. -initial_state = state = tf.zeros([batch_size, lstm.state_size]) +initial_state = state = lstm.zero_state(batch_size, dtype=tf.float32) for i in range(num_steps): # The value of state is updated after processing each batch of words. diff --git a/tensorflow/docs_src/tutorials/recurrent_quickdraw.md b/tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md similarity index 98% rename from tensorflow/docs_src/tutorials/recurrent_quickdraw.md rename to tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md index 1afd861738512f20de5171548d539d256f5f5225..37bce5b76d46741dfe04cbf3612f71863adb02c6 100644 --- a/tensorflow/docs_src/tutorials/recurrent_quickdraw.md +++ b/tensorflow/docs_src/tutorials/sequences/recurrent_quickdraw.md @@ -13,7 +13,7 @@ In this tutorial we'll show how to build an RNN-based recognizer for this problem. The model will use a combination of convolutional layers, LSTM layers, and a softmax output layer to classify the drawings: -
![RNN model structure](../images/quickdraw_model.png)
+
![RNN model structure](../../images/quickdraw_model.png)
The figure above shows the structure of the model that we will build in this tutorial. The input is a drawing that is encoded as a sequence of strokes of @@ -208,7 +208,7 @@ This data is then reformatted into a tensor of shape `[num_training_samples, max_length, 3]`. Then we determine the bounding box of the original drawing in screen coordinates and normalize the size such that the drawing has unit height. -
![Size normalization](../images/quickdraw_sizenormalization.png)
+
![Size normalization](../../images/quickdraw_sizenormalization.png)
Finally, we compute the differences between consecutive points and store these as a `VarLenFeature` in a diff --git a/tensorflow/docs_src/tutorials/wide.md b/tensorflow/docs_src/tutorials/wide.md deleted file mode 100644 index 27ce75a30dd2acd5925702611042270e767b0c73..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/tutorials/wide.md +++ /dev/null @@ -1,461 +0,0 @@ -# TensorFlow Linear Model Tutorial - -In this tutorial, we will use the tf.estimator API in TensorFlow to solve a -binary classification problem: Given census data about a person such as age, -education, marital status, and occupation (the features), we will try to predict -whether or not the person earns more than 50,000 dollars a year (the target -label). We will train a **logistic regression** model, and given an individual's -information our model will output a number between 0 and 1, which can be -interpreted as the probability that the individual has an annual income of over -50,000 dollars. - -## Setup - -To try the code for this tutorial: - -1. @{$install$Install TensorFlow} if you haven't already. - -2. Download [the tutorial code](https://github.com/tensorflow/models/tree/master/official/wide_deep/). - -3. Execute the data download script we provide to you: - - $ python data_download.py - -4. Execute the tutorial code with the following command to train the linear -model described in this tutorial: - - $ python wide_deep.py --model_type=wide - -Read on to find out how this code builds its linear model. - -## Reading The Census Data - -The dataset we'll be using is the -[Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/Census+Income). -We have provided -[data_download.py](https://github.com/tensorflow/models/tree/master/official/wide_deep/data_download.py) -which downloads the code and performs some additional cleanup. - -Since the task is a binary classification problem, we'll construct a label -column named "label" whose value is 1 if the income is over 50K, and 0 -otherwise. For reference, see `input_fn` in -[wide_deep.py](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py). - -Next, let's take a look at the dataframe and see which columns we can use to -predict the target label. The columns can be grouped into two types—categorical -and continuous columns: - -* A column is called **categorical** if its value can only be one of the - categories in a finite set. For example, the relationship status of a person - (wife, husband, unmarried, etc.) or the education level (high school, - college, etc.) are categorical columns. -* A column is called **continuous** if its value can be any numerical value in - a continuous range. For example, the capital gain of a person (e.g. $14,084) - is a continuous column. - -Here's a list of columns available in the Census Income dataset: - -| Column Name | Type | Description | -| -------------- | ----------- | --------------------------------- | -| age | Continuous | The age of the individual | -| workclass | Categorical | The type of employer the | -: : : individual has (government, : -: : : military, private, etc.). : -| fnlwgt | Continuous | The number of people the census | -: : : takers believe that observation : -: : : represents (sample weight). Final : -: : : weight will not be used. : -| education | Categorical | The highest level of education | -: : : achieved for that individual. : -| education_num | Continuous | The highest level of education in | -: : : numerical form. : -| marital_status | Categorical | Marital status of the individual. | -| occupation | Categorical | The occupation of the individual. | -| relationship | Categorical | Wife, Own-child, Husband, | -: : : Not-in-family, Other-relative, : -: : : Unmarried. : -| race | Categorical | Amer-Indian-Eskimo, Asian-Pac- | -: : : Islander, Black, White, Other. : -| gender | Categorical | Female, Male. | -| capital_gain | Continuous | Capital gains recorded. | -| capital_loss | Continuous | Capital Losses recorded. | -| hours_per_week | Continuous | Hours worked per week. | -| native_country | Categorical | Country of origin of the | -: : : individual. : -| income_bracket | Categorical | ">50K" or "<=50K", meaning | -: : : whether the person makes more : -: : : than $50,000 annually. : - -## Converting Data into Tensors - -When building a tf.estimator model, the input data is specified by means of an -Input Builder function. This builder function will not be called until it is -later passed to tf.estimator.Estimator methods such as `train` and `evaluate`. -The purpose of this function is to construct the input data, which is -represented in the form of @{tf.Tensor}s or @{tf.SparseTensor}s. -In more detail, the input builder function returns the following as a pair: - -1. `features`: A dict from feature column names to `Tensors` or - `SparseTensors`. -2. `labels`: A `Tensor` containing the label column. - -The keys of the `features` will be used to construct columns in the next -section. Because we want to call the `train` and `evaluate` methods with -different data, we define a method that returns an input function based on the -given data. Note that the returned input function will be called while -constructing the TensorFlow graph, not while running the graph. What it is -returning is a representation of the input data as the fundamental unit of -TensorFlow computations, a `Tensor` (or `SparseTensor`). - -Each continuous column in the train or test data will be converted into a -`Tensor`, which in general is a good format to represent dense data. For -categorical data, we must represent the data as a `SparseTensor`. This data -format is good for representing sparse data. Our `input_fn` uses the `tf.data` -API, which makes it easy to apply transformations to our dataset: - -```python -def input_fn(data_file, num_epochs, shuffle, batch_size): - """Generate an input function for the Estimator.""" - assert tf.gfile.Exists(data_file), ( - '%s not found. Please make sure you have either run data_download.py or ' - 'set both arguments --train_data and --test_data.' % data_file) - - def parse_csv(value): - print('Parsing', data_file) - columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS) - features = dict(zip(_CSV_COLUMNS, columns)) - labels = features.pop('income_bracket') - return features, tf.equal(labels, '>50K') - - # Extract lines from input files using the Dataset API. - dataset = tf.data.TextLineDataset(data_file) - - if shuffle: - dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER) - - dataset = dataset.map(parse_csv, num_parallel_calls=5) - - # We call repeat after shuffling, rather than before, to prevent separate - # epochs from blending together. - dataset = dataset.repeat(num_epochs) - dataset = dataset.batch(batch_size) - - iterator = dataset.make_one_shot_iterator() - features, labels = iterator.get_next() - return features, labels -``` - -## Selecting and Engineering Features for the Model - -Selecting and crafting the right set of feature columns is key to learning an -effective model. A **feature column** can be either one of the raw columns in -the original dataframe (let's call them **base feature columns**), or any new -columns created based on some transformations defined over one or multiple base -columns (let's call them **derived feature columns**). Basically, "feature -column" is an abstract concept of any raw or derived variable that can be used -to predict the target label. - -### Base Categorical Feature Columns - -To define a feature column for a categorical feature, we can create a -`CategoricalColumn` using the tf.feature_column API. If you know the set of all -possible feature values of a column and there are only a few of them, you can -use `categorical_column_with_vocabulary_list`. Each key in the list will get -assigned an auto-incremental ID starting from 0. For example, for the -`relationship` column we can assign the feature string "Husband" to an integer -ID of 0 and "Not-in-family" to 1, etc., by doing: - -```python -relationship = tf.feature_column.categorical_column_with_vocabulary_list( - 'relationship', [ - 'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried', - 'Other-relative']) -``` - -What if we don't know the set of possible values in advance? Not a problem. We -can use `categorical_column_with_hash_bucket` instead: - -```python -occupation = tf.feature_column.categorical_column_with_hash_bucket( - 'occupation', hash_bucket_size=1000) -``` - -What will happen is that each possible value in the feature column `occupation` -will be hashed to an integer ID as we encounter them in training. See an example -illustration below: - -ID | Feature ---- | ------------- -... | -9 | `"Machine-op-inspct"` -... | -103 | `"Farming-fishing"` -... | -375 | `"Protective-serv"` -... | - -No matter which way we choose to define a `SparseColumn`, each feature string -will be mapped into an integer ID by looking up a fixed mapping or by hashing. -Note that hashing collisions are possible, but may not significantly impact the -model quality. Under the hood, the `LinearModel` class is responsible for -managing the mapping and creating `tf.Variable` to store the model parameters -(also known as model weights) for each feature ID. The model parameters will be -learned through the model training process we'll go through later. - -We'll do the similar trick to define the other categorical features: - -```python -education = tf.feature_column.categorical_column_with_vocabulary_list( - 'education', [ - 'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college', - 'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school', - '5th-6th', '10th', '1st-4th', 'Preschool', '12th']) - -marital_status = tf.feature_column.categorical_column_with_vocabulary_list( - 'marital_status', [ - 'Married-civ-spouse', 'Divorced', 'Married-spouse-absent', - 'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed']) - -relationship = tf.feature_column.categorical_column_with_vocabulary_list( - 'relationship', [ - 'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried', - 'Other-relative']) - -workclass = tf.feature_column.categorical_column_with_vocabulary_list( - 'workclass', [ - 'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov', - 'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked']) - -# To show an example of hashing: -occupation = tf.feature_column.categorical_column_with_hash_bucket( - 'occupation', hash_bucket_size=1000) -``` - -### Base Continuous Feature Columns - -Similarly, we can define a `NumericColumn` for each continuous feature column -that we want to use in the model: - -```python -age = tf.feature_column.numeric_column('age') -education_num = tf.feature_column.numeric_column('education_num') -capital_gain = tf.feature_column.numeric_column('capital_gain') -capital_loss = tf.feature_column.numeric_column('capital_loss') -hours_per_week = tf.feature_column.numeric_column('hours_per_week') -``` - -### Making Continuous Features Categorical through Bucketization - -Sometimes the relationship between a continuous feature and the label is not -linear. As a hypothetical example, a person's income may grow with age in the -early stage of one's career, then the growth may slow at some point, and finally -the income decreases after retirement. In this scenario, using the raw `age` as -a real-valued feature column might not be a good choice because the model can -only learn one of the three cases: - -1. Income always increases at some rate as age grows (positive correlation), -1. Income always decreases at some rate as age grows (negative correlation), or -1. Income stays the same no matter at what age (no correlation) - -If we want to learn the fine-grained correlation between income and each age -group separately, we can leverage **bucketization**. Bucketization is a process -of dividing the entire range of a continuous feature into a set of consecutive -bins/buckets, and then converting the original numerical feature into a bucket -ID (as a categorical feature) depending on which bucket that value falls into. -So, we can define a `bucketized_column` over `age` as: - -```python -age_buckets = tf.feature_column.bucketized_column( - age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]) -``` - -where the `boundaries` is a list of bucket boundaries. In this case, there are -10 boundaries, resulting in 11 age group buckets (from age 17 and below, 18-24, -25-29, ..., to 65 and over). - -### Intersecting Multiple Columns with CrossedColumn - -Using each base feature column separately may not be enough to explain the data. -For example, the correlation between education and the label (earning > 50,000 -dollars) may be different for different occupations. Therefore, if we only learn -a single model weight for `education="Bachelors"` and `education="Masters"`, we -won't be able to capture every single education-occupation combination (e.g. -distinguishing between `education="Bachelors" AND occupation="Exec-managerial"` -and `education="Bachelors" AND occupation="Craft-repair"`). To learn the -differences between different feature combinations, we can add **crossed feature -columns** to the model. - -```python -education_x_occupation = tf.feature_column.crossed_column( - ['education', 'occupation'], hash_bucket_size=1000) -``` - -We can also create a `CrossedColumn` over more than two columns. Each -constituent column can be either a base feature column that is categorical -(`SparseColumn`), a bucketized real-valued feature column (`BucketizedColumn`), -or even another `CrossColumn`. Here's an example: - -```python -age_buckets_x_education_x_occupation = tf.feature_column.crossed_column( - [age_buckets, 'education', 'occupation'], hash_bucket_size=1000) -``` - -## Defining The Logistic Regression Model - -After processing the input data and defining all the feature columns, we're now -ready to put them all together and build a Logistic Regression model. In the -previous section we've seen several types of base and derived feature columns, -including: - -* `CategoricalColumn` -* `NumericColumn` -* `BucketizedColumn` -* `CrossedColumn` - -All of these are subclasses of the abstract `FeatureColumn` class, and can be -added to the `feature_columns` field of a model: - -```python -base_columns = [ - education, marital_status, relationship, workclass, occupation, - age_buckets, -] -crossed_columns = [ - tf.feature_column.crossed_column( - ['education', 'occupation'], hash_bucket_size=1000), - tf.feature_column.crossed_column( - [age_buckets, 'education', 'occupation'], hash_bucket_size=1000), -] - -model_dir = tempfile.mkdtemp() -model = tf.estimator.LinearClassifier( - model_dir=model_dir, feature_columns=base_columns + crossed_columns) -``` - -The model also automatically learns a bias term, which controls the prediction -one would make without observing any features (see the section "How Logistic -Regression Works" for more explanations). The learned model files will be stored -in `model_dir`. - -## Training and Evaluating Our Model - -After adding all the features to the model, now let's look at how to actually -train the model. Training a model is just a single command using the -tf.estimator API: - -```python -model.train(input_fn=lambda: input_fn(train_data, num_epochs, True, batch_size)) -``` - -After the model is trained, we can evaluate how good our model is at predicting -the labels of the holdout data: - -```python -results = model.evaluate(input_fn=lambda: input_fn( - test_data, 1, False, batch_size)) -for key in sorted(results): - print('%s: %s' % (key, results[key])) -``` - -The first line of the final output should be something like -`accuracy: 0.83557522`, which means the accuracy is 83.6%. Feel free to try more -features and transformations and see if you can do even better! - -After the model is evaluated, we can use the model to predict whether an individual has an annual income of over -50,000 dollars given an individual's information input. -```python - pred_iter = model.predict(input_fn=lambda: input_fn(FLAGS.test_data, 1, False, 1)) - for pred in pred_iter: - print(pred['classes']) -``` - -The model prediction output would be like `[b'1']` or `[b'0']` which means whether corresponding individual has an annual income of over 50,000 dollars or not. - -If you'd like to see a working end-to-end example, you can download our -[example code](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py) -and set the `model_type` flag to `wide`. - -## Adding Regularization to Prevent Overfitting - -Regularization is a technique used to avoid **overfitting**. Overfitting happens -when your model does well on the data it is trained on, but worse on test data -that the model has not seen before, such as live traffic. Overfitting generally -occurs when a model is excessively complex, such as having too many parameters -relative to the number of observed training data. Regularization allows for you -to control your model's complexity and makes the model more generalizable to -unseen data. - -In the Linear Model library, you can add L1 and L2 regularizations to the model -as: - -``` -model = tf.estimator.LinearClassifier( - model_dir=model_dir, feature_columns=base_columns + crossed_columns, - optimizer=tf.train.FtrlOptimizer( - learning_rate=0.1, - l1_regularization_strength=1.0, - l2_regularization_strength=1.0)) -``` - -One important difference between L1 and L2 regularization is that L1 -regularization tends to make model weights stay at zero, creating sparser -models, whereas L2 regularization also tries to make the model weights closer to -zero but not necessarily zero. Therefore, if you increase the strength of L1 -regularization, you will have a smaller model size because many of the model -weights will be zero. This is often desirable when the feature space is very -large but sparse, and when there are resource constraints that prevent you from -serving a model that is too large. - -In practice, you should try various combinations of L1, L2 regularization -strengths and find the best parameters that best control overfitting and give -you a desirable model size. - -## How Logistic Regression Works - -Finally, let's take a minute to talk about what the Logistic Regression model -actually looks like in case you're not already familiar with it. We'll denote -the label as \\(Y\\), and the set of observed features as a feature vector -\\(\mathbf{x}=[x_1, x_2, ..., x_d]\\). We define \\(Y=1\\) if an individual -earned > 50,000 dollars and \\(Y=0\\) otherwise. In Logistic Regression, the -probability of the label being positive (\\(Y=1\\)) given the features -\\(\mathbf{x}\\) is given as: - -$$ P(Y=1|\mathbf{x}) = \frac{1}{1+\exp(-(\mathbf{w}^T\mathbf{x}+b))}$$ - -where \\(\mathbf{w}=[w_1, w_2, ..., w_d]\\) are the model weights for the -features \\(\mathbf{x}=[x_1, x_2, ..., x_d]\\). \\(b\\) is a constant that is -often called the **bias** of the model. The equation consists of two parts—A -linear model and a logistic function: - -* **Linear Model**: First, we can see that \\(\mathbf{w}^T\mathbf{x}+b = b + - w_1x_1 + ... +w_dx_d\\) is a linear model where the output is a linear - function of the input features \\(\mathbf{x}\\). The bias \\(b\\) is the - prediction one would make without observing any features. The model weight - \\(w_i\\) reflects how the feature \\(x_i\\) is correlated with the positive - label. If \\(x_i\\) is positively correlated with the positive label, the - weight \\(w_i\\) increases, and the probability \\(P(Y=1|\mathbf{x})\\) will - be closer to 1. On the other hand, if \\(x_i\\) is negatively correlated - with the positive label, then the weight \\(w_i\\) decreases and the - probability \\(P(Y=1|\mathbf{x})\\) will be closer to 0. - -* **Logistic Function**: Second, we can see that there's a logistic function - (also known as the sigmoid function) \\(S(t) = 1/(1+\exp(-t))\\) being - applied to the linear model. The logistic function is used to convert the - output of the linear model \\(\mathbf{w}^T\mathbf{x}+b\\) from any real - number into the range of \\([0, 1]\\), which can be interpreted as a - probability. - -Model training is an optimization problem: The goal is to find a set of model -weights (i.e. model parameters) to minimize a **loss function** defined over the -training data, such as logistic loss for Logistic Regression models. The loss -function measures the discrepancy between the ground-truth label and the model's -prediction. If the prediction is very close to the ground-truth label, the loss -value will be low; if the prediction is very far from the label, then the loss -value would be high. - -## Learn Deeper - -If you're interested in learning more, check out our -@{$wide_and_deep$Wide & Deep Learning Tutorial} where we'll show you how to -combine the strengths of linear models and deep neural networks by jointly -training them using the tf.estimator API. diff --git a/tensorflow/docs_src/tutorials/wide_and_deep.md b/tensorflow/docs_src/tutorials/wide_and_deep.md deleted file mode 100644 index 44677a810bc5c253c198d81fae2be723c4f8ae4e..0000000000000000000000000000000000000000 --- a/tensorflow/docs_src/tutorials/wide_and_deep.md +++ /dev/null @@ -1,243 +0,0 @@ -# TensorFlow Wide & Deep Learning Tutorial - -In the previous @{$wide$TensorFlow Linear Model Tutorial}, we trained a logistic -regression model to predict the probability that the individual has an annual -income of over 50,000 dollars using the -[Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/Census+Income). -TensorFlow is great for training deep neural networks too, and you might be -thinking which one you should choose—well, why not both? Would it be possible to -combine the strengths of both in one model? - -In this tutorial, we'll introduce how to use the tf.estimator API to jointly -train a wide linear model and a deep feed-forward neural network. This approach -combines the strengths of memorization and generalization. It's useful for -generic large-scale regression and classification problems with sparse input -features (e.g., categorical features with a large number of possible feature -values). If you're interested in learning more about how Wide & Deep Learning -works, please check out our [research paper](https://arxiv.org/abs/1606.07792). - -![Wide & Deep Spectrum of Models](https://www.tensorflow.org/images/wide_n_deep.svg "Wide & Deep") - -The figure above shows a comparison of a wide model (logistic regression with -sparse features and transformations), a deep model (feed-forward neural network -with an embedding layer and several hidden layers), and a Wide & Deep model -(joint training of both). At a high level, there are only 3 steps to configure a -wide, deep, or Wide & Deep model using the tf.estimator API: - -1. Select features for the wide part: Choose the sparse base columns and - crossed columns you want to use. -1. Select features for the deep part: Choose the continuous columns, the - embedding dimension for each categorical column, and the hidden layer sizes. -1. Put them all together in a Wide & Deep model - (`DNNLinearCombinedClassifier`). - -And that's it! Let's go through a simple example. - -## Setup - -To try the code for this tutorial: - -1. @{$install$Install TensorFlow} if you haven't already. - -2. Download [the tutorial code](https://github.com/tensorflow/models/tree/master/official/wide_deep/). - -3. Execute the data download script we provide to you: - - $ python data_download.py - -4. Execute the tutorial code with the following command to train the wide and -deep model described in this tutorial: - - $ python wide_deep.py - -Read on to find out how this code builds its model. - - -## Define Base Feature Columns - -First, let's define the base categorical and continuous feature columns that -we'll use. These base columns will be the building blocks used by both the wide -part and the deep part of the model. - -```python -import tensorflow as tf - -# Continuous columns -age = tf.feature_column.numeric_column('age') -education_num = tf.feature_column.numeric_column('education_num') -capital_gain = tf.feature_column.numeric_column('capital_gain') -capital_loss = tf.feature_column.numeric_column('capital_loss') -hours_per_week = tf.feature_column.numeric_column('hours_per_week') - -education = tf.feature_column.categorical_column_with_vocabulary_list( - 'education', [ - 'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college', - 'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school', - '5th-6th', '10th', '1st-4th', 'Preschool', '12th']) - -marital_status = tf.feature_column.categorical_column_with_vocabulary_list( - 'marital_status', [ - 'Married-civ-spouse', 'Divorced', 'Married-spouse-absent', - 'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed']) - -relationship = tf.feature_column.categorical_column_with_vocabulary_list( - 'relationship', [ - 'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried', - 'Other-relative']) - -workclass = tf.feature_column.categorical_column_with_vocabulary_list( - 'workclass', [ - 'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov', - 'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked']) - -# To show an example of hashing: -occupation = tf.feature_column.categorical_column_with_hash_bucket( - 'occupation', hash_bucket_size=1000) - -# Transformations. -age_buckets = tf.feature_column.bucketized_column( - age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]) -``` - -## The Wide Model: Linear Model with Crossed Feature Columns - -The wide model is a linear model with a wide set of sparse and crossed feature -columns: - -```python -base_columns = [ - education, marital_status, relationship, workclass, occupation, - age_buckets, -] - -crossed_columns = [ - tf.feature_column.crossed_column( - ['education', 'occupation'], hash_bucket_size=1000), - tf.feature_column.crossed_column( - [age_buckets, 'education', 'occupation'], hash_bucket_size=1000), -] -``` - -You can also see the @{$wide$TensorFlow Linear Model Tutorial} for more details. - -Wide models with crossed feature columns can memorize sparse interactions -between features effectively. That being said, one limitation of crossed feature -columns is that they do not generalize to feature combinations that have not -appeared in the training data. Let's add a deep model with embeddings to fix -that. - -## The Deep Model: Neural Network with Embeddings - -The deep model is a feed-forward neural network, as shown in the previous -figure. Each of the sparse, high-dimensional categorical features are first -converted into a low-dimensional and dense real-valued vector, often referred to -as an embedding vector. These low-dimensional dense embedding vectors are -concatenated with the continuous features, and then fed into the hidden layers -of a neural network in the forward pass. The embedding values are initialized -randomly, and are trained along with all other model parameters to minimize the -training loss. If you're interested in learning more about embeddings, check out -the TensorFlow tutorial on @{$word2vec$Vector Representations of Words} or -[Word embedding](https://en.wikipedia.org/wiki/Word_embedding) on Wikipedia. - -Another way to represent categorical columns to feed into a neural network is -via a one-hot or multi-hot representation. This is often appropriate for -categorical columns with only a few possible values. As an example of a one-hot -representation, for the relationship column, `"Husband"` can be represented as -[1, 0, 0, 0, 0, 0], and `"Not-in-family"` as [0, 1, 0, 0, 0, 0], etc. This is a -fixed representation, whereas embeddings are more flexible and calculated at -training time. - -We'll configure the embeddings for the categorical columns using -`embedding_column`, and concatenate them with the continuous columns. -We also use `indicator_column` to create multi-hot representations of some -categorical columns. - -```python -deep_columns = [ - age, - education_num, - capital_gain, - capital_loss, - hours_per_week, - tf.feature_column.indicator_column(workclass), - tf.feature_column.indicator_column(education), - tf.feature_column.indicator_column(marital_status), - tf.feature_column.indicator_column(relationship), - # To show an example of embedding - tf.feature_column.embedding_column(occupation, dimension=8), -] -``` - -The higher the `dimension` of the embedding is, the more degrees of freedom the -model will have to learn the representations of the features. For simplicity, we -set the dimension to 8 for all feature columns here. Empirically, a more -informed decision for the number of dimensions is to start with a value on the -order of \\(\log_2(n)\\) or \\(k\sqrt[4]n\\), where \\(n\\) is the number of -unique features in a feature column and \\(k\\) is a small constant (usually -smaller than 10). - -Through dense embeddings, deep models can generalize better and make predictions -on feature pairs that were previously unseen in the training data. However, it -is difficult to learn effective low-dimensional representations for feature -columns when the underlying interaction matrix between two feature columns is -sparse and high-rank. In such cases, the interaction between most feature pairs -should be zero except a few, but dense embeddings will lead to nonzero -predictions for all feature pairs, and thus can over-generalize. On the other -hand, linear models with crossed features can memorize these “exception rules” -effectively with fewer model parameters. - -Now, let's see how to jointly train wide and deep models and allow them to -complement each other’s strengths and weaknesses. - -## Combining Wide and Deep Models into One - -The wide models and deep models are combined by summing up their final output -log odds as the prediction, then feeding the prediction to a logistic loss -function. All the graph definition and variable allocations have already been -handled for you under the hood, so you simply need to create a -`DNNLinearCombinedClassifier`: - -```python -model = tf.estimator.DNNLinearCombinedClassifier( - model_dir='/tmp/census_model', - linear_feature_columns=base_columns + crossed_columns, - dnn_feature_columns=deep_columns, - dnn_hidden_units=[100, 50]) -``` - -## Training and Evaluating The Model - -Before we train the model, let's read in the Census dataset as we did in the -@{$wide$TensorFlow Linear Model tutorial}. See `data_download.py` as well as -`input_fn` within -[`wide_deep.py`](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py). - -After reading in the data, you can train and evaluate the model: - -```python -# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs. -for n in range(FLAGS.train_epochs // FLAGS.epochs_per_eval): - model.train(input_fn=lambda: input_fn( - FLAGS.train_data, FLAGS.epochs_per_eval, True, FLAGS.batch_size)) - - results = model.evaluate(input_fn=lambda: input_fn( - FLAGS.test_data, 1, False, FLAGS.batch_size)) - - # Display evaluation metrics - print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval) - print('-' * 30) - - for key in sorted(results): - print('%s: %s' % (key, results[key])) -``` - -The final output accuracy should be somewhere around 85.5%. If you'd like to -see a working end-to-end example, you can download our -[example code](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py). - -Note that this tutorial is just a quick example on a small dataset to get you -familiar with the API. Wide & Deep Learning will be even more powerful if you -try it on a large dataset with many sparse feature columns that have a large -number of possible feature values. Again, feel free to take a look at our -[research paper](https://arxiv.org/abs/1606.07792) for more ideas about how to -apply Wide & Deep Learning in real-world large-scale machine learning problems. diff --git a/tensorflow/examples/android/.gitignore b/tensorflow/examples/android/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d245ab61095a6f9b6d2077aac934f9b13e66d85e --- /dev/null +++ b/tensorflow/examples/android/.gitignore @@ -0,0 +1,29 @@ +# This file is based on https://github.com/github/gitignore/blob/master/Android.gitignore +*.iml +.idea/compiler.xml +.idea/copyright +.idea/dictionaries +.idea/gradle.xml +.idea/libraries +.idea/inspectionProfiles +.idea/misc.xml +.idea/modules.xml +.idea/runConfigurations.xml +.idea/tasks.xml +.idea/workspace.xml +.gradle +local.properties +.DS_Store +build/ +gradleBuild/ +*.apk +*.ap_ +*.dex +*.class +bin/ +gen/ +out/ +*.log +.navigation/ +/captures +.externalNativeBuild diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD index 07f096418f53219c9ec7000a4560d78a3ff609e1..f327b645f58f35cedd27baa8ab521e334c8e7b15 100644 --- a/tensorflow/examples/android/BUILD +++ b/tensorflow/examples/android/BUILD @@ -1,6 +1,8 @@ # Description: # TensorFlow camera demo app for Android. +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md index 30a26d13c5734c5cf4a3b565c793db3e093c8271..dac9b7ab82c97d4d694374fea82d4d6fda85e0a0 100644 --- a/tensorflow/examples/android/README.md +++ b/tensorflow/examples/android/README.md @@ -45,11 +45,7 @@ on API >= 14 devices. ## Prebuilt Components: -If you just want the fastest path to trying the demo, you may download the -nightly build -[here](https://ci.tensorflow.org/view/Nightly/job/nightly-android/). Expand the -"View" and then the "out" folders under "Last Successful Artifacts" to find -tensorflow_demo.apk. +The fastest path to trying the demo is to download the [prebuilt demo APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). Also available are precompiled native libraries, and a jcenter package that you may simply drop into your own applications. See @@ -113,8 +109,7 @@ protobuf compilation. NOTE: Bazel does not currently support building for Android on Windows. Full support for gradle/cmake builds is coming soon, but in the meantime we suggest -that Windows users download the [prebuilt -binaries](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) instead. +that Windows users download the [prebuilt demo APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) instead. ##### Install Bazel and Android Prerequisites diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java index 614d3c7dd7766bb6eb7cd83deb85064d9522cbe5..9739e580185b316b3cc509e815ac05a28a267b29 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java @@ -137,7 +137,7 @@ public class TensorFlowObjectDetectionAPIModel implements Classifier { Trace.beginSection("recognizeImage"); Trace.beginSection("preprocessBitmap"); - // Preprocess the image data from 0-255 int to normalized float based + // Preprocess the image data to extract R, G and B bytes from int of form 0x00RRGGBB // on the provided parameters. bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py index 307eede5c03780e9244b035f020fc7846290d4d9..740224744860fdd76bea9c4531242a4976b20784 100644 --- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py +++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py @@ -17,7 +17,7 @@ This version is like fully_connected_feed.py but uses data converted to a TFRecords file containing tf.train.Example protocol buffers. See: -https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files +https://www.tensorflow.org/guide/reading_data#reading_from_files for context. YOU MUST run convert_to_records before running this (but you only need to diff --git a/tensorflow/examples/ios/README.md b/tensorflow/examples/ios/README.md index 5d7bd36837b2a2c33ab4bc311a582c174666dcd5..64412d25a00f55543f011b4ae3aaa85f03894ab5 100644 --- a/tensorflow/examples/ios/README.md +++ b/tensorflow/examples/ios/README.md @@ -190,8 +190,5 @@ increase you see in your own app is similar, and if it's larger, look at the "Other Linker Flags" used in the Simple Xcode project settings to strip the executable. -After that, you can manually look at modifying the list of kernels -included in tensorflow/contrib/makefile/tf_op_files.txt to reduce the number of -implementations to the ones you're actually using in your own model. We're -hoping to automate this step in the future, but for now manually removing them -is the best approach. +For further optimization, please refer to the ["Optimization" section](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/makefile#optimization) +of the makefile instructions. diff --git a/tensorflow/examples/saved_model/saved_model_half_plus_two.py b/tensorflow/examples/saved_model/saved_model_half_plus_two.py index 0d6f1ef655bcaba43c0d68e1e924bcb4b29967af..2d1e0c6f6de88ae116fe1951ca24505d41743fa9 100644 --- a/tensorflow/examples/saved_model/saved_model_half_plus_two.py +++ b/tensorflow/examples/saved_model/saved_model_half_plus_two.py @@ -33,6 +33,13 @@ where `a`, `b` and `c` are variables with `a=0.5` and `b=2` and `c=3`. Output from this program is typically used to exercise SavedModel load and execution code. + +To create a CPU model: + bazel run -c opt saved_half_plus_two -- --device=cpu + +To create GPU model: + bazel run --config=cuda -c opt saved_half_plus_two -- \ + --device=gpu """ from __future__ import absolute_import @@ -105,42 +112,52 @@ def _build_classification_signature(input_tensor, scores_tensor): def _generate_saved_model_for_half_plus_two(export_dir, as_text=False, - use_main_op=False): + use_main_op=False, + device_type="cpu"): """Generates SavedModel for half plus two. Args: export_dir: The directory to which the SavedModel should be written. as_text: Writes the SavedModel protocol buffer in text format to disk. use_main_op: Whether to supply a main op during SavedModel build time. + device_name: Device to force ops to run on. """ builder = tf.saved_model.builder.SavedModelBuilder(export_dir) - with tf.Session(graph=tf.Graph()) as sess: - # Set up the model parameters as variables to exercise variable loading - # functionality upon restore. - a = tf.Variable(0.5, name="a") - b = tf.Variable(2.0, name="b") - c = tf.Variable(3.0, name="c") - - # Create a placeholder for serialized tensorflow.Example messages to be fed. - serialized_tf_example = tf.placeholder(tf.string, name="tf_example") - - # Parse the tensorflow.Example looking for a feature named "x" with a single - # floating point value. - feature_configs = { - "x": tf.FixedLenFeature( - [1], dtype=tf.float32), - "x2": tf.FixedLenFeature( - [1], dtype=tf.float32, default_value=[0.0]) - } - tf_example = tf.parse_example(serialized_tf_example, feature_configs) - # Use tf.identity() to assign name - x = tf.identity(tf_example["x"], name="x") - y = tf.add(tf.multiply(a, x), b, name="y") - y2 = tf.add(tf.multiply(a, x), c, name="y2") - - x2 = tf.identity(tf_example["x2"], name="x2") - y3 = tf.add(tf.multiply(a, x2), c, name="y3") + device_name = "/cpu:0" + if device_type == "gpu": + device_name = "/gpu:0" + + with tf.Session( + graph=tf.Graph(), + config=tf.ConfigProto(log_device_placement=True)) as sess: + with tf.device(device_name): + # Set up the model parameters as variables to exercise variable loading + # functionality upon restore. + a = tf.Variable(0.5, name="a") + b = tf.Variable(2.0, name="b") + c = tf.Variable(3.0, name="c") + + # Create a placeholder for serialized tensorflow.Example messages to be + # fed. + serialized_tf_example = tf.placeholder(tf.string, name="tf_example") + + # Parse the tensorflow.Example looking for a feature named "x" with a + # single floating point value. + feature_configs = { + "x": tf.FixedLenFeature([1], dtype=tf.float32), + "x2": tf.FixedLenFeature([1], dtype=tf.float32, default_value=[0.0]) + } + # parse_example only works on CPU + with tf.device("/cpu:0"): + tf_example = tf.parse_example(serialized_tf_example, feature_configs) + # Use tf.identity() to assign name + x = tf.identity(tf_example["x"], name="x") + y = tf.add(tf.multiply(a, x), b, name="y") + y2 = tf.add(tf.multiply(a, x), c, name="y2") + + x2 = tf.identity(tf_example["x2"], name="x2") + y3 = tf.add(tf.multiply(a, x2), c, name="y3") # Create an assets file that can be saved and restored as part of the # SavedModel. @@ -185,20 +202,7 @@ def _generate_saved_model_for_half_plus_two(export_dir, } # Initialize all variables and then save the SavedModel. sess.run(tf.global_variables_initializer()) - signature_def_map = { - "regress_x_to_y": - _build_regression_signature(serialized_tf_example, y), - "regress_x_to_y2": - _build_regression_signature(serialized_tf_example, y2), - "regress_x2_to_y3": - _build_regression_signature(x2, y3), - "classify_x_to_y": - _build_classification_signature(serialized_tf_example, y), - "classify_x2_to_y3": - _build_classification_signature(x2, y3), - tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: - predict_signature_def - } + if use_main_op: builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], @@ -212,19 +216,30 @@ def _generate_saved_model_for_half_plus_two(export_dir, signature_def_map=signature_def_map, assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS), legacy_init_op=tf.group(assign_filename_op)) - builder.save(as_text) + builder.save(as_text) def main(_): - _generate_saved_model_for_half_plus_two(FLAGS.output_dir) - print("SavedModel generated at: %s" % FLAGS.output_dir) + _generate_saved_model_for_half_plus_two( + FLAGS.output_dir, device_type=FLAGS.device) + print("SavedModel generated for %(device)s at: %(dir)s" % { + "device": FLAGS.device, + "dir": FLAGS.output_dir + }) - _generate_saved_model_for_half_plus_two(FLAGS.output_dir_pbtxt, as_text=True) - print("SavedModel generated at: %s" % FLAGS.output_dir_pbtxt) + _generate_saved_model_for_half_plus_two( + FLAGS.output_dir_pbtxt, as_text=True, device_type=FLAGS.device) + print("SavedModel generated for %(device)s at: %(dir)s" % { + "device": FLAGS.device, + "dir": FLAGS.output_dir_pbtxt + }) _generate_saved_model_for_half_plus_two( - FLAGS.output_dir_main_op, use_main_op=True) - print("SavedModel generated at: %s" % FLAGS.output_dir_main_op) + FLAGS.output_dir_main_op, use_main_op=True, device_type=FLAGS.device) + print("SavedModel generated for %(device)s at: %(dir)s " % { + "device": FLAGS.device, + "dir": FLAGS.output_dir_main_op + }) if __name__ == "__main__": @@ -244,5 +259,10 @@ if __name__ == "__main__": type=str, default="/tmp/saved_model_half_plus_two_main_op", help="Directory where to output the SavedModel with a main op.") + parser.add_argument( + "--device", + type=str, + default="cpu", + help="Force model to run on 'cpu' or 'gpu'") FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/speech_commands/BUILD b/tensorflow/examples/speech_commands/BUILD index 13bca34a86b0c2fba7e5e8e3527d13587feacaae..7a44e2ee4fdf690ce576f720bb371785f88779b4 100644 --- a/tensorflow/examples/speech_commands/BUILD +++ b/tensorflow/examples/speech_commands/BUILD @@ -56,6 +56,7 @@ tf_py_test( srcs = ["input_data_test.py"], additional_deps = [ ":input_data", + ":models", "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/examples/speech_commands/freeze.py b/tensorflow/examples/speech_commands/freeze.py index c8671d9c41169c07ce3134a49bf81a4ac29a8c60..89e790d4e4436cdc49af0fb2ae53dea8485ae9c5 100644 --- a/tensorflow/examples/speech_commands/freeze.py +++ b/tensorflow/examples/speech_commands/freeze.py @@ -54,7 +54,7 @@ FLAGS = None def create_inference_graph(wanted_words, sample_rate, clip_duration_ms, clip_stride_ms, window_size_ms, window_stride_ms, - dct_coefficient_count, model_architecture): + feature_bin_count, model_architecture, preprocess): """Creates an audio model with the nodes needed for inference. Uses the supplied arguments to create a model, and inserts the input and @@ -67,14 +67,19 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms, clip_stride_ms: How often to run recognition. Useful for models with cache. window_size_ms: Time slice duration to estimate frequencies from. window_stride_ms: How far apart time slices should be. - dct_coefficient_count: Number of frequency bands to analyze. + feature_bin_count: Number of frequency bands to analyze. model_architecture: Name of the kind of model to generate. + preprocess: How the spectrogram is processed to produce features, for + example 'mfcc' or 'average'. + + Raises: + Exception: If the preprocessing mode isn't recognized. """ words_list = input_data.prepare_words_list(wanted_words.split(',')) model_settings = models.prepare_model_settings( len(words_list), sample_rate, clip_duration_ms, window_size_ms, - window_stride_ms, dct_coefficient_count) + window_stride_ms, feature_bin_count, preprocess) runtime_settings = {'clip_stride_ms': clip_stride_ms} wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data') @@ -88,15 +93,25 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms, window_size=model_settings['window_size_samples'], stride=model_settings['window_stride_samples'], magnitude_squared=True) - fingerprint_input = contrib_audio.mfcc( - spectrogram, - decoded_sample_data.sample_rate, - dct_coefficient_count=dct_coefficient_count) - fingerprint_frequency_size = model_settings['dct_coefficient_count'] - fingerprint_time_size = model_settings['spectrogram_length'] - reshaped_input = tf.reshape(fingerprint_input, [ - -1, fingerprint_time_size * fingerprint_frequency_size - ]) + + if preprocess == 'average': + fingerprint_input = tf.nn.pool( + tf.expand_dims(spectrogram, -1), + window_shape=[1, model_settings['average_window_width']], + strides=[1, model_settings['average_window_width']], + pooling_type='AVG', + padding='SAME') + elif preprocess == 'mfcc': + fingerprint_input = contrib_audio.mfcc( + spectrogram, + sample_rate, + dct_coefficient_count=model_settings['fingerprint_width']) + else: + raise Exception('Unknown preprocess mode "%s" (should be "mfcc" or' + ' "average")' % (preprocess)) + + fingerprint_size = model_settings['fingerprint_size'] + reshaped_input = tf.reshape(fingerprint_input, [-1, fingerprint_size]) logits = models.create_model( reshaped_input, model_settings, model_architecture, is_training=False, @@ -110,10 +125,12 @@ def main(_): # Create the model and load its weights. sess = tf.InteractiveSession() - create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate, - FLAGS.clip_duration_ms, FLAGS.clip_stride_ms, - FLAGS.window_size_ms, FLAGS.window_stride_ms, - FLAGS.dct_coefficient_count, FLAGS.model_architecture) + create_inference_graph( + FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms, + FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms, + FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess) + if FLAGS.quantize: + tf.contrib.quantize.create_eval_graph() models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) # Turn all the variables into inline constants inside the graph and save it. @@ -155,10 +172,11 @@ if __name__ == '__main__': default=10.0, help='How long the stride is between spectrogram timeslices',) parser.add_argument( - '--dct_coefficient_count', + '--feature_bin_count', type=int, default=40, - help='How many bins to use for the MFCC fingerprint',) + help='How many bins to use for the MFCC fingerprint', + ) parser.add_argument( '--start_checkpoint', type=str, @@ -176,5 +194,15 @@ if __name__ == '__main__': help='Words to use (others will be added to an unknown label)',) parser.add_argument( '--output_file', type=str, help='Where to save the frozen graph.') + parser.add_argument( + '--quantize', + type=bool, + default=False, + help='Whether to train the model for eight-bit deployment') + parser.add_argument( + '--preprocess', + type=str, + default='mfcc', + help='Spectrogram processing mode. Can be "mfcc" or "average"') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/speech_commands/freeze_test.py b/tensorflow/examples/speech_commands/freeze_test.py index 97c6eac675f696d89d069258edf6eec901cfad0b..c8de6c2152909cd6dfca9acc895c25b0ae8e09ca 100644 --- a/tensorflow/examples/speech_commands/freeze_test.py +++ b/tensorflow/examples/speech_commands/freeze_test.py @@ -24,14 +24,62 @@ from tensorflow.python.platform import test class FreezeTest(test.TestCase): - def testCreateInferenceGraph(self): + def testCreateInferenceGraphWithMfcc(self): with self.test_session() as sess: - freeze.create_inference_graph('a,b,c,d', 16000, 1000.0, 30.0, 30.0, 10.0, - 40, 'conv') + freeze.create_inference_graph( + wanted_words='a,b,c,d', + sample_rate=16000, + clip_duration_ms=1000.0, + clip_stride_ms=30.0, + window_size_ms=30.0, + window_stride_ms=10.0, + feature_bin_count=40, + model_architecture='conv', + preprocess='mfcc') self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0')) self.assertIsNotNone( sess.graph.get_tensor_by_name('decoded_sample_data:0')) self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0')) + ops = [node.op for node in sess.graph_def.node] + self.assertEqual(1, ops.count('Mfcc')) + + def testCreateInferenceGraphWithoutMfcc(self): + with self.test_session() as sess: + freeze.create_inference_graph( + wanted_words='a,b,c,d', + sample_rate=16000, + clip_duration_ms=1000.0, + clip_stride_ms=30.0, + window_size_ms=30.0, + window_stride_ms=10.0, + feature_bin_count=40, + model_architecture='conv', + preprocess='average') + self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0')) + self.assertIsNotNone( + sess.graph.get_tensor_by_name('decoded_sample_data:0')) + self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0')) + ops = [node.op for node in sess.graph_def.node] + self.assertEqual(0, ops.count('Mfcc')) + + def testFeatureBinCount(self): + with self.test_session() as sess: + freeze.create_inference_graph( + wanted_words='a,b,c,d', + sample_rate=16000, + clip_duration_ms=1000.0, + clip_stride_ms=30.0, + window_size_ms=30.0, + window_stride_ms=10.0, + feature_bin_count=80, + model_architecture='conv', + preprocess='average') + self.assertIsNotNone(sess.graph.get_tensor_by_name('wav_data:0')) + self.assertIsNotNone( + sess.graph.get_tensor_by_name('decoded_sample_data:0')) + self.assertIsNotNone(sess.graph.get_tensor_by_name('labels_softmax:0')) + ops = [node.op for node in sess.graph_def.node] + self.assertEqual(0, ops.count('Mfcc')) if __name__ == '__main__': diff --git a/tensorflow/examples/speech_commands/generate_streaming_test_wav.py b/tensorflow/examples/speech_commands/generate_streaming_test_wav.py index 053206ae2f144ce05efa7eb490626aef01a6bc49..9858906927737cd520a9fd02f04437d01e0f6d31 100644 --- a/tensorflow/examples/speech_commands/generate_streaming_test_wav.py +++ b/tensorflow/examples/speech_commands/generate_streaming_test_wav.py @@ -87,11 +87,12 @@ def main(_): words_list = input_data.prepare_words_list(FLAGS.wanted_words.split(',')) model_settings = models.prepare_model_settings( len(words_list), FLAGS.sample_rate, FLAGS.clip_duration_ms, - FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count) + FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.feature_bin_count, + 'mfcc') audio_processor = input_data.AudioProcessor( '', FLAGS.data_dir, FLAGS.silence_percentage, 10, FLAGS.wanted_words.split(','), FLAGS.validation_percentage, - FLAGS.testing_percentage, model_settings) + FLAGS.testing_percentage, model_settings, FLAGS.data_dir) output_audio_sample_count = FLAGS.sample_rate * FLAGS.test_duration_seconds output_audio = np.zeros((output_audio_sample_count,), dtype=np.float32) @@ -242,10 +243,11 @@ if __name__ == '__main__': default=10.0, help='How long the stride is between spectrogram timeslices',) parser.add_argument( - '--dct_coefficient_count', + '--feature_bin_count', type=int, default=40, - help='How many bins to use for the MFCC fingerprint',) + help='How many bins to use for the MFCC fingerprint', + ) parser.add_argument( '--wanted_words', type=str, diff --git a/tensorflow/examples/speech_commands/input_data.py b/tensorflow/examples/speech_commands/input_data.py index 63dd18457fea42acb09058b9ddd4623d72d1fd04..30f2cfa9fef7d0b5800c7e557bde4702dbafaf26 100644 --- a/tensorflow/examples/speech_commands/input_data.py +++ b/tensorflow/examples/speech_commands/input_data.py @@ -153,14 +153,14 @@ class AudioProcessor(object): def __init__(self, data_url, data_dir, silence_percentage, unknown_percentage, wanted_words, validation_percentage, testing_percentage, - model_settings): + model_settings, summaries_dir): self.data_dir = data_dir self.maybe_download_and_extract_dataset(data_url, data_dir) self.prepare_data_index(silence_percentage, unknown_percentage, wanted_words, validation_percentage, testing_percentage) self.prepare_background_data() - self.prepare_processing_graph(model_settings) + self.prepare_processing_graph(model_settings, summaries_dir) def maybe_download_and_extract_dataset(self, data_url, dest_directory): """Download and extract data set tar file. @@ -325,7 +325,7 @@ class AudioProcessor(object): if not self.background_data: raise Exception('No background wav files were found in ' + search_path) - def prepare_processing_graph(self, model_settings): + def prepare_processing_graph(self, model_settings, summaries_dir): """Builds a TensorFlow graph to apply the input distortions. Creates a graph that loads a WAVE file, decodes it, scales the volume, @@ -341,48 +341,88 @@ class AudioProcessor(object): - time_shift_offset_placeholder_: How much to move the clip in time. - background_data_placeholder_: PCM sample data for background noise. - background_volume_placeholder_: Loudness of mixed-in background. - - mfcc_: Output 2D fingerprint of processed audio. + - output_: Output 2D fingerprint of processed audio. Args: model_settings: Information about the current model being trained. + summaries_dir: Path to save training summary information to. + + Raises: + ValueError: If the preprocessing mode isn't recognized. """ - desired_samples = model_settings['desired_samples'] - self.wav_filename_placeholder_ = tf.placeholder(tf.string, []) - wav_loader = io_ops.read_file(self.wav_filename_placeholder_) - wav_decoder = contrib_audio.decode_wav( - wav_loader, desired_channels=1, desired_samples=desired_samples) - # Allow the audio sample's volume to be adjusted. - self.foreground_volume_placeholder_ = tf.placeholder(tf.float32, []) - scaled_foreground = tf.multiply(wav_decoder.audio, - self.foreground_volume_placeholder_) - # Shift the sample's start position, and pad any gaps with zeros. - self.time_shift_padding_placeholder_ = tf.placeholder(tf.int32, [2, 2]) - self.time_shift_offset_placeholder_ = tf.placeholder(tf.int32, [2]) - padded_foreground = tf.pad( - scaled_foreground, - self.time_shift_padding_placeholder_, - mode='CONSTANT') - sliced_foreground = tf.slice(padded_foreground, - self.time_shift_offset_placeholder_, - [desired_samples, -1]) - # Mix in background noise. - self.background_data_placeholder_ = tf.placeholder(tf.float32, - [desired_samples, 1]) - self.background_volume_placeholder_ = tf.placeholder(tf.float32, []) - background_mul = tf.multiply(self.background_data_placeholder_, - self.background_volume_placeholder_) - background_add = tf.add(background_mul, sliced_foreground) - background_clamp = tf.clip_by_value(background_add, -1.0, 1.0) - # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio. - spectrogram = contrib_audio.audio_spectrogram( - background_clamp, - window_size=model_settings['window_size_samples'], - stride=model_settings['window_stride_samples'], - magnitude_squared=True) - self.mfcc_ = contrib_audio.mfcc( - spectrogram, - wav_decoder.sample_rate, - dct_coefficient_count=model_settings['dct_coefficient_count']) + with tf.get_default_graph().name_scope('data'): + desired_samples = model_settings['desired_samples'] + self.wav_filename_placeholder_ = tf.placeholder( + tf.string, [], name='wav_filename') + wav_loader = io_ops.read_file(self.wav_filename_placeholder_) + wav_decoder = contrib_audio.decode_wav( + wav_loader, desired_channels=1, desired_samples=desired_samples) + # Allow the audio sample's volume to be adjusted. + self.foreground_volume_placeholder_ = tf.placeholder( + tf.float32, [], name='foreground_volume') + scaled_foreground = tf.multiply(wav_decoder.audio, + self.foreground_volume_placeholder_) + # Shift the sample's start position, and pad any gaps with zeros. + self.time_shift_padding_placeholder_ = tf.placeholder( + tf.int32, [2, 2], name='time_shift_padding') + self.time_shift_offset_placeholder_ = tf.placeholder( + tf.int32, [2], name='time_shift_offset') + padded_foreground = tf.pad( + scaled_foreground, + self.time_shift_padding_placeholder_, + mode='CONSTANT') + sliced_foreground = tf.slice(padded_foreground, + self.time_shift_offset_placeholder_, + [desired_samples, -1]) + # Mix in background noise. + self.background_data_placeholder_ = tf.placeholder( + tf.float32, [desired_samples, 1], name='background_data') + self.background_volume_placeholder_ = tf.placeholder( + tf.float32, [], name='background_volume') + background_mul = tf.multiply(self.background_data_placeholder_, + self.background_volume_placeholder_) + background_add = tf.add(background_mul, sliced_foreground) + background_clamp = tf.clip_by_value(background_add, -1.0, 1.0) + # Run the spectrogram and MFCC ops to get a 2D 'fingerprint' of the audio. + spectrogram = contrib_audio.audio_spectrogram( + background_clamp, + window_size=model_settings['window_size_samples'], + stride=model_settings['window_stride_samples'], + magnitude_squared=True) + tf.summary.image( + 'spectrogram', tf.expand_dims(spectrogram, -1), max_outputs=1) + # The number of buckets in each FFT row in the spectrogram will depend on + # how many input samples there are in each window. This can be quite + # large, with a 160 sample window producing 127 buckets for example. We + # don't need this level of detail for classification, so we often want to + # shrink them down to produce a smaller result. That's what this section + # implements. One method is to use average pooling to merge adjacent + # buckets, but a more sophisticated approach is to apply the MFCC + # algorithm to shrink the representation. + if model_settings['preprocess'] == 'average': + self.output_ = tf.nn.pool( + tf.expand_dims(spectrogram, -1), + window_shape=[1, model_settings['average_window_width']], + strides=[1, model_settings['average_window_width']], + pooling_type='AVG', + padding='SAME') + tf.summary.image('shrunk_spectrogram', self.output_, max_outputs=1) + elif model_settings['preprocess'] == 'mfcc': + self.output_ = contrib_audio.mfcc( + spectrogram, + wav_decoder.sample_rate, + dct_coefficient_count=model_settings['fingerprint_width']) + tf.summary.image( + 'mfcc', tf.expand_dims(self.output_, -1), max_outputs=1) + else: + raise ValueError('Unknown preprocess mode "%s" (should be "mfcc" or' + ' "average")' % (model_settings['preprocess'])) + + # Merge all the summaries and write them out to /tmp/retrain_logs (by + # default) + self.merged_summaries_ = tf.summary.merge_all(scope='data') + self.summary_writer_ = tf.summary.FileWriter(summaries_dir + '/data', + tf.get_default_graph()) def set_size(self, mode): """Calculates the number of samples in the dataset partition. @@ -418,6 +458,9 @@ class AudioProcessor(object): Returns: List of sample data for the transformed samples, and list of label indexes + + Raises: + ValueError: If background samples are too short. """ # Pick one of the partitions to choose samples from. candidates = self.data_index[mode] @@ -460,6 +503,11 @@ class AudioProcessor(object): if use_background or sample['label'] == SILENCE_LABEL: background_index = np.random.randint(len(self.background_data)) background_samples = self.background_data[background_index] + if len(background_samples) <= model_settings['desired_samples']: + raise ValueError( + 'Background sample is too short! Need more than %d' + ' samples but only %d were found' % + (model_settings['desired_samples'], len(background_samples))) background_offset = np.random.randint( 0, len(background_samples) - model_settings['desired_samples']) background_clipped = background_samples[background_offset:( @@ -482,7 +530,10 @@ class AudioProcessor(object): else: input_dict[self.foreground_volume_placeholder_] = 1 # Run the graph to produce the output audio. - data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten() + summary, data_tensor = sess.run( + [self.merged_summaries_, self.output_], feed_dict=input_dict) + self.summary_writer_.add_summary(summary) + data[i - offset, :] = data_tensor.flatten() label_index = self.word_to_index[sample['label']] labels[i - offset] = label_index return data, labels diff --git a/tensorflow/examples/speech_commands/input_data_test.py b/tensorflow/examples/speech_commands/input_data_test.py index 13f294d39dbf89367496d2a16f466f8e2195d900..2e551be9a208221dc8b788e4d795e68bde21c9e5 100644 --- a/tensorflow/examples/speech_commands/input_data_test.py +++ b/tensorflow/examples/speech_commands/input_data_test.py @@ -25,6 +25,7 @@ import tensorflow as tf from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio from tensorflow.examples.speech_commands import input_data +from tensorflow.examples.speech_commands import models from tensorflow.python.platform import test @@ -32,7 +33,7 @@ class InputDataTest(test.TestCase): def _getWavData(self): with self.test_session() as sess: - sample_data = tf.zeros([1000, 2]) + sample_data = tf.zeros([32000, 2]) wav_encoder = contrib_audio.encode_wav(sample_data, 16000) wav_data = sess.run(wav_encoder) return wav_data @@ -57,9 +58,31 @@ class InputDataTest(test.TestCase): "label_count": 4, "window_size_samples": 100, "window_stride_samples": 100, - "dct_coefficient_count": 40, + "fingerprint_width": 40, + "preprocess": "mfcc", } + def _runGetDataTest(self, preprocess, window_length_ms): + tmp_dir = self.get_temp_dir() + wav_dir = os.path.join(tmp_dir, "wavs") + os.mkdir(wav_dir) + self._saveWavFolders(wav_dir, ["a", "b", "c"], 100) + background_dir = os.path.join(wav_dir, "_background_noise_") + os.mkdir(background_dir) + wav_data = self._getWavData() + for i in range(10): + file_path = os.path.join(background_dir, "background_audio_%d.wav" % i) + self._saveTestWavFile(file_path, wav_data) + model_settings = models.prepare_model_settings( + 4, 16000, 1000, window_length_ms, 20, 40, preprocess) + with self.test_session() as sess: + audio_processor = input_data.AudioProcessor( + "", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir) + result_data, result_labels = audio_processor.get_data( + 10, 0, model_settings, 0.3, 0.1, 100, "training", sess) + self.assertEqual(10, len(result_data)) + self.assertEqual(10, len(result_labels)) + def testPrepareWordsList(self): words_list = ["a", "b"] self.assertGreater( @@ -76,8 +99,9 @@ class InputDataTest(test.TestCase): def testPrepareDataIndex(self): tmp_dir = self.get_temp_dir() self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100) - audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], - 10, 10, self._model_settings()) + audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, + ["a", "b"], 10, 10, + self._model_settings(), tmp_dir) self.assertLess(0, audio_processor.set_size("training")) self.assertTrue("training" in audio_processor.data_index) self.assertTrue("validation" in audio_processor.data_index) @@ -90,7 +114,7 @@ class InputDataTest(test.TestCase): self._saveWavFolders(tmp_dir, ["a", "b", "c"], 0) with self.assertRaises(Exception) as e: _ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], 10, 10, - self._model_settings()) + self._model_settings(), tmp_dir) self.assertTrue("No .wavs found" in str(e.exception)) def testPrepareDataIndexMissing(self): @@ -98,7 +122,7 @@ class InputDataTest(test.TestCase): self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100) with self.assertRaises(Exception) as e: _ = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b", "d"], 10, - 10, self._model_settings()) + 10, self._model_settings(), tmp_dir) self.assertTrue("Expected to find" in str(e.exception)) def testPrepareBackgroundData(self): @@ -110,8 +134,9 @@ class InputDataTest(test.TestCase): file_path = os.path.join(background_dir, "background_audio_%d.wav" % i) self._saveTestWavFile(file_path, wav_data) self._saveWavFolders(tmp_dir, ["a", "b", "c"], 100) - audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, ["a", "b"], - 10, 10, self._model_settings()) + audio_processor = input_data.AudioProcessor("", tmp_dir, 10, 10, + ["a", "b"], 10, 10, + self._model_settings(), tmp_dir) self.assertEqual(10, len(audio_processor.background_data)) def testLoadWavFile(self): @@ -148,44 +173,27 @@ class InputDataTest(test.TestCase): "label_count": 4, "window_size_samples": 100, "window_stride_samples": 100, - "dct_coefficient_count": 40, + "fingerprint_width": 40, + "preprocess": "mfcc", } audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"], - 10, 10, model_settings) + 10, 10, model_settings, tmp_dir) self.assertIsNotNone(audio_processor.wav_filename_placeholder_) self.assertIsNotNone(audio_processor.foreground_volume_placeholder_) self.assertIsNotNone(audio_processor.time_shift_padding_placeholder_) self.assertIsNotNone(audio_processor.time_shift_offset_placeholder_) self.assertIsNotNone(audio_processor.background_data_placeholder_) self.assertIsNotNone(audio_processor.background_volume_placeholder_) - self.assertIsNotNone(audio_processor.mfcc_) + self.assertIsNotNone(audio_processor.output_) - def testGetData(self): - tmp_dir = self.get_temp_dir() - wav_dir = os.path.join(tmp_dir, "wavs") - os.mkdir(wav_dir) - self._saveWavFolders(wav_dir, ["a", "b", "c"], 100) - background_dir = os.path.join(wav_dir, "_background_noise_") - os.mkdir(background_dir) - wav_data = self._getWavData() - for i in range(10): - file_path = os.path.join(background_dir, "background_audio_%d.wav" % i) - self._saveTestWavFile(file_path, wav_data) - model_settings = { - "desired_samples": 160, - "fingerprint_size": 40, - "label_count": 4, - "window_size_samples": 100, - "window_stride_samples": 100, - "dct_coefficient_count": 40, - } - audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"], - 10, 10, model_settings) - with self.test_session() as sess: - result_data, result_labels = audio_processor.get_data( - 10, 0, model_settings, 0.3, 0.1, 100, "training", sess) - self.assertEqual(10, len(result_data)) - self.assertEqual(10, len(result_labels)) + def testGetDataAverage(self): + self._runGetDataTest("average", 10) + + def testGetDataAverageLongWindow(self): + self._runGetDataTest("average", 30) + + def testGetDataMfcc(self): + self._runGetDataTest("mfcc", 30) def testGetUnprocessedData(self): tmp_dir = self.get_temp_dir() @@ -198,10 +206,11 @@ class InputDataTest(test.TestCase): "label_count": 4, "window_size_samples": 100, "window_stride_samples": 100, - "dct_coefficient_count": 40, + "fingerprint_width": 40, + "preprocess": "mfcc", } audio_processor = input_data.AudioProcessor("", wav_dir, 10, 10, ["a", "b"], - 10, 10, model_settings) + 10, 10, model_settings, tmp_dir) result_data, result_labels = audio_processor.get_unprocessed_data( 10, model_settings, "training") self.assertEqual(10, len(result_data)) diff --git a/tensorflow/examples/speech_commands/models.py b/tensorflow/examples/speech_commands/models.py index ab611f414a8afa1f08b955918071b04ae0ef88db..4d1454be0d733ccf6ea41f822030f139135fb895 100644 --- a/tensorflow/examples/speech_commands/models.py +++ b/tensorflow/examples/speech_commands/models.py @@ -24,9 +24,21 @@ import math import tensorflow as tf +def _next_power_of_two(x): + """Calculates the smallest enclosing power of two for an input. + + Args: + x: Positive float or integer number. + + Returns: + Next largest power of two integer. + """ + return 1 if x == 0 else 2**(int(x) - 1).bit_length() + + def prepare_model_settings(label_count, sample_rate, clip_duration_ms, - window_size_ms, window_stride_ms, - dct_coefficient_count): + window_size_ms, window_stride_ms, feature_bin_count, + preprocess): """Calculates common settings needed for all models. Args: @@ -35,10 +47,14 @@ def prepare_model_settings(label_count, sample_rate, clip_duration_ms, clip_duration_ms: Length of each audio clip to be analyzed. window_size_ms: Duration of frequency analysis window. window_stride_ms: How far to move in time between frequency windows. - dct_coefficient_count: Number of frequency bins to use for analysis. + feature_bin_count: Number of frequency bins to use for analysis. + preprocess: How the spectrogram is processed to produce features. Returns: Dictionary containing common settings. + + Raises: + ValueError: If the preprocessing mode isn't recognized. """ desired_samples = int(sample_rate * clip_duration_ms / 1000) window_size_samples = int(sample_rate * window_size_ms / 1000) @@ -48,16 +64,28 @@ def prepare_model_settings(label_count, sample_rate, clip_duration_ms, spectrogram_length = 0 else: spectrogram_length = 1 + int(length_minus_window / window_stride_samples) - fingerprint_size = dct_coefficient_count * spectrogram_length + if preprocess == 'average': + fft_bin_count = 1 + (_next_power_of_two(window_size_samples) / 2) + average_window_width = int(math.floor(fft_bin_count / feature_bin_count)) + fingerprint_width = int(math.ceil(fft_bin_count / average_window_width)) + elif preprocess == 'mfcc': + average_window_width = -1 + fingerprint_width = feature_bin_count + else: + raise ValueError('Unknown preprocess mode "%s" (should be "mfcc" or' + ' "average")' % (preprocess)) + fingerprint_size = fingerprint_width * spectrogram_length return { 'desired_samples': desired_samples, 'window_size_samples': window_size_samples, 'window_stride_samples': window_stride_samples, 'spectrogram_length': spectrogram_length, - 'dct_coefficient_count': dct_coefficient_count, + 'fingerprint_width': fingerprint_width, 'fingerprint_size': fingerprint_size, 'label_count': label_count, 'sample_rate': sample_rate, + 'preprocess': preprocess, + 'average_window_width': average_window_width, } @@ -106,10 +134,14 @@ def create_model(fingerprint_input, model_settings, model_architecture, elif model_architecture == 'low_latency_svdf': return create_low_latency_svdf_model(fingerprint_input, model_settings, is_training, runtime_settings) + elif model_architecture == 'tiny_conv': + return create_tiny_conv_model(fingerprint_input, model_settings, + is_training) else: raise Exception('model_architecture argument "' + model_architecture + '" not recognized, should be one of "single_fc", "conv",' + - ' "low_latency_conv, or "low_latency_svdf"') + ' "low_latency_conv, "low_latency_svdf",' + + ' or "tiny_conv"') def load_variables_from_checkpoint(sess, start_checkpoint): @@ -152,9 +184,12 @@ def create_single_fc_model(fingerprint_input, model_settings, is_training): dropout_prob = tf.placeholder(tf.float32, name='dropout_prob') fingerprint_size = model_settings['fingerprint_size'] label_count = model_settings['label_count'] - weights = tf.Variable( - tf.truncated_normal([fingerprint_size, label_count], stddev=0.001)) - bias = tf.Variable(tf.zeros([label_count])) + weights = tf.get_variable( + name='weights', + initializer=tf.truncated_normal_initializer(stddev=0.001), + shape=[fingerprint_size, label_count]) + bias = tf.get_variable( + name='bias', initializer=tf.zeros_initializer, shape=[label_count]) logits = tf.matmul(fingerprint_input, weights) + bias if is_training: return logits, dropout_prob @@ -212,18 +247,21 @@ def create_conv_model(fingerprint_input, model_settings, is_training): """ if is_training: dropout_prob = tf.placeholder(tf.float32, name='dropout_prob') - input_frequency_size = model_settings['dct_coefficient_count'] + input_frequency_size = model_settings['fingerprint_width'] input_time_size = model_settings['spectrogram_length'] fingerprint_4d = tf.reshape(fingerprint_input, [-1, input_time_size, input_frequency_size, 1]) first_filter_width = 8 first_filter_height = 20 first_filter_count = 64 - first_weights = tf.Variable( - tf.truncated_normal( - [first_filter_height, first_filter_width, 1, first_filter_count], - stddev=0.01)) - first_bias = tf.Variable(tf.zeros([first_filter_count])) + first_weights = tf.get_variable( + name='first_weights', + initializer=tf.truncated_normal_initializer(stddev=0.01), + shape=[first_filter_height, first_filter_width, 1, first_filter_count]) + first_bias = tf.get_variable( + name='first_bias', + initializer=tf.zeros_initializer, + shape=[first_filter_count]) first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [1, 1, 1, 1], 'SAME') + first_bias first_relu = tf.nn.relu(first_conv) @@ -235,14 +273,17 @@ def create_conv_model(fingerprint_input, model_settings, is_training): second_filter_width = 4 second_filter_height = 10 second_filter_count = 64 - second_weights = tf.Variable( - tf.truncated_normal( - [ - second_filter_height, second_filter_width, first_filter_count, - second_filter_count - ], - stddev=0.01)) - second_bias = tf.Variable(tf.zeros([second_filter_count])) + second_weights = tf.get_variable( + name='second_weights', + initializer=tf.truncated_normal_initializer(stddev=0.01), + shape=[ + second_filter_height, second_filter_width, first_filter_count, + second_filter_count + ]) + second_bias = tf.get_variable( + name='second_bias', + initializer=tf.zeros_initializer, + shape=[second_filter_count]) second_conv = tf.nn.conv2d(max_pool, second_weights, [1, 1, 1, 1], 'SAME') + second_bias second_relu = tf.nn.relu(second_conv) @@ -259,10 +300,14 @@ def create_conv_model(fingerprint_input, model_settings, is_training): flattened_second_conv = tf.reshape(second_dropout, [-1, second_conv_element_count]) label_count = model_settings['label_count'] - final_fc_weights = tf.Variable( - tf.truncated_normal( - [second_conv_element_count, label_count], stddev=0.01)) - final_fc_bias = tf.Variable(tf.zeros([label_count])) + final_fc_weights = tf.get_variable( + name='final_fc_weights', + initializer=tf.truncated_normal_initializer(stddev=0.01), + shape=[second_conv_element_count, label_count]) + final_fc_bias = tf.get_variable( + name='final_fc_bias', + initializer=tf.zeros_initializer, + shape=[label_count]) final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias if is_training: return final_fc, dropout_prob @@ -318,7 +363,7 @@ def create_low_latency_conv_model(fingerprint_input, model_settings, """ if is_training: dropout_prob = tf.placeholder(tf.float32, name='dropout_prob') - input_frequency_size = model_settings['dct_coefficient_count'] + input_frequency_size = model_settings['fingerprint_width'] input_time_size = model_settings['spectrogram_length'] fingerprint_4d = tf.reshape(fingerprint_input, [-1, input_time_size, input_frequency_size, 1]) @@ -327,11 +372,14 @@ def create_low_latency_conv_model(fingerprint_input, model_settings, first_filter_count = 186 first_filter_stride_x = 1 first_filter_stride_y = 1 - first_weights = tf.Variable( - tf.truncated_normal( - [first_filter_height, first_filter_width, 1, first_filter_count], - stddev=0.01)) - first_bias = tf.Variable(tf.zeros([first_filter_count])) + first_weights = tf.get_variable( + name='first_weights', + initializer=tf.truncated_normal_initializer(stddev=0.01), + shape=[first_filter_height, first_filter_width, 1, first_filter_count]) + first_bias = tf.get_variable( + name='first_bias', + initializer=tf.zeros_initializer, + shape=[first_filter_count]) first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [ 1, first_filter_stride_y, first_filter_stride_x, 1 ], 'VALID') + first_bias @@ -351,30 +399,42 @@ def create_low_latency_conv_model(fingerprint_input, model_settings, flattened_first_conv = tf.reshape(first_dropout, [-1, first_conv_element_count]) first_fc_output_channels = 128 - first_fc_weights = tf.Variable( - tf.truncated_normal( - [first_conv_element_count, first_fc_output_channels], stddev=0.01)) - first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels])) + first_fc_weights = tf.get_variable( + name='first_fc_weights', + initializer=tf.truncated_normal_initializer(stddev=0.01), + shape=[first_conv_element_count, first_fc_output_channels]) + first_fc_bias = tf.get_variable( + name='first_fc_bias', + initializer=tf.zeros_initializer, + shape=[first_fc_output_channels]) first_fc = tf.matmul(flattened_first_conv, first_fc_weights) + first_fc_bias if is_training: second_fc_input = tf.nn.dropout(first_fc, dropout_prob) else: second_fc_input = first_fc second_fc_output_channels = 128 - second_fc_weights = tf.Variable( - tf.truncated_normal( - [first_fc_output_channels, second_fc_output_channels], stddev=0.01)) - second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels])) + second_fc_weights = tf.get_variable( + name='second_fc_weights', + initializer=tf.truncated_normal_initializer(stddev=0.01), + shape=[first_fc_output_channels, second_fc_output_channels]) + second_fc_bias = tf.get_variable( + name='second_fc_bias', + initializer=tf.zeros_initializer, + shape=[second_fc_output_channels]) second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias if is_training: final_fc_input = tf.nn.dropout(second_fc, dropout_prob) else: final_fc_input = second_fc label_count = model_settings['label_count'] - final_fc_weights = tf.Variable( - tf.truncated_normal( - [second_fc_output_channels, label_count], stddev=0.01)) - final_fc_bias = tf.Variable(tf.zeros([label_count])) + final_fc_weights = tf.get_variable( + name='final_fc_weights', + initializer=tf.truncated_normal_initializer(stddev=0.01), + shape=[second_fc_output_channels, label_count]) + final_fc_bias = tf.get_variable( + name='final_fc_bias', + initializer=tf.zeros_initializer, + shape=[label_count]) final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias if is_training: return final_fc, dropout_prob @@ -422,7 +482,7 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings, Args: fingerprint_input: TensorFlow node that will output audio feature vectors. The node is expected to produce a 2D Tensor of shape: - [batch, model_settings['dct_coefficient_count'] * + [batch, model_settings['fingerprint_width'] * model_settings['spectrogram_length']] with the features corresponding to the same time slot arranged contiguously, and the oldest slot at index [:, 0], and newest at [:, -1]. @@ -440,7 +500,7 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings, if is_training: dropout_prob = tf.placeholder(tf.float32, name='dropout_prob') - input_frequency_size = model_settings['dct_coefficient_count'] + input_frequency_size = model_settings['fingerprint_width'] input_time_size = model_settings['spectrogram_length'] # Validation. @@ -462,8 +522,11 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings, num_filters = rank * num_units # Create the runtime memory: [num_filters, batch, input_time_size] batch = 1 - memory = tf.Variable(tf.zeros([num_filters, batch, input_time_size]), - trainable=False, name='runtime-memory') + memory = tf.get_variable( + initializer=tf.zeros_initializer, + shape=[num_filters, batch, input_time_size], + trainable=False, + name='runtime-memory') # Determine the number of new frames in the input, such that we only operate # on those. For training we do not use the memory, and thus use all frames # provided in the input. @@ -483,8 +546,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings, new_fingerprint_input = tf.expand_dims(new_fingerprint_input, 2) # Create the frequency filters. - weights_frequency = tf.Variable( - tf.truncated_normal([input_frequency_size, num_filters], stddev=0.01)) + weights_frequency = tf.get_variable( + name='weights_frequency', + initializer=tf.truncated_normal_initializer(stddev=0.01), + shape=[input_frequency_size, num_filters]) # Expand to add input channels dimensions. # weights_frequency: [input_frequency_size, 1, num_filters] weights_frequency = tf.expand_dims(weights_frequency, 1) @@ -506,8 +571,10 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings, activations_time = new_memory # Create the time filters. - weights_time = tf.Variable( - tf.truncated_normal([num_filters, input_time_size], stddev=0.01)) + weights_time = tf.get_variable( + name='weights_time', + initializer=tf.truncated_normal_initializer(stddev=0.01), + shape=[num_filters, input_time_size]) # Apply the time filter on the outputs of the feature filters. # weights_time: [num_filters, input_time_size, 1] # outputs: [num_filters, batch, 1] @@ -524,7 +591,8 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings, units_output = tf.transpose(units_output) # Appy bias. - bias = tf.Variable(tf.zeros([num_units])) + bias = tf.get_variable( + name='bias', initializer=tf.zeros_initializer, shape=[num_units]) first_bias = tf.nn.bias_add(units_output, bias) # Relu. @@ -536,31 +604,135 @@ def create_low_latency_svdf_model(fingerprint_input, model_settings, first_dropout = first_relu first_fc_output_channels = 256 - first_fc_weights = tf.Variable( - tf.truncated_normal([num_units, first_fc_output_channels], stddev=0.01)) - first_fc_bias = tf.Variable(tf.zeros([first_fc_output_channels])) + first_fc_weights = tf.get_variable( + name='first_fc_weights', + initializer=tf.truncated_normal_initializer(stddev=0.01), + shape=[num_units, first_fc_output_channels]) + first_fc_bias = tf.get_variable( + name='first_fc_bias', + initializer=tf.zeros_initializer, + shape=[first_fc_output_channels]) first_fc = tf.matmul(first_dropout, first_fc_weights) + first_fc_bias if is_training: second_fc_input = tf.nn.dropout(first_fc, dropout_prob) else: second_fc_input = first_fc second_fc_output_channels = 256 - second_fc_weights = tf.Variable( - tf.truncated_normal( - [first_fc_output_channels, second_fc_output_channels], stddev=0.01)) - second_fc_bias = tf.Variable(tf.zeros([second_fc_output_channels])) + second_fc_weights = tf.get_variable( + name='second_fc_weights', + initializer=tf.truncated_normal_initializer(stddev=0.01), + shape=[first_fc_output_channels, second_fc_output_channels]) + second_fc_bias = tf.get_variable( + name='second_fc_bias', + initializer=tf.zeros_initializer, + shape=[second_fc_output_channels]) second_fc = tf.matmul(second_fc_input, second_fc_weights) + second_fc_bias if is_training: final_fc_input = tf.nn.dropout(second_fc, dropout_prob) else: final_fc_input = second_fc label_count = model_settings['label_count'] - final_fc_weights = tf.Variable( - tf.truncated_normal( - [second_fc_output_channels, label_count], stddev=0.01)) - final_fc_bias = tf.Variable(tf.zeros([label_count])) + final_fc_weights = tf.get_variable( + name='final_fc_weights', + initializer=tf.truncated_normal(stddev=0.01), + shape=[second_fc_output_channels, label_count]) + final_fc_bias = tf.get_variable( + name='final_fc_bias', + initializer=tf.zeros_initializer, + shape=[label_count]) final_fc = tf.matmul(final_fc_input, final_fc_weights) + final_fc_bias if is_training: return final_fc, dropout_prob else: return final_fc + + +def create_tiny_conv_model(fingerprint_input, model_settings, is_training): + """Builds a convolutional model aimed at microcontrollers. + + Devices like DSPs and microcontrollers can have very small amounts of + memory and limited processing power. This model is designed to use less + than 20KB of working RAM, and fit within 32KB of read-only (flash) memory. + + Here's the layout of the graph: + + (fingerprint_input) + v + [Conv2D]<-(weights) + v + [BiasAdd]<-(bias) + v + [Relu] + v + [MatMul]<-(weights) + v + [BiasAdd]<-(bias) + v + + This doesn't produce particularly accurate results, but it's designed to be + used as the first stage of a pipeline, running on a low-energy piece of + hardware that can always be on, and then wake higher-power chips when a + possible utterance has been found, so that more accurate analysis can be done. + + During training, a dropout node is introduced after the relu, controlled by a + placeholder. + + Args: + fingerprint_input: TensorFlow node that will output audio feature vectors. + model_settings: Dictionary of information about the model. + is_training: Whether the model is going to be used for training. + + Returns: + TensorFlow node outputting logits results, and optionally a dropout + placeholder. + """ + if is_training: + dropout_prob = tf.placeholder(tf.float32, name='dropout_prob') + input_frequency_size = model_settings['fingerprint_width'] + input_time_size = model_settings['spectrogram_length'] + fingerprint_4d = tf.reshape(fingerprint_input, + [-1, input_time_size, input_frequency_size, 1]) + first_filter_width = 8 + first_filter_height = 10 + first_filter_count = 8 + first_weights = tf.get_variable( + name='first_weights', + initializer=tf.truncated_normal_initializer(stddev=0.01), + shape=[first_filter_height, first_filter_width, 1, first_filter_count]) + first_bias = tf.get_variable( + name='first_bias', + initializer=tf.zeros_initializer, + shape=[first_filter_count]) + first_conv_stride_x = 2 + first_conv_stride_y = 2 + first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, + [1, first_conv_stride_y, first_conv_stride_x, 1], + 'SAME') + first_bias + first_relu = tf.nn.relu(first_conv) + if is_training: + first_dropout = tf.nn.dropout(first_relu, dropout_prob) + else: + first_dropout = first_relu + first_dropout_shape = first_dropout.get_shape() + first_dropout_output_width = first_dropout_shape[2] + first_dropout_output_height = first_dropout_shape[1] + first_dropout_element_count = int( + first_dropout_output_width * first_dropout_output_height * + first_filter_count) + flattened_first_dropout = tf.reshape(first_dropout, + [-1, first_dropout_element_count]) + label_count = model_settings['label_count'] + final_fc_weights = tf.get_variable( + name='final_fc_weights', + initializer=tf.truncated_normal_initializer(stddev=0.01), + shape=[first_dropout_element_count, label_count]) + final_fc_bias = tf.get_variable( + name='final_fc_bias', + initializer=tf.zeros_initializer, + shape=[label_count]) + final_fc = ( + tf.matmul(flattened_first_dropout, final_fc_weights) + final_fc_bias) + if is_training: + return final_fc, dropout_prob + else: + return final_fc diff --git a/tensorflow/examples/speech_commands/models_test.py b/tensorflow/examples/speech_commands/models_test.py index 80c795367fa01f214d78d3fa7df7864b6b243b97..0c373967ed8fb9cddcc82972e0fc8bba186add2e 100644 --- a/tensorflow/examples/speech_commands/models_test.py +++ b/tensorflow/examples/speech_commands/models_test.py @@ -26,12 +26,29 @@ from tensorflow.python.platform import test class ModelsTest(test.TestCase): + def _modelSettings(self): + return models.prepare_model_settings( + label_count=10, + sample_rate=16000, + clip_duration_ms=1000, + window_size_ms=20, + window_stride_ms=10, + feature_bin_count=40, + preprocess="mfcc") + def testPrepareModelSettings(self): self.assertIsNotNone( - models.prepare_model_settings(10, 16000, 1000, 20, 10, 40)) + models.prepare_model_settings( + label_count=10, + sample_rate=16000, + clip_duration_ms=1000, + window_size_ms=20, + window_stride_ms=10, + feature_bin_count=40, + preprocess="mfcc")) def testCreateModelConvTraining(self): - model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40) + model_settings = self._modelSettings() with self.test_session() as sess: fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) logits, dropout_prob = models.create_model(fingerprint_input, @@ -42,7 +59,7 @@ class ModelsTest(test.TestCase): self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name)) def testCreateModelConvInference(self): - model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40) + model_settings = self._modelSettings() with self.test_session() as sess: fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) logits = models.create_model(fingerprint_input, model_settings, "conv", @@ -51,7 +68,7 @@ class ModelsTest(test.TestCase): self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) def testCreateModelLowLatencyConvTraining(self): - model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40) + model_settings = self._modelSettings() with self.test_session() as sess: fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) logits, dropout_prob = models.create_model( @@ -62,7 +79,7 @@ class ModelsTest(test.TestCase): self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name)) def testCreateModelFullyConnectedTraining(self): - model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40) + model_settings = self._modelSettings() with self.test_session() as sess: fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) logits, dropout_prob = models.create_model( @@ -73,7 +90,7 @@ class ModelsTest(test.TestCase): self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name)) def testCreateModelBadArchitecture(self): - model_settings = models.prepare_model_settings(10, 16000, 1000, 20, 10, 40) + model_settings = self._modelSettings() with self.test_session(): fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) with self.assertRaises(Exception) as e: @@ -81,6 +98,17 @@ class ModelsTest(test.TestCase): "bad_architecture", True) self.assertTrue("not recognized" in str(e.exception)) + def testCreateModelTinyConvTraining(self): + model_settings = self._modelSettings() + with self.test_session() as sess: + fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]]) + logits, dropout_prob = models.create_model( + fingerprint_input, model_settings, "tiny_conv", True) + self.assertIsNotNone(logits) + self.assertIsNotNone(dropout_prob) + self.assertIsNotNone(sess.graph.get_tensor_by_name(logits.name)) + self.assertIsNotNone(sess.graph.get_tensor_by_name(dropout_prob.name)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py index fc28eb0631dc5e1947c2a31a6acdb02ed8d28f3a..eca34f8812b76a60168c97a745f5890bf3ee0269 100644 --- a/tensorflow/examples/speech_commands/train.py +++ b/tensorflow/examples/speech_commands/train.py @@ -98,12 +98,12 @@ def main(_): model_settings = models.prepare_model_settings( len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))), FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms, - FLAGS.window_stride_ms, FLAGS.dct_coefficient_count) + FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess) audio_processor = input_data.AudioProcessor( - FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage, - FLAGS.unknown_percentage, + FLAGS.data_url, FLAGS.data_dir, + FLAGS.silence_percentage, FLAGS.unknown_percentage, FLAGS.wanted_words.split(','), FLAGS.validation_percentage, - FLAGS.testing_percentage, model_settings) + FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir) fingerprint_size = model_settings['fingerprint_size'] label_count = model_settings['label_count'] time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000) @@ -122,8 +122,25 @@ def main(_): 'lists, but are %d and %d long instead' % (len(training_steps_list), len(learning_rates_list))) - fingerprint_input = tf.placeholder( + input_placeholder = tf.placeholder( tf.float32, [None, fingerprint_size], name='fingerprint_input') + if FLAGS.quantize: + # TODO(petewarden): These values have been derived from the observed ranges + # of spectrogram and MFCC inputs. If the preprocessing pipeline changes, + # they may need to be updated. + if FLAGS.preprocess == 'average': + fingerprint_min = 0.0 + fingerprint_max = 2048.0 + elif FLAGS.preprocess == 'mfcc': + fingerprint_min = -247.0 + fingerprint_max = 30.0 + else: + raise Exception('Unknown preprocess mode "%s" (should be "mfcc" or' + ' "average")' % (FLAGS.preprocess)) + fingerprint_input = tf.fake_quant_with_min_max_args( + input_placeholder, fingerprint_min, fingerprint_max) + else: + fingerprint_input = input_placeholder logits, dropout_prob = models.create_model( fingerprint_input, @@ -146,7 +163,8 @@ def main(_): with tf.name_scope('cross_entropy'): cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy( labels=ground_truth_input, logits=logits) - tf.summary.scalar('cross_entropy', cross_entropy_mean) + if FLAGS.quantize: + tf.contrib.quantize.create_training_graph(quant_delay=0) with tf.name_scope('train'), tf.control_dependencies(control_dependencies): learning_rate_input = tf.placeholder( tf.float32, [], name='learning_rate_input') @@ -157,7 +175,9 @@ def main(_): confusion_matrix = tf.confusion_matrix( ground_truth_input, predicted_indices, num_classes=label_count) evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - tf.summary.scalar('accuracy', evaluation_step) + with tf.get_default_graph().name_scope('eval'): + tf.summary.scalar('cross_entropy', cross_entropy_mean) + tf.summary.scalar('accuracy', evaluation_step) global_step = tf.train.get_or_create_global_step() increment_global_step = tf.assign(global_step, global_step + 1) @@ -165,7 +185,7 @@ def main(_): saver = tf.train.Saver(tf.global_variables()) # Merge all the summaries and write them out to /tmp/retrain_logs (by default) - merged_summaries = tf.summary.merge_all() + merged_summaries = tf.summary.merge_all(scope='eval') train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph) validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation') @@ -207,8 +227,11 @@ def main(_): # Run the graph with this batch of training data. train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run( [ - merged_summaries, evaluation_step, cross_entropy_mean, train_step, - increment_global_step + merged_summaries, + evaluation_step, + cross_entropy_mean, + train_step, + increment_global_step, ], feed_dict={ fingerprint_input: train_fingerprints, @@ -364,10 +387,11 @@ if __name__ == '__main__': default=10.0, help='How far to move in time between spectogram timeslices.',) parser.add_argument( - '--dct_coefficient_count', + '--feature_bin_count', type=int, default=40, - help='How many bins to use for the MFCC fingerprint',) + help='How many bins to use for the MFCC fingerprint', + ) parser.add_argument( '--how_many_training_steps', type=str, @@ -423,6 +447,16 @@ if __name__ == '__main__': type=bool, default=False, help='Whether to check for invalid numbers during processing') + parser.add_argument( + '--quantize', + type=bool, + default=False, + help='Whether to train the model for eight-bit deployment') + parser.add_argument( + '--preprocess', + type=str, + default='mfcc', + help='Spectrogram processing mode. Can be "mfcc" or "average"') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD index d7bc6a5a7d1e4cd3927c7c5067ccc22993885994..d4070fdd1e015fb78dcf2ff72fe30b6f1746c8fb 100644 --- a/tensorflow/examples/tutorials/mnist/BUILD +++ b/tensorflow/examples/tutorials/mnist/BUILD @@ -97,7 +97,7 @@ py_binary( py_test( name = "fully_connected_feed_test", - size = "small", + size = "medium", srcs = [ "fully_connected_feed.py", ], diff --git a/tensorflow/examples/tutorials/mnist/mnist_deep.py b/tensorflow/examples/tutorials/mnist/mnist_deep.py index 1e0294db27bc675870afceca77a2cdcd4b3f5ad3..5d8d8d84fe26c0a3ec69791885f3c7ce5e0fba15 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_deep.py +++ b/tensorflow/examples/tutorials/mnist/mnist_deep.py @@ -34,6 +34,8 @@ from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf +import numpy + FLAGS = None @@ -164,8 +166,15 @@ def main(_): print('step %d, training accuracy %g' % (i, train_accuracy)) train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) - print('test accuracy %g' % accuracy.eval(feed_dict={ - x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})) + # compute in batches to avoid OOM on GPUs + accuracy_l = [] + for _ in range(20): + batch = mnist.test.next_batch(500, shuffle=False) + accuracy_l.append(accuracy.eval(feed_dict={x: batch[0], + y_: batch[1], + keep_prob: 1.0})) + print('test accuracy %g' % numpy.mean(accuracy_l)) + if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/tensorflow/g3doc/README.txt b/tensorflow/g3doc/README.txt index ed648f8b6b8895010be84becd4fda25ded5859fb..515a9e9a025d9b974d4ba0cf81c3c8319f38a877 100644 --- a/tensorflow/g3doc/README.txt +++ b/tensorflow/g3doc/README.txt @@ -22,12 +22,12 @@ When authoring docs, note that we have some new syntax for references -- at least for docs coming from Python docstrings or tensorflow/docs_src/. Use: -* @{tf.symbol} to make a link to the reference page for a Python +* `tf.symbol` to make a link to the reference page for a Python symbol. Note that class members don't get their own page, but the - syntax still works, since @{tf.MyClass.method} links to the right + syntax still works, since `tf.MyClass.method` links to the right part of the tf.MyClass page. -* @{tensorflow::symbol} to make a link to the reference page for a C++ +* `tensorflow::symbol` to make a link to the reference page for a C++ symbol. (This only works for a few symbols but will work for more soon.) * @{$doc_page} to make a link to another (not an API reference) doc diff --git a/tensorflow/go/README.md b/tensorflow/go/README.md index e251356ec8e97311affaf752c0a515be97013fa8..288a32530a7ed2f4d773912591907395c82db34e 100644 --- a/tensorflow/go/README.md +++ b/tensorflow/go/README.md @@ -46,7 +46,7 @@ from source. ```sh cd ${GOPATH}/src/github.com/tensorflow/tensorflow ./configure - bazel build --config opt //tensorflow:libtensorflow.so + bazel build -c opt //tensorflow:libtensorflow.so ``` This can take a while (tens of minutes, more if also building for GPU). diff --git a/tensorflow/go/attrs.go b/tensorflow/go/attrs.go index bfa60d2aa83339a5be6020e876900c330b7174a1..f86c5737bc79f1e349e442669615598949ecd333 100644 --- a/tensorflow/go/attrs.go +++ b/tensorflow/go/attrs.go @@ -33,7 +33,8 @@ func makeCShape(shape []C.int64_t) Shape { return s } -// Attr returns the value of an attribute on op. +// Attr returns the value of an attribute on op. It returns an error if the +// attribute does not exist. func (op *Operation) Attr(name string) (interface{}, error) { cname := C.CString(name) defer C.free(unsafe.Pointer(cname)) @@ -55,9 +56,13 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf switch meta._type { case C.TF_ATTR_STRING: + if meta.list_size == 0 { + return []string(nil), nil + } values := make([]unsafe.Pointer, meta.list_size) lengths := make([]C.size_t, meta.list_size) - storage := make([]C.char, meta.total_size) + // Add one element in case total_size is zero. + storage := make([]C.char, meta.total_size+1) C.TF_OperationGetAttrStringList(op.c, cname, &values[0], &lengths[0], C.int(meta.list_size), unsafe.Pointer(&storage[0]), C.size_t(meta.total_size), status.c) if err := status.Err(); err != nil { return nil, err @@ -70,6 +75,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return list, nil case C.TF_ATTR_INT: + if meta.list_size == 0 { + return []int64(nil), nil + } list := make([]C.int64_t, meta.list_size) C.TF_OperationGetAttrIntList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -82,6 +90,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_FLOAT: + if meta.list_size == 0 { + return []float32(nil), nil + } list := make([]C.float, meta.list_size) C.TF_OperationGetAttrFloatList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -94,6 +105,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_BOOL: + if meta.list_size == 0 { + return []bool(nil), nil + } list := make([]C.uchar, meta.list_size) C.TF_OperationGetAttrBoolList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -106,6 +120,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_TYPE: + if meta.list_size == 0 { + return []DataType(nil), nil + } list := make([]C.TF_DataType, meta.list_size) C.TF_OperationGetAttrTypeList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -118,6 +135,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_TENSOR: + if meta.list_size == 0 { + return []*Tensor(nil), nil + } list := make([]*C.TF_Tensor, meta.list_size) C.TF_OperationGetAttrTensorList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -130,9 +150,13 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_SHAPE: + if meta.list_size == 0 { + return []Shape(nil), nil + } dims := make([]*C.int64_t, meta.list_size) numDims := make([]C.int, meta.list_size) - storage := make([]C.int64_t, meta.total_size) + // Add one element in case total_size is zero. + storage := make([]C.int64_t, meta.total_size+1) C.TF_OperationGetAttrShapeList(op.c, cname, &dims[0], &numDims[0], C.int(meta.list_size), &storage[0], C.int(meta.total_size), status.c) if err := status.Err(); err != nil { return nil, err @@ -161,6 +185,9 @@ func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (inte switch meta._type { case C.TF_ATTR_STRING: + if meta.total_size == 0 { + return "", nil + } v := make([]C.char, meta.total_size) C.TF_OperationGetAttrString(op.c, cname, unsafe.Pointer(&v[0]), C.size_t(meta.total_size), status.c) if err := status.Err(); err != nil { @@ -202,6 +229,9 @@ func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (inte if numDims < 0 { return Shape{}, nil } + if numDims == 0 { + return ScalarShape(), nil + } dims := make([]C.int64_t, numDims) C.TF_OperationGetAttrShape(op.c, cname, (*C.int64_t)(unsafe.Pointer(&dims[0])), C.int(numDims), status.c) if err := status.Err(); err != nil { diff --git a/tensorflow/go/attrs_test.go b/tensorflow/go/attrs_test.go index 18fc0de90a1470e6d7adc93e4ca0b94ee526784b..ea8af221aeef3bf1d2edeab4372ae00f0cc7e92d 100644 --- a/tensorflow/go/attrs_test.go +++ b/tensorflow/go/attrs_test.go @@ -17,31 +17,177 @@ limitations under the License. package tensorflow import ( + "fmt" "reflect" "testing" ) func TestOperationAttrs(t *testing.T) { - attrs := map[string]interface{}{ - "dtype": Float, + g := NewGraph() + + i := 0 + makeConst := func(v interface{}) Output { + op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v) + i++ + if err != nil { + t.Fatal(err) + } + return op } - g := NewGraph() - op, err := g.AddOperation(OpSpec{ - Type: "Placeholder", - Name: "placeholder", - Attrs: attrs, - }) - if err != nil { - t.Fatal(err) + makeTensor := func(v interface{}) *Tensor { + tensor, err := NewTensor(v) + if err != nil { + t.Fatal(err) + } + return tensor } - for key, want := range attrs { - out, err := op.Attr(key) + + cases := []OpSpec{ + { + Name: "type", + Type: "Placeholder", + Attrs: map[string]interface{}{ + "dtype": Float, + }, + }, + { + Name: "list(float)", + Type: "Bucketize", + Input: []Input{ + makeConst([]float32{1, 2, 3, 4}), + }, + Attrs: map[string]interface{}{ + "boundaries": []float32{0, 1, 2, 3, 4, 5}, + }, + }, + { + Name: "list(float) empty", + Type: "Bucketize", + Input: []Input{ + makeConst([]float32{}), + }, + Attrs: map[string]interface{}{ + "boundaries": []float32(nil), + }, + }, + /* TODO(ashankar): debug this issue and add it back later. + { + Name: "list(type),list(shape)", + Type: "InfeedEnqueueTuple", + Input: []Input{ + OutputList([]Output{ + makeConst(float32(1)), + makeConst([][]int32{{2}}), + }), + }, + Attrs: map[string]interface{}{ + "dtypes": []DataType{Float, Int32}, + "shapes": []Shape{ScalarShape(), MakeShape(1, 1)}, + }, + }, + { + Name: "list(type),list(shape) empty", + Type: "InfeedEnqueueTuple", + Input: []Input{ + OutputList([]Output{ + makeConst([][]int32{{2}}), + }), + }, + Attrs: map[string]interface{}{ + "dtypes": []DataType{Int32}, + "shapes": []Shape(nil), + }, + }, + { + Name: "list(type) empty,string empty,int", + Type: "_XlaSendFromHost", + Input: []Input{ + OutputList([]Output{}), + makeConst(""), + }, + Attrs: map[string]interface{}{ + "Tinputs": []DataType(nil), + "key": "", + "device_ordinal": int64(0), + }, + }, + */ + { + Name: "list(int),int", + Type: "StringToHashBucketStrong", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "num_buckets": int64(2), + "key": []int64{1, 2}, + }, + }, + { + Name: "list(int) empty,int", + Type: "StringToHashBucketStrong", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "num_buckets": int64(2), + "key": ([]int64)(nil), + }, + }, + { + Name: "list(string),type", + Type: "TensorSummary", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "T": String, + "labels": []string{"foo", "bar"}, + }, + }, + { + Name: "list(string) empty,type", + Type: "TensorSummary", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "T": String, + "labels": ([]string)(nil), + }, + }, + { + Name: "tensor", + Type: "Const", + Attrs: map[string]interface{}{ + "dtype": String, + "value": makeTensor("foo"), + }, + }, + } + + for i, spec := range cases { + op, err := g.AddOperation(spec) if err != nil { t.Fatal(err) } - if !reflect.DeepEqual(out, want) { - t.Fatalf("%q: Got %+v, wanted %+v", key, out, want) + for key, want := range spec.Attrs { + out, err := op.Attr(key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(out, want) { + t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, out, want) + } + wantT, ok := want.(*Tensor) + if ok { + wantVal := wantT.Value() + outVal := out.(*Tensor).Value() + if !reflect.DeepEqual(outVal, wantVal) { + t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, outVal, wantVal) + } + } } } } diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index 08943a527cbdc072b12b066240c213be45ffd54c..32a77550ee2fa5606b402600aa6429950d8e72a5 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -177,7 +177,14 @@ type OpSpec struct { // being added. ControlDependencies []*Operation - // Other possible fields: Device, ColocateWith. + // The device on which the operation should be executed. + // If omitted, an appropriate device will automatically be selected. + // + // For example, if set of "/device:GPU:0", then the operation will + // execute on GPU #0. + Device string + + // Other possible fields: ColocateWith. } // AddOperation adds an operation to g. @@ -225,6 +232,11 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) { return nil, fmt.Errorf("%v (memory will be leaked)", err) } } + if len(args.Device) > 0 { + cdevice := C.CString(args.Device) + C.TF_SetDevice(cdesc, cdevice) + C.free(unsafe.Pointer(cdevice)) + } c := C.TF_FinishOperation(cdesc, status.c) if err := status.Err(); err != nil { return nil, err diff --git a/tensorflow/go/op/scope.go b/tensorflow/go/op/scope.go index 13de4294dc2ebdfff9bb68d277c09239d0bc8593..ac39808d838f4737b81b170d3f540d10ed38fe42 100644 --- a/tensorflow/go/op/scope.go +++ b/tensorflow/go/op/scope.go @@ -37,6 +37,7 @@ type Scope struct { namemap map[string]int namespace string controlDependencies []*tf.Operation + device string err *scopeErr } @@ -82,6 +83,7 @@ func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation { args.Name = s.namespace + "/" + args.Name } args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...) + args.Device = s.device op, err := s.graph.AddOperation(args) if err != nil { s.UpdateErr(args.Type, err) @@ -98,10 +100,12 @@ func (s *Scope) SubScope(namespace string) *Scope { namespace = s.namespace + "/" + namespace } return &Scope{ - graph: s.graph, - namemap: make(map[string]int), - namespace: namespace, - err: s.err, + graph: s.graph, + namemap: make(map[string]int), + namespace: namespace, + controlDependencies: s.controlDependencies, + device: s.device, + err: s.err, } } @@ -123,6 +127,25 @@ func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope { namemap: s.namemap, namespace: s.namespace, controlDependencies: deps, + device: s.device, + err: s.err, + } +} + +// WithDevice returns a new Scope which will cause all operations added to the +// graph to execute on devices that match the provided device specification. +// +// For example, WithDevice("/device:GPU:0") will cause operations added to +// the graph to execute on GPU #0. +// +// An empty string removes any device restrictions. +func (s *Scope) WithDevice(device string) *Scope { + return &Scope{ + graph: s.graph, + namemap: s.namemap, + namespace: s.namespace, + controlDependencies: s.controlDependencies, + device: device, err: s.err, } } diff --git a/tensorflow/go/op/scope_test.go b/tensorflow/go/op/scope_test.go index b58a61de98b0f5b04959e1eca35c6b6c4d77e42b..be7b0ad8926aadac47218b7625036d7e12b9554b 100644 --- a/tensorflow/go/op/scope_test.go +++ b/tensorflow/go/op/scope_test.go @@ -112,6 +112,21 @@ func TestControlDependencies(t *testing.T) { } } +func TestDevice(t *testing.T) { + s := NewScope() + matrix := Const(s, [][]float32{{3.0}}) + s = s.WithDevice("/device:GPU:0") + square := MatMul(s.SubScope("square"), matrix, matrix) + s = s.WithDevice("") + cube := MatMul(s.SubScope("cube"), square, matrix) + if got, want := square.Op.Device(), "/device:GPU:0"; got != want { + t.Errorf("Got %q, want %q", got, want) + } + if got, want := cube.Op.Device(), ""; got != want { + t.Errorf("Got %q, want %q", got, want) + } +} + func TestScopeFinalize(t *testing.T) { var ( root = NewScope() diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index a5224fbda04fdfb2460fce96efffb6eab4f08551..3e0ea619e3596123870aca7bc45cdba3736684ce 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -327,15 +327,19 @@ func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQua return op.Output(0) } -// Scatter `updates` into a new (initially zero) tensor according to `indices`. +// Scatter `updates` into a new tensor according to `indices`. // -// Creates a new tensor by applying sparse `updates` to individual -// values or slices within a zero tensor of the given `shape` according to -// indices. This operator is the inverse of the @{tf.gather_nd} operator which -// extracts values or slices from a given tensor. +// Creates a new tensor by applying sparse `updates` to individual values or +// slices within a tensor (initially zero for numeric, empty for string) of +// the given `shape` according to indices. This operator is the inverse of the +// @{tf.gather_nd} operator which extracts values or slices from a given tensor. +// +// If `indices` contains duplicates, then their updates are accumulated (summed). // // **WARNING**: The order in which updates are applied is nondeterministic, so the -// output will be nondeterministic if `indices` contains duplicates. +// output will be nondeterministic if `indices` contains duplicates -- because +// of some numerical approximation issues, numbers summed in different order +// may yield different results. // // `indices` is an integer tensor containing indices into a new tensor of shape // `shape`. The last dimension of `indices` can be at most the rank of `shape`: @@ -430,7 +434,8 @@ type QuantizeAndDequantizeV2Attr func(optionalAttr) // QuantizeAndDequantizeV2SignedInput sets the optional signed_input attribute to value. // -// value: If the quantization is signed or unsigned. +// value: Whether the quantization is signed or unsigned. (actually this parameter should +// have been called `signed_output`) // If not specified, defaults to true func QuantizeAndDequantizeV2SignedInput(value bool) QuantizeAndDequantizeV2Attr { return func(m optionalAttr) { @@ -450,7 +455,7 @@ func QuantizeAndDequantizeV2NumBits(value int64) QuantizeAndDequantizeV2Attr { // QuantizeAndDequantizeV2RangeGiven sets the optional range_given attribute to value. // -// value: If the range is given or should be computed from the tensor. +// value: Whether the range is given or should be determined from the `input` tensor. // If not specified, defaults to false func QuantizeAndDequantizeV2RangeGiven(value bool) QuantizeAndDequantizeV2Attr { return func(m optionalAttr) { @@ -461,61 +466,64 @@ func QuantizeAndDequantizeV2RangeGiven(value bool) QuantizeAndDequantizeV2Attr { // Quantizes then dequantizes a tensor. // // This op simulates the precision loss from the quantized forward pass by: +// // 1. Quantizing the tensor to fixed point numbers, which should match the target // quantization method when it is used in inference. // 2. Dequantizing it back to floating point numbers for the following ops, most // likely matmul. // -// There are different ways to quantize. This version does not use the full range -// of the output type, choosing to elide the lowest possible value for symmetry -// (e.g., output range is -127 to 127, not -128 to 127 for signed 8 bit -// quantization), so that 0.0 maps to 0. -// -// To perform this op, we first find the range of values in our tensor. The range -// we use is always centered on 0, so we find m such that -// -// 1. m = max(abs(input_min), abs(input_max)) if range_given is true, -// 2. m = max(abs(min_elem(input)), abs(max_elem(input))) otherwise. +// There are different ways to quantize. This version uses only scaling, so 0.0 +// maps to 0. // -// Our input tensor range is then [-m, m]. +// From the specified 'num_bits' in the quantized output type, it determines +// minimum and maximum representable quantized values. // -// Next, we choose our fixed-point quantization buckets, [min_fixed, max_fixed]. -// If signed_input is true, this is +// e.g. // -// [min_fixed, max_fixed ] = -// [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1]. +// * [-128, 127] for signed, num_bits = 8, or +// * [0, 255] for unsigned, num_bits = 8. // -// Otherwise, if signed_input is false, the fixed-point range is +// If range_given == False, the initial input_min, input_max will be determined +// automatically as the minimum and maximum values in the input tensor, otherwise +// the specified values of input_min, input_max are used. // -// [min_fixed, max_fixed] = [0, (1 << num_bits) - 1]. +// Note: If the input_min, input_max are specified, they do not need to equal the +// actual minimum and maximum values in the tensor. e.g. in some cases it may be +// beneficial to specify these values such that the low probability extremes of the +// input distribution are clipped. // -// From this we compute our scaling factor, s: -// -// s = (max_fixed - min_fixed) / (2 * m). +// This op determines the maximum scale_factor that would map the initial +// [input_min, input_max] range to a range that lies within the representable +// quantized range. // -// Now we can quantize and dequantize the elements of our tensor. An element e -// is transformed into e': +// It determines the scale from one of input_min and input_max, then updates the +// other one to maximize the respresentable range. // -// e' = (e * s).round_to_nearest() / s. +// e.g. // -// Note that we have a different number of buckets in the signed vs. unsigned -// cases. For example, if num_bits == 8, we get 254 buckets in the signed case -// vs. 255 in the unsigned case. +// * if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0, +// 5.0]: it would use a scale_factor of -128 / -10.0 = 12.8 In this case, it +// would update input_max to be 127 / 12.8 = 9.921875 +// * if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0, +// 10.0]: it would use a scale_factor of 127 / 10.0 = 12.7 In this case, it +// would update input_min to be 128.0 / 12.7 = -10.07874 +// * if the output is unsigned, input_min is forced to be 0, and only the +// specified input_max is used. // -// For example, suppose num_bits = 8 and m = 1. Then +// After determining the scale_factor and updating the input range, it applies the +// following to each value in the 'input' tensor. // -// [min_fixed, max_fixed] = [-127, 127], and -// s = (127 + 127) / 2 = 127. +// output = round(clamp(value, input_min, input_max) * scale_factor) / scale_factor. // -// Given the vector {-1, -0.5, 0, 0.3}, this is quantized to -// {-127, -63, 0, 38}, and dequantized to {-1, -63.0/127, 0, 38.0/127}. // // Arguments: // input: Tensor to quantize and then dequantize. -// input_min: If range_given, this is the min of the range, otherwise this input -// will be ignored. -// input_max: If range_given, this is the max of the range, otherwise this input -// will be ignored. +// input_min: If `range_given == True`, this specifies the minimum input value that needs to +// be represented, otherwise it is determined from the min value of the `input` +// tensor. +// input_max: If `range_given == True`, this specifies the maximum input value that needs to +// be represented, otherwise it is determined from the max value of the `input` +// tensor. func QuantizeAndDequantizeV2(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, optional ...QuantizeAndDequantizeV2Attr) (output tf.Output) { if scope.Err() != nil { return @@ -2249,7 +2257,7 @@ func CheckNumerics(scope *Scope, tensor tf.Output, message string) (output tf.Ou // (K-1)-dimensional tensor of indices into `params`, where each element defines a // slice of `params`: // -// output[i_0, ..., i_{K-2}] = params[indices[i0, ..., i_{K-2}]] +// output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]] // // Whereas in @{tf.gather} `indices` defines slices into the first // dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the @@ -2610,70 +2618,6 @@ func Reverse(scope *Scope, tensor tf.Output, dims tf.Output) (output tf.Output) return op.Output(0) } -// Copy a tensor setting everything outside a central band in each innermost matrix -// -// to zero. -// -// The `band` part is computed as follows: -// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a -// tensor with the same shape where -// -// `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`. -// -// The indicator function -// -// `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && -// (num_upper < 0 || (n-m) <= num_upper)`. -// -// For example: -// -// ``` -// # if 'input' is [[ 0, 1, 2, 3] -// [-1, 0, 1, 2] -// [-2, -1, 0, 1] -// [-3, -2, -1, 0]], -// -// tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3] -// [-1, 0, 1, 2] -// [ 0, -1, 0, 1] -// [ 0, 0, -1, 0]], -// -// tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0] -// [-1, 0, 1, 0] -// [-2, -1, 0, 1] -// [ 0, -2, -1, 0]] -// ``` -// -// Useful special cases: -// -// ``` -// tf.matrix_band_part(input, 0, -1) ==> Upper triangular part. -// tf.matrix_band_part(input, -1, 0) ==> Lower triangular part. -// tf.matrix_band_part(input, 0, 0) ==> Diagonal. -// ``` -// -// Arguments: -// input: Rank `k` tensor. -// num_lower: 0-D tensor. Number of subdiagonals to keep. If negative, keep entire -// lower triangle. -// num_upper: 0-D tensor. Number of superdiagonals to keep. If negative, keep -// entire upper triangle. -// -// Returns Rank `k` tensor of the same shape as input. The extracted banded tensor. -func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_upper tf.Output) (band tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatrixBandPart", - Input: []tf.Input{ - input, num_lower, num_upper, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Returns the batched diagonal part of a batched tensor. // // This operation returns a tensor with the `diagonal` part @@ -2990,170 +2934,161 @@ func Split(scope *Scope, axis tf.Output, value tf.Output, num_split int64) (outp return output } -// Creates a sequence of numbers. -// -// This operation creates a sequence of numbers that begins at `start` and -// extends by increments of `delta` up to but not including `limit`. -// -// For example: -// -// ``` -// # 'start' is 3 -// # 'limit' is 18 -// # 'delta' is 3 -// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] -// ``` +// Concatenates tensors along one dimension. // // Arguments: -// start: 0-D (scalar). First entry in the sequence. -// limit: 0-D (scalar). Upper limit of sequence, exclusive. -// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`. +// concat_dim: 0-D. The dimension along which to concatenate. Must be in the +// range [0, rank(values)). +// values: The `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. // -// Returns 1-D. -func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) { +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes. +func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Range", + Type: "Concat", Input: []tf.Input{ - start, limit, delta, + concat_dim, tf.OutputList(values), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes gradients for SparseSegmentSqrtN. +// Broadcast an array for a compatible shape. // -// Returns tensor "output" with same shape as grad, except for dimension 0 whose -// value is output_dim0. +// Broadcasting is the process of making arrays to have compatible shapes +// for arithmetic operations. Two shapes are compatible if for each +// dimension pair they are either equal or one of them is one. When trying +// to broadcast a Tensor to a shape, it starts with the trailing dimensions, +// and works its way forward. +// +// For example, +// ``` +// >>> x = tf.constant([1, 2, 3]) +// >>> y = tf.broadcast_to(x, [3, 3]) +// >>> sess.run(y) +// array([[1, 2, 3], +// [1, 2, 3], +// [1, 2, 3]], dtype=int32) +// ``` +// In the above example, the input Tensor with the shape of `[1, 3]` +// is broadcasted to output Tensor with shape of `[3, 3]`. // // Arguments: -// grad: gradient propagated to the SparseSegmentSqrtN op. -// indices: indices passed to the corresponding SparseSegmentSqrtN op. -// segment_ids: segment_ids passed to the corresponding SparseSegmentSqrtN op. -// output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op. -func SparseSegmentSqrtNGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { +// input: A Tensor to broadcast. +// shape: An 1-D `int` Tensor. The shape of the desired output. +// +// Returns A Tensor. +func BroadcastTo(scope *Scope, input tf.Output, shape tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSegmentSqrtNGrad", + Type: "BroadcastTo", Input: []tf.Input{ - grad, indices, segment_ids, output_dim0, + input, shape, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the mean along sparse segments of a tensor. +// Converts a flat index or array of flat indices into a tuple of // -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. +// coordinate arrays. // -// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first -// dimension, selecting a subset of dimension 0, specified by `indices`. +// @compatibility(numpy) +// Equivalent to np.unravel_index +// @end_compatibility // // Arguments: +// indices: An 0-D or 1-D `int` Tensor whose elements are indices into the +// flattened version of an array of dimensions dims. +// dims: An 1-D `int` Tensor. The shape of the array to use for unraveling +// indices. // -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { +// Returns An 2-D (or 1-D if indices is 0-D) tensor where each row has the +// same shape as the indices array. +func UnravelIndex(scope *Scope, indices tf.Output, dims tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSegmentMean", + Type: "UnravelIndex", Input: []tf.Input{ - data, indices, segment_ids, + indices, dims, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Pop the element at the top of the stack. +// Subtracts `v` into specified rows of `x`. +// +// Computes y = x; y[i, :] -= v; return y. // // Arguments: -// handle: The handle to a stack. -// elem_type: The type of the elem that is popped. +// x: A `Tensor` of type T. +// i: A vector. Indices into the left-most dimension of `x`. +// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. // -// Returns The tensor that is popped from the top of the stack. -func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) { +// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. +func InplaceSub(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"elem_type": elem_type} opspec := tf.OpSpec{ - Type: "StackPopV2", + Type: "InplaceSub", Input: []tf.Input{ - handle, + x, i, v, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// PreventGradientAttr is an optional argument to PreventGradient. -type PreventGradientAttr func(optionalAttr) - -// PreventGradientMessage sets the optional message attribute to value. -// -// value: Will be printed in the error when anyone tries to differentiate -// this operation. -// If not specified, defaults to "" -func PreventGradientMessage(value string) PreventGradientAttr { - return func(m optionalAttr) { - m["message"] = value - } -} - -// An identity op that triggers an error if a gradient is requested. -// -// When executed in a graph, this op outputs its input tensor as-is. +// Updates specified rows with values in `v`. // -// When building ops to compute gradients, the TensorFlow gradient system -// will return an error when trying to lookup the gradient of this op, -// because no gradient must ever be registered for this function. This -// op exists to prevent subtle bugs from silently returning unimplemented -// gradients in some corner cases. +// Computes `x[i, :] = v; return x`. // // Arguments: -// input: any tensor. +// x: A tensor of type `T`. +// i: A vector. Indices into the left-most dimension of `x`. +// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. // -// Returns the same input tensor. -func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) { +// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. +func InplaceUpdate(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "PreventGradient", + Type: "InplaceUpdate", Input: []tf.Input{ - input, + x, i, v, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes asin of x element-wise. -func Asin(scope *Scope, x tf.Output) (y tf.Output) { +// Makes a copy of `x`. +// +// Arguments: +// x: The source tensor of type `T`. +// +// Returns y: A `Tensor` of type `T`. A copy of `x`. Guaranteed that `y` +// is not an alias of `x`. +func DeepCopy(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Asin", + Type: "DeepCopy", Input: []tf.Input{ x, }, @@ -3162,342 +3097,309 @@ func Asin(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Computes the sum along sparse segments of a tensor. +// PackAttr is an optional argument to Pack. +type PackAttr func(optionalAttr) + +// PackAxis sets the optional axis attribute to value. // -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. +// value: Dimension along which to pack. Negative values wrap around, so the +// valid range is `[-(R+1), R+1)`. +// If not specified, defaults to 0 +func PackAxis(value int64) PackAttr { + return func(m optionalAttr) { + m["axis"] = value + } +} + +// Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor. // -// Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first -// dimension, selecting a subset of dimension 0, specified by `indices`. +// Packs the `N` tensors in `values` into a tensor with rank one higher than each +// tensor in `values`, by packing them along the `axis` dimension. +// Given a list of tensors of shape `(A, B, C)`; +// +// if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`. +// if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. +// Etc. // // For example: // -// ```python -// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) +// ``` +// # 'x' is [1, 4] +// # 'y' is [2, 5] +// # 'z' is [3, 6] +// pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. +// pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] +// ``` // -// # Select two rows, one segment. -// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0])) -// # => [[0 0 0 0]] -// -// # Select two rows, two segment. -// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1])) -// # => [[ 1 2 3 4] -// # [-1 -2 -3 -4]] -// -// # Select all rows, two segments. -// tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1])) -// # => [[0 0 0 0] -// # [5 6 7 8]] -// -// # Which is equivalent to: -// tf.segment_sum(c, tf.constant([0, 0, 1])) -// ``` +// This is the opposite of `unpack`. // // Arguments: +// values: Must be of same shape and type. // -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentSum(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { +// Returns The packed tensor. +func Pack(scope *Scope, values []tf.Output, optional ...PackAttr) (output tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "SparseSegmentSum", - Input: []tf.Input{ - data, indices, segment_ids, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes hyperbolic sine of x element-wise. -func Sinh(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) } opspec := tf.OpSpec{ - Type: "Sinh", + Type: "Pack", Input: []tf.Input{ - x, + tf.OutputList(values), }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the minimum along segments of a tensor. +// Concatenates a list of `N` tensors along the first dimension. // -// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of -// segments. +// The input tensors are all required to have size 1 in the first dimension. // -// This operator is similar to the unsorted segment sum operator found -// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). -// Instead of computing the sum over segments, it computes the minimum such that: +// For example: // -// \\(output_i = \min_j data_j\\) where min is over `j` such -// that `segment_ids[j] == i`. +// ``` +// # 'x' is [[1, 4]] +// # 'y' is [[2, 5]] +// # 'z' is [[3, 6]] +// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. +// ``` // -// If the minimum is empty for a given segment ID `i`, it outputs the largest -// possible value for the specific numeric type, -// `output[i] = numeric_limits::max()`. +// The difference between concat and parallel_concat is that concat requires all +// of the inputs be computed before the operation will begin but doesn't require +// that the input shapes be known during graph construction. Parallel concat +// will copy pieces of the input into the output as they become available, in +// some situations this can provide a performance benefit. // // Arguments: +// values: Tensors to be concatenated. All must have size 1 in the first dimension +// and same shape. +// shape: the final shape of the result; should be equal to the shapes of any input +// but with the number of input values in the first dimension. // -// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s -// first dimension. -// -// -// Returns Has same shape as data, except for dimension 0 which -// has size `num_segments`. -func UnsortedSegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { +// Returns The concatenated tensor. +func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"shape": shape} opspec := tf.OpSpec{ - Type: "UnsortedSegmentMin", + Type: "ParallelConcat", Input: []tf.Input{ - data, segment_ids, num_segments, + tf.OutputList(values), }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes rectified linear 6: `min(max(features, 0), 6)`. -func Relu6(scope *Scope, features tf.Output) (activations tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Relu6", - Input: []tf.Input{ - features, - }, +// DecodeWavAttr is an optional argument to DecodeWav. +type DecodeWavAttr func(optionalAttr) + +// DecodeWavDesiredChannels sets the optional desired_channels attribute to value. +// +// value: Number of sample channels wanted. +// If not specified, defaults to -1 +func DecodeWavDesiredChannels(value int64) DecodeWavAttr { + return func(m optionalAttr) { + m["desired_channels"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes the sum along segments of a tensor. +// DecodeWavDesiredSamples sets the optional desired_samples attribute to value. // -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. +// value: Length of audio requested. +// If not specified, defaults to -1 +func DecodeWavDesiredSamples(value int64) DecodeWavAttr { + return func(m optionalAttr) { + m["desired_samples"] = value + } +} + +// Decode a 16-bit PCM WAV file to a float tensor. // -// Computes a tensor such that -// `(output[i] = sum_{j...} data[j...]` where the sum is over tuples `j...` such -// that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids` -// need not be sorted and need not cover all values in the full -// range of valid values. +// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float. // -// If the sum is empty for a given segment ID `i`, `output[i] = 0`. -// If the given segment ID `i` is negative, the value is dropped and will not be -// added to the sum of the segment. +// When desired_channels is set, if the input contains fewer channels than this +// then the last channel will be duplicated to give the requested number, else if +// the input has more channels than requested then the additional channels will be +// ignored. // -// `num_segments` should equal the number of distinct segment IDs. +// If desired_samples is set, then the audio will be cropped or padded with zeroes +// to the requested length. // -//
-// -//
+// The first output contains a Tensor with the content of the audio samples. The +// lowest dimension will be the number of channels, and the second will be the +// number of samples. For example, a ten-sample-long stereo WAV file should give an +// output shape of [10, 2]. // // Arguments: +// contents: The WAV-encoded audio, usually from a file. // -// segment_ids: A tensor whose shape is a prefix of `data.shape`. -// -// -// Returns Has same shape as data, except for the first `segment_ids.rank` -// dimensions, which are replaced with a single dimension which has size -// `num_segments`. -func UnsortedSegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { +// Returns 2-D with shape `[length, channels]`.Scalar holding the sample rate found in the WAV header. +func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (audio tf.Output, sample_rate tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "UnsortedSegmentSum", + Type: "DecodeWav", Input: []tf.Input{ - data, segment_ids, num_segments, + contents, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceStridedSliceAssignAttr is an optional argument to ResourceStridedSliceAssign. -type ResourceStridedSliceAssignAttr func(optionalAttr) - -// ResourceStridedSliceAssignBeginMask sets the optional begin_mask attribute to value. -// If not specified, defaults to 0 -func ResourceStridedSliceAssignBeginMask(value int64) ResourceStridedSliceAssignAttr { - return func(m optionalAttr) { - m["begin_mask"] = value - } -} - -// ResourceStridedSliceAssignEndMask sets the optional end_mask attribute to value. -// If not specified, defaults to 0 -func ResourceStridedSliceAssignEndMask(value int64) ResourceStridedSliceAssignAttr { - return func(m optionalAttr) { - m["end_mask"] = value - } + return op.Output(0), op.Output(1) } -// ResourceStridedSliceAssignEllipsisMask sets the optional ellipsis_mask attribute to value. -// If not specified, defaults to 0 -func ResourceStridedSliceAssignEllipsisMask(value int64) ResourceStridedSliceAssignAttr { - return func(m optionalAttr) { - m["ellipsis_mask"] = value - } -} +// UnbatchAttr is an optional argument to Unbatch. +type UnbatchAttr func(optionalAttr) -// ResourceStridedSliceAssignNewAxisMask sets the optional new_axis_mask attribute to value. -// If not specified, defaults to 0 -func ResourceStridedSliceAssignNewAxisMask(value int64) ResourceStridedSliceAssignAttr { +// UnbatchContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func UnbatchContainer(value string) UnbatchAttr { return func(m optionalAttr) { - m["new_axis_mask"] = value + m["container"] = value } } -// ResourceStridedSliceAssignShrinkAxisMask sets the optional shrink_axis_mask attribute to value. -// If not specified, defaults to 0 -func ResourceStridedSliceAssignShrinkAxisMask(value int64) ResourceStridedSliceAssignAttr { +// UnbatchSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func UnbatchSharedName(value string) UnbatchAttr { return func(m optionalAttr) { - m["shrink_axis_mask"] = value + m["shared_name"] = value } } -// Assign `value` to the sliced l-value reference of `ref`. -// -// The values of `value` are assigned to the positions in the variable -// `ref` that are selected by the slice parameters. The slice parameters -// `begin, `end`, `strides`, etc. work exactly as in `StridedSlice`. +// Reverses the operation of Batch for a single output Tensor. // -// NOTE this op currently does not support broadcasting and so `value`'s -// shape must be exactly the shape produced by the slice of `ref`. +// An instance of Unbatch either receives an empty batched_tensor, in which case it +// asynchronously waits until the values become available from a concurrently +// running instance of Unbatch with the same container and shared_name, or receives +// a non-empty batched_tensor in which case it finalizes all other concurrently +// running instances and outputs its own element from the batch. // -// Returns the created operation. -func ResourceStridedSliceAssign(scope *Scope, ref tf.Output, begin tf.Output, end tf.Output, strides tf.Output, value tf.Output, optional ...ResourceStridedSliceAssignAttr) (o *tf.Operation) { +// batched_tensor: The possibly transformed output of Batch. The size of the first +// dimension should remain unchanged by the transformations for the operation to +// work. +// batch_index: The matching batch_index obtained from Batch. +// id: The id scalar emitted by Batch. +// unbatched_tensor: The Tensor corresponding to this execution. +// timeout_micros: Maximum amount of time (in microseconds) to wait to receive the +// batched input tensor associated with a given invocation of the op. +// container: Container to control resource sharing. +// shared_name: Instances of Unbatch with the same container and shared_name are +// assumed to possibly belong to the same batch. If left empty, the op name will +// be used as the shared name. +func Unbatch(scope *Scope, batched_tensor tf.Output, batch_index tf.Output, id tf.Output, timeout_micros int64, optional ...UnbatchAttr) (unbatched_tensor tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"timeout_micros": timeout_micros} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceStridedSliceAssign", + Type: "Unbatch", Input: []tf.Input{ - ref, begin, end, strides, value, + batched_tensor, batch_index, id, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// ArgMaxAttr is an optional argument to ArgMax. -type ArgMaxAttr func(optionalAttr) - -// ArgMaxOutputType sets the optional output_type attribute to value. -// If not specified, defaults to DT_INT64 -func ArgMaxOutputType(value tf.DataType) ArgMaxAttr { - return func(m optionalAttr) { - m["output_type"] = value - } -} - -// Returns the index with the largest value across dimensions of a tensor. -// -// Note that in case of ties the identity of the return value is not guaranteed. -// -// Arguments: +// Elementwise computes the bitwise left-shift of `x` and `y`. // -// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`. -// Describes which dimension of the input Tensor to reduce across. For vectors, -// use dimension = 0. -func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMaxAttr) (output tf.Output) { +// If `y` is negative, or greater than or equal to the width of `x` in bits the +// result is implementation defined. +func LeftShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ArgMax", + Type: "LeftShift", Input: []tf.Input{ - input, dimension, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns which elements of x are finite. +// Elementwise computes the bitwise XOR of `x` and `y`. // -// @compatibility(numpy) -// Equivalent to np.isfinite -// @end_compatibility -func IsFinite(scope *Scope, x tf.Output) (y tf.Output) { +// The result will have those bits set, that are different in `x` and `y`. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseXor(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "IsFinite", + Type: "BitwiseXor", Input: []tf.Input{ - x, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// MatMulAttr is an optional argument to MatMul. -type MatMulAttr func(optionalAttr) - -// MatMulTransposeA sets the optional transpose_a attribute to value. +// Computes the mean along sparse segments of a tensor. // -// value: If true, "a" is transposed before multiplication. -// If not specified, defaults to false -func MatMulTransposeA(value bool) MatMulAttr { - return func(m optionalAttr) { - m["transpose_a"] = value - } -} - -// MatMulTransposeB sets the optional transpose_b attribute to value. +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. // -// value: If true, "b" is transposed before multiplication. -// If not specified, defaults to false -func MatMulTransposeB(value bool) MatMulAttr { - return func(m optionalAttr) { - m["transpose_b"] = value +// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first +// dimension, selecting a subset of dimension 0, specified by `indices`. +// +// Arguments: +// +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { + if scope.Err() != nil { + return } + opspec := tf.OpSpec{ + Type: "SparseSegmentMean", + Input: []tf.Input{ + data, indices, segment_ids, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Multiply the matrix "a" by the matrix "b". +// Pop the element at the top of the stack. // -// The inputs must be two-dimensional matrices and the inner dimension of -// "a" (after being transposed if transpose_a is true) must match the -// outer dimension of "b" (after being transposed if transposed_b is -// true). +// Arguments: +// handle: The handle to a stack. +// elem_type: The type of the elem that is popped. // -// *Note*: The default kernel implementation for MatMul on GPUs uses -// cublas. -func MatMul(scope *Scope, a tf.Output, b tf.Output, optional ...MatMulAttr) (product tf.Output) { +// Returns The tensor that is popped from the top of the stack. +func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"elem_type": elem_type} opspec := tf.OpSpec{ - Type: "MatMul", + Type: "StackPopV2", Input: []tf.Input{ - a, b, + handle, }, Attrs: attrs, } @@ -3505,335 +3407,336 @@ func MatMul(scope *Scope, a tf.Output, b tf.Output, optional ...MatMulAttr) (pro return op.Output(0) } -// Selects elements from `x` or `y`, depending on `condition`. -// -// The `x`, and `y` tensors must all have the same shape, and the -// output will also have that shape. -// -// The `condition` tensor must be a scalar if `x` and `y` are scalars. -// If `x` and `y` are vectors or higher rank, then `condition` must be either a -// scalar, a vector with size matching the first dimension of `x`, or must have -// the same shape as `x`. +// Computes the sum along sparse segments of a tensor. // -// The `condition` tensor acts as a mask that chooses, based on the value at each -// element, whether the corresponding element / row in the output should be -// taken from `x` (if true) or `y` (if false). +// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is +// misisng, the `output` tensor at that position will be zeroed. // -// If `condition` is a vector and `x` and `y` are higher rank matrices, then -// it chooses which row (outer dimension) to copy from `x` and `y`. -// If `condition` has the same shape as `x` and `y`, then it chooses which -// element to copy from `x` and `y`. +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. // // For example: // // ```python -// # 'condition' tensor is [[True, False] -// # [False, True]] -// # 't' is [[1, 2], -// # [3, 4]] -// # 'e' is [[5, 6], -// # [7, 8]] -// select(condition, t, e) # => [[1, 6], [7, 4]] -// +// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) // -// # 'condition' tensor is [True, False] -// # 't' is [[1, 2], -// # [3, 4]] -// # 'e' is [[5, 6], -// # [7, 8]] -// select(condition, t, e) ==> [[1, 2], -// [7, 8]] +// tf.sparse_segment_sum_with_num_segments( +// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3) +// # => [[0 0 0 0] +// # [0 0 0 0] +// # [0 0 0 0]] // +// tf.sparse_segment_sum_with_num_segments(c, +// tf.constant([0, 1]), +// tf.constant([0, 2], +// num_segments=4)) +// # => [[ 1 2 3 4] +// # [ 0 0 0 0] +// # [-1 -2 -3 -4] +// # [ 0 0 0 0]] // ``` // // Arguments: // -// x: = A `Tensor` which may have the same shape as `condition`. -// If `condition` is rank 1, `x` may have higher rank, -// but its first dimension must match the size of `condition`. -// y: = A `Tensor` with the same type and shape as `x`. +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// num_segments: Should equal the number of distinct segment IDs. // -// Returns = A `Tensor` with the same type and shape as `x` and `y`. -func Select(scope *Scope, condition tf.Output, x tf.Output, y tf.Output) (output tf.Output) { +// Returns Has same shape as data, except for dimension 0 which +// has size `num_segments`. +func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Select", + Type: "SparseSegmentSumWithNumSegments", Input: []tf.Input{ - condition, x, y, + data, indices, segment_ids, num_segments, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the truth value of x OR y element-wise. +// PreventGradientAttr is an optional argument to PreventGradient. +type PreventGradientAttr func(optionalAttr) + +// PreventGradientMessage sets the optional message attribute to value. // -// *NOTE*: `LogicalOr` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func LogicalOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LogicalOr", - Input: []tf.Input{ - x, y, - }, +// value: Will be printed in the error when anyone tries to differentiate +// this operation. +// If not specified, defaults to "" +func PreventGradientMessage(value string) PreventGradientAttr { + return func(m optionalAttr) { + m["message"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Compute the regularized incomplete beta integral \\(I_x(a, b)\\). -// -// The regularized incomplete beta integral is defined as: -// -// -// \\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\) -// -// where +// An identity op that triggers an error if a gradient is requested. // +// When executed in a graph, this op outputs its input tensor as-is. // -// \\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\) +// When building ops to compute gradients, the TensorFlow gradient system +// will return an error when trying to lookup the gradient of this op, +// because no gradient must ever be registered for this function. This +// op exists to prevent subtle bugs from silently returning unimplemented +// gradients in some corner cases. // +// Arguments: +// input: any tensor. // -// is the incomplete beta function and \\(B(a, b)\\) is the *complete* -// beta function. -func Betainc(scope *Scope, a tf.Output, b tf.Output, x tf.Output) (z tf.Output) { +// Returns the same input tensor. +func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "Betainc", - Input: []tf.Input{ - a, b, x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Return a tensor with the same shape and contents as the input tensor or value. -func Identity(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) } opspec := tf.OpSpec{ - Type: "Identity", + Type: "PreventGradient", Input: []tf.Input{ input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes arctangent of `y/x` element-wise, respecting signs of the arguments. -// -// This is the angle \( \theta \in [-\pi, \pi] \) such that -// \[ x = r \cos(\theta) \] -// and -// \[ y = r \sin(\theta) \] -// where \(r = \sqrt(x^2 + y^2) \). -func Atan2(scope *Scope, y tf.Output, x tf.Output) (z tf.Output) { +// Computes asin of x element-wise. +func Asin(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Atan2", + Type: "Asin", Input: []tf.Input{ - y, x, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Creates a dataset that passes a sliding window over `input_dataset`. +// Computes the sum along sparse segments of a tensor. // -// Arguments: +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. // -// window_size: A scalar representing the number of elements in the -// sliding window. -// stride: A scalar representing the steps moving the sliding window -// forward in one iteration. It must be in `[1, window_size)`. +// Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first +// dimension, selecting a subset of dimension 0, specified by `indices`. // +// For example: // -func SlideDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output, stride tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "SlideDataset", - Input: []tf.Input{ - input_dataset, window_size, stride, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the sum along sparse segments of a tensor divided by the sqrt of N. +// ```python +// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) // -// N is the size of the segment being reduced. +// # Select two rows, one segment. +// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0])) +// # => [[0 0 0 0]] // -// Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is -// misisng, the `output` tensor at that position will be zeroed. +// # Select two rows, two segment. +// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1])) +// # => [[ 1 2 3 4] +// # [-1 -2 -3 -4]] // -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. +// # Select all rows, two segments. +// tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1])) +// # => [[0 0 0 0] +// # [5 6 7 8]] +// +// # Which is equivalent to: +// tf.segment_sum(c, tf.constant([0, 0, 1])) +// ``` // // Arguments: // // indices: A 1-D tensor. Has same rank as `segment_ids`. // segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// num_segments: Should equal the number of distinct segment IDs. // // Returns Has same shape as data, except for dimension 0 which // has size `k`, the number of segments. -func SparseSegmentSqrtNWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { +func SparseSegmentSum(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSegmentSqrtNWithNumSegments", + Type: "SparseSegmentSum", Input: []tf.Input{ - data, indices, segment_ids, num_segments, + data, indices, segment_ids, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Compute the upper regularized incomplete Gamma function `Q(a, x)`. -// -// The upper regularized incomplete Gamma function is defined as: -// -// \\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\) -// -// where -// -// \\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\\) -// -// is the upper incomplete Gama function. -// -// Note, above `P(a, x)` (`Igamma`) is the lower regularized complete -// Gamma function. -func Igammac(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { +// Computes hyperbolic sine of x element-wise. +func Sinh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Igammac", + Type: "Sinh", Input: []tf.Input{ - a, x, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ApproximateEqualAttr is an optional argument to ApproximateEqual. -type ApproximateEqualAttr func(optionalAttr) - -// ApproximateEqualTolerance sets the optional tolerance attribute to value. -// If not specified, defaults to 1e-05 -func ApproximateEqualTolerance(value float32) ApproximateEqualAttr { - return func(m optionalAttr) { - m["tolerance"] = value - } -} - -// Returns the truth value of abs(x-y) < tolerance element-wise. -func ApproximateEqual(scope *Scope, x tf.Output, y tf.Output, optional ...ApproximateEqualAttr) (z tf.Output) { +// Computes the minimum along segments of a tensor. +// +// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of +// segments. +// +// This operator is similar to the unsorted segment sum operator found +// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). +// Instead of computing the sum over segments, it computes the minimum such that: +// +// \\(output_i = \min_j data_j\\) where min is over `j` such +// that `segment_ids[j] == i`. +// +// If the minimum is empty for a given segment ID `i`, it outputs the largest +// possible value for the specific numeric type, +// `output[i] = numeric_limits::max()`. +// +// Arguments: +// +// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +// first dimension. +// +// +// Returns Has same shape as data, except for dimension 0 which +// has size `num_segments`. +func UnsortedSegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ApproximateEqual", + Type: "UnsortedSegmentMin", Input: []tf.Input{ - x, y, + data, segment_ids, num_segments, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns x / y element-wise. -// -// *NOTE*: `Div` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Div(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Computes rectified linear 6: `min(max(features, 0), 6)`. +func Relu6(scope *Scope, features tf.Output) (activations tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Div", + Type: "Relu6", Input: []tf.Input{ - x, y, + features, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns x * y element-wise. +// Computes the sum along segments of a tensor. // -// *NOTE*: `Multiply` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Mul(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. +// +// Computes a tensor such that +// \\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such +// that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids` +// need not be sorted and need not cover all values in the full +// range of valid values. +// +// If the sum is empty for a given segment ID `i`, `output[i] = 0`. +// If the given segment ID `i` is negative, the value is dropped and will not be +// added to the sum of the segment. +// +// `num_segments` should equal the number of distinct segment IDs. +// +//
+// +//
+// +// Arguments: +// +// segment_ids: A tensor whose shape is a prefix of `data.shape`. +// +// +// Returns Has same shape as data, except for the first `segment_ids.rank` +// dimensions, which are replaced with a single dimension which has size +// `num_segments`. +func UnsortedSegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Mul", + Type: "UnsortedSegmentSum", Input: []tf.Input{ - x, y, + data, segment_ids, num_segments, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// BiasAddAttr is an optional argument to BiasAdd. -type BiasAddAttr func(optionalAttr) +// ResourceStridedSliceAssignAttr is an optional argument to ResourceStridedSliceAssign. +type ResourceStridedSliceAssignAttr func(optionalAttr) -// BiasAddDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the bias tensor will be added to the last dimension -// of the value tensor. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// The tensor will be added to "in_channels", the third-to-the-last -// dimension. -// If not specified, defaults to "NHWC" -func BiasAddDataFormat(value string) BiasAddAttr { +// ResourceStridedSliceAssignBeginMask sets the optional begin_mask attribute to value. +// If not specified, defaults to 0 +func ResourceStridedSliceAssignBeginMask(value int64) ResourceStridedSliceAssignAttr { return func(m optionalAttr) { - m["data_format"] = value + m["begin_mask"] = value } } -// Adds `bias` to `value`. +// ResourceStridedSliceAssignEndMask sets the optional end_mask attribute to value. +// If not specified, defaults to 0 +func ResourceStridedSliceAssignEndMask(value int64) ResourceStridedSliceAssignAttr { + return func(m optionalAttr) { + m["end_mask"] = value + } +} + +// ResourceStridedSliceAssignEllipsisMask sets the optional ellipsis_mask attribute to value. +// If not specified, defaults to 0 +func ResourceStridedSliceAssignEllipsisMask(value int64) ResourceStridedSliceAssignAttr { + return func(m optionalAttr) { + m["ellipsis_mask"] = value + } +} + +// ResourceStridedSliceAssignNewAxisMask sets the optional new_axis_mask attribute to value. +// If not specified, defaults to 0 +func ResourceStridedSliceAssignNewAxisMask(value int64) ResourceStridedSliceAssignAttr { + return func(m optionalAttr) { + m["new_axis_mask"] = value + } +} + +// ResourceStridedSliceAssignShrinkAxisMask sets the optional shrink_axis_mask attribute to value. +// If not specified, defaults to 0 +func ResourceStridedSliceAssignShrinkAxisMask(value int64) ResourceStridedSliceAssignAttr { + return func(m optionalAttr) { + m["shrink_axis_mask"] = value + } +} + +// Assign `value` to the sliced l-value reference of `ref`. // -// This is a special case of `tf.add` where `bias` is restricted to be 1-D. -// Broadcasting is supported, so `value` may have any number of dimensions. +// The values of `value` are assigned to the positions in the variable +// `ref` that are selected by the slice parameters. The slice parameters +// `begin, `end`, `strides`, etc. work exactly as in `StridedSlice`. // -// Arguments: -// value: Any number of dimensions. -// bias: 1-D with size the last dimension of `value`. +// NOTE this op currently does not support broadcasting and so `value`'s +// shape must be exactly the shape produced by the slice of `ref`. // -// Returns Broadcasted sum of `value` and `bias`. -func BiasAdd(scope *Scope, value tf.Output, bias tf.Output, optional ...BiasAddAttr) (output tf.Output) { +// Returns the created operation. +func ResourceStridedSliceAssign(scope *Scope, ref tf.Output, begin tf.Output, end tf.Output, strides tf.Output, value tf.Output, optional ...ResourceStridedSliceAssignAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -3842,51 +3745,36 @@ func BiasAdd(scope *Scope, value tf.Output, bias tf.Output, optional ...BiasAddA a(attrs) } opspec := tf.OpSpec{ - Type: "BiasAdd", + Type: "ResourceStridedSliceAssign", Input: []tf.Input{ - value, bias, + ref, begin, end, strides, value, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// SparseReduceSumSparseAttr is an optional argument to SparseReduceSumSparse. -type SparseReduceSumSparseAttr func(optionalAttr) +// ArgMaxAttr is an optional argument to ArgMax. +type ArgMaxAttr func(optionalAttr) -// SparseReduceSumSparseKeepDims sets the optional keep_dims attribute to value. -// -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SparseReduceSumSparseKeepDims(value bool) SparseReduceSumSparseAttr { +// ArgMaxOutputType sets the optional output_type attribute to value. +// If not specified, defaults to DT_INT64 +func ArgMaxOutputType(value tf.DataType) ArgMaxAttr { return func(m optionalAttr) { - m["keep_dims"] = value + m["output_type"] = value } } -// Computes the sum of elements across dimensions of a SparseTensor. +// Returns the index with the largest value across dimensions of a tensor. // -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_sum()`. In contrast to SparseReduceSum, this Op returns a -// SparseTensor. -// -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. -// -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. +// Note that in case of ties the identity of the return value is not guaranteed. // // Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. -func SparseReduceSumSparse(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumSparseAttr) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { +// +// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`. +// Describes which dimension of the input Tensor to reduce across. For vectors, +// use dimension = 0. +func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMaxAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -3895,83 +3783,68 @@ func SparseReduceSumSparse(scope *Scope, input_indices tf.Output, input_values t a(attrs) } opspec := tf.OpSpec{ - Type: "SparseReduceSumSparse", + Type: "ArgMax", Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, + input, dimension, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Returns x + y element-wise. +// Returns which elements of x are finite. // -// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// @compatibility(numpy) +// Equivalent to np.isfinite +// @end_compatibility +func IsFinite(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "AddV2", + Type: "IsFinite", Input: []tf.Input{ - x, y, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns x + y element-wise. +// MatMulAttr is an optional argument to MatMul. +type MatMulAttr func(optionalAttr) + +// MatMulTransposeA sets the optional transpose_a attribute to value. // -// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Add", - Input: []tf.Input{ - x, y, - }, +// value: If true, "a" is transposed before multiplication. +// If not specified, defaults to false +func MatMulTransposeA(value bool) MatMulAttr { + return func(m optionalAttr) { + m["transpose_a"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// NthElementAttr is an optional argument to NthElement. -type NthElementAttr func(optionalAttr) - -// NthElementReverse sets the optional reverse attribute to value. +// MatMulTransposeB sets the optional transpose_b attribute to value. // -// value: When set to True, find the nth-largest value in the vector and vice -// versa. +// value: If true, "b" is transposed before multiplication. // If not specified, defaults to false -func NthElementReverse(value bool) NthElementAttr { +func MatMulTransposeB(value bool) MatMulAttr { return func(m optionalAttr) { - m["reverse"] = value + m["transpose_b"] = value } } -// Finds values of the `n`-th order statistic for the last dimension. -// -// If the input is a vector (rank-1), finds the entries which is the nth-smallest -// value in the vector and outputs their values as scalar tensor. -// -// For matrices (resp. higher rank input), computes the entries which is the -// nth-smallest value in each row (resp. vector along the last dimension). Thus, -// -// values.shape = input.shape[:-1] +// Multiply the matrix "a" by the matrix "b". // -// Arguments: -// input: 1-D or higher with last dimension at least `n+1`. -// n: 0-D. Position of sorted vector to select along the last dimension (along -// each row for matrices). Valid range of n is `[0, input.shape[:-1])` +// The inputs must be two-dimensional matrices and the inner dimension of +// "a" (after being transposed if transpose_a is true) must match the +// outer dimension of "b" (after being transposed if transposed_b is +// true). // -// Returns The `n`-th order statistic along each last dimensional slice. -func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) { +// *Note*: The default kernel implementation for MatMul on GPUs uses +// cublas. +func MatMul(scope *Scope, a tf.Output, b tf.Output, optional ...MatMulAttr) (product tf.Output) { if scope.Err() != nil { return } @@ -3980,9 +3853,9 @@ func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthEleme a(attrs) } opspec := tf.OpSpec{ - Type: "NthElement", + Type: "MatMul", Input: []tf.Input{ - input, n, + a, b, }, Attrs: attrs, } @@ -3990,294 +3863,360 @@ func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthEleme return op.Output(0) } -// Computes the maximum along segments of a tensor. +// Selects elements from `x` or `y`, depending on `condition`. // -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. +// The `x`, and `y` tensors must all have the same shape, and the +// output will also have that shape. // -// This operator is similar to the unsorted segment sum operator found -// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). -// Instead of computing the sum over segments, it computes the maximum such that: +// The `condition` tensor must be a scalar if `x` and `y` are scalars. +// If `x` and `y` are vectors or higher rank, then `condition` must be either a +// scalar, a vector with size matching the first dimension of `x`, or must have +// the same shape as `x`. // -// \\(output_i = \max_j data_j\\) where max is over `j` such -// that `segment_ids[j] == i`. +// The `condition` tensor acts as a mask that chooses, based on the value at each +// element, whether the corresponding element / row in the output should be +// taken from `x` (if true) or `y` (if false). // -// If the maximum is empty for a given segment ID `i`, it outputs the smallest -// possible value for the specific numeric type, -// `output[i] = numeric_limits::lowest()`. +// If `condition` is a vector and `x` and `y` are higher rank matrices, then +// it chooses which row (outer dimension) to copy from `x` and `y`. +// If `condition` has the same shape as `x` and `y`, then it chooses which +// element to copy from `x` and `y`. // -//
-// -//
+// For example: // -// Arguments: +// ```python +// # 'condition' tensor is [[True, False] +// # [False, True]] +// # 't' is [[1, 2], +// # [3, 4]] +// # 'e' is [[5, 6], +// # [7, 8]] +// select(condition, t, e) # => [[1, 6], [7, 4]] // -// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s -// first dimension. // +// # 'condition' tensor is [True, False] +// # 't' is [[1, 2], +// # [3, 4]] +// # 'e' is [[5, 6], +// # [7, 8]] +// select(condition, t, e) ==> [[1, 2], +// [7, 8]] +// +// ``` +// +// Arguments: // -// Returns Has same shape as data, except for dimension 0 which -// has size `num_segments`. -func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { +// x: = A `Tensor` which may have the same shape as `condition`. +// If `condition` is rank 1, `x` may have higher rank, +// but its first dimension must match the size of `condition`. +// y: = A `Tensor` with the same type and shape as `x`. +// +// Returns = A `Tensor` with the same type and shape as `x` and `y`. +func Select(scope *Scope, condition tf.Output, x tf.Output, y tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "UnsortedSegmentMax", + Type: "Select", Input: []tf.Input{ - data, segment_ids, num_segments, + condition, x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes exponential of x element-wise. \\(y = e^x\\). -func Exp(scope *Scope, x tf.Output) (y tf.Output) { +// Returns the truth value of x OR y element-wise. +// +// *NOTE*: `LogicalOr` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func LogicalOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Exp", + Type: "LogicalOr", Input: []tf.Input{ - x, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns an element-wise indication of the sign of a number. +// Compute the regularized incomplete beta integral \\(I_x(a, b)\\). // -// `y = sign(x) = -1` if `x < 0`; 0 if `x == 0`; 1 if `x > 0`. +// The regularized incomplete beta integral is defined as: // -// For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`. -func Sign(scope *Scope, x tf.Output) (y tf.Output) { +// +// \\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\) +// +// where +// +// +// \\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\) +// +// +// is the incomplete beta function and \\(B(a, b)\\) is the *complete* +// beta function. +func Betainc(scope *Scope, a tf.Output, b tf.Output, x tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Sign", + Type: "Betainc", Input: []tf.Input{ - x, + a, b, x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// ArgMinAttr is an optional argument to ArgMin. -type ArgMinAttr func(optionalAttr) - -// ArgMinOutputType sets the optional output_type attribute to value. -// If not specified, defaults to DT_INT64 -func ArgMinOutputType(value tf.DataType) ArgMinAttr { - return func(m optionalAttr) { - m["output_type"] = value +// Return a tensor with the same shape and contents as the input tensor or value. +func Identity(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Identity", + Input: []tf.Input{ + input, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns the index with the smallest value across dimensions of a tensor. -// -// Note that in case of ties the identity of the return value is not guaranteed. -// -// Arguments: +// Computes arctangent of `y/x` element-wise, respecting signs of the arguments. // -// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`. -// Describes which dimension of the input Tensor to reduce across. For vectors, -// use dimension = 0. -func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMinAttr) (output tf.Output) { +// This is the angle \( \theta \in [-\pi, \pi] \) such that +// \[ x = r \cos(\theta) \] +// and +// \[ y = r \sin(\theta) \] +// where \(r = \sqrt(x^2 + y^2) \). +func Atan2(scope *Scope, y tf.Output, x tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ArgMin", + Type: "Atan2", Input: []tf.Input{ - input, dimension, + y, x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Convert the quantized 'input' tensor into a lower-precision 'output', using the -// -// output range specified with 'requested_output_min' and 'requested_output_max'. -// -// [input_min, input_max] are scalar floats that specify the range for the float -// interpretation of the 'input' data. For example, if input_min is -1.0f and -// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0 -// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. +// Creates a dataset that passes a sliding window over `input_dataset`. // // Arguments: // -// input_min: The float value that the minimum quantized input value represents. -// input_max: The float value that the maximum quantized input value represents. -// requested_output_min: The float value that the minimum quantized output value represents. -// requested_output_max: The float value that the maximum quantized output value represents. -// out_type: The type of the output. Should be a lower bit depth than Tinput. +// window_size: A scalar representing the number of elements in the +// sliding window. +// window_shift: A scalar representing the steps moving the sliding window +// forward in one iteration. It must be positive. +// window_stride: A scalar representing the stride of the input elements of the sliding window. +// It must be positive. // -// Returns The requested_output_min value is copied into this output.The requested_output_max value is copied into this output. -func Requantize(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) { +// +func SlideDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output, window_shift tf.Output, window_stride tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"out_type": out_type} + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Requantize", + Type: "SlideDataset", Input: []tf.Input{ - input, input_min, input_max, requested_output_min, requested_output_max, + input_dataset, window_size, window_shift, window_stride, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Computes the determinant of one or more square matrices. +// FusedBatchNormAttr is an optional argument to FusedBatchNorm. +type FusedBatchNormAttr func(optionalAttr) + +// FusedBatchNormEpsilon sets the optional epsilon attribute to value. // -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. The output is a tensor containing the determinants -// for all input submatrices `[..., :, :]`. +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr { + return func(m optionalAttr) { + m["epsilon"] = value + } +} + +// FusedBatchNormDataFormat sets the optional data_format attribute to value. +// +// value: The data format for x and y. Either "NHWC" (default) or "NCHW". +// If not specified, defaults to "NHWC" +func FusedBatchNormDataFormat(value string) FusedBatchNormAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// FusedBatchNormIsTraining sets the optional is_training attribute to value. +// +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr { + return func(m optionalAttr) { + m["is_training"] = value + } +} + +// Batch normalization. +// +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. // // Arguments: -// input: Shape is `[..., M, M]`. +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// offset: A 1D Tensor for offset, to shift to the normalized x. +// mean: A 1D Tensor for population mean. Used for inference only; +// must be empty for training. +// variance: A 1D Tensor for population variance. Used for inference only; +// must be empty for training. // -// Returns Shape is `[...]`. -func MatrixDeterminant(scope *Scope, input tf.Output) (output tf.Output) { +// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow +// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by +// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused +// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance +// in the cuDNN case), to be reused in the gradient computation. +func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "MatrixDeterminant", + Type: "FusedBatchNorm", Input: []tf.Input{ - input, + x, scale, offset, mean, variance, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// Computes sin of x element-wise. -func Sin(scope *Scope, x tf.Output) (y tf.Output) { +// ApproximateEqualAttr is an optional argument to ApproximateEqual. +type ApproximateEqualAttr func(optionalAttr) + +// ApproximateEqualTolerance sets the optional tolerance attribute to value. +// If not specified, defaults to 1e-05 +func ApproximateEqualTolerance(value float32) ApproximateEqualAttr { + return func(m optionalAttr) { + m["tolerance"] = value + } +} + +// Returns the truth value of abs(x-y) < tolerance element-wise. +func ApproximateEqual(scope *Scope, x tf.Output, y tf.Output, optional ...ApproximateEqualAttr) (z tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Sin", + Type: "ApproximateEqual", Input: []tf.Input{ - x, + x, y, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the complementary error function of `x` element-wise. -func Erfc(scope *Scope, x tf.Output) (y tf.Output) { +// Returns x / y element-wise. +// +// *NOTE*: `Div` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Div(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Erfc", + Type: "Div", Input: []tf.Input{ - x, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes Psi, the derivative of Lgamma (the log of the absolute value of +// Returns x * y element-wise. // -// `Gamma(x)`), element-wise. -func Digamma(scope *Scope, x tf.Output) (y tf.Output) { +// *NOTE*: `Multiply` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Mul(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Digamma", + Type: "Mul", Input: []tf.Input{ - x, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter. -type Conv2DBackpropFilterAttr func(optionalAttr) - -// Conv2DBackpropFilterUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. -// If not specified, defaults to true -func Conv2DBackpropFilterUseCudnnOnGpu(value bool) Conv2DBackpropFilterAttr { - return func(m optionalAttr) { - m["use_cudnn_on_gpu"] = value - } -} +// BiasAddAttr is an optional argument to BiasAdd. +type BiasAddAttr func(optionalAttr) -// Conv2DBackpropFilterDataFormat sets the optional data_format attribute to value. +// BiasAddDataFormat sets the optional data_format attribute to value. // // value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. +// default format "NHWC", the bias tensor will be added to the last dimension +// of the value tensor. // Alternatively, the format could be "NCHW", the data storage order of: // [batch, in_channels, in_height, in_width]. +// The tensor will be added to "in_channels", the third-to-the-last +// dimension. // If not specified, defaults to "NHWC" -func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { +func BiasAddDataFormat(value string) BiasAddAttr { return func(m optionalAttr) { m["data_format"] = value } } -// Conv2DBackpropFilterDilations sets the optional dilations attribute to value. +// Adds `bias` to `value`. // -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to -func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes the gradients of convolution with respect to the filter. +// This is a special case of `tf.add` where `bias` is restricted to be 1-D. +// Broadcasting is supported, so `value` may have any number of dimensions. // // Arguments: -// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 4-D -// `[filter_height, filter_width, in_channels, out_channels]` tensor. -// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. Must be in the same order as the dimension specified with -// format. -// padding: The type of padding algorithm to use. +// value: Any number of dimensions. +// bias: 1-D with size the last dimension of `value`. // -// Returns 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. -// the `filter` input of the convolution. -func Conv2DBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropFilterAttr) (output tf.Output) { +// Returns Broadcasted sum of `value` and `bias`. +func BiasAdd(scope *Scope, value tf.Output, bias tf.Output, optional ...BiasAddAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv2DBackpropFilter", + Type: "BiasAdd", Input: []tf.Input{ - input, filter_sizes, out_backprop, + value, bias, }, Attrs: attrs, } @@ -4285,104 +4224,175 @@ func Conv2DBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, return op.Output(0) } -// Returns the number of work units this Reader has finished processing. +// SparseReduceSumSparseAttr is an optional argument to SparseReduceSumSparse. +type SparseReduceSumSparseAttr func(optionalAttr) + +// SparseReduceSumSparseKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SparseReduceSumSparseKeepDims(value bool) SparseReduceSumSparseAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the sum of elements across dimensions of a SparseTensor. +// +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_sum()`. In contrast to SparseReduceSum, this Op returns a +// SparseTensor. +// +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. +// +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. // // Arguments: -// reader_handle: Handle to a Reader. -func ReaderNumWorkUnitsCompletedV2(scope *Scope, reader_handle tf.Output) (units_completed tf.Output) { +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +func SparseReduceSumSparse(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumSparseAttr) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ReaderNumWorkUnitsCompletedV2", + Type: "SparseReduceSumSparse", Input: []tf.Input{ - reader_handle, + input_indices, input_values, input_shape, reduction_axes, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Computes the log of the absolute value of `Gamma(x)` element-wise. -func Lgamma(scope *Scope, x tf.Output) (y tf.Output) { +// Returns x + y element-wise. +// +// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Lgamma", + Type: "AddV2", Input: []tf.Input{ - x, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the reverse mode backpropagated gradient of the Cholesky algorithm. +// NthElementAttr is an optional argument to NthElement. +type NthElementAttr func(optionalAttr) + +// NthElementReverse sets the optional reverse attribute to value. // -// For an explanation see "Differentiation of the Cholesky algorithm" by -// Iain Murray http://arxiv.org/abs/1602.07527. +// value: When set to True, find the nth-largest value in the vector and vice +// versa. +// If not specified, defaults to false +func NthElementReverse(value bool) NthElementAttr { + return func(m optionalAttr) { + m["reverse"] = value + } +} + +// Finds values of the `n`-th order statistic for the last dimension. +// +// If the input is a vector (rank-1), finds the entries which is the nth-smallest +// value in the vector and outputs their values as scalar tensor. +// +// For matrices (resp. higher rank input), computes the entries which is the +// nth-smallest value in each row (resp. vector along the last dimension). Thus, +// +// values.shape = input.shape[:-1] // // Arguments: -// l: Output of batch Cholesky algorithm l = cholesky(A). Shape is `[..., M, M]`. -// Algorithm depends only on lower triangular part of the innermost matrices of -// this tensor. -// grad: df/dl where f is some scalar function. Shape is `[..., M, M]`. -// Algorithm depends only on lower triangular part of the innermost matrices of -// this tensor. +// input: 1-D or higher with last dimension at least `n+1`. +// n: 0-D. Position of sorted vector to select along the last dimension (along +// each row for matrices). Valid range of n is `[0, input.shape[:-1])` // -// Returns Symmetrized version of df/dA . Shape is `[..., M, M]` -func CholeskyGrad(scope *Scope, l tf.Output, grad tf.Output) (output tf.Output) { +// Returns The `n`-th order statistic along each last dimensional slice. +func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "CholeskyGrad", + Type: "NthElement", Input: []tf.Input{ - l, grad, + input, n, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the mean along sparse segments of a tensor. -// -// Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is -// misisng, the `output` tensor at that position will be zeroed. +// Computes the maximum along segments of a tensor. // // Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of // segments. // +// This operator is similar to the unsorted segment sum operator found +// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). +// Instead of computing the sum over segments, it computes the maximum such that: +// +// \\(output_i = \max_j data_j\\) where max is over `j` such +// that `segment_ids[j] == i`. +// +// If the maximum is empty for a given segment ID `i`, it outputs the smallest +// possible value for the specific numeric type, +// `output[i] = numeric_limits::lowest()`. +// +//
+// +//
+// // Arguments: // -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// num_segments: Should equal the number of distinct segment IDs. +// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s +// first dimension. // -// Returns Has same shape as data, except for dimension 0 which has size -// `num_segments`. -func SparseSegmentMeanWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { +// +// Returns Has same shape as data, except for dimension 0 which +// has size `num_segments`. +func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSegmentMeanWithNumSegments", + Type: "UnsortedSegmentMax", Input: []tf.Input{ - data, indices, segment_ids, num_segments, + data, segment_ids, num_segments, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes hyperbolic cosine of x element-wise. -func Cosh(scope *Scope, x tf.Output) (y tf.Output) { +// Computes exponential of x element-wise. \\(y = e^x\\). +func Exp(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Cosh", + Type: "Exp", Input: []tf.Input{ x, }, @@ -4391,116 +4401,129 @@ func Cosh(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Creates a dataset that emits each dim-0 slice of `components` once. -func TensorSliceDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { +// Returns an element-wise indication of the sign of a number. +// +// `y = sign(x) = -1` if `x < 0`; 0 if `x == 0`; 1 if `x > 0`. +// +// For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`. +func Sign(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TensorSliceDataset", + Type: "Sign", Input: []tf.Input{ - tf.OutputList(components), + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes natural logarithm of (1 + x) element-wise. +// ArgMinAttr is an optional argument to ArgMin. +type ArgMinAttr func(optionalAttr) + +// ArgMinOutputType sets the optional output_type attribute to value. +// If not specified, defaults to DT_INT64 +func ArgMinOutputType(value tf.DataType) ArgMinAttr { + return func(m optionalAttr) { + m["output_type"] = value + } +} + +// Returns the index with the smallest value across dimensions of a tensor. // -// I.e., \\(y = \log_e (1 + x)\\). -func Log1p(scope *Scope, x tf.Output) (y tf.Output) { +// Note that in case of ties the identity of the return value is not guaranteed. +// +// Arguments: +// +// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`. +// Describes which dimension of the input Tensor to reduce across. For vectors, +// use dimension = 0. +func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMinAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Log1p", + Type: "ArgMin", Input: []tf.Input{ - x, + input, dimension, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes rectified linear 6 gradients for a Relu6 operation. +// Convert the quantized 'input' tensor into a lower-precision 'output', using the +// +// output range specified with 'requested_output_min' and 'requested_output_max'. +// +// [input_min, input_max] are scalar floats that specify the range for the float +// interpretation of the 'input' data. For example, if input_min is -1.0f and +// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0 +// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. // // Arguments: -// gradients: The backpropagated gradients to the corresponding Relu6 operation. -// features: The features passed as input to the corresponding Relu6 operation, or -// its output; using either one produces the same result. // -// Returns The gradients: -// `gradients * (features > 0) * (features < 6)`. -func Relu6Grad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { +// input_min: The float value that the minimum quantized input value represents. +// input_max: The float value that the maximum quantized input value represents. +// requested_output_min: The float value that the minimum quantized output value represents. +// requested_output_max: The float value that the maximum quantized output value represents. +// out_type: The type of the output. Should be a lower bit depth than Tinput. +// +// Returns The requested_output_min value is copied into this output.The requested_output_max value is copied into this output. +func Requantize(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"out_type": out_type} opspec := tf.OpSpec{ - Type: "Relu6Grad", + Type: "Requantize", Input: []tf.Input{ - gradients, features, + input, input_min, input_max, requested_output_min, requested_output_max, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResizeBicubicAttr is an optional argument to ResizeBicubic. -type ResizeBicubicAttr func(optionalAttr) - -// ResizeBicubicAlignCorners sets the optional align_corners attribute to value. -// -// value: If true, the centers of the 4 corner pixels of the input and output tensors are -// aligned, preserving the values at the corner pixels. Defaults to false. -// If not specified, defaults to false -func ResizeBicubicAlignCorners(value bool) ResizeBicubicAttr { - return func(m optionalAttr) { - m["align_corners"] = value - } + return op.Output(0), op.Output(1), op.Output(2) } -// Resize `images` to `size` using bicubic interpolation. +// Computes the determinant of one or more square matrices. // -// Input images can be of different types but output images are always float. +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. The output is a tensor containing the determinants +// for all input submatrices `[..., :, :]`. // // Arguments: -// images: 4-D with shape `[batch, height, width, channels]`. -// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The -// new size for the images. +// input: Shape is `[..., M, M]`. // -// Returns 4-D with shape -// `[batch, new_height, new_width, channels]`. -func ResizeBicubic(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBicubicAttr) (resized_images tf.Output) { +// Returns Shape is `[...]`. +func MatrixDeterminant(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResizeBicubic", + Type: "MatrixDeterminant", Input: []tf.Input{ - images, size, + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes natural logarithm of x element-wise. -// -// I.e., \\(y = \log_e x\\). -func Log(scope *Scope, x tf.Output) (y tf.Output) { +// Computes sin of x element-wise. +func Sin(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Log", + Type: "Sin", Input: []tf.Input{ x, }, @@ -4509,16 +4532,13 @@ func Log(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Rounds the values of a tensor to the nearest integer, element-wise. -// -// Rounds half to even. Also known as bankers rounding. If you want to round -// according to the current system rounding mode use std::cint. -func Round(scope *Scope, x tf.Output) (y tf.Output) { +// Computes the complementary error function of `x` element-wise. +func Erfc(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Round", + Type: "Erfc", Input: []tf.Input{ x, }, @@ -4527,344 +4547,314 @@ func Round(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// RecordInputAttr is an optional argument to RecordInput. -type RecordInputAttr func(optionalAttr) - -// RecordInputFileRandomSeed sets the optional file_random_seed attribute to value. +// Computes Psi, the derivative of Lgamma (the log of the absolute value of // -// value: Random seeds used to produce randomized records. -// If not specified, defaults to 301 -func RecordInputFileRandomSeed(value int64) RecordInputAttr { - return func(m optionalAttr) { - m["file_random_seed"] = value +// `Gamma(x)`), element-wise. +func Digamma(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Digamma", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// RecordInputFileShuffleShiftRatio sets the optional file_shuffle_shift_ratio attribute to value. -// -// value: Shifts the list of files after the list is randomly -// shuffled. -// If not specified, defaults to 0 -func RecordInputFileShuffleShiftRatio(value float32) RecordInputAttr { +// Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter. +type Conv2DBackpropFilterAttr func(optionalAttr) + +// Conv2DBackpropFilterUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. +// If not specified, defaults to true +func Conv2DBackpropFilterUseCudnnOnGpu(value bool) Conv2DBackpropFilterAttr { return func(m optionalAttr) { - m["file_shuffle_shift_ratio"] = value + m["use_cudnn_on_gpu"] = value } } -// RecordInputFileBufferSize sets the optional file_buffer_size attribute to value. +// Conv2DBackpropFilterDataFormat sets the optional data_format attribute to value. // -// value: The randomization shuffling buffer. -// If not specified, defaults to 10000 -func RecordInputFileBufferSize(value int64) RecordInputAttr { - return func(m optionalAttr) { - m["file_buffer_size"] = value - } -} - -// RecordInputFileParallelism sets the optional file_parallelism attribute to value. -// -// value: How many sstables are opened and concurrently iterated over. -// If not specified, defaults to 16 -func RecordInputFileParallelism(value int64) RecordInputAttr { - return func(m optionalAttr) { - m["file_parallelism"] = value - } -} - -// RecordInputBatchSize sets the optional batch_size attribute to value. -// -// value: The batch size. -// If not specified, defaults to 32 -func RecordInputBatchSize(value int64) RecordInputAttr { +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { return func(m optionalAttr) { - m["batch_size"] = value + m["data_format"] = value } } -// RecordInputCompressionType sets the optional compression_type attribute to value. +// Conv2DBackpropFilterDilations sets the optional dilations attribute to value. // -// value: The type of compression for the file. Currently ZLIB and -// GZIP are supported. Defaults to none. -// If not specified, defaults to "" -func RecordInputCompressionType(value string) RecordInputAttr { +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { return func(m optionalAttr) { - m["compression_type"] = value + m["dilations"] = value } } -// Emits randomized records. +// Computes the gradients of convolution with respect to the filter. // // Arguments: -// file_pattern: Glob pattern for the data files. +// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 4-D +// `[filter_height, filter_width, in_channels, out_channels]` tensor. +// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. Must be in the same order as the dimension specified with +// format. +// padding: The type of padding algorithm to use. // -// Returns A tensor of shape [batch_size]. -func RecordInput(scope *Scope, file_pattern string, optional ...RecordInputAttr) (records tf.Output) { +// Returns 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. +// the `filter` input of the convolution. +func Conv2DBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropFilterAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"file_pattern": file_pattern} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RecordInput", - + Type: "Conv2DBackpropFilter", + Input: []tf.Input{ + input, filter_sizes, out_backprop, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes reciprocal of square root of x element-wise. +// Returns the number of work units this Reader has finished processing. // -// I.e., \\(y = 1 / \sqrt{x}\\). -func Rsqrt(scope *Scope, x tf.Output) (y tf.Output) { +// Arguments: +// reader_handle: Handle to a Reader. +func ReaderNumWorkUnitsCompletedV2(scope *Scope, reader_handle tf.Output) (units_completed tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Rsqrt", + Type: "ReaderNumWorkUnitsCompletedV2", Input: []tf.Input{ - x, + reader_handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// MatrixInverseAttr is an optional argument to MatrixInverse. -type MatrixInverseAttr func(optionalAttr) - -// MatrixInverseAdjoint sets the optional adjoint attribute to value. -// If not specified, defaults to false -func MatrixInverseAdjoint(value bool) MatrixInverseAttr { - return func(m optionalAttr) { - m["adjoint"] = value +// Computes the log of the absolute value of `Gamma(x)` element-wise. +func Lgamma(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Lgamma", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes the inverse of one or more square invertible matrices or their -// -// adjoints (conjugate transposes). -// -// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions -// form square matrices. The output is a tensor of the same shape as the input -// containing the inverse for all input submatrices `[..., :, :]`. -// -// The op uses LU decomposition with partial pivoting to compute the inverses. +// Computes the reverse mode backpropagated gradient of the Cholesky algorithm. // -// If a matrix is not invertible there is no guarantee what the op does. It -// may detect the condition and raise an exception or it may simply return a -// garbage result. +// For an explanation see "Differentiation of the Cholesky algorithm" by +// Iain Murray http://arxiv.org/abs/1602.07527. // // Arguments: -// input: Shape is `[..., M, M]`. -// -// Returns Shape is `[..., M, M]`. +// l: Output of batch Cholesky algorithm l = cholesky(A). Shape is `[..., M, M]`. +// Algorithm depends only on lower triangular part of the innermost matrices of +// this tensor. +// grad: df/dl where f is some scalar function. Shape is `[..., M, M]`. +// Algorithm depends only on lower triangular part of the innermost matrices of +// this tensor. // -// @compatibility(numpy) -// Equivalent to np.linalg.inv -// @end_compatibility -func MatrixInverse(scope *Scope, input tf.Output, optional ...MatrixInverseAttr) (output tf.Output) { +// Returns Symmetrized version of df/dA . Shape is `[..., M, M]` +func CholeskyGrad(scope *Scope, l tf.Output, grad tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MatrixInverse", + Type: "CholeskyGrad", Input: []tf.Input{ - input, + l, grad, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes square of x element-wise. +// Computes the mean along sparse segments of a tensor. // -// I.e., \\(y = x * x = x^2\\). -func Square(scope *Scope, x tf.Output) (y tf.Output) { +// Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is +// misisng, the `output` tensor at that position will be zeroed. +// +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. +// +// Arguments: +// +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// num_segments: Should equal the number of distinct segment IDs. +// +// Returns Has same shape as data, except for dimension 0 which has size +// `num_segments`. +func SparseSegmentMeanWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Square", + Type: "SparseSegmentMeanWithNumSegments", Input: []tf.Input{ - x, + data, indices, segment_ids, num_segments, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise. -// -// See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) -// ](http://arxiv.org/abs/1511.07289) -func Elu(scope *Scope, features tf.Output) (activations tf.Output) { +// Computes hyperbolic cosine of x element-wise. +func Cosh(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Elu", + Type: "Cosh", Input: []tf.Input{ - features, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the reciprocal of x element-wise. -// -// I.e., \\(y = 1 / x\\). -func Reciprocal(scope *Scope, x tf.Output) (y tf.Output) { +// Creates a dataset that emits each dim-0 slice of `components` once. +func TensorSliceDataset(scope *Scope, components []tf.Output, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Reciprocal", + Type: "TensorSliceDataset", Input: []tf.Input{ - x, + tf.OutputList(components), }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns a batched matrix tensor with new batched diagonal values. -// -// Given `input` and `diagonal`, this operation returns a tensor with the -// same shape and values as `input`, except for the main diagonal of the -// innermost matrices. These will be overwritten by the values in `diagonal`. -// -// The output is computed as follows: -// -// Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has -// `k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a -// tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where: -// -// * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`. -// * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`. -// -// Arguments: -// input: Rank `k+1`, where `k >= 1`. -// diagonal: Rank `k`, where `k >= 1`. +// Computes natural logarithm of (1 + x) element-wise. // -// Returns Rank `k+1`, with `output.shape = input.shape`. -func MatrixSetDiag(scope *Scope, input tf.Output, diagonal tf.Output) (output tf.Output) { +// I.e., \\(y = \log_e (1 + x)\\). +func Log1p(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "MatrixSetDiag", + Type: "Log1p", Input: []tf.Input{ - input, diagonal, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the element-wise max of two SparseTensors. -// -// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. +// Computes rectified linear 6 gradients for a Relu6 operation. // // Arguments: -// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, in the canonical lexicographic ordering. -// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. -// a_shape: 1-D. Shape of the input SparseTensor. -// b_indices: counterpart to `a_indices` for the other operand. -// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. -// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. +// gradients: The backpropagated gradients to the corresponding Relu6 operation. +// features: The features passed as input to the corresponding Relu6 operation, or +// its output; using either one produces the same result. // -// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. -func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { +// Returns The gradients: +// `gradients * (features > 0) * (features < 6)`. +func Relu6Grad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSparseMaximum", + Type: "Relu6Grad", Input: []tf.Input{ - a_indices, a_values, a_shape, b_indices, b_values, b_shape, + gradients, features, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// OrderedMapClearAttr is an optional argument to OrderedMapClear. -type OrderedMapClearAttr func(optionalAttr) +// ResizeBicubicAttr is an optional argument to ResizeBicubic. +type ResizeBicubicAttr func(optionalAttr) -// OrderedMapClearCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// ResizeBicubicAlignCorners sets the optional align_corners attribute to value. // -// REQUIRES: value >= 0 -func OrderedMapClearCapacity(value int64) OrderedMapClearAttr { +// value: If true, the centers of the 4 corner pixels of the input and output tensors are +// aligned, preserving the values at the corner pixels. Defaults to false. +// If not specified, defaults to false +func ResizeBicubicAlignCorners(value bool) ResizeBicubicAttr { return func(m optionalAttr) { - m["capacity"] = value + m["align_corners"] = value } } -// OrderedMapClearMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Resize `images` to `size` using bicubic interpolation. // -// REQUIRES: value >= 0 -func OrderedMapClearMemoryLimit(value int64) OrderedMapClearAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// OrderedMapClearContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapClearContainer(value string) OrderedMapClearAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// OrderedMapClearSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapClearSharedName(value string) OrderedMapClearAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op removes all elements in the underlying container. +// Input images can be of different types but output images are always float. // -// Returns the created operation. -func OrderedMapClear(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapClearAttr) (o *tf.Operation) { +// Arguments: +// images: 4-D with shape `[batch, height, width, channels]`. +// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The +// new size for the images. +// +// Returns 4-D with shape +// `[batch, new_height, new_width, channels]`. +func ResizeBicubic(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeBicubicAttr) (resized_images tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "OrderedMapClear", - + Type: "ResizeBicubic", + Input: []tf.Input{ + images, size, + }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Computes the reciprocal of x element-wise. +// Computes natural logarithm of x element-wise. // -// I.e., \\(y = 1 / x\\). -func Inv(scope *Scope, x tf.Output) (y tf.Output) { +// I.e., \\(y = \log_e x\\). +func Log(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Inv", + Type: "Log", Input: []tf.Input{ x, }, @@ -4873,135 +4863,190 @@ func Inv(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// ComplexAbsAttr is an optional argument to ComplexAbs. -type ComplexAbsAttr func(optionalAttr) - -// ComplexAbsTout sets the optional Tout attribute to value. -// If not specified, defaults to DT_FLOAT -func ComplexAbsTout(value tf.DataType) ComplexAbsAttr { - return func(m optionalAttr) { - m["Tout"] = value - } -} - -// Computes the complex absolute value of a tensor. +// Rounds the values of a tensor to the nearest integer, element-wise. // -// Given a tensor `x` of complex numbers, this operation returns a tensor of type -// `float` or `double` that is the absolute value of each element in `x`. All -// elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute -// value is computed as \\( \sqrt{a^2 + b^2}\\). -func ComplexAbs(scope *Scope, x tf.Output, optional ...ComplexAbsAttr) (y tf.Output) { +// Rounds half to even. Also known as bankers rounding. If you want to round +// according to the current system rounding mode use std::cint. +func Round(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ComplexAbs", + Type: "Round", Input: []tf.Input{ x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the truth value of x AND y element-wise. +// RecordInputAttr is an optional argument to RecordInput. +type RecordInputAttr func(optionalAttr) + +// RecordInputFileRandomSeed sets the optional file_random_seed attribute to value. // -// *NOTE*: `LogicalAnd` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func LogicalAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return +// value: Random seeds used to produce randomized records. +// If not specified, defaults to 301 +func RecordInputFileRandomSeed(value int64) RecordInputAttr { + return func(m optionalAttr) { + m["file_random_seed"] = value } - opspec := tf.OpSpec{ - Type: "LogicalAnd", - Input: []tf.Input{ - x, y, - }, +} + +// RecordInputFileShuffleShiftRatio sets the optional file_shuffle_shift_ratio attribute to value. +// +// value: Shifts the list of files after the list is randomly +// shuffled. +// If not specified, defaults to 0 +func RecordInputFileShuffleShiftRatio(value float32) RecordInputAttr { + return func(m optionalAttr) { + m["file_shuffle_shift_ratio"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Checks whether a tree ensemble has been initialized. +// RecordInputFileBufferSize sets the optional file_buffer_size attribute to value. +// +// value: The randomization shuffling buffer. +// If not specified, defaults to 10000 +func RecordInputFileBufferSize(value int64) RecordInputAttr { + return func(m optionalAttr) { + m["file_buffer_size"] = value + } +} + +// RecordInputFileParallelism sets the optional file_parallelism attribute to value. +// +// value: How many sstables are opened and concurrently iterated over. +// If not specified, defaults to 16 +func RecordInputFileParallelism(value int64) RecordInputAttr { + return func(m optionalAttr) { + m["file_parallelism"] = value + } +} + +// RecordInputBatchSize sets the optional batch_size attribute to value. +// +// value: The batch size. +// If not specified, defaults to 32 +func RecordInputBatchSize(value int64) RecordInputAttr { + return func(m optionalAttr) { + m["batch_size"] = value + } +} + +// RecordInputCompressionType sets the optional compression_type attribute to value. +// +// value: The type of compression for the file. Currently ZLIB and +// GZIP are supported. Defaults to none. +// If not specified, defaults to "" +func RecordInputCompressionType(value string) RecordInputAttr { + return func(m optionalAttr) { + m["compression_type"] = value + } +} + +// Emits randomized records. // // Arguments: -// tree_ensemble_handle: Handle to the tree ensemble resouce. +// file_pattern: Glob pattern for the data files. // -// Returns output boolean on whether it is initialized or not. -func IsBoostedTreesEnsembleInitialized(scope *Scope, tree_ensemble_handle tf.Output) (is_initialized tf.Output) { +// Returns A tensor of shape [batch_size]. +func RecordInput(scope *Scope, file_pattern string, optional ...RecordInputAttr) (records tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"file_pattern": file_pattern} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "IsBoostedTreesEnsembleInitialized", - Input: []tf.Input{ - tree_ensemble_handle, - }, + Type: "RecordInput", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Cast x of type SrcT to y of DstT. -func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output) { +// Computes reciprocal of square root of x element-wise. +// +// I.e., \\(y = 1 / \sqrt{x}\\). +func Rsqrt(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"DstT": DstT} opspec := tf.OpSpec{ - Type: "Cast", + Type: "Rsqrt", Input: []tf.Input{ x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// MaxAttr is an optional argument to Max. -type MaxAttr func(optionalAttr) +// AudioSpectrogramAttr is an optional argument to AudioSpectrogram. +type AudioSpectrogramAttr func(optionalAttr) -// MaxKeepDims sets the optional keep_dims attribute to value. +// AudioSpectrogramMagnitudeSquared sets the optional magnitude_squared attribute to value. // -// value: If true, retain reduced dimensions with length 1. +// value: Whether to return the squared magnitude or just the +// magnitude. Using squared magnitude can avoid extra calculations. // If not specified, defaults to false -func MaxKeepDims(value bool) MaxAttr { +func AudioSpectrogramMagnitudeSquared(value bool) AudioSpectrogramAttr { return func(m optionalAttr) { - m["keep_dims"] = value + m["magnitude_squared"] = value } } -// Computes the maximum of elements across dimensions of a tensor. +// Produces a visualization of audio data over time. // -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. +// Spectrograms are a standard way of representing audio information as a series of +// slices of frequency information, one slice for each window of time. By joining +// these together into a sequence, they form a distinctive fingerprint of the sound +// over time. +// +// This op expects to receive audio data as an input, stored as floats in the range +// -1 to 1, together with a window width in samples, and a stride specifying how +// far to move the window between slices. From this it generates a three +// dimensional output. The lowest dimension has an amplitude value for each +// frequency during that time slice. The next dimension is time, with successive +// frequency slices. The final dimension is for the channels in the input, so a +// stereo audio input would have two here for example. +// +// This means the layout when converted and saved as an image is rotated 90 degrees +// clockwise from a typical spectrogram. Time is descending down the Y axis, and +// the frequency decreases from left to right. +// +// Each value in the result represents the square root of the sum of the real and +// imaginary parts of an FFT on the current window of samples. In this way, the +// lowest dimension represents the power of each frequency in the current window, +// and adjacent windows are concatenated in the next dimension. +// +// To get a more intuitive and visual look at what this operation does, you can run +// tensorflow/examples/wav_to_spectrogram to read in an audio file and save out the +// resulting spectrogram as a PNG image. // // Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. +// input: Float representation of audio data. +// window_size: How wide the input window is in samples. For the highest efficiency +// this should be a power of two, but other values are accepted. +// stride: How widely apart the center of adjacent sample windows should be. // -// Returns The reduced tensor. -func Max(scope *Scope, input tf.Output, axis tf.Output, optional ...MaxAttr) (output tf.Output) { +// Returns 3D representation of the audio frequencies as an image. +func AudioSpectrogram(scope *Scope, input tf.Output, window_size int64, stride int64, optional ...AudioSpectrogramAttr) (spectrogram tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"window_size": window_size, "stride": stride} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Max", + Type: "AudioSpectrogram", Input: []tf.Input{ - input, axis, + input, }, Attrs: attrs, } @@ -5009,311 +5054,370 @@ func Max(scope *Scope, input tf.Output, axis tf.Output, optional ...MaxAttr) (ou return op.Output(0) } -// Quantized Batch normalization. -// -// This op is deprecated and will be removed in the future. Prefer -// `tf.nn.batch_normalization`. +// CTCBeamSearchDecoderAttr is an optional argument to CTCBeamSearchDecoder. +type CTCBeamSearchDecoderAttr func(optionalAttr) + +// CTCBeamSearchDecoderMergeRepeated sets the optional merge_repeated attribute to value. // -// Arguments: -// t: A 4D input Tensor. -// t_min: The value represented by the lowest quantized input. -// t_max: The value represented by the highest quantized input. -// m: A 1D mean Tensor with size matching the last dimension of t. -// This is the first output from tf.nn.moments, -// or a saved moving average thereof. -// m_min: The value represented by the lowest quantized mean. -// m_max: The value represented by the highest quantized mean. -// v: A 1D variance Tensor with size matching the last dimension of t. -// This is the second output from tf.nn.moments, -// or a saved moving average thereof. -// v_min: The value represented by the lowest quantized variance. -// v_max: The value represented by the highest quantized variance. -// beta: A 1D beta Tensor with size matching the last dimension of t. -// An offset to be added to the normalized tensor. -// beta_min: The value represented by the lowest quantized offset. -// beta_max: The value represented by the highest quantized offset. -// gamma: A 1D gamma Tensor with size matching the last dimension of t. -// If "scale_after_normalization" is true, this tensor will be multiplied -// with the normalized tensor. -// gamma_min: The value represented by the lowest quantized gamma. -// gamma_max: The value represented by the highest quantized gamma. +// value: If true, merge repeated classes in output. +// If not specified, defaults to true +func CTCBeamSearchDecoderMergeRepeated(value bool) CTCBeamSearchDecoderAttr { + return func(m optionalAttr) { + m["merge_repeated"] = value + } +} + +// Performs beam search decoding on the logits given in input. // -// variance_epsilon: A small float number to avoid dividing by 0. -// scale_after_normalization: A bool indicating whether the resulted tensor -// needs to be multiplied with gamma. -func QuantizedBatchNormWithGlobalNormalization(scope *Scope, t tf.Output, t_min tf.Output, t_max tf.Output, m tf.Output, m_min tf.Output, m_max tf.Output, v tf.Output, v_min tf.Output, v_max tf.Output, beta tf.Output, beta_min tf.Output, beta_max tf.Output, gamma tf.Output, gamma_min tf.Output, gamma_max tf.Output, out_type tf.DataType, variance_epsilon float32, scale_after_normalization bool) (result tf.Output, result_min tf.Output, result_max tf.Output) { +// A note about the attribute merge_repeated: For the beam search decoder, +// this means that if consecutive entries in a beam are the same, only +// the first of these is emitted. That is, when the top path is "A B B B B", +// "A B" is returned if merge_repeated = True but "A B B B B" is +// returned if merge_repeated = False. +// +// Arguments: +// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +// sequence_length: A vector containing sequence lengths, size `(batch)`. +// beam_width: A scalar >= 0 (beam search beam width). +// top_paths: A scalar >= 0, <= beam_width (controls output size). +// +// Returns A list (length: top_paths) of indices matrices. Matrix j, +// size `(total_decoded_outputs[j] x 2)`, has indices of a +// `SparseTensor`. The rows store: [batch, time].A list (length: top_paths) of values vectors. Vector j, +// size `(length total_decoded_outputs[j])`, has the values of a +// `SparseTensor`. The vector stores the decoded classes for beam j.A list (length: top_paths) of shape vector. Vector j, +// size `(2)`, stores the shape of the decoded `SparseTensor[j]`. +// Its values are: `[batch_size, max_decoded_length[j]]`.A matrix, shaped: `(batch_size x top_paths)`. The +// sequence log-probabilities. +func CTCBeamSearchDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, beam_width int64, top_paths int64, optional ...CTCBeamSearchDecoderAttr) (decoded_indices []tf.Output, decoded_values []tf.Output, decoded_shape []tf.Output, log_probability tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"out_type": out_type, "variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} + attrs := map[string]interface{}{"beam_width": beam_width, "top_paths": top_paths} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "QuantizedBatchNormWithGlobalNormalization", + Type: "CTCBeamSearchDecoder", Input: []tf.Input{ - t, t_min, t_max, m, m_min, m_max, v, v_min, v_max, beta, beta_min, beta_max, gamma, gamma_min, gamma_max, + inputs, sequence_length, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + if scope.Err() != nil { + return + } + var idx int + var err error + if decoded_indices, idx, err = makeOutputList(op, idx, "decoded_indices"); err != nil { + scope.UpdateErr("CTCBeamSearchDecoder", err) + return + } + if decoded_values, idx, err = makeOutputList(op, idx, "decoded_values"); err != nil { + scope.UpdateErr("CTCBeamSearchDecoder", err) + return + } + if decoded_shape, idx, err = makeOutputList(op, idx, "decoded_shape"); err != nil { + scope.UpdateErr("CTCBeamSearchDecoder", err) + return + } + log_probability = op.Output(idx) + return decoded_indices, decoded_values, decoded_shape, log_probability } -// Adds Tensor 'bias' to Tensor 'input' for Quantized types. +// MatrixInverseAttr is an optional argument to MatrixInverse. +type MatrixInverseAttr func(optionalAttr) + +// MatrixInverseAdjoint sets the optional adjoint attribute to value. +// If not specified, defaults to false +func MatrixInverseAdjoint(value bool) MatrixInverseAttr { + return func(m optionalAttr) { + m["adjoint"] = value + } +} + +// Computes the inverse of one or more square invertible matrices or their // -// Broadcasts the values of bias on dimensions 0..N-2 of 'input'. +// adjoints (conjugate transposes). // -// Arguments: +// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +// form square matrices. The output is a tensor of the same shape as the input +// containing the inverse for all input submatrices `[..., :, :]`. // -// bias: A 1D bias Tensor with size matching the last dimension of 'input'. -// min_input: The float value that the lowest quantized input value represents. -// max_input: The float value that the highest quantized input value represents. -// min_bias: The float value that the lowest quantized bias value represents. -// max_bias: The float value that the highest quantized bias value represents. +// The op uses LU decomposition with partial pivoting to compute the inverses. +// +// If a matrix is not invertible there is no guarantee what the op does. It +// may detect the condition and raise an exception or it may simply return a +// garbage result. // +// Arguments: +// input: Shape is `[..., M, M]`. // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedBiasAdd(scope *Scope, input tf.Output, bias tf.Output, min_input tf.Output, max_input tf.Output, min_bias tf.Output, max_bias tf.Output, out_type tf.DataType) (output tf.Output, min_out tf.Output, max_out tf.Output) { +// Returns Shape is `[..., M, M]`. +// +// @compatibility(numpy) +// Equivalent to np.linalg.inv +// @end_compatibility +func MatrixInverse(scope *Scope, input tf.Output, optional ...MatrixInverseAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"out_type": out_type} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "QuantizedBiasAdd", + Type: "MatrixInverse", Input: []tf.Input{ - input, bias, min_input, max_input, min_bias, max_bias, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Produces the average pool of the input tensor for quantized types. -// -// Arguments: -// input: 4-D with shape `[batch, height, width, channels]`. -// min_input: The float value that the lowest quantized input value represents. -// max_input: The float value that the highest quantized input value represents. -// ksize: The size of the window for each dimension of the input tensor. -// The length must be 4 to match the number of dimensions of the input. -// strides: The stride of the sliding window for each dimension of the input -// tensor. The length must be 4 to match the number of dimensions of the input. -// padding: The type of padding algorithm to use. +// Returns x + y element-wise. // -// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. -func QuantizedAvgPool(scope *Scope, input tf.Output, min_input tf.Output, max_input tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output, min_output tf.Output, max_output tf.Output) { +// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "QuantizedAvgPool", + Type: "Add", Input: []tf.Input{ - input, min_input, max_input, + x, y, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// FractionalAvgPoolAttr is an optional argument to FractionalAvgPool. -type FractionalAvgPoolAttr func(optionalAttr) - -// FractionalAvgPoolPseudoRandom sets the optional pseudo_random attribute to value. -// -// value: When set to True, generates the pooling sequence in a -// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin -// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for -// difference between pseudorandom and random. -// If not specified, defaults to false -func FractionalAvgPoolPseudoRandom(value bool) FractionalAvgPoolAttr { - return func(m optionalAttr) { - m["pseudo_random"] = value +// Computes the derivative of a Gamma random sample w.r.t. `alpha`. +func RandomGammaGrad(scope *Scope, alpha tf.Output, sample tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RandomGammaGrad", + Input: []tf.Input{ + alpha, sample, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// FractionalAvgPoolOverlapping sets the optional overlapping attribute to value. -// -// value: When set to True, it means when pooling, the values at the boundary -// of adjacent pooling cells are used by both cells. For example: -// -// `index 0 1 2 3 4` -// -// `value 20 5 16 3 7` +// Computes square of x element-wise. // -// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. -// The result would be [41/3, 26/3] for fractional avg pooling. -// If not specified, defaults to false -func FractionalAvgPoolOverlapping(value bool) FractionalAvgPoolAttr { - return func(m optionalAttr) { - m["overlapping"] = value +// I.e., \\(y = x * x = x^2\\). +func Square(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Square", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// FractionalAvgPoolDeterministic sets the optional deterministic attribute to value. +// Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise. // -// value: When set to True, a fixed pooling region will be used when -// iterating over a FractionalAvgPool node in the computation graph. Mainly used -// in unit test to make FractionalAvgPool deterministic. -// If not specified, defaults to false -func FractionalAvgPoolDeterministic(value bool) FractionalAvgPoolAttr { - return func(m optionalAttr) { - m["deterministic"] = value +// See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) +// ](http://arxiv.org/abs/1511.07289) +func Elu(scope *Scope, features tf.Output) (activations tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Elu", + Input: []tf.Input{ + features, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// FractionalAvgPoolSeed sets the optional seed attribute to value. +// Computes the reciprocal of x element-wise. // -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func FractionalAvgPoolSeed(value int64) FractionalAvgPoolAttr { - return func(m optionalAttr) { - m["seed"] = value +// I.e., \\(y = 1 / x\\). +func Reciprocal(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Reciprocal", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// FractionalAvgPoolSeed2 sets the optional seed2 attribute to value. +// Returns a batched matrix tensor with new batched diagonal values. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func FractionalAvgPoolSeed2(value int64) FractionalAvgPoolAttr { - return func(m optionalAttr) { - m["seed2"] = value +// Given `input` and `diagonal`, this operation returns a tensor with the +// same shape and values as `input`, except for the main diagonal of the +// innermost matrices. These will be overwritten by the values in `diagonal`. +// +// The output is computed as follows: +// +// Assume `input` has `k+1` dimensions `[I, J, K, ..., M, N]` and `diagonal` has +// `k` dimensions `[I, J, K, ..., min(M, N)]`. Then the output is a +// tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where: +// +// * `output[i, j, k, ..., m, n] = diagonal[i, j, k, ..., n]` for `m == n`. +// * `output[i, j, k, ..., m, n] = input[i, j, k, ..., m, n]` for `m != n`. +// +// Arguments: +// input: Rank `k+1`, where `k >= 1`. +// diagonal: Rank `k`, where `k >= 1`. +// +// Returns Rank `k+1`, with `output.shape = input.shape`. +func MatrixSetDiag(scope *Scope, input tf.Output, diagonal tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatrixSetDiag", + Input: []tf.Input{ + input, diagonal, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Performs fractional average pooling on the input. +// Returns the element-wise max of two SparseTensors. // -// Fractional average pooling is similar to Fractional max pooling in the pooling -// region generation step. The only difference is that after pooling regions are -// generated, a mean operation is performed instead of a max operation in each -// pooling region. +// Assumes the two SparseTensors have the same shape, i.e., no broadcasting. // // Arguments: -// value: 4-D with shape `[batch, height, width, channels]`. -// pooling_ratio: Pooling ratio for each dimension of `value`, currently only -// supports row and col dimension and should be >= 1.0. For example, a valid -// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements -// must be 1.0 because we don't allow pooling on batch and channels -// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions -// respectively. +// a_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, in the canonical lexicographic ordering. +// a_values: 1-D. `N` non-empty values corresponding to `a_indices`. +// a_shape: 1-D. Shape of the input SparseTensor. +// b_indices: counterpart to `a_indices` for the other operand. +// b_values: counterpart to `a_values` for the other operand; must be of the same dtype. +// b_shape: counterpart to `a_shape` for the other operand; the two shapes must be equal. // -// Returns output tensor after fractional avg pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. -func FractionalAvgPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalAvgPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { +// Returns 2-D. The indices of the output SparseTensor.1-D. The values of the output SparseTensor. +func SparseSparseMaximum(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output) (output_indices tf.Output, output_values tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "FractionalAvgPool", + Type: "SparseSparseMaximum", Input: []tf.Input{ - value, + a_indices, a_values, a_shape, b_indices, b_values, b_shape, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0), op.Output(1) } -// RandomCropAttr is an optional argument to RandomCrop. -type RandomCropAttr func(optionalAttr) +// OrderedMapClearAttr is an optional argument to OrderedMapClear. +type OrderedMapClearAttr func(optionalAttr) -// RandomCropSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. +// OrderedMapClearCapacity sets the optional capacity attribute to value. // If not specified, defaults to 0 -func RandomCropSeed(value int64) RandomCropAttr { +// +// REQUIRES: value >= 0 +func OrderedMapClearCapacity(value int64) OrderedMapClearAttr { return func(m optionalAttr) { - m["seed"] = value + m["capacity"] = value } } -// RandomCropSeed2 sets the optional seed2 attribute to value. -// -// value: An second seed to avoid seed collision. +// OrderedMapClearMemoryLimit sets the optional memory_limit attribute to value. // If not specified, defaults to 0 -func RandomCropSeed2(value int64) RandomCropAttr { +// +// REQUIRES: value >= 0 +func OrderedMapClearMemoryLimit(value int64) OrderedMapClearAttr { return func(m optionalAttr) { - m["seed2"] = value + m["memory_limit"] = value } } -// Randomly crop `image`. -// -// DEPRECATED at GraphDef version 8: Random crop is now pure Python -// -// `size` is a 1-D int64 tensor with 2 elements representing the crop height and -// width. The values must be non negative. -// -// This Op picks a random location in `image` and crops a `height` by `width` -// rectangle from that location. The random location is picked so the cropped -// area will fit inside the original image. -// -// Arguments: -// image: 3-D of shape `[height, width, channels]`. -// size: 1-D of length 2 containing: `crop_height`, `crop_width`.. +// OrderedMapClearContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapClearContainer(value string) OrderedMapClearAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// OrderedMapClearSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapClearSharedName(value string) OrderedMapClearAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes all elements in the underlying container. // -// Returns 3-D of shape `[crop_height, crop_width, channels].` -func RandomCrop(scope *Scope, image tf.Output, size tf.Output, optional ...RandomCropAttr) (output tf.Output) { +// Returns the created operation. +func OrderedMapClear(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapClearAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RandomCrop", + Type: "OrderedMapClear", + + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Computes the reciprocal of x element-wise. +// +// I.e., \\(y = 1 / x\\). +func Inv(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Inv", Input: []tf.Input{ - image, size, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// TopKV2Attr is an optional argument to TopKV2. -type TopKV2Attr func(optionalAttr) +// ComplexAbsAttr is an optional argument to ComplexAbs. +type ComplexAbsAttr func(optionalAttr) -// TopKV2Sorted sets the optional sorted attribute to value. -// -// value: If true the resulting `k` elements will be sorted by the values in -// descending order. -// If not specified, defaults to true -func TopKV2Sorted(value bool) TopKV2Attr { +// ComplexAbsTout sets the optional Tout attribute to value. +// If not specified, defaults to DT_FLOAT +func ComplexAbsTout(value tf.DataType) ComplexAbsAttr { return func(m optionalAttr) { - m["sorted"] = value + m["Tout"] = value } } -// Finds values and indices of the `k` largest elements for the last dimension. -// -// If the input is a vector (rank-1), finds the `k` largest entries in the vector -// and outputs their values and indices as vectors. Thus `values[j]` is the -// `j`-th largest entry in `input`, and its index is `indices[j]`. -// -// For matrices (resp. higher rank input), computes the top `k` entries in each -// row (resp. vector along the last dimension). Thus, -// -// values.shape = indices.shape = input.shape[:-1] + [k] -// -// If two elements are equal, the lower-index element appears first. -// -// Arguments: -// input: 1-D or higher with last dimension at least `k`. -// k: 0-D. Number of top elements to look for along the last dimension (along each -// row for matrices). +// Computes the complex absolute value of a tensor. // -// Returns The `k` largest elements along each last dimensional slice.The indices of `values` within the last dimension of `input`. -func TopKV2(scope *Scope, input tf.Output, k tf.Output, optional ...TopKV2Attr) (values tf.Output, indices tf.Output) { +// Given a tensor `x` of complex numbers, this operation returns a tensor of type +// `float` or `double` that is the absolute value of each element in `x`. All +// elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute +// value is computed as \\( \sqrt{a^2 + b^2}\\). +func ComplexAbs(scope *Scope, x tf.Output, optional ...ComplexAbsAttr) (y tf.Output) { if scope.Err() != nil { return } @@ -5322,26 +5426,26 @@ func TopKV2(scope *Scope, input tf.Output, k tf.Output, optional ...TopKV2Attr) a(attrs) } opspec := tf.OpSpec{ - Type: "TopKV2", + Type: "ComplexAbs", Input: []tf.Input{ - input, k, + x, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Returns x // y element-wise. +// Returns the truth value of x AND y element-wise. // -// *NOTE*: `FloorDiv` supports broadcasting. More about broadcasting +// *NOTE*: `LogicalAnd` supports broadcasting. More about broadcasting // [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func FloorDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +func LogicalAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "FloorDiv", + Type: "LogicalAnd", Input: []tf.Input{ x, y, }, @@ -5350,325 +5454,407 @@ func FloorDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Computes the inverse permutation of a tensor. -// -// This operation computes the inverse of an index permutation. It takes a 1-D -// integer tensor `x`, which represents the indices of a zero-based array, and -// swaps each value with its index position. In other words, for an output tensor -// `y` and an input tensor `x`, this operation computes the following: -// -// `y[x[i]] = i for i in [0, 1, ..., len(x) - 1]` -// -// The values must include 0. There can be no duplicate values or negative values. -// -// For example: -// -// ``` -// # tensor `x` is [3, 4, 0, 2, 1] -// invert_permutation(x) ==> [2, 4, 3, 0, 1] -// ``` +// Checks whether a tree ensemble has been initialized. // // Arguments: -// x: 1-D. +// tree_ensemble_handle: Handle to the tree ensemble resouce. // -// Returns 1-D. -func InvertPermutation(scope *Scope, x tf.Output) (y tf.Output) { +// Returns output boolean on whether it is initialized or not. +func IsBoostedTreesEnsembleInitialized(scope *Scope, tree_ensemble_handle tf.Output) (is_initialized tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "InvertPermutation", + Type: "IsBoostedTreesEnsembleInitialized", Input: []tf.Input{ - x, + tree_ensemble_handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes log softmax activations. -// -// For each batch `i` and class `j` we have -// -// logsoftmax[i, j] = logits[i, j] - log(sum(exp(logits[i]))) -// -// Arguments: -// logits: 2-D with shape `[batch_size, num_classes]`. -// -// Returns Same shape as `logits`. -func LogSoftmax(scope *Scope, logits tf.Output) (logsoftmax tf.Output) { +// CastAttr is an optional argument to Cast. +type CastAttr func(optionalAttr) + +// CastTruncate sets the optional Truncate attribute to value. +// If not specified, defaults to false +func CastTruncate(value bool) CastAttr { + return func(m optionalAttr) { + m["Truncate"] = value + } +} + +// Cast x of type SrcT to y of DstT. +func Cast(scope *Scope, x tf.Output, DstT tf.DataType, optional ...CastAttr) (y tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"DstT": DstT} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "LogSoftmax", + Type: "Cast", Input: []tf.Input{ - logits, + x, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the truth value of (x <= y) element-wise. +// MaxAttr is an optional argument to Max. +type MaxAttr func(optionalAttr) + +// MaxKeepDims sets the optional keep_dims attribute to value. // -// *NOTE*: `LessEqual` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func LessEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LessEqual", - Input: []tf.Input{ - x, y, - }, +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func MaxKeepDims(value bool) MaxAttr { + return func(m optionalAttr) { + m["keep_dims"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes softmax activations. -// -// For each batch `i` and class `j` we have +// Computes the maximum of elements across dimensions of a tensor. // -// softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j])) +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. // // Arguments: -// logits: 2-D with shape `[batch_size, num_classes]`. +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -// Returns Same shape as `logits`. -func Softmax(scope *Scope, logits tf.Output) (softmax tf.Output) { +// Returns The reduced tensor. +func Max(scope *Scope, input tf.Output, axis tf.Output, optional ...MaxAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Softmax", + Type: "Max", Input: []tf.Input{ - logits, + input, axis, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// DecodeBmpAttr is an optional argument to DecodeBmp. -type DecodeBmpAttr func(optionalAttr) - -// DecodeBmpChannels sets the optional channels attribute to value. -// If not specified, defaults to 0 -func DecodeBmpChannels(value int64) DecodeBmpAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// Decode the first frame of a BMP-encoded image to a uint8 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: +// Quantized Batch normalization. // -// * 0: Use the number of channels in the BMP-encoded image. -// * 3: output an RGB image. -// * 4: output an RGBA image. +// This op is deprecated and will be removed in the future. Prefer +// `tf.nn.batch_normalization`. // // Arguments: -// contents: 0-D. The BMP-encoded image. +// t: A 4D input Tensor. +// t_min: The value represented by the lowest quantized input. +// t_max: The value represented by the highest quantized input. +// m: A 1D mean Tensor with size matching the last dimension of t. +// This is the first output from tf.nn.moments, +// or a saved moving average thereof. +// m_min: The value represented by the lowest quantized mean. +// m_max: The value represented by the highest quantized mean. +// v: A 1D variance Tensor with size matching the last dimension of t. +// This is the second output from tf.nn.moments, +// or a saved moving average thereof. +// v_min: The value represented by the lowest quantized variance. +// v_max: The value represented by the highest quantized variance. +// beta: A 1D beta Tensor with size matching the last dimension of t. +// An offset to be added to the normalized tensor. +// beta_min: The value represented by the lowest quantized offset. +// beta_max: The value represented by the highest quantized offset. +// gamma: A 1D gamma Tensor with size matching the last dimension of t. +// If "scale_after_normalization" is true, this tensor will be multiplied +// with the normalized tensor. +// gamma_min: The value represented by the lowest quantized gamma. +// gamma_max: The value represented by the highest quantized gamma. // -// Returns 3-D with shape `[height, width, channels]`. RGB order -func DecodeBmp(scope *Scope, contents tf.Output, optional ...DecodeBmpAttr) (image tf.Output) { +// variance_epsilon: A small float number to avoid dividing by 0. +// scale_after_normalization: A bool indicating whether the resulted tensor +// needs to be multiplied with gamma. +func QuantizedBatchNormWithGlobalNormalization(scope *Scope, t tf.Output, t_min tf.Output, t_max tf.Output, m tf.Output, m_min tf.Output, m_max tf.Output, v tf.Output, v_min tf.Output, v_max tf.Output, beta tf.Output, beta_min tf.Output, beta_max tf.Output, gamma tf.Output, gamma_min tf.Output, gamma_max tf.Output, out_type tf.DataType, variance_epsilon float32, scale_after_normalization bool) (result tf.Output, result_min tf.Output, result_max tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"out_type": out_type, "variance_epsilon": variance_epsilon, "scale_after_normalization": scale_after_normalization} opspec := tf.OpSpec{ - Type: "DecodeBmp", + Type: "QuantizedBatchNormWithGlobalNormalization", Input: []tf.Input{ - contents, + t, t_min, t_max, m, m_min, m_max, v, v_min, v_max, beta, beta_min, beta_max, gamma, gamma_min, gamma_max, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Computes softsign gradients for a softsign operation. +// Adds Tensor 'bias' to Tensor 'input' for Quantized types. +// +// Broadcasts the values of bias on dimensions 0..N-2 of 'input'. // // Arguments: -// gradients: The backpropagated gradients to the corresponding softsign operation. -// features: The features passed as input to the corresponding softsign operation. // -// Returns The gradients: `gradients / (1 + abs(features)) ** 2`. -func SoftsignGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { +// bias: A 1D bias Tensor with size matching the last dimension of 'input'. +// min_input: The float value that the lowest quantized input value represents. +// max_input: The float value that the highest quantized input value represents. +// min_bias: The float value that the lowest quantized bias value represents. +// max_bias: The float value that the highest quantized bias value represents. +// +// +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedBiasAdd(scope *Scope, input tf.Output, bias tf.Output, min_input tf.Output, max_input tf.Output, min_bias tf.Output, max_bias tf.Output, out_type tf.DataType) (output tf.Output, min_out tf.Output, max_out tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"out_type": out_type} opspec := tf.OpSpec{ - Type: "SoftsignGrad", + Type: "QuantizedBiasAdd", Input: []tf.Input{ - gradients, features, + input, bias, min_input, max_input, min_bias, max_bias, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Provides the time since epoch in seconds. +// Produces the average pool of the input tensor for quantized types. // -// Returns the timestamp as a `float64` for seconds since the Unix epoch. +// Arguments: +// input: 4-D with shape `[batch, height, width, channels]`. +// min_input: The float value that the lowest quantized input value represents. +// max_input: The float value that the highest quantized input value represents. +// ksize: The size of the window for each dimension of the input tensor. +// The length must be 4 to match the number of dimensions of the input. +// strides: The stride of the sliding window for each dimension of the input +// tensor. The length must be 4 to match the number of dimensions of the input. +// padding: The type of padding algorithm to use. // -// Note: the timestamp is computed when the op is executed, not when it is added -// to the graph. -func Timestamp(scope *Scope) (ts tf.Output) { +// Returns The float value that the lowest quantized output value represents.The float value that the highest quantized output value represents. +func QuantizedAvgPool(scope *Scope, input tf.Output, min_input tf.Output, max_input tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output, min_output tf.Output, max_output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "Timestamp", + Type: "QuantizedAvgPool", + Input: []tf.Input{ + input, min_input, max_input, + }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// BatchMatMulAttr is an optional argument to BatchMatMul. -type BatchMatMulAttr func(optionalAttr) +// FractionalAvgPoolAttr is an optional argument to FractionalAvgPool. +type FractionalAvgPoolAttr func(optionalAttr) -// BatchMatMulAdjX sets the optional adj_x attribute to value. +// FractionalAvgPoolPseudoRandom sets the optional pseudo_random attribute to value. // -// value: If `True`, adjoint the slices of `x`. Defaults to `False`. +// value: When set to True, generates the pooling sequence in a +// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin +// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for +// difference between pseudorandom and random. // If not specified, defaults to false -func BatchMatMulAdjX(value bool) BatchMatMulAttr { +func FractionalAvgPoolPseudoRandom(value bool) FractionalAvgPoolAttr { return func(m optionalAttr) { - m["adj_x"] = value + m["pseudo_random"] = value } } -// BatchMatMulAdjY sets the optional adj_y attribute to value. +// FractionalAvgPoolOverlapping sets the optional overlapping attribute to value. // -// value: If `True`, adjoint the slices of `y`. Defaults to `False`. +// value: When set to True, it means when pooling, the values at the boundary +// of adjacent pooling cells are used by both cells. For example: +// +// `index 0 1 2 3 4` +// +// `value 20 5 16 3 7` +// +// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. +// The result would be [41/3, 26/3] for fractional avg pooling. // If not specified, defaults to false -func BatchMatMulAdjY(value bool) BatchMatMulAttr { +func FractionalAvgPoolOverlapping(value bool) FractionalAvgPoolAttr { return func(m optionalAttr) { - m["adj_y"] = value + m["overlapping"] = value } } -// Multiplies slices of two tensors in batches. -// -// Multiplies all slices of `Tensor` `x` and `y` (each slice can be -// viewed as an element of a batch), and arranges the individual results -// in a single output tensor of the same batch size. Each of the -// individual slices can optionally be adjointed (to adjoint a matrix -// means to transpose and conjugate it) before multiplication by setting -// the `adj_x` or `adj_y` flag to `True`, which are by default `False`. -// -// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` -// and `[..., r_y, c_y]`. +// FractionalAvgPoolDeterministic sets the optional deterministic attribute to value. // -// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// value: When set to True, a fixed pooling region will be used when +// iterating over a FractionalAvgPool node in the computation graph. Mainly used +// in unit test to make FractionalAvgPool deterministic. +// If not specified, defaults to false +func FractionalAvgPoolDeterministic(value bool) FractionalAvgPoolAttr { + return func(m optionalAttr) { + m["deterministic"] = value + } +} + +// FractionalAvgPoolSeed sets the optional seed attribute to value. // -// r_o = c_x if adj_x else r_x -// c_o = r_y if adj_y else c_y +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func FractionalAvgPoolSeed(value int64) FractionalAvgPoolAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// FractionalAvgPoolSeed2 sets the optional seed2 attribute to value. // -// It is computed as: +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func FractionalAvgPoolSeed2(value int64) FractionalAvgPoolAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Performs fractional average pooling on the input. // -// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +// Fractional average pooling is similar to Fractional max pooling in the pooling +// region generation step. The only difference is that after pooling regions are +// generated, a mean operation is performed instead of a max operation in each +// pooling region. // // Arguments: -// x: 2-D or higher with shape `[..., r_x, c_x]`. -// y: 2-D or higher with shape `[..., r_y, c_y]`. +// value: 4-D with shape `[batch, height, width, channels]`. +// pooling_ratio: Pooling ratio for each dimension of `value`, currently only +// supports row and col dimension and should be >= 1.0. For example, a valid +// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements +// must be 1.0 because we don't allow pooling on batch and channels +// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions +// respectively. // -// Returns 3-D or higher with shape `[..., r_o, c_o]` -func BatchMatMul(scope *Scope, x tf.Output, y tf.Output, optional ...BatchMatMulAttr) (output tf.Output) { +// Returns output tensor after fractional avg pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. +func FractionalAvgPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalAvgPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "BatchMatMul", + Type: "FractionalAvgPool", Input: []tf.Input{ - x, y, + value, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Returns which elements of x are NaN. +// RandomCropAttr is an optional argument to RandomCrop. +type RandomCropAttr func(optionalAttr) + +// RandomCropSeed sets the optional seed attribute to value. // -// @compatibility(numpy) -// Equivalent to np.isnan -// @end_compatibility -func IsNan(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomCropSeed(value int64) RandomCropAttr { + return func(m optionalAttr) { + m["seed"] = value } - opspec := tf.OpSpec{ - Type: "IsNan", - Input: []tf.Input{ - x, - }, +} + +// RandomCropSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomCropSeed2(value int64) RandomCropAttr { + return func(m optionalAttr) { + m["seed2"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Identity op for gradient debugging. +// Randomly crop `image`. // -// This op is hidden from public in Python. It is used by TensorFlow Debugger to -// register gradient tensors for gradient debugging. -// This op operates on non-reference-type tensors. -func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) { +// DEPRECATED at GraphDef version 8: Random crop is now pure Python +// +// `size` is a 1-D int64 tensor with 2 elements representing the crop height and +// width. The values must be non negative. +// +// This Op picks a random location in `image` and crops a `height` by `width` +// rectangle from that location. The random location is picked so the cropped +// area will fit inside the original image. +// +// Arguments: +// image: 3-D of shape `[height, width, channels]`. +// size: 1-D of length 2 containing: `crop_height`, `crop_width`.. +// +// Returns 3-D of shape `[crop_height, crop_width, channels].` +func RandomCrop(scope *Scope, image tf.Output, size tf.Output, optional ...RandomCropAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DebugGradientIdentity", + Type: "RandomCrop", Input: []tf.Input{ - input, + image, size, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// ResourceSparseApplyAdadeltaAttr is an optional argument to ResourceSparseApplyAdadelta. -type ResourceSparseApplyAdadeltaAttr func(optionalAttr) +// TopKV2Attr is an optional argument to TopKV2. +type TopKV2Attr func(optionalAttr) -// ResourceSparseApplyAdadeltaUseLocking sets the optional use_locking attribute to value. +// TopKV2Sorted sets the optional sorted attribute to value. // -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyAdadeltaUseLocking(value bool) ResourceSparseApplyAdadeltaAttr { +// value: If true the resulting `k` elements will be sorted by the values in +// descending order. +// If not specified, defaults to true +func TopKV2Sorted(value bool) TopKV2Attr { return func(m optionalAttr) { - m["use_locking"] = value + m["sorted"] = value } } -// var: Should be from a Variable(). +// Finds values and indices of the `k` largest elements for the last dimension. // -// Arguments: +// If the input is a vector (rank-1), finds the `k` largest entries in the vector +// and outputs their values and indices as vectors. Thus `values[j]` is the +// `j`-th largest entry in `input`, and its index is `indices[j]`. // -// accum: Should be from a Variable(). -// accum_update: : Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// rho: Decay factor. Must be a scalar. -// epsilon: Constant factor. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. +// For matrices (resp. higher rank input), computes the top `k` entries in each +// row (resp. vector along the last dimension). Thus, // -// Returns the created operation. -func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdadeltaAttr) (o *tf.Operation) { +// values.shape = indices.shape = input.shape[:-1] + [k] +// +// If two elements are equal, the lower-index element appears first. +// +// Arguments: +// input: 1-D or higher with last dimension at least `k`. +// k: 0-D. Number of top elements to look for along the last dimension (along each +// row for matrices). +// +// Returns The `k` largest elements along each last dimensional slice.The indices of `values` within the last dimension of `input`. +func TopKV2(scope *Scope, input tf.Output, k tf.Output, optional ...TopKV2Attr) (values tf.Output, indices tf.Output) { if scope.Err() != nil { return } @@ -5677,199 +5863,174 @@ func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdadelta", + Type: "TopKV2", Input: []tf.Input{ - var_, accum, accum_update, lr, rho, epsilon, grad, indices, + input, k, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// Computes rectified linear gradients for a Relu operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding Relu operation. -// features: The features passed as input to the corresponding Relu operation, OR -// the outputs of that operation (both work equivalently). +// Returns x // y element-wise. // -// Returns `gradients * (features > 0)`. -func ReluGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { +// *NOTE*: `FloorDiv` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func FloorDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ReluGrad", + Type: "FloorDiv", Input: []tf.Input{ - gradients, features, + x, y, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the gradient of morphological 2-D dilation with respect to the input. +// Computes the inverse permutation of a tensor. +// +// This operation computes the inverse of an index permutation. It takes a 1-D +// integer tensor `x`, which represents the indices of a zero-based array, and +// swaps each value with its index position. In other words, for an output tensor +// `y` and an input tensor `x`, this operation computes the following: +// +// `y[x[i]] = i for i in [0, 1, ..., len(x) - 1]` +// +// The values must include 0. There can be no duplicate values or negative values. +// +// For example: +// +// ``` +// # tensor `x` is [3, 4, 0, 2, 1] +// invert_permutation(x) ==> [2, 4, 3, 0, 1] +// ``` // // Arguments: -// input: 4-D with shape `[batch, in_height, in_width, depth]`. -// filter: 3-D with shape `[filter_height, filter_width, depth]`. -// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. -// strides: 1-D of length 4. The stride of the sliding window for each dimension of -// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. -// rates: 1-D of length 4. The input stride for atrous morphological dilation. -// Must be: `[1, rate_height, rate_width, 1]`. -// padding: The type of padding algorithm to use. +// x: 1-D. // -// Returns 4-D with shape `[batch, in_height, in_width, depth]`. -func Dilation2DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (in_backprop tf.Output) { +// Returns 1-D. +func InvertPermutation(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} opspec := tf.OpSpec{ - Type: "Dilation2DBackpropInput", + Type: "InvertPermutation", Input: []tf.Input{ - input, filter, out_backprop, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// CTCBeamSearchDecoderAttr is an optional argument to CTCBeamSearchDecoder. -type CTCBeamSearchDecoderAttr func(optionalAttr) - -// CTCBeamSearchDecoderMergeRepeated sets the optional merge_repeated attribute to value. +// Computes log softmax activations. // -// value: If true, merge repeated classes in output. -// If not specified, defaults to true -func CTCBeamSearchDecoderMergeRepeated(value bool) CTCBeamSearchDecoderAttr { - return func(m optionalAttr) { - m["merge_repeated"] = value - } -} - -// Performs beam search decoding on the logits given in input. +// For each batch `i` and class `j` we have // -// A note about the attribute merge_repeated: For the beam search decoder, -// this means that if consecutive entries in a beam are the same, only -// the first of these is emitted. That is, when the top path is "A B B B B", -// "A B" is returned if merge_repeated = True but "A B B B B" is -// returned if merge_repeated = False. +// logsoftmax[i, j] = logits[i, j] - log(sum(exp(logits[i]))) // // Arguments: -// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. -// sequence_length: A vector containing sequence lengths, size `(batch)`. -// beam_width: A scalar >= 0 (beam search beam width). -// top_paths: A scalar >= 0, <= beam_width (controls output size). +// logits: 2-D with shape `[batch_size, num_classes]`. // -// Returns A list (length: top_paths) of indices matrices. Matrix j, -// size `(total_decoded_outputs[j] x 2)`, has indices of a -// `SparseTensor`. The rows store: [batch, time].A list (length: top_paths) of values vectors. Vector j, -// size `(length total_decoded_outputs[j])`, has the values of a -// `SparseTensor`. The vector stores the decoded classes for beam j.A list (length: top_paths) of shape vector. Vector j, -// size `(2)`, stores the shape of the decoded `SparseTensor[j]`. -// Its values are: `[batch_size, max_decoded_length[j]]`.A matrix, shaped: `(batch_size x top_paths)`. The -// sequence log-probabilities. -func CTCBeamSearchDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, beam_width int64, top_paths int64, optional ...CTCBeamSearchDecoderAttr) (decoded_indices []tf.Output, decoded_values []tf.Output, decoded_shape []tf.Output, log_probability tf.Output) { +// Returns Same shape as `logits`. +func LogSoftmax(scope *Scope, logits tf.Output) (logsoftmax tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"beam_width": beam_width, "top_paths": top_paths} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "CTCBeamSearchDecoder", + Type: "LogSoftmax", Input: []tf.Input{ - inputs, sequence_length, + logits, }, - Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the truth value of (x <= y) element-wise. +// +// *NOTE*: `LessEqual` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func LessEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if decoded_indices, idx, err = makeOutputList(op, idx, "decoded_indices"); err != nil { - scope.UpdateErr("CTCBeamSearchDecoder", err) - return - } - if decoded_values, idx, err = makeOutputList(op, idx, "decoded_values"); err != nil { - scope.UpdateErr("CTCBeamSearchDecoder", err) - return - } - if decoded_shape, idx, err = makeOutputList(op, idx, "decoded_shape"); err != nil { - scope.UpdateErr("CTCBeamSearchDecoder", err) - return + opspec := tf.OpSpec{ + Type: "LessEqual", + Input: []tf.Input{ + x, y, + }, } - log_probability = op.Output(idx) - return decoded_indices, decoded_values, decoded_shape, log_probability + op := scope.AddOperation(opspec) + return op.Output(0) } -// AudioSpectrogramAttr is an optional argument to AudioSpectrogram. -type AudioSpectrogramAttr func(optionalAttr) - -// AudioSpectrogramMagnitudeSquared sets the optional magnitude_squared attribute to value. +// Computes softmax activations. // -// value: Whether to return the squared magnitude or just the -// magnitude. Using squared magnitude can avoid extra calculations. -// If not specified, defaults to false -func AudioSpectrogramMagnitudeSquared(value bool) AudioSpectrogramAttr { +// For each batch `i` and class `j` we have +// +// $$softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))$$ +// +// Arguments: +// logits: 2-D with shape `[batch_size, num_classes]`. +// +// Returns Same shape as `logits`. +func Softmax(scope *Scope, logits tf.Output) (softmax tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Softmax", + Input: []tf.Input{ + logits, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DecodeBmpAttr is an optional argument to DecodeBmp. +type DecodeBmpAttr func(optionalAttr) + +// DecodeBmpChannels sets the optional channels attribute to value. +// If not specified, defaults to 0 +func DecodeBmpChannels(value int64) DecodeBmpAttr { return func(m optionalAttr) { - m["magnitude_squared"] = value + m["channels"] = value } } -// Produces a visualization of audio data over time. -// -// Spectrograms are a standard way of representing audio information as a series of -// slices of frequency information, one slice for each window of time. By joining -// these together into a sequence, they form a distinctive fingerprint of the sound -// over time. -// -// This op expects to receive audio data as an input, stored as floats in the range -// -1 to 1, together with a window width in samples, and a stride specifying how -// far to move the window between slices. From this it generates a three -// dimensional output. The lowest dimension has an amplitude value for each -// frequency during that time slice. The next dimension is time, with successive -// frequency slices. The final dimension is for the channels in the input, so a -// stereo audio input would have two here for example. +// Decode the first frame of a BMP-encoded image to a uint8 tensor. // -// This means the layout when converted and saved as an image is rotated 90 degrees -// clockwise from a typical spectrogram. Time is descending down the Y axis, and -// the frequency decreases from left to right. +// The attr `channels` indicates the desired number of color channels for the +// decoded image. // -// Each value in the result represents the square root of the sum of the real and -// imaginary parts of an FFT on the current window of samples. In this way, the -// lowest dimension represents the power of each frequency in the current window, -// and adjacent windows are concatenated in the next dimension. +// Accepted values are: // -// To get a more intuitive and visual look at what this operation does, you can run -// tensorflow/examples/wav_to_spectrogram to read in an audio file and save out the -// resulting spectrogram as a PNG image. +// * 0: Use the number of channels in the BMP-encoded image. +// * 3: output an RGB image. +// * 4: output an RGBA image. // // Arguments: -// input: Float representation of audio data. -// window_size: How wide the input window is in samples. For the highest efficiency -// this should be a power of two, but other values are accepted. -// stride: How widely apart the center of adjacent sample windows should be. +// contents: 0-D. The BMP-encoded image. // -// Returns 3D representation of the audio frequencies as an image. -func AudioSpectrogram(scope *Scope, input tf.Output, window_size int64, stride int64, optional ...AudioSpectrogramAttr) (spectrogram tf.Output) { +// Returns 3-D with shape `[height, width, channels]`. RGB order +func DecodeBmp(scope *Scope, contents tf.Output, optional ...DecodeBmpAttr) (image tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"window_size": window_size, "stride": stride} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AudioSpectrogram", + Type: "DecodeBmp", Input: []tf.Input{ - input, + contents, }, Attrs: attrs, } @@ -5877,98 +6038,105 @@ func AudioSpectrogram(scope *Scope, input tf.Output, window_size int64, stride i return op.Output(0) } -// Compute the polygamma function \\(\psi^{(n)}(x)\\). -// -// The polygamma function is defined as: -// +// Computes softsign gradients for a softsign operation. // -// \\(\psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x)\\) +// Arguments: +// gradients: The backpropagated gradients to the corresponding softsign operation. +// features: The features passed as input to the corresponding softsign operation. // -// where \\(\psi(x)\\) is the digamma function. -func Polygamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { +// Returns The gradients: `gradients / (1 + abs(features)) ** 2`. +func SoftsignGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Polygamma", + Type: "SoftsignGrad", Input: []tf.Input{ - a, x, + gradients, features, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes second-order gradients of the maxpooling function. +// Provides the time since epoch in seconds. // -// Arguments: -// input: The original input. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the -// input of `max_pool`. -// argmax: The indices of the maximum values chosen for each output of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// Returns the timestamp as a `float64` for seconds since the Unix epoch. // -// Returns Gradients of gradients w.r.t. the input of `max_pool`. -func MaxPoolGradGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output) { +// Note: the timestamp is computed when the op is executed, not when it is added +// to the graph. +func Timestamp(scope *Scope) (ts tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "MaxPoolGradGradWithArgmax", - Input: []tf.Input{ - input, grad, argmax, - }, - Attrs: attrs, + Type: "Timestamp", } op := scope.AddOperation(opspec) return op.Output(0) } -// MaxPoolGradGradV2Attr is an optional argument to MaxPoolGradGradV2. -type MaxPoolGradGradV2Attr func(optionalAttr) +// BatchMatMulAttr is an optional argument to BatchMatMul. +type BatchMatMulAttr func(optionalAttr) -// MaxPoolGradGradV2DataFormat sets the optional data_format attribute to value. +// BatchMatMulAdjX sets the optional adj_x attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradGradV2DataFormat(value string) MaxPoolGradGradV2Attr { +// value: If `True`, adjoint the slices of `x`. Defaults to `False`. +// If not specified, defaults to false +func BatchMatMulAdjX(value bool) BatchMatMulAttr { return func(m optionalAttr) { - m["data_format"] = value + m["adj_x"] = value } } -// Computes second-order gradients of the maxpooling function. +// BatchMatMulAdjY sets the optional adj_y attribute to value. +// +// value: If `True`, adjoint the slices of `y`. Defaults to `False`. +// If not specified, defaults to false +func BatchMatMulAdjY(value bool) BatchMatMulAttr { + return func(m optionalAttr) { + m["adj_y"] = value + } +} + +// Multiplies slices of two tensors in batches. +// +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. Each of the +// individual slices can optionally be adjointed (to adjoint a matrix +// means to transpose and conjugate it) before multiplication by setting +// the `adj_x` or `adj_y` flag to `True`, which are by default `False`. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if adj_x else r_x +// c_o = r_y if adj_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) // // Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// x: 2-D or higher with shape `[..., r_x, c_x]`. +// y: 2-D or higher with shape `[..., r_y, c_y]`. // -// Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPoolGradGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolGradGradV2Attr) (output tf.Output) { +// Returns 3-D or higher with shape `[..., r_o, c_o]` +func BatchMatMul(scope *Scope, x tf.Output, y tf.Output, optional ...BatchMatMulAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MaxPoolGradGradV2", + Type: "BatchMatMul", Input: []tf.Input{ - orig_input, orig_output, grad, ksize, strides, + x, y, }, Attrs: attrs, } @@ -5976,64 +6144,72 @@ func MaxPoolGradGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output return op.Output(0) } -// Computes gradients of the maxpooling function. +// Returns which elements of x are NaN. // -// Arguments: -// input: The original input. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the -// output of `max_pool`. -// argmax: The indices of the maximum values chosen for each output of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// @compatibility(numpy) +// Equivalent to np.isnan +// @end_compatibility +func IsNan(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IsNan", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Identity op for gradient debugging. // -// Returns Gradients w.r.t. the input of `max_pool`. -func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output) { +// This op is hidden from public in Python. It is used by TensorFlow Debugger to +// register gradient tensors for gradient debugging. +// This op operates on non-reference-type tensors. +func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "MaxPoolGradWithArgmax", + Type: "DebugGradientIdentity", Input: []tf.Input{ - input, grad, argmax, + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// MutexV2Attr is an optional argument to MutexV2. -type MutexV2Attr func(optionalAttr) +// ResourceSparseApplyAdadeltaAttr is an optional argument to ResourceSparseApplyAdadelta. +type ResourceSparseApplyAdadeltaAttr func(optionalAttr) -// MutexV2Container sets the optional container attribute to value. +// ResourceSparseApplyAdadeltaUseLocking sets the optional use_locking attribute to value. // -// value: If non-empty, this variable is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutexV2Container(value string) MutexV2Attr { +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyAdadeltaUseLocking(value bool) ResourceSparseApplyAdadeltaAttr { return func(m optionalAttr) { - m["container"] = value + m["use_locking"] = value } } -// MutexV2SharedName sets the optional shared_name attribute to value. +// var: Should be from a Variable(). // -// value: If non-empty, this variable is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func MutexV2SharedName(value string) MutexV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Creates a Mutex resource that can be locked by `MutexLock`. +// Arguments: // -// Returns The mutex resource. -func MutexV2(scope *Scope, optional ...MutexV2Attr) (resource tf.Output) { +// accum: Should be from a Variable(). +// accum_update: : Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// rho: Decay factor. Must be a scalar. +// epsilon: Constant factor. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// +// Returns the created operation. +func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdadeltaAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -6042,54 +6218,187 @@ func MutexV2(scope *Scope, optional ...MutexV2Attr) (resource tf.Output) { a(attrs) } opspec := tf.OpSpec{ - Type: "MutexV2", + Type: "ResourceSparseApplyAdadelta", + Input: []tf.Input{ + var_, accum, accum_update, lr, rho, epsilon, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Computes rectified linear gradients for a Relu operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Relu operation. +// features: The features passed as input to the corresponding Relu operation, OR +// the outputs of that operation (both work equivalently). +// +// Returns `gradients * (features > 0)`. +func ReluGrad(scope *Scope, gradients tf.Output, features tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReluGrad", + Input: []tf.Input{ + gradients, features, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} +// Computes the gradient of morphological 2-D dilation with respect to the input. +// +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, depth]`. +// filter: 3-D with shape `[filter_height, filter_width, depth]`. +// out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. +// strides: 1-D of length 4. The stride of the sliding window for each dimension of +// the input tensor. Must be: `[1, stride_height, stride_width, 1]`. +// rates: 1-D of length 4. The input stride for atrous morphological dilation. +// Must be: `[1, rate_height, rate_width, 1]`. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape `[batch, in_height, in_width, depth]`. +func Dilation2DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, rates []int64, padding string) (in_backprop tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "rates": rates, "padding": padding} + opspec := tf.OpSpec{ + Type: "Dilation2DBackpropInput", + Input: []tf.Input{ + input, filter, out_backprop, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// AvgPool3DAttr is an optional argument to AvgPool3D. -type AvgPool3DAttr func(optionalAttr) +// Compute the polygamma function \\(\psi^{(n)}(x)\\). +// +// The polygamma function is defined as: +// +// +// \\(\psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x)\\) +// +// where \\(\psi(x)\\) is the digamma function. +func Polygamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Polygamma", + Input: []tf.Input{ + a, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} -// AvgPool3DDataFormat sets the optional data_format attribute to value. +// Computes second-order gradients of the maxpooling function. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func AvgPool3DDataFormat(value string) AvgPool3DAttr { +// Arguments: +// input: The original input. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the +// input of `max_pool`. +// argmax: The indices of the maximum values chosen for each output of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns Gradients of gradients w.r.t. the input of `max_pool`. +func MaxPoolGradGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + opspec := tf.OpSpec{ + Type: "MaxPoolGradGradWithArgmax", + Input: []tf.Input{ + input, grad, argmax, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MaxPoolGradGradV2Attr is an optional argument to MaxPoolGradGradV2. +type MaxPoolGradGradV2Attr func(optionalAttr) + +// MaxPoolGradGradV2DataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradGradV2DataFormat(value string) MaxPoolGradGradV2Attr { return func(m optionalAttr) { m["data_format"] = value } } -// Performs 3D average pooling on the input. +// Computes second-order gradients of the maxpooling function. // // Arguments: -// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. // padding: The type of padding algorithm to use. // -// Returns The average pooled output tensor. -func AvgPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DAttr) (output tf.Output) { +// Returns Gradients of gradients w.r.t. the input to `max_pool`. +func MaxPoolGradGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolGradGradV2Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{"padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AvgPool3D", + Type: "MaxPoolGradGradV2", Input: []tf.Input{ - input, + orig_input, orig_output, grad, ksize, strides, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes gradients of the maxpooling function. +// +// Arguments: +// input: The original input. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the +// output of `max_pool`. +// argmax: The indices of the maximum values chosen for each output of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns Gradients w.r.t. the input of `max_pool`. +func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + opspec := tf.OpSpec{ + Type: "MaxPoolGradWithArgmax", + Input: []tf.Input{ + input, grad, argmax, }, Attrs: attrs, } @@ -6097,6 +6406,51 @@ func AvgPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, pa return op.Output(0) } +// MutexV2Attr is an optional argument to MutexV2. +type MutexV2Attr func(optionalAttr) + +// MutexV2Container sets the optional container attribute to value. +// +// value: If non-empty, this variable is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutexV2Container(value string) MutexV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutexV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this variable is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func MutexV2SharedName(value string) MutexV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Creates a Mutex resource that can be locked by `MutexLock`. +// +// Returns The mutex resource. +func MutexV2(scope *Scope, optional ...MutexV2Attr) (resource tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutexV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Returns element-wise remainder of division. This emulates C semantics in that // // the result here is consistent with a truncating divide. E.g. @@ -6657,8 +7011,9 @@ type CropAndResizeAttr func(optionalAttr) // CropAndResizeMethod sets the optional method attribute to value. // -// value: A string specifying the interpolation method. Only 'bilinear' is -// supported for now. +// value: A string specifying the sampling method for resizing. It can be either +// `"bilinear"` or `"nearest"` and default to `"bilinear"`. Currently two sampling +// methods are supported: Bilinear and Nearest Neighbor. // If not specified, defaults to "bilinear" func CropAndResizeMethod(value string) CropAndResizeAttr { return func(m optionalAttr) { @@ -6676,19 +7031,23 @@ func CropAndResizeExtrapolationValue(value float32) CropAndResizeAttr { } } -// Extracts crops from the input image tensor and bilinearly resizes them (possibly +// Extracts crops from the input image tensor and resizes them. // -// with aspect ratio change) to a common output size specified by `crop_size`. This -// is more general than the `crop_to_bounding_box` op which extracts a fixed size -// slice from the input image and does not allow resizing or aspect ratio change. +// Extracts crops from the input image tensor and resizes them using bilinear +// sampling or nearest neighbor sampling (possibly with aspect ratio change) to a +// common output size specified by `crop_size`. This is more general than the +// `crop_to_bounding_box` op which extracts a fixed size slice from the input image +// and does not allow resizing or aspect ratio change. // // Returns a tensor with `crops` from the input `image` at positions defined at the // bounding box locations in `boxes`. The cropped boxes are all resized (with -// bilinear interpolation) to a fixed `size = [crop_height, crop_width]`. The -// result is a 4-D tensor `[num_boxes, crop_height, crop_width, depth]`. The -// resizing is corner aligned. In particular, if `boxes = [[0, 0, 1, 1]]`, the -// method will give identical results to using `tf.image.resize_bilinear()` -// with `align_corners=True`. +// bilinear or nearest neighbor interpolation) to a fixed +// `size = [crop_height, crop_width]`. The result is a 4-D tensor +// `[num_boxes, crop_height, crop_width, depth]`. The resizing is corner aligned. +// In particular, if `boxes = [[0, 0, 1, 1]]`, the method will give identical +// results to using `tf.image.resize_bilinear()` or +// `tf.image.resize_nearest_neighbor()`(depends on the `method` argument) with +// `align_corners=True`. // // Arguments: // image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`. @@ -7071,88 +7430,97 @@ func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (ou return op.Output(0) } -// Transforms a Tensor into a serialized TensorProto proto. +// Computes the Bessel i1e function of `x` element-wise. // -// Arguments: -// tensor: A Tensor of type `T`. +// Exponentially scaled modified Bessel function of order 0 defined as +// `bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. // -// Returns A serialized TensorProto proto of the input tensor. -func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) { +// This function is faster and numerically stabler than `bessel_i1(x)`. +func BesselI1e(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SerializeTensor", + Type: "BesselI1e", Input: []tf.Input{ - tensor, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes acos of x element-wise. -func Acos(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } +// Transforms a Tensor into a serialized TensorProto proto. +// +// Arguments: +// tensor: A Tensor of type `T`. +// +// Returns A serialized TensorProto proto of the input tensor. +func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) { + if scope.Err() != nil { + return + } opspec := tf.OpSpec{ - Type: "Acos", + Type: "SerializeTensor", Input: []tf.Input{ - x, + tensor, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// QuantizeAndDequantizeAttr is an optional argument to QuantizeAndDequantize. -type QuantizeAndDequantizeAttr func(optionalAttr) - -// QuantizeAndDequantizeSignedInput sets the optional signed_input attribute to value. -// If not specified, defaults to true -func QuantizeAndDequantizeSignedInput(value bool) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["signed_input"] = value +// Computes acos of x element-wise. +func Acos(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return } -} - -// QuantizeAndDequantizeNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func QuantizeAndDequantizeNumBits(value int64) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["num_bits"] = value + opspec := tf.OpSpec{ + Type: "Acos", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// QuantizeAndDequantizeRangeGiven sets the optional range_given attribute to value. -// If not specified, defaults to false -func QuantizeAndDequantizeRangeGiven(value bool) QuantizeAndDequantizeAttr { - return func(m optionalAttr) { - m["range_given"] = value - } -} +// UnbatchGradAttr is an optional argument to UnbatchGrad. +type UnbatchGradAttr func(optionalAttr) -// QuantizeAndDequantizeInputMin sets the optional input_min attribute to value. -// If not specified, defaults to 0 -func QuantizeAndDequantizeInputMin(value float32) QuantizeAndDequantizeAttr { +// UnbatchGradContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func UnbatchGradContainer(value string) UnbatchGradAttr { return func(m optionalAttr) { - m["input_min"] = value + m["container"] = value } } -// QuantizeAndDequantizeInputMax sets the optional input_max attribute to value. -// If not specified, defaults to 0 -func QuantizeAndDequantizeInputMax(value float32) QuantizeAndDequantizeAttr { +// UnbatchGradSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func UnbatchGradSharedName(value string) UnbatchGradAttr { return func(m optionalAttr) { - m["input_max"] = value + m["shared_name"] = value } } -// Use QuantizeAndDequantizeV2 instead. +// Gradient of Unbatch. // -// DEPRECATED at GraphDef version 22: Replaced by QuantizeAndDequantizeV2 -func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAndDequantizeAttr) (output tf.Output) { +// Acts like Batch but using the given batch_index index of batching things as they +// become available. This ensures that the gradients are propagated back in the +// same session which did the forward pass. +// +// original_input: The input to the Unbatch operation this is the gradient of. +// batch_index: The batch_index given to the Unbatch operation this is the gradient +// of. +// grad: The downstream gradient. +// id: The id scalar emitted by Batch. +// batched_grad: The return value, either an empty tensor or the batched gradient. +// container: Container to control resource sharing. +// shared_name: Instances of UnbatchGrad with the same container and shared_name +// are assumed to possibly belong to the same batch. If left empty, the op name +// will be used as the shared name. +func UnbatchGrad(scope *Scope, original_input tf.Output, batch_index tf.Output, grad tf.Output, id tf.Output, optional ...UnbatchGradAttr) (batched_grad tf.Output) { if scope.Err() != nil { return } @@ -7161,9 +7529,9 @@ func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAn a(attrs) } opspec := tf.OpSpec{ - Type: "QuantizeAndDequantize", + Type: "UnbatchGrad", Input: []tf.Input{ - input, + original_input, batch_index, grad, id, }, Attrs: attrs, } @@ -7171,280 +7539,1244 @@ func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAn return op.Output(0) } -// Returns locations of nonzero / true values in a tensor. -// -// This operation returns the coordinates of true elements in `condition`. The -// coordinates are returned in a 2-D tensor where the first dimension (rows) -// represents the number of true elements, and the second dimension (columns) -// represents the coordinates of the true elements. Keep in mind, the shape of -// the output tensor can vary depending on how many true values there are in -// `condition`. Indices are output in row-major order. -// -// For example: -// -// ``` -// # 'input' tensor is [[True, False] -// # [True, False]] -// # 'input' has two true values, so output has two coordinates. -// # 'input' has rank of 2, so coordinates have two indices. -// where(input) ==> [[0, 0], -// [1, 0]] -// -// # `condition` tensor is [[[True, False] -// # [True, False]] -// # [[False, True] -// # [False, True]] -// # [[False, False] -// # [False, True]]] -// # 'input' has 5 true values, so output has 5 coordinates. -// # 'input' has rank of 3, so coordinates have three indices. -// where(input) ==> [[0, 0, 0], -// [0, 1, 0], -// [1, 0, 1], -// [1, 1, 1], -// [2, 1, 1]] -// -// # `condition` tensor is [[[1.5, 0.0] -// # [-0.5, 0.0]] -// # [[0.0, 0.25] -// # [0.0, 0.75]] -// # [[0.0, 0.0] -// # [0.0, 0.01]]] -// # 'input' has 5 nonzero values, so output has 5 coordinates. -// # 'input' has rank of 3, so coordinates have three indices. -// where(input) ==> [[0, 0, 0], -// [0, 1, 0], -// [1, 0, 1], -// [1, 1, 1], -// [2, 1, 1]] -// -// # `condition` tensor is [[[1.5 + 0.0j, 0.0 + 0.0j] -// # [0.0 + 0.5j, 0.0 + 0.0j]] -// # [[0.0 + 0.0j, 0.25 + 1.5j] -// # [0.0 + 0.0j, 0.75 + 0.0j]] -// # [[0.0 + 0.0j, 0.0 + 0.0j] -// # [0.0 + 0.0j, 0.01 + 0.0j]]] -// # 'input' has 5 nonzero magnitude values, so output has 5 coordinates. -// # 'input' has rank of 3, so coordinates have three indices. -// where(input) ==> [[0, 0, 0], -// [0, 1, 0], -// [1, 0, 1], -// [1, 1, 1], -// [2, 1, 1]] -// ``` -func Where(scope *Scope, condition tf.Output) (index tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Where", - Input: []tf.Input{ - condition, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// QueueDequeueV2Attr is an optional argument to QueueDequeueV2. -type QueueDequeueV2Attr func(optionalAttr) +// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. +type AvgPool3DGradAttr func(optionalAttr) -// QueueDequeueV2TimeoutMs sets the optional timeout_ms attribute to value. +// AvgPool3DGradDataFormat sets the optional data_format attribute to value. // -// value: If the queue is empty, this operation will block for up to -// timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueDequeueV2TimeoutMs(value int64) QueueDequeueV2Attr { +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { return func(m optionalAttr) { - m["timeout_ms"] = value + m["data_format"] = value } } -// Dequeues a tuple of one or more tensors from the given queue. -// -// This operation has k outputs, where k is the number of components -// in the tuples stored in the given queue, and output i is the ith -// component of the dequeued tuple. -// -// N.B. If the queue is empty, this operation will block until an element -// has been dequeued (or 'timeout_ms' elapses, if specified). +// Computes gradients of average pooling function. // // Arguments: -// handle: The handle to a queue. -// component_types: The type of each component in a tuple. +// orig_input_shape: The original input dimensions. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. // -// Returns One or more tensors that were dequeued as a tuple. -func QueueDequeueV2(scope *Scope, handle tf.Output, component_types []tf.DataType, optional ...QueueDequeueV2Attr) (components []tf.Output) { +// Returns The backprop for input. +func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QueueDequeueV2", + Type: "AvgPool3DGrad", Input: []tf.Input{ - handle, + orig_input_shape, grad, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("QueueDequeueV2", err) - return - } - return components -} - -// Computes the Gauss error function of `x` element-wise. -func Erf(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Erf", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns element-wise largest integer not greater than x. -func Floor(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Floor", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) return op.Output(0) } -// OneHotAttr is an optional argument to OneHot. -type OneHotAttr func(optionalAttr) +// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample. +type ParseSingleSequenceExampleAttr func(optionalAttr) -// OneHotAxis sets the optional axis attribute to value. +// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value. +// +// value: A list of Ncontext_sparse types; the data types of data in +// each context Feature given in context_sparse_keys. +// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["context_sparse_types"] = value + } +} + +// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_dense_types"] = value + } +} + +// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value. +// +// value: A list of Ncontext_dense shapes; the shapes of data in +// each context Feature given in context_dense_keys. +// The number of elements in the Feature corresponding to context_dense_key[j] +// must always equal context_dense_shapes[j].NumEntries(). +// The shape of context_dense_values[j] will match context_dense_shapes[j]. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["context_dense_shapes"] = value + } +} + +// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value. +// +// value: A list of Nfeature_list_sparse types; the data types +// of data in each FeatureList given in feature_list_sparse_keys. +// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_sparse_types"] = value + } +} + +// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value. +// +// value: A list of Nfeature_list_dense shapes; the shapes of +// data in each FeatureList given in feature_list_dense_keys. +// The shape of each Feature in the FeatureList corresponding to +// feature_list_dense_key[j] must always equal +// feature_list_dense_shapes[j].NumEntries(). +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { + return func(m optionalAttr) { + m["feature_list_dense_shapes"] = value + } +} + +// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. +// +// Arguments: +// serialized: A scalar containing a binary serialized SequenceExample proto. +// feature_list_dense_missing_assumed_empty: A vector listing the +// FeatureList keys which may be missing from the SequenceExample. If the +// associated FeatureList is missing, it is treated as empty. By default, +// any FeatureList not listed in this vector must exist in the SequenceExample. +// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars). +// The keys expected in the Examples' features associated with context_sparse +// values. +// context_dense_keys: A list of Ncontext_dense string Tensors (scalars). +// The keys expected in the SequenceExamples' context features associated with +// dense values. +// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors +// (scalars). The keys expected in the FeatureLists associated with sparse +// values. +// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars). +// The keys expected in the SequenceExamples' feature_lists associated +// with lists of dense values. +// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty). +// context_dense_defaults[j] provides default values +// when the SequenceExample's context map lacks context_dense_key[j]. +// If an empty Tensor is provided for context_dense_defaults[j], +// then the Feature context_dense_keys[j] is required. +// The input type is inferred from context_dense_defaults[j], even when it's +// empty. If context_dense_defaults[j] is not empty, its shape must match +// context_dense_shapes[j]. +// debug_name: A scalar containing the name of the serialized proto. +// May contain, for example, table key (descriptive) name for the +// corresponding serialized proto. This is purely useful for debugging +// purposes, and the presence of values here has no effect on the output. +// May also be an empty scalar if no name is available. +func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ParseSingleSequenceExample", + Input: []tf.Input{ + serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil { + scope.UpdateErr("ParseSingleSequenceExample", err) + return + } + return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values +} + +// QuantizeAndDequantizeAttr is an optional argument to QuantizeAndDequantize. +type QuantizeAndDequantizeAttr func(optionalAttr) + +// QuantizeAndDequantizeSignedInput sets the optional signed_input attribute to value. +// If not specified, defaults to true +func QuantizeAndDequantizeSignedInput(value bool) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["signed_input"] = value + } +} + +// QuantizeAndDequantizeNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func QuantizeAndDequantizeNumBits(value int64) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// QuantizeAndDequantizeRangeGiven sets the optional range_given attribute to value. +// If not specified, defaults to false +func QuantizeAndDequantizeRangeGiven(value bool) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["range_given"] = value + } +} + +// QuantizeAndDequantizeInputMin sets the optional input_min attribute to value. +// If not specified, defaults to 0 +func QuantizeAndDequantizeInputMin(value float32) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["input_min"] = value + } +} + +// QuantizeAndDequantizeInputMax sets the optional input_max attribute to value. +// If not specified, defaults to 0 +func QuantizeAndDequantizeInputMax(value float32) QuantizeAndDequantizeAttr { + return func(m optionalAttr) { + m["input_max"] = value + } +} + +// Use QuantizeAndDequantizeV2 instead. +// +// DEPRECATED at GraphDef version 22: Replaced by QuantizeAndDequantizeV2 +func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAndDequantizeAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizeAndDequantize", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns locations of nonzero / true values in a tensor. +// +// This operation returns the coordinates of true elements in `condition`. The +// coordinates are returned in a 2-D tensor where the first dimension (rows) +// represents the number of true elements, and the second dimension (columns) +// represents the coordinates of the true elements. Keep in mind, the shape of +// the output tensor can vary depending on how many true values there are in +// `condition`. Indices are output in row-major order. +// +// For example: +// +// ``` +// # 'input' tensor is [[True, False] +// # [True, False]] +// # 'input' has two true values, so output has two coordinates. +// # 'input' has rank of 2, so coordinates have two indices. +// where(input) ==> [[0, 0], +// [1, 0]] +// +// # `condition` tensor is [[[True, False] +// # [True, False]] +// # [[False, True] +// # [False, True]] +// # [[False, False] +// # [False, True]]] +// # 'input' has 5 true values, so output has 5 coordinates. +// # 'input' has rank of 3, so coordinates have three indices. +// where(input) ==> [[0, 0, 0], +// [0, 1, 0], +// [1, 0, 1], +// [1, 1, 1], +// [2, 1, 1]] +// +// # `condition` tensor is [[[1.5, 0.0] +// # [-0.5, 0.0]] +// # [[0.0, 0.25] +// # [0.0, 0.75]] +// # [[0.0, 0.0] +// # [0.0, 0.01]]] +// # 'input' has 5 nonzero values, so output has 5 coordinates. +// # 'input' has rank of 3, so coordinates have three indices. +// where(input) ==> [[0, 0, 0], +// [0, 1, 0], +// [1, 0, 1], +// [1, 1, 1], +// [2, 1, 1]] +// +// # `condition` tensor is [[[1.5 + 0.0j, 0.0 + 0.0j] +// # [0.0 + 0.5j, 0.0 + 0.0j]] +// # [[0.0 + 0.0j, 0.25 + 1.5j] +// # [0.0 + 0.0j, 0.75 + 0.0j]] +// # [[0.0 + 0.0j, 0.0 + 0.0j] +// # [0.0 + 0.0j, 0.01 + 0.0j]]] +// # 'input' has 5 nonzero magnitude values, so output has 5 coordinates. +// # 'input' has rank of 3, so coordinates have three indices. +// where(input) ==> [[0, 0, 0], +// [0, 1, 0], +// [1, 0, 1], +// [1, 1, 1], +// [2, 1, 1]] +// ``` +func Where(scope *Scope, condition tf.Output) (index tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Where", + Input: []tf.Input{ + condition, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// QueueDequeueV2Attr is an optional argument to QueueDequeueV2. +type QueueDequeueV2Attr func(optionalAttr) + +// QueueDequeueV2TimeoutMs sets the optional timeout_ms attribute to value. +// +// value: If the queue is empty, this operation will block for up to +// timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueDequeueV2TimeoutMs(value int64) QueueDequeueV2Attr { + return func(m optionalAttr) { + m["timeout_ms"] = value + } +} + +// Dequeues a tuple of one or more tensors from the given queue. +// +// This operation has k outputs, where k is the number of components +// in the tuples stored in the given queue, and output i is the ith +// component of the dequeued tuple. +// +// N.B. If the queue is empty, this operation will block until an element +// has been dequeued (or 'timeout_ms' elapses, if specified). +// +// Arguments: +// handle: The handle to a queue. +// component_types: The type of each component in a tuple. +// +// Returns One or more tensors that were dequeued as a tuple. +func QueueDequeueV2(scope *Scope, handle tf.Output, component_types []tf.DataType, optional ...QueueDequeueV2Attr) (components []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"component_types": component_types} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QueueDequeueV2", + Input: []tf.Input{ + handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("QueueDequeueV2", err) + return + } + return components +} + +// Computes the Gauss error function of `x` element-wise. +func Erf(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Erf", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns element-wise largest integer not greater than x. +func Floor(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Floor", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// OneHotAttr is an optional argument to OneHot. +type OneHotAttr func(optionalAttr) + +// OneHotAxis sets the optional axis attribute to value. // // value: The axis to fill (default: -1, a new inner-most axis). // If not specified, defaults to -1 func OneHotAxis(value int64) OneHotAttr { return func(m optionalAttr) { - m["axis"] = value + m["axis"] = value + } +} + +// Returns a one-hot tensor. +// +// The locations represented by indices in `indices` take value `on_value`, +// while all other locations take value `off_value`. +// +// If the input `indices` is rank `N`, the output will have rank `N+1`, +// The new axis is created at dimension `axis` (default: the new axis is +// appended at the end). +// +// If `indices` is a scalar the output shape will be a vector of length `depth`. +// +// If `indices` is a vector of length `features`, the output shape will be: +// ``` +// features x depth if axis == -1 +// depth x features if axis == 0 +// ``` +// +// If `indices` is a matrix (batch) with shape `[batch, features]`, +// the output shape will be: +// ``` +// batch x features x depth if axis == -1 +// batch x depth x features if axis == 1 +// depth x batch x features if axis == 0 +// ``` +// +// +// Examples +// ========= +// +// Suppose that +// +// ``` +// indices = [0, 2, -1, 1] +// depth = 3 +// on_value = 5.0 +// off_value = 0.0 +// axis = -1 +// ``` +// +// Then output is `[4 x 3]`: +// +// ```output = +// [5.0 0.0 0.0] // one_hot(0) +// [0.0 0.0 5.0] // one_hot(2) +// [0.0 0.0 0.0] // one_hot(-1) +// [0.0 5.0 0.0] // one_hot(1) +// ``` +// +// Suppose that +// +// ``` +// indices = [0, 2, -1, 1] +// depth = 3 +// on_value = 0.0 +// off_value = 3.0 +// axis = 0 +// ``` +// +// Then output is `[3 x 4]`: +// +// ```output = +// [0.0 3.0 3.0 3.0] +// [3.0 3.0 3.0 0.0] +// [3.0 3.0 3.0 3.0] +// [3.0 0.0 3.0 3.0] +// // ^ one_hot(0) +// // ^ one_hot(2) +// // ^ one_hot(-1) +// // ^ one_hot(1) +// ``` +// Suppose that +// +// ``` +// indices = [[0, 2], [1, -1]] +// depth = 3 +// on_value = 1.0 +// off_value = 0.0 +// axis = -1 +// ``` +// +// Then output is `[2 x 2 x 3]`: +// +// ```output = +// [ +// [1.0, 0.0, 0.0] // one_hot(0) +// [0.0, 0.0, 1.0] // one_hot(2) +// ][ +// [0.0, 1.0, 0.0] // one_hot(1) +// [0.0, 0.0, 0.0] // one_hot(-1) +// ]``` +// +// Arguments: +// indices: A tensor of indices. +// depth: A scalar defining the depth of the one hot dimension. +// on_value: A scalar defining the value to fill in output when `indices[j] = i`. +// off_value: A scalar defining the value to fill in output when `indices[j] != i`. +// +// Returns The one-hot tensor. +func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output, off_value tf.Output, optional ...OneHotAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "OneHot", + Input: []tf.Input{ + indices, depth, on_value, off_value, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Transforms a vector of brain.Example protos (as strings) into typed tensors. +// +// Arguments: +// serialized: A vector containing a batch of binary serialized Example protos. +// names: A vector containing the names of the serialized protos. +// May contain, for example, table key (descriptive) names for the +// corresponding serialized protos. These are purely useful for debugging +// purposes, and the presence of values here has no effect on the output. +// May also be an empty vector if no names are available. +// If non-empty, this vector must be the same length as "serialized". +// sparse_keys: A list of Nsparse string Tensors (scalars). +// The keys expected in the Examples' features associated with sparse values. +// dense_keys: A list of Ndense string Tensors (scalars). +// The keys expected in the Examples' features associated with dense values. +// dense_defaults: A list of Ndense Tensors (some may be empty). +// dense_defaults[j] provides default values +// when the example's feature_map lacks dense_key[j]. If an empty Tensor is +// provided for dense_defaults[j], then the Feature dense_keys[j] is required. +// The input type is inferred from dense_defaults[j], even when it's empty. +// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, +// then the shape of dense_defaults[j] must match that of dense_shapes[j]. +// If dense_shapes[j] has an undefined major dimension (variable strides dense +// feature), dense_defaults[j] must contain a single element: +// the padding element. +// sparse_types: A list of Nsparse types; the data types of data in each Feature +// given in sparse_keys. +// Currently the ParseExample supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// dense_shapes: A list of Ndense shapes; the shapes of data in each Feature +// given in dense_keys. +// The number of elements in the Feature corresponding to dense_key[j] +// must always equal dense_shapes[j].NumEntries(). +// If dense_shapes[j] == (D0, D1, ..., DN) then the shape of output +// Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN): +// The dense outputs are just the inputs row-stacked by batch. +// This works for dense_shapes[j] = (-1, D1, ..., DN). In this case +// the shape of the output Tensor dense_values[j] will be +// (|serialized|, M, D1, .., DN), where M is the maximum number of blocks +// of elements of length D1 * .... * DN, across all minibatch entries +// in the input. Any minibatch entry with less than M blocks of elements of +// length D1 * ... * DN will be padded with the corresponding default_value +// scalar element along the second dimension. +func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_keys []tf.Output, dense_keys []tf.Output, dense_defaults []tf.Output, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"sparse_types": sparse_types, "dense_shapes": dense_shapes} + opspec := tf.OpSpec{ + Type: "ParseExample", + Input: []tf.Input{ + serialized, names, tf.OutputList(sparse_keys), tf.OutputList(dense_keys), tf.OutputList(dense_defaults), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { + scope.UpdateErr("ParseExample", err) + return + } + return sparse_indices, sparse_values, sparse_shapes, dense_values +} + +// DecodeRawAttr is an optional argument to DecodeRaw. +type DecodeRawAttr func(optionalAttr) + +// DecodeRawLittleEndian sets the optional little_endian attribute to value. +// +// value: Whether the input `bytes` are in little-endian order. +// Ignored for `out_type` values that are stored in a single byte like +// `uint8`. +// If not specified, defaults to true +func DecodeRawLittleEndian(value bool) DecodeRawAttr { + return func(m optionalAttr) { + m["little_endian"] = value + } +} + +// Reinterpret the bytes of a string as a vector of numbers. +// +// Arguments: +// bytes: All the elements must have the same length. +// +// +// Returns A Tensor with one more dimension than the input `bytes`. The +// added dimension will have size equal to the length of the elements +// of `bytes` divided by the number of bytes to represent `out_type`. +func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...DecodeRawAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"out_type": out_type} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeRaw", + Input: []tf.Input{ + bytes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RandomPoissonAttr is an optional argument to RandomPoisson. +type RandomPoissonAttr func(optionalAttr) + +// RandomPoissonSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func RandomPoissonSeed(value int64) RandomPoissonAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomPoissonSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func RandomPoissonSeed2(value int64) RandomPoissonAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Use RandomPoissonV2 instead. +// +// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2 +func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomPoisson", + Input: []tf.Input{ + shape, rate, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns the element-wise sum of a list of tensors. +// +// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not +// wait for all of its inputs to be ready before beginning to sum. This can +// save memory if inputs are ready at different times, since minimum temporary +// storage is proportional to the output size rather than the inputs size. +// +// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. +// +// Returns a `Tensor` of same shape and type as the elements of `inputs`. +// +// Arguments: +// inputs: A list of `Tensor` objects, each with same shape and type. +// shape: Shape of elements of `inputs`. +func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"shape": shape} + opspec := tf.OpSpec{ + Type: "AccumulateNV2", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RandomShuffleAttr is an optional argument to RandomShuffle. +type RandomShuffleAttr func(optionalAttr) + +// RandomShuffleSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomShuffleSeed(value int64) RandomShuffleAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomShuffleSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomShuffleSeed2(value int64) RandomShuffleAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Randomly shuffles a tensor along its first dimension. +// +// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped +// to one and only one `output[i]`. For example, a mapping that might occur for a +// 3x2 tensor is: +// +// ``` +// [[1, 2], [[5, 6], +// [3, 4], ==> [1, 2], +// [5, 6]] [3, 4]] +// ``` +// +// Arguments: +// value: The tensor to be shuffled. +// +// Returns A tensor of same shape and type as `value`, shuffled along its first +// dimension. +func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomShuffle", + Input: []tf.Input{ + value, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// OrderedMapIncompleteSizeAttr is an optional argument to OrderedMapIncompleteSize. +type OrderedMapIncompleteSizeAttr func(optionalAttr) + +// OrderedMapIncompleteSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func OrderedMapIncompleteSizeCapacity(value int64) OrderedMapIncompleteSizeAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// OrderedMapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func OrderedMapIncompleteSizeMemoryLimit(value int64) OrderedMapIncompleteSizeAttr { + return func(m optionalAttr) { + m["memory_limit"] = value + } +} + +// OrderedMapIncompleteSizeContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapIncompleteSizeContainer(value string) OrderedMapIncompleteSizeAttr { + return func(m optionalAttr) { + m["container"] = value } } -// Returns a one-hot tensor. +// OrderedMapIncompleteSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapIncompleteSizeSharedName(value string) OrderedMapIncompleteSizeAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op returns the number of incomplete elements in the underlying container. +func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapIncompleteSizeAttr) (size tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "OrderedMapIncompleteSize", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter. +type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr) + +// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value. // -// The locations represented by indices in `indices` take value `on_value`, -// while all other locations take value `off_value`. +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value. // -// If the input `indices` is rank `N`, the output will have rank `N+1`, -// The new axis is created at dimension `axis` (default: the new axis is -// appended at the end). +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of depthwise convolution with respect to the filter. // -// If `indices` is a scalar the output shape will be a vector of length `depth`. +// Arguments: +// input: 4-D with shape based on `data_format`. For example, if +// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height, +// in_width, in_channels]` tensor. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 4-D +// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor. +// out_backprop: 4-D with shape based on `data_format`. +// For example, if `data_format` is 'NHWC' then +// out_backprop shape is `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. +// padding: The type of padding algorithm to use. // -// If `indices` is a vector of length `features`, the output shape will be: -// ``` -// features x depth if axis == -1 -// depth x features if axis == 0 -// ``` +// Returns 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. +// the `filter` input of the convolution. +func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DepthwiseConv2dNativeBackpropFilter", + Input: []tf.Input{ + input, filter_sizes, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns immutable tensor from memory region. // -// If `indices` is a matrix (batch) with shape `[batch, features]`, -// the output shape will be: -// ``` -// batch x features x depth if axis == -1 -// batch x depth x features if axis == 1 -// depth x batch x features if axis == 0 -// ``` +// The current implementation memmaps the tensor from a file. +// +// Arguments: +// dtype: Type of the returned tensor. +// shape: Shape of the returned tensor. +// memory_region_name: Name of readonly memory region used by the tensor, see +// NewReadOnlyMemoryRegionFromFile in tensorflow::Env. +func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name} + opspec := tf.OpSpec{ + Type: "ImmutableConst", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StringJoinAttr is an optional argument to StringJoin. +type StringJoinAttr func(optionalAttr) + +// StringJoinSeparator sets the optional separator attribute to value. +// +// value: string, an optional join separator. +// If not specified, defaults to "" +func StringJoinSeparator(value string) StringJoinAttr { + return func(m optionalAttr) { + m["separator"] = value + } +} + +// Joins the strings in the given list of string tensors into one tensor; // +// with the given separator (default is an empty separator). // -// Examples -// ========= +// Arguments: +// inputs: A list of string tensors. The tensors must all have the same shape, +// or be scalars. Scalars may be mixed in; these will be broadcast to the shape +// of non-scalar inputs. +func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StringJoin", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl. +type ResourceApplyFtrlAttr func(optionalAttr) + +// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the Ftrl-proximal scheme. +// +// accum_new = accum + grad * grad +// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new +// +// Arguments: +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regulariation. Must be a scalar. +// l2: L2 regulariation. Must be a scalar. +// lr_power: Scaling factor. Must be a scalar. +// +// Returns the created operation. +func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyFtrl", + Input: []tf.Input{ + var_, accum, linear, grad, lr, l1, l2, lr_power, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// RandomUniformAttr is an optional argument to RandomUniform. +type RandomUniformAttr func(optionalAttr) + +// RandomUniformSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomUniformSeed(value int64) RandomUniformAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomUniformSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomUniformSeed2(value int64) RandomUniformAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Outputs random values from a uniform distribution. +// +// The generated values follow a uniform distribution in the range `[0, 1)`. The +// lower bound 0 is included in the range, while the upper bound 1 is excluded. // -// Suppose that +// Arguments: +// shape: The shape of the output tensor. +// dtype: The type of the output. // -// ``` -// indices = [0, 2, -1, 1] -// depth = 3 -// on_value = 5.0 -// off_value = 0.0 -// axis = -1 -// ``` +// Returns A tensor of the specified shape filled with uniform random values. +func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomUniformAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomUniform", + Input: []tf.Input{ + shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Encode audio data using the WAV file format. // -// Then output is `[4 x 3]`: +// This operation will generate a string suitable to be saved out to create a .wav +// audio file. It will be encoded in the 16-bit PCM format. It takes in float +// values in the range -1.0f to 1.0f, and any outside that value will be clamped to +// that range. // -// ```output = -// [5.0 0.0 0.0] // one_hot(0) -// [0.0 0.0 5.0] // one_hot(2) -// [0.0 0.0 0.0] // one_hot(-1) -// [0.0 5.0 0.0] // one_hot(1) -// ``` +// `audio` is a 2-D float Tensor of shape `[length, channels]`. +// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). // -// Suppose that +// Arguments: +// audio: 2-D with shape `[length, channels]`. +// sample_rate: Scalar containing the sample frequency. // -// ``` -// indices = [0, 2, -1, 1] -// depth = 3 -// on_value = 0.0 -// off_value = 3.0 -// axis = 0 -// ``` +// Returns 0-D. WAV-encoded file contents. +func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "EncodeWav", + Input: []tf.Input{ + audio, sample_rate, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes atan of x element-wise. +func Atan(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Atan", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyAdaMaxAttr is an optional argument to ResourceApplyAdaMax. +type ResourceApplyAdaMaxAttr func(optionalAttr) + +// ResourceApplyAdaMaxUseLocking sets the optional use_locking attribute to value. // -// Then output is `[3 x 4]`: +// value: If `True`, updating of the var, m, and v tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAdaMaxUseLocking(value bool) ResourceApplyAdaMaxAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the AdaMax algorithm. // -// ```output = -// [0.0 3.0 3.0 3.0] -// [3.0 3.0 3.0 0.0] -// [3.0 3.0 3.0 3.0] -// [3.0 0.0 3.0 3.0] -// // ^ one_hot(0) -// // ^ one_hot(2) -// // ^ one_hot(-1) -// // ^ one_hot(1) -// ``` -// Suppose that +// m_t <- beta1 * m_{t-1} + (1 - beta1) * g +// v_t <- max(beta2 * v_{t-1}, abs(g)) +// variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon) // -// ``` -// indices = [[0, 2], [1, -1]] -// depth = 3 -// on_value = 1.0 -// off_value = 0.0 -// axis = -1 -// ``` +// Arguments: +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// v: Should be from a Variable(). +// beta1_power: Must be a scalar. +// lr: Scaling factor. Must be a scalar. +// beta1: Momentum factor. Must be a scalar. +// beta2: Momentum factor. Must be a scalar. +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. // -// Then output is `[2 x 2 x 3]`: +// Returns the created operation. +func ResourceApplyAdaMax(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdaMaxAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAdaMax", + Input: []tf.Input{ + var_, m, v, beta1_power, lr, beta1, beta2, epsilon, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// AssertAttr is an optional argument to Assert. +type AssertAttr func(optionalAttr) + +// AssertSummarize sets the optional summarize attribute to value. // -// ```output = -// [ -// [1.0, 0.0, 0.0] // one_hot(0) -// [0.0, 0.0, 1.0] // one_hot(2) -// ][ -// [0.0, 1.0, 0.0] // one_hot(1) -// [0.0, 0.0, 0.0] // one_hot(-1) -// ]``` +// value: Print this many entries of each tensor. +// If not specified, defaults to 3 +func AssertSummarize(value int64) AssertAttr { + return func(m optionalAttr) { + m["summarize"] = value + } +} + +// Asserts that the given condition is true. +// +// If `condition` evaluates to false, print the list of tensors in `data`. +// `summarize` determines how many entries of the tensors to print. // // Arguments: -// indices: A tensor of indices. -// depth: A scalar defining the depth of the one hot dimension. -// on_value: A scalar defining the value to fill in output when `indices[j] = i`. -// off_value: A scalar defining the value to fill in output when `indices[j] != i`. +// condition: The condition to evaluate. +// data: The tensors to print out when condition is false. // -// Returns The one-hot tensor. -func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output, off_value tf.Output, optional ...OneHotAttr) (output tf.Output) { +// Returns the created operation. +func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...AssertAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -7453,67 +8785,96 @@ func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output a(attrs) } opspec := tf.OpSpec{ - Type: "OneHot", + Type: "Assert", Input: []tf.Input{ - indices, depth, on_value, off_value, + condition, tf.OutputList(data), }, Attrs: attrs, } + return scope.AddOperation(opspec) +} + +// Computes element-wise population count (a.k.a. popcount, bitsum, bitcount). +// +// For each entry in `x`, calculates the number of `1` (on) bits in the binary +// representation of that entry. +// +// **NOTE**: It is more efficient to first `tf.bitcast` your tensors into +// `int32` or `int64` and perform the bitcount on the result, than to feed in +// 8- or 16-bit inputs and then aggregate the resulting counts. +func PopulationCount(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "PopulationCount", + Input: []tf.Input{ + x, + }, + } op := scope.AddOperation(opspec) return op.Output(0) } -// Transforms a vector of brain.Example protos (as strings) into typed tensors. -// -// Arguments: -// serialized: A vector containing a batch of binary serialized Example protos. -// names: A vector containing the names of the serialized protos. -// May contain, for example, table key (descriptive) names for the -// corresponding serialized protos. These are purely useful for debugging -// purposes, and the presence of values here has no effect on the output. -// May also be an empty vector if no names are available. -// If non-empty, this vector must be the same length as "serialized". -// sparse_keys: A list of Nsparse string Tensors (scalars). -// The keys expected in the Examples' features associated with sparse values. -// dense_keys: A list of Ndense string Tensors (scalars). -// The keys expected in the Examples' features associated with dense values. -// dense_defaults: A list of Ndense Tensors (some may be empty). -// dense_defaults[j] provides default values -// when the example's feature_map lacks dense_key[j]. If an empty Tensor is -// provided for dense_defaults[j], then the Feature dense_keys[j] is required. -// The input type is inferred from dense_defaults[j], even when it's empty. -// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, -// then the shape of dense_defaults[j] must match that of dense_shapes[j]. -// If dense_shapes[j] has an undefined major dimension (variable strides dense -// feature), dense_defaults[j] must contain a single element: -// the padding element. -// sparse_types: A list of Nsparse types; the data types of data in each Feature -// given in sparse_keys. -// Currently the ParseExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// dense_shapes: A list of Ndense shapes; the shapes of data in each Feature -// given in dense_keys. -// The number of elements in the Feature corresponding to dense_key[j] -// must always equal dense_shapes[j].NumEntries(). -// If dense_shapes[j] == (D0, D1, ..., DN) then the shape of output -// Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN): -// The dense outputs are just the inputs row-stacked by batch. -// This works for dense_shapes[j] = (-1, D1, ..., DN). In this case -// the shape of the output Tensor dense_values[j] will be -// (|serialized|, M, D1, .., DN), where M is the maximum number of blocks -// of elements of length D1 * .... * DN, across all minibatch entries -// in the input. Any minibatch entry with less than M blocks of elements of -// length D1 * ... * DN will be padded with the corresponding default_value -// scalar element along the second dimension. -func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_keys []tf.Output, dense_keys []tf.Output, dense_defaults []tf.Output, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { +// Broadcasts a tensor value to one or more other devices. +func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape} + opspec := tf.OpSpec{ + Type: "CollectiveBcastSend", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Split a `SparseTensor` into `num_split` tensors along one dimension. +// +// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices +// `[0 : shape[split_dim] % num_split]` gets one extra dimension. +// For example, if `split_dim = 1` and `num_split = 2` and the input is +// +// input_tensor = shape = [2, 7] +// [ a d e ] +// [b c ] +// +// Graphically the output tensors are: +// +// output_tensor[0] = shape = [2, 4] +// [ a ] +// [b c ] +// +// output_tensor[1] = shape = [2, 3] +// [ d e ] +// [ ] +// +// Arguments: +// split_dim: 0-D. The dimension along which to split. Must be in the range +// `[0, rank(shape))`. +// indices: 2-D tensor represents the indices of the sparse tensor. +// values: 1-D tensor represents the values of the sparse tensor. +// shape: 1-D. tensor represents the shape of the sparse tensor. +// output indices: A list of 1-D tensors represents the indices of the output +// sparse tensors. +// num_split: The number of ways to split. +// +// Returns A list of 1-D tensors represents the values of the output sparse +// tensors.A list of 1-D tensors represents the shape of the output sparse +// tensors. +func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf.Output, shape tf.Output, num_split int64) (output_indices []tf.Output, output_values []tf.Output, output_shape []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"sparse_types": sparse_types, "dense_shapes": dense_shapes} + attrs := map[string]interface{}{"num_split": num_split} opspec := tf.OpSpec{ - Type: "ParseExample", + Type: "SparseSplit", Input: []tf.Input{ - serialized, names, tf.OutputList(sparse_keys), tf.OutputList(dense_keys), tf.OutputList(dense_defaults), + split_dim, indices, values, shape, }, Attrs: attrs, } @@ -7523,91 +8884,110 @@ func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_ke } var idx int var err error - if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { - scope.UpdateErr("ParseExample", err) - return - } - if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { - scope.UpdateErr("ParseExample", err) + if output_indices, idx, err = makeOutputList(op, idx, "output_indices"); err != nil { + scope.UpdateErr("SparseSplit", err) return } - if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { - scope.UpdateErr("ParseExample", err) + if output_values, idx, err = makeOutputList(op, idx, "output_values"); err != nil { + scope.UpdateErr("SparseSplit", err) return } - if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { - scope.UpdateErr("ParseExample", err) + if output_shape, idx, err = makeOutputList(op, idx, "output_shape"); err != nil { + scope.UpdateErr("SparseSplit", err) return } - return sparse_indices, sparse_values, sparse_shapes, dense_values + return output_indices, output_values, output_shape } -// DecodeRawAttr is an optional argument to DecodeRaw. -type DecodeRawAttr func(optionalAttr) +// ResourceSparseApplyFtrlV2Attr is an optional argument to ResourceSparseApplyFtrlV2. +type ResourceSparseApplyFtrlV2Attr func(optionalAttr) -// DecodeRawLittleEndian sets the optional little_endian attribute to value. +// ResourceSparseApplyFtrlV2UseLocking sets the optional use_locking attribute to value. // -// value: Whether the input `bytes` are in little-endian order. -// Ignored for `out_type` values that are stored in a single byte like -// `uint8`. -// If not specified, defaults to true -func DecodeRawLittleEndian(value bool) DecodeRawAttr { +// value: If `True`, updating of the var and accum tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceSparseApplyFtrlV2UseLocking(value bool) ResourceSparseApplyFtrlV2Attr { return func(m optionalAttr) { - m["little_endian"] = value + m["use_locking"] = value } } -// Reinterpret the bytes of a string as a vector of numbers. +// Update relevant entries in '*var' according to the Ftrl-proximal scheme. +// +// That is for rows we have grad for, we update var, accum and linear as follows: +// grad_with_shrinkage = grad + 2 * l2_shrinkage * var +// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage +// linear += grad_with_shrinkage + +// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 +// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 +// accum = accum_new // // Arguments: -// bytes: All the elements must have the same length. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// linear: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 shrinkage regulariation. Must be a scalar. // +// lr_power: Scaling factor. Must be a scalar. // -// Returns A Tensor with one more dimension than the input `bytes`. The -// added dimension will have size equal to the length of the elements -// of `bytes` divided by the number of bytes to represent `out_type`. -func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...DecodeRawAttr) (output tf.Output) { +// Returns the created operation. +func ResourceSparseApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"out_type": out_type} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeRaw", + Type: "ResourceSparseApplyFtrlV2", Input: []tf.Input{ - bytes, + var_, accum, linear, grad, indices, lr, l1, l2, l2_shrinkage, lr_power, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// RandomPoissonAttr is an optional argument to RandomPoisson. -type RandomPoissonAttr func(optionalAttr) +// DataFormatVecPermuteAttr is an optional argument to DataFormatVecPermute. +type DataFormatVecPermuteAttr func(optionalAttr) -// RandomPoissonSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func RandomPoissonSeed(value int64) RandomPoissonAttr { +// DataFormatVecPermuteSrcFormat sets the optional src_format attribute to value. +// +// value: source data format. +// If not specified, defaults to "NHWC" +func DataFormatVecPermuteSrcFormat(value string) DataFormatVecPermuteAttr { return func(m optionalAttr) { - m["seed"] = value + m["src_format"] = value } } -// RandomPoissonSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func RandomPoissonSeed2(value int64) RandomPoissonAttr { +// DataFormatVecPermuteDstFormat sets the optional dst_format attribute to value. +// +// value: destination data format. +// If not specified, defaults to "NCHW" +func DataFormatVecPermuteDstFormat(value string) DataFormatVecPermuteAttr { return func(m optionalAttr) { - m["seed2"] = value + m["dst_format"] = value } } -// Use RandomPoissonV2 instead. +// Returns the permuted vector/tensor in the destination data format given the // -// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2 -func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) { +// one in the source data format. +// +// Arguments: +// x: Vector of size 4 or Tensor of shape (4, 2) in source data format. +// +// Returns Vector of size 4 or Tensor of shape (4, 2) in destination data format. +func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPermuteAttr) (y tf.Output) { if scope.Err() != nil { return } @@ -7616,9 +8996,9 @@ func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...Ra a(attrs) } opspec := tf.OpSpec{ - Type: "RandomPoisson", + Type: "DataFormatVecPermute", Input: []tf.Input{ - shape, rate, + x, }, Attrs: attrs, } @@ -7626,29 +9006,44 @@ func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...Ra return op.Output(0) } -// Returns the element-wise sum of a list of tensors. -// -// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not -// wait for all of its inputs to be ready before beginning to sum. This can -// save memory if inputs are ready at different times, since minimum temporary -// storage is proportional to the output size rather than the inputs size. +// Computes the gradient of `igamma(a, x)` wrt `a`. +func IgammaGradA(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "IgammaGradA", + Input: []tf.Input{ + a, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Converts each string in the input Tensor to its hash mod by a number of buckets. // -// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. +// The hash function is deterministic on the content of the string within the +// process. // -// Returns a `Tensor` of same shape and type as the elements of `inputs`. +// Note that the hash function may change from time to time. +// This functionality will be deprecated and it's recommended to use +// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`. // // Arguments: -// inputs: A list of `Tensor` objects, each with same shape and type. -// shape: Shape of elements of `inputs`. -func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) { +// +// num_buckets: The number of buckets. +// +// Returns A Tensor of the same shape as the input `string_tensor`. +func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shape": shape} + attrs := map[string]interface{}{"num_buckets": num_buckets} opspec := tf.OpSpec{ - Type: "AccumulateNV2", + Type: "StringToHashBucket", Input: []tf.Input{ - tf.OutputList(inputs), + string_tensor, }, Attrs: attrs, } @@ -7656,69 +9051,46 @@ func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Out return op.Output(0) } -// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter. -type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr) - -// DepthwiseConv2dNativeBackpropFilterDataFormat sets the optional data_format attribute to value. +// Computes gradients for the exponential linear (Elu) operation. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. -// If not specified, defaults to "NHWC" -func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2dNativeBackpropFilterAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// DepthwiseConv2dNativeBackpropFilterDilations sets the optional dilations attribute to value. +// Arguments: +// gradients: The backpropagated gradients to the corresponding Elu operation. +// outputs: The outputs of the corresponding Elu operation. // -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to -func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr { - return func(m optionalAttr) { - m["dilations"] = value +// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0, +// `gradients` otherwise. +func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { + if scope.Err() != nil { + return } -} - -// Computes the gradients of depthwise convolution with respect to the filter. -// -// Arguments: -// input: 4-D with shape based on `data_format`. For example, if -// `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height, -// in_width, in_channels]` tensor. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 4-D -// `[filter_height, filter_width, in_channels, depthwise_multiplier]` tensor. -// out_backprop: 4-D with shape based on `data_format`. -// For example, if `data_format` is 'NHWC' then -// out_backprop shape is `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. -// padding: The type of padding algorithm to use. + opspec := tf.OpSpec{ + Type: "EluGrad", + Input: []tf.Input{ + gradients, outputs, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that contains `count` elements from the `input_dataset`. // -// Returns 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. -// the `filter` input of the convolution. -func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropFilterAttr) (output tf.Output) { +// Arguments: +// +// count: A scalar representing the number of elements from the `input_dataset` +// that should be taken. A value of `-1` indicates that all of `input_dataset` +// is taken. +// +// +func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "DepthwiseConv2dNativeBackpropFilter", + Type: "TakeDataset", Input: []tf.Input{ - input, filter_sizes, out_backprop, + input_dataset, count, }, Attrs: attrs, } @@ -7726,219 +9098,207 @@ func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_s return op.Output(0) } -// Returns immutable tensor from memory region. +// Reads the value of a variable. // -// The current implementation memmaps the tensor from a file. +// The tensor returned by this operation is immutable. +// +// The value returned by this operation is guaranteed to be influenced by all the +// writes on which this operation depends directly or indirectly, and to not be +// influenced by any of the writes which depend directly or indirectly on this +// operation. // // Arguments: -// dtype: Type of the returned tensor. -// shape: Shape of the returned tensor. -// memory_region_name: Name of readonly memory region used by the tensor, see -// NewReadOnlyMemoryRegionFromFile in tensorflow::Env. -func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) { +// resource: handle to the resource in which to store the variable. +// dtype: the dtype of the value. +func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name} + attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "ImmutableConst", - + Type: "ReadVariableOp", + Input: []tf.Input{ + resource, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// StringJoinAttr is an optional argument to StringJoin. -type StringJoinAttr func(optionalAttr) - -// StringJoinSeparator sets the optional separator attribute to value. -// -// value: string, an optional join separator. -// If not specified, defaults to "" -func StringJoinSeparator(value string) StringJoinAttr { - return func(m optionalAttr) { - m["separator"] = value +// Computes tan of x element-wise. +func Tan(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Tan", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Joins the strings in the given list of string tensors into one tensor; +// Updates the tree ensemble by either adding a layer to the last tree being grown // -// with the given separator (default is an empty separator). +// or by starting a new tree. // // Arguments: -// inputs: A list of string tensors. The tensors must all have the same shape, -// or be scalars. Scalars may be mixed in; these will be broadcast to the shape -// of non-scalar inputs. -func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) { +// tree_ensemble_handle: Handle to the ensemble variable. +// feature_ids: Rank 1 tensor with ids for each feature. This is the real id of +// the feature that will be used in the split. +// node_ids: List of rank 1 tensors representing the nodes for which this feature +// has a split. +// gains: List of rank 1 tensors representing the gains for each of the feature's +// split. +// thresholds: List of rank 1 tensors representing the thesholds for each of the +// feature's split. +// left_node_contribs: List of rank 2 tensors with left leaf contribs for each of +// the feature's splits. Will be added to the previous node values to constitute +// the values of the left nodes. +// right_node_contribs: List of rank 2 tensors with right leaf contribs for each +// of the feature's splits. Will be added to the previous node values to constitute +// the values of the right nodes. +// max_depth: Max depth of the tree to build. +// learning_rate: shrinkage const for each new tree. +// pruning_mode: 0-No pruning, 1-Pre-pruning, 2-Post-pruning. +// +// Returns the created operation. +func BoostedTreesUpdateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, feature_ids tf.Output, node_ids []tf.Output, gains []tf.Output, thresholds []tf.Output, left_node_contribs []tf.Output, right_node_contribs []tf.Output, max_depth tf.Output, learning_rate tf.Output, pruning_mode int64) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"pruning_mode": pruning_mode} opspec := tf.OpSpec{ - Type: "StringJoin", + Type: "BoostedTreesUpdateEnsemble", Input: []tf.Input{ - tf.OutputList(inputs), + tree_ensemble_handle, feature_ids, tf.OutputList(node_ids), tf.OutputList(gains), tf.OutputList(thresholds), tf.OutputList(left_node_contribs), tf.OutputList(right_node_contribs), max_depth, learning_rate, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// LRNGradAttr is an optional argument to LRNGrad. -type LRNGradAttr func(optionalAttr) +// EncodeJpegAttr is an optional argument to EncodeJpeg. +type EncodeJpegAttr func(optionalAttr) -// LRNGradDepthRadius sets the optional depth_radius attribute to value. +// EncodeJpegFormat sets the optional format attribute to value. // -// value: A depth radius. -// If not specified, defaults to 5 -func LRNGradDepthRadius(value int64) LRNGradAttr { +// value: Per pixel image format. +// If not specified, defaults to "" +func EncodeJpegFormat(value string) EncodeJpegAttr { return func(m optionalAttr) { - m["depth_radius"] = value + m["format"] = value } } -// LRNGradBias sets the optional bias attribute to value. +// EncodeJpegQuality sets the optional quality attribute to value. // -// value: An offset (usually > 0 to avoid dividing by 0). -// If not specified, defaults to 1 -func LRNGradBias(value float32) LRNGradAttr { +// value: Quality of the compression from 0 to 100 (higher is better and slower). +// If not specified, defaults to 95 +func EncodeJpegQuality(value int64) EncodeJpegAttr { return func(m optionalAttr) { - m["bias"] = value + m["quality"] = value } } -// LRNGradAlpha sets the optional alpha attribute to value. +// EncodeJpegProgressive sets the optional progressive attribute to value. // -// value: A scale factor, usually positive. -// If not specified, defaults to 1 -func LRNGradAlpha(value float32) LRNGradAttr { +// value: If True, create a JPEG that loads progressively (coarse to fine). +// If not specified, defaults to false +func EncodeJpegProgressive(value bool) EncodeJpegAttr { return func(m optionalAttr) { - m["alpha"] = value + m["progressive"] = value } } -// LRNGradBeta sets the optional beta attribute to value. +// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. // -// value: An exponent. -// If not specified, defaults to 0.5 -func LRNGradBeta(value float32) LRNGradAttr { +// value: If True, spend CPU/RAM to reduce size with no quality change. +// If not specified, defaults to false +func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { return func(m optionalAttr) { - m["beta"] = value + m["optimize_size"] = value } } -// Gradients for Local Response Normalization. -// -// Arguments: -// input_grads: 4-D with shape `[batch, height, width, channels]`. -// input_image: 4-D with shape `[batch, height, width, channels]`. -// output_image: 4-D with shape `[batch, height, width, channels]`. +// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. // -// Returns The gradients for LRN. -func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "LRNGrad", - Input: []tf.Input{ - input_grads, input_image, output_image, - }, - Attrs: attrs, +// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. +// If not specified, defaults to true +func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { + return func(m optionalAttr) { + m["chroma_downsampling"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// AnyAttr is an optional argument to Any. -type AnyAttr func(optionalAttr) - -// AnyKeepDims sets the optional keep_dims attribute to value. +// EncodeJpegDensityUnit sets the optional density_unit attribute to value. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func AnyKeepDims(value bool) AnyAttr { +// value: Unit used to specify `x_density` and `y_density`: +// pixels per inch (`'in'`) or centimeter (`'cm'`). +// If not specified, defaults to "in" +func EncodeJpegDensityUnit(value string) EncodeJpegAttr { return func(m optionalAttr) { - m["keep_dims"] = value + m["density_unit"] = value } } -// Computes the "logical or" of elements across dimensions of a tensor. -// -// Reduces `input` along the dimensions given in `axis`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `axis`. If `keep_dims` is true, the reduced dimensions are -// retained with length 1. -// -// Arguments: -// input: The tensor to reduce. -// axis: The dimensions to reduce. Must be in the range -// `[-rank(input), rank(input))`. +// EncodeJpegXDensity sets the optional x_density attribute to value. // -// Returns The reduced tensor. -func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Any", - Input: []tf.Input{ - input, axis, - }, - Attrs: attrs, +// value: Horizontal pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegXDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["x_density"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl. -type ResourceApplyFtrlAttr func(optionalAttr) +// EncodeJpegYDensity sets the optional y_density attribute to value. +// +// value: Vertical pixels per density unit. +// If not specified, defaults to 300 +func EncodeJpegYDensity(value int64) EncodeJpegAttr { + return func(m optionalAttr) { + m["y_density"] = value + } +} -// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value. +// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr { +// value: If not empty, embed this XMP metadata in the image header. +// If not specified, defaults to "" +func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { return func(m optionalAttr) { - m["use_locking"] = value + m["xmp_metadata"] = value } } -// Update '*var' according to the Ftrl-proximal scheme. +// JPEG-encode an image. // -// accum_new = accum + grad * grad -// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new +// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. +// +// The attr `format` can be used to override the color format of the encoded +// output. Values can be: +// +// * `''`: Use a default format based on the number of channels in the image. +// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension +// of `image` must be 1. +// * `rgb`: Output an RGB JPEG image. The `channels` dimension +// of `image` must be 3. +// +// If `format` is not specified or is the empty string, a default format is picked +// in function of the number of channels in `image`: +// +// * 1: Output a grayscale image. +// * 3: Output an RGB image. // // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regulariation. Must be a scalar. -// l2: L2 regulariation. Must be a scalar. -// lr_power: Scaling factor. Must be a scalar. +// image: 3-D with shape `[height, width, channels]`. // -// Returns the created operation. -func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) { +// Returns 0-D. JPEG-encoded image. +func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { if scope.Err() != nil { return } @@ -7947,62 +9307,69 @@ func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf. a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceApplyFtrl", + Type: "EncodeJpeg", Input: []tf.Input{ - var_, accum, linear, grad, lr, l1, l2, lr_power, + image, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// RandomUniformAttr is an optional argument to RandomUniform. -type RandomUniformAttr func(optionalAttr) +// MultinomialAttr is an optional argument to Multinomial. +type MultinomialAttr func(optionalAttr) -// RandomUniformSeed sets the optional seed attribute to value. +// MultinomialSeed sets the optional seed attribute to value. // -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. +// value: If either seed or seed2 is set to be non-zero, the internal random number +// generator is seeded by the given seed. Otherwise, a random seed is used. // If not specified, defaults to 0 -func RandomUniformSeed(value int64) RandomUniformAttr { +func MultinomialSeed(value int64) MultinomialAttr { return func(m optionalAttr) { m["seed"] = value } } -// RandomUniformSeed2 sets the optional seed2 attribute to value. +// MultinomialSeed2 sets the optional seed2 attribute to value. // // value: A second seed to avoid seed collision. // If not specified, defaults to 0 -func RandomUniformSeed2(value int64) RandomUniformAttr { +func MultinomialSeed2(value int64) MultinomialAttr { return func(m optionalAttr) { m["seed2"] = value } } -// Outputs random values from a uniform distribution. -// -// The generated values follow a uniform distribution in the range `[0, 1)`. The -// lower bound 0 is included in the range, while the upper bound 1 is excluded. +// MultinomialOutputDtype sets the optional output_dtype attribute to value. +// If not specified, defaults to DT_INT64 +func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { + return func(m optionalAttr) { + m["output_dtype"] = value + } +} + +// Draws samples from a multinomial distribution. // // Arguments: -// shape: The shape of the output tensor. -// dtype: The type of the output. +// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` +// represents the unnormalized log probabilities for all classes. +// num_samples: 0-D. Number of independent samples to draw for each row slice. // -// Returns A tensor of the specified shape filled with uniform random values. -func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomUniformAttr) (output tf.Output) { +// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` +// contains the drawn class labels with range `[0, num_classes)`. +func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RandomUniform", + Type: "Multinomial", Input: []tf.Input{ - shape, + logits, num_samples, }, Attrs: attrs, } @@ -8010,30 +9377,35 @@ func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional .. return op.Output(0) } -// AssertAttr is an optional argument to Assert. -type AssertAttr func(optionalAttr) +// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. +type ResourceSparseApplyAdagradDAAttr func(optionalAttr) -// AssertSummarize sets the optional summarize attribute to value. +// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value. // -// value: Print this many entries of each tensor. -// If not specified, defaults to 3 -func AssertSummarize(value int64) AssertAttr { +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr { return func(m optionalAttr) { - m["summarize"] = value + m["use_locking"] = value } } -// Asserts that the given condition is true. -// -// If `condition` evaluates to false, print the list of tensors in `data`. -// `summarize` determines how many entries of the tensors to print. +// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. // // Arguments: -// condition: The condition to evaluate. -// data: The tensors to print out when condition is false. +// var_: Should be from a Variable(). +// gradient_accumulator: Should be from a Variable(). +// gradient_squared_accumulator: Should be from a Variable(). +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// lr: Learning rate. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// global_step: Training step number. Must be a scalar. // // Returns the created operation. -func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...AssertAttr) (o *tf.Operation) { +func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -8042,151 +9414,25 @@ func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...Ass a(attrs) } opspec := tf.OpSpec{ - Type: "Assert", + Type: "ResourceSparseApplyAdagradDA", Input: []tf.Input{ - condition, tf.OutputList(data), + var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// Computes element-wise population count (a.k.a. popcount, bitsum, bitcount). -// -// For each entry in `x`, calculates the number of `1` (on) bits in the binary -// representation of that entry. -// -// **NOTE**: It is more efficient to first `tf.bitcast` your tensors into -// `int32` or `int64` and perform the bitcount on the result, than to feed in -// 8- or 16-bit inputs and then aggregate the resulting counts. -func PopulationCount(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "PopulationCount", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Broadcasts a tensor value to one or more other devices. -func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape} - opspec := tf.OpSpec{ - Type: "CollectiveBcastSend", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Makes a copy of `x`. -// -// Arguments: -// x: The source tensor of type `T`. -// -// Returns y: A `Tensor` of type `T`. A copy of `x`. Guaranteed that `y` -// is not an alias of `x`. -func DeepCopy(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DeepCopy", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Split a `SparseTensor` into `num_split` tensors along one dimension. -// -// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices -// `[0 : shape[split_dim] % num_split]` gets one extra dimension. -// For example, if `split_dim = 1` and `num_split = 2` and the input is -// -// input_tensor = shape = [2, 7] -// [ a d e ] -// [b c ] -// -// Graphically the output tensors are: -// -// output_tensor[0] = shape = [2, 4] -// [ a ] -// [b c ] -// -// output_tensor[1] = shape = [2, 3] -// [ d e ] -// [ ] -// -// Arguments: -// split_dim: 0-D. The dimension along which to split. Must be in the range -// `[0, rank(shape))`. -// indices: 2-D tensor represents the indices of the sparse tensor. -// values: 1-D tensor represents the values of the sparse tensor. -// shape: 1-D. tensor represents the shape of the sparse tensor. -// output indices: A list of 1-D tensors represents the indices of the output -// sparse tensors. -// num_split: The number of ways to split. -// -// Returns A list of 1-D tensors represents the values of the output sparse -// tensors.A list of 1-D tensors represents the shape of the output sparse -// tensors. -func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf.Output, shape tf.Output, num_split int64) (output_indices []tf.Output, output_values []tf.Output, output_shape []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_split": num_split} - opspec := tf.OpSpec{ - Type: "SparseSplit", - Input: []tf.Input{ - split_dim, indices, values, shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output_indices, idx, err = makeOutputList(op, idx, "output_indices"); err != nil { - scope.UpdateErr("SparseSplit", err) - return - } - if output_values, idx, err = makeOutputList(op, idx, "output_values"); err != nil { - scope.UpdateErr("SparseSplit", err) - return - } - if output_shape, idx, err = makeOutputList(op, idx, "output_shape"); err != nil { - scope.UpdateErr("SparseSplit", err) - return - } - return output_indices, output_values, output_shape -} - -// ResourceSparseApplyFtrlV2Attr is an optional argument to ResourceSparseApplyFtrlV2. -type ResourceSparseApplyFtrlV2Attr func(optionalAttr) +// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. +type ResourceSparseApplyFtrlAttr func(optionalAttr) -// ResourceSparseApplyFtrlV2UseLocking sets the optional use_locking attribute to value. +// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. // // value: If `True`, updating of the var and accum tensors will be protected // by a lock; otherwise the behavior is undefined, but may exhibit less // contention. // If not specified, defaults to false -func ResourceSparseApplyFtrlV2UseLocking(value bool) ResourceSparseApplyFtrlV2Attr { +func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { return func(m optionalAttr) { m["use_locking"] = value } @@ -8195,10 +9441,8 @@ func ResourceSparseApplyFtrlV2UseLocking(value bool) ResourceSparseApplyFtrlV2At // Update relevant entries in '*var' according to the Ftrl-proximal scheme. // // That is for rows we have grad for, we update var, accum and linear as follows: -// grad_with_shrinkage = grad + 2 * l2_shrinkage * var -// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage -// linear += grad_with_shrinkage + -// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var +// accum_new = accum + grad * grad +// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var // quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 // var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 // accum = accum_new @@ -8211,12 +9455,11 @@ func ResourceSparseApplyFtrlV2UseLocking(value bool) ResourceSparseApplyFtrlV2At // indices: A vector of indices into the first dimension of var and accum. // lr: Scaling factor. Must be a scalar. // l1: L1 regularization. Must be a scalar. -// l2: L2 shrinkage regulariation. Must be a scalar. -// +// l2: L2 regularization. Must be a scalar. // lr_power: Scaling factor. Must be a scalar. // // Returns the created operation. -func ResourceSparseApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlV2Attr) (o *tf.Operation) { +func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -8225,290 +9468,335 @@ func ResourceSparseApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, li a(attrs) } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyFtrlV2", + Type: "ResourceSparseApplyFtrl", Input: []tf.Input{ - var_, accum, linear, grad, indices, lr, l1, l2, l2_shrinkage, lr_power, + var_, accum, linear, grad, indices, lr, l1, l2, lr_power, }, Attrs: attrs, } return scope.AddOperation(opspec) } -// DataFormatVecPermuteAttr is an optional argument to DataFormatVecPermute. -type DataFormatVecPermuteAttr func(optionalAttr) - -// DataFormatVecPermuteSrcFormat sets the optional src_format attribute to value. +// Returns which elements of x are Inf. // -// value: source data format. -// If not specified, defaults to "NHWC" -func DataFormatVecPermuteSrcFormat(value string) DataFormatVecPermuteAttr { - return func(m optionalAttr) { - m["src_format"] = value +// @compatibility(numpy) +// Equivalent to np.isinf +// @end_compatibility +func IsInf(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return } -} - -// DataFormatVecPermuteDstFormat sets the optional dst_format attribute to value. -// -// value: destination data format. -// If not specified, defaults to "NCHW" -func DataFormatVecPermuteDstFormat(value string) DataFormatVecPermuteAttr { - return func(m optionalAttr) { - m["dst_format"] = value + opspec := tf.OpSpec{ + Type: "IsInf", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns the permuted vector/tensor in the destination data format given the +// Computes the sum along sparse segments of a tensor divided by the sqrt of N. // -// one in the source data format. +// N is the size of the segment being reduced. +// +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. // // Arguments: -// x: Vector of size 4 or Tensor of shape (4, 2) in source data format. // -// Returns Vector of size 4 or Tensor of shape (4, 2) in destination data format. -func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPermuteAttr) (y tf.Output) { +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "DataFormatVecPermute", + Type: "SparseSegmentSqrtN", Input: []tf.Input{ - x, + data, indices, segment_ids, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Reads the value of a variable. -// -// The tensor returned by this operation is immutable. +// Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. // -// The value returned by this operation is guaranteed to be influenced by all the -// writes on which this operation depends directly or indirectly, and to not be -// influenced by any of the writes which depend directly or indirectly on this -// operation. +// This Op does not require `a_indices` be sorted in standard lexicographic order. // // Arguments: -// resource: handle to the resource in which to store the variable. -// dtype: the dtype of the value. -func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) { +// a_indices: 2-D. The `indices` of the `SparseTensor`, with shape `[nnz, ndims]`. +// a_values: 1-D. The `values` of the `SparseTensor`, with shape `[nnz]`. +// a_shape: 1-D. The `shape` of the `SparseTensor`, with shape `[ndims]`. +// b: `ndims`-D Tensor. With shape `a_shape`. +func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} opspec := tf.OpSpec{ - Type: "ReadVariableOp", + Type: "SparseTensorDenseAdd", Input: []tf.Input{ - resource, + a_indices, a_values, a_shape, b, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes tan of x element-wise. -func Tan(scope *Scope, x tf.Output) (y tf.Output) { +// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. +type StatelessTruncatedNormalAttr func(optionalAttr) + +// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs deterministic pseudorandom values from a truncated normal distribution. +// +// The generated values follow a normal distribution with mean 0 and standard +// deviation 1, except that values whose magnitude is more than 2 standard +// deviations from the mean are dropped and re-picked. +// +// The outputs are a deterministic function of `shape` and `seed`. +// +// Arguments: +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). +// +// Returns Random values with specified shape. +func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Tan", + Type: "StatelessTruncatedNormal", Input: []tf.Input{ - x, + shape, seed, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Updates the tree ensemble by either adding a layer to the last tree being grown +// RestoreSliceAttr is an optional argument to RestoreSlice. +type RestoreSliceAttr func(optionalAttr) + +// RestoreSlicePreferredShard sets the optional preferred_shard attribute to value. // -// or by starting a new tree. +// value: Index of file to open first if multiple files match +// `file_pattern`. See the documentation for `Restore`. +// If not specified, defaults to -1 +func RestoreSlicePreferredShard(value int64) RestoreSliceAttr { + return func(m optionalAttr) { + m["preferred_shard"] = value + } +} + +// Restores a tensor from checkpoint files. +// +// This is like `Restore` except that restored tensor can be listed as filling +// only a slice of a larger tensor. `shape_and_slice` specifies the shape of the +// larger tensor and the slice that the restored tensor covers. +// +// The `shape_and_slice` input has the same format as the +// elements of the `shapes_and_slices` input of the `SaveSlices` op. // // Arguments: -// tree_ensemble_handle: Handle to the ensemble variable. -// feature_ids: Rank 1 tensor with ids for each feature. This is the real id of -// the feature that will be used in the split. -// node_ids: List of rank 1 tensors representing the nodes for which this feature -// has a split. -// gains: List of rank 1 tensors representing the gains for each of the feature's -// split. -// thresholds: List of rank 1 tensors representing the thesholds for each of the -// feature's split. -// left_node_contribs: List of rank 2 tensors with left leaf contribs for each of -// the feature's splits. Will be added to the previous node values to constitute -// the values of the left nodes. -// right_node_contribs: List of rank 2 tensors with right leaf contribs for each -// of the feature's splits. Will be added to the previous node values to constitute -// the values of the right nodes. -// max_depth: Max depth of the tree to build. -// learning_rate: shrinkage const for each new tree. -// pruning_mode: 0-No pruning, 1-Pre-pruning, 2-Post-pruning. +// file_pattern: Must have a single element. The pattern of the files from +// which we read the tensor. +// tensor_name: Must have a single element. The name of the tensor to be +// restored. +// shape_and_slice: Scalar. The shapes and slice specifications to use when +// restoring a tensors. +// dt: The type of the tensor to be restored. // -// Returns the created operation. -func BoostedTreesUpdateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, feature_ids tf.Output, node_ids []tf.Output, gains []tf.Output, thresholds []tf.Output, left_node_contribs []tf.Output, right_node_contribs []tf.Output, max_depth tf.Output, learning_rate tf.Output, pruning_mode int64) (o *tf.Operation) { +// Returns The restored tensor. +func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, shape_and_slice tf.Output, dt tf.DataType, optional ...RestoreSliceAttr) (tensor tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"pruning_mode": pruning_mode} + attrs := map[string]interface{}{"dt": dt} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "BoostedTreesUpdateEnsemble", + Type: "RestoreSlice", Input: []tf.Input{ - tree_ensemble_handle, feature_ids, tf.OutputList(node_ids), tf.OutputList(gains), tf.OutputList(thresholds), tf.OutputList(left_node_contribs), tf.OutputList(right_node_contribs), max_depth, learning_rate, + file_pattern, tensor_name, shape_and_slice, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl. -type ResourceSparseApplyFtrlAttr func(optionalAttr) - -// ResourceSparseApplyFtrlUseLocking sets the optional use_locking attribute to value. +// Divides sparse updates into the variable referenced by `resource`. // -// value: If `True`, updating of the var and accum tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceSparseApplyFtrlUseLocking(value bool) ResourceSparseApplyFtrlAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update relevant entries in '*var' according to the Ftrl-proximal scheme. +// This operation computes // -// That is for rows we have grad for, we update var, accum and linear as follows: -// accum_new = accum + grad * grad -// linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var -// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2 -// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0 -// accum = accum_new +// # Scalar indices +// ref[indices, ...] /= updates[...] +// +// # Vector indices (for each i) +// ref[indices[i], ...] /= updates[i, ...] +// +// # High rank indices (for each i, ..., j) +// ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] +// +// Duplicate entries are handled correctly: if multiple `indices` reference +// the same location, their contributions multiply. +// +// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// +//
+// +//
// // Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// linear: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// lr_power: Scaling factor. Must be a scalar. +// resource: Should be from a `Variable` node. +// indices: A tensor of indices into the first dimension of `ref`. +// updates: A tensor of updated values to add to `ref`. // // Returns the created operation. -func ResourceSparseApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceSparseApplyFtrlAttr) (o *tf.Operation) { +func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "ResourceSparseApplyFtrl", + Type: "ResourceScatterDiv", Input: []tf.Input{ - var_, accum, linear, grad, indices, lr, l1, l2, lr_power, + resource, indices, updates, }, - Attrs: attrs, } return scope.AddOperation(opspec) } -// Returns which elements of x are Inf. +// ResourceScatterNdAddAttr is an optional argument to ResourceScatterNdAdd. +type ResourceScatterNdAddAttr func(optionalAttr) + +// ResourceScatterNdAddUseLocking sets the optional use_locking attribute to value. // -// @compatibility(numpy) -// Equivalent to np.isinf -// @end_compatibility -func IsInf(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IsInf", - Input: []tf.Input{ - x, - }, +// value: An optional bool. Defaults to True. If True, the assignment will +// be protected by a lock; otherwise the behavior is undefined, +// but may exhibit less contention. +// If not specified, defaults to true +func ResourceScatterNdAddUseLocking(value bool) ResourceScatterNdAddAttr { + return func(m optionalAttr) { + m["use_locking"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Computes the sum along sparse segments of a tensor divided by the sqrt of N. +// Adds sparse `updates` to individual values or slices within a given // -// N is the size of the segment being reduced. +// variable according to `indices`. // -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. +// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. // -// Arguments: +// `indices` must be integer tensor, containing indices into `ref`. +// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. // -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +// dimension of `ref`. // -// Returns Has same shape as data, except for dimension 0 which -// has size `k`, the number of segments. -func SparseSegmentSqrtN(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) { +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: +// +// ``` +// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +// ``` +// +// For example, say we want to update 4 scattered elements to a rank-1 tensor to +// 8 elements. In Python, that update would look like this: +// +// ```python +// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True) +// indices = tf.constant([[4], [3], [1] ,[7]]) +// updates = tf.constant([9, 10, 11, 12]) +// update = tf.scatter_nd_add(ref, indices, updates) +// with tf.Session() as sess: +// print sess.run(update) +// ``` +// +// The resulting update to ref would look like this: +// +// [1, 12, 3, 14, 14, 6, 7, 20] +// +// See @{tf.scatter_nd} for more details about how to make updates to +// slices. +// +// Arguments: +// ref: A resource handle. Must be from a VarHandleOp. +// indices: A Tensor. Must be one of the following types: int32, int64. +// A tensor of indices into ref. +// updates: A Tensor. Must have the same type as ref. A tensor of +// values to add to ref. +// +// Returns the created operation. +func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdAddAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "SparseSegmentSqrtN", + Type: "ResourceScatterNdAdd", Input: []tf.Input{ - data, indices, segment_ids, + ref, indices, updates, }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. -// -// This Op does not require `a_indices` be sorted in standard lexicographic order. -// -// Arguments: -// a_indices: 2-D. The `indices` of the `SparseTensor`, with shape `[nnz, ndims]`. -// a_values: 1-D. The `values` of the `SparseTensor`, with shape `[nnz]`. -// a_shape: 1-D. The `shape` of the `SparseTensor`, with shape `[ndims]`. -// b: `ndims`-D Tensor. With shape `a_shape`. -func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b tf.Output) (output tf.Output) { +// Mutually reduces multiple tensors of identical type and shape. +func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets} opspec := tf.OpSpec{ - Type: "SparseTensorDenseAdd", + Type: "CollectiveReduce", Input: []tf.Input{ - a_indices, a_values, a_shape, b, + input, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. -type StatelessTruncatedNormalAttr func(optionalAttr) +// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal. +type StatelessRandomNormalAttr func(optionalAttr) -// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. +// StatelessRandomNormalDtype sets the optional dtype attribute to value. // // value: The type of the output. // If not specified, defaults to DT_FLOAT -func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { +func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr { return func(m optionalAttr) { m["dtype"] = value } } -// Outputs deterministic pseudorandom values from a truncated normal distribution. +// Outputs deterministic pseudorandom values from a normal distribution. // -// The generated values follow a normal distribution with mean 0 and standard -// deviation 1, except that values whose magnitude is more than 2 standard -// deviations from the mean are dropped and re-picked. +// The generated values will have mean 0 and standard deviation 1. // // The outputs are a deterministic function of `shape` and `seed`. // @@ -8517,7 +9805,7 @@ func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAt // seed: 2 seeds (shape [2]). // // Returns Random values with specified shape. -func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { +func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -8526,7 +9814,7 @@ func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, opt a(attrs) } opspec := tf.OpSpec{ - Type: "StatelessTruncatedNormal", + Type: "StatelessRandomNormal", Input: []tf.Input{ shape, seed, }, @@ -8536,157 +9824,248 @@ func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, opt return op.Output(0) } -// RestoreSliceAttr is an optional argument to RestoreSlice. -type RestoreSliceAttr func(optionalAttr) +// StringSplitV2Attr is an optional argument to StringSplitV2. +type StringSplitV2Attr func(optionalAttr) -// RestoreSlicePreferredShard sets the optional preferred_shard attribute to value. +// StringSplitV2Maxsplit sets the optional maxsplit attribute to value. // -// value: Index of file to open first if multiple files match -// `file_pattern`. See the documentation for `Restore`. +// value: An `int`. If `maxsplit > 0`, limit of the split of the result. // If not specified, defaults to -1 -func RestoreSlicePreferredShard(value int64) RestoreSliceAttr { +func StringSplitV2Maxsplit(value int64) StringSplitV2Attr { return func(m optionalAttr) { - m["preferred_shard"] = value + m["maxsplit"] = value } } -// Restores a tensor from checkpoint files. +// Split elements of `source` based on `sep` into a `SparseTensor`. // -// This is like `Restore` except that restored tensor can be listed as filling -// only a slice of a larger tensor. `shape_and_slice` specifies the shape of the -// larger tensor and the slice that the restored tensor covers. +// Let N be the size of source (typically N will be the batch size). Split each +// element of `source` based on `sep` and return a `SparseTensor` +// containing the split tokens. Empty tokens are ignored. // -// The `shape_and_slice` input has the same format as the -// elements of the `shapes_and_slices` input of the `SaveSlices` op. +// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', +// then the output will be +// ``` +// st.indices = [0, 0; +// 0, 1; +// 1, 0; +// 1, 1; +// 1, 2] +// st.shape = [2, 3] +// st.values = ['hello', 'world', 'a', 'b', 'c'] +// ``` // -// Arguments: -// file_pattern: Must have a single element. The pattern of the files from -// which we read the tensor. -// tensor_name: Must have a single element. The name of the tensor to be -// restored. -// shape_and_slice: Scalar. The shapes and slice specifications to use when -// restoring a tensors. -// dt: The type of the tensor to be restored. +// If `sep` is given, consecutive delimiters are not grouped together and are +// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and +// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty +// string, consecutive whitespace are regarded as a single separator, and the +// result will contain no empty strings at the startor end if the string has +// leading or trailing whitespace. // -// Returns The restored tensor. -func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, shape_and_slice tf.Output, dt tf.DataType, optional ...RestoreSliceAttr) (tensor tf.Output) { +// Note that the above mentioned behavior matches python's str.split. +// +// Arguments: +// input: `1-D` string `Tensor`, the strings to split. +// sep: `0-D` string `Tensor`, the delimiter character. +func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dt": dt} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RestoreSlice", + Type: "StringSplitV2", Input: []tf.Input{ - file_pattern, tensor_name, shape_and_slice, + input, sep, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// Divides sparse updates into the variable referenced by `resource`. -// -// This operation computes -// -// # Scalar indices -// ref[indices, ...] /= updates[...] -// -// # Vector indices (for each i) -// ref[indices[i], ...] /= updates[i, ...] -// -// # High rank indices (for each i, ..., j) -// ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] +// MaxPoolAttr is an optional argument to MaxPool. +type MaxPoolAttr func(optionalAttr) + +// MaxPoolDataFormat sets the optional data_format attribute to value. // -// Duplicate entries are handled correctly: if multiple `indices` reference -// the same location, their contributions multiply. +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolDataFormat(value string) MaxPoolAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Performs max pooling on the input. // -// Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`. +// Arguments: +// input: 4-D input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -//
-// -//
+// Returns The max pooled output tensor. +func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPool", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// SparseMatMulAttr is an optional argument to SparseMatMul. +type SparseMatMulAttr func(optionalAttr) + +// SparseMatMulTransposeA sets the optional transpose_a attribute to value. +// If not specified, defaults to false +func SparseMatMulTransposeA(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["transpose_a"] = value + } +} + +// SparseMatMulTransposeB sets the optional transpose_b attribute to value. +// If not specified, defaults to false +func SparseMatMulTransposeB(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["transpose_b"] = value + } +} + +// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value. +// If not specified, defaults to false +func SparseMatMulAIsSparse(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["a_is_sparse"] = value + } +} + +// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value. +// If not specified, defaults to false +func SparseMatMulBIsSparse(value bool) SparseMatMulAttr { + return func(m optionalAttr) { + m["b_is_sparse"] = value + } +} + +// Multiply matrix "a" by matrix "b". // -// Arguments: -// resource: Should be from a `Variable` node. -// indices: A tensor of indices into the first dimension of `ref`. -// updates: A tensor of updated values to add to `ref`. +// The inputs must be two-dimensional matrices and the inner dimension of "a" must +// match the outer dimension of "b". Both "a" and "b" must be `Tensor`s not +// `SparseTensor`s. This op is optimized for the case where at least one of "a" or +// "b" is sparse, in the sense that they have a large proportion of zero values. +// The breakeven for using this versus a dense matrix multiply on one platform was +// 30% zero values in the sparse matrix. // -// Returns the created operation. -func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, updates tf.Output) (o *tf.Operation) { +// The gradient computation of this operation will only take advantage of sparsity +// in the input gradient when that gradient comes from a Relu. +func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ResourceScatterDiv", + Type: "SparseMatMul", Input: []tf.Input{ - resource, indices, updates, + a, b, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Mutually reduces multiple tensors of identical type and shape. -func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) { +// Concatenates quantized tensors along one dimension. +// +// Arguments: +// concat_dim: 0-D. The dimension along which to concatenate. Must be in the +// range [0, rank(values)). +// values: The `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// input_mins: The minimum scalar values for each of the input tensors. +// input_maxes: The maximum scalar values for each of the input tensors. +// +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. +func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets} opspec := tf.OpSpec{ - Type: "CollectiveReduce", + Type: "QuantizedConcat", Input: []tf.Input{ - input, + concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes), }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal. -type StatelessRandomNormalAttr func(optionalAttr) - -// StatelessRandomNormalDtype sets the optional dtype attribute to value. +// Slice a `SparseTensor` based on the `start` and `size`. // -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs deterministic pseudorandom values from a normal distribution. +// For example, if the input is // -// The generated values will have mean 0 and standard deviation 1. +// input_tensor = shape = [2, 7] +// [ a d e ] +// [b c ] // -// The outputs are a deterministic function of `shape` and `seed`. +// Graphically the output tensors are: +// +// sparse_slice([0, 0], [2, 4]) = shape = [2, 4] +// [ a ] +// [b c ] +// +// sparse_slice([0, 4], [2, 3]) = shape = [2, 3] +// [ d e ] +// [ ] // // Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). +// indices: 2-D tensor represents the indices of the sparse tensor. +// values: 1-D tensor represents the values of the sparse tensor. +// shape: 1-D. tensor represents the shape of the sparse tensor. +// start: 1-D. tensor represents the start of the slice. +// size: 1-D. tensor represents the size of the slice. +// output indices: A list of 1-D tensors represents the indices of the output +// sparse tensors. // -// Returns Random values with specified shape. -func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) { +// Returns A list of 1-D tensors represents the values of the output sparse +// tensors.A list of 1-D tensors represents the shape of the output sparse +// tensors. +func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "StatelessRandomNormal", + Type: "SparseSlice", Input: []tf.Input{ - shape, seed, + indices, values, shape, start, size, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } // Reduces sparse updates into the variable referenced by `resource` using the `min` operation. @@ -9226,6 +10605,51 @@ func AvgPoolGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize return op.Output(0) } +// Greedily selects a subset of bounding boxes in descending order of score, +// +// pruning away boxes that have high overlaps +// with previously selected boxes. Bounding boxes with score less than +// `score_threshold` are removed. N-by-n overlap values are supplied as square matrix, +// which allows for defining a custom overlap criterium (eg. intersection over union, +// intersection over area, etc.). +// +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// +// selected_indices = tf.image.non_max_suppression_with_overlaps( +// overlaps, scores, max_output_size, overlap_threshold, score_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) +// +// Arguments: +// overlaps: A 2-D float tensor of shape `[num_boxes, num_boxes]` representing +// the n-by-n box overlap values. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// overlap_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too. +// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove +// boxes based on score. +// +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppressionWithOverlaps(scope *Scope, overlaps tf.Output, scores tf.Output, max_output_size tf.Output, overlap_threshold tf.Output, score_threshold tf.Output) (selected_indices tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NonMaxSuppressionWithOverlaps", + Input: []tf.Input{ + overlaps, scores, max_output_size, overlap_threshold, score_threshold, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // StageClearAttr is an optional argument to StageClear. type StageClearAttr func(optionalAttr) @@ -9471,83 +10895,6 @@ func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...Fix return op.Output(0) } -// Converts each string in the input Tensor to its hash mod by a number of buckets. -// -// The hash function is deterministic on the content of the string within the -// process. -// -// Note that the hash function may change from time to time. -// This functionality will be deprecated and it's recommended to use -// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`. -// -// Arguments: -// -// num_buckets: The number of buckets. -// -// Returns A Tensor of the same shape as the input `string_tensor`. -func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_buckets": num_buckets} - opspec := tf.OpSpec{ - Type: "StringToHashBucket", - Input: []tf.Input{ - string_tensor, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes gradients for the exponential linear (Elu) operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding Elu operation. -// outputs: The outputs of the corresponding Elu operation. -// -// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0, -// `gradients` otherwise. -func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "EluGrad", - Input: []tf.Input{ - gradients, outputs, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that contains `count` elements from the `input_dataset`. -// -// Arguments: -// -// count: A scalar representing the number of elements from the `input_dataset` -// that should be taken. A value of `-1` indicates that all of `input_dataset` -// is taken. -// -// -func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "TakeDataset", - Input: []tf.Input{ - input_dataset, count, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // The gradient operator for the SparseAdd op. // // The SparseAdd op calculates A + B, where A, B, and the sum are all represented @@ -9560,68 +10907,24 @@ func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_ // the non-empty values of the sum. // a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`. // b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`. -// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size -// `[nnz(sum), ndims]`. -// -// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the -// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the -// non-empty values of B. -func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseAddGrad", - Input: []tf.Input{ - backprop_val_grad, a_indices, b_indices, sum_indices, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Computes atan of x element-wise. -func Atan(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Atan", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Encode audio data using the WAV file format. -// -// This operation will generate a string suitable to be saved out to create a .wav -// audio file. It will be encoded in the 16-bit PCM format. It takes in float -// values in the range -1.0f to 1.0f, and any outside that value will be clamped to -// that range. -// -// `audio` is a 2-D float Tensor of shape `[length, channels]`. -// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100). -// -// Arguments: -// audio: 2-D with shape `[length, channels]`. -// sample_rate: Scalar containing the sample frequency. +// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size +// `[nnz(sum), ndims]`. // -// Returns 0-D. WAV-encoded file contents. -func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) { +// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the +// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the +// non-empty values of B. +func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "EncodeWav", + Type: "SparseAddGrad", Input: []tf.Input{ - audio, sample_rate, + backprop_val_grad, a_indices, b_indices, sum_indices, }, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } // Converts each string in the input Tensor to its hash mod by a number of buckets. @@ -10203,6 +11506,120 @@ func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Out return scope.AddOperation(opspec) } +// CudnnRNNBackpropV2Attr is an optional argument to CudnnRNNBackpropV2. +type CudnnRNNBackpropV2Attr func(optionalAttr) + +// CudnnRNNBackpropV2RnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNBackpropV2RnnMode(value string) CudnnRNNBackpropV2Attr { + return func(m optionalAttr) { + m["rnn_mode"] = value + } +} + +// CudnnRNNBackpropV2InputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNBackpropV2InputMode(value string) CudnnRNNBackpropV2Attr { + return func(m optionalAttr) { + m["input_mode"] = value + } +} + +// CudnnRNNBackpropV2Direction sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNBackpropV2Direction(value string) CudnnRNNBackpropV2Attr { + return func(m optionalAttr) { + m["direction"] = value + } +} + +// CudnnRNNBackpropV2Dropout sets the optional dropout attribute to value. +// If not specified, defaults to 0 +func CudnnRNNBackpropV2Dropout(value float32) CudnnRNNBackpropV2Attr { + return func(m optionalAttr) { + m["dropout"] = value + } +} + +// CudnnRNNBackpropV2Seed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func CudnnRNNBackpropV2Seed(value int64) CudnnRNNBackpropV2Attr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// CudnnRNNBackpropV2Seed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNBackpropV2Seed2(value int64) CudnnRNNBackpropV2Attr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Backprop step of CudnnRNN. +// +// Compute the backprop of both data and weights in a RNN. Takes an extra +// "host_reserved" inupt than CudnnRNNBackprop, which is used to determine RNN +// cudnnRNNAlgo_t and cudnnMathType_t. +// +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicates whether there is a linear projection between the input and +// the actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. Should be +// "unidirectional" or "bidirectional". +// dropout: Dropout probability. When set to 0., dropout is disabled. +// seed: The 1st part of a seed to initialize dropout. +// seed2: The 2nd part of a seed to initialize dropout. +// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. +// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, +// num_units]. +// input_c: For LSTM, a 3-D tensor with the shape of +// [num_layer * dir, batch, num_units]. For other models, it is ignored. +// params: A 1-D tensor that contains the weights and biases in an opaque layout. +// The size must be created through CudnnRNNParamsSize, and initialized +// separately. Note that they might not be compatible across different +// generations. So it is a good idea to save and restore +// output: A 3-D tensor with the shape of [seq_length, batch_size, +// dir * num_units]. +// output_h: The same shape has input_h. +// output_c: The same shape as input_c for LSTM. An empty tensor for other models. +// output_backprop: A 3-D tensor with the same shape as output in the forward pass. +// output_h_backprop: A 3-D tensor with the same shape as output_h in the forward +// pass. +// output_c_backprop: A 3-D tensor with the same shape as output_c in the forward +// pass. +// reserve_space: The same reserve_space produced in the forward operation. +// host_reserved: The same host_reserved produced in the forward operation. +// input_backprop: The backprop to input in the forward pass. Has the same shape +// as input. +// input_h_backprop: The backprop to input_h in the forward pass. Has the same +// shape as input_h. +// input_c_backprop: The backprop to input_c in the forward pass. Has the same +// shape as input_c. +// params_backprop: The backprop to the params buffer in the forward pass. Has the +// same shape as params. +func CudnnRNNBackpropV2(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, output tf.Output, output_h tf.Output, output_c tf.Output, output_backprop tf.Output, output_h_backprop tf.Output, output_c_backprop tf.Output, reserve_space tf.Output, host_reserved tf.Output, optional ...CudnnRNNBackpropV2Attr) (input_backprop tf.Output, input_h_backprop tf.Output, input_c_backprop tf.Output, params_backprop tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CudnnRNNBackpropV2", + Input: []tf.Input{ + input, input_h, input_c, params, output, output_h, output_c, output_backprop, output_h_backprop, output_c_backprop, reserve_space, host_reserved, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + // Locks a mutex resource. The output is the lock. So long as the lock tensor // // is alive, any other request to use `MutexLock` with this mutex will wait. @@ -10390,6 +11807,34 @@ func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, o return op.Output(0) } +// Check if the input matches the regex pattern. +// +// The input is a string tensor of any shape. The pattern is a scalar +// string tensor which is applied to every element of the input tensor. +// The boolean values (True or False) of the output tensor indicate +// if the input matches the regex pattern provided. +// +// The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +// +// Arguments: +// input: A string tensor of the text to be processed. +// pattern: A 1-D string tensor of the regular expression to match the input. +// +// Returns A bool tensor with the same shape as `input`. +func RegexFullMatch(scope *Scope, input tf.Output, pattern tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RegexFullMatch", + Input: []tf.Input{ + input, pattern, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Says whether the targets are in the top `K` predictions. // // This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the @@ -10754,79 +12199,6 @@ func IFFT(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } -// Generates values in an interval. -// -// A sequence of `num` evenly-spaced values are generated beginning at `start`. -// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, -// so that the last one is exactly `stop`. -// -// For example: -// -// ``` -// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] -// ``` -// -// Arguments: -// start: First entry in the range. -// stop: Last entry in the range. -// num: Number of values to generate. -// -// Returns 1-D. The generated values. -func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "LinSpace", - Input: []tf.Input{ - start, stop, num, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. -type DestroyResourceOpAttr func(optionalAttr) - -// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. -// -// value: whether to ignore the error when the resource -// doesn't exist. -// If not specified, defaults to true -func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { - return func(m optionalAttr) { - m["ignore_lookup_error"] = value - } -} - -// Deletes the resource specified by the handle. -// -// All subsequent operations using the resource will result in a NotFound -// error status. -// -// Arguments: -// resource: handle to the resource to delete. -// -// Returns the created operation. -func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DestroyResourceOp", - Input: []tf.Input{ - resource, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp. type ResourceSparseApplyRMSPropAttr func(optionalAttr) @@ -11206,6 +12578,65 @@ func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, l return scope.AddOperation(opspec) } +// Elementwise computes the bitwise right-shift of `x` and `y`. +// +// Performs a logical shift for unsigned integer types, and an arithmetic shift +// for signed integer types. +// +// If `y` is negative, or greater than or equal to than the width of `x` in bits +// the result is implementation defined. +func RightShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RightShift", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// TensorListStackAttr is an optional argument to TensorListStack. +type TensorListStackAttr func(optionalAttr) + +// TensorListStackNumElements sets the optional num_elements attribute to value. +// If not specified, defaults to -1 +func TensorListStackNumElements(value int64) TensorListStackAttr { + return func(m optionalAttr) { + m["num_elements"] = value + } +} + +// Stacks all tensors in the list. +// +// Requires that all tensors have the same shape. +// +// input_handle: the input list +// tensor: the gathered result +// num_elements: optional. If not -1, the number of elements in the list. +// +func TensorListStack(scope *Scope, input_handle tf.Output, element_dtype tf.DataType, optional ...TensorListStackAttr) (tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"element_dtype": element_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorListStack", + Input: []tf.Input{ + input_handle, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform. type StatelessRandomUniformAttr func(optionalAttr) @@ -11282,24 +12713,6 @@ func Fact(scope *Scope) (fact tf.Output) { return op.Output(0) } -// Elementwise computes the bitwise XOR of `x` and `y`. -// -// The result will have those bits set, that are different in `x` and `y`. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseXor(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BitwiseXor", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Deserialize `SparseTensor` objects. // // The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where @@ -11402,7 +12815,7 @@ func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr { // 8 elements. In Python, that update would look like this: // // ```python -// ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8]) +// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) // indices = tf.constant([[4], [3], [1] ,[7]]) // updates = tf.constant([9, 10, 11, 12]) // update = tf.scatter_nd_update(ref, indices, updates) @@ -11727,6 +13140,7 @@ func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Outp // [0, 0, 2, 2, 0, 0] // [0, 0, 0, 0, 0, 0]] // ``` +// func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) { if scope.Err() != nil { return @@ -12305,337 +13719,95 @@ func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, label // // Returns A complex64 tensor of the same shape as `input`. The inner-most // dimension of `input` is replaced with its 1D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.fft -// @end_compatibility -func FFT(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FFT", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Transforms a serialized tensorflow.TensorProto proto into a Tensor. -// -// Arguments: -// serialized: A scalar string containing a serialized TensorProto proto. -// out_type: The type of the serialized tensor. The provided type must match the -// type of the serialized tensor and no implicit conversion will take place. -// -// Returns A Tensor of type `out_type`. -func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"out_type": out_type} - opspec := tf.OpSpec{ - Type: "ParseTensor", - Input: []tf.Input{ - serialized, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax. -type MaxPoolWithArgmaxAttr func(optionalAttr) - -// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value. -// If not specified, defaults to DT_INT64 -func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr { - return func(m optionalAttr) { - m["Targmax"] = value - } -} - -// Performs max pooling on the input and outputs both max values and indices. -// -// The indices in `argmax` are flattened, so that a maximum value at position -// `[b, y, x, c]` becomes flattened index -// `((b * height + y) * width + x) * channels + c`. -// -// The indices returned are always in `[0, height) x [0, width)` before flattening, -// even if padding is involved and the mathematically correct answer is outside -// (either negative or too large). This is a bug, but fixing it is difficult to do -// in a safe backwards compatible way, especially due to flattening. -// -// Arguments: -// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output. -func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolWithArgmax", - Input: []tf.Input{ - input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. -type ResourceSparseApplyAdagradDAAttr func(optionalAttr) - -// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update entries in '*var' and '*accum' according to the proximal adagrad scheme. -// -// Arguments: -// var_: Should be from a Variable(). -// gradient_accumulator: Should be from a Variable(). -// gradient_squared_accumulator: Should be from a Variable(). -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// lr: Learning rate. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// global_step: Training step number. Must be a scalar. -// -// Returns the created operation. -func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdagradDA", - Input: []tf.Input{ - var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// EncodeJpegAttr is an optional argument to EncodeJpeg. -type EncodeJpegAttr func(optionalAttr) - -// EncodeJpegFormat sets the optional format attribute to value. -// -// value: Per pixel image format. -// If not specified, defaults to "" -func EncodeJpegFormat(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["format"] = value - } -} - -// EncodeJpegQuality sets the optional quality attribute to value. -// -// value: Quality of the compression from 0 to 100 (higher is better and slower). -// If not specified, defaults to 95 -func EncodeJpegQuality(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["quality"] = value - } -} - -// EncodeJpegProgressive sets the optional progressive attribute to value. -// -// value: If True, create a JPEG that loads progressively (coarse to fine). -// If not specified, defaults to false -func EncodeJpegProgressive(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["progressive"] = value - } -} - -// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value. -// -// value: If True, spend CPU/RAM to reduce size with no quality change. -// If not specified, defaults to false -func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["optimize_size"] = value - } -} - -// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value. -// -// value: See http://en.wikipedia.org/wiki/Chroma_subsampling. -// If not specified, defaults to true -func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr { - return func(m optionalAttr) { - m["chroma_downsampling"] = value - } -} - -// EncodeJpegDensityUnit sets the optional density_unit attribute to value. -// -// value: Unit used to specify `x_density` and `y_density`: -// pixels per inch (`'in'`) or centimeter (`'cm'`). -// If not specified, defaults to "in" -func EncodeJpegDensityUnit(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["density_unit"] = value - } -} - -// EncodeJpegXDensity sets the optional x_density attribute to value. -// -// value: Horizontal pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegXDensity(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["x_density"] = value - } -} - -// EncodeJpegYDensity sets the optional y_density attribute to value. -// -// value: Vertical pixels per density unit. -// If not specified, defaults to 300 -func EncodeJpegYDensity(value int64) EncodeJpegAttr { - return func(m optionalAttr) { - m["y_density"] = value - } -} - -// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value. -// -// value: If not empty, embed this XMP metadata in the image header. -// If not specified, defaults to "" -func EncodeJpegXmpMetadata(value string) EncodeJpegAttr { - return func(m optionalAttr) { - m["xmp_metadata"] = value - } -} - -// JPEG-encode an image. -// -// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`. -// -// The attr `format` can be used to override the color format of the encoded -// output. Values can be: -// -// * `''`: Use a default format based on the number of channels in the image. -// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension -// of `image` must be 1. -// * `rgb`: Output an RGB JPEG image. The `channels` dimension -// of `image` must be 3. -// -// If `format` is not specified or is the empty string, a default format is picked -// in function of the number of channels in `image`: -// -// * 1: Output a grayscale image. -// * 3: Output an RGB image. -// -// Arguments: -// image: 3-D with shape `[height, width, channels]`. -// -// Returns 0-D. JPEG-encoded image. -func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) { +// +// @compatibility(numpy) +// Equivalent to np.fft.fft +// @end_compatibility +func FFT(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "EncodeJpeg", + Type: "FFT", Input: []tf.Input{ - image, + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// MultinomialAttr is an optional argument to Multinomial. -type MultinomialAttr func(optionalAttr) - -// MultinomialSeed sets the optional seed attribute to value. +// Transforms a serialized tensorflow.TensorProto proto into a Tensor. // -// value: If either seed or seed2 is set to be non-zero, the internal random number -// generator is seeded by the given seed. Otherwise, a random seed is used. -// If not specified, defaults to 0 -func MultinomialSeed(value int64) MultinomialAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// MultinomialSeed2 sets the optional seed2 attribute to value. +// Arguments: +// serialized: A scalar string containing a serialized TensorProto proto. +// out_type: The type of the serialized tensor. The provided type must match the +// type of the serialized tensor and no implicit conversion will take place. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func MultinomialSeed2(value int64) MultinomialAttr { - return func(m optionalAttr) { - m["seed2"] = value +// Returns A Tensor of type `out_type`. +func ParseTensor(scope *Scope, serialized tf.Output, out_type tf.DataType) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"out_type": out_type} + opspec := tf.OpSpec{ + Type: "ParseTensor", + Input: []tf.Input{ + serialized, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MultinomialOutputDtype sets the optional output_dtype attribute to value. +// MaxPoolWithArgmaxAttr is an optional argument to MaxPoolWithArgmax. +type MaxPoolWithArgmaxAttr func(optionalAttr) + +// MaxPoolWithArgmaxTargmax sets the optional Targmax attribute to value. // If not specified, defaults to DT_INT64 -func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { +func MaxPoolWithArgmaxTargmax(value tf.DataType) MaxPoolWithArgmaxAttr { return func(m optionalAttr) { - m["output_dtype"] = value + m["Targmax"] = value } } -// Draws samples from a multinomial distribution. +// Performs max pooling on the input and outputs both max values and indices. +// +// The indices in `argmax` are flattened, so that a maximum value at position +// `[b, y, x, c]` becomes flattened index +// `((b * height + y) * width + x) * channels + c`. +// +// The indices returned are always in `[0, height) x [0, width)` before flattening, +// even if padding is involved and the mathematically correct answer is outside +// (either negative or too large). This is a bug, but fixing it is difficult to do +// in a safe backwards compatible way, especially due to flattening. // // Arguments: -// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` -// represents the unnormalized log probabilities for all classes. -// num_samples: 0-D. Number of independent samples to draw for each row slice. +// input: 4-D with shape `[batch, height, width, channels]`. Input to pool over. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. // -// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` -// contains the drawn class labels with range `[0, num_classes)`. -func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { +// Returns The max pooled output tensor.4-D. The flattened indices of the max values chosen for each output. +func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolWithArgmaxAttr) (output tf.Output, argmax tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Multinomial", + Type: "MaxPoolWithArgmax", Input: []tf.Input{ - logits, num_samples, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } // Returns the truth value of NOT x element-wise. @@ -12815,6 +13987,24 @@ func BoostedTreesPredict(scope *Scope, tree_ensemble_handle tf.Output, bucketize return op.Output(0) } +// Elementwise computes the bitwise OR of `x` and `y`. +// +// The result will have those bits set, that are set in `x`, `y` or both. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BitwiseOr", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // MatrixSolveLsAttr is an optional argument to MatrixSolveLs. type MatrixSolveLsAttr func(optionalAttr) @@ -12892,24 +14082,6 @@ func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer return op.Output(0) } -// Elementwise computes the bitwise OR of `x` and `y`. -// -// The result will have those bits set, that are set in `x`, `y` or both. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "BitwiseOr", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // MaxPool3DAttr is an optional argument to MaxPool3D. type MaxPool3DAttr func(optionalAttr) @@ -12950,166 +14122,50 @@ func MaxPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, pa Type: "MaxPool3D", Input: []tf.Input{ input, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Conv3DBackpropInputAttr is an optional argument to Conv3DBackpropInput. -type Conv3DBackpropInputAttr func(optionalAttr) - -// Conv3DBackpropInputDilations sets the optional dilations attribute to value. -// If not specified, defaults to -func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes the gradients of 3-D convolution with respect to the input. -// -// DEPRECATED at GraphDef version 10: Use Conv3DBackpropInputV2 -// -// Arguments: -// input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. -// `in_channels` must match between `input` and `filter`. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Conv3DBackpropInput", - Input: []tf.Input{ - input, filter, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. -type ResourceApplyProximalAdagradAttr func(optionalAttr) - -// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. -// -// accum += grad * grad -// prox_v = var - lr * grad * (1 / sqrt(accum)) -// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} -// -// Arguments: -// var_: Should be from a Variable(). -// accum: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// l1: L1 regularization. Must be a scalar. -// l2: L2 regularization. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyProximalAdagrad", - Input: []tf.Input{ - var_, accum, lr, l1, l2, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. -type MutableHashTableOfTensorsV2Attr func(optionalAttr) - -// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. -// -// value: If non-empty, this table is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this table is shared under the given name across -// multiple sessions. -// If not specified, defaults to "" -func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. -// If not specified, defaults to false -func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { - return func(m optionalAttr) { - m["use_node_name_sharing"] = value + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. -// If not specified, defaults to <> -func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { +// Conv3DBackpropInputAttr is an optional argument to Conv3DBackpropInput. +type Conv3DBackpropInputAttr func(optionalAttr) + +// Conv3DBackpropInputDilations sets the optional dilations attribute to value. +// If not specified, defaults to +func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr { return func(m optionalAttr) { - m["value_shape"] = value + m["dilations"] = value } } -// Creates an empty hash table. +// Computes the gradients of 3-D convolution with respect to the input. // -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a vector. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. +// DEPRECATED at GraphDef version 10: Use Conv3DBackpropInputV2 // // Arguments: -// key_dtype: Type of the table keys. -// value_dtype: Type of the table values. -// -// Returns Handle to a table. -func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { +// input: Shape `[batch, depth, rows, cols, in_channels]`. +// filter: Shape `[depth, rows, cols, in_channels, out_channels]`. +// `in_channels` must match between `input` and `filter`. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropInputAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MutableHashTableOfTensorsV2", - + Type: "Conv3DBackpropInput", + Input: []tf.Input{ + input, filter, out_backprop, + }, Attrs: attrs, } op := scope.AddOperation(opspec) @@ -13157,62 +14213,6 @@ func ResourceScatterSub(scope *Scope, resource tf.Output, indices tf.Output, upd return scope.AddOperation(opspec) } -// Inverse 2D fast Fourier transform. -// -// Computes the inverse 2-dimensional discrete Fourier transform over the -// inner-most 2 dimensions of `input`. -// -// Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their inverse 2D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.ifft2 -// @end_compatibility -func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "IFFT2D", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// 2D fast Fourier transform. -// -// Computes the 2-dimensional discrete Fourier transform over the inner-most -// 2 dimensions of `input`. -// -// Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 -// dimensions of `input` are replaced with their 2D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.fft2 -// @end_compatibility -func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FFT2D", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // ResourceApplyProximalGradientDescentAttr is an optional argument to ResourceApplyProximalGradientDescent. type ResourceApplyProximalGradientDescentAttr func(optionalAttr) @@ -13459,9 +14459,11 @@ func ReduceJoinSeparator(value string) ReduceJoinAttr { // Joins a string Tensor across the given dimensions. // // Computes the string join across dimensions in the given string Tensor of shape -// `[d_0, d_1, ..., d_n-1]`. Returns a new Tensor created by joining the input +// `[\\(d_0, d_1, ..., d_{n-1}\\)]`. Returns a new Tensor created by joining the input // strings with the given separator (default: empty string). Negative indices are -// counted backwards from the end, with `-1` being equivalent to `n - 1`. +// counted backwards from the end, with `-1` being equivalent to `n - 1`. If +// indices are not specified, joins across all dimensions beginning from `n - 1` +// through `0`. // // For example: // @@ -13474,9 +14476,10 @@ func ReduceJoinSeparator(value string) ReduceJoinAttr { // tf.reduce_join(a, 0, keep_dims=True) ==> [["ac", "bd"]] // tf.reduce_join(a, 1, keep_dims=True) ==> [["ab"], ["cd"]] // tf.reduce_join(a, 0, separator=".") ==> ["a.c", "b.d"] -// tf.reduce_join(a, [0, 1]) ==> ["acbd"] -// tf.reduce_join(a, [1, 0]) ==> ["abcd"] -// tf.reduce_join(a, []) ==> ["abcd"] +// tf.reduce_join(a, [0, 1]) ==> "acbd" +// tf.reduce_join(a, [1, 0]) ==> "abcd" +// tf.reduce_join(a, []) ==> [["a", "b"], ["c", "d"]] +// tf.reduce_join(a) = tf.reduce_join(a, [1, 0]) ==> "abcd" // ``` // // Arguments: @@ -13917,6 +14920,83 @@ func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } +// MfccAttr is an optional argument to Mfcc. +type MfccAttr func(optionalAttr) + +// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value. +// +// value: The highest frequency to use when calculating the +// ceptstrum. +// If not specified, defaults to 4000 +func MfccUpperFrequencyLimit(value float32) MfccAttr { + return func(m optionalAttr) { + m["upper_frequency_limit"] = value + } +} + +// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value. +// +// value: The lowest frequency to use when calculating the +// ceptstrum. +// If not specified, defaults to 20 +func MfccLowerFrequencyLimit(value float32) MfccAttr { + return func(m optionalAttr) { + m["lower_frequency_limit"] = value + } +} + +// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value. +// +// value: Resolution of the Mel bank used internally. +// If not specified, defaults to 40 +func MfccFilterbankChannelCount(value int64) MfccAttr { + return func(m optionalAttr) { + m["filterbank_channel_count"] = value + } +} + +// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value. +// +// value: How many output channels to produce per time slice. +// If not specified, defaults to 13 +func MfccDctCoefficientCount(value int64) MfccAttr { + return func(m optionalAttr) { + m["dct_coefficient_count"] = value + } +} + +// Transforms a spectrogram into a form that's useful for speech recognition. +// +// Mel Frequency Cepstral Coefficients are a way of representing audio data that's +// been effective as an input feature for machine learning. They are created by +// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the +// higher frequencies that are less significant to the human ear. They have a long +// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum +// is a good resource to learn more. +// +// Arguments: +// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared +// set to true. +// sample_rate: How many samples per second the source audio used. +func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Mfcc", + Input: []tf.Input{ + spectrogram, sample_rate, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // AudioSummaryAttr is an optional argument to AudioSummary. type AudioSummaryAttr func(optionalAttr) @@ -14278,117 +15358,58 @@ func QuantizedRelu(scope *Scope, features tf.Output, min_features tf.Output, max // Reshapes a SparseTensor to represent values in a new dense shape. // -// This operation has the same semantics as reshape on the represented dense -// tensor. The `input_indices` are recomputed based on the requested `new_shape`. -// -// If one component of `new_shape` is the special value -1, the size of that -// dimension is computed so that the total dense size remains constant. At -// most one component of `new_shape` can be -1. The number of dense elements -// implied by `new_shape` must be the same as the number of dense elements -// originally implied by `input_shape`. -// -// Reshaping does not affect the order of values in the SparseTensor. -// -// If the input tensor has rank `R_in` and `N` non-empty values, and `new_shape` -// has length `R_out`, then `input_indices` has shape `[N, R_in]`, -// `input_shape` has length `R_in`, `output_indices` has shape `[N, R_out]`, and -// `output_shape` has length `R_out`. -// -// Arguments: -// input_indices: 2-D. `N x R_in` matrix with the indices of non-empty values in a -// SparseTensor. -// input_shape: 1-D. `R_in` vector with the input SparseTensor's dense shape. -// new_shape: 1-D. `R_out` vector with the requested new dense shape. -// -// Returns 2-D. `N x R_out` matrix with the updated indices of non-empty -// values in the output SparseTensor.1-D. `R_out` vector with the full dense shape of the output -// SparseTensor. This is the same as `new_shape` but with any -1 dimensions -// filled in. -func SparseReshape(scope *Scope, input_indices tf.Output, input_shape tf.Output, new_shape tf.Output) (output_indices tf.Output, output_shape tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseReshape", - Input: []tf.Input{ - input_indices, input_shape, new_shape, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Deprecated. Use TensorArraySplitV3 -// -// DEPRECATED at GraphDef version 26: Use TensorArraySplitV3 -func TensorArraySplitV2(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "TensorArraySplitV2", - Input: []tf.Input{ - handle, value, lengths, flow_in, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// PackAttr is an optional argument to Pack. -type PackAttr func(optionalAttr) - -// PackAxis sets the optional axis attribute to value. -// -// value: Dimension along which to pack. Negative values wrap around, so the -// valid range is `[-(R+1), R+1)`. -// If not specified, defaults to 0 -func PackAxis(value int64) PackAttr { - return func(m optionalAttr) { - m["axis"] = value - } -} - -// Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor. -// -// Packs the `N` tensors in `values` into a tensor with rank one higher than each -// tensor in `values`, by packing them along the `axis` dimension. -// Given a list of tensors of shape `(A, B, C)`; -// -// if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`. -// if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. -// Etc. +// This operation has the same semantics as reshape on the represented dense +// tensor. The `input_indices` are recomputed based on the requested `new_shape`. // -// For example: +// If one component of `new_shape` is the special value -1, the size of that +// dimension is computed so that the total dense size remains constant. At +// most one component of `new_shape` can be -1. The number of dense elements +// implied by `new_shape` must be the same as the number of dense elements +// originally implied by `input_shape`. // -// ``` -// # 'x' is [1, 4] -// # 'y' is [2, 5] -// # 'z' is [3, 6] -// pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. -// pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] -// ``` +// Reshaping does not affect the order of values in the SparseTensor. // -// This is the opposite of `unpack`. +// If the input tensor has rank `R_in` and `N` non-empty values, and `new_shape` +// has length `R_out`, then `input_indices` has shape `[N, R_in]`, +// `input_shape` has length `R_in`, `output_indices` has shape `[N, R_out]`, and +// `output_shape` has length `R_out`. // // Arguments: -// values: Must be of same shape and type. +// input_indices: 2-D. `N x R_in` matrix with the indices of non-empty values in a +// SparseTensor. +// input_shape: 1-D. `R_in` vector with the input SparseTensor's dense shape. +// new_shape: 1-D. `R_out` vector with the requested new dense shape. // -// Returns The packed tensor. -func Pack(scope *Scope, values []tf.Output, optional ...PackAttr) (output tf.Output) { +// Returns 2-D. `N x R_out` matrix with the updated indices of non-empty +// values in the output SparseTensor.1-D. `R_out` vector with the full dense shape of the output +// SparseTensor. This is the same as `new_shape` but with any -1 dimensions +// filled in. +func SparseReshape(scope *Scope, input_indices tf.Output, input_shape tf.Output, new_shape tf.Output) (output_indices tf.Output, output_shape tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) + opspec := tf.OpSpec{ + Type: "SparseReshape", + Input: []tf.Input{ + input_indices, input_shape, new_shape, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// Deprecated. Use TensorArraySplitV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArraySplitV3 +func TensorArraySplitV2(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { + if scope.Err() != nil { + return } opspec := tf.OpSpec{ - Type: "Pack", + Type: "TensorArraySplitV2", Input: []tf.Input{ - tf.OutputList(values), + handle, value, lengths, flow_in, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -14548,27 +15569,27 @@ func CudnnRNNBackpropSeed2(value int64) CudnnRNNBackpropAttr { // // rnn_mode: Indicates the type of the RNN model. // input_mode: Indicate whether there is a linear projection between the input and -// The actual computation before the first layer. 'skip_input' is only allowed +// the actual computation before the first layer. 'skip_input' is only allowed // when input_size == num_units; 'auto_select' implies 'skip_input' when // input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. -// dir = (direction == bidirectional) ? 2 : 1 -// dropout: dropout probability. When set to 0., dropout is disabled. -// seed: the 1st part of a seed to initialize dropout. -// seed2: the 2nd part of a seed to initialize dropout. -// input: a 3-D tensor with the shape of [seq_length, batch_size, input_size]. -// input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size, +// direction: Indicates whether a bidirectional model will be used. Should be +// "unidirectional" or "bidirectional". +// dropout: Dropout probability. When set to 0., dropout is disabled. +// seed: The 1st part of a seed to initialize dropout. +// seed2: The 2nd part of a seed to initialize dropout. +// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. +// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, // num_units]. // input_c: For LSTM, a 3-D tensor with the shape of // [num_layer * dir, batch, num_units]. For other models, it is ignored. -// params: a 1-D tensor that contains the weights and biases in an opaque layout. +// params: A 1-D tensor that contains the weights and biases in an opaque layout. // The size must be created through CudnnRNNParamsSize, and initialized // separately. Note that they might not be compatible across different // generations. So it is a good idea to save and restore -// output: a 3-D tensor with the shape of [seq_length, batch_size, +// output: A 3-D tensor with the shape of [seq_length, batch_size, // dir * num_units]. -// output_h: the same shape has input_h. -// output_c: the same shape as input_c for LSTM. An empty tensor for other models. +// output_h: The same shape has input_h. +// output_c: The same shape as input_c for LSTM. An empty tensor for other models. // output_backprop: A 3-D tensor with the same shape as output in the forward pass. // output_h_backprop: A 3-D tensor with the same shape as output_h in the forward // pass. @@ -15053,102 +16074,6 @@ func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Updates specified rows with values in `v`. -// -// Computes `x[i, :] = v; return x`. -// -// Arguments: -// x: A tensor of type `T`. -// i: A vector. Indices into the left-most dimension of `x`. -// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. -// -// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. -func InplaceUpdate(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "InplaceUpdate", - Input: []tf.Input{ - x, i, v, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FusedBatchNormAttr is an optional argument to FusedBatchNorm. -type FusedBatchNormAttr func(optionalAttr) - -// FusedBatchNormEpsilon sets the optional epsilon attribute to value. -// -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr { - return func(m optionalAttr) { - m["epsilon"] = value - } -} - -// FusedBatchNormDataFormat sets the optional data_format attribute to value. -// -// value: The data format for x and y. Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormDataFormat(value string) FusedBatchNormAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// FusedBatchNormIsTraining sets the optional is_training attribute to value. -// -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr { - return func(m optionalAttr) { - m["is_training"] = value - } -} - -// Batch normalization. -// -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. -// -// Arguments: -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// offset: A 1D Tensor for offset, to shift to the normalized x. -// mean: A 1D Tensor for population mean. Used for inference only; -// must be empty for training. -// variance: A 1D Tensor for population variance. Used for inference only; -// must be empty for training. -// -// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow -// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by -// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused -// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance -// in the cuDNN case), to be reused in the gradient computation. -func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FusedBatchNorm", - Input: []tf.Input{ - x, scale, offset, mean, variance, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) -} - // RandomStandardNormalAttr is an optional argument to RandomStandardNormal. type RandomStandardNormalAttr func(optionalAttr) @@ -15324,31 +16249,6 @@ func BoostedTreesEnsembleResourceHandleOp(scope *Scope, optional ...BoostedTrees return op.Output(0) } -// Concatenates tensors along one dimension. -// -// Arguments: -// concat_dim: 0-D. The dimension along which to concatenate. Must be in the -// range [0, rank(values)). -// values: The `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes. -func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Concat", - Input: []tf.Input{ - concat_dim, tf.OutputList(values), - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum. type ResourceApplyMomentumAttr func(optionalAttr) @@ -15578,6 +16478,30 @@ func OrderedMapUnstageNoKey(scope *Scope, indices tf.Output, dtypes []tf.DataTyp return key, values } +// Calculates the prior from the training data (the bias) and fills in the first node with the logits' prior. Returns a boolean indicating whether to continue centering. +// +// Arguments: +// tree_ensemble_handle: Handle to the tree ensemble. +// mean_gradients: A tensor with shape=[logits_dimension] with mean of gradients for a first node. +// mean_hessians: A tensor with shape=[logits_dimension] mean of hessians for a first node. +// l1: l1 regularization factor on leaf weights, per instance based. +// l2: l2 regularization factor on leaf weights, per instance based. +// +// Returns Bool, whether to continue bias centering. +func BoostedTreesCenterBias(scope *Scope, tree_ensemble_handle tf.Output, mean_gradients tf.Output, mean_hessians tf.Output, l1 tf.Output, l2 tf.Output) (continue_centering tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BoostedTreesCenterBias", + Input: []tf.Input{ + tree_ensemble_handle, mean_gradients, mean_hessians, l1, l2, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // SerializeManySparseAttr is an optional argument to SerializeManySparse. type SerializeManySparseAttr func(optionalAttr) @@ -16233,35 +17157,91 @@ func MutableDenseHashTableV2MaxLoadFactor(value float32) MutableDenseHashTableV2 } } -// Creates an empty hash table that uses tensors as the backing store. -// -// It uses "open addressing" with quadratic reprobing to resolve -// collisions. +// Creates an empty hash table that uses tensors as the backing store. +// +// It uses "open addressing" with quadratic reprobing to resolve +// collisions. +// +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a scalar. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// empty_key: The key used to represent empty key buckets internally. Must not +// be used in insert or lookup operations. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.DataType, optional ...MutableDenseHashTableV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutableDenseHashTableV2", + Input: []tf.Input{ + empty_key, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// 2D fast Fourier transform. +// +// Computes the 2-dimensional discrete Fourier transform over the inner-most +// 2 dimensions of `input`. +// +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 +// dimensions of `input` are replaced with their 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fft2 +// @end_compatibility +func FFT2D(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FFT2D", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Inverse 2D fast Fourier transform. // -// This op creates a mutable hash table, specifying the type of its keys and -// values. Each value must be a scalar. Data can be inserted into the table using -// the insert operations. It does not support the initialization operation. +// Computes the inverse 2-dimensional discrete Fourier transform over the +// inner-most 2 dimensions of `input`. // // Arguments: -// empty_key: The key used to represent empty key buckets internally. Must not -// be used in insert or lookup operations. -// value_dtype: Type of the table values. +// input: A complex64 tensor. // -// Returns Handle to a table. -func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.DataType, optional ...MutableDenseHashTableV2Attr) (table_handle tf.Output) { +// Returns A complex64 tensor of the same shape as `input`. The inner-most 2 +// dimensions of `input` are replaced with their inverse 2D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.ifft2 +// @end_compatibility +func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"value_dtype": value_dtype} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MutableDenseHashTableV2", + Type: "IFFT2D", Input: []tf.Input{ - empty_key, + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -17090,6 +18070,7 @@ func QuantizeV2RoundMode(value string) QuantizeV2Attr { // out[i] = (in[i] - min_range) * range(T) / (max_range - min_range) // if T == qint8, out[i] -= (range(T) + 1) / 2.0 // ``` +// // here `range(T) = numeric_limits::max() - numeric_limits::min()` // // *MIN_COMBINED Mode Example* @@ -17133,6 +18114,7 @@ func QuantizeV2RoundMode(value string) QuantizeV2Attr { // // We first find the range of values in our tensor. The // range we use is always centered on 0, so we find m such that +// // ```c++ // m = max(abs(input_min), abs(input_max)) // ``` @@ -17141,6 +18123,7 @@ func QuantizeV2RoundMode(value string) QuantizeV2Attr { // // Next, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`. // If T is signed, this is +// // ``` // num_bits = sizeof(T) * 8 // [min_fixed, max_fixed] = @@ -17148,16 +18131,19 @@ func QuantizeV2RoundMode(value string) QuantizeV2Attr { // ``` // // Otherwise, if T is unsigned, the fixed-point range is +// // ``` // [min_fixed, max_fixed] = [0, (1 << num_bits) - 1] // ``` // // From this we compute our scaling factor, s: +// // ```c++ // s = (max_fixed - min_fixed) / (2 * m) // ``` // // Now we can quantize the elements of our tensor: +// // ```c++ // result = round(input * s) // ``` @@ -17254,6 +18240,31 @@ func QuantizedReluX(scope *Scope, features tf.Output, max_value tf.Output, min_f return op.Output(0), op.Output(1), op.Output(2) } +// Creates a dataset that batches `batch_size` elements from `input_dataset`. +// +// Arguments: +// +// batch_size: A scalar representing the number of elements to accumulate in a batch. +// drop_remainder: A scalar representing whether the last batch should be dropped in case its size +// is smaller than desired. +// +// +func BatchDatasetV2(scope *Scope, input_dataset tf.Output, batch_size tf.Output, drop_remainder tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "BatchDatasetV2", + Input: []tf.Input{ + input_dataset, batch_size, drop_remainder, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // QuantizedConv2DAttr is an optional argument to QuantizedConv2D. type QuantizedConv2DAttr func(optionalAttr) @@ -17777,75 +18788,238 @@ func SparseCross(scope *Scope, indices []tf.Output, values []tf.Output, shapes [ return op.Output(0), op.Output(1), op.Output(2) } -// Concatenates quantized tensors along one dimension. +// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad. +type ResourceApplyProximalAdagradAttr func(optionalAttr) + +// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate. +// +// accum += grad * grad +// prox_v = var - lr * grad * (1 / sqrt(accum)) +// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} // // Arguments: -// concat_dim: 0-D. The dimension along which to concatenate. Must be in the -// range [0, rank(values)). -// values: The `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// input_mins: The minimum scalar values for each of the input tensors. -// input_maxes: The maximum scalar values for each of the input tensors. +// var_: Should be from a Variable(). +// accum: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// l1: L1 regularization. Must be a scalar. +// l2: L2 regularization. Must be a scalar. +// grad: The gradient. // -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents. -func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) { +// Returns the created operation. +func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "QuantizedConcat", + Type: "ResourceApplyProximalAdagrad", Input: []tf.Input{ - concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes), + var_, accum, lr, l1, l2, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. +type MutableHashTableOfTensorsV2Attr func(optionalAttr) + +// MutableHashTableOfTensorsV2Container sets the optional container attribute to value. +// +// value: If non-empty, this table is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this table is shared under the given name across +// multiple sessions. +// If not specified, defaults to "" +func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value. +// If not specified, defaults to false +func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["use_node_name_sharing"] = value + } +} + +// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value. +// If not specified, defaults to <> +func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr { + return func(m optionalAttr) { + m["value_shape"] = value + } +} + +// Creates an empty hash table. +// +// This op creates a mutable hash table, specifying the type of its keys and +// values. Each value must be a vector. Data can be inserted into the table using +// the insert operations. It does not support the initialization operation. +// +// Arguments: +// key_dtype: Type of the table keys. +// value_dtype: Type of the table values. +// +// Returns Handle to a table. +func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MutableHashTableOfTensorsV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// The gradient operator for the SparseSlice op. +// +// This op takes in the upstream gradient w.r.t. non-empty values of +// the sliced `SparseTensor`, and outputs the gradients w.r.t. +// the non-empty values of input `SparseTensor`. +// +// Arguments: +// backprop_val_grad: 1-D. The gradient with respect to +// the non-empty values of the sliced `SparseTensor`. +// input_indices: 2-D. The `indices` of the input `SparseTensor`. +// input_start: 1-D. tensor represents the start of the slice. +// output_indices: 2-D. The `indices` of the sliced `SparseTensor`. +// +// Returns 1-D. The gradient with respect to the non-empty values of input `SparseTensor`. +func SparseSliceGrad(scope *Scope, backprop_val_grad tf.Output, input_indices tf.Output, input_start tf.Output, output_indices tf.Output) (val_grad tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSliceGrad", + Input: []tf.Input{ + backprop_val_grad, input_indices, input_start, output_indices, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Slice a `SparseTensor` based on the `start` and `size`. +// Computes the gradient of the sigmoid of `x` wrt its input. // -// For example, if the input is +// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and +// `dy` is the corresponding input gradient. +func SigmoidGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SigmoidGrad", + Input: []tf.Input{ + y, dy, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Convert one or more images from HSV to RGB. // -// input_tensor = shape = [2, 7] -// [ a d e ] -// [b c ] +// Outputs a tensor of the same shape as the `images` tensor, containing the RGB +// value of the pixels. The output is only well defined if the value in `images` +// are in `[0,1]`. // -// Graphically the output tensors are: +// See `rgb_to_hsv` for a description of the HSV encoding. // -// sparse_slice([0, 0], [2, 4]) = shape = [2, 4] -// [ a ] -// [b c ] +// Arguments: +// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. // -// sparse_slice([0, 4], [2, 3]) = shape = [2, 3] -// [ d e ] -// [ ] +// Returns `images` converted to RGB. +func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "HSVToRGB", + Input: []tf.Input{ + images, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset by applying optimizations to `input_dataset`. +// +// Creates a dataset by applying optimizations to `input_dataset`. +// +// Arguments: +// input_dataset: A variant tensor representing the input dataset. +// optimizations: A `tf.string` vector `tf.Tensor` identifying optimizations to use. +// +// +func OptimizeDataset(scope *Scope, input_dataset tf.Output, optimizations tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "OptimizeDataset", + Input: []tf.Input{ + input_dataset, optimizations, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics. // // Arguments: -// indices: 2-D tensor represents the indices of the sparse tensor. -// values: 1-D tensor represents the values of the sparse tensor. -// shape: 1-D. tensor represents the shape of the sparse tensor. -// start: 1-D. tensor represents the start of the slice. -// size: 1-D. tensor represents the size of the slice. -// output indices: A list of 1-D tensors represents the indices of the output -// sparse tensors. +// tree_ensemble_handle: Handle to the tree ensemble. // -// Returns A list of 1-D tensors represents the values of the output sparse -// tensors.A list of 1-D tensors represents the shape of the output sparse -// tensors. -func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { +// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest +// layer. +func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "SparseSlice", + Type: "BoostedTreesGetEnsembleStates", Input: []tf.Input{ - indices, values, shape, start, size, + tree_ensemble_handle, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } // Returns the element-wise min of two SparseTensors. @@ -17978,73 +19152,47 @@ func TakeManySparseFromTensorsMap(scope *Scope, sparse_handles tf.Output, dtype return op.Output(0), op.Output(1), op.Output(2) } -// MaxPoolAttr is an optional argument to MaxPool. -type MaxPoolAttr func(optionalAttr) - -// MaxPoolDataFormat sets the optional data_format attribute to value. +// Assigns a new value to a variable. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolDataFormat(value string) MaxPoolAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Performs max pooling on the input. +// Any ReadVariableOp with a control dependency on this op is guaranteed to return +// this value or a subsequent newer value of the variable. // // Arguments: -// input: 4-D input to pool over. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. +// resource: handle to the resource in which to store the variable. +// value: the value to set the new tensor to use. // -// Returns The max pooled output tensor. -func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) { +// Returns the created operation. +func AssignVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "MaxPool", + Type: "AssignVariableOp", Input: []tf.Input{ - input, + resource, value, }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Assigns a new value to a variable. -// -// Any ReadVariableOp with a control dependency on this op is guaranteed to return -// this value or a subsequent newer value of the variable. +// Strip leading and trailing whitespaces from the Tensor. // // Arguments: -// resource: handle to the resource in which to store the variable. -// value: the value to set the new tensor to use. +// input: A string `Tensor` of any shape. // -// Returns the created operation. -func AssignVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf.Operation) { +// Returns A string `Tensor` of the same shape as the input. +func StringStrip(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "AssignVariableOp", + Type: "StringStrip", Input: []tf.Input{ - resource, value, + input, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } // Returns a tensor of ones with the same shape and type as x. @@ -18098,9 +19246,12 @@ func SparseFillEmptyRowsGrad(scope *Scope, reverse_index_map tf.Output, grad_val } // Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` +// // if < 0, `scale * features` otherwise. // -// Assumes weights to have zero mean and variance 1.0 / fan_in. +// To be used together with +// `initializer = tf.variance_scaling_initializer(factor=1.0, mode='FAN_IN')`. +// For correct dropout, use `tf.contrib.nn.alpha_dropout`. // // See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) func Selu(scope *Scope, features tf.Output) (activations tf.Output) { @@ -18195,6 +19346,70 @@ func LogMatrixDeterminant(scope *Scope, input tf.Output) (sign tf.Output, log_ab return op.Output(0), op.Output(1) } +// Copy a tensor setting everything outside a central band in each innermost matrix +// +// to zero. +// +// The `band` part is computed as follows: +// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a +// tensor with the same shape where +// +// `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`. +// +// The indicator function +// +// `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && +// (num_upper < 0 || (n-m) <= num_upper)`. +// +// For example: +// +// ``` +// # if 'input' is [[ 0, 1, 2, 3] +// [-1, 0, 1, 2] +// [-2, -1, 0, 1] +// [-3, -2, -1, 0]], +// +// tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3] +// [-1, 0, 1, 2] +// [ 0, -1, 0, 1] +// [ 0, 0, -1, 0]], +// +// tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0] +// [-1, 0, 1, 0] +// [-2, -1, 0, 1] +// [ 0, -2, -1, 0]] +// ``` +// +// Useful special cases: +// +// ``` +// tf.matrix_band_part(input, 0, -1) ==> Upper triangular part. +// tf.matrix_band_part(input, -1, 0) ==> Lower triangular part. +// tf.matrix_band_part(input, 0, 0) ==> Diagonal. +// ``` +// +// Arguments: +// input: Rank `k` tensor. +// num_lower: 0-D tensor. Number of subdiagonals to keep. If negative, keep entire +// lower triangle. +// num_upper: 0-D tensor. Number of superdiagonals to keep. If negative, keep +// entire upper triangle. +// +// Returns Rank `k` tensor of the same shape as input. The extracted banded tensor. +func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_upper tf.Output) (band tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatrixBandPart", + Input: []tf.Input{ + input, num_lower, num_upper, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // SumAttr is an optional argument to Sum. type SumAttr func(optionalAttr) @@ -18539,134 +19754,71 @@ func SdcaOptimizerAdaptative(value bool) SdcaOptimizerAttr { // // [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).
// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, -// Peter Richtarik, Martin Takac. 2015 -// -// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).
-// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 -// -// Arguments: -// sparse_example_indices: a list of vectors which contain example indices. -// sparse_feature_indices: a list of vectors which contain feature indices. -// sparse_feature_values: a list of vectors which contains feature value -// associated with each feature group. -// dense_features: a list of matrices which contains the dense feature values. -// example_weights: a vector which contains the weight associated with each -// example. -// example_labels: a vector which contains the label/target associated with each -// example. -// sparse_indices: a list of vectors where each value is the indices which has -// corresponding weights in sparse_weights. This field maybe omitted for the -// dense approach. -// sparse_weights: a list of vectors where each value is the weight associated with -// a sparse feature group. -// dense_weights: a list of vectors where the values are the weights associated -// with a dense feature group. -// example_state_data: a list of vectors containing the example state data. -// loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, -// squared and hinge losses. -// l1: Symmetric l1 regularization strength. -// l2: Symmetric l2 regularization strength. -// num_loss_partitions: Number of partitions of the global loss function. -// num_inner_iterations: Number of iterations per mini-batch. -// -// Returns a list of vectors containing the updated example state -// data.a list of vectors where each value is the delta -// weights associated with a sparse feature group.a list of vectors where the values are the delta -// weights associated with a dense feature group. -func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feature_indices []tf.Output, sparse_feature_values []tf.Output, dense_features []tf.Output, example_weights tf.Output, example_labels tf.Output, sparse_indices []tf.Output, sparse_weights []tf.Output, dense_weights []tf.Output, example_state_data tf.Output, loss_type string, l1 float32, l2 float32, num_loss_partitions int64, num_inner_iterations int64, optional ...SdcaOptimizerAttr) (out_example_state_data tf.Output, out_delta_sparse_weights []tf.Output, out_delta_dense_weights []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"loss_type": loss_type, "l1": l1, "l2": l2, "num_loss_partitions": num_loss_partitions, "num_inner_iterations": num_inner_iterations} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "SdcaOptimizer", - Input: []tf.Input{ - tf.OutputList(sparse_example_indices), tf.OutputList(sparse_feature_indices), tf.OutputList(sparse_feature_values), tf.OutputList(dense_features), example_weights, example_labels, tf.OutputList(sparse_indices), tf.OutputList(sparse_weights), tf.OutputList(dense_weights), example_state_data, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - out_example_state_data = op.Output(idx) - if out_delta_sparse_weights, idx, err = makeOutputList(op, idx, "out_delta_sparse_weights"); err != nil { - scope.UpdateErr("SdcaOptimizer", err) - return - } - if out_delta_dense_weights, idx, err = makeOutputList(op, idx, "out_delta_dense_weights"); err != nil { - scope.UpdateErr("SdcaOptimizer", err) - return - } - return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights -} - -// SparseMatMulAttr is an optional argument to SparseMatMul. -type SparseMatMulAttr func(optionalAttr) - -// SparseMatMulTransposeA sets the optional transpose_a attribute to value. -// If not specified, defaults to false -func SparseMatMulTransposeA(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["transpose_a"] = value - } -} - -// SparseMatMulTransposeB sets the optional transpose_b attribute to value. -// If not specified, defaults to false -func SparseMatMulTransposeB(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["transpose_b"] = value - } -} - -// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value. -// If not specified, defaults to false -func SparseMatMulAIsSparse(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["a_is_sparse"] = value - } -} - -// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value. -// If not specified, defaults to false -func SparseMatMulBIsSparse(value bool) SparseMatMulAttr { - return func(m optionalAttr) { - m["b_is_sparse"] = value - } -} - -// Multiply matrix "a" by matrix "b". +// Peter Richtarik, Martin Takac. 2015 // -// The inputs must be two-dimensional matrices and the inner dimension of "a" must -// match the outer dimension of "b". This op is optimized for the case where at -// least one of "a" or "b" is sparse. The breakeven for using this versus a dense -// matrix multiply on one platform was 30% zero values in the sparse matrix. +// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).
+// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015 // -// The gradient computation of this operation will only take advantage of sparsity -// in the input gradient when that gradient comes from a Relu. -func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) { +// Arguments: +// sparse_example_indices: a list of vectors which contain example indices. +// sparse_feature_indices: a list of vectors which contain feature indices. +// sparse_feature_values: a list of vectors which contains feature value +// associated with each feature group. +// dense_features: a list of matrices which contains the dense feature values. +// example_weights: a vector which contains the weight associated with each +// example. +// example_labels: a vector which contains the label/target associated with each +// example. +// sparse_indices: a list of vectors where each value is the indices which has +// corresponding weights in sparse_weights. This field maybe omitted for the +// dense approach. +// sparse_weights: a list of vectors where each value is the weight associated with +// a sparse feature group. +// dense_weights: a list of vectors where the values are the weights associated +// with a dense feature group. +// example_state_data: a list of vectors containing the example state data. +// loss_type: Type of the primal loss. Currently SdcaSolver supports logistic, +// squared and hinge losses. +// l1: Symmetric l1 regularization strength. +// l2: Symmetric l2 regularization strength. +// num_loss_partitions: Number of partitions of the global loss function. +// num_inner_iterations: Number of iterations per mini-batch. +// +// Returns a list of vectors containing the updated example state +// data.a list of vectors where each value is the delta +// weights associated with a sparse feature group.a list of vectors where the values are the delta +// weights associated with a dense feature group. +func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feature_indices []tf.Output, sparse_feature_values []tf.Output, dense_features []tf.Output, example_weights tf.Output, example_labels tf.Output, sparse_indices []tf.Output, sparse_weights []tf.Output, dense_weights []tf.Output, example_state_data tf.Output, loss_type string, l1 float32, l2 float32, num_loss_partitions int64, num_inner_iterations int64, optional ...SdcaOptimizerAttr) (out_example_state_data tf.Output, out_delta_sparse_weights []tf.Output, out_delta_dense_weights []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"loss_type": loss_type, "l1": l1, "l2": l2, "num_loss_partitions": num_loss_partitions, "num_inner_iterations": num_inner_iterations} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "SparseMatMul", + Type: "SdcaOptimizer", Input: []tf.Input{ - a, b, + tf.OutputList(sparse_example_indices), tf.OutputList(sparse_feature_indices), tf.OutputList(sparse_feature_values), tf.OutputList(dense_features), example_weights, example_labels, tf.OutputList(sparse_indices), tf.OutputList(sparse_weights), tf.OutputList(dense_weights), example_state_data, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + out_example_state_data = op.Output(idx) + if out_delta_sparse_weights, idx, err = makeOutputList(op, idx, "out_delta_sparse_weights"); err != nil { + scope.UpdateErr("SdcaOptimizer", err) + return + } + if out_delta_dense_weights, idx, err = makeOutputList(op, idx, "out_delta_dense_weights"); err != nil { + scope.UpdateErr("SdcaOptimizer", err) + return + } + return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights } // ShapeAttr is an optional argument to Shape. @@ -18847,7 +19999,7 @@ func MatrixTriangularSolveLower(value bool) MatrixTriangularSolveAttr { // adjoint. // // @compatibility(numpy) -// Equivalent to np.linalg.triangular_solve +// Equivalent to scipy.linalg.solve_triangular // @end_compatibility // If not specified, defaults to false func MatrixTriangularSolveAdjoint(value bool) MatrixTriangularSolveAttr { @@ -19372,12 +20524,211 @@ func RandomUniformIntSeed2(value int64) RandomUniformIntAttr { // smaller than the range of the output (either `2^32` or `2^64`). // // Arguments: -// shape: The shape of the output tensor. -// minval: 0-D. Inclusive lower bound on the generated integers. -// maxval: 0-D. Exclusive upper bound on the generated integers. +// shape: The shape of the output tensor. +// minval: 0-D. Inclusive lower bound on the generated integers. +// maxval: 0-D. Exclusive upper bound on the generated integers. +// +// Returns A tensor of the specified shape filled with uniform random integers. +func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomUniformInt", + Input: []tf.Input{ + shape, minval, maxval, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Compute the upper regularized incomplete Gamma function `Q(a, x)`. +// +// The upper regularized incomplete Gamma function is defined as: +// +// \\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\) +// +// where +// +// \\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\\) +// +// is the upper incomplete Gama function. +// +// Note, above `P(a, x)` (`Igamma`) is the lower regularized complete +// Gamma function. +func Igammac(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Igammac", + Input: []tf.Input{ + a, x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes the sum along sparse segments of a tensor divided by the sqrt of N. +// +// N is the size of the segment being reduced. +// +// Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is +// misisng, the `output` tensor at that position will be zeroed. +// +// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of +// segments. +// +// Arguments: +// +// indices: A 1-D tensor. Has same rank as `segment_ids`. +// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. +// num_segments: Should equal the number of distinct segment IDs. +// +// Returns Has same shape as data, except for dimension 0 which +// has size `k`, the number of segments. +func SparseSegmentSqrtNWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSegmentSqrtNWithNumSegments", + Input: []tf.Input{ + data, indices, segment_ids, num_segments, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes gradients for SparseSegmentSqrtN. +// +// Returns tensor "output" with same shape as grad, except for dimension 0 whose +// value is output_dim0. +// +// Arguments: +// grad: gradient propagated to the SparseSegmentSqrtN op. +// indices: indices passed to the corresponding SparseSegmentSqrtN op. +// segment_ids: segment_ids passed to the corresponding SparseSegmentSqrtN op. +// output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op. +func SparseSegmentSqrtNGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseSegmentSqrtNGrad", + Input: []tf.Input{ + grad, indices, segment_ids, output_dim0, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// LRNGradAttr is an optional argument to LRNGrad. +type LRNGradAttr func(optionalAttr) + +// LRNGradDepthRadius sets the optional depth_radius attribute to value. +// +// value: A depth radius. +// If not specified, defaults to 5 +func LRNGradDepthRadius(value int64) LRNGradAttr { + return func(m optionalAttr) { + m["depth_radius"] = value + } +} + +// LRNGradBias sets the optional bias attribute to value. +// +// value: An offset (usually > 0 to avoid dividing by 0). +// If not specified, defaults to 1 +func LRNGradBias(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["bias"] = value + } +} + +// LRNGradAlpha sets the optional alpha attribute to value. +// +// value: A scale factor, usually positive. +// If not specified, defaults to 1 +func LRNGradAlpha(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["alpha"] = value + } +} + +// LRNGradBeta sets the optional beta attribute to value. +// +// value: An exponent. +// If not specified, defaults to 0.5 +func LRNGradBeta(value float32) LRNGradAttr { + return func(m optionalAttr) { + m["beta"] = value + } +} + +// Gradients for Local Response Normalization. +// +// Arguments: +// input_grads: 4-D with shape `[batch, height, width, channels]`. +// input_image: 4-D with shape `[batch, height, width, channels]`. +// output_image: 4-D with shape `[batch, height, width, channels]`. +// +// Returns The gradients for LRN. +func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "LRNGrad", + Input: []tf.Input{ + input_grads, input_image, output_image, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// AnyAttr is an optional argument to Any. +type AnyAttr func(optionalAttr) + +// AnyKeepDims sets the optional keep_dims attribute to value. +// +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func AnyKeepDims(value bool) AnyAttr { + return func(m optionalAttr) { + m["keep_dims"] = value + } +} + +// Computes the "logical or" of elements across dimensions of a tensor. +// +// Reduces `input` along the dimensions given in `axis`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `axis`. If `keep_dims` is true, the reduced dimensions are +// retained with length 1. +// +// Arguments: +// input: The tensor to reduce. +// axis: The dimensions to reduce. Must be in the range +// `[-rank(input), rank(input))`. // -// Returns A tensor of the specified shape filled with uniform random integers. -func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) { +// Returns The reduced tensor. +func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -19386,9 +20737,9 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf a(attrs) } opspec := tf.OpSpec{ - Type: "RandomUniformInt", + Type: "Any", Input: []tf.Input{ - shape, minval, maxval, + input, axis, }, Attrs: attrs, } @@ -19396,119 +20747,108 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf return op.Output(0) } -// RandomShuffleAttr is an optional argument to RandomShuffle. -type RandomShuffleAttr func(optionalAttr) - -// RandomShuffleSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomShuffleSeed(value int64) RandomShuffleAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomShuffleSeed2 sets the optional seed2 attribute to value. +// Creates a sequence of numbers. // -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomShuffleSeed2(value int64) RandomShuffleAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Randomly shuffles a tensor along its first dimension. +// This operation creates a sequence of numbers that begins at `start` and +// extends by increments of `delta` up to but not including `limit`. // -// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped -// to one and only one `output[i]`. For example, a mapping that might occur for a -// 3x2 tensor is: +// For example: // // ``` -// [[1, 2], [[5, 6], -// [3, 4], ==> [1, 2], -// [5, 6]] [3, 4]] +// # 'start' is 3 +// # 'limit' is 18 +// # 'delta' is 3 +// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] // ``` // // Arguments: -// value: The tensor to be shuffled. +// start: 0-D (scalar). First entry in the sequence. +// limit: 0-D (scalar). Upper limit of sequence, exclusive. +// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`. // -// Returns A tensor of same shape and type as `value`, shuffled along its first -// dimension. -func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) { +// Returns 1-D. +func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "RandomShuffle", + Type: "Range", Input: []tf.Input{ - value, + start, limit, delta, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// OrderedMapIncompleteSizeAttr is an optional argument to OrderedMapIncompleteSize. -type OrderedMapIncompleteSizeAttr func(optionalAttr) +// DestroyResourceOpAttr is an optional argument to DestroyResourceOp. +type DestroyResourceOpAttr func(optionalAttr) -// OrderedMapIncompleteSizeCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value. // -// REQUIRES: value >= 0 -func OrderedMapIncompleteSizeCapacity(value int64) OrderedMapIncompleteSizeAttr { +// value: whether to ignore the error when the resource +// doesn't exist. +// If not specified, defaults to true +func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr { return func(m optionalAttr) { - m["capacity"] = value + m["ignore_lookup_error"] = value } } -// OrderedMapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Deletes the resource specified by the handle. // -// REQUIRES: value >= 0 -func OrderedMapIncompleteSizeMemoryLimit(value int64) OrderedMapIncompleteSizeAttr { - return func(m optionalAttr) { - m["memory_limit"] = value +// All subsequent operations using the resource will result in a NotFound +// error status. +// +// Arguments: +// resource: handle to the resource to delete. +// +// Returns the created operation. +func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) { + if scope.Err() != nil { + return } -} - -// OrderedMapIncompleteSizeContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapIncompleteSizeContainer(value string) OrderedMapIncompleteSizeAttr { - return func(m optionalAttr) { - m["container"] = value + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) } -} - -// OrderedMapIncompleteSizeSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapIncompleteSizeSharedName(value string) OrderedMapIncompleteSizeAttr { - return func(m optionalAttr) { - m["shared_name"] = value + opspec := tf.OpSpec{ + Type: "DestroyResourceOp", + Input: []tf.Input{ + resource, + }, + Attrs: attrs, } + return scope.AddOperation(opspec) } -// Op returns the number of incomplete elements in the underlying container. -func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapIncompleteSizeAttr) (size tf.Output) { +// Generates values in an interval. +// +// A sequence of `num` evenly-spaced values are generated beginning at `start`. +// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`, +// so that the last one is exactly `stop`. +// +// For example: +// +// ``` +// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] +// ``` +// +// Arguments: +// start: 0-D tensor. First entry in the range. +// stop: 0-D tensor. Last entry in the range. +// num: 0-D tensor. Number of values to generate. +// +// Returns 1-D. The generated values. +func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "OrderedMapIncompleteSize", - - Attrs: attrs, + Type: "LinSpace", + Input: []tf.Input{ + start, stop, num, + }, } op := scope.AddOperation(opspec) return op.Output(0) @@ -20327,83 +21667,6 @@ func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x return op.Output(0), op.Output(1), op.Output(2) } -// MfccAttr is an optional argument to Mfcc. -type MfccAttr func(optionalAttr) - -// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value. -// -// value: The highest frequency to use when calculating the -// ceptstrum. -// If not specified, defaults to 4000 -func MfccUpperFrequencyLimit(value float32) MfccAttr { - return func(m optionalAttr) { - m["upper_frequency_limit"] = value - } -} - -// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value. -// -// value: The lowest frequency to use when calculating the -// ceptstrum. -// If not specified, defaults to 20 -func MfccLowerFrequencyLimit(value float32) MfccAttr { - return func(m optionalAttr) { - m["lower_frequency_limit"] = value - } -} - -// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value. -// -// value: Resolution of the Mel bank used internally. -// If not specified, defaults to 40 -func MfccFilterbankChannelCount(value int64) MfccAttr { - return func(m optionalAttr) { - m["filterbank_channel_count"] = value - } -} - -// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value. -// -// value: How many output channels to produce per time slice. -// If not specified, defaults to 13 -func MfccDctCoefficientCount(value int64) MfccAttr { - return func(m optionalAttr) { - m["dct_coefficient_count"] = value - } -} - -// Transforms a spectrogram into a form that's useful for speech recognition. -// -// Mel Frequency Cepstral Coefficients are a way of representing audio data that's -// been effective as an input feature for machine learning. They are created by -// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the -// higher frequencies that are less significant to the human ear. They have a long -// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum -// is a good resource to learn more. -// -// Arguments: -// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared -// set to true. -// sample_rate: How many samples per second the source audio used. -func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Mfcc", - Input: []tf.Input{ - spectrogram, sample_rate, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Given a quantized tensor described by (input, input_min, input_max), outputs a // // range that covers the actual values present in that tensor. This op is @@ -20737,25 +22000,56 @@ func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (conten // The tensor `values` must be of the type of the table values. // // Arguments: -// table_handle: Handle to the table. -// keys: Any shape. Keys to look up. -// values: Values to associate with keys. +// table_handle: Handle to the table. +// keys: Any shape. Keys to look up. +// values: Values to associate with keys. +// +// Returns the created operation. +func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LookupTableInsertV2", + Input: []tf.Input{ + table_handle, keys, values, + }, + } + return scope.AddOperation(opspec) +} + +// Creates a dataset that batches and pads `batch_size` elements from the input. +// +// Arguments: // -// Returns the created operation. -func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) { +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// padded_shapes: A list of int64 tensors representing the desired padded shapes +// of the corresponding output components. These shapes may be partially +// specified, using `-1` to indicate that a particular dimension should be +// padded to the maximum size of all batch elements. +// padding_values: A list of scalars containing the padding value to use for +// each of the outputs. +// drop_remainder: A scalar representing whether the last batch should be dropped in case its size +// is smaller than desired. +// +func PaddedBatchDatasetV2(scope *Scope, input_dataset tf.Output, batch_size tf.Output, padded_shapes []tf.Output, padding_values []tf.Output, drop_remainder tf.Output, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "LookupTableInsertV2", + Type: "PaddedBatchDatasetV2", Input: []tf.Input{ - table_handle, keys, values, + input_dataset, batch_size, tf.OutputList(padded_shapes), tf.OutputList(padding_values), drop_remainder, }, + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Returns element-wise smallest integer in not less than x. +// Returns element-wise smallest integer not less than x. func Ceil(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return @@ -22084,7 +23378,7 @@ func TensorListSetItem(scope *Scope, input_handle tf.Output, index tf.Output, it // Computes the matrix exponential of one or more square matrices: // -// exp(A) = \sum_{n=0}^\infty A^n/n! +// \\(exp(A) = \sum_{n=0}^\infty A^n/n!\\) // // The exponential is computed using a combination of the scaling and squaring // method and the Pade approximation. Details can be founds in: @@ -22464,6 +23758,28 @@ func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...Matr return op.Output(0) } +// Returns a serialized GraphDef representing `input_dataset`. +// +// Returns a graph representation for `input_dataset`. +// +// Arguments: +// input_dataset: A variant tensor representing the dataset to return the graph representation for. +// +// Returns The graph representation of the dataset (as serialized GraphDef). +func DatasetToGraph(scope *Scope, input_dataset tf.Output) (graph tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DatasetToGraph", + Input: []tf.Input{ + input_dataset, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // SvdAttr is an optional argument to Svd. type SvdAttr func(optionalAttr) @@ -23460,6 +24776,145 @@ func ReaderReadUpToV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Out return op.Output(0), op.Output(1) } +// BatchAttr is an optional argument to Batch. +type BatchAttr func(optionalAttr) + +// BatchMaxEnqueuedBatches sets the optional max_enqueued_batches attribute to value. +// If not specified, defaults to 10 +func BatchMaxEnqueuedBatches(value int64) BatchAttr { + return func(m optionalAttr) { + m["max_enqueued_batches"] = value + } +} + +// BatchAllowedBatchSizes sets the optional allowed_batch_sizes attribute to value. +// If not specified, defaults to <> +func BatchAllowedBatchSizes(value []int64) BatchAttr { + return func(m optionalAttr) { + m["allowed_batch_sizes"] = value + } +} + +// BatchContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func BatchContainer(value string) BatchAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// BatchSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func BatchSharedName(value string) BatchAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// BatchBatchingQueue sets the optional batching_queue attribute to value. +// If not specified, defaults to "" +func BatchBatchingQueue(value string) BatchAttr { + return func(m optionalAttr) { + m["batching_queue"] = value + } +} + +// Batches all input tensors nondeterministically. +// +// When many instances of this Op are being run concurrently with the same +// container/shared_name in the same device, some will output zero-shaped Tensors +// and others will output Tensors of size up to max_batch_size. +// +// All Tensors in in_tensors are batched together (so, for example, labels and +// features should be batched with a single instance of this operation. +// +// Each invocation of batch emits an `id` scalar which will be used to identify +// this particular invocation when doing unbatch or its gradient. +// +// Each op which emits a non-empty batch will also emit a non-empty batch_index +// Tensor, which, is a [K, 3] matrix where each row contains the invocation's id, +// start, and length of elements of each set of Tensors present in batched_tensors. +// +// Batched tensors are concatenated along the first dimension, and all tensors in +// in_tensors must have the first dimension of the same size. +// +// in_tensors: The tensors to be batched. +// num_batch_threads: Number of scheduling threads for processing batches of work. +// Determines the number of batches processed in parallel. +// max_batch_size: Batch sizes will never be bigger than this. +// batch_timeout_micros: Maximum number of microseconds to wait before outputting +// an incomplete batch. +// allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does +// nothing. Otherwise, supplies a list of batch sizes, causing the op to pad +// batches up to one of those sizes. The entries must increase monotonically, and +// the final entry must equal max_batch_size. +// grad_timeout_micros: The timeout to use for the gradient. See Unbatch. +// batched_tensors: Either empty tensors or a batch of concatenated Tensors. +// batch_index: If out_tensors is non-empty, has information to invert it. +// container: Controls the scope of sharing of this batch. +// id: always contains a scalar with a unique ID for this invocation of Batch. +// shared_name: Concurrently running instances of batch in the same device with the +// same container and shared_name will batch their elements together. If left +// empty, the op name will be used as the shared name. +// T: the types of tensors to be batched. +func Batch(scope *Scope, in_tensors []tf.Output, num_batch_threads int64, max_batch_size int64, batch_timeout_micros int64, grad_timeout_micros int64, optional ...BatchAttr) (batched_tensors []tf.Output, batch_index tf.Output, id tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_batch_threads": num_batch_threads, "max_batch_size": max_batch_size, "batch_timeout_micros": batch_timeout_micros, "grad_timeout_micros": grad_timeout_micros} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Batch", + Input: []tf.Input{ + tf.OutputList(in_tensors), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if batched_tensors, idx, err = makeOutputList(op, idx, "batched_tensors"); err != nil { + scope.UpdateErr("Batch", err) + return + } + batch_index = op.Output(idx) + id = op.Output(idx) + return batched_tensors, batch_index, id +} + +// Adjust the hue of one or more images. +// +// `images` is a tensor of at least 3 dimensions. The last dimension is +// interpretted as channels, and must be three. +// +// The input image is considered in the RGB colorspace. Conceptually, the RGB +// colors are first mapped into HSV. A delta is then applied all the hue values, +// and then remapped back to RGB colorspace. +// +// Arguments: +// images: Images to adjust. At least 3-D. +// delta: A float delta to add to the hue. +// +// Returns The hue-adjusted image or images. +func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AdjustHue", + Input: []tf.Input{ + images, delta, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. type ResourceApplyAdamAttr func(optionalAttr) @@ -23487,10 +24942,10 @@ func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr { // Update '*var' according to the Adam algorithm. // -// lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t -// v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t -// variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon) +// $$lr_t := \text{learning_rate} * \sqrt{(1 - beta_2^t) / (1 - beta_1^t)}$$ +// $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ +// $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ +// $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ // // Arguments: // var_: Should be from a Variable(). @@ -23827,126 +25282,61 @@ func DecodePngDtype(value tf.DataType) DecodePngAttr { // * 3: output an RGB image. // * 4: output an RGBA image. // -// If needed, the PNG-encoded image is transformed to match the requested number -// of color channels. -// -// This op also supports decoding JPEGs and non-animated GIFs since the interface -// is the same, though it is cleaner to use `tf.image.decode_image`. -// -// Arguments: -// contents: 0-D. The PNG-encoded image. -// -// Returns 3-D with shape `[height, width, channels]`. -func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodePng", - Input: []tf.Input{ - contents, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Decode the first frame of a GIF-encoded image to a uint8 tensor. -// -// GIF with frame or transparency compression are not supported -// convert animated GIF from compressed to uncompressed by: -// -// convert $src.gif -coalesce $dst.gif -// -// This op also supports decoding JPEGs and PNGs, though it is cleaner to use -// `tf.image.decode_image`. -// -// Arguments: -// contents: 0-D. The GIF-encoded image. -// -// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order -func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DecodeGif", - Input: []tf.Input{ - contents, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes the gradient of the sigmoid of `x` wrt its input. -// -// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and -// `dy` is the corresponding input gradient. -func SigmoidGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SigmoidGrad", - Input: []tf.Input{ - y, dy, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Convert one or more images from HSV to RGB. -// -// Outputs a tensor of the same shape as the `images` tensor, containing the RGB -// value of the pixels. The output is only well defined if the value in `images` -// are in `[0,1]`. -// -// See `rgb_to_hsv` for a description of the HSV encoding. +// If needed, the PNG-encoded image is transformed to match the requested number +// of color channels. +// +// This op also supports decoding JPEGs and non-animated GIFs since the interface +// is the same, though it is cleaner to use `tf.image.decode_image`. // // Arguments: -// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3. +// contents: 0-D. The PNG-encoded image. // -// Returns `images` converted to RGB. -func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) { +// Returns 3-D with shape `[height, width, channels]`. +func DecodePng(scope *Scope, contents tf.Output, optional ...DecodePngAttr) (image tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "HSVToRGB", + Type: "DecodePng", Input: []tf.Input{ - images, + contents, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics. +// Decode the first frame of a GIF-encoded image to a uint8 tensor. +// +// GIF with frame or transparency compression are not supported +// convert animated GIF from compressed to uncompressed by: +// +// convert $src.gif -coalesce $dst.gif +// +// This op also supports decoding JPEGs and PNGs, though it is cleaner to use +// `tf.image.decode_image`. // // Arguments: -// tree_ensemble_handle: Handle to the tree ensemble. +// contents: 0-D. The GIF-encoded image. // -// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest -// layer. -func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) { +// Returns 4-D with shape `[num_frames, height, width, 3]`. RGB order +func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BoostedTreesGetEnsembleStates", + Type: "DecodeGif", Input: []tf.Input{ - tree_ensemble_handle, + contents, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) + return op.Output(0) } // Gets the next output from the given iterator. @@ -24528,10 +25918,124 @@ func NonMaxSuppressionV2(scope *Scope, boxes tf.Output, scores tf.Output, max_ou return op.Output(0) } +// Greedily selects a subset of bounding boxes in descending order of score, +// +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes with score less than +// `score_threshold` are removed. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system and more +// generally is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// selected_indices = tf.image.non_max_suppression_v2( +// boxes, scores, max_output_size, iou_threshold, score_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) +// +// Arguments: +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// iou_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too much with respect to IOU. +// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove +// boxes based on score. +// +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`. +func NonMaxSuppressionV3(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output, score_threshold tf.Output) (selected_indices tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "NonMaxSuppressionV3", + Input: []tf.Input{ + boxes, scores, max_output_size, iou_threshold, score_threshold, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// NonMaxSuppressionV4Attr is an optional argument to NonMaxSuppressionV4. +type NonMaxSuppressionV4Attr func(optionalAttr) + +// NonMaxSuppressionV4PadToMaxOutputSize sets the optional pad_to_max_output_size attribute to value. +// +// value: If true, the output `selected_indices` is padded to be of length +// `max_output_size`. Defaults to false. +// If not specified, defaults to false +func NonMaxSuppressionV4PadToMaxOutputSize(value bool) NonMaxSuppressionV4Attr { + return func(m optionalAttr) { + m["pad_to_max_output_size"] = value + } +} + +// Greedily selects a subset of bounding boxes in descending order of score, +// +// pruning away boxes that have high intersection-over-union (IOU) overlap +// with previously selected boxes. Bounding boxes with score less than +// `score_threshold` are removed. Bounding boxes are supplied as +// [y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +// diagonal pair of box corners and the coordinates can be provided as normalized +// (i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +// is agnostic to where the origin is in the coordinate system and more +// generally is invariant to orthogonal transformations and translations +// of the coordinate system; thus translating or reflections of the coordinate +// system result in the same boxes being selected by the algorithm. +// The output of this operation is a set of integers indexing into the input +// collection of bounding boxes representing the selected boxes. The bounding +// box coordinates corresponding to the selected indices can then be obtained +// using the `tf.gather operation`. For example: +// selected_indices = tf.image.non_max_suppression_v2( +// boxes, scores, max_output_size, iou_threshold, score_threshold) +// selected_boxes = tf.gather(boxes, selected_indices) +// +// Arguments: +// boxes: A 2-D float tensor of shape `[num_boxes, 4]`. +// scores: A 1-D float tensor of shape `[num_boxes]` representing a single +// score corresponding to each box (each row of boxes). +// max_output_size: A scalar integer tensor representing the maximum number of +// boxes to be selected by non max suppression. +// iou_threshold: A 0-D float tensor representing the threshold for deciding whether +// boxes overlap too much with respect to IOU. +// score_threshold: A 0-D float tensor representing the threshold for deciding when to remove +// boxes based on score. +// +// Returns A 1-D integer tensor of shape `[M]` representing the selected +// indices from the boxes tensor, where `M <= max_output_size`.A 0-D integer tensor representing the number of valid elements in +// `selected_indices`, with the valid elements appearing first. +func NonMaxSuppressionV4(scope *Scope, boxes tf.Output, scores tf.Output, max_output_size tf.Output, iou_threshold tf.Output, score_threshold tf.Output, optional ...NonMaxSuppressionV4Attr) (selected_indices tf.Output, valid_outputs tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "NonMaxSuppressionV4", + Input: []tf.Input{ + boxes, scores, max_output_size, iou_threshold, score_threshold, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + // Computes the matrix logarithm of one or more square matrices: // // -// log(exp(A)) = A +// \\(log(exp(A)) = A\\) // // This op is only defined for complex matrices. If A is positive-definite and // real, then casting to a complex matrix, taking the logarithm and casting back @@ -24568,6 +26072,31 @@ func MatrixLogarithm(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } +// This op is used as a placeholder in If branch functions. It doesn't provide a +// valid output when run, so must either be removed (e.g. replaced with a +// function input) or guaranteed not to be used (e.g. if mirroring an +// intermediate output needed for the gradient computation of the other branch). +// +// Arguments: +// dtype: The type of the output. +// shape: The purported shape of the output. This is only used for shape inference; +// the output will not necessarily have this shape. Can be a partial shape. +// +// Returns \"Fake\" output value. This should not be consumed by another op. +func FakeParam(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype, "shape": shape} + opspec := tf.OpSpec{ + Type: "FakeParam", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // EncodeProtoAttr is an optional argument to EncodeProto. type EncodeProtoAttr func(optionalAttr) @@ -24674,160 +26203,35 @@ func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names // gradient nodes may accidentally flow through the same accumulator TensorArray. // This double counts and generally breaks the TensorArray gradient flow. // -// The solution is to identify which gradient call this particular -// TensorArray gradient is being called in. This is performed by identifying -// a unique string (e.g. "gradients", "gradients_1", ...) from the input -// gradient Tensor's name. This string is used as a suffix when creating -// the TensorArray gradient object here (the attribute `source`). -// -// The attribute `source` is added as a suffix to the forward TensorArray's -// name when performing the creation / lookup, so that each separate gradient -// calculation gets its own TensorArray accumulator. -// -// Arguments: -// handle: The handle to the forward TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. -// source: The gradient source string, used to decide which gradient TensorArray -// to return. -func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"source": source} - opspec := tf.OpSpec{ - Type: "TensorArrayGradV3", - Input: []tf.Input{ - handle, flow_in, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// DecodeProtoV2Attr is an optional argument to DecodeProtoV2. -type DecodeProtoV2Attr func(optionalAttr) - -// DecodeProtoV2DescriptorSource sets the optional descriptor_source attribute to value. -// -// value: Either the special value `local://` or a path to a file containing -// a serialized `FileDescriptorSet`. -// If not specified, defaults to "local://" -func DecodeProtoV2DescriptorSource(value string) DecodeProtoV2Attr { - return func(m optionalAttr) { - m["descriptor_source"] = value } -} - -// DecodeProtoV2MessageFormat sets the optional message_format attribute to value. -// -// value: Either `binary` or `text`. -// If not specified, defaults to "binary" -func DecodeProtoV2MessageFormat(value string) DecodeProtoV2Attr { - return func(m optionalAttr) { - m["message_format"] = value - } -} - -// DecodeProtoV2Sanitize sets the optional sanitize attribute to value. -// -// value: Whether to sanitize the result or not. -// If not specified, defaults to false -func DecodeProtoV2Sanitize(value bool) DecodeProtoV2Attr { - return func(m optionalAttr) { - m["sanitize"] = value - } -} - -// The op extracts fields from a serialized protocol buffers message into tensors. -// -// The `decode_proto` op extracts fields from a serialized protocol buffers -// message into tensors. The fields in `field_names` are decoded and converted -// to the corresponding `output_types` if possible. -// -// A `message_type` name must be provided to give context for the field -// names. The actual message descriptor can be looked up either in the -// linked-in descriptor pool or a filename provided by the caller using -// the `descriptor_source` attribute. -// -// Each output tensor is a dense tensor. This means that it is padded to -// hold the largest number of repeated elements seen in the input -// minibatch. (The shape is also padded by one to prevent zero-sized -// dimensions). The actual repeat counts for each example in the -// minibatch can be found in the `sizes` output. In many cases the output -// of `decode_proto` is fed immediately into tf.squeeze if missing values -// are not a concern. When using tf.squeeze, always pass the squeeze -// dimension explicitly to avoid surprises. -// -// For the most part, the mapping between Proto field types and -// TensorFlow dtypes is straightforward. However, there are a few -// special cases: -// -// - A proto field that contains a submessage or group can only be converted -// to `DT_STRING` (the serialized submessage). This is to reduce the -// complexity of the API. The resulting string can be used as input -// to another instance of the decode_proto op. -// -// - TensorFlow lacks support for unsigned integers. The ops represent uint64 -// types as a `DT_INT64` with the same twos-complement bit pattern -// (the obvious way). Unsigned int32 values can be represented exactly by -// specifying type `DT_INT64`, or using twos-complement if the caller -// specifies `DT_INT32` in the `output_types` attribute. -// -// The `descriptor_source` attribute selects a source of protocol -// descriptors to consult when looking up `message_type`. This may be a -// filename containing a serialized `FileDescriptorSet` message, -// or the special value `local://`, in which case only descriptors linked -// into the code will be searched; the filename can be on any filesystem -// accessible to TensorFlow. -// -// You can build a `descriptor_source` file using the `--descriptor_set_out` -// and `--include_imports` options to the protocol compiler `protoc`. -// -// The `local://` database only covers descriptors linked into the -// code via C++ libraries, not Python imports. You can link in a proto descriptor -// by creating a cc_library target with alwayslink=1. +// The solution is to identify which gradient call this particular +// TensorArray gradient is being called in. This is performed by identifying +// a unique string (e.g. "gradients", "gradients_1", ...) from the input +// gradient Tensor's name. This string is used as a suffix when creating +// the TensorArray gradient object here (the attribute `source`). // -// Both binary and text proto serializations are supported, and can be -// chosen using the `format` attribute. +// The attribute `source` is added as a suffix to the forward TensorArray's +// name when performing the creation / lookup, so that each separate gradient +// calculation gets its own TensorArray accumulator. // // Arguments: -// bytes: Tensor of serialized protos with shape `batch_shape`. -// message_type: Name of the proto message type to decode. -// field_names: List of strings containing proto field names. -// output_types: List of TF types to use for the respective field in field_names. -// -// Returns Tensor of int32 with shape `[batch_shape, len(field_names)]`. -// Each entry is the number of values found for the corresponding field. -// Optional fields may have 0 or 1 values.List of tensors containing values for the corresponding field. -// `values[i]` has datatype `output_types[i]` -// and shape `[batch_shape, max(sizes[...,i])]`. -func DecodeProtoV2(scope *Scope, bytes tf.Output, message_type string, field_names []string, output_types []tf.DataType, optional ...DecodeProtoV2Attr) (sizes tf.Output, values []tf.Output) { +// handle: The handle to the forward TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// source: The gradient source string, used to decide which gradient TensorArray +// to return. +func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"message_type": message_type, "field_names": field_names, "output_types": output_types} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"source": source} opspec := tf.OpSpec{ - Type: "DecodeProtoV2", + Type: "TensorArrayGradV3", Input: []tf.Input{ - bytes, + handle, flow_in, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - sizes = op.Output(idx) - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("DecodeProtoV2", err) - return - } - return sizes, values + return op.Output(0), op.Output(1) } // Creates a dataset that splits a SparseTensor into elements row-wise. @@ -24908,6 +26312,23 @@ func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) { return scope.AddOperation(opspec) } +// A dataset that splits the elements of its input into multiple elements. +func UnbatchDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "UnbatchDataset", + Input: []tf.Input{ + input_dataset, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // RpcAttr is an optional argument to Rpc. type RpcAttr func(optionalAttr) @@ -25160,6 +26581,36 @@ func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset t return op.Output(0) } +// Debugging/model interpretability outputs for each example. +// +// It traverses all the trees and computes debug metrics for individual examples, +// such as getting split feature ids and logits after each split along the decision +// path used to compute directional feature contributions. +// +// Arguments: +// +// bucketized_features: A list of rank 1 Tensors containing bucket id for each +// feature. +// logits_dimension: scalar, dimension of the logits, to be used for constructing the protos in +// examples_debug_outputs_serialized. +// +// Returns Output rank 1 Tensor containing a proto serialized as a string for each example. +func BoostedTreesExampleDebugOutputs(scope *Scope, tree_ensemble_handle tf.Output, bucketized_features []tf.Output, logits_dimension int64) (examples_debug_outputs_serialized tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"logits_dimension": logits_dimension} + opspec := tf.OpSpec{ + Type: "BoostedTreesExampleDebugOutputs", + Input: []tf.Input{ + tree_ensemble_handle, tf.OutputList(bucketized_features), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Adds a value to the current value of a variable. // // Any ReadVariableOp with a control dependency on this op is guaranteed to @@ -25748,81 +27199,6 @@ func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, out return op.Output(0) } -// Computes the sum along sparse segments of a tensor. -// -// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is -// misisng, the `output` tensor at that position will be zeroed. -// -// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of -// segments. -// -// For example: -// -// ```python -// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]) -// -// tf.sparse_segment_sum_with_num_segments( -// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3) -// # => [[0 0 0 0] -// # [0 0 0 0] -// # [0 0 0 0]] -// -// tf.sparse_segment_sum_with_num_segments(c, -// tf.constant([0, 1]), -// tf.constant([0, 2], -// num_segments=4)) -// # => [[ 1 2 3 4] -// # [ 0 0 0 0] -// # [-1 -2 -3 -4] -// # [ 0 0 0 0]] -// ``` -// -// Arguments: -// -// indices: A 1-D tensor. Has same rank as `segment_ids`. -// segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -// num_segments: Should equal the number of distinct segment IDs. -// -// Returns Has same shape as data, except for dimension 0 which -// has size `num_segments`. -func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SparseSegmentSumWithNumSegments", - Input: []tf.Input{ - data, indices, segment_ids, num_segments, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that executes a SQL query and emits rows of the result set. -// -// Arguments: -// driver_name: The database type. Currently, the only supported type is 'sqlite'. -// data_source_name: A connection string to connect to the database. -// query: A SQL query to execute. -// -// -func SqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "SqlDataset", - Input: []tf.Input{ - driver_name, data_source_name, query, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Creates a dataset that emits the records from one or more binary files. // // Arguments: @@ -25910,6 +27286,26 @@ func TFRecordDataset(scope *Scope, filenames tf.Output, compression_type tf.Outp return op.Output(0) } +// A container for an iterator resource. +// +// Returns A handle to the iterator that can be passed to a "MakeIterator" or +// "IteratorGetNext" op. In contrast to Iterator, AnonymousIterator prevents +// resource sharing by name, and does not keep a reference to the resource +// container. +func AnonymousIterator(scope *Scope, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "AnonymousIterator", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // BatchToSpace for 4-D tensors of type T. // // This is a legacy version of the more general BatchToSpaceND. @@ -26094,7 +27490,7 @@ func AdjustContrastv2(scope *Scope, images tf.Output, contrast_factor tf.Output) return op.Output(0) } -// Gets the next output from the given iterator. +// Gets the next output from the given iterator . func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { if scope.Err() != nil { return @@ -26413,113 +27809,68 @@ func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) { return op.Output(0) } -// Performs a padding as a preprocess during a convolution. -// -// Similar to FusedResizeAndPadConv2d, this op allows for an optimized -// implementation where the spatial padding transformation stage is fused with the -// im2col lookup, but in this case without the bilinear filtering required for -// resizing. Fusing the padding prevents the need to write out the intermediate -// results as whole tensors, reducing memory pressure, and we can get some latency -// gains by merging the transformation calculations. -// The data_format attribute for Conv2D isn't supported by this op, and 'NHWC' -// order is used instead. -// Internally this op uses a single per-graph scratch buffer, which means that it -// will block if multiple versions are being run in parallel. This is because this -// operator is primarily an optimization to minimize memory usage. +// Writes the given dataset to the given file using the TFRecord format. // // Arguments: -// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. +// input_dataset: A variant tensor representing the dataset to write. +// filename: A scalar string tensor representing the filename to use. +// compression_type: A scalar string tensor containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". // -// strides: 1-D of length 4. The stride of the sliding window for each dimension -// of `input`. Must be in the same order as the dimension specified with format. -// padding: The type of padding algorithm to use. -func FusedPadConv2D(scope *Scope, input tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string) (output tf.Output) { +// Returns the created operation. +func DatasetToTFRecord(scope *Scope, input_dataset tf.Output, filename tf.Output, compression_type tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "FusedPadConv2D", + Type: "DatasetToTFRecord", Input: []tf.Input{ - input, paddings, filter, + input_dataset, filename, compression_type, }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Conv2DBackpropInputAttr is an optional argument to Conv2DBackpropInput. -type Conv2DBackpropInputAttr func(optionalAttr) - -// Conv2DBackpropInputUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. -// If not specified, defaults to true -func Conv2DBackpropInputUseCudnnOnGpu(value bool) Conv2DBackpropInputAttr { - return func(m optionalAttr) { - m["use_cudnn_on_gpu"] = value - } -} +// AvgPool3DAttr is an optional argument to AvgPool3D. +type AvgPool3DAttr func(optionalAttr) -// Conv2DBackpropInputDataFormat sets the optional data_format attribute to value. +// AvgPool3DDataFormat sets the optional data_format attribute to value. // -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func AvgPool3DDataFormat(value string) AvgPool3DAttr { return func(m optionalAttr) { m["data_format"] = value } } -// Conv2DBackpropInputDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to -func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes the gradients of convolution with respect to the input. +// Performs 3D average pooling on the input. // // Arguments: -// input_sizes: An integer vector representing the shape of `input`, -// where `input` is a 4-D `[batch, height, width, channels]` tensor. -// filter: 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. -// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. Must be in the same order as the dimension specified with -// format. +// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. // padding: The type of padding algorithm to use. // -// Returns 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient -// w.r.t. the input of the convolution. -func Conv2DBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropInputAttr) (output tf.Output) { +// Returns The average pooled output tensor. +func AvgPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv2DBackpropInput", + Type: "AvgPool3D", Input: []tf.Input{ - input_sizes, filter, out_backprop, + input, }, Attrs: attrs, } @@ -26527,386 +27878,229 @@ func Conv2DBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, return op.Output(0) } -// Interleave the values from the `data` tensors into a single tensor. -// -// Builds a merged tensor such that -// -// ```python -// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] -// ``` -// -// For example, if each `indices[m]` is scalar or vector, we have -// -// ```python -// # Scalar indices: -// merged[indices[m], ...] = data[m][...] -// -// # Vector indices: -// merged[indices[m][i], ...] = data[m][i, ...] -// ``` +// A placeholder for input pipeline graph optimizations. // -// Each `data[i].shape` must start with the corresponding `indices[i].shape`, -// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we -// must have `data[i].shape = indices[i].shape + constant`. In terms of this -// `constant`, the output shape is -// -// merged.shape = [max(indices)] + constant -// -// Values are merged in order, so if an index appears in both `indices[m][i]` and -// `indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the -// merged result. If you do not need this guarantee, ParallelDynamicStitch might -// perform better on some devices. -// -// For example: -// -// ```python -// indices[0] = 6 -// indices[1] = [4, 1] -// indices[2] = [[5, 2], [0, 3]] -// data[0] = [61, 62] -// data[1] = [[41, 42], [11, 12]] -// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] -// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], -// [51, 52], [61, 62]] -// ``` -// -// This method can be used to merge partitions created by `dynamic_partition` -// as illustrated on the following example: -// -// ```python -// # Apply function (increments x_i) on elements for which a certain condition -// # apply (x_i != -1 in this example). -// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) -// condition_mask=tf.not_equal(x,tf.constant(-1.)) -// partitioned_data = tf.dynamic_partition( -// x, tf.cast(condition_mask, tf.int32) , 2) -// partitioned_data[1] = partitioned_data[1] + 1.0 -// condition_indices = tf.dynamic_partition( -// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) -// x = tf.dynamic_stitch(condition_indices, partitioned_data) -// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain -// # unchanged. -// ``` +// A placeholder for input pipeline graph optimizations. // -//
-// -//
-func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { +// Arguments: +// input_dataset: A variant tensor representing the input dataset. +func SinkDataset(scope *Scope, input_dataset tf.Output) (handle tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DynamicStitch", + Type: "SinkDataset", Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(data), + input_dataset, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns the truth value of (x == y) element-wise. -// -// *NOTE*: `Equal` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Equal(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Constructs an Optional variant from a tuple of tensors. +func OptionalFromValue(scope *Scope, components []tf.Output) (optional tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Equal", + Type: "OptionalFromValue", Input: []tf.Input{ - x, y, + tf.OutputList(components), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// TensorArrayGatherV2Attr is an optional argument to TensorArrayGatherV2. -type TensorArrayGatherV2Attr func(optionalAttr) +// DecodeProtoV2Attr is an optional argument to DecodeProtoV2. +type DecodeProtoV2Attr func(optionalAttr) -// TensorArrayGatherV2ElementShape sets the optional element_shape attribute to value. -// If not specified, defaults to -func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr { +// DecodeProtoV2DescriptorSource sets the optional descriptor_source attribute to value. +// +// value: Either the special value `local://` or a path to a file containing +// a serialized `FileDescriptorSet`. +// If not specified, defaults to "local://" +func DecodeProtoV2DescriptorSource(value string) DecodeProtoV2Attr { return func(m optionalAttr) { - m["element_shape"] = value + m["descriptor_source"] = value } } -// Deprecated. Use TensorArrayGatherV3 +// DecodeProtoV2MessageFormat sets the optional message_format attribute to value. // -// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3 -func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TensorArrayGatherV2", - Input: []tf.Input{ - handle, indices, flow_in, - }, - Attrs: attrs, +// value: Either `binary` or `text`. +// If not specified, defaults to "binary" +func DecodeProtoV2MessageFormat(value string) DecodeProtoV2Attr { + return func(m optionalAttr) { + m["message_format"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Interleave the values from the `data` tensors into a single tensor. +// DecodeProtoV2Sanitize sets the optional sanitize attribute to value. // -// Builds a merged tensor such that +// value: Whether to sanitize the result or not. +// If not specified, defaults to false +func DecodeProtoV2Sanitize(value bool) DecodeProtoV2Attr { + return func(m optionalAttr) { + m["sanitize"] = value + } +} + +// The op extracts fields from a serialized protocol buffers message into tensors. // -// ```python -// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] -// ``` +// The `decode_proto` op extracts fields from a serialized protocol buffers +// message into tensors. The fields in `field_names` are decoded and converted +// to the corresponding `output_types` if possible. // -// For example, if each `indices[m]` is scalar or vector, we have +// A `message_type` name must be provided to give context for the field +// names. The actual message descriptor can be looked up either in the +// linked-in descriptor pool or a filename provided by the caller using +// the `descriptor_source` attribute. // -// ```python -// # Scalar indices: -// merged[indices[m], ...] = data[m][...] +// Each output tensor is a dense tensor. This means that it is padded to +// hold the largest number of repeated elements seen in the input +// minibatch. (The shape is also padded by one to prevent zero-sized +// dimensions). The actual repeat counts for each example in the +// minibatch can be found in the `sizes` output. In many cases the output +// of `decode_proto` is fed immediately into tf.squeeze if missing values +// are not a concern. When using tf.squeeze, always pass the squeeze +// dimension explicitly to avoid surprises. // -// # Vector indices: -// merged[indices[m][i], ...] = data[m][i, ...] -// ``` +// For the most part, the mapping between Proto field types and +// TensorFlow dtypes is straightforward. However, there are a few +// special cases: // -// Each `data[i].shape` must start with the corresponding `indices[i].shape`, -// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we -// must have `data[i].shape = indices[i].shape + constant`. In terms of this -// `constant`, the output shape is +// - A proto field that contains a submessage or group can only be converted +// to `DT_STRING` (the serialized submessage). This is to reduce the +// complexity of the API. The resulting string can be used as input +// to another instance of the decode_proto op. // -// merged.shape = [max(indices)] + constant +// - TensorFlow lacks support for unsigned integers. The ops represent uint64 +// types as a `DT_INT64` with the same twos-complement bit pattern +// (the obvious way). Unsigned int32 values can be represented exactly by +// specifying type `DT_INT64`, or using twos-complement if the caller +// specifies `DT_INT32` in the `output_types` attribute. // -// Values may be merged in parallel, so if an index appears in both `indices[m][i]` -// and `indices[n][j]`, the result may be invalid. This differs from the normal -// DynamicStitch operator that defines the behavior in that case. +// The `descriptor_source` attribute selects a source of protocol +// descriptors to consult when looking up `message_type`. This may be a +// filename containing a serialized `FileDescriptorSet` message, +// or the special value `local://`, in which case only descriptors linked +// into the code will be searched; the filename can be on any filesystem +// accessible to TensorFlow. // -// For example: +// You can build a `descriptor_source` file using the `--descriptor_set_out` +// and `--include_imports` options to the protocol compiler `protoc`. // -// ```python -// indices[0] = 6 -// indices[1] = [4, 1] -// indices[2] = [[5, 2], [0, 3]] -// data[0] = [61, 62] -// data[1] = [[41, 42], [11, 12]] -// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] -// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], -// [51, 52], [61, 62]] -// ``` +// The `local://` database only covers descriptors linked into the +// code via C++ libraries, not Python imports. You can link in a proto descriptor +// by creating a cc_library target with alwayslink=1. // -// This method can be used to merge partitions created by `dynamic_partition` -// as illustrated on the following example: +// Both binary and text proto serializations are supported, and can be +// chosen using the `format` attribute. // -// ```python -// # Apply function (increments x_i) on elements for which a certain condition -// # apply (x_i != -1 in this example). -// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) -// condition_mask=tf.not_equal(x,tf.constant(-1.)) -// partitioned_data = tf.dynamic_partition( -// x, tf.cast(condition_mask, tf.int32) , 2) -// partitioned_data[1] = partitioned_data[1] + 1.0 -// condition_indices = tf.dynamic_partition( -// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) -// x = tf.dynamic_stitch(condition_indices, partitioned_data) -// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain -// # unchanged. -// ``` +// Arguments: +// bytes: Tensor of serialized protos with shape `batch_shape`. +// message_type: Name of the proto message type to decode. +// field_names: List of strings containing proto field names. +// output_types: List of TF types to use for the respective field in field_names. // -//
-// -//
-func ParallelDynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { +// Returns Tensor of int32 with shape `[batch_shape, len(field_names)]`. +// Each entry is the number of values found for the corresponding field. +// Optional fields may have 0 or 1 values.List of tensors containing values for the corresponding field. +// `values[i]` has datatype `output_types[i]` +// and shape `[batch_shape, max(sizes[...,i])]`. +func DecodeProtoV2(scope *Scope, bytes tf.Output, message_type string, field_names []string, output_types []tf.DataType, optional ...DecodeProtoV2Attr) (sizes tf.Output, values []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"message_type": message_type, "field_names": field_names, "output_types": output_types} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ParallelDynamicStitch", + Type: "DecodeProtoV2", Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(data), + bytes, }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + sizes = op.Output(idx) + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("DecodeProtoV2", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + return sizes, values } -// Computes the gradient for the inverse of `x` wrt its input. -// -// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` -// is the corresponding input gradient. -func InvGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { +// Creates an Optional variant with no value. +func OptionalNone(scope *Scope) (optional tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "InvGrad", - Input: []tf.Input{ - y, dy, - }, + Type: "OptionalNone", } op := scope.AddOperation(opspec) return op.Output(0) } -// List of the given size with empty elements. -// -// element_shape: the shape of the future elements of the list -// num_elements: the number of elements to reserve -// handle: the output list -// element_dtype: the desired type of elements in the list. -func TensorListReserve(scope *Scope, element_shape tf.Output, num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { +// Returns true if and only if the given Optional variant has a value. +func OptionalHasValue(scope *Scope, optional tf.Output) (has_value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "TensorListReserve", + Type: "OptionalHasValue", Input: []tf.Input{ - element_shape, num_elements, + optional, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// PriorityQueueV2Attr is an optional argument to PriorityQueueV2. -type PriorityQueueV2Attr func(optionalAttr) - -// PriorityQueueV2ComponentTypes sets the optional component_types attribute to value. -// -// value: The type of each component in a value. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func PriorityQueueV2ComponentTypes(value []tf.DataType) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["component_types"] = value - } -} - -// PriorityQueueV2Capacity sets the optional capacity attribute to value. -// -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func PriorityQueueV2Capacity(value int64) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// PriorityQueueV2Container sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func PriorityQueueV2Container(value string) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// PriorityQueueV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func PriorityQueueV2SharedName(value string) PriorityQueueV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A queue that produces elements sorted by the first component value. -// -// Note that the PriorityQueue requires the first component of any element -// to be a scalar int64, in addition to the other elements declared by -// component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue -// and DequeueMany) on a PriorityQueue will all require (resp. output) one extra -// entry in their input (resp. output) lists. +// Creates a dataset that executes a SQL query and emits rows of the result set. // // Arguments: -// shapes: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. +// driver_name: The database type. Currently, the only supported type is 'sqlite'. +// data_source_name: A connection string to connect to the database. +// query: A SQL query to execute. // -// Returns The handle to the queue. -func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV2Attr) (handle tf.Output) { +// +func SqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shapes": shapes} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "PriorityQueueV2", - + Type: "SqlDataset", + Input: []tf.Input{ + driver_name, data_source_name, query, + }, Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// UnstageAttr is an optional argument to Unstage. -type UnstageAttr func(optionalAttr) - -// UnstageCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func UnstageCapacity(value int64) UnstageAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// UnstageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func UnstageMemoryLimit(value int64) UnstageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value - } -} - -// UnstageContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func UnstageContainer(value string) UnstageAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// UnstageSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func UnstageSharedName(value string) UnstageAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op is similar to a lightweight Dequeue. -// -// The basic functionality is similar to dequeue with many fewer -// capabilities and options. This Op is optimized for performance. -func Unstage(scope *Scope, dtypes []tf.DataType, optional ...UnstageAttr) (values []tf.Output) { +// Returns the value stored in an Optional variant or raises an error if none exists. +func OptionalGetValue(scope *Scope, optional tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "Unstage", - + Type: "OptionalGetValue", + Input: []tf.Input{ + optional, + }, Attrs: attrs, } op := scope.AddOperation(opspec) @@ -26915,162 +28109,137 @@ func Unstage(scope *Scope, dtypes []tf.DataType, optional ...UnstageAttr) (value } var idx int var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("Unstage", err) + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("OptionalGetValue", err) return } - return values -} - -// QueueEnqueueV2Attr is an optional argument to QueueEnqueueV2. -type QueueEnqueueV2Attr func(optionalAttr) - -// QueueEnqueueV2TimeoutMs sets the optional timeout_ms attribute to value. -// -// value: If the queue is full, this operation will block for up to -// timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueEnqueueV2TimeoutMs(value int64) QueueEnqueueV2Attr { - return func(m optionalAttr) { - m["timeout_ms"] = value - } + return components } -// Enqueues a tuple of one or more tensors in the given queue. -// -// The components input has k elements, which correspond to the components of -// tuples stored in the given queue. -// -// N.B. If the queue is full, this operation will block until the given -// element has been enqueued (or 'timeout_ms' elapses, if specified). -// -// Arguments: -// handle: The handle to a queue. -// components: One or more tensors from which the enqueued tensors should be taken. -// -// Returns the created operation. -func QueueEnqueueV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueV2Attr) (o *tf.Operation) { +// Gets the next output from the given iterator as an Optional variant. +func IteratorGetNextAsOptional(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (optional tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "QueueEnqueueV2", + Type: "IteratorGetNextAsOptional", Input: []tf.Input{ - handle, tf.OutputList(components), + iterator, }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// QueueDequeueManyV2Attr is an optional argument to QueueDequeueManyV2. -type QueueDequeueManyV2Attr func(optionalAttr) - -// QueueDequeueManyV2TimeoutMs sets the optional timeout_ms attribute to value. -// -// value: If the queue has fewer than n elements, this operation -// will block for up to timeout_ms milliseconds. -// Note: This option is not supported yet. -// If not specified, defaults to -1 -func QueueDequeueManyV2TimeoutMs(value int64) QueueDequeueManyV2Attr { - return func(m optionalAttr) { - m["timeout_ms"] = value - } + op := scope.AddOperation(opspec) + return op.Output(0) } -// Dequeues `n` tuples of one or more tensors from the given queue. -// -// If the queue is closed and there are fewer than `n` elements, then an -// OutOfRange error is returned. -// -// This operation concatenates queue-element component tensors along the -// 0th dimension to make a single component tensor. All of the components -// in the dequeued tuple will have size `n` in the 0th dimension. -// -// This operation has `k` outputs, where `k` is the number of components in -// the tuples stored in the given queue, and output `i` is the ith -// component of the dequeued tuple. +// Performs a padding as a preprocess during a convolution. // -// N.B. If the queue is empty, this operation will block until `n` elements -// have been dequeued (or 'timeout_ms' elapses, if specified). +// Similar to FusedResizeAndPadConv2d, this op allows for an optimized +// implementation where the spatial padding transformation stage is fused with the +// im2col lookup, but in this case without the bilinear filtering required for +// resizing. Fusing the padding prevents the need to write out the intermediate +// results as whole tensors, reducing memory pressure, and we can get some latency +// gains by merging the transformation calculations. +// The data_format attribute for Conv2D isn't supported by this op, and 'NHWC' +// order is used instead. +// Internally this op uses a single per-graph scratch buffer, which means that it +// will block if multiple versions are being run in parallel. This is because this +// operator is primarily an optimization to minimize memory usage. // // Arguments: -// handle: The handle to a queue. -// n: The number of tuples to dequeue. -// component_types: The type of each component in a tuple. +// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. // -// Returns One or more tensors that were dequeued as a tuple. -func QueueDequeueManyV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueManyV2Attr) (components []tf.Output) { +// strides: 1-D of length 4. The stride of the sliding window for each dimension +// of `input`. Must be in the same order as the dimension specified with format. +// padding: The type of padding algorithm to use. +func FusedPadConv2D(scope *Scope, input tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "QueueDequeueManyV2", + Type: "FusedPadConv2D", Input: []tf.Input{ - handle, n, + input, paddings, filter, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("QueueDequeueManyV2", err) - return - } - return components + return op.Output(0) } -// EncodeBase64Attr is an optional argument to EncodeBase64. -type EncodeBase64Attr func(optionalAttr) +// Conv2DBackpropInputAttr is an optional argument to Conv2DBackpropInput. +type Conv2DBackpropInputAttr func(optionalAttr) -// EncodeBase64Pad sets the optional pad attribute to value. -// -// value: Bool whether padding is applied at the ends. -// If not specified, defaults to false -func EncodeBase64Pad(value bool) EncodeBase64Attr { +// Conv2DBackpropInputUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. +// If not specified, defaults to true +func Conv2DBackpropInputUseCudnnOnGpu(value bool) Conv2DBackpropInputAttr { return func(m optionalAttr) { - m["pad"] = value + m["use_cudnn_on_gpu"] = value } } -// Encode strings into web-safe base64 format. +// Conv2DBackpropInputDataFormat sets the optional data_format attribute to value. // -// Refer to the following article for more information on base64 format: -// en.wikipedia.org/wiki/Base64. Base64 strings may have padding with '=' at the -// end so that the encoded has length multiple of 4. See Padding section of the -// link above. +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Conv2DBackpropInputDilations sets the optional dilations attribute to value. // -// Web-safe means that the encoder uses - and _ instead of + and /. +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to +func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of convolution with respect to the input. // // Arguments: -// input: Strings to be encoded. +// input_sizes: An integer vector representing the shape of `input`, +// where `input` is a 4-D `[batch, height, width, channels]` tensor. +// filter: 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. +// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. Must be in the same order as the dimension specified with +// format. +// padding: The type of padding algorithm to use. // -// Returns Input strings encoded in base64. -func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (output tf.Output) { +// Returns 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient +// w.r.t. the input of the convolution. +func Conv2DBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropInputAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "EncodeBase64", + Type: "Conv2DBackpropInput", Input: []tf.Input{ - input, + input_sizes, filter, out_backprop, }, Attrs: attrs, } @@ -27078,183 +28247,245 @@ func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) ( return op.Output(0) } -// Deprecated. Use TensorArrayCloseV3 +// Interleave the values from the `data` tensors into a single tensor. // -// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3 +// Builds a merged tensor such that // -// Returns the created operation. -func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { +// ```python +// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] +// ``` +// +// For example, if each `indices[m]` is scalar or vector, we have +// +// ```python +// # Scalar indices: +// merged[indices[m], ...] = data[m][...] +// +// # Vector indices: +// merged[indices[m][i], ...] = data[m][i, ...] +// ``` +// +// Each `data[i].shape` must start with the corresponding `indices[i].shape`, +// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we +// must have `data[i].shape = indices[i].shape + constant`. In terms of this +// `constant`, the output shape is +// +// merged.shape = [max(indices)] + constant +// +// Values are merged in order, so if an index appears in both `indices[m][i]` and +// `indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the +// merged result. If you do not need this guarantee, ParallelDynamicStitch might +// perform better on some devices. +// +// For example: +// +// ```python +// indices[0] = 6 +// indices[1] = [4, 1] +// indices[2] = [[5, 2], [0, 3]] +// data[0] = [61, 62] +// data[1] = [[41, 42], [11, 12]] +// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] +// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], +// [51, 52], [61, 62]] +// ``` +// +// This method can be used to merge partitions created by `dynamic_partition` +// as illustrated on the following example: +// +// ```python +// # Apply function (increments x_i) on elements for which a certain condition +// # apply (x_i != -1 in this example). +// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) +// condition_mask=tf.not_equal(x,tf.constant(-1.)) +// partitioned_data = tf.dynamic_partition( +// x, tf.cast(condition_mask, tf.int32) , 2) +// partitioned_data[1] = partitioned_data[1] + 1.0 +// condition_indices = tf.dynamic_partition( +// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) +// x = tf.dynamic_stitch(condition_indices, partitioned_data) +// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain +// # unchanged. +// ``` +// +//
+// +//
+func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArrayCloseV2", + Type: "DynamicStitch", Input: []tf.Input{ - handle, + tf.OutputList(indices), tf.OutputList(data), }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Forwards the value of an available tensor from `inputs` to `output`. -// -// `Merge` waits for at least one of the tensors in `inputs` to become available. -// It is usually combined with `Switch` to implement branching. -// -// `Merge` forwards the first tensor to become available to `output`, and sets -// `value_index` to its index in `inputs`. -// -// Arguments: -// inputs: The input tensors, exactly one of which will become available. +// Returns the truth value of (x == y) element-wise. // -// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. -func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { +// *NOTE*: `Equal` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Equal(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Merge", + Type: "Equal", Input: []tf.Input{ - tf.OutputList(inputs), + x, y, }, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// QueueCloseV2Attr is an optional argument to QueueCloseV2. -type QueueCloseV2Attr func(optionalAttr) +// TensorArrayGatherV2Attr is an optional argument to TensorArrayGatherV2. +type TensorArrayGatherV2Attr func(optionalAttr) -// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. -// -// value: If true, all pending enqueue requests that are -// blocked on the given queue will be canceled. -// If not specified, defaults to false -func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { +// TensorArrayGatherV2ElementShape sets the optional element_shape attribute to value. +// If not specified, defaults to +func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr { return func(m optionalAttr) { - m["cancel_pending_enqueues"] = value + m["element_shape"] = value } } -// Closes the given queue. -// -// This operation signals that no more elements will be enqueued in the -// given queue. Subsequent Enqueue(Many) operations will fail. -// Subsequent Dequeue(Many) operations will continue to succeed if -// sufficient elements remain in the queue. Subsequent Dequeue(Many) -// operations that would block will fail immediately. -// -// Arguments: -// handle: The handle to a queue. +// Deprecated. Use TensorArrayGatherV3 // -// Returns the created operation. -func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { +// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3 +func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "QueueCloseV2", + Type: "TensorArrayGatherV2", Input: []tf.Input{ - handle, + handle, indices, flow_in, }, Attrs: attrs, } - return scope.AddOperation(opspec) -} - -// Computes inverse hyperbolic tangent of x element-wise. -func Atanh(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Atanh", - Input: []tf.Input{ - x, - }, - } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns true if queue is closed. +// Interleave the values from the `data` tensors into a single tensor. +// +// Builds a merged tensor such that +// +// ```python +// merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...] +// ``` +// +// For example, if each `indices[m]` is scalar or vector, we have +// +// ```python +// # Scalar indices: +// merged[indices[m], ...] = data[m][...] +// +// # Vector indices: +// merged[indices[m][i], ...] = data[m][i, ...] +// ``` +// +// Each `data[i].shape` must start with the corresponding `indices[i].shape`, +// and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we +// must have `data[i].shape = indices[i].shape + constant`. In terms of this +// `constant`, the output shape is // -// This operation returns true if the queue is closed and false if the queue -// is open. +// merged.shape = [max(indices)] + constant // -// Arguments: -// handle: The handle to a queue. -func QueueIsClosedV2(scope *Scope, handle tf.Output) (is_closed tf.Output) { +// Values may be merged in parallel, so if an index appears in both `indices[m][i]` +// and `indices[n][j]`, the result may be invalid. This differs from the normal +// DynamicStitch operator that defines the behavior in that case. +// +// For example: +// +// ```python +// indices[0] = 6 +// indices[1] = [4, 1] +// indices[2] = [[5, 2], [0, 3]] +// data[0] = [61, 62] +// data[1] = [[41, 42], [11, 12]] +// data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]] +// merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42], +// [51, 52], [61, 62]] +// ``` +// +// This method can be used to merge partitions created by `dynamic_partition` +// as illustrated on the following example: +// +// ```python +// # Apply function (increments x_i) on elements for which a certain condition +// # apply (x_i != -1 in this example). +// x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4]) +// condition_mask=tf.not_equal(x,tf.constant(-1.)) +// partitioned_data = tf.dynamic_partition( +// x, tf.cast(condition_mask, tf.int32) , 2) +// partitioned_data[1] = partitioned_data[1] + 1.0 +// condition_indices = tf.dynamic_partition( +// tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2) +// x = tf.dynamic_stitch(condition_indices, partitioned_data) +// # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain +// # unchanged. +// ``` +// +//
+// +//
+func ParallelDynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "QueueIsClosedV2", + Type: "ParallelDynamicStitch", Input: []tf.Input{ - handle, + tf.OutputList(indices), tf.OutputList(data), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Computes the absolute value of a tensor. +// Computes the gradient for the inverse of `x` wrt its input. // -// Given a tensor `x`, this operation returns a tensor containing the absolute -// value of each element in `x`. For example, if x is an input element and y is -// an output element, this operation computes \\(y = |x|\\). -func Abs(scope *Scope, x tf.Output) (y tf.Output) { +// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy` +// is the corresponding input gradient. +func InvGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "Abs", + Type: "InvGrad", Input: []tf.Input{ - x, + y, dy, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// StackV2Attr is an optional argument to StackV2. -type StackV2Attr func(optionalAttr) - -// StackV2StackName sets the optional stack_name attribute to value. -// -// value: Overrides the name used for the temporary stack resource. Default -// value is the name of the 'Stack' op (which is guaranteed unique). -// If not specified, defaults to "" -func StackV2StackName(value string) StackV2Attr { - return func(m optionalAttr) { - m["stack_name"] = value - } -} - -// A stack that produces elements in first-in last-out order. -// -// Arguments: -// max_size: The maximum size of the stack if non-negative. If negative, the stack -// size is unlimited. -// elem_type: The type of the elements on the stack. +// List of the given size with empty elements. // -// Returns The handle to the stack. -func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) { +// element_shape: the shape of the future elements of the list +// num_elements: the number of elements to reserve +// handle: the output list +// element_dtype: the desired type of elements in the list. +func TensorListReserve(scope *Scope, element_shape tf.Output, num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"elem_type": elem_type} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"element_dtype": element_dtype} opspec := tf.OpSpec{ - Type: "StackV2", + Type: "TensorListReserve", Input: []tf.Input{ - max_size, + element_shape, num_elements, }, Attrs: attrs, } @@ -27262,823 +28493,642 @@ func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional . return op.Output(0) } -// FusedBatchNormGradV2Attr is an optional argument to FusedBatchNormGradV2. -type FusedBatchNormGradV2Attr func(optionalAttr) +// PriorityQueueV2Attr is an optional argument to PriorityQueueV2. +type PriorityQueueV2Attr func(optionalAttr) -// FusedBatchNormGradV2Epsilon sets the optional epsilon attribute to value. +// PriorityQueueV2ComponentTypes sets the optional component_types attribute to value. // -// value: A small float number added to the variance of x. -// If not specified, defaults to 0.0001 -func FusedBatchNormGradV2Epsilon(value float32) FusedBatchNormGradV2Attr { - return func(m optionalAttr) { - m["epsilon"] = value - } -} - -// FusedBatchNormGradV2DataFormat sets the optional data_format attribute to value. +// value: The type of each component in a value. +// If not specified, defaults to <> // -// value: The data format for y_backprop, x, x_backprop. -// Either "NHWC" (default) or "NCHW". -// If not specified, defaults to "NHWC" -func FusedBatchNormGradV2DataFormat(value string) FusedBatchNormGradV2Attr { +// REQUIRES: len(value) >= 0 +func PriorityQueueV2ComponentTypes(value []tf.DataType) PriorityQueueV2Attr { return func(m optionalAttr) { - m["data_format"] = value + m["component_types"] = value } } -// FusedBatchNormGradV2IsTraining sets the optional is_training attribute to value. +// PriorityQueueV2Capacity sets the optional capacity attribute to value. // -// value: A bool value to indicate the operation is for training (default) -// or inference. -// If not specified, defaults to true -func FusedBatchNormGradV2IsTraining(value bool) FusedBatchNormGradV2Attr { +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func PriorityQueueV2Capacity(value int64) PriorityQueueV2Attr { return func(m optionalAttr) { - m["is_training"] = value + m["capacity"] = value } } -// Gradient for batch normalization. -// -// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -// The size of 1D Tensors matches the dimension C of the 4D Tensors. -// -// Arguments: -// y_backprop: A 4D Tensor for the gradient with respect to y. -// x: A 4D Tensor for input data. -// scale: A 1D Tensor for scaling factor, to scale the normalized x. -// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch -// mean to be reused in gradient computation. When is_training is -// False, a 1D Tensor for the population mean to be reused in both -// 1st and 2nd order gradient computation. -// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch -// variance (inverted variance in the cuDNN case) to be reused in -// gradient computation. When is_training is False, a 1D Tensor -// for the population variance to be reused in both 1st and 2nd -// order gradient computation. +// PriorityQueueV2Container sets the optional container attribute to value. // -// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input -// in FusedBatchNorm. -func FusedBatchNormGradV2(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradV2Attr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FusedBatchNormGradV2", - Input: []tf.Input{ - y_backprop, x, scale, reserve_space_1, reserve_space_2, - }, - Attrs: attrs, +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func PriorityQueueV2Container(value string) PriorityQueueV2Attr { + return func(m optionalAttr) { + m["container"] = value } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// DecodeCompressedAttr is an optional argument to DecodeCompressed. -type DecodeCompressedAttr func(optionalAttr) - -// DecodeCompressedCompressionType sets the optional compression_type attribute to value. +// PriorityQueueV2SharedName sets the optional shared_name attribute to value. // -// value: A scalar containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. // If not specified, defaults to "" -func DecodeCompressedCompressionType(value string) DecodeCompressedAttr { +func PriorityQueueV2SharedName(value string) PriorityQueueV2Attr { return func(m optionalAttr) { - m["compression_type"] = value + m["shared_name"] = value } } -// Decompress strings. -// -// This op decompresses each element of the `bytes` input `Tensor`, which -// is assumed to be compressed using the given `compression_type`. +// A queue that produces elements sorted by the first component value. // -// The `output` is a string `Tensor` of the same shape as `bytes`, -// each element containing the decompressed data from the corresponding -// element in `bytes`. +// Note that the PriorityQueue requires the first component of any element +// to be a scalar int64, in addition to the other elements declared by +// component_types. Therefore calls to Enqueue and EnqueueMany (resp. Dequeue +// and DequeueMany) on a PriorityQueue will all require (resp. output) one extra +// entry in their input (resp. output) lists. // // Arguments: -// bytes: A Tensor of string which is compressed. +// shapes: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. // -// Returns A Tensor with the same shape as input `bytes`, uncompressed -// from bytes. -func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) { +// Returns The handle to the queue. +func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"shapes": shapes} for _, a := range optional { a(attrs) } - opspec := tf.OpSpec{ - Type: "DecodeCompressed", - Input: []tf.Input{ - bytes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// CudnnRNNAttr is an optional argument to CudnnRNN. -type CudnnRNNAttr func(optionalAttr) - -// CudnnRNNRnnMode sets the optional rnn_mode attribute to value. -// If not specified, defaults to "lstm" -func CudnnRNNRnnMode(value string) CudnnRNNAttr { - return func(m optionalAttr) { - m["rnn_mode"] = value - } -} - -// CudnnRNNInputMode sets the optional input_mode attribute to value. -// If not specified, defaults to "linear_input" -func CudnnRNNInputMode(value string) CudnnRNNAttr { - return func(m optionalAttr) { - m["input_mode"] = value - } -} - -// CudnnRNNDirection sets the optional direction attribute to value. -// If not specified, defaults to "unidirectional" -func CudnnRNNDirection(value string) CudnnRNNAttr { - return func(m optionalAttr) { - m["direction"] = value + opspec := tf.OpSpec{ + Type: "PriorityQueueV2", + + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// CudnnRNNDropout sets the optional dropout attribute to value. +// UnstageAttr is an optional argument to Unstage. +type UnstageAttr func(optionalAttr) + +// UnstageCapacity sets the optional capacity attribute to value. // If not specified, defaults to 0 -func CudnnRNNDropout(value float32) CudnnRNNAttr { +// +// REQUIRES: value >= 0 +func UnstageCapacity(value int64) UnstageAttr { return func(m optionalAttr) { - m["dropout"] = value + m["capacity"] = value } } -// CudnnRNNSeed sets the optional seed attribute to value. +// UnstageMemoryLimit sets the optional memory_limit attribute to value. // If not specified, defaults to 0 -func CudnnRNNSeed(value int64) CudnnRNNAttr { +// +// REQUIRES: value >= 0 +func UnstageMemoryLimit(value int64) UnstageAttr { return func(m optionalAttr) { - m["seed"] = value + m["memory_limit"] = value } } -// CudnnRNNSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func CudnnRNNSeed2(value int64) CudnnRNNAttr { +// UnstageContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func UnstageContainer(value string) UnstageAttr { return func(m optionalAttr) { - m["seed2"] = value + m["container"] = value } } -// CudnnRNNIsTraining sets the optional is_training attribute to value. -// If not specified, defaults to true -func CudnnRNNIsTraining(value bool) CudnnRNNAttr { +// UnstageSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func UnstageSharedName(value string) UnstageAttr { return func(m optionalAttr) { - m["is_training"] = value + m["shared_name"] = value } } -// A RNN backed by cuDNN. -// -// Computes the RNN from the input and initial states, with respect to the params -// buffer. +// Op is similar to a lightweight Dequeue. // -// rnn_mode: Indicates the type of the RNN model. -// input_mode: Indicate whether there is a linear projection between the input and -// The actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. -// dir = (direction == bidirectional) ? 2 : 1 -// dropout: dropout probability. When set to 0., dropout is disabled. -// seed: the 1st part of a seed to initialize dropout. -// seed2: the 2nd part of a seed to initialize dropout. -// input: a 3-D tensor with the shape of [seq_length, batch_size, input_size]. -// input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size, -// num_units]. -// input_c: For LSTM, a 3-D tensor with the shape of -// [num_layer * dir, batch, num_units]. For other models, it is ignored. -// params: a 1-D tensor that contains the weights and biases in an opaque layout. -// The size must be created through CudnnRNNParamsSize, and initialized -// separately. Note that they might not be compatible across different -// generations. So it is a good idea to save and restore -// output: a 3-D tensor with the shape of [seq_length, batch_size, -// dir * num_units]. -// output_h: the same shape has input_h. -// output_c: the same shape as input_c for LSTM. An empty tensor for other models. -// is_training: Indicates whether this operation is used for inferenece or -// training. -// reserve_space: an opaque tensor that can be used in backprop calculation. It -// is only produced if is_training is false. -func CudnnRNN(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, optional ...CudnnRNNAttr) (output tf.Output, output_h tf.Output, output_c tf.Output, reserve_space tf.Output) { +// The basic functionality is similar to dequeue with many fewer +// capabilities and options. This Op is optimized for performance. +func Unstage(scope *Scope, dtypes []tf.DataType, optional ...UnstageAttr) (values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "CudnnRNN", - Input: []tf.Input{ - input, input_h, input_c, params, - }, + Type: "Unstage", + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// Compare values of `input` to `threshold` and pack resulting bits into a `uint8`. -// -// Each comparison returns a boolean `true` (if `input_value > threshold`) -// or and `false` otherwise. -// -// This operation is useful for Locality-Sensitive-Hashing (LSH) and other -// algorithms that use hashing approximations of cosine and `L2` distances; -// codes can be generated from an input via: -// -// ```python -// codebook_size = 50 -// codebook_bits = codebook_size * 32 -// codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits], -// dtype=x.dtype, -// initializer=tf.orthogonal_initializer()) -// codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.) -// codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32 -// # now codes has shape x.shape[:-1] + [codebook_size] -// ``` -// -// **NOTE**: Currently, the innermost dimension of the tensor must be divisible -// by 8. -// -// Given an `input` shaped `[s0, s1, ..., s_n]`, the output is -// a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`. -// -// Arguments: -// input: Values to compare against `threshold` and bitpack. -// threshold: Threshold to compare against. -// -// Returns The bitpacked comparisons. -func CompareAndBitpack(scope *Scope, input tf.Output, threshold tf.Output) (output tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "CompareAndBitpack", - Input: []tf.Input{ - input, threshold, - }, + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("Unstage", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + return values } -// Push an element onto the tensor_array. +// QueueEnqueueV2Attr is an optional argument to QueueEnqueueV2. +type QueueEnqueueV2Attr func(optionalAttr) + +// QueueEnqueueV2TimeoutMs sets the optional timeout_ms attribute to value. +// +// value: If the queue is full, this operation will block for up to +// timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueEnqueueV2TimeoutMs(value int64) QueueEnqueueV2Attr { + return func(m optionalAttr) { + m["timeout_ms"] = value + } +} + +// Enqueues a tuple of one or more tensors in the given queue. +// +// The components input has k elements, which correspond to the components of +// tuples stored in the given queue. +// +// N.B. If the queue is full, this operation will block until the given +// element has been enqueued (or 'timeout_ms' elapses, if specified). // // Arguments: -// handle: The handle to a TensorArray. -// index: The position to write to inside the TensorArray. -// value: The tensor to write to the TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. +// handle: The handle to a queue. +// components: One or more tensors from which the enqueued tensors should be taken. // -// Returns A float scalar that enforces proper chaining of operations. -func TensorArrayWriteV3(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// Returns the created operation. +func QueueEnqueueV2(scope *Scope, handle tf.Output, components []tf.Output, optional ...QueueEnqueueV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorArrayWriteV3", + Type: "QueueEnqueueV2", Input: []tf.Input{ - handle, index, value, flow_in, + handle, tf.OutputList(components), }, + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Scatter the data from the input value into specific TensorArray elements. +// Computes the Bessel i0e function of `x` element-wise. // -// `indices` must be a vector, its length must match the first dim of `value`. -// -// Arguments: -// handle: The handle to a TensorArray. -// indices: The locations at which to write the tensor elements. -// value: The concatenated tensor to write to the TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. +// Exponentially scaled modified Bessel function of order 0 defined as +// `bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`. // -// Returns A float scalar that enforces proper chaining of operations. -func TensorArrayScatterV3(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// This function is faster and numerically stabler than `bessel_i0(x)`. +func BesselI0e(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArrayScatterV3", + Type: "BesselI0e", Input: []tf.Input{ - handle, indices, value, flow_in, + x, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// EmptyAttr is an optional argument to Empty. -type EmptyAttr func(optionalAttr) +// QueueDequeueManyV2Attr is an optional argument to QueueDequeueManyV2. +type QueueDequeueManyV2Attr func(optionalAttr) -// EmptyInit sets the optional init attribute to value. +// QueueDequeueManyV2TimeoutMs sets the optional timeout_ms attribute to value. // -// value: If True, initialize the returned tensor with the default value of dtype. Otherwise, the implementation is free not to initializethe tensor's content. -// If not specified, defaults to false -func EmptyInit(value bool) EmptyAttr { +// value: If the queue has fewer than n elements, this operation +// will block for up to timeout_ms milliseconds. +// Note: This option is not supported yet. +// If not specified, defaults to -1 +func QueueDequeueManyV2TimeoutMs(value int64) QueueDequeueManyV2Attr { return func(m optionalAttr) { - m["init"] = value + m["timeout_ms"] = value } } -// Creates a tensor with the given shape. +// Dequeues `n` tuples of one or more tensors from the given queue. // -// This operation creates a tensor of `shape` and `dtype`. +// If the queue is closed and there are fewer than `n` elements, then an +// OutOfRange error is returned. // -// Arguments: -// shape: 1-D. Represents the shape of the output tensor. +// This operation concatenates queue-element component tensors along the +// 0th dimension to make a single component tensor. All of the components +// in the dequeued tuple will have size `n` in the 0th dimension. // +// This operation has `k` outputs, where `k` is the number of components in +// the tuples stored in the given queue, and output `i` is the ith +// component of the dequeued tuple. // -// Returns A `Tensor` of type `T`. -func Empty(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...EmptyAttr) (output tf.Output) { +// N.B. If the queue is empty, this operation will block until `n` elements +// have been dequeued (or 'timeout_ms' elapses, if specified). +// +// Arguments: +// handle: The handle to a queue. +// n: The number of tuples to dequeue. +// component_types: The type of each component in a tuple. +// +// Returns One or more tensors that were dequeued as a tuple. +func QueueDequeueManyV2(scope *Scope, handle tf.Output, n tf.Output, component_types []tf.DataType, optional ...QueueDequeueManyV2Attr) (components []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{"component_types": component_types} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Empty", + Type: "QueueDequeueManyV2", Input: []tf.Input{ - shape, + handle, n, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("QueueDequeueManyV2", err) + return + } + return components } -// TensorArrayConcatV3Attr is an optional argument to TensorArrayConcatV3. -type TensorArrayConcatV3Attr func(optionalAttr) +// EncodeBase64Attr is an optional argument to EncodeBase64. +type EncodeBase64Attr func(optionalAttr) -// TensorArrayConcatV3ElementShapeExcept0 sets the optional element_shape_except0 attribute to value. +// EncodeBase64Pad sets the optional pad attribute to value. // -// value: The expected shape of an element, if known, -// excluding the first dimension. Used to validate the shapes of -// TensorArray elements. If this shape is not fully specified, concatenating -// zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayConcatV3ElementShapeExcept0(value tf.Shape) TensorArrayConcatV3Attr { +// value: Bool whether padding is applied at the ends. +// If not specified, defaults to false +func EncodeBase64Pad(value bool) EncodeBase64Attr { return func(m optionalAttr) { - m["element_shape_except0"] = value + m["pad"] = value } } -// Concat the elements from the TensorArray into value `value`. -// -// Takes `T` elements of shapes -// -// ``` -// (n0 x d0 x d1 x ...), (n1 x d0 x d1 x ...), ..., (n(T-1) x d0 x d1 x ...) -// ``` -// -// and concatenates them into a Tensor of shape: +// Encode strings into web-safe base64 format. // -// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)``` +// Refer to the following article for more information on base64 format: +// en.wikipedia.org/wiki/Base64. Base64 strings may have padding with '=' at the +// end so that the encoded has length multiple of 4. See Padding section of the +// link above. // -// All elements must have the same shape (excepting the first dimension). +// Web-safe means that the encoder uses - and _ instead of + and /. // // Arguments: -// handle: The handle to a TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. -// dtype: The type of the elem that is returned. +// input: Strings to be encoded. // -// Returns All of the elements in the TensorArray, concatenated along the first -// axis.A vector of the row sizes of the original T elements in the -// value output. In the example above, this would be the values: -// `(n1, n2, ..., n(T-1))`. -func TensorArrayConcatV3(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV3Attr) (value tf.Output, lengths tf.Output) { +// Returns Input strings encoded in base64. +func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtype": dtype} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorArrayConcatV3", + Type: "EncodeBase64", Input: []tf.Input{ - handle, flow_in, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Split the data from the input value into TensorArray elements. -// -// Assuming that `lengths` takes on values -// -// ```(n0, n1, ..., n(T-1))``` -// -// and that `value` has shape -// -// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)```, +// A dataset that creates window datasets from the input dataset. // -// this splits values into a TensorArray with T tensors. -// -// TensorArray index t will be the subtensor of values with starting position -// -// ```(n0 + n1 + ... + n(t-1), 0, 0, ...)``` -// -// and having size +// Arguments: // -// ```nt x d0 x d1 x ...``` +// window_size: A scalar representing the number of elements to accumulate in a window. // -// Arguments: -// handle: The handle to a TensorArray. -// value: The concatenated tensor to write to the TensorArray. -// lengths: The vector of lengths, how to split the rows of value into the -// TensorArray. -// flow_in: A float scalar that enforces proper chaining of operations. // -// Returns A float scalar that enforces proper chaining of operations. -func TensorArraySplitV3(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { +func WindowDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "TensorArraySplitV3", + Type: "WindowDataset", Input: []tf.Input{ - handle, value, lengths, flow_in, + input_dataset, window_size, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// SerializeSparseAttr is an optional argument to SerializeSparse. -type SerializeSparseAttr func(optionalAttr) - -// SerializeSparseOutType sets the optional out_type attribute to value. +// Deprecated. Use TensorArrayCloseV3 // -// value: The `dtype` to use for serialization; the supported types are `string` -// (default) and `variant`. -// If not specified, defaults to DT_STRING -func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr { - return func(m optionalAttr) { - m["out_type"] = value - } -} - -// Serialize a `SparseTensor` into a `[3]` `Tensor` object. +// DEPRECATED at GraphDef version 26: Use TensorArrayCloseV3 // -// Arguments: -// sparse_indices: 2-D. The `indices` of the `SparseTensor`. -// sparse_values: 1-D. The `values` of the `SparseTensor`. -// sparse_shape: 1-D. The `shape` of the `SparseTensor`. -func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) { +// Returns the created operation. +func TensorArrayCloseV2(scope *Scope, handle tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "SerializeSparse", + Type: "TensorArrayCloseV2", Input: []tf.Input{ - sparse_indices, sparse_values, sparse_shape, + handle, }, - Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2. -type RandomShuffleQueueV2Attr func(optionalAttr) - -// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value. +// Forwards the value of an available tensor from `inputs` to `output`. // -// value: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. -// If not specified, defaults to <> +// `Merge` waits for at least one of the tensors in `inputs` to become available. +// It is usually combined with `Switch` to implement branching. // -// REQUIRES: len(value) >= 0 -func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["shapes"] = value - } -} - -// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value. +// `Merge` forwards the first tensor to become available to `output`, and sets +// `value_index` to its index in `inputs`. // -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value. +// Arguments: +// inputs: The input tensors, exactly one of which will become available. // -// value: Dequeue will block unless there would be this -// many elements after the dequeue or the queue is closed. This -// ensures a minimum level of mixing of elements. -// If not specified, defaults to 0 -func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["min_after_dequeue"] = value +// Returns Will be set to the available input tensor.The index of the chosen input tensor in `inputs`. +func Merge(scope *Scope, inputs []tf.Output) (output tf.Output, value_index tf.Output) { + if scope.Err() != nil { + return } -} - -// RandomShuffleQueueV2Seed sets the optional seed attribute to value. -// -// value: If either seed or seed2 is set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, a random seed is used. -// If not specified, defaults to 0 -func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["seed"] = value + opspec := tf.OpSpec{ + Type: "Merge", + Input: []tf.Input{ + tf.OutputList(inputs), + }, } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["seed2"] = value - } -} +// QueueCloseV2Attr is an optional argument to QueueCloseV2. +type QueueCloseV2Attr func(optionalAttr) -// RandomShuffleQueueV2Container sets the optional container attribute to value. +// QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. // -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr { +// value: If true, all pending enqueue requests that are +// blocked on the given queue will be canceled. +// If not specified, defaults to false +func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { return func(m optionalAttr) { - m["container"] = value + m["cancel_pending_enqueues"] = value } } -// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value. +// Closes the given queue. // -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A queue that randomizes the order of elements. +// This operation signals that no more elements will be enqueued in the +// given queue. Subsequent Enqueue(Many) operations will fail. +// Subsequent Dequeue(Many) operations will continue to succeed if +// sufficient elements remain in the queue. Subsequent Dequeue(Many) +// operations that would block will fail immediately. // // Arguments: -// component_types: The type of each component in a value. +// handle: The handle to a queue. // -// Returns The handle to the queue. -func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) { +// Returns the created operation. +func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"component_types": component_types} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "RandomShuffleQueueV2", - + Type: "QueueCloseV2", + Input: []tf.Input{ + handle, + }, Attrs: attrs, } + return scope.AddOperation(opspec) +} + +// Computes inverse hyperbolic tangent of x element-wise. +func Atanh(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Atanh", + Input: []tf.Input{ + x, + }, + } op := scope.AddOperation(opspec) return op.Output(0) } -// Draw bounding boxes on a batch of images. -// -// Outputs a copy of `images` but draws on top of the pixels zero or more bounding -// boxes specified by the locations in `boxes`. The coordinates of the each -// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The -// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and -// height of the underlying image. -// -// For example, if an image is 100 x 200 pixels (height x width) and the bounding -// box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of -// the bounding box will be `(40, 10)` to `(100, 50)` (in (x,y) coordinates). +// Returns true if queue is closed. // -// Parts of the bounding box may fall outside the image. +// This operation returns true if the queue is closed and false if the queue +// is open. // // Arguments: -// images: 4-D with shape `[batch, height, width, depth]`. A batch of images. -// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding -// boxes. -// -// Returns 4-D with the same shape as `images`. The batch of input images with -// bounding boxes drawn on the images. -func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) { +// handle: The handle to a queue. +func QueueIsClosedV2(scope *Scope, handle tf.Output) (is_closed tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "DrawBoundingBoxes", + Type: "QueueIsClosedV2", Input: []tf.Input{ - images, boxes, + handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler. -type LearnedUnigramCandidateSamplerAttr func(optionalAttr) - -// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// Computes the absolute value of a tensor. // -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed"] = value +// Given a tensor `x`, this operation returns a tensor containing the absolute +// value of each element in `x`. For example, if x is an input element and y is +// an output element, this operation computes \\(y = |x|\\). +func Abs(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Abs", + Input: []tf.Input{ + x, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// StackV2Attr is an optional argument to StackV2. +type StackV2Attr func(optionalAttr) + +// StackV2StackName sets the optional stack_name attribute to value. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr { +// value: Overrides the name used for the temporary stack resource. Default +// value is the name of the 'Stack' op (which is guaranteed unique). +// If not specified, defaults to "" +func StackV2StackName(value string) StackV2Attr { return func(m optionalAttr) { - m["seed2"] = value + m["stack_name"] = value } } -// Generates labels for candidate sampling with a learned unigram distribution. -// -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. -// -// For each batch, this op picks a single set of sampled candidate labels. -// -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. +// A stack that produces elements in first-in last-out order. // // Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to randomly sample. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// range_max: The sampler will sample integers from the interval [0, range_max). +// max_size: The maximum size of the stack if non-negative. If negative, the stack +// size is unlimited. +// elem_type: The type of the elements on the stack. // -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +// Returns The handle to the stack. +func StackV2(scope *Scope, max_size tf.Output, elem_type tf.DataType, optional ...StackV2Attr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} + attrs := map[string]interface{}{"elem_type": elem_type} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "LearnedUnigramCandidateSampler", + Type: "StackV2", Input: []tf.Input{ - true_classes, + max_size, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// Computes gradients for the scaled exponential linear (Selu) operation. -// -// Arguments: -// gradients: The backpropagated gradients to the corresponding Selu operation. -// outputs: The outputs of the corresponding Selu operation. +// FusedBatchNormGradV2Attr is an optional argument to FusedBatchNormGradV2. +type FusedBatchNormGradV2Attr func(optionalAttr) + +// FusedBatchNormGradV2Epsilon sets the optional epsilon attribute to value. // -// Returns The gradients: `gradients * (outputs + scale * alpha)` -// if outputs < 0, `scale * gradients` otherwise. -func SeluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "SeluGrad", - Input: []tf.Input{ - gradients, outputs, - }, +// value: A small float number added to the variance of x. +// If not specified, defaults to 0.0001 +func FusedBatchNormGradV2Epsilon(value float32) FusedBatchNormGradV2Attr { + return func(m optionalAttr) { + m["epsilon"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Get the current size of the TensorArray. -// -// Arguments: -// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). -// flow_in: A float scalar that enforces proper chaining of operations. +// FusedBatchNormGradV2DataFormat sets the optional data_format attribute to value. // -// Returns The current size of the TensorArray. -func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { - if scope.Err() != nil { - return +// value: The data format for y_backprop, x, x_backprop. +// Either "NHWC" (default) or "NCHW". +// If not specified, defaults to "NHWC" +func FusedBatchNormGradV2DataFormat(value string) FusedBatchNormGradV2Attr { + return func(m optionalAttr) { + m["data_format"] = value } - opspec := tf.OpSpec{ - Type: "TensorArraySizeV3", - Input: []tf.Input{ - handle, flow_in, - }, +} + +// FusedBatchNormGradV2IsTraining sets the optional is_training attribute to value. +// +// value: A bool value to indicate the operation is for training (default) +// or inference. +// If not specified, defaults to true +func FusedBatchNormGradV2IsTraining(value bool) FusedBatchNormGradV2Attr { + return func(m optionalAttr) { + m["is_training"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Deprecated. Use TensorArrayGradV3 +// Gradient for batch normalization. // -// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3 -func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +// The size of 1D Tensors matches the dimension C of the 4D Tensors. +// +// Arguments: +// y_backprop: A 4D Tensor for the gradient with respect to y. +// x: A 4D Tensor for input data. +// scale: A 1D Tensor for scaling factor, to scale the normalized x. +// reserve_space_1: When is_training is True, a 1D Tensor for the computed batch +// mean to be reused in gradient computation. When is_training is +// False, a 1D Tensor for the population mean to be reused in both +// 1st and 2nd order gradient computation. +// reserve_space_2: When is_training is True, a 1D Tensor for the computed batch +// variance (inverted variance in the cuDNN case) to be reused in +// gradient computation. When is_training is False, a 1D Tensor +// for the population variance to be reused in both 1st and 2nd +// order gradient computation. +// +// Returns A 4D Tensor for the gradient with respect to x.A 1D Tensor for the gradient with respect to scale.A 1D Tensor for the gradient with respect to offset.Unused placeholder to match the mean input in FusedBatchNorm.Unused placeholder to match the variance input +// in FusedBatchNorm. +func FusedBatchNormGradV2(scope *Scope, y_backprop tf.Output, x tf.Output, scale tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output, optional ...FusedBatchNormGradV2Attr) (x_backprop tf.Output, scale_backprop tf.Output, offset_backprop tf.Output, reserve_space_3 tf.Output, reserve_space_4 tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "TensorArrayWriteV2", + Type: "FusedBatchNormGradV2", Input: []tf.Input{ - handle, index, value, flow_in, + y_backprop, x, scale, reserve_space_1, reserve_space_2, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// SparseReduceMaxAttr is an optional argument to SparseReduceMax. -type SparseReduceMaxAttr func(optionalAttr) +// DecodeCompressedAttr is an optional argument to DecodeCompressed. +type DecodeCompressedAttr func(optionalAttr) -// SparseReduceMaxKeepDims sets the optional keep_dims attribute to value. +// DecodeCompressedCompressionType sets the optional compression_type attribute to value. // -// value: If true, retain reduced dimensions with length 1. -// If not specified, defaults to false -func SparseReduceMaxKeepDims(value bool) SparseReduceMaxAttr { +// value: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// If not specified, defaults to "" +func DecodeCompressedCompressionType(value string) DecodeCompressedAttr { return func(m optionalAttr) { - m["keep_dims"] = value + m["compression_type"] = value } } -// Computes the max of elements across dimensions of a SparseTensor. -// -// This Op takes a SparseTensor and is the sparse counterpart to -// `tf.reduce_max()`. In particular, this Op also returns a dense `Tensor` -// instead of a sparse one. +// Decompress strings. // -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. +// This op decompresses each element of the `bytes` input `Tensor`, which +// is assumed to be compressed using the given `compression_type`. // -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. +// The `output` is a string `Tensor` of the same shape as `bytes`, +// each element containing the decompressed data from the corresponding +// element in `bytes`. // // Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +// bytes: A Tensor of string which is compressed. // -// Returns `R-K`-D. The reduced Tensor. -func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxAttr) (output tf.Output) { +// Returns A Tensor with the same shape as input `bytes`, uncompressed +// from bytes. +func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) { if scope.Err() != nil { return } @@ -28087,9 +29137,9 @@ func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Outp a(attrs) } opspec := tf.OpSpec{ - Type: "SparseReduceMax", + Type: "DecodeCompressed", Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, + bytes, }, Attrs: attrs, } @@ -28097,68 +29147,98 @@ func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Outp return op.Output(0) } -// AsStringAttr is an optional argument to AsString. -type AsStringAttr func(optionalAttr) +// CudnnRNNAttr is an optional argument to CudnnRNN. +type CudnnRNNAttr func(optionalAttr) -// AsStringPrecision sets the optional precision attribute to value. -// -// value: The post-decimal precision to use for floating point numbers. -// Only used if precision > -1. -// If not specified, defaults to -1 -func AsStringPrecision(value int64) AsStringAttr { +// CudnnRNNRnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNRnnMode(value string) CudnnRNNAttr { + return func(m optionalAttr) { + m["rnn_mode"] = value + } +} + +// CudnnRNNInputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNInputMode(value string) CudnnRNNAttr { + return func(m optionalAttr) { + m["input_mode"] = value + } +} + +// CudnnRNNDirection sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNDirection(value string) CudnnRNNAttr { return func(m optionalAttr) { - m["precision"] = value + m["direction"] = value } } -// AsStringScientific sets the optional scientific attribute to value. -// -// value: Use scientific notation for floating point numbers. -// If not specified, defaults to false -func AsStringScientific(value bool) AsStringAttr { +// CudnnRNNDropout sets the optional dropout attribute to value. +// If not specified, defaults to 0 +func CudnnRNNDropout(value float32) CudnnRNNAttr { return func(m optionalAttr) { - m["scientific"] = value + m["dropout"] = value } } -// AsStringShortest sets the optional shortest attribute to value. -// -// value: Use shortest representation (either scientific or standard) for -// floating point numbers. -// If not specified, defaults to false -func AsStringShortest(value bool) AsStringAttr { +// CudnnRNNSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func CudnnRNNSeed(value int64) CudnnRNNAttr { return func(m optionalAttr) { - m["shortest"] = value + m["seed"] = value } } -// AsStringWidth sets the optional width attribute to value. -// -// value: Pad pre-decimal numbers to this width. -// Applies to both floating point and integer numbers. -// Only used if width > -1. -// If not specified, defaults to -1 -func AsStringWidth(value int64) AsStringAttr { +// CudnnRNNSeed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNSeed2(value int64) CudnnRNNAttr { return func(m optionalAttr) { - m["width"] = value + m["seed2"] = value } } -// AsStringFill sets the optional fill attribute to value. -// -// value: The value to pad if width > -1. If empty, pads with spaces. -// Another typical value is '0'. String cannot be longer than 1 character. -// If not specified, defaults to "" -func AsStringFill(value string) AsStringAttr { +// CudnnRNNIsTraining sets the optional is_training attribute to value. +// If not specified, defaults to true +func CudnnRNNIsTraining(value bool) CudnnRNNAttr { return func(m optionalAttr) { - m["fill"] = value + m["is_training"] = value } } -// Converts each entry in the given tensor to strings. Supports many numeric +// A RNN backed by cuDNN. // -// types and boolean. -func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output tf.Output) { +// Computes the RNN from the input and initial states, with respect to the params +// buffer. +// +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicate whether there is a linear projection between the input and +// the actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. Should be +// "unidirectional" or "bidirectional". +// dropout: Dropout probability. When set to 0., dropout is disabled. +// seed: The 1st part of a seed to initialize dropout. +// seed2: The 2nd part of a seed to initialize dropout. +// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. +// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, +// num_units]. +// input_c: For LSTM, a 3-D tensor with the shape of +// [num_layer * dir, batch, num_units]. For other models, it is ignored. +// params: A 1-D tensor that contains the weights and biases in an opaque layout. +// The size must be created through CudnnRNNParamsSize, and initialized +// separately. Note that they might not be compatible across different +// generations. So it is a good idea to save and restore +// output: A 3-D tensor with the shape of [seq_length, batch_size, +// dir * num_units]. +// output_h: The same shape has input_h. +// output_c: The same shape as input_c for LSTM. An empty tensor for other models. +// is_training: Indicates whether this operation is used for inferenece or +// training. +// reserve_space: An opaque tensor that can be used in backprop calculation. It +// is only produced if is_training is false. +func CudnnRNN(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, optional ...CudnnRNNAttr) (output tf.Output, output_h tf.Output, output_c tf.Output, reserve_space tf.Output) { if scope.Err() != nil { return } @@ -28167,850 +29247,1127 @@ func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output t a(attrs) } opspec := tf.OpSpec{ - Type: "AsString", + Type: "CudnnRNN", Input: []tf.Input{ - input, + input, input_h, input_c, params, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// Deprecated. Use TensorArrayScatterV3 +// Creates a TensorArray for storing multiple gradients of values in the given handle. // -// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3 -func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { +// Similar to TensorArrayGradV3. However it creates an accumulator with an +// expanded shape compared to the input TensorArray whose gradient is being +// computed. This enables multiple gradients for the same TensorArray to be +// calculated using the same accumulator. +// +// Arguments: +// handle: The handle to the forward TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// shape_to_prepend: An int32 vector representing a shape. Elements in the gradient accumulator will +// have shape which is this shape_to_prepend value concatenated with shape of the +// elements in the TensorArray corresponding to the input handle. +// source: The gradient source string, used to decide which gradient TensorArray +// to return. +func TensorArrayGradWithShape(scope *Scope, handle tf.Output, flow_in tf.Output, shape_to_prepend tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"source": source} opspec := tf.OpSpec{ - Type: "TensorArrayScatterV2", + Type: "TensorArrayGradWithShape", Input: []tf.Input{ - handle, indices, value, flow_in, + handle, flow_in, shape_to_prepend, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Creates a tree ensemble model and returns a handle to it. +// Compare values of `input` to `threshold` and pack resulting bits into a `uint8`. +// +// Each comparison returns a boolean `true` (if `input_value > threshold`) +// or and `false` otherwise. +// +// This operation is useful for Locality-Sensitive-Hashing (LSH) and other +// algorithms that use hashing approximations of cosine and `L2` distances; +// codes can be generated from an input via: +// +// ```python +// codebook_size = 50 +// codebook_bits = codebook_size * 32 +// codebook = tf.get_variable('codebook', [x.shape[-1].value, codebook_bits], +// dtype=x.dtype, +// initializer=tf.orthogonal_initializer()) +// codes = compare_and_threshold(tf.matmul(x, codebook), threshold=0.) +// codes = tf.bitcast(codes, tf.int32) # go from uint8 to int32 +// # now codes has shape x.shape[:-1] + [codebook_size] +// ``` +// +// **NOTE**: Currently, the innermost dimension of the tensor must be divisible +// by 8. +// +// Given an `input` shaped `[s0, s1, ..., s_n]`, the output is +// a `uint8` tensor shaped `[s0, s1, ..., s_n / 8]`. // // Arguments: -// tree_ensemble_handle: Handle to the tree ensemble resource to be created. -// stamp_token: Token to use as the initial value of the resource stamp. -// tree_ensemble_serialized: Serialized proto of the tree ensemble. +// input: Values to compare against `threshold` and bitpack. +// threshold: Threshold to compare against. // -// Returns the created operation. -func BoostedTreesCreateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { +// Returns The bitpacked comparisons. +func CompareAndBitpack(scope *Scope, input tf.Output, threshold tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BoostedTreesCreateEnsemble", + Type: "CompareAndBitpack", Input: []tf.Input{ - tree_ensemble_handle, stamp_token, tree_ensemble_serialized, + input, threshold, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// Applies sparse addition to `input` using individual values or slices -// -// from `updates` according to indices `indices`. The updates are non-aliasing: -// `input` is only modified in-place if no other operations will use it. -// Otherwise, a copy of `input` is made. This operation has a gradient with -// respect to both `input` and `updates`. -// -// `input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. -// -// `indices` must be integer tensor, containing indices into `input`. -// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. -// -// The innermost dimension of `indices` (with length `K`) corresponds to -// indices into elements (if `K = P`) or `(P-K)`-dimensional slices -// (if `K < P`) along the `K`th dimension of `input`. -// -// `updates` is `Tensor` of rank `Q-1+P-K` with shape: -// -// ``` -// [d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]]. -// ``` -// -// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 -// elements. In Python, that addition would look like this: -// -// input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8]) -// indices = tf.constant([[4], [3], [1], [7]]) -// updates = tf.constant([9, 10, 11, 12]) -// output = tf.scatter_nd_non_aliasing_add(input, indices, updates) -// with tf.Session() as sess: -// print(sess.run(output)) +// Push an element onto the tensor_array. // -// The resulting value `output` would look like this: +// Arguments: +// handle: The handle to a TensorArray. +// index: The position to write to inside the TensorArray. +// value: The tensor to write to the TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. // -// [1, 13, 3, 14, 14, 6, 7, 20] +// Returns A float scalar that enforces proper chaining of operations. +func TensorArrayWriteV3(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArrayWriteV3", + Input: []tf.Input{ + handle, index, value, flow_in, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Scatter the data from the input value into specific TensorArray elements. // -// See @{tf.scatter_nd} for more details about how to make updates to slices. +// `indices` must be a vector, its length must match the first dim of `value`. // // Arguments: -// input: A Tensor. -// indices: A Tensor. Must be one of the following types: `int32`, `int64`. -// A tensor of indices into `input`. -// updates: A Tensor. Must have the same type as ref. A tensor of updated values -// to add to `input`. +// handle: The handle to a TensorArray. +// indices: The locations at which to write the tensor elements. +// value: The concatenated tensor to write to the TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. // -// Returns A `Tensor` with the same shape as `input`, containing values of `input` -// updated with `updates`. -func ScatterNdNonAliasingAdd(scope *Scope, input tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { +// Returns A float scalar that enforces proper chaining of operations. +func TensorArrayScatterV3(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "ScatterNdNonAliasingAdd", + Type: "TensorArrayScatterV3", Input: []tf.Input{ - input, indices, updates, + handle, indices, value, flow_in, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// FractionalMaxPoolAttr is an optional argument to FractionalMaxPool. -type FractionalMaxPoolAttr func(optionalAttr) +// EmptyAttr is an optional argument to Empty. +type EmptyAttr func(optionalAttr) -// FractionalMaxPoolPseudoRandom sets the optional pseudo_random attribute to value. +// EmptyInit sets the optional init attribute to value. // -// value: When set to True, generates the pooling sequence in a -// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin -// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for -// difference between pseudorandom and random. +// value: If True, initialize the returned tensor with the default value of dtype. Otherwise, the implementation is free not to initializethe tensor's content. // If not specified, defaults to false -func FractionalMaxPoolPseudoRandom(value bool) FractionalMaxPoolAttr { +func EmptyInit(value bool) EmptyAttr { return func(m optionalAttr) { - m["pseudo_random"] = value + m["init"] = value } } -// FractionalMaxPoolOverlapping sets the optional overlapping attribute to value. +// Creates a tensor with the given shape. // -// value: When set to True, it means when pooling, the values at the boundary -// of adjacent pooling cells are used by both cells. For example: +// This operation creates a tensor of `shape` and `dtype`. // -// `index 0 1 2 3 4` +// Arguments: +// shape: 1-D. Represents the shape of the output tensor. // -// `value 20 5 16 3 7` // -// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. -// The result would be [20, 16] for fractional max pooling. -// If not specified, defaults to false -func FractionalMaxPoolOverlapping(value bool) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["overlapping"] = value +// Returns A `Tensor` of type `T`. +func Empty(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...EmptyAttr) (output tf.Output) { + if scope.Err() != nil { + return } -} - -// FractionalMaxPoolDeterministic sets the optional deterministic attribute to value. -// -// value: When set to True, a fixed pooling region will be used when -// iterating over a FractionalMaxPool node in the computation graph. Mainly used -// in unit test to make FractionalMaxPool deterministic. -// If not specified, defaults to false -func FractionalMaxPoolDeterministic(value bool) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["deterministic"] = value + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Empty", + Input: []tf.Input{ + shape, + }, + Attrs: attrs, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// FractionalMaxPoolSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func FractionalMaxPoolSeed(value int64) FractionalMaxPoolAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} +// TensorArrayConcatV3Attr is an optional argument to TensorArrayConcatV3. +type TensorArrayConcatV3Attr func(optionalAttr) -// FractionalMaxPoolSeed2 sets the optional seed2 attribute to value. +// TensorArrayConcatV3ElementShapeExcept0 sets the optional element_shape_except0 attribute to value. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func FractionalMaxPoolSeed2(value int64) FractionalMaxPoolAttr { +// value: The expected shape of an element, if known, +// excluding the first dimension. Used to validate the shapes of +// TensorArray elements. If this shape is not fully specified, concatenating +// zero-size TensorArrays is an error. +// If not specified, defaults to +func TensorArrayConcatV3ElementShapeExcept0(value tf.Shape) TensorArrayConcatV3Attr { return func(m optionalAttr) { - m["seed2"] = value + m["element_shape_except0"] = value } } -// Performs fractional max pooling on the input. -// -// Fractional max pooling is slightly different than regular max pooling. In -// regular max pooling, you downsize an input set by taking the maximum value of -// smaller N x N subsections of the set (often 2x2), and try to reduce the set by -// a factor of N, where N is an integer. Fractional max pooling, as you might -// expect from the word "fractional", means that the overall reduction ratio N -// does not have to be an integer. -// -// The sizes of the pooling regions are generated randomly but are fairly uniform. -// For example, let's look at the height dimension, and the constraints on the -// list of rows that will be pool boundaries. +// Concat the elements from the TensorArray into value `value`. // -// First we define the following: +// Takes `T` elements of shapes // -// 1. input_row_length : the number of rows from the input set -// 2. output_row_length : which will be smaller than the input -// 3. alpha = input_row_length / output_row_length : our reduction ratio -// 4. K = floor(alpha) -// 5. row_pooling_sequence : this is the result list of pool boundary rows +// ``` +// (n0 x d0 x d1 x ...), (n1 x d0 x d1 x ...), ..., (n(T-1) x d0 x d1 x ...) +// ``` // -// Then, row_pooling_sequence should satisfy: +// and concatenates them into a Tensor of shape: // -// 1. a[0] = 0 : the first value of the sequence is 0 -// 2. a[end] = input_row_length : the last value of the sequence is the size -// 3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size -// 4. length(row_pooling_sequence) = output_row_length+1 +// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)``` // -// For more details on fractional max pooling, see this paper: -// [Benjamin Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) +// All elements must have the same shape (excepting the first dimension). // // Arguments: -// value: 4-D with shape `[batch, height, width, channels]`. -// pooling_ratio: Pooling ratio for each dimension of `value`, currently only -// supports row and col dimension and should be >= 1.0. For example, a valid -// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements -// must be 1.0 because we don't allow pooling on batch and channels -// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions -// respectively. +// handle: The handle to a TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// dtype: The type of the elem that is returned. // -// Returns output tensor after fractional max pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. -func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalMaxPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { +// Returns All of the elements in the TensorArray, concatenated along the first +// axis.A vector of the row sizes of the original T elements in the +// value output. In the example above, this would be the values: +// `(n1, n2, ..., n(T-1))`. +func TensorArrayConcatV3(scope *Scope, handle tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayConcatV3Attr) (value tf.Output, lengths tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} + attrs := map[string]interface{}{"dtype": dtype} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "FractionalMaxPool", + Type: "TensorArrayConcatV3", Input: []tf.Input{ - value, + handle, flow_in, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0), op.Output(1) } -// Deprecated. Use TensorArraySizeV3 +// Split the data from the input value into TensorArray elements. // -// DEPRECATED at GraphDef version 26: Use TensorArraySizeV3 -func TensorArraySizeV2(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { +// Assuming that `lengths` takes on values +// +// ```(n0, n1, ..., n(T-1))``` +// +// and that `value` has shape +// +// ```(n0 + n1 + ... + n(T-1) x d0 x d1 x ...)```, +// +// this splits values into a TensorArray with T tensors. +// +// TensorArray index t will be the subtensor of values with starting position +// +// ```(n0 + n1 + ... + n(t-1), 0, 0, ...)``` +// +// and having size +// +// ```nt x d0 x d1 x ...``` +// +// Arguments: +// handle: The handle to a TensorArray. +// value: The concatenated tensor to write to the TensorArray. +// lengths: The vector of lengths, how to split the rows of value into the +// TensorArray. +// flow_in: A float scalar that enforces proper chaining of operations. +// +// Returns A float scalar that enforces proper chaining of operations. +func TensorArraySplitV3(scope *Scope, handle tf.Output, value tf.Output, lengths tf.Output, flow_in tf.Output) (flow_out tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "TensorArraySizeV2", + Type: "TensorArraySplitV3", Input: []tf.Input{ - handle, flow_in, + handle, value, lengths, flow_in, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Conv2DAttr is an optional argument to Conv2D. -type Conv2DAttr func(optionalAttr) - -// Conv2DUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. -// If not specified, defaults to true -func Conv2DUseCudnnOnGpu(value bool) Conv2DAttr { - return func(m optionalAttr) { - m["use_cudnn_on_gpu"] = value - } -} - -// Conv2DDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, height, width, channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, channels, height, width]. -// If not specified, defaults to "NHWC" -func Conv2DDataFormat(value string) Conv2DAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} +// SerializeSparseAttr is an optional argument to SerializeSparse. +type SerializeSparseAttr func(optionalAttr) -// Conv2DDilations sets the optional dilations attribute to value. +// SerializeSparseOutType sets the optional out_type attribute to value. // -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each -// filter element on that dimension. The dimension order is determined by the -// value of `data_format`, see above for details. Dilations in the batch and -// depth dimensions must be 1. -// If not specified, defaults to -func Conv2DDilations(value []int64) Conv2DAttr { +// value: The `dtype` to use for serialization; the supported types are `string` +// (default) and `variant`. +// If not specified, defaults to DT_STRING +func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr { return func(m optionalAttr) { - m["dilations"] = value + m["out_type"] = value } } -// Computes a 2-D convolution given 4-D `input` and `filter` tensors. -// -// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` -// and a filter / kernel tensor of shape -// `[filter_height, filter_width, in_channels, out_channels]`, this op -// performs the following: -// -// 1. Flattens the filter to a 2-D matrix with shape -// `[filter_height * filter_width * in_channels, output_channels]`. -// 2. Extracts image patches from the input tensor to form a *virtual* -// tensor of shape `[batch, out_height, out_width, -// filter_height * filter_width * in_channels]`. -// 3. For each patch, right-multiplies the filter matrix and the image patch -// vector. -// -// In detail, with the default NHWC format, -// -// output[b, i, j, k] = -// sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] * -// filter[di, dj, q, k] -// -// Must have `strides[0] = strides[3] = 1`. For the most common case of the same -// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. +// Serialize a `SparseTensor` into a `[3]` `Tensor` object. // // Arguments: -// input: A 4-D tensor. The dimension order is interpreted according to the value -// of `data_format`, see below for details. -// filter: A 4-D tensor of shape -// `[filter_height, filter_width, in_channels, out_channels]` -// strides: 1-D tensor of length 4. The stride of the sliding window for each -// dimension of `input`. The dimension order is determined by the value of -// `data_format`, see below for details. -// padding: The type of padding algorithm to use. -// -// Returns A 4-D tensor. The dimension order is determined by the value of -// `data_format`, see below for details. -func Conv2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv2DAttr) (output tf.Output) { +// sparse_indices: 2-D. The `indices` of the `SparseTensor`. +// sparse_values: 1-D. The `values` of the `SparseTensor`. +// sparse_shape: 1-D. The `shape` of the `SparseTensor`. +func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Conv2D", + Type: "SerializeSparse", Input: []tf.Input{ - input, filter, + sparse_indices, sparse_values, sparse_shape, }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2. +type RandomShuffleQueueV2Attr func(optionalAttr) + +// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value. +// +// value: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["shapes"] = value + } +} + +// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value. +// +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value. +// +// value: Dequeue will block unless there would be this +// many elements after the dequeue or the queue is closed. This +// ensures a minimum level of mixing of elements. +// If not specified, defaults to 0 +func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["min_after_dequeue"] = value + } } -// StageAttr is an optional argument to Stage. -type StageAttr func(optionalAttr) - -// StageCapacity sets the optional capacity attribute to value. +// RandomShuffleQueueV2Seed sets the optional seed attribute to value. // -// value: Maximum number of elements in the Staging Area. If > 0, inserts -// on the container will block when the capacity is reached. +// value: If either seed or seed2 is set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, a random seed is used. // If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func StageCapacity(value int64) StageAttr { +func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr { return func(m optionalAttr) { - m["capacity"] = value + m["seed"] = value } } -// StageMemoryLimit sets the optional memory_limit attribute to value. +// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value. // -// value: The maximum number of bytes allowed for Tensors in the Staging Area. -// If > 0, inserts will block until sufficient space is available. +// value: A second seed to avoid seed collision. // If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func StageMemoryLimit(value int64) StageAttr { +func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr { return func(m optionalAttr) { - m["memory_limit"] = value + m["seed2"] = value } } -// StageContainer sets the optional container attribute to value. +// RandomShuffleQueueV2Container sets the optional container attribute to value. // -// value: If non-empty, this queue is placed in the given container. Otherwise, -// a default container is used. +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. // If not specified, defaults to "" -func StageContainer(value string) StageAttr { +func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr { return func(m optionalAttr) { m["container"] = value } } -// StageSharedName sets the optional shared_name attribute to value. +// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value. // -// value: It is necessary to match this name to the matching Unstage Op. +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. // If not specified, defaults to "" -func StageSharedName(value string) StageAttr { +func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr { return func(m optionalAttr) { m["shared_name"] = value } } -// Stage values similar to a lightweight Enqueue. -// -// The basic functionality of this Op is similar to a queue with many -// fewer capabilities and options. This Op is optimized for performance. +// A queue that randomizes the order of elements. // // Arguments: -// values: a list of tensors -// dtypes A list of data types that inserted values should adhere to. +// component_types: The type of each component in a value. // -// Returns the created operation. -func Stage(scope *Scope, values []tf.Output, optional ...StageAttr) (o *tf.Operation) { +// Returns The handle to the queue. +func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"component_types": component_types} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Stage", - Input: []tf.Input{ - tf.OutputList(values), - }, + Type: "RandomShuffleQueueV2", + Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// StagePeekAttr is an optional argument to StagePeek. -type StagePeekAttr func(optionalAttr) - -// StagePeekCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// Draw bounding boxes on a batch of images. // -// REQUIRES: value >= 0 -func StagePeekCapacity(value int64) StagePeekAttr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// StagePeekMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// Outputs a copy of `images` but draws on top of the pixels zero or more bounding +// boxes specified by the locations in `boxes`. The coordinates of the each +// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The +// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and +// height of the underlying image. // -// REQUIRES: value >= 0 -func StagePeekMemoryLimit(value int64) StagePeekAttr { - return func(m optionalAttr) { - m["memory_limit"] = value +// For example, if an image is 100 x 200 pixels (height x width) and the bounding +// box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of +// the bounding box will be `(40, 10)` to `(180, 50)` (in (x,y) coordinates). +// +// Parts of the bounding box may fall outside the image. +// +// Arguments: +// images: 4-D with shape `[batch, height, width, depth]`. A batch of images. +// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding +// boxes. +// +// Returns 4-D with the same shape as `images`. The batch of input images with +// bounding boxes drawn on the images. +func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DrawBoundingBoxes", + Input: []tf.Input{ + images, boxes, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// StagePeekContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func StagePeekContainer(value string) StagePeekAttr { +// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler. +type LearnedUnigramCandidateSamplerAttr func(optionalAttr) + +// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["container"] = value + m["seed"] = value } } -// StagePeekSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func StagePeekSharedName(value string) StagePeekAttr { +// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["seed2"] = value } } -// Op peeks at the values at the specified index. If the +// Generates labels for candidate sampling with a learned unigram distribution. // -// underlying container does not contain sufficient elements -// this op will block until it does. This Op is optimized for -// performance. -func StagePeek(scope *Scope, index tf.Output, dtypes []tf.DataType, optional ...StagePeekAttr) (values []tf.Output) { +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. +// +// For each batch, this op picks a single set of sampled candidate labels. +// +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. +// +// Arguments: +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to randomly sample. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// range_max: The sampler will sample integers from the interval [0, range_max). +// +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "StagePeek", + Type: "LearnedUnigramCandidateSampler", Input: []tf.Input{ - index, + true_classes, }, Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// Computes gradients for the scaled exponential linear (Selu) operation. +// +// Arguments: +// gradients: The backpropagated gradients to the corresponding Selu operation. +// outputs: The outputs of the corresponding Selu operation. +// +// Returns The gradients: `gradients * (outputs + scale * alpha)` +// if outputs < 0, `scale * gradients` otherwise. +func SeluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) { if scope.Err() != nil { return } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("StagePeek", err) - return + opspec := tf.OpSpec{ + Type: "SeluGrad", + Input: []tf.Input{ + gradients, outputs, + }, } - return values + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapStageAttr is an optional argument to MapStage. -type MapStageAttr func(optionalAttr) - -// MapStageCapacity sets the optional capacity attribute to value. +// Get the current size of the TensorArray. // -// value: Maximum number of elements in the Staging Area. If > 0, inserts -// on the container will block when the capacity is reached. -// If not specified, defaults to 0 +// Arguments: +// handle: The handle to a TensorArray (output of TensorArray or TensorArrayGrad). +// flow_in: A float scalar that enforces proper chaining of operations. // -// REQUIRES: value >= 0 -func MapStageCapacity(value int64) MapStageAttr { - return func(m optionalAttr) { - m["capacity"] = value +// Returns The current size of the TensorArray. +func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { + if scope.Err() != nil { + return } -} - -// MapStageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func MapStageMemoryLimit(value int64) MapStageAttr { - return func(m optionalAttr) { - m["memory_limit"] = value + opspec := tf.OpSpec{ + Type: "TensorArraySizeV3", + Input: []tf.Input{ + handle, flow_in, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapStageContainer sets the optional container attribute to value. +// Deprecated. Use TensorArrayGradV3 // -// value: If non-empty, this queue is placed in the given container. Otherwise, -// a default container is used. -// If not specified, defaults to "" -func MapStageContainer(value string) MapStageAttr { - return func(m optionalAttr) { - m["container"] = value +// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3 +func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArrayWriteV2", + Input: []tf.Input{ + handle, index, value, flow_in, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapStageSharedName sets the optional shared_name attribute to value. +// SparseReduceMaxAttr is an optional argument to SparseReduceMax. +type SparseReduceMaxAttr func(optionalAttr) + +// SparseReduceMaxKeepDims sets the optional keep_dims attribute to value. // -// value: It is necessary to match this name to the matching Unstage Op. -// If not specified, defaults to "" -func MapStageSharedName(value string) MapStageAttr { +// value: If true, retain reduced dimensions with length 1. +// If not specified, defaults to false +func SparseReduceMaxKeepDims(value bool) SparseReduceMaxAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["keep_dims"] = value } } -// Stage (key, values) in the underlying container which behaves like a hashtable. +// Computes the max of elements across dimensions of a SparseTensor. // -// Arguments: -// key: int64 +// This Op takes a SparseTensor and is the sparse counterpart to +// `tf.reduce_max()`. In particular, this Op also returns a dense `Tensor` +// instead of a sparse one. // -// values: a list of tensors -// dtypes A list of data types that inserted values should adhere to. +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. +// +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. // +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. // -// Returns the created operation. -func MapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...MapStageAttr) (o *tf.Operation) { +// Returns `R-K`-D. The reduced Tensor. +func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceMaxAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapStage", + Type: "SparseReduceMax", Input: []tf.Input{ - key, indices, tf.OutputList(values), + input_indices, input_values, input_shape, reduction_axes, }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapUnstageAttr is an optional argument to MapUnstage. -type MapUnstageAttr func(optionalAttr) +// AsStringAttr is an optional argument to AsString. +type AsStringAttr func(optionalAttr) -// MapUnstageCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// AsStringPrecision sets the optional precision attribute to value. // -// REQUIRES: value >= 0 -func MapUnstageCapacity(value int64) MapUnstageAttr { +// value: The post-decimal precision to use for floating point numbers. +// Only used if precision > -1. +// If not specified, defaults to -1 +func AsStringPrecision(value int64) AsStringAttr { return func(m optionalAttr) { - m["capacity"] = value + m["precision"] = value } } -// MapUnstageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// AsStringScientific sets the optional scientific attribute to value. // -// REQUIRES: value >= 0 -func MapUnstageMemoryLimit(value int64) MapUnstageAttr { +// value: Use scientific notation for floating point numbers. +// If not specified, defaults to false +func AsStringScientific(value bool) AsStringAttr { return func(m optionalAttr) { - m["memory_limit"] = value + m["scientific"] = value } } -// MapUnstageContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapUnstageContainer(value string) MapUnstageAttr { +// AsStringShortest sets the optional shortest attribute to value. +// +// value: Use shortest representation (either scientific or standard) for +// floating point numbers. +// If not specified, defaults to false +func AsStringShortest(value bool) AsStringAttr { return func(m optionalAttr) { - m["container"] = value + m["shortest"] = value } } -// MapUnstageSharedName sets the optional shared_name attribute to value. +// AsStringWidth sets the optional width attribute to value. +// +// value: Pad pre-decimal numbers to this width. +// Applies to both floating point and integer numbers. +// Only used if width > -1. +// If not specified, defaults to -1 +func AsStringWidth(value int64) AsStringAttr { + return func(m optionalAttr) { + m["width"] = value + } +} + +// AsStringFill sets the optional fill attribute to value. +// +// value: The value to pad if width > -1. If empty, pads with spaces. +// Another typical value is '0'. String cannot be longer than 1 character. // If not specified, defaults to "" -func MapUnstageSharedName(value string) MapUnstageAttr { +func AsStringFill(value string) AsStringAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["fill"] = value } } -// Op removes and returns the values associated with the key +// Converts each entry in the given tensor to strings. Supports many numeric // -// from the underlying container. If the underlying container -// does not contain this key, the op will block until it does. -func MapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageAttr) (values []tf.Output) { +// types and boolean. +func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapUnstage", + Type: "AsString", Input: []tf.Input{ - key, indices, + input, }, Attrs: attrs, } op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Deprecated. Use TensorArrayScatterV3 +// +// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3 +func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArrayScatterV2", + Input: []tf.Input{ + handle, indices, value, flow_in, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a tree ensemble model and returns a handle to it. +// +// Arguments: +// tree_ensemble_handle: Handle to the tree ensemble resource to be created. +// stamp_token: Token to use as the initial value of the resource stamp. +// tree_ensemble_serialized: Serialized proto of the tree ensemble. +// +// Returns the created operation. +func BoostedTreesCreateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("MapUnstage", err) + opspec := tf.OpSpec{ + Type: "BoostedTreesCreateEnsemble", + Input: []tf.Input{ + tree_ensemble_handle, stamp_token, tree_ensemble_serialized, + }, + } + return scope.AddOperation(opspec) +} + +// Applies sparse addition to `input` using individual values or slices +// +// from `updates` according to indices `indices`. The updates are non-aliasing: +// `input` is only modified in-place if no other operations will use it. +// Otherwise, a copy of `input` is made. This operation has a gradient with +// respect to both `input` and `updates`. +// +// `input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. +// +// `indices` must be integer tensor, containing indices into `input`. +// It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where `0 < K <= P`. +// +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or `(P-K)`-dimensional slices +// (if `K < P`) along the `K`th dimension of `input`. +// +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: +// +// $$[d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]].$$ +// +// For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 +// elements. In Python, that addition would look like this: +// +// input = tf.constant([1, 2, 3, 4, 5, 6, 7, 8]) +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// output = tf.scatter_nd_non_aliasing_add(input, indices, updates) +// with tf.Session() as sess: +// print(sess.run(output)) +// +// The resulting value `output` would look like this: +// +// [1, 13, 3, 14, 14, 6, 7, 20] +// +// See @{tf.scatter_nd} for more details about how to make updates to slices. +// +// Arguments: +// input: A Tensor. +// indices: A Tensor. Must be one of the following types: `int32`, `int64`. +// A tensor of indices into `input`. +// updates: A Tensor. Must have the same type as ref. A tensor of updated values +// to add to `input`. +// +// Returns A `Tensor` with the same shape as `input`, containing values of `input` +// updated with `updates`. +func ScatterNdNonAliasingAdd(scope *Scope, input tf.Output, indices tf.Output, updates tf.Output) (output tf.Output) { + if scope.Err() != nil { return } - return values + opspec := tf.OpSpec{ + Type: "ScatterNdNonAliasingAdd", + Input: []tf.Input{ + input, indices, updates, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) } -// MapIncompleteSizeAttr is an optional argument to MapIncompleteSize. -type MapIncompleteSizeAttr func(optionalAttr) +// FractionalMaxPoolAttr is an optional argument to FractionalMaxPool. +type FractionalMaxPoolAttr func(optionalAttr) -// MapIncompleteSizeCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// FractionalMaxPoolPseudoRandom sets the optional pseudo_random attribute to value. // -// REQUIRES: value >= 0 -func MapIncompleteSizeCapacity(value int64) MapIncompleteSizeAttr { +// value: When set to True, generates the pooling sequence in a +// pseudorandom fashion, otherwise, in a random fashion. Check paper [Benjamin +// Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) for +// difference between pseudorandom and random. +// If not specified, defaults to false +func FractionalMaxPoolPseudoRandom(value bool) FractionalMaxPoolAttr { return func(m optionalAttr) { - m["capacity"] = value + m["pseudo_random"] = value } } -// MapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 +// FractionalMaxPoolOverlapping sets the optional overlapping attribute to value. +// +// value: When set to True, it means when pooling, the values at the boundary +// of adjacent pooling cells are used by both cells. For example: +// +// `index 0 1 2 3 4` +// +// `value 20 5 16 3 7` +// +// If the pooling sequence is [0, 2, 4], then 16, at index 2 will be used twice. +// The result would be [20, 16] for fractional max pooling. +// If not specified, defaults to false +func FractionalMaxPoolOverlapping(value bool) FractionalMaxPoolAttr { + return func(m optionalAttr) { + m["overlapping"] = value + } +} + +// FractionalMaxPoolDeterministic sets the optional deterministic attribute to value. // -// REQUIRES: value >= 0 -func MapIncompleteSizeMemoryLimit(value int64) MapIncompleteSizeAttr { +// value: When set to True, a fixed pooling region will be used when +// iterating over a FractionalMaxPool node in the computation graph. Mainly used +// in unit test to make FractionalMaxPool deterministic. +// If not specified, defaults to false +func FractionalMaxPoolDeterministic(value bool) FractionalMaxPoolAttr { return func(m optionalAttr) { - m["memory_limit"] = value + m["deterministic"] = value } } -// MapIncompleteSizeContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func MapIncompleteSizeContainer(value string) MapIncompleteSizeAttr { +// FractionalMaxPoolSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func FractionalMaxPoolSeed(value int64) FractionalMaxPoolAttr { return func(m optionalAttr) { - m["container"] = value + m["seed"] = value } } -// MapIncompleteSizeSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func MapIncompleteSizeSharedName(value string) MapIncompleteSizeAttr { +// FractionalMaxPoolSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func FractionalMaxPoolSeed2(value int64) FractionalMaxPoolAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["seed2"] = value } } -// Op returns the number of incomplete elements in the underlying container. -func MapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...MapIncompleteSizeAttr) (size tf.Output) { +// Performs fractional max pooling on the input. +// +// Fractional max pooling is slightly different than regular max pooling. In +// regular max pooling, you downsize an input set by taking the maximum value of +// smaller N x N subsections of the set (often 2x2), and try to reduce the set by +// a factor of N, where N is an integer. Fractional max pooling, as you might +// expect from the word "fractional", means that the overall reduction ratio N +// does not have to be an integer. +// +// The sizes of the pooling regions are generated randomly but are fairly uniform. +// For example, let's look at the height dimension, and the constraints on the +// list of rows that will be pool boundaries. +// +// First we define the following: +// +// 1. input_row_length : the number of rows from the input set +// 2. output_row_length : which will be smaller than the input +// 3. alpha = input_row_length / output_row_length : our reduction ratio +// 4. K = floor(alpha) +// 5. row_pooling_sequence : this is the result list of pool boundary rows +// +// Then, row_pooling_sequence should satisfy: +// +// 1. a[0] = 0 : the first value of the sequence is 0 +// 2. a[end] = input_row_length : the last value of the sequence is the size +// 3. K <= (a[i+1] - a[i]) <= K+1 : all intervals are K or K+1 size +// 4. length(row_pooling_sequence) = output_row_length+1 +// +// For more details on fractional max pooling, see this paper: +// [Benjamin Graham, Fractional Max-Pooling](http://arxiv.org/abs/1412.6071) +// +// Arguments: +// value: 4-D with shape `[batch, height, width, channels]`. +// pooling_ratio: Pooling ratio for each dimension of `value`, currently only +// supports row and col dimension and should be >= 1.0. For example, a valid +// pooling ratio looks like [1.0, 1.44, 1.73, 1.0]. The first and last elements +// must be 1.0 because we don't allow pooling on batch and channels +// dimensions. 1.44 and 1.73 are pooling ratio on height and width dimensions +// respectively. +// +// Returns output tensor after fractional max pooling.row pooling sequence, needed to calculate gradient.column pooling sequence, needed to calculate gradient. +func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, optional ...FractionalMaxPoolAttr) (output tf.Output, row_pooling_sequence tf.Output, col_pooling_sequence tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"pooling_ratio": pooling_ratio} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "MapIncompleteSize", - + Type: "FractionalMaxPool", + Input: []tf.Input{ + value, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// OrderedMapUnstageAttr is an optional argument to OrderedMapUnstage. -type OrderedMapUnstageAttr func(optionalAttr) - -// OrderedMapUnstageCapacity sets the optional capacity attribute to value. -// If not specified, defaults to 0 +// Deprecated. Use TensorArraySizeV3 // -// REQUIRES: value >= 0 -func OrderedMapUnstageCapacity(value int64) OrderedMapUnstageAttr { - return func(m optionalAttr) { - m["capacity"] = value +// DEPRECATED at GraphDef version 26: Use TensorArraySizeV3 +func TensorArraySizeV2(scope *Scope, handle tf.Output, flow_in tf.Output) (size tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TensorArraySizeV2", + Input: []tf.Input{ + handle, flow_in, + }, } + op := scope.AddOperation(opspec) + return op.Output(0) } -// OrderedMapUnstageMemoryLimit sets the optional memory_limit attribute to value. -// If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func OrderedMapUnstageMemoryLimit(value int64) OrderedMapUnstageAttr { +// Conv2DAttr is an optional argument to Conv2D. +type Conv2DAttr func(optionalAttr) + +// Conv2DUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. +// If not specified, defaults to true +func Conv2DUseCudnnOnGpu(value bool) Conv2DAttr { return func(m optionalAttr) { - m["memory_limit"] = value + m["use_cudnn_on_gpu"] = value } } -// OrderedMapUnstageContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func OrderedMapUnstageContainer(value string) OrderedMapUnstageAttr { +// Conv2DDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, height, width, channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, channels, height, width]. +// If not specified, defaults to "NHWC" +func Conv2DDataFormat(value string) Conv2DAttr { return func(m optionalAttr) { - m["container"] = value + m["data_format"] = value } } -// OrderedMapUnstageSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapUnstageSharedName(value string) OrderedMapUnstageAttr { +// Conv2DDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each +// filter element on that dimension. The dimension order is determined by the +// value of `data_format`, see above for details. Dilations in the batch and +// depth dimensions must be 1. +// If not specified, defaults to +func Conv2DDilations(value []int64) Conv2DAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["dilations"] = value } } -// Op removes and returns the values associated with the key +// Computes a 2-D convolution given 4-D `input` and `filter` tensors. // -// from the underlying container. If the underlying container -// does not contain this key, the op will block until it does. -func OrderedMapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapUnstageAttr) (values []tf.Output) { +// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` +// and a filter / kernel tensor of shape +// `[filter_height, filter_width, in_channels, out_channels]`, this op +// performs the following: +// +// 1. Flattens the filter to a 2-D matrix with shape +// `[filter_height * filter_width * in_channels, output_channels]`. +// 2. Extracts image patches from the input tensor to form a *virtual* +// tensor of shape `[batch, out_height, out_width, +// filter_height * filter_width * in_channels]`. +// 3. For each patch, right-multiplies the filter matrix and the image patch +// vector. +// +// In detail, with the default NHWC format, +// +// output[b, i, j, k] = +// sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] * +// filter[di, dj, q, k] +// +// Must have `strides[0] = strides[3] = 1`. For the most common case of the same +// horizontal and vertices strides, `strides = [1, stride, stride, 1]`. +// +// Arguments: +// input: A 4-D tensor. The dimension order is interpreted according to the value +// of `data_format`, see below for details. +// filter: A 4-D tensor of shape +// `[filter_height, filter_width, in_channels, out_channels]` +// strides: 1-D tensor of length 4. The stride of the sliding window for each +// dimension of `input`. The dimension order is determined by the value of +// `data_format`, see below for details. +// padding: The type of padding algorithm to use. +// +// Returns A 4-D tensor. The dimension order is determined by the value of +// `data_format`, see below for details. +func Conv2D(scope *Scope, input tf.Output, filter tf.Output, strides []int64, padding string, optional ...Conv2DAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"dtypes": dtypes} + attrs := map[string]interface{}{"strides": strides, "padding": padding} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "OrderedMapUnstage", + Type: "Conv2D", Input: []tf.Input{ - key, indices, + input, filter, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if values, idx, err = makeOutputList(op, idx, "values"); err != nil { - scope.UpdateErr("OrderedMapUnstage", err) - return - } - return values + return op.Output(0) } -// OrderedMapSizeAttr is an optional argument to OrderedMapSize. -type OrderedMapSizeAttr func(optionalAttr) +// StageAttr is an optional argument to Stage. +type StageAttr func(optionalAttr) -// OrderedMapSizeCapacity sets the optional capacity attribute to value. +// StageCapacity sets the optional capacity attribute to value. +// +// value: Maximum number of elements in the Staging Area. If > 0, inserts +// on the container will block when the capacity is reached. // If not specified, defaults to 0 // // REQUIRES: value >= 0 -func OrderedMapSizeCapacity(value int64) OrderedMapSizeAttr { +func StageCapacity(value int64) StageAttr { return func(m optionalAttr) { m["capacity"] = value } } -// OrderedMapSizeMemoryLimit sets the optional memory_limit attribute to value. +// StageMemoryLimit sets the optional memory_limit attribute to value. +// +// value: The maximum number of bytes allowed for Tensors in the Staging Area. +// If > 0, inserts will block until sufficient space is available. // If not specified, defaults to 0 // // REQUIRES: value >= 0 -func OrderedMapSizeMemoryLimit(value int64) OrderedMapSizeAttr { +func StageMemoryLimit(value int64) StageAttr { return func(m optionalAttr) { m["memory_limit"] = value } } -// OrderedMapSizeContainer sets the optional container attribute to value. +// StageContainer sets the optional container attribute to value. +// +// value: If non-empty, this queue is placed in the given container. Otherwise, +// a default container is used. // If not specified, defaults to "" -func OrderedMapSizeContainer(value string) OrderedMapSizeAttr { +func StageContainer(value string) StageAttr { return func(m optionalAttr) { m["container"] = value } } -// OrderedMapSizeSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func OrderedMapSizeSharedName(value string) OrderedMapSizeAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Op returns the number of elements in the underlying container. -func OrderedMapSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapSizeAttr) (size tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtypes": dtypes} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "OrderedMapSize", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ShapeNAttr is an optional argument to ShapeN. -type ShapeNAttr func(optionalAttr) - -// ShapeNOutType sets the optional out_type attribute to value. -// If not specified, defaults to DT_INT32 -func ShapeNOutType(value tf.DataType) ShapeNAttr { +// StageSharedName sets the optional shared_name attribute to value. +// +// value: It is necessary to match this name to the matching Unstage Op. +// If not specified, defaults to "" +func StageSharedName(value string) StageAttr { return func(m optionalAttr) { - m["out_type"] = value + m["shared_name"] = value } } -// Returns shape of tensors. +// Stage values similar to a lightweight Enqueue. // -// This operation returns N 1-D integer tensors representing shape of `input[i]s`. -func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output) { +// The basic functionality of this Op is similar to a queue with many +// fewer capabilities and options. This Op is optimized for performance. +// +// Arguments: +// values: a list of tensors +// dtypes A list of data types that inserted values should adhere to. +// +// Returns the created operation. +func Stage(scope *Scope, values []tf.Output, optional ...StageAttr) (o *tf.Operation) { if scope.Err() != nil { return } @@ -29019,119 +30376,71 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t a(attrs) } opspec := tf.OpSpec{ - Type: "ShapeN", + Type: "Stage", Input: []tf.Input{ - tf.OutputList(input), + tf.OutputList(values), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - scope.UpdateErr("ShapeN", err) - return - } - return output -} - -// CudnnRNNParamsToCanonicalAttr is an optional argument to CudnnRNNParamsToCanonical. -type CudnnRNNParamsToCanonicalAttr func(optionalAttr) - -// CudnnRNNParamsToCanonicalRnnMode sets the optional rnn_mode attribute to value. -// If not specified, defaults to "lstm" -func CudnnRNNParamsToCanonicalRnnMode(value string) CudnnRNNParamsToCanonicalAttr { - return func(m optionalAttr) { - m["rnn_mode"] = value - } + return scope.AddOperation(opspec) } -// CudnnRNNParamsToCanonicalInputMode sets the optional input_mode attribute to value. -// If not specified, defaults to "linear_input" -func CudnnRNNParamsToCanonicalInputMode(value string) CudnnRNNParamsToCanonicalAttr { - return func(m optionalAttr) { - m["input_mode"] = value - } -} +// StagePeekAttr is an optional argument to StagePeek. +type StagePeekAttr func(optionalAttr) -// CudnnRNNParamsToCanonicalDirection sets the optional direction attribute to value. -// If not specified, defaults to "unidirectional" -func CudnnRNNParamsToCanonicalDirection(value string) CudnnRNNParamsToCanonicalAttr { +// StagePeekCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 +// +// REQUIRES: value >= 0 +func StagePeekCapacity(value int64) StagePeekAttr { return func(m optionalAttr) { - m["direction"] = value + m["capacity"] = value } } -// CudnnRNNParamsToCanonicalDropout sets the optional dropout attribute to value. +// StagePeekMemoryLimit sets the optional memory_limit attribute to value. // If not specified, defaults to 0 -func CudnnRNNParamsToCanonicalDropout(value float32) CudnnRNNParamsToCanonicalAttr { +// +// REQUIRES: value >= 0 +func StagePeekMemoryLimit(value int64) StagePeekAttr { return func(m optionalAttr) { - m["dropout"] = value + m["memory_limit"] = value } } -// CudnnRNNParamsToCanonicalSeed sets the optional seed attribute to value. -// If not specified, defaults to 0 -func CudnnRNNParamsToCanonicalSeed(value int64) CudnnRNNParamsToCanonicalAttr { +// StagePeekContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func StagePeekContainer(value string) StagePeekAttr { return func(m optionalAttr) { - m["seed"] = value + m["container"] = value } } -// CudnnRNNParamsToCanonicalSeed2 sets the optional seed2 attribute to value. -// If not specified, defaults to 0 -func CudnnRNNParamsToCanonicalSeed2(value int64) CudnnRNNParamsToCanonicalAttr { +// StagePeekSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func StagePeekSharedName(value string) StagePeekAttr { return func(m optionalAttr) { - m["seed2"] = value + m["shared_name"] = value } } -// Retrieves CudnnRNN params in canonical form. -// -// Retrieves a set of weights from the opaque params buffer that can be saved and -// restored in a way compatible with future runs. -// -// Note that the params buffer may not be compatible across different GPUs. So any -// save and restoration should be converted to and from the canonical weights and -// biases. +// Op peeks at the values at the specified index. If the // -// num_layers: Specifies the number of layers in the RNN model. -// num_units: Specifies the size of the hidden state. -// input_size: Specifies the size of the input state. -// num_params: number of parameter sets for all layers. -// Each layer may contain multiple parameter sets, with each set consisting of -// a weight matrix and a bias vector. -// weights: the canonical form of weights that can be used for saving -// and restoration. They are more likely to be compatible across different -// generations. -// biases: the canonical form of biases that can be used for saving -// and restoration. They are more likely to be compatible across different -// generations. -// rnn_mode: Indicates the type of the RNN model. -// input_mode: Indicate whether there is a linear projection between the input and -// The actual computation before the first layer. 'skip_input' is only allowed -// when input_size == num_units; 'auto_select' implies 'skip_input' when -// input_size == num_units; otherwise, it implies 'linear_input'. -// direction: Indicates whether a bidirectional model will be used. -// dir = (direction == bidirectional) ? 2 : 1 -// dropout: dropout probability. When set to 0., dropout is disabled. -// seed: the 1st part of a seed to initialize dropout. -// seed2: the 2nd part of a seed to initialize dropout. -func CudnnRNNParamsToCanonical(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, params tf.Output, num_params int64, optional ...CudnnRNNParamsToCanonicalAttr) (weights []tf.Output, biases []tf.Output) { +// underlying container does not contain sufficient elements +// this op will block until it does. This Op is optimized for +// performance. +func StagePeek(scope *Scope, index tf.Output, dtypes []tf.DataType, optional ...StagePeekAttr) (values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_params": num_params} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "CudnnRNNParamsToCanonical", + Type: "StagePeek", Input: []tf.Input{ - num_layers, num_units, input_size, params, + index, }, Attrs: attrs, } @@ -29141,544 +30450,472 @@ func CudnnRNNParamsToCanonical(scope *Scope, num_layers tf.Output, num_units tf. } var idx int var err error - if weights, idx, err = makeOutputList(op, idx, "weights"); err != nil { - scope.UpdateErr("CudnnRNNParamsToCanonical", err) - return - } - if biases, idx, err = makeOutputList(op, idx, "biases"); err != nil { - scope.UpdateErr("CudnnRNNParamsToCanonical", err) + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("StagePeek", err) return } - return weights, biases + return values } -// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler. -type UniformCandidateSamplerAttr func(optionalAttr) +// MapStageAttr is an optional argument to MapStage. +type MapStageAttr func(optionalAttr) -// UniformCandidateSamplerSeed sets the optional seed attribute to value. +// MapStageCapacity sets the optional capacity attribute to value. // -// value: If either seed or seed2 are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. +// value: Maximum number of elements in the Staging Area. If > 0, inserts +// on the container will block when the capacity is reached. // If not specified, defaults to 0 -func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value. // -// value: An second seed to avoid seed collision. -// If not specified, defaults to 0 -func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr { +// REQUIRES: value >= 0 +func MapStageCapacity(value int64) MapStageAttr { return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Generates labels for candidate sampling with a uniform distribution. -// -// See explanations of candidate sampling and the data formats at -// go/candidate-sampling. -// -// For each batch, this op picks a single set of sampled candidate labels. -// -// The advantages of sampling candidates per-batch are simplicity and the -// possibility of efficient dense matrix multiplication. The disadvantage is that -// the sampled candidates must be chosen independently of the context and of the -// true labels. -// -// Arguments: -// true_classes: A batch_size * num_true matrix, in which each row contains the -// IDs of the num_true target_classes in the corresponding original label. -// num_true: Number of true labels per context. -// num_sampled: Number of candidates to randomly sample. -// unique: If unique is true, we sample with rejection, so that all sampled -// candidates in a batch are unique. This requires some approximation to -// estimate the post-rejection sampling probabilities. -// range_max: The sampler will sample integers from the interval [0, range_max). -// -// Returns A vector of length num_sampled, in which each element is -// the ID of a sampled candidate.A batch_size * num_true matrix, representing -// the number of times each candidate is expected to occur in a batch -// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled -// candidate representing the number of times the candidate is expected -// to occur in a batch of sampled candidates. If unique=true, then this is a -// probability. -func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "UniformCandidateSampler", - Input: []tf.Input{ - true_classes, - }, - Attrs: attrs, + m["capacity"] = value } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) } -// CTCLossAttr is an optional argument to CTCLoss. -type CTCLossAttr func(optionalAttr) - -// CTCLossPreprocessCollapseRepeated sets the optional preprocess_collapse_repeated attribute to value. +// MapStageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: Scalar, if true then repeated labels are -// collapsed prior to the CTC calculation. -// If not specified, defaults to false -func CTCLossPreprocessCollapseRepeated(value bool) CTCLossAttr { +// REQUIRES: value >= 0 +func MapStageMemoryLimit(value int64) MapStageAttr { return func(m optionalAttr) { - m["preprocess_collapse_repeated"] = value + m["memory_limit"] = value } } -// CTCLossCtcMergeRepeated sets the optional ctc_merge_repeated attribute to value. +// MapStageContainer sets the optional container attribute to value. // -// value: Scalar. If set to false, *during* CTC calculation -// repeated non-blank labels will not be merged and are interpreted as -// individual labels. This is a simplified version of CTC. -// If not specified, defaults to true -func CTCLossCtcMergeRepeated(value bool) CTCLossAttr { +// value: If non-empty, this queue is placed in the given container. Otherwise, +// a default container is used. +// If not specified, defaults to "" +func MapStageContainer(value string) MapStageAttr { return func(m optionalAttr) { - m["ctc_merge_repeated"] = value + m["container"] = value } } -// CTCLossIgnoreLongerOutputsThanInputs sets the optional ignore_longer_outputs_than_inputs attribute to value. -// -// value: Scalar. If set to true, during CTC -// calculation, items that have longer output sequences than input sequences -// are skipped: they don't contribute to the loss term and have zero-gradient. -// If not specified, defaults to false -func CTCLossIgnoreLongerOutputsThanInputs(value bool) CTCLossAttr { +// MapStageSharedName sets the optional shared_name attribute to value. +// +// value: It is necessary to match this name to the matching Unstage Op. +// If not specified, defaults to "" +func MapStageSharedName(value string) MapStageAttr { return func(m optionalAttr) { - m["ignore_longer_outputs_than_inputs"] = value + m["shared_name"] = value } } -// Calculates the CTC Loss (log probability) for each batch entry. Also calculates -// -// the gradient. This class performs the softmax operation for you, so inputs -// should be e.g. linear projections of outputs by an LSTM. +// Stage (key, values) in the underlying container which behaves like a hashtable. // // Arguments: -// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. -// labels_indices: The indices of a `SparseTensor`. -// `labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for -// `(batch b, time t)`. -// labels_values: The values (labels) associated with the given batch and time. -// sequence_length: A vector containing sequence lengths (batch). +// key: int64 // -// Returns A vector (batch) containing log-probabilities.The gradient of `loss`. 3-D, shape: -// `(max_time x batch_size x num_classes)`. -func CTCLoss(scope *Scope, inputs tf.Output, labels_indices tf.Output, labels_values tf.Output, sequence_length tf.Output, optional ...CTCLossAttr) (loss tf.Output, gradient tf.Output) { +// values: a list of tensors +// dtypes A list of data types that inserted values should adhere to. +// +// +// Returns the created operation. +func MapStage(scope *Scope, key tf.Output, indices tf.Output, values []tf.Output, dtypes []tf.DataType, optional ...MapStageAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "CTCLoss", + Type: "MapStage", Input: []tf.Input{ - inputs, labels_indices, labels_values, sequence_length, + key, indices, tf.OutputList(values), }, Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return scope.AddOperation(opspec) } -// CTCGreedyDecoderAttr is an optional argument to CTCGreedyDecoder. -type CTCGreedyDecoderAttr func(optionalAttr) +// MapUnstageAttr is an optional argument to MapUnstage. +type MapUnstageAttr func(optionalAttr) -// CTCGreedyDecoderMergeRepeated sets the optional merge_repeated attribute to value. +// MapUnstageCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: If True, merge repeated classes in output. -// If not specified, defaults to false -func CTCGreedyDecoderMergeRepeated(value bool) CTCGreedyDecoderAttr { +// REQUIRES: value >= 0 +func MapUnstageCapacity(value int64) MapUnstageAttr { return func(m optionalAttr) { - m["merge_repeated"] = value + m["capacity"] = value } } -// Performs greedy decoding on the logits given in inputs. -// -// A note about the attribute merge_repeated: if enabled, when -// consecutive logits' maximum indices are the same, only the first of -// these is emitted. Labeling the blank '*', the sequence "A B B * B B" -// becomes "A B B" if merge_repeated = True and "A B B B B" if -// merge_repeated = False. -// -// Regardless of the value of merge_repeated, if the maximum index of a given -// time and batch corresponds to the blank, index `(num_classes - 1)`, no new -// element is emitted. -// -// Arguments: -// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. -// sequence_length: A vector containing sequence lengths, size `(batch_size)`. +// MapUnstageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// Returns Indices matrix, size `(total_decoded_outputs x 2)`, -// of a `SparseTensor`. The rows store: [batch, time].Values vector, size: `(total_decoded_outputs)`, -// of a `SparseTensor`. The vector stores the decoded classes.Shape vector, size `(2)`, of the decoded SparseTensor. -// Values are: `[batch_size, max_decoded_length]`.Matrix, size `(batch_size x 1)`, containing sequence -// log-probabilities. -func CTCGreedyDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, optional ...CTCGreedyDecoderAttr) (decoded_indices tf.Output, decoded_values tf.Output, decoded_shape tf.Output, log_probability tf.Output) { - if scope.Err() != nil { - return +// REQUIRES: value >= 0 +func MapUnstageMemoryLimit(value int64) MapUnstageAttr { + return func(m optionalAttr) { + m["memory_limit"] = value } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) +} + +// MapUnstageContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapUnstageContainer(value string) MapUnstageAttr { + return func(m optionalAttr) { + m["container"] = value } - opspec := tf.OpSpec{ - Type: "CTCGreedyDecoder", - Input: []tf.Input{ - inputs, sequence_length, - }, - Attrs: attrs, +} + +// MapUnstageSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapUnstageSharedName(value string) MapUnstageAttr { + return func(m optionalAttr) { + m["shared_name"] = value } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// Forwards `data` to the output port determined by `pred`. -// -// If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise, -// the data goes to `output_false`. -// -// See also `RefSwitch` and `Merge`. -// -// Arguments: -// data: The tensor to be forwarded to the appropriate output. -// pred: A scalar that specifies which output port will receive data. +// Op removes and returns the values associated with the key // -// Returns If `pred` is false, data will be forwarded to this output.If `pred` is true, data will be forwarded to this output. -func Switch(scope *Scope, data tf.Output, pred tf.Output) (output_false tf.Output, output_true tf.Output) { +// from the underlying container. If the underlying container +// does not contain this key, the op will block until it does. +func MapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageAttr) (values []tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Switch", + Type: "MapUnstage", Input: []tf.Input{ - data, pred, + key, indices, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Add all input tensors element wise. -// -// Arguments: -// inputs: Must all be the same size and shape. -func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "AddN", - Input: []tf.Input{ - tf.OutputList(inputs), - }, + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("MapUnstage", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + return values } -// TryRpcAttr is an optional argument to TryRpc. -type TryRpcAttr func(optionalAttr) +// MapIncompleteSizeAttr is an optional argument to MapIncompleteSize. +type MapIncompleteSizeAttr func(optionalAttr) -// TryRpcProtocol sets the optional protocol attribute to value. +// MapIncompleteSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: RPC protocol to use. Empty string means use the default protocol. -// Options include 'grpc'. -// If not specified, defaults to "" -func TryRpcProtocol(value string) TryRpcAttr { +// REQUIRES: value >= 0 +func MapIncompleteSizeCapacity(value int64) MapIncompleteSizeAttr { return func(m optionalAttr) { - m["protocol"] = value + m["capacity"] = value } } -// TryRpcFailFast sets the optional fail_fast attribute to value. +// MapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: `boolean`. If `true` (default), then failures to connect -// (i.e., the server does not immediately respond) cause an RPC failure. -// If not specified, defaults to true -func TryRpcFailFast(value bool) TryRpcAttr { +// REQUIRES: value >= 0 +func MapIncompleteSizeMemoryLimit(value int64) MapIncompleteSizeAttr { return func(m optionalAttr) { - m["fail_fast"] = value + m["memory_limit"] = value } } -// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value. -// -// value: `int`. If `0` (default), then the kernel will run the RPC -// request and only time out if the RPC deadline passes or the session times out. -// If this value is greater than `0`, then the op will raise an exception if -// the RPC takes longer than `timeout_in_ms`. -// If not specified, defaults to 0 -func TryRpcTimeoutInMs(value int64) TryRpcAttr { +// MapIncompleteSizeContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func MapIncompleteSizeContainer(value string) MapIncompleteSizeAttr { return func(m optionalAttr) { - m["timeout_in_ms"] = value + m["container"] = value } } - -// Perform batches of RPC requests. -// -// This op asynchronously performs either a single RPC request, or a batch -// of requests. RPC requests are defined by three main parameters: -// -// - `address` (the host+port or BNS address of the request) -// - `method` (the method name for the request) -// - `request` (the serialized proto string, or vector of strings, -// of the RPC request argument). -// -// For example, if you have an RPC service running on port localhost:2345, -// and its interface is configured with the following proto declaration: -// -// ``` -// service MyService { -// rpc MyMethod(MyRequestProto) returns (MyResponseProto) { -// } -// }; -// ``` -// -// then call this op with arguments: -// -// ``` -// address = "localhost:2345" -// method = "MyService/MyMethod" -// ``` -// -// The `request` tensor is a string tensor representing serialized `MyRequestProto` -// strings; and the output string tensor `response` will have the same shape -// and contain (upon successful completion) corresponding serialized -// `MyResponseProto` strings. -// -// For example, to send a single, empty, `MyRequestProto`, call -// this op with `request = ""`. To send 5 **parallel** empty requests, -// call this op with `request = ["", "", "", "", ""]`. -// -// More generally, one can create a batch of `MyRequestProto` serialized protos -// from regular batched tensors using the `encode_proto` op, and convert -// the response `MyResponseProto` serialized protos to batched tensors -// using the `decode_proto` op. -// -// **NOTE** Working with serialized proto strings is faster than instantiating -// actual proto objects in memory, so no performance degradation is expected -// compared to writing custom kernels for this workflow. -// -// Unlike the standard `Rpc` op, if the connection fails or the remote worker -// returns an error status, this op does **not** reraise the exception. -// Instead, the `status_code` and `status_message` entry for the corresponding RPC -// call is set with the error returned from the RPC call. The `response` tensor -// will contain valid response values for those minibatch entries whose RPCs did -// not fail; the rest of the entries will have empty strings. -// -// Arguments: -// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `method` and `request`. -// method: `0-D` or `1-D`. The method address on the RPC server. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `address` and `request`. -// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument. -// If this tensor has more than 1 element, then multiple parallel rpc requests -// are sent. This argument broadcasts with `address` and `method`. -// -// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages -// returned from the RPC calls. -func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) { + +// MapIncompleteSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func MapIncompleteSizeSharedName(value string) MapIncompleteSizeAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op returns the number of incomplete elements in the underlying container. +func MapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...MapIncompleteSizeAttr) (size tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TryRpc", - Input: []tf.Input{ - address, method, request, - }, + Type: "MapIncompleteSize", + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) + return op.Output(0) } -// EnterAttr is an optional argument to Enter. -type EnterAttr func(optionalAttr) +// OrderedMapUnstageAttr is an optional argument to OrderedMapUnstage. +type OrderedMapUnstageAttr func(optionalAttr) -// EnterIsConstant sets the optional is_constant attribute to value. +// OrderedMapUnstageCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// value: If true, the output is constant within the child frame. -// If not specified, defaults to false -func EnterIsConstant(value bool) EnterAttr { +// REQUIRES: value >= 0 +func OrderedMapUnstageCapacity(value int64) OrderedMapUnstageAttr { return func(m optionalAttr) { - m["is_constant"] = value + m["capacity"] = value } } -// EnterParallelIterations sets the optional parallel_iterations attribute to value. +// OrderedMapUnstageMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// value: The number of iterations allowed to run in parallel. -// If not specified, defaults to 10 -func EnterParallelIterations(value int64) EnterAttr { +// REQUIRES: value >= 0 +func OrderedMapUnstageMemoryLimit(value int64) OrderedMapUnstageAttr { return func(m optionalAttr) { - m["parallel_iterations"] = value + m["memory_limit"] = value } } -// Creates or finds a child frame, and makes `data` available to the child frame. -// -// This op is used together with `Exit` to create loops in the graph. -// The unique `frame_name` is used by the `Executor` to identify frames. If -// `is_constant` is true, `output` is a constant in the child frame; otherwise -// it may be changed in the child frame. At most `parallel_iterations` iterations -// are run in parallel in the child frame. -// -// Arguments: -// data: The tensor to be made available to the child frame. -// frame_name: The name of the child frame. +// OrderedMapUnstageContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapUnstageContainer(value string) OrderedMapUnstageAttr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// OrderedMapUnstageSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapUnstageSharedName(value string) OrderedMapUnstageAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op removes and returns the values associated with the key // -// Returns The same tensor as `data`. -func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) { +// from the underlying container. If the underlying container +// does not contain this key, the op will block until it does. +func OrderedMapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapUnstageAttr) (values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"frame_name": frame_name} + attrs := map[string]interface{}{"dtypes": dtypes} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "Enter", + Type: "OrderedMapUnstage", Input: []tf.Input{ - data, + key, indices, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Produce a string tensor that encodes the state of a Reader. -// -// Not all Readers support being serialized, so this can produce an -// Unimplemented error. -// -// Arguments: -// reader_handle: Handle to a Reader. -func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) { if scope.Err() != nil { return } - opspec := tf.OpSpec{ - Type: "ReaderSerializeStateV2", - Input: []tf.Input{ - reader_handle, - }, + var idx int + var err error + if values, idx, err = makeOutputList(op, idx, "values"); err != nil { + scope.UpdateErr("OrderedMapUnstage", err) + return } - op := scope.AddOperation(opspec) - return op.Output(0) + return values } -// Exits the current frame to its parent frame. -// -// Exit makes its input `data` available to the parent frame. +// OrderedMapSizeAttr is an optional argument to OrderedMapSize. +type OrderedMapSizeAttr func(optionalAttr) + +// OrderedMapSizeCapacity sets the optional capacity attribute to value. +// If not specified, defaults to 0 // -// Arguments: -// data: The tensor to be made available to the parent frame. +// REQUIRES: value >= 0 +func OrderedMapSizeCapacity(value int64) OrderedMapSizeAttr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// OrderedMapSizeMemoryLimit sets the optional memory_limit attribute to value. +// If not specified, defaults to 0 // -// Returns The same tensor as `data`. -func Exit(scope *Scope, data tf.Output) (output tf.Output) { - if scope.Err() != nil { - return +// REQUIRES: value >= 0 +func OrderedMapSizeMemoryLimit(value int64) OrderedMapSizeAttr { + return func(m optionalAttr) { + m["memory_limit"] = value } - opspec := tf.OpSpec{ - Type: "Exit", - Input: []tf.Input{ - data, - }, +} + +// OrderedMapSizeContainer sets the optional container attribute to value. +// If not specified, defaults to "" +func OrderedMapSizeContainer(value string) OrderedMapSizeAttr { + return func(m optionalAttr) { + m["container"] = value } - op := scope.AddOperation(opspec) - return op.Output(0) } -// Returns a copy of the input tensor. -func Snapshot(scope *Scope, input tf.Output) (output tf.Output) { +// OrderedMapSizeSharedName sets the optional shared_name attribute to value. +// If not specified, defaults to "" +func OrderedMapSizeSharedName(value string) OrderedMapSizeAttr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// Op returns the number of elements in the underlying container. +func OrderedMapSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapSizeAttr) (size tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"dtypes": dtypes} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Snapshot", - Input: []tf.Input{ - input, - }, + Type: "OrderedMapSize", + + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns a tensor of zeros with the same shape and type as x. +// CudnnRNNV2Attr is an optional argument to CudnnRNNV2. +type CudnnRNNV2Attr func(optionalAttr) + +// CudnnRNNV2RnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNV2RnnMode(value string) CudnnRNNV2Attr { + return func(m optionalAttr) { + m["rnn_mode"] = value + } +} + +// CudnnRNNV2InputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNV2InputMode(value string) CudnnRNNV2Attr { + return func(m optionalAttr) { + m["input_mode"] = value + } +} + +// CudnnRNNV2Direction sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNV2Direction(value string) CudnnRNNV2Attr { + return func(m optionalAttr) { + m["direction"] = value + } +} + +// CudnnRNNV2Dropout sets the optional dropout attribute to value. +// If not specified, defaults to 0 +func CudnnRNNV2Dropout(value float32) CudnnRNNV2Attr { + return func(m optionalAttr) { + m["dropout"] = value + } +} + +// CudnnRNNV2Seed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func CudnnRNNV2Seed(value int64) CudnnRNNV2Attr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// CudnnRNNV2Seed2 sets the optional seed2 attribute to value. +// If not specified, defaults to 0 +func CudnnRNNV2Seed2(value int64) CudnnRNNV2Attr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// CudnnRNNV2IsTraining sets the optional is_training attribute to value. +// If not specified, defaults to true +func CudnnRNNV2IsTraining(value bool) CudnnRNNV2Attr { + return func(m optionalAttr) { + m["is_training"] = value + } +} + +// A RNN backed by cuDNN. // -// Arguments: -// x: a tensor of type T. +// Computes the RNN from the input and initial states, with respect to the params +// buffer. Produces one extra output "host_reserved" than CudnnRNN. // -// Returns a tensor of the same shape and type as x but filled with zeros. -func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) { +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicates whether there is a linear projection between the input and +// the actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. Should be +// "unidirectional" or "bidirectional". +// dropout: Dropout probability. When set to 0., dropout is disabled. +// seed: The 1st part of a seed to initialize dropout. +// seed2: The 2nd part of a seed to initialize dropout. +// input: A 3-D tensor with the shape of [seq_length, batch_size, input_size]. +// input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size, +// num_units]. +// input_c: For LSTM, a 3-D tensor with the shape of +// [num_layer * dir, batch, num_units]. For other models, it is ignored. +// params: A 1-D tensor that contains the weights and biases in an opaque layout. +// The size must be created through CudnnRNNParamsSize, and initialized +// separately. Note that they might not be compatible across different +// generations. So it is a good idea to save and restore +// output: A 3-D tensor with the shape of [seq_length, batch_size, +// dir * num_units]. +// output_h: The same shape has input_h. +// output_c: The same shape as input_c for LSTM. An empty tensor for other models. +// is_training: Indicates whether this operation is used for inferenece or +// training. +// reserve_space: An opaque tensor that can be used in backprop calculation. It +// is only produced if is_training is true. +// host_reserved: An opaque tensor that can be used in backprop calculation. It is +// only produced if is_training is true. It is output on host memory rather than +// device memory. +func CudnnRNNV2(scope *Scope, input tf.Output, input_h tf.Output, input_c tf.Output, params tf.Output, optional ...CudnnRNNV2Attr) (output tf.Output, output_h tf.Output, output_c tf.Output, reserve_space tf.Output, host_reserved tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ZerosLike", + Type: "CudnnRNNV2", Input: []tf.Input{ - x, + input, input_h, input_c, params, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) -} - -// AbortAttr is an optional argument to Abort. -type AbortAttr func(optionalAttr) - -// AbortErrorMsg sets the optional error_msg attribute to value. -// -// value: A string which is the message associated with the exception. -// If not specified, defaults to "" -func AbortErrorMsg(value string) AbortAttr { - return func(m optionalAttr) { - m["error_msg"] = value - } + return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } -// AbortExitWithoutError sets the optional exit_without_error attribute to value. -// If not specified, defaults to false -func AbortExitWithoutError(value bool) AbortAttr { +// ShapeNAttr is an optional argument to ShapeN. +type ShapeNAttr func(optionalAttr) + +// ShapeNOutType sets the optional out_type attribute to value. +// If not specified, defaults to DT_INT32 +func ShapeNOutType(value tf.DataType) ShapeNAttr { return func(m optionalAttr) { - m["exit_without_error"] = value + m["out_type"] = value } } -// Raise a exception to abort the process when called. -// -// If exit_without_error is true, the process will exit normally, -// otherwise it will exit with a SIGABORT signal. -// -// Returns nothing but an exception. +// Returns shape of tensors. // -// Returns the created operation. -func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) { +// This operation returns N 1-D integer tensors representing shape of `input[i]s`. +func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output) { if scope.Err() != nil { return } @@ -29687,128 +30924,168 @@ func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) { a(attrs) } opspec := tf.OpSpec{ - Type: "Abort", - + Type: "ShapeN", + Input: []tf.Input{ + tf.OutputList(input), + }, Attrs: attrs, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if output, idx, err = makeOutputList(op, idx, "output"); err != nil { + scope.UpdateErr("ShapeN", err) + return + } + return output } -// FixedUnigramCandidateSamplerAttr is an optional argument to FixedUnigramCandidateSampler. -type FixedUnigramCandidateSamplerAttr func(optionalAttr) +// CudnnRNNParamsToCanonicalAttr is an optional argument to CudnnRNNParamsToCanonical. +type CudnnRNNParamsToCanonicalAttr func(optionalAttr) -// FixedUnigramCandidateSamplerVocabFile sets the optional vocab_file attribute to value. -// -// value: Each valid line in this file (which should have a CSV-like format) -// corresponds to a valid word ID. IDs are in sequential order, starting from -// num_reserved_ids. The last entry in each line is expected to be a value -// corresponding to the count or relative probability. Exactly one of vocab_file -// and unigrams needs to be passed to this op. -// If not specified, defaults to "" -func FixedUnigramCandidateSamplerVocabFile(value string) FixedUnigramCandidateSamplerAttr { +// CudnnRNNParamsToCanonicalRnnMode sets the optional rnn_mode attribute to value. +// If not specified, defaults to "lstm" +func CudnnRNNParamsToCanonicalRnnMode(value string) CudnnRNNParamsToCanonicalAttr { return func(m optionalAttr) { - m["vocab_file"] = value + m["rnn_mode"] = value } } -// FixedUnigramCandidateSamplerDistortion sets the optional distortion attribute to value. -// -// value: The distortion is used to skew the unigram probability distribution. -// Each weight is first raised to the distortion's power before adding to the -// internal unigram distribution. As a result, distortion = 1.0 gives regular -// unigram sampling (as defined by the vocab file), and distortion = 0.0 gives -// a uniform distribution. -// If not specified, defaults to 1 -func FixedUnigramCandidateSamplerDistortion(value float32) FixedUnigramCandidateSamplerAttr { +// CudnnRNNParamsToCanonicalInputMode sets the optional input_mode attribute to value. +// If not specified, defaults to "linear_input" +func CudnnRNNParamsToCanonicalInputMode(value string) CudnnRNNParamsToCanonicalAttr { return func(m optionalAttr) { - m["distortion"] = value + m["input_mode"] = value } } -// FixedUnigramCandidateSamplerNumReservedIds sets the optional num_reserved_ids attribute to value. -// -// value: Optionally some reserved IDs can be added in the range [0, -// ..., num_reserved_ids) by the users. One use case is that a special unknown -// word token is used as ID 0. These IDs will have a sampling probability of 0. +// CudnnRNNParamsToCanonicalDirection sets the optional direction attribute to value. +// If not specified, defaults to "unidirectional" +func CudnnRNNParamsToCanonicalDirection(value string) CudnnRNNParamsToCanonicalAttr { + return func(m optionalAttr) { + m["direction"] = value + } +} + +// CudnnRNNParamsToCanonicalDropout sets the optional dropout attribute to value. // If not specified, defaults to 0 -func FixedUnigramCandidateSamplerNumReservedIds(value int64) FixedUnigramCandidateSamplerAttr { +func CudnnRNNParamsToCanonicalDropout(value float32) CudnnRNNParamsToCanonicalAttr { return func(m optionalAttr) { - m["num_reserved_ids"] = value + m["dropout"] = value } } -// FixedUnigramCandidateSamplerNumShards sets the optional num_shards attribute to value. -// -// value: A sampler can be used to sample from a subset of the original range -// in order to speed up the whole computation through parallelism. This parameter -// (together with 'shard') indicates the number of partitions that are being -// used in the overall computation. -// If not specified, defaults to 1 -// -// REQUIRES: value >= 1 -func FixedUnigramCandidateSamplerNumShards(value int64) FixedUnigramCandidateSamplerAttr { +// CudnnRNNParamsToCanonicalSeed sets the optional seed attribute to value. +// If not specified, defaults to 0 +func CudnnRNNParamsToCanonicalSeed(value int64) CudnnRNNParamsToCanonicalAttr { return func(m optionalAttr) { - m["num_shards"] = value + m["seed"] = value } } -// FixedUnigramCandidateSamplerShard sets the optional shard attribute to value. -// -// value: A sampler can be used to sample from a subset of the original range -// in order to speed up the whole computation through parallelism. This parameter -// (together with 'num_shards') indicates the particular partition number of a -// sampler op, when partitioning is being used. +// CudnnRNNParamsToCanonicalSeed2 sets the optional seed2 attribute to value. // If not specified, defaults to 0 -// -// REQUIRES: value >= 0 -func FixedUnigramCandidateSamplerShard(value int64) FixedUnigramCandidateSamplerAttr { +func CudnnRNNParamsToCanonicalSeed2(value int64) CudnnRNNParamsToCanonicalAttr { return func(m optionalAttr) { - m["shard"] = value + m["seed2"] = value } } -// FixedUnigramCandidateSamplerUnigrams sets the optional unigrams attribute to value. +// Retrieves CudnnRNN params in canonical form. // -// value: A list of unigram counts or probabilities, one per ID in sequential -// order. Exactly one of vocab_file and unigrams should be passed to this op. -// If not specified, defaults to <> -func FixedUnigramCandidateSamplerUnigrams(value []float32) FixedUnigramCandidateSamplerAttr { - return func(m optionalAttr) { - m["unigrams"] = value +// Retrieves a set of weights from the opaque params buffer that can be saved and +// restored in a way compatible with future runs. +// +// Note that the params buffer may not be compatible across different GPUs. So any +// save and restoration should be converted to and from the canonical weights and +// biases. +// +// num_layers: Specifies the number of layers in the RNN model. +// num_units: Specifies the size of the hidden state. +// input_size: Specifies the size of the input state. +// num_params: number of parameter sets for all layers. +// Each layer may contain multiple parameter sets, with each set consisting of +// a weight matrix and a bias vector. +// weights: the canonical form of weights that can be used for saving +// and restoration. They are more likely to be compatible across different +// generations. +// biases: the canonical form of biases that can be used for saving +// and restoration. They are more likely to be compatible across different +// generations. +// rnn_mode: Indicates the type of the RNN model. +// input_mode: Indicate whether there is a linear projection between the input and +// The actual computation before the first layer. 'skip_input' is only allowed +// when input_size == num_units; 'auto_select' implies 'skip_input' when +// input_size == num_units; otherwise, it implies 'linear_input'. +// direction: Indicates whether a bidirectional model will be used. +// dir = (direction == bidirectional) ? 2 : 1 +// dropout: dropout probability. When set to 0., dropout is disabled. +// seed: the 1st part of a seed to initialize dropout. +// seed2: the 2nd part of a seed to initialize dropout. +func CudnnRNNParamsToCanonical(scope *Scope, num_layers tf.Output, num_units tf.Output, input_size tf.Output, params tf.Output, num_params int64, optional ...CudnnRNNParamsToCanonicalAttr) (weights []tf.Output, biases []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_params": num_params} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "CudnnRNNParamsToCanonical", + Input: []tf.Input{ + num_layers, num_units, input_size, params, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if weights, idx, err = makeOutputList(op, idx, "weights"); err != nil { + scope.UpdateErr("CudnnRNNParamsToCanonical", err) + return + } + if biases, idx, err = makeOutputList(op, idx, "biases"); err != nil { + scope.UpdateErr("CudnnRNNParamsToCanonical", err) + return } + return weights, biases } -// FixedUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler. +type UniformCandidateSamplerAttr func(optionalAttr) + +// UniformCandidateSamplerSeed sets the optional seed attribute to value. // // value: If either seed or seed2 are set to be non-zero, the random number // generator is seeded by the given seed. Otherwise, it is seeded by a // random seed. // If not specified, defaults to 0 -func FixedUnigramCandidateSamplerSeed(value int64) FixedUnigramCandidateSamplerAttr { +func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr { return func(m optionalAttr) { m["seed"] = value } } -// FixedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value. // // value: An second seed to avoid seed collision. // If not specified, defaults to 0 -func FixedUnigramCandidateSamplerSeed2(value int64) FixedUnigramCandidateSamplerAttr { +func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr { return func(m optionalAttr) { m["seed2"] = value } } -// Generates labels for candidate sampling with a learned unigram distribution. -// -// A unigram sampler could use a fixed unigram distribution read from a -// file or passed in as an in-memory array instead of building up the distribution -// from data on the fly. There is also an option to skew the distribution by -// applying a distortion power to the weights. +// Generates labels for candidate sampling with a uniform distribution. // -// The vocabulary file should be in CSV-like format, with the last field -// being the weight associated with the word. +// See explanations of candidate sampling and the data formats at +// go/candidate-sampling. // // For each batch, this op picks a single set of sampled candidate labels. // @@ -29834,7 +31111,7 @@ func FixedUnigramCandidateSamplerSeed2(value int64) FixedUnigramCandidateSampler // candidate representing the number of times the candidate is expected // to occur in a batch of sampled candidates. If unique=true, then this is a // probability. -func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...FixedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { +func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } @@ -29843,7 +31120,7 @@ func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true a(attrs) } opspec := tf.OpSpec{ - Type: "FixedUnigramCandidateSampler", + Type: "UniformCandidateSampler", Input: []tf.Input{ true_classes, }, @@ -29853,38 +31130,60 @@ func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true return op.Output(0), op.Output(1), op.Output(2) } -// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2. -type WholeFileReaderV2Attr func(optionalAttr) +// CTCLossAttr is an optional argument to CTCLoss. +type CTCLossAttr func(optionalAttr) -// WholeFileReaderV2Container sets the optional container attribute to value. +// CTCLossPreprocessCollapseRepeated sets the optional preprocess_collapse_repeated attribute to value. // -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr { +// value: Scalar, if true then repeated labels are +// collapsed prior to the CTC calculation. +// If not specified, defaults to false +func CTCLossPreprocessCollapseRepeated(value bool) CTCLossAttr { return func(m optionalAttr) { - m["container"] = value + m["preprocess_collapse_repeated"] = value } } -// WholeFileReaderV2SharedName sets the optional shared_name attribute to value. +// CTCLossCtcMergeRepeated sets the optional ctc_merge_repeated attribute to value. // -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr { +// value: Scalar. If set to false, *during* CTC calculation +// repeated non-blank labels will not be merged and are interpreted as +// individual labels. This is a simplified version of CTC. +// If not specified, defaults to true +func CTCLossCtcMergeRepeated(value bool) CTCLossAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["ctc_merge_repeated"] = value } } -// A Reader that outputs the entire contents of a file as a value. +// CTCLossIgnoreLongerOutputsThanInputs sets the optional ignore_longer_outputs_than_inputs attribute to value. +// +// value: Scalar. If set to true, during CTC +// calculation, items that have longer output sequences than input sequences +// are skipped: they don't contribute to the loss term and have zero-gradient. +// If not specified, defaults to false +func CTCLossIgnoreLongerOutputsThanInputs(value bool) CTCLossAttr { + return func(m optionalAttr) { + m["ignore_longer_outputs_than_inputs"] = value + } +} + +// Calculates the CTC Loss (log probability) for each batch entry. Also calculates +// +// the gradient. This class performs the softmax operation for you, so inputs +// should be e.g. linear projections of outputs by an LSTM. // -// To use, enqueue filenames in a Queue. The output of ReaderRead will -// be a filename (key) and the contents of that file (value). +// Arguments: +// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +// labels_indices: The indices of a `SparseTensor`. +// `labels_indices(i, :) == [b, t]` means `labels_values(i)` stores the id for +// `(batch b, time t)`. +// labels_values: The values (labels) associated with the given batch and time. +// sequence_length: A vector containing sequence lengths (batch). // -// Returns The handle to reference the Reader. -func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) { +// Returns A vector (batch) containing log-probabilities.The gradient of `loss`. 3-D, shape: +// `(max_time x batch_size x num_classes)`. +func CTCLoss(scope *Scope, inputs tf.Output, labels_indices tf.Output, labels_values tf.Output, sequence_length tf.Output, optional ...CTCLossAttr) (loss tf.Output, gradient tf.Output) { if scope.Err() != nil { return } @@ -29893,174 +31192,283 @@ func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_ a(attrs) } opspec := tf.OpSpec{ - Type: "WholeFileReaderV2", - + Type: "CTCLoss", + Input: []tf.Input{ + inputs, labels_indices, labels_values, sequence_length, + }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1) } -// Transforms a tf.Example proto (as a string) into typed tensors. +// CTCGreedyDecoderAttr is an optional argument to CTCGreedyDecoder. +type CTCGreedyDecoderAttr func(optionalAttr) + +// CTCGreedyDecoderMergeRepeated sets the optional merge_repeated attribute to value. +// +// value: If True, merge repeated classes in output. +// If not specified, defaults to false +func CTCGreedyDecoderMergeRepeated(value bool) CTCGreedyDecoderAttr { + return func(m optionalAttr) { + m["merge_repeated"] = value + } +} + +// Performs greedy decoding on the logits given in inputs. +// +// A note about the attribute merge_repeated: if enabled, when +// consecutive logits' maximum indices are the same, only the first of +// these is emitted. Labeling the blank '*', the sequence "A B B * B B" +// becomes "A B B" if merge_repeated = True and "A B B B B" if +// merge_repeated = False. +// +// Regardless of the value of merge_repeated, if the maximum index of a given +// time and batch corresponds to the blank, index `(num_classes - 1)`, no new +// element is emitted. // // Arguments: -// serialized: A vector containing a batch of binary serialized Example protos. -// dense_defaults: A list of Tensors (some may be empty), whose length matches -// the length of `dense_keys`. dense_defaults[j] provides default values -// when the example's feature_map lacks dense_key[j]. If an empty Tensor is -// provided for dense_defaults[j], then the Feature dense_keys[j] is required. -// The input type is inferred from dense_defaults[j], even when it's empty. -// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, -// then the shape of dense_defaults[j] must match that of dense_shapes[j]. -// If dense_shapes[j] has an undefined major dimension (variable strides dense -// feature), dense_defaults[j] must contain a single element: -// the padding element. -// num_sparse: The number of sparse features to be parsed from the example. This -// must match the lengths of `sparse_keys` and `sparse_types`. -// sparse_keys: A list of `num_sparse` strings. -// The keys expected in the Examples' features associated with sparse values. -// dense_keys: The keys expected in the Examples' features associated with dense -// values. -// sparse_types: A list of `num_sparse` types; the data types of data in each -// Feature given in sparse_keys. -// Currently the ParseSingleExample op supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// dense_shapes: The shapes of data in each Feature given in dense_keys. -// The length of this list must match the length of `dense_keys`. The -// number of elements in the Feature corresponding to dense_key[j] must -// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] == -// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j] -// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1, -// ..., DN), the shape of the output Tensor dense_values[j] will be (M, -// D1, .., DN), where M is the number of blocks of elements of length -// D1 * .... * DN, in the input. -func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { +// inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +// sequence_length: A vector containing sequence lengths, size `(batch_size)`. +// +// Returns Indices matrix, size `(total_decoded_outputs x 2)`, +// of a `SparseTensor`. The rows store: [batch, time].Values vector, size: `(total_decoded_outputs)`, +// of a `SparseTensor`. The vector stores the decoded classes.Shape vector, size `(2)`, of the decoded SparseTensor. +// Values are: `[batch_size, max_decoded_length]`.Matrix, size `(batch_size x 1)`, containing sequence +// log-probabilities. +func CTCGreedyDecoder(scope *Scope, inputs tf.Output, sequence_length tf.Output, optional ...CTCGreedyDecoderAttr) (decoded_indices tf.Output, decoded_values tf.Output, decoded_shape tf.Output, log_probability tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "ParseSingleExample", + Type: "CTCGreedyDecoder", Input: []tf.Input{ - serialized, tf.OutputList(dense_defaults), + inputs, sequence_length, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { - scope.UpdateErr("ParseSingleExample", err) - return - } - return sparse_indices, sparse_values, sparse_shapes, dense_values + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } -// Deserializes a serialized tree ensemble config and replaces current tree +// Forwards `data` to the output port determined by `pred`. // -// ensemble. +// If `pred` is true, the `data` input is forwarded to `output_true`. Otherwise, +// the data goes to `output_false`. +// +// See also `RefSwitch` and `Merge`. // // Arguments: -// tree_ensemble_handle: Handle to the tree ensemble. -// stamp_token: Token to use as the new value of the resource stamp. -// tree_ensemble_serialized: Serialized proto of the ensemble. +// data: The tensor to be forwarded to the appropriate output. +// pred: A scalar that specifies which output port will receive data. // -// Returns the created operation. -func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { +// Returns If `pred` is false, data will be forwarded to this output.If `pred` is true, data will be forwarded to this output. +func Switch(scope *Scope, data tf.Output, pred tf.Output) (output_false tf.Output, output_true tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BoostedTreesDeserializeEnsemble", + Type: "Switch", Input: []tf.Input{ - tree_ensemble_handle, stamp_token, tree_ensemble_serialized, + data, pred, }, } - return scope.AddOperation(opspec) + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } -// Elementwise computes the bitwise AND of `x` and `y`. +// Add all input tensors element wise. // -// The result will have those bits set, that are set in both `x` and `y`. The -// computation is performed on the underlying representations of `x` and `y`. -func BitwiseAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Arguments: +// inputs: Must all be the same size and shape. +func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "BitwiseAnd", + Type: "AddN", Input: []tf.Input{ - x, y, + tf.OutputList(inputs), }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Elementwise computes the bitwise left-shift of `x` and `y`. +// TryRpcAttr is an optional argument to TryRpc. +type TryRpcAttr func(optionalAttr) + +// TryRpcProtocol sets the optional protocol attribute to value. +// +// value: RPC protocol to use. Empty string means use the default protocol. +// Options include 'grpc'. +// If not specified, defaults to "" +func TryRpcProtocol(value string) TryRpcAttr { + return func(m optionalAttr) { + m["protocol"] = value + } +} + +// TryRpcFailFast sets the optional fail_fast attribute to value. +// +// value: `boolean`. If `true` (default), then failures to connect +// (i.e., the server does not immediately respond) cause an RPC failure. +// If not specified, defaults to true +func TryRpcFailFast(value bool) TryRpcAttr { + return func(m optionalAttr) { + m["fail_fast"] = value + } +} + +// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value. +// +// value: `int`. If `0` (default), then the kernel will run the RPC +// request and only time out if the RPC deadline passes or the session times out. +// If this value is greater than `0`, then the op will raise an exception if +// the RPC takes longer than `timeout_in_ms`. +// If not specified, defaults to 0 +func TryRpcTimeoutInMs(value int64) TryRpcAttr { + return func(m optionalAttr) { + m["timeout_in_ms"] = value + } +} + +// Perform batches of RPC requests. +// +// This op asynchronously performs either a single RPC request, or a batch +// of requests. RPC requests are defined by three main parameters: +// +// - `address` (the host+port or BNS address of the request) +// - `method` (the method name for the request) +// - `request` (the serialized proto string, or vector of strings, +// of the RPC request argument). +// +// For example, if you have an RPC service running on port localhost:2345, +// and its interface is configured with the following proto declaration: +// +// ``` +// service MyService { +// rpc MyMethod(MyRequestProto) returns (MyResponseProto) { +// } +// }; +// ``` +// +// then call this op with arguments: +// +// ``` +// address = "localhost:2345" +// method = "MyService/MyMethod" +// ``` +// +// The `request` tensor is a string tensor representing serialized `MyRequestProto` +// strings; and the output string tensor `response` will have the same shape +// and contain (upon successful completion) corresponding serialized +// `MyResponseProto` strings. +// +// For example, to send a single, empty, `MyRequestProto`, call +// this op with `request = ""`. To send 5 **parallel** empty requests, +// call this op with `request = ["", "", "", "", ""]`. +// +// More generally, one can create a batch of `MyRequestProto` serialized protos +// from regular batched tensors using the `encode_proto` op, and convert +// the response `MyResponseProto` serialized protos to batched tensors +// using the `decode_proto` op. +// +// **NOTE** Working with serialized proto strings is faster than instantiating +// actual proto objects in memory, so no performance degradation is expected +// compared to writing custom kernels for this workflow. +// +// Unlike the standard `Rpc` op, if the connection fails or the remote worker +// returns an error status, this op does **not** reraise the exception. +// Instead, the `status_code` and `status_message` entry for the corresponding RPC +// call is set with the error returned from the RPC call. The `response` tensor +// will contain valid response values for those minibatch entries whose RPCs did +// not fail; the rest of the entries will have empty strings. +// +// Arguments: +// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `method` and `request`. +// method: `0-D` or `1-D`. The method address on the RPC server. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `address` and `request`. +// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument. +// If this tensor has more than 1 element, then multiple parallel rpc requests +// are sent. This argument broadcasts with `address` and `method`. // -// If `y` is negative, or greater than or equal to the width of `x` in bits the -// result is implementation defined. -func LeftShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages +// returned from the RPC calls. +func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "LeftShift", + Type: "TryRpc", Input: []tf.Input{ - x, y, + address, method, request, }, + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// TensorListStackAttr is an optional argument to TensorListStack. -type TensorListStackAttr func(optionalAttr) +// EnterAttr is an optional argument to Enter. +type EnterAttr func(optionalAttr) -// TensorListStackNumElements sets the optional num_elements attribute to value. -// If not specified, defaults to -1 -func TensorListStackNumElements(value int64) TensorListStackAttr { +// EnterIsConstant sets the optional is_constant attribute to value. +// +// value: If true, the output is constant within the child frame. +// If not specified, defaults to false +func EnterIsConstant(value bool) EnterAttr { return func(m optionalAttr) { - m["num_elements"] = value + m["is_constant"] = value } } -// Stacks all tensors in the list. +// EnterParallelIterations sets the optional parallel_iterations attribute to value. // -// Requires that all tensors have the same shape. +// value: The number of iterations allowed to run in parallel. +// If not specified, defaults to 10 +func EnterParallelIterations(value int64) EnterAttr { + return func(m optionalAttr) { + m["parallel_iterations"] = value + } +} + +// Creates or finds a child frame, and makes `data` available to the child frame. // -// input_handle: the input list -// tensor: the gathered result -// num_elements: optional. If not -1, the number of elements in the list. +// This op is used together with `Exit` to create loops in the graph. +// The unique `frame_name` is used by the `Executor` to identify frames. If +// `is_constant` is true, `output` is a constant in the child frame; otherwise +// it may be changed in the child frame. At most `parallel_iterations` iterations +// are run in parallel in the child frame. // -func TensorListStack(scope *Scope, input_handle tf.Output, element_dtype tf.DataType, optional ...TensorListStackAttr) (tensor tf.Output) { +// Arguments: +// data: The tensor to be made available to the child frame. +// frame_name: The name of the child frame. +// +// Returns The same tensor as `data`. +func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"element_dtype": element_dtype} + attrs := map[string]interface{}{"frame_name": frame_name} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "TensorListStack", + Type: "Enter", Input: []tf.Input{ - input_handle, + data, }, Attrs: attrs, } @@ -30068,535 +31476,320 @@ func TensorListStack(scope *Scope, input_handle tf.Output, element_dtype tf.Data return op.Output(0) } -// Elementwise computes the bitwise right-shift of `x` and `y`. +// Produce a string tensor that encodes the state of a Reader. // -// Performs a logical shift for unsigned integer types, and an arithmetic shift -// for signed integer types. +// Not all Readers support being serialized, so this can produce an +// Unimplemented error. // -// If `y` is negative, or greater than or equal to than the width of `x` in bits -// the result is implementation defined. -func RightShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Arguments: +// reader_handle: Handle to a Reader. +func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "RightShift", + Type: "ReaderSerializeStateV2", Input: []tf.Input{ - x, y, + reader_handle, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Adjust the hue of one or more images. -// -// `images` is a tensor of at least 3 dimensions. The last dimension is -// interpretted as channels, and must be three. +// Exits the current frame to its parent frame. // -// The input image is considered in the RGB colorspace. Conceptually, the RGB -// colors are first mapped into HSV. A delta is then applied all the hue values, -// and then remapped back to RGB colorspace. +// Exit makes its input `data` available to the parent frame. // // Arguments: -// images: Images to adjust. At least 3-D. -// delta: A float delta to add to the hue. +// data: The tensor to be made available to the parent frame. // -// Returns The hue-adjusted image or images. -func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) { +// Returns The same tensor as `data`. +func Exit(scope *Scope, data tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "AdjustHue", + Type: "Exit", Input: []tf.Input{ - images, delta, + data, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// BatchAttr is an optional argument to Batch. -type BatchAttr func(optionalAttr) - -// BatchMaxEnqueuedBatches sets the optional max_enqueued_batches attribute to value. -// If not specified, defaults to 10 -func BatchMaxEnqueuedBatches(value int64) BatchAttr { - return func(m optionalAttr) { - m["max_enqueued_batches"] = value - } -} - -// BatchAllowedBatchSizes sets the optional allowed_batch_sizes attribute to value. -// If not specified, defaults to <> -func BatchAllowedBatchSizes(value []int64) BatchAttr { - return func(m optionalAttr) { - m["allowed_batch_sizes"] = value - } -} - -// BatchContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func BatchContainer(value string) BatchAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// BatchSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func BatchSharedName(value string) BatchAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// BatchBatchingQueue sets the optional batching_queue attribute to value. -// If not specified, defaults to "" -func BatchBatchingQueue(value string) BatchAttr { - return func(m optionalAttr) { - m["batching_queue"] = value - } -} - -// Batches all input tensors nondeterministically. -// -// When many instances of this Op are being run concurrently with the same -// container/shared_name in the same device, some will output zero-shaped Tensors -// and others will output Tensors of size up to max_batch_size. -// -// All Tensors in in_tensors are batched together (so, for example, labels and -// features should be batched with a single instance of this operation. -// -// Each invocation of batch emits an `id` scalar which will be used to identify -// this particular invocation when doing unbatch or its gradient. -// -// Each op which emits a non-empty batch will also emit a non-empty batch_index -// Tensor, which, is a [K, 3] matrix where each row contains the invocation's id, -// start, and length of elements of each set of Tensors present in batched_tensors. -// -// Batched tensors are concatenated along the first dimension, and all tensors in -// in_tensors must have the first dimension of the same size. -// -// in_tensors: The tensors to be batched. -// num_batch_threads: Number of scheduling threads for processing batches of work. -// Determines the number of batches processed in parallel. -// max_batch_size: Batch sizes will never be bigger than this. -// batch_timeout_micros: Maximum number of microseconds to wait before outputting -// an incomplete batch. -// allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does -// nothing. Otherwise, supplies a list of batch sizes, causing the op to pad -// batches up to one of those sizes. The entries must increase monotonically, and -// the final entry must equal max_batch_size. -// grad_timeout_micros: The timeout to use for the gradient. See Unbatch. -// batched_tensors: Either empty tensors or a batch of concatenated Tensors. -// batch_index: If out_tensors is non-empty, has information to invert it. -// container: Controls the scope of sharing of this batch. -// id: always contains a scalar with a unique ID for this invocation of Batch. -// shared_name: Concurrently running instances of batch in the same device with the -// same container and shared_name will batch their elements together. If left -// empty, the op name will be used as the shared name. -// T: the types of tensors to be batched. -func Batch(scope *Scope, in_tensors []tf.Output, num_batch_threads int64, max_batch_size int64, batch_timeout_micros int64, grad_timeout_micros int64, optional ...BatchAttr) (batched_tensors []tf.Output, batch_index tf.Output, id tf.Output) { +// Returns a copy of the input tensor. +func Snapshot(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_batch_threads": num_batch_threads, "max_batch_size": max_batch_size, "batch_timeout_micros": batch_timeout_micros, "grad_timeout_micros": grad_timeout_micros} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Batch", + Type: "Snapshot", Input: []tf.Input{ - tf.OutputList(in_tensors), + input, }, - Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if batched_tensors, idx, err = makeOutputList(op, idx, "batched_tensors"); err != nil { - scope.UpdateErr("Batch", err) - return - } - batch_index = op.Output(idx) - id = op.Output(idx) - return batched_tensors, batch_index, id -} - -// UnbatchAttr is an optional argument to Unbatch. -type UnbatchAttr func(optionalAttr) - -// UnbatchContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func UnbatchContainer(value string) UnbatchAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// UnbatchSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func UnbatchSharedName(value string) UnbatchAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } + return op.Output(0) } -// Reverses the operation of Batch for a single output Tensor. +// Returns a tensor of zeros with the same shape and type as x. // -// An instance of Unbatch either receives an empty batched_tensor, in which case it -// asynchronously waits until the values become available from a concurrently -// running instance of Unbatch with the same container and shared_name, or receives -// a non-empty batched_tensor in which case it finalizes all other concurrently -// running instances and outputs its own element from the batch. +// Arguments: +// x: a tensor of type T. // -// batched_tensor: The possibly transformed output of Batch. The size of the first -// dimension should remain unchanged by the transformations for the operation to -// work. -// batch_index: The matching batch_index obtained from Batch. -// id: The id scalar emitted by Batch. -// unbatched_tensor: The Tensor corresponding to this execution. -// timeout_micros: Maximum amount of time (in microseconds) to wait to receive the -// batched input tensor associated with a given invocation of the op. -// container: Container to control resource sharing. -// shared_name: Instances of Unbatch with the same container and shared_name are -// assumed to possibly belong to the same batch. If left empty, the op name will -// be used as the shared name. -func Unbatch(scope *Scope, batched_tensor tf.Output, batch_index tf.Output, id tf.Output, timeout_micros int64, optional ...UnbatchAttr) (unbatched_tensor tf.Output) { +// Returns a tensor of the same shape and type as x but filled with zeros. +func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"timeout_micros": timeout_micros} - for _, a := range optional { - a(attrs) - } opspec := tf.OpSpec{ - Type: "Unbatch", + Type: "ZerosLike", Input: []tf.Input{ - batched_tensor, batch_index, id, + x, }, - Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// AvgPool3DGradAttr is an optional argument to AvgPool3DGrad. -type AvgPool3DGradAttr func(optionalAttr) +// AbortAttr is an optional argument to Abort. +type AbortAttr func(optionalAttr) -// AvgPool3DGradDataFormat sets the optional data_format attribute to value. +// AbortErrorMsg sets the optional error_msg attribute to value. // -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func AvgPool3DGradDataFormat(value string) AvgPool3DGradAttr { +// value: A string which is the message associated with the exception. +// If not specified, defaults to "" +func AbortErrorMsg(value string) AbortAttr { return func(m optionalAttr) { - m["data_format"] = value + m["error_msg"] = value } } -// Computes gradients of average pooling function. +// AbortExitWithoutError sets the optional exit_without_error attribute to value. +// If not specified, defaults to false +func AbortExitWithoutError(value bool) AbortAttr { + return func(m optionalAttr) { + m["exit_without_error"] = value + } +} + +// Raise a exception to abort the process when called. // -// Arguments: -// orig_input_shape: The original input dimensions. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. +// If exit_without_error is true, the process will exit normally, +// otherwise it will exit with a SIGABORT signal. // -// Returns The backprop for input. -func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DGradAttr) (output tf.Output) { +// Returns nothing but an exception. +// +// Returns the created operation. +func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + attrs := map[string]interface{}{} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "AvgPool3DGrad", - Input: []tf.Input{ - orig_input_shape, grad, - }, + Type: "Abort", + Attrs: attrs, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// ParseSingleSequenceExampleAttr is an optional argument to ParseSingleSequenceExample. -type ParseSingleSequenceExampleAttr func(optionalAttr) +// FixedUnigramCandidateSamplerAttr is an optional argument to FixedUnigramCandidateSampler. +type FixedUnigramCandidateSamplerAttr func(optionalAttr) -// ParseSingleSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value. -// -// value: A list of Ncontext_sparse types; the data types of data in -// each context Feature given in context_sparse_keys. -// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// If not specified, defaults to <> +// FixedUnigramCandidateSamplerVocabFile sets the optional vocab_file attribute to value. // -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleContextSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { +// value: Each valid line in this file (which should have a CSV-like format) +// corresponds to a valid word ID. IDs are in sequential order, starting from +// num_reserved_ids. The last entry in each line is expected to be a value +// corresponding to the count or relative probability. Exactly one of vocab_file +// and unigrams needs to be passed to this op. +// If not specified, defaults to "" +func FixedUnigramCandidateSamplerVocabFile(value string) FixedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["context_sparse_types"] = value + m["vocab_file"] = value } } -// ParseSingleSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value. -// If not specified, defaults to <> +// FixedUnigramCandidateSamplerDistortion sets the optional distortion attribute to value. // -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { +// value: The distortion is used to skew the unigram probability distribution. +// Each weight is first raised to the distortion's power before adding to the +// internal unigram distribution. As a result, distortion = 1.0 gives regular +// unigram sampling (as defined by the vocab file), and distortion = 0.0 gives +// a uniform distribution. +// If not specified, defaults to 1 +func FixedUnigramCandidateSamplerDistortion(value float32) FixedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["feature_list_dense_types"] = value + m["distortion"] = value } } -// ParseSingleSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value. -// -// value: A list of Ncontext_dense shapes; the shapes of data in -// each context Feature given in context_dense_keys. -// The number of elements in the Feature corresponding to context_dense_key[j] -// must always equal context_dense_shapes[j].NumEntries(). -// The shape of context_dense_values[j] will match context_dense_shapes[j]. -// If not specified, defaults to <> +// FixedUnigramCandidateSamplerNumReservedIds sets the optional num_reserved_ids attribute to value. // -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleContextDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { +// value: Optionally some reserved IDs can be added in the range [0, +// ..., num_reserved_ids) by the users. One use case is that a special unknown +// word token is used as ID 0. These IDs will have a sampling probability of 0. +// If not specified, defaults to 0 +func FixedUnigramCandidateSamplerNumReservedIds(value int64) FixedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["context_dense_shapes"] = value + m["num_reserved_ids"] = value } } -// ParseSingleSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value. +// FixedUnigramCandidateSamplerNumShards sets the optional num_shards attribute to value. // -// value: A list of Nfeature_list_sparse types; the data types -// of data in each FeatureList given in feature_list_sparse_keys. -// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList), -// DT_INT64 (Int64List), and DT_STRING (BytesList). -// If not specified, defaults to <> +// value: A sampler can be used to sample from a subset of the original range +// in order to speed up the whole computation through parallelism. This parameter +// (together with 'shard') indicates the number of partitions that are being +// used in the overall computation. +// If not specified, defaults to 1 // -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSingleSequenceExampleAttr { +// REQUIRES: value >= 1 +func FixedUnigramCandidateSamplerNumShards(value int64) FixedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["feature_list_sparse_types"] = value + m["num_shards"] = value } } -// ParseSingleSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value. +// FixedUnigramCandidateSamplerShard sets the optional shard attribute to value. // -// value: A list of Nfeature_list_dense shapes; the shapes of -// data in each FeatureList given in feature_list_dense_keys. -// The shape of each Feature in the FeatureList corresponding to -// feature_list_dense_key[j] must always equal -// feature_list_dense_shapes[j].NumEntries(). -// If not specified, defaults to <> +// value: A sampler can be used to sample from a subset of the original range +// in order to speed up the whole computation through parallelism. This parameter +// (together with 'num_shards') indicates the particular partition number of a +// sampler op, when partitioning is being used. +// If not specified, defaults to 0 // -// REQUIRES: len(value) >= 0 -func ParseSingleSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSingleSequenceExampleAttr { +// REQUIRES: value >= 0 +func FixedUnigramCandidateSamplerShard(value int64) FixedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["feature_list_dense_shapes"] = value + m["shard"] = value } } -// Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors. +// FixedUnigramCandidateSamplerUnigrams sets the optional unigrams attribute to value. // -// Arguments: -// serialized: A scalar containing a binary serialized SequenceExample proto. -// feature_list_dense_missing_assumed_empty: A vector listing the -// FeatureList keys which may be missing from the SequenceExample. If the -// associated FeatureList is missing, it is treated as empty. By default, -// any FeatureList not listed in this vector must exist in the SequenceExample. -// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars). -// The keys expected in the Examples' features associated with context_sparse -// values. -// context_dense_keys: A list of Ncontext_dense string Tensors (scalars). -// The keys expected in the SequenceExamples' context features associated with -// dense values. -// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors -// (scalars). The keys expected in the FeatureLists associated with sparse -// values. -// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars). -// The keys expected in the SequenceExamples' feature_lists associated -// with lists of dense values. -// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty). -// context_dense_defaults[j] provides default values -// when the SequenceExample's context map lacks context_dense_key[j]. -// If an empty Tensor is provided for context_dense_defaults[j], -// then the Feature context_dense_keys[j] is required. -// The input type is inferred from context_dense_defaults[j], even when it's -// empty. If context_dense_defaults[j] is not empty, its shape must match -// context_dense_shapes[j]. -// debug_name: A scalar containing the name of the serialized proto. -// May contain, for example, table key (descriptive) name for the -// corresponding serialized proto. This is purely useful for debugging -// purposes, and the presence of values here has no effect on the output. -// May also be an empty scalar if no name is available. -func ParseSingleSequenceExample(scope *Scope, serialized tf.Output, feature_list_dense_missing_assumed_empty tf.Output, context_sparse_keys []tf.Output, context_dense_keys []tf.Output, feature_list_sparse_keys []tf.Output, feature_list_dense_keys []tf.Output, context_dense_defaults []tf.Output, debug_name tf.Output, optional ...ParseSingleSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ParseSingleSequenceExample", - Input: []tf.Input{ - serialized, feature_list_dense_missing_assumed_empty, tf.OutputList(context_sparse_keys), tf.OutputList(context_dense_keys), tf.OutputList(feature_list_sparse_keys), tf.OutputList(feature_list_dense_keys), tf.OutputList(context_dense_defaults), debug_name, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return - } - if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil { - scope.UpdateErr("ParseSingleSequenceExample", err) - return +// value: A list of unigram counts or probabilities, one per ID in sequential +// order. Exactly one of vocab_file and unigrams should be passed to this op. +// If not specified, defaults to <> +func FixedUnigramCandidateSamplerUnigrams(value []float32) FixedUnigramCandidateSamplerAttr { + return func(m optionalAttr) { + m["unigrams"] = value } - return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values } -// UnbatchGradAttr is an optional argument to UnbatchGrad. -type UnbatchGradAttr func(optionalAttr) - -// UnbatchGradContainer sets the optional container attribute to value. -// If not specified, defaults to "" -func UnbatchGradContainer(value string) UnbatchGradAttr { +// FixedUnigramCandidateSamplerSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func FixedUnigramCandidateSamplerSeed(value int64) FixedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["container"] = value + m["seed"] = value } } -// UnbatchGradSharedName sets the optional shared_name attribute to value. -// If not specified, defaults to "" -func UnbatchGradSharedName(value string) UnbatchGradAttr { +// FixedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value. +// +// value: An second seed to avoid seed collision. +// If not specified, defaults to 0 +func FixedUnigramCandidateSamplerSeed2(value int64) FixedUnigramCandidateSamplerAttr { return func(m optionalAttr) { - m["shared_name"] = value + m["seed2"] = value } } -// Gradient of Unbatch. +// Generates labels for candidate sampling with a learned unigram distribution. // -// Acts like Batch but using the given batch_index index of batching things as they -// become available. This ensures that the gradients are propagated back in the -// same session which did the forward pass. +// A unigram sampler could use a fixed unigram distribution read from a +// file or passed in as an in-memory array instead of building up the distribution +// from data on the fly. There is also an option to skew the distribution by +// applying a distortion power to the weights. // -// original_input: The input to the Unbatch operation this is the gradient of. -// batch_index: The batch_index given to the Unbatch operation this is the gradient -// of. -// grad: The downstream gradient. -// id: The id scalar emitted by Batch. -// batched_grad: The return value, either an empty tensor or the batched gradient. -// container: Container to control resource sharing. -// shared_name: Instances of UnbatchGrad with the same container and shared_name -// are assumed to possibly belong to the same batch. If left empty, the op name -// will be used as the shared name. -func UnbatchGrad(scope *Scope, original_input tf.Output, batch_index tf.Output, grad tf.Output, id tf.Output, optional ...UnbatchGradAttr) (batched_grad tf.Output) { +// The vocabulary file should be in CSV-like format, with the last field +// being the weight associated with the word. +// +// For each batch, this op picks a single set of sampled candidate labels. +// +// The advantages of sampling candidates per-batch are simplicity and the +// possibility of efficient dense matrix multiplication. The disadvantage is that +// the sampled candidates must be chosen independently of the context and of the +// true labels. +// +// Arguments: +// true_classes: A batch_size * num_true matrix, in which each row contains the +// IDs of the num_true target_classes in the corresponding original label. +// num_true: Number of true labels per context. +// num_sampled: Number of candidates to randomly sample. +// unique: If unique is true, we sample with rejection, so that all sampled +// candidates in a batch are unique. This requires some approximation to +// estimate the post-rejection sampling probabilities. +// range_max: The sampler will sample integers from the interval [0, range_max). +// +// Returns A vector of length num_sampled, in which each element is +// the ID of a sampled candidate.A batch_size * num_true matrix, representing +// the number of times each candidate is expected to occur in a batch +// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled +// candidate representing the number of times the candidate is expected +// to occur in a batch of sampled candidates. If unique=true, then this is a +// probability. +func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...FixedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} + attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max} for _, a := range optional { a(attrs) } opspec := tf.OpSpec{ - Type: "UnbatchGrad", + Type: "FixedUnigramCandidateSampler", Input: []tf.Input{ - original_input, batch_index, grad, id, + true_classes, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + return op.Output(0), op.Output(1), op.Output(2) } -// DecodeWavAttr is an optional argument to DecodeWav. -type DecodeWavAttr func(optionalAttr) +// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2. +type WholeFileReaderV2Attr func(optionalAttr) -// DecodeWavDesiredChannels sets the optional desired_channels attribute to value. +// WholeFileReaderV2Container sets the optional container attribute to value. // -// value: Number of sample channels wanted. -// If not specified, defaults to -1 -func DecodeWavDesiredChannels(value int64) DecodeWavAttr { +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr { return func(m optionalAttr) { - m["desired_channels"] = value + m["container"] = value } } -// DecodeWavDesiredSamples sets the optional desired_samples attribute to value. +// WholeFileReaderV2SharedName sets the optional shared_name attribute to value. // -// value: Length of audio requested. -// If not specified, defaults to -1 -func DecodeWavDesiredSamples(value int64) DecodeWavAttr { +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr { return func(m optionalAttr) { - m["desired_samples"] = value + m["shared_name"] = value } } -// Decode a 16-bit PCM WAV file to a float tensor. -// -// The -32768 to 32767 signed 16-bit values will be scaled to -1.0 to 1.0 in float. -// -// When desired_channels is set, if the input contains fewer channels than this -// then the last channel will be duplicated to give the requested number, else if -// the input has more channels than requested then the additional channels will be -// ignored. -// -// If desired_samples is set, then the audio will be cropped or padded with zeroes -// to the requested length. -// -// The first output contains a Tensor with the content of the audio samples. The -// lowest dimension will be the number of channels, and the second will be the -// number of samples. For example, a ten-sample-long stereo WAV file should give an -// output shape of [10, 2]. +// A Reader that outputs the entire contents of a file as a value. // -// Arguments: -// contents: The WAV-encoded audio, usually from a file. +// To use, enqueue filenames in a Queue. The output of ReaderRead will +// be a filename (key) and the contents of that file (value). // -// Returns 2-D with shape `[length, channels]`.Scalar holding the sample rate found in the WAV header. -func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (audio tf.Output, sample_rate tf.Output) { +// Returns The handle to reference the Reader. +func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) { if scope.Err() != nil { return } @@ -30605,106 +31798,119 @@ func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (aud a(attrs) } opspec := tf.OpSpec{ - Type: "DecodeWav", - Input: []tf.Input{ - contents, - }, + Type: "WholeFileReaderV2", + Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) + return op.Output(0) } -// Concatenates a list of `N` tensors along the first dimension. -// -// The input tensors are all required to have size 1 in the first dimension. -// -// For example: -// -// ``` -// # 'x' is [[1, 4]] -// # 'y' is [[2, 5]] -// # 'z' is [[3, 6]] -// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. -// ``` -// -// The difference between concat and parallel_concat is that concat requires all -// of the inputs be computed before the operation will begin but doesn't require -// that the input shapes be known during graph construction. Parallel concat -// will copy pieces of the input into the output as they become available, in -// some situations this can provide a performance benefit. +// Transforms a tf.Example proto (as a string) into typed tensors. // // Arguments: -// values: Tensors to be concatenated. All must have size 1 in the first dimension -// and same shape. -// shape: the final shape of the result; should be equal to the shapes of any input -// but with the number of input values in the first dimension. -// -// Returns The concatenated tensor. -func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) { +// serialized: A vector containing a batch of binary serialized Example protos. +// dense_defaults: A list of Tensors (some may be empty), whose length matches +// the length of `dense_keys`. dense_defaults[j] provides default values +// when the example's feature_map lacks dense_key[j]. If an empty Tensor is +// provided for dense_defaults[j], then the Feature dense_keys[j] is required. +// The input type is inferred from dense_defaults[j], even when it's empty. +// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, +// then the shape of dense_defaults[j] must match that of dense_shapes[j]. +// If dense_shapes[j] has an undefined major dimension (variable strides dense +// feature), dense_defaults[j] must contain a single element: +// the padding element. +// num_sparse: The number of sparse features to be parsed from the example. This +// must match the lengths of `sparse_keys` and `sparse_types`. +// sparse_keys: A list of `num_sparse` strings. +// The keys expected in the Examples' features associated with sparse values. +// dense_keys: The keys expected in the Examples' features associated with dense +// values. +// sparse_types: A list of `num_sparse` types; the data types of data in each +// Feature given in sparse_keys. +// Currently the ParseSingleExample op supports DT_FLOAT (FloatList), +// DT_INT64 (Int64List), and DT_STRING (BytesList). +// dense_shapes: The shapes of data in each Feature given in dense_keys. +// The length of this list must match the length of `dense_keys`. The +// number of elements in the Feature corresponding to dense_key[j] must +// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] == +// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j] +// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1, +// ..., DN), the shape of the output Tensor dense_values[j] will be (M, +// D1, .., DN), where M is the number of blocks of elements of length +// D1 * .... * DN, in the input. +func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"shape": shape} + attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes} opspec := tf.OpSpec{ - Type: "ParallelConcat", + Type: "ParseSingleExample", Input: []tf.Input{ - tf.OutputList(values), + serialized, tf.OutputList(dense_defaults), }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil { + scope.UpdateErr("ParseSingleExample", err) + return + } + return sparse_indices, sparse_values, sparse_shapes, dense_values } -// Subtracts `v` into specified rows of `x`. +// Deserializes a serialized tree ensemble config and replaces current tree // -// Computes y = x; y[i, :] -= v; return y. +// ensemble. // // Arguments: -// x: A `Tensor` of type T. -// i: A vector. Indices into the left-most dimension of `x`. -// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size. +// tree_ensemble_handle: Handle to the tree ensemble. +// stamp_token: Token to use as the new value of the resource stamp. +// tree_ensemble_serialized: Serialized proto of the ensemble. // -// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`. -func InplaceSub(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) { +// Returns the created operation. +func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "InplaceSub", + Type: "BoostedTreesDeserializeEnsemble", Input: []tf.Input{ - x, i, v, + tree_ensemble_handle, stamp_token, tree_ensemble_serialized, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + return scope.AddOperation(opspec) } -// Converts a flat index or array of flat indices into a tuple of -// -// coordinate arrays. -// -// @compatibility(numpy) -// Equivalent to np.unravel_index -// @end_compatibility -// -// Arguments: -// indices: An 0-D or 1-D `int` Tensor whose elements are indices into the -// flattened version of an array of dimensions dims. -// dims: An 1-D `int` Tensor. The shape of the array to use for unraveling -// indices. +// Elementwise computes the bitwise AND of `x` and `y`. // -// Returns An 2-D (or 1-D if indices is 0-D) tensor where each row has the -// same shape as the indices array. -func UnravelIndex(scope *Scope, indices tf.Output, dims tf.Output) (output tf.Output) { +// The result will have those bits set, that are set in both `x` and `y`. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "UnravelIndex", + Type: "BitwiseAnd", Input: []tf.Input{ - indices, dims, + x, y, }, } op := scope.AddOperation(opspec) diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go index baaac41f4ed9cff285cad0c7e63084102eb9a67c..d6a37e0a8633f936fda7ec9612c6c097c9029c31 100644 --- a/tensorflow/go/operation.go +++ b/tensorflow/go/operation.go @@ -45,6 +45,12 @@ func (op *Operation) NumOutputs() int { return int(C.TF_OperationNumOutputs(op.c)) } +// Device returns a specification of the device on which this operation +// will be executed, or the empty string if there is no such specification. +func (op *Operation) Device() string { + return C.GoString(C.TF_OperationDevice(op.c)) +} + // OutputListSize returns the size of the list of Outputs that is produced by a // named output of op. // @@ -131,6 +137,9 @@ func (p Output) canBeAnInput() {} // Consumers returns the inputs that consume this output. func (p Output) Consumers() []Consumer { max := int(C.TF_OperationOutputNumConsumers(p.c())) + if max == 0 { + return nil + } inputs := make([]C.TF_Input, max) n := C.TF_OperationOutputConsumers(p.c(), (*C.TF_Input)(unsafe.Pointer(&inputs[0])), C.int(max)) inputs = inputs[:int(n)] diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 0672e8ecc75629f114a52b9c677afeea96b213b4..4af9e33ad0aea5d269d876f154f96cbc99243cad 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -222,6 +222,33 @@ func TestOperationConsumers(t *testing.T) { t.Fatalf("%d. Got op name %q, wanted %q", i, got, want) } } + + if len(b.Consumers()) != 0 { + t.Fatalf("expected %+v to have no consumers", b) + } +} + +func TestOperationDevice(t *testing.T) { + graph := NewGraph() + v, err := NewTensor(float32(1.0)) + if err != nil { + t.Fatal(err) + } + op, err := graph.AddOperation(OpSpec{ + Type: "Const", + Name: "Const", + Attrs: map[string]interface{}{ + "dtype": v.DataType(), + "value": v, + }, + Device: "/device:GPU:0", + }) + if err != nil { + t.Fatal(err) + } + if got, want := op.Device(), "/device:GPU:0"; got != want { + t.Errorf("Got %q, want %q", got, want) + } } func forceGC() { diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index 2d25c04dc9b1d0bc2ae831f98c0879e73a6bfafa..f3338f6595793df82380f4ce63058ba4285c91dd 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -131,13 +131,9 @@ func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) } runtime.SetFinalizer(t, (*Tensor).finalize) raw := tensorData(t.c) - n, err := r.Read(raw) - if err != nil { + if _, err := io.ReadFull(r, raw); err != nil { return nil, err } - if uintptr(n) != nbytes { - return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n) - } return t, nil } diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go index 793c36dd4db28fc5fdb713095c6d1d6713367a7a..dc533cd3e1c7198f902b2db850e8daff50f4cdeb 100644 --- a/tensorflow/go/tensor_test.go +++ b/tensorflow/go/tensor_test.go @@ -18,6 +18,7 @@ package tensorflow import ( "bytes" + "io" "reflect" "testing" ) @@ -226,6 +227,54 @@ func TestTensorSerializationErrors(t *testing.T) { } } +func TestReadTensorReadAll(t *testing.T) { + // Get the bytes of a tensor. + a := []float32{1.1, 1.2, 1.3} + ats, err := NewTensor(a) + if err != nil { + t.Fatal(err) + } + abuf := new(bytes.Buffer) + if _, err := ats.WriteContentsTo(abuf); err != nil { + t.Fatal(err) + } + + // Get the bytes of another tensor. + b := []float32{1.1, 1.2, 1.3} + bts, err := NewTensor(b) + if err != nil { + t.Fatal(err) + } + bbuf := new(bytes.Buffer) + if _, err := bts.WriteContentsTo(bbuf); err != nil { + t.Fatal(err) + } + + // Check that ReadTensor reads all bytes of both tensors, when the situation + // requires one than reads. + abbuf := io.MultiReader(abuf, bbuf) + abts, err := ReadTensor(Float, []int64{2, 3}, abbuf) + if err != nil { + t.Fatal(err) + } + abtsf32 := abts.Value().([][]float32) + expected := [][]float32{a, b} + + if len(abtsf32) != 2 { + t.Fatalf("first dimension %d is not 2", len(abtsf32)) + } + for i := 0; i < 2; i++ { + if len(abtsf32[i]) != 3 { + t.Fatalf("second dimension %d is not 3", len(abtsf32[i])) + } + for j := 0; j < 3; j++ { + if abtsf32[i][j] != expected[i][j] { + t.Errorf("value at %d %d not equal %f %f", i, j, abtsf32[i][j], expected[i][j]) + } + } + } +} + func benchmarkNewTensor(b *testing.B, v interface{}) { for i := 0; i < b.N; i++ { if t, err := NewTensor(v); err != nil || t == nil { diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 19d2133a55f347cfc3d4dc766e0593a0e188c967..9dce78b9a367cdf5243dfab621cc6fc77d732ee5 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -56,6 +56,10 @@ java_library( srcs = glob(["src/gen/java/org/tensorflow/processor/**/*.java"]), javacopts = JAVACOPTS, resources = glob(["src/gen/resources/META-INF/services/javax.annotation.processing.Processor"]), + deps = [ + "@com_google_guava", + "@com_squareup_javapoet", + ], ) filegroup( @@ -70,6 +74,7 @@ tf_java_op_gen_srcjar( name = "java_op_gen_sources", api_def_srcs = [ "//tensorflow/core/api_def:base_api_def", + "//tensorflow/core/api_def:java_api_def", ], base_package = "org.tensorflow.op", gen_tool = ":java_op_gen_tool", @@ -81,7 +86,10 @@ tf_cc_binary( "src/gen/cc/op_gen_main.cc", ], copts = tf_copts(), - linkopts = ["-lm"], + linkopts = select({ + "//tensorflow:windows": [], + "//conditions:default": ["-lm"], + }), linkstatic = 1, deps = [ ":java_op_gen_lib", @@ -287,6 +295,32 @@ tf_java_test( ], ) +tf_java_test( + name = "GradientsTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/op/core/GradientsTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.op.core.GradientsTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + +tf_java_test( + name = "ZerosTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/op/core/ZerosTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.op.core.ZerosTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + filegroup( name = "processor_test_resources", srcs = glob([ @@ -337,7 +371,6 @@ tf_cc_binary( "$(location {})".format(LINKER_EXPORTED_SYMBOLS), ], "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", "-s", diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md index 2f1ce253b2facb6d86d5c44b60668823f660ae7e..c7382ff23138cd8121718d0b7552da0f0a2d78af 100644 --- a/tensorflow/java/README.md +++ b/tensorflow/java/README.md @@ -1,7 +1,7 @@ # TensorFlow for Java > *WARNING*: The TensorFlow Java API is not currently covered by the TensorFlow -> [API stability guarantees](https://www.tensorflow.org/programmers_guide/version_semantics). +> [API stability guarantees](https://www.tensorflow.org/guide/version_semantics). > > For using TensorFlow on Android refer instead to > [contrib/android](https://www.tensorflow.org/code/tensorflow/contrib/android), @@ -23,8 +23,7 @@ native libraries will need to be built from source. 2. Setup the environment to build TensorFlow from source code ([Linux](https://www.tensorflow.org/install/install_sources#PrepareLinux) - or [Mac OS - X](https://www.tensorflow.org/install/install_sources#PrepareMac)). + or [macOS](https://www.tensorflow.org/install/install_sources#PrepareMac)). If you'd like to skip reading those details and do not care about GPU support, try the following: diff --git a/tensorflow/java/maven/README.md b/tensorflow/java/maven/README.md index 3e030dcd09c886983540b95640230eae3a6f2c0f..cbc64a284fc0d977d4540d80d96a0901207c34aa 100644 --- a/tensorflow/java/maven/README.md +++ b/tensorflow/java/maven/README.md @@ -151,16 +151,6 @@ conducted in a [Docker](https://www.docker.com) container. 7. Upon successful release, commit changes to all the `pom.xml` files (which should have the updated version number). -### Snapshots - -If the `TF_VERSION` provided to the `release.sh` script ends in `-SNAPSHOT`, -then instead of using official release files, the nightly build artifacts from -https://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/, -https://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/ and -https://ci.tensorflow.org/view/Nightly/job/nightly-android -will be used to upload to the Maven Central snapshots repository. (Note that -snapshots are only uploaded to Maven Central, not Bintray.) - ### Skip deploying to a repository Should you need, setting environment variables `DEPLOY_OSSRH=0` or @@ -173,12 +163,12 @@ cannot skip deploying to OSSRH for a `-SNAPSHOT` version. This section provides some pointers around how artifacts are currently assembled. -All native and java code is first built and tested on -a [Tensorflow Jenkins server](https://ci.tensorflow.org/) which run various -scripts under the [`tools/ci_build`](../../tools/ci_build/) directory. Of -particular interest may be `tools/ci_build/builds/libtensorflow.sh` which -bundles Java-related build sources and outputs into archives, and -`tools/ci_build/builds/android_full.sh` which produces an Android AAR package. +All native and java code is first built and tested by the release process +which run various scripts under the [`tools/ci_build`](../../tools/ci_build/) +directory. Of particular interest may be +`tools/ci_build/builds/libtensorflow.sh` which bundles Java-related build +sources and outputs into archives, and `tools/ci_build/builds/android_full.sh` +which produces an Android AAR package. Maven artifacts however are not created in Jenkins. Instead, artifacts are created and deployed externally on-demand, when a maintainer runs the diff --git a/tensorflow/java/maven/hadoop/pom.xml b/tensorflow/java/maven/hadoop/pom.xml index 0642be06fa148933902ab450c5cf2f771e268828..e0409fa41b4a3be018cd052c63635886a46ed3d1 100644 --- a/tensorflow/java/maven/hadoop/pom.xml +++ b/tensorflow/java/maven/hadoop/pom.xml @@ -1,12 +1,30 @@ - - + 4.0.0 - TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop + org.tensorflow hadoop jar + 1.10.0 + tensorflow-hadoop + https://www.tensorflow.org + TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop + + + UTF-8 + 1.6 + 1.6 + 2.6.0 + 3.3.1 + 4.11 + + + + + Apache License Version 2.0 + http://www.apache.org/licenses/LICENSE-2.0.txt + + https://github.com/tensorflow/ecosystem.git @@ -14,11 +32,161 @@ scm:git:https://github.com/tensorflow/ecosystem.git - https://github.com/tensorflow/ecosystem/ - - org.tensorflow - parentpom - 1.9.0-rc0 - ../ - - \ No newline at end of file + + + + + org.apache.maven.plugins + maven-gpg-plugin + 1.5 + + + sign-artifacts + verify + + sign + + + + + + + + + org.apache.maven.plugins + maven-source-plugin + 2.2.1 + + + attach-sources + + jar-no-fork + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 2.9.1 + + + attach-javadocs + + jar + + + + + + + + + + org.tensorflow + proto + ${project.version} + + + org.apache.hadoop + hadoop-common + ${hadoop.version} + + + com.google.protobuf + protobuf-java + + + + + org.apache.hadoop + hadoop-mapreduce-client-core + ${hadoop.version} + + + com.google.protobuf + protobuf-java + + + + + com.google.protobuf + protobuf-java + ${protobuf.version} + + + junit + junit + ${junit.version} + test + + + org.apache.hadoop + hadoop-mapreduce-client-jobclient + ${hadoop.version} + test-jar + true + test + + + com.google.protobuf + protobuf-java + + + + + + + + + ossrh + + + + ossrh + https://oss.sonatype.org/content/repositories/snapshots + + + ossrh + https://oss.sonatype.org/service/local/staging/deploy/maven2/ + + + + + + org.apache.maven.plugins + maven-gpg-plugin + + + + + + bintray + + + + bintray + https://api.bintray.com/maven/google/tensorflow/tensorflow/;publish=0 + + + + + + org.apache.maven.plugins + maven-gpg-plugin + + + + + + + + + TensorFlowers + TensorFlow + http://www.tensorflow.org + + + diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml index 38e87b16399e7b9344f654b553fa1623b6b2d9cd..f9093ce385408d6df5cd2b6730ddb31cd3c21f54 100644 --- a/tensorflow/java/maven/libtensorflow/pom.xml +++ b/tensorflow/java/maven/libtensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.9.0-rc0 + 1.10.0 ../ libtensorflow diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml index 36c984e280199c97a6e07516cc84290fa91e6b27..1208956decf4909f76411e2a524b6154d8b1fb4f 100644 --- a/tensorflow/java/maven/libtensorflow_jni/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.9.0-rc0 + 1.10.0 ../ libtensorflow_jni diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml index 4c846de05ad415bedf2c14b6a07ff8d5bc6f11b8..755449cb3c0fb3c27b96271d38d520855a605c59 100644 --- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.9.0-rc0 + 1.10.0 ../ libtensorflow_jni_gpu diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index acab08f58c074144472f80eeaf894e25ad8389f8..035077e1e0140ef21921995a33a176f1d84a9208 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -6,7 +6,7 @@ 4.0.0 org.tensorflow parentpom - 1.9.0-rc0 + 1.10.0 pom https://www.tensorflow.org diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml index eb0a952c7d2f8960a387fa63e6c257e33b80bcbb..b89f0425677adcb5eb4f6c803ea643765b5d13cc 100644 --- a/tensorflow/java/maven/proto/pom.xml +++ b/tensorflow/java/maven/proto/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.9.0-rc0 + 1.10.0 ../ proto diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh index 2e771064e4a9a9ca4c50f5ecf8cae91cf8b5ce6c..8c4c9d498c008a78f9eed68e0fcdca9f141b9b25 100644 --- a/tensorflow/java/maven/run_inside_container.sh +++ b/tensorflow/java/maven/run_inside_container.sh @@ -26,12 +26,6 @@ TF_ECOSYSTEM_URL="https://github.com/tensorflow/ecosystem.git" DEPLOY_BINTRAY="${DEPLOY_BINTRAY:-true}" DEPLOY_OSSRH="${DEPLOY_OSSRH:-true}" -IS_SNAPSHOT="false" -if [[ "${TF_VERSION}" == *"-SNAPSHOT" ]]; then - IS_SNAPSHOT="true" - # Bintray does not allow snapshots. - DEPLOY_BINTRAY="false" -fi PROTOC_RELEASE_URL="https://github.com/google/protobuf/releases/download/v3.5.1/protoc-3.5.1-linux-x86_64.zip" if [[ "${DEPLOY_BINTRAY}" != "true" && "${DEPLOY_OSSRH}" != "true" ]]; then echo "Must deploy to at least one of Bintray or OSSRH" >&2 @@ -69,11 +63,7 @@ mvn_property() { } download_libtensorflow() { - if [[ "${IS_SNAPSHOT}" == "true" ]]; then - URL="http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=cpu-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow-src.jar" - else - URL="${RELEASE_URL_PREFIX}/libtensorflow-src-${TF_VERSION}.jar" - fi + URL="${RELEASE_URL_PREFIX}/libtensorflow-src-${TF_VERSION}.jar" curl -L "${URL}" -o /tmp/src.jar cd "${DIR}/libtensorflow" jar -xvf /tmp/src.jar @@ -101,17 +91,9 @@ download_libtensorflow_jni() { mkdir windows-x86_64 mkdir darwin-x86_64 - if [[ "${IS_SNAPSHOT}" == "true" ]]; then - # Nightly builds from http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/ - # and http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/ - curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=cpu-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-cpu-linux-x86_64.tar.gz" | tar -xvz -C linux-x86_64 - curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=mac-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-cpu-darwin-x86_64.tar.gz" | tar -xvz -C darwin-x86_64 - curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-cpu-windows-x86_64.zip" -o /tmp/windows.zip - else - curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64 - curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-darwin-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C darwin-x86_64 - curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-windows-x86_64-${TF_VERSION}.zip" -o /tmp/windows.zip - fi + curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64 + curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-darwin-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C darwin-x86_64 + curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-windows-x86_64-${TF_VERSION}.zip" -o /tmp/windows.zip unzip /tmp/windows.zip -d windows-x86_64 rm -f /tmp/windows.zip @@ -128,17 +110,17 @@ download_libtensorflow_jni_gpu() { cd "${NATIVE_DIR}" mkdir linux-x86_64 + mkdir windows-x86_64 - if [[ "${IS_SNAPSHOT}" == "true" ]]; then - # Nightly builds from http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/ - # and http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/ - curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=gpu-linux/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-gpu-linux-x86_64.tar.gz" | tar -xvz -C linux-x86_64 - else - curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-gpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64 - fi + curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-gpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64 + curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-gpu-windows-x86_64-${TF_VERSION}.zip" -o /tmp/windows.zip + + unzip /tmp/windows.zip -d windows-x86_64 + rm -f /tmp/windows.zip # Updated timestamps seem to be required to get Maven to pick up the file. touch linux-x86_64/* + touch windows-x86_64/* cd "${DIR}" } @@ -165,11 +147,7 @@ generate_java_protos() { rm -f "/tmp/protoc.zip" # Download the release archive of TensorFlow protos. - if [[ "${IS_SNAPSHOT}" == "true" ]]; then - URL="http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=cpu-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow_proto.zip" - else - URL="${RELEASE_URL_PREFIX}/libtensorflow_proto-${TF_VERSION}.zip" - fi + URL="${RELEASE_URL_PREFIX}/libtensorflow_proto-${TF_VERSION}.zip" curl -L "${URL}" -o /tmp/libtensorflow_proto.zip mkdir -p "${DIR}/proto/tmp/src" unzip -d "${DIR}/proto/tmp/src" "/tmp/libtensorflow_proto.zip" @@ -203,7 +181,10 @@ download_tf_ecosystem() { cd "${ECOSYSTEM_DIR}" git clone "${TF_ECOSYSTEM_URL}" cd ecosystem - git checkout r${TF_VERSION} + # TF_VERSION is a semver string (..[-suffix]) + # but the branch is just (r.). + RELEASE_BRANCH=$(echo "${TF_VERSION}" | sed -e 's/\([0-9]\+\.[0-9]\+\)\.[0-9]\+.*/\1/') + git checkout r${RELEASE_BRANCH} # Copy the TensorFlow Hadoop source cp -r "${ECOSYSTEM_DIR}/ecosystem/hadoop/src" "${HADOOP_DIR}" @@ -235,11 +216,7 @@ deploy_profile() { # Determine the correct pom file property to use # for the repository url. local rtype - if [[ "${IS_SNAPSHOT}" == "true" ]]; then - rtype='snapshotRepository' - else - rtype='repository' - fi + rtype='repository' local url=$(mvn_property "${profile}" "project.distributionManagement.${rtype}.url") local repositoryId=$(mvn_property "${profile}" "project.distributionManagement.${rtype}.id") mvn gpg:sign-and-deploy-file \ @@ -297,17 +274,13 @@ mvn verify deploy_artifacts set +ex -if [[ "${IS_SNAPSHOT}" == "false" ]]; then - echo "Uploaded to the staging repository" - echo "After validating the release: " - if [[ "${DEPLOY_OSSRH}" == "true" ]]; then - echo "* Login to https://oss.sonatype.org/#stagingRepositories" - echo "* Find the 'org.tensorflow' staging release and click either 'Release' to release or 'Drop' to abort" - fi - if [[ "${DEPLOY_BINTRAY}" == "true" ]]; then - echo "* Login to https://bintray.com/google/tensorflow/tensorflow" - echo "* Either 'Publish' unpublished items to release, or 'Discard' to abort" - fi -else - echo "Uploaded to the snapshot repository" +echo "Uploaded to the staging repository" +echo "After validating the release: " +if [[ "${DEPLOY_OSSRH}" == "true" ]]; then + echo "* Login to https://oss.sonatype.org/#stagingRepositories" + echo "* Find the 'org.tensorflow' staging release and click either 'Release' to release or 'Drop' to abort" +fi +if [[ "${DEPLOY_BINTRAY}" == "true" ]]; then + echo "* Login to https://bintray.com/google/tensorflow/tensorflow" + echo "* Either 'Publish' unpublished items to release, or 'Discard' to abort" fi diff --git a/tensorflow/java/maven/spark-connector/pom.xml b/tensorflow/java/maven/spark-connector/pom.xml index 19c752d08be1deec40042bc1aa8fd1159b2f2be9..31e39c588a065cb152ea871ce131d145bd88596a 100644 --- a/tensorflow/java/maven/spark-connector/pom.xml +++ b/tensorflow/java/maven/spark-connector/pom.xml @@ -1,12 +1,23 @@ - - + + 4.0.0 - TensorFlow TFRecord connector for Apache Spark DataFrames - spark-connector + org.tensorflow + spark-connector_2.11 jar + 1.10.0 + spark-tensorflow-connector + https://www.tensorflow.org + TensorFlow TFRecord connector for Apache Spark DataFrames + + + + The Apache Software License, Version 2.0 + http://www.apache.org/licenses/LICENSE-2.0.txt + repo + + https://github.com/tensorflow/ecosystem.git @@ -14,11 +25,325 @@ scm:git:https://github.com/tensorflow/ecosystem.git - https://github.com/tensorflow/ecosystem/ - - org.tensorflow - parentpom - 1.9.0-rc0 - ../ - - \ No newline at end of file + + UTF-8 + 3.2.2 + 2.11 + 1.0 + 2.2.6 + 3.0 + 1.8 + 2.3.0 + 2.7.3 + 4.11 + + + + + + + true + net.alchim31.maven + scala-maven-plugin + ${scala.maven.version} + + + compile + + add-source + compile + + + + -Xms256m + -Xmx512m + + + -g:vars + -deprecation + -feature + -unchecked + -Xfatal-warnings + -language:implicitConversions + -language:existentials + + + + + test + + add-source + testCompile + + + + attach-javadocs + + doc-jar + + + + + incremental + true + ${scala.binary.version} + false + + + + true + org.scalatest + scalatest-maven-plugin + ${scalatest.maven.version} + + + scalaTest + test + + test + + + + + + + maven-shade-plugin + 3.1.0 + + + package + + shade + + + true + + + com.google.protobuf:protobuf-java + org.tensorflow:hadoop + org.tensorflow:proto + + + + + + com.google.protobuf:protobuf-java + + **/*.java + + + + + + com.google.protobuf + + org.tensorflow.spark.shaded.com.google.protobuf + + + + + + + + + + org.apache.maven.plugins + maven-gpg-plugin + 1.5 + + + sign-artifacts + verify + + sign + + + + + + + + + net.alchim31.maven + scala-maven-plugin + + + org.apache.maven.plugins + maven-shade-plugin + + + org.scalatest + scalatest-maven-plugin + + + org.apache.maven.plugins + maven-compiler-plugin + ${maven.compiler.version} + + ${java.version} + ${java.version} + + + + org.apache.maven.plugins + maven-source-plugin + 2.2.1 + + + attach-sources + + jar-no-fork + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 2.9.1 + + + attach-javadocs + + jar + + + + + + + + + + test + + true + + !NEVERSETME + + + + + + net.alchim31.maven + scala-maven-plugin + + + + + + + org.scalatest + scalatest_${scala.binary.version} + ${scala.test.version} + test + + + + + + org.scalatest + scalatest_${scala.binary.version} + test + + + + + + + ossrh + + + + ossrh + https://oss.sonatype.org/content/repositories/snapshots + + + ossrh + https://oss.sonatype.org/service/local/staging/deploy/maven2/ + + + + + + org.apache.maven.plugins + maven-gpg-plugin + + + + + + bintray + + + + bintray + https://api.bintray.com/maven/google/tensorflow/tensorflow/;publish=0 + + + + + + org.apache.maven.plugins + maven-gpg-plugin + + + + + + + + + TensorFlowers + TensorFlow + http://www.tensorflow.org + + + + + + org.tensorflow + hadoop + ${project.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-mllib_${scala.binary.version} + ${spark.version} + provided + + + org.apache.hadoop + hadoop-yarn-api + ${yarn.api.version} + provided + + + org.apache.spark + spark-mllib_${scala.binary.version} + ${spark.version} + test-jar + test + + + junit + junit + ${junit.version} + test + + + diff --git a/tensorflow/java/maven/tensorflow-android/update.py b/tensorflow/java/maven/tensorflow-android/update.py index 2206d800ca1fe82c5596ff39e56518bc5aea6211..c620564072cb6b5f35415e7c5844bddcdd78cdc7 100644 --- a/tensorflow/java/maven/tensorflow-android/update.py +++ b/tensorflow/java/maven/tensorflow-android/update.py @@ -86,19 +86,10 @@ def read_template(path): def main(): args = get_args() - # Artifacts are downloaded from the ci build. A SNAPSHOT release is - # associated with artifacts from the last successful nightly build. Otherwise, - # it comes from the officially blessed release artifacts. - if args.version.endswith('SNAPSHOT'): - info_url = ('https://ci.tensorflow.org/view/Nightly/job/nightly-android' - '/lastSuccessfulBuild/api/json') - aar_url = None - build_type = 'nightly-android' - else: - release_prefix = 'https://storage.googleapis.com/tensorflow/libtensorflow' - info_url = '%s/android_buildinfo-%s.json' % (release_prefix, args.version) - aar_url = '%s/tensorflow-%s.aar' % (release_prefix, args.version) - build_type = 'release-android' + release_prefix = 'https://storage.googleapis.com/tensorflow/libtensorflow' + info_url = '%s/android_buildinfo-%s.json' % (release_prefix, args.version) + aar_url = '%s/tensorflow-%s.aar' % (release_prefix, args.version) + build_type = 'release-android' # Retrieve build information build_info = get_json(info_url) diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml index 48668a47f2839e06eba774cae75b03663985ea28..0de90244b11d64b59e8bca51fb422af4564fa67e 100644 --- a/tensorflow/java/maven/tensorflow/pom.xml +++ b/tensorflow/java/maven/tensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.9.0-rc0 + 1.10.0 ../ tensorflow diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index f5f54bf4d31af159624c668f1abb106f68944737..d9d6f8adc8ac9e58dbfe3609171803b55e76e42d 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_ #define TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_ -#include #include #include +#include #include namespace tensorflow { diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 9b171f66ec741aa747fa365c13144a4b8b1aecbc..d5bd99bdd9d71f73288661380ec45e76c797fa75 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -35,7 +35,7 @@ namespace tensorflow { namespace java { namespace { -const char* kLicense = +constexpr const char kLicense[] = "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" "\n" "Licensed under the Apache License, Version 2.0 (the \"License\");\n" @@ -391,9 +391,12 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, } if (!op.hidden()) { // expose the op in the Ops Graph API only if it is visible - op_class.add_annotation( - Annotation::Create("Operator", "org.tensorflow.op.annotation") - .attributes("group = \"" + endpoint.package() + "\"")); + Annotation oper_annot = + Annotation::Create("Operator", "org.tensorflow.op.annotation"); + if (endpoint.package() != kDefaultEndpointPackage) { + oper_annot.attributes("group = \"" + endpoint.package() + "\""); + } + op_class.add_annotation(oper_annot); } // create op class file const string op_dir_name = io::JoinPath( diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 759d800ecfb5bec10b7bf8454baf5fc4c389e990..05decd6b54944f18205cce4d2341d7009ce7d806 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -19,10 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/java/src/gen/cc/op_specs.h" namespace tensorflow { diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index 4bcfc7fe011423df71a899d18815d3558e01b35f..941ab2699cb887375987f14200664b9bfaf6815a 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include #include #include +#include #include "re2/re2.h" #include "tensorflow/core/framework/op.h" @@ -50,7 +50,7 @@ class TypeResolver { // For example, if the argument's datatype is DT_STRING, this method will // return "java.lang.String", so the argument can become "Operand" // in the Ops API - Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out); + Type TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out); // Returns types of an input attribute // @@ -62,7 +62,7 @@ class TypeResolver { // , so the attribute can be used as a "Float" object // in the Ops API and casted to a "float" when passing through the JNI layer. std::pair TypesOf(const OpDef_AttrDef& attr_def, - bool *iterable_out); + bool* iterable_out); // Returns true if the type of this attribute has already been resolved bool IsAttributeVisited(const string& attr_name) { @@ -89,14 +89,14 @@ class TypeResolver { } }; -Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, - bool* iterable_out) { +Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) { *iterable_out = false; if (!arg_def.number_attr().empty()) { // when number_attr is set, argument has to be a list of tensors *iterable_out = true; visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int())); } + Type type = Type::Wildcard(); if (arg_def.type() != DataType::DT_INVALID) { // resolve type from DataType @@ -153,13 +153,13 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, } else { LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name() - << "\" in operation \"" << op_def_.name() << "\""; + << "\" in operation \"" << op_def_.name() << "\""; } return type; } std::pair TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, - bool* iterable_out) { + bool* iterable_out) { std::pair types = MakeTypePair(Type::Wildcard()); *iterable_out = false; StringPiece attr_type = attr_def.type(); @@ -184,7 +184,7 @@ std::pair TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, } else if (attr_type == "tensor") { types = MakeTypePair(Type::Class("Tensor", "org.tensorflow") - .add_parameter(Type::Wildcard())); + .add_parameter(Type::Wildcard())); } else if (attr_type == "type") { Type type = *iterable_out ? Type::Wildcard() : NextGeneric(); @@ -195,7 +195,7 @@ std::pair TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, } else { LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type - << "\" in operation \"" << op_def_.name() << "\""; + << "\" in operation \"" << op_def_.name() << "\""; } visited_attrs_.insert(std::make_pair(attr_def.name(), types.first)); return types; @@ -218,47 +218,43 @@ string SnakeToCamelCase(const string& str, bool upper = false) { return result; } -bool FindAndCut(re2::StringPiece* input, const RE2& expr, - re2::StringPiece* before_match, re2::StringPiece* ret_match = nullptr) { - re2::StringPiece match; - if (!expr.Match(*input, 0, input->size(), RE2::UNANCHORED, &match, 1)) { - return false; - } - before_match->set(input->data(), match.begin() - input->begin()); - input->remove_prefix(match.end() - before_match->begin()); - if (ret_match != nullptr) { - *ret_match = match; - } +bool FindAndCut(string* input, const RE2& expr, string* before_match, + string* ret_match = nullptr) { + string match; + if (!RE2::PartialMatch(*input, expr, &match)) return false; + *before_match = input->substr(0, input->find(match)); + *input = input->substr(before_match->size() + match.size()); + if (ret_match != nullptr) *ret_match = match; return true; } -string ParseDocumentation(re2::StringPiece input) { +string ParseDocumentation(const string& inp) { std::stringstream javadoc_text; // TODO(karllessard) This is a very minimalist utility method for converting // markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check // for alternatives to increase the level of support for markups. std::vector markups_subexpr; - markups_subexpr.push_back("\n+\\*\\s+"); // lists - markups_subexpr.push_back("\n{2,}"); // paragraphs + markups_subexpr.push_back("\n+\\*\\s+"); // lists + markups_subexpr.push_back("\n{2,}"); // paragraphs markups_subexpr.push_back("`{3,}\\s*[^\\s\n]*\\s*\n"); // code blocks - markups_subexpr.push_back("`+"); // inlined code and code blocks + markups_subexpr.push_back("`+"); // inlined code and code blocks markups_subexpr.push_back("\\*{1,2}\\b"); // text emphasis - markups_subexpr.push_back("\\["); // hyperlinks - const RE2 markup_expr(str_util::Join(markups_subexpr, "|")); + markups_subexpr.push_back("\\["); // hyperlinks + const RE2 markup_expr("(" + str_util::Join(markups_subexpr, "|") + ")"); bool in_list = false; + string input = inp; while (true) { - re2::StringPiece text; - re2::StringPiece markup; + string text, markup; if (!FindAndCut(&input, markup_expr, &text, &markup)) { javadoc_text << input; break; // end of loop } javadoc_text << text; - if (markup.starts_with("\n")) { + if (str_util::StartsWith(markup, "\n")) { javadoc_text << "\n"; - if (markup.contains("*")) { + if (str_util::StrContains(markup, "*")) { // new list item javadoc_text << (in_list ? "\n" : "
    \n") << "
  • \n"; in_list = true; @@ -266,18 +262,18 @@ string ParseDocumentation(re2::StringPiece input) { // end of list javadoc_text << "
  • \n
\n"; in_list = false; - } else if (!input.starts_with("```")) { + } else if (!str_util::StartsWith(input, "```")) { // new paragraph (not required if a
 block follows)
         javadoc_text << "

\n"; } - } else if (markup.starts_with("```")) { + } else if (str_util::StartsWith(markup, "```")) { // code blocks - if (FindAndCut(&input, "```\\s*\n*", &text)) { + if (FindAndCut(&input, "(```\\s*\n*)", &text)) { javadoc_text << "

{@code\n" << text << "}
\n"; } else { javadoc_text << markup; } - } else if (markup.starts_with("`")) { + } else if (str_util::StartsWith("(" + markup + ")", "`")) { // inlined code if (FindAndCut(&input, markup, &text)) { javadoc_text << "{@code " << text << "}"; @@ -286,26 +282,28 @@ string ParseDocumentation(re2::StringPiece input) { } } else if (markup == "**") { // text emphasis (strong) - if (FindAndCut(&input, "\\b\\*{2}", &text)) { + if (FindAndCut(&input, "(\\b\\*{2})", &text)) { javadoc_text << "" << ParseDocumentation(text) << ""; } else { javadoc_text << markup; } } else if (markup == "*") { // text emphasis (normal) - if (FindAndCut(&input, "\\b\\*{1}", &text)) { + if (FindAndCut(&input, "(\\b\\*{1})", &text)) { javadoc_text << "" << ParseDocumentation(text) << ""; } else { javadoc_text << markup; } - } else if (markup.starts_with("[")) { + } else if (str_util::StartsWith(markup, "[")) { // hyperlinks string label; string link; - if (RE2::Consume(&input, "([^\\[]+)\\]\\((http.+)\\)", &label, &link)) { + if (RE2::PartialMatch(input, "([^\\[]+)\\]\\((http.+)\\)", &label, + &link) && + str_util::StartsWith(input, label + link)) { + input = input.substr(label.size() + link.size()); javadoc_text << "" - << ParseDocumentation(label) - << ""; + << ParseDocumentation(label) << ""; } else { javadoc_text << markup; } @@ -318,57 +316,56 @@ string ParseDocumentation(re2::StringPiece input) { } ArgumentSpec CreateInput(const OpDef_ArgDef& input_def, - const ApiDef::Arg& input_api_def, TypeResolver* type_resolver) { + const ApiDef::Arg& input_api_def, + TypeResolver* type_resolver) { bool iterable = false; Type type = type_resolver->TypeOf(input_def, &iterable); - Type var_type = Type::Interface("Operand", "org.tensorflow") - .add_parameter(type); + Type var_type = + Type::Interface("Operand", "org.tensorflow").add_parameter(type); if (iterable) { var_type = Type::IterableOf(var_type); } - return ArgumentSpec(input_api_def.name(), + return ArgumentSpec( + input_api_def.name(), Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type), - type, - ParseDocumentation(input_api_def.description()), - iterable); + type, ParseDocumentation(input_api_def.description()), iterable); } AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, - const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) { + const ApiDef::Attr& attr_api_def, + TypeResolver* type_resolver) { bool iterable = false; std::pair types = type_resolver->TypesOf(attr_def, &iterable); - Type var_type = types.first.kind() == Type::GENERIC ? - Type::Class("Class").add_parameter(types.first) : types.first; + Type var_type = types.first.kind() == Type::GENERIC + ? Type::Class("Class").add_parameter(types.first) + : types.first; if (iterable) { var_type = Type::ListOf(var_type); } - return AttributeSpec(attr_api_def.name(), + return AttributeSpec( + attr_api_def.name(), Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type), - types.first, - types.second, - ParseDocumentation(attr_api_def.description()), - iterable, - attr_api_def.has_default_value()); + types.first, types.second, ParseDocumentation(attr_api_def.description()), + iterable, attr_api_def.has_default_value()); } ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def, - const ApiDef::Arg& output_api, TypeResolver* type_resolver) { + const ApiDef::Arg& output_api, + TypeResolver* type_resolver) { bool iterable = false; Type type = type_resolver->TypeOf(output_def, &iterable); - Type var_type = Type::Class("Output", "org.tensorflow") - .add_parameter(type); + Type var_type = Type::Class("Output", "org.tensorflow").add_parameter(type); if (iterable) { var_type = Type::ListOf(var_type); } - return ArgumentSpec(output_api.name(), + return ArgumentSpec( + output_api.name(), Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type), - type, - ParseDocumentation(output_api.description()), - iterable); + type, ParseDocumentation(output_api.description()), iterable); } EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def, - const ApiDef_Endpoint& endpoint_def) { + const ApiDef_Endpoint& endpoint_def) { std::vector name_tokens = str_util::Split(endpoint_def.name(), "."); string package; string name; @@ -379,24 +376,22 @@ EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def, package = "core"; // generate unclassified ops in the 'core' package name = name_tokens.at(0); } - return EndpointSpec(package, - name, - Javadoc::Create(ParseDocumentation(api_def.summary())) - .details(ParseDocumentation(api_def.description()))); + return EndpointSpec(package, name, + Javadoc::Create(ParseDocumentation(api_def.summary())) + .details(ParseDocumentation(api_def.description()))); } } // namespace OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { - OpSpec op(api_def.graph_op_name(), - api_def.visibility() == ApiDef::HIDDEN, - op_def.deprecation().explanation()); + OpSpec op(api_def.graph_op_name(), api_def.visibility() == ApiDef::HIDDEN, + op_def.deprecation().explanation()); TypeResolver type_resolver(op_def); for (const string& next_input_name : api_def.arg_order()) { for (int i = 0; i < op_def.input_arg().size(); ++i) { if (op_def.input_arg(i).name() == next_input_name) { op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i), - &type_resolver)); + &type_resolver)); break; } } @@ -405,8 +400,8 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { // do not parse attributes already visited, they have probably been inferred // before as an input argument type if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) { - AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i), - &type_resolver); + AttributeSpec attr = + CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver); // attributes with a default value are optional if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) { op.optional_attributes_.push_back(attr); @@ -416,8 +411,8 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { } } for (int i = 0; i < op_def.output_arg().size(); ++i) { - op.outputs_.push_back(CreateOutput(op_def.output_arg(i), api_def.out_arg(i), - &type_resolver)); + op.outputs_.push_back( + CreateOutput(op_def.output_arg(i), api_def.out_arg(i), &type_resolver)); } for (const auto& endpoint_def : api_def.endpoint()) { op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def)); diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h index 034cf636ed071a9dccac643d0f89988b070a1efc..30ecb8ce53d15372606981639183d3ba0e4466a4 100644 --- a/tensorflow/java/src/gen/cc/op_specs.h +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -19,14 +19,16 @@ limitations under the License. #include #include -#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/java/src/gen/cc/java_defs.h" namespace tensorflow { namespace java { +constexpr const char kDefaultEndpointPackage[] = "core"; + class EndpointSpec { public: // A specification for an operation endpoint @@ -36,9 +38,8 @@ class EndpointSpec { // javadoc: the endpoint class documentation // TODO(annarev): hardcode depcreated to false until deprecated is possible EndpointSpec(const string& package, const string& name, - const Javadoc& javadoc) - : package_(package), name_(name), javadoc_(javadoc), - deprecated_(false) {} + const Javadoc& javadoc) + : package_(package), name_(name), javadoc_(javadoc), deprecated_(false) {} const string& package() const { return package_; } const string& name() const { return name_; } @@ -61,10 +62,13 @@ class ArgumentSpec { // type: the tensor type of this argument // description: a description of this argument, in javadoc // iterable: true if this argument is a list - ArgumentSpec(const string& op_def_name, const Variable& var, - const Type& type, const string& description, bool iterable) - : op_def_name_(op_def_name), var_(var), type_(type), - description_(description), iterable_(iterable) {} + ArgumentSpec(const string& op_def_name, const Variable& var, const Type& type, + const string& description, bool iterable) + : op_def_name_(op_def_name), + var_(var), + type_(type), + description_(description), + iterable_(iterable) {} const string& op_def_name() const { return op_def_name_; } const Variable& var() const { return var_; } @@ -92,11 +96,16 @@ class AttributeSpec { // iterable: true if this attribute is a list // has_default_value: true if this attribute has a default value if not set AttributeSpec(const string& op_def_name, const Variable& var, - const Type& type, const Type& jni_type, const string& description, - bool iterable, bool has_default_value) - : op_def_name_(op_def_name), var_(var), type_(type), - description_(description), iterable_(iterable), - jni_type_(jni_type), has_default_value_(has_default_value) {} + const Type& type, const Type& jni_type, + const string& description, bool iterable, + bool has_default_value) + : op_def_name_(op_def_name), + var_(var), + type_(type), + description_(description), + iterable_(iterable), + jni_type_(jni_type), + has_default_value_(has_default_value) {} const string& op_def_name() const { return op_def_name_; } const Variable& var() const { return var_; } @@ -145,9 +154,10 @@ class OpSpec { // hidden: true if this op should not be visible through the Graph Ops API // deprecation_explanation: message to show if all endpoints are deprecated explicit OpSpec(const string& graph_op_name, bool hidden, - const string& deprecation_explanation) - : graph_op_name_(graph_op_name), hidden_(hidden), - deprecation_explanation_(deprecation_explanation) {} + const string& deprecation_explanation) + : graph_op_name_(graph_op_name), + hidden_(hidden), + deprecation_explanation_(deprecation_explanation) {} const string graph_op_name_; const bool hidden_; diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index 11fda4fc22aeec9c2d94b5e884c11ceb2a66d29e..1b7bcdab35f45142aefdc9e9635b398090e60b17 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -15,19 +15,44 @@ limitations under the License. package org.tensorflow.processor; +import com.google.common.base.CaseFormat; +import com.google.common.base.Strings; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.FieldSpec; +import com.squareup.javapoet.JavaFile; +import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.ParameterSpec; +import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.TypeSpec; +import com.squareup.javapoet.TypeVariableName; import java.io.IOException; -import java.io.PrintWriter; +import java.util.Collection; import java.util.Collections; -import java.util.HashSet; +import java.util.HashMap; +import java.util.Map; import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import javax.annotation.processing.AbstractProcessor; import javax.annotation.processing.Filer; import javax.annotation.processing.Messager; import javax.annotation.processing.ProcessingEnvironment; import javax.annotation.processing.RoundEnvironment; import javax.lang.model.SourceVersion; +import javax.lang.model.element.AnnotationMirror; +import javax.lang.model.element.AnnotationValue; import javax.lang.model.element.Element; +import javax.lang.model.element.ExecutableElement; +import javax.lang.model.element.Modifier; import javax.lang.model.element.TypeElement; +import javax.lang.model.element.TypeParameterElement; +import javax.lang.model.element.VariableElement; +import javax.lang.model.type.TypeMirror; +import javax.lang.model.type.TypeVariable; +import javax.lang.model.util.ElementFilter; +import javax.lang.model.util.Elements; import javax.tools.Diagnostic.Kind; /** @@ -55,6 +80,7 @@ public final class OperatorProcessor extends AbstractProcessor { super.init(processingEnv); messager = processingEnv.getMessager(); filer = processingEnv.getFiler(); + elements = processingEnv.getElementUtils(); } @Override @@ -98,42 +124,77 @@ public final class OperatorProcessor extends AbstractProcessor { } // Collect all classes tagged with our annotation. - Set opClasses = new HashSet(); - if (!collectOpClasses(roundEnv, opClasses, annotation)) { + Multimap groupedMethods = HashMultimap.create(); + if (!collectOpsMethods(roundEnv, groupedMethods, annotation)) { return true; } // Nothing to do when there are no tagged classes. - if (opClasses.isEmpty()) { + if (groupedMethods.isEmpty()) { return true; } - // TODO:(kbsriram) validate operator classes and generate Op API. - writeApi(); + // Validate operator classes and generate Op API. + writeApi(groupedMethods); + hasRun = true; return true; } @Override public Set getSupportedAnnotationTypes() { - return Collections.singleton(String.format("%s.annotation.Operator", OP_PACKAGE)); + return Collections.singleton("org.tensorflow.op.annotation.Operator"); + } + + private static final Pattern JAVADOC_TAG_PATTERN = + Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); + private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops"); + private static final TypeName T_OPERATOR = + ClassName.get("org.tensorflow.op.annotation", "Operator"); + private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope"); + private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph"); + private static final TypeName T_STRING = ClassName.get(String.class); + + private Filer filer; + private Messager messager; + private Elements elements; + private boolean hasRun = false; + + private void error(Element e, String message, Object... args) { + if (args != null && args.length > 0) { + message = String.format(message, args); + } + messager.printMessage(Kind.ERROR, message, e); } - private void writeApi() { - // Generate an empty class for now and get the build working correctly. This will be changed to - // generate the actual API once we've done with build-related changes. - // TODO:(kbsriram) - try (PrintWriter writer = - new PrintWriter(filer.createSourceFile(String.format("%s.Ops", OP_PACKAGE)).openWriter())) { - writer.println(String.format("package %s;", OP_PACKAGE)); - writer.println("public class Ops{}"); + private void write(TypeSpec spec) { + try { + JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer); } catch (IOException e) { - error(null, "Unexpected failure generating API: %s", e.getMessage()); + throw new AssertionError(e); + } + } + + private void writeApi(Multimap groupedMethods) { + Map groups = new HashMap<>(); + + // Generate a API class for each group collected other than the default one (= empty string) + for (Map.Entry> entry : groupedMethods.asMap().entrySet()) { + if (!entry.getKey().isEmpty()) { + TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue()); + write(groupClass); + groups.put(entry.getKey(), ClassName.get("org.tensorflow.op", groupClass.name)); + } } + // Generate the top API class, adding any methods added to the default group + TypeSpec topClass = buildTopClass(groups, groupedMethods.get("")); + write(topClass); } - private boolean collectOpClasses( - RoundEnvironment roundEnv, Set opClasses, TypeElement annotation) { + private boolean collectOpsMethods( + RoundEnvironment roundEnv, + Multimap groupedMethods, + TypeElement annotation) { boolean result = true; for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) { // @Operator can only apply to types, so e must be a TypeElement. @@ -145,20 +206,251 @@ public final class OperatorProcessor extends AbstractProcessor { result = false; continue; } - opClasses.add((TypeElement) e); + TypeElement opClass = (TypeElement) e; + // Skip deprecated operations for now, as we do not guarantee API stability yet + if (opClass.getAnnotation(Deprecated.class) == null) { + collectOpMethods(groupedMethods, opClass, annotation); + } } return result; } - private void error(Element e, String message, Object... args) { - if (args != null && args.length > 0) { - message = String.format(message, args); + private void collectOpMethods( + Multimap groupedMethods, TypeElement opClass, TypeElement annotation) { + AnnotationMirror am = getAnnotationMirror(opClass, annotation); + String groupName = getAnnotationElementValueAsString("group", am); + String methodName = getAnnotationElementValueAsString("name", am); + ClassName opClassName = ClassName.get(opClass); + if (Strings.isNullOrEmpty(methodName)) { + methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName()); + } + // Build a method for each @Operator found in the class path. There should be one method per + // operation factory called + // "create", which takes in parameter a scope and, optionally, a list of arguments + for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) { + if (opMethod.getModifiers().contains(Modifier.STATIC) + && opMethod.getSimpleName().contentEquals("create")) { + MethodSpec method = buildOpMethod(methodName, opClassName, opMethod); + groupedMethods.put(groupName, method); + } } - messager.printMessage(Kind.ERROR, message, e); } - private Filer filer; - private Messager messager; - private boolean hasRun = false; - private static final String OP_PACKAGE = "org.tensorflow.op"; + private MethodSpec buildOpMethod( + String methodName, ClassName opClassName, ExecutableElement factoryMethod) { + MethodSpec.Builder builder = + MethodSpec.methodBuilder(methodName) + .addModifiers(Modifier.PUBLIC) + .returns(TypeName.get(factoryMethod.getReturnType())) + .varargs(factoryMethod.isVarArgs()) + .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod)); + + for (TypeParameterElement tp : factoryMethod.getTypeParameters()) { + TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType()); + builder.addTypeVariable(tvn); + } + for (TypeMirror thrownType : factoryMethod.getThrownTypes()) { + builder.addException(TypeName.get(thrownType)); + } + StringBuilder call = new StringBuilder("return $T.create(scope"); + boolean first = true; + for (VariableElement param : factoryMethod.getParameters()) { + ParameterSpec p = ParameterSpec.get(param); + if (first) { + first = false; + continue; + } + call.append(", "); + call.append(p.name); + builder.addParameter(p); + } + call.append(")"); + builder.addStatement(call.toString(), opClassName); + return builder.build(); + } + + private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) { + StringBuilder javadoc = new StringBuilder(); + javadoc + .append("Adds an {@link ") + .append(opClassName.simpleName()) + .append("} operation to the graph\n\n"); + + // Add all javadoc tags found in the operator factory method but the first one, which should be + // in all cases the + // 'scope' parameter that is implicitly passed by this API + Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod)); + boolean firstParam = true; + + while (tagMatcher.find()) { + String tag = tagMatcher.group(); + if (tag.startsWith("@param") && firstParam) { + firstParam = false; + } else { + javadoc.append(tag).append('\n'); + } + } + javadoc.append("@see ").append(opClassName).append("\n"); + + return javadoc.toString(); + } + + private static TypeSpec buildGroupClass(String group, Collection methods) { + MethodSpec.Builder ctorBuilder = + MethodSpec.constructorBuilder() + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope"); + + TypeSpec.Builder builder = + TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops") + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for adding {@code $L} operations to a {@link $T Graph}\n\n" + + "@see {@link $T}\n", + group, + T_GRAPH, + T_OPS) + .addMethods(methods) + .addMethod(ctorBuilder.build()); + + builder.addField( + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); + + return builder.build(); + } + + private static TypeSpec buildTopClass( + Map groupToClass, Collection methods) { + MethodSpec.Builder ctorBuilder = + MethodSpec.constructorBuilder() + .addModifiers(Modifier.PRIVATE) + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope", T_SCOPE); + + for (Map.Entry entry : groupToClass.entrySet()) { + ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue()); + } + + TypeSpec.Builder opsBuilder = + TypeSpec.classBuilder("Ops") + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for building a {@link $T} with operation wrappers\n

\n" + + "Any operation wrapper found in the classpath properly annotated as an" + + "{@link $T @Operator} is exposed\n" + + "by this API or one of its subgroup.\n

Example usage:\n

{@code\n"
+                    + "try (Graph g = new Graph()) {\n"
+                    + "  Ops ops = new Ops(g);\n"
+                    + "  // Operations are typed classes with convenience\n"
+                    + "  // builders in Ops.\n"
+                    + "  Constant three = ops.constant(3);\n"
+                    + "  // Single-result operations implement the Operand\n"
+                    + "  // interface, so this works too.\n"
+                    + "  Operand four = ops.constant(4);\n"
+                    + "  // Most builders are found within a group, and accept\n"
+                    + "  // Operand types as operands\n"
+                    + "  Operand nine = ops.math().add(four, ops.constant(5));\n"
+                    + "  // Multi-result operations however offer methods to\n"
+                    + "  // select a particular result for use.\n"
+                    + "  Operand result = \n"
+                    + "      ops.math().add(ops.array().unique(s, a).y(), b);\n"
+                    + "  // Optional attributes\n"
+                    + "  ops.math().matMul(a, b, MatMul.transposeA(true));\n"
+                    + "  // Naming operators\n"
+                    + "  ops.withName(“foo”).constant(5); // name “foo”\n"
+                    + "  // Names can exist in a hierarchy\n"
+                    + "  Ops sub = ops.withSubScope(“sub”);\n"
+                    + "  sub.withName(“bar”).constant(4); // “sub/bar”\n"
+                    + "}\n"
+                    + "}
\n", + T_GRAPH, + T_OPERATOR) + .addMethods(methods) + .addMethod(ctorBuilder.build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("withSubScope") + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "childScopeName") + .returns(T_OPS) + .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) + .addJavadoc( + "Returns an API that adds operations to the graph with the provided name prefix.\n" + + "\n@see {@link $T#withSubScope(String)}\n", + T_SCOPE) + .build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("withName") + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "opName") + .returns(T_OPS) + .addStatement("return new Ops(scope.withName(opName))") + .addJavadoc( + "Returns an API that uses the provided name for an op.\n\n" + + "@see {@link $T#withName(String)}\n", + T_SCOPE) + .build()); + + opsBuilder.addField( + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("scope") + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(T_SCOPE) + .addStatement("return scope") + .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) + .build()); + + for (Map.Entry entry : groupToClass.entrySet()) { + opsBuilder.addField( + FieldSpec.builder(entry.getValue(), entry.getKey()) + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder(entry.getKey()) + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(entry.getValue()) + .addStatement("return $L", entry.getKey()) + .addJavadoc( + "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) + .build()); + } + + opsBuilder.addMethod( + MethodSpec.methodBuilder("create") + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .addParameter(T_GRAPH, "graph") + .returns(T_OPS) + .addStatement("return new Ops(new $T(graph))", T_SCOPE) + .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") + .build()); + + return opsBuilder.build(); + } + + private static AnnotationMirror getAnnotationMirror(Element element, TypeElement annotation) { + for (AnnotationMirror am : element.getAnnotationMirrors()) { + if (am.getAnnotationType().asElement().equals(annotation)) { + return am; + } + } + throw new IllegalArgumentException( + "Annotation " + + annotation.getSimpleName() + + " not present on element " + + element.getSimpleName()); + } + + private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) { + for (Map.Entry entry : + am.getElementValues().entrySet()) { + if (entry.getKey().getSimpleName().contentEquals(elementName)) { + return entry.getValue().getValue().toString(); + } + } + return ""; + } } diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java index 7b92be6d385765c749a6aaed2f1d29df2710c247..516655040baccb66f33f351226361032b126a87b 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java +++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java @@ -17,40 +17,54 @@ package org.tensorflow; import java.util.HashMap; import java.util.Map; + import org.tensorflow.types.UInt8; /** Represents the type of elements in a {@link Tensor} as an enum. */ public enum DataType { /** 32-bit single precision floating point. */ - FLOAT(1), + FLOAT(1, 4), /** 64-bit double precision floating point. */ - DOUBLE(2), + DOUBLE(2, 8), /** 32-bit signed integer. */ - INT32(3), + INT32(3, 4), /** 8-bit unsigned integer. */ - UINT8(4), + UINT8(4, 1), /** * A sequence of bytes. * *

TensorFlow uses the STRING type for an arbitrary sequence of bytes. */ - STRING(7), + STRING(7, -1), /** 64-bit signed integer. */ - INT64(9), + INT64(9, 8), /** Boolean. */ - BOOL(10); + BOOL(10, 1); private final int value; + + private final int byteSize; - // The integer value must match the corresponding TF_* value in the TensorFlow C API. - DataType(int value) { + /** + * @param value must match the corresponding TF_* value in the TensorFlow C API. + * @param byteSize size of an element of this type, in bytes, -1 if unknown + */ + DataType(int value, int byteSize) { this.value = value; + this.byteSize = byteSize; + } + + /** + * Returns the size of an element of this type, in bytes, or -1 if element size is variable. + */ + public int byteSize() { + return byteSize; } /** Corresponding value of the TF_DataType enum in the TensorFlow C API. */ diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index d4fd3db5f7325ae891832ff7b658f5d3ea0789a6..752b49af040268d7e3355b12e4ae6aae310789bd 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -143,6 +143,99 @@ public final class Graph implements AutoCloseable { } } + /** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., + * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + * + *

{@code dx} are used as initial gradients (which represent the symbolic partial derivatives + * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of + * {@code y}. + * + *

If {@code dx} is null, the implementation will use dx of {@link + * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}. + * + *

{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute + * gradients. It must be unique within the provided graph or the operation will fail. + * + *

If {@code prefix} is null, then one will be chosen automatically. + * + * @param prefix unique string prefix applied before the names of nodes added to the graph to + * compute gradients. If null, a default one will be chosen. + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return the partial derivatives {@code dy} with the size of {@code x} + */ + public Output[] addGradients(String prefix, Output[] y, Output[] x, Output[] dx) { + Output[] dy = new Output[x.length]; + final long[] yHandles = new long[y.length]; + final int[] yIndices = new int[y.length]; + final long[] xHandles = new long[x.length]; + final int[] xIndices = new int[x.length]; + long[] dxHandles = null; + int[] dxIndices = null; + + try (Reference ref = ref()) { + for (int i = 0; i < y.length; ++i) { + yHandles[i] = y[i].op().getUnsafeNativeHandle(); + yIndices[i] = y[i].index(); + } + for (int i = 0; i < x.length; ++i) { + xHandles[i] = x[i].op().getUnsafeNativeHandle(); + xIndices[i] = x[i].index(); + } + if (dx != null && dx.length > 0) { + dxHandles = new long[dx.length]; + dxIndices = new int[dx.length]; + + for (int i = 0; i < dx.length; ++i) { + dxHandles[i] = dx[i].op().getUnsafeNativeHandle(); + dxIndices[i] = dx[i].index(); + } + } + // Gradient outputs are returned in two continuous arrays concatenated into one. The first + // holds the native handles of the gradient operations while the second holds the index of + // their output e.g. given + // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain + // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] + long[] dyHandlesAndIndices = + addGradients( + ref.nativeHandle(), + prefix, + yHandles, + yIndices, + xHandles, + xIndices, + dxHandles, + dxIndices); + int ndy = dyHandlesAndIndices.length >> 1; + if (ndy != dy.length) { + throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length + + " were expected"); + } + for (int i = 0, j = ndy; i < ndy; ++i, ++j) { + Operation op = new Operation(this, dyHandlesAndIndices[i]); + dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]); + } + } + return dy; + } + + /** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code dy/dx_1, dy/dx_2...} + *

+ * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is + * a single output, {@code dx} is null and {@code prefix} is null. + * + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @return the partial derivatives {@code dy} with the size of {@code x} + */ + public Output[] addGradients(Output y, Output[] x) { + return addGradients(null, new Output[] {y}, x, null); + } + private final Object nativeHandleLock = new Object(); private long nativeHandle; private int refcount = 0; @@ -254,6 +347,16 @@ public final class Graph implements AutoCloseable { private static native byte[] toGraphDef(long handle); + private static native long[] addGradients( + long handle, + String prefix, + long[] inputHandles, + int[] inputIndices, + long[] outputHandles, + int[] outputIndices, + long[] gradInputHandles, + int[] gradInputIndices); + static { TensorFlow.init(); } diff --git a/tensorflow/java/src/main/java/org/tensorflow/Input.java b/tensorflow/java/src/main/java/org/tensorflow/Input.java new file mode 100644 index 0000000000000000000000000000000000000000..13bc463e7d6a991858332a353681b24fff417547 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/Input.java @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +/** + * Interface implemented by operands of a TensorFlow operation. + * + *

Example usage: + * + *

{@code
+ * // The "decodeJpeg" operation can be used as input to the "cast" operation
+ * Input decodeJpeg = ops.image().decodeJpeg(...);
+ * ops.math().cast(decodeJpeg, DataType.FLOAT);
+ *
+ * // The output "y" of the "unique" operation can be used as input to the "cast" operation
+ * Output y = ops.array().unique(...).y();
+ * ops.math().cast(y, DataType.FLOAT);
+ *
+ * // The "split" operation can be used as input list to the "concat" operation
+ * Iterable split = ops.array().split(...);
+ * ops.array().concat(0, split);
+ * }
+ */ +public interface Input { + + /** + * Returns the symbolic handle of a tensor. + * + *

Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is + * used to obtain a symbolic handle that represents the computation of the input. + * + * @see OperationBuilder#addInput(Output) + */ + Output asOutput(); +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java index c8b9126f033685c0320dfd2d8594061510bdd1e5..49594e6b47b9295d164a1823386b0981776e66f4 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java @@ -25,18 +25,86 @@ package org.tensorflow; * protocol buffer). */ public class SavedModelBundle implements AutoCloseable { + /** Options for loading a SavedModel. */ + public static final class Loader { + /** Load a SavedModelBundle with the configured options. */ + public SavedModelBundle load() { + return SavedModelBundle.load(exportDir, tags, configProto, runOptions); + } + + /** + * Sets options to use when executing model initialization operations. + * + * @param options Serialized RunOptions + * protocol buffer. + */ + public Loader withRunOptions(byte[] options) { + this.runOptions = options; + return this; + } + + /** + * Set configuration of the Session object created when loading the model. + * + * @param configProto Serialized ConfigProto + * protocol buffer. + */ + public Loader withConfigProto(byte[] configProto) { + this.configProto = configProto; + return this; + } + + /** + * Sets the set of tags that identify the specific graph in the saved model to load. + * + * @param tags the tags identifying the specific MetaGraphDef to load. + */ + public Loader withTags(String... tags) { + this.tags = tags; + return this; + } + + private Loader(String exportDir) { + this.exportDir = exportDir; + } + + private String exportDir = null; + private String[] tags = null; + private byte[] configProto = null; + private byte[] runOptions = null; + } /** * Load a saved model from an export directory. The model that is being loaded should be created * using the Saved Model * API. * + *

This method is a shorthand for: + * + *

{@code
+   * SavedModelBundler.loader().withTags(tags).load();
+   * }
+ * * @param exportDir the directory path containing a saved model. * @param tags the tags identifying the specific metagraphdef to load. * @return a bundle containing the graph and associated session. */ public static SavedModelBundle load(String exportDir, String... tags) { - return load(exportDir, tags, null); + return loader(exportDir).withTags(tags).load(); + } + + /** + * Load a saved model. + * + *

Returns a Loader object that can set configuration options before actually + * loading the model, + * + * @param exportDir the directory path containing a saved model. + */ + public static Loader loader(String exportDir) { + return new Loader(exportDir); } /** @@ -95,7 +163,8 @@ public class SavedModelBundle implements AutoCloseable { return new SavedModelBundle(graph, session, metaGraphDef); } - private static native SavedModelBundle load(String exportDir, String[] tags, byte[] runOptions); + private static native SavedModelBundle load( + String exportDir, String[] tags, byte[] config, byte[] runOptions); static { TensorFlow.init(); diff --git a/tensorflow/java/src/main/java/org/tensorflow/Session.java b/tensorflow/java/src/main/java/org/tensorflow/Session.java index 73324f23e6e3b79f2c3785bea1990b5317f16a52..a660d25f98ec961ac2ba1a48bced13803c00096b 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Session.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Session.java @@ -185,11 +185,20 @@ public final class Session implements AutoCloseable { return this; } - /** Makes {@link #run()} return the Tensor referred to by {@code output}. */ + /** + * Makes {@link #run()} return the Tensor referred to by {@code output}. + */ public Runner fetch(Output output) { outputs.add(output); return this; } + + /** + * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}. + */ + public Runner fetch(Operand operand) { + return fetch(operand.asOutput()); + } /** * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s. @@ -209,6 +218,13 @@ public final class Session implements AutoCloseable { targets.add(operation); return this; } + + /** + * Make {@link #run()} execute {@code operand}, but not return any evaluated {@link Tensor}s. + */ + public Runner addTarget(Operand operand) { + return addTarget(operand.asOutput().op()); + } /** * (Experimental method): set options (typically for debugging) for this run. diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java index 24a3775db625d3224628ee2d4c6fea9c56ff94fd..89872537689815924a070c282c34c4a2baf175c7 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java @@ -595,20 +595,11 @@ public final class Tensor implements AutoCloseable { } private static int elemByteSize(DataType dataType) { - switch (dataType) { - case FLOAT: - case INT32: - return 4; - case DOUBLE: - case INT64: - return 8; - case BOOL: - case UINT8: - return 1; - case STRING: + int size = dataType.byteSize(); + if (size < 0) { throw new IllegalArgumentException("STRING tensors do not have a fixed element size"); } - throw new IllegalArgumentException("DataType " + dataType + " is not supported yet"); + return size; } private static void throwExceptionIfNotByteOfByteArrays(Object array) { diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java index 8de2eaeb797628d1239eba46f6befb46de100af9..5a233bcc98469133e4e47551435b27ab2890ec22 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java @@ -135,17 +135,8 @@ public final class Scope { * }

* *

Note: if you provide a composite operator building class (i.e, a class that adds a - * set of related operations to the graph by calling other operator building code) you should also - * create a {@link #withSubScope(String)} scope for the underlying operators to group them under a - * meaningful name. - * - *

{@code
-   * public static Stddev create(Scope scope, ...) {
-   *   // group sub-operations under a common name
-   *   Scope group = scope.withSubScope("stddev");
-   *   ... Sqrt.create(group, Mean.create(group, ...))
-   * }
-   * }
+ * set of related operations to the graph by calling other operator building code), the provided + * name will act as a subscope to all underlying operators. * * @param defaultName name for the underlying operator. * @return unique name for the operator. diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java index de4049f66b2a88ff086a3319fb1c3b8b3b6143d9..00b6726be346e98c00e83ccc3c1e9bbde736fed7 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -15,11 +15,15 @@ limitations under the License. package org.tensorflow.op.core; +import static java.nio.charset.StandardCharsets.UTF_8; + import java.nio.ByteBuffer; import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; +import java.nio.charset.Charset; + import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -32,25 +36,82 @@ import org.tensorflow.op.annotation.Operator; /** An operator producing a constant value. */ @Operator public final class Constant extends PrimitiveOp implements Operand { + /** - * Create a constant from a Java object. + * Creates a constant containing a single {@code int} element. * - *

The argument {@code object} is first converted into a Tensor using {@link - * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be - * provided. For example: + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return an integer constant + */ + public static Constant create(Scope scope, int data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-1 constant of {@code int} elements. * - *

{@code
-   * Constant.create(scope, 7); // returns a constant scalar tensor 7
-   * }
+ * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, int[] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-2 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param object a Java object representing the constant. - * @see org.tensorflow.Tensor#create(Object) Tensor.create + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. */ - public static Constant create(Scope scope, Object object, Class type) { - try (Tensor value = Tensor.create(object, type)) { - return createWithTensor(scope, value); - } + public static Constant create(Scope scope, int[][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-3 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, int[][][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-4 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, int[][][][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-5 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, int[][][][][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-6 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, int[][][][][][] data) { + return create(scope, data, Integer.class); } /** @@ -64,6 +125,7 @@ public final class Constant extends PrimitiveOp implements Operand { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return an integer constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant create(Scope scope, long[] shape, IntBuffer data) { @@ -72,6 +134,83 @@ public final class Constant extends PrimitiveOp implements Operand { } } + /** + * Creates a constant containing a single {@code float} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return a float constant + */ + public static Constant create(Scope scope, float data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-1 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, float[] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-2 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, float[][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-3 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, float[][][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-4 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, float[][][][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-5 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, float[][][][][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-6 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, float[][][][][][] data) { + return create(scope, data, Float.class); + } + /** * Create a {@link DataType#FLOAT} constant with data from the given buffer. * @@ -83,6 +222,7 @@ public final class Constant extends PrimitiveOp implements Operand { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a float constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant create(Scope scope, long[] shape, FloatBuffer data) { @@ -91,6 +231,83 @@ public final class Constant extends PrimitiveOp implements Operand { } } + /** + * Creates a constant containing a single {@code double} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return a double constant + */ + public static Constant create(Scope scope, double data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-1 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, double[] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-2 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, double[][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-3 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, double[][][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-4 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, double[][][][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-5 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, double[][][][][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-6 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, double[][][][][][] data) { + return create(scope, data, Double.class); + } + /** * Create a {@link DataType#DOUBLE} constant with data from the given buffer. * @@ -102,6 +319,7 @@ public final class Constant extends PrimitiveOp implements Operand { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a double constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant create(Scope scope, long[] shape, DoubleBuffer data) { @@ -110,6 +328,83 @@ public final class Constant extends PrimitiveOp implements Operand { } } + /** + * Creates a constant containing a single {@code long} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return a long constant + */ + public static Constant create(Scope scope, long data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-1 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, long[] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-2 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, long[][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-3 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, long[][][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-4 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, long[][][][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-5 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, long[][][][][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-6 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, long[][][][][][] data) { + return create(scope, data, Long.class); + } + /** * Create a {@link DataType#INT64} constant with data from the given buffer. * @@ -121,6 +416,7 @@ public final class Constant extends PrimitiveOp implements Operand { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a long constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant create(Scope scope, long[] shape, LongBuffer data) { @@ -129,6 +425,174 @@ public final class Constant extends PrimitiveOp implements Operand { } } + /** + * Creates a constant containing a single {@code boolean} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return a boolean constant + */ + public static Constant create(Scope scope, boolean data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-1 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, boolean[] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-2 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, boolean[][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-3 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, boolean[][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-4 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, boolean[][][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-5 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, boolean[][][][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-6 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, boolean[][][][][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a {@code String} constant using the default, UTF-8 encoding. + * + * @param scope is a scope used to add the underlying operation. + * @param data The string to put into the new constant. + * @return a string constant + */ + public static Constant create(Scope scope, String data) { + return create(scope, data, UTF_8); + } + + /** + * Creates a {@code String} constant using a specified encoding. + * + * @param scope is a scope used to add the underlying operation. + * @param charset The encoding from String to bytes. + * @param data The string to put into the new constant. + * @return a string constant + */ + public static Constant create(Scope scope, String data, Charset charset) { + try (Tensor value = Tensor.create(data.getBytes(charset), String.class)) { + return createWithTensor(scope, Tensor.create(data.getBytes(charset), String.class)); + } + } + + /** + * Creates a constant containing a single {@code String} element, represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant create(Scope scope, byte[] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-1 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant create(Scope scope, byte[][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-2 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant create(Scope scope, byte[][][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-3 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant create(Scope scope, byte[][][][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-4 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant create(Scope scope, byte[][][][][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-5 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant create(Scope scope, byte[][][][][][] data) { + return create(scope, data, String.class); + } + /** * Create a constant with data from the given buffer. * @@ -141,6 +605,7 @@ public final class Constant extends PrimitiveOp implements Operand { * @param type the tensor datatype. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a constant of type `type` * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ @@ -150,6 +615,28 @@ public final class Constant extends PrimitiveOp implements Operand { } } + /** + * Create a constant from a Java object. + * + *

The argument {@code object} is first converted into a Tensor using {@link + * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be + * provided. For example: + * + *

{@code
+   * Constant.create(scope, new int[]{{1, 2}, {3, 4}}, Integer.class); // returns a 2x2 integer matrix
+   * }
+ * + * @param scope is a scope used to add the underlying operation. + * @param object a Java object representing the constant. + * @return a constant of type `type` + * @see org.tensorflow.Tensor#create(Object) Tensor.create + */ + public static Constant create(Scope scope, Object object, Class type) { + try (Tensor value = Tensor.create(object, type)) { + return createWithTensor(scope, value); + } + } + private static Constant createWithTensor(Scope scope, Tensor value) { return new Constant( scope diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java new file mode 100644 index 0000000000000000000000000000000000000000..eea9dc1c47c925df3b359059b70c95440b66c009 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -0,0 +1,161 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.Op; +import org.tensorflow.op.Operands; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Operator; + +/** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + *

+ * If {@code Options.dx()} values are set, they are as the initial symbolic partial derivatives of some loss + * function {@code L} w.r.t. {@code y}. {@code Options.dx()} must have the size of {@code y}. + *

+ * If {@code Options.dx()} is not set, the implementation will use dx of {@code OnesLike} for all + * shapes in {@code y}. + *

+ * The partial derivatives are returned in output {@code dy}, with the size of {@code x}. + *

+ * Example of usage: + *

{@code
+ * Gradients gradients = Gradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b));
+ * 
+ * Constant alpha = ops.constant(1.0f, Float.class);
+ * ApplyGradientDescent.create(scope, w, alpha, gradients.dy(0));
+ * ApplyGradientDescent.create(scope, b, alpha, gradients.dy(1));
+ * }
+ */ +@Operator +public class Gradients implements Op, Iterable> { + + /** + * Optional attributes for {@link Gradients} + */ + public static class Options { + + /** + * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return this option builder + */ + public Options dx(Iterable> dx) { + this.dx = dx; + return this; + } + + private Iterable> dx; + + private Options() { + } + } + + /** + * Adds gradients computation ops to the graph according to scope. + * + * @param scope current graph scope + * @param y outputs of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param options carries optional attributes values + * @return a new instance of {@code Gradients} + */ + public static Gradients create( + Scope scope, + Iterable> y, + Iterable> x, + Options... options) { + Output[] dx = null; + if (options != null) { + for (Options opts : options) { + if (opts.dx != null) { + dx = Operands.asOutputs(opts.dx); + } + } + } + Output[] dy = + scope + .graph() + .addGradients( + scope.makeOpName("Gradients"), Operands.asOutputs(y), Operands.asOutputs(x), dx); + return new Gradients(Arrays.asList(dy)); + } + + /** + * Adds gradients computation ops to the graph according to scope. + * + *

This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where + * {@code y} is a single output. + * + * @param scope current graph scope + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param options carries optional attributes values + * @return a new instance of {@code Gradients} + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public static Gradients create( + Scope scope, Operand y, Iterable> x, Options... options) { + return create(scope, (Iterable) Arrays.asList(y), x, options); + } + + /** + * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return builder to add more options to this operation + */ + public static Options dx(Iterable> dx) { + return new Options().dx(dx); + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator> iterator() { + return (Iterator) dy.iterator(); + } + + /** + * Partial derivatives of {@code y}s w.r.t. {@code x}s, with the size of {@code x} + */ + public List> dy() { + return dy; + } + + /** + * Returns a symbolic handle to one of the gradient operation output + * + *

Warning: Does not check that the type of the tensor matches T. It is recommended to call + * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code + * gradients.dy(0)} + * + * @param The expected element type of the tensors produced by this output. + * @param index The index of the output among the gradients added by this operation + */ + @SuppressWarnings("unchecked") + public Output dy(int index) { + return (Output) dy.get(index); + } + + private List> dy; + + private Gradients(List> dy) { + this.dy = dy; + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java new file mode 100644 index 0000000000000000000000000000000000000000..b7c6beb9bcc118243d68c21c4232c591af2210aa --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.op.core; + +import java.nio.ByteBuffer; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Operator; + +/** + * An operator creating a constant initialized with zeros of the shape given by `dims`. + * + *

For example, the following expression + *

{@code ops.zeros(ops.constant(new long[]{2, 2}), Float.class)
+ * is the equivalent of + *
{@code ops.fill(ops.constant(new long[]{2, 2}), ops.constant(0.0f))
+ * + * @param constant type + */ +@Operator +public class Zeros implements Op, Operand { + + /** + * Creates a zeroed tensor given its type and shape. + * + * @param scope is a scope used to add the underlying operation + * @param dims a 1-D operand that represents the shape of the output tensor + * @param type the output tensor datatype + * @return a constant tensor initialized with zeros + * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros. + */ + public static Zeros create(Scope scope, Operand dims, Class type) { + Scope childScope = scope.withSubScope("Zeros"); // If scope had an op name set, it will prevail on "Zeros" + int zeroSize = DataType.fromClass(type).byteSize(); + if (zeroSize < 0) { + throw new IllegalArgumentException(type.getSimpleName() + " tensors cannot be initialized with zeros"); + } + Constant zero = Constant.create(childScope.withName("Zero"), type, new long[]{}, ByteBuffer.allocate(zeroSize)); + return new Zeros(Fill.create(childScope, dims, zero)); + } + + @Override + public Output asOutput() { + return fill.asOutput(); + } + + private final Fill fill; + + private Zeros(Fill fill) { + this.fill = fill; + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/package-info.java b/tensorflow/java/src/main/java/org/tensorflow/package-info.java index 521c5c610c1f775cf9174664f5b786786ce1181d..f353ee31459806eb2db98d23ac030c15258a77fb 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/package-info.java +++ b/tensorflow/java/src/main/java/org/tensorflow/package-info.java @@ -17,7 +17,7 @@ limitations under the License. * Defines classes to build, save, load and execute TensorFlow models. * *

WARNING: The API is currently experimental and is not covered by TensorFlow API stability + * href="https://www.tensorflow.org/guide/version_semantics">API stability * guarantees. See README.md for installation * instructions. diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc index 0fef15527586555e7d3fc2c76403c6e5888fb236..f1744d87693ae8f43c032b24622aaecb41a30cb2 100644 --- a/tensorflow/java/src/main/native/graph_jni.cc +++ b/tensorflow/java/src/main/native/graph_jni.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/java/src/main/native/graph_jni.h" #include +#include #include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/utils_jni.h" #include "tensorflow/java/src/main/native/exception_jni.h" namespace { @@ -130,3 +132,60 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) { TF_DeleteBuffer(buf); return ret; } + +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients( + JNIEnv* env, jclass clazz, jlong handle, jstring prefix, + jlongArray y_handles, jintArray y_indices, jlongArray x_handles, + jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) { + TF_Graph* g = requireHandle(env, handle); + if (g == nullptr) return nullptr; + + const jint ny = env->GetArrayLength(y_handles); + const jint nx = env->GetArrayLength(x_handles); + + std::unique_ptr y(new TF_Output[ny]); + std::unique_ptr x(new TF_Output[nx]); + std::unique_ptr dx(nullptr); + std::unique_ptr dy(new TF_Output[nx]); + + resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny); + resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx); + if (dx_handles != nullptr) { + if (env->GetArrayLength(dx_handles) != ny) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d dx handles", ny, + env->GetArrayLength(dx_handles)); + } + dx.reset(new TF_Output[ny]); + resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny); + } + if (env->ExceptionCheck()) return nullptr; + + const char* cprefix = nullptr; + if (prefix != nullptr) { + cprefix = env->GetStringUTFChars(prefix, nullptr); + } + TF_Status* status = TF_NewStatus(); + TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(), + status, dy.get()); + if (prefix != nullptr) { + env->ReleaseStringUTFChars(prefix, cprefix); + } + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return nullptr; + } + TF_DeleteStatus(status); + + // returned array contains both op handles and output indices, in pair + jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1); + jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr); + for (int i = 0, j = nx; i < nx; ++i, ++j) { + TF_Output dy_output = dy.get()[i]; + dy_elems[i] = reinterpret_cast(dy_output.oper); + dy_elems[j] = static_cast(dy_output.index); + } + env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0); + + return dy_handles_and_indices; +} diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h index dd2e038332f7d39e6460d6cfef40a9df7e348758..215695cdfde4a5c0c7bef0e329b48d632f6de87a 100644 --- a/tensorflow/java/src/main/native/graph_jni.h +++ b/tensorflow/java/src/main/native/graph_jni.h @@ -73,6 +73,15 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *, jclass, jlong); +/* + * Class: org_tensorflow_Graph + * Method: name + * Signature: (JLjava/lang/String;[J[I[J[I[J[I)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients( + JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray, + jintArray, jlongArray, jintArray); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/java/src/main/native/saved_model_bundle_jni.cc b/tensorflow/java/src/main/native/saved_model_bundle_jni.cc index de6382a79c484bac1c8c6746562199c4abdc52de..68999fb2da8b9bd6e2df1f76abfa4f0d86952a0c 100644 --- a/tensorflow/java/src/main/native/saved_model_bundle_jni.cc +++ b/tensorflow/java/src/main/native/saved_model_bundle_jni.cc @@ -22,12 +22,25 @@ limitations under the License. JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load( JNIEnv* env, jclass clazz, jstring export_dir, jobjectArray tags, - jbyteArray run_options) { + jbyteArray config, jbyteArray run_options) { TF_Status* status = TF_NewStatus(); jobject bundle = nullptr; // allocate parameters for TF_LoadSessionFromSavedModel TF_SessionOptions* opts = TF_NewSessionOptions(); + if (config != nullptr) { + size_t sz = env->GetArrayLength(config); + if (sz > 0) { + jbyte* config_data = env->GetByteArrayElements(config, nullptr); + TF_SetConfig(opts, static_cast(config_data), sz, status); + env->ReleaseByteArrayElements(config, config_data, JNI_ABORT); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteSessionOptions(opts); + TF_DeleteStatus(status); + return nullptr; + } + } + } TF_Buffer* crun_options = nullptr; if (run_options != nullptr) { size_t sz = env->GetArrayLength(run_options); diff --git a/tensorflow/java/src/main/native/saved_model_bundle_jni.h b/tensorflow/java/src/main/native/saved_model_bundle_jni.h index 6cce6a81bd195842d4c2bb86fddbfbb21e0c8f5b..a4b05d0409797e8aa712d22f247fedc2ffbefdf1 100644 --- a/tensorflow/java/src/main/native/saved_model_bundle_jni.h +++ b/tensorflow/java/src/main/native/saved_model_bundle_jni.h @@ -26,10 +26,10 @@ extern "C" { * Class: org_tensorflow_SavedModelBundle * Method: load * Signature: - * (Ljava/lang/String;[Ljava/lang/String;[B)Lorg/tensorflow/SavedModelBundle; + * (Ljava/lang/String;[Ljava/lang/String;[B;[B)Lorg/tensorflow/SavedModelBundle; */ JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load( - JNIEnv *, jclass, jstring, jobjectArray, jbyteArray); + JNIEnv *, jclass, jstring, jobjectArray, jbyteArray, jbyteArray); #ifdef __cplusplus } // extern "C" diff --git a/tensorflow/java/src/main/native/session_jni.cc b/tensorflow/java/src/main/native/session_jni.cc index 2cd542d3c9be536a42037e9ef533ed629dd3ac9f..8b1152578555c0d9b5b4b383460116050c89c3d5 100644 --- a/tensorflow/java/src/main/native/session_jni.cc +++ b/tensorflow/java/src/main/native/session_jni.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/utils_jni.h" #include "tensorflow/java/src/main/native/exception_jni.h" #include "tensorflow/java/src/main/native/session_jni.h" @@ -55,37 +56,6 @@ void resolveHandles(JNIEnv* env, const char* type, jlongArray src_array, env->ReleaseLongArrayElements(src_array, src_start, JNI_ABORT); } -void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, - jintArray src_index, TF_Output* dst, jint n) { - if (env->ExceptionCheck()) return; - jint len = env->GetArrayLength(src_op); - if (len != n) { - throwException(env, kIllegalArgumentException, - "expected %d, got %d %s Operations", n, len, type); - return; - } - len = env->GetArrayLength(src_index); - if (len != n) { - throwException(env, kIllegalArgumentException, - "expected %d, got %d %s Operation output indices", n, len, - type); - return; - } - jlong* op_handles = env->GetLongArrayElements(src_op, nullptr); - jint* indices = env->GetIntArrayElements(src_index, nullptr); - for (int i = 0; i < n; ++i) { - if (op_handles[i] == 0) { - throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type, - i, n); - break; - } - dst[i] = TF_Output{reinterpret_cast(op_handles[i]), - static_cast(indices[i])}; - } - env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT); - env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT); -} - void TF_MaybeDeleteBuffer(TF_Buffer* buf) { if (buf == nullptr) return; TF_DeleteBuffer(buf); @@ -116,20 +86,22 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate2( TF_Graph* graph = reinterpret_cast(graph_handle); TF_Status* status = TF_NewStatus(); TF_SessionOptions* opts = TF_NewSessionOptions(); - const char* ctarget = nullptr; jbyte* cconfig = nullptr; - if (target != nullptr) { - ctarget = env->GetStringUTFChars(target, nullptr); - } if (config != nullptr) { cconfig = env->GetByteArrayElements(config, nullptr); TF_SetConfig(opts, cconfig, static_cast(env->GetArrayLength(config)), status); if (!throwExceptionIfNotOK(env, status)) { env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT); + TF_DeleteSessionOptions(opts); + TF_DeleteStatus(status); return 0; } } + const char* ctarget = nullptr; + if (target != nullptr) { + ctarget = env->GetStringUTFChars(target, nullptr); + } TF_Session* session = TF_NewSession(graph, opts, status); if (config != nullptr) { env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT); diff --git a/tensorflow/java/src/main/native/utils_jni.cc b/tensorflow/java/src/main/native/utils_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..069ac05a1c39408dc02f5bbf9a7fc50fd095cc96 --- /dev/null +++ b/tensorflow/java/src/main/native/utils_jni.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/java/src/main/native/utils_jni.h" + +#include "tensorflow/java/src/main/native/exception_jni.h" + +void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, + jintArray src_index, TF_Output* dst, jint n) { + if (env->ExceptionCheck()) return; + jint len = env->GetArrayLength(src_op); + if (len != n) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d %s Operations", n, len, type); + return; + } + len = env->GetArrayLength(src_index); + if (len != n) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d %s Operation output indices", n, len, + type); + return; + } + jlong* op_handles = env->GetLongArrayElements(src_op, nullptr); + jint* indices = env->GetIntArrayElements(src_index, nullptr); + for (int i = 0; i < n; ++i) { + if (op_handles[i] == 0) { + throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type, + i, n); + break; + } + dst[i] = TF_Output{reinterpret_cast(op_handles[i]), + static_cast(indices[i])}; + } + env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT); + env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT); +} + + + + diff --git a/tensorflow/java/src/main/native/utils_jni.h b/tensorflow/java/src/main/native/utils_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..352298e7de1d07cebc1a287774c9bef85c9a6ae4 --- /dev/null +++ b/tensorflow/java/src/main/native/utils_jni.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_UTILS_JNI_H_ +#define TENSORFLOW_JAVA_UTILS_JNI_H_ + +#include + +#include "tensorflow/c/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, + jintArray src_index, TF_Output* dst, jint n); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif /* TENSORFLOW_JAVA_UTILS_JNI_H_ */ diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index c540299bdcfcd7bc5969caf82b29144bad24201f..7c05c1deafeea5d0b482a70f528d997a3394b365 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -129,4 +129,133 @@ public class GraphTest { // expected exception. } } + + @Test + public void addGradientsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output x1 = TestUtil.placeholder(g, "x1", Float.class); + Output x2 = TestUtil.placeholder(g, "x2", Float.class); + Output y0 = TestUtil.square(g, "y0", x1); + Output y1 = TestUtil.square(g, "y1", y0); + Output y2 = TestUtil.addN(g, y0, x2); + + Output[] grads0 = g.addGradients(y1, toArray(x1)); + assertNotNull(grads0); + assertEquals(1, grads0.length); + assertEquals(DataType.FLOAT, grads0[0].dataType()); + + Output[] grads1 = g.addGradients(y2, toArray(x1, x2)); + assertNotNull(grads1); + assertEquals(2, grads1.length); + assertEquals(DataType.FLOAT, grads1[0].dataType()); + assertEquals(DataType.FLOAT, grads1[1].dataType()); + + try (Tensor c1 = Tensors.create(3.0f); + Tensor c2 = Tensors.create(2.0f); + TestUtil.AutoCloseableList> outputs = new TestUtil.AutoCloseableList<>( + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run())) { + + assertEquals(3, outputs.size()); + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f); + assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f); + } + } + } + + @Test + public void addGradientSumsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output x = TestUtil.placeholder(g, "x", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Output[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null); + assertNotNull(grad); + assertEquals(1, grad.length); + assertEquals(DataType.FLOAT, grad[0].dataType()); + + try (Tensor c = Tensors.create(3.0f); + Tensor output = s.runner() + .feed(x, c) + .fetch(grad[0]) + .run() + .get(0)) { + + assertEquals(114.0f, output.floatValue(), 0.0f); + } + } + } + + @Test + public void addGradientsWithInitialValuesToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output x = TestUtil.placeholder(g, "x", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Output[] grad0 = g.addGradients(y1, toArray(y0)); + assertNotNull(grad0); + assertEquals(1, grad0.length); + assertEquals(DataType.FLOAT, grad0[0].dataType()); + + Output[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0])); + assertNotNull(grad1); + assertEquals(1, grad1.length); + assertEquals(DataType.FLOAT, grad1[0].dataType()); + + try (Tensor c = Tensors.create(3.0f); + Tensor output = s.runner() + .feed(x, c) + .fetch(grad1[0]) + .run() + .get(0)) { + + assertEquals(108.0f, output.floatValue(), 0.0f); + } + } + } + + @Test + public void validateGradientsNames() { + try (Graph g = new Graph()) { + + Output x = TestUtil.placeholder(g, "x", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + + Output[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null); + assertTrue(grad0[0].op().name().startsWith("gradients/")); + + Output[] grad1 = g.addGradients(null, toArray(y0), toArray(x), null); + assertTrue(grad1[0].op().name().startsWith("gradients_1/")); + + Output[] grad2 = g.addGradients("more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad2[0].op().name().startsWith("more_gradients/")); + + Output[] grad3 = g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad3[0].op().name().startsWith("even_more_gradients/")); + + try { + g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); + } catch (IllegalArgumentException e) { + // expected exception + } + } + } + + private static Output[] toArray(Output... outputs) { + return outputs; + } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java index 7922f3329c7d7276edd139d6e3cc741c9c01cf2a..7d936867a785483442203098166664daf7a77b49 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -47,7 +47,61 @@ public class SavedModelBundleTest { fail("not expected"); } catch (org.tensorflow.TensorFlowException e) { // expected exception - assertTrue(e.getMessage().contains("SavedModel not found")); + assertTrue(e.getMessage().contains("Could not find SavedModel")); } } + + @Test + public void loader() { + try (SavedModelBundle bundle = SavedModelBundle.loader(SAVED_MODEL_PATH) + .withTags("serve") + .withConfigProto(sillyConfigProto()) + .withRunOptions(sillyRunOptions()) + .load()) { + assertNotNull(bundle.session()); + assertNotNull(bundle.graph()); + assertNotNull(bundle.metaGraphDef()); + } + } + + private static byte[] sillyRunOptions() { + // Ideally this would use the generated Java sources for protocol buffers + // and end up with something like the snippet below. However, generating + // the Java files for the .proto files in tensorflow/core:protos_all is + // a bit cumbersome in bazel until the proto_library rule is setup. + // + // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866 + // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362 + // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558 + // + // For this test, for now, the use of specific bytes suffices. + return new byte[] {0x08, 0x03}; + /* + return org.tensorflow.framework.RunOptions.newBuilder() + .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) + .build() + .toByteArray(); + */ + } + + public static byte[] sillyConfigProto() { + // Ideally this would use the generated Java sources for protocol buffers + // and end up with something like the snippet below. However, generating + // the Java files for the .proto files in tensorflow/core:protos_all is + // a bit cumbersome in bazel until the proto_library rule is setup. + // + // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866 + // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362 + // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558 + // + // For this test, for now, the use of specific bytes suffices. + return new byte[] {0x10, 0x01, 0x28, 0x01}; + /* + return org.tensorflow.framework.ConfigProto.newBuilder() + .setInterOpParallelismThreads(1) + .setIntraOpParallelismThreads(1) + .build() + .toByteArray(); + */ + } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java index e8cc76c2a6458193161a98e17483fe73de107b77..7d5980bcdedebedcd2fa4722e85abc1d598fb4fd 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java @@ -20,8 +20,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import java.util.ArrayList; -import java.util.Collection; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -36,8 +34,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}}); try (Tensor x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList> outputs = - new AutoCloseableList>(s.runner().feed("X", x).fetch("Y").run())) { + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList>(s.runner().feed("X", x).fetch("Y").run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -53,8 +51,8 @@ public class SessionTest { Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); try (Tensor x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList> outputs = - new AutoCloseableList>(s.runner().feed(feed, x).fetch(fetch).run())) { + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList>(s.runner().feed(feed, x).fetch(fetch).run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -112,7 +110,7 @@ public class SessionTest { .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - AutoCloseableList> outputs = new AutoCloseableList>(result.outputs); + TestUtil.AutoCloseableList> outputs = new TestUtil.AutoCloseableList>(result.outputs); assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -135,8 +133,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.constant(g, "c1", 2718); TestUtil.constant(g, "c2", 31415); - AutoCloseableList> outputs = - new AutoCloseableList>(s.runner().fetch("c2").fetch("c1").run()); + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList>(s.runner().fetch("c2").fetch("c1").run()); assertEquals(2, outputs.size()); assertEquals(31415, outputs.get(0).intValue()); assertEquals(2718, outputs.get(1).intValue()); @@ -164,28 +162,6 @@ public class SessionTest { Session s = new Session(g, singleThreadConfigProto())) {} } - private static final class AutoCloseableList extends ArrayList - implements AutoCloseable { - AutoCloseableList(Collection c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } - } - private static byte[] fullTraceRunOptions() { // Ideally this would use the generated Java sources for protocol buffers // and end up with something like the snippet below. However, generating diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index c973b5a3d8b2be8ee21710d65732bc1e5c3b520a..f984c508ee9e64e32796f410ac65305015ec956b 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -16,9 +16,34 @@ limitations under the License. package org.tensorflow; import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Collection; /** Static utility functions. */ public class TestUtil { + + public static final class AutoCloseableList extends ArrayList + implements AutoCloseable { + public AutoCloseableList(Collection c) { + super(c); + } + + @Override + public void close() { + Exception toThrow = null; + for (AutoCloseable c : this) { + try { + c.close(); + } catch (Exception e) { + toThrow = e; + } + } + if (toThrow != null) { + throw new RuntimeException(toThrow); + } + } + } + public static Output constant(Graph g, String name, Object value) { try (Tensor t = Tensor.create(value)) { return g.opBuilder("Const", name) @@ -36,7 +61,7 @@ public class TestUtil { .output(0); } - public static Output addN(Graph g, Output... inputs) { + public static Output addN(Graph g, Output... inputs) { return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0); } @@ -58,6 +83,13 @@ public class TestUtil { .setAttr("num_split", numSplit) .build(); } + + public static Output square(Graph g, String name, Output value) { + return g.opBuilder("Square", name) + .addInput(value) + .build() + .output(0); + } public static void transpose_A_times_X(Graph g, int[][] a) { Output aa = constant(g, "A", a); diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java index ca54214e0673fbb0308a9eed1e321b7f33bd7fc0..7d3b26de8dcbc099d28f2533bb1283ef7dd579e9 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -16,6 +16,7 @@ limitations under the License. package org.tensorflow.op.core; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.io.ByteArrayOutputStream; @@ -26,6 +27,7 @@ import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -37,6 +39,20 @@ import org.tensorflow.op.Scope; @RunWith(JUnit4.class) public class ConstantTest { private static final float EPSILON = 1e-7f; + + @Test + public void createInt() { + int value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, value); + try (Tensor result = sess.runner().fetch(op).run().get(0).expect(Integer.class)) { + assertEquals(value, result.intValue()); + } + } + } @Test public void createIntBuffer() { @@ -47,10 +63,24 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, shape, IntBuffer.wrap(ints)); - Tensor result = sess.runner().fetch(op.asOutput()) - .run().get(0).expect(Integer.class); - int[] actual = new int[ints.length]; - assertArrayEquals(ints, result.copyTo(actual)); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + int[] actual = new int[ints.length]; + assertArrayEquals(ints, result.expect(Integer.class).copyTo(actual)); + } + } + } + + @Test + public void createFloat() { + float value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, value); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Float.class).floatValue(), 0.0f); + } } } @@ -63,9 +93,24 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, shape, FloatBuffer.wrap(floats)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class); - float[] actual = new float[floats.length]; - assertArrayEquals(floats, result.copyTo(actual), EPSILON); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + float[] actual = new float[floats.length]; + assertArrayEquals(floats, result.expect(Float.class).copyTo(actual), EPSILON); + } + } + } + + @Test + public void createDouble() { + double value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, value); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Double.class).doubleValue(), 0.0); + } } } @@ -78,9 +123,24 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class); - double[] actual = new double[doubles.length]; - assertArrayEquals(doubles, result.copyTo(actual), EPSILON); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + double[] actual = new double[doubles.length]; + assertArrayEquals(doubles, result.expect(Double.class).copyTo(actual), EPSILON); + } + } + } + + @Test + public void createLong() { + long value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, value); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Long.class).longValue()); + } } } @@ -93,15 +153,29 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, shape, LongBuffer.wrap(longs)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class); - long[] actual = new long[longs.length]; - assertArrayEquals(longs, result.copyTo(actual)); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + long[] actual = new long[longs.length]; + assertArrayEquals(longs, result.expect(Long.class).copyTo(actual)); + } } } @Test - public void createStringBuffer() throws IOException { + public void createBoolean() { + boolean value = true; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, value); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Boolean.class).booleanValue()); + } + } + } + @Test + public void createStringBuffer() throws IOException { byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4}; long[] shape = {}; @@ -124,8 +198,9 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, String.class, shape, ByteBuffer.wrap(content)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(String.class); - assertArrayEquals(data, result.bytesValue()); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + assertArrayEquals(data, result.expect(String.class).bytesValue()); + } } } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java new file mode 100644 index 0000000000000000000000000000000000000000..3f49790b291a4bf3872678eb9464d160ce0470bd --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.Graph; +import org.tensorflow.Output; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.Tensors; +import org.tensorflow.TestUtil; +import org.tensorflow.op.Scope; + +@RunWith(JUnit4.class) +public class GradientsTest { + + @Test + public void createGradients() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output x = TestUtil.placeholder(g, "x1", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Gradients grads = Gradients.create(scope, y1, Arrays.asList(x, y0)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(2, grads.dy().size()); + + try (Tensor c = Tensors.create(3.0f); + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList<>( + sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { + + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(18.0f, outputs.get(1).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithSum() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output x = TestUtil.placeholder(g, "x1", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Gradients grads = Gradients.create(scope, Arrays.asList(y0, y1), Arrays.asList(x)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(1, grads.dy().size()); + + try (Tensor c = Tensors.create(3.0f); + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { + + assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithInitialValues() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output x = TestUtil.placeholder(g, "x1", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Gradients grads0 = Gradients.create(scope, y1, Arrays.asList(y0)); + Gradients grads1 = Gradients.create(scope, y0, Arrays.asList(x), Gradients.dx(grads0.dy())); + + assertNotNull(grads1); + assertNotNull(grads1.dy()); + assertEquals(1, grads1.dy().size()); + + try (Tensor c = Tensors.create(3.0f); + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList<>( + sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { + + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + } + } + } + + @Test + public void validateGradientsNames() { + try (Graph g = new Graph()) { + Scope scope = new Scope(g).withSubScope("sub"); + + Output x = TestUtil.placeholder(g, "x1", Float.class); + Output y = TestUtil.square(g, "y", x); + + Gradients grad0 = Gradients.create(scope, y, Arrays.asList(x)); + assertTrue(grad0.dy(0).op().name().startsWith("sub/Gradients/")); + + Gradients grad1 = Gradients.create(scope.withName("MyGradients"), y, Arrays.asList(x)); + assertTrue(grad1.dy(0).op().name().startsWith("sub/MyGradients/")); + } + } +} diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java new file mode 100644 index 0000000000000000000000000000000000000000..cf3910b594fad87c8dbff9df92ba23da22e13c04 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -0,0 +1,165 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import java.util.List; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.Graph; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.op.Scope; +import org.tensorflow.types.UInt8; + +@RunWith(JUnit4.class) +public class ZerosTest { + private static final float EPSILON = 1e-7f; + + @Test + public void createIntZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros op = Zeros.create(scope, Constant.create(scope, shape), Integer.class); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + int[][] actual = result.expect(Integer.class).copyTo(new int[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0, actual[i][j]); + } + } + } + } + } + + @Test + public void createFloatZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros op = Zeros.create(scope, Constant.create(scope, shape), Float.class); + try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0)) { + float[][] actual = result.expect(Float.class).copyTo(new float[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0.0f, actual[i][j], EPSILON); + } + } + } + } + } + + @Test + public void createDoubleZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros op = Zeros.create(scope, Constant.create(scope, shape), Double.class); + try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0)) { + double[][] actual = result.expect(Double.class).copyTo(new double[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0.0, actual[i][j], EPSILON); + } + } + } + } + } + + @Test + public void createLongZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros op = Zeros.create(scope, Constant.create(scope, shape), Long.class); + try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0)) { + long[][] actual = result.expect(Long.class).copyTo(new long[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0L, actual[i][j]); + } + } + } + } + } + + @Test + public void createBooleanZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros op = Zeros.create(scope, Constant.create(scope, shape), Boolean.class); + try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0)) { + boolean[][] actual = result.expect(Boolean.class).copyTo(new boolean[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertFalse(actual[i][j]); + } + } + } + } + } + + @Test + public void createUInt8Zeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros op = Zeros.create(scope, Constant.create(scope, shape), UInt8.class); + try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0)) { + byte[][] actual = result.expect(UInt8.class).copyTo(new byte[(int)shape[0]][(int)shape[1]]); + result.copyTo(actual); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0, actual[i][j]); + } + } + } + } + } + + @Test(expected = IllegalArgumentException.class) + public void cannotCreateStringZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros.create(scope, Constant.create(scope, shape), String.class); + } + } + + @Test + public void operationsComposingZerosAreCorrectlyNamed() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.create(scope, shape), Float.class); + List> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); + } + } +} diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a06b536f5b24334963004ef9c321d59dc1685a44..7c2e8b003b93fe43de8429aacc841893f17b7442 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4,14 +4,16 @@ # Public targets: # ":platform" - Low-level and platform-specific Python code. -package(default_visibility = [ +visibility = [ "//engedu/ml/tf_from_scratch:__pkg__", "//tensorflow:internal", "//tensorflow/contrib/lite/toco/python:__pkg__", "//tensorflow_models:__subpackages__", # TODO(aselle): to pass open source test. "//bazel_pip/tensorflow/contrib/lite/toco/python:__pkg__", -]) +] + +package(default_visibility = visibility) licenses(["notice"]) # Apache 2.0 @@ -55,12 +57,12 @@ py_library( "//tensorflow/contrib/lite/toco/python:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/python/debug:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/python/tools:__pkg__", # TODO(b/34059704): remove when fixed - "//tensorflow/tools/api/generator:__pkg__", "//tensorflow/tools/quantization:__pkg__", # TODO(b/34059704): remove when fixed ], deps = [ ":no_contrib", "//tensorflow/contrib:contrib_py", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -71,7 +73,7 @@ py_library( visibility = [ "//tensorflow:__pkg__", "//tensorflow/python/tools:__pkg__", - "//tensorflow/tools/api/generator:__pkg__", + "//tensorflow/python/tools/api/generator:__pkg__", ], deps = [ ":array_ops", @@ -94,6 +96,7 @@ py_library( ":image_ops", ":initializers_ns", ":io_ops", + ":kernels", ":layers", ":lib", ":list_ops", @@ -125,13 +128,14 @@ py_library( ":util", ":weights_broadcast_ops", "//tensorflow/core:protos_all_py", + "//tensorflow/python/compat", "//tensorflow/python/data", - "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column:feature_column_py", "//tensorflow/python/keras", "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/linalg", "//tensorflow/python/ops/losses", + "//tensorflow/python/ops/parallel_for", "//tensorflow/python/profiler", "//tensorflow/python/saved_model", "//third_party/py/numpy", @@ -278,6 +282,9 @@ cc_library( name = "ndarray_tensor_bridge", srcs = ["lib/core/ndarray_tensor_bridge.cc"], hdrs = ["lib/core/ndarray_tensor_bridge.h"], + visibility = visibility + [ + "//learning/deepmind/courier:__subpackages__", + ], deps = [ ":bfloat16_lib", ":numpy_lib", @@ -358,6 +365,9 @@ cc_library( name = "ndarray_tensor", srcs = ["lib/core/ndarray_tensor.cc"], hdrs = ["lib/core/ndarray_tensor.h"], + visibility = visibility + [ + "//learning/deepmind/courier:__subpackages__", + ], deps = [ ":bfloat16_lib", ":ndarray_tensor_bridge", @@ -690,12 +700,24 @@ py_library( ], ) +py_library( + name = "error_interpolation", + srcs = [ + "framework/error_interpolation.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":util", + ], +) + py_library( name = "function", srcs = ["framework/function.py"], srcs_version = "PY2AND3", deps = [ ":array_ops", + ":cond_v2_impl", ":dtypes", ":framework_ops", ":graph_to_function_def", @@ -712,6 +734,7 @@ py_library( srcs = ["framework/graph_to_function_def.py"], srcs_version = "PY2AND3", deps = [ + ":cond_v2_impl", ":op_def_registry", "//tensorflow/core:protos_all_py", ], @@ -723,8 +746,8 @@ py_library( srcs_version = "PY2AND3", deps = [ ":framework", + ":framework_ops", ":function", - ":op_def_registry", ":tensor_shape", ":versions", "//tensorflow/core:protos_all_py", @@ -740,8 +763,10 @@ py_test( deps = [ ":array_ops", ":client_testlib", + ":constant_op", ":dtypes", ":framework_ops", + ":function", ":function_def_to_graph", ":graph_to_function_def", ":math_ops", @@ -765,6 +790,19 @@ py_library( ], ) +py_library( + name = "kernels", + srcs = [ + "framework/kernels.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":pywrap_tensorflow", + ":util", + "//tensorflow/core:protos_all_py", + ], +) + py_library( name = "op_def_library", srcs = ["framework/op_def_library.py"], @@ -796,12 +834,15 @@ py_library( deps = [ ":c_api_util", ":control_flow_util", + ":cpp_shape_inference_proto_py", ":device", ":dtypes", + ":error_interpolation", ":op_def_registry", ":platform", ":registry", ":tensor_shape", + ":traceable_stack", ":util", ":versions", "//tensorflow/core:protos_all_py", @@ -867,6 +908,17 @@ py_library( ], ) +# This target is maintained separately from :util to provide separate visibility +# for legacy users who were granted visibility when the functions were private +# members of ops.Graph. +py_library( + name = "tf_stack", + srcs = ["util/tf_stack.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [], +) + py_library( name = "tensor_shape", srcs = ["framework/tensor_shape.py"], @@ -901,6 +953,16 @@ py_library( ], ) +py_library( + name = "traceable_stack", + srcs = ["framework/traceable_stack.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":util", + ], +) + py_library( name = "versions", srcs = ["framework/versions.py"], @@ -990,6 +1052,20 @@ py_test( ], ) +py_test( + name = "framework_error_interpolation_test", + size = "small", + srcs = ["framework/error_interpolation_test.py"], + main = "framework/error_interpolation_test.py", + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":constant_op", + ":error_interpolation", + ":traceable_stack", + ], +) + py_test( name = "framework_subscribe_test", size = "small", @@ -1052,7 +1128,6 @@ tf_gen_op_wrapper_private_py( name = "functional_ops_gen", visibility = [ "//learning/brain/python/ops:__pkg__", - "//tensorflow/contrib/control_flow:__pkg__", ], ) @@ -1176,6 +1251,21 @@ py_test( ], ) +py_test( + name = "framework_traceable_stack_test", + size = "small", + srcs = ["framework/traceable_stack_test.py"], + main = "framework/traceable_stack_test.py", + srcs_version = "PY2AND3", + deps = [ + ":framework_test_lib", + ":platform_test", + ":test_ops", + ":traceable_stack", + ":util", + ], +) + tf_gen_op_wrapper_py( name = "test_ops", out = "framework/test_ops.py", @@ -1408,6 +1498,20 @@ py_test( ], ) +py_test( + name = "framework_kernels_test", + size = "small", + srcs = ["framework/kernels_test.py"], + main = "framework/kernels_test.py", + srcs_version = "PY2AND3", + deps = [ + ":framework_test_lib", + ":kernels", + ":platform_test", + ":test_ops", + ], +) + tf_gen_op_wrapper_private_py( name = "array_ops_gen", visibility = [ @@ -1600,6 +1704,9 @@ tf_gen_op_wrapper_private_py( tf_gen_op_wrapper_private_py( name = "resource_variable_ops_gen", + visibility = [ + "//tensorflow/compiler/tf2xla:internal", + ], ) tf_gen_op_wrapper_private_py( @@ -1763,6 +1870,7 @@ py_library( ":framework_for_generated_wrappers", ":math_ops", ":nn_ops_gen", + ":numerics", "@six_archive//:six", ], ) @@ -1776,7 +1884,6 @@ py_test( ":client_testlib", ":clip_ops", ":framework_for_generated_wrappers", - ":numerics", "//third_party/py/numpy", ], ) @@ -1827,6 +1934,7 @@ py_library( "tensor_shape", ":array_ops", ":array_ops_gen", + ":cond_v2_impl", ":constant_op", ":control_flow_ops_gen", ":control_flow_util", @@ -1855,6 +1963,37 @@ py_library( ], ) +py_library( + name = "cond_v2", + srcs = [ + "ops/cond_v2.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":cond_v2_impl", + ":function", + ":function_def_to_graph", + ":gradients", + ], +) + +py_library( + name = "cond_v2_impl", + srcs = [ + "ops/cond_v2_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":c_api_util", + ":framework_ops", + ":functional_ops_gen", + ":pywrap_tensorflow", + ":util", + "//tensorflow/core:protos_all_py", + ], +) + py_library( name = "ctc_ops", srcs = ["ops/ctc_ops.py"], @@ -1921,6 +2060,8 @@ py_library( ":math_ops", ":platform", ":resource_variable_ops", + ":sparse_ops", + ":tensor_shape", ":variables", ], ) @@ -1937,6 +2078,7 @@ py_library( ":array_grad", ":array_ops", ":bitwise_ops", + ":cond_v2_impl", ":control_flow_grad", ":control_flow_ops", ":control_flow_util", @@ -1953,6 +2095,7 @@ py_library( ":math_grad", ":math_ops", ":platform", + ":random_grad", ":resource_variable_ops", ":spectral_grad", ":util", @@ -2026,8 +2169,8 @@ py_library( ":linalg_ops_gen", ":linalg_ops_impl", ":math_ops", - ":nn_ops", ":random_ops", + ":util", "//third_party/py/numpy", ], ) @@ -2331,6 +2474,19 @@ py_library( ], ) +py_library( + name = "random_grad", + srcs = ["ops/random_grad.py"], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":dtypes", + ":framework_ops", + ":math_ops", + ":random_ops_gen", + ], +) + py_library( name = "random_ops", srcs = ["ops/random_ops.py"], @@ -2391,6 +2547,7 @@ py_library( srcs = ["ops/script_ops.py"], srcs_version = "PY2AND3", deps = [ + ":array_ops", ":framework_for_generated_wrappers", ":script_ops_gen", "//third_party/py/numpy", @@ -2921,6 +3078,20 @@ cuda_py_test( shard_count = 5, ) +cuda_py_test( + name = "init_ops_test", + size = "small", + srcs = ["ops/init_ops_test.py"], + additional_deps = [ + ":client_testlib", + ":init_ops", + ":framework_ops", + ":resource_variable_ops", + "//third_party/py/numpy", + "//tensorflow/python/eager:context", + ], +) + cuda_py_test( name = "math_grad_test", size = "small", @@ -3002,6 +3173,7 @@ cuda_py_test( ":partitioned_variables", ":variable_scope", ":variables", + "@absl_py//absl/testing:parameterized", "//third_party/py/numpy", ], tags = ["no_windows"], @@ -3046,14 +3218,18 @@ py_library( "training/checkpointable/**/*.py", # The following targets have their own build rules (same name as the # file): + "training/checkpoint_management.py", "training/saveable_object.py", + "training/saver.py", "training/training_util.py", ], ), srcs_version = "PY2AND3", deps = [ + "saver", ":array_ops", ":array_ops_gen", + ":checkpoint_management", ":checkpoint_ops_gen", ":client", ":control_flow_ops", @@ -3065,24 +3241,20 @@ py_library( ":framework_ops", ":gradients", ":init_ops", - ":distribute", ":io_ops", - ":io_ops_gen", ":layers_base", - ":lib", ":lookup_ops", ":math_ops", ":platform", - ":protos_all_py", ":pywrap_tensorflow", ":random_ops", ":resource_variable_ops", ":resources", - ":saveable_object", ":sdca_ops", + ":session", ":sparse_ops", + ":sparse_tensor", ":state_ops", - ":string_ops", ":summary", ":training_ops_gen", ":training_util", @@ -3092,6 +3264,7 @@ py_library( "//third_party/py/numpy", "@six_archive//:six", "//tensorflow/core:protos_all_py", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", # `layers` dependency only exists due to the use of a small utility. @@ -3108,6 +3281,52 @@ py_library( srcs_version = "PY2AND3", ) +py_library( + name = "checkpoint_management", + srcs = ["training/checkpoint_management.py"], + deps = [ + ":errors", + ":lib", + ":platform", + ":protos_all_py", + ":util", + "//tensorflow/core:protos_all_py", + ], +) + +py_library( + name = "saver", + srcs = ["training/saver.py"], + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":checkpoint_management", + ":constant_op", + ":control_flow_ops", + ":device", + ":errors", + ":framework", + ":framework_ops", + ":io_ops", + ":io_ops_gen", + ":platform", + ":pywrap_tensorflow", + ":resource_variable_ops", + ":saveable_object", + ":session", + ":state_ops", + ":string_ops", + ":training_util", + ":util", + ":variables", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/eager:context", + "//tensorflow/python/training/checkpointable:base", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "device_util", srcs = ["training/device_util.py"], @@ -3121,7 +3340,10 @@ py_library( py_library( name = "distribute", - srcs = ["training/distribute.py"], + srcs = [ + "training/distribute.py", + "training/distribution_strategy_context.py", + ], srcs_version = "PY2AND3", deps = [ ":array_ops", @@ -3213,6 +3435,9 @@ py_library( ], ), srcs_version = "PY2AND3", + visibility = visibility + [ + "//tensorflow:__pkg__", + ], deps = [ "//third_party/py/numpy", "@org_python_pypi_backports_weakref", @@ -3235,6 +3460,7 @@ py_test( ":math_ops", ":util", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -3338,6 +3564,19 @@ py_library( ], ) +py_test( + name = "lock_util_test", + size = "small", + srcs = ["util/lock_util_test.py"], + main = "util/lock_util_test.py", + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":util", + "@absl_py//absl/testing:parameterized", + ], +) + tf_proto_library( name = "protos_all", srcs = glob( @@ -3472,6 +3711,7 @@ tf_cuda_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_ref", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", ], @@ -3656,6 +3896,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":c_api_util", + ":error_interpolation", ":errors", ":framework", ":framework_for_generated_wrappers", @@ -3856,7 +4097,7 @@ tf_cuda_library( tf_py_test( name = "session_test", - size = "small", + size = "medium", srcs = ["client/session_test.py"], additional_deps = [ ":array_ops", @@ -3967,7 +4208,6 @@ cuda_py_test( ":math_ops", "//tensorflow/core:protos_all_py", ], - tags = ["no_windows"], ) cuda_py_test( @@ -3981,6 +4221,7 @@ cuda_py_test( ":math_ops", "//tensorflow/core:protos_all_py", ], + tags = ["no_windows_gpu"], ) py_test( @@ -4038,6 +4279,19 @@ py_test( ], ) +py_test( + name = "tf_record_test", + size = "small", + srcs = ["lib/io/tf_record_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":errors", + ":lib", + ":util", + ], +) + cuda_py_test( name = "adam_test", size = "small", @@ -4184,6 +4438,42 @@ cuda_py_test( tags = ["multi_gpu"], ) +cuda_py_test( + name = "checkpoint_management_test", + size = "small", + srcs = [ + "training/checkpoint_management_test.py", + ], + additional_deps = [ + ":array_ops", + ":client_testlib", + ":control_flow_ops", + ":data_flow_ops", + ":errors", + ":gradients", + ":math_ops", + ":nn_grad", + ":nn_ops", + ":saver_test_utils", + ":partitioned_variables", + ":platform", + ":platform_test", + ":pywrap_tensorflow", + ":random_ops", + ":resource_variable_ops", + ":sparse_ops", + ":summary", + ":training", + ":util", + ":variable_scope", + ":variables", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "saver_large_variable_test", size = "medium", @@ -4211,7 +4501,6 @@ py_test( srcs = ["training/saver_large_partitioned_variable_test.py"], srcs_version = "PY2AND3", tags = [ - "no_windows", "noasan", # http://b/30782289 "notsan", # http://b/30782289 ], @@ -4250,6 +4539,7 @@ tf_py_test( srcs = ["training/supervisor_test.py"], additional_deps = [ ":array_ops", + ":checkpoint_management", ":client_testlib", ":errors", ":framework", @@ -4257,6 +4547,7 @@ tf_py_test( ":io_ops", ":parsing_ops", ":platform", + ":saver", ":summary", ":training", ":variables", @@ -4370,10 +4661,13 @@ py_test( tags = ["notsan"], # b/67945581 deps = [ ":array_ops", + ":checkpoint_management", ":client_testlib", ":control_flow_ops", ":errors", ":framework_for_generated_wrappers", + ":resource_variable_ops", + ":saver", ":session", ":state_ops", ":summary", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index cf707fb2c731c0db57c2335d3ffd49b292c811cc..a2ab63bb48799d5b93882bb87ab40b02dbb96621 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -79,7 +79,6 @@ from tensorflow.python.ops import initializers_ns as initializers # Bring in subpackages. from tensorflow.python import data from tensorflow.python import keras -from tensorflow.python.estimator import estimator_lib as estimator from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.layers import layers from tensorflow.python.ops import bitwise_ops as bitwise diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 35aa37ac6dd721750cd72b54d1b8ef6a70402038..28f26ad27e151d579a43d9d282bfb51066f79d38 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import functools import re import threading @@ -28,6 +29,7 @@ import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python import pywrap_tensorflow as tf_session from tensorflow.python.framework import device +from tensorflow.python.framework import error_interpolation from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -243,7 +245,7 @@ class _FetchMapper(object): elif isinstance(fetch, (list, tuple)): # NOTE(touts): This is also the code path for namedtuples. return _ListFetchMapper(fetch) - elif isinstance(fetch, dict): + elif isinstance(fetch, collections.Mapping): return _DictFetchMapper(fetch) else: # Look for a handler in the registered expansions. @@ -361,7 +363,7 @@ class _ListFetchMapper(_FetchMapper): for m, vi in zip(self._mappers, self._value_indices): results.append(m.build_results([values[j] for j in vi])) # Return a value of the original type of the fetches. - if self._fetch_type == list: + if issubclass(self._fetch_type, list): return results elif self._fetch_type == tuple: return tuple(results) @@ -540,10 +542,11 @@ class _DeviceAttributes(object): (in bytes). """ - def __init__(self, name, device_type, memory_limit_bytes): + def __init__(self, name, device_type, memory_limit_bytes, incarnation): self._name = device.canonical_name(name) self._device_type = device_type self._memory_limit_bytes = memory_limit_bytes + self._incarnation = incarnation @property def name(self): @@ -557,11 +560,16 @@ class _DeviceAttributes(object): def memory_limit_bytes(self): return self._memory_limit_bytes + @property + def incarnation(self): + return self._incarnation + def __repr__(self): - return '_DeviceAttributes(%s, %s, %d)' % ( + return '_DeviceAttributes(%s, %s, %d, %d)' % ( self.name, self.device_type, self.memory_limit_bytes, + self.incarnation, ) @@ -623,7 +631,7 @@ class BaseSession(SessionInterface): opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) try: # pylint: disable=protected-access - self._session = tf_session.TF_NewSession(self._graph._c_graph, opts) + self._session = tf_session.TF_NewSessionRef(self._graph._c_graph, opts) # pylint: enable=protected-access finally: tf_session.TF_DeleteSessionOptions(opts) @@ -658,7 +666,9 @@ class BaseSession(SessionInterface): name = tf_session.TF_DeviceListName(raw_device_list, i) device_type = tf_session.TF_DeviceListType(raw_device_list, i) memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i) - device_list.append(_DeviceAttributes(name, device_type, memory)) + incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i) + device_list.append( + _DeviceAttributes(name, device_type, memory, incarnation)) tf_session.TF_DeleteDeviceList(raw_device_list) return device_list @@ -714,7 +724,7 @@ class BaseSession(SessionInterface): """Returns a context manager that makes this object the default session. Use with the `with` keyword to specify that calls to - @{tf.Operation.run} or @{tf.Tensor.eval} should be executed in + `tf.Operation.run` or `tf.Tensor.eval` should be executed in this session. ```python @@ -726,7 +736,7 @@ class BaseSession(SessionInterface): print(c.eval()) ``` - To get the current default session, use @{tf.get_default_session}. + To get the current default session, use `tf.get_default_session`. *N.B.* The `as_default` context manager *does not* close the session when you exit the context, and you must close the session @@ -755,7 +765,7 @@ class BaseSession(SessionInterface): *N.B.* Entering a `with sess.as_default():` block does not affect the current default graph. If you are using multiple graphs, and - `sess.graph` is different from the value of @{tf.get_default_graph}, + `sess.graph` is different from the value of `tf.get_default_graph`, you must explicitly enter a `with sess.graph.as_default():` block to make `sess.graph` the default graph. @@ -776,14 +786,14 @@ class BaseSession(SessionInterface): nested list, tuple, namedtuple, dict, or OrderedDict containing graph elements at its leaves. A graph element can be one of the following types: - * An @{tf.Operation}. + * An `tf.Operation`. The corresponding fetched value will be `None`. - * A @{tf.Tensor}. + * A `tf.Tensor`. The corresponding fetched value will be a numpy ndarray containing the value of that tensor. - * A @{tf.SparseTensor}. + * A `tf.SparseTensor`. The corresponding fetched value will be a - @{tf.SparseTensorValue} + `tf.SparseTensorValue` containing the value of that sparse tensor. * A `get_tensor_handle` op. The corresponding fetched value will be a numpy ndarray containing the handle of that tensor. @@ -819,16 +829,16 @@ class BaseSession(SessionInterface): the value of tensors in the graph. Each key in `feed_dict` can be one of the following types: - * If the key is a @{tf.Tensor}, the + * If the key is a `tf.Tensor`, the value may be a Python scalar, string, list, or numpy ndarray that can be converted to the same `dtype` as that tensor. Additionally, if the key is a - @{tf.placeholder}, the shape of + `tf.placeholder`, the shape of the value will be checked for compatibility with the placeholder. * If the key is a - @{tf.SparseTensor}, + `tf.SparseTensor`, the value should be a - @{tf.SparseTensorValue}. + `tf.SparseTensorValue`. * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value should be a nested tuple with the same structure that maps to their corresponding values as above. @@ -1110,7 +1120,7 @@ class BaseSession(SessionInterface): For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th argument to the returned callable must be a numpy ndarray (or something convertible to an ndarray) with matching element type and shape. See - @{tf.Session.run} for details of the allowable feed key and value types. + `tf.Session.run` for details of the allowable feed key and value types. The returned callable will have the same return type as `tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`, @@ -1118,14 +1128,14 @@ class BaseSession(SessionInterface): it will return `None`. Args: - fetches: A value or list of values to fetch. See @{tf.Session.run} + fetches: A value or list of values to fetch. See `tf.Session.run` for details of the allowable fetch types. feed_list: (Optional.) A list of `feed_dict` keys. See - @{tf.Session.run} for details of the allowable feed key types. + `tf.Session.run` for details of the allowable feed key types. accept_options: (Optional.) Iff `True`, the returned `Callable` will be - able to accept @{tf.RunOptions} and @{tf.RunMetadata} as optional + able to accept `tf.RunOptions` and `tf.RunMetadata` as optional keyword arguments `options` and `run_metadata`, respectively, with - the same syntax and semantics as @{tf.Session.run}, which is useful + the same syntax and semantics as `tf.Session.run`, which is useful for certain use cases (profiling and debugging) but will result in measurable slowdown of the `Callable`'s performance. Default: `False`. @@ -1135,7 +1145,7 @@ class BaseSession(SessionInterface): Raises: TypeError: If `fetches` or `feed_list` cannot be interpreted - as arguments to @{tf.Session.run}. + as arguments to `tf.Session.run`. """ if feed_list is not None: if not isinstance(feed_list, (list, tuple)): @@ -1226,8 +1236,12 @@ class BaseSession(SessionInterface): return _fetch_handler_run - # Captures the name of a node in an error status. - _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =') + # Captures the name of a node in an error status. The regex below matches + # both the old and the new formats: + # Old format: [[Node: = ...]] + # New format: [[{{node }} = ...]] + _NODEDEF_NAME_RE = re.compile( + r'\[\[(Node: )?(\{\{node )?([^\} ]*)(\}\})?\s*=') def _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata): @@ -1282,16 +1296,19 @@ class BaseSession(SessionInterface): node_def = None op = None if m is not None: - node_name = m.group(1) + node_name = m.group(3) try: op = self._graph.get_operation_by_name(node_name) node_def = op.node_def except KeyError: pass + if (self._config is not None and + self._config.experimental.client_handles_error_formatting): + message = error_interpolation.interpolate(message, self._graph) raise type(e)(node_def, op, message) def _extend_graph(self): - with self._graph._lock: # pylint: disable=protected-access + with self._graph._session_run_lock(): # pylint: disable=protected-access tf_session.ExtendSession(self._session) # The threshold to run garbage collection to delete dead tensors. @@ -1436,10 +1453,10 @@ class Session(BaseSession): ``` A session may own resources, such as - @{tf.Variable}, @{tf.QueueBase}, - and @{tf.ReaderBase}. It is important to release + `tf.Variable`, `tf.QueueBase`, + and `tf.ReaderBase`. It is important to release these resources when they are no longer required. To do this, either - invoke the @{tf.Session.close} method on the session, or use + invoke the `tf.Session.close` method on the session, or use the session as a context manager. The following two examples are equivalent: @@ -1575,8 +1592,8 @@ class InteractiveSession(BaseSession): The only difference with a regular `Session` is that an `InteractiveSession` installs itself as the default session on construction. - The methods @{tf.Tensor.eval} - and @{tf.Operation.run} + The methods `tf.Tensor.eval` + and `tf.Operation.run` will use that session to run ops. This is convenient in interactive shells and [IPython diff --git a/tensorflow/python/client/session_list_devices_test.py b/tensorflow/python/client/session_list_devices_test.py index c5d82c213ac890ac4c968eba506695c3a2ce93c4..dd381c689fde31531668d83441b5ee92bd1ab9ec 100644 --- a/tensorflow/python/client/session_list_devices_test.py +++ b/tensorflow/python/client/session_list_devices_test.py @@ -37,6 +37,8 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase): devices = sess.list_devices() self.assertTrue('/job:localhost/replica:0/task:0/device:CPU:0' in set( [d.name for d in devices]), devices) + # All valid device incarnations must be non-zero. + self.assertTrue(all(d.incarnation != 0 for d in devices)) def testInvalidDeviceNumber(self): opts = tf_session.TF_NewSessionOptions() @@ -54,6 +56,8 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase): devices = sess.list_devices() self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in set( [d.name for d in devices]), devices) + # All valid device incarnations must be non-zero. + self.assertTrue(all(d.incarnation != 0 for d in devices)) def testListDevicesClusterSpecPropagation(self): server1 = server_lib.Server.create_local_server() @@ -67,11 +71,13 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase): config = config_pb2.ConfigProto(cluster_def=cluster_def) with session.Session(server1.target, config=config) as sess: devices = sess.list_devices() - device_names = set([d.name for d in devices]) + device_names = set(d.name for d in devices) self.assertTrue( '/job:worker/replica:0/task:0/device:CPU:0' in device_names) self.assertTrue( '/job:worker/replica:0/task:1/device:CPU:0' in device_names) + # All valid device incarnations must be non-zero. + self.assertTrue(all(d.incarnation != 0 for d in devices)) if __name__ == '__main__': diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index e49d0671050f557842ad1d3305331d61cd8c9672..052be683856beb41ab572e808c260817b05ef5ae 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import collections +import random import os import sys import threading @@ -34,6 +35,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op +from tensorflow.python.framework import device as framework_device_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function @@ -103,18 +105,20 @@ class SessionTest(test_util.TensorFlowTestCase): copy_val) def testManyCPUs(self): - # TODO(keveman): Implement ListDevices and test for the number of - # devices returned by ListDevices. with session.Session( config=config_pb2.ConfigProto(device_count={ - 'CPU': 2 - })): + 'CPU': 2, 'GPU': 0 + })) as sess: inp = constant_op.constant(10.0, name='W1') self.assertAllEqual(inp.eval(), 10.0) + devices = sess.list_devices() + self.assertEqual(2, len(devices)) + for device in devices: + self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string( + device.name).device_type) + def testPerSessionThreads(self): - # TODO(keveman): Implement ListDevices and test for the number of - # devices returned by ListDevices. with session.Session( config=config_pb2.ConfigProto(use_per_session_threads=True)): inp = constant_op.constant(10.0, name='W1') @@ -1040,40 +1044,72 @@ class SessionTest(test_util.TensorFlowTestCase): for t in threads: t.join() - def testParallelRunAndBuild(self): + @staticmethod + def _build_graph(): + time.sleep(random.random() * 0.1) + # Do some graph construction. Try to exercise non-trivial paths. + graph = ops.get_default_graph() + gdef = None + for _ in range(10): + x = array_ops.placeholder(dtype=dtypes.float32) + with ops.colocate_with(x): + y = array_ops.placeholder(dtype=dtypes.float32) + with ops.device('/cpu:0'): + z = control_flow_ops.while_loop( + lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y]) + with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}): + gradients_impl.gradients(z, [x, y]) + if gdef is None: + gdef = graph.as_graph_def() + else: + importer.import_graph_def(gdef, name='import') + + def testParallelRunAndSingleBuild(self): with session.Session() as sess: c = constant_op.constant(5.0) stop = threading.Event() def run_loop(): while not stop.is_set(): + time.sleep(random.random() * 0.1) self.assertEqual(sess.run(c), 5.0) - threads = [self.checkedThread(target=run_loop) for _ in range(100)] + threads = [self.checkedThread(target=run_loop) for _ in range(10)] for t in threads: t.start() - # Do some graph construction. Try to exercise non-trivial paths. - graph = ops.get_default_graph() - gdef = None - for _ in range(10): - x = array_ops.placeholder(dtype=dtypes.float32) - with ops.colocate_with(x): - y = array_ops.placeholder(dtype=dtypes.float32) - with ops.device('/cpu:0'): - z = control_flow_ops.while_loop( - lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y]) - with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}): - gradients_impl.gradients(z, [x, y]) - if gdef is None: - gdef = graph.as_graph_def() - else: - importer.import_graph_def(gdef, name='import') + SessionTest._build_graph() stop.set() for t in threads: t.join() + def testParallelRunAndParallelBuild(self): + with session.Session() as sess: + c = constant_op.constant(5.0) + stop = threading.Event() + + def run_loop(): + while not stop.is_set(): + time.sleep(random.random() * 0.1) + self.assertEqual(sess.run(c), 5.0) + + run_threads = [self.checkedThread(target=run_loop) for _ in range(10)] + for t in run_threads: + t.start() + + build_threads = [self.checkedThread(target=SessionTest._build_graph) + for _ in range(10)] + for t in build_threads: + t.start() + for t in build_threads: + t.join() + + # Let the run_threads run until the build threads are finished. + stop.set() + for t in run_threads: + t.join() + def testRunFeedDict(self): with session.Session() as s: x = array_ops.zeros([2]) @@ -1835,19 +1871,21 @@ class SessionTest(test_util.TensorFlowTestCase): def testDeviceAttributes(self): attrs = session._DeviceAttributes( - '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337) + '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337, 1000000) self.assertEqual(1337, attrs.memory_limit_bytes) self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name) self.assertEqual('TYPE', attrs.device_type) + self.assertEqual(1000000, attrs.incarnation) str_repr = '%s' % attrs self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) def testDeviceAttributesCanonicalization(self): attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1', - 'TYPE', 1337) + 'TYPE', 1337, 1000000) self.assertEqual(1337, attrs.memory_limit_bytes) self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name) self.assertEqual('TYPE', attrs.device_type) + self.assertEqual(1000000, attrs.incarnation) str_repr = '%s' % attrs self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 1db1432d6521bb5f48558081916158792010b1c5..39a2922ac0e54367f454c36921a029a9a7d7e82e 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -135,7 +135,12 @@ tensorflow::ImportNumpy(); // Convert TF_DeviceListMemoryBytes and TF_Dim int64_t output to Python integers %typemap(out) int64_t { - $result = PyInt_FromLong($1); + $result = PyLong_FromLongLong($1); +} + +// Convert TF_DeviceListIncarnation uint64_t output to Python integer +%typemap(out) uint64_t { + $result = PyLong_FromUnsignedLongLong($1); } // We use TF_OperationGetControlInputs_wrapper instead of @@ -610,7 +615,7 @@ def TF_Reset(target, containers=None, config=None): } for (size_t i = 0; i < $1.size(); ++i) { - PyList_SET_ITEM($result, i, PyInt_FromLong($1[i])); + PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i])); } } @@ -673,7 +678,7 @@ def TF_Reset(target, containers=None, config=None): } for (size_t i = 0; i < $1.size(); ++i) { - PyList_SET_ITEM($result, i, PyInt_FromLong($1[i])); + PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i])); } } @@ -772,6 +777,7 @@ def TF_Reset(target, containers=None, config=None): $1 = &types_local; } +%unignore TF_NewSessionRef; %unignore SetRequireShapeInferenceFns; %unignore TF_TryEvaluateConstant_wrapper; %noexception TF_TryEvaluateConstant_wrapper; diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index b6481e7e29e4057f08e1c78b310bf5581afc5411..bcd4af291282bbefda3db0309bb9f0a913f186ce 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/common_runtime/session_ref.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" @@ -42,6 +43,19 @@ static const char* kFeedDictErrorMsg = "feed_dict must be a dictionary mapping strings to NumPy arrays."; } // end namespace +TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts, + TF_Status* status) { + TF_Session* tf_session = TF_NewSession(graph, opts, status); + if (tf_session == nullptr) { + return nullptr; + } + + Session* session = reinterpret_cast(tf_session->session); + SessionRef* session_ref = new SessionRef(session); + tf_session->session = session_ref; + return tf_session; +} + void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle, const TF_Buffer* run_options, PyObject* feed_dict, const NameVector& output_names, diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index cfd27c2bee990ab4e2829652a532761e674ed8e0..dab7e71aac5a7f4cbf9f8825ad6dd5d3f556bd43 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -40,6 +40,9 @@ typedef tensorflow::gtl::InlinedVector PyObjectVector; // A TF_TensorVector is a vector of borrowed pointers to TF_Tensors. typedef gtl::InlinedVector TF_TensorVector; +TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts, + TF_Status* status); + // Run the graph associated with the session starting with the // supplied inputs[]. Regardless of success or failure, inputs[] are // stolen by the implementation (i.e. the implementation will diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..e0a1c8e0571879e9661cdb0714cc6a794b7ea455 --- /dev/null +++ b/tensorflow/python/compat/BUILD @@ -0,0 +1,23 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +py_library( + name = "compat", + srcs = ["compat.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = ["//tensorflow/python:util"], +) + +tf_py_test( + name = "compat_test", + size = "small", + srcs = ["compat_test.py"], + additional_deps = [ + ":compat", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py new file mode 100644 index 0000000000000000000000000000000000000000..cf4b78a71a1bf3a809e55e1704841f7a060bc7fc --- /dev/null +++ b/tensorflow/python/compat/compat.py @@ -0,0 +1,132 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for API compatibility between TensorFlow release versions. + +See +@{$guide/version_compat#backward_and_partial_forward_compatibility} +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +from tensorflow.python.util import tf_contextlib +from tensorflow.python.util.tf_export import tf_export + +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 14) + + +@tf_export("compat.forward_compatible") +def forward_compatible(year, month, day): + """Return true if the forward compatibility window has expired. + + See @{$guide/version_compat#backward_and_partial_forward_compatibility}. + + Forward-compatibility refers to scenarios where the producer of a TensorFlow + model (a GraphDef or SavedModel) is compiled against a version of the + TensorFlow library newer than what the consumer was compiled against. The + "producer" is typically a Python program that constructs and trains a model + while the "consumer" is typically another program that loads and serves the + model. + + TensorFlow has been supporting a 3 week forward-compatibility window for + programs compiled from source at HEAD. + + For example, consider the case where a new operation `MyNewAwesomeAdd` is + created with the intent of replacing the implementation of an existing Python + wrapper - `tf.add`. The Python wrapper implementation should change from + something like: + + ```python + def add(inputs, name=None): + return gen_math_ops.add(inputs, name) + ``` + + to: + + ```python + from tensorflow.python.compat import compat + + def add(inputs, name=None): + if compat.forward_compatible(year, month, day): + # Can use the awesome new implementation. + return gen_math_ops.my_new_awesome_add(inputs, name) + # To maintain forward compatibiltiy, use the old implementation. + return gen_math_ops.add(inputs, name) + ``` + + Where `year`, `month`, and `day` specify the date beyond which binaries + that consume a model are expected to have been updated to include the + new operations. This date is typically at least 3 weeks beyond the date + the code that adds the new operation is committed. + + Args: + year: A year (e.g., 2018). + month: A month (1 <= month <= 12) in year. + day: A day (1 <= day <= 31, or 30, or 29, or 28) in month. + + Returns: + True if the caller can expect that serialized TensorFlow graphs produced + can be consumed by programs that are compiled with the TensorFlow library + source code after (year, month, day). + """ + return _FORWARD_COMPATIBILITY_HORIZON > datetime.date(year, month, day) + + +@tf_export("compat.forward_compatibility_horizon") +@tf_contextlib.contextmanager +def forward_compatibility_horizon(year, month, day): + """Context manager for testing forward compatibility of generated graphs. + + See @{$guide/version_compat#backward_and_partial_forward_compatibility}. + + To ensure forward compatibility of generated graphs (see `forward_compatible`) + with older binaries, new features can be gated with: + + ```python + if compat.forward_compatible(year=2018, month=08, date=01): + generate_graph_with_new_features() + else: + generate_graph_so_older_binaries_can_consume_it() + ``` + + However, when adding new features, one may want to unittest it before + the forward compatibility window expires. This context manager enables + such tests. For example: + + ```python + from tensorflow.python.compat import compat + + def testMyNewFeature(self): + with compat.forward_compatibility_horizon(2018, 08, 02): + # Test that generate_graph_with_new_features() has an effect + ``` + + Args : + year: A year (e.g. 2018). + month: A month (1 <= month <= 12) in year. + day: A day (1 <= day <= 31, or 30, or 29, or 28) in month. + + Yields: + Nothing. + """ + global _FORWARD_COMPATIBILITY_HORIZON + try: + old_compat_date = _FORWARD_COMPATIBILITY_HORIZON + _FORWARD_COMPATIBILITY_HORIZON = datetime.date(year, month, day) + yield + finally: + _FORWARD_COMPATIBILITY_HORIZON = old_compat_date diff --git a/tensorflow/python/compat/compat_test.py b/tensorflow/python/compat/compat_test.py new file mode 100644 index 0000000000000000000000000000000000000000..946abbb300d66e7be5ea317e365bc75cbcf6941c --- /dev/null +++ b/tensorflow/python/compat/compat_test.py @@ -0,0 +1,70 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for forward and backwards compatibility utilties.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +from tensorflow.python.compat import compat +from tensorflow.python.platform import test + + +class CompatTest(test.TestCase): + + def _compatibility_date(self): + date = compat._FORWARD_COMPATIBILITY_HORIZON # pylint: disable=protected-access + return (date.year, date.month, date.day) + + def _n_days_after(self, n): + date = compat._FORWARD_COMPATIBILITY_HORIZON + datetime.timedelta(days=n) # pylint: disable=protected-access + return (date.year, date.month, date.day) + + def test_basic(self): + compatibility_date = self._compatibility_date() + one_day_before = self._n_days_after(-1) + self.assertTrue(compat.forward_compatible(*one_day_before)) + self.assertFalse(compat.forward_compatible(*compatibility_date)) + + def test_decorator(self): + compatibility_date = self._compatibility_date() + one_day_after = self._n_days_after(1) + with compat.forward_compatibility_horizon(*one_day_after): + self.assertTrue(compat.forward_compatible(*compatibility_date)) + self.assertFalse(compat.forward_compatible(*one_day_after)) + + # After exiting context manager, value should be reset. + self.assertFalse(compat.forward_compatible(*compatibility_date)) + + def test_decorator_with_failure(self): + compatibility_date = self._compatibility_date() + one_day_after = self._n_days_after(1) + + class DummyError(Exception): + pass + + try: + with compat.forward_compatibility_horizon(*one_day_after): + raise DummyError() + except DummyError: + pass # silence DummyError + + # After exiting context manager, value should be reset. + self.assertFalse(compat.forward_compatible(*compatibility_date)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/data/__init__.py b/tensorflow/python/data/__init__.py index 7efe0948e7729c398f972977b51426d80b8cd83e..3b9bf2469e6d41fd0e8c5199af677e60bedf93f9 100644 --- a/tensorflow/python/data/__init__.py +++ b/tensorflow/python/data/__init__.py @@ -14,7 +14,7 @@ # ============================================================================== """`tf.data.Dataset` API for input pipelines. -See the @{$datasets$Importing Data} Programmer's Guide for an overview. +See @{$guide/datasets$Importing Data} for an overview. """ from __future__ import absolute_import diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index e86c2f69935bc0d66b18b7dbc987bb5966c7ed8d..23c98247bf38643c38d137f23e9cd8bc4ecacbba 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -179,6 +179,7 @@ tf_py_test( size = "small", srcs = ["prefetch_dataset_op_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dataset_ops_gen", @@ -317,7 +318,7 @@ tf_py_test( ], ) -tf_py_test( +cuda_py_test( name = "iterator_ops_test", size = "small", srcs = ["iterator_ops_test.py"], @@ -328,6 +329,8 @@ tf_py_test( "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/util:sparse", + "//tensorflow/python/eager:context", + "//tensorflow/python/training/checkpointable:util", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -348,6 +351,9 @@ tf_py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_shape", "//tensorflow/python:training", + "//tensorflow/python/compat:compat", + "//tensorflow/python:util", + "//tensorflow/python:variables", ], grpc_enabled = True, ) @@ -379,3 +385,22 @@ tf_py_test( "no_windows", ], ) + +cuda_py_test( + name = "optional_ops_test", + size = "small", + srcs = ["optional_ops_test.py"], + additional_deps = [ + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:optional_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:tensor_shape", + ], +) diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py index 50bb0837b7052d67ced4fdf5c9c7e96212bdb415..89de55dd4f9fdc612663c839b926684d27d48c54 100644 --- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py @@ -18,9 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + from absl.testing import parameterized import numpy as np +from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -275,7 +278,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): result = sess.run(get_next) padded_len = padded_shapes[0] if padded_len is None or padded_len == -1: - padded_len = np.max(result) + padded_len = np.max(result) if result.size > 0 else 0 self.assertEqual((batch_size, padded_len), result.shape) for j in range(batch_size): seq_len = seq_lens[(i * batch_size) + j] @@ -285,7 +288,7 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): if not drop_remainder and len(seq_lens) % batch_size > 0: result = sess.run(get_next) - padded_len = np.max(result) + padded_len = np.max(result) if result.size > 0 else 0 self.assertEqual((len(seq_lens) % batch_size, padded_len), result.shape) for j in range(len(seq_lens) % batch_size): @@ -461,5 +464,55 @@ class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase): 5, padded_shapes=shape_as_tensor) +class BatchDatasetBenchmark(test.Benchmark): + + def benchmarkBatchSparse(self): + non_zeros_per_row_values = [0, 1, 5, 10, 100] + batch_size_values = [1, 32, 64, 128, 1024] + + sparse_placeholder = array_ops.sparse_placeholder(dtype=dtypes.int64) + batch_size_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[]) + + dataset = dataset_ops.Dataset.from_tensors(sparse_placeholder).repeat( + ).batch(batch_size_placeholder) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + for non_zeros_per_row in non_zeros_per_row_values: + + sparse_value = sparse_tensor.SparseTensorValue( + indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis], + values=np.arange(non_zeros_per_row, dtype=np.int64), + dense_shape=[1000]) + + for batch_size in batch_size_values: + + with session.Session() as sess: + sess.run(iterator.initializer, feed_dict={ + sparse_placeholder: sparse_value, + batch_size_placeholder: batch_size}) + # Run five steps to warm up the session caches before taking the + # first measurement. + for _ in range(5): + sess.run(next_element.indices.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.indices.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100.0 + + print('Batch sparse dataset non-zeros per row: %d batch_size: %d ' + 'wall time: %f' + % (non_zeros_per_row, batch_size, median_wall_time)) + self.report_benchmark( + iters=10000, wall_time=median_wall_time, + name='benchmark_batch_sparse_dataset_nnz_%d_batch_size_%d' % ( + non_zeros_per_row, batch_size)) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py index 25269dc810ae2e3107f8b5317496a35a8ff59d0c..4f7fd3566ef5ed9389df670a3796e63abf3bfce9 100644 --- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py @@ -34,7 +34,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -class FilesystemCacheDatasetTest(test.TestCase): +class FileCacheDatasetTest(test.TestCase): def setUp(self): self.tmp_dir = tempfile.mkdtemp() diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py index 820c167b6bb9dc3b1c25d9c6156cef17ad20eb1b..352424514e63220dce4218b625bd8b142a0d4107 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import os import warnings @@ -25,6 +26,7 @@ import numpy as np from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.compat import compat as forward_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers @@ -45,7 +47,9 @@ from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import server_lib +from tensorflow.python.training.checkpointable import util as checkpointable_utils from tensorflow.python.util import compat @@ -415,6 +419,69 @@ class IteratorTest(test.TestCase): sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle}) + def testIteratorStringHandleFuture(self): + with forward_compat.forward_compatibility_horizon(2018, 8, 4): + dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) + dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) + + iterator_3 = dataset_3.make_one_shot_iterator() + iterator_4 = dataset_4.make_one_shot_iterator() + + handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) + feedable_iterator = iterator_ops.Iterator.from_string_handle( + handle_placeholder, dataset_3.output_types, dataset_3.output_shapes) + next_element = feedable_iterator.get_next() + + self.assertEqual(dataset_3.output_types, feedable_iterator.output_types) + self.assertEqual(dataset_4.output_types, feedable_iterator.output_types) + self.assertEqual([], feedable_iterator.output_shapes) + + with self.test_session() as sess: + iterator_3_handle = sess.run(iterator_3.string_handle()) + iterator_4_handle = sess.run(iterator_4.string_handle()) + + self.assertEqual( + 10, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_4_handle})) + self.assertEqual( + 1, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_3_handle})) + self.assertEqual( + 20, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_4_handle})) + self.assertEqual( + 2, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_3_handle})) + self.assertEqual( + 30, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_4_handle})) + self.assertEqual( + 3, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_3_handle})) + self.assertEqual( + 40, + sess.run( + next_element, + feed_dict={handle_placeholder: iterator_4_handle})) + with self.assertRaises(errors.OutOfRangeError): + sess.run( + next_element, feed_dict={handle_placeholder: iterator_3_handle}) + with self.assertRaises(errors.OutOfRangeError): + sess.run( + next_element, feed_dict={handle_placeholder: iterator_4_handle}) + def testIteratorStringHandleReuseTensorObject(self): dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) one_shot_iterator = dataset.make_one_shot_iterator() @@ -724,5 +791,98 @@ class IteratorTest(test.TestCase): val += 1 +class IteratorCheckpointingTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testSaveRestoreOneShotIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map( + math_ops.square).batch(2) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next if context.executing_eagerly( + ) else functools.partial(self.evaluate, iterator.get_next()) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + with self.test_session() as sess: + self.assertAllEqual([1, 4], get_next()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([9, 16], get_next()) + self.assertAllEqual([25, 36], get_next()) + checkpoint.restore(save_path).run_restore_ops(sess) + self.assertAllEqual([9, 16], get_next()) + self.assertAllEqual([25, 36], get_next()) + with self.assertRaises(errors.OutOfRangeError): + get_next() + + @test_util.run_in_graph_and_eager_modes + def testSaveRestoreMultipleIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + dataset = dataset_ops.Dataset.from_tensor_slices( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + dataset = dataset.map(math_ops.square).batch(2) + iterator_1 = dataset.make_one_shot_iterator() + get_next_1 = iterator_1.get_next if context.executing_eagerly( + ) else functools.partial(self.evaluate, iterator_1.get_next()) + iterator_2 = dataset.make_one_shot_iterator() + get_next_2 = iterator_2.get_next if context.executing_eagerly( + ) else functools.partial(self.evaluate, iterator_2.get_next()) + dataset_2 = dataset_ops.Dataset.range(10) + iterator_3 = dataset_2.make_one_shot_iterator() + get_next_3 = iterator_3.get_next if context.executing_eagerly( + ) else functools.partial(self.evaluate, iterator_3.get_next()) + checkpoint = checkpointable_utils.Checkpoint( + iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) + with self.test_session() as sess: + self.assertAllEqual([1, 4], get_next_1()) + self.assertAllEqual(0, get_next_3()) + self.assertAllEqual(1, get_next_3()) + self.assertAllEqual(2, get_next_3()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([1, 4], get_next_2()) + self.assertAllEqual([9, 16], get_next_2()) + self.assertAllEqual(3, get_next_3()) + checkpoint.restore(save_path).run_restore_ops(sess) + self.assertAllEqual([9, 16], get_next_1()) + self.assertAllEqual([1, 4], get_next_2()) + self.assertAllEqual(3, get_next_3()) + + @test_util.run_in_graph_and_eager_modes + def testRestoreExhaustedIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + dataset = dataset_ops.Dataset.range(3) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next if context.executing_eagerly( + ) else functools.partial(self.evaluate, iterator.get_next()) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + with self.test_session() as sess: + self.assertAllEqual(0, get_next()) + self.assertAllEqual(1, get_next()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual(2, get_next()) + checkpoint.restore(save_path).run_restore_ops(sess) + self.assertAllEqual(2, get_next()) + save_path = checkpoint.save(checkpoint_prefix) + checkpoint.restore(save_path).run_restore_ops(sess) + with self.assertRaises(errors.OutOfRangeError): + get_next() + + def testRestoreInReconstructedIteratorInitializable(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + dataset = dataset_ops.Dataset.range(10) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + for i in range(5): + with self.test_session() as sess: + checkpoint.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)).initialize_or_restore(sess) + for j in range(2): + self.assertEqual(i * 2 + j, sess.run(get_next)) + checkpoint.save(file_prefix=checkpoint_prefix) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py index f7d7d085c974fa217ed30708723cb1b887034ca0..579096f88097ad9a724b029b7dfd74d04b75f90a 100644 --- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py @@ -123,13 +123,11 @@ class ListFilesDatasetOpTest(test.TestCase): with self.test_session() as sess: itr = dataset.make_initializable_iterator() - next_element = itr.get_next() - sess.run( - itr.initializer, - feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, 'No files matched pattern: '): + sess.run( + itr.initializer, + feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) def testSimpleDirectoryInitializer(self): filenames = ['a', 'b', 'c'] diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index 768d4ac82cce4e78e9ff493cb5b4401614ecd1c0..637bde9ae4eb839e2b983ceec082f868f3ed2728 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -659,6 +659,20 @@ class MapDatasetTest(test.TestCase): break self.assertTrue(found_warning) + def testNestedDatasetError(self): + dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]) + with self.assertRaisesRegexp( + NotImplementedError, r"The Dataset.map\(\) transformation does not " + "currently support nested datasets as outputs."): + _ = dataset.map(dataset_ops.Dataset.from_tensor_slices) + + def testReturnValueError(self): + dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]) + with self.assertRaisesRegexp( + TypeError, r"Unsupported return value from function passed to " + r"Dataset.map\(\): None."): + _ = dataset.map(lambda x: None) + class MapDatasetBenchmark(test.Benchmark): diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a32527af8d6becfdfc1bd7283984c486e8602b92 --- /dev/null +++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py @@ -0,0 +1,186 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the Optional data type wrapper.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops import optional_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class OptionalTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testFromValue(self): + opt = optional_ops.Optional.from_value(constant_op.constant(37.0)) + self.assertEqual(dtypes.float32, opt.output_types) + self.assertEqual([], opt.output_shapes) + self.assertEqual(ops.Tensor, opt.output_classes) + self.assertTrue(self.evaluate(opt.has_value())) + self.assertEqual(37.0, self.evaluate(opt.get_value())) + + @test_util.run_in_graph_and_eager_modes + def testFromStructuredValue(self): + opt = optional_ops.Optional.from_value({ + "a": constant_op.constant(37.0), + "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) + }) + self.assertEqual({ + "a": dtypes.float32, + "b": (dtypes.string, dtypes.string) + }, opt.output_types) + self.assertEqual({"a": [], "b": ([1], [])}, opt.output_shapes) + self.assertEqual({ + "a": ops.Tensor, + "b": (ops.Tensor, ops.Tensor) + }, opt.output_classes) + self.assertTrue(self.evaluate(opt.has_value())) + self.assertEqual({ + "a": 37.0, + "b": ([b"Foo"], b"Bar") + }, self.evaluate(opt.get_value())) + + @test_util.run_in_graph_and_eager_modes + def testFromSparseTensor(self): + st_0 = sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0], dtype=np.int64), + dense_shape=np.array([1])) + st_1 = sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1]]), + values=np.array([-1., 1.], dtype=np.float32), + dense_shape=np.array([2, 2])) + opt = optional_ops.Optional.from_value((st_0, st_1)) + self.assertEqual((dtypes.int64, dtypes.float32), opt.output_types) + self.assertEqual(([1], [2, 2]), opt.output_shapes) + self.assertEqual((sparse_tensor.SparseTensor, sparse_tensor.SparseTensor), + opt.output_classes) + + @test_util.run_in_graph_and_eager_modes + def testFromNone(self): + opt = optional_ops.Optional.none_from_structure(tensor_shape.scalar(), + dtypes.float32, ops.Tensor) + self.assertEqual(dtypes.float32, opt.output_types) + self.assertEqual([], opt.output_shapes) + self.assertEqual(ops.Tensor, opt.output_classes) + self.assertFalse(self.evaluate(opt.has_value())) + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(opt.get_value()) + + def testStructureMismatchError(self): + tuple_output_shapes = (tensor_shape.scalar(), tensor_shape.scalar()) + tuple_output_types = (dtypes.float32, dtypes.float32) + tuple_output_classes = (ops.Tensor, ops.Tensor) + + dict_output_shapes = { + "a": tensor_shape.scalar(), + "b": tensor_shape.scalar() + } + dict_output_types = {"a": dtypes.float32, "b": dtypes.float32} + dict_output_classes = {"a": ops.Tensor, "b": ops.Tensor} + + with self.assertRaises(TypeError): + optional_ops.Optional.none_from_structure( + tuple_output_shapes, tuple_output_types, dict_output_classes) + + with self.assertRaises(TypeError): + optional_ops.Optional.none_from_structure( + tuple_output_shapes, dict_output_types, tuple_output_classes) + + with self.assertRaises(TypeError): + optional_ops.Optional.none_from_structure( + dict_output_shapes, tuple_output_types, tuple_output_classes) + + @test_util.run_in_graph_and_eager_modes + def testCopyToGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + with ops.device("/cpu:0"): + optional_with_value = optional_ops.Optional.from_value( + (constant_op.constant(37.0), constant_op.constant("Foo"), + constant_op.constant(42))) + optional_none = optional_ops.Optional.none_from_structure( + tensor_shape.scalar(), dtypes.float32, ops.Tensor) + + with ops.device("/gpu:0"): + gpu_optional_with_value = optional_ops._OptionalImpl( + array_ops.identity(optional_with_value._variant_tensor), + optional_with_value.output_shapes, optional_with_value.output_types, + optional_with_value.output_classes) + gpu_optional_none = optional_ops._OptionalImpl( + array_ops.identity(optional_none._variant_tensor), + optional_none.output_shapes, optional_none.output_types, + optional_none.output_classes) + + gpu_optional_with_value_has_value = gpu_optional_with_value.has_value() + gpu_optional_with_value_values = gpu_optional_with_value.get_value() + + gpu_optional_none_has_value = gpu_optional_none.has_value() + + self.assertTrue(self.evaluate(gpu_optional_with_value_has_value)) + self.assertEqual((37.0, b"Foo", 42), + self.evaluate(gpu_optional_with_value_values)) + self.assertFalse(self.evaluate(gpu_optional_none_has_value)) + + def testIteratorGetNextAsOptional(self): + ds = dataset_ops.Dataset.range(3) + iterator = ds.make_initializable_iterator() + next_elem = iterator_ops.get_next_as_optional(iterator) + self.assertTrue(isinstance(next_elem, optional_ops.Optional)) + self.assertEqual(ds.output_types, next_elem.output_types) + self.assertEqual(ds.output_shapes, next_elem.output_shapes) + self.assertEqual(ds.output_classes, next_elem.output_classes) + elem_has_value_t = next_elem.has_value() + elem_value_t = next_elem.get_value() + with self.test_session() as sess: + # Before initializing the iterator, evaluating the optional fails with + # a FailedPreconditionError. + with self.assertRaises(errors.FailedPreconditionError): + sess.run(elem_has_value_t) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(elem_value_t) + + # For each element of the dataset, assert that the optional evaluates to + # the expected value. + sess.run(iterator.initializer) + for i in range(3): + elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t]) + self.assertTrue(elem_has_value) + self.assertEqual(i, elem_value) + + # After exhausting the iterator, `next_elem.has_value()` will evaluate to + # false, and attempting to get the value will fail. + for _ in range(2): + self.assertFalse(sess.run(elem_has_value_t)) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(elem_value_t) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py index 646324cb95df6fc1fa0a901ebdccc8d4ef74a66c..63a0830272dca254866c1609fec3677ab28749d5 100644 --- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -24,35 +26,33 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class PrefetchDatasetTest(test.TestCase): +class PrefetchDatasetTest(test.TestCase, parameterized.TestCase): - def testBufferSize(self): - buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) + @parameterized.parameters((-1), (0), (5)) + def testBufferSize(self, buffer_size): + buffer_size_t = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.Dataset.range(10).prefetch( - buffer_size=buffer_size).make_initializable_iterator() + buffer_size=buffer_size_t).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() with self.test_session() as sess: - sess.run(init_op, feed_dict={buffer_size: 5}) + sess.run(init_op, feed_dict={buffer_size_t: buffer_size}) for m in range(10): self.assertEqual(m, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testInvalidBufferSize(self): - buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) + @parameterized.parameters((-2), (-42)) + def testInvalidBufferSize(self, buffer_size): + buffer_size_t = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.Dataset.range(10).prefetch( - buffer_size=buffer_size).make_initializable_iterator() + buffer_size=buffer_size_t).make_initializable_iterator() init_op = iterator.initializer with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"): with self.test_session() as sess: - sess.run(init_op, feed_dict={buffer_size: 0}) - - with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"): - with self.test_session() as sess: - sess.run(init_op, feed_dict={buffer_size: -5}) + sess.run(init_op, feed_dict={buffer_size_t: buffer_size}) if __name__ == "__main__": diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index fa2e86eab18b0b97ea01a96e309b0ea82d91b267..50ba5f403e4a16361920b27f317fcf31fd1798fc 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -11,6 +11,7 @@ py_library( deps = [ ":iterator_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -19,6 +20,7 @@ py_library( "//tensorflow/python:random_seed", "//tensorflow/python:script_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python:util", @@ -40,6 +42,7 @@ py_library( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python/compat", "//tensorflow/python/data/util:convert", ], ) @@ -49,13 +52,33 @@ py_library( srcs = ["iterator_ops.py"], srcs_version = "PY2AND3", deps = [ + ":optional_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:saver", "//tensorflow/python:tensor_shape", + "//tensorflow/python/compat", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", "//tensorflow/python/eager:context", + "//tensorflow/python/training/checkpointable:base", + ], +) + +py_library( + name = "optional_ops", + srcs = ["optional_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", ], ) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index d0deed5edef81a03a5ee69a01c16b98a799189c7..8ba98cb88d8d879762cf00cfd8fa19f00b09f82e 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -24,6 +24,7 @@ import warnings import numpy as np import six +from tensorflow.python.compat import compat from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed @@ -38,10 +39,12 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.ops import string_ops from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @@ -107,13 +110,12 @@ class Dataset(object): "execution is enabled.") if shared_name is None: shared_name = "" - iterator_resource = gen_dataset_ops.iterator( - container="", - shared_name=shared_name, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + if compat.forward_compatible(2018, 8, 3): + iterator_resource = gen_dataset_ops.iterator_v2( + container="", shared_name=shared_name, **flat_structure(self)) + else: + iterator_resource = gen_dataset_ops.iterator( + container="", shared_name=shared_name, **flat_structure(self)) with ops.colocate_with(iterator_resource): initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(), iterator_resource) @@ -171,13 +173,8 @@ class Dataset(object): return iterator_ops.Iterator( gen_dataset_ops.one_shot_iterator( - dataset_factory=_make_dataset, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, - self.output_classes))), None, - self.output_types, self.output_shapes, self.output_classes) + dataset_factory=_make_dataset, **flat_structure(self)), + None, self.output_types, self.output_shapes, self.output_classes) @abc.abstractproperty def output_classes(self): @@ -225,10 +222,10 @@ class Dataset(object): Note that if `tensors` contains a NumPy array, and eager execution is not enabled, the values will be embedded in the graph as one or more - @{tf.constant} operations. For large datasets (> 1 GB), this can waste + `tf.constant` operations. For large datasets (> 1 GB), this can waste memory and run into byte limits of graph serialization. If tensors contains one or more large NumPy arrays, consider the alternative described in - @{$programmers_guide/datasets#consuming_numpy_arrays$this guide}. + @{$guide/datasets#consuming_numpy_arrays$this guide}. Args: tensors: A nested structure of tensors. @@ -244,10 +241,10 @@ class Dataset(object): Note that if `tensors` contains a NumPy array, and eager execution is not enabled, the values will be embedded in the graph as one or more - @{tf.constant} operations. For large datasets (> 1 GB), this can waste + `tf.constant` operations. For large datasets (> 1 GB), this can waste memory and run into byte limits of graph serialization. If tensors contains one or more large NumPy arrays, consider the alternative described in - @{$programmers_guide/datasets#consuming_numpy_arrays$this guide}. + @{$guide/datasets#consuming_numpy_arrays$this guide}. Args: tensors: A nested structure of tensors, each having the same size in the @@ -334,7 +331,7 @@ class Dataset(object): ``` NOTE: The current implementation of `Dataset.from_generator()` uses - @{tf.py_func} and inherits the same constraints. In particular, it + `tf.py_func` and inherits the same constraints. In particular, it requires the `Dataset`- and `Iterator`-related operations to be placed on a device in the same process as the Python program that called `Dataset.from_generator()`. The body of `generator` will not be @@ -644,22 +641,39 @@ class Dataset(object): Defaults to `True`. seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random seed that will be used to create the distribution. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. Returns: Dataset: A `Dataset` of strings corresponding to file names. """ - if shuffle is None: - shuffle = True - matching_files = gen_io_ops.matching_files(file_pattern) - dataset = Dataset.from_tensor_slices(matching_files) - if shuffle: - # NOTE(mrry): The shuffle buffer size must be greater than zero, but the - # list of files might be empty. - buffer_size = math_ops.maximum( - array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1) - dataset = dataset.shuffle(buffer_size, seed=seed) - return dataset + with ops.name_scope("list_files"): + if shuffle is None: + shuffle = True + file_pattern = ops.convert_to_tensor( + file_pattern, dtype=dtypes.string, name="file_pattern") + matching_files = gen_io_ops.matching_files(file_pattern) + + # Raise an exception if `file_pattern` does not match any files. + condition = math_ops.greater(array_ops.shape(matching_files)[0], 0, + name="match_not_empty") + + message = math_ops.add( + "No files matched pattern: ", + string_ops.reduce_join(file_pattern, separator=", "), name="message") + + assert_not_empty = control_flow_ops.Assert( + condition, [message], summarize=1, name="assert_not_empty") + with ops.control_dependencies([assert_not_empty]): + matching_files = array_ops.identity(matching_files) + + dataset = Dataset.from_tensor_slices(matching_files) + if shuffle: + # NOTE(mrry): The shuffle buffer size must be greater than zero, but the + # list of files might be empty. + buffer_size = math_ops.maximum( + array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1) + dataset = dataset.shuffle(buffer_size, seed=seed) + return dataset def repeat(self, count=None): """Repeats this dataset `count` times. @@ -692,7 +706,7 @@ class Dataset(object): dataset will sample. seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random seed that will be used to create the distribution. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. reshuffle_each_iteration: (Optional.) A boolean, which if true indicates that the dataset should be pseudorandomly reshuffled each time it is iterated over. (Defaults to `True`.) @@ -819,11 +833,12 @@ class Dataset(object): def batch(self, batch_size, drop_remainder=False): """Combines consecutive elements of this dataset into batches. - NOTE: If the number of elements (`N`) in this dataset is not an exact - multiple of `batch_size`, the final batch contain smaller tensors with - shape `N % batch_size` in the batch dimension. If your program depends on - the batches having the same shape, consider using the - @{tf.contrib.data.batch_and_drop_remainder} transformation instead. + The tensors in the resulting element will have an additional outer + dimension, which will be `batch_size` (or `N % batch_size` for the last + element if `batch_size` does not divide the number of input elements `N` + evenly and `drop_remainder` is `False`). If your program depends on the + batches having the same outer dimension, you should set the `drop_remainder` + argument to `True` to prevent the smaller batch from being produced. Args: batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of @@ -846,13 +861,19 @@ class Dataset(object): """Combines consecutive elements of this dataset into padded batches. This transformation combines multiple consecutive elements of the input - dataset into a single element. Like @{tf.data.Dataset.batch}, the tensors - in the resulting element have an additional outer dimension, which will be - `batch_size` for all but the last element, and `N % batch_size` for the - last element (where `N` is the number of elements in this dataset). Unlike - @{tf.data.Dataset.batch}, the elements may have different shapes for some - of their components, and this transformation will pad each component to - the respective shape in `padding_shapes`. The `padding_shapes` argument + dataset into a single element. + + Like `tf.data.Dataset.batch`, the tensors in the resulting element will + have an additional outer dimension, which will be `batch_size` (or + `N % batch_size` for the last element if `batch_size` does not divide the + number of input elements `N` evenly and `drop_remainder` is `False`). If + your program depends on the batches having the same outer dimension, you + should set the `drop_remainder` argument to `True` to prevent the smaller + batch from being produced. + + Unlike `tf.data.Dataset.batch`, the input elements to be batched may have + different shapes, and this transformation will pad each component to the + respective shape in `padding_shapes`. The `padding_shapes` argument determines the resulting shape for each dimension of each component in an output element: @@ -862,14 +883,8 @@ class Dataset(object): will be padded out to the maximum length of all elements in that dimension. - NOTE: If the number of elements (`N`) in this dataset is not an exact - multiple of `batch_size`, the final batch contain smaller tensors with - shape `N % batch_size` in the batch dimension. If your program depends on - the batches having the same shape, consider using the - @{tf.contrib.data.padded_batch_and_drop_remainder} transformation instead. - - See also @{tf.contrib.data.dense_to_sparse_batch}, which combines elements - that may have different shapes into a @{tf.SparseTensor}. + See also `tf.contrib.data.dense_to_sparse_batch`, which combines elements + that may have different shapes into a `tf.SparseTensor`. Args: batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of @@ -897,7 +912,83 @@ class Dataset(object): drop_remainder) def map(self, map_func, num_parallel_calls=None): - """Maps `map_func` across this dataset. + """Maps `map_func` across the elements of this dataset. + + This transformation applies `map_func` to each element of this dataset, and + returns a new dataset containing the transformed elements, in the same + order as they appeared in the input. + + For example: + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. + a = { 1, 2, 3, 4, 5 } + + a.map(lambda x: x + 1) = { 2, 3, 4, 5, 6 } + ``` + + The input signature of `map_func` is determined by the structure of each + element in this dataset. For example: + + ```python + # Each element is a `tf.Tensor` object. + a = { 1, 2, 3, 4, 5 } + # `map_func` takes a single argument of type `tf.Tensor` with the same + # shape and dtype. + result = a.map(lambda x: ...) + + # Each element is a tuple containing two `tf.Tensor` objects. + b = { (1, "foo"), (2, "bar"), (3, "baz") } + # `map_func` takes two arguments of type `tf.Tensor`. + result = b.map(lambda x_int, y_str: ...) + + # Each element is a dictionary mapping strings to `tf.Tensor` objects. + c = { {"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}, {"a": 3, "b": "baz"} } + # `map_func` takes a single argument of type `dict` with the same keys as + # the elements. + result = c.map(lambda d: ...) + ``` + + The value or values returned by `map_func` determine the structure of each + element in the returned dataset. + + ```python + # `map_func` returns a scalar `tf.Tensor` of type `tf.float32`. + def f(...): + return tf.constant(37.0) + result = dataset.map(f) + result.output_classes == tf.Tensor + result.output_types == tf.float32 + result.output_shapes == [] # scalar + + # `map_func` returns two `tf.Tensor` objects. + def g(...): + return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"]) + result = dataset.map(g) + result.output_classes == (tf.Tensor, tf.Tensor) + result.output_types == (tf.float32, tf.string) + result.output_shapes == ([], [3]) + + # Python primitives, lists, and NumPy arrays are implicitly converted to + # `tf.Tensor`. + def h(...): + return 37.0, ["Foo", "Bar", "Baz"], np.array([1.0, 2.0] dtype=np.float64) + result = dataset.map(h) + result.output_classes == (tf.Tensor, tf.Tensor, tf.Tensor) + result.output_types == (tf.float32, tf.string, tf.float64) + result.output_shapes == ([], [3], [2]) + + # `map_func` can return nested structures. + def i(...): + return {"a": 37.0, "b": [42, 16]}, "foo" + result.output_classes == ({"a": tf.Tensor, "b": tf.Tensor}, tf.Tensor) + result.output_types == ({"a": tf.float32, "b": tf.int32}, tf.string) + result.output_shapes == ({"a": [], "b": [2]}, []) + ``` + + In addition to `tf.Tensor` objects, `map_func` can accept as arguments and + return `tf.SparseTensor` objects. Args: map_func: A function mapping a nested structure of tensors (having @@ -948,7 +1039,7 @@ class Dataset(object): elements are produced. `cycle_length` controls the number of input elements that are processed concurrently. If you set `cycle_length` to 1, this transformation will handle one input element at a time, and will produce - identical results = to @{tf.data.Dataset.flat_map}. In general, + identical results = to `tf.data.Dataset.flat_map`. In general, this transformation will apply `map_func` to `cycle_length` input elements, open iterators on the returned `Dataset` objects, and cycle through them producing `block_length` consecutive elements from each iterator, and @@ -1158,6 +1249,270 @@ class SparseTensorSliceDataset(Dataset): return (dtypes.int64, self._sparse_tensor.dtype, dtypes.int64) +class _NestedDatasetComponent(object): + """The structure of a `Dataset` nested in a component of another `Dataset`. + + A `StructuredFunctionWrapper` around a function that returns a `Dataset` as + one of its components will have a `NestedDatasetComponent` in the + corresponding position in the `output_classes`, `output_shapes`, and + `output_types` properties. + + NOTE(mrry): This class is not currently exposed via the public API. Support + for nested datasets can be enabled on a function-by-function basis by setting + `experimental_nested_dataset_support=True` in the `StructuredFunctionWrapper` + initializer. + + TODO(b/110122868): Add this class, or something equivalent, to the public API. + We are considering revising the public API for accessing Dataset structure + (`output_classes` etc.) based on experience with nested datasets and other + custom component types. + """ + + def __init__(self, + dataset=None, + output_shapes=None, + output_types=None, + output_classes=None): + if dataset is None: + if (output_classes is None or output_shapes is None or + output_types is None): + raise ValueError( + "Either `dataset`, or all of `output_classes`, " + "`output_shapes`, and `output_types` must be specified.") + self._output_classes = output_classes + self._output_shapes = output_shapes + self._output_types = output_types + else: + if not (output_classes is None and output_shapes is None and + output_types is None): + raise ValueError( + "Either `dataset`, or all of `output_classes`, " + "`output_shapes`, and `output_types` must be specified.") + self._output_classes = dataset.output_classes + self._output_shapes = dataset.output_shapes + self._output_types = dataset.output_types + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +class _VariantDataset(Dataset): + """A Dataset wrapper around a `tf.variant`-typed function argument.""" + + def __init__(self, dataset_variant, structure): + super(_VariantDataset, self).__init__() + self._dataset_variant = dataset_variant + self._structure = structure + + def _as_variant_tensor(self): + return self._dataset_variant + + @property + def output_classes(self): + return self._structure.output_classes + + @property + def output_shapes(self): + return self._structure.output_shapes + + @property + def output_types(self): + return self._structure.output_types + + +class StructuredFunctionWrapper(object): + """A wrapper for `Defun` that supports structured arguments and return values. + """ + + def __init__(self, func, transformation_name, dataset=None, + input_classes=None, input_shapes=None, input_types=None, + add_to_graph=True, experimental_nested_dataset_support=False): + """Creates a new `StructuredFunctionWrapper` for the given function. + + Args: + func: A function from a nested structure to another nested structure. + transformation_name: Human-readable name of the transformation in which + this function is being instantiated, for error messages. + dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this + dataset will be assumed as the structure for `func` arguments; otherwise + `input_classes`, `input_shapes`, and `input_types` must be defined. + input_classes: (Optional.) A nested structure of `type`. If given, this + argument defines the Python types for `func` arguments. + input_shapes: (Optional.) A nested structure of `tf.TensorShape`. If + given, this argument defines the shapes and structure for `func` + arguments. + input_types: (Optional.) A nested structure of `tf.DType`. If given, this + argument defines the element types and structure for `func` arguments. + add_to_graph: (Optional.) If `True`, the function will be added to the + default graph. + experimental_nested_dataset_support: (Optional.) If `True`, the function + will support `tf.data.Dataset` objects as arguments and return values. + + Raises: + ValueError: If an invalid combination of `dataset`, `input_classes`, + `input_shapes`, and `input_types` is passed. + """ + if dataset is None: + if input_classes is None or input_shapes is None or input_types is None: + raise ValueError("Either `dataset`, or all of `input_classes`, " + "`input_shapes`, and `input_types` must be specified.") + self._input_shapes = input_shapes + self._input_types = input_types + self._input_classes = input_classes + else: + if not (input_classes is None and input_shapes is None and + input_types is None): + raise ValueError("Either `dataset`, or all of `input_classes`, " + "`input_shapes`, and `input_types` must be specified.") + self._input_shapes = dataset.output_shapes + self._input_types = dataset.output_types + self._input_classes = dataset.output_classes + + self._transformation_name = transformation_name + + # TODO(b/110122868): Enable this support for all `tf.data` functions. + self._nested_dataset_support = experimental_nested_dataset_support + + @function.Defun(*self._defun_args()) + def tf_data_structured_function_wrapper(*args): + """Wrapper for passing nested structures to and from tf.data functions.""" + flat_args = [] + for arg, arg_class, arg_shape, arg_type in zip( + args, + nest.flatten(self._input_classes), + nest.flatten(self._input_shapes), + nest.flatten(self._input_types)): + # TODO(b/110122868): Add a registration mechanism for new component + # types. + if arg_class is sparse_tensor_lib.SparseTensor: + arg = sparse.deserialize_sparse_tensors( + arg, arg_type, arg_shape, arg_class) + arg.indices.set_shape([None, arg_shape.ndims]) + arg.dense_shape.set_shape([arg_shape.ndims]) + elif isinstance(arg_class, _NestedDatasetComponent): + assert self._nested_dataset_support + arg = _VariantDataset(arg, arg_class) + else: + arg.set_shape(arg_shape) + flat_args.append(arg) + nested_args = nest.pack_sequence_as(self._input_classes, flat_args) + if not _should_unpack_args(nested_args): + nested_args = (nested_args,) + + ret = func(*nested_args) + # If `func` returns a list of tensors, `nest.flatten()` and + # `ops.convert_to_tensor()` would conspire to attempt to stack + # those tensors into a single tensor, because the customized + # version of `nest.flatten()` does not recurse into lists. Since + # it is more likely that the list arose from returning the + # result of an operation (such as `tf.py_func()`) that returns a + # list of not-necessarily-stackable tensors, we treat the + # returned value is a `tuple` instead. A user wishing to pack + # the return value into a single tensor can use an explicit + # `tf.stack()` before returning. + if isinstance(ret, list): + ret = tuple(ret) + + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + flat_ret = [] + flat_classes = [] + flat_shapes = [] + flat_types = [] + for t in nest.flatten(ret): + # TODO(b/110122868): Add a registration mechanism for new component + # types. + if sparse_tensor_lib.is_sparse(t): + t = sparse_tensor_lib.SparseTensor.from_value(t) + flat_ret.append(sparse.serialize_sparse_tensors(t)) + flat_classes.append(sparse_tensor_lib.SparseTensor) + flat_shapes.append(t.get_shape()) + flat_types.append(t.dtype) + elif isinstance(t, Dataset): + if not self._nested_dataset_support: + raise NotImplementedError( + "The %s transformation does not currently support nested " + "datasets as outputs." % self._transformation_name) + + flat_ret.append(t._as_variant_tensor()) # pylint: disable=protected-access + component = _NestedDatasetComponent(t) + flat_classes.append(component) + flat_shapes.append(component) + flat_types.append(component) + else: + try: + t = ops.convert_to_tensor(t) + except (ValueError, TypeError): + raise TypeError("Unsupported return value from function passed to " + "%s: %s." % (transformation_name, t)) + flat_ret.append(t) + flat_classes.append(ops.Tensor) + flat_shapes.append(t.get_shape()) + flat_types.append(t.dtype) + + ret = nest.pack_sequence_as(ret, flat_ret) + self._output_classes = nest.pack_sequence_as(ret, flat_classes) + self._output_shapes = nest.pack_sequence_as(ret, flat_shapes) + self._output_types = nest.pack_sequence_as(ret, flat_types) + + _warn_if_collections(transformation_name) + + return flat_ret + + self._function = tf_data_structured_function_wrapper + if add_to_graph: + self._function.add_to_graph(ops.get_default_graph()) + else: + # Use the private method that will execute + # `tf_data_structured_function_wrapper` but delay adding it to the graph + # in case (e.g.) we need to rerun the function. + self._function._create_definition_if_needed() # pylint: disable=protected-access + + def _defun_args(self): + """Returns a flat list of `tf.DType` for the input element structure.""" + ret = [] + for input_type, input_class in zip(nest.flatten(self._input_types), + nest.flatten(self._input_classes)): + # TODO(b/110122868): Add a registration mechanism for new component types. + if input_class is sparse_tensor_lib.SparseTensor: + ret.append(dtypes.variant) + elif isinstance(input_class, _NestedDatasetComponent): + if not self._nested_dataset_support: + raise NotImplementedError( + "The %s transformation does not currently support nested " + "datasets as inputs." % self._transformation_name) + ret.append(dtypes.variant) + else: + assert isinstance(input_type, dtypes.DType) + ret.append(input_type) + return ret + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + @property + def function(self): + return self._function + + def flat_structure(dataset): """Helper for setting `output_shapes` and `output_types` attrs of Dataset ops. @@ -1168,17 +1523,36 @@ def flat_structure(dataset): `**flat_structure(self)` to the op constructor. Args: - dataset: A @{tf.data.Dataset}. + dataset: A `tf.data.Dataset`. Returns: A dictionary of keyword arguments that can be passed to many Dataset op constructors. """ + output_classes = [] + output_shapes = [] + output_types = [] + for output_class, output_shape, output_type in zip( + nest.flatten(dataset.output_classes), nest.flatten(dataset.output_shapes), + nest.flatten(dataset.output_types)): + if isinstance(output_class, _NestedDatasetComponent): + output_classes.append(output_class.output_classes) + output_shapes.append(output_shape.output_shapes) + output_types.append(output_type.output_types) + else: + output_classes.append(output_class) + output_shapes.append(output_shape) + output_types.append(output_type) + + output_classes = nest.pack_sequence_as(dataset.output_classes, output_classes) + output_shapes = nest.pack_sequence_as(dataset.output_shapes, output_shapes) + output_types = nest.pack_sequence_as(dataset.output_types, output_types) + return { - "output_shapes": nest.flatten(sparse.as_dense_shapes( - dataset.output_shapes, dataset.output_classes)), - "output_types": nest.flatten(sparse.as_dense_types( - dataset.output_types, dataset.output_classes)), + "output_shapes": + nest.flatten(sparse.as_dense_shapes(output_shapes, output_classes)), + "output_types": + nest.flatten(sparse.as_dense_types(output_types, output_classes)), } @@ -1214,137 +1588,26 @@ class _GeneratorDataset(Dataset): init_args_types = nest.pack_sequence_as( init_args, [t.dtype for t in nest.flatten(init_args)]) - @function.Defun(*nest.flatten( - sparse.as_dense_types(init_args_types, init_args_classes))) - def tf_init_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - dense_shapes = sparse.as_dense_shapes(init_args_shapes, init_args_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(init_args_classes, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, init_args_types, init_args_shapes, init_args_classes) - if _should_unpack_args(nested_args): - ret = init_func(*nested_args) - else: - ret = init_func(nested_args) - - # If `init_func` returns a list of tensors, `nest.flatten()` and - # `ops.convert_to_tensor()` would conspire to attempt to stack - # those tensors into a single tensor, because the customized - # version of `nest.flatten()` does not recurse into lists. Since - # it is more likely that the list arose from returning the - # result of an operation (such as `tf.py_func()`) that returns a - # list of not-necessarily-stackable tensors, we treat the - # returned value is a `tuple` instead. A user wishing to pack - # the return value into a single tensor can use an explicit - # `tf.stack()` before returning. - if isinstance(ret, list): - ret = tuple(ret) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor_lib.SparseTensor.from_value(t) - if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._state_classes = sparse.get_classes(ret) - self._state_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._state_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._init_func = tf_init_func - self._init_func.add_to_graph(ops.get_default_graph()) - - # These members will be initialized by `tf_next_func`. - self._output_classes = None - self._output_shapes = None - self._output_types = None - - @function.Defun(*nest.flatten( - sparse.as_dense_types(self._state_types, self._state_classes))) - def tf_next_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(self._state_shapes, - self._state_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(self._state_classes, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, self._state_types, self._state_shapes, - self._state_classes) - if _should_unpack_args(nested_args): - ret = next_func(*nested_args) - else: - ret = next_func(nested_args) - - # If `next_func` returns a list of tensors, `nest.flatten()` and - # `ops.convert_to_tensor()` would conspire to attempt to stack - # those tensors into a single tensor, because the customized - # version of `nest.flatten()` does not recurse into lists. Since - # it is more likely that the list arose from returning the - # result of an operation (such as `tf.py_func()`) that returns a - # list of not-necessarily-stackable tensors, we treat the - # returned value is a `tuple` instead. A user wishing to pack - # the return value into a single tensor can use an explicit - # `tf.stack()` before returning. - if isinstance(ret, list): - ret = tuple(ret) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor_lib.SparseTensor.from_value(t) - if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._output_classes = sparse.get_classes(ret) - self._output_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._output_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._next_func = tf_next_func - self._next_func.add_to_graph(ops.get_default_graph()) - - @function.Defun(*nest.flatten( - sparse.as_dense_types(self._state_types, self._state_classes))) - def tf_finalize_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the state. - dense_shapes = sparse.as_dense_shapes(self._state_shapes, - self._state_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(self._state_classes, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, self._state_types, self._state_shapes, - self._state_classes) - if _should_unpack_args(nested_args): - return finalize_func(*nested_args) - else: - return finalize_func(nested_args) - - self._finalize_func = tf_finalize_func - self._finalize_func.add_to_graph(ops.get_default_graph()) + wrapped_init_func = StructuredFunctionWrapper( + init_func, "GeneratorDataset", input_classes=init_args_classes, + input_shapes=init_args_shapes, input_types=init_args_types) + self._state_classes = wrapped_init_func.output_classes + self._state_shapes = wrapped_init_func.output_shapes + self._state_types = wrapped_init_func.output_types + self._init_func = wrapped_init_func.function + + wrapped_next_func = StructuredFunctionWrapper( + next_func, "GeneratorDataset", input_classes=self._state_classes, + input_shapes=self._state_shapes, input_types=self._state_types) + self._output_classes = wrapped_next_func.output_classes + self._output_shapes = wrapped_next_func.output_shapes + self._output_types = wrapped_next_func.output_types + self._next_func = wrapped_next_func.function + + wrapped_finalize_func = StructuredFunctionWrapper( + finalize_func, "GeneratorDataset", input_classes=self._state_classes, + input_shapes=self._state_shapes, input_types=self._state_types) + self._finalize_func = wrapped_finalize_func.function def _as_variant_tensor(self): return gen_dataset_ops.generator_dataset( @@ -1498,6 +1761,7 @@ class RangeDataset(Dataset): self._parse_args(*args) def _parse_args(self, *args): + """Parse arguments according to the same rules as the `range()` builtin.""" if len(args) == 1: self._start = self._build_tensor(0, "start") self._stop = self._build_tensor(args[0], "stop") @@ -1582,7 +1846,7 @@ class ShuffleDataset(Dataset): dataset will sample. seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random seed that will be used to create the distribution. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. reshuffle_each_iteration: (Optional.) A boolean, which if true indicates that the dataset should be pseudorandomly reshuffled each time it is iterated over. (Defaults to `True`.) @@ -1823,7 +2087,7 @@ def _padding_value_to_tensor(value, output_type): def _default_padding(input_dataset): - + """Returns default padding tensors in a structure matching `input_dataset`.""" def make_zero(t): if t.base_dtype == dtypes.string: return "" @@ -1949,66 +2213,12 @@ class MapDataset(Dataset): super(MapDataset, self).__init__() self._input_dataset = input_dataset - self._output_classes = None - self._output_shapes = None - self._output_types = None - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_map_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - if _should_unpack_args(nested_args): - ret = map_func(*nested_args) - else: - ret = map_func(nested_args) - - # If `map_func` returns a list of tensors, `nest.flatten()` and - # `ops.convert_to_tensor()` would conspire to attempt to stack - # those tensors into a single tensor, because the customized - # version of `nest.flatten()` does not recurse into lists. Since - # it is more likely that the list arose from returning the - # result of an operation (such as `tf.py_func()`) that returns a - # list of not-necessarily-stackable tensors, we treat the - # returned value is a `tuple` instead. A user wishing to pack - # the return value into a single tensor can use an explicit - # `tf.stack()` before returning. - if isinstance(ret, list): - ret = tuple(ret) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor_lib.SparseTensor.from_value(t) - if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._output_classes = sparse.get_classes(ret) - self._output_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._output_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - _warn_if_collections("Dataset.map()") - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._map_func = tf_map_func - self._map_func.add_to_graph(ops.get_default_graph()) + wrapped_func = StructuredFunctionWrapper( + map_func, "Dataset.map()", input_dataset) + self._output_classes = wrapped_func.output_classes + self._output_shapes = wrapped_func.output_shapes + self._output_types = wrapped_func.output_types + self._map_func = wrapped_func.function def _as_variant_tensor(self): input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access @@ -2061,39 +2271,15 @@ class FlatMapDataset(Dataset): super(FlatMapDataset, self).__init__() self._input_dataset = input_dataset - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_map_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - if _should_unpack_args(nested_args): - dataset = map_func(*nested_args) - else: - dataset = map_func(nested_args) - - if not isinstance(dataset, Dataset): - raise TypeError("`map_func` must return a `Dataset` object.") - - _warn_if_collections(self._transformation_name()) - - self._output_classes = dataset.output_classes - self._output_types = dataset.output_types - self._output_shapes = dataset.output_shapes - - return dataset._as_variant_tensor() # pylint: disable=protected-access - - self._map_func = tf_map_func - self._map_func.add_to_graph(ops.get_default_graph()) + wrapped_func = StructuredFunctionWrapper( + map_func, self._transformation_name(), input_dataset, + experimental_nested_dataset_support=True) + if not isinstance(wrapped_func.output_classes, _NestedDatasetComponent): + raise TypeError("`map_func` must return a `Dataset` object.") + self._output_classes = wrapped_func.output_classes.output_classes + self._output_types = wrapped_func.output_types.output_types + self._output_shapes = wrapped_func.output_shapes.output_shapes + self._map_func = wrapped_func.function def _as_variant_tensor(self): return gen_dataset_ops.flat_map_dataset( @@ -2150,38 +2336,13 @@ class FilterDataset(Dataset): """See `Dataset.filter()` for details.""" super(FilterDataset, self).__init__() self._input_dataset = input_dataset - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_predicate(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - if _should_unpack_args(nested_args): - ret = predicate(*nested_args) - else: - ret = predicate(nested_args) - - ret = ops.convert_to_tensor(ret, dtype=dtypes.bool) - if not (ret.dtype == dtypes.bool and - ret.shape.is_compatible_with(tensor_shape.scalar())): - raise ValueError("`predicate` must return a scalar boolean tensor.") - - _warn_if_collections("Dataset.filter()") - - return ret - - self._predicate = tf_predicate - self._predicate.add_to_graph(ops.get_default_graph()) + wrapped_func = StructuredFunctionWrapper( + predicate, "Dataset.filter()", input_dataset) + if not ( + wrapped_func.output_types == dtypes.bool and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError("`predicate` must return a scalar boolean tensor.") + self._predicate = wrapped_func.function def _as_variant_tensor(self): return gen_dataset_ops.filter_dataset( diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index b6dba4e3ca3874b8e9bc3b7ea92fb91fe41759d8..8f8e026df92c3fd430a2c1d6211668cad2a20a4c 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -20,6 +20,8 @@ from __future__ import print_function import threading import warnings +from tensorflow.python.compat import compat +from tensorflow.python.data.ops import optional_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.eager import context @@ -29,6 +31,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.saver import BaseSaverBuilder from tensorflow.python.util.tf_export import tf_export @@ -56,8 +60,15 @@ GET_NEXT_CALL_WARNING_MESSAGE = ( GLOBAL_ITERATORS = "iterators" +def _device_stack_is_empty(): + # pylint: disable=protected-access + device_stack = ops.get_default_graph()._device_functions_outer_to_inner + # pylint: enable=protected-access + return not bool(device_stack) + + @tf_export("data.Iterator") -class Iterator(object): +class Iterator(checkpointable.CheckpointableBase): """Represents the state of iterating through a `Dataset`.""" def __init__(self, iterator_resource, initializer, output_types, @@ -172,13 +183,32 @@ class Iterator(object): nest.assert_same_structure(output_types, output_shapes) if shared_name is None: shared_name = "" - iterator_resource = gen_dataset_ops.iterator( - container="", - shared_name=shared_name, - output_types=nest.flatten( - sparse.as_dense_types(output_types, output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(output_shapes, output_classes))) + if compat.forward_compatible(2018, 8, 3): + if _device_stack_is_empty(): + with ops.device("/cpu:0"): + iterator_resource = gen_dataset_ops.iterator_v2( + container="", + shared_name=shared_name, + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) + else: + iterator_resource = gen_dataset_ops.iterator_v2( + container="", + shared_name=shared_name, + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) + else: + iterator_resource = gen_dataset_ops.iterator( + container="", + shared_name=shared_name, + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes) @@ -190,9 +220,9 @@ class Iterator(object): """Creates a new, uninitialized `Iterator` based on the given handle. This method allows you to define a "feedable" iterator where you can choose - between concrete iterators by feeding a value in a @{tf.Session.run} call. - In that case, `string_handle` would a @{tf.placeholder}, and you would feed - it with the value of @{tf.data.Iterator.string_handle} in each step. + between concrete iterators by feeding a value in a `tf.Session.run` call. + In that case, `string_handle` would be a `tf.placeholder`, and you would + feed it with the value of `tf.data.Iterator.string_handle` in each step. For example, if you had two iterators that marked the current position in a training dataset and a test dataset, you could choose which to use in @@ -242,12 +272,29 @@ class Iterator(object): output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) - iterator_resource = gen_dataset_ops.iterator_from_string_handle( - string_handle, - output_types=nest.flatten( - sparse.as_dense_types(output_types, output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(output_shapes, output_classes))) + if compat.forward_compatible(2018, 8, 3): + if _device_stack_is_empty(): + with ops.device("/cpu:0"): + iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( + string_handle, + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) + else: + iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( + string_handle, + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) + else: + iterator_resource = gen_dataset_ops.iterator_from_string_handle( + string_handle, + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes) @@ -315,9 +362,9 @@ class Iterator(object): In graph mode, you should typically call this method *once* and use its result as the input to another computation. A typical loop will then call - @{tf.Session.run} on the result of that computation. The loop will terminate + `tf.Session.run` on the result of that computation. The loop will terminate when the `Iterator.get_next()` operation raises - @{tf.errors.OutOfRangeError}. The following skeleton shows how to use + `tf.errors.OutOfRangeError`. The following skeleton shows how to use this method when building a training loop: ```python @@ -420,6 +467,13 @@ class Iterator(object): """ return self._output_types + def _gather_saveables_for_checkpoint(self): + + def _saveable_factory(name): + return _IteratorSaveable(self._iterator_resource, name) + + return {"ITERATOR": _saveable_factory} + _uid_counter = 0 _uid_lock = threading.Lock() @@ -433,7 +487,7 @@ def _generate_shared_name(prefix): return "{}{}".format(prefix, uid) -class EagerIterator(object): +class EagerIterator(checkpointable.CheckpointableBase): """An iterator producing tf.Tensor objects from a tf.data.Dataset.""" def __init__(self, dataset): @@ -462,7 +516,8 @@ class EagerIterator(object): "tf.data.Dataset.make_initializable_iterator or " "tf.data.Dataset.make_one_shot_iterator for graph construction". format(type(self))) - with ops.device("/device:CPU:0"): + self._device = context.context().device_name + with ops.device("/cpu:0"): ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access self._output_classes = dataset.output_classes self._output_types = dataset.output_types @@ -471,14 +526,14 @@ class EagerIterator(object): sparse.as_dense_types(self._output_types, self._output_classes)) self._flat_output_shapes = nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)) - self._resource = gen_dataset_ops.anonymous_iterator( - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - gen_dataset_ops.make_iterator(ds_variant, self._resource) - # Delete the resource when this object is deleted - self._resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._resource, handle_device="/device:CPU:0") - self._device = context.context().device_name + with ops.colocate_with(ds_variant): + self._resource = gen_dataset_ops.anonymous_iterator( + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + gen_dataset_ops.make_iterator(ds_variant, self._resource) + # Delete the resource when this object is deleted + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._resource, handle_device=self._device) def __iter__(self): return self @@ -565,3 +620,56 @@ class EagerIterator(object): """ del name return self._next_internal() + + def _gather_saveables_for_checkpoint(self): + + def _saveable_factory(name): + return _IteratorSaveable(self._resource, name) + + return {"ITERATOR": _saveable_factory} + + +# TODO(b/71645805): Expose checkpointable stateful objects from dataset +# attributes(potential). +class _IteratorSaveable(BaseSaverBuilder.SaveableObject): + """SaveableObject for saving/restoring iterator state.""" + + def __init__(self, iterator_resource, name): + serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) + specs = [ + BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") + ] + # pylint: disable=protected-access + super(_IteratorSaveable, self).__init__(iterator_resource, specs, name) + + def restore(self, restored_tensors, restored_shapes): + with ops.colocate_with(self.op): + return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) + + +def get_next_as_optional(iterator): + """Returns an `Optional` that contains the next value from the iterator. + + If `iterator` has reached the end of the sequence, the returned `Optional` + will have no value. + + Args: + iterator: A `tf.data.Iterator` object. + + Returns: + An `Optional` object representing the next value from the iterator (if it + has one) or no value. + """ + # pylint: disable=protected-access + return optional_ops._OptionalImpl( + gen_dataset_ops.iterator_get_next_as_optional( + iterator._iterator_resource, + output_types=nest.flatten( + sparse.as_dense_types(iterator.output_types, + iterator.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(iterator.output_shapes, + iterator.output_classes))), + output_shapes=iterator.output_shapes, + output_types=iterator.output_types, + output_classes=iterator.output_classes) diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..b75b98dc72975bb30cfb3e56f3ed1845b4d5c370 --- /dev/null +++ b/tensorflow/python/data/ops/optional_ops.py @@ -0,0 +1,209 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An Optional type for representing potentially missing values.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +class Optional(object): + """Wraps a nested structure of tensors that may/may not be present at runtime. + + An `Optional` can represent the result of an operation that may fail as a + value, rather than raising an exception and halting execution. For example, + `tf.contrib.data.get_next_as_optional` returns an `Optional` that either + contains the next value from a `tf.data.Iterator` if one exists, or a "none" + value that indicates the end of the sequence has been reached. + """ + + @abc.abstractmethod + def has_value(self, name=None): + """Returns a tensor that evaluates to `True` if this optional has a value. + + Args: + name: (Optional.) A name for the created operation. + + Returns: + A scalar `tf.Tensor` of type `tf.bool`. + """ + raise NotImplementedError("Optional.has_value()") + + @abc.abstractmethod + def get_value(self, name=None): + """Returns a nested structure of values wrapped by this optional. + + If this optional does not have a value (i.e. `self.has_value()` evaluates + to `False`), this operation will raise `tf.errors.InvalidArgumentError` + at runtime. + + Args: + name: (Optional.) A name for the created operation. + + Returns: + A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects. + """ + raise NotImplementedError("Optional.get_value()") + + @abc.abstractproperty + def output_classes(self): + """Returns the class of each component of this optional. + + The expected values are `tf.Tensor` and `tf.SparseTensor`. + + Returns: + A nested structure of Python `type` objects corresponding to each + component of this optional. + """ + raise NotImplementedError("Optional.output_classes") + + @abc.abstractproperty + def output_shapes(self): + """Returns the shape of each component of this optional. + + Returns: + A nested structure of `tf.TensorShape` objects corresponding to each + component of this optional. + """ + raise NotImplementedError("Optional.output_shapes") + + @abc.abstractproperty + def output_types(self): + """Returns the type of each component of this optional. + + Returns: + A nested structure of `tf.DType` objects corresponding to each component + of this optional. + """ + raise NotImplementedError("Optional.output_types") + + @staticmethod + def from_value(value): + """Returns an `Optional` that wraps the given value. + + Args: + value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects. + + Returns: + An `Optional` that wraps `value`. + """ + # TODO(b/110122868): Consolidate this destructuring logic with the + # similar code in `Dataset.from_tensors()`. + with ops.name_scope("optional") as scope: + with ops.name_scope("value"): + value = nest.pack_sequence_as(value, [ + sparse_tensor_lib.SparseTensor.from_value(t) + if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor( + t, name="component_%d" % i) + for i, t in enumerate(nest.flatten(value)) + ]) + + encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value)) + output_classes = sparse.get_classes(value) + output_shapes = nest.pack_sequence_as( + value, [t.get_shape() for t in nest.flatten(value)]) + output_types = nest.pack_sequence_as( + value, [t.dtype for t in nest.flatten(value)]) + + return _OptionalImpl( + gen_dataset_ops.optional_from_value(encoded_value, name=scope), + output_shapes, output_types, output_classes) + + @staticmethod + def none_from_structure(output_shapes, output_types, output_classes): + """Returns an `Optional` that has no value. + + NOTE: This method takes arguments that define the structure of the value + that would be contained in the returned `Optional` if it had a value. + + Args: + output_shapes: A nested structure of `tf.TensorShape` objects + corresponding to each component of this optional. + output_types: A nested structure of `tf.DType` objects corresponding to + each component of this optional. + output_classes: A nested structure of Python `type` objects corresponding + to each component of this optional. + + Returns: + An `Optional` that has no value. + """ + return _OptionalImpl(gen_dataset_ops.optional_none(), output_shapes, + output_types, output_classes) + + +class _OptionalImpl(Optional): + """Concrete implementation of `tf.contrib.data.Optional`. + + NOTE(mrry): This implementation is kept private, to avoid defining + `Optional.__init__()` in the public API. + """ + + def __init__(self, variant_tensor, output_shapes, output_types, + output_classes): + # TODO(b/110122868): Consolidate the structure validation logic with the + # similar logic in `Iterator.from_structure()` and + # `Dataset.from_generator()`. + output_types = nest.map_structure(dtypes.as_dtype, output_types) + output_shapes = nest.map_structure_up_to( + output_types, tensor_shape.as_shape, output_shapes) + nest.assert_same_structure(output_types, output_shapes) + nest.assert_same_structure(output_types, output_classes) + self._variant_tensor = variant_tensor + self._output_shapes = output_shapes + self._output_types = output_types + self._output_classes = output_classes + + def has_value(self, name=None): + return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name) + + def get_value(self, name=None): + # TODO(b/110122868): Consolidate the restructuring logic with similar logic + # in `Iterator.get_next()` and `StructuredFunctionWrapper`. + with ops.name_scope(name, "OptionalGetValue", + [self._variant_tensor]) as scope: + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as( + self._output_types, + gen_dataset_ops.optional_get_value( + self._variant_tensor, + name=scope, + output_types=nest.flatten( + sparse.as_dense_types(self._output_types, + self._output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self._output_shapes, + self._output_classes)))), + self._output_types, self._output_shapes, self._output_classes) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types diff --git a/tensorflow/python/data/util/convert.py b/tensorflow/python/data/util/convert.py index 746b3d66de082d59e8c1e316c51e2a9ab7670e6d..ba297900b0c9834d856d1fea866c01313473ad0a 100644 --- a/tensorflow/python/data/util/convert.py +++ b/tensorflow/python/data/util/convert.py @@ -36,11 +36,11 @@ def optional_param_to_tensor(argument_name, def partial_shape_to_tensor(shape_like): - """Returns a @{tf.Tensor} that represents the given shape. + """Returns a `tf.Tensor` that represents the given shape. Args: - shape_like: A value that can be converted to a @{tf.TensorShape} or a - @{tf.Tensor}. + shape_like: A value that can be converted to a `tf.TensorShape` or a + `tf.Tensor`. Returns: A 1-D `tf.Tensor` of `tf.int64` elements representing the given shape, where diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py index 32e08021dc80d11baaead68ea062b6dab7a8dfdd..1b596bdfc0e7cc18be8ffbb96a5e3a797d7cf619 100644 --- a/tensorflow/python/data/util/nest.py +++ b/tensorflow/python/data/util/nest.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== -# TODO(shivaniagrawal): Merge with core nest """## Functions for working with arbitrarily nested sequences of elements. NOTE(mrry): This fork of the `tensorflow.python.util.nest` module diff --git a/tensorflow/python/data/util/random_seed.py b/tensorflow/python/data/util/random_seed.py index e2c9d8672f94587fd3164f25f97b44a97526be07..d5169f7a53e815f7ab4e1a2e973414ead4b7c71d 100644 --- a/tensorflow/python/data/util/random_seed.py +++ b/tensorflow/python/data/util/random_seed.py @@ -29,14 +29,14 @@ from tensorflow.python.ops import math_ops def get_seed(seed): """Returns the local seeds an operation should use given an op-specific seed. - See @{tf.get_seed} for more details. This wrapper adds support for the case + See `tf.get_seed` for more details. This wrapper adds support for the case where `seed` may be a tensor. Args: - seed: An integer or a @{tf.int64} scalar tensor. + seed: An integer or a `tf.int64` scalar tensor. Returns: - A tuple of two @{tf.int64} scalar tensors that should be used for the local + A tuple of two `tf.int64` scalar tensors that should be used for the local seed of the calling dataset. """ seed, seed2 = random_seed.get_seed(seed) diff --git a/tensorflow/python/data/util/random_seed_test.py b/tensorflow/python/data/util/random_seed_test.py index 33227e82afe6fe1c748693d107d4e9844abb8e09..a809151e6ef57de8a39806b8164f818d94b8a783 100644 --- a/tensorflow/python/data/util/random_seed_test.py +++ b/tensorflow/python/data/util/random_seed_test.py @@ -30,7 +30,7 @@ from tensorflow.python.platform import test class RandomSeedTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRandomSeed(self): zero_t = constant_op.constant(0, dtype=dtypes.int64, name='zero') one_t = constant_op.constant(1, dtype=dtypes.int64, name='one') diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 09062abd7446628ede12e782e202ee0e55905879..8a4ac6aaef4ca1a13c4041a217fbdafa192fb8fb 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -5,7 +5,7 @@ # # ":debug_py": Public Python methods and classes of tfdbg. # For API documentation, see https://www.tensorflow.org/api_docs/python/tfdbg -# For a user interface walkthrough, see https://www.tensorflow.org/programmers_guide/debugger +# For a user interface walkthrough, see https://www.tensorflow.org/guide/debugger # ":grpc_debug_server": Server interface for grpc:// debug URLs. package( @@ -167,6 +167,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:platform", + "//third_party/py/numpy", "@six_archive//:six", ], ) @@ -403,6 +404,7 @@ py_library( deps = [ ":debug_errors", ":debug_fibonacci", + ":debug_keras", ":debug_mnist", ":debug_tflearn_iris", ], @@ -453,6 +455,17 @@ py_binary( ], ) +py_binary( + name = "debug_keras", + srcs = ["examples/debug_keras.py"], + srcs_version = "PY2AND3", + deps = [ + ":debug_py", + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) + py_test( name = "common_test", size = "small", @@ -790,6 +803,7 @@ cuda_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:variables", ], + tags = ["no_windows_gpu"], ) py_test( @@ -802,6 +816,7 @@ py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", "//tensorflow/python:platform_test", + "//third_party/py/numpy", ], ) @@ -921,7 +936,6 @@ py_test( size = "small", srcs = ["cli/profile_analyzer_cli_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], deps = [ ":debugger_cli_common", ":profile_analyzer_cli", @@ -1084,6 +1098,7 @@ py_test( "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python:variables", + "//third_party/py/numpy", ], ) @@ -1094,6 +1109,7 @@ sh_test( data = [ ":debug_errors", ":debug_fibonacci", + ":debug_keras", ":debug_mnist", ":debug_tflearn_iris", ":offline_analyzer", diff --git a/tensorflow/python/debug/README.md b/tensorflow/python/debug/README.md index 269bbb19bdb898d1d81d0b9c618a284a437e68b9..9c16af4d79754cee5d77158d5c2466412c6b9e68 100644 --- a/tensorflow/python/debug/README.md +++ b/tensorflow/python/debug/README.md @@ -28,7 +28,7 @@ models: * Easy access through session wrappers * Easy integration with common high-level APIs, such as - [TensorFlow Estimators](https://www.tensorflow.org/programmers_guide/estimators) and + [TensorFlow Estimators](https://www.tensorflow.org/guide/estimators) and [Keras](https://keras.io/) * Inspection of runtime tensor values and node connections * Conditional breaking after runs that generate tensors satisfying given @@ -43,7 +43,7 @@ models: ## How to use TFDBG? -* For a walkthrough of TFDBG command-line interface, see https://www.tensorflow.org/programmers_guide/debugger. +* For a walkthrough of TFDBG command-line interface, see https://www.tensorflow.org/guide/debugger. * For information on the web GUI of TFDBG (TensorBoard Debugger Plugin), see [this README](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md). * For programmatic use of the API of TFDBG, see https://www.tensorflow.org/api_docs/python/tfdbg. diff --git a/tensorflow/python/debug/cli/debugger_cli_common.py b/tensorflow/python/debug/cli/debugger_cli_common.py index 12e79ab07a4655c7d41f41d2e71906273e154a08..02563fde845e7951046a8bcd65899ef5e1fcc35f 100644 --- a/tensorflow/python/debug/cli/debugger_cli_common.py +++ b/tensorflow/python/debug/cli/debugger_cli_common.py @@ -23,9 +23,11 @@ import re import sre_constants import traceback +import numpy as np import six from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python import pywrap_tensorflow_internal from tensorflow.python.platform import gfile HELP_INDENT = " " @@ -131,6 +133,25 @@ def rich_text_lines_from_rich_line_list(rich_text_list, annotations=None): return RichTextLines(lines, font_attr_segs, annotations=annotations) +def get_tensorflow_version_lines(include_dependency_versions=False): + """Generate RichTextLines with TensorFlow version info. + + Args: + include_dependency_versions: Include the version of TensorFlow's key + dependencies, such as numpy. + + Returns: + A formatted, multi-line `RichTextLines` object. + """ + lines = ["TensorFlow version: %s" % pywrap_tensorflow_internal.__version__] + lines.append("") + if include_dependency_versions: + lines.append("Dependency version(s):") + lines.append(" numpy: %s" % np.__version__) + lines.append("") + return RichTextLines(lines) + + class RichTextLines(object): """Rich multi-line text. @@ -538,6 +559,8 @@ class CommandHandlerRegistry(object): HELP_COMMAND = "help" HELP_COMMAND_ALIASES = ["h"] + VERSION_COMMAND = "version" + VERSION_COMMAND_ALIASES = ["ver"] def __init__(self): # A dictionary from command prefix to handler. @@ -562,6 +585,13 @@ class CommandHandlerRegistry(object): "Print this help message.", prefix_aliases=self.HELP_COMMAND_ALIASES) + # Register a default handler for the command "version". + self.register_command_handler( + self.VERSION_COMMAND, + self._version_handler, + "Print the versions of TensorFlow and its key dependencies.", + prefix_aliases=self.VERSION_COMMAND_ALIASES) + def register_command_handler(self, prefix, handler, @@ -763,6 +793,11 @@ class CommandHandlerRegistry(object): else: return RichTextLines(["ERROR: help takes only 0 or 1 input argument."]) + def _version_handler(self, args, screen_info=None): + del args # Unused currently. + del screen_info # Unused currently. + return get_tensorflow_version_lines(include_dependency_versions=True) + def _resolve_prefix(self, token): """Resolve command prefix from the prefix itself or its alias. diff --git a/tensorflow/python/debug/cli/debugger_cli_common_test.py b/tensorflow/python/debug/cli/debugger_cli_common_test.py index 1b7a5962fe7dc4e19446c3e3b0aeab672eb30f1f..aba95e5820b1d8c6b3811fc69328317ce2c3ac64 100644 --- a/tensorflow/python/debug/cli/debugger_cli_common_test.py +++ b/tensorflow/python/debug/cli/debugger_cli_common_test.py @@ -21,6 +21,9 @@ import os import stat import tempfile +import numpy as np + +from tensorflow.python import pywrap_tensorflow_internal from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.framework import test_util from tensorflow.python.platform import gfile @@ -547,7 +550,10 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase): " Show screen width in number of columns.", "", "", "help", " Aliases: h", "", " Print this help message.", "", "", "noop", " Aliases: n, NOOP", "", - " No operation.", " I.e., do nothing.", "", ""], + " No operation.", " I.e., do nothing.", "", "", + "version", " Aliases: ver", "", + " Print the versions of TensorFlow and its key " + "dependencies.", "", ""], output.lines) # Get help for one specific command prefix. @@ -575,7 +581,9 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase): self.assertEqual(help_intro.lines + [ "help", " Aliases: h", "", " Print this help message.", "", "", "noop", " Aliases: n, NOOP", "", " No operation.", - " I.e., do nothing.", "", "" + " I.e., do nothing.", "", "", + "version", " Aliases: ver", "", + " Print the versions of TensorFlow and its key dependencies.", "", "" ], output.lines) @@ -1147,5 +1155,22 @@ class MenuTest(test_util.TensorFlowTestCase): self.assertEqual((40, 50, ["bold"]), output.font_attr_segs[0][2]) +class GetTensorFlowVersionLinesTest(test_util.TensorFlowTestCase): + + def testGetVersionWithoutDependencies(self): + out = debugger_cli_common.get_tensorflow_version_lines() + self.assertEqual(2, len(out.lines)) + self.assertEqual( + "TensorFlow version: %s" % pywrap_tensorflow_internal.__version__, + out.lines[0]) + + def testGetVersionWithDependencies(self): + out = debugger_cli_common.get_tensorflow_version_lines(True) + self.assertIn( + "TensorFlow version: %s" % pywrap_tensorflow_internal.__version__, + out.lines) + self.assertIn(" numpy: %s" % np.__version__, out.lines) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/debug/examples/README.md b/tensorflow/python/debug/examples/README.md index cb4d484092fe39698de1ff11e4d50d4879960e0c..3b431e04dc3565037dc018991bea68ab019e8af0 100644 --- a/tensorflow/python/debug/examples/README.md +++ b/tensorflow/python/debug/examples/README.md @@ -3,7 +3,7 @@ Hi, there! The documentation of **TensorFlow Debugger (tfdbg)** has moved. See the source version at -[this new location](../../../docs_src/programmers_guide/debugger.md). +[this new location](../../../docs_src/guide/debugger.md). See the public website version at -[https://www.tensorflow.org/programmers_guide/debugger](https://www.tensorflow.org/programmers_guide/debugger). +[https://www.tensorflow.org/guide/debugger](https://www.tensorflow.org/guide/debugger). diff --git a/tensorflow/python/debug/examples/debug_keras.py b/tensorflow/python/debug/examples/debug_keras.py new file mode 100644 index 0000000000000000000000000000000000000000..3272d85ade957b254b2c1a0977156179cd71bb9d --- /dev/null +++ b/tensorflow/python/debug/examples/debug_keras.py @@ -0,0 +1,89 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""tfdbg example: debugging tf.keras models training on tf.data.Dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import numpy as np +import tensorflow as tf + +from tensorflow.python import debug as tf_debug + + +def main(_): + # Create a dummy dataset. + num_examples = 8 + steps_per_epoch = 2 + input_dims = 3 + output_dims = 1 + xs = np.zeros([num_examples, input_dims]) + ys = np.zeros([num_examples, output_dims]) + dataset = tf.data.Dataset.from_tensor_slices( + (xs, ys)).repeat(num_examples).batch(int(num_examples / steps_per_epoch)) + + sess = tf.Session() + if FLAGS.debug: + # Use the command-line interface (CLI) of tfdbg. + sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type) + elif FLAGS.tensorboard_debug_address: + # Use the TensorBoard Debugger Plugin (GUI of tfdbg). + sess = tf_debug.TensorBoardDebugWrapperSession( + sess, FLAGS.tensorboard_debug_address) + tf.keras.backend.set_session(sess) + + # Create a dummy model. + model = tf.keras.Sequential([ + tf.keras.layers.Dense(1, input_shape=[input_dims])]) + model.compile(loss="mse", optimizer="sgd") + + # Train the model using the dummy dataset created above. + model.fit(dataset, epochs=FLAGS.epochs, steps_per_epoch=steps_per_epoch) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--debug", + type="bool", + nargs="?", + const=True, + default=False, + help="Use debugger to track down bad values during training. " + "Mutually exclusive with the --tensorboard_debug_address flag.") + parser.add_argument( + "--ui_type", + type=str, + default="curses", + help="Command-line user interface type (curses | readline).") + parser.add_argument( + "--tensorboard_debug_address", + type=str, + default=None, + help="Connect to the TensorBoard Debugger Plugin backend specified by " + "the gRPC address (e.g., localhost:1234). Mutually exclusive with the " + "--debug flag.") + parser.add_argument( + "--epochs", + type=int, + default=2, + help="Number of epochs to train the model for.") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/debug/examples/examples_test.sh b/tensorflow/python/debug/examples/examples_test.sh index e9c45a7e6e92d069f51648647620f7a7c3a5aadc..f7d597c8c065ced5efe95031a83877a92d7ccae1 100755 --- a/tensorflow/python/debug/examples/examples_test.sh +++ b/tensorflow/python/debug/examples/examples_test.sh @@ -48,12 +48,14 @@ if [[ -z "${PYTHON_BIN_PATH}" ]]; then DEBUG_ERRORS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_errors" DEBUG_MNIST_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_mnist" DEBUG_TFLEARN_IRIS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_tflearn_iris" + DEBUG_KERAS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_keras" OFFLINE_ANALYZER_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/offline_analyzer" else DEBUG_FIBONACCI_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_fibonacci" DEBUG_ERRORS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_errors" DEBUG_MNIST_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_mnist" DEBUG_TFLEARN_IRIS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_tflearn_iris" + DEBUG_KERAS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_keras" OFFLINE_ANALYZER_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.cli.offline_analyzer" fi @@ -96,6 +98,11 @@ if [[ -d "${CUSTOM_DUMP_ROOT}" ]]; then exit 1 fi +# Test debugging of tf.keras. +cat << EOF | ${DEBUG_KERAS_BIN} --debug --ui_type=readline +run -f has_inf_or_nan +EOF + # Test offline_analyzer. echo echo "Testing offline_analyzer" diff --git a/tensorflow/python/debug/lib/debug_gradients.py b/tensorflow/python/debug/lib/debug_gradients.py index 589a13db7f798aef3bb82dfbd442deabfbcf2a41..5e95bcba479a4365d3a140ab85ad7492a13a2482 100644 --- a/tensorflow/python/debug/lib/debug_gradients.py +++ b/tensorflow/python/debug/lib/debug_gradients.py @@ -69,7 +69,7 @@ class GradientsDebugger(object): """Gradients Debugger. Allows retrieval of gradient tensors created by TensorFlow's automatic - differentiation algorithm, i.e., @{tf.gradients} and optimizer classes that + differentiation algorithm, i.e., `tf.gradients` and optimizer classes that use it. """ # TODO(cais): Add examples code in the doc string? @@ -142,8 +142,8 @@ class GradientsDebugger(object): Args: input_tensor: the input `tf.Tensor` object whose related gradient tensors are to be reigstered with this `GradientsDebugger` instance when they - are created, e.g., during @{tf.gradients} calls or the construction - of optimization (training) op that uses @{tf.gradients}. + are created, e.g., during `tf.gradients` calls or the construction + of optimization (training) op that uses `tf.gradients`. Returns: A forwarded identity of `input_tensor`, as a `tf.Tensor`. diff --git a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py index bd00f738610627a4b3bc7c61476164188a7b460c..676097fde95e2e5a685e8e43f8f38d3e62e7084a 100644 --- a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py +++ b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py @@ -44,7 +44,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase): def _no_rewrite_session_config(self): rewriter_config = rewriter_config_pb2.RewriterConfig( - dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, + min_graph_nodes=-1) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) return config_pb2.ConfigProto(graph_options=graph_options) diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper.py b/tensorflow/python/debug/wrappers/dumping_wrapper.py index 3fac2e59717a828424a808b770812afc7772bfe2..c02d5f66ec96d3428ee36e68b69d103af8fc1352 100644 --- a/tensorflow/python/debug/wrappers/dumping_wrapper.py +++ b/tensorflow/python/debug/wrappers/dumping_wrapper.py @@ -45,7 +45,7 @@ class DumpingDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): session_root: (`str`) Path to the session root directory. Must be a directory that does not exist or an empty directory. If the directory does not exist, it will be created by the debugger core during debug - @{tf.Session.run} + `tf.Session.run` calls. As the `run()` calls occur, subdirectories will be added to `session_root`. The subdirectories' names has the following pattern: diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index c530204bbf6959f56a72c6e67add91f1e575f067..b9524ce649c7d6d888affacc22cfadd41dbe2e40 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -392,6 +392,9 @@ class BaseDebugWrapperSession(session.SessionInterface): self._default_session_context_manager = None + # A cache for callables created from CallableOptions. + self._cached_callables_from_options = dict() + @property def graph(self): return self._sess.graph @@ -414,7 +417,8 @@ class BaseDebugWrapperSession(session.SessionInterface): options=None, run_metadata=None, callable_runner=None, - callable_runner_args=None): + callable_runner_args=None, + callable_options=None): """Wrapper around Session.run() that inserts tensor watch options. Args: @@ -424,7 +428,12 @@ class BaseDebugWrapperSession(session.SessionInterface): run_metadata: Same as the `run_metadata` arg to regular `Session.run()`. callable_runner: A `callable` returned by `Session.make_callable()`. If not `None`, `fetches` and `feed_dict` must both be `None`. - callable_runner_args: An optional list of arguments to `callable_runner`. + Mutually exclusive with `callable_options`. + callable_runner_args: An optional list of arguments to `callable_runner` + or for `callable_options`. + callable_options: An instance of `config_pb2.CallableOptions`, to be + used with `Session._make_callable_from_options()`. Mutually exclusive + with `callable_runner`. Returns: Simply forwards the output of the wrapped `Session.run()` call. @@ -433,13 +442,17 @@ class BaseDebugWrapperSession(session.SessionInterface): ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner` is not `None` and either or both of `fetches` and `feed_dict` is `None`. """ - if not callable_runner: + if callable_runner and callable_options: + raise ValueError( + "callable_runner and callable_options are mutually exclusive, but " + "are both specified in this call to BaseDebugWrapperSession.run().") + + if not (callable_runner or callable_options): self.increment_run_call_count() - else: - if fetches or feed_dict: - raise ValueError( - "callable_runner and fetches/feed_dict are mutually exclusive, but " - "are used simultaneously.") + elif callable_runner and (fetches or feed_dict): + raise ValueError( + "callable_runner and fetches/feed_dict are mutually exclusive, " + "but are used simultaneously.") empty_fetches = not nest.flatten(fetches) if empty_fetches: @@ -449,6 +462,11 @@ class BaseDebugWrapperSession(session.SessionInterface): if self._is_disabled_thread() or empty_fetches: if callable_runner: return callable_runner(*callable_runner_args) + elif callable_options: + # pylint:disable=protected-access + return self._sess._make_callable_from_options( + callable_options)(*callable_runner_args) + # pylint:enable=protected-access else: return self._sess.run(fetches, feed_dict=feed_dict, @@ -464,19 +482,30 @@ class BaseDebugWrapperSession(session.SessionInterface): if run_start_resp.action == OnRunStartAction.DEBUG_RUN: # Decorate RunOption to fill in debugger tensor watch specifications. - decorated_run_options = options or config_pb2.RunOptions() + decorated_run_options = None + if callable_options: + callable_options_id = id(callable_options) + if callable_options_id not in self._cached_callables_from_options: + # Make a copy of callable_options to avoid mutating it. + new_callable_options = config_pb2.CallableOptions() + new_callable_options.CopyFrom(callable_options) + decorated_run_options = new_callable_options.run_options + else: + decorated_run_options = options or config_pb2.RunOptions() + run_metadata = run_metadata or config_pb2.RunMetadata() - self._decorate_run_options_for_debug( - decorated_run_options, - run_start_resp.debug_urls, - debug_ops=run_start_resp.debug_ops, - node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, - op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, - tensor_dtype_regex_whitelist=( - run_start_resp.tensor_dtype_regex_whitelist), - tolerate_debug_op_creation_failures=( - run_start_resp.tolerate_debug_op_creation_failures)) + if decorated_run_options: + self._decorate_run_options_for_debug( + decorated_run_options, + run_start_resp.debug_urls, + debug_ops=run_start_resp.debug_ops, + node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, + op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, + tensor_dtype_regex_whitelist=( + run_start_resp.tensor_dtype_regex_whitelist), + tolerate_debug_op_creation_failures=( + run_start_resp.tolerate_debug_op_creation_failures)) # Invoke the run() method of the wrapped Session. Catch any TensorFlow # runtime errors. @@ -486,6 +515,19 @@ class BaseDebugWrapperSession(session.SessionInterface): retvals = callable_runner(*callable_runner_args, options=decorated_run_options, run_metadata=run_metadata) + elif callable_options: + # pylint:disable=protected-access + if callable_options_id in self._cached_callables_from_options: + callable_object = self._cached_callables_from_options[ + callable_options_id] + else: + callable_object = self._sess._make_callable_from_options( + new_callable_options) + self._cached_callables_from_options[ + callable_options_id] = callable_object + # pylint:enable=protected-access + retvals = callable_object( + *callable_runner_args, run_metadata=run_metadata) else: retvals = self._sess.run(fetches, feed_dict=feed_dict, @@ -590,7 +632,14 @@ class BaseDebugWrapperSession(session.SessionInterface): run_metadata=kwargs.get("run_metadata", None), callable_runner=runner, callable_runner_args=runner_args) + return wrapped_runner + def _make_callable_from_options(self, callable_options): + def wrapped_runner(*feed_values, **kwargs): + return self.run(None, + run_metadata=kwargs.get("run_metadata", None), + callable_options=callable_options, + callable_runner_args=feed_values) return wrapped_runner @property diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py index 1f9c8fa5a96b4d6826fae0870608e0e737c7cd88..85944fa61118114cc73f9288f3f974f0a5a8a839 100644 --- a/tensorflow/python/debug/wrappers/grpc_wrapper.py +++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py @@ -215,7 +215,8 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): options=None, run_metadata=None, callable_runner=None, - callable_runner_args=None): + callable_runner_args=None, + callable_options=None): if self._send_traceback_and_source_code: self._sent_graph_version = publish_traceback( self._grpc_debug_server_urls, self.graph, feed_dict, fetches, @@ -226,4 +227,5 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): options=options, run_metadata=run_metadata, callable_runner=callable_runner, - callable_runner_args=callable_runner_args) + callable_runner_args=callable_runner_args, + callable_options=callable_options) diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py index c8625655e51a43a222addedd4beecdd3515d7fb6..668ffb57f10a69ce7e11e889fe613afbd618e823 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py @@ -290,6 +290,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): if self._run_call_count == 1: # Show logo at the onset of the first run. help_intro.extend(cli_shared.get_tfdbg_logo()) + help_intro.extend(debugger_cli_common.get_tensorflow_version_lines()) help_intro.extend(debugger_cli_common.RichTextLines("Upcoming run:")) help_intro.extend(self._run_info) @@ -466,6 +467,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): if self._run_call_count == 1: output.extend(cli_shared.get_tfdbg_logo()) + output.extend(debugger_cli_common.get_tensorflow_version_lines()) output.extend(self._run_info) if (not self._is_run_start and @@ -594,7 +596,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): # Register tab completion for the filter names. curses_cli.register_tab_comp_context(["run", "r"], list(self._tensor_filters.keys())) - if self._feed_dict: + if self._feed_dict and hasattr(self._feed_dict, "keys"): # Register tab completion for feed_dict keys. feed_keys = [common.get_graph_element_name(key) for key in self._feed_dict.keys()] diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py index b06fa26a935b42709575f8e400e0bda951ffbbc7..05c9eaa4d27319ecf5e12fdeb0a973246c61704a 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py @@ -21,7 +21,10 @@ import os import shutil import tempfile +import numpy as np + from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.debug.cli import cli_shared from tensorflow.python.debug.cli import debugger_cli_common @@ -149,7 +152,13 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): dtypes.float32, shape=([5, 5]), name="sparse_placeholder") self.sparse_add = sparse_ops.sparse_add(self.sparse_ph, self.sparse_ph) - self.sess = session.Session() + rewriter_config = rewriter_config_pb2.RewriterConfig( + disable_model_pruning=True, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) + config_proto = config_pb2.ConfigProto(graph_options=graph_options) + self.sess = session.Session(config=config_proto) # Initialize variable. self.sess.run(variables.global_variables_initializer()) @@ -393,6 +402,113 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase): self.assertAllClose(42.0, tensor_runner(41.0, 1.0)) self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"])) + def testDebuggingMakeCallableFromOptionsWithZeroFeedWorks(self): + variable_1 = variables.Variable( + 10.5, dtype=dtypes.float32, name="variable_1") + a = math_ops.add(variable_1, variable_1, "callable_a") + math_ops.add(a, a, "callable_b") + self.sess.run(variable_1.initializer) + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]] * 3, self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.fetch.append("callable_b") + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + for _ in range(2): + callable_output = sess_callable() + self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(2, len(debug_dumps)) + for debug_dump in debug_dumps: + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual( + ["callable_a", "callable_b", "variable_1", "variable_1/read"], + node_names) + + def testDebuggingMakeCallableFromOptionsWithOneFeedWorks(self): + ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1") + a = math_ops.add(ph1, ph1, "callable_a") + math_ops.add(a, a, "callable_b") + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]] * 3, self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.feed.append("callable_ph1") + callable_options.fetch.append("callable_b") + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + ph1_value = np.array([10.5, -10.5], dtype=np.float32) + + for _ in range(2): + callable_output = sess_callable(ph1_value) + self.assertAllClose( + np.array([42.0, -42.0], dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(2, len(debug_dumps)) + for debug_dump in debug_dumps: + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual(["callable_a", "callable_b"], node_names) + + def testDebuggingMakeCallableFromOptionsWithTwoFeedsWorks(self): + ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1") + ph2 = array_ops.placeholder(dtypes.float32, name="callable_ph2") + a = math_ops.add(ph1, ph2, "callable_a") + math_ops.add(a, a, "callable_b") + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"]] * 3, self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.feed.append("callable_ph1") + callable_options.feed.append("callable_ph2") + callable_options.fetch.append("callable_b") + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + ph1_value = np.array(5.0, dtype=np.float32) + ph2_value = np.array(16.0, dtype=np.float32) + + for _ in range(2): + callable_output = sess_callable(ph1_value, ph2_value) + self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(2, len(debug_dumps)) + for debug_dump in debug_dumps: + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual(["callable_a", "callable_b"], node_names) + + def testDebugMakeCallableFromOptionsWithCustomOptionsAndMetadataWorks(self): + variable_1 = variables.Variable( + 10.5, dtype=dtypes.float32, name="variable_1") + a = math_ops.add(variable_1, variable_1, "callable_a") + math_ops.add(a, a, "callable_b") + self.sess.run(variable_1.initializer) + + wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( + [["run"], ["run"]], self.sess, dump_root=self._tmp_dir) + callable_options = config_pb2.CallableOptions() + callable_options.fetch.append("callable_b") + callable_options.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE + + sess_callable = wrapped_sess._make_callable_from_options(callable_options) + + run_metadata = config_pb2.RunMetadata() + # Call the callable with a custom run_metadata. + callable_output = sess_callable(run_metadata=run_metadata) + # Verify that step_stats is populated in the custom run_metadata. + self.assertTrue(run_metadata.step_stats) + self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0]) + + debug_dumps = wrapped_sess.observers["debug_dumps"] + self.assertEqual(1, len(debug_dumps)) + debug_dump = debug_dumps[0] + node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data] + self.assertItemsEqual( + ["callable_a", "callable_b", "variable_1", "variable_1/read"], + node_names) + def testRuntimeErrorShouldBeCaught(self): wrapped_sess = LocalCLIDebuggerWrapperSessionForTest( [["run"], ["run"]], self.sess, dump_root=self._tmp_dir) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..68d8b8d13b1327cb5c540e868aaddc8bde684082 --- /dev/null +++ b/tensorflow/python/distribute/BUILD @@ -0,0 +1,43 @@ +package( + default_visibility = ["//tensorflow:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "distribute_coordinator", + srcs = [ + "distribute_coordinator.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:training", + ], +) + +py_test( + name = "distribute_coordinator_test", + size = "large", + srcs = ["distribute_coordinator_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":distribute_coordinator", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:distributed_framework_test_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py new file mode 100644 index 0000000000000000000000000000000000000000..fc9ca4ac4a3fb159261cb8b5cbae36f0a4e5c97f --- /dev/null +++ b/tensorflow/python/distribute/distribute_coordinator.py @@ -0,0 +1,490 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A unified and split coordinator for distributed TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import json +import os +import threading + +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.training import server_lib + + +class _TaskType(object): + PS = "ps" + WORKER = "worker" + CHIEF = "chief" + EVALUATOR = "evaluator" + CLIENT = "client" + + +# TODO(yuefengz): support another mode where the client colocates with one +# worker. +class CoordinatorMode(object): + """Specify how distribute coordinator runs.""" + # The default mode where distribute coordinator will run as a standalone + # client and connects to remote servers for training. Each remote server can + # use the distribute coordinator binary with task_type set correctly which + # will then turn into standard servers. + SPLIT_CLIENT = 0 + + # The distribute coordinator runs on each worker. It will run a standard + # server on each worker and optionally run the `worker_fn` that is configured + # to talk to its standard server. + INDEPENDENT_WORKER = 1 + + +_worker_context = threading.local() + + +def get_current_worker_context(): + """Returns the current task context.""" + try: + return _worker_context.current + except AttributeError: + return None + + +class _Barrier(object): + """A reusable barrier class for worker synchronization.""" + + def __init__(self, num_participants): + """Initializes the barrier object. + + Args: + num_participants: an integer which is the expected number of calls of + `wait` pass to through this barrier. + """ + self._num_participants = num_participants + self._counter = 0 + self._flag = False + self._local_sense = threading.local() + self._lock = threading.Lock() + self._condition = threading.Condition() + + def wait(self): + """Waits until all other callers reach the same wait call.""" + if not hasattr(self._local_sense, "value"): + self._local_sense.value = False + self._local_sense.value = not self._flag + with self._lock: + self._counter += 1 + if self._counter == self._num_participants: + self._counter = 0 + self._flag = self._local_sense.value + with self._condition: + while self._flag != self._local_sense.value: + self._condition.wait() + self._condition.notify_all() + + +def _get_num_workers(cluster_spec): + """Gets number of workers including chief.""" + if not cluster_spec: + return 0 + return len(cluster_spec.as_dict().get(_TaskType.WORKER, [])) + len( + cluster_spec.as_dict().get(_TaskType.CHIEF, [])) + + +class _WorkerContext(object): + """The worker context class. + + This context object provides configuration information for each task. One + context manager with a worker context object will be created per + invocation to the `worker_fn` where `get_current_worker_context` can be called + to access the worker context object. + """ + + def __init__(self, + cluster_spec, + task_type, + task_id, + rpc_layer="grpc", + worker_barrier=None): + """Initialize the worker context object. + + Args: + cluster_spec: a ClusterSpec object. It can be empty or None in the local + training case. + task_type: a string indicating the role of the corresponding task, such as + "worker" or "ps". It can be None if it is local training or in-graph + replicated training. + task_id: an integer indicating id of the corresponding task. It can be + None if it is local training or in-graph replicated training. + rpc_layer: optional string specifying the RPC protocol for communication + with worker masters. If None or empty, hosts in the `cluster_spec` will + be used directly. + worker_barrier: optional, the barrier object for worker synchronization. + """ + self._cluster_spec = cluster_spec + self._task_type = task_type + self._task_id = task_id + self._worker_barrier = worker_barrier + self._rpc_layer = rpc_layer + self._master_target = self._get_master_target() + self._num_workers = _get_num_workers(cluster_spec) + self._is_chief_node = self._is_chief() + + def _debug_message(self): + return "[cluster_spec: %r, task_type: %r, task_id: %r]" % ( + self._cluster_spec, self.task_type, self.task_id) + + def __enter__(self): + old_context = get_current_worker_context() + if old_context: + raise ValueError( + "You cannot run distribute coordinator in a `worker_fn`.\t" + + self._debug_message()) + _worker_context.current = self + + def __exit__(self, unused_exception_type, unused_exception_value, + unused_traceback): + _worker_context.current = None + + def _get_master_target(self): + """Return the master target for a task.""" + # If cluster_spec is None or empty, we use local master. + if not self._cluster_spec: + return "local" + + # If task_type is None, then it is in-graph replicated training. In this + # case we use the chief or first worker's master target. + if not self._task_type: + if _TaskType.CHIEF in self._cluster_spec.jobs: + task_type = _TaskType.CHIEF + task_id = 0 + else: + assert _TaskType.WORKER in self._cluster_spec.jobs + task_type = _TaskType.WORKER + task_id = 0 + else: + task_type = self._task_type + task_id = self._task_id + + prefix = "" + if self._rpc_layer: + prefix = self._rpc_layer + "://" + return prefix + self._cluster_spec.job_tasks(task_type)[task_id or 0] + + def _is_chief(self): + """Return whether the task is the chief worker.""" + if (not self._cluster_spec or + self._task_type in [_TaskType.CHIEF, _TaskType.EVALUATOR, None]): + return True + + # If not local and chief not in the cluster_spec, use the first worker as + # chief. + if (_TaskType.CHIEF not in self._cluster_spec.jobs and + self._task_type == _TaskType.WORKER and self._task_id == 0): + return True + return False + + def wait_for_other_workers(self): + """Waits for other workers to reach the same call to this method. + + Raises: + ValueError: if `worker_barrier` is not passed to the __init__ method. + """ + if not self._worker_barrier: + raise ValueError("`worker_barrier is not set in the worker context.` \t" + + self._debug_message()) + self._worker_barrier.wait() + + @property + def has_barrier(self): + """Whether the barrier is set or not.""" + return self._worker_barrier is not None + + @property + def distributed_mode(self): + """Whether it is distributed training or not.""" + return bool(self._cluster_spec) and self._task_type != _TaskType.EVALUATOR + + @property + def cluster_spec(self): + """Returns a copy of the cluster_spec object.""" + return copy.deepcopy(self._cluster_spec) + + @property + def task_type(self): + """Returns the role of the corresponing task.""" + return self._task_type + + @property + def task_id(self): + """Returns the id or index of the corresponing task.""" + return self._task_id + + @property + def master_target(self): + """Returns the session master for the corresponding task to connect to.""" + return self._master_target + + @property + def is_chief(self): + """Returns whether the task is a chief node.""" + return self._is_chief_node + + @property + def num_workers(self): + """Returns number of workers in the cluster, including chief.""" + return self._num_workers + + +def _run_single_worker(worker_fn, + cluster_spec, + task_type, + task_id, + rpc_layer, + worker_barrier=None): + """Runs a single worker by calling `worker_fn` under context.""" + with _WorkerContext( + cluster_spec, + task_type, + task_id, + rpc_layer=rpc_layer, + worker_barrier=worker_barrier): + worker_fn() + + +def _run_std_server(cluster_spec=None, + task_type=None, + task_id=None, + session_config=None, + rpc_layer=None): + """Runs a standard server.""" + server = server_lib.Server( + cluster_spec, + job_name=task_type, + task_index=task_id, + config=session_config, + protocol=rpc_layer) + server.start() + return server + + +def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer): + """Runs a standalone client for between-graph replication.""" + eval_thread = None + if _TaskType.EVALUATOR in cluster_spec.jobs: + eval_thread = threading.Thread( + target=_run_single_worker, + args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0), + kwargs={ + "rpc_layer": rpc_layer, + }) + eval_thread.start() + + threads = [] + worker_barrier = _Barrier(_get_num_workers(cluster_spec)) + for task_type in [_TaskType.CHIEF, _TaskType.WORKER]: + for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): + t = threading.Thread( + target=_run_single_worker, + args=(worker_fn, cluster_spec, task_type, task_id), + kwargs={ + "rpc_layer": rpc_layer, + "worker_barrier": worker_barrier + }) + t.start() + threads.append(t) + + # TODO(yuefengz): wrap threads into thread coordinator? + for t in threads: + t.join() + + # TODO(yuefengz): is it necessary to join eval thread? + if eval_thread: + eval_thread.join() + + +def _run_in_graph_client(worker_fn, cluster_spec, rpc_layer): + """Runs a standalone client for in-graph replication.""" + eval_thread = None + if _TaskType.EVALUATOR in cluster_spec.jobs: + eval_thread = threading.Thread( + target=_run_single_worker, + args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0), + kwargs={ + "rpc_layer": rpc_layer, + }) + eval_thread.start() + + _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer) + if eval_thread: + eval_thread.join() + + +# TODO(yuefengz): propagate cluster_spec in the SPLIT_CLIENT mode. +# TODO(yuefengz): we may need a smart way to figure out whether the current task +# is the special task when we support cluster_spec propagation. +def run_distribute_coordinator(worker_fn, + mode=CoordinatorMode.SPLIT_CLIENT, + cluster_spec=None, + task_type=None, + task_id=None, + between_graph=False, + rpc_layer="grpc"): + """Runs the coordinator for distributed TensorFlow. + + This function runs a split coordinator for distributed TensorFlow in its + default mode, i.e the SPLIT_CLIENT mode. Given a `cluster_spec` specifying + server addresses and their roles in a cluster, this coordinator will figure + out how to set them up, give the underlying function the right targets for + master sessions via a scope object and coordinate their training. The cluster + consisting of standard servers needs to be brought up either with the standard + server binary or with a binary running distribute coordinator with `task_type` + set to non-client type which will then turn into standard servers. + + In addition to be the distribute coordinator, this is also the source of + configurations for each job in the distributed training. As there are multiple + ways to configure a distributed TensorFlow cluster, its context object + provides these configurations so that users or higher-level APIs don't have to + figure out the configuration for each job by themselves. + + In the between-graph replicated training, this coordinator will create + multiple threads and each calls the `worker_fn` which is supposed to create + its own graph and connect to one worker master given by its context object. In + the in-graph replicated training, it has only one thread calling this + `worker_fn`. + + Another mode is the INDEPENDENT_WORKER mode where each server runs a + distribute coordinator which will start a standard server and optionally runs + `worker_fn` depending whether it is between-graph training or in-graph + replicated training. + + The `worker_fn` defines the training logic and is called under a its own + worker context which can be accessed to via `get_current_worker_context`. A + worker context provides access to configurations for each task, e.g. the + task_type, task_id, master target and so on. Since `worker_fn` will be called + in a thread and possibly multiple times, caller should be careful when it + accesses global data. For example, it is unsafe to define flags in a + `worker_fn` or to define different environment variables for different + `worker_fn`s. + + The `worker_fn` for the between-graph replication is defined as if there is + only one worker corresponding to the `worker_fn` and possibly ps jobs. For + example, when training with parameter servers, it assigns variables to + parameter servers and all other operations to that worker. In the in-graph + replication case, the `worker_fn` has to define operations for all worker + jobs. Using a distribution strategy can simplify the `worker_fn` by not having + to worry about the replication and device assignment of variables and + operations. + + This method is intended to be invoked by high-level APIs so that users don't + have to explictly call it to run this coordinator. For those who don't use + high-level APIs, to change a program to use this coordinator, wrap everything + in a the program after global data definitions such as commandline flag + definition into the `worker_fn` and get task-specific configurations from + the worker context. + + The `cluster_spec` can be either passed by the argument or parsed from the + "TF_CONFIG" envrionment variable. Example of a TF_CONFIG: + ``` + cluster = {'chief': ['host0:2222'], + 'ps': ['host1:2222', 'host2:2222'], + 'worker': ['host3:2222', 'host4:2222', 'host5:2222']} + os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster}) + ``` + + If `cluster_spec` is not given in any format, it becomes local training and + this coordinator will connect to a local session. + + For evaluation, if "evaluator" exist in the cluster_spec, a separate thread + will be created with its `task_type` set to "evaluator". If "evaluator" is not + set in the cluster_spec, it entirely depends on the `worker_fn` for how to do + evaluation. + + Args: + worker_fn: the function to be called and given the access to a coordinator + context object. + mode: in which mode this distribute coordinator runs. + cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles + in a cluster. If not set or empty, fall back to local training. + task_type: the current task type, optional if this is a client. + task_id: the current task id, optional if this is a client. + between_graph: a boolean. It is only useful when `cluster_spec` is set and + not empty. If true, it will use between-graph replicated training; + otherwise it will use in-graph replicated training. + rpc_layer: optional string, the protocol for RPC, e.g. "grpc". + + Raises: + ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef or + a ClusterSpec. + """ + tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) + if not cluster_spec: + cluster_spec = tf_config.get("cluster", {}) + task_env = tf_config.get("task", {}) + if task_env: + task_type = task_env.get("type", task_type) + task_id = int(task_env.get("index", task_id)) + + if cluster_spec: + if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): + cluster_spec = server_lib.ClusterSpec(cluster_spec) + elif not isinstance(cluster_spec, server_lib.ClusterSpec): + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object") + # TODO(yuefengz): validate cluster_spec. + + if not cluster_spec: + # `mode` is ignored in the local case. + _run_single_worker(worker_fn, None, None, None, rpc_layer) + elif mode == CoordinatorMode.SPLIT_CLIENT: + # The client must know the cluster but servers in the cluster don't have to + # know the client. + if task_type in [_TaskType.CLIENT, None]: + if between_graph: + _run_between_graph_client(worker_fn, cluster_spec, rpc_layer) + else: + _run_in_graph_client(worker_fn, cluster_spec, rpc_layer) + else: + # If not a client job, run the standard server. + server = _run_std_server( + cluster_spec=cluster_spec, task_type=task_type, task_id=task_id) + server.join() + else: + if mode != CoordinatorMode.INDEPENDENT_WORKER: + raise ValueError("Unexpected coordinator mode: %r" % mode) + + # Every one starts a standard server. + server = _run_std_server( + cluster_spec=cluster_spec, task_type=task_type, task_id=task_id) + + if task_type in [_TaskType.CHIEF, _TaskType.WORKER]: + if between_graph: + # All jobs run `worker_fn` if between-graph. + _run_single_worker(worker_fn, cluster_spec, task_type, task_id, + rpc_layer) + else: + # Only one node runs `worker_fn` if in-graph. + context = _WorkerContext(cluster_spec, task_type, task_id, rpc_layer) + if context.is_chief: + _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer) + else: + server.join() + elif task_type == _TaskType.EVALUATOR: + _run_single_worker(worker_fn, cluster_spec, task_type, task_id, rpc_layer) + else: + if task_type != _TaskType.PS: + raise ValueError("Unexpected task_type: %r" % task_type) + server.join() diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..319c29ba2fae9829526fde03e02f5a77b8c4d46c --- /dev/null +++ b/tensorflow/python/distribute/distribute_coordinator_test.py @@ -0,0 +1,559 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for distribute coordinator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import copy +import os +import sys +import threading +import six + +# pylint: disable=invalid-name +_portpicker_import_error = None +try: + import portpicker # pylint: disable=g-import-not-at-top +except ImportError as _error: + _portpicker_import_error = _error + portpicker = None +# pylint: enable=invalid-name + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.distribute import distribute_coordinator +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + +CHIEF = distribute_coordinator._TaskType.CHIEF +WORKER = distribute_coordinator._TaskType.WORKER +PS = distribute_coordinator._TaskType.PS +EVALUATOR = distribute_coordinator._TaskType.EVALUATOR + +SPLIT_CLIENT = distribute_coordinator.CoordinatorMode.SPLIT_CLIENT +INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER + +RUN_STD_SERVER_METHOD = "tensorflow.python.distribute.distribute_coordinator._run_std_server" + +NUM_WORKERS = 3 +NUM_PS = 2 + + +def _bytes_to_str(maybe_bytes): + if isinstance(maybe_bytes, six.string_types): + return maybe_bytes + else: + return str(maybe_bytes, "utf-8") + + +def _strip_protocol(target): + # cluster_spec expects "host:port" strings. + if "//" in target: + return target.split("//")[1] + else: + return target + + +class MockServer(object): + + def __init__(self): + self._joined = False + + def join(self): + assert not self._joined + self._joined = True + + @property + def joined(self): + return self._joined + + +class DistributeCoordinatorTestBase(test.TestCase): + + @classmethod + def setUpClass(cls): + # We have to create a global in-process cluster because once an in-process + # tensorflow server is created, there is no way to terminate it. Please see + # multi_worker_test_base.py for more details. + cls._workers, cls._ps = test_util.create_local_cluster( + NUM_WORKERS, num_ps=NUM_PS) + cls._cluster_spec = { + WORKER: [ + _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers + ], + PS: [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps] + } + + def setUp(self): + self._result_correct = 0 + self._lock = threading.Lock() + self._worker_context = {} + self._std_servers = {} + self._barrier = distribute_coordinator._Barrier(NUM_WORKERS) + + @contextlib.contextmanager + def _test_session(self, target): + config = config_pb2.ConfigProto(allow_soft_placement=True) + config.graph_options.optimizer_options.opt_level = -1 + with session.Session(graph=None, config=config, target=target) as sess: + yield sess + + def _create_cluster_spec(self, + has_chief=False, + num_workers=1, + num_ps=0, + has_eval=False): + if _portpicker_import_error: + raise _portpicker_import_error # pylint: disable=raising-bad-type + + cluster_spec = {} + if has_chief: + cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()] + if num_workers: + cluster_spec[WORKER] = [ + "localhost:%s" % portpicker.pick_unused_port() + for _ in range(num_workers) + ] + if num_ps: + cluster_spec[PS] = [ + "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps) + ] + if has_eval: + cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()] + return cluster_spec + + def _in_graph_worker_fn(self): + context = distribute_coordinator.get_current_worker_context() + self.assertTrue(context is not None) + with self._test_session(target=context.master_target) as sess: + xs = [] + expected = 0.0 + for i in range(context.num_workers): + with ops.device("/job:worker/task:%d" % i): + x = variable_scope.get_variable("x_%d" % i, initializer=10.0) + x_add = x.assign_add(float(i)) + xs.append(x_add) + expected += i + 10.0 + + with ops.device("/job:worker/task:0"): + result = math_ops.add_n(xs) + + variables.global_variables_initializer().run() + result_value = sess.run(result) + self.assertEqual(result_value, expected) + if result_value == expected: + self._result_correct += 1 + + def _run_coordinator_in_thread(self, worker_fn, **kwargs): + t = threading.Thread( + target=distribute_coordinator.run_distribute_coordinator, + args=(worker_fn,), + kwargs=kwargs) + t.start() + return t + + def _run_multiple_coordinator_in_threads(self, worker_fn, cluster_spec, + **kwargs): + threads = {} + for task_type in cluster_spec.keys(): + threads[task_type] = [] + for task_id in range(len(cluster_spec[task_type])): + t = self._run_coordinator_in_thread( + worker_fn, + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id, + **kwargs) + threads[task_type].append(t) + return threads + + def _between_graph_worker_fn(self): + context = distribute_coordinator.get_current_worker_context() + self.assertTrue(context is not None) + with self._test_session(target=context.master_target) as sess: + with ops.device("/job:ps/task:0"): + # TODO(yuefengz): investigate why not using resource variable will make + # the test flaky. + x = variable_scope.get_variable( + "x", initializer=10.0, use_resource=True) + with ops.device("/job:ps/task:1"): + y = variable_scope.get_variable( + "y", initializer=20.0, use_resource=True) + + x_add = x.assign_add(2.0) + y_sub = y.assign_sub(2.0) + train_op = control_flow_ops.group([x_add, y_sub]) + + if context.is_chief: + variables.global_variables_initializer().run() + + # Synchronize workers after initializaton. + if context.has_barrier: + context.wait_for_other_workers() + else: + while True: + uninit_vars = sess.run(variables.report_uninitialized_variables()) + # pylint: disable=g-explicit-length-test + if len(uninit_vars) == 0: + break + + sess.run(train_op) + + # Synchronize workers after one step to make sure they all have finished + # training. + if context.has_barrier: + context.wait_for_other_workers() + else: + self._barrier.wait() + + x_val, y_val = sess.run([x, y]) + + self.assertEqual(x_val, 16.0) + self.assertEqual(y_val, 14.0) + if x_val == 16.0 and y_val == 14.0: + with self._lock: + self._result_correct += 1 + + def _dump_worker_context(self): + """Dumps the propoerties of each worker context. + + It dumps the context properties to a dict mapping from task_type to a list + of tuples of master_target, num_workers, is_chief and distribute_mode, where + the list is indexed by the task_id. + """ + context = distribute_coordinator.get_current_worker_context() + self.assertTrue(context is not None) + task_type = str(context.task_type) + task_id = context.task_id or 0 + with self._lock: + if task_type not in self._worker_context: + self._worker_context[task_type] = [] + while len(self._worker_context[task_type]) <= task_id: + self._worker_context[task_type].append(None) + self._worker_context[task_type][task_id] = (context.master_target, + context.num_workers, + context.is_chief, + context.distributed_mode) + + def _run_mock_std_server(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None, + rpc_layer=None): + task_type = str(task_type) + task_id = task_id or 0 + with self._lock: + if task_type not in self._std_servers: + self._std_servers[task_type] = [] + while len(self._std_servers[task_type]) <= task_id: + self._std_servers[task_type].append(None) + + server = MockServer() + self._std_servers[task_type][task_id] = server + return server + + +class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): + + def testInGraphSplitMode(self): + """Test it runs in-graph replication in split client mode.""" + distribute_coordinator.run_distribute_coordinator( + self._in_graph_worker_fn, + cluster_spec=self._cluster_spec, + between_graph=False) + self.assertEqual(self._result_correct, 1) + + def testBetweenGraph(self): + """Test it runs between-graph replication in split client mode.""" + distribute_coordinator.run_distribute_coordinator( + self._between_graph_worker_fn, + cluster_spec=self._cluster_spec, + between_graph=True) + + # Each finished worker will increment self._result_correct. + self.assertEqual(self._result_correct, NUM_WORKERS) + + def testBetweenGraphContext(self): + # Dumps the task contexts to the self._worker_context dict. + distribute_coordinator.run_distribute_coordinator( + self._dump_worker_context, + cluster_spec=self._cluster_spec, + between_graph=True) + + # There is only one type of task and there three such tasks. + self.assertEqual(len(self._worker_context), 1) + self.assertTrue(WORKER in self._worker_context) + self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS) + + # Check whether each task has the right master_target, num_workers, is_chief + # and distributed_mode. + self.assertEqual( + self._worker_context[WORKER][0], + (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True)) + self.assertEqual( + self._worker_context[WORKER][1], + (_bytes_to_str(self._workers[1].target), NUM_WORKERS, False, True)) + self.assertEqual( + self._worker_context[WORKER][2], + (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True)) + + def testInGraphContext(self): + # Dumps the task contexts to the self._worker_context dict. + distribute_coordinator.run_distribute_coordinator( + self._dump_worker_context, + cluster_spec=self._cluster_spec, + between_graph=False) + + # There is only a "None" task in the dumped task context. + self.assertEqual(len(self._worker_context), 1) + self.assertTrue("None" in self._worker_context) + self.assertEqual(len(self._worker_context["None"]), 1) + + # Check whether each task has the right master_target, num_workers, is_chief + # and distributed_mode. + self.assertEqual( + self._worker_context["None"][0], + (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True)) + + def testLocalContext(self): + # Dumps the task contexts to the self._worker_context dict. + distribute_coordinator.run_distribute_coordinator( + self._dump_worker_context, cluster_spec=None, between_graph=True) + + # There is only a "None" task. + self.assertEqual(len(self._worker_context), 1) + self.assertTrue("None" in self._worker_context) + self.assertEqual(len(self._worker_context["None"]), 1) + + # Check whether each task has the right master_target, num_workers, is_chief + # and distributed_mode. + self.assertEqual(self._worker_context["None"][0], ("local", 0, True, False)) + + def testBetweenGraphContextWithChief(self): + # Adds a chief node, so there are NUM_WORKERS + 1 workers in total. + cluster_spec = copy.deepcopy(self._cluster_spec) + cluster_spec[CHIEF] = ["fake_chief"] + + # Dumps the task contexts to the self._worker_context dict. + distribute_coordinator.run_distribute_coordinator( + self._dump_worker_context, + cluster_spec=cluster_spec, + between_graph=True, + rpc_layer="grpc") + + # There are one CHIEF and three workers. + self.assertEqual(len(self._worker_context), 2) + self.assertTrue(CHIEF in self._worker_context) + self.assertTrue(WORKER in self._worker_context) + self.assertEqual(len(self._worker_context[CHIEF]), 1) + self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS) + + # Check whether each task has the right master_target, num_workers, is_chief + # and distributed_mode. + self.assertEqual(self._worker_context[CHIEF][0], + ("grpc://fake_chief", 4, True, True)) + self.assertEqual( + self._worker_context[WORKER][0], + (_bytes_to_str(self._workers[0].target), NUM_WORKERS + 1, False, True)) + self.assertEqual( + self._worker_context[WORKER][1], + (_bytes_to_str(self._workers[1].target), NUM_WORKERS + 1, False, True)) + self.assertEqual( + self._worker_context[WORKER][2], + (_bytes_to_str(self._workers[2].target), NUM_WORKERS + 1, False, True)) + + def testInGraphContextWithEval(self): + # Adds a EVALUATOR job. + cluster_spec = copy.deepcopy(self._cluster_spec) + cluster_spec[EVALUATOR] = ["fake_evaluator"] + + # Dumps the task contexts to the self._worker_context dict. + distribute_coordinator.run_distribute_coordinator( + self._dump_worker_context, + cluster_spec=cluster_spec, + between_graph=False, + rpc_layer=None) + + # There are one "None" task and one EVALUATOR task. + self.assertEqual(len(self._worker_context), 2) + self.assertTrue("None" in self._worker_context) + self.assertTrue(EVALUATOR in self._worker_context) + self.assertEqual(len(self._worker_context["None"]), 1) + self.assertEqual(len(self._worker_context[EVALUATOR]), 1) + + # Check whether each task has the right master_target, num_workers, is_chief + # and distributed_mode. + self.assertEqual(self._worker_context["None"][0], (_strip_protocol( + _bytes_to_str(self._workers[0].target)), 3, True, True)) + self.assertEqual(self._worker_context[EVALUATOR][0], + ("fake_evaluator", 3, True, False)) + + +class DistributeCoordinatorTestInpendentWorkerMode( + DistributeCoordinatorTestBase): + + def testInGraph(self): + cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) + threads = self._run_multiple_coordinator_in_threads( + self._in_graph_worker_fn, + cluster_spec, + between_graph=False, + mode=INDEPENDENT_WORKER) + threads[WORKER][0].join() + self.assertEqual(self._result_correct, 1) + + def testBetweenGraph(self): + cluster_spec = self._create_cluster_spec( + num_workers=NUM_WORKERS, num_ps=NUM_PS) + threads = self._run_multiple_coordinator_in_threads( + self._between_graph_worker_fn, + cluster_spec, + between_graph=True, + mode=INDEPENDENT_WORKER) + for task_id in range(NUM_WORKERS): + threads[WORKER][task_id].join() + + # Each finished worker will increment self._result_correct. + self.assertEqual(self._result_correct, NUM_WORKERS) + + def testBetweenGraphContext(self): + cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) + # Dumps the task contexts and std server arguments. + with test.mock.patch.object(distribute_coordinator, "_run_std_server", + self._run_mock_std_server): + threads = self._run_multiple_coordinator_in_threads( + self._dump_worker_context, + cluster_spec, + mode=INDEPENDENT_WORKER, + between_graph=True, + rpc_layer=None) + for task_id in range(NUM_WORKERS): + threads[WORKER][task_id].join() + + # There is only one type of task and three such tasks. + self.assertEqual(len(self._worker_context), 1) + self.assertTrue(WORKER in self._worker_context) + self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS) + + # Check whether each task has the right master_target, num_workers, is_chief + # and distributed_mode. + self.assertEqual( + self._worker_context[WORKER][0], + (_bytes_to_str(cluster_spec[WORKER][0]), NUM_WORKERS, True, True)) + self.assertEqual( + self._worker_context[WORKER][1], + (_bytes_to_str(cluster_spec[WORKER][1]), NUM_WORKERS, False, True)) + self.assertEqual( + self._worker_context[WORKER][2], + (_bytes_to_str(cluster_spec[WORKER][2]), NUM_WORKERS, False, True)) + + # Make sure each worker runs a std server. + self.assertEqual(len(self._std_servers), 1) + self.assertTrue(WORKER in self._std_servers) + self.assertEqual(len(self._std_servers[WORKER]), 3) + self.assertFalse(self._std_servers[WORKER][0].joined) + self.assertFalse(self._std_servers[WORKER][1].joined) + self.assertFalse(self._std_servers[WORKER][2].joined) + + def testInGraphContext(self): + cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) + # Dumps the task contexts and std server arguments. + with test.mock.patch.object(distribute_coordinator, "_run_std_server", + self._run_mock_std_server): + threads = self._run_multiple_coordinator_in_threads( + self._dump_worker_context, + cluster_spec, + mode=INDEPENDENT_WORKER, + between_graph=False, + rpc_layer=None) + for task_id in range(NUM_WORKERS): + threads[WORKER][task_id].join() + + # There is only a "None" task in the dumped task context. + self.assertEqual(len(self._worker_context), 1) + self.assertTrue("None" in self._worker_context) + self.assertEqual(len(self._worker_context["None"]), 1) + + # Check whether each task has the right master_target, num_workers, is_chief + # and distributed_mode. + self.assertEqual( + self._worker_context["None"][0], + (_bytes_to_str(cluster_spec[WORKER][0]), NUM_WORKERS, True, True)) + + # Make sure each worker runs a std server. + self.assertEqual(len(self._std_servers), 1) + self.assertTrue(WORKER in self._std_servers) + self.assertEqual(len(self._std_servers[WORKER]), 3) + self.assertFalse(self._std_servers[WORKER][0].joined) + self.assertTrue(self._std_servers[WORKER][1].joined) + self.assertTrue(self._std_servers[WORKER][2].joined) + + def testInGraphContextWithEval(self): + # Adds a EVALUATOR job. + cluster_spec = self._create_cluster_spec( + num_workers=NUM_WORKERS, has_eval=True) + + # Dumps the task contexts and std server arguments. + with test.mock.patch.object(distribute_coordinator, "_run_std_server", + self._run_mock_std_server): + threads = self._run_multiple_coordinator_in_threads( + self._dump_worker_context, + cluster_spec, + mode=INDEPENDENT_WORKER, + between_graph=False, + rpc_layer=None) + for task_id in range(NUM_WORKERS): + threads[WORKER][task_id].join() + threads[EVALUATOR][0].join() + + # There are one "None" task and one EVALUATOR task. + self.assertEqual(len(self._worker_context), 2) + self.assertTrue("None" in self._worker_context) + self.assertTrue(EVALUATOR in self._worker_context) + self.assertEqual(len(self._worker_context["None"]), 1) + self.assertEqual(len(self._worker_context[EVALUATOR]), 1) + + # Check whether each task has the right master_target, num_workers, is_chief + # and distributed_mode. + self.assertEqual(self._worker_context["None"][0], + (_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True)) + self.assertEqual(self._worker_context[EVALUATOR][0], + (cluster_spec[EVALUATOR][0], 3, True, False)) + + # Make sure each worker runs a std server. + self.assertEqual(len(self._std_servers), 2) + self.assertTrue(WORKER in self._std_servers) + self.assertTrue(EVALUATOR in self._std_servers) + self.assertEqual(len(self._std_servers[WORKER]), 3) + self.assertEqual(len(self._std_servers[EVALUATOR]), 1) + self.assertFalse(self._std_servers[WORKER][0].joined) + self.assertTrue(self._std_servers[WORKER][1].joined) + self.assertTrue(self._std_servers[WORKER][2].joined) + self.assertFalse(self._std_servers[EVALUATOR][0].joined) + + +if __name__ == "__main__": + # TODO(yuefengz): find a smart way to terminite std server threads. + with test.mock.patch.object(sys, "exit", os._exit): + test.main() diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index dee86966f1bb08540c69f158e13ce6a288bd9821..de93b1e2e15037a295aa47c9de50378bdece5769 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -32,6 +32,7 @@ cc_library( "//tensorflow/python:numpy_lib", "//tensorflow/python:py_seq_tensor", "//tensorflow/python:safe_ptr", + "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", ], ) @@ -248,6 +249,7 @@ py_library( "//tensorflow/python/eager:execute", "//tensorflow/python/eager:tape", "//third_party/py/numpy", + "@six_archive//:six", ], ) @@ -321,6 +323,7 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:random_ops", + "//tensorflow/python/keras", ], ) @@ -391,3 +394,21 @@ py_library( srcs = ["imperative_grad.py"], srcs_version = "PY2AND3", ) + +cuda_py_test( + name = "memory_test", + size = "medium", + srcs = ["memory_test.py"], + additional_deps = [ + "//tensorflow/python/eager:backprop", + "//tensorflow/python/keras", + "//tensorflow/python/eager:test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "@six_archive//:six", + ], + tags = [ + "optonly", # The test is too slow in non-opt mode + ], +) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index bd97b181ff7fa5a38ea8ab16e55b3ade7b599261..553f761a1409112773537b0a1eb9b7b5399533b9 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -276,7 +276,7 @@ def implicit_grad(f): def _get_arg_spec(f, params, param_args): """The positions of the parameters of f to be differentiated in param_args.""" try: - args = tf_inspect.getargspec(f).args + args = tf_inspect.getfullargspec(f).args except TypeError as e: # TypeError can happen when f is a callable object. if params is None: @@ -591,29 +591,34 @@ def _num_elements(grad): raise ValueError("`grad` not a Tensor or IndexedSlices.") -_zeros_cache = context._TensorCache() # pylint: disable=protected-access - - def _fast_fill(value, shape, dtype): return array_ops.fill(shape, constant_op.constant(value, dtype=dtype)) def _zeros(shape, dtype): - """Wraps array_ops.zeros to cache last zero for a given shape and dtype.""" - device = context.context().device_name + """Helper to return (possibly cached) zero tensors in eager mode.""" if dtype == dtypes.variant: # TODO(apassos): need to save enough information about variant tensors to do # a zeros return None + + ctx = context.context() + if not ctx.executing_eagerly(): + return array_ops.zeros(shape, dtype) + + device = ctx.device_name cache_key = shape, dtype, device - cached = _zeros_cache.get(cache_key) + cached = ctx.zeros_cache().get(cache_key) if cached is None: cached = _fast_fill(0, shape, dtype) - _zeros_cache.put(cache_key, cached) + ctx.zeros_cache().put(cache_key, cached) return cached def _ones(shape, dtype): + if not context.context().executing_eagerly(): + return array_ops.ones(shape, dtype) + if shape == (): # pylint: disable=g-explicit-bool-comparison return constant_op.constant(1, dtype=dtype) return _fast_fill(1, shape, dtype) @@ -641,10 +646,10 @@ class GradientTape(object): Operations are recorded if they are executed within this context manager and at least one of their inputs is being "watched". - Trainable variables (created by `tf.contrib.eager.Variable` or - @{tf.get_variable}, trainable=True is default in both cases) are automatically - watched. Tensors can be manually watched by invoking the `watch` method on - this context manager. + Trainable variables (created by `tf.Variable` or `tf.get_variable`, + trainable=True is default in both cases) are automatically watched. Tensors + can be manually watched by invoking the `watch` method on this context + manager. For example, consider the function `y = x * x`. The gradient at `x = 3.0` can be computed as: @@ -700,6 +705,7 @@ class GradientTape(object): self._tape = None self._persistent = persistent self._recording = False + context.context().start_step() def __enter__(self): """Enters a context inside which operations are recorded on this tape.""" @@ -711,10 +717,15 @@ class GradientTape(object): if self._recording: self._pop_tape() - def _push_tape(self): + def _push_tape(self, existing_tape=False): if self._recording: raise ValueError("Tape is already recording.") - self._tape = tape.push_new_tape(persistent=self._persistent) + if existing_tape: + if self._tape is None: + raise ValueError("There is no existing tape.") + tape.push_tape(self._tape) + else: + self._tape = tape.push_new_tape(persistent=self._persistent) self._recording = True def _pop_tape(self): @@ -723,6 +734,9 @@ class GradientTape(object): tape.pop_tape(self._tape) self._recording = False + def __del__(self): + context.context().end_step() + def watch(self, tensor): """Ensures that `tensor` is being traced by this tape. @@ -762,7 +776,7 @@ class GradientTape(object): try: yield finally: - self._push_tape() + self._push_tape(existing_tape=True) def reset(self): """Clears all information stored in this tape. diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 826c6683b9668ab892883119a533ee8d497d7b58..3d3f54b9c468fa6e47838a2d440c0330651402da 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -46,7 +46,7 @@ from tensorflow.python.training import training class BackpropTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAggregateGradients(self): def fn(x): @@ -96,6 +96,19 @@ class BackpropTest(test.TestCase): self.assertAllEqual(grads_and_vars[0][0], 1.0) self.assertAllEqual(id(grads_and_vars[0][1]), id(x)) + def testGradientInsideLoop(self): + with ops.Graph().as_default(): + v = resource_variable_ops.ResourceVariable(1.0) + + def body(_): + _ = v + 1.0 # This reads the variable inside the loop context + with backprop.GradientTape() as t: + result = v * 2 + self.assertTrue(t.gradient(result, v) is not None) + return 1.0 + + control_flow_ops.while_loop(lambda i: False, body, [1.0]) + def testWhereGradient(self): # Note: where is special because only some of its arguments are of # differentiable dtypes. @@ -223,11 +236,23 @@ class BackpropTest(test.TestCase): def testTapeStopRecording(self): with backprop.GradientTape() as t: - x = constant_op.constant(1.0) + x = resource_variable_ops.ResourceVariable(1.0) with t.stop_recording(): y = x * x self.assertEqual(t.gradient(y, x), None) + def testTapeStopStartRecording(self): + with backprop.GradientTape(persistent=True) as t: + x = resource_variable_ops.ResourceVariable(1.0) + x2 = x * 2 # This should be differentiated through. + with t.stop_recording(): + y = x2 * x2 + z = x2 * x2 + self.assertEqual(t.gradient(y, x2), None) + + # If the x*2 was not differentiated through, this would be 2.0, not 4.0 + self.assertEqual(t.gradient(z, x2).numpy(), 4.0) + def testTapeReset(self): with backprop.GradientTape() as t: v = resource_variable_ops.ResourceVariable(1.0) @@ -251,7 +276,7 @@ class BackpropTest(test.TestCase): g, = backprop.gradients_function(loss, [0])(logits, labels) self.assertAllEqual(g.numpy(), [[-0.5, 0.5]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientWithinTapeBlock(self): v1 = resource_variable_ops.ResourceVariable(1.) self.evaluate(v1.initializer) @@ -265,7 +290,7 @@ class BackpropTest(test.TestCase): grad = t.gradient(loss, v1) self.assertAllEqual(self.evaluate(grad), 2.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestedSelfContexts(self): v1 = resource_variable_ops.ResourceVariable(1.) self.evaluate(v1.initializer) @@ -435,7 +460,7 @@ class BackpropTest(test.TestCase): self.assertEqual(backprop.implicit_grad(f)()[0][0], None) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeRepeatedSource(self): with backprop.GradientTape(persistent=False) as g: x = constant_op.constant(3.0) @@ -445,7 +470,7 @@ class BackpropTest(test.TestCase): self.assertEqual(self.evaluate(grad), [2.0, 2.0]) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPersistentGradientTapeRepeatedSource(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -459,7 +484,7 @@ class BackpropTest(test.TestCase): self.assertEqual(self.evaluate(grad), [3.0, 11.0]) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeStructure(self): with backprop.GradientTape(persistent=True) as g: # Using different constant values because constant tensors are @@ -482,7 +507,7 @@ class BackpropTest(test.TestCase): [1.0, {'x2': 2.0, 'x3': 3.0}]) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTape(self): with backprop.GradientTape() as g: x = constant_op.constant(3.0) @@ -497,7 +522,7 @@ class BackpropTest(test.TestCase): grad = g.gradient(y, [x])[0] self.assertEqual(self.evaluate(grad), 6.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeWithCond(self): x = constant_op.constant(3.0) @@ -518,7 +543,7 @@ class BackpropTest(test.TestCase): dy = g.gradient(y, [x])[0] self.assertEqual(self.evaluate(dy), 6.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeWithWhileLoop(self): i = constant_op.constant(1) x = constant_op.constant(2.) @@ -553,7 +578,7 @@ class BackpropTest(test.TestCase): g.gradient(y, [x]) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPersistentTape(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -567,7 +592,7 @@ class BackpropTest(test.TestCase): del g @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testHigherOrderGradient(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -584,7 +609,7 @@ class BackpropTest(test.TestCase): del g @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPersistentNestedTape(self): with backprop.GradientTape(persistent=True) as g: x = constant_op.constant(3.0) @@ -605,7 +630,7 @@ class BackpropTest(test.TestCase): del g @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientTapeVariable(self): v = resource_variable_ops.ResourceVariable(1.0, name='v') self.evaluate(v.initializer) @@ -615,7 +640,7 @@ class BackpropTest(test.TestCase): self.assertAllEqual(self.evaluate(grad), 2.0) @test_util.assert_no_new_tensors - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNestedGradients(self): x = constant_op.constant(3.0) with backprop.GradientTape() as g: @@ -900,6 +925,24 @@ class BackpropTest(test.TestCase): 'did you forget to return a value from fn?'): val_and_grads_fn(x, y) + def testZerosCacheDoesntLeakAcrossGraphs(self): + with context.graph_mode(): + def get_grad(): + with ops.Graph().as_default(), self.test_session(): + t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4)) + x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4)) + with backprop.GradientTape() as gt: + tape.watch(x) + x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1) + y1 = x1**2 + y = array_ops.concat([y1, t], axis=1) + return self.evaluate(gt.gradient(y, x)) + + grad1 = get_grad() + grad2 = get_grad() + + self.assertAllEqual(grad1, grad2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 3aad4a114a710280b5046666256b6b43dc0d5523..e2b1890c2f6d2e797d2498fb239fbed1e55b00ac 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -31,14 +31,17 @@ import numpy as np import six from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python import keras from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop # pylint: disable=unused-import from tensorflow.python.eager import context from tensorflow.python.eager import core from tensorflow.python.eager import function from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -70,6 +73,60 @@ def c_tfe_py_fastpath_execute(a, six.raise_from(core._status_to_exception(e.code, message), None) +class SubclassedKerasModel(keras.Model): + + def __init__(self): + super(SubclassedKerasModel, self).__init__() + self.layer_a = keras.layers.Dense( + 64, kernel_initializer="ones", bias_initializer="zeros") + self.layer_b = keras.layers.Dense( + 128, kernel_initializer="ones", bias_initializer="zeros") + self.layer_c = keras.layers.Dense( + 256, kernel_initializer="ones", bias_initializer="zeros") + self.layer_d = keras.layers.Dense( + 256, kernel_initializer="ones", bias_initializer="zeros") + self.layer_e = keras.layers.Dense( + 10, kernel_initializer="ones", bias_initializer="zeros") + + def call(self, x): + x = self.layer_a(x) + x = self.layer_b(x) + x = self.layer_c(x) + x = self.layer_d(x) + return self.layer_e(x) + + +def make_keras_model(): + model_input = keras.Input(shape=(10,)) + x = keras.layers.Dense( + 64, kernel_initializer="ones", bias_initializer="zeros")(model_input) + x = keras.layers.Dense( + 128, kernel_initializer="ones", bias_initializer="zeros")(x) + x = keras.layers.Dense( + 256, kernel_initializer="ones", bias_initializer="zeros")(x) + x = keras.layers.Dense( + 256, kernel_initializer="ones", bias_initializer="zeros")(x) + x = keras.layers.Dense( + 10, kernel_initializer="ones", bias_initializer="zeros")(x) + return keras.Model(inputs=model_input, outputs=x) + + +def make_sequential_keras_model(): + model = keras.models.Sequential() + model.add(keras.layers.Dense( + 64, kernel_initializer="ones", bias_initializer="zeros", + input_shape=(10,))) + model.add(keras.layers.Dense( + 128, kernel_initializer="ones", bias_initializer="zeros")) + model.add(keras.layers.Dense( + 256, kernel_initializer="ones", bias_initializer="zeros")) + model.add(keras.layers.Dense( + 256, kernel_initializer="ones", bias_initializer="zeros")) + model.add(keras.layers.Dense( + 10, kernel_initializer="ones", bias_initializer="zeros")) + return model + + class MicroBenchmarks(test.Benchmark): def __init__(self): @@ -115,6 +172,7 @@ class MicroBenchmarks(test.Benchmark): def func(): ops.EagerTensor(value, context=handle, device=device, dtype=dtype) + self._run(func, 30000) def benchmark_create_float_tensor_from_list_CPU(self): @@ -211,8 +269,8 @@ class MicroBenchmarks(test.Benchmark): inputs = [m] def f(): - pywrap_tensorflow.TFE_Py_Execute( - ctx_handle, None, "Identity", inputs, attrs, 1) + pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "Identity", inputs, + attrs, 1) self._run(f, 30000) @@ -234,14 +292,13 @@ class MicroBenchmarks(test.Benchmark): def f(): with backprop.GradientTape(): pass + self._run(f, 30000) def benchmark_tf_gradient_function_no_op(self): with context.device(CPU): m = gen_array_ops.identity(self._m_2) - self._run( - lambda: backprop.gradients_function(lambda x: x, [0])(m), - 30000) + self._run(lambda: backprop.gradients_function(lambda x: x, [0])(m), 30000) def _benchmark_np_matmul(self, m, transpose_b, num_iters): a = m.cpu().numpy() @@ -255,6 +312,7 @@ class MicroBenchmarks(test.Benchmark): self._run(func, num_iters, execution_mode=execution_mode) def _benchmark_gen_math_ops_matmul(self, m, transpose_b, num_iters): + def func(): gen_math_ops.mat_mul(m, m, transpose_b=transpose_b) @@ -276,9 +334,10 @@ class MicroBenchmarks(test.Benchmark): device = context.context().device_name attrs = ("transpose_a", False, "transpose_b", transpose_b, "T", m.dtype.as_datatype_enum) + def func(): - pywrap_tensorflow.TFE_Py_Execute(ctx_handle, device, "MatMul", - inputs, attrs, 1) + pywrap_tensorflow.TFE_Py_Execute(ctx_handle, device, "MatMul", inputs, + attrs, 1) self._run(func, num_iters) @@ -505,6 +564,54 @@ class MicroBenchmarks(test.Benchmark): self._benchmark_defun_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) + def benchmark_defun_without_signature(self): + + def func(t1, t2, t3, t4, t5, t6, t7, t8): + del t1, t2, t3, t4, t5, t6, t7, t8 + return None + + defined = function.defun(func) + t = constant_op.constant(0.0) + cache_computation = lambda: defined(t, t, t, t, t, t, t, t) + self._run(cache_computation, 30000) + + def benchmark_defun_without_signature_and_with_kwargs(self): + + def func(t1, t2, t3, t4, t5, t6, t7, t8): + del t1, t2, t3, t4, t5, t6, t7, t8 + return None + + defined = function.defun(func) + t = constant_op.constant(0.0) + def cache_computation(): + return defined(t1=t, t2=t, t3=t, t4=t, t5=t, t6=t, t7=t, t8=t) + self._run(cache_computation, 30000) + + def benchmark_defun_with_signature(self): + + def func(t1, t2, t3, t4, t5, t6, t7, t8): + del t1, t2, t3, t4, t5, t6, t7, t8 + return None + + defined = function.defun( + func, input_signature=[tensor_spec.TensorSpec([], dtypes.float32)] * 8) + t = constant_op.constant(0.0) + signature_computation = lambda: defined(t, t, t, t, t, t, t, t) + self._run(signature_computation, 30000) + + def benchmark_defun_with_signature_and_kwargs(self): + + def func(t1, t2, t3, t4, t5, t6, t7, t8): + del t1, t2, t3, t4, t5, t6, t7, t8 + return None + + defined = function.defun( + func, input_signature=[tensor_spec.TensorSpec([], dtypes.float32)] * 8) + t = constant_op.constant(0.0) + def signature_computation(): + return defined(t1=t, t2=t, t3=t, t4=t, t5=t, t6=t, t7=t, t8=t) + self._run(signature_computation, 30000) + def benchmark_matmul_read_variable_op_2_by_2_CPU(self): with context.device(CPU): m = resource_variable_ops.ResourceVariable(self._m_2_by_2) @@ -542,6 +649,39 @@ class MicroBenchmarks(test.Benchmark): self._benchmark_read_variable_with_tape( m, num_iters=self._num_iters_2_by_2) + def benchmark_keras_model_subclassed(self): + model = SubclassedKerasModel() + data = random_ops.random_uniform((10, 10)) + + func = lambda: model(data) + # First call is more expensive (creates variables etc.), discount that. + func() + + # The whole point of this test is to contrast subclassing with + # the functional style of keras model building, so validate that + # the models are equivalent. + assert np.equal(func(), make_keras_model()(data)).all() + + self._run(func, 30000) + + def benchmark_keras_model_functional(self): + model = make_keras_model() + data = random_ops.random_uniform((10, 10)) + func = lambda: model(data) + # Symmetry with benchmark_keras_model_subclassed + func() + assert np.equal(func(), SubclassedKerasModel()(data)).all() + self._run(func, 30000) + + def benchmark_keras_model_sequential(self): + model = make_sequential_keras_model() + data = random_ops.random_uniform((10, 10)) + func = lambda: model(data) + # Symmetry with benchmark_keras_model_functional + func() + assert np.equal(func(), make_keras_model()(data)).all() + self._run(func, 30000) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 9e146f021e813886b42ca72b07122b485901a24b..6a327bd010f5d00c403c09a4f5b6fe6572fc6d9a 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -91,6 +91,7 @@ class _EagerContext(threading.local): self.summary_writer_resource = None self.scalar_cache = {} self.ones_rank_cache = _TensorCache() + self.zeros_cache = _TensorCache() self.execution_mode = None @@ -143,7 +144,11 @@ class Context(object): # TODO(agarwal): create and link in some documentation for `execution_mode`. # pylint: disable=redefined-outer-name - def __init__(self, config=None, device_policy=None, execution_mode=None): + def __init__(self, + config=None, + device_policy=None, + execution_mode=None, + server_def=None): """Creates a new Context. Args: @@ -173,6 +178,11 @@ class Context(object): - tf.contrib.eager.SYNC: executes each operation synchronously. - tf.contrib.eager.ASYNC: executes each operation asynchronously. These operations may return "non-ready" handles. + server_def: (Optional.) A tensorflow::ServerDef proto. + Enables execution on remote devices. GrpcServers need to be started by + creating an identical server_def to this, and setting the appropriate + task_indexes, so that the servers can communicate. It will then be + possible to execute operations on remote devices. Raises: ValueError: If execution_mode is not valid. @@ -192,6 +202,7 @@ class Context(object): if execution_mode is None: execution_mode = SYNC self._execution_mode = execution_mode + self._server_def = server_def # pylint: enable=redefined-outer-name @@ -215,6 +226,24 @@ class Context(object): """ return self._rng.randint(0, _MAXINT32) + def _initialize_devices(self): + """Helper to initialize devices.""" + # Store list of devices + self._context_devices = [] + device_list = pywrap_tensorflow.TFE_ContextListDevices( + self._context_handle) + try: + self._num_gpus = 0 + for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): + dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i) + self._context_devices.append(pydev.canonical_name(dev_name)) + dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i) + if dev_type == "GPU": + self._num_gpus += 1 + + finally: + pywrap_tensorflow.TF_DeleteDeviceList(device_list) + def _initialize_handle_and_devices(self): """Initialize handle and devices.""" with self._initialize_lock: @@ -234,21 +263,50 @@ class Context(object): self._context_handle = pywrap_tensorflow.TFE_NewContext(opts) finally: pywrap_tensorflow.TFE_DeleteContextOptions(opts) - # Store list of devices - self._context_devices = [] - device_list = pywrap_tensorflow.TFE_ContextListDevices( - self._context_handle) - try: - self._num_gpus = 0 - for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): - dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i) - self._context_devices.append(pydev.canonical_name(dev_name)) - dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i) - if dev_type == "GPU": - self._num_gpus += 1 + if self._server_def is not None: + server_def_str = self._server_def.SerializeToString() + pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, 600, + server_def_str) - finally: - pywrap_tensorflow.TF_DeleteDeviceList(device_list) + self._initialize_devices() + + def _clear_caches(self): + self.scalar_cache().clear() + self.ones_rank_cache().flush() + self.zeros_cache().flush() + + def set_server_def(self, server_def, keep_alive_secs=600): + """Allow setting a server_def on the context. + + When a server def is replaced, it effectively clears a bunch of caches + within the context. If you attempt to use a tensor object that was pointing + to a tensor on the remote device, it will raise an error. + + Args: + server_def: A tensorflow::ServerDef proto. + Enables execution on remote devices. + keep_alive_secs: Num. seconds after which the remote end will hang up. + As long as the client is still alive, the server state for the context + will be kept alive. If the client is killed (or there is some failure), + the server will clean up its context keep_alive_secs after the final RPC + it receives. + + Raises: + ValueError: if server_def is None. + """ + if not server_def: + raise ValueError("server_def is None.") + if not self._context_handle: + self._server_def = server_def + else: + server_def_str = server_def.SerializeToString() + pywrap_tensorflow.TFE_ContextSetServerDef(self._context_handle, + keep_alive_secs, server_def_str) + + # Clear all the caches in case there are remote tensors in them. + self._clear_caches() + + self._initialize_devices() @property def _handle(self): @@ -311,6 +369,10 @@ class Context(object): """Per-device cache for scalars.""" return self._eager_context.ones_rank_cache + def zeros_cache(self): + """Per-device cache for scalars.""" + return self._eager_context.zeros_cache + @property def scope_name(self): """Returns scope name for the current thread.""" @@ -546,6 +608,12 @@ class Context(object): """Returns a stack of context switches.""" return self._context_switches + def start_step(self): + pywrap_tensorflow.TFE_ContextStartStep(self._handle) + + def end_step(self): + pywrap_tensorflow.TFE_ContextEndStep(self._handle) + _context = None _context_lock = threading.Lock() @@ -595,7 +663,7 @@ def internal_operation_seed(): def executing_eagerly(): """Returns True if the current thread has eager execution enabled. - Eager execution is typically enabled via @{tf.enable_eager_execution}, + Eager execution is typically enabled via `tf.enable_eager_execution`, but may also be enabled within the context of a Python function via tf.contrib.eager.py_func. """ @@ -722,6 +790,10 @@ def export_run_metadata(): return context().export_run_metadata() +def set_server_def(server_def): + context().set_server_def(server_def) + + # Not every user creates a Context via context.context() # (for example, enable_eager_execution in python/framework/ops.py), # but they do all import this file. Note that IS_IN_GRAPH_MODE and diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 3fabe7060e980423268eb6f52ab4043cc4a4847c..cc765725a48631f0c50662dedf7fe7af7b30f9a3 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -610,6 +610,14 @@ class TFETest(test_util.TensorFlowTestCase): self.assertEquals(typ, dtypes.float32) self.assertIsInstance(t, ops.EagerTensor) + def testConvertMixedEagerTensorsWithVariables(self): + var = resource_variable_ops.ResourceVariable(1.0) + types, tensors = execute_lib.convert_to_mixed_eager_tensors( + ['foo', var], context.context()) + self.assertAllEqual([dtypes.string, dtypes.float32], types) + for t in tensors: + self.assertIsInstance(t, ops.EagerTensor) + class SendRecvTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py index 2ff5b8d8f489731c14d8abb81652a17026ed4935..f9b8d2cb5db9aedcd834afcde00dac3afa4008bb 100644 --- a/tensorflow/python/eager/execute.py +++ b/tensorflow/python/eager/execute.py @@ -198,11 +198,7 @@ def args_to_matching_eager(l, ctx, default_dtype=None): def convert_to_mixed_eager_tensors(values, ctx): - v = [ - t if isinstance(t, ops.EagerTensor) else ops.EagerTensor( - t, context=ctx._handle, device=ctx.device_name) # pylint: disable=protected-access - for t in values - ] + v = [ops.internal_convert_to_tensor(t, ctx=ctx) for t in values] types = [t._datatype_enum() for t in v] # pylint: disable=protected-access return types, v diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 03393bcd46c4670fcac8f01f274137ee60df0bea..7e51293724d0aad197de77b5418d05e911f8da2d 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -20,8 +20,11 @@ from __future__ import division from __future__ import print_function import collections +import functools +import threading import numpy as np +import six from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 @@ -33,64 +36,78 @@ from tensorflow.python.eager.graph_only_ops import graph_placeholder from tensorflow.python.framework import c_api_util from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect + + +def create_substitute_placeholder(value, name, dtype=None): + """Creates a placeholder for `value` and propagates shape info to it.""" + # Note: setting ops.control_dependencies(None) ensures we always put + # capturing placeholders outside of any control flow context. + with ops.control_dependencies(None): + placeholder = graph_placeholder( + dtype=dtype or value.dtype, shape=value.shape, name=name) + if placeholder.dtype == dtypes_module.resource: + if isinstance(value, ops.EagerTensor): + handle_data = value._handle_data # pylint: disable=protected-access + else: + handle_data = resource_variable_ops.get_resource_handle_data(value) + if handle_data is not None and handle_data.is_set: + # pylint: disable=protected-access + pywrap_tensorflow.SetResourceHandleShapeAndType( + placeholder.graph._c_graph, placeholder._as_tf_output(), + handle_data.SerializeToString()) + # pylint: enable=protected-access + # Ensure that shapes and dtypes are propagated. + shapes, types = zip(*[(pair.shape, pair.dtype) + for pair in handle_data.shape_and_type]) + ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] + shapes = [[d.size for d in s.dim] + if not s.unknown_rank else None for s in shapes] + pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( + placeholder._op._graph._c_graph, # pylint: disable=protected-access + placeholder._as_tf_output(), # pylint: disable=protected-access + shapes, ranks, types) + + return placeholder def capture_value(tensor_map, value, dtype, name): """Capture a value from outside the function, to pass in as an extra arg.""" - captured_value = tensor_map.get(ops.tensor_id(value), None) + captured_value = tensor_map.get(value, None) if captured_value is None: - captured_value = graph_placeholder( - dtype=dtype or value.dtype, shape=value.shape, name=name) - if captured_value.dtype == dtypes_module.resource: - if ops._USE_C_SHAPES: # pylint: disable=protected-access - if isinstance(value, ops.EagerTensor): - handle_data = value._handle_data # pylint: disable=protected-access - else: - handle_data = resource_variable_ops.get_resource_handle_data(value) - else: - handle_data = value._handle_data # pylint: disable=protected-access - if handle_data is not None and handle_data.is_set: - # pylint: disable=protected-access - if ops._USE_C_SHAPES: - pywrap_tensorflow.SetResourceHandleShapeAndType( - captured_value.graph._c_graph, captured_value._as_tf_output(), - handle_data.SerializeToString()) - else: - captured_value._handle_data = handle_data - # pylint: enable=protected-access - # Ensure that shapes and dtypes are propagated. - shapes, types = zip(*[(pair.shape, pair.dtype) - for pair in handle_data.shape_and_type]) - ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] - shapes = [[d.size for d in s.dim] - if not s.unknown_rank else None for s in shapes] - pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( - captured_value._op._graph._c_graph, # pylint: disable=protected-access - captured_value._as_tf_output(), # pylint: disable=protected-access - shapes, ranks, types) - - tensor_map[ops.tensor_id(value)] = (value, captured_value) - else: - captured_value = captured_value[1] + captured_value = create_substitute_placeholder(value, name=name, + dtype=dtype) + tensor_map[value] = captured_value tape.record_operation("captured_value", [captured_value], [value], lambda x: [x]) return captured_value class CapturingGraph(ops.Graph): - """Graph used when constructing eager functions.""" + """Graph that can capture tensors from other graphs. + + Attributes: + captures: Maps external tensor -> internal tensor (e.g. input placeholder). + The entries are in the order they were captured. + """ - def __init__(self, captures): + def __init__(self): super(CapturingGraph, self).__init__() + + self.captures = collections.OrderedDict() self._building_function = True - self.captures = captures + # Map from resource tensor name to last op (in program order) which uses # this tensor. Used to enforce that execution order matches program order # for resource tensors. @@ -103,7 +120,22 @@ class CapturingGraph(ops.Graph): def clear_resource_control_flow_state(self): self._last_op_using_resource_tensor = {} + # TODO(skyewm): get rid of name and use the name of `tensor`. def capture(self, tensor, name=None): + """Capture `tensor` if it's external to this graph. + + If `tensor` is from a different graph, returns a placeholder for it. + `tensor` and the placeholder will also appears in self.captures. Multiple + calls to this method with the same `tensor` argument will return the same + placeholder. If `tensor` is from this graph, returns `tensor`. + + Args: + tensor: Tensor. May be from this FuncGraph or a different graph. + name: Optional name if a placeholder is created. + + Returns: + Tensor from this FuncGraph. + """ if isinstance(tensor, ops.EagerTensor): if name is None: name = str(ops.uid()) @@ -125,86 +157,94 @@ class CapturingGraph(ops.Graph): op_def=None, compute_shapes=True, compute_device=True): - # TODO(apassos) this should do some form of alias analysis as ops which - # forward the resources such as Identity and Switch can cause serialization - # to fail. + """Captures an external inputs before calling Graph.capture_op.""" + # This capturing logic interacts poorly with control flow contexts which + # want to replace inputs of ops far too late in the process. This can lead + # the context to get confused and try to create an Enter for an Enter. We + # can detect this here and skip the additional Enter which can confuse loop + # validation logic. + if op_type == "Enter" and inputs[0].op.type == "Enter": + if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: + return inputs[0].op + # Calling AddValue on the control flow contexts to force creation of the + # backward accumulators in the original graph before we create placeholders + # to capture the inputs. + ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access for i, inp in enumerate(inputs): - inputs[i] = self.capture(inp) + if ctxt is not None and hasattr(ctxt, "AddValue"): + inp = ctxt.AddValue(inp) + inp = self.capture(inp) + inputs[i] = inp return super(CapturingGraph, self).create_op( op_type, inputs, dtypes, input_types, name, attrs, op_def, - compute_shapes, compute_device) - - -# pylint: disable=invalid-name -class HelperContext(object): - """ControlFlowContext with a customizable AddOp method.""" - - def __init__(self, add_op_internal): - self._add_op_internal = add_op_internal - self._values = set() # control flow code sometimes updates this. - - def _AddOpInternal(self, op): - self._add_op_internal(op) - - @property - def outer_context(self): - return self._outer_context - - def GetWhileContext(self): - if self._outer_context: - return self._outer_context.GetWhileContext() - - def IsWhileContext(self): - return False - - def IsCondContext(self): - return False - - def IsXLAContext(self): - return False - - def AddOp(self, op): # pylint: disable=invalid-name - self._AddOpInternal(op) - if self._outer_context: - self._outer_context.AddOp(op) - - def AddName(self, _): - pass - - def AddInnerOp(self, op): - self._AddOpInternal(op) - if self._outer_context: - self._outer_context.AddInnerOp(op) + compute_device=compute_device) + + +class FuncGraph(CapturingGraph): + """Graph representing a function body. + + Attributes: + name: The name of the function. + + inputs: Placeholder tensors representing the inputs to this function. The + tensors are in this FuncGraph. This represents "regular" inputs as well as + captured inputs (i.e. the values of self.captures), with the regular + inputs coming first. + outputs: Tensors that will be returned by this function. The tensors are in + this FuncGraph. + structured_outputs: A possibly-nested python object which will be returned + by this function. The Tensors in this structure are the same as those of + self.outputs. Note that this structure might contain Python `None`s. + variables: Variables that should be watched during function execution. + seed: The graph-level random seed. + """ - def AddValue(self, val): - if self._outer_context: - return self._outer_context.AddValue(val) - else: - return val + def __init__(self, name, graph=None): + """Construct a new FuncGraph. - def EnterGradientColocation(self, op, gradient_uid): - """Start building a gradient colocated with an op.""" - if self._outer_context: - self._outer_context.EnterGradientColocation(op, gradient_uid) + Args: + name: the name of the function. + graph: if specified, this FuncGraph will inherit its graph key, + collections, and seed from `graph`. + """ + super(FuncGraph, self).__init__() + + self.name = name + self.inputs = [] + self.outputs = [] + self.structured_outputs = None + self.variables = [] + + if graph is not None: + # Inherit the graph key, since this is used for matching variables in + # optimizers. + self._graph_key = graph._graph_key # pylint: disable=protected-access + + # Copy the graph collections to ensure summaries and other things work. + # This lets the function access (but not mutate) collections of the + # containing graph, such as the global step and the summary writer + # collections. + for collection in graph.collections: + self.get_collection_ref(collection)[:] = graph.get_collection( + collection) + + # Copy distribution strategy scope from the containing graph as well. + self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access + + if context.executing_eagerly(): + self.seed = context.global_seed() + else: + self.seed = graph.seed - def ExitGradientColocation(self, op, gradient_uid): - """Start building a gradient colocated with an op.""" - if self._outer_context: - self._outer_context.ExitGradientColocation(op, gradient_uid) + def capture(self, tensor, name=None): + """Calls CapturingGraph.capture and updates self.inputs if necessary.""" + new_capture = tensor not in self.captures + internal_tensor = super(FuncGraph, self).capture(tensor, name) - def __enter__(self): - # pylint: disable=protected-access - self._g = ops.get_default_graph() - self._outer_context = self._g._get_control_flow_context() - self._g._set_control_flow_context(self) - self._nested_contexts = ( - self._outer_context._nested_contexts - if self._outer_context is not None else None) - # pylint: enable=protected-access + if new_capture and tensor is not internal_tensor: + self.inputs.append(internal_tensor) - def __exit__(self, *_): - self._g._set_control_flow_context(self._outer_context) # pylint: disable=protected-access -# pylint: enable=invalid-name + return internal_tensor def _forward_name(n): @@ -222,11 +262,25 @@ def _inference_name(n): return "__inference_%s_%s" % (n, ops.uid()) +def _register(fn): + """Registers the function `fn`.""" + context.context().add_function(fn) + + +_xla_compile_attr = "_XlaCompile" + + # TODO(apassos) get rid of this by splitting framework.function._DefinedFunction # so it doesn't have the definition-generating logic and is just a container for # an already-defined function. class _EagerDefinedFunction(object): - """Function object with the interface of tf _DefinedFunction.""" + """Callable with the interface of `framework.function._DefinedFunction.` + + `_EagerDefinedFunction` encapsulates a function definition and its properties, + and it provides a method for calling the encapsulated function. Some Ops + take functions as attributes, which have type `func`; an instance of this + class may be provided as the value of these `func` attributes. + """ def __init__(self, name, graph, operations, inputs, outputs, attrs): """Initializes an eager defined function. @@ -257,6 +311,7 @@ class _EagerDefinedFunction(object): # It might be worth creating a convenient way to re-use status. pywrap_tensorflow.TF_FunctionSetAttrValueProto( fn, compat.as_str(name), serialized) + self._xla_compile = _xla_compile_attr in attrs # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. @@ -268,12 +323,92 @@ class _EagerDefinedFunction(object): if context.executing_eagerly(): _register(fn) self.definition = function_def - self.name = function_def.signature.name + self.name = compat.as_bytes(function_def.signature.name) self.signature = function_def.signature + self._num_outputs = len(self.signature.output_arg) + self._output_types = [o.type for o in self.signature.output_arg] self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) self._grad_func = None + self._graph = graph + self._stateful_ops = tuple(op for op in operations if op.op_def.is_stateful) + + def add_to_graph(self, g): + # pylint: disable=protected-access + if self.name not in g._functions: + g._add_function(self) + for f in self._graph._functions.values(): + if f.name not in g._functions: + g._add_function(f) + # pylint: enable=protected-access + + @property + def stateful_ops(self): + return self._stateful_ops + + def call(self, ctx, args, output_shapes): + """Calls this function with `args` as inputs. + + Function execution respects device annotations only if the function won't + be compiled with xla. + + Args: + ctx: a Context object + args: a list of arguments to supply this function with. + output_shapes: shapes to which outputs should be set; ignored when + executing eagerly. + + Returns: + The outputs of the function call. + """ + + executing_eagerly = ctx.executing_eagerly() + + xla_compile = self._xla_compile or (executing_eagerly and + ctx.device_spec.device_type == "TPU") + + if xla_compile: + # XLA compilation relies upon a custom kernel creator to run functions. + signature = self.signature + if executing_eagerly: + outputs = execute.execute( + str(signature.name), + num_outputs=self._num_outputs, + inputs=args, + attrs=None, + ctx=ctx) + else: + g = ops.get_default_graph() + self.add_to_graph(g) + op = g.create_op( + signature.name, + [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args], + tuple(dtypes_module.DType(x.type) for x in signature.output_arg), + op_def=signature, + name="FunctionCall", + compute_shapes=False) + outputs = op.outputs + if not outputs: + return op + outputs = [outputs] if isinstance( + outputs, (ops.Tensor, type(None))) else list(outputs) + else: + # TODO(akshayka): Either remove this if the FunctionLibraryRuntime + # creates `PartitionedCallOp` kernels by default, or remove the previous + # branch if a TPU kernel is registered for `PartitionedCall`. + outputs = functional_ops.partitioned_call( + args=args, + f=self, + tout=self._output_types, + executing_eagerly=executing_eagerly) + + if executing_eagerly: + return outputs + else: + for i, shape in enumerate(output_shapes): + outputs[i].set_shape(shape) + return outputs def _map_sequence_obj_to_idx(sequence): @@ -297,8 +432,12 @@ def _flatten(sequence): return outputs +# TODO(akshayka): Perhaps rename to something more appropriate. class GraphModeFunction(object): - """Callable object representing a graph-mode function. + """Callable object encapsulating a function definition and its gradient. + + `GraphModeFunction` is a callable that encapsulates a function definition and + is differentiable under `tf.GradientTape` objects. """ def __init__(self, @@ -308,7 +447,7 @@ class GraphModeFunction(object): graph, operations, outputs, - func_outputs, + python_func_outputs, output_shapes, variables=None, attrs=None): @@ -327,9 +466,10 @@ class GraphModeFunction(object): definition. outputs: a flat list of the Tensors in the graph used as outputs to the function - func_outputs: a possibly nested python object which will be returned by - this function. The Tensors in this structure will be replaced by their - corresponding values in outputs. + python_func_outputs: a possibly nested python object which will be + returned by this function. The Tensors in this structure will be + replaced by their corresponding values in outputs. Note that this + structure might contain Python `None`s. output_shapes: List of shapes of all tensors in outputs variables: (optional) List of variables to watch during function execution. @@ -350,50 +490,68 @@ class GraphModeFunction(object): self._func_name = name self._function_def = defined_function self._num_outputs = len(defined_function.signature.output_arg) - self._ops = operations - self._func_outputs = func_outputs - self._returns = [func_outputs] if isinstance( - func_outputs, (ops.Tensor, type(None))) else _flatten(func_outputs) + self._python_func_outputs = python_func_outputs + self._python_returns = [python_func_outputs] if isinstance( + python_func_outputs, + (ops.Tensor, type(None))) else _flatten(python_func_outputs) self._output_shapes = output_shapes self._variables = variables if variables is not None else [] + # Find the variables that are components of something distributed and + # put them into a {handle_tensor -> distributed variable object} map. + self._distributed_variables = {} + strategy = distribution_strategy_context.get_distribution_strategy() + for variable in self._variables: + # If variable is not distributed, unwrap returns [variable]. + component_variables = strategy.unwrap(variable) + # Only add to the dictionary when the variable is actually distributed, + # i.e. more than one component or the component is different from the + # variable itself. component_variables cannot be empty. + if (len(component_variables) > 1 or component_variables[0] != variable): + for component_variable in component_variables: + self._distributed_variables[component_variable.handle] = variable + @property def variables(self): return self._variables def _construct_backprop_function(self): """Constructs the backprop function object for this function.""" - with self._graph.as_default(), context.graph_mode(): - c_known_ops = set() - c_captured_tensors = set() - - existing_op_len = len(self._graph.get_operations()) - filtered_outputs = [x for x in self._returns if x is not None] + filtered_outputs = [x for x in self._python_returns if x is not None] + # TODO(skyewm): use FuncGraph + backwards_graph = CapturingGraph() + backwards_graph._graph_key = self._graph._graph_key # pylint: disable=protected-access + for collection in self._graph.collections: + backwards_graph.get_collection_ref( + collection)[:] = self._graph.get_collection(collection) + backwards_graph.seed = self._graph.seed + with backwards_graph.as_default(): self._out_grad_placeholders = [ graph_placeholder(x.dtype, x.shape) for x in filtered_outputs] - in_gradients = gradients_impl.gradients( + in_gradients = gradients_impl._GradientsHelper( # pylint: disable=protected-access filtered_outputs, self._input_placeholders, - grad_ys=self._out_grad_placeholders) - for op in self._graph.get_operations()[existing_op_len:]: - if op.type in ["Variable", "VariableV2", "VarHandleOp"]: - raise ValueError("tfe.defun cannot capture variables created without " - "using tf.get_variable. Op: %s" % op) - c_known_ops.add(op) - for i in op.inputs: - if i.op not in c_known_ops: - c_captured_tensors.add(i) + grad_ys=self._out_grad_placeholders, + src_graph=self._graph) backward_outputs = tuple( grad for grad in _flatten(in_gradients) if grad is not None) output_shapes = tuple(grad.shape for grad in backward_outputs) - captures = list(sorted(c_captured_tensors, key=lambda x: x.name)) + extra_inputs = backwards_graph.captures.keys() + extra_placeholders = backwards_graph.captures.values() + forward_name = _forward_name(self._func_name) + # Note: we cannot have placeholder ops in the graph or the TPU compilation + # pass fails. + placeholder_ops = set([y.op for y in self._input_placeholders]) + function_ops = [x for x in self._graph.get_operations() + if x not in placeholder_ops] self._forward_fdef = _EagerDefinedFunction( - forward_name, self._graph, self._ops, self._input_placeholders, - filtered_outputs + captures, self._attrs) - all_inputs = self._out_grad_placeholders + captures + forward_name, self._graph, function_ops, + self._input_placeholders, filtered_outputs + list(extra_inputs), + self._attrs) + all_inputs = self._out_grad_placeholders + list(extra_placeholders) # Excluding input ops from the body as we do not intend to execute these # operations when the function is executed. all_ignored_ops = frozenset(x.op for x in all_inputs) @@ -401,11 +559,12 @@ class GraphModeFunction(object): # means rerunning the function-defining code will always define the same # function, which is useful if we serialize this etc. function_def_ops = tuple(x - for x in sorted(c_known_ops, key=lambda x: x.name) + for x in sorted(backwards_graph.get_operations(), + key=lambda x: x.name) if x not in all_ignored_ops) bname = _backward_name(self._func_name) self._backward_function = GraphModeFunction( - bname, all_inputs, [], self._graph, function_def_ops, + bname, all_inputs, [], backwards_graph, function_def_ops, backward_outputs, in_gradients, output_shapes, attrs=self._attrs) def _backprop_call(self, args): @@ -414,48 +573,28 @@ class GraphModeFunction(object): (Only records results on a tape if the function has outputs) Args: - args: The tensor inputs to the function. + args: All inputs to the function, including resolved extra inputs Returns: The call output. """ - all_args = args + self._extra_inputs - signature = self._forward_fdef.signature ctx = context.context() - if ctx.executing_eagerly(): - outputs = execute.execute( - str(signature.name), - num_outputs=len(signature.output_arg), - inputs=all_args, - attrs=None, - ctx=ctx) - if not outputs: - return None - else: - g = ops.get_default_graph() - g._add_function(self._forward_fdef) # pylint: disable=protected-access - op = g.create_op( - signature.name, - [ops.internal_convert_to_tensor(x, ctx=ctx) for x in all_args], - tuple(dtypes_module.DType(x.type) for x in signature.output_arg), - op_def=signature, - name="FunctionCall", - compute_shapes=False) - outputs = op.outputs - if not outputs: - return op - outputs = [outputs] if isinstance(outputs, ops.Tensor) else list(outputs) - for i, s in enumerate(self._output_shapes): - outputs[i].set_shape(s) - real_outputs = outputs[:len(self._returns)] - side_outputs = outputs[len(self._returns):] + outputs = self._forward_fdef.call(ctx, args, self._output_shapes) + if isinstance(outputs, ops.Operation) or outputs is None: + return outputs + + # `real_outputs` are the actual outputs of the inference graph function; + # `side_outputs` are the intermediate Tensors that were added as outputs to + # the forward graph function so that we can compute its gradient. + real_outputs = outputs[:self._num_outputs] + side_outputs = outputs[self._num_outputs:] def backward_function(*args): return self._backward_function(*(list(args) + side_outputs)) # pylint: disable=not-callable tape.record_operation( - signature.name, + self._forward_fdef.signature.name, real_outputs, - (args + self._extra_inputs), + args, backward_function) return self._build_call_outputs(real_outputs) @@ -464,8 +603,8 @@ class GraphModeFunction(object): def output_shapes(self): """The function's output shapes.""" # TODO(ebrevdo): Should we only keep the output shapes associated - # with len(self._returns) outputs? - outputs_list = nest.flatten(self._func_outputs) + # with len(self._python_returns) outputs? + outputs_list = nest.flatten(self._python_func_outputs) j = 0 for i, o in enumerate(outputs_list): if o is not None: @@ -479,12 +618,12 @@ class GraphModeFunction(object): else: outputs_list[i] = self._output_shapes[j] j += 1 - return nest.pack_sequence_as(self._func_outputs, outputs_list) + return nest.pack_sequence_as(self._python_func_outputs, outputs_list) @property def output_dtypes(self): return nest.map_structure( - lambda x: x.dtype if x is not None else None, self._func_outputs) + lambda x: x.dtype if x is not None else None, self._python_func_outputs) @property def captured_inputs(self): @@ -495,12 +634,32 @@ class GraphModeFunction(object): """Returns the name of the function in Eager-compatible format.""" return self._function_def.name.encode("utf-8") - def add_to_graph(self, g): - if self._function_def.name not in g._functions: # pylint: disable=protected-access - g._add_function(self._function_def) # pylint: disable=protected-access - for f in self._graph._functions.values(): # pylint: disable=protected-access - if f.name not in g._functions: # pylint: disable=protected-access - g._add_function(f) # pylint: disable=protected-access + def _resolve_extra_inputs(self): + """Resolve captured distributed variables to their current values. + + Some inputs can be distributed variables. Such variables yield a different + component (i.e. actual tf.Variable) variables depending on the context of + execution. + + Returns: + a list of resolved extra input tensors. + """ + if self._distributed_variables: + # Loop over each extra_inputs and check if it corresponds to something + # distributed. If so, get its _distributed_container and fetch the + # component appropriate for the current execution context. + resolved_extra_inputs = self._extra_inputs[:] + for i, extra_input in enumerate(self._extra_inputs): + distributed_var = self._distributed_variables.get(extra_input, None) + if distributed_var is not None: + # distributed variables override __getattr__ and substitute the + # right component variable. In here, `distributed_var.handle` + # actually does the equivalent of + # distributed_var.get_current_component_var().handle. + resolved_extra_inputs[i] = distributed_var.handle + return resolved_extra_inputs + + return self._extra_inputs def __call__(self, *args): """Executes the passed function in eager mode.""" @@ -508,40 +667,19 @@ class GraphModeFunction(object): if v.trainable: tape.watch_variable(v) + resolved_extra_inputs = self._resolve_extra_inputs() + tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)] + args = tensor_inputs + resolved_extra_inputs if tape.should_record(tensor_inputs) or tape.should_record( - self._extra_inputs): + resolved_extra_inputs): if self._backward_function is None: self._construct_backprop_function() - return self._backprop_call(tensor_inputs) + return self._backprop_call(args) ctx = context.context() - if ctx.executing_eagerly(): - result = execute.execute( - str(self._func_name), - num_outputs=self._num_outputs, - inputs=tensor_inputs + self._extra_inputs, - attrs=None, - ctx=ctx) - else: - g = ops.get_default_graph() - self.add_to_graph(g) - signature = self._function_def.definition.signature - args = list(tensor_inputs) + self._extra_inputs - op = g.create_op( - signature.name, - [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args], - tuple(dtypes_module.DType(x.type) for x in signature.output_arg), - op_def=signature, - name="FunctionCall", - compute_shapes=False) - result = op.outputs - if not result: - return op - for i, s in enumerate(self._output_shapes): - result[i].set_shape(s) - - return self._build_call_outputs(result) + outputs = self._function_def.call(ctx, args, self._output_shapes) + return self._build_call_outputs(outputs) def _build_call_outputs(self, result): """Maps the fdef output list to actual output structure. @@ -551,11 +689,12 @@ class GraphModeFunction(object): Returns: The actual call output. """ - if self._func_outputs is None: - return None + if self._python_func_outputs is None: + return result + # Use `nest.flatten` instead of `_flatten` in order to preserve any - # IndexedSlices in `self._func_outputs`. - outputs_list = nest.flatten(self._func_outputs) + # IndexedSlices in `self._python_func_outputs`. + outputs_list = nest.flatten(self._python_func_outputs) j = 0 for i, o in enumerate(outputs_list): if o is not None: @@ -575,192 +714,478 @@ class GraphModeFunction(object): else: outputs_list[i] = result[j] j += 1 - ret = nest.pack_sequence_as(self._func_outputs, outputs_list) + ret = nest.pack_sequence_as(self._python_func_outputs, outputs_list) return ret -def _get_defun_inputs(args): - """Maps the inputs args to graph inputs.""" - ret = [] - flat_args = nest.flatten(args) - for a in flat_args: - if isinstance(a, ops.Tensor): - ret.append(graph_placeholder(a.dtype, a.shape)) - else: - ret.append(a) - return nest.pack_sequence_as(args, ret) - - -def _defun_internal(name, func, compiled, args, kwds): - """Defines and returns graph-mode version of func.""" - graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access - with context.graph_mode(): - captures = {} - tmp_graph = CapturingGraph(captures) - # Inherit the graph key, since this is used for matching variables in - # optimizers. - tmp_graph._graph_key = graph_key # pylint: disable=protected-access - # Copy the graph collections to ensure summaries and other things work. This - # lets the function access (but not mutate) collections of the containing - # graph, such as the global step and the summary writer collections. - curr_graph = ops.get_default_graph() - for collection in curr_graph.collections: - tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( - collection) - with tmp_graph.as_default(), AutomaticControlDependencies() as a: - func_inputs = _get_defun_inputs(args) - - def convert(x): - if x is None: - return None - x = ops.convert_to_tensor_or_indexed_slices(x) - x = a.mark_as_return(x) - return x +def _get_defun_inputs_from_signature(signature): + """Maps a signature to graph-construction inputs.""" + function_inputs = [ + graph_placeholder(spec.dtype, spec.shape) + for spec in nest.flatten(signature) + ] + return nest.pack_sequence_as(signature, function_inputs) + + +def _get_defun_inputs_from_args(args): + """Maps python function args to graph-construction inputs.""" + function_inputs = [ + graph_placeholder(arg.dtype, arg.shape) if isinstance(arg, ops.Tensor) + else arg for arg in nest.flatten(args) + ] + return nest.pack_sequence_as(args, function_inputs) + + +def _trace_and_define_function(name, python_func, compiled, args, kwds, + signature=None): + """Defines and returns graph-mode version of `python_func`. + + Args: + name: an identifier for the function. + python_func: the Python function to trace. + compiled: whether the graph function should be compiled through XLA. + args: the positional args with which the Python function should be called; + ignored if a signature is provided. + kwds: the keyword args with which the Python function should be called; + ignored if a signature is provided. + signature: a possibly nested sequence of `TensorSpecs` specifying the shapes + and dtypes of the arguments. When a signature is provided, `args` and + `kwds` are ignored, and `python_func` is traced with Tensors conforming + to `signature`. If `None`, the shapes and dtypes are inferred from the + inputs. - this_tape = tape.push_new_tape() + Returns: + A GraphModeFunction. + + Raises: + TypeError: If any of `python_func`'s return values is neither `None` nor a + `Tensor`. + """ + func_graph = FuncGraph(_inference_name(name), graph=ops.get_default_graph()) + + with func_graph.as_default(), AutomaticControlDependencies() as a: + variable_scope.get_variable_scope().set_use_resource(True) + + if signature is None: + func_args = _get_defun_inputs_from_args(args) + func_kwds = _get_defun_inputs_from_args(kwds) + else: + func_args = _get_defun_inputs_from_signature(signature) + func_kwds = {} + + # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. + func_graph.inputs.extend( + x for x in nest.flatten(func_args) + nest.flatten(func_kwds) + if isinstance(x, ops.Tensor) + ) + + # Variables to help check whether mutation happens in calling the function + # Copy the recursive list, tuple and map structure, but not base objects + func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args)) + func_kwds_before = nest.pack_sequence_as(func_kwds, nest.flatten(func_kwds)) + + def convert(x): + """Converts an argument to a Tensor.""" + if x is None: + return None try: - func_outputs = func(*func_inputs, **kwds) - func_outputs = nest.map_structure(convert, func_outputs) - finally: - tape.pop_tape(this_tape) - variables = this_tape.watched_variables() - - # Returning a closed-over tensor as an output does not trigger a - # call to convert_to_tensor, so we manually capture all such tensors. - outputs_list = _flatten(func_outputs) - func_def_outputs = [ - tmp_graph.capture(x) for x in outputs_list - if x is not None - ] - - ids = list(sorted(captures.keys())) - if ids: - extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids]) - else: - extra_inputs = [] - extra_placeholders = [] - output_shapes = tuple( - x.shape if isinstance(x, ops.Tensor) else None - for x in outputs_list) - - flat_inputs = [x for x in nest.flatten(func_inputs) - if isinstance(x, ops.Tensor)] - all_inputs = flat_inputs + list(extra_placeholders) - all_ignored_ops = frozenset(x.op for x in all_inputs) - fname = _inference_name(name) - operations = tuple(x for x in tmp_graph.get_operations() + x = ops.convert_to_tensor_or_indexed_slices(x) + except (ValueError, TypeError): + raise TypeError( + "To be compatible with tf.contrib.eager.defun, Python functions " + "must return zero or more Tensors; in compilation of %s, found " + "return value of type %s, which is not a Tensor." % + (str(python_func), type(x))) + x = a.mark_as_return(x) + return x + + this_tape = tape.push_new_tape() + try: + func_outputs = python_func(*func_args, **func_kwds) + func_outputs = nest.map_structure(convert, func_outputs) + + def check_mutation(n1, n2): + """Check if two list of arguments are exactly the same.""" + errmsg = ("Function to be traced should not modify structure of input " + "arguments. Check if your function has list and dictionary " + "operations that alter input arguments, " + "such as `list.pop`, `list.append`") + try: + nest.assert_same_structure(n1, n2) + except ValueError: + raise ValueError(errmsg) + + for arg1, arg2 in zip(nest.flatten(n1), nest.flatten(n2)): + if arg1 is not arg2: + raise ValueError(errmsg) + + check_mutation(func_args_before, func_args) + check_mutation(func_kwds_before, func_kwds) + + finally: + tape.pop_tape(this_tape) + func_graph.structured_outputs = func_outputs + variables = list(this_tape.watched_variables()) + + # Some variables captured by the tape can come from a DistributedValue. + # At call time, DistributedValue can return another variable (e.g. if + # the function is run on a different device). Thus, instead of storing + # the specific captured variable, we replace it with its distributed + # container. + strategy = distribution_strategy_context.get_distribution_strategy() + for i, variable in enumerate(variables): + # If variable is not distributed value_container returns itself. + variables[i] = strategy.value_container(variable) + + func_graph.variables = variables + + # Returning a closed-over tensor as an output does not trigger a + # call to convert_to_tensor, so we manually capture all such tensors. + func_graph.outputs.extend( + func_graph.capture(x) for x in _flatten(func_graph.structured_outputs) + if x is not None + ) + + output_shapes = tuple( + x.shape if isinstance(x, ops.Tensor) else None + for x in func_graph.outputs) + + all_ignored_ops = frozenset(x.op for x in func_graph.inputs) + operations = tuple(x for x in func_graph.get_operations() if x not in all_ignored_ops) # Register any other functions defined in the graph # TODO(ashankar): Oh lord, forgive me for this lint travesty. if context.executing_eagerly(): - for f in tmp_graph._functions.values(): # pylint: disable=protected-access + for f in func_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? _register(f._c_func.func) # pylint: disable=protected-access attrs = {} if compiled: - attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=True) + attrs[_xla_compile_attr] = attr_value_pb2.AttrValue(b=True) return GraphModeFunction( - fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs, - func_outputs, output_shapes, variables, attrs) + func_graph.name, func_graph.inputs, func_graph.captures.keys(), + func_graph, operations, func_graph.outputs, func_graph.structured_outputs, + output_shapes, func_graph.variables, attrs) + +_TensorType = collections.namedtuple("_TensorType", ["dtype", "shape"]) -# Defun uses this instead of Tensor as a cache key. Using dtype because -# TensorFlow graphs are not parametric wrt dtypes, and using shapes for -# performance reasons, as much TensorFlow code specializes on known shapes to -# produce slimmer graphs. -_TensorDtype = collections.namedtuple("_TensorDtype", ["dtype", "shape"]) -_ZeroDtype = collections.namedtuple("_ZeroDtype", ["dtype", "shape"]) +def _encode_arg(arg): + """A canonical representation for this argument, for use in a cache key.""" -def _cache_key(x): - """Cache key for tfe functions.""" - if isinstance(x, ops.Tensor): - return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access - if isinstance(x, ops.IndexedSlices): - if x.dense_shape is not None: + # `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes + # are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes + # are used for both performance reasons, as much TensorFlow code specializes + # on known shapes to produce slimmer graphs, and correctness, as some + # high-level APIs require shapes to be fully-known. + # + # TODO(akshayka): Add support for sparse tensors. + # + # pylint: disable=protected-access + if isinstance(arg, ops.Tensor): + return _TensorType(arg.dtype, arg._shape_tuple()) + elif isinstance(arg, ops.IndexedSlices): + if arg.dense_shape is not None: return tuple([ - _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access - _TensorDtype(x.indices.dtype, x.indices._shape_tuple()), # pylint: disable=protected-access - _TensorDtype(x.dense_shape.dtype, x.dense_shape._shape_tuple()) # pylint: disable=protected-access + _TensorType(arg.values.dtype, arg.values._shape_tuple()), + _TensorType(arg.indices.dtype, arg.indices._shape_tuple()), + _TensorType(arg.dense_shape.dtype, arg.dense_shape._shape_tuple()), ]) else: return tuple([ - _TensorDtype(x.values.dtype, x.values._shape_tuple()), # pylint: disable=protected-access - _TensorDtype(x.indices.dtype, x.indices._shape_tuple()) # pylint: disable=protected-access + _TensorType(arg.values.dtype, arg.values._shape_tuple()), + _TensorType(arg.indices.dtype, arg.indices._shape_tuple()), ]) - if isinstance(x, np.ndarray): - return ("array", x.shape, tuple(x.reshape(-1))) - if isinstance(x, (list, tuple)): - return tuple([_cache_key(a) for a in x]) - if isinstance(x, dict): - return tuple(tuple([_cache_key(k), _cache_key(v)]) for k, v in x.items()) - return x + elif isinstance(arg, np.ndarray): + tensor = ops.convert_to_tensor(arg) + return _TensorType(tensor.dtype, tensor._shape_tuple()) + # pylint: enable=protected-access + elif isinstance(arg, (list, tuple)): + return tuple([_encode_arg(elem) for elem in arg]) + elif isinstance(arg, dict): + return tuple( + (_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg)) + else: + return arg -def _register(fn): - """Registers the function `fn`.""" - context.context().add_function(fn) +def _deterministic_dict_values(dictionary): + return tuple(dictionary[key] for key in sorted(dictionary)) -# TODO(apassos): better error messages for non-hashable arguments. -def named_defun(func, name, compiled=False): - """Defines a function with a given name. +class _PolymorphicFunction(object): + """Wrapper class for the graph functions defined for a Python function. See the documentation for `defun` for more information on the semantics of - this function. + defined functions. - Args: - func: the function to be wrapped. - name: the name given to it. - compiled: if true, the framework will attempt to compile func with XLA. - - Returns: - the wrapped function. + _PolymorphicFunction class is thread-compatible meaning that minimal + usage of defuns (defining and calling) is thread-safe, but if users call other + methods or invoke the base `python_function` themselves, external + synchronization is necessary. """ - arguments_to_functions = {} - def decorated(*args, **kwds): - """Decorated version of func.""" - # Macroexpand on non-Tensor arguments - cache_key = tuple(_cache_key(x) for x in args) - if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): - raise ValueError("Tensor keyword arguments are not supported.") - cache_key = (cache_key, tuple(kwds.items())) + def __init__(self, + python_function, + name, + input_signature=None, + compiled=False): + """Initializes a polymorphic function. - if cache_key not in arguments_to_functions: - arguments_to_functions[cache_key] = _defun_internal( - name, func, compiled, args, kwds) - return arguments_to_functions[cache_key](*args) + Args: + python_function: the function to be wrapped. + name: the name given to it. + input_signature: a possibly nested sequence of `TensorSpec` objects + specifying the input signature of this function. If `None`, a separate + function is instantiated for each inferred input signature. + compiled: if True, the framework will attempt to compile func with XLA. + + Raises: + ValueError: if `input_signature` is not None and the `python_function`'s + argspec has keyword arguments. + TypeError: if `input_signature` contains anything other than + `TensorSpec` objects, or (if not None) is anything other than a tuple or + list. + """ - return decorated + if isinstance(python_function, functools.partial): + self._python_function = python_function.func + self._args_to_prepend = python_function.args or tuple() + self._kwds_to_include = python_function.keywords or {} + else: + self._python_function = python_function + self._args_to_prepend = tuple() + self._kwds_to_include = {} + self._name = name + self._compiled = compiled + self._arguments_to_functions = {} + self._variables = [] + + self._lock = threading.Lock() + + fullargspec = tf_inspect.getfullargspec(self._python_function) + if tf_inspect.ismethod(self._python_function): + # Remove `self`: default arguments shouldn't be matched to it. + args = fullargspec.args[1:] + else: + args = fullargspec.args + + # A cache mapping from argument name to index, for canonicalizing + # arguments that are called in a keyword-like fashion. + self._args_to_indices = {arg: i for i, arg in enumerate(args)} + # A cache mapping from arg index to default value, for canonicalization. + offset = len(args) - len(fullargspec.defaults or []) + self._arg_indices_to_default_values = { + offset + index: default + for index, default in enumerate(fullargspec.defaults or []) + } + if input_signature is None: + self._input_signature = None + else: + if fullargspec.varkw is not None or fullargspec.kwonlyargs: + raise ValueError("Cannot define a TensorFlow function from a Python " + "function with keyword arguments when " + "input_signature is provided.") + + if not isinstance(input_signature, (tuple, list)): + raise TypeError("input_signature must be either a tuple or a " + "list, received " + str(type(input_signature))) + + self._input_signature = tuple(input_signature) + self._flat_input_signature = tuple(nest.flatten(input_signature)) + if any(not isinstance(arg, tensor_spec.TensorSpec) + for arg in self._flat_input_signature): + raise TypeError("Invalid input_signature %s; input_signature must be " + "a possibly nested sequence of TensorSpec objects.") + + def __get__(self, instance, owner): + """Makes it possible to defun instance methods.""" + del owner + # `instance` here is the instance that this `_PolymorphicFunction` was + # accessed through; e.g., for + # + # class Foo(object): + # + # @function.defun + # def bar(self): + # ... + # + # foo = Foo() + # foo.bar() # `foo.bar` is a `_PolymorphicFunction` instance + # + # then `instance` will be `foo` (and `owner` will be `Foo`). + return functools.partial(self.__call__, instance) + + def _cache_key(self, args, kwds): + """Computes the cache key given inputs.""" + if self._input_signature is None: + inputs = (args, kwds) if kwds else args + cache_key = tuple(_encode_arg(arg) for arg in inputs) + else: + del args, kwds + cache_key = self._flat_input_signature + # The graph, or whether we're executing eagerly, should be a part of the + # cache key so we don't improperly capture tensors such as variables. + return cache_key + (context.executing_eagerly() or ops.get_default_graph(),) + + def _canonicalize_function_inputs(self, *args, **kwds): + """Canonicalizes `args` and `kwds`. + + Canonicalize the inputs to the Python function using its fullargspec. In + particular, we parse the varags and kwargs that this + `_PolymorphicFunction` was called with into a tuple corresponding to the + Python function's positional (named) arguments and a dictionary + corresponding to its kwargs. + + Args: + *args: The varargs this object was called with. + **kwds: The keyword args this function was called with. + + Returns: + A canonicalized ordering of the inputs. + + Raises: + ValueError: If a keyword in `kwds` cannot be matched with a positional + argument when an input signature is specified, or when the inputs + do not conform to the input signature. + """ + args = self._args_to_prepend + args + kwds = dict(kwds, **self._kwds_to_include) + # Maps from index of arg to its corresponding value, according to `args` + # and `kwds`; seeded with the default values for the named args that aren't + # in `args`. + arg_indices_to_values = { + index: default + for index, default in six.iteritems(self._arg_indices_to_default_values) + if index >= len(args) + } + consumed_args = [] + for arg, value in six.iteritems(kwds): + index = self._args_to_indices.get(arg, None) + if index is not None: + arg_indices_to_values[index] = value + consumed_args.append(arg) + elif self._input_signature is not None: + raise ValueError("Cannot define a TensorFlow function from a Python " + "function with keyword arguments when " + "input_signature is provided.") + for arg in consumed_args: + # After this loop, `kwds` will only contain true keyword arguments, as + # opposed to named arguments called in a keyword-like fashion. + kwds.pop(arg) + inputs = args + _deterministic_dict_values(arg_indices_to_values) + if self._input_signature is None: + return inputs, kwds + else: + assert not kwds + try: + nest.assert_same_structure(self._input_signature, inputs) + except (ValueError, TypeError): + raise ValueError("Structure of Python function inputs does not match " + "input_signature.") + flat_inputs = nest.flatten(inputs) + if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs): + raise ValueError("When input_signature is provided, all inputs to " + "the Python function must be Tensors.") + tensor_specs = [tensor_spec.TensorSpec.from_tensor(tensor) + for tensor in flat_inputs] + if any(not spec.is_compatible_with(other) + for spec, other in zip(self._flat_input_signature, tensor_specs)): + raise ValueError("Python inputs incompatible with input_signature: " + "inputs (%s), input_signature (%s)" % + (str(inputs), str(self._input_signature))) + return inputs, {} + + def _maybe_define_function(self, *args, **kwds): + """Gets a function for these inputs, defining it if necessary. + + Args: + *args: args for the Python function. + **kwds: keywords for the Python function. + + Returns: + A graph function corresponding to the input signature implied by args and + kwds, as well as the inputs that the object should be called with. + + Raises: + ValueError: If inputs are incompatible with the input signature. + TypeError: If the function inputs include non-hashable objects + """ + + args, kwds = self._canonicalize_function_inputs(*args, **kwds) + cache_key = self._cache_key(args, kwds) + with self._lock: + try: + graph_function = self._arguments_to_functions.get(cache_key, None) + except TypeError: + raise TypeError("Arguments supplied to `defun`-generated functions " + "must be hashable.") + + if graph_function is None: + graph_function = _trace_and_define_function( + self._name, self._python_function, self._compiled, args, kwds, + self._input_signature) + self._variables.extend( + [v for v in graph_function.variables if v not in self._variables]) + self._arguments_to_functions[cache_key] = graph_function + return graph_function, (args, kwds) + + def __call__(self, *args, **kwds): + """Calls a graph function specialized for this input signature.""" + graph_function, inputs = self._maybe_define_function(*args, **kwds) + return graph_function(*inputs) + + def call_python_function(self, *args, **kwargs): + """Directly calls the wrapped python function.""" + return self._python_function(*args, **kwargs) + + @property + def variables(self): + """Returns a list of variables used in any of the defined functions.""" + return self._variables # TODO(akshayka): Remove the `compiled` flag and create a separate # API for xla compilation (`defun` is already complicated enough # as it is, and the keyword argument makes 'compiled' an overloaded concept) -def defun(func=None, compiled=False): +def defun(func=None, input_signature=None, compiled=False): """Compiles a Python function into a callable TensorFlow graph. `defun` (short for "define function") trace-compiles a Python function - composed of TensorFlow operations into a callable that executes a @{tf.Graph} - containing those operations. When eager execution is enabled, the ability to - create graphs from Python functions makes it possible to incrementally trade - off debugability and interactivity for performance. Functions compiled with - `defun` cannot be inspected with `pdb` and `print` statements; however, - executing a graph generated by `defun` sometimes takes less time and memory - than eagerly executing the corresponding Python function, since specifying - computations as graphs allows for optimizations like automatic buffer reuse - and parallelization among ops. Note that executing a `defun`-compiled function + composed of TensorFlow operations into a callable that executes a `tf.Graph` + containing those operations. The callable produced by `defun` contains only + the subgraph of TensorFlow operations that were executed when the Python + function was called with a particular input signature, defined as a list + of the shapes and dtypes of the Python function's Tensor-valued arguments and + the values of its non-Tensor Python objects. In particular, `defun` is _not_ a + compiler for arbitrary Python code. + + When eager execution is enabled, the ability to create graphs from Python + functions makes it possible to incrementally trade off debugability and + interactivity for performance. Functions compiled with `defun` cannot be + inspected with `pdb` and `print` statements; however, executing a graph + generated by `defun` sometimes takes less time and memory than eagerly + executing the corresponding Python function, since specifying computations as + graphs allows for optimizations like automatic buffer reuse and + parallelization among ops. Note that executing a `defun`-compiled function incurs a small constant overhead, so eagerly executing sufficiently small Python functions might take less time than executing their corresponding `defun`-generated graphs. - For a Python function to be compatible with `defun`, the values of its keyword - arguments cannot be Tensors and all of its arguments, including its keyword - arguments, must be hashable Python objects or lists thereof. Additionally, it - must return zero or more @{tf.Tensor} objects. + For a Python function to be compatible with `defun`, all of its arguments must + be hashable Python objects or lists thereof. The function itself may not + modify the list/map structure of its arguments. Additionally, it must return + zero or more `tf.Tensor` objects. If the Python function returns + a `tf.Variable`, its compiled version will return the value of that variable + as a `tf.Tensor`. + + Executing a graph generated by `defun` respects device annotations (i.e., + all `with tf.device` directives present in a Python function will also be + present in its corresponding graph), but it is not yet possible to execute the + generated graphs across multiple machines. _Example Usage_ @@ -824,29 +1249,74 @@ def defun(func=None, compiled=False): When using `defun`, there are subtleties regarding inputs, Python control flow, and variable creation that one should be aware of. For concreteness, let - `f` be a Python function that returns zero or more @{tf.Tensor} objects and + `f` be a Python function that returns zero or more `tf.Tensor` objects and let `F = defun(f)`. `F` builds a graph for each unique input signature it sees, Python control flow is baked into graphs, and operations related to variable initialization are automatically lifted out of the graphs that `F` generates and placed in the eager context if executing eagerly or into an outer graph otherwise. - _Tracing and Input Signatures_. - The signature of inputs supplied to `F` is defined to be a tuple of the shapes - and dtypes of Tensor-typed arguments and the values of non-Tensor arguments - and keyword arguments. Every time `F` is invoked, the signature of its inputs - are inferred. The first time `F(*args, **kwargs)` is invoked with a particular - signature, `f(*args, **kwargs)` is executed and all the TensorFlow operations - that `f` executes, along with the Tensors that flow between them, are recorded - in a TensorFlow graph. `F` caches this graph and binds it to the inputs' - signature; every subsequent invocation of `F` with inputs conforming to this - signature will immediately retrieve the cached graph and pass it to the - TensorFlow runtime for execution. - - Be aware that because `F` only logs TensorFlow operations, all non-TensorFlow - operations that `f` executes will only shape the _construction_ of the graphs - that `F` executes: They won't be executed when the graphs themselves are - executed. For example, whereas the Python function + _Input Signatures_ + By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph + for every unique sequence of the shapes and dtypes of Tensor arguments and + the values of Python objects it is invoked with. For example, calling + `F(tf.random_uniform([2])` will execute a different graph than + `F(tf.random_uniform([3])` because the two inputs have different shapes. + The first time that `F(*args, **kwargs)` is called with a particular sequence + of Tensor shapes and dtypes and Python values, it constructs a graph by + tracing the execution of `f(*args, **kwargs)`; this graph is bound to an + input signature inferred from `(*args, **kwargs)` and cached for future reuse. + + `tf.contrib.eager.defun` caches graphs for your convenience, letting you + define TensorFlow functions without explicitly specifying their signatures. + However, this policy is conservative and potentially expensive; for example, + when different invocations of your function have differently-shaped Tensor + inputs, this policy might generate more graph functions than necessary. To + eliminate such costs, `tf.contrib.eager.defun` allows you to supply an + optional `input_signature` argument specifying the shapes and dtypes of the + inputs. In particular, the shapes may be partially unspecified, with `None`s + in the unknown dimensions. When an input signature is provided, + `tf.contrib.eager.defun` will only instantiate a single graph for the + decorated Python function. The following is an example: + + ```python + import tensorflow as tf + + # The first `TensorSpec` below describes the shape and dtype of `words`, + # and the second describes the shape and dtype of `another_tensor`. Note that + # the last dimension of the `words` `TensorSpec` is left unspecified. + @tf.contrib.eager.defun(input_signature=[ + tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32), + tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32) + ]) + def my_sequence_model(words, another_tensor): + ... + + # Note how the third dimension of the first input can vary freely. + words = tf.random_uniform(([50, 300, 10]) + second_input = tf.random_uniform([300, 100]) + my_sequence_model(words, second_input) + + words = tf.random_uniform(([50, 300, 20]) + my_sequence_model(words, second_input) + + # Passing an input with an incompatible shape will raise an error. + words = tf.random_uniform(([50, 100, 20]) + my_sequence_model(words, second_input) # <---- This will raise an error. + + ``` + + Python functions that are compiled with an `input_signature` must only accept + Tensors as arguments and must not take unnamed keyword arguments (**kwargs). + + _Tracing_ + Be aware that because `F` only logs TensorFlow operations, all the other + Python code that `f` executes will only shape the _construction_ of the graphs + that `F` executes: the Python code won't be executed when the graphs + themselves are executed, though it will be executed every time the Python + function is traced (and a given Python function might be traced multiple + times, once for each input signature it is invoked with). For example, whereas + the Python function ```python import tensorflow as tf @@ -854,17 +1324,23 @@ def defun(func=None, compiled=False): tf.enable_eager_execution() - matrix = tf.eye(5) - # `matrix` is assumed to be a Tensor def add_noise(): - return matrix + np.random.randn(matrix.shape[0], matrix.shape[1]) + return tf.eye(5) + np.random.randn(5, 5) ``` will return a different output everytime it is invoked, the compiled function `compiled = tf.contrib.eager.defun(add_noise)` will return the same value every time it is called, since a particular random offset generated by NumPy will be inserted into the graph as a TensorFlow constant. The solution is to - replace the call to `np.random.randn` with `tf.random_normal(matrix.shape)`. + replace the call to `np.random.randn` with `tf.random_normal((5, 5))`. + + _Python Side-Effects_ + A corollary of the previous discussion on tracing is the following: If a + Python function `f` has Python side-effects, then executing `f` multiple times + will not necessarily be semantically equivalent to executing `F = + tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact + that `defun` only captures the subgraph of TensorFlow operations that is + constructed when `f` is called in a graph-building context. _Python Control Flow_. The structure of many machine learning computations depend upon whether one is @@ -898,10 +1374,10 @@ def defun(func=None, compiled=False): On the other hand, because `defun` generates graphs by tracing and not by source code analysis, it fully unrolls Python `for` and `while` loops, potentially creating large graphs. If your Python function has native loops - that run for many iterations, consider replacing them with @{tf.while_loop} + that run for many iterations, consider replacing them with `tf.while_loop` operations. - When constructing graphs, @{tf.Tensor} objects cannot be used as Python + When constructing graphs, `tf.Tensor` objects cannot be used as Python `bool` objects. This means, for example, that you should replace code in `f` resembling @@ -920,7 +1396,7 @@ def defun(func=None, compiled=False): automatically lifted out of the graphs generated by `defun`. In practice, this implies that variable creation and initialization only happen the first time `F` is called, and that variables are reused every time thereafter. Many - TensorFlow APIs, like @{tf.keras.layers.Layer} objects, create variables the + TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the first time they are called and reuse them thereafter. Automatic variable lifting makes it possible to compile these APIs without extra effort, at the cost of introducing a discrepancy between the semantics of executing Python @@ -932,7 +1408,7 @@ def defun(func=None, compiled=False): tf.enable_eager_execution() def fn(): - x = tf.contrib.eager.Variable(0.0) + x = tf.Variable(0.0) x.assign_add(1.0) return x.read_value() @@ -949,19 +1425,18 @@ def defun(func=None, compiled=False): ``` Finally, because each input signature is bound to a unique graph, if your - Python function constructs `tf.contrib.eager.Variable` objects, then each - graph constructed for that Python function will reference a unique set of - variables. To circumvent this problem, we recommend against compiling Python - functions that create `tf.contrib.eager.Variable` objects. Instead, Python - functions should either lexically close over `tf.contrib.eager.Variable` - objects or accept them as arguments, preferably encapsulated in an - object-oriented container. If you must create variables inside your Python - function and you want each graph generated for it to reference the same set of - variables, add logic to your Python function that ensures that variables are - only created the first time it is called and are reused for every subsequent - invocation; note that this is precisely what @{tf.keras.layers.Layer} objects - do, so we recommend using them to represent variable-bearing computations - whenever possible. + Python function constructs `tf.Variable` objects, then each graph constructed + for that Python function will reference a unique set of variables. To + circumvent this problem, we recommend against compiling Python functions that + create `tf.Variable` objects. Instead, Python functions should either + lexically close over `tf.Variable` objects or accept them as arguments, + preferably encapsulated in an object-oriented container. If you must create + variables inside your Python function and you want each graph generated for it + to reference the same set of variables, add logic to your Python function that + ensures that variables are only created the first time it is called and are + reused for every subsequent invocation; note that this is precisely what + `tf.keras.layers.Layer` objects do, so we recommend using them to represent + variable-bearing computations whenever possible. Args: func: function to be compiled. If `func` is None, returns a @@ -973,6 +1448,13 @@ def defun(func=None, compiled=False): def foo(...): ... + input_signature: A possibly nested sequence of + `tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of + the Tensors that will be supplied to this function. If `None`, a separate + function is instantiated for each inferred input signature. If a + signature is specified, every input to `func` must be a `Tensor`, and + `func` cannot accept `**kwargs`. + compiled: If True, an attempt to compile `func` with XLA will be made. If it fails, function will be run normally. Experimental. Currently supported only for execution on TPUs. For the vast majority of users, @@ -991,7 +1473,9 @@ def defun(func=None, compiled=False): except AttributeError: name = "function" return tf_decorator.make_decorator( - function, named_defun(function, name, compiled=compiled)) + function, + _PolymorphicFunction( + function, name, input_signature=input_signature, compiled=compiled)) # This code path is for the `foo = tfe.defun(foo, ...)` use case if func is not None: @@ -1048,15 +1532,8 @@ def make_defun_op(func, *args, **kwds): A wrapper object which can be queried for its output properties, and which can be called directly the way a `@defun` wrapped function can. - - Raises: - ValueError: if any of the keyword arguments to `func` are `EagerTensor` - objects (not yet supported). """ - name = func.__name__ - if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): - raise ValueError("Tensor keyword arguments are not supported.") - return _defun_internal(name, func, False, args, kwds) + return _trace_and_define_function(func.__name__, func, False, args, kwds) class AutomaticControlDependencies(object): @@ -1170,7 +1647,7 @@ class AutomaticControlDependencies(object): # Ensures the merge always runs ops_which_must_run.add(new_merge[0].op) if inp in last_op_using_resource_tensor: - # Ensures the switch exectutes after the previous op using the resource. + # Ensures the switch executes after the previous op using the resource. switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access # Ensure the next op outside the cond happens after the merge. last_op_using_resource_tensor[inp] = new_merge[0].op diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index cfdbe5f079058906d18268c924819e0cc428bca7..380bcf763f076d0e669c0b0c902724c9db4ebf77 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -18,28 +18,42 @@ from __future__ import division from __future__ import print_function import collections +import functools +from multiprocessing.pool import ThreadPool +import sys +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import tape -from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function as tf_function from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import gradient_descent +from tensorflow.python.platform import test +from tensorflow.python.training import adam +from tensorflow.python.training import momentum +from tensorflow.python.training import training_ops +from tensorflow.python.util import compat +from tensorflow.python.util import nest @test_util.with_c_shapes @@ -90,6 +104,32 @@ class FunctionTest(test.TestCase): self.assertAllEqual(step(), 2.0) + def testGraphGradientVariable(self): + with ops.Graph().as_default(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def f(): + return 2.0 * v + + node = f() + grads, = gradients_impl.gradients(node, v) + v.initializer.run() + self.assertAllEqual(grads.eval(), 2.0) + self.assertEqual(grads.shape, v.shape) + + def testGraphEagerIsolation(self): + + @function.defun + def f(): + v = resource_variable_ops.ResourceVariable(1.0) + return v.read_value() + + self.assertAllEqual(f(), 1.0) + + with ops.Graph().as_default(): + self.assertEqual(f().shape, ()) + def testBasicDefunOpGraphMode(self): matmul = function.defun(math_ops.matmul) @@ -104,6 +144,73 @@ class FunctionTest(test.TestCase): out = sq_op(t) self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) + def testExecutingStatelessDefunConcurrently(self): + + @function.defun + def stateless(x): + return math_ops.multiply(2.0, x) + + pool = ThreadPool() + inputs = [constant_op.constant(1.0 * x) for x in range(100)] + outputs = [float(out) for out in pool.map(stateless, inputs)] + expected = [float(2.0 * x) for x in inputs] + self.assertSequenceEqual(outputs, expected) + + def testExecutingManyStatelessDefunsConcurrently(self): + + @function.defun + def stateless(x): + del x + return math_ops.multiply(2.0, 2.0) + + pool = ThreadPool() + # `pool.map` below instantiates 100 functions, one for each object. + outputs = [ + float(out) + for out in pool.map(stateless, [object() for _ in range(100)]) + ] + expected = [4.0] * 100 + self.assertSequenceEqual(outputs, expected) + + def testExecutingStatefulDefunConcurrently(self): + + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def stateful(x): + v.assign(x) + + pool = ThreadPool() + inputs = [constant_op.constant(0.0)] * 100 + pool.map(stateful, inputs) + self.assertEqual(float(v.read_value()), 0.0) + + def testExecutingManyStatefulDefunsConcurrently(self): + + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def stateful(x): + del x + return v.assign(0.0) + + pool = ThreadPool() + # `pool.map` below instantiates 100 functions, one for each object. + pool.map(stateful, [object() for _ in range(100)]) + self.assertEqual(float(v.read_value()), 0.0) + + def disabled_testRandomSeed(self): + + @function.defun + def f(): + return random_ops.random_normal(()) + + random_seed.set_random_seed(1) + x = f() + self.assertNotEqual(x, f()) + random_seed.set_random_seed(1) + self.assertAllEqual(f(), x) + def testNestedInputsDefunOpGraphMode(self): matmul = function.defun(math_ops.matmul) @@ -166,6 +273,59 @@ class FunctionTest(test.TestCase): self.assertEqual(fn_op.output_shapes, None) self.assertAllEqual(fn_op(x, x), None) + @test_util.run_in_graph_and_eager_modes() + def testDefunCondGradient(self): + + @function.defun + def f(x): + return control_flow_ops.cond(x > 0.5, lambda: 2 * x, lambda: 3 * x) + + with backprop.GradientTape() as t: + x = constant_op.constant(1.0) + t.watch(x) + y = f(x) + self.assertAllEqual(self.evaluate(t.gradient(y, x)), 2.0) + + @test_util.run_in_graph_and_eager_modes() + def testGraphLoopGradient(self): + + @function.defun + def f(x): + return control_flow_ops.while_loop(lambda _, i: i < 2, + lambda x, i: (2*x, i + 1), + [x, 0])[0] + + with backprop.GradientTape() as t: + x = constant_op.constant(1.0) + t.watch(x) + y = f(x) + self.assertAllEqual(self.evaluate(t.gradient(y, x)), 4.0) + + def testDefunNumpyArraysConvertedToTensors(self): + + def f(x): + return x + + x = random_ops.random_uniform([2, 2]).numpy() + defined = function.defun(f) + defined(x) + self.assertEqual(len(defined._arguments_to_functions), 1) + + x = random_ops.random_uniform([2, 2]).numpy() + defined(x) + # A NumPy array with different values but the same shape and dtype + # shouldn't trigger another function definition. + self.assertEqual(len(defined._arguments_to_functions), 1) + + def testDefunCapturedInt32(self): + x = constant_op.constant(1, dtype=dtypes.int32) + + @function.defun + def add_int32s(): + return x + x + + self.assertEqual(2, int(add_int32s())) + def testDefunReadVariable(self): v = resource_variable_ops.ResourceVariable(1.0) @@ -177,13 +337,14 @@ class FunctionTest(test.TestCase): def testDefunAssignAddVariable(self): v = resource_variable_ops.ResourceVariable(1.0) + x = constant_op.constant(2.0) @function.defun - def f(): - v.assign_add(2.0) + def test_assign_add(): + v.assign_add(x) return v.read_value() - self.assertEqual(3.0, float(f())) + self.assertEqual(3.0, float(test_assign_add())) def testDefunShapeInferenceWithCapturedResourceVariable(self): v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) @@ -196,6 +357,21 @@ class FunctionTest(test.TestCase): compiled = function.defun(f) compiled() + def testVariableInLoopInFunction(self): + + @function.defun + def test_function(): + + def loop_test(_): + return False + + def loop_body(_): + return variable_scope.get_variable('a', shape=()) + + return control_flow_ops.while_loop(loop_test, loop_body, [0.0]) + + self.assertEqual(test_function().shape, []) + def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self): with context.graph_mode(): v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) @@ -221,6 +397,18 @@ class FunctionTest(test.TestCase): compiled = function.defun(f) compiled() + @test_util.run_in_graph_and_eager_modes + def testDefunForcesResourceVariables(self): + + def variable_creator(): + return variables.Variable(0.0).read_value() + + defined = function.defun(variable_creator) + defined() # Create the variable. + self.assertEqual(len(defined.variables), 1) + self.assertIsInstance( + defined.variables[0], resource_variable_ops.ResourceVariable) + def testDefunDifferentiable(self): v = resource_variable_ops.ResourceVariable(1.0) @@ -258,6 +446,22 @@ class FunctionTest(test.TestCase): op = call() self.assertAllEqual(sess.run(op), 2.0) + def testSymbolicGradientVariableZerosLike(self): + with ops.Graph().as_default(): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def f(x, v): + v.read_value() + return x * x + + x = constant_op.constant(1.0) + l = f(x, v) + _, dv = gradients_impl.gradients(l, [x, v]) + with self.test_session(): + v.initializer.run() + self.assertAllEqual(dv.eval(), 0.0) + def testGraphModeManyFunctions(self): with context.graph_mode(), self.test_session(): @@ -398,24 +602,33 @@ class FunctionTest(test.TestCase): self.assertAllEqual(f(constant_op.constant(1.0)), 2.0) - def testGradientOfGatherWithDefun(self): + def testGatherResourceWithDefun(self): with ops.device('cpu:0'): v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) - def sum_gather(): - return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) + def sum_gather(): + return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) + + defined = function.defun(sum_gather) + self.assertAllEqual(sum_gather(), defined()) + + def testGradientOfGatherWithDefun(self): + v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) + + def sum_gather(): + return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) - grad_fn = backprop.implicit_grad(sum_gather) - gradient = grad_fn() - defun_grad_fn = backprop.implicit_grad(function.defun(sum_gather)) - defun_gradient = defun_grad_fn() - self.assertEqual(len(gradient), len(defun_gradient)) + grad_fn = backprop.implicit_grad(sum_gather) + gradient = grad_fn() + defun_grad_fn = backprop.implicit_grad(function.defun(sum_gather)) + defun_gradient = defun_grad_fn() + self.assertEqual(len(gradient), len(defun_gradient)) - gradient = gradient[0][0] - defun_gradient = defun_gradient[0][0] - self.assertAllEqual(gradient.values, defun_gradient.values) - self.assertAllEqual(gradient.indices, defun_gradient.indices) - self.assertAllEqual(gradient.dense_shape, defun_gradient.dense_shape) + gradient = gradient[0][0] + defun_gradient = defun_gradient[0][0] + self.assertAllEqual(gradient.values, defun_gradient.values) + self.assertAllEqual(gradient.indices, defun_gradient.indices) + self.assertAllEqual(gradient.dense_shape, defun_gradient.dense_shape) def testReturningIndexedSlicesWithDefun(self): @@ -479,6 +692,66 @@ class FunctionTest(test.TestCase): y = f(x, x).cpu() self.assertAllEqual(y, [2.]) + @test_util.run_in_graph_and_eager_modes + def testFunctionWithResourcesOnDifferentDevices(self): + if not context.context().num_gpus(): + self.skipTest('No GPUs found.') + + with ops.device('/cpu:0'): + v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) + + with ops.device('/gpu:0'): + v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) + + def sum_gather(): + cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2])) + gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) + return cpu_result, gpu_result + + defined = function.defun(sum_gather) + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + expected = self.evaluate(sum_gather()) + self.assertAllEqual(expected, self.evaluate(defined())) + + @test_util.run_in_graph_and_eager_modes + def testOpInFunctionWithConflictingResourceInputs(self): + if not context.context().num_gpus(): + self.skipTest('No GPUs found.') + + with ops.device('/cpu:0'): + v_cpu = resource_variable_ops.ResourceVariable( + [0.0, 1.0, 2.0], name='cpu') + v_also_cpu = resource_variable_ops.ResourceVariable( + [0.0, 1.0, 2.0], name='also_cpu') + + with ops.device('/gpu:0'): + v_gpu = resource_variable_ops.ResourceVariable( + [0.0, 1.0, 2.0], name='gpu') + + @function.defun + def resource_apply_adam(): + training_ops.resource_apply_adam( + v_cpu.handle, + v_gpu.handle, + v_also_cpu.handle, + 1.0, # beta1_power + 1.0, # beta2_power + 1.0, # learning_rate + 1.0, # beta1 + 1.0, # beta2 + 1.0, # epsilon, + [1.0, 1.0, 1.0], # grad + False) # use_locking + return None + + with self.assertRaisesRegexp( + errors.InvalidArgumentError, 'Could not colocate node with its ' + 'resource and reference inputs.*'): + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + self.evaluate(resource_apply_adam()) + def testFunctionHandlesInputsOnDifferentDevices(self): if not context.context().num_gpus(): self.skipTest('No GPUs found') @@ -512,6 +785,60 @@ class FunctionTest(test.TestCase): g = backprop.gradients_function(wrapper, [0])(constant_op.constant(0.0)) self.assertAllEqual(g[0], 1.) + @function.defun + def foo(a): + return None, a * a + + x = constant_op.constant(5.0) + with backprop.GradientTape() as tp: + tp.watch(x) + none, r = foo(x) + g = tp.gradient(r, x) + + self.assertIs(none, None) + self.assertAllEqual(r, 25.0) + self.assertAllEqual(g, 2 * 5.0) + + def testNestedDifferentiableFunction(self): + @function.defun + def inner_fn(a, b): + return a * math_ops.add(a, b) + + @function.defun + def outer_fn(x): + return inner_fn(x, 1.0) + + x = constant_op.constant(5.0) + with backprop.GradientTape() as tp: + tp.watch(x) + result = outer_fn(x) + grad = tp.gradient(result, x) + + self.assertAllEqual(grad, 2 * 5.0 + 1.0) + + def testNestedDifferentiableFunctionNoneOutputs(self): + @function.defun + def foo(a, b): + return None, a * math_ops.add(a, b), None, 2*a + + @function.defun + def bar(x): + return foo(x, 1.0) + + x = constant_op.constant(5.0) + with backprop.GradientTape(persistent=True) as tp: + tp.watch(x) + none1, r1, none2, r2 = bar(x) + g1 = tp.gradient(r1, x) + g2 = tp.gradient(r2, x) + + self.assertAllEqual(r1, 30.0) + self.assertAllEqual(r2, 10.0) + self.assertIs(none1, None) + self.assertIs(none2, None) + self.assertAllEqual(g1, 2 * 5.0 + 1.0) + self.assertAllEqual(g2, 2.0) + def testNoneOutput(self): @function.defun @@ -534,15 +861,15 @@ class FunctionTest(test.TestCase): self.assertAllEqual(3, add_one(constant_op.constant(2))) def testVariableCaptureInNestedFunctions(self): - v = resource_variable_ops.ResourceVariable(1) + v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32) @function.defun - def read(): + def inner_read(): return v.read_value() @function.defun def outer(): - return read() + return inner_read() self.assertEqual(1, int(outer())) @@ -633,6 +960,412 @@ class FunctionTest(test.TestCase): y = model(x) self.assertAllEqual([[[[4.0]]]], y.numpy()) + # Note: The ConfigProto below unfortunately only configures graph + # construction. Eager's configuration is controlled in `__main__`. + @test_util.run_in_graph_and_eager_modes( + config=config_pb2.ConfigProto(device_count={'CPU': 4})) + def testDeviceAnnotationsRespected(self): + + @function.defun + def multi_device_fn(): + with ops.device('/cpu:0'): + s1 = iterator_ops.Iterator.from_structure( + (dtypes.float32,)).string_handle() + with ops.device('/cpu:1'): + s2 = iterator_ops.Iterator.from_structure( + (dtypes.float32,)).string_handle() + with ops.device('/cpu:2'): + s3 = iterator_ops.Iterator.from_structure( + (dtypes.float32,)).string_handle() + with ops.device(''): + # TODO(akshayka): This is unfortunate and brittle. It prevents + # `Iterator.from_structure` from assigning the iterator op to 'cpu:0'. + # Remove this hack once we have a way of obtaining metadata about + # function execution. + s4 = iterator_ops.Iterator.from_structure( + (dtypes.float32,)).string_handle() + return s1, s2, s3, s4 + + with ops.device('/cpu:3'): + outputs = self.evaluate(multi_device_fn()) + self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) + self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) + self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) + self.assertIn(compat.as_bytes('CPU:3'), outputs[3]) + + with ops.device('/cpu:0'): + outputs = self.evaluate(multi_device_fn()) + self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) + self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) + self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) + self.assertIn(compat.as_bytes('CPU:0'), outputs[3]) + + def testVariablesAreTracked(self): + v = resource_variable_ops.ResourceVariable(1.0) + + def foo(x): + return v * x + + defined = function.defun(foo) + + x = constant_op.constant([1.0]) + self.assertAllEqual(defined.variables, []) + _ = defined(x) + self.assertAllEqual(defined.variables, [v]) + + x = constant_op.constant([1.0, 2.0]) + _ = defined(x) # ensure the variables list remains the same + self.assertAllEqual(defined.variables, [v]) + + def testPythonFunctionWithDefaultArgs(self): + + def func(foo, bar=1, baz=2): + del foo + del bar + del baz + return + + defined = function.defun(func) + defined(0, baz=20) + # `True` corresponds to the fact that we're executing eagerly + self.assertIn((0, 1, 20, True), defined._arguments_to_functions) + + defined(1) # bar=1, baz=2 + self.assertIn((1, 1, 2, True), defined._arguments_to_functions) + + # This matches the previous call. + defined(foo=1) + self.assertEqual(len(defined._arguments_to_functions), 2) + + defined(1, 2, 3) + self.assertIn((1, 2, 3, True), defined._arguments_to_functions) + + # This matches the previous call. + defined(1, bar=2, baz=3) + self.assertEqual(len(defined._arguments_to_functions), 3) + + # This matches the previous call. + defined(1, baz=3, bar=2) + self.assertEqual(len(defined._arguments_to_functions), 3) + + def testFunctoolsPartialUnwrappedCorrectly(self): + + def full_function(a, b, c=3): + return a, b, c + + partial = functools.partial(full_function, 1, c=3) + a, b, c = partial(2) + + defined = function.defun(partial) + func_a, func_b, func_c = defined(2) + self.assertEqual(func_a.numpy(), a) + self.assertEqual(func_b.numpy(), b) + self.assertEqual(func_c.numpy(), c) + + def testInputSignatureWithCompatibleInputs(self): + + def foo(a): + self.assertEqual(a.shape, (2,)) + return a + + signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] + defined = function.defun(foo, input_signature=signature) + a = array_ops.ones([2]) + out = defined(a) + self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertAllEqual(out, a) + + def bar(a): + self.assertEqual(a._shape_tuple(), (2, None)) + return a + + signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)] + defined = function.defun(bar, input_signature=signature) + a = array_ops.ones([2, 1]) + out = defined(a) + self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertAllEqual(out, a) + + # Changing the second dimension shouldn't create a new function. + b = array_ops.ones([2, 3]) + out = defined(b) + self.assertEqual(len(defined._arguments_to_functions), 1) + self.assertAllEqual(out, b) + + def testNestedInputSignatures(self): + + def foo(a, b): + self.assertEqual(a[0]._shape_tuple(), (2, None)) + self.assertEqual(a[1]._shape_tuple(), (2, None)) + self.assertEqual(b._shape_tuple(), (1,)) + return [a, b] + + signature = [[tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2, + tensor_spec.TensorSpec((1,), dtypes.float32)] + defined = function.defun(foo, input_signature=signature) + a = array_ops.ones([2, 1]) + b = array_ops.ones([1]) + out = defined([a, a], b) + self.assertEqual(len(defined._arguments_to_functions), 1) + nest.assert_same_structure(out, [[a, a], b]) + self.assertAllEqual(out[0][0], a) + self.assertAllEqual(out[0][1], a) + self.assertAllEqual(out[1], b) + + # Changing the unspecified dimensions shouldn't create a new function. + a = array_ops.ones([2, 3]) + b = array_ops.ones([2, 5]) + c = array_ops.ones([1]) + out = defined([a, b], c) + self.assertEqual(len(defined._arguments_to_functions), 1) + nest.assert_same_structure(out, [[a, b], c]) + self.assertAllEqual(out[0][0], a) + self.assertAllEqual(out[0][1], b) + self.assertAllEqual(out[1], c) + + def bar(a): + self.assertEqual(a['a']._shape_tuple(), (2, None)) + self.assertEqual(a['b']._shape_tuple(), (2, None)) + self.assertEqual(a['c']._shape_tuple(), (1,)) + return a + + signature = [{ + 'a': tensor_spec.TensorSpec((2, None), dtypes.float32), + 'b': tensor_spec.TensorSpec((2, None), dtypes.float32), + 'c': tensor_spec.TensorSpec((1,), dtypes.float32) + }] + a = array_ops.ones([2, 3]) + b = array_ops.ones([1]) + inputs = {'a': a, 'b': a, 'c': b} + defined = function.defun(bar, input_signature=signature) + out = defined(inputs) + nest.assert_same_structure(out, inputs) + self.assertAllEqual(out['a'], inputs['a']) + self.assertAllEqual(out['b'], inputs['b']) + self.assertAllEqual(out['c'], inputs['c']) + + def testInputSignatureMustBeSequenceOfTensorSpecs(self): + + def foo(a, b): + del a + del b + + # Signatures must consist exclusively of `TensorSpec` objects. + signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)] + with self.assertRaisesRegexp(TypeError, 'Invalid input_signature.*'): + function.defun(foo, input_signature=signature)(1, 2) + + # Signatures must be either lists or tuples on their outermost levels. + signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)} + with self.assertRaisesRegexp(TypeError, 'input_signature must be either a ' + 'tuple or a list.*'): + function.defun(foo, input_signature=signature)(1, 2) + + def testInputsIncompatibleWithSignatureRaisesError(self): + + def foo(a): + return a + + signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] + defined = function.defun(foo, input_signature=signature) + + # Invalid shapes. + with self.assertRaisesRegexp(ValueError, 'Python inputs incompatible.*'): + defined(array_ops.ones([3])) + + with self.assertRaisesRegexp(ValueError, 'Python inputs incompatible.*'): + defined(array_ops.ones([2, 1])) + + # Wrong number of arguments. + with self.assertRaisesRegexp(ValueError, + 'Structure of Python function inputs.*'): + defined(array_ops.ones([2]), array_ops.ones([2])) + with self.assertRaisesRegexp(ValueError, + 'Structure of Python function inputs.*'): + defined() + + def testInputSignatureForFunctionWithNonTensorInputsNotAllowed(self): + + def foo(a, training=True): + if training: + return a + else: + return -1.0 * a + + signature = [tensor_spec.TensorSpec([], dtypes.float32)] * 2 + defined = function.defun(foo, input_signature=signature) + a = constant_op.constant(1.0) + with self.assertRaisesRegexp( + ValueError, 'When input_signature is provided, ' + 'all inputs to the Python function must be Tensors.'): + defined(a, training=True) + + def testInputSignatureWithKeywordPositionalArgs(self): + + @function.defun(input_signature=[ + tensor_spec.TensorSpec([], dtypes.float32), + tensor_spec.TensorSpec([], dtypes.int64) + ]) + def foo(flt, integer): + return flt, integer + + flt = constant_op.constant(1.0) + integer = constant_op.constant(2, dtypes.int64) + + out1, out2 = foo(flt, integer) + self.assertEqual(len(foo._arguments_to_functions), 1) + self.assertEqual(out1.numpy(), 1.0) + self.assertEqual(out2.numpy(), 2) + + out1, out2 = foo(flt=flt, integer=integer) + self.assertEqual(len(foo._arguments_to_functions), 1) + self.assertEqual(out1.numpy(), 1.0) + self.assertEqual(out2.numpy(), 2) + + out1, out2 = foo(integer=integer, flt=flt) + self.assertEqual(len(foo._arguments_to_functions), 1) + self.assertEqual(out1.numpy(), 1.0) + self.assertEqual(out2.numpy(), 2) + + out1, out2 = foo(flt, integer=integer) + self.assertEqual(len(foo._arguments_to_functions), 1) + self.assertEqual(out1.numpy(), 1.0) + self.assertEqual(out2.numpy(), 2) + + def testInputSignatureWithKeywordArgsFails(self): + + def foo(a, **kwargs): + del a + del kwargs + + with self.assertRaisesRegexp( + ValueError, 'Cannot define a TensorFlow function from a Python ' + 'function with keyword arguments when input_signature.*'): + function.defun( + foo, + input_signature=[ + tensor_spec.TensorSpec([], dtypes.float32), + tensor_spec.TensorSpec([], dtypes.int64) + ]) + + def testTensorKeywordArguments(self): + + def foo(a, b): + del a + return b + + defined = function.defun(foo) + a = constant_op.constant(2.0) + b = constant_op.constant([1.0, 2.0]) + one = defined(a, b) + self.assertEqual(len(defined._arguments_to_functions), 1) + + two = defined(a=a, b=b) + self.assertEqual(len(defined._arguments_to_functions), 1) + + three = defined(b=b, a=a) + self.assertEqual(len(defined._arguments_to_functions), 1) + + four = defined(a, b=b) + self.assertEqual(len(defined._arguments_to_functions), 1) + + # The next call corresponds to a new input signature, hence + # we expect another function to be defined. + five = defined(b, a) + self.assertEqual(len(defined._arguments_to_functions), 2) + + six = defined(a=b, b=a) + self.assertEqual(len(defined._arguments_to_functions), 2) + + seven = defined(b=a, a=b) + self.assertEqual(len(defined._arguments_to_functions), 2) + + self.assertAllEqual(one, [1.0, 2.0]) + self.assertAllEqual(two, [1.0, 2.0]) + self.assertAllEqual(three, [1.0, 2.0]) + self.assertAllEqual(four, [1.0, 2.0]) + self.assertAllEqual(five, 2.0) + self.assertAllEqual(six, 2.0) + self.assertAllEqual(seven, 2.0) + + def testGradientWithKeywordArguments(self): + matmul = function.defun(math_ops.matmul) + + def sq(x): + return matmul(a=x, b=x, transpose_a=True) + + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + grad_t, = backprop.gradients_function(sq, [0])(t) + self.assertAllEqual(grad_t, [[6, 6], [14, 14]]) + + with backprop.GradientTape(persistent=True) as gtape: + gtape.watch(t) + one = matmul(t, b=t, transpose_a=True) + two = matmul(b=t, a=t, transpose_a=True) + three = matmul(a=t, b=t, transpose_a=True) + + for output in [one, two, three]: + self.assertAllEqual(gtape.gradient(output, t), [[6, 6], [14, 14]]) + + def testGradientInFunctionWithKeywordArguments(self): + + @function.defun + def f(x): + return backprop.gradients_function(lambda y: y * y, [0])(x)[0] + + self.assertAllEqual(f(x=constant_op.constant(1.0)), 2.0) + + def testDefuningInstanceMethod(self): + + integer = constant_op.constant(2, dtypes.int64) + + class Foo(object): + + def one(self, tensor): + return tensor + + @function.defun + def two(self, tensor, other=integer): + return self.one(tensor), other + + foo = Foo() + t = constant_op.constant(1.0) + one, two = foo.two(t) + self.assertEqual(one.numpy(), 1.0) + self.assertEqual(two.numpy(), 2) + + def testDefuningInstanceMethodWithDefaultArgument(self): + + integer = constant_op.constant(2, dtypes.int64) + + class Foo(object): + + @function.defun + def func(self, other=integer): + return other + + foo = Foo() + self.assertEqual(foo.func().numpy(), int(integer)) + + def testPythonCallWithSideEffects(self): + state = [] + + @function.defun + def side_effecting_function(): + state.append(0) + + side_effecting_function() + self.assertAllEqual(state, [0]) + + # The second invocation should call the graph function, which shouldn't + # trigger the list append. + side_effecting_function() + self.assertAllEqual(state, [0]) + + # Whereas calling the python function directly should create a side-effect. + side_effecting_function.call_python_function() + self.assertAllEqual(state, [0, 0]) + @test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): @@ -820,7 +1553,7 @@ class AutomaticControlDependenciesTest(test.TestCase): def loss(v): return v**2 - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) + optimizer = momentum.MomentumOptimizer(learning_rate=1.0, momentum=1.0) @function.defun def train(): @@ -832,12 +1565,41 @@ class AutomaticControlDependenciesTest(test.TestCase): value = train() self.assertEqual(value.numpy(), -1.0) + def testReturningNonTensorRaisesError(self): + optimizer = momentum.MomentumOptimizer(learning_rate=1.0, momentum=1.0) + optimizer.apply_gradients = function.defun(optimizer.apply_gradients) + v = resource_variable_ops.ResourceVariable(1.0) + grad = backprop.implicit_grad(lambda v: v**2)(v) + + with self.assertRaisesRegexp(TypeError, + '.*must return zero or more Tensors.*'): + # TODO(akshayka): We might want to allow defun-ing Python functions + # that return operations (and just execute the op instead of running it). + optimizer.apply_gradients(grad) + + # TODO(b/111663004): This should work when the outer context is graph + # building. + def testOptimizerNonSlotVarsInDefunNoError(self): + def loss(v): + return v**2 + + optimizer = adam.AdamOptimizer(learning_rate=1.0) + + @function.defun + def train(): + v = resource_variable_ops.ResourceVariable(1.0) + grad = backprop.implicit_grad(loss)(v) + optimizer.apply_gradients(grad) + return v.read_value() + + train() + def testOptimizerInDefunWithCapturedVariable(self): v = resource_variable_ops.ResourceVariable(1.0) def loss(): return v**2 - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) + optimizer = momentum.MomentumOptimizer(learning_rate=1.0, momentum=1.0) @function.defun def train(): @@ -847,6 +1609,176 @@ class AutomaticControlDependenciesTest(test.TestCase): train() self.assertEqual(v.numpy(), -1.0) + def testFunctionModifiesInputList(self): + # Tests on `list` methods that do in place modification, except `list.sort` + # since it cannot even be "defunned" in the first place + + def get_list(): + return [constant_op.constant(0.), constant_op.constant(1.)] + + expected_msg = ( + 'Function to be traced should not modify structure of input ' + 'arguments. Check if your function has list and dictionary ' + 'operations that alter input arguments, ' + 'such as `list.pop`, `list.append`') + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def append(l): + l.append(constant_op.constant(0.)) + + append(get_list()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def extend(l): + l.extend([constant_op.constant(0.)]) + + extend(get_list()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def insert(l): + l.insert(0, constant_op.constant(0.)) + + insert(get_list()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def pop(l): + l.pop() + + pop(get_list()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def reverse(l): + l.reverse() + + reverse(get_list()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def remove(l): + l.remove(l[0]) + + remove(get_list()) + + # `list.clear` is a method that is in Py3 but not Py2 + if sys.version.startswith('3'): + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def clear(l): + l.clear() + + clear(get_list()) + + # One last test for keyword arguments + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def kwdappend(**kwargs): + l = kwargs['l'] + l.append(constant_op.constant(0.)) + + kwdappend(l=get_list()) + + def testFunctionModifiesInputDict(self): + + def get_dict(): + return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)} + + expected_msg = ( + 'Function to be traced should not modify structure of input ' + 'arguments. Check if your function has list and dictionary ' + 'operations that alter input arguments, ' + 'such as `list.pop`, `list.append`') + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def clear(m): + m.clear() + + clear(get_dict()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def pop(m): + m.pop('t1') + + pop(get_dict()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def popitem(m): + m.popitem() + + popitem(get_dict()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def update(m): + m.update({'t1': constant_op.constant(3.)}) + + update(get_dict()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def setdefault(m): + m.setdefault('t3', constant_op.constant(3.)) + + setdefault(get_dict()) + + def testFunctionModifiesInputNest(self): + # Test on functions that modify structure of nested input arguments + expected_msg = ( + 'Function to be traced should not modify structure of input ' + 'arguments. Check if your function has list and dictionary ' + 'operations that alter input arguments, ' + 'such as `list.pop`, `list.append`') + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def modify(n): + n[0]['t1'].append(constant_op.constant(1.)) + + nested_input = [{ + 't1': [constant_op.constant(0.), + constant_op.constant(1.)], + }, + constant_op.constant(2.)] + + modify(nested_input) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + # The flat list doesn't change whereas the true structure changes + @function.defun + def modify_same_flat(n): + n[0].append(n[1].pop(0)) + + nested_input = [[constant_op.constant(0.)], + [constant_op.constant(1.), + constant_op.constant(2.)]] + + modify_same_flat(nested_input) + if __name__ == '__main__': + ops.enable_eager_execution( + config=config_pb2.ConfigProto(device_count={'CPU': 4})) test.main() diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index 760a1485523798c6587e95804488a14b42a69bc0..7105d2e399b373e962526802228f7ffe5af4a55a 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -110,13 +110,25 @@ class _VariableCapturingScope(object): """ # TODO(apassos) ignoring the regularizer and partitioner here; figure out # how to deal with these. - def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring - initializer=None, regularizer=None, reuse=None, - trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name - partitioner=None, validate_shape=True, - use_resource=None): + def _custom_getter( # pylint: disable=missing-docstring + getter=None, + name=None, + shape=None, + dtype=dtypes.float32, + initializer=None, + regularizer=None, + reuse=None, + trainable=None, + collections=None, + caching_device=None, # pylint: disable=redefined-outer-name + partitioner=None, + validate_shape=True, + use_resource=None, + aggregation=variable_scope.VariableAggregation.NONE, + synchronization=variable_scope.VariableSynchronization.AUTO): del getter, regularizer, partitioner, validate_shape, use_resource, dtype - del collections, initializer, trainable, reuse, caching_device, shape, + del collections, initializer, trainable, reuse, caching_device, shape + del aggregation, synchronization assert name in self.variables v = self.variables[name] return v.variable @@ -136,13 +148,24 @@ class _VariableCapturingScope(object): """ # TODO(apassos) ignoring the regularizer and partitioner here; figure out # how to deal with these. - def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring - initializer=None, regularizer=None, reuse=None, - trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name - partitioner=None, validate_shape=True, - use_resource=None): + def _custom_getter( # pylint: disable=missing-docstring + getter=None, + name=None, + shape=None, + dtype=dtypes.float32, + initializer=None, + regularizer=None, + reuse=None, + trainable=None, + collections=None, + caching_device=None, # pylint: disable=redefined-outer-name + partitioner=None, + validate_shape=True, + use_resource=None, + aggregation=variable_scope.VariableAggregation.NONE, + synchronization=variable_scope.VariableSynchronization.AUTO): del getter, regularizer, collections, caching_device, partitioner - del use_resource, validate_shape + del use_resource, validate_shape, aggregation, synchronization if name in self.tf_variables: if reuse: return self.tf_variables[name].initialized_value() @@ -257,8 +280,7 @@ def _graph_callable_internal(func, shape_and_dtypes): # This graph will store both the initialization and the call version of the # wrapped function. It will later be used by the backprop code to build the # backprop graph, if necessary. - captures = {} - tmp_graph = function.CapturingGraph(captures) + tmp_graph = function.CapturingGraph() # Inherit the graph key from the original graph to ensure optimizers don't # misbehave. tmp_graph._container = container # pylint: disable=protected-access @@ -266,7 +288,7 @@ def _graph_callable_internal(func, shape_and_dtypes): with tmp_graph.as_default(): # Placeholders for the non-variable inputs. func_inputs = _get_graph_callable_inputs(shape_and_dtypes) - func_num_args = len(tf_inspect.getargspec(func).args) + func_num_args = len(tf_inspect.getfullargspec(func).args) if len(func_inputs) != func_num_args: raise TypeError("The number of arguments accepted by the decorated " "function `%s` (%d) must match the number of " @@ -308,12 +330,9 @@ def _graph_callable_internal(func, shape_and_dtypes): sorted_variables = sorted(variable_captures.variables.values(), key=lambda x: x.name) - ids = list(sorted(captures.keys())) - if ids: - extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids]) - else: - extra_inputs = [] - extra_placeholders = [] + + extra_inputs = tmp_graph.captures.keys() + extra_placeholders = tmp_graph.captures.values() flat_inputs = [x for x in nest.flatten(func_inputs) if isinstance(x, tf_ops.Tensor)] diff --git a/tensorflow/python/eager/memory_test.py b/tensorflow/python/eager/memory_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a59d511fdd4b831ea853b1f1cb3212322a3b84 --- /dev/null +++ b/tensorflow/python/eager/memory_test.py @@ -0,0 +1,110 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for memory leaks in eager execution. + +It is possible that this test suite will eventually become flaky due to taking +too long to run (since the tests iterate many times), but for now they are +helpful for finding memory leaks since not all PyObject leaks are found by +introspection (test_util decorators). Please be careful adding new tests here. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.python import keras +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops + +# memory_profiler might not be available in the OSS version of TensorFlow. +try: + import memory_profiler # pylint:disable=g-import-not-at-top +except ImportError: + memory_profiler = None + + +class SingleLayerNet(keras.Model): + """Simple keras model used to ensure that there are no leaks.""" + + def __init__(self): + super(SingleLayerNet, self).__init__() + self.fc1 = keras.layers.Dense(5) + + def call(self, x): + return self.fc1(x) + + +class MemoryTest(test.TestCase): + + def assertNotIncreasingMemory(self, + f, + num_iters=100000, + increase_threshold_absolute_mb=10): + """Assert memory usage doesn't increase beyond given threshold for f.""" + + with context.eager_mode(): + # Warm up. + f() + + initial = memory_profiler.memory_usage(-1)[0] + + for _ in six.moves.range(num_iters): + f() + + increase = memory_profiler.memory_usage(-1)[0] - initial + + assert increase < increase_threshold_absolute_mb, ( + "Increase is too high. Initial memory usage: %f MB. Increase: %f MB. " + "Maximum allowed increase: %f") % (initial, increase, + increase_threshold_absolute_mb) + + def testMemoryLeakInSimpleModelForwardOnly(self): + if memory_profiler is None: + self.skipTest("memory_profiler required to run this test") + + inputs = array_ops.zeros([32, 100], dtypes.float32) + net = SingleLayerNet() + + def f(): + with backprop.GradientTape(): + net(inputs) + + self.assertNotIncreasingMemory(f) + + def testMemoryLeakInSimpleModelForwardAndBackward(self): + if memory_profiler is None: + self.skipTest("memory_profiler required to run this test") + + inputs = array_ops.zeros([32, 100], dtypes.float32) + net = SingleLayerNet() + + def f(): + with backprop.GradientTape() as tape: + result = net(inputs) + + tape.gradient(result, net.variables) + + del tape + + self.assertNotIncreasingMemory(f) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index fc76ede4c502ae8b554c925a921e419bf003c40c..17a090d5262f790c92dfa1a92d47f9b5ac6c07d9 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -370,6 +370,10 @@ class OpsTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError): float(x) + def testRange(self): + x = constant_op.constant(2) + self.assertEqual([0, 1], list(range(x))) + def testFormatString(self): x = constant_op.constant(3.1415) self.assertEqual('3.14', '{:.2f}'.format(x)) diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index ea604647faede0e5b86a17938d0a7c8a7621dec1..15d2ccf9d2b533ced7fd0d104d7a3be3c2ad4dd3 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -154,6 +154,7 @@ TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle, if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR TFE_OpSetAttrType(op, "SrcT", src_type_enum); TFE_OpSetAttrType(op, "DstT", dst_type_enum); + TFE_OpSetAttrBool(op, "Truncate", false); TFE_TensorHandle* output = nullptr; int num_outputs = 1; TFE_Execute(op, &output, &num_outputs, out_status); @@ -620,10 +621,6 @@ static PyType_Slot EagerTensor_Type_slots[] = { {Py_tp_init, reinterpret_cast(EagerTensor_init)}, {0, nullptr}, }; - -PyType_Spec EagerTensor_Type_spec = {"EagerTensor", sizeof(EagerTensor), 0, - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, - EagerTensor_Type_slots}; #else // TODO(agarwal): support active_trace. static PyTypeObject _EagerTensorType = { @@ -754,6 +751,34 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { #if PY_MAJOR_VERSION >= 3 PyObject* bases = PyTuple_New(1); PyTuple_SET_ITEM(bases, 0, base_class); + + tensorflow::Safe_PyObjectPtr base_class_module( + PyObject_GetAttrString(base_class, "__module__")); + const char* module = nullptr; + if (PyErr_Occurred()) { + PyErr_Clear(); + module = "__builtin__"; + } else { + module = PyBytes_AsString(base_class_module.get()); + if (module == nullptr) { + PyErr_Clear(); + module = PyUnicode_AsUTF8(base_class_module.get()); + if (module == nullptr) { + PyErr_Clear(); + module = "__builtin__"; + } + } + } + + // NOTE: The c_str from this string needs to outlast the function, hence is + // static. + static tensorflow::string fully_qualified_name = + tensorflow::strings::StrCat(module, ".EagerTensor"); + + static PyType_Spec EagerTensor_Type_spec = { + fully_qualified_name.c_str(), sizeof(EagerTensor), 0, + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots}; + EagerTensorType = reinterpret_cast( PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases)); if (PyErr_Occurred()) { diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 6c9481c3af31a96308bb47f6e2aa35988f16f709..2d54555cd37e630697b1721d526e50c285687e88 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -205,14 +205,20 @@ bool ParseDimensionValue(const string& key, PyObject* py_value, } bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status, - const char** value) { + tensorflow::StringPiece* value) { if (PyBytes_Check(py_value)) { - *value = PyBytes_AsString(py_value); + Py_ssize_t size = 0; + char* buf = nullptr; + if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false; + *value = tensorflow::StringPiece(buf, size); return true; } #if PY_MAJOR_VERSION >= 3 if (PyUnicode_Check(py_value)) { - *value = PyUnicode_AsUTF8(py_value); + Py_ssize_t size = 0; + char* buf = PyUnicode_AsUTF8AndSize(py_value, &size); + if (buf == nullptr) return false; + *value = tensorflow::StringPiece(buf, size); return true; } #endif @@ -275,8 +281,16 @@ bool SetOpAttrList( } if (type == TF_ATTR_STRING) { - PARSE_LIST(const char*, ParseStringValue); - TFE_OpSetAttrStringList(op, key, values.get(), num_values); + std::unique_ptr values(new const void*[num_values]); + std::unique_ptr lengths(new size_t[num_values]); + for (int i = 0; i < num_values; ++i) { + tensorflow::StringPiece value; + tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); + if (!ParseStringValue(key, py_value.get(), status, &value)) return false; + values[i] = value.data(); + lengths[i] = value.size(); + } + TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); } else if (type == TF_ATTR_INT) { PARSE_LIST(int64_t, ParseInt64Value); TFE_OpSetAttrIntList(op, key, values.get(), num_values); @@ -379,12 +393,15 @@ void SetOpAttrListDefault( TF_Status* status) { if (type == TF_ATTR_STRING) { int num_values = attr.default_value().list().s_size(); - std::unique_ptr values(new const char*[num_values]); + std::unique_ptr values(new const void*[num_values]); + std::unique_ptr lengths(new size_t[num_values]); (*attr_list_sizes)[key] = num_values; for (int i = 0; i < num_values; i++) { - values[i] = attr.default_value().list().s(i).data(); + const string& v = attr.default_value().list().s(i); + values[i] = v.data(); + lengths[i] = v.size(); } - TFE_OpSetAttrStringList(op, key, values.get(), num_values); + TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); } else if (type == TF_ATTR_INT) { int num_values = attr.default_value().list().i_size(); std::unique_ptr values(new int64_t[num_values]); @@ -470,9 +487,9 @@ bool SetOpAttrScalar( tensorflow::gtl::FlatMap* attr_list_sizes, TF_Status* status) { if (type == TF_ATTR_STRING) { - const char* value; + tensorflow::StringPiece value; if (!ParseStringValue(key, py_value, status, &value)) return false; - TFE_OpSetAttrString(op, key, value); + TFE_OpSetAttrString(op, key, value.data(), value.size()); } else if (type == TF_ATTR_INT) { int64_t value; if (!ParseInt64Value(key, py_value, status, &value)) return false; @@ -533,7 +550,7 @@ bool SetOpAttrScalar( // (which is what the various "defun" or "Defun" decorators do). // And in the future also allow an object that can encapsulate // the function name and its attribute values. - const char* func_name = nullptr; + tensorflow::StringPiece func_name; if (!ParseStringValue(key, py_value, status, &func_name)) { PyObject* name_attr = PyObject_GetAttrString(py_value, "name"); if (name_attr == nullptr || @@ -549,7 +566,8 @@ bool SetOpAttrScalar( return false; } } - TFE_Op* func = TFE_NewOp(ctx, func_name, status); + TFE_Op* func = TFE_NewOp( + ctx, string(func_name.data(), func_name.size()).c_str(), status); if (TF_GetCode(status) != TF_OK) return false; TFE_OpSetAttrFunction(op, key, func); TFE_DeleteOp(func); @@ -827,11 +845,9 @@ int64_t get_uid() { PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); } void TFE_DeleteContextCapsule(PyObject* context) { - TF_Status* status = TF_NewStatus(); TFE_Context* ctx = reinterpret_cast(PyCapsule_GetPointer(context, nullptr)); - TFE_DeleteContext(ctx, status); - TF_DeleteStatus(status); + TFE_DeleteContext(ctx); } static tensorflow::int64 MakeInt(PyObject* integer) { @@ -930,7 +946,7 @@ class GradientTape : id(id), variable(variable) {} }; struct CompareById { - bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) { + bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const { return lhs.id < rhs.id; } }; @@ -1155,14 +1171,14 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { TFE_TensorHandle* t = EagerTensor_Handle(tensor); tensorflow::int64 id = EagerTensor_id(tensor); - const tensorflow::Tensor* tensor = nullptr; - const tensorflow::Status status = t->handle->Tensor(&tensor); + tensorflow::TensorShape tensor_shape; + const tensorflow::Status status = t->handle->Shape(&tensor_shape); + if (MaybeRaiseExceptionFromStatus(status, nullptr)) { return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensorflow::TensorShape({})}; } else { - return tensorflow::eager::TapeTensor{id, t->handle->dtype, - tensor->shape()}; + return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape}; } } tensorflow::int64 id = FastTensorId(tensor); @@ -1710,7 +1726,6 @@ bool OpDoesntRequireOutput(const string& op_name) { "BiasAdd", "BiasAddV1", "BiasAddGrad", - "Relu6", "Softplus", "SoftplusGrad", "Softsign", @@ -1783,6 +1798,7 @@ bool OpDoesntRequireInput(const string& op_name) { "LogSoftmax", "BiasAdd", "Relu", + "Relu6", "Elu", "Selu", "SparseSoftmaxCrossEntropyWithLogits", @@ -1880,14 +1896,39 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, void MaybeWatchVariable(PyObject* input) { DCHECK(CheckResourceVariable(input)); - DCHECK(PyObject_HasAttrString(input, "trainable")); + DCHECK(PyObject_HasAttrString(input, "_trainable")); tensorflow::Safe_PyObjectPtr trainable( - PyObject_GetAttrString(input, "trainable")); + PyObject_GetAttrString(input, "_trainable")); if (trainable.get() == Py_False) return; TFE_Py_TapeSetWatchVariable(input); } +bool CastTensor(const FastPathOpExecInfo& op_exec_info, + const TF_DataType& desired_dtype, + tensorflow::Safe_TFE_TensorHandlePtr* handle, + TF_Status* status) { + TF_DataType input_dtype = TFE_TensorHandleDataType(handle->get()); + TF_DataType output_dtype = input_dtype; + + if (desired_dtype >= 0 && desired_dtype != input_dtype) { + *handle = tensorflow::make_safe( + tensorflow::EagerCast(op_exec_info.ctx, handle->get(), input_dtype, + static_cast(desired_dtype), status)); + if (!status->status.ok()) return false; + output_dtype = desired_dtype; + } + + if (output_dtype != TF_INT32) { + // Note that this is a shallow copy and will share the underlying buffer + // if copying to the same device. + *handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice( + handle->get(), op_exec_info.ctx, op_exec_info.device_name, status)); + if (!status->status.ok()) return false; + } + return true; +} + bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, PyObject* input, tensorflow::Safe_PyObjectPtr* output, TF_Status* status) { @@ -1920,9 +1961,31 @@ bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, TFE_Execute(op, &output_handle, &num_retvals, status); if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; - // Always create the py object (and correctly DECREF it) from the returned - // value, else the data will leak. - output->reset(EagerTensorFromHandle(output_handle)); + if (!PyObject_HasAttrString(input, "_read_dtype")) { + // Always create the py object (and correctly DECREF it) from the returned + // value, else the data will leak. + output->reset(EagerTensorFromHandle(output_handle)); + } else { + // This is a _MixedPrecisionVariable which potentially does casting when + // being read. + tensorflow::Safe_PyObjectPtr read_dtype( + PyObject_GetAttrString(input, "_read_dtype")); + int desired_dtype = -1; + if (!ParseTypeValue("_read_dtype", read_dtype.get(), status, + &desired_dtype)) { + return false; + } + + auto safe_output_handle = tensorflow::make_safe(output_handle); + // Retires output_handle in the future. + output_handle = nullptr; + if (!CastTensor(parent_op_exec_info, + static_cast(desired_dtype), + &safe_output_handle, status)) { + return false; + } + output->reset(EagerTensorFromHandle(safe_output_handle.release())); + } // TODO(nareshmodi): Should we run post exec callbacks here? if (parent_op_exec_info.run_gradient_callback) { @@ -1992,27 +2055,13 @@ bool ConvertToTensor( } } - TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get()); - if (desired_dtype >= 0 && desired_dtype != handle_dtype) { - handle = tensorflow::make_safe( - tensorflow::EagerCast(op_exec_info.ctx, handle.get(), handle_dtype, - static_cast(desired_dtype), status)); - if (!status->status.ok()) return false; - - handle_dtype = TFE_TensorHandleDataType(handle.get()); - } - - if (handle_dtype != TF_INT32) { - // Note that this is a shallow copy and will share the underlying buffer - // if copying to the same device. - handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice( - handle.get(), op_exec_info.ctx, op_exec_info.device_name, status)); - if (!status->status.ok()) return false; + if (!CastTensor(op_exec_info, static_cast(desired_dtype), + &handle, status)) { + return false; } - + TF_DataType output_dtype = TFE_TensorHandleDataType(handle.get()); output_handle->reset(EagerTensorFromHandle(handle.release())); - - dtype_setter(handle_dtype); + dtype_setter(output_dtype); return true; } diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py index faaae40b3f1ef02984a7a75c23ae4acae65ac335..fd8ab695b8fbb732bb853cd4affadf98d4861cc2 100644 --- a/tensorflow/python/eager/pywrap_tfe_test.py +++ b/tensorflow/python/eager/pywrap_tfe_test.py @@ -23,6 +23,7 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -69,6 +70,25 @@ class Tests(test.TestCase): self.assertAllEqual(x, y) + @test_util.assert_no_new_tensors + @test_util.assert_no_garbage_created + def testFastpathExecute_MixedPrecisionVariableMatMulCorrectResponse(self): + ctx = context.context() + a_2_by_2 = constant_op.constant(1.0, shape=[2, 2]) + a_2_by_2_fp16 = math_ops.cast(a_2_by_2, dtype=dtypes.float16) + m = resource_variable_ops.ResourceVariable(a_2_by_2) + m = resource_variable_ops._MixedPrecisionVariable( + m, read_dtype=dtypes.float16) + x = pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "MatMul", None, None, m, m, "transpose_a", + False, "transpose_b", False) + y = pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2_fp16, + a_2_by_2_fp16, "transpose_a", False, "transpose_b", False) + + self.assertEqual(x.dtype, dtypes.float16) + self.assertAllEqual(x, y) + @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created def testFastpathExecute_TapeWrite(self): @@ -98,6 +118,29 @@ class Tests(test.TestCase): self.assertAllEqual(dz_dy.numpy(), constant_op.constant(4.0, shape=[2, 2]).numpy()) + @test_util.assert_no_new_tensors + @test_util.assert_no_garbage_created + def testFastpathExecute_MixedPrecisionVariableTapeWrite(self): + ctx = context.context() + with backprop.GradientTape(persistent=True) as tape: + a_2_by_2 = constant_op.constant( + [[1.0, 2.0], [3.0, 4.0]], dtype=dtypes.float32) + a_2_by_2_fp16 = math_ops.cast(a_2_by_2, dtype=dtypes.float16) + m1 = resource_variable_ops.ResourceVariable(a_2_by_2) + m2 = resource_variable_ops._MixedPrecisionVariable( + m1, read_dtype=dtypes.float16) + tape.watch(m2) + z = pywrap_tensorflow.TFE_Py_FastPathExecute( + ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2_fp16, m2, + "transpose_a", False, "transpose_b", False) + dz_dy = tape.gradient(z, [m2])[0] + self.assertEqual(dz_dy.dtype, dtypes.float16) + + expected_grads = math_ops.matmul( + array_ops.transpose(a_2_by_2_fp16), + constant_op.constant(1., shape=[2, 2], dtype=dtypes.float16)).numpy() + self.assertAllEqual(dz_dy.numpy(), expected_grads) + # Tests homogeneous list op @test_util.assert_no_new_tensors @test_util.assert_no_garbage_created diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 626a4eb1eee9bda6c910c9dfa9cfff27b04444c1..871136e2c893ff92bc13caa9405b0a8f3fd1385d 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -278,7 +278,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp( TypeError, - r"tensors argument must be a list or a tuple. Got \"EagerTensor\""): + r"tensors argument must be a list or a tuple. Got.*EagerTensor"): pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2) def testNegativeSliceDim(self): diff --git a/tensorflow/python/eager/test.py b/tensorflow/python/eager/test.py index f6a46e7eb3d03982f07bf4162d94c6038217bf61..33ee797678ed73c52ebb17723f688cec4feca402 100644 --- a/tensorflow/python/eager/test.py +++ b/tensorflow/python/eager/test.py @@ -23,6 +23,7 @@ from tensorflow.python.platform import test as _test from tensorflow.python.platform.test import * # pylint: disable=wildcard-import +# TODO(akshayka): Do away with this file. def main(argv=None): _ops.enable_eager_execution() _test.main(argv) diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 9e716e81f40a2395b6ff04989f695dc5c0d91d15..817c8e6848dcab5e04770e33dda803ed2a4a0c9a 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -1,8 +1,4 @@ -package( - default_visibility = [ - "//tensorflow:internal", - ], -) +package(default_visibility = ["//tensorflow:internal"]) licenses(["notice"]) # Apache 2.0 @@ -10,7 +6,10 @@ load("//tensorflow:tensorflow.bzl", "py_test") py_library( name = "estimator_py", - srcs = ["estimator_lib.py"], + srcs = [ + "__init__.py", + "estimator_lib.py", + ], srcs_version = "PY2AND3", visibility = [ "//tensorflow:__pkg__", @@ -31,7 +30,7 @@ py_library( ":parsing_utils", ":run_config", ":training", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -41,12 +40,9 @@ py_library( srcs_version = "PY2AND3", deps = [ ":gc", - "//tensorflow/python:errors", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:util", - "//tensorflow/python/estimator:metric_keys", - "//tensorflow/python/estimator:util", + ":metric_keys", + ":util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -58,10 +54,7 @@ py_test( deps = [ ":estimator", ":exporter", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -70,8 +63,7 @@ py_library( srcs = ["gc.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -82,10 +74,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":gc", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -95,12 +84,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":export_output", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python/saved_model:tag_constants", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -113,12 +97,7 @@ py_test( deps = [ ":export_output", ":model_fn", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -130,11 +109,7 @@ py_library( ":estimator", ":exporter", ":run_config", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:training", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -153,13 +128,7 @@ py_test( ":inputs", ":run_config", ":training", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -168,7 +137,7 @@ py_library( srcs = ["run_config.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:protos_all_py", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -180,8 +149,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":run_config", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -194,14 +162,7 @@ py_library( ":head", ":model_fn", ":optimizers", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -210,6 +171,7 @@ py_test( name = "baseline_test", size = "medium", srcs = ["canned/baseline_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = [ "no_pip", @@ -225,26 +187,7 @@ py_test( ":numpy_io", ":pandas_io", ":run_config", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -257,20 +200,7 @@ py_library( ":estimator", ":head", ":model_fn", - "//tensorflow/python:array_ops", - "//tensorflow/python:boosted_trees_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -278,22 +208,15 @@ py_test( name = "boosted_trees_test", size = "medium", srcs = ["canned/boosted_trees_test.py"], + shard_count = 2, srcs_version = "PY2AND3", + tags = [ + "optonly", + ], deps = [ ":boosted_trees", - "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:platform_test", - "//tensorflow/python:resources", - "//tensorflow/python:training", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/feature_column", + ":inputs", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -306,14 +229,7 @@ py_library( ":head", ":model_fn", ":optimizers", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:summary", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -330,22 +246,7 @@ py_library( ":model_fn", ":numpy_io", ":prediction_keys", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -368,16 +269,7 @@ py_test( ":numpy_io", ":pandas_io", ":prediction_keys", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -393,19 +285,7 @@ py_library( ":linear", ":model_fn", ":optimizers", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -428,17 +308,7 @@ py_test( ":numpy_io", ":pandas_io", ":prediction_keys", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:nn", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -450,10 +320,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:platform", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/data", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -464,10 +331,7 @@ py_test( tags = ["notsan"], # b/67510291 deps = [ ":util", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:training", - "//tensorflow/python/data", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -484,21 +348,7 @@ py_library( ":model_fn", ":run_config", ":util", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:metrics", - "//tensorflow/python:platform", - "//tensorflow/python:random_seed", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/data", - "//tensorflow/python/saved_model:builder", - "//tensorflow/python/saved_model:constants", - "//tensorflow/python/saved_model:tag_constants", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -517,29 +367,7 @@ py_test( ":model_fn", ":numpy_io", ":run_config", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:lib", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:saver_test_utils", - "//tensorflow/python:session", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variables", - "//tensorflow/python/data", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:loader", - "//tensorflow/python/saved_model:tag_constants", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -552,9 +380,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dtypes", - "//tensorflow/python:parsing_ops", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -565,10 +391,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":parsing_utils", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:parsing_ops", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -577,9 +400,7 @@ py_library( srcs = ["export/export_output.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python/saved_model:signature_def_utils", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -591,13 +412,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":export_output", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -610,7 +425,7 @@ py_library( deps = [ ":export_export", ":export_output", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -622,13 +437,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":util", - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -641,17 +450,8 @@ py_test( deps = [ ":export_export", ":export_output", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python/saved_model:signature_def_utils", + ":util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -664,24 +464,7 @@ py_library( ":metric_keys", ":model_fn", ":prediction_keys", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:nn", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:weights_broadcast_ops", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -700,23 +483,7 @@ py_test( ":model_fn", ":numpy_io", ":prediction_keys", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", "@six_archive//:six", ], @@ -729,7 +496,7 @@ py_library( deps = [ ":numpy_io", ":pandas_io", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -741,11 +508,7 @@ py_library( ":estimator", ":head", ":optimizers", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -763,25 +526,7 @@ py_library( ":numpy_io", ":pandas_io", ":run_config", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:check_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:distribute", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -799,7 +544,7 @@ py_test( deps = [ ":linear", ":linear_testing_utils", - "//tensorflow/python:client_testlib", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -828,9 +573,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":numpy_io", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -839,7 +582,7 @@ py_library( srcs = ["canned/optimizers.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -851,8 +594,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":optimizers", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -870,9 +612,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":pandas_io", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -892,15 +632,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "@six_archive//:six", ], ) @@ -914,7 +646,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":inputs_queues", - "//tensorflow/python:client_testlib", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -925,10 +657,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":inputs_queues", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:session", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -941,32 +670,7 @@ py_library( ":export_export", ":model_fn", ":run_config", - "//tensorflow/python:check_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:nn", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:platform", - "//tensorflow/python:random_seed", - "//tensorflow/python:session", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:summary", - "//tensorflow/python:tensor_util", - "//tensorflow/python:training", - "//tensorflow/python:training_util", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/feature_column", - "//tensorflow/python/keras:backend", - "//tensorflow/python/keras:engine", - "//tensorflow/python/keras:layers", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model", - "//tensorflow/python/saved_model:signature_constants", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -974,6 +678,7 @@ py_test( name = "keras_test", size = "large", srcs = ["keras_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = [ "no_windows", @@ -981,18 +686,49 @@ py_test( ], deps = [ ":keras", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/estimator:run_config", - "//tensorflow/python/keras", - "//tensorflow/python/keras:backend", - "//tensorflow/python/keras:engine", + ":numpy_io", + ":run_config", + "//tensorflow:tensorflow_py_no_contrib", "//third_party/py/numpy", ], ) + +py_library( + name = "expect_numpy_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect numpy to already be installed on the system, e.g. via + # `pip install numpy` + visibility = ["//visibility:public"], +) + +py_library( + name = "expect_pandas_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect pandas to already be installed on the system, e.g. via + # `pip install pandas` + visibility = ["//visibility:public"], +) + +py_library( + name = "expect_h5py_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect h5py to already be installed on the system, e.g. via + # `pip install h5py' + visibility = ["//visibility:public"], +) + +py_library( + name = "expect_six_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect six to already be installed on the system, e.g. via + # `pip install six` + visibility = ["//visibility:public"], +) + +py_library( + name = "expect_tensorflow_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect tensorflow to already be installed on the system, e.g. via + # `pip install tensorflow` or `pip install tensorflow_gpu` + visibility = ["//visibility:public"], +) diff --git a/tensorflow/python/estimator/__init__.py b/tensorflow/python/estimator/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..8cf8df567f0e36604b5c3f6fe992b572d6632954 100644 --- a/tensorflow/python/estimator/__init__.py +++ b/tensorflow/python/estimator/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Import Estimator APIs. + +Note: This file is imported by the create_estimator_api genrule. It must +transitively import all Estimator modules/packages for their @estimator_export +annotations to generate the public Estimator python API. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.python.estimator.estimator_lib diff --git a/tensorflow/python/estimator/api/BUILD b/tensorflow/python/estimator/api/BUILD index cddee9b8f30555da63a9aad1190a7644d02e5392..a75fa7d0aee56c4fd4faccfaf2fa07c399cedcc9 100644 --- a/tensorflow/python/estimator/api/BUILD +++ b/tensorflow/python/estimator/api/BUILD @@ -6,12 +6,14 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow/tools/api/generator:api_gen.bzl", "gen_api_init_files") -load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES") +load("//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files") +load("//tensorflow/python/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES") gen_api_init_files( name = "estimator_python_api_gen", api_name = "estimator", output_files = ESTIMATOR_API_INIT_FILES, + output_package = "tensorflow.python.estimator.api", package = "tensorflow.python.estimator", + package_dep = "//tensorflow/python/estimator:estimator_py", ) diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py index 3c6816cb03f180fa0ab230cba38d45ef162d58db..20c7a69b7cb071365e5442b512c1a858a7e0b246 100644 --- a/tensorflow/python/estimator/canned/baseline.py +++ b/tensorflow/python/estimator/canned/baseline.py @@ -24,10 +24,10 @@ Example: classifier = BaselineClassifier(n_classes=3) # Input builders -def input_fn_train: # returns x, y (where y represents label's class index). +def input_fn_train(): # returns x, y (where y represents label's class index). pass -def input_fn_eval: # returns x, y (where y represents label's class index). +def input_fn_eval(): # returns x, y (where y represents label's class index). pass # Fit model. @@ -215,6 +215,13 @@ class BaselineClassifier(estimator.Estimator): * if `weight_column` is not `None`, a feature with `key=weight_column` whose value is a `Tensor`. + + @compatibility(eager) + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. + @end_compatibility """ def __init__(self, @@ -313,6 +320,13 @@ class BaselineRegressor(estimator.Estimator): * if `weight_column` is not `None`, a feature with `key=weight_column` whose value is a `Tensor`. + + @compatibility(eager) + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. + @end_compatibility """ def __init__(self, diff --git a/tensorflow/python/estimator/canned/baseline_test.py b/tensorflow/python/estimator/canned/baseline_test.py index 7bf2e62da9c4598c28ad38825aac2031c9d51905..e46a3a156dfd546b733067299906857fbd705736 100644 --- a/tensorflow/python/estimator/canned/baseline_test.py +++ b/tensorflow/python/estimator/canned/baseline_test.py @@ -154,6 +154,8 @@ class BaselineRegressorEvaluationTest(test.TestCase): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 9., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -176,6 +178,8 @@ class BaselineRegressorEvaluationTest(test.TestCase): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 18., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -204,6 +208,8 @@ class BaselineRegressorEvaluationTest(test.TestCase): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 27., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -229,7 +235,9 @@ class BaselineRegressorEvaluationTest(test.TestCase): self.assertItemsEqual( (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, - ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys()) + metric_keys.MetricKeys.PREDICTION_MEAN, + metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP), + eval_metrics.keys()) # Logit is bias which is [46, 58] self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS]) diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index 6b54f51ca6f9cc7fed644383bec432cc6c8013bd..8b423f76de8e3b685e0426ef48fdd9f195925ff1 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -17,7 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections +import functools from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn @@ -44,12 +46,13 @@ from tensorflow.python.util.tf_export import estimator_export # TODO(nponomareva): Reveal pruning params here. _TreeHParams = collections.namedtuple('TreeHParams', [ 'n_trees', 'max_depth', 'learning_rate', 'l1', 'l2', 'tree_complexity', - 'min_node_weight' + 'min_node_weight', 'center_bias', 'pruning_mode' ]) _HOLD_FOR_MULTI_CLASS_SUPPORT = object() _HOLD_FOR_MULTI_DIM_SUPPORT = object() _DUMMY_NUM_BUCKETS = -1 +_DUMMY_NODE_ID = -1 def _get_transformed_features(features, sorted_feature_columns): @@ -168,9 +171,10 @@ def _group_features_by_num_buckets(sorted_feature_columns): # pylint:enable=protected-access # Replace the dummy key with the real max num of buckets for all bucketized # columns. - bucket_size_to_feature_ids_dict[ - max_buckets_for_bucketized] = bucket_size_to_feature_ids_dict[ - _DUMMY_NUM_BUCKETS] + if max_buckets_for_bucketized not in bucket_size_to_feature_ids_dict: + bucket_size_to_feature_ids_dict[max_buckets_for_bucketized] = [] + bucket_size_to_feature_ids_dict[max_buckets_for_bucketized].extend( + bucket_size_to_feature_ids_dict[_DUMMY_NUM_BUCKETS]) del bucket_size_to_feature_ids_dict[_DUMMY_NUM_BUCKETS] feature_ids_list = list(bucket_size_to_feature_ids_dict.values()) @@ -278,7 +282,9 @@ class _CacheTrainingStatesUsingHashTable(object): """Returns cached_tree_ids, cached_node_ids, cached_logits.""" cached_tree_ids, cached_node_ids, cached_logits = array_ops.split( lookup_ops.lookup_table_find_v2( - self._table_ref, self._example_ids, default_value=[0.0, 0.0, 0.0]), + self._table_ref, + self._example_ids, + default_value=[0.0, _DUMMY_NODE_ID, 0.0]), [1, 1, self._logits_dimension], axis=1) cached_tree_ids = array_ops.squeeze( @@ -329,7 +335,7 @@ class _CacheTrainingStatesUsingVariables(object): array_ops.zeros([batch_size], dtype=dtypes.int32), name='tree_ids_cache') self._node_ids = _local_variable( - array_ops.zeros([batch_size], dtype=dtypes.int32), + _DUMMY_NODE_ID*array_ops.ones([batch_size], dtype=dtypes.int32), name='node_ids_cache') self._logits = _local_variable( array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32), @@ -379,6 +385,260 @@ class _StopAtAttemptsHook(session_run_hook.SessionRunHook): run_context.request_stop() +def _get_max_splits(tree_hparams): + """Calculates the max possible number of splits based on tree params.""" + # maximum number of splits possible in the whole tree =2^(D-1)-1 + max_splits = (1 << tree_hparams.max_depth) - 1 + return max_splits + + +class _EnsembleGrower(object): + """Abstract base class for different types of ensemble growers. + + Use it to receive training ops for growing and centering bias, depending + on the implementation (for example, in memory or accumulator-based + distributed): + grower = ...create subclass grower(tree_ensemble, tree_hparams) + grow_op = grower.grow_tree(stats_summaries_list, feature_ids_list, + last_layer_nodes_range) + training_ops.append(grow_op) + """ + + def __init__(self, tree_ensemble, tree_hparams): + """Initializes a grower object. + + Args: + tree_ensemble: A TreeEnsemble variable. + tree_hparams: TODO. collections.namedtuple for hyper parameters. + Raises: + ValueError: when pruning mode is invalid or pruning is used and no tree + complexity is set. + """ + self._tree_ensemble = tree_ensemble + self._tree_hparams = tree_hparams + # pylint: disable=protected-access + self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str( + tree_hparams.pruning_mode) + + if (self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING + and tree_hparams.tree_complexity <= 0): + raise ValueError('For pruning, tree_complexity must be positive.') + # pylint: enable=protected-access + + @abc.abstractmethod + def center_bias(self, center_bias_var, gradients, hessians): + """Centers bias, if ready, based on statistics. + + Args: + center_bias_var: A variable that will be updated when bias centering + finished. + gradients: A rank 2 tensor of gradients. + hessians: A rank 2 tensor of hessians. + + Returns: + An operation for centering bias. + """ + + @abc.abstractmethod + def grow_tree(self, stats_summaries_list, feature_ids_list, + last_layer_nodes_range): + """Grows a tree, if ready, based on provided statistics. + + Args: + stats_summaries_list: List of stats summary tensors, representing sums of + gradients and hessians for each feature bucket. + feature_ids_list: a list of lists of feature ids for each bucket size. + last_layer_nodes_range: A tensor representing ids of the nodes in the + current layer, to be split. + + Returns: + An op for growing a tree. + """ + + # ============= Helper methods =========== + + def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians): + """Updates the ensembles and cache (if needed) with logits prior.""" + continue_centering = boosted_trees_ops.center_bias( + self._tree_ensemble.resource_handle, + mean_gradients=mean_gradients, + mean_hessians=mean_hessians, + l1=self._tree_hparams.l1, + l2=self._tree_hparams.l2) + return center_bias_var.assign(continue_centering) + + def _grow_tree_from_stats_summaries(self, stats_summaries_list, + feature_ids_list, last_layer_nodes_range): + """Updates ensemble based on the best gains from stats summaries.""" + node_ids_per_feature = [] + gains_list = [] + thresholds_list = [] + left_node_contribs_list = [] + right_node_contribs_list = [] + all_feature_ids = [] + assert len(stats_summaries_list) == len(feature_ids_list) + + max_splits = _get_max_splits(self._tree_hparams) + + for i, feature_ids in enumerate(feature_ids_list): + (numeric_node_ids_per_feature, numeric_gains_list, + numeric_thresholds_list, numeric_left_node_contribs_list, + numeric_right_node_contribs_list) = ( + boosted_trees_ops.calculate_best_gains_per_feature( + node_id_range=last_layer_nodes_range, + stats_summary_list=stats_summaries_list[i], + l1=self._tree_hparams.l1, + l2=self._tree_hparams.l2, + tree_complexity=self._tree_hparams.tree_complexity, + min_node_weight=self._tree_hparams.min_node_weight, + max_splits=max_splits)) + + all_feature_ids += feature_ids + node_ids_per_feature += numeric_node_ids_per_feature + gains_list += numeric_gains_list + thresholds_list += numeric_thresholds_list + left_node_contribs_list += numeric_left_node_contribs_list + right_node_contribs_list += numeric_right_node_contribs_list + + grow_op = boosted_trees_ops.update_ensemble( + # Confirm if local_tree_ensemble or tree_ensemble should be used. + self._tree_ensemble.resource_handle, + feature_ids=all_feature_ids, + node_ids=node_ids_per_feature, + gains=gains_list, + thresholds=thresholds_list, + left_node_contribs=left_node_contribs_list, + right_node_contribs=right_node_contribs_list, + learning_rate=self._tree_hparams.learning_rate, + max_depth=self._tree_hparams.max_depth, + pruning_mode=self._pruning_mode_parsed) + return grow_op + + +class _InMemoryEnsembleGrower(_EnsembleGrower): + """A base class for ensemble growers.""" + + def __init__(self, tree_ensemble, tree_hparams): + + super(_InMemoryEnsembleGrower, self).__init__( + tree_ensemble=tree_ensemble, tree_hparams=tree_hparams) + + def center_bias(self, center_bias_var, gradients, hessians): + # For in memory, we already have a full batch of gradients and hessians, + # so just take a mean and proceed with centering. + mean_gradients = array_ops.expand_dims( + math_ops.reduce_mean(gradients, 0), 0) + mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0) + return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians) + + def grow_tree(self, stats_summaries_list, feature_ids_list, + last_layer_nodes_range): + # For in memory, we already have full data in one batch, so we can grow the + # tree immediately. + return self._grow_tree_from_stats_summaries( + stats_summaries_list, feature_ids_list, last_layer_nodes_range) + + +class _AccumulatorEnsembleGrower(_EnsembleGrower): + """A base class for ensemble growers.""" + + def __init__(self, tree_ensemble, tree_hparams, stamp_token, + n_batches_per_layer, bucket_size_list, is_chief): + super(_AccumulatorEnsembleGrower, self).__init__( + tree_ensemble=tree_ensemble, tree_hparams=tree_hparams) + self._stamp_token = stamp_token + self._n_batches_per_layer = n_batches_per_layer + self._bucket_size_list = bucket_size_list + self._is_chief = is_chief + + def center_bias(self, center_bias_var, gradients, hessians): + # For not in memory situation, we need to accumulate enough of batches first + # before proceeding with centering bias. + + # Create an accumulator. + bias_dependencies = [] + bias_accumulator = data_flow_ops.ConditionalAccumulator( + dtype=dtypes.float32, + # The stats consist of grads and hessians means only. + # TODO(nponomareva): this will change for a multiclass + shape=[2, 1], + shared_name='bias_accumulator') + + grads_and_hess = array_ops.stack([gradients, hessians], axis=0) + grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1) + + apply_grad = bias_accumulator.apply_grad(grads_and_hess, self._stamp_token) + bias_dependencies.append(apply_grad) + + # Center bias if enough batches were processed. + with ops.control_dependencies(bias_dependencies): + if not self._is_chief: + return control_flow_ops.no_op() + + def center_bias_from_accumulator(): + accumulated = array_ops.unstack(bias_accumulator.take_grad(1), axis=0) + return self._center_bias_fn(center_bias_var, + array_ops.expand_dims(accumulated[0], 0), + array_ops.expand_dims(accumulated[1], 0)) + + center_bias_op = control_flow_ops.cond( + math_ops.greater_equal(bias_accumulator.num_accumulated(), + self._n_batches_per_layer), + center_bias_from_accumulator, + control_flow_ops.no_op, + name='wait_until_n_batches_for_bias_accumulated') + return center_bias_op + + def grow_tree(self, stats_summaries_list, feature_ids_list, + last_layer_nodes_range): + # For not in memory situation, we need to accumulate enough of batches first + # before proceeding with building a tree layer. + max_splits = _get_max_splits(self._tree_hparams) + + # Prepare accumulators. + accumulators = [] + dependencies = [] + for i, feature_ids in enumerate(feature_ids_list): + stats_summaries = stats_summaries_list[i] + accumulator = data_flow_ops.ConditionalAccumulator( + dtype=dtypes.float32, + # The stats consist of grads and hessians (the last dimension). + shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2], + shared_name='numeric_stats_summary_accumulator_' + str(i)) + accumulators.append(accumulator) + + apply_grad = accumulator.apply_grad( + array_ops.stack(stats_summaries, axis=0), self._stamp_token) + dependencies.append(apply_grad) + + # Grow the tree if enough batches is accumulated. + with ops.control_dependencies(dependencies): + if not self._is_chief: + return control_flow_ops.no_op() + + min_accumulated = math_ops.reduce_min( + array_ops.stack([acc.num_accumulated() for acc in accumulators])) + + def grow_tree_from_accumulated_summaries_fn(): + """Updates tree with the best layer from accumulated summaries.""" + # Take out the accumulated summaries from the accumulator and grow. + stats_summaries_list = [] + stats_summaries_list = [ + array_ops.unstack(accumulator.take_grad(1), axis=0) + for accumulator in accumulators + ] + grow_op = self._grow_tree_from_stats_summaries( + stats_summaries_list, feature_ids_list, last_layer_nodes_range) + return grow_op + + grow_model = control_flow_ops.cond( + math_ops.greater_equal(min_accumulated, self._n_batches_per_layer), + grow_tree_from_accumulated_summaries_fn, + control_flow_ops.no_op, + name='wait_until_n_batches_accumulated') + return grow_model + + def _bt_model_fn( features, labels, @@ -424,8 +684,9 @@ def _bt_model_fn( ValueError: mode or params are invalid, or features has the wrong type. """ is_single_machine = (config.num_worker_replicas <= 1) - sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name) + center_bias = tree_hparams.center_bias + if train_in_memory: assert n_batches_per_layer == 1, ( 'When train_in_memory is enabled, input_fn should return the entire ' @@ -436,11 +697,6 @@ def _bt_model_fn( raise ValueError('train_in_memory is supported only for ' 'non-distributed training.') worker_device = control_flow_ops.no_op().device - # maximum number of splits possible in the whole tree =2^(D-1)-1 - # TODO(youngheek): perhaps storage could be optimized by storing stats with - # the dimension max_splits_per_layer, instead of max_splits (for the entire - # tree). - max_splits = (1 << tree_hparams.max_depth) - 1 train_op = [] with ops.name_scope(name) as name: # Prepare. @@ -468,6 +724,9 @@ def _bt_model_fn( # Create Ensemble resources. tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) + # Variable that determines whether bias centering is needed. + center_bias_var = variable_scope.variable( + initial_value=center_bias, name='center_bias_needed', trainable=False) # Create logits. if mode != model_fn.ModeKeys.TRAIN: logits = boosted_trees_ops.predict( @@ -488,6 +747,7 @@ def _bt_model_fn( # TODO(soroush): Do partial updates if this becomes a bottleneck. ensemble_reload = local_tree_ensemble.deserialize( *tree_ensemble.serialize()) + if training_state_cache: cached_tree_ids, cached_node_ids, cached_logits = ( training_state_cache.lookup()) @@ -496,9 +756,10 @@ def _bt_model_fn( batch_size = array_ops.shape(labels)[0] cached_tree_ids, cached_node_ids, cached_logits = ( array_ops.zeros([batch_size], dtype=dtypes.int32), - array_ops.zeros([batch_size], dtype=dtypes.int32), + _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32), array_ops.zeros( [batch_size, head.logits_dimension], dtype=dtypes.float32)) + with ops.control_dependencies([ensemble_reload]): (stamp_token, num_trees, num_finalized_trees, num_attempted_layers, last_layer_nodes_range) = local_tree_ensemble.get_states() @@ -512,13 +773,20 @@ def _bt_model_fn( cached_node_ids=cached_node_ids, bucketized_features=input_feature_list, logits_dimension=head.logits_dimension) + logits = cached_logits + partial_logits # Create training graph. def _train_op_fn(loss): """Run one training iteration.""" if training_state_cache: - train_op.append(training_state_cache.insert(tree_ids, node_ids, logits)) + # Cache logits only after center_bias is complete, if it's in progress. + train_op.append( + control_flow_ops.cond( + center_bias_var, control_flow_ops.no_op, + lambda: training_state_cache.insert(tree_ids, node_ids, logits)) + ) + if closed_form_grad_and_hess_fn: gradients, hessians = closed_form_grad_and_hess_fn(logits, labels) else: @@ -526,6 +794,11 @@ def _bt_model_fn( hessians = gradients_impl.gradients( gradients, logits, name='Hessians')[0] + # TODO(youngheek): perhaps storage could be optimized by storing stats + # with the dimension max_splits_per_layer, instead of max_splits (for the + # entire tree). + max_splits = _get_max_splits(tree_hparams) + stats_summaries_list = [] for i, feature_ids in enumerate(feature_ids_list): num_buckets = bucket_size_list[i] @@ -542,103 +815,28 @@ def _bt_model_fn( ] stats_summaries_list.append(summaries) - accumulators = [] - - def grow_tree_from_stats_summaries(stats_summaries_list, - feature_ids_list): - """Updates ensemble based on the best gains from stats summaries.""" - node_ids_per_feature = [] - gains_list = [] - thresholds_list = [] - left_node_contribs_list = [] - right_node_contribs_list = [] - all_feature_ids = [] - - assert len(stats_summaries_list) == len(feature_ids_list) - - for i, feature_ids in enumerate(feature_ids_list): - (numeric_node_ids_per_feature, numeric_gains_list, - numeric_thresholds_list, numeric_left_node_contribs_list, - numeric_right_node_contribs_list) = ( - boosted_trees_ops.calculate_best_gains_per_feature( - node_id_range=last_layer_nodes_range, - stats_summary_list=stats_summaries_list[i], - l1=tree_hparams.l1, - l2=tree_hparams.l2, - tree_complexity=tree_hparams.tree_complexity, - min_node_weight=tree_hparams.min_node_weight, - max_splits=max_splits)) - - all_feature_ids += feature_ids - node_ids_per_feature += numeric_node_ids_per_feature - gains_list += numeric_gains_list - thresholds_list += numeric_thresholds_list - left_node_contribs_list += numeric_left_node_contribs_list - right_node_contribs_list += numeric_right_node_contribs_list - - grow_op = boosted_trees_ops.update_ensemble( - # Confirm if local_tree_ensemble or tree_ensemble should be used. - tree_ensemble.resource_handle, - feature_ids=all_feature_ids, - node_ids=node_ids_per_feature, - gains=gains_list, - thresholds=thresholds_list, - left_node_contribs=left_node_contribs_list, - right_node_contribs=right_node_contribs_list, - learning_rate=tree_hparams.learning_rate, - max_depth=tree_hparams.max_depth, - pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING) - return grow_op - if train_in_memory and is_single_machine: - train_op.append(distribute_lib.increment_var(global_step)) - train_op.append( - grow_tree_from_stats_summaries(stats_summaries_list, - feature_ids_list)) + grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams) else: - dependencies = [] - - for i, feature_ids in enumerate(feature_ids_list): - stats_summaries = stats_summaries_list[i] - accumulator = data_flow_ops.ConditionalAccumulator( - dtype=dtypes.float32, - # The stats consist of grads and hessians (the last dimension). - shape=[len(feature_ids), max_splits, bucket_size_list[i], 2], - shared_name='numeric_stats_summary_accumulator_' + str(i)) - accumulators.append(accumulator) - - apply_grad = accumulator.apply_grad( - array_ops.stack(stats_summaries, axis=0), stamp_token) - dependencies.append(apply_grad) - - def grow_tree_from_accumulated_summaries_fn(): - """Updates the tree with the best layer from accumulated summaries.""" - # Take out the accumulated summaries from the accumulator and grow. - stats_summaries_list = [] - - stats_summaries_list = [ - array_ops.unstack(accumulator.take_grad(1), axis=0) - for accumulator in accumulators - ] - - grow_op = grow_tree_from_stats_summaries(stats_summaries_list, - feature_ids_list) - return grow_op - - with ops.control_dependencies(dependencies): - train_op.append(distribute_lib.increment_var(global_step)) - if config.is_chief: - min_accumulated = math_ops.reduce_min( - array_ops.stack( - [acc.num_accumulated() for acc in accumulators])) - - train_op.append( - control_flow_ops.cond( - math_ops.greater_equal(min_accumulated, - n_batches_per_layer), - grow_tree_from_accumulated_summaries_fn, - control_flow_ops.no_op, - name='wait_until_n_batches_accumulated')) + grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams, + stamp_token, n_batches_per_layer, + bucket_size_list, config.is_chief) + + update_model = control_flow_ops.cond( + center_bias_var, + functools.partial( + grower.center_bias, + center_bias_var, + gradients, + hessians, + ), + functools.partial(grower.grow_tree, stats_summaries_list, + feature_ids_list, last_layer_nodes_range)) + train_op.append(update_model) + + with ops.control_dependencies([update_model]): + increment_global = distribute_lib.increment_var(global_step) + train_op.append(increment_global) return control_flow_ops.group(train_op, name='train_op') @@ -714,7 +912,15 @@ def _create_regression_head(label_dimension, weight_column=None): @estimator_export('estimator.BoostedTreesClassifier') class BoostedTreesClassifier(estimator.Estimator): - """A Classifier for Tensorflow Boosted Trees models.""" + """A Classifier for Tensorflow Boosted Trees models. + + @compatibility(eager) + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. + @end_compatibility + """ def __init__(self, feature_columns, @@ -730,7 +936,9 @@ class BoostedTreesClassifier(estimator.Estimator): l2_regularization=0., tree_complexity=0., min_node_weight=0., - config=None): + config=None, + center_bias=False, + pruning_mode='none'): """Initializes a `BoostedTreesClassifier` instance. Example: @@ -798,6 +1006,17 @@ class BoostedTreesClassifier(estimator.Estimator): split to be considered. The value will be compared with sum(leaf_hessian)/(batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. Raises: ValueError: when wrong arguments are given or unsupported functionalities @@ -810,9 +1029,9 @@ class BoostedTreesClassifier(estimator.Estimator): n_classes, weight_column, label_vocabulary=label_vocabulary) # HParams for the model. - tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate, - l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_hparams = _TreeHParams( + n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return _bt_model_fn( # pylint: disable=protected-access @@ -832,7 +1051,15 @@ class BoostedTreesClassifier(estimator.Estimator): @estimator_export('estimator.BoostedTreesRegressor') class BoostedTreesRegressor(estimator.Estimator): - """A Regressor for Tensorflow Boosted Trees models.""" + """A Regressor for Tensorflow Boosted Trees models. + + @compatibility(eager) + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. + @end_compatibility + """ def __init__(self, feature_columns, @@ -847,7 +1074,9 @@ class BoostedTreesRegressor(estimator.Estimator): l2_regularization=0., tree_complexity=0., min_node_weight=0., - config=None): + config=None, + center_bias=False, + pruning_mode='none'): """Initializes a `BoostedTreesRegressor` instance. Example: @@ -908,6 +1137,17 @@ class BoostedTreesRegressor(estimator.Estimator): split to be considered. The value will be compared with sum(leaf_hessian)/(batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. Raises: ValueError: when wrong arguments are given or unsupported functionalities @@ -919,9 +1159,9 @@ class BoostedTreesRegressor(estimator.Estimator): head = _create_regression_head(label_dimension, weight_column) # HParams for the model. - tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate, - l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_hparams = _TreeHParams( + n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return _bt_model_fn( # pylint: disable=protected-access diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py index 9ea4f484744762a98c67207d582bcc5b7be8d850..ec597e468615e0bb16b832e82d4049ef5961c4c3 100644 --- a/tensorflow/python/estimator/canned/boosted_trees_test.py +++ b/tensorflow/python/estimator/canned/boosted_trees_test.py @@ -500,6 +500,50 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertEqual(2, ensemble.trees[0].nodes[0].bucketized_split.feature_id) self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold) + def testTrainEvaluateAndPredictWithOnlyIndicatorColumn(self): + categorical = feature_column.categorical_column_with_vocabulary_list( + key='categorical', vocabulary_list=('bad', 'good', 'ok')) + feature_indicator = feature_column.indicator_column(categorical) + + labels = np.array([[0.], [5.7], [5.7], [0.], [0.]], dtype=np.float32) + # Our categorical feature defines the labels perfectly + input_fn = numpy_io.numpy_input_fn( + x={ + 'categorical': np.array(['bad', 'good', 'good', 'ok', 'bad']), + }, + y=labels, + batch_size=5, + shuffle=False) + + # Train depth 1 tree. + est = boosted_trees.BoostedTreesRegressor( + feature_columns=[feature_indicator], + n_batches_per_layer=1, + n_trees=1, + learning_rate=1.0, + max_depth=1) + + num_steps = 1 + est.train(input_fn, steps=num_steps) + ensemble = self._assert_checkpoint_and_return_model( + est.model_dir, global_step=1, finalized_trees=1, attempted_layers=1) + + # We learnt perfectly. + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['loss'], 0) + + predictions = list(est.predict(input_fn)) + self.assertAllClose( + labels, + [pred['predictions'] for pred in predictions]) + + self.assertEqual(3, len(ensemble.trees[0].nodes)) + + # Check that the split happened on 'good' value, which will be encoded as + # feature with index 1 (0 - 'bad', 2 - 'ok') + self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id) + self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold) + class ModelFnTests(test_util.TensorFlowTestCase): """Tests bt_model_fn including unexposed internal functionalities.""" @@ -510,37 +554,495 @@ class ModelFnTests(test_util.TensorFlowTestCase): feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), BUCKET_BOUNDARIES) for i in range(NUM_FEATURES) } - self._tree_hparams = boosted_trees._TreeHParams( # pylint:disable=protected-access - n_trees=2, - max_depth=2, - learning_rate=0.1, - l1=0., - l2=0.01, - tree_complexity=0., - min_node_weight=0.) - def _get_expected_ensembles_for_classification(self): + def _get_expected_ensembles_for_classification(self): + first_round = """ + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.387675 + } + } + nodes { + leaf { + scalar: -0.181818 + } + } + nodes { + leaf { + scalar: 0.0625 + } + } + } + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 1 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + last_layer_node_start: 1 + last_layer_node_end: 3 + } + """ + second_round = """ + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.387675 + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 3 + left_id: 3 + right_id: 4 + } + metadata { + gain: 0.0 + original_leaf { + scalar: -0.181818 + } + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 0 + left_id: 5 + right_id: 6 + } + metadata { + gain: 0.105518 + original_leaf { + scalar: 0.0625 + } + } + } + nodes { + leaf { + scalar: -0.348397 + } + } + nodes { + leaf { + scalar: -0.181818 + } + } + nodes { + leaf { + scalar: 0.224091 + } + } + nodes { + leaf { + scalar: 0.056815 + } + } + } + trees { + nodes { + leaf { + scalar: 0.0 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 2 + is_finalized: true + } + tree_metadata { + num_layers_grown: 0 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 2 + last_layer_node_start: 0 + last_layer_node_end: 1 + } + """ + third_round = """ + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.387675 + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 3 + left_id: 3 + right_id: 4 + } + metadata { + gain: 0.0 + original_leaf { + scalar: -0.181818 + } + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 0 + left_id: 5 + right_id: 6 + } + metadata { + gain: 0.105518 + original_leaf { + scalar: 0.0625 + } + } + } + nodes { + leaf { + scalar: -0.348397 + } + } + nodes { + leaf { + scalar: -0.181818 + } + } + nodes { + leaf { + scalar: 0.224091 + } + } + nodes { + leaf { + scalar: 0.056815 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 1 + threshold: 0 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.287131 + } + } + nodes { + leaf { + scalar: 0.162042 + } + } + nodes { + leaf { + scalar: -0.086986 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 2 + is_finalized: true + } + tree_metadata { + num_layers_grown: 1 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 2 + num_layers_attempted: 3 + last_layer_node_start: 1 + last_layer_node_end: 3 + } + """ + return (first_round, second_round, third_round) + + def _get_expected_ensembles_for_classification_with_bias(self): + first_round = """ + trees { + nodes { + leaf { + scalar: -0.405086 + } + } + } + tree_weights: 1.0 + tree_metadata { + } + """ + second_round = """ + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.407711 + original_leaf { + scalar: -0.405086 + } + } + } + nodes { + leaf { + scalar: -0.556054 + } + } + nodes { + leaf { + scalar: -0.301233 + } + } + } + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 1 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + last_layer_node_start: 1 + last_layer_node_end: 3 + } + """ + third_round = """ + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.407711 + original_leaf { + scalar: -0.405086 + } + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 3 + left_id: 3 + right_id: 4 + } + metadata { + original_leaf { + scalar: -0.556054 + } + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 0 + left_id: 5 + right_id: 6 + } + metadata { + gain: 0.09876 + original_leaf { + scalar: -0.301233 + } + } + } + nodes { + leaf { + scalar: -0.698072 + } + } + nodes { + leaf { + scalar: -0.556054 + } + } + nodes { + leaf { + scalar: -0.106016 + } + } + nodes { + leaf { + scalar: -0.27349 + } + } + } + trees { + nodes { + leaf { + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 2 + is_finalized: true + } + tree_metadata { + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 2 + last_layer_node_end: 1 + } + """ + forth_round = """ + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.4077113 + original_leaf { + scalar: -0.405086 + } + } + } + nodes { + bucketized_split { + threshold: 3 + left_id: 3 + right_id: 4 + } + metadata { + original_leaf { + scalar: -0.556054 + } + } + } + nodes { + bucketized_split { + threshold: 0 + left_id: 5 + right_id: 6 + } + metadata { + gain: 0.09876 + original_leaf { + scalar: -0.301233 + } + } + } + nodes { + leaf { + scalar: -0.698072 + } + } + nodes { + leaf { + scalar: -0.556054 + } + } + nodes { + leaf { + scalar: -0.106016 + } + } + nodes { + leaf { + scalar: -0.27349 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 0.289927 + } + } + nodes { + leaf { + scalar: -0.134588 + } + } + nodes { + leaf { + scalar: 0.083838 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 2 + is_finalized: true + } + tree_metadata { + num_layers_grown: 1 + } + growing_metadata { + num_trees_attempted: 2 + num_layers_attempted: 3 + last_layer_node_start: 1 + last_layer_node_end: 3 + } + """ + return (first_round, second_round, third_round, forth_round) + + def _get_expected_ensembles_for_regression(self): first_round = """ trees { nodes { bucketized_split { - feature_id: 2 - threshold: 2 + feature_id: 1 + threshold: 1 left_id: 1 right_id: 2 } metadata { - gain: 0.387675 + gain: 1.169714 } } nodes { leaf { - scalar: -0.181818 + scalar: 0.241322 } } nodes { leaf { - scalar: 0.0625 + scalar: 0.083951 } } } @@ -560,26 +1062,26 @@ class ModelFnTests(test_util.TensorFlowTestCase): trees { nodes { bucketized_split { - feature_id: 2 - threshold: 2 + feature_id: 1 + threshold: 1 left_id: 1 right_id: 2 } metadata { - gain: 0.387675 + gain: 1.169714 } } nodes { bucketized_split { feature_id: 0 - threshold: 3 + threshold: 1 left_id: 3 right_id: 4 } metadata { - gain: 0.0 + gain: 2.673407 original_leaf { - scalar: -0.181818 + scalar: 0.241322 } } } @@ -591,30 +1093,30 @@ class ModelFnTests(test_util.TensorFlowTestCase): right_id: 6 } metadata { - gain: 0.105518 + gain: 0.324102 original_leaf { - scalar: 0.0625 + scalar: 0.083951 } } } nodes { leaf { - scalar: -0.348397 + scalar: 0.563167 } } nodes { leaf { - scalar: -0.181818 + scalar: 0.247047 } } nodes { leaf { - scalar: 0.224091 + scalar: 0.095273 } } nodes { leaf { - scalar: 0.056815 + scalar: 0.222102 } } } @@ -646,26 +1148,26 @@ class ModelFnTests(test_util.TensorFlowTestCase): trees { nodes { bucketized_split { - feature_id: 2 - threshold: 2 + feature_id: 1 + threshold: 1 left_id: 1 right_id: 2 } metadata { - gain: 0.387675 + gain: 1.169714 } } nodes { bucketized_split { feature_id: 0 - threshold: 3 + threshold: 1 left_id: 3 right_id: 4 } metadata { - gain: 0.0 + gain: 2.673407 original_leaf { - scalar: -0.181818 + scalar: 0.241322 } } } @@ -677,30 +1179,30 @@ class ModelFnTests(test_util.TensorFlowTestCase): right_id: 6 } metadata { - gain: 0.105518 + gain: 0.324102 original_leaf { - scalar: 0.0625 + scalar: 0.083951 } } } nodes { leaf { - scalar: -0.348397 + scalar: 0.563167 } } nodes { leaf { - scalar: -0.181818 + scalar: 0.247047 } } nodes { leaf { - scalar: 0.224091 + scalar: 0.095273 } } nodes { leaf { - scalar: 0.056815 + scalar: 0.222102 } } } @@ -713,17 +1215,17 @@ class ModelFnTests(test_util.TensorFlowTestCase): right_id: 2 } metadata { - gain: 0.287131 + gain: 0.981026 } } nodes { leaf { - scalar: 0.162042 + scalar: 0.005166 } } nodes { leaf { - scalar: -0.086986 + scalar: 0.180281 } } } @@ -746,8 +1248,20 @@ class ModelFnTests(test_util.TensorFlowTestCase): """ return (first_round, second_round, third_round) - def _get_expected_ensembles_for_regression(self): + def _get_expected_ensembles_for_regression_with_bias(self): first_round = """ + trees { + nodes { + leaf { + scalar: 1.799974 + } + } + } + tree_weights: 1.0 + tree_metadata { + } + """ + second_round = """ trees { nodes { bucketized_split { @@ -757,17 +1271,20 @@ class ModelFnTests(test_util.TensorFlowTestCase): right_id: 2 } metadata { - gain: 1.169714 + gain: 1.190442 + original_leaf { + scalar: 1.799974 + } } } nodes { leaf { - scalar: 0.241322 + scalar: 1.862786 } } nodes { leaf { - scalar: 0.083951 + scalar: 1.706149 } } } @@ -783,7 +1300,7 @@ class ModelFnTests(test_util.TensorFlowTestCase): last_layer_node_end: 3 } """ - second_round = """ + third_round = """ trees { nodes { bucketized_split { @@ -793,7 +1310,10 @@ class ModelFnTests(test_util.TensorFlowTestCase): right_id: 2 } metadata { - gain: 1.169714 + gain: 1.190442 + original_leaf { + scalar: 1.799974 + } } } nodes { @@ -804,9 +1324,9 @@ class ModelFnTests(test_util.TensorFlowTestCase): right_id: 4 } metadata { - gain: 2.673407 + gain: 2.683594 original_leaf { - scalar: 0.241322 + scalar: 1.862786 } } } @@ -818,30 +1338,30 @@ class ModelFnTests(test_util.TensorFlowTestCase): right_id: 6 } metadata { - gain: 0.324102 + gain: 0.322693 original_leaf { - scalar: 0.083951 + scalar: 1.706149 } } } nodes { leaf { - scalar: 0.563167 + scalar: 2.024487 } } nodes { leaf { - scalar: 0.247047 + scalar: 1.710319 } } nodes { leaf { - scalar: 0.095273 + scalar: 1.559208 } } nodes { leaf { - scalar: 0.222102 + scalar: 1.686037 } } } @@ -869,7 +1389,7 @@ class ModelFnTests(test_util.TensorFlowTestCase): last_layer_node_end: 1 } """ - third_round = """ + forth_round = """ trees { nodes { bucketized_split { @@ -879,55 +1399,55 @@ class ModelFnTests(test_util.TensorFlowTestCase): right_id: 2 } metadata { - gain: 1.169714 + gain: 1.190442 + original_leaf { + scalar: 1.799974 + } } } nodes { bucketized_split { - feature_id: 0 threshold: 1 left_id: 3 right_id: 4 } metadata { - gain: 2.673407 + gain: 2.683594 original_leaf { - scalar: 0.241322 + scalar: 1.8627863 } } } nodes { bucketized_split { - feature_id: 0 - threshold: 0 left_id: 5 right_id: 6 } metadata { - gain: 0.324102 + gain: 0.322693 original_leaf { - scalar: 0.083951 + scalar: 1.706149 } } } nodes { leaf { - scalar: 0.563167 + scalar: 2.024487 } } nodes { leaf { - scalar: 0.247047 + scalar: 1.710319 } } nodes { leaf { - scalar: 0.095273 + scalar: 1.5592078 } } nodes { leaf { - scalar: 0.222102 + scalar: 1.686037 } } } @@ -935,22 +1455,21 @@ class ModelFnTests(test_util.TensorFlowTestCase): nodes { bucketized_split { feature_id: 1 - threshold: 0 left_id: 1 right_id: 2 } metadata { - gain: 0.981026 + gain: 0.972589 } } nodes { leaf { - scalar: 0.005166 + scalar: -0.137592 } } nodes { leaf { - scalar: 0.180281 + scalar: 0.034926 } } } @@ -962,7 +1481,6 @@ class ModelFnTests(test_util.TensorFlowTestCase): } tree_metadata { num_layers_grown: 1 - is_finalized: false } growing_metadata { num_trees_attempted: 2 @@ -971,19 +1489,35 @@ class ModelFnTests(test_util.TensorFlowTestCase): last_layer_node_end: 3 } """ - return (first_round, second_round, third_round) - - def _get_train_op_and_ensemble(self, head, config, is_classification, - train_in_memory): + return (first_round, second_round, third_round, forth_round) + + def _get_train_op_and_ensemble(self, + head, + config, + is_classification, + train_in_memory, + center_bias=False): """Calls bt_model_fn() and returns the train_op and ensemble_serialzed.""" features, labels = _make_train_input_fn(is_classification)() + + tree_hparams = boosted_trees._TreeHParams( # pylint:disable=protected-access + n_trees=2, + max_depth=2, + learning_rate=0.1, + l1=0., + l2=0.01, + tree_complexity=0., + min_node_weight=0., + center_bias=center_bias, + pruning_mode='none') + estimator_spec = boosted_trees._bt_model_fn( # pylint:disable=protected-access features=features, labels=labels, mode=model_fn.ModeKeys.TRAIN, head=head, feature_columns=self._feature_columns, - tree_hparams=self._tree_hparams, + tree_hparams=tree_hparams, example_id_column_name=EXAMPLE_ID_COLUMN, n_batches_per_layer=1, config=config, @@ -1032,6 +1566,49 @@ class ModelFnTests(test_util.TensorFlowTestCase): ensemble_proto.ParseFromString(serialized) self.assertProtoEquals(expected_third, ensemble_proto) + def testTrainClassifierWithCenterBiasInMemory(self): + ops.reset_default_graph() + + # When bias centering is on, we expect the very first node to have the + expected_first, expected_second, expected_third, expected_forth = ( + self._get_expected_ensembles_for_classification_with_bias()) + + with self.test_session() as sess: + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_classification_head(n_classes=2), + run_config.RunConfig(), + is_classification=True, + train_in_memory=True, + center_bias=True) + + # 4 iterations to center bias. + for _ in range(4): + _, serialized = sess.run([train_op, ensemble_serialized]) + + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_first, ensemble_proto) + + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + # Forth round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + + self.assertProtoEquals(expected_forth, ensemble_proto) + def testTrainClassifierNonInMemory(self): ops.reset_default_graph() expected_first, expected_second, expected_third = ( @@ -1062,6 +1639,47 @@ class ModelFnTests(test_util.TensorFlowTestCase): ensemble_proto.ParseFromString(serialized) self.assertProtoEquals(expected_third, ensemble_proto) + def testTrainClassifierWithCenterBiasNonInMemory(self): + ops.reset_default_graph() + + # When bias centering is on, we expect the very first node to have the + expected_first, expected_second, expected_third, expected_forth = ( + self._get_expected_ensembles_for_classification_with_bias()) + + with self.test_session() as sess: + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_classification_head(n_classes=2), + run_config.RunConfig(), + is_classification=True, + train_in_memory=False, + center_bias=True) + # 4 iterations to center bias. + for _ in range(4): + _, serialized = sess.run([train_op, ensemble_serialized]) + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_first, ensemble_proto) + + # Run one more time and validate the trained ensemble. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + # Forth round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_forth, ensemble_proto) + def testTrainRegressorInMemory(self): ops.reset_default_graph() expected_first, expected_second, expected_third = ( @@ -1092,6 +1710,46 @@ class ModelFnTests(test_util.TensorFlowTestCase): ensemble_proto.ParseFromString(serialized) self.assertProtoEquals(expected_third, ensemble_proto) + def testTrainRegressorInMemoryWithCenterBias(self): + ops.reset_default_graph() + expected_first, expected_second, expected_third, expected_forth = ( + self._get_expected_ensembles_for_regression_with_bias()) + with self.test_session() as sess: + # Train with train_in_memory mode. + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_regression_head(label_dimension=1), + run_config.RunConfig(), + is_classification=False, + train_in_memory=True, + center_bias=True) + # 3 iterations to center bias. + for _ in range(3): + _, serialized = sess.run([train_op, ensemble_serialized]) + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + + self.assertProtoEquals(expected_first, ensemble_proto) + + # Run one more time and validate the trained ensemble. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + # Forth round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_forth, ensemble_proto) + def testTrainRegressorNonInMemory(self): ops.reset_default_graph() expected_first, expected_second, expected_third = ( @@ -1122,6 +1780,46 @@ class ModelFnTests(test_util.TensorFlowTestCase): ensemble_proto.ParseFromString(serialized) self.assertProtoEquals(expected_third, ensemble_proto) + def testTrainRegressorNotInMemoryWithCenterBias(self): + ops.reset_default_graph() + expected_first, expected_second, expected_third, expected_forth = ( + self._get_expected_ensembles_for_regression_with_bias()) + with self.test_session() as sess: + # Train with train_in_memory mode. + with sess.graph.as_default(): + train_op, ensemble_serialized = self._get_train_op_and_ensemble( + boosted_trees._create_regression_head(label_dimension=1), + run_config.RunConfig(), + is_classification=False, + train_in_memory=False, + center_bias=True) + # 3 iterations to center the bias (because we are using regularization). + for _ in range(3): + _, serialized = sess.run([train_op, ensemble_serialized]) + + # Validate the trained ensemble. + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_first, ensemble_proto) + + # Run one more time and validate the trained ensemble. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_second, ensemble_proto) + + # Third round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_third, ensemble_proto) + + # Forth round training and validation. + _, serialized = sess.run([train_op, ensemble_serialized]) + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertProtoEquals(expected_forth, ensemble_proto) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index b924ad5df43fac68696abf5fe815b5e7fa20908c..c08cf61220716730fa495c6e327b91e8f3c69cd5 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -26,6 +26,7 @@ from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import optimizers from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.layers import core as core_layers +from tensorflow.python.layers import normalization from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables @@ -45,7 +46,7 @@ def _add_hidden_layer_summary(value, tag): def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn, - dropout, input_layer_partitioner): + dropout, input_layer_partitioner, batch_norm): """Function builder for a dnn logit_fn. Args: @@ -58,6 +59,7 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn, dropout: When not `None`, the probability we will drop out a given coordinate. input_layer_partitioner: Partitioner for input layer. + batch_norm: Whether to use batch normalization after each hidden layer. Returns: A logit_fn (see below). @@ -83,6 +85,7 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn, A `Tensor` representing the logits, or a list of `Tensor`'s representing multiple logits in the MultiHead case. """ + is_training = mode == model_fn.ModeKeys.TRAIN with variable_scope.variable_scope( 'input_from_feature_columns', values=tuple(six.itervalues(features)), @@ -98,8 +101,20 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn, activation=activation_fn, kernel_initializer=init_ops.glorot_uniform_initializer(), name=hidden_layer_scope) - if dropout is not None and mode == model_fn.ModeKeys.TRAIN: + if dropout is not None and is_training: net = core_layers.dropout(net, rate=dropout, training=True) + if batch_norm: + # TODO(hjm): In future, if this becomes popular, we can enable + # customization of the batch normalization params by accepting a + # list of `BatchNormalization` instances as `batch_norm`. + net = normalization.batch_normalization( + net, + # The default momentum 0.99 actually crashes on certain + # problem, so here we use 0.999, which is the default of + # tf.contrib.layers.batch_norm. + momentum=0.999, + training=is_training, + name='batchnorm_%d' % layer_id) _add_hidden_layer_summary(net, hidden_layer_scope.name) with variable_scope.variable_scope('logits', values=(net,)) as logits_scope: @@ -127,7 +142,8 @@ def _dnn_model_fn(features, dropout=None, input_layer_partitioner=None, config=None, - tpu_estimator_spec=False): + tpu_estimator_spec=False, + batch_norm=False): """Deep Neural Net model_fn. Args: @@ -150,6 +166,7 @@ def _dnn_model_fn(features, config: `RunConfig` object to configure the runtime settings. tpu_estimator_spec: Whether to return a `_TPUEstimatorSpec` or or `model_fn.EstimatorSpec` instance. + batch_norm: Whether to use batch normalization after each hidden layer. Returns: An `EstimatorSpec` instance. @@ -182,7 +199,8 @@ def _dnn_model_fn(features, feature_columns=feature_columns, activation_fn=activation_fn, dropout=dropout, - input_layer_partitioner=input_layer_partitioner) + input_layer_partitioner=input_layer_partitioner, + batch_norm=batch_norm) logits = logit_fn(features=features, mode=mode) if tpu_estimator_spec: @@ -230,6 +248,17 @@ class DNNClassifier(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = DNNClassifier( + feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], + hidden_units=[1024, 512, 256], + optimizer=lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = DNNClassifier( feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], @@ -266,7 +295,10 @@ class DNNClassifier(estimator.Estimator): Loss is calculated by using softmax cross entropy. @compatibility(eager) - Estimators are not compatible with eager execution. + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. @end_compatibility """ @@ -285,6 +317,7 @@ class DNNClassifier(estimator.Estimator): config=None, warm_start_from=None, loss_reduction=losses.Reduction.SUM, + batch_norm=False, ): """Initializes a `DNNClassifier` instance. @@ -314,8 +347,9 @@ class DNNClassifier(estimator.Estimator): encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there will be errors if vocabulary is not provided and labels are string. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to Adagrad optimizer. activation_fn: Activation function applied to each layer. If `None`, will use `tf.nn.relu`. dropout: When not `None`, the probability we will drop out a given @@ -330,6 +364,7 @@ class DNNClassifier(estimator.Estimator): names are unchanged. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + batch_norm: Whether to use batch normalization after each hidden layer. """ head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access n_classes, weight_column, label_vocabulary, loss_reduction) @@ -346,7 +381,8 @@ class DNNClassifier(estimator.Estimator): activation_fn=activation_fn, dropout=dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + batch_norm=batch_norm) super(DNNClassifier, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, @@ -382,6 +418,17 @@ class DNNRegressor(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = DNNRegressor( + feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], + hidden_units=[1024, 512, 256], + optimizer=lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = DNNRegressor( feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb], @@ -418,7 +465,10 @@ class DNNRegressor(estimator.Estimator): Loss is calculated by using mean squared error. @compatibility(eager) - Estimators are not compatible with eager execution. + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. @end_compatibility """ @@ -436,6 +486,7 @@ class DNNRegressor(estimator.Estimator): config=None, warm_start_from=None, loss_reduction=losses.Reduction.SUM, + batch_norm=False, ): """Initializes a `DNNRegressor` instance. @@ -459,8 +510,9 @@ class DNNRegressor(estimator.Estimator): used as a key to fetch weight tensor from the `features`. If it is a `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then weight_column.normalizer_fn is applied on it to get weight tensor. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to Adagrad optimizer. activation_fn: Activation function applied to each layer. If `None`, will use `tf.nn.relu`. dropout: When not `None`, the probability we will drop out a given @@ -475,6 +527,7 @@ class DNNRegressor(estimator.Estimator): names are unchanged. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + batch_norm: Whether to use batch normalization after each hidden layer. """ def _model_fn(features, labels, mode, config): @@ -492,7 +545,8 @@ class DNNRegressor(estimator.Estimator): activation_fn=activation_fn, dropout=dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + batch_norm=batch_norm) super(DNNRegressor, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py index 64d81c46ce6f6cec3f095b8bde82a3039f3377fc..4945c3ba11cadf95d8dd4f066a115ba9791461c8 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -88,7 +88,9 @@ def _dnn_linear_combined_model_fn(features, dnn_activation_fn=nn.relu, dnn_dropout=None, input_layer_partitioner=None, - config=None): + config=None, + batch_norm=False, + linear_sparse_combiner='sum'): """Deep Neural Net and Linear combined model_fn. Args: @@ -115,7 +117,10 @@ def _dnn_linear_combined_model_fn(features, coordinate. input_layer_partitioner: Partitioner for input layer. config: `RunConfig` object to configure the runtime settings. - + batch_norm: Whether to use batch normalization after each hidden layer. + linear_sparse_combiner: A string specifying how to reduce the linear model + if a categorical column is multivalent. One of "mean", "sqrtn", and + "sum". Returns: An `EstimatorSpec` instance. @@ -164,7 +169,8 @@ def _dnn_linear_combined_model_fn(features, feature_columns=dnn_feature_columns, activation_fn=dnn_activation_fn, dropout=dnn_dropout, - input_layer_partitioner=input_layer_partitioner) + input_layer_partitioner=input_layer_partitioner, + batch_norm=batch_norm) dnn_logits = dnn_logit_fn(features=features, mode=mode) linear_parent_scope = 'linear' @@ -182,7 +188,8 @@ def _dnn_linear_combined_model_fn(features, partitioner=input_layer_partitioner) as scope: logit_fn = linear._linear_logit_fn_builder( # pylint: disable=protected-access units=head.logits_dimension, - feature_columns=linear_feature_columns) + feature_columns=linear_feature_columns, + sparse_combiner=linear_sparse_combiner) linear_logits = logit_fn(features=features) _add_layer_summary(linear_logits, scope.name) @@ -257,12 +264,19 @@ class DNNLinearCombinedClassifier(estimator.Estimator): # warm-start settings warm_start_from="/path/to/checkpoint/dir") - # To apply L1 and L2 regularization, you can set optimizers as follows: + # To apply L1 and L2 regularization, you can set dnn_optimizer to: tf.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001, l2_regularization_strength=0.001) - # It is same for FtrlOptimizer. + # To apply learning rate decay, you can set dnn_optimizer to a callable: + lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96) + # It is the same for linear_optimizer. # Input builders def input_fn_train: # returns x, y @@ -292,7 +306,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator): Loss is calculated by using softmax cross entropy. @compatibility(eager) - Estimators are not compatible with eager execution. + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. @end_compatibility """ @@ -311,7 +328,9 @@ class DNNLinearCombinedClassifier(estimator.Estimator): input_layer_partitioner=None, config=None, warm_start_from=None, - loss_reduction=losses.Reduction.SUM): + loss_reduction=losses.Reduction.SUM, + batch_norm=False, + linear_sparse_combiner='sum'): """Initializes a DNNLinearCombinedClassifier instance. Args: @@ -322,12 +341,16 @@ class DNNLinearCombinedClassifier(estimator.Estimator): used by linear part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the linear part of the model. Defaults to FTRL optimizer. + the linear part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL + optimizer. dnn_feature_columns: An iterable containing all the feature columns used by deep part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the deep part of the model. Defaults to Adagrad optimizer. + the deep part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad + optimizer. dnn_hidden_units: List of hidden units per layer. All layers are fully connected. dnn_activation_fn: Activation function applied to each layer. If None, @@ -360,6 +383,12 @@ class DNNLinearCombinedClassifier(estimator.Estimator): names are unchanged. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + batch_norm: Whether to use batch normalization after each hidden layer. + linear_sparse_combiner: A string specifying how to reduce the linear model + if a categorical column is multivalent. One of "mean", "sqrtn", and + "sum" -- these are effectively different ways to do example-level + normalization, which can be useful for bag-of-words features. For more + details, see `tf.feature_column.linear_model`. Raises: ValueError: If both linear_feature_columns and dnn_features_columns are @@ -399,7 +428,9 @@ class DNNLinearCombinedClassifier(estimator.Estimator): dnn_activation_fn=dnn_activation_fn, dnn_dropout=dnn_dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + batch_norm=batch_norm, + linear_sparse_combiner=linear_sparse_combiner) super(DNNLinearCombinedClassifier, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, @@ -438,12 +469,19 @@ class DNNLinearCombinedRegressor(estimator.Estimator): # warm-start settings warm_start_from="/path/to/checkpoint/dir") - # To apply L1 and L2 regularization, you can set optimizers as follows: + # To apply L1 and L2 regularization, you can set dnn_optimizer to: tf.train.ProximalAdagradOptimizer( learning_rate=0.1, l1_regularization_strength=0.001, l2_regularization_strength=0.001) - # It is same for FtrlOptimizer. + # To apply learning rate decay, you can set dnn_optimizer to a callable: + lambda: tf.AdamOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96) + # It is the same for linear_optimizer. # Input builders def input_fn_train: # returns x, y @@ -473,7 +511,10 @@ class DNNLinearCombinedRegressor(estimator.Estimator): Loss is calculated by using mean squared error. @compatibility(eager) - Estimators are not compatible with eager execution. + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. @end_compatibility """ @@ -491,7 +532,9 @@ class DNNLinearCombinedRegressor(estimator.Estimator): input_layer_partitioner=None, config=None, warm_start_from=None, - loss_reduction=losses.Reduction.SUM): + loss_reduction=losses.Reduction.SUM, + batch_norm=False, + linear_sparse_combiner='sum'): """Initializes a DNNLinearCombinedRegressor instance. Args: @@ -502,12 +545,16 @@ class DNNLinearCombinedRegressor(estimator.Estimator): used by linear part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the linear part of the model. Defaults to FTRL optimizer. + the linear part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL + optimizer. dnn_feature_columns: An iterable containing all the feature columns used by deep part of the model. All items in the set must be instances of classes derived from `FeatureColumn`. dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the deep part of the model. Defaults to Adagrad optimizer. + the deep part of the model. Can also be a string (one of 'Adagrad', + 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad + optimizer. dnn_hidden_units: List of hidden units per layer. All layers are fully connected. dnn_activation_fn: Activation function applied to each layer. If None, @@ -534,6 +581,12 @@ class DNNLinearCombinedRegressor(estimator.Estimator): names are unchanged. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + batch_norm: Whether to use batch normalization after each hidden layer. + linear_sparse_combiner: A string specifying how to reduce the linear model + if a categorical column is multivalent. One of "mean", "sqrtn", and + "sum" -- these are effectively different ways to do example-level + normalization, which can be useful for bag-of-words features. For more + details, see `tf.feature_column.linear_model`. Raises: ValueError: If both linear_feature_columns and dnn_features_columns are @@ -564,7 +617,9 @@ class DNNLinearCombinedRegressor(estimator.Estimator): dnn_activation_fn=dnn_activation_fn, dnn_dropout=dnn_dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + batch_norm=batch_norm, + linear_sparse_combiner=linear_sparse_combiner) super(DNNLinearCombinedRegressor, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py index d275695eb319117cf94aefd7038ab5ee685e05a9..d16318659ba8fac70486e88fff07d71e060eac9b 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py @@ -100,7 +100,8 @@ def _linear_regressor_fn(feature_columns, weight_column=None, optimizer='Ftrl', config=None, - partitioner=None): + partitioner=None, + sparse_combiner='sum'): return dnn_linear_combined.DNNLinearCombinedRegressor( model_dir=model_dir, linear_feature_columns=feature_columns, @@ -108,7 +109,8 @@ def _linear_regressor_fn(feature_columns, label_dimension=label_dimension, weight_column=weight_column, input_layer_partitioner=partitioner, - config=config) + config=config, + linear_sparse_combiner=sparse_combiner) class LinearOnlyRegressorPartitionerTest( @@ -163,7 +165,8 @@ def _linear_classifier_fn(feature_columns, label_vocabulary=None, optimizer='Ftrl', config=None, - partitioner=None): + partitioner=None, + sparse_combiner='sum'): return dnn_linear_combined.DNNLinearCombinedClassifier( model_dir=model_dir, linear_feature_columns=feature_columns, @@ -172,7 +175,8 @@ def _linear_classifier_fn(feature_columns, weight_column=weight_column, label_vocabulary=label_vocabulary, input_layer_partitioner=partitioner, - config=config) + config=config, + linear_sparse_combiner=sparse_combiner) class LinearOnlyClassifierTrainingTest( diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py index 06a648777f8f730b4c739a69528090c5821f2681..de226ed0ef28e6a026e5df6ce128e178254a8c93 100644 --- a/tensorflow/python/estimator/canned/dnn_testing_utils.py +++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py @@ -65,6 +65,11 @@ from tensorflow.python.training import training_util LEARNING_RATE_NAME = 'dnn/regression_head/dnn/learning_rate' HIDDEN_WEIGHTS_NAME_PATTERN = 'dnn/hiddenlayer_%d/kernel' HIDDEN_BIASES_NAME_PATTERN = 'dnn/hiddenlayer_%d/bias' +BATCH_NORM_BETA_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/beta' +BATCH_NORM_GAMMA_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/gamma' +BATCH_NORM_MEAN_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/moving_mean' +BATCH_NORM_VARIANCE_NAME_PATTERN = ( + 'dnn/hiddenlayer_%d/batchnorm_%d/moving_variance') LOGITS_WEIGHTS_NAME = 'dnn/logits/kernel' LOGITS_BIASES_NAME = 'dnn/logits/bias' OCCUPATION_EMBEDDING_NAME = ('dnn/input_from_feature_columns/input_layer/' @@ -89,7 +94,10 @@ def assert_close(expected, actual, rtol=1e-04, message='', name='assert_close'): name=scope) -def create_checkpoint(weights_and_biases, global_step, model_dir): +def create_checkpoint(weights_and_biases, + global_step, + model_dir, + batch_norm_vars=None): """Create checkpoint file with provided model weights. Args: @@ -98,12 +106,20 @@ def create_checkpoint(weights_and_biases, global_step, model_dir): model_dir: Directory into which checkpoint is saved. """ weights, biases = zip(*weights_and_biases) + if batch_norm_vars: + assert len(batch_norm_vars) == len(weights_and_biases) - 1 + (bn_betas, bn_gammas, bn_means, bn_variances) = zip(*batch_norm_vars) model_weights = {} # Hidden layer weights. for i in range(0, len(weights) - 1): model_weights[HIDDEN_WEIGHTS_NAME_PATTERN % i] = weights[i] model_weights[HIDDEN_BIASES_NAME_PATTERN % i] = biases[i] + if batch_norm_vars: + model_weights[BATCH_NORM_BETA_NAME_PATTERN % (i, i)] = bn_betas[i] + model_weights[BATCH_NORM_GAMMA_NAME_PATTERN % (i, i)] = bn_gammas[i] + model_weights[BATCH_NORM_MEAN_NAME_PATTERN % (i, i)] = bn_means[i] + model_weights[BATCH_NORM_VARIANCE_NAME_PATTERN % (i, i)] = bn_variances[i] # Output layer weights. model_weights[LOGITS_WEIGHTS_NAME] = weights[-1] @@ -503,8 +519,13 @@ class BaseDNNLogitFnTest(object): writer_cache.FileWriterCache.clear() shutil.rmtree(self._model_dir) - def _test_logits(self, mode, hidden_units, logits_dimension, inputs, - expected_logits): + def _test_logits(self, + mode, + hidden_units, + logits_dimension, + inputs, + expected_logits, + batch_norm=False): """Tests that the expected logits are calculated.""" with ops.Graph().as_default(): # Global step needed for MonitoredSession, which is in turn used to @@ -525,7 +546,8 @@ class BaseDNNLogitFnTest(object): ], activation_fn=nn.relu, dropout=None, - input_layer_partitioner=input_layer_partitioner) + input_layer_partitioner=input_layer_partitioner, + batch_norm=batch_norm) logits = logit_fn( features={'age': constant_op.constant(inputs)}, mode=mode) with monitored_session.MonitoredTrainingSession( @@ -556,6 +578,69 @@ class BaseDNNLogitFnTest(object): inputs=[[10.]], expected_logits=[[-2.08]]) + def test_one_dim_logits_with_batch_norm(self): + """Tests one-dimensional logits. + + input_layer = [[10]] + hidden_layer_0 = [[relu(0.6*10 +1), relu(0.5*10 -1)]] = [[7, 4]] + hidden_layer_0 = [[relu(0.6*20 +1), relu(0.5*20 -1)]] = [[13, 9]] + + batch_norm_0, training (epsilon = 0.001): + mean1 = 1/2*(7+13) = 10, + variance1 = 1/2*(3^2+3^2) = 9 + x11 = (7-10)/sqrt(9+0.001) = -0.999944449, + x21 = (13-10)/sqrt(9+0.001) = 0.999944449, + + mean2 = 1/2*(4+9) = 6.5, + variance2 = 1/2*(2.5^2+.2.5^2) = 6.25 + x12 = (4-6.5)/sqrt(6.25+0.001) = -0.99992001, + x22 = (9-6.5)/sqrt(6.25+0.001) = 0.99992001, + + logits = [[-1*(-0.999944449) + 2*(-0.99992001) + 0.3], + [-1*0.999944449 + 2*0.99992001 + 0.3]] + = [[-0.699895571],[1.299895571]] + + batch_norm_0, not training (epsilon = 0.001): + moving_mean1 = 0, moving_variance1 = 1 + x11 = (7-0)/sqrt(1+0.001) = 6.996502623, + x21 = (13-0)/sqrt(1+0.001) = 12.993504871, + moving_mean2 = 0, moving_variance2 = 1 + x12 = (4-0)/sqrt(1+0.001) = 3.998001499, + x22 = (9-0)/sqrt(1+0.001) = 8.995503372, + + logits = [[-1*6.996502623 + 2*3.998001499 + 0.3], + [-1*12.993504871 + 2*8.995503372 + 0.3]] + = [[1.299500375],[5.297501873]] + """ + base_global_step = 100 + create_checkpoint( + ( + ([[.6, .5]], [1., -1.]), + ([[-1.], [2.]], [.3]), + ), + base_global_step, + self._model_dir, + batch_norm_vars=([[0, 0], # beta. + [1, 1], # gamma. + [0, 0], # moving mean. + [1, 1], # moving variance. + ],)) + self._test_logits( + model_fn.ModeKeys.TRAIN, + hidden_units=[2], + logits_dimension=1, + inputs=[[10.], [20.]], + expected_logits=[[-0.699895571], [1.299895571]], + batch_norm=True) + for mode in [model_fn.ModeKeys.EVAL, model_fn.ModeKeys.PREDICT]: + self._test_logits( + mode, + hidden_units=[2], + logits_dimension=1, + inputs=[[10.], [20.]], + expected_logits=[[1.299500375], [5.297501873]], + batch_norm=True) + def test_multi_dim_logits(self): """Tests multi-dimensional logits. @@ -706,7 +791,8 @@ class BaseDNNLogitFnTest(object): ], activation_fn=nn.relu, dropout=None, - input_layer_partitioner=input_layer_partitioner) + input_layer_partitioner=input_layer_partitioner, + batch_norm=False) logits = logit_fn( features={ 'age': constant_op.constant(inputs[0]), @@ -1185,6 +1271,8 @@ class BaseDNNRegressorEvaluateTest(object): self.assertAllClose({ metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS_MEAN: expected_loss, + metric_keys.MetricKeys.PREDICTION_MEAN: -2.08, + metric_keys.MetricKeys.LABEL_MEAN: 1.0, ops.GraphKeys.GLOBAL_STEP: global_step }, dnn_regressor.evaluate(input_fn=_input_fn, steps=1)) @@ -1215,6 +1303,8 @@ class BaseDNNRegressorEvaluateTest(object): self.assertAllClose({ metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS_MEAN: expected_loss / label_dimension, + metric_keys.MetricKeys.PREDICTION_MEAN: 0.39 / 3.0, + metric_keys.MetricKeys.LABEL_MEAN: 0.5 / 3.0, ops.GraphKeys.GLOBAL_STEP: global_step }, dnn_regressor.evaluate(input_fn=_input_fn, steps=1)) diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index b74ef1015cc564c20370e17e94e3a09d460c4f85..da9a64c2bc9f6b6797ef6cc115f36a73616b2e1e 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -1398,15 +1398,21 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): weights=weights, processed_labels=labels) - def _eval_metric_ops(self, weights, unreduced_loss, regularization_loss): + def _eval_metric_ops(self, predicted_value, labels, weights, unreduced_loss, + regularization_loss): """Returns the Eval metric ops.""" keys = metric_keys.MetricKeys # Estimator already adds a metric for loss. eval_metric_ops = { _summary_key(self._name, keys.LOSS_MEAN): - metrics_lib.mean( - values=unreduced_loss, - weights=weights) + metrics_lib.mean(values=unreduced_loss, weights=weights), + _summary_key(self._name, keys.PREDICTION_MEAN): + _predictions_mean( + predictions=predicted_value, + weights=weights, + name=keys.PREDICTION_MEAN), + _summary_key(self._name, keys.LABEL_MEAN): + metrics_lib.mean(values=labels, weights=weights) } if regularization_loss is not None: regularization_loss_key = _summary_key( @@ -1489,13 +1495,13 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): predictions=predictions, loss=regularized_training_loss, eval_metrics=_create_eval_metrics_tuple( - self._eval_metric_ops, - { + self._eval_metric_ops, { + 'predicted_value': predicted_value, + 'labels': labels, 'weights': weights, 'unreduced_loss': unreduced_loss, 'regularization_loss': regularization_loss, - } - )) + })) # Train. if optimizer is not None: diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index 08ce5ca8e833fdd88f9c45b668f0914fcc70acd0..bd2e0ae943fb4da2acc09b120db59cf08e4ed9e6 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -3103,8 +3103,10 @@ class RegressionHead(test.TestCase): self.assertItemsEqual((prediction_key,), spec.predictions.keys()) self.assertEqual(dtypes.float32, spec.predictions[prediction_key].dtype) self.assertEqual(dtypes.float32, spec.loss.dtype) - self.assertItemsEqual( - (metric_keys.MetricKeys.LOSS_MEAN,), spec.eval_metric_ops.keys()) + self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN, + metric_keys.MetricKeys.PREDICTION_MEAN, + metric_keys.MetricKeys.LABEL_MEAN), + spec.eval_metric_ops.keys()) self.assertIsNone(spec.train_op) self.assertIsNone(spec.export_outputs) _assert_no_hooks(self, spec) @@ -3140,6 +3142,9 @@ class RegressionHead(test.TestCase): expected_metric_keys = [ '{}/some_regression_head'.format(metric_keys.MetricKeys.LOSS_MEAN), + '{}/some_regression_head'.format( + metric_keys.MetricKeys.PREDICTION_MEAN), + '{}/some_regression_head'.format(metric_keys.MetricKeys.LABEL_MEAN), ] self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys()) @@ -3170,6 +3175,8 @@ class RegressionHead(test.TestCase): expected_metrics = { keys.LOSS_MEAN: expected_unregularized_loss, keys.LOSS_REGULARIZATION: expected_regularization_loss, + keys.PREDICTION_MEAN: (45 + 41) / 2.0, + keys.LABEL_MEAN: (43 + 44) / 2.0, } # Assert predictions, loss, and metrics. @@ -3471,8 +3478,10 @@ class RegressionHead(test.TestCase): self.assertItemsEqual((prediction_key,), spec.predictions.keys()) self.assertEqual(dtypes.float32, spec.predictions[prediction_key].dtype) self.assertEqual(dtypes.float32, spec.loss.dtype) - self.assertItemsEqual( - (metric_keys.MetricKeys.LOSS_MEAN,), spec.eval_metric_ops.keys()) + self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN, + metric_keys.MetricKeys.PREDICTION_MEAN, + metric_keys.MetricKeys.LABEL_MEAN), + spec.eval_metric_ops.keys()) self.assertIsNone(spec.train_op) self.assertIsNone(spec.export_outputs) _assert_no_hooks(self, spec) @@ -3700,8 +3709,10 @@ class RegressionHead(test.TestCase): self.assertItemsEqual((prediction_key,), spec.predictions.keys()) self.assertEqual(dtypes.float32, spec.predictions[prediction_key].dtype) self.assertEqual(dtypes.float32, spec.loss.dtype) - self.assertItemsEqual( - (metric_keys.MetricKeys.LOSS_MEAN,), spec.eval_metric_ops.keys()) + self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN, + metric_keys.MetricKeys.PREDICTION_MEAN, + metric_keys.MetricKeys.LABEL_MEAN), + spec.eval_metric_ops.keys()) self.assertIsNone(spec.train_op) self.assertIsNone(spec.export_outputs) _assert_no_hooks(self, spec) @@ -3832,7 +3843,13 @@ class RegressionHead(test.TestCase): # losses = [1*(35-45)^2, .1*(42-41)^2, 1.5*(45-44)^2] = [100, .1, 1.5] # loss = sum(losses) = 100+.1+1.5 = 101.6 # loss_mean = loss/(1+.1+1.5) = 101.6/2.6 = 39.076923 - expected_metrics = {metric_keys.MetricKeys.LOSS_MEAN: 39.076923} + expected_metrics = { + metric_keys.MetricKeys.LOSS_MEAN: + 39.076923, + metric_keys.MetricKeys.PREDICTION_MEAN: + (45 + 41 * 0.1 + 44 * 1.5) / 2.6, + metric_keys.MetricKeys.LABEL_MEAN: (35 + 42 * 0.1 + 45 * 1.5) / 2.6, + } # Assert spec contains expected tensors. self.assertEqual(dtypes.float32, spec.loss.dtype) diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index 705fc3ce06ee9b1c75b3d7adefdfb2ec8459335c..115dd185185adb049d7ce04592fa8dac1e7e4f82 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -66,13 +66,15 @@ def _compute_fraction_of_zero(cols_to_vars): return nn.zero_fraction(array_ops.concat(all_weight_vars, axis=0)) -def _linear_logit_fn_builder(units, feature_columns): +def _linear_logit_fn_builder(units, feature_columns, sparse_combiner='sum'): """Function builder for a linear logit_fn. Args: units: An int indicating the dimension of the logit layer. feature_columns: An iterable containing all the feature columns used by the model. + sparse_combiner: A string specifying how to reduce if a categorical column + is multivalent. One of "mean", "sqrtn", and "sum". Returns: A logit_fn (see below). @@ -95,6 +97,7 @@ def _linear_logit_fn_builder(units, feature_columns): features=features, feature_columns=feature_columns, units=units, + sparse_combiner=sparse_combiner, cols_to_vars=cols_to_vars) bias = cols_to_vars.pop('bias') if units > 1: @@ -111,7 +114,7 @@ def _linear_logit_fn_builder(units, feature_columns): def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, - partitioner, config): + partitioner, config, sparse_combiner='sum'): """A model_fn for linear models that use a gradient-based optimizer. Args: @@ -126,6 +129,8 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, optimizer to use for training. If `None`, will use a FTRL optimizer. partitioner: Partitioner for variables. config: `RunConfig` object to configure the runtime settings. + sparse_combiner: A string specifying how to reduce if a categorical column + is multivalent. One of "mean", "sqrtn", and "sum". Returns: An `EstimatorSpec` instance. @@ -153,7 +158,8 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer, partitioner=partitioner): logit_fn = _linear_logit_fn_builder( - units=head.logits_dimension, feature_columns=feature_columns) + units=head.logits_dimension, feature_columns=feature_columns, + sparse_combiner=sparse_combiner) logits = logit_fn(features=features) return head.create_estimator_spec( @@ -193,6 +199,17 @@ class LinearClassifier(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator using an optimizer with a learning rate decay. + estimator = LinearClassifier( + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=lambda: tf.train.FtrlOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = LinearClassifier( feature_columns=[categorical_column_a, @@ -227,7 +244,10 @@ class LinearClassifier(estimator.Estimator): Loss is calculated by using softmax cross entropy. @compatibility(eager) - Estimators are not compatible with eager execution. + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. @end_compatibility """ @@ -241,7 +261,8 @@ class LinearClassifier(estimator.Estimator): config=None, partitioner=None, warm_start_from=None, - loss_reduction=losses.Reduction.SUM): + loss_reduction=losses.Reduction.SUM, + sparse_combiner='sum'): """Construct a `LinearClassifier` estimator object. Args: @@ -269,8 +290,9 @@ class LinearClassifier(estimator.Estimator): encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there will be errors if vocabulary is not provided and labels are string. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to FTRL optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to FTRL optimizer. config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. warm_start_from: A string filepath to a checkpoint to warm-start from, or @@ -280,6 +302,11 @@ class LinearClassifier(estimator.Estimator): and Tensor names are unchanged. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + sparse_combiner: A string specifying how to reduce if a categorical column + is multivalent. One of "mean", "sqrtn", and "sum" -- these are + effectively different ways to do example-level normalization, which can + be useful for bag-of-words features. for more details, see + `tf.feature_column.linear_model`. Returns: A `LinearClassifier` estimator. @@ -308,7 +335,8 @@ class LinearClassifier(estimator.Estimator): feature_columns=tuple(feature_columns or []), optimizer=optimizer, partitioner=partitioner, - config=config) + config=config, + sparse_combiner=sparse_combiner) super(LinearClassifier, self).__init__( model_fn=_model_fn, @@ -332,10 +360,31 @@ class LinearRegressor(estimator.Estimator): categorical_feature_a_x_categorical_feature_b = crossed_column(...) + # Estimator using the default optimizer. estimator = LinearRegressor( feature_columns=[categorical_column_a, categorical_feature_a_x_categorical_feature_b]) + # Or estimator using the FTRL optimizer with regularization. + estimator = LinearRegressor( + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=tf.train.FtrlOptimizer( + learning_rate=0.1, + l1_regularization_strength=0.001 + )) + + # Or estimator using an optimizer with a learning rate decay. + estimator = LinearRegressor( + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b], + optimizer=lambda: tf.train.FtrlOptimizer( + learning_rate=tf.exponential_decay( + learning_rate=0.1, + global_step=tf.get_global_step(), + decay_steps=10000, + decay_rate=0.96)) + # Or estimator with warm-starting from a previous checkpoint. estimator = LinearRegressor( feature_columns=[categorical_column_a, @@ -370,7 +419,10 @@ class LinearRegressor(estimator.Estimator): Loss is calculated by using mean squared error. @compatibility(eager) - Estimators are not compatible with eager execution. + Estimators can be used while eager execution is enabled. Note that `input_fn` + and all hooks are executed inside a graph context, so they have to be written + to be compatible with graph mode. Note that `input_fn` code using `tf.data` + generally works in both graph and eager modes. @end_compatibility """ @@ -383,7 +435,8 @@ class LinearRegressor(estimator.Estimator): config=None, partitioner=None, warm_start_from=None, - loss_reduction=losses.Reduction.SUM): + loss_reduction=losses.Reduction.SUM, + sparse_combiner='sum'): """Initializes a `LinearRegressor` instance. Args: @@ -403,8 +456,9 @@ class LinearRegressor(estimator.Estimator): used as a key to fetch weight tensor from the `features`. If it is a `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then weight_column.normalizer_fn is applied on it to get weight tensor. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to FTRL optimizer. + optimizer: An instance of `tf.Optimizer` used to train the model. Can also + be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or + callable. Defaults to FTRL optimizer. config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. warm_start_from: A string filepath to a checkpoint to warm-start from, or @@ -414,6 +468,11 @@ class LinearRegressor(estimator.Estimator): and Tensor names are unchanged. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. + sparse_combiner: A string specifying how to reduce if a categorical column + is multivalent. One of "mean", "sqrtn", and "sum" -- these are + effectively different ways to do example-level normalization, which can + be useful for bag-of-words features. for more details, see + `tf.feature_column.linear_model`. """ head = head_lib._regression_head( # pylint: disable=protected-access label_dimension=label_dimension, weight_column=weight_column, @@ -429,7 +488,8 @@ class LinearRegressor(estimator.Estimator): feature_columns=tuple(feature_columns or []), optimizer=optimizer, partitioner=partitioner, - config=config) + config=config, + sparse_combiner=sparse_combiner) super(LinearRegressor, self).__init__( model_fn=_model_fn, diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py index 0e6436b42143f4b136165d47c41e143dacb4d476..c3934c7a801033d587465f0926301f30d4257fc7 100644 --- a/tensorflow/python/estimator/canned/linear_testing_utils.py +++ b/tensorflow/python/estimator/canned/linear_testing_utils.py @@ -29,6 +29,7 @@ import six from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.client import session as tf_session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import run_config from tensorflow.python.estimator.canned import linear @@ -260,6 +261,8 @@ class BaseLinearRegressorEvaluationTest(object): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 9., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -285,6 +288,8 @@ class BaseLinearRegressorEvaluationTest(object): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 18., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -315,6 +320,8 @@ class BaseLinearRegressorEvaluationTest(object): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 27., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -345,7 +352,9 @@ class BaseLinearRegressorEvaluationTest(object): self.assertItemsEqual( (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, - ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys()) + metric_keys.MetricKeys.PREDICTION_MEAN, + metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP), + eval_metrics.keys()) # Logit is # [2., 4., 5.] * [1.0, 2.0] + [7.0, 8.0] = [39, 50] + [7.0, 8.0] @@ -382,7 +391,9 @@ class BaseLinearRegressorEvaluationTest(object): eval_metrics = est.evaluate(input_fn=input_fn, steps=1) self.assertItemsEqual( (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, - ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys()) + metric_keys.MetricKeys.PREDICTION_MEAN, + metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP), + eval_metrics.keys()) # Logit is [(20. * 10.0 + 4 * 2.0 + 5.0), (40. * 10.0 + 8 * 2.0 + 5.0)] = # [213.0, 421.0], while label is [213., 421.]. Loss = 0. @@ -484,6 +495,69 @@ class BaseLinearRegressorPredictTest(object): # x0 * weight0 + x1 * weight1 + bias = 2. * 10. + 3. * 20 + .2 = 80.2 self.assertAllClose([[80.2]], predicted_scores) + def testSparseCombiner(self): + w_a = 2.0 + w_b = 3.0 + w_c = 5.0 + bias = 5.0 + with ops.Graph().as_default(): + variables_lib.Variable([[w_a], [w_b], [w_c]], name=LANGUAGE_WEIGHT_NAME) + variables_lib.Variable([bias], name=BIAS_NAME) + variables_lib.Variable(1, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + def _input_fn(): + return dataset_ops.Dataset.from_tensors({ + 'language': sparse_tensor.SparseTensor( + values=['a', 'c', 'b', 'c'], + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + dense_shape=[2, 2]), + }) + + feature_columns = ( + feature_column_lib.categorical_column_with_vocabulary_list( + 'language', vocabulary_list=['a', 'b', 'c']),) + + # Check prediction for each sparse_combiner. + # With sparse_combiner = 'sum', we have + # logits_1 = w_a + w_c + bias + # = 2.0 + 5.0 + 5.0 = 12.0 + # logits_2 = w_b + w_c + bias + # = 3.0 + 5.0 + 5.0 = 13.0 + linear_regressor = self._linear_regressor_fn( + feature_columns=feature_columns, + model_dir=self._model_dir) + predictions = linear_regressor.predict(input_fn=_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + self.assertAllClose([[12.0], [13.0]], predicted_scores) + + # With sparse_combiner = 'mean', we have + # logits_1 = 1/2 * (w_a + w_c) + bias + # = 1/2 * (2.0 + 5.0) + 5.0 = 8.5 + # logits_2 = 1/2 * (w_b + w_c) + bias + # = 1/2 * (3.0 + 5.0) + 5.0 = 9.0 + linear_regressor = self._linear_regressor_fn( + feature_columns=feature_columns, + model_dir=self._model_dir, + sparse_combiner='mean') + predictions = linear_regressor.predict(input_fn=_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + self.assertAllClose([[8.5], [9.0]], predicted_scores) + + # With sparse_combiner = 'sqrtn', we have + # logits_1 = sqrt(2)/2 * (w_a + w_c) + bias + # = sqrt(2)/2 * (2.0 + 5.0) + 5.0 = 9.94974 + # logits_2 = sqrt(2)/2 * (w_b + w_c) + bias + # = sqrt(2)/2 * (3.0 + 5.0) + 5.0 = 10.65685 + linear_regressor = self._linear_regressor_fn( + feature_columns=feature_columns, + model_dir=self._model_dir, + sparse_combiner='sqrtn') + predictions = linear_regressor.predict(input_fn=_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + self.assertAllClose([[9.94974], [10.65685]], predicted_scores) + class BaseLinearRegressorIntegrationTest(object): @@ -1636,6 +1710,69 @@ class BaseLinearClassifierPredictTest(object): for i in range(n_classes)], label_output_fn=lambda x: ('class_vocab_%s' % x).encode()) + def testSparseCombiner(self): + w_a = 2.0 + w_b = 3.0 + w_c = 5.0 + bias = 5.0 + with ops.Graph().as_default(): + variables_lib.Variable([[w_a], [w_b], [w_c]], name=LANGUAGE_WEIGHT_NAME) + variables_lib.Variable([bias], name=BIAS_NAME) + variables_lib.Variable(1, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + def _input_fn(): + return dataset_ops.Dataset.from_tensors({ + 'language': sparse_tensor.SparseTensor( + values=['a', 'c', 'b', 'c'], + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + dense_shape=[2, 2]), + }) + + feature_columns = ( + feature_column_lib.categorical_column_with_vocabulary_list( + 'language', vocabulary_list=['a', 'b', 'c']),) + + # Check prediction for each sparse_combiner. + # With sparse_combiner = 'sum', we have + # logits_1 = w_a + w_c + bias + # = 2.0 + 5.0 + 5.0 = 12.0 + # logits_2 = w_b + w_c + bias + # = 3.0 + 5.0 + 5.0 = 13.0 + linear_classifier = self._linear_classifier_fn( + feature_columns=feature_columns, + model_dir=self._model_dir) + predictions = linear_classifier.predict(input_fn=_input_fn) + predicted_scores = list([x['logits'] for x in predictions]) + self.assertAllClose([[12.0], [13.0]], predicted_scores) + + # With sparse_combiner = 'mean', we have + # logits_1 = 1/2 * (w_a + w_c) + bias + # = 1/2 * (2.0 + 5.0) + 5.0 = 8.5 + # logits_2 = 1/2 * (w_b + w_c) + bias + # = 1/2 * (3.0 + 5.0) + 5.0 = 9.0 + linear_classifier = self._linear_classifier_fn( + feature_columns=feature_columns, + model_dir=self._model_dir, + sparse_combiner='mean') + predictions = linear_classifier.predict(input_fn=_input_fn) + predicted_scores = list([x['logits'] for x in predictions]) + self.assertAllClose([[8.5], [9.0]], predicted_scores) + + # With sparse_combiner = 'sqrtn', we have + # logits_1 = sqrt(2)/2 * (w_a + w_c) + bias + # = sqrt(2)/2 * (2.0 + 5.0) + 5.0 = 9.94974 + # logits_2 = sqrt(2)/2 * (w_b + w_c) + bias + # = sqrt(2)/2 * (3.0 + 5.0) + 5.0 = 10.65685 + linear_classifier = self._linear_classifier_fn( + feature_columns=feature_columns, + model_dir=self._model_dir, + sparse_combiner='sqrtn') + predictions = linear_classifier.predict(input_fn=_input_fn) + predicted_scores = list([x['logits'] for x in predictions]) + self.assertAllClose([[9.94974], [10.65685]], predicted_scores) + class BaseLinearClassifierIntegrationTest(object): diff --git a/tensorflow/python/estimator/canned/metric_keys.py b/tensorflow/python/estimator/canned/metric_keys.py index 4f7c849ba4b058492c55dd27e0bf79f8d540ece9..9d49240fea4579fffe25172092080560ccd1d35d 100644 --- a/tensorflow/python/estimator/canned/metric_keys.py +++ b/tensorflow/python/estimator/canned/metric_keys.py @@ -47,3 +47,8 @@ class MetricKeys(object): PROBABILITY_MEAN_AT_CLASS = 'probability_mean/class%d' AUC_AT_CLASS = 'auc/class%d' AUC_PR_AT_CLASS = 'auc_precision_recall/class%d' + + # The following require a class name applied. + PROBABILITY_MEAN_AT_NAME = 'probability_mean/%s' + AUC_AT_NAME = 'auc/%s' + AUC_PR_AT_NAME = 'auc_precision_recall/%s' diff --git a/tensorflow/python/estimator/canned/optimizers.py b/tensorflow/python/estimator/canned/optimizers.py index f72c5ca5cbb2721d967ad9ef9dfa896f7ccce240..8f51cc3a80dd9b91eb24a83577b7d0614615e008 100644 --- a/tensorflow/python/estimator/canned/optimizers.py +++ b/tensorflow/python/estimator/canned/optimizers.py @@ -72,6 +72,8 @@ def get_optimizer_instance(opt, learning_rate=None): raise ValueError( 'Unsupported optimizer name: {}. Supported names are: {}'.format( opt, tuple(sorted(six.iterkeys(_OPTIMIZER_CLS_NAMES))))) + if callable(opt): + opt = opt() if not isinstance(opt, optimizer_lib.Optimizer): raise ValueError( 'The given object is not an Optimizer instance. Given: {}'.format(opt)) diff --git a/tensorflow/python/estimator/canned/optimizers_test.py b/tensorflow/python/estimator/canned/optimizers_test.py index ee28756155afd5ae3421475c3d41542db9411345..eadabdbc496334270cd792f5b8d5ff39a446bcf7 100644 --- a/tensorflow/python/estimator/canned/optimizers_test.py +++ b/tensorflow/python/estimator/canned/optimizers_test.py @@ -28,6 +28,13 @@ from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import rmsprop +class _TestOptimizer(optimizer_lib.Optimizer): + + def __init__(self): + super(_TestOptimizer, self).__init__( + use_locking=False, name='TestOptimizer') + + class GetOptimizerInstance(test.TestCase): def test_unsupported_name(self): @@ -66,12 +73,6 @@ class GetOptimizerInstance(test.TestCase): self.assertAlmostEqual(0.1, opt._learning_rate) def test_object(self): - class _TestOptimizer(optimizer_lib.Optimizer): - - def __init__(self): - super(_TestOptimizer, self).__init__( - use_locking=False, name='TestOptimizer') - opt = optimizers.get_optimizer_instance(_TestOptimizer()) self.assertIsInstance(opt, _TestOptimizer) @@ -80,6 +81,23 @@ class GetOptimizerInstance(test.TestCase): ValueError, 'The given object is not an Optimizer instance'): optimizers.get_optimizer_instance((1, 2, 3)) + def test_callable(self): + def _optimizer_fn(): + return _TestOptimizer() + opt = optimizers.get_optimizer_instance(_optimizer_fn) + self.assertIsInstance(opt, _TestOptimizer) + + def test_lambda(self): + opt = optimizers.get_optimizer_instance(lambda: _TestOptimizer()) # pylint: disable=unnecessary-lambda + self.assertIsInstance(opt, _TestOptimizer) + + def test_callable_returns_invalid(self): + def _optimizer_fn(): + return (1, 2, 3) + with self.assertRaisesRegexp( + ValueError, 'The given object is not an Optimizer instance'): + optimizers.get_optimizer_instance(_optimizer_fn) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 41c25f1c73295aceb542887df4fb9164c17e770a..ee3c3bba7bb124fcfe3b3507bde212a7f21098da 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -29,8 +29,6 @@ import six from google.protobuf import message from tensorflow.core.framework import summary_pb2 -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session as tf_session from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib @@ -38,6 +36,7 @@ from tensorflow.python.estimator import run_config from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export import export as export_helpers from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -51,9 +50,11 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder as saved_model_builder -from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import device_setter from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import evaluation @@ -85,14 +86,15 @@ class Estimator(object): subdirectory thereof. If `model_dir` is not set, a temporary directory is used. - The `config` argument can be passed `RunConfig` object containing information - about the execution environment. It is passed on to the `model_fn`, if the - `model_fn` has a parameter named "config" (and input functions in the same - manner). If the `config` parameter is not passed, it is instantiated by the - `Estimator`. Not passing config means that defaults useful for local execution - are used. `Estimator` makes config available to the model (for instance, to - allow specialization based on the number of workers available), and also uses - some of its fields to control internals, especially regarding checkpointing. + The `config` argument can be passed `tf.estimator.RunConfig` object containing + information about the execution environment. It is passed on to the + `model_fn`, if the `model_fn` has a parameter named "config" (and input + functions in the same manner). If the `config` parameter is not passed, it is + instantiated by the `Estimator`. Not passing config means that defaults useful + for local execution are used. `Estimator` makes config available to the model + (for instance, to allow specialization based on the number of workers + available), and also uses some of its fields to control internals, especially + regarding checkpointing. The `params` argument contains hyperparameters. It is passed to the `model_fn`, if the `model_fn` has a parameter named "params", and to the input @@ -103,6 +105,15 @@ class Estimator(object): None of `Estimator`'s methods can be overridden in subclasses (its constructor enforces this). Subclasses should use `model_fn` to configure the base class, and may add methods implementing specialized functionality. + + @compatibility(eager) + Calling methods of `Estimator` will work while eager execution is enabled. + However, the `model_fn` and `input_fn` is not executed eagerly, `Estimator` + will switch to graph model before calling all user-provided functions (incl. + hooks), so their code has to be compatible with graph mode execution. Note + that `input_fn` code using `tf.data` generally works in both graph and eager + modes. + @end_compatibility """ def __init__(self, model_fn, model_dir=None, config=None, params=None, @@ -119,7 +130,7 @@ class Estimator(object): ``` For more details on warm-start configuration, see - @{tf.estimator.WarmStartSettings$WarmStartSettings}. + `tf.estimator.WarmStartSettings`. Args: model_fn: Model function. Follows the signature: @@ -128,15 +139,16 @@ class Estimator(object): * `features`: This is the first item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. This should be a - single `Tensor` or `dict` of same. + single `tf.Tensor` or `dict` of same. * `labels`: This is the second item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. This should be a - single `Tensor` or `dict` of same (for multi-head models). If - mode is `ModeKeys.PREDICT`, `labels=None` will be passed. If - the `model_fn`'s signature does not accept `mode`, the - `model_fn` must still be able to handle `labels=None`. + single `tf.Tensor` or `dict` of same (for multi-head models). + If mode is @{tf.estimator.ModeKeys.PREDICT}, `labels=None` will + be passed. If the `model_fn`'s signature does not accept + `mode`, the `model_fn` must still be able to handle + `labels=None`. * `mode`: Optional. Specifies if this training, evaluation or - prediction. See `ModeKeys`. + prediction. See `tf.estimator.ModeKeys`. * `params`: Optional `dict` of hyperparameters. Will receive what is passed to Estimator in `params` parameter. This allows to configure Estimators from hyper parameter tuning. @@ -146,10 +158,10 @@ class Estimator(object): configuration such as `num_ps_replicas`, or `model_dir`. * Returns: - `EstimatorSpec` + `tf.estimator.EstimatorSpec` model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator to + also be used to load checkpoints from the directory into an estimator to continue training a previously saved model. If `PathLike` object, the path will be resolved. If `None`, the model_dir in `config` will be used if set. If both are set, they must be same. If both are `None`, a @@ -160,9 +172,10 @@ class Estimator(object): warm_start_from: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a `tf.estimator.WarmStartSettings` object to fully configure warm-starting. If the string - filepath is provided instead of a `WarmStartSettings`, - then all variables are warm-started, and it is assumed - that vocabularies and Tensor names are unchanged. + filepath is provided instead of a + `tf.estimator.WarmStartSettings`, then all variables are + warm-started, and it is assumed that vocabularies + and `tf.Tensor` names are unchanged. Raises: ValueError: parameters of `model_fn` don't match `params`. @@ -171,49 +184,17 @@ class Estimator(object): """ Estimator._assert_members_are_not_overridden(self) - if config is None: - self._config = run_config.RunConfig() - logging.info('Using default config.') - else: - if not isinstance(config, run_config.RunConfig): - raise ValueError( - 'config must be an instance of RunConfig, but provided %s.' % - config) - self._config = config + config = maybe_overwrite_model_dir_and_session_config(config, model_dir) + self._config = config # The distribute field contains an instance of DistributionStrategy. - self._distribution = self._config.train_distribute - + self._train_distribution = self._config.train_distribute + self._eval_distribution = self._config.eval_distribute # Model directory. - model_dir = compat_internal.path_to_str(model_dir) - if (model_dir is not None) and (self._config.model_dir is not None): - if model_dir != self._config.model_dir: - # TODO(alanyee): remove this suppression after it is no longer needed - # pylint: disable=g-doc-exception - raise ValueError( - "model_dir are set both in constructor and RunConfig, but with " - "different values. In constructor: '{}', in RunConfig: " - "'{}' ".format(model_dir, self._config.model_dir)) - # pylint: enable=g-doc-exception - - self._model_dir = model_dir or self._config.model_dir - if self._model_dir is None: - self._model_dir = tempfile.mkdtemp() - logging.warning('Using temporary folder as model directory: %s', - self._model_dir) - if self._config.model_dir is None: - self._config = self._config.replace(model_dir=self._model_dir) + self._model_dir = self._config.model_dir + self._session_config = self._config.session_config logging.info('Using config: %s', str(vars(self._config))) - if self._config.session_config is None: - rewrite_opts = rewriter_config_pb2.RewriterConfig( - meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE) - graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts) - self._session_config = config_pb2.ConfigProto( - allow_soft_placement=True, graph_options=graph_opts) - else: - self._session_config = self._config.session_config - self._device_fn = ( self._config.device_fn or _get_replica_device_setter(self._config)) @@ -242,10 +223,10 @@ class Estimator(object): @property def model_fn(self): - """Returns the model_fn which is bound to self.params. + """Returns the `model_fn` which is bound to `self.params`. Returns: - The model_fn with following signature: + The `model_fn` with following signature: `def model_fn(features, labels, mode, config)` """ @@ -265,7 +246,7 @@ class Estimator(object): Numpy array - value of the tensor. Raises: - ValueError: If the Estimator has not produced a checkpoint yet. + ValueError: If the `Estimator` has not produced a checkpoint yet. """ _check_checkpoint_available(self.model_dir) with context.graph_mode(): @@ -278,21 +259,21 @@ class Estimator(object): List of names. Raises: - ValueError: If the Estimator has not produced a checkpoint yet. + ValueError: If the `Estimator` has not produced a checkpoint yet. """ _check_checkpoint_available(self.model_dir) with context.graph_mode(): return [name for name, _ in training.list_variables(self.model_dir)] def latest_checkpoint(self): - """Finds the filename of latest saved checkpoint file in `model_dir`. + """Finds the filename of the latest saved checkpoint file in `model_dir`. Returns: The full path to the latest checkpoint or `None` if no checkpoint was found. """ with context.graph_mode(): - return saver.latest_checkpoint(self.model_dir) + return checkpoint_management.latest_checkpoint(self.model_dir) def train(self, input_fn, @@ -300,40 +281,36 @@ class Estimator(object): steps=None, max_steps=None, saving_listeners=None): - """Trains a model given training data input_fn. + """Trains a model given training data `input_fn`. Args: input_fn: A function that provides input data for training as minibatches. - See @{$premade_estimators#create_input_functions} for more - information. The function should construct and return one of - the following: - - * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a - tuple (features, labels) with same constraints as below. - * A tuple (features, labels): Where `features` is a `Tensor` or a - dictionary of string feature name to `Tensor` and `labels` is a - `Tensor` or a dictionary of string label name to `Tensor`. Both - `features` and `labels` are consumed by `model_fn`. They should - satisfy the expectation of `model_fn` from inputs. - - hooks: List of `SessionRunHook` subclass instances. Used for callbacks - inside the training loop. - steps: Number of steps for which to train model. If `None`, train forever - or train until input_fn generates the `OutOfRange` error or - `StopIteration` exception. 'steps' works incrementally. If you call two - times train(steps=10) then training occurs in total 20 steps. If - `OutOfRange` or `StopIteration` occurs in the middle, training stops + See @{$premade_estimators#create_input_functions} for more information. + The function should construct and return one of the following: * A + `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple + `(features, labels)` with same constraints as below. * A tuple + `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary + of string feature name to `Tensor` and `labels` is a `Tensor` or a + dictionary of string label name to `Tensor`. Both `features` and + `labels` are consumed by `model_fn`. They should satisfy the expectation + of `model_fn` from inputs. + hooks: List of `tf.train.SessionRunHook` subclass instances. Used for + callbacks inside the training loop. + steps: Number of steps for which to train the model. If `None`, train + forever or train until `input_fn` generates the `tf.errors.OutOfRange` + error or `StopIteration` exception. `steps` works incrementally. If you + call two times `train(steps=10)` then training occurs in total 20 steps. + If `OutOfRange` or `StopIteration` occurs in the middle, training stops before 20 steps. If you don't want to have incremental behavior please set `max_steps` instead. If set, `max_steps` must be `None`. max_steps: Number of total steps for which to train model. If `None`, - train forever or train until input_fn generates the `OutOfRange` error - or `StopIteration` exception. If set, `steps` must be `None`. If - `OutOfRange` or `StopIteration` occurs in the middle, training stops - before `max_steps` steps. - Two calls to `train(steps=100)` means 200 training - iterations. On the other hand, two calls to `train(max_steps=100)` means - that the second call will not do any iteration since first call did - all 100 steps. + train forever or train until `input_fn` generates the + `tf.errors.OutOfRange` error or `StopIteration` exception. If set, + `steps` must be `None`. If `OutOfRange` or `StopIteration` occurs in the + middle, training stops before `max_steps` steps. Two calls to + `train(steps=100)` means 200 training iterations. On the other hand, two + calls to `train(max_steps=100)` means that the second call will not do + any iteration since first call did all 100 steps. saving_listeners: list of `CheckpointSaverListener` objects. Used for callbacks that run immediately before or after checkpoint savings. @@ -342,7 +319,7 @@ class Estimator(object): Raises: ValueError: If both `steps` and `max_steps` are not `None`. - ValueError: If either `steps` or `max_steps` is <= 0. + ValueError: If either `steps` or `max_steps <= 0`. """ with context.graph_mode(): if (steps is not None) and (max_steps is not None): @@ -368,13 +345,29 @@ class Estimator(object): return self def _convert_train_steps_to_hooks(self, steps, max_steps): + """Create hooks to run correct number of steps in training. + + Args: + steps: number of steps to run during training. + max_steps: maximum number of steps to be run during training. It'll be + the maximum number of steps the model will train to after restoring + from checkpoint even across multiple estimator.train calls. + + Returns: + List of hooks to be passed to the estimator. + """ if steps is not None or max_steps is not None: + if self._train_distribution: + steps_per_run = getattr(self._train_distribution, 'steps_per_run', 1) + if steps_per_run > 1: + return [basic_session_run_hooks._MultiStepStopAtStepHook( # pylint: disable=protected-access + steps, max_steps, steps_per_run)] return [training.StopAtStepHook(steps, max_steps)] else: return [] def eval_dir(self, name=None): - """Shows directory name where evaluation metrics are dumped. + """Shows the directory name where evaluation metrics are dumped. Args: name: Name of the evaluation if user needs to run multiple evaluations on @@ -390,36 +383,34 @@ class Estimator(object): def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None): - """Evaluates the model given evaluation data input_fn. + """Evaluates the model given evaluation data `input_fn`. For each step, calls `input_fn`, which returns one batch of data. Evaluates until: - `steps` batches are processed, or - - `input_fn` raises an end-of-input exception (`OutOfRangeError` or + - `input_fn` raises an end-of-input exception (`tf.errors.OutOfRangeError` + or `StopIteration`). Args: - input_fn: A function that constructs the input data for evaluation. - See @{$premade_estimators#create_input_functions} for more - information. The function should construct and return one of - the following: - - * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a - tuple (features, labels) with same constraints as below. - * A tuple (features, labels): Where `features` is a `Tensor` or a - dictionary of string feature name to `Tensor` and `labels` is a - `Tensor` or a dictionary of string label name to `Tensor`. Both - `features` and `labels` are consumed by `model_fn`. They should - satisfy the expectation of `model_fn` from inputs. - + input_fn: A function that constructs the input data for evaluation. See + @{$premade_estimators#create_input_functions} for more information. The + function should construct and return one of the following: * A + `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple + `(features, labels)` with same constraints as below. * A tuple + `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary + of string feature name to `Tensor` and `labels` is a `Tensor` or a + dictionary of string label name to `Tensor`. Both `features` and + `labels` are consumed by `model_fn`. They should satisfy the expectation + of `model_fn` from inputs. steps: Number of steps for which to evaluate model. If `None`, evaluates until `input_fn` raises an end-of-input exception. - hooks: List of `SessionRunHook` subclass instances. Used for callbacks - inside the evaluation call. + hooks: List of `tf.train.SessionRunHook` subclass instances. Used for + callbacks inside the evaluation call. checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the latest checkpoint in `model_dir` is used. If there are no checkpoints in `model_dir`, evaluation is run with newly initialized `Variables` - instead of restored from checkpoint. + instead of ones restored from checkpoint. name: Name of the evaluation if user needs to run multiple evaluations on different data sets, such as on training data vs test data. Metrics for different evaluations are saved in separate folders, and appear @@ -441,16 +432,15 @@ class Estimator(object): # Check that model has been trained (if nothing has been set explicitly). if not checkpoint_path: - latest_path = saver.latest_checkpoint(self._model_dir) + latest_path = checkpoint_management.latest_checkpoint(self._model_dir) if not latest_path: logging.info('Could not find trained model in model_dir: {}, running ' 'initialization to evaluate.'.format(self._model_dir)) checkpoint_path = latest_path - with ops.Graph().as_default(): - (scaffold, update_op, - eval_dict, all_hooks) = self._evaluate_build_graph( - input_fn, hooks, checkpoint_path) + def _evaluate(): + (scaffold, update_op, eval_dict, all_hooks) = ( + self._evaluate_build_graph(input_fn, hooks, checkpoint_path)) return self._evaluate_run( checkpoint_path=checkpoint_path, scaffold=scaffold, @@ -459,6 +449,15 @@ class Estimator(object): all_hooks=all_hooks, output_dir=self.eval_dir(name)) + with ops.Graph().as_default(): + # TODO(priyag): Support distributed eval on TPUs. + if (self._eval_distribution + and self._eval_distribution.__class__.__name__ != 'TPUStrategy'): + with self._eval_distribution.scope(): + return _evaluate() + else: + return _evaluate() + def _convert_eval_steps_to_hooks(self, steps): if steps is None: return [] @@ -477,33 +476,33 @@ class Estimator(object): Args: input_fn: A function that constructs the features. Prediction continues - until `input_fn` raises an end-of-input exception (`OutOfRangeError` or - `StopIteration`). + until `input_fn` raises an end-of-input exception + (`tf.errors.OutOfRangeError` or `StopIteration`). See @{$premade_estimators#create_input_functions} for more information. The function should construct and return one of the following: - * A 'tf.data.Dataset' object: Outputs of `Dataset` object must have + * A `tf.data.Dataset` object: Outputs of `Dataset` object must have same constraints as below. - * features: A `Tensor` or a dictionary of string feature name to + * features: A `tf.Tensor` or a dictionary of string feature name to `Tensor`. features are consumed by `model_fn`. They should satisfy the expectation of `model_fn` from inputs. * A tuple, in which case the first item is extracted as features. predict_keys: list of `str`, name of the keys to predict. It is used if - the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used - then rest of the predictions will be filtered from the dictionary. If - `None`, returns all. - hooks: List of `SessionRunHook` subclass instances. Used for callbacks - inside the prediction call. + the `tf.estimator.EstimatorSpec.predictions` is a `dict`. If + `predict_keys` is used then rest of the predictions will be filtered + from the dictionary. If `None`, returns all. + hooks: List of `tf.train.SessionRunHook` subclass instances. Used for + callbacks inside the prediction call. checkpoint_path: Path of a specific checkpoint to predict. If `None`, the latest checkpoint in `model_dir` is used. If there are no checkpoints in `model_dir`, prediction is run with newly initialized `Variables` - instead of restored from checkpoint. - yield_single_examples: If False, yield the whole batch as returned by the - `model_fn` instead of decomposing the batch into individual elements. - This is useful if `model_fn` returns some tensors whose first dimension - is not equal to the batch size. + instead of ones restored from checkpoint. + yield_single_examples: If `False`, yields the whole batch as returned by + the `model_fn` instead of decomposing the batch into individual + elements. This is useful if `model_fn` returns some tensors whose first + dimension is not equal to the batch size. Yields: Evaluated values of `predictions` tensors. @@ -511,16 +510,17 @@ class Estimator(object): Raises: ValueError: Could not find a trained model in `model_dir`. ValueError: If batch length of predictions is not the same and - `yield_single_examples` is True. + `yield_single_examples` is `True`. ValueError: If there is a conflict between `predict_keys` and `predictions`. For example if `predict_keys` is not `None` but - `EstimatorSpec.predictions` is not a `dict`. + `tf.estimator.EstimatorSpec.predictions` is not a `dict`. """ with context.graph_mode(): hooks = _check_hooks_type(hooks) # Check that model has been trained. if not checkpoint_path: - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint( + self._model_dir) if not checkpoint_path: logging.info('Could not find trained model in model_dir: {}, running ' 'initialization to predict.'.format(self._model_dir)) @@ -563,12 +563,19 @@ class Estimator(object): def _assert_members_are_not_overridden(self): """Asserts members of `Estimator` are not overridden.""" + # TPUEstimator is special cased (owned by TF). + if self.__class__.__name__ == 'TPUEstimator': + return + allowed_overrides = set([ - '_call_input_fn', '_create_global_step', + '_call_input_fn', '_call_model_fn', '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks', - '_tf_api_names', '_estimator_api_names', '_estimator_api_constants', + '_create_global_step', '_create_and_assert_global_step', + '_tf_api_names', '_tf_api_names_v1', '_estimator_api_names', + '_estimator_api_names_v1', '_estimator_api_constants', + '_estimator_api_constants_v1', '_validate_features_in_predict_input', - '_call_model_fn', '_add_meta_graph_for_mode' + '_add_meta_graph_for_mode' ]) estimator_members = set([m for m in Estimator.__dict__.keys() if not m.startswith('__')]) @@ -589,30 +596,34 @@ class Estimator(object): checkpoint_path=None, strip_default_attrs=False): # pylint: disable=line-too-long - """Exports inference graph as a SavedModel into given dir. + """Exports inference graph as a `SavedModel` into the given dir. For a detailed guide, see - @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with + Estimators}. This method builds a new graph by first calling the - serving_input_receiver_fn to obtain feature `Tensor`s, and then calling - this `Estimator`'s model_fn to generate the model graph based on those + `serving_input_receiver_fn` to obtain feature `Tensor`s, and then calling + this `Estimator`'s `model_fn` to generate the model graph based on those features. It restores the given checkpoint (or, lacking that, the most recent checkpoint) into this graph in a fresh session. Finally it creates - a timestamped export directory below the given export_dir_base, and writes - a `SavedModel` into it containing a single `MetaGraphDef` saved from this + a timestamped export directory below the given `export_dir_base`, and writes + a `SavedModel` into it containing a single `tf.MetaGraphDef` saved from this session. The exported `MetaGraphDef` will provide one `SignatureDef` for each - element of the export_outputs dict returned from the model_fn, named using + element of the `export_outputs` dict returned from the `model_fn`, named + using the same keys. One of these keys is always - signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`, + indicating which signature will be served when a serving request does not specify one. For each signature, the outputs are provided by the corresponding - `ExportOutput`s, and the inputs are always the input receivers provided by - the serving_input_receiver_fn. + `tf.estimator.export.ExportOutput`s, and the inputs are always the input + receivers provided by + the `serving_input_receiver_fn`. - Extra assets may be written into the SavedModel via the assets_extra + Extra assets may be written into the `SavedModel` via the `assets_extra` argument. This should be a dict, where each key gives a destination path (including the filename) relative to the assets.extra directory. The corresponding value gives the full path of the source file to be copied. @@ -621,23 +632,27 @@ class Estimator(object): Args: export_dir_base: A string containing a directory in which to create - timestamped subdirectories containing exported SavedModels. - serving_input_receiver_fn: A function that takes no argument and - returns a `ServingInputReceiver` or `TensorServingInputReceiver`. + timestamped subdirectories containing exported `SavedModel`s. + serving_input_receiver_fn: A function that takes no argument and returns a + `tf.estimator.export.ServingInputReceiver` or + `tf.estimator.export.TensorServingInputReceiver`. assets_extra: A dict specifying how to populate the assets.extra directory - within the exported SavedModel, or `None` if no extra assets are needed. - as_text: whether to write the SavedModel proto in text format. + within the exported `SavedModel`, or `None` if no extra assets are + needed. + as_text: whether to write the `SavedModel` proto in text format. checkpoint_path: The checkpoint path to export. If `None` (the default), the most recent checkpoint found within the model directory is chosen. strip_default_attrs: Boolean. If `True`, default-valued attributes will be - removed from the NodeDefs. For a detailed guide, see - [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + removed from the `NodeDef`s. For a detailed guide, see [Stripping + Default-Valued + Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: The string path to the exported directory. Raises: - ValueError: if no serving_input_receiver_fn is provided, no export_outputs + ValueError: if no `serving_input_receiver_fn` is provided, no + `export_outputs` are provided, or no checkpoint can be found. """ # pylint: enable=line-too-long @@ -658,35 +673,37 @@ class Estimator(object): strip_default_attrs=False, mode=model_fn_lib.ModeKeys.PREDICT): # pylint: disable=line-too-long - """Exports a single train/eval/predict graph as a SavedModel. + """Exports a single train/eval/predict graph as a `SavedModel`. - This method is a wrapper for _export_all_saved_models, and wraps a raw - input_receiver_fn in a dictionary to pass in to that function. - See _export_all_saved_models for full docs. + This method is a wrapper for `_export_all_saved_models`, and wraps a raw + `input_receiver_fn` in a dictionary to pass in to that function. + See `_export_all_saved_models` for full docs. - See tf.contrib.estimator.export_saved_model_for_mode for the currently + See `tf.contrib.estimator.export_saved_model_for_mode` for the currently exposed version of this function. Args: export_dir_base: A string containing a directory in which to create - timestamped subdirectories containing exported SavedModels. - input_receiver_fn: a function that takes no argument and - returns the appropriate subclass of `InputReceiver`. + timestamped subdirectories containing exported `SavedModel`s. + input_receiver_fn: a function that takes no argument and returns the + appropriate subclass of `InputReceiver`. assets_extra: A dict specifying how to populate the assets.extra directory - within the exported SavedModel, or `None` if no extra assets are needed. - as_text: whether to write the SavedModel proto in text format. + within the exported `SavedModel`, or `None` if no extra assets are + needed. + as_text: whether to write the `SavedModel` proto in text format. checkpoint_path: The checkpoint path to export. If `None` (the default), the most recent checkpoint found within the model directory is chosen. strip_default_attrs: Boolean. If `True`, default-valued attributes will be - removed from the NodeDefs. For a detailed guide, see - [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). - mode: tf.estimator.ModeKeys value indicating with mode will be exported. + removed from the `NodeDef`s. For a detailed guide, see [Stripping + Default-Valued + Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: `tf.estimator.ModeKeys` value indicating with mode will be exported. Returns: The string path to the exported directory. Raises: - ValueError: if input_receiver_fn is None, no export_outputs + ValueError: if `input_receiver_fn` is `None`, no `export_outputs` are provided, or no checkpoint can be found. """ # pylint: enable=line-too-long @@ -710,40 +727,46 @@ class Estimator(object): checkpoint_path=None, strip_default_attrs=False): # pylint: disable=line-too-long - """Exports a SavedModel containing MetaGraphDefs for each requested mode. + """Exports a `SavedModel` containing `tf.MetaGraphDefs` for each requested mode. - See tf.contrib.estimator.export_all_saved_models for the currently + See `tf.contrib.estimator.export_all_saved_models` for the currently exposed version of this function. - For each mode passed in via the input_receiver_fn_map, - this method builds a new graph by calling the input_receiver_fn to obtain + For each mode passed in via the `input_receiver_fn_map`, + this method builds a new graph by calling the `input_receiver_fn` to obtain feature and label `Tensor`s. Next, this method calls the `Estimator`'s - model_fn in the passed mode to generate the model graph based on + `model_fn` in the passed mode to generate the model graph based on those features and labels, and restores the given checkpoint (or, lacking that, the most recent checkpoint) into the graph. - Only one of the modes is used for saving variables to the SavedModel - (order of preference: TRAIN, EVAL, then PREDICT), such that up to three - MetaGraphDefs are saved with a single set of variables in a single - SavedModel directory. - - For the variables and MetaGraphDefs, a timestamped export directory below - export_dir_base, and writes a `SavedModel` into it containing - the `MetaGraphDef` for the given mode and its associated signatures. + Only one of the modes is used for saving variables to the `SavedModel` + (order of preference: @{tf.estimator.ModeKeys#TRAIN$TRAIN}, + @{tf.estimator.ModeKeys#EVAL$EVAL}, then + @{tf.estimator.ModeKeys#PREDICT$PREDICT}), such that up to three + `tf.MetaGraphDefs` are saved with a single set of variables in a single + `SavedModel` directory. + + For the variables and `tf.MetaGraphDefs`, a timestamped export directory + below + `export_dir_base`, and writes a `SavedModel` into it containing + the `tf.MetaGraphDef` for the given mode and its associated signatures. For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` - for each element of the export_outputs dict returned from the model_fn, + for each element of the `export_outputs` dict returned from the `model_fn`, named using the same keys. One of these keys is always - signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`, + indicating which signature will be served when a serving request does not specify one. For each signature, the outputs are provided by the corresponding - `ExportOutput`s, and the inputs are always the input receivers provided by - the serving_input_receiver_fn. + `tf.estimator.export.ExportOutput`s, and the inputs are always the input + receivers provided by + the `serving_input_receiver_fn`. - For training and evaluation, the train_op is stored in an extra collection, - and loss, metrics, and predictions are included in a SignatureDef for the + For training and evaluation, the `train_op` is stored in an extra + collection, + and loss, metrics, and predictions are included in a `SignatureDef` for the mode in question. - Extra assets may be written into the SavedModel via the assets_extra + Extra assets may be written into the `SavedModel` via the `assets_extra` argument. This should be a dict, where each key gives a destination path (including the filename) relative to the assets.extra directory. The corresponding value gives the full path of the source file to be copied. @@ -752,25 +775,28 @@ class Estimator(object): Args: export_dir_base: A string containing a directory in which to create - timestamped subdirectories containing exported SavedModels. - input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn - mappings, where the input_receiver_fn is a function that takes no - argument and returns the appropriate subclass of `InputReceiver`. + timestamped subdirectories containing exported `SavedModel`s. + input_receiver_fn_map: dict of `tf.estimator.ModeKeys` to + `input_receiver_fn` mappings, where the `input_receiver_fn` is a + function that takes no arguments and returns the appropriate subclass of + `InputReceiver`. assets_extra: A dict specifying how to populate the assets.extra directory - within the exported SavedModel, or `None` if no extra assets are needed. - as_text: whether to write the SavedModel proto in text format. + within the exported `SavedModel`, or `None` if no extra assets are + needed. + as_text: whether to write the `SavedModel` proto in text format. checkpoint_path: The checkpoint path to export. If `None` (the default), the most recent checkpoint found within the model directory is chosen. strip_default_attrs: Boolean. If `True`, default-valued attributes will be - removed from the NodeDefs. For a detailed guide, see - [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + removed from the `NodeDef`s. For a detailed guide, see [Stripping + Default-Valued + Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: - A dict of tf.estimator.ModeKeys value to string path for each exported + A dict of `tf.estimator.ModeKeys` value to string path for each exported directory. Raises: - ValueError: if any input_receiver_fn is None, no export_outputs + ValueError: if any `input_receiver_fn` is `None`, no `export_outputs` are provided, or no checkpoint can be found. """ # pylint: enable=line-too-long @@ -778,7 +804,8 @@ class Estimator(object): with context.graph_mode(): if not checkpoint_path: # Locate the latest checkpoint - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint( + self._model_dir) if not checkpoint_path: raise ValueError("Couldn't find trained model at %s." % self._model_dir) @@ -839,27 +866,36 @@ class Estimator(object): strip_default_attrs, save_variables=True, mode=model_fn_lib.ModeKeys.PREDICT, - export_tags=None): + export_tags=None, + check_variables=True): # pylint: disable=line-too-long - """Loads variables and adds them along with a MetaGraphDef for saving. + """Loads variables and adds them along with a `tf.MetaGraphDef` for saving. Args: - builder: instance of SavedModelBuilder that will be used for saving. - input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn - mappings, where the input_receiver_fn is a function that takes no - argument and returns the appropriate subclass of `InputReceiver`. + builder: instance of `tf.saved_modle.builder.SavedModelBuilder` that will + be used for saving. + input_receiver_fn_map: dict of `tf.estimator.ModeKeys` to + `input_receiver_fn` mappings, where the `input_receiver_fn` is a + function that takes no argument and returns the appropriate subclass of + `InputReceiver`. checkpoint_path: The checkpoint path to export. If `None` (the default), the most recent checkpoint found within the model directory is chosen. strip_default_attrs: Boolean. If `True`, default-valued attributes will be - removed from the NodeDefs. For a detailed guide, see - [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). - save_variables: bool, whether variables should be saved. If False, just - the MetaGraphDef will be saved. Note that save_variables should only be - True for the first call to this function, and the SavedModelBuilder will - raise an error if that is not the case. - mode: tf.estimator.ModeKeys value indicating which mode will be exported. - export_tags: The set of tags with which to save `MetaGraphDef`. If None, - a default set will be selected to matched the passed mode. + removed from the `NodeDef`s. For a detailed guide, see [Stripping + Default-Valued + Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + save_variables: bool, whether variables should be saved. If `False`, just + the `tf.MetaGraphDef` will be saved. Note that `save_variables` should + only be `True` for the first call to this function, and the + `SavedModelBuilder` will raise an error if that is not the case. + mode: `tf.estimator.ModeKeys` value indicating which mode will be + exported. + export_tags: The set of tags with which to save `tf.MetaGraphDef`. If + `None`, a default set will be selected to matched the passed mode. + check_variables: bool, whether to check the checkpoint has all variables. + + Raises: + ValueError: if `save_variables` is `True` and `check_variable` is `False`. """ # pylint: enable=line-too-long if export_tags is None: @@ -890,9 +926,10 @@ class Estimator(object): with tf_session.Session(config=self._session_config) as session: - local_init_op = ( - estimator_spec.scaffold.local_init_op or - monitored_session.Scaffold.default_local_init_op()) + if estimator_spec.scaffold.local_init_op is not None: + local_init_op = estimator_spec.scaffold.local_init_op + else: + local_init_op = monitored_session.Scaffold.default_local_init_op() # This saver will be used both for restoring variables now, # and in saving out the metagraph below. This ensures that any @@ -900,16 +937,20 @@ class Estimator(object): # SavedModel for restore later. graph_saver = estimator_spec.scaffold.saver or saver.Saver(sharded=True) - try: - graph_saver.restore(session, checkpoint_path) - except errors.NotFoundError as e: - msg = ('Could not load all requested variables from the checkpoint. ' - 'Please make sure your model_fn does not expect variables ' - 'that were not saved in the checkpoint.\n\n' - 'Encountered error with mode `{}` while restoring checkpoint ' - 'from: `{}`. Full Traceback:\n\n{}').format( - mode, checkpoint_path, e) - raise ValueError(msg) + if save_variables and not check_variables: + raise ValueError('If `save_variables` is `True, `check_variables`' + 'must not be `False`.') + if check_variables: + try: + graph_saver.restore(session, checkpoint_path) + except errors.NotFoundError as e: + msg = ('Could not load all requested variables from checkpoint. ' + 'Please make sure your model_fn does not expect variables ' + 'that were not saved in the checkpoint.\n\n' + 'Encountered error with mode `{}` while restoring ' + 'checkpoint from: `{}`. Full Traceback:\n\n{}').format( + mode, checkpoint_path, e) + raise ValueError(msg) # We add the train op explicitly for now, so that we don't have to # change the Builder public interface. Note that this is a no-op @@ -932,21 +973,23 @@ class Estimator(object): builder.add_meta_graph(**meta_graph_kwargs) def _get_export_outputs_for_spec(self, estimator_spec): - """Given an EstimatorSpec, determine what our export outputs should be. + """Given an `EstimatorSpec`, determine what our export outputs should be. - EstimatorSpecs contain export_outputs that are used for serving, but for + `EstimatorSpecs` contains `export_outputs` that are used for serving, but + for training and eval graphs, we must wrap the tensors of interest in - appropriate ExportOutput objects. + appropriate `tf.estimator.export.ExportOutput` objects. Args: - estimator_spec: EstimatorSpec object that will be exported. + estimator_spec: `tf.estimator.EstimatorSpec` object that will be exported. Returns: - a dict mapping export_output_name to ExportOutput object. + a dict mapping `export_output_name` to `tf.estimator.export.ExportOutput` + object. Raises: - ValueError: if an appropriate ExportOutput cannot be found for the - passed EstimatorSpec.mode + ValueError: if an appropriate `ExportOutput` cannot be found for the + passed `EstimatorSpec.mode` """ mode = estimator_spec.mode if mode == model_fn_lib.ModeKeys.PREDICT: @@ -981,10 +1024,11 @@ class Estimator(object): 'QueueRunner. That means predict yields forever. ' 'This is probably a mistake.') - def _get_features_and_labels_from_input_fn(self, input_fn, mode): + def _get_features_and_labels_from_input_fn(self, input_fn, mode, + distribution=None): """Extracts the `features` and labels from return values of `input_fn`.""" - if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN: - result = self._distribution.distribute_dataset( + if distribution is not None: + result = distribution.distribute_dataset( lambda: self._call_input_fn(input_fn, mode)) else: result = self._call_input_fn(input_fn, mode) @@ -1022,13 +1066,13 @@ class Estimator(object): """Creates the global step tensor in graph. The global step tensor must be an integer type with name 'global_step' and - be added to the collection @{tf.GraphKeys.GLOBAL_STEP}. + be added to the collection @{tf.GraphKeys#GLOBAL_STEP$GLOBAL_STEP}. Args: graph: The graph in which to create the global step tensor. Returns: - The global step `Tensor`. + The global step `tf.Tensor`. """ return training.create_global_step(graph) @@ -1039,7 +1083,7 @@ class Estimator(object): graph: The graph in which to create the global step tensor. Returns: - The global step `Tensor`. + The global step `tf.Tensor`. """ step = self._create_global_step(graph) assert step == training.get_global_step() @@ -1051,21 +1095,21 @@ class Estimator(object): Args: input_fn: The input function. - mode: ModeKeys + mode: `tf.estimator.ModeKeys` Returns: - The return value of the passed input_fn, which should be one of: + The return value of the passed `input_fn`, which should be one of: * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a - tuple (features, labels) with same constraints as below. - * A tuple (features, labels): Where `features` is a `Tensor` or a + tuple `(features, labels)` with same constraints as below. + * A tuple `(features, labels)`: Where `features` is a `Tensor` or a dictionary of string feature name to `Tensor` and `labels` is a `Tensor` or a dictionary of string label name to `Tensor`. Both `features` and `labels` are consumed by `model_fn`. They should satisfy the expectation of `model_fn` from inputs. Raises: - ValueError: if input_fn takes invalid arguments. + ValueError: if `input_fn` takes invalid arguments. """ input_fn_args = function_utils.fn_args(input_fn) kwargs = {} @@ -1084,14 +1128,14 @@ class Estimator(object): Args: features: features dict. labels: labels dict. - mode: ModeKeys - config: RunConfig + mode: `tf.estimator.ModeKeys` + config: `tf.estimator.RunConfig` Returns: - An `EstimatorSpec` object. + An `tf.estimator.EstimatorSpec` object. Raises: - ValueError: if model_fn returns invalid objects. + ValueError: if `model_fn` returns invalid objects. """ model_fn_args = function_utils.fn_args(self._model_fn) kwargs = {} @@ -1118,123 +1162,142 @@ class Estimator(object): return model_fn_results def _train_model(self, input_fn, hooks, saving_listeners): - if self._distribution: + if self._train_distribution: return self._train_model_distributed(input_fn, hooks, saving_listeners) else: return self._train_model_default(input_fn, hooks, saving_listeners) def _train_model_default(self, input_fn, hooks, saving_listeners): + """Initiate training with `input_fn`, without `DistributionStrategies`. + + Args: + input_fn: A function that provides input data for training as minibatches. + hooks: List of `tf.train.SessionRunHook` subclass instances. Used for + callbacks inside the training loop. + saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used + for callbacks that run immediately before or after checkpoint savings. + + Returns: + Loss from training + """ worker_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = self._create_and_assert_global_step(g) - training_util._get_or_create_global_step_read() # pylint: disable=protected-access + + # Skip creating a read variable if _create_and_assert_global_step + # returns None (e.g. tf.contrib.estimator.SavedModelEstimator). + if global_step_tensor is not None: + training_util._get_or_create_global_step_read(g) # pylint: disable=protected-access + features, labels, input_hooks = ( self._get_features_and_labels_from_input_fn( input_fn, model_fn_lib.ModeKeys.TRAIN)) worker_hooks.extend(input_hooks) estimator_spec = self._call_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) + global_step_tensor = training_util.get_global_step(g) return self._train_with_estimator_spec(estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners) def _train_model_distributed(self, input_fn, hooks, saving_listeners): - self._distribution.configure(self._session_config) - worker_hooks = [] - with ops.Graph().as_default() as g: - with self._distribution.scope(): - random_seed.set_random_seed(self._config.tf_random_seed) - features, labels, input_hooks = ( - self._get_features_and_labels_from_input_fn( - input_fn, model_fn_lib.ModeKeys.TRAIN)) - worker_hooks.extend(input_hooks) - global_step_tensor = self._create_and_assert_global_step(g) - # The default destination for the global_step_tensor fetch call is the - # CPU. - global_step_read_tensor = self._distribution.fetch(global_step_tensor) - # we want to add to the global collection in the main thread not the - # tower threads. - ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY, - global_step_read_tensor) - grouped_estimator_spec = self._distribution.call_for_each_tower( - self._call_model_fn, - features, - labels, # although this will be None it seems - model_fn_lib.ModeKeys.TRAIN, - self.config) - - # TODO(anjalisridhar): Figure out how to resolve the following scaffold - # parameters: init_feed_dict, init_fn. - scaffold_list = self._distribution.unwrap( - grouped_estimator_spec.scaffold) - init_feed_dict = [ - s.init_feed_dict - for s in scaffold_list - if s.init_feed_dict is not None - ] - if init_feed_dict: - init_feed_dict = self._distribution.group(init_feed_dict) - else: - init_feed_dict = None + """Initiate training with `input_fn`, using `DistributionStrategies`. - init_fn = [s.init_fn for s in scaffold_list if s.init_fn is not None] - if init_fn: - init_fn = self._distribution.group(init_fn) - else: - init_fn = None + Args: + input_fn: A function that provides input data for training as minibatches. + hooks: List of `tf.train.SessionRunHook` subclass instances. Used for + callbacks inside the training loop. + saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used + for callbacks that run immediately before or after checkpoint savings. - init_op = [s.init_op for s in scaffold_list if s.init_op is not None] - if init_op: - init_op = self._distribution.group(init_op) - else: - init_op = None + Returns: + Loss from training + """ + self._train_distribution.configure(self._session_config) - ready_op = self._distribution.call_for_each_tower( - create_per_tower_ready_op, grouped_estimator_spec.scaffold) - if ready_op is not None: - ready_op = self._distribution.group(ready_op) - else: - ready_op = None - - ready_for_local_init_op = self._distribution.call_for_each_tower( - create_per_tower_ready_for_local_init_op, - grouped_estimator_spec.scaffold) - if ready_for_local_init_op is not None: - ready_for_local_init_op = self._distribution.group( - ready_for_local_init_op) - else: - ready_for_local_init_op = None + # TODO(sourabhbajaj): Remove this hack once we migrate the other strategies + # to use the new API + is_tpu_strategy = ( + self._train_distribution.__class__.__name__ == 'TPUStrategy') - local_init_op = [ - s.local_init_op - for s in scaffold_list - if s.local_init_op is not None - ] - if local_init_op: - local_init_op = self._distribution.group(local_init_op) - else: - local_init_op = None + worker_hooks = [] + with ops.Graph().as_default() as g: + # We want to create the iterations variable outside the distribution scope + # as that is just stored on the host and mainly used to drive the loop + # and doesn't need to be a Mirrored/Device variable. + steps_per_run_variable = training.get_or_create_steps_per_run_variable() + with self._train_distribution.scope(): + random_seed.set_random_seed(self._config.tf_random_seed) - summary_op = [ - s.summary_op for s in scaffold_list if s.summary_op is not None - ] - if summary_op: - summary_op = self._distribution.group(summary_op) + if is_tpu_strategy: + # Create the iterator for run_on_dataset function + # TODO(sourabhbajaj): refactor this out to call a function on the + # strategy + dataset = self._train_distribution.distribute_dataset( + lambda: self._call_input_fn(input_fn, # pylint: disable=g-long-lambda + model_fn_lib.ModeKeys.TRAIN)) + iterator = dataset.make_initializable_iterator() + worker_hooks.append( + estimator_util._DatasetInitializerHook(iterator)) # pylint: disable=protected-access + + global_step_tensor = self._create_and_assert_global_step(g) + # we want to add to the global collection in the main thread not the + # tower threads. + ops.add_to_collection( + training_util.GLOBAL_STEP_READ_KEY, + self._train_distribution.read_var(global_step_tensor)) + + # Create a step_fn from the train_op of grouped_estimator_spec + def step_fn(ctx, inputs): + """A single step that is passed to run_on_dataset.""" + features, labels = inputs + estimator_spec = self._train_distribution.call_for_each_tower( + self._call_model_fn, + features, + labels, + model_fn_lib.ModeKeys.TRAIN, + self.config) + ctx.set_last_step_output( + name='loss', + output=estimator_spec.loss, + aggregation=distribute_lib.get_loss_reduction()) + ctx.set_non_tensor_output( + name='estimator_spec', output=estimator_spec) + return estimator_spec.train_op + + # Create new train_op post graph rewrites + initial_training_loss = constant_op.constant(1e7) + ctx = self._train_distribution.run_steps_on_dataset( + step_fn, iterator, iterations=steps_per_run_variable, + initial_loop_values={'loss': initial_training_loss}) + distributed_train_op = ctx.run_op + tpu_result = ctx.last_step_outputs + grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec'] else: - summary_op = None - - scaffold = monitored_session.Scaffold( - init_op=init_op, - ready_op=ready_op, - ready_for_local_init_op=ready_for_local_init_op, - local_init_op=local_init_op, - summary_op=summary_op, - init_feed_dict=init_feed_dict, - init_fn=init_fn) + features, labels, input_hooks = ( + self._get_features_and_labels_from_input_fn( + input_fn, model_fn_lib.ModeKeys.TRAIN, + self._train_distribution)) + worker_hooks.extend(input_hooks) + global_step_tensor = self._create_and_assert_global_step(g) + # we want to add to the global collection in the main thread not the + # tower threads. + ops.add_to_collection( + training_util.GLOBAL_STEP_READ_KEY, + self._train_distribution.read_var(global_step_tensor)) + grouped_estimator_spec = self._train_distribution.call_for_each_tower( + self._call_model_fn, + features, + labels, # although this will be None it seems + model_fn_lib.ModeKeys.TRAIN, + self.config) + + scaffold = _combine_distributed_scaffold( + grouped_estimator_spec.scaffold, self._train_distribution) def get_hooks_from_the_first_device(per_device_hooks): - hooks_list = self._distribution.unwrap(per_device_hooks) + hooks_list = self._train_distribution.unwrap(per_device_hooks) assert hooks_list return hooks_list[0] @@ -1243,18 +1306,30 @@ class Estimator(object): training_chief_hooks = get_hooks_from_the_first_device( grouped_estimator_spec.training_chief_hooks) + # TODO(sourabhbajaj): Merge the two code paths and clean up the code + if is_tpu_strategy: + loss = tpu_result['loss'] + worker_hooks.append( + estimator_util.StrategyInitFinalizeHook( + self._train_distribution.initialize, + self._train_distribution.finalize)) + else: + loss = self._train_distribution.unwrap( + self._train_distribution.reduce( + distribute_lib.get_loss_reduction(), + grouped_estimator_spec.loss, + destinations='/device:CPU:0'))[0] + distributed_train_op = grouped_estimator_spec.train_op + estimator_spec = model_fn_lib.EstimatorSpec( mode=grouped_estimator_spec.mode, - loss=self._distribution.unwrap( - self._distribution.reduce(distribute_lib.get_loss_reduction(), - grouped_estimator_spec.loss, - destinations='/device:CPU:0'))[0], - train_op=self._distribution.group(grouped_estimator_spec.train_op), + loss=loss, + train_op=self._train_distribution.group(distributed_train_op), training_hooks=training_hooks, training_chief_hooks=training_chief_hooks, scaffold=scaffold) return self._train_with_estimator_spec(estimator_spec, worker_hooks, - hooks, global_step_read_tensor, + hooks, global_step_tensor, saving_listeners) def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks, @@ -1344,27 +1419,31 @@ class Estimator(object): def _evaluate_build_graph(self, input_fn, hooks=None, checkpoint_path=None): """Builds the graph and related hooks to run evaluation.""" random_seed.set_random_seed(self._config.tf_random_seed) - global_step_tensor = self._create_and_assert_global_step( - ops.get_default_graph()) + self._create_and_assert_global_step(ops.get_default_graph()) features, labels, input_hooks = ( - self._get_features_and_labels_from_input_fn(input_fn, - model_fn_lib.ModeKeys.EVAL)) - estimator_spec = self._call_model_fn( - features, labels, model_fn_lib.ModeKeys.EVAL, self.config) + self._get_features_and_labels_from_input_fn( + input_fn, model_fn_lib.ModeKeys.EVAL, self._eval_distribution)) + + if self._eval_distribution: + (loss_metric, scaffold, evaluation_hooks, eval_metric_ops) = ( + self._call_model_fn_eval_distributed(features, labels, self.config)) + else: + (loss_metric, scaffold, evaluation_hooks, eval_metric_ops) = ( + self._call_model_fn_eval(features, labels, self.config)) + global_step_tensor = training_util.get_global_step(ops.get_default_graph()) # Call to warm_start has to be after model_fn is called. self._maybe_warm_start(checkpoint_path) - if model_fn_lib.LOSS_METRIC_KEY in estimator_spec.eval_metric_ops: + if model_fn_lib.LOSS_METRIC_KEY in eval_metric_ops: raise ValueError( 'Metric with name "%s" is not allowed, because Estimator ' % (model_fn_lib.LOSS_METRIC_KEY) + 'already defines a default metric with the same name.') - estimator_spec.eval_metric_ops[ - model_fn_lib.LOSS_METRIC_KEY] = metrics_lib.mean(estimator_spec.loss) + eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric - update_op, eval_dict = _extract_metric_update_ops( - estimator_spec.eval_metric_ops) + update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops, + self._eval_distribution) if ops.GraphKeys.GLOBAL_STEP in eval_dict: raise ValueError( @@ -1374,9 +1453,42 @@ class Estimator(object): all_hooks = list(input_hooks) all_hooks.extend(hooks) - all_hooks.extend(list(estimator_spec.evaluation_hooks or [])) + all_hooks.extend(list(evaluation_hooks or [])) + # New local variables have been added, so update the estimator spec's + # local init op if it was defined. + if scaffold and scaffold.local_init_op: + # Ensure that eval step has been created before updating local init op. + evaluation._get_or_create_eval_step() # pylint: disable=protected-access + + scaffold = monitored_session.Scaffold( + local_init_op=control_flow_ops.group( + scaffold.local_init_op, + monitored_session.Scaffold.default_local_init_op()), + copy_from_scaffold=scaffold + ) + + return scaffold, update_op, eval_dict, all_hooks - return estimator_spec.scaffold, update_op, eval_dict, all_hooks + def _call_model_fn_eval(self, features, labels, config): + estimator_spec = self._call_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, config) + loss_metric = metrics_lib.mean(estimator_spec.loss) + return (loss_metric, estimator_spec.scaffold, + estimator_spec.evaluation_hooks, estimator_spec.eval_metric_ops) + + def _call_model_fn_eval_distributed(self, features, labels, config): + """Call model_fn in distribution mode and handle return values.""" + grouped_estimator_spec = self._eval_distribution.call_for_each_tower( + self._call_model_fn, features, labels, + model_fn_lib.ModeKeys.EVAL, config) + scaffold = _combine_distributed_scaffold( + grouped_estimator_spec.scaffold, self._eval_distribution) + evaluation_hooks = self._eval_distribution.unwrap( + grouped_estimator_spec.evaluation_hooks)[0] + loss_metric = self._eval_distribution.call_for_each_tower( + metrics_lib.mean, grouped_estimator_spec.loss) + return (loss_metric, scaffold, + evaluation_hooks, grouped_estimator_spec.eval_metric_ops) def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict, all_hooks, output_dir): @@ -1412,8 +1524,51 @@ class Estimator(object): warm_starting_util.warm_start(*self._warm_start_settings) +def maybe_overwrite_model_dir_and_session_config(config, model_dir): + """Overwrite estimator config by `model_dir` and `session_config` if needed. + + Args: + config: Original estimator config. + model_dir: Estimator model checkpoint directory. + + Returns: + Overwritten estimator config. + + Raises: + ValueError: Model directory inconsistent between `model_dir` and `config`. + """ + + if config is None: + config = run_config.RunConfig() + logging.info('Using default config.') + if not isinstance(config, run_config.RunConfig): + raise ValueError( + 'config must be an instance of `RunConfig`, but provided %s.' % config) + + if config.session_config is None: + session_config = run_config.get_default_session_config() + config = run_config.RunConfig.replace(config, session_config=session_config) + + model_dir = compat_internal.path_to_str(model_dir) + if model_dir is not None: + if (getattr(config, 'model_dir', None) is not None and + config.model_dir != model_dir): + raise ValueError( + "`model_dir` are set both in constructor and `RunConfig`, but with " + "different values. In constructor: '{}', in `RunConfig`: " + "'{}' ".format(model_dir, config.model_dir)) + if model_dir: + config = run_config.RunConfig.replace(config, model_dir=model_dir) + elif getattr(config, 'model_dir', None) is None: + model_dir = tempfile.mkdtemp() + logging.warning('Using temporary folder as model directory: %s', model_dir) + config = run_config.RunConfig.replace(config, model_dir=model_dir) + + return config + + def create_per_tower_ready_op(scaffold): - """Create a Scaffold.ready_op inside a tower.""" + """Create a `tf.train.Scaffold.ready_op` inside a tower.""" if scaffold.ready_op: return scaffold.ready_op @@ -1428,7 +1583,7 @@ def create_per_tower_ready_op(scaffold): def create_per_tower_ready_for_local_init_op(scaffold): - """Create a Scaffold.ready_for_local_init_op inside a tower.""" + """Create a `tf.train.Scaffold.ready_for_local_init_op` inside a tower.""" if scaffold.ready_for_local_init_op: return scaffold.ready_for_local_init_op @@ -1441,15 +1596,92 @@ def create_per_tower_ready_for_local_init_op(scaffold): default_ready_for_local_init_op) +def _combine_distributed_scaffold(grouped_scaffold, distribution): + """Combines scaffold(s) returned from `distribution.call_for_each_tower`.""" + + # TODO(anjalisridhar): Figure out how to resolve the following scaffold + # parameters: init_feed_dict, init_fn. + scaffold_list = distribution.unwrap(grouped_scaffold) + init_feed_dict = [ + s.init_feed_dict + for s in scaffold_list + if s.init_feed_dict is not None + ] + if init_feed_dict: + init_feed_dict = distribution.group(init_feed_dict) + else: + init_feed_dict = None + + init_fn = [s.init_fn for s in scaffold_list if s.init_fn is not None] + if init_fn: + init_fn = distribution.group(init_fn) + else: + init_fn = None + + init_op = [s.init_op for s in scaffold_list if s.init_op is not None] + if init_op: + init_op = distribution.group(init_op) + else: + init_op = None + + def _unwrap_and_concat(value): + value = nest.flatten(distribution.unwrap(value)) + if len(value) != 1: + return array_ops.concat(value) + return value[0] + + ready_op = distribution.call_for_each_tower( + create_per_tower_ready_op, grouped_scaffold) + if ready_op is not None: + ready_op = _unwrap_and_concat(ready_op) + else: + ready_op = None + + ready_for_local_init_op = distribution.call_for_each_tower( + create_per_tower_ready_for_local_init_op, grouped_scaffold) + if ready_for_local_init_op is not None: + ready_for_local_init_op = _unwrap_and_concat(ready_for_local_init_op) + else: + ready_for_local_init_op = None + + local_init_op = [ + s.local_init_op + for s in scaffold_list + if s.local_init_op is not None + ] + if local_init_op: + local_init_op = distribution.group(local_init_op) + else: + local_init_op = None + + summary_op = [ + s.summary_op for s in scaffold_list if s.summary_op is not None + ] + if summary_op: + summary_op = distribution.group(summary_op) + else: + summary_op = None + + scaffold = monitored_session.Scaffold( + init_op=init_op, + ready_op=ready_op, + ready_for_local_init_op=ready_for_local_init_op, + local_init_op=local_init_op, + summary_op=summary_op, + init_feed_dict=init_feed_dict, + init_fn=init_fn) + return scaffold + + def _check_checkpoint_available(model_dir): - latest_path = saver.latest_checkpoint(model_dir) + latest_path = checkpoint_management.latest_checkpoint(model_dir) if not latest_path: raise ValueError( 'Could not find trained model in model_dir: {}.'.format(model_dir)) def _check_hooks_type(hooks): - """Returns hooks if all are SessionRunHook, raises TypeError otherwise.""" + """Returns hooks if all are `SessionRunHook`, raises TypeError otherwise.""" hooks = list(hooks or []) for h in hooks: if not isinstance(h, training.SessionRunHook): @@ -1469,17 +1701,18 @@ def _check_listeners_type(saving_listeners): def _get_replica_device_setter(config): - """Creates a replica device setter if required as a default device_fn. + """Creates a replica device setter if required as a default `device_fn`. - `Estimator` uses ReplicaDeviceSetter as a default device placer. It sets the - distributed related arguments such as number of ps_replicas based on given - config. + `Estimator` uses `tf.train.ReplicaDeviceSetter` as a default device placer. It + sets the + distributed related arguments such as number of `ps_replicas` based on given + `config`. Args: - config: A `RunConfig` instance. + config: A `tf.estimator.RunConfig` instance. Returns: - A replica device setter, or None. + A replica device setter, or `None`. """ if config.task_type: worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id) @@ -1498,7 +1731,7 @@ def _get_replica_device_setter(config): def _verify_model_fn_args(model_fn, params): - """Verifies model fn arguments.""" + """Verifies `model_fn` arguments.""" args = set(function_utils.fn_args(model_fn)) if 'features' not in args: raise ValueError('model_fn (%s) must include features argument.' % model_fn) @@ -1525,14 +1758,18 @@ def _load_global_step_from_checkpoint_dir(checkpoint_dir): return 0 -def _extract_metric_update_ops(eval_dict): +def _extract_metric_update_ops(eval_dict, distribution=None): """Separate update operations from metric value operations.""" update_ops = [] value_ops = {} # Sort metrics lexicographically so graph is identical every time. for name, metric_ops in sorted(six.iteritems(eval_dict)): value_ops[name] = metric_ops[0] - update_ops.append(metric_ops[1]) + if distribution: + update_op = distribution.group(metric_ops[1]) + else: + update_op = metric_ops[1] + update_ops.append(update_op) if update_ops: update_op = control_flow_ops.group(*update_ops) @@ -1592,10 +1829,24 @@ def _write_dict_to_summary(output_dir, logging.warn('Skipping summary for %s, cannot parse string to Summary.', key) continue + elif isinstance(dictionary[key], np.ndarray): + value = summary_proto.value.add() + value.tag = key + value.node_name = key + tensor_proto = tensor_util.make_tensor_proto(dictionary[key]) + value.tensor.CopyFrom(tensor_proto) + # pylint: disable=line-too-long + logging.info( + 'Summary for np.ndarray is not visible in Tensorboard by default. ' + 'Consider using a Tensorboard plugin for visualization (see ' + 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md' + ' for more information).') + # pylint: enable=line-too-long else: logging.warn( 'Skipping summary for %s, must be a float, np.float32, np.int64, ' - 'np.int32 or int or a serialized string of Summary.', key) + 'np.int32 or int or np.ndarray or a serialized string of Summary.', + key) summary_writer.add_summary(summary_proto, current_global_step) summary_writer.flush() @@ -1625,7 +1876,7 @@ def _write_checkpoint_path_to_summary(output_dir, checkpoint_path, def _has_dataset_or_queue_runner(maybe_tensor): - """Returns True if TF dataset or QueueRunner has been used.""" + """Returns `True` if `Dataset` or `QueueRunner` has been used.""" # Check TF dataset first. Here, we use a simple algorithm to check the top # level Tensors only, which should be sufficient for most users. tensors = [x for x in nest.flatten(maybe_tensor) if isinstance(x, ops.Tensor)] @@ -1648,9 +1899,9 @@ class WarmStartSettings( 'var_name_to_vocab_info', 'var_name_to_prev_var_name', ])): - """Settings for warm-starting in Estimators. + """Settings for warm-starting in `tf.estimator.Estimators`. - Example Use with canned `DNNEstimator`: + Example Use with canned `tf.estimator.DNNEstimator`: ``` emb_vocab_file = tf.feature_column.embedding_column( @@ -1767,23 +2018,19 @@ class WarmStartSettings( ckpt_to_initialize_from: [Required] A string specifying the directory with checkpoint file(s) or path to checkpoint from which to warm-start the model parameters. - vars_to_warm_start: [Optional] One of the following: - - - A regular expression (string) that captures which variables to - warm-start (see tf.get_collection). This expression will only consider - variables in the TRAINABLE_VARIABLES collection. - - A list of Variables to warm-start. - - A list of strings, each representing a full variable name to warm-start. - - `None`, in which case only variables specified in - `var_name_to_vocab_info` will be warm-started. - - Defaults to `'.*'`, which warm-starts all variables in the - TRAINABLE_VARIABLES collection. Note that this excludes variables such as - accumulators and moving statistics from batch norm. + vars_to_warm_start: [Optional] One of the following: - A regular expression + (string) that captures which variables to warm-start (see + `tf.get_collection`). This expression will only consider variables in the + `TRAINABLE_VARIABLES` collection. - A list of Variables to warm-start. - A + list of strings, each representing a full variable name to warm-start. - + `None`, in which case only variables specified in `var_name_to_vocab_info` + will be warm-started. Defaults to `'.*'`, which warm-starts all variables + in the `TRAINABLE_VARIABLES` collection. Note that this excludes + variables such as accumulators and moving statistics from batch norm. var_name_to_vocab_info: [Optional] Dict of variable names (strings) to - VocabInfo. The variable names should be "full" variables, not the names - of the partitions. If not explicitly provided, the variable is assumed to - have no vocabulary. + `tf.estimator.VocabInfo`. The variable names should be "full" variables, + not the names of the partitions. If not explicitly provided, the variable + is assumed to have no vocabulary. var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to name of the previously-trained variable in `ckpt_to_initialize_from`. If not explicitly provided, the name of the variable is assumed to be same @@ -1807,33 +2054,46 @@ class WarmStartSettings( ) +def _get_saved_model_ckpt(saved_model_dir): + """Return path to variables checkpoint in a `SavedModel` directory.""" + if not gfile.Exists( + os.path.join(saved_model_utils.get_variables_dir(saved_model_dir), + compat.as_text('variables.index'))): + raise ValueError('Directory provided has an invalid SavedModel format: %s' + % saved_model_dir) + return saved_model_utils.get_variables_path(saved_model_dir) + + def _get_default_warm_start_settings(warm_start_from): - """Returns default WarmStartSettings. + """Returns default `tf.estimator.WarmStartSettings`. Args: warm_start_from: Either a string representing the filepath of a checkpoint - or SavedModel to initialize from, or an instance of WarmStartSettings. + or `SavedModel` to initialize from, or an instance of + `tf.estimator.WarmStartSettings`. Returns: - Either None or an instance of WarmStartSettings. + Either None or an instance of `WarmStartSettings`. Raises: - ValueError: If warm_start_from is not None but is neither a string nor an - instance of WarmStartSettings. + ValueError: If `warm_start_from` is not `None` but is neither a string nor + an + instance of `WarmStartSettings`. """ if warm_start_from is None: return None if isinstance(warm_start_from, (six.string_types, six.binary_type)): # Infer that this is a SavedModel if export_path + # 'variables/variables.index' exists, and if so, construct the - # WarmStartSettings pointing to export_path + 'variables/variables'. - if gfile.Exists(os.path.join(compat.as_bytes(warm_start_from), - compat.as_bytes('variables/variables.index'))): + # WarmStartSettings pointing to the variables path + # (export_path + 'variables/variables'). + if gfile.Exists(os.path.join( + saved_model_utils.get_variables_dir(warm_start_from), + compat.as_text('variables.index'))): logging.info('Warm-starting from a SavedModel') - return WarmStartSettings(ckpt_to_initialize_from=os.path.join( - compat.as_bytes(warm_start_from), - compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY, - constants.VARIABLES_FILENAME)))) + return WarmStartSettings( + ckpt_to_initialize_from=saved_model_utils.get_variables_path( + warm_start_from)) return WarmStartSettings(ckpt_to_initialize_from=warm_start_from) elif isinstance(warm_start_from, WarmStartSettings): return warm_start_from diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index a43b820f322d70093a5015457fea294e436daeea..e3f22d9010cccb9c9f4a7937270e4aff030df910 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -28,6 +28,7 @@ import six from google.protobuf import text_format +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator @@ -38,6 +39,7 @@ from tensorflow.python.estimator.export import export_output from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util @@ -61,11 +63,13 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import loader_impl +from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.summary import summary_iterator from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import checkpoint_state_pb2 from tensorflow.python.training import saver from tensorflow.python.training import saver_test_utils @@ -172,7 +176,7 @@ class EstimatorInheritanceConstraintTest(test.TestCase): class EstimatorConstructorTest(test.TestCase): def test_config_must_be_a_run_config(self): - with self.assertRaisesRegexp(ValueError, 'an instance of RunConfig'): + with self.assertRaisesRegexp(ValueError, 'an instance of `RunConfig`'): estimator.Estimator(model_fn=None, config='NotARunConfig') def test_model_fn_must_be_provided(self): @@ -201,6 +205,10 @@ class EstimatorConstructorTest(test.TestCase): est = estimator.Estimator(model_fn=model_fn) self.assertTrue(isinstance(est.config, run_config.RunConfig)) + self.assertTrue(est._session_config.allow_soft_placement) + rewrite_options = est._session_config.graph_options.rewrite_options + self.assertEqual(rewrite_options.meta_optimizer_iterations, + rewriter_config_pb2.RewriterConfig.ONE) def test_default_model_dir(self): @@ -221,6 +229,15 @@ class EstimatorConstructorTest(test.TestCase): self.assertEqual(_TMP_DIR, est.config.model_dir) self.assertEqual(_TMP_DIR, est.model_dir) + def test_empty_model_dir(self): + def model_fn(features, labels): + _, _ = features, labels + + with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR): + est = estimator.Estimator(model_fn=model_fn, model_dir='') + self.assertEqual(_TMP_DIR, est.config.model_dir) + self.assertEqual(_TMP_DIR, est.model_dir) + def test_model_dir_in_run_config(self): class FakeConfig(run_config.RunConfig): @@ -265,7 +282,7 @@ class EstimatorConstructorTest(test.TestCase): with self.assertRaisesRegexp( ValueError, - 'model_dir are set both in constructor and RunConfig, but ' + '`model_dir` are set both in constructor and `RunConfig`, but ' 'with different values'): estimator.Estimator( model_fn=model_fn, config=FakeConfig(), model_dir=_ANOTHER_TMP_DIR) @@ -1295,6 +1312,31 @@ class EstimatorEvaluateTest(test.TestCase): dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint()) self.assertEqual(5, scores['global_step']) + def test_wrong_shape_throws_reasonable_error(self): + """Make sure we are helpful when model_fns change. See b/110263146.""" + def _get_model_fn(val=1): + def _model_fn(features, labels, mode): + del features, labels # unused + variables.Variable(val, name='weight') + return model_fn_lib.EstimatorSpec( + mode=mode, + predictions=constant_op.constant([[1.]]), + loss=constant_op.constant(0.), + train_op=state_ops.assign_add(training.get_global_step(), 1)) + return _model_fn + + model_fn_1 = _get_model_fn() + model_fn_2 = _get_model_fn(val=[1]) + + est1 = estimator.Estimator(model_fn=model_fn_1) + est1.train(dummy_input_fn, steps=5) + est2 = estimator.Estimator( + model_fn=model_fn_2, model_dir=est1.model_dir) + + expected_msg = 'Restoring from checkpoint failed.*a mismatch between' + with self.assertRaisesRegexp(errors.InvalidArgumentError, expected_msg): + est2.train(dummy_input_fn, steps=1,) + def test_scaffold_is_used(self): def _model_fn_scaffold(features, labels, mode): @@ -1416,6 +1458,48 @@ class EstimatorEvaluateTest(test.TestCase): self.assertProtoEquals(expected_tensor_proto, next(summaries).value[0].tensor) + def test_summary_writing_with_tensor(self): + + def model_fn_with_prediction_mean_tensor_eval_metric_ops( + features, labels, mode, params): + _, _ = features, labels + global_step = training.get_global_step() + + metric_name = params.get('metric_name') or 'metric' + predictions = constant_op.constant([1., .5, 0.]) + eval_metric_ops = {metric_name: metrics_lib.mean_tensor(predictions)} + return model_fn_lib.EstimatorSpec( + mode, + loss=constant_op.constant(1.), + predictions={'predictions': predictions}, + train_op=state_ops.assign_add(global_step, 1), + eval_metric_ops=eval_metric_ops) + + metric_key = 'PMT' + params = { + 'metric_name': metric_key, + } + est = estimator.Estimator( + model_fn=model_fn_with_prediction_mean_tensor_eval_metric_ops, + params=params, + config=run_config.RunConfig(save_summary_steps=1)) + est.train(input_fn=dummy_input_fn, steps=10) + est.evaluate( + input_fn=dummy_input_fn, + steps=10, + ) + + writer_cache.FileWriterCache.clear() + + self.assertTrue( + check_eventfile_for_keyword(metric_key, est.eval_dir()), + '{} should be part of reported summaries.'.format(metric_key)) + + summaries = summaries_with_matching_keyword(metric_key, est.eval_dir()) + for value in next(summaries).value: + if value.tag == metric_key: + self.assertTrue(value.HasField('tensor')) + class EstimatorPredictTest(test.TestCase): @@ -1507,7 +1591,8 @@ class EstimatorPredictTest(test.TestCase): next( est.predict( dummy_input_fn, - checkpoint_path=saver.latest_checkpoint('fakedir'))) + checkpoint_path= + checkpoint_management.latest_checkpoint('fakedir'))) def test_tensor_predictions(self): @@ -2277,6 +2362,43 @@ class EstimatorExportTest(test.TestCase): with self.assertRaisesRegexp(ValueError, err_regex): est._export_all_saved_models(export_dir_base, input_receiver_fn_map) + def test_export_all_saved_models_metric_operation(self): + """Ensures metrics ops.Operations can be expoerted (b/109740581).""" + + def _model_fn(features, labels, mode): + del features, labels # Unused + metrics = {'metrics': (constant_op.constant([0]), + control_flow_ops.no_op())} + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn) + est.train(input_fn=dummy_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('metric_operation_export')) + + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()} + + export_dir = est._export_all_saved_models( + export_dir_base, input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + meta_graph = loader.load(sess, [tag_constants.EVAL], export_dir) + sig_outputs = meta_graph.signature_def[ + model_fn_lib.ModeKeys.EVAL].outputs + self.assertEqual( + sig_outputs['metrics/update_op'].name, 'metric_op_wrapper:0') + def test_export_savedmodel_with_saveables_proto_roundtrip(self): tmpdir = tempfile.mkdtemp() est = estimator.Estimator( @@ -2561,6 +2683,7 @@ class EstimatorExportTest(test.TestCase): _, _ = features, labels my_int = variables.Variable(1, name='my_int', collections=[ops.GraphKeys.LOCAL_VARIABLES]) + _ = training.get_or_create_steps_per_run_variable() scores = constant_op.constant([3.]) with ops.control_dependencies([ variables.local_variables_initializer(), @@ -2829,6 +2952,45 @@ class EstimatorExportTest(test.TestCase): # Clean up. gfile.DeleteRecursively(tmpdir) + def test_export_savedmodel_no_export_outputs(self): + """Ensure that an EstimatorSpec without outputs defined can be exported.""" + + def _model_fn(features, labels, mode): + _, _ = features, labels + variables.Variable(1., name='weight') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1)) + + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn) + est.train(input_fn=dummy_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('no_export_outputs')) + export_dir = est.export_savedmodel( + export_dir_base, _get_serving_input_receiver_fn()) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + meta_graph = loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('weight' in graph_ops) + + sig_def = meta_graph.signature_def + self.assertEqual(len(sig_def), 1) + sig_outputs = sig_def[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs + self.assertEqual(sig_outputs['output'].name, 'Const:0') + class EstimatorHookOrderingTest(test.TestCase): @@ -2873,7 +3035,7 @@ class EstimatorHookOrderingTest(test.TestCase): class EstimatorIntegrationTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_complete_flow_with_a_simple_linear_model(self): def _model_fn(features, labels, mode): diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index ca26341445e86ad554ac2e7cbf643c7775dd9825..3d171f78119e10f700e2b98811ba169a8e037938 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -40,29 +40,38 @@ _SINGLE_FEATURE_DEFAULT_NAME = 'feature' _SINGLE_RECEIVER_DEFAULT_NAME = 'input' _SINGLE_LABEL_DEFAULT_NAME = 'label' +_SINGLE_TENSOR_DEFAULT_NAMES = { + 'feature': _SINGLE_FEATURE_DEFAULT_NAME, + 'label': _SINGLE_LABEL_DEFAULT_NAME, + 'receiver_tensor': _SINGLE_RECEIVER_DEFAULT_NAME, + 'receiver_tensors_alternative': _SINGLE_RECEIVER_DEFAULT_NAME +} + -def _wrap_and_check_receiver_tensors(receiver_tensors): - """Ensure that receiver_tensors is a dict of str to Tensor mappings. +def _wrap_and_check_input_tensors(tensors, field_name): + """Ensure that tensors is a dict of str to Tensor mappings. Args: - receiver_tensors: dict of str to Tensors, or a single Tensor. + tensors: dict of str to Tensors, or a single Tensor. + field_name: name of the member field of `ServingInputReceiver` + whose value is being passed to `tensors`. Returns: dict of str to Tensors; this is the original dict if one was passed, or the original tensor wrapped in a dictionary. Raises: - ValueError: if receiver_tensors is None, or has non-string keys, + ValueError: if tensors is None, or has non-string keys, or non-Tensor values """ - if receiver_tensors is None: - raise ValueError('receiver_tensors must be defined.') - if not isinstance(receiver_tensors, dict): - receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} - for name, tensor in receiver_tensors.items(): - _check_tensor_key(name, error_label='receiver_tensors') - _check_tensor(tensor, name, error_label='receiver_tensor') - return receiver_tensors + if tensors is None: + raise ValueError('{}s must be defined.'.format(field_name)) + if not isinstance(tensors, dict): + tensors = {_SINGLE_TENSOR_DEFAULT_NAMES[field_name]: tensors} + for name, tensor in tensors.items(): + _check_tensor_key(name, error_label=field_name) + _check_tensor(tensor, name, error_label=field_name) + return tensors def _check_tensor(tensor, name, error_label='feature'): @@ -125,15 +134,10 @@ class ServingInputReceiver( features, receiver_tensors, receiver_tensors_alternatives=None): - if features is None: - raise ValueError('features must be defined.') - if not isinstance(features, dict): - features = {_SINGLE_FEATURE_DEFAULT_NAME: features} - for name, tensor in features.items(): - _check_tensor_key(name) - _check_tensor(tensor, name) + features = _wrap_and_check_input_tensors(features, 'feature') - receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors) + receiver_tensors = _wrap_and_check_input_tensors(receiver_tensors, + 'receiver_tensor') if receiver_tensors_alternatives is not None: if not isinstance(receiver_tensors_alternatives, dict): @@ -142,17 +146,10 @@ class ServingInputReceiver( receiver_tensors_alternatives)) for alternative_name, receiver_tensors_alt in ( six.iteritems(receiver_tensors_alternatives)): - if not isinstance(receiver_tensors_alt, dict): - receiver_tensors_alt = { - _SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt - } - # Updating dict during iteration is OK in this case. - receiver_tensors_alternatives[alternative_name] = ( - receiver_tensors_alt) - for name, tensor in receiver_tensors_alt.items(): - _check_tensor_key(name, error_label='receiver_tensors_alternative') - _check_tensor( - tensor, name, error_label='receiver_tensors_alternative') + # Updating dict during iteration is OK in this case. + receiver_tensors_alternatives[alternative_name] = ( + _wrap_and_check_input_tensors( + receiver_tensors_alt, 'receiver_tensors_alternative')) return super(ServingInputReceiver, cls).__new__( cls, @@ -245,16 +242,12 @@ class SupervisedInputReceiver( def __new__(cls, features, labels, receiver_tensors): # Both features and labels can be dicts or raw tensors. for input_vals, error_label in ((features, 'feature'), (labels, 'label')): - if input_vals is None: - raise ValueError('{}s must be defined.'.format(error_label)) - if isinstance(input_vals, dict): - for name, tensor in input_vals.items(): - _check_tensor_key(name, error_label=error_label) - _check_tensor(tensor, name, error_label=error_label) - else: - _check_tensor(input_vals, None, error_label=error_label) + # _wrap_and_check_input_tensors is called here only to validate the + # tensors. The wrapped dict that is returned is deliberately discarded. + _wrap_and_check_input_tensors(input_vals, error_label) - receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors) + receiver_tensors = _wrap_and_check_input_tensors(receiver_tensors, + 'receiver_tensor') return super(SupervisedInputReceiver, cls).__new__( cls, @@ -295,9 +288,8 @@ def build_parsing_serving_input_receiver_fn(feature_spec, def _placeholder_from_tensor(t, default_batch_size=None): - shape_list = t.get_shape().as_list() - shape_list[0] = default_batch_size - shape = tensor_shape.TensorShape(shape_list) + batch_shape = tensor_shape.TensorShape([default_batch_size]) + shape = batch_shape.concatenate(t.get_shape()[1:]) # Reuse the feature tensor's op name (t.op.name) for the placeholder, # excluding the index from the tensor's name (t.name): diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 6c26d299851eaea74f1e564d0fac217f238d76a2..20382a58d8d6fa5be938ee08fcf1487043868301 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -23,6 +23,7 @@ import abc import six +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.saved_model import signature_def_utils @@ -338,8 +339,16 @@ class _SupervisedOutput(ExportOutput): raise ValueError( '{} update_op must be a Tensor or Operation; got {}.'.format( key, metric_op)) + + # We must wrap any ops in a Tensor before export, as the SignatureDef + # proto expects tensors only. See b/109740581 + metric_op_tensor = metric_op + if isinstance(metric_op, ops.Operation): + with ops.control_dependencies([metric_op]): + metric_op_tensor = constant_op.constant([], name='metric_op_wrapper') + outputs[val_name] = metric_val - outputs[op_name] = metric_op + outputs[op_name] = metric_op_tensor return outputs diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py index b21ba91b0fbb7e14df5eb74dbabace57d3596cc9..d94c764fd7c353a5eeb13c5272b7fe0c4ebdfe07 100644 --- a/tensorflow/python/estimator/export/export_output_test.py +++ b/tensorflow/python/estimator/export/export_output_test.py @@ -24,8 +24,10 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -335,5 +337,18 @@ class SupervisedOutputTest(test.TestCase): self.assertTrue("predictions/output1" in sig_def.outputs) self.assertTrue("features" in sig_def.inputs) + def test_metric_op_is_operation(self): + """Tests that ops.Operation is wrapped by a tensor for metric_ops.""" + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + metrics = {"metrics": (constant_op.constant([0]), control_flow_ops.no_op())} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0]) + self.assertEqual( + outputter.metrics["metrics/update_op"].name, "metric_op_wrapper:0") + self.assertTrue( + isinstance(outputter.metrics["metrics/update_op"], ops.Tensor)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index a7074712c25532a1d2156a11d2314150d9efabc1..1d475adb4396e58abd16c5b4bec1ad5ede925335 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops @@ -107,7 +108,7 @@ class ServingInputReceiverTest(test_util.TensorFlowTestCase): receiver_tensors=None) with self.assertRaisesRegexp( - ValueError, "receiver_tensors keys must be strings"): + ValueError, "receiver_tensor keys must be strings"): export.ServingInputReceiver( features=features, receiver_tensors={ @@ -271,7 +272,7 @@ class SupervisedInputReceiverTest(test_util.TensorFlowTestCase): receiver_tensors=None) with self.assertRaisesRegexp( - ValueError, "receiver_tensors keys must be strings"): + ValueError, "receiver_tensor keys must be strings"): export.SupervisedInputReceiver( features=features, labels=labels, @@ -378,6 +379,20 @@ class ExportTest(test_util.TensorFlowTestCase): v = serving_input_receiver_fn() self.assertTrue(isinstance(v, export.ServingInputReceiver)) + def test_build_raw_serving_input_receiver_fn_without_shape(self): + """Test case for issue #21178.""" + f = {"feature_1": array_ops.placeholder(dtypes.float32), + "feature_2": array_ops.placeholder(dtypes.int32)} + serving_input_receiver_fn = export.build_raw_serving_input_receiver_fn(f) + v = serving_input_receiver_fn() + self.assertTrue(isinstance(v, export.ServingInputReceiver)) + self.assertEqual( + tensor_shape.unknown_shape(), + v.receiver_tensors["feature_1"].shape) + self.assertEqual( + tensor_shape.unknown_shape(), + v.receiver_tensors["feature_2"].shape) + def test_build_raw_serving_input_receiver_fn(self): features = {"feature_1": constant_op.constant(["hello"]), "feature_2": constant_op.constant([42])} @@ -740,7 +755,7 @@ class TensorServingReceiverTest(test_util.TensorFlowTestCase): receiver_tensors=None) with self.assertRaisesRegexp( - ValueError, "receiver_tensors keys must be strings"): + ValueError, "receiver_tensor keys must be strings"): export.TensorServingInputReceiver( features=features, receiver_tensors={ diff --git a/tensorflow/python/estimator/inputs/numpy_io_test.py b/tensorflow/python/estimator/inputs/numpy_io_test.py index 81b201cc5c5f3d6b8211030d17006f89a545793e..4e7b00b3075fc10b9d8320008be8d23bd5092755 100644 --- a/tensorflow/python/estimator/inputs/numpy_io_test.py +++ b/tensorflow/python/estimator/inputs/numpy_io_test.py @@ -19,9 +19,15 @@ from __future__ import division from __future__ import print_function import numpy as np - +from tensorflow.python.client import session as session_lib from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column_lib as fc +from tensorflow.python.feature_column.feature_column import _LinearModel from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test from tensorflow.python.training import coordinator from tensorflow.python.training import monitored_session @@ -456,5 +462,159 @@ class NumpyIoTest(test.TestCase): self.assertAllEqual(res_arr[1], res_dict[1]) +class FeatureColumnIntegrationTest(test.TestCase): + + def _initialized_session(self, config=None): + sess = session_lib.Session(config=config) + sess.run(variables_lib.global_variables_initializer()) + sess.run(lookup_ops.tables_initializer()) + return sess + + def _get_linear_model_bias(self, name='linear_model'): + with variable_scope.variable_scope(name, reuse=True): + return variable_scope.get_variable('bias_weights') + + def _get_linear_model_column_var(self, column, name='linear_model'): + return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, + name + '/' + column.name)[0] + + def _get_keras_linear_model_predictions( + self, + features, + feature_columns, + units=1, + sparse_combiner='sum', + weight_collections=None, + trainable=True, + cols_to_vars=None): + keras_linear_model = _LinearModel( + feature_columns, + units, + sparse_combiner, + weight_collections, + trainable, + name='linear_model') + retval = keras_linear_model(features) # pylint: disable=not-callable + if cols_to_vars is not None: + cols_to_vars.update(keras_linear_model.cols_to_vars()) + return retval + + def test_linear_model_numpy_input_fn(self): + price = fc.numeric_column('price') + price_buckets = fc.bucketized_column(price, boundaries=[0., 10., 100.,]) + body_style = fc.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + + input_fn = numpy_io.numpy_input_fn( + x={ + 'price': np.array([-1., 2., 13., 104.]), + 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']), + }, + batch_size=2, + shuffle=False) + features = input_fn() + net = fc.linear_model(features, [price_buckets, body_style]) + # self.assertEqual(1 + 3 + 5, net.shape[1]) + with self._initialized_session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess, coord=coord) + + bias = self._get_linear_model_bias() + price_buckets_var = self._get_linear_model_column_var(price_buckets) + body_style_var = self._get_linear_model_column_var(body_style) + + sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) + sess.run(bias.assign([5.])) + + self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net)) + + coord.request_stop() + coord.join(threads) + + def test_linear_model_impl_numpy_input_fn(self): + price = fc.numeric_column('price') + price_buckets = fc.bucketized_column( + price, boundaries=[ + 0., + 10., + 100., + ]) + body_style = fc.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + + input_fn = numpy_io.numpy_input_fn( + x={ + 'price': np.array([-1., 2., 13., 104.]), + 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']), + }, + batch_size=2, + shuffle=False) + features = input_fn() + net = self._get_keras_linear_model_predictions( + features, [price_buckets, body_style]) + # self.assertEqual(1 + 3 + 5, net.shape[1]) + with self._initialized_session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess, coord=coord) + + bias = self._get_linear_model_bias() + price_buckets_var = self._get_linear_model_column_var(price_buckets) + body_style_var = self._get_linear_model_column_var(body_style) + + sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) + sess.run(bias.assign([5.])) + + self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net)) + + coord.request_stop() + coord.join(threads) + + def test_functional_input_layer_with_numpy_input_fn(self): + embedding_values = ( + (1., 2., 3., 4., 5.), # id 0 + (6., 7., 8., 9., 10.), # id 1 + (11., 12., 13., 14., 15.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + del shape, dtype, partition_info + return embedding_values + + # price has 1 dimension in input_layer + price = fc.numeric_column('price') + body_style = fc.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + # one_hot_body_style has 3 dims in input_layer. + one_hot_body_style = fc.indicator_column(body_style) + # embedded_body_style has 5 dims in input_layer. + embedded_body_style = fc.embedding_column(body_style, dimension=5, + initializer=_initializer) + + input_fn = numpy_io.numpy_input_fn( + x={ + 'price': np.array([11., 12., 13., 14.]), + 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']), + }, + batch_size=2, + shuffle=False) + features = input_fn() + net = fc.input_layer(features, + [price, one_hot_body_style, embedded_body_style]) + self.assertEqual(1 + 3 + 5, net.shape[1]) + with self._initialized_session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess, coord=coord) + + # Each row is formed by concatenating `embedded_body_style`, + # `one_hot_body_style`, and `price` in order. + self.assertAllEqual( + [[11., 12., 13., 14., 15., 0., 0., 1., 11.], + [1., 2., 3., 4., 5., 1., 0., 0., 12]], + sess.run(net)) + + coord.request_stop() + coord.join(threads) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py index 57f8e5fd6aff366ad1f574d5dd40a6c457966b52..616bcb410f8119e170e991f8320c5b6448ee85c9 100644 --- a/tensorflow/python/estimator/inputs/pandas_io.py +++ b/tensorflow/python/estimator/inputs/pandas_io.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six +import uuid import numpy as np from tensorflow.python.estimator.inputs.queues import feeding_functions @@ -35,6 +37,22 @@ except ImportError: HAS_PANDAS = False +def _get_unique_target_key(features, target_column_name): + """Returns a key that does not exist in the input DataFrame `features`. + + Args: + features: DataFrame + target_column_name: Name of the target column as a `str` + + Returns: + A unique key that can be used to insert the target into + features. + """ + if target_column_name in features: + target_column_name += '_' + str(uuid.uuid4()) + return target_column_name + + @estimator_export('estimator.inputs.pandas_input_fn') def pandas_input_fn(x, y=None, @@ -50,7 +68,7 @@ def pandas_input_fn(x, Args: x: pandas `DataFrame` object. - y: pandas `Series` object. `None` if absent. + y: pandas `Series` object or `DataFrame`. `None` if absent. batch_size: int, size of batches to return. num_epochs: int, number of epochs to iterate over data. If not `None`, read attempts that would exceed this value will raise `OutOfRangeError`. @@ -60,7 +78,8 @@ def pandas_input_fn(x, num_threads: Integer, number of threads used for reading and enqueueing. In order to have predicted and repeatable order of reading and enqueueing, such as in prediction and evaluation mode, `num_threads` should be 1. - target_column: str, name to give the target column `y`. + target_column: str, name to give the target column `y`. This parameter + is not used when `y` is a `DataFrame`. Returns: Function, that has signature of ()->(dict of `features`, `target`) @@ -79,6 +98,9 @@ def pandas_input_fn(x, '(it is recommended to set it as True for training); ' 'got {}'.format(shuffle)) + if not isinstance(target_column, six.string_types): + raise TypeError('target_column must be a string type') + x = x.copy() if y is not None: if target_column in x: @@ -88,7 +110,13 @@ def pandas_input_fn(x, if not np.array_equal(x.index, y.index): raise ValueError('Index for x and y are mismatched.\nIndex for x: %s\n' 'Index for y: %s\n' % (x.index, y.index)) - x[target_column] = y + if isinstance(y, pd.DataFrame): + y_columns = [(column, _get_unique_target_key(x, column)) + for column in list(y)] + target_column = [v for _, v in y_columns] + x[target_column] = y + else: + x[target_column] = y # TODO(mdan): These are memory copies. We probably don't need 4x slack space. # The sizes below are consistent with what I've seen elsewhere. @@ -118,7 +146,12 @@ def pandas_input_fn(x, features = features[1:] features = dict(zip(list(x.columns), features)) if y is not None: - target = features.pop(target_column) + if isinstance(target_column, list): + keys = [k for k, _ in y_columns] + values = [features.pop(column) for column in target_column] + target = {k: v for k, v in zip(keys, values)} + else: + target = features.pop(target_column) return features, target return features return input_fn diff --git a/tensorflow/python/estimator/inputs/pandas_io_test.py b/tensorflow/python/estimator/inputs/pandas_io_test.py index dcecf6dd61c4d24a36b2be8f054c066050d088fc..6f13bc95d2d315ad1aabfd89d5d479d65fe08502 100644 --- a/tensorflow/python/estimator/inputs/pandas_io_test.py +++ b/tensorflow/python/estimator/inputs/pandas_io_test.py @@ -47,6 +47,16 @@ class PandasIoTest(test.TestCase): y = pd.Series(np.arange(-32, -28), index=index) return x, y + def makeTestDataFrameWithYAsDataFrame(self): + index = np.arange(100, 104) + a = np.arange(4) + b = np.arange(32, 36) + a_label = np.arange(10, 14) + b_label = np.arange(50, 54) + x = pd.DataFrame({'a': a, 'b': b}, index=index) + y = pd.DataFrame({'a_target': a_label, 'b_target': b_label}, index=index) + return x, y + def callInputFnOnce(self, input_fn, session): results = input_fn() coord = coordinator.Coordinator() @@ -65,6 +75,19 @@ class PandasIoTest(test.TestCase): pandas_io.pandas_input_fn( x, y_noindex, batch_size=2, shuffle=False, num_epochs=1) + def testPandasInputFn_RaisesWhenTargetColumnIsAList(self): + if not HAS_PANDAS: + return + + x, y = self.makeTestDataFrame() + + with self.assertRaisesRegexp(TypeError, + 'target_column must be a string type'): + pandas_io.pandas_input_fn(x, y, batch_size=2, + shuffle=False, + num_epochs=1, + target_column=['one', 'two']) + def testPandasInputFn_NonBoolShuffle(self): if not HAS_PANDAS: return @@ -90,6 +113,53 @@ class PandasIoTest(test.TestCase): self.assertAllEqual(features['b'], [32, 33]) self.assertAllEqual(target, [-32, -31]) + def testPandasInputFnWhenYIsDataFrame_ProducesExpectedOutput(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + x, y = self.makeTestDataFrameWithYAsDataFrame() + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + + features, targets = self.callInputFnOnce(input_fn, session) + + self.assertAllEqual(features['a'], [0, 1]) + self.assertAllEqual(features['b'], [32, 33]) + self.assertAllEqual(targets['a_target'], [10, 11]) + self.assertAllEqual(targets['b_target'], [50, 51]) + + def testPandasInputFnYIsDataFrame_HandlesOverlappingColumns(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + x, y = self.makeTestDataFrameWithYAsDataFrame() + y = y.rename(columns={'a_target': 'a', 'b_target': 'b'}) + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + + features, targets = self.callInputFnOnce(input_fn, session) + + self.assertAllEqual(features['a'], [0, 1]) + self.assertAllEqual(features['b'], [32, 33]) + self.assertAllEqual(targets['a'], [10, 11]) + self.assertAllEqual(targets['b'], [50, 51]) + + def testPandasInputFnYIsDataFrame_HandlesOverlappingColumnsInTargets(self): + if not HAS_PANDAS: + return + with self.test_session() as session: + x, y = self.makeTestDataFrameWithYAsDataFrame() + y = y.rename(columns={'a_target': 'a', 'b_target': 'a_n'}) + input_fn = pandas_io.pandas_input_fn( + x, y, batch_size=2, shuffle=False, num_epochs=1) + + features, targets = self.callInputFnOnce(input_fn, session) + + self.assertAllEqual(features['a'], [0, 1]) + self.assertAllEqual(features['b'], [32, 33]) + self.assertAllEqual(targets['a'], [10, 11]) + self.assertAllEqual(targets['a_n'], [50, 51]) + def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self): if not HAS_PANDAS: return diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py index 2f439f765e6811335667b62437f7aafc934904dc..e4ce5339d09e13fd9b58c882723f6f85d3fe2c60 100644 --- a/tensorflow/python/estimator/keras.py +++ b/tensorflow/python/estimator/keras.py @@ -21,11 +21,11 @@ from __future__ import print_function import os import re + from tensorflow.python.client import session from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import export as export_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib @@ -39,13 +39,16 @@ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_module -from tensorflow.python.ops import variables as variables_module +from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import data_structures _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -70,16 +73,22 @@ def _convert_tensor(x): return x -def _any_variable_initialized(): - """Check if any variable has been initialized in the Keras model. +def _any_weight_initialized(keras_model): + """Check if any weights has been initialized in the Keras model. + + Args: + keras_model: An instance of compiled keras model. Returns: - boolean, True if at least one variable has been initialized, else False. + boolean, True if at least one weight has been initialized, else False. + Currently keras initialize all weights at get_session(). """ - variables = variables_module.global_variables() - for v in variables: - if getattr(v, '_keras_initialized', False): - return True + if keras_model is None: + return False + for layer in keras_model.layers: + for weight in layer.weights: + if hasattr(weight, '_keras_initialized'): + return True return False @@ -123,8 +132,8 @@ def _create_ordered_io(keras_model, estimator_io, is_input=True): 'It needs to match one ' 'of the following: %s' % ('input' if is_input else 'output', key, ', '.join(keras_io_names))) - tensors = [_convert_tensor(estimator_io[io_name]) - for io_name in keras_io_names] + tensors = [_convert_tensor(estimator_io[io_name]) + for io_name in keras_io_names] return tensors else: # Plain array. @@ -174,7 +183,7 @@ def _in_place_subclassed_model_reset(model): # Replace layers on the model with fresh layers layers_to_names = {value: key for key, value in attributes_cache.items()} original_layers = model._layers[:] - model._layers = [] + model._layers = data_structures.NoDependency([]) for layer in original_layers: # We preserve layer order. config = layer.get_config() # This will not work for nested subclassed models used as layers. @@ -222,7 +231,8 @@ def _in_place_subclassed_model_reset(model): ] for name in attributes_to_cache: attributes_cache[name] = getattr(model, name) - model._original_attributes_cache = attributes_cache + model._original_attributes_cache = data_structures.NoDependency( + attributes_cache) # Reset built state model.built = False model.inputs = None @@ -242,8 +252,17 @@ def _in_place_subclassed_model_state_restoration(model): # Restore layers and build attributes if (hasattr(model, '_original_attributes_cache') and model._original_attributes_cache is not None): - model._layers = [] + # Models have sticky attribute assignment, so we want to be careful to add + # back the previous attributes and track Layers by their original names + # without adding dependencies on "utility" attributes which Models exempt + # when they're constructed. + model._layers = data_structures.NoDependency([]) for name, value in model._original_attributes_cache.items(): + if not isinstance(value, checkpointable.CheckpointableBase): + # If this value is not already checkpointable, it's probably that way + # for a reason; we don't want to start tracking data structures that the + # original Model didn't. + value = data_structures.NoDependency(value) setattr(model, name, value) model._original_attributes_cache = None else: @@ -340,13 +359,21 @@ def _create_keras_model_fn(keras_model, custom_objects=None): def model_fn(features, labels, mode): """model_fn for keras Estimator.""" + # Raise an error when users use DistributionStrategy with native Keras + # optimizers. Currently we only support native TensorFlow optimizers. + if distribution_strategy_context.has_distribution_strategy() and \ + not isinstance(keras_model.optimizer, + (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)): + raise ValueError('Only TensorFlow native optimizers are supported with ' + 'DistributionStrategy.') + model = _clone_and_build_model(mode, keras_model, custom_objects, features, labels) model_output_names = [] # We need to make sure that the output names of the last layer in the model # is the same for each of the cloned models. This is required for mirrored # strategy when we call regroup. - if distribute_lib.has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): for name in model.output_names: name = re.compile(r'_\d$').sub('', name) model_output_names.append(name) @@ -369,7 +396,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None): loss = model.total_loss if model.metrics: - # TODO(fchollet): support stateful metrics + # TODO(psv/fchollet): support stateful metrics eval_metric_ops = {} # When each metric maps to an output if isinstance(model.metrics, dict): @@ -411,29 +438,34 @@ def _create_keras_model_fn(keras_model, custom_objects=None): return model_fn -def _save_first_checkpoint(keras_model, estimator, custom_objects, - keras_weights): +def _save_first_checkpoint(keras_model, custom_objects, config): """Save first checkpoint for the keras Estimator. Args: keras_model: an instance of compiled keras model. - estimator: keras estimator. custom_objects: Dictionary for custom objects. - keras_weights: A flat list of Numpy arrays for weights of given keras_model. + config: Estimator config. Returns: - The model_fn for a keras Estimator. + The path where keras model checkpoint is saved. """ + # save checkpoint into subdirectory to allow warm start + keras_model_dir = os.path.join(config.model_dir, 'keras') # Load weights and save to checkpoint if there is no checkpoint - latest_path = saver_lib.latest_checkpoint(estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint(keras_model_dir) if not latest_path: + keras_weights = None + if _any_weight_initialized(keras_model): + keras_weights = keras_model.get_weights() + if not gfile.IsDirectory(keras_model_dir): + gfile.MakeDirs(keras_model_dir) with ops.Graph().as_default(): - random_seed.set_random_seed(estimator.config.tf_random_seed) + random_seed.set_random_seed(config.tf_random_seed) training_util.create_global_step() model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model, custom_objects) # save to checkpoint - with session.Session(config=estimator._session_config) as sess: + with session.Session(config=config.session_config) as sess: if keras_weights: model.set_weights(keras_weights) # Make update ops and initialize all variables. @@ -443,10 +475,11 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects, K._initialize_variables(sess) # pylint: enable=protected-access saver = saver_lib.Saver() - saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt')) + latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt') + saver.save(sess, latest_path) + return latest_path -@tf_export('keras.estimator.model_to_estimator') def model_to_estimator(keras_model=None, keras_model_path=None, custom_objects=None, @@ -455,7 +488,7 @@ def model_to_estimator(keras_model=None, """Constructs an `Estimator` instance from given keras model. For usage example, please see - @{$programmers_guide/estimators$creating_estimators_from_keras_models}. + @{$guide/estimators$creating_estimators_from_keras_models}. Args: keras_model: A compiled Keras model object. This argument is mutually @@ -464,9 +497,9 @@ def model_to_estimator(keras_model=None, format, which can be generated with the `save()` method of a Keras model. This argument is mutually exclusive with `keras_model`. custom_objects: Dictionary for custom objects. - model_dir: Directory to save Estimator model parameters, graph, summary + model_dir: Directory to save `Estimator` model parameters, graph, summary files for TensorBoard, etc. - config: Configuration object. + config: `RunConfig` to config `Estimator`. Returns: An Estimator from given keras model. @@ -503,45 +536,40 @@ def model_to_estimator(keras_model=None, 'Please compile the model with `model.compile()` ' 'before calling `model_to_estimator()`.') - if isinstance(config, dict): - config = run_config_lib.RunConfig(**config) + config = estimator_lib.maybe_overwrite_model_dir_and_session_config(config, + model_dir) keras_model_fn = _create_keras_model_fn(keras_model, custom_objects) - estimator = estimator_lib.Estimator( - keras_model_fn, model_dir=model_dir, config=config) - - # Check if we need to call get_weights: - if _any_variable_initialized(): - keras_weights = keras_model.get_weights() + if _any_weight_initialized(keras_model): # Warn if config passed to estimator tries to update GPUOptions. If a # session has already been created, the GPUOptions passed to the first # session sticks. - if estimator._session_config.HasField('gpu_options'): + if config.session_config.HasField('gpu_options'): logging.warning( 'The Keras backend session has already been set. ' 'The _session_config passed to model_to_estimator will not be used.') else: # Pass the config into keras backend's default session. - sess = session.Session(config=estimator._session_config) + sess = session.Session(config=config.session_config) K.set_session(sess) - keras_weights = None + warm_start_path = None if keras_model._is_graph_network: - # TODO(yifeif): move checkpoint initialization to scaffold.init_fn - _save_first_checkpoint(keras_model, - estimator, - custom_objects, - keras_weights) + warm_start_path = _save_first_checkpoint(keras_model, custom_objects, + config) elif keras_model.built: - logging.warning('You are creating an Estimator from a Keras model ' - 'manually subclassed from `Model`, that was ' - 'already called on some inputs (and thus already had ' - 'weights). We are currently unable to preserve ' - 'the model\'s state (its weights) ' - 'as part of the estimator ' - 'in this case. Be warned that the estimator ' - 'has been created using ' - 'a freshly initialized version of your model.\n' - 'Note that this doesn\'t affect the state of the ' - 'model instance you passed as `keras_model` argument.') + logging.warning('You are creating an Estimator from a Keras model manually ' + 'subclassed from `Model`, that was already called on some ' + 'inputs (and thus already had weights). We are currently ' + 'unable to preserve the model\'s state (its weights) as ' + 'part of the estimator in this case. Be warned that the ' + 'estimator has been created using a freshly initialized ' + 'version of your model.\n' + 'Note that this doesn\'t affect the state of the model ' + 'instance you passed as `keras_model` argument.') + + estimator = estimator_lib.Estimator(keras_model_fn, + config=config, + warm_start_from=warm_start_path) + return estimator diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py index 5e094ae92bcf88a48d7afe3fb88bbced4971b587..332e38572622d230616dd55fc7a7f1bcf62dbe96 100644 --- a/tensorflow/python/estimator/keras_test.py +++ b/tensorflow/python/estimator/keras_test.py @@ -32,13 +32,14 @@ from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils -from tensorflow.python.keras.applications import mobilenet from tensorflow.python.keras.optimizers import SGD +from tensorflow.python.ops import variable_scope from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import rmsprop +from tensorflow.python.training import session_run_hook try: @@ -51,6 +52,8 @@ _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) _NUM_CLASS = 2 +_TMP_DIR = '/tmp' + def simple_sequential_model(): model = keras.models.Sequential() @@ -60,9 +63,9 @@ def simple_sequential_model(): return model -def simple_functional_model(): +def simple_functional_model(activation='relu'): a = keras.layers.Input(shape=_INPUT_SIZE) - b = keras.layers.Dense(16, activation='relu')(a) + b = keras.layers.Dense(16, activation=activation)(a) b = keras.layers.Dropout(0.1)(b) b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b) model = keras.models.Model(inputs=[a], outputs=[b]) @@ -168,6 +171,12 @@ def multi_inputs_multi_outputs_model(): return model +class MyHook(session_run_hook.SessionRunHook): + + def begin(self): + _ = variable_scope.get_variable('temp', [1]) + + class TestKerasEstimator(test_util.TensorFlowTestCase): def setUp(self): @@ -204,6 +213,55 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) + # see b/109935364 + @test_util.run_in_graph_and_eager_modes + def test_train_with_hooks(self): + for model_type in ['sequential', 'functional']: + keras_model, (_, _), ( + _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model( + model_type=model_type, is_evaluate=True) + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(1e-3), + metrics=['mse', keras.metrics.categorical_accuracy]) + + my_hook = MyHook() + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=self._config) + before_eval_results = est_keras.evaluate( + input_fn=eval_input_fn, steps=1) + est_keras.train(input_fn=train_input_fn, hooks=[my_hook], + steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + @test_util.run_in_graph_and_eager_modes + def test_train_with_model_fit_and_hooks(self): + keras_model, (x_train, y_train), _, \ + train_input_fn, eval_input_fn = get_resource_for_simple_model( + model_type='sequential', is_evaluate=True) + + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(1e-3), + metrics=['mse', keras.metrics.categorical_accuracy]) + my_hook = MyHook() + with self.test_session(): + keras_model.fit(x_train, y_train, epochs=1) + + keras_est = keras_lib.model_to_estimator( + keras_model=keras_model, config=self._config) + before_eval_results = keras_est.evaluate(input_fn=eval_input_fn) + keras_est.train(input_fn=train_input_fn, hooks=[my_hook], + steps=_TRAIN_SIZE / 16) + after_eval_results = keras_est.evaluate(input_fn=eval_input_fn, steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + @test_util.run_in_graph_and_eager_modes def test_train_with_tf_optimizer(self): for model_type in ['sequential', 'functional']: keras_model, (_, _), ( @@ -217,11 +275,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): with self.test_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, - # Also use dict config argument to get test coverage for that line. - config={ - 'tf_random_seed': _RANDOM_SEED, - 'model_dir': self._base_dir, - }) + config=self._config) before_eval_results = est_keras.evaluate( input_fn=eval_input_fn, steps=1) est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16) @@ -231,6 +285,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) + @test_util.run_in_graph_and_eager_modes def test_train_with_subclassed_model(self): keras_model, (_, _), ( _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model( @@ -472,23 +527,43 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): est_keras.train(input_fn=invald_output_name_input_fn, steps=100) def test_custom_objects(self): - keras_mobile = mobilenet.MobileNet(weights=None) - keras_mobile.compile(loss='categorical_crossentropy', optimizer='adam') + + def relu6(x): + return keras.backend.relu(x, max_value=6) + + keras_model = simple_functional_model(activation=relu6) + keras_model.compile(loss='categorical_crossentropy', optimizer='adam') custom_objects = { - 'relu6': mobilenet.relu6, - 'DepthwiseConv2D': mobilenet.DepthwiseConv2D + 'relu6': relu6 } + + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(10,), + num_classes=2) + y_train = keras.utils.to_categorical(y_train, 2) + input_name = keras_model.input_names[0] + output_name = keras_model.output_names[0] + train_input_fn = numpy_io.numpy_input_fn( + x=randomize_io_type(x_train, input_name), + y=randomize_io_type(y_train, output_name), + shuffle=False, + num_epochs=None, + batch_size=16) with self.assertRaisesRegexp(ValueError, 'relu6'): with self.test_session(): - keras_lib.model_to_estimator( - keras_model=keras_mobile, + est = keras_lib.model_to_estimator( + keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir)) + est.train(input_fn=train_input_fn, steps=1) with self.test_session(): - keras_lib.model_to_estimator( - keras_model=keras_mobile, + est = keras_lib.model_to_estimator( + keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir), custom_objects=custom_objects) + est.train(input_fn=train_input_fn, steps=1) def test_tf_config(self): keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model() @@ -525,12 +600,73 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3) sess_config = config_pb2.ConfigProto(gpu_options=gpu_options) self._config._session_config = sess_config - keras_lib.model_to_estimator( - keras_model=keras_model, config=self._config) - self.assertEqual( - keras.backend.get_session() - ._config.gpu_options.per_process_gpu_memory_fraction, - gpu_options.per_process_gpu_memory_fraction) + with self.test_session(): + keras_lib.model_to_estimator( + keras_model=keras_model, config=self._config) + self.assertEqual( + keras.backend.get_session() + ._config.gpu_options.per_process_gpu_memory_fraction, + gpu_options.per_process_gpu_memory_fraction) + + def test_with_empty_config(self): + keras_model, _, _, _, _ = get_resource_for_simple_model( + model_type='sequential', is_evaluate=True) + keras_model.compile( + loss='categorical_crossentropy', + optimizer='rmsprop', + metrics=['mse', keras.metrics.categorical_accuracy]) + + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, model_dir=self._base_dir, + config=run_config_lib.RunConfig()) + self.assertEqual(run_config_lib.get_default_session_config(), + est_keras._session_config) + self.assertEqual(est_keras._session_config, + est_keras._config.session_config) + self.assertEqual(self._base_dir, est_keras._config.model_dir) + self.assertEqual(self._base_dir, est_keras._model_dir) + + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, model_dir=self._base_dir, + config=None) + self.assertEqual(run_config_lib.get_default_session_config(), + est_keras._session_config) + self.assertEqual(est_keras._session_config, + est_keras._config.session_config) + self.assertEqual(self._base_dir, est_keras._config.model_dir) + self.assertEqual(self._base_dir, est_keras._model_dir) + + def test_with_empty_config_and_empty_model_dir(self): + keras_model, _, _, _, _ = get_resource_for_simple_model( + model_type='sequential', is_evaluate=True) + keras_model.compile( + loss='categorical_crossentropy', + optimizer='rmsprop', + metrics=['mse', keras.metrics.categorical_accuracy]) + + with self.test_session(): + with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, + config=run_config_lib.RunConfig()) + self.assertEqual(est_keras._model_dir, _TMP_DIR) + + def test_with_conflicting_model_dir_and_config(self): + keras_model, _, _, _, _ = get_resource_for_simple_model( + model_type='sequential', is_evaluate=True) + keras_model.compile( + loss='categorical_crossentropy', + optimizer='rmsprop', + metrics=['mse', keras.metrics.categorical_accuracy]) + + with self.test_session(): + with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in ' + 'constructor and `RunConfig`'): + keras_lib.model_to_estimator( + keras_model=keras_model, model_dir=self._base_dir, + config=run_config_lib.RunConfig(model_dir=_TMP_DIR)) def test_pretrained_weights(self): keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model() diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index c60c7f63bacc810e447bcafe954c55cb49ede7e0..9db9ccd01d3fe996c2b047d65cc2f35747b005e1 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -23,7 +23,7 @@ import collections import six -from tensorflow.python.estimator.export.export_output import ExportOutput +from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -99,7 +99,7 @@ class EstimatorSpec( ignored in eval and infer modes. Example: ```python - def my_model_fn(mode, features, labels): + def my_model_fn(features, labels, mode): predictions = ... loss = ... train_op = ... @@ -114,7 +114,7 @@ class EstimatorSpec( given mode. Example: ```python - def my_model_fn(mode, features, labels): + def my_model_fn(features, labels, mode): if (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL): loss = ... @@ -158,6 +158,8 @@ class EstimatorSpec( Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY. + If no entry is provided, a default `PredictOutput` mapping to + `predictions` will be created. training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to run on the chief worker during training. training_hooks: Iterable of `tf.train.SessionRunHook` objects to run @@ -232,29 +234,9 @@ class EstimatorSpec( _check_is_tensor_or_operation(metric_update, 'eval_metric_ops[{}]'.format(key)) - # Validate export_outputs. - if export_outputs is not None: - if not isinstance(export_outputs, dict): - raise TypeError('export_outputs must be dict, given: {}'.format( - export_outputs)) - for v in six.itervalues(export_outputs): - if not isinstance(v, ExportOutput): - raise TypeError( - 'Values in export_outputs must be ExportOutput objects. ' - 'Given: {}'.format(export_outputs)) - # Note export_outputs is allowed to be empty. - if len(export_outputs) == 1: - (key, value), = export_outputs.items() - if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: - export_outputs[ - signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value - if len(export_outputs) > 1: - if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - not in export_outputs): - raise ValueError( - 'Multiple export_outputs were provided, but none of them is ' - 'specified as the default. Do this by naming one of them with ' - 'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.') + # Validate the passed export outputs, or generate defaults. + if mode == ModeKeys.PREDICT: + export_outputs = _get_export_outputs(export_outputs, predictions) # Validate that all tensors and ops are from the default graph. default_graph = ops.get_default_graph() @@ -286,11 +268,11 @@ class EstimatorSpec( raise ValueError(error_message_template.format('train_op', train_op.name)) for key, value in list(six.iteritems(eval_metric_ops)): values = nest.flatten(value) - for value in values: - if value.graph is not default_graph: + for val in values: + if val.graph is not default_graph: raise ValueError(error_message_template.format( 'eval_metric_ops', - '{0}: {1}'.format(key, value.name))) + '{0}: {1}'.format(key, val.name))) # Validate hooks. training_chief_hooks = tuple(training_chief_hooks or []) @@ -334,15 +316,76 @@ class EstimatorSpec( return EstimatorSpec(*new_fields) -class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [ - 'mode', - 'predictions', - 'loss', - 'train_op', - 'eval_metrics', - 'export_outputs', - 'scaffold_fn', - 'host_call'])): +def _get_export_outputs(export_outputs, predictions): + """Validate export_outputs or create default export_outputs. + + Args: + export_outputs: Describes the output signatures to be exported to + `SavedModel` and used during serving. Should be a dict or None. + predictions: Predictions `Tensor` or dict of `Tensor`. + + Returns: + Valid export_outputs dict + + Raises: + TypeError: if export_outputs is not a dict or its values are not + ExportOutput instances. + """ + if export_outputs is None: + default_output = export_output_lib.PredictOutput(predictions) + export_outputs = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: default_output} + + if not isinstance(export_outputs, dict): + raise TypeError('export_outputs must be dict, given: {}'.format( + export_outputs)) + for v in six.itervalues(export_outputs): + if not isinstance(v, export_output_lib.ExportOutput): + raise TypeError( + 'Values in export_outputs must be ExportOutput objects. ' + 'Given: {}'.format(export_outputs)) + + _maybe_add_default_serving_output(export_outputs) + + return export_outputs + + +def _maybe_add_default_serving_output(export_outputs): + """Add a default serving output to the export_outputs if not present. + + Args: + export_outputs: Describes the output signatures to be exported to + `SavedModel` and used during serving. Should be a dict. + + Returns: + export_outputs dict with default serving signature added if necessary + + Raises: + ValueError: if multiple export_outputs were provided without a default + serving key. + """ + if len(export_outputs) == 1: + (key, value), = export_outputs.items() + if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_outputs[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value + if len(export_outputs) > 1: + if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + not in export_outputs): + raise ValueError( + 'Multiple export_outputs were provided, but none of them is ' + 'specified as the default. Do this by naming one of them with ' + 'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.') + + return export_outputs + + +class _TPUEstimatorSpec( + collections.namedtuple('TPUEstimatorSpec', [ + 'mode', 'predictions', 'loss', 'train_op', 'eval_metrics', + 'export_outputs', 'scaffold_fn', 'host_call', 'training_hooks', + 'evaluation_hooks', 'prediction_hooks' + ])): """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. This is a simplified implementation of `tf.contrib.tpu.EstimatorSpec`. See @@ -358,17 +401,24 @@ class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [ eval_metrics=None, export_outputs=None, scaffold_fn=None, - host_call=None): + host_call=None, + training_hooks=None, + evaluation_hooks=None, + prediction_hooks=None): """Creates a `_TPUEstimatorSpec` instance.""" - return super(_TPUEstimatorSpec, cls).__new__(cls, - mode=mode, - predictions=predictions, - loss=loss, - train_op=train_op, - eval_metrics=eval_metrics, - export_outputs=export_outputs, - scaffold_fn=scaffold_fn, - host_call=host_call) + return super(_TPUEstimatorSpec, cls).__new__( + cls, + mode=mode, + predictions=predictions, + loss=loss, + train_op=train_op, + eval_metrics=eval_metrics, + export_outputs=export_outputs, + scaffold_fn=scaffold_fn, + host_call=host_call, + training_hooks=training_hooks, + evaluation_hooks=evaluation_hooks, + prediction_hooks=prediction_hooks) def as_estimator_spec(self): """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" @@ -377,12 +427,16 @@ class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [ else: metric_fn, tensors = self.eval_metrics eval_metric_ops = metric_fn(**tensors) - return EstimatorSpec(mode=self.mode, - predictions=self.predictions, - loss=self.loss, - train_op=self.train_op, - eval_metric_ops=eval_metric_ops, - export_outputs=self.export_outputs) + return EstimatorSpec( + mode=self.mode, + predictions=self.predictions, + loss=self.loss, + train_op=self.train_op, + eval_metric_ops=eval_metric_ops, + export_outputs=self.export_outputs, + training_hooks=self.training_hooks, + evaluation_hooks=self.evaluation_hooks, + prediction_hooks=self.prediction_hooks) def _check_is_tensor_or_operation(x, name): diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py index b7eeeb437cb4a624cdee552be3032364b18a8290..08e41fd4146e9254fc8cc7da6bc809e80d053a5b 100644 --- a/tensorflow/python/estimator/model_fn_test.py +++ b/tensorflow/python/estimator/model_fn_test.py @@ -592,6 +592,27 @@ class EstimatorSpecInferTest(test.TestCase): predictions=predictions, export_outputs=export_outputs) + def testDefaultExportOutputCreated(self): + """Ensure that a default PredictOutput is created for export.""" + with ops.Graph().as_default(), self.test_session(): + predictions = constant_op.constant(1.) + self._assertDefaultExportOutputForPredictions(predictions) + + def testDefaultExportOutputCreatedDict(self): + """Ensure that a default PredictOutput is created for export for dicts.""" + with ops.Graph().as_default(), self.test_session(): + predictions = {'loss': constant_op.constant(1.), + 'score': constant_op.constant(10.)} + self._assertDefaultExportOutputForPredictions(predictions) + + def _assertDefaultExportOutputForPredictions(self, predictions): + spec = model_fn.EstimatorSpec( + mode=model_fn.ModeKeys.PREDICT, predictions=predictions) + + expected = export_output.PredictOutput(predictions).outputs + serving_output = spec.export_outputs[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + self.assertEqual(serving_output.outputs, expected) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index b948ce96e0c09c0537619366403658408cf17895..220c3e58ca5c2c268bf8d4bb889793c5d9151c95 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -25,6 +25,7 @@ import os import six from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.util import compat_internal @@ -47,7 +48,9 @@ _DEFAULT_REPLACEABLE_LIST = [ 'keep_checkpoint_every_n_hours', 'log_step_count_steps', 'train_distribute', - 'device_fn' + 'device_fn', + 'protocol', + 'eval_distribute', ] _SAVE_CKPT_ERR = ( @@ -287,6 +290,21 @@ def _validate_properties(run_config): message='device_fn must be callable with exactly' ' one argument "op".') + _validate('protocol', + lambda protocol: protocol in (None, "grpc", "grpc+verbs"), + message='protocol should be grpc or grpc+verbs') + + +def get_default_session_config(): + """Returns tf.ConfigProto instance.""" + + rewrite_opts = rewriter_config_pb2.RewriterConfig( + meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE) + graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts) + + return config_pb2.ConfigProto(allow_soft_placement=True, + graph_options=graph_opts) + class TaskType(object): MASTER = 'master' @@ -311,7 +329,9 @@ class RunConfig(object): keep_checkpoint_every_n_hours=10000, log_step_count_steps=100, train_distribute=None, - device_fn=None): + device_fn=None, + protocol=None, + eval_distribute=None): """Constructs a RunConfig. All distributed training related properties `cluster_spec`, `is_chief`, @@ -435,7 +455,7 @@ class RunConfig(object): the feature. log_step_count_steps: The frequency, in number of global steps, that the global step/sec and the loss will be logged during training. - train_distribute: an optional instance of + train_distribute: An optional instance of `tf.contrib.distribute.DistributionStrategy`. If specified, then Estimator will distribute the user's model during training, according to the policy specified by that strategy. @@ -443,6 +463,12 @@ class RunConfig(object): `Operation` and returns the device string. If `None`, defaults to the device function returned by `tf.train.replica_device_setter` with round-robin strategy. + protocol: An optional argument which specifies the protocol used when + starting server. None means default to grpc. + eval_distribute: An optional instance of + `tf.contrib.distribute.DistributionStrategy`. If specified, + then Estimator will distribute the user's model during evaluation, + according to the policy specified by that strategy. Raises: ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs` @@ -480,10 +506,58 @@ class RunConfig(object): keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, log_step_count_steps=log_step_count_steps, train_distribute=train_distribute, - device_fn=device_fn) + device_fn=device_fn, + protocol=protocol, + eval_distribute=eval_distribute) self._init_distributed_setting_from_environment_var(tf_config) + self._maybe_overwrite_session_config_for_distributed_training() + + def _maybe_overwrite_session_config_for_distributed_training(self): + """Overwrites the session_config for distributed training. + + The default overwrite is optimized for between-graph training. Subclass + should override this method if necessary. + """ + # Get session_config only for between-graph distributed mode (cluster_spec + # is present). + if not self._session_config and self._cluster_spec: + RunConfig._replace( + self, + allowed_properties_list=_DEFAULT_REPLACEABLE_LIST, + session_config=self._get_default_session_config_distributed()) + + def _get_default_session_config_distributed(self): + """Returns None or tf.ConfigProto instance with default device_filters set. + + Device filters are set such that chief/master and worker communicates with + only ps. session_config=None for evaluators or any other TaskType. + """ + + rewrite_opts = rewriter_config_pb2.RewriterConfig( + meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE) + graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts) + + device_filters = None + if self._task_type == TaskType.MASTER: + device_filters = ['/job:ps', '/job:master'] + elif self._task_type == TaskType.CHIEF: + device_filters = ['/job:ps', '/job:chief'] + elif self._task_type == TaskType.WORKER: + device_filters = ['/job:ps', '/job:worker/task:%d' % self._task_id] + elif self._task_type == TaskType.PS: + device_filters = ['/job:ps', '/job:worker', '/job:master'] + else: + # If the task_type is `EVALUATOR` or something other than the ones in + # TaskType then don't set any device filters. + return None + + return config_pb2.ConfigProto( + allow_soft_placement=True, + graph_options=graph_opts, + device_filters=device_filters) + def _init_distributed_setting_from_environment_var(self, tf_config): """Initialize distributed properties based on `tf_config`.""" @@ -703,10 +777,21 @@ class RunConfig(object): @property def train_distribute(self): - """Returns the optional `tf.contrib.distribute.DistributionStrategy` object. + """Optional `tf.contrib.distribute.DistributionStrategy` for training. """ return self._train_distribute + @property + def eval_distribute(self): + """Optional `tf.contrib.distribute.DistributionStrategy` for evaluation. + """ + return self._eval_distribute + + @property + def protocol(self): + """Returns the optional protocol value.""" + return self._protocol + def replace(self, **kwargs): """Returns a new instance of `RunConfig` replacing specified properties. @@ -722,7 +807,9 @@ class RunConfig(object): - `keep_checkpoint_every_n_hours`, - `log_step_count_steps`, - `train_distribute`, - - `device_fn`. + - `device_fn`, + - `protocol`. + - `eval_distribute`, In addition, either `save_checkpoints_steps` or `save_checkpoints_secs` can be set (should not be both). diff --git a/tensorflow/python/estimator/run_config_test.py b/tensorflow/python/estimator/run_config_test.py index c8b12605e1aaad11e114e4ace63697b93f3b2b92..06df7cb9dd4ae3d167d622601e551079b64e80a2 100644 --- a/tensorflow/python/estimator/run_config_test.py +++ b/tensorflow/python/estimator/run_config_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import json from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.platform import test @@ -290,6 +291,7 @@ class RunConfigDistributedSettingTest(test.TestCase): expected_num_worker_replicas=1, expected_num_ps_replicas=0) self.assertEqual(0, run_config.global_id_in_cluster) + self.assertIsNone(run_config.session_config, None) def test_session_master_for_local(self): tf_config = {'session_master': '_my_master'} @@ -1119,5 +1121,115 @@ class RunConfigModelDirTest(test.TestCase): _create_run_config_with_cluster_spec(tf_config) +class RunConfigSessionConfigTest(test.TestCase): + + def _assert_equal_session_config(self, session_config, + expected_device_filters): + + rewrite_opts = rewriter_config_pb2.RewriterConfig( + meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE) + graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts) + expected_session_config = config_pb2.ConfigProto( + allow_soft_placement=True, + graph_options=graph_opts, + device_filters=expected_device_filters) + self.assertEqual(session_config, expected_session_config) + + def test_master_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.MASTER: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.MASTER, + 'index': 0 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self._assert_equal_session_config(run_config.session_config, + ['/job:ps', '/job:master']) + + def test_chief_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self._assert_equal_session_config(run_config.session_config, + ['/job:ps', '/job:chief']) + + def test_worker_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.MASTER: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.WORKER, + 'index': 1 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self._assert_equal_session_config(run_config.session_config, + ['/job:ps', '/job:worker/task:1']) + + def test_ps_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.MASTER: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.PS, + 'index': 1 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self._assert_equal_session_config(run_config.session_config, + ['/job:ps', '/job:worker', '/job:master']) + + def test_evaluator_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.EVALUATOR, + 'index': 0 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self.assertIsNone(run_config.session_config) + + def test_other_type_session_config(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.MASTER: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], + 'other_type': ['host3:1', 'host4:2'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': 'other_type', + 'index': 0 + } + } + run_config = _create_run_config_with_cluster_spec(tf_config) + self.assertIsNone(run_config.session_config) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 1572af579b964e8cf5cdb3d5d11a56d80b965b5c..bb1305767f8d8d565943bf1aa6f5e1784463d58e 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -278,10 +278,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): supported distributed training configuration is between-graph replication. Overfitting: In order to avoid overfitting, it is recommended to set up the - training `input_fn` to shuffle the training data properly. It is also - recommended to train the model a little longer, say multiple epochs, before - performing evaluation, as the input pipeline starts from scratch for each - training. It is particularly important for local training and evaluation. + training `input_fn` to shuffle the training data properly. Stop condition: In order to support both distributed and non-distributed configuration reliably, the only supported stop condition for model @@ -315,10 +312,10 @@ def train_and_evaluate(estimator, train_spec, eval_spec): # hidden_units=[1024, 512, 256]) # Input pipeline for train and evaluate. - def train_input_fn: # returns x, y + def train_input_fn(): # returns x, y # please shuffle the data. pass - def eval_input_fn_eval: # returns x, y + def eval_input_fn(): # returns x, y pass train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000) @@ -326,6 +323,10 @@ def train_and_evaluate(estimator, train_spec, eval_spec): tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) ``` + Note that in current implementation `estimator.evaluate` will be called + multiple times. This means that evaluation graph (including eval_input_fn) + will be re-created for each `evaluate` call. `estimator.train` will be called + only once. Example of distributed training: @@ -470,6 +471,61 @@ class _StopAtSecsHook(session_run_hook.SessionRunHook): run_context.request_stop() +class _NewCheckpointListenerForEvaluate( + basic_session_run_hooks.CheckpointSaverListener): + """A saver listener to run evaluate with every checkpoint.""" + + def __init__(self, evaluator, eval_throttle_secs, continuous_eval_listener): + self._evaluator = evaluator + self._eval_throttle_secs = eval_throttle_secs + self._continuous_eval_listener = continuous_eval_listener + self.eval_result, self.export_results = None, None + + def begin(self): + self._timer = basic_session_run_hooks.SecondOrStepTimer( + every_secs=self._eval_throttle_secs) + self._is_first_run = True + + def after_save(self, session, global_step_value): + del session # unused; required by signature. + # skip first run model is not trained yet. + if self._is_first_run: + self._is_first_run = False + return + + if not self._continuous_eval_listener.before_eval(): + logging.info('Exiting training and evaluation loop, as requested by ' + '_ContinuousEvalListener.before_eval.') + return True + if self._timer.should_trigger_for_step(global_step_value): + self._evaluate(global_step_value) # updates self.eval_result + if not self._continuous_eval_listener.after_eval(self.eval_result): + logging.info('Exiting evaluation, as requested by ' + '_ContinuousEvalListener.after_eval.') + return True + else: + # TODO(ispir): add remaining time in the log. + logging.info('Skip the current checkpoint eval due to throttle secs ' + '({} secs).'.format(self._eval_throttle_secs)) + + def end(self, session, global_step_value): + # Evaluate if the last step has not been evaluated, yet. + if global_step_value != self._timer.last_triggered_step(): + if self._continuous_eval_listener.before_eval(): + self._evaluate(global_step_value) + self._continuous_eval_listener.after_eval(self.eval_result) + + def _evaluate(self, global_step_value): + self._timer.update_last_triggered_step(global_step_value) + self.eval_result, self.export_results = ( + self._evaluator.evaluate_and_export()) + if self.eval_result.status != _EvalStatus.EVALUATED: + # This is unexpected; should never happen. + # Training should always end with a new checkpoint. + raise RuntimeError('There was no new checkpoint after the training. ' + 'Eval status: {}'.format(self.eval_result.status)) + + class _TrainingExecutor(object): """The executor to run `Estimator` training and evaluation. @@ -576,28 +632,6 @@ class _TrainingExecutor(object): def run_master(self): """Runs task master.""" - - class NewCheckpointListener( - basic_session_run_hooks.CheckpointSaverListener): - - def __init__(self, evaluator, eval_throttle_secs): - self._evaluator = evaluator - self._eval_throttle_secs = eval_throttle_secs - - def begin(self): - self._timer = basic_session_run_hooks.SecondOrStepTimer( - every_secs=self._eval_throttle_secs) - - def after_save(self, session, global_step_value): - del session # unused; required by signature. - - if self._timer.should_trigger_for_step(global_step_value): - self._timer.update_last_triggered_step(global_step_value) - self._evaluator.evaluate_and_export() - else: - logging.info('Skip the current checkpoint eval due to throttle secs ' - '({} secs).'.format(self._eval_throttle_secs)) - _assert_eval_spec(self._eval_spec) # Final export signal: For any eval result with global_step >= train @@ -617,16 +651,12 @@ class _TrainingExecutor(object): # When the underlying `Estimator` object saves a new checkpoint, we would # like this callback to be called so that evaluation and export can trigger. saving_listeners = [ - NewCheckpointListener(evaluator, self._eval_spec.throttle_secs) + _NewCheckpointListenerForEvaluate(evaluator, + self._eval_spec.throttle_secs, + _ContinuousEvalListener()) ] self._start_distributed_training(saving_listeners=saving_listeners) - if not evaluator.is_final_export_triggered: - logging.info('Training has already ended. But the last eval is skipped ' - 'due to eval throttle_secs. Now evaluating the final ' - 'checkpoint.') - evaluator.evaluate_and_export() - def run_evaluator(self): """Runs task evaluator.""" # TODO(xiejw): To allow execution framework to add continuous eval listener. @@ -640,68 +670,33 @@ class _TrainingExecutor(object): def run_local(self): """Runs training and evaluation locally (non-distributed).""" - - def _should_stop_local_train(global_step): - if self._train_spec.max_steps is None: - return False - if global_step >= self._train_spec.max_steps: - return True - return False - _assert_eval_spec(self._eval_spec) - if self._eval_spec.throttle_secs <= 0: - raise ValueError('eval_spec.throttle_secs should be positive, given: {}.' - 'It is used do determine how long each training ' - 'iteration should go when train and evaluate ' - 'locally.'.format(self._eval_spec.throttle_secs)) - - stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs) - train_hooks = ( - list(self._train_spec.hooks) + [stop_hook] + list(self._train_hooks)) + train_hooks = list(self._train_spec.hooks) + list(self._train_hooks) logging.info('Start train and evaluate loop. The evaluate will happen ' - 'after {} secs (eval_spec.throttle_secs) or training is ' - 'finished.'.format(self._eval_spec.throttle_secs)) + 'after every checkpoint. Checkpoint frequency is determined ' + 'based on RunConfig arguments: save_checkpoints_steps {} or ' + 'save_checkpoints_secs {}.'.format( + self._estimator.config.save_checkpoints_steps, + self._estimator.config.save_checkpoints_secs)) evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec, self._train_spec.max_steps) - eval_result = _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT) - export_results = [] - - while True: - self._estimator.train( - input_fn=self._train_spec.input_fn, - max_steps=self._train_spec.max_steps, - hooks=train_hooks) - - if not self._continuous_eval_listener.before_eval(): - logging.info('Exiting training and evaluation loop, as requested by ' - '_ContinuousEvalListener.before_eval.') - break - - # Final export signal: For any eval result with global_step >= train - # max_steps, the evaluator will send the final export signal. The - # _should_stop_local_train will then end the while True as the stopping - # condition is satisfied (both checks use the same global_step value, - # i.e., no race condition) - eval_result, export_results = evaluator.evaluate_and_export() - - if eval_result.status != _EvalStatus.EVALUATED: - # This is unexpected; should never happen. - # Training should always end with a new checkpoint. - raise RuntimeError('There was no new checkpoint after the training. ' - 'Eval status: {}'.format(eval_result.status)) - - if not self._continuous_eval_listener.after_eval(eval_result): - logging.info('Exiting evaluation, as requested by ' - '_ContinuousEvalListener.after_eval.') - break + listener_for_eval = _NewCheckpointListenerForEvaluate( + evaluator, self._eval_spec.throttle_secs, + self._continuous_eval_listener) + saving_listeners = [listener_for_eval] + + self._estimator.train( + input_fn=self._train_spec.input_fn, + max_steps=self._train_spec.max_steps, + hooks=train_hooks, + saving_listeners=saving_listeners) - if _should_stop_local_train( - eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]): - break - return eval_result.metrics, export_results + eval_result = listener_for_eval.eval_result or _EvalResult( + status=_EvalStatus.MISSING_CHECKPOINT) + return eval_result.metrics, listener_for_eval.export_results def _start_std_server(self, config): """Creates, starts, and returns a server_lib.Server.""" @@ -741,7 +736,8 @@ class _TrainingExecutor(object): job_name=config.task_type, task_index=config.task_id, config=session_config, - start=False) + start=False, + protocol=config.protocol) server.start() return server diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py index 2c838db7a4de98d941752ce9d5ddf8f2b47a46f1..dc106c7d3baf561a203341a2063c1a9b86fa2b5b 100644 --- a/tensorflow/python/estimator/training_test.py +++ b/tensorflow/python/estimator/training_test.py @@ -29,17 +29,21 @@ import time import numpy as np +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import exporter as exporter_lib +from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.estimator import training from tensorflow.python.estimator.canned import dnn from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export as export_lib -from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import state_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging @@ -49,6 +53,7 @@ from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import monitored_session from tensorflow.python.training import server_lib from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util from tensorflow.python.util import compat _DEFAULT_EVAL_STEPS = 100 @@ -467,6 +472,7 @@ class _TrainingExecutorTrainingTest(object): job_name=mock_est.config.task_type, task_index=mock_est.config.task_id, config=test.mock.ANY, + protocol=None, start=False) self.assertTrue(mock_server_instance.start.called) @@ -497,6 +503,7 @@ class _TrainingExecutorTrainingTest(object): job_name=mock_est.config.task_type, task_index=mock_est.config.task_id, config=test.mock.ANY, + protocol=None, start=False) self.assertTrue(mock_server_instance.start.called) @@ -724,6 +731,7 @@ class TrainingExecutorRunMasterTest(test.TestCase): job_name=mock_est.config.task_type, task_index=mock_est.config.task_id, config=test.mock.ANY, + protocol=None, start=False) self.assertTrue(mock_server_instance.start.called) @@ -885,7 +893,8 @@ class TrainingExecutorRunMasterTest(test.TestCase): # `after_save`. del args, kwargs saving_listeners[0].begin() - saving_listeners[0].after_save(session=None, global_step_value=None) + saving_listeners[0].after_save(session=None, global_step_value=0) + saving_listeners[0].after_save(session=None, global_step_value=10) mock_est = test.mock.Mock( spec=estimator_lib.Estimator, model_dir='path/', train=estimator_train) @@ -930,7 +939,10 @@ class TrainingExecutorRunMasterTest(test.TestCase): del args, kwargs saving_listeners[0].begin() - # Call three times. + # Call four times. + mock_timer.should_trigger_for_step.return_value = True + saving_listeners[0].after_save(session=None, global_step_value=None) + mock_timer.should_trigger_for_step.return_value = True saving_listeners[0].after_save(session=None, global_step_value=None) @@ -979,14 +991,19 @@ class TrainingExecutorRunMasterTest(test.TestCase): del args, kwargs saving_listeners[0].begin() - # Call two times. + # Call tree times (one for first saving). mock_timer.should_trigger_for_step.return_value = True - saving_listeners[0].after_save(session=None, global_step_value=None) + saving_listeners[0].after_save(session=None, global_step_value=0) + + mock_timer.should_trigger_for_step.return_value = True + saving_listeners[0].after_save(session=None, global_step_value=125) - # The final ckpt is skipped by the timer. It will be picked up the final - # export check in the code. mock_timer.should_trigger_for_step.return_value = False - saving_listeners[0].after_save(session=None, global_step_value=None) + saving_listeners[0].after_save(session=None, global_step_value=250) + + # At the end evaluate should be called even if throttle secs prevents it. + mock_timer.should_trigger_for_step.return_value = False + saving_listeners[0].end(session=None, global_step_value=300) mock_est.train = estimator_train mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2'] @@ -1467,6 +1484,7 @@ class TrainingExecutorRunPsTest(test.TestCase): job_name=mock_est.config.task_type, task_index=mock_est.config.task_id, config=test.mock.ANY, + protocol=None, start=False) self.assertTrue(mock_server_instance.start.called) @@ -1566,28 +1584,31 @@ class StopAtSecsHookTest(test.TestCase): class TrainingExecutorRunLocalTest(test.TestCase): """Tests run_local of _TrainingExecutor.""" + def _model_fn(self, features, labels, mode): + del labels + with ops.control_dependencies([features]): + train_op = state_ops.assign_add(training_util.get_global_step(), 1) + return model_fn_lib.EstimatorSpec( + mode, + loss=constant_op.constant(0.), + train_op=train_op, + predictions=constant_op.constant([[10.]]), + eval_metric_ops={'mean_of_features': metrics_lib.mean(features)}) + + def _input_fn(self, repeat=True): + ds = dataset_ops.Dataset.from_tensors([1]) + if repeat: + return ds.repeat() + return ds + def unique_checkpoint_every_time_fn(self): return 'checkpoint_path_%s/' % random.random() - def test_send_stop_at_secs_to_train(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn - train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) - eval_spec = training.EvalSpec( - input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) - mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} - - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) - executor.run_local() - - stop_hook = mock_est.train.call_args[1]['hooks'][-1] - self.assertIsInstance(stop_hook, training._StopAtSecsHook) - self.assertEqual(eval_spec.throttle_secs, stop_hook._stop_after_secs) - - def test_runs_in_a_loop_until_max_steps(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn + def test_runs_evaluate_with_every_new_checkpoint(self): + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=10)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) mock_est.times_export_was_called = 0 mock_est.times_final_export_was_called = 0 @@ -1604,42 +1625,30 @@ class TrainingExecutorRunLocalTest(test.TestCase): exporter.name = 'see_how_many_times_export_is_called' exporter.export = export - train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=22) eval_spec = training.EvalSpec( - input_fn=lambda: 1, - hooks=[_FakeHook()], - throttle_secs=100, + input_fn=lambda: self._input_fn(repeat=False), + throttle_secs=0, exporters=exporter) - # should be called 3 times. - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 100 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() - self.assertEqual(3, mock_est.train.call_count) + self.assertEqual(1, mock_est.train.call_count) self.assertEqual(3, mock_est.evaluate.call_count) self.assertEqual(3, mock_est.times_export_was_called) self.assertEqual(1, mock_est.times_final_export_was_called) def test_runs_with_eval_listener_before_eval(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=10)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn - train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) - eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=100) - # should be called 2 times without the evallistener - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] + train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=12) + eval_spec = training.EvalSpec(input_fn=lambda: self._input_fn(repeat=False)) + mock_est.evaluate.side_effect = [{_GLOBAL_STEP_KEY: train_spec.max_steps}] class _Listener(training._ContinuousEvalListener): @@ -1658,67 +1667,61 @@ class TrainingExecutorRunLocalTest(test.TestCase): self.assertEqual(1, mock_est.train.call_count) self.assertEqual(0, mock_est.evaluate.call_count) - self.assertEqual(1, listener.call_count) def test_runs_with_eval_listener_after_eval(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=10)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) - train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) - eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=100) - # should be called 2 times without the evallistener - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] + train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=3000) + eval_spec = training.EvalSpec( + input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0) class _Listener(training._ContinuousEvalListener): - def __init__(self, test_case): + def __init__(self): self.call_count = 0 - self._test_case = test_case def after_eval(self, eval_result): self.call_count += 1 - self._test_case.assertEqual( - train_spec.max_steps - 50, eval_result.metrics[_GLOBAL_STEP_KEY]) return False # Will stop the run_local after first eval. - listener = _Listener(test_case=self) + listener = _Listener() executor = training._TrainingExecutor( mock_est, train_spec, eval_spec, continuous_eval_listener=listener) - executor.run_local() + metrics, _ = executor.run_local() # pylint: disable=assignment-from-no-return self.assertEqual(1, mock_est.train.call_count) self.assertEqual(1, mock_est.evaluate.call_count) self.assertEqual(1, listener.call_count) + # Should be less than max_steps since listener did early stopping. + self.assertLess(metrics[_GLOBAL_STEP_KEY], train_spec.max_steps) def test_handles_no_new_checkpoint_found(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint.return_value = ( - 'no_new_checkpoints_after_the_first_train_step') + est = estimator_lib.Estimator( + model_fn=self._model_fn, + # disable saving checkpoint + config=run_config_lib.RunConfig( + save_checkpoints_steps=None, save_checkpoints_secs=None)) train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + input_fn=self._input_fn, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( - input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) - # It was going to be called 3 times. - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 100 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] + input_fn=lambda: self._input_fn(repeat=False), + hooks=[_FakeHook()], + throttle_secs=100) - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) - with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG): + executor = training._TrainingExecutor(est, train_spec, eval_spec) + with self.assertRaisesRegexp(ValueError, + 'There should be a CheckpointSaverHook'): executor.run_local() def test_final_export_is_true_in_the_end(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=10)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) mock_est.times_export_fn_was_called = 0 mock_est.times_the_final_export_was_true = 0 @@ -1734,37 +1737,29 @@ class TrainingExecutorRunLocalTest(test.TestCase): exporter.export = export train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + input_fn=self._input_fn, max_steps=12, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( - input_fn=lambda: 1, - hooks=[_FakeHook()], - throttle_secs=100, + input_fn=lambda: self._input_fn(repeat=False), + throttle_secs=0, exporters=exporter) - # should be called 3 times. - mock_est.evaluate.side_effect = [{ - _GLOBAL_STEP_KEY: train_spec.max_steps - 100 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - 50 - }, { - _GLOBAL_STEP_KEY: train_spec.max_steps - }] - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() - self.assertEqual(3, mock_est.train.call_count) - self.assertEqual(3, mock_est.evaluate.call_count) - self.assertEqual(3, mock_est.times_export_fn_was_called) + self.assertEqual(1, mock_est.train.call_count) + self.assertEqual(2, mock_est.evaluate.call_count) + self.assertEqual(2, mock_est.times_export_fn_was_called) self.assertEqual(1, mock_est.times_the_final_export_was_true) def test_train_and_evaluate_args(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint.return_value = 'checkpoint_path/' + est = estimator_lib.Estimator(model_fn=self._model_fn) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + input_fn=self._input_fn, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( - input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval') - mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} + input_fn=lambda: self._input_fn(repeat=False), + steps=2, + hooks=[_FakeHook()], + name='local_eval') executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() @@ -1773,11 +1768,11 @@ class TrainingExecutorRunLocalTest(test.TestCase): name=eval_spec.name, input_fn=eval_spec.input_fn, steps=eval_spec.steps, - checkpoint_path='checkpoint_path/', + checkpoint_path=est.latest_checkpoint(), hooks=eval_spec.hooks) train_args = mock_est.train.call_args[1] - self.assertEqual(list(train_spec.hooks), list(train_args['hooks'][:-1])) + self.assertEqual(list(train_spec.hooks), list(train_args['hooks'])) self.assertEqual(train_spec.input_fn, train_args['input_fn']) self.assertEqual(train_spec.max_steps, train_args['max_steps']) @@ -1812,25 +1807,11 @@ class TrainingExecutorRunLocalTest(test.TestCase): if not isinstance(h, training._StopAtSecsHook) ]) - def test_errors_out_if_throttle_secs_is_zero(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - train_spec = training.TrainSpec(input_fn=lambda: 1) - eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=0) - - executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) - with self.assertRaisesRegexp(ValueError, 'throttle_secs'): - executor.run_local() - def test_that_export_is_called_with_run_local(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - mock_train_spec = test.mock.Mock(spec=training.TrainSpec) - mock_train_spec.max_steps = 200 - mock_est.evaluate.return_value = { - _GLOBAL_STEP_KEY: mock_train_spec.max_steps - } - # _validate_hooks would have made sure that train_spec.hooks is [], when - # None were passed. - mock_train_spec.hooks = [] + est = estimator_lib.Estimator(model_fn=self._model_fn) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) + train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=12) + mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} def export(estimator, *args, **kwargs): del args, kwargs @@ -1842,13 +1823,13 @@ class TrainingExecutorRunLocalTest(test.TestCase): exporter.export = export eval_spec = training.EvalSpec( - input_fn=lambda: 1, + input_fn=lambda: self._input_fn(repeat=False), steps=2, start_delay_secs=0, throttle_secs=213, exporters=exporter) - executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) + executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) # pylint: disable=assignment-from-no-return _, export_results = executor.run_local() # pylint: enable=assignment-from-no-return @@ -1857,9 +1838,13 @@ class TrainingExecutorRunLocalTest(test.TestCase): self.assertEqual(export_results, ['path_to_export']) def test_errors_out_if_evaluate_returns_empty_dict(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - train_spec = training.TrainSpec(input_fn=lambda: 1) - eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=2)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) + train_spec = training.TrainSpec(input_fn=self._input_fn) + eval_spec = training.EvalSpec( + input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0) mock_est.evaluate.return_value = {} executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) @@ -1867,18 +1852,26 @@ class TrainingExecutorRunLocalTest(test.TestCase): executor.run_local() def test_errors_out_if_evaluate_returns_non_dict(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - train_spec = training.TrainSpec(input_fn=lambda: 1) - eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=2)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) + train_spec = training.TrainSpec(input_fn=self._input_fn) + eval_spec = training.EvalSpec( + input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0) mock_est.evaluate.return_value = 123 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR): executor.run_local() def test_errors_out_if_evaluate_returns_dict_without_global_step(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator) - train_spec = training.TrainSpec(input_fn=lambda: 1) - eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) + est = estimator_lib.Estimator( + model_fn=self._model_fn, + config=run_config_lib.RunConfig(save_checkpoints_steps=2)) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) + train_spec = training.TrainSpec(input_fn=self._input_fn) + eval_spec = training.EvalSpec( + input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0) mock_est.evaluate.return_value = {'loss': 123} executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) @@ -1887,19 +1880,21 @@ class TrainingExecutorRunLocalTest(test.TestCase): executor.run_local() def test_train_and_evaluate_return_metrics(self): - mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') - mock_est.latest_checkpoint.return_value = 'checkpoint_path/' + est = estimator_lib.Estimator(model_fn=self._model_fn) + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est) train_spec = training.TrainSpec( - input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + input_fn=self._input_fn, max_steps=12, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( - input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval') - mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} + input_fn=lambda: self._input_fn(repeat=False), + steps=2, + hooks=[_FakeHook()], + name='local_eval') executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) # pylint: disable=assignment-from-no-return metrics, _ = executor.run_local() # pylint: enable=assignment-from-no-return - self.assertEqual(metrics['global_step'], 300) + self.assertEqual(metrics['global_step'], 12) class TrainAndEvaluateRunTest(test.TestCase): @@ -2096,7 +2091,7 @@ class TrainAndEvaluateIntegrationTest(test.TestCase): # max_steps should be larger than save_summary_steps max_steps = 10 - save_summary_steps = 2 + save_summary_steps = 9 data = np.linspace( 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) @@ -2104,24 +2099,20 @@ class TrainAndEvaluateIntegrationTest(test.TestCase): y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) # learn y = x - train_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, - y=y_data, - batch_size=batch_size, - num_epochs=None, - shuffle=True) - - eval_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, - y=y_data, - batch_size=batch_size, - num_epochs=1, - shuffle=False) - - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, - batch_size=batch_size, - shuffle=False) + def train_input_fn(): + return dataset_ops.Dataset.from_tensor_slices(({ + 'x': x_data + }, y_data)).batch(batch_size).repeat().shuffle(1000) + + def eval_input_fn(): + return dataset_ops.Dataset.from_tensor_slices(({ + 'x': x_data + }, y_data)).batch(batch_size) + + def predict_input_fn(): + return dataset_ops.Dataset.from_tensor_slices({ + 'x': x_data + }).batch(batch_size) feature_columns = [ feature_column.numeric_column('x', shape=(input_dimension,))] @@ -2137,9 +2128,11 @@ class TrainAndEvaluateIntegrationTest(test.TestCase): max_steps=max_steps) eval_spec = training.EvalSpec( - name=eval_name, input_fn=eval_input_fn, steps=None, + name=eval_name, + input_fn=eval_input_fn, + steps=None, exporters=self._get_exporter(exporter_name, feature_columns), - throttle_secs=2) + throttle_secs=0) training.train_and_evaluate(est, train_spec, eval_spec) @@ -2148,15 +2141,12 @@ class TrainAndEvaluateIntegrationTest(test.TestCase): # Examine the training events. Use a range to check global step to avoid # flakyness due to global step race condition. - training_loss, training_global_step = self._extract_loss_and_global_step( - est.model_dir) + training_loss, _ = self._extract_loss_and_global_step(est.model_dir) self.assertIsNotNone(training_loss) - self.assertTrue( - max_steps - save_summary_steps < training_global_step <= max_steps) # Examine the eval events. The global step should be accurate. eval_loss, eval_global_step = self._extract_loss_and_global_step( - event_folder=os.path.join(est.model_dir, 'eval_' + eval_name)) + event_folder=est.eval_dir(eval_name)) self.assertIsNotNone(eval_loss) self.assertEqual(max_steps, eval_global_step) diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py index 924ca309ff0455d3bb06be61ce65bb0a61e84fb0..d4a75478d53f5b3dc8e66df98a78b51a6d25aab8 100644 --- a/tensorflow/python/estimator/util.py +++ b/tensorflow/python/estimator/util.py @@ -22,6 +22,7 @@ from __future__ import print_function import os import time +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training @@ -129,3 +130,24 @@ class _DatasetInitializerHook(training.SessionRunHook): def after_create_session(self, session, coord): del coord session.run(self._initializer) + + +class StrategyInitFinalizeHook(training.SessionRunHook): + """Creates a SessionRunHook that initializes and shutsdown devices.""" + + def __init__(self, initialization_fn, finalize_fn): + self._initialization_fn = initialization_fn + self._finalize_fn = finalize_fn + + def begin(self): + self._init_ops = self._initialization_fn() + self._finalize_ops = self._finalize_fn() + + def after_create_session(self, session, coord): + logging.info('Initialize system') + session.run(self._init_ops, + options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) + + def end(self, session): + logging.info('Finalize system.') + session.run(self._finalize_ops) diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index 295d4ca094cc8cb85c0f1f7fd47c20b910c270df..1017d4ba475bc0c1f74c1628fc2a23d9195fde27 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -48,6 +48,39 @@ py_library( ], ) +py_library( + name = "feature_column_v2", + srcs = ["feature_column_v2.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:embedding_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", + "//tensorflow/python:template", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/keras", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + filegroup( name = "vocabulary_testdata", srcs = [ @@ -89,6 +122,40 @@ py_test( "//tensorflow/python:variables", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", + ], +) + +py_test( + name = "feature_column_v2_test", + srcs = ["feature_column_v2_test.py"], + data = [":vocabulary_testdata"], + srcs_version = "PY2AND3", + tags = [ + "no_cuda_on_cpu_tap", + "no_pip", + ], + deps = [ + ":feature_column_py", + ":feature_column_v2", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:session", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", "//tensorflow/python/estimator:numpy_io", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index af2ead9b8405696977cc9a40be750b939541006a..2246d2f3e99a2a80311e7e5b5b4f97f3b6ccfd45 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -16,7 +16,7 @@ FeatureColumns provide a high level abstraction for ingesting and representing features. FeatureColumns are also the primary way of encoding features for -canned @{tf.estimator.Estimator}s. +canned `tf.estimator.Estimator`s. When using FeatureColumns with `Estimators`, the type of feature column you should choose depends on (1) the feature type and (2) the model type. @@ -172,7 +172,7 @@ def _internal_input_layer(features, scope=None): """See input_layer. `scope` is a name or variable scope to use.""" - feature_columns = _clean_feature_columns(feature_columns) + feature_columns = _normalize_feature_columns(feature_columns) for column in feature_columns: if not isinstance(column, _DenseColumn): raise ValueError( @@ -350,10 +350,23 @@ def linear_model(features, prediction itself for linear regression problems. Note on supported columns: `linear_model` treats categorical columns as - `indicator_column`s while `input_layer` explicitly requires wrapping each - of them with an `embedding_column` or an `indicator_column`. + `indicator_column`s. To be specific, assume the input as `SparseTensor` looks + like: - Example: + ```python + shape = [2, 2] + { + [0, 0]: "a" + [1, 0]: "b" + [1, 1]: "c" + } + ``` + `linear_model` assigns weights for the presence of "a", "b", "c' implicitly, + just like `indicator_column`, while `input_layer` explicitly requires wrapping + each of categorical columns with an `embedding_column` or an + `indicator_column`. + + Example of usage: ```python price = numeric_column('price') @@ -374,13 +387,44 @@ def linear_model(features, to your model. All items should be instances of classes derived from `_FeatureColumn`s. units: An integer, dimensionality of the output space. Default value is 1. - sparse_combiner: A string specifying how to reduce if a sparse column is - multivalent. Currently "mean", "sqrtn" and "sum" are supported, with "sum" - the default. "sqrtn" often achieves good accuracy, in particular with - bag-of-words columns. It combines each sparse columns independently. + sparse_combiner: A string specifying how to reduce if a categorical column + is multivalent. Except `numeric_column`, almost all columns passed to + `linear_model` are considered as categorical columns. It combines each + categorical column independently. Currently "mean", "sqrtn" and "sum" are + supported, with "sum" the default for linear model. "sqrtn" often achieves + good accuracy, in particular with bag-of-words columns. * "sum": do not normalize features in the column * "mean": do l1 normalization on features in the column * "sqrtn": do l2 normalization on features in the column + For example, for two features represented as the categorical columns: + + ```python + # Feature 1 + + shape = [2, 2] + { + [0, 0]: "a" + [0, 1]: "b" + [1, 0]: "c" + } + + # Feature 2 + + shape = [2, 3] + { + [0, 0]: "d" + [1, 0]: "e" + [1, 1]: "f" + [1, 2]: "g" + } + ``` + with `sparse_combiner` as "mean", the linear model outputs conceptly are: + ``` + y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0 + y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1 + ``` + where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight + assigned to the presence of `x` in the input features. weight_collections: A list of collection names to which the Variable will be added. Note that, variables will also be added to collections `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. @@ -408,13 +452,15 @@ def linear_model(features, ValueError: if an item in `feature_columns` is neither a `_DenseColumn` nor `_CategoricalColumn`. """ + with variable_scope.variable_scope(None, 'linear_model') as vs: + model_name = _strip_leading_slashes(vs.name) linear_model_layer = _LinearModel( feature_columns=feature_columns, units=units, sparse_combiner=sparse_combiner, weight_collections=weight_collections, trainable=trainable, - name='linear_model') + name=model_name) retval = linear_model_layer(features) # pylint: disable=not-callable if cols_to_vars is not None: cols_to_vars.update(linear_model_layer.cols_to_vars()) @@ -422,13 +468,25 @@ def linear_model(features, def _add_to_collections(var, weight_collections): - # TODO(rohanj): Explore adding a _get_variable_list method on `Variable` - # so that we don't have to do this check. - if isinstance(var, variables.PartitionedVariable): - for constituent_var in list(var): - ops.add_to_collections(weight_collections, constituent_var) - else: - ops.add_to_collections(weight_collections, var) + """Adds a var to the list of weight_collections provided. + + Handles the case for partitioned and non-partitioned variables. + + Args: + var: A variable or Partitioned Variable. + weight_collections: List of collections to add variable to. + """ + for weight_collection in weight_collections: + # The layer self.add_variable call already adds it to GLOBAL_VARIABLES. + if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES: + continue + # TODO(rohanj): Explore adding a _get_variable_list method on `Variable` + # so that we don't have to do this check. + if isinstance(var, variables.PartitionedVariable): + for constituent_var in list(var): + ops.add_to_collection(weight_collection, constituent_var) + else: + ops.add_to_collection(weight_collection, var) class _FCLinearWrapper(base.Layer): @@ -536,8 +594,11 @@ class _LinearModel(training.Model): name=None, **kwargs): super(_LinearModel, self).__init__(name=name, **kwargs) - self._feature_columns = _clean_feature_columns(feature_columns) + self._feature_columns = _normalize_feature_columns( + feature_columns) self._weight_collections = list(weight_collections or []) + if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections: + self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections: self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES) @@ -643,7 +704,7 @@ def _transform_features(features, feature_columns): Returns: A `dict` mapping `_FeatureColumn` to `Tensor` and `SparseTensor` values. """ - feature_columns = _clean_feature_columns(feature_columns) + feature_columns = _normalize_feature_columns(feature_columns) outputs = {} with ops.name_scope( None, default_name='transform_features', values=features.values()): @@ -911,7 +972,8 @@ def shared_embedding_columns( tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which to restore the column weights. Required if `ckpt_to_load_from` is not `None`. - max_norm: If not `None`, embedding values are l2-normalized to this value. + max_norm: If not `None`, each embedding is clipped if its l2-norm is + larger than this value, before combining. trainable: Whether or not the embedding is trainable. Default is True. Returns: @@ -925,7 +987,12 @@ def shared_embedding_columns( ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` is specified. ValueError: if `initializer` is specified and is not callable. + RuntimeError: if eager execution is enabled. """ + if context.executing_eagerly(): + raise RuntimeError('shared_embedding_columns are not supported when eager ' + 'execution is enabled.') + if (dimension is None) or (dimension < 1): raise ValueError('Invalid dimension {}.'.format(dimension)) if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): @@ -970,16 +1037,6 @@ def shared_embedding_columns( shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) shared_embedding_collection_name += '_shared_embedding' - # Create the state (_SharedEmbeddingColumnLayer) here. - embedding_shape = num_buckets, dimension - - shared_embedding_column_layer = _EmbeddingColumnLayer( - embedding_shape=embedding_shape, - initializer=initializer, - weight_collections=[], - trainable=trainable, - name=shared_embedding_collection_name) - result = [] for column in categorical_columns: result.append( @@ -988,16 +1045,12 @@ def shared_embedding_columns( initializer=initializer, dimension=dimension, combiner=combiner, - var_scope_name=shared_embedding_collection_name, + shared_embedding_collection_name=shared_embedding_collection_name, ckpt_to_load_from=ckpt_to_load_from, tensor_name_in_ckpt=tensor_name_in_ckpt, max_norm=max_norm, trainable=trainable)) - for single_result in result: - single_result._set_layer(shared_embedding_column_layer) # pylint: disable=protected-access - single_result._set_all_columns(result) # pylint: disable=protected-access - return result @@ -1182,12 +1235,13 @@ def categorical_column_with_hash_bucket(key, Use this when your sparse features are in string or integer format, and you want to distribute your inputs into a finite number of buckets by hashing. - output_id = Hash(input_feature_string) % bucket_size + output_id = Hash(input_feature_string) % bucket_size for string type input. + For int type input, the value is converted to its string representation first + and then hashed by the same formula. For input dictionary `features`, `features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int - and `''` for string. Note that these values are independent of the - `default_value` argument. + and `''` for string, which will be dropped by this feature column. Example: @@ -1249,8 +1303,7 @@ def categorical_column_with_vocabulary_file(key, For input dictionary `features`, `features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int - and `''` for string. Note that these values are independent of the - `default_value` argument. + and `''` for string, which will be dropped by this feature column. Example with `num_oov_buckets`: File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state @@ -1366,8 +1419,7 @@ def categorical_column_with_vocabulary_list( For input dictionary `features`, `features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int - and `''` for string. Note that these values are independent of the - `default_value` argument. + and `''` for string, which will be dropped by this feature column. Example with `num_oov_buckets`: In the following example, each input in `vocabulary_list` is assigned an ID @@ -1480,8 +1532,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None): For input dictionary `features`, `features[key]` is either `Tensor` or `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int - and `''` for string. Note that these values are independent of the - `default_value` argument. + and `''` for string, which will be dropped by this feature column. In the following examples, each input in the range `[0, 1000000)` is assigned the same value. All other inputs are assigned `default_value` 0. Note that a @@ -1538,8 +1589,14 @@ def categorical_column_with_identity(key, num_buckets, default_value=None): def indicator_column(categorical_column): """Represents multi-hot representation of given categorical column. - Used to wrap any `categorical_column_*` (e.g., to feed to DNN). Use - `embedding_column` if the inputs are sparse. + - For DNN model, `indicator_column` can be used to wrap any + `categorical_column_*` (e.g., to feed to DNN). Consider to Use + `embedding_column` if the number of buckets/unique(values) are large. + + - For Wide (aka linear) model, `indicator_column` is the internal + representation for categorical column when passing categorical column + directly (as any element in feature_columns) to `linear_model`. See + `linear_model` for details. ```python name = indicator_column(categorical_column_with_vocabulary_list( @@ -1813,11 +1870,8 @@ class _EmbeddingColumnLayer(base.Layer): dtype=dtypes.float32, initializer=self._initializer, trainable=self.trainable) - # self.add_variable already appends to GLOBAL_VARIABLES collection. if self._weight_collections and not context.executing_eagerly(): - for weight_collection in self._weight_collections: - if weight_collection != ops.GraphKeys.GLOBAL_VARIABLES: - _add_to_collections(self._embedding_weight_var, [weight_collection]) + _add_to_collections(self._embedding_weight_var, self._weight_collections) self.built = True def call(self, _): @@ -1882,7 +1936,7 @@ class _FeatureColumn(object): It is used for get_parsing_spec for `tf.parse_example`. Returned spec is a dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other - supported objects. Please check documentation of @{tf.parse_example} for all + supported objects. Please check documentation of `tf.parse_example` for all supported spec objects. Let's say a Feature column depends on raw feature ('raw') and another @@ -1941,7 +1995,7 @@ class _DenseColumn(_FeatureColumn): weight_collections: List of graph collections to which Variables (if any will be created) are added. trainable: If `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see @{tf.Variable}). + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). Returns: `Tensor` of shape [batch_size] + `_variable_shape`. @@ -1956,7 +2010,7 @@ def _create_weighted_sum(column, weight_collections, trainable, weight_var=None): - """Creates a weighted sum for a dense or sparse column for linear_model.""" + """Creates a weighted sum for a dense/categorical column for linear_model.""" if isinstance(column, _CategoricalColumn): return _create_categorical_column_weighted_sum( column=column, @@ -2008,7 +2062,7 @@ class _CategoricalColumn(_FeatureColumn): WARNING: Do not subclass this layer unless you know what you are doing: the API is subject to future changes. - A categorical feature typically handled with a @{tf.SparseTensor} of IDs. + A categorical feature typically handled with a `tf.SparseTensor` of IDs. """ __metaclass__ = abc.ABCMeta @@ -2043,7 +2097,7 @@ class _CategoricalColumn(_FeatureColumn): weight_collections: List of graph collections to which variables (if any will be created) are added. trainable: If `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see @{tf.get_variable}). + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.get_variable`). """ pass @@ -2055,7 +2109,34 @@ def _create_categorical_column_weighted_sum(column, weight_collections, trainable, weight_var=None): - """Create a weighted sum of a categorical column for linear_model.""" + # pylint: disable=g-doc-return-or-yield,g-doc-args + """Create a weighted sum of a categorical column for linear_model. + + Note to maintainer: As implementation details, the weighted sum is + implemented via embedding_lookup_sparse toward efficiency. Mathematically, + they are the same. + + To be specific, conceptually, categorical column can be treated as multi-hot + vector. Say: + + ```python + x = [0 0 1] # categorical column input + w = [a b c] # weights + ``` + The weighted sum is `c` in this case, which is same as `w[2]`. + + Another example is + + ```python + x = [0 1 1] # categorical column input + w = [a b c] # weights + ``` + The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`. + + For both cases, we can implement weighted sum via embedding_lookup with + sparse_combiner = "sum". + """ + sparse_tensors = column._get_sparse_tensors( # pylint: disable=protected-access builder, weight_collections=weight_collections, @@ -2077,7 +2158,7 @@ def _create_categorical_column_weighted_sum(column, initializer=init_ops.zeros_initializer(), trainable=trainable, collections=weight_collections) - return _safe_embedding_lookup_sparse( + return embedding_ops.safe_embedding_lookup_sparse( weight, id_tensor, sparse_weights=weight_tensor, @@ -2249,7 +2330,7 @@ def _shape_offsets(shape): # TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py -def _to_sparse_input(input_tensor, ignore_value=None): +def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None): """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. If `input_tensor` is already a `SparseTensor`, just return it. @@ -2293,8 +2374,22 @@ def _to_sparse_input(input_tensor, ignore_value=None): input_tensor, out_type=dtypes.int64, name='dense_shape')) -def _clean_feature_columns(feature_columns): - """Verifies and normalizes `feature_columns` input.""" +def _normalize_feature_columns(feature_columns): + """Normalizes the `feature_columns` input. + + This method converts the `feature_columns` to list type as best as it can. In + addition, verifies the type and other parts of feature_columns, required by + downstream library. + + Args: + feature_columns: The raw feature columns, usually passed by users. + + Returns: + The normalized feature column list. + + Raises: + ValueError: for any invalid inputs, such as empty, duplicated names, etc. + """ if isinstance(feature_columns, _FeatureColumn): feature_columns = [feature_columns] @@ -2420,6 +2515,7 @@ class _BucketizedColumn(_DenseColumn, _CategoricalColumn, def _get_sparse_tensors(self, inputs, weight_collections=None, trainable=None): + """Converts dense inputs to SparseTensor so downstream code can use it.""" input_tensor = inputs.get(self) batch_size = array_ops.shape(input_tensor)[0] # By construction, source_column is always one-dimensional. @@ -2498,7 +2594,7 @@ class _EmbeddingColumn( }) # Return embedding lookup result. - return _safe_embedding_lookup_sparse( + return embedding_ops.safe_embedding_lookup_sparse( embedding_weights=embedding_weights, sparse_ids=sparse_ids, sparse_weights=sparse_weights, @@ -2553,12 +2649,12 @@ def _get_graph_for_variable(var): class _SharedEmbeddingColumn( - _DenseColumn, + _DenseColumn, _SequenceDenseColumn, collections.namedtuple( '_SharedEmbeddingColumn', ('categorical_column', 'dimension', 'combiner', 'initializer', - 'var_scope_name', 'ckpt_to_load_from', 'tensor_name_in_ckpt', - 'max_norm', 'trainable'))): + 'shared_embedding_collection_name', 'ckpt_to_load_from', + 'tensor_name_in_ckpt', 'max_norm', 'trainable'))): """See `embedding_column`.""" @property @@ -2569,7 +2665,7 @@ class _SharedEmbeddingColumn( @property def _var_scope_name(self): - return self.var_scope_name + return self.shared_embedding_collection_name @property def _parse_example_spec(self): @@ -2578,29 +2674,17 @@ class _SharedEmbeddingColumn( def _transform_feature(self, inputs): return inputs.get(self.categorical_column) - def _set_layer(self, layer): - self._layer = layer - - def _set_all_columns(self, all_columns): - self._all_columns = all_columns - - def _reset_config(self): - config = self._layer.get_config() - config['embedding_shape'] = ( - self.categorical_column._num_buckets, # pylint: disable=protected-access - self.dimension) - config['initializer'] = self.initializer - self._layer = self._layer.__class__.from_config(config) - for column in self._all_columns: - column._set_layer(self._layer) # pylint: disable=protected-access - @property def _variable_shape(self): if not hasattr(self, '_shape'): self._shape = tensor_shape.vector(self.dimension) return self._shape - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + def _get_dense_tensor_internal(self, + inputs, + weight_collections=None, + trainable=None): + """Private method that follows the signature of _get_dense_tensor.""" # This method is called from a variable_scope with name _var_scope_name, # which is shared among all shared embeddings. Open a name_scope here, so # that the ops for different columns have distinct names. @@ -2611,19 +2695,38 @@ class _SharedEmbeddingColumn( sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor - self._layer.set_weight_collections(weight_collections) - embedding_weights = self._layer( - None, scope=variable_scope.get_variable_scope()) - # If we're in graph mode and this is called with a different graph, - # then we should reset. - if not context.executing_eagerly() and ( - ops.get_default_graph() != - _get_graph_for_variable(embedding_weights)): - self._reset_config() - self._layer.set_weight_collections(weight_collections) - embedding_weights = self._layer( - None, scope=variable_scope.get_variable_scope()) - + embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access + shared_embedding_collection = ops.get_collection( + self.shared_embedding_collection_name) + if shared_embedding_collection: + if len(shared_embedding_collection) > 1: + raise ValueError( + 'Collection {} can only contain one variable. ' + 'Suggested fix A: Choose a unique name for this collection. ' + 'Suggested fix B: Do not add any variables to this collection. ' + 'The feature_column library already adds a variable under the ' + 'hood.'.format(shared_embedding_collection)) + embedding_weights = shared_embedding_collection[0] + if embedding_weights.get_shape() != embedding_shape: + raise ValueError( + 'Shared embedding collection {} contains variable {} of ' + 'unexpected shape {}. Expected shape is {}. ' + 'Suggested fix A: Choose a unique name for this collection. ' + 'Suggested fix B: Do not add any variables to this collection. ' + 'The feature_column library already adds a variable under the ' + 'hood.'.format(self.shared_embedding_collection_name, + embedding_weights.name, + embedding_weights.get_shape(), embedding_shape)) + else: + embedding_weights = variable_scope.get_variable( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable and trainable, + collections=weight_collections) + ops.add_to_collection(self.shared_embedding_collection_name, + embedding_weights) if self.ckpt_to_load_from is not None: to_restore = embedding_weights if isinstance(to_restore, variables.PartitionedVariable): @@ -2633,7 +2736,7 @@ class _SharedEmbeddingColumn( }) # Return embedding lookup result. - return _safe_embedding_lookup_sparse( + return embedding_ops.safe_embedding_lookup_sparse( embedding_weights=embedding_weights, sparse_ids=sparse_ids, sparse_weights=sparse_weights, @@ -2641,6 +2744,44 @@ class _SharedEmbeddingColumn( name='%s_weights' % self.name, max_norm=self.max_norm) + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + if isinstance(self.categorical_column, _SequenceCategoricalColumn): + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must not be of type _SequenceCategoricalColumn. ' + 'Suggested fix A: If you wish to use input_layer, use a ' + 'non-sequence categorical_column_with_*. ' + 'Suggested fix B: If you wish to create sequence input, use ' + 'sequence_input_layer instead of input_layer. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + return self._get_dense_tensor_internal( + inputs=inputs, + weight_collections=weight_collections, + trainable=trainable) + + def _get_sequence_dense_tensor(self, + inputs, + weight_collections=None, + trainable=None): + if not isinstance(self.categorical_column, _SequenceCategoricalColumn): + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must be of type _SequenceCategoricalColumn ' + 'to use sequence_input_layer. ' + 'Suggested fix: Use one of sequence_categorical_column_with_*. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access + inputs=inputs, + weight_collections=weight_collections, + trainable=trainable) + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access + sequence_length = _sequence_length_from_sparse_tensor( + sparse_tensors.id_tensor) + return _SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + def _create_tuple(shape, value): """Returns a tuple with given shape and filled with value.""" @@ -2762,7 +2903,7 @@ class _HashedCategoricalColumn( return {self.key: parsing_ops.VarLenFeature(self.dtype)} def _transform_feature(self, inputs): - input_tensor = _to_sparse_input(inputs.get(self.key)) + input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): raise ValueError('SparseColumn input must be a SparseTensor.') @@ -2813,7 +2954,7 @@ class _VocabularyFileCategoricalColumn( return {self.key: parsing_ops.VarLenFeature(self.dtype)} def _transform_feature(self, inputs): - input_tensor = _to_sparse_input(inputs.get(self.key)) + input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) if self.dtype.is_integer != input_tensor.dtype.is_integer: raise ValueError( @@ -2865,7 +3006,7 @@ class _VocabularyListCategoricalColumn( return {self.key: parsing_ops.VarLenFeature(self.dtype)} def _transform_feature(self, inputs): - input_tensor = _to_sparse_input(inputs.get(self.key)) + input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) if self.dtype.is_integer != input_tensor.dtype.is_integer: raise ValueError( @@ -2917,7 +3058,7 @@ class _IdentityCategoricalColumn( return {self.key: parsing_ops.VarLenFeature(dtypes.int64)} def _transform_feature(self, inputs): - input_tensor = _to_sparse_input(inputs.get(self.key)) + input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) if not input_tensor.dtype.is_integer: raise ValueError( @@ -2999,7 +3140,8 @@ class _WeightedCategoricalColumn( self.dtype, weight_tensor.dtype)) if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor): # The weight tensor can be a regular Tensor. In this case, sparsify it. - weight_tensor = _to_sparse_input(weight_tensor, ignore_value=0.0) + weight_tensor = _to_sparse_input_and_drop_ignore_values( + weight_tensor, ignore_value=0.0) if not weight_tensor.dtype.is_floating: weight_tensor = math_ops.to_float(weight_tensor) return (inputs.get(self.categorical_column), weight_tensor) @@ -3086,161 +3228,6 @@ def _collect_leaf_level_keys(cross): return leaf_level_keys -# TODO(zakaria): Move this to embedding_ops and make it public. -def _safe_embedding_lookup_sparse(embedding_weights, - sparse_ids, - sparse_weights=None, - combiner='mean', - default_id=None, - name=None, - partition_strategy='div', - max_norm=None): - """Lookup embedding results, accounting for invalid IDs and empty features. - - The partitioned embedding in `embedding_weights` must all be the same shape - except for the first dimension. The first dimension is allowed to vary as the - vocabulary size is not necessarily a multiple of `P`. `embedding_weights` - may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a - partitioner. - - Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs - with non-positive weight. For an entry with no features, the embedding vector - for `default_id` is returned, or the 0-vector if `default_id` is not supplied. - - The ids and weights may be multi-dimensional. Embeddings are always aggregated - along the last dimension. - - Args: - embedding_weights: A list of `P` float `Tensor`s or values representing - partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable` - created by partitioning along dimension 0. The total unpartitioned - shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the - vocab size and `e_1, ..., e_m` are the embedding dimensions. - sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the - ids. `d_0` is typically batch size. - sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing - float weights corresponding to `sparse_ids`, or `None` if all weights - are be assumed to be 1.0. - combiner: A string specifying how to combine embedding results for each - entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" - the default. - default_id: The id to use for an entry with no features. - name: A name for this operation (optional). - partition_strategy: A string specifying the partitioning strategy. - Currently `"div"` and `"mod"` are supported. Default is `"div"`. - max_norm: If not `None`, all embeddings are l2-normalized to max_norm before - combining. - - - Returns: - Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`. - - Raises: - ValueError: if `embedding_weights` is empty. - """ - if embedding_weights is None: - raise ValueError('Missing embedding_weights %s.' % embedding_weights) - if isinstance(embedding_weights, variables.PartitionedVariable): - embedding_weights = list(embedding_weights) # get underlying Variables. - if not isinstance(embedding_weights, list): - embedding_weights = [embedding_weights] - if len(embedding_weights) < 1: - raise ValueError('Missing embedding_weights %s.' % embedding_weights) - - dtype = sparse_weights.dtype if sparse_weights is not None else None - embedding_weights = [ - ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights - ] - - with ops.name_scope(name, 'embedding_lookup', - embedding_weights + [sparse_ids, - sparse_weights]) as scope: - # Reshape higher-rank sparse ids and weights to linear segment ids. - original_shape = sparse_ids.dense_shape - original_rank_dim = sparse_ids.dense_shape.get_shape()[0] - original_rank = ( - array_ops.size(original_shape) - if original_rank_dim.value is None - else original_rank_dim.value) - sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ - math_ops.reduce_prod( - array_ops.slice(original_shape, [0], [original_rank - 1])), - array_ops.gather(original_shape, original_rank - 1)]) - if sparse_weights is not None: - sparse_weights = sparse_tensor_lib.SparseTensor( - sparse_ids.indices, - sparse_weights.values, sparse_ids.dense_shape) - - # Prune invalid ids and weights. - sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) - if combiner != 'sum': - sparse_ids, sparse_weights = _prune_invalid_weights( - sparse_ids, sparse_weights) - - # Fill in dummy values for empty features, if necessary. - sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids, - default_id or - 0) - if sparse_weights is not None: - sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) - - result = embedding_ops.embedding_lookup_sparse( - embedding_weights, - sparse_ids, - sparse_weights, - combiner=combiner, - partition_strategy=partition_strategy, - name=None if default_id is None else scope, - max_norm=max_norm) - - if default_id is None: - # Broadcast is_row_empty to the same shape as embedding_lookup_result, - # for use in Select. - is_row_empty = array_ops.tile( - array_ops.reshape(is_row_empty, [-1, 1]), - array_ops.stack([1, array_ops.shape(result)[1]])) - - result = array_ops.where(is_row_empty, - array_ops.zeros_like(result), - result, - name=scope) - - # Reshape back from linear ids back into higher-dimensional dense result. - final_result = array_ops.reshape( - result, - array_ops.concat([ - array_ops.slice( - math_ops.cast(original_shape, dtypes.int32), [0], - [original_rank - 1]), - array_ops.slice(array_ops.shape(result), [1], [-1]) - ], 0)) - final_result.set_shape(tensor_shape.unknown_shape( - (original_rank_dim - 1).value).concatenate(result.get_shape()[1:])) - return final_result - - -def _prune_invalid_ids(sparse_ids, sparse_weights): - """Prune invalid IDs (< 0) from the input ids and weights.""" - is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) - if sparse_weights is not None: - is_id_valid = math_ops.logical_and( - is_id_valid, - array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) - sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) - if sparse_weights is not None: - sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) - return sparse_ids, sparse_weights - - -def _prune_invalid_weights(sparse_ids, sparse_weights): - """Prune invalid weights (< 0) from the input ids and weights.""" - if sparse_weights is not None: - is_weights_valid = math_ops.greater(sparse_weights.values, 0) - sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) - sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) - return sparse_ids, sparse_weights - - class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn, collections.namedtuple('_IndicatorColumn', ['categorical_column'])): @@ -3277,10 +3264,14 @@ class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn, sp_ids=id_tensor, sp_values=weight_tensor, vocab_size=int(self._variable_shape[-1])) - # Remove (?, -1) index + # Remove (?, -1) index. weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0], weighted_column.dense_shape) - return sparse_ops.sparse_tensor_to_dense(weighted_column) + # Use scatter_nd to merge duplicated indices if existed, + # instead of sparse_tensor_to_dense. + return array_ops.scatter_nd(weighted_column.indices, + weighted_column.values, + weighted_column.dense_shape) dense_id_tensor = sparse_ops.sparse_tensor_to_dense( id_tensor, default_value=-1) diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index 627430d6bc5995cf054482ac3004098b8a2472ab..6be930be87b5db4e8b0ce261bf2956dd2999c606 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -30,7 +30,6 @@ from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.feature_column.feature_column import _CategoricalColumn from tensorflow.python.feature_column.feature_column import _DenseColumn @@ -52,8 +51,6 @@ from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test -from tensorflow.python.training import coordinator -from tensorflow.python.training import queue_runner_impl def _initialized_session(config=None): @@ -1257,14 +1254,14 @@ class CrossedColumnTest(test.TestCase): }, (crossed,)) -def get_linear_model_bias(): - with variable_scope.variable_scope('linear_model', reuse=True): +def get_linear_model_bias(name='linear_model'): + with variable_scope.variable_scope(name, reuse=True): return variable_scope.get_variable('bias_weights') -def get_linear_model_column_var(column): +def get_linear_model_column_var(column, name='linear_model'): return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, - 'linear_model/' + column.name)[0] + name + '/' + column.name)[0] def get_keras_linear_model_predictions(features, @@ -1803,39 +1800,6 @@ class LinearModelTest(test.TestCase): features['price2']: [[1.], [5.]], }) - def test_with_numpy_input_fn(self): - price = fc.numeric_column('price') - price_buckets = fc.bucketized_column(price, boundaries=[0., 10., 100.,]) - body_style = fc.categorical_column_with_vocabulary_list( - 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) - - input_fn = numpy_io.numpy_input_fn( - x={ - 'price': np.array([-1., 2., 13., 104.]), - 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']), - }, - batch_size=2, - shuffle=False) - features = input_fn() - net = fc.linear_model(features, [price_buckets, body_style]) - # self.assertEqual(1 + 3 + 5, net.shape[1]) - with _initialized_session() as sess: - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(sess, coord=coord) - - bias = get_linear_model_bias() - price_buckets_var = get_linear_model_column_var(price_buckets) - body_style_var = get_linear_model_column_var(body_style) - - sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) - sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) - sess.run(bias.assign([5.])) - - self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net)) - - coord.request_stop() - coord.join(threads) - def test_with_1d_sparse_tensor(self): price = fc.numeric_column('price') price_buckets = fc.bucketized_column(price, boundaries=[0., 10., 100.,]) @@ -1928,6 +1892,27 @@ class LinearModelTest(test.TestCase): with self.assertRaisesOpError('Feature .* cannot have rank 0'): sess.run(net, feed_dict={features['price']: np.array(1)}) + def test_multiple_linear_models(self): + price = fc.numeric_column('price') + with ops.Graph().as_default(): + features1 = {'price': [[1.], [5.]]} + features2 = {'price': [[2.], [10.]]} + predictions1 = fc.linear_model(features1, [price]) + predictions2 = fc.linear_model(features2, [price]) + bias1 = get_linear_model_bias(name='linear_model') + bias2 = get_linear_model_bias(name='linear_model_1') + price_var1 = get_linear_model_column_var(price, name='linear_model') + price_var2 = get_linear_model_column_var(price, name='linear_model_1') + with _initialized_session() as sess: + self.assertAllClose([0.], bias1.eval()) + sess.run(price_var1.assign([[10.]])) + sess.run(bias1.assign([5.])) + self.assertAllClose([[15.], [55.]], predictions1.eval()) + self.assertAllClose([0.], bias2.eval()) + sess.run(price_var2.assign([[10.]])) + sess.run(bias2.assign([5.])) + self.assertAllClose([[25.], [105.]], predictions2.eval()) + class _LinearModelTest(test.TestCase): @@ -2437,45 +2422,6 @@ class _LinearModelTest(test.TestCase): features['price2']: [[1.], [5.]], }) - def test_with_numpy_input_fn(self): - price = fc.numeric_column('price') - price_buckets = fc.bucketized_column( - price, boundaries=[ - 0., - 10., - 100., - ]) - body_style = fc.categorical_column_with_vocabulary_list( - 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) - - input_fn = numpy_io.numpy_input_fn( - x={ - 'price': np.array([-1., 2., 13., 104.]), - 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']), - }, - batch_size=2, - shuffle=False) - features = input_fn() - net = get_keras_linear_model_predictions(features, - [price_buckets, body_style]) - # self.assertEqual(1 + 3 + 5, net.shape[1]) - with _initialized_session() as sess: - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(sess, coord=coord) - - bias = get_linear_model_bias() - price_buckets_var = get_linear_model_column_var(price_buckets) - body_style_var = get_linear_model_column_var(body_style) - - sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) - sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) - sess.run(bias.assign([5.])) - - self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net)) - - coord.request_stop() - coord.join(threads) - def test_with_1d_sparse_tensor(self): price = fc.numeric_column('price') price_buckets = fc.bucketized_column( @@ -2586,7 +2532,7 @@ class _LinearModelTest(test.TestCase): class InputLayerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_retrieving_input(self): features = {'a': [0.]} input_layer = InputLayer(fc.numeric_column('a')) @@ -3022,51 +2968,6 @@ class FunctionalInputLayerTest(test.TestCase): ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) - def test_with_numpy_input_fn(self): - embedding_values = ( - (1., 2., 3., 4., 5.), # id 0 - (6., 7., 8., 9., 10.), # id 1 - (11., 12., 13., 14., 15.) # id 2 - ) - def _initializer(shape, dtype, partition_info): - del shape, dtype, partition_info - return embedding_values - - # price has 1 dimension in input_layer - price = fc.numeric_column('price') - body_style = fc.categorical_column_with_vocabulary_list( - 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) - # one_hot_body_style has 3 dims in input_layer. - one_hot_body_style = fc.indicator_column(body_style) - # embedded_body_style has 5 dims in input_layer. - embedded_body_style = fc.embedding_column(body_style, dimension=5, - initializer=_initializer) - - input_fn = numpy_io.numpy_input_fn( - x={ - 'price': np.array([11., 12., 13., 14.]), - 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']), - }, - batch_size=2, - shuffle=False) - features = input_fn() - net = fc.input_layer(features, - [price, one_hot_body_style, embedded_body_style]) - self.assertEqual(1 + 3 + 5, net.shape[1]) - with _initialized_session() as sess: - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(sess, coord=coord) - - # Each row is formed by concatenating `embedded_body_style`, - # `one_hot_body_style`, and `price` in order. - self.assertAllEqual( - [[11., 12., 13., 14., 15., 0., 0., 1., 11.], - [1., 2., 3., 4., 5., 1., 0., 0., 12]], - sess.run(net)) - - coord.request_stop() - coord.join(threads) - def test_with_1d_sparse_tensor(self): embedding_values = ( (1., 2., 3., 4., 5.), # id 0 @@ -4559,12 +4460,12 @@ class IndicatorColumnTest(test.TestCase): weights = fc.weighted_categorical_column(ids, 'weights') indicator = fc.indicator_column(weights) features = { - 'ids': constant_op.constant([['c', 'b', 'a']]), - 'weights': constant_op.constant([[2., 4., 6.]]) + 'ids': constant_op.constant([['c', 'b', 'a', 'c']]), + 'weights': constant_op.constant([[2., 4., 6., 1.]]) } indicator_tensor = _transform_features(features, [indicator])[indicator] with _initialized_session(): - self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval()) + self.assertAllEqual([[6., 4., 3.]], indicator_tensor.eval()) def test_transform_with_missing_value_in_weighted_column(self): # Github issue 12583 @@ -5329,9 +5230,9 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertIsNone(embedding_column_a.ckpt_to_load_from) self.assertIsNone(embedding_column_b.ckpt_to_load_from) self.assertEqual('aaa_bbb_shared_embedding', - embedding_column_a.var_scope_name) + embedding_column_a.shared_embedding_collection_name) self.assertEqual('aaa_bbb_shared_embedding', - embedding_column_b.var_scope_name) + embedding_column_b.shared_embedding_collection_name) self.assertIsNone(embedding_column_a.tensor_name_in_ckpt) self.assertIsNone(embedding_column_b.tensor_name_in_ckpt) self.assertIsNone(embedding_column_a.max_norm) @@ -5378,9 +5279,9 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertEqual('my_combiner', embedding_column_a.combiner) self.assertEqual('my_combiner', embedding_column_b.combiner) self.assertEqual('shared_embedding_collection_name', - embedding_column_a.var_scope_name) + embedding_column_a.shared_embedding_collection_name) self.assertEqual('shared_embedding_collection_name', - embedding_column_b.var_scope_name) + embedding_column_b.shared_embedding_collection_name) self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from) self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) @@ -5431,7 +5332,7 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertEqual(embedding_dimension, embedding_column_a.dimension) self.assertEqual('my_combiner', embedding_column_a.combiner) self.assertEqual('shared_embedding_collection_name', - embedding_column_a.var_scope_name) + embedding_column_a.shared_embedding_collection_name) self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) self.assertEqual(42., embedding_column_a.max_norm) diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..b6bf516286a824ac829ef230c78dd1bbd432fbac --- /dev/null +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -0,0 +1,3600 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This API defines FeatureColumn abstraction. + +FeatureColumns provide a high level abstraction for ingesting and representing +features. FeatureColumns are also the primary way of encoding features for +canned `tf.estimator.Estimator`s. + +When using FeatureColumns with `Estimators`, the type of feature column you +should choose depends on (1) the feature type and (2) the model type. + +1. Feature type: + + * Continuous features can be represented by `numeric_column`. + * Categorical features can be represented by any `categorical_column_with_*` + column: + - `categorical_column_with_vocabulary_list` + - `categorical_column_with_vocabulary_file` + - `categorical_column_with_hash_bucket` + - `categorical_column_with_identity` + - `weighted_categorical_column` + +2. Model type: + + * Deep neural network models (`DNNClassifier`, `DNNRegressor`). + + Continuous features can be directly fed into deep neural network models. + + age_column = numeric_column("age") + + To feed sparse features into DNN models, wrap the column with + `embedding_column` or `indicator_column`. `indicator_column` is recommended + for features with only a few possible values. For features with many + possible values, to reduce the size of your model, `embedding_column` is + recommended. + + embedded_dept_column = embedding_column( + categorical_column_with_vocabulary_list( + "department", ["math", "philosophy", ...]), dimension=10) + + * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`). + + Sparse features can be fed directly into linear models. They behave like an + indicator column but with an efficient implementation. + + dept_column = categorical_column_with_vocabulary_list("department", + ["math", "philosophy", "english"]) + + It is recommended that continuous features be bucketized before being + fed into linear models. + + bucketized_age_column = bucketized_column( + source_column=age_column, + boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]) + + Sparse features can be crossed (also known as conjuncted or combined) in + order to form non-linearities, and then fed into linear models. + + cross_dept_age_column = crossed_column( + columns=["department", bucketized_age_column], + hash_bucket_size=1000) + +Example of building canned `Estimator`s using FeatureColumns: + + ```python + # Define features and transformations + deep_feature_columns = [age_column, embedded_dept_column] + wide_feature_columns = [dept_column, bucketized_age_column, + cross_dept_age_column] + + # Build deep model + estimator = DNNClassifier( + feature_columns=deep_feature_columns, + hidden_units=[500, 250, 50]) + estimator.train(...) + + # Or build a wide model + estimator = LinearClassifier( + feature_columns=wide_feature_columns) + estimator.train(...) + + # Or build a wide and deep model! + estimator = DNNLinearCombinedClassifier( + linear_feature_columns=wide_feature_columns, + dnn_feature_columns=deep_feature_columns, + dnn_hidden_units=[500, 250, 50]) + estimator.train(...) + ``` + + +FeatureColumns can also be transformed into a generic input layer for +custom models using `input_layer`. + +Example of building model using FeatureColumns, this can be used in a +`model_fn` which is given to the {tf.estimator.Estimator}: + + ```python + # Building model via layers + + deep_feature_columns = [age_column, embedded_dept_column] + columns_to_tensor = parse_feature_columns_from_examples( + serialized=my_data, + feature_columns=deep_feature_columns) + first_layer = input_layer( + features=columns_to_tensor, + feature_columns=deep_feature_columns) + second_layer = fully_connected(first_layer, ...) + ``` + +NOTE: Functions prefixed with "_" indicate experimental or private parts of +the API subject to change, and should not be relied upon! +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import collections +import math + +import numpy as np +import six + + +from tensorflow.python.eager import context +from tensorflow.python.feature_column import feature_column as fc_old +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine import training +from tensorflow.python.layers import base +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.ops import template +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.util import nest + + +def _internal_input_layer(features, + feature_columns, + weight_collections=None, + trainable=True, + cols_to_vars=None, + scope=None): + """See input_layer. `scope` is a name or variable scope to use.""" + + feature_columns = fc_old._normalize_feature_columns(feature_columns) # pylint: disable=protected-access + for column in feature_columns: + if not isinstance(column, fc_old._DenseColumn): # pylint: disable=protected-access + raise ValueError( + 'Items of feature_columns must be a _DenseColumn. ' + 'You can wrap a categorical column with an ' + 'embedding_column or indicator_column. Given: {}'.format(column)) + weight_collections = list(weight_collections or []) + if ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections: + weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) + if ops.GraphKeys.MODEL_VARIABLES not in weight_collections: + weight_collections.append(ops.GraphKeys.MODEL_VARIABLES) + + # a non-None `scope` can allow for variable reuse, when, e.g., this function + # is wrapped by a `make_template`. + with variable_scope.variable_scope( + scope, default_name='input_layer', values=features.values()): + builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access + output_tensors = [] + ordered_columns = [] + for column in sorted(feature_columns, key=lambda x: x.name): + ordered_columns.append(column) + with variable_scope.variable_scope( + None, default_name=column._var_scope_name): # pylint: disable=protected-access + tensor = column._get_dense_tensor( # pylint: disable=protected-access + builder, + weight_collections=weight_collections, + trainable=trainable) + num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access + batch_size = array_ops.shape(tensor)[0] + output_tensors.append( + array_ops.reshape(tensor, shape=(batch_size, num_elements))) + if cols_to_vars is not None: + # Retrieve any variables created (some _DenseColumn's don't create + # variables, in which case an empty list is returned). + cols_to_vars[column] = ops.get_collection( + ops.GraphKeys.GLOBAL_VARIABLES, + scope=variable_scope.get_variable_scope().name) + _verify_static_batch_size_equality(output_tensors, ordered_columns) + return array_ops.concat(output_tensors, 1) + + +def input_layer(features, + feature_columns, + weight_collections=None, + trainable=True, + cols_to_vars=None): + """Returns a dense `Tensor` as input layer based on given `feature_columns`. + + Generally a single example in training data is described with FeatureColumns. + At the first layer of the model, this column oriented data should be converted + to a single `Tensor`. + + Example: + + ```python + price = numeric_column('price') + keywords_embedded = embedding_column( + categorical_column_with_hash_bucket("keywords", 10K), dimensions=16) + columns = [price, keywords_embedded, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + dense_tensor = input_layer(features, columns) + for units in [128, 64, 32]: + dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu) + prediction = tf.layers.dense(dense_tensor, 1) + ``` + + Args: + features: A mapping from key to tensors. `_FeatureColumn`s look up via these + keys. For example `numeric_column('price')` will look at 'price' key in + this dict. Values can be a `SparseTensor` or a `Tensor` depends on + corresponding `_FeatureColumn`. + feature_columns: An iterable containing the FeatureColumns to use as inputs + to your model. All items should be instances of classes derived from + `_DenseColumn` such as `numeric_column`, `embedding_column`, + `bucketized_column`, `indicator_column`. If you have categorical features, + you can wrap them with an `embedding_column` or `indicator_column`. + weight_collections: A list of collection names to which the Variable will be + added. Note that variables will also be added to collections + `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + cols_to_vars: If not `None`, must be a dictionary that will be filled with a + mapping from `_FeatureColumn` to list of `Variable`s. For example, after + the call, we might have cols_to_vars = + {_EmbeddingColumn( + categorical_column=_HashedCategoricalColumn( + key='sparse_feature', hash_bucket_size=5, dtype=tf.string), + dimension=10): [], + 'bias': [], + _NumericColumn( + key='numeric_feature2', shape=(2,)): + []} + If a column creates no variables, its value will be an empty list. Note + that cols_to_vars will also contain a string key 'bias' that maps to a + list of Variables. + + Returns: + A `Tensor` which represents predictions/logits of a linear model. Its shape + is (batch_size, units) and its dtype is `float32`. + + Raises: + ValueError: if an item in `feature_columns` is neither a `_DenseColumn` + nor `_CategoricalColumn`. + """ + with variable_scope.variable_scope(None, 'linear_model') as vs: + model_name = _strip_leading_slashes(vs.name) + linear_model_layer = _LinearModel( + feature_columns=feature_columns, + units=units, + sparse_combiner=sparse_combiner, + weight_collections=weight_collections, + trainable=trainable, + name=model_name) + retval = linear_model_layer(features) # pylint: disable=not-callable + if cols_to_vars is not None: + cols_to_vars.update(linear_model_layer.cols_to_vars()) + return retval + + +def _add_to_collections(var, weight_collections): + """Adds a var to the list of weight_collections provided. + + Handles the case for partitioned and non-partitioned variables. + + Args: + var: A variable or Partitioned Variable. + weight_collections: List of collections to add variable to. + """ + for weight_collection in weight_collections: + # The layer self.add_variable call already adds it to GLOBAL_VARIABLES. + if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES: + continue + # TODO(rohanj): Explore adding a _get_variable_list method on `Variable` + # so that we don't have to do this check. + if isinstance(var, variables.PartitionedVariable): + for constituent_var in list(var): + ops.add_to_collection(weight_collection, constituent_var) + else: + ops.add_to_collection(weight_collection, var) + + +class _FCLinearWrapper(base.Layer): + """Wraps a _FeatureColumn in a layer for use in a linear model. + + See `linear_model` above. + """ + + def __init__(self, + feature_column, + units=1, + sparse_combiner='sum', + weight_collections=None, + trainable=True, + name=None, + **kwargs): + super(_FCLinearWrapper, self).__init__( + trainable=trainable, name=name, **kwargs) + self._feature_column = feature_column + self._units = units + self._sparse_combiner = sparse_combiner + self._weight_collections = weight_collections + + def build(self, _): + if isinstance(self._feature_column, fc_old._CategoricalColumn): # pylint: disable=protected-access + weight = self.add_variable( + name='weights', + shape=(self._feature_column._num_buckets, self._units), # pylint: disable=protected-access + initializer=init_ops.zeros_initializer(), + trainable=self.trainable) + else: + num_elements = self._feature_column._variable_shape.num_elements() # pylint: disable=protected-access + weight = self.add_variable( + name='weights', + shape=[num_elements, self._units], + initializer=init_ops.zeros_initializer(), + trainable=self.trainable) + _add_to_collections(weight, self._weight_collections) + self._weight_var = weight + self.built = True + + def call(self, builder): + weighted_sum = fc_old._create_weighted_sum( # pylint: disable=protected-access + column=self._feature_column, + builder=builder, + units=self._units, + sparse_combiner=self._sparse_combiner, + weight_collections=self._weight_collections, + trainable=self.trainable, + weight_var=self._weight_var) + return weighted_sum + + +class _BiasLayer(base.Layer): + """A layer for the bias term. + """ + + def __init__(self, + units=1, + trainable=True, + weight_collections=None, + name=None, + **kwargs): + super(_BiasLayer, self).__init__(trainable=trainable, name=name, **kwargs) + self._units = units + self._weight_collections = weight_collections + + def build(self, _): + self._bias_variable = self.add_variable( + 'bias_weights', + shape=[self._units], + initializer=init_ops.zeros_initializer(), + trainable=self.trainable) + _add_to_collections(self._bias_variable, self._weight_collections) + self.built = True + + def call(self, _): + return self._bias_variable + + +def _get_expanded_variable_list(variable): + if (isinstance(variable, variables.Variable) or + resource_variable_ops.is_resource_variable(variable)): + return [variable] # Single variable case. + else: # Must be a PartitionedVariable, so convert into a list. + return list(variable) + + +def _strip_leading_slashes(name): + return name.rsplit('/', 1)[-1] + + +class _LinearModel(training.Model): + """Creates a linear model using feature columns. + + See `linear_model` for details. + """ + + def __init__(self, + feature_columns, + units=1, + sparse_combiner='sum', + weight_collections=None, + trainable=True, + name=None, + **kwargs): + super(_LinearModel, self).__init__(name=name, **kwargs) + self._feature_columns = fc_old._normalize_feature_columns( # pylint: disable=protected-access + feature_columns) + self._weight_collections = list(weight_collections or []) + if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections: + self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) + if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections: + self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES) + + column_layers = {} + for column in sorted(self._feature_columns, key=lambda x: x.name): + with variable_scope.variable_scope( + None, default_name=column._var_scope_name) as vs: # pylint: disable=protected-access + # Having the fully expressed variable scope name ends up doubly + # expressing the outer scope (scope with which this method was called) + # in the name of the variable that would get created. + column_name = _strip_leading_slashes(vs.name) + column_layer = _FCLinearWrapper(column, units, sparse_combiner, + self._weight_collections, trainable, + column_name, **kwargs) + column_layers[column_name] = column_layer + self._column_layers = self._add_layers(column_layers) + self._bias_layer = _BiasLayer( + units=units, + trainable=trainable, + weight_collections=self._weight_collections, + name='bias_layer', + **kwargs) + self._cols_to_vars = {} + + def cols_to_vars(self): + """Returns a dict mapping _FeatureColumns to variables. + + See `linear_model` for more information. + This is not populated till `call` is called i.e. layer is built. + """ + return self._cols_to_vars + + def call(self, features): + with variable_scope.variable_scope(self.name): + for column in self._feature_columns: + if not isinstance( + column, + ( + fc_old._DenseColumn, # pylint: disable=protected-access + fc_old._CategoricalColumn)): # pylint: disable=protected-access + raise ValueError( + 'Items of feature_columns must be either a ' + '_DenseColumn or _CategoricalColumn. Given: {}'.format(column)) + weighted_sums = [] + ordered_columns = [] + builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access + for layer in sorted(self._column_layers.values(), key=lambda x: x.name): + column = layer._feature_column # pylint: disable=protected-access + ordered_columns.append(column) + weighted_sum = layer(builder) + weighted_sums.append(weighted_sum) + self._cols_to_vars[column] = ops.get_collection( + ops.GraphKeys.GLOBAL_VARIABLES, scope=layer.scope_name) + + _verify_static_batch_size_equality(weighted_sums, ordered_columns) + predictions_no_bias = math_ops.add_n( + weighted_sums, name='weighted_sum_no_bias') + predictions = nn_ops.bias_add( + predictions_no_bias, + self._bias_layer( # pylint: disable=not-callable + builder, + scope=variable_scope.get_variable_scope()), # pylint: disable=not-callable + name='weighted_sum') + bias = self._bias_layer.variables[0] + self._cols_to_vars['bias'] = _get_expanded_variable_list(bias) + return predictions + + def _add_layers(self, layers): + # "Magic" required for keras.Model classes to track all the variables in + # a list of layers.Layer objects. + # TODO(ashankar): Figure out API so user code doesn't have to do this. + for name, layer in layers.items(): + setattr(self, 'layer-%s' % name, layer) + return layers + + +def _transform_features(features, feature_columns, state_manager): + """Returns transformed features based on features columns passed in. + + Please note that most probably you would not need to use this function. Please + check `input_layer` and `linear_model` to see whether they will + satisfy your use case or not. + + Example: + + ```python + # Define features and transformations + crosses_a_x_b = crossed_column( + columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000) + price_buckets = bucketized_column( + source_column=numeric_column("price"), boundaries=[...]) + + columns = [crosses_a_x_b, price_buckets] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + transformed = transform_features(features=features, feature_columns=columns) + + assertCountEqual(columns, transformed.keys()) + ``` + + Args: + features: A mapping from key to tensors. `FeatureColumn`s look up via these + keys. For example `numeric_column('price')` will look at 'price' key in + this dict. Values can be a `SparseTensor` or a `Tensor` depends on + corresponding `FeatureColumn`. + feature_columns: An iterable containing all the `FeatureColumn`s. + state_manager: A StateManager object that holds the FeatureColumn state. + + Returns: + A `dict` mapping `FeatureColumn` to `Tensor` and `SparseTensor` values. + """ + feature_columns = _normalize_feature_columns(feature_columns) + outputs = {} + with ops.name_scope( + None, default_name='transform_features', values=features.values()): + transformation_cache = FeatureTransformationCache(features) + for column in sorted(feature_columns, key=lambda x: x.name): + with ops.name_scope(None, default_name=column.name): + outputs[column] = transformation_cache.get(column, state_manager) + return outputs + + +def make_parse_example_spec(feature_columns): + """Creates parsing spec dictionary from input feature_columns. + + The returned dictionary can be used as arg 'features' in `tf.parse_example`. + + Typical usage example: + + ```python + # Define features and transformations + feature_a = categorical_column_with_vocabulary_file(...) + feature_b = numeric_column(...) + feature_c_bucketized = bucketized_column(numeric_column("feature_c"), ...) + feature_a_x_feature_c = crossed_column( + columns=["feature_a", feature_c_bucketized], ...) + + feature_columns = set( + [feature_b, feature_c_bucketized, feature_a_x_feature_c]) + features = tf.parse_example( + serialized=serialized_examples, + features=make_parse_example_spec(feature_columns)) + ``` + + For the above example, make_parse_example_spec would return the dict: + + ```python + { + "feature_a": parsing_ops.VarLenFeature(tf.string), + "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32), + "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32) + } + ``` + + Args: + feature_columns: An iterable containing all feature columns. All items + should be instances of classes derived from `FeatureColumn`. + + Returns: + A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature` + value. + + Raises: + ValueError: If any of the given `feature_columns` is not a `FeatureColumn` + instance. + """ + result = {} + for column in feature_columns: + if not isinstance(column, FeatureColumn): + raise ValueError('All feature_columns must be FeatureColumn instances. ' + 'Given: {}'.format(column)) + config = column.parse_example_spec + for key, value in six.iteritems(config): + if key in result and value != result[key]: + raise ValueError( + 'feature_columns contain different parse_spec for key ' + '{}. Given {} and {}'.format(key, value, result[key])) + result.update(config) + return result + + +def embedding_column( + categorical_column, dimension, combiner='mean', initializer=None, + ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None, + trainable=True): + """`_DenseColumn` that converts from sparse, categorical input. + + Use this when your inputs are sparse, but you want to convert them to a dense + representation (e.g., to feed to a DNN). + + Inputs must be a `_CategoricalColumn` created by any of the + `categorical_column_*` function. Here is an example of using + `embedding_column` with `DNNClassifier`: + + ```python + video_id = categorical_column_with_identity( + key='video_id', num_buckets=1000000, default_value=0) + columns = [embedding_column(video_id, 9),...] + + estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) + + label_column = ... + def input_fn(): + features = tf.parse_example( + ..., features=make_parse_example_spec(columns + [label_column])) + labels = features.pop(label_column.name) + return features, labels + + estimator.train(input_fn=input_fn, steps=100) + ``` + + Here is an example using `embedding_column` with model_fn: + + ```python + def model_fn(features, ...): + video_id = categorical_column_with_identity( + key='video_id', num_buckets=1000000, default_value=0) + columns = [embedding_column(video_id, 9),...] + dense_tensor = input_layer(features, columns) + # Form DNN layers, calculate loss, and return EstimatorSpec. + ... + ``` + + Args: + categorical_column: A `_CategoricalColumn` created by a + `categorical_column_with_*` function. This column produces the sparse IDs + that are inputs to the embedding lookup. + dimension: An integer specifying dimension of the embedding, must be > 0. + combiner: A string specifying how to reduce if there are multiple entries + in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with + 'mean' the default. 'sqrtn' often achieves good accuracy, in particular + with bag-of-words columns. Each of this can be thought as example level + normalizations on the column. For more information, see + `tf.embedding_lookup_sparse`. + initializer: A variable initializer function to be used in embedding + variable initialization. If not specified, defaults to + `tf.truncated_normal_initializer` with mean `0.0` and standard deviation + `1/sqrt(dimension)`. + ckpt_to_load_from: String representing checkpoint name/pattern from which to + restore column weights. Required if `tensor_name_in_ckpt` is not `None`. + tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from + which to restore the column weights. Required if `ckpt_to_load_from` is + not `None`. + max_norm: If not `None`, embedding values are l2-normalized to this value. + trainable: Whether or not the embedding is trainable. Default is True. + + Returns: + `_DenseColumn` that converts from sparse input. + + Raises: + ValueError: if `dimension` not > 0. + ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` + is specified. + ValueError: if `initializer` is specified and is not callable. + RuntimeError: If eager execution is enabled. + """ + if (dimension is None) or (dimension < 1): + raise ValueError('Invalid dimension {}.'.format(dimension)) + if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): + raise ValueError('Must specify both `ckpt_to_load_from` and ' + '`tensor_name_in_ckpt` or none of them.') + + if (initializer is not None) and (not callable(initializer)): + raise ValueError('initializer must be callable if specified. ' + 'Embedding of column_name: {}'.format( + categorical_column.name)) + if initializer is None: + initializer = init_ops.truncated_normal_initializer( + mean=0.0, stddev=1 / math.sqrt(dimension)) + + return EmbeddingColumn( + categorical_column=categorical_column, + dimension=dimension, + combiner=combiner, + initializer=initializer, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable) + + +def shared_embedding_columns( + categorical_columns, dimension, combiner='mean', initializer=None, + shared_embedding_collection_name=None, ckpt_to_load_from=None, + tensor_name_in_ckpt=None, max_norm=None, trainable=True): + """List of dense columns that convert from sparse, categorical input. + + This is similar to `embedding_column`, except that it produces a list of + embedding columns that share the same embedding weights. + + Use this when your inputs are sparse and of the same type (e.g. watched and + impression video IDs that share the same vocabulary), and you want to convert + them to a dense representation (e.g., to feed to a DNN). + + Inputs must be a list of categorical columns created by any of the + `categorical_column_*` function. They must all be of the same type and have + the same arguments except `key`. E.g. they can be + categorical_column_with_vocabulary_file with the same vocabulary_file. Some or + all columns could also be weighted_categorical_column. + + Here is an example embedding of two features for a DNNClassifier model: + + ```python + watched_video_id = categorical_column_with_vocabulary_file( + 'watched_video_id', video_vocabulary_file, video_vocabulary_size) + impression_video_id = categorical_column_with_vocabulary_file( + 'impression_video_id', video_vocabulary_file, video_vocabulary_size) + columns = shared_embedding_columns( + [watched_video_id, impression_video_id], dimension=10) + + estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) + + label_column = ... + def input_fn(): + features = tf.parse_example( + ..., features=make_parse_example_spec(columns + [label_column])) + labels = features.pop(label_column.name) + return features, labels + + estimator.train(input_fn=input_fn, steps=100) + ``` + + Here is an example using `shared_embedding_columns` with model_fn: + + ```python + def model_fn(features, ...): + watched_video_id = categorical_column_with_vocabulary_file( + 'watched_video_id', video_vocabulary_file, video_vocabulary_size) + impression_video_id = categorical_column_with_vocabulary_file( + 'impression_video_id', video_vocabulary_file, video_vocabulary_size) + columns = shared_embedding_columns( + [watched_video_id, impression_video_id], dimension=10) + dense_tensor = input_layer(features, columns) + # Form DNN layers, calculate loss, and return EstimatorSpec. + ... + ``` + + Args: + categorical_columns: List of categorical columns created by a + `categorical_column_with_*` function. These columns produce the sparse IDs + that are inputs to the embedding lookup. All columns must be of the same + type and have the same arguments except `key`. E.g. they can be + categorical_column_with_vocabulary_file with the same vocabulary_file. + Some or all columns could also be weighted_categorical_column. + dimension: An integer specifying dimension of the embedding, must be > 0. + combiner: A string specifying how to reduce if there are multiple entries + in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with + 'mean' the default. 'sqrtn' often achieves good accuracy, in particular + with bag-of-words columns. Each of this can be thought as example level + normalizations on the column. For more information, see + `tf.embedding_lookup_sparse`. + initializer: A variable initializer function to be used in embedding + variable initialization. If not specified, defaults to + `tf.truncated_normal_initializer` with mean `0.0` and standard deviation + `1/sqrt(dimension)`. + shared_embedding_collection_name: Optional collective name of these columns. + If not given, a reasonable name will be chosen based on the names of + `categorical_columns`. + ckpt_to_load_from: String representing checkpoint name/pattern from which to + restore column weights. Required if `tensor_name_in_ckpt` is not `None`. + tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from + which to restore the column weights. Required if `ckpt_to_load_from` is + not `None`. + max_norm: If not `None`, each embedding is clipped if its l2-norm is + larger than this value, before combining. + trainable: Whether or not the embedding is trainable. Default is True. + + Returns: + A list of dense columns that converts from sparse input. The order of + results follows the ordering of `categorical_columns`. + + Raises: + ValueError: if `dimension` not > 0. + ValueError: if any of the given `categorical_columns` is of different type + or has different arguments than the others. + ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` + is specified. + ValueError: if `initializer` is specified and is not callable. + RuntimeError: if eager execution is enabled. + """ + if context.executing_eagerly(): + raise RuntimeError('shared_embedding_columns are not supported when eager ' + 'execution is enabled.') + + if (dimension is None) or (dimension < 1): + raise ValueError('Invalid dimension {}.'.format(dimension)) + if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): + raise ValueError('Must specify both `ckpt_to_load_from` and ' + '`tensor_name_in_ckpt` or none of them.') + + if (initializer is not None) and (not callable(initializer)): + raise ValueError('initializer must be callable if specified.') + if initializer is None: + initializer = init_ops.truncated_normal_initializer( + mean=0.0, stddev=1. / math.sqrt(dimension)) + + # Sort the columns so the default collection name is deterministic even if the + # user passes columns from an unsorted collection, such as dict.values(). + sorted_columns = sorted(categorical_columns, key=lambda x: x.name) + + c0 = sorted_columns[0] + num_buckets = c0.num_buckets + if not isinstance(c0, CategoricalColumn): + raise ValueError( + 'All categorical_columns must be subclasses of CategoricalColumn. ' + 'Given: {}, of type: {}'.format(c0, type(c0))) + if isinstance(c0, WeightedCategoricalColumn): + c0 = c0.categorical_column + for c in sorted_columns[1:]: + if isinstance(c, WeightedCategoricalColumn): + c = c.categorical_column + if not isinstance(c, type(c0)): + raise ValueError( + 'To use shared_embedding_column, all categorical_columns must have ' + 'the same type, or be weighted_categorical_column of the same type. ' + 'Given column: {} of type: {} does not match given column: {} of ' + 'type: {}'.format(c0, type(c0), c, type(c))) + if num_buckets != c.num_buckets: + raise ValueError( + 'To use shared_embedding_column, all categorical_columns must have ' + 'the same number of buckets. Given column: {} with buckets: {} does ' + 'not match column: {} with buckets: {}'.format( + c0, num_buckets, c, c.num_buckets)) + + if not shared_embedding_collection_name: + shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) + shared_embedding_collection_name += '_shared_embedding' + + result = [] + for column in categorical_columns: + result.append( + SharedEmbeddingColumn( + categorical_column=column, + initializer=initializer, + dimension=dimension, + combiner=combiner, + shared_embedding_collection_name=shared_embedding_collection_name, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable)) + + return result + + +def numeric_column(key, + shape=(1,), + default_value=None, + dtype=dtypes.float32, + normalizer_fn=None): + """Represents real valued or numerical features. + + Example: + + ```python + price = numeric_column('price') + columns = [price, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + dense_tensor = input_layer(features, columns) + + # or + bucketized_price = bucketized_column(price, boundaries=[...]) + columns = [bucketized_price, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + linear_prediction = linear_model(features, columns) + ``` + + Args: + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. + shape: An iterable of integers specifies the shape of the `Tensor`. An + integer can be given which means a single dimension `Tensor` with given + width. The `Tensor` representing the column will have the shape of + [batch_size] + `shape`. + default_value: A single value compatible with `dtype` or an iterable of + values compatible with `dtype` which the column takes on during + `tf.Example` parsing if data is missing. A default value of `None` will + cause `tf.parse_example` to fail if an example does not contain this + column. If a single value is provided, the same value will be applied as + the default value for every item. If an iterable of values is provided, + the shape of the `default_value` should be equal to the given `shape`. + dtype: defines the type of values. Default value is `tf.float32`. Must be a + non-quantized, real integer or floating point type. + normalizer_fn: If not `None`, a function that can be used to normalize the + value of the tensor after `default_value` is applied for parsing. + Normalizer function takes the input `Tensor` as its argument, and returns + the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that + even though the most common use case of this function is normalization, it + can be used for any kind of Tensorflow transformations. + + Returns: + A `NumericColumn`. + + Raises: + TypeError: if any dimension in shape is not an int + ValueError: if any dimension in shape is not a positive integer + TypeError: if `default_value` is an iterable but not compatible with `shape` + TypeError: if `default_value` is not compatible with `dtype`. + ValueError: if `dtype` is not convertible to `tf.float32`. + """ + shape = _check_shape(shape, key) + if not (dtype.is_integer or dtype.is_floating): + raise ValueError('dtype must be convertible to float. ' + 'dtype: {}, key: {}'.format(dtype, key)) + default_value = _check_default_value(shape, default_value, dtype, key) + + if normalizer_fn is not None and not callable(normalizer_fn): + raise TypeError( + 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) + + _assert_key_is_string(key) + return NumericColumn( + key, + shape=shape, + default_value=default_value, + dtype=dtype, + normalizer_fn=normalizer_fn) + + +def bucketized_column(source_column, boundaries): + """Represents discretized dense input. + + Buckets include the left boundary, and exclude the right boundary. Namely, + `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`, + `[1., 2.)`, and `[2., +inf)`. + + For example, if the inputs are + + ```python + boundaries = [0, 10, 100] + input tensor = [[-5, 10000] + [150, 10] + [5, 100]] + ``` + + then the output will be + + ```python + output = [[0, 3] + [3, 2] + [1, 3]] + ``` + + Example: + + ```python + price = numeric_column('price') + bucketized_price = bucketized_column(price, boundaries=[...]) + columns = [bucketized_price, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + linear_prediction = linear_model(features, columns) + + # or + columns = [bucketized_price, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + dense_tensor = input_layer(features, columns) + ``` + + `bucketized_column` can also be crossed with another categorical column using + `crossed_column`: + + ```python + price = numeric_column('price') + # bucketized_column converts numerical feature to a categorical one. + bucketized_price = bucketized_column(price, boundaries=[...]) + # 'keywords' is a string feature. + price_x_keywords = crossed_column([bucketized_price, 'keywords'], 50K) + columns = [price_x_keywords, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + linear_prediction = linear_model(features, columns) + ``` + + Args: + source_column: A one-dimensional dense column which is generated with + `numeric_column`. + boundaries: A sorted list or tuple of floats specifying the boundaries. + + Returns: + A `BucketizedColumn`. + + Raises: + ValueError: If `source_column` is not a numeric column, or if it is not + one-dimensional. + ValueError: If `boundaries` is not a sorted list or tuple. + """ + if not isinstance(source_column, NumericColumn): + raise ValueError( + 'source_column must be a column generated with numeric_column(). ' + 'Given: {}'.format(source_column)) + if len(source_column.shape) > 1: + raise ValueError( + 'source_column must be one-dimensional column. ' + 'Given: {}'.format(source_column)) + if (not boundaries or + not (isinstance(boundaries, list) or isinstance(boundaries, tuple))): + raise ValueError('boundaries must be a sorted list.') + for i in range(len(boundaries) - 1): + if boundaries[i] >= boundaries[i + 1]: + raise ValueError('boundaries must be a sorted list.') + return BucketizedColumn(source_column, tuple(boundaries)) + + +def _assert_string_or_int(dtype, prefix): + if (dtype != dtypes.string) and (not dtype.is_integer): + raise ValueError( + '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype)) + + +def _assert_key_is_string(key): + if not isinstance(key, six.string_types): + raise ValueError( + 'key must be a string. Got: type {}. Given key: {}.'.format( + type(key), key)) + + +def categorical_column_with_hash_bucket(key, + hash_bucket_size, + dtype=dtypes.string): + """Represents sparse feature where ids are set by hashing. + + Use this when your sparse features are in string or integer format, and you + want to distribute your inputs into a finite number of buckets by hashing. + output_id = Hash(input_feature_string) % bucket_size for string type input. + For int type input, the value is converted to its string representation first + and then hashed by the same formula. + + For input dictionary `features`, `features[key]` is either `Tensor` or + `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int + and `''` for string, which will be dropped by this feature column. + + Example: + + ```python + keywords = categorical_column_with_hash_bucket("keywords", 10K) + columns = [keywords, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + linear_prediction = linear_model(features, columns) + + # or + keywords_embedded = embedding_column(keywords, 16) + columns = [keywords_embedded, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + dense_tensor = input_layer(features, columns) + ``` + + Args: + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. + hash_bucket_size: An int > 1. The number of buckets. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `HashedCategoricalColumn`. + + Raises: + ValueError: `hash_bucket_size` is not greater than 1. + ValueError: `dtype` is neither string nor integer. + """ + if hash_bucket_size is None: + raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key)) + + if hash_bucket_size < 1: + raise ValueError('hash_bucket_size must be at least 1. ' + 'hash_bucket_size: {}, key: {}'.format( + hash_bucket_size, key)) + + _assert_key_is_string(key) + _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + + return HashedCategoricalColumn(key, hash_bucket_size, dtype) + + +def categorical_column_with_vocabulary_file(key, + vocabulary_file, + vocabulary_size=None, + num_oov_buckets=0, + default_value=None, + dtype=dtypes.string): + """A `CategoricalColumn` with a vocabulary file. + + Use this when your inputs are in string or integer format, and you have a + vocabulary file that maps each value to an integer ID. By default, + out-of-vocabulary values are ignored. Use either (but not both) of + `num_oov_buckets` and `default_value` to specify how to include + out-of-vocabulary values. + + For input dictionary `features`, `features[key]` is either `Tensor` or + `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int + and `''` for string, which will be dropped by this feature column. + + Example with `num_oov_buckets`: + File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state + abbreviation. All inputs with values in that file are assigned an ID 0-49, + corresponding to its line number. All other values are hashed and assigned an + ID 50-54. + + ```python + states = categorical_column_with_vocabulary_file( + key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, + num_oov_buckets=5) + columns = [states, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + linear_prediction = linear_model(features, columns) + ``` + + Example with `default_value`: + File '/us/states.txt' contains 51 lines - the first line is 'XX', and the + other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX' + in input, and other values missing from the file, will be assigned ID 0. All + others are assigned the corresponding line number 1-50. + + ```python + states = categorical_column_with_vocabulary_file( + key='states', vocabulary_file='/us/states.txt', vocabulary_size=51, + default_value=0) + columns = [states, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + linear_prediction, _, _ = linear_model(features, columns) + ``` + + And to make an embedding with either: + + ```python + columns = [embedding_column(states, 3),...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + dense_tensor = input_layer(features, columns) + ``` + + Args: + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. + vocabulary_file: The vocabulary file name. + vocabulary_size: Number of the elements in the vocabulary. This must be no + greater than length of `vocabulary_file`, if less than length, later + values are ignored. If None, it is set to the length of `vocabulary_file`. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of + the input value. A positive `num_oov_buckets` can not be specified with + `default_value`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to `-1`. This can not be specified with a positive + `num_oov_buckets`. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `CategoricalColumn` with a vocabulary file. + + Raises: + ValueError: `vocabulary_file` is missing or cannot be opened. + ValueError: `vocabulary_size` is missing or < 1. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: `dtype` is neither string nor integer. + """ + if not vocabulary_file: + raise ValueError('Missing vocabulary_file in {}.'.format(key)) + + if vocabulary_size is None: + if not gfile.Exists(vocabulary_file): + raise ValueError('vocabulary_file in {} does not exist.'.format(key)) + + with gfile.GFile(vocabulary_file) as f: + vocabulary_size = sum(1 for _ in f) + logging.info( + 'vocabulary_size = %d in %s is inferred from the number of elements ' + 'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file) + + # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`. + if vocabulary_size < 1: + raise ValueError('Invalid vocabulary_size in {}.'.format(key)) + if num_oov_buckets: + if default_value is not None: + raise ValueError( + 'Can\'t specify both num_oov_buckets and default_value in {}.'.format( + key)) + if num_oov_buckets < 0: + raise ValueError('Invalid num_oov_buckets {} in {}.'.format( + num_oov_buckets, key)) + _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + _assert_key_is_string(key) + return VocabularyFileCategoricalColumn( + key=key, + vocabulary_file=vocabulary_file, + vocabulary_size=vocabulary_size, + num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets, + default_value=-1 if default_value is None else default_value, + dtype=dtype) + + +def categorical_column_with_vocabulary_list( + key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0): + """A `_CategoricalColumn` with in-memory vocabulary. + + Use this when your inputs are in string or integer format, and you have an + in-memory vocabulary mapping each value to an integer ID. By default, + out-of-vocabulary values are ignored. Use either (but not both) of + `num_oov_buckets` and `default_value` to specify how to include + out-of-vocabulary values. + + For input dictionary `features`, `features[key]` is either `Tensor` or + `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int + and `''` for string, which will be dropped by this feature column. + + Example with `num_oov_buckets`: + In the following example, each input in `vocabulary_list` is assigned an ID + 0-3 corresponding to its index (e.g., input 'B' produces output 2). All other + inputs are hashed and assigned an ID 4-5. + + ```python + colors = categorical_column_with_vocabulary_list( + key='colors', vocabulary_list=('R', 'G', 'B', 'Y'), + num_oov_buckets=2) + columns = [colors, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + linear_prediction, _, _ = linear_model(features, columns) + ``` + + Example with `default_value`: + In the following example, each input in `vocabulary_list` is assigned an ID + 0-4 corresponding to its index (e.g., input 'B' produces output 3). All other + inputs are assigned `default_value` 0. + + + ```python + colors = categorical_column_with_vocabulary_list( + key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0) + columns = [colors, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + linear_prediction, _, _ = linear_model(features, columns) + ``` + + And to make an embedding with either: + + ```python + columns = [embedding_column(colors, 3),...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + dense_tensor = input_layer(features, columns) + ``` + + Args: + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. + vocabulary_list: An ordered iterable defining the vocabulary. Each feature + is mapped to the index of its value (if present) in `vocabulary_list`. + Must be castable to `dtype`. + dtype: The type of features. Only string and integer types are supported. + If `None`, it will be inferred from `vocabulary_list`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to `-1`. This can not be specified with a positive + `num_oov_buckets`. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a + hash of the input value. A positive `num_oov_buckets` can not be specified + with `default_value`. + + Returns: + A `CategoricalColumn` with in-memory vocabulary. + + Raises: + ValueError: if `vocabulary_list` is empty, or contains duplicate keys. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: if `dtype` is not integer or string. + """ + if (vocabulary_list is None) or (len(vocabulary_list) < 1): + raise ValueError( + 'vocabulary_list {} must be non-empty, column_name: {}'.format( + vocabulary_list, key)) + if len(set(vocabulary_list)) != len(vocabulary_list): + raise ValueError( + 'Duplicate keys in vocabulary_list {}, column_name: {}'.format( + vocabulary_list, key)) + vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype) + if num_oov_buckets: + if default_value != -1: + raise ValueError( + 'Can\'t specify both num_oov_buckets and default_value in {}.'.format( + key)) + if num_oov_buckets < 0: + raise ValueError('Invalid num_oov_buckets {} in {}.'.format( + num_oov_buckets, key)) + _assert_string_or_int( + vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key)) + if dtype is None: + dtype = vocabulary_dtype + elif dtype.is_integer != vocabulary_dtype.is_integer: + raise ValueError( + 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format( + dtype, vocabulary_dtype, key)) + _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + _assert_key_is_string(key) + + return VocabularyListCategoricalColumn( + key=key, + vocabulary_list=tuple(vocabulary_list), + dtype=dtype, + default_value=default_value, + num_oov_buckets=num_oov_buckets) + + +def categorical_column_with_identity(key, num_buckets, default_value=None): + """A `CategoricalColumn` that returns identity values. + + Use this when your inputs are integers in the range `[0, num_buckets)`, and + you want to use the input value itself as the categorical ID. Values outside + this range will result in `default_value` if specified, otherwise it will + fail. + + Typically, this is used for contiguous ranges of integer indexes, but + it doesn't have to be. This might be inefficient, however, if many of IDs + are unused. Consider `categorical_column_with_hash_bucket` in that case. + + For input dictionary `features`, `features[key]` is either `Tensor` or + `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int + and `''` for string, which will be dropped by this feature column. + + In the following examples, each input in the range `[0, 1000000)` is assigned + the same value. All other inputs are assigned `default_value` 0. Note that a + literal 0 in inputs will result in the same default ID. + + Linear model: + + ```python + video_id = categorical_column_with_identity( + key='video_id', num_buckets=1000000, default_value=0) + columns = [video_id, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + linear_prediction, _, _ = linear_model(features, columns) + ``` + + Embedding for a DNN model: + + ```python + columns = [embedding_column(video_id, 9),...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + dense_tensor = input_layer(features, columns) + ``` + + Args: + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. + num_buckets: Range of inputs and outputs is `[0, num_buckets)`. + default_value: If `None`, this column's graph operations will fail for + out-of-range inputs. Otherwise, this value must be in the range + `[0, num_buckets)`, and will replace inputs in that range. + + Returns: + A `CategoricalColumn` that returns identity values. + + Raises: + ValueError: if `num_buckets` is less than one. + ValueError: if `default_value` is not in range `[0, num_buckets)`. + """ + if num_buckets < 1: + raise ValueError( + 'num_buckets {} < 1, column_name {}'.format(num_buckets, key)) + if (default_value is not None) and ( + (default_value < 0) or (default_value >= num_buckets)): + raise ValueError( + 'default_value {} not in range [0, {}), column_name {}'.format( + default_value, num_buckets, key)) + _assert_key_is_string(key) + return IdentityCategoricalColumn( + key=key, number_buckets=num_buckets, default_value=default_value) + + +def indicator_column(categorical_column): + """Represents multi-hot representation of given categorical column. + + - For DNN model, `indicator_column` can be used to wrap any + `categorical_column_*` (e.g., to feed to DNN). Consider to Use + `embedding_column` if the number of buckets/unique(values) are large. + + - For Wide (aka linear) model, `indicator_column` is the internal + representation for categorical column when passing categorical column + directly (as any element in feature_columns) to `linear_model`. See + `linear_model` for details. + + ```python + name = indicator_column(categorical_column_with_vocabulary_list( + 'name', ['bob', 'george', 'wanda']) + columns = [name, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + dense_tensor = input_layer(features, columns) + + dense_tensor == [[1, 0, 0]] # If "name" bytes_list is ["bob"] + dense_tensor == [[1, 0, 1]] # If "name" bytes_list is ["bob", "wanda"] + dense_tensor == [[2, 0, 0]] # If "name" bytes_list is ["bob", "bob"] + ``` + + Args: + categorical_column: A `CategoricalColumn` which is created by + `categorical_column_with_*` or `crossed_column` functions. + + Returns: + An `IndicatorColumn`. + """ + return IndicatorColumn(categorical_column) + + +def weighted_categorical_column( + categorical_column, weight_feature_key, dtype=dtypes.float32): + """Applies weight values to a `_CategoricalColumn`. + + Use this when each of your sparse inputs has both an ID and a value. For + example, if you're representing text documents as a collection of word + frequencies, you can provide 2 parallel sparse input features ('terms' and + 'frequencies' below). + + Example: + + Input `tf.Example` objects: + + ```proto + [ + features { + feature { + key: "terms" + value {bytes_list {value: "very" value: "model"}} + } + feature { + key: "frequencies" + value {float_list {value: 0.3 value: 0.1}} + } + }, + features { + feature { + key: "terms" + value {bytes_list {value: "when" value: "course" value: "human"}} + } + feature { + key: "frequencies" + value {float_list {value: 0.4 value: 0.1 value: 0.2}} + } + } + ] + ``` + + ```python + categorical_column = categorical_column_with_hash_bucket( + column_name='terms', hash_bucket_size=1000) + weighted_column = weighted_categorical_column( + categorical_column=categorical_column, weight_feature_key='frequencies') + columns = [weighted_column, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + linear_prediction, _, _ = linear_model(features, columns) + ``` + + This assumes the input dictionary contains a `SparseTensor` for key + 'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have + the same indices and dense shape. + + Args: + categorical_column: A `_CategoricalColumn` created by + `categorical_column_with_*` functions. + weight_feature_key: String key for weight values. + dtype: Type of weights, such as `tf.float32`. Only float and integer weights + are supported. + + Returns: + A `CategoricalColumn` composed of two sparse features: one represents id, + the other represents weight (value) of the id feature in that example. + + Raises: + ValueError: if `dtype` is not convertible to float. + """ + if (dtype is None) or not (dtype.is_integer or dtype.is_floating): + raise ValueError('dtype {} is not convertible to float.'.format(dtype)) + return WeightedCategoricalColumn( + categorical_column=categorical_column, + weight_feature_key=weight_feature_key, + dtype=dtype) + + +def crossed_column(keys, hash_bucket_size, hash_key=None): + """Returns a column for performing crosses of categorical features. + + Crossed features will be hashed according to `hash_bucket_size`. Conceptually, + the transformation can be thought of as: + Hash(cartesian product of features) % `hash_bucket_size` + + For example, if the input features are: + + * SparseTensor referred by first key: + + ```python + shape = [2, 2] + { + [0, 0]: "a" + [1, 0]: "b" + [1, 1]: "c" + } + ``` + + * SparseTensor referred by second key: + + ```python + shape = [2, 1] + { + [0, 0]: "d" + [1, 0]: "e" + } + ``` + + then crossed feature will look like: + + ```python + shape = [2, 2] + { + [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size + [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size + [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size + } + ``` + + Here is an example to create a linear model with crosses of string features: + + ```python + keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K) + columns = [keywords_x_doc_terms, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + linear_prediction = linear_model(features, columns) + ``` + + You could also use vocabulary lookup before crossing: + + ```python + keywords = categorical_column_with_vocabulary_file( + 'keywords', '/path/to/vocabulary/file', vocabulary_size=1K) + keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K) + columns = [keywords_x_doc_terms, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + linear_prediction = linear_model(features, columns) + ``` + + If an input feature is of numeric type, you can use + `categorical_column_with_identity`, or `bucketized_column`, as in the example: + + ```python + # vertical_id is an integer categorical feature. + vertical_id = categorical_column_with_identity('vertical_id', 10K) + price = numeric_column('price') + # bucketized_column converts numerical feature to a categorical one. + bucketized_price = bucketized_column(price, boundaries=[...]) + vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K) + columns = [vertical_id_x_price, ...] + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + linear_prediction = linear_model(features, columns) + ``` + + To use crossed column in DNN model, you need to add it in an embedding column + as in this example: + + ```python + vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K) + vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10) + dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...]) + ``` + + Args: + keys: An iterable identifying the features to be crossed. Each element can + be either: + * string: Will use the corresponding feature which must be of string type. + * `CategoricalColumn`: Will use the transformed tensor produced by this + column. Does not support hashed categorical column. + hash_bucket_size: An int > 1. The number of buckets. + hash_key: Specify the hash_key that will be used by the `FingerprintCat64` + function to combine the crosses fingerprints on SparseCrossOp (optional). + + Returns: + A `CrossedColumn`. + + Raises: + ValueError: If `len(keys) < 2`. + ValueError: If any of the keys is neither a string nor `CategoricalColumn`. + ValueError: If any of the keys is `HashedCategoricalColumn`. + ValueError: If `hash_bucket_size < 1`. + """ + if not hash_bucket_size or hash_bucket_size < 1: + raise ValueError('hash_bucket_size must be > 1. ' + 'hash_bucket_size: {}'.format(hash_bucket_size)) + if not keys or len(keys) < 2: + raise ValueError( + 'keys must be a list with length > 1. Given: {}'.format(keys)) + for key in keys: + if (not isinstance(key, six.string_types) and + not isinstance(key, CategoricalColumn)): + raise ValueError( + 'Unsupported key type. All keys must be either string, or ' + 'categorical column except HashedCategoricalColumn. ' + 'Given: {}'.format(key)) + if isinstance(key, HashedCategoricalColumn): + raise ValueError( + 'categorical_column_with_hash_bucket is not supported for crossing. ' + 'Hashing before crossing will increase probability of collision. ' + 'Instead, use the feature name as a string. Given: {}'.format(key)) + return CrossedColumn( + keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key) + + +class StateManager(object): + """Manages the state associated with FeatureColumns. + + Some `FeatureColumn`s create variables or resources to assist their + computation. The `StateManager` is responsible for creating and storing these + objects since `FeatureColumn`s are supposed to be stateless configuration + only. + """ + + def get_variable(self, + feature_column, + name, + shape, + dtype=None, + initializer=None): + """Creates a new variable or returns an existing one. + + Args: + feature_column: A `FeatureColumn` object this variable corresponds to. + name: variable name. + shape: variable shape. + dtype: The type of the variable. Defaults to `self.dtype` or `float32`. + initializer: initializer instance (callable). + + Returns: + The variable. + """ + raise NotImplementedError('StateManager.get_variable') + + def get_resource(self, feature_column, name, resource_creator): + """Creates a new resource or returns an existing one. + + Resources can be things such as tables etc. + + Args: + feature_column: A `FeatureColumn` object this variable corresponds to. + name: Name of the resource. + resource_creator: A callable that can create the resource. + + Returns: + The resource. + """ + raise NotImplementedError('StateManager.get_resource') + + +class FeatureColumn(object): + """Represents a feature column abstraction. + + WARNING: Do not subclass this layer unless you know what you are doing: + the API is subject to future changes. + + To distinguish between the concept of a feature family and a specific binary + feature within a family, we refer to a feature family like "country" as a + feature column. For example, we can have a feature in a `tf.Example` format: + {key: "country", value: [ "US" ]} + In this example the value of feature is "US" and "country" refers to the + column of the feature. + + This class is an abstract class. Users should not create instances of this. + """ + __metaclass__ = abc.ABCMeta + + @abc.abstractproperty + def name(self): + """Returns string. Used for naming.""" + pass + + @abc.abstractmethod + def transform_feature(self, transformation_cache, state_manager): + """Returns intermediate representation (usually a `Tensor`). + + Uses `transformation_cache` to create an intermediate representation + (usually a `Tensor`) that other feature columns can use. + + Example usage of `transformation_cache`: + Let's say a Feature column depends on raw feature ('raw') and another + `FeatureColumn` (input_fc). To access corresponding `Tensor`s, + transformation_cache will be used as follows: + + ```python + raw_tensor = transformation_cache.get('raw', state_manager) + fc_tensor = transformation_cache.get(input_fc, state_manager) + ``` + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + + Returns: + Transformed feature `Tensor`. + """ + pass + + @abc.abstractproperty + def parse_example_spec(self): + """Returns a `tf.Example` parsing spec as dict. + + It is used for get_parsing_spec for `tf.parse_example`. Returned spec is a + dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other + supported objects. Please check documentation of `tf.parse_example` for all + supported spec objects. + + Let's say a Feature column depends on raw feature ('raw') and another + `FeatureColumn` (input_fc). One possible implementation of + parse_example_spec is as follows: + + ```python + spec = {'raw': tf.FixedLenFeature(...)} + spec.update(input_fc.parse_example_spec) + return spec + ``` + """ + pass + + def create_state(self, state_manager): + """Uses the `state_manager` to create state for the FeatureColumn. + + Args: + state_manager: A `StateManager` to create / access resources such as + lookup tables and variables. + """ + pass + + +class DenseColumn(FeatureColumn): + """Represents a column which can be represented as `Tensor`. + + Some examples of this type are: numeric_column, embedding_column, + indicator_column. + """ + + __metaclass__ = abc.ABCMeta + + @abc.abstractproperty + def variable_shape(self): + """`TensorShape` of `get_dense_tensor`, without batch dimension.""" + pass + + @abc.abstractmethod + def get_dense_tensor(self, transformation_cache, state_manager): + """Returns a `Tensor`. + + The output of this function will be used by model-builder-functions. For + example the pseudo code of `input_layer` will be like: + + ```python + def input_layer(features, feature_columns, ...): + outputs = [fc.get_dense_tensor(...) for fc in feature_columns] + return tf.concat(outputs) + ``` + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + + Returns: + `Tensor` of shape [batch_size] + `variable_shape`. + """ + pass + + +def _create_weighted_sum(column, + transformation_cache, + state_manager, + units, + sparse_combiner, + weight_collections, + trainable, + weight_var=None): + """Creates a weighted sum for a dense/categorical column for linear_model.""" + if isinstance(column, CategoricalColumn): + return _create_categorical_column_weighted_sum( + column=column, + transformation_cache=transformation_cache, + state_manager=state_manager, + units=units, + sparse_combiner=sparse_combiner, + weight_collections=weight_collections, + trainable=trainable, + weight_var=weight_var) + else: + return _create_dense_column_weighted_sum( + column=column, + transformation_cache=transformation_cache, + state_manager=state_manager, + units=units, + weight_collections=weight_collections, + trainable=trainable, + weight_var=weight_var) + + +def _create_dense_column_weighted_sum(column, + transformation_cache, + state_manager, + units, + weight_collections, + trainable, + weight_var=None): + """Create a weighted sum of a dense column for linear_model.""" + tensor = column.get_dense_tensor(transformation_cache, state_manager) + num_elements = column.variable_shape.num_elements() + batch_size = array_ops.shape(tensor)[0] + tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements)) + if weight_var is not None: + weight = weight_var + else: + weight = variable_scope.get_variable( + name='weights', + shape=[num_elements, units], + initializer=init_ops.zeros_initializer(), + trainable=trainable, + collections=weight_collections) + return math_ops.matmul(tensor, weight, name='weighted_sum') + + +class CategoricalColumn(FeatureColumn): + """Represents a categorical feature. + + A categorical feature typically handled with a `tf.SparseTensor` of IDs. + """ + __metaclass__ = abc.ABCMeta + + IdWeightPair = collections.namedtuple( # pylint: disable=invalid-name + 'IdWeightPair', ('id_tensor', 'weight_tensor')) + + @abc.abstractproperty + def num_buckets(self): + """Returns number of buckets in this sparse feature.""" + pass + + @abc.abstractmethod + def get_sparse_tensors(self, transformation_cache, state_manager): + """Returns an IdWeightPair. + + `IdWeightPair` is a pair of `SparseTensor`s which represents ids and + weights. + + `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets` + `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a + `SparseTensor` of `float` or `None` to indicate all weights should be + taken to be 1. If specified, `weight_tensor` must have exactly the same + shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing + output of a `VarLenFeature` which is a ragged matrix. + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + """ + pass + + +def _create_categorical_column_weighted_sum(column, + transformation_cache, + state_manager, + units, + sparse_combiner, + weight_collections, + trainable, + weight_var=None): + # pylint: disable=g-doc-return-or-yield,g-doc-args + """Create a weighted sum of a categorical column for linear_model. + + Note to maintainer: As implementation details, the weighted sum is + implemented via embedding_lookup_sparse toward efficiency. Mathematically, + they are the same. + + To be specific, conceptually, categorical column can be treated as multi-hot + vector. Say: + + ```python + x = [0 0 1] # categorical column input + w = [a b c] # weights + ``` + The weighted sum is `c` in this case, which is same as `w[2]`. + + Another example is + + ```python + x = [0 1 1] # categorical column input + w = [a b c] # weights + ``` + The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`. + + For both cases, we can implement weighted sum via embedding_lookup with + sparse_combiner = "sum". + """ + + sparse_tensors = column.get_sparse_tensors(transformation_cache, + state_manager) + id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [ + array_ops.shape(sparse_tensors.id_tensor)[0], -1 + ]) + weight_tensor = sparse_tensors.weight_tensor + if weight_tensor is not None: + weight_tensor = sparse_ops.sparse_reshape( + weight_tensor, [array_ops.shape(weight_tensor)[0], -1]) + + if weight_var is not None: + weight = weight_var + else: + weight = variable_scope.get_variable( + name='weights', + shape=(column.num_buckets, units), + initializer=init_ops.zeros_initializer(), + trainable=trainable, + collections=weight_collections) + return _safe_embedding_lookup_sparse( + weight, + id_tensor, + sparse_weights=weight_tensor, + combiner=sparse_combiner, + name='weighted_sum') + + +class SequenceDenseColumn(FeatureColumn): + """Represents dense sequence data.""" + + __metaclass__ = abc.ABCMeta + + TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name + 'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length')) + + @abc.abstractmethod + def get_sequence_dense_tensor(self, transformation_cache, state_manager): + """Returns a `TensorSequenceLengthPair`. + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + """ + pass + + +class FeatureTransformationCache(object): + """Handles caching of transformations while building the model. + + `FeatureColumn` specifies how to digest an input column to the network. Some + feature columns require data transformations. This class caches those + transformations. + + Some features may be used in more than one place. For example, one can use a + bucketized feature by itself and a cross with it. In that case we + should create only one bucketization op instead of creating ops for each + feature column separately. To handle re-use of transformed columns, + `FeatureTransformationCache` caches all previously transformed columns. + + Example: + We're trying to use the following `FeatureColumn`s: + + ```python + bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...) + keywords = fc.categorical_column_with_hash_buckets("keywords", ...) + age_X_keywords = fc.crossed_column([bucketized_age, "keywords"]) + ... = linear_model(features, + [bucketized_age, keywords, age_X_keywords] + ``` + + If we transform each column independently, then we'll get duplication of + bucketization (one for cross, one for bucketization itself). + The `FeatureTransformationCache` eliminates this duplication. + """ + + def __init__(self, features): + """Creates a `FeatureTransformationCache`. + + Args: + features: A mapping from feature column to objects that are `Tensor` or + `SparseTensor`, or can be converted to same via + `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key + signifies a base feature (not-transformed). A `FeatureColumn` key + means that this `Tensor` is the output of an existing `FeatureColumn` + which can be reused. + """ + self._features = features.copy() + self._feature_tensors = {} + + def get(self, key, state_manager): + """Returns a `Tensor` for the given key. + + A `str` key is used to access a base feature (not-transformed). When a + `FeatureColumn` is passed, the transformed feature is returned if it + already exists, otherwise the given `FeatureColumn` is asked to provide its + transformed output, which is then cached. + + Args: + key: a `str` or a `FeatureColumn`. + state_manager: A StateManager object that holds the FeatureColumn state. + + Returns: + The transformed `Tensor` corresponding to the `key`. + + Raises: + ValueError: if key is not found or a transformed `Tensor` cannot be + computed. + """ + if key in self._feature_tensors: + # FeatureColumn is already transformed or converted. + return self._feature_tensors[key] + + if key in self._features: + feature_tensor = self._get_raw_feature_as_tensor(key) + self._feature_tensors[key] = feature_tensor + return feature_tensor + + if isinstance(key, six.string_types): + raise ValueError('Feature {} is not in features dictionary.'.format(key)) + + if not isinstance(key, FeatureColumn): + raise TypeError('"key" must be either a "str" or "FeatureColumn". ' + 'Provided: {}'.format(key)) + + column = key + logging.debug('Transforming feature_column %s.', column) + transformed = column.transform_feature(self, state_manager) + if transformed is None: + raise ValueError('Column {} is not supported.'.format(column.name)) + self._feature_tensors[column] = transformed + return transformed + + def _get_raw_feature_as_tensor(self, key): + """Gets the raw_feature (keyed by `key`) as `tensor`. + + The raw feature is converted to (sparse) tensor and maybe expand dim. + + For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if + the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will + error out as it is not supported. + + Args: + key: A `str` key to access the raw feature. + + Returns: + A `Tensor` or `SparseTensor`. + + Raises: + ValueError: if the raw feature has rank 0. + """ + raw_feature = self._features[key] + feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( + raw_feature) + + def expand_dims(input_tensor): + # Input_tensor must have rank 1. + if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): + return sparse_ops.sparse_reshape( + input_tensor, [array_ops.shape(input_tensor)[0], -1]) + else: + return array_ops.expand_dims(input_tensor, -1) + + rank = feature_tensor.get_shape().ndims + if rank is not None: + if rank == 0: + raise ValueError( + 'Feature (key: {}) cannot have rank 0. Give: {}'.format( + key, feature_tensor)) + return feature_tensor if rank != 1 else expand_dims(feature_tensor) + + # Handle dynamic rank. + with ops.control_dependencies([ + check_ops.assert_positive( + array_ops.rank(feature_tensor), + message='Feature (key: {}) cannot have rank 0. Given: {}'.format( + key, feature_tensor))]): + return control_flow_ops.cond( + math_ops.equal(1, array_ops.rank(feature_tensor)), + lambda: expand_dims(feature_tensor), + lambda: feature_tensor) + + +# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py +def _shape_offsets(shape): + """Returns moving offset for each dimension given shape.""" + offsets = [] + for dim in reversed(shape): + if offsets: + offsets.append(dim * offsets[-1]) + else: + offsets.append(dim) + offsets.reverse() + return offsets + + +# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py +def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None): + """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. + + If `input_tensor` is already a `SparseTensor`, just return it. + + Args: + input_tensor: A string or integer `Tensor`. + ignore_value: Entries in `dense_tensor` equal to this value will be + absent from the resulting `SparseTensor`. If `None`, default value of + `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`). + + Returns: + A `SparseTensor` with the same shape as `input_tensor`. + + Raises: + ValueError: when `input_tensor`'s rank is `None`. + """ + input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( + input_tensor) + if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): + return input_tensor + with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)): + if ignore_value is None: + if input_tensor.dtype == dtypes.string: + # Exception due to TF strings are converted to numpy objects by default. + ignore_value = '' + elif input_tensor.dtype.is_integer: + ignore_value = -1 # -1 has a special meaning of missing feature + else: + # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is + # constructing a new numpy object of the given type, which yields the + # default value for that type. + ignore_value = input_tensor.dtype.as_numpy_dtype() + ignore_value = math_ops.cast( + ignore_value, input_tensor.dtype, name='ignore_value') + indices = array_ops.where( + math_ops.not_equal(input_tensor, ignore_value), name='indices') + return sparse_tensor_lib.SparseTensor( + indices=indices, + values=array_ops.gather_nd(input_tensor, indices, name='values'), + dense_shape=array_ops.shape( + input_tensor, out_type=dtypes.int64, name='dense_shape')) + + +def _normalize_feature_columns(feature_columns): + """Normalizes the `feature_columns` input. + + This method converts the `feature_columns` to list type as best as it can. In + addition, verifies the type and other parts of feature_columns, required by + downstream library. + + Args: + feature_columns: The raw feature columns, usually passed by users. + + Returns: + The normalized feature column list. + + Raises: + ValueError: for any invalid inputs, such as empty, duplicated names, etc. + """ + if isinstance(feature_columns, FeatureColumn): + feature_columns = [feature_columns] + + if isinstance(feature_columns, collections.Iterator): + feature_columns = list(feature_columns) + + if isinstance(feature_columns, dict): + raise ValueError('Expected feature_columns to be iterable, found dict.') + + for column in feature_columns: + if not isinstance(column, FeatureColumn): + raise ValueError('Items of feature_columns must be a FeatureColumn. ' + 'Given (type {}): {}.'.format(type(column), column)) + if not feature_columns: + raise ValueError('feature_columns must not be empty.') + name_to_column = dict() + for column in feature_columns: + if column.name in name_to_column: + raise ValueError('Duplicate feature column name found for columns: {} ' + 'and {}. This usually means that these columns refer to ' + 'same base feature. Either one must be discarded or a ' + 'duplicated but renamed item must be inserted in ' + 'features dict.'.format(column, + name_to_column[column.name])) + name_to_column[column.name] = column + + return feature_columns + + +class NumericColumn( + DenseColumn, + collections.namedtuple( + 'NumericColumn', + ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))): + """see `numeric_column`.""" + + @property + def name(self): + """See `FeatureColumn` base class.""" + return self.key + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + return { + self.key: + parsing_ops.FixedLenFeature(self.shape, self.dtype, + self.default_value) + } + + def transform_feature(self, transformation_cache, state_manager): + """See `FeatureColumn` base class. + + In this case, we apply the `normalizer_fn` to the input tensor. + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + + Returns: + Normalized input tensor. + Raises: + ValueError: If a SparseTensor is passed in. + """ + input_tensor = transformation_cache.get(self.key, state_manager) + if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): + raise ValueError( + 'The corresponding Tensor of numerical column must be a Tensor. ' + 'SparseTensor is not supported. key: {}'.format(self.key)) + if self.normalizer_fn is not None: + input_tensor = self.normalizer_fn(input_tensor) + return math_ops.to_float(input_tensor) + + @property + def variable_shape(self): + """See `DenseColumn` base class.""" + return tensor_shape.TensorShape(self.shape) + + def get_dense_tensor(self, transformation_cache, state_manager): + """Returns dense `Tensor` representing numeric feature. + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + + Returns: + Dense `Tensor` created within `transform_feature`. + """ + # Feature has been already transformed. Return the intermediate + # representation created by _transform_feature. + return transformation_cache.get(self, state_manager) + + +class BucketizedColumn(DenseColumn, CategoricalColumn, + collections.namedtuple('BucketizedColumn', + ('source_column', 'boundaries'))): + """See `bucketized_column`.""" + + @property + def name(self): + """See `FeatureColumn` base class.""" + return '{}_bucketized'.format(self.source_column.name) + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + return self.source_column.parse_example_spec + + def transform_feature(self, transformation_cache, state_manager): + """Returns bucketized categorical `source_column` tensor.""" + source_tensor = transformation_cache.get(self.source_column, state_manager) + return math_ops._bucketize( # pylint: disable=protected-access + source_tensor, + boundaries=self.boundaries) + + @property + def variable_shape(self): + """See `DenseColumn` base class.""" + return tensor_shape.TensorShape( + tuple(self.source_column.shape) + (len(self.boundaries) + 1,)) + + def get_dense_tensor(self, transformation_cache, state_manager): + """Returns one hot encoded dense `Tensor`.""" + input_tensor = transformation_cache.get(self, state_manager) + return array_ops.one_hot( + indices=math_ops.to_int64(input_tensor), + depth=len(self.boundaries) + 1, + on_value=1., + off_value=0.) + + @property + def num_buckets(self): + """See `CategoricalColumn` base class.""" + # By construction, source_column is always one-dimensional. + return (len(self.boundaries) + 1) * self.source_column.shape[0] + + def get_sparse_tensors(self, transformation_cache, state_manager): + """Converts dense inputs to SparseTensor so downstream code can use it.""" + input_tensor = transformation_cache.get(self, state_manager) + batch_size = array_ops.shape(input_tensor)[0] + # By construction, source_column is always one-dimensional. + source_dimension = self.source_column.shape[0] + + i1 = array_ops.reshape( + array_ops.tile( + array_ops.expand_dims(math_ops.range(0, batch_size), 1), + [1, source_dimension]), + (-1,)) + i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size]) + # Flatten the bucket indices and unique them across dimensions + # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets + bucket_indices = ( + array_ops.reshape(input_tensor, (-1,)) + + (len(self.boundaries) + 1) * i2) + + indices = math_ops.to_int64(array_ops.transpose(array_ops.stack((i1, i2)))) + dense_shape = math_ops.to_int64(array_ops.stack( + [batch_size, source_dimension])) + sparse_tensor = sparse_tensor_lib.SparseTensor( + indices=indices, + values=bucket_indices, + dense_shape=dense_shape) + return CategoricalColumn.IdWeightPair(sparse_tensor, None) + + +class EmbeddingColumn( + DenseColumn, SequenceDenseColumn, + collections.namedtuple( + 'EmbeddingColumn', + ('categorical_column', 'dimension', 'combiner', 'initializer', + 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable'))): + """See `embedding_column`.""" + + @property + def name(self): + """See `FeatureColumn` base class.""" + return '{}_embedding'.format(self.categorical_column.name) + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + return self.categorical_column.parse_example_spec + + def transform_feature(self, transformation_cache, state_manager): + """Transforms underlying `categorical_column`.""" + return transformation_cache.get(self.categorical_column, state_manager) + + @property + def variable_shape(self): + """See `DenseColumn` base class.""" + return tensor_shape.vector(self.dimension) + + def _get_dense_tensor_internal(self, transformation_cache, state_manager): + """Private method that follows the signature of _get_dense_tensor.""" + # Get sparse IDs and weights. + sparse_tensors = self.categorical_column.get_sparse_tensors( + transformation_cache, state_manager) + sparse_ids = sparse_tensors.id_tensor + sparse_weights = sparse_tensors.weight_tensor + + embedding_shape = (self.categorical_column.num_buckets, self.dimension) + embedding_weights = state_manager.get_variable( + self, + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer) + + if self.ckpt_to_load_from is not None: + to_restore = embedding_weights + if isinstance(to_restore, variables.PartitionedVariable): + to_restore = to_restore._get_variable_list() # pylint: disable=protected-access + checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, { + self.tensor_name_in_ckpt: to_restore + }) + + # Return embedding lookup result. + return _safe_embedding_lookup_sparse( + embedding_weights=embedding_weights, + sparse_ids=sparse_ids, + sparse_weights=sparse_weights, + combiner=self.combiner, + name='%s_weights' % self.name, + max_norm=self.max_norm) + + def get_dense_tensor(self, transformation_cache, state_manager): + """Returns tensor after doing the embedding lookup. + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + + Returns: + Embedding lookup tensor. + + Raises: + ValueError: `categorical_column` is SequenceCategoricalColumn. + """ + if isinstance(self.categorical_column, SequenceCategoricalColumn): + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must not be of type SequenceCategoricalColumn. ' + 'Suggested fix A: If you wish to use input_layer, use a ' + 'non-sequence categorical_column_with_*. ' + 'Suggested fix B: If you wish to create sequence input, use ' + 'sequence_input_layer instead of input_layer. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + return self._get_dense_tensor_internal(transformation_cache, state_manager) + + def get_sequence_dense_tensor(self, transformation_cache, state_manager): + """See `SequenceDenseColumn` base class.""" + if not isinstance(self.categorical_column, SequenceCategoricalColumn): + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must be of type SequenceCategoricalColumn ' + 'to use sequence_input_layer. ' + 'Suggested fix: Use one of sequence_categorical_column_with_*. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access + transformation_cache, state_manager) + sparse_tensors = self.categorical_column.get_sparse_tensors( + transformation_cache, state_manager) + sequence_length = _sequence_length_from_sparse_tensor( + sparse_tensors.id_tensor) + return SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + + +def _get_graph_for_variable(var): + if isinstance(var, variables.PartitionedVariable): + return list(var)[0].graph + else: + return var.graph + + +class SharedEmbeddingColumn( + DenseColumn, SequenceDenseColumn, + collections.namedtuple( + 'SharedEmbeddingColumn', + ('categorical_column', 'dimension', 'combiner', 'initializer', + 'shared_embedding_collection_name', 'ckpt_to_load_from', + 'tensor_name_in_ckpt', 'max_norm', 'trainable'))): + """See `embedding_column`.""" + + @property + def name(self): + """See `FeatureColumn` base class.""" + return '{}_shared_embedding'.format(self.categorical_column.name) + + @property + def shared_collection_name(self): + """Returns the shared name of this column. + + A group of columns share an embedding. Each one of those columns would have + the same `shared_collection_name` by which they could be collectively + referred to. + """ + return self.shared_embedding_collection_name + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + return self.categorical_column.parse_example_spec + + def transform_feature(self, transformation_cache, state_manager): + """See `FeatureColumn` base class.""" + return transformation_cache.get(self.categorical_column, state_manager) + + @property + def variable_shape(self): + """See `DenseColumn` base class.""" + return tensor_shape.vector(self.dimension) + + def _get_dense_tensor_internal(self, transformation_cache, state_manager): + """Private method that follows the signature of _get_dense_tensor.""" + # This method is called from a variable_scope with name _var_scope_name, + # which is shared among all shared embeddings. Open a name_scope here, so + # that the ops for different columns have distinct names. + with ops.name_scope(None, default_name=self.name): + # Get sparse IDs and weights. + sparse_tensors = self.categorical_column.get_sparse_tensors( + transformation_cache, state_manager) + sparse_ids = sparse_tensors.id_tensor + sparse_weights = sparse_tensors.weight_tensor + + embedding_shape = (self.categorical_column.num_buckets, self.dimension) + embedding_weights = state_manager.get_variable( + self, + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer) + + if self.ckpt_to_load_from is not None: + to_restore = embedding_weights + if isinstance(to_restore, variables.PartitionedVariable): + to_restore = to_restore._get_variable_list() # pylint: disable=protected-access + checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, { + self.tensor_name_in_ckpt: to_restore + }) + + # Return embedding lookup result. + return _safe_embedding_lookup_sparse( + embedding_weights=embedding_weights, + sparse_ids=sparse_ids, + sparse_weights=sparse_weights, + combiner=self.combiner, + name='%s_weights' % self.name, + max_norm=self.max_norm) + + def get_dense_tensor(self, transformation_cache, state_manager): + """Returns the embedding lookup result.""" + if isinstance(self.categorical_column, SequenceCategoricalColumn): + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must not be of type SequenceCategoricalColumn. ' + 'Suggested fix A: If you wish to use input_layer, use a ' + 'non-sequence categorical_column_with_*. ' + 'Suggested fix B: If you wish to create sequence input, use ' + 'sequence_input_layer instead of input_layer. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + return self._get_dense_tensor_internal(transformation_cache, state_manager) + + def get_sequence_dense_tensor(self, transformation_cache, state_manager): + """See `SequenceDenseColumn` base class.""" + if not isinstance(self.categorical_column, SequenceCategoricalColumn): + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must be of type SequenceCategoricalColumn ' + 'to use sequence_input_layer. ' + 'Suggested fix: Use one of sequence_categorical_column_with_*. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + dense_tensor = self.get_dense_tensor_internal(transformation_cache, + state_manager) + sparse_tensors = self.categorical_column.get_sparse_tensors( + transformation_cache, state_manager) + sequence_length = _sequence_length_from_sparse_tensor( + sparse_tensors.id_tensor) + return SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + + +def _create_tuple(shape, value): + """Returns a tuple with given shape and filled with value.""" + if shape: + return tuple([_create_tuple(shape[1:], value) for _ in range(shape[0])]) + return value + + +def _as_tuple(value): + if not nest.is_sequence(value): + return value + return tuple([_as_tuple(v) for v in value]) + + +def _check_shape(shape, key): + """Returns shape if it's valid, raises error otherwise.""" + assert shape is not None + if not nest.is_sequence(shape): + shape = [shape] + shape = tuple(shape) + for dimension in shape: + if not isinstance(dimension, int): + raise TypeError('shape dimensions must be integer. ' + 'shape: {}, key: {}'.format(shape, key)) + if dimension < 1: + raise ValueError('shape dimensions must be greater than 0. ' + 'shape: {}, key: {}'.format(shape, key)) + return shape + + +def _is_shape_and_default_value_compatible(default_value, shape): + """Verifies compatibility of shape and default_value.""" + # Invalid condition: + # * if default_value is not a scalar and shape is empty + # * or if default_value is an iterable and shape is not empty + if nest.is_sequence(default_value) != bool(shape): + return False + if not shape: + return True + if len(default_value) != shape[0]: + return False + for i in range(shape[0]): + if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]): + return False + return True + + +def _check_default_value(shape, default_value, dtype, key): + """Returns default value as tuple if it's valid, otherwise raises errors. + + This function verifies that `default_value` is compatible with both `shape` + and `dtype`. If it is not compatible, it raises an error. If it is compatible, + it casts default_value to a tuple and returns it. `key` is used only + for error message. + + Args: + shape: An iterable of integers specifies the shape of the `Tensor`. + default_value: If a single value is provided, the same value will be applied + as the default value for every item. If an iterable of values is + provided, the shape of the `default_value` should be equal to the given + `shape`. + dtype: defines the type of values. Default value is `tf.float32`. Must be a + non-quantized, real integer or floating point type. + key: Column name, used only for error messages. + + Returns: + A tuple which will be used as default value. + + Raises: + TypeError: if `default_value` is an iterable but not compatible with `shape` + TypeError: if `default_value` is not compatible with `dtype`. + ValueError: if `dtype` is not convertible to `tf.float32`. + """ + if default_value is None: + return None + + if isinstance(default_value, int): + return _create_tuple(shape, default_value) + + if isinstance(default_value, float) and dtype.is_floating: + return _create_tuple(shape, default_value) + + if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays + default_value = default_value.tolist() + + if nest.is_sequence(default_value): + if not _is_shape_and_default_value_compatible(default_value, shape): + raise ValueError( + 'The shape of default_value must be equal to given shape. ' + 'default_value: {}, shape: {}, key: {}'.format( + default_value, shape, key)) + # Check if the values in the list are all integers or are convertible to + # floats. + is_list_all_int = all( + isinstance(v, int) for v in nest.flatten(default_value)) + is_list_has_float = any( + isinstance(v, float) for v in nest.flatten(default_value)) + if is_list_all_int: + return _as_tuple(default_value) + if is_list_has_float and dtype.is_floating: + return _as_tuple(default_value) + raise TypeError('default_value must be compatible with dtype. ' + 'default_value: {}, dtype: {}, key: {}'.format( + default_value, dtype, key)) + + +class HashedCategoricalColumn( + CategoricalColumn, + collections.namedtuple('HashedCategoricalColumn', + ('key', 'hash_bucket_size', 'dtype'))): + """see `categorical_column_with_hash_bucket`.""" + + @property + def name(self): + """See `FeatureColumn` base class.""" + return self.key + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def transform_feature(self, transformation_cache, state_manager): + """Hashes the values in the feature_column.""" + input_tensor = _to_sparse_input_and_drop_ignore_values( + transformation_cache.get(self.key, state_manager)) + if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): + raise ValueError('SparseColumn input must be a SparseTensor.') + + _assert_string_or_int( + input_tensor.dtype, + prefix='column_name: {} input_tensor'.format(self.key)) + + if self.dtype.is_integer != input_tensor.dtype.is_integer: + raise ValueError( + 'Column dtype and SparseTensors dtype must be compatible. ' + 'key: {}, column dtype: {}, tensor dtype: {}'.format( + self.key, self.dtype, input_tensor.dtype)) + + if self.dtype == dtypes.string: + sparse_values = input_tensor.values + else: + sparse_values = string_ops.as_string(input_tensor.values) + + sparse_id_values = string_ops.string_to_hash_bucket_fast( + sparse_values, self.hash_bucket_size, name='lookup') + return sparse_tensor_lib.SparseTensor( + input_tensor.indices, sparse_id_values, input_tensor.dense_shape) + + @property + def num_buckets(self): + """Returns number of buckets in this sparse feature.""" + return self.hash_bucket_size + + def get_sparse_tensors(self, transformation_cache, state_manager): + """See `CategoricalColumn` base class.""" + return CategoricalColumn.IdWeightPair( + transformation_cache.get(self, state_manager), None) + + +class VocabularyFileCategoricalColumn( + CategoricalColumn, + collections.namedtuple('VocabularyFileCategoricalColumn', + ('key', 'vocabulary_file', 'vocabulary_size', + 'num_oov_buckets', 'dtype', 'default_value'))): + """See `categorical_column_with_vocabulary_file`.""" + + @property + def name(self): + """See `FeatureColumn` base class.""" + return self.key + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def transform_feature(self, transformation_cache, state_manager): + """Creates a lookup table for the vocabulary.""" + input_tensor = _to_sparse_input_and_drop_ignore_values( + transformation_cache.get(self.key, state_manager)) + + if self.dtype.is_integer != input_tensor.dtype.is_integer: + raise ValueError( + 'Column dtype and SparseTensors dtype must be compatible. ' + 'key: {}, column dtype: {}, tensor dtype: {}'.format( + self.key, self.dtype, input_tensor.dtype)) + + _assert_string_or_int( + input_tensor.dtype, + prefix='column_name: {} input_tensor'.format(self.key)) + + key_dtype = self.dtype + if input_tensor.dtype.is_integer: + # `index_table_from_file` requires 64-bit integer keys. + key_dtype = dtypes.int64 + input_tensor = math_ops.to_int64(input_tensor) + + # TODO(rohanj): Use state manager to manage the index table creation. + return lookup_ops.index_table_from_file( + vocabulary_file=self.vocabulary_file, + num_oov_buckets=self.num_oov_buckets, + vocab_size=self.vocabulary_size, + default_value=self.default_value, + key_dtype=key_dtype, + name='{}_lookup'.format(self.key)).lookup(input_tensor) + + @property + def num_buckets(self): + """Returns number of buckets in this sparse feature.""" + return self.vocabulary_size + self.num_oov_buckets + + def get_sparse_tensors(self, transformation_cache, state_manager): + """See `CategoricalColumn` base class.""" + return CategoricalColumn.IdWeightPair( + transformation_cache.get(self, state_manager), None) + + +class VocabularyListCategoricalColumn( + CategoricalColumn, + collections.namedtuple( + 'VocabularyListCategoricalColumn', + ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets')) +): + """See `categorical_column_with_vocabulary_list`.""" + + @property + def name(self): + """See `FeatureColumn` base class.""" + return self.key + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def transform_feature(self, transformation_cache, state_manager): + """Creates a lookup table for the vocabulary list.""" + input_tensor = _to_sparse_input_and_drop_ignore_values( + transformation_cache.get(self.key, state_manager)) + + if self.dtype.is_integer != input_tensor.dtype.is_integer: + raise ValueError( + 'Column dtype and SparseTensors dtype must be compatible. ' + 'key: {}, column dtype: {}, tensor dtype: {}'.format( + self.key, self.dtype, input_tensor.dtype)) + + _assert_string_or_int( + input_tensor.dtype, + prefix='column_name: {} input_tensor'.format(self.key)) + + key_dtype = self.dtype + if input_tensor.dtype.is_integer: + # `index_table_from_tensor` requires 64-bit integer keys. + key_dtype = dtypes.int64 + input_tensor = math_ops.to_int64(input_tensor) + + # TODO(rohanj): Use state manager to manage the index table creation. + return lookup_ops.index_table_from_tensor( + vocabulary_list=tuple(self.vocabulary_list), + default_value=self.default_value, + num_oov_buckets=self.num_oov_buckets, + dtype=key_dtype, + name='{}_lookup'.format(self.key)).lookup(input_tensor) + + @property + def num_buckets(self): + """Returns number of buckets in this sparse feature.""" + return len(self.vocabulary_list) + self.num_oov_buckets + + def get_sparse_tensors(self, transformation_cache, state_manager): + """See `CategoricalColumn` base class.""" + return CategoricalColumn.IdWeightPair( + transformation_cache.get(self, state_manager), None) + + +class IdentityCategoricalColumn( + CategoricalColumn, + collections.namedtuple('IdentityCategoricalColumn', + ('key', 'number_buckets', 'default_value'))): + + """See `categorical_column_with_identity`.""" + + @property + def name(self): + """See `FeatureColumn` base class.""" + return self.key + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + return {self.key: parsing_ops.VarLenFeature(dtypes.int64)} + + def transform_feature(self, transformation_cache, state_manager): + """Returns a SparseTensor with identity values.""" + input_tensor = _to_sparse_input_and_drop_ignore_values( + transformation_cache.get(self.key, state_manager)) + + if not input_tensor.dtype.is_integer: + raise ValueError( + 'Invalid input, not integer. key: {} dtype: {}'.format( + self.key, input_tensor.dtype)) + + values = math_ops.to_int64(input_tensor.values, name='values') + num_buckets = math_ops.to_int64(self.num_buckets, name='num_buckets') + zero = math_ops.to_int64(0, name='zero') + if self.default_value is None: + # Fail if values are out-of-range. + assert_less = check_ops.assert_less( + values, num_buckets, data=(values, num_buckets), + name='assert_less_than_num_buckets') + assert_greater = check_ops.assert_greater_equal( + values, zero, data=(values,), + name='assert_greater_or_equal_0') + with ops.control_dependencies((assert_less, assert_greater)): + values = array_ops.identity(values) + else: + # Assign default for out-of-range values. + values = array_ops.where( + math_ops.logical_or( + values < zero, values >= num_buckets, name='out_of_range'), + array_ops.fill( + dims=array_ops.shape(values), + value=math_ops.to_int64(self.default_value), + name='default_values'), + values) + + return sparse_tensor_lib.SparseTensor( + indices=input_tensor.indices, + values=values, + dense_shape=input_tensor.dense_shape) + + @property + def num_buckets(self): + """Returns number of buckets in this sparse feature.""" + return self.number_buckets + + def get_sparse_tensors(self, transformation_cache, state_manager): + """See `CategoricalColumn` base class.""" + return CategoricalColumn.IdWeightPair( + transformation_cache.get(self, state_manager), None) + + +class WeightedCategoricalColumn( + CategoricalColumn, + collections.namedtuple( + 'WeightedCategoricalColumn', + ('categorical_column', 'weight_feature_key', 'dtype'))): + """See `weighted_categorical_column`.""" + + @property + def name(self): + """See `FeatureColumn` base class.""" + return '{}_weighted_by_{}'.format( + self.categorical_column.name, self.weight_feature_key) + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + config = self.categorical_column.parse_example_spec + if self.weight_feature_key in config: + raise ValueError('Parse config {} already exists for {}.'.format( + config[self.weight_feature_key], self.weight_feature_key)) + config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype) + return config + + @property + def num_buckets(self): + """See `DenseColumn` base class.""" + return self.categorical_column.num_buckets + + def transform_feature(self, transformation_cache, state_manager): + """Applies weights to tensor generated from `categorical_column`'.""" + weight_tensor = transformation_cache.get(self.weight_feature_key, + state_manager) + if weight_tensor is None: + raise ValueError('Missing weights {}.'.format(self.weight_feature_key)) + weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( + weight_tensor) + if self.dtype != weight_tensor.dtype.base_dtype: + raise ValueError('Bad dtype, expected {}, but got {}.'.format( + self.dtype, weight_tensor.dtype)) + if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor): + # The weight tensor can be a regular Tensor. In this case, sparsify it. + weight_tensor = _to_sparse_input_and_drop_ignore_values( + weight_tensor, ignore_value=0.0) + if not weight_tensor.dtype.is_floating: + weight_tensor = math_ops.to_float(weight_tensor) + return (transformation_cache.get(self.categorical_column, state_manager), + weight_tensor) + + def get_sparse_tensors(self, transformation_cache, state_manager): + """See `CategoricalColumn` base class.""" + tensors = transformation_cache.get(self, state_manager) + return CategoricalColumn.IdWeightPair(tensors[0], tensors[1]) + + +class CrossedColumn( + CategoricalColumn, + collections.namedtuple('CrossedColumn', + ('keys', 'hash_bucket_size', 'hash_key'))): + """See `crossed_column`.""" + + @property + def name(self): + """See `FeatureColumn` base class.""" + feature_names = [] + for key in _collect_leaf_level_keys(self): + if isinstance(key, FeatureColumn): + feature_names.append(key.name) + else: # key must be a string + feature_names.append(key) + return '_X_'.join(sorted(feature_names)) + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + config = {} + for key in self.keys: + if isinstance(key, FeatureColumn): + config.update(key.parse_example_spec) + else: # key must be a string + config.update({key: parsing_ops.VarLenFeature(dtypes.string)}) + return config + + def transform_feature(self, transformation_cache, state_manager): + """Generates a hashed sparse cross from the input tensors.""" + feature_tensors = [] + for key in _collect_leaf_level_keys(self): + if isinstance(key, six.string_types): + feature_tensors.append(transformation_cache.get(key, state_manager)) + elif isinstance(key, CategoricalColumn): + ids_and_weights = key.get_sparse_tensors(transformation_cache, + state_manager) + if ids_and_weights.weight_tensor is not None: + raise ValueError( + 'crossed_column does not support weight_tensor, but the given ' + 'column populates weight_tensor. ' + 'Given column: {}'.format(key.name)) + feature_tensors.append(ids_and_weights.id_tensor) + else: + raise ValueError('Unsupported column type. Given: {}'.format(key)) + return sparse_ops.sparse_cross_hashed( + inputs=feature_tensors, + num_buckets=self.hash_bucket_size, + hash_key=self.hash_key) + + @property + def num_buckets(self): + """Returns number of buckets in this sparse feature.""" + return self.hash_bucket_size + + def get_sparse_tensors(self, transformation_cache, state_manager): + """See `CategoricalColumn` base class.""" + return CategoricalColumn.IdWeightPair( + transformation_cache.get(self, state_manager), None) + + +def _collect_leaf_level_keys(cross): + """Collects base keys by expanding all nested crosses. + + Args: + cross: A `CrossedColumn`. + + Returns: + A list of strings or `CategoricalColumn` instances. + """ + leaf_level_keys = [] + for k in cross.keys: + if isinstance(k, CrossedColumn): + leaf_level_keys.extend(_collect_leaf_level_keys(k)) + else: + leaf_level_keys.append(k) + return leaf_level_keys + + +# TODO(zakaria): Move this to embedding_ops and make it public. +def _safe_embedding_lookup_sparse(embedding_weights, + sparse_ids, + sparse_weights=None, + combiner='mean', + default_id=None, + name=None, + partition_strategy='div', + max_norm=None): + """Lookup embedding results, accounting for invalid IDs and empty features. + + The partitioned embedding in `embedding_weights` must all be the same shape + except for the first dimension. The first dimension is allowed to vary as the + vocabulary size is not necessarily a multiple of `P`. `embedding_weights` + may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a + partitioner. + + Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs + with non-positive weight. For an entry with no features, the embedding vector + for `default_id` is returned, or the 0-vector if `default_id` is not supplied. + + The ids and weights may be multi-dimensional. Embeddings are always aggregated + along the last dimension. + + Args: + embedding_weights: A list of `P` float `Tensor`s or values representing + partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable` + created by partitioning along dimension 0. The total unpartitioned + shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the + vocab size and `e_1, ..., e_m` are the embedding dimensions. + sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the + ids. `d_0` is typically batch size. + sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing + float weights corresponding to `sparse_ids`, or `None` if all weights + are be assumed to be 1.0. + combiner: A string specifying how to combine embedding results for each + entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" + the default. + default_id: The id to use for an entry with no features. + name: A name for this operation (optional). + partition_strategy: A string specifying the partitioning strategy. + Currently `"div"` and `"mod"` are supported. Default is `"div"`. + max_norm: If not `None`, all embeddings are l2-normalized to max_norm before + combining. + + + Returns: + Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`. + + Raises: + ValueError: if `embedding_weights` is empty. + """ + if embedding_weights is None: + raise ValueError('Missing embedding_weights %s.' % embedding_weights) + if isinstance(embedding_weights, variables.PartitionedVariable): + embedding_weights = list(embedding_weights) # get underlying Variables. + if not isinstance(embedding_weights, list): + embedding_weights = [embedding_weights] + if len(embedding_weights) < 1: + raise ValueError('Missing embedding_weights %s.' % embedding_weights) + + dtype = sparse_weights.dtype if sparse_weights is not None else None + embedding_weights = [ + ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights + ] + + with ops.name_scope(name, 'embedding_lookup', + embedding_weights + [sparse_ids, + sparse_weights]) as scope: + # Reshape higher-rank sparse ids and weights to linear segment ids. + original_shape = sparse_ids.dense_shape + original_rank_dim = sparse_ids.dense_shape.get_shape()[0] + original_rank = ( + array_ops.size(original_shape) + if original_rank_dim.value is None + else original_rank_dim.value) + sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ + math_ops.reduce_prod( + array_ops.slice(original_shape, [0], [original_rank - 1])), + array_ops.gather(original_shape, original_rank - 1)]) + if sparse_weights is not None: + sparse_weights = sparse_tensor_lib.SparseTensor( + sparse_ids.indices, + sparse_weights.values, sparse_ids.dense_shape) + + # Prune invalid ids and weights. + sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) + if combiner != 'sum': + sparse_ids, sparse_weights = _prune_invalid_weights( + sparse_ids, sparse_weights) + + # Fill in dummy values for empty features, if necessary. + sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids, + default_id or + 0) + if sparse_weights is not None: + sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) + + result = embedding_ops.embedding_lookup_sparse( + embedding_weights, + sparse_ids, + sparse_weights, + combiner=combiner, + partition_strategy=partition_strategy, + name=None if default_id is None else scope, + max_norm=max_norm) + + if default_id is None: + # Broadcast is_row_empty to the same shape as embedding_lookup_result, + # for use in Select. + is_row_empty = array_ops.tile( + array_ops.reshape(is_row_empty, [-1, 1]), + array_ops.stack([1, array_ops.shape(result)[1]])) + + result = array_ops.where(is_row_empty, + array_ops.zeros_like(result), + result, + name=scope) + + # Reshape back from linear ids back into higher-dimensional dense result. + final_result = array_ops.reshape( + result, + array_ops.concat([ + array_ops.slice( + math_ops.cast(original_shape, dtypes.int32), [0], + [original_rank - 1]), + array_ops.slice(array_ops.shape(result), [1], [-1]) + ], 0)) + final_result.set_shape(tensor_shape.unknown_shape( + (original_rank_dim - 1).value).concatenate(result.get_shape()[1:])) + return final_result + + +def _prune_invalid_ids(sparse_ids, sparse_weights): + """Prune invalid IDs (< 0) from the input ids and weights.""" + is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) + if sparse_weights is not None: + is_id_valid = math_ops.logical_and( + is_id_valid, + array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) + sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) + if sparse_weights is not None: + sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) + return sparse_ids, sparse_weights + + +def _prune_invalid_weights(sparse_ids, sparse_weights): + """Prune invalid weights (< 0) from the input ids and weights.""" + if sparse_weights is not None: + is_weights_valid = math_ops.greater(sparse_weights.values, 0) + sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) + sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) + return sparse_ids, sparse_weights + + +class IndicatorColumn(DenseColumn, SequenceDenseColumn, + collections.namedtuple('IndicatorColumn', + ('categorical_column'))): + """Represents a one-hot column for use in deep networks. + + Args: + categorical_column: A `CategoricalColumn` which is created by + `categorical_column_with_*` function. + """ + + @property + def name(self): + """See `FeatureColumn` base class.""" + return '{}_indicator'.format(self.categorical_column.name) + + def transform_feature(self, transformation_cache, state_manager): + """Returns dense `Tensor` representing feature. + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + + Returns: + Transformed feature `Tensor`. + + Raises: + ValueError: if input rank is not known at graph building time. + """ + id_weight_pair = self.categorical_column.get_sparse_tensors( + transformation_cache, state_manager) + id_tensor = id_weight_pair.id_tensor + weight_tensor = id_weight_pair.weight_tensor + + # If the underlying column is weighted, return the input as a dense tensor. + if weight_tensor is not None: + weighted_column = sparse_ops.sparse_merge( + sp_ids=id_tensor, + sp_values=weight_tensor, + vocab_size=int(self.variable_shape[-1])) + # Remove (?, -1) index + weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0], + weighted_column.dense_shape) + return sparse_ops.sparse_tensor_to_dense(weighted_column) + + dense_id_tensor = sparse_ops.sparse_tensor_to_dense( + id_tensor, default_value=-1) + + # One hot must be float for tf.concat reasons since all other inputs to + # input_layer are float32. + one_hot_id_tensor = array_ops.one_hot( + dense_id_tensor, + depth=self.variable_shape[-1], + on_value=1.0, + off_value=0.0) + + # Reduce to get a multi-hot per example. + return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2]) + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + return self.categorical_column.parse_example_spec + + @property + def variable_shape(self): + """Returns a `TensorShape` representing the shape of the dense `Tensor`.""" + return tensor_shape.TensorShape([1, self.categorical_column.num_buckets]) + + def get_dense_tensor(self, transformation_cache, state_manager): + """Returns dense `Tensor` representing feature. + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + + Returns: + Dense `Tensor` created within `transform_feature`. + + Raises: + ValueError: If `categorical_column` is a `SequenceCategoricalColumn`. + """ + if isinstance(self.categorical_column, SequenceCategoricalColumn): + raise ValueError( + 'In indicator_column: {}. ' + 'categorical_column must not be of type SequenceCategoricalColumn. ' + 'Suggested fix A: If you wish to use input_layer, use a ' + 'non-sequence categorical_column_with_*. ' + 'Suggested fix B: If you wish to create sequence input, use ' + 'sequence_input_layer instead of input_layer. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + # Feature has been already transformed. Return the intermediate + # representation created by transform_feature. + return transformation_cache.get(self, state_manager) + + def get_sequence_dense_tensor(self, transformation_cache, state_manager): + """See `SequenceDenseColumn` base class.""" + if not isinstance(self.categorical_column, SequenceCategoricalColumn): + raise ValueError( + 'In indicator_column: {}. ' + 'categorical_column must be of type SequenceCategoricalColumn ' + 'to use sequence_input_layer. ' + 'Suggested fix: Use one of sequence_categorical_column_with_*. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + # Feature has been already transformed. Return the intermediate + # representation created by transform_feature. + dense_tensor = transformation_cache.get(self, state_manager) + sparse_tensors = self.categorical_column.get_sparse_tensors( + transformation_cache, state_manager) + sequence_length = _sequence_length_from_sparse_tensor( + sparse_tensors.id_tensor) + return SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + + +def _verify_static_batch_size_equality(tensors, columns): + # bath_size is a tf.Dimension object. + expected_batch_size = None + for i in range(0, len(tensors)): + if tensors[i].shape[0].value is not None: + if expected_batch_size is None: + bath_size_column_index = i + expected_batch_size = tensors[i].shape[0] + elif not expected_batch_size.is_compatible_with(tensors[i].shape[0]): + raise ValueError( + 'Batch size (first dimension) of each feature must be same. ' + 'Batch size of columns ({}, {}): ({}, {})'.format( + columns[bath_size_column_index].name, columns[i].name, + expected_batch_size, tensors[i].shape[0])) + + +def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): + """Returns a [batch_size] Tensor with per-example sequence length.""" + with ops.name_scope(None, 'sequence_length') as name_scope: + row_ids = sp_tensor.indices[:, 0] + column_ids = sp_tensor.indices[:, 1] + column_ids += array_ops.ones_like(column_ids) + seq_length = math_ops.to_int64( + math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements) + # If the last n rows do not have ids, seq_length will have shape + # [batch_size - n]. Pad the remaining values with zeros. + n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1] + padding = array_ops.zeros(n_pad, dtype=seq_length.dtype) + return array_ops.concat([seq_length, padding], axis=0, name=name_scope) + + +class SequenceCategoricalColumn(FeatureColumn, + collections.namedtuple( + 'SequenceCategoricalColumn', + ('categorical_column'))): + """Represents sequences of categorical data.""" + + @property + def name(self): + """See `FeatureColumn` base class.""" + return self.categorical_column.name + + @property + def parse_example_spec(self): + """See `FeatureColumn` base class.""" + return self.categorical_column.parse_example_spec + + def transform_feature(self, transformation_cache, state_manager): + """See `FeatureColumn` base class.""" + return self.categorical_column.transform_feature(transformation_cache, + state_manager) + + @property + def num_buckets(self): + """Returns number of buckets in this sparse feature.""" + return self.categorical_column.num_buckets + + def get_sequence_sparse_tensors(self, transformation_cache, state_manager): + """Returns an IdWeightPair. + + `IdWeightPair` is a pair of `SparseTensor`s which represents ids and + weights. + + `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets` + `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a + `SparseTensor` of `float` or `None` to indicate all weights should be + taken to be 1. If specified, `weight_tensor` must have exactly the same + shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing + output of a `VarLenFeature` which is a ragged matrix. + + Args: + transformation_cache: A `FeatureTransformationCache` object to access + features. + state_manager: A `StateManager` to create / access resources such as + lookup tables. + """ + sparse_tensors = self.categorical_column.get_sparse_tensors( + transformation_cache, state_manager) + id_tensor = sparse_tensors.id_tensor + weight_tensor = sparse_tensors.weight_tensor + # Expands final dimension, so that embeddings are not combined during + # embedding lookup. + check_id_rank = check_ops.assert_equal( + array_ops.rank(id_tensor), 2, + data=[ + 'Column {} expected ID tensor of rank 2. '.format(self.name), + 'id_tensor shape: ', array_ops.shape(id_tensor)]) + with ops.control_dependencies([check_id_rank]): + id_tensor = sparse_ops.sparse_reshape( + id_tensor, + shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0)) + if weight_tensor is not None: + check_weight_rank = check_ops.assert_equal( + array_ops.rank(weight_tensor), 2, + data=[ + 'Column {} expected weight tensor of rank 2.'.format(self.name), + 'weight_tensor shape:', array_ops.shape(weight_tensor)]) + with ops.control_dependencies([check_weight_rank]): + weight_tensor = sparse_ops.sparse_reshape( + weight_tensor, + shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0)) + return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..80a9d5d40e275fce664ef52e5d5413930432d683 --- /dev/null +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -0,0 +1,6583 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for feature_column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import copy + +import numpy as np + +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.client import session +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column as fc_old +from tensorflow.python.feature_column import feature_column_v2 as fc +from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn +from tensorflow.python.feature_column.feature_column_v2 import FeatureTransformationCache +from tensorflow.python.feature_column.feature_column_v2 import InputLayer +from tensorflow.python.feature_column.feature_column_v2 import StateManager +from tensorflow.python.feature_column.feature_column_v2 import _LinearModel +from tensorflow.python.feature_column.feature_column_v2 import _transform_features +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.platform import test +from tensorflow.python.training import coordinator +from tensorflow.python.training import queue_runner_impl + + +def _initialized_session(config=None): + sess = session.Session(config=config) + sess.run(variables_lib.global_variables_initializer()) + sess.run(lookup_ops.tables_initializer()) + return sess + + +class LazyColumnTest(test.TestCase): + + def test_transformations_called_once(self): + + class TransformCounter(FeatureColumn): + + def __init__(self): + self.num_transform = 0 + + @property + def name(self): + return 'TransformCounter' + + def transform_feature(self, transformation_cache, state_manager): + self.num_transform += 1 # Count transform calls. + return transformation_cache.get('a', state_manager) + + @property + def parse_example_spec(self): + pass + + transformation_cache = FeatureTransformationCache( + features={'a': [[2], [3.]]}) + column = TransformCounter() + self.assertEqual(0, column.num_transform) + transformation_cache.get(column, None) + self.assertEqual(1, column.num_transform) + transformation_cache.get(column, None) + self.assertEqual(1, column.num_transform) + + def test_returns_transform_output(self): + + class Transformer(FeatureColumn): + + @property + def name(self): + return 'Transformer' + + def transform_feature(self, transformation_cache, state_manager): + return 'Output' + + @property + def parse_example_spec(self): + pass + + transformation_cache = FeatureTransformationCache( + features={'a': [[2], [3.]]}) + column = Transformer() + self.assertEqual('Output', transformation_cache.get(column, None)) + self.assertEqual('Output', transformation_cache.get(column, None)) + + def test_does_not_pollute_given_features_dict(self): + + class Transformer(FeatureColumn): + + @property + def name(self): + return 'Transformer' + + def transform_feature(self, transformation_cache, state_manager): + return 'Output' + + @property + def parse_example_spec(self): + pass + + features = {'a': [[2], [3.]]} + transformation_cache = FeatureTransformationCache(features=features) + transformation_cache.get(Transformer(), None) + self.assertEqual(['a'], list(features.keys())) + + def test_error_if_feature_is_not_found(self): + transformation_cache = FeatureTransformationCache( + features={'a': [[2], [3.]]}) + with self.assertRaisesRegexp(ValueError, + 'bbb is not in features dictionary'): + transformation_cache.get('bbb', None) + with self.assertRaisesRegexp(ValueError, + 'bbb is not in features dictionary'): + transformation_cache.get(u'bbb', None) + + def test_not_supported_feature_column(self): + + class NotAProperColumn(FeatureColumn): + + @property + def name(self): + return 'NotAProperColumn' + + def transform_feature(self, transformation_cache, state_manager): + # It should return not None. + pass + + @property + def parse_example_spec(self): + pass + + transformation_cache = FeatureTransformationCache( + features={'a': [[2], [3.]]}) + with self.assertRaisesRegexp(ValueError, + 'NotAProperColumn is not supported'): + transformation_cache.get(NotAProperColumn(), None) + + def test_key_should_be_string_or_feature_colum(self): + + class NotAFeatureColumn(object): + pass + + transformation_cache = FeatureTransformationCache( + features={'a': [[2], [3.]]}) + with self.assertRaisesRegexp( + TypeError, '"key" must be either a "str" or "FeatureColumn".'): + transformation_cache.get(NotAFeatureColumn(), None) + + +class NumericColumnTest(test.TestCase): + + def test_defaults(self): + a = fc.numeric_column('aaa') + self.assertEqual('aaa', a.key) + self.assertEqual('aaa', a.name) + self.assertEqual((1,), a.shape) + self.assertIsNone(a.default_value) + self.assertEqual(dtypes.float32, a.dtype) + self.assertIsNone(a.normalizer_fn) + + def test_key_should_be_string(self): + with self.assertRaisesRegexp(ValueError, 'key must be a string.'): + fc.numeric_column(key=('aaa',)) + + def test_shape_saved_as_tuple(self): + a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]]) + self.assertEqual((1, 2), a.shape) + + def test_default_value_saved_as_tuple(self): + a = fc.numeric_column('aaa', default_value=4.) + self.assertEqual((4.,), a.default_value) + a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]]) + self.assertEqual(((3., 2.),), a.default_value) + + def test_shape_and_default_value_compatibility(self): + fc.numeric_column('aaa', shape=[2], default_value=[1, 2.]) + with self.assertRaisesRegexp(ValueError, 'The shape of default_value'): + fc.numeric_column('aaa', shape=[2], default_value=[1, 2, 3.]) + fc.numeric_column( + 'aaa', shape=[3, 2], default_value=[[2, 3], [1, 2], [2, 3.]]) + with self.assertRaisesRegexp(ValueError, 'The shape of default_value'): + fc.numeric_column( + 'aaa', shape=[3, 1], default_value=[[2, 3], [1, 2], [2, 3.]]) + with self.assertRaisesRegexp(ValueError, 'The shape of default_value'): + fc.numeric_column( + 'aaa', shape=[3, 3], default_value=[[2, 3], [1, 2], [2, 3.]]) + + def test_default_value_type_check(self): + fc.numeric_column( + 'aaa', shape=[2], default_value=[1, 2.], dtype=dtypes.float32) + fc.numeric_column( + 'aaa', shape=[2], default_value=[1, 2], dtype=dtypes.int32) + with self.assertRaisesRegexp(TypeError, 'must be compatible with dtype'): + fc.numeric_column( + 'aaa', shape=[2], default_value=[1, 2.], dtype=dtypes.int32) + with self.assertRaisesRegexp(TypeError, + 'default_value must be compatible with dtype'): + fc.numeric_column('aaa', default_value=['string']) + + def test_shape_must_be_positive_integer(self): + with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'): + fc.numeric_column( + 'aaa', shape=[ + 1.0, + ]) + + with self.assertRaisesRegexp(ValueError, + 'shape dimensions must be greater than 0'): + fc.numeric_column( + 'aaa', shape=[ + 0, + ]) + + def test_dtype_is_convertible_to_float(self): + with self.assertRaisesRegexp(ValueError, + 'dtype must be convertible to float'): + fc.numeric_column('aaa', dtype=dtypes.string) + + def test_scalar_default_value_fills_the_shape(self): + a = fc.numeric_column('aaa', shape=[2, 3], default_value=2.) + self.assertEqual(((2., 2., 2.), (2., 2., 2.)), a.default_value) + + def test_parse_spec(self): + a = fc.numeric_column('aaa', shape=[2, 3], dtype=dtypes.int32) + self.assertEqual({ + 'aaa': parsing_ops.FixedLenFeature((2, 3), dtype=dtypes.int32) + }, a.parse_example_spec) + + def test_parse_example_no_default_value(self): + price = fc.numeric_column('price', shape=[2]) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'price': + feature_pb2.Feature(float_list=feature_pb2.FloatList( + value=[20., 110.])) + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([price])) + self.assertIn('price', features) + with self.test_session(): + self.assertAllEqual([[20., 110.]], features['price'].eval()) + + def test_parse_example_with_default_value(self): + price = fc.numeric_column('price', shape=[2], default_value=11.) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'price': + feature_pb2.Feature(float_list=feature_pb2.FloatList( + value=[20., 110.])) + })) + no_data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'something_else': + feature_pb2.Feature(float_list=feature_pb2.FloatList( + value=[20., 110.])) + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString(), + no_data.SerializeToString()], + features=fc.make_parse_example_spec([price])) + self.assertIn('price', features) + with self.test_session(): + self.assertAllEqual([[20., 110.], [11., 11.]], features['price'].eval()) + + def test_normalizer_fn_must_be_callable(self): + with self.assertRaisesRegexp(TypeError, 'must be a callable'): + fc.numeric_column('price', normalizer_fn='NotACallable') + + def test_normalizer_fn_transform_feature(self): + + def _increment_two(input_tensor): + return input_tensor + 2. + + price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two) + output = _transform_features({'price': [[1., 2.], [5., 6.]]}, [price], None) + with self.test_session(): + self.assertAllEqual([[3., 4.], [7., 8.]], output[price].eval()) + + def test_get_dense_tensor(self): + + def _increment_two(input_tensor): + return input_tensor + 2. + + price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two) + transformation_cache = FeatureTransformationCache({ + 'price': [[1., 2.], [5., 6.]] + }) + self.assertEqual( + transformation_cache.get(price, None), + price.get_dense_tensor(transformation_cache, None)) + + def test_sparse_tensor_not_supported(self): + price = fc.numeric_column('price') + transformation_cache = FeatureTransformationCache({ + 'price': + sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[0.3], dense_shape=[1, 1]) + }) + with self.assertRaisesRegexp(ValueError, 'must be a Tensor'): + price.transform_feature(transformation_cache, None) + + def test_deep_copy(self): + a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3., 2.]]) + a_copy = copy.deepcopy(a) + self.assertEqual(a_copy.name, 'aaa') + self.assertEqual(a_copy.shape, (1, 2)) + self.assertEqual(a_copy.default_value, ((3., 2.),)) + + def test_numpy_default_value(self): + a = fc.numeric_column( + 'aaa', shape=[1, 2], default_value=np.array([[3., 2.]])) + self.assertEqual(a.default_value, ((3., 2.),)) + + def test_linear_model(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + predictions = fc.linear_model(features, [price]) + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + self.assertAllClose([[0.]], price_var.eval()) + self.assertAllClose([[0.], [0.]], predictions.eval()) + sess.run(price_var.assign([[10.]])) + self.assertAllClose([[10.], [50.]], predictions.eval()) + + def test_keras_linear_model(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + predictions = get_keras_linear_model_predictions(features, [price]) + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + self.assertAllClose([[0.]], price_var.eval()) + self.assertAllClose([[0.], [0.]], predictions.eval()) + sess.run(price_var.assign([[10.]])) + self.assertAllClose([[10.], [50.]], predictions.eval()) + + +class BucketizedColumnTest(test.TestCase): + + def test_invalid_source_column_type(self): + a = fc.categorical_column_with_hash_bucket('aaa', hash_bucket_size=10) + with self.assertRaisesRegexp( + ValueError, + 'source_column must be a column generated with numeric_column'): + fc.bucketized_column(a, boundaries=[0, 1]) + + def test_invalid_source_column_shape(self): + a = fc.numeric_column('aaa', shape=[2, 3]) + with self.assertRaisesRegexp( + ValueError, 'source_column must be one-dimensional column'): + fc.bucketized_column(a, boundaries=[0, 1]) + + def test_invalid_boundaries(self): + a = fc.numeric_column('aaa') + with self.assertRaisesRegexp( + ValueError, 'boundaries must be a sorted list'): + fc.bucketized_column(a, boundaries=None) + with self.assertRaisesRegexp( + ValueError, 'boundaries must be a sorted list'): + fc.bucketized_column(a, boundaries=1.) + with self.assertRaisesRegexp( + ValueError, 'boundaries must be a sorted list'): + fc.bucketized_column(a, boundaries=[1, 0]) + with self.assertRaisesRegexp( + ValueError, 'boundaries must be a sorted list'): + fc.bucketized_column(a, boundaries=[1, 1]) + + def test_name(self): + a = fc.numeric_column('aaa', dtype=dtypes.int32) + b = fc.bucketized_column(a, boundaries=[0, 1]) + self.assertEqual('aaa_bucketized', b.name) + + def test_parse_spec(self): + a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32) + b = fc.bucketized_column(a, boundaries=[0, 1]) + self.assertEqual({ + 'aaa': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32) + }, b.parse_example_spec) + + def test_variable_shape(self): + a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32) + b = fc.bucketized_column(a, boundaries=[0, 1]) + # Column 'aaa` has shape [2] times three buckets -> variable_shape=[2, 3]. + self.assertAllEqual((2, 3), b.variable_shape) + + def test_num_buckets(self): + a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32) + b = fc.bucketized_column(a, boundaries=[0, 1]) + # Column 'aaa` has shape [2] times three buckets -> num_buckets=6. + self.assertEqual(6, b.num_buckets) + + def test_parse_example(self): + price = fc.numeric_column('price', shape=[2]) + bucketized_price = fc.bucketized_column(price, boundaries=[0, 50]) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'price': + feature_pb2.Feature(float_list=feature_pb2.FloatList( + value=[20., 110.])) + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([bucketized_price])) + self.assertIn('price', features) + with self.test_session(): + self.assertAllEqual([[20., 110.]], features['price'].eval()) + + def test_transform_feature(self): + price = fc.numeric_column('price', shape=[2]) + bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6]) + with ops.Graph().as_default(): + transformed_tensor = _transform_features({ + 'price': [[-1., 1.], [5., 6.]] + }, [bucketized_price], None) + with _initialized_session(): + self.assertAllEqual([[0, 1], [3, 4]], + transformed_tensor[bucketized_price].eval()) + + def test_get_dense_tensor_one_input_value(self): + """Tests _get_dense_tensor() for input with shape=[1].""" + price = fc.numeric_column('price', shape=[1]) + bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6]) + with ops.Graph().as_default(): + transformation_cache = FeatureTransformationCache({ + 'price': [[-1.], [1.], [5.], [6.]] + }) + with _initialized_session(): + bucketized_price_tensor = bucketized_price.get_dense_tensor( + transformation_cache, None) + self.assertAllClose( + # One-hot tensor. + [[[1., 0., 0., 0., 0.]], + [[0., 1., 0., 0., 0.]], + [[0., 0., 0., 1., 0.]], + [[0., 0., 0., 0., 1.]]], + bucketized_price_tensor.eval()) + + def test_get_dense_tensor_two_input_values(self): + """Tests _get_dense_tensor() for input with shape=[2].""" + price = fc.numeric_column('price', shape=[2]) + bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6]) + with ops.Graph().as_default(): + transformation_cache = FeatureTransformationCache({ + 'price': [[-1., 1.], [5., 6.]] + }) + with _initialized_session(): + bucketized_price_tensor = bucketized_price.get_dense_tensor( + transformation_cache, None) + self.assertAllClose( + # One-hot tensor. + [[[1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.]], + [[0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]]], + bucketized_price_tensor.eval()) + + def test_get_sparse_tensors_one_input_value(self): + """Tests _get_sparse_tensors() for input with shape=[1].""" + price = fc.numeric_column('price', shape=[1]) + bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6]) + with ops.Graph().as_default(): + transformation_cache = FeatureTransformationCache({ + 'price': [[-1.], [1.], [5.], [6.]] + }) + with _initialized_session() as sess: + id_weight_pair = bucketized_price.get_sparse_tensors( + transformation_cache, None) + self.assertIsNone(id_weight_pair.weight_tensor) + id_tensor_value = sess.run(id_weight_pair.id_tensor) + self.assertAllEqual( + [[0, 0], [1, 0], [2, 0], [3, 0]], id_tensor_value.indices) + self.assertAllEqual([0, 1, 3, 4], id_tensor_value.values) + self.assertAllEqual([4, 1], id_tensor_value.dense_shape) + + def test_get_sparse_tensors_two_input_values(self): + """Tests _get_sparse_tensors() for input with shape=[2].""" + price = fc.numeric_column('price', shape=[2]) + bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6]) + with ops.Graph().as_default(): + transformation_cache = FeatureTransformationCache({ + 'price': [[-1., 1.], [5., 6.]] + }) + with _initialized_session() as sess: + id_weight_pair = bucketized_price.get_sparse_tensors( + transformation_cache, None) + self.assertIsNone(id_weight_pair.weight_tensor) + id_tensor_value = sess.run(id_weight_pair.id_tensor) + self.assertAllEqual( + [[0, 0], [0, 1], [1, 0], [1, 1]], id_tensor_value.indices) + # Values 0-4 correspond to the first column of the input price. + # Values 5-9 correspond to the second column of the input price. + self.assertAllEqual([0, 6, 3, 9], id_tensor_value.values) + self.assertAllEqual([2, 2], id_tensor_value.dense_shape) + + def test_sparse_tensor_input_not_supported(self): + price = fc.numeric_column('price') + bucketized_price = fc.bucketized_column(price, boundaries=[0, 1]) + transformation_cache = FeatureTransformationCache({ + 'price': + sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[0.3], dense_shape=[1, 1]) + }) + with self.assertRaisesRegexp(ValueError, 'must be a Tensor'): + bucketized_price.transform_feature(transformation_cache, None) + + def test_deep_copy(self): + a = fc.numeric_column('aaa', shape=[2]) + a_bucketized = fc.bucketized_column(a, boundaries=[0, 1]) + a_bucketized_copy = copy.deepcopy(a_bucketized) + self.assertEqual(a_bucketized_copy.name, 'aaa_bucketized') + self.assertAllEqual(a_bucketized_copy.variable_shape, (2, 3)) + self.assertEqual(a_bucketized_copy.boundaries, (0, 1)) + + def test_linear_model_one_input_value(self): + """Tests linear_model() for input with shape=[1].""" + price = fc_old.numeric_column('price', shape=[1]) + bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6]) + with ops.Graph().as_default(): + features = {'price': [[-1.], [1.], [5.], [6.]]} + predictions = fc.linear_model(features, [bucketized_price]) + bias = get_linear_model_bias() + bucketized_price_var = get_linear_model_column_var(bucketized_price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + # One weight variable per bucket, all initialized to zero. + self.assertAllClose( + [[0.], [0.], [0.], [0.], [0.]], bucketized_price_var.eval()) + self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval()) + sess.run(bucketized_price_var.assign( + [[10.], [20.], [30.], [40.], [50.]])) + # price -1. is in the 0th bucket, whose weight is 10. + # price 1. is in the 1st bucket, whose weight is 20. + # price 5. is in the 3rd bucket, whose weight is 40. + # price 6. is in the 4th bucket, whose weight is 50. + self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval()) + sess.run(bias.assign([1.])) + self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval()) + + def test_linear_model_two_input_values(self): + """Tests linear_model() for input with shape=[2].""" + price = fc_old.numeric_column('price', shape=[2]) + bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6]) + with ops.Graph().as_default(): + features = {'price': [[-1., 1.], [5., 6.]]} + predictions = fc.linear_model(features, [bucketized_price]) + bias = get_linear_model_bias() + bucketized_price_var = get_linear_model_column_var(bucketized_price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + # One weight per bucket per input column, all initialized to zero. + self.assertAllClose( + [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], + bucketized_price_var.eval()) + self.assertAllClose([[0.], [0.]], predictions.eval()) + sess.run(bucketized_price_var.assign( + [[10.], [20.], [30.], [40.], [50.], + [60.], [70.], [80.], [90.], [100.]])) + # 1st example: + # price -1. is in the 0th bucket, whose weight is 10. + # price 1. is in the 6th bucket, whose weight is 70. + # 2nd example: + # price 5. is in the 3rd bucket, whose weight is 40. + # price 6. is in the 9th bucket, whose weight is 100. + self.assertAllClose([[80.], [140.]], predictions.eval()) + sess.run(bias.assign([1.])) + self.assertAllClose([[81.], [141.]], predictions.eval()) + + def test_keras_linear_model_one_input_value(self): + """Tests _LinearModel for input with shape=[1].""" + price = fc_old.numeric_column('price', shape=[1]) + bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6]) + with ops.Graph().as_default(): + features = {'price': [[-1.], [1.], [5.], [6.]]} + predictions = get_keras_linear_model_predictions(features, + [bucketized_price]) + bias = get_linear_model_bias() + bucketized_price_var = get_linear_model_column_var(bucketized_price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + # One weight variable per bucket, all initialized to zero. + self.assertAllClose([[0.], [0.], [0.], [0.], [0.]], + bucketized_price_var.eval()) + self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval()) + sess.run( + bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]])) + # price -1. is in the 0th bucket, whose weight is 10. + # price 1. is in the 1st bucket, whose weight is 20. + # price 5. is in the 3rd bucket, whose weight is 40. + # price 6. is in the 4th bucket, whose weight is 50. + self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval()) + sess.run(bias.assign([1.])) + self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval()) + + def test_keras_linear_model_two_input_values(self): + """Tests _LinearModel for input with shape=[2].""" + price = fc_old.numeric_column('price', shape=[2]) + bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6]) + with ops.Graph().as_default(): + features = {'price': [[-1., 1.], [5., 6.]]} + predictions = get_keras_linear_model_predictions(features, + [bucketized_price]) + bias = get_linear_model_bias() + bucketized_price_var = get_linear_model_column_var(bucketized_price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + # One weight per bucket per input column, all initialized to zero. + self.assertAllClose( + [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], + bucketized_price_var.eval()) + self.assertAllClose([[0.], [0.]], predictions.eval()) + sess.run( + bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.], + [60.], [70.], [80.], [90.], [100.]])) + # 1st example: + # price -1. is in the 0th bucket, whose weight is 10. + # price 1. is in the 6th bucket, whose weight is 70. + # 2nd example: + # price 5. is in the 3rd bucket, whose weight is 40. + # price 6. is in the 9th bucket, whose weight is 100. + self.assertAllClose([[80.], [140.]], predictions.eval()) + sess.run(bias.assign([1.])) + self.assertAllClose([[81.], [141.]], predictions.eval()) + + +class HashedCategoricalColumnTest(test.TestCase): + + def test_defaults(self): + a = fc.categorical_column_with_hash_bucket('aaa', 10) + self.assertEqual('aaa', a.name) + self.assertEqual('aaa', a.key) + self.assertEqual(10, a.hash_bucket_size) + self.assertEqual(dtypes.string, a.dtype) + + def test_key_should_be_string(self): + with self.assertRaisesRegexp(ValueError, 'key must be a string.'): + fc.categorical_column_with_hash_bucket(('key',), 10) + + def test_bucket_size_should_be_given(self): + with self.assertRaisesRegexp(ValueError, 'hash_bucket_size must be set.'): + fc.categorical_column_with_hash_bucket('aaa', None) + + def test_bucket_size_should_be_positive(self): + with self.assertRaisesRegexp(ValueError, + 'hash_bucket_size must be at least 1'): + fc.categorical_column_with_hash_bucket('aaa', 0) + + def test_dtype_should_be_string_or_integer(self): + fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.string) + fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.int32) + with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'): + fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.float32) + + def test_deep_copy(self): + original = fc.categorical_column_with_hash_bucket('aaa', 10) + for column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', column.name) + self.assertEqual(10, column.hash_bucket_size) + self.assertEqual(10, column.num_buckets) + self.assertEqual(dtypes.string, column.dtype) + + def test_parse_spec_string(self): + a = fc.categorical_column_with_hash_bucket('aaa', 10) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.string) + }, a.parse_example_spec) + + def test_parse_spec_int(self): + a = fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.int32) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, a.parse_example_spec) + + def test_parse_example(self): + a = fc.categorical_column_with_hash_bucket('aaa', 10) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'aaa': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=[b'omar', b'stringer'])) + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([a])) + self.assertIn('aaa', features) + with self.test_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([b'omar', b'stringer'], dtype=np.object_), + dense_shape=[1, 2]), + features['aaa'].eval()) + + def test_strings_should_be_hashed(self): + hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10) + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + outputs = _transform_features({'wire': wire_tensor}, [hashed_sparse], None) + output = outputs[hashed_sparse] + # Check exact hashed output. If hashing changes this test will break. + expected_values = [6, 4, 1] + with self.test_session(): + self.assertEqual(dtypes.int64, output.values.dtype) + self.assertAllEqual(expected_values, output.values.eval()) + self.assertAllEqual(wire_tensor.indices.eval(), output.indices.eval()) + self.assertAllEqual(wire_tensor.dense_shape.eval(), + output.dense_shape.eval()) + + def test_tensor_dtype_should_be_string_or_integer(self): + string_fc = fc.categorical_column_with_hash_bucket( + 'a_string', 10, dtype=dtypes.string) + int_fc = fc.categorical_column_with_hash_bucket( + 'a_int', 10, dtype=dtypes.int32) + float_fc = fc.categorical_column_with_hash_bucket( + 'a_float', 10, dtype=dtypes.string) + int_tensor = sparse_tensor.SparseTensor( + values=[101], + indices=[[0, 0]], + dense_shape=[1, 1]) + string_tensor = sparse_tensor.SparseTensor( + values=['101'], + indices=[[0, 0]], + dense_shape=[1, 1]) + float_tensor = sparse_tensor.SparseTensor( + values=[101.], + indices=[[0, 0]], + dense_shape=[1, 1]) + transformation_cache = FeatureTransformationCache({ + 'a_int': int_tensor, + 'a_string': string_tensor, + 'a_float': float_tensor + }) + transformation_cache.get(string_fc, None) + transformation_cache.get(int_fc, None) + with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'): + transformation_cache.get(float_fc, None) + + def test_dtype_should_match_with_tensor(self): + hashed_sparse = fc.categorical_column_with_hash_bucket( + 'wire', 10, dtype=dtypes.int64) + wire_tensor = sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + transformation_cache = FeatureTransformationCache({'wire': wire_tensor}) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + transformation_cache.get(hashed_sparse, None) + + def test_ints_should_be_hashed(self): + hashed_sparse = fc.categorical_column_with_hash_bucket( + 'wire', 10, dtype=dtypes.int64) + wire_tensor = sparse_tensor.SparseTensor( + values=[101, 201, 301], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + transformation_cache = FeatureTransformationCache({'wire': wire_tensor}) + output = transformation_cache.get(hashed_sparse, None) + # Check exact hashed output. If hashing changes this test will break. + expected_values = [3, 7, 5] + with self.test_session(): + self.assertAllEqual(expected_values, output.values.eval()) + + def test_int32_64_is_compatible(self): + hashed_sparse = fc.categorical_column_with_hash_bucket( + 'wire', 10, dtype=dtypes.int64) + wire_tensor = sparse_tensor.SparseTensor( + values=constant_op.constant([101, 201, 301], dtype=dtypes.int32), + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + transformation_cache = FeatureTransformationCache({'wire': wire_tensor}) + output = transformation_cache.get(hashed_sparse, None) + # Check exact hashed output. If hashing changes this test will break. + expected_values = [3, 7, 5] + with self.test_session(): + self.assertAllEqual(expected_values, output.values.eval()) + + def test_get_sparse_tensors(self): + hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10) + transformation_cache = FeatureTransformationCache({ + 'wire': + sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + }) + id_weight_pair = hashed_sparse.get_sparse_tensors(transformation_cache, + None) + self.assertIsNone(id_weight_pair.weight_tensor) + self.assertEqual( + transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor) + + def DISABLED_test_get_sparse_tensors_weight_collections(self): + column = fc.categorical_column_with_hash_bucket('aaa', 10) + inputs = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + column._get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), + weight_collections=('my_weights',)) + + self.assertItemsEqual( + [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertItemsEqual([], ops.get_collection('my_weights')) + + def test_get_sparse_tensors_dense_input(self): + hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10) + transformation_cache = FeatureTransformationCache({ + 'wire': (('omar', ''), ('stringer', 'marlo')) + }) + id_weight_pair = hashed_sparse.get_sparse_tensors(transformation_cache, + None) + self.assertIsNone(id_weight_pair.weight_tensor) + self.assertEqual( + transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor) + + def test_linear_model(self): + wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4) + self.assertEqual(4, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = fc.linear_model({ + wire_column.name: sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval() + # 'marlo' -> 3: wire_var[3] = 4 + # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6 + self.assertAllClose(((4.,), (6.,)), predictions.eval()) + + def test_keras_linear_model(self): + wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4) + self.assertEqual(4, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + wire_column.name: + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval() + # 'marlo' -> 3: wire_var[3] = 4 + # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6 + self.assertAllClose(((4.,), (6.,)), predictions.eval()) + + +class CrossedColumnTest(test.TestCase): + + def test_keys_empty(self): + with self.assertRaisesRegexp( + ValueError, 'keys must be a list with length > 1'): + fc.crossed_column([], 10) + + def test_keys_length_one(self): + with self.assertRaisesRegexp( + ValueError, 'keys must be a list with length > 1'): + fc.crossed_column(['a'], 10) + + def test_key_type_unsupported(self): + with self.assertRaisesRegexp(ValueError, 'Unsupported key type'): + fc.crossed_column(['a', fc.numeric_column('c')], 10) + + with self.assertRaisesRegexp( + ValueError, 'categorical_column_with_hash_bucket is not supported'): + fc.crossed_column( + ['a', fc.categorical_column_with_hash_bucket('c', 10)], 10) + + def test_hash_bucket_size_negative(self): + with self.assertRaisesRegexp( + ValueError, 'hash_bucket_size must be > 1'): + fc.crossed_column(['a', 'c'], -1) + + def test_hash_bucket_size_zero(self): + with self.assertRaisesRegexp( + ValueError, 'hash_bucket_size must be > 1'): + fc.crossed_column(['a', 'c'], 0) + + def test_hash_bucket_size_none(self): + with self.assertRaisesRegexp( + ValueError, 'hash_bucket_size must be > 1'): + fc.crossed_column(['a', 'c'], None) + + def test_name(self): + a = fc.numeric_column('a', dtype=dtypes.int32) + b = fc.bucketized_column(a, boundaries=[0, 1]) + crossed1 = fc.crossed_column(['d1', 'd2'], 10) + + crossed2 = fc.crossed_column([b, 'c', crossed1], 10) + self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name) + + def test_name_ordered_alphabetically(self): + """Tests that the name does not depend on the order of given columns.""" + a = fc.numeric_column('a', dtype=dtypes.int32) + b = fc.bucketized_column(a, boundaries=[0, 1]) + crossed1 = fc.crossed_column(['d1', 'd2'], 10) + + crossed2 = fc.crossed_column([crossed1, 'c', b], 10) + self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name) + + def test_name_leaf_keys_ordered_alphabetically(self): + """Tests that the name does not depend on the order of given columns.""" + a = fc.numeric_column('a', dtype=dtypes.int32) + b = fc.bucketized_column(a, boundaries=[0, 1]) + crossed1 = fc.crossed_column(['d2', 'c'], 10) + + crossed2 = fc.crossed_column([crossed1, 'd1', b], 10) + self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name) + + def test_parse_spec(self): + a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32) + b = fc.bucketized_column(a, boundaries=[0, 1]) + crossed = fc.crossed_column([b, 'c'], 10) + self.assertEqual({ + 'a': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32), + 'c': parsing_ops.VarLenFeature(dtypes.string), + }, crossed.parse_example_spec) + + def test_num_buckets(self): + a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32) + b = fc.bucketized_column(a, boundaries=[0, 1]) + crossed = fc.crossed_column([b, 'c'], 15) + self.assertEqual(15, crossed.num_buckets) + + def test_deep_copy(self): + a = fc.numeric_column('a', dtype=dtypes.int32) + b = fc.bucketized_column(a, boundaries=[0, 1]) + crossed1 = fc.crossed_column(['d1', 'd2'], 10) + crossed2 = fc.crossed_column([b, 'c', crossed1], 15, hash_key=5) + crossed2_copy = copy.deepcopy(crossed2) + self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2_copy.name,) + self.assertEqual(15, crossed2_copy.hash_bucket_size) + self.assertEqual(5, crossed2_copy.hash_key) + + def test_parse_example(self): + price = fc.numeric_column('price', shape=[2]) + bucketized_price = fc.bucketized_column(price, boundaries=[0, 50]) + price_cross_wire = fc.crossed_column([bucketized_price, 'wire'], 10) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'price': + feature_pb2.Feature(float_list=feature_pb2.FloatList( + value=[20., 110.])), + 'wire': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=[b'omar', b'stringer'])), + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([price_cross_wire])) + self.assertIn('price', features) + self.assertIn('wire', features) + with self.test_session(): + self.assertAllEqual([[20., 110.]], features['price'].eval()) + wire_sparse = features['wire'] + self.assertAllEqual([[0, 0], [0, 1]], wire_sparse.indices.eval()) + # Use byte constants to pass the open-source test. + self.assertAllEqual([b'omar', b'stringer'], wire_sparse.values.eval()) + self.assertAllEqual([1, 2], wire_sparse.dense_shape.eval()) + + def test_transform_feature(self): + price = fc.numeric_column('price', shape=[2]) + bucketized_price = fc.bucketized_column(price, boundaries=[0, 50]) + hash_bucket_size = 10 + price_cross_wire = fc.crossed_column( + [bucketized_price, 'wire'], hash_bucket_size) + features = { + 'price': constant_op.constant([[1., 2.], [5., 6.]]), + 'wire': sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]), + } + outputs = _transform_features(features, [price_cross_wire], None) + output = outputs[price_cross_wire] + with self.test_session() as sess: + output_val = sess.run(output) + self.assertAllEqual( + [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices) + for val in output_val.values: + self.assertIn(val, list(range(hash_bucket_size))) + self.assertAllEqual([2, 4], output_val.dense_shape) + + def test_get_sparse_tensors(self): + a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,)) + b = fc.bucketized_column(a, boundaries=(0, 1)) + crossed1 = fc.crossed_column(['d1', 'd2'], 10) + crossed2 = fc.crossed_column([b, 'c', crossed1], 15, hash_key=5) + with ops.Graph().as_default(): + transformation_cache = FeatureTransformationCache({ + 'a': + constant_op.constant(((-1., .5), (.5, 1.))), + 'c': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=['cA', 'cB', 'cC'], + dense_shape=(2, 2)), + 'd1': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=['d1A', 'd1B', 'd1C'], + dense_shape=(2, 2)), + 'd2': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=['d2A', 'd2B', 'd2C'], + dense_shape=(2, 2)), + }) + id_weight_pair = crossed2.get_sparse_tensors(transformation_cache, None) + with _initialized_session(): + id_tensor_eval = id_weight_pair.id_tensor.eval() + self.assertAllEqual( + ((0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), + (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (1, 12), (1, 13), + (1, 14), (1, 15)), + id_tensor_eval.indices) + # Check exact hashed output. If hashing changes this test will break. + # All values are within [0, hash_bucket_size). + expected_values = ( + 6, 14, 0, 13, 8, 8, 10, 12, 2, 0, 1, 9, 8, 12, 2, 0, 10, 11) + self.assertAllEqual(expected_values, id_tensor_eval.values) + self.assertAllEqual((2, 16), id_tensor_eval.dense_shape) + + def test_get_sparse_tensors_simple(self): + """Same as test_get_sparse_tensors, but with simpler values.""" + a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,)) + b = fc.bucketized_column(a, boundaries=(0, 1)) + crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5) + with ops.Graph().as_default(): + transformation_cache = FeatureTransformationCache({ + 'a': + constant_op.constant(((-1., .5), (.5, 1.))), + 'c': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=['cA', 'cB', 'cC'], + dense_shape=(2, 2)), + }) + id_weight_pair = crossed.get_sparse_tensors(transformation_cache, None) + with _initialized_session(): + id_tensor_eval = id_weight_pair.id_tensor.eval() + self.assertAllEqual( + ((0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (1, 3)), + id_tensor_eval.indices) + # Check exact hashed output. If hashing changes this test will break. + # All values are within [0, hash_bucket_size). + expected_values = (1, 0, 1, 3, 4, 2) + self.assertAllEqual(expected_values, id_tensor_eval.values) + self.assertAllEqual((2, 4), id_tensor_eval.dense_shape) + + def test_linear_model(self): + """Tests linear_model. + + Uses data from test_get_sparse_tesnsors_simple. + """ + a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,)) + b = fc_old.bucketized_column(a, boundaries=(0, 1)) + crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5) + with ops.Graph().as_default(): + predictions = fc.linear_model({ + 'a': constant_op.constant(((-1., .5), (.5, 1.))), + 'c': sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=['cA', 'cB', 'cC'], + dense_shape=(2, 2)), + }, (crossed,)) + bias = get_linear_model_bias() + crossed_var = get_linear_model_column_var(crossed) + with _initialized_session() as sess: + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose( + ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,)))) + # Expected ids after cross = (1, 0, 1, 3, 4, 2) + self.assertAllClose(((3.,), (14.,)), predictions.eval()) + sess.run(bias.assign((.1,))) + self.assertAllClose(((3.1,), (14.1,)), predictions.eval()) + + def test_linear_model_with_weights(self): + + class _TestColumnWithWeights(fc_old._CategoricalColumn): + """Produces sparse IDs and sparse weights.""" + + @property + def name(self): + return 'test_column' + + @property + def _parse_example_spec(self): + return { + self.name: parsing_ops.VarLenFeature(dtypes.int32), + '{}_weights'.format(self.name): parsing_ops.VarLenFeature( + dtypes.float32), + } + + @property + def _num_buckets(self): + return 5 + + def _transform_feature(self, inputs): + return (inputs.get(self.name), + inputs.get('{}_weights'.format(self.name))) + + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + """Populates both id_tensor and weight_tensor.""" + ids_and_weights = inputs.get(self) + return fc_old._CategoricalColumn.IdWeightPair( + id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1]) + + t = _TestColumnWithWeights() + crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5) + with ops.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, + 'crossed_column does not support weight_tensor.*{}'.format(t.name)): + fc.linear_model({ + t.name: sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=[0, 1, 2], + dense_shape=(2, 2)), + '{}_weights'.format(t.name): sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=[1., 10., 2.], + dense_shape=(2, 2)), + 'c': sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=['cA', 'cB', 'cC'], + dense_shape=(2, 2)), + }, (crossed,)) + + def test_keras_linear_model(self): + """Tests _LinearModel. + + Uses data from test_get_sparse_tesnsors_simple. + """ + a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,)) + b = fc_old.bucketized_column(a, boundaries=(0, 1)) + crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5) + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + 'a': + constant_op.constant(((-1., .5), (.5, 1.))), + 'c': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=['cA', 'cB', 'cC'], + dense_shape=(2, 2)), + }, (crossed,)) + bias = get_linear_model_bias() + crossed_var = get_linear_model_column_var(crossed) + with _initialized_session() as sess: + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)), + crossed_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,)))) + # Expected ids after cross = (1, 0, 1, 3, 4, 2) + self.assertAllClose(((3.,), (14.,)), predictions.eval()) + sess.run(bias.assign((.1,))) + self.assertAllClose(((3.1,), (14.1,)), predictions.eval()) + + def test_keras_linear_model_with_weights(self): + + class _TestColumnWithWeights(fc_old._CategoricalColumn): + """Produces sparse IDs and sparse weights.""" + + @property + def name(self): + return 'test_column' + + @property + def _parse_example_spec(self): + return { + self.name: + parsing_ops.VarLenFeature(dtypes.int32), + '{}_weights'.format(self.name): + parsing_ops.VarLenFeature(dtypes.float32), + } + + @property + def _num_buckets(self): + return 5 + + def _transform_feature(self, inputs): + return (inputs.get(self.name), + inputs.get('{}_weights'.format(self.name))) + + def _get_sparse_tensors(self, + inputs, + weight_collections=None, + trainable=None): + """Populates both id_tensor and weight_tensor.""" + ids_and_weights = inputs.get(self) + return fc_old._CategoricalColumn.IdWeightPair( + id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1]) + + t = _TestColumnWithWeights() + crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5) + with ops.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, + 'crossed_column does not support weight_tensor.*{}'.format(t.name)): + get_keras_linear_model_predictions({ + t.name: + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=[0, 1, 2], + dense_shape=(2, 2)), + '{}_weights'.format(t.name): + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=[1., 10., 2.], + dense_shape=(2, 2)), + 'c': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=['cA', 'cB', 'cC'], + dense_shape=(2, 2)), + }, (crossed,)) + + +def get_linear_model_bias(name='linear_model'): + with variable_scope.variable_scope(name, reuse=True): + return variable_scope.get_variable('bias_weights') + + +def get_linear_model_column_var(column, name='linear_model'): + return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, + name + '/' + column.name)[0] + + +def get_keras_linear_model_predictions(features, + feature_columns, + units=1, + sparse_combiner='sum', + weight_collections=None, + trainable=True, + cols_to_vars=None): + keras_linear_model = _LinearModel( + feature_columns, + units, + sparse_combiner, + weight_collections, + trainable, + name='linear_model') + retval = keras_linear_model(features) # pylint: disable=not-callable + if cols_to_vars is not None: + cols_to_vars.update(keras_linear_model.cols_to_vars()) + return retval + + +class LinearModelTest(test.TestCase): + + def test_raises_if_empty_feature_columns(self): + with self.assertRaisesRegexp(ValueError, + 'feature_columns must not be empty'): + fc.linear_model(features={}, feature_columns=[]) + + def test_should_be_feature_column(self): + with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'): + fc.linear_model(features={'a': [[0]]}, feature_columns='NotSupported') + + def test_should_be_dense_or_categorical_column(self): + + class NotSupportedColumn(fc_old._FeatureColumn): + + @property + def name(self): + return 'NotSupportedColumn' + + def _transform_feature(self, cache): + pass + + @property + def _parse_example_spec(self): + pass + + with self.assertRaisesRegexp( + ValueError, 'must be either a _DenseColumn or _CategoricalColumn'): + fc.linear_model( + features={'a': [[0]]}, feature_columns=[NotSupportedColumn()]) + + def test_does_not_support_dict_columns(self): + with self.assertRaisesRegexp( + ValueError, 'Expected feature_columns to be iterable, found dict.'): + fc.linear_model( + features={'a': [[0]]}, + feature_columns={'a': fc_old.numeric_column('a')}) + + def test_raises_if_duplicate_name(self): + with self.assertRaisesRegexp( + ValueError, 'Duplicate feature column name found for columns'): + fc.linear_model( + features={'a': [[0]]}, + feature_columns=[ + fc_old.numeric_column('a'), + fc_old.numeric_column('a') + ]) + + def test_dense_bias(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + predictions = fc.linear_model(features, [price]) + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + sess.run(price_var.assign([[10.]])) + sess.run(bias.assign([5.])) + self.assertAllClose([[15.], [55.]], predictions.eval()) + + def test_sparse_bias(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default(): + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {'wire_cast': wire_tensor} + predictions = fc.linear_model(features, [wire_cast]) + bias = get_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval()) + sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(bias.assign([5.])) + self.assertAllClose([[1005.], [10015.]], predictions.eval()) + + def test_dense_and_sparse_bias(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + price = fc_old.numeric_column('price') + with ops.Graph().as_default(): + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]} + predictions = fc.linear_model(features, [wire_cast, price]) + bias = get_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(bias.assign([5.])) + sess.run(price_var.assign([[10.]])) + self.assertAllClose([[1015.], [10065.]], predictions.eval()) + + def test_dense_and_sparse_column(self): + """When the column is both dense and sparse, uses sparse tensors.""" + + class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn): + + @property + def name(self): + return 'dense_and_sparse_column' + + @property + def _parse_example_spec(self): + return {self.name: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + return inputs.get(self.name) + + @property + def _variable_shape(self): + raise ValueError('Should not use this method.') + + def _get_dense_tensor(self, inputs, weight_collections=None, + trainable=None): + raise ValueError('Should not use this method.') + + @property + def _num_buckets(self): + return 4 + + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + sp_tensor = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0], [1, 1]], + values=[2, 0, 3], + dense_shape=[2, 2]) + return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None) + + dense_and_sparse_column = _DenseAndSparseColumn() + with ops.Graph().as_default(): + sp_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {dense_and_sparse_column.name: sp_tensor} + predictions = fc.linear_model(features, [dense_and_sparse_column]) + bias = get_linear_model_bias() + dense_and_sparse_column_var = get_linear_model_column_var( + dense_and_sparse_column) + with _initialized_session() as sess: + sess.run(dense_and_sparse_column_var.assign( + [[10.], [100.], [1000.], [10000.]])) + sess.run(bias.assign([5.])) + self.assertAllClose([[1005.], [10015.]], predictions.eval()) + + def test_dense_multi_output(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + predictions = fc.linear_model(features, [price], units=3) + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose(np.zeros((3,)), bias.eval()) + self.assertAllClose(np.zeros((1, 3)), price_var.eval()) + sess.run(price_var.assign([[10., 100., 1000.]])) + sess.run(bias.assign([5., 6., 7.])) + self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]], + predictions.eval()) + + def test_sparse_multi_output(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default(): + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {'wire_cast': wire_tensor} + predictions = fc.linear_model(features, [wire_cast], units=3) + bias = get_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + with _initialized_session() as sess: + self.assertAllClose(np.zeros((3,)), bias.eval()) + self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval()) + sess.run( + wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.], [ + 1000., 1100., 1200. + ], [10000., 11000., 12000.]])) + sess.run(bias.assign([5., 6., 7.])) + self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]], + predictions.eval()) + + def test_dense_multi_dimension(self): + price = fc_old.numeric_column('price', shape=2) + with ops.Graph().as_default(): + features = {'price': [[1., 2.], [5., 6.]]} + predictions = fc.linear_model(features, [price]) + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose([[0.], [0.]], price_var.eval()) + sess.run(price_var.assign([[10.], [100.]])) + self.assertAllClose([[210.], [650.]], predictions.eval()) + + def test_sparse_multi_rank(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default(): + wire_tensor = array_ops.sparse_placeholder(dtypes.string) + wire_value = sparse_tensor.SparseTensorValue( + values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2] + indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]], + dense_shape=[2, 2, 2]) + features = {'wire_cast': wire_tensor} + predictions = fc.linear_model(features, [wire_cast]) + wire_cast_var = get_linear_model_column_var(wire_cast) + with _initialized_session() as sess: + self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval()) + self.assertAllClose( + np.zeros((2, 1)), + predictions.eval(feed_dict={wire_tensor: wire_value})) + sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) + self.assertAllClose( + [[1010.], [11000.]], + predictions.eval(feed_dict={wire_tensor: wire_value})) + + def test_sparse_combiner(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default(): + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {'wire_cast': wire_tensor} + predictions = fc.linear_model( + features, [wire_cast], sparse_combiner='mean') + bias = get_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + with _initialized_session() as sess: + sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(bias.assign([5.])) + self.assertAllClose([[1005.], [5010.]], predictions.eval()) + + def test_sparse_combiner_with_negative_weights(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + wire_cast_weights = fc_old.weighted_categorical_column(wire_cast, 'weights') + + with ops.Graph().as_default(): + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = { + 'wire_cast': wire_tensor, + 'weights': constant_op.constant([[1., 1., -1.0]]) + } + predictions = fc.linear_model( + features, [wire_cast_weights], sparse_combiner='sum') + bias = get_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + with _initialized_session() as sess: + sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(bias.assign([5.])) + self.assertAllClose([[1005.], [-9985.]], predictions.eval()) + + def test_dense_multi_dimension_multi_output(self): + price = fc_old.numeric_column('price', shape=2) + with ops.Graph().as_default(): + features = {'price': [[1., 2.], [5., 6.]]} + predictions = fc.linear_model(features, [price], units=3) + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose(np.zeros((3,)), bias.eval()) + self.assertAllClose(np.zeros((2, 3)), price_var.eval()) + sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]])) + sess.run(bias.assign([2., 3., 4.])) + self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]], + predictions.eval()) + + def test_raises_if_shape_mismatch(self): + price = fc_old.numeric_column('price', shape=2) + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + with self.assertRaisesRegexp( + Exception, + r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'): + fc.linear_model(features, [price]) + + def test_dense_reshaping(self): + price = fc_old.numeric_column('price', shape=[1, 2]) + with ops.Graph().as_default(): + features = {'price': [[[1., 2.]], [[5., 6.]]]} + predictions = fc.linear_model(features, [price]) + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + self.assertAllClose([[0.], [0.]], price_var.eval()) + self.assertAllClose([[0.], [0.]], predictions.eval()) + sess.run(price_var.assign([[10.], [100.]])) + self.assertAllClose([[210.], [650.]], predictions.eval()) + + def test_dense_multi_column(self): + price1 = fc_old.numeric_column('price1', shape=2) + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': [[1., 2.], [5., 6.]], + 'price2': [[3.], [4.]] + } + predictions = fc.linear_model(features, [price1, price2]) + bias = get_linear_model_bias() + price1_var = get_linear_model_column_var(price1) + price2_var = get_linear_model_column_var(price2) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + self.assertAllClose([[0.], [0.]], price1_var.eval()) + self.assertAllClose([[0.]], price2_var.eval()) + self.assertAllClose([[0.], [0.]], predictions.eval()) + sess.run(price1_var.assign([[10.], [100.]])) + sess.run(price2_var.assign([[1000.]])) + sess.run(bias.assign([7.])) + self.assertAllClose([[3217.], [4657.]], predictions.eval()) + + def test_fills_cols_to_vars(self): + price1 = fc_old.numeric_column('price1', shape=2) + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]} + cols_to_vars = {} + fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars) + bias = get_linear_model_bias() + price1_var = get_linear_model_column_var(price1) + price2_var = get_linear_model_column_var(price2) + self.assertAllEqual(cols_to_vars['bias'], [bias]) + self.assertAllEqual(cols_to_vars[price1], [price1_var]) + self.assertAllEqual(cols_to_vars[price2], [price2_var]) + + def test_fills_cols_to_vars_partitioned_variables(self): + price1 = fc_old.numeric_column('price1', shape=2) + price2 = fc_old.numeric_column('price2', shape=3) + with ops.Graph().as_default(): + features = { + 'price1': [[1., 2.], [6., 7.]], + 'price2': [[3., 4., 5.], [8., 9., 10.]] + } + cols_to_vars = {} + with variable_scope.variable_scope( + 'linear', + partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)): + fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars) + with _initialized_session(): + self.assertEqual([0.], cols_to_vars['bias'][0].eval()) + # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables. + self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval()) + self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval()) + # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and + # a [1, 1] Variable. + self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval()) + self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval()) + + def test_dense_collection(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default() as g: + features = {'price': [[1.], [5.]]} + fc.linear_model(features, [price], weight_collections=['my-vars']) + my_vars = g.get_collection('my-vars') + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + self.assertIn(bias, my_vars) + self.assertIn(price_var, my_vars) + + def test_sparse_collection(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default() as g: + wire_tensor = sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + features = {'wire_cast': wire_tensor} + fc.linear_model( + features, [wire_cast], weight_collections=['my-vars']) + my_vars = g.get_collection('my-vars') + bias = get_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + self.assertIn(bias, my_vars) + self.assertIn(wire_cast_var, my_vars) + + def test_dense_trainable_default(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default() as g: + features = {'price': [[1.], [5.]]} + fc.linear_model(features, [price]) + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertIn(bias, trainable_vars) + self.assertIn(price_var, trainable_vars) + + def test_sparse_trainable_default(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default() as g: + wire_tensor = sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + features = {'wire_cast': wire_tensor} + fc.linear_model(features, [wire_cast]) + trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + bias = get_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + self.assertIn(bias, trainable_vars) + self.assertIn(wire_cast_var, trainable_vars) + + def test_dense_trainable_false(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default() as g: + features = {'price': [[1.], [5.]]} + fc.linear_model(features, [price], trainable=False) + trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertEqual([], trainable_vars) + + def test_sparse_trainable_false(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default() as g: + wire_tensor = sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + features = {'wire_cast': wire_tensor} + fc.linear_model(features, [wire_cast], trainable=False) + trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertEqual([], trainable_vars) + + def test_column_order(self): + price_a = fc_old.numeric_column('price_a') + price_b = fc_old.numeric_column('price_b') + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default() as g: + features = { + 'price_a': [[1.]], + 'price_b': [[3.]], + 'wire_cast': + sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + } + fc.linear_model( + features, [price_a, wire_cast, price_b], + weight_collections=['my-vars']) + my_vars = g.get_collection('my-vars') + self.assertIn('price_a', my_vars[0].name) + self.assertIn('price_b', my_vars[1].name) + self.assertIn('wire_cast', my_vars[2].name) + + with ops.Graph().as_default() as g: + features = { + 'price_a': [[1.]], + 'price_b': [[3.]], + 'wire_cast': + sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + } + fc.linear_model( + features, [wire_cast, price_b, price_a], + weight_collections=['my-vars']) + my_vars = g.get_collection('my-vars') + self.assertIn('price_a', my_vars[0].name) + self.assertIn('price_b', my_vars[1].name) + self.assertIn('wire_cast', my_vars[2].name) + + def test_static_batch_size_mismatch(self): + price1 = fc_old.numeric_column('price1') + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': [[1.], [5.], [7.]], # batchsize = 3 + 'price2': [[3.], [4.]] # batchsize = 2 + } + with self.assertRaisesRegexp( + ValueError, + 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string + fc.linear_model(features, [price1, price2]) + + def test_subset_of_static_batch_size_mismatch(self): + price1 = fc_old.numeric_column('price1') + price2 = fc_old.numeric_column('price2') + price3 = fc_old.numeric_column('price3') + with ops.Graph().as_default(): + features = { + 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3 + 'price2': [[3.], [4.]], # batchsize = 2 + 'price3': [[3.], [4.], [5.]] # batchsize = 3 + } + with self.assertRaisesRegexp( + ValueError, + 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string + fc.linear_model(features, [price1, price2, price3]) + + def test_runtime_batch_size_mismatch(self): + price1 = fc_old.numeric_column('price1') + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3 + 'price2': [[3.], [4.]] # batchsize = 2 + } + predictions = fc.linear_model(features, [price1, price2]) + with _initialized_session() as sess: + with self.assertRaisesRegexp(errors.OpError, + 'must have the same size and shape'): + sess.run( + predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]}) + + def test_runtime_batch_size_matches(self): + price1 = fc_old.numeric_column('price1') + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2 + 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2 + } + predictions = fc.linear_model(features, [price1, price2]) + with _initialized_session() as sess: + sess.run( + predictions, + feed_dict={ + features['price1']: [[1.], [5.]], + features['price2']: [[1.], [5.]], + }) + + def test_with_numpy_input_fn(self): + price = fc_old.numeric_column('price') + price_buckets = fc_old.bucketized_column( + price, boundaries=[ + 0., + 10., + 100., + ]) + body_style = fc_old.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + + input_fn = numpy_io.numpy_input_fn( + x={ + 'price': np.array([-1., 2., 13., 104.]), + 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']), + }, + batch_size=2, + shuffle=False) + features = input_fn() + net = fc.linear_model(features, [price_buckets, body_style]) + # self.assertEqual(1 + 3 + 5, net.shape[1]) + with _initialized_session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess, coord=coord) + + bias = get_linear_model_bias() + price_buckets_var = get_linear_model_column_var(price_buckets) + body_style_var = get_linear_model_column_var(body_style) + + sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) + sess.run(bias.assign([5.])) + + self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net)) + + coord.request_stop() + coord.join(threads) + + def test_with_1d_sparse_tensor(self): + price = fc_old.numeric_column('price') + price_buckets = fc_old.bucketized_column( + price, boundaries=[ + 0., + 10., + 100., + ]) + body_style = fc_old.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + + # Provides 1-dim tensor and dense tensor. + features = { + 'price': constant_op.constant([-1., 12.,]), + 'body-style': sparse_tensor.SparseTensor( + indices=((0,), (1,)), + values=('sedan', 'hardtop'), + dense_shape=(2,)), + } + self.assertEqual(1, features['price'].shape.ndims) + self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0]) + + net = fc.linear_model(features, [price_buckets, body_style]) + with _initialized_session() as sess: + bias = get_linear_model_bias() + price_buckets_var = get_linear_model_column_var(price_buckets) + body_style_var = get_linear_model_column_var(body_style) + + sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) + sess.run(bias.assign([5.])) + + self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net)) + + def test_with_1d_unknown_shape_sparse_tensor(self): + price = fc_old.numeric_column('price') + price_buckets = fc_old.bucketized_column( + price, boundaries=[ + 0., + 10., + 100., + ]) + body_style = fc_old.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + country = fc_old.categorical_column_with_vocabulary_list( + 'country', vocabulary_list=['US', 'JP', 'CA']) + + # Provides 1-dim tensor and dense tensor. + features = { + 'price': array_ops.placeholder(dtypes.float32), + 'body-style': array_ops.sparse_placeholder(dtypes.string), + 'country': array_ops.placeholder(dtypes.string), + } + self.assertIsNone(features['price'].shape.ndims) + self.assertIsNone(features['body-style'].get_shape().ndims) + + price_data = np.array([-1., 12.]) + body_style_data = sparse_tensor.SparseTensorValue( + indices=((0,), (1,)), + values=('sedan', 'hardtop'), + dense_shape=(2,)) + country_data = np.array(['US', 'CA']) + + net = fc.linear_model(features, [price_buckets, body_style, country]) + bias = get_linear_model_bias() + price_buckets_var = get_linear_model_column_var(price_buckets) + body_style_var = get_linear_model_column_var(body_style) + with _initialized_session() as sess: + sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) + sess.run(bias.assign([5.])) + + self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], + sess.run( + net, + feed_dict={ + features['price']: price_data, + features['body-style']: body_style_data, + features['country']: country_data + })) + + def test_with_rank_0_feature(self): + price = fc_old.numeric_column('price') + features = { + 'price': constant_op.constant(0), + } + self.assertEqual(0, features['price'].shape.ndims) + + # Static rank 0 should fail + with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'): + fc.linear_model(features, [price]) + + # Dynamic rank 0 should fail + features = { + 'price': array_ops.placeholder(dtypes.float32), + } + net = fc.linear_model(features, [price]) + self.assertEqual(1, net.shape[1]) + with _initialized_session() as sess: + with self.assertRaisesOpError('Feature .* cannot have rank 0'): + sess.run(net, feed_dict={features['price']: np.array(1)}) + + def test_multiple_linear_models(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default(): + features1 = {'price': [[1.], [5.]]} + features2 = {'price': [[2.], [10.]]} + predictions1 = fc.linear_model(features1, [price]) + predictions2 = fc.linear_model(features2, [price]) + bias1 = get_linear_model_bias(name='linear_model') + bias2 = get_linear_model_bias(name='linear_model_1') + price_var1 = get_linear_model_column_var(price, name='linear_model') + price_var2 = get_linear_model_column_var(price, name='linear_model_1') + with _initialized_session() as sess: + self.assertAllClose([0.], bias1.eval()) + sess.run(price_var1.assign([[10.]])) + sess.run(bias1.assign([5.])) + self.assertAllClose([[15.], [55.]], predictions1.eval()) + self.assertAllClose([0.], bias2.eval()) + sess.run(price_var2.assign([[10.]])) + sess.run(bias2.assign([5.])) + self.assertAllClose([[25.], [105.]], predictions2.eval()) + + +class _LinearModelTest(test.TestCase): + + def test_raises_if_empty_feature_columns(self): + with self.assertRaisesRegexp(ValueError, + 'feature_columns must not be empty'): + get_keras_linear_model_predictions(features={}, feature_columns=[]) + + def test_should_be_feature_column(self): + with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'): + get_keras_linear_model_predictions( + features={'a': [[0]]}, feature_columns='NotSupported') + + def test_should_be_dense_or_categorical_column(self): + + class NotSupportedColumn(fc_old._FeatureColumn): + + @property + def name(self): + return 'NotSupportedColumn' + + def _transform_feature(self, cache): + pass + + @property + def _parse_example_spec(self): + pass + + with self.assertRaisesRegexp( + ValueError, 'must be either a _DenseColumn or _CategoricalColumn'): + get_keras_linear_model_predictions( + features={'a': [[0]]}, feature_columns=[NotSupportedColumn()]) + + def test_does_not_support_dict_columns(self): + with self.assertRaisesRegexp( + ValueError, 'Expected feature_columns to be iterable, found dict.'): + fc.linear_model( + features={'a': [[0]]}, + feature_columns={'a': fc_old.numeric_column('a')}) + + def test_raises_if_duplicate_name(self): + with self.assertRaisesRegexp( + ValueError, 'Duplicate feature column name found for columns'): + get_keras_linear_model_predictions( + features={'a': [[0]]}, + feature_columns=[ + fc_old.numeric_column('a'), + fc_old.numeric_column('a') + ]) + + def test_dense_bias(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + predictions = get_keras_linear_model_predictions(features, [price]) + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + sess.run(price_var.assign([[10.]])) + sess.run(bias.assign([5.])) + self.assertAllClose([[15.], [55.]], predictions.eval()) + + def test_sparse_bias(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default(): + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {'wire_cast': wire_tensor} + predictions = get_keras_linear_model_predictions(features, [wire_cast]) + bias = get_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval()) + sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(bias.assign([5.])) + self.assertAllClose([[1005.], [10015.]], predictions.eval()) + + def test_dense_and_sparse_bias(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + price = fc_old.numeric_column('price') + with ops.Graph().as_default(): + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]} + predictions = get_keras_linear_model_predictions(features, + [wire_cast, price]) + bias = get_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(bias.assign([5.])) + sess.run(price_var.assign([[10.]])) + self.assertAllClose([[1015.], [10065.]], predictions.eval()) + + def test_dense_and_sparse_column(self): + """When the column is both dense and sparse, uses sparse tensors.""" + + class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn): + + @property + def name(self): + return 'dense_and_sparse_column' + + @property + def _parse_example_spec(self): + return {self.name: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + return inputs.get(self.name) + + @property + def _variable_shape(self): + raise ValueError('Should not use this method.') + + def _get_dense_tensor(self, + inputs, + weight_collections=None, + trainable=None): + raise ValueError('Should not use this method.') + + @property + def _num_buckets(self): + return 4 + + def _get_sparse_tensors(self, + inputs, + weight_collections=None, + trainable=None): + sp_tensor = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0], [1, 1]], + values=[2, 0, 3], + dense_shape=[2, 2]) + return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None) + + dense_and_sparse_column = _DenseAndSparseColumn() + with ops.Graph().as_default(): + sp_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {dense_and_sparse_column.name: sp_tensor} + predictions = get_keras_linear_model_predictions( + features, [dense_and_sparse_column]) + bias = get_linear_model_bias() + dense_and_sparse_column_var = get_linear_model_column_var( + dense_and_sparse_column) + with _initialized_session() as sess: + sess.run( + dense_and_sparse_column_var.assign([[10.], [100.], [1000.], + [10000.]])) + sess.run(bias.assign([5.])) + self.assertAllClose([[1005.], [10015.]], predictions.eval()) + + def test_dense_multi_output(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + predictions = get_keras_linear_model_predictions( + features, [price], units=3) + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose(np.zeros((3,)), bias.eval()) + self.assertAllClose(np.zeros((1, 3)), price_var.eval()) + sess.run(price_var.assign([[10., 100., 1000.]])) + sess.run(bias.assign([5., 6., 7.])) + self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]], + predictions.eval()) + + def test_sparse_multi_output(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default(): + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {'wire_cast': wire_tensor} + predictions = get_keras_linear_model_predictions( + features, [wire_cast], units=3) + bias = get_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + with _initialized_session() as sess: + self.assertAllClose(np.zeros((3,)), bias.eval()) + self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval()) + sess.run( + wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.], + [1000., 1100., + 1200.], [10000., 11000., 12000.]])) + sess.run(bias.assign([5., 6., 7.])) + self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]], + predictions.eval()) + + def test_dense_multi_dimension(self): + price = fc_old.numeric_column('price', shape=2) + with ops.Graph().as_default(): + features = {'price': [[1., 2.], [5., 6.]]} + predictions = get_keras_linear_model_predictions(features, [price]) + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose([[0.], [0.]], price_var.eval()) + sess.run(price_var.assign([[10.], [100.]])) + self.assertAllClose([[210.], [650.]], predictions.eval()) + + def test_sparse_multi_rank(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default(): + wire_tensor = array_ops.sparse_placeholder(dtypes.string) + wire_value = sparse_tensor.SparseTensorValue( + values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2] + indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]], + dense_shape=[2, 2, 2]) + features = {'wire_cast': wire_tensor} + predictions = get_keras_linear_model_predictions(features, [wire_cast]) + wire_cast_var = get_linear_model_column_var(wire_cast) + with _initialized_session() as sess: + self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval()) + self.assertAllClose( + np.zeros((2, 1)), + predictions.eval(feed_dict={wire_tensor: wire_value})) + sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) + self.assertAllClose( + [[1010.], [11000.]], + predictions.eval(feed_dict={wire_tensor: wire_value})) + + def test_sparse_combiner(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default(): + wire_tensor = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + features = {'wire_cast': wire_tensor} + predictions = get_keras_linear_model_predictions( + features, [wire_cast], sparse_combiner='mean') + bias = get_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + with _initialized_session() as sess: + sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(bias.assign([5.])) + self.assertAllClose([[1005.], [5010.]], predictions.eval()) + + def test_dense_multi_dimension_multi_output(self): + price = fc_old.numeric_column('price', shape=2) + with ops.Graph().as_default(): + features = {'price': [[1., 2.], [5., 6.]]} + predictions = get_keras_linear_model_predictions( + features, [price], units=3) + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose(np.zeros((3,)), bias.eval()) + self.assertAllClose(np.zeros((2, 3)), price_var.eval()) + sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]])) + sess.run(bias.assign([2., 3., 4.])) + self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]], + predictions.eval()) + + def test_raises_if_shape_mismatch(self): + price = fc_old.numeric_column('price', shape=2) + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + with self.assertRaisesRegexp( + Exception, + r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'): + get_keras_linear_model_predictions(features, [price]) + + def test_dense_reshaping(self): + price = fc_old.numeric_column('price', shape=[1, 2]) + with ops.Graph().as_default(): + features = {'price': [[[1., 2.]], [[5., 6.]]]} + predictions = get_keras_linear_model_predictions(features, [price]) + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + self.assertAllClose([[0.], [0.]], price_var.eval()) + self.assertAllClose([[0.], [0.]], predictions.eval()) + sess.run(price_var.assign([[10.], [100.]])) + self.assertAllClose([[210.], [650.]], predictions.eval()) + + def test_dense_multi_column(self): + price1 = fc_old.numeric_column('price1', shape=2) + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]} + predictions = get_keras_linear_model_predictions(features, + [price1, price2]) + bias = get_linear_model_bias() + price1_var = get_linear_model_column_var(price1) + price2_var = get_linear_model_column_var(price2) + with _initialized_session() as sess: + self.assertAllClose([0.], bias.eval()) + self.assertAllClose([[0.], [0.]], price1_var.eval()) + self.assertAllClose([[0.]], price2_var.eval()) + self.assertAllClose([[0.], [0.]], predictions.eval()) + sess.run(price1_var.assign([[10.], [100.]])) + sess.run(price2_var.assign([[1000.]])) + sess.run(bias.assign([7.])) + self.assertAllClose([[3217.], [4657.]], predictions.eval()) + + def test_fills_cols_to_vars(self): + price1 = fc_old.numeric_column('price1', shape=2) + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]} + cols_to_vars = {} + get_keras_linear_model_predictions( + features, [price1, price2], cols_to_vars=cols_to_vars) + bias = get_linear_model_bias() + price1_var = get_linear_model_column_var(price1) + price2_var = get_linear_model_column_var(price2) + self.assertAllEqual(cols_to_vars['bias'], [bias]) + self.assertAllEqual(cols_to_vars[price1], [price1_var]) + self.assertAllEqual(cols_to_vars[price2], [price2_var]) + + def test_fills_cols_to_vars_partitioned_variables(self): + price1 = fc_old.numeric_column('price1', shape=2) + price2 = fc_old.numeric_column('price2', shape=3) + with ops.Graph().as_default(): + features = { + 'price1': [[1., 2.], [6., 7.]], + 'price2': [[3., 4., 5.], [8., 9., 10.]] + } + cols_to_vars = {} + with variable_scope.variable_scope( + 'linear', + partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)): + get_keras_linear_model_predictions( + features, [price1, price2], cols_to_vars=cols_to_vars) + with _initialized_session(): + self.assertEqual([0.], cols_to_vars['bias'][0].eval()) + # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables. + self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval()) + self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval()) + # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and + # a [1, 1] Variable. + self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval()) + self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval()) + + def test_dense_collection(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default() as g: + features = {'price': [[1.], [5.]]} + get_keras_linear_model_predictions( + features, [price], weight_collections=['my-vars']) + my_vars = g.get_collection('my-vars') + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + self.assertIn(bias, my_vars) + self.assertIn(price_var, my_vars) + + def test_sparse_collection(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default() as g: + wire_tensor = sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + features = {'wire_cast': wire_tensor} + get_keras_linear_model_predictions( + features, [wire_cast], weight_collections=['my-vars']) + my_vars = g.get_collection('my-vars') + bias = get_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + self.assertIn(bias, my_vars) + self.assertIn(wire_cast_var, my_vars) + + def test_dense_trainable_default(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default() as g: + features = {'price': [[1.], [5.]]} + get_keras_linear_model_predictions(features, [price]) + bias = get_linear_model_bias() + price_var = get_linear_model_column_var(price) + trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertIn(bias, trainable_vars) + self.assertIn(price_var, trainable_vars) + + def test_sparse_trainable_default(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default() as g: + wire_tensor = sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + features = {'wire_cast': wire_tensor} + get_keras_linear_model_predictions(features, [wire_cast]) + trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + bias = get_linear_model_bias() + wire_cast_var = get_linear_model_column_var(wire_cast) + self.assertIn(bias, trainable_vars) + self.assertIn(wire_cast_var, trainable_vars) + + def test_dense_trainable_false(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default() as g: + features = {'price': [[1.], [5.]]} + get_keras_linear_model_predictions(features, [price], trainable=False) + trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertEqual([], trainable_vars) + + def test_sparse_trainable_false(self): + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default() as g: + wire_tensor = sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + features = {'wire_cast': wire_tensor} + get_keras_linear_model_predictions(features, [wire_cast], trainable=False) + trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertEqual([], trainable_vars) + + def test_column_order(self): + price_a = fc_old.numeric_column('price_a') + price_b = fc_old.numeric_column('price_b') + wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + with ops.Graph().as_default() as g: + features = { + 'price_a': [[1.]], + 'price_b': [[3.]], + 'wire_cast': + sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + } + get_keras_linear_model_predictions( + features, [price_a, wire_cast, price_b], + weight_collections=['my-vars']) + my_vars = g.get_collection('my-vars') + self.assertIn('price_a', my_vars[0].name) + self.assertIn('price_b', my_vars[1].name) + self.assertIn('wire_cast', my_vars[2].name) + + with ops.Graph().as_default() as g: + features = { + 'price_a': [[1.]], + 'price_b': [[3.]], + 'wire_cast': + sparse_tensor.SparseTensor( + values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) + } + get_keras_linear_model_predictions( + features, [wire_cast, price_b, price_a], + weight_collections=['my-vars']) + my_vars = g.get_collection('my-vars') + self.assertIn('price_a', my_vars[0].name) + self.assertIn('price_b', my_vars[1].name) + self.assertIn('wire_cast', my_vars[2].name) + + def test_static_batch_size_mismatch(self): + price1 = fc_old.numeric_column('price1') + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': [[1.], [5.], [7.]], # batchsize = 3 + 'price2': [[3.], [4.]] # batchsize = 2 + } + with self.assertRaisesRegexp( + ValueError, + 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string + get_keras_linear_model_predictions(features, [price1, price2]) + + def test_subset_of_static_batch_size_mismatch(self): + price1 = fc_old.numeric_column('price1') + price2 = fc_old.numeric_column('price2') + price3 = fc_old.numeric_column('price3') + with ops.Graph().as_default(): + features = { + 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3 + 'price2': [[3.], [4.]], # batchsize = 2 + 'price3': [[3.], [4.], [5.]] # batchsize = 3 + } + with self.assertRaisesRegexp( + ValueError, + 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string + get_keras_linear_model_predictions(features, [price1, price2, price3]) + + def test_runtime_batch_size_mismatch(self): + price1 = fc_old.numeric_column('price1') + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3 + 'price2': [[3.], [4.]] # batchsize = 2 + } + predictions = get_keras_linear_model_predictions(features, + [price1, price2]) + with _initialized_session() as sess: + with self.assertRaisesRegexp(errors.OpError, + 'must have the same size and shape'): + sess.run( + predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]}) + + def test_runtime_batch_size_matches(self): + price1 = fc_old.numeric_column('price1') + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2 + 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2 + } + predictions = get_keras_linear_model_predictions(features, + [price1, price2]) + with _initialized_session() as sess: + sess.run( + predictions, + feed_dict={ + features['price1']: [[1.], [5.]], + features['price2']: [[1.], [5.]], + }) + + def test_with_numpy_input_fn(self): + price = fc_old.numeric_column('price') + price_buckets = fc_old.bucketized_column( + price, boundaries=[ + 0., + 10., + 100., + ]) + body_style = fc_old.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + + input_fn = numpy_io.numpy_input_fn( + x={ + 'price': np.array([-1., 2., 13., 104.]), + 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']), + }, + batch_size=2, + shuffle=False) + features = input_fn() + net = get_keras_linear_model_predictions(features, + [price_buckets, body_style]) + # self.assertEqual(1 + 3 + 5, net.shape[1]) + with _initialized_session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess, coord=coord) + + bias = get_linear_model_bias() + price_buckets_var = get_linear_model_column_var(price_buckets) + body_style_var = get_linear_model_column_var(body_style) + + sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) + sess.run(bias.assign([5.])) + + self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net)) + + coord.request_stop() + coord.join(threads) + + def test_with_1d_sparse_tensor(self): + price = fc_old.numeric_column('price') + price_buckets = fc_old.bucketized_column( + price, boundaries=[ + 0., + 10., + 100., + ]) + body_style = fc_old.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + + # Provides 1-dim tensor and dense tensor. + features = { + 'price': + constant_op.constant([ + -1., + 12., + ]), + 'body-style': + sparse_tensor.SparseTensor( + indices=((0,), (1,)), + values=('sedan', 'hardtop'), + dense_shape=(2,)), + } + self.assertEqual(1, features['price'].shape.ndims) + self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0]) + + net = get_keras_linear_model_predictions(features, + [price_buckets, body_style]) + with _initialized_session() as sess: + bias = get_linear_model_bias() + price_buckets_var = get_linear_model_column_var(price_buckets) + body_style_var = get_linear_model_column_var(body_style) + + sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) + sess.run(bias.assign([5.])) + + self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net)) + + def test_with_1d_unknown_shape_sparse_tensor(self): + price = fc_old.numeric_column('price') + price_buckets = fc_old.bucketized_column( + price, boundaries=[ + 0., + 10., + 100., + ]) + body_style = fc_old.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + country = fc_old.categorical_column_with_vocabulary_list( + 'country', vocabulary_list=['US', 'JP', 'CA']) + + # Provides 1-dim tensor and dense tensor. + features = { + 'price': array_ops.placeholder(dtypes.float32), + 'body-style': array_ops.sparse_placeholder(dtypes.string), + 'country': array_ops.placeholder(dtypes.string), + } + self.assertIsNone(features['price'].shape.ndims) + self.assertIsNone(features['body-style'].get_shape().ndims) + + price_data = np.array([-1., 12.]) + body_style_data = sparse_tensor.SparseTensorValue( + indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,)) + country_data = np.array(['US', 'CA']) + + net = get_keras_linear_model_predictions( + features, [price_buckets, body_style, country]) + bias = get_linear_model_bias() + price_buckets_var = get_linear_model_column_var(price_buckets) + body_style_var = get_linear_model_column_var(body_style) + with _initialized_session() as sess: + sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) + sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) + sess.run(bias.assign([5.])) + + self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], + sess.run( + net, + feed_dict={ + features['price']: price_data, + features['body-style']: body_style_data, + features['country']: country_data + })) + + def test_with_rank_0_feature(self): + price = fc_old.numeric_column('price') + features = { + 'price': constant_op.constant(0), + } + self.assertEqual(0, features['price'].shape.ndims) + + # Static rank 0 should fail + with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'): + get_keras_linear_model_predictions(features, [price]) + + # Dynamic rank 0 should fail + features = { + 'price': array_ops.placeholder(dtypes.float32), + } + net = get_keras_linear_model_predictions(features, [price]) + self.assertEqual(1, net.shape[1]) + with _initialized_session() as sess: + with self.assertRaisesOpError('Feature .* cannot have rank 0'): + sess.run(net, feed_dict={features['price']: np.array(1)}) + + +class InputLayerTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_retrieving_input(self): + features = {'a': [0.]} + input_layer = InputLayer(fc_old.numeric_column('a')) + inputs = self.evaluate(input_layer(features)) + self.assertAllClose([[0.]], inputs) + + def test_reuses_variables(self): + with context.eager_mode(): + sparse_input = sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (2, 0)), + values=(0, 1, 2), + dense_shape=(3, 3)) + + # Create feature columns (categorical and embedding). + categorical_column = fc_old.categorical_column_with_identity( + key='a', num_buckets=3) + embedding_dimension = 2 + def _embedding_column_initializer(shape, dtype, partition_info): + del shape # unused + del dtype # unused + del partition_info # unused + embedding_values = ( + (1, 0), # id 0 + (0, 1), # id 1 + (1, 1)) # id 2 + return embedding_values + + embedding_column = fc_old.embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_embedding_column_initializer) + + input_layer = InputLayer([embedding_column]) + features = {'a': sparse_input} + + inputs = input_layer(features) + variables = input_layer.variables + + # Sanity check: test that the inputs are correct. + self.assertAllEqual([[1, 0], [0, 1], [1, 1]], inputs) + + # Check that only one variable was created. + self.assertEqual(1, len(variables)) + + # Check that invoking input_layer on the same features does not create + # additional variables + _ = input_layer(features) + self.assertEqual(1, len(variables)) + self.assertEqual(variables[0], input_layer.variables[0]) + + def test_feature_column_input_layer_gradient(self): + with context.eager_mode(): + sparse_input = sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (2, 0)), + values=(0, 1, 2), + dense_shape=(3, 3)) + + # Create feature columns (categorical and embedding). + categorical_column = fc_old.categorical_column_with_identity( + key='a', num_buckets=3) + embedding_dimension = 2 + + def _embedding_column_initializer(shape, dtype, partition_info): + del shape # unused + del dtype # unused + del partition_info # unused + embedding_values = ( + (1, 0), # id 0 + (0, 1), # id 1 + (1, 1)) # id 2 + return embedding_values + + embedding_column = fc_old.embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_embedding_column_initializer) + + input_layer = InputLayer([embedding_column]) + features = {'a': sparse_input} + + def scale_matrix(): + matrix = input_layer(features) + return 2 * matrix + + # Sanity check: Verify that scale_matrix returns the correct output. + self.assertAllEqual([[2, 0], [0, 2], [2, 2]], scale_matrix()) + + # Check that the returned gradient is correct. + grad_function = backprop.implicit_grad(scale_matrix) + grads_and_vars = grad_function() + indexed_slice = grads_and_vars[0][0] + gradient = grads_and_vars[0][0].values + + self.assertAllEqual([0, 1, 2], indexed_slice.indices) + self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient) + + +class FunctionalInputLayerTest(test.TestCase): + + def test_raises_if_empty_feature_columns(self): + with self.assertRaisesRegexp(ValueError, + 'feature_columns must not be empty'): + fc.input_layer(features={}, feature_columns=[]) + + def test_should_be_dense_column(self): + with self.assertRaisesRegexp(ValueError, 'must be a _DenseColumn'): + fc.input_layer( + features={'a': [[0]]}, + feature_columns=[ + fc_old.categorical_column_with_hash_bucket('wire_cast', 4) + ]) + + def test_does_not_support_dict_columns(self): + with self.assertRaisesRegexp( + ValueError, 'Expected feature_columns to be iterable, found dict.'): + fc.input_layer( + features={'a': [[0]]}, + feature_columns={'a': fc_old.numeric_column('a')}) + + def test_bare_column(self): + with ops.Graph().as_default(): + features = features = {'a': [0.]} + net = fc.input_layer(features, fc_old.numeric_column('a')) + with _initialized_session(): + self.assertAllClose([[0.]], net.eval()) + + def test_column_generator(self): + with ops.Graph().as_default(): + features = features = {'a': [0.], 'b': [1.]} + columns = (fc_old.numeric_column(key) for key in features) + net = fc.input_layer(features, columns) + with _initialized_session(): + self.assertAllClose([[0., 1.]], net.eval()) + + def test_raises_if_duplicate_name(self): + with self.assertRaisesRegexp( + ValueError, 'Duplicate feature column name found for columns'): + fc.input_layer( + features={'a': [[0]]}, + feature_columns=[ + fc_old.numeric_column('a'), + fc_old.numeric_column('a') + ]) + + def test_one_column(self): + price = fc_old.numeric_column('price') + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + net = fc.input_layer(features, [price]) + with _initialized_session(): + self.assertAllClose([[1.], [5.]], net.eval()) + + def test_multi_dimension(self): + price = fc_old.numeric_column('price', shape=2) + with ops.Graph().as_default(): + features = {'price': [[1., 2.], [5., 6.]]} + net = fc.input_layer(features, [price]) + with _initialized_session(): + self.assertAllClose([[1., 2.], [5., 6.]], net.eval()) + + def test_raises_if_shape_mismatch(self): + price = fc_old.numeric_column('price', shape=2) + with ops.Graph().as_default(): + features = {'price': [[1.], [5.]]} + with self.assertRaisesRegexp( + Exception, + r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'): + fc.input_layer(features, [price]) + + def test_reshaping(self): + price = fc_old.numeric_column('price', shape=[1, 2]) + with ops.Graph().as_default(): + features = {'price': [[[1., 2.]], [[5., 6.]]]} + net = fc.input_layer(features, [price]) + with _initialized_session(): + self.assertAllClose([[1., 2.], [5., 6.]], net.eval()) + + def test_multi_column(self): + price1 = fc_old.numeric_column('price1', shape=2) + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': [[1., 2.], [5., 6.]], + 'price2': [[3.], [4.]] + } + net = fc.input_layer(features, [price1, price2]) + with _initialized_session(): + self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval()) + + def test_fills_cols_to_vars(self): + # Provide three _DenseColumn's to input_layer: a _NumericColumn, a + # _BucketizedColumn, and an _EmbeddingColumn. Only the _EmbeddingColumn + # creates a Variable. + price1 = fc_old.numeric_column('price1') + dense_feature = fc_old.numeric_column('dense_feature') + dense_feature_bucketized = fc_old.bucketized_column( + dense_feature, boundaries=[0.]) + some_sparse_column = fc_old.categorical_column_with_hash_bucket( + 'sparse_feature', hash_bucket_size=5) + some_embedding_column = fc_old.embedding_column( + some_sparse_column, dimension=10) + with ops.Graph().as_default(): + features = { + 'price1': [[3.], [4.]], + 'dense_feature': [[-1.], [4.]], + 'sparse_feature': [['a'], ['x']], + } + cols_to_vars = {} + all_cols = [price1, dense_feature_bucketized, some_embedding_column] + fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars) + self.assertItemsEqual(list(cols_to_vars.keys()), all_cols) + self.assertEqual(0, len(cols_to_vars[price1])) + self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized])) + self.assertEqual(1, len(cols_to_vars[some_embedding_column])) + self.assertIsInstance(cols_to_vars[some_embedding_column][0], + variables_lib.Variable) + self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10]) + + def test_fills_cols_to_vars_partitioned_variables(self): + price1 = fc_old.numeric_column('price1') + dense_feature = fc_old.numeric_column('dense_feature') + dense_feature_bucketized = fc_old.bucketized_column( + dense_feature, boundaries=[0.]) + some_sparse_column = fc_old.categorical_column_with_hash_bucket( + 'sparse_feature', hash_bucket_size=5) + some_embedding_column = fc_old.embedding_column( + some_sparse_column, dimension=10) + with ops.Graph().as_default(): + features = { + 'price1': [[3.], [4.]], + 'dense_feature': [[-1.], [4.]], + 'sparse_feature': [['a'], ['x']], + } + cols_to_vars = {} + all_cols = [price1, dense_feature_bucketized, some_embedding_column] + with variable_scope.variable_scope( + 'input_from_feature_columns', + partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)): + fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars) + self.assertItemsEqual(list(cols_to_vars.keys()), all_cols) + self.assertEqual(0, len(cols_to_vars[price1])) + self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized])) + self.assertEqual(3, len(cols_to_vars[some_embedding_column])) + self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [2, 10]) + self.assertAllEqual(cols_to_vars[some_embedding_column][1].shape, [2, 10]) + self.assertAllEqual(cols_to_vars[some_embedding_column][2].shape, [1, 10]) + + def test_column_order(self): + price_a = fc_old.numeric_column('price_a') + price_b = fc_old.numeric_column('price_b') + with ops.Graph().as_default(): + features = { + 'price_a': [[1.]], + 'price_b': [[3.]], + } + net1 = fc.input_layer(features, [price_a, price_b]) + net2 = fc.input_layer(features, [price_b, price_a]) + with _initialized_session(): + self.assertAllClose([[1., 3.]], net1.eval()) + self.assertAllClose([[1., 3.]], net2.eval()) + + def test_fails_for_categorical_column(self): + animal = fc_old.categorical_column_with_identity('animal', num_buckets=4) + with ops.Graph().as_default(): + features = { + 'animal': + sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2]) + } + with self.assertRaisesRegexp(Exception, 'must be a _DenseColumn'): + fc.input_layer(features, [animal]) + + def test_static_batch_size_mismatch(self): + price1 = fc_old.numeric_column('price1') + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': [[1.], [5.], [7.]], # batchsize = 3 + 'price2': [[3.], [4.]] # batchsize = 2 + } + with self.assertRaisesRegexp( + ValueError, + 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string + fc.input_layer(features, [price1, price2]) + + def test_subset_of_static_batch_size_mismatch(self): + price1 = fc_old.numeric_column('price1') + price2 = fc_old.numeric_column('price2') + price3 = fc_old.numeric_column('price3') + with ops.Graph().as_default(): + features = { + 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3 + 'price2': [[3.], [4.]], # batchsize = 2 + 'price3': [[3.], [4.], [5.]] # batchsize = 3 + } + with self.assertRaisesRegexp( + ValueError, + 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string + fc.input_layer(features, [price1, price2, price3]) + + def test_runtime_batch_size_mismatch(self): + price1 = fc_old.numeric_column('price1') + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3 + 'price2': [[3.], [4.]] # batchsize = 2 + } + net = fc.input_layer(features, [price1, price2]) + with _initialized_session() as sess: + with self.assertRaisesRegexp(errors.OpError, + 'Dimensions of inputs should match'): + sess.run(net, feed_dict={features['price1']: [[1.], [5.], [7.]]}) + + def test_runtime_batch_size_matches(self): + price1 = fc_old.numeric_column('price1') + price2 = fc_old.numeric_column('price2') + with ops.Graph().as_default(): + features = { + 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2 + 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2 + } + net = fc.input_layer(features, [price1, price2]) + with _initialized_session() as sess: + sess.run( + net, + feed_dict={ + features['price1']: [[1.], [5.]], + features['price2']: [[1.], [5.]], + }) + + def test_multiple_layers_with_same_embedding_column(self): + some_sparse_column = fc_old.categorical_column_with_hash_bucket( + 'sparse_feature', hash_bucket_size=5) + some_embedding_column = fc_old.embedding_column( + some_sparse_column, dimension=10) + + with ops.Graph().as_default(): + features = { + 'sparse_feature': [['a'], ['x']], + } + all_cols = [some_embedding_column] + fc.input_layer(features, all_cols) + fc.input_layer(features, all_cols) + # Make sure that 2 variables get created in this case. + self.assertEqual(2, len( + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) + expected_var_names = [ + 'input_layer/sparse_feature_embedding/embedding_weights:0', + 'input_layer_1/sparse_feature_embedding/embedding_weights:0' + ] + self.assertItemsEqual( + expected_var_names, + [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + + def test_multiple_layers_with_same_shared_embedding_column(self): + categorical_column_a = fc_old.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc_old.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension) + + with ops.Graph().as_default(): + features = { + 'aaa': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 1, 0), + dense_shape=(2, 2)), + 'bbb': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 1), + dense_shape=(2, 2)), + } + all_cols = [embedding_column_a, embedding_column_b] + fc.input_layer(features, all_cols) + fc.input_layer(features, all_cols) + # Make sure that only 1 variable gets created in this case. + self.assertEqual(1, len( + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) + self.assertItemsEqual( + ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], + [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + + def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self): + categorical_column_a = fc_old.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc_old.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension) + all_cols = [embedding_column_a, embedding_column_b] + + with ops.Graph().as_default(): + features = { + 'aaa': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 1, 0), + dense_shape=(2, 2)), + 'bbb': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 1), + dense_shape=(2, 2)), + } + fc.input_layer(features, all_cols) + # Make sure that only 1 variable gets created in this case. + self.assertEqual(1, len( + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) + + with ops.Graph().as_default(): + features1 = { + 'aaa': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 1, 0), + dense_shape=(2, 2)), + 'bbb': + sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 1), + dense_shape=(2, 2)), + } + + fc.input_layer(features1, all_cols) + # Make sure that only 1 variable gets created in this case. + self.assertEqual(1, len( + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) + self.assertItemsEqual( + ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], + [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + + def test_with_numpy_input_fn(self): + embedding_values = ( + (1., 2., 3., 4., 5.), # id 0 + (6., 7., 8., 9., 10.), # id 1 + (11., 12., 13., 14., 15.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + del shape, dtype, partition_info + return embedding_values + + # price has 1 dimension in input_layer + price = fc_old.numeric_column('price') + body_style = fc_old.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + # one_hot_body_style has 3 dims in input_layer. + one_hot_body_style = fc_old.indicator_column(body_style) + # embedded_body_style has 5 dims in input_layer. + embedded_body_style = fc_old.embedding_column( + body_style, dimension=5, initializer=_initializer) + + input_fn = numpy_io.numpy_input_fn( + x={ + 'price': np.array([11., 12., 13., 14.]), + 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']), + }, + batch_size=2, + shuffle=False) + features = input_fn() + net = fc.input_layer(features, + [price, one_hot_body_style, embedded_body_style]) + self.assertEqual(1 + 3 + 5, net.shape[1]) + with _initialized_session() as sess: + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(sess, coord=coord) + + # Each row is formed by concatenating `embedded_body_style`, + # `one_hot_body_style`, and `price` in order. + self.assertAllEqual( + [[11., 12., 13., 14., 15., 0., 0., 1., 11.], + [1., 2., 3., 4., 5., 1., 0., 0., 12]], + sess.run(net)) + + coord.request_stop() + coord.join(threads) + + def test_with_1d_sparse_tensor(self): + embedding_values = ( + (1., 2., 3., 4., 5.), # id 0 + (6., 7., 8., 9., 10.), # id 1 + (11., 12., 13., 14., 15.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + del shape, dtype, partition_info + return embedding_values + + # price has 1 dimension in input_layer + price = fc_old.numeric_column('price') + + # one_hot_body_style has 3 dims in input_layer. + body_style = fc_old.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + one_hot_body_style = fc_old.indicator_column(body_style) + + # embedded_body_style has 5 dims in input_layer. + country = fc_old.categorical_column_with_vocabulary_list( + 'country', vocabulary_list=['US', 'JP', 'CA']) + embedded_country = fc_old.embedding_column( + country, dimension=5, initializer=_initializer) + + # Provides 1-dim tensor and dense tensor. + features = { + 'price': constant_op.constant([11., 12.,]), + 'body-style': sparse_tensor.SparseTensor( + indices=((0,), (1,)), + values=('sedan', 'hardtop'), + dense_shape=(2,)), + # This is dense tensor for the categorical_column. + 'country': constant_op.constant(['CA', 'US']), + } + self.assertEqual(1, features['price'].shape.ndims) + self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0]) + self.assertEqual(1, features['country'].shape.ndims) + + net = fc.input_layer(features, + [price, one_hot_body_style, embedded_country]) + self.assertEqual(1 + 3 + 5, net.shape[1]) + with _initialized_session() as sess: + + # Each row is formed by concatenating `embedded_body_style`, + # `one_hot_body_style`, and `price` in order. + self.assertAllEqual( + [[0., 0., 1., 11., 12., 13., 14., 15., 11.], + [1., 0., 0., 1., 2., 3., 4., 5., 12.]], + sess.run(net)) + + def test_with_1d_unknown_shape_sparse_tensor(self): + embedding_values = ( + (1., 2.), # id 0 + (6., 7.), # id 1 + (11., 12.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + del shape, dtype, partition_info + return embedding_values + + # price has 1 dimension in input_layer + price = fc_old.numeric_column('price') + + # one_hot_body_style has 3 dims in input_layer. + body_style = fc_old.categorical_column_with_vocabulary_list( + 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) + one_hot_body_style = fc_old.indicator_column(body_style) + + # embedded_body_style has 5 dims in input_layer. + country = fc_old.categorical_column_with_vocabulary_list( + 'country', vocabulary_list=['US', 'JP', 'CA']) + embedded_country = fc_old.embedding_column( + country, dimension=2, initializer=_initializer) + + # Provides 1-dim tensor and dense tensor. + features = { + 'price': array_ops.placeholder(dtypes.float32), + 'body-style': array_ops.sparse_placeholder(dtypes.string), + # This is dense tensor for the categorical_column. + 'country': array_ops.placeholder(dtypes.string), + } + self.assertIsNone(features['price'].shape.ndims) + self.assertIsNone(features['body-style'].get_shape().ndims) + self.assertIsNone(features['country'].shape.ndims) + + price_data = np.array([11., 12.]) + body_style_data = sparse_tensor.SparseTensorValue( + indices=((0,), (1,)), + values=('sedan', 'hardtop'), + dense_shape=(2,)) + country_data = np.array([['US'], ['CA']]) + + net = fc.input_layer(features, + [price, one_hot_body_style, embedded_country]) + self.assertEqual(1 + 3 + 2, net.shape[1]) + with _initialized_session() as sess: + + # Each row is formed by concatenating `embedded_body_style`, + # `one_hot_body_style`, and `price` in order. + self.assertAllEqual( + [[0., 0., 1., 1., 2., 11.], [1., 0., 0., 11., 12., 12.]], + sess.run( + net, + feed_dict={ + features['price']: price_data, + features['body-style']: body_style_data, + features['country']: country_data + })) + + def test_with_rank_0_feature(self): + # price has 1 dimension in input_layer + price = fc_old.numeric_column('price') + features = { + 'price': constant_op.constant(0), + } + self.assertEqual(0, features['price'].shape.ndims) + + # Static rank 0 should fail + with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'): + fc.input_layer(features, [price]) + + # Dynamic rank 0 should fail + features = { + 'price': array_ops.placeholder(dtypes.float32), + } + net = fc.input_layer(features, [price]) + self.assertEqual(1, net.shape[1]) + with _initialized_session() as sess: + with self.assertRaisesOpError('Feature .* cannot have rank 0'): + sess.run(net, feed_dict={features['price']: np.array(1)}) + + +class MakeParseExampleSpecTest(test.TestCase): + + class _TestFeatureColumn(FeatureColumn, + collections.namedtuple('_TestFeatureColumn', + ('parse_spec'))): + + @property + def name(self): + return "_TestFeatureColumn" + + def transform_feature(self, transformation_cache, state_manager): + pass + + @property + def parse_example_spec(self): + return self.parse_spec + + def test_no_feature_columns(self): + actual = fc.make_parse_example_spec([]) + self.assertDictEqual({}, actual) + + def test_invalid_type(self): + key1 = 'key1' + parse_spec1 = parsing_ops.FixedLenFeature( + shape=(2,), dtype=dtypes.float32, default_value=0.) + with self.assertRaisesRegexp( + ValueError, + 'All feature_columns must be FeatureColumn instances.*invalid_column'): + fc.make_parse_example_spec( + (self._TestFeatureColumn({key1: parse_spec1}), 'invalid_column')) + + def test_one_feature_column(self): + key1 = 'key1' + parse_spec1 = parsing_ops.FixedLenFeature( + shape=(2,), dtype=dtypes.float32, default_value=0.) + actual = fc.make_parse_example_spec( + (self._TestFeatureColumn({key1: parse_spec1}),)) + self.assertDictEqual({key1: parse_spec1}, actual) + + def test_two_feature_columns(self): + key1 = 'key1' + parse_spec1 = parsing_ops.FixedLenFeature( + shape=(2,), dtype=dtypes.float32, default_value=0.) + key2 = 'key2' + parse_spec2 = parsing_ops.VarLenFeature(dtype=dtypes.string) + actual = fc.make_parse_example_spec( + (self._TestFeatureColumn({key1: parse_spec1}), + self._TestFeatureColumn({key2: parse_spec2}))) + self.assertDictEqual({key1: parse_spec1, key2: parse_spec2}, actual) + + def test_equal_keys_different_parse_spec(self): + key1 = 'key1' + parse_spec1 = parsing_ops.FixedLenFeature( + shape=(2,), dtype=dtypes.float32, default_value=0.) + parse_spec2 = parsing_ops.VarLenFeature(dtype=dtypes.string) + with self.assertRaisesRegexp( + ValueError, + 'feature_columns contain different parse_spec for key key1'): + fc.make_parse_example_spec( + (self._TestFeatureColumn({key1: parse_spec1}), + self._TestFeatureColumn({key1: parse_spec2}))) + + def test_equal_keys_equal_parse_spec(self): + key1 = 'key1' + parse_spec1 = parsing_ops.FixedLenFeature( + shape=(2,), dtype=dtypes.float32, default_value=0.) + actual = fc.make_parse_example_spec( + (self._TestFeatureColumn({key1: parse_spec1}), + self._TestFeatureColumn({key1: parse_spec1}))) + self.assertDictEqual({key1: parse_spec1}, actual) + + def test_multiple_features_dict(self): + """parse_spc for one column is a dict with length > 1.""" + key1 = 'key1' + parse_spec1 = parsing_ops.FixedLenFeature( + shape=(2,), dtype=dtypes.float32, default_value=0.) + key2 = 'key2' + parse_spec2 = parsing_ops.VarLenFeature(dtype=dtypes.string) + key3 = 'key3' + parse_spec3 = parsing_ops.VarLenFeature(dtype=dtypes.int32) + actual = fc.make_parse_example_spec( + (self._TestFeatureColumn({key1: parse_spec1}), + self._TestFeatureColumn({key2: parse_spec2, key3: parse_spec3}))) + self.assertDictEqual( + {key1: parse_spec1, key2: parse_spec2, key3: parse_spec3}, actual) + + +def _assert_sparse_tensor_value(test_case, expected, actual): + test_case.assertEqual(np.int64, np.array(actual.indices).dtype) + test_case.assertAllEqual(expected.indices, actual.indices) + + test_case.assertEqual( + np.array(expected.values).dtype, np.array(actual.values).dtype) + test_case.assertAllEqual(expected.values, actual.values) + + test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype) + test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) + + +class VocabularyFileCategoricalColumnTest(test.TestCase): + + def setUp(self): + super(VocabularyFileCategoricalColumnTest, self).setUp() + + # Contains ints, Golden State Warriors jersey numbers: 30, 35, 11, 23, 22 + self._warriors_vocabulary_file_name = test.test_src_dir_path( + 'python/feature_column/testdata/warriors_vocabulary.txt') + self._warriors_vocabulary_size = 5 + + # Contains strings, character names from 'The Wire': omar, stringer, marlo + self._wire_vocabulary_file_name = test.test_src_dir_path( + 'python/feature_column/testdata/wire_vocabulary.txt') + self._wire_vocabulary_size = 3 + + def test_defaults(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path_to_file', vocabulary_size=3) + self.assertEqual('aaa', column.name) + self.assertEqual('aaa', column.key) + self.assertEqual(3, column.num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.string) + }, column.parse_example_spec) + + def test_key_should_be_string(self): + with self.assertRaisesRegexp(ValueError, 'key must be a string.'): + fc.categorical_column_with_vocabulary_file( + key=('aaa',), vocabulary_file='path_to_file', vocabulary_size=3) + + def test_all_constructor_args(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path_to_file', vocabulary_size=3, + num_oov_buckets=4, dtype=dtypes.int32) + self.assertEqual(7, column.num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column.parse_example_spec) + + def test_deep_copy(self): + original = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path_to_file', vocabulary_size=3, + num_oov_buckets=4, dtype=dtypes.int32) + for column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', column.name) + self.assertEqual(7, column.num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column.parse_example_spec) + + def test_vocabulary_file_none(self): + with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file=None, vocabulary_size=3) + + def test_vocabulary_file_empty_string(self): + with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='', vocabulary_size=3) + + def test_invalid_vocabulary_file(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='file_does_not_exist', vocabulary_size=10) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + column.get_sparse_tensors(FeatureTransformationCache({'aaa': inputs}), None) + with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'): + with self.test_session(): + lookup_ops.tables_initializer().run() + + def test_invalid_vocabulary_size(self): + with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=-1) + with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=0) + + def test_too_large_vocabulary_size(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size + 1) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + column.get_sparse_tensors(FeatureTransformationCache({'aaa': inputs}), None) + with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'): + with self.test_session(): + lookup_ops.tables_initializer().run() + + def test_invalid_num_oov_buckets(self): + with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path', vocabulary_size=3, + num_oov_buckets=-1) + + def test_invalid_dtype(self): + with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'): + fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path', vocabulary_size=3, + dtype=dtypes.float64) + + def test_invalid_buckets_and_default_value(self): + with self.assertRaisesRegexp( + ValueError, 'both num_oov_buckets and default_value'): + fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + num_oov_buckets=100, + default_value=2) + + def test_invalid_input_dtype_int32(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + dtype=dtypes.string) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(12, 24, 36), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + + def test_invalid_input_dtype_string(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._warriors_vocabulary_file_name, + vocabulary_size=self._warriors_vocabulary_size, + dtype=dtypes.int32) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + + def test_parse_example(self): + a = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file='path_to_file', vocabulary_size=3) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'aaa': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=[b'omar', b'stringer'])) + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([a])) + self.assertIn('aaa', features) + with self.test_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([b'omar', b'stringer'], dtype=np.object_), + dense_shape=[1, 2]), + features['aaa'].eval()) + + def test_get_sparse_tensors(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_none_vocabulary_size(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file=self._wire_vocabulary_file_name) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value(self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array( + (2, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_transform_feature(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + id_tensor = _transform_features({'aaa': inputs}, [column], None)[column] + with _initialized_session(): + _assert_sparse_tensor_value(self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array( + (2, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_tensor.eval()) + + def DISABLED_test_get_sparse_tensors_weight_collections(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + inputs = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), + weight_collections=('my_weights',)) + + self.assertItemsEqual( + [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertItemsEqual([], ops.get_collection('my_weights')) + + def test_get_sparse_tensors_dense_input(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': (('marlo', ''), ('skywalker', 'omar')) + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2)), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_default_value_in_vocabulary(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + default_value=2) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 2, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_with_oov_buckets(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + num_oov_buckets=100) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (1, 2)), + values=('marlo', 'skywalker', 'omar', 'heisenberg'), + dense_shape=(2, 3)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 33, 0, 62), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_small_vocabulary_size(self): + # 'marlo' is the last entry in our vocabulary file, so be setting + # `vocabulary_size` to 1 less than number of entries in file, we take + # 'marlo' out of the vocabulary. + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size - 1) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((-1, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._warriors_vocabulary_file_name, + vocabulary_size=self._warriors_vocabulary_size, + dtype=dtypes.int32) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=(11, 100, 30, 22), + dense_shape=(3, 3)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0, 4), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32_dense_input(self): + default_value = -100 + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._warriors_vocabulary_file_name, + vocabulary_size=self._warriors_vocabulary_size, + dtype=dtypes.int32, + default_value=default_value) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': ((11, -1, -1), (100, 30, -1), (-1, -1, 22)) + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=np.array((2, default_value, 0, 4), dtype=np.int64), + dense_shape=(3, 3)), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32_with_oov_buckets(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._warriors_vocabulary_file_name, + vocabulary_size=self._warriors_vocabulary_size, + dtype=dtypes.int32, + num_oov_buckets=100) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=(11, 100, 30, 22), + dense_shape=(3, 3)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 60, 0, 4), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_linear_model(self): + wire_column = fc_old.categorical_column_with_vocabulary_file( + key='wire', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + num_oov_buckets=1) + self.assertEqual(4, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = fc.linear_model({ + wire_column.name: sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval() + # 'marlo' -> 2: wire_var[2] = 3 + # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5 + self.assertAllClose(((3.,), (5.,)), predictions.eval()) + + def test_keras_linear_model(self): + wire_column = fc_old.categorical_column_with_vocabulary_file( + key='wire', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size, + num_oov_buckets=1) + self.assertEqual(4, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + wire_column.name: + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval() + # 'marlo' -> 2: wire_var[2] = 3 + # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5 + self.assertAllClose(((3.,), (5.,)), predictions.eval()) + + +class VocabularyListCategoricalColumnTest(test.TestCase): + + def test_defaults_string(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) + self.assertEqual('aaa', column.name) + self.assertEqual('aaa', column.key) + self.assertEqual(3, column.num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.string) + }, column.parse_example_spec) + + def test_key_should_be_string(self): + with self.assertRaisesRegexp(ValueError, 'key must be a string.'): + fc.categorical_column_with_vocabulary_list( + key=('aaa',), vocabulary_list=('omar', 'stringer', 'marlo')) + + def test_defaults_int(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36)) + self.assertEqual('aaa', column.name) + self.assertEqual('aaa', column.key) + self.assertEqual(3, column.num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, column.parse_example_spec) + + def test_all_constructor_args(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32, + default_value=-99) + self.assertEqual(3, column.num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column.parse_example_spec) + + def test_deep_copy(self): + original = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32) + for column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', column.name) + self.assertEqual(3, column.num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int32) + }, column.parse_example_spec) + + def test_invalid_dtype(self): + with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'), + dtype=dtypes.float32) + + def test_invalid_mapping_dtype(self): + with self.assertRaisesRegexp( + ValueError, r'vocabulary dtype must be string or integer'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12., 24., 36.)) + + def test_mismatched_int_dtype(self): + with self.assertRaisesRegexp( + ValueError, r'dtype.*and vocabulary dtype.*do not match'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'), + dtype=dtypes.int32) + + def test_mismatched_string_dtype(self): + with self.assertRaisesRegexp( + ValueError, r'dtype.*and vocabulary dtype.*do not match'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.string) + + def test_none_mapping(self): + with self.assertRaisesRegexp( + ValueError, r'vocabulary_list.*must be non-empty'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=None) + + def test_empty_mapping(self): + with self.assertRaisesRegexp( + ValueError, r'vocabulary_list.*must be non-empty'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=tuple([])) + + def test_duplicate_mapping(self): + with self.assertRaisesRegexp(ValueError, 'Duplicate keys'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 12)) + + def test_invalid_num_oov_buckets(self): + with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'): + fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(12, 24, 36), + num_oov_buckets=-1) + + def test_invalid_buckets_and_default_value(self): + with self.assertRaisesRegexp( + ValueError, 'both num_oov_buckets and default_value'): + fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=(12, 24, 36), + num_oov_buckets=100, + default_value=2) + + def test_invalid_input_dtype_int32(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(12, 24, 36), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + + def test_invalid_input_dtype_string(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=(12, 24, 36)) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'): + column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + + def test_parse_example_string(self): + a = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'aaa': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=[b'omar', b'stringer'])) + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([a])) + self.assertIn('aaa', features) + with self.test_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([b'omar', b'stringer'], dtype=np.object_), + dense_shape=[1, 2]), + features['aaa'].eval()) + + def test_parse_example_int(self): + a = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=(11, 21, 31)) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'aaa': + feature_pb2.Feature(int64_list=feature_pb2.Int64List( + value=[11, 21])) + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([a])) + self.assertIn('aaa', features) + with self.test_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=[11, 21], + dense_shape=[1, 2]), + features['aaa'].eval()) + + def test_get_sparse_tensors(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_transform_feature(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + id_tensor = _transform_features({'aaa': inputs}, [column], None)[column] + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_tensor.eval()) + + def DISABLED_test_get_sparse_tensors_weight_collections(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), + weight_collections=('my_weights',)) + + self.assertItemsEqual( + [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertItemsEqual([], ops.get_collection('my_weights')) + + def test_get_sparse_tensors_dense_input(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': (('marlo', ''), ('skywalker', 'omar')) + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2)), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_default_value_in_vocabulary(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo'), + default_value=2) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 2, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_with_oov_buckets(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo'), + num_oov_buckets=100) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (1, 2)), + values=('marlo', 'skywalker', 'omar', 'heisenberg'), + dense_shape=(2, 3)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 33, 0, 62), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32), + dtype=dtypes.int32) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=np.array((11, 100, 30, 22), dtype=np.int32), + dense_shape=(3, 3)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, -1, 0, 4), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32_dense_input(self): + default_value = -100 + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32), + dtype=dtypes.int32, + default_value=default_value) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': + np.array( + ((11, -1, -1), (100, 30, -1), (-1, -1, 22)), dtype=np.int32) + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=np.array((2, default_value, 0, 4), dtype=np.int64), + dense_shape=(3, 3)), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_int32_with_oov_buckets(self): + column = fc.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32), + dtype=dtypes.int32, + num_oov_buckets=100) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1), (2, 2)), + values=(11, 100, 30, 22), + dense_shape=(3, 3)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((2, 60, 0, 4), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_linear_model(self): + wire_column = fc_old.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo'), + num_oov_buckets=1) + self.assertEqual(4, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = fc.linear_model({ + wire_column.name: sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval() + # 'marlo' -> 2: wire_var[2] = 3 + # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5 + self.assertAllClose(((3.,), (5.,)), predictions.eval()) + + def test_keras_linear_model(self): + wire_column = fc_old.categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo'), + num_oov_buckets=1) + self.assertEqual(4, wire_column._num_buckets) + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + wire_column.name: + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + }, (wire_column,)) + bias = get_linear_model_bias() + wire_var = get_linear_model_column_var(wire_column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval() + # 'marlo' -> 2: wire_var[2] = 3 + # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5 + self.assertAllClose(((3.,), (5.,)), predictions.eval()) + + +class IdentityCategoricalColumnTest(test.TestCase): + + def test_constructor(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + self.assertEqual('aaa', column.name) + self.assertEqual('aaa', column.key) + self.assertEqual(3, column.num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, column.parse_example_spec) + + def test_key_should_be_string(self): + with self.assertRaisesRegexp(ValueError, 'key must be a string.'): + fc.categorical_column_with_identity(key=('aaa',), num_buckets=3) + + def test_deep_copy(self): + original = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + for column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', column.name) + self.assertEqual(3, column.num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, column.parse_example_spec) + + def test_invalid_num_buckets_zero(self): + with self.assertRaisesRegexp(ValueError, 'num_buckets 0 < 1'): + fc.categorical_column_with_identity(key='aaa', num_buckets=0) + + def test_invalid_num_buckets_negative(self): + with self.assertRaisesRegexp(ValueError, 'num_buckets -1 < 1'): + fc.categorical_column_with_identity(key='aaa', num_buckets=-1) + + def test_invalid_default_value_too_small(self): + with self.assertRaisesRegexp(ValueError, 'default_value -1 not in range'): + fc.categorical_column_with_identity( + key='aaa', num_buckets=3, default_value=-1) + + def test_invalid_default_value_too_big(self): + with self.assertRaisesRegexp(ValueError, 'default_value 3 not in range'): + fc.categorical_column_with_identity( + key='aaa', num_buckets=3, default_value=3) + + def test_invalid_input_dtype(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'): + column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + + def test_parse_example(self): + a = fc.categorical_column_with_identity(key='aaa', num_buckets=30) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'aaa': + feature_pb2.Feature(int64_list=feature_pb2.Int64List( + value=[11, 21])) + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([a])) + self.assertIn('aaa', features) + with self.test_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([11, 21], dtype=np.int64), + dense_shape=[1, 2]), + features['aaa'].eval()) + + def test_get_sparse_tensors(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 1, 0), + dense_shape=(2, 2)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((0, 1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_transform_feature(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 1, 0), + dense_shape=(2, 2)) + id_tensor = _transform_features({'aaa': inputs}, [column], None)[column] + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((0, 1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_tensor.eval()) + + def DISABLED_test_get_sparse_tensors_weight_collections(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 1, 0), + dense_shape=(2, 2)) + column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), + weight_collections=('my_weights',)) + + self.assertItemsEqual( + [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertItemsEqual([], ops.get_collection('my_weights')) + + def test_get_sparse_tensors_dense_input(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': ((0, -1), (1, 0)) + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=np.array((0, 1, 0), dtype=np.int64), + dense_shape=(2, 2)), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_with_inputs_too_small(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, -1, 0), + dense_shape=(2, 2)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + with self.assertRaisesRegexp( + errors.OpError, 'assert_greater_or_equal_0'): + id_weight_pair.id_tensor.eval() + + def test_get_sparse_tensors_with_inputs_too_big(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 99, 0), + dense_shape=(2, 2)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + with self.assertRaisesRegexp( + errors.OpError, 'assert_less_than_num_buckets'): + id_weight_pair.id_tensor.eval() + + def test_get_sparse_tensors_with_default_value(self): + column = fc.categorical_column_with_identity( + key='aaa', num_buckets=4, default_value=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, -1, 99), + dense_shape=(2, 2)) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((1, 3, 3), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self): + column = fc.categorical_column_with_identity( + key='aaa', num_buckets=4, default_value=3) + input_indices = array_ops.placeholder(dtype=dtypes.int64) + input_values = array_ops.placeholder(dtype=dtypes.int32) + input_shape = array_ops.placeholder(dtype=dtypes.int64) + inputs = sparse_tensor.SparseTensorValue( + indices=input_indices, + values=input_values, + dense_shape=input_shape) + id_weight_pair = column.get_sparse_tensors( + FeatureTransformationCache({ + 'aaa': inputs + }), None) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=np.array(((0, 0), (1, 0), (1, 1)), dtype=np.int64), + values=np.array((1, 3, 3), dtype=np.int64), + dense_shape=np.array((2, 2), dtype=np.int64)), + id_weight_pair.id_tensor.eval(feed_dict={ + input_indices: ((0, 0), (1, 0), (1, 1)), + input_values: (1, -1, 99), + input_shape: (2, 2), + })) + + def test_linear_model(self): + column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3) + self.assertEqual(3, column.num_buckets) + with ops.Graph().as_default(): + predictions = fc.linear_model({ + column.name: sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)) + }, (column,)) + bias = get_linear_model_bias() + weight_var = get_linear_model_column_var(column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + weight_var.assign(((1.,), (2.,), (3.,))).eval() + # weight_var[0] = 1 + # weight_var[2] + weight_var[1] = 3+2 = 5 + self.assertAllClose(((1.,), (5.,)), predictions.eval()) + + def test_keras_linear_model(self): + column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3) + self.assertEqual(3, column.num_buckets) + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + column.name: + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)) + }, (column,)) + bias = get_linear_model_bias() + weight_var = get_linear_model_column_var(column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + weight_var.assign(((1.,), (2.,), (3.,))).eval() + # weight_var[0] = 1 + # weight_var[2] + weight_var[1] = 3+2 = 5 + self.assertAllClose(((1.,), (5.,)), predictions.eval()) + + +class TransformFeaturesTest(test.TestCase): + + # All transform tests are distributed in column test. + # Here we only test multi column case and naming + def transform_multi_column(self): + bucketized_price = fc.bucketized_column( + fc.numeric_column('price'), boundaries=[0, 2, 4, 6]) + hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10) + with ops.Graph().as_default(): + features = { + 'price': [[-1.], [5.]], + 'wire': + sparse_tensor.SparseTensor( + values=['omar', 'stringer', 'marlo'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + } + transformed = _transform_features(features, + [bucketized_price, hashed_sparse], None) + with _initialized_session(): + self.assertIn(bucketized_price.name, transformed[bucketized_price].name) + self.assertAllEqual([[0], [3]], transformed[bucketized_price].eval()) + self.assertIn(hashed_sparse.name, transformed[hashed_sparse].name) + self.assertAllEqual([6, 4, 1], transformed[hashed_sparse].values.eval()) + + def test_column_order(self): + """When the column is both dense and sparse, uses sparse tensors.""" + + class _LoggerColumn(FeatureColumn): + + def __init__(self, name): + self._name = name + + @property + def name(self): + return self._name + + def transform_feature(self, transformation_cache, state_manager): + self.call_order = call_logger['count'] + call_logger['count'] += 1 + return 'Anything' + + @property + def parse_example_spec(self): + pass + + with ops.Graph().as_default(): + column1 = _LoggerColumn('1') + column2 = _LoggerColumn('2') + call_logger = {'count': 0} + _transform_features({}, [column1, column2], None) + self.assertEqual(0, column1.call_order) + self.assertEqual(1, column2.call_order) + + call_logger = {'count': 0} + _transform_features({}, [column2, column1], None) + self.assertEqual(0, column1.call_order) + self.assertEqual(1, column2.call_order) + + +class IndicatorColumnTest(test.TestCase): + + def test_indicator_column(self): + a = fc.categorical_column_with_hash_bucket('a', 4) + indicator_a = fc.indicator_column(a) + self.assertEqual(indicator_a.categorical_column.name, 'a') + self.assertEqual(indicator_a.name, 'a_indicator') + self.assertEqual(indicator_a.variable_shape, [1, 4]) + + b = fc.categorical_column_with_hash_bucket('b', hash_bucket_size=100) + indicator_b = fc.indicator_column(b) + self.assertEqual(indicator_b.categorical_column.name, 'b') + self.assertEqual(indicator_b.name, 'b_indicator') + self.assertEqual(indicator_b.variable_shape, [1, 100]) + + def test_1D_shape_succeeds(self): + animal = fc.indicator_column( + fc.categorical_column_with_hash_bucket('animal', 4)) + transformation_cache = FeatureTransformationCache({ + 'animal': ['fox', 'fox'] + }) + output = transformation_cache.get(animal, None) + with self.test_session(): + self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval()) + + def test_2D_shape_succeeds(self): + # TODO(ispir/cassandrax): Swith to categorical_column_with_keys when ready. + animal = fc.indicator_column( + fc.categorical_column_with_hash_bucket('animal', 4)) + transformation_cache = FeatureTransformationCache({ + 'animal': + sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0]], + values=['fox', 'fox'], + dense_shape=[2, 1]) + }) + output = transformation_cache.get(animal, None) + with self.test_session(): + self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval()) + + def test_multi_hot(self): + animal = fc.indicator_column( + fc.categorical_column_with_identity('animal', num_buckets=4)) + + transformation_cache = FeatureTransformationCache({ + 'animal': + sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1]], values=[1, 1], dense_shape=[1, 2]) + }) + output = transformation_cache.get(animal, None) + with self.test_session(): + self.assertAllEqual([[0., 2., 0., 0.]], output.eval()) + + def test_multi_hot2(self): + animal = fc.indicator_column( + fc.categorical_column_with_identity('animal', num_buckets=4)) + transformation_cache = FeatureTransformationCache({ + 'animal': + sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2]) + }) + output = transformation_cache.get(animal, None) + with self.test_session(): + self.assertAllEqual([[0., 1., 1., 0.]], output.eval()) + + def test_deep_copy(self): + a = fc.categorical_column_with_hash_bucket('a', 4) + column = fc.indicator_column(a) + column_copy = copy.deepcopy(column) + self.assertEqual(column_copy.categorical_column.name, 'a') + self.assertEqual(column.name, 'a_indicator') + self.assertEqual(column.variable_shape, [1, 4]) + + def test_parse_example(self): + a = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) + a_indicator = fc.indicator_column(a) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'aaa': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=[b'omar', b'stringer'])) + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([a_indicator])) + self.assertIn('aaa', features) + with self.test_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([b'omar', b'stringer'], dtype=np.object_), + dense_shape=[1, 2]), + features['aaa'].eval()) + + def test_transform(self): + a = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) + a_indicator = fc.indicator_column(a) + features = { + 'aaa': sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + } + indicator_tensor = _transform_features(features, [a_indicator], + None)[a_indicator] + with _initialized_session(): + self.assertAllEqual([[0, 0, 1], [1, 0, 0]], indicator_tensor.eval()) + + def test_transform_with_weighted_column(self): + # Github issue 12557 + ids = fc.categorical_column_with_vocabulary_list( + key='ids', vocabulary_list=('a', 'b', 'c')) + weights = fc.weighted_categorical_column(ids, 'weights') + indicator = fc.indicator_column(weights) + features = { + 'ids': constant_op.constant([['c', 'b', 'a']]), + 'weights': constant_op.constant([[2., 4., 6.]]) + } + indicator_tensor = _transform_features(features, [indicator], + None)[indicator] + with _initialized_session(): + self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval()) + + def test_transform_with_missing_value_in_weighted_column(self): + # Github issue 12583 + ids = fc.categorical_column_with_vocabulary_list( + key='ids', vocabulary_list=('a', 'b', 'c')) + weights = fc.weighted_categorical_column(ids, 'weights') + indicator = fc.indicator_column(weights) + features = { + 'ids': constant_op.constant([['c', 'b', 'unknown']]), + 'weights': constant_op.constant([[2., 4., 6.]]) + } + indicator_tensor = _transform_features(features, [indicator], + None)[indicator] + with _initialized_session(): + self.assertAllEqual([[0., 4., 2.]], indicator_tensor.eval()) + + def test_transform_with_missing_value_in_categorical_column(self): + # Github issue 12583 + ids = fc.categorical_column_with_vocabulary_list( + key='ids', vocabulary_list=('a', 'b', 'c')) + indicator = fc.indicator_column(ids) + features = { + 'ids': constant_op.constant([['c', 'b', 'unknown']]), + } + indicator_tensor = _transform_features(features, [indicator], + None)[indicator] + with _initialized_session(): + self.assertAllEqual([[0., 1., 1.]], indicator_tensor.eval()) + + def test_linear_model(self): + animal = fc_old.indicator_column( + fc_old.categorical_column_with_identity('animal', num_buckets=4)) + with ops.Graph().as_default(): + features = { + 'animal': + sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2]) + } + + predictions = fc.linear_model(features, [animal]) + weight_var = get_linear_model_column_var(animal) + with _initialized_session(): + # All should be zero-initialized. + self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval()) + self.assertAllClose([[0.]], predictions.eval()) + weight_var.assign([[1.], [2.], [3.], [4.]]).eval() + self.assertAllClose([[2. + 3.]], predictions.eval()) + + def test_keras_linear_model(self): + animal = fc_old.indicator_column( + fc_old.categorical_column_with_identity('animal', num_buckets=4)) + with ops.Graph().as_default(): + features = { + 'animal': + sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2]) + } + + predictions = get_keras_linear_model_predictions(features, [animal]) + weight_var = get_linear_model_column_var(animal) + with _initialized_session(): + # All should be zero-initialized. + self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval()) + self.assertAllClose([[0.]], predictions.eval()) + weight_var.assign([[1.], [2.], [3.], [4.]]).eval() + self.assertAllClose([[2. + 3.]], predictions.eval()) + + def test_input_layer(self): + animal = fc_old.indicator_column( + fc_old.categorical_column_with_identity('animal', num_buckets=4)) + with ops.Graph().as_default(): + features = { + 'animal': + sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2]) + } + net = fc.input_layer(features, [animal]) + with _initialized_session(): + self.assertAllClose([[0., 1., 1., 0.]], net.eval()) + + +class _TestStateManager(StateManager): + + def __init__(self, trainable=True): + # Dict of feature_column to a dict of variables. + self._all_variables = {} + self._trainable = trainable + + def get_variable(self, + feature_column, + name, + shape, + dtype=None, + initializer=None): + if feature_column not in self._all_variables: + self._all_variables[feature_column] = {} + var_dict = self._all_variables[feature_column] + if name in var_dict: + return var_dict[name] + else: + var = variable_scope.get_variable( + name=name, + shape=shape, + initializer=initializer, + trainable=self._trainable) + var_dict[name] = var + return var + + +class EmbeddingColumnTest(test.TestCase): + + def test_defaults(self): + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + embedding_dimension = 2 + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension) + self.assertIs(categorical_column, embedding_column.categorical_column) + self.assertEqual(embedding_dimension, embedding_column.dimension) + self.assertEqual('mean', embedding_column.combiner) + self.assertIsNone(embedding_column.ckpt_to_load_from) + self.assertIsNone(embedding_column.tensor_name_in_ckpt) + self.assertIsNone(embedding_column.max_norm) + self.assertTrue(embedding_column.trainable) + self.assertEqual('aaa_embedding', embedding_column.name) + self.assertEqual((embedding_dimension,), embedding_column.variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column.parse_example_spec) + + def test_all_constructor_args(self): + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + embedding_dimension = 2 + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + combiner='my_combiner', initializer=lambda: 'my_initializer', + ckpt_to_load_from='my_ckpt', tensor_name_in_ckpt='my_ckpt_tensor', + max_norm=42., trainable=False) + self.assertIs(categorical_column, embedding_column.categorical_column) + self.assertEqual(embedding_dimension, embedding_column.dimension) + self.assertEqual('my_combiner', embedding_column.combiner) + self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from) + self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt) + self.assertEqual(42., embedding_column.max_norm) + self.assertFalse(embedding_column.trainable) + self.assertEqual('aaa_embedding', embedding_column.name) + self.assertEqual((embedding_dimension,), embedding_column.variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column.parse_example_spec) + + def test_deep_copy(self): + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + embedding_dimension = 2 + original = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + combiner='my_combiner', initializer=lambda: 'my_initializer', + ckpt_to_load_from='my_ckpt', tensor_name_in_ckpt='my_ckpt_tensor', + max_norm=42., trainable=False) + for embedding_column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', embedding_column.categorical_column.name) + self.assertEqual(3, embedding_column.categorical_column.num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column.categorical_column.parse_example_spec) + + self.assertEqual(embedding_dimension, embedding_column.dimension) + self.assertEqual('my_combiner', embedding_column.combiner) + self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from) + self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt) + self.assertEqual(42., embedding_column.max_norm) + self.assertFalse(embedding_column.trainable) + self.assertEqual('aaa_embedding', embedding_column.name) + self.assertEqual((embedding_dimension,), embedding_column.variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column.parse_example_spec) + + def test_invalid_initializer(self): + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + with self.assertRaisesRegexp(ValueError, 'initializer must be callable'): + fc.embedding_column(categorical_column, dimension=2, initializer='not_fn') + + def test_parse_example(self): + a = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) + a_embedded = fc.embedding_column(a, dimension=2) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'aaa': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=[b'omar', b'stringer'])) + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([a_embedded])) + self.assertIn('aaa', features) + with self.test_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([b'omar', b'stringer'], dtype=np.object_), + dense_shape=[1, 2]), + features['aaa'].eval()) + + def test_transform_feature(self): + a = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + a_embedded = fc.embedding_column(a, dimension=2) + features = { + 'aaa': sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 1, 0), + dense_shape=(2, 2)) + } + outputs = _transform_features(features, [a, a_embedded], None) + output_a = outputs[a] + output_embedded = outputs[a_embedded] + with _initialized_session(): + _assert_sparse_tensor_value( + self, output_a.eval(), output_embedded.eval()) + + def test_get_dense_tensor(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + state_manager = _TestStateManager() + + # Provide sparse input and get dense result. + embedding_lookup = embedding_column.get_dense_tensor( + FeatureTransformationCache({ + 'aaa': sparse_input + }), state_manager) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, global_vars[0].eval()) + self.assertAllEqual(expected_lookups, embedding_lookup.eval()) + + def test_get_dense_tensor_3d(self): + # Inputs. + vocabulary_size = 4 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0, 0), (1, 1, 0), (1, 1, 4), (3, 0, 0), (3, 1, 2)), + values=(2, 0, 1, 1, 2), + dense_shape=(4, 2, 5)) + + # Embedding variable. + embedding_dimension = 3 + embedding_values = ( + (1., 2., 4.), # id 0 + (3., 5., 1.), # id 1 + (7., 11., 2.), # id 2 + (2., 7., 12.) # id 3 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [[2], []], embedding = [[7, 11, 2], [0, 0, 0]] + ((7., 11., 2.), (0., 0., 0.)), + # example 1, ids [[], [0, 1]], embedding + # = mean([[], [1, 2, 4] + [3, 5, 1]]) = [[0, 0, 0], [2, 3.5, 2.5]] + ((0., 0., 0.), (2., 3.5, 2.5)), + # example 2, ids [[], []], embedding = [[0, 0, 0], [0, 0, 0]] + ((0., 0., 0.), (0., 0., 0.)), + # example 3, ids [[1], [2]], embedding = [[3, 5, 1], [7, 11, 2]] + ((3., 5., 1.), (7., 11., 2.)), + ) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + state_manager = _TestStateManager() + + # Provide sparse input and get dense result. + embedding_lookup = embedding_column.get_dense_tensor( + FeatureTransformationCache({ + 'aaa': sparse_input + }), state_manager) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, global_vars[0].eval()) + self.assertAllEqual(expected_lookups, embedding_lookup.eval()) + + def DISABLED_test_get_dense_tensor_weight_collections(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + embedding_column = fc.embedding_column(categorical_column, dimension=2) + + # Provide sparse input and get dense result. + embedding_column.get_dense_tensor( + FeatureTransformationCache({ + 'aaa': sparse_input + }), + weight_collections=('my_vars',)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + my_vars = ops.get_collection('my_vars') + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in my_vars])) + + def test_get_dense_tensor_placeholder_inputs(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + state_manager = _TestStateManager() + + # Provide sparse input and get dense result. + input_indices = array_ops.placeholder(dtype=dtypes.int64) + input_values = array_ops.placeholder(dtype=dtypes.int64) + input_shape = array_ops.placeholder(dtype=dtypes.int64) + embedding_lookup = embedding_column.get_dense_tensor( + FeatureTransformationCache({ + 'aaa': + sparse_tensor.SparseTensorValue( + indices=input_indices, + values=input_values, + dense_shape=input_shape) + }), state_manager) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, global_vars[0].eval()) + self.assertAllEqual(expected_lookups, embedding_lookup.eval( + feed_dict={ + input_indices: sparse_input.indices, + input_values: sparse_input.values, + input_shape: sparse_input.dense_shape, + })) + + def test_get_dense_tensor_restore_from_ckpt(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. The checkpoint file contains _embedding_values. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + ckpt_path = test.test_src_dir_path( + 'python/feature_column/testdata/embedding.ckpt') + ckpt_tensor = 'my_embedding' + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + ckpt_to_load_from=ckpt_path, + tensor_name_in_ckpt=ckpt_tensor) + state_manager = _TestStateManager() + + # Provide sparse input and get dense result. + embedding_lookup = embedding_column.get_dense_tensor( + FeatureTransformationCache({ + 'aaa': sparse_input + }), state_manager) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, global_vars[0].eval()) + self.assertAllEqual(expected_lookups, embedding_lookup.eval()) + + def test_linear_model(self): + # Inputs. + batch_size = 4 + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(batch_size, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_shape = (vocabulary_size, embedding_dimension) + zeros_embedding_values = np.zeros(embedding_shape) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual(embedding_shape, shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return zeros_embedding_values + + # Build columns. + categorical_column = fc_old.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc_old.embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_initializer) + + with ops.Graph().as_default(): + predictions = fc.linear_model({ + categorical_column.name: sparse_input + }, (embedding_column,)) + expected_var_names = ( + 'linear_model/bias_weights:0', + 'linear_model/aaa_embedding/weights:0', + 'linear_model/aaa_embedding/embedding_weights:0', + ) + self.assertItemsEqual( + expected_var_names, + [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + trainable_vars = { + v.name: v for v in ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES) + } + self.assertItemsEqual(expected_var_names, trainable_vars.keys()) + bias = trainable_vars['linear_model/bias_weights:0'] + embedding_weights = trainable_vars[ + 'linear_model/aaa_embedding/embedding_weights:0'] + linear_weights = trainable_vars[ + 'linear_model/aaa_embedding/weights:0'] + with _initialized_session(): + # Predictions with all zero weights. + self.assertAllClose(np.zeros((1,)), bias.eval()) + self.assertAllClose(zeros_embedding_values, embedding_weights.eval()) + self.assertAllClose( + np.zeros((embedding_dimension, 1)), linear_weights.eval()) + self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval()) + + # Predictions with all non-zero weights. + embedding_weights.assign(( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + )).eval() + linear_weights.assign(((4.,), (6.,))).eval() + # example 0, ids [2], embedding[0] = [7, 11] + # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5] + # example 2, ids [], embedding[2] = [0, 0] + # example 3, ids [1], embedding[3] = [3, 5] + # sum(embeddings * linear_weights) + # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42] + self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval()) + + def test_keras_linear_model(self): + # Inputs. + batch_size = 4 + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(batch_size, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_shape = (vocabulary_size, embedding_dimension) + zeros_embedding_values = np.zeros(embedding_shape) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual(embedding_shape, shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return zeros_embedding_values + + # Build columns. + categorical_column = fc_old.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc_old.embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_initializer) + + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + categorical_column.name: sparse_input + }, (embedding_column,)) + expected_var_names = ( + 'linear_model/bias_weights:0', + 'linear_model/aaa_embedding/weights:0', + 'linear_model/aaa_embedding/embedding_weights:0', + ) + self.assertItemsEqual( + expected_var_names, + [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + trainable_vars = { + v.name: v + for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + } + self.assertItemsEqual(expected_var_names, trainable_vars.keys()) + bias = trainable_vars['linear_model/bias_weights:0'] + embedding_weights = trainable_vars[ + 'linear_model/aaa_embedding/embedding_weights:0'] + linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0'] + with _initialized_session(): + # Predictions with all zero weights. + self.assertAllClose(np.zeros((1,)), bias.eval()) + self.assertAllClose(zeros_embedding_values, embedding_weights.eval()) + self.assertAllClose( + np.zeros((embedding_dimension, 1)), linear_weights.eval()) + self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval()) + + # Predictions with all non-zero weights. + embedding_weights.assign(( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + )).eval() + linear_weights.assign(((4.,), (6.,))).eval() + # example 0, ids [2], embedding[0] = [7, 11] + # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5] + # example 2, ids [], embedding[2] = [0, 0] + # example 3, ids [1], embedding[3] = [3, 5] + # sum(embeddings * linear_weights) + # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42] + self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval()) + + def test_input_layer(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc_old.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc_old.embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_initializer) + + # Provide sparse input and get dense result. + input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('input_layer/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual( + ('input_layer/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in trainable_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, trainable_vars[0].eval()) + self.assertAllEqual(expected_lookups, input_layer.eval()) + + def test_input_layer_not_trainable(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc_old.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc_old.embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_initializer, + trainable=False) + + # Provide sparse input and get dense result. + input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('input_layer/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + self.assertItemsEqual( + [], ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) + with _initialized_session(): + self.assertAllEqual(embedding_values, global_vars[0].eval()) + self.assertAllEqual(expected_lookups, input_layer.eval()) + + +class _TestSharedEmbeddingStateManager(StateManager): + """Manages the state for shared embedding columns. + + This can handle multiple groups of shared embedding columns. + """ + + def __init__(self, trainable=True): + # Dict of shared_embedding_collection_name to a dict of variables. + self._all_variables = {} + self._trainable = trainable + + def get_variable(self, + feature_column, + name, + shape, + dtype=None, + initializer=None): + if not isinstance(feature_column, fc.SharedEmbeddingColumn): + raise ValueError( + 'SharedEmbeddingStateManager can only handle SharedEmbeddingColumns. ' + 'Given type: {} '.format(type(feature_column))) + + collection_name = feature_column.shared_collection_name + if collection_name not in self._all_variables: + self._all_variables[collection_name] = {} + var_dict = self._all_variables[collection_name] + if name in var_dict: + return var_dict[name] + else: + var = variable_scope.get_variable( + name=name, + shape=shape, + initializer=initializer, + trainable=self._trainable) + var_dict[name] = var + return var + + +class SharedEmbeddingColumnTest(test.TestCase): + + def test_defaults(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + embedding_column_b, embedding_column_a = fc.shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension) + self.assertIs(categorical_column_a, embedding_column_a.categorical_column) + self.assertIs(categorical_column_b, embedding_column_b.categorical_column) + self.assertEqual(embedding_dimension, embedding_column_a.dimension) + self.assertEqual(embedding_dimension, embedding_column_b.dimension) + self.assertEqual('mean', embedding_column_a.combiner) + self.assertEqual('mean', embedding_column_b.combiner) + self.assertIsNone(embedding_column_a.ckpt_to_load_from) + self.assertIsNone(embedding_column_b.ckpt_to_load_from) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_a.shared_collection_name) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_b.shared_collection_name) + self.assertIsNone(embedding_column_a.tensor_name_in_ckpt) + self.assertIsNone(embedding_column_b.tensor_name_in_ckpt) + self.assertIsNone(embedding_column_a.max_norm) + self.assertIsNone(embedding_column_b.max_norm) + self.assertTrue(embedding_column_a.trainable) + self.assertTrue(embedding_column_b.trainable) + self.assertEqual('aaa_shared_embedding', embedding_column_a.name) + self.assertEqual('bbb_shared_embedding', embedding_column_b.name) + self.assertEqual((embedding_dimension,), embedding_column_a.variable_shape) + self.assertEqual((embedding_dimension,), embedding_column_b.variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a.parse_example_spec) + self.assertEqual({ + 'bbb': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_b.parse_example_spec) + + def test_all_constructor_args(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + combiner='my_combiner', + initializer=lambda: 'my_initializer', + shared_embedding_collection_name='shared_embedding_collection_name', + ckpt_to_load_from='my_ckpt', + tensor_name_in_ckpt='my_ckpt_tensor', + max_norm=42., + trainable=False) + self.assertIs(categorical_column_a, embedding_column_a.categorical_column) + self.assertIs(categorical_column_b, embedding_column_b.categorical_column) + self.assertEqual(embedding_dimension, embedding_column_a.dimension) + self.assertEqual(embedding_dimension, embedding_column_b.dimension) + self.assertEqual('my_combiner', embedding_column_a.combiner) + self.assertEqual('my_combiner', embedding_column_b.combiner) + self.assertEqual('shared_embedding_collection_name', + embedding_column_a.shared_collection_name) + self.assertEqual('shared_embedding_collection_name', + embedding_column_b.shared_collection_name) + self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) + self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from) + self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) + self.assertEqual('my_ckpt_tensor', embedding_column_b.tensor_name_in_ckpt) + self.assertEqual(42., embedding_column_a.max_norm) + self.assertEqual(42., embedding_column_b.max_norm) + self.assertFalse(embedding_column_a.trainable) + self.assertFalse(embedding_column_b.trainable) + self.assertEqual('aaa_shared_embedding', embedding_column_a.name) + self.assertEqual('bbb_shared_embedding', embedding_column_b.name) + self.assertEqual((embedding_dimension,), embedding_column_a.variable_shape) + self.assertEqual((embedding_dimension,), embedding_column_b.variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a.parse_example_spec) + self.assertEqual({ + 'bbb': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_b.parse_example_spec) + + def test_deep_copy(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + original_a, _ = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + combiner='my_combiner', + initializer=lambda: 'my_initializer', + shared_embedding_collection_name='shared_embedding_collection_name', + ckpt_to_load_from='my_ckpt', + tensor_name_in_ckpt='my_ckpt_tensor', + max_norm=42., trainable=False) + for embedding_column_a in (original_a, copy.deepcopy(original_a)): + self.assertEqual('aaa', embedding_column_a.categorical_column.name) + self.assertEqual(3, embedding_column_a.categorical_column.num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a.categorical_column.parse_example_spec) + + self.assertEqual(embedding_dimension, embedding_column_a.dimension) + self.assertEqual('my_combiner', embedding_column_a.combiner) + self.assertEqual('shared_embedding_collection_name', + embedding_column_a.shared_collection_name) + self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) + self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) + self.assertEqual(42., embedding_column_a.max_norm) + self.assertFalse(embedding_column_a.trainable) + self.assertEqual('aaa_shared_embedding', embedding_column_a.name) + self.assertEqual((embedding_dimension,), + embedding_column_a.variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a.parse_example_spec) + + def test_invalid_initializer(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + with self.assertRaisesRegexp(ValueError, 'initializer must be callable'): + fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2, + initializer='not_fn') + + def test_incompatible_column_type(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + categorical_column_c = fc.categorical_column_with_hash_bucket( + key='ccc', hash_bucket_size=3) + with self.assertRaisesRegexp( + ValueError, 'all categorical_columns must have the same type.*' + 'IdentityCategoricalColumn.*HashedCategoricalColumn'): + fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b, categorical_column_c], + dimension=2) + + def test_weighted_categorical_column_ok(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + weighted_categorical_column_a = fc.weighted_categorical_column( + categorical_column_a, weight_feature_key='aaa_weights') + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + weighted_categorical_column_b = fc.weighted_categorical_column( + categorical_column_b, weight_feature_key='bbb_weights') + fc.shared_embedding_columns( + [weighted_categorical_column_a, categorical_column_b], dimension=2) + fc.shared_embedding_columns( + [categorical_column_a, weighted_categorical_column_b], dimension=2) + fc.shared_embedding_columns( + [weighted_categorical_column_a, weighted_categorical_column_b], + dimension=2) + + def test_parse_example(self): + a = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) + b = fc.categorical_column_with_vocabulary_list( + key='bbb', vocabulary_list=('omar', 'stringer', 'marlo')) + a_embedded, b_embedded = fc.shared_embedding_columns( + [a, b], dimension=2) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'aaa': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=[b'omar', b'stringer'])), + 'bbb': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=[b'stringer', b'marlo'])), + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([a_embedded, b_embedded])) + self.assertIn('aaa', features) + self.assertIn('bbb', features) + with self.test_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([b'omar', b'stringer'], dtype=np.object_), + dense_shape=[1, 2]), + features['aaa'].eval()) + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([b'stringer', b'marlo'], dtype=np.object_), + dense_shape=[1, 2]), + features['bbb'].eval()) + + def test_transform_feature(self): + a = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + b = fc.categorical_column_with_identity(key='bbb', num_buckets=3) + a_embedded, b_embedded = fc.shared_embedding_columns( + [a, b], dimension=2) + features = { + 'aaa': sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 1, 0), + dense_shape=(2, 2)), + 'bbb': sparse_tensor.SparseTensor( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 1), + dense_shape=(2, 2)), + } + outputs = _transform_features(features, [a, a_embedded, b, b_embedded], + None) + output_a = outputs[a] + output_a_embedded = outputs[a_embedded] + output_b = outputs[b] + output_b_embedded = outputs[b_embedded] + with _initialized_session(): + _assert_sparse_tensor_value( + self, output_a.eval(), output_a_embedded.eval()) + _assert_sparse_tensor_value( + self, output_b.eval(), output_b_embedded.eval()) + + def test_get_dense_tensor(self): + # Inputs. + vocabulary_size = 3 + # -1 values are ignored. + input_a = np.array( + [[2, -1, -1], # example 0, ids [2] + [0, 1, -1]]) # example 1, ids [0, 1] + input_b = np.array( + [[0, -1, -1], # example 0, ids [0] + [-1, -1, -1]]) # example 1, ids [] + input_features = { + 'aaa': input_a, + 'bbb': input_b + } + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups_a = ( + # example 0: + (7., 11.), # ids [2], embedding = [7, 11] + # example 1: + (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + ) + expected_lookups_b = ( + # example 0: + (1., 2.), # ids [0], embedding = [1, 2] + # example 1: + (0., 0.), # ids [], embedding = [0, 0] + ) + + # Build columns. + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, initializer=_initializer) + state_manager = _TestSharedEmbeddingStateManager() + + # Provide sparse input and get dense result. + embedding_lookup_a = embedding_column_a.get_dense_tensor( + FeatureTransformationCache(input_features), state_manager) + embedding_lookup_b = embedding_column_b.get_dense_tensor( + FeatureTransformationCache(input_features), state_manager) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + embedding_var = global_vars[0] + with _initialized_session(): + self.assertAllEqual(embedding_values, embedding_var.eval()) + self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval()) + self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval()) + + def DISABLED_test_get_dense_tensor_weight_collections(self): + # Inputs. + vocabulary_size = 3 + # -1 values are ignored. + input_a = np.array([ + [2, -1, -1], # example 0, ids [2] + [0, 1, -1] + ]) # example 1, ids [0, 1] + input_b = np.array([ + [0, -1, -1], # example 0, ids [0] + [-1, -1, -1] + ]) # example 1, ids [] + input_features = {'aaa': input_a, 'bbb': input_b} + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Build columns. + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + fc.input_layer( + input_features, [embedding_column_a, embedding_column_b], + weight_collections=('my_vars',)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple(v.name for v in global_vars)) + my_vars = ops.get_collection('my_vars') + self.assertItemsEqual( + ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple(v.name for v in my_vars)) + + def test_get_dense_tensor_placeholder_inputs(self): + # Inputs. + vocabulary_size = 3 + # -1 values are ignored. + input_a = np.array( + [[2, -1, -1], # example 0, ids [2] + [0, 1, -1]]) # example 1, ids [0, 1] + input_b = np.array( + [[0, -1, -1], # example 0, ids [0] + [-1, -1, -1]]) # example 1, ids [] + # Specify shape, because dense input must have rank specified. + input_a_placeholder = array_ops.placeholder( + dtype=dtypes.int64, shape=[None, 3]) + input_b_placeholder = array_ops.placeholder( + dtype=dtypes.int64, shape=[None, 3]) + input_features = { + 'aaa': input_a_placeholder, + 'bbb': input_b_placeholder, + } + feed_dict = { + input_a_placeholder: input_a, + input_b_placeholder: input_b, + } + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Build columns. + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, initializer=_initializer) + state_manager = _TestSharedEmbeddingStateManager() + + # Provide sparse input and get dense result. + embedding_lookup_a = embedding_column_a.get_dense_tensor( + FeatureTransformationCache(input_features), state_manager) + embedding_lookup_b = embedding_column_b.get_dense_tensor( + FeatureTransformationCache(input_features), state_manager) + + with _initialized_session() as sess: + sess.run([embedding_lookup_a, embedding_lookup_b], feed_dict=feed_dict) + + def test_linear_model(self): + # Inputs. + batch_size = 2 + vocabulary_size = 3 + # -1 values are ignored. + input_a = np.array( + [[2, -1, -1], # example 0, ids [2] + [0, 1, -1]]) # example 1, ids [0, 1] + input_b = np.array( + [[0, -1, -1], # example 0, ids [0] + [-1, -1, -1]]) # example 1, ids [] + + # Embedding variable. + embedding_dimension = 2 + embedding_shape = (vocabulary_size, embedding_dimension) + zeros_embedding_values = np.zeros(embedding_shape) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual(embedding_shape, shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return zeros_embedding_values + + # Build columns. + categorical_column_a = fc_old.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc_old.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + with ops.Graph().as_default(): + predictions = fc.linear_model({ + categorical_column_a.name: input_a, + categorical_column_b.name: input_b, + }, (embedding_column_a, embedding_column_b)) + # Linear weights do not follow the column name. But this is a rare use + # case, and fixing it would add too much complexity to the code. + expected_var_names = ( + 'linear_model/bias_weights:0', + 'linear_model/aaa_bbb_shared_embedding/weights:0', + 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0', + 'linear_model/aaa_bbb_shared_embedding_1/weights:0', + ) + self.assertItemsEqual( + expected_var_names, + [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + trainable_vars = { + v.name: v for v in ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES) + } + self.assertItemsEqual(expected_var_names, trainable_vars.keys()) + bias = trainable_vars['linear_model/bias_weights:0'] + embedding_weights = trainable_vars[ + 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0'] + linear_weights_a = trainable_vars[ + 'linear_model/aaa_bbb_shared_embedding/weights:0'] + linear_weights_b = trainable_vars[ + 'linear_model/aaa_bbb_shared_embedding_1/weights:0'] + with _initialized_session(): + # Predictions with all zero weights. + self.assertAllClose(np.zeros((1,)), bias.eval()) + self.assertAllClose(zeros_embedding_values, embedding_weights.eval()) + self.assertAllClose( + np.zeros((embedding_dimension, 1)), linear_weights_a.eval()) + self.assertAllClose( + np.zeros((embedding_dimension, 1)), linear_weights_b.eval()) + self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval()) + + # Predictions with all non-zero weights. + embedding_weights.assign(( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + )).eval() + linear_weights_a.assign(((4.,), (6.,))).eval() + # example 0, ids [2], embedding[0] = [7, 11] + # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5] + # sum(embeddings * linear_weights) + # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29] + linear_weights_b.assign(((3.,), (5.,))).eval() + # example 0, ids [0], embedding[0] = [1, 2] + # example 1, ids [], embedding[1] = 0, 0] + # sum(embeddings * linear_weights) + # = [3*1 + 5*2, 3*0 +5*0] = [13, 0] + self.assertAllClose([[94. + 13.], [29.]], predictions.eval()) + + def test_keras_linear_model(self): + # Inputs. + batch_size = 2 + vocabulary_size = 3 + # -1 values are ignored. + input_a = np.array([ + [2, -1, -1], # example 0, ids [2] + [0, 1, -1] + ]) # example 1, ids [0, 1] + input_b = np.array([ + [0, -1, -1], # example 0, ids [0] + [-1, -1, -1] + ]) # example 1, ids [] + + # Embedding variable. + embedding_dimension = 2 + embedding_shape = (vocabulary_size, embedding_dimension) + zeros_embedding_values = np.zeros(embedding_shape) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual(embedding_shape, shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return zeros_embedding_values + + # Build columns. + categorical_column_a = fc_old.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc_old.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + categorical_column_a.name: input_a, + categorical_column_b.name: input_b, + }, (embedding_column_a, embedding_column_b)) + # Linear weights do not follow the column name. But this is a rare use + # case, and fixing it would add too much complexity to the code. + expected_var_names = ( + 'linear_model/bias_weights:0', + 'linear_model/aaa_bbb_shared_embedding/weights:0', + 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0', + 'linear_model/aaa_bbb_shared_embedding_1/weights:0', + ) + self.assertItemsEqual( + expected_var_names, + [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + trainable_vars = { + v.name: v + for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + } + self.assertItemsEqual(expected_var_names, trainable_vars.keys()) + bias = trainable_vars['linear_model/bias_weights:0'] + embedding_weights = trainable_vars[ + 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0'] + linear_weights_a = trainable_vars[ + 'linear_model/aaa_bbb_shared_embedding/weights:0'] + linear_weights_b = trainable_vars[ + 'linear_model/aaa_bbb_shared_embedding_1/weights:0'] + with _initialized_session(): + # Predictions with all zero weights. + self.assertAllClose(np.zeros((1,)), bias.eval()) + self.assertAllClose(zeros_embedding_values, embedding_weights.eval()) + self.assertAllClose( + np.zeros((embedding_dimension, 1)), linear_weights_a.eval()) + self.assertAllClose( + np.zeros((embedding_dimension, 1)), linear_weights_b.eval()) + self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval()) + + # Predictions with all non-zero weights. + embedding_weights.assign(( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + )).eval() + linear_weights_a.assign(((4.,), (6.,))).eval() + # example 0, ids [2], embedding[0] = [7, 11] + # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5] + # sum(embeddings * linear_weights) + # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29] + linear_weights_b.assign(((3.,), (5.,))).eval() + # example 0, ids [0], embedding[0] = [1, 2] + # example 1, ids [], embedding[1] = 0, 0] + # sum(embeddings * linear_weights) + # = [3*1 + 5*2, 3*0 +5*0] = [13, 0] + self.assertAllClose([[94. + 13.], [29.]], predictions.eval()) + + def _test_input_layer(self, trainable=True): + # Inputs. + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 4)), + values=(2, 0, 1), + dense_shape=(2, 5)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [0] + # example 1, ids [] + indices=((0, 0),), + values=(0,), + dense_shape=(2, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0: + # A ids [2], embedding = [7, 11] + # B ids [0], embedding = [1, 2] + (7., 11., 1., 2.), + # example 1: + # A ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + # B ids [], embedding = [0, 0] + (2., 3.5, 0., 0.), + ) + + # Build columns. + categorical_column_a = fc_old.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc_old.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer, + trainable=trainable) + + # Provide sparse input and get dense result. + input_layer = fc.input_layer( + features={'aaa': sparse_input_a, 'bbb': sparse_input_b}, + feature_columns=(embedding_column_b, embedding_column_a)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], + tuple([v.name for v in global_vars])) + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + if trainable: + self.assertItemsEqual( + ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], + tuple([v.name for v in trainable_vars])) + else: + self.assertItemsEqual([], tuple([v.name for v in trainable_vars])) + shared_embedding_vars = global_vars + with _initialized_session(): + self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval()) + self.assertAllEqual(expected_lookups, input_layer.eval()) + + def test_input_layer(self): + self._test_input_layer() + + def test_input_layer_no_trainable(self): + self._test_input_layer(trainable=False) + + +class WeightedCategoricalColumnTest(test.TestCase): + + def test_defaults(self): + column = fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + self.assertEqual('ids_weighted_by_values', column.name) + self.assertEqual(3, column.num_buckets) + self.assertEqual({ + 'ids': parsing_ops.VarLenFeature(dtypes.int64), + 'values': parsing_ops.VarLenFeature(dtypes.float32) + }, column.parse_example_spec) + + def test_deep_copy(self): + """Tests deepcopy of categorical_column_with_hash_bucket.""" + original = fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + for column in (original, copy.deepcopy(original)): + self.assertEqual('ids_weighted_by_values', column.name) + self.assertEqual(3, column.num_buckets) + self.assertEqual({ + 'ids': parsing_ops.VarLenFeature(dtypes.int64), + 'values': parsing_ops.VarLenFeature(dtypes.float32) + }, column.parse_example_spec) + + def test_invalid_dtype_none(self): + with self.assertRaisesRegexp(ValueError, 'is not convertible to float'): + fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values', + dtype=None) + + def test_invalid_dtype_string(self): + with self.assertRaisesRegexp(ValueError, 'is not convertible to float'): + fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values', + dtype=dtypes.string) + + def test_invalid_input_dtype(self): + column = fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + strings = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'Bad dtype'): + _transform_features({'ids': strings, 'values': strings}, (column,), None) + + def test_column_name_collision(self): + with self.assertRaisesRegexp(ValueError, r'Parse config.*already exists'): + fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='aaa', num_buckets=3), + weight_feature_key='aaa').parse_example_spec() + + def test_missing_weights(self): + column = fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + with self.assertRaisesRegexp( + ValueError, 'values is not in features dictionary'): + _transform_features({'ids': inputs}, (column,), None) + + def test_parse_example(self): + a = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) + a_weighted = fc.weighted_categorical_column(a, weight_feature_key='weights') + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'aaa': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=[b'omar', b'stringer'])), + 'weights': + feature_pb2.Feature(float_list=feature_pb2.FloatList( + value=[1., 10.])) + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([a_weighted])) + self.assertIn('aaa', features) + self.assertIn('weights', features) + with self.test_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([b'omar', b'stringer'], dtype=np.object_), + dense_shape=[1, 2]), + features['aaa'].eval()) + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([1., 10.], dtype=np.float32), + dense_shape=[1, 2]), + features['weights'].eval()) + + def test_transform_features(self): + column = fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 1, 0), + dense_shape=(2, 2)) + weights = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0.5, 1.0, 0.1), + dense_shape=(2, 2)) + id_tensor, weight_tensor = _transform_features({ + 'ids': inputs, + 'values': weights, + }, (column,), None)[column] + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array(inputs.values, dtype=np.int64), + dense_shape=inputs.dense_shape), + id_tensor.eval()) + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=weights.indices, + values=np.array(weights.values, dtype=np.float32), + dense_shape=weights.dense_shape), + weight_tensor.eval()) + + def test_transform_features_dense_input(self): + column = fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + weights = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0.5, 1.0, 0.1), + dense_shape=(2, 2)) + id_tensor, weight_tensor = _transform_features({ + 'ids': ((0, -1), (1, 0)), + 'values': weights, + }, (column,), None)[column] + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=np.array((0, 1, 0), dtype=np.int64), + dense_shape=(2, 2)), + id_tensor.eval()) + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=weights.indices, + values=np.array(weights.values, dtype=np.float32), + dense_shape=weights.dense_shape), + weight_tensor.eval()) + + def test_transform_features_dense_weights(self): + column = fc.weighted_categorical_column( + categorical_column=fc.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 1, 0), + dense_shape=(2, 2)) + id_tensor, weight_tensor = _transform_features({ + 'ids': inputs, + 'values': ((.5, 0.), (1., .1)), + }, (column,), None)[column] + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array(inputs.values, dtype=np.int64), + dense_shape=inputs.dense_shape), + id_tensor.eval()) + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=np.array((.5, 1., .1), dtype=np.float32), + dense_shape=(2, 2)), + weight_tensor.eval()) + + def test_keras_linear_model(self): + column = fc_old.weighted_categorical_column( + categorical_column=fc_old.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + 'ids': + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)), + 'values': + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(.5, 1., .1), + dense_shape=(2, 2)) + }, (column,)) + bias = get_linear_model_bias() + weight_var = get_linear_model_column_var(column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + weight_var.assign(((1.,), (2.,), (3.,))).eval() + # weight_var[0] * weights[0, 0] = 1 * .5 = .5 + # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1] + # = 3*1 + 2*.1 = 3+.2 = 3.2 + self.assertAllClose(((.5,), (3.2,)), predictions.eval()) + + def test_keras_linear_model_mismatched_shape(self): + column = fc_old.weighted_categorical_column( + categorical_column=fc_old.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + with ops.Graph().as_default(): + with self.assertRaisesRegexp(ValueError, + r'Dimensions.*are not compatible'): + get_keras_linear_model_predictions({ + 'ids': + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)), + 'values': + sparse_tensor.SparseTensorValue( + indices=((0, 0), (0, 1), (1, 0), (1, 1)), + values=(.5, 11., 1., .1), + dense_shape=(2, 2)) + }, (column,)) + + def test_keras_linear_model_mismatched_dense_values(self): + column = fc_old.weighted_categorical_column( + categorical_column=fc_old.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions( + { + 'ids': + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)), + 'values': ((.5,), (1.,)) + }, (column,), + sparse_combiner='mean') + # Disabling the constant folding optimizer here since it changes the + # error message differently on CPU and GPU. + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + with _initialized_session(config): + with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'): + predictions.eval() + + def test_keras_linear_model_mismatched_dense_shape(self): + column = fc_old.weighted_categorical_column( + categorical_column=fc_old.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + with ops.Graph().as_default(): + predictions = get_keras_linear_model_predictions({ + 'ids': + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)), + 'values': ((.5,), (1.,), (.1,)) + }, (column,)) + bias = get_linear_model_bias() + weight_var = get_linear_model_column_var(column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + weight_var.assign(((1.,), (2.,), (3.,))).eval() + # weight_var[0] * weights[0, 0] = 1 * .5 = .5 + # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1] + # = 3*1 + 2*.1 = 3+.2 = 3.2 + self.assertAllClose(((.5,), (3.2,)), predictions.eval()) + + def test_linear_model(self): + column = fc_old.weighted_categorical_column( + categorical_column=fc_old.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + with ops.Graph().as_default(): + predictions = fc.linear_model({ + 'ids': sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)), + 'values': sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(.5, 1., .1), + dense_shape=(2, 2)) + }, (column,)) + bias = get_linear_model_bias() + weight_var = get_linear_model_column_var(column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + weight_var.assign(((1.,), (2.,), (3.,))).eval() + # weight_var[0] * weights[0, 0] = 1 * .5 = .5 + # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1] + # = 3*1 + 2*.1 = 3+.2 = 3.2 + self.assertAllClose(((.5,), (3.2,)), predictions.eval()) + + def test_linear_model_mismatched_shape(self): + column = fc_old.weighted_categorical_column( + categorical_column=fc_old.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + with ops.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, r'Dimensions.*are not compatible'): + fc.linear_model({ + 'ids': sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)), + 'values': sparse_tensor.SparseTensorValue( + indices=((0, 0), (0, 1), (1, 0), (1, 1)), + values=(.5, 11., 1., .1), + dense_shape=(2, 2)) + }, (column,)) + + def test_linear_model_mismatched_dense_values(self): + column = fc_old.weighted_categorical_column( + categorical_column=fc_old.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + with ops.Graph().as_default(): + predictions = fc.linear_model( + { + 'ids': + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)), + 'values': ((.5,), (1.,)) + }, (column,), + sparse_combiner='mean') + # Disabling the constant folding optimizer here since it changes the + # error message differently on CPU and GPU. + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + with _initialized_session(config): + with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'): + predictions.eval() + + def test_linear_model_mismatched_dense_shape(self): + column = fc_old.weighted_categorical_column( + categorical_column=fc_old.categorical_column_with_identity( + key='ids', num_buckets=3), + weight_feature_key='values') + with ops.Graph().as_default(): + predictions = fc.linear_model({ + 'ids': sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)), + 'values': ((.5,), (1.,), (.1,)) + }, (column,)) + bias = get_linear_model_bias() + weight_var = get_linear_model_column_var(column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + weight_var.assign(((1.,), (2.,), (3.,))).eval() + # weight_var[0] * weights[0, 0] = 1 * .5 = .5 + # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1] + # = 3*1 + 2*.1 = 3+.2 = 3.2 + self.assertAllClose(((.5,), (3.2,)), predictions.eval()) + + # TODO(ptucker): Add test with embedding of weighted categorical. + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py index 3c5aebbce8af117aa1e216f1ef07ded181c997ea..40788e24c486c4357042672e3697063a4c7fb381 100644 --- a/tensorflow/python/framework/common_shapes.py +++ b/tensorflow/python/framework/common_shapes.py @@ -28,6 +28,18 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +def has_fully_defined_shape(tensor): + """Returns true if tensor has a fully defined shape.""" + return isinstance(tensor, ops.EagerTensor) or tensor.shape.is_fully_defined() + + +def rank(tensor): + """Return a rank if it is a tensor, else return None.""" + if isinstance(tensor, ops.Tensor): + return tensor._rank() # pylint: disable=protected-access + return None + + def scalar_shape(unused_op): """Shape function for ops that output a scalar value.""" return [tensor_shape.scalar()] diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..6e844e14b959dbf8b6fd8086a89d2e7aa51a111d --- /dev/null +++ b/tensorflow/python/framework/error_interpolation.py @@ -0,0 +1,313 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Function for interpolating formatted errors from the TensorFlow runtime. + +Exposes the function `interpolate` to interpolate messages with tags of the form +^^type:name:format^^. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import itertools +import os +import re +import string + +import six + +from tensorflow.python.util import tf_stack + + +_NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?" +_FORMAT_REGEX = r"[A-Za-z0-9_.\-/${}:]+" +_TAG_REGEX = r"\^\^({name}):({name}):({fmt})\^\^".format( + name=_NAME_REGEX, fmt=_FORMAT_REGEX) +_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX) +_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX) + +_ParseTag = collections.namedtuple("_ParseTag", ["type", "name", "format"]) + +_BAD_FILE_SUBSTRINGS = [ + os.path.join("tensorflow", "python"), + " + file: Replaced with the filename in which the node was defined. + line: Replaced by the line number at which the node was defined. + colocations: Replaced by a multi-line message describing the file and + line numbers at which this node was colocated with other nodes. + + Args: + message: String to parse + + Returns: + (list of separator strings, list of _ParseTags). + + For example, if message is "123^^node:Foo:${file}^^456" then this function + returns (["123", "456"], [_ParseTag("node", "Foo", "${file}")]) + """ + seps = [] + tags = [] + pos = 0 + while pos < len(message): + match = re.match(_INTERPOLATION_PATTERN, message[pos:]) + if match: + seps.append(match.group(1)) + tags.append(_ParseTag(match.group(3), match.group(4), match.group(5))) + pos += match.end() + else: + break + seps.append(message[pos:]) + return seps, tags + + +def _compute_device_summary_from_list(name, device_assignment_list, prefix=""): + """Return a summary of an op's device function stack. + + Args: + name: The name of the op. + device_assignment_list: The op._device_assignments list. + prefix: An optional string prefix used before each line of the multi- + line string returned by this function. + + Returns: + A multi-line string similar to: + Device assignments active during op 'foo' creation: + with tf.device(/cpu:0): + with tf.device(some_func): + The first line will have no padding to its left by default. Subsequent + lines will have two spaces of left-padding. Use the prefix argument + to increase indentation. + """ + if not device_assignment_list: + message = "No device assignments were active during op '%s' creation." + message %= name + return prefix + message + + str_list = [] + str_list.append("%sDevice assignments active during op '%s' creation:" + % (prefix, name)) + + for traceable_obj in device_assignment_list: + location_summary = "<{file}:{line}>".format(file=traceable_obj.filename, + line=traceable_obj.lineno) + subs = { + "prefix": prefix, + "indent": " ", + "dev_name": traceable_obj.obj, + "loc": location_summary, + } + str_list.append( + "{prefix}{indent}with tf.device({dev_name}): {loc}".format(**subs)) + + return "\n".join(str_list) + + +def _compute_device_assignment_summary_from_op(op, prefix=""): + # pylint: disable=protected-access + return _compute_device_summary_from_list(op.name, op._device_assignments, + prefix) + # pylint: enable=protected-access + + +def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""): + """Return a summary of an op's colocation stack. + + Args: + name: The op name. + colocation_dict: The op._colocation_dict. + prefix: An optional string prefix used before each line of the multi- + line string returned by this function. + + Returns: + A multi-line string similar to: + Node-device colocations active during op creation: + with tf.colocate_with(test_node_1): + with tf.colocate_with(test_node_2): + The first line will have no padding to its left by default. Subsequent + lines will have two spaces of left-padding. Use the prefix argument + to increase indentation. + """ + if not colocation_dict: + message = "No node-device colocations were active during op '%s' creation." + message %= name + return prefix + message + + str_list = [] + str_list.append("%sNode-device colocations active during op '%s' creation:" + % (prefix, name)) + + for coloc_name, location in colocation_dict.items(): + location_summary = "<{file}:{line}>".format(file=location.filename, + line=location.lineno) + subs = { + "prefix": prefix, + "indent": " ", + "name": coloc_name, + "loc": location_summary, + } + str_list.append( + "{prefix}{indent}with tf.colocate_with({name}): {loc}".format(**subs)) + + return "\n".join(str_list) + + +def _compute_colocation_summary_from_op(op, prefix=""): + """Fetch colocation file, line, and nesting and return a summary string.""" + return _compute_colocation_summary_from_dict( + op.name, op._colocation_dict, prefix) # pylint: disable=protected-access + + +def _find_index_of_defining_frame_for_op(op): + """Return index in op._traceback with first 'useful' frame. + + This method reads through the stack stored in op._traceback looking for the + innermost frame which (hopefully) belongs to the caller. It accomplishes this + by rejecting frames whose filename appears to come from TensorFlow (see + error_interpolation._BAD_FILE_SUBSTRINGS for the list of rejected substrings). + + Args: + op: the Operation object for which we would like to find the defining + location. + + Returns: + Integer index into op._traceback where the first non-TF file was found + (innermost to outermost), or 0 (for the outermost stack frame) if all files + came from TensorFlow. + """ + # pylint: disable=protected-access + # Index 0 of tf_traceback is the outermost frame. + tf_traceback = tf_stack.convert_stack(op._traceback) + size = len(tf_traceback) + # pylint: enable=protected-access + filenames = [frame[tf_stack.TB_FILENAME] for frame in tf_traceback] + # We process the filenames from the innermost frame to outermost. + for idx, filename in enumerate(reversed(filenames)): + contains_bad_substrings = [ss in filename for ss in _BAD_FILE_SUBSTRINGS] + if not any(contains_bad_substrings): + return size - idx - 1 + return 0 + + +def _get_defining_frame_from_op(op): + """Find and return stack frame where op was defined.""" + frame_index = _find_index_of_defining_frame_for_op(op) + # pylint: disable=protected-access + frame = op._traceback[frame_index] + # pylint: enable=protected-access + return frame + + +def compute_field_dict(op): + """Return a dictionary mapping interpolation tokens to values. + + Args: + op: op.Operation object having a _traceback member. + + Returns: + A dictionary mapping string tokens to string values. The keys are shown + below along with example values. + { + "file": "tool_utils.py", + "line": "124", + "defined_at": " (defined at tool_utils.py:124)", + "colocations": + '''Node-device colocations active during op creation: + with tf.colocate_with(test_node_1): + with tf.colocate_with(test_node_2): ''' + "devices": + '''Device assignments active during op 'foo' creation: + with tf.device(/cpu:0): + with tf.device(some_func): ''' + "devs_and_colocs": A concatenation of colocations and devices, e.g. + '''Node-device colocations active during op creation: + with tf.colocate_with(test_node_1): + with tf.colocate_with(test_node_2): ''' + Device assignments active during op 'foo' creation: + with tf.device(/cpu:0): + with tf.device(some_func): ''' + } + """ + frame = _get_defining_frame_from_op(op) + filename = frame[tf_stack.TB_FILENAME] + lineno = frame[tf_stack.TB_LINENO] + defined_at = " (defined at %s:%d)" % (filename, lineno) + colocation_summary = _compute_colocation_summary_from_op(op) + device_summary = _compute_device_assignment_summary_from_op(op) + combined_summary = "\n".join([colocation_summary, device_summary]) + + field_dict = { + "file": filename, + "line": lineno, + "defined_at": defined_at, + "colocations": colocation_summary, + "devices": device_summary, + "devs_and_colocs": combined_summary, + } + return field_dict + + +def interpolate(error_message, graph): + """Interpolates an error message. + + The error message can contain tags of the form ^^type:name:format^^ which will + be replaced. + + Args: + error_message: A string to interpolate. + graph: ops.Graph object containing all nodes referenced in the error + message. + + Returns: + The string with tags of the form ^^type:name:format^^ interpolated. + """ + seps, tags = _parse_message(error_message) + + node_name_to_substitution_dict = {} + for name in [t.name for t in tags]: + if name in node_name_to_substitution_dict: + continue + try: + op = graph.get_operation_by_name(name) + except KeyError: + op = None + + if op is not None: + field_dict = compute_field_dict(op) + else: + msg = "" + field_dict = collections.defaultdict(lambda s=msg: s) + node_name_to_substitution_dict[name] = field_dict + + subs = [ + string.Template(tag.format).safe_substitute( + node_name_to_substitution_dict[tag.name]) for tag in tags + ] + return "".join( + itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue=""))) diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0427156b2bdeff5b8cdfc0fd2b1dc51350e597b5 --- /dev/null +++ b/tensorflow/python/framework/error_interpolation_test.py @@ -0,0 +1,298 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.framework.errors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import error_interpolation +from tensorflow.python.framework import ops +from tensorflow.python.framework import traceable_stack +from tensorflow.python.platform import test +from tensorflow.python.util import tf_stack + + +def _make_frame_with_filename(op, idx, filename): + """Return a copy of an existing stack frame with a new filename.""" + stack_frame = list(op._traceback[idx]) + stack_frame[tf_stack.TB_FILENAME] = filename + return tuple(stack_frame) + + +def _modify_op_stack_with_filenames(op, num_user_frames, user_filename, + num_inner_tf_frames): + """Replace op._traceback with a new traceback using special filenames.""" + tf_filename = "%d" + error_interpolation._BAD_FILE_SUBSTRINGS[0] + user_filename = os.path.join("%d", "my_favorite_file.py") + + num_requested_frames = num_user_frames + num_inner_tf_frames + num_actual_frames = len(op._traceback) + num_outer_frames = num_actual_frames - num_requested_frames + assert num_requested_frames <= num_actual_frames, "Too few real frames." + + # The op's traceback has outermost frame at index 0. + stack = [] + for idx in range(0, num_outer_frames): + stack.append(op._traceback[idx]) + for idx in range(len(stack), len(stack)+num_user_frames): + stack.append(_make_frame_with_filename(op, idx, user_filename % idx)) + for idx in range(len(stack), len(stack)+num_inner_tf_frames): + stack.append(_make_frame_with_filename(op, idx, tf_filename % idx)) + op._traceback = stack + + +class ComputeDeviceSummaryFromOpTest(test.TestCase): + + def testCorrectFormatWithActiveDeviceAssignments(self): + assignments = [] + assignments.append( + traceable_stack.TraceableObject("/cpu:0", + filename="hope.py", + lineno=24)) + assignments.append( + traceable_stack.TraceableObject("/gpu:2", + filename="please.py", + lineno=42)) + + summary = error_interpolation._compute_device_summary_from_list( + "nodename", assignments, prefix=" ") + + self.assertIn("nodename", summary) + self.assertIn("tf.device(/cpu:0)", summary) + self.assertIn("", summary) + self.assertIn("tf.device(/gpu:2)", summary) + self.assertIn("", summary) + + def testCorrectFormatWhenNoColocationsWereActive(self): + device_assignment_list = [] + summary = error_interpolation._compute_device_summary_from_list( + "nodename", device_assignment_list, prefix=" ") + self.assertIn("nodename", summary) + self.assertIn("No device assignments", summary) + + +class ComputeColocationSummaryFromOpTest(test.TestCase): + + def testCorrectFormatWithActiveColocations(self): + t_obj_1 = traceable_stack.TraceableObject(None, + filename="test_1.py", + lineno=27) + t_obj_2 = traceable_stack.TraceableObject(None, + filename="test_2.py", + lineno=38) + colocation_dict = { + "test_node_1": t_obj_1, + "test_node_2": t_obj_2, + } + summary = error_interpolation._compute_colocation_summary_from_dict( + "node_name", colocation_dict, prefix=" ") + self.assertIn("node_name", summary) + self.assertIn("colocate_with(test_node_1)", summary) + self.assertIn("", summary) + self.assertIn("colocate_with(test_node_2)", summary) + self.assertIn("", summary) + + def testCorrectFormatWhenNoColocationsWereActive(self): + colocation_dict = {} + summary = error_interpolation._compute_colocation_summary_from_dict( + "node_name", colocation_dict, prefix=" ") + self.assertIn("node_name", summary) + self.assertIn("No node-device colocations", summary) + + +class InterpolateFilenamesAndLineNumbersTest(test.TestCase): + + def setUp(self): + ops.reset_default_graph() + # Add nodes to the graph for retrieval by name later. + constant_op.constant(1, name="One") + constant_op.constant(2, name="Two") + three = constant_op.constant(3, name="Three") + self.graph = three.graph + + # Change the list of bad file substrings so that constant_op.py is chosen + # as the defining stack frame for constant_op.constant ops. + self.old_bad_strings = error_interpolation._BAD_FILE_SUBSTRINGS + error_interpolation._BAD_FILE_SUBSTRINGS = [ + "%sops.py" % os.sep, + "%sutil" % os.sep, + ] + + def tearDown(self): + error_interpolation._BAD_FILE_SUBSTRINGS = self.old_bad_strings + + def testFindIndexOfDefiningFrameForOp(self): + local_op = constant_op.constant(42).op + user_filename = "hope.py" + _modify_op_stack_with_filenames(local_op, + num_user_frames=3, + user_filename=user_filename, + num_inner_tf_frames=5) + idx = error_interpolation._find_index_of_defining_frame_for_op(local_op) + # Expected frame is 6th from the end because there are 5 inner frames witih + # TF filenames. + expected_frame = len(local_op._traceback) - 6 + self.assertEqual(expected_frame, idx) + + def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self): + local_op = constant_op.constant(43).op + # Truncate stack to known length. + local_op._traceback = local_op._traceback[:7] + # Ensure all frames look like TF frames. + _modify_op_stack_with_filenames(local_op, + num_user_frames=0, + user_filename="user_file.py", + num_inner_tf_frames=7) + idx = error_interpolation._find_index_of_defining_frame_for_op(local_op) + self.assertEqual(0, idx) + + def testNothingToDo(self): + normal_string = "This is just a normal string" + interpolated_string = error_interpolation.interpolate(normal_string, + self.graph) + self.assertEqual(interpolated_string, normal_string) + + def testOneTag(self): + one_tag_string = "^^node:Two:${file}^^" + interpolated_string = error_interpolation.interpolate(one_tag_string, + self.graph) + self.assertTrue(interpolated_string.endswith("constant_op.py"), + "interpolated_string '%s' did not end with constant_op.py" + % interpolated_string) + + def testOneTagWithAFakeNameResultsInPlaceholders(self): + one_tag_string = "^^node:MinusOne:${file}^^" + interpolated_string = error_interpolation.interpolate(one_tag_string, + self.graph) + self.assertEqual("", interpolated_string) + + def testTwoTagsNoSeps(self): + two_tags_no_seps = "^^node:One:${file}^^^^node:Three:${line}^^" + interpolated_string = error_interpolation.interpolate(two_tags_no_seps, + self.graph) + self.assertRegexpMatches(interpolated_string, "constant_op.py[0-9]+") + + def testTwoTagsWithSeps(self): + two_tags_with_seps = ";;;^^node:Two:${file}^^,,,^^node:Three:${line}^^;;;" + interpolated_string = error_interpolation.interpolate(two_tags_with_seps, + self.graph) + expected_regex = "^;;;.*constant_op.py,,,[0-9]*;;;$" + self.assertRegexpMatches(interpolated_string, expected_regex) + + +class InterpolateDeviceSummaryTest(test.TestCase): + + def _fancy_device_function(self, unused_op): + return "/cpu:*" + + def setUp(self): + ops.reset_default_graph() + self.zero = constant_op.constant([0.0], name="zero") + with ops.device("/cpu"): + self.one = constant_op.constant([1.0], name="one") + with ops.device("/cpu:0"): + self.two = constant_op.constant([2.0], name="two") + with ops.device(self._fancy_device_function): + self.three = constant_op.constant(3.0, name="three") + + self.graph = self.three.graph + + def testNodeZeroHasNoDeviceSummaryInfo(self): + message = "^^node:zero:${devices}^^" + result = error_interpolation.interpolate(message, self.graph) + self.assertIn("No device assignments were active", result) + + def testNodeOneHasExactlyOneInterpolatedDevice(self): + message = "^^node:one:${devices}^^" + result = error_interpolation.interpolate(message, self.graph) + num_devices = result.count("tf.device") + self.assertEqual(1, num_devices) + self.assertIn("tf.device(/cpu)", result) + + def testNodeTwoHasTwoInterpolatedDevice(self): + message = "^^node:two:${devices}^^" + result = error_interpolation.interpolate(message, self.graph) + num_devices = result.count("tf.device") + self.assertEqual(2, num_devices) + self.assertIn("tf.device(/cpu)", result) + self.assertIn("tf.device(/cpu:0)", result) + + def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self): + message = "^^node:three:${devices}^^" + result = error_interpolation.interpolate(message, self.graph) + num_devices = result.count("tf.device") + self.assertEqual(1, num_devices) + name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>" + expected_re = r"with tf.device\(.*%s\)" % name_re + self.assertRegexpMatches(result, expected_re) + + +class InterpolateColocationSummaryTest(test.TestCase): + + def setUp(self): + ops.reset_default_graph() + # Add nodes to the graph for retrieval by name later. + node_one = constant_op.constant(1, name="One") + node_two = constant_op.constant(2, name="Two") + + # node_three has one colocation group, obviously. + with ops.colocate_with(node_one): + node_three = constant_op.constant(3, name="Three_with_one") + + # node_four has one colocation group even though three is (transitively) + # colocated with one. + with ops.colocate_with(node_three): + constant_op.constant(4, name="Four_with_three") + + # node_five has two colocation groups because one and two are not colocated. + with ops.colocate_with(node_two): + with ops.colocate_with(node_one): + constant_op.constant(5, name="Five_with_one_with_two") + + self.graph = node_three.graph + + def testNodeThreeHasColocationInterpolation(self): + message = "^^node:Three_with_one:${colocations}^^" + result = error_interpolation.interpolate(message, self.graph) + self.assertIn("colocate_with(One)", result) + + def testNodeFourHasColocationInterpolationForNodeThreeOnly(self): + message = "^^node:Four_with_three:${colocations}^^" + result = error_interpolation.interpolate(message, self.graph) + self.assertIn("colocate_with(Three_with_one)", result) + self.assertNotIn( + "One", result, + "Node One should not appear in Four_with_three's summary:\n%s" + % result) + + def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self): + message = "^^node:Five_with_one_with_two:${colocations}^^" + result = error_interpolation.interpolate(message, self.graph) + self.assertIn("colocate_with(One)", result) + self.assertIn("colocate_with(Two)", result) + + def testColocationInterpolationForNodeLackingColocation(self): + message = "^^node:One:${colocations}^^" + result = error_interpolation.interpolate(message, self.graph) + self.assertIn("No node-device colocations", result) + self.assertNotIn("Two", result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py index 84106c32c673e15832ff747a7fededdfbfb94ed8..9f973de4004cf69921b551c90a7c4068edaa2029 100644 --- a/tensorflow/python/framework/errors_impl.py +++ b/tensorflow/python/framework/errors_impl.py @@ -63,9 +63,9 @@ class OpError(Exception): *N.B.* If the failed op was synthesized at runtime, e.g. a `Send` or `Recv` op, there will be no corresponding - @{tf.Operation} + `tf.Operation` object. In that case, this will return `None`, and you should - instead use the @{tf.OpError.node_def} to + instead use the `tf.OpError.node_def` to discover information about the op. Returns: @@ -181,10 +181,10 @@ class CancelledError(OpError): """Raised when an operation or step is cancelled. For example, a long-running operation (e.g. - @{tf.QueueBase.enqueue} may be + `tf.QueueBase.enqueue` may be cancelled by running another operation (e.g. - @{tf.QueueBase.close}, - or by @{tf.Session.close}. + `tf.QueueBase.close`, + or by `tf.Session.close`. A step that is running such a long-running operation will fail by raising `CancelledError`. @@ -221,9 +221,9 @@ class InvalidArgumentError(OpError): This may occur, for example, if an operation is receives an input tensor that has an invalid value or shape. For example, the - @{tf.matmul} op will raise this + `tf.matmul` op will raise this error if it receives an input that is not a matrix, and the - @{tf.reshape} op will raise + `tf.reshape` op will raise this error if the new shape does not match the number of elements in the input tensor. @@ -256,7 +256,7 @@ class NotFoundError(OpError): """Raised when a requested entity (e.g., a file or directory) was not found. For example, running the - @{tf.WholeFileReader.read} + `tf.WholeFileReader.read` operation could raise `NotFoundError` if it receives the name of a file that does not exist. @@ -273,7 +273,7 @@ class AlreadyExistsError(OpError): """Raised when an entity that we attempted to create already exists. For example, running an operation that saves a file - (e.g. @{tf.train.Saver.save}) + (e.g. `tf.train.Saver.save`) could potentially raise this exception if an explicit filename for an existing file was passed. @@ -291,7 +291,7 @@ class PermissionDeniedError(OpError): """Raised when the caller does not have permission to run an operation. For example, running the - @{tf.WholeFileReader.read} + `tf.WholeFileReader.read` operation could raise `PermissionDeniedError` if it receives the name of a file for which the user does not have the read file permission. @@ -340,7 +340,7 @@ class FailedPreconditionError(OpError): """Operation was rejected because the system is not in a state to execute it. This exception is most commonly raised when running an operation - that reads a @{tf.Variable} + that reads a `tf.Variable` before it has been initialized. @@__init__ @@ -357,9 +357,9 @@ class AbortedError(OpError): """The operation was aborted, typically due to a concurrent action. For example, running a - @{tf.QueueBase.enqueue} + `tf.QueueBase.enqueue` operation may raise `AbortedError` if a - @{tf.QueueBase.close} operation + `tf.QueueBase.close` operation previously ran. @@__init__ @@ -375,9 +375,9 @@ class OutOfRangeError(OpError): """Raised when an operation iterates past the valid input range. This exception is raised in "end-of-file" conditions, such as when a - @{tf.QueueBase.dequeue} + `tf.QueueBase.dequeue` operation is blocked on an empty queue, and a - @{tf.QueueBase.close} + `tf.QueueBase.close` operation executes. @@__init__ @@ -395,7 +395,7 @@ class UnimplementedError(OpError): Some operations may raise this error when passed otherwise-valid arguments that it does not currently support. For example, running - the @{tf.nn.max_pool} operation + the `tf.nn.max_pool` operation would raise this error if pooling was requested on the batch dimension, because this is not yet supported. @@ -443,7 +443,7 @@ class DataLossError(OpError): """Raised when unrecoverable data loss or corruption is encountered. For example, this may be raised by running a - @{tf.WholeFileReader.read} + `tf.WholeFileReader.read` operation, if the file is truncated while it is being read. @@__init__ @@ -475,8 +475,8 @@ _CODE_TO_EXCEPTION_CLASS = { c_api.PyExceptionRegistry_Init(_CODE_TO_EXCEPTION_CLASS) -_EXCEPTION_CLASS_TO_CODE = dict(( - (class_, code) for (code, class_) in _CODE_TO_EXCEPTION_CLASS.items())) +_EXCEPTION_CLASS_TO_CODE = { + class_: code for code, class_ in _CODE_TO_EXCEPTION_CLASS.items()} @tf_export("errors.exception_type_from_error_code") diff --git a/tensorflow/python/framework/fast_tensor_util.pyx b/tensorflow/python/framework/fast_tensor_util.pyx index 17d112a1ece9ae3d121b894d9b246d24b46d84e2..2e3e15f53a919bac669b56e4a8f27c1808da345a 100644 --- a/tensorflow/python/framework/fast_tensor_util.pyx +++ b/tensorflow/python/framework/fast_tensor_util.pyx @@ -6,6 +6,13 @@ cimport numpy as np from tensorflow.python.util import compat +def AppendBFloat16ArrayToTensorProto( + tensor_proto, np.ndarray[np.uint16_t, ndim=1] nparray): + cdef long i, n + n = nparray.size + for i in range(n): + tensor_proto.half_val.append(nparray[i]) + def AppendFloat16ArrayToTensorProto( # For numpy, npy_half is a typedef for npy_uint16, diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 82ecba310ba4cbad62b8fda073a006b28be2c3ad..f47c0d8a5e3867e9a1bdf2fb8bd00a635fd8d622 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -23,6 +23,7 @@ from __future__ import print_function import collections import hashlib +import sys from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 @@ -33,12 +34,17 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import compat -from tensorflow.python.util import tf_decorator +from tensorflow.python.util import function_utils +from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_inspect +# This is to avoid a circular dependency with cond_v2_impl. +cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access + class Defun(object): """Decorator used to define TensorFlow functions. @@ -249,9 +255,12 @@ class _DefinedFunction(object): # Constructed only when C API is enabled, lazily self._c_func = None self._sub_functions = dict() # Constructed with _definition or _c_func - device_stack = ops.get_default_graph()._device_function_stack # pylint: disable=protected-access + # pylint: disable=protected-access + device_funcs = ops.get_default_graph()._device_functions_outer_to_inner + # pylint: enable=protected-access + # Get the innermost device if possbile. - self._caller_device = device_stack[-1] if device_stack else None + self._caller_device = device_funcs[-1] if device_funcs else None # Cached OpDef for this function. When C API is enabled, this is # the only part of FunctionDef that we cache in Python. When C API @@ -348,7 +357,7 @@ class _DefinedFunction(object): if self._func_name: base_func_name = self._func_name else: - base_func_name = _get_func_name(self._func) + base_func_name = function_utils.get_func_name(self._func) if self._grad_func: base_func_name += ("_%s" % self._grad_func.name) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) @@ -650,6 +659,41 @@ class _FuncGraph(ops.Graph): # TODO(skyewm): is this needed? self.extra_vars = [] + # pylint: disable=g-doc-return-or-yield + + @tf_contextlib.contextmanager + def container(self, container_name): + """Returns a context manager that specifies the resource container to use. + + Overridden from `tf.Graph` to update both the init_scope container + and the present inner container. This is necessary to make sure setting + containers applies correctly both to created variables and to stateful + ops. + + Args: + container_name: container name string. + + Returns: + A context manager for defining resource containers for stateful ops, + yields the container name. + """ + original_container = self._container + # pylint: disable=protected-access + with ops.init_scope(): + original_init_container = ops.get_default_graph()._container + try: + self._container = container_name + with ops.init_scope(): + ops.get_default_graph()._container = container_name + yield self._container + finally: + self._container = original_container + with ops.init_scope(): + ops.get_default_graph()._container = original_init_container + # pylint: enable=protected-access + + # pylint: enable=g-doc-return-or-yield + def getvar( self, getter, @@ -773,7 +817,9 @@ class _FuncGraph(ops.Graph): def func_graph_from_py_func(func, arg_names, arg_types, name=None, - capture_by_value=False, device=None): + capture_by_value=False, device=None, + colocation_stack=None, container=None, + collections_ref=None, arg_shapes=None): """Returns a _FuncGraph generated from `func`. Args: @@ -786,6 +832,11 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None, capture_by_value: boolean. If True, captured values will be copied into the function body. device: device name or function. + colocation_stack: A colocation stack (list) the _FuncGraph should use. + container: A container name the _FuncGraph should start with. + collections_ref: A reference to a collections dict the _FuncGraph should + use internally. + arg_shapes: A sequence of the function's argument shapes. Returns: A _FuncGraph. @@ -794,12 +845,25 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None, ValueError: if func returns None. """ if not name: - name = _get_func_name(func) + name = function_utils.get_func_name(func) func_graph = _FuncGraph(name, capture_by_value) + with func_graph.as_default(), ops.device(device): + # pylint: disable=protected-access + if collections_ref is not None: + func_graph._collections = collections_ref + if container is not None: + func_graph._container = container + if colocation_stack is not None: + func_graph._colocation_stack = colocation_stack + # pylint: enable=protected-access + + if arg_shapes is None: + arg_shapes = [None] * len(arg_types) + # Create placeholders for the function arguments. - for (argname, argtype) in zip(arg_names, arg_types): - argholder = array_ops.placeholder(argtype, name=argname) + for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes): + argholder = array_ops.placeholder(argtype, shape=argshape, name=argname) func_graph.inputs.append(argholder) # Call func and gather the output tensors. with vs.variable_scope("", custom_getter=func_graph.getvar): @@ -1082,19 +1146,6 @@ def _parse_kwargs_as_attrs(func_name, **kwargs): return attrs -def _get_func_name(func): - _, func = tf_decorator.unwrap(func) - if callable(func): - if tf_inspect.isfunction(func): - return func.__name__ - elif tf_inspect.ismethod(func): - return "%s.%s" % (func.__self__.__name__, func.__name__) - else: # Probably a class instance with __call__ - return type(func) - else: - raise ValueError("Argument must be callable") - - def get_extra_vars(): """Returns the captured variables by the function. diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py index 4fecc41343f26291dac8455f6c972a755b65ecfc..1b09506662d26a4d0be8e7d77f7f0fea61d6835b 100644 --- a/tensorflow/python/framework/function_def_to_graph.py +++ b/tensorflow/python/framework/function_def_to_graph.py @@ -18,13 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys + from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import versions_pb2 from tensorflow.python.framework import function from tensorflow.python.framework import importer -from tensorflow.python.framework import op_def_registry +from tensorflow.python.framework import ops from tensorflow.python.framework import versions +from tensorflow.python.ops import cond_v2_impl + +# This is to avoid a circular dependency with cond_v2_impl. +cond_v2_impl._function_def_to_graph = sys.modules[__name__] # pylint: disable=protected-access def function_def_to_graph(fdef, input_shapes=None): @@ -108,6 +114,10 @@ def function_def_to_graph_def(fdef, input_shapes=None): producer=versions.GRAPH_DEF_VERSION, min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)) + # Copy *all* functions from outer graph to `graph_def` so that both direct + # and indirect references are safely handled. + ops.get_default_graph()._copy_functions_to_graph_def(graph_def, 0) # pylint: disable=protected-access + if input_shapes and len(input_shapes) != len(fdef.signature.input_arg): raise ValueError("Length of input_shapes must match the number of " + "input_args. len(input_shapes): {} len(input_arg): {}". @@ -136,24 +146,18 @@ def function_def_to_graph_def(fdef, input_shapes=None): nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name) for node_def in fdef.node_def: - op_def = op_def_registry.get_registered_ops().get(node_def.op) - if not op_def: - # TODO(b/80470245): Support functions which refer other functions. - raise NotImplementedError( - "No op registered for {},".format(node_def.op) + - " it may be a function. function_def_to_graph_def " + - "currently does not support converting functions with " + - "references to other graph functions.") + op_def = ops.get_default_graph()._get_op_def(node_def.op) # pylint: disable=protected-access for attr in op_def.attr: - if attr.type in ("func", "list(func)"): - # TODO(b/80470245): Support functions which refer other functions. - raise NotImplementedError("Unsupported attr {} ".format(attr.name) + - " with type {}".format(attr.type) + - " in op {}. ".format(op_def.name) + - "function_def_to_graph_def currently does " + - "not support converting functions with " + - "references to other graph functions.") + if attr.type == "func": + fname = node_def.attr[attr.name].func.name + if not ops.get_default_graph()._is_function(fname): # pylint: disable=protected-access + raise ValueError("%s function not found." % fname) + elif attr.type == "list(func)": + for fn in node_def.attr[attr.name].list.func: + fname = fn.name + if not ops.get_default_graph()._is_function(fname): # pylint: disable=protected-access + raise ValueError("%s function not found." % fname) # Iterate over output_args in op_def to build the map. # Index of the output tensor in the flattened list of *all* output diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py index 0f4e6ef54fb02cc6ba52c9de2ccabea982fd2323..cd2a16ed5a83f8ffcd8288967815b7ad2761c52f 100644 --- a/tensorflow/python/framework/function_def_to_graph_test.py +++ b/tensorflow/python/framework/function_def_to_graph_test.py @@ -18,7 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function from tensorflow.python.framework import function_def_to_graph from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops @@ -79,7 +81,6 @@ class FunctionDefToGraphTest(test.TestCase): g = function_def_to_graph.function_def_to_graph( fdef, input_shapes=[None, tensor_shape.matrix(5, 7)]) - print(g.as_graph_def()) self.assertIsNone(g.inputs[0].shape.dims) self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7]) self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7]) @@ -179,6 +180,37 @@ class FunctionDefToGraphDefTest(test.TestCase): self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False) self.assertFalse("shape" in g.node[2].attr) + def testFunctionCallsFromFunction(self): + x = constant_op.constant(5.0) + y = constant_op.constant(10.0) + + @function.Defun() + def fn(): + + @function.Defun() + def inner_fn(): + return x + y + + return inner_fn() + + # Instantiate the function in this graph so that + # `function_def_to_graph` can find it. + fn() + + def fn2(): + return 2 * fn() + + fdef = function._DefinedFunction(fn2, [], []).definition + func_graph = function_def_to_graph.function_def_to_graph(fdef) + with func_graph.as_default(): + x_ph, y_ph = func_graph.inputs + with self.test_session(graph=func_graph) as sess: + self.assertEqual( + sess.run(func_graph.outputs[0], feed_dict={ + x_ph: 5.0, + y_ph: 10.0 + }), 30.0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 15e41ba91f9ae121d3d4ea48e3e71eace7cd9a3e..1707f929b89203e1890ee96fd153ace2063b449c 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -537,19 +537,25 @@ class FunctionTest(test.TestCase): def testResourceVarAsImplicitInput(self): g = ops.Graph() with g.as_default(), ops.device("cpu:0"): + expected_type = dtypes.float32 + expected_shape = tensor_shape.TensorShape((4, 4)) v = variable_scope.get_variable( - "var", (4, 4), dtypes.float32, use_resource=True) + "var", expected_shape, expected_type, use_resource=True) @function.Defun() def Foo(): - return array_ops.identity(v) + captured = array_ops.identity(v) + self.assertEqual(expected_type, captured.dtype) + self.assertEqual(expected_shape, captured.shape) + return captured, array_ops.shape(captured) - y = v.value() - z = Foo() + expected_val = v.value() + actual_val, actual_shape = Foo() with self.test_session(graph=g): v.initializer.run() - self.assertAllEqual(y.eval(), z.eval()) + self.assertAllEqual(expected_val.eval(), actual_val.eval()) + self.assertAllEqual(expected_shape, actual_shape.eval()) def testDefineErrors(self): with ops.Graph().as_default(): diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 72eb7e0eeb73fb1f8725ab2cbd4182e543c79b9f..e48e67c8a13aea7bb070f4b216cdc8081c711da4 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -205,7 +205,7 @@ def _PopulateTFImportGraphDefOptions(options, prefix, input_map, for input_src, input_dst in input_map.items(): input_src = compat.as_str(input_src) if input_src.startswith('^'): - src_name = compat.as_bytes(input_src[1:]) + src_name = compat.as_str(input_src[1:]) dst_op = input_dst._as_tf_output().oper # pylint: disable=protected-access c_api.TF_ImportGraphDefOptionsRemapControlDependency( options, src_name, dst_op) @@ -344,9 +344,9 @@ def import_graph_def(graph_def, This function provides a way to import a serialized TensorFlow [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as - @{tf.Tensor} and @{tf.Operation} objects. Once extracted, + `tf.Tensor` and `tf.Operation` objects. Once extracted, these objects are placed into the current default `Graph`. See - @{tf.Graph.as_graph_def} for a way to create a `GraphDef` + `tf.Graph.as_graph_def` for a way to create a `GraphDef` proto. Args: @@ -407,11 +407,11 @@ def import_graph_def(graph_def, _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements) - # _ProcessNewOps mutates the new operations. _lock ensures a Session.run - # call cannot occur between creating the TF_Operations in the + # _ProcessNewOps mutates the new operations. _mutation_lock ensures a + # Session.run call cannot occur between creating the TF_Operations in the # TF_GraphImportGraphDefWithResults call and mutating the them in # _ProcessNewOps. - with graph._lock: # pylint: disable=protected-access + with graph._mutation_lock(): # pylint: disable=protected-access with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: try: results = c_api.TF_GraphImportGraphDefWithResults( diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index c5a54470d27b5949fd642b057feda7f3f1a4347f..7182c28666e8ab81962ed12f6490bc67144aa75d 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import importer from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_ops # pylint: disable=unused-import from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops @@ -419,6 +420,46 @@ class ImportGraphDefTest(test.TestCase): with self.test_session() as sess: self.assertEqual(sess.run(imported_r), 10) + def testImportWhileLoopInCond(self): + # Produce GraphDef containing while loop. + graph = ops.Graph() + with graph.as_default(): + r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0]) + graph_def = graph.as_graph_def() + + # Import the GraphDef inside a cond and make sure it runs. + with ops.Graph().as_default(): + + def ImportFn(): + return importer.import_graph_def(graph_def, return_elements=[r.name])[0] + + pred = array_ops.placeholder(dtypes.bool) + out = control_flow_ops.cond(pred, ImportFn, + lambda: constant_op.constant(1)) + with self.test_session() as sess: + self.assertEqual(sess.run(out, {pred: True}), 10) + self.assertEqual(sess.run(out, {pred: False}), 1) + + def testImportWhileLoopInWhileLoop(self): + self.skipTest("b/111757448") + # Produce GraphDef containing while loop. + graph = ops.Graph() + with graph.as_default(): + r = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [0]) + graph_def = graph.as_graph_def() + + # Import the GraphDef inside another loop and make sure it runs. + with ops.Graph().as_default(): + + def ImportFn(_): + return importer.import_graph_def(graph_def, return_elements=[r.name])[0] + + out = control_flow_ops.while_loop( + lambda i: i < 2, ImportFn, [0], + shape_invariants=[tensor_shape.TensorShape(None)]) + with self.test_session() as sess: + self.assertEqual(sess.run(out), 10) + def testTypeMismatchInGraphDef(self): # TODO(skyewm): improve error message error_msg = ("Input 0 of node import/B was passed int32 from import/A:0 " diff --git a/tensorflow/python/framework/kernels.py b/tensorflow/python/framework/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..f7641f3442e4c5a6508a3463c700ade97ce202a9 --- /dev/null +++ b/tensorflow/python/framework/kernels.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions for querying registered kernels.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.framework import kernel_def_pb2 +from tensorflow.python import pywrap_tensorflow as c_api +from tensorflow.python.util import compat + + +def get_all_registered_kernels(): + """Returns a KernelList proto of all registered kernels. + """ + buf = c_api.TF_GetAllRegisteredKernels() + data = c_api.TF_GetBuffer(buf) + kernel_list = kernel_def_pb2.KernelList() + kernel_list.ParseFromString(compat.as_bytes(data)) + return kernel_list + + +def get_registered_kernels_for_op(name): + """Returns a KernelList proto of registered kernels for a given op. + + Args: + name: A string representing the name of the op whose kernels to retrieve. + """ + buf = c_api.TF_GetRegisteredKernelsForOp(name) + data = c_api.TF_GetBuffer(buf) + kernel_list = kernel_def_pb2.KernelList() + kernel_list.ParseFromString(compat.as_bytes(data)) + return kernel_list diff --git a/tensorflow/python/framework/kernels_test.py b/tensorflow/python/framework/kernels_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c53500be73a05b2d9b379fd61e899a091b7db9b1 --- /dev/null +++ b/tensorflow/python/framework/kernels_test.py @@ -0,0 +1,41 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for querying registered kernels.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import kernels +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +class GetAllRegisteredKernelsTest(test_util.TensorFlowTestCase): + + def testFindsAtLeastOneKernel(self): + kernel_list = kernels.get_all_registered_kernels() + self.assertGreater(len(kernel_list.kernel), 0) + + +class GetRegisteredKernelsForOp(test_util.TensorFlowTestCase): + + def testFindsAtLeastOneKernel(self): + kernel_list = kernels.get_registered_kernels_for_op("KernelLabel") + self.assertGreater(len(kernel_list.kernel), 0) + self.assertEqual(kernel_list.kernel[0].op, "KernelLabel") + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py index 923e76fc9c8f231cc9a43bc05280dac1ea458d3c..33631282bd03a15daddb334e6f40e6b52f84c750 100644 --- a/tensorflow/python/framework/meta_graph.py +++ b/tensorflow/python/framework/meta_graph.py @@ -696,6 +696,67 @@ def import_scoped_meta_graph(meta_graph_or_file, Raises: ValueError: If the graph_def contains unbound inputs. """ + return import_scoped_meta_graph_with_return_elements( + meta_graph_or_file, clear_devices, graph, import_scope, input_map, + unbound_inputs_col_name, restore_collections_predicate)[0] + + +def import_scoped_meta_graph_with_return_elements( + meta_graph_or_file, + clear_devices=False, + graph=None, + import_scope=None, + input_map=None, + unbound_inputs_col_name="unbound_inputs", + restore_collections_predicate=(lambda key: True), + return_elements=None): + """Imports graph from `MetaGraphDef` and returns vars and return elements. + + This function takes a `MetaGraphDef` protocol buffer as input. If + the argument is a file containing a `MetaGraphDef` protocol buffer , + it constructs a protocol buffer from the file content. The function + then adds all the nodes from the `graph_def` field to the + current graph, recreates the desired collections, and returns a dictionary of + all the Variables imported into the name scope. + + In combination with `export_scoped_meta_graph()`, this function can be used to + + * Serialize a graph along with other Python objects such as `QueueRunner`, + `Variable` into a `MetaGraphDef`. + + * Restart training from a saved graph and checkpoints. + + * Run inference from a saved graph and checkpoints. + + Args: + meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including + the path) containing a `MetaGraphDef`. + clear_devices: Boolean which controls whether to clear device information + from graph_def. Default false. + graph: The `Graph` to import into. If `None`, use the default graph. + import_scope: Optional `string`. Name scope into which to import the + subgraph. If `None`, the graph is imported to the root name scope. + input_map: A dictionary mapping input names (as strings) in `graph_def` to + `Tensor` objects. The values of the named input tensors in the imported + graph will be re-mapped to the respective `Tensor` values. + unbound_inputs_col_name: Collection name for looking up unbound inputs. + restore_collections_predicate: a predicate on collection names. A collection + named c (i.e whose key is c) will be restored iff + 1) `restore_collections_predicate(c)` is True, and + 2) `c != unbound_inputs_col_name`. + return_elements: A list of strings containing operation names in the + `MetaGraphDef` that will be returned as `Operation` objects; and/or + tensor names in `MetaGraphDef` that will be returned as `Tensor` objects. + + Returns: + A tuple of ( + dictionary of all the `Variables` imported into the name scope, + list of `Operation` or `Tensor` objects from the `return_elements` list). + + Raises: + ValueError: If the graph_def contains unbound inputs. + + """ if context.executing_eagerly(): raise ValueError("Exporting/importing meta graphs is not supported when " "eager execution is enabled.") @@ -737,11 +798,12 @@ def import_scoped_meta_graph(meta_graph_or_file, scope_to_prepend_to_names = graph.unique_name( import_scope or "", mark_as_used=False) - importer.import_graph_def( + imported_return_elements = importer.import_graph_def( input_graph_def, name=(import_scope or scope_to_prepend_to_names), input_map=input_map, - producer_op_list=producer_op_list) + producer_op_list=producer_op_list, + return_elements=return_elements) # Restores all the other collections. variable_objects = {} @@ -806,7 +868,7 @@ def import_scoped_meta_graph(meta_graph_or_file, for v in variables: var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v - return var_list + return var_list, imported_return_elements def export_scoped_meta_graph(filename=None, diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index b440cde3ad4268b9f53da77af0c570f7c0f51d65..5527f5286074baefaa29c6aec8a7c027a759f377 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -20,7 +20,6 @@ from __future__ import print_function import collections import copy -import linecache import os import re import sys @@ -45,17 +44,22 @@ from tensorflow.python.framework import c_api_util from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes +from tensorflow.python.framework import error_interpolation from tensorflow.python.framework import errors from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import registry from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import traceable_stack from tensorflow.python.framework import versions from tensorflow.python.ops import control_flow_util from tensorflow.python.platform import app from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import decorator_utils +from tensorflow.python.util import function_utils +from tensorflow.python.util import lock_util from tensorflow.python.util import tf_contextlib +from tensorflow.python.util import tf_stack from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.tf_export import tf_export @@ -63,7 +67,7 @@ from tensorflow.python.util.tf_export import tf_export # Temporary global switches determining if we should enable the work-in-progress # calls to the C API. These will be removed once all functionality is supported. _USE_C_API = True -_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "0") is not "0" +_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "1") != "0" def tensor_id(tensor): @@ -71,6 +75,31 @@ def tensor_id(tensor): return tensor._id # pylint: disable=protected-access +class _UserDeviceSpec(object): + """Store user-specified device and provide computation of merged device.""" + + def __init__(self, device_name_or_function): + self._device_name_or_function = device_name_or_function + + self.display_name = str(self._device_name_or_function) + if callable(self._device_name_or_function): + dev_func = self._device_name_or_function + func_name = function_utils.get_func_name(dev_func) + func_code = function_utils.get_func_code(dev_func) + if func_code: + fname = func_code.co_filename + lineno = func_code.co_firstlineno + else: + fname = "unknown" + lineno = -1 + self.display_name = "%s<%s, %d>" % (func_name, fname, lineno) + + self.function = self._device_name_or_function + if not (self._device_name_or_function is None or + callable(self._device_name_or_function)): + self.function = pydev.merge_device(self._device_name_or_function) + + class _NullContextmanager(object): def __enter__(self): @@ -200,7 +229,7 @@ class Tensor(_TensorLike): A `Tensor` is a symbolic handle to one of the outputs of an `Operation`. It does not hold the values of that operation's output, but instead provides a means of computing those values in a - TensorFlow @{tf.Session}. + TensorFlow `tf.Session`. This class has two primary purposes: @@ -211,7 +240,7 @@ class Tensor(_TensorLike): 2. After the graph has been launched in a session, the value of the `Tensor` can be computed by passing it to - @{tf.Session.run}. + `tf.Session.run`. `t.eval()` is a shortcut for calling `tf.get_default_session().run(t)`. @@ -336,7 +365,7 @@ class Tensor(_TensorLike): The shape is computed using shape inference functions that are registered in the Op for each `Operation`. See - @{tf.TensorShape} + `tf.TensorShape` for more details of what a shape represents. The inferred shape of a tensor is used to provide shape @@ -426,7 +455,7 @@ class Tensor(_TensorLike): def __iter__(self): if not context.executing_eagerly(): raise TypeError( - "Tensor objects are not iterable when eager execution is not " + "Tensor objects are only iterable when eager execution is " "enabled. To iterate over this tensor use tf.map_fn.") shape = self._shape_tuple() if shape is None: @@ -666,7 +695,7 @@ class Tensor(_TensorLike): Args: feed_dict: A dictionary that maps `Tensor` objects to feed values. - See @{tf.Session.run} for a + See `tf.Session.run` for a description of the valid feed values. session: (Optional.) The `Session` to be used to evaluate this tensor. If none, the default session will be used. @@ -705,9 +734,9 @@ class _EagerTensorBase(Tensor): """ if self.dtype == dtypes.resource: raise ValueError("Resource handles are not convertible to numpy.") - return self.cpu()._numpy() # pylint: disable=protected-access + return self._cpu_nograd()._numpy() # pylint: disable=protected-access - # __int__ and __float__ may copy the tensor to CPU and + # __int__, __float__ and __index__ may copy the tensor to CPU and # only work for scalars; values are cast as per numpy. def __int__(self): return int(self.numpy()) @@ -715,6 +744,9 @@ class _EagerTensorBase(Tensor): def __float__(self): return float(self.numpy()) + def __index__(self): + return int(self.numpy()) + def __array__(self, dtype=None): return np.array(self.numpy(), dtype=dtype) @@ -779,8 +811,8 @@ class _EagerTensorBase(Tensor): def _override_operator(name, func): setattr(_EagerTensorBase, name, func) - def _copy(self, ctx=None, device_name=None): - """Copies tensor to dest device.""" + def _copy_nograd(self, ctx=None, device_name=None): + """Copies tensor to dest device, but doesn't record the operation.""" # pylint: disable=protected-access # Creates a new tensor on the dest device. if ctx is None: @@ -792,7 +824,11 @@ class _EagerTensorBase(Tensor): new_tensor = self._copy_to_device(context=ctx._handle, device=device_name) except core._NotOkStatusException as e: six.raise_from(core._status_to_exception(e.code, e.message), None) + return new_tensor + def _copy(self, ctx=None, device_name=None): + """Copies tensor to dest device.""" + new_tensor = self._copy_nograd(ctx, device_name) # Record the copy on tape and define backprop copy as well. if context.executing_eagerly(): self_device = self.device @@ -823,6 +859,16 @@ class _EagerTensorBase(Tensor): """Returns the number of Tensor dimensions.""" return self.shape.ndims + def _cpu_nograd(self): + """A copy of this Tensor with contents backed by host memory. + + The copy cannot be differentiated through. + + Returns: + A CPU-memory backed Tensor object with the same contents as this Tensor. + """ + return self._copy_nograd(context.context(), "CPU:0") + def cpu(self): """A copy of this Tensor with contents backed by host memory.""" return self._copy(context.context(), "CPU:0") @@ -1409,10 +1455,10 @@ class IndexedSlices(_TensorLike): The `IndexedSlices` class is used principally in the definition of gradients for operations that have sparse gradients - (e.g. @{tf.gather}). + (e.g. `tf.gather`). Contrast this representation with - @{tf.SparseTensor}, + `tf.SparseTensor`, which uses multi-dimensional indices and scalar values. """ @@ -1573,8 +1619,8 @@ class Operation(object): more `Tensor` objects as input, and produces zero or more `Tensor` objects as output. Objects of type `Operation` are created by calling a Python op constructor (such as - @{tf.matmul}) - or @{tf.Graph.create_op}. + `tf.matmul`) + or `tf.Graph.create_op`. For example `c = tf.matmul(a, b)` creates an `Operation` of type "MatMul" that takes tensors `a` and `b` as input, and produces `c` @@ -1582,7 +1628,7 @@ class Operation(object): After the graph has been launched in a session, an `Operation` can be executed by passing it to - @{tf.Session.run}. + `tf.Session.run`. `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. """ @@ -1696,10 +1742,19 @@ class Operation(object): # This will be set by self.inputs. self._inputs_val = None - self._id_value = self._graph._next_id() # pylint: disable=protected-access + # pylint: disable=protected-access + self._id_value = self._graph._next_id() self._original_op = original_op - self._traceback = self._graph._extract_stack() # pylint: disable=protected-access - self._control_flow_context = self.graph._get_control_flow_context() # pylint: disable=protected-access + self._traceback = tf_stack.extract_stack() + + # List of _UserDevSpecs holding code location of device context manager + # invocations and the users original argument to them. + self._device_code_locations = None + # Dict mapping op name to file and line information for op colocation + # context managers. + self._colocation_code_locations = None + self._control_flow_context = self.graph._get_control_flow_context() + # pylint: enable=protected-access # Initialize self._c_op. if c_op: @@ -1837,6 +1892,72 @@ class Operation(object): """ return c_api.TF_OperationDevice(self._c_op) + @property + def _device_assignments(self): + """Code locations for device context managers active at op creation. + + This property will return a list of traceable_stack.TraceableObject + instances where .obj is a string representing the assigned device + (or information about the function that would be applied to this op + to compute the desired device) and the filename and lineno members + record the location of the relevant device context manager. + + For example, suppose file_a contained these lines: + + file_a.py: + 15: with tf.device('/gpu:0'): + 16: node_b = tf.constant(4, name='NODE_B') + + Then a TraceableObject t_obj representing the device context manager + would have these member values: + + t_obj.obj -> '/gpu:0' + t_obj.filename = 'file_a.py' + t_obj.lineno = 15 + + and node_b.op._device_assignments would return the list [t_obj]. + + Returns: + [str: traceable_stack.TraceableObject, ...] as per this method's + description, above. + """ + return self._device_code_locations or [] + + @property + def _colocation_dict(self): + """Code locations for colocation context managers active at op creation. + + This property will return a dictionary for which the keys are nodes with + which this Operation is colocated, and for which the values are + traceable_stack.TraceableObject instances. The TraceableObject instances + record the location of the relevant colocation context manager but have the + "obj" field set to None to prevent leaking private data. + + For example, suppose file_a contained these lines: + + file_a.py: + 14: node_a = tf.constant(3, name='NODE_A') + 15: with tf.colocate_with(node_a): + 16: node_b = tf.constant(4, name='NODE_B') + + Then a TraceableObject t_obj representing the colocation context manager + would have these member values: + + t_obj.obj -> None + t_obj.filename = 'file_a.py' + t_obj.lineno = 15 + + and node_b.op._colocation_dict would return the dictionary + + { 'NODE_A': t_obj } + + Returns: + {str: traceable_stack.TraceableObject} as per this method's description, + above. + """ + locations_dict = self._colocation_code_locations or {} + return locations_dict.copy() + @property def _output_types(self): """List this operation's output types. @@ -2139,7 +2260,7 @@ class Operation(object): @property def traceback(self): """Returns the call stack from when this operation was constructed.""" - return self._graph._convert_stack(self._traceback) # pylint: disable=protected-access + return tf_stack.convert_stack(self._traceback) @property def traceback_with_start_lines(self): @@ -2148,9 +2269,8 @@ class Operation(object): Returns: A list of 5-tuples (filename, lineno, name, code, func_start_lineno). """ - return self._graph._convert_stack( # pylint: disable=protected-access - self._traceback, - include_func_start_lineno=True) + return tf_stack.convert_stack(self._traceback, + include_func_start_lineno=True) def _set_attr(self, attr_name, attr_value): """Private method used to set an attribute in the node_def.""" @@ -2218,7 +2338,7 @@ class Operation(object): Args: feed_dict: A dictionary that maps `Tensor` objects to feed values. - See @{tf.Session.run} + See `tf.Session.run` for a description of the valid feed values. session: (Optional.) The `Session` to be used to run to this operation. If none, the default session will be used. @@ -2599,18 +2719,21 @@ def _name_from_scope_name(name): return name[:-1] if (name and name[-1] == "/") else name +_MUTATION_LOCK_GROUP = 0 +_SESSION_RUN_LOCK_GROUP = 1 + @tf_export("Graph") class Graph(object): """A TensorFlow computation, represented as a dataflow graph. A `Graph` contains a set of - @{tf.Operation} objects, + `tf.Operation` objects, which represent units of computation; and - @{tf.Tensor} objects, which represent + `tf.Tensor` objects, which represent the units of data that flow between operations. A default `Graph` is always registered, and accessible by calling - @{tf.get_default_graph}. + `tf.get_default_graph`. To add an operation to the default graph, simply call one of the functions that defines a new `Operation`: @@ -2620,7 +2743,7 @@ class Graph(object): ``` Another typical usage involves the - @{tf.Graph.as_default} + `tf.Graph.as_default` context manager, which overrides the current default graph for the lifetime of the context: @@ -2641,27 +2764,28 @@ class Graph(object): that are identified by name. For convenience when building a large graph, collections can store groups of related objects: for example, the `tf.Variable` uses a collection (named - @{tf.GraphKeys.GLOBAL_VARIABLES}) for + `tf.GraphKeys.GLOBAL_VARIABLES`) for all variables that are created during the construction of a graph. The caller may define additional collections by specifying a new name. """ def __init__(self): """Creates a new, empty Graph.""" - # Protects core state that can be returned via public accessors, as well as - # synchronizes Session.run calls with methods that create and mutate ops - # (e.g. Graph.create_op()). This synchronization is necessary because it's - # illegal to modify an operation after it's been run. Thread-safety is - # provided on a best-effort basis to support buggy programs, and is not - # guaranteed by the public `tf.Graph` API. - # - # The lock must be reentrant because create_op can be called recursively due - # to control flow. Without a reentrant lock, many methods would also need a - # "locked" version or parameter (including generated code). + # Protects core state that can be returned via public accessors. + # Thread-safety is provided on a best-effort basis to support buggy + # programs, and is not guaranteed by the public `tf.Graph` API. # # NOTE(mrry): This does not protect the various stacks. A warning will # be reported if these are used from multiple threads self._lock = threading.RLock() + # The group lock synchronizes Session.run calls with methods that create + # and mutate ops (e.g. Graph.create_op()). This synchronization is + # necessary because it's illegal to modify an operation after it's been run. + # The group lock allows any number of threads to mutate ops at the same time + # but if any modification is going on, all Session.run calls have to wait. + # Similarly, if one or more Session.run calls are going on, all mutate ops + # have to wait until all Session.run calls have finished. + self._group_lock = lock_util.GroupLock(num_groups=2) self._nodes_by_id = dict() # GUARDED_BY(self._lock) self._next_id_counter = 0 # GUARDED_BY(self._lock) self._nodes_by_name = dict() # GUARDED_BY(self._lock) @@ -2673,7 +2797,7 @@ class Graph(object): # Functions that will be applied to choose a device if none is specified. # After switch_to_thread_local(), self._thread_local._device_function_stack # is used instead. - self._graph_device_function_stack = [] + self._graph_device_function_stack = traceable_stack.TraceableStack() # Default original_op applied to new ops. self._default_original_op = None # Current control flow context. It could be either CondContext or @@ -2706,7 +2830,7 @@ class Graph(object): self._building_function = False # Stack of colocate_with ops. After switch_to_thread_local(), # self._thread_local._colocation_stack is used instead. - self._graph_colocation_stack = [] + self._graph_colocation_stack = traceable_stack.TraceableStack() # Set of tensors that are dangerous to feed! self._unfeedable_tensors = set() # Set of operations that are dangerous to fetch! @@ -2746,36 +2870,6 @@ class Graph(object): """Temporary hack; can be overridden to force C API usage.""" return _USE_C_API - def _convert_stack(self, stack, include_func_start_lineno=False): - """Converts a stack extracted using _extract_stack() to a traceback stack. - - Args: - stack: A list of n 5-tuples, - (filename, lineno, name, frame_globals, func_start_lineno). - include_func_start_lineno: True if function start line number should be - included as the 5th entry in return tuples. - - Returns: - A list of n 4-tuples or 5-tuples - (filename, lineno, name, code, [optional: func_start_lineno]), where the - code tuple element is calculated from the corresponding elements of the - input tuple. - """ - ret = [] - for (filename, lineno, name, frame_globals, func_start_lineno, - unused_frame_info) in stack: - linecache.checkcache(filename) - line = linecache.getline(filename, lineno, frame_globals) - if line: - line = line.strip() - else: - line = None - if include_func_start_lineno: - ret.append((filename, lineno, name, line, func_start_lineno)) - else: - ret.append((filename, lineno, name, line)) - return ret - # Note: this method is private because the API of tf.Graph() is public and # frozen, and this functionality is still not ready for public visibility. @tf_contextlib.contextmanager @@ -2783,63 +2877,23 @@ class Graph(object): # This step makes a copy of the existing stack, and it also initializes # self._thread_local._variable_creator_stack if it doesn't exist yet. old = list(self._variable_creator_stack) - self._thread_local._variable_creator_stack.append(creator) + self._thread_local._variable_creator_stack.append(creator) # pylint: disable=protected-access try: yield finally: - self._thread_local._variable_creator_stack = old + self._thread_local._variable_creator_stack = old # pylint: disable=protected-access # Note: this method is private because the API of tf.Graph() is public and # frozen, and this functionality is still not ready for public visibility. @property def _variable_creator_stack(self): if not hasattr(self._thread_local, "_variable_creator_stack"): - self._thread_local._variable_creator_stack = [] - return list(self._thread_local._variable_creator_stack) + self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access + return list(self._thread_local._variable_creator_stack) # pylint: disable=protected-access @_variable_creator_stack.setter def _variable_creator_stack(self, variable_creator_stack): - self._thread_local._variable_creator_stack = variable_creator_stack - - def _extract_stack(self): - """A lightweight, extensible re-implementation of traceback.extract_stack. - - NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for - each stack frame using linecache, which results in an abundance of stat() - calls. This implementation does not retrieve the code, and any consumer - should apply _convert_stack to the result to obtain a traceback that can - be formatted etc. using traceback methods. - - Derived classes can implement _extract_frame_info() to add extra information - to the traceback. - - Returns: - A list of 6-tuples - (filename, lineno, name, frame_globals, func_start_lineno, custom_info) - corresponding to the call stack of the current thread. - """ - try: - raise ZeroDivisionError - except ZeroDivisionError: - f = sys.exc_info()[2].tb_frame.f_back - ret = [] - while f is not None: - lineno = f.f_lineno - co = f.f_code - filename = co.co_filename - name = co.co_name - frame_globals = f.f_globals - func_start_lineno = co.co_firstlineno - frame_info = self._extract_frame_info(f) - ret.append((filename, lineno, name, frame_globals, func_start_lineno, - frame_info)) - f = f.f_back - ret.reverse() - return ret - - def _extract_frame_info(self, frame): # pylint: disable=unused-argument - """Extracts custom information from a frame in an op traceback.""" - return None + self._thread_local._variable_creator_stack = variable_creator_stack # pylint: disable=protected-access def _check_not_finalized(self): """Check if the graph is finalized. @@ -2887,7 +2941,7 @@ class Graph(object): """Returns a version number that increases as ops are added to the graph. Note that this is unrelated to the - @{tf.Graph.graph_def_versions}. + `tf.Graph.graph_def_versions`. Returns: An integer version that increases as ops are added to the graph. @@ -2937,7 +2991,7 @@ class Graph(object): After calling `g.finalize()`, no new operations can be added to `g`. This method is used to ensure that no operations are added to a graph when it is shared between multiple threads, for example - when using a @{tf.train.QueueRunner}. + when using a `tf.train.QueueRunner`. """ self._finalized = True @@ -2986,7 +3040,7 @@ class Graph(object): """Returns a serialized `GraphDef` representation of this graph. The serialized `GraphDef` can be imported into another `Graph` - (using @{tf.import_graph_def}) or used with the + (using `tf.import_graph_def`) or used with the [C++ Session API](../../../../api_docs/cc/index.md). This method is thread-safe. @@ -3032,7 +3086,7 @@ class Graph(object): """Returns a serialized `GraphDef` representation of this graph. The serialized `GraphDef` can be imported into another `Graph` - (using @{tf.import_graph_def}) or used with the + (using `tf.import_graph_def`) or used with the [C++ Session API](../../api_docs/cc/index.md). This method is thread-safe. @@ -3192,9 +3246,9 @@ class Graph(object): input_ops = set([t.op for t in inputs]) control_inputs = self._control_dependencies_for_inputs(input_ops) - # _create_op_helper mutates the new Operation. _lock ensures a Session.run - # call cannot occur between creating and mutating the op. - with self._lock: + # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a + # Session.run call cannot occur between creating and mutating the op. + with self._mutation_lock(): ret = Operation( node_def, self, @@ -3239,6 +3293,36 @@ class Graph(object): self._create_op_helper(ret, compute_device=compute_device) return ret + def _make_colocation_conflict_message(self, op, colocation_op): + """Return detailed error message about device conflict due to colocation.""" + # Example error message: + # Tried to colocate op 'a' (defined at file1.py:149) having device + # '/device:GPU:0' with op 'b' (defined at file2:96) which had an + # incompatible device '/device:CPU:0'. + # + # No node-device colocations were active during op 'a' creation. + # Device assignments active during op 'a' creation: + # with tf.device(/device:GPU:0): file1.py:148> + # + # Node-device colocations active during op 'b' creation: + # with tf.colocate_with(a): file2.py:93> + # Device assignments active during op 'b' creation: + # with tf.device(/cpu:0): file2.py:94 + op_info = error_interpolation.compute_field_dict(op) + coloc_op_info = error_interpolation.compute_field_dict(colocation_op) + msg = ("Tried to colocate op '{op_name}'{op_loc} having device '{op_dev}' " + "with op '{coloc_op_name}'{coloc_op_loc} which had an incompatible " + "device '{coloc_op_dev}'.\n\n{op_summary}\n\n{coloc_op_summary}" + .format(op_name=op.name, + op_loc=op_info["defined_at"], + op_dev=op.device, + op_summary=op_info["devs_and_colocs"], + coloc_op_name=colocation_op.name, + coloc_op_loc=coloc_op_info["defined_at"], + coloc_op_dev=colocation_op.device, + coloc_op_summary=coloc_op_info["devs_and_colocs"])) + return msg + def _create_op_helper(self, op, compute_device=True): """Common logic for creating an op in this graph.""" # Apply any additional attributes requested. Do not overwrite any existing @@ -3279,20 +3363,22 @@ class Graph(object): if compute_device: self._apply_device_functions(op) + # Snapshot the colocation stack metadata before we might generate error + # messages using it. Note that this snapshot depends on the actual stack + # and is independent of the op's _class attribute. + # pylint: disable=protected-access + op._colocation_code_locations = self._snapshot_colocation_stack_metadata() + # pylint: enable=protected-access + if self._colocation_stack: all_colocation_groups = [] - for colocation_op in self._colocation_stack: + for colocation_op in self._colocation_stack.peek_objs(): all_colocation_groups.extend(colocation_op.colocation_groups()) if colocation_op.device: - # Make this device match the device of the colocated op, to provide - # consistency between the device and the colocation property. if (op.device and pydev.canonical_name(op.device) != pydev.canonical_name(colocation_op.device)): - logging.warning("Tried to colocate %s with an op %s that had " - "a different device: %s vs %s. Postponing " - "error-checking until all devices are assigned.", - op.name, colocation_op.name, op.device, - colocation_op.device) + msg = self._make_colocation_conflict_message(op, colocation_op) + logging.warning(msg) else: op._set_device(colocation_op.device) # pylint: disable=protected-access @@ -3609,9 +3695,13 @@ class Graph(object): This method should be used if you want to create multiple graphs in the same process. For convenience, a global default graph is provided, and all ops will be added to this graph if you do not - create a new graph explicitly. Use this method with the `with` keyword - to specify that ops created within the scope of a block should be - added to this graph. + create a new graph explicitly. + + Use this method with the `with` keyword to specify that ops created within + the scope of a block should be added to this graph. In this case, once + the scope of the `with` is exited, the previous default graph is set again + as default. There is a stack, so it's ok to have multiple nested levels + of `as_default` calls. The default graph is a property of the current thread. If you create a new thread, and wish to use the default graph in that @@ -3782,8 +3872,8 @@ class Graph(object): Nothing. """ old_original_op = self._default_original_op + self._default_original_op = op try: - self._default_original_op = op yield finally: self._default_original_op = old_original_op @@ -3900,15 +3990,15 @@ class Graph(object): # op name regex, which constrains the initial character. if not _VALID_OP_NAME_REGEX.match(name): raise ValueError("'%s' is not a valid scope name" % name) + old_stack = self._name_stack + if not name: # Both for name=None and name="" we re-set to empty scope. + new_stack = None + elif name[-1] == "/": + new_stack = _name_from_scope_name(name) + else: + new_stack = self.unique_name(name) + self._name_stack = new_stack try: - old_stack = self._name_stack - if not name: # Both for name=None and name="" we re-set to empty scope. - new_stack = None - elif name[-1] == "/": - new_stack = _name_from_scope_name(name) - else: - new_stack = self.unique_name(name) - self._name_stack = new_stack yield "" if new_stack is None else new_stack + "/" finally: self._name_stack = old_stack @@ -3989,8 +4079,8 @@ class Graph(object): ignore_existing=False): with self.colocate_with(op, ignore_existing): if gradient_uid is not None and self._control_flow_context is not None: + self._control_flow_context.EnterGradientColocation(op, gradient_uid) try: - self._control_flow_context.EnterGradientColocation(op, gradient_uid) yield finally: self._control_flow_context.ExitGradientColocation(op, gradient_uid) @@ -4032,7 +4122,6 @@ class Graph(object): Yields: A context manager that specifies the op with which to colocate newly created ops. - """ if op is None and not ignore_existing: raise ValueError("Trying to reset colocation (op is None) but " @@ -4050,14 +4139,17 @@ class Graph(object): # In the future, a caller may specify that device_functions win # over colocation, in which case we can add support. device_fn_tmp = self._device_function_stack - self._device_function_stack = [] + self._device_function_stack = traceable_stack.TraceableStack() if ignore_existing: current_stack = self._colocation_stack - self._colocation_stack = [] + self._colocation_stack = traceable_stack.TraceableStack() if op is not None: - self._colocation_stack.append(op) + # offset refers to the stack frame used for storing code location. + # We use 4, the sum of 1 to use our caller's stack frame and 3 + # to jump over layers of context managers above us. + self._colocation_stack.push_obj(op, offset=4) try: yield @@ -4065,12 +4157,19 @@ class Graph(object): # Restore device function stack self._device_function_stack = device_fn_tmp if op is not None: - self._colocation_stack.pop() + self._colocation_stack.pop_obj() # Reset the colocation stack if requested. if ignore_existing: self._colocation_stack = current_stack + def _add_device_to_stack(self, device_name_or_function, offset=0): + """Add device to stack manually, separate from a context manager.""" + total_offset = 1 + offset + spec = _UserDeviceSpec(device_name_or_function) + self._device_function_stack.push_obj(spec, offset=total_offset) + return spec + @tf_contextlib.contextmanager def device(self, device_name_or_function): # pylint: disable=line-too-long @@ -4128,31 +4227,26 @@ class Graph(object): Yields: A context manager that specifies the default device to use for newly created ops. - """ - # pylint: enable=line-too-long - if (device_name_or_function is not None and - not callable(device_name_or_function)): - device_function = pydev.merge_device(device_name_or_function) - else: - device_function = device_name_or_function - + self._add_device_to_stack(device_name_or_function, offset=2) try: - self._device_function_stack.append(device_function) yield finally: - self._device_function_stack.pop() + self._device_function_stack.pop_obj() def _apply_device_functions(self, op): """Applies the current device function stack to the given operation.""" - # Apply any device functions in reverse order, so that the most recently + # Apply any device functions in LIFO order, so that the most recently # pushed function has the first chance to apply a device to the op. # We apply here because the result can depend on the Operation's # signature, which is computed in the Operation constructor. - for device_function in reversed(self._device_function_stack): - if device_function is None: + # pylint: disable=protected-access + for device_spec in self._device_function_stack.peek_objs(): + if device_spec.function is None: break - op._set_device(device_function(op)) # pylint: disable=protected-access + op._set_device(device_spec.function(op)) + op._device_code_locations = self._snapshot_device_function_stack_metadata() + # pylint: enable=protected-access # pylint: disable=g-doc-return-or-yield @tf_contextlib.contextmanager @@ -4201,8 +4295,8 @@ class Graph(object): yields the container name. """ original_container = self._container + self._container = container_name try: - self._container = container_name yield self._container finally: self._container = original_container @@ -4676,35 +4770,74 @@ class Graph(object): if self._stack_state_is_thread_local: # This may be called from a thread where device_function_stack doesn't yet # exist. + # pylint: disable=protected-access if not hasattr(self._thread_local, "_device_function_stack"): - self._thread_local._device_function_stack = ( - self._graph_device_function_stack[:]) + stack_copy_for_this_thread = self._graph_device_function_stack.copy() + self._thread_local._device_function_stack = stack_copy_for_this_thread return self._thread_local._device_function_stack + # pylint: enable=protected-access else: return self._graph_device_function_stack + @property + def _device_functions_outer_to_inner(self): + user_device_specs = self._device_function_stack.peek_objs() + device_functions = [spec.function for spec in user_device_specs] + device_functions_outer_to_inner = list(reversed(device_functions)) + return device_functions_outer_to_inner + + def _snapshot_device_function_stack_metadata(self): + """Return device function stack as a list of TraceableObjects. + + Returns: + [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj + member is a displayable name for the user's argument to Graph.device, and + the filename and lineno members point to the code location where + Graph.device was called directly or indirectly by the user. + """ + traceable_objects = self._device_function_stack.peek_traceable_objs() + snapshot = [] + for obj in traceable_objects: + obj_copy = obj.copy_metadata() + obj_copy.obj = obj.obj.display_name + snapshot.append(obj_copy) + return snapshot + @_device_function_stack.setter def _device_function_stack(self, device_function_stack): if self._stack_state_is_thread_local: + # pylint: disable=protected-access self._thread_local._device_function_stack = device_function_stack + # pylint: enable=protected-access else: self._graph_device_function_stack = device_function_stack @property def _colocation_stack(self): + """Return thread-local copy of colocation stack.""" if self._stack_state_is_thread_local: # This may be called from a thread where colocation_stack doesn't yet # exist. + # pylint: disable=protected-access if not hasattr(self._thread_local, "_colocation_stack"): - self._thread_local._colocation_stack = self._graph_colocation_stack[:] + stack_copy_for_this_thread = self._graph_colocation_stack.copy() + self._thread_local._colocation_stack = stack_copy_for_this_thread return self._thread_local._colocation_stack + # pylint: enable=protected-access else: return self._graph_colocation_stack + def _snapshot_colocation_stack_metadata(self): + """Return colocation stack metadata as a dictionary.""" + traceable_objects = self._colocation_stack.peek_traceable_objs() + return {obj.obj.name: obj.copy_metadata() for obj in traceable_objects} + @_colocation_stack.setter def _colocation_stack(self, colocation_stack): if self._stack_state_is_thread_local: + # pylint: disable=protected-access self._thread_local._colocation_stack = colocation_stack + # pylint: enable=protected-access else: self._graph_colocation_stack = colocation_stack @@ -4727,6 +4860,32 @@ class Graph(object): else: self._graph_control_dependencies_stack = control_dependencies + @property + def _distribution_strategy_stack(self): + """A stack to maintain distribution strategy context for each thread.""" + if not hasattr(self._thread_local, "_distribution_strategy_stack"): + self._thread_local._distribution_strategy_stack = [] # pylint: disable=protected-access + return self._thread_local._distribution_strategy_stack # pylint: disable=protected-access + + @_distribution_strategy_stack.setter + def _distribution_strategy_stack(self, _distribution_strategy_stack): + self._thread_local._distribution_strategy_stack = ( # pylint: disable=protected-access + _distribution_strategy_stack) + + def _mutation_lock(self): + """Returns a lock to guard code that creates & mutates ops. + + See the comment for self._group_lock for more info. + """ + return self._group_lock.group(_MUTATION_LOCK_GROUP) + + def _session_run_lock(self): + """Returns a lock to guard code for Session.run. + + See the comment for self._group_lock for more info. + """ + return self._group_lock.group(_SESSION_RUN_LOCK_GROUP) + # TODO(agarwal): currently device directives in an outer eager scope will not # apply to inner graph mode code. Fix that. @@ -4737,7 +4896,7 @@ def device(device_name_or_function): """Wrapper for `Graph.device()` using the default graph. See - @{tf.Graph.device} + `tf.Graph.device` for more details. Args: @@ -4803,7 +4962,7 @@ def colocate_with(op, ignore_existing=False): def control_dependencies(control_inputs): """Wrapper for `Graph.control_dependencies()` using the default graph. - See @{tf.Graph.control_dependencies} + See `tf.Graph.control_dependencies` for more details. When eager execution is enabled, any callable object in the `control_inputs` @@ -4859,8 +5018,8 @@ class _DefaultStack(threading.local): @tf_contextlib.contextmanager def get_controller(self, default): """A context manager for manipulating a default stack.""" + self.stack.append(default) try: - self.stack.append(default) yield default finally: # stack may be empty if reset() was called @@ -5048,13 +5207,15 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access @tf_contextlib.contextmanager def get_controller(self, default): + context.context().context_switches.push( + default.building_function, default.as_default) try: - context.context().context_switches.push( - default.building_function, default.as_default) with super(_DefaultGraphStack, self).get_controller( default) as g, context.graph_mode(): yield g finally: + # If an exception is raised here it may be hiding a related exception in + # the try-block (just above). context.context().context_switches.pop() @@ -5090,6 +5251,9 @@ def init_scope(): `init_scope` will simply install a fresh graph as the default one. (3) The gradient tape is paused while the scope is active. + + Raises: + RuntimeError: if graph state is incompatible with this initialization. """ # pylint: enable=g-doc-return-or-yield,line-too-long @@ -5102,10 +5266,10 @@ def init_scope(): # the name scope of the current context. default_graph = get_default_graph() scope = default_graph.get_name_scope() - if scope and scope[-1] != '/': + if scope and scope[-1] != "/": # Names that end with trailing slashes are treated by `name_scope` as # absolute. - scope = scope + '/' + scope = scope + "/" inner_device_stack = default_graph._device_function_stack # pylint: disable=protected-access outer_context = None @@ -5150,18 +5314,21 @@ def init_scope(): outer_graph._device_function_stack = inner_device_stack # pylint: disable=protected-access yield finally: + # If an exception is raised here it may be hiding a related exception in + # try-block (just above). if outer_graph is not None: outer_graph._device_function_stack = outer_device_stack # pylint: disable=protected-access @tf_export("enable_eager_execution") -def enable_eager_execution(config=None, device_policy=None, +def enable_eager_execution(config=None, + device_policy=None, execution_mode=None): """Enables eager execution for the lifetime of this program. Eager execution provides an imperative interface to TensorFlow. With eager execution enabled, TensorFlow functions execute operations immediately (as - opposed to adding to a graph to be executed later in a @{tf.Session}) and + opposed to adding to a graph to be executed later in a `tf.Session`) and return concrete values (as opposed to symbolic references to a node in a computational graph). @@ -5181,9 +5348,9 @@ def enable_eager_execution(config=None, device_policy=None, both with and without eager execution). Args: - config: (Optional.) A @{tf.ConfigProto} to use to configure the environment - in which operations are executed. Note that @{tf.ConfigProto} is also - used to configure graph execution (via @{tf.Session}) and many options + config: (Optional.) A `tf.ConfigProto` to use to configure the environment + in which operations are executed. Note that `tf.ConfigProto` is also + used to configure graph execution (via `tf.Session`) and many options within `tf.ConfigProto` are not implemented (or are irrelevant) when eager execution is enabled. device_policy: (Optional.) Policy controlling how operations requiring @@ -5215,6 +5382,34 @@ def enable_eager_execution(config=None, device_policy=None, TensorFlow graph, or if options provided conflict with a previous call to this function. """ + return enable_eager_execution_internal( + config=config, + device_policy=device_policy, + execution_mode=execution_mode, + server_def=None) + + +def enable_eager_execution_internal(config=None, + device_policy=None, + execution_mode=None, + server_def=None): + """Enables eager execution for the lifetime of this program. + + Most of the doc string for enable_eager_execution is relevant here as well. + Args: + config: See enable_eager_execution doc string + device_policy: See enable_eager_execution doc string + execution_mode: See enable_eager_execution doc string + server_def: (Optional.) A tensorflow::ServerDef proto. + Enables execution on remote devices. GrpcServers need to be started by + creating an identical server_def to this, and setting the appropriate + task_indexes, so that the servers can communicate. It will then be + possible to execute operations on remote devices. + + Raises: + ValueError + + """ if config is not None and not isinstance(config, config_pb2.ConfigProto): raise TypeError( "config must be a tf.ConfigProto, but got %s" % type(config)) @@ -5242,7 +5437,8 @@ def enable_eager_execution(config=None, device_policy=None, context._context = context.Context( config=config, device_policy=device_policy, - execution_mode=execution_mode) + execution_mode=execution_mode, + server_def=server_def) elif ((config is not None and config is not context._context._config) or (device_policy is not None and device_policy is not context._context._device_policy) or @@ -5454,7 +5650,7 @@ class GraphKeys(object): * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared across distributed environment (model variables are subset of these). See - @{tf.global_variables} + `tf.global_variables` for more details. Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`, and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`. @@ -5466,19 +5662,19 @@ class GraphKeys(object): `tf.contrib.framework.model_variable` to add to this collection. * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will be trained by an optimizer. See - @{tf.trainable_variables} + `tf.trainable_variables` for more details. * `SUMMARIES`: the summary `Tensor` objects that have been created in the graph. See - @{tf.summary.merge_all} + `tf.summary.merge_all` for more details. * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to produce input for a computation. See - @{tf.train.start_queue_runners} + `tf.train.start_queue_runners` for more details. * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also keep moving averages. See - @{tf.moving_average_variables} + `tf.moving_average_variables` for more details. * `REGULARIZATION_LOSSES`: regularization losses collected during graph construction. @@ -5592,7 +5788,7 @@ class GraphKeys(object): def add_to_collection(name, value): """Wrapper for `Graph.add_to_collection()` using the default graph. - See @{tf.Graph.add_to_collection} + See `tf.Graph.add_to_collection` for more details. Args: @@ -5611,7 +5807,7 @@ def add_to_collection(name, value): def add_to_collections(names, value): """Wrapper for `Graph.add_to_collections()` using the default graph. - See @{tf.Graph.add_to_collections} + See `tf.Graph.add_to_collections` for more details. Args: @@ -5631,7 +5827,7 @@ def add_to_collections(names, value): def get_collection_ref(key): """Wrapper for `Graph.get_collection_ref()` using the default graph. - See @{tf.Graph.get_collection_ref} + See `tf.Graph.get_collection_ref` for more details. Args: @@ -5655,7 +5851,7 @@ def get_collection_ref(key): def get_collection(key, scope=None): """Wrapper for `Graph.get_collection()` using the default graph. - See @{tf.Graph.get_collection} + See `tf.Graph.get_collection` for more details. Args: @@ -5698,7 +5894,7 @@ class name_scope(object): # pylint: disable=invalid-name This context manager validates that the given `values` are from the same graph, makes that graph the default graph, and pushes a name scope in that graph (see - @{tf.Graph.name_scope} + `tf.Graph.name_scope` for more details on that). For example, to define a new Python op called `my_op`: diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 750df4d8e3926fbaee7a38978457e448c21d64c7..318387c61b29c0ed63e6b9f77d7617ffecbea780 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import gc +import os import threading import weakref @@ -1690,7 +1691,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): # e should be dominated by c. self.assertEqual(e.op.control_inputs, []) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEager(self): def future(): future.calls += 1 @@ -1875,7 +1876,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): class OpScopeTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNames(self): with ops.name_scope("foo") as foo: self.assertEqual("foo/", foo) @@ -1906,7 +1907,7 @@ class OpScopeTest(test_util.TensorFlowTestCase): with ops.name_scope("a//b/c") as foo10: self.assertEqual("a//b/c/", foo10) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerDefaultScopeName(self): with ops.name_scope(None, "default") as scope: self.assertEqual(scope, "default/") @@ -2542,6 +2543,56 @@ class StatisticsTest(test_util.TensorFlowTestCase): self.assertEqual(3, flops_total.value) +class DeviceStackTest(test_util.TensorFlowTestCase): + + def testBasicDeviceAssignmentMetadata(self): + + def device_func(unused_op): + return "/cpu:*" + + const_zero = constant_op.constant([0.0], name="zero") + with ops.device("/cpu"): + const_one = constant_op.constant([1.0], name="one") + with ops.device("/cpu:0"): + const_two = constant_op.constant([2.0], name="two") + with ops.device(device_func): + const_three = constant_op.constant(3.0, name="three") + + self.assertEqual(0, len(const_zero.op._device_assignments)) + + one_list = const_one.op._device_assignments + self.assertEqual(1, len(one_list)) + self.assertEqual("/cpu", one_list[0].obj) + self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename)) + + two_list = const_two.op._device_assignments + self.assertEqual(2, len(two_list)) + devices = [t.obj for t in two_list] + self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices)) + + three_list = const_three.op._device_assignments + self.assertEqual(1, len(three_list)) + func_description = three_list[0].obj + expected_regex = r"device_func<.*ops_test.py, [0-9]+" + self.assertRegexpMatches(func_description, expected_regex) + + def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self): + + with ops.device("/cpu"): + const_one = constant_op.constant([1.0], name="one") + with ops.get_default_graph().device("/cpu"): + const_two = constant_op.constant([2.0], name="two") + + one_metadata = const_one.op._device_assignments[0] + two_metadata = const_two.op._device_assignments[0] + + # Verify both types of device assignment return the right stack info. + self.assertRegexpMatches("ops_test.py", + os.path.basename(one_metadata.filename)) + self.assertEqual(one_metadata.filename, two_metadata.filename) + self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno) + + class ColocationGroupTest(test_util.TensorFlowTestCase): def testBasic(self): @@ -2554,6 +2605,18 @@ class ColocationGroupTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): c.op.get_attr("_class") + def testBasicColocationMetadata(self): + const_two = constant_op.constant([2.0], name="two") + with ops.colocate_with(const_two.op): + const_three = constant_op.constant(3.0, name="three") + locations_dict = const_three.op._colocation_dict + self.assertIn("two", locations_dict) + metadata = locations_dict["two"] + self.assertIsNone(metadata.obj) + # Check that this test's filename is recorded as the file containing the + # colocation statement. + self.assertEqual("ops_test.py", os.path.basename(metadata.filename)) + def testColocationDeviceInteraction(self): with ops.device("/cpu:0"): with ops.device("/device:GPU:0"): @@ -2665,6 +2728,28 @@ class ColocationGroupTest(test_util.TensorFlowTestCase): self.assertEqual("/device:CPU:0", b.device) + def testMakeColocationConflictMessage(self): + """Test that provides an example of a complicated error message.""" + # We could test the message with any ops, but this test will be more + # instructive with a real colocation conflict. + with ops.device("/device:GPU:0"): + a = constant_op.constant([2.0], name="a") + with ops.colocate_with(a.op): + with ops.device("/cpu:0"): + b = constant_op.constant([3.0], name="b") + # The definition-location of the nodes will be wrong because of running + # from within a TF unittest. The rest of the info should be correct. + message = ops.get_default_graph()._make_colocation_conflict_message(a.op, + b.op) + self.assertRegexpMatches(message, + r"Tried to colocate op 'a' \(defined at.*\)") + self.assertRegexpMatches(message, "No node-device.*'a'") + self.assertRegexpMatches(message, "Device assignments active.*'a'") + self.assertRegexpMatches(message, "GPU:0") + self.assertRegexpMatches(message, "Node-device colocations active.*'b'") + self.assertRegexpMatches(message, "Device assignments active.*'b'") + self.assertRegexpMatches(message, "cpu:0") + class DeprecatedTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index ec3748b40ec53814f036ca3463c1840d31bc1140..76d4c2017cac46761c53cabddd4a6506e519f136 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -943,6 +943,7 @@ from tensorflow.python.framework import common_shapes as _common_shapes from tensorflow.python.framework import op_def_registry as _op_def_registry from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import op_def_library as _op_def_library +from tensorflow.python.util.deprecation import deprecated_endpoints from tensorflow.python.util.tf_export import tf_export )"); diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc index 940bffb906db753f3699b6a8d2401741bc50a517..031b4a384ea23033afc5d5e1b3318ee02037842c 100644 --- a/tensorflow/python/framework/python_op_gen_internal.cc +++ b/tensorflow/python/framework/python_op_gen_internal.cc @@ -588,10 +588,12 @@ void GenPythonOp::AddExport() { return; } + // Add @tf_export decorator. strings::StrAppend(&result_, "@tf_export("); // Add all endpoint names to tf_export. bool first_endpoint = true; + std::vector deprecated_endpoints; for (const auto& endpoint : api_def_.endpoint()) { if (!first_endpoint) { strings::StrAppend(&result_, ", "); @@ -601,9 +603,32 @@ void GenPythonOp::AddExport() { string endpoint_name; python_op_gen_internal::GenerateLowerCaseOpName(endpoint.name(), &endpoint_name); + if (endpoint.deprecated()) { + deprecated_endpoints.push_back(endpoint_name); + } strings::StrAppend(&result_, "'", endpoint_name, "'"); } strings::StrAppend(&result_, ")\n"); + + // If all endpoints are deprecated, add @deprecated decorator. + if (!api_def_.deprecation_message().empty()) { + const string instructions = api_def_.deprecation_message(); + strings::StrAppend(&result_, "@deprecated(None, '", instructions, "')\n"); + } + // Add @deprecated_endpoints decorator. + if (!deprecated_endpoints.empty()) { + strings::StrAppend(&result_, "@deprecated_endpoints("); + bool first_endpoint = true; + for (auto& endpoint_name : deprecated_endpoints) { + if (first_endpoint) { + first_endpoint = false; + } else { + strings::StrAppend(&result_, ", "); + } + strings::StrAppend(&result_, "'", endpoint_name, "'"); + } + strings::StrAppend(&result_, ")\n"); + } } void GenPythonOp::AddDefLine(const string& function_name, diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py index b724432e00b0d11de86a0fff9ff31758ad36479f..2f9504889afd07dd9e3fa73e3290efa4b3e0b752 100644 --- a/tensorflow/python/framework/random_seed.py +++ b/tensorflow/python/framework/random_seed.py @@ -43,7 +43,7 @@ def get_seed(op_seed): graph, or for only specific operations. For details on how the graph-level seed interacts with op seeds, see - @{tf.set_random_seed}. + `tf.set_random_seed`. Args: op_seed: integer. diff --git a/tensorflow/python/framework/random_seed_test.py b/tensorflow/python/framework/random_seed_test.py index 194492268631abfa911bd45f13a302c09a2c8bda..6696bffc6c553f3fcf458f52cb9cd386e2711ff4 100644 --- a/tensorflow/python/framework/random_seed_test.py +++ b/tensorflow/python/framework/random_seed_test.py @@ -26,7 +26,7 @@ from tensorflow.python.platform import test class RandomSeedTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRandomSeed(self): test_cases = [ # Each test case is a tuple with input to get_seed: diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index 6a5c6468f77382b2b7e62a6a49d4fb637fed4dc0..a45581190fc1db3dd0ca2df88763f1cd6ae11537 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -205,7 +205,7 @@ class SparseTensor(_TensorLike): Args: feed_dict: A dictionary that maps `Tensor` objects to feed values. - See @{tf.Session.run} for a + See `tf.Session.run` for a description of the valid feed values. session: (Optional.) The `Session` to be used to evaluate this sparse tensor. If none, the default session will be used. diff --git a/tensorflow/python/framework/subscribe.py b/tensorflow/python/framework/subscribe.py index 7797d991da7c1c3a429bbf9e60772f0a1952c723..cee73989743ed123b2c9a6ab4e3021dc5d44a98f 100644 --- a/tensorflow/python/framework/subscribe.py +++ b/tensorflow/python/framework/subscribe.py @@ -47,7 +47,7 @@ def _recursive_apply(tensors, apply_fn): tensors_type = type(tensors) if tensors_type is ops.Tensor: return apply_fn(tensors) - elif tensors_type is variables.Variable: + elif isinstance(tensors, variables.Variable): return apply_fn(tensors.value()) elif isinstance(tensors, (list, tuple)): tensors = [_recursive_apply(t, apply_fn) for t in tensors] diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index c9be3d50056b2838e8cf39c3a17e1cff14e67ea0..bd0f691a619e35a59d152046cfb1e80b74dc1f66 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -500,7 +500,7 @@ class TensorShape(object): may be inferred if there is a registered shape function for `"Foo"`. See @{$adding_an_op#shape-functions-in-c$`Shape functions in C++`} for details of shape functions and how to register them. Alternatively, - the shape may be set explicitly using @{tf.Tensor.set_shape}. + the shape may be set explicitly using `tf.Tensor.set_shape`. """ def __init__(self, dims): diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py index 6676cfcaa334e02208d9ec346de7d266c4700f24..fbea930fe0e6a4545b9a5ac55c0a7684b3cd8e28 100644 --- a/tensorflow/python/framework/tensor_spec.py +++ b/tensorflow/python/framework/tensor_spec.py @@ -34,7 +34,7 @@ class TensorSpec(object): construction and configuration. """ - __slots__ = ["_shape", "_dtype", "_name"] + __slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"] def __init__(self, shape, dtype, name=None): """Creates a TensorSpec. @@ -49,6 +49,10 @@ class TensorSpec(object): not convertible to a `tf.DType`. """ self._shape = tensor_shape.TensorShape(shape) + try: + self._shape_tuple = tuple(self.shape.as_list()) + except ValueError: + self._shape_tuple = None self._dtype = dtypes.as_dtype(dtype) self._name = name @@ -104,6 +108,9 @@ class TensorSpec(object): return "TensorSpec(shape={}, dtype={}, name={})".format( self.shape, repr(self.dtype), repr(self.name)) + def __hash__(self): + return hash((self._shape_tuple, self.dtype)) + def __eq__(self, other): return self.shape == other.shape and self.dtype == other.dtype diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index ca63efbc84dab20850845841e9e212a681b6bb06..b14290c203aa04a4ee77fa4f7a0ab55665bece67 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -67,10 +67,16 @@ def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): [ExtractBitsFromBFloat16(x) for x in proto_values]) +def FastAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): + fast_tensor_util.AppendBFloat16ArrayToTensorProto( + tensor_proto, np.asarray( + proto_values, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16)) + + if _FAST_TENSOR_UTIL_AVAILABLE: _NP_TO_APPEND_FN = { dtypes.bfloat16.as_numpy_dtype: - SlowAppendBFloat16ArrayToTensorProto, + FastAppendBFloat16ArrayToTensorProto, np.float16: _MediumAppendFloat16ArrayToTensorProto, np.float32: @@ -935,8 +941,10 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name def is_tensor(x): # pylint: disable=invalid-name """Check whether `x` is of tensor type. - Check whether an object is a tensor. Equivalent to - `isinstance(x, [tf.Tensor, tf.SparseTensor, tf.Variable])`. + Check whether an object is a tensor. This check is equivalent to calling + `isinstance(x, (tf.Tensor, tf.SparseTensor, tf.Variable))` and also checks + if all the component variables of a MirroredVariable or a TowerLocalVariable + are tensors. Args: x: A python object to check. @@ -944,4 +952,5 @@ def is_tensor(x): # pylint: disable=invalid-name Returns: `True` if `x` is a tensor, `False` if not. """ - return isinstance(x, ops._TensorLike) or ops.is_dense_tensor_like(x) # pylint: disable=protected-access + return (isinstance(x, ops._TensorLike) or ops.is_dense_tensor_like(x) or # pylint: disable=protected-access + (hasattr(x, "is_tensor_like") and x.is_tensor_like)) diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 35fff80c61b98e7603d3b7b5df3cabdb59059a72..395cf43b3f189e7ed61ab4bcf479d24de801f3ef 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -50,13 +50,13 @@ class TensorUtilTest(test.TestCase): def testFloatN(self): t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0]) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } tensor_content: "A \000\000A\240\000\000A\360\000\000" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } tensor_content: "\000\000 A\000\000\240A\000\000\360A" @@ -68,13 +68,13 @@ class TensorUtilTest(test.TestCase): def testFloatTyped(self): t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } tensor_content: "A \000\000A\240\000\000A\360\000\000" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } tensor_content: "\000\000 A\000\000\240A\000\000\360A" @@ -86,13 +86,13 @@ class TensorUtilTest(test.TestCase): def testFloatTypeCoerce(self): t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtypes.float32) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } tensor_content: "A \000\000A\240\000\000A\360\000\000" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } tensor_content: "\000\000 A\000\000\240A\000\000\360A" @@ -105,13 +105,13 @@ class TensorUtilTest(test.TestCase): arr = np.asarray([10, 20, 30], dtype="int") t = tensor_util.make_tensor_proto(arr, dtype=dtypes.float32) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } tensor_content: "A \000\000A\240\000\000A\360\000\000" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } tensor_content: "\000\000 A\000\000\240A\000\000\360A" @@ -123,13 +123,13 @@ class TensorUtilTest(test.TestCase): def testFloatSizes(self): t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[1, 3]) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 1 } dim { size: 3 } } tensor_content: "A \000\000A\240\000\000A\360\000\000" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 1 } dim { size: 3 } } tensor_content: "\000\000 A\000\000\240A\000\000\360A" @@ -141,13 +141,13 @@ class TensorUtilTest(test.TestCase): def testFloatSizes2(self): t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[3, 1]) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } dim { size: 1 } } tensor_content: "A \000\000A\240\000\000A\360\000\000" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } dim { size: 1 } } tensor_content: "\000\000 A\000\000\240A\000\000\360A" @@ -169,13 +169,13 @@ class TensorUtilTest(test.TestCase): t = tensor_util.make_tensor_proto( np.array([[10.0, 20.0, 30.0]], dtype=np.float64)) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_DOUBLE tensor_shape { dim { size: 1 } dim { size: 3 } } tensor_content: "@$\000\000\000\000\000\000@4\000\000\000\000\000\000@>\000\000\000\000\000\000" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_DOUBLE tensor_shape { dim { size: 1 } dim { size: 3 } } tensor_content: "\000\000\000\000\000\000$@\000\000\000\000\000\0004@\000\000\000\000\000\000>@" @@ -206,13 +206,13 @@ class TensorUtilTest(test.TestCase): self.assertEquals(np.float32, a.dtype) self.assertAllClose(np.array([5.0, 20.0, 30.0], dtype=np.float32), a) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } tensor_content: "A \000\000A\240\000\000A\360\000\000" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } tensor_content: "\000\000 A\000\000\240A\000\000\360A" @@ -299,16 +299,16 @@ class TensorUtilTest(test.TestCase): def testIntNDefaultType(self): t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2]) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_INT32 tensor_shape { dim { size: 2 } dim { size: 2 } } - tensor_content: "\000\000\000\\n\000\000\000\024\000\000\000\036\000\000\000(" + tensor_content: "\000\000\000\n\000\000\000\024\000\000\000\036\000\000\000(" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_INT32 tensor_shape { dim { size: 2 } dim { size: 2 } } - tensor_content: "\\n\000\000\000\024\000\000\000\036\000\000\000(\000\000\000" + tensor_content: "\n\000\000\000\024\000\000\000\036\000\000\000(\000\000\000" """, t) a = tensor_util.MakeNdarray(t) self.assertEquals(np.int32, a.dtype) @@ -380,16 +380,16 @@ class TensorUtilTest(test.TestCase): t = tensor_util.make_tensor_proto( [10, 20, 30], shape=[1, 3], dtype=dtypes.int64) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_INT64 tensor_shape { dim { size: 1 } dim { size: 3 } } - tensor_content: "\000\000\000\000\000\000\000\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036" + tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_INT64 tensor_shape { dim { size: 1 } dim { size: 3 } } - tensor_content: "\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000" + tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000" """, t) a = tensor_util.MakeNdarray(t) self.assertEquals(np.int64, a.dtype) @@ -398,16 +398,16 @@ class TensorUtilTest(test.TestCase): def testLongNpArray(self): t = tensor_util.make_tensor_proto(np.array([10, 20, 30])) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_INT64 tensor_shape { dim { size: 3 } } - tensor_content: "\000\000\000\000\000\000\000\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036" + tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_INT64 tensor_shape { dim { size: 3 } } - tensor_content: "\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000" + tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000" """, t) a = tensor_util.MakeNdarray(t) self.assertEquals(np.int64, a.dtype) @@ -419,13 +419,13 @@ class TensorUtilTest(test.TestCase): t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint32) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_QINT32 tensor_shape { dim { size: 3 } } tensor_content: "\000\000\000\025\000\000\000\026\000\000\000\027" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_QINT32 tensor_shape { dim { size: 3 } } tensor_content: "\025\000\000\000\026\000\000\000\027\000\000\000" @@ -435,7 +435,7 @@ class TensorUtilTest(test.TestCase): self.assertAllEqual(np.array(data, dtype=a.dtype), a) t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint8) - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_QUINT8 tensor_shape { dim { size: 3 } } tensor_content: "\025\026\027" @@ -445,7 +445,7 @@ class TensorUtilTest(test.TestCase): self.assertAllEqual(np.array(data, dtype=a.dtype), a) t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint8) - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_QINT8 tensor_shape { dim { size: 3 } } tensor_content: "\025\026\027" @@ -456,13 +456,13 @@ class TensorUtilTest(test.TestCase): t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint16) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_QUINT16 tensor_shape { dim { size: 3 } } tensor_content: "\000\025\000\026\000\027" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_QUINT16 tensor_shape { dim { size: 3 } } tensor_content: "\025\000\026\000\027\000" @@ -473,13 +473,13 @@ class TensorUtilTest(test.TestCase): t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint16) if sys.byteorder == "big": - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_QINT16 tensor_shape { dim { size: 3 } } tensor_content: "\000\025\000\026\000\027" """, t) else: - self.assertProtoEquals(""" + self.assertProtoEquals(r""" dtype: DT_QINT16 tensor_shape { dim { size: 3 } } tensor_content: "\025\000\026\000\027\000" @@ -941,7 +941,7 @@ class ConstantValueTest(test.TestCase): class ConstantValueAsShapeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConstant(self): np_val = np.random.rand(3).astype(np.int32) tf_val = constant_op.constant(np_val) @@ -954,13 +954,13 @@ class ConstantValueAsShapeTest(test.TestCase): tensor_shape.TensorShape([]), tensor_util.constant_value_as_shape(tf_val)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShape(self): tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])) c_val = tensor_util.constant_value_as_shape(tf_val) self.assertEqual(tensor_shape.TensorShape([1, 2, 3]), c_val) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMinusOneBecomesNone(self): tf_val = constant_op.constant([-1, 1, -1], shape=[3]) c_val = tensor_util.constant_value_as_shape(tf_val) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 5582b14249599f96453c0686fd89d56d4985531d..9be6391b0436b9508d0d6930b5fd5de560c05dd3 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +from collections import OrderedDict import contextlib import gc import itertools @@ -27,6 +29,7 @@ import random import re import tempfile import threading +import unittest import numpy as np import six @@ -48,7 +51,6 @@ from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import device_lib from tensorflow.python.client import session -from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import tape # pylint: disable=unused-import from tensorflow.python.framework import device as pydev @@ -61,13 +63,13 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.util import compat from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect from tensorflow.python.util.protobuf import compare from tensorflow.python.util.tf_export import tf_export @@ -414,8 +416,28 @@ def assert_no_new_pyobjects_executing_eagerly(f): f(self, **kwargs) gc.collect() previous_count = len(gc.get_objects()) + collection_sizes_before = { + collection: len(ops.get_collection(collection)) + for collection in ops.get_default_graph().collections} for _ in range(3): f(self, **kwargs) + # Note that gc.get_objects misses anything that isn't subject to garbage + # collection (C types). Collections are a common source of leaks, so we + # test for collection sizes explicitly. + for collection_key in ops.get_default_graph().collections: + collection = ops.get_collection(collection_key) + size_before = collection_sizes_before.get(collection_key, 0) + if len(collection) > size_before: + raise AssertionError( + ("Collection %s increased in size from " + "%d to %d (current items %s).") + % (collection_key, size_before, len(collection), collection)) + # Make sure our collection checks don't show up as leaked memory by + # removing references to temporary variables. + del collection + del collection_key + del size_before + del collection_sizes_before gc.collect() # There should be no new Python objects hanging around. new_count = len(gc.get_objects()) @@ -475,9 +497,7 @@ def assert_no_new_tensors(f): f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. - backprop._zeros_cache.flush() - context.get_default_context().ones_rank_cache().flush() - context.get_default_context().scalar_cache().clear() + context.get_default_context()._clear_caches() # pylint: disable=protected-access gc.collect() tensors_after = [ obj for obj in gc.get_objects() @@ -550,16 +570,88 @@ def assert_no_garbage_created(f): return decorator +def _combine_named_parameters(**kwargs): + """Generate combinations based on its keyword arguments. + + Two sets of returned combinations can be concatenated using +. Their product + can be computed using `times()`. + + Args: + **kwargs: keyword arguments of form `option=[possibilities, ...]` + or `option=the_only_possibility`. + + Returns: + a list of dictionaries for each combination. Keys in the dictionaries are + the keyword argument names. Each key has one value - one of the + corresponding keyword argument values. + """ + if not kwargs: + return [OrderedDict()] + + sort_by_key = lambda k: k[0][0] + kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) + first = list(kwargs.items())[0] + + rest = dict(list(kwargs.items())[1:]) + rest_combined = _combine_named_parameters(**rest) + + key = first[0] + values = first[1] + if not isinstance(values, list): + values = [values] + + combinations = [ + OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) + for v in values + for combined in rest_combined + ] + return combinations + + +def generate_combinations_with_testcase_name(**kwargs): + """Generate combinations based on its keyword arguments using combine(). + + This function calls combine() and appends a testcase name to the list of + dictionaries returned. The 'testcase_name' key is a required for named + parameterized tests. + + Args: + **kwargs: keyword arguments of form `option=[possibilities, ...]` + or `option=the_only_possibility`. + + Returns: + a list of dictionaries for each combination. Keys in the dictionaries are + the keyword argument names. Each key has one value - one of the + corresponding keyword argument values. + """ + combinations = _combine_named_parameters(**kwargs) + named_combinations = [] + for combination in combinations: + assert isinstance(combination, OrderedDict) + name = "".join([ + "_{}_{}".format( + "".join(filter(str.isalnum, key)), + "".join(filter(str.isalnum, str(value)))) + for key, value in combination.items() + ]) + named_combinations.append( + OrderedDict( + list(combination.items()) + [("testcase_name", + "_test{}".format(name))])) + + return named_combinations + + def run_all_in_graph_and_eager_modes(cls): """Execute all test methods in the given class with and without eager.""" - base_decorator = run_in_graph_and_eager_modes() + base_decorator = run_in_graph_and_eager_modes for name, value in cls.__dict__.copy().items(): if callable(value) and name.startswith("test"): setattr(cls, name, base_decorator(value)) return cls -def run_in_graph_and_eager_modes(__unused__=None, +def run_in_graph_and_eager_modes(func=None, config=None, use_gpu=True, reset_test=True, @@ -567,17 +659,17 @@ def run_in_graph_and_eager_modes(__unused__=None, """Execute the decorated test with and without enabling eager execution. This function returns a decorator intended to be applied to test methods in - a @{tf.test.TestCase} class. Doing so will cause the contents of the test + a `tf.test.TestCase` class. Doing so will cause the contents of the test method to be executed twice - once normally, and once with eager execution enabled. This allows unittests to confirm the equivalence between eager - and graph execution (see @{tf.enable_eager_execution}). + and graph execution (see `tf.enable_eager_execution`). For example, consider the following unittest: ```python class MyTests(tf.test.TestCase): - @run_in_graph_and_eager_modes() + @run_in_graph_and_eager_modes def test_foo(self): x = tf.constant([1, 2]) y = tf.constant([3, 4]) @@ -594,7 +686,9 @@ def run_in_graph_and_eager_modes(__unused__=None, Args: - __unused__: Prevents silently skipping tests. + func: function to be annotated. If `func` is None, this method returns a + decorator the can be applied to a function. If `func` is not None this + returns the decorator applied to `func`. config: An optional config_pb2.ConfigProto to use to configure the session when executing graphs. use_gpu: If True, attempt to run as many operations as possible on GPU. @@ -616,20 +710,19 @@ def run_in_graph_and_eager_modes(__unused__=None, eager execution enabled. """ - assert not __unused__, "Add () after run_in_graph_and_eager_modes." - def decorator(f): - def decorated(self, **kwargs): - with context.graph_mode(): - with self.test_session(use_gpu=use_gpu): - f(self, **kwargs) + if tf_inspect.isclass(f): + raise ValueError( + "`run_test_in_graph_and_eager_modes` only supports test methods. " + "Did you mean to use `run_all_tests_in_graph_and_eager_modes`?") - if reset_test: - # This decorator runs the wrapped test twice. - # Reset the test environment between runs. - self.tearDown() - self._tempdir = None - self.setUp() + def decorated(self, **kwargs): + try: + with context.graph_mode(): + with self.test_session(use_gpu=use_gpu, config=config): + f(self, **kwargs) + except unittest.case.SkipTest: + pass def run_eagerly(self, **kwargs): if not use_gpu: @@ -644,10 +737,20 @@ def run_in_graph_and_eager_modes(__unused__=None, assert_no_garbage_created(run_eagerly)) with context.eager_mode(): + if reset_test: + # This decorator runs the wrapped test twice. + # Reset the test environment between runs. + self.tearDown() + self._tempdir = None + self.setUp() + run_eagerly(self, **kwargs) return decorated + if func is not None: + return decorator(func) + return decorator @@ -830,14 +933,13 @@ class TensorFlowTestCase(googletest.TestCase): def _eval_tensor(self, tensor): if tensor is None: return None - elif isinstance(tensor, ops.EagerTensor): - return tensor.numpy() - elif isinstance(tensor, resource_variable_ops.ResourceVariable): - return tensor.read_value().numpy() elif callable(tensor): return self._eval_helper(tensor()) else: - raise ValueError("Unsupported type %s." % type(tensor)) + try: + return tensor.numpy() + except AttributeError as e: + six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e) def _eval_helper(self, tensors): if tensors is None: @@ -1196,8 +1298,8 @@ class TensorFlowTestCase(googletest.TestCase): a = a._asdict() if hasattr(b, "_asdict"): b = b._asdict() - a_is_dict = isinstance(a, dict) - if a_is_dict != isinstance(b, dict): + a_is_dict = isinstance(a, collections.Mapping) + if a_is_dict != isinstance(b, collections.Mapping): raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" % (path_str, path_str, msg)) if a_is_dict: diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 0178908bcc9c0613353e3beea8e1eb11638f9531..f983cbef0415de98fd6cae717f12fbcde3fe78bd 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -73,7 +73,7 @@ class TestUtilTest(test_util.TensorFlowTestCase): test_util.assert_equal_graph_def(def_57, def_75) # Compare two unequal graphs with self.assertRaisesRegexp(AssertionError, - r"^Found unexpected node 'seven"): + r"^Found unexpected node '{{node seven}}"): test_util.assert_equal_graph_def(def_57, def_empty) def testIsGoogleCudaEnabled(self): @@ -569,7 +569,7 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertEqual(a_np_rand, b_np_rand) self.assertEqual(a_rand, b_rand) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_callable_evaluate(self): def model(): return resource_variable_ops.ResourceVariable( @@ -578,7 +578,7 @@ class TestUtilTest(test_util.TensorFlowTestCase): with context.eager_mode(): self.assertEqual(2, self.evaluate(model)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_nested_tensors_evaluate(self): expected = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}} nested = {"a": constant_op.constant(1), @@ -588,6 +588,27 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertEqual(expected, self.evaluate(nested)) + def test_run_in_graph_and_eager_modes(self): + l = [] + def inc(self, with_brackets): + del self # self argument is required by run_in_graph_and_eager_modes. + mode = "eager" if context.executing_eagerly() else "graph" + with_brackets = "with_brackets" if with_brackets else "without_brackets" + l.append((with_brackets, mode)) + + f = test_util.run_in_graph_and_eager_modes(inc) + f(self, with_brackets=False) + f = test_util.run_in_graph_and_eager_modes()(inc) + f(self, with_brackets=True) + + self.assertEqual(len(l), 4) + self.assertEqual(set(l), { + ("with_brackets", "graph"), + ("with_brackets", "eager"), + ("without_brackets", "graph"), + ("without_brackets", "eager"), + }) + def test_get_node_def_from_graph(self): graph_def = graph_pb2.GraphDef() node_foo = graph_def.node.add() @@ -595,6 +616,55 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo) self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def)) + def test_run_in_eager_and_graph_modes_test_class(self): + msg = "`run_test_in_graph_and_eager_modes` only supports test methods.*" + with self.assertRaisesRegexp(ValueError, msg): + @test_util.run_in_graph_and_eager_modes() + class Foo(object): + pass + del Foo # Make pylint unused happy. + + def test_run_in_eager_and_graph_modes_skip_graph_runs_eager(self): + modes = [] + def _test(self): + if not context.executing_eagerly(): + self.skipTest("Skipping in graph mode") + modes.append("eager" if context.executing_eagerly() else "graph") + test_util.run_in_graph_and_eager_modes(_test)(self) + self.assertEqual(modes, ["eager"]) + + def test_run_in_eager_and_graph_modes_skip_eager_runs_graph(self): + modes = [] + def _test(self): + if context.executing_eagerly(): + self.skipTest("Skipping in eager mode") + modes.append("eager" if context.executing_eagerly() else "graph") + test_util.run_in_graph_and_eager_modes(_test)(self) + self.assertEqual(modes, ["graph"]) + + def test_run_in_graph_and_eager_modes_setup_in_same_mode(self): + modes = [] + mode_name = lambda: "eager" if context.executing_eagerly() else "graph" + + class ExampleTest(test_util.TensorFlowTestCase): + + def runTest(self): + pass + + def setUp(self): + modes.append("setup_" + mode_name()) + + @test_util.run_in_graph_and_eager_modes + def testBody(self): + modes.append("run_" + mode_name()) + + e = ExampleTest() + e.setUp() + e.testBody() + + self.assertEqual(modes[0:2], ["setup_graph", "run_graph"]) + self.assertEqual(modes[2:], ["setup_eager", "run_eager"]) + class GarbageCollectionTest(test_util.TensorFlowTestCase): @@ -619,7 +689,7 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase): ReferenceCycleTest().test_has_no_cycle() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_no_leaked_tensor_decorator(self): class LeakedTensorTest(object): diff --git a/tensorflow/python/framework/traceable_stack.py b/tensorflow/python/framework/traceable_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..7f4d28237ffba80e5aa604b880fccf00482a9ca5 --- /dev/null +++ b/tensorflow/python/framework/traceable_stack.py @@ -0,0 +1,132 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A simple stack that associates filename and line numbers with each object.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.util import tf_stack + + +class TraceableObject(object): + """Wrap an object together with its the code definition location.""" + + # Return codes for the set_filename_and_line_from_caller() method. + SUCCESS, HEURISTIC_USED, FAILURE = (0, 1, 2) + + def __init__(self, obj, filename=None, lineno=None): + self.obj = obj + self.filename = filename + self.lineno = lineno + + def set_filename_and_line_from_caller(self, offset=0): + """Set filename and line using the caller's stack frame. + + If the requested stack information is not available, a heuristic may + be applied and self.HEURISTIC USED will be returned. If the heuristic + fails then no change will be made to the filename and lineno members + (None by default) and self.FAILURE will be returned. + + Args: + offset: Integer. If 0, the caller's stack frame is used. If 1, + the caller's caller's stack frame is used. Larger values are + permissible but if out-of-range (larger than the number of stack + frames available) the outermost stack frame will be used. + + Returns: + TraceableObject.SUCCESS if appropriate stack information was found, + TraceableObject.HEURISTIC_USED if the offset was larger than the stack, + and TraceableObject.FAILURE if the stack was empty. + """ + # Offset is defined in "Args" as relative to the caller. We are one frame + # beyond the caller. + local_offset = offset + 1 + + frame_records = tf_stack.extract_stack() + if not frame_records: + return self.FAILURE + if len(frame_records) >= local_offset: + # Negative indexing is one-indexed instead of zero-indexed. + negative_offset = -(local_offset + 1) + self.filename, self.lineno = frame_records[negative_offset][:2] + return self.SUCCESS + else: + # If the offset is too large then we use the largest offset possible, + # meaning we use the outermost stack frame at index 0. + self.filename, self.lineno = frame_records[0][:2] + return self.HEURISTIC_USED + + def copy_metadata(self): + """Return a TraceableObject like this one, but without the object.""" + return self.__class__(None, filename=self.filename, lineno=self.lineno) + + +class TraceableStack(object): + """A stack of TraceableObjects.""" + + def __init__(self, existing_stack=None): + """Constructor. + + Args: + existing_stack: [TraceableObject, ...] If provided, this object will + set its new stack to a SHALLOW COPY of existing_stack. + """ + self._stack = existing_stack[:] if existing_stack else [] + + def push_obj(self, obj, offset=0): + """Add object to the stack and record its filename and line information. + + Args: + obj: An object to store on the stack. + offset: Integer. If 0, the caller's stack frame is used. If 1, + the caller's caller's stack frame is used. + + Returns: + TraceableObject.SUCCESS if appropriate stack information was found, + TraceableObject.HEURISTIC_USED if the stack was smaller than expected, + and TraceableObject.FAILURE if the stack was empty. + """ + traceable_obj = TraceableObject(obj) + self._stack.append(traceable_obj) + # Offset is defined in "Args" as relative to the caller. We are 1 frame + # beyond the caller and need to compensate. + return traceable_obj.set_filename_and_line_from_caller(offset + 1) + + def pop_obj(self): + """Remove last-inserted object and return it, without filename/line info.""" + return self._stack.pop().obj + + def peek_objs(self): + """Return list of stored objects ordered newest to oldest.""" + return [t_obj.obj for t_obj in reversed(self._stack)] + + def peek_traceable_objs(self): + """Return list of stored TraceableObjects ordered newest to oldest.""" + return list(reversed(self._stack)) + + def __len__(self): + """Return number of items on the stack, and used for truth-value testing.""" + return len(self._stack) + + def copy(self): + """Return a copy of self referencing the same objects but in a new list. + + This method is implemented to support thread-local stacks. + + Returns: + TraceableStack with a new list that holds existing objects. + """ + return TraceableStack(self._stack) diff --git a/tensorflow/python/framework/traceable_stack_test.py b/tensorflow/python/framework/traceable_stack_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3e7876f6318da368a373ca554e674a21b0d869c3 --- /dev/null +++ b/tensorflow/python/framework/traceable_stack_test.py @@ -0,0 +1,133 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.framework.traceable_stack.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import test_util +from tensorflow.python.framework import traceable_stack +from tensorflow.python.platform import googletest +from tensorflow.python.util import tf_inspect as inspect + +_LOCAL_OBJECT = lambda x: x +_THIS_FILENAME = inspect.getsourcefile(_LOCAL_OBJECT) + + +class TraceableObjectTest(test_util.TensorFlowTestCase): + + def testSetFilenameAndLineFromCallerUsesCallersStack(self): + t_obj = traceable_stack.TraceableObject(17) + + # Do not separate placeholder from the set_filename_and_line_from_caller() + # call one line below it as it is used to calculate the latter's line + # number. + placeholder = lambda x: x + result = t_obj.set_filename_and_line_from_caller() + + expected_lineno = inspect.getsourcelines(placeholder)[1] + 1 + self.assertEqual(expected_lineno, t_obj.lineno) + self.assertEqual(_THIS_FILENAME, t_obj.filename) + self.assertEqual(t_obj.SUCCESS, result) + + def testSetFilenameAndLineFromCallerRespectsOffset(self): + + def call_set_filename_and_line_from_caller(t_obj): + # We expect to retrieve the line number from _our_ caller. + return t_obj.set_filename_and_line_from_caller(offset=1) + + t_obj = traceable_stack.TraceableObject(None) + # Do not separate placeholder from the + # call_set_filename_and_line_from_caller() call one line below it as it is + # used to calculate the latter's line number. + placeholder = lambda x: x + result = call_set_filename_and_line_from_caller(t_obj) + + expected_lineno = inspect.getsourcelines(placeholder)[1] + 1 + self.assertEqual(expected_lineno, t_obj.lineno) + self.assertEqual(t_obj.SUCCESS, result) + + def testSetFilenameAndLineFromCallerHandlesRidiculousOffset(self): + t_obj = traceable_stack.TraceableObject('The quick brown fox.') + # This line shouldn't die. + result = t_obj.set_filename_and_line_from_caller(offset=300) + + # We expect a heuristic to be used because we are not currently 300 frames + # down on the stack. The filename and lineno of the outermost frame are not + # predictable -- in some environments the filename is this test file, but in + # other environments it is not (e.g. due to a test runner calling this + # file). Therefore we only test that the called function knows it applied a + # heuristic for the ridiculous stack offset. + self.assertEqual(t_obj.HEURISTIC_USED, result) + + +class TraceableStackTest(test_util.TensorFlowTestCase): + + def testPushPeekPopObj(self): + t_stack = traceable_stack.TraceableStack() + t_stack.push_obj(42.0) + t_stack.push_obj('hope') + + expected_lifo_peek = ['hope', 42.0] + self.assertEqual(expected_lifo_peek, t_stack.peek_objs()) + + self.assertEqual('hope', t_stack.pop_obj()) + self.assertEqual(42.0, t_stack.pop_obj()) + + def testPushPopPreserveLifoOrdering(self): + t_stack = traceable_stack.TraceableStack() + t_stack.push_obj(0) + t_stack.push_obj(1) + t_stack.push_obj(2) + t_stack.push_obj(3) + + obj_3 = t_stack.pop_obj() + obj_2 = t_stack.pop_obj() + obj_1 = t_stack.pop_obj() + obj_0 = t_stack.pop_obj() + + self.assertEqual(3, obj_3) + self.assertEqual(2, obj_2) + self.assertEqual(1, obj_1) + self.assertEqual(0, obj_0) + + def testPushObjSetsFilenameAndLineInfoForCaller(self): + t_stack = traceable_stack.TraceableStack() + + # We expect that the line number recorded for the 1-object will come from + # the call to t_stack.push_obj(1). Do not separate the next two lines! + placeholder_1 = lambda x: x + t_stack.push_obj(1) + + # We expect that the line number recorded for the 2-object will come from + # the call to call_push_obj() and _not_ the call to t_stack.push_obj(). + def call_push_obj(obj): + t_stack.push_obj(obj, offset=1) + + # Do not separate the next two lines! + placeholder_2 = lambda x: x + call_push_obj(2) + + expected_lineno_1 = inspect.getsourcelines(placeholder_1)[1] + 1 + expected_lineno_2 = inspect.getsourcelines(placeholder_2)[1] + 1 + + t_obj_2, t_obj_1 = t_stack.peek_traceable_objs() + self.assertEqual(expected_lineno_2, t_obj_2.lineno) + self.assertEqual(expected_lineno_1, t_obj_1.lineno) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index af5d709f7e936e0438d5e03f60b44bc0017cb4b6..8cc971c61d5964d0fad1bfa843c3ef8d3407599f 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -158,6 +158,7 @@ def _get_config(layout_optimizer=True): layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, # do not remove duplicated nodes arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) + rewrite_options.min_graph_nodes = -1 graph_options = config_pb2.GraphOptions( rewrite_options=rewrite_options, build_cost_model=1) config = config_pb2.ConfigProto(graph_options=graph_options) @@ -1339,7 +1340,7 @@ class LayoutOptimizerTest(test.TestCase): expected_num_transposes = 2 self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) - self.assertAllEqual(output_val_ref, output_val) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testLoop(self): if test.is_gpu_available(cuda_only=True): @@ -1443,7 +1444,8 @@ class LayoutOptimizerTest(test.TestCase): def testGradient(self): meta_graph = _simple_metagraph() rewrite_options = rewriter_config_pb2.RewriterConfig( - layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON, + min_graph_nodes=-1) optimized_graph = tf_optimizer.OptimizeGraph( rewrite_options, meta_graph, cluster=_get_cluster()) @@ -1457,7 +1459,8 @@ class LayoutOptimizerTest(test.TestCase): def testDepthwise(self): meta_graph = _simple_metagraph(depthwise=True) rewrite_options = rewriter_config_pb2.RewriterConfig( - layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON, + min_graph_nodes=-1) optimized_graph = tf_optimizer.OptimizeGraph( rewrite_options, meta_graph, cluster=_get_cluster()) diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py index 7ed4b128e495c484d294ece40541427f21856cf1..b658edff2dffac9856432c575b9af0d2f0b1986b 100644 --- a/tensorflow/python/grappler/memory_optimizer_test.py +++ b/tensorflow/python/grappler/memory_optimizer_test.py @@ -76,7 +76,8 @@ class MemoryOptimizerSwapTest(test.TestCase): disable_model_pruning=True, meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE, constant_folding=rewriter_config_pb2.RewriterConfig.OFF, - memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) + memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL, + min_graph_nodes=-1) graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) self.assertEqual(len(graph.node), graph_size + 2) @@ -133,6 +134,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase): dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + min_graph_nodes=-1, memory_optimization=rewriter_config_pb2.RewriterConfig. RECOMPUTATION_HEURISTICS), original_metagraph) self.assertGreater( @@ -158,6 +160,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase): dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + min_graph_nodes=-1, memory_optimization=rewriter_config_pb2.RewriterConfig. RECOMPUTATION_HEURISTICS, # Checks that name scope "gradients/" also match sub-scope. @@ -297,6 +300,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase): if 'Recomputed/' in node.name])) rewritten_graph_def = tf_optimizer.OptimizeGraph( rewriter_config_pb2.RewriterConfig( + min_graph_nodes=-1, memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL), metagraph) self.assertEqual( diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py index 1c0f072dd32d38f048cfa48d38b45264951d095e..5a9afe725753749ea42d53382731ab14a3cf24f5 100644 --- a/tensorflow/python/grappler/tf_optimizer_test.py +++ b/tensorflow/python/grappler/tf_optimizer_test.py @@ -47,6 +47,7 @@ class PyWrapOptimizeGraphTest(test.TestCase): rewriter_config = rewriter_config_pb2.RewriterConfig() rewriter_config.optimizers.append('constfold') + rewriter_config.min_graph_nodes = -1 graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) @@ -68,6 +69,7 @@ class PyWrapOptimizeGraphTest(test.TestCase): # Optimize the graph. mg = meta_graph.create_meta_graph_def(graph=g) rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.min_graph_nodes = -1 optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) # Check that the nodes referenced in various collections have been preserved @@ -109,6 +111,7 @@ class PyWrapOptimizeGraphTest(test.TestCase): # Optimize the graph. mg = meta_graph.create_meta_graph_def(graph=g) rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.min_graph_nodes = -1 optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) mg.graph_def.CopyFrom(optimized_graph) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index fe40c9fbed7c041ad6b6dc8cdb1c50b80f57a48f..7eb7884d1d4542fa7b76fc2f8d4267eb09808441 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -25,6 +25,7 @@ py_library( "applications/inception_resnet_v2.py", "applications/inception_v3.py", "applications/mobilenet.py", + "applications/mobilenet_v2.py", "applications/nasnet.py", "applications/resnet50.py", "applications/vgg16.py", @@ -39,6 +40,7 @@ py_library( "datasets/imdb.py", "datasets/mnist.py", "datasets/reuters.py", + "estimator/__init__.py", "preprocessing/__init__.py", "preprocessing/image.py", "preprocessing/sequence.py", @@ -113,12 +115,14 @@ py_library( "constraints.py", "engine/__init__.py", "engine/base_layer.py", + "engine/distributed_training_utils.py", "engine/input_layer.py", "engine/network.py", "engine/saving.py", "engine/sequential.py", "engine/training.py", "engine/training_arrays.py", + "engine/training_distributed.py", "engine/training_eager.py", "engine/training_generator.py", "engine/training_utils.py", @@ -135,7 +139,7 @@ py_library( deps = [ ":backend", "//tensorflow/python/data", - "//tensorflow/python/training/checkpointable:data_structures_base", + "//tensorflow/python/training/checkpointable:data_structures", "@six_archive//:six", ], ) @@ -292,109 +296,15 @@ py_test( ) py_test( - name = "densenet_test", - size = "large", - srcs = ["applications/densenet_test.py"], - srcs_version = "PY2AND3", - tags = ["nomsan"], # times out, http://b/78650237 - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "inception_resnet_v2_test", - size = "medium", - srcs = ["applications/inception_resnet_v2_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "inception_v3_test", - size = "medium", - srcs = ["applications/inception_v3_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "mobilenet_test", - size = "medium", - srcs = ["applications/mobilenet_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "nasnet_test", - size = "large", - srcs = ["applications/nasnet_test.py"], - srcs_version = "PY2AND3", - tags = ["nomsan"], # times out, http://b/78573625 - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "resnet50_test", - size = "medium", - srcs = ["applications/resnet50_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "vgg16_test", - size = "small", - srcs = ["applications/vgg16_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "vgg19_test", - size = "small", - srcs = ["applications/vgg19_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "xception_test", - size = "medium", - srcs = ["applications/xception_test.py"], + name = "applications_test", + size = "enormous", + srcs = ["applications/applications_test.py"], + shard_count = 2, srcs_version = "PY2AND3", deps = [ ":keras", "//tensorflow/python:client_testlib", - "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -450,6 +360,7 @@ cuda_py_test( "//tensorflow/python:client_testlib", ], shard_count = 2, + tags = ["no_windows_gpu"], ) py_test( @@ -549,7 +460,7 @@ py_test( py_test( name = "gru_test", - size = "medium", + size = "large", srcs = ["layers/gru_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], # http://b/62136390 @@ -702,15 +613,14 @@ cuda_py_test( ], ) -py_test( - name = "imagenet_utils_test", +cuda_py_test( + name = "training_gpu_test", size = "small", - srcs = ["applications/imagenet_utils_test.py"], - srcs_version = "PY2AND3", - deps = [ + srcs = ["engine/training_gpu_test.py"], + additional_deps = [ ":keras", - "//tensorflow/python:client_testlib", "//third_party/py/numpy", + "//tensorflow/python:client_testlib", ], ) @@ -765,7 +675,7 @@ py_test( py_test( name = "training_test", - size = "medium", + size = "large", srcs = ["engine/training_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], @@ -789,6 +699,19 @@ py_test( ], ) +py_test( + name = "training_utils_test", + size = "medium", + srcs = ["engine/training_utils_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + py_test( name = "model_subclassing_test", size = "medium", @@ -832,19 +755,20 @@ py_test( py_test( name = "sequential_test", - size = "small", + size = "medium", srcs = ["engine/sequential_test.py"], srcs_version = "PY2AND3", deps = [ ":keras", "//tensorflow/python:client_testlib", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) py_test( name = "models_test", - size = "small", + size = "medium", srcs = ["models_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], # b/67509773 @@ -858,7 +782,7 @@ py_test( py_test( name = "backend_test", - size = "small", + size = "medium", srcs = ["backend_test.py"], srcs_version = "PY2AND3", deps = [ @@ -866,6 +790,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:util", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/keras/__init__.py b/tensorflow/python/keras/__init__.py index 3493069a5bf53ffbfe6447f2c1b3df7ac64cbf3a..198c66d9e184c82423e529540b92ad447b947cf8 100644 --- a/tensorflow/python/keras/__init__.py +++ b/tensorflow/python/keras/__init__.py @@ -27,6 +27,7 @@ from tensorflow.python.keras import backend from tensorflow.python.keras import callbacks from tensorflow.python.keras import constraints from tensorflow.python.keras import datasets +from tensorflow.python.keras import estimator from tensorflow.python.keras import initializers from tensorflow.python.keras import layers from tensorflow.python.keras import losses diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py index f608dea430f0573503713f0cbc60f8921e6df51e..99645de736fc9e3f34c3ea29171cde0f91d8345a 100644 --- a/tensorflow/python/keras/activations.py +++ b/tensorflow/python/keras/activations.py @@ -128,20 +128,26 @@ def softsign(x): @tf_export('keras.activations.relu') -def relu(x, alpha=0., max_value=None): +def relu(x, alpha=0., max_value=None, threshold=0): """Rectified Linear Unit. + With default values, it returns element-wise `max(x, 0)`. + + Otherwise, it follows: + `f(x) = max_value` for `x >= max_value`, + `f(x) = x` for `threshold <= x < max_value`, + `f(x) = alpha * (x - threshold)` otherwise. + Arguments: - x: Input tensor. - alpha: Slope of the negative part. Defaults to zero. - max_value: Maximum value for the output. + x: A tensor or variable. + alpha: A scalar, slope of negative section (default=`0.`). + max_value: float. Saturation threshold. + threshold: float. Threshold value for thresholded activation. Returns: - The (leaky) rectified linear unit activation: `x` if `x > 0`, - `alpha * x` if `x < 0`. If `max_value` is defined, the result - is truncated to this value. + A tensor. """ - return K.relu(x, alpha=alpha, max_value=max_value) + return K.relu(x, alpha=alpha, max_value=max_value, threshold=threshold) @tf_export('keras.activations.tanh') diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py index 062135266dd8b11c489b7dff83b46ae29a0d21e6..cd9462d6b557c9ab3c484c5b98b3f749cadf7ce6 100644 --- a/tensorflow/python/keras/applications/__init__.py +++ b/tensorflow/python/keras/applications/__init__.py @@ -13,17 +13,33 @@ # limitations under the License. # ============================================================================== """Keras Applications are canned architectures with pre-trained weights.""" - +# pylint: disable=g-import-not-at-top from __future__ import absolute_import from __future__ import division from __future__ import print_function +import keras_applications + +from tensorflow.python.keras import backend +from tensorflow.python.keras import engine +from tensorflow.python.keras import layers +from tensorflow.python.keras import models +from tensorflow.python.keras import utils + +keras_applications.set_keras_submodules( + backend=backend, + engine=engine, + layers=layers, + models=models, + utils=utils) + from tensorflow.python.keras.applications.densenet import DenseNet121 from tensorflow.python.keras.applications.densenet import DenseNet169 from tensorflow.python.keras.applications.densenet import DenseNet201 from tensorflow.python.keras.applications.inception_resnet_v2 import InceptionResNetV2 from tensorflow.python.keras.applications.inception_v3 import InceptionV3 from tensorflow.python.keras.applications.mobilenet import MobileNet +# TODO(fchollet): enable MobileNetV2 in next version. from tensorflow.python.keras.applications.nasnet import NASNetLarge from tensorflow.python.keras.applications.nasnet import NASNetMobile from tensorflow.python.keras.applications.resnet50 import ResNet50 diff --git a/tensorflow/python/keras/applications/applications_test.py b/tensorflow/python/keras/applications/applications_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ef3198a937db705e2825eb900e011ab9dac1627d --- /dev/null +++ b/tensorflow/python/keras/applications/applications_test.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration tests for Keras applications.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.keras import applications +from tensorflow.python.platform import test + + +MODEL_LIST = [ + (applications.ResNet50, 2048), + (applications.VGG16, 512), + (applications.VGG19, 512), + (applications.Xception, 2048), + (applications.InceptionV3, 2048), + (applications.InceptionResNetV2, 1536), + (applications.MobileNet, 1024), + # TODO(fchollet): enable MobileNetV2 in next version. + (applications.DenseNet121, 1024), + (applications.DenseNet169, 1664), + (applications.DenseNet201, 1920), + (applications.NASNetMobile, 1056), + (applications.NASNetLarge, 4032), +] + + +class ApplicationsTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters(*MODEL_LIST) + def test_classification_model(self, model_fn, _): + model = model_fn(classes=1000, weights=None) + self.assertEqual(model.output_shape[-1], 1000) + + @parameterized.parameters(*MODEL_LIST) + def test_feature_extration_model(self, model_fn, output_dim): + model = model_fn(include_top=False, weights=None) + self.assertEqual(model.output_shape, (None, None, None, output_dim)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/applications/densenet.py b/tensorflow/python/keras/applications/densenet.py index f81f10719a31e2e79589d3b389049353c992091c..fbdcc66d2d75bb1d580c106a2fb3101522d75095 100644 --- a/tensorflow/python/keras/applications/densenet.py +++ b/tensorflow/python/keras/applications/densenet.py @@ -13,342 +13,25 @@ # limitations under the License. # ============================================================================== # pylint: disable=invalid-name -# pylint: disable=unused-import """DenseNet models for Keras. - -# Reference paper - -- [Densely Connected Convolutional Networks] - (https://arxiv.org/abs/1608.06993) (CVPR 2017 Best Paper Award) """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -from tensorflow.python.keras import backend as K -from tensorflow.python.keras.applications import imagenet_utils -from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.engine.network import get_source_inputs -from tensorflow.python.keras.layers import Activation -from tensorflow.python.keras.layers import AveragePooling2D -from tensorflow.python.keras.layers import BatchNormalization -from tensorflow.python.keras.layers import Concatenate -from tensorflow.python.keras.layers import Conv2D -from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras.layers import Input -from tensorflow.python.keras.layers import MaxPooling2D -from tensorflow.python.keras.layers import ZeroPadding2D -from tensorflow.python.keras.models import Model -from tensorflow.python.keras.utils.data_utils import get_file +from keras_applications import densenet from tensorflow.python.util.tf_export import tf_export - -DENSENET121_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet121_weights_tf_dim_ordering_tf_kernels.h5' -DENSENET121_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5' -DENSENET169_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet169_weights_tf_dim_ordering_tf_kernels.h5' -DENSENET169_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5' -DENSENET201_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet201_weights_tf_dim_ordering_tf_kernels.h5' -DENSENET201_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5' - - -def dense_block(x, blocks, name): - """A dense block. - - Arguments: - x: input tensor. - blocks: integer, the number of building blocks. - name: string, block label. - - Returns: - output tensor for the block. - """ - for i in range(blocks): - x = conv_block(x, 32, name=name + '_block' + str(i + 1)) - return x - - -def transition_block(x, reduction, name): - """A transition block. - - Arguments: - x: input tensor. - reduction: float, compression rate at transition layers. - name: string, block label. - - Returns: - output tensor for the block. - """ - bn_axis = 3 if K.image_data_format() == 'channels_last' else 1 - x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + '_bn')(x) - x = Activation('relu', name=name + '_relu')(x) - x = Conv2D( - int(K.int_shape(x)[bn_axis] * reduction), - 1, - use_bias=False, - name=name + '_conv')( - x) - x = AveragePooling2D(2, strides=2, name=name + '_pool')(x) - return x - - -def conv_block(x, growth_rate, name): - """A building block for a dense block. - - Arguments: - x: input tensor. - growth_rate: float, growth rate at dense layers. - name: string, block label. - - Returns: - output tensor for the block. - """ - bn_axis = 3 if K.image_data_format() == 'channels_last' else 1 - x1 = BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')( - x) - x1 = Activation('relu', name=name + '_0_relu')(x1) - x1 = Conv2D(4 * growth_rate, 1, use_bias=False, name=name + '_1_conv')(x1) - x1 = BatchNormalization( - axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')( - x1) - x1 = Activation('relu', name=name + '_1_relu')(x1) - x1 = Conv2D( - growth_rate, 3, padding='same', use_bias=False, name=name + '_2_conv')( - x1) - x = Concatenate(axis=bn_axis, name=name + '_concat')([x, x1]) - return x - - -def DenseNet(blocks, - include_top=True, - weights='imagenet', - input_tensor=None, - input_shape=None, - pooling=None, - classes=1000): - """Instantiates the DenseNet architecture. - - Optionally loads weights pre-trained - on ImageNet. Note that when using TensorFlow, - for best performance you should set - `image_data_format='channels_last'` in your Keras config - at ~/.keras/keras.json. - - The model and the weights are compatible with - TensorFlow, Theano, and CNTK. The data format - convention used by the model is the one - specified in your Keras config file. - - Arguments: - blocks: numbers of building blocks for the four dense layers. - include_top: whether to include the fully-connected - layer at the top of the network. - weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. - input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) - to use as image input for the model. - input_shape: optional shape tuple, only to be specified - if `include_top` is False (otherwise the input shape - has to be `(224, 224, 3)` (with `channels_last` data format) - or `(3, 224, 224)` (with `channels_first` data format). - It should have exactly 3 inputs channels. - pooling: optional pooling mode for feature extraction - when `include_top` is `False`. - - `None` means that the output of the model will be - the 4D tensor output of the - last convolutional layer. - - `avg` means that global average pooling - will be applied to the output of the - last convolutional layer, and thus - the output of the model will be a 2D tensor. - - `max` means that global max pooling will - be applied. - classes: optional number of classes to classify images - into, only to be specified if `include_top` is True, and - if no `weights` argument is specified. - - Returns: - A Keras model instance. - - Raises: - ValueError: in case of invalid argument for `weights`, - or invalid input shape. - """ - if not (weights in {'imagenet', None} or os.path.exists(weights)): - raise ValueError('The `weights` argument should be either ' - '`None` (random initialization), `imagenet` ' - '(pre-training on ImageNet), ' - 'or the path to the weights file to be loaded.') - - if weights == 'imagenet' and include_top and classes != 1000: - raise ValueError('If using `weights` as imagenet with `include_top`' - ' as true, `classes` should be 1000') - - # Determine proper input shape - input_shape = _obtain_input_shape( - input_shape, - default_size=224, - min_size=221, - data_format=K.image_data_format(), - require_flatten=include_top, - weights=weights) - - if input_tensor is None: - img_input = Input(shape=input_shape) - else: - if not K.is_keras_tensor(input_tensor): - img_input = Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - - bn_axis = 3 if K.image_data_format() == 'channels_last' else 1 - - x = ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input) - x = Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x) - x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(x) - x = Activation('relu', name='conv1/relu')(x) - x = ZeroPadding2D(padding=((1, 1), (1, 1)))(x) - x = MaxPooling2D(3, strides=2, name='pool1')(x) - - x = dense_block(x, blocks[0], name='conv2') - x = transition_block(x, 0.5, name='pool2') - x = dense_block(x, blocks[1], name='conv3') - x = transition_block(x, 0.5, name='pool3') - x = dense_block(x, blocks[2], name='conv4') - x = transition_block(x, 0.5, name='pool4') - x = dense_block(x, blocks[3], name='conv5') - - x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x) - - if include_top: - x = GlobalAveragePooling2D(name='avg_pool')(x) - x = Dense(classes, activation='softmax', name='fc1000')(x) - else: - if pooling == 'avg': - x = GlobalAveragePooling2D(name='avg_pool')(x) - elif pooling == 'max': - x = GlobalMaxPooling2D(name='max_pool')(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = get_source_inputs(input_tensor) - else: - inputs = img_input - - # Create model. - if blocks == [6, 12, 24, 16]: - model = Model(inputs, x, name='densenet121') - elif blocks == [6, 12, 32, 32]: - model = Model(inputs, x, name='densenet169') - elif blocks == [6, 12, 48, 32]: - model = Model(inputs, x, name='densenet201') - else: - model = Model(inputs, x, name='densenet') - - # Load weights. - if weights == 'imagenet': - if include_top: - if blocks == [6, 12, 24, 16]: - weights_path = get_file( - 'densenet121_weights_tf_dim_ordering_tf_kernels.h5', - DENSENET121_WEIGHT_PATH, - cache_subdir='models', - file_hash='0962ca643bae20f9b6771cb844dca3b0') - elif blocks == [6, 12, 32, 32]: - weights_path = get_file( - 'densenet169_weights_tf_dim_ordering_tf_kernels.h5', - DENSENET169_WEIGHT_PATH, - cache_subdir='models', - file_hash='bcf9965cf5064a5f9eb6d7dc69386f43') - elif blocks == [6, 12, 48, 32]: - weights_path = get_file( - 'densenet201_weights_tf_dim_ordering_tf_kernels.h5', - DENSENET201_WEIGHT_PATH, - cache_subdir='models', - file_hash='7bb75edd58cb43163be7e0005fbe95ef') - else: - if blocks == [6, 12, 24, 16]: - weights_path = get_file( - 'densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5', - DENSENET121_WEIGHT_PATH_NO_TOP, - cache_subdir='models', - file_hash='4912a53fbd2a69346e7f2c0b5ec8c6d3') - elif blocks == [6, 12, 32, 32]: - weights_path = get_file( - 'densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5', - DENSENET169_WEIGHT_PATH_NO_TOP, - cache_subdir='models', - file_hash='50662582284e4cf834ce40ab4dfa58c6') - elif blocks == [6, 12, 48, 32]: - weights_path = get_file( - 'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5', - DENSENET201_WEIGHT_PATH_NO_TOP, - cache_subdir='models', - file_hash='1c2de60ee40562448dbac34a0737e798') - model.load_weights(weights_path) - elif weights is not None: - model.load_weights(weights) - - return model - - -@tf_export('keras.applications.DenseNet121', - 'keras.applications.densenet.DenseNet121') -def DenseNet121(include_top=True, - weights='imagenet', - input_tensor=None, - input_shape=None, - pooling=None, - classes=1000): - return DenseNet([6, 12, 24, 16], include_top, weights, input_tensor, - input_shape, pooling, classes) - - -@tf_export('keras.applications.DenseNet169', - 'keras.applications.densenet.DenseNet169') -def DenseNet169(include_top=True, - weights='imagenet', - input_tensor=None, - input_shape=None, - pooling=None, - classes=1000): - return DenseNet([6, 12, 32, 32], include_top, weights, input_tensor, - input_shape, pooling, classes) - - -@tf_export('keras.applications.DenseNet201', - 'keras.applications.densenet.DenseNet201') -def DenseNet201(include_top=True, - weights='imagenet', - input_tensor=None, - input_shape=None, - pooling=None, - classes=1000): - return DenseNet([6, 12, 48, 32], include_top, weights, input_tensor, - input_shape, pooling, classes) - - -@tf_export('keras.applications.densenet.preprocess_input') -def preprocess_input(x, data_format=None): - """Preprocesses a numpy array encoding a batch of images. - - Arguments: - x: a 3D or 4D numpy array consists of RGB values within [0, 255]. - data_format: data format of the image tensor. - - Returns: - Preprocessed array. - """ - return imagenet_utils.preprocess_input(x, data_format, mode='torch') - - -setattr(DenseNet121, '__doc__', DenseNet.__doc__) -setattr(DenseNet169, '__doc__', DenseNet.__doc__) -setattr(DenseNet201, '__doc__', DenseNet.__doc__) +DenseNet121 = densenet.DenseNet121 +DenseNet169 = densenet.DenseNet169 +DenseNet201 = densenet.DenseNet201 +decode_predictions = densenet.decode_predictions +preprocess_input = densenet.preprocess_input + +tf_export('keras.applications.densenet.DenseNet121', + 'keras.applications.DenseNet121')(DenseNet121) +tf_export('keras.applications.densenet.DenseNet169', + 'keras.applications.DenseNet169')(DenseNet169) +tf_export('keras.applications.densenet.DenseNet201', + 'keras.applications.DenseNet201')(DenseNet201) +tf_export('keras.applications.densenet.preprocess_input')(preprocess_input) diff --git a/tensorflow/python/keras/applications/densenet_test.py b/tensorflow/python/keras/applications/densenet_test.py deleted file mode 100644 index 8b6aa281ad0e2d0798952b7489c89892709cda29..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/densenet_test.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for DenseNet application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python import keras -from tensorflow.python.platform import test - - -class DenseNet121Test(test.TestCase): - - def test_with_top(self): - model = keras.applications.DenseNet121(weights=None) - self.assertEqual(model.output_shape, (None, 1000)) - - def test_no_top(self): - model = keras.applications.DenseNet121(weights=None, include_top=False) - self.assertEqual(model.output_shape, (None, None, None, 1024)) - - def test_with_pooling(self): - model = keras.applications.DenseNet121(weights=None, - include_top=False, - pooling='avg') - self.assertEqual(model.output_shape, (None, 1024)) - - def test_weight_loading(self): - with self.assertRaises(ValueError): - keras.applications.DenseNet121(weights='unknown', - include_top=False) - with self.assertRaises(ValueError): - keras.applications.DenseNet121(weights='imagenet', - classes=2000) - - -class DenseNet169Test(test.TestCase): - - def test_with_top(self): - model = keras.applications.DenseNet169(weights=None) - self.assertEqual(model.output_shape, (None, 1000)) - - def test_no_top(self): - model = keras.applications.DenseNet169(weights=None, include_top=False) - self.assertEqual(model.output_shape, (None, None, None, 1664)) - - def test_with_pooling(self): - model = keras.applications.DenseNet169(weights=None, - include_top=False, - pooling='max') - self.assertEqual(model.output_shape, (None, 1664)) - - def test_weight_loading(self): - with self.assertRaises(ValueError): - keras.applications.DenseNet169(weights='unknown', - include_top=False) - with self.assertRaises(ValueError): - keras.applications.DenseNet169(weights='imagenet', - classes=2000) - - -class DenseNet201(test.TestCase): - - def test_with_top(self): - model = keras.applications.DenseNet201(weights=None) - self.assertEqual(model.output_shape, (None, 1000)) - - def test_no_top(self): - model = keras.applications.DenseNet201(weights=None, include_top=False) - self.assertEqual(model.output_shape, (None, None, None, 1920)) - - def test_with_pooling(self): - model = keras.applications.DenseNet201(weights=None, - include_top=False, - pooling='avg') - self.assertEqual(model.output_shape, (None, 1920)) - - def test_weight_loading(self): - with self.assertRaises(ValueError): - keras.applications.DenseNet201(weights='unknown', - include_top=False) - with self.assertRaises(ValueError): - keras.applications.DenseNet201(weights='imagenet', - classes=2000) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/keras/applications/imagenet_utils.py b/tensorflow/python/keras/applications/imagenet_utils.py index 0d8ccca1b5c2a6c05f0d933a8f0fe176ea62c2a3..70f8f6fb32cfd0fe397c75ad4d3237919e7b0fad 100644 --- a/tensorflow/python/keras/applications/imagenet_utils.py +++ b/tensorflow/python/keras/applications/imagenet_utils.py @@ -18,322 +18,28 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import json - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.keras import backend as K -from tensorflow.python.keras.utils.data_utils import get_file -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import tf_logging as logging +from keras_applications import imagenet_utils from tensorflow.python.util.tf_export import tf_export - -CLASS_INDEX = None -CLASS_INDEX_PATH = 'https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json' - -# Global tensor of imagenet mean for preprocessing symbolic inputs -_IMAGENET_MEAN = None - - -def _preprocess_numpy_input(x, data_format, mode): - """Preprocesses a Numpy array encoding a batch of images. - - Arguments: - x: Input array, 3D or 4D. - data_format: Data format of the image array. - mode: One of "caffe", "tf" or "torch". - - caffe: will convert the images from RGB to BGR, - then will zero-center each color channel with - respect to the ImageNet dataset, - without scaling. - - tf: will scale pixels between -1 and 1, - sample-wise. - - torch: will scale pixels between 0 and 1 and then - will normalize each channel with respect to the - ImageNet dataset. - - Returns: - Preprocessed Numpy array. - """ - if mode == 'tf': - x /= 127.5 - x -= 1. - return x - - if mode == 'torch': - x /= 255. - mean = [0.485, 0.456, 0.406] - std = [0.229, 0.224, 0.225] - else: - if data_format == 'channels_first': - # 'RGB'->'BGR' - if x.ndim == 3: - x = x[::-1, ...] - else: - x = x[:, ::-1, ...] - else: - # 'RGB'->'BGR' - x = x[..., ::-1] - mean = [103.939, 116.779, 123.68] - std = None - - # Zero-center by mean pixel - if data_format == 'channels_first': - if x.ndim == 3: - x[0, :, :] -= mean[0] - x[1, :, :] -= mean[1] - x[2, :, :] -= mean[2] - if std is not None: - x[0, :, :] /= std[0] - x[1, :, :] /= std[1] - x[2, :, :] /= std[2] - else: - x[:, 0, :, :] -= mean[0] - x[:, 1, :, :] -= mean[1] - x[:, 2, :, :] -= mean[2] - if std is not None: - x[:, 0, :, :] /= std[0] - x[:, 1, :, :] /= std[1] - x[:, 2, :, :] /= std[2] - else: - x[..., 0] -= mean[0] - x[..., 1] -= mean[1] - x[..., 2] -= mean[2] - if std is not None: - x[..., 0] /= std[0] - x[..., 1] /= std[1] - x[..., 2] /= std[2] - return x - - -def _preprocess_symbolic_input(x, data_format, mode): - """Preprocesses a tensor encoding a batch of images. - - Arguments: - x: Input tensor, 3D or 4D. - data_format: Data format of the image tensor. - mode: One of "caffe", "tf" or "torch". - - caffe: will convert the images from RGB to BGR, - then will zero-center each color channel with - respect to the ImageNet dataset, - without scaling. - - tf: will scale pixels between -1 and 1, - sample-wise. - - torch: will scale pixels between 0 and 1 and then - will normalize each channel with respect to the - ImageNet dataset. - - Returns: - Preprocessed tensor. - """ - global _IMAGENET_MEAN - - if mode == 'tf': - x /= 127.5 - x -= 1. - return x - - if mode == 'torch': - x /= 255. - mean = [0.485, 0.456, 0.406] - std = [0.229, 0.224, 0.225] - else: - if data_format == 'channels_first': - # 'RGB'->'BGR' - if K.ndim(x) == 3: - x = x[::-1, ...] - else: - x = x[:, ::-1, ...] - else: - # 'RGB'->'BGR' - x = x[..., ::-1] - mean = [103.939, 116.779, 123.68] - std = None - - if _IMAGENET_MEAN is None: - _IMAGENET_MEAN = constant_op.constant(-np.array(mean), dtype=K.floatx()) - - # Zero-center by mean pixel - if K.dtype(x) != K.dtype(_IMAGENET_MEAN): - x = K.bias_add(x, math_ops.cast(_IMAGENET_MEAN, K.dtype(x)), data_format) - else: - x = K.bias_add(x, _IMAGENET_MEAN, data_format) - if std is not None: - x /= std - return x - - -@tf_export('keras.applications.resnet50.preprocess_input', - 'keras.applications.vgg19.preprocess_input', - 'keras.applications.vgg16.preprocess_input') -def preprocess_input(x, data_format=None, mode='caffe'): - """Preprocesses a tensor or Numpy array encoding a batch of images. - - Arguments: - x: Input Numpy or symbolic tensor, 3D or 4D. - data_format: Data format of the image tensor/array. - mode: One of "caffe", "tf". - - caffe: will convert the images from RGB to BGR, - then will zero-center each color channel with - respect to the ImageNet dataset, - without scaling. - - tf: will scale pixels between -1 and 1, - sample-wise. - - Returns: - Preprocessed tensor or Numpy array. - - Raises: - ValueError: In case of unknown `data_format` argument. - """ - if data_format is None: - data_format = K.image_data_format() - if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) - - if isinstance(x, np.ndarray): - return _preprocess_numpy_input(x, data_format=data_format, mode=mode) - else: - return _preprocess_symbolic_input(x, data_format=data_format, mode=mode) - - -@tf_export('keras.applications.nasnet.decode_predictions', - 'keras.applications.resnet50.decode_predictions', - 'keras.applications.vgg19.decode_predictions', - 'keras.applications.vgg16.decode_predictions', - 'keras.applications.inception_resnet_v2.decode_predictions', - 'keras.applications.inception_v3.decode_predictions', - 'keras.applications.densenet.decode_predictions', - 'keras.applications.mobilenet.decode_predictions', - 'keras.applications.xception.decode_predictions') -def decode_predictions(preds, top=5): - """Decodes the prediction of an ImageNet model. - - Arguments: - preds: Numpy tensor encoding a batch of predictions. - top: Integer, how many top-guesses to return. - - Returns: - A list of lists of top class prediction tuples - `(class_name, class_description, score)`. - One list of tuples per sample in batch input. - - Raises: - ValueError: In case of invalid shape of the `pred` array - (must be 2D). - """ - global CLASS_INDEX - if len(preds.shape) != 2 or preds.shape[1] != 1000: - raise ValueError('`decode_predictions` expects ' - 'a batch of predictions ' - '(i.e. a 2D array of shape (samples, 1000)). ' - 'Found array with shape: ' + str(preds.shape)) - if CLASS_INDEX is None: - fpath = get_file( - 'imagenet_class_index.json', - CLASS_INDEX_PATH, - cache_subdir='models', - file_hash='c2c37ea517e94d9795004a39431a14cb') - with open(fpath) as f: - CLASS_INDEX = json.load(f) - results = [] - for pred in preds: - top_indices = pred.argsort()[-top:][::-1] - result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices] - result.sort(key=lambda x: x[2], reverse=True) - results.append(result) - return results - - -def _obtain_input_shape(input_shape, - default_size, - min_size, - data_format, - require_flatten, - weights=None): - """Internal utility to compute/validate a model's input shape. - - Arguments: - input_shape: Either None (will return the default network input shape), - or a user-provided shape to be validated. - default_size: Default input width/height for the model. - min_size: Minimum input width/height accepted by the model. - data_format: Image data format to use. - require_flatten: Whether the model is expected to - be linked to a classifier via a Flatten layer. - weights: One of `None` (random initialization) - or 'imagenet' (pre-training on ImageNet). - If weights='imagenet' input channels must be equal to 3. - - Returns: - An integer shape tuple (may include None entries). - - Raises: - ValueError: In case of invalid argument values. - """ - if weights != 'imagenet' and input_shape and len(input_shape) == 3: - if data_format == 'channels_first': - if input_shape[0] not in {1, 3}: - logging.warning('This model usually expects 1 or 3 input channels. ' - 'However, it was passed an input_shape with ' + - str(input_shape[0]) + ' input channels.') - default_shape = (input_shape[0], default_size, default_size) - else: - if input_shape[-1] not in {1, 3}: - logging.warning('This model usually expects 1 or 3 input channels. ' - 'However, it was passed an input_shape with ' + - str(input_shape[-1]) + ' input channels.') - default_shape = (default_size, default_size, input_shape[-1]) - else: - if data_format == 'channels_first': - default_shape = (3, default_size, default_size) - else: - default_shape = (default_size, default_size, 3) - if weights == 'imagenet' and require_flatten: - if input_shape is not None: - if input_shape != default_shape: - raise ValueError('When setting`include_top=True` ' - 'and loading `imagenet` weights, ' - '`input_shape` should be ' + str(default_shape) + '.') - return default_shape - if input_shape: - if data_format == 'channels_first': - if input_shape is not None: - if len(input_shape) != 3: - raise ValueError('`input_shape` must be a tuple of three integers.') - if input_shape[0] != 3 and weights == 'imagenet': - raise ValueError('The input must have 3 channels; got ' - '`input_shape=' + str(input_shape) + '`') - if ((input_shape[1] is not None and input_shape[1] < min_size) or - (input_shape[2] is not None and input_shape[2] < min_size)): - raise ValueError('Input size must be at least ' + str(min_size) + - 'x' + str(min_size) + '; got ' - '`input_shape=' + str(input_shape) + '`') - else: - if input_shape is not None: - if len(input_shape) != 3: - raise ValueError('`input_shape` must be a tuple of three integers.') - if input_shape[-1] != 3 and weights == 'imagenet': - raise ValueError('The input must have 3 channels; got ' - '`input_shape=' + str(input_shape) + '`') - if ((input_shape[0] is not None and input_shape[0] < min_size) or - (input_shape[1] is not None and input_shape[1] < min_size)): - raise ValueError('Input size must be at least ' + str(min_size) + - 'x' + str(min_size) + '; got ' - '`input_shape=' + str(input_shape) + '`') - else: - if require_flatten: - input_shape = default_shape - else: - if data_format == 'channels_first': - input_shape = (3, None, None) - else: - input_shape = (None, None, 3) - if require_flatten: - if None in input_shape: - raise ValueError('If `include_top` is True, ' - 'you should specify a static `input_shape`. ' - 'Got `input_shape=' + str(input_shape) + '`') - return input_shape +decode_predictions = imagenet_utils.decode_predictions +preprocess_input = imagenet_utils.preprocess_input + +tf_export( + 'keras.applications.imagenet_utils.decode_predictions', + 'keras.applications.densenet.decode_predictions', + 'keras.applications.inception_resnet_v2.decode_predictions', + 'keras.applications.inception_v3.decode_predictions', + 'keras.applications.mobilenet.decode_predictions', + 'keras.applications.mobilenet_v2.decode_predictions', + 'keras.applications.nasnet.decode_predictions', + 'keras.applications.resnet50.decode_predictions', + 'keras.applications.vgg16.decode_predictions', + 'keras.applications.vgg19.decode_predictions', + 'keras.applications.xception.decode_predictions', +)(decode_predictions) +tf_export( + 'keras.applications.imagenet_utils.preprocess_input', + 'keras.applications.resnet50.preprocess_input', + 'keras.applications.vgg16.preprocess_input', + 'keras.applications.vgg19.preprocess_input', +)(preprocess_input) diff --git a/tensorflow/python/keras/applications/imagenet_utils_test.py b/tensorflow/python/keras/applications/imagenet_utils_test.py deleted file mode 100644 index 349339309017f3e9e3a9922d95188f1954ed8634..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/imagenet_utils_test.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Inception V3 application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python import keras -from tensorflow.python.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.platform import test - - -class ImageNetUtilsTest(test.TestCase): - - def test_preprocess_input(self): - # Test batch of images - x = np.random.uniform(0, 255, (2, 10, 10, 3)) - self.assertEqual(preprocess_input(x).shape, x.shape) - out1 = preprocess_input(x, 'channels_last') - out2 = preprocess_input(np.transpose(x, (0, 3, 1, 2)), 'channels_first') - self.assertAllClose(out1, out2.transpose(0, 2, 3, 1)) - - # Test single image - x = np.random.uniform(0, 255, (10, 10, 3)) - self.assertEqual(preprocess_input(x).shape, x.shape) - out1 = preprocess_input(x, 'channels_last') - out2 = preprocess_input(np.transpose(x, (2, 0, 1)), 'channels_first') - self.assertAllClose(out1, out2.transpose(1, 2, 0)) - - def test_preprocess_input_symbolic(self): - # Test image batch - x = np.random.uniform(0, 255, (2, 10, 10, 3)) - inputs = keras.layers.Input(shape=x.shape[1:]) - outputs = keras.layers.Lambda( - preprocess_input, output_shape=x.shape[1:])(inputs) - model = keras.models.Model(inputs, outputs) - assert model.predict(x).shape == x.shape - # pylint: disable=g-long-lambda - outputs1 = keras.layers.Lambda(lambda x: - preprocess_input(x, 'channels_last'), - output_shape=x.shape[1:])(inputs) - model1 = keras.models.Model(inputs, outputs1) - out1 = model1.predict(x) - x2 = np.transpose(x, (0, 3, 1, 2)) - inputs2 = keras.layers.Input(shape=x2.shape[1:]) - # pylint: disable=g-long-lambda - outputs2 = keras.layers.Lambda(lambda x: - preprocess_input(x, 'channels_first'), - output_shape=x2.shape[1:])(inputs2) - model2 = keras.models.Model(inputs2, outputs2) - out2 = model2.predict(x2) - self.assertAllClose(out1, out2.transpose(0, 2, 3, 1)) - - # Test single image - x = np.random.uniform(0, 255, (10, 10, 3)) - inputs = keras.layers.Input(shape=x.shape) - outputs = keras.layers.Lambda(preprocess_input, - output_shape=x.shape)(inputs) - model = keras.models.Model(inputs, outputs) - assert model.predict(x[np.newaxis])[0].shape == x.shape - # pylint: disable=g-long-lambda - outputs1 = keras.layers.Lambda(lambda x: - preprocess_input(x, 'channels_last'), - output_shape=x.shape)(inputs) - model1 = keras.models.Model(inputs, outputs1) - out1 = model1.predict(x[np.newaxis])[0] - x2 = np.transpose(x, (2, 0, 1)) - inputs2 = keras.layers.Input(shape=x2.shape) - outputs2 = keras.layers.Lambda(lambda x: - preprocess_input(x, 'channels_first'), - output_shape=x2.shape)(inputs2) # pylint: disable=g-long-lambda - model2 = keras.models.Model(inputs2, outputs2) - out2 = model2.predict(x2[np.newaxis])[0] - self.assertAllClose(out1, out2.transpose(1, 2, 0)) - - def test_obtain_input_shape(self): - # input_shape and default_size are not identical. - with self.assertRaises(ValueError): - keras.applications.imagenet_utils._obtain_input_shape( - input_shape=(224, 224, 3), - default_size=299, - min_size=139, - data_format='channels_last', - require_flatten=True, - weights='imagenet') - - # Test invalid use cases - for data_format in ['channels_last', 'channels_first']: - # input_shape is smaller than min_size. - shape = (100, 100) - if data_format == 'channels_last': - input_shape = shape + (3,) - else: - input_shape = (3,) + shape - with self.assertRaises(ValueError): - keras.applications.imagenet_utils._obtain_input_shape( - input_shape=input_shape, - default_size=None, - min_size=139, - data_format=data_format, - require_flatten=False) - - # shape is 1D. - shape = (100,) - if data_format == 'channels_last': - input_shape = shape + (3,) - else: - input_shape = (3,) + shape - with self.assertRaises(ValueError): - keras.applications.imagenet_utils._obtain_input_shape( - input_shape=input_shape, - default_size=None, - min_size=139, - data_format=data_format, - require_flatten=False) - - # the number of channels is 5 not 3. - shape = (100, 100) - if data_format == 'channels_last': - input_shape = shape + (5,) - else: - input_shape = (5,) + shape - with self.assertRaises(ValueError): - keras.applications.imagenet_utils._obtain_input_shape( - input_shape=input_shape, - default_size=None, - min_size=139, - data_format=data_format, - require_flatten=False) - - # require_flatten=True with dynamic input shape. - with self.assertRaises(ValueError): - keras.applications.imagenet_utils._obtain_input_shape( - input_shape=None, - default_size=None, - min_size=139, - data_format='channels_first', - require_flatten=True) - - assert keras.applications.imagenet_utils._obtain_input_shape( - input_shape=(3, 200, 200), - default_size=None, - min_size=139, - data_format='channels_first', - require_flatten=True) == (3, 200, 200) - - assert keras.applications.imagenet_utils._obtain_input_shape( - input_shape=None, - default_size=None, - min_size=139, - data_format='channels_last', - require_flatten=False) == (None, None, 3) - - assert keras.applications.imagenet_utils._obtain_input_shape( - input_shape=None, - default_size=None, - min_size=139, - data_format='channels_first', - require_flatten=False) == (3, None, None) - - assert keras.applications.imagenet_utils._obtain_input_shape( - input_shape=None, - default_size=None, - min_size=139, - data_format='channels_last', - require_flatten=False) == (None, None, 3) - - assert keras.applications.imagenet_utils._obtain_input_shape( - input_shape=(150, 150, 3), - default_size=None, - min_size=139, - data_format='channels_last', - require_flatten=False) == (150, 150, 3) - - assert keras.applications.imagenet_utils._obtain_input_shape( - input_shape=(3, None, None), - default_size=None, - min_size=139, - data_format='channels_first', - require_flatten=False) == (3, None, None) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py index fe1d0f2d4fb47f7ebab38f94afc8ace2f7b73cbc..63debb4e0df0da5d3a229332d3a2d473b1e1a23e 100644 --- a/tensorflow/python/keras/applications/inception_resnet_v2.py +++ b/tensorflow/python/keras/applications/inception_resnet_v2.py @@ -13,372 +13,20 @@ # limitations under the License. # ============================================================================== # pylint: disable=invalid-name -# pylint: disable=unused-import """Inception-ResNet V2 model for Keras. - -# Reference -- [Inception-v4, Inception-ResNet and the Impact of - Residual Connections on Learning](https://arxiv.org/abs/1602.07261) - """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -from tensorflow.python.keras import backend as K -from tensorflow.python.keras.applications import imagenet_utils -from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.engine.network import get_source_inputs -from tensorflow.python.keras.layers import Activation -from tensorflow.python.keras.layers import AveragePooling2D -from tensorflow.python.keras.layers import BatchNormalization -from tensorflow.python.keras.layers import Concatenate -from tensorflow.python.keras.layers import Conv2D -from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras.layers import Input -from tensorflow.python.keras.layers import Lambda -from tensorflow.python.keras.layers import MaxPooling2D -from tensorflow.python.keras.models import Model -from tensorflow.python.keras.utils.data_utils import get_file -from tensorflow.python.platform import tf_logging as logging +from keras_applications import inception_resnet_v2 from tensorflow.python.util.tf_export import tf_export +InceptionResNetV2 = inception_resnet_v2.InceptionResNetV2 +decode_predictions = inception_resnet_v2.decode_predictions +preprocess_input = inception_resnet_v2.preprocess_input -BASE_WEIGHT_URL = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.7/' - - -@tf_export('keras.applications.inception_resnet_v2.preprocess_input') -def preprocess_input(x): - """Preprocesses a numpy array encoding a batch of images. - - Arguments: - x: a 4D numpy array consists of RGB values within [0, 255]. - - Returns: - Preprocessed array. - """ - return imagenet_utils.preprocess_input(x, mode='tf') - - -def conv2d_bn(x, - filters, - kernel_size, - strides=1, - padding='same', - activation='relu', - use_bias=False, - name=None): - """Utility function to apply conv + BN. - - Arguments: - x: input tensor. - filters: filters in `Conv2D`. - kernel_size: kernel size as in `Conv2D`. - strides: strides in `Conv2D`. - padding: padding mode in `Conv2D`. - activation: activation in `Conv2D`. - use_bias: whether to use a bias in `Conv2D`. - name: name of the ops; will become `name + '_ac'` for the activation - and `name + '_bn'` for the batch norm layer. - - Returns: - Output tensor after applying `Conv2D` and `BatchNormalization`. - """ - x = Conv2D( - filters, - kernel_size, - strides=strides, - padding=padding, - use_bias=use_bias, - name=name)( - x) - if not use_bias: - bn_axis = 1 if K.image_data_format() == 'channels_first' else 3 - bn_name = None if name is None else name + '_bn' - x = BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x) - if activation is not None: - ac_name = None if name is None else name + '_ac' - x = Activation(activation, name=ac_name)(x) - return x - - -def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): - """Adds a Inception-ResNet block. - - This function builds 3 types of Inception-ResNet blocks mentioned - in the paper, controlled by the `block_type` argument (which is the - block name used in the official TF-slim implementation): - - Inception-ResNet-A: `block_type='block35'` - - Inception-ResNet-B: `block_type='block17'` - - Inception-ResNet-C: `block_type='block8'` - - Arguments: - x: input tensor. - scale: scaling factor to scale the residuals (i.e., the output of - passing `x` through an inception module) before adding them - to the shortcut branch. Let `r` be the output from the residual - branch, - the output of this block will be `x + scale * r`. - block_type: `'block35'`, `'block17'` or `'block8'`, determines - the network structure in the residual branch. - block_idx: an `int` used for generating layer names. The Inception-ResNet - blocks - are repeated many times in this network. We use `block_idx` to - identify - each of the repetitions. For example, the first Inception-ResNet-A - block - will have `block_type='block35', block_idx=0`, ane the layer names - will have - a common prefix `'block35_0'`. - activation: activation function to use at the end of the block. - When `activation=None`, no activation is applied - (i.e., "linear" activation: `a(x) = x`). - - Returns: - Output tensor for the block. - - Raises: - ValueError: if `block_type` is not one of `'block35'`, - `'block17'` or `'block8'`. - """ - if block_type == 'block35': - branch_0 = conv2d_bn(x, 32, 1) - branch_1 = conv2d_bn(x, 32, 1) - branch_1 = conv2d_bn(branch_1, 32, 3) - branch_2 = conv2d_bn(x, 32, 1) - branch_2 = conv2d_bn(branch_2, 48, 3) - branch_2 = conv2d_bn(branch_2, 64, 3) - branches = [branch_0, branch_1, branch_2] - elif block_type == 'block17': - branch_0 = conv2d_bn(x, 192, 1) - branch_1 = conv2d_bn(x, 128, 1) - branch_1 = conv2d_bn(branch_1, 160, [1, 7]) - branch_1 = conv2d_bn(branch_1, 192, [7, 1]) - branches = [branch_0, branch_1] - elif block_type == 'block8': - branch_0 = conv2d_bn(x, 192, 1) - branch_1 = conv2d_bn(x, 192, 1) - branch_1 = conv2d_bn(branch_1, 224, [1, 3]) - branch_1 = conv2d_bn(branch_1, 256, [3, 1]) - branches = [branch_0, branch_1] - else: - raise ValueError('Unknown Inception-ResNet block type. ' - 'Expects "block35", "block17" or "block8", ' - 'but got: ' + str(block_type)) - - block_name = block_type + '_' + str(block_idx) - channel_axis = 1 if K.image_data_format() == 'channels_first' else 3 - mixed = Concatenate(axis=channel_axis, name=block_name + '_mixed')(branches) - up = conv2d_bn( - mixed, - K.int_shape(x)[channel_axis], - 1, - activation=None, - use_bias=True, - name=block_name + '_conv') - - x = Lambda( - lambda inputs, scale: inputs[0] + inputs[1] * scale, - output_shape=K.int_shape(x)[1:], - arguments={'scale': scale}, - name=block_name)([x, up]) - if activation is not None: - x = Activation(activation, name=block_name + '_ac')(x) - return x - - -@tf_export('keras.applications.InceptionResNetV2', - 'keras.applications.inception_resnet_v2.InceptionResNetV2') -def InceptionResNetV2(include_top=True, - weights='imagenet', - input_tensor=None, - input_shape=None, - pooling=None, - classes=1000): - """Instantiates the Inception-ResNet v2 architecture. - - Optionally loads weights pre-trained on ImageNet. - Note that when using TensorFlow, for best performance you should - set `"image_data_format": "channels_last"` in your Keras config - at `~/.keras/keras.json`. - - The model and the weights are compatible with TensorFlow, Theano and - CNTK backends. The data format convention used by the model is - the one specified in your Keras config file. - - Note that the default input image size for this model is 299x299, instead - of 224x224 as in the VGG16 and ResNet models. Also, the input preprocessing - function is different (i.e., do not use `imagenet_utils.preprocess_input()` - with this model. Use `preprocess_input()` defined in this module instead). - - Arguments: - include_top: whether to include the fully-connected - layer at the top of the network. - weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. - input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) - to use as image input for the model. - input_shape: optional shape tuple, only to be specified - if `include_top` is `False` (otherwise the input shape - has to be `(299, 299, 3)` (with `'channels_last'` data format) - or `(3, 299, 299)` (with `'channels_first'` data format). - It should have exactly 3 inputs channels, - and width and height should be no smaller than 139. - E.g. `(150, 150, 3)` would be one valid value. - pooling: Optional pooling mode for feature extraction - when `include_top` is `False`. - - `None` means that the output of the model will be - the 4D tensor output of the last convolutional layer. - - `'avg'` means that global average pooling - will be applied to the output of the - last convolutional layer, and thus - the output of the model will be a 2D tensor. - - `'max'` means that global max pooling will be applied. - classes: optional number of classes to classify images - into, only to be specified if `include_top` is `True`, and - if no `weights` argument is specified. - - Returns: - A Keras `Model` instance. - - Raises: - ValueError: in case of invalid argument for `weights`, - or invalid input shape. - """ - if not (weights in {'imagenet', None} or os.path.exists(weights)): - raise ValueError('The `weights` argument should be either ' - '`None` (random initialization), `imagenet` ' - '(pre-training on ImageNet), ' - 'or the path to the weights file to be loaded.') - - if weights == 'imagenet' and include_top and classes != 1000: - raise ValueError('If using `weights` as imagenet with `include_top`' - ' as true, `classes` should be 1000') - - # Determine proper input shape - input_shape = _obtain_input_shape( - input_shape, - default_size=299, - min_size=139, - data_format=K.image_data_format(), - require_flatten=False, - weights=weights) - - if input_tensor is None: - img_input = Input(shape=input_shape) - else: - if not K.is_keras_tensor(input_tensor): - img_input = Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - - # Stem block: 35 x 35 x 192 - x = conv2d_bn(img_input, 32, 3, strides=2, padding='valid') - x = conv2d_bn(x, 32, 3, padding='valid') - x = conv2d_bn(x, 64, 3) - x = MaxPooling2D(3, strides=2)(x) - x = conv2d_bn(x, 80, 1, padding='valid') - x = conv2d_bn(x, 192, 3, padding='valid') - x = MaxPooling2D(3, strides=2)(x) - - # Mixed 5b (Inception-A block): 35 x 35 x 320 - branch_0 = conv2d_bn(x, 96, 1) - branch_1 = conv2d_bn(x, 48, 1) - branch_1 = conv2d_bn(branch_1, 64, 5) - branch_2 = conv2d_bn(x, 64, 1) - branch_2 = conv2d_bn(branch_2, 96, 3) - branch_2 = conv2d_bn(branch_2, 96, 3) - branch_pool = AveragePooling2D(3, strides=1, padding='same')(x) - branch_pool = conv2d_bn(branch_pool, 64, 1) - branches = [branch_0, branch_1, branch_2, branch_pool] - channel_axis = 1 if K.image_data_format() == 'channels_first' else 3 - x = Concatenate(axis=channel_axis, name='mixed_5b')(branches) - - # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320 - for block_idx in range(1, 11): - x = inception_resnet_block( - x, scale=0.17, block_type='block35', block_idx=block_idx) - - # Mixed 6a (Reduction-A block): 17 x 17 x 1088 - branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='valid') - branch_1 = conv2d_bn(x, 256, 1) - branch_1 = conv2d_bn(branch_1, 256, 3) - branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='valid') - branch_pool = MaxPooling2D(3, strides=2, padding='valid')(x) - branches = [branch_0, branch_1, branch_pool] - x = Concatenate(axis=channel_axis, name='mixed_6a')(branches) - - # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088 - for block_idx in range(1, 21): - x = inception_resnet_block( - x, scale=0.1, block_type='block17', block_idx=block_idx) - - # Mixed 7a (Reduction-B block): 8 x 8 x 2080 - branch_0 = conv2d_bn(x, 256, 1) - branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='valid') - branch_1 = conv2d_bn(x, 256, 1) - branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='valid') - branch_2 = conv2d_bn(x, 256, 1) - branch_2 = conv2d_bn(branch_2, 288, 3) - branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='valid') - branch_pool = MaxPooling2D(3, strides=2, padding='valid')(x) - branches = [branch_0, branch_1, branch_2, branch_pool] - x = Concatenate(axis=channel_axis, name='mixed_7a')(branches) - - # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080 - for block_idx in range(1, 10): - x = inception_resnet_block( - x, scale=0.2, block_type='block8', block_idx=block_idx) - x = inception_resnet_block( - x, scale=1., activation=None, block_type='block8', block_idx=10) - - # Final convolution block: 8 x 8 x 1536 - x = conv2d_bn(x, 1536, 1, name='conv_7b') - - if include_top: - # Classification block - x = GlobalAveragePooling2D(name='avg_pool')(x) - x = Dense(classes, activation='softmax', name='predictions')(x) - else: - if pooling == 'avg': - x = GlobalAveragePooling2D()(x) - elif pooling == 'max': - x = GlobalMaxPooling2D()(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor` - if input_tensor is not None: - inputs = get_source_inputs(input_tensor) - else: - inputs = img_input - - # Create model - model = Model(inputs, x, name='inception_resnet_v2') - - # Load weights - if weights == 'imagenet': - if include_top: - fname = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5' - weights_path = get_file( - fname, - BASE_WEIGHT_URL + fname, - cache_subdir='models', - file_hash='e693bd0210a403b3192acc6073ad2e96') - else: - fname = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop.h5' - weights_path = get_file( - fname, - BASE_WEIGHT_URL + fname, - cache_subdir='models', - file_hash='d19885ff4a710c122648d3b5c3b684e4') - model.load_weights(weights_path) - elif weights is not None: - model.load_weights(weights) - - return model +tf_export('keras.applications.inception_resnet_v2.InceptionResNetV2', + 'keras.applications.InceptionResNetV2')(InceptionResNetV2) +tf_export( + 'keras.applications.inception_resnet_v2.preprocess_input')(preprocess_input) diff --git a/tensorflow/python/keras/applications/inception_resnet_v2_test.py b/tensorflow/python/keras/applications/inception_resnet_v2_test.py deleted file mode 100644 index 0a12f885052ae9530e82190f7580c8288860c9a8..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/inception_resnet_v2_test.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Inception V3 application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python import keras -from tensorflow.python.platform import test - - -class InceptionResNetV2Test(test.TestCase): - - def test_with_top(self): - model = keras.applications.InceptionResNetV2(weights=None) - self.assertEqual(model.output_shape, (None, 1000)) - - def test_no_top(self): - model = keras.applications.InceptionResNetV2(weights=None, - include_top=False) - self.assertEqual(model.output_shape, (None, None, None, 1536)) - - def test_with_pooling(self): - model = keras.applications.InceptionResNetV2(weights=None, - include_top=False, - pooling='avg') - self.assertEqual(model.output_shape, (None, 1536)) - - def test_weight_loading(self): - with self.assertRaises(ValueError): - keras.applications.InceptionResNetV2(weights='unknown', - include_top=False) - with self.assertRaises(ValueError): - keras.applications.InceptionResNetV2(weights='imagenet', - classes=2000) - - def test_preprocess_input(self): - x = np.random.uniform(0, 255, (2, 300, 200, 3)) - out1 = keras.applications.inception_resnet_v2.preprocess_input(x) - self.assertAllClose(np.mean(out1), 0., atol=0.1) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/keras/applications/inception_v3.py b/tensorflow/python/keras/applications/inception_v3.py index 857ad49dae9ef234fe7d8251601ee122de39c947..87534086c8fbe818a4196ffa34ad3ee33fdf27f4 100644 --- a/tensorflow/python/keras/applications/inception_v3.py +++ b/tensorflow/python/keras/applications/inception_v3.py @@ -13,404 +13,19 @@ # limitations under the License. # ============================================================================== # pylint: disable=invalid-name -# pylint: disable=unused-import """Inception V3 model for Keras. - -Note that the input image format for this model is different than for -the VGG16 and ResNet models (299x299 instead of 224x224), -and that the input preprocessing function is also different (same as Xception). - -# Reference - -- [Rethinking the Inception Architecture for Computer -Vision](http://arxiv.org/abs/1512.00567) - """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -from tensorflow.python.keras import backend as K -from tensorflow.python.keras import layers -from tensorflow.python.keras.applications import imagenet_utils -from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.engine.network import get_source_inputs -from tensorflow.python.keras.layers import Activation -from tensorflow.python.keras.layers import AveragePooling2D -from tensorflow.python.keras.layers import BatchNormalization -from tensorflow.python.keras.layers import Conv2D -from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras.layers import Input -from tensorflow.python.keras.layers import MaxPooling2D -from tensorflow.python.keras.models import Model -from tensorflow.python.keras.utils.data_utils import get_file -from tensorflow.python.platform import tf_logging as logging +from keras_applications import inception_v3 from tensorflow.python.util.tf_export import tf_export +InceptionV3 = inception_v3.InceptionV3 +decode_predictions = inception_v3.decode_predictions +preprocess_input = inception_v3.preprocess_input -WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.5/inception_v3_weights_tf_dim_ordering_tf_kernels.h5' -WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.5/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5' - - -def conv2d_bn(x, - filters, - num_row, - num_col, - padding='same', - strides=(1, 1), - name=None): - """Utility function to apply conv + BN. - - Arguments: - x: input tensor. - filters: filters in `Conv2D`. - num_row: height of the convolution kernel. - num_col: width of the convolution kernel. - padding: padding mode in `Conv2D`. - strides: strides in `Conv2D`. - name: name of the ops; will become `name + '_conv'` - for the convolution and `name + '_bn'` for the - batch norm layer. - - Returns: - Output tensor after applying `Conv2D` and `BatchNormalization`. - """ - if name is not None: - bn_name = name + '_bn' - conv_name = name + '_conv' - else: - bn_name = None - conv_name = None - if K.image_data_format() == 'channels_first': - bn_axis = 1 - else: - bn_axis = 3 - x = Conv2D( - filters, (num_row, num_col), - strides=strides, - padding=padding, - use_bias=False, - name=conv_name)( - x) - x = BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x) - x = Activation('relu', name=name)(x) - return x - - -@tf_export('keras.applications.InceptionV3', - 'keras.applications.inception_v3.InceptionV3') -def InceptionV3(include_top=True, - weights='imagenet', - input_tensor=None, - input_shape=None, - pooling=None, - classes=1000): - """Instantiates the Inception v3 architecture. - - Optionally loads weights pre-trained - on ImageNet. Note that when using TensorFlow, - for best performance you should set - `image_data_format='channels_last'` in your Keras config - at ~/.keras/keras.json. - The model and the weights are compatible with both - TensorFlow and Theano. The data format - convention used by the model is the one - specified in your Keras config file. - Note that the default input image size for this model is 299x299. - - Arguments: - include_top: whether to include the fully-connected - layer at the top of the network. - weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. - input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) - to use as image input for the model. - input_shape: optional shape tuple, only to be specified - if `include_top` is False (otherwise the input shape - has to be `(299, 299, 3)` (with `channels_last` data format) - or `(3, 299, 299)` (with `channels_first` data format). - It should have exactly 3 inputs channels, - and width and height should be no smaller than 139. - E.g. `(150, 150, 3)` would be one valid value. - pooling: Optional pooling mode for feature extraction - when `include_top` is `False`. - - `None` means that the output of the model will be - the 4D tensor output of the - last convolutional layer. - - `avg` means that global average pooling - will be applied to the output of the - last convolutional layer, and thus - the output of the model will be a 2D tensor. - - `max` means that global max pooling will - be applied. - classes: optional number of classes to classify images - into, only to be specified if `include_top` is True, and - if no `weights` argument is specified. - - Returns: - A Keras model instance. - - Raises: - ValueError: in case of invalid argument for `weights`, - or invalid input shape. - """ - if not (weights in {'imagenet', None} or os.path.exists(weights)): - raise ValueError('The `weights` argument should be either ' - '`None` (random initialization), `imagenet` ' - '(pre-training on ImageNet), ' - 'or the path to the weights file to be loaded.') - - if weights == 'imagenet' and include_top and classes != 1000: - raise ValueError('If using `weights` as imagenet with `include_top`' - ' as true, `classes` should be 1000') - - # Determine proper input shape - input_shape = _obtain_input_shape( - input_shape, - default_size=299, - min_size=139, - data_format=K.image_data_format(), - require_flatten=False, - weights=weights) - - if input_tensor is None: - img_input = Input(shape=input_shape) - else: - if not K.is_keras_tensor(input_tensor): - img_input = Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - - if K.image_data_format() == 'channels_first': - channel_axis = 1 - else: - channel_axis = 3 - - x = conv2d_bn(img_input, 32, 3, 3, strides=(2, 2), padding='valid') - x = conv2d_bn(x, 32, 3, 3, padding='valid') - x = conv2d_bn(x, 64, 3, 3) - x = MaxPooling2D((3, 3), strides=(2, 2))(x) - - x = conv2d_bn(x, 80, 1, 1, padding='valid') - x = conv2d_bn(x, 192, 3, 3, padding='valid') - x = MaxPooling2D((3, 3), strides=(2, 2))(x) - - # mixed 0, 1, 2: 35 x 35 x 256 - branch1x1 = conv2d_bn(x, 64, 1, 1) - - branch5x5 = conv2d_bn(x, 48, 1, 1) - branch5x5 = conv2d_bn(branch5x5, 64, 5, 5) - - branch3x3dbl = conv2d_bn(x, 64, 1, 1) - branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) - branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) - - branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x) - branch_pool = conv2d_bn(branch_pool, 32, 1, 1) - x = layers.concatenate( - [branch1x1, branch5x5, branch3x3dbl, branch_pool], - axis=channel_axis, - name='mixed0') - - # mixed 1: 35 x 35 x 256 - branch1x1 = conv2d_bn(x, 64, 1, 1) - - branch5x5 = conv2d_bn(x, 48, 1, 1) - branch5x5 = conv2d_bn(branch5x5, 64, 5, 5) - - branch3x3dbl = conv2d_bn(x, 64, 1, 1) - branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) - branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) - - branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x) - branch_pool = conv2d_bn(branch_pool, 64, 1, 1) - x = layers.concatenate( - [branch1x1, branch5x5, branch3x3dbl, branch_pool], - axis=channel_axis, - name='mixed1') - - # mixed 2: 35 x 35 x 256 - branch1x1 = conv2d_bn(x, 64, 1, 1) - - branch5x5 = conv2d_bn(x, 48, 1, 1) - branch5x5 = conv2d_bn(branch5x5, 64, 5, 5) - - branch3x3dbl = conv2d_bn(x, 64, 1, 1) - branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) - branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) - - branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x) - branch_pool = conv2d_bn(branch_pool, 64, 1, 1) - x = layers.concatenate( - [branch1x1, branch5x5, branch3x3dbl, branch_pool], - axis=channel_axis, - name='mixed2') - - # mixed 3: 17 x 17 x 768 - branch3x3 = conv2d_bn(x, 384, 3, 3, strides=(2, 2), padding='valid') - - branch3x3dbl = conv2d_bn(x, 64, 1, 1) - branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) - branch3x3dbl = conv2d_bn( - branch3x3dbl, 96, 3, 3, strides=(2, 2), padding='valid') - - branch_pool = MaxPooling2D((3, 3), strides=(2, 2))(x) - x = layers.concatenate( - [branch3x3, branch3x3dbl, branch_pool], axis=channel_axis, name='mixed3') - - # mixed 4: 17 x 17 x 768 - branch1x1 = conv2d_bn(x, 192, 1, 1) - - branch7x7 = conv2d_bn(x, 128, 1, 1) - branch7x7 = conv2d_bn(branch7x7, 128, 1, 7) - branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) - - branch7x7dbl = conv2d_bn(x, 128, 1, 1) - branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1) - branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 1, 7) - branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1) - branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) - - branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x) - branch_pool = conv2d_bn(branch_pool, 192, 1, 1) - x = layers.concatenate( - [branch1x1, branch7x7, branch7x7dbl, branch_pool], - axis=channel_axis, - name='mixed4') - - # mixed 5, 6: 17 x 17 x 768 - for i in range(2): - branch1x1 = conv2d_bn(x, 192, 1, 1) - - branch7x7 = conv2d_bn(x, 160, 1, 1) - branch7x7 = conv2d_bn(branch7x7, 160, 1, 7) - branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) - - branch7x7dbl = conv2d_bn(x, 160, 1, 1) - branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1) - branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 1, 7) - branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1) - branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) - - branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x) - branch_pool = conv2d_bn(branch_pool, 192, 1, 1) - x = layers.concatenate( - [branch1x1, branch7x7, branch7x7dbl, branch_pool], - axis=channel_axis, - name='mixed' + str(5 + i)) - - # mixed 7: 17 x 17 x 768 - branch1x1 = conv2d_bn(x, 192, 1, 1) - - branch7x7 = conv2d_bn(x, 192, 1, 1) - branch7x7 = conv2d_bn(branch7x7, 192, 1, 7) - branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) - - branch7x7dbl = conv2d_bn(x, 192, 1, 1) - branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1) - branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) - branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1) - branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) - - branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x) - branch_pool = conv2d_bn(branch_pool, 192, 1, 1) - x = layers.concatenate( - [branch1x1, branch7x7, branch7x7dbl, branch_pool], - axis=channel_axis, - name='mixed7') - - # mixed 8: 8 x 8 x 1280 - branch3x3 = conv2d_bn(x, 192, 1, 1) - branch3x3 = conv2d_bn(branch3x3, 320, 3, 3, strides=(2, 2), padding='valid') - - branch7x7x3 = conv2d_bn(x, 192, 1, 1) - branch7x7x3 = conv2d_bn(branch7x7x3, 192, 1, 7) - branch7x7x3 = conv2d_bn(branch7x7x3, 192, 7, 1) - branch7x7x3 = conv2d_bn( - branch7x7x3, 192, 3, 3, strides=(2, 2), padding='valid') - - branch_pool = MaxPooling2D((3, 3), strides=(2, 2))(x) - x = layers.concatenate( - [branch3x3, branch7x7x3, branch_pool], axis=channel_axis, name='mixed8') - - # mixed 9: 8 x 8 x 2048 - for i in range(2): - branch1x1 = conv2d_bn(x, 320, 1, 1) - - branch3x3 = conv2d_bn(x, 384, 1, 1) - branch3x3_1 = conv2d_bn(branch3x3, 384, 1, 3) - branch3x3_2 = conv2d_bn(branch3x3, 384, 3, 1) - branch3x3 = layers.concatenate( - [branch3x3_1, branch3x3_2], axis=channel_axis, name='mixed9_' + str(i)) - - branch3x3dbl = conv2d_bn(x, 448, 1, 1) - branch3x3dbl = conv2d_bn(branch3x3dbl, 384, 3, 3) - branch3x3dbl_1 = conv2d_bn(branch3x3dbl, 384, 1, 3) - branch3x3dbl_2 = conv2d_bn(branch3x3dbl, 384, 3, 1) - branch3x3dbl = layers.concatenate( - [branch3x3dbl_1, branch3x3dbl_2], axis=channel_axis) - - branch_pool = AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x) - branch_pool = conv2d_bn(branch_pool, 192, 1, 1) - x = layers.concatenate( - [branch1x1, branch3x3, branch3x3dbl, branch_pool], - axis=channel_axis, - name='mixed' + str(9 + i)) - if include_top: - # Classification block - x = GlobalAveragePooling2D(name='avg_pool')(x) - x = Dense(classes, activation='softmax', name='predictions')(x) - else: - if pooling == 'avg': - x = GlobalAveragePooling2D()(x) - elif pooling == 'max': - x = GlobalMaxPooling2D()(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = get_source_inputs(input_tensor) - else: - inputs = img_input - # Create model. - model = Model(inputs, x, name='inception_v3') - - # load weights - if weights == 'imagenet': - if include_top: - weights_path = get_file( - 'inception_v3_weights_tf_dim_ordering_tf_kernels.h5', - WEIGHTS_PATH, - cache_subdir='models', - file_hash='9a0d58056eeedaa3f26cb7ebd46da564') - else: - weights_path = get_file( - 'inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5', - WEIGHTS_PATH_NO_TOP, - cache_subdir='models', - file_hash='bcbd6486424b2319ff4ef7d526e38f63') - model.load_weights(weights_path) - elif weights is not None: - model.load_weights(weights) - - return model - - -@tf_export('keras.applications.nasnet.preprocess_input', - 'keras.applications.inception_v3.preprocess_input') -def preprocess_input(x): - """Preprocesses a numpy array encoding a batch of images. - - Arguments: - x: a 4D numpy array consists of RGB values within [0, 255]. - - Returns: - Preprocessed array. - """ - return imagenet_utils.preprocess_input(x, mode='tf') +tf_export('keras.applications.inception_v3.InceptionV3', + 'keras.applications.InceptionV3')(InceptionV3) +tf_export('keras.applications.inception_v3.preprocess_input')(preprocess_input) diff --git a/tensorflow/python/keras/applications/inception_v3_test.py b/tensorflow/python/keras/applications/inception_v3_test.py deleted file mode 100644 index a3fcdd55644af5a2211b58169d87ab4fba996b19..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/inception_v3_test.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Inception V3 application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python import keras -from tensorflow.python.platform import test - - -class InceptionV3Test(test.TestCase): - - def test_with_top(self): - model = keras.applications.InceptionV3(weights=None) - self.assertEqual(model.output_shape, (None, 1000)) - - def test_no_top(self): - model = keras.applications.InceptionV3(weights=None, include_top=False) - self.assertEqual(model.output_shape, (None, None, None, 2048)) - - def test_with_pooling(self): - model = keras.applications.InceptionV3(weights=None, - include_top=False, - pooling='avg') - self.assertEqual(model.output_shape, (None, 2048)) - - def test_weight_loading(self): - with self.assertRaises(ValueError): - keras.applications.InceptionV3(weights='unknown', - include_top=False) - with self.assertRaises(ValueError): - keras.applications.InceptionV3(weights='imagenet', - classes=2000) - - def test_preprocess_input(self): - x = np.random.uniform(0, 255, (2, 300, 200, 3)) - out1 = keras.applications.inception_v3.preprocess_input(x) - self.assertAllClose(np.mean(out1), 0., atol=0.1) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py index 9d845be0d5b1ab06dd8a41bc04f75ae7b5f00789..3528f027b3fc0849ffa8e35d84928da9f55d33a7 100644 --- a/tensorflow/python/keras/applications/mobilenet.py +++ b/tensorflow/python/keras/applications/mobilenet.py @@ -13,480 +13,19 @@ # limitations under the License. # ============================================================================== # pylint: disable=invalid-name -# pylint: disable=unused-import """MobileNet v1 models for Keras. - -MobileNet is a general architecture and can be used for multiple use cases. -Depending on the use case, it can use different input layer size and -different width factors. This allows different width models to reduce -the number of multiply-adds and thereby -reduce inference cost on mobile devices. - -MobileNets support any input size greater than 32 x 32, with larger image sizes -offering better performance. -The number of parameters and number of multiply-adds -can be modified by using the `alpha` parameter, -which increases/decreases the number of filters in each layer. -By altering the image size and `alpha` parameter, -all 16 models from the paper can be built, with ImageNet weights provided. - -The paper demonstrates the performance of MobileNets using `alpha` values of -1.0 (also called 100 % MobileNet), 0.75, 0.5 and 0.25. -For each of these `alpha` values, weights for 4 different input image sizes -are provided (224, 192, 160, 128). - -The following table describes the size and accuracy of the 100% MobileNet -on size 224 x 224: ----------------------------------------------------------------------------- -Width Multiplier (alpha) | ImageNet Acc | Multiply-Adds (M) | Params (M) ----------------------------------------------------------------------------- -| 1.0 MobileNet-224 | 70.6 % | 529 | 4.2 | -| 0.75 MobileNet-224 | 68.4 % | 325 | 2.6 | -| 0.50 MobileNet-224 | 63.7 % | 149 | 1.3 | -| 0.25 MobileNet-224 | 50.6 % | 41 | 0.5 | ----------------------------------------------------------------------------- - -The following table describes the performance of -the 100 % MobileNet on various input sizes: ------------------------------------------------------------------------- - Resolution | ImageNet Acc | Multiply-Adds (M) | Params (M) ------------------------------------------------------------------------- -| 1.0 MobileNet-224 | 70.6 % | 529 | 4.2 | -| 1.0 MobileNet-192 | 69.1 % | 529 | 4.2 | -| 1.0 MobileNet-160 | 67.2 % | 529 | 4.2 | -| 1.0 MobileNet-128 | 64.4 % | 529 | 4.2 | ------------------------------------------------------------------------- - -The weights for all 16 models are obtained and translated -from TensorFlow checkpoints found at -https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md - -# Reference -- [MobileNets: Efficient Convolutional Neural Networks for - Mobile Vision Applications](https://arxiv.org/pdf/1704.04861.pdf)) """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -from tensorflow.python.keras import backend as K -from tensorflow.python.keras import constraints -from tensorflow.python.keras import initializers -from tensorflow.python.keras import regularizers -from tensorflow.python.keras.applications import imagenet_utils -from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine.network import get_source_inputs -from tensorflow.python.keras.layers import Activation -from tensorflow.python.keras.layers import BatchNormalization -from tensorflow.python.keras.layers import Conv2D -from tensorflow.python.keras.layers import DepthwiseConv2D -from tensorflow.python.keras.layers import Dropout -from tensorflow.python.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras.layers import Input -from tensorflow.python.keras.layers import Reshape -from tensorflow.python.keras.layers import ZeroPadding2D -from tensorflow.python.keras.models import Model -from tensorflow.python.keras.utils import conv_utils -from tensorflow.python.keras.utils.data_utils import get_file -from tensorflow.python.platform import tf_logging as logging +from keras_applications import mobilenet from tensorflow.python.util.tf_export import tf_export +MobileNet = mobilenet.MobileNet +decode_predictions = mobilenet.decode_predictions +preprocess_input = mobilenet.preprocess_input -BASE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.6/' - - -def relu6(x): - return K.relu(x, max_value=6) - - -@tf_export('keras.applications.mobilenet.preprocess_input') -def preprocess_input(x): - """Preprocesses a numpy array encoding a batch of images. - - Arguments: - x: a 4D numpy array consists of RGB values within [0, 255]. - - Returns: - Preprocessed array. - """ - return imagenet_utils.preprocess_input(x, mode='tf') - - -@tf_export('keras.applications.MobileNet', - 'keras.applications.mobilenet.MobileNet') -def MobileNet(input_shape=None, - alpha=1.0, - depth_multiplier=1, - dropout=1e-3, - include_top=True, - weights='imagenet', - input_tensor=None, - pooling=None, - classes=1000): - """Instantiates the MobileNet architecture. - - To load a MobileNet model via `load_model`, import the custom - objects `relu6` and pass them to the `custom_objects` parameter. - E.g. - model = load_model('mobilenet.h5', custom_objects={ - 'relu6': mobilenet.relu6}) - - Arguments: - input_shape: optional shape tuple, only to be specified - if `include_top` is False (otherwise the input shape - has to be `(224, 224, 3)` (with `channels_last` data format) - or (3, 224, 224) (with `channels_first` data format). - It should have exactly 3 inputs channels, - and width and height should be no smaller than 32. - E.g. `(200, 200, 3)` would be one valid value. - alpha: controls the width of the network. - - If `alpha` < 1.0, proportionally decreases the number - of filters in each layer. - - If `alpha` > 1.0, proportionally increases the number - of filters in each layer. - - If `alpha` = 1, default number of filters from the paper - are used at each layer. - depth_multiplier: depth multiplier for depthwise convolution - (also called the resolution multiplier) - dropout: dropout rate - include_top: whether to include the fully-connected - layer at the top of the network. - weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. - input_tensor: optional Keras tensor (i.e. output of - `layers.Input()`) - to use as image input for the model. - pooling: Optional pooling mode for feature extraction - when `include_top` is `False`. - - `None` means that the output of the model - will be the 4D tensor output of the - last convolutional layer. - - `avg` means that global average pooling - will be applied to the output of the - last convolutional layer, and thus - the output of the model will be a - 2D tensor. - - `max` means that global max pooling will - be applied. - classes: optional number of classes to classify images - into, only to be specified if `include_top` is True, and - if no `weights` argument is specified. - - Returns: - A Keras model instance. - - Raises: - ValueError: in case of invalid argument for `weights`, - or invalid input shape. - RuntimeError: If attempting to run this model with a - backend that does not support separable convolutions. - """ - - if not (weights in {'imagenet', None} or os.path.exists(weights)): - raise ValueError('The `weights` argument should be either ' - '`None` (random initialization), `imagenet` ' - '(pre-training on ImageNet), ' - 'or the path to the weights file to be loaded.') - - if weights == 'imagenet' and include_top and classes != 1000: - raise ValueError('If using `weights` as ImageNet with `include_top` ' - 'as true, `classes` should be 1000') - - # Determine proper input shape and default size. - if input_shape is None: - default_size = 224 - else: - if K.image_data_format() == 'channels_first': - rows = input_shape[1] - cols = input_shape[2] - else: - rows = input_shape[0] - cols = input_shape[1] - - if rows == cols and rows in [128, 160, 192, 224]: - default_size = rows - else: - default_size = 224 - - input_shape = _obtain_input_shape( - input_shape, - default_size=default_size, - min_size=32, - data_format=K.image_data_format(), - require_flatten=include_top, - weights=weights) - - if K.image_data_format() == 'channels_last': - row_axis, col_axis = (0, 1) - else: - row_axis, col_axis = (1, 2) - rows = input_shape[row_axis] - cols = input_shape[col_axis] - - if weights == 'imagenet': - if depth_multiplier != 1: - raise ValueError('If imagenet weights are being loaded, ' - 'depth multiplier must be 1') - - if alpha not in [0.25, 0.50, 0.75, 1.0]: - raise ValueError('If imagenet weights are being loaded, ' - 'alpha can be one of' - '`0.25`, `0.50`, `0.75` or `1.0` only.') - - if rows != cols or rows not in [128, 160, 192, 224]: - if rows is None: - rows = 224 - logging.warning('MobileNet shape is undefined.' - ' Weights for input shape (224, 224) will be loaded.') - else: - raise ValueError('If imagenet weights are being loaded, ' - 'input must have a static square shape (one of ' - '(128, 128), (160, 160), (192, 192), or (224, 224)).' - ' Input shape provided = %s' % (input_shape,)) - - if K.image_data_format() != 'channels_last': - logging.warning('The MobileNet family of models is only available ' - 'for the input data format "channels_last" ' - '(width, height, channels). ' - 'However your settings specify the default ' - 'data format "channels_first" (channels, width, height).' - ' You should set `image_data_format="channels_last"` ' - 'in your Keras config located at ~/.keras/keras.json. ' - 'The model being returned right now will expect inputs ' - 'to follow the "channels_last" data format.') - K.set_image_data_format('channels_last') - old_data_format = 'channels_first' - else: - old_data_format = None - - if input_tensor is None: - img_input = Input(shape=input_shape) - else: - if not K.is_keras_tensor(input_tensor): - img_input = Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - - x = _conv_block(img_input, 32, alpha, strides=(2, 2)) - x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1) - - x = _depthwise_conv_block( - x, 128, alpha, depth_multiplier, strides=(2, 2), block_id=2) - x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3) - - x = _depthwise_conv_block( - x, 256, alpha, depth_multiplier, strides=(2, 2), block_id=4) - x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5) - - x = _depthwise_conv_block( - x, 512, alpha, depth_multiplier, strides=(2, 2), block_id=6) - x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7) - x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8) - x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9) - x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10) - x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11) - - x = _depthwise_conv_block( - x, 1024, alpha, depth_multiplier, strides=(2, 2), block_id=12) - x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13) - - if include_top: - if K.image_data_format() == 'channels_first': - shape = (int(1024 * alpha), 1, 1) - else: - shape = (1, 1, int(1024 * alpha)) - - x = GlobalAveragePooling2D()(x) - x = Reshape(shape, name='reshape_1')(x) - x = Dropout(dropout, name='dropout')(x) - x = Conv2D(classes, (1, 1), padding='same', name='conv_preds')(x) - x = Activation('softmax', name='act_softmax')(x) - x = Reshape((classes,), name='reshape_2')(x) - else: - if pooling == 'avg': - x = GlobalAveragePooling2D()(x) - elif pooling == 'max': - x = GlobalMaxPooling2D()(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = get_source_inputs(input_tensor) - else: - inputs = img_input - - # Create model. - model = Model(inputs, x, name='mobilenet_%0.2f_%s' % (alpha, rows)) - - # load weights - if weights == 'imagenet': - if K.image_data_format() == 'channels_first': - raise ValueError('Weights for "channels_first" format ' - 'are not available.') - if alpha == 1.0: - alpha_text = '1_0' - elif alpha == 0.75: - alpha_text = '7_5' - elif alpha == 0.50: - alpha_text = '5_0' - else: - alpha_text = '2_5' - - if include_top: - model_name = 'mobilenet_%s_%d_tf.h5' % (alpha_text, rows) - weigh_path = BASE_WEIGHT_PATH + model_name - weights_path = get_file(model_name, weigh_path, cache_subdir='models') - else: - model_name = 'mobilenet_%s_%d_tf_no_top.h5' % (alpha_text, rows) - weigh_path = BASE_WEIGHT_PATH + model_name - weights_path = get_file(model_name, weigh_path, cache_subdir='models') - model.load_weights(weights_path) - elif weights is not None: - model.load_weights(weights) - - if old_data_format: - K.set_image_data_format(old_data_format) - return model - - -def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)): - """Adds an initial convolution layer (with batch normalization and relu6). - - Arguments: - inputs: Input tensor of shape `(rows, cols, 3)` - (with `channels_last` data format) or - (3, rows, cols) (with `channels_first` data format). - It should have exactly 3 inputs channels, - and width and height should be no smaller than 32. - E.g. `(224, 224, 3)` would be one valid value. - filters: Integer, the dimensionality of the output space - (i.e. the number of output filters in the convolution). - alpha: controls the width of the network. - - If `alpha` < 1.0, proportionally decreases the number - of filters in each layer. - - If `alpha` > 1.0, proportionally increases the number - of filters in each layer. - - If `alpha` = 1, default number of filters from the paper - are used at each layer. - kernel: An integer or tuple/list of 2 integers, specifying the - width and height of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the width and height. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - - Input shape: - 4D tensor with shape: - `(samples, channels, rows, cols)` if data_format='channels_first' - or 4D tensor with shape: - `(samples, rows, cols, channels)` if data_format='channels_last'. - - Output shape: - 4D tensor with shape: - `(samples, filters, new_rows, new_cols)` if data_format='channels_first' - or 4D tensor with shape: - `(samples, new_rows, new_cols, filters)` if data_format='channels_last'. - `rows` and `cols` values might have changed due to stride. - - Returns: - Output tensor of block. - """ - channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 - filters = int(filters * alpha) - x = ZeroPadding2D(padding=(1, 1), name='conv1_pad')(inputs) - x = Conv2D( - filters, - kernel, - padding='valid', - use_bias=False, - strides=strides, - name='conv1')(x) - x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x) - return Activation(relu6, name='conv1_relu')(x) - - -def _depthwise_conv_block(inputs, - pointwise_conv_filters, - alpha, - depth_multiplier=1, - strides=(1, 1), - block_id=1): - """Adds a depthwise convolution block. - - A depthwise convolution block consists of a depthwise conv, - batch normalization, relu6, pointwise convolution, - batch normalization and relu6 activation. - - Arguments: - inputs: Input tensor of shape `(rows, cols, channels)` - (with `channels_last` data format) or - (channels, rows, cols) (with `channels_first` data format). - pointwise_conv_filters: Integer, the dimensionality of the output space - (i.e. the number of output filters in the pointwise convolution). - alpha: controls the width of the network. - - If `alpha` < 1.0, proportionally decreases the number - of filters in each layer. - - If `alpha` > 1.0, proportionally increases the number - of filters in each layer. - - If `alpha` = 1, default number of filters from the paper - are used at each layer. - depth_multiplier: The number of depthwise convolution output channels - for each input channel. - The total number of depthwise convolution output - channels will be equal to `filters_in * depth_multiplier`. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the width and height. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - block_id: Integer, a unique identification designating the block number. - - Input shape: - 4D tensor with shape: - `(batch, channels, rows, cols)` if data_format='channels_first' - or 4D tensor with shape: - `(batch, rows, cols, channels)` if data_format='channels_last'. - - Output shape: - 4D tensor with shape: - `(batch, filters, new_rows, new_cols)` if data_format='channels_first' - or 4D tensor with shape: - `(batch, new_rows, new_cols, filters)` if data_format='channels_last'. - `rows` and `cols` values might have changed due to stride. - - Returns: - Output tensor of block. - """ - channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 - pointwise_conv_filters = int(pointwise_conv_filters * alpha) - x = ZeroPadding2D(padding=(1, 1), name='conv_pad_%d' % block_id)(inputs) - x = DepthwiseConv2D( # pylint: disable=not-callable - (3, 3), - padding='valid', - depth_multiplier=depth_multiplier, - strides=strides, - use_bias=False, - name='conv_dw_%d' % block_id)(x) - x = BatchNormalization(axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x) - x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x) - - x = Conv2D( - pointwise_conv_filters, (1, 1), - padding='same', - use_bias=False, - strides=(1, 1), - name='conv_pw_%d' % block_id)( - x) - x = BatchNormalization(axis=channel_axis, name='conv_pw_%d_bn' % block_id)(x) - return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x) +tf_export('keras.applications.mobilenet.MobileNet', + 'keras.applications.MobileNet')(MobileNet) +tf_export('keras.applications.mobilenet.preprocess_input')(preprocess_input) diff --git a/tensorflow/python/keras/applications/mobilenet_test.py b/tensorflow/python/keras/applications/mobilenet_test.py deleted file mode 100644 index 5661ed7856ad6e307cf3e388ea3db98c69db983f..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/mobilenet_test.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for MobileNet application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python import keras -from tensorflow.python.platform import test - - -class MobileNetTest(test.TestCase): - - def test_with_top(self): - model = keras.applications.MobileNet(weights=None) - self.assertEqual(model.output_shape, (None, 1000)) - - def test_no_top(self): - model = keras.applications.MobileNet(weights=None, include_top=False) - self.assertEqual(model.output_shape, (None, None, None, 1024)) - - def test_with_pooling(self): - model = keras.applications.MobileNet(weights=None, - include_top=False, - pooling='avg') - self.assertEqual(model.output_shape, (None, 1024)) - - def test_weight_loading(self): - with self.assertRaises(ValueError): - keras.applications.MobileNet(weights='unknown', - include_top=False) - with self.assertRaises(ValueError): - keras.applications.MobileNet(weights='imagenet', - classes=2000) - - def test_preprocess_input(self): - x = np.random.uniform(0, 255, (2, 300, 200, 3)) - out1 = keras.applications.mobilenet.preprocess_input(x) - self.assertAllClose(np.mean(out1), 0., atol=0.1) - - def test_invalid_use_cases(self): - keras.backend.set_image_data_format('channels_first') - model = keras.applications.MobileNet(weights=None) - self.assertEqual(model.output_shape, (None, 1000)) - keras.backend.set_image_data_format('channels_last') - - def test_mobilenet_variable_input_channels(self): - input_shape = (None, None, 1) - model = keras.applications.MobileNet(weights=None, - include_top=False, - input_shape=input_shape) - self.assertEqual(model.output_shape, (None, None, None, 1024)) - - input_shape = (None, None, 4) - model = keras.applications.MobileNet(weights=None, - include_top=False, - input_shape=input_shape) - self.assertEqual(model.output_shape, (None, None, None, 1024)) - - def test_mobilenet_image_size(self): - with self.test_session(): - valid_image_sizes = [128, 160, 192, 224] - for size in valid_image_sizes: - keras.backend.set_image_data_format('channels_last') - input_shape = (size, size, 3) - model = keras.applications.MobileNet(input_shape=input_shape, - weights=None, - include_top=True) - self.assertEqual(model.input_shape, (None,) + input_shape) - - keras.backend.set_image_data_format('channels_first') - input_shape = (3, size, size) - model = keras.applications.MobileNet(input_shape=input_shape, - weights=None, - include_top=True) - self.assertEqual(model.input_shape, (None,) + input_shape) - - keras.backend.set_image_data_format('channels_last') - invalid_image_shape = (112, 112, 3) - with self.assertRaises(ValueError): - model = keras.applications.MobileNet(input_shape=invalid_image_shape, - weights='imagenet', - include_top=True) - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/keras/applications/mobilenet_v2.py b/tensorflow/python/keras/applications/mobilenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..9194c3ee14840f74e56a0290d2febdf1b3458cc6 --- /dev/null +++ b/tensorflow/python/keras/applications/mobilenet_v2.py @@ -0,0 +1,22 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=invalid-name +"""MobileNet v2 models for Keras. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# TODO(fchollet): export MobileNetV2 as part of the public API in next version. diff --git a/tensorflow/python/keras/applications/nasnet.py b/tensorflow/python/keras/applications/nasnet.py index b521bc673139403dcdecbba8e35b5bafec2d42bf..26ff5db53f913d78ad4f054fee30ff66d5f7dcbf 100644 --- a/tensorflow/python/keras/applications/nasnet.py +++ b/tensorflow/python/keras/applications/nasnet.py @@ -12,784 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# pylint: disable=line-too-long # pylint: disable=invalid-name -# pylint: disable=unused-import """NASNet-A models for Keras. - -NASNet refers to Neural Architecture Search Network, a family of models -that were designed automatically by learning the model architectures -directly on the dataset of interest. - -Here we consider NASNet-A, the highest performance model that was found -for the CIFAR-10 dataset, and then extended to ImageNet 2012 dataset, -obtaining state of the art performance on CIFAR-10 and ImageNet 2012. -Only the NASNet-A models, and their respective weights, which are suited -for ImageNet 2012 are provided. - -The below table describes the performance on ImageNet 2012: --------------------------------------------------------------------------------- - Architecture | Top-1 Acc | Top-5 Acc | Multiply-Adds | Params (M) --------------------------------------------------------------------------------- -| NASNet-A (4 @ 1056) | 74.0 % | 91.6 % | 564 M | 5.3 | -| NASNet-A (6 @ 4032) | 82.7 % | 96.2 % | 23.8 B | 88.9 | --------------------------------------------------------------------------------- - -References: - - [Learning Transferable Architectures for Scalable Image Recognition] - (https://arxiv.org/abs/1707.07012) """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -from tensorflow.python.keras import backend as K -from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.applications.inception_v3 import preprocess_input -from tensorflow.python.keras.engine.network import get_source_inputs -from tensorflow.python.keras.layers import Activation -from tensorflow.python.keras.layers import add -from tensorflow.python.keras.layers import AveragePooling2D -from tensorflow.python.keras.layers import BatchNormalization -from tensorflow.python.keras.layers import concatenate -from tensorflow.python.keras.layers import Conv2D -from tensorflow.python.keras.layers import Cropping2D -from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras.layers import Input -from tensorflow.python.keras.layers import MaxPooling2D -from tensorflow.python.keras.layers import SeparableConv2D -from tensorflow.python.keras.layers import ZeroPadding2D -from tensorflow.python.keras.models import Model -from tensorflow.python.keras.utils.data_utils import get_file -from tensorflow.python.platform import tf_logging as logging +from keras_applications import nasnet from tensorflow.python.util.tf_export import tf_export +NASNetMobile = nasnet.NASNetMobile +NASNetLarge = nasnet.NASNetLarge +decode_predictions = nasnet.decode_predictions +preprocess_input = nasnet.preprocess_input -NASNET_MOBILE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-mobile.h5' -NASNET_MOBILE_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-mobile-no-top.h5' -NASNET_LARGE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-large.h5' -NASNET_LARGE_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-large-no-top.h5' - - -def NASNet(input_shape=None, - penultimate_filters=4032, - num_blocks=6, - stem_block_filters=96, - skip_reduction=True, - filter_multiplier=2, - include_top=True, - weights=None, - input_tensor=None, - pooling=None, - classes=1000, - default_size=None): - """Instantiates a NASNet model. - - Note that only TensorFlow is supported for now, - therefore it only works with the data format - `image_data_format='channels_last'` in your Keras config - at `~/.keras/keras.json`. - - Arguments: - input_shape: Optional shape tuple, the input shape - is by default `(331, 331, 3)` for NASNetLarge and - `(224, 224, 3)` for NASNetMobile. - It should have exactly 3 inputs channels, - and width and height should be no smaller than 32. - E.g. `(224, 224, 3)` would be one valid value. - penultimate_filters: Number of filters in the penultimate layer. - NASNet models use the notation `NASNet (N @ P)`, where: - - N is the number of blocks - - P is the number of penultimate filters - num_blocks: Number of repeated blocks of the NASNet model. - NASNet models use the notation `NASNet (N @ P)`, where: - - N is the number of blocks - - P is the number of penultimate filters - stem_block_filters: Number of filters in the initial stem block - skip_reduction: Whether to skip the reduction step at the tail - end of the network. Set to `False` for CIFAR models. - filter_multiplier: Controls the width of the network. - - If `filter_multiplier` < 1.0, proportionally decreases the number - of filters in each layer. - - If `filter_multiplier` > 1.0, proportionally increases the number - of filters in each layer. - - If `filter_multiplier` = 1, default number of filters from the - paper are used at each layer. - include_top: Whether to include the fully-connected - layer at the top of the network. - weights: `None` (random initialization) or - `imagenet` (ImageNet weights) - input_tensor: Optional Keras tensor (i.e. output of - `layers.Input()`) - to use as image input for the model. - pooling: Optional pooling mode for feature extraction - when `include_top` is `False`. - - `None` means that the output of the model - will be the 4D tensor output of the - last convolutional layer. - - `avg` means that global average pooling - will be applied to the output of the - last convolutional layer, and thus - the output of the model will be a - 2D tensor. - - `max` means that global max pooling will - be applied. - classes: Optional number of classes to classify images - into, only to be specified if `include_top` is True, and - if no `weights` argument is specified. - default_size: Specifies the default image size of the model - - Returns: - A Keras model instance. - - Raises: - ValueError: In case of invalid argument for `weights`, - invalid input shape or invalid `penultimate_filters` value. - RuntimeError: If attempting to run this model with a - backend that does not support separable convolutions. - """ - if K.backend() != 'tensorflow': - raise RuntimeError('Only Tensorflow backend is currently supported, ' - 'as other backends do not support ' - 'separable convolution.') - - if not (weights in {'imagenet', None} or os.path.exists(weights)): - raise ValueError('The `weights` argument should be either ' - '`None` (random initialization), `imagenet` ' - '(pre-training on ImageNet), ' - 'or the path to the weights file to be loaded.') - - if weights == 'imagenet' and include_top and classes != 1000: - raise ValueError('If using `weights` as ImageNet with `include_top` ' - 'as true, `classes` should be 1000') - - if (isinstance(input_shape, tuple) and None in input_shape and - weights == 'imagenet'): - raise ValueError('When specifying the input shape of a NASNet' - ' and loading `ImageNet` weights, ' - 'the input_shape argument must be static ' - '(no None entries). Got: `input_shape=' + - str(input_shape) + '`.') - - if default_size is None: - default_size = 331 - - # Determine proper input shape and default size. - input_shape = _obtain_input_shape( - input_shape, - default_size=default_size, - min_size=32, - data_format=K.image_data_format(), - require_flatten=False, - weights=weights) - - if K.image_data_format() != 'channels_last': - logging.warning('The NASNet family of models is only available ' - 'for the input data format "channels_last" ' - '(width, height, channels). ' - 'However your settings specify the default ' - 'data format "channels_first" (channels, width, height).' - ' You should set `image_data_format="channels_last"` ' - 'in your Keras config located at ~/.keras/keras.json. ' - 'The model being returned right now will expect inputs ' - 'to follow the "channels_last" data format.') - K.set_image_data_format('channels_last') - old_data_format = 'channels_first' - else: - old_data_format = None - - if input_tensor is None: - img_input = Input(shape=input_shape) - else: - if not K.is_keras_tensor(input_tensor): - img_input = Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - - if penultimate_filters % 24 != 0: - raise ValueError( - 'For NASNet-A models, the value of `penultimate_filters` ' - 'needs to be divisible by 24. Current value: %d' % penultimate_filters) - - channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 - filters = penultimate_filters // 24 - - if not skip_reduction: - x = Conv2D( - stem_block_filters, (3, 3), - strides=(2, 2), - padding='valid', - use_bias=False, - name='stem_conv1', - kernel_initializer='he_normal')( - img_input) - else: - x = Conv2D( - stem_block_filters, (3, 3), - strides=(1, 1), - padding='same', - use_bias=False, - name='stem_conv1', - kernel_initializer='he_normal')( - img_input) - - x = BatchNormalization( - axis=channel_dim, momentum=0.9997, epsilon=1e-3, name='stem_bn1')( - x) - - p = None - if not skip_reduction: # imagenet / mobile mode - x, p = _reduction_a_cell( - x, p, filters // (filter_multiplier**2), block_id='stem_1') - x, p = _reduction_a_cell( - x, p, filters // filter_multiplier, block_id='stem_2') - - for i in range(num_blocks): - x, p = _normal_a_cell(x, p, filters, block_id='%d' % (i)) - - x, p0 = _reduction_a_cell( - x, p, filters * filter_multiplier, block_id='reduce_%d' % (num_blocks)) - - p = p0 if not skip_reduction else p - - for i in range(num_blocks): - x, p = _normal_a_cell( - x, p, filters * filter_multiplier, block_id='%d' % (num_blocks + i + 1)) - - x, p0 = _reduction_a_cell( - x, - p, - filters * filter_multiplier**2, - block_id='reduce_%d' % (2 * num_blocks)) - - p = p0 if not skip_reduction else p - - for i in range(num_blocks): - x, p = _normal_a_cell( - x, - p, - filters * filter_multiplier**2, - block_id='%d' % (2 * num_blocks + i + 1)) - - x = Activation('relu')(x) - - if include_top: - x = GlobalAveragePooling2D()(x) - x = Dense(classes, activation='softmax', name='predictions')(x) - else: - if pooling == 'avg': - x = GlobalAveragePooling2D()(x) - elif pooling == 'max': - x = GlobalMaxPooling2D()(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = get_source_inputs(input_tensor) - else: - inputs = img_input - - model = Model(inputs, x, name='NASNet') - - # load weights - if weights == 'imagenet': - if default_size == 224: # mobile version - if include_top: - weight_path = NASNET_MOBILE_WEIGHT_PATH - model_name = 'nasnet_mobile.h5' - else: - weight_path = NASNET_MOBILE_WEIGHT_PATH_NO_TOP - model_name = 'nasnet_mobile_no_top.h5' - - weights_file = get_file(model_name, weight_path, cache_subdir='models') - model.load_weights(weights_file) - - elif default_size == 331: # large version - if include_top: - weight_path = NASNET_LARGE_WEIGHT_PATH - model_name = 'nasnet_large.h5' - else: - weight_path = NASNET_LARGE_WEIGHT_PATH_NO_TOP - model_name = 'nasnet_large_no_top.h5' - - weights_file = get_file(model_name, weight_path, cache_subdir='models') - model.load_weights(weights_file) - else: - raise ValueError('ImageNet weights can only be loaded with NASNetLarge' - ' or NASNetMobile') - elif weights is not None: - model.load_weights(weights) - - if old_data_format: - K.set_image_data_format(old_data_format) - - return model - - -@tf_export('keras.applications.NASNetLarge', - 'keras.applications.nasnet.NASNetLarge') -def NASNetLarge(input_shape=None, - include_top=True, - weights='imagenet', - input_tensor=None, - pooling=None, - classes=1000): - """Instantiates a NASNet model in ImageNet mode. - - Note that only TensorFlow is supported for now, - therefore it only works with the data format - `image_data_format='channels_last'` in your Keras config - at `~/.keras/keras.json`. - - Arguments: - input_shape: Optional shape tuple, only to be specified - if `include_top` is False (otherwise the input shape - has to be `(331, 331, 3)` for NASNetLarge. - It should have exactly 3 inputs channels, - and width and height should be no smaller than 32. - E.g. `(224, 224, 3)` would be one valid value. - include_top: Whether to include the fully-connected - layer at the top of the network. - weights: `None` (random initialization) or - `imagenet` (ImageNet weights) - input_tensor: Optional Keras tensor (i.e. output of - `layers.Input()`) - to use as image input for the model. - pooling: Optional pooling mode for feature extraction - when `include_top` is `False`. - - `None` means that the output of the model - will be the 4D tensor output of the - last convolutional layer. - - `avg` means that global average pooling - will be applied to the output of the - last convolutional layer, and thus - the output of the model will be a - 2D tensor. - - `max` means that global max pooling will - be applied. - classes: Optional number of classes to classify images - into, only to be specified if `include_top` is True, and - if no `weights` argument is specified. - - Returns: - A Keras model instance. - - Raises: - ValueError: in case of invalid argument for `weights`, - or invalid input shape. - RuntimeError: If attempting to run this model with a - backend that does not support separable convolutions. - """ - return NASNet( - input_shape, - penultimate_filters=4032, - num_blocks=6, - stem_block_filters=96, - skip_reduction=False, - filter_multiplier=2, - include_top=include_top, - weights=weights, - input_tensor=input_tensor, - pooling=pooling, - classes=classes, - default_size=331) - - -@tf_export('keras.applications.NASNetMobile', - 'keras.applications.nasnet.NASNetMobile') -def NASNetMobile(input_shape=None, - include_top=True, - weights='imagenet', - input_tensor=None, - pooling=None, - classes=1000): - """Instantiates a Mobile NASNet model in ImageNet mode. - - Note that only TensorFlow is supported for now, - therefore it only works with the data format - `image_data_format='channels_last'` in your Keras config - at `~/.keras/keras.json`. - - Arguments: - input_shape: Optional shape tuple, only to be specified - if `include_top` is False (otherwise the input shape - has to be `(224, 224, 3)` for NASNetMobile - It should have exactly 3 inputs channels, - and width and height should be no smaller than 32. - E.g. `(224, 224, 3)` would be one valid value. - include_top: Whether to include the fully-connected - layer at the top of the network. - weights: `None` (random initialization) or - `imagenet` (ImageNet weights) - input_tensor: Optional Keras tensor (i.e. output of - `layers.Input()`) - to use as image input for the model. - pooling: Optional pooling mode for feature extraction - when `include_top` is `False`. - - `None` means that the output of the model - will be the 4D tensor output of the - last convolutional layer. - - `avg` means that global average pooling - will be applied to the output of the - last convolutional layer, and thus - the output of the model will be a - 2D tensor. - - `max` means that global max pooling will - be applied. - classes: Optional number of classes to classify images - into, only to be specified if `include_top` is True, and - if no `weights` argument is specified. - - Returns: - A Keras model instance. - - Raises: - ValueError: In case of invalid argument for `weights`, - or invalid input shape. - RuntimeError: If attempting to run this model with a - backend that does not support separable convolutions. - """ - return NASNet( - input_shape, - penultimate_filters=1056, - num_blocks=4, - stem_block_filters=32, - skip_reduction=False, - filter_multiplier=2, - include_top=include_top, - weights=weights, - input_tensor=input_tensor, - pooling=pooling, - classes=classes, - default_size=224) - - -def _separable_conv_block(ip, - filters, - kernel_size=(3, 3), - strides=(1, 1), - block_id=None): - """Adds 2 blocks of [relu-separable conv-batchnorm]. - - Arguments: - ip: Input tensor - filters: Number of output filters per layer - kernel_size: Kernel size of separable convolutions - strides: Strided convolution for downsampling - block_id: String block_id - - Returns: - A Keras tensor - """ - channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 - - with K.name_scope('separable_conv_block_%s' % block_id): - x = Activation('relu')(ip) - x = SeparableConv2D( - filters, - kernel_size, - strides=strides, - name='separable_conv_1_%s' % block_id, - padding='same', - use_bias=False, - kernel_initializer='he_normal')( - x) - x = BatchNormalization( - axis=channel_dim, - momentum=0.9997, - epsilon=1e-3, - name='separable_conv_1_bn_%s' % (block_id))( - x) - x = Activation('relu')(x) - x = SeparableConv2D( - filters, - kernel_size, - name='separable_conv_2_%s' % block_id, - padding='same', - use_bias=False, - kernel_initializer='he_normal')( - x) - x = BatchNormalization( - axis=channel_dim, - momentum=0.9997, - epsilon=1e-3, - name='separable_conv_2_bn_%s' % (block_id))( - x) - return x - - -def _adjust_block(p, ip, filters, block_id=None): - """Adjusts the input `previous path` to match the shape of the `input`. - - Used in situations where the output number of filters needs to be changed. - - Arguments: - p: Input tensor which needs to be modified - ip: Input tensor whose shape needs to be matched - filters: Number of output filters to be matched - block_id: String block_id - - Returns: - Adjusted Keras tensor - """ - channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 - img_dim = 2 if K.image_data_format() == 'channels_first' else -2 - - ip_shape = K.int_shape(ip) - - if p is not None: - p_shape = K.int_shape(p) - - with K.name_scope('adjust_block'): - if p is None: - p = ip - - elif p_shape[img_dim] != ip_shape[img_dim]: - with K.name_scope('adjust_reduction_block_%s' % block_id): - p = Activation('relu', name='adjust_relu_1_%s' % block_id)(p) - - p1 = AveragePooling2D( - (1, 1), - strides=(2, 2), - padding='valid', - name='adjust_avg_pool_1_%s' % block_id)( - p) - p1 = Conv2D( - filters // 2, (1, 1), - padding='same', - use_bias=False, - name='adjust_conv_1_%s' % block_id, - kernel_initializer='he_normal')( - p1) - - p2 = ZeroPadding2D(padding=((0, 1), (0, 1)))(p) - p2 = Cropping2D(cropping=((1, 0), (1, 0)))(p2) - p2 = AveragePooling2D( - (1, 1), - strides=(2, 2), - padding='valid', - name='adjust_avg_pool_2_%s' % block_id)( - p2) - p2 = Conv2D( - filters // 2, (1, 1), - padding='same', - use_bias=False, - name='adjust_conv_2_%s' % block_id, - kernel_initializer='he_normal')( - p2) - - p = concatenate([p1, p2], axis=channel_dim) - p = BatchNormalization( - axis=channel_dim, - momentum=0.9997, - epsilon=1e-3, - name='adjust_bn_%s' % block_id)( - p) - - elif p_shape[channel_dim] != filters: - with K.name_scope('adjust_projection_block_%s' % block_id): - p = Activation('relu')(p) - p = Conv2D( - filters, (1, 1), - strides=(1, 1), - padding='same', - name='adjust_conv_projection_%s' % block_id, - use_bias=False, - kernel_initializer='he_normal')( - p) - p = BatchNormalization( - axis=channel_dim, - momentum=0.9997, - epsilon=1e-3, - name='adjust_bn_%s' % block_id)( - p) - return p - - -def _normal_a_cell(ip, p, filters, block_id=None): - """Adds a Normal cell for NASNet-A (Fig. 4 in the paper). - - Arguments: - ip: Input tensor `x` - p: Input tensor `p` - filters: Number of output filters - block_id: String block_id - - Returns: - A Keras tensor - """ - channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 - - with K.name_scope('normal_A_block_%s' % block_id): - p = _adjust_block(p, ip, filters, block_id) - - h = Activation('relu')(ip) - h = Conv2D( - filters, (1, 1), - strides=(1, 1), - padding='same', - name='normal_conv_1_%s' % block_id, - use_bias=False, - kernel_initializer='he_normal')( - h) - h = BatchNormalization( - axis=channel_dim, - momentum=0.9997, - epsilon=1e-3, - name='normal_bn_1_%s' % block_id)( - h) - - with K.name_scope('block_1'): - x1_1 = _separable_conv_block( - h, filters, kernel_size=(5, 5), block_id='normal_left1_%s' % block_id) - x1_2 = _separable_conv_block( - p, filters, block_id='normal_right1_%s' % block_id) - x1 = add([x1_1, x1_2], name='normal_add_1_%s' % block_id) - - with K.name_scope('block_2'): - x2_1 = _separable_conv_block( - p, filters, (5, 5), block_id='normal_left2_%s' % block_id) - x2_2 = _separable_conv_block( - p, filters, (3, 3), block_id='normal_right2_%s' % block_id) - x2 = add([x2_1, x2_2], name='normal_add_2_%s' % block_id) - - with K.name_scope('block_3'): - x3 = AveragePooling2D( - (3, 3), - strides=(1, 1), - padding='same', - name='normal_left3_%s' % (block_id))( - h) - x3 = add([x3, p], name='normal_add_3_%s' % block_id) - - with K.name_scope('block_4'): - x4_1 = AveragePooling2D( - (3, 3), - strides=(1, 1), - padding='same', - name='normal_left4_%s' % (block_id))( - p) - x4_2 = AveragePooling2D( - (3, 3), - strides=(1, 1), - padding='same', - name='normal_right4_%s' % (block_id))( - p) - x4 = add([x4_1, x4_2], name='normal_add_4_%s' % block_id) - - with K.name_scope('block_5'): - x5 = _separable_conv_block( - h, filters, block_id='normal_left5_%s' % block_id) - x5 = add([x5, h], name='normal_add_5_%s' % block_id) - - x = concatenate( - [p, x1, x2, x3, x4, x5], - axis=channel_dim, - name='normal_concat_%s' % block_id) - return x, ip - - -def _reduction_a_cell(ip, p, filters, block_id=None): - """Adds a Reduction cell for NASNet-A (Fig. 4 in the paper). - - Arguments: - ip: Input tensor `x` - p: Input tensor `p` - filters: Number of output filters - block_id: String block_id - - Returns: - A Keras tensor - """ - channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 - - with K.name_scope('reduction_A_block_%s' % block_id): - p = _adjust_block(p, ip, filters, block_id) - - h = Activation('relu')(ip) - h = Conv2D( - filters, (1, 1), - strides=(1, 1), - padding='same', - name='reduction_conv_1_%s' % block_id, - use_bias=False, - kernel_initializer='he_normal')( - h) - h = BatchNormalization( - axis=channel_dim, - momentum=0.9997, - epsilon=1e-3, - name='reduction_bn_1_%s' % block_id)( - h) - - with K.name_scope('block_1'): - x1_1 = _separable_conv_block( - h, - filters, (5, 5), - strides=(2, 2), - block_id='reduction_left1_%s' % block_id) - x1_2 = _separable_conv_block( - p, - filters, (7, 7), - strides=(2, 2), - block_id='reduction_1_%s' % block_id) - x1 = add([x1_1, x1_2], name='reduction_add_1_%s' % block_id) - - with K.name_scope('block_2'): - x2_1 = MaxPooling2D( - (3, 3), - strides=(2, 2), - padding='same', - name='reduction_left2_%s' % block_id)( - h) - x2_2 = _separable_conv_block( - p, - filters, (7, 7), - strides=(2, 2), - block_id='reduction_right2_%s' % block_id) - x2 = add([x2_1, x2_2], name='reduction_add_2_%s' % block_id) - - with K.name_scope('block_3'): - x3_1 = AveragePooling2D( - (3, 3), - strides=(2, 2), - padding='same', - name='reduction_left3_%s' % block_id)( - h) - x3_2 = _separable_conv_block( - p, - filters, (5, 5), - strides=(2, 2), - block_id='reduction_right3_%s' % block_id) - x3 = add([x3_1, x3_2], name='reduction_add3_%s' % block_id) - - with K.name_scope('block_4'): - x4 = AveragePooling2D( - (3, 3), - strides=(1, 1), - padding='same', - name='reduction_left4_%s' % block_id)( - x1) - x4 = add([x2, x4]) - - with K.name_scope('block_5'): - x5_1 = _separable_conv_block( - x1, filters, (3, 3), block_id='reduction_left4_%s' % block_id) - x5_2 = MaxPooling2D( - (3, 3), - strides=(2, 2), - padding='same', - name='reduction_right5_%s' % block_id)( - h) - x5 = add([x5_1, x5_2], name='reduction_add4_%s' % block_id) - - x = concatenate( - [x2, x3, x4, x5], - axis=channel_dim, - name='reduction_concat_%s' % block_id) - return x, ip +tf_export('keras.applications.nasnet.NASNetMobile', + 'keras.applications.NASNetMobile')(NASNetMobile) +tf_export('keras.applications.nasnet.NASNetLarge', + 'keras.applications.NASNetLarge')(NASNetLarge) +tf_export('keras.applications.nasnet.preprocess_input')(preprocess_input) diff --git a/tensorflow/python/keras/applications/nasnet_test.py b/tensorflow/python/keras/applications/nasnet_test.py deleted file mode 100644 index f96c3aa51c17ff3a123ad1a22ceff6c23f69d311..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/nasnet_test.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Nasnet application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python import keras -from tensorflow.python.platform import test - - -class NASNetMobileTest(test.TestCase): - - def test_with_top(self): - model = keras.applications.NASNetMobile(weights=None) - self.assertEqual(model.output_shape, (None, 1000)) - - def test_no_top(self): - model = keras.applications.NASNetMobile(weights=None, include_top=False) - self.assertEqual(model.output_shape, (None, None, None, 1056)) - - def test_with_pooling(self): - model = keras.applications.NASNetMobile(weights=None, - include_top=False, - pooling='avg') - self.assertEqual(model.output_shape, (None, 1056)) - - def test_weight_loading(self): - with self.assertRaises(ValueError): - keras.applications.NASNetMobile(weights='unknown', - include_top=False) - with self.assertRaises(ValueError): - keras.applications.NASNetMobile(weights='imagenet', - classes=2000) - - -class NASNetLargeTest(test.TestCase): - - def test_with_top(self): - model = keras.applications.NASNetLarge(weights=None) - self.assertEqual(model.output_shape, (None, 1000)) - - def test_no_top(self): - model = keras.applications.NASNetLarge(weights=None, include_top=False) - self.assertEqual(model.output_shape, (None, None, None, 4032)) - - def test_with_pooling(self): - model = keras.applications.NASNetLarge(weights=None, - include_top=False, - pooling='avg') - self.assertEqual(model.output_shape, (None, 4032)) - - def test_weight_loading(self): - with self.assertRaises(ValueError): - keras.applications.NASNetLarge(weights='unknown', - include_top=False) - with self.assertRaises(ValueError): - keras.applications.NASNetLarge(weights='imagenet', - classes=2000) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/keras/applications/resnet50.py b/tensorflow/python/keras/applications/resnet50.py index 508550f445e39dcf2a249bc91aaee289abfe3d1f..4d804a3c440287f44ad50b34e575474cfbf051e1 100644 --- a/tensorflow/python/keras/applications/resnet50.py +++ b/tensorflow/python/keras/applications/resnet50.py @@ -13,292 +13,18 @@ # limitations under the License. # ============================================================================== # pylint: disable=invalid-name -# pylint: disable=unused-import """ResNet50 model for Keras. - -# Reference: - -- [Deep Residual Learning for Image -Recognition](https://arxiv.org/abs/1512.03385) - -Adapted from code contributed by BigMoyan. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -from tensorflow.python.keras import backend as K -from tensorflow.python.keras import layers -from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras.engine.network import get_source_inputs -from tensorflow.python.keras.layers import Activation -from tensorflow.python.keras.layers import AveragePooling2D -from tensorflow.python.keras.layers import BatchNormalization -from tensorflow.python.keras.layers import Conv2D -from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.layers import Flatten -from tensorflow.python.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras.layers import Input -from tensorflow.python.keras.layers import MaxPooling2D -from tensorflow.python.keras.layers import ZeroPadding2D -from tensorflow.python.keras.models import Model -from tensorflow.python.keras.utils import layer_utils -from tensorflow.python.keras.utils.data_utils import get_file -from tensorflow.python.platform import tf_logging as logging +from keras_applications import resnet50 from tensorflow.python.util.tf_export import tf_export +ResNet50 = resnet50.ResNet50 +decode_predictions = resnet50.decode_predictions +preprocess_input = resnet50.preprocess_input -WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5' -WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' - - -def identity_block(input_tensor, kernel_size, filters, stage, block): - """The identity block is the block that has no conv layer at shortcut. - - Arguments: - input_tensor: input tensor - kernel_size: default 3, the kernel size of middle conv layer at main path - filters: list of integers, the filters of 3 conv layer at main path - stage: integer, current stage label, used for generating layer names - block: 'a','b'..., current block label, used for generating layer names - - Returns: - Output tensor for the block. - """ - filters1, filters2, filters3 = filters - if K.image_data_format() == 'channels_last': - bn_axis = 3 - else: - bn_axis = 1 - conv_name_base = 'res' + str(stage) + block + '_branch' - bn_name_base = 'bn' + str(stage) + block + '_branch' - - x = Conv2D(filters1, (1, 1), name=conv_name_base + '2a')(input_tensor) - x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) - x = Activation('relu')(x) - - x = Conv2D( - filters2, kernel_size, padding='same', name=conv_name_base + '2b')( - x) - x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) - x = Activation('relu')(x) - - x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x) - x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) - - x = layers.add([x, input_tensor]) - x = Activation('relu')(x) - return x - - -def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, - 2)): - """A block that has a conv layer at shortcut. - - Arguments: - input_tensor: input tensor - kernel_size: default 3, the kernel size of middle conv layer at main path - filters: list of integers, the filters of 3 conv layer at main path - stage: integer, current stage label, used for generating layer names - block: 'a','b'..., current block label, used for generating layer names - strides: Strides for the first conv layer in the block. - - Returns: - Output tensor for the block. - - Note that from stage 3, - the first conv layer at main path is with strides=(2, 2) - And the shortcut should have strides=(2, 2) as well - """ - filters1, filters2, filters3 = filters - if K.image_data_format() == 'channels_last': - bn_axis = 3 - else: - bn_axis = 1 - conv_name_base = 'res' + str(stage) + block + '_branch' - bn_name_base = 'bn' + str(stage) + block + '_branch' - - x = Conv2D( - filters1, (1, 1), strides=strides, name=conv_name_base + '2a')( - input_tensor) - x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) - x = Activation('relu')(x) - - x = Conv2D( - filters2, kernel_size, padding='same', name=conv_name_base + '2b')( - x) - x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) - x = Activation('relu')(x) - - x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x) - x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) - - shortcut = Conv2D( - filters3, (1, 1), strides=strides, name=conv_name_base + '1')( - input_tensor) - shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) - - x = layers.add([x, shortcut]) - x = Activation('relu')(x) - return x - - -@tf_export('keras.applications.ResNet50', - 'keras.applications.resnet50.ResNet50') -def ResNet50(include_top=True, - weights='imagenet', - input_tensor=None, - input_shape=None, - pooling=None, - classes=1000): - """Instantiates the ResNet50 architecture. - - Optionally loads weights pre-trained - on ImageNet. Note that when using TensorFlow, - for best performance you should set - `image_data_format='channels_last'` in your Keras config - at ~/.keras/keras.json. - - The model and the weights are compatible with both - TensorFlow and Theano. The data format - convention used by the model is the one - specified in your Keras config file. - - Arguments: - include_top: whether to include the fully-connected - layer at the top of the network. - weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. - input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) - to use as image input for the model. - input_shape: optional shape tuple, only to be specified - if `include_top` is False (otherwise the input shape - has to be `(224, 224, 3)` (with `channels_last` data format) - or `(3, 224, 224)` (with `channels_first` data format). - It should have exactly 3 inputs channels, - and width and height should be no smaller than 197. - E.g. `(200, 200, 3)` would be one valid value. - pooling: Optional pooling mode for feature extraction - when `include_top` is `False`. - - `None` means that the output of the model will be - the 4D tensor output of the - last convolutional layer. - - `avg` means that global average pooling - will be applied to the output of the - last convolutional layer, and thus - the output of the model will be a 2D tensor. - - `max` means that global max pooling will - be applied. - classes: optional number of classes to classify images - into, only to be specified if `include_top` is True, and - if no `weights` argument is specified. - - Returns: - A Keras model instance. - - Raises: - ValueError: in case of invalid argument for `weights`, - or invalid input shape. - """ - if not (weights in {'imagenet', None} or os.path.exists(weights)): - raise ValueError('The `weights` argument should be either ' - '`None` (random initialization), `imagenet` ' - '(pre-training on ImageNet), ' - 'or the path to the weights file to be loaded.') - - if weights == 'imagenet' and include_top and classes != 1000: - raise ValueError('If using `weights` as imagenet with `include_top`' - ' as true, `classes` should be 1000') - - # Determine proper input shape - input_shape = _obtain_input_shape( - input_shape, - default_size=224, - min_size=197, - data_format=K.image_data_format(), - require_flatten=include_top, - weights=weights) - - if input_tensor is None: - img_input = Input(shape=input_shape) - else: - if not K.is_keras_tensor(input_tensor): - img_input = Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - if K.image_data_format() == 'channels_last': - bn_axis = 3 - else: - bn_axis = 1 - - x = Conv2D( - 64, (7, 7), strides=(2, 2), padding='same', name='conv1')(img_input) - x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x) - x = Activation('relu')(x) - x = MaxPooling2D((3, 3), strides=(2, 2))(x) - - x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) - x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') - x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') - - x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') - x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') - - x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') - x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') - - x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') - x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') - x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') - - x = AveragePooling2D((7, 7), name='avg_pool')(x) - - if include_top: - x = Flatten()(x) - x = Dense(classes, activation='softmax', name='fc1000')(x) - else: - if pooling == 'avg': - x = GlobalAveragePooling2D()(x) - elif pooling == 'max': - x = GlobalMaxPooling2D()(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = get_source_inputs(input_tensor) - else: - inputs = img_input - # Create model. - model = Model(inputs, x, name='resnet50') - - # load weights - if weights == 'imagenet': - if include_top: - weights_path = get_file( - 'resnet50_weights_tf_dim_ordering_tf_kernels.h5', - WEIGHTS_PATH, - cache_subdir='models', - md5_hash='a7b3fe01876f51b976af0dea6bc144eb') - else: - weights_path = get_file( - 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5', - WEIGHTS_PATH_NO_TOP, - cache_subdir='models', - md5_hash='a268eb855778b3df3c7506639542a6af') - model.load_weights(weights_path) - elif weights is not None: - model.load_weights(weights) - - return model +tf_export('keras.applications.resnet50.ResNet50', + 'keras.applications.ResNet50')(ResNet50) diff --git a/tensorflow/python/keras/applications/resnet50_test.py b/tensorflow/python/keras/applications/resnet50_test.py deleted file mode 100644 index 22a3f055805f48bb27ad75db664b142d7916b654..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/resnet50_test.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for ResNet50 application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python import keras -from tensorflow.python.platform import test - - -class ResNet50Test(test.TestCase): - - def test_with_top(self): - model = keras.applications.ResNet50(weights=None) - self.assertEqual(model.output_shape, (None, 1000)) - - def test_no_top(self): - model = keras.applications.ResNet50(weights=None, include_top=False) - self.assertEqual(model.output_shape, (None, None, None, 2048)) - - def test_with_pooling(self): - model = keras.applications.ResNet50(weights=None, - include_top=False, - pooling='avg') - self.assertEqual(model.output_shape, (None, 2048)) - - def test_weight_loading(self): - with self.assertRaises(ValueError): - keras.applications.ResNet50(weights='unknown', - include_top=False) - - with self.assertRaises(ValueError): - keras.applications.ResNet50(weights='imagenet', - classes=2000) - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/keras/applications/vgg16.py b/tensorflow/python/keras/applications/vgg16.py index 659a6533e6772402663aee891ed90df792b12f09..c420d9b81e7d6e11d9237949dabae56cd848e0d0 100644 --- a/tensorflow/python/keras/applications/vgg16.py +++ b/tensorflow/python/keras/applications/vgg16.py @@ -13,218 +13,18 @@ # limitations under the License. # ============================================================================== # pylint: disable=invalid-name -# pylint: disable=unused-import """VGG16 model for Keras. - -# Reference - -- [Very Deep Convolutional Networks for Large-Scale Image -Recognition](https://arxiv.org/abs/1409.1556) - """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -from tensorflow.python.keras import backend as K -from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras.engine.network import get_source_inputs -from tensorflow.python.keras.layers import Conv2D -from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.layers import Flatten -from tensorflow.python.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras.layers import Input -from tensorflow.python.keras.layers import MaxPooling2D -from tensorflow.python.keras.models import Model -from tensorflow.python.keras.utils import layer_utils -from tensorflow.python.keras.utils.data_utils import get_file -from tensorflow.python.platform import tf_logging as logging +from keras_applications import vgg16 from tensorflow.python.util.tf_export import tf_export +VGG16 = vgg16.VGG16 +decode_predictions = vgg16.decode_predictions +preprocess_input = vgg16.preprocess_input -WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5' -WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5' - - -@tf_export('keras.applications.VGG16', 'keras.applications.vgg16.VGG16') -def VGG16(include_top=True, - weights='imagenet', - input_tensor=None, - input_shape=None, - pooling=None, - classes=1000): - """Instantiates the VGG16 architecture. - - Optionally loads weights pre-trained - on ImageNet. Note that when using TensorFlow, - for best performance you should set - `image_data_format='channels_last'` in your Keras config - at ~/.keras/keras.json. - - The model and the weights are compatible with both - TensorFlow and Theano. The data format - convention used by the model is the one - specified in your Keras config file. - - Arguments: - include_top: whether to include the 3 fully-connected - layers at the top of the network. - weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. - input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) - to use as image input for the model. - input_shape: optional shape tuple, only to be specified - if `include_top` is False (otherwise the input shape - has to be `(224, 224, 3)` (with `channels_last` data format) - or `(3, 224, 224)` (with `channels_first` data format). - It should have exactly 3 input channels, - and width and height should be no smaller than 48. - E.g. `(200, 200, 3)` would be one valid value. - pooling: Optional pooling mode for feature extraction - when `include_top` is `False`. - - `None` means that the output of the model will be - the 4D tensor output of the - last convolutional layer. - - `avg` means that global average pooling - will be applied to the output of the - last convolutional layer, and thus - the output of the model will be a 2D tensor. - - `max` means that global max pooling will - be applied. - classes: optional number of classes to classify images - into, only to be specified if `include_top` is True, and - if no `weights` argument is specified. - - Returns: - A Keras model instance. - - Raises: - ValueError: in case of invalid argument for `weights`, - or invalid input shape. - """ - if not (weights in {'imagenet', None} or os.path.exists(weights)): - raise ValueError('The `weights` argument should be either ' - '`None` (random initialization), `imagenet` ' - '(pre-training on ImageNet), ' - 'or the path to the weights file to be loaded.') - - if weights == 'imagenet' and include_top and classes != 1000: - raise ValueError('If using `weights` as imagenet with `include_top`' - ' as true, `classes` should be 1000') - # Determine proper input shape - input_shape = _obtain_input_shape( - input_shape, - default_size=224, - min_size=48, - data_format=K.image_data_format(), - require_flatten=include_top, - weights=weights) - - if input_tensor is None: - img_input = Input(shape=input_shape) - else: - if not K.is_keras_tensor(input_tensor): - img_input = Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - # Block 1 - x = Conv2D( - 64, (3, 3), activation='relu', padding='same', name='block1_conv1')( - img_input) - x = Conv2D( - 64, (3, 3), activation='relu', padding='same', name='block1_conv2')( - x) - x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) - - # Block 2 - x = Conv2D( - 128, (3, 3), activation='relu', padding='same', name='block2_conv1')( - x) - x = Conv2D( - 128, (3, 3), activation='relu', padding='same', name='block2_conv2')( - x) - x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) - - # Block 3 - x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv1')( - x) - x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv2')( - x) - x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv3')( - x) - x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) - - # Block 4 - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv1')( - x) - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv2')( - x) - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv3')( - x) - x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) - - # Block 5 - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv1')( - x) - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv2')( - x) - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv3')( - x) - x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) - - if include_top: - # Classification block - x = Flatten(name='flatten')(x) - x = Dense(4096, activation='relu', name='fc1')(x) - x = Dense(4096, activation='relu', name='fc2')(x) - x = Dense(classes, activation='softmax', name='predictions')(x) - else: - if pooling == 'avg': - x = GlobalAveragePooling2D()(x) - elif pooling == 'max': - x = GlobalMaxPooling2D()(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = get_source_inputs(input_tensor) - else: - inputs = img_input - # Create model. - model = Model(inputs, x, name='vgg16') - - # load weights - if weights == 'imagenet': - if include_top: - weights_path = get_file( - 'vgg16_weights_tf_dim_ordering_tf_kernels.h5', - WEIGHTS_PATH, - cache_subdir='models', - file_hash='64373286793e3c8b2b4e3219cbf3544b') - else: - weights_path = get_file( - 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', - WEIGHTS_PATH_NO_TOP, - cache_subdir='models', - file_hash='6d6bbae143d832006294945121d1f1fc') - model.load_weights(weights_path) - - elif weights is not None: - model.load_weights(weights) - - return model +tf_export('keras.applications.vgg16.VGG16', + 'keras.applications.VGG16')(VGG16) diff --git a/tensorflow/python/keras/applications/vgg16_test.py b/tensorflow/python/keras/applications/vgg16_test.py deleted file mode 100644 index cad65765f3d18c5a458c802a6b1aed688468d444..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/vgg16_test.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for VGG16 application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python import keras -from tensorflow.python.platform import test - - -class VGG16Test(test.TestCase): - - def test_with_top(self): - model = keras.applications.VGG16(weights=None) - self.assertEqual(model.output_shape, (None, 1000)) - - def test_no_top(self): - model = keras.applications.VGG16(weights=None, include_top=False) - self.assertEqual(model.output_shape, (None, None, None, 512)) - - def test_with_pooling(self): - model = keras.applications.VGG16(weights=None, - include_top=False, - pooling='avg') - self.assertEqual(model.output_shape, (None, 512)) - - def test_weight_loading(self): - with self.assertRaises(ValueError): - keras.applications.VGG16(weights='unknown', - include_top=False) - with self.assertRaises(ValueError): - keras.applications.VGG16(weights='imagenet', - classes=2000) - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/keras/applications/vgg19.py b/tensorflow/python/keras/applications/vgg19.py index 5e27ab8fb1fb99c65566cc4519798e3b8e0e1b0b..73d3d1d1c369b2b6665477eb51f37b967b303ac5 100644 --- a/tensorflow/python/keras/applications/vgg19.py +++ b/tensorflow/python/keras/applications/vgg19.py @@ -13,227 +13,18 @@ # limitations under the License. # ============================================================================== # pylint: disable=invalid-name -# pylint: disable=unused-import """VGG19 model for Keras. - -# Reference - -- [Very Deep Convolutional Networks for Large-Scale Image -Recognition](https://arxiv.org/abs/1409.1556) - """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -from tensorflow.python.keras import backend as K -from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras.engine.network import get_source_inputs -from tensorflow.python.keras.layers import Conv2D -from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.layers import Flatten -from tensorflow.python.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras.layers import Input -from tensorflow.python.keras.layers import MaxPooling2D -from tensorflow.python.keras.models import Model -from tensorflow.python.keras.utils import layer_utils -from tensorflow.python.keras.utils.data_utils import get_file -from tensorflow.python.platform import tf_logging as logging +from keras_applications import vgg19 from tensorflow.python.util.tf_export import tf_export +VGG19 = vgg19.VGG19 +decode_predictions = vgg19.decode_predictions +preprocess_input = vgg19.preprocess_input -WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels.h5' -WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5' - - -@tf_export('keras.applications.VGG19', 'keras.applications.vgg19.VGG19') -def VGG19(include_top=True, - weights='imagenet', - input_tensor=None, - input_shape=None, - pooling=None, - classes=1000): - """Instantiates the VGG19 architecture. - - Optionally loads weights pre-trained - on ImageNet. Note that when using TensorFlow, - for best performance you should set - `image_data_format='channels_last'` in your Keras config - at ~/.keras/keras.json. - - The model and the weights are compatible with both - TensorFlow and Theano. The data format - convention used by the model is the one - specified in your Keras config file. - - Arguments: - include_top: whether to include the 3 fully-connected - layers at the top of the network. - weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. - input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) - to use as image input for the model. - input_shape: optional shape tuple, only to be specified - if `include_top` is False (otherwise the input shape - has to be `(224, 224, 3)` (with `channels_last` data format) - or `(3, 224, 224)` (with `channels_first` data format). - It should have exactly 3 inputs channels, - and width and height should be no smaller than 48. - E.g. `(200, 200, 3)` would be one valid value. - pooling: Optional pooling mode for feature extraction - when `include_top` is `False`. - - `None` means that the output of the model will be - the 4D tensor output of the - last convolutional layer. - - `avg` means that global average pooling - will be applied to the output of the - last convolutional layer, and thus - the output of the model will be a 2D tensor. - - `max` means that global max pooling will - be applied. - classes: optional number of classes to classify images - into, only to be specified if `include_top` is True, and - if no `weights` argument is specified. - - Returns: - A Keras model instance. - - Raises: - ValueError: in case of invalid argument for `weights`, - or invalid input shape. - """ - if not (weights in {'imagenet', None} or os.path.exists(weights)): - raise ValueError('The `weights` argument should be either ' - '`None` (random initialization), `imagenet` ' - '(pre-training on ImageNet), ' - 'or the path to the weights file to be loaded.') - - if weights == 'imagenet' and include_top and classes != 1000: - raise ValueError('If using `weights` as imagenet with `include_top`' - ' as true, `classes` should be 1000') - # Determine proper input shape - input_shape = _obtain_input_shape( - input_shape, - default_size=224, - min_size=48, - data_format=K.image_data_format(), - require_flatten=include_top, - weights=weights) - - if input_tensor is None: - img_input = Input(shape=input_shape) - else: - if not K.is_keras_tensor(input_tensor): - img_input = Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - # Block 1 - x = Conv2D( - 64, (3, 3), activation='relu', padding='same', name='block1_conv1')( - img_input) - x = Conv2D( - 64, (3, 3), activation='relu', padding='same', name='block1_conv2')( - x) - x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) - - # Block 2 - x = Conv2D( - 128, (3, 3), activation='relu', padding='same', name='block2_conv1')( - x) - x = Conv2D( - 128, (3, 3), activation='relu', padding='same', name='block2_conv2')( - x) - x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) - - # Block 3 - x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv1')( - x) - x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv2')( - x) - x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv3')( - x) - x = Conv2D( - 256, (3, 3), activation='relu', padding='same', name='block3_conv4')( - x) - x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) - - # Block 4 - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv1')( - x) - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv2')( - x) - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv3')( - x) - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block4_conv4')( - x) - x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) - - # Block 5 - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv1')( - x) - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv2')( - x) - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv3')( - x) - x = Conv2D( - 512, (3, 3), activation='relu', padding='same', name='block5_conv4')( - x) - x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) - - if include_top: - # Classification block - x = Flatten(name='flatten')(x) - x = Dense(4096, activation='relu', name='fc1')(x) - x = Dense(4096, activation='relu', name='fc2')(x) - x = Dense(classes, activation='softmax', name='predictions')(x) - else: - if pooling == 'avg': - x = GlobalAveragePooling2D()(x) - elif pooling == 'max': - x = GlobalMaxPooling2D()(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = get_source_inputs(input_tensor) - else: - inputs = img_input - # Create model. - model = Model(inputs, x, name='vgg19') - - # load weights - if weights == 'imagenet': - if include_top: - weights_path = get_file( - 'vgg19_weights_tf_dim_ordering_tf_kernels.h5', - WEIGHTS_PATH, - cache_subdir='models', - file_hash='cbe5617147190e668d6c5d5026f83318') - else: - weights_path = get_file( - 'vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5', - WEIGHTS_PATH_NO_TOP, - cache_subdir='models', - file_hash='253f8cb515780f3b799900260a226db6') - model.load_weights(weights_path) - - elif weights is not None: - model.load_weights(weights) - - return model +tf_export('keras.applications.vgg19.VGG19', + 'keras.applications.VGG19')(VGG19) diff --git a/tensorflow/python/keras/applications/vgg19_test.py b/tensorflow/python/keras/applications/vgg19_test.py deleted file mode 100644 index 61dccc0c5cc315cc0e5c0284cf829ac2034c69d2..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/vgg19_test.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for VGG19 application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python import keras -from tensorflow.python.platform import test - - -class VGG19Test(test.TestCase): - - def test_with_top(self): - model = keras.applications.VGG19(weights=None) - self.assertEqual(model.output_shape, (None, 1000)) - - def test_no_top(self): - model = keras.applications.VGG19(weights=None, include_top=False) - self.assertEqual(model.output_shape, (None, None, None, 512)) - - def test_with_pooling(self): - model = keras.applications.VGG19(weights=None, - include_top=False, - pooling='avg') - self.assertEqual(model.output_shape, (None, 512)) - - def test_weight_loading(self): - with self.assertRaises(ValueError): - keras.applications.VGG19(weights='unknown', - include_top=False) - with self.assertRaises(ValueError): - keras.applications.VGG19(weights='imagenet', - classes=2000) - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/keras/applications/xception.py b/tensorflow/python/keras/applications/xception.py index e1be8a3c46e6eafa43405f1472a2f0292b73aa0c..5b221ac8e05bb7317d06f4a2ce0569d8219f99a4 100644 --- a/tensorflow/python/keras/applications/xception.py +++ b/tensorflow/python/keras/applications/xception.py @@ -13,332 +13,19 @@ # limitations under the License. # ============================================================================== # pylint: disable=invalid-name -# pylint: disable=unused-import """Xception V1 model for Keras. - -On ImageNet, this model gets to a top-1 validation accuracy of 0.790 -and a top-5 validation accuracy of 0.945. - -Do note that the input image format for this model is different than for -the VGG16 and ResNet models (299x299 instead of 224x224), -and that the input preprocessing function -is also different (same as Inception V3). - -Also do note that this model is only available for the TensorFlow backend, -due to its reliance on `SeparableConvolution` layers. - -# Reference - -- [Xception: Deep Learning with Depthwise Separable -Convolutions](https://arxiv.org/abs/1610.02357) - """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -from tensorflow.python.keras import backend as K -from tensorflow.python.keras import layers -from tensorflow.python.keras.applications import imagenet_utils -from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape -from tensorflow.python.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras.engine.network import get_source_inputs -from tensorflow.python.keras.layers import Activation -from tensorflow.python.keras.layers import BatchNormalization -from tensorflow.python.keras.layers import Conv2D -from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.layers import GlobalAveragePooling2D -from tensorflow.python.keras.layers import GlobalMaxPooling2D -from tensorflow.python.keras.layers import Input -from tensorflow.python.keras.layers import MaxPooling2D -from tensorflow.python.keras.layers import SeparableConv2D -from tensorflow.python.keras.models import Model -from tensorflow.python.keras.utils.data_utils import get_file -from tensorflow.python.platform import tf_logging as logging +from keras_applications import xception from tensorflow.python.util.tf_export import tf_export +Xception = xception.Xception +decode_predictions = xception.decode_predictions +preprocess_input = xception.preprocess_input -TF_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels.h5' -TF_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels_notop.h5' - - -@tf_export('keras.applications.Xception', - 'keras.applications.xception.Xception') -def Xception(include_top=True, - weights='imagenet', - input_tensor=None, - input_shape=None, - pooling=None, - classes=1000): - """Instantiates the Xception architecture. - - Optionally loads weights pre-trained - on ImageNet. This model is available for TensorFlow only, - and can only be used with inputs following the TensorFlow - data format `(width, height, channels)`. - You should set `image_data_format='channels_last'` in your Keras config - located at ~/.keras/keras.json. - - Note that the default input image size for this model is 299x299. - - Arguments: - include_top: whether to include the fully-connected - layer at the top of the network. - weights: one of `None` (random initialization), - 'imagenet' (pre-training on ImageNet), - or the path to the weights file to be loaded. - input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) - to use as image input for the model. - input_shape: optional shape tuple, only to be specified - if `include_top` is False (otherwise the input shape - has to be `(299, 299, 3)`. - It should have exactly 3 inputs channels, - and width and height should be no smaller than 71. - E.g. `(150, 150, 3)` would be one valid value. - pooling: Optional pooling mode for feature extraction - when `include_top` is `False`. - - `None` means that the output of the model will be - the 4D tensor output of the - last convolutional layer. - - `avg` means that global average pooling - will be applied to the output of the - last convolutional layer, and thus - the output of the model will be a 2D tensor. - - `max` means that global max pooling will - be applied. - classes: optional number of classes to classify images - into, only to be specified if `include_top` is True, and - if no `weights` argument is specified. - - Returns: - A Keras model instance. - - Raises: - ValueError: in case of invalid argument for `weights`, - or invalid input shape. - RuntimeError: If attempting to run this model with a - backend that does not support separable convolutions. - """ - if not (weights in {'imagenet', None} or os.path.exists(weights)): - raise ValueError('The `weights` argument should be either ' - '`None` (random initialization), `imagenet` ' - '(pre-training on ImageNet), ' - 'or the path to the weights file to be loaded.') - - if weights == 'imagenet' and include_top and classes != 1000: - raise ValueError('If using `weights` as imagenet with `include_top`' - ' as true, `classes` should be 1000') - - if K.image_data_format() != 'channels_last': - logging.warning( - 'The Xception model is only available for the ' - 'input data format "channels_last" ' - '(width, height, channels). ' - 'However your settings specify the default ' - 'data format "channels_first" (channels, width, height). ' - 'You should set `image_data_format="channels_last"` in your Keras ' - 'config located at ~/.keras/keras.json. ' - 'The model being returned right now will expect inputs ' - 'to follow the "channels_last" data format.') - K.set_image_data_format('channels_last') - old_data_format = 'channels_first' - else: - old_data_format = None - - # Determine proper input shape - input_shape = _obtain_input_shape( - input_shape, - default_size=299, - min_size=71, - data_format=K.image_data_format(), - require_flatten=False, - weights=weights) - - if input_tensor is None: - img_input = Input(shape=input_shape) - else: - if not K.is_keras_tensor(input_tensor): - img_input = Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - - x = Conv2D( - 32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')( - img_input) - x = BatchNormalization(name='block1_conv1_bn')(x) - x = Activation('relu', name='block1_conv1_act')(x) - x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x) - x = BatchNormalization(name='block1_conv2_bn')(x) - x = Activation('relu', name='block1_conv2_act')(x) - - residual = Conv2D( - 128, (1, 1), strides=(2, 2), padding='same', use_bias=False)( - x) - residual = BatchNormalization()(residual) - - x = SeparableConv2D( - 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')( - x) - x = BatchNormalization(name='block2_sepconv1_bn')(x) - x = Activation('relu', name='block2_sepconv2_act')(x) - x = SeparableConv2D( - 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')( - x) - x = BatchNormalization(name='block2_sepconv2_bn')(x) - - x = MaxPooling2D( - (3, 3), strides=(2, 2), padding='same', name='block2_pool')( - x) - x = layers.add([x, residual]) - - residual = Conv2D( - 256, (1, 1), strides=(2, 2), padding='same', use_bias=False)( - x) - residual = BatchNormalization()(residual) - - x = Activation('relu', name='block3_sepconv1_act')(x) - x = SeparableConv2D( - 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')( - x) - x = BatchNormalization(name='block3_sepconv1_bn')(x) - x = Activation('relu', name='block3_sepconv2_act')(x) - x = SeparableConv2D( - 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')( - x) - x = BatchNormalization(name='block3_sepconv2_bn')(x) - - x = MaxPooling2D( - (3, 3), strides=(2, 2), padding='same', name='block3_pool')( - x) - x = layers.add([x, residual]) - - residual = Conv2D( - 728, (1, 1), strides=(2, 2), padding='same', use_bias=False)( - x) - residual = BatchNormalization()(residual) - - x = Activation('relu', name='block4_sepconv1_act')(x) - x = SeparableConv2D( - 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')( - x) - x = BatchNormalization(name='block4_sepconv1_bn')(x) - x = Activation('relu', name='block4_sepconv2_act')(x) - x = SeparableConv2D( - 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')( - x) - x = BatchNormalization(name='block4_sepconv2_bn')(x) - - x = MaxPooling2D( - (3, 3), strides=(2, 2), padding='same', name='block4_pool')( - x) - x = layers.add([x, residual]) - - for i in range(8): - residual = x - prefix = 'block' + str(i + 5) - - x = Activation('relu', name=prefix + '_sepconv1_act')(x) - x = SeparableConv2D( - 728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')( - x) - x = BatchNormalization(name=prefix + '_sepconv1_bn')(x) - x = Activation('relu', name=prefix + '_sepconv2_act')(x) - x = SeparableConv2D( - 728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')( - x) - x = BatchNormalization(name=prefix + '_sepconv2_bn')(x) - x = Activation('relu', name=prefix + '_sepconv3_act')(x) - x = SeparableConv2D( - 728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')( - x) - x = BatchNormalization(name=prefix + '_sepconv3_bn')(x) - - x = layers.add([x, residual]) - - residual = Conv2D( - 1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)( - x) - residual = BatchNormalization()(residual) - - x = Activation('relu', name='block13_sepconv1_act')(x) - x = SeparableConv2D( - 728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')( - x) - x = BatchNormalization(name='block13_sepconv1_bn')(x) - x = Activation('relu', name='block13_sepconv2_act')(x) - x = SeparableConv2D( - 1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')( - x) - x = BatchNormalization(name='block13_sepconv2_bn')(x) - - x = MaxPooling2D( - (3, 3), strides=(2, 2), padding='same', name='block13_pool')( - x) - x = layers.add([x, residual]) - - x = SeparableConv2D( - 1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')( - x) - x = BatchNormalization(name='block14_sepconv1_bn')(x) - x = Activation('relu', name='block14_sepconv1_act')(x) - - x = SeparableConv2D( - 2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')( - x) - x = BatchNormalization(name='block14_sepconv2_bn')(x) - x = Activation('relu', name='block14_sepconv2_act')(x) - - if include_top: - x = GlobalAveragePooling2D(name='avg_pool')(x) - x = Dense(classes, activation='softmax', name='predictions')(x) - else: - if pooling == 'avg': - x = GlobalAveragePooling2D()(x) - elif pooling == 'max': - x = GlobalMaxPooling2D()(x) - - # Ensure that the model takes into account - # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = get_source_inputs(input_tensor) - else: - inputs = img_input - # Create model. - model = Model(inputs, x, name='xception') - - # load weights - if weights == 'imagenet': - if include_top: - weights_path = get_file( - 'xception_weights_tf_dim_ordering_tf_kernels.h5', - TF_WEIGHTS_PATH, - cache_subdir='models', - file_hash='0a58e3b7378bc2990ea3b43d5981f1f6') - else: - weights_path = get_file( - 'xception_weights_tf_dim_ordering_tf_kernels_notop.h5', - TF_WEIGHTS_PATH_NO_TOP, - cache_subdir='models', - file_hash='b0042744bf5b25fce3cb969f33bebb97') - model.load_weights(weights_path) - elif weights is not None: - model.load_weights(weights) - - if old_data_format: - K.set_image_data_format(old_data_format) - return model - - -@tf_export('keras.applications.xception.preprocess_input') -def preprocess_input(x): - """Preprocesses a numpy array encoding a batch of images. - - Arguments: - x: a 4D numpy array consists of RGB values within [0, 255]. - - Returns: - Preprocessed array. - """ - return imagenet_utils.preprocess_input(x, mode='tf') +tf_export('keras.applications.xception.Xception', + 'keras.applications.Xception')(Xception) +tf_export('keras.applications.xception.preprocess_input')(preprocess_input) diff --git a/tensorflow/python/keras/applications/xception_test.py b/tensorflow/python/keras/applications/xception_test.py deleted file mode 100644 index 7e2efd0017836ae671d88b561385b6e61be9fa0b..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/applications/xception_test.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Xception application.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python import keras -from tensorflow.python.platform import test - - -class XceptionTest(test.TestCase): - - def test_with_top(self): - model = keras.applications.Xception(weights=None) - self.assertEqual(model.output_shape, (None, 1000)) - - def test_no_top(self): - model = keras.applications.Xception(weights=None, include_top=False) - self.assertEqual(model.output_shape, (None, None, None, 2048)) - - def test_with_pooling(self): - model = keras.applications.Xception(weights=None, - include_top=False, - pooling='avg') - self.assertEqual(model.output_shape, (None, 2048)) - - def test_weight_loading(self): - with self.assertRaises(ValueError): - keras.applications.Xception(weights='unknown', - include_top=False) - with self.assertRaises(ValueError): - keras.applications.Xception(weights='imagenet', - classes=2000) - - def test_preprocess_input(self): - x = np.random.uniform(0, 255, (2, 300, 200, 3)) - out1 = keras.applications.xception.preprocess_input(x) - self.assertAllClose(np.mean(out1), 0., atol=0.1) - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 84821918bfe6160d7ee1b4556e00f533a07f5ebd..418586b85f01cc30210f58d2ee04097936407e54 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -22,6 +22,7 @@ from __future__ import division from __future__ import print_function import collections +import itertools import json import os import weakref @@ -647,7 +648,7 @@ def variable(value, dtype=None, name=None, constraint=None): constraint=constraint) if isinstance(value, np.ndarray): v._keras_shape = value.shape - elif hasattr(value, 'get_shape'): + elif hasattr(value, 'shape'): v._keras_shape = int_shape(value) v._uses_learning_phase = False return v @@ -735,9 +736,10 @@ def is_keras_tensor(x): True ``` """ - if not isinstance(x, (ops.Tensor, - variables_module.Variable, - sparse_tensor.SparseTensor)): + if (not isinstance(x, (ops.Tensor, + variables_module.Variable, + sparse_tensor.SparseTensor)) and + x.__class__.__name__ != 'DeferredTensor'): raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) + '`. Expected a symbolic tensor instance.') return hasattr(x, '_keras_history') @@ -852,7 +854,10 @@ def int_shape(x): ``` """ try: - return tuple(x.get_shape().as_list()) + shape = x.shape + if not isinstance(shape, tuple): + shape = tuple(shape.as_list()) + return shape except ValueError: return None @@ -879,7 +884,7 @@ def ndim(x): 2 ``` """ - dims = x.get_shape()._dims + dims = x.shape._dims if dims is not None: return len(dims) return None @@ -962,13 +967,14 @@ def zeros(shape, dtype=None, name=None): [ 0., 0., 0., 0.]], dtype=float32) ``` """ - if dtype is None: - dtype = floatx() - tf_dtype = dtypes_module.as_dtype(dtype) - v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name) - if py_all(v.get_shape().as_list()): - return variable(v, dtype=dtype, name=name) - return v + with ops.init_scope(): + if dtype is None: + dtype = floatx() + tf_dtype = dtypes_module.as_dtype(dtype) + v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name) + if py_all(v.shape.as_list()): + return variable(v, dtype=dtype, name=name) + return v @tf_export('keras.backend.ones') @@ -995,13 +1001,14 @@ def ones(shape, dtype=None, name=None): [ 1., 1., 1., 1.]], dtype=float32) ``` """ - if dtype is None: - dtype = floatx() - tf_dtype = dtypes_module.as_dtype(dtype) - v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name) - if py_all(v.get_shape().as_list()): - return variable(v, dtype=dtype, name=name) - return v + with ops.init_scope(): + if dtype is None: + dtype = floatx() + tf_dtype = dtypes_module.as_dtype(dtype) + v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name) + if py_all(v.shape.as_list()): + return variable(v, dtype=dtype, name=name) + return v @tf_export('keras.backend.eye') @@ -1193,7 +1200,7 @@ def count_params(x): [ 0., 0., 0.]], dtype=float32) ``` """ - return np.prod(x.get_shape().as_list()) + return np.prod(x.shape.as_list()) @tf_export('keras.backend.cast') @@ -2112,10 +2119,10 @@ def _fused_normalize_batch_in_training(x, if gamma is None: gamma = constant_op.constant( - 1.0, dtype=x.dtype, shape=[x.get_shape()[normalization_axis]]) + 1.0, dtype=x.dtype, shape=[x.shape[normalization_axis]]) if beta is None: beta = constant_op.constant( - 0.0, dtype=x.dtype, shape=[x.get_shape()[normalization_axis]]) + 0.0, dtype=x.dtype, shape=[x.shape[normalization_axis]]) return nn.fused_batch_norm( x, gamma, beta, epsilon=epsilon, data_format=tf_data_format) @@ -2320,7 +2327,7 @@ def repeat_elements(x, rep, axis): Returns: A tensor. """ - x_shape = x.get_shape().as_list() + x_shape = x.shape.as_list() # For static axis if x_shape[axis] is not None: # slices along the repeat axis @@ -2340,7 +2347,7 @@ def repeat_elements(x, rep, axis): auxiliary_axis = axis + 1 x_shape = array_ops.shape(x) x_rep = array_ops.expand_dims(x, axis=auxiliary_axis) - reps = np.ones(len(x.get_shape()) + 1) + reps = np.ones(len(x.shape) + 1) reps[auxiliary_axis] = rep x_rep = array_ops.tile(x_rep, reps) @@ -2352,7 +2359,7 @@ def repeat_elements(x, rep, axis): x_rep = array_ops.reshape(x_rep, x_shape) # Fix shape representation - x_shape = x.get_shape().as_list() + x_shape = x.shape.as_list() x_rep.set_shape(x_shape) x_rep._keras_shape = tuple(x_shape) return x_rep @@ -2794,10 +2801,15 @@ class Function(object): if not isinstance(self.fetches, list): self.fetches = [self.fetches] # The main use case of `fetches` being passed to a model is the ability - # to run custom updates (since the outputs of fetches are never returned). + # to run custom updates # This requires us to wrap fetches in `identity` ops. self.fetches = [array_ops.identity(x) for x in self.fetches] self.session_kwargs = session_kwargs + # This mapping keeps track of the function that should receive the + # output from a fetch in `fetches`: { fetch: function(fetch_output) } + # A Callback can use this to register a function with access to the + # output values for a fetch it added. + self.fetch_callbacks = dict() if session_kwargs: raise ValueError('Some keys in session_kwargs are not supported at this ' @@ -2807,6 +2819,7 @@ class Function(object): self._feed_arrays = None self._feed_symbols = None self._symbol_vals = None + self._fetches = None self._session = None def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session): @@ -2852,8 +2865,14 @@ class Function(object): self._feed_arrays = feed_arrays self._feed_symbols = feed_symbols self._symbol_vals = symbol_vals + self._fetches = list(self.fetches) self._session = session + def _call_fetch_callbacks(self, fetches_output): + for fetch, output in zip(self._fetches, fetches_output): + if fetch in self.fetch_callbacks: + self.fetch_callbacks[fetch](output) + def __call__(self, inputs): if not isinstance(inputs, (list, tuple)): raise TypeError('`inputs` should be a list or tuple.') @@ -2880,21 +2899,24 @@ class Function(object): feed_arrays.append(tensor) # We need to do array conversion and type casting at this level, since # `callable_fn` only supports exact matches. - array_vals.append(np.asarray(value, dtype=tensor.dtype.base_dtype.name)) + tensor_type = dtypes_module.as_dtype(tensor.dtype) + array_vals.append(np.asarray(value, + dtype=tensor_type.as_numpy_dtype)) + if self.feed_dict: for key in sorted(self.feed_dict.keys()): array_vals.append( np.asarray(self.feed_dict[key], dtype=key.dtype.base_dtype.name)) # Refresh callable if anything has changed. - if (self._callable_fn is None or - feed_arrays != self._feed_arrays or + if (self._callable_fn is None or feed_arrays != self._feed_arrays or symbol_vals != self._symbol_vals or - feed_symbols != self._feed_symbols or + feed_symbols != self._feed_symbols or self.fetches != self._fetches or session != self._session): self._make_callable(feed_arrays, feed_symbols, symbol_vals, session) fetched = self._callable_fn(*array_vals) + self._call_fetch_callbacks(fetched[-len(self._fetches):]) return fetched[:len(self.outputs)] @@ -2916,8 +2938,8 @@ def function(inputs, outputs, updates=None, **kwargs): """ if kwargs: for key in kwargs: - if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and - key not in tf_inspect.getargspec(Function.__init__)[0]): + if (key not in tf_inspect.getfullargspec(session_module.Session.run)[0] + and key not in tf_inspect.getfullargspec(Function.__init__)[0]): msg = ('Invalid argument "%s" passed to K.function with TensorFlow ' 'backend') % key raise ValueError(msg) @@ -3014,17 +3036,17 @@ def rnn(step_function, ValueError: if `mask` is provided (not `None`) but states is not provided (`len(states)` == 0). """ - ndim = len(inputs.get_shape()) + ndim = len(inputs.shape) if ndim < 3: raise ValueError('Input should be at least 3D.') - inputs_shape = inputs.get_shape() + inputs_shape = inputs.shape axes = [1, 0] + list(range(2, ndim)) inputs = array_ops.transpose(inputs, (axes)) if mask is not None: if mask.dtype != dtypes_module.bool: mask = math_ops.cast(mask, dtypes_module.bool) - if len(mask.get_shape()) == ndim - 1: + if len(mask.shape) == ndim - 1: mask = expand_dims(mask) mask = array_ops.transpose(mask, axes) @@ -3035,7 +3057,7 @@ def rnn(step_function, uses_learning_phase = False if unroll: - if not inputs.get_shape()[0]: + if not inputs.shape[0]: raise ValueError('Unrolling requires a fixed number of timesteps.') states = initial_states successive_states = [] @@ -3152,15 +3174,21 @@ def rnn(step_function, global uses_learning_phase # pylint: disable=global-variable-undefined uses_learning_phase = True for state, new_state in zip(states, new_states): - new_state.set_shape(state.get_shape()) + new_state.set_shape(state.shape) tiled_mask_t = array_ops.tile(mask_t, array_ops.stack( [1, array_ops.shape(output)[1]])) output = array_ops.where(tiled_mask_t, output, states[0]) - new_states = [ - array_ops.where(tiled_mask_t, new_states[i], states[i]) - for i in range(len(states)) - ] + + masked_states = [] + for i in range(len(states)): + states_dim = array_ops.shape(new_states[i])[1] + stacked_states_dim = array_ops.stack([1, states_dim]) + tiled_mask = array_ops.tile(mask_t, stacked_states_dim) + masked_state = array_ops.where(tiled_mask, new_states[i], states[i]) + masked_states.append(masked_state) + new_states = masked_states + output_ta_t = output_ta_t.write(time, output) return (time + 1, output_ta_t) + tuple(new_states) else: @@ -3183,7 +3211,7 @@ def rnn(step_function, global uses_learning_phase # pylint: disable=global-variable-undefined uses_learning_phase = True for state, new_state in zip(states, new_states): - new_state.set_shape(state.get_shape()) + new_state.set_shape(state.shape) output_ta_t = output_ta_t.write(time, output) return (time + 1, output_ta_t) + tuple(new_states) @@ -3201,11 +3229,11 @@ def rnn(step_function, outputs = output_ta.stack() last_output = output_ta.read(last_time - 1) - axes = [1, 0] + list(range(2, len(outputs.get_shape()))) + axes = [1, 0] + list(range(2, len(outputs.shape))) outputs = array_ops.transpose(outputs, axes) # Static shape inference: (samples, time, ...) - outputs_shape = outputs.get_shape().as_list() + outputs_shape = outputs.shape.as_list() outputs_shape[0] = inputs_shape[0] outputs_shape[1] = inputs_shape[1] outputs.set_shape(outputs_shape) @@ -3348,26 +3376,48 @@ def in_test_phase(x, alt, training=None): @tf_export('keras.backend.relu') -def relu(x, alpha=0., max_value=None): +def relu(x, alpha=0., max_value=None, threshold=0): """Rectified linear unit. With default values, it returns element-wise `max(x, 0)`. + Otherwise, it follows: + `f(x) = max_value` for `x >= max_value`, + `f(x) = x` for `threshold <= x < max_value`, + `f(x) = alpha * (x - threshold)` otherwise. + Arguments: x: A tensor or variable. alpha: A scalar, slope of negative section (default=`0.`). - max_value: Saturation threshold. + max_value: float. Saturation threshold. + threshold: float. Threshold value for thresholded activation. Returns: A tensor. """ + clip_max = max_value is not None + if alpha != 0.: - negative_part = nn.relu(-x) - x = nn.relu(x) - if max_value is not None: + if threshold != 0: + negative_part = nn.relu(-x + threshold) + else: + negative_part = nn.relu(-x) + + if threshold != 0: + # computes x for x > threshold else 0 + x = x * math_ops.cast(math_ops.greater(x, threshold), floatx()) + elif max_value == 6: + # if no threshold, then can use nn.relu6 native TF op for performance + x = nn.relu6(x) + clip_max = False + else: + x = nn.relu(x) + + if clip_max: max_value = _to_tensor(max_value, x.dtype.base_dtype) zero = _to_tensor(0., x.dtype.base_dtype) x = clip_ops.clip_by_value(x, zero, max_value) + if alpha != 0.: alpha = _to_tensor(alpha, x.dtype.base_dtype) x -= alpha * negative_part @@ -3434,7 +3484,7 @@ def softsign(x): @tf_export('keras.backend.categorical_crossentropy') -def categorical_crossentropy(target, output, from_logits=False): +def categorical_crossentropy(target, output, from_logits=False, axis=-1): """Categorical crossentropy between an output tensor and a target tensor. Arguments: @@ -3444,28 +3494,33 @@ def categorical_crossentropy(target, output, from_logits=False): case `output` is expected to be the logits). from_logits: Boolean, whether `output` is the result of a softmax, or is a tensor of logits. + axis: Int specifying the channels axis. `axis=-1` corresponds to data + format `channels_last', and `axis=1` corresponds to data format + `channels_first`. Returns: Output tensor. + + Raises: + ValueError: if `axis` is neither -1 nor one of the axes of `output`. """ + rank = len(output.shape) + axis = axis % rank # Note: nn.softmax_cross_entropy_with_logits_v2 # expects logits, Keras expects probabilities. if not from_logits: # scale preds so that the class probas of each sample sum to 1 - output = output / math_ops.reduce_sum( # pylint: disable=g-no-augmented-assignment - output, len(output.get_shape()) - 1, True) + output = output / math_ops.reduce_sum(output, axis, True) # manual computation of crossentropy epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype) output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_) - return -math_ops.reduce_sum( - target * math_ops.log(output), - axis=len(output.get_shape()) - 1) + return -math_ops.reduce_sum(target * math_ops.log(output), axis) else: return nn.softmax_cross_entropy_with_logits_v2(labels=target, logits=output) @tf_export('keras.backend.sparse_categorical_crossentropy') -def sparse_categorical_crossentropy(target, output, from_logits=False): +def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): """Categorical crossentropy with integer targets. Arguments: @@ -3475,10 +3530,22 @@ def sparse_categorical_crossentropy(target, output, from_logits=False): case `output` is expected to be the logits). from_logits: Boolean, whether `output` is the result of a softmax, or is a tensor of logits. + axis: Int specifying the channels axis. `axis=-1` corresponds to data + format `channels_last', and `axis=1` corresponds to data format + `channels_first`. Returns: Output tensor. + + Raises: + ValueError: if `axis` is neither -1 nor one of the axes of `output`. """ + rank = len(output.shape) + axis = axis % rank + if axis != rank - 1: + permutation = list(range(axis)) + list(range(axis + 1, rank)) + [axis] + output = array_ops.transpose(output, perm=permutation) + # Note: nn.sparse_softmax_cross_entropy_with_logits # expects logits, Keras expects probabilities. if not from_logits: @@ -3486,7 +3553,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False): output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_) output = math_ops.log(output) - output_shape = output.get_shape() + output_shape = output.shape targets = cast(flatten(target), 'int64') logits = array_ops.reshape(output, [-1, int(output_shape[-1])]) res = nn.sparse_softmax_cross_entropy_with_logits( @@ -3733,7 +3800,7 @@ def conv1d(x, if data_format not in {'channels_first', 'channels_last'}: raise ValueError('Unknown data_format: ' + str(data_format)) - kernel_shape = kernel.get_shape().as_list() + kernel_shape = kernel.shape.as_list() if padding == 'causal': # causal (dilated) convolution: left_pad = dilation_rate * (kernel_shape[0] - 1) @@ -4242,58 +4309,115 @@ def pool3d(x, return x -def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): - """Apply 1D conv with un-shared weights. - - Arguments: - inputs: 3D tensor with shape: - (batch_size, steps, input_dim) - if data_format is "channels_last" or - (batch_size, input_dim, steps) - if data_format is "channels_first". - kernel: the unshared weight for convolution, - with shape (output_length, feature_dim, filters) - kernel_size: a tuple of a single integer, - specifying the length of the 1D convolution window - strides: a tuple of a single integer, - specifying the stride length of the convolution - data_format: the data format, channels_first or channels_last - - Returns: - the tensor after 1d conv with un-shared weights, with shape (batch_size, - output_length, filters) +def local_conv(inputs, + kernel, + kernel_size, + strides, + output_shape, + data_format=None): + """Apply N-D convolution with un-shared weights. + + Arguments: + inputs: (N+2)-D tensor with shape + (batch_size, channels_in, d_in1, ..., d_inN) + if data_format='channels_first', or + (batch_size, d_in1, ..., d_inN, channels_in) + if data_format='channels_last'. + kernel: the unshared weight for N-D convolution, + with shape (output_items, feature_dim, channels_out), where + feature_dim = np.prod(kernel_size) * channels_in, + output_items = np.prod(output_shape). + kernel_size: a tuple of N integers, specifying the + spatial dimensions of the N-D convolution window. + strides: a tuple of N integers, specifying the strides + of the convolution along the spatial dimensions. + output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial + dimensionality of the output. + data_format: string, "channels_first" or "channels_last". + + Returns: + An (N+2)-D tensor with shape: + (batch_size, channels_out) + output_shape + if data_format='channels_first', or: + (batch_size,) + output_shape + (channels_out,) + if data_format='channels_last'. Raises: - ValueError: if `data_format` is neither `channels_last` or - `channels_first`. + ValueError: if `data_format` is neither + `channels_last` nor `channels_first`. """ if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: raise ValueError('Unknown data_format: ' + str(data_format)) - stride = strides[0] kernel_shape = int_shape(kernel) - output_length = kernel_shape[0] feature_dim = kernel_shape[1] + channels_out = kernel_shape[-1] + ndims = len(output_shape) + spatial_dimensions = list(range(ndims)) xs = [] - for i in range(output_length): - slice_length = slice(i * stride, i * stride + kernel_size[0]) + output_axes_ticks = [range(axis_max) for axis_max in output_shape] + for position in itertools.product(*output_axes_ticks): + slices = [slice(None)] + if data_format == 'channels_first': - xs.append(reshape(inputs[:, :, slice_length], (1, -1, feature_dim))) - else: - xs.append(reshape(inputs[:, slice_length, :], (1, -1, feature_dim))) + slices.append(slice(None)) + + slices.extend([slice(position[d] * strides[d], + position[d] * strides[d] + kernel_size[d]) + for d in spatial_dimensions]) + + if data_format == 'channels_last': + slices.append(slice(None)) + + xs.append(reshape(inputs[slices], (1, -1, feature_dim))) x_aggregate = concatenate(xs, axis=0) - # Shape: `(output_length, batch_size, filters)`. output = batch_dot(x_aggregate, kernel) + output = reshape(output, output_shape + (-1, channels_out)) if data_format == 'channels_first': - output = permute_dimensions(output, (1, 2, 0)) + permutation = [ndims, ndims + 1] + spatial_dimensions else: - output = permute_dimensions(output, (1, 0, 2)) - return output + permutation = [ndims] + spatial_dimensions + [ndims + 1] + + return permute_dimensions(output, permutation) + + +def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): + """Apply 1D conv with un-shared weights. + + Arguments: + inputs: 3D tensor with shape: + (batch_size, steps, input_dim) + if data_format is "channels_last" or + (batch_size, input_dim, steps) + if data_format is "channels_first". + kernel: the unshared weight for convolution, + with shape (output_length, feature_dim, filters). + kernel_size: a tuple of a single integer, + specifying the length of the 1D convolution window. + strides: a tuple of a single integer, + specifying the stride length of the convolution. + data_format: the data format, channels_first or channels_last. + + Returns: + A 3d tensor with shape: + (batch_size, output_length, filters) + if data_format='channels_first' + or 3D tensor with shape: + (batch_size, filters, output_length) + if data_format='channels_last'. + """ + output_shape = (kernel.shape[0],) + return local_conv(inputs, + kernel, + kernel_size, + strides, + output_shape, + data_format) def local_conv2d(inputs, @@ -4306,64 +4430,34 @@ def local_conv2d(inputs, Arguments: inputs: 4D tensor with shape: - (batch_size, filters, new_rows, new_cols) - if data_format='channels_first' - or 4D tensor with shape: - (batch_size, new_rows, new_cols, filters) - if data_format='channels_last'. + (batch_size, filters, new_rows, new_cols) + if data_format='channels_first' + or 4D tensor with shape: + (batch_size, new_rows, new_cols, filters) + if data_format='channels_last'. kernel: the unshared weight for convolution, - with shape (output_items, feature_dim, filters) + with shape (output_items, feature_dim, filters). kernel_size: a tuple of 2 integers, specifying the - width and height of the 2D convolution window. + width and height of the 2D convolution window. strides: a tuple of 2 integers, specifying the strides - of the convolution along the width and height. - output_shape: a tuple with (output_row, output_col) - data_format: the data format, channels_first or channels_last + of the convolution along the width and height. + output_shape: a tuple with (output_row, output_col). + data_format: the data format, channels_first or channels_last. Returns: - A 4d tensor with shape: + A 4D tensor with shape: (batch_size, filters, new_rows, new_cols) if data_format='channels_first' or 4D tensor with shape: (batch_size, new_rows, new_cols, filters) if data_format='channels_last'. - - Raises: - ValueError: if `data_format` is neither - `channels_last` or `channels_first`. """ - if data_format is None: - data_format = image_data_format() - if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format: ' + str(data_format)) - - stride_row, stride_col = strides - output_row, output_col = output_shape - kernel_shape = int_shape(kernel) - feature_dim = kernel_shape[1] - filters = kernel_shape[2] - - xs = [] - for i in range(output_row): - for j in range(output_col): - slice_row = slice(i * stride_row, i * stride_row + kernel_size[0]) - slice_col = slice(j * stride_col, j * stride_col + kernel_size[1]) - if data_format == 'channels_first': - xs.append( - reshape(inputs[:, :, slice_row, slice_col], (1, -1, feature_dim))) - else: - xs.append( - reshape(inputs[:, slice_row, slice_col, :], (1, -1, feature_dim))) - - x_aggregate = concatenate(xs, axis=0) - output = batch_dot(x_aggregate, kernel) - output = reshape(output, (output_row, output_col, -1, filters)) - - if data_format == 'channels_first': - output = permute_dimensions(output, (2, 3, 0, 1)) - else: - output = permute_dimensions(output, (2, 0, 1, 3)) - return output + return local_conv(inputs, + kernel, + kernel_size, + strides, + output_shape, + data_format) @tf_export('keras.backend.bias_add') diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py index 53e30e0e4aeda1847c4f1867cb37b87841ad1ee7..40e79100613daef6731cfc71d5aad0cb0b6cc275 100644 --- a/tensorflow/python/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -17,10 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np import scipy.sparse from tensorflow.python import keras +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -274,6 +277,36 @@ class BackendUtilsTest(test.TestCase): self.assertEqual( keras.backend.get_session().run(fetches=[x, y]), [30., 40.]) + def test_function_fetch_callbacks(self): + + class CallbackStub(object): + + def __init__(self): + self.times_called = 0 + self.callback_result = 0 + + def _fetch_callback(self, result): + self.times_called += 1 + self.callback_result = result + + with self.test_session(): + callback = CallbackStub() + x_placeholder = keras.backend.placeholder(shape=()) + y_placeholder = keras.backend.placeholder(shape=()) + + callback_op = x_placeholder * y_placeholder + + f = keras.backend.function( + inputs=[x_placeholder, y_placeholder], + outputs=[x_placeholder + y_placeholder]) + f.fetches.append(callback_op) + f.fetch_callbacks[callback_op] = callback._fetch_callback + + _ = f([10., 20.]) + + self.assertEqual(callback.times_called, 1) + self.assertEqual(callback.callback_result, 200) + class BackendVariableTest(test.TestCase): @@ -458,6 +491,66 @@ class BackendLinearAlgebraTest(test.TestCase): input_shape_a=(4, 7), input_shape_b=(4, 7)) + def test_relu(self): + x = ops.convert_to_tensor([[-4, 0], [2, 7]], 'float32') + with self.test_session(): + # standard relu + relu_op = keras.backend.relu(x) + self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 7]]) + + # alpha + relu_op = keras.backend.relu(x, alpha=0.5) + self.assertAllClose(keras.backend.eval(relu_op), [[-2, 0], [2, 7]]) + + # max_value < some elements + relu_op = keras.backend.relu(x, max_value=5) + self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 5]]) + + # nn.relu6 used + relu_op = keras.backend.relu(x, max_value=6) + self.assertTrue('Relu6' in relu_op.name) # uses tf.nn.relu6 + self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 6]]) + + # max value > 6 + relu_op = keras.backend.relu(x, max_value=10) + self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 7]]) + + # max value is float + relu_op = keras.backend.relu(x, max_value=4.3) + self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 4.3]]) + + # max value == 0 + relu_op = keras.backend.relu(x, max_value=0) + self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [0, 0]]) + + # alpha and max_value + relu_op = keras.backend.relu(x, alpha=0.25, max_value=3) + self.assertAllClose(keras.backend.eval(relu_op), [[-1, 0], [2, 3]]) + + # threshold + relu_op = keras.backend.relu(x, threshold=3) + self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [0, 7]]) + + # threshold is float + relu_op = keras.backend.relu(x, threshold=1.5) + self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 7]]) + + # threshold is negative + relu_op = keras.backend.relu(x, threshold=-5) + self.assertAllClose(keras.backend.eval(relu_op), [[-4, 0], [2, 7]]) + + # threshold and max_value + relu_op = keras.backend.relu(x, threshold=3, max_value=5) + self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [0, 5]]) + + # threshold and alpha + relu_op = keras.backend.relu(x, alpha=0.25, threshold=4) + self.assertAllClose(keras.backend.eval(relu_op), [[-2, -1], [-0.5, 7]]) + + # threshold, alpha, and max_value + relu_op = keras.backend.relu(x, alpha=0.25, threshold=4, max_value=5) + self.assertAllClose(keras.backend.eval(relu_op), [[-2, -1], [-0.5, 5]]) + class BackendShapeOpsTest(test.TestCase): @@ -661,7 +754,7 @@ class BackendShapeOpsTest(test.TestCase): np_kwargs={'data_format': 'channels_first'}) -class BackendNNOpsTest(test.TestCase): +class BackendNNOpsTest(test.TestCase, parameterized.TestCase): def test_bias_add(self): with self.test_session(): @@ -810,52 +903,117 @@ class BackendNNOpsTest(test.TestCase): padding='same', data_format='channels_last') self.assertEqual(y.get_shape().as_list(), [10, 5, 5]) - def test_local_conv1d_channels_dim(self): - input_length = 5 - input_dim = 3 + def test_local_conv_channels_dim(self): + filters = 3 batch_size = 2 - inputs = np.random.normal(0, 1, (batch_size, input_dim, input_length)) - inputs_cf = keras.backend.variable(inputs) + for input_shape in [(3, 5), (2, 3, 5), (2, 5, 3, 4)]: + channels_in = input_shape[0] + input_spatial_shape = input_shape[1:] + dim = len(input_spatial_shape) + + inputs = np.random.normal(0, 1, (batch_size,) + input_shape) + inputs_cf = keras.backend.variable(inputs) + + for kernel_size in [1, 2]: + for stride in [1, 2]: + kernel_sizes = (kernel_size,) * dim + strides = (stride,) * dim - filters = 4 - for kernel_size in [(1,), (2,), (3,)]: - for strides in [(1,), (2,), (3,)]: - output_length = (input_length - kernel_size[0] - + strides[0]) // strides[0] + output_shape = tuple([(i - kernel_size + stride) // stride + for i in input_spatial_shape]) - kernel_shape = (output_length, kernel_size[0] * input_dim, filters) - kernel = np.random.normal(0, 1, (output_length, - input_dim, - kernel_size[0], - filters)) - kernel_cf = np.reshape(kernel, kernel_shape) - kernel_cf = keras.backend.variable(kernel_cf) + kernel_shape = (np.prod(output_shape), + np.prod(kernel_sizes) * channels_in, + filters) - conv_cf = keras.backend.local_conv1d(inputs_cf, + kernel = np.random.normal( + 0, + 1, + output_shape + (channels_in, np.prod(kernel_sizes), filters) + ) + + kernel_cf = np.reshape(kernel, kernel_shape) + kernel_cf = keras.backend.variable(kernel_cf) + + conv_cf = keras.backend.local_conv(inputs_cf, kernel_cf, - kernel_size, + kernel_sizes, strides, + output_shape, 'channels_first') - inputs_cl = np.transpose(inputs, (0, 2, 1)) - inputs_cl = keras.backend.variable(inputs_cl) + inputs_cl = np.transpose(inputs, [0, 2] + list(range(3, dim + 2)) + + [1]) + inputs_cl = keras.backend.variable(inputs_cl) - kernel_cl = np.reshape(np.transpose(kernel, (0, 2, 1, 3)), - kernel_shape) - kernel_cl = keras.backend.variable(kernel_cl) + kernel_cl = np.reshape( + np.transpose(kernel, list(range(dim)) + [dim + 1, dim, dim + 2]), + kernel_shape + ) + kernel_cl = keras.backend.variable(kernel_cl) - conv_cl = keras.backend.local_conv1d(inputs_cl, + conv_cl = keras.backend.local_conv(inputs_cl, kernel_cl, - kernel_size, + kernel_sizes, strides, + output_shape, 'channels_last') - with self.test_session(): - conv_cf = keras.backend.eval(conv_cf) - conv_cl = keras.backend.eval(conv_cl) + with self.test_session(): + conv_cf = keras.backend.eval(conv_cf) + conv_cl = keras.backend.eval(conv_cl) + + self.assertAllCloseAccordingToType( + conv_cf, + np.transpose(conv_cl, + [0, dim + 1] + list(range(1, dim + 1))), + atol=1e-5 + ) + + @parameterized.named_parameters( + ('local_conv1d', (5, 6), (3,), (1,), (3,)), + ('local_conv2d', (4, 5, 6), (3, 3), (1, 1), (2, 3))) + def test_local_conv_1d_and_2d(self, + input_shape, + kernel_sizes, + strides, + output_shape): + filters = 3 + batch_size = 2 + + inputs = np.random.normal(0, 1, (batch_size,) + input_shape) + inputs = keras.backend.variable(inputs) + + kernel = np.random.normal(0, 1, (np.prod(output_shape), + np.prod(kernel_sizes) * input_shape[-1], + filters)) + kernel = keras.backend.variable(kernel) + + local_conv = keras.backend.local_conv(inputs, + kernel, + kernel_sizes, + strides, + output_shape, + 'channels_last') + if len(output_shape) == 1: + local_conv_dim = keras.backend.local_conv1d(inputs, + kernel, + kernel_sizes, + strides, + 'channels_last') + else: + local_conv_dim = keras.backend.local_conv2d(inputs, + kernel, + kernel_sizes, + strides, + output_shape, + 'channels_last') - self.assertAllCloseAccordingToType(conv_cf, - np.transpose(conv_cl, (0, 2, 1))) + with self.test_session(): + local_conv = keras.backend.eval(local_conv) + local_conv_dim = keras.backend.eval(local_conv_dim) + + self.assertAllCloseAccordingToType(local_conv, local_conv_dim) def test_conv2d(self): val = np.random.random((10, 4, 10, 10)) @@ -1010,7 +1168,7 @@ class BackendNNOpsTest(test.TestCase): {'go_backwards': False, 'mask': mask, 'unroll': True}, ] with self.test_session(): - for (i, kwargs) in enumerate(kwargs_list): + for i, kwargs in enumerate(kwargs_list): last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs, initial_states, **kwargs) @@ -1057,6 +1215,115 @@ class BackendNNOpsTest(test.TestCase): for b_s, b_u_s in zip(state_list[2], state_list[3]): self.assertAllClose(b_s, b_u_s, atol=1e-04) + def test_rnn_additional_states(self): + # implement a simple RNN + num_samples = 4 + input_dim = 5 + output_dim = 3 + timesteps = 6 + + input_val = np.random.random( + (num_samples, timesteps, input_dim)).astype(np.float32) + init_state_val = np.random.random( + (num_samples, output_dim)).astype(np.float32) + w_i_val = np.random.random((input_dim, output_dim)).astype(np.float32) + w_o_val = np.random.random((output_dim, output_dim)).astype(np.float32) + np_mask = np.random.randint(2, size=(num_samples, timesteps)) + + def rnn_step_fn(): + w_i = keras.backend.variable(w_i_val) + w_o = keras.backend.variable(w_o_val) + + def step_function(x, states): + assert len(states) == 2 + prev_output = states[0] + output = keras.backend.dot(x, w_i) + keras.backend.dot(prev_output, w_o) + return output, [output, + keras.backend.concatenate([output, output], axis=-1)] + + return step_function + + # test default setup + last_output_list = [[], [], [], [], [], []] + outputs_list = [[], [], [], [], [], []] + state_list = [[], [], [], [], [], []] + additional_state_list = [[], [], [], [], [], []] + + rnn_fn = rnn_step_fn() + inputs = keras.backend.variable(input_val) + initial_states = [keras.backend.variable(init_state_val), + np.concatenate([init_state_val, init_state_val], axis=-1)] + mask = keras.backend.variable(np_mask) + + kwargs_list = [ + {'go_backwards': False, 'mask': None}, + {'go_backwards': False, 'mask': None, 'unroll': True}, + {'go_backwards': True, 'mask': None}, + {'go_backwards': True, 'mask': None, 'unroll': True}, + {'go_backwards': False, 'mask': mask}, + {'go_backwards': False, 'mask': mask, 'unroll': True}, + ] + with self.test_session(): + for i, kwargs in enumerate(kwargs_list): + last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs, + initial_states, + **kwargs) + # check static shape inference + self.assertEqual(last_output.get_shape().as_list(), + [num_samples, output_dim]) + self.assertEqual(outputs.get_shape().as_list(), + [num_samples, timesteps, output_dim]) + # for state in new_states: + # self.assertEquals(state.get_shape().as_list(), + # [num_samples, output_dim]) + self.assertEqual(new_states[0].get_shape().as_list(), + [num_samples, output_dim]) + self.assertEqual(new_states[1].get_shape().as_list(), + [num_samples, 2 * output_dim]) + + last_output_list[i].append(keras.backend.eval(last_output)) + outputs_list[i].append(keras.backend.eval(outputs)) + self.assertEqual(len(new_states), 2) + state_list[i].append(keras.backend.eval(new_states[0])) + additional_state_list[i].append(keras.backend.eval(new_states[1])) + + def assert_list_pairwise(z_list, atol=1e-05): + for (z1, z2) in zip(z_list[1:], z_list[:-1]): + self.assertAllClose(z1, z2, atol=atol) + + assert_list_pairwise(last_output_list[0], atol=1e-04) + assert_list_pairwise(outputs_list[0], atol=1e-04) + assert_list_pairwise(state_list[0], atol=1e-04) + assert_list_pairwise(additional_state_list[0], atol=1e-04) + assert_list_pairwise(last_output_list[2], atol=1e-04) + assert_list_pairwise(outputs_list[2], atol=1e-04) + assert_list_pairwise(state_list[2], atol=1e-04) + assert_list_pairwise(additional_state_list[2], atol=1e-04) + + for l, u_l in zip(last_output_list[0], last_output_list[1]): + self.assertAllClose(l, u_l, atol=1e-04) + + for o, u_o in zip(outputs_list[0], outputs_list[1]): + self.assertAllClose(o, u_o, atol=1e-04) + + for s, u_s in zip(state_list[0], state_list[1]): + self.assertAllClose(s, u_s, atol=1e-04) + + for s, u_s in zip(additional_state_list[0], additional_state_list[1]): + self.assertAllClose(s, u_s, atol=1e-04) + + for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]): + self.assertAllClose(b_l, b_u_l, atol=1e-04) + + for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]): + self.assertAllClose(b_o, b_u_o, atol=1e-04) + + for b_s, b_u_s in zip(state_list[2], state_list[3]): + self.assertAllClose(b_s, b_u_s, atol=1e-04) + + for s, u_s in zip(additional_state_list[2], additional_state_list[3]): + self.assertAllClose(s, u_s, atol=1e-04) + def test_normalize_batch_in_training(self): val = np.random.random((10, 3, 10, 10)) x = keras.backend.variable(val) @@ -1212,6 +1479,13 @@ class TestRandomOps(test.TestCase): self.assertAllClose(np.max(y), 2., atol=0.1) self.assertAllClose(np.min(y), -2., atol=0.1) + def test_string_input(self): + seq = keras.Sequential([ + keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string), + keras.layers.Lambda(lambda x: x[0]) + ]) + preds = seq.predict([['tensorflow eager']]) + self.assertEqual(preds.shape, (1,)) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 9f91368e5bd772b47ac951a600f458126c1e12a6..befe82f4eccbcde55e8a620e51285d27eba03a7e 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -22,19 +22,30 @@ from __future__ import print_function from collections import deque from collections import Iterable from collections import OrderedDict +import copy import csv import json +import math import os import time import numpy as np import six +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine.training_utils import standardize_input_data +from tensorflow.python.keras.utils.data_utils import Sequence from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.ops import array_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import summary_ops_v2 +from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as tf_summary +from tensorflow.python.training import saver from tensorflow.python.util.tf_export import tf_export @@ -44,6 +55,110 @@ except ImportError: requests = None +def configure_callbacks(callbacks, + model, + do_validation=False, + val_inputs=None, + val_targets=None, + val_sample_weights=None, + batch_size=None, + epochs=None, + steps_per_epoch=None, + samples=None, + validation_steps=None, + verbose=1, + count_mode='steps'): + """Configures callbacks for use in various training loops. + + Arguments: + callbacks: List of Callbacks. + model: Model being trained. + do_validation: Whether or not validation loop will be run. + val_inputs: Inputs to Model for validation loop. Can be any + data format Keras accepts. + val_targets: Targets for Model for validation loop. Can be any + data format Keras accepts. + val_sample_weights: Sample weights for Model for validation loop. + Can be any data format Keras accepts. + batch_size: Number of samples per batch. + epochs: Number of epoch to train. + steps_per_epoch: Number of batches to run per training epoch. + samples: Number of training samples. + validation_steps: Number of batches to run per validation epoch. + verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger. + count_mode: One of 'steps' or 'samples'. Per-batch or per-sample count. + + Returns: + Instance of CallbackList used to control all Callbacks. + """ + + # Add additional callbacks + model.history = History() + stateful_metric_names = None + if hasattr(model, 'stateful_metric_names'): + stateful_metric_names = model.stateful_metric_names + callbacks = [BaseLogger(stateful_metrics=stateful_metric_names) + ] + (callbacks or []) + [model.history] + if verbose: + callbacks.append( + ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names)) + callback_list = CallbackList(callbacks) + + # Set callback model + callback_model = model._get_callback_model() # pylint: disable=protected-access + if do_validation and val_inputs and not context.executing_eagerly(): + # Need to create the test_function before start of the first epoch + # because TensorBoard callback on_epoch_begin adds summary to the + # list of fetches of the test_function + callback_model._make_test_function() # pylint: disable=protected-access + callback_list.set_model(callback_model) + + # Set callback parameters + callback_metrics = [] + # When we have deferred build scenario with iterator input, we will compile + # when we standardize first batch of data. + if model._is_compiled: # pylint: disable=protected-access + callback_metrics = copy.copy(model.metrics_names) + if do_validation: + callback_metrics += ['val_' + n for n in model.metrics_names] + if validation_steps is None and isinstance(val_inputs, Sequence): + validation_steps = len(val_inputs) + callback_params = { + 'batch_size': batch_size, + 'epochs': epochs, + 'steps': steps_per_epoch, + 'samples': samples, + 'verbose': verbose, + 'do_validation': do_validation, + 'metrics': callback_metrics, + 'validation_steps': validation_steps + } + callback_list.set_params(callback_params) + + # Pass validation data to callbacks + if not val_inputs: + val_data = [] + elif _is_generator_like(val_inputs): + val_data = val_inputs + else: + val_data = val_inputs + val_targets + if val_sample_weights: + val_data += val_sample_weights + if model.uses_learning_phase and not isinstance(K.learning_phase(), int): + val_data += [0.] + for cbk in callbacks: + cbk.validation_data = val_data + + callback_list.model.stop_training = False + return callback_list + + +def _is_generator_like(data): + """Checks if data is a generator, Sequence, or Iterator.""" + return (hasattr(data, 'next') or hasattr(data, '__next__') or isinstance( + data, (Sequence, iterator_ops.Iterator, iterator_ops.EagerIterator))) + + class CallbackList(object): """Container abstracting a list of callbacks. @@ -57,15 +172,19 @@ class CallbackList(object): callbacks = callbacks or [] self.callbacks = [c for c in callbacks] self.queue_length = queue_length + self.params = {} + self.model = None def append(self, callback): self.callbacks.append(callback) def set_params(self, params): + self.params = params for callback in self.callbacks: callback.set_params(params) def set_model(self, model): + self.model = model for callback in self.callbacks: callback.set_model(model) @@ -496,6 +615,9 @@ class EarlyStopping(Callback): monitored has stopped increasing; in `auto` mode, the direction is automatically inferred from the name of the monitored quantity. + baseline: baseline value for the monitored quantity. + Training will stop if the model doesn't show improvement over the + baseline. """ def __init__(self, @@ -503,13 +625,15 @@ class EarlyStopping(Callback): min_delta=0, patience=0, verbose=0, - mode='auto'): + mode='auto', + baseline=None): super(EarlyStopping, self).__init__() self.monitor = monitor self.patience = patience self.verbose = verbose - self.min_delta = min_delta + self.baseline = baseline + self.min_delta = abs(min_delta) self.wait = 0 self.stopped_epoch = 0 @@ -537,7 +661,10 @@ class EarlyStopping(Callback): # Allow instances to be re-used self.wait = 0 self.stopped_epoch = 0 - self.best = np.Inf if self.monitor_op == np.less else -np.Inf + if self.baseline is not None: + self.best = self.baseline + else: + self.best = np.Inf if self.monitor_op == np.less else -np.Inf def on_epoch_end(self, epoch, logs=None): current = logs.get(self.monitor) @@ -688,7 +815,9 @@ class TensorBoard(Callback): write_images: whether to write model weights to visualize as image in TensorBoard. embeddings_freq: frequency (in epochs) at which selected embedding - layers will be saved. + layers will be saved. If set to 0, embeddings won't be computed. + Data to be visualized in TensorBoard's Embedding tab must be passed + as `embeddings_data`. embeddings_layer_names: a list of names of layers to keep eye on. If None or empty list all the embedding layer will be watched. embeddings_metadata: a dictionary which maps layer name to a file name @@ -696,6 +825,19 @@ class TensorBoard(Callback): [details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional) about metadata files format. In case if the same metadata file is used for all embedding layers, string can be passed. + embeddings_data: data to be embedded at layers specified in + `embeddings_layer_names`. Numpy array (if the model has a single + input) or list of Numpy arrays (if the model has multiple inputs). + Learn [more about embeddings](https://www.tensorflow.org/programmers_guide/embedding) + + Raises: + ValueError: If histogram_freq is set and no validation data is provided. + + @compatibility(eager) + Using `Tensorboard` callback will work while eager execution is enabled, + however outputting histogram summaries of weights and gradients is not + supported, and thus `histogram_freq` will be ignored. + @end_compatibility """ # pylint: enable=line-too-long @@ -706,19 +848,43 @@ class TensorBoard(Callback): batch_size=32, write_graph=True, write_grads=False, - write_images=False): + write_images=False, + embeddings_freq=0, + embeddings_layer_names=None, + embeddings_metadata=None, + embeddings_data=None): super(TensorBoard, self).__init__() self.log_dir = log_dir self.histogram_freq = histogram_freq + if self.histogram_freq and context.executing_eagerly(): + logging.warning( + UserWarning('Weight and gradient histograms not supported for eager' + 'execution, setting `histogram_freq` to `0`.')) + self.histogram_freq = 0 self.merged = None self.write_graph = write_graph self.write_grads = write_grads self.write_images = write_images self.batch_size = batch_size + self._current_batch = 0 + self._total_batches_seen = 0 + self.embeddings_freq = embeddings_freq + self.embeddings_layer_names = embeddings_layer_names + self.embeddings_metadata = embeddings_metadata + self.embeddings_data = embeddings_data + + def _init_writer(self): + """Sets file writer.""" + if context.executing_eagerly(): + self.writer = summary_ops_v2.create_file_writer(self.log_dir) + elif self.write_graph: + self.writer = tf_summary.FileWriter(self.log_dir, K.get_session().graph) + else: + self.writer = tf_summary.FileWriter(self.log_dir) - def set_model(self, model): - self.model = model - self.sess = K.get_session() + def _make_histogram_ops(self, model): + """Defines histogram ops when histogram_freq > 0.""" + # only make histogram summary op if it hasn't already been made if self.histogram_freq and self.merged is None: for layer in self.model.layers: for weight in layer.weights: @@ -758,73 +924,223 @@ class TensorBoard(Callback): def is_indexed_slices(grad): return type(grad).__name__ == 'IndexedSlices' - grads = [grad.values if is_indexed_slices(grad) else grad - for grad in grads] + grads = [ + grad.values if is_indexed_slices(grad) else grad + for grad in grads + ] tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads) if hasattr(layer, 'output'): - tf_summary.histogram('{}_out'.format(layer.name), layer.output) - self.merged = tf_summary.merge_all() + if isinstance(layer.output, list): + for i, output in enumerate(layer.output): + tf_summary.histogram('{}_out_{}'.format(layer.name, i), output) + else: + tf_summary.histogram('{}_out'.format(layer.name), layer.output) - if self.write_graph: - self.writer = tf_summary.FileWriter(self.log_dir, self.sess.graph) - else: - self.writer = tf_summary.FileWriter(self.log_dir) + def set_model(self, model): + """Sets Keras model and creates summary ops.""" - def on_epoch_end(self, epoch, logs=None): + self.model = model + self._init_writer() + # histogram summaries only enabled in graph mode + if not context.executing_eagerly(): + self._make_histogram_ops(model) + self.merged = tf_summary.merge_all() + + # If both embedding_freq and embeddings_data are available, we will + # visualize embeddings. + if self.embeddings_freq and self.embeddings_data is not None: + self.embeddings_data = standardize_input_data(self.embeddings_data, + model.input_names) + + # If embedding_layer_names are not provided, get all of the embedding + # layers from the model. + embeddings_layer_names = self.embeddings_layer_names + if not embeddings_layer_names: + embeddings_layer_names = [ + layer.name + for layer in self.model.layers + if type(layer).__name__ == 'Embedding' + ] + + self.assign_embeddings = [] + embeddings_vars = {} + + self.batch_id = batch_id = array_ops.placeholder(dtypes.int32) + self.step = step = array_ops.placeholder(dtypes.int32) + + for layer in self.model.layers: + if layer.name in embeddings_layer_names: + embedding_input = self.model.get_layer(layer.name).output + embedding_size = np.prod(embedding_input.shape[1:]) + embedding_input = array_ops.reshape(embedding_input, + (step, int(embedding_size))) + shape = (self.embeddings_data[0].shape[0], int(embedding_size)) + embedding = variables.Variable( + array_ops.zeros(shape), name=layer.name + '_embedding') + embeddings_vars[layer.name] = embedding + batch = state_ops.assign(embedding[batch_id:batch_id + step], + embedding_input) + self.assign_embeddings.append(batch) + + self.saver = saver.Saver(list(embeddings_vars.values())) + + # Create embeddings_metadata dictionary + if isinstance(self.embeddings_metadata, str): + embeddings_metadata = { + layer_name: self.embeddings_metadata + for layer_name in embeddings_vars.keys() + } + else: + # If embedding_metadata is already a dictionary + embeddings_metadata = self.embeddings_metadata + + try: + from tensorboard.plugins import projector + except ImportError: + raise ImportError('Failed to import TensorBoard. Please make sure that ' + 'TensorBoard integration is complete."') + + # TODO(psv): Add integration tests to test embedding visualization + # with TensorBoard callback. We are unable to write a unit test for this + # because TensorBoard dependency assumes TensorFlow package is installed. + config = projector.ProjectorConfig() + for layer_name, tensor in embeddings_vars.items(): + embedding = config.embeddings.add() + embedding.tensor_name = tensor.name + + if (embeddings_metadata is not None and + layer_name in embeddings_metadata): + embedding.metadata_path = embeddings_metadata[layer_name] + + projector.visualize_embeddings(self.writer, config) + + def _fetch_callback(self, summary): + self.writer.add_summary( + summary, + self._epoch + self._current_val_batch / self._validation_batches) + self._current_val_batch += 1 + + def _write_custom_summaries(self, step, logs=None): + """Writes metrics out as custom scalar summaries. + + Arguments: + step: the global step to use for Tensorboard. + logs: dict. Keys are scalar summary names, values are + NumPy scalars. + + """ logs = logs or {} + if context.executing_eagerly(): + # use v2 summary ops + with self.writer.as_default(), summary_ops_v2.always_record_summaries(): + for name, value in logs.items(): + summary_ops_v2.scalar(name, value.item(), step=step) + else: + # use FileWriter from v1 summary + for name, value in logs.items(): + summary = tf_summary.Summary() + summary_value = summary.value.add() + summary_value.simple_value = value.item() + summary_value.tag = name + self.writer.add_summary(summary, step) + self.writer.flush() - if not self.validation_data and self.histogram_freq: - raise ValueError('If printing histograms, validation_data must be ' - 'provided, and cannot be a generator.') - if self.validation_data and self.histogram_freq: - if epoch % self.histogram_freq == 0: + def on_train_begin(self, logs=None): + """Checks if histogram summaries can be run.""" + # will never be set when in eager + if self.histogram_freq: + if self.params.get('validation_steps', None) is not None: + self._validation_batches = self.params['validation_steps'] + elif self.validation_data: + self._validation_batches = math.ceil( + self.validation_data[0].shape[0] / self.batch_size) + else: + raise ValueError('If printing histograms, validation data must be ' + 'provided.') + if self._validation_batches == 0: + raise ValueError( + 'If printing histograms, validation data must have length > 0.') - val_data = self.validation_data - tensors = ( - self.model.inputs + self.model.targets + self.model.sample_weights) + def on_batch_end(self, batch, logs=None): + """Writes scalar summaries for metrics on every training batch.""" + # Don't output batch_size and batch number as Tensorboard summaries + logs = logs or {} + batch_logs = {('batch_' + k): v + for k, v in logs.items() + if k not in ['batch', 'size']} + self._write_custom_summaries(self._total_batches_seen, batch_logs) + self._total_batches_seen += 1 - if self.model.uses_learning_phase: - tensors += [K.learning_phase()] + def on_epoch_begin(self, epoch, logs=None): + """Add histogram op to Model test_function callbacks, reset batch count.""" + + # check if histogram summary should be run for this epoch + if self.histogram_freq and epoch % self.histogram_freq == 0: + self._epoch = epoch + self._current_val_batch = 0 + # add the histogram summary op if it should run this epoch + if self.merged not in self.model.test_function.fetches: + self.model.test_function.fetches.append(self.merged) + self.model.test_function.fetch_callbacks[ + self.merged] = self._fetch_callback - assert len(val_data) == len(tensors) - val_size = val_data[0].shape[0] + def on_epoch_end(self, epoch, logs=None): + """Checks if summary ops should run next epoch, logs scalar summaries.""" + + # don't output batch_size and + # batch number as Tensorboard summaries + logs = {('epoch_' + k): v + for k, v in logs.items() + if k not in ['batch', 'size']} + self._write_custom_summaries(epoch, logs) + + # pop the histogram summary op after each epoch + if self.histogram_freq: + if self.merged in self.model.test_function.fetches: + self.model.test_function.fetches.remove(self.merged) + if self.merged in self.model.test_function.fetch_callbacks: + self.model.test_function.fetch_callbacks.pop(self.merged) + + if self.embeddings_data is None and self.embeddings_freq: + raise ValueError('To visualize embeddings, embeddings_data must ' + 'be provided.') + + if self.embeddings_freq and self.embeddings_data is not None: + if epoch % self.embeddings_freq == 0: + # We need a second forward-pass here because we're passing + # the `embeddings_data` explicitly. This design allows to pass + # arbitrary data as `embeddings_data` and results from the fact + # that we need to know the size of the `tf.Variable`s which + # hold the embeddings in `set_model`. At this point, however, + # the `validation_data` is not yet set. + + embeddings_data = self.embeddings_data + n_samples = embeddings_data[0].shape[0] i = 0 - while i < val_size: - step = min(self.batch_size, val_size - i) - batch_val = [] - batch_val.append(val_data[0][i:i + step] - if val_data[0] is not None else None) - batch_val.append(val_data[1][i:i + step] - if val_data[1] is not None else None) - batch_val.append(val_data[2][i:i + step] - if val_data[2] is not None else None) - if self.model.uses_learning_phase: - # do not slice the learning phase - batch_val = [x[i:i + step] if x is not None else None - for x in val_data[:-1]] - batch_val.append(val_data[-1]) + while i < n_samples: + step = min(self.batch_size, n_samples - i) + batch = slice(i, i + step) + + if isinstance(self.model.input, list): + feed_dict = { + model_input: embeddings_data[idx][batch] + for idx, model_input in enumerate(self.model.input) + } else: - batch_val = [x[i:i + step] if x is not None else None - for x in val_data] - feed_dict = {} - for key, val in zip(tensors, batch_val): - if val is not None: - feed_dict[key] = val - result = self.sess.run([self.merged], feed_dict=feed_dict) - summary_str = result[0] - self.writer.add_summary(summary_str, epoch) - i += self.batch_size + feed_dict = {self.model.input: embeddings_data[0][batch]} - for name, value in logs.items(): - if name in ['batch', 'size']: - continue - summary = tf_summary.Summary() - summary_value = summary.value.add() - summary_value.simple_value = value.item() - summary_value.tag = name - self.writer.add_summary(summary, epoch) - self.writer.flush() + feed_dict.update({self.batch_id: i, self.step: step}) + + if self.model.uses_learning_phase: + feed_dict[K.learning_phase()] = False + + self.sess.run(self.assign_embeddings, feed_dict=feed_dict) + self.saver.save(self.sess, + os.path.join(self.log_dir, 'keras_embedding.ckpt'), + epoch) + + i += self.batch_size def on_train_end(self, logs=None): self.writer.close() diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 5062a26580ddb10011fd04f9a6e75ee6d2adbc68..e84e023384d22bcf93d5da7d0c4037cf685f6bf3 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -22,16 +22,21 @@ import csv import os import re import shutil +import tempfile import threading import unittest import numpy as np +from tensorflow.core.framework import summary_pb2 from tensorflow.python import keras +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import adam try: import h5py # pylint:disable=g-import-not-at-top @@ -62,7 +67,7 @@ class KerasCallbacksTest(test.TestCase): np.random.seed(1337) temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) + self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) filepath = os.path.join(temp_dir, 'checkpoint.h5') (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( @@ -273,16 +278,43 @@ class KerasCallbacksTest(test.TestCase): 1, activation='sigmoid'),)) model.compile( optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy']) - stopper = keras.callbacks.EarlyStopping(monitor='acc', patience=patience) weights = model.get_weights() + stopper = keras.callbacks.EarlyStopping(monitor='acc', patience=patience) hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) assert len(hist.epoch) >= patience # This should allow training to go for at least `patience` epochs model.set_weights(weights) hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) - assert len(hist.epoch) >= patience + assert len(hist.epoch) >= patience + + def test_EarlyStopping_with_baseline(self): + with self.test_session(): + np.random.seed(1337) + baseline = 0.5 + (data, labels), _ = testing_utils.get_test_data( + train_samples=100, + test_samples=50, + input_shape=(1,), + num_classes=NUM_CLASSES) + model = keras.models.Sequential((keras.layers.Dense( + 1, input_dim=1, activation='relu'), keras.layers.Dense( + 1, activation='sigmoid'),)) + model.compile( + optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy']) + + stopper = keras.callbacks.EarlyStopping(monitor='acc', + baseline=baseline) + hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) + assert len(hist.epoch) == 1 + + patience = 3 + stopper = keras.callbacks.EarlyStopping(monitor='acc', + patience=patience, + baseline=baseline) + hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) + assert len(hist.epoch) >= patience def test_RemoteMonitor(self): if requests is None: @@ -354,6 +386,7 @@ class KerasCallbacksTest(test.TestCase): y_train = keras.utils.to_categorical(y_train) def make_model(): + random_seed.set_random_seed(1234) np.random.seed(1337) model = keras.models.Sequential() model.add( @@ -451,7 +484,7 @@ class KerasCallbacksTest(test.TestCase): with self.test_session(): np.random.seed(1337) temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) + self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) filepath = os.path.join(temp_dir, 'log.tsv') sep = '\t' @@ -529,7 +562,7 @@ class KerasCallbacksTest(test.TestCase): # does not result in invalid CSVs. np.random.seed(1337) tmpdir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, tmpdir) + self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True) with self.test_session(): fp = os.path.join(tmpdir, 'test.csv') @@ -621,7 +654,7 @@ class KerasCallbacksTest(test.TestCase): np.random.seed(1337) temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) + self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( train_samples=TRAIN_SAMPLES, @@ -695,6 +728,8 @@ class KerasCallbacksTest(test.TestCase): verbose=0) # fit generator without validation data + # histogram_freq must be zero + tsb.histogram_freq = 0 model.fit_generator( data_generator(True), len(x_train), @@ -703,6 +738,7 @@ class KerasCallbacksTest(test.TestCase): verbose=0) # fit generator with validation data and accuracy + tsb.histogram_freq = 1 model.fit_generator( data_generator(True), len(x_train), @@ -712,6 +748,7 @@ class KerasCallbacksTest(test.TestCase): verbose=0) # fit generator without validation data and accuracy + tsb.histogram_freq = 0 model.fit_generator( data_generator(True), len(x_train), epochs=2, callbacks=cbks) assert os.path.exists(temp_dir) @@ -719,7 +756,7 @@ class KerasCallbacksTest(test.TestCase): def test_TensorBoard_histogram_freq_must_have_validation_data(self): np.random.seed(1337) tmpdir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, tmpdir) + self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True) with self.test_session(): filepath = os.path.join(tmpdir, 'logs') @@ -785,28 +822,13 @@ class KerasCallbacksTest(test.TestCase): for cb in cbs: cb.on_train_end() - # fit generator with validation data generator should raise ValueError if - # histogram_freq > 0 - cbs = callbacks_factory(histogram_freq=1) - with self.assertRaises(ValueError): - model.fit_generator( - data_generator(True), - len(x_train), - epochs=2, - validation_data=data_generator(False), - validation_steps=1, - callbacks=cbs) - - for cb in cbs: - cb.on_train_end() - # Make sure file writer cache is clear to avoid failures during cleanup. writer_cache.FileWriterCache.clear() def test_TensorBoard_multi_input_output(self): np.random.seed(1337) tmpdir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, tmpdir) + self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True) with self.test_session(): filepath = os.path.join(tmpdir, 'logs') @@ -874,6 +896,133 @@ class KerasCallbacksTest(test.TestCase): callbacks=callbacks_factory(histogram_freq=1)) assert os.path.isdir(filepath) + def test_Tensorboard_histogram_summaries_in_test_function(self): + + class FileWriterStub(object): + + def __init__(self, logdir, graph=None): + self.logdir = logdir + self.graph = graph + self.steps_seen = [] + + def add_summary(self, summary, global_step): + summary_obj = summary_pb2.Summary() + + # ensure a valid Summary proto is being sent + if isinstance(summary, bytes): + summary_obj.ParseFromString(summary) + else: + assert isinstance(summary, summary_pb2.Summary) + summary_obj = summary + + # keep track of steps seen for the merged_summary op, + # which contains the histogram summaries + if len(summary_obj.value) > 1: + self.steps_seen.append(global_step) + + def flush(self): + pass + + def close(self): + pass + + def _init_writer(obj): + obj.writer = FileWriterStub(obj.log_dir) + + np.random.seed(1337) + tmpdir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True) + (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( + train_samples=TRAIN_SAMPLES, + test_samples=TEST_SAMPLES, + input_shape=(INPUT_DIM,), + num_classes=NUM_CLASSES) + y_test = keras.utils.to_categorical(y_test) + y_train = keras.utils.to_categorical(y_train) + + with self.test_session(): + model = keras.models.Sequential() + model.add( + keras.layers.Dense( + NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu')) + # non_trainable_weights: moving_variance, moving_mean + model.add(keras.layers.BatchNormalization()) + model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax')) + model.compile( + loss='categorical_crossentropy', + optimizer='sgd', + metrics=['accuracy']) + keras.callbacks.TensorBoard._init_writer = _init_writer + tsb = keras.callbacks.TensorBoard( + log_dir=tmpdir, + histogram_freq=1, + write_images=True, + write_grads=True, + batch_size=5) + cbks = [tsb] + + # fit with validation data + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=3, + verbose=0) + + self.assertAllEqual(tsb.writer.steps_seen, [0, 0.5, 1, 1.5, 2, 2.5]) + + def test_Tensorboard_histogram_summaries_with_generator(self): + np.random.seed(1337) + tmpdir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True) + + def generator(): + x = np.random.randn(10, 100).astype(np.float32) + y = np.random.randn(10, 10).astype(np.float32) + while True: + yield x, y + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(10, input_dim=100, activation='relu')) + model.add(keras.layers.Dense(10, activation='softmax')) + model.compile( + loss='categorical_crossentropy', + optimizer='sgd', + metrics=['accuracy']) + tsb = keras.callbacks.TensorBoard( + log_dir=tmpdir, + histogram_freq=1, + write_images=True, + write_grads=True, + batch_size=5) + cbks = [tsb] + + # fit with validation generator + model.fit_generator( + generator(), + steps_per_epoch=2, + epochs=2, + validation_data=generator(), + validation_steps=2, + callbacks=cbks, + verbose=0) + + with self.assertRaises(ValueError): + # fit with validation generator but no + # validation_steps + model.fit_generator( + generator(), + steps_per_epoch=2, + epochs=2, + validation_data=generator(), + callbacks=cbks, + verbose=0) + + self.assertTrue(os.path.exists(tmpdir)) + @unittest.skipIf( os.name == 'nt', 'use_multiprocessing=True does not work on windows properly.') @@ -924,7 +1073,7 @@ class KerasCallbacksTest(test.TestCase): def test_TensorBoard_with_ReduceLROnPlateau(self): with self.test_session(): temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) + self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( train_samples=TRAIN_SAMPLES, @@ -959,6 +1108,112 @@ class KerasCallbacksTest(test.TestCase): assert os.path.exists(temp_dir) + def test_Tensorboard_batch_logging(self): + + class FileWriterStub(object): + + def __init__(self, logdir, graph=None): + self.logdir = logdir + self.graph = graph + self.batches_logged = [] + self.summary_values = [] + self.summary_tags = [] + + def add_summary(self, summary, step): + self.summary_values.append(summary.value[0].simple_value) + self.summary_tags.append(summary.value[0].tag) + self.batches_logged.append(step) + + def flush(self): + pass + + def close(self): + pass + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) + + tb_cbk = keras.callbacks.TensorBoard(temp_dir) + tb_cbk.writer = FileWriterStub(temp_dir) + + for batch in range(5): + tb_cbk.on_batch_end(batch, {'acc': np.float32(batch)}) + self.assertEqual(tb_cbk.writer.batches_logged, [0, 1, 2, 3, 4]) + self.assertEqual(tb_cbk.writer.summary_values, [0., 1., 2., 3., 4.]) + self.assertEqual(tb_cbk.writer.summary_tags, ['batch_acc'] * 5) + + def test_Tensorboard_epoch_and_batch_logging(self): + + class FileWriterStub(object): + + def __init__(self, logdir, graph=None): + self.logdir = logdir + self.graph = graph + + def add_summary(self, summary, step): + if 'batch_' in summary.value[0].tag: + self.batch_summary = (step, summary) + elif 'epoch_' in summary.value[0].tag: + self.epoch_summary = (step, summary) + + def flush(self): + pass + + def close(self): + pass + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) + + tb_cbk = keras.callbacks.TensorBoard(temp_dir) + tb_cbk.writer = FileWriterStub(temp_dir) + + tb_cbk.on_batch_end(0, {'acc': np.float32(5.0)}) + tb_cbk.on_epoch_end(0, {'acc': np.float32(10.0)}) + batch_step, batch_summary = tb_cbk.writer.batch_summary + self.assertEqual(batch_step, 0) + self.assertEqual(batch_summary.value[0].simple_value, 5.0) + epoch_step, epoch_summary = tb_cbk.writer.epoch_summary + self.assertEqual(epoch_step, 0) + self.assertEqual(epoch_summary.value[0].simple_value, 10.0) + + @test_util.run_in_graph_and_eager_modes + def test_Tensorboard_eager(self): + with self.test_session(): + temp_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) + self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) + + (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( + train_samples=TRAIN_SAMPLES, + test_samples=TEST_SAMPLES, + input_shape=(INPUT_DIM,), + num_classes=NUM_CLASSES) + y_test = keras.utils.to_categorical(y_test) + y_train = keras.utils.to_categorical(y_train) + + model = keras.models.Sequential() + model.add( + keras.layers.Dense( + NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu')) + model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax')) + model.compile( + loss='binary_crossentropy', + optimizer=adam.AdamOptimizer(0.01), + metrics=['accuracy']) + + cbks = [keras.callbacks.TensorBoard(log_dir=temp_dir)] + + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=2, + verbose=0) + + self.assertTrue(os.path.exists(temp_dir)) + def test_RemoteMonitorWithJsonPayload(self): if requests is None: self.skipTest('`requests` required to run this test') diff --git a/tensorflow/python/keras/datasets/boston_housing.py b/tensorflow/python/keras/datasets/boston_housing.py index 4c4cab8c0865098ebed1a7fbe29246ef51bb9833..eeb7cbc44a72a5c624f8d1d1d9dbfab1fcd1b225 100644 --- a/tensorflow/python/keras/datasets/boston_housing.py +++ b/tensorflow/python/keras/datasets/boston_housing.py @@ -45,10 +45,9 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113): origin=origin_folder + 'boston_housing.npz', file_hash= 'f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5') - f = np.load(path) - x = f['x'] - y = f['y'] - f.close() + with np.load(path) as f: + x = f['x'] + y = f['y'] np.random.seed(seed) indices = np.arange(len(x)) diff --git a/tensorflow/python/keras/datasets/mnist.py b/tensorflow/python/keras/datasets/mnist.py index 03564accc74507713d198b8ba1ed8c08bd597e8d..a96b581960f3d5f60994fe92a1424e793d7e39c7 100644 --- a/tensorflow/python/keras/datasets/mnist.py +++ b/tensorflow/python/keras/datasets/mnist.py @@ -47,8 +47,8 @@ def load_data(path='mnist.npz'): path, origin=origin_folder + 'mnist.npz', file_hash='8a61469f7ea1b51cbae51d4f78837e45') - f = np.load(path) - x_train, y_train = f['x_train'], f['y_train'] - x_test, y_test = f['x_test'], f['y_test'] - f.close() - return (x_train, y_train), (x_test, y_test) + with np.load(path) as f: + x_train, y_train = f['x_train'], f['y_train'] + x_test, y_test = f['x_test'], f['y_test'] + + return (x_train, y_train), (x_test, y_test) diff --git a/tensorflow/python/keras/datasets/reuters.py b/tensorflow/python/keras/datasets/reuters.py index 2120b4b2421c652c9587a2e644bf008c3ece3980..cb796bb06cf09157cc510b55e3981d518fd8b433 100644 --- a/tensorflow/python/keras/datasets/reuters.py +++ b/tensorflow/python/keras/datasets/reuters.py @@ -130,7 +130,5 @@ def get_word_index(path='reuters_word_index.json'): path, origin=origin_folder + 'reuters_word_index.json', file_hash='4d44cc38712099c9e383dc6e5f11a921') - f = open(path) - data = json.load(f) - f.close() - return data + with open(path) as f: + return json.load(f) diff --git a/tensorflow/python/keras/engine/__init__.py b/tensorflow/python/keras/engine/__init__.py index ec7c0831992b2691c442bbd30445dbff8dba662f..26aed34766f9e1e2094db7a4c8b66ff057dacc4b 100644 --- a/tensorflow/python/keras/engine/__init__.py +++ b/tensorflow/python/keras/engine/__init__.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +# TODO(fchollet): Remove hourglass imports once external code is done importing +# non-public APIs. from tensorflow.python.keras.engine.base_layer import InputSpec from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.input_layer import Input from tensorflow.python.keras.engine.input_layer import InputLayer -from tensorflow.python.keras.engine.network import get_source_inputs -from tensorflow.python.keras.engine.network import Network -from tensorflow.python.keras.engine.training import Model +from tensorflow.python.keras.utils.layer_utils import get_source_inputs del absolute_import del division diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 4814275fd5ba53e7845f383b3447a6ef9f47f6c2..d6d3db21fbbd2e498fbedca514d119c7110b46f8 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -26,6 +26,7 @@ import numpy as np from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.eager import context +from tensorflow.python.eager import function as eager_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -116,6 +117,7 @@ class Layer(checkpointable.CheckpointableBase): constraints on inputs that can be accepted by the layer. """ + @checkpointable.no_automatic_dependency_tracking def __init__(self, trainable=True, name=None, dtype=None, **kwargs): # These properties should be set by the user via keyword arguments. # note that 'dtype', 'input_shape' and 'batch_input_shape' @@ -173,6 +175,12 @@ class Layer(checkpointable.CheckpointableBase): self.supports_masking = False + call_argspec = tf_inspect.getfullargspec(self.call) + if 'training' in call_argspec.args: + self._expects_training_arg = True + else: + self._expects_training_arg = False + # Manage input shape information if passed. if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: # In this case we will later create an input layer @@ -217,7 +225,7 @@ class Layer(checkpointable.CheckpointableBase): @activity_regularizer.setter def activity_regularizer(self, regularizer): """Optional regularizer function for the output of this layer.""" - self._activity_regularizer = regularizer + self._activity_regularizer = self._no_dependency(regularizer) @property def trainable_weights(self): @@ -459,14 +467,18 @@ class Layer(checkpointable.CheckpointableBase): """Alias for `add_weight`.""" return self.add_weight(*args, **kwargs) - def add_weight(self, name, shape, + def add_weight(self, + name, + shape, dtype=None, initializer=None, regularizer=None, - trainable=True, + trainable=None, constraint=None, partitioner=None, use_resource=None, + synchronization=vs.VariableSynchronization.AUTO, + aggregation=vs.VariableAggregation.NONE, getter=None): """Adds a new variable to the layer, or gets an existing one; returns it. @@ -481,10 +493,20 @@ class Layer(checkpointable.CheckpointableBase): or "non_trainable_variables" (e.g. BatchNorm mean, stddev). Note, if the current variable scope is marked as non-trainable then this parameter is ignored and any added variables are also - marked as non-trainable. + marked as non-trainable. `trainable` defaults to `True` unless + `synchronization` is set to `ON_READ`. constraint: constraint instance (callable). partitioner: Partitioner to be passed to the `Checkpointable` API. use_resource: Whether to use `ResourceVariable`. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. getter: Variable getter argument to be passed to the `Checkpointable` API. Returns: @@ -495,7 +517,8 @@ class Layer(checkpointable.CheckpointableBase): Raises: RuntimeError: If called with partioned variable regularization and eager execution is enabled. - ValueError: When giving unsupported dtype and no initializer. + ValueError: When giving unsupported dtype and no initializer or when + trainable has been set to True with synchronization set as `ON_READ`. """ if dtype is None: dtype = self.dtype or backend.floatx() @@ -504,6 +527,19 @@ class Layer(checkpointable.CheckpointableBase): regularizer = regularizers.get(regularizer) constraint = constraints.get(constraint) + if synchronization == vs.VariableSynchronization.ON_READ: + if trainable: + raise ValueError( + 'Synchronization value can be set to ' + 'VariableSynchronization.ON_READ only for non-trainable variables. ' + 'You have specified trainable=True and ' + 'synchronization=VariableSynchronization.ON_READ.') + else: + # Set trainable to be false when variable is to be synced on read. + trainable = False + elif trainable is None: + trainable = True + # Initialize variable when no initializer provided if initializer is None: # If dtype is DT_FLOAT, provide a uniform unit scaling initializer @@ -531,7 +567,9 @@ class Layer(checkpointable.CheckpointableBase): constraint=constraint, trainable=trainable and self.trainable, partitioner=partitioner, - use_resource=use_resource) + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) if regularizer is not None: # TODO(fchollet): in the future, this should be handled at the @@ -654,11 +692,12 @@ class Layer(checkpointable.CheckpointableBase): # Handle Keras mask propagation from previous layer to current layer. previous_mask = None - if (not hasattr(self, '_compute_previous_mask') or - self._compute_previous_mask): + if build_graph and (not hasattr(self, '_compute_previous_mask') or + self._compute_previous_mask): previous_mask = collect_previous_mask(inputs) if not hasattr(self, '_call_fn_args'): - self._call_fn_args = function_utils.fn_args(self.call) + self._call_fn_args = self._no_dependency( + function_utils.fn_args(self.call)) if ('mask' in self._call_fn_args and 'mask' not in kwargs and not generic_utils.is_all_none(previous_mask)): # The previous layer generated a mask, and mask was not explicitly pass @@ -691,9 +730,20 @@ class Layer(checkpointable.CheckpointableBase): self._dtype = input_list[0].dtype.base_dtype.name except AttributeError: pass - if all(hasattr(x, 'get_shape') for x in input_list): - input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs) - self.build(input_shapes) + + if all(hasattr(x, 'shape') for x in input_list): + input_shapes = nest.map_structure(lambda x: x.shape, inputs) + + if (not hasattr(self, '_is_graph_network') or + self.__class__.__name__ == 'Sequential' or + not hasattr(self.build, '_is_default')): + # Only if self is a layer, an instance of a sequential model, or + # the user has manually overwritten the build method do we need to + # build it. + self.build(input_shapes) + # We must set self.built since user defined build functions are not + # constrained to set self.built. + self.built = True # Check input assumptions set after layer building, e.g. input shape. if build_graph or in_deferred_mode: @@ -709,7 +759,7 @@ class Layer(checkpointable.CheckpointableBase): # Deferred mode behavior: use `compute_output_shape` to # infer the number of outputs of the layer and their shapes. if input_shapes is None: - input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs) + input_shapes = nest.map_structure(lambda x: x.shape, inputs) output_shapes = self.compute_output_shape(input_shapes) output_shapes = nest.flatten(output_shapes) @@ -723,14 +773,11 @@ class Layer(checkpointable.CheckpointableBase): if build_graph: self._handle_activity_regularization(inputs, outputs) - # TODO(fchollet): consider enabling masking for Eager mode. self._set_mask_metadata(inputs, outputs, previous_mask) if in_deferred_mode or build_graph and have_all_keras_metadata(inputs): inputs, outputs = self._set_connectivity_metadata_( inputs, outputs, args, kwargs) - - self.built = True if context.executing_eagerly(): return outputs @@ -747,17 +794,8 @@ class Layer(checkpointable.CheckpointableBase): if hasattr(self, '_initial_weights') and self._initial_weights is not None: self.set_weights(self._initial_weights) del self._initial_weights - self._post_build_cleanup() return outputs - def _post_build_cleanup(self): - """Hooks to run after all sub-Layers are built.""" - # Note that in addition to Layer.__call__, this method is called by Model - # after building a graph network (which skips __call__). It should be called - # when possible if self.built may have switched from False to True, and is - # idempotent. - pass # No-op for Layers which don't override this method. - def apply(self, inputs, *args, **kwargs): """Apply the layer on a input. @@ -791,21 +829,27 @@ class Layer(checkpointable.CheckpointableBase): pass def _set_mask_metadata(self, inputs, outputs, previous_mask): - if hasattr(self, 'compute_mask'): + # In some cases the mask of the outputs has already been computed by + # inner layers and does not need to be recomputed by this layer. + mask_already_computed = all( + hasattr(x, '_keras_mask') for x in generic_utils.to_list(outputs)) + if hasattr(self, 'compute_mask') and not mask_already_computed: output_mask = self.compute_mask(inputs, previous_mask) - if isinstance(outputs, (list, tuple)): - if output_mask is None: - output_mask = [None for _ in range(len(outputs))] - for x, m in zip(outputs, output_mask): - try: - x._keras_mask = m # pylint: disable=protected-access - except AttributeError: - pass # C type such as dict. Masking not supported in this case. - else: + else: + output_mask = None + if isinstance(outputs, (list, tuple)): + if output_mask is None: + output_mask = [None for _ in range(len(outputs))] + for x, m in zip(outputs, output_mask): try: - outputs._keras_mask = output_mask # pylint: disable=protected-access + x._keras_mask = m # pylint: disable=protected-access except AttributeError: pass # C type such as dict. Masking not supported in this case. + else: + try: + outputs._keras_mask = output_mask # pylint: disable=protected-access + except AttributeError: + pass # C type such as dict. Masking not supported in this case. def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs): call_convention = getattr(self, '_call_convention', @@ -867,7 +911,7 @@ class Layer(checkpointable.CheckpointableBase): assert len(call_args) == 1 # TypeError raised earlier in __call__. return call_args[0], call_kwargs else: - call_arg_spec = tf_inspect.getargspec(self.call) + call_arg_spec = tf_inspect.getfullargspec(self.call) # There is no explicit "inputs" argument expected or provided to # call(). Arguments which have default values are considered non-inputs, # and arguments without are considered inputs. @@ -887,8 +931,8 @@ class Layer(checkpointable.CheckpointableBase): _, unwrapped_call = tf_decorator.unwrap(self.call) bound_args = inspect.getcallargs( unwrapped_call, *call_args, **call_kwargs) - if call_arg_spec.keywords is not None: - var_kwargs = bound_args.pop(call_arg_spec.keywords) + if call_arg_spec.varkw is not None: + var_kwargs = bound_args.pop(call_arg_spec.varkw) bound_args.update(var_kwargs) keyword_arg_names = keyword_arg_names.union(var_kwargs.keys()) all_args = call_arg_spec.args @@ -931,6 +975,39 @@ class Layer(checkpointable.CheckpointableBase): Returns: An input shape tuple. """ + if context.executing_eagerly(): + # In this case we build the model first in order to do shape inference. + # This is acceptable because the framework only calls + # `compute_output_shape` on shape values that the layer would later be + # built for. It would however cause issues in case a user attempts to + # use `compute_output_shape` manually (these users will have to + # implement `compute_output_shape` themselves). + self.build(input_shape) + + with context.graph_mode(): + graph = eager_function.CapturingGraph() + with graph.as_default(): + if isinstance(input_shape, list): + inputs = [generate_placeholders_from_shape(shape) + for shape in input_shape] + else: + inputs = generate_placeholders_from_shape(input_shape) + + try: + if self._expects_training_arg: + outputs = self(inputs, training=False) + else: + outputs = self(inputs) + except TypeError: + raise NotImplementedError('We could not automatically infer ' + 'the static shape of the layer\'s output.' + ' Please implement the ' + '`compute_output_shape` method on your ' + 'layer (%s).' % self.__class__.__name__) + if isinstance(outputs, list): + return [output.shape for output in outputs] + else: + return outputs.shape raise NotImplementedError def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument @@ -1293,7 +1370,7 @@ class Layer(checkpointable.CheckpointableBase): ', but the layer isn\'t built. ' 'You can build it manually via: `' + self.name + '.build(batch_input_shape)`.') - weight_shapes = [w.get_shape().as_list() for w in self.weights] + weight_shapes = [w.shape.as_list() for w in self.weights] return int(sum([np.prod(w) for w in weight_shapes])) @property @@ -1376,7 +1453,7 @@ class Layer(checkpointable.CheckpointableBase): if (spec.ndim is not None or spec.min_ndim is not None or spec.max_ndim is not None): - if x.get_shape().ndims is None: + if x.shape.ndims is None: raise ValueError('Input ' + str(input_index) + ' of layer ' + self.name + ' is incompatible with the layer: ' 'its rank is undefined, but the layer requires a ' @@ -1384,29 +1461,29 @@ class Layer(checkpointable.CheckpointableBase): # Check ndim. if spec.ndim is not None: - ndim = x.get_shape().ndims + ndim = x.shape.ndims if ndim != spec.ndim: raise ValueError('Input ' + str(input_index) + ' of layer ' + self.name + ' is incompatible with the layer: ' 'expected ndim=' + str(spec.ndim) + ', found ndim=' + str(ndim) + '. Full shape received: ' + - str(x.get_shape().as_list())) + str(x.shape.as_list())) if spec.max_ndim is not None: - ndim = x.get_shape().ndims + ndim = x.shape.ndims if ndim is not None and ndim > spec.max_ndim: raise ValueError('Input ' + str(input_index) + ' of layer ' + self.name + ' is incompatible with the layer: ' 'expected max_ndim=' + str(spec.max_ndim) + ', found ndim=' + str(ndim)) if spec.min_ndim is not None: - ndim = x.get_shape().ndims + ndim = x.shape.ndims if ndim is not None and ndim < spec.min_ndim: raise ValueError('Input ' + str(input_index) + ' of layer ' + self.name + ' is incompatible with the layer: ' ': expected min_ndim=' + str(spec.min_ndim) + ', found ndim=' + str(ndim) + '. Full shape received: ' + - str(x.get_shape().as_list())) + str(x.shape.as_list())) # Check dtype. if spec.dtype is not None: if x.dtype != spec.dtype: @@ -1416,7 +1493,7 @@ class Layer(checkpointable.CheckpointableBase): ', found dtype=' + str(x.dtype)) # Check specific shape axes. if spec.axes: - shape = x.get_shape().as_list() + shape = x.shape.as_list() if shape is not None: for axis, value in spec.axes.items(): if hasattr(value, 'value'): @@ -1429,7 +1506,7 @@ class Layer(checkpointable.CheckpointableBase): ' but received input with shape ' + str(shape)) # Check shape. if spec.shape is not None: - shape = x.get_shape().as_list() + shape = x.shape.as_list() if shape is not None: for spec_dim, dim in zip(spec.shape, shape): if spec_dim is not None and dim is not None: @@ -1704,12 +1781,12 @@ class DeferredTensor(object): def __str__(self): return "DeferredTensor('%s', shape=%s, dtype=%s)" % (self.name, - self.get_shape(), + self.shape, self.dtype.name) def __repr__(self): return "" % (self.name, - self.get_shape(), + self.shape, self.dtype.name) @@ -1804,11 +1881,13 @@ def make_variable(name, dtype=dtypes.float32, initializer=None, partition_info=None, - trainable=True, + trainable=None, caching_device=None, validate_shape=True, constraint=None, use_resource=None, + synchronization=vs.VariableSynchronization.AUTO, + aggregation=vs.VariableAggregation.NONE, partitioner=None): # pylint: disable=unused-argument """Temporary util to create a variable (relies on `variable_scope.variable`). @@ -1834,11 +1913,21 @@ def make_variable(name, or "non_trainable_variables" (e.g. BatchNorm mean, stddev). Note, if the current variable scope is marked as non-trainable then this parameter is ignored and any added variables are also - marked as non-trainable. + marked as non-trainable. `trainable` defaults to `True` unless + `synchronization` is set to `ON_READ`. caching_device: Passed to `vs.variable`. validate_shape: Passed to `vs.variable`. constraint: Constraint instance (callable). use_resource: Whether to use a `ResourceVariable`. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. partitioner: Not handled at this time. Returns: @@ -1870,5 +1959,17 @@ def make_variable(name, dtype=variable_dtype, validate_shape=validate_shape, constraint=constraint, - use_resource=use_resource) + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) return v + + +def default(method): + """Decorates a method to detect overrides in subclasses.""" + method._is_default = True + return method + + +def generate_placeholders_from_shape(shape): + return array_ops.placeholder(shape=shape, dtype=backend.floatx()) diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb073322c76c0494b01a50486a69a2125b61d2c --- /dev/null +++ b/tensorflow/python/keras/engine/distributed_training_utils.py @@ -0,0 +1,271 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities related to distributed training.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import tensor_util +from tensorflow.python.keras import backend +from tensorflow.python.keras import callbacks +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.util import nest + + +def set_weights(distribution_strategy, dist_model, weights): + """Sets the weights of the replicated models. + + The weights of the replicated models are set to the weights of the original + model. The weights of the replicated model are Mirrored variables and hence + we need to use the `update` call within a DistributionStrategy scope. + + Args: + distribution_strategy: DistributionStrategy used to distribute training + and validation. + dist_model: The replicated models on the different devices. + weights: The weights of the original model. + """ + assign_ops = [] + for layer in dist_model.layers: + num_param = len(layer.weights) + layer_weights = weights[:num_param] + for sw, w in zip(layer.weights, layer_weights): + assign_ops.append(distribution_strategy.unwrap(sw.assign(w))) + + weights = weights[num_param:] + backend.get_session().run(assign_ops) + + +def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs, + grouped_updates, grouped_session_args, + with_loss_tensor=False): + """Unwrap and return the list of values contained in the PerDevice parameters. + + This function calls `flatten_perdevice_values` to parse each of the input + parameters into a list of values on the different devices. If we set + `with_loss_tensor` to be True, we also call `reduce` on the list of losses on + the different devices to give us one loss tensor. + + Args: + distribution_strategy: DistributionStrategy used to distribute training and + validation. + grouped_inputs: PerDevice inputs returned from the train or test function + that we ran on each device. + grouped_outputs: PerDevice outputs returned from the train or test function + that we ran on each device. + grouped_updates: PerDevice updates returned from the train or test function + that we ran on each device. + grouped_session_args: PerDevice session args returned from the train or + test function that we ran on each device. + with_loss_tensor: Boolean that indicates if we need to add the reduced loss + tensor as one of the outputs. + + Returns: + Values of each of the PerDevice parameters. + + """ + # Unwrap per device values returned from each model's train function. + # This will be used to construct the main train function. + all_inputs = flatten_perdevice_values(distribution_strategy, + grouped_inputs) + if with_loss_tensor: + # reduce loss tensor before adding it to the list of fetches + loss = distribution_strategy.unwrap( + distribution_strategy.reduce(distribute_lib.get_loss_reduction(), + grouped_outputs[0], + destinations='/device:CPU:0'))[0] + + all_outputs = flatten_perdevice_values(distribution_strategy, + grouped_outputs[1:]) + all_outputs = [loss] + all_outputs + else: + all_outputs = flatten_perdevice_values(distribution_strategy, + grouped_outputs) + + all_updates = flatten_perdevice_values(distribution_strategy, + grouped_updates) + + all_session_args = {} + grouped_feed_dict = grouped_session_args.get('feed_dict') + if grouped_feed_dict: + all_session_args['feed_dict'] = flatten_perdevice_values( + distribution_strategy, grouped_feed_dict) + + grouped_fetches = grouped_session_args.get('fetches') + if grouped_fetches: + all_session_args['fetches'] = flatten_perdevice_values( + distribution_strategy, grouped_fetches) + + return all_inputs, all_outputs, all_updates, all_session_args + + +def flatten_perdevice_values(distribution_strategy, perdevice_values): + """Unwraps and flattens a nest of PerDevice parameters. + + PerDevice values have one value associated with each device. Each entry in + the PerDevice dict has a device `key` and the corresponding value on the + device as the `value`. In this function we take a PerDevice value or a list of + PerDevice values and return all the values in the PerDevice dict. + + Args: + distribution_strategy: DistributionStrategy used to distribute training and + validation. + perdevice_values: List of PerDevice object or a single PerDevice object. + + Returns: + List of values of all the PerDevice objects. + + """ + # This function takes a PerDevice object or a list of PerDevice objects and + # returns all the values associated with it. + return [e for flattened in nest.flatten(perdevice_values) + for e in distribution_strategy.unwrap(flattened)] + + +def validate_callbacks(input_callbacks): + """Validate whether given callbacks are supported by DistributionStrategy. + + Args: + input_callbacks: List of callbacks passed by the user to fit. + + Raises: + ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the + callbacks passed. + ValueError: If `histogram_freq` or `write_grads` is one of the parameters + passed as part of the TensorBoard callback. + """ + if input_callbacks: + for callback in input_callbacks: + if callback not in [callbacks.TensorBoard, callbacks.ReduceLROnPlateau, + callbacks.LearningRateScheduler, callbacks.CSVLogger, + callbacks.EarlyStopping, callbacks.ModelCheckpoint, + callbacks.TerminateOnNaN, callbacks.ProgbarLogger, + callbacks.History, callbacks.RemoteMonitor]: + logging.warning('Your input callback is not one of the predefined ' + 'Callbacks that supports DistributionStrategy. You ' + 'might encounter an error if you access one of the ' + 'model\'s attributes as part of the callback since ' + 'these attributes are not set. You can access each of ' + 'the individual distributed models using the ' + '`_grouped_model` attribute of your original model.') + if isinstance(callback, callbacks.LearningRateScheduler): + raise ValueError('LearningRateScheduler callback is not supported with ' + 'DistributionStrategy.') + if isinstance(callback, callbacks.ReduceLROnPlateau): + raise ValueError('ReduceLROnPlateau callback is not supported with ' + 'DistributionStrategy.') + + # If users want to use the TensorBoard callback they cannot use certain + # features of the callback that involve accessing model attributes and + # running ops. + if isinstance(callback, callbacks.TensorBoard): + if callback.__getattribute__('histogram_freq'): + raise ValueError('histogram_freq in the TensorBoard callback is not ' + 'supported when using DistributionStrategy.') + if callback.__getattribute__('write_grads'): + raise ValueError('write_grads in the TensorBoard callback is not ' + 'supported when using DistributionStrategy.') + + +def validate_distributed_dataset_inputs(distribution_strategy, x, y): + """Validate all the components of a DistributedValue Dataset input. + + Args: + distribution_strategy: The current DistributionStrategy used to call + `fit`/`evaluate`. + x: Input Dataset DistributedValue object. For example, when we use + `MirroredStrategy` this is a PerDevice object with a tensor for each + device set in the dict. x can also be a tuple or dict. The keys of the + dict should match the names of the input layers of the model. + y: Target Dataset DistributedValue object. For example, when we use + `MirroredStrategy` this is a PerDevice object with a tensor for each + device set in the dict. y can also be a tuple or dict. The keys of the + dict should match the names of the output layers of the model. + + Returns: + The unwrapped values list of the x and y DistributedValues inputs. + + Raises: + ValueError: If x and y do not have support for being evaluated as tensors. + or if x and y contain elements that are not tensors or if x and y + contain elements that have a shape or dtype mismatch. + """ + # If the input and target used to call the model are not dataset tensors, + # we need to raise an error. When using a DistributionStrategy, the input + # and targets to a model should be from a `tf.data.Dataset`. + + # If each element of x and y are not tensors, we cannot standardize and + # validate the input and targets. + x_values_list = validate_per_device_inputs(distribution_strategy, x) + + y_values_list = validate_per_device_inputs(distribution_strategy, y) + + # Return the unwrapped values to avoid calling `unwrap` a second time. + return x_values_list, y_values_list + + +def validate_per_device_inputs(distribution_strategy, x): + """Validates PerDevice dataset input list. + + Args: + distribution_strategy: The current DistributionStrategy used to call + `fit`, `evaluate` and `predict`. + x: A list of PerDevice objects that represent the input or + target values. + + Returns: + List containing the first element of each of the PerDevice objects in + the input list. + + Raises: + ValueError: If any of the objects in the `per_device_list` is not a tensor. + + """ + # Convert the inputs and targets into a list of PerDevice objects. + per_device_list = nest.flatten(x) + x_values_list = [] + for x in per_device_list: + if not tensor_util.is_tensor(x): + raise ValueError('Dataset input to the model should be tensors instead ' + 'they are of type {}'.format(type(x))) + + # At this point both x and y contain tensors in the `DistributedValues` + # structure. + x_values = distribution_strategy.unwrap(x) + + # Validate that the shape and dtype of all the elements in x are the same. + validate_all_tensor_shapes(x, x_values) + validate_all_tensor_types(x, x_values) + + x_values_list.append(x_values[0]) + return x_values_list + + +def validate_all_tensor_types(x, x_values): + x_dtype = x_values[0].dtype + for i in range(1, len(x_values)): + if x_dtype != x_values[i].dtype: + raise ValueError('Input tensor dtypes do not match for distributed tensor' + ' inputs {}'.format(x)) + + +def validate_all_tensor_shapes(x, x_values): + # Validate that the shape of all the elements in x have the same shape + x_shape = x_values[0].get_shape().as_list() + for i in range(1, len(x_values)): + if x_shape != x_values[i].get_shape().as_list(): + raise ValueError('Input tensor shapes do not match for distributed tensor' + ' inputs {}'.format(x)) diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index 7996110829b56b6f7f0b3e4c5b6c5b9f35affb64..8a4018a0df50b8d4c9df5900ffddfcdc093f161f 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -215,7 +215,7 @@ def Input( # pylint: disable=invalid-name if dtype is None: dtype = K.floatx() - if not shape and tensor is None: + if shape is None and tensor is None: raise ValueError('Please provide to Input either a `shape`' ' or a `tensor` argument. Note that ' '`shape` does not include the batch ' diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 427efaaf11ecd964d72b3e34233920d2cdfeaeeb..708fa1c807d40ac6d1315d8e14451e021d2517f7 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function import copy -import functools import json import os import weakref @@ -30,6 +29,8 @@ from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context +from tensorflow.python.eager import function as eager_function +from tensorflow.python.framework import errors from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -42,10 +43,11 @@ from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_management from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import data_structures_base +from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import layer_utils as checkpointable_layer_utils from tensorflow.python.training.checkpointable import util as checkpointable_utils -from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect @@ -80,6 +82,20 @@ class Network(base_layer.Layer): # Subclassed network self._init_subclassed_network(**kwargs) + # Several Network methods have "no_automatic_dependency_tracking" + # annotations. Since Network does automatic dependency tracking on attribute + # assignment, including for common data structures such as lists, by default + # we'd have quite a few empty dependencies which users don't care about (or + # would need some way to ignore dependencies automatically, which is confusing + # when applied to user code). Some attributes, such as _layers, would cause + # structural issues (_layers being the place where Layers assigned to tracked + # attributes are stored). + # + # Aside from these aesthetic and structural issues, useless dependencies on + # empty lists shouldn't cause issues; adding or removing them will not break + # checkpoints, but may cause "all Python objects matched" assertions to fail + # (in which case less strict assertions may be substituted if necessary). + @checkpointable.no_automatic_dependency_tracking def _base_init(self, name=None): # The following are implemented as property functions: # self.trainable_weights @@ -101,6 +117,16 @@ class Network(base_layer.Layer): # included in base_init to avoid excessive special casing when retrieving # the value). self._extra_variables = [] + # In many internal cases one needs to compute both the model's output + # and its output mask without relying on `__call__` (which would do both and + # set mask metadata), but for models, computing the mask requires to + # recompute the output. + # Hence the pattern `output = model.call(); mask = model.compute_mask()` + # would be redundant, and internal logic + # (susceptible to use `call` directly) should prefer using the + # internal method `output, mask = _call_and_compute_mask()`. + # This is True for Sequential networks and graph networks. + self._compute_output_and_mask_jointly = False self.supports_masking = False if not hasattr(self, 'optimizer'): @@ -129,11 +155,8 @@ class Network(base_layer.Layer): self._checkpointable_saver = checkpointable_utils.CheckpointableSaver( weakref.ref(self)) - # A zero-argument function which should be called and set back to None as - # soon as the network is built (only applicable to subclassed Models). Runs - # restore operations when graph building. - self._in_progress_restore_finalizer = None + @checkpointable.no_automatic_dependency_tracking def _init_graph_network(self, inputs, outputs, name=None): self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT # Normalize and set self.inputs, self.outputs. @@ -202,11 +225,12 @@ class Network(base_layer.Layer): self._base_init(name=name) self._compute_previous_mask = ( - 'mask' in tf_inspect.getargspec(self.call).args or + 'mask' in tf_inspect.getfullargspec(self.call).args or hasattr(self, 'compute_mask')) # A Network does not create weights of its own, thus it is already # built. self.built = True + self._compute_output_and_mask_jointly = True self._is_graph_network = True self._input_layers = [] @@ -258,23 +282,6 @@ class Network(base_layer.Layer): input_tensors=self.inputs, output_tensors=self.outputs) - # Fill in the output mask cache. - masks = [] - for x in self.inputs: - mask = x._keras_mask if hasattr(x, '_keras_mask') else None # pylint: disable=protected-access - masks.append(mask) - mask_cache_key = (generic_utils.object_list_uid(self.inputs) + '_' + - generic_utils.object_list_uid(masks)) - masks = [] - for x in self.outputs: - mask = x._keras_mask if hasattr(x, '_keras_mask') else None # pylint: disable=protected-access - masks.append(mask) - if len(masks) == 1: - mask = masks[0] - else: - mask = masks - self._output_mask_cache[mask_cache_key] = mask - # Build self.input_names and self.output_names. self.input_names = [] self.output_names = [] @@ -292,17 +299,18 @@ class Network(base_layer.Layer): for layer in self._output_layers: self.output_names.append(layer.name) + @checkpointable.no_automatic_dependency_tracking def _init_subclassed_network(self, name=None): self._base_init(name=name) self._is_graph_network = False - call_argspec = tf_inspect.getargspec(self.call) + call_argspec = tf_inspect.getfullargspec(self.call) if 'training' in call_argspec.args: self._expects_training_arg = True else: self._expects_training_arg = False self._call_convention = self._determine_call_convention(call_argspec) - self.outputs = None - self.inputs = None + self.outputs = [] + self.inputs = [] self.built = False def _determine_call_convention(self, call_argspec): @@ -361,14 +369,35 @@ class Network(base_layer.Layer): self._track_checkpointable( layer, name='layer-%d' % layer_index, overwrite=True) + def _no_dependency(self, value): + """Override to allow `Layer` to disable dependency tracking. + + `CheckpointableBase` defines this method, whose semantics are "if a subclass + does dependency tracking, this method exempts `value`." Layer uses + `_no_dependency` to exempt some of its attribute assignments (conditional on + attribute assignment causing tracking in the subclass). + + Args: + value: An object which will be assigned to an object attribute, whose + value should not be tracked. + + Returns: + A wrapped object which, when assigned to an attribute, will not be + tracked (`value` will be stored in the attribute). + """ + return data_structures.NoDependency(value) + def __setattr__(self, name, value): - no_dependency = isinstance(value, checkpointable.NoDependency) - if no_dependency: - value = value.value + if not getattr(self, '_setattr_tracking', True): + super(Network, self).__setattr__(name, value) + return + no_dependency = isinstance(value, data_structures.NoDependency) + value = data_structures.sticky_attribute_assignment( + checkpointable=self, value=value, name=name) if isinstance(value, ( base_layer.Layer, Network, - data_structures_base.CheckpointableDataStructureBase)): + data_structures.CheckpointableDataStructure)): try: is_graph_network = self._is_graph_network except AttributeError: @@ -376,7 +405,9 @@ class Network(base_layer.Layer): 'forgot to call `super(YourClass, self).__init__()`.' ' Always start with this line.') if not is_graph_network: - if value not in self._layers: + # We need to check object identity to avoid de-duplicating empty + # container types which compare equal. + if not any((layer is value for layer in self._layers)): self._layers.append(value) if hasattr(value, '_use_resource_variables'): # In subclassed models, legacy layers (tf.layers) must always use @@ -384,12 +415,6 @@ class Network(base_layer.Layer): value._use_resource_variables = True if (not no_dependency and isinstance(value, checkpointable.CheckpointableBase)): - # Layer (and therefore Network/Model) inherit from CheckpointableBase - # rather than Checkpointable, which means there is no Checkpointable - # __setattr__ override (it would be a performance issue for functional - # layers). Therefore Model tracks Checkpointable objects itself. - self._track_checkpointable( - checkpointable=value, name=name, overwrite=True) if ( # For subclassed models only, users may add extra weights/variables # simply by assigning them to attributes. not self._is_graph_network @@ -482,17 +507,14 @@ class Network(base_layer.Layer): masks = [None for _ in range(len(inputs))] else: masks = generic_utils.to_list(mask) - cache_key = (generic_utils.object_list_uid(inputs) - + '_' + generic_utils.object_list_uid(masks)) - if cache_key in self._output_mask_cache: - return self._output_mask_cache[cache_key] - else: - _, output_masks = self._run_internal_graph(inputs, mask=masks) - return output_masks + + _, output_masks = self._run_internal_graph(inputs, mask=masks) + return output_masks @property def layers(self): - return self._layers + return checkpointable_layer_utils.filter_empty_layer_containers( + self._layers) def get_layer(self, name=None, index=None): """Retrieves a layer based on either its name (unique) or index. @@ -527,6 +549,28 @@ class Network(base_layer.Layer): return layer raise ValueError('No such layer: ' + name) + @property + def _unfiltered_updates(self): + if context.executing_eagerly(): + return [] + updates = [] + for layer in self.layers: + if isinstance(layer, Network): + updates += layer._unfiltered_updates + else: + updates += layer.updates + return updates + + @property + def _unfiltered_losses(self): + losses = [] + for layer in self.layers: + if isinstance(layer, Network): + losses += layer._unfiltered_losses + else: + losses += layer.losses + return losses + @property def updates(self): """Retrieves the network's updates. @@ -536,6 +580,8 @@ class Network(base_layer.Layer): (e.g. will not include updates that were created by layers of this model outside of the model). + When the network has no registered inputs, all updates are returned. + Effectively, `network.updates` behaves like `layer.updates`. Concrete example: @@ -581,22 +627,20 @@ class Network(base_layer.Layer): if not self.trainable and not self.stateful: return [] - updates = [] - for layer in self.layers: - updates += layer.updates + updates = self._unfiltered_updates # `updates` might contain irrelevant updates, so it needs to be filtered # with respect to inputs the model has been called on. - if self.inputs: - relevant_inputs = self.inputs[:] - else: - relevant_inputs = [] - for i in range(1, len(self._inbound_nodes)): + relevant_inputs = [] + for i in range(0, len(self._inbound_nodes)): inputs = self.get_input_at(i) if isinstance(inputs, list): relevant_inputs += inputs else: relevant_inputs.append(inputs) + if not relevant_inputs: + return updates + reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, updates) relevant_conditional_updates = [x for x in updates if x in reachable] unconditional_updates = [ @@ -615,25 +659,25 @@ class Network(base_layer.Layer): (e.g. will not include losses that depend on tensors that aren't inputs to this model). + When the network has no registered inputs, all losses are returned. + Returns: A list of loss tensors. """ - losses = [] - for layer in self.layers: - losses += layer.losses + losses = self._unfiltered_losses if context.executing_eagerly(): return losses - if self.inputs: - relevant_inputs = self.inputs[:] - else: - relevant_inputs = [] - for i in range(1, len(self._inbound_nodes)): + relevant_inputs = [] + for i in range(0, len(self._inbound_nodes)): inputs = self.get_input_at(i) if isinstance(inputs, list): relevant_inputs += inputs else: relevant_inputs.append(inputs) + if not relevant_inputs: + return losses + reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, losses) relevant_conditional_losses = [x for x in losses if x in reachable] unconditional_losses = [ @@ -643,14 +687,14 @@ class Network(base_layer.Layer): @property def trainable_weights(self): - return layer_utils.gather_trainable_weights( + return checkpointable_layer_utils.gather_trainable_weights( trainable=self.trainable, sub_layers=self.layers, extra_variables=self._extra_variables) @property def non_trainable_weights(self): - return layer_utils.gather_non_trainable_weights( + return checkpointable_layer_utils.gather_non_trainable_weights( trainable=self.trainable, sub_layers=self.layers, extra_variables=self._extra_variables) @@ -682,6 +726,93 @@ class Network(base_layer.Layer): return specs[0] return specs + @base_layer.default + def build(self, input_shape): + """Builds the model based on input shapes received. + + This is to be used for subclassed models, which do not know at instantiation + time what their inputs look like. + + Args: + input_shape: Single tuple, TensorShape, or list of shapes, where shapes + are tuples, integers, or TensorShapes. + + Raises: + ValueError: + 1. In case of invalid user-provided data (not of type tuple, + list, or TensorShape). + 2. If the model requires call arguments that are agnostic + to the input shapes (positional or kwarg in call signature). + 3. If not all layers were properly built. + 4. If float type inputs are not supported within the layers. + + In each of these cases, the user should build their model by calling it + on real tensor data. + """ + if self._is_graph_network: + self.built = True + return + + # If subclass network + if input_shape is None: + raise ValueError('Input shape must be defined when calling build on a ' + 'model subclass network.') + valid_types = (tuple, list, tensor_shape.TensorShape) + if not isinstance(input_shape, valid_types): + raise ValueError('Specified input shape is not one of the valid types. ' + 'Please specify a batch input shape of type tuple or ' + 'list of input shapes. User provided ' + 'input type: {}'.format(type(input_shape))) + + if input_shape and not self.inputs: + # We create placeholders for the `None`s in the shape and build the model + # in a Graph. Since tf.Variable is compatible with both eager execution + # and graph building, the variables created after building the model in + # a Graph are still valid when executing eagerly. + with context.graph_mode(): + graph = eager_function.CapturingGraph() + with graph.as_default(): + if isinstance(input_shape, list): + x = [base_layer.generate_placeholders_from_shape(shape) + for shape in input_shape] + else: + x = base_layer.generate_placeholders_from_shape(input_shape) + + kwargs = {} + num_call_args = len(tf_inspect.getfullargspec(self.call).args) + if self._expects_training_arg and num_call_args == 3: + # Has call signature of call(self, input, training) + kwargs['training'] = False + elif num_call_args > 2: + # Has invalid call signature of call(self, input, *args, **kwargs) + raise ValueError('Currently, you cannot build your model if it has ' + 'positional or keyword arguments that are not ' + 'inputs to the model, but are required for its ' + '`call` method. Instead, in order to instantiate ' + 'and build your model, `call` your model on real ' + 'tensor data with all expected call arguments.') + + try: + self.call(x, **kwargs) + except (errors.InvalidArgumentError, TypeError): + raise ValueError('You cannot build your model by calling `build` ' + 'if your layers do not support float type inputs. ' + 'Instead, in order to instantiate and build your ' + 'model, `call` your model on real tensor data (of ' + 'the correct dtype).') + + if self._layers: + self._track_layers(self._layers) + if self.layers: + for layer in self.layers: + if not layer.built: + raise ValueError('Layer: {} was not built in your model. Calling ' + '`build` manually on a subclassed model is only ' + 'allowed for models with a static topology. ' + 'In this case, you can build your model by ' + 'calling it on real tensor data.'.format(layer)) + self.built = True + def call(self, inputs, training=None, mask=None): """Calls the model on new inputs. @@ -700,28 +831,34 @@ class Network(base_layer.Layer): A tensor if there is a single output, or a list of tensors if there are more than one outputs. """ - inputs = nest.flatten(inputs) + if not self._is_graph_network: + raise NotImplementedError('When subclassing the `Model` class, you should' + ' implement a `call` method.') + + inputs = generic_utils.to_list(inputs) if mask is None: masks = [None for _ in range(len(inputs))] else: - masks = nest.flatten(mask) - - if not context.executing_eagerly(): - # Try to retrieve cached outputs if the layer has already been called - # on these exact inputs. - cache_key = (generic_utils.object_list_uid(inputs) - + '_' + generic_utils.object_list_uid(masks)) - if cache_key in self._output_tensor_cache: - # Cache hit. - return self._output_tensor_cache[cache_key] - # Actually apply the network graph to the new inputs. + masks = generic_utils.to_list(mask) outputs, _ = self._run_internal_graph(inputs, training=training, mask=masks) return outputs + def _call_and_compute_mask(self, inputs, training=None, mask=None): + inputs = generic_utils.to_list(inputs) + if mask is None: + masks = [None for _ in range(len(inputs))] + else: + masks = generic_utils.to_list(mask) + return self._run_internal_graph(inputs, + training=training, + mask=masks) + def compute_output_shape(self, input_shape): if not self._is_graph_network: + if context.executing_eagerly(): + return super(Network, self).compute_output_shape(input_shape) raise NotImplementedError if isinstance(input_shape, list): @@ -743,9 +880,10 @@ class Network(base_layer.Layer): ' tensor inputs.') cache_key = generic_utils.object_list_uid(input_shapes) - if cache_key not in self._output_shape_cache: - # Cache miss. We have to run the network graph manually (recursive calls - # to `compute_output_shape`). + if cache_key in self._output_shape_cache: + # Cache hit. + output_shapes = self._output_shape_cache[cache_key] + else: layers_to_output_shapes = {} for i in range(len(input_shapes)): layer = self._input_layers[i] @@ -807,9 +945,6 @@ class Network(base_layer.Layer): output_shapes.append(layers_to_output_shapes[shape_key]) # Store in cache. self._output_shape_cache[cache_key] = output_shapes - else: - # Cache hit. - output_shapes = self._output_shape_cache[cache_key] if isinstance(output_shapes, list): if len(output_shapes) == 1: @@ -832,7 +967,7 @@ class Network(base_layer.Layer): mask: List of masks (tensors or None). Returns: - Three lists: output_tensors, output_masks, output_shapes + Two lists: output_tensors, output_masks """ # Note: masking support is relevant mainly for Keras. # It cannot be factored out without having the fully reimplement the network @@ -849,8 +984,6 @@ class Network(base_layer.Layer): # Dictionary mapping reference tensors to tuples # (computed tensor, compute mask) # we assume a 1:1 mapping from tensor to mask - # TODO(fchollet): raise exception when a `.compute_mask()` call - # does not return a list the same size as `call` tensor_map = {} for x, y, mask in zip(self.inputs, inputs, masks): tensor_map[str(id(x))] = (y, mask) @@ -879,54 +1012,69 @@ class Network(base_layer.Layer): kwargs = node.arguments else: kwargs = {} + # Ensure `training` arg propagation if applicable. + if 'training' in tf_inspect.getfullargspec(layer.call).args: + kwargs.setdefault('training', training) + if len(computed_data) == 1: computed_tensor, computed_mask = computed_data[0] # Ensure mask propagation if applicable. - if 'mask' in tf_inspect.getargspec(layer.call).args: + if 'mask' in tf_inspect.getfullargspec(layer.call).args: kwargs.setdefault('mask', computed_mask) - if 'training' in tf_inspect.getargspec(layer.call).args: - kwargs.setdefault('training', training) - - output_tensors = nest.flatten( - layer.call(computed_tensor, **kwargs)) - if hasattr(layer, 'compute_mask'): - output_masks = layer.compute_mask(computed_tensor, - computed_mask) - if output_masks is None: - output_masks = [None for _ in output_tensors] - else: - output_masks = nest.flatten(output_masks) + + # Compute outputs and masks. + if (isinstance(layer, Network) and + layer._compute_output_and_mask_jointly): + output_tensors, output_masks = layer._call_and_compute_mask( + computed_tensor, **kwargs) else: - output_masks = [None for _ in output_tensors] + output_tensors = layer.call(computed_tensor, **kwargs) + if hasattr(layer, 'compute_mask'): + output_masks = layer.compute_mask(computed_tensor, + computed_mask) + else: + output_masks = [None for _ in output_tensors] computed_tensors = [computed_tensor] - computed_masks = [computed_mask] + else: computed_tensors = [x[0] for x in computed_data] computed_masks = [x[1] for x in computed_data] - if 'mask' in tf_inspect.getargspec(layer.call).args: + # Ensure mask propagation if applicable. + if 'mask' in tf_inspect.getfullargspec(layer.call).args: kwargs.setdefault('mask', computed_masks) - if 'training' in tf_inspect.getargspec(layer.call).args: - kwargs.setdefault('training', training) - output_tensors = nest.flatten( - layer.call(computed_tensors, **kwargs)) - - if hasattr(layer, 'compute_mask'): - output_masks = layer.compute_mask(computed_tensors, - computed_masks) - if output_masks is None: - output_masks = [None for _ in output_tensors] - else: - output_masks = nest.flatten(output_masks) + # Compute outputs and masks. + if (isinstance(layer, Network) and + layer._compute_output_and_mask_jointly): + output_tensors, output_masks = layer._call_and_compute_mask( + computed_tensors, **kwargs) else: - output_masks = [None for _ in output_tensors] + output_tensors = layer.call(computed_tensors, **kwargs) + if hasattr(layer, 'compute_mask'): + output_masks = layer.compute_mask(computed_tensors, + computed_masks) + else: + output_masks = [None for _ in output_tensors] + + output_tensors = generic_utils.to_list(output_tensors) + if output_masks is None: + output_masks = [None for _ in output_tensors] + else: + output_masks = generic_utils.to_list(output_masks) if not context.executing_eagerly(): + # Set mask metadata. + for x, m in zip(output_tensors, output_masks): + try: + x._keras_mask = m + except AttributeError: + pass + + # Apply activity regularizer if any. if layer.activity_regularizer is not None: regularization_losses = [ layer.activity_regularizer(x) for x in output_tensors ] - # Apply activity regularizer if any: layer.add_loss(regularization_losses, computed_tensors) # Update tensor_map. @@ -951,18 +1099,10 @@ class Network(base_layer.Layer): if output_masks is not None: output_masks = output_masks[0] - if not context.executing_eagerly(): - # Update cache; - # keys are based on ids on input tensors and inputs masks. - cache_key = (generic_utils.object_list_uid(inputs) - + '_' + generic_utils.object_list_uid(masks)) - self._output_tensor_cache[cache_key] = output_tensors - self._output_mask_cache[cache_key] = output_masks - - if output_shapes is not None: - input_shapes = [backend.int_shape(x) for x in inputs] - cache_key = generic_utils.object_list_uid(input_shapes) - self._output_shape_cache[cache_key] = output_shapes + if output_shapes is not None: + input_shapes = [backend.int_shape(x) for x in inputs] + cache_key = generic_utils.object_list_uid(input_shapes) + self._output_shape_cache[cache_key] = output_shapes return output_tensors, output_masks @@ -1305,7 +1445,22 @@ class Network(base_layer.Layer): session = None else: session = backend.get_session() + optimizer = getattr(self, 'optimizer', None) + if (optimizer + and not isinstance(optimizer, checkpointable.CheckpointableBase)): + logging.warning( + ('This model was compiled with a Keras optimizer (%s) but is being ' + 'saved in TensorFlow format with `save_weights`. The model\'s ' + 'weights will be saved, but unlike with TensorFlow optimizers in ' + 'the TensorFlow format the optimizer\'s state will not be ' + 'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.') + % (optimizer,)) self._checkpointable_saver.save(filepath, session=session) + # Record this checkpoint so it's visible from tf.train.latest_checkpoint. + checkpoint_management.update_checkpoint_state( + save_dir=os.path.dirname(filepath), + model_checkpoint_path=filepath, + all_model_checkpoint_paths=[filepath]) def load_weights(self, filepath, by_name=False): """Loads all layer weights, either from a TensorFlow or an HDF5 weight file. @@ -1366,13 +1521,9 @@ class Network(base_layer.Layer): 'load_weights).') if not context.executing_eagerly(): session = backend.get_session() - finalizer = functools.partial(status.run_restore_ops, session=session) - if self.built: - finalizer() - else: - # Hold on to this status object until the network is built (for - # subclassed Models). Then we'll run restore ops if necessary. - self._in_progress_restore_finalizer = finalizer + # Restore existing variables (if any) immediately, and set up a + # streaming restore for any variables created in the future. + checkpointable_utils.streaming_restore(status=status, session=session) return status if h5py is None: raise ImportError( @@ -1390,14 +1541,6 @@ class Network(base_layer.Layer): else: saving.load_weights_from_hdf5_group(f, self.layers) - def _post_build_cleanup(self): - super(Network, self)._post_build_cleanup() - if self._in_progress_restore_finalizer is not None: - # Runs queued restore operations left over from load_weights when graph - # building. - self._in_progress_restore_finalizer() - self._in_progress_restore_finalizer = None - def _updated_config(self): """Util shared between different serialization methods. @@ -1496,47 +1639,6 @@ class Network(base_layer.Layer): print_fn=print_fn) -def get_source_inputs(tensor, layer=None, node_index=None): - """Returns the list of input tensors necessary to compute `tensor`. - - Output will always be a list of tensors - (potentially with 1 element). - - Arguments: - tensor: The tensor to start from. - layer: Origin layer of the tensor. Will be - determined via tensor._keras_history if not provided. - node_index: Origin node index of the tensor. - - Returns: - List of input tensors. - """ - if not hasattr(tensor, '_keras_history'): - return tensor - - if layer is None or node_index: - layer, node_index, _ = tensor._keras_history - if not layer._inbound_nodes: - return [tensor] - else: - node = layer._inbound_nodes[node_index] - if not node.inbound_layers: - # Reached an Input layer, stop recursion. - return node.input_tensors - else: - source_tensors = [] - for i in range(len(node.inbound_layers)): - x = node.input_tensors[i] - layer = node.inbound_layers[i] - node_index = node.node_indices[i] - previous_sources = get_source_inputs(x, layer, node_index) - # Avoid input redundancy. - for x in previous_sources: - if x not in source_tensors: - source_tensors.append(x) - return source_tensors - - def _is_hdf5_filepath(filepath): return filepath.endswith('.h5') or filepath.endswith('.keras') diff --git a/tensorflow/python/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py index b9a2e1f25f637dc8017f751bbdd400c1e5c9dd44..a2eed7cb462c57da2468c418d04108fb274b7fb6 100644 --- a/tensorflow/python/keras/engine/saving.py +++ b/tensorflow/python/keras/engine/saving.py @@ -127,6 +127,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True): }, 'loss': model.loss, 'metrics': model.metrics, + 'weighted_metrics': model.weighted_metrics, 'sample_weight_mode': model.sample_weight_mode, 'loss_weights': model.loss_weights, }, @@ -246,6 +247,8 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable= # Recover loss functions and metrics. loss = convert_custom_objects(training_config['loss']) metrics = convert_custom_objects(training_config['metrics']) + weighted_metrics = convert_custom_objects( + training_config['weighted_metrics']) sample_weight_mode = training_config['sample_weight_mode'] loss_weights = training_config['loss_weights'] @@ -254,6 +257,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable= optimizer=optimizer, loss=loss, metrics=metrics, + weighted_metrics=weighted_metrics, loss_weights=loss_weights, sample_weight_mode=sample_weight_mode) @@ -351,7 +355,10 @@ def preprocess_weights_for_loading(layer, weights, original_keras_version=None, original_backend=None): - """Converts layers weights from Keras 1 format to Keras 2. + """Preprocess layer weights between different Keras formats. + + Converts layers weights from Keras 1 format to Keras 2 and also weights of + CuDNN layers in Keras 2. Arguments: layer: Layer instance. @@ -363,7 +370,18 @@ def preprocess_weights_for_loading(layer, Returns: A list of weights values (Numpy arrays). """ - if layer.__class__.__name__ == 'Bidirectional': + def convert_nested_bidirectional(weights): + """Converts layers nested in `Bidirectional` wrapper. + + This function uses `preprocess_weights_for_loading()` for converting + layers. + + Arguments: + weights: List of weights values (Numpy arrays). + + Returns: + A list of weights values (Numpy arrays). + """ num_weights_per_layer = len(weights) // 2 forward_weights = preprocess_weights_for_loading( layer.forward_layer, weights[:num_weights_per_layer], @@ -371,7 +389,69 @@ def preprocess_weights_for_loading(layer, backward_weights = preprocess_weights_for_loading( layer.backward_layer, weights[num_weights_per_layer:], original_keras_version, original_backend) - weights = forward_weights + backward_weights + return forward_weights + backward_weights + + def convert_nested_time_distributed(weights): + """Converts layers nested in `TimeDistributed` wrapper. + + This function uses `preprocess_weights_for_loading()` for converting nested + layers. + + Arguments: + weights: List of weights values (Numpy arrays). + + Returns: + A list of weights values (Numpy arrays). + """ + return preprocess_weights_for_loading( + layer.layer, weights, original_keras_version, original_backend) + + def convert_nested_model(weights): + """Converts layers nested in `Model` or `Sequential`. + + This function uses `preprocess_weights_for_loading()` for converting nested + layers. + + Arguments: + weights: List of weights values (Numpy arrays). + + Returns: + A list of weights values (Numpy arrays). + """ + new_weights = [] + # trainable weights + for sublayer in layer.layers: + num_weights = len(sublayer.trainable_weights) + if num_weights > 0: + new_weights.extend(preprocess_weights_for_loading( + layer=sublayer, + weights=weights[:num_weights], + original_keras_version=original_keras_version, + original_backend=original_backend)) + weights = weights[num_weights:] + + # non-trainable weights + for sublayer in layer.layers: + num_weights = len([l for l in sublayer.weights + if l not in sublayer.trainable_weights]) + if num_weights > 0: + new_weights.extend(preprocess_weights_for_loading( + layer=sublayer, + weights=weights[:num_weights], + original_keras_version=original_keras_version, + original_backend=original_backend)) + weights = weights[num_weights:] + return new_weights + + # Convert layers nested in Bidirectional/Model/Sequential. + # Both transformation should be ran for both Keras 1->2 conversion + # and for conversion of CuDNN layers. + if layer.__class__.__name__ == 'Bidirectional': + weights = convert_nested_bidirectional(weights) + if layer.__class__.__name__ == 'TimeDistributed': + weights = convert_nested_time_distributed(weights) + elif layer.__class__.__name__ in ['Model', 'Sequential']: + weights = convert_nested_model(weights) if original_keras_version == '1': if layer.__class__.__name__ == 'TimeDistributed': @@ -446,35 +526,6 @@ def preprocess_weights_for_loading(layer, recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0)) weights = [kernel, recurrent_kernel, bias] - if layer.__class__.__name__ in ['Model', 'Sequential']: - new_weights = [] - # trainable weights - for sublayer in layer.layers: - num_weights = len(sublayer.trainable_weights) - if num_weights > 0: - new_weights.extend( - preprocess_weights_for_loading( - layer=sublayer, - weights=weights[:num_weights], - original_keras_version=original_keras_version, - original_backend=original_backend)) - weights = weights[num_weights:] - - # non-trainable weights - for sublayer in layer.layers: - num_weights = len([ - l for l in sublayer.weights if l not in sublayer.trainable_weights - ]) - if num_weights > 0: - new_weights.extend( - preprocess_weights_for_loading( - layer=sublayer, - weights=weights[:num_weights], - original_keras_version=original_keras_version, - original_backend=original_backend)) - weights = weights[num_weights:] - weights = new_weights - conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D'] if layer.__class__.__name__ in conv_layers: if original_backend == 'theano': @@ -486,6 +537,7 @@ def preprocess_weights_for_loading(layer, if layer.__class__.__name__ == 'ConvLSTM2D': weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) + # convert CuDNN layers return _convert_rnn_weights(layer, weights) @@ -624,7 +676,7 @@ def _convert_rnn_weights(layer, weights): kernels = transform_kernels(weights[0], transpose_input(from_cudnn), n_gates) recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates) - biases = weights[2].reshape((2, -1) if from_cudnn else -1) + biases = np.array(weights[2]).reshape((2, -1) if from_cudnn else -1) return [kernels, recurrent_kernels, biases] if bias_shape == (2 * units * n_gates,): @@ -806,7 +858,16 @@ def load_weights_from_hdf5_group_by_name(f, layers): str(len(weight_values)) + ' element(s).') # Set values. for i in range(len(weight_values)): - weight_value_tuples.append((symbolic_weights[i], weight_values[i])) + if K.int_shape(symbolic_weights[i]) != weight_values[i].shape: + raise ValueError('Layer #' + str(k) +' (named "' + layer.name + + '"), weight ' + str(symbolic_weights[i]) + + ' has shape {}'.format(K.int_shape( + symbolic_weights[i])) + + ', but the saved weight has shape ' + + str(weight_values[i].shape) + '.') + + else: + weight_value_tuples.append((symbolic_weights[i], weight_values[i])) K.batch_set_value(weight_value_tuples) diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py index 7e82db028b8db01d6fe2b693e2087cbdfac55314..b7c2e9cb53c118072825f286ec2a05a6b0dcbd5c 100644 --- a/tensorflow/python/keras/engine/saving_test.py +++ b/tensorflow/python/keras/engine/saving_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import os import shutil import tempfile - from absl.testing import parameterized import numpy as np @@ -31,10 +30,13 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import saving from tensorflow.python.keras.engine import training from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import training as training_module try: @@ -248,6 +250,82 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): self.assertAllClose(y, ref_y) + def test_sequential_weight_loading_group_name_with_incorrect_length(self): + if h5py is None: + return + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + h5_path = os.path.join(temp_dir, 'test.h5') + + num_hidden = 5 + input_dim = 3 + num_classes = 2 + with self.test_session(): + ref_model = keras.models.Sequential() + ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, + name='d1')) + ref_model.add(keras.layers.Dense(num_classes, name='d2')) + ref_model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + + f_ref_model = h5py.File(h5_path, 'w') + saving.save_weights_to_hdf5_group(f_ref_model, ref_model.layers) + + f_model = h5py.File(h5_path, 'r') + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden, use_bias=False, + input_dim=input_dim, name='d1')) + model.add(keras.layers.Dense(num_classes, name='d2')) + model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + with self.assertRaisesRegexp(ValueError, + r'Layer #0 \(named \"d1\"\) expects 1 ' + r'weight\(s\), but the saved weights have 2 ' + r'element\(s\)\.'): + saving.load_weights_from_hdf5_group_by_name(f_model, model.layers) + + def test_sequential_weight_loading_group_name_with_incorrect_shape(self): + if h5py is None: + return + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + h5_path = os.path.join(temp_dir, 'test.h5') + + num_hidden = 5 + input_dim = 3 + num_classes = 2 + with self.test_session(): + ref_model = keras.models.Sequential() + ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, + name='d1')) + ref_model.add(keras.layers.Dense(num_classes, name='d2')) + ref_model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + + f_ref_model = h5py.File(h5_path, 'w') + saving.save_weights_to_hdf5_group(f_ref_model, ref_model.layers) + + f_model = h5py.File(h5_path, 'r') + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden + 5, input_dim=input_dim, + name='d1')) + model.add(keras.layers.Dense(num_classes, name='d2')) + model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + with self.assertRaisesRegexp(ValueError, + r'Layer #0 \(named "d1"\), weight ' + r' has ' + r'shape \(3, 10\), but the saved weight has ' + r'shape \(3, 5\)\.'): + saving.load_weights_from_hdf5_group_by_name(f_model, model.layers) + class TestWholeModelSaving(test.TestCase): @@ -260,10 +338,18 @@ class TestWholeModelSaving(test.TestCase): model.add(keras.layers.Dense(2, input_shape=(3,))) model.add(keras.layers.RepeatVector(3)) model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) - model.compile(loss=keras.losses.MSE, - optimizer=keras.optimizers.RMSprop(lr=0.0001), - metrics=[keras.metrics.categorical_accuracy], - sample_weight_mode='temporal') + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[ + keras.metrics.categorical_accuracy, + keras.metrics.CategoricalAccuracy() + ], + weighted_metrics=[ + keras.metrics.categorical_accuracy, + keras.metrics.CategoricalAccuracy() + ], + sample_weight_mode='temporal') x = np.random.random((1, 3)) y = np.random.random((1, 3, 3)) model.train_on_batch(x, y) @@ -358,9 +444,17 @@ class TestWholeModelSaving(test.TestCase): output = keras.layers.Dense(3)(x) model = keras.models.Model(inputs, output) - model.compile(loss=keras.losses.MSE, - optimizer=keras.optimizers.RMSprop(lr=0.0001), - metrics=[keras.metrics.categorical_accuracy]) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[ + keras.metrics.categorical_accuracy, + keras.metrics.CategoricalAccuracy() + ], + weighted_metrics=[ + keras.metrics.categorical_accuracy, + keras.metrics.CategoricalAccuracy() + ]) x = np.random.random((1, 3)) y = np.random.random((1, 3)) model.train_on_batch(x, y) @@ -546,9 +640,13 @@ class TestWholeModelSaving(test.TestCase): outputs = keras.layers.Dense(3)(x) model = keras.Model(inputs, outputs) - model.compile(loss=keras.losses.MSE, - optimizer=keras.optimizers.Adam(), - metrics=[keras.metrics.categorical_accuracy]) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.Adam(), + metrics=[ + keras.metrics.categorical_accuracy, + keras.metrics.CategoricalAccuracy() + ]) x = np.random.random((1, 3)) y = np.random.random((1, 3)) model.train_on_batch(x, y) @@ -587,7 +685,23 @@ class SubclassedModel(training.Model): class TestWeightSavingAndLoadingTFFormat(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + def test_keras_optimizer_warning(self): + graph = ops.Graph() + with graph.as_default(), self.test_session(graph): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.Dense(3)) + model.compile(loss='mse', optimizer='adam', metrics=['acc']) + model._make_train_function() + temp_dir = self.get_temp_dir() + prefix = os.path.join(temp_dir, 'ckpt') + with test.mock.patch.object(logging, 'warning') as mock_log: + model.save_weights(prefix) + self.assertRegexpMatches( + str(mock_log.call_args), + 'Keras optimizer') + + @test_util.run_in_graph_and_eager_modes def test_tensorflow_format_overwrite(self): with self.test_session() as session: model = SubclassedModel() @@ -646,18 +760,23 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): self.assertEqual(len(graph.get_operations()), op_count) def _weight_loading_test_template(self, make_model_fn): - with self.test_session() as session: + with self.test_session(): model = make_model_fn() + model.compile( + loss='mse', + optimizer=training_module.RMSPropOptimizer(0.1), + metrics=['acc', keras.metrics.CategoricalAccuracy()]) temp_dir = self.get_temp_dir() prefix = os.path.join(temp_dir, 'ckpt') + train_x = np.random.random((3, 2)) + train_y = np.random.random((3,)) + x = constant_op.constant(train_x, dtype=dtypes.float32) - x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32) - executing_eagerly = context.executing_eagerly() - ref_y_tensor = model(x) - if not executing_eagerly: - session.run([v.initializer for v in model.variables]) - ref_y = self.evaluate(ref_y_tensor) + model.train_on_batch(train_x, train_y) model.save_weights(prefix, save_format='tf') + ref_y_before_train = model.predict(train_x) + model.train_on_batch(train_x, train_y) + ref_y_after_train = model.predict(train_x) for v in model.variables: self.evaluate( v.assign(random_ops.random_normal(shape=array_ops.shape(v)))) @@ -665,18 +784,29 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): self.addCleanup(shutil.rmtree, temp_dir) model.load_weights(prefix) - y = self.evaluate(model(x)) - self.assertAllClose(ref_y, y) + self.assertAllClose(ref_y_before_train, self.evaluate(model(x))) # Test restore-on-create if this is a subclassed Model (graph Networks # will have already created their variables). load_model = make_model_fn() load_model.load_weights(prefix) - restore_on_create_y_tensor = load_model(x) - restore_on_create_y = self.evaluate(restore_on_create_y_tensor) - self.assertAllClose(ref_y, restore_on_create_y) - - @test_util.run_in_graph_and_eager_modes() + self.assertAllClose( + ref_y_before_train, + self.evaluate(load_model(x))) + load_model = make_model_fn() + load_model.load_weights(prefix) + # We need to run some of the restore ops for predict(), but not all + # variables have been created yet (optimizer slot variables). Tests + # incremental restore. + load_model.predict(train_x) + load_model.compile( + loss='mse', + optimizer=training_module.RMSPropOptimizer(0.1), + metrics=['acc', keras.metrics.CategoricalAccuracy()]) + load_model.train_on_batch(train_x, train_y) + self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x))) + + @test_util.run_in_graph_and_eager_modes def test_weight_loading_graph_model(self): def _make_graph_model(): a = keras.layers.Input(shape=(2,)) @@ -686,7 +816,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): self._weight_loading_test_template(_make_graph_model) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_subclassed_model(self): self._weight_loading_test_template(SubclassedModel) @@ -704,6 +834,9 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): session.run([v.initializer for v in model.variables]) ref_y = self.evaluate(ref_y_tensor) model.save_weights(prefix) + self.assertEqual( + prefix, + checkpoint_management.latest_checkpoint(temp_dir)) for v in model.variables: self.evaluate( v.assign(random_ops.random_normal(shape=array_ops.shape(v)))) @@ -720,7 +853,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): y = self.evaluate(model(x)) self.assertAllClose(ref_y, y) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_graph_model_added_layer(self): def _save_graph_model(): a = keras.layers.Input(shape=(2,)) @@ -740,7 +873,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): _save_graph_model, _restore_graph_model, _restore_init_fn) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_graph_model_added_no_weight_layer(self): def _save_graph_model(): a = keras.layers.Input(shape=(2,)) @@ -761,7 +894,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): _save_graph_model, _restore_graph_model, _restore_init_fn) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_weight_loading_subclassed_model_added_layer(self): class SubclassedModelRestore(training.Model): @@ -782,5 +915,6 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): SubclassedModel, SubclassedModelRestore, _restore_init_fn) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index 3ca8fdd3260d964442c18bc30c3925f252e8a304..415b15fde1655f43f3c317cfc6e7756859fc9da1 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -21,14 +21,18 @@ from __future__ import print_function import copy -from tensorflow.python.keras import backend as K +from tensorflow.python.eager import context +from tensorflow.python.framework import ops from tensorflow.python.keras import layers as layer_module from tensorflow.python.keras.engine import base_layer -from tensorflow.python.keras.engine import network from tensorflow.python.keras.engine.input_layer import Input from tensorflow.python.keras.engine.input_layer import InputLayer +from tensorflow.python.keras.engine.network import Network from tensorflow.python.keras.engine.training import Model +from tensorflow.python.keras.utils import layer_utils from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export @@ -91,8 +95,12 @@ class Sequential(Model): ``` """ + @checkpointable.no_automatic_dependency_tracking def __init__(self, layers=None, name=None): super(Sequential, self).__init__(name=name) + self.supports_masking = True + self._build_input_shape = None + self._compute_output_and_mask_jointly = True # Add to the model any layers passed to the constructor. if layers: @@ -104,10 +112,14 @@ class Sequential(Model): # Historically, `sequential.layers` only returns layers that were added # via `add`, and omits the auto-generated `InputLayer` that comes at the # bottom of the stack. - if self._layers and isinstance(self._layers[0], InputLayer): - return self._layers[1:] - return self._layers - + # `CheckpointableBase` manages the `_layers` attributes and does filtering + # over it. + layers = super(Sequential, self).layers + if layers and isinstance(layers[0], InputLayer): + return layers[1:] + return layers[:] + + @checkpointable.no_automatic_dependency_tracking def add(self, layer): """Adds a layer instance on top of the layer stack. @@ -127,32 +139,16 @@ class Sequential(Model): 'an instance of class Layer. ' 'Found: ' + str(layer)) self.built = False + set_inputs = False if not self._layers: - set_inputs = False - # First layer in model: check that it is an input layer. - if not isinstance(layer, InputLayer): - # Create an input tensor and call `layer` on the input tensor. - # First, we need to infer the expected input shape and dtype. - first_layer = layer - if isinstance(layer, (Model, Sequential)): - # We were passed a model as first layer. - # This requires a specific way to figure out the - # input shape and dtype. - if not layer.layers: - raise ValueError('Cannot add an empty model ' - 'to a `Sequential` model.') - # In case of nested models: recover the first layer - # of the deepest model to infer input shape and dtype. - first_layer = layer.layers[0] - while isinstance(first_layer, (Model, Sequential)): - first_layer = first_layer.layers[0] - batch_shape = first_layer._batch_input_shape - dtype = first_layer.dtype - - if hasattr(first_layer, '_batch_input_shape'): - batch_shape = first_layer._batch_input_shape - dtype = first_layer.dtype - # Instantiate the input layer. + if isinstance(layer, InputLayer): + # Corner case where the user passes an InputLayer layer via `add`. + assert len(layer._inbound_nodes[-1].output_tensors) == 1 + set_inputs = True + else: + batch_shape, dtype = get_input_shape_and_dtype(layer) + if batch_shape: + # Instantiate an input layer. x = Input( batch_shape=batch_shape, dtype=dtype, @@ -162,25 +158,20 @@ class Sequential(Model): # to the input layer we just created. layer(x) set_inputs = True - else: - # The layer doesn't know about its expected shape. We will have to - # build the model lazily on `fit`/etc. - batch_shape = None - else: - # Corner case where the user passes an InputLayer layer via `add`. - assert len(layer._inbound_nodes[-1].output_tensors) == 1 - set_inputs = True if set_inputs: + # If an input layer (placeholder) is available. if len(layer._inbound_nodes[-1].output_tensors) != 1: raise ValueError('All layers in a Sequential model ' 'should have a single output tensor. ' 'For multi-output layers, ' 'use the functional API.') - self.outputs = [layer._inbound_nodes[-1].output_tensors[0]] - self.inputs = network.get_source_inputs(self.outputs[0]) + self.inputs = layer_utils.get_source_inputs(self.outputs[0]) + elif self.outputs: + # If the model is being built continuously on top of an input layer: + # refresh its output. output_tensor = layer(self.outputs[0]) if isinstance(output_tensor, list): raise TypeError('All layers in a Sequential model ' @@ -188,11 +179,15 @@ class Sequential(Model): 'For multi-output layers, ' 'use the functional API.') self.outputs = [output_tensor] - if self.inputs: - self.build() + if set_inputs or self._is_graph_network: + self._init_graph_network(self.inputs, self.outputs, name=self.name) + self.built = True else: self._layers.append(layer) + if self._layers: + self._track_layers(self._layers) + @checkpointable.no_automatic_dependency_tracking def pop(self): """Removes the last layer in the model. @@ -203,35 +198,73 @@ class Sequential(Model): raise TypeError('There are no layers in the model.') self._layers.pop() - self.built = False if not self.layers: self.outputs = None self.inputs = None - elif self.outputs: + self.built = False + elif self._is_graph_network: self.layers[-1]._outbound_nodes = [] self.outputs = [self.layers[-1].output] - self.build() + self._init_graph_network(self.inputs, self.outputs, name=self.name) + self.built = True def build(self, input_shape=None): - if input_shape and not self.inputs: - batch_shape = tuple(input_shape) - dtype = K.floatx() - x = Input( - batch_shape=batch_shape, dtype=dtype, name=self.name + '_input') - self.inputs = [x] - for layer in self._layers: - x = layer(x) - self.outputs = [x] - # Make sure that the model's input shape will be preserved during - # serialization. - if self._layers: - self._layers[0]._batch_input_shape = batch_shape - - if self.inputs: + if self._is_graph_network: self._init_graph_network(self.inputs, self.outputs, name=self.name) - self.built = True - if self._layers: - self._track_layers(self._layers) + else: + if input_shape is None: + raise ValueError('You must provide an `input_shape` argument.') + self._build_input_shape = input_shape + shape = input_shape + for layer in self.layers: + if not layer.built: + with ops.name_scope(layer._name_scope()): + layer.build(shape) + layer.built = True + shape = layer.compute_output_shape(shape) + self.built = True + + def call(self, inputs, training=None, mask=None): + if self._is_graph_network: + return super(Sequential, self).call(inputs, training=training, mask=mask) + + outputs, _ = self._call_and_compute_mask( + inputs, training=training, mask=mask) + return outputs + + def _call_and_compute_mask(self, inputs, training=None, mask=None): + if not self.built: + self.build(inputs.shape) + + x = inputs + for layer in self.layers: + kwargs = {} + if 'mask' in tf_inspect.getargspec(layer.call).args: + kwargs['mask'] = mask + if 'training' in tf_inspect.getargspec(layer.call).args: + kwargs['training'] = training + + if isinstance(layer, Network) and layer._compute_output_and_mask_jointly: + x, mask = layer._call_and_compute_mask(x, **kwargs) + else: + x = layer.call(x, **kwargs) + if layer.supports_masking: + mask = layer.compute_mask(x, mask) + else: + mask = None + if not context.executing_eagerly(): + x._keras_mask = mask + return x, mask + + def compute_output_shape(self, input_shape): + shape = input_shape + for layer in self.layers: + shape = layer.compute_output_shape(shape) + return shape + + def compute_mask(self, inputs, mask): + _, mask = self._call_and_compute_mask(inputs, mask=mask) + return mask def predict_proba(self, x, batch_size=32, verbose=0): """Generates class probability predictions for the input samples. @@ -276,18 +309,69 @@ class Sequential(Model): return (proba > 0.5).astype('int32') def get_config(self): - config = [] + layer_configs = [] for layer in self.layers: - config.append({ + layer_configs.append({ 'class_name': layer.__class__.__name__, 'config': layer.get_config() }) - return copy.deepcopy(config) + config = { + 'name': self.name, + 'layers': copy.deepcopy(layer_configs) + } + if self._build_input_shape: + config['build_input_shape'] = self._build_input_shape + return config @classmethod def from_config(cls, config, custom_objects=None): - model = cls() - for conf in config: - layer = layer_module.deserialize(conf, custom_objects=custom_objects) + if 'name' in config: + name = config['name'] + build_input_shape = config.get('build_input_shape') + layer_configs = config['layers'] + else: + name = None + build_input_shape = None + model = cls(name=name) + for layer_config in layer_configs: + layer = layer_module.deserialize(layer_config, + custom_objects=custom_objects) model.add(layer) + if not model.inputs and build_input_shape: + model.build(build_input_shape) return model + + +def get_input_shape_and_dtype(layer): + """Retrieve input shape and input dtype of layer if applicable. + + Args: + layer: Layer (or model) instance. + + Returns: + Tuple (input_shape, input_dtype). Both could be None if the layer + does not have a defined input shape. + + Raises: + ValueError: in case an empty Sequential or Graph Network is passed. + """ + if ((isinstance(layer, Model) and layer._is_graph_network) + or isinstance(layer, Sequential)): + # We were passed a model as first layer. + # This requires a specific way to figure out the + # input shape and dtype. + if not layer.layers: + raise ValueError('Cannot add an empty model ' + 'to a `Sequential` model.') + # In case of nested models: recover the first layer + # of the deepest model to infer input shape and dtype. + layer = layer.layers[0] + while ((isinstance(layer, Model) and layer._is_graph_network) + or isinstance(layer, Sequential)): + layer = layer.layers[0] + + if hasattr(layer, '_batch_input_shape'): + batch_shape = layer._batch_input_shape + dtype = layer.dtype + return batch_shape, dtype + return None, None diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index cdaf9162de8c5318c7e092077d409f05a7edc717..3f8e120df056bad06e432654e2743e32bb271661 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -18,22 +18,34 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test from tensorflow.python.training import rmsprop -class TestSequential(test.TestCase): +def _get_small_mlp(num_hidden, num_classes, input_dim=None): + model = keras.models.Sequential() + if input_dim: + model.add(keras.layers.Dense(num_hidden, activation='relu', + input_dim=input_dim)) + else: + model.add(keras.layers.Dense(num_hidden, activation='relu')) + model.add(keras.layers.Dense(num_classes, activation='softmax')) + return model + + +class TestSequential(test.TestCase, parameterized.TestCase): """Most Sequential model API tests are covered in `training_test.py`. """ - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_basic_methods(self): model = keras.models.Sequential() model.add(keras.layers.Dense(1, input_dim=2)) @@ -44,16 +56,14 @@ class TestSequential(test.TestCase): self.assertEqual(len(model.weights), 2 * 2) self.assertEqual(model.get_layer(name='dp').name, 'dp') - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_sequential_pop(self): num_hidden = 5 input_dim = 3 batch_size = 5 num_classes = 2 - model = keras.models.Sequential() - model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) - model.add(keras.layers.Dense(num_classes)) + model = _get_small_mlp(num_hidden, num_classes, input_dim) model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3)) x = np.random.random((batch_size, input_dim)) y = np.random.random((batch_size, num_classes)) @@ -77,18 +87,18 @@ class TestSequential(test.TestCase): with self.assertRaises(TypeError): model.pop() - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_sequential_deferred_build_with_np_arrays(self): num_hidden = 5 input_dim = 3 batch_size = 5 num_classes = 2 - model = keras.models.Sequential() - # We don't specify the input shape. - model.add(keras.layers.Dense(num_hidden)) - model.add(keras.layers.Dense(num_classes)) - model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3)) + model = _get_small_mlp(num_hidden, num_classes) + model.compile( + loss='mse', + optimizer=rmsprop.RMSPropOptimizer(1e-3), + metrics=[keras.metrics.CategoricalAccuracy()]) self.assertEqual(len(model.layers), 2) self.assertEqual(len(model.weights), 0) self.assertFalse(model.built) @@ -97,27 +107,22 @@ class TestSequential(test.TestCase): y = np.random.random((batch_size, num_classes)) model.fit(x, y, epochs=1) self.assertTrue(model.built) - self.assertEqual(model.inputs[0].get_shape().as_list(), [None, input_dim]) - self.assertEqual(model.outputs[0].get_shape().as_list(), - [None, num_classes]) + self.assertFalse(model._is_graph_network) self.assertEqual(len(model.weights), 2 * 2) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_sequential_deferred_build_with_dataset_iterators(self): - if not context.executing_eagerly(): - # TODO(psv/fchollet): Add support for this use case in graph mode. - return num_hidden = 5 input_dim = 3 num_classes = 2 num_samples = 50 steps_per_epoch = 10 - model = keras.models.Sequential() - # We don't specify the input shape. - model.add(keras.layers.Dense(num_hidden)) - model.add(keras.layers.Dense(num_classes)) - model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3)) + model = _get_small_mlp(num_hidden, num_classes) + model.compile( + loss='mse', + optimizer=rmsprop.RMSPropOptimizer(1e-3), + metrics=[keras.metrics.CategoricalAccuracy()]) self.assertEqual(len(model.layers), 2) self.assertEqual(len(model.weights), 0) self.assertFalse(model.built) @@ -131,12 +136,53 @@ class TestSequential(test.TestCase): model.fit(iterator, epochs=1, steps_per_epoch=steps_per_epoch) self.assertTrue(model.built) - self.assertEqual(model.inputs[0].get_shape().as_list(), [None, input_dim]) - self.assertEqual(model.outputs[0].get_shape().as_list(), - [None, num_classes]) self.assertEqual(len(model.weights), 2 * 2) + self.assertFalse(model._is_graph_network) + + @parameterized.parameters((True,), (False,)) + def test_training_and_eval_methods_on_symbolic_tensors(self, deferred): + with self.test_session(): - @tf_test_util.run_in_graph_and_eager_modes() + def get_model(): + if deferred: + model = _get_small_mlp(10, 4) + else: + model = _get_small_mlp(10, 4, input_dim=3) + model.compile( + optimizer=rmsprop.RMSPropOptimizer(1e-3), + loss='categorical_crossentropy', + metrics=['accuracy']) + return model + + inputs = keras.backend.zeros(shape=(10, 3)) + targets = keras.backend.zeros(shape=(10, 4)) + + model = get_model() + model.fit(inputs, targets, epochs=10, steps_per_epoch=30) + + model = get_model() + model.evaluate(inputs, targets, steps=2, verbose=0) + + model = get_model() + model.predict(inputs, steps=2) + + model = get_model() + model.train_on_batch(inputs, targets) + + model = get_model() + model.test_on_batch(inputs, targets) + + model = get_model() + model.fit( + inputs, + targets, + epochs=1, + steps_per_epoch=2, + verbose=0, + validation_data=(inputs, targets), + validation_steps=2) + + @tf_test_util.run_in_graph_and_eager_modes def test_invalid_use_cases(self): # Added objects must be layer instances with self.assertRaises(TypeError): @@ -160,7 +206,7 @@ class TestSequential(test.TestCase): model.add(keras.layers.Dense(1, input_dim=1)) model.add(MyLayer()) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_nested_sequential_trainability(self): input_dim = 20 num_units = 10 @@ -209,17 +255,18 @@ class TestSequential(test.TestCase): x2 = model.predict(val_a) assert np.abs(np.sum(x1 - x2)) > 1e-5 + @tf_test_util.run_in_graph_and_eager_modes def test_sequential_deferred_build_serialization(self): num_hidden = 5 input_dim = 3 batch_size = 5 num_classes = 2 - model = keras.models.Sequential() - # We don't specify the input shape. - model.add(keras.layers.Dense(num_hidden)) - model.add(keras.layers.Dense(num_classes)) - model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3)) + model = _get_small_mlp(num_hidden, num_classes) + model.compile( + loss='mse', + optimizer=rmsprop.RMSPropOptimizer(1e-3), + metrics=[keras.metrics.CategoricalAccuracy()]) self.assertFalse(model.built) x = np.random.random((batch_size, input_dim)) @@ -228,11 +275,93 @@ class TestSequential(test.TestCase): self.assertTrue(model.built) config = model.get_config() + self.assertIn('build_input_shape', config) + new_model = keras.models.Sequential.from_config(config) self.assertTrue(new_model.built) self.assertEqual(len(model.layers), 2) self.assertEqual(len(model.weights), 4) + @tf_test_util.run_in_graph_and_eager_modes + def test_sequential_shape_inference_deferred(self): + model = _get_small_mlp(4, 5) + output_shape = model.compute_output_shape((None, 7)) + self.assertEqual(tuple(output_shape.as_list()), (None, 5)) + + @tf_test_util.run_in_graph_and_eager_modes + def test_sequential_build_deferred(self): + model = _get_small_mlp(4, 5) + + model.build((None, 10)) + self.assertTrue(model.built) + self.assertEqual(len(model.weights), 4) + + # Test with nested model + model = _get_small_mlp(4, 3) + inner_model = _get_small_mlp(4, 5) + model.add(inner_model) + + model.build((None, 10)) + self.assertTrue(model.built) + self.assertTrue(model.layers[-1].built) + self.assertEqual(len(model.weights), 8) + + @tf_test_util.run_in_graph_and_eager_modes + def test_sequential_nesting(self): + model = _get_small_mlp(4, 3) + inner_model = _get_small_mlp(4, 5) + model.add(inner_model) + + model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3)) + x = np.random.random((2, 6)) + y = np.random.random((2, 5)) + model.fit(x, y, epochs=1) + + @tf_test_util.run_in_graph_and_eager_modes + def test_variable_names(self): + model = keras.models.Sequential([keras.layers.Dense(3)]) + model.add(keras.layers.Dense(2)) + model(array_ops.ones([2, 4])) + self.assertEqual( + ['sequential/dense/kernel:0', 'sequential/dense/bias:0', + 'sequential/dense_1/kernel:0', 'sequential/dense_1/bias:0'], + [v.name for v in model.variables]) + + +class TestSequentialEagerIntegration(test.TestCase): + + @tf_test_util.run_in_graph_and_eager_modes + def test_defun_on_call(self): + # Check that one can subclass Sequential and place the `call` in a `defun`. + + class MySequential(keras.Sequential): + + def __init__(self, name=None): + super(MySequential, self).__init__(name=name) + self.call = function.defun(self.call) + + model = MySequential() + model.add(keras.layers.Dense(4, activation='relu')) + model.add(keras.layers.Dense(5, activation='softmax')) + + model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3)) + + x = np.random.random((2, 6)) + y = np.random.random((2, 5)) + model.fit(x, y, epochs=1) + + @tf_test_util.run_in_graph_and_eager_modes + def test_build_before_fit(self): + # Fix for b/112433577 + model = _get_small_mlp(4, 5) + model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3)) + + model.build((None, 6)) + + x = np.random.random((2, 6)) + y = np.random.random((2, 5)) + model.fit(x, y, epochs=1) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py index 183e26e8bf813ec0a8c84920a93dcb79a291ca9d..079c8dae71153e597d8be119a685085864fcae83 100644 --- a/tensorflow/python/keras/engine/topology_test.py +++ b/tensorflow/python/keras/engine/topology_test.py @@ -24,8 +24,11 @@ from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import input_layer as input_layer_lib +from tensorflow.python.keras.engine import network as network_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -62,7 +65,7 @@ class TopologyConstructionTest(test.TestCase): inputs=True) return inputs + 1 - x1 = keras.Input(shape=(1,)) + x1 = input_layer_lib.Input(shape=(1,)) layer = MyLayer() _ = layer.apply(x1) @@ -70,7 +73,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_updates_for(x1)), 1) self.assertEqual(len(layer.get_updates_for(None)), 1) - x2 = keras.Input(shape=(1,)) + x2 = input_layer_lib.Input(shape=(1,)) y2 = layer.apply(x2) self.assertEqual(len(layer.updates), 3) @@ -78,17 +81,17 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_updates_for(x2)), 1) self.assertEqual(len(layer.get_updates_for(None)), 1) - network = keras.engine.Network(x2, y2) + network = network_lib.Network(x2, y2) self.assertEqual(len(network.updates), 2) self.assertEqual(len(network.get_updates_for(x1)), 0) self.assertEqual(len(network.get_updates_for(x2)), 1) self.assertEqual(len(network.get_updates_for(None)), 1) - x3 = keras.Input(shape=(1,)) + x3 = input_layer_lib.Input(shape=(1,)) _ = layer.apply(x3) self.assertEqual(len(network.updates), 2) - x4 = keras.Input(shape=(1,)) + x4 = input_layer_lib.Input(shape=(1,)) _ = network(x4) self.assertEqual(len(network.updates), 3) self.assertEqual(len(network.get_updates_for(x2)), 1) @@ -104,11 +107,10 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(network.get_updates_for(x4)), 2) def test_get_updates_bn(self): - x1 = keras.Input(shape=(1,)) + x1 = input_layer_lib.Input(shape=(1,)) layer = keras.layers.BatchNormalization() _ = layer.apply(x1) - print('BN updates', layer._updates) self.assertEqual(len(layer.updates), 2) self.assertEqual(len(layer.get_updates_for(x1)), 2) self.assertEqual(len(layer.get_updates_for(None)), 0) @@ -134,7 +136,7 @@ class TopologyConstructionTest(test.TestCase): inputs=True) return inputs + 1 - x1 = keras.Input(shape=(1,)) + x1 = input_layer_lib.Input(shape=(1,)) layer = MyLayer() _ = layer.apply(x1) @@ -142,7 +144,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_losses_for(x1)), 1) self.assertEqual(len(layer.get_losses_for(None)), 1) - x2 = keras.Input(shape=(1,)) + x2 = input_layer_lib.Input(shape=(1,)) y2 = layer.apply(x2) self.assertEqual(len(layer.losses), 3) @@ -150,17 +152,17 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(len(layer.get_losses_for(x2)), 1) self.assertEqual(len(layer.get_losses_for(None)), 1) - network = keras.engine.Network(x2, y2) + network = network_lib.Network(x2, y2) self.assertEqual(len(network.losses), 2) self.assertEqual(len(network.get_losses_for(x1)), 0) self.assertEqual(len(network.get_losses_for(x2)), 1) self.assertEqual(len(network.get_losses_for(None)), 1) - x3 = keras.Input(shape=(1,)) + x3 = input_layer_lib.Input(shape=(1,)) _ = layer.apply(x3) self.assertEqual(len(network.losses), 2) - x4 = keras.Input(shape=(1,)) + x4 = input_layer_lib.Input(shape=(1,)) _ = network(x4) self.assertEqual(len(network.losses), 3) self.assertEqual(len(network.get_losses_for(x2)), 1) @@ -177,8 +179,8 @@ class TopologyConstructionTest(test.TestCase): def testTopologicalAttributes(self): # test layer attributes / methods related to cross-layer connectivity. - a = keras.Input(shape=(32,), name='input_a') - b = keras.Input(shape=(32,), name='input_b') + a = input_layer_lib.Input(shape=(32,), name='input_a') + b = input_layer_lib.Input(shape=(32,), name='input_b') # test input, output, input_shape, output_shape test_layer = keras.layers.Dense(16, name='test_layer') @@ -219,15 +221,15 @@ class TopologyConstructionTest(test.TestCase): _ = new_dense.input_shape with self.assertRaises(AttributeError): new_dense = keras.layers.Dense(16) - a = keras.Input(shape=(3, 32)) - a = keras.Input(shape=(5, 32)) + a = input_layer_lib.Input(shape=(3, 32)) + a = input_layer_lib.Input(shape=(5, 32)) a_2 = dense(a) b_2 = dense(b) _ = new_dense.input_shape with self.assertRaises(AttributeError): new_dense = keras.layers.Dense(16) - a = keras.Input(shape=(3, 32)) - a = keras.Input(shape=(5, 32)) + a = input_layer_lib.Input(shape=(3, 32)) + a = input_layer_lib.Input(shape=(5, 32)) a_2 = dense(a) b_2 = dense(b) _ = new_dense.output_shape @@ -239,7 +241,7 @@ class TopologyConstructionTest(test.TestCase): def call(self, inputs): return [inputs**2, inputs**3] - x = keras.Input(shape=(32,)) + x = input_layer_lib.Input(shape=(32,)) test_layer = PowersLayer() p1, p2 = test_layer(x) # pylint: disable=not-callable @@ -256,8 +258,8 @@ class TopologyConstructionTest(test.TestCase): assert len(inputs) == 2 return inputs[0] + inputs[1] - a = keras.Input(shape=(32,)) - b = keras.Input(shape=(32,)) + a = input_layer_lib.Input(shape=(32,)) + b = input_layer_lib.Input(shape=(32,)) test_layer = AddLayer() y = test_layer([a, b]) # pylint: disable=not-callable @@ -268,10 +270,10 @@ class TopologyConstructionTest(test.TestCase): def testBasicNetwork(self): # minimum viable network - x = keras.Input(shape=(32,)) + x = input_layer_lib.Input(shape=(32,)) dense = keras.layers.Dense(2) y = dense(x) - network = keras.engine.Network(x, y, name='dense_network') + network = network_lib.Network(x, y, name='dense_network') # test basic attributes self.assertEqual(network.name, 'dense_network') @@ -282,7 +284,7 @@ class TopologyConstructionTest(test.TestCase): self.assertEqual(network.non_trainable_weights, dense.non_trainable_weights) # test callability on Input - x_2 = keras.Input(shape=(32,)) + x_2 = input_layer_lib.Input(shape=(32,)) y_2 = network(x_2) self.assertEqual(y_2.get_shape().as_list(), [None, 2]) @@ -506,7 +508,7 @@ class TopologyConstructionTest(test.TestCase): self.assertListEqual([x.shape for x in fn_outputs], [(10, 64), (10, 5)]) # test get_source_inputs - self.assertListEqual(keras.engine.network.get_source_inputs(c), [a, b]) + self.assertListEqual(keras.engine.get_source_inputs(c), [a, b]) # serialization / deserialization json_config = model.to_json() @@ -778,12 +780,12 @@ class TopologyConstructionTest(test.TestCase): self.evaluate(getattr(b, '_keras_mask'))) self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b)) else: - x = keras.Input(shape=(32,)) + x = input_layer_lib.Input(shape=(32,)) y = MaskedLayer()(x) # pylint: disable=not-callable - network = keras.engine.Network(x, y) + network = network_lib.Network(x, y) # test callability on Input - x_2 = keras.Input(shape=(32,)) + x_2 = input_layer_lib.Input(shape=(32,)) y_2 = network(x_2) self.assertEqual(y_2.get_shape().as_list(), [None, 32]) @@ -797,14 +799,14 @@ class TopologyConstructionTest(test.TestCase): def reg(x): return math_ops.reduce_sum(x) - net_a_input = keras.Input((2,)) + net_a_input = input_layer_lib.Input((2,)) net_a = net_a_input net_a = keras.layers.Dense(2, kernel_initializer='ones', use_bias=False, activity_regularizer=reg)(net_a) model_a = keras.Model([net_a_input], [net_a]) - net_b_input = keras.Input((2,)) + net_b_input = input_layer_lib.Input((2,)) net_b = model_a(net_b_input) model_b = keras.Model([net_b_input], [net_b]) @@ -817,7 +819,7 @@ class TopologyConstructionTest(test.TestCase): with self.test_session(): x_val = np.random.random((10, 5)) - x = keras.Input(shape=(5,)) + x = input_layer_lib.Input(shape=(5,)) a = keras.layers.Dense(5, name='A') b = keras.layers.Dense(5, name='B') output = a(b(a(b(x)))) @@ -837,7 +839,7 @@ class TopologyConstructionTest(test.TestCase): def test_layer_sharing_at_heterogenous_depth_with_concat(self): with self.test_session(): input_shape = (16, 9, 3) - input_layer = keras.Input(shape=input_shape) + input_layer = input_layer_lib.Input(shape=input_shape) a = keras.layers.Dense(3, name='dense_A') b = keras.layers.Dense(3, name='dense_B') @@ -924,7 +926,7 @@ class DeferredModeTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testSimpleNetworkBuilding(self): - inputs = keras.engine.Input(shape=(32,)) + inputs = input_layer_lib.Input(shape=(32,)) if context.executing_eagerly(): self.assertIsInstance(inputs, base_layer.DeferredTensor) self.assertEqual(inputs.dtype.name, 'float32') @@ -937,8 +939,8 @@ class DeferredModeTest(test.TestCase): self.assertEqual(x.shape.as_list(), [None, 2]) outputs = keras.layers.Dense(4)(x) - network = keras.engine.Network(inputs, outputs) - self.assertIsInstance(network, keras.engine.Network) + network = network_lib.Network(inputs, outputs) + self.assertIsInstance(network, network_lib.Network) if context.executing_eagerly(): # It should be possible to call such a network on EagerTensors. @@ -949,8 +951,8 @@ class DeferredModeTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testMultiIONetworkbuilding(self): - input_a = keras.engine.Input(shape=(32,)) - input_b = keras.engine.Input(shape=(16,)) + input_a = input_layer_lib.Input(shape=(32,)) + input_b = input_layer_lib.Input(shape=(16,)) a = keras.layers.Dense(16)(input_a) class AddLayer(keras.layers.Layer): @@ -958,13 +960,10 @@ class DeferredModeTest(test.TestCase): def call(self, inputs): return inputs[0] + inputs[1] - def compute_output_shape(self, input_shape): - return input_shape[0] - c = AddLayer()([a, input_b]) # pylint: disable=not-callable c = keras.layers.Dense(2)(c) - network = keras.engine.Network([input_a, input_b], [a, c]) + network = network_lib.Network([input_a, input_b], [a, c]) if context.executing_eagerly(): a_val = constant_op.constant( np.random.random((10, 32)).astype('float32')) @@ -976,6 +975,196 @@ class DeferredModeTest(test.TestCase): self.assertEqual(outputs[1].shape.as_list(), [10, 2]) +class DefaultShapeInferenceBehaviorTest(test.TestCase): + + def _testShapeInference(self, model, input_shape, expected_output_shape): + input_value = np.random.random(input_shape) + output_value = model.predict(input_value) + self.assertEqual(output_value.shape, expected_output_shape) + + @test_util.run_in_graph_and_eager_modes() + def testSingleInputCase(self): + + class LayerWithOneInput(keras.layers.Layer): + + def build(self, input_shape): + self.w = array_ops.ones(shape=(3, 4)) + + def call(self, inputs): + return keras.backend.dot(inputs, self.w) + + inputs = input_layer_lib.Input(shape=(3,)) + layer = LayerWithOneInput() + + if context.executing_eagerly(): + self.assertEqual( + layer.compute_output_shape((None, 3)).as_list(), [None, 4]) + # As a side-effect, compute_output_shape builds the layer. + self.assertTrue(layer.built) + # We can still query the layer's compute_output_shape with compatible + # input shapes. + self.assertEqual( + layer.compute_output_shape((6, 3)).as_list(), [6, 4]) + + outputs = layer(inputs) + model = keras.Model(inputs, outputs) + self._testShapeInference(model, (2, 3), (2, 4)) + + @test_util.run_in_graph_and_eager_modes() + def testMultiInputOutputCase(self): + + class MultiInputOutputLayer(keras.layers.Layer): + + def build(self, input_shape): + self.w = array_ops.ones(shape=(3, 4)) + + def call(self, inputs): + a = keras.backend.dot(inputs[0], self.w) + b = a + inputs[1] + return [a, b] + + input_a = input_layer_lib.Input(shape=(3,)) + input_b = input_layer_lib.Input(shape=(4,)) + output_a, output_b = MultiInputOutputLayer()([input_a, input_b]) + model = keras.Model([input_a, input_b], [output_a, output_b]) + output_a_val, output_b_val = model.predict( + [np.random.random((2, 3)), np.random.random((2, 4))]) + self.assertEqual(output_a_val.shape, (2, 4)) + self.assertEqual(output_b_val.shape, (2, 4)) + + @test_util.run_in_graph_and_eager_modes() + def testTrainingArgument(self): + + class LayerWithTrainingArg(keras.layers.Layer): + + def build(self, input_shape): + self.w = array_ops.ones(shape=(3, 4)) + + def call(self, inputs, training): + return keras.backend.dot(inputs, self.w) + + inputs = input_layer_lib.Input(shape=(3,)) + outputs = LayerWithTrainingArg()(inputs, training=False) + model = keras.Model(inputs, outputs) + self._testShapeInference(model, (2, 3), (2, 4)) + + @test_util.run_in_graph_and_eager_modes() + def testUnsupportedSignature(self): + + class LayerWithAdditionalArg(keras.layers.Layer): + + def build(self, input_shape): + self.w = array_ops.ones(shape=(3, 4)) + + def call(self, inputs, some_arg): + return keras.backend.dot(inputs, self.w) + some_arg + + inputs = input_layer_lib.Input(shape=(3,)) + if context.executing_eagerly(): + with self.assertRaises(NotImplementedError): + outputs = LayerWithAdditionalArg()(inputs, some_arg=0) + else: + # Works with graph mode because the graph of ops is built together with + # the graph of layers. + outputs = LayerWithAdditionalArg()(inputs, some_arg=0) + _ = keras.Model(inputs, outputs) + + @test_util.run_in_graph_and_eager_modes() + def testNoneInShape(self): + + class Model(keras.Model): + + def __init__(self): + super(Model, self).__init__() + self.conv1 = keras.layers.Conv2D(8, 3) + self.pool = keras.layers.GlobalAveragePooling2D() + self.fc = keras.layers.Dense(3) + + def call(self, x): + x = self.conv1(x) + x = self.pool(x) + x = self.fc(x) + return x + + model = Model() + model.build(tensor_shape.TensorShape((None, None, None, 1))) + self.assertTrue(model.built, 'Model should be built') + self.assertTrue(model.weights, + 'Model should have its weights created as it ' + 'has been built') + sample_input = array_ops.ones((1, 10, 10, 1)) + output = model(sample_input) + self.assertEqual(output.shape, (1, 3)) + + @test_util.run_in_graph_and_eager_modes() + def testNoneInShapeWithCompoundModel(self): + + class BasicBlock(keras.Model): + + def __init__(self): + super(BasicBlock, self).__init__() + self.conv1 = keras.layers.Conv2D(8, 3) + self.pool = keras.layers.GlobalAveragePooling2D() + self.dense = keras.layers.Dense(3) + + def call(self, x): + x = self.conv1(x) + x = self.pool(x) + x = self.dense(x) + return x + + class CompoundModel(keras.Model): + + def __init__(self): + super(CompoundModel, self).__init__() + self.block = BasicBlock() + + def call(self, x): + x = self.block(x) # pylint: disable=not-callable + return x + + model = CompoundModel() + model.build(tensor_shape.TensorShape((None, None, None, 1))) + self.assertTrue(model.built, 'Model should be built') + self.assertTrue(model.weights, + 'Model should have its weights created as it ' + 'has been built') + sample_input = array_ops.ones((1, 10, 10, 1)) + output = model(sample_input) # pylint: disable=not-callable + self.assertEqual(output.shape, (1, 3)) + + @test_util.run_in_graph_and_eager_modes() + def testNoneInShapeWithFunctinalAPI(self): + + class BasicBlock(keras.Model): + # Inherting from keras.layers.Layer since we are calling this layer + # inside a model created using functional API. + + def __init__(self): + super(BasicBlock, self).__init__() + self.conv1 = keras.layers.Conv2D(8, 3) + + def call(self, x): + x = self.conv1(x) + return x + + input_layer = keras.layers.Input(shape=(None, None, 1)) + x = BasicBlock()(input_layer) + x = keras.layers.GlobalAveragePooling2D()(x) + output_layer = keras.layers.Dense(3)(x) + + model = keras.Model(inputs=input_layer, outputs=output_layer) + + model.build(tensor_shape.TensorShape((None, None, None, 1))) + self.assertTrue(model.built, 'Model should be built') + self.assertTrue(model.weights, + 'Model should have its weights created as it ' + 'has been built') + sample_input = array_ops.ones((1, 10, 10, 1)) + output = model(sample_input) + self.assertEqual(output.shape, (1, 3)) + + class GraphUtilsTest(test.TestCase): def testGetReachableFromInputs(self): diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index fce6cbdb7a09d5a4daa70d0483732735e70c28b7..f71388cadb94d8215dddb8bc4a7cb2d38d7823a4 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -24,7 +24,6 @@ import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -33,15 +32,19 @@ from tensorflow.python.keras import losses from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.engine import training_arrays +from tensorflow.python.keras.engine import training_distributed from tensorflow.python.keras.engine import training_eager from tensorflow.python.keras.engine import training_generator from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine.network import Network from tensorflow.python.keras.utils.generic_utils import slice_arrays -from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import optimizer as tf_optimizer_module +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util.tf_export import tf_export @@ -74,6 +77,7 @@ class Model(Network): class MyModel(tf.keras.Model): def __init__(self): + super(MyModel, self).__init__() self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) @@ -94,6 +98,7 @@ class Model(Network): class MyModel(tf.keras.Model): def __init__(self): + super(MyModel, self).__init__() self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) self.dropout = tf.keras.layers.Dropout(0.5) @@ -114,7 +119,190 @@ class Model(Network): self._iterator_get_next = weakref.WeakKeyDictionary() # Create a cache for dataset - uninitialized iterators self._dataset_iterator_cache = weakref.WeakKeyDictionary() + # initializing _distribution_strategy here since it is possible to call + # predict on a model without compiling it. + self._distribution_strategy = None + + def _set_sample_weight_attributes(self, sample_weight_mode, + skip_target_weighing_indices): + """Sets sample weight related attributes on the model.""" + sample_weights, sample_weight_modes = training_utils.prepare_sample_weights( + self.output_names, sample_weight_mode, skip_target_weighing_indices) + self.sample_weights = sample_weights + self.sample_weight_modes = sample_weight_modes + self._feed_sample_weight_modes = [ + sample_weight_modes[i] + for i in range(len(self.outputs)) + if i not in skip_target_weighing_indices + ] + self._feed_sample_weights = [ + sample_weights[i] + for i in range(len(sample_weights)) + if i not in skip_target_weighing_indices + ] + + def _get_metric_name(self, metric, output_index, weighted=False): + """Returns the metric name corresponding to the given metric input. + + Arguments: + metric: Metric function name or reference. + output_index: Index of the current output. + weighted: Boolean indicating if the given metric is weighted. + + Returns: + A metric name. + """ + metric_name_prefix = 'weighted_' if weighted else '' + if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): + if metric in ('accuracy', 'acc'): + suffix = 'acc' + elif metric in ('crossentropy', 'ce'): + suffix = 'ce' + else: + metric_fn = metrics_module.get(metric) + # Get metric name as string + if hasattr(metric_fn, 'name'): + suffix = metric_fn.name + else: + suffix = metric_fn.__name__ + metric_name = metric_name_prefix + suffix + + if len(self.output_names) > 1: + metric_name = '%s_%s' % (self.output_names[output_index], metric_name) + j = 1 + base_metric_name = metric_name + while metric_name in self.metrics_names: + metric_name = '%s_%d' % (base_metric_name, j) + j += 1 + + return metric_name + + def _handle_per_output_metrics(self, + metrics, + y_true, + y_pred, + output_index, + output_shape, + loss_fn, + mask, + weights=None): + """Calls metric functions and sets metric attributes for a single output. + + Arguments: + metrics: List of metrics. + y_true: Target output. + y_pred: Predicted output. + output_index: Index of the current output. + output_shape: Shape of the current output. + loss_fn: Loss function corresponding to the current output. + mask: Computed mask value for the current output. + weights: Weights to be applied on the current output. + + Returns: + A list of metric result tensors. + """ + metric_results = [] + for metric in metrics: + metric_fn = training_utils.get_metric_function( + metric, output_shape=output_shape, loss_fn=loss_fn) + metric_name = self._get_metric_name( + metric, output_index, weighted=weights is not None) + + with K.name_scope(metric_name): + # If both outputs and targets are available, call the metric function. + if y_true is not None and y_pred is not None: + if isinstance(metric_fn, metrics_module.Metric): + # Call the stateful metric function. + if mask is not None: + mask = math_ops.cast(mask, y_pred.dtype) + # Update weights with mask. + if weights is None: + weights = mask + else: + # Update shape of weights if possible before adding mask. + # Update dimensions of weights to match with mask if possible. + mask, _, weights = metrics_module.squeeze_or_expand_dimensions( + mask, None, weights) + try: + # Broadcast weights if possible. + weights = weights_broadcast_ops.broadcast_weights( + weights, mask) + except ValueError: + pass + # TODO(psv): Handle case when mask and weight shapes are not + # compatible. + weights *= mask + + metric_result = metric_fn(y_true, y_pred, weights) + else: + # Call the stateless metric function. + weighted_metric_fn = training_utils.weighted_masked_objective( + metric_fn) + metric_result = weighted_metric_fn( + y_true, y_pred, weights=weights, mask=mask) + + if not context.executing_eagerly(): + # Keep track of metric result tensor. + self.metrics_tensors.append(metric_result) + metric_results.append(metric_result) + + # Keep track of metric name. + self.metrics_names.append(metric_name) + + # Keep track of stateful metric attributes (name and metric function). + if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful: + self.stateful_metric_names.append(metric_name) + self.stateful_metric_functions.append(metric_fn) + if not context.executing_eagerly(): + # Keep track of updates created by stateful metrics. + self.metrics_updates += metric_fn.updates + return metric_results + + def _handle_metrics(self, + outputs, + skip_target_indices=None, + targets=None, + sample_weights=None, + masks=None): + """Handles calling metric functions and setting model metric attributes. + + Arguments: + outputs: List of outputs (predictions). + skip_target_indices: Optional. List of target ids to skip. + targets: List of targets. + sample_weights: Optional list of sample weight arrays. + masks: List of computed output mask values. + Returns: + A list of metric result tensors. + """ + skip_target_indices = skip_target_indices or [] + metric_results = [] + with K.name_scope('metrics'): + for i in range(len(outputs)): + if i in skip_target_indices: + continue + output = outputs[i] if outputs else None + target = targets[i] if targets else None + output_shape = None if output is None else output.get_shape().as_list() + output_mask = masks[i] if masks else None + metric_results.extend( + self._handle_per_output_metrics( + self.nested_metrics[i], target, output, i, output_shape, + self.loss_functions[i], output_mask)) + metric_results.extend( + self._handle_per_output_metrics( + self.nested_weighted_metrics[i], + target, + output, + i, + output_shape, + self.loss_functions[i], + output_mask, + weights=sample_weights[i])) + return metric_results + + @checkpointable.no_automatic_dependency_tracking def compile(self, optimizer, loss=None, @@ -123,14 +311,15 @@ class Model(Network): sample_weight_mode=None, weighted_metrics=None, target_tensors=None, + distribute=None, **kwargs): """Configures the model for training. Arguments: optimizer: String (name of optimizer) or optimizer instance. - See [optimizers](/optimizers). + See [optimizers](/api_docs/python/tf/keras/optimizers). loss: String (name of objective function) or objective function. - See [losses](/losses). + See [losses](/api_docs/python/tf/losses). If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model @@ -166,31 +355,70 @@ class Model(Network): can specify them via the `target_tensors` argument. It can be a single tensor (for a single-output model), a list of tensors, or a dict mapping output names to target tensors. + distribute: The DistributionStrategy instance that we want to use to + distribute the training of the model. **kwargs: These arguments are passed to `tf.Session.run`. Raises: ValueError: In case of invalid arguments for `optimizer`, `loss`, `metrics` or `sample_weight_mode`. """ + # Validate that arguments passed by the user to `compile` are supported by + # DistributionStrategy. + if distribute and not isinstance( + optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)): + raise NotImplementedError('Only TF native optimizers are supported with ' + 'DistributionStrategy.') + if distribute and context.executing_eagerly(): + raise NotImplementedError('DistributionStrategy is not supported in ' + 'Eager mode.') + if distribute and sample_weight_mode: + raise NotImplementedError('sample_weight_mode is not supported with ' + 'DistributionStrategy.') + if distribute and weighted_metrics: + raise NotImplementedError('weighted_metrics is not supported with ' + 'DistributionStrategy.') + if distribute and target_tensors: + raise ValueError('target_tensors is not supported with ' + 'DistributionStrategy.') + loss = loss or {} if context.executing_eagerly() and not isinstance( optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)): raise ValueError('Only TF native optimizers are supported in Eager mode.') self.optimizer = optimizers.get(optimizer) + # We've disabled automatic dependency tracking for this method, but do want + # to add a checkpoint dependency on the optimizer if it's checkpointable. + if isinstance(self.optimizer, checkpointable.CheckpointableBase): + self._track_checkpointable( + self.optimizer, name='optimizer', overwrite=True) self.loss = loss self.metrics = metrics or [] self.loss_weights = loss_weights - if context.executing_eagerly() and sample_weight_mode is not None: - raise ValueError('sample_weight_mode is not supported in Eager mode.') self.sample_weight_mode = sample_weight_mode - if context.executing_eagerly() and weighted_metrics is not None: - raise ValueError('weighted_metrics is not supported in Eager mode.') self.weighted_metrics = weighted_metrics if context.executing_eagerly() and target_tensors is not None: raise ValueError('target_tensors is not supported in Eager mode.') self.target_tensors = target_tensors + # Set DistributionStrategy specific parameters. + self._distribution_strategy = distribute + if self._distribution_strategy is not None: + self._grouped_model = self._compile_distributed_model( + self._distribution_strategy) + with self._distribution_strategy.scope(): + first_replicated_model = self._distribution_strategy.unwrap( + self._grouped_model)[0] + # If the specified metrics in `compile` are stateful, raise an error + # since we currently don't support stateful metrics. + if first_replicated_model.stateful_metric_names: + raise NotImplementedError('Stateful metrics are not supported with ' + 'DistributionStrategy.') + + # We initialize the callback model with the first replicated model. + self._replicated_model = DistributedCallbackModel(first_replicated_model) + self._replicated_model.set_original_model(self) if not self.built: # Model is not compilable because it does not know its number of inputs # and outputs, nor their shapes and names. We will compile after the first @@ -210,10 +438,9 @@ class Model(Network): for name in self.output_names: if name not in loss: logging.warning( - 'Output "' + name + '" missing from loss dictionary. ' - 'We assume this was done on purpose, ' - 'and we will not be expecting ' - 'any data to be passed to "' + name + '" during training.') + 'Output "' + name + '" missing from loss dictionary. We assume ' + 'this was done on purpose. The fit and evaluate APIs will not be ' + 'expecting any data to be passed to "' + name + '".') loss_functions.append(losses.get(loss.get(name))) elif isinstance(loss, list): if len(loss) != len(self.outputs): @@ -242,9 +469,7 @@ class Model(Network): # Prepare output masks. if not context.executing_eagerly(): - masks = self.compute_mask(self.inputs, mask=None) - if masks is None: - masks = [None for _ in self.outputs] + masks = [getattr(x, '_keras_mask', None) for x in self.outputs] if not isinstance(masks, list): masks = [masks] @@ -274,29 +499,40 @@ class Model(Network): str(loss_weights) + ' - expected a list of dicts.') self.loss_weights_list = loss_weights_list - # initialization for Eager mode execution + # Initialize model metric attributes. + self.metrics_names = ['loss'] + self.metrics_tensors = [] + self.metrics_updates = [] + self.stateful_metric_names = [] + self.stateful_metric_functions = [] + + # Nested metrics is a list of list of metrics. + # One list per output of the model. + self.nested_metrics = training_utils.collect_metrics( + metrics, self.output_names) + self.nested_weighted_metrics = training_utils.collect_metrics( + weighted_metrics, self.output_names) + + # Initialization for Eager mode execution. if context.executing_eagerly(): + # Prepare sample weights. + self._set_sample_weight_attributes(sample_weight_mode, + skip_target_weighing_indices) + if target_tensors is not None: raise ValueError('target_tensors are not currently supported in Eager ' 'mode.') self.total_loss = None - self.metrics_tensors = [] - self.metrics_names = ['loss'] for i in range(len(self.outputs)): if len(self.outputs) > 1: self.metrics_names.append(self.output_names[i] + '_loss') - self.nested_metrics = training_utils.collect_metrics(metrics, - self.output_names) - # TODO(fchollet): support stateful metrics in eager execution. - self.stateful_metric_functions = [] - self.stateful_metric_names = [] - - with K.name_scope('metrics'): - training_utils.populate_metric_names(self) - self._feed_sample_weight_modes = [] - for i in range(len(self.outputs)): - self._feed_sample_weight_modes.append(None) - self.sample_weights = [] + + # Set metric attributes on model. + self._handle_metrics( + self.outputs, + skip_target_indices=skip_target_indices, + sample_weights=self.sample_weights) + self.targets = [] for i in range(len(self.outputs)): self._feed_output_names.append(self.output_names[i]) @@ -356,78 +592,8 @@ class Model(Network): self.targets.append(target) # Prepare sample weights. - sample_weights = [] - sample_weight_modes = [] - if isinstance(sample_weight_mode, dict): - for name in sample_weight_mode: - if name not in self.output_names: - raise ValueError( - 'Unknown entry in ' - 'sample_weight_mode dictionary: "' + name + '". ' - 'Only expected the following keys: ' + str(self.output_names)) - for i, name in enumerate(self.output_names): - if i in skip_target_weighing_indices: - weight = None - sample_weight_modes.append(None) - else: - if name not in sample_weight_mode: - raise ValueError( - 'Output "' + name + '" missing from sample_weight_modes ' - 'dictionary') - if sample_weight_mode.get(name) == 'temporal': - weight = K.placeholder(ndim=2, name=name + '_sample_weights') - sample_weight_modes.append('temporal') - else: - weight = K.placeholder(ndim=1, name=name + 'sample_weights') - sample_weight_modes.append(None) - sample_weights.append(weight) - elif isinstance(sample_weight_mode, list): - if len(sample_weight_mode) != len(self.outputs): - raise ValueError('When passing a list as sample_weight_mode, ' - 'it should have one entry per model output. ' - 'The model has ' + str(len(self.outputs)) + - ' outputs, but you passed ' - 'sample_weight_mode=' + str(sample_weight_mode)) - for i in range(len(self.output_names)): - if i in skip_target_weighing_indices: - weight = None - sample_weight_modes.append(None) - else: - mode = sample_weight_mode[i] - name = self.output_names[i] - if mode == 'temporal': - weight = K.placeholder(ndim=2, name=name + '_sample_weights') - sample_weight_modes.append('temporal') - else: - weight = K.placeholder(ndim=1, name=name + '_sample_weights') - sample_weight_modes.append(None) - sample_weights.append(weight) - else: - for i, name in enumerate(self.output_names): - if i in skip_target_weighing_indices: - sample_weight_modes.append(None) - sample_weights.append(None) - else: - if sample_weight_mode == 'temporal': - sample_weights.append(array_ops.placeholder_with_default( - constant_op.constant([[1.]], dtype=K.floatx()), - shape=[None, None], name=name + '_sample_weights')) - sample_weight_modes.append('temporal') - else: - sample_weights.append(array_ops.placeholder_with_default( - constant_op.constant([1.], dtype=K.floatx()), - shape=[None], name=name + '_sample_weights')) - sample_weight_modes.append(None) - self.sample_weight_modes = sample_weight_modes - self._feed_sample_weight_modes = [] - for i in range(len(self.outputs)): - if i not in skip_target_weighing_indices: - self._feed_sample_weight_modes.append(self.sample_weight_modes[i]) - - # Prepare metrics. - self.weighted_metrics = weighted_metrics - self.metrics_names = ['loss'] - self.metrics_tensors = [] + self._set_sample_weight_attributes(sample_weight_mode, + skip_target_weighing_indices) # Compute total loss. total_loss = None @@ -438,7 +604,7 @@ class Model(Network): y_true = self.targets[i] y_pred = self.outputs[i] weighted_loss = weighted_losses[i] - sample_weight = sample_weights[i] + sample_weight = self.sample_weights[i] mask = masks[i] loss_weight = loss_weights_list[i] with K.name_scope(self.output_names[i] + '_loss'): @@ -462,84 +628,16 @@ class Model(Network): for loss_tensor in self.losses: total_loss += loss_tensor - # List of same size as output_names. - # contains tuples (metrics for output, names of metrics). - nested_metrics = training_utils.collect_metrics(metrics, self.output_names) - nested_weighted_metrics = training_utils.collect_metrics(weighted_metrics, - self.output_names) - self.metrics_updates = [] - self.stateful_metric_names = [] - self.stateful_metric_functions = [] - with K.name_scope('metrics'): - for i in range(len(self.outputs)): - if i in skip_target_indices: - continue - - y_true = self.targets[i] - y_pred = self.outputs[i] - weights = sample_weights[i] - output_metrics = nested_metrics[i] - output_weighted_metrics = nested_weighted_metrics[i] - - def handle_metrics(metrics, weights=None): - - for metric in metrics: - if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): - # custom handling of accuracy/crossentropy - # (because of class mode duality) - output_shape = self.outputs[i].get_shape().as_list() - if (output_shape[-1] == 1 or - self.loss_functions[i] == losses.binary_crossentropy): - # case: binary accuracy/crossentropy - if metric in ('accuracy', 'acc'): - metric_fn = metrics_module.binary_accuracy - elif metric in ('crossentropy', 'ce'): - metric_fn = metrics_module.binary_crossentropy - elif self.loss_functions[ - i] == losses.sparse_categorical_crossentropy: - # case: categorical accuracy/crossentropy with sparse targets - if metric in ('accuracy', 'acc'): - metric_fn = metrics_module.sparse_categorical_accuracy - elif metric in ('crossentropy', 'ce'): - metric_fn = metrics_module.sparse_categorical_crossentropy - else: - # case: categorical accuracy/crossentropy - if metric in ('accuracy', 'acc'): - metric_fn = metrics_module.categorical_accuracy - elif metric in ('crossentropy', 'ce'): - metric_fn = metrics_module.categorical_crossentropy - weighted_metric_fn = training_utils.weighted_masked_objective( - metric_fn) - else: - metric_fn = metrics_module.get(metric) - weighted_metric_fn = training_utils.weighted_masked_objective( - metric_fn) - metric_name = training_utils.get_base_metric_name( - metric, weighted=weights is not None) - with K.name_scope(metric_name): - metric_result = weighted_metric_fn( - y_true, y_pred, weights=weights, mask=masks[i]) - - training_utils.add_metric_name(self, metric_name, i) - self.metrics_tensors.append(metric_result) - - # Keep track of state updates created by - # stateful metrics (i.e. metrics layers). - if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful: - self.stateful_metric_names.append(metric_name) - self.stateful_metric_functions.append(metric_fn) - self.metrics_updates += metric_fn.updates - - handle_metrics(output_metrics) - handle_metrics(output_weighted_metrics, weights=weights) + # Invoke metric functions for all the outputs. + self._handle_metrics( + self.outputs, + masks=masks, + targets=self.targets, + skip_target_indices=skip_target_indices, + sample_weights=self.sample_weights) # Prepare gradient updates and state updates. self.total_loss = total_loss - self.sample_weights = sample_weights - self._feed_sample_weights = [] - for i in range(len(self.sample_weights)): - if i not in skip_target_weighing_indices: - self._feed_sample_weights.append(self.sample_weights[i]) # Functions for train, test and predict will # be compiled lazily when required. @@ -554,6 +652,19 @@ class Model(Network): trainable_weights = self.trainable_weights self._collected_trainable_weights = trainable_weights + def _compile_distributed_model(self, distribution_strategy): + # TODO(anjalisridhar): Can we move the clone_and_build_model to outside the + # model? + def _clone_model_per_tower(model): + new_model = training_distributed.clone_and_build_model(model) + return new_model + + with distribution_strategy.scope(): + # Create a copy of this model on each of the devices. + grouped_models = distribution_strategy.call_for_each_tower( + _clone_model_per_tower, self) + return grouped_models + def _check_trainable_weights_consistency(self): """Check trainable weights count consistency. @@ -592,7 +703,7 @@ class Model(Network): # Unconditional updates updates += self.get_updates_for(None) # Conditional updates relevant to this model - updates += self.get_updates_for(self._feed_inputs) + updates += self.get_updates_for(self.inputs) # Stateful metrics updates updates += self.metrics_updates # Gets loss and metrics. Updates weights at each call. @@ -601,7 +712,6 @@ class Model(Network): updates=updates, name='train_function', **self._function_kwargs) - self._post_build_cleanup() def _make_test_function(self): if not hasattr(self, 'test_function'): @@ -619,7 +729,6 @@ class Model(Network): updates=self.state_updates + self.metrics_updates, name='test_function', **self._function_kwargs) - self._post_build_cleanup() def _make_predict_function(self): if not hasattr(self, 'predict_function'): @@ -638,7 +747,6 @@ class Model(Network): updates=self.state_updates, name='predict_function', **kwargs) - self._post_build_cleanup() def _get_iterator_get_next_tensors(self, iterator): get_next_op = self._iterator_get_next.get(iterator, None) @@ -647,6 +755,103 @@ class Model(Network): self._iterator_get_next[iterator] = get_next_op return get_next_op + def _distribution_standardize_user_data(self, + x, + y=None, + sample_weight=None, + class_weight=None, + batch_size=None, + check_steps=False, + steps_name='steps', + steps=None, + validation_split=0): + """Runs validation checks on input and target data passed by the user. + + This is called when using DistributionStrategy to train, evaluate or serve + the model. + + Args: + x: Input data. A `tf.data` dataset. + y: Since `x` is a dataset, `y` should not be specified + (since targets will be obtained from the iterator). + sample_weight: An optional sample-weight array passed by the user to + weight the importance of each sample in `x`. + class_weight: An optional class-weight array by the user to + weight the importance of samples in `x` based on the class they belong + to, as conveyed by `y`. + batch_size: Integer batch size. If provided, it is used to run additional + validation checks on stateful models. + check_steps: boolean, True if we want to check for validity of `steps` and + False, otherwise. + steps_name: The public API's parameter name for `steps`. + steps: Integer or `None`. Total number of steps (batches of samples) to + execute. + validation_split: Float between 0 and 1. + Fraction of the training data to be used as validation data. + + Returns: + A tuple of 3 lists: input arrays, target arrays, sample-weight arrays. + If the model's input and targets are symbolic, these lists are empty + (since the model takes no user-provided data, instead the data comes + from the symbolic inputs/targets). + + Raises: + ValueError: In case of invalid user-provided data. + RuntimeError: If the model was never compiled. + """ + if sample_weight is not None and sample_weight.all(): + raise NotImplementedError('sample_weight is currently not supported when ' + 'using DistributionStrategy.') + if class_weight: + raise NotImplementedError('class_weight is currently not supported when ' + 'using DistributionStrategy.') + + # TODO(anjalisridhar): Can we use the iterator and getnext op cache? + # We require users to pass Datasets since we distribute the dataset across + # multiple devices. + if not isinstance(x, dataset_ops.Dataset): + raise ValueError('When using DistributionStrategy you must specify a ' + 'Dataset object instead of a %s.' % type(x)) + # TODO(anjalisridhar): We want distribute_dataset() to accept a Dataset or a + # function which returns a Dataset. Currently distribute_dataset() only + # accepts a function that returns a Dataset. Once we add support for being + # able to clone a Dataset on multiple workers we can remove this lambda. + result = self._distribution_strategy.distribute_dataset(lambda: x) + iterator = result.make_initializable_iterator() + K.get_session().run(iterator.initializer) + # Validates `steps` argument based on x's type. + if check_steps: + if steps is None: + raise ValueError('When using a Dataset instance as input to a model, ' + 'you should specify the `{steps_name}` argument.' + .format(steps_name=steps_name)) + + training_utils.validate_iterator_input(x, y, sample_weight, + validation_split) + # x an y may be PerDevice objects with an input and output tensor + # corresponding to each device. For example, x could be + # PerDevice:{device: get_next tensor,...}. + next_element = iterator.get_next() + + if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: + raise ValueError('Please provide data as a list or tuple of 2 elements ' + ' - input and target pair. Received %s' % next_element) + x, y = next_element + # Validate that all the elements in x and y are of the same type and shape. + # We can then pass the first element of x and y to `_standardize_weights` + # below and be confident of the output. We need to reopen the scope since + # we unwrap values when we validate x and y. + with self._distribution_strategy.scope(): + x_values, y_values = distributed_training_utils.\ + validate_distributed_dataset_inputs(self._distribution_strategy, x, y) + + _, _, sample_weights = self._standardize_weights(x_values, + y_values, + sample_weight, + class_weight, + batch_size) + return x, y, sample_weights + def _standardize_user_data(self, x, y=None, @@ -709,6 +914,18 @@ class Model(Network): ValueError: In case of invalid user-provided data. RuntimeError: If the model was never compiled. """ + if self._distribution_strategy: + return self._distribution_standardize_user_data( + x, + y, + sample_weight=sample_weight, + class_weight=class_weight, + batch_size=batch_size, + check_steps=check_steps, + steps_name=steps_name, + steps=steps, + validation_split=validation_split) + if isinstance(x, dataset_ops.Dataset): if context.executing_eagerly(): x = x.make_one_shot_iterator() @@ -757,12 +974,17 @@ class Model(Network): raise ValueError('Please provide data as a list or tuple of 2 elements ' ' - input and target pair. Received %s' % next_element) x, y = next_element + x, y, sample_weights = self._standardize_weights(x, y, sample_weight, + class_weight, batch_size) + return x, y, sample_weights + def _standardize_weights(self, x, y, sample_weight=None, class_weight=None, + batch_size=None,): # First, we build/compile the model on the fly if necessary. all_inputs = [] is_build_called = False is_compile_called = False - if not self.built: + if not self.inputs: # We need to use `x` to set the model inputs. # We type-check that `x` and `y` are either single arrays # or lists of arrays. @@ -871,13 +1093,7 @@ class Model(Network): exception_prefix='input') if y is not None: - if context.executing_eagerly(): - feed_output_names = self.output_names - feed_output_shapes = None - # Sample weighting not supported in this case. - # TODO(fchollet): consider supporting it. - feed_sample_weight_modes = [None for _ in self.outputs] - elif not self._is_graph_network: + if not self._is_graph_network: feed_output_names = self._feed_output_names feed_output_shapes = None # Sample weighting not supported in this case. @@ -890,7 +1106,11 @@ class Model(Network): for output_shape, loss_fn in zip(self._feed_output_shapes, self._feed_loss_fns): if loss_fn is losses.sparse_categorical_crossentropy: - feed_output_shapes.append(output_shape[:-1] + (1,)) + if K.image_data_format() == 'channels_first': + feed_output_shapes.append( + (output_shape[0], 1) + output_shape[2:]) + else: + feed_output_shapes.append(output_shape[:-1] + (1,)) elif (not hasattr(loss_fn, '__name__') or getattr(losses, loss_fn.__name__, None) is None): # If `loss_fn` is not a function (e.g. callable class) @@ -921,11 +1141,12 @@ class Model(Network): feed_sample_weight_modes) ] # Check that all arrays have the same length. - training_utils.check_array_lengths(x, y, sample_weights) - if self._is_graph_network and not context.executing_eagerly(): - # Additional checks to avoid users mistakenly using improper loss fns. - training_utils.check_loss_and_target_compatibility( - y, self._feed_loss_fns, feed_output_shapes) + if not self._distribution_strategy: + training_utils.check_array_lengths(x, y, sample_weights) + if self._is_graph_network and not context.executing_eagerly(): + # Additional checks to avoid users mistakenly using improper loss fns. + training_utils.check_loss_and_target_compatibility( + y, self._feed_loss_fns, feed_output_shapes) else: y = [] sample_weights = [] @@ -941,6 +1162,7 @@ class Model(Network): str(x[0].shape[0]) + ' samples') return x, y, sample_weights + @checkpointable.no_automatic_dependency_tracking def _set_inputs(self, inputs, training=None): """Set model's input and output specs based on the input data received. @@ -973,22 +1195,18 @@ class Model(Network): 'in their call() signatures do not yet support shape inference. File ' 'a feature request if this limitation bothers you.') if self.__class__.__name__ == 'Sequential': - # Note: we can't test whether the model is `Sequential` via `isinstance` - # since `Sequential` depends on `Model`. - if isinstance(inputs, list): - assert len(inputs) == 1 - inputs = inputs[0] - if tensor_util.is_tensor(inputs): input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:]) + self.build(input_shape=input_shape) else: input_shape = (None,) + inputs.shape[1:] - self.build(input_shape=input_shape) - elif context.executing_eagerly(): + self.build(input_shape=input_shape) + if context.executing_eagerly(): self._eager_set_inputs(inputs) else: self._symbolic_set_inputs(inputs, training=training) + @checkpointable.no_automatic_dependency_tracking def _eager_set_inputs(self, inputs): """Set model's input and output specs based on the input data received. @@ -1041,6 +1259,7 @@ class Model(Network): 'output_%d' % (i + 1) for i in range(len(dummy_output_values))] self.built = True + @checkpointable.no_automatic_dependency_tracking def _symbolic_set_inputs(self, inputs, outputs=None, training=None): """Set model's inputs and output specs based. @@ -1173,7 +1392,7 @@ class Model(Network): 0 = silent, 1 = progress bar, 2 = one line per epoch. callbacks: List of `keras.callbacks.Callback` instances. List of callbacks to apply during training. - See [callbacks](/callbacks). + See [callbacks](/api_docs/python/tf/keras/callbacks). validation_split: Float between 0 and 1. Fraction of the training data to be used as validation data. The model will set apart this fraction of the training data, @@ -1256,6 +1475,9 @@ class Model(Network): raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) # Validate and standardize user data. + if self._distribution_strategy: + distributed_training_utils.validate_callbacks(callbacks) + x, y, sample_weights = self._standardize_user_data( x, y, @@ -1336,6 +1558,17 @@ class Model(Network): initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps) + elif self._distribution_strategy: + return training_distributed.fit_loop( + self, x, y, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + val_inputs=val_x, + val_targets=val_y, + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps) else: return training_arrays.fit_loop( self, x, y, @@ -1428,12 +1661,29 @@ class Model(Network): if context.executing_eagerly(): return training_eager.test_loop( - self, inputs=x, targets=y, sample_weights=sample_weights, - batch_size=batch_size, verbose=verbose, steps=steps) + self, + inputs=x, + targets=y, + sample_weights=sample_weights, + batch_size=batch_size, + verbose=verbose, + steps=steps) + elif self._distribution_strategy: + return training_distributed.test_loop( + self, + inputs=x, + targets=y, + verbose=verbose, + steps=steps) else: return training_arrays.test_loop( - self, inputs=x, targets=y, sample_weights=sample_weights, - batch_size=batch_size, verbose=verbose, steps=steps) + self, + inputs=x, + targets=y, + sample_weights=sample_weights, + batch_size=batch_size, + verbose=verbose, + steps=steps) def predict(self, x, batch_size=None, verbose=0, steps=None): """Generates output predictions for the input samples. @@ -1478,6 +1728,9 @@ class Model(Network): if context.executing_eagerly(): return training_eager.predict_loop( self, x, batch_size=batch_size, verbose=verbose, steps=steps) + elif self._distribution_strategy: + return training_distributed.predict_loop( + self, x, verbose=verbose, steps=steps) else: return training_arrays.predict_loop( self, x, batch_size=batch_size, verbose=verbose, steps=steps) @@ -1525,6 +1778,9 @@ class Model(Network): Raises: ValueError: In case of invalid user-provided arguments. """ + if self._distribution_strategy: + raise NotImplementedError('`train_on_batch` is not supported for models ' + 'compiled with DistributionStrategy.') # Validate and standardize user data. x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight, class_weight=class_weight) @@ -1581,6 +1837,9 @@ class Model(Network): Raises: ValueError: In case of invalid user-provided arguments. """ + if self._distribution_strategy: + raise NotImplementedError('`test_on_batch` is not supported for models ' + 'compiled with DistributionStrategy.') # Validate and standardize user data. x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight) @@ -1618,6 +1877,9 @@ class Model(Network): ValueError: In case of mismatch between given number of inputs and expectations of the model. """ + if self._distribution_strategy: + raise NotImplementedError('`predict_on_batch` is not supported for ' + 'models compiled with DistributionStrategy.') # Validate and standardize user data. inputs, _, _ = self._standardize_user_data(x) if context.executing_eagerly(): @@ -1748,6 +2010,10 @@ class Model(Network): Raises: ValueError: In case the generator yields data in an invalid format. """ + if self._distribution_strategy: + raise NotImplementedError('`fit_generator` is not supported for ' + 'models compiled with DistributionStrategy.') + if not self.built and not self._is_graph_network: raise NotImplementedError( '`fit_generator` is not yet enabled for unbuilt Model subclasses') @@ -1815,6 +2081,10 @@ class Model(Network): Raises: ValueError: In case the generator yields data in an invalid format. """ + if self._distribution_strategy: + raise NotImplementedError('`evaluate_generator` is not supported for ' + 'models compiled with DistributionStrategy.') + if not self.built and not self._is_graph_network: raise NotImplementedError( '`evaluate_generator` is not yet enabled for ' @@ -1869,6 +2139,10 @@ class Model(Network): Raises: ValueError: In case the generator yields data in an invalid format. """ + if self._distribution_strategy: + raise NotImplementedError('`predict_generator` is not supported for ' + 'models compiled with DistributionStrategy.') + if not self.built and not self._is_graph_network: raise NotImplementedError( '`predict_generator` is not yet enabled for unbuilt Model subclasses') @@ -1881,3 +2155,59 @@ class Model(Network): workers=workers, use_multiprocessing=use_multiprocessing, verbose=verbose) + + def _get_callback_model(self): + """Returns the Callback Model for this Model.""" + + if hasattr(self, '_replicated_model') and self._replicated_model: + # When using training_distributed, we set the callback model + # to an instance of the `DistributedModel` that we create in + # the `compile` call. The `DistributedModel` is initialized + # with the first replicated model. We need to set the callback + # model to a DistributedModel to allow us to override saving + # and loading weights when we checkpoint the model during training. + return self._replicated_model + if hasattr(self, 'callback_model') and self.callback_model: + return self.callback_model + return self + + +class DistributedCallbackModel(Model): + """Model that is used for callbacks with DistributionStrategy.""" + + def __init__(self, model): + super(DistributedCallbackModel, self).__init__() + # TODO(anjalisridhar): Right now the only attributes set are the layer and + # weights. We may need to set additional attributes as needed since we have + # not called compile on this model. + + def set_original_model(self, orig_model): + self._original_model = orig_model + + def save_weights(self, filepath, overwrite=True, save_format=None): + self._replicated_model.save_weights(filepath, overwrite=overwrite, + save_format=save_format) + + def save(self, filepath, overwrite=True, include_optimizer=True): + # save weights from the distributed model to the original model + distributed_model_weights = self.get_weights() + self._original_model.set_weights(distributed_model_weights) + # TODO(anjalisridhar): Do we need to save the original model here? + # Saving the first replicated model works as well. + self._original_model.save(filepath, overwrite=True, include_optimizer=False) + + def load_weights(self, filepath, by_name=False): + self._original_model.load_weights(filepath, by_name=False) + # Copy the weights from the original model to each of the replicated models. + orig_model_weights = self._original_model.get_weights() + distributed_training_utils.set_weights( + self._original_model._distribution_strategy, self, # pylint: disable=protected-access + orig_model_weights) + + def __getattr__(self, item): + # Whitelisted atttributes of the model that can be accessed by the user + # during a callback. + if item not in ['_setattr_tracking']: + logging.warning('You are accessing attribute ' + item + 'of the' + 'DistributedCallbackModel that may not have been set' + 'correctly.') diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py index 281ad9bd50edf519f520fe4aa664ae05b72528d8..e2c458c65f27c5802acd9186e9bcedd4062e5a2a 100644 --- a/tensorflow/python/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/engine/training_arrays.py @@ -19,8 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - import numpy as np from tensorflow.python.framework import errors @@ -50,7 +48,6 @@ def fit_loop(model, val_targets=None, val_sample_weights=None, shuffle=True, - callback_metrics=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): @@ -69,8 +66,6 @@ def fit_loop(model, val_targets: List of target arrays. val_sample_weights: Optional list of sample weight arrays. shuffle: Whether to shuffle the data at the beginning of each epoch - callback_metrics: List of strings, the display names of the metrics - passed to the callbacks. They should be the concatenation of list the display names of the outputs of `f` and the list of display names of the outputs of `f_val`. initial_epoch: Epoch at which to start training @@ -95,14 +90,8 @@ def fit_loop(model, val_sample_weights = val_sample_weights or [] if model.uses_learning_phase and not isinstance(K.learning_phase(), int): ins = inputs + targets + sample_weights + [1] - if val_inputs: - val_ins = val_inputs + val_targets + val_sample_weights + [1] else: ins = inputs + targets + sample_weights - if val_inputs: - val_ins = val_inputs + val_targets + val_sample_weights - if not val_inputs: - val_ins = [] do_validation = False if val_inputs: @@ -119,57 +108,27 @@ def fit_loop(model, 'training, i.e. `steps_per_epoch` ' 'must be set.') - out_labels = model.metrics_names - if do_validation: - callback_metrics = copy.copy(out_labels) + [ - 'val_' + n for n in out_labels - ] - else: - callback_metrics = copy.copy(out_labels) - num_train_samples = training_utils.check_num_samples( ins, batch_size, steps_per_epoch, 'steps_per_epoch') + count_mode = 'steps' if steps_per_epoch else 'samples' + callbacks = cbks.configure_callbacks( + callbacks, + model, + do_validation=do_validation, + val_inputs=val_inputs, + val_targets=val_targets, + val_sample_weights=val_sample_weights, + batch_size=batch_size, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + samples=num_train_samples, + validation_steps=validation_steps, + verbose=verbose, + count_mode=count_mode) + if num_train_samples is not None: index_array = np.arange(num_train_samples) - model.history = cbks.History() - all_callbacks = [cbks.BaseLogger( - stateful_metrics=model.stateful_metric_names)] - if verbose: - if steps_per_epoch is not None: - count_mode = 'steps' - else: - count_mode = 'samples' - all_callbacks.append( - cbks.ProgbarLogger( - count_mode, stateful_metrics=model.stateful_metric_names)) - all_callbacks += (callbacks or []) + [model.history] - callbacks = cbks.CallbackList(all_callbacks) - out_labels = out_labels or [] - - # it's possible to callback a different model than self - # (used by Sequential models) - if hasattr(model, 'callback_model') and model.callback_model: - callback_model = model.callback_model - else: - callback_model = model - - callbacks.set_model(callback_model) - - callbacks.set_params({ - 'batch_size': batch_size, - 'epochs': epochs, - 'steps': steps_per_epoch, - 'samples': num_train_samples, - 'verbose': verbose, - 'do_validation': do_validation, - 'metrics': callback_metrics or [], - }) - callbacks.on_train_begin() - callback_model.stop_training = False - for cbk in callbacks: - cbk.validation_data = val_ins - # To prevent a slowdown, we find beforehand the arrays that need conversion. feed = model._feed_inputs + model._feed_targets + model._feed_sample_weights indices_for_conversion_to_dense = [] @@ -177,6 +136,7 @@ def fit_loop(model, if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]): indices_for_conversion_to_dense.append(i) + callbacks.on_train_begin() for epoch in range(initial_epoch, epochs): # Reset stateful metrics for m in model.stateful_metric_functions: @@ -187,9 +147,7 @@ def fit_loop(model, if steps_per_epoch is not None: # Step-wise fit loop. for step_index in range(steps_per_epoch): - batch_logs = {} - batch_logs['batch'] = step_index - batch_logs['size'] = 1 + batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) try: outs = f(ins) @@ -197,17 +155,19 @@ def fit_loop(model, logging.warning('Your dataset iterator ran out of data; ' 'interrupting training. Make sure that your dataset ' 'can generate at least `steps_per_epoch * epochs` ' - 'batches (in this case, %d batches).' % + 'batches (in this case, %d batches). You may need to' + 'use the repeat() function when building your ' + 'dataset.' % steps_per_epoch * epochs) break if not isinstance(outs, list): outs = [outs] - for l, o in zip(out_labels, outs): + for l, o in zip(model.metrics_names, outs): batch_logs[l] = o callbacks.on_batch_end(step_index, batch_logs) - if callback_model.stop_training: + if callbacks.model.stop_training: break if do_validation: @@ -221,7 +181,7 @@ def fit_loop(model, if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. - for l, o in zip(out_labels, val_outs): + for l, o in zip(model.metrics_names, val_outs): epoch_logs['val_' + l] = o else: # Sample-wise fit loop. @@ -254,11 +214,11 @@ def fit_loop(model, outs = f(ins_batch) if not isinstance(outs, list): outs = [outs] - for l, o in zip(out_labels, outs): + for l, o in zip(model.metrics_names, outs): batch_logs[l] = o callbacks.on_batch_end(batch_index, batch_logs) - if callback_model.stop_training: + if callbacks.model.stop_training: break if batch_index == len(batches) - 1: # Last batch. @@ -273,10 +233,10 @@ def fit_loop(model, if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. - for l, o in zip(out_labels, val_outs): + for l, o in zip(model.metrics_names, val_outs): epoch_logs['val_' + l] = o callbacks.on_epoch_end(epoch, epoch_logs) - if callback_model.stop_training: + if callbacks.model.stop_training: break callbacks.on_train_end() return model.history @@ -378,7 +338,9 @@ def predict_loop(model, inputs, batch_size=32, verbose=0, steps=None): return outs -def test_loop(model, inputs, targets, +def test_loop(model, + inputs, + targets, sample_weights=None, batch_size=None, verbose=0, @@ -475,8 +437,7 @@ def test_loop(model, inputs, targets, if isinstance(batch_outs, list): if batch_index == 0: - for batch_out in enumerate(batch_outs): - outs.append(0.) + outs.extend([0.] * len(batch_outs)) for i, batch_out in enumerate(batch_outs): if i in stateful_metric_indices: outs[i] = batch_out diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..5feedc43a587b82633dd960baeca5b60d3970f41 --- /dev/null +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -0,0 +1,421 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Part of the Keras training engine related to distributed training. +""" +# pylint: disable=protected-access +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from tensorflow.python.framework import errors +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import callbacks as cbks +from tensorflow.python.keras import optimizers +from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.platform import tf_logging as logging + + +def fit_loop( + model, + inputs, + targets, + epochs=100, + verbose=1, + callbacks=None, + val_inputs=None, + val_targets=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None): + """fit function when using DistributionStrategy for training. + + Arguments: + model: Keras Model instance. + inputs: List of input arrays. + targets: List of target arrays. + epochs: Number of times to iterate over the data + verbose: Verbosity mode, 0, 1 or 2 + callbacks: List of callbacks to be called during training + val_inputs: List of input arrays. + val_targets: List of target arrays. + initial_epoch: Epoch at which to start training + (useful for resuming a previous training run) + steps_per_epoch: Total number of steps (batches of samples) + before declaring one epoch finished and starting the + next epoch. Ignored with the default value of `None`. + validation_steps: Number of steps to run validation for + (only if doing validation from data tensors). + Ignored with the default value of `None`. + + Returns: + `History` object. + + Raises: + ValueError: in case of invalid arguments. + """ + current_strategy = model._distribution_strategy + def _per_device_train_function(model): + model._make_train_function() + return (model.train_function.inputs, + model.train_function.outputs, + model.train_function.updates_op, + model.train_function.session_kwargs) + + with current_strategy.scope(): + # Create train ops on each of the devices when we call + # `_per_device_train_function`. + (grouped_inputs, grouped_outputs, grouped_updates, + grouped_session_args) = current_strategy.call_for_each_tower( + _per_device_train_function, model._grouped_model) + # Unwrap all the per device values returned from `call_for_each_tower`. + # Unwrapping per device values gives you a list of values that can be + # used to construct a new train function that is composed of update ops on + # all the devices over which the model is distributed. + (all_inputs, all_outputs, all_updates, + all_session_args) = distributed_training_utils.unwrap_values( + current_strategy, grouped_inputs, grouped_outputs, + grouped_updates, grouped_session_args, with_loss_tensor=True) + + # Dataset inputs and targets are also per devices values that need to be + # unwrapped. + dataset_inputs = distributed_training_utils.flatten_perdevice_values( + current_strategy, inputs) + dataset_targets = distributed_training_utils.flatten_perdevice_values( + current_strategy, targets) + + # Create a train function that is composed of all the parameters above. + distributed_train_function = K.Function( + all_inputs, all_outputs, + updates=all_updates, + name='distributed_train_function', + **all_session_args) + + # We need to set sample_weights to None since there are sample weight + # placeholders that are created with default values. + sample_weights = [None for _ in range(len(model.outputs) * + current_strategy.num_towers)] + if model.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = dataset_inputs + dataset_targets + sample_weights + [1] + else: + ins = dataset_inputs + dataset_targets + + do_validation = False + if validation_steps: + do_validation = True + if steps_per_epoch is None: + raise ValueError('Can only use `validation_steps` ' + 'when doing step-wise ' + 'training, i.e. `steps_per_epoch` ' + 'must be set.') + + # Copy the weights from the original model to each of the replicated models. + orig_model_weights = model.get_weights() + with current_strategy.scope(): + distributed_model = current_strategy.unwrap(model._grouped_model)[0] + distributed_training_utils.set_weights( + current_strategy, distributed_model, orig_model_weights) + + callbacks = cbks.configure_callbacks( + callbacks, + model, + do_validation=do_validation, + val_inputs=None, + val_targets=None, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + verbose=verbose) + out_labels = model.metrics_names or [] + callbacks.on_train_begin() + for epoch in range(initial_epoch, epochs): + callbacks.on_epoch_begin(epoch) + if steps_per_epoch is not None: + epoch_logs = {} + for step_index in range(steps_per_epoch): + batch_logs = {'batch': step_index, 'size': 1} + callbacks.on_batch_begin(step_index, batch_logs) + try: + outs = distributed_train_function(ins) + except errors.OutOfRangeError: + logging.warning('Your dataset iterator ran out of data; ' + 'interrupting training. Make sure that your dataset ' + 'can generate at least `steps_per_epoch * epochs` ' + 'batches (in this case, %d batches).' % + steps_per_epoch * epochs) + break + + if not isinstance(outs, list): + outs = [outs] + + outs = _aggregate_metrics_across_towers( + len(current_strategy._devices), out_labels, outs) + for l, o in zip(out_labels, outs): + batch_logs[l] = o + callbacks.on_batch_end(step_index, batch_logs) + if callbacks.model.stop_training: + break + if do_validation: + val_outs = test_loop( + model, + val_inputs, + val_targets, + steps=validation_steps, + verbose=0) + if not isinstance(val_outs, list): + val_outs = [val_outs] + # Same labels assumed. + for l, o in zip(out_labels, val_outs): + epoch_logs['val_' + l] = o + + callbacks.on_epoch_end(epoch, epoch_logs) + if callbacks.model.stop_training: + break + callbacks.on_train_end() + + # Copy the weights back from the replicated model to the original model. + with current_strategy.scope(): + updated_weights = current_strategy.unwrap( + model._grouped_model)[0].get_weights() + model.set_weights(updated_weights) + return model.history + + +def test_loop(model, inputs, targets, verbose=0, steps=None): + """evaluate method to validate a model that uses DistributionStrategy. + + Arguments: + model: Keras Model instance. + inputs: List of input arrays. + targets: List of target arrays. + verbose: verbosity mode. + steps: Total number of steps (batches of samples) + before declaring predictions finished. + Ignored with the default value of `None`. + + Returns: + Scalar loss (if the model has a single output and no metrics) + or list of scalars (if the model has multiple outputs + and/or metrics). The attribute `model.metrics_names` will give you + the display labels for the scalar outputs. + """ + current_strategy = model._distribution_strategy + def _per_device_test_function(model): + model._make_test_function() + return (model.test_function.inputs, + model.test_function.outputs, + model.test_function.updates_op, + model.test_function.session_kwargs) + + with current_strategy.scope(): + (grouped_inputs, grouped_outputs, grouped_updates, + grouped_session_args) = current_strategy.call_for_each_tower( + _per_device_test_function, model._grouped_model) + + (all_inputs, all_outputs, all_updates, + all_session_args) = distributed_training_utils.unwrap_values( + current_strategy, grouped_inputs, grouped_outputs, grouped_updates, + grouped_session_args, with_loss_tensor=True) + + dataset_inputs = distributed_training_utils.flatten_perdevice_values( + current_strategy, inputs) + dataset_targets = distributed_training_utils.flatten_perdevice_values( + current_strategy, targets) + + distributed_test_function = K.Function( + all_inputs, all_outputs, + updates=all_updates, + name='distributed_test_function', + **all_session_args) + + # We need to set sample_weights to None since there are sample weight + # placeholders that are created with default values. + sample_weights = [None for _ in range(len(model.outputs) * + current_strategy.num_towers)] + if model.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = dataset_inputs + dataset_targets + sample_weights + [0] + else: + ins = dataset_inputs + dataset_targets + + outs = [] + if verbose == 1: + progbar = Progbar(target=steps) + + # Copy the weights from the original model to each of the replicated models. + orig_model_weights = model.get_weights() + with current_strategy.scope(): + distributed_model = current_strategy.unwrap(model._grouped_model)[0] + distributed_training_utils.set_weights( + current_strategy, distributed_model, orig_model_weights) + + if steps is not None: + for step in range(steps): + batch_outs = distributed_test_function(ins) + batch_outs = _aggregate_metrics_across_towers( + len(current_strategy._devices), model.metrics_names, batch_outs) + if isinstance(batch_outs, list): + if step == 0: + for _ in enumerate(batch_outs): + outs.append(0.) + for i, batch_out in enumerate(batch_outs): + outs[i] += batch_out + else: + if step == 0: + outs.append(0.) + outs[0] += batch_outs + if verbose == 1: + progbar.update(step + 1) + for i in range(len(outs)): + outs[i] /= steps + + if len(outs) == 1: + return outs[0] + return outs + + +def predict_loop(model, inputs, verbose=0, steps=None): + """Abstract method to loop over some data in batches. + + Arguments: + model: Keras Model instance. + inputs: list of tensors to be fed to `f`. + verbose: verbosity mode. + steps: Total number of steps (batches of samples) + before declaring `_predict_loop` finished. + Ignored with the default value of `None`. + + Returns: + Array of predictions (if the model has a single output) + or list of arrays of predictions + (if the model has multiple outputs). + """ + current_strategy = model._distribution_strategy + def _per_device_predict_function(model): + model._make_predict_function() + return (model.predict_function.inputs, + model.predict_function.outputs, + model.predict_function.updates_op, + model.predict_function.session_kwargs) + + with current_strategy.scope(): + (grouped_inputs, grouped_outputs, grouped_updates, + grouped_session_args) = current_strategy.call_for_each_tower( + _per_device_predict_function, model._grouped_model) + + (all_inputs, all_outputs, all_updates, + all_session_args) = distributed_training_utils.unwrap_values( + current_strategy, grouped_inputs, grouped_outputs, grouped_updates, + grouped_session_args) + + dataset_inputs = distributed_training_utils.flatten_perdevice_values( + current_strategy, inputs) + + distributed_predict_function = K.Function( + all_inputs, all_outputs, + updates=all_updates, + name='distributed_predict_function', + **all_session_args) + + if model.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = dataset_inputs + [0] + else: + ins = dataset_inputs + + if verbose == 1: + progbar = Progbar(target=steps) + + # Copy the weights from the original model to each of the replicated models. + orig_model_weights = model.get_weights() + with current_strategy.scope(): + distributed_model = current_strategy.unwrap(model._grouped_model)[0] + distributed_training_utils.set_weights( + current_strategy, distributed_model, orig_model_weights) + + if steps is not None: + # Since we do not know how many samples we will see, we cannot pre-allocate + # the returned Numpy arrays. Instead, we store one array per batch seen + # and concatenate them upon returning. + unconcatenated_outs = [] + for step in range(steps): + batch_outs = distributed_predict_function(ins) + if not isinstance(batch_outs, list): + batch_outs = [batch_outs] + if step == 0: + for _ in batch_outs: + unconcatenated_outs.append([]) + for i, batch_out in enumerate(batch_outs): + unconcatenated_outs[i].append(batch_out) + if verbose == 1: + progbar.update(step + 1) + if len(unconcatenated_outs) == 1: + return np.concatenate(unconcatenated_outs[0], axis=0) + return [ + np.concatenate(unconcatenated_outs[i], axis=0) + for i in range(len(unconcatenated_outs)) + ] + + +def clone_and_build_model(model): + """Clone and build the given keras_model.""" + # We need to set the import here since we run into a circular dependency + # error. + from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top + cloned_model = models.clone_model(model, input_tensors=None) + + # Compile and build model. + if isinstance(model.optimizer, optimizers.TFOptimizer): + optimizer = model.optimizer + else: + optimizer_config = model.optimizer.get_config() + optimizer = model.optimizer.__class__.from_config(optimizer_config) + + cloned_model.compile( + optimizer, + model.loss, + metrics=model.metrics, + loss_weights=model.loss_weights, + sample_weight_mode=model.sample_weight_mode, + weighted_metrics=model.weighted_metrics) + return cloned_model + + +def _aggregate_metrics_across_towers(num_devices, out_labels, outs): + """Aggregate metrics values across all towers. + + When using `MirroredStrategy`, the number of towers is equal to the + number of devices over which training is distributed. This may not always be + the case. + + Args: + num_devices: Number of devices over which the model is being distributed. + out_labels: The list of metric names passed to `compile`. + outs: The output from all the towers. + + Returns: + The average value of each metric across the towers. + """ + # TODO(anjalisridhar): Temporary workaround for aggregating metrics + # across towers. Replace with the new metrics module eventually. + merged_output = [] + # The first output is the total loss. + merged_output.append(outs[0]) + current_index = 1 + # Each label in `out_labels` corresponds to one set of metrics. The + # number of metric values corresponds to the number of devices. We + # currently take the mean of the values. + for _ in out_labels[1:]: + m = np.mean(outs[current_index:current_index + num_devices]) + merged_output.append(m) + current_index += num_devices + return merged_output diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py index e8838cd3bca7b3afba80504f9e705943474423c5..1e377149b64ff6d810d59809eee5a3f1630ecdd6 100644 --- a/tensorflow/python/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -30,79 +30,36 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.keras import callbacks as cbks -from tensorflow.python.keras import losses -from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.utils import generic_utils -from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging -def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None): - if metric == 'accuracy' or metric == 'acc': - # custom handling of accuracy - # (because of class mode duality) - output_shape = internal_output_shapes - if output_shape[-1] == 1 or loss_func == losses.binary_crossentropy: - # case: binary accuracy - acc_fn = metrics_module.binary_accuracy - elif loss_func == losses.sparse_categorical_crossentropy: - # case: categorical accuracy with sparse targets - acc_fn = metrics_module.sparse_categorical_accuracy - else: - acc_fn = metrics_module.categorical_accuracy - - metric_name = 'acc' - return metric_name, acc_fn - else: - metric_fn = metrics_module.get(metric) - metric_name = metric_fn.__name__ - return metric_name, metric_fn - - def _eager_loss_fn(outputs, targets, loss_fn, output_name): with backend.name_scope(output_name + '_loss'): loss = loss_fn(targets, outputs) return loss -def _eager_metrics_fn(model, outputs, targets): +def _eager_metrics_fn(model, outputs, targets, sample_weights=None, masks=None): """Calculates the metrics for each output of the given model. Arguments: model: The model on which metrics are being calculated. outputs: The outputs of the given model. targets: The predictions or targets of the given model. + sample_weights: Optional list of sample weights for each output. + masks: Optional list of masks for each output. Returns: - Returns the metric names and metric results for each output of the model. + Returns the metric results for each output of the model. """ - metric_names = [] - metric_results = [] - if not isinstance(outputs, list): - outputs = [outputs] - - if not isinstance(targets, list): - targets = [targets] - - for i in range(len(model.outputs)): - output_metrics = model.nested_metrics[i] - for nested_output_metric in output_metrics: - metric_name, metric_fn = _get_metrics_info( - nested_output_metric, backend.int_shape(model.outputs[i]), - model.loss_functions[i]) - - if len(model.output_names) > 1: - metric_name = model.output_names[i] + '_' + metric_name - if metric_name not in model.metrics_names: - model.metrics_names.append(metric_name) - - with backend.name_scope(metric_name): - metric_result = metric_fn(targets[i], outputs[i]) - metric_names.append(metric_name) - metric_results.append(backend.mean(metric_result)) - - return metric_results + outputs = generic_utils.to_list(outputs) + targets = generic_utils.to_list(targets) + # TODO(psv): Consider supporting skip target indices in eager mode? + metric_results = model._handle_metrics( + outputs, targets=targets, sample_weights=sample_weights, masks=masks) + return [backend.mean(t) for t in metric_results] def _model_loss(model, inputs, targets, sample_weights=None, training=False): @@ -116,26 +73,29 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False): training: Whether the model should be run in inference or training mode. Returns: - Returns the model output, total loss and loss value calculated using the - specified loss function. The total loss includes regularization losses and - applies masking and sample weighting to the loss value. + Returns the model output, total loss, loss value calculated using the + specified loss function and masks for each output. The total loss includes + regularization losses and applies masking and sample weighting + to the loss value. """ total_loss = 0 + kwargs = {} + if model._expects_training_arg: + kwargs['training'] = training if len(inputs) == 1: - if model._expects_training_arg: - outs = model.call(inputs[0], training=training) - else: - outs = model.call(inputs[0]) + inputs = inputs[0] + + if model._compute_output_and_mask_jointly: + outs, masks = model._call_and_compute_mask(inputs, **kwargs) + masks = generic_utils.to_list(masks) else: - if model._expects_training_arg: - outs = model.call(inputs, training=training) - else: - outs = model.call(inputs) - if not isinstance(outs, list): - outs = [outs] + outs = model.call(inputs, **kwargs) + masks = None - if not isinstance(targets, list): - targets = [targets] + outs = generic_utils.to_list(outs) + if masks is None: + masks = [None for _ in outs] + targets = generic_utils.to_list(targets) loss_metrics = [] with backend.name_scope('loss'): @@ -144,10 +104,7 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False): weights = sample_weights[i] else: weights = None - - # TODO(fchollet): support masking; in practice `_keras_mask` is never - # set in this context currently. - mask = outs[i]._keras_mask + mask = masks[i] weighted_masked_fn = training_utils.weighted_masked_objective(loss_fn) with backend.name_scope(model.output_names[i] + '_loss'): @@ -176,15 +133,13 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False): if custom_losses: total_loss += sum(custom_losses) - return outs, total_loss, loss_metrics + return outs, total_loss, loss_metrics, masks def iterator_fit_loop(model, inputs, class_weight, steps_per_epoch, - callback_model, - out_labels, epoch_logs, val_inputs=None, val_targets=None, @@ -192,9 +147,9 @@ def iterator_fit_loop(model, epochs=1, verbose=1, callbacks=None, - callback_metrics=None, validation_steps=None, - do_validation=False): + do_validation=False, + batch_size=None): """Fit function for eager execution when input is given as dataset iterator. Updates the given epoch logs. @@ -208,32 +163,33 @@ def iterator_fit_loop(model, steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. - callback_model: Instance of `Model` to callback. - out_labels: Output labels generated from model metric names. epoch_logs: Dictionary of logs from every epoch. val_inputs: Input data for validation. val_targets: Target data for validation. val_sample_weights: Sample weight data for validation. epochs: Number of times to iterate over the data verbose: Verbosity mode, 0, 1 or 2 - callbacks: List of callbacks to be called during training - callback_metrics: List of strings, the display names of the metrics - passed to the callbacks. They should be the - concatenation of list the display names of the outputs of - `f` and the list of display names of the outputs of `f_val`. + callbacks: CallbackList instance. Controls callbacks during training. validation_steps: Number of steps to run validation for (only if doing validation from data tensors). Ignored with default value of `None`. do_validation: Boolean value indicating whether we should do validation. + batch_size: int, val_inputs and val_targets will be evaled batch by + batch with size batch_size if they are array. Raises: ValueError: In case of mismatch between given number of inputs and expectations of the model. """ assert isinstance(inputs, iterator_ops.EagerIterator) + + # make sure either x,y or x,y,sample_weights is provided + if (not isinstance(inputs.output_shapes, (list, tuple)) or + len(inputs.output_shapes) not in (2, 3)): + raise ValueError('Please provide either inputs and targets' + 'or inputs, targets, and sample_weights') + for step_index in range(steps_per_epoch): - batch_logs = {} - batch_logs['batch'] = step_index - batch_logs['size'] = 1 + batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) # Get data from the iterator. @@ -241,62 +197,72 @@ def iterator_fit_loop(model, next_element = inputs.get_next() except errors.OutOfRangeError: logging.warning( - 'Your dataset iterator ran out of data; ' - 'interrupting training. Make sure that your dataset' - ' can generate at least `steps_per_epoch * epochs` ' - 'batches (in this case, %d batches).' % steps_per_epoch * epochs) + 'Your dataset iterator ran out of data; interrupting training. Make ' + 'sure that your dataset can generate at least ' + '`steps_per_epoch * epochs` batches (in this case, %d batches). You ' + 'may need to use the repeat() function when building your ' + 'dataset.' % steps_per_epoch * epochs) break - if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: - raise ValueError('Please provide data as a list or tuple of 2 elements ' - ' - input and target pair. Received %s' % next_element) - x, y = next_element + if len(inputs.output_shapes) == 2: + x, y = next_element + sample_weights = None + else: + x, y, sample_weights = next_element # Validate and standardize data. x, y, sample_weights = model._standardize_user_data( - x, y, class_weight=class_weight) + x, y, sample_weight=sample_weights, class_weight=class_weight) x = training_utils.cast_if_floating_dtype(x) y = training_utils.cast_if_floating_dtype(y) if sample_weights: sample_weights = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) + training_utils.cast_if_floating_dtype( + ops.convert_to_tensor(val, dtype=backend.floatx())) if val is not None else None for val in sample_weights ] - if step_index == 0 and not callback_metrics: - out_labels = model.metrics_names + # Set stateful_metrics in callbacks. We do not do this before the + # `steps_per_epoch` loop because model will be compiled only in the first + # iteration of this loop in the deferred build scenario. + if step_index == 0: + for cbk in callbacks: + if (isinstance(cbk, cbks.BaseLogger) or + isinstance(cbk, cbks.ProgbarLogger)): + cbk.stateful_metrics = model.stateful_metric_names + + if step_index == 0 and not callbacks.params['metrics']: + callback_metrics = copy.copy(model.metrics_names) if do_validation: - callback_metrics = copy.copy(out_labels) + [ - 'val_' + n for n in out_labels - ] - else: - callback_metrics = copy.copy(out_labels) + callback_metrics += ['val_' + n for n in model.metrics_names] callbacks.set_params({ + 'batch_size': batch_size, 'epochs': epochs, 'steps': steps_per_epoch, 'verbose': verbose, 'do_validation': do_validation, 'metrics': callback_metrics or [], + 'validation_steps': validation_steps }) # Train model. - outs, loss, loss_metrics = _process_single_batch( + outs, loss, loss_metrics, masks = _process_single_batch( model, x, y, sample_weights=sample_weights, training=True) - if not isinstance(outs, list): - outs = [outs] + outs = generic_utils.to_list(outs) # Calculate metrics. - for l, o in zip(out_labels, outs): + for l, o in zip(model.metrics_names, outs): batch_logs[l] = o # Required for eager execution - metrics_results = _eager_metrics_fn(model, outs, y) + metrics_results = _eager_metrics_fn( + model, outs, y, sample_weights=sample_weights, masks=masks) batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss)) for k, v in zip(model.metrics_names, [backend.mean(loss)] + loss_metrics + metrics_results): batch_logs[k] = tensor_util.constant_value(v) callbacks.on_batch_end(step_index, batch_logs) - if callback_model.stop_training: + if callbacks.model.stop_training: break if step_index == steps_per_epoch - 1: @@ -307,126 +273,12 @@ def iterator_fit_loop(model, val_targets, sample_weights=val_sample_weights, steps=validation_steps, - verbose=0) - if not isinstance(val_outs, list): - val_outs = [val_outs] - # Same labels assumed. - for l, o in zip(out_labels, val_outs): - epoch_logs['val_' + l] = o - - -def batch_fit_loop(model, - inputs, - targets, - epoch_logs, - index_array, - out_labels, - callback_model, - batch_size, - sample_weights=None, - val_inputs=None, - val_targets=None, - val_sample_weights=None, - callbacks=None, - shuffle=True, - num_train_samples=None, - do_validation=False): - """Fit function for eager execution when input is given as arrays or tensors. - - Updates the given epoch logs. - - Arguments: - model: Instance of the `Model`. - inputs: List of input arrays. - targets: List of target arrays. - epoch_logs: Dictionary of logs from every epoch. - index_array: Index array generated from number of training samples. - out_labels: Output labels generated from model metric names. - callback_model: Instance of `Model` to callback. - batch_size: Integer batch size or None if unknown. - sample_weights: Optional list of sample weight arrays. - val_inputs: Input data for validation. - val_targets: Target data for validation. - val_sample_weights: Sample weight data for validation. - callbacks: List of callbacks to be called during training. - shuffle: Whether to shuffle the data at the beginning of each epoch. - num_train_samples: Integer number of training samples. - do_validation: Boolean value indicating whether we should do validation. - """ - # TODO(psv): Create a dataset iterator instead of manually creating batches - # here and in batch_test_loop, batch_predict_loop. - if shuffle == 'batch': - index_array = model._batch_shuffle(index_array, batch_size) - elif shuffle: - np.random.shuffle(index_array) - - batches = generic_utils.make_batches(num_train_samples, batch_size) - - for batch_index, (batch_start, batch_end) in enumerate(batches): - batch_ids = index_array[batch_start:batch_end] - inputs_batch = slice_arrays(inputs, batch_ids, contiguous=not shuffle) - targets_batch = slice_arrays(targets, batch_ids, contiguous=not shuffle) - if sample_weights: - sample_weights_batch = slice_arrays( - sample_weights, batch_ids, contiguous=not shuffle) - else: - sample_weights_batch = None - batch_logs = {} - batch_logs['batch'] = batch_index - batch_logs['size'] = len(batch_ids) - - callbacks.on_batch_begin(batch_index, batch_logs) - - inputs_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - for val in inputs_batch - ] - targets_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - for val in targets_batch - ] - if sample_weights: - sample_weights_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - if val is not None else None for val in sample_weights_batch - ] - - outs, loss, loss_metrics = _process_single_batch( - model, - inputs_batch, - targets_batch, - sample_weights=sample_weights_batch, - training=True) - - if not isinstance(outs, list): - outs = [outs] - - for l, o in zip(out_labels, outs): - batch_logs[l] = o - # Required for eager execution - metrics_results = _eager_metrics_fn(model, outs, targets_batch) - batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss)) - - for k, v in zip(model.metrics_names, - [backend.mean(loss)] + loss_metrics + metrics_results): - batch_logs[k] = tensor_util.constant_value(v) - callbacks.on_batch_end(batch_index, batch_logs) - if callback_model.stop_training: - break - - if batch_index == len(batches) - 1: # Last batch. - if do_validation: - val_outs = test_loop( - model, - val_inputs, - val_targets, - sample_weights=val_sample_weights, - batch_size=batch_size, - verbose=0) + verbose=0, + batch_size=batch_size) if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. - for l, o in zip(out_labels, val_outs): + for l, o in zip(model.metrics_names, val_outs): epoch_logs['val_' + l] = o @@ -451,6 +303,11 @@ def iterator_test_loop(model, inputs, steps, verbose=0): expectations of the model. """ assert isinstance(inputs, iterator_ops.EagerIterator) + # make sure either x,y or x,y,sample_weights is provided + if (not isinstance(inputs.output_shapes, (list, tuple)) or + len(inputs.output_shapes) < 2 or len(inputs.output_shapes) > 3): + raise ValueError('Please provide either inputs and targets' + 'or inputs, targets, and sample_weights') outs = [] num_samples = 0 if verbose == 1: @@ -463,23 +320,47 @@ def iterator_test_loop(model, inputs, steps, verbose=0): logging.warning( 'Your dataset iterator ran out of data interrupting testing. ' 'Make sure that your dataset can generate at least `steps` batches ' - '(in this case, %d batches).', steps) + '(in this case, %d batches). You may need to use the repeat() ' + 'function when building your dataset.', steps) break - if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: - raise ValueError('Please provide data as a list or tuple of 2 elements ' - ' - input and target pair. Received %s' % next_element) - x, y = next_element + if len(inputs.output_shapes) == 2: + x, y = next_element + sample_weights = None + else: + x, y, sample_weights = next_element # Validate and standardize data. - x, y, sample_weights = model._standardize_user_data(x, y) + x, y, sample_weights = model._standardize_user_data( + x, y, sample_weight=sample_weights) x = training_utils.cast_if_floating_dtype(x) y = training_utils.cast_if_floating_dtype(y) + if sample_weights: + sample_weights = [ + training_utils.cast_if_floating_dtype( + ops.convert_to_tensor(val, dtype=backend.floatx())) + if val is not None else None for val in sample_weights + ] + + if step_index == 0: + # Get stateful metrics indices. We do not do this before the `steps` loop + # because model will be compiled only in the first iteration of this loop + # in the deferred build scenario. + if hasattr(model, 'metrics'): + for m in model.stateful_metric_functions: + m.reset_states() + stateful_metric_indices = [ + i for i, name in enumerate(model.metrics_names) + if str(name) in model.stateful_metric_names + ] + else: + stateful_metric_indices = [] # Calculate model output, loss values. - loss_outs, loss, loss_metrics = _model_loss( + loss_outs, loss, loss_metrics, masks = _model_loss( model, x, y, sample_weights=sample_weights, training=False) - metrics_results = _eager_metrics_fn(model, loss_outs, y) + metrics_results = _eager_metrics_fn( + model, loss_outs, y, sample_weights=sample_weights, masks=masks) batch_outs = [] for _, v in zip(model.metrics_names, [backend.mean(loss)] + loss_metrics + metrics_results): @@ -498,7 +379,10 @@ def iterator_test_loop(model, inputs, steps, verbose=0): for _ in enumerate(batch_outs): outs.append(0.) for i, batch_out in enumerate(batch_outs): - outs[i] += batch_out * step_size + if i in stateful_metric_indices: + outs[i] = batch_out + else: + outs[i] += batch_out * step_size # Calculate sample size. num_samples += step_size @@ -506,95 +390,8 @@ def iterator_test_loop(model, inputs, steps, verbose=0): progbar.update(step_index + 1) for i in range(len(outs)): - outs[i] /= num_samples - if len(outs) == 1: - return outs[0] - return outs - - -def batch_test_loop(model, - inputs, - targets, - batch_size, - sample_weights=None, - verbose=0): - """Test function for eager execution when input is given as arrays or tensors. - - Arguments: - model: Model instance that is being evaluated in Eager mode. - inputs: List of input arrays. - targets: List of target arrays. - batch_size: Integer batch size. - sample_weights: Optional list of sample weight arrays. - verbose: Verbosity mode. - - Returns: - Scalar loss (if the model has a single output and no metrics) - or list of scalars (if the model has multiple outputs - and/or metrics). The attribute `model.metrics_names` will give you - the display labels for the scalar outputs. - """ - outs = [] - feed_data = inputs + targets - if sample_weights: - feed_data += sample_weights - num_samples = training_utils.check_num_samples( - feed_data, batch_size=batch_size) - if verbose == 1: - progbar = generic_utils.Progbar(target=num_samples) - batches = generic_utils.make_batches(num_samples, batch_size) - index_array = np.arange(num_samples) - for batch_index, (batch_start, batch_end) in enumerate(batches): - batch_ids = index_array[batch_start:batch_end] - inputs_batch = slice_arrays(inputs, batch_ids) - targets_batch = slice_arrays(targets, batch_ids) - if sample_weights: - sample_weights_batch = slice_arrays(sample_weights, batch_ids) - else: - sample_weights_batch = None - - inputs_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - for val in inputs_batch - ] - targets_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - for val in targets_batch - ] - if sample_weights: - sample_weights_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - if val is not None else None for val in sample_weights_batch - ] - - loss_outs, loss, loss_metrics = _model_loss( - model, - inputs_batch, - targets_batch, - sample_weights=sample_weights_batch, - training=False) - metrics_results = _eager_metrics_fn(model, loss_outs, targets_batch) - batch_outs = [] - for _, v in zip(model.metrics_names, - [backend.mean(loss)] + loss_metrics + metrics_results): - batch_outs.append(tensor_util.constant_value(v)) - - if isinstance(batch_outs, list): - if batch_index == 0: - for _ in enumerate(batch_outs): - outs.append(0.) - for i, batch_out in enumerate(batch_outs): - outs[i] += batch_out * len(batch_ids) - else: - if batch_index == 0: - outs.append(0.) - outs[0] += batch_outs * len(batch_ids) - - if verbose == 1: - progbar.update(batch_end) - - for i in range(len(outs)): - outs[i] /= num_samples + if i not in stateful_metric_indices: + outs[i] /= num_samples if len(outs) == 1: return outs[0] return outs @@ -619,6 +416,12 @@ def iterator_predict_loop(model, inputs, steps, verbose=0): expectations of the model. """ assert isinstance(inputs, iterator_ops.EagerIterator) + if not isinstance(inputs.output_shapes, + (list, tuple)) or len(inputs.output_shapes) > 2: + raise ValueError( + 'Please provide data as a list or tuple of 1 or 2 elements ' + ' - input or input and target pair. Received %s. We do not use the ' + '`target` value here.' % inputs.output_shapes) outs = [] if verbose == 1: progbar = generic_utils.Progbar(target=steps) @@ -628,18 +431,14 @@ def iterator_predict_loop(model, inputs, steps, verbose=0): next_element = inputs.get_next() except errors.OutOfRangeError: logging.warning( - 'Your dataset iterator ran out of data; ' - 'interrupting prediction. Make sure that your ' - 'dataset can generate at least `steps` ' - 'batches (in this case, %d batches).', steps) + 'Your dataset iterator ran out of data; interrupting prediction. ' + 'Make sure that your dataset can generate at least `steps` batches ' + '(in this case, %d batches). You may need to use the repeat() ' + 'function when building your dataset.', steps) break - if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: - raise ValueError( - 'Please provide data as a list or tuple of 2 elements ' - ' - input and target pair. Received %s. We do not use the ' - '`target` value here.' % next_element) - x, _ = next_element + # expects a tuple, where first element of tuple represents inputs + x = next_element[0] # Validate and standardize data. x, _, _ = model._standardize_user_data(x) @@ -670,99 +469,6 @@ def iterator_predict_loop(model, inputs, steps, verbose=0): return outs -def batch_predict_loop(model, inputs, batch_size, verbose=0): - """Predict function for eager execution when input is arrays or tensors. - - Arguments: - model: Instance of `Model`. - inputs: List of input arrays. - batch_size: Integer batch size. - verbose: Verbosity mode. - - Returns: - Array of predictions (if the model has a single output) - or list of arrays of predictions (if the model has multiple outputs). - """ - outs = [] - num_samples = training_utils.check_num_samples(inputs, batch_size) - if verbose == 1: - progbar = generic_utils.Progbar(target=num_samples) - batches = generic_utils.make_batches(num_samples, batch_size) - index_array = np.arange(num_samples) - for batch_index, (batch_start, batch_end) in enumerate(batches): - batch_ids = index_array[batch_start:batch_end] - inputs_batch = slice_arrays(inputs, batch_ids) - - inputs_batch = [ - ops.convert_to_tensor(val, dtype=backend.floatx()) - for val in inputs_batch - ] - - if len(inputs_batch) == 1: - if model._expects_training_arg: - batch_outs = model.call(inputs_batch[0], training=False) - else: - batch_outs = model.call(inputs_batch[0]) - else: - if model._expects_training_arg: - batch_outs = model.call(inputs_batch, training=False) - else: - batch_outs = model.call(inputs_batch) - - if not isinstance(batch_outs, list): - batch_outs = [batch_outs] - if batch_index == 0: - # Pre-allocate the results arrays. - for batch_out in batch_outs: - dims = batch_out.shape[1:].dims - dims_list = [d.value for d in dims] - shape = (num_samples,) + tuple(dims_list) - outs.append(np.zeros(shape, dtype=batch_out.dtype.as_numpy_dtype)) - for i, batch_out in enumerate(batch_outs): - outs[i][batch_start:batch_end] = batch_out - if verbose == 1: - progbar.update(batch_end) - - if len(outs) == 1: - return outs[0] - return outs - - -def slice_arrays(arrays, indices, contiguous=True): - """Slices batches out of provided arrays (workaround for eager tensors). - - Unfortunately eager tensors don't have the same slicing behavior as - Numpy arrays (they follow the same slicing behavior as symbolic TF tensors), - hence we cannot use `generic_utils.slice_arrays` directly - and we have to implement this workaround based on `concat`. This has a - performance cost. - - Arguments: - arrays: Single array or list of arrays. - indices: List of indices in the array that should be included in the output - batch. - contiguous: Boolean flag indicating whether the indices are contiguous. - - Returns: - Slice of data (either single array or list of arrays). - """ - if any(tensor_util.is_tensor(x) for x in arrays): - converted_to_list = False - if not isinstance(arrays, list): - converted_to_list = True - arrays = [arrays] - if not contiguous: - entries = [[x[i:i + 1] for i in indices] for x in arrays] - slices = [array_ops.concat(x, axis=0) for x in entries] - else: - slices = [x[indices[0]:indices[-1] + 1] for x in arrays] - if converted_to_list: - slices = slices[0] - return slices - else: - return generic_utils.slice_arrays(arrays, indices) - - def _process_single_batch(model, inputs, targets, @@ -782,16 +488,20 @@ def _process_single_batch(model, set this to False. Returns: - output of the model, total loss and the loss associated with each output. + output of the model, total loss, the loss and the mask + associated with each output. Raises: ValueError: If the model has no loss to optimize. """ with backend.learning_phase_scope(1 if training else 0): with GradientTape() as tape: - outs, loss, loss_metrics = _model_loss(model, inputs, targets, - sample_weights=sample_weights, - training=training) + outs, loss, loss_metrics, masks = _model_loss( + model, + inputs, + targets, + sample_weights=sample_weights, + training=training) if loss is None: raise ValueError('The model cannot be run ' 'because it has no loss to optimize.') @@ -804,7 +514,7 @@ def _process_single_batch(model, grads = tape.gradient(loss, model._collected_trainable_weights) model.optimizer.apply_gradients(zip(grads, model._collected_trainable_weights)) - return outs, loss, loss_metrics + return outs, loss, loss_metrics, masks def train_on_batch(model, inputs, targets, sample_weights=None): @@ -835,14 +545,18 @@ def train_on_batch(model, inputs, targets, sample_weights=None): if val is not None else None for val in sample_weights ] - outs, loss, _ = _process_single_batch( + outs, loss, loss_metrics, masks = _process_single_batch( model, inputs, targets, sample_weights=sample_weights, training=True) if not isinstance(outs, list): outs = [outs] - metrics_results = _eager_metrics_fn(model, outs, targets) - if not isinstance(loss, list): - loss = [loss] - return loss + metrics_results + metrics_results = _eager_metrics_fn( + model, outs, targets, sample_weights=sample_weights, masks=masks) + loss = generic_utils.to_list(loss) + + return [ + tensor_util.constant_value(v) + for v in loss + loss_metrics + metrics_results + ] def test_on_batch(model, inputs, targets, sample_weights=None): @@ -872,14 +586,18 @@ def test_on_batch(model, inputs, targets, sample_weights=None): ops.convert_to_tensor(val, dtype=backend.floatx()) if val is not None else None for val in sample_weights ] - outs, loss, loss_metrics = _model_loss( + outs, loss, loss_metrics, masks = _model_loss( model, inputs, targets, sample_weights=sample_weights, training=False) if not isinstance(outs, list): outs = [outs] - metrics_results = _eager_metrics_fn(model, outs, targets) - if not isinstance(loss, list): - loss = [loss] - return loss + loss_metrics + metrics_results + metrics_results = _eager_metrics_fn( + model, outs, targets, sample_weights=sample_weights, masks=masks) + loss = generic_utils.to_list(loss) + + return [ + tensor_util.constant_value(v) + for v in loss + loss_metrics + metrics_results + ] def fit_loop(model, @@ -895,7 +613,6 @@ def fit_loop(model, verbose=1, callbacks=None, shuffle=True, - callback_metrics=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): @@ -917,10 +634,6 @@ def fit_loop(model, verbose: Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training shuffle: Whether to shuffle the data at the beginning of each epoch - callback_metrics: List of strings, the display names of the metrics - passed to the callbacks. They should be the - concatenation of list the display names of the outputs of - `f` and the list of display names of the outputs of `f_val`. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) @@ -935,123 +648,56 @@ def fit_loop(model, Raises: ValueError: In case of invalid argument values. """ + # Convert training inputs to an EagerIterator + inputs, steps_per_epoch = training_utils.convert_to_iterator( + x=inputs, + y=targets, + sample_weights=sample_weights, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + epochs=epochs, + shuffle=shuffle) # Required for eager execution with backend.learning_phase_scope(1): - do_validation = False - if val_inputs: - do_validation = True - if (steps_per_epoch is None and verbose and inputs and - hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')): - print('Train on %d samples, validate on %d samples' % - (inputs[0].shape[0], val_inputs[0].shape[0])) - - num_train_samples = None - out_labels = None - if steps_per_epoch is None or model._is_compiled: - out_labels = model.metrics_names - if do_validation: - callback_metrics = copy.copy(out_labels) + [ - 'val_' + n for n in out_labels - ] - else: - callback_metrics = copy.copy(out_labels) + do_validation = val_inputs is not None + callbacks = cbks.configure_callbacks( + callbacks, + model, + do_validation=do_validation, + batch_size=batch_size, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + val_inputs=val_inputs, + val_targets=val_targets, + val_sample_weights=val_sample_weights, + validation_steps=validation_steps, + verbose=verbose) - if steps_per_epoch is None: - if sample_weights: - feed_data = inputs + targets + sample_weights - else: - feed_data = inputs + targets - num_train_samples = training_utils.check_num_samples( - feed_data, - batch_size=batch_size, - steps=steps_per_epoch, - steps_name='steps_per_epoch') - - if num_train_samples is not None: - index_array = np.arange(num_train_samples) - - model.history = cbks.History() - callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history] - if verbose: - if steps_per_epoch is not None: - count_mode = 'steps' - else: - count_mode = 'samples' - callbacks += [cbks.ProgbarLogger(count_mode)] - callbacks = cbks.CallbackList(callbacks) - - # it's possible to callback a different model than self - # (used by Sequential models) - if hasattr(model, 'callback_model') and model.callback_model: - callback_model = model.callback_model - else: - callback_model = model - - callbacks.set_model(callback_model) - - callbacks.set_params({ - 'batch_size': batch_size, - 'epochs': epochs, - 'steps': steps_per_epoch, - 'samples': num_train_samples, - 'verbose': verbose, - 'do_validation': do_validation, - 'metrics': callback_metrics or [], - }) callbacks.on_train_begin() - callback_model.stop_training = False - for cbk in callbacks: - if not val_inputs: - cbk.validation_data = [] - elif isinstance(val_inputs, iterator_ops.EagerIterator): - cbk.validation_data = val_inputs - elif val_sample_weights: - cbk.validation_data = val_inputs + val_targets + val_sample_weights - else: - cbk.validation_data = val_inputs + val_targets - for epoch in range(initial_epoch, epochs): + if model._is_compiled: # Model may not be compiled the first time. + # Reset stateful metrics + for m in model.stateful_metric_functions: + m.reset_states() callbacks.on_epoch_begin(epoch) epoch_logs = {} - - if steps_per_epoch is not None: - iterator_fit_loop( - model, - inputs, - class_weight, - steps_per_epoch=steps_per_epoch, - callback_model=callback_model, - out_labels=out_labels, - epoch_logs=epoch_logs, - val_inputs=val_inputs, - val_targets=val_targets, - val_sample_weights=val_sample_weights, - epochs=epochs, - verbose=verbose, - callbacks=callbacks, - callback_metrics=callback_metrics, - validation_steps=validation_steps, - do_validation=do_validation) - else: - batch_fit_loop( - model, - inputs, - targets, - epoch_logs=epoch_logs, - index_array=index_array, - out_labels=out_labels, - callback_model=callback_model, - batch_size=batch_size, - sample_weights=sample_weights, - val_inputs=val_inputs, - val_targets=val_targets, - val_sample_weights=val_sample_weights, - callbacks=callbacks, - shuffle=shuffle, - num_train_samples=num_train_samples, - do_validation=do_validation) + iterator_fit_loop( + model, + inputs, + class_weight, + steps_per_epoch=steps_per_epoch, + epoch_logs=epoch_logs, + val_inputs=val_inputs, + val_targets=val_targets, + val_sample_weights=val_sample_weights, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + validation_steps=validation_steps, + do_validation=do_validation, + batch_size=batch_size) callbacks.on_epoch_end(epoch, epoch_logs) - if callback_model.stop_training: + if callbacks.model.stop_training: break callbacks.on_train_end() return model.history @@ -1081,23 +727,17 @@ def test_loop(model, inputs, targets, and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. """ + inputs, steps = training_utils.convert_to_iterator( + x=inputs, + y=targets, + sample_weights=sample_weights, + batch_size=batch_size, + steps_per_epoch=steps) with backend.learning_phase_scope(0): - if steps is not None: - return iterator_test_loop(model, inputs, steps, verbose=verbose) - else: - return batch_test_loop( - model, - inputs, - targets, - batch_size=batch_size, - sample_weights=sample_weights, - verbose=verbose) + return iterator_test_loop(model, inputs, steps, verbose=verbose) -def predict_loop(model, inputs, - batch_size=32, - verbose=0, - steps=None): +def predict_loop(model, inputs, batch_size=32, verbose=0, steps=None): """Predict function for eager execution. Arguments: @@ -1115,8 +755,6 @@ def predict_loop(model, inputs, (if the model has multiple outputs). """ with backend.learning_phase_scope(0): - if steps is not None: - return iterator_predict_loop(model, inputs, steps, verbose=verbose) - else: - return batch_predict_loop( - model, inputs, batch_size=batch_size, verbose=verbose) + inputs, steps = training_utils.convert_to_iterator( + x=inputs, batch_size=batch_size, steps_per_epoch=steps) + return iterator_predict_loop(model, inputs, steps, verbose=verbose) diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py index 1571a7782aaae3c3be233c7d4fdd91456a96e3e3..db7ccb181fb5d4c0f151a2736eed461fc4855446 100644 --- a/tensorflow/python/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/engine/training_eager_test.py @@ -24,291 +24,13 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python import keras from tensorflow.python.framework import ops from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras import testing_utils +from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.platform import test from tensorflow.python.training.rmsprop import RMSPropOptimizer class TrainingTest(test.TestCase): - def test_fit_on_arrays(self): - a = keras.layers.Input(shape=(3,), name='input_a') - b = keras.layers.Input(shape=(3,), name='input_b') - - dense = keras.layers.Dense(4, name='dense') - c = dense(a) - d = dense(b) - e = keras.layers.Dropout(0.5, name='dropout')(c) - - model = keras.models.Model([a, b], [d, e]) - - optimizer = RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - loss_weights = [1., 0.5] - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights) - - input_a_np = np.random.random((10, 3)) - input_b_np = np.random.random((10, 3)) - - output_d_np = np.random.random((10, 4)) - output_e_np = np.random.random((10, 4)) - - # Test fit at different verbosity - model.fit( - [input_a_np, input_b_np], [output_d_np, output_e_np], - epochs=1, - batch_size=5, - verbose=0) - model.fit( - [input_a_np, input_b_np], [output_d_np, output_e_np], - epochs=1, - batch_size=5, - verbose=1) - model.fit( - [input_a_np, input_b_np], [output_d_np, output_e_np], - epochs=2, - batch_size=5, - verbose=2) - - # Test with validation data - model.fit( - [input_a_np, input_b_np], [output_d_np, output_e_np], - validation_data=([input_a_np, input_b_np], [output_d_np, - output_e_np]), - epochs=1, - batch_size=5, - verbose=0) - model.fit( - [input_a_np, input_b_np], [output_d_np, output_e_np], - validation_data=([input_a_np, input_b_np], [output_d_np, - output_e_np]), - epochs=2, - batch_size=5, - verbose=1) - model.fit( - [input_a_np, input_b_np], [output_d_np, output_e_np], - validation_data=([input_a_np, input_b_np], [output_d_np, - output_e_np]), - epochs=2, - batch_size=5, - verbose=2) - model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) - - # Test with validation split - model.fit( - [input_a_np, input_b_np], [output_d_np, output_e_np], - epochs=2, - batch_size=5, - verbose=0, - validation_split=0.2) - - # Test with dictionary inputs - model.fit( - { - 'input_a': input_a_np, - 'input_b': input_b_np - }, {'dense': output_d_np, - 'dropout': output_e_np}, - epochs=1, - batch_size=5, - verbose=0) - model.fit( - { - 'input_a': input_a_np, - 'input_b': input_b_np - }, {'dense': output_d_np, - 'dropout': output_e_np}, - epochs=1, - batch_size=5, - verbose=1) - model.fit( - { - 'input_a': input_a_np, - 'input_b': input_b_np - }, {'dense': output_d_np, - 'dropout': output_e_np}, - validation_data=({'input_a': input_a_np, - 'input_b': input_b_np - }, - { - 'dense': output_d_np, - 'dropout': output_e_np - }), - epochs=1, - batch_size=5, - verbose=0) - model.train_on_batch({ - 'input_a': input_a_np, - 'input_b': input_b_np - }, {'dense': output_d_np, - 'dropout': output_e_np}) - # Test with lists for loss, metrics - loss = ['mae', 'mse'] - metrics = ['acc', 'mae'] - model.compile(optimizer, loss, metrics=metrics) - model.fit( - [input_a_np, input_b_np], [output_d_np, output_e_np], - epochs=1, - batch_size=5, - verbose=0) - - # Test with dictionaries for loss, metrics, loss weights - loss = {'dense': 'mse', 'dropout': 'mae'} - loss_weights = {'dense': 1., 'dropout': 0.5} - metrics = {'dense': 'mse', 'dropout': 'mae'} - model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights) - model.fit( - [input_a_np, input_b_np], [output_d_np, output_e_np], - epochs=1, - batch_size=5, - verbose=0) - - # Invalid use cases - with self.assertRaises(AttributeError): - model.fit( - [input_a_np, input_b_np], [output_d_np, output_e_np], - epochs=1, - validation_data=([input_a_np, input_b_np], 0, 0), - verbose=0) - with self.assertRaises(ValueError): - model.train_on_batch({'input_a': input_a_np}, - [output_d_np, output_e_np]) - with self.assertRaises(ValueError): - model.train_on_batch([input_a_np], [output_d_np, output_e_np]) - with self.assertRaises(AttributeError): - model.train_on_batch(1, [output_d_np, output_e_np]) - with self.assertRaises(ValueError): - model.train_on_batch(input_a_np, [output_d_np, output_e_np]) - with self.assertRaises(ValueError): - bad_input = np.random.random((11, 3)) - model.train_on_batch([bad_input, input_b_np], - [output_d_np, output_e_np]) - with self.assertRaises(ValueError): - bad_target = np.random.random((11, 4)) - model.train_on_batch([input_a_np, input_b_np], - [bad_target, output_e_np]) - - # Build single-input model - x = keras.layers.Input(shape=(3,), name='input_a') - y = keras.layers.Dense(4)(x) - model = keras.models.Model(x, y) - model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse') - # This will work - model.fit([input_a_np], output_d_np, epochs=1) - with self.assertRaises(ValueError): - model.fit([input_a_np, input_a_np], output_d_np, epochs=1) - - def test_evaluate_predict_on_arrays(self): - a = keras.layers.Input(shape=(3,), name='input_a') - b = keras.layers.Input(shape=(3,), name='input_b') - - dense = keras.layers.Dense(4, name='dense') - c = dense(a) - d = dense(b) - e = keras.layers.Dropout(0.5, name='dropout')(c) - - model = keras.models.Model([a, b], [d, e]) - - optimizer = RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - loss_weights = [1., 0.5] - metrics = ['acc', 'mae'] - model.compile( - optimizer, - loss, - metrics=metrics, - loss_weights=loss_weights, - sample_weight_mode=None) - - input_a_np = np.random.random((10, 3)) - input_b_np = np.random.random((10, 3)) - - output_d_np = np.random.random((10, 4)) - output_e_np = np.random.random((10, 4)) - - # Test evaluate at different verbosity - out = model.evaluate( - [input_a_np, input_b_np], [output_d_np, output_e_np], - batch_size=5, - verbose=0) - self.assertEqual(len(out), 7) - out = model.evaluate( - [input_a_np, input_b_np], [output_d_np, output_e_np], - batch_size=5, - verbose=1) - self.assertEqual(len(out), 7) - out = model.evaluate( - [input_a_np, input_b_np], [output_d_np, output_e_np], - batch_size=5, - verbose=2) - self.assertEqual(len(out), 7) - out = model.test_on_batch([input_a_np, input_b_np], - [output_d_np, output_e_np]) - self.assertEqual(len(out), 7) - - # Test evaluate with dictionary inputs - model.evaluate( - { - 'input_a': input_a_np, - 'input_b': input_b_np - }, {'dense': output_d_np, - 'dropout': output_e_np}, - batch_size=5, - verbose=0) - model.evaluate( - { - 'input_a': input_a_np, - 'input_b': input_b_np - }, {'dense': output_d_np, - 'dropout': output_e_np}, - batch_size=5, - verbose=1) - - # Test predict - out = model.predict([input_a_np, input_b_np], batch_size=5) - self.assertEqual(len(out), 2) - out = model.predict({'input_a': input_a_np, 'input_b': input_b_np}) - self.assertEqual(len(out), 2) - out = model.predict_on_batch({ - 'input_a': input_a_np, - 'input_b': input_b_np - }) - self.assertEqual(len(out), 2) - - def test_invalid_loss_or_metrics(self): - num_classes = 5 - train_samples = 1000 - test_samples = 1000 - input_dim = 5 - - model = keras.models.Sequential() - model.add(keras.layers.Dense(10, input_shape=(input_dim,))) - model.add(keras.layers.Activation('relu')) - model.add(keras.layers.Dense(num_classes)) - model.add(keras.layers.Activation('softmax')) - model.compile(loss='categorical_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001)) - np.random.seed(1337) - - (x_train, y_train), (_, _) = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - - with self.assertRaises(ValueError): - model.fit(x_train, np.concatenate([y_train, y_train], axis=-1)) - - with self.assertRaises(TypeError): - model.compile(loss='categorical_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001), - metrics=set(0)) - - with self.assertRaises(ValueError): - model.compile(loss=None, - optimizer='rms') - def test_model_methods_with_eager_tensors_multi_io(self): a = keras.layers.Input(shape=(3,), name='input_a') b = keras.layers.Input(shape=(3,), name='input_b') @@ -323,7 +45,7 @@ class TrainingTest(test.TestCase): optimizer = RMSPropOptimizer(learning_rate=0.001) loss = 'mse' loss_weights = [1., 0.5] - metrics = ['mae'] + metrics = ['mae', metrics_module.CategoricalAccuracy()] model.compile( optimizer, loss, @@ -388,7 +110,7 @@ class TrainingTest(test.TestCase): optimizer = RMSPropOptimizer(learning_rate=0.001) loss = 'mse' - metrics = ['mae'] + metrics = ['mae', metrics_module.CategoricalAccuracy()] model.compile(optimizer, loss, metrics=metrics) inputs = keras.backend.zeros(shape=(10, 3)) @@ -407,7 +129,9 @@ class TrainingTest(test.TestCase): model = keras.Sequential() model.add(keras.layers.Dense(4, input_shape=(3,))) optimizer = RMSPropOptimizer(learning_rate=0.001) - model.compile(optimizer, 'mse', metrics=['mae']) + model.compile( + optimizer, 'mse', metrics=['mae', + metrics_module.CategoricalAccuracy()]) x = np.random.random((10, 3)) y = np.random.random((10, 4)) @@ -422,232 +146,9 @@ class TrainingTest(test.TestCase): self.assertEqual(out.shape, (30, 4)) -class LossWeightingTest(test.TestCase): - - def test_class_weights(self): - num_classes = 5 - batch_size = 5 - weighted_class = 3 - train_samples = 300 - test_samples = 300 - input_dim = 5 - - model = keras.models.Sequential() - model.add(keras.layers.Dense(10, input_shape=(input_dim,))) - model.add(keras.layers.Activation('relu')) - model.add(keras.layers.Dense(num_classes)) - model.add(keras.layers.Activation('softmax')) - model.compile(loss='categorical_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001)) - - np.random.seed(1337) - (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - int_y_test = y_test.copy() - int_y_train = y_train.copy() - # convert class vectors to binary class matrices - y_train = keras.utils.to_categorical(y_train, num_classes) - y_test = keras.utils.to_categorical(y_test, num_classes) - test_ids = np.where(int_y_test == np.array(weighted_class))[0] - - class_weight = dict([(i, 1.) for i in range(num_classes)]) - class_weight[weighted_class] = 4. - - sample_weight = np.ones((y_train.shape[0])) - sample_weight[int_y_train == weighted_class] = 4. - - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=2, - verbose=0, - class_weight=class_weight, - validation_data=(x_train, y_train, sample_weight)) - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=2, - verbose=0, - class_weight=class_weight) - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=2, - verbose=0, - class_weight=class_weight, - validation_split=0.1) - - model.train_on_batch( - x_train[:batch_size], y_train[:batch_size], class_weight=class_weight) - ref_score = model.evaluate(x_test, y_test, verbose=0) - score = model.evaluate( - x_test[test_ids, :], y_test[test_ids, :], verbose=0) - self.assertLess(score, ref_score) - - def test_sample_weights(self): - num_classes = 5 - batch_size = 5 - weighted_class = 3 - train_samples = 300 - test_samples = 300 - input_dim = 5 - - model = keras.models.Sequential() - model.add(keras.layers.Dense(10, input_shape=(input_dim,))) - model.add(keras.layers.Activation('relu')) - model.add(keras.layers.Dense(num_classes)) - model.add(keras.layers.Activation('softmax')) - model.compile(loss='categorical_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001)) - - np.random.seed(43) - (x_train, y_train), _ = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - int_y_train = y_train.copy() - y_train = keras.utils.to_categorical(y_train, num_classes) - - class_weight = dict([(i, 1.) for i in range(num_classes)]) - class_weight[weighted_class] = 4. - - sample_weight = np.ones((y_train.shape[0])) - sample_weight[int_y_train == weighted_class] = 4. - - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=2, - verbose=0, - sample_weight=sample_weight) - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=2, - verbose=0, - sample_weight=sample_weight, - validation_split=0.1) - model.train_on_batch( - x_train[:batch_size], - y_train[:batch_size], - sample_weight=sample_weight[:batch_size]) - model.test_on_batch( - x_train[:batch_size], - y_train[:batch_size], - sample_weight=sample_weight[:batch_size]) - - def test_temporal_sample_weights(self): - num_classes = 5 - weighted_class = 3 - train_samples = 1000 - test_samples = 1000 - input_dim = 5 - timesteps = 3 - - model = keras.models.Sequential() - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(num_classes), - input_shape=(timesteps, input_dim))) - model.add(keras.layers.Activation('softmax')) - - np.random.seed(1337) - (_, y_train), _ = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - int_y_train = y_train.copy() - # convert class vectors to binary class matrices - y_train = keras.utils.to_categorical(y_train, num_classes) - - class_weight = dict([(i, 1.) for i in range(num_classes)]) - class_weight[weighted_class] = 2. - - sample_weight = np.ones((y_train.shape[0])) - sample_weight[int_y_train == weighted_class] = 2. - with self.assertRaises(ValueError): - model.compile( - loss='binary_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001), - sample_weight_mode='temporal') - - def test_class_weight_invalid_use_case(self): - num_classes = 5 - train_samples = 1000 - test_samples = 1000 - input_dim = 5 - timesteps = 3 - - model = keras.models.Sequential() - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(num_classes), - input_shape=(timesteps, input_dim))) - model.add(keras.layers.Activation('softmax')) - model.compile( - loss='binary_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001)) - - (x_train, y_train), _ = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - # convert class vectors to binary class matrices - y_train = keras.utils.to_categorical(y_train, num_classes) - class_weight = dict([(i, 1.) for i in range(num_classes)]) - - del class_weight[1] - with self.assertRaises(ValueError): - model.fit(x_train, y_train, - epochs=0, verbose=0, class_weight=class_weight) - - with self.assertRaises(ValueError): - model.compile( - loss='binary_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001), - sample_weight_mode=[]) - - # Build multi-output model - x = keras.Input((3,)) - y1 = keras.layers.Dense(4, name='1')(x) - y2 = keras.layers.Dense(4, name='2')(x) - model = keras.models.Model(x, [y1, y2]) - model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse') - x_np = np.random.random((10, 3)) - y_np = np.random.random((10, 4)) - w_np = np.random.random((10,)) - # This will work - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': w_np}) - # These will not - with self.assertRaises(ValueError): - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight=[w_np]) - with self.assertRaises(TypeError): - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight=w_np) - with self.assertRaises(ValueError): - bad_w_np = np.random.random((11,)) - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) - with self.assertRaises(ValueError): - bad_w_np = np.random.random((10, 2)) - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) - with self.assertRaises(ValueError): - bad_w_np = np.random.random((10, 2, 2)) - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) - - class CorrectnessTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_loss_correctness(self): # Test that training loss is the same in eager and graph # (by comparing it to a reference value in a deterministic case) @@ -668,28 +169,7 @@ class CorrectnessTest(test.TestCase): self.assertEqual( np.around(history.history['loss'][-1], decimals=4), 0.6173) - @tf_test_util.run_in_graph_and_eager_modes() - def test_metrics_correctness(self): - model = keras.Sequential() - model.add(keras.layers.Dense(3, - activation='relu', - input_dim=4, - kernel_initializer='ones')) - model.add(keras.layers.Dense(1, - activation='sigmoid', - kernel_initializer='ones')) - model.compile(loss='mae', - metrics=['acc'], - optimizer=RMSPropOptimizer(learning_rate=0.001)) - x = np.ones((100, 4)) - y = np.ones((100, 1)) - outs = model.evaluate(x, y) - self.assertEqual(outs[1], 1.) - y = np.zeros((100, 1)) - outs = model.evaluate(x, y) - self.assertEqual(outs[1], 0.) - - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_loss_correctness_with_iterator(self): # Test that training loss is the same in eager and graph # (by comparing it to a reference value in a deterministic case) @@ -712,35 +192,6 @@ class CorrectnessTest(test.TestCase): history = model.fit(iterator, epochs=1, steps_per_epoch=10) self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173) - @tf_test_util.run_in_graph_and_eager_modes() - def test_metrics_correctness_with_iterator(self): - model = keras.Sequential() - model.add( - keras.layers.Dense( - 8, activation='relu', input_dim=4, kernel_initializer='ones')) - model.add( - keras.layers.Dense(1, activation='sigmoid', kernel_initializer='ones')) - model.compile( - loss='binary_crossentropy', - metrics=['accuracy'], - optimizer=RMSPropOptimizer(learning_rate=0.001)) - np.random.seed(123) - x = np.random.randint(10, size=(100, 4)).astype(np.float32) - y = np.random.randint(2, size=(100, 1)).astype(np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) - dataset = dataset.batch(10) - iterator = dataset.make_one_shot_iterator() - outs = model.evaluate(iterator, steps=10) - self.assertEqual(np.around(outs[1], decimals=1), 0.5) - - y = np.zeros((100, 1), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) - dataset = dataset.repeat(100) - dataset = dataset.batch(10) - iterator = dataset.make_one_shot_iterator() - outs = model.evaluate(iterator, steps=10) - self.assertEqual(outs[1], 0.) - if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py index d81b384f0e1810614bd98e3861b4324f0f8a4dca..413c1f4fbaba63d173de2c1d1c9943e919b05719 100644 --- a/tensorflow/python/keras/engine/training_generator.py +++ b/tensorflow/python/keras/engine/training_generator.py @@ -21,7 +21,6 @@ from __future__ import print_function import numpy as np -from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks as cbks from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer @@ -79,55 +78,37 @@ def fit_generator(model, ' class. Please specify `validation_steps` or use' ' the `keras.utils.Sequence` class.') - # Prepare display labels. - out_labels = model.metrics_names - callback_metrics = out_labels + ['val_%s' % n for n in out_labels] - - # prepare callbacks - model.history = cbks.History() - callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history] - if verbose: - callbacks += [cbks.ProgbarLogger(count_mode='steps')] - callbacks = cbks.CallbackList(callbacks) - - # it's possible to callback a different model than self: - if hasattr(model, 'callback_model') and model.callback_model: - callback_model = model.callback_model - else: - callback_model = model - callbacks.set_model(callback_model) - callbacks.set_params({ - 'epochs': epochs, - 'steps': steps_per_epoch, - 'verbose': verbose, - 'do_validation': do_validation, - 'metrics': callback_metrics, - }) - callbacks.on_train_begin() - enqueuer = None val_enqueuer = None try: + val_x, val_y, val_sample_weights = validation_data, None, None if do_validation and not val_gen: # Prepare data for validation if len(validation_data) == 2: val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence - val_sample_weight = None + val_sample_weights = None elif len(validation_data) == 3: - val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence + val_x, val_y, val_sample_weights = validation_data # pylint: disable=unpacking-non-sequence else: raise ValueError( '`validation_data` should be a tuple ' '`(val_x, val_y, val_sample_weight)` ' 'or `(val_x, val_y)`. Found: ' + str(validation_data)) val_x, val_y, val_sample_weights = model._standardize_user_data( - val_x, val_y, val_sample_weight) - val_data = val_x + val_y + val_sample_weights - if model.uses_learning_phase and not isinstance(K.learning_phase(), int): - val_data += [0.] - for cbk in callbacks: - cbk.validation_data = val_data + val_x, val_y, val_sample_weights) + + callbacks = cbks.configure_callbacks( + callbacks, + model, + do_validation=do_validation, + val_inputs=val_x, + val_targets=val_y, + val_sample_weights=val_sample_weights, + epochs=epochs, + validation_steps=validation_steps, + steps_per_epoch=steps_per_epoch, + verbose=verbose) if workers > 0: if is_sequence: @@ -148,7 +129,7 @@ def fit_generator(model, else: output_generator = generator - callback_model.stop_training = False + callbacks.on_train_begin() # Construct epoch logs. epoch_logs = {} while epoch < epochs: @@ -191,7 +172,7 @@ def fit_generator(model, if not isinstance(outs, list): outs = [outs] - for l, o in zip(out_labels, outs): + for l, o in zip(model.metrics_names, outs): batch_logs[l] = o callbacks.on_batch_end(batch_index, batch_logs) @@ -221,15 +202,15 @@ def fit_generator(model, if not isinstance(val_outs, list): val_outs = [val_outs] # Same labels assumed. - for l, o in zip(out_labels, val_outs): + for l, o in zip(model.metrics_names, val_outs): epoch_logs['val_' + l] = o - if callback_model.stop_training: + if callbacks.model.stop_training: break callbacks.on_epoch_end(epoch, epoch_logs) epoch += 1 - if callback_model.stop_training: + if callbacks.model.stop_training: break finally: @@ -252,7 +233,6 @@ def evaluate_generator(model, use_multiprocessing=False, verbose=0): """See docstring for `Model.evaluate_generator`.""" - stateful_metric_indices = [] if hasattr(model, 'metrics'): for m in model.stateful_metric_functions: m.reset_states() @@ -350,7 +330,7 @@ def evaluate_generator(model, averages.append( np.average([out[i] for out in all_outs], weights=batch_sizes)) else: - averages.append(float(all_outs[-1][i])) + averages.append(np.float64(all_outs[-1][i])) return averages diff --git a/tensorflow/python/keras/engine/training_gpu_test.py b/tensorflow/python/keras/engine/training_gpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5825ce814fd84bf59637f6079e7402d752e2b77b --- /dev/null +++ b/tensorflow/python/keras/engine/training_gpu_test.py @@ -0,0 +1,125 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for training routines.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.framework import test_util +from tensorflow.python.keras import backend as K +from tensorflow.python.keras.layers.convolutional import Conv2D +from tensorflow.python.platform import test +from tensorflow.python.training import rmsprop + + +class TrainingGPUTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def test_model_with_crossentropy_losses_channels_first(self): + """Tests use of all crossentropy losses with `channels_first`. + + Tests `sparse_categorical_crossentropy`, `categorical_crossentropy`, + and `binary_crossentropy`. + Verifies that evaluate gives the same result with either `channels_first` + or `channels_last` image_data_format. + """ + def prepare_simple_model(input_tensor, loss_name, target): + axis = 1 if K.image_data_format() == 'channels_first' else -1 + loss = None + num_channels = None + activation = None + if loss_name == 'sparse_categorical_crossentropy': + loss = lambda y_true, y_pred: K.sparse_categorical_crossentropy( # pylint: disable=g-long-lambda + y_true, y_pred, axis=axis) + num_channels = np.amax(target) + 1 + activation = 'softmax' + elif loss_name == 'categorical_crossentropy': + loss = lambda y_true, y_pred: K.categorical_crossentropy( # pylint: disable=g-long-lambda + y_true, y_pred, axis=axis) + num_channels = target.shape[axis] + activation = 'softmax' + elif loss_name == 'binary_crossentropy': + loss = lambda y_true, y_pred: K.binary_crossentropy(y_true, y_pred) # pylint: disable=unnecessary-lambda + num_channels = target.shape[axis] + activation = 'sigmoid' + predictions = Conv2D(num_channels, + 1, + activation=activation, + kernel_initializer='ones', + bias_initializer='ones')(input_tensor) + simple_model = keras.models.Model(inputs=input_tensor, + outputs=predictions) + simple_model.compile(optimizer=rmsprop.RMSPropOptimizer(1e-3), loss=loss) + return simple_model + + if test.is_gpu_available(cuda_only=True): + with self.test_session(use_gpu=True): + losses_to_test = ['sparse_categorical_crossentropy', + 'categorical_crossentropy', 'binary_crossentropy'] + + data_channels_first = np.array([[[[8., 7.1, 0.], [4.5, 2.6, 0.55], + [0.9, 4.2, 11.2]]]], dtype=np.float32) + # Labels for testing 4-class sparse_categorical_crossentropy, 4-class + # categorical_crossentropy, and 2-class binary_crossentropy: + labels_channels_first = [np.array([[[[0, 1, 3], [2, 1, 0], [2, 2, 1]]]], dtype=np.float32), # pylint: disable=line-too-long + np.array([[[[0, 1, 0], [0, 1, 0], [0, 0, 0]], + [[1, 0, 0], [0, 0, 1], [0, 1, 0]], + [[0, 0, 0], [1, 0, 0], [0, 0, 1]], + [[0, 0, 1], [0, 0, 0], [1, 0, 0]]]], dtype=np.float32), # pylint: disable=line-too-long + np.array([[[[0, 1, 0], [0, 1, 0], [0, 0, 1]], + [[1, 0, 1], [1, 0, 1], [1, 1, 0]]]], dtype=np.float32)] # pylint: disable=line-too-long + # Compute one loss for each loss function in the list `losses_to_test`: + loss_channels_last = [0., 0., 0.] + loss_channels_first = [0., 0., 0.] + + old_data_format = K.image_data_format() + + # Evaluate a simple network with channels last, with all three loss + # functions: + K.set_image_data_format('channels_last') + data = np.moveaxis(data_channels_first, 1, -1) + for index, loss_function in enumerate(losses_to_test): + labels = np.moveaxis(labels_channels_first[index], 1, -1) + inputs = keras.Input(shape=(3, 3, 1)) + model = prepare_simple_model(inputs, loss_function, labels) + loss_channels_last[index] = model.evaluate(x=data, y=labels, + batch_size=1, verbose=0) + + # Evaluate the same network with channels first, with all three loss + # functions: + K.set_image_data_format('channels_first') + data = data_channels_first + for index, loss_function in enumerate(losses_to_test): + labels = labels_channels_first[index] + inputs = keras.Input(shape=(1, 3, 3)) + model = prepare_simple_model(inputs, loss_function, labels) + loss_channels_first[index] = model.evaluate(x=data, y=labels, + batch_size=1, verbose=0) + + K.set_image_data_format(old_data_format) + + np.testing.assert_allclose(loss_channels_first, + loss_channels_last, + err_msg='{}{}'.format( + 'Computed different losses for ', + 'channels_first and channels_last')) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index a1ab72018957a26984f9b7b1ccba9a128a136866..15e7d725dea80a746867d769baff9ec77f4f3fe1 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import logging import os import unittest @@ -25,9 +26,11 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util as tf_test_util +from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine.training_utils import weighted_masked_objective from tensorflow.python.keras.utils.generic_utils import slice_arrays @@ -44,6 +47,7 @@ except ImportError: class TrainingTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes def test_fit_on_arrays(self): with self.test_session(): a = keras.layers.Input(shape=(3,), name='input_a') @@ -56,11 +60,14 @@ class TrainingTest(test.TestCase): model = keras.models.Model([a, b], [d, e]) - optimizer = 'rmsprop' + optimizer = RMSPropOptimizer(learning_rate=0.001) loss = 'mse' loss_weights = [1., 0.5] - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights) + model.compile( + optimizer, + loss, + metrics=[metrics_module.CategoricalAccuracy(), 'mae'], + loss_weights=loss_weights) input_a_np = np.random.random((10, 3)) input_b_np = np.random.random((10, 3)) @@ -175,8 +182,10 @@ class TrainingTest(test.TestCase): # Test with lists for loss, metrics loss = ['mae', 'mse'] - metrics = ['acc', 'mae'] - model.compile(optimizer, loss, metrics=metrics) + model.compile( + optimizer, + loss, + metrics=[metrics_module.CategoricalAccuracy(), 'mae']) model.fit( [input_a_np, input_b_np], [output_d_np, output_e_np], epochs=1, @@ -186,7 +195,10 @@ class TrainingTest(test.TestCase): # Test with dictionaries for loss, metrics, loss weights loss = {'dense': 'mse', 'dropout': 'mae'} loss_weights = {'dense': 1., 'dropout': 0.5} - metrics = {'dense': 'mse', 'dropout': 'mae'} + metrics = { + 'dense': 'mse', + 'dropout': metrics_module.CategoricalAccuracy() + } model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights) model.fit( [input_a_np, input_b_np], [output_d_np, output_e_np], @@ -223,7 +235,7 @@ class TrainingTest(test.TestCase): x = keras.layers.Input(shape=(3,), name='input_a') y = keras.layers.Dense(4)(x) model = keras.models.Model(x, y) - model.compile(optimizer='rmsprop', loss='mse') + model.compile(optimizer, loss='mse') # This will work model.fit([input_a_np], output_d_np, epochs=1) with self.assertRaises(ValueError): @@ -239,6 +251,7 @@ class TrainingTest(test.TestCase): batch_size=5, verbose=2) + @tf_test_util.run_in_graph_and_eager_modes def test_evaluate_predict_on_arrays(self): with self.test_session(): a = keras.layers.Input(shape=(3,), name='input_a') @@ -251,14 +264,13 @@ class TrainingTest(test.TestCase): model = keras.models.Model([a, b], [d, e]) - optimizer = 'rmsprop' + optimizer = RMSPropOptimizer(learning_rate=0.001) loss = 'mse' loss_weights = [1., 0.5] - metrics = ['mae'] model.compile( optimizer, loss, - metrics=metrics, + metrics=['mae', metrics_module.CategoricalAccuracy()], loss_weights=loss_weights, sample_weight_mode=None) @@ -273,20 +285,20 @@ class TrainingTest(test.TestCase): [input_a_np, input_b_np], [output_d_np, output_e_np], batch_size=5, verbose=0) - self.assertEqual(len(out), 5) + self.assertEqual(len(out), 7) out = model.evaluate( [input_a_np, input_b_np], [output_d_np, output_e_np], batch_size=5, verbose=1) - self.assertEqual(len(out), 5) + self.assertEqual(len(out), 7) out = model.evaluate( [input_a_np, input_b_np], [output_d_np, output_e_np], batch_size=5, verbose=2) - self.assertEqual(len(out), 5) + self.assertEqual(len(out), 7) out = model.test_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) - self.assertEqual(len(out), 5) + self.assertEqual(len(out), 7) # Test evaluate with dictionary inputs model.evaluate( @@ -321,7 +333,8 @@ class TrainingTest(test.TestCase): }) self.assertEqual(len(out), 2) - def test_invalid_loss_or_metrics(self): + @tf_test_util.run_in_graph_and_eager_modes + def test_invalid_loss(self): num_classes = 5 train_samples = 1000 test_samples = 1000 @@ -333,27 +346,25 @@ class TrainingTest(test.TestCase): model.add(keras.layers.Activation('relu')) model.add(keras.layers.Dense(num_classes)) model.add(keras.layers.Activation('softmax')) - model.compile(loss='categorical_crossentropy', optimizer='rmsprop') + optimizer = RMSPropOptimizer(learning_rate=0.001) + model.compile(optimizer, loss='categorical_crossentropy') np.random.seed(1337) (x_train, y_train), (_, _) = testing_utils.get_test_data( train_samples=train_samples, test_samples=test_samples, input_shape=(input_dim,), num_classes=num_classes) - with self.assertRaises(ValueError): - model.fit(x_train, y_train) with self.assertRaises(ValueError): model.fit(x_train, np.concatenate([y_train, y_train], axis=-1)) - with self.assertRaises(TypeError): - model.compile(loss='categorical_crossentropy', - optimizer='rmsprop', - metrics=set(0)) + if not context.executing_eagerly(): + # TODO(psv): Investigate these use cases in eager mode. + with self.assertRaises(ValueError): + model.fit(x_train, y_train) - with self.assertRaises(ValueError): - model.compile(loss=None, - optimizer='rmsprop') + with self.assertRaises(ValueError): + model.compile(optimizer, loss=None) def test_training_on_sparse_data_with_dense_placeholders(self): if scipy_sparse is None: @@ -372,7 +383,11 @@ class TrainingTest(test.TestCase): out2 = keras.layers.Dense(4, name='dense_1')(in2) model = keras.Model([in1, in2], [out1, out2]) model.predict(test_inputs, batch_size=2) - model.compile('rmsprop', 'mse') + optimizer = RMSPropOptimizer(learning_rate=0.001) + model.compile( + optimizer, + 'mse', + metrics=['mae', metrics_module.CategoricalAccuracy()]) model.fit(test_inputs, test_outputs, epochs=1, batch_size=2, validation_split=0.5) model.evaluate(test_inputs, test_outputs, batch_size=2) @@ -415,9 +430,34 @@ class TrainingTest(test.TestCase): x2 = model.predict(val_a) self.assertAllClose(x1, x2, atol=1e-7) + @tf_test_util.run_in_graph_and_eager_modes + def test_compile_warning_for_loss_missing_output(self): + with self.test_session(): + inp = keras.layers.Input(shape=(16,), name='input_a') + out_1 = keras.layers.Dense(8, name='dense_1')(inp) + out_2 = keras.layers.Dense(3, activation='softmax', name='dense_2')(out_1) + model = keras.models.Model(inputs=[inp], outputs=[out_1, out_2]) + optimizer = RMSPropOptimizer(learning_rate=0.001) + + with test.mock.patch.object(logging, 'warning') as mock_log: + model.compile( + optimizer, + loss={ + 'dense_2': 'categorical_crossentropy', + }, + metrics={ + 'dense_2': 'categorical_accuracy', + 'dense_1': metrics_module.CategoricalAccuracy(), + }) + msg = ('Output "dense_1" missing from loss dictionary. We assume this ' + 'was done on purpose. The fit and evaluate APIs will not be ' + 'expecting any data to be passed to "dense_1".') + self.assertRegexpMatches(str(mock_log.call_args), msg) + class LossWeightingTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes def test_class_weights(self): num_classes = 5 batch_size = 5 @@ -426,6 +466,7 @@ class LossWeightingTest(test.TestCase): train_samples = 1000 test_samples = 1000 input_dim = 5 + learning_rate = 0.001 with self.test_session(): model = keras.models.Sequential() @@ -433,7 +474,11 @@ class LossWeightingTest(test.TestCase): model.add(keras.layers.Activation('relu')) model.add(keras.layers.Dense(num_classes)) model.add(keras.layers.Activation('softmax')) - model.compile(loss='categorical_crossentropy', optimizer='rmsprop') + model.compile( + loss='categorical_crossentropy', + metrics=['acc'], + weighted_metrics=['mae'], + optimizer=RMSPropOptimizer(learning_rate=learning_rate)) np.random.seed(1337) (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( @@ -483,8 +528,9 @@ class LossWeightingTest(test.TestCase): ref_score = model.evaluate(x_test, y_test, verbose=0) score = model.evaluate( x_test[test_ids, :], y_test[test_ids, :], verbose=0) - self.assertLess(score, ref_score) + self.assertLess(score[0], ref_score[0]) + @tf_test_util.run_in_graph_and_eager_modes def test_sample_weights(self): num_classes = 5 batch_size = 5 @@ -493,6 +539,7 @@ class LossWeightingTest(test.TestCase): train_samples = 1000 test_samples = 1000 input_dim = 5 + learning_rate = 0.001 with self.test_session(): model = keras.models.Sequential() @@ -500,7 +547,11 @@ class LossWeightingTest(test.TestCase): model.add(keras.layers.Activation('relu')) model.add(keras.layers.Dense(num_classes)) model.add(keras.layers.Activation('softmax')) - model.compile(loss='categorical_crossentropy', optimizer='rmsprop') + model.compile( + RMSPropOptimizer(learning_rate=learning_rate), + metrics=['acc'], + weighted_metrics=['mae'], + loss='categorical_crossentropy') np.random.seed(43) (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( @@ -515,9 +566,6 @@ class LossWeightingTest(test.TestCase): y_test = keras.utils.to_categorical(y_test, num_classes) test_ids = np.where(int_y_test == np.array(weighted_class))[0] - class_weight = dict([(i, 1.) for i in range(num_classes)]) - class_weight[weighted_class] = 2. - sample_weight = np.ones((y_train.shape[0])) sample_weight[int_y_train == weighted_class] = 2. @@ -546,10 +594,12 @@ class LossWeightingTest(test.TestCase): y_train[:batch_size], sample_weight=sample_weight[:batch_size]) ref_score = model.evaluate(x_test, y_test, verbose=0) - score = model.evaluate( - x_test[test_ids, :], y_test[test_ids, :], verbose=0) - self.assertLess(score, ref_score) + if not context.executing_eagerly(): + score = model.evaluate( + x_test[test_ids, :], y_test[test_ids, :], verbose=0) + self.assertLess(score[0], ref_score[0]) + @tf_test_util.run_in_graph_and_eager_modes def test_temporal_sample_weights(self): num_classes = 5 batch_size = 5 @@ -559,6 +609,7 @@ class LossWeightingTest(test.TestCase): test_samples = 1000 input_dim = 5 timesteps = 3 + learning_rate = 0.001 with self.test_session(): model = keras.models.Sequential() @@ -581,9 +632,6 @@ class LossWeightingTest(test.TestCase): y_test = keras.utils.to_categorical(y_test, num_classes) test_ids = np.where(int_y_test == np.array(weighted_class))[0] - class_weight = dict([(i, 1.) for i in range(num_classes)]) - class_weight[weighted_class] = 2. - sample_weight = np.ones((y_train.shape[0])) sample_weight[int_y_train == weighted_class] = 2. @@ -605,8 +653,10 @@ class LossWeightingTest(test.TestCase): temporal_sample_weight, timesteps, axis=1) model.compile( + RMSPropOptimizer(learning_rate=learning_rate), loss='binary_crossentropy', - optimizer='rmsprop', + metrics=['acc'], + weighted_metrics=['mae'], sample_weight_mode='temporal') model.fit( @@ -634,16 +684,19 @@ class LossWeightingTest(test.TestCase): temporal_y_train[:batch_size], sample_weight=temporal_sample_weight[:batch_size]) ref_score = model.evaluate(temporal_x_test, temporal_y_test, verbose=0) - score = model.evaluate( - temporal_x_test[test_ids], temporal_y_test[test_ids], verbose=0) - self.assertLess(score, ref_score) + if not context.executing_eagerly(): + score = model.evaluate( + temporal_x_test[test_ids], temporal_y_test[test_ids], verbose=0) + self.assertLess(score[0], ref_score[0]) + @tf_test_util.run_in_graph_and_eager_modes def test_class_weight_invalid_use_case(self): num_classes = 5 train_samples = 1000 test_samples = 1000 input_dim = 5 timesteps = 3 + learning_rate = 0.001 with self.test_session(): model = keras.models.Sequential() @@ -652,9 +705,8 @@ class LossWeightingTest(test.TestCase): keras.layers.Dense(num_classes), input_shape=(timesteps, input_dim))) model.add(keras.layers.Activation('softmax')) - model.compile( - loss='binary_crossentropy', - optimizer='rmsprop') + optimizer = RMSPropOptimizer(learning_rate=learning_rate) + model.compile(optimizer, loss='binary_crossentropy') (x_train, y_train), _ = testing_utils.get_test_data( train_samples=train_samples, @@ -672,16 +724,14 @@ class LossWeightingTest(test.TestCase): with self.assertRaises(ValueError): model.compile( - loss='binary_crossentropy', - optimizer='rmsprop', - sample_weight_mode=[]) + optimizer, loss='binary_crossentropy', sample_weight_mode=[]) # Build multi-output model x = keras.Input((3,)) y1 = keras.layers.Dense(4, name='1')(x) y2 = keras.layers.Dense(4, name='2')(x) model = keras.models.Model(x, [y1, y2]) - model.compile(optimizer='rmsprop', loss='mse') + model.compile(optimizer, loss='mse') x_np = np.random.random((10, 3)) y_np = np.random.random((10, 4)) w_np = np.random.random((10,)) @@ -708,22 +758,127 @@ class LossWeightingTest(test.TestCase): model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) + @tf_test_util.run_in_graph_and_eager_modes + def test_default_sample_weight(self): + """Verifies that fit works without having to set sample_weight.""" + + num_classes = 5 + input_dim = 5 + timesteps = 3 + learning_rate = 0.001 + + with self.test_session(): + model = keras.models.Sequential() + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(num_classes), + input_shape=(timesteps, input_dim))) + + x = np.random.random((10, timesteps, input_dim)) + y = np.random.random((10, timesteps, num_classes)) + optimizer = RMSPropOptimizer(learning_rate=learning_rate) + + # sample_weight_mode is a list and mode value is None + model.compile(optimizer, loss='mse', sample_weight_mode=[None]) + model.fit(x, y, epochs=1, batch_size=10) + + # sample_weight_mode is a list and mode value is `temporal` + model.compile(optimizer, loss='mse', sample_weight_mode=['temporal']) + model.fit(x, y, epochs=1, batch_size=10) + + # sample_weight_mode is a dict and mode value is None + model.compile( + optimizer, loss='mse', sample_weight_mode={'time_distributed': None}) + model.fit(x, y, epochs=1, batch_size=10) + + # sample_weight_mode is a dict and mode value is `temporal` + model.compile( + optimizer, + loss='mse', + sample_weight_mode={'time_distributed': 'temporal'}) + model.fit(x, y, epochs=1, batch_size=10) + + # sample_weight_mode is a not a list/dict and mode value is None + model.compile(optimizer, loss='mse', sample_weight_mode=None) + model.fit(x, y, epochs=1, batch_size=10) + + # sample_weight_mode is a not a list/dict and mode value is `temporal` + model.compile(optimizer, loss='mse', sample_weight_mode='temporal') + model.fit(x, y, epochs=1, batch_size=10) + class LossMaskingTest(test.TestCase): - def test_masking(self): + @tf_test_util.run_in_graph_and_eager_modes + def test_masking_graph_sequential(self): with self.test_session(): - np.random.seed(1337) x = np.array([[[1], [1]], [[0], [0]]]) model = keras.models.Sequential() model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) model.add( keras.layers.TimeDistributed( keras.layers.Dense(1, kernel_initializer='one'))) - model.compile(loss='mse', optimizer='sgd') + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) y = np.array([[[1], [1]], [[1], [1]]]) loss = model.train_on_batch(x, y) - self.assertEqual(loss, 0) + self.assertEqual(float(loss), 0.) + + @tf_test_util.run_in_graph_and_eager_modes + def test_masking_deferred_sequential(self): + with self.test_session(): + x = np.array([[[1], [1]], [[0], [0]]]) + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0)) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + y = np.array([[[1], [1]], [[1], [1]]]) + loss = model.train_on_batch(x, y) + self.assertEqual(float(loss), 0.) + + @tf_test_util.run_in_graph_and_eager_modes + def test_masking_functional(self): + with self.test_session(): + x = np.array([[[1], [1]], [[0], [0]]]) + inputs = keras.layers.Input((2, 1)) + outputs = keras.layers.Masking(mask_value=0)(inputs) + outputs = keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))(outputs) + model = keras.Model(inputs, outputs) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + y = np.array([[[1], [1]], [[1], [1]]]) + loss = model.train_on_batch(x, y) + self.assertEqual(float(loss), 0.) + + @tf_test_util.run_in_graph_and_eager_modes + def test_mask_argument_in_layer(self): + # Test that the mask argument gets correctly passed to a layer in the + # functional API. + + class CustomMaskedLayer(keras.layers.Layer): + + def __init__(self): + super(CustomMaskedLayer, self).__init__() + self.supports_masking = True + + def call(self, inputs, mask=None): + assert mask is not None + return inputs + + def compute_output_shape(self, input_shape): + return input_shape + + with self.test_session(): + x = np.random.random((5, 3)) + inputs = keras.layers.Input((3,)) + masked = keras.layers.Masking(mask_value=0)(inputs) + outputs = CustomMaskedLayer()(masked) + + model = keras.Model(inputs, outputs) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + y = np.random.random((5, 3)) + model.train_on_batch(x, y) def test_loss_masking(self): with self.test_session(): @@ -744,6 +899,22 @@ class LossMaskingTest(test.TestCase): keras.backend.variable(weights), keras.backend.variable(mask))) +class LearningPhaseTest(test.TestCase): + + def test_empty_model_no_learning_phase(self): + with self.test_session(): + model = keras.models.Sequential() + self.assertFalse(model.uses_learning_phase) + + def test_dropout_has_learning_phase(self): + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_dim=3)) + model.add(keras.layers.Dropout(0.5)) + model.add(keras.layers.Dense(2)) + self.assertTrue(model.uses_learning_phase) + + class TestDynamicTrainability(test.TestCase): def test_trainable_warning(self): @@ -911,7 +1082,10 @@ class TestGeneratorMethods(test.TestCase): x = keras.Input((2,)) y = keras.layers.Dense(1)(x) fn_model = keras.models.Model(x, y) - fn_model.compile(loss='mse', optimizer='sgd') + fn_model.compile( + loss='mse', + optimizer='sgd', + metrics=['mae', metrics_module.CategoricalAccuracy()]) seq_model = keras.models.Sequential() seq_model.add(keras.layers.Dense(1, input_shape=(2,))) @@ -993,7 +1167,10 @@ class TestGeneratorMethods(test.TestCase): with self.test_session(): model = keras.models.Sequential() model.add(keras.layers.Dense(1, input_shape=(2,))) - model.compile(loss='mse', optimizer='sgd') + model.compile( + loss='mse', + optimizer='sgd', + metrics=['mae', metrics_module.CategoricalAccuracy()]) model.fit_generator(custom_generator(), steps_per_epoch=5, @@ -1145,10 +1322,12 @@ class TestTrainingWithDataTensors(test.TestCase): y = keras.layers.Dense(4, name='dense')(x) model = keras.Model(x, y) - optimizer = 'rmsprop' + optimizer = RMSPropOptimizer(learning_rate=0.001) loss = 'mse' - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics) + model.compile( + optimizer, + loss, + metrics=['mae', metrics_module.CategoricalAccuracy()]) inputs = keras.backend.zeros(shape=(10, 3)) targets = keras.backend.zeros(shape=(10, 4)) @@ -1192,8 +1371,11 @@ class TestTrainingWithDataTensors(test.TestCase): optimizer = 'rmsprop' loss = 'mse' loss_weights = [1., 0.5] - metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights) + model.compile( + optimizer, + loss, + metrics=['mae', metrics_module.CategoricalAccuracy()], + loss_weights=loss_weights) input_a_tf = keras.backend.zeros(shape=(10, 3)) input_b_tf = keras.backend.zeros(shape=(10, 3)) @@ -1631,8 +1813,11 @@ class TestTrainingWithDataTensors(test.TestCase): model.train_on_batch(input_val, None) # test with sample weights - model.compile(optimizer='rmsprop', loss='mse', - target_tensors=[target_a, target_b]) + model.compile( + optimizer='rmsprop', + loss='mse', + metrics=['mae', metrics_module.CategoricalAccuracy()], + target_tensors=[target_a, target_b]) model.train_on_batch(input_val, None, sample_weight={'dense_a': np.random.random((10,))}) @@ -1696,34 +1881,10 @@ class TestTrainingWithDataTensors(test.TestCase): model.train_on_batch([input_a_np, input_b_np], [output_a_np, output_b_np]) - @tf_test_util.run_in_graph_and_eager_modes() - def test_metric_names_are_identical_in_graph_and_eager(self): - a = keras.layers.Input(shape=(3,), name='input_a') - b = keras.layers.Input(shape=(3,), name='input_b') - - dense = keras.layers.Dense(4, name='dense') - c = dense(a) - d = dense(b) - e = keras.layers.Dropout(0.5, name='dropout')(c) - - model = keras.models.Model([a, b], [d, e]) - - optimizer = RMSPropOptimizer(learning_rate=0.001) - loss = 'mse' - loss_weights = [1., 0.5] - metrics = ['mae', 'acc'] - model.compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights) - reference_metric_names = ['loss', 'dense_loss', 'dropout_loss', - 'dense_mean_absolute_error', - 'dense_acc', - 'dropout_mean_absolute_error', - 'dropout_acc'] - self.assertEqual(reference_metric_names, model.metrics_names) - class TestTrainingWithDatasetIterators(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_training_and_eval_methods_on_iterators_single_io(self): with self.test_session(): x = keras.layers.Input(shape=(3,), name='input') @@ -1732,7 +1893,7 @@ class TestTrainingWithDatasetIterators(test.TestCase): optimizer = RMSPropOptimizer(learning_rate=0.001) loss = 'mse' - metrics = ['mae'] + metrics = ['mae', metrics_module.CategoricalAccuracy()] model.compile(optimizer, loss, metrics=metrics) inputs = np.zeros((10, 3)) @@ -1789,6 +1950,7 @@ class TestTrainingWithDatasetIterators(test.TestCase): 'you should specify the `steps` argument'): model.predict(iterator, verbose=0) + @tf_test_util.run_in_graph_and_eager_modes def test_get_next_op_created_once(self): with self.test_session(): x = keras.layers.Input(shape=(3,), name='input') @@ -1813,7 +1975,7 @@ class TestTrainingWithDatasetIterators(test.TestCase): ops.get_default_graph().finalize() model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_iterators_running_out_of_data(self): with self.test_session(): x = keras.layers.Input(shape=(3,), name='input') @@ -1841,6 +2003,7 @@ class TestTrainingWithDatasetIterators(test.TestCase): class TestTrainingWithDataset(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes def test_calling_model_on_same_dataset(self): with self.test_session(): x = keras.layers.Input(shape=(3,), name='input') @@ -1867,7 +2030,7 @@ class TestTrainingWithDataset(test.TestCase): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, validation_data=dataset, validation_steps=2) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_training_and_eval_methods_on_dataset(self): with self.test_session(): x = keras.layers.Input(shape=(3,), name='input') @@ -1876,7 +2039,7 @@ class TestTrainingWithDataset(test.TestCase): optimizer = RMSPropOptimizer(learning_rate=0.001) loss = 'mse' - metrics = ['mae'] + metrics = ['mae', metrics_module.CategoricalAccuracy()] model.compile(optimizer, loss, metrics=metrics) inputs = np.zeros((10, 3)) @@ -1963,5 +2126,188 @@ class TestTrainingWithDataset(test.TestCase): model.train_on_batch(dataset) +class TestTrainingWithMetrics(test.TestCase): + """Training tests related to metrics.""" + + @tf_test_util.run_in_graph_and_eager_modes + def test_metrics_names(self): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + + optimizer = RMSPropOptimizer(learning_rate=0.001) + metrics = ['mse', metrics_module.BinaryAccuracy()] + model.compile(optimizer, loss='mae', metrics=metrics) + reference_metric_names = [ + 'loss', 'dense_loss', 'dropout_loss', 'dense_mean_squared_error', + 'dense_binary_accuracy', 'dropout_mean_squared_error', + 'dropout_binary_accuracy' + ] + self.assertEqual(reference_metric_names, model.metrics_names) + + @tf_test_util.run_in_graph_and_eager_modes + def test_metrics_correctness(self): + with self.test_session(): + model = keras.Sequential() + model.add( + keras.layers.Dense( + 3, activation='relu', input_dim=4, kernel_initializer='ones')) + model.add( + keras.layers.Dense( + 1, activation='sigmoid', kernel_initializer='ones')) + model.compile( + loss='mae', + metrics=['accuracy', metrics_module.BinaryAccuracy()], + optimizer=RMSPropOptimizer(learning_rate=0.001)) + + # verify correctness of stateful and stateless metrics. + x = np.ones((100, 4)) + y = np.ones((100, 1)) + outs = model.evaluate(x, y) + self.assertEqual(outs[1], 1.) + self.assertEqual(outs[2], 1.) + + y = np.zeros((100, 1)) + outs = model.evaluate(x, y) + self.assertEqual(outs[1], 0.) + self.assertEqual(outs[2], 0.) + + @tf_test_util.run_in_graph_and_eager_modes + def test_metrics_correctness_with_iterator(self): + with self.test_session(): + model = keras.Sequential() + model.add( + keras.layers.Dense( + 8, activation='relu', input_dim=4, kernel_initializer='ones')) + model.add( + keras.layers.Dense( + 1, activation='sigmoid', kernel_initializer='ones')) + model.compile( + loss='binary_crossentropy', + metrics=['accuracy', metrics_module.BinaryAccuracy()], + optimizer=RMSPropOptimizer(learning_rate=0.001)) + + np.random.seed(123) + x = np.random.randint(10, size=(100, 4)).astype(np.float32) + y = np.random.randint(2, size=(100, 1)).astype(np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.batch(10) + iterator = dataset.make_one_shot_iterator() + outs = model.evaluate(iterator, steps=10) + self.assertEqual(np.around(outs[1], decimals=1), 0.5) + self.assertEqual(np.around(outs[2], decimals=1), 0.5) + + y = np.zeros((100, 1), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + iterator = dataset.make_one_shot_iterator() + outs = model.evaluate(iterator, steps=10) + self.assertEqual(outs[1], 0.) + self.assertEqual(outs[2], 0.) + + @tf_test_util.run_in_graph_and_eager_modes + def test_metrics_correctness_with_weighted_metrics(self): + with self.test_session(): + np.random.seed(1337) + x = np.array([[[1.], [1.]], [[0.], [0.]]]) + model = keras.models.Sequential() + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='ones'), + input_shape=(2, 1))) + model.compile( + RMSPropOptimizer(learning_rate=0.001), + loss='mse', + sample_weight_mode='temporal', + weighted_metrics=['accuracy', + metrics_module.BinaryAccuracy()]) + y = np.array([[[1.], [1.]], [[1.], [1.]]]) + + outs = model.evaluate(x, y) + self.assertEqual(outs, [0.5, 0.5, 0.5]) + + w = np.array([[0., 0.], [0., 0.]]) + outs = model.evaluate(x, y, sample_weight=w) + self.assertEqual(outs, [0., 0., 0.]) + + w = np.array([[3., 4.], [1., 2.]]) + outs = model.evaluate(x, y, sample_weight=w) + self.assertArrayNear(outs, [0.3, 0.7, 0.7], .001) + + @tf_test_util.run_in_graph_and_eager_modes + def test_metric_state_reset_between_fit_and_evaluate(self): + with self.test_session(): + model = keras.Sequential() + model.add(keras.layers.Dense(3, activation='relu', input_dim=4)) + model.add(keras.layers.Dense(1, activation='sigmoid')) + acc_obj = metrics_module.BinaryAccuracy() + model.compile( + loss='mae', + metrics=[acc_obj], + optimizer=RMSPropOptimizer(learning_rate=0.001)) + + x_train = np.random.random((100, 4)) + y_train = np.random.random((100, 1)) + model.fit(x_train, y_train, batch_size=5, epochs=2) + self.assertEqual(self.evaluate(acc_obj.count), 100) + + x_test = np.random.random((10, 4)) + y_test = np.random.random((10, 1)) + model.evaluate(x_test, y_test, batch_size=5) + self.assertEqual(self.evaluate(acc_obj.count), 10) + + @tf_test_util.run_in_graph_and_eager_modes + def test_invalid_metrics(self): + num_classes = 5 + input_dim = 5 + + with self.test_session(): + model = keras.models.Sequential() + model.add( + keras.layers.Dense(10, activation='relu', input_shape=(input_dim,))) + model.add(keras.layers.Dense(num_classes, activation='softmax')) + + with self.assertRaisesRegexp( + TypeError, 'Type of `metrics` argument not understood. ' + 'Expected a list or dictionary, found: '): + model.compile( + RMSPropOptimizer(learning_rate=0.001), + loss='categorical_crossentropy', + metrics=metrics_module.CategoricalAccuracy()) + + @tf_test_util.run_in_graph_and_eager_modes + def test_metrics_masking(self): + with self.test_session(): + np.random.seed(1337) + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='ones'))) + model.compile( + RMSPropOptimizer(learning_rate=0.001), + loss='mse', + weighted_metrics=['accuracy', + metrics_module.BinaryAccuracy()]) + + # verify that masking is applied for stateless and stateful metrics. + x = np.array([[[1], [1]], [[1], [1]], [[0], [0]]]) + y = np.array([[[1], [1]], [[0], [1]], [[1], [1]]]) + scores = model.train_on_batch(x, y) + self.assertArrayNear(scores, [0.25, 0.75, 0.75], 0.1) + + # verify that masking is combined with sample weights. + w = np.array([3, 2, 4]) + scores = model.train_on_batch(x, y, sample_weight=w) + self.assertArrayNear(scores, [0.2, 0.8, 0.8], 0.1) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 728a2b493b9f076cc2942766d2677c1f24fb3c15..f94697c91389e67d1766459e3b27eb1ad8c8523c 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -19,16 +19,150 @@ from __future__ import division from __future__ import print_function import copy +import math import numpy as np +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K from tensorflow.python.keras import losses from tensorflow.python.keras import metrics as metrics_module +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import weights_broadcast_ops + + +def _map_nested(data, func): + """Maps each nested element using func.""" + if isinstance(data, list): + return [_map_nested(nested_data, func) for nested_data in data] + elif isinstance(data, tuple): + return tuple(_map_nested(nested_data, func) for nested_data in data) + elif isinstance(data, dict): + return { + k: _map_nested(nested_data, func) for k, nested_data in data.items() + } + else: + return func(data) + + +def _nested_all(data, cond_func): + """Checks if all elements in a nested structure satisfy cond_func.""" + if isinstance(data, (tuple, list)): + return all([_nested_all(nested_data, cond_func) for nested_data in data]) + elif isinstance(data, dict): + return all( + [_nested_all(nested_data, cond_func) for nested_data in data.values()]) + else: + return cond_func(data) + + +def _nested_any(data, cond_func): + """Checks if any nested_elements in a nested structure satisfy cond_func.""" + if isinstance(data, (tuple, list)): + return any([_nested_any(nested_data, cond_func) for nested_data in data]) + elif isinstance(data, dict): + return any( + [_nested_any(nested_data, cond_func) for nested_data in data.values()]) + else: + return cond_func(data) + + +def _convert_lists_to_tuples(data): + """Converts all lists to tuples, since Datasets expect tuples.""" + if isinstance(data, (tuple, list)): + return tuple(_convert_lists_to_tuples(nested_data) for nested_data in data) + elif isinstance(data, dict): + return { + k: _convert_lists_to_tuples(nested_data) + for k, nested_data in data.items() + } + else: + return data + + +def _get_batch_axis_size(data): + """Returns batch axis shape for nested data.""" + if isinstance(data, (tuple, list)): + return _get_batch_axis_size(data[0]) + elif isinstance(data, dict): + return _get_batch_axis_size(list(data.values())) + else: + return int(data.shape[0]) + + +def convert_to_iterator(x=None, + y=None, + sample_weights=None, + batch_size=None, + steps_per_epoch=None, + epochs=1, + shuffle=False): + """Converts NumPy arrays or EagerTensors to an EagerIterator. + + Combines all provided data into a single EagerIterator. + + Arguments: + x: NumPy array or EagerTensor, or list of Numpy arrays or EagerTensors + representing inputs to a model. + y: Optional. NumPy array or EagerTensor, or list of Numpy arrays or + EagerTensors representing targets of a model. + sample_weights: Optional NumPy array or EagerTensor representing sample + weights. + batch_size: Used to batch data and calculate how many steps EagerIterator + should take per epoch. + steps_per_epoch: If provided, how many steps EagerIterator should take per + epoch. + epochs: Epochs to repeat iterator for. + shuffle: Whether to shuffle data after each epoch. + + Raises: + ValueError: if steps_per_epoch cannot be calculated from the data + provided. + + Returns: + (Iterator, steps_per_epoch). + + """ + if isinstance(x, iterator_ops.EagerIterator): + return x, steps_per_epoch + + if not _nested_any(sample_weights, lambda x: x is None): + data = (x, y, sample_weights) + elif not _nested_any(y, lambda x: x is None): + data = (x, y) + else: + # always wrap in a tuple, so we know y, sample_weights weren't set + # even when x has multiple elements + data = (x,) + + data = _convert_lists_to_tuples(data) + if steps_per_epoch is None and batch_size is not None: + num_samples = _get_batch_axis_size(data) + steps_per_epoch = int(math.ceil(num_samples / batch_size)) + + if steps_per_epoch is None: + raise ValueError('Could not determine steps_per_epoch.' + 'Please provide either batch_size or' + 'steps_per_epoch.') + + # TODO(omalleyt) for NumPy arrays in graph mode + # placeholder ops should be used + # this is only ideal for eager mode + dataset = dataset_ops.Dataset.from_tensor_slices(data) + + if batch_size is not None: + dataset = dataset.batch(batch_size) + if shuffle: + dataset = dataset.shuffle(buffer_size=10000) + dataset = dataset.repeat(epochs) + iterator = dataset.make_one_shot_iterator() + + return iterator, steps_per_epoch def check_num_samples(ins, @@ -128,8 +262,8 @@ def standardize_input_data(data, except KeyError as e: raise ValueError('No data provided for "' + e.args[0] + '". Need data ' 'for each key in: ' + str(names)) - elif isinstance(data, list): - if isinstance(data[0], list): + elif isinstance(data, (list, tuple)): + if isinstance(data[0], (list, tuple)): data = [np.asarray(d) for d in data] elif len(names) == 1 and isinstance(data[0], (float, int)): data = [np.asarray(data)] @@ -436,23 +570,44 @@ def weighted_masked_objective(fn): # score_array has ndim >= 2 score_array = fn(y_true, y_pred) if mask is not None: - # Cast the mask to floatX to avoid float64 upcasting in theano - mask = math_ops.cast(mask, K.floatx()) - # mask should have the same shape as score_array - score_array *= mask - # the loss per batch should be proportional - # to the number of unmasked samples. - score_array /= K.mean(mask) - - # apply sample weighting + mask = math_ops.cast(mask, y_pred.dtype) + # Update weights with mask. + if weights is None: + weights = mask + else: + # Update shape of weights if possible before adding mask. + # Update dimensions of weights to match with mask if possible. + mask, _, weights = metrics_module.squeeze_or_expand_dimensions( + mask, None, weights) + try: + # Broadcast weights if possible. + weights = weights_broadcast_ops.broadcast_weights(weights, mask) + weights *= mask + except ValueError: + score_array *= mask + score_array /= K.mean(mask) + # TODO(psv): Handle case when mask and weight shapes are not + # compatible. + + # Apply sample weighting. if weights is not None: - # reduce score_array to same ndim as weight array - ndim = K.ndim(score_array) - weight_ndim = K.ndim(weights) - score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim))) - score_array *= weights - score_array /= K.mean( - math_ops.cast(math_ops.not_equal(weights, 0), K.floatx())) + + # Update dimensions of weights to match with values if possible. + score_array, _, weights = metrics_module.squeeze_or_expand_dimensions( + score_array, None, weights) + try: + # Broadcast weights if possible. + weights = weights_broadcast_ops.broadcast_weights(weights, score_array) + except ValueError: + # Reduce values to same ndim as weight array. + ndim = K.ndim(score_array) + weight_ndim = K.ndim(weights) + score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim))) + + score_array = math_ops.multiply(score_array, weights) + score_array = math_ops.reduce_sum(score_array) + weights = math_ops.reduce_sum(weights) + score_array = metrics_module.safe_div(score_array, weights) return K.mean(score_array) return weighted @@ -482,6 +637,9 @@ def standardize_weights(y, Raises: ValueError: In case of invalid user-provided arguments. """ + # Iterator may return sample_weight as 1-tuple + if isinstance(sample_weight, tuple): + sample_weight = sample_weight[0] if sample_weight_mode is not None: if sample_weight_mode != 'temporal': raise ValueError('"sample_weight_mode ' @@ -562,65 +720,34 @@ def has_tensors(ls): return tensor_util.is_tensor(ls) -def populate_metric_names(model): - for i in range(len(model.outputs)): - metrics = model.nested_metrics[i] - for metric in metrics: - base_metric_name = get_base_metric_name(metric) - add_metric_name(model, base_metric_name, i) - - -def get_base_metric_name(metric, weighted=False): - """Returns the metric name given the metric function. +def get_metric_function(metric, output_shape=None, loss_fn=None): + """Returns the metric function corresponding to the given metric input. Arguments: metric: Metric function name or reference. - weighted: Boolean indicating if the metric for which we are adding - names is weighted. + output_shape: The shape of the output that this metric + will be calculated for. + loss_fn: The loss function used. Returns: - a metric name. - """ - metric_name_prefix = 'weighted_' if weighted else '' - if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): - if metric in ('accuracy', 'acc'): - suffix = 'acc' - elif metric in ('crossentropy', 'ce'): - suffix = 'ce' - metric_name = metric_name_prefix + suffix - else: - metric_fn = metrics_module.get(metric) - # Get metric name as string - if hasattr(metric_fn, 'name'): - metric_name = metric_fn.name - else: - metric_name = metric_fn.__name__ - metric_name = metric_name_prefix + metric_name - - return metric_name - - -def add_metric_name(model, metric_name, index): - """Makes the metric name unique and adds it to the model's metric name list. - - If there are multiple outputs for which the metrics are calculated, the - metric names have to be made unique by appending an integer. - - Arguments: - model: Model to which we are adding metric names. - metric_name: Metric name that corresponds to the metric specified by the - user. For example: 'acc' - index: The index of the model output for which the metric name is being - added. + The metric function. """ - if len(model.output_names) > 1: - metric_name = '%s_%s' % (model.output_names[index], metric_name) - j = 1 - base_metric_name = metric_name - while metric_name in model.metrics_names: - metric_name = '%s_%d' % (base_metric_name, j) - j += 1 - model.metrics_names.append(metric_name) + if metric in ['accuracy', 'acc']: + if output_shape[-1] == 1 or loss_fn == losses.binary_crossentropy: + return metrics_module.binary_accuracy # case: binary accuracy + elif loss_fn == losses.sparse_categorical_crossentropy: + # case: categorical accuracy with sparse targets + return metrics_module.sparse_categorical_accuracy + return metrics_module.categorical_accuracy # case: categorical accuracy + elif metric in ['crossentropy', 'ce']: + if output_shape[-1] == 1 or loss_fn == losses.binary_crossentropy: + return metrics_module.binary_crossentropy # case: binary cross-entropy + elif loss_fn == losses.sparse_categorical_crossentropy: + # case: categorical cross-entropy with sparse targets + return metrics_module.sparse_categorical_crossentropy + # case: categorical cross-entropy + return metrics_module.categorical_crossentropy + return metrics_module.get(metric) def validate_iterator_input(x, y, sample_weight, validation_split=None): @@ -722,3 +849,83 @@ def cast_if_floating_dtype(x): for val in x ] return math_ops.cast(x, dtype=K.floatx()) if x.dtype.is_floating else x + + +def get_output_sample_weight_and_mode(skip_target_weighing_indices, + sample_weight_mode, output_name, + output_index): + """Returns the sample weight and weight mode for a single output.""" + if output_index in skip_target_weighing_indices: + return None, None + + if sample_weight_mode == 'temporal': + default_value = [[1.]] + shape = [None, None] + mode = 'temporal' + else: + default_value = [1.] + shape = [None] + mode = None + if context.executing_eagerly(): + weight = None + else: + weight = array_ops.placeholder_with_default( + constant_op.constant(default_value, dtype=K.floatx()), + shape=shape, + name=output_name + '_sample_weights') + return weight, mode + + +def prepare_sample_weights(output_names, sample_weight_mode, + skip_target_weighing_indices): + """Prepares sample weights for the model. + + Args: + output_names: List of model output names. + sample_weight_mode: sample weight mode user input passed from compile API. + skip_target_weighing_indices: Indices of output for which sample weights + should be skipped. + + Returns: + A pair of list of sample weights and sample weight modes + (one for each output). + + Raises: + ValueError: In case of invalid `sample_weight_mode` input. + """ + sample_weights = [] + sample_weight_modes = [] + if isinstance(sample_weight_mode, dict): + unknown_output = set(sample_weight_mode.keys()) - set(output_names) + if unknown_output: + raise ValueError('Unknown entry in ' + 'sample_weight_mode dictionary: "' + unknown_output + + '". Only expected the following keys: ' + + str(output_names)) + for i, name in enumerate(output_names): + if (i not in skip_target_weighing_indices and + name not in sample_weight_mode): + raise ValueError('Output missing from sample_weight_modes dictionary') + weight, mode = get_output_sample_weight_and_mode( + skip_target_weighing_indices, sample_weight_mode.get(name), name, i) + sample_weights.append(weight) + sample_weight_modes.append(mode) + elif isinstance(sample_weight_mode, list): + if len(sample_weight_mode) != len(output_names): + raise ValueError('When passing a list as sample_weight_mode, ' + 'it should have one entry per model output. ' + 'The model has ' + str(len(output_names)) + + ' outputs, but you passed ' + + str(len(sample_weight_mode)) + 'sample_weight_modes') + for i, name in enumerate(output_names): + weight, mode = get_output_sample_weight_and_mode( + skip_target_weighing_indices, sample_weight_mode[i], name, i) + sample_weights.append(weight) + sample_weight_modes.append(mode) + else: + for i, name in enumerate(output_names): + weight, mode = get_output_sample_weight_and_mode( + skip_target_weighing_indices, sample_weight_mode, name, i) + sample_weights.append(weight) + sample_weight_modes.append(mode) + return sample_weights, sample_weight_modes diff --git a/tensorflow/python/keras/engine/training_utils_test.py b/tensorflow/python/keras/engine/training_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..297a1ae494f8c55265a98a60490a8b0d240b3969 --- /dev/null +++ b/tensorflow/python/keras/engine/training_utils_test.py @@ -0,0 +1,150 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for training utility functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training_utils +from tensorflow.python.platform import test + + +class TrainingUtilTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def test_convert_to_iterator_single_numpy(self): + batch_size = 2 + a = np.ones([10, 10]) + iterator, steps_per_epoch = training_utils.convert_to_iterator( + x=a, batch_size=batch_size) + self.assertEquals(steps_per_epoch, 5) + + expected_batch = a[:batch_size, :] + actual_batch, = iterator.get_next() + self.assertAllEqual(expected_batch, actual_batch) + + @test_util.run_in_graph_and_eager_modes + def test_convert_to_iterator_single_tensor(self): + batch_size = 2 + a = ops.convert_to_tensor(np.ones([10, 10])) + iterator, steps_per_epoch = training_utils.convert_to_iterator( + x=a, batch_size=batch_size) + self.assertEquals(steps_per_epoch, 5) + + expected_batch = a[:batch_size, :] + actual_batch, = iterator.get_next() + self.assertAllEqual(expected_batch, actual_batch) + + @test_util.run_in_graph_and_eager_modes + def test_convert_to_iterator_y(self): + batch_size = 2 + a = np.ones([10, 100]) + b = np.ones([10, 10]) + iterator, steps_per_epoch = training_utils.convert_to_iterator( + x=a, y=b, batch_size=batch_size) + self.assertEquals(steps_per_epoch, 5) + + expected_x = a[:batch_size, :] + expected_y = b[:batch_size, :] + actual_x, actual_y = iterator.get_next() + self.assertAllEqual(expected_x, actual_x) + self.assertAllEqual(expected_y, actual_y) + + @test_util.run_in_graph_and_eager_modes + def test_convert_to_iterator_sample_weights(self): + batch_size = 2 + a = ops.convert_to_tensor(np.ones([10, 100])) + b = ops.convert_to_tensor(np.ones([10, 10])) + sw = ops.convert_to_tensor(np.ones([10])) + iterator, steps_per_epoch = training_utils.convert_to_iterator( + x=a, y=b, sample_weights=sw, batch_size=batch_size) + self.assertEquals(steps_per_epoch, 5) + + expected_x = a[:batch_size, :] + expected_y = b[:batch_size, :] + expected_sw = sw[:batch_size] + actual_x, actual_y, actual_sw = iterator.get_next() + self.assertAllEqual(expected_x, actual_x) + self.assertAllEqual(expected_y, actual_y) + self.assertAllEqual(expected_sw, actual_sw) + + @test_util.run_in_graph_and_eager_modes + def test_convert_to_iterator_nested(self): + batch_size = 2 + x = {'1': np.ones([10, 100]), '2': [np.zeros([10, 10]), np.ones([10, 20])]} + iterator, steps_per_epoch = training_utils.convert_to_iterator( + x=x, batch_size=batch_size) + self.assertEquals(steps_per_epoch, 5) + + expected_x1 = x['1'][:batch_size, :] + expected_x2_0 = x['2'][0][:batch_size, :] + expected_x2_1 = x['2'][1][:batch_size, :] + + actual_x, = iterator.get_next() + actual_x1 = actual_x['1'][:batch_size, :] + actual_x2_0 = actual_x['2'][0][:batch_size, :] + actual_x2_1 = actual_x['2'][1][:batch_size, :] + + self.assertAllEqual(expected_x1, actual_x1) + self.assertAllEqual(expected_x2_0, actual_x2_0) + self.assertAllEqual(expected_x2_1, actual_x2_1) + + @test_util.run_in_graph_and_eager_modes + def test_convert_to_iterator_epochs(self): + batch_size = 2 + a = np.ones([10, 10]) + iterator, steps_per_epoch = training_utils.convert_to_iterator( + x=a, batch_size=batch_size, epochs=2) + self.assertEquals(steps_per_epoch, 5) + + expected_batch = a[:batch_size, :] + # loop through one whole epoch + for _ in range(6): + actual_batch, = iterator.get_next() + self.assertAllEqual(expected_batch, actual_batch) + + @test_util.run_in_graph_and_eager_modes + def test_convert_to_iterator_insufficient_info(self): + # with batch_size and steps_per_epoch not set + with self.assertRaises(ValueError): + a = np.ones([10, 10]) + _ = training_utils.convert_to_iterator(x=a) + + def test_nested_all(self): + nested_data = {'a': True, 'b': [True, True, (False, True)]} + all_true = training_utils._nested_all(nested_data, lambda x: x) + self.assertEquals(all_true, False) + + nested_data = {'a': True, 'b': [True, True, (True, True)]} + all_true = training_utils._nested_all(nested_data, lambda x: x) + self.assertEquals(all_true, True) + + def test_nested_any(self): + nested_data = [False, {'a': False, 'b': (False, True)}] + any_true = training_utils._nested_any(nested_data, lambda x: x) + self.assertEquals(any_true, True) + + nested_data = [False, {'a': False, 'b': (False, False)}] + any_true = training_utils._nested_any(nested_data, lambda x: x) + self.assertEquals(any_true, False) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/estimator/__init__.py b/tensorflow/python/keras/estimator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b244beb5b58cf339a4687216b87418c88b953c17 --- /dev/null +++ b/tensorflow/python/keras/estimator/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras estimator API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.util.tf_export import tf_export + +# Keras has undeclared dependency on tensorflow/estimator:estimator_py. +# As long as you depend //third_party/py/tensorflow:tensorflow target +# everything will work as normal. + +try: + from tensorflow.python.estimator import keras as keras_lib # pylint: disable=g-import-not-at-top + model_to_estimator = tf_export('keras.estimator.model_to_estimator')( + keras_lib.model_to_estimator) +except Exception: # pylint: disable=broad-except + + # pylint: disable=unused-argument + def stub_model_to_estimator(keras_model=None, + keras_model_path=None, + custom_objects=None, + model_dir=None, + config=None): + raise NotImplementedError( + 'tf.keras.estimator.model_to_estimator function not available in your ' + 'installation.') + # pylint: enable=unused-argument + + model_to_estimator = tf_export('keras.estimator.model_to_estimator')( + stub_model_to_estimator) + diff --git a/tensorflow/python/keras/initializers.py b/tensorflow/python/keras/initializers.py index b9b2e9ad598fabe8cbfbbcbd57d4d71ddf630df7..b9d856efa8f20500595a2f2a49447c724b9a563e 100644 --- a/tensorflow/python/keras/initializers.py +++ b/tensorflow/python/keras/initializers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras initializer classes (soon to be replaced with core TF initializers). +"""Keras initializer serialization / deserialization. """ from __future__ import absolute_import from __future__ import division @@ -22,150 +22,27 @@ import six from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object + +# These imports are brought in so that keras.initializers.deserialize +# has them available in module_objects. from tensorflow.python.ops.init_ops import Constant +from tensorflow.python.ops.init_ops import glorot_normal_initializer +from tensorflow.python.ops.init_ops import glorot_uniform_initializer +from tensorflow.python.ops.init_ops import he_normal # pylint: disable=unused-import +from tensorflow.python.ops.init_ops import he_uniform # pylint: disable=unused-import from tensorflow.python.ops.init_ops import Identity from tensorflow.python.ops.init_ops import Initializer # pylint: disable=unused-import +from tensorflow.python.ops.init_ops import lecun_normal # pylint: disable=unused-import +from tensorflow.python.ops.init_ops import lecun_uniform # pylint: disable=unused-import from tensorflow.python.ops.init_ops import Ones from tensorflow.python.ops.init_ops import Orthogonal from tensorflow.python.ops.init_ops import RandomNormal from tensorflow.python.ops.init_ops import RandomUniform from tensorflow.python.ops.init_ops import TruncatedNormal -from tensorflow.python.ops.init_ops import VarianceScaling +from tensorflow.python.ops.init_ops import VarianceScaling # pylint: disable=unused-import from tensorflow.python.ops.init_ops import Zeros -from tensorflow.python.util.tf_export import tf_export - - -@tf_export('keras.initializers.lecun_normal') -def lecun_normal(seed=None): - """LeCun normal initializer. - - It draws samples from a truncated normal distribution centered on 0 - with `stddev = sqrt(1 / fan_in)` - where `fan_in` is the number of input units in the weight tensor. - - Arguments: - seed: A Python integer. Used to seed the random generator. - - Returns: - An initializer. - - References: - - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) - - [Efficient - Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf) - """ - return VarianceScaling( - scale=1., mode='fan_in', distribution='normal', seed=seed) - - -@tf_export('keras.initializers.lecun_uniform') -def lecun_uniform(seed=None): - """LeCun uniform initializer. - - It draws samples from a uniform distribution within [-limit, limit] - where `limit` is `sqrt(3 / fan_in)` - where `fan_in` is the number of input units in the weight tensor. - - Arguments: - seed: A Python integer. Used to seed the random generator. - - Returns: - An initializer. - - References: - LeCun 98, Efficient Backprop, - http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf - """ - return VarianceScaling( - scale=1., mode='fan_in', distribution='uniform', seed=seed) - - -@tf_export('keras.initializers.glorot_normal') -def glorot_normal(seed=None): - """Glorot normal initializer, also called Xavier normal initializer. - - It draws samples from a truncated normal distribution centered on 0 - with `stddev = sqrt(2 / (fan_in + fan_out))` - where `fan_in` is the number of input units in the weight tensor - and `fan_out` is the number of output units in the weight tensor. - - Arguments: - seed: A Python integer. Used to seed the random generator. - - Returns: - An initializer. - References: - Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf - """ - return VarianceScaling( - scale=1., mode='fan_avg', distribution='normal', seed=seed) - - -@tf_export('keras.initializers.glorot_uniform') -def glorot_uniform(seed=None): - """Glorot uniform initializer, also called Xavier uniform initializer. - - It draws samples from a uniform distribution within [-limit, limit] - where `limit` is `sqrt(6 / (fan_in + fan_out))` - where `fan_in` is the number of input units in the weight tensor - and `fan_out` is the number of output units in the weight tensor. - - Arguments: - seed: A Python integer. Used to seed the random generator. - - Returns: - An initializer. - - References: - Glorot & Bengio, AISTATS 2010 - http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf - """ - return VarianceScaling( - scale=1., mode='fan_avg', distribution='uniform', seed=seed) - - -@tf_export('keras.initializers.he_normal') -def he_normal(seed=None): - """He normal initializer. - - It draws samples from a truncated normal distribution centered on 0 - with `stddev = sqrt(2 / fan_in)` - where `fan_in` is the number of input units in the weight tensor. - - Arguments: - seed: A Python integer. Used to seed the random generator. - - Returns: - An initializer. - - References: - He et al., http://arxiv.org/abs/1502.01852 - """ - return VarianceScaling( - scale=2., mode='fan_in', distribution='normal', seed=seed) - - -@tf_export('keras.initializers.he_uniform') -def he_uniform(seed=None): - """He uniform variance scaling initializer. - - It draws samples from a uniform distribution within [-limit, limit] - where `limit` is `sqrt(6 / fan_in)` - where `fan_in` is the number of input units in the weight tensor. - - Arguments: - seed: A Python integer. Used to seed the random generator. - - Returns: - An initializer. - - References: - He et al., http://arxiv.org/abs/1502.01852 - """ - return VarianceScaling( - scale=2., mode='fan_in', distribution='uniform', seed=seed) +from tensorflow.python.util.tf_export import tf_export # Compatibility aliases @@ -179,6 +56,8 @@ normal = random_normal = RandomNormal truncated_normal = TruncatedNormal identity = Identity orthogonal = Orthogonal +glorot_normal = glorot_normal_initializer +glorot_uniform = glorot_uniform_initializer # pylint: enable=invalid-name diff --git a/tensorflow/python/keras/initializers_test.py b/tensorflow/python/keras/initializers_test.py index c519e194bdc21692025f259533b8b75e2dc48c09..51725e03f201db092a515456b280e7eca2927ac3 100644 --- a/tensorflow/python/keras/initializers_test.py +++ b/tensorflow/python/keras/initializers_test.py @@ -31,16 +31,6 @@ class KerasInitializersTest(test.TestCase): target_max=None, target_min=None): variable = keras.backend.variable(init(shape)) output = keras.backend.get_value(variable) - lim = 3e-2 - if target_std is not None: - self.assertGreater(lim, abs(output.std() - target_std)) - if target_mean is not None: - self.assertGreater(lim, abs(output.mean() - target_mean)) - if target_max is not None: - self.assertGreater(lim, abs(output.max() - target_max)) - if target_min is not None: - self.assertGreater(lim, abs(output.min() - target_min)) - # Test serialization (assumes deterministic behavior). config = init.get_config() reconstructed_init = init.__class__.from_config(config) diff --git a/tensorflow/python/keras/integration_test.py b/tensorflow/python/keras/integration_test.py index 2a05699407cc608c1ed0dd97d230beeb6e99e0ef..a103b9fbf2d5e3a6c792021c846a562d677e75b8 100644 --- a/tensorflow/python/keras/integration_test.py +++ b/tensorflow/python/keras/integration_test.py @@ -21,9 +21,11 @@ from __future__ import print_function import numpy as np from tensorflow.python import keras +from tensorflow.python.framework import dtypes from tensorflow.python.keras import testing_utils from tensorflow.python.layers import core as tf_core_layers from tensorflow.python.ops import nn +from tensorflow.python.ops import rnn_cell from tensorflow.python.platform import test @@ -103,6 +105,30 @@ class KerasIntegrationTest(test.TestCase): verbose=2) self.assertGreater(history.history['val_acc'][-1], 0.7) + def test_temporal_classification_sequential_tf_rnn(self): + with self.test_session(): + np.random.seed(1337) + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=100, + test_samples=0, + input_shape=(4, 10), + num_classes=2) + y_train = keras.utils.to_categorical(y_train) + + model = keras.models.Sequential() + model.add(keras.layers.RNN(rnn_cell.LSTMCell(5), return_sequences=True, + input_shape=x_train.shape[1:])) + model.add(keras.layers.RNN(rnn_cell.GRUCell(y_train.shape[-1], + activation='softmax', + dtype=dtypes.float32))) + model.compile(loss='categorical_crossentropy', + optimizer=keras.optimizers.Adam(lr=0.1), + metrics=['accuracy']) + history = model.fit(x_train, y_train, epochs=15, batch_size=16, + validation_data=(x_train, y_train), + verbose=2) + self.assertGreater(history.history['val_acc'][-1], 0.7) + def test_image_classification_sequential(self): with self.test_session(): np.random.seed(1337) diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index ce0cdb2e1b34d1133f64965818b0a2bcef108d86..e3a686f45d92dde8ea90d496b3cb5099f6b84b58 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -20,15 +20,16 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras.engine import Input -from tensorflow.python.keras.engine import InputLayer -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.input_layer import Input +from tensorflow.python.keras.engine.input_layer import InputLayer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer # Advanced activations. from tensorflow.python.keras.layers.advanced_activations import LeakyReLU from tensorflow.python.keras.layers.advanced_activations import PReLU from tensorflow.python.keras.layers.advanced_activations import ELU +from tensorflow.python.keras.layers.advanced_activations import ReLU from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU from tensorflow.python.keras.layers.advanced_activations import Softmax diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py index 8ade3c317456a88181f6005c620953817463595b..61ab69c16f14b8d734a306ab3ad18c73eaf160ca 100644 --- a/tensorflow/python/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/layers/advanced_activations.py @@ -23,8 +23,8 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export @@ -278,3 +278,65 @@ class Softmax(Layer): @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): return input_shape + + +@tf_export('keras.layers.ReLU') +class ReLU(Layer): + """Rectified Linear Unit activation function. + + With default values, it returns element-wise `max(x, 0)`. + + Otherwise, it follows: + `f(x) = max_value` for `x >= max_value`, + `f(x) = x` for `threshold <= x < max_value`, + `f(x) = negative_slope * (x - threshold)` otherwise. + + Input shape: + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape: + Same shape as the input. + + Arguments: + max_value: float >= 0. Maximum activation value. + negative_slope: float >= 0. Negative slope coefficient. + threshold: float. Threshold value for thresholded activation. + """ + + def __init__(self, max_value=None, negative_slope=0, threshold=0, **kwargs): + super(ReLU, self).__init__(**kwargs) + if max_value is not None and max_value < 0.: + raise ValueError('max_value of Relu layer ' + 'cannot be negative value: ' + str(max_value)) + if negative_slope < 0.: + raise ValueError('negative_slope of Relu layer ' + 'cannot be negative value: ' + str(negative_slope)) + + self.support_masking = True + self.max_value = K.cast_to_floatx(max_value) + self.negative_slope = K.cast_to_floatx(negative_slope) + self.threshold = K.cast_to_floatx(threshold) + + def call(self, inputs): + # alpha is used for leaky relu slope in activations instead of + # negative_slope. + return activations.relu( + inputs, + alpha=self.negative_slope, + max_value=self.max_value, + threshold=self.threshold) + + def get_config(self): + config = { + 'max_value': self.max_value, + 'negative_slope': self.negative_slope, + 'threshold': self.threshold + } + base_config = super(ReLU, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @tf_utils.shape_type_conversion + def compute_output_shape(self, input_shape): + return input_shape diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py index 81c76db14cd3741687bf5e2bec66e5354e9f6312..53c1baa2bbd4367eb09d5bc792de9f20baa981ef 100644 --- a/tensorflow/python/keras/layers/advanced_activations_test.py +++ b/tensorflow/python/keras/layers/advanced_activations_test.py @@ -62,6 +62,28 @@ class AdvancedActivationsTest(test.TestCase): kwargs={'axis': 1}, input_shape=(2, 3, 4)) + def test_relu(self): + with self.test_session(): + testing_utils.layer_test(keras.layers.ReLU, + kwargs={'max_value': 10}, + input_shape=(2, 3, 4)) + + def test_relu_with_invalid_arg(self): + with self.assertRaisesRegexp( + ValueError, 'max_value of Relu layer cannot be negative value: -10'): + with self.test_session(): + testing_utils.layer_test(keras.layers.ReLU, + kwargs={'max_value': -10}, + input_shape=(2, 3, 4)) + with self.assertRaisesRegexp( + ValueError, + 'negative_slope of Relu layer cannot be negative value: -2'): + with self.test_session(): + testing_utils.layer_test( + keras.layers.ReLU, + kwargs={'negative_slope': -2}, + input_shape=(2, 3, 4)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index 720b386c4dd5fbed95b5275c0ca7d3885afda422..a57ac121ed7486a9beb64e6dd7ed3b132ca258df 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -26,8 +26,8 @@ from tensorflow.python.keras import backend from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer # imports for backwards namespace compatibility # pylint: disable=unused-import from tensorflow.python.keras.layers.pooling import AveragePooling1D @@ -1195,6 +1195,7 @@ class SeparableConv(Conv): dilation_rate=dilation_rate, activation=activations.get(activation), use_bias=use_bias, + bias_initializer=initializers.get(bias_initializer), bias_regularizer=regularizers.get(bias_regularizer), activity_regularizer=regularizers.get(activity_regularizer), bias_constraint=bias_constraint, @@ -1729,7 +1730,7 @@ class DepthwiseConv2D(Conv2D): dilation_rate=self.dilation_rate, data_format=self.data_format) - if self.bias: + if self.use_bias: outputs = backend.bias_add( outputs, self.bias, diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py index c731508b3c32d93895432fd5174c1f57557b10dc..e61dd3043d96e69f76cb5bb041de304f5c1c2642 100644 --- a/tensorflow/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/layers/convolutional_recurrent.py @@ -26,8 +26,8 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.layers.recurrent import _generate_dropout_mask from tensorflow.python.keras.layers.recurrent import _standardize_args from tensorflow.python.keras.layers.recurrent import RNN @@ -788,7 +788,7 @@ class ConvLSTM2D(ConvRNN2D): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of n integers, specifying the dimensions of the convolution window. strides: An integer or tuple/list of n integers, diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py index 167cabaeecb0c4ce9a785e7a990aa715f2d1a5b3..f904744422a4b1296e8f5e8a34373fd0344dc643 100644 --- a/tensorflow/python/keras/layers/convolutional_test.py +++ b/tensorflow/python/keras/layers/convolutional_test.py @@ -45,7 +45,7 @@ class Convolution1DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, length, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv1d(self): kwargs = { 'filters': 2, @@ -117,7 +117,7 @@ class Conv2DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv2d(self): kwargs = { 'filters': 2, @@ -192,7 +192,7 @@ class Conv2DTransposeTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv2dtranspose(self): kwargs = { 'filters': 2, @@ -258,7 +258,7 @@ class Conv3DTransposeTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, depth, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv3dtranspose(self): kwargs = { 'filters': 2, @@ -322,7 +322,7 @@ class SeparableConv1DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, length, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_separable_conv1d(self): kwargs = { 'filters': 2, @@ -398,7 +398,7 @@ class SeparableConv2DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_separable_conv2d(self): kwargs = { 'filters': 2, @@ -477,7 +477,7 @@ class Conv3DTest(test.TestCase): kwargs=test_kwargs, input_shape=(num_samples, depth, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_conv3d(self): kwargs = { 'filters': 2, @@ -529,7 +529,7 @@ class Conv3DTest(test.TestCase): class ZeroPaddingTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_zero_padding_1d(self): num_samples = 2 input_dim = 2 @@ -581,7 +581,7 @@ class ZeroPaddingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.ZeroPadding1D(padding=None) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_zero_padding_2d(self): num_samples = 2 stack_size = 2 @@ -660,7 +660,7 @@ class ZeroPaddingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.ZeroPadding2D(padding=None) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_zero_padding_3d(self): num_samples = 2 stack_size = 2 @@ -702,13 +702,13 @@ class ZeroPaddingTest(test.TestCase): class UpSamplingTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_upsampling_1d(self): with self.test_session(use_gpu=True): testing_utils.layer_test( keras.layers.UpSampling1D, kwargs={'size': 2}, input_shape=(3, 5, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_upsampling_2d(self): num_samples = 2 stack_size = 2 @@ -758,7 +758,7 @@ class UpSamplingTest(test.TestCase): np.testing.assert_allclose(np_output, expected_out) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_upsampling_3d(self): num_samples = 2 stack_size = 2 @@ -818,7 +818,7 @@ class UpSamplingTest(test.TestCase): class CroppingTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_cropping_1d(self): num_samples = 2 time_length = 4 @@ -837,7 +837,7 @@ class CroppingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.Cropping1D(cropping=None) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_cropping_2d(self): num_samples = 2 stack_size = 2 @@ -905,7 +905,7 @@ class CroppingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.Cropping2D(cropping=None) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_cropping_3d(self): num_samples = 2 stack_size = 2 @@ -995,6 +995,7 @@ class DepthwiseConv2DTest(test.TestCase): 'bias_regularizer': 'l2', 'activity_regularizer': 'l2', 'depthwise_constraint': 'unit_norm', + 'use_bias': True, 'strides': (2, 2), } self._run_test(kwargs, 'depth_multiplier', [1]) diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index f60064ed6363d36731795d08bb42e75398628283..4032202986d64047ebde194f812d99924b1a4630 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -26,6 +26,7 @@ import warnings import numpy as np from tensorflow.python.eager import context +from tensorflow.python.framework import common_shapes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import activations @@ -33,8 +34,8 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import conv_utils from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils @@ -465,7 +466,7 @@ class Permute(Layer): Arguments: dims: Tuple of integers. Permutation pattern, does not include the samples dimension. Indexing starts at 1. - For instance, `(2, 1)` permutes the first and second dimension + For instance, `(2, 1)` permutes the first and second dimensions of the input. Input shape: @@ -481,6 +482,11 @@ class Permute(Layer): def __init__(self, dims, **kwargs): super(Permute, self).__init__(**kwargs) self.dims = tuple(dims) + if sorted(dims) != list(range(1, len(dims) + 1)): + raise ValueError( + 'Invalid permutation `dims` for Permute Layer: %s. ' + 'The set of indices in `dims` must be consecutive and start from 1.' % + (dims,)) self.input_spec = InputSpec(ndim=len(self.dims) + 1) def compute_output_shape(self, input_shape): @@ -675,9 +681,8 @@ class Lambda(Layer): 'must be a list, a tuple, or a function.') self._output_shape = output_shape + @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): - input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) - if self._output_shape is None: if context.executing_eagerly(): raise NotImplementedError @@ -929,13 +934,13 @@ class Dense(Layer): def call(self, inputs): inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) - shape = inputs.get_shape().as_list() - if len(shape) > 2: + rank = common_shapes.rank(inputs) + if rank > 2: # Broadcasting is required for the inputs. - outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1], - [0]]) + outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]]) # Reshape the output back to the original ndim of the input. if not context.executing_eagerly(): + shape = inputs.get_shape().as_list() output_shape = shape[:-1] + [self.units] outputs.set_shape(output_shape) else: diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index ff8af976b99376b037af81ed81707332ccf9937e..49ca68ee9e2ac5a59e38586daaf1757bf458a9c4 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -51,7 +51,7 @@ class CoreLayersTest(test.TestCase): dropout = keras.layers.Dropout(0.5) self.assertEqual(True, dropout.supports_masking) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_spatial_dropout(self): testing_utils.layer_test( keras.layers.SpatialDropout1D, @@ -78,7 +78,7 @@ class CoreLayersTest(test.TestCase): kwargs={'rate': 0.5, 'data_format': 'channels_first'}, input_shape=(2, 3, 4, 4, 5)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_activation(self): # with string argument testing_utils.layer_test( @@ -92,7 +92,7 @@ class CoreLayersTest(test.TestCase): kwargs={'activation': keras.backend.relu}, input_shape=(3, 2)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_reshape(self): testing_utils.layer_test( keras.layers.Reshape, @@ -114,12 +114,26 @@ class CoreLayersTest(test.TestCase): kwargs={'target_shape': (-1, 1)}, input_shape=(None, None, 2)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_permute(self): testing_utils.layer_test( keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes + def test_permute_errors_on_invalid_starting_dims_index(self): + with self.assertRaisesRegexp(ValueError, r'Invalid permutation .*dims.*'): + testing_utils.layer_test( + keras.layers.Permute, + kwargs={'dims': (0, 1, 2)}, input_shape=(3, 2, 4)) + + @tf_test_util.run_in_graph_and_eager_modes + def test_permute_errors_on_invalid_set_of_dims_indices(self): + with self.assertRaisesRegexp(ValueError, r'Invalid permutation .*dims.*'): + testing_utils.layer_test( + keras.layers.Permute, + kwargs={'dims': (1, 4, 2)}, input_shape=(3, 2, 4)) + + @tf_test_util.run_in_graph_and_eager_modes def test_flatten(self): testing_utils.layer_test( keras.layers.Flatten, kwargs={}, input_shape=(3, 2, 4)) @@ -134,7 +148,7 @@ class CoreLayersTest(test.TestCase): np.transpose(inputs, (0, 2, 3, 1)), (-1, 5 * 5 * 3)) self.assertAllClose(outputs, target_outputs) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_repeat_vector(self): testing_utils.layer_test( keras.layers.RepeatVector, kwargs={'n': 3}, input_shape=(3, 2)) @@ -173,7 +187,15 @@ class CoreLayersTest(test.TestCase): config = ld.get_config() ld = keras.layers.Lambda.from_config(config) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes + def test_lambda_multiple_inputs(self): + ld = keras.layers.Lambda(lambda x: x[0], output_shape=lambda x: x[0]) + x1 = np.ones([3, 2], np.float32) + x2 = np.ones([3, 5], np.float32) + out = ld([x1, x2]) + self.assertAllEqual(out.shape, [3, 2]) + + @tf_test_util.run_in_graph_and_eager_modes def test_dense(self): testing_utils.layer_test( keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 2)) diff --git a/tensorflow/python/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py index ad6594279d037c8dc0e1408955d2a2eebd51ce1d..cf2b0c476c7229a288f4b4f7b31de09388ade40f 100644 --- a/tensorflow/python/keras/layers/cudnn_recurrent.py +++ b/tensorflow/python/keras/layers/cudnn_recurrent.py @@ -25,7 +25,7 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec +from tensorflow.python.keras.engine.base_layer import InputSpec from tensorflow.python.keras.layers.recurrent import RNN from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_cudnn_rnn_ops diff --git a/tensorflow/python/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py index 9d186f8c586bd9f626e142a855be6d2cf00d7121..2ed0aa8f2684009251e61c92a1ac167f1ba2f0af 100644 --- a/tensorflow/python/keras/layers/cudnn_recurrent_test.py +++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import tempfile from absl.testing import parameterized import numpy as np @@ -30,7 +32,7 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer class CuDNNTest(test.TestCase, parameterized.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_cudnn_rnn_basics(self): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): @@ -58,7 +60,7 @@ class CuDNNTest(test.TestCase, parameterized.TestCase): 'go_backwards': go_backwards}, input_shape=(num_samples, timesteps, input_size)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_trainability(self): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): @@ -217,27 +219,14 @@ class CuDNNTest(test.TestCase, parameterized.TestCase): out5 = model.predict(np.ones((num_samples, timesteps))) self.assertNotEqual(out4.max(), out5.max()) - # TODO(psv): Add generic cross product helper function for parametrized tests. @parameterized.named_parameters( - ('cudnnlstm_to_lstm_unidirectional_impl_1', 'LSTM', False, False, 1), - ('cudnnlstm_to_lstm_bidirectional_impl_1', 'LSTM', False, True, 1), - ('lstm_to_cudnnlstm_unidirectional_impl_1', 'LSTM', True, False, 1), - ('lstm_to_cudnnlstm_bidirectional_impl_1', 'LSTM', True, True, 1), - ('cudnngru_to_gru_unidirectional_impl_1', 'GRU', False, False, 1), - ('cudnngru_to_gru_bidirectional_impl_1', 'GRU', False, True, 1), - ('gru_to_cudnngru_unidirectional_impl_1', 'GRU', True, False, 1), - ('gru_to_cudnngru_bidirectional_impl_1', 'GRU', True, True, 1), - ('cudnnlstm_to_lstm_unidirectional_impl_2', 'LSTM', False, False, 2), - ('cudnnlstm_to_lstm_bidirectional_impl_2', 'LSTM', False, True, 2), - ('lstm_to_cudnnlstm_unidirectional_impl_2', 'LSTM', True, False, 2), - ('lstm_to_cudnnlstm_bidirectional_impl_2', 'LSTM', True, True, 2), - ('cudnngru_to_gru_unidirectional_impl_2', 'GRU', False, False, 2), - ('cudnngru_to_gru_bidirectional_impl_2', 'GRU', False, True, 2), - ('gru_to_cudnngru_unidirectional_impl_2', 'GRU', True, False, 2), - ('gru_to_cudnngru_bidirectional_impl_2', 'GRU', True, True, 2), - ) + *test_util.generate_combinations_with_testcase_name( + rnn_type=['LSTM', 'GRU'], to_cudnn=[True, False], + bidirectional=[True, False], implementation=[1, 2], + model_nest_level=[1, 2], model_type=['seq', 'func'])) def test_load_weights_between_noncudnn_rnn(self, rnn_type, to_cudnn, - bidirectional, implementation): + bidirectional, implementation, + model_nest_level, model_type): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): input_size = 10 @@ -261,14 +250,6 @@ class CuDNNTest(test.TestCase, parameterized.TestCase): cudnn_rnn_layer_class = keras.layers.CuDNNGRU rnn_layer_kwargs['reset_after'] = True - def convert_weights(source_layer, target_layer): - weights = source_layer.get_weights() - weights = keras.engine.saving.preprocess_weights_for_loading( - target_layer, weights) - target_layer.set_weights(weights) - - input_layer = keras.layers.InputLayer(input_shape) - layer = rnn_layer_class(units, **rnn_layer_kwargs) if bidirectional: layer = keras.layers.Bidirectional(layer) @@ -277,18 +258,96 @@ class CuDNNTest(test.TestCase, parameterized.TestCase): if bidirectional: cudnn_layer = keras.layers.Bidirectional(cudnn_layer) - model = keras.models.Sequential([input_layer, layer]) - cudnn_model = keras.models.Sequential([input_layer, cudnn_layer]) + model = self._make_nested_model(input_shape, layer, model_nest_level, + model_type) + cudnn_model = self._make_nested_model(input_shape, cudnn_layer, + model_nest_level, model_type) + + if to_cudnn: + self._convert_model_weights(model, cudnn_model) + else: + self._convert_model_weights(cudnn_model, model) + + self.assertAllClose(model.predict(inputs), cudnn_model.predict(inputs), + atol=1e-4) + + def _make_nested_model(self, input_shape, layer, level=1, model_type='func'): + # example: make_nested_seq_model((1,), Dense(10), level=2).summary() + def make_nested_seq_model(input_shape, layer, level=1): + model = layer + for i in range(1, level + 1): + layers = [keras.layers.InputLayer(input_shape), + model] if (i == 1) else [model] + model = keras.models.Sequential(layers) + return model + + # example: make_nested_func_model((1,), Dense(10), level=2).summary() + def make_nested_func_model(input_shape, layer, level=1): + model_input = keras.layers.Input(input_shape) + model = layer + for _ in range(level): + model = keras.models.Model(model_input, model(model_input)) + return model + + if model_type == 'func': + return make_nested_func_model(input_shape, layer, level) + elif model_type == 'seq': + return make_nested_seq_model(input_shape, layer, level) + + def _convert_model_weights(self, source_model, target_model): + _, fname = tempfile.mkstemp('.h5') + source_model.save_weights(fname) + target_model.load_weights(fname) + os.remove(fname) + + @parameterized.named_parameters( + *test_util.generate_combinations_with_testcase_name( + rnn_type=['LSTM', 'GRU'], to_cudnn=[True, False])) + def test_load_weights_between_noncudnn_rnn_time_distributed(self, rnn_type, + to_cudnn): + # Similar test as test_load_weights_between_noncudnn_rnn() but has different + # rank of input due to usage of TimeDistributed. Issue: #10356. + if test.is_gpu_available(cuda_only=True): + with self.test_session(use_gpu=True): + input_size = 10 + steps = 6 + timesteps = 6 + input_shape = (timesteps, steps, input_size) + units = 2 + num_samples = 32 + inputs = np.random.random((num_samples, timesteps, steps, input_size)) + + rnn_layer_kwargs = { + 'recurrent_activation': 'sigmoid', + # ensure biases are non-zero and properly converted + 'bias_initializer': 'random_uniform', + } + if rnn_type == 'LSTM': + rnn_layer_class = keras.layers.LSTM + cudnn_rnn_layer_class = keras.layers.CuDNNLSTM + else: + rnn_layer_class = keras.layers.GRU + cudnn_rnn_layer_class = keras.layers.CuDNNGRU + rnn_layer_kwargs['reset_after'] = True + + layer = rnn_layer_class(units, **rnn_layer_kwargs) + layer = keras.layers.TimeDistributed(layer) + + cudnn_layer = cudnn_rnn_layer_class(units) + cudnn_layer = keras.layers.TimeDistributed(cudnn_layer) + + model = self._make_nested_model(input_shape, layer) + cudnn_model = self._make_nested_model(input_shape, cudnn_layer) if to_cudnn: - convert_weights(layer, cudnn_layer) + self._convert_model_weights(model, cudnn_model) else: - convert_weights(cudnn_layer, layer) + self._convert_model_weights(cudnn_model, model) - self.assertAllClose( - model.predict(inputs), cudnn_model.predict(inputs), atol=1e-4) + self.assertAllClose(model.predict(inputs), cudnn_model.predict(inputs), + atol=1e-4) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_cudnnrnn_bidirectional(self): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py index 25eeeee9529bcb52e608eeb9468c210eea8bd8be..629a9ec9a10c8afd4d98174a9183a2e9b08269ea 100644 --- a/tensorflow/python/keras/layers/embeddings.py +++ b/tensorflow/python/keras/layers/embeddings.py @@ -22,7 +22,7 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops @@ -112,6 +112,7 @@ class Embedding(Layer): self.activity_regularizer = regularizers.get(activity_regularizer) self.embeddings_constraint = constraints.get(embeddings_constraint) self.mask_zero = mask_zero + self.supports_masking = mask_zero self.input_length = input_length @tf_utils.shape_type_conversion @@ -127,8 +128,8 @@ class Embedding(Layer): def compute_mask(self, inputs, mask=None): if not self.mask_zero: return None - else: - return math_ops.not_equal(inputs, 0) + + return math_ops.not_equal(inputs, 0) @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): diff --git a/tensorflow/python/keras/layers/gru_test.py b/tensorflow/python/keras/layers/gru_test.py index 234434f7a0205c7dda80d308e4780cd761352d77..afef997b0049bb7dbf0cce4ac78f3c417d7d5846 100644 --- a/tensorflow/python/keras/layers/gru_test.py +++ b/tensorflow/python/keras/layers/gru_test.py @@ -29,7 +29,7 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer class GRULayerTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_return_sequences_GRU(self): num_samples = 2 timesteps = 3 @@ -41,7 +41,7 @@ class GRULayerTest(test.TestCase): 'return_sequences': True}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dynamic_behavior_GRU(self): num_samples = 2 timesteps = 3 @@ -55,7 +55,7 @@ class GRULayerTest(test.TestCase): y = np.random.random((num_samples, units)) model.train_on_batch(x, y) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dropout_GRU(self): num_samples = 2 timesteps = 3 @@ -68,7 +68,7 @@ class GRULayerTest(test.TestCase): 'recurrent_dropout': 0.1}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_implementation_mode_GRU(self): num_samples = 2 timesteps = 3 @@ -183,6 +183,7 @@ class GRULayerTest(test.TestCase): self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint) self.assertEqual(layer.cell.bias.constraint, b_constraint) + @tf_test_util.run_in_graph_and_eager_modes def test_with_masking_layer_GRU(self): layer_class = keras.layers.GRU with self.test_session(): @@ -192,7 +193,8 @@ class GRULayerTest(test.TestCase): model = keras.models.Sequential() model.add(keras.layers.Masking(input_shape=(3, 4))) model.add(layer_class(units=5, return_sequences=True, unroll=False)) - model.compile(loss='categorical_crossentropy', optimizer='adam') + model.compile(loss='categorical_crossentropy', + optimizer=RMSPropOptimizer(0.01)) model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1) def test_from_config_GRU(self): diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py index f222ea3083bad48094fbec7fe6750921f0233e35..0ebafe07cc45698200d0b1fa858a436c7a08820e 100644 --- a/tensorflow/python/keras/layers/local.py +++ b/tensorflow/python/keras/layers/local.py @@ -23,8 +23,8 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import conv_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.util.tf_export import tf_export @@ -140,9 +140,9 @@ class LocallyConnected1D(Layer): if input_dim is None: raise ValueError('Axis 2 of input should be fully-defined. ' 'Found shape:', input_shape) - output_length = conv_utils.conv_output_length( + self.output_length = conv_utils.conv_output_length( input_length, self.kernel_size[0], self.padding, self.strides[0]) - self.kernel_shape = (output_length, self.kernel_size[0] * input_dim, + self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim, self.filters) self.kernel = self.add_weight( shape=self.kernel_shape, @@ -152,7 +152,7 @@ class LocallyConnected1D(Layer): constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight( - shape=(output_length, self.filters), + shape=(self.output_length, self.filters), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, @@ -182,12 +182,13 @@ class LocallyConnected1D(Layer): return (input_shape[0], length, self.filters) def call(self, inputs): - output = K.local_conv1d(inputs, self.kernel, self.kernel_size, - self.strides, self.data_format) + output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides, + (self.output_length,), self.data_format) + if self.use_bias: output = K.bias_add(output, self.bias, data_format=self.data_format) - if self.activation is not None: - output = self.activation(output) + + output = self.activation(output) return output def get_config(self): @@ -400,9 +401,8 @@ class LocallyConnected2D(Layer): return (input_shape[0], rows, cols, self.filters) def call(self, inputs): - output = K.local_conv2d(inputs, self.kernel, self.kernel_size, self.strides, - (self.output_row, self.output_col), - self.data_format) + output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides, + (self.output_row, self.output_col), self.data_format) if self.use_bias: output = K.bias_add(output, self.bias, data_format=self.data_format) diff --git a/tensorflow/python/keras/layers/local_test.py b/tensorflow/python/keras/layers/local_test.py index 8df3f6b7bd741ad0b698fe500f0ac72e73985421..9639e0251f5a56e4130b13c0185792fe11da2532 100644 --- a/tensorflow/python/keras/layers/local_test.py +++ b/tensorflow/python/keras/layers/local_test.py @@ -28,7 +28,7 @@ from tensorflow.python.platform import test class LocallyConnectedLayersTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_locallyconnected_1d(self): num_samples = 2 num_steps = 8 @@ -92,7 +92,7 @@ class LocallyConnectedLayersTest(test.TestCase): self.assertEqual(layer.kernel.constraint, k_constraint) self.assertEqual(layer.bias.constraint, b_constraint) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_locallyconnected_2d(self): num_samples = 8 filters = 3 @@ -118,7 +118,7 @@ class LocallyConnectedLayersTest(test.TestCase): }, input_shape=(num_samples, num_row, num_col, stack_size)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_locallyconnected_2d_channels_first(self): num_samples = 8 filters = 3 diff --git a/tensorflow/python/keras/layers/lstm_test.py b/tensorflow/python/keras/layers/lstm_test.py index 87cb344bf82b73b6af9830a4428a5ba099135324..9802820fd0bf813e43a3bd017f5cb6d18f7f4d24 100644 --- a/tensorflow/python/keras/layers/lstm_test.py +++ b/tensorflow/python/keras/layers/lstm_test.py @@ -29,7 +29,7 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer class LSTMLayerTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_return_sequences_LSTM(self): num_samples = 2 timesteps = 3 @@ -56,7 +56,7 @@ class LSTMLayerTest(test.TestCase): outputs = model.layers[-1].output self.assertEquals(outputs.get_shape().as_list(), [None, timesteps, units]) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dynamic_behavior_LSTM(self): num_samples = 2 timesteps = 3 @@ -70,7 +70,7 @@ class LSTMLayerTest(test.TestCase): y = np.random.random((num_samples, units)) model.train_on_batch(x, y) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dropout_LSTM(self): num_samples = 2 timesteps = 3 @@ -83,7 +83,7 @@ class LSTMLayerTest(test.TestCase): 'recurrent_dropout': 0.1}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_implementation_mode_LSTM(self): num_samples = 2 timesteps = 3 @@ -197,6 +197,7 @@ class LSTMLayerTest(test.TestCase): self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint) self.assertEqual(layer.cell.bias.constraint, b_constraint) + @tf_test_util.run_in_graph_and_eager_modes def test_with_masking_layer_LSTM(self): layer_class = keras.layers.LSTM with self.test_session(): @@ -206,7 +207,8 @@ class LSTMLayerTest(test.TestCase): model = keras.models.Sequential() model.add(keras.layers.Masking(input_shape=(3, 4))) model.add(layer_class(units=5, return_sequences=True, unroll=False)) - model.compile(loss='categorical_crossentropy', optimizer='adam') + model.compile(loss='categorical_crossentropy', + optimizer=RMSPropOptimizer(0.01)) model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1) def test_from_config_LSTM(self): @@ -311,7 +313,8 @@ class LSTMLayerTest(test.TestCase): output = keras.layers.LSTM(units)(inputs, initial_state=initial_state) model = keras.models.Model([inputs] + initial_state, output) - model.compile(loss='categorical_crossentropy', optimizer='adam') + model.compile(loss='categorical_crossentropy', + optimizer=RMSPropOptimizer(0.01)) inputs = np.random.random((num_samples, timesteps, embedding_dim)) initial_state = [np.random.random((num_samples, units)) diff --git a/tensorflow/python/keras/layers/merge_test.py b/tensorflow/python/keras/layers/merge_test.py index 8a097cf7f57d06155f26e3099554e34a54186189..39bc98d039624d50788e1b7995dc5fba300a5276 100644 --- a/tensorflow/python/keras/layers/merge_test.py +++ b/tensorflow/python/keras/layers/merge_test.py @@ -28,7 +28,7 @@ from tensorflow.python.platform import test class MergeLayersTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_add(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -76,7 +76,7 @@ class MergeLayersTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.add([i1]) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_multiply(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -92,7 +92,7 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, x1 * x2 * x3, atol=1e-4) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_average(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -106,7 +106,7 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_maximum(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -120,7 +120,7 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_minimum(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -134,7 +134,7 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_concatenate(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) @@ -169,7 +169,7 @@ class MergeLayersTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'called on a list'): keras.layers.concatenate([i1], axis=-1) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_dot(self): i1 = keras.layers.Input(shape=(4,)) i2 = keras.layers.Input(shape=(4,)) @@ -215,7 +215,7 @@ class MergeLayersTest(test.TestCase): dot = keras.layers.Dot(1) dot.compute_output_shape(1) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_merge_subtract(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) diff --git a/tensorflow/python/keras/layers/noise.py b/tensorflow/python/keras/layers/noise.py index a895caa25b91702d92002f84fe44b5b5c3a8ca0c..cb7cee3ebc3ebd2413836b876f2aaf21985f1d9c 100644 --- a/tensorflow/python/keras/layers/noise.py +++ b/tensorflow/python/keras/layers/noise.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.keras import backend as K -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/python/keras/layers/noise_test.py b/tensorflow/python/keras/layers/noise_test.py index bde2185f03bd45c1c9fecbd6fe5544a17e9c04ef..aa2be62390b0dcf0656a533cba9bdbe9ceee09dd 100644 --- a/tensorflow/python/keras/layers/noise_test.py +++ b/tensorflow/python/keras/layers/noise_test.py @@ -40,7 +40,7 @@ class NoiseLayersTest(test.TestCase): kwargs={'rate': 0.5}, input_shape=(3, 2, 3)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_AlphaDropout(self): testing_utils.layer_test( keras.layers.AlphaDropout, diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index ff51eadee9928902cb310b28569c06a49d8e5ecc..cd26e04c39fca2987e15bb277569853a3bedea42 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -26,16 +26,17 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util.tf_export import tf_export @@ -180,11 +181,6 @@ class BatchNormalization(Layer): self.renorm_clipping = renorm_clipping self.renorm_momentum = renorm_momentum - def _add_tower_local_variable(self, *args, **kwargs): - tower_context = distribute_lib.get_tower_context() - with tower_context.tower_local_var_scope('mean'): - return self.add_weight(*args, **kwargs) - def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if not input_shape.ndims: @@ -312,19 +308,23 @@ class BatchNormalization(Layer): self._scope.set_partitioner(None) else: partitioner = None - self.moving_mean = self._add_tower_local_variable( + self.moving_mean = self.add_weight( name='moving_mean', shape=param_shape, dtype=param_dtype, initializer=self.moving_mean_initializer, - trainable=False) + synchronization=variable_scope.VariableSynchronization.ON_READ, + trainable=False, + aggregation=variable_scope.VariableAggregation.MEAN) - self.moving_variance = self._add_tower_local_variable( + self.moving_variance = self.add_weight( name='moving_variance', shape=param_shape, dtype=param_dtype, initializer=self.moving_variance_initializer, - trainable=False) + synchronization=variable_scope.VariableSynchronization.ON_READ, + trainable=False, + aggregation=variable_scope.VariableAggregation.MEAN) if self.renorm: # Create variables to maintain the moving mean and standard deviation. @@ -335,24 +335,26 @@ class BatchNormalization(Layer): # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): - var = self._add_tower_local_variable( + var = self.add_weight( name=name, shape=shape, dtype=param_dtype, initializer=init_ops.zeros_initializer(), - trainable=False) + synchronization=variable_scope.VariableSynchronization.ON_READ, + trainable=False, + aggregation=variable_scope.VariableAggregation.MEAN) return var - with distribute_lib.get_distribution_strategy().colocate_vars_with( - self.moving_mean): + with distribution_strategy_context.get_distribution_strategy( + ).colocate_vars_with(self.moving_mean): self.renorm_mean = _renorm_variable('renorm_mean', param_shape) self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ()) # We initialize renorm_stddev to 0, and maintain the (0-initialized) # renorm_stddev_weight. This allows us to (1) mix the average # stddev with the minibatch stddev early in training, and (2) compute # the unbiased average stddev by dividing renorm_stddev by the weight. - with distribute_lib.get_distribution_strategy().colocate_vars_with( - self.moving_variance): + with distribution_strategy_context.get_distribution_strategy( + ).colocate_vars_with(self.moving_variance): self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape) self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight', ()) @@ -364,11 +366,12 @@ class BatchNormalization(Layer): def _assign_moving_average(self, variable, value, momentum): with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope: - decay = ops.convert_to_tensor(1.0 - momentum, name='decay') - if decay.dtype != variable.dtype.base_dtype: - decay = math_ops.cast(decay, variable.dtype.base_dtype) - update_delta = (variable - value) * decay - return state_ops.assign_sub(variable, update_delta, name=scope) + with ops.colocate_with(variable): + decay = ops.convert_to_tensor(1.0 - momentum, name='decay') + if decay.dtype != variable.dtype.base_dtype: + decay = math_ops.cast(decay, variable.dtype.base_dtype) + update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay + return state_ops.assign_sub(variable, update_delta, name=scope) def _fused_batch_norm(self, inputs, training): """Returns the output of fused batch norm.""" @@ -616,6 +619,10 @@ class BatchNormalization(Layer): else: mean, variance = self.moving_mean, self.moving_variance + mean = math_ops.cast(mean, inputs.dtype) + variance = math_ops.cast(variance, inputs.dtype) + if offset is not None: + offset = math_ops.cast(offset, inputs.dtype) outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py index b22f3bd1529812f6b5f63efe5cf6b6133db97f07..a97b4cac469f596112481e1b3b3f93b17ea20074 100644 --- a/tensorflow/python/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/layers/normalization_test.py @@ -95,6 +95,24 @@ class NormalizationLayersTest(test.TestCase): np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + def test_batchnorm_mixed_precision(self): + with self.test_session(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) + model.add(norm) + model.compile(loss='mse', optimizer='sgd') + + # centered on 5.0, variance 10.0 + x = np.random.normal( + loc=5.0, scale=10.0, size=(1000, 10)).astype(np.float16) + model.fit(x, x, epochs=4, verbose=0) + out = model.predict(x) + out -= keras.backend.eval(norm.beta) + out /= keras.backend.eval(norm.gamma) + + np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) + np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + def test_batchnorm_convnet(self): if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True): diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py index 10a82b285eff6f6b414e67441ceb88976ca2368f..912e8bd619db8b35a54853c0752382479567fd04 100644 --- a/tensorflow/python/keras/layers/pooling.py +++ b/tensorflow/python/keras/layers/pooling.py @@ -20,8 +20,8 @@ from __future__ import print_function from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import conv_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn diff --git a/tensorflow/python/keras/layers/pooling_test.py b/tensorflow/python/keras/layers/pooling_test.py index cbd58a22879975b7dbaab8290f59cee573b272cd..2cd9939e66ff869dac5058d2dd00d8d495e40f55 100644 --- a/tensorflow/python/keras/layers/pooling_test.py +++ b/tensorflow/python/keras/layers/pooling_test.py @@ -27,14 +27,14 @@ from tensorflow.python.platform import test class GlobalPoolingTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_globalpooling_1d(self): testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D, input_shape=(3, 4, 5)) testing_utils.layer_test( keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_globalpooling_2d(self): testing_utils.layer_test( keras.layers.pooling.GlobalMaxPooling2D, @@ -53,7 +53,7 @@ class GlobalPoolingTest(test.TestCase): kwargs={'data_format': 'channels_last'}, input_shape=(3, 5, 6, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_globalpooling_3d(self): testing_utils.layer_test( keras.layers.pooling.GlobalMaxPooling3D, @@ -75,7 +75,7 @@ class GlobalPoolingTest(test.TestCase): class Pooling2DTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_maxpooling_2d(self): pool_size = (3, 3) for strides in [(1, 1), (2, 2)]: @@ -88,7 +88,7 @@ class Pooling2DTest(test.TestCase): }, input_shape=(3, 5, 6, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_averagepooling_2d(self): testing_utils.layer_test( keras.layers.AveragePooling2D, @@ -122,7 +122,7 @@ class Pooling2DTest(test.TestCase): class Pooling3DTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_maxpooling_3d(self): pool_size = (3, 3, 3) testing_utils.layer_test( @@ -141,7 +141,7 @@ class Pooling3DTest(test.TestCase): }, input_shape=(3, 4, 11, 12, 10)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_averagepooling_3d(self): pool_size = (3, 3, 3) testing_utils.layer_test( @@ -163,7 +163,7 @@ class Pooling3DTest(test.TestCase): class Pooling1DTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_maxpooling_1d(self): for padding in ['valid', 'same']: for stride in [1, 2]: @@ -173,7 +173,7 @@ class Pooling1DTest(test.TestCase): 'padding': padding}, input_shape=(3, 5, 4)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_averagepooling_1d(self): for padding in ['valid', 'same']: for stride in [1, 2]: diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 7e509fb45182653d938adfd679e204cc7ea1e900..12c82a53f6eeb0170456e5abfd6e1f2bb4e0f833 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -19,7 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numbers import numpy as np from tensorflow.python.eager import context @@ -29,14 +28,16 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers from tensorflow.python.keras import regularizers -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -86,17 +87,24 @@ class StackedRNNCells(Layer): # (assuming one LSTM has states [h, c]) state_size = [] for cell in self.cells[::-1]: - if hasattr(cell.state_size, '__len__'): + if _is_multiple_state(cell.state_size): state_size += list(cell.state_size) else: state_size.append(cell.state_size) return tuple(state_size) + @property + def output_size(self): + if hasattr(self.cells[-1], 'output_size'): + return self.cells[-1].output_size + else: + return self.state_size[0] + def call(self, inputs, states, constants=None, **kwargs): # Recover per-cell states. nested_states = [] for cell in self.cells[::-1]: - if hasattr(cell.state_size, '__len__'): + if _is_multiple_state(cell.state_size): nested_states.append(states[:len(cell.state_size)]) states = states[len(cell.state_size):] else: @@ -133,11 +141,12 @@ class StackedRNNCells(Layer): cell.build([input_shape] + constants_shape) else: cell.build(input_shape) - if hasattr(cell.state_size, '__len__'): + if _is_multiple_state(cell.state_size): output_dim = cell.state_size[0] else: output_dim = cell.state_size - input_shape = (input_shape[0], output_dim) + input_shape = tuple([input_shape[0]] + + tensor_shape.as_shape(output_dim).as_list()) self.built = True def get_config(self): @@ -235,22 +244,26 @@ class RNN(Layer): """Base class for recurrent layers. Arguments: - cell: A RNN cell instance. A RNN cell is a class that has: + cell: A RNN cell instance or a list of RNN cell instances. + A RNN cell is a class that has: - a `call(input_at_t, states_at_t)` method, returning `(output_at_t, states_at_t_plus_1)`. The call method of the cell can also take the optional argument `constants`, see section "Note on passing external constants" below. - a `state_size` attribute. This can be a single integer - (single state) in which case it is - the size of the recurrent state - (which should be the same as the size of the cell output). - This can also be a list/tuple of integers - (one size per state). In this case, the first entry - (`state_size[0]`) should be the same as - the size of the cell output. - It is also possible for `cell` to be a list of RNN cell instances, - in which cases the cells get stacked on after the other in the RNN, - implementing an efficient stacked RNN. + (single state) in which case it is the size of the recurrent + state. This can also be a list/tuple of integers (one size per + state). + The `state_size` can also be TensorShape or tuple/list of + TensorShape, to represent high dimension state. + - a `output_size` attribute. This can be a single integer or a + TensorShape, which represent the shape of the output. For backward + compatible reason, if this attribute is not available for the + cell, the value will be inferred by the first element of the + `state_size`. + In the case that `cell` is a list of RNN cell instances, the cells + will be stacked on after the other in the RNN, implementing an + efficient stacked RNN. return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence. return_state: Boolean. Whether to return the last state @@ -267,9 +280,8 @@ class RNN(Layer): Unrolling can speed-up a RNN, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. - input_dim: dimensionality of the input (integer). - This argument (or alternatively, - the keyword argument `input_shape`) + input_dim: dimensionality of the input (integer or tuple of integers). + This argument (or alternatively, the keyword argument `input_shape`) is required when using this layer as the first layer in a model. input_length: Length of input sequences, to be specified when it is constant. @@ -282,15 +294,18 @@ class RNN(Layer): (e.g. via the `input_shape` argument) Input shape: - 3D tensor with shape `(batch_size, timesteps, input_dim)`. + N-D tensor with shape `(batch_size, timesteps, ...)`. Output shape: - if `return_state`: a list of tensors. The first tensor is the output. The remaining tensors are the last states, - each with shape `(batch_size, units)`. - - if `return_sequences`: 3D tensor with shape - `(batch_size, timesteps, units)`. - - else, 2D tensor with shape `(batch_size, units)`. + each with shape `(batch_size, state_size)`, where `state_size` could + be a high dimension tensor shape. + - if `return_sequences`: N-D tensor with shape + `(batch_size, timesteps, output_size)`, where `output_size` could + be a high dimension tensor shape. + - else, N-D tensor with shape `(batch_size, output_size)`, where + `output_size` could be a high dimension tensor shape. # Masking This layer supports masking for input data with a variable number @@ -402,6 +417,8 @@ class RNN(Layer): 'one integer per RNN state).') super(RNN, self).__init__(**kwargs) self.cell = cell + if isinstance(cell, checkpointable.CheckpointableBase): + self._track_checkpointable(self.cell, name='cell') self.return_sequences = return_sequences self.return_state = return_state self.go_backwards = go_backwards @@ -409,7 +426,7 @@ class RNN(Layer): self.unroll = unroll self.supports_masking = True - self.input_spec = [InputSpec(ndim=3)] + self.input_spec = [None] # The input shape is unknown yet, at least rank 3. self.state_spec = None self._states = None self.constants_spec = None @@ -418,11 +435,8 @@ class RNN(Layer): @property def states(self): if self._states is None: - if isinstance(self.cell.state_size, numbers.Integral): - num_states = 1 - else: - num_states = len(self.cell.state_size) - return [None for _ in range(num_states)] + state = nest.map_structure(lambda _: None, self.cell.state_size) + return state if nest.is_sequence(self.cell.state_size) else [state] return self._states @states.setter @@ -434,19 +448,27 @@ class RNN(Layer): if isinstance(input_shape, list): input_shape = input_shape[0] - if hasattr(self.cell.state_size, '__len__'): + if _is_multiple_state(self.cell.state_size): state_size = self.cell.state_size else: state_size = [self.cell.state_size] - output_dim = state_size[0] + + if hasattr(self.cell, 'output_size'): + output_dim = tensor_shape.as_shape(self.cell.output_size).as_list() + else: + # Note that state_size[0] could be a tensor_shape or int. + output_dim = tensor_shape.as_shape(state_size[0]).as_list() if self.return_sequences: - output_shape = (input_shape[0], input_shape[1], output_dim) + output_shape = tuple([input_shape[0], input_shape[1]] + output_dim) else: - output_shape = (input_shape[0], output_dim) + output_shape = tuple([input_shape[0]] + output_dim) if self.return_state: - state_shape = [(input_shape[0], dim) for dim in state_size] + state_shape = [ + tuple([input_shape[0]] + tensor_shape.as_shape(dim).as_list()) + for dim in state_size + ] return [output_shape] + state_shape else: return output_shape @@ -474,49 +496,83 @@ class RNN(Layer): input_shape = input_shape[0] batch_size = input_shape[0] if self.stateful else None - input_dim = input_shape[-1] - self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim)) + input_dim = input_shape[2:] + self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_dim) # allow cell (if layer) to build before we set or validate state_spec if isinstance(self.cell, Layer): - step_input_shape = (input_shape[0],) + input_shape[2:] + step_input_shape = (input_shape[0],) + input_dim if constants_shape is not None: self.cell.build([step_input_shape] + constants_shape) else: self.cell.build(step_input_shape) # set or validate state_spec - if hasattr(self.cell.state_size, '__len__'): + if _is_multiple_state(self.cell.state_size): state_size = list(self.cell.state_size) else: state_size = [self.cell.state_size] if self.state_spec is not None: # initial_state was passed in call, check compatibility - if [spec.shape[-1] for spec in self.state_spec] != state_size: - raise ValueError( - 'An `initial_state` was passed that is not compatible with ' - '`cell.state_size`. Received `state_spec`={}; ' - 'however `cell.state_size` is ' - '{}'.format(self.state_spec, self.cell.state_size)) + self._validate_state_spec(state_size, self.state_spec) else: - self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size] + self.state_spec = [ + InputSpec(shape=[None] + tensor_shape.as_shape(dim).as_list()) + for dim in state_size + ] if self.stateful: self.reset_states() self.built = True + @staticmethod + def _validate_state_spec(cell_state_sizes, init_state_specs): + """Validate the state spec between the initial_state and the state_size. + + Args: + cell_state_sizes: list, the `state_size` attribute from the cell. + init_state_specs: list, the `state_spec` from the initial_state that is + passed in call() + + Raises: + ValueError: When initial state spec is not compatible with the state size. + """ + validation_error = ValueError( + 'An `initial_state` was passed that is not compatible with ' + '`cell.state_size`. Received `state_spec`={}; ' + 'however `cell.state_size` is ' + '{}'.format(init_state_specs, cell_state_sizes)) + if len(cell_state_sizes) == len(init_state_specs): + for i in range(len(cell_state_sizes)): + if not tensor_shape.TensorShape( + # Ignore the first axis for init_state which is for batch + init_state_specs[i].shape[1:]).is_compatible_with( + tensor_shape.TensorShape(cell_state_sizes[i])): + raise validation_error + else: + raise validation_error + def get_initial_state(self, inputs): - # build an all-zero tensor of shape (samples, output_dim) + # build an all-zero tensor of shape (batch, cell.state_size) initial_state = array_ops.zeros_like(inputs) - # shape of initial_state = (samples, timesteps, input_dim) - initial_state = math_ops.reduce_sum(initial_state, axis=(1, 2)) - # shape of initial_state = (samples,) - initial_state = array_ops.expand_dims(initial_state, axis=-1) - # shape of initial_state = (samples, 1) - if hasattr(self.cell.state_size, '__len__'): - return [K.tile(initial_state, [1, dim]) for dim in self.cell.state_size] + # shape of initial_state = (batch, timesteps, ...) + initial_state = math_ops.reduce_sum( + initial_state, axis=list(range(1, len(inputs.shape)))) + # shape of initial_state = (batch,) + if _is_multiple_state(self.cell.state_size): + states = [] + for dims in self.cell.state_size: + state = initial_state + flat_dims = tensor_shape.as_shape(dims).as_list() + # reshape the state to (batch, 1, 1, ....) and then expand each state. + state = array_ops.reshape(state, [-1,] + [1] * len(flat_dims)) + states.append(K.tile(state, [1] + flat_dims)) + return states else: - return [K.tile(initial_state, [1, self.cell.state_size])] + flat_dims = tensor_shape.as_shape(self.cell.state_size).as_list() + initial_state = array_ops.reshape( + initial_state, [-1] + [1] * len(flat_dims)) + return [K.tile(initial_state, [1] + flat_dims)] def __call__(self, inputs, initial_state=None, constants=None, **kwargs): inputs, initial_state, constants = _standardize_args(inputs, @@ -614,6 +670,8 @@ class RNN(Layer): if generic_utils.has_arg(self.cell.call, 'training'): kwargs['training'] = training + # TF RNN cells expect single tensor as state instead of list wrapped tensor. + is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None if constants: if not generic_utils.has_arg(self.cell.call, 'constants'): raise ValueError('RNN cell does not support constants') @@ -621,11 +679,21 @@ class RNN(Layer): def step(inputs, states): constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type - return self.cell.call(inputs, states, constants=constants, **kwargs) + + states = states[0] if len(states) == 1 and is_tf_rnn_cell else states + output, new_states = self.cell.call( + inputs, states, constants=constants, **kwargs) + if not nest.is_sequence(new_states): + new_states = [new_states] + return output, new_states else: def step(inputs, states): - return self.cell.call(inputs, states, **kwargs) + states = states[0] if len(states) == 1 and is_tf_rnn_cell else states + output, new_states = self.cell.call(inputs, states, **kwargs) + if not nest.is_sequence(new_states): + new_states = [new_states] + return output, new_states last_output, outputs, states = K.rnn( step, @@ -679,19 +747,26 @@ class RNN(Layer): '`batch_shape` argument to your Input layer.') # initialize state if None if self.states[0] is None: - if hasattr(self.cell.state_size, '__len__'): + if _is_multiple_state(self.cell.state_size): self.states = [ - K.zeros((batch_size, dim)) for dim in self.cell.state_size + K.zeros([batch_size] + tensor_shape.as_shape(dim).as_list()) + for dim in self.cell.state_size ] else: - self.states = [K.zeros((batch_size, self.cell.state_size))] + self.states = [ + K.zeros([batch_size] + + tensor_shape.as_shape(self.cell.state_size).as_list()) + ] elif states is None: - if hasattr(self.cell.state_size, '__len__'): + if _is_multiple_state(self.cell.state_size): for state, dim in zip(self.states, self.cell.state_size): - K.set_value(state, np.zeros((batch_size, dim))) + K.set_value(state, + np.zeros([batch_size] + + tensor_shape.as_shape(dim).as_list())) else: - K.set_value(self.states[0], np.zeros((batch_size, - self.cell.state_size))) + K.set_value(self.states[0], np.zeros( + [batch_size] + + tensor_shape.as_shape(self.cell.state_size).as_list())) else: if not isinstance(states, (list, tuple)): states = [states] @@ -701,11 +776,12 @@ class RNN(Layer): 'but it received ' + str(len(states)) + ' state values. Input received: ' + str(states)) for index, (value, state) in enumerate(zip(states, self.states)): - if hasattr(self.cell.state_size, '__len__'): + if _is_multiple_state(self.cell.state_size): dim = self.cell.state_size[index] else: dim = self.cell.state_size - if value.shape != (batch_size, dim): + if value.shape != tuple([batch_size] + + tensor_shape.as_shape(dim).as_list()): raise ValueError( 'State ' + str(index) + ' is incompatible with layer ' + self.name + ': expected shape=' + str( @@ -843,6 +919,7 @@ class SimpleRNNCell(Layer): self.dropout = min(1., max(0., dropout)) self.recurrent_dropout = min(1., max(0., recurrent_dropout)) self.state_size = self.units + self.output_size = self.units self._dropout_mask = None self._recurrent_dropout_mask = None @@ -1246,6 +1323,7 @@ class GRUCell(Layer): self.implementation = implementation self.reset_after = reset_after self.state_size = self.units + self.output_size = self.units self._dropout_mask = None self._recurrent_dropout_mask = None @@ -1791,6 +1869,7 @@ class LSTMCell(Layer): self.recurrent_dropout = min(1., max(0., recurrent_dropout)) self.implementation = implementation self.state_size = (self.units, self.units) + self.output_size = self.units self._dropout_mask = None self._recurrent_dropout_mask = None @@ -2227,342 +2306,6 @@ def _generate_dropout_mask(ones, rate, training=None, count=1): return K.in_train_phase(dropped_inputs, ones, training=training) -class Recurrent(Layer): - """Deprecated abstract base class for recurrent layers. - - It still exists because it is leveraged by the convolutional-recurrent layers. - It will be removed entirely in the future. - It was never part of the public API. - Do not use. - - Arguments: - weights: list of Numpy arrays to set as initial weights. - The list should have 3 elements, of shapes: - `[(input_dim, output_dim), (output_dim, output_dim), (output_dim,)]`. - return_sequences: Boolean. Whether to return the last output - in the output sequence, or the full sequence. - return_state: Boolean. Whether to return the last state - in addition to the output. - go_backwards: Boolean (default False). - If True, process the input sequence backwards and return the - reversed sequence. - stateful: Boolean (default False). If True, the last state - for each sample at index i in a batch will be used as initial - state for the sample of index i in the following batch. - unroll: Boolean (default False). - If True, the network will be unrolled, - else a symbolic loop will be used. - Unrolling can speed-up a RNN, - although it tends to be more memory-intensive. - Unrolling is only suitable for short sequences. - implementation: one of {0, 1, or 2}. - If set to 0, the RNN will use - an implementation that uses fewer, larger matrix products, - thus running faster on CPU but consuming more memory. - If set to 1, the RNN will use more matrix products, - but smaller ones, thus running slower - (may actually be faster on GPU) while consuming less memory. - If set to 2 (LSTM/GRU only), - the RNN will combine the input gate, - the forget gate and the output gate into a single matrix, - enabling more time-efficient parallelization on the GPU. - Note: RNN dropout must be shared for all gates, - resulting in a slightly reduced regularization. - input_dim: dimensionality of the input (integer). - This argument (or alternatively, the keyword argument `input_shape`) - is required when using this layer as the first layer in a model. - input_length: Length of input sequences, to be specified - when it is constant. - This argument is required if you are going to connect - `Flatten` then `Dense` layers upstream - (without it, the shape of the dense outputs cannot be computed). - Note that if the recurrent layer is not the first layer - in your model, you would need to specify the input length - at the level of the first layer - (e.g. via the `input_shape` argument) - - Input shape: - 3D tensor with shape `(batch_size, timesteps, input_dim)`, - (Optional) 2D tensors with shape `(batch_size, output_dim)`. - - Output shape: - - if `return_state`: a list of tensors. The first tensor is - the output. The remaining tensors are the last states, - each with shape `(batch_size, units)`. - - if `return_sequences`: 3D tensor with shape - `(batch_size, timesteps, units)`. - - else, 2D tensor with shape `(batch_size, units)`. - - # Masking - This layer supports masking for input data with a variable number - of timesteps. To introduce masks to your data, - use an `Embedding` layer with the `mask_zero` parameter - set to `True`. - - # Note on using statefulness in RNNs - You can set RNN layers to be 'stateful', which means that the states - computed for the samples in one batch will be reused as initial states - for the samples in the next batch. This assumes a one-to-one mapping - between samples in different successive batches. - - To enable statefulness: - - specify `stateful=True` in the layer constructor. - - specify a fixed batch size for your model, by passing - if sequential model: - `batch_input_shape=(...)` to the first layer in your model. - else for functional model with 1 or more Input layers: - `batch_shape=(...)` to all the first layers in your model. - This is the expected shape of your inputs - *including the batch size*. - It should be a tuple of integers, e.g. `(32, 10, 100)`. - - specify `shuffle=False` when calling fit(). - - To reset the states of your model, call `.reset_states()` on either - a specific layer, or on your entire model. - - # Note on specifying the initial state of RNNs - You can specify the initial state of RNN layers symbolically by - calling them with the keyword argument `initial_state`. The value of - `initial_state` should be a tensor or list of tensors representing - the initial state of the RNN layer. - - You can specify the initial state of RNN layers numerically by - calling `reset_states` with the keyword argument `states`. The value of - `states` should be a numpy array or list of numpy arrays representing - the initial state of the RNN layer. - """ - - def __init__(self, - return_sequences=False, - return_state=False, - go_backwards=False, - stateful=False, - unroll=False, - implementation=0, - **kwargs): - super(Recurrent, self).__init__(**kwargs) - self.return_sequences = return_sequences - self.return_state = return_state - self.go_backwards = go_backwards - self.stateful = stateful - self.unroll = unroll - self.implementation = implementation - self.supports_masking = True - self.input_spec = [InputSpec(ndim=3)] - self.state_spec = None - self.dropout = 0 - self.recurrent_dropout = 0 - - @tf_utils.shape_type_conversion - def compute_output_shape(self, input_shape): - if isinstance(input_shape, list): - input_shape = input_shape[0] - input_shape = tensor_shape.TensorShape(input_shape).as_list() - if self.return_sequences: - output_shape = (input_shape[0], input_shape[1], self.units) - else: - output_shape = (input_shape[0], self.units) - - if self.return_state: - state_shape = [tensor_shape.TensorShape( - (input_shape[0], self.units)) for _ in self.states] - return [tensor_shape.TensorShape(output_shape)] + state_shape - return tensor_shape.TensorShape(output_shape) - - def compute_mask(self, inputs, mask): - if isinstance(mask, list): - mask = mask[0] - output_mask = mask if self.return_sequences else None - if self.return_state: - state_mask = [None for _ in self.states] - return [output_mask] + state_mask - return output_mask - - def step(self, inputs, states): - raise NotImplementedError - - def get_constants(self, inputs, training=None): - return [] - - def get_initial_state(self, inputs): - # build an all-zero tensor of shape (samples, output_dim) - initial_state = array_ops.zeros_like(inputs) - # shape of initial_state = (samples, timesteps, input_dim) - initial_state = math_ops.reduce_sum(initial_state, axis=(1, 2)) - # shape of initial_state = (samples,) - initial_state = array_ops.expand_dims(initial_state, axis=-1) - # shape of initial_state = (samples, 1) - initial_state = K.tile(initial_state, [1, - self.units]) # (samples, output_dim) - initial_state = [initial_state for _ in range(len(self.states))] - return initial_state - - def preprocess_input(self, inputs, training=None): - return inputs - - def __call__(self, inputs, initial_state=None, **kwargs): - if (isinstance(inputs, (list, tuple)) and - len(inputs) > 1 - and initial_state is None): - initial_state = inputs[1:] - inputs = inputs[0] - - # If `initial_state` is specified, - # and if it a Keras tensor, - # then add it to the inputs and temporarily - # modify the input spec to include the state. - if initial_state is None: - return super(Recurrent, self).__call__(inputs, **kwargs) - - if not isinstance(initial_state, (list, tuple)): - initial_state = [initial_state] - - is_keras_tensor = hasattr(initial_state[0], '_keras_history') - for tensor in initial_state: - if hasattr(tensor, '_keras_history') != is_keras_tensor: - raise ValueError('The initial state of an RNN layer cannot be' - ' specified with a mix of Keras tensors and' - ' non-Keras tensors') - - if is_keras_tensor: - # Compute the full input spec, including state - input_spec = self.input_spec - state_spec = self.state_spec - if not isinstance(input_spec, list): - input_spec = [input_spec] - if not isinstance(state_spec, list): - state_spec = [state_spec] - self.input_spec = input_spec + state_spec - - # Compute the full inputs, including state - inputs = [inputs] + list(initial_state) - - # Perform the call - output = super(Recurrent, self).__call__(inputs, **kwargs) - - # Restore original input spec - self.input_spec = input_spec - return output - else: - kwargs['initial_state'] = initial_state - return super(Recurrent, self).__call__(inputs, **kwargs) - - def call(self, inputs, mask=None, training=None, initial_state=None): - # input shape: `(samples, time (padded with zeros), input_dim)` - # note that the .build() method of subclasses MUST define - # self.input_spec and self.state_spec with complete input shapes. - if isinstance(inputs, list): - initial_state = inputs[1:] - inputs = inputs[0] - elif initial_state is not None: - pass - elif self.stateful: - initial_state = self.states - else: - initial_state = self.get_initial_state(inputs) - - if isinstance(mask, list): - mask = mask[0] - - if len(initial_state) != len(self.states): - raise ValueError('Layer has ' + str(len(self.states)) + - ' states but was passed ' + str(len(initial_state)) + - ' initial states.') - input_shape = K.int_shape(inputs) - if self.unroll and input_shape[1] is None: - raise ValueError('Cannot unroll a RNN if the ' - 'time dimension is undefined. \n' - '- If using a Sequential model, ' - 'specify the time dimension by passing ' - 'an `input_shape` or `batch_input_shape` ' - 'argument to your first layer. If your ' - 'first layer is an Embedding, you can ' - 'also use the `input_length` argument.\n' - '- If using the functional API, specify ' - 'the time dimension by passing a `shape` ' - 'or `batch_shape` argument to your Input layer.') - constants = self.get_constants(inputs, training=None) - preprocessed_input = self.preprocess_input(inputs, training=None) - last_output, outputs, states = K.rnn( - self.step, - preprocessed_input, - initial_state, - go_backwards=self.go_backwards, - mask=mask, - constants=constants, - unroll=self.unroll) - if self.stateful: - updates = [] - for i in range(len(states)): - updates.append(state_ops.assign(self.states[i], states[i])) - self.add_update(updates, inputs) - - # Properly set learning phase - if 0 < self.dropout + self.recurrent_dropout: - last_output._uses_learning_phase = True - outputs._uses_learning_phase = True - - if not self.return_sequences: - outputs = last_output - - if self.return_state: - if not isinstance(states, (list, tuple)): - states = [states] - else: - states = list(states) - return [outputs] + states - return outputs - - def reset_states(self, states=None): - if not self.stateful: - raise AttributeError('Layer must be stateful.') - batch_size = self.input_spec[0].shape[0] - if not batch_size: - raise ValueError('If a RNN is stateful, it needs to know ' - 'its batch size. Specify the batch size ' - 'of your input tensors: \n' - '- If using a Sequential model, ' - 'specify the batch size by passing ' - 'a `batch_input_shape` ' - 'argument to your first layer.\n' - '- If using the functional API, specify ' - 'the time dimension by passing a ' - '`batch_shape` argument to your Input layer.') - # initialize state if None - if self.states[0] is None: - self.states = [K.zeros((batch_size, self.units)) for _ in self.states] - elif states is None: - for state in self.states: - K.set_value(state, np.zeros((batch_size, self.units))) - else: - if not isinstance(states, (list, tuple)): - states = [states] - if len(states) != len(self.states): - raise ValueError('Layer ' + self.name + ' expects ' + - str(len(self.states)) + ' states, ' - 'but it received ' + str(len(states)) + - ' state values. Input received: ' + str(states)) - for index, (value, state) in enumerate(zip(states, self.states)): - if value.shape != (batch_size, self.units): - raise ValueError('State ' + str(index) + - ' is incompatible with layer ' + self.name + - ': expected shape=' + str((batch_size, self.units)) + - ', found shape=' + str(value.shape)) - K.set_value(state, value) - - def get_config(self): - config = { - 'return_sequences': self.return_sequences, - 'return_state': self.return_state, - 'go_backwards': self.go_backwards, - 'stateful': self.stateful, - 'unroll': self.unroll, - 'implementation': self.implementation - } - base_config = super(Recurrent, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - def _standardize_args(inputs, initial_state, constants, num_constants): """Standardizes `__call__` to a single list of tensor inputs. @@ -2605,3 +2348,9 @@ def _standardize_args(inputs, initial_state, constants, num_constants): constants = to_list_or_none(constants) return inputs, initial_state, constants + + +def _is_multiple_state(state_size): + """Check whether the state_size contains multiple states.""" + return (hasattr(state_size, '__len__') and + not isinstance(state_size, tensor_shape.TensorShape)) diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py index 802374d2d28d792c1e32bf5095b928f569144b49..13bd07052873d7895ecee8a0227c0a30b933f31d 100644 --- a/tensorflow/python/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/layers/recurrent_test.py @@ -24,10 +24,13 @@ from __future__ import print_function import numpy as np from tensorflow.python import keras +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import util as checkpointable_util class RNNTest(test.TestCase): @@ -556,5 +559,179 @@ class RNNTest(test.TestCase): [tuple(o.as_list()) for o in output_shape], expected_output_shape) + def test_checkpointable_dependencies(self): + rnn = keras.layers.SimpleRNN + with self.test_session(): + x = np.random.random((2, 2, 2)) + y = np.random.random((2, 2)) + model = keras.models.Sequential() + model.add(rnn(2)) + model.compile(optimizer='rmsprop', loss='mse') + model.fit(x, y, epochs=1, batch_size=1) + + # check whether the model variables are present in the + # checkpointable list of objects + checkpointed_objects = set(checkpointable_util.list_objects(model)) + for v in model.variables: + self.assertIn(v, checkpointed_objects) + + def test_high_dimension_RNN(self): + with self.test_session(): + # Basic test case. + unit_a = 10 + unit_b = 20 + input_a = 5 + input_b = 10 + batch = 32 + time_step = 4 + + cell = Minimal2DRNNCell(unit_a, unit_b) + x = keras.Input((None, input_a, input_b)) + layer = keras.layers.RNN(cell) + y = layer(x) + + self.assertEqual(cell.state_size.as_list(), [unit_a, unit_b]) + init_state = layer.get_initial_state(x) + self.assertEqual(len(init_state), 1) + self.assertEqual(init_state[0].get_shape().as_list(), + [None, unit_a, unit_b]) + + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + np.zeros((batch, time_step, input_a, input_b)), + np.zeros((batch, unit_a, unit_b))) + self.assertEqual(model.output_shape, (None, unit_a, unit_b)) + + # Test stacking. + cells = [ + Minimal2DRNNCell(unit_a, unit_b), + Minimal2DRNNCell(unit_a * 2, unit_b * 2), + Minimal2DRNNCell(unit_a * 4, unit_b * 4) + ] + layer = keras.layers.RNN(cells) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + np.zeros((batch, time_step, input_a, input_b)), + np.zeros((batch, unit_a * 4, unit_b * 4))) + self.assertEqual(model.output_shape, (None, unit_a * 4, unit_b * 4)) + + def test_high_dimension_RNN_with_init_state(self): + unit_a = 10 + unit_b = 20 + input_a = 5 + input_b = 10 + batch = 32 + time_step = 4 + + with self.test_session(): + # Basic test case. + cell = Minimal2DRNNCell(unit_a, unit_b) + x = keras.Input((None, input_a, input_b)) + s = keras.Input((unit_a, unit_b)) + layer = keras.layers.RNN(cell) + y = layer(x, initial_state=s) + + model = keras.models.Model([x, s], y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch([ + np.zeros((batch, time_step, input_a, input_b)), + np.zeros((batch, unit_a, unit_b)) + ], np.zeros((batch, unit_a, unit_b))) + self.assertEqual(model.output_shape, (None, unit_a, unit_b)) + + with self.test_session(): + # Bad init state shape. + bad_shape_a = unit_a * 2 + bad_shape_b = unit_b * 2 + cell = Minimal2DRNNCell(unit_a, unit_b) + x = keras.Input((None, input_a, input_b)) + s = keras.Input((bad_shape_a, bad_shape_b)) + layer = keras.layers.RNN(cell) + with self.assertRaisesWithPredicateMatch(ValueError, + 'however `cell.state_size` is'): + layer(x, initial_state=s) + + def test_inconsistent_output_state_size(self): + with self.test_session(): + batch = 32 + time_step = 4 + state_size = 5 + input_size = 6 + cell = PlusOneRNNCell(state_size) + x = keras.Input((None, input_size)) + layer = keras.layers.RNN(cell) + y = layer(x) + + self.assertEqual(cell.state_size, state_size) + init_state = layer.get_initial_state(x) + self.assertEqual(len(init_state), 1) + self.assertEqual(init_state[0].get_shape().as_list(), + [None, state_size]) + + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + np.zeros((batch, time_step, input_size)), + np.zeros((batch, input_size))) + self.assertEqual(model.output_shape, (None, input_size)) + + +class Minimal2DRNNCell(keras.layers.Layer): + """The minimal 2D RNN cell is a simple combination of 2 1-D RNN cell. + + Both internal state and output have 2 dimensions and are orthogonal + between each other. + """ + + def __init__(self, unit_a, unit_b, **kwargs): + self.unit_a = unit_a + self.unit_b = unit_b + self.state_size = tensor_shape.as_shape([unit_a, unit_b]) + self.output_size = tensor_shape.as_shape([unit_a, unit_b]) + super(Minimal2DRNNCell, self).__init__(**kwargs) + + def build(self, input_shape): + input_a = input_shape[-2] + input_b = input_shape[-1] + self.kernel = self.add_weight( + shape=(input_a, input_b, self.unit_a, self.unit_b), + initializer='uniform', + name='kernel') + self.recurring_kernel = self.add_weight( + shape=(self.unit_a, self.unit_b, self.unit_a, self.unit_b), + initializer='uniform', + name='recurring_kernel') + self.bias = self.add_weight( + shape=(self.unit_a, self.unit_b), initializer='uniform', name='bias') + self.built = True + + def call(self, inputs, states): + prev_output = states[0] + h = special_math_ops.einsum('bij,ijkl->bkl', inputs, self.kernel) + h += array_ops.expand_dims(self.bias, axis=0) + output = h + special_math_ops.einsum('bij,ijkl->bkl', prev_output, + self.recurring_kernel) + return output, [output] + + +class PlusOneRNNCell(keras.layers.Layer): + """Add one to the input and state. + + This cell is used for testing state_size and output_size.""" + + def __init__(self, num_unit, **kwargs): + self.state_size = num_unit + super(PlusOneRNNCell, self).__init__(**kwargs) + + def build(self, input_shape): + self.output_size = input_shape[-1] + + def call(self, inputs, states): + return inputs + 1, [states[0] + 1] + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index be306c0af765dd79bcc2b7651d97957c1cf80519..7c45e08b5c48084cc57569a4d1102a0a7c5b29e1 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -20,8 +20,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras.engine import Input -from tensorflow.python.keras.engine import InputLayer +from tensorflow.python.keras.engine.input_layer import Input +from tensorflow.python.keras.engine.input_layer import InputLayer from tensorflow.python.keras.layers.advanced_activations import * from tensorflow.python.keras.layers.convolutional import * from tensorflow.python.keras.layers.convolutional_recurrent import * diff --git a/tensorflow/python/keras/layers/simplernn_test.py b/tensorflow/python/keras/layers/simplernn_test.py index 3d24b0d5045d9c264f32adedaa0e91cdc5cbb0cf..1429537648d6fb01b4ffcd1e7a1c5447ef442cc0 100644 --- a/tensorflow/python/keras/layers/simplernn_test.py +++ b/tensorflow/python/keras/layers/simplernn_test.py @@ -29,7 +29,7 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer class SimpleRNNLayerTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_return_sequences_SimpleRNN(self): num_samples = 2 timesteps = 3 @@ -41,7 +41,7 @@ class SimpleRNNLayerTest(test.TestCase): 'return_sequences': True}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dynamic_behavior_SimpleRNN(self): num_samples = 2 timesteps = 3 @@ -55,7 +55,7 @@ class SimpleRNNLayerTest(test.TestCase): y = np.random.random((num_samples, units)) model.train_on_batch(x, y) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_dropout_SimpleRNN(self): num_samples = 2 timesteps = 3 @@ -68,7 +68,7 @@ class SimpleRNNLayerTest(test.TestCase): 'recurrent_dropout': 0.1}, input_shape=(num_samples, timesteps, embedding_dim)) - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_implementation_mode_SimpleRNN(self): num_samples = 2 timesteps = 3 @@ -183,6 +183,7 @@ class SimpleRNNLayerTest(test.TestCase): self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint) self.assertEqual(layer.cell.bias.constraint, b_constraint) + @tf_test_util.run_in_graph_and_eager_modes def test_with_masking_layer_SimpleRNN(self): layer_class = keras.layers.SimpleRNN with self.test_session(): @@ -192,7 +193,8 @@ class SimpleRNNLayerTest(test.TestCase): model = keras.models.Sequential() model.add(keras.layers.Masking(input_shape=(3, 4))) model.add(layer_class(units=5, return_sequences=True, unroll=False)) - model.compile(loss='categorical_crossentropy', optimizer='adam') + model.compile(loss='categorical_crossentropy', + optimizer=RMSPropOptimizer(0.01)) model.fit(inputs, targets, epochs=1, batch_size=2, verbose=1) def test_from_config_SimpleRNN(self): diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py index 7759561ef94c4a81552ef7b40ea71e49bbb743ae..9b8d5fc5cc6937e69c0ea53d5719da76a2a52299 100644 --- a/tensorflow/python/keras/layers/wrappers.py +++ b/tensorflow/python/keras/layers/wrappers.py @@ -23,8 +23,8 @@ import copy from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import backend as K -from tensorflow.python.keras.engine import InputSpec -from tensorflow.python.keras.engine import Layer +from tensorflow.python.keras.engine.base_layer import InputSpec +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.layers.recurrent import _standardize_args from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils @@ -45,6 +45,7 @@ class Wrapper(Layer): """ def __init__(self, layer, **kwargs): + assert isinstance(layer, Layer) self.layer = layer # Tracks mapping of Wrapper inputs to inner layer inputs. Useful when # the inner layer has update ops that depend on its inputs (as opposed @@ -154,11 +155,51 @@ class TimeDistributed(Wrapper): Arguments: layer: a layer instance. + + Raises: + ValueError: If not initialized with a `Layer` instance. """ def __init__(self, layer, **kwargs): + if not isinstance(layer, Layer): + raise ValueError( + 'Please initialize `TimeDistributed` layer with a ' + '`Layer` instance. You passed: {input}'.format(input=layer)) super(TimeDistributed, self).__init__(layer, **kwargs) self.supports_masking = True + self._track_checkpointable(layer, name='layer') + + def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None): + """Finds non-specific dimensions in the static shapes. + + The static shapes are replaced with the corresponding dynamic shapes of the + tensor. + + Arguments: + init_tuple: a tuple, the first part of the output shape + tensor: the tensor from which to get the (static and dynamic) shapes + as the last part of the output shape + start_idx: int, which indicate the first dimension to take from + the static shape of the tensor + int_shape: an alternative static shape to take as the last part + of the output shape + Returns: + The new int_shape with the first part from init_tuple + and the last part from either `int_shape` (if provided) + or `tensor.shape`, where every `None` is replaced by + the corresponding dimension from `tf.shape(tensor)`. + """ + # replace all None in int_shape by K.shape + if int_shape is None: + int_shape = K.int_shape(tensor)[start_idx:] + if not any(not s for s in int_shape): + return init_tuple + tuple(int_shape) + shape = K.shape(tensor) + int_shape = list(int_shape) + for i, s in enumerate(int_shape): + if not s: + int_shape[i] = shape[start_idx + i] + return init_tuple + tuple(int_shape) def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() @@ -166,7 +207,10 @@ class TimeDistributed(Wrapper): self.input_spec = InputSpec(shape=input_shape) child_input_shape = [input_shape[0]] + input_shape[2:] if not self.layer.built: - self.layer.build(child_input_shape) + # The base layer class calls a conversion function on the input shape to + # convert it to a TensorShape. The conversion function requires a + # tuple which is why we cast the shape. + self.layer.build(tuple(child_input_shape)) self.layer.built = True super(TimeDistributed, self).build() self.built = True @@ -212,18 +256,24 @@ class TimeDistributed(Wrapper): input_length = input_shape[1] if not input_length: input_length = array_ops.shape(inputs)[1] + inner_input_shape = self._get_shape_tuple((-1,), inputs, 2) # Shape: (num_samples * timesteps, ...). And track the # transformation in self._input_map. input_uid = generic_utils.object_list_uid(inputs) - inputs = array_ops.reshape(inputs, (-1,) + input_shape[2:]) + inputs = array_ops.reshape(inputs, inner_input_shape) self._input_map[input_uid] = inputs # (num_samples * timesteps, ...) + if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None: + inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) + kwargs['mask'] = K.reshape(mask, inner_mask_shape) y = self.layer.call(inputs, **kwargs) if hasattr(y, '_uses_learning_phase'): uses_learning_phase = y._uses_learning_phase # Shape: (num_samples, timesteps, ...) output_shape = self.compute_output_shape(input_shape).as_list() - y = array_ops.reshape(y, (-1, input_length) + tuple(output_shape[2:])) + output_shape = self._get_shape_tuple( + (-1, input_length), y, 1, output_shape[2:]) + y = array_ops.reshape(y, output_shape) # Apply activity regularizer if any: if (hasattr(self.layer, 'activity_regularizer') and @@ -235,6 +285,80 @@ class TimeDistributed(Wrapper): y._uses_learning_phase = True return y + def compute_mask(self, inputs, mask=None): + """Computes an output mask tensor for Embedding layer. + + This is based on the inputs, mask, and the inner layer. + If batch size is specified: + Simply return the input `mask`. (An rnn-based implementation with + more than one rnn inputs is required but not supported in tf.keras yet.) + Otherwise we call `compute_mask` of the inner layer at each time step. + If the output mask at each time step is not `None`: + (E.g., inner layer is Masking or RNN) + Concatenate all of them and return the concatenation. + If the output mask at each time step is `None` and the input mask is not + `None`:(E.g., inner layer is Dense) + Reduce the input_mask to 2 dimensions and return it. + Otherwise (both the output mask and the input mask are `None`): + (E.g., `mask` is not used at all) + Return `None`. + + Arguments: + inputs: Tensor with shape [batch size, timesteps, ...] indicating the + input to TimeDistributed. If static shape information is available for + "batch size", `mask` is returned unmodified. + mask: Either None (indicating no masking) or a Tensor indicating the + input mask for TimeDistributed. The shape can be static or dynamic. + + Returns: + Either None (no masking), or a [batch size, timesteps, ...] Tensor with + an output mask for the TimeDistributed layer with the shape beyond the + second dimension being the value of the input mask shape(if the computed + output mask is none), an output mask with the shape beyond the first + dimension being the value of the mask shape(if mask is not None) or + output mask with the shape beyond the first dimension being the + value of the computed output shape. + + """ + # cases need to call the layer.compute_mask when input_mask is None: + # Masking layer and Embedding layer with mask_zero + input_shape = K.int_shape(inputs) + if input_shape[0]: + # batch size matters, we currently do not handle mask explicitly + return mask + inner_mask = mask + if inner_mask is not None: + inner_mask_shape = self._get_shape_tuple((-1,), mask, 2) + inner_mask = K.reshape(inner_mask, inner_mask_shape) + input_uid = generic_utils.object_list_uid(inputs) + inner_inputs = self._input_map.get(input_uid, inputs) + output_mask = self.layer.compute_mask(inner_inputs, inner_mask) + if output_mask is None: + if mask is None: + return None + # input_mask is not None, and output_mask is None: + # we should return a not-None mask + output_mask = mask + for _ in range(2, len(K.int_shape(mask))): + output_mask = K.any(output_mask, axis=-1) + else: + # output_mask is not None. We need to reshape it + input_length = input_shape[1] + if not input_length: + input_length = K.shape(inputs)[1] + output_mask_int_shape = K.int_shape(output_mask) + if output_mask_int_shape is None: + # if the output_mask does not have a static shape, + # its shape must be the same as mask's + if mask is not None: + output_mask_int_shape = K.int_shape(mask) + else: + output_mask_int_shape = K.compute_output_shape(input_shape)[:-1] + output_mask_shape = self._get_shape_tuple( + (-1, input_length), output_mask, 1, output_mask_int_shape[1:]) + output_mask = K.reshape(output_mask, output_mask_shape) + return output_mask + @tf_export('keras.layers.Bidirectional') class Bidirectional(Wrapper): @@ -249,7 +373,8 @@ class Bidirectional(Wrapper): they will be returned as a list. Raises: - ValueError: In case of invalid `merge_mode` argument. + ValueError: If not initialized with a `Layer` instance or + In case of invalid `merge_mode` argument. Examples: @@ -265,6 +390,10 @@ class Bidirectional(Wrapper): """ def __init__(self, layer, merge_mode='concat', weights=None, **kwargs): + if not isinstance(layer, Layer): + raise ValueError( + 'Please initialize `Bidirectional` layer with a ' + '`Layer` instance. You passed: {input}'.format(input=layer)) if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]: raise ValueError('Invalid merge mode. ' 'Merge mode should be one of ' @@ -288,6 +417,8 @@ class Bidirectional(Wrapper): self._num_constants = None super(Bidirectional, self).__init__(layer, **kwargs) self.input_spec = layer.input_spec + self._track_checkpointable(self.forward_layer, name='forward_layer') + self._track_checkpointable(self.backward_layer, name='backward_layer') @property def trainable(self): @@ -397,7 +528,8 @@ class Bidirectional(Wrapper): else: return super(Bidirectional, self).__call__(inputs, **kwargs) - def call(self, inputs, + def call(self, + inputs, training=None, mask=None, initial_state=None, diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index 5eab6aba8a5f9a7e70f55685a9cd9ae6e0cf024d..0cd774ef0fa70ede62d496db981817b58666bcfc 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -23,8 +23,10 @@ import copy import numpy as np from tensorflow.python import keras +from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import util as checkpointable_util from tensorflow.python.training.rmsprop import RMSPropOptimizer @@ -69,7 +71,7 @@ class _RNNCellWithConstants(keras.layers.Layer): class TimeDistributedTest(test.TestCase): - @tf_test_util.run_in_graph_and_eager_modes() + @tf_test_util.run_in_graph_and_eager_modes def test_timedistributed_dense(self): model = keras.models.Sequential() model.add( @@ -85,6 +87,12 @@ class TimeDistributedTest(test.TestCase): # test config model.get_config() + # check whether the model variables are present in the + # checkpointable list of objects + checkpointed_objects = set(checkpointable_util.list_objects(model)) + for v in model.variables: + self.assertIn(v, checkpointed_objects) + def test_timedistributed_static_batch_size(self): model = keras.models.Sequential() model.add( @@ -97,6 +105,13 @@ class TimeDistributedTest(test.TestCase): epochs=1, batch_size=10) + def test_timedistributed_invalid_init(self): + x = constant_op.constant(np.zeros((1, 1)).astype('float32')) + with self.assertRaisesRegexp( + ValueError, + 'Please initialize `TimeDistributed` layer with a `Layer` instance.'): + keras.layers.TimeDistributed(x) + def test_timedistributed_conv2d(self): with self.test_session(): model = keras.models.Sequential() @@ -177,8 +192,8 @@ class TimeDistributedTest(test.TestCase): x = keras.layers.Input(shape=(3, 2)) layer = keras.layers.TimeDistributed(keras.layers.BatchNormalization()) _ = layer(x) - assert len(layer.updates) == 2 - assert len(layer.trainable_weights) == 2 + self.assertEquals(len(layer.updates), 2) + self.assertEquals(len(layer.trainable_weights), 2) layer.trainable = False assert not layer.updates assert not layer.trainable_weights @@ -186,6 +201,62 @@ class TimeDistributedTest(test.TestCase): assert len(layer.updates) == 2 assert len(layer.trainable_weights) == 2 + def test_TimeDistributed_with_masked_embedding_and_unspecified_shape(self): + with self.test_session(): + # test with unspecified shape and Embeddings with mask_zero + model = keras.models.Sequential() + model.add(keras.layers.TimeDistributed( + keras.layers.Embedding(5, 6, mask_zero=True), + input_shape=(None, None))) # N by t_1 by t_2 by 6 + model.add(keras.layers.TimeDistributed( + keras.layers.SimpleRNN(7, return_sequences=True))) + model.add(keras.layers.TimeDistributed( + keras.layers.SimpleRNN(8, return_sequences=False))) + model.add(keras.layers.SimpleRNN(1, return_sequences=False)) + model.compile(optimizer='rmsprop', loss='mse') + model_input = np.random.randint(low=1, high=5, size=(10, 3, 4), + dtype='int32') + for i in range(4): + model_input[i, i:, i:] = 0 + model.fit(model_input, + np.random.random((10, 1)), epochs=1, batch_size=10) + mask_outputs = [model.layers[0].compute_mask(model.input)] + for layer in model.layers[1:]: + mask_outputs.append(layer.compute_mask(layer.input, mask_outputs[-1])) + func = keras.backend.function([model.input], mask_outputs[:-1]) + mask_outputs_val = func([model_input]) + ref_mask_val_0 = model_input > 0 # embedding layer + ref_mask_val_1 = ref_mask_val_0 # first RNN layer + ref_mask_val_2 = np.any(ref_mask_val_1, axis=-1) # second RNN layer + ref_mask_val = [ref_mask_val_0, ref_mask_val_1, ref_mask_val_2] + for i in range(3): + self.assertAllEqual(mask_outputs_val[i], ref_mask_val[i]) + self.assertIs(mask_outputs[-1], None) # final layer + + def test_TimeDistributed_with_masking_layer(self): + with self.test_session(): + # test with Masking layer + model = keras.models.Sequential() + model.add(keras.layers.TimeDistributed(keras.layers.Masking( + mask_value=0.,), input_shape=(None, 4))) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(5))) + model.compile(optimizer='rmsprop', loss='mse') + model_input = np.random.randint(low=1, high=5, size=(10, 3, 4)) + for i in range(4): + model_input[i, i:, :] = 0. + model.compile(optimizer='rmsprop', loss='mse') + model.fit(model_input, + np.random.random((10, 3, 5)), epochs=1, batch_size=6) + mask_outputs = [model.layers[0].compute_mask(model.input)] + mask_outputs += [model.layers[1].compute_mask(model.layers[1].input, + mask_outputs[-1])] + func = keras.backend.function([model.input], mask_outputs) + mask_outputs_val = func([model_input]) + self.assertEqual((mask_outputs_val[0]).all(), + model_input.all()) + self.assertEqual((mask_outputs_val[1]).all(), + model_input.all()) + class BidirectionalTest(test.TestCase): @@ -209,6 +280,12 @@ class BidirectionalTest(test.TestCase): model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') model.fit(x, y, epochs=1, batch_size=1) + # check whether the model variables are present in the + # checkpointable list of objects + checkpointed_objects = set(checkpointable_util.list_objects(model)) + for v in model.variables: + self.assertIn(v, checkpointed_objects) + # test compute output shape ref_shape = model.layers[-1].output.get_shape() shape = model.layers[-1].compute_output_shape( @@ -220,6 +297,13 @@ class BidirectionalTest(test.TestCase): model = keras.models.model_from_json(model.to_json()) model.summary() + def test_bidirectional_invalid_init(self): + x = constant_op.constant(np.zeros((1, 1)).astype('float32')) + with self.assertRaisesRegexp( + ValueError, + 'Please initialize `Bidirectional` layer with a `Layer` instance.'): + keras.layers.Bidirectional(x) + def test_bidirectional_weight_loading(self): rnn = keras.layers.SimpleRNN samples = 2 @@ -424,6 +508,42 @@ class BidirectionalTest(test.TestCase): layer.trainable = True assert len(layer.trainable_weights) == 6 + def test_Bidirectional_updates(self): + with self.test_session(): + x = keras.layers.Input(shape=(3, 2)) + x_reachable_update = x * x + layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3)) + _ = layer(x) + assert not layer.updates + assert not layer.get_updates_for(None) + assert not layer.get_updates_for(x) + layer.forward_layer.add_update(x_reachable_update, inputs=x) + layer.forward_layer.add_update(1, inputs=None) + layer.backward_layer.add_update(x_reachable_update, inputs=x) + layer.backward_layer.add_update(1, inputs=None) + assert len(layer.updates) == 4 + assert len(layer.get_updates_for(None)) == 2 + assert len(layer.get_updates_for(x)) == 2 + + def test_Bidirectional_losses(self): + with self.test_session(): + x = keras.layers.Input(shape=(3, 2)) + x_reachable_loss = x * x + layer = keras.layers.Bidirectional( + keras.layers.SimpleRNN( + 3, kernel_regularizer='l1', bias_regularizer='l1')) + _ = layer(x) + assert len(layer.losses) == 4 + assert len(layer.get_losses_for(None)) == 4 + assert not layer.get_losses_for(x) + layer.forward_layer.add_loss(x_reachable_loss, inputs=x) + layer.forward_layer.add_loss(1, inputs=None) + layer.backward_layer.add_loss(x_reachable_loss, inputs=x) + layer.backward_layer.add_loss(1, inputs=None) + assert len(layer.losses) == 8 + assert len(layer.get_losses_for(None)) == 6 + assert len(layer.get_losses_for(x)) == 2 + def test_Bidirectional_with_constants(self): with self.test_session(): # Test basic case. diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index e03d7dfe93585efd06f4701a8d20f61fc314d564..9b87170ebe5bead0ea07f7b0ae2062cbb515812b 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -19,9 +19,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from abc import ABCMeta +from abc import abstractmethod + +import types import six +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.losses import binary_crossentropy from tensorflow.python.keras.losses import categorical_crossentropy from tensorflow.python.keras.losses import cosine_proximity @@ -37,14 +46,531 @@ from tensorflow.python.keras.losses import sparse_categorical_crossentropy from tensorflow.python.keras.losses import squared_hinge from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import confusion_matrix +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import weights_broadcast_ops +from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.util import tf_decorator from tensorflow.python.util.tf_export import tf_export +def check_is_tensor_or_operation(x, name): + """Raises type error if the given input is not a tensor or operation.""" + if not (isinstance(x, ops.Tensor) or isinstance(x, ops.Operation)): + raise TypeError('{0} must be a Tensor or Operation, given: {1}'.format( + name, x)) + + +def update_state_wrapper(update_state_fn): + """Decorator to wrap metric `update_state()` with `add_update()`. + + Args: + update_state_fn: function that accumulates metric statistics. + + Returns: + Decorated function that wraps `update_state_fn()` with `add_update()`. + """ + + def decorated(metric_obj, *args, **kwargs): + """Decorated function with `add_update()`.""" + + update_op = update_state_fn(*args, **kwargs) + if update_op is not None: # update_op will be None in eager execution. + metric_obj.add_update(update_op, inputs=True) + check_is_tensor_or_operation( + update_op, 'Metric {0}\'s update'.format(metric_obj.name)) + return update_op + + return tf_decorator.make_decorator(update_state_fn, decorated) + + +def result_wrapper(result_fn): + """Decorator to wrap metric `result()` function in `merge_call()`. + + Result computation is an idempotent operation that simply calculates the + metric value using the state variables. + + If metric state variables are distributed across towers/devices and + `result()` is requested from the context of one device - This function wraps + `result()` in a distribution strategy `merge_call()`. With this, + the metric state variables will be aggregated across devices. + + Args: + result_fn: function that computes the metric result. + + Returns: + Decorated function that wraps `result_fn()` in distribution strategy + `merge_call()`. + """ + + def decorated(metric_obj, *args): + """Decorated function with merge_call.""" + tower_context = distribution_strategy_context.get_tower_context() + if tower_context is None: # if in cross tower context already + result_t = result_fn(*args) + else: + # TODO(psv): Test distribution of metrics using different distribution + # strategies. + + # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn + # with distribution object as the first parameter. We create a wrapper + # here so that the result function need not have that parameter. + def merge_fn_wrapper(distribution, merge_fn, *args): + # We will get `PerDevice` merge function. Taking the first one as all + # are identical copies of the function that we had passed below. + return distribution.unwrap(merge_fn)[0](*args) + + # Wrapping result in merge_call. merge_call is used when we want to leave + # tower mode and compute a value in cross tower mode. + result_t = tower_context.merge_call(merge_fn_wrapper, result_fn, *args) + check_is_tensor_or_operation(result_t, + 'Metric {0}\'s result'.format(metric_obj.name)) + return result_t + + return tf_decorator.make_decorator(result_fn, decorated) + + +def safe_div(numerator, denominator): + """Divides two tensors element-wise, returning 0 if the denominator is <= 0. + + Args: + numerator: A `Tensor`. + denominator: A `Tensor`, with dtype matching `numerator`. + + Returns: + 0 if `denominator` <= 0, else `numerator` / `denominator` + """ + t = math_ops.truediv(numerator, denominator) + zero = array_ops.zeros_like(t, dtype=denominator.dtype) + condition = math_ops.greater(denominator, zero) + zero = math_ops.cast(zero, t.dtype) + return array_ops.where(condition, t, zero) + + +def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight): + """Squeeze or expand last dimension if needed. + + 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1 + (using `confusion_matrix.remove_squeezable_dimensions`). + 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 + from the new rank of `y_pred`. + If `sample_weight` is scalar, it is kept scalar. + + This will use static shape if available. Otherwise, it will add graph + operations, which could result in a performance hit. + + Args: + y_pred: Predicted values, a `Tensor` of arbitrary dimensions. + y_true: Optional label `Tensor` whose dimensions match `y_pred`. + sample_weight: Optional weight scalar or `Tensor` whose dimensions match + `y_pred`. + + Returns: + Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has + the last dimension squeezed, + `sample_weight` could be extended by one dimension. + """ + if y_true is not None: + # squeeze last dim of `y_pred` or `y_true` if their rank differs by 1 + y_true, y_pred = confusion_matrix.remove_squeezable_dimensions( + y_true, y_pred) + y_pred.get_shape().assert_is_compatible_with(y_true.get_shape()) + + if sample_weight is None: + return y_pred, y_true, None + + sample_weight = ops.convert_to_tensor(sample_weight) + weights_shape = sample_weight.get_shape() + weights_rank = weights_shape.ndims + if weights_rank == 0: # If weights is scalar, do nothing. + return y_pred, y_true, sample_weight + + y_pred_shape = y_pred.get_shape() + y_pred_rank = y_pred_shape.ndims + if (y_pred_rank is not None) and (weights_rank is not None): + # Use static rank. + if weights_rank - y_pred_rank == 1: + sample_weight = array_ops.squeeze(sample_weight, [-1]) + elif y_pred_rank - weights_rank == 1: + sample_weight = array_ops.expand_dims(sample_weight, [-1]) + return y_pred, y_true, sample_weight + + # Use dynamic rank. + weights_rank_tensor = array_ops.rank(sample_weight) + rank_diff = weights_rank_tensor - array_ops.rank(y_pred) + maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1]) + + def _maybe_expand_weights(): + return control_flow_ops.cond( + math_ops.equal(rank_diff, + -1), lambda: array_ops.expand_dims(sample_weight, [-1]), + lambda: sample_weight) + + def _maybe_adjust_weights(): + return control_flow_ops.cond( + math_ops.equal(rank_diff, 1), maybe_squeeze_weights, + _maybe_expand_weights) + + # squeeze or expand last dim of `sample_weight` if its rank differs by 1 + # from the new rank of `y_pred`. + sample_weight = control_flow_ops.cond( + math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight, + _maybe_adjust_weights) + return y_pred, y_true, sample_weight + + +class Metric(Layer): + """Encapsulates metric logic and state. + + Usage with eager execution: + + ```python + m = SomeMetric(...) + for input in ...: + m.update_state(input) + print('Final result: ', m.result().numpy()) + ``` + + Usage with graph execution: + + ```python + m = SomeMetric(...) + init_op = tf.global_variables_initializer() # Initialize variables + with tf.Session() as sess: + sess.run(init_op) + for input in ...: + update_op = m.update_state(input) + sess.run(update_op) + print('Final result: ', sess.run(m.result())) + ``` + + Usage with tf.keras API: + + ```python + model = tf.keras.Sequential() + model.add(tf.keras.layers.Dense(64, activation='relu')) + model.add(tf.keras.layers.Dense(64, activation='relu')) + model.add(tf.keras.layers.Dense(10, activation='softmax')) + + model.compile(optimizer=tf.train.RMSPropOptimizer(0.01), + loss=tf.keras.losses.categorical_crossentropy, + metrics=[tf.keras.metrics.CategoricalAccuracy()]) + + data = np.random.random((1000, 32)) + labels = np.random.random((1000, 10)) + + dataset = tf.data.Dataset.from_tensor_slices((data, labels)) + dataset = dataset.batch(32) + dataset = dataset.repeat() + + model.fit(dataset, epochs=10, steps_per_epoch=30) + ``` + + To be implemented by subclasses: + * `__init__()`: All state variables should be created in this method by + calling `self.add_weight()` like: `self.var = self.add_weight(...)` + * `update_state()`: Has all updates to the state variables like: + self.var.assign_add(...). + * `result()`: Computes and returns a value for the metric + from the state variables. + + Example subclass implementation: + + ``` + class BinaryTruePositives(Metric): + def __init__(self, name='binary_true_positives', dtype=None): + super(BinaryTruePositives, self).__init__(name=name, dtype=dtype) + self.true_positives = self.add_weight( + 'true_positives', initializer=init_ops.zeros_initializer) + + def update_state(self, y_true, y_pred, sample_weight=None): + y_true = math_ops.cast(y_true, dtypes.bool) + y_pred = math_ops.cast(y_pred, dtypes.bool) + y_pred, y_true, sample_weight = squeeze_or_expand_dimensions( + y_pred, y_true, sample_weight) + + values = math_ops.logical_and( + math_ops.equal(y_true, True), math_ops.equal(y_pred, True)) + values = math_ops.cast(values, self._dtype) + if sample_weight is not None: + sample_weight = math_ops.cast(sample_weight, self._dtype) + values = math_ops.multiply(values, sample_weight) + state_ops.assign_add(self.true_positives, math_ops.reduce_sum(values)) + + def result(self): + return array_ops.identity(self.true_positives) + ``` + """ + __metaclass__ = ABCMeta + + def __init__(self, name=None, dtype=None): + super(Metric, self).__init__(name=name, dtype=dtype) + self.stateful = True # All metric layers are stateful. + self.built = True + self._dtype = K.floatx() if dtype is None else dtypes.as_dtype(dtype).name + + def __new__(cls, *args, **kwargs): + obj = super(Metric, cls).__new__(cls) + # TODO(psv): Fix reference cycle issue here. + + # Converting update_state_fn() into a graph function, so that + # we can return a single op that performs all of the variable updates. + defuned_update_state_fn = function.defun(obj.update_state) + obj.update_state = types.MethodType( + update_state_wrapper(defuned_update_state_fn), obj) + obj.result = types.MethodType(result_wrapper(obj.result), obj) + return obj + + def __call__(self, *args, **kwargs): + """Accumulates statistics and then computes metric result value. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric, + passed on to `update_state()`. + + Returns: + The metric value tensor. + """ + update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable + with ops.control_dependencies([update_op]): + return self.result() # pylint: disable=not-callable + + def reset_states(self): + """Resets all of the metric state variables. + + This function is called between epochs/steps, + when a metric is evaluated during training. + """ + for v in self.variables: + K.set_value(v, 0) + + @abstractmethod + def update_state(self, *args, **kwargs): + """Accumulates statistics for the metric. + + Note: This function is executed as a graph function in graph mode. + This means: + a) Operations on the same resource are executed in textual order. + This should make it easier to do things like add the updated + value of a variable to another, for example. + b) You don't need to worry about collecting the update ops to execute. + All update ops added to the graph by this function will be executed. + As a result, code should generally work the same way with graph or + eager execution. + and adds the update op to the metric layer. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric. + """ + NotImplementedError('Must be implemented in subclasses.') + + @abstractmethod + def result(self): + """Computes and returns the metric value tensor. + + Result computation is an idempotent operation that simply calculates the + metric value using the state variables. + """ + NotImplementedError('Must be implemented in subclasses.') + + @classmethod + def from_config(cls, config): + if 'trainable' in config: + config.pop('trainable') + return cls(**config) + + ### For use by subclasses ### + def add_weight(self, + name, + shape=(), + aggregation=vs.VariableAggregation.SUM, + synchronization=vs.VariableSynchronization.ON_READ, + initializer=None): + """Adds state variable. Only for use by subclasses.""" + return super(Metric, self).add_weight( + name=name, + shape=shape, + dtype=self._dtype, + trainable=False, + initializer=initializer, + synchronization=synchronization, + aggregation=aggregation) + + ### End: For use by subclasses ### + + +class Mean(Metric): + """Computes the (weighted) mean of the given values. + + This metric creates two variables, `total` and `count` that are used to + compute the average of `values`. This average is ultimately returned as `mean` + which is an idempotent operation that simply divides `total` by `count`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + """ + + def __init__(self, name='mean', dtype=None): + """Creates a `Mean` instance. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + """ + super(Mean, self).__init__(name=name, dtype=dtype) + # Create new state variables + self.total = self.add_weight( + 'total', initializer=init_ops.zeros_initializer) + self.count = self.add_weight( + 'count', initializer=init_ops.zeros_initializer) + + def update_state(self, values, sample_weight=None): + """Accumulates statistics for computing the mean. + + For example, if `values` is [1, 3, 5, 7] then the mean is 4. If + the `sample_weight` is specified as [1, 1, 0, 0] then the mean would be 2. + + Args: + values: Per-example value. + sample_weight: Optional weighting of each example. Defaults to 1. + """ + values = math_ops.cast(values, self._dtype) + if sample_weight is None: + num_values = math_ops.cast(array_ops.size(values), self._dtype) + else: + sample_weight = math_ops.cast(sample_weight, self._dtype) + + # Update dimensions of weights to match with values if possible. + values, _, sample_weight = squeeze_or_expand_dimensions( + values, None, sample_weight) + try: + # Broadcast weights if possible. + sample_weight = weights_broadcast_ops.broadcast_weights( + sample_weight, values) + except ValueError: + # Reduce values to same ndim as weight array + ndim = K.ndim(values) + weight_ndim = K.ndim(sample_weight) + values = math_ops.reduce_mean( + values, axis=list(range(weight_ndim, ndim))) + + num_values = math_ops.reduce_sum(sample_weight) + values = math_ops.multiply(values, sample_weight) + values = math_ops.reduce_sum(values) + + # Update state variables + state_ops.assign_add(self.total, values) + state_ops.assign_add(self.count, num_values) + + def result(self): + return safe_div(self.total, self.count) + + +class MeanMetricWrapper(Mean): + """Wraps a stateless metric function with the Mean metric.""" + + def __init__(self, fn, name=None, dtype=None, **kwargs): + """Creates a `MeanMetricWrapper` instance. + + Args: + fn: The metric function to wrap, with signature + `fn(y_true, y_pred, **kwargs)`. + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + **kwargs: The keyword arguments that are passed on to `fn`. + """ + super(MeanMetricWrapper, self).__init__(name=name, dtype=dtype) + self._fn = fn + self._fn_kwargs = kwargs + + def update_state(self, y_true, y_pred, sample_weight=None): + """Accumulates metric statistics. + + `y_true` and `y_pred` should have the same shape. + + Args: + y_true: The ground truth values. + y_pred: The predicted values. + sample_weight: Optional weighting of each example. Defaults to 1. Can be + a `Tensor` whose rank is either 0, or the same rank as `y_true`, + and must be broadcastable to `y_true`. + """ + y_true = math_ops.cast(y_true, self._dtype) + y_pred = math_ops.cast(y_pred, self._dtype) + y_pred, y_true, sample_weight = squeeze_or_expand_dimensions( + y_pred, y_true, sample_weight) + + matches = self._fn(y_true, y_pred, **self._fn_kwargs) + super(MeanMetricWrapper, self).update_state( + matches, sample_weight=sample_weight) + + def get_config(self): + config = self._fn_kwargs + base_config = super(MeanMetricWrapper, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class BinaryAccuracy(MeanMetricWrapper): + """Calculates how often predictions matches labels. + + This metric creates two local variables, `total` and `count` that are used to + compute the frequency with which `y_pred` matches `y_true`. This frequency is + ultimately returned as `binary accuracy`: an idempotent operation that simply + divides `total` by `count`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + """ + + def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5): + """Creates a `BinaryAccuracy` instance. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + threshold: (Optional) Float representing the threshold for deciding + whether prediction values are 1 or 0. + """ + super(BinaryAccuracy, self).__init__( + binary_accuracy, name, dtype=dtype, threshold=threshold) + + +class CategoricalAccuracy(MeanMetricWrapper): + """Calculates how often predictions matches labels. + + This metric creates two local variables, `total` and `count` that are used to + compute the frequency with which `y_pred` matches `y_true`. This frequency is + ultimately returned as `categorical accuracy`: an idempotent operation that + simply divides `total` by `count`. + + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + """ + + def __init__(self, name='categorical_accuracy', dtype=None): + """Creates a `CategoricalAccuracy` instance. + + Args: + name: (Optional) string name of the metric instance. + dtype: (Optional) data type of the metric result. + """ + super(CategoricalAccuracy, self).__init__( + categorical_accuracy, name, dtype=dtype) + + @tf_export('keras.metrics.binary_accuracy') -def binary_accuracy(y_true, y_pred): - return K.mean(math_ops.equal(y_true, math_ops.round(y_pred)), axis=-1) +def binary_accuracy(y_true, y_pred, threshold=0.5): + threshold = math_ops.cast(threshold, y_pred.dtype) + y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype) + return K.mean(math_ops.equal(y_true, y_pred), axis=-1) @tf_export('keras.metrics.categorical_accuracy') @@ -103,8 +629,7 @@ def deserialize(config, custom_objects=None): @tf_export('keras.metrics.get') def get(identifier): if isinstance(identifier, dict): - config = {'class_name': str(identifier), 'config': {}} - return deserialize(config) + return deserialize(identifier) elif isinstance(identifier, six.string_types): return deserialize(str(identifier)) elif callable(identifier): diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index 15e793f5fcf0b416978095da370fbdaabd1490a6..2ac74219d4035420afcf3421dd1d2ca803a0f576 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -18,67 +18,72 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np -from tensorflow.python import keras +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import layers +from tensorflow.python.keras import metrics +from tensorflow.python.keras.engine.training import Model +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import util as checkpointable_utils class KerasMetricsTest(test.TestCase): def test_metrics(self): with self.test_session(): - y_a = keras.backend.variable(np.random.random((6, 7))) - y_b = keras.backend.variable(np.random.random((6, 7))) - for metric in [keras.metrics.binary_accuracy, - keras.metrics.categorical_accuracy]: + y_a = K.variable(np.random.random((6, 7))) + y_b = K.variable(np.random.random((6, 7))) + for metric in [metrics.binary_accuracy, metrics.categorical_accuracy]: output = metric(y_a, y_b) - self.assertEqual(keras.backend.eval(output).shape, (6,)) + self.assertEqual(K.eval(output).shape, (6,)) def test_sparse_categorical_accuracy(self): with self.test_session(): - metric = keras.metrics.sparse_categorical_accuracy - y_a = keras.backend.variable(np.random.randint(0, 7, (6,))) - y_b = keras.backend.variable(np.random.random((6, 7))) - self.assertEqual(keras.backend.eval(metric(y_a, y_b)).shape, (6,)) + metric = metrics.sparse_categorical_accuracy + y_a = K.variable(np.random.randint(0, 7, (6,))) + y_b = K.variable(np.random.random((6, 7))) + self.assertEqual(K.eval(metric(y_a, y_b)).shape, (6,)) def test_sparse_top_k_categorical_accuracy(self): with self.test_session(): - y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1], - [0.1, 0.2, 0.7]])) - y_true = keras.backend.variable(np.array([[1], [0]])) - result = keras.backend.eval( - keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3)) + y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) + y_true = K.variable(np.array([[1], [0]])) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3)) self.assertEqual(result, 1) - result = keras.backend.eval( - keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2)) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2)) self.assertEqual(result, 0.5) - result = keras.backend.eval( - keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) + result = K.eval( + metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) self.assertEqual(result, 0.) def test_top_k_categorical_accuracy(self): with self.test_session(): - y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1], - [0.1, 0.2, 0.7]])) - y_true = keras.backend.variable(np.array([[0, 1, 0], [1, 0, 0]])) - result = keras.backend.eval( - keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)) + y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) + y_true = K.variable(np.array([[0, 1, 0], [1, 0, 0]])) + result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)) self.assertEqual(result, 1) - result = keras.backend.eval( - keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=2)) + result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=2)) self.assertEqual(result, 0.5) - result = keras.backend.eval( - keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=1)) + result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=1)) self.assertEqual(result, 0.) def test_stateful_metrics(self): with self.test_session(): np.random.seed(1334) - class BinaryTruePositives(keras.layers.Layer): + class BinaryTruePositives(layers.Layer): """Stateful Metric to count the total true positives over all batches. Assumes predictions and targets of shape `(samples, 1)`. @@ -91,11 +96,11 @@ class KerasMetricsTest(test.TestCase): def __init__(self, name='true_positives', **kwargs): super(BinaryTruePositives, self).__init__(name=name, **kwargs) - self.true_positives = keras.backend.variable(value=0, dtype='int32') + self.true_positives = K.variable(value=0, dtype='int32') self.stateful = True def reset_states(self): - keras.backend.set_value(self.true_positives, 0) + K.set_value(self.true_positives, 0) def __call__(self, y_true, y_pred): """Computes the number of true positives in a batch. @@ -120,14 +125,14 @@ class KerasMetricsTest(test.TestCase): return current_true_pos + true_pos metric_fn = BinaryTruePositives() - config = keras.metrics.serialize(metric_fn) - metric_fn = keras.metrics.deserialize( + config = metrics.serialize(metric_fn) + metric_fn = metrics.deserialize( config, custom_objects={'BinaryTruePositives': BinaryTruePositives}) # Test on simple model - inputs = keras.Input(shape=(2,)) - outputs = keras.layers.Dense(1, activation='sigmoid')(inputs) - model = keras.Model(inputs, outputs) + inputs = layers.Input(shape=(2,)) + outputs = layers.Dense(1, activation='sigmoid')(inputs) + model = Model(inputs, outputs) model.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['acc', metric_fn]) @@ -184,6 +189,245 @@ class KerasMetricsTest(test.TestCase): self.assertAllClose( val_outs[2], history.history['val_true_positives'][-1], atol=1e-5) + @test_util.run_in_graph_and_eager_modes + def test_mean(self): + m = metrics.Mean(name='my_mean') + + # check config + self.assertEqual(m.name, 'my_mean') + self.assertTrue(m.stateful) + self.assertEqual(m.dtype, dtypes.float32) + self.assertEqual(len(m.variables), 2) + self.evaluate(variables.global_variables_initializer()) + + # check initial state + self.assertEqual(self.evaluate(m.total), 0) + self.assertEqual(self.evaluate(m.count), 0) + + # check __call__() + self.assertEqual(self.evaluate(m(100)), 100) + self.assertEqual(self.evaluate(m.total), 100) + self.assertEqual(self.evaluate(m.count), 1) + + # check update_state() and result() + state accumulation + tensor input + update_op = m.update_state(ops.convert_n_to_tensor([1, 5])) + self.evaluate(update_op) + self.assertAlmostEqual(self.evaluate(m.result()), 106 / 3, 2) + self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5 + self.assertEqual(self.evaluate(m.count), 3) + + # check reset_states() + m.reset_states() + self.assertEqual(self.evaluate(m.total), 0) + self.assertEqual(self.evaluate(m.count), 0) + + @test_util.run_in_graph_and_eager_modes + def test_mean_with_sample_weight(self): + m = metrics.Mean(dtype=dtypes.float64) + self.assertEqual(m.dtype, dtypes.float64) + self.evaluate(variables.global_variables_initializer()) + + # check scalar weight + result_t = m(100, sample_weight=0.5) + self.assertEqual(self.evaluate(result_t), 50 / 0.5) + self.assertEqual(self.evaluate(m.total), 50) + self.assertEqual(self.evaluate(m.count), 0.5) + + # check weights not scalar and weights rank matches values rank + result_t = m([1, 5], sample_weight=[1, 0.2]) + result = self.evaluate(result_t) + self.assertAlmostEqual(result, 52 / 1.7, 2) + self.assertAlmostEqual(self.evaluate(m.total), 52, 2) # 50 + 1 + 5 * 0.2 + self.assertAlmostEqual(self.evaluate(m.count), 1.7, 2) # 0.5 + 1.2 + + # check weights broadcast + result_t = m([1, 2], sample_weight=0.5) + self.assertAlmostEqual(self.evaluate(result_t), 53.5 / 2.7, 2) + self.assertAlmostEqual(self.evaluate(m.total), 53.5, 2) # 52 + 0.5 + 1 + self.assertAlmostEqual(self.evaluate(m.count), 2.7, 2) # 1.7 + 0.5 + 0.5 + + # check weights squeeze + result_t = m([1, 5], sample_weight=[[1], [0.2]]) + self.assertAlmostEqual(self.evaluate(result_t), 55.5 / 3.9, 2) + self.assertAlmostEqual(self.evaluate(m.total), 55.5, 2) # 53.5 + 1 + 1 + self.assertAlmostEqual(self.evaluate(m.count), 3.9, 2) # 2.7 + 1.2 + + # check weights expand + result_t = m([[1], [5]], sample_weight=[1, 0.2]) + self.assertAlmostEqual(self.evaluate(result_t), 57.5 / 5.1, 2) + self.assertAlmostEqual(self.evaluate(m.total), 57.5, 2) # 55.5 + 1 + 1 + self.assertAlmostEqual(self.evaluate(m.count), 5.1, 2) # 3.9 + 1.2 + + # check values reduced to the dimensions of weight + result_t = m([[[1., 2.], [3., 2.], [0.5, 4.]]], sample_weight=[0.5]) + result = np.round(self.evaluate(result_t), decimals=2) # 58.5 / 5.6 + self.assertEqual(result, 10.45) + self.assertEqual(np.round(self.evaluate(m.total), decimals=2), 58.54) + self.assertEqual(np.round(self.evaluate(m.count), decimals=2), 5.6) + + def test_mean_graph_with_placeholder(self): + with context.graph_mode(), self.test_session() as sess: + m = metrics.Mean() + v = array_ops.placeholder(dtypes.float32) + w = array_ops.placeholder(dtypes.float32) + sess.run(variables.global_variables_initializer()) + + # check __call__() + result_t = m(v, sample_weight=w) + result = sess.run(result_t, feed_dict=({v: 100, w: 0.5})) + self.assertEqual(sess.run(m.total), 50) + self.assertEqual(sess.run(m.count), 0.5) + self.assertEqual(result, 50 / 0.5) + + # check update_state() and result() + result = sess.run(result_t, feed_dict=({v: [1, 5], w: [1, 0.2]})) + self.assertAlmostEqual(sess.run(m.total), 52, 2) # 50 + 1 + 5 * 0.2 + self.assertAlmostEqual(sess.run(m.count), 1.7, 2) # 0.5 + 1.2 + self.assertAlmostEqual(result, 52 / 1.7, 2) + + @test_util.run_in_graph_and_eager_modes + def test_save_restore(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + m = metrics.Mean() + checkpoint = checkpointable_utils.Checkpoint(mean=m) + self.evaluate(variables.global_variables_initializer()) + + # update state + self.evaluate(m(100.)) + self.evaluate(m(200.)) + + # save checkpoint and then add an update + save_path = checkpoint.save(checkpoint_prefix) + self.evaluate(m(1000.)) + + # restore to the same checkpoint mean object + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.evaluate(m(300.)) + self.assertEqual(200., self.evaluate(m.result())) + + # restore to a different checkpoint mean object + restore_mean = metrics.Mean() + restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean) + status = restore_checkpoint.restore(save_path) + restore_update = restore_mean(300.) + status.assert_consumed().run_restore_ops() + self.evaluate(restore_update) + self.assertEqual(200., self.evaluate(restore_mean.result())) + self.assertEqual(3, self.evaluate(restore_mean.count)) + + @test_util.run_in_graph_and_eager_modes + def test_binary_accuracy(self): + acc_obj = metrics.BinaryAccuracy(name='my acc') + + # check config + self.assertEqual(acc_obj.name, 'my acc') + self.assertTrue(acc_obj.stateful) + self.assertEqual(len(acc_obj.variables), 2) + self.assertEqual(acc_obj.dtype, dtypes.float32) + self.evaluate(variables.global_variables_initializer()) + + # verify that correct value is returned + update_op = acc_obj.update_state([[1], [0]], [[1], [0]]) + self.evaluate(update_op) + result = self.evaluate(acc_obj.result()) + self.assertEqual(result, 1) # 2/2 + + # check y_pred squeeze + update_op = acc_obj.update_state([[1], [1]], [[[1]], [[0]]]) + self.evaluate(update_op) + result = self.evaluate(acc_obj.result()) + self.assertAlmostEqual(result, 0.75, 2) # 3/4 + + # check y_true squeeze + result_t = acc_obj([[[1]], [[1]]], [[1], [0]]) + result = self.evaluate(result_t) + self.assertAlmostEqual(result, 0.67, 2) # 4/6 + + # check with sample_weight + result_t = acc_obj([[1], [1]], [[1], [0]], [[0.5], [0.2]]) + result = self.evaluate(result_t) + self.assertAlmostEqual(result, 0.67, 2) # 4.5/6.7 + + # check incompatible shapes + with self.assertRaisesRegexp(ValueError, + r'Shapes \(1,\) and \(2,\) are incompatible'): + acc_obj.update_state([1, 1], [1]) + + @test_util.run_in_graph_and_eager_modes + def test_binary_accuracy_threshold(self): + acc_obj = metrics.BinaryAccuracy(threshold=0.7) + self.evaluate(variables.global_variables_initializer()) + result_t = acc_obj([[1], [1], [0], [0]], [[0.9], [0.6], [0.4], [0.8]]) + result = self.evaluate(result_t) + self.assertAlmostEqual(result, 0.5, 2) + + @test_util.run_in_graph_and_eager_modes + def test_categorical_accuracy(self): + acc_obj = metrics.CategoricalAccuracy(name='my acc') + + # check config + self.assertEqual(acc_obj.name, 'my acc') + self.assertTrue(acc_obj.stateful) + self.assertEqual(len(acc_obj.variables), 2) + self.assertEqual(acc_obj.dtype, dtypes.float32) + self.evaluate(variables.global_variables_initializer()) + + # verify that correct value is returned + update_op = acc_obj.update_state([[0, 0, 1], [0, 1, 0]], + [[0.1, 0.1, 0.8], [0.05, 0.95, 0]]) + self.evaluate(update_op) + result = self.evaluate(acc_obj.result()) + self.assertEqual(result, 1) # 2/2 + + # check with sample_weight + result_t = acc_obj([[0, 0, 1], [0, 1, 0]], + [[0.1, 0.1, 0.8], [0.05, 0, 0.95]], [[0.5], [0.2]]) + result = self.evaluate(result_t) + self.assertAlmostEqual(result, 0.93, 2) # 2.5/2.7 + + @test_util.run_in_graph_and_eager_modes + def test_invalid_result(self): + + class InvalidResult(metrics.Metric): + + def __init__(self, name='invalid-result', dtype=dtypes.float64): + super(InvalidResult, self).__init__(name=name, dtype=dtype) + + def update_state(self, *args, **kwargs): + pass + + def result(self): + return 1 + + invalid_result_obj = InvalidResult() + with self.assertRaisesRegexp( + TypeError, + 'Metric invalid-result\'s result must be a Tensor or Operation, given:' + ): + invalid_result_obj.result() + + @test_util.run_in_graph_and_eager_modes + def test_invalid_update(self): + + class InvalidUpdate(metrics.Metric): + + def __init__(self, name='invalid-update', dtype=dtypes.float64): + super(InvalidUpdate, self).__init__(name=name, dtype=dtype) + + def update_state(self, *args, **kwargs): + return [1] + + def result(self): + pass + + invalid_update_obj = InvalidUpdate() + with self.assertRaisesRegexp( + TypeError, + 'Metric invalid-update\'s update must be a Tensor or Operation, given:' + ): + invalid_update_obj.update_state() + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py index 8fb957da439dd490bc3378df96f611733335c809..71c1987cee6c610a19d12d5b9e2389606c5f1c24 100644 --- a/tensorflow/python/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/model_subclassing_test.py @@ -29,9 +29,11 @@ from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import data_structures from tensorflow.python.training.rmsprop import RMSPropOptimizer try: @@ -65,6 +67,22 @@ class SimpleTestModel(keras.Model): return self.dense2(x) +class SimpleConvTestModel(keras.Model): + + def __init__(self, num_classes=10): + super(SimpleConvTestModel, self).__init__(name='test_model') + self.num_classes = num_classes + + self.conv1 = keras.layers.Conv2D(32, (3, 3), activation='relu') + self.flatten = keras.layers.Flatten() + self.dense1 = keras.layers.Dense(num_classes, activation='softmax') + + def call(self, x): + x = self.conv1(x) + x = self.flatten(x) + return self.dense1(x) + + class MultiIOTestModel(keras.Model): def __init__(self, use_bn=False, use_dp=False, num_classes=(2, 3)): @@ -162,9 +180,6 @@ def get_nested_model_3(input_dim, num_classes): x = self.dense2(x) return self.bn(x) - def compute_output_shape(self, input_shape): - return tensor_shape.TensorShape((input_shape[0], 5)) - test_model = Inner() x = test_model(x) outputs = keras.layers.Dense(num_classes)(x) @@ -173,7 +188,235 @@ def get_nested_model_3(input_dim, num_classes): class ModelSubclassingTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes + def test_custom_build(self): + class DummyModel(keras.Model): + + def __init__(self): + super(DummyModel, self).__init__() + self.dense1 = keras.layers.Dense(32, activation='relu') + self.uses_custom_build = False + + def call(self, inputs): + return self.dense1(inputs) + + def build(self, input_shape): + self.uses_custom_build = True + + test_model = DummyModel() + dummy_data = array_ops.ones((32, 50)) + test_model(dummy_data) + self.assertTrue(test_model.uses_custom_build, 'Model should use user ' + 'defined build when called.') + + @test_util.run_in_graph_and_eager_modes + def test_invalid_input_shape_build(self): + num_classes = 2 + input_dim = 50 + + model = SimpleTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + with self.assertRaisesRegexp( + ValueError, 'input shape is not one of the valid types'): + model.build(input_shape=tensor_shape.Dimension(input_dim)) + + @test_util.run_in_graph_and_eager_modes + def test_embed_dtype_with_subclass_build(self): + class Embedding(keras.layers.Layer): + """An Embedding layer.""" + + def __init__(self, vocab_size, embedding_dim, **kwargs): + super(Embedding, self).__init__(**kwargs) + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + + def build(self, _): + self.embedding = self.add_variable( + 'embedding_kernel', + shape=[self.vocab_size, self.embedding_dim], + dtype=np.float32, + initializer=init_ops.random_uniform_initializer(-0.1, 0.1), + trainable=True) + + def call(self, x): + return embedding_ops.embedding_lookup(self.embedding, x) + + class EmbedModel(keras.Model): + + def __init__(self, vocab_size, embed_size): + super(EmbedModel, self).__init__() + self.embed1 = Embedding(vocab_size, embed_size) + + def call(self, inputs): + return self.embed1(inputs) + + model = EmbedModel(100, 20) + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + with self.assertRaisesRegexp( + ValueError, 'if your layers do not support float type inputs'): + model.build(input_shape=(35, 20)) + + @test_util.run_in_graph_and_eager_modes + def test_single_time_step_rnn_build(self): + dim = 4 + timesteps = 1 + batch_input_shape = (None, timesteps, dim) + units = 3 + + class SimpleRNNModel(keras.Model): + + def __init__(self): + super(SimpleRNNModel, self).__init__() + self.lstm = keras.layers.LSTM(units) + + def call(self, inputs): + return self.lstm(inputs) + + model = SimpleRNNModel() + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + model.build(batch_input_shape) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + model(array_ops.ones((32, timesteps, dim))) + + @test_util.run_in_graph_and_eager_modes + def test_single_io_subclass_build(self): + num_classes = 2 + input_dim = 50 + batch_size = None + + model = SimpleTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + model.build(input_shape=(batch_size, input_dim)) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + model(array_ops.ones((32, input_dim))) + + @test_util.run_in_graph_and_eager_modes + def test_single_io_dimension_subclass_build(self): + num_classes = 2 + input_dim = tensor_shape.Dimension(50) + batch_size = tensor_shape.Dimension(None) + + model = SimpleTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + model.build(input_shape=(batch_size, input_dim)) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + model(array_ops.ones((32, input_dim))) + + @test_util.run_in_graph_and_eager_modes + def test_multidim_io_subclass_build(self): + num_classes = 10 + # Input size, e.g. image + batch_size = 32 + input_shape = (32, 32, 3) + + model = SimpleConvTestModel(num_classes) + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + batch_input_shape = (batch_size,) + input_shape + model.build(input_shape=batch_input_shape) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + + model(array_ops.ones(batch_input_shape)) + + @test_util.run_in_graph_and_eager_modes + def test_tensorshape_io_subclass_build(self): + num_classes = 10 + # Input size, e.g. image + batch_size = None + input_shape = (32, 32, 3) + + model = SimpleConvTestModel(num_classes) + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + model.build( + input_shape=tensor_shape.TensorShape((batch_size,) + input_shape)) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + + model(array_ops.ones((32,) + input_shape)) + + def test_subclass_save_model(self): + num_classes = 10 + # Input size, e.g. image + batch_size = None + input_shape = (32, 32, 3) + + model = SimpleConvTestModel(num_classes) + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + model.build( + input_shape=tensor_shape.TensorShape((batch_size,) + input_shape)) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + weights = model.get_weights() + + tf_format_name = os.path.join(self.get_temp_dir(), 'ckpt') + model.save_weights(tf_format_name) + if h5py is not None: + hdf5_format_name = os.path.join(self.get_temp_dir(), 'weights.h5') + model.save_weights(hdf5_format_name) + + model = SimpleConvTestModel(num_classes) + model.build( + input_shape=tensor_shape.TensorShape((batch_size,) + input_shape)) + if h5py is not None: + model.load_weights(hdf5_format_name) + self.assertAllClose(weights, model.get_weights()) + model.load_weights(tf_format_name) + self.assertAllClose(weights, model.get_weights()) + + @test_util.run_in_graph_and_eager_modes + def test_multi_io_subclass_build(self): + batch_size = None + num_samples = 1000 + input_dim = 50 + model = MultiIOTestModel() + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + batch_input_shape = tensor_shape.TensorShape((batch_size, input_dim)) + model.build( + input_shape=[batch_input_shape, batch_input_shape]) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + x1 = array_ops.ones((num_samples, input_dim)) + x2 = array_ops.ones((num_samples, input_dim)) + model([x1, x2]) + + @test_util.run_in_graph_and_eager_modes def test_single_io_workflow_with_np_arrays(self): num_classes = 2 num_samples = 100 @@ -182,9 +425,10 @@ class ModelSubclassingTest(test.TestCase): model = SimpleTestModel(num_classes=num_classes, use_dp=True, use_bn=True) - model.compile(loss='mse', - optimizer=RMSPropOptimizer(learning_rate=0.001), - metrics=['acc']) + model.compile( + loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc', keras.metrics.CategoricalAccuracy()]) x = np.ones((num_samples, input_dim)) y = np.zeros((num_samples, num_classes)) @@ -192,7 +436,7 @@ class ModelSubclassingTest(test.TestCase): model.fit(x, y, epochs=2, batch_size=32, verbose=0) _ = model.evaluate(x, y, verbose=0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_multi_io_workflow_with_np_arrays(self): num_classes = (2, 3) num_samples = 1000 @@ -251,7 +495,7 @@ class ModelSubclassingTest(test.TestCase): model.fit([x1, x2], [y1, y2], epochs=2, steps_per_epoch=10, verbose=0) _ = model.evaluate(steps=10, verbose=0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_io_workflow_with_dataset_iterators(self): num_classes = 2 num_samples = 10 @@ -325,7 +569,7 @@ class ModelSubclassingTest(test.TestCase): self.assertEqual(len(model.inputs), 2) self.assertEqual(len(model.outputs), 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_updates(self): # test that updates get run during training num_samples = 100 @@ -352,7 +596,74 @@ class ModelSubclassingTest(test.TestCase): y_new = model.predict(x) self.assertGreater(np.sum(np.abs(y_ref - y_new)), 0.1) - @test_util.run_in_graph_and_eager_modes() + def test_updates_and_losses_for_nested_models_in_subclassed_model(self): + + # Case 1: deferred-build sequential nested in subclass. + class TestModel1(keras.Model): + + def __init__(self): + super(TestModel1, self).__init__() + self.fc = keras.layers.Dense(10, input_shape=(784,), + activity_regularizer='l1') + self.bn = keras.Sequential([keras.layers.BatchNormalization(axis=1)]) + + def call(self, x): + return self.bn(self.fc(x)) + + with self.test_session(): + model = TestModel1() + + x = array_ops.ones(shape=[100, 784], dtype='float32') + model(x) + self.assertEqual(len(model.get_updates_for(x)), 2) + self.assertEqual(len(model.get_losses_for(x)), 1) + + # Case 2: placeholder-sequential nested in subclass. + class TestModel2(keras.Model): + + def __init__(self): + super(TestModel2, self).__init__() + self.fc = keras.layers.Dense(10, input_shape=(784,), + activity_regularizer='l1') + self.bn = keras.Sequential( + [keras.layers.BatchNormalization(axis=1, input_shape=(10,))]) + + def call(self, x): + return self.bn(self.fc(x)) + + with self.test_session(): + model = TestModel2() + + x = array_ops.ones(shape=[100, 784], dtype='float32') + model(x) + self.assertEqual(len(model.get_updates_for(x)), 2) + self.assertEqual(len(model.get_losses_for(x)), 1) + + # Case 3: functional-API model nested in subclass. + inputs = keras.Input((10,)) + outputs = keras.layers.BatchNormalization(axis=1)(inputs) + bn = keras.Model(inputs, outputs) + + class TestModel3(keras.Model): + + def __init__(self): + super(TestModel3, self).__init__() + self.fc = keras.layers.Dense(10, input_shape=(784,), + activity_regularizer='l1') + self.bn = bn + + def call(self, x): + return self.bn(self.fc(x)) + + with self.test_session(): + model = TestModel3() + + x = array_ops.ones(shape=[100, 784], dtype='float32') + model(x) + self.assertEqual(len(model.get_updates_for(x)), 2) + self.assertEqual(len(model.get_losses_for(x)), 1) + + @test_util.run_in_graph_and_eager_modes def test_training_and_inference_behavior(self): # test that dropout is applied in training and not inference @@ -380,7 +691,7 @@ class ModelSubclassingTest(test.TestCase): loss = model.train_on_batch(x, y) self.assertGreater(loss, 0.1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_training_methods(self): # test fit, train_on_batch # on different input types: list, dict @@ -433,14 +744,14 @@ class ModelSubclassingTest(test.TestCase): model = MultiIOTestModel(num_classes=num_classes, use_bn=True) model.predict_on_batch([x1, x2]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_trainable_mutation(self): # test that you can change `trainable` on a model or layer, and that # it freezes the model state during training # TODO(fchollet): add test after we unify BN behavior in eager and symbolic. pass - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_saving(self): num_classes = (2, 3) @@ -482,7 +793,7 @@ class ModelSubclassingTest(test.TestCase): self.assertAllClose(y_ref_1, y1, atol=1e-5) self.assertAllClose(y_ref_2, y2, atol=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_summary(self): class ToString(object): @@ -508,7 +819,7 @@ class ModelSubclassingTest(test.TestCase): model.summary(print_fn=print_fn) self.assertTrue('Trainable params: 587' in print_fn.contents) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_subclass_nested_in_subclass(self): num_classes = 2 num_samples = 100 @@ -531,7 +842,7 @@ class ModelSubclassingTest(test.TestCase): self.assertEqual(len(model.trainable_weights), 6 + len(model.test_net.trainable_weights)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_graph_nested_in_subclass(self): num_classes = 2 num_samples = 100 @@ -554,7 +865,7 @@ class ModelSubclassingTest(test.TestCase): self.assertEqual(len(model.trainable_weights), 6 + len(model.test_net.trainable_weights)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_subclass_nested_in_graph(self): num_classes = 2 num_samples = 100 @@ -576,7 +887,7 @@ class ModelSubclassingTest(test.TestCase): len(model.non_trainable_weights), 4) self.assertEqual(len(model.trainable_weights), 12) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_support_for_manual_training_arg(self): # In most cases, the `training` argument is left unspecified, in which # case it defaults to value corresponding to the Model method being used @@ -612,8 +923,8 @@ class ModelSubclassingTest(test.TestCase): def __init__(self): super(Foo, self).__init__() self.isdep = keras.layers.Dense(1) - self.notdep = checkpointable.NoDependency(keras.layers.Dense(2)) - self.notdep_var = checkpointable.NoDependency( + self.notdep = data_structures.NoDependency(keras.layers.Dense(2)) + self.notdep_var = data_structures.NoDependency( resource_variable_ops.ResourceVariable(1., name='notdep_var')) m = Foo() @@ -683,9 +994,19 @@ class CustomCallModel(keras.Model): return combined +class TrainingNoDefaultModel(keras.Model): + + def __init__(self): + super(TrainingNoDefaultModel, self).__init__() + self.dense1 = keras.layers.Dense(1) + + def call(self, x, training): + return self.dense1(x) + + class CustomCallSignatureTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_no_inputs_in_signature(self): model = CustomCallModel() first = array_ops.ones([2, 3]) @@ -699,7 +1020,33 @@ class CustomCallSignatureTests(test.TestCase): output = model(first, second=second, training=False) self.assertAllClose(expected_output, self.evaluate(output)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes + def test_training_args_call_build(self): + input_dim = 2 + + model = TrainingNoDefaultModel() + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + model.build((None, input_dim)) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + + @test_util.run_in_graph_and_eager_modes + def test_custom_call_kwargs_and_build(self): + first_input_shape = (2, 3) + second_input_shape = (2, 5) + + model = CustomCallModel() + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + with self.assertRaisesRegexp( + ValueError, 'cannot build your model if it has positional'): + model.build(input_shape=[first_input_shape, second_input_shape]) + + @test_util.run_in_graph_and_eager_modes def test_inputs_in_signature(self): class HasInputsAndOtherPositional(keras.Model): @@ -716,7 +1063,7 @@ class CustomCallSignatureTests(test.TestCase): x1, x2 = keras.Input((1, 1)), keras.Input((1, 1)) model(x1, x2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_kwargs_in_signature(self): class HasKwargs(keras.Model): @@ -730,7 +1077,7 @@ class CustomCallSignatureTests(test.TestCase): if not context.executing_eagerly(): six.assertCountEqual(self, [arg], model.inputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_args_in_signature(self): class HasArgs(keras.Model): @@ -762,14 +1109,9 @@ class CustomCallSignatureTests(test.TestCase): def test_training_no_default(self): - class TrainingNoDefault(keras.Model): - - def call(self, x, training): - return x - with context.graph_mode(): - model = TrainingNoDefault() - arg = array_ops.ones([]) + model = TrainingNoDefaultModel() + arg = array_ops.ones([1, 1]) model(arg, True) six.assertCountEqual(self, [arg], model.inputs) diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index 21217fdca14eabaa425903d5370731eb94fdeec6..0bd662022019de7de6abeb12bce8ffb4824f1aff 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -26,7 +26,6 @@ from tensorflow.python.keras.engine import training from tensorflow.python.keras.engine.input_layer import Input from tensorflow.python.keras.engine.input_layer import InputLayer from tensorflow.python.keras.utils import generic_utils -from tensorflow.python.keras.utils.generic_utils import has_arg # API entries importable from `keras.models`: @@ -69,7 +68,7 @@ def _clone_functional_model(model, input_tensors=None): 'got a `Sequential` instance instead:', model) layer_map = {} # Cache for created layers. - tensor_map = {} # Map {reference_tensor: (corresponding_tensor, mask)} + tensor_map = {} # Map {reference_tensor: corresponding_tensor} if input_tensors is None: # Create placeholders to build the model on top of. input_layers = [] @@ -106,7 +105,7 @@ def _clone_functional_model(model, input_tensors=None): input_tensors = input_tensors_ for x, y in zip(model.inputs, input_tensors): - tensor_map[x] = (y, None) # tensor, mask + tensor_map[x] = y # Iterated over every node in the reference model, in depth order. depth_keys = list(model._nodes_by_depth.keys()) @@ -131,55 +130,41 @@ def _clone_functional_model(model, input_tensors=None): continue # Gather inputs to call the new layer. - referenceinput_tensors_ = node.input_tensors + reference_input_tensors = node.input_tensors reference_output_tensors = node.output_tensors # If all previous input tensors are available in tensor_map, # then call node.inbound_layer on them. - computed_data = [] # List of tuples (input, mask). - for x in referenceinput_tensors_: + computed_tensors = [] + for x in reference_input_tensors: if x in tensor_map: - computed_data.append(tensor_map[x]) + computed_tensors.append(tensor_map[x]) - if len(computed_data) == len(referenceinput_tensors_): + if len(computed_tensors) == len(reference_input_tensors): # Call layer. if node.arguments: kwargs = node.arguments else: kwargs = {} - if len(computed_data) == 1: - computed_tensor, computed_mask = computed_data[0] - if has_arg(layer.call, 'mask'): - if 'mask' not in kwargs: - kwargs['mask'] = computed_mask + if len(computed_tensors) == 1: + computed_tensor = computed_tensors[0] output_tensors = generic_utils.to_list(layer(computed_tensor, **kwargs)) - output_masks = generic_utils.to_list( - layer.compute_mask(computed_tensor, computed_mask)) computed_tensors = [computed_tensor] - computed_masks = [computed_mask] else: - computed_tensors = [x[0] for x in computed_data] - computed_masks = [x[1] for x in computed_data] - if has_arg(layer.call, 'mask'): - if 'mask' not in kwargs: - kwargs['mask'] = computed_masks + computed_tensors = computed_tensors output_tensors = generic_utils.to_list(layer(computed_tensors, **kwargs)) - output_masks = generic_utils.to_list( - layer.compute_mask(computed_tensors, computed_masks)) - # Update tensor_map. - for x, y, mask in zip(reference_output_tensors, output_tensors, - output_masks): - tensor_map[x] = (y, mask) + + for x, y in zip(reference_output_tensors, output_tensors): + tensor_map[x] = y # Check that we did compute the model outputs, # then instantiate a new model from inputs and outputs. output_tensors = [] for x in model.outputs: assert x in tensor_map, 'Could not compute output ' + str(x) - tensor, _ = tensor_map[x] - output_tensors.append(tensor) + output_tensors.append(tensor_map[x]) return Model(input_tensors, output_tensors, name=model.name) diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py index e6e45902a8f117e5765249da18afa7cc35aa6b16..1385ad5390f206ce930dcca12535496d7302a1d2 100644 --- a/tensorflow/python/keras/models_test.py +++ b/tensorflow/python/keras/models_test.py @@ -37,6 +37,7 @@ class TestModelCloning(test.TestCase): model = keras.models.Sequential() model.add(keras.layers.Dense(4, input_shape=(4,))) + model.add(keras.layers.BatchNormalization()) model.add(keras.layers.Dropout(0.5)) model.add(keras.layers.Dense(4)) @@ -46,6 +47,8 @@ class TestModelCloning(test.TestCase): with self.test_session(): # With placeholder creation new_model = keras.models.clone_model(model) + # update ops from batch norm needs to be included + self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2) new_model.compile('rmsprop', 'mse') new_model.train_on_batch(val_a, val_out) @@ -53,6 +56,7 @@ class TestModelCloning(test.TestCase): input_a = keras.Input(shape=(4,)) new_model = keras.models.clone_model( model, input_tensors=input_a) + self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2) new_model.compile('rmsprop', 'mse') new_model.train_on_batch(val_a, val_out) @@ -60,6 +64,7 @@ class TestModelCloning(test.TestCase): input_a = keras.backend.variable(val_a) new_model = keras.models.clone_model( model, input_tensors=input_a) + self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2) new_model.compile('rmsprop', 'mse') new_model.train_on_batch(None, val_out) @@ -76,6 +81,7 @@ class TestModelCloning(test.TestCase): x_a = dense_1(input_a) x_a = keras.layers.Dropout(0.5)(x_a) + x_a = keras.layers.BatchNormalization()(x_a) x_b = dense_1(input_b) x_a = dense_2(x_a) outputs = keras.layers.add([x_a, x_b]) @@ -87,6 +93,7 @@ class TestModelCloning(test.TestCase): with self.test_session(): # With placeholder creation new_model = keras.models.clone_model(model) + self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2) new_model.compile('rmsprop', 'mse') new_model.train_on_batch([val_a, val_b], val_out) @@ -95,6 +102,7 @@ class TestModelCloning(test.TestCase): input_b = keras.Input(shape=(4,), name='b') new_model = keras.models.clone_model( model, input_tensors=[input_a, input_b]) + self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2) new_model.compile('rmsprop', 'mse') new_model.train_on_batch([val_a, val_b], val_out) @@ -103,9 +111,26 @@ class TestModelCloning(test.TestCase): input_b = keras.backend.variable(val_b) new_model = keras.models.clone_model( model, input_tensors=[input_a, input_b]) + self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2) new_model.compile('rmsprop', 'mse') new_model.train_on_batch(None, val_out) + @test_util.run_in_graph_and_eager_modes + def test_clone_functional_model_with_masking(self): + with self.test_session(): + x = np.array([[[1], [1]], [[0], [0]]]) + inputs = keras.Input((2, 1)) + outputs = keras.layers.Masking(mask_value=0)(inputs) + outputs = keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))(outputs) + model = keras.Model(inputs, outputs) + + model = keras.models.clone_model(model) + model.compile(loss='mse', optimizer=adam.AdamOptimizer(0.01)) + y = np.array([[[1], [1]], [[1], [1]]]) + loss = model.train_on_batch(x, y) + self.assertEqual(float(loss), 0.) + def test_model_cloning_invalid_use_cases(self): seq_model = keras.models.Sequential() seq_model.add(keras.layers.Dense(4, input_shape=(4,))) @@ -129,7 +154,7 @@ class TestModelCloning(test.TestCase): class CheckpointingTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_optimizer_dependency(self): model = keras.models.Sequential() model.add(keras.layers.Dense(1, input_shape=(4,))) diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py index f58aeaea1acae2717f00a0323b5ff297a8cc8b46..f339a7e04761627ad1e5716974481349e8a71f81 100644 --- a/tensorflow/python/keras/optimizers.py +++ b/tensorflow/python/keras/optimizers.py @@ -19,57 +19,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - import six from six.moves import zip # pylint: disable=redefined-builtin -from tensorflow.python.framework import dtypes as dtypes_module -from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import clip_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training import training_util from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util.tf_export import tf_export -def clip_norm(g, c, n): - """Clip a tensor by norm. - - Arguments: - g: gradient tensor to clip. - c: clipping threshold. - n: norm of gradient tensor. - - Returns: - Clipped gradient tensor. - """ - if c > 0: - condition = n >= c - then_expression = lambda: math_ops.scalar_mul(c / n, g) - else_expression = lambda: g - - # saving the shape to avoid converting sparse tensor to dense - if isinstance(g, ops.Tensor): - g_shape = copy.copy(g.get_shape()) - elif isinstance(g, ops.IndexedSlices): - g_shape = copy.copy(g.dense_shape) - if condition.dtype != dtypes_module.bool: - condition = math_ops.cast(condition, 'bool') - g = control_flow_ops.cond(condition, then_expression, else_expression) - if isinstance(g, ops.Tensor): - g.set_shape(g_shape) - elif isinstance(g, ops.IndexedSlices): - g._dense_shape = g_shape # pylint: disable=protected-access - return g - - @tf_export('keras.optimizers.Optimizer') class Optimizer(object): """Abstract optimizer base class. @@ -91,6 +56,9 @@ class Optimizer(object): if k not in allowed_kwargs: raise TypeError('Unexpected keyword argument ' 'passed to optimizer: ' + str(k)) + # checks that clipnorm >= 0 and clipvalue >= 0 + if kwargs[k] < 0: + raise ValueError('Expected {} >= 0, received: {}'.format(k, kwargs[k])) self.__dict__.update(kwargs) self.updates = [] self.weights = [] @@ -119,12 +87,13 @@ class Optimizer(object): 'gradient defined (i.e. are differentiable). ' 'Common ops without gradient: ' 'K.argmax, K.round, K.eval.') - if hasattr(self, 'clipnorm') and self.clipnorm > 0: - norm = K.sqrt( - sum([math_ops.reduce_sum(math_ops.square(g)) for g in grads])) - grads = [clip_norm(g, self.clipnorm, norm) for g in grads] - if hasattr(self, 'clipvalue') and self.clipvalue > 0: - grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads] + if hasattr(self, 'clipnorm'): + grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads] + if hasattr(self, 'clipvalue'): + grads = [ + clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue) + for g in grads + ] return grads def set_weights(self, weights): @@ -719,12 +688,13 @@ class Nadam(Optimizer): return dict(list(base_config.items()) + list(config.items())) -class TFOptimizer(Optimizer, checkpointable.Checkpointable): +class TFOptimizer(Optimizer, checkpointable.CheckpointableBase): """Wrapper class for native TensorFlow optimizers. """ def __init__(self, optimizer): # pylint: disable=super-init-not-called self.optimizer = optimizer + self._track_checkpointable(optimizer, name='optimizer') with K.name_scope(self.__class__.__name__): self.iterations = K.variable(0, dtype='int64', name='iterations') @@ -735,7 +705,7 @@ class TFOptimizer(Optimizer, checkpointable.Checkpointable): return self.optimizer.compute_gradients(loss, params) def get_updates(self, loss, params): - if distribute_lib.has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): self.updates = [] if not params: @@ -748,10 +718,13 @@ class TFOptimizer(Optimizer, checkpointable.Checkpointable): global_step = training_util.get_global_step() opt_update = self.optimizer.apply_gradients(grads, global_step) else: - self.updates = [state_ops.assign_add(self.iterations, 1)] if not params: + self.updates = [state_ops.assign_add(self.iterations, 1)] return self.updates + # Updates list starts out empty because the iterations variable is + # incremented in optimizer.apply_gradients() + self.updates = [] grads = self.optimizer.compute_gradients(loss, params) opt_update = self.optimizer.apply_gradients( grads, global_step=self.iterations) diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py index 92b0cf326158adb1c6124384571a075196dbd3cc..4d295351f561268bbbadc885308c7d8e8765e042 100644 --- a/tensorflow/python/keras/optimizers_test.py +++ b/tensorflow/python/keras/optimizers_test.py @@ -46,7 +46,11 @@ def _test_optimizer(optimizer, target=0.75): model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) + np.testing.assert_equal(keras.backend.get_value(model.optimizer.iterations), + 0) history = model.fit(x_train, y_train, epochs=2, batch_size=16, verbose=0) + np.testing.assert_equal(keras.backend.get_value(model.optimizer.iterations), + 126) # 63 steps per epoch assert history.history['acc'][-1] >= target config = keras.optimizers.serialize(optimizer) optim = keras.optimizers.deserialize(config) @@ -66,7 +70,11 @@ def _test_optimizer(optimizer, target=0.75): model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) + np.testing.assert_equal(keras.backend.get_value(model.optimizer.iterations), + 126) # Using same optimizer from before model.train_on_batch(x_train[:10], y_train[:10]) + np.testing.assert_equal(keras.backend.get_value(model.optimizer.iterations), + 127) kernel, bias = dense.get_weights() np.testing.assert_allclose(kernel, 1., atol=1e-3) np.testing.assert_allclose(bias, 2., atol=1e-3) @@ -145,6 +153,34 @@ class KerasOptimizersTest(test.TestCase): with self.assertRaises(NotImplementedError): optimizer.from_config(None) + def test_tfoptimizer_iterations(self): + with self.test_session(): + optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01)) + model = keras.models.Sequential() + model.add(keras.layers.Dense( + 2, input_shape=(3,), kernel_constraint=keras.constraints.MaxNorm(1))) + model.compile(loss='mean_squared_error', optimizer=optimizer) + self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 0) + + model.fit(np.random.random((55, 3)), + np.random.random((55, 2)), + epochs=1, + batch_size=5, + verbose=0) + self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 11) + + model.fit(np.random.random((20, 3)), + np.random.random((20, 2)), + steps_per_epoch=8, + verbose=0) + self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 19) + + def test_negative_clipvalue_or_clipnorm(self): + with self.assertRaises(ValueError): + _ = keras.optimizers.SGD(lr=0.01, clipvalue=-0.5) + with self.assertRaises(ValueError): + _ = keras.optimizers.Adam(clipnorm=-2.0) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/preprocessing/__init__.py b/tensorflow/python/keras/preprocessing/__init__.py index e6704eeaa1f953be68e7ccdbc7e8bd60c62a61d8..2f08f88600f422b3a69ae1969ce5faa5716364f3 100644 --- a/tensorflow/python/keras/preprocessing/__init__.py +++ b/tensorflow/python/keras/preprocessing/__init__.py @@ -13,10 +13,18 @@ # limitations under the License. # ============================================================================== """Keras data preprocessing utils.""" +# pylint: disable=g-import-not-at-top from __future__ import absolute_import from __future__ import division from __future__ import print_function +import keras_preprocessing + +from tensorflow.python.keras import backend +from tensorflow.python.keras import utils + +keras_preprocessing.set_keras_submodules(backend=backend, utils=utils) + from tensorflow.python.keras.preprocessing import image from tensorflow.python.keras.preprocessing import sequence from tensorflow.python.keras.preprocessing import text diff --git a/tensorflow/python/keras/preprocessing/image.py b/tensorflow/python/keras/preprocessing/image.py index aa425df6a8bdb29b90a6d7000d126b771247c19f..ba227385eff30367487bf0d4294875bb6a48bced 100644 --- a/tensorflow/python/keras/preprocessing/image.py +++ b/tensorflow/python/keras/preprocessing/image.py @@ -12,1588 +12,58 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# pylint: disable=invalid-name # pylint: disable=g-import-not-at-top -"""Fairly basic set of tools for real-time data augmentation on image data. - -Can easily be extended to include new transformations, -new preprocessing methods, etc... +"""Set of tools for real-time data augmentation on image data. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from functools import partial -import multiprocessing.pool -import os -import re -import threading - -import numpy as np -from tensorflow.python.keras import backend as K -from tensorflow.python.keras.utils.data_utils import Sequence -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util.tf_export import tf_export - +from keras_preprocessing import image try: - from scipy import linalg - import scipy.ndimage as ndi + from scipy import linalg # pylint: disable=unused-import + from scipy import ndimage # pylint: disable=unused-import except ImportError: - linalg = None - ndi = None - - -try: - from PIL import ImageEnhance - from PIL import Image as pil_image -except ImportError: - pil_image = None - -if pil_image is not None: - _PIL_INTERPOLATION_METHODS = { - 'nearest': pil_image.NEAREST, - 'bilinear': pil_image.BILINEAR, - 'bicubic': pil_image.BICUBIC, - } - # These methods were only introduced in version 3.4.0 (2016). - if hasattr(pil_image, 'HAMMING'): - _PIL_INTERPOLATION_METHODS['hamming'] = pil_image.HAMMING - if hasattr(pil_image, 'BOX'): - _PIL_INTERPOLATION_METHODS['box'] = pil_image.BOX - # This method is new in version 1.1.3 (2013). - if hasattr(pil_image, 'LANCZOS'): - _PIL_INTERPOLATION_METHODS['lanczos'] = pil_image.LANCZOS - - -@tf_export('keras.preprocessing.image.random_rotation') -def random_rotation(x, - rg, - row_axis=1, - col_axis=2, - channel_axis=0, - fill_mode='nearest', - cval=0.): - """Performs a random rotation of a Numpy image tensor. - - Arguments: - x: Input tensor. Must be 3D. - rg: Rotation range, in degrees. - row_axis: Index of axis for rows in the input tensor. - col_axis: Index of axis for columns in the input tensor. - channel_axis: Index of axis for channels in the input tensor. - fill_mode: Points outside the boundaries of the input - are filled according to the given mode - (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). - cval: Value used for points outside the boundaries - of the input if `mode='constant'`. - - Returns: - Rotated Numpy image tensor. - """ - theta = np.deg2rad(np.random.uniform(-rg, rg)) - rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], - [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) - - h, w = x.shape[row_axis], x.shape[col_axis] - transform_matrix = transform_matrix_offset_center(rotation_matrix, h, w) - x = apply_transform(x, transform_matrix, channel_axis, fill_mode, cval) - return x - - -@tf_export('keras.preprocessing.image.random_shift') -def random_shift(x, - wrg, - hrg, - row_axis=1, - col_axis=2, - channel_axis=0, - fill_mode='nearest', - cval=0.): - """Performs a random spatial shift of a Numpy image tensor. - - Arguments: - x: Input tensor. Must be 3D. - wrg: Width shift range, as a float fraction of the width. - hrg: Height shift range, as a float fraction of the height. - row_axis: Index of axis for rows in the input tensor. - col_axis: Index of axis for columns in the input tensor. - channel_axis: Index of axis for channels in the input tensor. - fill_mode: Points outside the boundaries of the input - are filled according to the given mode - (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). - cval: Value used for points outside the boundaries - of the input if `mode='constant'`. - - Returns: - Shifted Numpy image tensor. - """ - h, w = x.shape[row_axis], x.shape[col_axis] - tx = np.random.uniform(-hrg, hrg) * h - ty = np.random.uniform(-wrg, wrg) * w - translation_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) - - transform_matrix = translation_matrix # no need to do offset - x = apply_transform(x, transform_matrix, channel_axis, fill_mode, cval) - return x - - -@tf_export('keras.preprocessing.image.random_shear') -def random_shear(x, - intensity, - row_axis=1, - col_axis=2, - channel_axis=0, - fill_mode='nearest', - cval=0.): - """Performs a random spatial shear of a Numpy image tensor. - - Arguments: - x: Input tensor. Must be 3D. - intensity: Transformation intensity in degrees. - row_axis: Index of axis for rows in the input tensor. - col_axis: Index of axis for columns in the input tensor. - channel_axis: Index of axis for channels in the input tensor. - fill_mode: Points outside the boundaries of the input - are filled according to the given mode - (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). - cval: Value used for points outside the boundaries - of the input if `mode='constant'`. - - Returns: - Sheared Numpy image tensor. - """ - shear = np.deg2rad(np.random.uniform(-intensity, intensity)) - shear_matrix = np.array([[1, -np.sin(shear), 0], [0, np.cos(shear), 0], - [0, 0, 1]]) - - h, w = x.shape[row_axis], x.shape[col_axis] - transform_matrix = transform_matrix_offset_center(shear_matrix, h, w) - x = apply_transform(x, transform_matrix, channel_axis, fill_mode, cval) - return x - - -@tf_export('keras.preprocessing.image.random_zoom') -def random_zoom(x, - zoom_range, - row_axis=1, - col_axis=2, - channel_axis=0, - fill_mode='nearest', - cval=0.): - """Performs a random spatial zoom of a Numpy image tensor. - - Arguments: - x: Input tensor. Must be 3D. - zoom_range: Tuple of floats; zoom range for width and height. - row_axis: Index of axis for rows in the input tensor. - col_axis: Index of axis for columns in the input tensor. - channel_axis: Index of axis for channels in the input tensor. - fill_mode: Points outside the boundaries of the input - are filled according to the given mode - (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). - cval: Value used for points outside the boundaries - of the input if `mode='constant'`. - - Returns: - Zoomed Numpy image tensor. - - Raises: - ValueError: if `zoom_range` isn't a tuple. - """ - if len(zoom_range) != 2: - raise ValueError('`zoom_range` should be a tuple or list of two floats. ' - 'Received arg: ', zoom_range) - - if zoom_range[0] == 1 and zoom_range[1] == 1: - zx, zy = 1, 1 - else: - zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2) - zoom_matrix = np.array([[zx, 0, 0], [0, zy, 0], [0, 0, 1]]) - - h, w = x.shape[row_axis], x.shape[col_axis] - transform_matrix = transform_matrix_offset_center(zoom_matrix, h, w) - x = apply_transform(x, transform_matrix, channel_axis, fill_mode, cval) - return x - - -@tf_export('keras.preprocessing.image.random_channel_shift') -def random_channel_shift(x, intensity, channel_axis=0): - """Perform a random channel shift. - - Arguments: - x: Input tensor. Must be 3D. - intensity: Transformation intensity. - channel_axis: Index of axis for channels in the input tensor. - - Returns: - Numpy image tensor. - """ - x = np.rollaxis(x, channel_axis, 0) - min_x, max_x = np.min(x), np.max(x) - channel_images = [ - np.clip(x_channel + np.random.uniform(-intensity, intensity), min_x, - max_x) for x_channel in x - ] - x = np.stack(channel_images, axis=0) - x = np.rollaxis(x, 0, channel_axis + 1) - return x - - -@tf_export('keras.preprocessing.image.random_brightness') -def random_brightness(x, brightness_range): - """Performs a random adjustment of brightness of a Numpy image tensor. - - Arguments: - x: Input tensor. Must be 3D. - brightness_range: Tuple of floats; range to pick a brightness value from. - - Returns: - Brightness adjusted Numpy image tensor. - - Raises: - ValueError: if `brightness_range` isn't a tuple. - """ - if len(brightness_range) != 2: - raise ValueError('`brightness_range should be tuple or list of two floats. ' - 'Received arg: ', brightness_range) - - x = array_to_img(x) - x = ImageEnhance.Brightness(x) - u = np.random.uniform(brightness_range[0], brightness_range[1]) - x = x.enhance(u) - x = img_to_array(x) - return x - - -def transform_matrix_offset_center(matrix, x, y): - o_x = float(x) / 2 + 0.5 - o_y = float(y) / 2 + 0.5 - offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) - reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) - transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) - return transform_matrix - - -@tf_export('keras.preprocessing.image.apply_transform') -def apply_transform(x, - transform_matrix, - channel_axis=0, - fill_mode='nearest', - cval=0.): - """Apply the image transformation specified by a matrix. - - Arguments: - x: 2D numpy array, single image. - transform_matrix: Numpy array specifying the geometric transformation. - channel_axis: Index of axis for channels in the input tensor. - fill_mode: Points outside the boundaries of the input - are filled according to the given mode - (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). - cval: Value used for points outside the boundaries - of the input if `mode='constant'`. - - Returns: - The transformed version of the input. - """ - x = np.rollaxis(x, channel_axis, 0) - final_affine_matrix = transform_matrix[:2, :2] - final_offset = transform_matrix[:2, 2] - channel_images = [ - ndi.interpolation.affine_transform( - x_channel, - final_affine_matrix, - final_offset, - order=1, - mode=fill_mode, - cval=cval) for x_channel in x - ] - x = np.stack(channel_images, axis=0) - x = np.rollaxis(x, 0, channel_axis + 1) - return x - - -@tf_export('keras.preprocessing.image.flip_axis') -def flip_axis(x, axis): - x = np.asarray(x).swapaxes(axis, 0) - x = x[::-1, ...] - x = x.swapaxes(0, axis) - return x - - -@tf_export('keras.preprocessing.image.array_to_img') -def array_to_img(x, data_format=None, scale=True): - """Converts a 3D Numpy array to a PIL Image instance. - - Arguments: - x: Input Numpy array. - data_format: Image data format. - scale: Whether to rescale image values - to be within [0, 255]. - - Returns: - A PIL Image instance. - - Raises: - ImportError: if PIL is not available. - ValueError: if invalid `x` or `data_format` is passed. - """ - if pil_image is None: - raise ImportError('Could not import PIL.Image. ' - 'The use of `array_to_img` requires PIL.') - x = np.asarray(x, dtype=K.floatx()) - if x.ndim != 3: - raise ValueError('Expected image array to have rank 3 (single image). ' - 'Got array with shape:', x.shape) - - if data_format is None: - data_format = K.image_data_format() - if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Invalid data_format:', data_format) - - # Original Numpy array x has format (height, width, channel) - # or (channel, height, width) - # but target PIL image has format (width, height, channel) - if data_format == 'channels_first': - x = x.transpose(1, 2, 0) - if scale: - x = x + max(-np.min(x), 0) # pylint: disable=g-no-augmented-assignment - x_max = np.max(x) - if x_max != 0: - x /= x_max - x *= 255 - if x.shape[2] == 3: - # RGB - return pil_image.fromarray(x.astype('uint8'), 'RGB') - elif x.shape[2] == 1: - # grayscale - return pil_image.fromarray(x[:, :, 0].astype('uint8'), 'L') - else: - raise ValueError('Unsupported channel number: ', x.shape[2]) - - -@tf_export('keras.preprocessing.image.img_to_array') -def img_to_array(img, data_format=None): - """Converts a PIL Image instance to a Numpy array. - - Arguments: - img: PIL Image instance. - data_format: Image data format. - - Returns: - A 3D Numpy array. - - Raises: - ValueError: if invalid `img` or `data_format` is passed. - """ - if data_format is None: - data_format = K.image_data_format() - if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format: ', data_format) - # Numpy array x has format (height, width, channel) - # or (channel, height, width) - # but original PIL image has format (width, height, channel) - x = np.asarray(img, dtype=K.floatx()) - if len(x.shape) == 3: - if data_format == 'channels_first': - x = x.transpose(2, 0, 1) - elif len(x.shape) == 2: - if data_format == 'channels_first': - x = x.reshape((1, x.shape[0], x.shape[1])) - else: - x = x.reshape((x.shape[0], x.shape[1], 1)) - else: - raise ValueError('Unsupported image shape: ', x.shape) - return x - - -@tf_export('keras.preprocessing.image.load_img') -def load_img(path, grayscale=False, target_size=None, interpolation='nearest'): - """Loads an image into PIL format. - - Arguments: - path: Path to image file - grayscale: Boolean, whether to load the image as grayscale. - target_size: Either `None` (default to original size) - or tuple of ints `(img_height, img_width)`. - interpolation: Interpolation method used to resample the image if the - target size is different from that of the loaded image. - Supported methods are "nearest", "bilinear", and "bicubic". - If PIL version 1.1.3 or newer is installed, "lanczos" is also - supported. If PIL version 3.4.0 or newer is installed, "box" and - "hamming" are also supported. By default, "nearest" is used. - - Returns: - A PIL Image instance. - - Raises: - ImportError: if PIL is not available. - ValueError: if interpolation method is not supported. - """ - if pil_image is None: - raise ImportError('Could not import PIL.Image. ' - 'The use of `array_to_img` requires PIL.') - img = pil_image.open(path) - if grayscale: - if img.mode != 'L': - img = img.convert('L') - else: - if img.mode != 'RGB': - img = img.convert('RGB') - if target_size is not None: - width_height_tuple = (target_size[1], target_size[0]) - if img.size != width_height_tuple: - if interpolation not in _PIL_INTERPOLATION_METHODS: - raise ValueError('Invalid interpolation method {} specified. Supported ' - 'methods are {}'.format(interpolation, ', '.join( - _PIL_INTERPOLATION_METHODS.keys()))) - resample = _PIL_INTERPOLATION_METHODS[interpolation] - img = img.resize(width_height_tuple, resample) - return img - - -def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm'): - return [ - os.path.join(root, f) - for root, _, files in os.walk(directory) - for f in files - if re.match(r'([\w]+\.(?:' + ext + '))', f) - ] - - -@tf_export('keras.preprocessing.image.ImageDataGenerator') -class ImageDataGenerator(object): - """Generates batches of tensor image data with real-time data augmentation. - The data will be looped over (in batches). - - Arguments: - featurewise_center: boolean, set input mean to 0 over the dataset, - feature-wise. - samplewise_center: boolean, set each sample mean to 0. - featurewise_std_normalization: boolean, divide inputs by std - of the dataset, feature-wise. - samplewise_std_normalization: boolean, divide each input by its std. - zca_epsilon: epsilon for ZCA whitening. Default is 1e-6. - zca_whitening: boolean, apply ZCA whitening. - rotation_range: int, degree range for random rotations. - width_shift_range: float, 1-D array-like or int - float: fraction of total width, if < 1, or pixels if >= 1. - 1-D array-like: random elements from the array. - int: integer number of pixels from interval - `(-width_shift_range, +width_shift_range)` - With `width_shift_range=2` possible values are integers [-1, 0, +1], - same as with `width_shift_range=[-1, 0, +1]`, - while with `width_shift_range=1.0` possible values are floats in - the interval [-1.0, +1.0). - shear_range: float, shear Intensity - (Shear angle in counter-clockwise direction in degrees) - zoom_range: float or [lower, upper], Range for random zoom. - If a float, `[lower, upper] = [1-zoom_range, 1+zoom_range]`. - channel_shift_range: float, range for random channel shifts. - fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}. - Default is 'nearest'. Points outside the boundaries of the input - are filled according to the given mode: - 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k) - 'nearest': aaaaaaaa|abcd|dddddddd - 'reflect': abcddcba|abcd|dcbaabcd - 'wrap': abcdabcd|abcd|abcdabcd - cval: float or int, value used for points outside the boundaries - when `fill_mode = "constant"`. - horizontal_flip: boolean, randomly flip inputs horizontally. - vertical_flip: boolean, randomly flip inputs vertically. - rescale: rescaling factor. Defaults to None. If None or 0, no rescaling - is applied, otherwise we multiply the data by the value provided - (before applying any other transformation). - preprocessing_function: function that will be implied on each input. - The function will run after the image is resized and augmented. - The function should take one argument: - one image (Numpy tensor with rank 3), - and should output a Numpy tensor with the same shape. - data_format: One of {"channels_first", "channels_last"}. - "channels_last" mode means that the images should have shape - `(samples, height, width, channels)`, - "channels_first" mode means that the images should have shape - `(samples, channels, height, width)`. - It defaults to the `image_data_format` value found in your - Keras config file at `~/.keras/keras.json`. - If you never set it, then it will be "channels_last". - validation_split: float, fraction of images reserved for validation - (strictly between 0 and 1). - - Examples: - Example of using `.flow(x, y)`: - ```python - (x_train, y_train), (x_test, y_test) = cifar10.load_data() - y_train = np_utils.to_categorical(y_train, num_classes) - y_test = np_utils.to_categorical(y_test, num_classes) - datagen = ImageDataGenerator( - featurewise_center=True, - featurewise_std_normalization=True, - rotation_range=20, - width_shift_range=0.2, - height_shift_range=0.2, - horizontal_flip=True) - # compute quantities required for featurewise normalization - # (std, mean, and principal components if ZCA whitening is applied) - datagen.fit(x_train) - # fits the model on batches with real-time data augmentation: - model.fit_generator(datagen.flow(x_train, y_train, batch_size=32), - steps_per_epoch=len(x_train) / 32, epochs=epochs) - # here's a more "manual" example - for e in range(epochs): - print('Epoch', e) - batches = 0 - for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32): - model.fit(x_batch, y_batch) - batches += 1 - if batches >= len(x_train) / 32: - # we need to break the loop by hand because - # the generator loops indefinitely - break - ``` - Example of using `.flow_from_directory(directory)`: - ```python - train_datagen = ImageDataGenerator( - rescale=1./255, - shear_range=0.2, - zoom_range=0.2, - horizontal_flip=True) - test_datagen = ImageDataGenerator(rescale=1./255) - train_generator = train_datagen.flow_from_directory( - 'data/train', - target_size=(150, 150), - batch_size=32, - class_mode='binary') - validation_generator = test_datagen.flow_from_directory( - 'data/validation', - target_size=(150, 150), - batch_size=32, - class_mode='binary') - model.fit_generator( - train_generator, - steps_per_epoch=2000, - epochs=50, - validation_data=validation_generator, - validation_steps=800) - ``` - Example of transforming images and masks together. - ```python - # we create two instances with the same arguments - data_gen_args = dict(featurewise_center=True, - featurewise_std_normalization=True, - rotation_range=90., - width_shift_range=0.1, - height_shift_range=0.1, - zoom_range=0.2) - image_datagen = ImageDataGenerator(**data_gen_args) - mask_datagen = ImageDataGenerator(**data_gen_args) - # Provide the same seed and keyword arguments to the fit and flow methods - seed = 1 - image_datagen.fit(images, augment=True, seed=seed) - mask_datagen.fit(masks, augment=True, seed=seed) - image_generator = image_datagen.flow_from_directory( - 'data/images', - class_mode=None, - seed=seed) - mask_generator = mask_datagen.flow_from_directory( - 'data/masks', - class_mode=None, - seed=seed) - # combine generators into one which yields image and masks - train_generator = zip(image_generator, mask_generator) - model.fit_generator( - train_generator, - steps_per_epoch=2000, - epochs=50) - ``` - """ - - def __init__(self, - featurewise_center=False, - samplewise_center=False, - featurewise_std_normalization=False, - samplewise_std_normalization=False, - zca_whitening=False, - zca_epsilon=1e-6, - rotation_range=0., - width_shift_range=0., - height_shift_range=0., - brightness_range=None, - shear_range=0., - zoom_range=0., - channel_shift_range=0., - fill_mode='nearest', - cval=0., - horizontal_flip=False, - vertical_flip=False, - rescale=None, - preprocessing_function=None, - data_format=None, - validation_split=0.0): - if data_format is None: - data_format = K.image_data_format() - self.featurewise_center = featurewise_center - self.samplewise_center = samplewise_center - self.featurewise_std_normalization = featurewise_std_normalization - self.samplewise_std_normalization = samplewise_std_normalization - self.zca_whitening = zca_whitening - self.zca_epsilon = zca_epsilon - self.rotation_range = rotation_range - self.width_shift_range = width_shift_range - self.height_shift_range = height_shift_range - self.brightness_range = brightness_range - self.shear_range = shear_range - self.zoom_range = zoom_range - self.channel_shift_range = channel_shift_range - self.fill_mode = fill_mode - self.cval = cval - self.horizontal_flip = horizontal_flip - self.vertical_flip = vertical_flip - self.rescale = rescale - self.preprocessing_function = preprocessing_function - - if data_format not in {'channels_last', 'channels_first'}: - raise ValueError( - '`data_format` should be `"channels_last"` (channel after row and ' - 'column) or `"channels_first"` (channel before row and column). ' - 'Received arg: ', data_format) - self.data_format = data_format - if data_format == 'channels_first': - self.channel_axis = 1 - self.row_axis = 2 - self.col_axis = 3 - if data_format == 'channels_last': - self.channel_axis = 3 - self.row_axis = 1 - self.col_axis = 2 - if validation_split and not 0 < validation_split < 1: - raise ValueError('`validation_split` must be strictly between 0 and 1. ' - 'Received arg: ', validation_split) - self.validation_split = validation_split - - self.mean = None - self.std = None - self.principal_components = None - - if np.isscalar(zoom_range): - self.zoom_range = [1 - zoom_range, 1 + zoom_range] - elif len(zoom_range) == 2: - self.zoom_range = [zoom_range[0], zoom_range[1]] - else: - raise ValueError('`zoom_range` should be a float or ' - 'a tuple or list of two floats. ' - 'Received arg: ', zoom_range) - if zca_whitening: - if not featurewise_center: - self.featurewise_center = True - logging.warning('This ImageDataGenerator specifies ' - '`zca_whitening`, which overrides ' - 'setting of `featurewise_center`.') - if featurewise_std_normalization: - self.featurewise_std_normalization = False - logging.warning('This ImageDataGenerator specifies ' - '`zca_whitening` ' - 'which overrides setting of' - '`featurewise_std_normalization`.') - if featurewise_std_normalization: - if not featurewise_center: - self.featurewise_center = True - logging.warning('This ImageDataGenerator specifies ' - '`featurewise_std_normalization`, ' - 'which overrides setting of ' - '`featurewise_center`.') - if samplewise_std_normalization: - if not samplewise_center: - self.samplewise_center = True - logging.warning('This ImageDataGenerator specifies ' - '`samplewise_std_normalization`, ' - 'which overrides setting of ' - '`samplewise_center`.') - - def flow(self, - x, - y=None, - batch_size=32, - shuffle=True, - seed=None, - save_to_dir=None, - save_prefix='', - save_format='png', - subset=None): - """Generates batches of augmented/normalized data with given numpy arrays. - - Arguments: - x: data. Should have rank 4. - In case of grayscale data, the channels axis should have value 1 - and in case of RGB data, it should have value 3. - y: labels. - batch_size: int (default: 32). - shuffle: boolean (default: True). - seed: int (default: None). - save_to_dir: None or str (default: None). - This allows you to optionally specify a directory - to which to save the augmented pictures being generated - (useful for visualizing what you are doing). - save_prefix: str (default: `''`). Prefix to use for filenames of - saved pictures (only relevant if `save_to_dir` is set). - save_format: one of "png", "jpeg". Default: "png". - (only relevant if `save_to_dir` is set) - subset: Subset of data (`"training"` or `"validation"`) if - `validation_split` is set in `ImageDataGenerator`. - - Returns: - An Iterator yielding tuples of `(x, y)` where `x` is a numpy array of - image data and `y` is a numpy array of corresponding labels. - """ - return NumpyArrayIterator( - x, - y, - self, - batch_size=batch_size, - shuffle=shuffle, - seed=seed, - data_format=self.data_format, - save_to_dir=save_to_dir, - save_prefix=save_prefix, - save_format=save_format, - subset=subset) - - def flow_from_directory(self, - directory, - target_size=(256, 256), - color_mode='rgb', - classes=None, - class_mode='categorical', - batch_size=32, - shuffle=True, - seed=None, - save_to_dir=None, - save_prefix='', - save_format='png', - follow_links=False, - subset=None, - interpolation='nearest'): - """Generates batches of augmented/normalized data given directory path. - - Arguments: - directory: path to the target directory. It should contain one - subdirectory per class. Any PNG, JPG, BMP, PPM or TIF images - inside each of the subdirectories directory tree will be included - in the generator. See [this script] - (https://gist.github.com/fchollet/0830affa1f7f19fd47b06d4cf89ed44d) - for more details. - target_size: tuple of integers `(height, width)`, default: `(256, - 256)`. The dimensions to which all images found will be resized. - color_mode: one of "grayscale", "rbg". Default: "rgb". Whether the - images will be converted to have 1 or 3 color channels. - classes: optional list of class subdirectories (e.g. `['dogs', - 'cats']`). Default: None. If not provided, the list of classes - will be automatically inferred from the subdirectory - names/structure under `directory`, where each subdirectory will be - treated as a different class (and the order of the classes, which - will map to the label indices, will be alphanumeric). The - dictionary containing the mapping from class names to class - indices can be obtained via the attribute `class_indices`. - class_mode: one of "categorical", "binary", "sparse", "input" or - None. Default: "categorical". Determines the type of label arrays - that are returned: "categorical" will be 2D one-hot encoded - labels, "binary" will be 1D binary labels, "sparse" will be 1D - integer labels, "input" will be images identical to input images - (mainly used to work with autoencoders). If None, no labels are - returned (the generator will only yield batches of image data, - which is useful to use `model.predict_generator()`, - `model.evaluate_generator()`, etc.). Please note that in case of - class_mode None, the data still needs to reside in a subdirectory - of `directory` for it to work correctly. - batch_size: size of the batches of data (default: 32). - shuffle: whether to shuffle the data (default: True) - seed: optional random seed for shuffling and transformations. - save_to_dir: None or str (default: None). This allows you to - optionally specify a directory to which to save the augmented - pictures being generated (useful for visualizing what you are doing) - save_prefix: str. Prefix to use for filenames of saved pictures - (only relevant if `save_to_dir` is set). - save_format: one of "png", "jpeg" (only relevant if `save_to_dir` is - set). Default: "png". - follow_links: whether to follow symlinks inside class subdirectories - (default: False). - subset: Subset of data (`"training"` or `"validation"`) if - ` validation_split` is set in `ImageDataGenerator`. - interpolation: Interpolation method used to resample the image if - the target size is different from that of the loaded image. - Supported methods are `"nearest"`, `"bilinear"`, and `"bicubic"`. - If PIL version 1.1.3 or newer is installed, `"lanczos"` is also - supported. If PIL version 3.4.0 or newer is installed, `"box"` and - `"hamming"` are also supported. By default, `"nearest"` is used. - - Returns: - A DirectoryIterator yielding tuples of `(x, y)` where `x` is a - numpy array containing a batch of images with shape - `(batch_size, *target_size, channels)` and `y` is a numpy - array of corresponding labels. - """ - return DirectoryIterator( - directory, - self, - target_size=target_size, - color_mode=color_mode, - classes=classes, - class_mode=class_mode, - data_format=self.data_format, - batch_size=batch_size, - shuffle=shuffle, - seed=seed, - save_to_dir=save_to_dir, - save_prefix=save_prefix, - save_format=save_format, - follow_links=follow_links, - subset=subset, - interpolation=interpolation) - - def standardize(self, x): - """Apply the normalization configuration to a batch of inputs. - - Arguments: - x: batch of inputs to be normalized. - - Returns: - The inputs, normalized. - """ - if self.preprocessing_function: - x = self.preprocessing_function(x) - if self.rescale: - x *= self.rescale - if self.samplewise_center: - x -= np.mean(x, keepdims=True) - if self.samplewise_std_normalization: - x /= (np.std(x, keepdims=True) + K.epsilon()) + pass - if self.featurewise_center: - if self.mean is not None: - x -= self.mean - else: - logging.warning('This ImageDataGenerator specifies ' - '`featurewise_center`, but it hasn\'t ' - 'been fit on any training data. Fit it ' - 'first by calling `.fit(numpy_data)`.') - if self.featurewise_std_normalization: - if self.std is not None: - x /= (self.std + K.epsilon()) - else: - logging.warning('This ImageDataGenerator specifies ' - '`featurewise_std_normalization`, but it hasn\'t ' - 'been fit on any training data. Fit it ' - 'first by calling `.fit(numpy_data)`.') - if self.zca_whitening: - if self.principal_components is not None: - flatx = np.reshape(x, (-1, np.prod(x.shape[-3:]))) - whitex = np.dot(flatx, self.principal_components) - x = np.reshape(whitex, x.shape) - else: - logging.warning('This ImageDataGenerator specifies ' - '`zca_whitening`, but it hasn\'t ' - 'been fit on any training data. Fit it ' - 'first by calling `.fit(numpy_data)`.') - return x - - def random_transform(self, x, seed=None): - """Randomly augment a single image tensor. - - Arguments: - x: 3D tensor, single image. - seed: random seed. - - Returns: - A randomly transformed version of the input (same shape). - - Raises: - ImportError: if Scipy is not available. - """ - if ndi is None: - raise ImportError('Scipy is required for image transformations.') - # x is a single image, so it doesn't have image number at index 0 - img_row_axis = self.row_axis - 1 - img_col_axis = self.col_axis - 1 - img_channel_axis = self.channel_axis - 1 - - if seed is not None: - np.random.seed(seed) - - # use composition of homographies - # to generate final transform that needs to be applied - if self.rotation_range: - theta = np.deg2rad( - np.random.uniform(-self.rotation_range, self.rotation_range)) - else: - theta = 0 - - if self.height_shift_range: - try: # 1-D array-like or int - tx = np.random.choice(self.height_shift_range) - tx *= np.random.choice([-1, 1]) - except ValueError: # floating point - tx = np.random.uniform(-self.height_shift_range, - self.height_shift_range) - if np.max(self.height_shift_range) < 1: - tx *= x.shape[img_row_axis] - else: - tx = 0 - - if self.width_shift_range: - try: # 1-D array-like or int - ty = np.random.choice(self.width_shift_range) - ty *= np.random.choice([-1, 1]) - except ValueError: # floating point - ty = np.random.uniform(-self.width_shift_range, self.width_shift_range) - if np.max(self.width_shift_range) < 1: - ty *= x.shape[img_col_axis] - else: - ty = 0 - - if self.shear_range: - shear = np.deg2rad(np.random.uniform(-self.shear_range, self.shear_range)) - else: - shear = 0 - - if self.zoom_range[0] == 1 and self.zoom_range[1] == 1: - zx, zy = 1, 1 - else: - zx, zy = np.random.uniform(self.zoom_range[0], self.zoom_range[1], 2) - - transform_matrix = None - if theta != 0: - rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], - [np.sin(theta), - np.cos(theta), 0], [0, 0, 1]]) - transform_matrix = rotation_matrix - - if tx != 0 or ty != 0: - shift_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) - transform_matrix = shift_matrix if transform_matrix is None else np.dot( - transform_matrix, shift_matrix) - - if shear != 0: - shear_matrix = np.array([[1, -np.sin(shear), 0], [0, np.cos(shear), 0], - [0, 0, 1]]) - transform_matrix = shear_matrix if transform_matrix is None else np.dot( - transform_matrix, shear_matrix) - - if zx != 1 or zy != 1: - zoom_matrix = np.array([[zx, 0, 0], [0, zy, 0], [0, 0, 1]]) - transform_matrix = zoom_matrix if transform_matrix is None else np.dot( - transform_matrix, zoom_matrix) - - if transform_matrix is not None: - h, w = x.shape[img_row_axis], x.shape[img_col_axis] - transform_matrix = transform_matrix_offset_center(transform_matrix, h, w) - x = apply_transform( - x, - transform_matrix, - img_channel_axis, - fill_mode=self.fill_mode, - cval=self.cval) - - if self.channel_shift_range != 0: - x = random_channel_shift(x, self.channel_shift_range, img_channel_axis) - if self.horizontal_flip: - if np.random.random() < 0.5: - x = flip_axis(x, img_col_axis) - - if self.vertical_flip: - if np.random.random() < 0.5: - x = flip_axis(x, img_row_axis) - - if self.brightness_range is not None: - x = random_brightness(x, self.brightness_range) - - return x - - def fit(self, x, augment=False, rounds=1, seed=None): - """Computes the internal data statistics based on an array of sample data. - - These are statistics related to the data-dependent transformations. - Only required if featurewise_center or featurewise_std_normalization or - zca_whitening. - - Arguments: - x: sample data. Should have rank 4. - In case of grayscale data, the channels axis should have value 1 - and in case of RGB data, it should have value 3. - augment: Boolean (default: False). Whether to fit on randomly - augmented samples. - rounds: int (default: 1). If augment, how many augmentation passes - over the data to use. - seed: int (default: None). Random seed. - - Raises: - ValueError: If input rank is not 4. - ImportError: If scipy is not imported. - """ - x = np.asarray(x, dtype=K.floatx()) - if x.ndim != 4: - raise ValueError('Input to `.fit()` should have rank 4. ' - 'Got array with shape: ' + str(x.shape)) - if x.shape[self.channel_axis] not in {1, 3, 4}: - logging.warning( - 'Expected input to be images (as Numpy array) ' - 'following the data format convention "' + self.data_format + '" ' - '(channels on axis ' + str(self.channel_axis) + '), i.e. expected ' - 'either 1, 3 or 4 channels on axis ' + str(self.channel_axis) + '. ' - 'However, it was passed an array with shape ' + str(x.shape) + ' (' + - str(x.shape[self.channel_axis]) + ' channels).') - - if seed is not None: - np.random.seed(seed) - - x = np.copy(x) - if augment: - ax = np.zeros( - tuple([rounds * x.shape[0]] + list(x.shape)[1:]), dtype=K.floatx()) - for r in range(rounds): - for i in range(x.shape[0]): - ax[i + r * x.shape[0]] = self.random_transform(x[i]) - x = ax - - if self.featurewise_center: - self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis)) - broadcast_shape = [1, 1, 1] - broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] - self.mean = np.reshape(self.mean, broadcast_shape) - x -= self.mean - - if self.featurewise_std_normalization: - self.std = np.std(x, axis=(0, self.row_axis, self.col_axis)) - broadcast_shape = [1, 1, 1] - broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis] - self.std = np.reshape(self.std, broadcast_shape) - x /= (self.std + K.epsilon()) - - if self.zca_whitening: - if linalg is None: - raise ImportError('Scipy is required for zca_whitening.') - - flat_x = np.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])) - sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0] - u, s, _ = linalg.svd(sigma) - s_inv = 1. / np.sqrt(s[np.newaxis] + self.zca_epsilon) - self.principal_components = (u * s_inv).dot(u.T) - - -@tf_export('keras.preprocessing.image.Iterator') -class Iterator(Sequence): - """Base class for image data iterators. - - Every `Iterator` must implement the `_get_batches_of_transformed_samples` - method. - - Arguments: - n: Integer, total number of samples in the dataset to loop over. - batch_size: Integer, size of a batch. - shuffle: Boolean, whether to shuffle the data between epochs. - seed: Random seeding for data shuffling. - """ - - def __init__(self, n, batch_size, shuffle, seed): - self.n = n - self.batch_size = batch_size - self.seed = seed - self.shuffle = shuffle - self.batch_index = 0 - self.total_batches_seen = 0 - self.lock = threading.Lock() - self.index_array = None - self.index_generator = self._flow_index() - - def _set_index_array(self): - self.index_array = np.arange(self.n) - if self.shuffle: - self.index_array = np.random.permutation(self.n) - - def __getitem__(self, idx): - if idx >= len(self): - raise ValueError('Asked to retrieve element {idx}, ' - 'but the Sequence ' - 'has length {length}'.format(idx=idx, length=len(self))) - if self.seed is not None: - np.random.seed(self.seed + self.total_batches_seen) - self.total_batches_seen += 1 - if self.index_array is None: - self._set_index_array() - index_array = self.index_array[self.batch_size * idx:self.batch_size * ( - idx + 1)] - return self._get_batches_of_transformed_samples(index_array) - - def __len__(self): - return (self.n + self.batch_size - 1) // self.batch_size # round up - - def on_epoch_end(self): - self._set_index_array() - - def reset(self): - self.batch_index = 0 - - def _flow_index(self): - # Ensure self.batch_index is 0. - self.reset() - while 1: - if self.seed is not None: - np.random.seed(self.seed + self.total_batches_seen) - if self.batch_index == 0: - self._set_index_array() - - current_index = (self.batch_index * self.batch_size) % self.n - if self.n > current_index + self.batch_size: - self.batch_index += 1 - else: - self.batch_index = 0 - self.total_batches_seen += 1 - yield self.index_array[current_index:current_index + self.batch_size] - - def __iter__(self): # pylint: disable=non-iterator-returned - # Needed if we want to do something like: - # for x, y in data_gen.flow(...): - return self - - def __next__(self, *args, **kwargs): - return self.next(*args, **kwargs) - - def _get_batches_of_transformed_samples(self, index_array): - """Gets a batch of transformed samples. - - Arguments: - index_array: array of sample indices to include in batch. - - Returns: - A batch of transformed samples. - """ - raise NotImplementedError - - -@tf_export('keras.preprocessing.image.NumpyArrayIterator') -class NumpyArrayIterator(Iterator): - """Iterator yielding data from a Numpy array. - - Arguments: - x: Numpy array of input data. - y: Numpy array of targets data. - image_data_generator: Instance of `ImageDataGenerator` - to use for random transformations and normalization. - batch_size: Integer, size of a batch. - shuffle: Boolean, whether to shuffle the data between epochs. - seed: Random seed for data shuffling. - data_format: String, one of `channels_first`, `channels_last`. - save_to_dir: Optional directory where to save the pictures - being yielded, in a viewable format. This is useful - for visualizing the random transformations being - applied, for debugging purposes. - save_prefix: String prefix to use for saving sample - images (if `save_to_dir` is set). - save_format: Format to use for saving sample images - (if `save_to_dir` is set). - subset: Subset of data (`"training"` or `"validation"`) if - validation_split is set in ImageDataGenerator. - """ - - def __init__(self, - x, - y, - image_data_generator, - batch_size=32, - shuffle=False, - seed=None, - data_format=None, - save_to_dir=None, - save_prefix='', - save_format='png', - subset=None): - if y is not None and len(x) != len(y): - raise ValueError('`x` (images tensor) and `y` (labels) ' - 'should have the same length. ' - 'Found: x.shape = %s, y.shape = %s' % - (np.asarray(x).shape, np.asarray(y).shape)) - if subset is not None: - if subset not in {'training', 'validation'}: - raise ValueError('Invalid subset name:', subset, - '; expected "training" or "validation".') - split_idx = int(len(x) * image_data_generator.validation_split) - if subset == 'validation': - x = x[:split_idx] - if y is not None: - y = y[:split_idx] - else: - x = x[split_idx:] - if y is not None: - y = y[split_idx:] - if data_format is None: - data_format = K.image_data_format() - self.x = np.asarray(x, dtype=K.floatx()) - if self.x.ndim != 4: - raise ValueError('Input data in `NumpyArrayIterator` ' - 'should have rank 4. You passed an array ' - 'with shape', self.x.shape) - channels_axis = 3 if data_format == 'channels_last' else 1 - if self.x.shape[channels_axis] not in {1, 3, 4}: - logging.warning( - 'NumpyArrayIterator is set to use the ' - 'data format convention "' + data_format + '" ' - '(channels on axis ' + str(channels_axis) + '), i.e. expected ' - 'either 1, 3 or 4 channels on axis ' + str(channels_axis) + '. ' - 'However, it was passed an array with shape ' + str(self.x.shape) + - ' (' + str(self.x.shape[channels_axis]) + ' channels).') - if y is not None: - self.y = np.asarray(y) - else: - self.y = None - self.image_data_generator = image_data_generator - self.data_format = data_format - self.save_to_dir = save_to_dir - self.save_prefix = save_prefix - self.save_format = save_format - super(NumpyArrayIterator, self).__init__(x.shape[0], batch_size, shuffle, - seed) - - def _get_batches_of_transformed_samples(self, index_array): - batch_x = np.zeros( - tuple([len(index_array)] + list(self.x.shape)[1:]), dtype=K.floatx()) - for i, j in enumerate(index_array): - x = self.x[j] - x = self.image_data_generator.random_transform(x.astype(K.floatx())) - x = self.image_data_generator.standardize(x) - batch_x[i] = x - if self.save_to_dir: - for i, j in enumerate(index_array): - img = array_to_img(batch_x[i], self.data_format, scale=True) - fname = '{prefix}_{index}_{hash}.{format}'.format( - prefix=self.save_prefix, - index=j, - hash=np.random.randint(1e4), - format=self.save_format) - img.save(os.path.join(self.save_to_dir, fname)) - if self.y is None: - return batch_x - batch_y = self.y[index_array] - return batch_x, batch_y - - def next(self): - """For python 2.x. - - Returns: - The next batch. - """ - # Keeps under lock only the mechanism which advances - # the indexing of each batch. - with self.lock: - index_array = next(self.index_generator) - # The transformation of images is not under thread lock - # so it can be done in parallel - return self._get_batches_of_transformed_samples(index_array) - - -def _iter_valid_files(directory, white_list_formats, follow_links): - """Count files with extension in `white_list_formats` contained in directory. - - Arguments: - directory: absolute path to the directory - containing files to be counted - white_list_formats: set of strings containing allowed extensions for - the files to be counted. - follow_links: boolean. - - Yields: - tuple of (root, filename) with extension in `white_list_formats`. - """ - - def _recursive_list(subpath): - return sorted( - os.walk(subpath, followlinks=follow_links), key=lambda x: x[0]) - - for root, _, files in _recursive_list(directory): - for fname in sorted(files): - for extension in white_list_formats: - if fname.lower().endswith('.tiff'): - logging.warning( - 'Using \'.tiff\' files with multiple bands will cause ' - 'distortion. Please verify your output.') - if fname.lower().endswith('.' + extension): - yield root, fname - - -def _count_valid_files_in_directory(directory, white_list_formats, split, - follow_links): - """Count files with extension in `white_list_formats` contained in directory. - - Arguments: - directory: absolute path to the directory - containing files to be counted - white_list_formats: set of strings containing allowed extensions for - the files to be counted. - split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into - account a certain fraction of files in each directory. - E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent - of images in each directory. - follow_links: boolean. - - Returns: - the count of files with extension in `white_list_formats` contained in - the directory. - """ - num_files = len( - list(_iter_valid_files(directory, white_list_formats, follow_links))) - if split: - start, stop = int(split[0] * num_files), int(split[1] * num_files) - else: - start, stop = 0, num_files - return stop - start - - -def _list_valid_filenames_in_directory(directory, white_list_formats, split, - class_indices, follow_links): - """List paths of files in `subdir` with extensions in `white_list_formats`. - - Arguments: - directory: absolute path to a directory containing the files to list. - The directory name is used as class label and must be a key of - `class_indices`. - white_list_formats: set of strings containing allowed extensions for - the files to be counted. - split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into - account a certain fraction of files in each directory. - E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent - of images in each directory. - class_indices: dictionary mapping a class name to its index. - follow_links: boolean. - - Returns: - classes: a list of class indices - filenames: the path of valid files in `directory`, relative from - `directory`'s parent (e.g., if `directory` is "dataset/class1", - the filenames will be ["class1/file1.jpg", "class1/file2.jpg", ...]). - """ - dirname = os.path.basename(directory) - if split: - num_files = len( - list(_iter_valid_files(directory, white_list_formats, follow_links))) - start, stop = int(split[0] * num_files), int(split[1] * num_files) - valid_files = list( - _iter_valid_files(directory, white_list_formats, - follow_links))[start:stop] - else: - valid_files = _iter_valid_files(directory, white_list_formats, follow_links) - - classes = [] - filenames = [] - for root, fname in valid_files: - classes.append(class_indices[dirname]) - absolute_path = os.path.join(root, fname) - relative_path = os.path.join(dirname, - os.path.relpath(absolute_path, directory)) - filenames.append(relative_path) - - return classes, filenames - - -@tf_export('keras.preprocessing.image.DirectoryIterator') -class DirectoryIterator(Iterator): - """Iterator capable of reading images from a directory on disk. - - Arguments: - directory: Path to the directory to read images from. - Each subdirectory in this directory will be - considered to contain images from one class, - or alternatively you could specify class subdirectories - via the `classes` argument. - image_data_generator: Instance of `ImageDataGenerator` - to use for random transformations and normalization. - target_size: tuple of integers, dimensions to resize input images to. - color_mode: One of `"rgb"`, `"grayscale"`. Color mode to read images. - classes: Optional list of strings, names of subdirectories - containing images from each class (e.g. `["dogs", "cats"]`). - It will be computed automatically if not set. - class_mode: Mode for yielding the targets: - `"binary"`: binary targets (if there are only two classes), - `"categorical"`: categorical targets, - `"sparse"`: integer targets, - `"input"`: targets are images identical to input images (mainly - used to work with autoencoders), - `None`: no targets get yielded (only input images are yielded). - batch_size: Integer, size of a batch. - shuffle: Boolean, whether to shuffle the data between epochs. - seed: Random seed for data shuffling. - data_format: String, one of `channels_first`, `channels_last`. - save_to_dir: Optional directory where to save the pictures - being yielded, in a viewable format. This is useful - for visualizing the random transformations being - applied, for debugging purposes. - save_prefix: String prefix to use for saving sample - images (if `save_to_dir` is set). - save_format: Format to use for saving sample images - (if `save_to_dir` is set). - subset: Subset of data (`"training"` or `"validation"`) if - validation_split is set in ImageDataGenerator. - interpolation: Interpolation method used to resample the image if the - target size is different from that of the loaded image. - Supported methods are "nearest", "bilinear", and "bicubic". - If PIL version 1.1.3 or newer is installed, "lanczos" is also - supported. If PIL version 3.4.0 or newer is installed, "box" and - "hamming" are also supported. By default, "nearest" is used. - """ - - def __init__(self, - directory, - image_data_generator, - target_size=(256, 256), - color_mode='rgb', - classes=None, - class_mode='categorical', - batch_size=32, - shuffle=True, - seed=None, - data_format=None, - save_to_dir=None, - save_prefix='', - save_format='png', - follow_links=False, - subset=None, - interpolation='nearest'): - if data_format is None: - data_format = K.image_data_format() - self.directory = directory - self.image_data_generator = image_data_generator - self.target_size = tuple(target_size) - if color_mode not in {'rgb', 'grayscale'}: - raise ValueError('Invalid color mode:', color_mode, - '; expected "rgb" or "grayscale".') - self.color_mode = color_mode - self.data_format = data_format - if self.color_mode == 'rgb': - if self.data_format == 'channels_last': - self.image_shape = self.target_size + (3,) - else: - self.image_shape = (3,) + self.target_size - else: - if self.data_format == 'channels_last': - self.image_shape = self.target_size + (1,) - else: - self.image_shape = (1,) + self.target_size - self.classes = classes - if class_mode not in {'categorical', 'binary', 'sparse', 'input', None}: - raise ValueError('Invalid class_mode:', class_mode, - '; expected one of "categorical", ' - '"binary", "sparse", "input"' - ' or None.') - self.class_mode = class_mode - self.save_to_dir = save_to_dir - self.save_prefix = save_prefix - self.save_format = save_format - self.interpolation = interpolation - - if subset is not None: - validation_split = self.image_data_generator.validation_split - if subset == 'validation': - split = (0, validation_split) - elif subset == 'training': - split = (validation_split, 1) - else: - raise ValueError('Invalid subset name: ', subset, - '; expected "training" or "validation"') - else: - split = None - self.subset = subset - - white_list_formats = {'png', 'jpg', 'jpeg', 'bmp', 'ppm', 'tif', 'tiff'} - - # first, count the number of samples and classes - self.samples = 0 - - if not classes: - classes = [] - for subdir in sorted(os.listdir(directory)): - if os.path.isdir(os.path.join(directory, subdir)): - classes.append(subdir) - self.num_classes = len(classes) - self.class_indices = dict(zip(classes, range(len(classes)))) - - pool = multiprocessing.pool.ThreadPool() - function_partial = partial( - _count_valid_files_in_directory, - white_list_formats=white_list_formats, - follow_links=follow_links, - split=split) - self.samples = sum( - pool.map(function_partial, - (os.path.join(directory, subdir) for subdir in classes))) - - print('Found %d images belonging to %d classes.' % (self.samples, - self.num_classes)) - - # second, build an index of the images in the different class subfolders - results = [] - - self.filenames = [] - self.classes = np.zeros((self.samples,), dtype='int32') - i = 0 - for dirpath in (os.path.join(directory, subdir) for subdir in classes): - results.append( - pool.apply_async(_list_valid_filenames_in_directory, - (dirpath, white_list_formats, split, - self.class_indices, follow_links))) - for res in results: - classes, filenames = res.get() - self.classes[i:i + len(classes)] = classes - self.filenames += filenames - i += len(classes) - - pool.close() - pool.join() - super(DirectoryIterator, self).__init__(self.samples, batch_size, shuffle, - seed) - - def _get_batches_of_transformed_samples(self, index_array): - batch_x = np.zeros((len(index_array),) + self.image_shape, dtype=K.floatx()) - grayscale = self.color_mode == 'grayscale' - # build batch of image data - for i, j in enumerate(index_array): - fname = self.filenames[j] - img = load_img( - os.path.join(self.directory, fname), - grayscale=grayscale, - target_size=self.target_size, - interpolation=self.interpolation) - x = img_to_array(img, data_format=self.data_format) - x = self.image_data_generator.random_transform(x) - x = self.image_data_generator.standardize(x) - batch_x[i] = x - # optionally save augmented images to disk for debugging purposes - if self.save_to_dir: - for i, j in enumerate(index_array): - img = array_to_img(batch_x[i], self.data_format, scale=True) - fname = '{prefix}_{index}_{hash}.{format}'.format( - prefix=self.save_prefix, - index=j, - hash=np.random.randint(1e7), - format=self.save_format) - img.save(os.path.join(self.save_to_dir, fname)) - # build batch of labels - if self.class_mode == 'input': - batch_y = batch_x.copy() - elif self.class_mode == 'sparse': - batch_y = self.classes[index_array] - elif self.class_mode == 'binary': - batch_y = self.classes[index_array].astype(K.floatx()) - elif self.class_mode == 'categorical': - batch_y = np.zeros((len(batch_x), self.num_classes), dtype=K.floatx()) - for i, label in enumerate(self.classes[index_array]): - batch_y[i, label] = 1. - else: - return batch_x - return batch_x, batch_y - - def next(self): - """For python 2.x. +from tensorflow.python.util.tf_export import tf_export - Returns: - The next batch. - """ - with self.lock: - index_array = next(self.index_generator) - # The transformation of images is not under thread lock - # so it can be done in parallel - return self._get_batches_of_transformed_samples(index_array) +random_rotation = image.random_rotation +random_shift = image.random_shift +random_shear = image.random_shear +random_zoom = image.random_zoom +apply_channel_shift = image.apply_channel_shift +random_channel_shift = image.random_channel_shift +apply_brightness_shift = image.apply_brightness_shift +random_brightness = image.random_brightness +apply_affine_transform = image.apply_affine_transform +array_to_img = image.array_to_img +img_to_array = image.img_to_array +save_img = image.save_img +load_img = image.load_img +ImageDataGenerator = image.ImageDataGenerator +Iterator = image.Iterator +NumpyArrayIterator = image.NumpyArrayIterator +DirectoryIterator = image.DirectoryIterator + +tf_export('keras.preprocessing.image.random_rotation')(random_rotation) +tf_export('keras.preprocessing.image.random_shift')(random_shift) +tf_export('keras.preprocessing.image.random_shear')(random_shear) +tf_export('keras.preprocessing.image.random_zoom')(random_zoom) +tf_export('keras.preprocessing.image.apply_channel_shift')(apply_channel_shift) +tf_export( + 'keras.preprocessing.image.random_channel_shift')(random_channel_shift) +tf_export( + 'keras.preprocessing.image.apply_brightness_shift')(apply_brightness_shift) +tf_export('keras.preprocessing.image.random_brightness')(random_brightness) +tf_export( + 'keras.preprocessing.image.apply_affine_transform')(apply_affine_transform) +tf_export('keras.preprocessing.image.array_to_img')(array_to_img) +tf_export('keras.preprocessing.image.img_to_array')(img_to_array) +tf_export('keras.preprocessing.image.save_img')(save_img) +tf_export('keras.preprocessing.image.load_img')(load_img) +tf_export('keras.preprocessing.image.ImageDataGenerator')(ImageDataGenerator) +tf_export('keras.preprocessing.image.Iterator')(Iterator) +tf_export('keras.preprocessing.image.NumpyArrayIterator')(NumpyArrayIterator) +tf_export('keras.preprocessing.image.DirectoryIterator')(DirectoryIterator) diff --git a/tensorflow/python/keras/preprocessing/image_test.py b/tensorflow/python/keras/preprocessing/image_test.py index 275808a6155b26159259584653cb48697af9f318..362cbc1dc9bb2b769c30553b042fc6dde3b23d96 100644 --- a/tensorflow/python/keras/preprocessing/image_test.py +++ b/tensorflow/python/keras/preprocessing/image_test.py @@ -161,9 +161,6 @@ class TestImage(test.TestCase): generator = keras.preprocessing.image.ImageDataGenerator( zoom_range=(2, 2)) - with self.assertRaises(ValueError): - generator = keras.preprocessing.image.ImageDataGenerator( - zoom_range=(2, 2, 2)) def test_image_data_generator_fit(self): generator = keras.preprocessing.image.ImageDataGenerator( diff --git a/tensorflow/python/keras/preprocessing/sequence.py b/tensorflow/python/keras/preprocessing/sequence.py index e0924f837a79dbdf31bee09667b43f70a1273b4b..116d3108d90cdf4aa455c5f25891da51610ff6cc 100644 --- a/tensorflow/python/keras/preprocessing/sequence.py +++ b/tensorflow/python/keras/preprocessing/sequence.py @@ -14,383 +14,25 @@ # ============================================================================== """Utilities for preprocessing sequence data. """ +# pylint: disable=invalid-name from __future__ import absolute_import from __future__ import division from __future__ import print_function -import random +from keras_preprocessing import sequence -import numpy as np -from six.moves import range # pylint: disable=redefined-builtin - -from tensorflow.python.keras.utils.data_utils import Sequence from tensorflow.python.util.tf_export import tf_export - -@tf_export('keras.preprocessing.sequence.pad_sequences') -def pad_sequences(sequences, - maxlen=None, - dtype='int32', - padding='pre', - truncating='pre', - value=0.): - """Pads sequences to the same length. - - This function transforms a list of - `num_samples` sequences (lists of integers) - into a 2D Numpy array of shape `(num_samples, num_timesteps)`. - `num_timesteps` is either the `maxlen` argument if provided, - or the length of the longest sequence otherwise. - - Sequences that are shorter than `num_timesteps` - are padded with `value` at the end. - - Sequences longer than `num_timesteps` are truncated - so that they fit the desired length. - The position where padding or truncation happens is determined by - the arguments `padding` and `truncating`, respectively. - - Pre-padding is the default. - - Arguments: - sequences: List of lists, where each element is a sequence. - maxlen: Int, maximum length of all sequences. - dtype: Type of the output sequences. - padding: String, 'pre' or 'post': - pad either before or after each sequence. - truncating: String, 'pre' or 'post': - remove values from sequences larger than - `maxlen`, either at the beginning or at the end of the sequences. - value: Float, padding value. - - Returns: - x: Numpy array with shape `(len(sequences), maxlen)` - - Raises: - ValueError: In case of invalid values for `truncating` or `padding`, - or in case of invalid shape for a `sequences` entry. - """ - if not hasattr(sequences, '__len__'): - raise ValueError('`sequences` must be iterable.') - lengths = [] - for x in sequences: - if not hasattr(x, '__len__'): - raise ValueError('`sequences` must be a list of iterables. ' - 'Found non-iterable: ' + str(x)) - lengths.append(len(x)) - - num_samples = len(sequences) - if maxlen is None: - maxlen = np.max(lengths) - - # take the sample shape from the first non empty sequence - # checking for consistency in the main loop below. - sample_shape = tuple() - for s in sequences: - if len(s) > 0: # pylint: disable=g-explicit-length-test - sample_shape = np.asarray(s).shape[1:] - break - - x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype) - for idx, s in enumerate(sequences): - if not len(s): # pylint: disable=g-explicit-length-test - continue # empty list/array was found - if truncating == 'pre': - trunc = s[-maxlen:] # pylint: disable=invalid-unary-operand-type - elif truncating == 'post': - trunc = s[:maxlen] - else: - raise ValueError('Truncating type "%s" not understood' % truncating) - - # check `trunc` has expected shape - trunc = np.asarray(trunc, dtype=dtype) - if trunc.shape[1:] != sample_shape: - raise ValueError('Shape of sample %s of sequence at position %s ' - 'is different from expected shape %s' % - (trunc.shape[1:], idx, sample_shape)) - - if padding == 'post': - x[idx, :len(trunc)] = trunc - elif padding == 'pre': - x[idx, -len(trunc):] = trunc - else: - raise ValueError('Padding type "%s" not understood' % padding) - return x - - -@tf_export('keras.preprocessing.sequence.make_sampling_table') -def make_sampling_table(size, sampling_factor=1e-5): - """Generates a word rank-based probabilistic sampling table. - - Used for generating the `sampling_table` argument for `skipgrams`. - `sampling_table[i]` is the probability of sampling - the word i-th most common word in a dataset - (more common words should be sampled less frequently, for balance). - - The sampling probabilities are generated according - to the sampling distribution used in word2vec: - - `p(word) = min(1, sqrt(word_frequency / sampling_factor) / (word_frequency / - sampling_factor))` - - We assume that the word frequencies follow Zipf's law (s=1) to derive - a numerical approximation of frequency(rank): - - `frequency(rank) ~ 1/(rank * (log(rank) + gamma) + 1/2 - 1/(12*rank))` - where `gamma` is the Euler-Mascheroni constant. - - Arguments: - size: Int, number of possible words to sample. - sampling_factor: The sampling factor in the word2vec formula. - - Returns: - A 1D Numpy array of length `size` where the ith entry - is the probability that a word of rank i should be sampled. - """ - gamma = 0.577 - rank = np.arange(size) - rank[0] = 1 - inv_fq = rank * (np.log(rank) + gamma) + 0.5 - 1. / (12. * rank) - f = sampling_factor * inv_fq - - return np.minimum(1., f / np.sqrt(f)) - - -@tf_export('keras.preprocessing.sequence.skipgrams') -def skipgrams(sequence, - vocabulary_size, - window_size=4, - negative_samples=1., - shuffle=True, - categorical=False, - sampling_table=None, - seed=None): - """Generates skipgram word pairs. - - This function transforms a sequence of word indexes (list of integers) - into tuples of words of the form: - - - (word, word in the same window), with label 1 (positive samples). - - (word, random word from the vocabulary), with label 0 (negative samples). - - Read more about Skipgram in this gnomic paper by Mikolov et al.: - [Efficient Estimation of Word Representations in - Vector Space](http://arxiv.org/pdf/1301.3781v3.pdf) - - Arguments: - sequence: A word sequence (sentence), encoded as a list - of word indices (integers). If using a `sampling_table`, - word indices are expected to match the rank - of the words in a reference dataset (e.g. 10 would encode - the 10-th most frequently occurring token). - Note that index 0 is expected to be a non-word and will be skipped. - vocabulary_size: Int, maximum possible word index + 1 - window_size: Int, size of sampling windows (technically half-window). - The window of a word `w_i` will be - `[i - window_size, i + window_size+1]`. - negative_samples: Float >= 0. 0 for no negative (i.e. random) samples. - 1 for same number as positive samples. - shuffle: Whether to shuffle the word couples before returning them. - categorical: bool. if False, labels will be - integers (eg. `[0, 1, 1 .. ]`), - if `True`, labels will be categorical, e.g. - `[[1,0],[0,1],[0,1] .. ]`. - sampling_table: 1D array of size `vocabulary_size` where the entry i - encodes the probability to sample a word of rank i. - seed: Random seed. - - Returns: - couples, labels: where `couples` are int pairs and - `labels` are either 0 or 1. - - # Note - By convention, index 0 in the vocabulary is - a non-word and will be skipped. - """ - couples = [] - labels = [] - for i, wi in enumerate(sequence): - if not wi: - continue - if sampling_table is not None: - if sampling_table[wi] < random.random(): - continue - - window_start = max(0, i - window_size) - window_end = min(len(sequence), i + window_size + 1) - for j in range(window_start, window_end): - if j != i: - wj = sequence[j] - if not wj: - continue - couples.append([wi, wj]) - if categorical: - labels.append([0, 1]) - else: - labels.append(1) - - if negative_samples > 0: - num_negative_samples = int(len(labels) * negative_samples) - words = [c[0] for c in couples] - random.shuffle(words) - - couples += [[words[i % len(words)], - random.randint(1, vocabulary_size - 1)] - for i in range(num_negative_samples)] - if categorical: - labels += [[1, 0]] * num_negative_samples - else: - labels += [0] * num_negative_samples - - if shuffle: - if seed is None: - seed = random.randint(0, 10e6) - random.seed(seed) - random.shuffle(couples) - random.seed(seed) - random.shuffle(labels) - - return couples, labels - - -def _remove_long_seq(maxlen, seq, label): - """Removes sequences that exceed the maximum length. - - Arguments: - maxlen: Int, maximum length of the output sequences. - seq: List of lists, where each sublist is a sequence. - label: List where each element is an integer. - - Returns: - new_seq, new_label: shortened lists for `seq` and `label`. - """ - new_seq, new_label = [], [] - for x, y in zip(seq, label): - if len(x) < maxlen: - new_seq.append(x) - new_label.append(y) - return new_seq, new_label - - -@tf_export('keras.preprocessing.sequence.TimeseriesGenerator') -class TimeseriesGenerator(Sequence): - """Utility class for generating batches of temporal data. - - This class takes in a sequence of data-points gathered at - equal intervals, along with time series parameters such as - stride, length of history, etc., to produce batches for - training/validation. - - Arguments: - data: Indexable generator (such as list or Numpy array) - containing consecutive data points (timesteps). - The data should be at 2D, and axis 0 is expected - to be the time dimension. - targets: Targets corresponding to timesteps in `data`. - It should have same length as `data`. - length: Length of the output sequences (in number of timesteps). - sampling_rate: Period between successive individual timesteps - within sequences. For rate `r`, timesteps - `data[i]`, `data[i-r]`, ... `data[i - length]` - are used for create a sample sequence. - stride: Period between successive output sequences. - For stride `s`, consecutive output samples would - be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc. - start_index, end_index: Data points earlier than `start_index` - or later than `end_index` will not be used in the output sequences. - This is useful to reserve part of the data for test or validation. - shuffle: Whether to shuffle output samples, - or instead draw them in chronological order. - reverse: Boolean: if `true`, timesteps in each output sample will be - in reverse chronological order. - batch_size: Number of timeseries samples in each batch - (except maybe the last one). - - Returns: - A [Sequence](/utils/#sequence) instance. - - Examples: - - ```python - from keras.preprocessing.sequence import TimeseriesGenerator - import numpy as np - - data = np.array([[i] for i in range(50)]) - targets = np.array([[i] for i in range(50)]) - - data_gen = TimeseriesGenerator(data, targets, - length=10, sampling_rate=2, - batch_size=2) - assert len(data_gen) == 20 - - batch_0 = data_gen[0] - x, y = batch_0 - assert np.array_equal(x, - np.array([[[0], [2], [4], [6], [8]], - [[1], [3], [5], [7], [9]]])) - assert np.array_equal(y, - np.array([[10], [11]])) - ``` - """ - - def __init__(self, - data, - targets, - length, - sampling_rate=1, - stride=1, - start_index=0, - end_index=None, - shuffle=False, - reverse=False, - batch_size=128): - self.data = data - self.targets = targets - self.length = length - self.sampling_rate = sampling_rate - self.stride = stride - self.start_index = start_index + length - if end_index is None: - end_index = len(data) - 1 - self.end_index = end_index - self.shuffle = shuffle - self.reverse = reverse - self.batch_size = batch_size - - if self.start_index > self.end_index: - raise ValueError('`start_index+length=%i > end_index=%i` ' - 'is disallowed, as no part of the sequence ' - 'would be left to be used as current step.' % - (self.start_index, self.end_index)) - - def __len__(self): - length = int( - np.ceil((self.end_index - self.start_index + 1) / - (self.batch_size * self.stride))) - return length if length >= 0 else 0 - - def _empty_batch(self, num_rows): - samples_shape = [num_rows, self.length // self.sampling_rate] - samples_shape.extend(self.data.shape[1:]) - targets_shape = [num_rows] - targets_shape.extend(self.targets.shape[1:]) - return np.empty(samples_shape), np.empty(targets_shape) - - def __getitem__(self, index): - if self.shuffle: - rows = np.random.randint( - self.start_index, self.end_index + 1, size=self.batch_size) - else: - i = self.start_index + self.batch_size * self.stride * index - rows = np.arange( - i, min(i + self.batch_size * self.stride, self.end_index + 1), - self.stride) - - samples, targets = self._empty_batch(len(rows)) - for j in range(len(rows)): - indices = range(rows[j] - self.length, rows[j], self.sampling_rate) - samples[j] = self.data[indices] - targets[j] = self.targets[rows[j]] - if self.reverse: - return samples[:, ::-1, ...], targets - return samples, targets +pad_sequences = sequence.pad_sequences +make_sampling_table = sequence.make_sampling_table +skipgrams = sequence.skipgrams +# TODO(fchollet): consider making `_remove_long_seq` public. +_remove_long_seq = sequence._remove_long_seq # pylint: disable=protected-access +TimeseriesGenerator = sequence.TimeseriesGenerator + +tf_export('keras.preprocessing.sequence.pad_sequences')(pad_sequences) +tf_export( + 'keras.preprocessing.sequence.make_sampling_table')(make_sampling_table) +tf_export('keras.preprocessing.sequence.skipgrams')(skipgrams) +tf_export( + 'keras.preprocessing.sequence.TimeseriesGenerator')(TimeseriesGenerator) diff --git a/tensorflow/python/keras/preprocessing/text.py b/tensorflow/python/keras/preprocessing/text.py index f3b57de257a58663f7eb30efb27638ce16b5c431..57e5d00e0486694f8034453d56247029164f9849 100644 --- a/tensorflow/python/keras/preprocessing/text.py +++ b/tensorflow/python/keras/preprocessing/text.py @@ -14,383 +14,22 @@ # ============================================================================== """Utilities for text input preprocessing. """ +# pylint: disable=invalid-name from __future__ import absolute_import from __future__ import division from __future__ import print_function -from collections import OrderedDict -from hashlib import md5 -import string -import sys +from keras_preprocessing import text -import numpy as np -from six.moves import range # pylint: disable=redefined-builtin -from six.moves import zip # pylint: disable=redefined-builtin - -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export +text_to_word_sequence = text.text_to_word_sequence +one_hot = text.one_hot +hashing_trick = text.hashing_trick +Tokenizer = text.Tokenizer -if sys.version_info < (3,): - maketrans = string.maketrans -else: - maketrans = str.maketrans - - -@tf_export('keras.preprocessing.text.text_to_word_sequence') -def text_to_word_sequence(text, - filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', - lower=True, - split=' '): - r"""Converts a text to a sequence of words (or tokens). - - Arguments: - text: Input text (string). - filters: list (or concatenation) of characters to filter out, such as - punctuation. Default: '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', - includes basic punctuation, tabs, and newlines. - lower: boolean, whether to convert the input to lowercase. - split: string, separator for word splitting. - - Returns: - A list of words (or tokens). - """ - if lower: - text = text.lower() - - if sys.version_info < (3,): - if isinstance(text, unicode): - translate_map = dict((ord(c), unicode(split)) for c in filters) - text = text.translate(translate_map) - elif len(split) == 1: - translate_map = maketrans(filters, split * len(filters)) - text = text.translate(translate_map) - else: - for c in filters: - text = text.replace(c, split) - else: - translate_dict = dict((c, split) for c in filters) - translate_map = maketrans(translate_dict) - text = text.translate(translate_map) - - seq = text.split(split) - return [i for i in seq if i] - - -@tf_export('keras.preprocessing.text.one_hot') -def one_hot(text, - n, - filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', - lower=True, - split=' '): - r"""One-hot encodes a text into a list of word indexes of size n. - - This is a wrapper to the `hashing_trick` function using `hash` as the - hashing function; unicity of word to index mapping non-guaranteed. - - Arguments: - text: Input text (string). - n: int, size of vocabulary. - filters: list (or concatenation) of characters to filter out, such as - punctuation. Default: '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', - includes basic punctuation, tabs, and newlines. - lower: boolean, whether to set the text to lowercase. - split: string, separator for word splitting. - - Returns: - List of integers in [1, n]. - Each integer encodes a word (unicity non-guaranteed). - """ - return hashing_trick( - text, n, hash_function=hash, filters=filters, lower=lower, split=split) - - -@tf_export('keras.preprocessing.text.hashing_trick') -def hashing_trick(text, - n, - hash_function=None, - filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', - lower=True, - split=' '): - r"""Converts a text to a sequence of indexes in a fixed-size hashing space. - - Arguments: - text: Input text (string). - n: Dimension of the hashing space. - hash_function: defaults to python `hash` function, can be 'md5' or - any function that takes in input a string and returns a int. - Note that 'hash' is not a stable hashing function, so - it is not consistent across different runs, while 'md5' - is a stable hashing function. - filters: list (or concatenation) of characters to filter out, such as - punctuation. Default: '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', - includes basic punctuation, tabs, and newlines. - lower: boolean, whether to set the text to lowercase. - split: string, separator for word splitting. - - Returns: - A list of integer word indices (unicity non-guaranteed). - - `0` is a reserved index that won't be assigned to any word. - - Two or more words may be assigned to the same index, due to possible - collisions by the hashing function. - The - probability - of a collision is in relation to the dimension of the hashing space and - the number of distinct objects. - """ - if hash_function is None: - hash_function = hash - elif hash_function == 'md5': - hash_function = lambda w: int(md5(w.encode()).hexdigest(), 16) - - seq = text_to_word_sequence(text, filters=filters, lower=lower, split=split) - return [(hash_function(w) % (n - 1) + 1) for w in seq] - - -@tf_export('keras.preprocessing.text.Tokenizer') -class Tokenizer(object): - """Text tokenization utility class. - - This class allows to vectorize a text corpus, by turning each - text into either a sequence of integers (each integer being the index - of a token in a dictionary) or into a vector where the coefficient - for each token could be binary, based on word count, based on tf-idf... - - Arguments: - num_words: the maximum number of words to keep, based - on word frequency. Only the most common `num_words` words will - be kept. - filters: a string where each element is a character that will be - filtered from the texts. The default is all punctuation, plus - tabs and line breaks, minus the `'` character. - lower: boolean. Whether to convert the texts to lowercase. - split: string, separator for word splitting. - char_level: if True, every character will be treated as a token. - oov_token: if given, it will be added to word_index and used to - replace out-of-vocabulary words during text_to_sequence calls - - By default, all punctuation is removed, turning the texts into - space-separated sequences of words - (words maybe include the `'` character). These sequences are then - split into lists of tokens. They will then be indexed or vectorized. - - `0` is a reserved index that won't be assigned to any word. - """ - - def __init__(self, - num_words=None, - filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', - lower=True, - split=' ', - char_level=False, - oov_token=None, - **kwargs): - # Legacy support - if 'nb_words' in kwargs: - logging.warning('The `nb_words` argument in `Tokenizer` ' - 'has been renamed `num_words`.') - num_words = kwargs.pop('nb_words') - if kwargs: - raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) - - self.word_counts = OrderedDict() - self.word_docs = {} - self.filters = filters - self.split = split - self.lower = lower - self.num_words = num_words - self.document_count = 0 - self.char_level = char_level - self.oov_token = oov_token - self.index_docs = {} - - def fit_on_texts(self, texts): - """Updates internal vocabulary based on a list of texts. - - In the case where texts contains lists, we assume each entry of the lists - to be a token. - - Required before using `texts_to_sequences` or `texts_to_matrix`. - - Arguments: - texts: can be a list of strings, - a generator of strings (for memory-efficiency), - or a list of list of strings. - """ - for text in texts: - self.document_count += 1 - if self.char_level or isinstance(text, list): - seq = text - else: - seq = text_to_word_sequence(text, self.filters, self.lower, self.split) - for w in seq: - if w in self.word_counts: - self.word_counts[w] += 1 - else: - self.word_counts[w] = 1 - for w in set(seq): - if w in self.word_docs: - self.word_docs[w] += 1 - else: - self.word_docs[w] = 1 - - wcounts = list(self.word_counts.items()) - wcounts.sort(key=lambda x: x[1], reverse=True) - sorted_voc = [wc[0] for wc in wcounts] - # note that index 0 is reserved, never assigned to an existing word - self.word_index = dict( - list(zip(sorted_voc, list(range(1, - len(sorted_voc) + 1))))) - - if self.oov_token is not None: - i = self.word_index.get(self.oov_token) - if i is None: - self.word_index[self.oov_token] = len(self.word_index) + 1 - - for w, c in list(self.word_docs.items()): - self.index_docs[self.word_index[w]] = c - - def fit_on_sequences(self, sequences): - """Updates internal vocabulary based on a list of sequences. - - Required before using `sequences_to_matrix` - (if `fit_on_texts` was never called). - - Arguments: - sequences: A list of sequence. - A "sequence" is a list of integer word indices. - """ - self.document_count += len(sequences) - for seq in sequences: - seq = set(seq) - for i in seq: - if i not in self.index_docs: - self.index_docs[i] = 1 - else: - self.index_docs[i] += 1 - - def texts_to_sequences(self, texts): - """Transforms each text in texts in a sequence of integers. - - Only top "num_words" most frequent words will be taken into account. - Only words known by the tokenizer will be taken into account. - - Arguments: - texts: A list of texts (strings). - - Returns: - A list of sequences. - """ - res = [] - for vect in self.texts_to_sequences_generator(texts): - res.append(vect) - return res - - def texts_to_sequences_generator(self, texts): - """Transforms each text in `texts` in a sequence of integers. - - Each item in texts can also be a list, in which case we assume each item of - that list - to be a token. - - Only top "num_words" most frequent words will be taken into account. - Only words known by the tokenizer will be taken into account. - - Arguments: - texts: A list of texts (strings). - - Yields: - Yields individual sequences. - """ - num_words = self.num_words - for text in texts: - if self.char_level or isinstance(text, list): - seq = text - else: - seq = text_to_word_sequence(text, self.filters, self.lower, self.split) - vect = [] - for w in seq: - i = self.word_index.get(w) - if i is not None: - if num_words and i >= num_words: - continue - else: - vect.append(i) - elif self.oov_token is not None: - i = self.word_index.get(self.oov_token) - if i is not None: - vect.append(i) - yield vect - - def texts_to_matrix(self, texts, mode='binary'): - """Convert a list of texts to a Numpy matrix. - - Arguments: - texts: list of strings. - mode: one of "binary", "count", "tfidf", "freq". - - Returns: - A Numpy matrix. - """ - sequences = self.texts_to_sequences(texts) - return self.sequences_to_matrix(sequences, mode=mode) - - def sequences_to_matrix(self, sequences, mode='binary'): - """Converts a list of sequences into a Numpy matrix. - - Arguments: - sequences: list of sequences - (a sequence is a list of integer word indices). - mode: one of "binary", "count", "tfidf", "freq" - - Returns: - A Numpy matrix. - - Raises: - ValueError: In case of invalid `mode` argument, - or if the Tokenizer requires to be fit to sample data. - """ - if not self.num_words: - if self.word_index: - num_words = len(self.word_index) + 1 - else: - raise ValueError('Specify a dimension (num_words argument), ' - 'or fit on some text data first.') - else: - num_words = self.num_words - - if mode == 'tfidf' and not self.document_count: - raise ValueError('Fit the Tokenizer on some data ' - 'before using tfidf mode.') - - x = np.zeros((len(sequences), num_words)) - for i, seq in enumerate(sequences): - if not seq: - continue - counts = {} - for j in seq: - if j >= num_words: - continue - if j not in counts: - counts[j] = 1. - else: - counts[j] += 1 - for j, c in list(counts.items()): - if mode == 'count': - x[i][j] = c - elif mode == 'freq': - x[i][j] = c / len(seq) - elif mode == 'binary': - x[i][j] = 1 - elif mode == 'tfidf': - # Use weighting scheme 2 in - # https://en.wikipedia.org/wiki/Tf%E2%80%93idf - tf = 1 + np.log(c) - idf = np.log(1 + self.document_count / - (1 + self.index_docs.get(j, 0))) - x[i][j] = tf * idf - else: - raise ValueError('Unknown vectorization mode:', mode) - return x +tf_export( + 'keras.preprocessing.text.text_to_word_sequence')(text_to_word_sequence) +tf_export('keras.preprocessing.text.one_hot')(one_hot) +tf_export('keras.preprocessing.text.hashing_trick')(hashing_trick) +tf_export('keras.preprocessing.text.Tokenizer')(Tokenizer) diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index e7cb45d5e110dcb749ae2b1b86dd8dd5b8ded4ef..6e8ee06ff53691c3f439114bd6c4745c03cf9a10 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -183,3 +183,4 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, # for further checks in the caller function return actual_output + diff --git a/tensorflow/python/keras/utils/__init__.py b/tensorflow/python/keras/utils/__init__.py index 69337b6a8d52abd4caf2ada518fde51c407f8103..c442b31116091955335423d2e60eaacf464c568e 100644 --- a/tensorflow/python/keras/utils/__init__.py +++ b/tensorflow/python/keras/utils/__init__.py @@ -31,6 +31,7 @@ from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.keras.utils.io_utils import HDF5Matrix from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras.utils.layer_utils import get_source_inputs from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model from tensorflow.python.keras.utils.np_utils import normalize from tensorflow.python.keras.utils.np_utils import to_categorical diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index a69893955f4f1cd7d4fafb1746019a59c240dd09..2e56fa2dc5474678ba3ef765bc148f09c4665ec0 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -162,7 +162,7 @@ def deserialize_keras_object(identifier, if cls is None: raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) if hasattr(cls, 'from_config'): - arg_spec = tf_inspect.getargspec(cls.from_config) + arg_spec = tf_inspect.getfullargspec(cls.from_config) custom_objects = custom_objects or {} if 'custom_objects' in arg_spec.args: @@ -281,8 +281,8 @@ def has_arg(fn, name, accept_all=False): Returns: bool, whether `fn` accepts a `name` keyword argument. """ - arg_spec = tf_inspect.getargspec(fn) - if accept_all and arg_spec.keywords is not None: + arg_spec = tf_inspect.getfullargspec(fn) + if accept_all and arg_spec.varkw is not None: return True return name in arg_spec.args diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py index 88daff0461593f6270f3be8c06a277c7e6751286..1f28c59ea41a96461a7faba2c41f5e65e6af0180 100644 --- a/tensorflow/python/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -26,6 +26,47 @@ from tensorflow.python.keras.utils.conv_utils import convert_kernel from tensorflow.python.util.tf_export import tf_export +def get_source_inputs(tensor, layer=None, node_index=None): + """Returns the list of input tensors necessary to compute `tensor`. + + Output will always be a list of tensors + (potentially with 1 element). + + Arguments: + tensor: The tensor to start from. + layer: Origin layer of the tensor. Will be + determined via tensor._keras_history if not provided. + node_index: Origin node index of the tensor. + + Returns: + List of input tensors. + """ + if not hasattr(tensor, '_keras_history'): + return tensor + + if layer is None or node_index: + layer, node_index, _ = tensor._keras_history + if not layer._inbound_nodes: + return [tensor] + else: + node = layer._inbound_nodes[node_index] + if not node.inbound_layers: + # Reached an Input layer, stop recursion. + return node.input_tensors + else: + source_tensors = [] + for i in range(len(node.inbound_layers)): + x = node.input_tensors[i] + layer = node.inbound_layers[i] + node_index = node.node_indices[i] + previous_sources = get_source_inputs(x, layer, node_index) + # Avoid input redundancy. + for x in previous_sources: + if x not in source_tensors: + source_tensors.append(x) + return source_tensors + + def count_params(weights): """Count the total number of scalars composing the weights. diff --git a/tensorflow/python/keras/utils/np_utils.py b/tensorflow/python/keras/utils/np_utils.py index 9d9c72b162700cb3bca2cf83d56db30f8df1deb9..c24e87308bee20e4ed978514699d4beb2ee4fbb9 100644 --- a/tensorflow/python/keras/utils/np_utils.py +++ b/tensorflow/python/keras/utils/np_utils.py @@ -33,7 +33,8 @@ def to_categorical(y, num_classes=None): num_classes: total number of classes. Returns: - A binary matrix representation of the input. + A binary matrix representation of the input. The classes axis is placed + last. """ y = np.array(y, dtype='int') input_shape = y.shape diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 5d29c2e5f86bd3c4997cc3f18f4cb760dc87d63b..2fce3e9b38decbaf7f248549590f4df7fb52617e 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -566,6 +566,7 @@ tf_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:linalg_ops", ], + shard_count = 16, ) tf_py_test( @@ -701,7 +702,7 @@ tf_py_test( tf_py_test( name = "priority_queue_test", - size = "small", + size = "medium", srcs = ["priority_queue_test.py"], additional_deps = [ "//third_party/py/numpy", @@ -893,6 +894,7 @@ tf_py_test( "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", + "//tensorflow/python:sparse_grad", "//tensorflow/python:sparse_ops", ], ) @@ -1524,6 +1526,7 @@ cuda_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", ], + tags = ["no_windows_gpu"], ) cuda_py_test( @@ -1716,7 +1719,7 @@ cuda_py_test( cuda_py_test( name = "matmul_op_test", - size = "small", + size = "medium", srcs = ["matmul_op_test.py"], additional_deps = [ "//third_party/py/numpy", @@ -2056,6 +2059,7 @@ cuda_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", ], + tags = ["no_windows_gpu"], ) tf_py_test( @@ -2177,7 +2181,6 @@ cuda_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:parsing_ops", ], - tags = ["no_windows"], ) cuda_py_test( @@ -2754,6 +2757,7 @@ cuda_py_test( "//tensorflow/python:embedding_ops", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:init_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:partitioned_variables", @@ -3087,3 +3091,23 @@ tf_py_test( data = [":invalid_op.so"], tags = ["no_pip"], ) + +tf_py_test( + name = "cond_v2_test", + size = "medium", + srcs = ["cond_v2_test.py"], + additional_deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:cond_v2", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:training", + ], + grpc_enabled = True, + tags = ["no_gpu"], # TODO(b/111656070) +) diff --git a/tensorflow/python/kernel_tests/argmax_op_test.py b/tensorflow/python/kernel_tests/argmax_op_test.py index ce0676990221fb441b99043083647f9d65722db8..1202c463e80d21b7cf88e5596cfc64eaa38ef8ba 100644 --- a/tensorflow/python/kernel_tests/argmax_op_test.py +++ b/tensorflow/python/kernel_tests/argmax_op_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -115,6 +116,12 @@ class ArgMaxTest(test.TestCase): ans = op([1]).eval() self.assertAllEqual(ans, 0) + def testOutputEmpty(self): + with self.test_session(): + for op in math_ops.argmin, math_ops.argmax: + ret = op(array_ops.zeros(shape=[1, 0, 2]), axis=-1).eval() + self.assertEqual(ret.shape, (1, 0)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 08bf2d9c644bcde2a80e6138557dae6e19383dfd..40567571e6d259eff3f013c67d1d1f9504fcb9e4 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1006,7 +1006,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase): class ShapeSizeRankTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDenseShape(self): t_value = [[0, 42], [24, 0]] self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t_value))) @@ -1018,7 +1018,7 @@ class ShapeSizeRankTest(test_util.TensorFlowTestCase): self.assertEqual(4, self.evaluate(array_ops.size(t))) self.assertEqual(2, self.evaluate(array_ops.rank(t))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSparseShape(self): sp_value = sparse_tensor.SparseTensorValue( indices=((0, 1), (1, 0)), values=(42, 24), dense_shape=(2, 2)) @@ -1031,7 +1031,7 @@ class ShapeSizeRankTest(test_util.TensorFlowTestCase): self.assertEqual(4, self.evaluate(array_ops.size(sp))) self.assertEqual(2, self.evaluate(array_ops.rank(sp))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSizeDtype(self): tensor = [1] self.assertEqual(dtypes.int32, self.evaluate(array_ops.size(tensor)).dtype) @@ -1123,7 +1123,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): class ConcatSliceResourceTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConcatSlice(self): r1 = test_ops.stub_resource_handle_op(container="a", shared_name="b") r2 = test_ops.stub_resource_handle_op(container="a", shared_name="c") diff --git a/tensorflow/python/kernel_tests/as_string_op_test.py b/tensorflow/python/kernel_tests/as_string_op_test.py index 94ed8ebd31f5874024bb6b0988073ece15d39d87..51aa17babeabdd06f52e6363fb0992e97d7cede0 100644 --- a/tensorflow/python/kernel_tests/as_string_op_test.py +++ b/tensorflow/python/kernel_tests/as_string_op_test.py @@ -160,7 +160,7 @@ class AsStringOpTest(test.TestCase): complex_inputs_ = [(x + (x + 1) * 1j) for x in float_inputs_] with self.test_session(): - for dtype in (dtypes.complex64,): + for dtype in (dtypes.complex64, dtypes.complex128): input_ = array_ops.placeholder(dtype) def clean_nans(s_l): diff --git a/tensorflow/python/kernel_tests/atrous_convolution_test.py b/tensorflow/python/kernel_tests/atrous_convolution_test.py index 0ef08581c9f931b991ef0c1218dc503345e248c2..b98e5fd3866cde007c6c00ae0cf04b1f1c46c6f2 100644 --- a/tensorflow/python/kernel_tests/atrous_convolution_test.py +++ b/tensorflow/python/kernel_tests/atrous_convolution_test.py @@ -124,7 +124,7 @@ class AtrousConvolutionTest(test.TestCase): x, w, "VALID", dilation_rate=[2, 2], data_format="NCHW") self.assertEqual(y.shape.as_list(), [1, 20, None, None]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousConvolution2D(self): with self._delay_checks() as add_check: for padding in ["SAME", "VALID"]: @@ -139,7 +139,7 @@ class AtrousConvolutionTest(test.TestCase): dilation_rate=dilation_rate, ) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousConvolution3D(self): with self._delay_checks() as add_check: for padding in ["SAME", "VALID"]: @@ -158,7 +158,7 @@ class AtrousConvolutionTest(test.TestCase): dilation_rate=dilation_rate, ) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousConvolution1D(self): with self._delay_checks() as add_check: for padding in ["SAME", "VALID"]: @@ -173,7 +173,7 @@ class AtrousConvolutionTest(test.TestCase): dilation_rate=[rate], ) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousConvolutionNC(self): if test.is_gpu_available(cuda_only=True): # "NCW" and "NCHW" formats are currently supported only on CUDA. @@ -197,7 +197,7 @@ class AtrousConvolutionTest(test.TestCase): data_format="NCHW", ) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAtrousSequence(self): """Tests optimization of sequence of atrous convolutions. diff --git a/tensorflow/python/kernel_tests/bitcast_op_test.py b/tensorflow/python/kernel_tests/bitcast_op_test.py index a535468b058d289d5cc6611ff542d89615793834..a2c6b54273f7f617ee78253e6184befd8f81e4ac 100644 --- a/tensorflow/python/kernel_tests/bitcast_op_test.py +++ b/tensorflow/python/kernel_tests/bitcast_op_test.py @@ -76,12 +76,18 @@ class BitcastTest(test.TestCase): datatype = dtypes.int8 array_ops.bitcast(x, datatype, None) - def testQuantizeType(self): + def testQuantizedType(self): shape = [3, 4] x = np.zeros(shape, np.uint16) datatype = dtypes.quint16 self._testBitcast(x, datatype, shape) + def testUnsignedType(self): + shape = [3, 4] + x = np.zeros(shape, np.int64) + datatype = dtypes.uint64 + self._testBitcast(x, datatype, shape) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py index 92cd53a031e73d4ff4ac50c2465f32a2c20545a7..4e31b1ea2a796a2e83696d278cf1b4784d177150 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py @@ -910,7 +910,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): feature_1_values = [11, 27] # Example 1: tree 0: 1.14, tree 1: 5.0, tree 2: 5.0 = > - # logit = 0.1*5.0+0.2*5.0+1*5 + # logit = 0.1*1.14+0.2*5.0+1*5 # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = > # logit= 0.1*1.14+0.2*7.0-1*7.0 expected_logits = [[6.114], [-5.486]] @@ -925,5 +925,147 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertAllClose(expected_logits, logits) +class FeatureContribsOpsTest(test_util.TensorFlowTestCase): + """Tests feature contribs ops for model understanding.""" + + def testContribsMultipleTree(self): + """Tests that the contribs work when we have multiple trees.""" + with self.test_session() as session: + tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() + text_format.Merge( + """ + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 28 + left_id: 1 + right_id: 2 + } + metadata { + gain: 7.62 + original_leaf: {scalar: 2.1} + } + } + nodes { + leaf { + scalar: 1.14 + } + } + nodes { + leaf { + scalar: 8.79 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 2 + threshold: 26 + left_id: 1 + right_id: 2 + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 50 + left_id: 3 + right_id: 4 + } + metadata { + original_leaf: {scalar: 5.5} + } + } + nodes { + leaf { + scalar: 7.0 + } + } + nodes { + leaf { + scalar: 5.0 + } + } + nodes { + leaf { + scalar: 6.0 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 0 + threshold: 34 + left_id: 1 + right_id: 2 + } + } + nodes { + leaf { + scalar: -7.0 + } + } + nodes { + leaf { + scalar: 5.0 + } + } + } + tree_weights: 0.1 + tree_weights: 0.2 + tree_weights: 1.0 + tree_metadata: { + num_layers_grown: 1} + tree_metadata: { + num_layers_grown: 2} + tree_metadata: { + num_layers_grown: 1} + """, tree_ensemble_config) + + tree_ensemble = boosted_trees_ops.TreeEnsemble( + 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) + tree_ensemble_handle = tree_ensemble.resource_handle + resources.initialize_resources(resources.shared_resources()).run() + + feature_0_values = [36, 32] + feature_1_values = [13, -29] # Unused. Feature is not in above ensemble. + feature_2_values = [11, 27] + + # Expected logits are computed by traversing the logit path and + # subtracting child logits from parent logits. + bias = 2.1 * 0.1 # Root node of tree_0. + expected_feature_ids = ((2, 2, 0, 0), (2, 2, 0)) + # example_0 : (bias, 0.1 * 1.14, 0.2 * 5.5 + .114, 0.2 * 5. + .114, + # 1.0 * 5.0 + 0.2 * 5. + .114) + # example_1 : (bias, 0.1 * 1.14, 0.2 * 7 + .114, + # 1.0 * -7. + 0.2 * 7 + .114) + expected_logits_paths = ((bias, 0.114, 1.214, 1.114, 6.114), + (bias, 0.114, 1.514, -5.486)) + + bucketized_features = [ + feature_0_values, feature_1_values, feature_2_values + ] + + debug_op = boosted_trees_ops.example_debug_outputs( + tree_ensemble_handle, + bucketized_features=bucketized_features, + logits_dimension=1) + + serialized_examples_debug_outputs = session.run(debug_op) + feature_ids = [] + logits_paths = [] + for example in serialized_examples_debug_outputs: + example_debug_outputs = boosted_trees_pb2.DebugOutput() + example_debug_outputs.ParseFromString(example) + feature_ids.append(example_debug_outputs.feature_ids) + logits_paths.append(example_debug_outputs.logits_path) + + self.assertAllClose(feature_ids, expected_feature_ids) + self.assertAllClose(logits_paths, expected_logits_paths) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py index 13b804875e94a9f8acc9c441ba2525876a3ef58f..d55240297a8b972ea926186c2fa38da5da780612 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py @@ -139,6 +139,49 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): self.assertEqual(new_stamp, 1) self.assertProtoEquals(expected_result, tree_ensemble) + def testBiasCenteringOnEmptyEnsemble(self): + """Test growing with bias centering on an empty ensemble.""" + with self.test_session() as session: + # Create empty ensemble. + tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble') + tree_ensemble_handle = tree_ensemble.resource_handle + resources.initialize_resources(resources.shared_resources()).run() + + gradients = np.array([[5.]], dtype=np.float32) + hessians = np.array([[24.]], dtype=np.float32) + + # Grow tree ensemble. + grow_op = boosted_trees_ops.center_bias( + tree_ensemble_handle, + mean_gradients=gradients, + mean_hessians=hessians, + l1=0.0, + l2=1.0 + ) + session.run(grow_op) + + new_stamp, serialized = session.run(tree_ensemble.serialize()) + + tree_ensemble = boosted_trees_pb2.TreeEnsemble() + tree_ensemble.ParseFromString(serialized) + + expected_result = """ + trees { + nodes { + leaf { + scalar: -0.2 + } + } + } + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 0 + is_finalized: false + } + """ + self.assertEqual(new_stamp, 1) + self.assertProtoEquals(expected_result, tree_ensemble) + def testGrowExistingEnsembleTreeNotFinalized(self): """Test growing an existing ensemble with the last tree not finalized.""" with self.test_session() as session: @@ -666,7 +709,6 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): num_layers_attempted: 1 last_layer_node_start: 1 last_layer_node_end: 3 - } """, tree_ensemble_config) diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 7ef841c96b5cec9c7ae56c631896231ed663b8be..bda6ca5ca91ab1f55c4586f604a116a9b3fed874 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -34,45 +34,45 @@ from tensorflow.python.platform import test class AssertProperIterableTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_tensor_raises(self): tensor = constant_op.constant(1) with self.assertRaisesRegexp(TypeError, "proper"): check_ops.assert_proper_iterable(tensor) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_sparse_tensor_raises(self): ten = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) with self.assertRaisesRegexp(TypeError, "proper"): check_ops.assert_proper_iterable(ten) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_ndarray_raises(self): array = np.array([1, 2, 3]) with self.assertRaisesRegexp(TypeError, "proper"): check_ops.assert_proper_iterable(array) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_single_string_raises(self): mystr = "hello" with self.assertRaisesRegexp(TypeError, "proper"): check_ops.assert_proper_iterable(mystr) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_non_iterable_object_raises(self): non_iterable = 1234 with self.assertRaisesRegexp(TypeError, "to be iterable"): check_ops.assert_proper_iterable(non_iterable) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_list_does_not_raise(self): list_of_stuff = [ constant_op.constant([11, 22]), constant_op.constant([1, 2]) ] check_ops.assert_proper_iterable(list_of_stuff) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_generator_does_not_raise(self): generator_of_stuff = (constant_op.constant([11, 22]), constant_op.constant( [1, 2])) @@ -81,14 +81,14 @@ class AssertProperIterableTest(test.TestCase): class AssertEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal(self): small = constant_op.constant([1, 2], name="small") with ops.control_dependencies([check_ops.assert_equal(small, small)]): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_scalar_comparison(self): const_true = constant_op.constant(True, name="true") const_false = constant_op.constant(False, name="false") @@ -101,7 +101,7 @@ class AssertEqualTest(test.TestCase): x = check_ops.assert_equal(small, small) assert x is None - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_greater(self): # Static check static_small = constant_op.constant([1, 2], name="small") @@ -179,7 +179,7 @@ First 2 elements of y: check_ops.assert_equal(big, small, message="big does not equal small", summarize=2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less(self): # Static check static_small = constant_op.constant([3, 1], name="small") @@ -196,7 +196,7 @@ First 2 elements of y: with self.assertRaisesOpError("small.*big"): out.eval(feed_dict={small: [3, 1], big: [4, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal_and_broadcastable_shapes(self): small = constant_op.constant([[1, 2], [1, 2]], name="small") small_2 = constant_op.constant([1, 2], name="small_2") @@ -204,7 +204,7 @@ First 2 elements of y: out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_equal_but_non_broadcastable_shapes(self): small = constant_op.constant([1, 1, 1], name="small") small_2 = constant_op.constant([1, 1], name="small_2") @@ -219,13 +219,13 @@ First 2 elements of y: out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_not_equal_and_broadcastable_shapes(self): cond = constant_op.constant([True, False], name="small") with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): check_ops.assert_equal(cond, False, message="fail") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -236,7 +236,7 @@ First 2 elements of y: class AssertNoneEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_not_equal(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([10, 20], name="small") @@ -245,7 +245,7 @@ class AssertNoneEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_equal(self): small = constant_op.constant([3, 1], name="small") with self.assertRaisesOpError("x != y did not hold"): @@ -254,7 +254,7 @@ class AssertNoneEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3], name="big") @@ -263,7 +263,7 @@ class AssertNoneEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_not_equal_but_non_broadcastable_shapes(self): with self.test_session(): small = constant_op.constant([1, 1, 1], name="small") @@ -280,7 +280,7 @@ class AssertNoneEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): with self.test_session(): larry = constant_op.constant([]) @@ -300,7 +300,7 @@ class AssertNoneEqualTest(test.TestCase): class AssertAllCloseTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal(self): x = constant_op.constant(1., name="x") y = constant_op.constant(1., name="y") @@ -309,7 +309,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_32_bit_due_to_default_rtol(self): eps = np.finfo(np.float32).eps # Default rtol/atol is 10*eps @@ -320,7 +320,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_32_bit_due_to_default_atol(self): eps = np.finfo(np.float32).eps # Default rtol/atol is 10*eps @@ -331,7 +331,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_64_bit_due_to_default_rtol(self): eps = np.finfo(np.float64).eps # Default rtol/atol is 10*eps @@ -342,7 +342,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_64_bit_due_to_default_atol(self): eps = np.finfo(np.float64).eps # Default rtol/atol is 10*eps @@ -353,7 +353,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_due_to_custom_rtol(self): x = constant_op.constant(1., name="x") y = constant_op.constant(1.1, name="y") @@ -363,7 +363,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_close_enough_due_to_custom_atol(self): x = constant_op.constant(0., name="x") y = constant_op.constant(0.1, name="y", dtype=np.float32) @@ -373,7 +373,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -381,7 +381,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(larry) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_atol_violated(self): x = constant_op.constant(10., name="x") y = constant_op.constant(10.2, name="y") @@ -392,7 +392,7 @@ class AssertAllCloseTest(test.TestCase): out = array_ops.identity(x) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_default_rtol_violated(self): x = constant_op.constant(0.1, name="x") y = constant_op.constant(0.0, name="y") @@ -412,7 +412,7 @@ class AssertAllCloseTest(test.TestCase): class AssertLessTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_equal(self): small = constant_op.constant([1, 2], name="small") with self.assertRaisesOpError("failure message.*\n*.* x < y did not hold"): @@ -422,7 +422,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_greater(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 4], name="big") @@ -431,7 +431,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_less(self): small = constant_op.constant([3, 1], name="small") big = constant_op.constant([4, 2], name="big") @@ -439,7 +439,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_less_and_broadcastable_shapes(self): small = constant_op.constant([1], name="small") big = constant_op.constant([3, 2], name="big") @@ -447,7 +447,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less_but_non_broadcastable_shapes(self): small = constant_op.constant([1, 1, 1], name="small") big = constant_op.constant([3, 2], name="big") @@ -462,7 +462,7 @@ class AssertLessTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -480,7 +480,7 @@ class AssertLessTest(test.TestCase): class AssertLessEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal(self): small = constant_op.constant([1, 2], name="small") with ops.control_dependencies( @@ -488,7 +488,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_greater(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 4], name="big") @@ -499,7 +499,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_less_equal(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 2], name="big") @@ -507,7 +507,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_less_equal_and_broadcastable_shapes(self): small = constant_op.constant([1], name="small") big = constant_op.constant([3, 1], name="big") @@ -515,7 +515,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less_equal_but_non_broadcastable_shapes(self): small = constant_op.constant([3, 1], name="small") big = constant_op.constant([1, 1, 1], name="big") @@ -531,7 +531,7 @@ class AssertLessEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -543,7 +543,7 @@ class AssertLessEqualTest(test.TestCase): class AssertGreaterTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_equal(self): small = constant_op.constant([1, 2], name="small") with self.assertRaisesOpError("fail"): @@ -553,7 +553,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 4], name="big") @@ -562,7 +562,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(big) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_greater(self): small = constant_op.constant([3, 1], name="small") big = constant_op.constant([4, 2], name="big") @@ -570,7 +570,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_greater_and_broadcastable_shapes(self): small = constant_op.constant([1], name="small") big = constant_op.constant([3, 2], name="big") @@ -578,7 +578,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_greater_but_non_broadcastable_shapes(self): small = constant_op.constant([1, 1, 1], name="small") big = constant_op.constant([3, 2], name="big") @@ -593,7 +593,7 @@ class AssertGreaterTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -604,7 +604,7 @@ class AssertGreaterTest(test.TestCase): class AssertGreaterEqualTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_equal(self): small = constant_op.constant([1, 2], name="small") with ops.control_dependencies( @@ -612,7 +612,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 4], name="big") @@ -623,7 +623,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_greater_equal(self): small = constant_op.constant([1, 2], name="small") big = constant_op.constant([3, 2], name="big") @@ -632,7 +632,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_greater_equal_and_broadcastable_shapes(self): small = constant_op.constant([1], name="small") big = constant_op.constant([3, 1], name="big") @@ -641,7 +641,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_less_equal_but_non_broadcastable_shapes(self): small = constant_op.constant([1, 1, 1], name="big") big = constant_op.constant([3, 1], name="small") @@ -657,7 +657,7 @@ class AssertGreaterEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): larry = constant_op.constant([]) curly = constant_op.constant([]) @@ -669,14 +669,14 @@ class AssertGreaterEqualTest(test.TestCase): class AssertNegativeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_negative(self): frank = constant_op.constant([-1, -2], name="frank") with ops.control_dependencies([check_ops.assert_negative(frank)]): out = array_ops.identity(frank) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_positive(self): doug = constant_op.constant([1, 2], name="doug") with self.assertRaisesOpError("fail"): @@ -686,7 +686,7 @@ class AssertNegativeTest(test.TestCase): out = array_ops.identity(doug) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_zero(self): claire = constant_op.constant([0], name="claire") with self.assertRaisesOpError("x < 0 did not hold"): @@ -694,7 +694,7 @@ class AssertNegativeTest(test.TestCase): out = array_ops.identity(claire) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_doesnt_raise(self): # A tensor is negative when it satisfies: # For every element x_i in x, x_i < 0 @@ -708,7 +708,7 @@ class AssertNegativeTest(test.TestCase): class AssertPositiveTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_negative(self): freddie = constant_op.constant([-1, -2], name="freddie") with self.assertRaisesOpError("fail"): @@ -718,14 +718,14 @@ class AssertPositiveTest(test.TestCase): out = array_ops.identity(freddie) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_positive(self): remmy = constant_op.constant([1, 2], name="remmy") with ops.control_dependencies([check_ops.assert_positive(remmy)]): out = array_ops.identity(remmy) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_zero(self): meechum = constant_op.constant([0], name="meechum") with self.assertRaisesOpError("x > 0 did not hold"): @@ -733,7 +733,7 @@ class AssertPositiveTest(test.TestCase): out = array_ops.identity(meechum) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_doesnt_raise(self): # A tensor is positive when it satisfies: # For every element x_i in x, x_i > 0 @@ -747,7 +747,7 @@ class AssertPositiveTest(test.TestCase): class AssertRankTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self): tensor = constant_op.constant(1, name="my_tensor") desired_rank = 1 @@ -768,7 +768,7 @@ class AssertRankTest(test.TestCase): with self.assertRaisesOpError("fail.*my_tensor.*rank"): array_ops.identity(tensor).eval(feed_dict={tensor: 0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self): tensor = constant_op.constant(1, name="my_tensor") desired_rank = 0 @@ -784,7 +784,7 @@ class AssertRankTest(test.TestCase): [check_ops.assert_rank(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: 0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_raises_if_rank_too_large_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 0 @@ -802,7 +802,7 @@ class AssertRankTest(test.TestCase): with self.assertRaisesOpError("my_tensor.*rank"): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 1 @@ -818,7 +818,7 @@ class AssertRankTest(test.TestCase): [check_ops.assert_rank(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 2 @@ -836,7 +836,7 @@ class AssertRankTest(test.TestCase): with self.assertRaisesOpError("my_tensor.*rank"): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_if_rank_is_not_scalar_static(self): tensor = constant_op.constant([1, 2], name="my_tensor") with self.assertRaisesRegexp(ValueError, "Rank must be a scalar"): @@ -852,7 +852,7 @@ class AssertRankTest(test.TestCase): [check_ops.assert_rank(tensor, rank_tensor)]): array_ops.identity(tensor).eval(feed_dict={rank_tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_if_rank_is_not_integer_static(self): tensor = constant_op.constant([1, 2], name="my_tensor") with self.assertRaisesRegexp(TypeError, @@ -873,7 +873,7 @@ class AssertRankTest(test.TestCase): class AssertRankInTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_raises_if_rank_mismatch_static_rank(self): tensor_rank0 = constant_op.constant(42, name="my_tensor") with self.assertRaisesRegexp( @@ -890,7 +890,7 @@ class AssertRankInTest(test.TestCase): with self.assertRaisesOpError("fail.*my_tensor.*rank"): array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_doesnt_raise_if_rank_matches_static_rank(self): tensor_rank0 = constant_op.constant(42, name="my_tensor") for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): @@ -906,7 +906,7 @@ class AssertRankInTest(test.TestCase): check_ops.assert_rank_in(tensor_rank0, desired_ranks)]): array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_doesnt_raise_if_rank_matches_static_rank(self): tensor_rank1 = constant_op.constant([42, 43], name="my_tensor") for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): @@ -924,7 +924,7 @@ class AssertRankInTest(test.TestCase): tensor_rank1: (42.0, 43.0) }) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_raises_if_rank_mismatches_static_rank(self): tensor_rank1 = constant_op.constant((42, 43), name="my_tensor") with self.assertRaisesRegexp(ValueError, "rank"): @@ -942,7 +942,7 @@ class AssertRankInTest(test.TestCase): tensor_rank1: (42.0, 43.0) }) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_if_rank_is_not_scalar_static(self): tensor = constant_op.constant((42, 43), name="my_tensor") desired_ranks = ( @@ -966,7 +966,7 @@ class AssertRankInTest(test.TestCase): desired_ranks[1]: [2, 1], }) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_if_rank_is_not_integer_static(self): tensor = constant_op.constant((42, 43), name="my_tensor") with self.assertRaisesRegexp(TypeError, @@ -987,7 +987,7 @@ class AssertRankInTest(test.TestCase): class AssertRankAtLeastTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self): tensor = constant_op.constant(1, name="my_tensor") desired_rank = 1 @@ -1005,7 +1005,7 @@ class AssertRankAtLeastTest(test.TestCase): with self.assertRaisesOpError("my_tensor.*rank"): array_ops.identity(tensor).eval(feed_dict={tensor: 0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self): tensor = constant_op.constant(1, name="my_tensor") desired_rank = 0 @@ -1021,7 +1021,7 @@ class AssertRankAtLeastTest(test.TestCase): [check_ops.assert_rank_at_least(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: 0}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_ten_doesnt_raise_raise_if_rank_too_large_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 0 @@ -1037,7 +1037,7 @@ class AssertRankAtLeastTest(test.TestCase): [check_ops.assert_rank_at_least(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 1 @@ -1053,7 +1053,7 @@ class AssertRankAtLeastTest(test.TestCase): [check_ops.assert_rank_at_least(tensor, desired_rank)]): array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self): tensor = constant_op.constant([1, 2], name="my_tensor") desired_rank = 2 @@ -1074,7 +1074,7 @@ class AssertRankAtLeastTest(test.TestCase): class AssertNonNegativeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_negative(self): zoe = constant_op.constant([-1, -2], name="zoe") with self.assertRaisesOpError("x >= 0 did not hold"): @@ -1082,14 +1082,14 @@ class AssertNonNegativeTest(test.TestCase): out = array_ops.identity(zoe) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_zero_and_positive(self): lucas = constant_op.constant([0, 2], name="lucas") with ops.control_dependencies([check_ops.assert_non_negative(lucas)]): out = array_ops.identity(lucas) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_doesnt_raise(self): # A tensor is non-negative when it satisfies: # For every element x_i in x, x_i >= 0 @@ -1103,14 +1103,14 @@ class AssertNonNegativeTest(test.TestCase): class AssertNonPositiveTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_zero_and_negative(self): tom = constant_op.constant([0, -2], name="tom") with ops.control_dependencies([check_ops.assert_non_positive(tom)]): out = array_ops.identity(tom) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_positive(self): rachel = constant_op.constant([0, 2], name="rachel") with self.assertRaisesOpError("x <= 0 did not hold"): @@ -1118,7 +1118,7 @@ class AssertNonPositiveTest(test.TestCase): out = array_ops.identity(rachel) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_doesnt_raise(self): # A tensor is non-positive when it satisfies: # For every element x_i in x, x_i <= 0 @@ -1132,14 +1132,14 @@ class AssertNonPositiveTest(test.TestCase): class AssertIntegerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_integer(self): integers = constant_op.constant([1, 2], name="integers") with ops.control_dependencies([check_ops.assert_integer(integers)]): out = array_ops.identity(integers) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_float(self): floats = constant_op.constant([1.0, 2.0], name="floats") with self.assertRaisesRegexp(TypeError, "Expected.*integer"): @@ -1148,7 +1148,7 @@ class AssertIntegerTest(test.TestCase): class AssertTypeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_correct_type(self): integers = constant_op.constant([1, 2], dtype=dtypes.int64) with ops.control_dependencies([ @@ -1156,7 +1156,7 @@ class AssertTypeTest(test.TestCase): out = array_ops.identity(integers) self.evaluate(out) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_raises_when_wrong_type(self): floats = constant_op.constant([1.0, 2.0], dtype=dtypes.float16) with self.assertRaisesRegexp(TypeError, "must be of type.*float32"): @@ -1165,74 +1165,74 @@ class AssertTypeTest(test.TestCase): class IsStrictlyIncreasingTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_constant_tensor_is_not_strictly_increasing(self): self.assertFalse(self.evaluate(check_ops.is_strictly_increasing([1, 1, 1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_decreasing_tensor_is_not_strictly_increasing(self): self.assertFalse(self.evaluate( check_ops.is_strictly_increasing([1, 0, -1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_2d_decreasing_tensor_is_not_strictly_increasing(self): self.assertFalse( self.evaluate(check_ops.is_strictly_increasing([[1, 3], [2, 4]]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increasing_tensor_is_increasing(self): self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1, 2, 3]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increasing_rank_two_tensor(self): self.assertTrue( self.evaluate(check_ops.is_strictly_increasing([[-1, 2], [3, 4]]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_tensor_with_one_element_is_strictly_increasing(self): self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_is_strictly_increasing(self): self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([]))) class IsNonDecreasingTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_constant_tensor_is_non_decreasing(self): self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 1, 1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_decreasing_tensor_is_not_non_decreasing(self): self.assertFalse(self.evaluate(check_ops.is_non_decreasing([3, 2, 1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_2d_decreasing_tensor_is_not_non_decreasing(self): self.assertFalse(self.evaluate( check_ops.is_non_decreasing([[1, 3], [2, 4]]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increasing_rank_one_tensor_is_non_decreasing(self): self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 2, 3]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_increasing_rank_two_tensor(self): self.assertTrue(self.evaluate( check_ops.is_non_decreasing([[-1, 2], [3, 3]]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_tensor_with_one_element_is_non_decreasing(self): self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1]))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_empty_tensor_is_non_decreasing(self): self.assertTrue(self.evaluate(check_ops.is_non_decreasing([]))) class FloatDTypeTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_assert_same_float_dtype(self): self.assertIs(dtypes.float32, check_ops.assert_same_float_dtype(None, None)) @@ -1286,7 +1286,7 @@ class FloatDTypeTest(test.TestCase): class AssertScalarTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_assert_scalar(self): check_ops.assert_scalar(constant_op.constant(3)) check_ops.assert_scalar(constant_op.constant("foo")) diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py index fb52d10475fa47f37b1ee7de97b49878b5d13341..400d38b9366f8b9c25a2c761e058bc5d3a429db3 100644 --- a/tensorflow/python/kernel_tests/clip_ops_test.py +++ b/tensorflow/python/kernel_tests/clip_ops_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -369,6 +370,21 @@ class ClipTest(test.TestCase): self.assertAllClose(np_ans_0, tf_ans_1) self.assertAllClose(np_ans_1, tf_ans_2) + def testClipByGlobalNormInf(self): + with self.test_session(use_gpu=True): + x0 = constant_op.constant([-2.0, 0.0, np.inf, 4.0, 0.0, 0.0], + shape=[2, 3]) + x1 = constant_op.constant([1.0, -2.0]) + clip_norm = 6.0 + + ans, norm = clip_ops.clip_by_global_norm([x0, x1], clip_norm) + with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"): + norm.eval() + with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"): + ans[0].eval() + with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"): + ans[1].eval() + def testClipByAverageNormClipped(self): # Norm clipping when average clip_norm < 0.83333333 with self.test_session(use_gpu=True): diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b9910133d8ece1c8c1dcef7093335b4675db6105 --- /dev/null +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -0,0 +1,938 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for cond_v2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond_v2 +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver +from tensorflow.python.util import compat + + +class CondV2Test(test.TestCase): + + def _testCond(self, true_fn, false_fn, train_vals, feed_dict=None): + if not feed_dict: + feed_dict = {} + with self.test_session(graph=ops.get_default_graph()) as sess: + pred = array_ops.placeholder(dtypes.bool, name="pred") + + expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected") + actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") + + expected_grad = gradients_impl.gradients(expected, train_vals) + actual_grad = gradients_impl.gradients(actual, train_vals) + + sess_run_args = {pred: True} + sess_run_args.update(feed_dict) + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), sess_run_args) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + sess_run_args = {pred: False} + sess_run_args.update(feed_dict) + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), sess_run_args) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + def testBasic(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * 2.0 + + def false_fn(): + return y * 3.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testMultipleOutputs(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(3.0, name="y") + + def true_fn(): + return x * y, y + + def false_fn(): + return x, y * 3.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testBasic2(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * y * 2.0 + + def false_fn(): + return 2.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testNoInputs(self): + with self.test_session() as sess: + pred = array_ops.placeholder(dtypes.bool, name="pred") + + def true_fn(): + return constant_op.constant(1.0) + + def false_fn(): + return constant_op.constant(2.0) + + out = cond_v2.cond_v2(pred, true_fn, false_fn) + + self.assertEqual(sess.run(out, {pred: True}), (1.0,)) + self.assertEqual(sess.run(out, {pred: False}), (2.0,)) + + def _createCond(self, name): + pred = constant_op.constant(True, name="pred") + x = constant_op.constant(1.0, name="x") + + def true_fn(): + return x + + def false_fn(): + return x + 1 + + return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)[0].op + + def testDefaultName(self): + with ops.Graph().as_default(): + cond = self._createCond(None) + self.assertEqual(cond.name, "cond") + self.assertIn("cond_true", ops.get_default_graph()._functions) + self.assertIn("cond_false", ops.get_default_graph()._functions) + + with ops.Graph().as_default(): + with ops.name_scope("foo"): + cond = self._createCond("") + self.assertEqual(cond.name, "foo/cond") + self.assertIn("foo_cond_true", ops.get_default_graph()._functions) + self.assertIn("foo_cond_false", ops.get_default_graph()._functions) + + cond2 = self._createCond(None) + self.assertEqual(cond2.name, "foo/cond_1") + self.assertIn("foo_cond_1_true", ops.get_default_graph()._functions) + self.assertIn("foo_cond_1_false", ops.get_default_graph()._functions) + + def testDefunInCond(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + + @function.Defun() + def fn(): + return x * y * 2.0 + + return fn() + + def false_fn(): + return 2.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testNestedDefunInCond(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return 2.0 + + def false_fn(): + + @function.Defun() + def fn(): + + @function.Defun() + def nested_fn(): + return x * y * 2.0 + + return nested_fn() + + return fn() + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testDoubleNestedDefunInCond(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + + @function.Defun() + def fn(): + + @function.Defun() + def nested_fn(): + + @function.Defun() + def nested_nested_fn(): + return x * y * 2.0 + + return nested_nested_fn() + + return nested_fn() + + return fn() + + def false_fn(): + return 2.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testNestedCond(self): + + def run_test(pred_value): + + def build_graph(): + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return 2.0 + + def false_fn(): + + def false_true_fn(): + return x * y * 2.0 + + def false_false_fn(): + return x * 5.0 + + return _cond(pred, false_true_fn, false_false_fn, "inside_false_fn") + + return x, y, pred, true_fn, false_fn + + with ops.Graph().as_default(): + x, y, pred, true_fn, false_fn = build_graph() + self._testCond(true_fn, false_fn, [x, y], {pred: pred_value}) + self._testCond(true_fn, false_fn, [x], {pred: pred_value}) + self._testCond(true_fn, false_fn, [y], {pred: pred_value}) + + run_test(True) + run_test(False) + + def testNestedCondBothBranches(self): + + def run_test(pred_value): + + def build_graph(): + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return _cond(pred, lambda: x + y, lambda: x * x, name=None) + + def false_fn(): + return _cond(pred, lambda: x - y, lambda: y * y, name=None) + + return x, y, pred, true_fn, false_fn + + with ops.Graph().as_default(): + x, y, pred, true_fn, false_fn = build_graph() + self._testCond(true_fn, false_fn, [x, y], {pred: pred_value}) + self._testCond(true_fn, false_fn, [x], {pred: pred_value}) + self._testCond(true_fn, false_fn, [y], {pred: pred_value}) + + run_test(True) + run_test(False) + + def testDoubleNestedCond(self): + + def run_test(pred1_value, pred2_value): + + def build_graph(): + pred1 = array_ops.placeholder(dtypes.bool, name="pred1") + pred2 = array_ops.placeholder(dtypes.bool, name="pred2") + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return 2.0 + + def false_fn(): + + def false_true_fn(): + + def false_true_true_fn(): + return x * y * 2.0 + + def false_true_false_fn(): + return x * 10.0 + + return _cond( + pred1, + false_true_true_fn, + false_true_false_fn, + name="inside_false_true_fn") + + def false_false_fn(): + return x * 5.0 + + return _cond( + pred2, false_true_fn, false_false_fn, name="inside_false_fn") + + return x, y, pred1, pred2, true_fn, false_fn + + with ops.Graph().as_default(): + x, y, pred1, pred2, true_fn, false_fn = build_graph() + self._testCond(true_fn, false_fn, [x, y], { + pred1: pred1_value, + pred2: pred2_value + }) + x, y, pred1, pred2, true_fn, false_fn = build_graph() + self._testCond(true_fn, false_fn, [x], { + pred1: pred1_value, + pred2: pred2_value + }) + x, y, pred1, pred2, true_fn, false_fn = build_graph() + self._testCond(true_fn, false_fn, [y], { + pred1: pred1_value, + pred2: pred2_value + }) + + run_test(True, True) + run_test(True, False) + run_test(False, False) + run_test(False, True) + + def testGradientFromInsideDefun(self): + + def build_graph(): + pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer") + pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner") + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return 2.0 + + def false_fn(): + + def inner_true_fn(): + return x * y * 2.0 + + def inner_false_fn(): + return x * 5.0 + + return cond_v2.cond_v2( + pred_inner, inner_true_fn, inner_false_fn, name="inner_cond") + + cond_outer = cond_v2.cond_v2( + pred_outer, true_fn, false_fn, name="outer_cond") + + # Compute grads inside a Defun. + @function.Defun() + def nesting_fn(): + return gradients_impl.gradients(cond_outer, [x, y]) + + grads = nesting_fn() + + return grads, pred_outer, pred_inner + + with ops.Graph().as_default(): + grads, pred_outer, pred_inner = build_graph() + with self.test_session(graph=ops.get_default_graph()) as sess: + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: True, + pred_inner: True + }), [0., 0.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: True, + pred_inner: False + }), [0., 0.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: False, + pred_inner: True + }), [4., 2.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: False, + pred_inner: False + }), [5., 0.]) + + def testGradientFromInsideNestedDefun(self): + + def build_graph(): + pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer") + pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner") + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return 2.0 + + def false_fn(): + + def inner_true_fn(): + return x * y * 2.0 + + def inner_false_fn(): + return x * 5.0 + + return cond_v2.cond_v2( + pred_inner, inner_true_fn, inner_false_fn, name="inner_cond") + + cond_outer = cond_v2.cond_v2( + pred_outer, true_fn, false_fn, name="outer_cond") + + # Compute grads inside a Defun. + @function.Defun() + def nesting_fn(): + + @function.Defun() + def inner_nesting_fn(): + return gradients_impl.gradients(cond_outer, [x, y]) + + return inner_nesting_fn() + + grads = nesting_fn() + + return grads, pred_outer, pred_inner + + with ops.Graph().as_default(): + grads, pred_outer, pred_inner = build_graph() + with self.test_session(graph=ops.get_default_graph()) as sess: + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: True, + pred_inner: True + }), [0., 0.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: True, + pred_inner: False + }), [0., 0.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: False, + pred_inner: True + }), [4., 2.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: False, + pred_inner: False + }), [5., 0.]) + + def testBuildCondAndGradientInsideDefun(self): + + def build_graph(): + pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer") + pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner") + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + # Build cond and its gradient inside a Defun. + @function.Defun() + def fn(): + + def true_fn(): + return 2.0 + + def false_fn(): + + def inner_true_fn(): + return x * y * 2.0 + + def inner_false_fn(): + return x * 5.0 + + return cond_v2.cond_v2( + pred_inner, inner_true_fn, inner_false_fn, name="inner_cond") + + cond_outer = cond_v2.cond_v2( + pred_outer, true_fn, false_fn, name="outer_cond") + return gradients_impl.gradients(cond_outer, [x, y]) + + grads = fn() + + return grads, pred_outer, pred_inner + + with ops.Graph().as_default(): + grads, pred_outer, pred_inner = build_graph() + with self.test_session(graph=ops.get_default_graph()) as sess: + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: True, + pred_inner: True + }), [0., 0.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: True, + pred_inner: False + }), [0., 0.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: False, + pred_inner: True + }), [4., 2.]) + self.assertSequenceEqual( + sess.run(grads, { + pred_outer: False, + pred_inner: False + }), [5., 0.]) + + def testSecondDerivative(self): + with self.test_session() as sess: + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(3.0, name="x") + + def true_fn(): + return math_ops.pow(x, 3) + + def false_fn(): + return x + + cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") + cond_grad = gradients_impl.gradients(cond, [x]) + cond_grad_grad = gradients_impl.gradients(cond_grad, [x]) + + # d[x^3]/dx = 3x^2 + true_val = sess.run(cond_grad, {pred: True}) + self.assertEqual(true_val, [27.0]) + # d[x]/dx = 1 + false_val = sess.run(cond_grad, {pred: False}) + self.assertEqual(false_val, [1.0]) + + true_val = sess.run(cond_grad_grad, {pred: True}) + # d2[x^3]/dx2 = 6x + self.assertEqual(true_val, [18.0]) + false_val = sess.run(cond_grad_grad, {pred: False}) + # d2[x]/dx2 = 0 + self.assertEqual(false_val, [0.0]) + + def testGradientOfDeserializedCond(self): + with ops.Graph().as_default(): + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(3.0, name="x") + ops.add_to_collection("x", x) + + def true_fn(): + return math_ops.pow(x, 3) + + def false_fn(): + return x + + ops.add_to_collection("pred", pred) + cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") + for c in cond: + ops.add_to_collection("cond", c) + meta_graph = saver.export_meta_graph() + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + saver.import_meta_graph(meta_graph) + x = ops.get_collection("x")[0] + pred = ops.get_collection("pred")[0] + cond = ops.get_collection("cond") + cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad") + cond_grad_grad = gradients_impl.gradients( + cond_grad, [x], name="cond_grad_grad") + # d[x^3]/dx = 3x^2 + true_val = sess.run(cond_grad, {pred: True}) + self.assertEqual(true_val, [27.0]) + # d[x]/dx = 1 + false_val = sess.run(cond_grad, {pred: False}) + self.assertEqual(false_val, [1.0]) + + true_val = sess.run(cond_grad_grad, {pred: True}) + # d2[x^3]/dx2 = 6x + self.assertEqual(true_val, [18.0]) + false_val = sess.run(cond_grad_grad, {pred: False}) + # d2[x]/dx2 = 0 + self.assertEqual(false_val, [0.0]) + + def testLowering(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + out_cond = self._createCond("cond") + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond, options=run_options, run_metadata=run_metadata) + + # If lowering was enabled, there should be a `Switch` node + switch_found = any( + any(node.op == "Switch" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertTrue(switch_found, + "A `Switch` op should exist if the graph was lowered.") + + # If lowering was enabled, there should be no `If` node + if_found = any( + any(node.op == "If" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertFalse(if_found, + "An `If` op was found, but it should be lowered.") + + def testLoweringDisabledInXLA(self): + with self.test_session(graph=ops.Graph()) as sess: + # Build the cond_v2 in an XLA context + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + out_cond = self._createCond("cond") + xla_context.Exit() + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond, options=run_options, run_metadata=run_metadata) + + # Lowering disabled in XLA, there should be no `Switch` node + switch_found = any( + any(node.op == "Switch" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertFalse( + switch_found, + "A `Switch` op exists, but the graph should not be lowered.") + + # Lowering disabled in XLA, there should still be an `If` node + if_found = any( + any(node.op == "If" for node in graph.node) + for graph in run_metadata.partition_graphs + ) + + self.assertTrue( + if_found, + "An `If` op was not found, but the graph should not be lowered.") + + +class CondV2CollectionTest(test.TestCase): + + def testCollectionIntValueAccessInCond(self): + """Read values from graph collections inside of cond_v2.""" + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + x = 2 + y = 5 + ops.add_to_collection("x", x) + ops.add_to_collection("y", y) + def fn(): + x_const = constant_op.constant(ops.get_collection("x")[0]) + y_const = constant_op.constant(ops.get_collection("y")[0]) + return math_ops.add(x_const, y_const) + + cnd = cond_v2.cond_v2(True, fn, fn) + self.assertEquals(cnd[0].eval(), 7) + + def testCollectionTensorValueAccessInCond(self): + """Read tensors from collections inside of cond_v2 & use them.""" + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + x = constant_op.constant(2) + y = constant_op.constant(5) + ops.add_to_collection("x", x) + ops.add_to_collection("y", y) + + def fn(): + x_read = ops.get_collection("x")[0] + y_read = ops.get_collection("y")[0] + return math_ops.add(x_read, y_read) + + cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn) + self.assertEquals(cnd[0].eval(), 7) + + def testCollectionIntValueWriteInCond(self): + """Make sure Int writes to collections work inside of cond_v2.""" + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + x = constant_op.constant(2) + y = constant_op.constant(5) + def true_fn(): + z = math_ops.add(x, y) + ops.add_to_collection("z", 7) + return math_ops.mul(x, z) + + def false_fn(): + z = math_ops.add(x, y) + return math_ops.mul(x, z) + + cnd = cond_v2.cond_v2( + True, true_fn, + false_fn) + self.assertEquals(cnd[0].eval(), 14) + + read_z_collection = ops.get_collection("z") + self.assertEquals(read_z_collection, [7]) + + +class CondV2ContainerTest(test.TestCase): + + def testContainer(self): + """Set containers outside & inside of cond_v2. + + Make sure the containers are set correctly for both variable creation + (tested by variables.Variable) and for stateful ops (tested by FIFOQueue) + """ + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + + v0 = variables.Variable([0]) + q0 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + def container(node): + return node.op.get_attr("container") + + self.assertEqual(compat.as_bytes(""), container(v0)) + self.assertEqual(compat.as_bytes(""), container(q0.queue_ref)) + + def true_fn(): + # When this branch is created in cond below, + # the container should begin with 'l1' + v1 = variables.Variable([1]) + q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + with ops.container("l2t"): + v2 = variables.Variable([2]) + q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + v3 = variables.Variable([1]) + q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + self.assertEqual(compat.as_bytes("l1"), container(v1)) + self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) + self.assertEqual(compat.as_bytes("l2t"), container(v2)) + self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref)) + self.assertEqual(compat.as_bytes("l1"), container(v3)) + self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) + + return constant_op.constant(2.0) + + def false_fn(): + # When this branch is created in cond below, + # the container should begin with 'l1' + v1 = variables.Variable([1]) + q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + with ops.container("l2f"): + v2 = variables.Variable([2]) + q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + v3 = variables.Variable([1]) + q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + self.assertEqual(compat.as_bytes("l1"), container(v1)) + self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) + self.assertEqual(compat.as_bytes("l2f"), container(v2)) + self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref)) + self.assertEqual(compat.as_bytes("l1"), container(v3)) + self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) + + return constant_op.constant(6.0) + + with ops.container("l1"): + cnd_true = cond_v2.cond_v2(True, true_fn, false_fn) + self.assertEquals(cnd_true[0].eval(), 2) + + cnd_false = cond_v2.cond_v2(False, true_fn, false_fn) + self.assertEquals(cnd_false[0].eval(), 6) + + v4 = variables.Variable([3]) + q4 = data_flow_ops.FIFOQueue(1, dtypes.float32) + v5 = variables.Variable([4]) + q5 = data_flow_ops.FIFOQueue(1, dtypes.float32) + + self.assertEqual(compat.as_bytes("l1"), container(v4)) + self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref)) + self.assertEqual(compat.as_bytes(""), container(v5)) + self.assertEqual(compat.as_bytes(""), container(q5.queue_ref)) + + +class CondV2ColocationGroupAndDeviceTest(test.TestCase): + + def testColocateWithBeforeCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + + a = constant_op.constant([2.0], name="a") + b = constant_op.constant([2.0], name="b") + + def fn(): + c = constant_op.constant(3.0) + self.assertEqual([b"loc:@a"], c.op.colocation_groups()) + return c + + with ops.colocate_with(a.op): + self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3) + + def fn2(): + c = constant_op.constant(3.0) + self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups()) + return c + + with ops.colocate_with(a.op): + with ops.colocate_with(b.op): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + def testColocateWithInAndOutOfCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + + a = constant_op.constant([2.0], name="a") + b = constant_op.constant([2.0], name="b") + + def fn2(): + with ops.colocate_with(b.op): + c = constant_op.constant(3.0) + self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups()) + return c + + with ops.colocate_with(a.op): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + d = constant_op.constant([2.0], name="d") + self.assertEqual([b"loc:@a"], d.op.colocation_groups()) + + def testColocateWithInCondGraphPartitioning(self): + with ops.Graph().as_default() as g: + with self.test_session( + graph=g, + config=config_pb2.ConfigProto(device_count={"CPU": 2}) + ) as sess: + + with ops.device("/device:CPU:0"): + a = constant_op.constant([2.0], name="a") + with ops.device("/device:CPU:1"): + b = constant_op.constant([2.0], name="b") + + def fn(): + with ops.colocate_with(b.op): + c = math_ops.add(a, a, name="c") + return c + out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond_2, options=run_options, run_metadata=run_metadata) + + # We expect there to be two partitions because of the + # colocate_with. We are only running the cond, which has a data + # dependency on `a` but not on `b`. So, without the colocate_with + # we would expect execution on just one device. + self.assertTrue(len(run_metadata.partition_graphs) >= 2) + + def testDeviceBeforeCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + def fn(): + c = constant_op.constant(3.0) + self.assertEqual("/device:CPU:0", c.op.device) + return c + + with ops.device("/device:CPU:0"): + self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3) + + def fn2(): + c = constant_op.constant(3.0) + self.assertEqual("/device:GPU:0", c.op.device) + return c + + with ops.device("/device:GPU:0"): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + def testDeviceInAndOutOfCond(self): + with ops.Graph().as_default() as g: + with self.test_session(graph=g): + def fn2(): + with ops.device("/device:GPU:0"): + c = constant_op.constant(3.0) + self.assertEqual("/device:GPU:0", c.op.device) + return c + + with ops.device("/device:CPU:0"): + self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) + + d = constant_op.constant(4.0) + self.assertEqual("/device:CPU:0", d.op.device) + + def testDeviceInCondGraphPartitioning(self): + with ops.Graph().as_default() as g: + with self.test_session( + graph=g, + config=config_pb2.ConfigProto(device_count={"CPU": 2}) + ) as sess: + + def fn(): + with ops.device("/device:CPU:1"): + c = math_ops.add(a, a, name="c") + return c + + with ops.device("/device:CPU:0"): + a = constant_op.constant([2.0], name="a") + out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + run_metadata = config_pb2.RunMetadata() + sess.run(out_cond_2, options=run_options, run_metadata=run_metadata) + + self.assertTrue(len(run_metadata.partition_graphs) >= 2) + + +def _cond(pred, true_fn, false_fn, name): + if _is_old_cond(): + return control_flow_ops.cond(pred, true_fn, false_fn, name=name) + else: + return cond_v2.cond_v2(pred, true_fn, false_fn, name=name) + + +def _is_old_cond(): + return isinstance(ops.get_default_graph()._get_control_flow_context(), + control_flow_ops.CondContext) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py index 79e419867d70071280b7c88b6bfa820b935b24cd..ae6875340e776fc6808be3f4afeb59644245c886 100644 --- a/tensorflow/python/kernel_tests/confusion_matrix_test.py +++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py @@ -34,7 +34,7 @@ from tensorflow.python.platform import test class ConfusionMatrixTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testExample(self): """This is a test of the example provided in pydoc.""" with self.test_session(): diff --git a/tensorflow/python/kernel_tests/constant_op_eager_test.py b/tensorflow/python/kernel_tests/constant_op_eager_test.py index 8e9d75667d49bf9e377ccb9290a3a91786b5a1cb..a0d5557b925162b254e34e9fc0971393ec119059 100644 --- a/tensorflow/python/kernel_tests/constant_op_eager_test.py +++ b/tensorflow/python/kernel_tests/constant_op_eager_test.py @@ -32,6 +32,9 @@ from tensorflow.python.util import compat # TODO(josh11b): add tests with lists/tuples, Shape. +# TODO(ashankar): Collapse with tests in constant_op_test.py and use something +# like the test_util.run_in_graph_and_eager_modes decorator to confirm +# equivalence between graph and eager execution. class ConstantTest(test.TestCase): def _testCpu(self, x): @@ -280,6 +283,34 @@ class ConstantTest(test.TestCase): with self.assertRaisesRegexp(ValueError, None): constant_op.constant([[1, 2], [3], [4, 5]]) + # TODO(ashankar): This test fails with graph construction since + # tensor_util.make_tensor_proto (invoked from constant_op.constant) + # does not handle iterables (it relies on numpy conversion). + # For consistency, should graph construction handle Python objects + # that implement the sequence protocol (but not numpy conversion), + # or should eager execution fail on such sequences? + def testCustomSequence(self): + + # This is inspired by how many objects in pandas are implemented: + # - They implement the Python sequence protocol + # - But may raise a KeyError on __getitem__(self, 0) + # See https://github.com/tensorflow/tensorflow/issues/20347 + class MySeq(object): + + def __getitem__(self, key): + if key != 1 and key != 3: + raise KeyError(key) + return key + + def __len__(self): + return 2 + + def __iter__(self): + l = list([1, 3]) + return l.__iter__() + + self.assertAllEqual([1, 3], self.evaluate(constant_op.constant(MySeq()))) + class AsTensorTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 68873df97ea2a632d98de4936a20a1f81bce93e9..1a29d0816df71c45d282c87dbb08de388ca0a621 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -647,7 +647,8 @@ class ControlFlowTest(test.TestCase): # feeding into the fill is dominated by a Switch. zero = graph.get_operation_by_name("gradients/zeros/Const") self.assertEqual(len(zero.control_inputs), 1) - self.assertEqual(zero.control_inputs[0].type, "Switch") + self.assertEqual(zero.control_inputs[0].type, "Identity") + self.assertEqual(zero.control_inputs[0].inputs[0].op.type, "Switch") def testCondGrad_2(self): with self.test_session(): @@ -734,11 +735,11 @@ class ControlFlowTest(test.TestCase): def body_fn(i): with ops.control_dependencies([increment]): - return i + i + return i + 1 - result = control_flow_ops.while_loop(cond=lambda i: i < 1, + result = control_flow_ops.while_loop(cond=lambda i: i < 2, body=body_fn, loop_vars=[1]) - result.eval() + self.assertAllEqual(result.eval(), 2) self.assertAllEqual(v.eval(), 1.0) def testWhileExternalControlDependenciesNoInput(self): diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index 450428707dab0d5fb88d180b8e54bb2c56958d8d..00de94f0041294c7f6b183c4caf5f92bfe1c25dd 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -345,7 +345,7 @@ class Conv2DTest(test.TestCase): self.assertAllClose(expected, np.ravel(value), atol=tol, rtol=tol) self.assertShapeEqual(value, conv) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D1x1Filter(self): expected_output = [ 30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0, @@ -358,7 +358,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Filter2x1Dilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[1, 4, 4, 1], @@ -367,7 +367,7 @@ class Conv2DTest(test.TestCase): dilations=[2, 1], padding="VALID") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DEmpty(self): expected_output = [] self._VerifyValues( @@ -377,7 +377,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DEmptyDilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[0, 2, 3, 3], @@ -386,7 +386,7 @@ class Conv2DTest(test.TestCase): dilations=[2, 1], padding="VALID") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Filter(self): # The outputs are computed using third_party/py/IPython/notebook. expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0] @@ -397,7 +397,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2FilterDilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[1, 2, 3, 3], @@ -406,7 +406,7 @@ class Conv2DTest(test.TestCase): dilations=[1, 2], padding="VALID") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D1x2Filter(self): # The outputs are computed using third_party/py/IPython/notebook. expected_output = [ @@ -420,7 +420,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D1x2FilterDilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[1, 2, 3, 3], @@ -429,7 +429,7 @@ class Conv2DTest(test.TestCase): dilations=[2, 1], padding="VALID") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2FilterStride2(self): expected_output = [2271.0, 2367.0, 2463.0] self._VerifyValues( @@ -439,7 +439,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2FilterStride2Same(self): expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0] self._VerifyValues( @@ -449,7 +449,7 @@ class Conv2DTest(test.TestCase): padding="SAME", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2FilterStride1x2(self): expected_output = [58.0, 78.0, 98.0, 118.0, 138.0, 158.0] self._VerifyValues( @@ -459,7 +459,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSmallerThanStrideValid(self): expected_output = [65, 95, 275, 305] self._VerifyValues( @@ -469,7 +469,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSmallerThanStrideSame(self): self._VerifyValues( tensor_in_sizes=[1, 3, 3, 1], @@ -492,7 +492,7 @@ class Conv2DTest(test.TestCase): padding="SAME", expected=[44, 28, 41, 16]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSizeMatchesInputSize(self): self._VerifyValues( tensor_in_sizes=[1, 2, 2, 1], @@ -501,7 +501,7 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=[50, 60]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSizeMatchesInputSizeDilation(self): self._VerifyDilatedConvValues( tensor_in_sizes=[1, 3, 3, 1], @@ -587,9 +587,9 @@ class Conv2DTest(test.TestCase): values.append(_GetVal(data_format, use_gpu)) for i in range(1, len(values)): - self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4) + self.assertAllClose(values[0], values[i], rtol=1e-2, atol=1e-2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth1ValidBackpropInput(self): expected_output = [1.0, 4.0, 4.0, 3.0, 10.0, 8.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -604,7 +604,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DEmptyBackpropInput(self): expected_output = [] for (data_format, use_gpu) in GetTestConfigs(): @@ -619,7 +619,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth3ValidBackpropInput(self): expected_output = [ 14.0, 32.0, 50.0, 100.0, 163.0, 226.0, 167.0, 212.0, 257.0, 122.0, @@ -639,7 +639,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-4) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth3ValidBackpropInputStride1x2(self): expected_output = [ 1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 7.0, 12.0, 11.0, 18.0, 15.0, 24.0, 12.0, @@ -657,7 +657,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DStrideTwoFilterOneSameBackpropInput(self): expected_output = [ 1.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 4.0, 0.0, 0.0, 0.0, @@ -675,7 +675,7 @@ class Conv2DTest(test.TestCase): use_gpu=use_gpu, err=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSizeMatchesInputSizeBackpropInput(self): expected_output = [5.0, 11.0, 17.0, 23.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -759,7 +759,7 @@ class Conv2DTest(test.TestCase): for i in range(1, len(values)): self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth1ValidBackpropFilter(self): expected = [5.0, 8.0, 14.0, 17.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -773,7 +773,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DEmptyBackpropFilter(self): expected = [] for (data_format, use_gpu) in GetTestConfigs(): @@ -787,7 +787,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DBackpropFilterWithEmptyInput(self): expected = [0, 0, 0, 0] for (data_format, use_gpu) in GetTestConfigs(): @@ -801,7 +801,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth3ValidBackpropFilter(self): expected = [ 17.0, 22.0, 27.0, 22.0, 29.0, 36.0, 27.0, 36.0, 45.0, 32.0, 43.0, 54.0, @@ -820,7 +820,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Depth3ValidBackpropFilterStride1x2(self): expected = [161.0, 182.0, 287.0, 308.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -834,7 +834,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DStrideTwoFilterOneSameBackpropFilter(self): expected_output = [78.] for (data_format, use_gpu) in GetTestConfigs(): @@ -848,7 +848,7 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConv2DKernelSizeMatchesInputSizeBackpropFilter(self): expected_output = [1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 4.0, 8.0] for (data_format, use_gpu) in GetTestConfigs(): @@ -1706,7 +1706,7 @@ class SeparableConv2DTest(test.TestCase): def testSeparableConv2D(self): self._testSeparableConv2D("NHWC") - def testSeparableConv2DNCHW(self): + def disabledtestSeparableConv2DNCHW(self): if not test.is_gpu_available(): return self._testSeparableConv2D("NCHW") @@ -1897,19 +1897,19 @@ if __name__ == "__main__": for index, (input_size_, filter_size_, output_size_, stride_, padding_) in enumerate(GetShrunkInceptionShapes()): setattr(Conv2DTest, "testInceptionFwd_" + str(index), - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_))) setattr( Conv2DTest, "testInceptionFwdDilatedConv_" + str(index), - test_util.run_in_graph_and_eager_modes()(GetInceptionFwdDilatedConvTest( + test_util.run_in_graph_and_eager_modes(GetInceptionFwdDilatedConvTest( input_size_, filter_size_, stride_, padding_))) setattr(Conv2DTest, "testInceptionBackInput_" + str(index), - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionBackInputTest(input_size_, filter_size_, output_size_, stride_, padding_))) setattr(Conv2DTest, "testInceptionBackFilter_" + str(index), - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionBackFilterTest(input_size_, filter_size_, output_size_, [stride_, stride_], padding_))) @@ -1924,17 +1924,17 @@ if __name__ == "__main__": fshape = [1, 1, 1, 256] oshape = [1, 400, 400, 256] setattr(Conv2DTest, "testInceptionFwd_No_Winograd_Nonfused", - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True))) setattr(Conv2DTest, "testInceptionFwdDilatedConv_No_Winograd_Nonfused", - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionFwdDilatedConvTest(ishape, fshape, 1, "SAME"))) setattr(Conv2DTest, "testInceptionBackInput_No_Winograd_Nonfused", - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME", gpu_only=True))) setattr(Conv2DTest, "testInceptionBackFilter_No_Winograd_Nonfused", - test_util.run_in_graph_and_eager_modes()( + test_util.run_in_graph_and_eager_modes( GetInceptionBackFilterTest(ishape, fshape, oshape, [1, 1], "SAME", gpu_only=True))) test.main() diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 8a3e64b174d9572ce11981a9d8d0e71cd9a336bc..b61232cdedecacf0cc0f9b1661486a52afc86c2e 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -96,7 +96,8 @@ class UnaryOpTest(test.TestCase): np_ans = np_func(x) with self.test_session(use_gpu=False): inx = ops.convert_to_tensor(x) - if x.dtype in (np.float32, np.float64): + if x.dtype in (np.float32, np.float64, + dtypes_lib.bfloat16.as_numpy_dtype): y = 1.1 * tf_func(inx) np_ans *= 1.1 else: @@ -105,6 +106,8 @@ class UnaryOpTest(test.TestCase): self.assertShapeEqual(np_ans, y) if x.dtype == np.float16: self.assertAllClose(np_ans, tf_cpu, rtol=1e-3, atol=1e-3) + elif x.dtype == dtypes_lib.bfloat16.as_numpy_dtype: + self.assertAllClose(np_ans, tf_cpu, rtol=1e-2, atol=1e-2) else: self.assertAllClose(np_ans, tf_cpu) @@ -668,12 +671,11 @@ class BinaryOpTest(test.TestCase): self._compareCpu(x, y, np_func, tf_func, also_compare_variables) if x.dtype in (np.float16, np.float32, np.float64, np.complex64, np.complex128): - if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.igamma, - math_ops.igammac, math_ops.zeta, math_ops.polygamma): + if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.zeta, + math_ops.polygamma): self._compareGradientX(x, y, np_func, tf_func) self._compareGradientY(x, y, np_func, tf_func) - if tf_func in (math_ops.igamma, math_ops.igammac, math_ops.zeta, - math_ops.polygamma): + if tf_func in (math_ops.zeta, math_ops.polygamma): # These methods only support gradients in the second parameter self._compareGradientY(x, y, np_func, tf_func) self._compareGpu(x, y, np_func, tf_func) diff --git a/tensorflow/python/kernel_tests/dct_ops_test.py b/tensorflow/python/kernel_tests/dct_ops_test.py index 93b2ff4561bcc8fd13855cde444c4b6237d7949b..97d7e2d8f90a620b693e2c81adc616d399e13bd6 100644 --- a/tensorflow/python/kernel_tests/dct_ops_test.py +++ b/tensorflow/python/kernel_tests/dct_ops_test.py @@ -40,50 +40,92 @@ def try_import(name): # pylint: disable=invalid-name fftpack = try_import("scipy.fftpack") +def _np_dct2(signals, norm=None): + """Computes the DCT-II manually with NumPy.""" + # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1 + dct_size = signals.shape[-1] + dct = np.zeros_like(signals) + for k in range(dct_size): + phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size) + dct[..., k] = np.sum(signals * phi, axis=-1) + # SciPy's `dct` has a scaling factor of 2.0 which we follow. + # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src + if norm == "ortho": + # The orthonormal scaling includes a factor of 0.5 which we combine with + # the overall scaling of 2.0 to cancel. + dct[..., 0] *= np.sqrt(1.0 / dct_size) + dct[..., 1:] *= np.sqrt(2.0 / dct_size) + else: + dct *= 2.0 + return dct + + +def _np_dct3(signals, norm=None): + """Computes the DCT-III manually with NumPy.""" + # SciPy's `dct` has a scaling factor of 2.0 which we follow. + # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src + dct_size = signals.shape[-1] + signals = np.array(signals) # make a copy so we can modify + if norm == "ortho": + signals[..., 0] *= np.sqrt(4.0 / dct_size) + signals[..., 1:] *= np.sqrt(2.0 / dct_size) + else: + signals *= 2.0 + dct = np.zeros_like(signals) + # X_k = 0.5 * x_0 + + # sum_{n=1}^{N-1} x_n * cos(\frac{pi}{N} * n * (k + 0.5)) k=0,...,N-1 + half_x0 = 0.5 * signals[..., 0] + for k in range(dct_size): + phi = np.cos(np.pi * np.arange(1, dct_size) * (k + 0.5) / dct_size) + dct[..., k] = half_x0 + np.sum(signals[..., 1:] * phi, axis=-1) + return dct + + +NP_DCT = {2: _np_dct2, 3: _np_dct3} +NP_IDCT = {2: _np_dct3, 3: _np_dct2} + + class DCTOpsTest(test.TestCase): - def _np_dct2(self, signals, norm=None): - """Computes the DCT-II manually with NumPy.""" - # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1 - dct_size = signals.shape[-1] - dct = np.zeros_like(signals) - for k in range(dct_size): - phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size) - dct[..., k] = np.sum(signals * phi, axis=-1) - # SciPy's `dct` has a scaling factor of 2.0 which we follow. - # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src - if norm == "ortho": - # The orthonormal scaling includes a factor of 0.5 which we combine with - # the overall scaling of 2.0 to cancel. - dct[..., 0] *= np.sqrt(1.0 / dct_size) - dct[..., 1:] *= np.sqrt(2.0 / dct_size) - else: - dct *= 2.0 - return dct - - def _compare(self, signals, norm, atol=5e-4, rtol=5e-4): - """Compares the DCT to SciPy (if available) and a NumPy implementation.""" - np_dct = self._np_dct2(signals, norm) - tf_dct = spectral_ops.dct(signals, type=2, norm=norm).eval() + def _compare(self, signals, norm, dct_type, atol=5e-4, rtol=5e-4): + """Compares (I)DCT to SciPy (if available) and a NumPy implementation.""" + np_dct = NP_DCT[dct_type](signals, norm) + tf_dct = spectral_ops.dct(signals, type=dct_type, norm=norm).eval() self.assertAllClose(np_dct, tf_dct, atol=atol, rtol=rtol) + np_idct = NP_IDCT[dct_type](signals, norm) + tf_idct = spectral_ops.idct(signals, type=dct_type, norm=norm).eval() + self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol) if fftpack: - scipy_dct = fftpack.dct(signals, type=2, norm=norm) + scipy_dct = fftpack.dct(signals, type=dct_type, norm=norm) self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol) + scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm) + self.assertAllClose(scipy_idct, tf_idct, atol=atol, rtol=rtol) + # Verify inverse(forward(s)) == s, up to a normalization factor. + tf_idct_dct = spectral_ops.idct( + tf_dct, type=dct_type, norm=norm).eval() + tf_dct_idct = spectral_ops.dct( + tf_idct, type=dct_type, norm=norm).eval() + if norm is None: + tf_idct_dct *= 0.5 / signals.shape[-1] + tf_dct_idct *= 0.5 / signals.shape[-1] + self.assertAllClose(signals, tf_idct_dct, atol=atol, rtol=rtol) + self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol) def test_random(self): """Test randomly generated batches of data.""" with spectral_ops_test_util.fft_kernel_label_map(): with self.test_session(use_gpu=True): - for shape in ([2, 20], [1], [2], [3], [10], [2, 20], [2, 3, 25]): + for shape in ([1], [2], [3], [10], [2, 20], [2, 3, 25]): signals = np.random.rand(*shape).astype(np.float32) for norm in (None, "ortho"): - self._compare(signals, norm) + self._compare(signals, norm, 2) + self._compare(signals, norm, 3) def test_error(self): signals = np.random.rand(10) # Unsupported type. with self.assertRaises(ValueError): - spectral_ops.dct(signals, type=3) + spectral_ops.dct(signals, type=1) # Unknown normalization. with self.assertRaises(ValueError): spectral_ops.dct(signals, norm="bad") diff --git a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py index 510daf79dc4252c3e2943e2ba23c1012370bf456..66b3e0f22fd2ab07311895da5df5448ee4e6e6f0 100644 --- a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py +++ b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py @@ -110,7 +110,8 @@ class DecodeJpegBenchmark(test.Benchmark): start_time = time.time() for _ in xrange(num_iters): sess.run(r) - return time.time() - start_time + end_time = time.time() + return end_time - start_time def benchmarkDecodeJpegSmall(self): """Evaluate single DecodeImageOp for small size image.""" diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py index 5e223b18281ed9c06a3f72a16b6d22290851f37b..58845552db5e22dd4e5e9a6de09de023c58be512 100644 --- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py +++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py @@ -90,7 +90,7 @@ def CheckGradConfigsToTest(): class DepthwiseConv2DTest(test.TestCase): # This is testing that depthwise_conv2d and depthwise_conv2d_native - # produce the same results. It also tests that NCHW and NWHC + # produce the same results. It also tests that NCHW and NHWC # formats agree, by comparing the depthwise_conv2d_native with # 'NCHW' format (with transposition) matches the 'NHWC' format using # the higher level interface. @@ -142,7 +142,7 @@ class DepthwiseConv2DTest(test.TestCase): native_t1 = t1 strides = [1, stride, stride, 1] if data_format == "NCHW": - # Transpose from NWHC input to NCHW + # Transpose from NHWC input to NCHW # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] native_t1 = array_ops.transpose(t1, [0, 3, 1, 2]) strides = [1, 1, stride, stride] @@ -356,7 +356,7 @@ class DepthwiseConv2DTest(test.TestCase): with self.test_session(graph=graph, use_gpu=use_gpu) as sess: tolerance = { dtypes.float16: 4e-0, - dtypes.float32: 5e-4, + dtypes.float32: 8e-4, dtypes.float64: 1e-12, }[data_type] @@ -368,7 +368,7 @@ class DepthwiseConv2DTest(test.TestCase): native_input = input_tensor strides = [1, stride, stride, 1] if data_format == "NCHW": - # Transpose from NWHC input to NCHW + # Transpose from NHWC input to NCHW # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] native_input = array_ops.transpose(input_tensor, [0, 3, 1, 2]) input_shape = [ diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD index cf2e8832fd5225e4d4be617a97b355bb410084c2..14532965d8c2c62139b3cd922acb9f90c0691d53 100644 --- a/tensorflow/python/kernel_tests/distributions/BUILD +++ b/tensorflow/python/kernel_tests/distributions/BUILD @@ -93,6 +93,7 @@ cuda_py_test( size = "small", srcs = ["categorical_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//tensorflow/python/ops/distributions", "//third_party/py/numpy", "//tensorflow/python:array_ops", @@ -134,6 +135,10 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + tags = [ + "noguitar", # b/110489471 + "notap", # b/110489471 + ], ) cuda_py_test( diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py index 095d1cde1530f15fd2a7ff4cb7f56424f276be5a..9ad77a54cbc730296508e4fe74248d2413029151 100644 --- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py +++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py @@ -22,6 +22,7 @@ import importlib import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util @@ -57,14 +58,14 @@ def entropy(p): class BernoulliTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testP(self): p = [0.2, 0.4] dist = bernoulli.Bernoulli(probs=p) with self.test_session(): self.assertAllClose(p, self.evaluate(dist.probs)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLogits(self): logits = [-42., 42.] dist = bernoulli.Bernoulli(logits=logits) @@ -82,7 +83,7 @@ class BernoulliTest(test.TestCase): with self.test_session(): self.assertAllClose(special.logit(p), self.evaluate(dist.logits)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInvalidP(self): invalid_ps = [1.01, 2.] for p in invalid_ps: @@ -104,7 +105,7 @@ class BernoulliTest(test.TestCase): dist = bernoulli.Bernoulli(probs=p) self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShapes(self): with self.test_session(): for batch_shape in ([], [1], [2, 3, 4]): @@ -115,7 +116,7 @@ class BernoulliTest(test.TestCase): self.assertAllEqual([], dist.event_shape.as_list()) self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDtype(self): dist = make_bernoulli([]) self.assertEqual(dist.dtype, dtypes.int32) @@ -133,7 +134,7 @@ class BernoulliTest(test.TestCase): self.assertEqual(dist64.dtype, dist64.sample(5).dtype) self.assertEqual(dist64.dtype, dist64.mode().dtype) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def _testPmf(self, **kwargs): dist = bernoulli.Bernoulli(**kwargs) with self.test_session(): @@ -174,7 +175,7 @@ class BernoulliTest(test.TestCase): p: [0.2, 0.3, 0.4] }), [[0.2, 0.7, 0.4]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPmfInvalid(self): p = [0.1, 0.2, 0.7] with self.test_session(): @@ -184,7 +185,7 @@ class BernoulliTest(test.TestCase): with self.assertRaisesOpError("Elements cannot exceed 1."): self.evaluate(dist.prob([2, 0, 1])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPmfWithP(self): p = [[0.2, 0.4], [0.3, 0.6]] self._testPmf(probs=p) @@ -226,21 +227,21 @@ class BernoulliTest(test.TestCase): dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]]) self.assertEqual((2, 1), dist.log_prob(1).get_shape()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBoundaryConditions(self): with self.test_session(): dist = bernoulli.Bernoulli(probs=1.0) self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0))) self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEntropyNoBatch(self): p = 0.2 dist = bernoulli.Bernoulli(probs=p) with self.test_session(): self.assertAllClose(self.evaluate(dist.entropy()), entropy(p)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEntropyWithBatch(self): p = [[0.1, 0.7], [0.2, 0.6]] dist = bernoulli.Bernoulli(probs=p, validate_args=False) @@ -250,7 +251,7 @@ class BernoulliTest(test.TestCase): [[entropy(0.1), entropy(0.7)], [entropy(0.2), entropy(0.6)]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSampleN(self): with self.test_session(): p = [0.2, 0.6] @@ -272,6 +273,16 @@ class BernoulliTest(test.TestCase): dist = bernoulli.Bernoulli(np.log([.2, .4])) self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list()) + @test_util.run_in_graph_and_eager_modes + def testNotReparameterized(self): + p = constant_op.constant([0.2, 0.6]) + with backprop.GradientTape() as tape: + tape.watch(p) + dist = bernoulli.Bernoulli(probs=p) + samples = dist.sample(100) + grad_p = tape.gradient(samples, p) + self.assertIsNone(grad_p) + def testSampleActsLikeSampleN(self): with self.test_session() as sess: p = [0.2, 0.6] @@ -282,18 +293,18 @@ class BernoulliTest(test.TestCase): self.evaluate(dist.sample(n, seed)), self.evaluate(dist.sample(n, seed))) n = array_ops.placeholder(dtypes.int32) - sample, sample = sess.run([dist.sample(n, seed), dist.sample(n, seed)], - feed_dict={n: 1000}) - self.assertAllEqual(sample, sample) + sample1, sample2 = sess.run([dist.sample(n, seed), dist.sample(n, seed)], + feed_dict={n: 1000}) + self.assertAllEqual(sample1, sample2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMean(self): with self.test_session(): p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32) dist = bernoulli.Bernoulli(probs=p) self.assertAllEqual(self.evaluate(dist.mean()), p) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarianceAndStd(self): var = lambda p: p * (1. - p) with self.test_session(): @@ -310,7 +321,7 @@ class BernoulliTest(test.TestCase): [np.sqrt(var(0.5)), np.sqrt(var(0.4))]], dtype=np.float32)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBernoulliBernoulliKL(self): batch_size = 6 a_p = np.array([0.5] * batch_size, dtype=np.float32) diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py index 4bc8303ebb6939f3f8e2637120b6510c225c2f12..36f3ffc333f74e3f6e672b6ba1591bf8de08a010 100644 --- a/tensorflow/python/kernel_tests/distributions/beta_test.py +++ b/tensorflow/python/kernel_tests/distributions/beta_test.py @@ -21,6 +21,7 @@ import importlib import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape @@ -282,6 +283,18 @@ class BetaTest(test.TestCase): self.assertAllClose( np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1) + def testBetaFullyReparameterized(self): + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + with backprop.GradientTape() as tape: + tape.watch(a) + tape.watch(b) + beta = beta_lib.Beta(a, b) + samples = beta.sample(100) + grad_a, grad_b = tape.gradient(samples, [a, b]) + self.assertIsNotNone(grad_a) + self.assertIsNotNone(grad_b) + # Test that sampling with the same seed twice gives the same results. def testBetaSampleMultipleTimes(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py index ca2358fe99934e110ba743c6085d1f25ff0f5e5e..d8939433ce68ffa561e8e2200826f88dbe283ac2 100644 --- a/tensorflow/python/kernel_tests/distributions/categorical_test.py +++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py @@ -18,8 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util @@ -40,7 +42,7 @@ def make_categorical(batch_shape, num_classes, dtype=dtypes.int32): return categorical.Categorical(logits, dtype=dtype) -class CategoricalTest(test.TestCase): +class CategoricalTest(test.TestCase, parameterized.TestCase): def testP(self): p = [0.2, 0.8] @@ -131,7 +133,7 @@ class CategoricalTest(test.TestCase): with self.test_session(): self.assertAllClose(dist.prob(0).eval(), 0.2) - def testCDFWithDynamicEventShape(self): + def testCDFWithDynamicEventShapeKnownNdims(self): """Test that dynamically-sized events with unknown shape work.""" batch_size = 2 histograms = array_ops.placeholder(dtype=dtypes.float32, @@ -167,6 +169,21 @@ class CategoricalTest(test.TestCase): self.assertAllClose(actual_cdf_one, expected_cdf_one) self.assertAllClose(actual_cdf_two, expected_cdf_two) + @parameterized.named_parameters( + ("test1", [0, 1], [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]], [0.0, 1.0]), + ("test2", [2, 5], [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1], + [0.15, 0.2, 0.05, 0.35, 0.13, 0.12]], [0.9, 0.88])) + def testCDFWithDynamicEventShapeUnknownNdims( + self, events, histograms, expected_cdf): + """Test that dynamically-sized events with unknown shape work.""" + event_ph = array_ops.placeholder_with_default(events, shape=None) + histograms_ph = array_ops.placeholder_with_default(histograms, shape=None) + dist = categorical.Categorical(probs=histograms_ph) + cdf_op = dist.cdf(event_ph) + + actual_cdf = self.evaluate(cdf_op) + self.assertAllClose(actual_cdf, expected_cdf) + def testCDFWithBatch(self): histograms = [[0.1, 0.2, 0.3, 0.25, 0.15], [0.0, 0.75, 0.2, 0.05, 0.0]] @@ -360,6 +377,15 @@ class CategoricalTest(test.TestCase): self.assertAllClose( [0.4**2 + 0.6**2], [prob_val[:, :, :, 1].mean()], atol=1e-2) + def testNotReparameterized(self): + p = constant_op.constant([0.3, 0.3, 0.4]) + with backprop.GradientTape() as tape: + tape.watch(p) + dist = categorical.Categorical(p) + samples = dist.sample(100) + grad_p = tape.gradient(samples, p) + self.assertIsNone(grad_p) + def testLogPMFBroadcasting(self): with self.test_session(): # 1 x 2 x 2 diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py index 7922fb0606c6f4b475b25da716d5f9a169e213b5..1b9edcc85a7581de1cb1bd93fdbb9d47b8d1b84a 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py @@ -17,6 +17,9 @@ from __future__ import division from __future__ import print_function import numpy as np + +from tensorflow.python.eager import backprop +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -250,10 +253,10 @@ class DirichletMultinomialTest(test.TestCase): dist.variance(), dist.stddev(), ]) - self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04) - self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.05) - self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.05) - self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02) + self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.) + self.assertAllClose(sample_cov_, analytic_cov, atol=0.05, rtol=0.) + self.assertAllClose(sample_var_, analytic_var, atol=0.05, rtol=0.) + self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) def testCovariance(self): # Shape [2] @@ -442,7 +445,7 @@ class DirichletMultinomialTest(test.TestCase): dist.covariance(), ]) self.assertAllEqual([4, 3, 2], sample_mean.get_shape()) - self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.15) + self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20) self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape()) self.assertAllClose( actual_covariance_, sample_covariance_, atol=0., rtol=0.20) @@ -470,10 +473,25 @@ class DirichletMultinomialTest(test.TestCase): dist.covariance(), ]) self.assertAllEqual([4], sample_mean.get_shape()) - self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.05) + self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20) self.assertAllEqual([4, 4], sample_covariance.get_shape()) self.assertAllClose( - actual_covariance_, sample_covariance_, atol=0., rtol=0.15) + actual_covariance_, sample_covariance_, atol=0., rtol=0.20) + + def testNotReparameterized(self): + total_count = constant_op.constant(5.0) + concentration = constant_op.constant([0.1, 0.1, 0.1]) + with backprop.GradientTape() as tape: + tape.watch(total_count) + tape.watch(concentration) + dist = ds.DirichletMultinomial( + total_count=total_count, + concentration=concentration) + samples = dist.sample(100) + grad_total_count, grad_concentration = tape.gradient( + samples, [total_count, concentration]) + self.assertIsNone(grad_total_count) + self.assertIsNone(grad_concentration) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py index bcec6ef610d0389f4b0f164ff4ab1a1cd1f6d1e5..67ed0447ede39d7f0738c8caf3cc665bcfe5fd0b 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py @@ -20,6 +20,7 @@ import importlib import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util @@ -190,10 +191,10 @@ class DirichletTest(test.TestCase): dist.stddev(), ]) - self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04) - self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.06) - self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03) - self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02) + self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.) + self.assertAllClose(sample_cov_, analytic_cov, atol=0.06, rtol=0.) + self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.) + self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) def testVariance(self): with self.test_session(): @@ -264,6 +265,15 @@ class DirichletTest(test.TestCase): a=1., b=2.).cdf)[0], 0.01) + def testDirichletFullyReparameterized(self): + alpha = constant_op.constant([1.0, 2.0, 3.0]) + with backprop.GradientTape() as tape: + tape.watch(alpha) + dirichlet = dirichlet_lib.Dirichlet(alpha) + samples = dirichlet.sample(100) + grad_alpha = tape.gradient(samples, alpha) + self.assertIsNotNone(grad_alpha) + def testDirichletDirichletKL(self): conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5], [1.5, 2.5, 3.5, 4.5, 5.5, 6.5]]) diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py index ebcd41b0e24ae8093752c84cf5077029f2ac9330..850da3e9697ab5f087761e9988094a3015636c36 100644 --- a/tensorflow/python/kernel_tests/distributions/exponential_test.py +++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py @@ -23,6 +23,7 @@ import importlib import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.ops import nn_ops @@ -163,6 +164,15 @@ class ExponentialTest(test.TestCase): stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) + def testFullyReparameterized(self): + lam = constant_op.constant([0.1, 1.0]) + with backprop.GradientTape() as tape: + tape.watch(lam) + exponential = exponential_lib.Exponential(rate=lam) + samples = exponential.sample(100) + grad_lam = tape.gradient(samples, lam) + self.assertIsNotNone(grad_lam) + def testExponentialWithSoftplusRate(self): with self.test_session(): lam = [-2.2, -3.4] diff --git a/tensorflow/python/kernel_tests/distributions/gamma_test.py b/tensorflow/python/kernel_tests/distributions/gamma_test.py index 5e4813ac0762d2855d7fbe6754fe1466c29c06c9..297e20264c6d36f5b9098005393302337e3d1315 100644 --- a/tensorflow/python/kernel_tests/distributions/gamma_test.py +++ b/tensorflow/python/kernel_tests/distributions/gamma_test.py @@ -21,9 +21,10 @@ import importlib import numpy as np -from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import gamma as gamma_lib @@ -45,6 +46,7 @@ special = try_import("scipy.special") stats = try_import("scipy.stats") +@test_util.run_all_in_graph_and_eager_modes class GammaTest(test.TestCase): def testGammaShape(self): @@ -53,9 +55,9 @@ class GammaTest(test.TestCase): beta = constant_op.constant(11.0) gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - self.assertEqual(gamma.batch_shape_tensor().eval(), (5,)) + self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,)) self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(gamma.event_shape_tensor().eval(), []) + self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), []) self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([])) def testGammaLogPDF(self): @@ -74,8 +76,8 @@ class GammaTest(test.TestCase): if not stats: return expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(log_pdf.eval(), expected_log_pdf) - self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf)) + self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) + self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) def testGammaLogPDFMultidimensional(self): with self.test_session(): @@ -87,10 +89,10 @@ class GammaTest(test.TestCase): x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) log_pdf = gamma.log_prob(x) - log_pdf_values = log_pdf.eval() + log_pdf_values = self.evaluate(log_pdf) self.assertEqual(log_pdf.get_shape(), (6, 2)) pdf = gamma.prob(x) - pdf_values = pdf.eval() + pdf_values = self.evaluate(pdf) self.assertEqual(pdf.get_shape(), (6, 2)) if not stats: return @@ -108,10 +110,10 @@ class GammaTest(test.TestCase): x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) log_pdf = gamma.log_prob(x) - log_pdf_values = log_pdf.eval() + log_pdf_values = self.evaluate(log_pdf) self.assertEqual(log_pdf.get_shape(), (6, 2)) pdf = gamma.prob(x) - pdf_values = pdf.eval() + pdf_values = self.evaluate(pdf) self.assertEqual(pdf.get_shape(), (6, 2)) if not stats: @@ -135,7 +137,7 @@ class GammaTest(test.TestCase): if not stats: return expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(cdf.eval(), expected_cdf) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testGammaMean(self): with self.test_session(): @@ -146,7 +148,7 @@ class GammaTest(test.TestCase): if not stats: return expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v) - self.assertAllClose(gamma.mean().eval(), expected_means) + self.assertAllClose(self.evaluate(gamma.mean()), expected_means) def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): with self.test_session(): @@ -155,7 +157,7 @@ class GammaTest(test.TestCase): gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) expected_modes = (alpha_v - 1) / beta_v self.assertEqual(gamma.mode().get_shape(), (3,)) - self.assertAllClose(gamma.mode().eval(), expected_modes) + self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): with self.test_session(): @@ -166,7 +168,7 @@ class GammaTest(test.TestCase): rate=beta_v, allow_nan_stats=False) with self.assertRaisesOpError("x < y"): - gamma.mode().eval() + self.evaluate(gamma.mode()) def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self): with self.test_session(): @@ -179,7 +181,7 @@ class GammaTest(test.TestCase): expected_modes = (alpha_v - 1) / beta_v expected_modes[0] = np.nan self.assertEqual(gamma.mode().get_shape(), (3,)) - self.assertAllClose(gamma.mode().eval(), expected_modes) + self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) def testGammaVariance(self): with self.test_session(): @@ -190,7 +192,7 @@ class GammaTest(test.TestCase): if not stats: return expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v) - self.assertAllClose(gamma.variance().eval(), expected_variances) + self.assertAllClose(self.evaluate(gamma.variance()), expected_variances) def testGammaStd(self): with self.test_session(): @@ -201,7 +203,7 @@ class GammaTest(test.TestCase): if not stats: return expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v) - self.assertAllClose(gamma.stddev().eval(), expected_stddev) + self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev) def testGammaEntropy(self): with self.test_session(): @@ -212,10 +214,10 @@ class GammaTest(test.TestCase): if not stats: return expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v) - self.assertAllClose(gamma.entropy().eval(), expected_entropy) + self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy) def testGammaSampleSmallAlpha(self): - with session.Session(): + with self.test_session(): alpha_v = 0.05 beta_v = 1.0 alpha = constant_op.constant(alpha_v) @@ -223,7 +225,7 @@ class GammaTest(test.TestCase): n = 100000 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) samples = gamma.sample(n, seed=137) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (n,)) self.assertEqual(sample_values.shape, (n,)) self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) @@ -240,7 +242,7 @@ class GammaTest(test.TestCase): atol=.15) def testGammaSample(self): - with session.Session(): + with self.test_session(): alpha_v = 4.0 beta_v = 3.0 alpha = constant_op.constant(alpha_v) @@ -248,7 +250,7 @@ class GammaTest(test.TestCase): n = 100000 gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) samples = gamma.sample(n, seed=137) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (n,)) self.assertEqual(sample_values.shape, (n,)) self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) @@ -264,14 +266,26 @@ class GammaTest(test.TestCase): stats.gamma.var(alpha_v, scale=1 / beta_v), atol=.15) + def testGammaFullyReparameterized(self): + alpha = constant_op.constant(4.0) + beta = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + tape.watch(alpha) + tape.watch(beta) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + samples = gamma.sample(100) + grad_alpha, grad_beta = tape.gradient(samples, [alpha, beta]) + self.assertIsNotNone(grad_alpha) + self.assertIsNotNone(grad_beta) + def testGammaSampleMultiDimensional(self): - with session.Session(): + with self.test_session(): alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) n = 10000 samples = gamma.sample(n, seed=137) - sample_values = samples.eval() + sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (n, 10, 100)) self.assertEqual(sample_values.shape, (n, 10, 100)) zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100 @@ -283,11 +297,11 @@ class GammaTest(test.TestCase): sample_values.mean(axis=0), stats.gamma.mean( alpha_bc, scale=1 / beta_bc), - rtol=.035) + atol=0., rtol=.05) self.assertAllClose( sample_values.var(axis=0), stats.gamma.var(alpha_bc, scale=1 / beta_bc), - atol=4.5) + atol=10.0, rtol=0.) fails = 0 trials = 0 for ai, a in enumerate(np.reshape(alpha_v, [-1])): @@ -306,12 +320,12 @@ class GammaTest(test.TestCase): return ks < 0.02 def testGammaPdfOfSampleMultiDims(self): - with session.Session() as sess: + with self.test_session(): gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]]) num = 50000 samples = gamma.sample(num, seed=137) pdfs = gamma.prob(samples) - sample_vals, pdf_vals = sess.run([samples, pdfs]) + sample_vals, pdf_vals = self.evaluate([samples, pdfs]) self.assertEqual(samples.get_shape(), (num, 2, 2)) self.assertEqual(pdfs.get_shape(), (num, 2, 2)) self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) @@ -345,18 +359,18 @@ class GammaTest(test.TestCase): with self.test_session(): alpha_v = constant_op.constant(0.0, name="alpha") beta_v = constant_op.constant(1.0, name="beta") - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - validate_args=True) - with self.assertRaisesOpError("alpha"): - gamma.mean().eval() + with self.assertRaisesOpError("x > 0"): + gamma = gamma_lib.Gamma(concentration=alpha_v, + rate=beta_v, + validate_args=True) + self.evaluate(gamma.mean()) alpha_v = constant_op.constant(1.0, name="alpha") beta_v = constant_op.constant(0.0, name="beta") - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - validate_args=True) - with self.assertRaisesOpError("beta"): - gamma.mean().eval() + with self.assertRaisesOpError("x > 0"): + gamma = gamma_lib.Gamma(concentration=alpha_v, + rate=beta_v, + validate_args=True) + self.evaluate(gamma.mean()) def testGammaWithSoftplusConcentrationRate(self): with self.test_session(): @@ -364,10 +378,10 @@ class GammaTest(test.TestCase): beta_v = constant_op.constant([1.0, -3.6], name="beta") gamma = gamma_lib.GammaWithSoftplusConcentrationRate( concentration=alpha_v, rate=beta_v) - self.assertAllEqual(nn_ops.softplus(alpha_v).eval(), - gamma.concentration.eval()) - self.assertAllEqual(nn_ops.softplus(beta_v).eval(), - gamma.rate.eval()) + self.assertAllEqual(self.evaluate(nn_ops.softplus(alpha_v)), + self.evaluate(gamma.concentration)) + self.assertAllEqual(self.evaluate(nn_ops.softplus(beta_v)), + self.evaluate(gamma.rate)) def testGammaGammaKL(self): alpha0 = np.array([3.]) @@ -377,15 +391,15 @@ class GammaTest(test.TestCase): beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.]) # Build graph. - with self.test_session() as sess: + with self.test_session(): g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0) g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1) x = g0.sample(int(1e4), seed=0) kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) kl_actual = kullback_leibler.kl_divergence(g0, g1) - # Execute graph. - [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual]) + # Execute graph. + [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual]) self.assertEqual(beta0.shape, kl_actual.get_shape()) @@ -399,7 +413,7 @@ class GammaTest(test.TestCase): + alpha0 * (beta1 / beta0 - 1.)) self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6) - self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-2) + self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-1) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/laplace_test.py b/tensorflow/python/kernel_tests/distributions/laplace_test.py index 918c7f63f2065525338632ba68cb180c7c50dea6..24b243f647e495c47d57f914951263e3ee4ca7a5 100644 --- a/tensorflow/python/kernel_tests/distributions/laplace_test.py +++ b/tensorflow/python/kernel_tests/distributions/laplace_test.py @@ -22,6 +22,7 @@ import importlib import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util @@ -255,6 +256,18 @@ class LaplaceTest(test.TestCase): atol=0.) self.assertTrue(self._kstest(loc_v, scale_v, sample_values)) + def testLaplaceFullyReparameterized(self): + loc = constant_op.constant(4.0) + scale = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + tape.watch(loc) + tape.watch(scale) + laplace = laplace_lib.Laplace(loc=loc, scale=scale) + samples = laplace.sample(100) + grad_loc, grad_scale = tape.gradient(samples, [loc, scale]) + self.assertIsNotNone(grad_loc) + self.assertIsNotNone(grad_scale) + def testLaplaceSampleMultiDimensional(self): with session.Session(): loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 diff --git a/tensorflow/python/kernel_tests/distributions/multinomial_test.py b/tensorflow/python/kernel_tests/distributions/multinomial_test.py index e24e8ade73a7ad762c877214f5ec3ee0848863fe..bfd40ba2b7a5d32e957507b36d44e1198bd3867f 100644 --- a/tensorflow/python/kernel_tests/distributions/multinomial_test.py +++ b/tensorflow/python/kernel_tests/distributions/multinomial_test.py @@ -18,6 +18,8 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import backprop +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -310,10 +312,10 @@ class MultinomialTest(test.TestCase): dist.covariance(), ]) self.assertAllEqual([4, 3, 2], sample_mean.get_shape()) - self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07) + self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10) self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape()) self.assertAllClose( - actual_covariance_, sample_covariance_, atol=0., rtol=0.10) + actual_covariance_, sample_covariance_, atol=0., rtol=0.20) def testSampleUnbiasedScalarBatch(self): with self.test_session() as sess: @@ -338,10 +340,24 @@ class MultinomialTest(test.TestCase): dist.covariance(), ]) self.assertAllEqual([4], sample_mean.get_shape()) - self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07) + self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10) self.assertAllEqual([4, 4], sample_covariance.get_shape()) self.assertAllClose( - actual_covariance_, sample_covariance_, atol=0., rtol=0.10) + actual_covariance_, sample_covariance_, atol=0., rtol=0.20) + + def testNotReparameterized(self): + total_count = constant_op.constant(5.0) + p = constant_op.constant([0.2, 0.6]) + with backprop.GradientTape() as tape: + tape.watch(total_count) + tape.watch(p) + dist = multinomial.Multinomial( + total_count=total_count, + probs=p) + samples = dist.sample(100) + grad_total_count, grad_p = tape.gradient(samples, [total_count, p]) + self.assertIsNone(grad_total_count) + self.assertIsNone(grad_p) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py index d793e03272909cc97543e313041b6ae7f487ae3f..7ff48c0c10f4d2cd18072a22cdcef0fefc530eae 100644 --- a/tensorflow/python/kernel_tests/distributions/normal_test.py +++ b/tensorflow/python/kernel_tests/distributions/normal_test.py @@ -23,6 +23,7 @@ import math import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -77,20 +78,20 @@ class NormalTest(test.TestCase): self.assertEqual(expected, mu_shape) self.assertEqual(expected, sigma_shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testParamShapes(self): sample_shape = [10, 3, 4] self._testParamShapes(sample_shape, sample_shape) self._testParamShapes(constant_op.constant(sample_shape), sample_shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testParamStaticShapes(self): sample_shape = [10, 3, 4] self._testParamStaticShapes(sample_shape, sample_shape) self._testParamStaticShapes( tensor_shape.TensorShape(sample_shape), sample_shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalWithSoftplusScale(self): with self.test_session(): mu = array_ops.zeros((10, 3)) @@ -100,7 +101,7 @@ class NormalTest(test.TestCase): self.assertAllEqual( self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalLogPDF(self): with self.test_session(): batch_size = 6 @@ -134,7 +135,7 @@ class NormalTest(test.TestCase): self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf)) self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalLogPDFMultidimensional(self): with self.test_session(): batch_size = 6 @@ -172,7 +173,7 @@ class NormalTest(test.TestCase): self.assertAllClose(expected_log_pdf, log_pdf_values) self.assertAllClose(np.exp(expected_log_pdf), pdf_values) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalCDF(self): with self.test_session(): batch_size = 50 @@ -194,7 +195,7 @@ class NormalTest(test.TestCase): expected_cdf = stats.norm(mu, sigma).cdf(x) self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalSurvivalFunction(self): with self.test_session(): batch_size = 50 @@ -217,7 +218,7 @@ class NormalTest(test.TestCase): expected_sf = stats.norm(mu, sigma).sf(x) self.assertAllClose(expected_sf, self.evaluate(sf), atol=0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalLogCDF(self): with self.test_session(): batch_size = 50 @@ -239,7 +240,7 @@ class NormalTest(test.TestCase): if not stats: return expected_cdf = stats.norm(mu, sigma).logcdf(x) - self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-5) + self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3) def testFiniteGradientAtDifficultPoints(self): for dtype in [np.float32, np.float64]: @@ -261,7 +262,7 @@ class NormalTest(test.TestCase): self.assertAllFinite(grads[0]) self.assertAllFinite(grads[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalLogSurvivalFunction(self): with self.test_session(): batch_size = 50 @@ -285,7 +286,7 @@ class NormalTest(test.TestCase): expected_sf = stats.norm(mu, sigma).logsf(x) self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalEntropyWithScalarInputs(self): # Scipy.stats.norm cannot deal with the shapes in the other test. with self.test_session(): @@ -307,7 +308,7 @@ class NormalTest(test.TestCase): expected_entropy = stats.norm(mu_v, sigma_v).entropy() self.assertAllClose(expected_entropy, self.evaluate(entropy)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalEntropy(self): with self.test_session(): mu_v = np.array([1.0, 1.0, 1.0]) @@ -328,7 +329,7 @@ class NormalTest(test.TestCase): self.assertAllEqual(normal.batch_shape, entropy.get_shape()) self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalMeanAndMode(self): with self.test_session(): # Mu will be broadcast to [7, 7, 7]. @@ -343,7 +344,7 @@ class NormalTest(test.TestCase): self.assertAllEqual((3,), normal.mode().get_shape()) self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalQuantile(self): with self.test_session(): batch_size = 52 @@ -395,7 +396,7 @@ class NormalTest(test.TestCase): def testQuantileFiniteGradientAtDifficultPointsFloat64(self): self._baseQuantileFiniteGradientAtDifficultPoints(np.float64) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalVariance(self): with self.test_session(): # sigma will be broadcast to [7, 7, 7] @@ -407,7 +408,7 @@ class NormalTest(test.TestCase): self.assertAllEqual((3,), normal.variance().get_shape()) self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalStandardDeviation(self): with self.test_session(): # sigma will be broadcast to [7, 7, 7] @@ -419,7 +420,7 @@ class NormalTest(test.TestCase): self.assertAllEqual((3,), normal.stddev().get_shape()) self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalSample(self): with self.test_session(): mu = constant_op.constant(3.0) @@ -453,7 +454,19 @@ class NormalTest(test.TestCase): self.assertAllEqual(expected_samples_shape, samples.get_shape()) self.assertAllEqual(expected_samples_shape, sample_values.shape) - @test_util.run_in_graph_and_eager_modes() + def testNormalFullyReparameterized(self): + mu = constant_op.constant(4.0) + sigma = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + tape.watch(mu) + tape.watch(sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) + samples = normal.sample(100) + grad_mu, grad_sigma = tape.gradient(samples, [mu, sigma]) + self.assertIsNotNone(grad_mu) + self.assertIsNotNone(grad_sigma) + + @test_util.run_in_graph_and_eager_modes def testNormalSampleMultiDimensional(self): with self.test_session(): batch_size = 2 @@ -489,7 +502,7 @@ class NormalTest(test.TestCase): self.assertAllEqual(expected_samples_shape, samples.get_shape()) self.assertAllEqual(expected_samples_shape, sample_values.shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNegativeSigmaFails(self): with self.test_session(): with self.assertRaisesOpError("Condition x > 0 did not hold"): @@ -497,7 +510,7 @@ class NormalTest(test.TestCase): loc=[1.], scale=[-5.], validate_args=True, name="G") self.evaluate(normal.mean()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalShape(self): with self.test_session(): mu = constant_op.constant([-3.0] * 5) @@ -524,7 +537,7 @@ class NormalTest(test.TestCase): feed_dict={mu: 5.0, sigma: [1.0, 2.0]}), [2]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNormalNormalKL(self): batch_size = 6 mu_a = np.array([3.0] * batch_size) diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py index 4565bf5c4669b4d416049816046f6f8ed187270d..a634194ce5293f4d7e7a68aa661080ed06493297 100644 --- a/tensorflow/python/kernel_tests/distributions/special_math_test.py +++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py @@ -89,7 +89,7 @@ class NdtriTest(test.TestCase): all_true = np.ones_like(is_finite, dtype=np.bool) self.assertAllEqual(all_true, is_finite) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNdtri(self): """Verifies that ndtri computation is correct.""" with self.test_session(): @@ -138,11 +138,11 @@ class NdtriTest(test.TestCase): lambda x: special_math.ndtri(x), p) # pylint: disable=unnecessary-lambda self.assertAllFinite(self.evaluate(grads[0])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNdtriFiniteGradientFloat32(self): self._baseNdtriFiniteGradientTest(np.float32) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNdtriFiniteGradientFloat64(self): self._baseNdtriFiniteGradientTest(np.float64) diff --git a/tensorflow/python/kernel_tests/distributions/student_t_test.py b/tensorflow/python/kernel_tests/distributions/student_t_test.py index a4fdb658e857d832d5bf69485bbfb2517646a7b7..05590542efe2623e608f783233db68240331ba20 100644 --- a/tensorflow/python/kernel_tests/distributions/student_t_test.py +++ b/tensorflow/python/kernel_tests/distributions/student_t_test.py @@ -23,6 +23,7 @@ import math import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util @@ -172,11 +173,11 @@ class StudentTTest(test.TestCase): sample_values = self.evaluate(samples) n_val = 200000 self.assertEqual(sample_values.shape, (n_val,)) - self.assertAllClose(sample_values.mean(), mu_v, rtol=1e-2, atol=0) + self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0) self.assertAllClose( sample_values.var(), sigma_v**2 * df_v / (df_v - 2), - rtol=1e-2, + rtol=0.1, atol=0) self._checkKLApprox(df_v, mu_v, sigma_v, sample_values) @@ -215,11 +216,11 @@ class StudentTTest(test.TestCase): def testStudentSampleMultiDimensional(self): with self.test_session(): batch_size = 7 - df = constant_op.constant([[3., 7.]] * batch_size) + df = constant_op.constant([[5., 7.]] * batch_size) mu = constant_op.constant([[3., -3.]] * batch_size) sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] * batch_size) - df_v = [3., 7.] + df_v = [5., 7.] mu_v = [3., -3.] sigma_v = [np.sqrt(10.), np.sqrt(15.)] n = constant_op.constant(200000) @@ -228,21 +229,21 @@ class StudentTTest(test.TestCase): sample_values = self.evaluate(samples) self.assertEqual(samples.get_shape(), (200000, batch_size, 2)) self.assertAllClose( - sample_values[:, 0, 0].mean(), mu_v[0], rtol=1e-2, atol=0) + sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0) self.assertAllClose( sample_values[:, 0, 0].var(), sigma_v[0]**2 * df_v[0] / (df_v[0] - 2), - rtol=1e-1, + rtol=0.2, atol=0) self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0]) self.assertAllClose( - sample_values[:, 0, 1].mean(), mu_v[1], rtol=1e-2, atol=0) + sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0) self.assertAllClose( sample_values[:, 0, 1].var(), sigma_v[1]**2 * df_v[1] / (df_v[1] - 2), - rtol=1e-1, + rtol=0.2, atol=0) - self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 1]) + self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1]) def _checkKLApprox(self, df, mu, sigma, samples): n = samples.size @@ -272,7 +273,7 @@ class StudentTTest(test.TestCase): self.assertEqual(student.entropy().get_shape(), (3,)) self.assertEqual(student.log_prob(2.).get_shape(), (3,)) self.assertEqual(student.prob(2.).get_shape(), (3,)) - self.assertEqual(student.sample(37, seed=123456).get_shape(), (37, 3,)) + self.assertEqual(student.sample(37).get_shape(), (37, 3,)) _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.)) _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.)) @@ -445,15 +446,30 @@ class StudentTTest(test.TestCase): self.assertEqual(samples.get_shape(), (num,)) self.assertEqual(pdfs.get_shape(), (num,)) self.assertEqual(mean.get_shape(), ()) - self.assertNear(np.pi, np.mean(sample_vals), err=0.02) + self.assertNear(np.pi, np.mean(sample_vals), err=0.1) self.assertNear(np.pi, mean_val, err=1e-6) # Verify integral over sample*pdf ~= 1. # Tolerance increased since eager was getting a value of 1.002041. - self._assertIntegral(sample_vals, pdf_vals, err=3e-3) + self._assertIntegral(sample_vals, pdf_vals, err=5e-2) if not stats: return self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6) + def testFullyReparameterized(self): + df = constant_op.constant(2.0) + mu = constant_op.constant(1.0) + sigma = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + tape.watch(df) + tape.watch(mu) + tape.watch(sigma) + student = student_t.StudentT(df=df, loc=mu, scale=sigma) + samples = student.sample(100) + grad_df, grad_mu, grad_sigma = tape.gradient(samples, [df, mu, sigma]) + self.assertIsNotNone(grad_df) + self.assertIsNotNone(grad_mu) + self.assertIsNotNone(grad_sigma) + def testPdfOfSampleMultiDims(self): student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.) self.assertAllEqual([], student.event_shape) @@ -466,22 +482,22 @@ class StudentTTest(test.TestCase): sample_vals, pdf_vals = self.evaluate([samples, pdfs]) self.assertEqual(samples.get_shape(), (num, 2, 2)) self.assertEqual(pdfs.get_shape(), (num, 2, 2)) - self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03) - self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03) - self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) - self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02) - self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02) - self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02) + self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=0.1) + self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=0.1) + self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.05) + self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.05) + self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.05) + self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.05) if not stats: return self.assertNear( stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var np.var(sample_vals[:, :, 0]), - err=.4) + err=1.0) self.assertNear( stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var np.var(sample_vals[:, :, 1]), - err=.4) + err=1.0) def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3): s_p = zip(sample_vals, pdf_vals) diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py index e74051c9013b7d51914868e66022546ae8862b60..bc9c267b9a5eac6fd8c9c4290dcc4b56865ddb50 100644 --- a/tensorflow/python/kernel_tests/distributions/uniform_test.py +++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py @@ -22,6 +22,7 @@ import importlib import numpy as np +from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape @@ -47,7 +48,7 @@ stats = try_import("scipy.stats") class UniformTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformRange(self): with self.test_session(): a = 3.0 @@ -57,7 +58,7 @@ class UniformTest(test.TestCase): self.assertAllClose(b, self.evaluate(uniform.high)) self.assertAllClose(b - a, self.evaluate(uniform.range())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformPDF(self): with self.test_session(): a = constant_op.constant([-3.0] * 5 + [15.0]) @@ -83,7 +84,7 @@ class UniformTest(test.TestCase): log_pdf = uniform.log_prob(x) self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformShape(self): with self.test_session(): a = constant_op.constant([-3.0] * 5) @@ -95,7 +96,7 @@ class UniformTest(test.TestCase): self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), []) self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformPDFWithScalarEndpoint(self): with self.test_session(): a = constant_op.constant([0.0, 5.0]) @@ -108,7 +109,7 @@ class UniformTest(test.TestCase): pdf = uniform.prob(x) self.assertAllClose(expected_pdf, self.evaluate(pdf)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformCDF(self): with self.test_session(): batch_size = 6 @@ -132,7 +133,7 @@ class UniformTest(test.TestCase): log_cdf = uniform.log_cdf(x) self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformEntropy(self): with self.test_session(): a_v = np.array([1.0, 1.0, 1.0]) @@ -142,7 +143,7 @@ class UniformTest(test.TestCase): expected_entropy = np.log(b_v - a_v) self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformAssertMaxGtMin(self): with self.test_session(): a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32) @@ -153,7 +154,7 @@ class UniformTest(test.TestCase): uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True) self.evaluate(uniform.low) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformSample(self): with self.test_session(): a = constant_op.constant([3.0, 4.0]) @@ -168,15 +169,15 @@ class UniformTest(test.TestCase): sample_values = self.evaluate(samples) self.assertEqual(sample_values.shape, (100000, 2)) self.assertAllClose( - sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-2) + sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.) self.assertAllClose( - sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-2) + sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.) self.assertFalse( np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v)) self.assertFalse( np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def _testUniformSampleMultiDimensional(self): # DISABLED: Please enable this test once b/issues/30149644 is resolved. with self.test_session(): @@ -207,7 +208,7 @@ class UniformTest(test.TestCase): self.assertAllClose( sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformMean(self): with self.test_session(): a = 10.0 @@ -218,7 +219,7 @@ class UniformTest(test.TestCase): s_uniform = stats.uniform(loc=a, scale=b - a) self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformVariance(self): with self.test_session(): a = 10.0 @@ -229,7 +230,7 @@ class UniformTest(test.TestCase): s_uniform = stats.uniform(loc=a, scale=b - a) self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformStd(self): with self.test_session(): a = 10.0 @@ -240,7 +241,7 @@ class UniformTest(test.TestCase): s_uniform = stats.uniform(loc=a, scale=b - a) self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformNans(self): with self.test_session(): a = 10.0 @@ -258,7 +259,7 @@ class UniformTest(test.TestCase): self.assertFalse(is_nan[0]) self.assertTrue(is_nan[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformSamplePdf(self): with self.test_session(): a = 10.0 @@ -268,7 +269,7 @@ class UniformTest(test.TestCase): self.evaluate( math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformBroadcasting(self): with self.test_session(): a = 10.0 @@ -279,7 +280,7 @@ class UniformTest(test.TestCase): expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]]) self.assertAllClose(expected_pdf, self.evaluate(pdf)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUniformSampleWithShape(self): with self.test_session(): a = 10.0 @@ -299,6 +300,18 @@ class UniformTest(test.TestCase): expected_pdf = [1.0, 0.1] self.assertAllClose(expected_pdf, self.evaluate(pdf)) + def testFullyReparameterized(self): + a = constant_op.constant(0.1) + b = constant_op.constant(0.8) + with backprop.GradientTape() as tape: + tape.watch(a) + tape.watch(b) + uniform = uniform_lib.Uniform(a, b) + samples = uniform.sample(100) + grad_a, grad_b = tape.gradient(samples, [a, b]) + self.assertIsNotNone(grad_a) + self.assertIsNotNone(grad_b) + # Eager doesn't pass due to a type mismatch in one of the ops. def testUniformFloat64(self): uniform = uniform_lib.Uniform( diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py index 2f256d3e8beac145a14ca1dd63f267fb5f4ef3a5..61faa8466edcf404dc48fc0596c47cb3c2094f13 100644 --- a/tensorflow/python/kernel_tests/distributions/util_test.py +++ b/tensorflow/python/kernel_tests/distributions/util_test.py @@ -59,65 +59,6 @@ def _logit(x): class AssertCloseTest(test.TestCase): - def testAssertCloseIntegerDtype(self): - x = array_ops.placeholder(dtypes.int32) - y = x - z = array_ops.placeholder(dtypes.int32) - feed_dict = {x: [1, 5, 10, 15, 20], z: [2, 5, 10, 15, 20]} - with self.test_session(): - with ops.control_dependencies([du.assert_close(x, y)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with ops.control_dependencies([du.assert_close(y, x)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(x, z)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(y, z)]): - array_ops.identity(y).eval(feed_dict=feed_dict) - - def testAssertCloseNonIntegerDtype(self): - x = array_ops.placeholder(dtypes.float32) - y = x + 1e-8 - z = array_ops.placeholder(dtypes.float32) - feed_dict = {x: [1., 5, 10, 15, 20], z: [2., 5, 10, 15, 20]} - with self.test_session(): - with ops.control_dependencies([du.assert_close(x, y)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with ops.control_dependencies([du.assert_close(y, x)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(x, z)]): - array_ops.identity(x).eval(feed_dict=feed_dict) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(y, z)]): - array_ops.identity(y).eval(feed_dict=feed_dict) - - @test_util.run_in_graph_and_eager_modes() - def testAssertCloseEpsilon(self): - x = [0., 5, 10, 15, 20] - # x != y - y = [0.1, 5, 10, 15, 20] - # x = z - z = [1e-8, 5, 10, 15, 20] - with self.test_session(): - with ops.control_dependencies([du.assert_close(x, z)]): - self.evaluate(array_ops.identity(x)) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(x, y)]): - self.evaluate(array_ops.identity(x)) - - with self.assertRaisesOpError("Condition x ~= y"): - with ops.control_dependencies([du.assert_close(y, z)]): - self.evaluate(array_ops.identity(y)) - def testAssertIntegerForm(self): # This should only be detected as an integer. x = array_ops.placeholder(dtypes.float32) @@ -150,21 +91,21 @@ class AssertCloseTest(test.TestCase): class MaybeGetStaticTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetStaticInt(self): x = 2 self.assertEqual(x, du.maybe_get_static_value(x)) self.assertAllClose( np.array(2.), du.maybe_get_static_value(x, dtype=np.float64)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetStaticNumpyArray(self): x = np.array(2, dtype=np.int32) self.assertEqual(x, du.maybe_get_static_value(x)) self.assertAllClose( np.array(2.), du.maybe_get_static_value(x, dtype=np.float64)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetStaticConstant(self): x = constant_op.constant(2, dtype=dtypes.int32) self.assertEqual(np.array(2, dtype=np.int32), du.maybe_get_static_value(x)) @@ -179,7 +120,7 @@ class MaybeGetStaticTest(test.TestCase): class GetLogitsAndProbsTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testImproperArguments(self): with self.test_session(): with self.assertRaises(ValueError): @@ -188,7 +129,7 @@ class GetLogitsAndProbsTest(test.TestCase): with self.assertRaises(ValueError): du.get_logits_and_probs(logits=[0.1], probs=[0.1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLogits(self): p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) logits = _logit(p) @@ -200,7 +141,7 @@ class GetLogitsAndProbsTest(test.TestCase): self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.) self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLogitsMultidimensional(self): p = np.array([0.2, 0.3, 0.5], dtype=np.float32) logits = np.log(p) @@ -212,7 +153,7 @@ class GetLogitsAndProbsTest(test.TestCase): self.assertAllClose(self.evaluate(new_p), p) self.assertAllClose(self.evaluate(new_logits), logits) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testProbability(self): p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) @@ -223,7 +164,7 @@ class GetLogitsAndProbsTest(test.TestCase): self.assertAllClose(_logit(p), self.evaluate(new_logits)) self.assertAllClose(p, self.evaluate(new_p)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testProbabilityMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) @@ -234,7 +175,7 @@ class GetLogitsAndProbsTest(test.TestCase): self.assertAllClose(np.log(p), self.evaluate(new_logits)) self.assertAllClose(p, self.evaluate(new_p)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testProbabilityValidateArgs(self): p = [0.01, 0.2, 0.5, 0.7, .99] # Component less than 0. @@ -265,7 +206,7 @@ class GetLogitsAndProbsTest(test.TestCase): probs=p3, validate_args=False) self.evaluate(prob) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testProbabilityValidateArgsMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) # Component less than 0. Still sums to 1. @@ -367,11 +308,13 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase): param) checked_param.eval(feed_dict={param: np.ones([int(2**11+1)])}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUnsupportedDtype(self): with self.test_session(): + param = ops.convert_to_tensor( + np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype), + dtype=dtypes.qint16) with self.assertRaises(TypeError): - param = array_ops.ones([int(2**11+1)], dtype=dtypes.qint16) du.embed_check_categorical_event_shape(param) @@ -552,7 +495,7 @@ class RotateTransposeTest(test.TestCase): x = np.array(x) return np.transpose(x, np.roll(np.arange(len(x.shape)), shift)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRollStatic(self): with self.test_session(): if context.executing_eagerly(): diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py index 159cba5fa3d69be5e3e3b22a85138c29d03981cc..c4d4ce780be2fa5a2617874ddb608e41edf70c36 100644 --- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py @@ -27,7 +27,6 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gradients_impl import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import from tensorflow.python.platform import test -from tensorflow.python.framework import dtypes class DynamicStitchTestBase(object): diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index e53ca1dcaa520b6937aefa45e2740f1c94188b09..55d75cb4749d6f1a33d6cf7a993a336d1afcf992 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import itertools +import math import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin @@ -31,6 +32,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables @@ -736,6 +738,222 @@ class EmbeddingLookupSparseTest(test.TestCase): x, sp_ids, sp_weights, combiner="mean") +class SafeEmbeddingLookupSparseTest(test.TestCase): + + def _random_weights(self, vocab_size=4, embed_dim=4, num_shards=1): + assert vocab_size > 0 + assert embed_dim > 0 + assert num_shards > 0 + assert num_shards <= vocab_size + + embedding_weights = partitioned_variables.create_partitioned_variables( + shape=[vocab_size, embed_dim], + slicing=[num_shards, 1], + initializer=init_ops.truncated_normal_initializer( + mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)) + for w in embedding_weights: + w.initializer.run() + embedding_weights = [w.eval() for w in embedding_weights] + return embedding_weights + + def _ids_and_weights_2d(self): + # Each row demonstrates a test case: + # Row 0: multiple valid ids, 1 invalid id, weighted mean + # Row 1: all ids are invalid (leaving no valid ids after pruning) + # Row 2: no ids to begin with + # Row 3: single id + # Row 4: all ids have <=0 weight + indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [4, 0], [4, 1]] + ids = [0, 1, -1, -1, 2, 0, 1] + weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5] + shape = [5, 4] + + sparse_ids = sparse_tensor.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(ids, dtypes.int64), + constant_op.constant(shape, dtypes.int64)) + + sparse_weights = sparse_tensor.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(weights, dtypes.float32), + constant_op.constant(shape, dtypes.int64)) + + return sparse_ids, sparse_weights + + def _ids_and_weights_3d(self): + # Each (2-D) index demonstrates a test case: + # Index 0, 0: multiple valid ids, 1 invalid id, weighted mean + # Index 0, 1: all ids are invalid (leaving no valid ids after pruning) + # Index 0, 2: no ids to begin with + # Index 1, 0: single id + # Index 1, 1: all ids have <=0 weight + # Index 1, 2: no ids to begin with + indices = [[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 0], [1, 0, 0], [1, 1, 0], + [1, 1, 1]] + ids = [0, 1, -1, -1, 2, 0, 1] + weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5] + shape = [2, 3, 4] + + sparse_ids = sparse_tensor.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(ids, dtypes.int64), + constant_op.constant(shape, dtypes.int64)) + + sparse_weights = sparse_tensor.SparseTensor( + constant_op.constant(indices, dtypes.int64), + constant_op.constant(weights, dtypes.float32), + constant_op.constant(shape, dtypes.int64)) + + return sparse_ids, sparse_weights + + def test_safe_embedding_lookup_sparse_return_zero_vector(self): + with self.test_session(): + embedding_weights = self._random_weights() + sparse_ids, sparse_weights = self._ids_and_weights_2d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights).eval()) + + self.assertAllClose( + embedding_lookup_result, + [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / + 3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4]) + + def test_safe_embedding_lookup_sparse_return_special_vector(self): + with self.test_session(): + embedding_weights = self._random_weights() + sparse_ids, sparse_weights = self._ids_and_weights_2d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights, default_id=3).eval()) + + self.assertAllClose( + embedding_lookup_result, + [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / + 3.0, embedding_weights[0][3], embedding_weights[0][3], + embedding_weights[0][2], embedding_weights[0][3]]) + + def test_safe_embedding_lookup_sparse_no_weights(self): + with self.test_session(): + embedding_weights = self._random_weights() + sparse_ids, _ = self._ids_and_weights_2d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None).eval()) + + self.assertAllClose( + embedding_lookup_result, + [(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, + [0] * 4, embedding_weights[0][2], ( + embedding_weights[0][0] + embedding_weights[0][1]) / 2.0]) + + def test_safe_embedding_lookup_sparse_partitioned(self): + with self.test_session(): + embedding_weights = self._random_weights(num_shards=3) + sparse_ids, _ = self._ids_and_weights_2d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None).eval()) + + embedding_weights = list(itertools.chain(*embedding_weights)) + self.assertAllClose(embedding_lookup_result, + [(embedding_weights[0] + embedding_weights[1]) / 2.0, + [0] * 4, [0] * 4, embedding_weights[2], + (embedding_weights[0] + embedding_weights[1]) / 2.0]) + + def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self): + with self.test_session(): + embedding_weights = self._random_weights(num_shards=3) + sparse_ids, sparse_weights = self._ids_and_weights_2d() + + embedding_weights[1] = embedding_weights[1].astype(np.float64) + self.assertRaises(TypeError, embedding_ops.safe_embedding_lookup_sparse, + embedding_weights, sparse_ids) + embedding_weights = [ + constant_op.constant(w, dtype=dtypes.float64) + for w in embedding_weights + ] + self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse, + embedding_weights, sparse_ids, sparse_weights) + + def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self): + with self.test_session(): + embedding_weights = self._random_weights() + sparse_ids, sparse_weights = self._ids_and_weights_3d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights).eval()) + + self.assertAllClose(embedding_lookup_result, [[ + (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0, + [0] * 4, [0] * 4 + ], [embedding_weights[0][2], [0] * 4, [0] * 4]]) + + def test_safe_embedding_lookup_sparse_3d_return_special_vector(self): + with self.test_session(): + embedding_weights = self._random_weights() + sparse_ids, sparse_weights = self._ids_and_weights_3d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, sparse_weights, default_id=3).eval()) + + self.assertAllClose( + embedding_lookup_result, + [[(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / + 3.0, embedding_weights[0][3], embedding_weights[0][3]], [ + embedding_weights[0][2], embedding_weights[0][3], + embedding_weights[0][3] + ]]) + + def test_safe_embedding_lookup_sparse_3d_no_weights(self): + with self.test_session(): + embedding_weights = self._random_weights() + sparse_ids, _ = self._ids_and_weights_3d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None).eval()) + + self.assertAllClose(embedding_lookup_result, [[( + embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, [ + 0 + ] * 4], [ + embedding_weights[0][2], + (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4 + ]]) + + def test_safe_embedding_lookup_sparse_3d_partitioned(self): + with self.test_session(): + embedding_weights = self._random_weights(num_shards=3) + sparse_ids, _ = self._ids_and_weights_3d() + + embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse( + embedding_weights, sparse_ids, None).eval()) + + embedding_weights = list(itertools.chain(*embedding_weights)) + self.assertAllClose(embedding_lookup_result, [[ + (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4 + ], [ + embedding_weights[2], + (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4 + ]]) + + def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights( + self): + with self.test_session(): + embedding_weights = self._random_weights(num_shards=3) + sparse_ids, sparse_weights = self._ids_and_weights_3d() + + embedding_weights[1] = embedding_weights[1].astype(np.float64) + self.assertRaises(TypeError, embedding_ops.safe_embedding_lookup_sparse, + embedding_weights, sparse_ids) + embedding_weights = [ + constant_op.constant(w, dtype=dtypes.float64) + for w in embedding_weights + ] + self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse, + embedding_weights, sparse_ids, sparse_weights) + + class DynamicStitchOpTest(test.TestCase): def testCint32Cpu(self): diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py index 14a336c6881960cbf03ba767a051835feebf9d04..9e7b5283381dd7bc0725e1ab6fb9d7d13153f02d 100644 --- a/tensorflow/python/kernel_tests/fifo_queue_test.py +++ b/tensorflow/python/kernel_tests/fifo_queue_test.py @@ -126,14 +126,14 @@ class FIFOQueueTest(test.TestCase): q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run() self.assertEqual(4, q.size().eval()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMultipleDequeues(self): q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) self.evaluate(q.enqueue_many([[1, 2, 3]])) a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()]) self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testQueuesDontShare(self): q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) self.evaluate(q.enqueue(1)) diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index facadc971ff516e4f9edea0c4f52ab0953ec5fce..5db2e9821dcfff08f947a92c6097e9be660f8bd7 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import @@ -56,7 +57,7 @@ def simple_scoped_fn(a, x): class FunctionalOpsTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldl_Simple(self): with self.test_session(): elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") @@ -72,7 +73,7 @@ class FunctionalOpsTest(test.TestCase): initializer=10) self.assertAllEqual(880, self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldl_SingleInputMultiOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -83,7 +84,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(22, r_value[0]) self.assertAllEqual(20, r_value[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldl_MultiInputSingleOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -111,7 +112,7 @@ class FunctionalOpsTest(test.TestCase): self.assertEqual(len(variables.trainable_variables()), 1) self.assertAllEqual(880, self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldr_Simple(self): with self.test_session(): elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") @@ -127,7 +128,7 @@ class FunctionalOpsTest(test.TestCase): initializer=10) self.assertAllEqual(1282, self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldr_SingleInputMultiOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -138,7 +139,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(22, r_value[0]) self.assertAllEqual(20, r_value[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldr_MultiInputSingleOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -182,7 +183,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(720.0, self.evaluate(r)) # pylint: enable=unnecessary-lambda - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_Simple(self): with self.test_session(): nums = [1, 2, 3, 4, 5, 6] @@ -202,7 +203,7 @@ class FunctionalOpsTest(test.TestCase): values=constant_op.constant([0, 1, 2]), dense_shape=[2, 2])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMapOverScalarErrors(self): with self.assertRaisesRegexp(ValueError, "not scalars"): functional_ops.map_fn(lambda x: x, [1, 2]) @@ -251,7 +252,7 @@ class FunctionalOpsTest(test.TestCase): r = gradients_impl.gradients(y, elems)[0] self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_SimpleNotTensor(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -260,7 +261,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual( np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_SingleInputMultiOutput(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -275,7 +276,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual((nums + 3) * 2, received[0]) self.assertAllEqual(-(nums + 3) * 2, received[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_MultiOutputMismatchedDtype(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -287,7 +288,7 @@ class FunctionalOpsTest(test.TestCase): nums, dtype=[dtypes.int64, dtypes.int64]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_MultiInputSingleOutput(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -298,7 +299,7 @@ class FunctionalOpsTest(test.TestCase): received = self.evaluate(r) self.assertAllEqual(nums * nums + (-nums), received) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMap_MultiInputSameStructureOutput(self): with self.test_session(): nums = np.array([1, 2, 3, 4, 5, 6]) @@ -313,7 +314,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(-nums, received[1]) self.assertAllEqual(nums, received[2]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_Simple(self): with self.test_session(): elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") @@ -328,7 +329,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) # pylint: enable=unnecessary-lambda - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_Reverse(self): with self.test_session(): elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") @@ -345,7 +346,7 @@ class FunctionalOpsTest(test.TestCase): self.evaluate(r)) # pylint: enable=unnecessary-lambda - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_SingleInputMultiOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -357,7 +358,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0]) self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_MultiInputSingleOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -367,7 +368,7 @@ class FunctionalOpsTest(test.TestCase): (elems + 1, -elems), initializer) self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_MultiInputSameTypeOutput(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -377,7 +378,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(np.cumsum(elems), r_value[0]) self.assertAllEqual(np.cumsum(-elems), r_value[1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScan_MultiOutputMismatchedInitializer(self): with self.test_session(): elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) @@ -408,7 +409,7 @@ class FunctionalOpsTest(test.TestCase): results = np.array([6, 16, 38, 84, 178, 368]) self.assertAllEqual(results, self.evaluate(r)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScanFoldl_Nested(self): with self.test_session(): elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data") @@ -467,7 +468,7 @@ class FunctionalOpsTest(test.TestCase): variables.global_variables_initializer().run() sess.run(grad) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFoldShape(self): with self.test_session(): x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) @@ -479,7 +480,7 @@ class FunctionalOpsTest(test.TestCase): y = functional_ops.foldl(fn, x, initializer=initializer) self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMapShape(self): with self.test_session(): x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) @@ -491,7 +492,7 @@ class FunctionalOpsTest(test.TestCase): y = functional_ops.map_fn(lambda e: e, x) self.assertIs(None, y.get_shape().dims) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMapEmptyScalar(self): with self.test_session(): map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([])) @@ -507,7 +508,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual([0, 3, 2], map_return.get_shape().dims) self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScanShape(self): with self.test_session(): x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) @@ -604,6 +605,25 @@ class FunctionalOpsTest(test.TestCase): mul = sess.run(remote_op) self.assertEqual(mul, [6]) + def testRemoteFunctionSameDeviceDirectSession(self): + + @function.Defun(dtypes.int32, dtypes.int32) + def _remote_fn(a, b): + return math_ops.multiply(a, b) + + with ops.device("/cpu:0"): + a = variables.Variable(2, dtype=dtypes.int32) + b = variables.Variable(3, dtype=dtypes.int32) + + with ops.device("/cpu:0"): + remote_op = functional_ops.remote_call( + args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0") + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + mul = sess.run(remote_op) + self.assertEqual(mul, [6]) + def testRemoteFunctionCPUGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") @@ -652,6 +672,24 @@ class FunctionalOpsTest(test.TestCase): mul = sess.run(remote_op) self.assertEqual(mul, 9.0) + def testRemoteFunctionGPUCPUStrings(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + @function.Defun(dtypes.string) + def _remote_fn(inp): + return array_ops.identity(inp) + + a = array_ops.constant("a") + + with ops.device("/gpu:0"): + remote_op = functional_ops.remote_call( + args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0") + + with self.test_session() as sess: + ret = sess.run(remote_op) + self.assertAllEqual(ret, [b"a"]) + def testRemoteFunctionCrossProcess(self): workers, _ = test_util.create_local_cluster(2, 1) @@ -940,6 +978,8 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(sess.run(bvals), [17., 16.]) +# TODO(akshayka): Replace `function.Defun` with tf.contrib.eager.defun` in the +# below test cases. class PartitionedCallTest(test.TestCase): def testBasicSingleDevice(self): @@ -1015,7 +1055,7 @@ class PartitionedCallTest(test.TestCase): self.assertEqual(output, 6.) def testShardsRunOnRequestedDevices(self): - config = config_pb2.ConfigProto(device_count={"CPU": 3}) + config = config_pb2.ConfigProto(device_count={"CPU": 4}) @function.Defun() def Body(): @@ -1035,13 +1075,80 @@ class PartitionedCallTest(test.TestCase): with ops.device("/cpu:2"): s3 = iterator_ops.Iterator.from_structure( (dtypes.float32,)).string_handle() - return s1, s2, s3 + with ops.device(""): + # TODO(akshayka): This is unfortunate and brittle. It prevents + # `Iterator.from_structure` from assigning the iterator op to 'cpu:0'. + # Remove this hack once we have a way of obtaining metadata about + # function execution. + s4 = iterator_ops.Iterator.from_structure( + (dtypes.float32,)).string_handle() + return s1, s2, s3, s4 - with self.test_session(config=config): - outputs = functional_ops.partitioned_call(args=[], f=Body) - self.assertTrue(compat.as_bytes("CPU:0") in outputs[0].eval()) - self.assertTrue(compat.as_bytes("CPU:1") in outputs[1].eval()) - self.assertTrue(compat.as_bytes("CPU:2") in outputs[2].eval()) + with self.test_session(config=config, use_gpu=True) as sess: + with ops.device("/cpu:3"): + outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body)) + self.assertIn(compat.as_bytes("CPU:0"), outputs[0]) + self.assertIn(compat.as_bytes("CPU:1"), outputs[1]) + self.assertIn(compat.as_bytes("CPU:2"), outputs[2]) + self.assertIn(compat.as_bytes("CPU:3"), outputs[3]) + + with self.test_session(config=config, use_gpu=True): + with ops.device("/cpu:0"): + outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body)) + self.assertIn(compat.as_bytes("CPU:0"), outputs[0]) + self.assertIn(compat.as_bytes("CPU:1"), outputs[1]) + self.assertIn(compat.as_bytes("CPU:2"), outputs[2]) + self.assertIn(compat.as_bytes("CPU:0"), outputs[3]) + + def testAssignAddResourceVariable(self): + + v = resource_variable_ops.ResourceVariable(1.0) + + @function.Defun() + def AssignAdd(): + v.assign_add(1.0) + + op = functional_ops.partitioned_call( + args=AssignAdd.captured_inputs, f=AssignAdd) + _ = self.evaluate(variables.global_variables_initializer()) + _ = self.evaluate(op) + value = self.evaluate(v.read_value()) + self.assertEqual(value, 2.0) + + def testFunctionWithResourcesOnDifferentDevices(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPUs available.") + + with ops.device("/cpu:0"): + v_cpu_zero = resource_variable_ops.ResourceVariable( + [0.0, 1.0, 2.0], name="v_cpu_zero") + + with ops.device("/cpu:1"): + v_cpu_one = resource_variable_ops.ResourceVariable( + [0.0, 1.0, 2.0], name="v_cpu_one") + + with ops.device("/gpu:0"): + v_gpu = resource_variable_ops.ResourceVariable( + [0.0, 1.0, 2.0], name="v_gpu") + + def sum_gather(): + cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_zero, [1, 2])) + also_cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_one, [1, 2])) + gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) + return cpu_result, also_cpu_result, gpu_result + + defined = function.Defun()(sum_gather) + with self.test_session( + config=config_pb2.ConfigProto( + allow_soft_placement=False, + log_device_placement=True, + device_count={"CPU": 2})) as sess: + sess.run(variables.global_variables_initializer()) + expected = sess.run(sum_gather()) + result = sess.run( + functional_ops.partitioned_call( + args=defined.captured_inputs, f=defined)) + self.assertAllEqual(expected, result) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py index 58e2a8ac2a3b827647b1b1176f4b69e6a88b76c6..c0b419e1d13405d04c34fb642cec15760ddcf50f 100644 --- a/tensorflow/python/kernel_tests/gather_nd_op_test.py +++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py @@ -203,8 +203,7 @@ class GatherNdTest(test.TestCase): indices = [[[0], [7]]] # Make this one higher rank gather_nd = array_ops.gather_nd(params, indices) with self.assertRaisesOpError( - r"flat indices\[1, :\] = \[7\] does not index into param " - r"\(shape: \[3\]\)"): + r"indices\[0,1\] = \[7\] does not index into param shape \[3\]"): gather_nd.eval() def _disabledTestBadIndicesGPU(self): @@ -217,8 +216,7 @@ class GatherNdTest(test.TestCase): indices = [[[0], [7]]] # Make this one higher rank gather_nd = array_ops.gather_nd(params, indices) with self.assertRaisesOpError( - r"flat indices\[1, :\] = \[7\] does not index into param " - r"\(shape: \[3\]\)"): + r"indices\[0,1\] = \[7\] does not index into param shape \[3\]"): gather_nd.eval() def testBadIndicesWithSlicesCPU(self): @@ -227,8 +225,7 @@ class GatherNdTest(test.TestCase): indices = [[[0], [0], [1]]] # Make this one higher rank gather_nd = array_ops.gather_nd(params, indices) with self.assertRaisesOpError( - r"flat indices\[2, :\] = \[1\] does not index into param " - r"\(shape: \[1,3\]\)"): + r"indices\[0,2\] = \[1\] does not index into param shape \[1,3\]"): gather_nd.eval() def _disabledTestBadIndicesWithSlicesGPU(self): @@ -241,8 +238,7 @@ class GatherNdTest(test.TestCase): indices = [[[0], [0], [1]]] # Make this one higher rank gather_nd = array_ops.gather_nd(params, indices) with self.assertRaisesOpError( - r"flat indices\[2, :\] = \[1\] does not index into param " - r"\(shape: \[1,3\]\)"): + r"indices\[0,2\] = \[1\] does not index into param shape \[1,3\]"): gather_nd.eval() def testGradientsRank2Elements(self): diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index 795aa67248f66e72f8f772845c4ca5b2b1b06d3d..f6097ad48984a1bb62708185ebf9782b72036e6a 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -364,14 +364,52 @@ class UniformUnitScalingInitializationTest(test.TestCase): class VarianceScalingInitializationTest(test.TestCase): + def testTruncatedNormalDistribution(self): + shape = [100, 100] + expect_mean = 0. + expect_var = 1. / shape[0] + init = init_ops.variance_scaling_initializer( + distribution='truncated_normal') + + with self.test_session(use_gpu=True), \ + test.mock.patch.object( + random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \ + as mock_truncated_normal: + x = init(shape).eval() + self.assertTrue(mock_truncated_normal.called) + + self.assertNear(np.mean(x), expect_mean, err=1e-2) + self.assertNear(np.var(x), expect_var, err=1e-2) + def testNormalDistribution(self): shape = [100, 100] expect_mean = 0. expect_var = 1. / shape[0] init = init_ops.variance_scaling_initializer(distribution='normal') - with self.test_session(use_gpu=True): + with self.test_session(use_gpu=True), \ + test.mock.patch.object( + random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \ + as mock_truncated_normal: + x = init(shape).eval() + self.assertTrue(mock_truncated_normal.called) + + self.assertNear(np.mean(x), expect_mean, err=1e-2) + self.assertNear(np.var(x), expect_var, err=1e-2) + + def testUntruncatedNormalDistribution(self): + shape = [100, 100] + expect_mean = 0. + expect_var = 1. / shape[0] + init = init_ops.variance_scaling_initializer( + distribution='untruncated_normal') + + with self.test_session(use_gpu=True), \ + test.mock.patch.object( + random_ops, 'random_normal', wraps=random_ops.random_normal) \ + as mock_random_normal: x = init(shape).eval() + self.assertTrue(mock_random_normal.called) self.assertNear(np.mean(x), expect_mean, err=1e-2) self.assertNear(np.var(x), expect_var, err=1e-2) @@ -792,7 +830,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase): tol = 1e-3 gain = 3.14 # Check orthogonality/isometry by computing the ratio between - # the 2-norms of the inputs and ouputs. + # the 2-norms of the inputs and outputs. for kernel_size in [[1], [2], [3], [4], [5], [6]]: convolution = convolutional.conv1d inputs = random_ops.random_normal(shape, dtype=dtype) @@ -887,7 +925,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase): tol = 1e-3 gain = 3.14 # Check orthogonality/isometry by computing the ratio between - # the 2-norms of the inputs and ouputs. + # the 2-norms of the inputs and outputs. for kernel_size in [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]: convolution = convolutional.conv2d inputs = random_ops.random_normal(shape, dtype=dtype) @@ -1012,7 +1050,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase): tol = 1e-3 gain = 3.14 # Check orthogonality/isometry by computing the ratio between - # the 2-norms of the inputs and ouputs. + # the 2-norms of the inputs and outputs. for kernel_size in [[1, 1, 1], [2, 2, 2], [3, 3, 3]]: convolution = convolutional.conv3d inputs = random_ops.random_normal(shape, dtype=dtype) diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 0123adc2c3e5c32fd86ef11e7b1f552964232abd..f4ec3e3996a17405b65d240534d2f2d47973d418 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -107,6 +107,10 @@ cuda_py_test( "//tensorflow/python:random_ops", ], shard_count = 5, + tags = [ + "noasan", + "optonly", + ], ) cuda_py_test( @@ -124,7 +128,10 @@ cuda_py_test( "//tensorflow/python:random_ops", ], shard_count = 5, - tags = ["optonly"], # Test is flaky without optimization. + tags = [ + "noasan", + "optonly", + ], ) cuda_py_test( @@ -141,6 +148,10 @@ cuda_py_test( "//tensorflow/python:platform_test", ], shard_count = 5, + tags = [ + "noasan", + "optonly", + ], ) cuda_py_test( @@ -178,11 +189,15 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], + tags = [ + "noasan", + "optonly", + ], ) cuda_py_test( name = "linear_operator_low_rank_update_test", - size = "medium", + size = "large", srcs = ["linear_operator_low_rank_update_test.py"], additional_deps = [ "//tensorflow/python/ops/linalg", @@ -214,4 +229,26 @@ cuda_py_test( "//tensorflow/python:platform_test", ], shard_count = 5, + tags = [ + "noasan", + "optonly", + ], +) + +cuda_py_test( + name = "linear_operator_zeros_test", + size = "medium", + srcs = ["linear_operator_zeros_test.py"], + additional_deps = [ + "//tensorflow/python/ops/linalg", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + ], + shard_count = 5, + tags = ["optonly"], # Test is flaky without optimization. ) diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py index 2b80f01b73441185281a3e2ef4db003b150c1e12..3ede2aceaa51c2795029ba13b763fed3e2ddc441 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py @@ -80,7 +80,7 @@ class SquareLinearOperatorBlockDiagTest( build_info((2, 1, 5, 5), blocks=[(2, 1, 2, 2), (1, 3, 3)]), ] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) expected_blocks = ( build_info.__dict__["blocks"] if "blocks" in build_info.__dict__ @@ -91,26 +91,19 @@ class SquareLinearOperatorBlockDiagTest( for block_shape in expected_blocks ] + lin_op_matrices = matrices + if use_placeholder: - matrices_ph = [ - array_ops.placeholder(dtype=dtype) for _ in expected_blocks - ] - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrices = self.evaluate(matrices) - operator = block_diag.LinearOperatorBlockDiag( - [linalg.LinearOperatorFullMatrix( - m_ph, is_square=True) for m_ph in matrices_ph], - is_square=True) - feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} - else: - operator = block_diag.LinearOperatorBlockDiag( - [linalg.LinearOperatorFullMatrix( - m, is_square=True) for m in matrices]) - feed_dict = None - # Should be auto-set. - self.assertTrue(operator.is_square) + lin_op_matrices = [ + array_ops.placeholder_with_default( + matrix, shape=None) for matrix in matrices] + + operator = block_diag.LinearOperatorBlockDiag( + [linalg.LinearOperatorFullMatrix( + l, is_square=True) for l in lin_op_matrices]) + + # Should be auto-set. + self.assertTrue(operator.is_square) # Broadcast the shapes. expected_shape = list(build_info.shape) @@ -123,7 +116,7 @@ class SquareLinearOperatorBlockDiagTest( block_diag_dense.set_shape( expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]]) - return operator, block_diag_dense, feed_dict + return operator, block_diag_dense def test_is_x_flags(self): # Matrix with two positive eigenvalues, 1, and 1. diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py index 5713d169696c78e996332b7a515a3ee2eedca839..7261d4bb3bc4aa24f51be21f9ac261549dca58d5 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py @@ -95,7 +95,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator( # real, the matrix will not be real. return [dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # For this test class, we are creating real spectrums. # We also want the spectrum to have eigenvalues bounded away from zero. @@ -107,22 +107,18 @@ class LinearOperatorCirculantTestSelfAdjointOperator( # zero, so the operator will still be self-adjoint. spectrum = math_ops.cast(spectrum, dtype) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant( - spectrum_ph, is_self_adjoint=True, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant( - spectrum, is_self_adjoint=True, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant( + lin_op_spectrum, is_self_adjoint=True, input_output_dtype=dtype) mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): with self.test_session(): @@ -149,7 +145,7 @@ class LinearOperatorCirculantTestHermitianSpectrum( def _dtypes_to_test(self): return [dtypes.float32, dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # For this test class, we are creating Hermitian spectrums. # We also want the spectrum to have eigenvalues bounded away from zero. @@ -172,22 +168,18 @@ class LinearOperatorCirculantTestHermitianSpectrum( spectrum = math_ops.fft(h_c) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant( - spectrum_ph, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant( - spectrum, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant( + lin_op_spectrum, input_output_dtype=dtype) mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): with self.test_session(): @@ -213,7 +205,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( def _dtypes_to_test(self): return [dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # Will be well conditioned enough to get accurate solves. spectrum = linear_operator_test_util.random_sign_uniform( @@ -222,22 +214,18 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( minval=1., maxval=2.) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant( - spectrum_ph, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant( - spectrum, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant( + lin_op_spectrum, input_output_dtype=dtype) mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): with self.test_session(): @@ -432,7 +420,7 @@ class LinearOperatorCirculant2DTestHermitianSpectrum( def _dtypes_to_test(self): return [dtypes.float32, dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # For this test class, we are creating Hermitian spectrums. # We also want the spectrum to have eigenvalues bounded away from zero. @@ -455,22 +443,18 @@ class LinearOperatorCirculant2DTestHermitianSpectrum( spectrum = math_ops.fft2d(h_c) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant2D( - spectrum_ph, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant2D( - spectrum, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant2D( + lin_op_spectrum, input_output_dtype=dtype) mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat class LinearOperatorCirculant2DTestNonHermitianSpectrum( @@ -486,7 +470,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum( def _dtypes_to_test(self): return [dtypes.complex64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = build_info.shape # Will be well conditioned enough to get accurate solves. spectrum = linear_operator_test_util.random_sign_uniform( @@ -495,22 +479,18 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum( minval=1., maxval=2.) + lin_op_spectrum = spectrum + if use_placeholder: - spectrum_ph = array_ops.placeholder(dtypes.complex64) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # it is random and we want the same value used for both mat and feed_dict. - spectrum = spectrum.eval() - operator = linalg.LinearOperatorCirculant2D( - spectrum_ph, input_output_dtype=dtype) - feed_dict = {spectrum_ph: spectrum} - else: - operator = linalg.LinearOperatorCirculant2D( - spectrum, input_output_dtype=dtype) - feed_dict = None + lin_op_spectrum = array_ops.placeholder_with_default( + spectrum, shape=None) + + operator = linalg.LinearOperatorCirculant2D( + lin_op_spectrum, input_output_dtype=dtype) mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype) - return operator, mat, feed_dict + return operator, mat def test_real_hermitian_spectrum_gives_real_symmetric_operator(self): with self.test_session() as sess: diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py index f96b9ccdaacae7d8e0552ed3d74ce53808fed963..612a50bcec771f8511d20d19b312a797d531f109 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py @@ -44,7 +44,7 @@ class SquareLinearOperatorCompositionTest( self._rtol[dtypes.float32] = 1e-4 self._rtol[dtypes.complex64] = 1e-4 - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): sess = ops.get_default_session() shape = list(build_info.shape) @@ -56,33 +56,23 @@ class SquareLinearOperatorCompositionTest( for _ in range(num_operators) ] + lin_op_matrices = matrices + if use_placeholder: - matrices_ph = [ - array_ops.placeholder(dtype=dtype) for _ in range(num_operators) - ] - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrices = sess.run(matrices) - operator = linalg.LinearOperatorComposition( - [linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph], - is_square=True) - feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} - else: - operator = linalg.LinearOperatorComposition( - [linalg.LinearOperatorFullMatrix(m) for m in matrices]) - feed_dict = None - # Should be auto-set. - self.assertTrue(operator.is_square) - - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated each matrix to a numpy array. + lin_op_matrices = [ + array_ops.placeholder_with_default( + matrix, shape=None) for matrix in matrices] + + operator = linalg.LinearOperatorComposition( + [linalg.LinearOperatorFullMatrix(l) for l in lin_op_matrices], + is_square=True) + matmul_order_list = list(reversed(matrices)) - mat = ops.convert_to_tensor(matmul_order_list[0]) + mat = matmul_order_list[0] for other_mat in matmul_order_list[1:]: mat = math_ops.matmul(other_mat, mat) - return operator, mat, feed_dict + return operator, mat def test_is_x_flags(self): # Matrix with two positive eigenvalues, 1, and 1. @@ -148,7 +138,7 @@ class NonSquareLinearOperatorCompositionTest( self._rtol[dtypes.float32] = 1e-4 self._rtol[dtypes.complex64] = 1e-4 - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): sess = ops.get_default_session() shape = list(build_info.shape) @@ -170,30 +160,22 @@ class NonSquareLinearOperatorCompositionTest( shape_2, dtype=dtype) ] + lin_op_matrices = matrices + if use_placeholder: - matrices_ph = [ - array_ops.placeholder(dtype=dtype) for _ in range(num_operators) - ] - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrices = sess.run(matrices) - operator = linalg.LinearOperatorComposition( - [linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph]) - feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} - else: - operator = linalg.LinearOperatorComposition( - [linalg.LinearOperatorFullMatrix(m) for m in matrices]) - feed_dict = None - - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated each matrix to a numpy array. + lin_op_matrices = [ + array_ops.placeholder_with_default( + matrix, shape=None) for matrix in matrices] + + operator = linalg.LinearOperatorComposition( + [linalg.LinearOperatorFullMatrix(l) for l in lin_op_matrices]) + matmul_order_list = list(reversed(matrices)) - mat = ops.convert_to_tensor(matmul_order_list[0]) + mat = matmul_order_list[0] for other_mat in matmul_order_list[1:]: mat = math_ops.matmul(other_mat, mat) - return operator, mat, feed_dict + return operator, mat def test_static_shapes(self): operators = [ diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py index 0a0e31c716ecfa10ed93cff92fa908a240f8495e..83cc8c483f9aec6dd0ddf3f961a8180af7515e40 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py @@ -34,25 +34,21 @@ class LinearOperatorDiagTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) diag = linear_operator_test_util.random_sign_uniform( shape[:-1], minval=1., maxval=2., dtype=dtype) + + lin_op_diag = diag + if use_placeholder: - diag_ph = array_ops.placeholder(dtype=dtype) - # Evaluate the diag here because (i) you cannot feed a tensor, and (ii) - # diag is random and we want the same value used for both mat and - # feed_dict. - diag = diag.eval() - operator = linalg.LinearOperatorDiag(diag_ph) - feed_dict = {diag_ph: diag} - else: - operator = linalg.LinearOperatorDiag(diag) - feed_dict = None + lin_op_diag = array_ops.placeholder_with_default(diag, shape=None) + + operator = linalg.LinearOperatorDiag(lin_op_diag) - mat = array_ops.matrix_diag(diag) + matrix = array_ops.matrix_diag(diag) - return operator, mat, feed_dict + return operator, matrix def test_assert_positive_definite_raises_for_zero_eigenvalue(self): # Matrix with one positive eigenvalue and one zero eigenvalue. diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py index b3da623b5e8d8c99c6777e75e2d49f24dab1c96b..1a40a29ec6a040ca3d98e0b27492b1379d30cb4b 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -36,30 +35,20 @@ class SquareLinearOperatorFullMatrixTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) matrix = linear_operator_test_util.random_positive_definite_matrix( shape, dtype) + lin_op_matrix = matrix + if use_placeholder: - matrix_ph = array_ops.placeholder(dtype=dtype) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrix = matrix.eval() - operator = linalg.LinearOperatorFullMatrix(matrix_ph, is_square=True) - feed_dict = {matrix_ph: matrix} - else: - # is_square should be auto-detected here. - operator = linalg.LinearOperatorFullMatrix(matrix) - feed_dict = None + lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated matrix to a numpy array. - mat = ops.convert_to_tensor(matrix) + operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True) - return operator, mat, feed_dict + return operator, matrix def test_is_x_flags(self): # Matrix with two positive eigenvalues. @@ -136,32 +125,20 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest( def _dtypes_to_test(self): return [dtypes.float32, dtypes.float64] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) matrix = linear_operator_test_util.random_positive_definite_matrix( shape, dtype, force_well_conditioned=True) + lin_op_matrix = matrix + if use_placeholder: - matrix_ph = array_ops.placeholder(dtype=dtype) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrix = matrix.eval() - # is_square is auto-set because of self_adjoint/pd. - operator = linalg.LinearOperatorFullMatrix( - matrix_ph, is_self_adjoint=True, is_positive_definite=True) - feed_dict = {matrix_ph: matrix} - else: - operator = linalg.LinearOperatorFullMatrix( - matrix, is_self_adjoint=True, is_positive_definite=True) - feed_dict = None - - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated matrix to a numpy array. - mat = ops.convert_to_tensor(matrix) - - return operator, mat, feed_dict + lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) + + operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True) + + return operator, matrix def test_is_x_flags(self): # Matrix with two positive eigenvalues. @@ -210,26 +187,18 @@ class NonSquareLinearOperatorFullMatrixTest( linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) matrix = linear_operator_test_util.random_normal(shape, dtype=dtype) + + lin_op_matrix = matrix + if use_placeholder: - matrix_ph = array_ops.placeholder(dtype=dtype) - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrix = matrix.eval() - operator = linalg.LinearOperatorFullMatrix(matrix_ph) - feed_dict = {matrix_ph: matrix} - else: - operator = linalg.LinearOperatorFullMatrix(matrix) - feed_dict = None + lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) - # Convert back to Tensor. Needed if use_placeholder, since then we have - # already evaluated matrix to a numpy array. - mat = ops.convert_to_tensor(matrix) + operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True) - return operator, mat, feed_dict + return operator, matrix def test_is_x_flags(self): matrix = [[3., 2., 1.], [1., 1., 1.]] diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py index 59f63f949e96991193412d3574603e58a75cb6e5..35dcf4417c313f5cbc00c8b66b4c5d1f2e157212 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py @@ -43,7 +43,7 @@ class LinearOperatorIdentityTest( # 16bit. return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) assert shape[-1] == shape[-2] @@ -54,13 +54,7 @@ class LinearOperatorIdentityTest( num_rows, batch_shape=batch_shape, dtype=dtype) mat = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype) - # Nothing to feed since LinearOperatorIdentity takes no Tensor args. - if use_placeholder: - feed_dict = {} - else: - feed_dict = None - - return operator, mat, feed_dict + return operator, mat def test_assert_positive_definite(self): with self.test_session(): @@ -261,7 +255,7 @@ class LinearOperatorScaledIdentityTest( # 16bit. return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) assert shape[-1] == shape[-2] @@ -274,24 +268,23 @@ class LinearOperatorScaledIdentityTest( multiplier = linear_operator_test_util.random_sign_uniform( shape=batch_shape, minval=1., maxval=2., dtype=dtype) - operator = linalg_lib.LinearOperatorScaledIdentity(num_rows, multiplier) # Nothing to feed since LinearOperatorScaledIdentity takes no Tensor args. + lin_op_multiplier = multiplier + if use_placeholder: - multiplier_ph = array_ops.placeholder(dtype=dtype) - multiplier = multiplier.eval() - operator = linalg_lib.LinearOperatorScaledIdentity( - num_rows, multiplier_ph) - feed_dict = {multiplier_ph: multiplier} - else: - feed_dict = None + lin_op_multiplier = array_ops.placeholder_with_default( + multiplier, shape=None) + + operator = linalg_lib.LinearOperatorScaledIdentity( + num_rows, lin_op_multiplier) multiplier_matrix = array_ops.expand_dims( array_ops.expand_dims(multiplier, -1), -1) - mat = multiplier_matrix * linalg_ops.eye( + matrix = multiplier_matrix * linalg_ops.eye( num_rows, batch_shape=batch_shape, dtype=dtype) - return operator, mat, feed_dict + return operator, matrix def test_assert_positive_definite_does_not_raise_when_positive(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py index 784c730bbc8179dd1302294b2d558e8a0c532c0c..e26b946151dd8ddb923e34352feb6b483f9752fc 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py @@ -101,7 +101,7 @@ class SquareLinearOperatorKroneckerTest( def _tests_to_skip(self): return ["det", "solve", "solve_with_broadcast"] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) expected_factors = build_info.__dict__["factors"] matrices = [ @@ -110,26 +110,15 @@ class SquareLinearOperatorKroneckerTest( for block_shape in expected_factors ] + lin_op_matrices = matrices + if use_placeholder: - matrices_ph = [ - array_ops.placeholder(dtype=dtype) for _ in expected_factors - ] - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - matrices = self.evaluate(matrices) - operator = kronecker.LinearOperatorKronecker( - [linalg.LinearOperatorFullMatrix( - m_ph, is_square=True) for m_ph in matrices_ph], - is_square=True) - feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} - else: - operator = kronecker.LinearOperatorKronecker( - [linalg.LinearOperatorFullMatrix( - m, is_square=True) for m in matrices]) - feed_dict = None - # Should be auto-set. - self.assertTrue(operator.is_square) + lin_op_matrices = [ + array_ops.placeholder_with_default(m, shape=None) for m in matrices] + + operator = kronecker.LinearOperatorKronecker( + [linalg.LinearOperatorFullMatrix( + l, is_square=True) for l in lin_op_matrices]) matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) @@ -138,7 +127,7 @@ class SquareLinearOperatorKroneckerTest( if not use_placeholder: kronecker_dense.set_shape(shape) - return operator, kronecker_dense, feed_dict + return operator, kronecker_dense def test_is_x_flags(self): # Matrix with two positive eigenvalues, 1, and 1. diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py index 8095f6419ef0d9543339cf1f4ee9cd4783f852b9..0e38dbd48d2252be4b3f0455ad69994ac5814126 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py @@ -48,12 +48,6 @@ class BaseLinearOperatorLowRankUpdatetest(object): # If False, A = L + UDU^H or A = L + UU^H, depending on _use_diag_update _use_v = None - @property - def _dtypes_to_test(self): - # TODO(langmore) Test complex types once cholesky works with them. - # See comment in LinearOperatorLowRankUpdate.__init__. - return [dtypes.float32, dtypes.float64] - @property def _operator_build_infos(self): build_info = linear_operator_test_util.OperatorBuildInfo @@ -68,7 +62,16 @@ class BaseLinearOperatorLowRankUpdatetest(object): build_info((3, 4, 4)), build_info((2, 1, 4, 4))] - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _gen_positive_diag(self, dtype, diag_shape): + if dtype.is_complex: + diag = linear_operator_test_util.random_uniform( + diag_shape, minval=1e-4, maxval=1., dtype=dtypes.float32) + return math_ops.cast(diag, dtype=dtype) + + return linear_operator_test_util.random_uniform( + diag_shape, minval=1e-4, maxval=1., dtype=dtype) + + def _operator_and_matrix(self, build_info, dtype, use_placeholder): # Recall A = L + UDV^H shape = list(build_info.shape) diag_shape = shape[:-1] @@ -78,63 +81,46 @@ class BaseLinearOperatorLowRankUpdatetest(object): # base_operator L will be a symmetric positive definite diagonal linear # operator, with condition number as high as 1e4. - base_diag = linear_operator_test_util.random_uniform( - diag_shape, minval=1e-4, maxval=1., dtype=dtype) - base_diag_ph = array_ops.placeholder(dtype=dtype) + base_diag = self._gen_positive_diag(dtype, diag_shape) + lin_op_base_diag = base_diag # U u = linear_operator_test_util.random_normal_correlated_columns( u_perturbation_shape, dtype=dtype) - u_ph = array_ops.placeholder(dtype=dtype) + lin_op_u = u # V v = linear_operator_test_util.random_normal_correlated_columns( u_perturbation_shape, dtype=dtype) - v_ph = array_ops.placeholder(dtype=dtype) + lin_op_v = v # D if self._is_diag_update_positive: - diag_update = linear_operator_test_util.random_uniform( - diag_update_shape, minval=1e-4, maxval=1., dtype=dtype) + diag_update = self._gen_positive_diag(dtype, diag_update_shape) else: diag_update = linear_operator_test_util.random_normal( diag_update_shape, stddev=1e-4, dtype=dtype) - diag_update_ph = array_ops.placeholder(dtype=dtype) + lin_op_diag_update = diag_update if use_placeholder: - # Evaluate here because (i) you cannot feed a tensor, and (ii) - # values are random and we want the same value used for both mat and - # feed_dict. - base_diag = base_diag.eval() - u = u.eval() - v = v.eval() - diag_update = diag_update.eval() - - # In all cases, set base_operator to be positive definite. - base_operator = linalg.LinearOperatorDiag( - base_diag_ph, is_positive_definite=True) - - operator = linalg.LinearOperatorLowRankUpdate( - base_operator, - u=u_ph, - v=v_ph if self._use_v else None, - diag_update=diag_update_ph if self._use_diag_update else None, - is_diag_update_positive=self._is_diag_update_positive) - feed_dict = { - base_diag_ph: base_diag, - u_ph: u, - v_ph: v, - diag_update_ph: diag_update} - else: - base_operator = linalg.LinearOperatorDiag( - base_diag, is_positive_definite=True) - operator = linalg.LinearOperatorLowRankUpdate( - base_operator, - u, - v=v if self._use_v else None, - diag_update=diag_update if self._use_diag_update else None, - is_diag_update_positive=self._is_diag_update_positive) - feed_dict = None + lin_op_base_diag = array_ops.placeholder_with_default( + base_diag, shape=None) + lin_op_u = array_ops.placeholder_with_default(u, shape=None) + lin_op_v = array_ops.placeholder_with_default(v, shape=None) + lin_op_diag_update = array_ops.placeholder_with_default( + diag_update, shape=None) + + base_operator = linalg.LinearOperatorDiag( + lin_op_base_diag, + is_positive_definite=True, + is_self_adjoint=True) + + operator = linalg.LinearOperatorLowRankUpdate( + base_operator, + lin_op_u, + v=lin_op_v if self._use_v else None, + diag_update=lin_op_diag_update if self._use_diag_update else None, + is_diag_update_positive=self._is_diag_update_positive) # The matrix representing L base_diag_mat = array_ops.matrix_diag(base_diag) @@ -146,28 +132,28 @@ class BaseLinearOperatorLowRankUpdatetest(object): if self._use_v and self._use_diag_update: # In this case, we have L + UDV^H and it isn't symmetric. expect_use_cholesky = False - mat = base_diag_mat + math_ops.matmul( + matrix = base_diag_mat + math_ops.matmul( u, math_ops.matmul(diag_update_mat, v, adjoint_b=True)) elif self._use_v: # In this case, we have L + UDV^H and it isn't symmetric. expect_use_cholesky = False - mat = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True) + matrix = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True) elif self._use_diag_update: # In this case, we have L + UDU^H, which is PD if D > 0, since L > 0. expect_use_cholesky = self._is_diag_update_positive - mat = base_diag_mat + math_ops.matmul( + matrix = base_diag_mat + math_ops.matmul( u, math_ops.matmul(diag_update_mat, u, adjoint_b=True)) else: # In this case, we have L + UU^H, which is PD since L > 0. expect_use_cholesky = True - mat = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True) + matrix = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True) if expect_use_cholesky: self.assertTrue(operator._use_cholesky) else: self.assertFalse(operator._use_cholesky) - return operator, mat, feed_dict + return operator, matrix class LinearOperatorLowRankUpdatetestWithDiagUseCholesky( @@ -186,6 +172,7 @@ class LinearOperatorLowRankUpdatetestWithDiagUseCholesky( self._rtol[dtypes.float32] = 1e-5 self._atol[dtypes.float64] = 1e-10 self._rtol[dtypes.float64] = 1e-10 + self._rtol[dtypes.complex64] = 1e-4 class LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky( @@ -205,6 +192,7 @@ class LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky( self._rtol[dtypes.float32] = 1e-4 self._atol[dtypes.float64] = 1e-9 self._rtol[dtypes.float64] = 1e-9 + self._rtol[dtypes.complex64] = 1e-4 class LinearOperatorLowRankUpdatetestNoDiagUseCholesky( @@ -223,6 +211,7 @@ class LinearOperatorLowRankUpdatetestNoDiagUseCholesky( self._rtol[dtypes.float32] = 1e-5 self._atol[dtypes.float64] = 1e-10 self._rtol[dtypes.float64] = 1e-10 + self._rtol[dtypes.complex64] = 1e-4 class LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky( @@ -242,6 +231,7 @@ class LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky( self._rtol[dtypes.float32] = 1e-4 self._atol[dtypes.float64] = 1e-9 self._rtol[dtypes.float64] = 1e-9 + self._rtol[dtypes.complex64] = 1e-4 class LinearOperatorLowRankUpdatetestWithDiagNotSquare( diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py index a57d2f085e089fb913f09fdd9b07cf13aa7f3c35..b389e0cbdf72f2cd43751bd75e5b103b313df4b7 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops.linalg import linalg as linalg_lib @@ -32,34 +31,23 @@ class LinearOperatorLowerTriangularTest( linear_operator_test_util.SquareLinearOperatorDerivedClassTest): """Most tests done in the base class LinearOperatorDerivedClassTest.""" - @property - def _dtypes_to_test(self): - # TODO(langmore) Test complex types once supported by - # matrix_triangular_solve. - return [dtypes.float32, dtypes.float64] - - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): shape = list(build_info.shape) # Upper triangle will be nonzero, but ignored. # Use a diagonal that ensures this matrix is well conditioned. tril = linear_operator_test_util.random_tril_matrix( shape, dtype=dtype, force_well_conditioned=True, remove_upper=False) + lin_op_tril = tril + if use_placeholder: - tril_ph = array_ops.placeholder(dtype=dtype) - # Evaluate the tril here because (i) you cannot feed a tensor, and (ii) - # tril is random and we want the same value used for both mat and - # feed_dict. - tril = tril.eval() - operator = linalg.LinearOperatorLowerTriangular(tril_ph) - feed_dict = {tril_ph: tril} - else: - operator = linalg.LinearOperatorLowerTriangular(tril) - feed_dict = None + lin_op_tril = array_ops.placeholder_with_default(lin_op_tril, shape=None) + + operator = linalg.LinearOperatorLowerTriangular(lin_op_tril) - mat = array_ops.matrix_band_part(tril, -1, 0) + matrix = array_ops.matrix_band_part(tril, -1, 0) - return operator, mat, feed_dict + return operator, matrix def test_assert_non_singular(self): # Singlular matrix with one positive eigenvalue and one zero eigenvalue. diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8f60b55e0ad416ea1f09996f633e09a7dc2c3741 --- /dev/null +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py @@ -0,0 +1,192 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.linalg import linalg as linalg_lib +from tensorflow.python.ops.linalg import linear_operator_test_util +from tensorflow.python.platform import test + + +random_seed.set_random_seed(23) +rng = np.random.RandomState(2016) + + +class LinearOperatorZerosTest( + linear_operator_test_util.SquareLinearOperatorDerivedClassTest): + """Most tests done in the base class LinearOperatorDerivedClassTest.""" + + @property + def _tests_to_skip(self): + return ["log_abs_det", "solve", "solve_with_broadcast"] + + @property + def _operator_build_infos(self): + build_info = linear_operator_test_util.OperatorBuildInfo + return [ + build_info((1, 1)), + build_info((1, 3, 3)), + build_info((3, 4, 4)), + build_info((2, 1, 4, 4))] + + def _operator_and_matrix(self, build_info, dtype, use_placeholder): + del use_placeholder + shape = list(build_info.shape) + assert shape[-1] == shape[-2] + + batch_shape = shape[:-2] + num_rows = shape[-1] + + operator = linalg_lib.LinearOperatorZeros( + num_rows, batch_shape=batch_shape, dtype=dtype) + matrix = array_ops.zeros(shape=shape, dtype=dtype) + + return operator, matrix + + def test_assert_positive_definite(self): + operator = linalg_lib.LinearOperatorZeros(num_rows=2) + with self.assertRaisesOpError("non-positive definite"): + operator.assert_positive_definite() + + def test_assert_non_singular(self): + with self.assertRaisesOpError("non-invertible"): + operator = linalg_lib.LinearOperatorZeros(num_rows=2) + operator.assert_non_singular() + + def test_assert_self_adjoint(self): + with self.test_session(): + operator = linalg_lib.LinearOperatorZeros(num_rows=2) + operator.assert_self_adjoint().run() # Should not fail + + def test_non_scalar_num_rows_raises_static(self): + with self.assertRaisesRegexp(ValueError, "must be a 0-D Tensor"): + linalg_lib.LinearOperatorZeros(num_rows=[2]) + with self.assertRaisesRegexp(ValueError, "must be a 0-D Tensor"): + linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=[2]) + + def test_non_integer_num_rows_raises_static(self): + with self.assertRaisesRegexp(TypeError, "must be integer"): + linalg_lib.LinearOperatorZeros(num_rows=2.) + with self.assertRaisesRegexp(TypeError, "must be integer"): + linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=2.) + + def test_negative_num_rows_raises_static(self): + with self.assertRaisesRegexp(ValueError, "must be non-negative"): + linalg_lib.LinearOperatorZeros(num_rows=-2) + with self.assertRaisesRegexp(ValueError, "must be non-negative"): + linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=-2) + + def test_non_1d_batch_shape_raises_static(self): + with self.assertRaisesRegexp(ValueError, "must be a 1-D"): + linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=2) + + def test_non_integer_batch_shape_raises_static(self): + with self.assertRaisesRegexp(TypeError, "must be integer"): + linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[2.]) + + def test_negative_batch_shape_raises_static(self): + with self.assertRaisesRegexp(ValueError, "must be non-negative"): + linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[-2]) + + def test_non_scalar_num_rows_raises_dynamic(self): + with self.test_session(): + num_rows = array_ops.placeholder(dtypes.int32) + operator = linalg_lib.LinearOperatorZeros( + num_rows, assert_proper_shapes=True) + with self.assertRaisesOpError("must be a 0-D Tensor"): + operator.to_dense().eval(feed_dict={num_rows: [2]}) + + def test_negative_num_rows_raises_dynamic(self): + with self.test_session(): + n = array_ops.placeholder(dtypes.int32) + operator = linalg_lib.LinearOperatorZeros( + num_rows=n, assert_proper_shapes=True) + with self.assertRaisesOpError("must be non-negative"): + operator.to_dense().eval(feed_dict={n: -2}) + + operator = linalg_lib.LinearOperatorZeros( + num_rows=2, num_columns=n, assert_proper_shapes=True) + with self.assertRaisesOpError("must be non-negative"): + operator.to_dense().eval(feed_dict={n: -2}) + + def test_non_1d_batch_shape_raises_dynamic(self): + with self.test_session(): + batch_shape = array_ops.placeholder(dtypes.int32) + operator = linalg_lib.LinearOperatorZeros( + num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True) + with self.assertRaisesOpError("must be a 1-D"): + operator.to_dense().eval(feed_dict={batch_shape: 2}) + + def test_negative_batch_shape_raises_dynamic(self): + with self.test_session(): + batch_shape = array_ops.placeholder(dtypes.int32) + operator = linalg_lib.LinearOperatorZeros( + num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True) + with self.assertRaisesOpError("must be non-negative"): + operator.to_dense().eval(feed_dict={batch_shape: [-2]}) + + def test_wrong_matrix_dimensions_raises_static(self): + operator = linalg_lib.LinearOperatorZeros(num_rows=2) + x = rng.randn(3, 3).astype(np.float32) + with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): + operator.matmul(x) + + def test_wrong_matrix_dimensions_raises_dynamic(self): + num_rows = array_ops.placeholder(dtypes.int32) + x = array_ops.placeholder(dtypes.float32) + + with self.test_session(): + operator = linalg_lib.LinearOperatorZeros( + num_rows, assert_proper_shapes=True) + y = operator.matmul(x) + with self.assertRaisesOpError("Incompatible.*dimensions"): + y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)}) + + def test_is_x_flags(self): + # The is_x flags are by default all True. + operator = linalg_lib.LinearOperatorZeros(num_rows=2) + self.assertFalse(operator.is_positive_definite) + self.assertFalse(operator.is_non_singular) + self.assertTrue(operator.is_self_adjoint) + + +class LinearOperatorZerosNotSquareTest( + linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): + + def _operator_and_matrix(self, build_info, dtype, use_placeholder): + del use_placeholder + shape = list(build_info.shape) + + batch_shape = shape[:-2] + num_rows = shape[-2] + num_columns = shape[-1] + + operator = linalg_lib.LinearOperatorZeros( + num_rows, num_columns, is_square=False, is_self_adjoint=False, + batch_shape=batch_shape, dtype=dtype) + matrix = array_ops.zeros(shape=shape, dtype=dtype) + + return operator, matrix + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py index 6f401358a2519a699488f0372323b5a41621c4cd..0e4e58409ece3437f2489b4c7f46ef9c66602e71 100644 --- a/tensorflow/python/kernel_tests/linalg_grad_test.py +++ b/tensorflow/python/kernel_tests/linalg_grad_test.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg_impl from tensorflow.python.platform import test as test_lib @@ -173,6 +174,10 @@ if __name__ == '__main__': _AddTest(MatrixUnaryFunctorGradientTest, 'MatrixInverseGradient', name, _GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_inverse, dtype, shape)) + _AddTest(MatrixUnaryFunctorGradientTest, 'MatrixExponentialGradient', + name, + _GetMatrixUnaryFunctorGradientTest( + linalg_impl.matrix_exponential, dtype, shape)) _AddTest( MatrixUnaryFunctorGradientTest, 'MatrixDeterminantGradient', name, _GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_determinant, diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index 49855200c2427a88a4bd582c2ef786c38a6fa76a..bf82e08551e6a276b95bf77f7932c31d7a844a78 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -46,7 +46,7 @@ def scalar_shape(): @test_util.with_c_shapes class ListOpsTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPushPop(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) @@ -54,14 +54,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), 1.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPushPopGPU(self): if not context.num_gpus(): return with context.device("gpu:0"): self.testPushPop() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testStack(self): l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, element_shape=scalar_shape()) @@ -70,14 +70,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [1.0, 2.0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testStackGPU(self): if not context.num_gpus(): return with context.device("gpu:0"): self.testStack() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorListFromTensor(self): t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape()) @@ -87,14 +87,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual(self.evaluate(e), 1.0) self.assertAllEqual(self.evaluate(list_ops.tensor_list_length(l)), 0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testFromTensorGPU(self): if not context.num_gpus(): return with context.device("gpu:0"): self.testTensorListFromTensor() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetSetItem(self): t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape()) @@ -104,14 +104,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(t), [3.0, 2.0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetSetGPU(self): if not context.num_gpus(): return with context.device("gpu:0"): self.testGetSetItem() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testUnknownShape(self): l = list_ops.empty_tensor_list( element_dtype=dtypes.float32, element_shape=-1) @@ -122,7 +122,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) self.assertAllEqual(self.evaluate(e), 1.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCPUGPUCopy(self): if not context.num_gpus(): return @@ -140,7 +140,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_pop_back( l_cpu, element_dtype=dtypes.float32)[1]), 2.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphStack(self): with context.graph_mode(), self.test_session(): tl = list_ops.empty_tensor_list( @@ -152,7 +152,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)), [[1]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphStackInLoop(self): with context.graph_mode(), self.test_session(): t1 = list_ops.empty_tensor_list( @@ -170,7 +170,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32) self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphStackSwitchDtype(self): with context.graph_mode(), self.test_session(): list_ = list_ops.empty_tensor_list( @@ -192,7 +192,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) self.assertAllEqual(self.evaluate(s1), np_s1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGraphStackInLoopSwitchDtype(self): with context.graph_mode(), self.test_session(): t1 = list_ops.empty_tensor_list( @@ -216,7 +216,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): np_s1 = np.vstack([np.arange(1, 4) * i for i in range(4)]) self.assertAllEqual(self.evaluate(s1), np_s1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSerialize(self): # pylint: disable=g-import-not-at-top try: @@ -248,7 +248,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): worker_e = array_ops.identity(e) self.assertAllEqual(self.evaluate(worker_e), [2.0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPushPopGradients(self): with backprop.GradientTape() as tape: l = list_ops.empty_tensor_list(element_dtype=dtypes.float32, @@ -260,7 +260,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): e = 2 * e self.assertAllEqual(self.evaluate(tape.gradient(e, [c])[0]), 2.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testStackFromTensorGradients(self): with backprop.GradientTape() as tape: c = constant_op.constant([1.0, 2.0]) @@ -272,7 +272,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): grad = tape.gradient(result, [c])[0] self.assertAllEqual(self.evaluate(grad), [2.0, 2.0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetSetGradients(self): with backprop.GradientTape() as tape: c = constant_op.constant([1.0, 2.0]) @@ -288,14 +288,14 @@ class ListOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual(self.evaluate(grad_c), [0.0, 4.0]) self.assertAllEqual(self.evaluate(grad_c2), 6.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSetOutOfBounds(self): c = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(list_ops.tensor_list_set_item(l, 20, 3.0)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testResourceVariableScatterGather(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) @@ -319,7 +319,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): [[1.0, 2.0]] * 4) self.assertAllEqual(self.evaluate(updated_v_stacked), expected) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConcat(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) @@ -379,7 +379,7 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_int_tls, element_dtype=dtypes.float32)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPushBackBatch(self): c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32) l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape()) diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py index 28c85fa13ad100c38382d2b787ff965f9e3ca44e..e635a71c78484278b54bfc4de70e232834c37a0a 100644 --- a/tensorflow/python/kernel_tests/logging_ops_test.py +++ b/tensorflow/python/kernel_tests/logging_ops_test.py @@ -59,7 +59,7 @@ class LoggingOpsTest(test.TestCase): class PrintGradientTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPrintShape(self): inp = constant_op.constant(2.0, shape=[100, 32]) inp_printed = logging_ops.Print(inp, [inp]) diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index 1123c20a165ba93bd380fa471a8be91f7005d7bb..87fc715783b972a20465827d697cf06637588154 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -118,6 +119,14 @@ class AbsoluteDifferenceLossTest(test.TestCase): with self.test_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testEagerNoMemoryLeaked(self): + # This is a somewhat convoluted way of testing that nothing gets added to + # a global collection. + predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) + labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + losses.absolute_difference(labels, predictions) + class SoftmaxCrossEntropyLossTest(test.TestCase): @@ -246,6 +255,13 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase): self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') self.assertAlmostEqual(loss.eval(), 0.0, 3) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testEagerNoMemoryLeaked(self): + logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32) + losses.sparse_softmax_cross_entropy(labels, logits) + def testAllCorrectInt64Labels(self): with self.test_session(): logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], diff --git a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py index a0c66c77d8850d3144678870983730537a253556..0386e91276eb3cd8515c41396e8687a90e27fbca 100644 --- a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py @@ -12,33 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tensorflow.ops.gen_linalg_ops.matrix_exponential.""" +"""Tests for tensorflow.ops.linalg.linalg_impl.matrix_exponential.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import itertools -import math import numpy as np from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables +from tensorflow.python.ops.linalg import linalg_impl from tensorflow.python.platform import test -def np_expm(x): +def np_expm(x): # pylint: disable=invalid-name """Slow but accurate Taylor series matrix exponential.""" y = np.zeros(x.shape, dtype=x.dtype) xn = np.eye(x.shape[0], dtype=x.dtype) for n in range(40): - y += xn / float(math.factorial(n)) + if n > 0: + xn /= float(n) + y += xn xn = np.dot(xn, x) return y @@ -48,7 +50,7 @@ class ExponentialOpTest(test.TestCase): def _verifyExponential(self, x, np_type): inp = x.astype(np_type) with self.test_session(use_gpu=True): - tf_ans = gen_linalg_ops.matrix_exponential(inp) + tf_ans = linalg_impl.matrix_exponential(inp) if x.size == 0: np_ans = np.empty(x.shape, dtype=np_type) else: @@ -76,7 +78,7 @@ class ExponentialOpTest(test.TestCase): matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1]) return matrix_batch - def testNonsymmetric(self): + def testNonsymmetricReal(self): # 2x2 matrices matrix1 = np.array([[1., 2.], [3., 4.]]) matrix2 = np.array([[1., 3.], [3., 5.]]) @@ -84,7 +86,10 @@ class ExponentialOpTest(test.TestCase): self._verifyExponentialReal(matrix2) # A multidimensional batch of 2x2 matrices self._verifyExponentialReal(self._makeBatch(matrix1, matrix2)) - # Complex + + def testNonsymmetricComplex(self): + matrix1 = np.array([[1., 2.], [3., 4.]]) + matrix2 = np.array([[1., 3.], [3., 5.]]) matrix1 = matrix1.astype(np.complex64) matrix1 += 1j * matrix1 matrix2 = matrix2.astype(np.complex64) @@ -94,7 +99,7 @@ class ExponentialOpTest(test.TestCase): # Complex batch self._verifyExponentialComplex(self._makeBatch(matrix1, matrix2)) - def testSymmetricPositiveDefinite(self): + def testSymmetricPositiveDefiniteReal(self): # 2x2 matrices matrix1 = np.array([[2., 1.], [1., 2.]]) matrix2 = np.array([[3., -1.], [-1., 3.]]) @@ -102,7 +107,10 @@ class ExponentialOpTest(test.TestCase): self._verifyExponentialReal(matrix2) # A multidimensional batch of 2x2 matrices self._verifyExponentialReal(self._makeBatch(matrix1, matrix2)) - # Complex + + def testSymmetricPositiveDefiniteComplex(self): + matrix1 = np.array([[2., 1.], [1., 2.]]) + matrix2 = np.array([[3., -1.], [-1., 3.]]) matrix1 = matrix1.astype(np.complex64) matrix1 += 1j * matrix1 matrix2 = matrix2.astype(np.complex64) @@ -116,35 +124,31 @@ class ExponentialOpTest(test.TestCase): # When the exponential of a non-square matrix is attempted we should return # an error with self.assertRaises(ValueError): - gen_linalg_ops.matrix_exponential(np.array([[1., 2., 3.], [3., 4., 5.]])) + linalg_impl.matrix_exponential(np.array([[1., 2., 3.], [3., 4., 5.]])) def testWrongDimensions(self): # The input to the exponential should be at least a 2-dimensional tensor. tensor3 = constant_op.constant([1., 2.]) with self.assertRaises(ValueError): - gen_linalg_ops.matrix_exponential(tensor3) + linalg_impl.matrix_exponential(tensor3) def testEmpty(self): self._verifyExponentialReal(np.empty([0, 2, 2])) self._verifyExponentialReal(np.empty([2, 0, 0])) - def testRandomSmallAndLarge(self): - np.random.seed(42) - for dtype in np.float32, np.float64, np.complex64, np.complex128: - for batch_dims in [(), (1,), (3,), (2, 2)]: - for size in 8, 31, 32: - shape = batch_dims + (size, size) - matrix = np.random.uniform( - low=-1.0, high=1.0, - size=np.prod(shape)).reshape(shape).astype(dtype) - self._verifyExponentialReal(matrix) + def testDynamic(self): + with self.test_session(use_gpu=True) as sess: + inp = array_ops.placeholder(ops.dtypes.float32) + expm = linalg_impl.matrix_exponential(inp) + matrix = np.array([[1., 2.], [3., 4.]]) + sess.run(expm, feed_dict={inp: matrix}) def testConcurrentExecutesWithoutError(self): with self.test_session(use_gpu=True) as sess: matrix1 = random_ops.random_normal([5, 5], seed=42) matrix2 = random_ops.random_normal([5, 5], seed=42) - expm1 = gen_linalg_ops.matrix_exponential(matrix1) - expm2 = gen_linalg_ops.matrix_exponential(matrix2) + expm1 = linalg_impl.matrix_exponential(matrix1) + expm2 = linalg_impl.matrix_exponential(matrix2) expm = sess.run([expm1, expm2]) self.assertAllEqual(expm[0], expm[1]) @@ -180,7 +184,7 @@ class MatrixExponentialBenchmark(test.Benchmark): session.Session() as sess, \ ops.device("/cpu:0"): matrix = self._GenerateMatrix(shape) - expm = gen_linalg_ops.matrix_exponential(matrix) + expm = linalg_impl.matrix_exponential(matrix) variables.global_variables_initializer().run() self.run_op_benchmark( sess, @@ -189,6 +193,66 @@ class MatrixExponentialBenchmark(test.Benchmark): name="matrix_exponential_cpu_{shape}".format( shape=shape)) + if test.is_gpu_available(True): + with ops.Graph().as_default(), \ + session.Session() as sess, \ + ops.device("/gpu:0"): + matrix = self._GenerateMatrix(shape) + expm = linalg_impl.matrix_exponential(matrix) + variables.global_variables_initializer().run() + self.run_op_benchmark( + sess, + control_flow_ops.group(expm), + min_iters=25, + name="matrix_exponential_gpu_{shape}".format( + shape=shape)) + + +def _TestRandomSmall(dtype, batch_dims, size): + + def Test(self): + np.random.seed(42) + shape = batch_dims + (size, size) + matrix = np.random.uniform( + low=-1.0, high=1.0, + size=shape).astype(dtype) + self._verifyExponentialReal(matrix) + + return Test + + +def _TestL1Norms(dtype, shape, scale): + + def Test(self): + np.random.seed(42) + matrix = np.random.uniform( + low=-1.0, high=1.0, + size=np.prod(shape)).reshape(shape).astype(dtype) + print(dtype, shape, scale, matrix) + l1_norm = np.max(np.sum(np.abs(matrix), axis=matrix.ndim-2)) + matrix /= l1_norm + self._verifyExponentialReal(scale * matrix) + + return Test + if __name__ == "__main__": + for dtype_ in [np.float32, np.float64, np.complex64, np.complex128]: + for batch_ in [(), (2,), (2, 2)]: + for size_ in [4, 7]: + name = "%s_%d_%d" % (dtype_.__name__, len(batch_), size_) + setattr(ExponentialOpTest, "testL1Norms_" + name, + _TestRandomSmall(dtype_, batch_, size_)) + + for shape_ in [(3, 3), (2, 3, 3)]: + for dtype_ in [np.float32, np.complex64]: + for scale_ in [0.1, 1.5, 5.0, 20.0]: + name = "%s_%d_%d" % (dtype_.__name__, len(shape_), int(scale_*10)) + setattr(ExponentialOpTest, "testL1Norms_" + name, + _TestL1Norms(dtype_, shape_, scale_)) + for dtype_ in [np.float64, np.complex128]: + for scale_ in [0.01, 0.2, 0.5, 1.5, 6.0, 25.0]: + name = "%s_%d_%d" % (dtype_.__name__, len(shape_), int(scale_*100)) + setattr(ExponentialOpTest, "testL1Norms_" + name, + _TestL1Norms(dtype_, shape_, scale_)) test.main() diff --git a/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py index d8ce9fffbd2bc0d18033339a02e0ad84f8f4c952..3cbbd48c8cb26d5cdb457c9599bfc9131000d174 100644 --- a/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py +++ b/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py @@ -82,7 +82,7 @@ def CheckGradConfigsToTest(): class DepthwiseConv2DTest(test.TestCase): # This is testing that depthwise_conv2d and depthwise_conv2d_native - # produce the same results. It also tests that NCHW and NWHC + # produce the same results. It also tests that NCHW and NHWC # formats agree, by comparing the depthwise_conv2d_native with # 'NCHW' format (with transposition) matches the 'NHWC' format using # the higher level interface. @@ -123,7 +123,7 @@ class DepthwiseConv2DTest(test.TestCase): native_t1 = t1 strides = [1, stride, stride, 1] if data_format == "NCHW": - # Transpose from NWHC input to NCHW + # Transpose from NHWC input to NCHW # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] native_t1 = array_ops.transpose(t1, [0, 3, 1, 2]) strides = [1, 1, stride, stride] diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py index f5c6255c346961fec7245889229ea1c4b89fa388..ba9359d92344f371e27c199a8ba464b6eae24778 100644 --- a/tensorflow/python/kernel_tests/partitioned_variables_test.py +++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py @@ -25,12 +25,15 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent class PartitionerCreatorsTest(test.TestCase): @@ -543,32 +546,6 @@ class PartitionedVariablesTestCase(test.TestCase): partitioned_variables.create_partitioned_variables( [10, 43], [1, 50], rnd.initialized_value()) - def testControlDepsNone(self): - with self.test_session() as session: - c = constant_op.constant(1.0) - with ops.control_dependencies([c]): - # d get the control dependency. - d = constant_op.constant(2.0) - # Partitioned variables do not. - var_x = variable_scope.get_variable( - "x", - shape=[2], - initializer=init_ops.ones_initializer(), - partitioner=partitioned_variables.variable_axis_size_partitioner(4)) - - ops_before_read = session.graph.get_operations() - var_x.as_tensor() # Caches the ops for subsequent reads. - reading_ops = [ - op for op in session.graph.get_operations() - if op not in ops_before_read - ] - - self.assertEqual([c.op], d.op.control_inputs) - # Tests that no control dependencies are added to reading a partitioned - # variable which is similar to reading a variable. - for op in reading_ops: - self.assertEqual([], op.control_inputs) - def testConcat(self): with self.test_session() as session: var_x = variable_scope.get_variable( @@ -594,6 +571,57 @@ class PartitionedVariablesTestCase(test.TestCase): variables.global_variables_initializer().run() self.assertAllClose(value.eval(), var_x.as_tensor().eval()) + def testVariableCreationInALoop(self): + """Tests the variable created inside a loop can be used outside the loop.""" + with self.test_session(): + with variable_scope.variable_scope("ascope") as scope: + def Body(i, _): + var_x = variable_scope.get_variable( + "x", + shape=[2], + initializer=init_ops.ones_initializer(), + partitioner=partitioned_variables.variable_axis_size_partitioner( + 4)) + return (i + 1, var_x.as_tensor()) + + cond = lambda i, _: i < 2 + _, x = control_flow_ops.while_loop( + cond, Body, (0, constant_op.constant([7, 8], dtypes.float32))) + variables.global_variables_initializer().run() + self.assertAllClose([1.0, 1.0], x.eval()) + + scope.reuse_variables() + var_x = variable_scope.get_variable( + "x", + shape=[2], + initializer=init_ops.ones_initializer(), + partitioner=partitioned_variables.variable_axis_size_partitioner(4)) + + self.assertAllClose([1.0, 1.0], var_x.as_tensor().eval()) + + def testReadInWhileLoop(self): + """Tests the value is current (not cached) when read within a loop.""" + with self.test_session(): + var_x = variable_scope.get_variable( + "x", + shape=[2], + initializer=init_ops.ones_initializer(), + partitioner=partitioned_variables.variable_axis_size_partitioner(4)) + + def Body(i, _): + # Use a SGD step to update the variable's value. + loss = math_ops.reduce_sum(var_x) + optimizer = gradient_descent.GradientDescentOptimizer(1.0) + minimize = optimizer.minimize(loss * 0.7) + with ops.control_dependencies([minimize]): + return (i + 1, var_x.as_tensor()) + + cond = lambda i, _: i < 2 + _, x = control_flow_ops.while_loop( + cond, Body, (0, constant_op.constant([7, 8], dtypes.float32))) + variables.global_variables_initializer().run() + self.assertAllClose([-0.4, -0.4], x.eval()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 4239151070b35ac11511360bd85c5c2424d80f3d..50154a45a8b58f270509e404737c8650cbd2c5ff 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -460,7 +460,7 @@ class PyFuncTest(test.TestCase): self.assertEqual(initial_size, script_ops._py_funcs.size()) # ----- Tests for eager_py_func ----- - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerSingleOutputInt32(self): a = array_ops.ones((3, 3), dtype=dtypes.int32) x = array_ops.ones((3, 1), dtype=dtypes.int32) @@ -468,7 +468,7 @@ class PyFuncTest(test.TestCase): ret = self.evaluate(output) self.assertAllEqual(ret, [[3], [3], [3]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerSingleOutputFloat32(self): with test_util.device(use_gpu=True): a = array_ops.ones((3, 3), dtype=dtypes.float32) @@ -477,7 +477,7 @@ class PyFuncTest(test.TestCase): ret = self.evaluate(output) self.assertAllClose(ret, [[3.0], [3.0], [3.0]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerArrayOutput(self): with test_util.device(use_gpu=True): a = array_ops.ones((3, 3), dtype=dtypes.float32) @@ -487,7 +487,7 @@ class PyFuncTest(test.TestCase): ret = self.evaluate(output) self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerReturnNone(self): with test_util.device(use_gpu=True): def no_return_value(): @@ -500,7 +500,7 @@ class PyFuncTest(test.TestCase): else: self.assertIsNone(ret) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerPyFuncInDefun(self): with test_util.device(use_gpu=True): def wrapper(): @@ -512,7 +512,7 @@ class PyFuncTest(test.TestCase): ret = self.evaluate(wrapped()) self.assertAllEqual(ret, [[3.0], [3.0], [3.0]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerExceptionHandling(self): with test_util.device(use_gpu=True): self._testExceptionHandling( @@ -531,7 +531,7 @@ class PyFuncTest(test.TestCase): self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerReturningVariableRaisesError(self): def return_variable(): return resource_variable_ops.ResourceVariable(0.0) @@ -542,7 +542,7 @@ class PyFuncTest(test.TestCase): return_variable, inp=[], Tout=dtypes.float32) self.evaluate(output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerGradientTape(self): def f(x): @@ -565,7 +565,7 @@ class PyFuncTest(test.TestCase): dy_dx = gradients_impl.gradients(y, x)[0] self.assertEqual(self.evaluate(dy_dx), 6.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerGradientTapeMultipleArgs(self): def f(x, y): @@ -616,6 +616,25 @@ class PyFuncTest(test.TestCase): self.assertEqual(y, 1.0) self.assertEqual(dy_dx, 2.0) + def testEagerRespectsDevicePlacmentOfOp(self): + + def f(x): + return math_ops.square(x) + + def g(x): + return math_ops.add(x, x) + + with ops.device("/CPU:0"): + # Explicitly ask for the py_funcs to execute on CPU, even if + # a GPU is available. + x = array_ops.placeholder(dtypes.float32) + y = script_ops.eager_py_func(func=f, inp=[x], Tout=dtypes.float32) + z = script_ops.eager_py_func(func=g, inp=[y], Tout=dtypes.float32) + + with self.test_session(use_gpu=True) as sess: + output = sess.run(z, feed_dict={x: 3.0}) + self.assertEqual(output, 18.0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/random/BUILD b/tensorflow/python/kernel_tests/random/BUILD index acd7566eec8e3fffd74db33234b03a0c87427a3e..3b3a28fc9a24104cc9032ab23dfc51e690d3ec94 100644 --- a/tensorflow/python/kernel_tests/random/BUILD +++ b/tensorflow/python/kernel_tests/random/BUILD @@ -107,6 +107,23 @@ cuda_py_test( tags = ["nozapfhahn"], ) +cuda_py_test( + name = "random_grad_test", + size = "small", + srcs = ["random_grad_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:random_grad", + "//tensorflow/python:random_ops", + ], +) + cuda_py_test( name = "random_poisson_test", size = "medium", diff --git a/tensorflow/python/kernel_tests/random/multinomial_op_test.py b/tensorflow/python/kernel_tests/random/multinomial_op_test.py index 051c7d86bf2342f15b587fc350bfbede7fae2285..bd64d61af8e793e71a319b6ac1af95bd7dd16a3d 100644 --- a/tensorflow/python/kernel_tests/random/multinomial_op_test.py +++ b/tensorflow/python/kernel_tests/random/multinomial_op_test.py @@ -54,7 +54,7 @@ native_sampler = random_ops.multinomial class MultinomialTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSmallEntropy(self): random_seed.set_random_seed(1618) for output_dtype in [np.int32, np.int64]: diff --git a/tensorflow/python/kernel_tests/random/random_grad_test.py b/tensorflow/python/kernel_tests/random/random_grad_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d455b785bbf562fb41f30cab7e0bb723a7b894 --- /dev/null +++ b/tensorflow/python/kernel_tests/random/random_grad_test.py @@ -0,0 +1,240 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.random_grad.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_grad +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +class AddLeadingUnitDimensionsTest(test.TestCase): + + def testBasic(self): + ret = random_grad.add_leading_unit_dimensions(array_ops.ones([3, 2, 1]), 3) + self.assertAllEqual(ret.shape, [1, 1, 1, 3, 2, 1]) + + def testZeroExtraDimensions(self): + ret = random_grad.add_leading_unit_dimensions(array_ops.ones([3, 2, 1]), 0) + self.assertAllEqual(ret.shape, [3, 2, 1]) + + def testScalarInput(self): + ret = random_grad.add_leading_unit_dimensions(1.0, 2) + self.assertAllEqual(ret.shape, [1, 1]) + + def testUnknownShape(self): + x = array_ops.placeholder(dtypes.float32) + num_dimensions = array_ops.placeholder(dtypes.int32) + ret = random_grad.add_leading_unit_dimensions(x, num_dimensions) + with self.test_session() as sess: + ret_val = sess.run(ret, {x: np.ones([2, 2]), num_dimensions: 2}) + self.assertAllEqual(ret_val.shape, [1, 1, 2, 2]) + + +class RandomGammaGradTest(test.TestCase): + """Tests for derivative of a sample ~ Gamma(alpha, beta) wrt alpha and beta. + + The sample is an "implicit" function of alpha, beta and the independent random + noise u. The derivatives we are looking for are + d sample(alpha, beta, u) / dalpha (and dbeta). + + The derivative w.r.t. beta is computed by the standard automatic + differentiation, so we trust that it is computed correctly. + + The derivative w.r.t. alpha is computed by Eigen function, so we test it in + several ways. Unfortunately, the standard derivative checking by perturbing + the parameter is impossible here, because we cannot fix the value of u + in the random sampler. Instead, we compare the derivative for the given pair + of (sample, alpha) to the values computed in various ways, and also check + some statistical properties of the derivative. + """ + + def testGradientsShape(self): + shape = [2, 3] + alpha = array_ops.ones([2, 2]) + beta = array_ops.ones([1, 2]) + sample = random_ops.random_gamma(shape, alpha, beta) + grads_alpha, grads_beta = gradients_impl.gradients(sample, [alpha, beta]) + self.assertAllEqual(grads_alpha.shape, alpha.shape) + self.assertAllEqual(grads_beta.shape, beta.shape) + + def testGradientsShapeWithOneSamplePerParameter(self): + shape = [] + alpha = array_ops.ones([2, 2]) + beta = array_ops.ones([1, 2]) + sample = random_ops.random_gamma(shape, alpha, beta) + grads_alpha, grads_beta = gradients_impl.gradients(sample, [alpha, beta]) + self.assertAllEqual(grads_alpha.shape, alpha.shape) + self.assertAllEqual(grads_beta.shape, beta.shape) + + def testGradientsUnknownShape(self): + shape = array_ops.placeholder(dtypes.int32) + alpha = array_ops.placeholder(dtypes.float32) + beta = array_ops.placeholder(dtypes.float32) + sample = random_ops.random_gamma(shape, alpha, beta) + grads_alpha, grads_beta = gradients_impl.gradients(sample, [alpha, beta]) + + alpha_val = np.ones([1, 2]) + beta_val = np.ones([2, 1]) + with self.test_session() as sess: + grads_alpha_val, grads_beta_val = sess.run( + [grads_alpha, grads_beta], + {alpha: alpha_val, beta: beta_val, shape: [2, 1]}) + self.assertAllEqual(grads_alpha_val.shape, alpha_val.shape) + self.assertAllEqual(grads_beta_val.shape, beta_val.shape) + + def _testCompareToExplicitDerivative(self, dtype): + """Compare to the explicit reparameterization derivative. + + Verifies that the computed derivative satisfies + dsample / dalpha = d igammainv(alpha, u) / dalpha, + where u = igamma(alpha, sample). + + Args: + dtype: TensorFlow dtype to perform the computations in. + """ + delta = 1e-3 + np_dtype = dtype.as_numpy_dtype + try: + from scipy import misc # pylint: disable=g-import-not-at-top + from scipy import special # pylint: disable=g-import-not-at-top + + alpha_val = np.logspace(-2, 3, dtype=np_dtype) + alpha = constant_op.constant(alpha_val) + sample = random_ops.random_gamma([], alpha, np_dtype(1.0), dtype=dtype) + actual = gradients_impl.gradients(sample, alpha)[0] + + (sample_val, actual_val) = self.evaluate((sample, actual)) + + u = special.gammainc(alpha_val, sample_val) + expected_val = misc.derivative( + lambda alpha_prime: special.gammaincinv(alpha_prime, u), + alpha_val, dx=delta * alpha_val) + + self.assertAllClose(actual_val, expected_val, rtol=1e-3, atol=1e-3) + except ImportError as e: + tf_logging.warn("Cannot use special functions in a test: %s" % str(e)) + + def testCompareToExplicitDerivativeFloat(self): + self._testCompareToExplicitDerivative(dtypes.float32) + + def testCompareToExplicitDerivativeDouble(self): + self._testCompareToExplicitDerivative(dtypes.float64) + + def _testCompareToImplicitDerivative(self, dtype): + """Compare to the implicit reparameterization derivative. + + Let's derive the formula we compare to. + + Start from the fact that CDF maps a random variable to the Uniform + random variable: + igamma(alpha, sample) = u, where u ~ Uniform(0, 1). + + Apply d / dalpha to both sides: + d igamma(alpha, sample) / dalpha + + d igamma(alpha, sample) / dsample * dsample/dalpha = 0 + d igamma(alpha, sample) / dalpha + + d igamma(alpha, sample) / dsample * dsample / dalpha = 0 + dsample/dalpha = - (d igamma(alpha, sample) / dalpha) + / d igamma(alpha, sample) / dsample + + This is the equation (8) of https://arxiv.org/abs/1805.08498 + + Args: + dtype: TensorFlow dtype to perform the computations in. + """ + np_dtype = dtype.as_numpy_dtype + alpha = constant_op.constant(np.logspace(-2, 3, dtype=np_dtype)) + sample = random_ops.random_gamma([], alpha, np_dtype(1.0), dtype=dtype) + actual = gradients_impl.gradients(sample, alpha)[0] + + sample_sg = array_ops.stop_gradient(sample) + cdf = math_ops.igamma(alpha, sample_sg) + dcdf_dalpha, dcdf_dsample = gradients_impl.gradients( + cdf, [alpha, sample_sg]) + # Numerically unstable due to division, do not try at home. + expected = -dcdf_dalpha / dcdf_dsample + + (actual_val, expected_val) = self.evaluate((actual, expected)) + + self.assertAllClose(actual_val, expected_val, rtol=1e-3, atol=1e-3) + + def testCompareToImplicitDerivativeFloat(self): + self._testCompareToImplicitDerivative(dtypes.float32) + + def testCompareToImplicitDerivativeDouble(self): + self._testCompareToImplicitDerivative(dtypes.float64) + + def testAverageAlphaGradient(self): + """Statistical test for the gradient. + + Using the equation (5) of https://arxiv.org/abs/1805.08498, we have + 1 = d/dalpha E_{sample ~ Gamma(alpha, 1)} sample + = E_{sample ~ Gamma(alpha, 1)} dsample/dalpha. + Here we verify that the rhs is fairly close to one. + The convergence speed is not great, so we use many samples and loose bounds. + """ + num_samples = 1000 + alpha = constant_op.constant([0.8, 1e1, 1e3], dtype=dtypes.float32) + sample = random_ops.random_gamma([num_samples], alpha) + # We need to average the gradients, which is equivalent to averaging the + # samples and then doing backprop. + mean_sample = math_ops.reduce_mean(sample, axis=0) + dsample_dalpha = gradients_impl.gradients(mean_sample, alpha)[0] + dsample_dalpha_val = self.evaluate(dsample_dalpha) + self.assertAllClose(dsample_dalpha_val, [1.0] * 3, atol=1e-1, rtol=1e-1) + + def testQuadraticLoss(self): + """Statistical test for the gradient. + + The equation (5) of https://arxiv.org/abs/1805.08498 says + d/dalpha E_{sample ~ Gamma(alpha, 1)} f(sample) + = E_{sample ~ Gamma(alpha, 1)} df(sample)/dalpha. + + Choose a quadratic loss function f(sample) = (sample - t)^2. + Then, the lhs can be computed analytically: + d/dalpha E_{sample ~ Gamma(alpha, 1)} f(sample) + = d/dalpha [ (alpha + alpha^2) - 2 * t * alpha + t^2 ] + = 1 + 2 * alpha - 2 * t. + + We compare the Monte-Carlo estimate of the expectation with the + true gradient. + """ + num_samples = 1000 + t = 0.3 + alpha = 0.5 + expected = 1 + 2 * alpha - 2 * t + + alpha = constant_op.constant(alpha) + sample = random_ops.random_gamma([num_samples], alpha, 1.0) + loss = math_ops.reduce_mean(math_ops.square(sample - t)) + dloss_dalpha = gradients_impl.gradients(loss, alpha)[0] + dloss_dalpha_val = self.evaluate(dloss_dalpha) + self.assertAllClose(expected, dloss_dalpha_val, atol=1e-1, rtol=1e-1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/kernel_tests/random/random_ops_test.py b/tensorflow/python/kernel_tests/random/random_ops_test.py index e4b5c3832a2252aedc8820a650b022cd30b7f285..0ef6a95cfc994ecdfb734f133984fbad774d8691 100644 --- a/tensorflow/python/kernel_tests/random/random_ops_test.py +++ b/tensorflow/python/kernel_tests/random/random_ops_test.py @@ -24,13 +24,42 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -class RandomNormalTest(test.TestCase): +class RandomOpTestCommon(test.TestCase): + + # Checks that executing the same rng_func multiple times rarely produces the + # same result. + def _testSingleSessionNotConstant(self, + rng_func, + num, + dtype, + min_or_mean, + max_or_stddev, + use_gpu, + op_seed=None, + graph_seed=None): + with self.test_session(use_gpu=use_gpu, graph=ops.Graph()) as sess: + if graph_seed is not None: + random_seed.set_random_seed(graph_seed) + x = rng_func([num], min_or_mean, max_or_stddev, dtype=dtype, seed=op_seed) + + y = sess.run(x) + z = sess.run(x) + w = sess.run(x) + + # We use exact equality here. If the random-number generator is producing + # the same output, all three outputs will be bitwise identical. + self.assertTrue((not np.array_equal(y, z)) or + (not np.array_equal(z, w)) or (not np.array_equal(y, w))) + + +class RandomNormalTest(RandomOpTestCommon): def _Sampler(self, num, mu, sigma, dtype, use_gpu, seed=None): @@ -90,6 +119,36 @@ class RandomNormalTest(test.TestCase): diff = rnd2 - rnd1 self.assertTrue(np.linalg.norm(diff.eval()) > 0.1) + def testSingleSessionNotConstant(self): + for use_gpu in [False, True]: + for dt in dtypes.float16, dtypes.float32, dtypes.float64: + self._testSingleSessionNotConstant( + random_ops.random_normal, 100, dt, 0.0, 1.0, use_gpu=use_gpu) + + def testSingleSessionOpSeedNotConstant(self): + for use_gpu in [False, True]: + for dt in dtypes.float16, dtypes.float32, dtypes.float64: + self._testSingleSessionNotConstant( + random_ops.random_normal, + 100, + dt, + 0.0, + 1.0, + use_gpu=use_gpu, + op_seed=1345) + + def testSingleSessionGraphSeedNotConstant(self): + for use_gpu in [False, True]: + for dt in dtypes.float16, dtypes.float32, dtypes.float64: + self._testSingleSessionNotConstant( + random_ops.random_normal, + 100, + dt, + 0.0, + 1.0, + use_gpu=use_gpu, + graph_seed=965) + class TruncatedNormalTest(test.TestCase): @@ -187,7 +246,7 @@ class TruncatedNormalTest(test.TestCase): self.assertAllEqual(rnd1, rnd2) -class RandomUniformTest(test.TestCase): +class RandomUniformTest(RandomOpTestCommon): def _Sampler(self, num, minv, maxv, dtype, use_gpu, seed=None): @@ -291,6 +350,39 @@ class RandomUniformTest(test.TestCase): diff = (rnd2 - rnd1).eval() self.assertTrue(np.linalg.norm(diff) > 0.1) + def testSingleSessionNotConstant(self): + for use_gpu in [False, True]: + for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, + dtypes.int64): + self._testSingleSessionNotConstant( + random_ops.random_uniform, 100, dt, 0, 17, use_gpu=use_gpu) + + def testSingleSessionOpSeedNotConstant(self): + for use_gpu in [False, True]: + for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, + dtypes.int64): + self._testSingleSessionNotConstant( + random_ops.random_uniform, + 100, + dt, + 10, + 20, + use_gpu=use_gpu, + op_seed=1345) + + def testSingleSessionGraphSeedNotConstant(self): + for use_gpu in [False, True]: + for dt in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, + dtypes.int64): + self._testSingleSessionNotConstant( + random_ops.random_uniform, + 100, + dt, + 20, + 200, + use_gpu=use_gpu, + graph_seed=965) + class RandomShapeTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index 7be473a5e750d9d6880f112cb0ca89b3ae61a7fd..8e06e1abfb52244e8c1a9b4ed15a270f6048e028 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -25,8 +25,6 @@ import shutil import threading import zlib -import six - from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -703,228 +701,6 @@ class TFRecordReaderTest(TFCompressionTestCase): self.assertAllEqual(self._Record(i, j), v) -class TFRecordWriterTest(TFCompressionTestCase): - - def setUp(self): - super(TFRecordWriterTest, self).setUp() - - def _AssertFilesEqual(self, a, b, equal): - for an, bn in zip(a, b): - with open(an, "rb") as af, open(bn, "rb") as bf: - if equal: - self.assertEqual(af.read(), bf.read()) - else: - self.assertNotEqual(af.read(), bf.read()) - - def testWriteReadZLibFiles(self): - # Write uncompressed then compress manually. - options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) - files = self._CreateFiles(options, prefix="uncompressed") - zlib_files = [ - self._ZlibCompressFile(fn, "tfrecord_%s.z" % i) - for i, fn in enumerate(files) - ] - self._AssertFilesEqual(files, zlib_files, False) - - # Now write compressd and verify same. - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) - compressed_files = self._CreateFiles(options, prefix="compressed") - self._AssertFilesEqual(compressed_files, zlib_files, True) - - # Decompress compress and verify same. - uncompressed_files = [ - self._ZlibDecompressFile(fn, "tfrecord_%s.z" % i) - for i, fn in enumerate(compressed_files) - ] - self._AssertFilesEqual(uncompressed_files, files, True) - - def testWriteReadGzipFiles(self): - # Write uncompressed then compress manually. - options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) - files = self._CreateFiles(options, prefix="uncompressed") - gzip_files = [ - self._GzipCompressFile(fn, "tfrecord_%s.gz" % i) - for i, fn in enumerate(files) - ] - self._AssertFilesEqual(files, gzip_files, False) - - # Now write compressd and verify same. - options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) - compressed_files = self._CreateFiles(options, prefix="compressed") - - # Note: Gzips written by TFRecordWriter add 'tfrecord_0' so - # compressed_files can't be compared with gzip_files - - # Decompress compress and verify same. - uncompressed_files = [ - self._GzipDecompressFile(fn, "tfrecord_%s.gz" % i) - for i, fn in enumerate(compressed_files) - ] - self._AssertFilesEqual(uncompressed_files, files, True) - - -class TFRecordWriterZlibTest(TFCompressionTestCase): - - def testOneEpoch(self): - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) - files = self._CreateFiles(options) - with self.test_session() as sess: - reader = io_ops.TFRecordReader(name="test_reader", options=options) - queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) - key, value = reader.read(queue) - - queue.enqueue_many([files]).run() - queue.close().run() - for i in range(self._num_files): - for j in range(self._num_records): - k, v = sess.run([key, value]) - self.assertTrue(compat.as_text(k).startswith("%s:" % files[i])) - self.assertAllEqual(self._Record(i, j), v) - - with self.assertRaisesOpError("is closed and has insufficient elements " - "\\(requested 1, current size 0\\)"): - k, v = sess.run([key, value]) - - def testZLibFlushRecord(self): - fn = self._WriteRecordsToFile([b"small record"], "small_record") - with open(fn, "rb") as h: - buff = h.read() - - # creating more blocks and trailing blocks shouldn't break reads - compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS) - - output = b"" - for c in buff: - if isinstance(c, int): - c = six.int2byte(c) - output += compressor.compress(c) - output += compressor.flush(zlib.Z_FULL_FLUSH) - - output += compressor.flush(zlib.Z_FULL_FLUSH) - output += compressor.flush(zlib.Z_FULL_FLUSH) - output += compressor.flush(zlib.Z_FINISH) - - # overwrite the original file with the compressed data - with open(fn, "wb") as h: - h.write(output) - - with self.test_session() as sess: - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) - reader = io_ops.TFRecordReader(name="test_reader", options=options) - queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=()) - key, value = reader.read(queue) - queue.enqueue(fn).run() - queue.close().run() - k, v = sess.run([key, value]) - self.assertTrue(compat.as_text(k).startswith("%s:" % fn)) - self.assertAllEqual(b"small record", v) - - def testZlibReadWrite(self): - """Verify that files produced are zlib compatible.""" - original = [b"foo", b"bar"] - fn = self._WriteRecordsToFile(original, "zlib_read_write.tfrecord") - zfn = self._ZlibCompressFile(fn, "zlib_read_write.tfrecord.z") - - # read the compressed contents and verify. - actual = [] - for r in tf_record.tf_record_iterator( - zfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)): - actual.append(r) - self.assertEqual(actual, original) - - def testZlibReadWriteLarge(self): - """Verify that writing large contents also works.""" - - # Make it large (about 5MB) - original = [_TEXT * 10240] - fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord") - zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z") - - actual = [] - for r in tf_record.tf_record_iterator( - zfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)): - actual.append(r) - self.assertEqual(actual, original) - - def testGzipReadWrite(self): - """Verify that files produced are gzip compatible.""" - original = [b"foo", b"bar"] - fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord") - gzfn = self._GzipCompressFile(fn, "tfrecord.gz") - - actual = [] - for r in tf_record.tf_record_iterator( - gzfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)): - actual.append(r) - self.assertEqual(actual, original) - - -class TFRecordIteratorTest(TFCompressionTestCase): - - def setUp(self): - super(TFRecordIteratorTest, self).setUp() - self._num_records = 7 - - def testIterator(self): - records = [self._Record(0, i) for i in range(self._num_records)] - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) - fn = self._WriteRecordsToFile(records, "compressed_records", options) - - reader = tf_record.tf_record_iterator(fn, options) - for expected in records: - record = next(reader) - self.assertAllEqual(expected, record) - with self.assertRaises(StopIteration): - record = next(reader) - - def testWriteZlibRead(self): - """Verify compression with TFRecordWriter is zlib library compatible.""" - original = [b"foo", b"bar"] - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) - fn = self._WriteRecordsToFile(original, "write_zlib_read.tfrecord.z", - options) - - zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord") - actual = list(tf_record.tf_record_iterator(zfn)) - self.assertEqual(actual, original) - - def testWriteZlibReadLarge(self): - """Verify compression for large records is zlib library compatible.""" - # Make it large (about 5MB) - original = [_TEXT * 10240] - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) - fn = self._WriteRecordsToFile(original, "write_zlib_read_large.tfrecord.z", - options) - zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tfrecord") - actual = list(tf_record.tf_record_iterator(zfn)) - self.assertEqual(actual, original) - - def testWriteGzipRead(self): - original = [b"foo", b"bar"] - options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) - fn = self._WriteRecordsToFile(original, "write_gzip_read.tfrecord.gz", - options) - - gzfn = self._GzipDecompressFile(fn, "write_gzip_read.tfrecord") - actual = list(tf_record.tf_record_iterator(gzfn)) - self.assertEqual(actual, original) - - def testBadFile(self): - """Verify that tf_record_iterator throws an exception on bad TFRecords.""" - fn = os.path.join(self.get_temp_dir(), "bad_file") - with tf_record.TFRecordWriter(fn) as writer: - writer.write(b"123") - fn_truncated = os.path.join(self.get_temp_dir(), "bad_file_truncated") - with open(fn, "rb") as f: - with open(fn_truncated, "wb") as f2: - # DataLossError requires that we've written the header, so this must - # be at least 12 bytes. - f2.write(f.read(14)) - with self.assertRaises(errors_impl.DataLossError): - for _ in tf_record.tf_record_iterator(fn_truncated): - pass - - class AsyncReaderTest(test.TestCase): def testNoDeadlockFromQueue(self): diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 5267eabf0e4f779a96840069a858609059554c89..b1ef46f2a181f77640ed104b27983e48891c46c4 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -145,7 +145,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertIn("", str(handle)) self.assertIn("", repr(handle)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDtypeSurvivesIdentity(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) id_handle = array_ops.identity(handle) @@ -156,7 +156,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): v = resource_variable_ops.ResourceVariable(1.0) self.assertNotEqual(v.name, v.assign_add(1.0).name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCreateRead(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) self.evaluate(resource_variable_ops.assign_variable_op( @@ -165,7 +165,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)) self.assertAllEqual(1, value) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testManyAssigns(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) create = resource_variable_ops.assign_variable_op( @@ -183,7 +183,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual(f, 1) self.assertEqual(s, 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAssignAdd(self): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) self.evaluate(resource_variable_ops.assign_variable_op( @@ -194,7 +194,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)) self.assertEqual(read, 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterAdd(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -207,7 +207,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterSub(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -220,7 +220,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[-1]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMul(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -233,7 +233,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[5]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterDiv(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -246,7 +246,16 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[2]]) - @test_util.run_in_graph_and_eager_modes() + def testUseResource(self): + v = variables.Variable(1.0, use_resource=True) + self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable)) + + def testEagerNoUseResource(self): + with context.eager_mode(): + v = variables.Variable(1.0) + self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable)) + + @test_util.run_in_graph_and_eager_modes def testScatterMin(self): with ops.device("cpu:0"): handle = resource_variable_ops.var_handle_op( @@ -283,7 +292,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): meta_graph_two = saver.export_meta_graph(graph=graph) self.assertEqual(meta_graph_def, meta_graph_two) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMax(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -296,7 +305,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[6]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterAddScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -309,7 +318,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterSubScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -322,7 +331,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[-1]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMulScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -335,7 +344,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[5]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterDivScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -348,7 +357,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[2]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMinScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -361,7 +370,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterMaxScalar(self): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) @@ -426,7 +435,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): state_ops.scatter_update(ref, indices, updates) self.assertAllEqual(ref.read_value(), [True, True, True]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConstraintArg(self): constraint = lambda x: x v = resource_variable_ops.ResourceVariable( @@ -466,32 +475,32 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): with self.assertRaises(errors.OutOfRangeError): state_ops.count_up_to(v, 1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitFnDtype(self): v = resource_variable_ops.ResourceVariable( initial_value=lambda: 1, dtype=dtypes.float32, name="var0") self.assertEqual(dtypes.float32, v.value().dtype) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitFnNoDtype(self): v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1, name="var2") self.assertEqual(dtypes.int32, v.value().dtype) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitializeAllVariables(self): v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32, name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(1.0, self.evaluate(v.value())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOperatorOverload(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) self.assertEqual(2.0, self.evaluate(v + v)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAssignMethod(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -509,7 +518,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(assign_without_read) self.assertEqual(4.0, self.evaluate(v.value())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLoad(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -561,7 +570,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): variable_def=trainable_variable.to_proto()) .trainable) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSparseRead(self): with self.test_session(): init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4)) @@ -583,7 +592,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEquals(v._handle, w._handle) self.assertEquals(v._graph_element, w._graph_element) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAssignAddMethod(self): v = resource_variable_ops.ResourceVariable(1.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -601,7 +610,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(assign_without_read) self.assertEqual(4.0, self.evaluate(v.value())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAssignSubMethod(self): v = resource_variable_ops.ResourceVariable(3.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -619,7 +628,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.evaluate(assign_without_read) self.assertEqual(0.0, self.evaluate(v.value())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDestroyResource(self): v = resource_variable_ops.ResourceVariable(3.0, name="var0") self.evaluate(variables.global_variables_initializer()) @@ -708,7 +717,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, self.evaluate(w_read)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testShape(self): v = resource_variable_ops.ResourceVariable( name="var4", initial_value=array_ops.ones(shape=[10, 20, 35])) @@ -826,6 +835,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): state_ops.scatter_add(v, [1], [3]) self.assertAllEqual([1.0, 5.0], v.numpy()) + def testScatterSubStateOps(self): + with context.eager_mode(): + v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="sub") + state_ops.scatter_sub(v, [1], [3]) + self.assertAllEqual([1.0, -1.0], v.numpy()) + def testScatterNdAddStateOps(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable( @@ -842,7 +857,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): state_ops.scatter_update(v, [1], [3]) self.assertAllEqual([1.0, 3.0], v.numpy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScatterUpdateInvalidArgs(self): v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update") # The exact error and message differ between graph construction (where the @@ -852,5 +867,62 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): state_ops.scatter_update(v, [0, 1], [0, 1, 2]) +class _MixedPrecisionVariableTest(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_dense_var_to_tensor_read_dtype_same_as_var_dtype(self): + # read_dtype is same as dtype + v = resource_variable_ops.ResourceVariable(1.0, dtype=dtypes.float32) + v = resource_variable_ops._MixedPrecisionVariable(v, dtypes.float32) + if not context.executing_eagerly(): + v.initializer.run() + + # dtype is not read_dtype, return NotImplemented + self.assertEqual( + NotImplemented, v._dense_var_to_tensor(dtype=dtypes.float16)) + self.assertEqual(NotImplemented, + v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=True)) + + # as_ref is False + t = v._dense_var_to_tensor(as_ref=False) + self.assertTrue(isinstance(t, ops.Tensor)) + self.assertEqual(t.dtype, dtypes.float32) + self.assertEqual(self.evaluate(t), 1.0) + + t = v._dense_var_to_tensor(dtype=dtypes.float32, as_ref=False) + self.assertTrue(isinstance(t, ops.Tensor)) + self.assertEqual(t.dtype, dtypes.float32) + self.assertEqual(self.evaluate(t), 1.0) + + # as_ref is True + self.assertEqual(NotImplemented, v._dense_var_to_tensor(as_ref=True)) + self.assertEqual(NotImplemented, + v._dense_var_to_tensor(dtype=dtypes.float32, as_ref=True)) + + @test_util.run_in_graph_and_eager_modes() + def test_dense_var_to_tensor_read_dtype_different_from_var_dtype(self): + # read_dtype is different from dtype + v = resource_variable_ops.ResourceVariable(1.0, dtype=dtypes.float32) + v = resource_variable_ops._MixedPrecisionVariable(v, dtypes.float16) + if not context.executing_eagerly(): + v.initializer.run() + + # as_ref is False + t = v._dense_var_to_tensor(as_ref=False) + self.assertTrue(isinstance(t, ops.Tensor)) + self.assertEqual(t.dtype, dtypes.float16) + self.assertEqual(self.evaluate(t), 1.0) + + t = v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=False) + self.assertTrue(isinstance(t, ops.Tensor)) + self.assertEqual(t.dtype, dtypes.float16) + self.assertEqual(self.evaluate(t), 1.0) + + # as_ref is True + self.assertEqual(NotImplemented, v._dense_var_to_tensor(as_ref=True)) + self.assertEqual(NotImplemented, + v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=True)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index fe5ad84c104502f0e09d3a963b406f49d6b97b71..c72ada11dad5709f92aab03425bac35b8c52690d 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import time import timeit @@ -26,6 +27,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import rnn as contrib_rnn from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -46,6 +48,7 @@ import tensorflow.python.ops.nn_grad # pylint: disable=unused-import import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import from tensorflow.python.platform import test +from tensorflow.python.training import saver class Plus1RNNCell(rnn_cell_impl.RNNCell): @@ -81,6 +84,25 @@ class ScalarStateRNNCell(rnn_cell_impl.RNNCell): return (input_, state + 1) +class UnbalancedOutputRNNCell(rnn_cell_impl.RNNCell): + """RNN Cell generating (output, new_state) = (input + 1, state + 1).""" + + @property + def output_size(self): + return tensor_shape.TensorShape(1), tensor_shape.TensorShape((2)) + + @property + def state_size(self): + return tensor_shape.TensorShape([]) + + def zero_state(self, batch_size, dtype): + return array_ops.zeros([], dtype=dtypes.int32) + + def call(self, input_, state, scope=None): + concatenated = array_ops.concat((input_, input_), axis=-1) + return (input_, concatenated), state + 1 + + class TensorArrayStateRNNCell(rnn_cell_impl.RNNCell): """RNN Cell its state as a TensorArray.""" @@ -108,7 +130,7 @@ class RNNTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInvalidSequenceLengthShape(self): cell = Plus1RNNCell() if context.executing_eagerly(): @@ -122,7 +144,7 @@ class RNNTest(test.TestCase): dtype=dtypes.float32, sequence_length=[[4]]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBatchSizeFromInput(self): cell = Plus1RNNCell() in_eager_mode = context.executing_eagerly() @@ -162,7 +184,7 @@ class RNNTest(test.TestCase): self.assertEqual(None, outputs.shape[0].value) self.assertEqual(None, state.shape[0].value) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScalarStateIsAccepted(self): cell = ScalarStateRNNCell() in_eager_mode = context.executing_eagerly() @@ -182,7 +204,29 @@ class RNNTest(test.TestCase): self.assertAllEqual([[[1], [2], [3], [4]]], outputs) self.assertAllEqual(4, state) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes + def testUnbalancedOutputIsAccepted(self): + cell = UnbalancedOutputRNNCell() + in_eager_mode = context.executing_eagerly() + + if in_eager_mode: + inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32) + else: + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) + + with self.test_session() as sess: + outputs, state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32, sequence_length=[4]) + if not in_eager_mode: + outputs, state = sess.run( + [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]}) + + self.assertIsInstance(outputs, tuple) + self.assertAllEqual([[[1], [2], [3], [4]]], outputs[0]) + self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1]) + self.assertAllEqual(4, state) + + @test_util.run_in_graph_and_eager_modes def testTensorArrayStateIsAccepted(self): cell = TensorArrayStateRNNCell() in_eager_mode = context.executing_eagerly() @@ -215,7 +259,7 @@ class RNNTest(test.TestCase): cell_output, _ = cell(array_ops.zeros(in_shape, dtype), state_output) self.assertAllEqual([batch_size, out_size], cell_output.shape.as_list()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCellsBuild(self): f32 = dtypes.float32 f64 = dtypes.float64 @@ -227,7 +271,71 @@ class RNNTest(test.TestCase): self._assert_cell_builds(rnn_cell_impl.GRUCell, f64, 5, 7, 3) self._assert_cell_builds(rnn_cell_impl.LSTMCell, f32, 5, 7, 3) self._assert_cell_builds(rnn_cell_impl.LSTMCell, f64, 5, 7, 3) - + self._assert_cell_builds(contrib_rnn.IndRNNCell, f32, 5, 7, 3) + self._assert_cell_builds(contrib_rnn.IndRNNCell, f64, 5, 7, 3) + self._assert_cell_builds(contrib_rnn.IndyGRUCell, f32, 5, 7, 3) + self._assert_cell_builds(contrib_rnn.IndyGRUCell, f64, 5, 7, 3) + self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f32, 5, 7, 3) + self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f64, 5, 7, 3) + + def testBasicLSTMCellInterchangeWithLSTMCell(self): + with self.test_session(graph=ops_lib.Graph()) as sess: + basic_cell = rnn_cell_impl.BasicLSTMCell(1) + basic_cell(array_ops.ones([1, 1]), + state=basic_cell.zero_state(batch_size=1, + dtype=dtypes.float32)) + self.evaluate([v.initializer for v in basic_cell.variables]) + self.evaluate(basic_cell._bias.assign([10.] * 4)) + save = saver.Saver() + prefix = os.path.join(self.get_temp_dir(), "ckpt") + save_path = save.save(sess, prefix) + + with self.test_session(graph=ops_lib.Graph()) as sess: + lstm_cell = rnn_cell_impl.LSTMCell(1, name="basic_lstm_cell") + lstm_cell(array_ops.ones([1, 1]), + state=lstm_cell.zero_state(batch_size=1, + dtype=dtypes.float32)) + self.evaluate([v.initializer for v in lstm_cell.variables]) + save = saver.Saver() + save.restore(sess, save_path) + self.assertAllEqual([10.] * 4, self.evaluate(lstm_cell._bias)) + + def testRNNCellSerialization(self): + for cell in [ + rnn_cell_impl.LSTMCell(32, use_peepholes=True, cell_clip=True), + rnn_cell_impl.BasicLSTMCell(32, dtype=dtypes.float32), + rnn_cell_impl.BasicRNNCell(32, activation="relu", dtype=dtypes.float32), + rnn_cell_impl.GRUCell( + 32, kernel_initializer="ones", dtype=dtypes.float32) + ]: + with self.test_session(): + x = keras.Input((None, 5)) + layer = keras.layers.RNN(cell) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer="rmsprop", loss="mse") + + # Test basic case serialization. + x_np = np.random.random((6, 5, 5)) + y_np = model.predict(x_np) + weights = model.get_weights() + config = layer.get_config() + # The custom_objects is important here since rnn_cell_impl is + # not visible as a Keras layer, and also has a name conflict with + # keras.LSTMCell and GRUCell. + layer = keras.layers.RNN.from_config( + config, + custom_objects={ + "BasicRNNCell": rnn_cell_impl.BasicRNNCell, + "GRUCell": rnn_cell_impl.GRUCell, + "LSTMCell": rnn_cell_impl.LSTMCell, + "BasicLSTMCell": rnn_cell_impl.BasicLSTMCell + }) + y = layer(x) + model = keras.models.Model(x, y) + model.set_weights(weights) + y_np_2 = model.predict(x_np) + self.assertAllClose(y_np, y_np_2, atol=1e-4) ######### Benchmarking RNN code diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index faa4b49a8d7d8b0169f10592845d3d30a3996c41..f2f30234696be7f6c8c98d041bc415ccf5cb4ecf 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -268,12 +268,12 @@ class StatefulScatterNdTest(test.TestCase): # Test some out of range errors. indices = np.array([[-1], [0], [5]]) with self.assertRaisesOpError( - r"Invalid indices: \[0,0\] = \[-1\] does not index into \[6\]"): + r"indices\[0\] = \[-1\] does not index into shape \[6\]"): op(ref, indices, updates).eval() indices = np.array([[2], [0], [6]]) with self.assertRaisesOpError( - r"Invalid indices: \[2,0\] = \[6\] does not index into \[6\]"): + r"indices\[2\] = \[6\] does not index into shape \[6\]"): op(ref, indices, updates).eval() def testRank3ValidShape(self): @@ -369,7 +369,30 @@ class ScatterNdTest(test.TestCase): del input_ # input_ is not used in scatter_nd return array_ops.scatter_nd(indices, updates, shape) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes + def testBool(self): + indices = constant_op.constant( + [[4], [3], [1], [7]], dtype=dtypes.int32) + updates = constant_op.constant( + [False, True, False, True], dtype=dtypes.bool) + expected = np.array( + [False, False, False, True, False, False, False, True]) + scatter = self.scatter_nd(indices, updates, shape=(8,)) + result = self.evaluate(scatter) + self.assertAllEqual(expected, result) + + # Same indice is updated twice by same value. + indices = constant_op.constant( + [[4], [3], [3], [7]], dtype=dtypes.int32) + updates = constant_op.constant( + [False, True, True, True], dtype=dtypes.bool) + expected = np.array([ + False, False, False, True, False, False, False, True]) + scatter = self.scatter_nd(indices, updates, shape=(8,)) + result = self.evaluate(scatter) + self.assertAllEqual(expected, result) + + @test_util.run_in_graph_and_eager_modes def testInvalidShape(self): # TODO(apassos) figure out how to unify these errors with self.assertRaises(errors.InvalidArgumentError diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py index 7368251ab69574cc6cba703e605f108c6ab45649..34e34d9d1b2034d8679844f051358f020a44587a 100644 --- a/tensorflow/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/python/kernel_tests/shape_ops_test.py @@ -642,6 +642,29 @@ class TileTest(test.TestCase): err = gradient_checker.compute_gradient_error(a, [4, 2], tiled, [4, 4]) self.assertLess(err, 1e-3) + def testGradientWithSparseGradWithRank1(self): + inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], + dtype=dtypes.float32) + outputs = array_ops.gather(array_ops.tile(inputs, [3]), + [1, 5, 9, 3, 7, 2, 2, 2]) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, inputs.get_shape().as_list(), + outputs, outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + + def testGradientWithSparseGradWithRank3(self): + inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], + dtype=dtypes.float32) + inputs = array_ops.reshape(inputs, [-1, 1, 1]) + outputs = array_ops.gather(array_ops.tile(inputs, [3, 4, 2]), + [1, 5, 9, 3, 7, 2, 2, 2]) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, inputs.get_shape().as_list(), + outputs, outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + def testShapeFunctionEdgeCases(self): # Unknown multiples shape. inp = constant_op.constant(0.0, shape=[4, 4, 4, 4]) diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py index 5fc9bef21816e3a12f0d274bab1fc82a83546422..4a1fc1d9a9b10a2738a508c7440bb63a32d4e89c 100644 --- a/tensorflow/python/kernel_tests/slice_op_test.py +++ b/tensorflow/python/kernel_tests/slice_op_test.py @@ -225,7 +225,7 @@ class SliceTest(test.TestCase): self.assertAllEqual(m1.get_shape().as_list(), [1, 2, 3]) m2 = array_ops.slice(z, [0, 0, 0], [constant_op.constant(1) + 0, 2, -1]) - self.assertAllEqual(m2.get_shape().as_list(), [None, 2, None]) + self.assertAllEqual(m2.get_shape().as_list(), [1, 2, 3]) def _testGradientSlice(self, input_shape, slice_begin, slice_size): @@ -283,7 +283,7 @@ class SliceTest(test.TestCase): # unintended behavior is prevented. c = constant_op.constant(5.0) with self.assertRaisesWithPredicateMatch( - TypeError, lambda e: "Tensor objects are not iterable" in str(e)): + TypeError, lambda e: "Tensor objects are only iterable" in str(e)): for _ in c: pass diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py index 427c07cfb8e47d65ac013dec2ecc0753ce4f5c05..fbf1adba9b02614720f830c904ce43558339158d 100644 --- a/tensorflow/python/kernel_tests/softmax_op_test.py +++ b/tensorflow/python/kernel_tests/softmax_op_test.py @@ -22,6 +22,7 @@ import unittest import numpy as np +from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops @@ -156,11 +157,17 @@ class SoftmaxTest(test.TestCase): np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64)) self._testOverflow() - def test1DTesnorAsInput(self): + def test1DTensorAsInput(self): self._testSoftmax( np.array([3., 2., 3., 9.]).astype(np.float64), use_gpu=False) self._testOverflow(use_gpu=False) + def test1DTensorAsInputNoReshape(self): + with compat.forward_compatibility_horizon(2018, 8, 27): + self._testSoftmax( + np.array([3., 2., 3., 9.]).astype(np.float64), use_gpu=False) + self._testOverflow(use_gpu=False) + def test3DTensorAsInput(self): self._testSoftmax( np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]], @@ -169,6 +176,15 @@ class SoftmaxTest(test.TestCase): use_gpu=False) self._testOverflow(use_gpu=False) + def test3DTensorAsInputNoReshape(self): + with compat.forward_compatibility_horizon(2018, 8, 27): + self._testSoftmax( + np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]], + [[2., 3., 4., 5.], [6., 7., 8., 9.]], + [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32), + use_gpu=False) + self._testOverflow(use_gpu=False) + def testAlongFirstDimension(self): self._testSoftmax( np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]], diff --git a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py index 27b39a626fcc6b2705bf9e797b5293ed3f1c7820..3847cebc7dcabd66c26a4e4551e5856c6a927a33 100644 --- a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py @@ -300,6 +300,51 @@ class SerializeSparseTest(test.TestCase): sparse_ops.serialize_many_sparse, sparse_ops.deserialize_sparse, dtypes.variant) + def testVariantSerializeDeserializeScalar(self): + with self.test_session(use_gpu=False) as sess: + indices_value = np.array([[]], dtype=np.int64) + values_value = np.array([37], dtype=np.int32) + shape_value = np.array([], dtype=np.int64) + sparse_tensor = self._SparseTensorPlaceholder() + serialized = sparse_ops.serialize_sparse( + sparse_tensor, out_type=dtypes.variant) + deserialized = sparse_ops.deserialize_sparse( + serialized, dtype=dtypes.int32) + deserialized_value = sess.run( + deserialized, + feed_dict={ + sparse_tensor.indices: indices_value, + sparse_tensor.values: values_value, + sparse_tensor.dense_shape: shape_value + }) + self.assertAllEqual(deserialized_value.indices, indices_value) + self.assertAllEqual(deserialized_value.values, values_value) + self.assertAllEqual(deserialized_value.dense_shape, shape_value) + + def testVariantSerializeDeserializeScalarBatch(self): + with self.test_session(use_gpu=False) as sess: + indices_value = np.array([[]], dtype=np.int64) + values_value = np.array([37], dtype=np.int32) + shape_value = np.array([], dtype=np.int64) + sparse_tensor = self._SparseTensorPlaceholder() + serialized = sparse_ops.serialize_sparse( + sparse_tensor, out_type=dtypes.variant) + stacked = array_ops.stack([serialized, serialized]) + deserialized = sparse_ops.deserialize_sparse(stacked, dtype=dtypes.int32) + deserialized_value = sess.run( + deserialized, + feed_dict={ + sparse_tensor.indices: indices_value, + sparse_tensor.values: values_value, + sparse_tensor.dense_shape: shape_value + }) + self.assertAllEqual(deserialized_value.indices, + np.array([[0], [1]], dtype=np.int64)) + self.assertAllEqual(deserialized_value.values, + np.array([37, 37], dtype=np.int32)) + self.assertAllEqual(deserialized_value.dense_shape, + np.array([2], dtype=np.int64)) + def _testDeserializeFailsWrongTypeHelper(self, serialize_fn, deserialize_fn, diff --git a/tensorflow/python/kernel_tests/sparse_slice_op_test.py b/tensorflow/python/kernel_tests/sparse_slice_op_test.py index da116601f833cc6b471e383e030c5fbe93b52ac5..97f30daf4a9c9615e1b42a1ba94e693e166bbc1c 100644 --- a/tensorflow/python/kernel_tests/sparse_slice_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_slice_op_test.py @@ -21,13 +21,15 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import sparse_ops +import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import from tensorflow.python.platform import test class SparseSliceOpTest(test.TestCase): - def _SparseTensor_4x6(self): + def _SparseTensor_4x6(self, val_dtype=np.int64): # [0 | |2 | |4 |5 ] # [ |11| |13|14| ] # [20| | |23| |25] @@ -37,7 +39,7 @@ class SparseSliceOpTest(test.TestCase): [2, 3], [2, 5], [3, 0], [3, 2], [3, 3], [3, 5]]).astype( np.int64) val = np.array([0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype( - np.int64) + val_dtype) shape = np.array([4, 6]).astype(np.int64) return sparse_tensor.SparseTensor(ind, val, shape) @@ -244,6 +246,22 @@ class SparseSliceOpTest(test.TestCase): self.assertAllEqual(sparse_tensor5.values.eval(), [5, 25, 35]) self.assertAllEqual(sparse_tensor5.dense_shape.eval(), [4, 1]) + def testGradients(self): + sp_input = self._SparseTensor_4x6(val_dtype=np.float32) + start_and_size = [([0, 0], [4, 2]), + ([0, 2], [5, 2]), + ([0, 4], [5, 3])] + + with self.test_session(use_gpu=False): + for start, size in start_and_size: + sp_output = sparse_ops.sparse_slice(sp_input, start, size) + nnz_in = len(sp_input.values.eval()) + nnz_out = len(sp_output.values.eval()) + + err = gradient_checker.compute_gradient_error( + [sp_input.values], [(nnz_in,)], sp_output.values, (nnz_out,)) + self.assertLess(err, 1e-3) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py index 8cfee3eb933afcea7a58d5632948b87b0c4c10df..3f9b029a6ac777fc97c65ecf3d70ac879bb5d116 100644 --- a/tensorflow/python/kernel_tests/split_op_test.py +++ b/tensorflow/python/kernel_tests/split_op_test.py @@ -95,7 +95,7 @@ class SplitOpTest(test.TestCase): sess.run(array_ops.split(value, size_splits), {size_splits: [2, 2, 6]}) self.assertTrue("Cannot infer num from shape" in str(context.exception)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testExplicitNum(self): size_splits = array_ops.constant([2, 2, 6], dtype=dtypes.int32) value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] @@ -109,7 +109,7 @@ class SplitOpTest(test.TestCase): self.assertAllEqual(r[1], value[2:4]) self.assertAllEqual(r[2], value[4:]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testListOfScalarTensors(self): a = math_ops.to_int32(5) b = math_ops.to_int32(6) @@ -168,12 +168,32 @@ class SplitOpTest(test.TestCase): offset += size_splits[i] self.assertAllEqual(result[i], inp[slices]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSpecialCasesVariable(self): self._testSpecialCasesVariable() for dtype in _TEST_DTYPES: self._testHugeNumberOfTensorsVariable(dtype) + @test_util.run_in_graph_and_eager_modes + def testDegenerateVariable(self): + inp = np.random.rand(4, 4).astype("f") + with test_util.device(use_gpu=True): + result = self.evaluate(array_ops.split(inp, [-1, 4], 0)) + self.assertAllEqual(result[0], inp[0:0, :]) + self.assertAllEqual(result[1], inp[0:4, :]) + + result = self.evaluate(array_ops.split(inp, [4, -1], 0)) + self.assertAllEqual(result[0], inp[0:4, :]) + self.assertAllEqual(result[1], inp[4:4, :]) + + result = self.evaluate(array_ops.split(inp, [-1, 4], 1)) + self.assertAllEqual(result[0], inp[:, 0:0]) + self.assertAllEqual(result[1], inp[:, 0:4]) + + result = self.evaluate(array_ops.split(inp, [4, -1], 1)) + self.assertAllEqual(result[0], inp[:, 0:4]) + self.assertAllEqual(result[1], inp[:, 4:4]) + def _testGradientsSimpleVariable(self, dtype): inp = self._makeData((4, 4), dtype) with test_util.device(use_gpu=True): @@ -210,13 +230,13 @@ class SplitOpTest(test.TestCase): self.assertAllEqual(np_ans[i], out[i]) self.assertShapeEqual(np_ans[i], tf_ans[i]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSplitRows(self): for dtype in _TEST_DTYPES: inp = self._makeData((4, 4), dtype) self._compare(inp, 0, 4) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSplitCols(self): for dtype in _TEST_DTYPES: inp = self._makeData((4, 4), dtype) @@ -232,7 +252,7 @@ class SplitOpTest(test.TestCase): self.assertEqual(out[i].shape, expected_shape) self.assertEqual(expected_shape, tf_ans[i].get_shape()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEmpty(self): # Note: np.split returns a rank-0 empty ndarray # if the input ndarray is empty. @@ -244,7 +264,7 @@ class SplitOpTest(test.TestCase): self._testEmpty(inp, 2, 3, (8, 0, 7)) self._testEmpty(inp, 2, 7, (8, 0, 3)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testIdentity(self): for dtype in _TEST_DTYPES: inp = self._makeData((2, 2, 2), dtype) @@ -252,7 +272,7 @@ class SplitOpTest(test.TestCase): self._compare(inp, 1, 1) self._compare(inp, 2, 1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSplitDim0(self): for dtype in _TEST_DTYPES: self._compare(self._makeData((6, 10, 18), dtype), 0, 3) @@ -281,7 +301,7 @@ class SplitOpTest(test.TestCase): offset += length self.assertAllEqual(result[i], inp[slices]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRandom(self): for dtype in _TEST_DTYPES: for _ in range(5): @@ -336,6 +356,16 @@ class SplitOpTest(test.TestCase): for s in splits: self.assertEqual(None, s.get_shape().ndims) + def testVariableShapeFunction(self): + # size_splits too big + with self.assertRaises(ValueError): + array_ops.split([0, 1], [3, -1], axis=0) + + # Correct inference of variable dimension + s0, s1 = array_ops.split([0, 1, 2], [2, -1], axis=0) + assert s0.shape.as_list() == [2] + assert s1.shape.as_list() == [1] + def testNonexistentDimTensor(self): x = array_ops.placeholder(dtypes.int32) values = np.zeros([5, 30]) diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py index 1b935d5286729e9e802c56e90e2ae7ab72a6e080..0b3a396d6bf46fb46416662a9443ed7b5811e15c 100644 --- a/tensorflow/python/kernel_tests/template_test.py +++ b/tensorflow/python/kernel_tests/template_test.py @@ -150,7 +150,7 @@ class TemplateTest(test.TestCase): # Parameters are tied, so the loss should have gone down after training. self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_skip_stack_frames(self): first = traceback.format_stack() second = traceback.format_stack() @@ -158,7 +158,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(result)) self.assertNotEqual(len(first), len(result)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_template_with_name(self): tmpl1 = template.make_template("s1", variable_scoped_function) tmpl2 = template.make_template("s1", variable_scoped_function) @@ -204,7 +204,7 @@ class TemplateTest(test.TestCase): self.assertEqual(v1, v3) self.assertEqual("s1/dummy:0", v1.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_template_in_scope(self): tmpl1 = template.make_template("s1", variable_scoped_function) tmpl2 = template.make_template("s1", variable_scoped_function) @@ -221,7 +221,7 @@ class TemplateTest(test.TestCase): self.assertEqual("scope/s1/dummy:0", v1.name) self.assertEqual("scope/s1_1/dummy:0", v3.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_template_with_internal_reuse(self): tmpl1 = template.make_template("s1", internally_variable_scoped_function) tmpl2 = template.make_template("s1", internally_variable_scoped_function) @@ -237,13 +237,13 @@ class TemplateTest(test.TestCase): with self.assertRaises(ValueError): tmpl1("not_test") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_template_without_name(self): with self.assertRaisesRegexp( ValueError, "name cannot be None."): template.make_template(None, variable_scoped_function) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_make_template(self): # Test both that we can call it with positional and keywords. tmpl1 = template.make_template( @@ -266,7 +266,7 @@ class TemplateTest(test.TestCase): with self.assertRaises(ValueError): tmpl() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_enforces_no_extra_trainable_variables_eager(self): tmpl = template.make_template("s", function_with_side_create, @@ -287,7 +287,7 @@ class TemplateTest(test.TestCase): trainable=False) self.assertEqual(tmpl(name="1"), tmpl(name="2")) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_internal_variable_reuse(self): def nested(): @@ -310,7 +310,7 @@ class TemplateTest(test.TestCase): self.assertEqual("s1/nested/x:0", v1.name) self.assertEqual("s1_1/nested/x:0", v3.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_nested_templates(self): def nested_template(): @@ -360,7 +360,7 @@ class TemplateTest(test.TestCase): self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name) self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_nested_templates_with_defun(self): def variable_scoped_function_no_return_value(trainable=True): @@ -429,7 +429,7 @@ class TemplateTest(test.TestCase): "a", partial, create_graph_function_=True) self.assertAllEqual(tmpl(ops.convert_to_tensor(1.0)), 2.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_immediate_scope_creation(self): # Create templates in scope a then call in scope b. make_template should # capture the scope the first time it is called, and make_immediate_template @@ -454,7 +454,7 @@ class TemplateTest(test.TestCase): self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name) self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_scope_access(self): # Ensure that we can access the scope inside the template, because the name # of that scope may be different from the name we pass to make_template, due @@ -479,7 +479,7 @@ class TemplateTest(test.TestCase): # Template is called at the top level, so there is no preceding "foo_2". self.assertEqual(tc.variable_scope.name, "blah") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_custom_getter(self): # Custom getter that maintains call count and forwards to true getter custom_getter_count = [0] @@ -512,7 +512,7 @@ class TemplateTest(test.TestCase): tmpl2() self.assertEqual(custom_getter_count[0], 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_fails_gracefully(self): for create_scope_now in [True, False]: def module_function_with_one_arg(inputs): @@ -535,7 +535,7 @@ class TemplateTest(test.TestCase): templatized_function(data) self.assertTrue(templatized_function._variables_created) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_name_scopes_for_variable_scopes(self): # Test that name scopes are not unnecessarily uniquified (but are # still uniquified when necessary). @@ -586,7 +586,7 @@ class TemplateTest(test.TestCase): "Second application of template should also get " "a freshly uniquified name scope.") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_global_variables(self): # Make sure global_variables are created. with variable_scope.variable_scope("foo"): @@ -608,7 +608,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(ta.global_variables)) self.assertEqual(2, len(tb.global_variables)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_trainable_variables(self): # Make sure trainable_variables are created. with variable_scope.variable_scope("foo2"): @@ -632,7 +632,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(ta.variables)) self.assertEqual(1, len(tb.variables)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_non_trainable_variables(self): # Make sure non_trainable_variables are created. with variable_scope.variable_scope("foo2"): @@ -675,7 +675,7 @@ class TemplateTest(test.TestCase): self.assertEqual(0, len(ta.local_variables)) self.assertEqual(1, len(tb.local_variables)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_make_template_with_defun(self): def variable_scoped_function_no_return_value(scope_name): diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index ea06357804f45bbba3a9e7e847659d47bf52bffb..6de6fbe7679fa8e95d3032b04fb81b43ac3a60d9 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -75,7 +75,7 @@ class TensorArrayTest(test.TestCase): super(TensorArrayTest, cls).tearDownClass() session_lib.Session.reset(cls._workers[0].target) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWriteRead(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -123,11 +123,11 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayWritePack(dtypes.complex128) self._testTensorArrayWritePack(dtypes.string) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWritePack(self): self._testTensorArrayWritePackMaybeLegacy() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEmptyTensorArrayPack(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -161,7 +161,7 @@ class TensorArrayTest(test.TestCase): convert([[4.0, 5.0], [104.0, 105.0], [204.0, 205.0], [6.0, 7.0], [106.0, 107.0], [8.0, 9.0]]), c0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWriteConcat(self): self._testTensorArrayWriteConcat(dtypes.float32) self._testTensorArrayWriteConcat(dtypes.float64) @@ -184,7 +184,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]], self.evaluate(ta.write(1, [[4.0, 5.0]]).concat())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros(self): self._testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros() @@ -200,7 +200,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]], self.evaluate(ta.write(1, [[4.0, 5.0]]).concat())) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros(self): self._testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros() @@ -251,7 +251,7 @@ class TensorArrayTest(test.TestCase): self._testTensorArrayUnpackRead(dtypes.complex128) self._testTensorArrayUnpackRead(dtypes.string) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayUnpackRead(self): self._testTensorArrayUnpackReadMaybeLegacy() @@ -297,7 +297,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual(convert([]).reshape(0, 2), d1) self.assertAllEqual(convert([[3.0, 301.0]]), d2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArraySplitRead(self): self._testTensorArraySplitRead(dtypes.float32) self._testTensorArraySplitRead(dtypes.float64) @@ -397,7 +397,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual(t_g_ta_0, t_g_ta_1) self.assertAllEqual([[4.0, 5.0]], d_r1_0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWriteWrongIndexOrDataTypeFails(self): with self.test_session(use_gpu=True): ta = _make_ta(3, "foo", dtype=dtypes.float32) @@ -416,7 +416,7 @@ class TensorArrayTest(test.TestCase): "resizeable and size is: 3"): self.evaluate(ta.write(3, 3.0).flow) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayReadWrongIndexOrDataTypeFails(self): with self.test_session(use_gpu=True): ta = _make_ta(3, "foo", dtype=dtypes.float32) @@ -450,7 +450,7 @@ class TensorArrayTest(test.TestCase): "it has already been written to."): self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayConcatIncompatibleShapesFails(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -482,7 +482,7 @@ class TensorArrayTest(test.TestCase): with self.assertRaisesOpError("shape"): self.evaluate(w3.concat()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArraySplitIncompatibleShapesFails(self): with self.test_session(use_gpu=True): in_eager_mode = context.executing_eagerly() @@ -603,7 +603,7 @@ class TensorArrayTest(test.TestCase): self.assertAllClose(fed_value, sess.run(read_value, feed_dict={value: fed_value})) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMultiTensorArray(self): with self.test_session(use_gpu=True): h1 = tensor_array_ops.TensorArray( @@ -706,7 +706,7 @@ class TensorArrayTest(test.TestCase): def testTensorArrayGradientWritePackConcatAndRead(self): self._testTensorArrayGradientWritePackConcatAndRead() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayReadTwice(self): with self.test_session(use_gpu=True): value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) @@ -811,14 +811,14 @@ class TensorArrayTest(test.TestCase): def testTensorArrayGradientDynamicUnpackRead(self): self._testTensorArrayGradientDynamicUnpackRead() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCloseTensorArray(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) self.evaluate(ta.close()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSizeTensorArray(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -826,7 +826,7 @@ class TensorArrayTest(test.TestCase): s = ta.size() self.assertAllEqual(3, self.evaluate(s)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWriteCloseTensorArray(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -924,7 +924,7 @@ class TensorArrayTest(test.TestCase): self.assertAllClose(grad_val.sum(axis=0), var_grad_t) self.assertAllClose(grad_val.sum(axis=0), state0_grad_t) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWhileLoopWritePackGradients(self): self._testWhileLoopWritePackGradients( dynamic_size=False, dtype=dtypes.float32) @@ -936,7 +936,7 @@ class TensorArrayTest(test.TestCase): self._testWhileLoopWritePackGradients( dynamic_size=True, dtype=dtypes.float32) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradSerialTwoLoops(self): with self.test_session(use_gpu=True): def loop(x): @@ -1113,7 +1113,7 @@ class TensorArrayTest(test.TestCase): r5 = w5.read(0) self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def _testUnpackShape(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -1147,7 +1147,7 @@ class TensorArrayTest(test.TestCase): def testUnpackShape(self): self._testUnpackShape() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSplitShape(self): with self.test_session(use_gpu=True): ta = tensor_array_ops.TensorArray( @@ -1289,7 +1289,7 @@ class TensorArrayTest(test.TestCase): self.assertAllEqual([10.0, -10.0], read_vals[1]) self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayWriteGatherAndGradients(self): with self.test_session(use_gpu=True) as session: ta = tensor_array_ops.TensorArray( @@ -1433,7 +1433,7 @@ class TensorArrayTest(test.TestCase): self.assertFalse( [s for s in dev_stats[d] if "/TensorArray" in s.node_name]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTensorArrayIdentity(self): with self.test_session(use_gpu=True): ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2, diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 2ee53df9317331dafd96f7884e9a8728cf443923..ae2a0ab29abed2902c0095f7b0886c1afa704af4 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -57,7 +57,7 @@ class VariableScopeTest(test.TestCase): v1 = vs.get_variable("v", [1]) self.assertEqual(v, v1) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testResource(self): vs = variable_scope._get_default_variable_store() v1 = vs.get_variable("v", [1], use_resource=True) @@ -87,7 +87,7 @@ class VariableScopeTest(test.TestCase): self.assertEqual( set(expected_names), set([v.name for v in vs._vars.values()])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeInitializer(self): init = init_ops.constant_initializer(0.3) with variable_scope.variable_scope("tower0") as tower: @@ -100,7 +100,7 @@ class VariableScopeTest(test.TestCase): self.evaluate(variables_lib.variables_initializer([w])) self.assertAllClose(self.evaluate(w.value()), 0.3) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeConstraint(self): constraint = lambda x: 0. * x with variable_scope.variable_scope("tower1") as tower: @@ -117,7 +117,7 @@ class VariableScopeTest(test.TestCase): variables_lib.global_variables_initializer().run() self.assertAllEqual(compat.as_bytes(v.eval()), b"") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeDType(self): with variable_scope.variable_scope("tower2") as tower: with variable_scope.variable_scope("foo", dtype=dtypes.float16): @@ -197,7 +197,7 @@ class VariableScopeTest(test.TestCase): self.assertAllEqual([v1, v2], [v3, v4]) f() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerVariablesStoreAddsToCollections(self): store = variable_scope.EagerVariableStore() with store.as_default(): @@ -214,7 +214,7 @@ class VariableScopeTest(test.TestCase): self.assertEqual( ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES), [concat]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerVariablesOutsideStoreNotAddedToCollections(self): if not context.executing_eagerly(): return @@ -223,7 +223,7 @@ class VariableScopeTest(test.TestCase): self.assertFalse(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) self.assertFalse(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitFromNonTensorValue(self): v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32) self.evaluate(variables_lib.variables_initializer([v])) @@ -239,7 +239,7 @@ class VariableScopeTest(test.TestCase): with self.assertRaises(error): variable_scope.get_variable("x4", initializer={}) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitFromNonInitializer(self): # Test various dtypes with zeros initializer as following: types = [ @@ -294,7 +294,7 @@ class VariableScopeTest(test.TestCase): v_tower = variable_scope.get_variable("v", []) self.assertFalse(v_tower.value().device.startswith(caching_device)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeRegularizer(self): init = init_ops.constant_initializer(0.3) @@ -339,7 +339,7 @@ class VariableScopeTest(test.TestCase): losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) self.assertEqual(3, len(losses)) # No new loss added. - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInitializeFromValue(self): init = constant_op.constant(0.1) w = variable_scope.get_variable("v", initializer=init) @@ -428,7 +428,7 @@ class VariableScopeTest(test.TestCase): sess.run(v0.initializer) sess.run(add) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetVariableScope(self): # Test the get_variable_scope() function and setting properties of result. init = init_ops.constant_initializer(0.3) @@ -449,7 +449,7 @@ class VariableScopeTest(test.TestCase): new_init = variable_scope.get_variable_scope().initializer self.assertEqual(new_init, None) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScope(self): with variable_scope.variable_scope("tower4") as tower: self.assertEqual(tower.name, "tower4") @@ -468,7 +468,7 @@ class VariableScopeTest(test.TestCase): with ops.name_scope("scope") as sc: self.assertEqual(sc, "tower6/tower4/scope/") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testVarScopeNameScope(self): with ops.name_scope("testVarScopeNameScope1"): with variable_scope.variable_scope("tower") as tower: @@ -961,7 +961,7 @@ class VariableScopeTest(test.TestCase): self.assertEqual( constant_op.constant([], name="c").name, "another/inner/c:0") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGetLocalVar(self): # Check that local variable respects naming. with variable_scope.variable_scope("outer") as outer: @@ -1054,7 +1054,7 @@ class VariableScopeTest(test.TestCase): "testGetCollection_foo/testGetCollection_a:0" ]) - def testGetTrainableVariables(self): + def testGetTrainableVariablesWithGetVariable(self): with self.test_session(): _ = variable_scope.get_variable("testGetTrainableVariables_a", []) with variable_scope.variable_scope( @@ -1062,10 +1062,72 @@ class VariableScopeTest(test.TestCase): _ = variable_scope.get_variable("testGetTrainableVariables_b", []) _ = variable_scope.get_variable( "testGetTrainableVariables_c", [], trainable=False) + + # sync `ON_READ` sets trainable=False + _ = variable_scope.get_variable( + "testGetTrainableVariables_d", [], + synchronization=variable_scope.VariableSynchronization.ON_READ) self.assertEqual( [v.name for v in scope.trainable_variables()], - ["testGetTrainableVariables_foo/" - "testGetTrainableVariables_b:0"]) + ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"]) + + # All other sync values sets trainable=True + _ = variable_scope.get_variable( + "testGetTrainableVariables_e", [], + synchronization=variable_scope.VariableSynchronization.ON_WRITE) + self.assertEqual([v.name for v in scope.trainable_variables()], [ + "testGetTrainableVariables_foo/testGetTrainableVariables_b:0", + "testGetTrainableVariables_foo/testGetTrainableVariables_e:0" + ]) + + with self.assertRaisesRegexp( + ValueError, "Synchronization value can be set to " + "VariableSynchronization.ON_READ only for non-trainable variables. " + "You have specified trainable=True and " + "synchronization=VariableSynchronization.ON_READ."): + _ = variable_scope.get_variable( + "testGetTrainableVariables_e", [], + synchronization=variable_scope.VariableSynchronization.ON_READ, + trainable=True) + + def testGetTrainableVariablesWithVariable(self): + with self.test_session(): + _ = variable_scope.variable(1.0, name="testGetTrainableVariables_a") + with variable_scope.variable_scope( + "testGetTrainableVariables_foo") as scope: + _ = variable_scope.variable(1.0, name="testGetTrainableVariables_b") + _ = variable_scope.variable( + 1.0, name="testGetTrainableVariables_c", trainable=False) + + # sync `ON_READ` sets trainable=False + _ = variable_scope.variable( + 1.0, + name="testGetTrainableVariables_d", + synchronization=variable_scope.VariableSynchronization.ON_READ) + self.assertEqual( + [v.name for v in scope.trainable_variables()], + ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"]) + + # All other sync values sets trainable=True + _ = variable_scope.variable( + 1.0, + name="testGetTrainableVariables_e", + synchronization=variable_scope.VariableSynchronization.ON_WRITE) + self.assertEqual([v.name for v in scope.trainable_variables()], [ + "testGetTrainableVariables_foo/testGetTrainableVariables_b:0", + "testGetTrainableVariables_foo/testGetTrainableVariables_e:0" + ]) + + with self.assertRaisesRegexp( + ValueError, "Synchronization value can be set to " + "VariableSynchronization.ON_READ only for non-trainable variables. " + "You have specified trainable=True and " + "synchronization=VariableSynchronization.ON_READ."): + _ = variable_scope.variable( + 1.0, + name="testGetTrainableVariables_e", + synchronization=variable_scope.VariableSynchronization.ON_READ, + trainable=True) def testGetGlobalVariables(self): with self.test_session(): @@ -1253,6 +1315,31 @@ class VariableScopeWithCustomGetterTest(test.TestCase): self.assertEqual(v3, v4) self.assertEqual(3, called[0]) # skipped one in the first new_scope + def testSynchronizationAndAggregationWithCustomGetter(self): + called = [0] + synchronization = variable_scope.VariableSynchronization.AUTO + aggregation = variable_scope.VariableAggregation.NONE + + def custom_getter(getter, *args, **kwargs): + called[0] += 1 + + # Verify synchronization and aggregation kwargs are as expected. + self.assertEqual(kwargs["synchronization"], synchronization) + self.assertEqual(kwargs["aggregation"], aggregation) + return getter(*args, **kwargs) + + with variable_scope.variable_scope("scope", custom_getter=custom_getter): + variable_scope.get_variable("v", [1]) + self.assertEqual(1, called[0]) + + with variable_scope.variable_scope("scope", custom_getter=custom_getter): + synchronization = variable_scope.VariableSynchronization.ON_READ + aggregation = variable_scope.VariableAggregation.MEAN + variable_scope.get_variable( + "v1", [1], synchronization=synchronization, aggregation=aggregation) + + self.assertEqual(2, called[0]) + def testCustomGetterWithReuse(self): # Custom getter can choose to behave differently on reused variables. def custom_getter(getter, *args, **kwargs): @@ -1355,6 +1442,23 @@ class VariableScopeWithCustomGetterTest(test.TestCase): self.assertAllEqual(variable_names, ["forced_name"]) + called = [False] + + def creater_c(next_creator, **kwargs): + called[0] = True + self.assertEqual(kwargs["synchronization"], + variable_scope.VariableSynchronization.ON_WRITE) + self.assertEqual(kwargs["aggregation"], + variable_scope.VariableAggregation.MEAN) + return next_creator(**kwargs) + + with variable_scope.variable_creator_scope(creater_c): + variable_scope.get_variable( + "v", [], + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN) + self.assertTrue(called[0]) + class PartitionInfoTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 62d596da91682c396c04efbc64cf063c8e29e7cc..2b9c62ad6f15aea65bd8d504b2f5e713ee38fc83 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -642,6 +642,8 @@ class PartitionedVariableTest(test.TestCase): iterated_partitions = list(partitioned_variable) self.assertEqual(2, num_partitions) self.assertEqual([v0, v1], iterated_partitions) + self.assertEqual([2], partitioned_variable.get_shape()) + self.assertEqual([2], partitioned_variable.shape) self.assertEqual([2], concatenated.get_shape()) self.assertEqual([2], concatenated.shape) diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py index 17575da6f1bf2c226a67419b4bc8156f70f6dedc..29fb002ef445bcfa483dfc747428c7365860fe37 100644 --- a/tensorflow/python/kernel_tests/where_op_test.py +++ b/tensorflow/python/kernel_tests/where_op_test.py @@ -135,6 +135,15 @@ class WhereOpTest(test.TestCase): tf_val = array_ops.where(constant_op.constant(x) > 0, x * x, -x).eval() self.assertAllEqual(tf_val, np_val) + def testBatchSelect(self): + x = np.array([[-2, 3, -1] * 64, [1, -3, -3] * 64] * 8192) # [16384, 192] + c_mat = np.array([[False] * 192, [True] * 192] * 8192) # [16384, 192] + c_vec = np.array([False, True] * 8192) # [16384] + np_val = np.where(c_mat, x * x, -x) + with self.test_session(use_gpu=True): + tf_val = array_ops.where(c_vec, x * x, -x).eval() + self.assertAllEqual(tf_val, np_val) + class WhereBenchmark(test.Benchmark): @@ -163,5 +172,32 @@ class WhereBenchmark(test.Benchmark): "Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput)) sys.stdout.flush() + def benchmarkBatchSelect(self): + for (m, n, use_gpu) in itertools.product([1000, 10000, 100000], + [10, 100, 1000], [False, True]): + name = "m_%d_n_%d_use_gpu_%s" % (m, n, use_gpu) + device = "/%s:0" % ("gpu" if use_gpu else "cpu") + with ops.Graph().as_default(): + with ops.device(device): + x_gen = random_ops.random_uniform([m, n], dtype=dtypes.float32) + y_gen = random_ops.random_uniform([m, n], dtype=dtypes.float32) + c_gen = random_ops.random_uniform([m], dtype=dtypes.float32) <= 0.5 + x = resource_variable_ops.ResourceVariable(x_gen) + y = resource_variable_ops.ResourceVariable(y_gen) + c = resource_variable_ops.ResourceVariable(c_gen) + op = array_ops.where(c, x, y) + with session.Session() as sess: + x.initializer.run() + y.initializer.run() + c.initializer.run() + r = self.run_op_benchmark(sess, op, min_iters=100, name=name) + # approximate size of output: m*n*2 floats for each axis. + gb_processed = m * n * 8 / 1.0e9 + throughput = gb_processed / r["wall_time"] + print("Benchmark: %s \t wall_time: %0.03g s \t " + "Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput)) + sys.stdout.flush() + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index eda036ece4a7d74e5752e80a4a2f4e4ada1b0a38..ab0886553269424a6d5c6d4e494dddb22cd2d7ce 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -152,10 +152,17 @@ class Layer(base_layer.Layer): scope, default_name=self._base_name) as captured_scope: self._scope = captured_scope - def add_weight(self, name, shape, dtype=None, - initializer=None, regularizer=None, - trainable=True, constraint=None, + def add_weight(self, + name, + shape, + dtype=None, + initializer=None, + regularizer=None, + trainable=None, + constraint=None, use_resource=None, + synchronization=vs.VariableSynchronization.AUTO, + aggregation=vs.VariableAggregation.NONE, partitioner=None): """Adds a new variable to the layer, or gets an existing one; returns it. @@ -170,9 +177,19 @@ class Layer(base_layer.Layer): or "non_trainable_variables" (e.g. BatchNorm mean, stddev). Note, if the current variable scope is marked as non-trainable then this parameter is ignored and any added variables are also - marked as non-trainable. + marked as non-trainable. `trainable` defaults to `True` unless + `synchronization` is set to `ON_READ`. constraint: constraint instance (callable). use_resource: Whether to use `ResourceVariable`. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. partitioner: (optional) partitioner instance (callable). If provided, when the requested variable is created it will be split into multiple partitions according to `partitioner`. In this case, @@ -190,8 +207,22 @@ class Layer(base_layer.Layer): Raises: RuntimeError: If called with partioned variable regularization and eager execution is enabled. + ValueError: When trainable has been set to True with synchronization + set as `ON_READ`. """ - + if synchronization == vs.VariableSynchronization.ON_READ: + if trainable: + raise ValueError( + 'Synchronization value can be set to ' + 'VariableSynchronization.ON_READ only for non-trainable variables. ' + 'You have specified trainable=True and ' + 'synchronization=VariableSynchronization.ON_READ.') + else: + # Set trainable to be false when variable is to be synced on read. + trainable = False + elif trainable is None: + trainable = True + def _should_add_regularizer(variable, existing_variable_set): if isinstance(variable, tf_variables.PartitionedVariable): for var in variable: @@ -240,6 +271,8 @@ class Layer(base_layer.Layer): constraint=constraint, partitioner=partitioner, use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation, getter=vs.get_variable) if regularizer: diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index ab49e37b90e183034ae7ab720fa92b06f39b2aed..d2443db6651cdab2aaf5fb2b9d678080b48bb254 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -39,7 +39,7 @@ from tensorflow.python.platform import test class BaseLayerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLayerProperties(self): layer = base_layers.Layer(name='my_layer') self.assertEqual(layer.variables, []) @@ -53,13 +53,13 @@ class BaseLayerTest(test.TestCase): layer = base_layers.Layer(name='my_layer', trainable=False) self.assertEqual(layer.trainable, False) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInt64Layer(self): layer = base_layers.Layer(name='my_layer', dtype='int64') layer.add_variable('my_var', [2, 2]) self.assertEqual(layer.name, 'my_layer') - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAddWeight(self): layer = base_layers.Layer(name='my_layer') @@ -90,12 +90,34 @@ class BaseLayerTest(test.TestCase): # regularizers only supported in GRAPH mode. regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 - variable = layer.add_variable( + _ = layer.add_variable( 'reg_var', [2, 2], initializer=init_ops.zeros_initializer(), regularizer=regularizer) self.assertEqual(len(layer.losses), 1) + # Test that sync `ON_READ` variables are defaulted to be non-trainable. + variable_3 = layer.add_variable( + 'sync_on_read_var', [2, 2], + initializer=init_ops.zeros_initializer(), + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + self.assertEqual(layer.non_trainable_variables, [variable_2, variable_3]) + + def testInvalidTrainableSynchronizationCombination(self): + layer = base_layers.Layer(name='my_layer') + + with self.assertRaisesRegexp( + ValueError, 'Synchronization value can be set to ' + 'VariableSynchronization.ON_READ only for non-trainable variables. ' + 'You have specified trainable=True and ' + 'synchronization=VariableSynchronization.ON_READ.'): + _ = layer.add_variable( + 'v', [2, 2], + initializer=init_ops.zeros_initializer(), + synchronization=variable_scope.VariableSynchronization.ON_READ, + trainable=True) + def testReusePartitionedVaraiblesAndRegularizers(self): regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 partitioner = partitioned_variables.fixed_size_partitioner(3) @@ -104,7 +126,7 @@ class BaseLayerTest(test.TestCase): partitioner=partitioner, reuse=reuse): layer = base_layers.Layer(name='my_layer') - variable = layer.add_variable( + _ = layer.add_variable( 'reg_part_var', [4, 4], initializer=init_ops.zeros_initializer(), regularizer=regularizer) @@ -116,7 +138,7 @@ class BaseLayerTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'activity_regularizer'): core_layers.Dense(1, activity_regularizer=lambda *args, **kwargs: 0.) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCall(self): class MyLayer(base_layers.Layer): @@ -132,7 +154,7 @@ class BaseLayerTest(test.TestCase): # op is only supported in GRAPH mode self.assertEqual(outputs.op.name, 'my_layer/Square') - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDeepCopy(self): class MyLayer(base_layers.Layer): @@ -155,7 +177,7 @@ class BaseLayerTest(test.TestCase): self.assertEqual(layer_copy._graph, layer._graph) self.assertEqual(layer_copy._private_tensor, layer._private_tensor) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testScopeNaming(self): class PrivateLayer(base_layers.Layer): @@ -203,7 +225,7 @@ class BaseLayerTest(test.TestCase): my_layer_scoped1.apply(inputs) self.assertEqual(my_layer_scoped1._scope.name, 'var_scope/my_layer_1') - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecNdimCheck(self): class CustomerLayer(base_layers.Layer): @@ -230,7 +252,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[1], [2]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecMinNdimCheck(self): class CustomerLayer(base_layers.Layer): @@ -258,7 +280,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[[1], [2]]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecMaxNdimCheck(self): class CustomerLayer(base_layers.Layer): @@ -286,7 +308,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[1], [2]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecDtypeCheck(self): class CustomerLayer(base_layers.Layer): @@ -306,7 +328,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant(1.0, dtype=dtypes.float32)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecAxesCheck(self): class CustomerLayer(base_layers.Layer): @@ -328,7 +350,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[1, 2], [3, 4], [5, 6]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testInputSpecShapeCheck(self): class CustomerLayer(base_layers.Layer): @@ -348,7 +370,7 @@ class BaseLayerTest(test.TestCase): layer = CustomerLayer() layer.apply(constant_op.constant([[1, 2, 3], [4, 5, 6]])) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoInputSpec(self): class CustomerLayer(base_layers.Layer): @@ -369,7 +391,7 @@ class BaseLayerTest(test.TestCase): layer.apply(array_ops.placeholder('int32')) layer.apply(array_ops.placeholder('int32', shape=(2, 3))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_count_params(self): dense = core_layers.Dense(16) dense.build((None, 4)) @@ -379,7 +401,7 @@ class BaseLayerTest(test.TestCase): with self.assertRaises(ValueError): dense.count_params() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDictInputOutput(self): class DictLayer(base_layers.Layer): @@ -589,6 +611,5 @@ class BaseLayerTest(test.TestCase): ValueError, 'Input graph and Layer graph are not the same'): layer.apply(constant_op.constant([[1.]])) - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index 267d78dbcb27392a528bf09414b857d9b1a7c2f9..d40743b0cea29553430a0fc247684f7b182a94ee 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -13,23 +13,15 @@ # limitations under the License. # ============================================================================= -# pylint: disable=unused-import,g-bad-import-order """Contains the convolutional layer classes and their functional aliases. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.eager import context -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import layers as keras_layers from tensorflow.python.layers import base -from tensorflow.python.layers import utils -from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops from tensorflow.python.util.tf_export import tf_export @@ -217,7 +209,6 @@ def conv1d(inputs, bias_constraint=bias_constraint, trainable=trainable, name=name, - dtype=inputs.dtype.base_dtype, _reuse=reuse, _scope=name) return layer.apply(inputs) @@ -421,7 +412,6 @@ def conv2d(inputs, bias_constraint=bias_constraint, trainable=trainable, name=name, - dtype=inputs.dtype.base_dtype, _reuse=reuse, _scope=name) return layer.apply(inputs) @@ -627,7 +617,6 @@ def conv3d(inputs, bias_constraint=bias_constraint, trainable=trainable, name=name, - dtype=inputs.dtype.base_dtype, _reuse=reuse, _scope=name) return layer.apply(inputs) @@ -1266,7 +1255,6 @@ def conv2d_transpose(inputs, bias_constraint=bias_constraint, trainable=trainable, name=name, - dtype=inputs.dtype.base_dtype, _reuse=reuse, _scope=name) return layer.apply(inputs) @@ -1438,7 +1426,6 @@ def conv3d_transpose(inputs, bias_constraint=bias_constraint, trainable=trainable, name=name, - dtype=inputs.dtype.base_dtype, _reuse=reuse, _scope=name) return layer.apply(inputs) diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index abbacac442c5bb20feeb255d4ad3f90626c75327..9879e5020f31286fc342331843472cac08c6f330 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================= -# pylint: disable=unused-import,g-bad-import-order """Contains the core layers: Dense, Dropout. Also contains their functional aliases. @@ -23,10 +22,6 @@ from __future__ import division from __future__ import print_function -import six -from six.moves import xrange # pylint: disable=redefined-builtin -import numpy as np - from tensorflow.python.keras import layers as keras_layers from tensorflow.python.layers import base from tensorflow.python.ops import init_ops @@ -132,8 +127,8 @@ def dense( """Functional interface for the densely-connected layer. This layer implements the operation: - `outputs = activation(inputs.kernel + bias)` - Where `activation` is the activation function passed as the `activation` + `outputs = activation(inputs * kernel + bias)` + where `activation` is the activation function passed as the `activation` argument (if not `None`), `kernel` is a weights matrix created by the layer, and `bias` is a bias vector created by the layer (only if `use_bias` is `True`). @@ -184,7 +179,6 @@ def dense( bias_constraint=bias_constraint, trainable=trainable, name=name, - dtype=inputs.dtype.base_dtype, _scope=name, _reuse=reuse) return layer.apply(inputs) @@ -209,7 +203,7 @@ class Dropout(keras_layers.Dropout, base.Layer): to be the same for all timesteps, you can use `noise_shape=[batch_size, 1, features]`. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed}. + `tf.set_random_seed`. for behavior. name: The name of the layer (string). """ @@ -254,7 +248,7 @@ def dropout(inputs, to be the same for all timesteps, you can use `noise_shape=[batch_size, 1, features]`. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. training: Either a Python boolean, or a TensorFlow boolean scalar tensor (e.g. a placeholder). Whether to return the output in training mode diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index cf45b07637108422f1c612390bb01efdad6d5bcf..040c1cddc0f2540eec5fcf3442bed3f4800bec7c 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -41,7 +41,7 @@ from tensorflow.python.platform import test class DenseTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDenseProperties(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense') self.assertEqual(dense.units, 2) @@ -91,14 +91,14 @@ class DenseTest(test.TestCase): core_layers.Dense(5)(inputs) core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')(inputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testCallTensorDot(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense') inputs = random_ops.random_uniform((5, 4, 3), seed=1) outputs = dense(inputs) self.assertListEqual([5, 4, 2], outputs.get_shape().as_list()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoBias(self): dense = core_layers.Dense(2, use_bias=False, name='my_dense') inputs = random_ops.random_uniform((5, 2), seed=1) @@ -112,7 +112,7 @@ class DenseTest(test.TestCase): self.assertEqual(dense.kernel.name, 'my_dense/kernel:0') self.assertEqual(dense.bias, None) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNonTrainable(self): dense = core_layers.Dense(2, trainable=False, name='my_dense') inputs = random_ops.random_uniform((5, 2), seed=1) @@ -125,7 +125,7 @@ class DenseTest(test.TestCase): self.assertEqual( len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOutputShape(self): dense = core_layers.Dense(7, activation=nn_ops.relu, name='my_dense') inputs = random_ops.random_uniform((5, 3), seed=1) @@ -165,7 +165,7 @@ class DenseTest(test.TestCase): dense = core_layers.Dense(4, name='my_dense') dense(inputs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testActivation(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1') inputs = random_ops.random_uniform((5, 3), seed=1) @@ -325,7 +325,7 @@ class DenseTest(test.TestCase): var_key = 'test2/dense/kernel' self.assertEqual(var_dict[var_key].name, '%s:0' % var_key) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testComputeOutputShape(self): dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1') ts = tensor_shape.TensorShape @@ -347,7 +347,7 @@ class DenseTest(test.TestCase): dense.compute_output_shape(ts([None, 4, 3])).as_list()) # pylint: enable=protected-access - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testConstraints(self): k_constraint = lambda x: x / math_ops.reduce_sum(x) b_constraint = lambda x: x / math_ops.reduce_max(x) @@ -369,7 +369,7 @@ def _get_variable_dict_from_varstore(): class DropoutTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDropoutProperties(self): dp = core_layers.Dropout(0.5, name='dropout') self.assertEqual(dp.rate, 0.5) @@ -377,7 +377,7 @@ class DropoutTest(test.TestCase): dp.apply(array_ops.ones(())) self.assertEqual(dp.name, 'dropout') - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBooleanLearningPhase(self): dp = core_layers.Dropout(0.5) inputs = array_ops.ones((5, 3)) @@ -402,7 +402,7 @@ class DropoutTest(test.TestCase): np_output = sess.run(dropped, feed_dict={training: False}) self.assertAllClose(np.ones((5, 5)), np_output) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDynamicNoiseShape(self): inputs = array_ops.ones((5, 3, 2)) noise_shape = [None, 1, None] diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index d082e312e9a3750726235a0360ef466fb8915208..691dac69865b6e0ee582071d01c2cf626f7f639a 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -13,16 +13,12 @@ # limitations under the License. # ============================================================================= -# pylint: disable=unused-import,g-bad-import-order """Contains the normalization layer classes and their functional aliases. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six -from six.moves import xrange # pylint: disable=redefined-builtin -import numpy as np from tensorflow.python.keras import layers as keras_layers from tensorflow.python.layers import base @@ -44,7 +40,7 @@ class BatchNormalization(keras_layers.BatchNormalization, base.Layer): normalized, typically the features axis/axes. For instance, after a `Conv2D` layer with `data_format="channels_first"`, set `axis=1`. If a list of axes is provided, each axis in `axis` will be normalized - simultaneously. Default is `-1` which takes uses last axis. Note: when + simultaneously. Default is `-1` which uses the last axis. Note: when using multi-axis batch norm, the `beta`, `gamma`, `moving_mean`, and `moving_variance` variables are the same rank as the input Tensor, with dimension size 1 in all reduced (non-axis) dimensions). @@ -308,7 +304,6 @@ def batch_normalization(inputs, virtual_batch_size=virtual_batch_size, adjustment=adjustment, name=name, - dtype=inputs.dtype.base_dtype, _reuse=reuse, _scope=name) return layer.apply(inputs, training=training) diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py index 3b156c36a2ff35fb9e05af1406d7b3f6cf883394..8e4b274207a9eeb183e5ecd640733d0001093817 100644 --- a/tensorflow/python/layers/utils.py +++ b/tensorflow/python/layers/utils.py @@ -13,19 +13,15 @@ # limitations under the License. # ============================================================================= -# pylint: disable=unused-import,g-bad-import-order """Contains layer utilies for input validation and format conversion. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.eager import context from tensorflow.python.ops import variables from tensorflow.python.ops import control_flow_ops -from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond as smart_module -from tensorflow.python.framework import tensor_util from tensorflow.python.util import nest diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc index 77fa2c1f66d2214dbb08e4d0ad3437fa4fe02822..fde3a83770280038b777a141693d117dace4b41f 100644 --- a/tensorflow/python/lib/core/bfloat16.cc +++ b/tensorflow/python/lib/core/bfloat16.cc @@ -446,6 +446,16 @@ npy_bool NPyBfloat16_NonZero(void* data, void* arr) { return x != static_cast(0); } +int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) { + bfloat16* const buffer = reinterpret_cast(buffer_raw); + const float start(buffer[0]); + const float delta = static_cast(buffer[1]) - start; + for (npy_intp i = 2; i < length; ++i) { + buffer[i] = static_cast(start + i * delta); + } + return 0; +} + // NumPy casts // Performs a NumPy array cast from type 'From' to 'To'. @@ -548,6 +558,7 @@ bool Initialize() { NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN; NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap; NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero; + NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill; Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type; npy_bfloat16_ = PyArray_RegisterDataType(&NPyBfloat16_Descr); diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py index 09d4b01fa43babdc09f8f255e79bbed539ddc04c..bc928cd9e5ef4d5a0ec0ce73e853e3e022a1f6fa 100644 --- a/tensorflow/python/lib/core/bfloat16_test.py +++ b/tensorflow/python/lib/core/bfloat16_test.py @@ -245,6 +245,20 @@ class Bfloat16NumPyTest(test.TestCase): np.logaddexp(x.astype(bfloat16), y.astype(bfloat16)), atol=2e-2) + def testArange(self): + self.assertAllEqual( + np.arange(100, dtype=np.float32).astype(bfloat16), + np.arange(100, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16), + np.arange(-10.5, 7.8, 0.5, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16), + np.arange(-0., -7., -0.25, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16), + np.arange(-16384., 16384., 64., dtype=bfloat16)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc index ec1ba7b8f7d611ad659ac483505a7d86bf4b31e5..5765b17594476d13cf9706e79d55350dba474106 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.cc +++ b/tensorflow/python/lib/core/ndarray_tensor.cc @@ -136,6 +136,33 @@ Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array, return Status::OK(); } +Status PyObjectToString(PyObject* obj, const char** ptr, Py_ssize_t* len, + PyObject** ptr_owner) { + *ptr_owner = nullptr; + if (!PyUnicode_Check(obj)) { + char* buf; + if (PyBytes_AsStringAndSize(obj, &buf, len) != 0) { + return errors::Internal("Unable to get element as bytes."); + } + *ptr = buf; + return Status::OK(); + } +#if (PY_MAJOR_VERSION > 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 3)) + *ptr = PyUnicode_AsUTF8AndSize(obj, len); + if (*ptr != nullptr) return Status::OK(); +#else + PyObject* utemp = PyUnicode_AsUTF8String(obj); + char* buf; + if (utemp != nullptr && PyBytes_AsStringAndSize(utemp, &buf, len) != -1) { + *ptr = buf; + *ptr_owner = utemp; + return Status::OK(); + } + Py_XDECREF(utemp); +#endif + return errors::Internal("Unable to convert element to UTF-8."); +} + // Iterate over the string array 'array', extract the ptr and len of each string // element and call f(ptr, len). template @@ -148,33 +175,12 @@ Status PyBytesArrayMap(PyArrayObject* array, F f) { if (!item) { return errors::Internal("Unable to get element from the feed - no item."); } - char* ptr; Py_ssize_t len; - - if (PyUnicode_Check(item.get())) { -#if PY_VERSION_HEX >= 0x03030000 - // Accept unicode by converting to UTF-8 bytes. - ptr = PyUnicode_AsUTF8AndSize(item.get(), &len); - if (!ptr) { - return errors::Internal("Unable to get element as UTF-8."); - } - f(ptr, len); -#else - PyObject* utemp = PyUnicode_AsUTF8String(item.get()); - if (!utemp || PyBytes_AsStringAndSize(utemp, &ptr, &len) == -1) { - Py_XDECREF(utemp); - return errors::Internal("Unable to convert element to UTF-8."); - } - f(ptr, len); - Py_DECREF(utemp); -#endif - } else { - int success = PyBytes_AsStringAndSize(item.get(), &ptr, &len); - if (success != 0) { - return errors::Internal("Unable to get element as bytes."); - } - f(ptr, len); - } + const char* ptr; + PyObject* ptr_owner; + TF_RETURN_IF_ERROR(PyObjectToString(item.get(), &ptr, &len, &ptr_owner)); + f(ptr, len); + Py_XDECREF(ptr_owner); PyArray_ITER_NEXT(iter.get()); } return Status::OK(); @@ -186,10 +192,11 @@ Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems, size_t* size, void** buffer) { // Compute bytes needed for encoding. *size = 0; - TF_RETURN_IF_ERROR(PyBytesArrayMap(array, [&size](char* ptr, Py_ssize_t len) { - *size += - sizeof(tensorflow::uint64) + tensorflow::core::VarintLength(len) + len; - })); + TF_RETURN_IF_ERROR( + PyBytesArrayMap(array, [&size](const char* ptr, Py_ssize_t len) { + *size += sizeof(tensorflow::uint64) + + tensorflow::core::VarintLength(len) + len; + })); // Encode all strings. std::unique_ptr base_ptr(new char[*size]); char* base = base_ptr.get(); @@ -198,7 +205,7 @@ Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems, tensorflow::uint64* offsets = reinterpret_cast(base); TF_RETURN_IF_ERROR(PyBytesArrayMap( - array, [&base, &data_start, &dst, &offsets](char* ptr, Py_ssize_t len) { + array, [&data_start, &dst, &offsets](const char* ptr, Py_ssize_t len) { *offsets = (dst - data_start); offsets++; dst = tensorflow::core::EncodeVarint64(dst, len); diff --git a/tensorflow/python/lib/core/numpy.h b/tensorflow/python/lib/core/numpy.h index 98354083c7e06103166a6fe535b153eaaf201c17..0098d938a086621a9fd98fa69b48aa78b5341171 100644 --- a/tensorflow/python/lib/core/numpy.h +++ b/tensorflow/python/lib/core/numpy.h @@ -31,6 +31,7 @@ limitations under the License. // Place `` before to avoid build failure in macOS. #include + #include #include "numpy/arrayobject.h" diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 30c1a9c75986f242c6cf5a8aa2ed1b64938d2bda..fc02d6de0e005cb66b4ddf59379f3d2355eb8eaf 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -55,37 +55,35 @@ struct PyCall { string token; // The device on which Tensors are stored; only used for EagerPyFunc. - Device* device; - - // True if and only if the op has been placed on a GPU. - bool gpu; + Device* device = nullptr; // True if the call is associated with an EagerPyFunc. - bool eager; + bool eager = false; // Inputs and outputs of this function invocation. std::vector ins; std::vector out; }; +bool IsCPUDevice(const Device* d) { + return d == nullptr || d->tensorflow_gpu_device_info() == nullptr; +} + // Givens the 'call', prepares the token and inputs as a python tuple // that is appropriate for calling the trampoline. Status MakeArgTuple(const PyCall* call, PyObject** tuple) { int64 n = call->ins.size(); PyObject* lst = PyList_New(n); CHECK(lst); + // TFE_TensorHandle assumes that CPU is identified by nullptr. + Device* device = IsCPUDevice(call->device) ? nullptr : call->device; for (int64 i = 0; i < n; ++i) { PyObject* arg = nullptr; const Tensor& t = call->ins[i]; if (call->eager) { - if (call->gpu) { - arg = EagerTensorFromHandle( - new TFE_TensorHandle(t, call->device, call->device)); - } else { - // TFE_TensorHandle assumes that CPU is identified by `nullptr`. - arg = EagerTensorFromHandle(new TFE_TensorHandle(t, nullptr, nullptr)); - } + arg = EagerTensorFromHandle(new TFE_TensorHandle(t, device, device)); if (arg == nullptr) { + Py_DECREF(lst); return errors::Internal("Unable to procure EagerTensor from Tensor."); } } else { @@ -97,8 +95,9 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) { } PyList_SetItem(lst, i, arg); } - *tuple = Py_BuildValue("(sON)", call->token.c_str(), - call->gpu ? Py_True : Py_False, lst); + const char* device_name = + device == nullptr ? nullptr : device->attributes().name().c_str(); + *tuple = Py_BuildValue("(ssN)", call->token.c_str(), device_name, lst); CHECK(*tuple); return Status::OK(); } @@ -167,9 +166,40 @@ bool IsSingleNone(PyObject* obj) { } // Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`. +// Validates that `output_tensor` is backed by memory in `expected_device` +// (which is assumed to be a local device, one on which the kernel was +// executed.) +// +// It may be nice to copy the tensor to the right device instead of failing if +// it isn't already there. This is left as a future exercise. The required +// device-copying logic is implemented in Python at the moment. tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor, + const Device* expected_device, const Tensor** output_tensor) { - return EagerTensor_Handle(eager_tensor)->handle->Tensor(output_tensor); + auto handle = EagerTensor_Handle(eager_tensor)->handle; + Device* actual_device = nullptr; + TF_RETURN_IF_ERROR(handle->Device(&actual_device)); + TF_RETURN_IF_ERROR(handle->Tensor(output_tensor)); + // actual_device may be nullptr, which implies local CPU. + if (expected_device == actual_device) return Status::OK(); + const string& expected_device_name = expected_device->attributes().name(); + if (actual_device == nullptr) { + if (!IsCPUDevice(expected_device)) { + return errors::Internal( + "expected the py_func to return a Tensor backed by memory in ", + expected_device_name, + ", but is actually backed by local host memory. This is a bug."); + } + return Status::OK(); + } + const string& actual_device_name = actual_device->attributes().name(); + if (actual_device_name != expected_device_name) { + return errors::Internal( + "expected the py_func to return a Tensor backed by memory in ", + expected_device_name, ", but is actually in ", actual_device_name, + ". This is a bug."); + } + return Status::OK(); } // Calls the registered py function through the trampoline. @@ -224,7 +254,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { const PyObject* item = PyList_GetItem(result, i); if (EagerTensor_CheckExact(item)) { const Tensor* tensor = nullptr; - s = ExtractTensorFromEagerTensor(item, &tensor); + s = ExtractTensorFromEagerTensor(item, call->device, &tensor); if (s.ok()) t = *tensor; } else { s = errors::FailedPrecondition( @@ -245,7 +275,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { DCHECK(call->eager); if (result != Py_None) { const Tensor* t = nullptr; - s = ExtractTensorFromEagerTensor(result, &t); + s = ExtractTensorFromEagerTensor(result, call->device, &t); if (s.ok()) call->out.push_back(*t); } } else if (PyArray_Check(result)) { @@ -303,6 +333,35 @@ class NumpyTensorBuffer : public TensorBuffer { void* data_; }; +Status PyObjectToString(PyObject* obj, string* str) { + char* py_bytes; + Py_ssize_t size; + if (PyBytes_AsStringAndSize(obj, &py_bytes, &size) != -1) { + str->assign(py_bytes, size); + return Status::OK(); + } +#if PY_MAJOR_VERSION >= 3 + const char* ptr = PyUnicode_AsUTF8AndSize(obj, &size); + if (ptr != nullptr) { + str->assign(ptr, size); + return Status::OK(); + } +#else + if (PyUnicode_Check(obj)) { + PyObject* unicode = PyUnicode_AsUTF8String(obj); + char* ptr; + if (unicode && PyString_AsStringAndSize(unicode, &ptr, &size) != -1) { + str->assign(ptr, size); + Py_DECREF(unicode); + return Status::OK(); + } + Py_XDECREF(unicode); + } +#endif + return errors::Unimplemented("Unsupported object type ", + obj->ob_type->tp_name); +} + Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) { PyArrayObject* input = reinterpret_cast(obj); DataType dtype = DT_INVALID; @@ -318,29 +377,7 @@ Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) { auto tflat = t.flat(); PyObject** input_data = reinterpret_cast(PyArray_DATA(input)); for (int i = 0; i < tflat.dimension(0); ++i) { - char* el; - Py_ssize_t el_size; - if (PyBytes_AsStringAndSize(input_data[i], &el, &el_size) == -1) { -#if PY_MAJOR_VERSION >= 3 - el = PyUnicode_AsUTF8AndSize(input_data[i], &el_size); -#else - el = nullptr; - if (PyUnicode_Check(input_data[i])) { - PyObject* unicode = PyUnicode_AsUTF8String(input_data[i]); - if (unicode) { - if (PyString_AsStringAndSize(unicode, &el, &el_size) == -1) { - Py_DECREF(unicode); - el = nullptr; - } - } - } -#endif - if (!el) { - return errors::Unimplemented("Unsupported object type ", - input_data[i]->ob_type->tp_name); - } - } - tflat(i) = string(el, el_size); + TF_RETURN_IF_ERROR(PyObjectToString(input_data[i], &tflat(i))); } *ret = t; break; @@ -449,13 +486,11 @@ class PyFuncOp : public OpKernel { explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_)); eager_ = type_string() == "EagerPyFunc"; - gpu_ = ctx->device_type().type_string() == DEVICE_GPU; } void Compute(OpKernelContext* ctx) override { PyCall call; call.token = token_; - call.gpu = gpu_; call.eager = eager_; if (call.eager) { // Eager's C API uses `Device`, whereas `OpKernelContext` stores a @@ -464,6 +499,7 @@ class PyFuncOp : public OpKernel { if (call.device == nullptr) { ctx->CtxFailureWithWarning( errors::Internal("Unrecognized device class")); + return; } } @@ -471,6 +507,17 @@ class PyFuncOp : public OpKernel { call.ins.push_back(ctx->input(i)); } + // NOTE(mrry): There is a potential time-of-check-to-time-of-use race here. + // because it is possible that `Py_Finalize()` could be called in another + // thread between this check and the call to `PyGILState_Ensure()`, which + // will abort the process if `Py_Finalize()` has been called. A more robust + // solution would be welcome, but it is not obvious how to make this work + // using the current Python C API. + OP_REQUIRES(ctx, Py_IsInitialized(), + errors::FailedPrecondition( + "Python interpreter state is not initialized. " + "The process may be terminated.")); + PyGILState_STATE py_threadstate; py_threadstate = PyGILState_Ensure(); bool log_on_error; @@ -508,9 +555,6 @@ class PyFuncOp : public OpKernel { private: string token_; - // True if and only if this op has been placed on a GPU. - bool gpu_; - // True if and only if this op should execute the python function eagerly, // i.e., if and only if the eager attribute is set. bool eager_; diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 386be35ba2ff1fed07d6b6f5ee5d60a0f2039441..3b4f12ae31b9e905ed15e86533e648b4c95736e1 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -88,6 +88,41 @@ bool IsPyDimension(PyObject* obj) { return ret; } +// Sets *elem to a NEW reference to an element in seq on success. +// REQUIRES: PySequence_Check(seq) && PySequence_Length(seq) > 0. +Status SampleElementFromSequence(PyObject* seq, PyObject** elem) { + *elem = PySequence_GetItem(seq, 0); + if (*elem != nullptr) return Status::OK(); + // seq may implement the sequence protocol (i.e., implement __getitem__) + // but may legitimately not have a 0-th element (__getitem__(self, 0) + // raises a KeyError). For example: + // seq = pandas.Series([0, 1, 2], index=[2, 4, 6]) + // + // We don't actually care for the element at key 0, any element will do + // for inferring the element types. All elements are expected to + // have the same type, and this will be validated when converting + // to an EagerTensor. + PyErr_Clear(); + Safe_PyObjectPtr iter(PyObject_GetIter(seq)); + if (PyErr_Occurred()) { + return errors::InvalidArgument("Cannot infer dtype of a ", + Py_TYPE(seq)->tp_name, + " object: ", PyExceptionFetch()); + } + *elem = PyIter_Next(iter.get()); + if (PyErr_Occurred()) { + return errors::InvalidArgument( + "Cannot infer dtype of a ", Py_TYPE(seq)->tp_name, + " object, as iter().next() failed: ", PyExceptionFetch()); + } + if (*elem == nullptr) { + return errors::InvalidArgument("Cannot infer dtype of a ", + Py_TYPE(seq)->tp_name, + " object since it is an empty sequence"); + } + return Status::OK(); +} + Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { std::vector refs_to_clean; while (true) { @@ -98,7 +133,9 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { auto length = PySequence_Length(obj); if (length > 0) { shape->AddDim(length); - obj = PySequence_GetItem(obj, 0); + PyObject* elem = nullptr; + TF_RETURN_IF_ERROR(SampleElementFromSequence(obj, &elem)); + obj = elem; refs_to_clean.push_back(make_safe(obj)); continue; } else if (length == 0) { diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc index 572693b1cfafa04a7716e09464885faa4c92e299..739cab46b10223fde918372af48b7f7a83d4a7a6 100644 --- a/tensorflow/python/lib/core/py_util.cc +++ b/tensorflow/python/lib/core/py_util.cc @@ -17,6 +17,9 @@ limitations under the License. // Place `` before to avoid build failure in macOS. #include + +// The empty line above is on purpose as otherwise clang-format will +// automatically move before . #include #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc index ba749da47a57305a8d414a946c1290f4982cc759..3c64813735139d8aa45fdc8c52f8907fb27e2583 100644 --- a/tensorflow/python/lib/io/py_record_writer.cc +++ b/tensorflow/python/lib/io/py_record_writer.cc @@ -47,6 +47,9 @@ PyRecordWriter* PyRecordWriter::New(const string& filename, } PyRecordWriter::~PyRecordWriter() { + // Writer depends on file during close for zlib flush, so destruct first. + writer_.reset(); + file_.reset(); } bool PyRecordWriter::WriteRecord(tensorflow::StringPiece record) { @@ -56,6 +59,11 @@ bool PyRecordWriter::WriteRecord(tensorflow::StringPiece record) { } void PyRecordWriter::Flush(TF_Status* out_status) { + if (writer_ == nullptr) { + TF_SetStatus(out_status, TF_FAILED_PRECONDITION, + "Writer not initialized or previously closed"); + return; + } Status s = writer_->Flush(); if (!s.ok()) { Set_TF_Status_from_Status(out_status, s); @@ -64,18 +72,22 @@ void PyRecordWriter::Flush(TF_Status* out_status) { } void PyRecordWriter::Close(TF_Status* out_status) { - Status s = writer_->Close(); - if (!s.ok()) { - Set_TF_Status_from_Status(out_status, s); - return; + if (writer_ != nullptr) { + Status s = writer_->Close(); + if (!s.ok()) { + Set_TF_Status_from_Status(out_status, s); + return; + } + writer_.reset(nullptr); } - writer_.reset(nullptr); - s = file_->Close(); - if (!s.ok()) { - Set_TF_Status_from_Status(out_status, s); - return; + if (file_ != nullptr) { + Status s = file_->Close(); + if (!s.ok()) { + Set_TF_Status_from_Status(out_status, s); + return; + } + file_.reset(nullptr); } - file_.reset(nullptr); } } // namespace io diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index bf2d6f68b55d78f9570d3854804e3d1316176c99..941d6cd67c202f7888ef555e63bf73406d8a1e33 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -125,6 +125,7 @@ class TFRecordWriter(object): Args: record: str """ + # TODO(sethtroisi): Failures are currently swallowed, change that. self._writer.WriteRecord(record) def flush(self): diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4743c037ecddd70d6917e8fcce289d310f7335ae --- /dev/null +++ b/tensorflow/python/lib/io/tf_record_test.py @@ -0,0 +1,384 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf_record.TFRecordWriter and tf_record.tf_record_iterator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import zlib + +import six + +from tensorflow.python.framework import errors_impl +from tensorflow.python.lib.io import tf_record +from tensorflow.python.platform import test +from tensorflow.python.util import compat + +prefix_path = "third_party/tensorflow/core/lib" + +# pylint: disable=invalid-name +TFRecordCompressionType = tf_record.TFRecordCompressionType +# pylint: enable=invalid-name + +# Edgar Allan Poe's 'Eldorado' +_TEXT = b"""Gaily bedight, + A gallant knight, + In sunshine and in shadow, + Had journeyed long, + Singing a song, + In search of Eldorado. + + But he grew old + This knight so bold + And o'er his heart a shadow + Fell as he found + No spot of ground + That looked like Eldorado. + + And, as his strength + Failed him at length, + He met a pilgrim shadow + 'Shadow,' said he, + 'Where can it be + This land of Eldorado?' + + 'Over the Mountains + Of the Moon' + Down the Valley of the Shadow, + Ride, boldly ride,' + The shade replied, + 'If you seek for Eldorado!' + """ + + +class TFCompressionTestCase(test.TestCase): + + def setUp(self): + super(TFCompressionTestCase, self).setUp() + self._num_files = 2 + self._num_records = 7 + + def _Record(self, f, r): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _CreateFiles(self, options=None, prefix=""): + filenames = [] + for i in range(self._num_files): + name = prefix + "tfrecord.%d.txt" % i + records = [self._Record(i, j) for j in range(self._num_records)] + fn = self._WriteRecordsToFile(records, name, options) + filenames.append(fn) + return filenames + + def _WriteRecordsToFile(self, records, name="tfrecord", options=None): + fn = os.path.join(self.get_temp_dir(), name) + with tf_record.TFRecordWriter(fn, options=options) as writer: + for r in records: + writer.write(r) + return fn + + def _ZlibCompressFile(self, infile, name="tfrecord.z"): + # zlib compress the file and write compressed contents to file. + with open(infile, "rb") as f: + cdata = zlib.compress(f.read()) + + zfn = os.path.join(self.get_temp_dir(), name) + with open(zfn, "wb") as f: + f.write(cdata) + return zfn + + def _GzipCompressFile(self, infile, name="tfrecord.gz"): + # gzip compress the file and write compressed contents to file. + with open(infile, "rb") as f: + cdata = f.read() + + gzfn = os.path.join(self.get_temp_dir(), name) + with gzip.GzipFile(gzfn, "wb") as f: + f.write(cdata) + return gzfn + + def _ZlibDecompressFile(self, infile, name="tfrecord"): + with open(infile, "rb") as f: + cdata = zlib.decompress(f.read()) + fn = os.path.join(self.get_temp_dir(), name) + with open(fn, "wb") as f: + f.write(cdata) + return fn + + def _GzipDecompressFile(self, infile, name="tfrecord"): + with gzip.GzipFile(infile, "rb") as f: + cdata = f.read() + fn = os.path.join(self.get_temp_dir(), name) + with open(fn, "wb") as f: + f.write(cdata) + return fn + + +class TFRecordWriterTest(TFCompressionTestCase): + + def setUp(self): + super(TFRecordWriterTest, self).setUp() + + def _AssertFilesEqual(self, a, b, equal): + for an, bn in zip(a, b): + with open(an, "rb") as af, open(bn, "rb") as bf: + if equal: + self.assertEqual(af.read(), bf.read()) + else: + self.assertNotEqual(af.read(), bf.read()) + + def testWriteReadZLibFiles(self): + # Write uncompressed then compress manually. + options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) + files = self._CreateFiles(options, prefix="uncompressed") + zlib_files = [ + self._ZlibCompressFile(fn, "tfrecord_%s.z" % i) + for i, fn in enumerate(files) + ] + self._AssertFilesEqual(files, zlib_files, False) + + # Now write compressd and verify same. + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + compressed_files = self._CreateFiles(options, prefix="compressed") + self._AssertFilesEqual(compressed_files, zlib_files, True) + + # Decompress compress and verify same. + uncompressed_files = [ + self._ZlibDecompressFile(fn, "tfrecord_%s.z" % i) + for i, fn in enumerate(compressed_files) + ] + self._AssertFilesEqual(uncompressed_files, files, True) + + def testWriteReadGzipFiles(self): + # Write uncompressed then compress manually. + options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) + files = self._CreateFiles(options, prefix="uncompressed") + gzip_files = [ + self._GzipCompressFile(fn, "tfrecord_%s.gz" % i) + for i, fn in enumerate(files) + ] + self._AssertFilesEqual(files, gzip_files, False) + + # Now write compressd and verify same. + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + compressed_files = self._CreateFiles(options, prefix="compressed") + + # Note: Gzips written by TFRecordWriter add 'tfrecord_0' so + # compressed_files can't be compared with gzip_files + + # Decompress compress and verify same. + uncompressed_files = [ + self._GzipDecompressFile(fn, "tfrecord_%s.gz" % i) + for i, fn in enumerate(compressed_files) + ] + self._AssertFilesEqual(uncompressed_files, files, True) + + +class TFRecordWriterZlibTest(TFCompressionTestCase): + + def testZLibFlushRecord(self): + original = [b"small record"] + fn = self._WriteRecordsToFile(original, "small_record") + with open(fn, "rb") as h: + buff = h.read() + + # creating more blocks and trailing blocks shouldn't break reads + compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS) + + output = b"" + for c in buff: + if isinstance(c, int): + c = six.int2byte(c) + output += compressor.compress(c) + output += compressor.flush(zlib.Z_FULL_FLUSH) + + output += compressor.flush(zlib.Z_FULL_FLUSH) + output += compressor.flush(zlib.Z_FULL_FLUSH) + output += compressor.flush(zlib.Z_FINISH) + + # overwrite the original file with the compressed data + with open(fn, "wb") as h: + h.write(output) + + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + actual = list(tf_record.tf_record_iterator(fn, options=options)) + self.assertEqual(actual, original) + + def testZlibReadWrite(self): + """Verify that files produced are zlib compatible.""" + original = [b"foo", b"bar"] + fn = self._WriteRecordsToFile(original, "zlib_read_write.tfrecord") + zfn = self._ZlibCompressFile(fn, "zlib_read_write.tfrecord.z") + + # read the compressed contents and verify. + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + actual = list(tf_record.tf_record_iterator(zfn, options=options)) + self.assertEqual(actual, original) + + def testZlibReadWriteLarge(self): + """Verify that writing large contents also works.""" + + # Make it large (about 5MB) + original = [_TEXT * 10240] + fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord") + zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z") + + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + actual = list(tf_record.tf_record_iterator(zfn, options=options)) + self.assertEqual(actual, original) + + def testGzipReadWrite(self): + """Verify that files produced are gzip compatible.""" + original = [b"foo", b"bar"] + fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord") + gzfn = self._GzipCompressFile(fn, "tfrecord.gz") + + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + actual = list(tf_record.tf_record_iterator(gzfn, options=options)) + self.assertEqual(actual, original) + + +class TFRecordIteratorTest(TFCompressionTestCase): + + def setUp(self): + super(TFRecordIteratorTest, self).setUp() + self._num_records = 7 + + def testIterator(self): + records = [self._Record(0, i) for i in range(self._num_records)] + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(records, "compressed_records", options) + + reader = tf_record.tf_record_iterator(fn, options) + for expected in records: + record = next(reader) + self.assertAllEqual(expected, record) + with self.assertRaises(StopIteration): + record = next(reader) + + def testWriteZlibRead(self): + """Verify compression with TFRecordWriter is zlib library compatible.""" + original = [b"foo", b"bar"] + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(original, "write_zlib_read.tfrecord.z", + options) + + zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord") + actual = list(tf_record.tf_record_iterator(zfn)) + self.assertEqual(actual, original) + + def testWriteZlibReadLarge(self): + """Verify compression for large records is zlib library compatible.""" + # Make it large (about 5MB) + original = [_TEXT * 10240] + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(original, "write_zlib_read_large.tfrecord.z", + options) + zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tfrecord") + actual = list(tf_record.tf_record_iterator(zfn)) + self.assertEqual(actual, original) + + def testWriteGzipRead(self): + original = [b"foo", b"bar"] + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + fn = self._WriteRecordsToFile(original, "write_gzip_read.tfrecord.gz", + options) + + gzfn = self._GzipDecompressFile(fn, "write_gzip_read.tfrecord") + actual = list(tf_record.tf_record_iterator(gzfn)) + self.assertEqual(actual, original) + + def testBadFile(self): + """Verify that tf_record_iterator throws an exception on bad TFRecords.""" + fn = os.path.join(self.get_temp_dir(), "bad_file") + with tf_record.TFRecordWriter(fn) as writer: + writer.write(b"123") + fn_truncated = os.path.join(self.get_temp_dir(), "bad_file_truncated") + with open(fn, "rb") as f: + with open(fn_truncated, "wb") as f2: + # DataLossError requires that we've written the header, so this must + # be at least 12 bytes. + f2.write(f.read(14)) + with self.assertRaises(errors_impl.DataLossError): + for _ in tf_record.tf_record_iterator(fn_truncated): + pass + +class TFRecordWriterCloseAndFlushTests(test.TestCase): + + def setUp(self, compression_type=TFRecordCompressionType.NONE): + super(TFRecordWriterCloseAndFlushTests, self).setUp() + self._fn = os.path.join(self.get_temp_dir(), "tf_record_writer_test.txt") + self._options = tf_record.TFRecordOptions(compression_type) + self._writer = tf_record.TFRecordWriter(self._fn, self._options) + self._num_records = 20 + + def _Record(self, r): + return compat.as_bytes("Record %d" % r) + + def testWriteAndLeaveOpen(self): + records = list(map(self._Record, range(self._num_records))) + for record in records: + self._writer.write(record) + + # Verify no segfault if writer isn't explicitly closed. + + def testWriteAndRead(self): + records = list(map(self._Record, range(self._num_records))) + for record in records: + self._writer.write(record) + self._writer.close() + + actual = list(tf_record.tf_record_iterator(self._fn, self._options)) + self.assertListEqual(actual, records) + + def testDoubleClose(self): + self._writer.write(self._Record(0)) + self._writer.close() + self._writer.close() + + def testFlushAfterCloseIsError(self): + self._writer.write(self._Record(0)) + self._writer.close() + + with self.assertRaises(errors_impl.FailedPreconditionError): + self._writer.flush() + + def testWriteAfterClose(self): + self._writer.write(self._Record(0)) + self._writer.close() + + # TODO(sethtroisi): No way to know this failed, changed that. + self._writer.write(self._Record(1)) + + +class TFRecordWriterCloseAndFlushGzipTests(TFRecordWriterCloseAndFlushTests): + + def setUp(self): + super(TFRecordWriterCloseAndFlushGzipTests, + self).setUp(TFRecordCompressionType.GZIP) + + +class TFRecordWriterCloseAndFlushZlibTests(TFRecordWriterCloseAndFlushTests): + + def setUp(self): + super(TFRecordWriterCloseAndFlushZlibTests, + self).setUp(TFRecordCompressionType.ZLIB) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 3678bd4c1f6a4500622b6d9e8334cb1ebae46578..a2b5f77f915404b124b14c5d34996cd85cb2b468 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -568,7 +568,6 @@ ops.NotDifferentiable("Size") @ops.RegisterGradient("Tile") def _TileGrad(op, grad): """Sum reduces grad along the tiled dimensions.""" - assert isinstance(grad, ops.Tensor) input_shape = array_ops.shape(op.inputs[0]) # We interleave multiples and input_shape to get split_shape, # reshape grad to split_shape, and reduce along all even @@ -581,6 +580,13 @@ def _TileGrad(op, grad): split_shape = array_ops.reshape( array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1]) axes = math_ops.range(0, array_ops.size(split_shape), 2) + # Sum reduces grad along the first dimension for IndexedSlices + if isinstance(grad, ops.IndexedSlices): + grad = math_ops.unsorted_segment_sum( + grad.values, + math_ops.mod(grad.indices, input_shape[0]), + input_shape[0]) + split_shape = array_ops.concat([[1], split_shape[1:]], axis=0) input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) # Fix shape inference if not context.executing_eagerly(): @@ -784,7 +790,7 @@ def _ExtractImagePatchesGrad(op, grad): sp_mat = sparse_tensor.SparseTensor( array_ops.constant(idx, dtype=ops.dtypes.int64), - array_ops.ones((len(idx),), dtype=ops.dtypes.float32), sp_shape) + array_ops.ones((len(idx),), dtype=grad.dtype), sp_shape) jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index fae63b1132cca527c6bc5d5f9f5c8be2952d8f3c..a917f5108753807465db2c929504b5fb1ba794a2 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -41,6 +41,7 @@ from tensorflow.python.ops import gen_math_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_array_ops import * +from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export # pylint: enable=wildcard-import @@ -537,7 +538,7 @@ def slice(input_, begin, size, name=None): words, `begin[i]` is the offset into the 'i'th dimension of `input` that you want to slice from. - Note that @{tf.Tensor.__getitem__} is typically a more pythonic way to + Note that `tf.Tensor.__getitem__` is typically a more pythonic way to perform slices, as it allows you to write `foo[3:7, :-2]` instead of `tf.slice(foo, [3, 0], [4, foo.get_shape()[1]-2])`. @@ -593,7 +594,7 @@ def strided_slice(input_, **Instead of calling this op directly most users will want to use the NumPy-style slicing syntax (e.g. `tensor[..., 3:4:-1, tf.newaxis, 3]`), which - is supported via @{tf.Tensor.__getitem__} and @{tf.Variable.__getitem__}.** + is supported via `tf.Tensor.__getitem__` and `tf.Variable.__getitem__`.** The interface of this op is a low-level encoding of the slicing syntax. Roughly speaking, this op extracts a slice of size `(end-begin)/stride` @@ -635,10 +636,10 @@ def strided_slice(input_, `foo[:4, tf.newaxis, :2]` would produce a shape `(4, 1, 2)` tensor. If the ith bit of `shrink_axis_mask` is set, it implies that the ith - specification shrinks the dimensionality by 1. `begin[i]`, `end[i]` and - `strides[i]` must imply a slice of size 1 in the dimension. For example in - Python one might do `foo[:, 3, :]` which would result in - `shrink_axis_mask` equal to 2. + specification shrinks the dimensionality by 1, taking on the value at index + `begin[i]`. `end[i]` and `strides[i]` are ignored in this case. For example in + Python one might do `foo[:, 3, :]` which would result in `shrink_axis_mask` + equal to 2. NOTE: `begin` and `end` are zero-indexed. @@ -722,7 +723,7 @@ def _SliceHelperVar(var, slice_spec): """Creates a slice helper object given a variable. This allows creating a sub-tensor from part of the current contents - of a variable. See @{tf.Tensor.__getitem__} for detailed examples + of a variable. See `tf.Tensor.__getitem__` for detailed examples of slicing. This function in addition also allows assignment to a sliced range. @@ -2609,14 +2610,6 @@ def where(condition, x=None, y=None, name=None): raise ValueError("x and y must both be non-None or both be None.") -@tf_export("reverse") -def reverse(tensor, axis, name=None): - return gen_array_ops.reverse_v2(tensor, axis, name) - - -reverse.__doc__ = gen_array_ops.reverse_v2.__doc__ - - # pylint: disable=redefined-builtin @tf_export("reverse_sequence") @deprecation.deprecated_args( diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py index 2a2bcdd9d69b7a0aed1e7f3d3197cf6d7dd98451..f7cbfe0312fe5e7d8d75580af9f362236fd5b79d 100644 --- a/tensorflow/python/ops/boosted_trees_ops.py +++ b/tensorflow/python/ops/boosted_trees_ops.py @@ -25,6 +25,8 @@ from tensorflow.python.ops import resources # Re-exporting ops used by other modules. # pylint: disable=unused-import from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature +from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias +from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict @@ -35,8 +37,19 @@ from tensorflow.python.training import saver class PruningMode(object): + """Class for working with Pruning modes.""" NO_PRUNING, PRE_PRUNING, POST_PRUNING = range(0, 3) + _map = {'none': NO_PRUNING, 'pre': PRE_PRUNING, 'post': POST_PRUNING} + + @classmethod + def from_str(cls, mode): + if mode in cls._map: + return cls._map[mode] + else: + raise ValueError('pruning_mode mode must be one of: {}'.format(', '.join( + sorted(cls._map)))) + class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject): """SaveableObject implementation for TreeEnsemble.""" diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index 75c459a9cf10a90f6043d304b302e0a0806bf045..78b395a6c185d2f948f78a8a19d1a8eeaa6a93f2 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import numerics from tensorflow.python.util.tf_export import tf_export @@ -42,6 +43,9 @@ def clip_by_value(t, clip_value_min, clip_value_max, Any values less than `clip_value_min` are set to `clip_value_min`. Any values greater than `clip_value_max` are set to `clip_value_max`. + Note: `clip_value_min` needs to be smaller or equal to `clip_value_max` for + correct results. + Args: t: A `Tensor`. clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape @@ -54,7 +58,7 @@ def clip_by_value(t, clip_value_min, clip_value_max, A clipped `Tensor`. Raises: - ValueError: if the clip tensors would trigger array broadcasting + ValueError: If the clip tensors would trigger array broadcasting that would make the returned tensor larger than the input. """ with ops.name_scope(name, "clip_by_value", @@ -243,6 +247,7 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None): Raises: TypeError: If `t_list` is not a sequence. + InvalidArgumentError: If global norm is not finite. """ if (not isinstance(t_list, collections.Sequence) or isinstance(t_list, six.string_types)): @@ -250,6 +255,8 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None): t_list = list(t_list) if use_norm is None: use_norm = global_norm(t_list, name) + use_norm = numerics.verify_tensor_all_finite(use_norm, + "Found Inf or NaN global norm.") with ops.name_scope(name, "clip_by_global_norm", t_list + [clip_norm]) as name: diff --git a/tensorflow/python/ops/collective_ops.py b/tensorflow/python/ops/collective_ops.py index a05fd15eca12a423bf02dfb13044dd1f7630b99c..98668facd5bc56892fa00f258dfebcbe93c063da 100644 --- a/tensorflow/python/ops/collective_ops.py +++ b/tensorflow/python/ops/collective_ops.py @@ -22,7 +22,7 @@ from tensorflow.python.ops import gen_collective_ops def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op, - subdiv_offsets=(0)): + subdiv_offsets=(0,)): """Reduces tensors collectively, across devices. Args: diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py index 8e16cffdf4917ba361a3c313047e39af514273bc..9cc64ef9f631faf2f76c3dbb3e70e1f37bbe4b1a 100644 --- a/tensorflow/python/ops/collective_ops_test.py +++ b/tensorflow/python/ops/collective_ops_test.py @@ -37,11 +37,11 @@ class CollectiveOpTest(test.TestCase): with ops.device('/CPU:0'): in0 = constant_op.constant(t0) colred0 = collective_ops.all_reduce(in0, 2, group_key, instance_key, - 'Add', 'Div', [0]) + 'Add', 'Div') with ops.device('/CPU:1'): in1 = constant_op.constant(t1) colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key, - 'Add', 'Div', [0]) + 'Add', 'Div') run_options = config_pb2.RunOptions() run_options.experimental.collective_graph_key = 1 results = sess.run([colred0, colred1], options=run_options) diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..76173e0f309b80402a15acdab5d2af49f35de741 --- /dev/null +++ b/tensorflow/python/ops/cond_v2.py @@ -0,0 +1,32 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""cond_v2 wrapper module. + +This imports the cond_v2 method and all necessary dependencies (this is to avoid +circular dependencies in the cond_v2 implementation). See cond_v2_impl for more +information. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.python.framework import function +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.ops import gradients_impl + +from tensorflow.python.ops.cond_v2_impl import cond_v2 +# pylint: enable=unused-import diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/python/ops/cond_v2_impl.py similarity index 62% rename from tensorflow/contrib/control_flow/python/cond_v2.py rename to tensorflow/python/ops/cond_v2_impl.py index 9ffad9caa92d2d3be8f598758a443b0eceb8d4d8..b3dacff6d66bfabc624f05cf698fd882efd72e74 100644 --- a/tensorflow/contrib/control_flow/python/cond_v2.py +++ b/tensorflow/python/ops/cond_v2_impl.py @@ -17,23 +17,32 @@ This is a version of cond that emits a single If op, as well as the gradient function for If ops produced by cond_v2. This will eventually replace the current tf.cond implementation once it reaches feature and performance parity. + +NOTE: most users of cond_v2 should import cond_v2, not this module! This module +does not contain all the necessary imports to prevent circular dependencies, +while cond_v2 does. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.framework import c_api_util -from tensorflow.python.framework import function -from tensorflow.python.framework import function_def_to_graph from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_functional_ops -from tensorflow.python.ops import gradients_impl from tensorflow.python.util import compat +# The following modules cannot be imported directly because they cause circular +# dependencies. These are set in each corresponding module. +_function = None +_function_def_to_graph = None +_gradients_impl = None + # NOTE(skyewm): TensorFlow uses protected class methods and fields to signify # that they aren't part of the official public API. These protected members # often need to be used by implementation code however. Rather than litter the @@ -44,11 +53,43 @@ from tensorflow.python.util import compat def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" + if not name: + name = "cond" + with ops.name_scope(name) as scope: - true_graph = function.func_graph_from_py_func(true_fn, [], [], - name="%s_true" % scope) - false_graph = function.func_graph_from_py_func(false_fn, [], [], - name="%s_false" % scope) + # Identify if there is a caller device, & get the innermost if possible. + # pylint: disable=protected-access + device_funcs = ops.get_default_graph()._device_functions_outer_to_inner + caller_device = device_funcs[-1] if device_funcs else None + + caller_colocation_stack = ops.get_default_graph()._colocation_stack + caller_container = ops.get_default_graph()._container + caller_collection_ref = ops.get_default_graph()._collections + + with ops.name_scope(None): + # Find the outer most graph for uniquing function names. + # TODO(jpienaar): Make this work in eager mode. + graph = ops.get_default_graph() + while isinstance(graph, _function._FuncGraph): + graph = graph._outer_graph + + true_name = graph.unique_name(("%strue" % scope).replace("/", "_")) + false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_")) + # pylint: enable=protected-access + true_graph = _function.func_graph_from_py_func( + true_fn, [], [], + name=true_name, + device=caller_device, + colocation_stack=caller_colocation_stack, + collections_ref=caller_collection_ref, + container=caller_container) + false_graph = _function.func_graph_from_py_func( + false_fn, [], [], + name=false_name, + device=caller_device, + colocation_stack=caller_colocation_stack, + collections_ref=caller_collection_ref, + container=caller_container) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that @@ -74,19 +115,41 @@ def cond_v2(pred, true_fn, false_fn, name="cond"): false_graph.outputs.extend(extra_false_outputs) # Create the If op. - tensors = gen_functional_ops._if( + tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], _create_new_tf_function(true_graph), _create_new_tf_function(false_graph), name=scope) - return tensors[:num_cond_outputs] + # Set the flag to enable lowering on the `if` op if necessary + # Lowering allows cond_v2 to avoid some of the limitations of Functions, + # allowing users to specify devices & colocation inside of cond_v2 branches, + # and enabling non-strict evaluation & partial pruning of cond_v2 branches. + # This brings cond_v2 closer to feature parity with tf.cond. + # + # However, we do not lower `If` in the XLA context because it is easier for + # XLA to apply its own optimizations when dealing with un-lowered `If` + # operators than with lowered switch/merge control flow. + # + # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output + if_op = tensors[0].op + if not control_flow_util.IsInXLAContext(if_op): + # pylint: disable=protected-access + if_op._set_attr("_lower_using_switch_merge", + attr_value_pb2.AttrValue(b=True)) + # pylint: enable=protected-access + + return tuple(tensors[:num_cond_outputs]) @ops.RegisterGradient("If") def _IfGrad(op, *grads): # pylint: disable=invalid-name """The gradient of an If op produced by cond_v2.""" true_graph, false_graph = _get_func_graphs(op) + # Note: op.graph != ops.get_default_graph() when we are computing the gradient + # of a nested cond. + assert true_graph._outer_graph == op.graph + assert false_graph._outer_graph == op.graph # Create grad functions that compute the gradient of the true/false forward # graphs. These functions will capture tensors from the forward pass @@ -99,15 +162,16 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name assert ([t.dtype for t in true_grad_graph.outputs] == [t.dtype for t in false_grad_graph.outputs]) - # Match up the captured grad function inputs with outputs of 'op' and other - # external tensors. - true_grad_inputs = _get_grad_inputs(op, true_graph, true_grad_graph) - false_grad_inputs = _get_grad_inputs(op, false_graph, false_grad_graph) + # Resolve references to forward graph tensors in grad graphs and ensure + # they are in-scope, i.e., belong to one of outer graphs of the grad graph. + true_grad_extra_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) + false_grad_extra_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) # Make the inputs to true_grad_graph and false_grad_graph match. Note that # this modifies true_grad_graph and false_grad_graph. grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, - true_grad_inputs, false_grad_inputs) + true_grad_extra_inputs, + false_grad_extra_inputs) # Add all intermediate tensors as function outputs so they're available for # higher-order gradient computations. @@ -146,14 +210,25 @@ def _get_func_graphs(if_op): A 2-tuple of the `_FuncGraph`s of the then_branch and else_branch. """ def _get_func_graph_for_branch(branch_name): + """Generates and returns a _FuncGraph for the given branch.""" extra_inputs = if_op.inputs[1:] # First input is pred. input_shapes = [t.shape for t in extra_inputs] func_name = if_op.get_attr(branch_name).name fdef = if_op.graph._get_function(func_name).definition - func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) + # `if_op.graph` may not be the same as `ops.get_default_graph()` e.g. + # in the case of nested if ops or when the gradient is being computed + # from inside a Defun. We build the `func_graph` with `if_op.graph` as its + # `outer_graph`. This resembles how the `_FuncGraph` was built in the + # forward pass. We need this so that we can resolve references to tensors + # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. + with if_op.graph.as_default(): + func_graph = _function_def_to_graph.function_def_to_graph( + fdef, input_shapes) func_graph.extra_inputs = extra_inputs func_graph.extra_args = func_graph.inputs func_graph._captured = dict(zip(extra_inputs, func_graph.inputs)) + # Set the if op so that the gradient code can use it. + func_graph._if = if_op return func_graph return (_get_func_graph_for_branch("then_branch"), @@ -182,7 +257,7 @@ def _grad_fn(func_graph, grads): ys = [] grad_ys = [] for y, grad_y in zip(func_graph.outputs, grads): - if not gradients_impl._IsTrainable(y): + if not _gradients_impl._IsTrainable(y): continue ys.append(y) grad_ys.append(grad_y) @@ -190,8 +265,8 @@ def _grad_fn(func_graph, grads): # Build the gradient graph. Note that this builds the gradient computation of # func_graph in the current graph, which requires capturing tensors from # func_graph. The captured func_graph tensors are resolved to external tensors - # in _get_grad_inputs. - result = gradients_impl._GradientsHelper( + # in _resolve_grad_inputs. + result = _gradients_impl._GradientsHelper( ys, func_graph.inputs, grad_ys=grad_ys, src_graph=func_graph) @@ -207,47 +282,53 @@ def _grad_fn(func_graph, grads): def _create_grad_func(func_graph, grads, name): """Returns the _FuncGraph representation of _grad_fn.""" - return function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads), - [], [], name) + return _function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads), + [], [], name) -def _get_grad_inputs(if_op, cond_graph, grad_graph): - """Returns the tensors we should pass to grad_graph. +def _resolve_grad_inputs(cond_graph, grad_graph): + """Returns the tensors to pass as `extra_inputs` to `grad_graph`. - This method handles tensors captured from cond_graph in grad_graph. It - converts these to suitable input tensors from the outer graph. + The `grad_graph` may have external references to + 1. Its outer graph containing the input gradients. These references are kept + as is. + 2. Tensors in the forward pass graph. These tensors may not be "live" + when the gradient is being computed. We replace such references by their + corresponding tensor in the least common ancestor graph of `grad_graph` and + `cond_graph`. Since we export intermediate tensors for all branch + functions, this is always possible. Args: - if_op: Operation. The forward-pass If op that uses cond_graph. cond_graph: function._FuncGraph. The forward-pass function. grad_graph: function._FuncGraph. The gradients function. Returns: A list of inputs tensors to be passed to grad_graph. """ - inputs = [] - - # Maps placeholders in cond_graph -> input tensor in outer graph. - forward_input_map = {v: k for k, v in cond_graph._captured.items()} + new_extra_inputs = [] for t in grad_graph.extra_inputs: - if t.graph == ops.get_default_graph(): - # t is in the outer graph (e.g. one of the input gradients). - inputs.append(t) - elif t in forward_input_map: - # t is an input placeholder in cond_graph. Get the corresponding input - # tensor in the outer graph. - assert t.graph == cond_graph - assert forward_input_map[t].graph == ops.get_default_graph() - inputs.append(forward_input_map[t]) - else: - # t is an intermediate value in cond_graph. Get the corresponding output - # of 'if_op' (note that all intermediate values are outputs). - assert t.graph == cond_graph - output_idx = cond_graph.outputs.index(t) - inputs.append(if_op.outputs[output_idx]) - - return inputs + if t.graph != grad_graph._outer_graph: + # `t` is a tensor in `cond_graph` or one of its ancestors. We bubble this + # tensor to the least common ancestor of the `cond_graph` and + # `grad_graph` so that it is "in-scope" for `grad_graph`. + # TODO(srbs): `_is_ancestor` calls may be expensive. Compute the least + # common ancestor once and re-use. + assert _is_ancestor(cond_graph, t.graph) + while not _is_ancestor(grad_graph, t.graph): + assert isinstance(t.graph, _function._FuncGraph) + if t in t.graph.extra_args: + # TODO(srbs): Consider building a map of extra_args -> extra_inputs. + # instead of searching for `t` twice. + t = t.graph.extra_inputs[t.graph.extra_args.index(t)] + else: + # Note: All intermediate tensors are output by the If op. + # TODO(srbs): .index() calls may be expensive. Optimize. + t = t.graph._if.outputs[t.graph.outputs.index(t)] + assert _is_ancestor(grad_graph, t.graph) + new_extra_inputs.append(t) + + return new_extra_inputs def _create_new_tf_function(func_graph): @@ -274,9 +355,10 @@ def _create_new_tf_function(func_graph): # TODO(b/109833212): this sucks, we're serializing the TF_Function*, # deserializing it into a Python FunctionDef, then reserializing it to create # a new TF_Function that we add to the graph. - fdef = function.function_def_from_tf_function(c_func) - defined_func = function._from_definition(fdef) - defined_func.add_to_graph(ops.get_default_graph()) + fdef = _function.function_def_from_tf_function(c_func) + defined_func = _function._from_definition(fdef) + defined_func._sub_functions = func_graph._functions + defined_func.add_to_graph(func_graph._outer_graph) return func_graph.name @@ -339,7 +421,8 @@ def _pad_params(true_graph, false_graph, true_params, false_params): return new_true_params, new_false_inputs -def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): +def _make_inputs_match(true_graph, false_graph, true_extra_inputs, + false_extra_inputs): """Modifies true_graph and false_graph so they have the same input signature. This method reorders and/or adds parameters to true_graph and false_graph so @@ -350,9 +433,9 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): Args: true_graph: function._FuncGraph false_graph: function._FuncGraph - true_inputs: a list of Tensors in the outer graph. The inputs for + true_extra_inputs: a list of Tensors in the outer graph. The inputs for true_graph. - false_inputs: a list of Tensors in the outer graph. The inputs for + false_extra_inputs: a list of Tensors in the outer graph. The inputs for false_graph. Returns: @@ -361,12 +444,12 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): false_inputs. """ shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs( - true_inputs, false_inputs) + true_extra_inputs, false_extra_inputs) new_inputs = shared_inputs + true_only_inputs + false_only_inputs - true_input_to_param = dict(zip(true_inputs, true_graph.inputs)) - false_input_to_param = dict(zip(false_inputs, false_graph.inputs)) + true_input_to_param = dict(zip(true_extra_inputs, true_graph.inputs)) + false_input_to_param = dict(zip(false_extra_inputs, false_graph.inputs)) true_graph.inputs = ( [true_input_to_param[t] for t in shared_inputs] + @@ -382,6 +465,9 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): true_graph.extra_inputs = new_inputs false_graph.extra_inputs = new_inputs + true_graph.extra_args = true_graph.inputs + false_graph.extra_args = false_graph.inputs + true_graph._captured = dict(zip(new_inputs, true_graph.inputs)) false_graph._captured = dict(zip(new_inputs, false_graph.inputs)) @@ -404,14 +490,30 @@ def _create_dummy_params(func_graph, template_tensors): def _get_grad_fn_name(func_graph): - """Returns a unique name to use for the grad function of `func_graph`.""" + """Returns a unique name to use for the grad function of `func_graph`. + + Ensures this name is unique in the entire hierarchy. + + Args: + func_graph: The _FuncGraph. + + Returns: + A string, the name to use for the gradient function. + """ name = "%s_grad" % func_graph.name base_name = name counter = 1 - if ops.get_default_graph()._is_function(name): - name = "%s_%s" % (base_name, counter) - counter += 1 + has_conflict = True + while has_conflict: + curr_graph = func_graph._outer_graph + has_conflict = curr_graph._is_function(name) + while not has_conflict and isinstance(curr_graph, _function._FuncGraph): + curr_graph = curr_graph._outer_graph + has_conflict = curr_graph._is_function(name) + if has_conflict: + name = "%s_%s" % (base_name, counter) + counter += 1 return name @@ -427,3 +529,11 @@ def _check_same_outputs(true_graph, false_graph): "arguments, got:\n" " true_fn: %s\n" " false_fn: %s" % (true_output_types, false_output_types)) + + +def _is_ancestor(graph, maybe_ancestor): + if maybe_ancestor == graph: + return True + if isinstance(graph, _function._FuncGraph): + return _is_ancestor(graph._outer_graph, maybe_ancestor) + return False diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 2e5a801f8e96aa1266695a1440d98e6bff53607c..f84ff4ddf03864030013a14be58bafd3a1817882 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -24,6 +24,7 @@ from __future__ import print_function import abc import collections import functools +import os import six @@ -38,6 +39,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import control_flow_util as util from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_control_flow_ops @@ -57,6 +59,10 @@ from tensorflow.python.util import nest from tensorflow.python.util import tf_should_use from tensorflow.python.util.tf_export import tf_export + +_ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0" + + # We override the 'tuple' for a control flow op, so we keep python's # existing 'tuple' for later use in this module. _basetuple = tuple @@ -596,7 +602,6 @@ def _EnforceShapeInvariant(merge_var, next_var): enter = merge_var.op.inputs[0].op assert util.IsLoopEnter(enter) input_t = enter.inputs[0] - assert input_t.shape == m_shape raise ValueError( "Input tensor '%s' enters the loop with shape %s, but has shape %s " "after one iteration. To allow the shape to vary across iterations, " @@ -812,11 +817,12 @@ class GradLoopState(object): outer_forward_ctxt = forward_ctxt.outer_context # Add the forward loop counter. - if outer_forward_ctxt: - outer_forward_ctxt.Enter() - cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) - if outer_forward_ctxt: - outer_forward_ctxt.Exit() + with forward_ctxt._graph.as_default(): # pylint: disable=protected-access + if outer_forward_ctxt: + outer_forward_ctxt.Enter() + cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) + if outer_forward_ctxt: + outer_forward_ctxt.Exit() self._forward_context = forward_ctxt self._forward_index = forward_index @@ -979,60 +985,61 @@ class GradLoopState(object): for the stack can't be found. """ # curr_ctxt is the context that tf.gradients was called in. - curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access - with ops.control_dependencies(None): - if curr_ctxt: - curr_ctxt.Enter() - with ops.colocate_with(value): - # We only need to pass maximum_iterations to the stack if - # we're inside an XLA context. - if not util.IsInXLAContext(value.op): - max_size = constant_op.constant(-1, dtypes.int32) - else: - max_size = GetMaxSizeFromNestedMaximumIterations( - value, self.forward_context) - acc = gen_data_flow_ops.stack_v2( - max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") - if curr_ctxt: - curr_ctxt.Exit() - - # Make acc available in the forward context. - enter_acc = self.forward_context.AddValue(acc) - - # Add the stack_push op in the context of value.op. - swap_enabled = self.forward_context.swap_memory - value_ctxt = util.GetOutputContext(value.op) - if value_ctxt == self.forward_context: - # value is not nested in the forward context. - self.forward_context.Enter() - push = gen_data_flow_ops.stack_push_v2( - enter_acc, value, swap_memory=swap_enabled) - self.forward_context.Exit() - # Protect stack push and order it before forward_index. - self.forward_index.op._add_control_input(push.op) - else: - # value is in a cond context within the forward context. - if not isinstance(value_ctxt, CondContext): - raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) - if dead_branch: - # The special case for creating a zero tensor for a dead - # branch of a switch. See ControlFlowState.ZerosLike(). - value_ctxt.outer_context.Enter() + with self._forward_index.graph.as_default(): + curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + with ops.control_dependencies(None): + if curr_ctxt: + curr_ctxt.Enter() + with ops.colocate_with(value): + # We only need to pass maximum_iterations to the stack if + # we're inside an XLA context. + if not util.IsInXLAContext(value.op): + max_size = constant_op.constant(-1, dtypes.int32) + else: + max_size = GetMaxSizeFromNestedMaximumIterations( + value, self.forward_context) + acc = gen_data_flow_ops.stack_v2( + max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") + if curr_ctxt: + curr_ctxt.Exit() + + # Make acc available in the forward context. + enter_acc = self.forward_context.AddValue(acc) + + # Add the stack_push op in the context of value.op. + swap_enabled = self.forward_context.swap_memory + value_ctxt = util.GetOutputContext(value.op) + if value_ctxt == self.forward_context: + # value is not nested in the forward context. + self.forward_context.Enter() push = gen_data_flow_ops.stack_push_v2( enter_acc, value, swap_memory=swap_enabled) - value_ctxt.outer_context.Exit() - push.op._set_control_flow_context(value_ctxt) + self.forward_context.Exit() + # Protect stack push and order it before forward_index. + self.forward_index.op._add_control_input(push.op) else: - value_ctxt.Enter() - push = gen_data_flow_ops.stack_push_v2( - enter_acc, value, swap_memory=swap_enabled) - value_ctxt.Exit() - # Protect stack push and order it before forward_sync. - self.forward_sync._add_control_input(push.op) - # Order stack push after the successor of forward_index - add_op = self.forward_index.op.inputs[0].op - push.op._add_control_input(add_op) - return acc + # value is in a cond context within the forward context. + if not isinstance(value_ctxt, CondContext): + raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) + if dead_branch: + # The special case for creating a zero tensor for a dead + # branch of a switch. See ControlFlowState.ZerosLike(). + value_ctxt.outer_context.Enter() + push = gen_data_flow_ops.stack_push_v2( + enter_acc, value, swap_memory=swap_enabled) + value_ctxt.outer_context.Exit() + push.op._set_control_flow_context(value_ctxt) + else: + value_ctxt.Enter() + push = gen_data_flow_ops.stack_push_v2( + enter_acc, value, swap_memory=swap_enabled) + value_ctxt.Exit() + # Protect stack push and order it before forward_sync. + self.forward_sync._add_control_input(push.op) + # Order stack push after the successor of forward_index + add_op = self.forward_index.op.inputs[0].op + push.op._add_control_input(add_op) + return acc def AddBackpropAccumulatedValue(self, history_value, value, dead_branch=False): @@ -1442,14 +1449,17 @@ def ZerosLikeOutsideLoop(op, index): pred = op_ctxt.pred branch = op_ctxt.branch switch_val = switch(op.inputs[0], pred)[1 - branch] + # A op is created along the branch taken as control dependencies are on + # the whole op and not on the tensor output. + pivot = array_ops.identity(switch_val) if val.dtype == dtypes.resource: - with ops.control_dependencies([switch_val]): + with ops.control_dependencies([pivot]): return array_ops.zeros( gen_resource_variable_ops.variable_shape(switch_val)) zeros_shape = array_ops.shape_internal(switch_val, optimize=False) # Ensure ops created within array_ops.zeros are dominated by switch in # cond context. - with ops.control_dependencies([switch_val]): + with ops.control_dependencies([pivot]): return array_ops.zeros(zeros_shape, dtype=val.dtype) else: return array_ops.zeros_like(val, optimize=False) @@ -1812,15 +1822,34 @@ class CondContext(ControlFlowContext): def _AddOpInternal(self, op): """Add `op` to the current context.""" if not op.inputs: - # Remove any external control dependency on this op + # If we're in a while loop, remove any control inputs from outside the + # loop. self._RemoveExternalControlEdges(op) - # pylint: disable=protected-access - op._add_control_input(self._pivot.op) - # pylint: enable=protected-access + + if not any(util.OpInContext(input_op, self) + for input_op in op.control_inputs): + # pylint: disable=protected-access + op._add_control_input(self._pivot.op) + # pylint: enable=protected-access else: + # Make each input to 'op' available in this CondContext. If an input is + # already part of this context there's nothing to do, but if it's + # external, AddValue() will handle adding the appropriate Switch node and + # other bookkeeping. for index in range(len(op.inputs)): x = op.inputs[index] - real_x = self.AddValue(x) + if op.type == "Merge" and x.op.type == "NextIteration": + # Edge case: if we're importing a while loop inside this CondContext, + # AddValue() will not correctly handle the NextIteration inputs to + # Merge node. The problem is that the NextIteration should also be + # part of this context, but if we're importing it won't have been + # processed and added to the context yet, so AddValue() will try to + # add a Switch which results in an invalid graph. Instead, we use the + # NextIteration input as-is here, and it will eventually be added to + # the context via AddOp(). + real_x = x + else: + real_x = self.AddValue(x) if real_x != x: # pylint: disable=protected-access op._update_input(index, real_x) @@ -1994,6 +2023,9 @@ def cond(pred, ``` """ + if _ENABLE_COND_V2: + return cond_v2_impl.cond_v2(pred, true_fn, false_fn, name) + # We needed to make true_fn/false_fn keyword arguments for # backwards-compatibility. This check exists so that we can convert back to # having them be positional arguments. @@ -2036,21 +2068,25 @@ def cond(pred, # Build the graph for the true branch in a new context. context_t = CondContext(pred, pivot_1, branch=1) - context_t.Enter() - orig_res_t, res_t = context_t.BuildCondBranch(true_fn) - if orig_res_t is None: - raise ValueError("true_fn must have a return value.") - context_t.ExitResult(res_t) - context_t.Exit() + try: + context_t.Enter() + orig_res_t, res_t = context_t.BuildCondBranch(true_fn) + if orig_res_t is None: + raise ValueError("true_fn must have a return value.") + context_t.ExitResult(res_t) + finally: + context_t.Exit() # Build the graph for the false branch in a new context. context_f = CondContext(pred, pivot_2, branch=0) - context_f.Enter() - orig_res_f, res_f = context_f.BuildCondBranch(false_fn) - if orig_res_f is None: - raise ValueError("false_fn must have a return value.") - context_f.ExitResult(res_f) - context_f.Exit() + try: + context_f.Enter() + orig_res_f, res_f = context_f.BuildCondBranch(false_fn) + if orig_res_f is None: + raise ValueError("false_fn must have a return value.") + context_f.ExitResult(res_f) + finally: + context_f.Exit() if not strict: orig_res_t = _UnpackIfSingleton(orig_res_t) @@ -2188,6 +2224,7 @@ class WhileContext(ControlFlowContext): self._loop_exits = [] # The list of enter tensors for loop variables. self._loop_enters = [] + self._graph = ops.get_default_graph() def _init_from_proto(self, context_def, import_scope=None): """Creates a new `WhileContext` from protocol buffer. @@ -2241,6 +2278,7 @@ class WhileContext(ControlFlowContext): op._set_attr("frame_name", attr_value_pb2.AttrValue(s=compat.as_bytes(self.name))) # pylint: enable=protected-access + self._graph = ops.get_default_graph() @property def maximum_iterations(self): @@ -2565,7 +2603,14 @@ class WhileContext(ControlFlowContext): Returns: The loop index. """ - one = constant_op.constant(1, name="b_count") + in_separate_functions = count.graph is not ops.get_default_graph() + if in_separate_functions: + # Brings the count into this graph + count = array_ops.identity(count) + else: + # TODO(apassos) XLA expects this constant to be created outside the loop, + # so doing that for now. + one = constant_op.constant(1, name="b_count") self.Enter() self.AddName(count.name) @@ -2580,6 +2625,8 @@ class WhileContext(ControlFlowContext): merge_count = merge([enter_count, enter_count])[0] self._pivot_for_pred = merge_count + if in_separate_functions: + one = constant_op.constant(1, name="b_count") pred = math_ops.greater_equal(merge_count, one) self._pivot = loop_cond(pred, name="b_count") switch_count = switch(merge_count, self._pivot) @@ -2924,7 +2971,8 @@ class WhileContext(ControlFlowContext): return original_body_result, exit_vars - def BuildLoop(self, pred, body, loop_vars, shape_invariants): + def BuildLoop(self, pred, body, loop_vars, shape_invariants, + return_same_structure): """Add the loop termination condition and body to the graph.""" # Keep original_loop_vars to identify which are TensorArrays @@ -2935,9 +2983,10 @@ class WhileContext(ControlFlowContext): loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars) try: self.Enter() - # _BuildLoop calls _update_input in several places. _lock ensures a - # Session.run call cannot occur between creating and mutating new ops. - with ops.get_default_graph()._lock: # pylint: disable=protected-access + # _BuildLoop calls _update_input in several places. _mutation_lock() + # ensures a Session.run call cannot occur between creating and mutating + # new ops. + with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access original_body_result, exit_vars = self._BuildLoop( pred, body, original_loop_vars, loop_vars, shape_invariants) finally: @@ -2951,7 +3000,11 @@ class WhileContext(ControlFlowContext): packed_exit_vars = nest.pack_sequence_as( structure=original_body_result, flat_sequence=exit_vars_with_tensor_arrays) - return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars + + if return_same_structure: + return packed_exit_vars + else: + return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars def _FixControlInputsAndContext(self, enters): graph = ops.get_default_graph() @@ -2991,7 +3044,8 @@ def while_loop(cond, back_prop=True, swap_memory=False, name=None, - maximum_iterations=None): + maximum_iterations=None, + return_same_structure=False): """Repeat `body` while the condition `cond` is true. `cond` is a callable returning a boolean scalar tensor. `body` is a callable @@ -3022,7 +3076,7 @@ def while_loop(cond, `loop_vars` is the same in every iteration. The `shape_invariants` argument allows the caller to specify a less specific shape invariant for each loop variable, which is needed if the shape varies between iterations. The - @{tf.Tensor.set_shape} + `tf.Tensor.set_shape` function may also be used in the `body` function to indicate that the output loop variable has a particular shape. The shape invariant for SparseTensor and IndexedSlices are treated specially as follows: @@ -3067,11 +3121,16 @@ def while_loop(cond, to run. If provided, the `cond` output is AND-ed with an additional condition ensuring the number of iterations executed is no greater than `maximum_iterations`. + return_same_structure: If True, output has same structure as `loop_vars`. If + eager execution is enabled, this is ignored (and always treated as True). Returns: - The output tensors for the loop variables after the loop. When the length - of `loop_vars` is 1 this is a Tensor, TensorArray or IndexedSlice and when - the length of `loop_vars` is greater than 1 it returns a list. + The output tensors for the loop variables after the loop. + If `return_same_structure` is True, the return value has the same + structure as `loop_vars`. + If `return_same_structure` is False, the return value is a Tensor, + TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list + otherwise. Raises: TypeError: if `cond` or `body` is not callable. @@ -3126,6 +3185,7 @@ def while_loop(cond, happen is that the thread updating `x` can never get ahead of the counter thread because the thread incrementing `x` depends on the value of the counter. + ```python import tensorflow as tf @@ -3207,7 +3267,8 @@ def while_loop(cond, # be encapsulated in the root context. if loop_context.outer_context is None: ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) - result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants) + result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, + return_same_structure) if maximum_iterations is not None: return result[1] else: @@ -3266,7 +3327,7 @@ def with_dependencies(dependencies, output_tensor, name=None): no guarantee that `output_tensor` will be evaluated after any `dependencies` have run. - See also @{tf.tuple$tuple} and @{tf.group$group}. + See also `tf.tuple` and `tf.group`. Args: dependencies: Iterable of operations to run before this op finishes. @@ -3311,8 +3372,8 @@ def group(*inputs, **kwargs): When this op finishes, all ops in `inputs` have finished. This op has no output. - See also @{tf.tuple$tuple} and - @{tf.control_dependencies$control_dependencies}. + See also `tf.tuple` and + `tf.control_dependencies`. Args: *inputs: Zero or more tensors to group. @@ -3340,12 +3401,6 @@ def group(*inputs, **kwargs): if not hasattr(inp, "device"): raise TypeError("Expected tf.group() expected Tensor arguments not " "'%s' with type '%s'" % (inp, type(inp))) - if not hasattr(inp, "device"): - if isinstance(inp, list): - raise TypeError("To call tf.group() with a list, use " - "tf.group(*[...]) not tf.group([...]).") - raise TypeError("Expected tf.group() expected Tensor arguments not " - "'%s' with type '%s'" % (inp, type(inp))) dev = inp.device if dev in ops_on_device: ops_on_device[dev].append(inp) @@ -3387,8 +3442,8 @@ def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined returned by `tuple` are only available after all the parallel computations are done. - See also @{tf.group$group} and - @{tf.control_dependencies$control_dependencies}. + See also `tf.group` and + `tf.control_dependencies`. Args: tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index 59bb925df0f25b3bf88112bc3eb1b13b21ace414..153548ae92cfecfe5c750746b1425abcf3747b1b 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -939,7 +939,7 @@ class CaseTest(test_util.TensorFlowTestCase): class WhileLoopTestCase(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWhileLoopWithSingleVariable(self): i = constant_op.constant(0) c = lambda i: math_ops.less(i, 10) @@ -948,7 +948,7 @@ class WhileLoopTestCase(test_util.TensorFlowTestCase): self.assertEqual(self.evaluate(r), 10) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testEagerWhileLoopWithSingleVariable_bodyReturnsTuple(self): i = constant_op.constant(0) c = lambda i: math_ops.less(i, 10) @@ -958,6 +958,28 @@ class WhileLoopTestCase(test_util.TensorFlowTestCase): # Expect a tuple since that is what the body returns. self.assertEqual(self.evaluate(r), (10,)) + def testWhileLoopSameReturnShape_False(self): + i = constant_op.constant(0) + c = lambda i, _: math_ops.less(i, 10) + + # Body returns a [tensor, []] + b = lambda i, _: [math_ops.add(i, 1), []] + + # Should only return the tensor. + r = control_flow_ops.while_loop(c, b, [i, []]) + self.assertEqual(self.evaluate(r), 10) + + def testWhileLoopSameReturnShape_True(self): + i = constant_op.constant(0) + c = lambda i, _: math_ops.less(i, 10) + + # Body returns a [tensor, []] + b = lambda i, _: [math_ops.add(i, 1), []] + + # Should only return the original structure. + r = control_flow_ops.while_loop(c, b, [i, []], return_same_structure=True) + self.assertEqual(self.evaluate(r), [10, []]) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py index 7a18986c5b03446d2b0e0a2ecd161dccdc3d70e1..72c074ed1af208da274edd52572961ecaa613b34 100644 --- a/tensorflow/python/ops/control_flow_util.py +++ b/tensorflow/python/ops/control_flow_util.py @@ -214,6 +214,14 @@ def IsContainingContext(ctxt, maybe_containing_ctxt): return True +def OpInContext(op, ctxt): + return IsContainingContext(op._get_control_flow_context(), ctxt) # pylint: disable=protected-access + + +def TensorInContext(tensor, ctxt): + return OpInContext(tensor.op, ctxt) + + def CheckInputFromValidContext(op, input_op): """Returns whether `input_op` can be used from `op`s context. diff --git a/tensorflow/python/ops/conv2d_benchmark.py b/tensorflow/python/ops/conv2d_benchmark.py index 907df85cd954d2a897ba9a0c4b21be8586859380..28111c273059bca3c4cc643b4aa826f9be402308 100644 --- a/tensorflow/python/ops/conv2d_benchmark.py +++ b/tensorflow/python/ops/conv2d_benchmark.py @@ -21,6 +21,8 @@ from __future__ import print_function import itertools import time +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -28,22 +30,32 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables +from tensorflow.python.platform import flags from tensorflow.python.platform import test +FLAGS = flags.FLAGS -def build_graph(device, input_shape, filter_shape, strides, padding, dtype, - num_iters, warmup_iters): +flags.DEFINE_boolean( + "enable_layout_optimizer", False, + "If true, enables layout optimizer to update input data format for faster " + "execution of convolution ops.") + + +def build_graph(device, dtype, data_format, input_shape, filter_shape, strides, + padding, num_iters, warmup_iters): """builds a graph containing a sequence of conv2d operations. Args: device: String, the device to run on. + dtype: Data type for the convolution. + data_format: A string from: "NHWC" or "NCHW". Data format for input and + output data. input_shape: Shape of the input tensor. filter_shape: Shape of the filter tensor. strides: A list of ints. 1-D of length 4. The stride of sliding window for each dimension of input. padding: A string from: "SAME", "VALID". The type of padding algorithm to use. - dtype: Data type for the convolution. num_iters: number of iterations to run conv2d. warmup_iters: number of iterations for warmup runs. @@ -57,22 +69,23 @@ def build_graph(device, input_shape, filter_shape, strides, padding, dtype, random_ops.truncated_normal(filter_shape, dtype=dtype)) outputs = [] - conv2d_op = nn_ops.conv2d(inp, filt, strides, padding, data_format="NHWC") + conv2d_op = nn_ops.conv2d( + inp, filt, strides, padding, data_format=data_format) outputs.append(conv2d_op) for _ in range(1, num_iters): with ops.control_dependencies([conv2d_op]): conv2d_op = nn_ops.conv2d( - inp, filt, strides, padding, data_format="NHWC") + inp, filt, strides, padding, data_format=data_format) outputs.append(conv2d_op) warmup_groups = [] warmup_conv2d_op = nn_ops.conv2d( - inp, filt, strides, padding, data_format="NHWC") + inp, filt, strides, padding, data_format=data_format) warmup_groups.append(warmup_conv2d_op) for _ in range(1, warmup_iters): with ops.control_dependencies([warmup_conv2d_op]): warmup_conv2d_op = nn_ops.conv2d( - inp, filt, strides, padding, data_format="NHWC") + inp, filt, strides, padding, data_format=data_format) warmup_groups.append(warmup_conv2d_op) return control_flow_ops.group(*warmup_groups), control_flow_ops.group( *outputs) @@ -81,12 +94,15 @@ def build_graph(device, input_shape, filter_shape, strides, padding, dtype, class Conv2DBenchmark(test.Benchmark): """Benchmark conv2d!""" - def _run_graph(self, device, input_shape, filter_shape, strides, padding, - dtype, num_iters, warmup_iters): + def _run_graph(self, device, dtype, data_format, input_shape, filter_shape, + strides, padding, num_iters, warmup_iters): """runs the graph and print its execution time. Args: device: String, the device to run on. + dtype: Data type for the convolution. + data_format: A string from: "NHWC" or "NCHW". Data format for input and + output data. input_shape: Shape of the input tensor. filter_shape: Shape of the filter tensor. strides: A list of ints. 1-D of length 4. The stride of sliding @@ -94,7 +110,6 @@ class Conv2DBenchmark(test.Benchmark): padding: A string from: "SAME", "VALID". The type of padding algorithm to use. num_iters: Number of iterations to run the benchmark. - dtype: Data type for the convolution. num_iters: number of iterations to run conv2d. warmup_iters: number of iterations for warmup runs. @@ -103,10 +118,27 @@ class Conv2DBenchmark(test.Benchmark): """ graph = ops.Graph() with graph.as_default(): - warmup_outputs, outputs = build_graph(device, input_shape, filter_shape, - strides, padding, dtype, num_iters, - warmup_iters) - with session_lib.Session(graph=graph) as session: + warmup_outputs, outputs = build_graph(device, dtype, data_format, + input_shape, filter_shape, strides, + padding, num_iters, warmup_iters) + + config = config_pb2.ConfigProto() + config.graph_options.optimizer_options.opt_level = -1 + rewrite_options = config.graph_options.rewrite_options + + # Disable layout optimizer to not change input data_format. + rewrite_options.layout_optimizer = ( + rewriter_config_pb2.RewriterConfig.ON if FLAGS.enable_layout_optimizer + else rewriter_config_pb2.RewriterConfig.OFF) + # Convolution ops are effectively noop in the test graph as we are not + # fetching the convolution outputs. Disable dependency optimizer to not + # remove the conv ops. + rewrite_options.dependency_optimization = ( + rewriter_config_pb2.RewriterConfig.OFF) + + with session_lib.Session(graph=graph, config=config) as session: + # TODO(hinsu): Use run_op_benchmark method from test.Benchmark to run + # benchmark along with warmup. variables.global_variables_initializer().run() # warmup runs session.run(warmup_outputs) @@ -114,20 +146,21 @@ class Conv2DBenchmark(test.Benchmark): start_time = time.time() session.run(outputs) duration = (time.time() - start_time) / num_iters - print("%s %s inputshape:%s filtershape:%s strides:%s padding:%s " + print("%s %s %s inputshape:%s filtershape:%s strides:%s padding:%s " "%d iters: %.8f sec" % - (device, str(dtype), str(input_shape).replace(" ", ""), - str(filter_shape).replace(" ", ""), + (device, str(dtype), data_format, str(input_shape).replace( + " ", ""), str(filter_shape).replace(" ", ""), str(strides).replace(" ", ""), padding, num_iters, duration)) name_template = ( - "conv2d_{device}_{datatype}_input_shape_{inputshape}_" + "conv2d_{device}_{datatype}_{data_format}_input_shape_{inputshape}_" "filter_shape_{filtershape}_strides_{strides}_padding_{padding}") self.report_benchmark( name=name_template.format( device=device, datatype=str(dtype), + data_format=str(data_format), inputshape=str(input_shape).replace(" ", ""), filtershape=str(filter_shape).replace(" ", ""), strides=str(strides).replace(" ", ""), @@ -140,24 +173,38 @@ class Conv2DBenchmark(test.Benchmark): def benchmark_conv2d(self): print("conv2d benchmark:") - h = 500 - w = 500 - fh = 3 - fw = 3 - input_shapes = [] - filter_shapes = [] data_types = [dtypes.float32, dtypes.float16] - for b, c in itertools.product([4, 16, 32], [i for i in range(3, 16)]): - input_shapes += [[b, h, w, c]] - filter_shapes += [[fh, fw, c, b]] - strides = [[1, 2, 2, 1]] + data_formats = ["NHWC", "NCHW"] + in_channels = list(range(1, 10)) + list(range(10, 20, 2)) + list( + range(20, 33, 4)) + out_channels = [4, 16, 32] + hw_strides = [[2, 2]] paddings = ["VALID", "SAME"] - for ishape, fshape in zip(input_shapes, filter_shapes): - for dtype in data_types: - for stride in strides: - for padding in paddings: - self._run_graph("gpu", ishape, fshape, stride, padding, dtype, 80, - 2) + + args_lists = [ + data_types, data_formats, in_channels, out_channels, hw_strides, + paddings + ] + for args in itertools.product(*args_lists): + dtype, data_format, in_channel, out_channel, hw_stride, padding = args + + # Keep batch size same as out channels just to reduce the number of + # different configurations to benchmark. + batch_size = out_channel + h, w, fh, fw = 500, 500, 3, 3 + if data_format == "NHWC": + ishape = [batch_size, h, w, in_channel] + stride = [1] + hw_stride + [1] + elif data_format == "NCHW": + ishape = [batch_size, in_channel, h, w] + stride = [1, 1] + hw_stride + else: + raise ValueError("Unknown data_format: " + str(data_format)) + fshape = [fh, fw, in_channel, out_channel] + num_iters = 80 + warmup_iters = 2 + self._run_graph("gpu", dtype, data_format, ishape, fshape, stride, + padding, num_iters, warmup_iters) if __name__ == "__main__": diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index ca24f11054039472baaefd301e45f57c9444f60d..871f236f783cb9aa13ea8b776dbe5850febcb440 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -73,7 +73,7 @@ def custom_gradient(f): With this definition, the gradient at x=100 will be correctly evaluated as 1.0. - See also @{tf.RegisterGradient} which registers a gradient function for a + See also `tf.RegisterGradient` which registers a gradient function for a primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows for fine grained control over the gradient computation of a sequence of operations. @@ -100,7 +100,7 @@ def custom_gradient(f): Returns: A function `h(x)` which returns the same value as `f(x)[0]` and whose - gradient (as calculated by @{tf.gradients}) is determined by `f(x)[1]`. + gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`. """ def decorated(*args, **kwargs): @@ -142,9 +142,9 @@ def _graph_mode_decorator(f, *args, **kwargs): # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables = list(set(tape.watched_variables()) - set(args)) - grad_argspec = tf_inspect.getargspec(grad_fn) + grad_argspec = tf_inspect.getfullargspec(grad_fn) variables_in_signature = ("variables" in grad_argspec.args or - grad_argspec.keywords) + grad_argspec.varkw) if variables and not variables_in_signature: raise TypeError("If using @custom_gradient with a function that " "uses variables, then grad_fn must accept a keyword " @@ -194,9 +194,9 @@ def _eager_mode_decorator(f, *args, **kwargs): # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables = [v for v in set(tape.watched_variables()) if v not in all_inputs] - grad_argspec = tf_inspect.getargspec(grad_fn) - if (variables and - not ("variables" in grad_argspec.args or grad_argspec.keywords)): + grad_argspec = tf_inspect.getfullargspec(grad_fn) + if (variables and ("variables" not in grad_argspec.args) and + not grad_argspec.varkw): raise TypeError("If using @custom_gradient with a function that " "uses variables, then grad_fn must accept a keyword " "argument 'variables'.") diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index abf597ca55c647cca3f6012ed602a815298e1ed3..7af2ca56be73c7713ac86965b7015a4fc5c957de 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -126,8 +126,8 @@ class QueueBase(object): handle single elements, versions that support enqueuing and dequeuing a batch of elements at once. - See @{tf.FIFOQueue} and - @{tf.RandomShuffleQueue} for concrete + See `tf.FIFOQueue` and + `tf.RandomShuffleQueue` for concrete implementations of this class, and instructions on how to create them. """ @@ -309,12 +309,12 @@ class QueueBase(object): until the element has been enqueued. At runtime, this operation may raise an error if the queue is - @{tf.QueueBase.close} before or during its execution. If the + `tf.QueueBase.close` before or during its execution. If the queue is closed before this operation runs, `tf.errors.CancelledError` will be raised. If this operation is blocked, and either (i) the queue is closed by a close operation with `cancel_pending_enqueues=True`, or (ii) the session is - @{tf.Session.close}, + `tf.Session.close`, `tf.errors.CancelledError` will be raised. Args: @@ -352,12 +352,12 @@ class QueueBase(object): until all of the elements have been enqueued. At runtime, this operation may raise an error if the queue is - @{tf.QueueBase.close} before or during its execution. If the + `tf.QueueBase.close` before or during its execution. If the queue is closed before this operation runs, `tf.errors.CancelledError` will be raised. If this operation is blocked, and either (i) the queue is closed by a close operation with `cancel_pending_enqueues=True`, or (ii) the session is - @{tf.Session.close}, + `tf.Session.close`, `tf.errors.CancelledError` will be raised. Args: @@ -413,11 +413,11 @@ class QueueBase(object): until there is an element to dequeue. At runtime, this operation may raise an error if the queue is - @{tf.QueueBase.close} before or during its execution. If the + `tf.QueueBase.close` before or during its execution. If the queue is closed, the queue is empty, and there are no pending enqueue operations that can fulfill this request, `tf.errors.OutOfRangeError` will be raised. If the session is - @{tf.Session.close}, + `tf.Session.close`, `tf.errors.CancelledError` will be raised. Args: @@ -455,11 +455,11 @@ class QueueBase(object): `OutOfRange` exception is raised. At runtime, this operation may raise an error if the queue is - @{tf.QueueBase.close} before or during its execution. If the + `tf.QueueBase.close` before or during its execution. If the queue is closed, the queue contains fewer than `n` elements, and there are no pending enqueue operations that can fulfill this request, `tf.errors.OutOfRangeError` will be raised. If the - session is @{tf.Session.close}, + session is `tf.Session.close`, `tf.errors.CancelledError` will be raised. Args: @@ -500,7 +500,7 @@ class QueueBase(object): If the queue is closed and there are more than `0` but fewer than `n` elements remaining, then instead of raising a - `tf.errors.OutOfRangeError` like @{tf.QueueBase.dequeue_many}, + `tf.errors.OutOfRangeError` like `tf.QueueBase.dequeue_many`, less than `n` elements are returned immediately. If the queue is closed and there are `0` elements left in the queue, then a `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`. @@ -608,7 +608,7 @@ def _shared_name(shared_name): class RandomShuffleQueue(QueueBase): """A queue implementation that dequeues elements in a random order. - See @{tf.QueueBase} for a description of the methods on + See `tf.QueueBase` for a description of the methods on this class. """ @@ -657,7 +657,7 @@ class RandomShuffleQueue(QueueBase): with the same length as `dtypes`, or `None`. If specified the dequeue methods return a dictionary with the names as keys. seed: A Python integer. Used to create a random seed. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. shared_name: (Optional.) If non-empty, this queue will be shared under the given name across multiple sessions. @@ -693,7 +693,7 @@ class RandomShuffleQueue(QueueBase): class FIFOQueue(QueueBase): """A queue implementation that dequeues elements in first-in first-out order. - See @{tf.QueueBase} for a description of the methods on + See `tf.QueueBase` for a description of the methods on this class. """ @@ -753,7 +753,7 @@ class PaddingFIFOQueue(QueueBase): A `PaddingFIFOQueue` may contain components with dynamic shape, while also supporting `dequeue_many`. See the constructor for more details. - See @{tf.QueueBase} for a description of the methods on + See `tf.QueueBase` for a description of the methods on this class. """ @@ -824,7 +824,7 @@ class PaddingFIFOQueue(QueueBase): class PriorityQueue(QueueBase): """A queue implementation that dequeues elements in prioritized order. - See @{tf.QueueBase} for a description of the methods on + See `tf.QueueBase` for a description of the methods on this class. """ diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py index f28f76b6c42a861c51c1fc06f99fa73b71b625a9..99d30b0bd112b62c625a94b43da589f9717d0774 100644 --- a/tensorflow/python/ops/distributions/beta.py +++ b/tensorflow/python/ops/distributions/beta.py @@ -84,13 +84,24 @@ class Beta(distribution.Distribution): Distribution parameters are automatically broadcast in all functions; see examples for details. + Warning: The samples can be zero due to finite precision. + This happens more often when some of the concentrations are very small. + Make sure to round the samples to `np.finfo(dtype).tiny` before computing the + density. + + Samples of this distribution are reparameterized (pathwise differentiable). + The derivatives are computed using the approach described in the paper + + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) + #### Examples ```python # Create a batch of three Beta distributions. alpha = [1, 2, 3] beta = [1, 2, 3] - dist = Beta(alpha, beta) + dist = tf.distributions.Beta(alpha, beta) dist.sample([4, 5]) # Shape [4, 5, 3] @@ -106,7 +117,7 @@ class Beta(distribution.Distribution): # Create batch_shape=[2, 3] via parameter broadcast: alpha = [[1.], [2]] # Shape [2, 1] beta = [3., 4, 5] # Shape [3] - dist = Beta(alpha, beta) + dist = tf.distributions.Beta(alpha, beta) # alpha broadcast as: [[1., 1, 1,], # [2, 2, 2]] @@ -122,6 +133,18 @@ class Beta(distribution.Distribution): dist.prob(x) # Shape [2, 3] ``` + Compute the gradients of samples w.r.t. the parameters: + + ```python + alpha = tf.constant(1.0) + beta = tf.constant(2.0) + dist = tf.distributions.Beta(alpha, beta) + samples = dist.sample(5) # Shape [5] + loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function + # Unbiased stochastic gradients of the loss function + grads = tf.gradients(loss, [alpha, beta]) + ``` + """ def __init__(self, @@ -165,7 +188,7 @@ class Beta(distribution.Distribution): dtype=self._total_concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, - reparameterization_type=distribution.NOT_REPARAMETERIZED, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=[self._concentration1, self._concentration0, diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py index b88a0518b6db15021b9917d4c2b5ffb7bcf9484f..dd25fce2ec860456fdbbad903032cf4bcda9daba 100644 --- a/tensorflow/python/ops/distributions/categorical.py +++ b/tensorflow/python/ops/distributions/categorical.py @@ -32,12 +32,8 @@ from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.util.tf_export import tf_export -def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32): +def _broadcast_cat_event_and_params(event, params, base_dtype): """Broadcasts the event or distribution parameters.""" - if event.shape.ndims is None: - raise NotImplementedError( - "Cannot broadcast with an event tensor of unknown rank.") - if event.dtype.is_integer: pass elif event.dtype.is_floating: @@ -47,15 +43,18 @@ def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32): else: raise TypeError("`value` should have integer `dtype` or " "`self.dtype` ({})".format(base_dtype)) - - if params.get_shape()[:-1] == event.get_shape(): - params = params - else: - params *= array_ops.ones_like( - array_ops.expand_dims(event, -1), dtype=params.dtype) + shape_known_statically = ( + params.shape.ndims is not None and + params.shape[:-1].is_fully_defined() and + event.shape.is_fully_defined()) + if not shape_known_statically or params.shape[:-1] != event.shape: + params *= array_ops.ones_like(event[..., array_ops.newaxis], + dtype=params.dtype) params_shape = array_ops.shape(params)[:-1] event *= array_ops.ones(params_shape, dtype=event.dtype) - event.set_shape(tensor_shape.TensorShape(params.get_shape()[:-1])) + if params.shape.ndims is not None: + event.set_shape(tensor_shape.TensorShape(params.shape[:-1])) + return event, params diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py index 72567e62f78665947c001282c9c4f4929e9ea0ef..9104a1d071af3d7b7d40838148f2e49301fa39ba 100644 --- a/tensorflow/python/ops/distributions/dirichlet.py +++ b/tensorflow/python/ops/distributions/dirichlet.py @@ -90,13 +90,24 @@ class Dirichlet(distribution.Distribution): Distribution parameters are automatically broadcast in all functions; see examples for details. + Warning: Some components of the samples can be zero due to finite precision. + This happens more often when some of the concentrations are very small. + Make sure to round the samples to `np.finfo(dtype).tiny` before computing the + density. + + Samples of this distribution are reparameterized (pathwise differentiable). + The derivatives are computed using the approach described in the paper + + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) + #### Examples ```python # Create a single trivariate Dirichlet, with the 3rd class being three times # more frequent than the first. I.e., batch_shape=[], event_shape=[3]. alpha = [1., 2, 3] - dist = Dirichlet(alpha) + dist = tf.distributions.Dirichlet(alpha) dist.sample([4, 5]) # shape: [4, 5, 3] @@ -118,7 +129,7 @@ class Dirichlet(distribution.Distribution): # Create batch_shape=[2], event_shape=[3]: alpha = [[1., 2, 3], [4, 5, 6]] # shape: [2, 3] - dist = Dirichlet(alpha) + dist = tf.distributions.Dirichlet(alpha) dist.sample([4, 5]) # shape: [4, 5, 2, 3] @@ -129,6 +140,17 @@ class Dirichlet(distribution.Distribution): dist.prob(x) # shape: [2] ``` + Compute the gradients of samples w.r.t. the parameters: + + ```python + alpha = tf.constant([1.0, 2.0, 3.0]) + dist = tf.distributions.Dirichlet(alpha) + samples = dist.sample(5) # Shape [5, 3] + loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function + # Unbiased stochastic gradients of the loss function + grads = tf.gradients(loss, alpha) + ``` + """ def __init__(self, @@ -165,7 +187,7 @@ class Dirichlet(distribution.Distribution): dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, - reparameterization_type=distribution.NOT_REPARAMETERIZED, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=[self._concentration, self._total_concentration], @@ -290,10 +312,8 @@ class Dirichlet(distribution.Distribution): if not self.validate_args: return x return control_flow_ops.with_dependencies([ - check_ops.assert_positive( - x, - message="samples must be positive"), - distribution_util.assert_close( + check_ops.assert_positive(x, message="samples must be positive"), + check_ops.assert_near( array_ops.ones([], dtype=self.dtype), math_ops.reduce_sum(x, -1), message="sample last-dimension must sum to `1`"), diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py index 41dcd401887a124780a35c3dbd84140553860485..ddf9442cd22d68d6ff43bb8017983e774ce9e11b 100644 --- a/tensorflow/python/ops/distributions/distribution.py +++ b/tensorflow/python/ops/distributions/distribution.py @@ -212,7 +212,7 @@ class ReparameterizationType(object): reparameterized, and straight-through gradients are either partially unsupported or are not supported at all. In this case, for purposes of e.g. RL or variational inference, it is generally safest to wrap the - sample results in a `stop_gradients` call and instead use policy + sample results in a `stop_gradients` call and use policy gradients / surrogate loss instead. """ @@ -526,8 +526,8 @@ class Distribution(_BaseDistribution): # Remove "self", "__class__", or other special variables. These can appear # if the subclass used: # `parameters = dict(locals())`. - return dict((k, v) for k, v in self._parameters.items() - if not k.startswith("__") and k != "self") + return {k: v for k, v in self._parameters.items() + if not k.startswith("__") and k != "self"} @property def reparameterization_type(self): diff --git a/tensorflow/python/ops/distributions/exponential.py b/tensorflow/python/ops/distributions/exponential.py index 24bc3f3d3eb06a01d5173cb6c7fb0f09172a0587..4325a14449dd9a13dabb65a240ede452544c761a 100644 --- a/tensorflow/python/ops/distributions/exponential.py +++ b/tensorflow/python/ops/distributions/exponential.py @@ -103,9 +103,6 @@ class Exponential(gamma.Gamma): allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name) - # While the Gamma distribution is not reparameterizable, the exponential - # distribution is. - self._reparameterization_type = True self._parameters = parameters self._graph_parents += [self._rate] diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py index 163a27f7585518c321dd1ea59b71029e2ae6a1e7..b631f0247c59e518fbd4925065d33345d4ea8e47 100644 --- a/tensorflow/python/ops/distributions/gamma.py +++ b/tensorflow/python/ops/distributions/gamma.py @@ -55,7 +55,7 @@ class Gamma(distribution.Distribution): ```none pdf(x; alpha, beta, x > 0) = x**(alpha - 1) exp(-x beta) / Z - Z = Gamma(alpha) beta**alpha + Z = Gamma(alpha) beta**(-alpha) ``` where: @@ -85,14 +85,35 @@ class Gamma(distribution.Distribution): Distribution parameters are automatically broadcast in all functions; see examples for details. - WARNING: This distribution may draw 0-valued samples for small `concentration` - values. See note in `tf.random_gamma` docstring. + Warning: The samples of this distribution are always non-negative. However, + the samples that are smaller than `np.finfo(dtype).tiny` are rounded + to this value, so it appears more often than it should. + This should only be noticeable when the `concentration` is very small, or the + `rate` is very large. See note in `tf.random_gamma` docstring. + + Samples of this distribution are reparameterized (pathwise differentiable). + The derivatives are computed using the approach described in the paper + + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) #### Examples ```python - dist = Gamma(concentration=3.0, rate=2.0) - dist2 = Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) + dist = tf.distributions.Gamma(concentration=3.0, rate=2.0) + dist2 = tf.distributions.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) + ``` + + Compute the gradients of samples w.r.t. the parameters: + + ```python + concentration = tf.constant(3.0) + rate = tf.constant(2.0) + dist = tf.distributions.Gamma(concentration, rate) + samples = dist.sample(5) # Shape [5] + loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function + # Unbiased stochastic gradients of the loss function + grads = tf.gradients(loss, [concentration, rate]) ``` """ @@ -141,7 +162,7 @@ class Gamma(distribution.Distribution): dtype=self._concentration.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, - reparameterization_type=distribution.NOT_REPARAMETERIZED, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, parameters=parameters, graph_parents=[self._concentration, self._rate], diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py index 20a2d16181442bede797ded5e4d3ebbd3d55ca2b..e0cf6f86f10eec76bf94cd74f64202c452425886 100644 --- a/tensorflow/python/ops/distributions/student_t.py +++ b/tensorflow/python/ops/distributions/student_t.py @@ -80,6 +80,12 @@ class StudentT(distribution.Distribution): variance. However it is not actually the std. deviation; the Student's t-distribution std. dev. is `scale sqrt(df / (df - 2))` when `df > 2`. + Samples of this distribution are reparameterized (pathwise differentiable). + The derivatives are computed using the approach described in the paper + + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) + #### Examples Examples of initialization of one or a batch of distributions. @@ -118,6 +124,19 @@ class StudentT(distribution.Distribution): dist.prob(3.0) ``` + Compute the gradients of samples w.r.t. the parameters: + + ```python + df = tf.constant(2.0) + loc = tf.constant(2.0) + scale = tf.constant(11.0) + dist = tf.distributions.StudentT(df=df, loc=loc, scale=scale) + samples = dist.sample(5) # Shape [5] + loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function + # Unbiased stochastic gradients of the loss function + grads = tf.gradients(loss, [df, loc, scale]) + ``` + """ # pylint: enable=line-too-long @@ -168,7 +187,7 @@ class StudentT(distribution.Distribution): (self._df, self._loc, self._scale)) super(StudentT, self).__init__( dtype=self._scale.dtype, - reparameterization_type=distribution.NOT_REPARAMETERIZED, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index 401676bf842b4dd76fc64b5f4599804a0f3a46f8..3e480a79f52b178789a2d34e98c6af31048c07b1 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -36,43 +36,6 @@ from tensorflow.python.ops import nn from tensorflow.python.util import tf_inspect -def assert_close( - x, y, data=None, summarize=None, message=None, name="assert_close"): - """Assert that x and y are within machine epsilon of each other. - - Args: - x: Floating-point `Tensor` - y: Floating-point `Tensor` - data: The tensors to print out if the condition is `False`. Defaults to - error message and first few entries of `x` and `y`. - summarize: Print this many entries of each tensor. - message: A string to prefix to the default message. - name: A name for this operation (optional). - - Returns: - Op raising `InvalidArgumentError` if |x - y| > machine epsilon. - """ - message = message or "" - x = ops.convert_to_tensor(x, name="x") - y = ops.convert_to_tensor(y, name="y") - - if data is None: - data = [ - message, - "Condition x ~= y did not hold element-wise: x = ", x, "y = ", y - ] - - if x.dtype.is_integer: - return check_ops.assert_equal( - x, y, data=data, summarize=summarize, message=message, name=name) - - with ops.name_scope(name, "assert_close", [x, y, data]): - tol = np.finfo(x.dtype.as_numpy_dtype).eps - condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol)) - return control_flow_ops.Assert( - condition, data, summarize=summarize) - - def assert_integer_form( x, data=None, summarize=None, message=None, int_dtype=None, name="assert_integer_form"): @@ -241,8 +204,12 @@ def get_logits_and_probs(logits=None, dependencies = [check_ops.assert_non_negative(probs)] if multidimensional: probs = embed_check_categorical_event_shape(probs) - dependencies += [assert_close(math_ops.reduce_sum(probs, -1), one, - message="probs does not sum to 1.")] + dependencies += [ + check_ops.assert_near( + math_ops.reduce_sum(probs, -1), + one, + message="probs does not sum to 1.") + ] else: dependencies += [check_ops.assert_less_equal( probs, one, message="probs has components greater than 1.")] diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index bcc717b043f226a18344de31b36f09d5064f25a3..7b9e7de1457d8503a9aac793227dbea9675653bb 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops # Imports gradient definitions. @@ -30,6 +31,7 @@ from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-impor from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -43,8 +45,8 @@ def _clip(params, ids, max_norm): Args: params: A `Tensor` of embeddings retrieved by `gather`. ids: The `ids` argument that was passed to `gather`. - max_norm: If provided, the embeddings are l2-normalized to the value of - max_norm. + max_norm: If not `None`, each embedding is clipped if its l2-norm is + larger than this value. Returns: A `Tensor` with the same type as `params`. @@ -251,7 +253,7 @@ def embedding_lookup( This function is used to perform parallel lookups on the list of tensors in `params`. It is a generalization of - @{tf.gather}, where `params` is + `tf.gather`, where `params` is interpreted as a partitioning of a large embedding tensor. `params` may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a partitioner. @@ -290,8 +292,8 @@ def embedding_lookup( in `indices` are always validated to be within range. If assigned to GPU, out-of-bound indices result in safe but unspecified behavior, which may include raising an error. - max_norm: If provided, embedding values are l2-normalized to the value of - max_norm. + max_norm: If not `None`, each embedding is clipped if its l2-norm is + larger than this value. Returns: A `Tensor` with the same type as the tensors in `params`. @@ -346,8 +348,8 @@ def embedding_lookup_sparse(params, "mean" is the weighted sum divided by the total weight. "sqrtn" is the weighted sum divided by the square root of the sum of the squares of the weights. - max_norm: If provided, each embedding is normalized to have l2 norm equal - to max_norm before combining. + max_norm: If not `None`, each embedding is clipped if its l2-norm is + larger than this value, before combining. Returns: A dense tensor representing the combined embeddings for the @@ -479,3 +481,158 @@ def embedding_lookup_sparse(params, assert False, "Unrecognized combiner" return embeddings + + +@tf_export("nn.safe_embedding_lookup_sparse") +def safe_embedding_lookup_sparse(embedding_weights, + sparse_ids, + sparse_weights=None, + combiner='mean', + default_id=None, + name=None, + partition_strategy='div', + max_norm=None): + """Lookup embedding results, accounting for invalid IDs and empty features. + + The partitioned embedding in `embedding_weights` must all be the same shape + except for the first dimension. The first dimension is allowed to vary as the + vocabulary size is not necessarily a multiple of `P`. `embedding_weights` + may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a + partitioner. + + Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs + with non-positive weight. For an entry with no features, the embedding vector + for `default_id` is returned, or the 0-vector if `default_id` is not supplied. + + The ids and weights may be multi-dimensional. Embeddings are always aggregated + along the last dimension. + + Args: + embedding_weights: A list of `P` float `Tensor`s or values representing + partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable` + created by partitioning along dimension 0. The total unpartitioned + shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the + vocab size and `e_1, ..., e_m` are the embedding dimensions. + sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the + ids. `d_0` is typically batch size. + sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing + float weights corresponding to `sparse_ids`, or `None` if all weights + are be assumed to be 1.0. + combiner: A string specifying how to combine embedding results for each + entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" + the default. + default_id: The id to use for an entry with no features. + name: A name for this operation (optional). + partition_strategy: A string specifying the partitioning strategy. + Currently `"div"` and `"mod"` are supported. Default is `"div"`. + max_norm: If not `None`, all embeddings are l2-normalized to max_norm before + combining. + + + Returns: + Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`. + + Raises: + ValueError: if `embedding_weights` is empty. + """ + if embedding_weights is None: + raise ValueError('Missing embedding_weights %s.' % embedding_weights) + if isinstance(embedding_weights, variables.PartitionedVariable): + embedding_weights = list(embedding_weights) # get underlying Variables. + if not isinstance(embedding_weights, list): + embedding_weights = [embedding_weights] + if len(embedding_weights) < 1: + raise ValueError('Missing embedding_weights %s.' % embedding_weights) + + dtype = sparse_weights.dtype if sparse_weights is not None else None + embedding_weights = [ + ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights + ] + + with ops.name_scope(name, 'embedding_lookup', + embedding_weights + [sparse_ids, + sparse_weights]) as scope: + # Reshape higher-rank sparse ids and weights to linear segment ids. + original_shape = sparse_ids.dense_shape + original_rank_dim = sparse_ids.dense_shape.get_shape()[0] + original_rank = ( + array_ops.size(original_shape) + if original_rank_dim.value is None + else original_rank_dim.value) + sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ + math_ops.reduce_prod( + array_ops.slice(original_shape, [0], [original_rank - 1])), + array_ops.gather(original_shape, original_rank - 1)]) + if sparse_weights is not None: + sparse_weights = sparse_tensor.SparseTensor( + sparse_ids.indices, + sparse_weights.values, sparse_ids.dense_shape) + + # Prune invalid ids and weights. + sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) + if combiner != 'sum': + sparse_ids, sparse_weights = _prune_invalid_weights( + sparse_ids, sparse_weights) + + # Fill in dummy values for empty features, if necessary. + sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids, + default_id or + 0) + if sparse_weights is not None: + sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) + + result = embedding_lookup_sparse( + embedding_weights, + sparse_ids, + sparse_weights, + combiner=combiner, + partition_strategy=partition_strategy, + name=None if default_id is None else scope, + max_norm=max_norm) + + if default_id is None: + # Broadcast is_row_empty to the same shape as embedding_lookup_result, + # for use in Select. + is_row_empty = array_ops.tile( + array_ops.reshape(is_row_empty, [-1, 1]), + array_ops.stack([1, array_ops.shape(result)[1]])) + + result = array_ops.where(is_row_empty, + array_ops.zeros_like(result), + result, + name=scope) + + # Reshape back from linear ids back into higher-dimensional dense result. + final_result = array_ops.reshape( + result, + array_ops.concat([ + array_ops.slice( + math_ops.cast(original_shape, dtypes.int32), [0], + [original_rank - 1]), + array_ops.slice(array_ops.shape(result), [1], [-1]) + ], 0)) + final_result.set_shape(tensor_shape.unknown_shape( + (original_rank_dim - 1).value).concatenate(result.get_shape()[1:])) + return final_result + + +def _prune_invalid_ids(sparse_ids, sparse_weights): + """Prune invalid IDs (< 0) from the input ids and weights.""" + is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) + if sparse_weights is not None: + is_id_valid = math_ops.logical_and( + is_id_valid, + array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) + sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) + if sparse_weights is not None: + sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) + return sparse_ids, sparse_weights + + +def _prune_invalid_weights(sparse_ids, sparse_weights): + """Prune invalid weights (< 0) from the input ids and weights.""" + if sparse_weights is not None: + is_weights_valid = math_ops.greater(sparse_weights.values, 0) + sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) + sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) + return sparse_ids, sparse_weights diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index 30413f289a0674db61406153cc05d12c7cc98f9b..4ecc74675ae673bcc30f18dde75a396ff673bfaa 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -775,7 +775,7 @@ def While(input_, cond, body, name=None, hostmem=None): a string, non-empty means True and empty means False. If the tensor is not a scalar, non-emptiness means True and False otherwise. - body: . A funcion takes a list of tensors and returns another + body: . A function takes a list of tensors and returns another list tensors. Both lists have the same types as specified by T. name: A name for the operation (optional). @@ -945,6 +945,61 @@ def For(start, # pylint: enable=invalid-name,protected-access -def partitioned_call(args, f): - return gen_functional_ops.partitioned_call( - args=args, Tout=[o.type for o in f.definition.signature.output_arg], f=f) +def partitioned_call(args, f, tout=None, executing_eagerly=None): + """Executes a function while respecting device annotations. + + Currently, only those functions that execute within the same address space + can be executed. + + Args: + args: The arguments of the function, including captured inputs. + f: The function to execute; an instance of `_DefinedFunction` or + `_EagerDefinedFunction`. + tout: a list containing the output dtypes enums; if `None`, inferred from + the signature of `f`. + executing_eagerly: (Optional) A boolean indicating whether the context is + executing eagerly. If `None`, fetched from the global context. + + Returns: + The list of `Tensor`s returned by invoking `f(args)`. If the function does + not return anything, then returns `None` if eager execution is enabled, or + the `Operation` if not. + """ + + if tout is None: + tout = tuple(x.type for x in f.definition.signature.output_arg) + + if executing_eagerly is None: + executing_eagerly = context.executing_eagerly() + + if executing_eagerly or len(tout): + if f.stateful_ops: + outputs = gen_functional_ops.stateful_partitioned_call( + args=args, Tout=tout, f=f) + else: + outputs = gen_functional_ops.partitioned_call(args=args, Tout=tout, f=f) + return outputs if outputs else None + + # The generated binding returns an empty list for functions that don't + # return any Tensors, hence the need to use `create_op` directly. + args = [ops.internal_convert_to_tensor(x) for x in args] + tin_attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + type=[x.dtype.as_datatype_enum for x in args])) + tout_attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue(type=tout)) + func_attr = attr_value_pb2.AttrValue( + func=attr_value_pb2.NameAttrList(name=f.name)) + + graph = ops.get_default_graph() + f.add_to_graph(graph) + op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall" + op = graph.create_op( + op_name, + args, + tout, + compute_shapes=False, + name="PartitionedFunctionCall", + attrs={"Tin": tin_attr, "Tout": tout_attr, "f": func_attr}) + outputs = op.outputs + return outputs if outputs else op diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 7385cb758514e160efec61d731e734d1af126742..a68f680224d4b7281637cda1239f95340a513ef5 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections import contextlib +import sys import warnings import numpy as np @@ -30,12 +31,14 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_grad # pylint: disable=unused-import from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops # pylint: disable=unused-import +from tensorflow.python.ops import cond_v2_impl from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util @@ -47,12 +50,17 @@ from tensorflow.python.ops import logging_ops # pylint: disable=unused-import from tensorflow.python.ops import manip_grad # pylint: disable=unused-import from tensorflow.python.ops import math_grad # pylint: disable=unused-import from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_grad # pylint: disable=unused-import from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import spectral_grad # pylint: disable=unused-import from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export +# This is to avoid a circular dependency with cond_v2_impl. +cond_v2_impl._gradients_impl = sys.modules[__name__] # pylint: disable=protected-access + # Warn the user if we convert a sparse representation to dense with at # least this number of elements. _LARGE_SPARSE_NUM_ELEMENTS = 100000000 @@ -107,12 +115,14 @@ ops.register_tensor_conversion_function(ops.IndexedSlices, _IndexedSlicesToTensor) -def _MarkReachedOps(from_ops, reached_ops): +def _MarkReachedOps(from_ops, reached_ops, func_graphs): """Mark all ops reached from "from_ops". Args: from_ops: list of Operations. reached_ops: set of Operations. + func_graphs: list of function._FuncGraphs. This method will traverse through + these functions if they capture from_ops or any reachable ops. """ queue = collections.deque() queue.extend(from_ops) @@ -122,36 +132,11 @@ def _MarkReachedOps(from_ops, reached_ops): reached_ops.add(op) for output in op.outputs: if _IsBackpropagatable(output): - queue.extend(output.consumers()) - - -def _GatherInputs(to_ops, reached_ops): - """List all inputs of to_ops that are in reached_ops. + queue.extend(_Consumers(output, func_graphs)) - Args: - to_ops: list of Operations. - reached_ops: set of Operations. - Returns: - The list of all inputs of to_ops that are in reached_ops. - That list includes all elements of to_ops. - """ - inputs = [] - queue = collections.deque() - queue.extend(to_ops) - while queue: - op = queue.popleft() - # We are interested in this op. - if op in reached_ops: - inputs.append(op) - # Clear the boolean so we won't add the inputs again. - reached_ops.remove(op) - for inp in op.inputs: - queue.append(inp.op) - return inputs - - -def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops): +def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, + xs): """Initialize the pending count for ops between two lists of Operations. 'pending_count[op]' indicates the number of backprop inputs @@ -161,6 +146,11 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops): to_ops: list of Operations. from_ops: list of Operations. colocate_gradients_with_ops: Python bool. See docstring of gradients(). + func_graphs: list of function._FuncGraphs. This method will traverse through + these functions if they capture from_ops or any reachable ops. This is + useful if to_ops occur in a function and from_ops are in an outer function + or graph. + xs: list of Tensors. Returns: A tuple containing: (1) the subset of to_ops reachable from from_ops by a @@ -171,7 +161,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops): """ # Mark reachable ops from from_ops. reached_ops = set() - _MarkReachedOps(from_ops, reached_ops) + _MarkReachedOps(from_ops, reached_ops, func_graphs) # X in reached_ops iff X is reachable from from_ops by a path of zero or more # backpropagatable tensors. @@ -190,7 +180,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops): between_op_list.append(op) # Clear the boolean so we won't add the inputs again. reached_ops.remove(op) - for inp in op.inputs: + for inp in _Inputs(op, xs): queue.append(inp.op) # X in between_ops iff X is on a path of zero or more backpropagatable tensors # between from_ops and to_ops @@ -202,7 +192,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops): # Initialize pending count for between ops. pending_count = collections.defaultdict(int) for op in between_op_list: - for x in op.inputs: + for x in _Inputs(op, xs): if x.op in between_ops: pending_count[x.op] += 1 @@ -323,7 +313,7 @@ def _VerifyGeneratedGradients(grads, op): "inputs %d" % (len(grads), op.node_def, len(op.inputs))) -def _StopOps(from_ops, stop_gradient_ops, pending_count): +def _StopOps(from_ops, stop_gradient_ops, pending_count, xs): """The set of ops that terminate the gradient computation. This computes the frontier of the forward graph *before* which backprop @@ -339,6 +329,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count): from_ops: list of Operations. stop_gradient_ops: list of Operations never to backprop through. pending_count: mapping from operation to number of backprop inputs. + xs: list of Tensors. Returns: The set of operations. @@ -346,7 +337,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count): stop_ops = set() for op in from_ops: is_stop_op = True - for inp in op.inputs: + for inp in _Inputs(op, xs): if pending_count[inp.op] > 0: is_stop_op = False break @@ -366,15 +357,26 @@ def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pyli yield -def _SymGrad(op, out_grads): +def _IsPartitionedCall(op): + return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall" + + +def _SymGrad(op, out_grads, xs): """Backprop through a function call node op given its outputs' gradients.""" - f_in = [x for x in op.inputs] + out_grads - f_types = [x.dtype for x in op.inputs] + f_in = [x for x in _Inputs(op, xs)] + out_grads + f_types = [x.dtype for x in _Inputs(op, xs)] f = attr_value_pb2.NameAttrList() - f.name = op.type + if _IsPartitionedCall(op): + f.name = op.get_attr("f").name + else: + f.name = op.type for k in op.node_def.attr: f.attr[k].CopyFrom(op.node_def.attr[k]) - in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f) + # TODO(apassos) use a better dtype here + in_grads = functional_ops.symbolic_gradient( + input=f_in, + Tout=[x if x != dtypes.resource else dtypes.float32 for x in f_types], + f=f) return in_grads @@ -415,7 +417,7 @@ def _MaybeCompile(scope, op, func, grad_fn): return grad_fn() -def _RaiseNoGradWrtInitialLoopValError(op, from_ops): +def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs): """Raises an error if we backprop through a loop var.""" # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error # message. @@ -429,7 +431,7 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops): if curr_op in from_ops: target_op = curr_op break - queue.extend(t.op for t in curr_op.inputs) + queue.extend(t.op for t in _Inputs(curr_op, xs)) assert target_op raise ValueError( "Cannot compute gradient inside while loop with respect to op '%s'. " @@ -439,6 +441,68 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops): % target_op.name) +def _MaybeCaptured(t): + """If t is a captured value placeholder, returns the original captured value. + + Args: + t: Tensor + + Returns: + A tensor, potentially from a different Graph/function._FuncGraph. + """ + # pylint: disable=protected-access + if isinstance(t.op.graph, function._FuncGraph) and t.op.type == "Placeholder": + for input_t, placeholder_t in t.op.graph._captured.items(): + if t == placeholder_t: + return _MaybeCaptured(input_t) + # pylint: enable=protected-access + return t + + +# TODO(skyewm): plumbing xs through everywhere is ugly, consider making +# _GradientsHelper a class with xs as a member variable. +def _Inputs(op, xs): + """Returns the inputs of op, crossing closure boundaries where necessary. + + Args: + op: Operation + xs: list of Tensors we are differentiating w.r.t. + + Returns: + A list of tensors. The tensors may be from multiple + Graph/function._FuncGraphs if op is in a function._FuncGraph and has + captured inputs. + """ + if isinstance(op.graph, function._FuncGraph): # pylint: disable=protected-access + # If we're differentiating w.r.t. `t`, do not attempt to traverse through it + # to a captured value. The algorithm needs to "see" `t` in this case, even + # if it's a function input for a captured value, whereas usually we'd like + # to traverse through these closures as if the captured value was the direct + # input to op. + return [t if (t in xs) else _MaybeCaptured(t) for t in op.inputs] + else: + return op.inputs + + +def _Consumers(t, func_graphs): + """Returns the consumers of t, crossing closure boundaries where necessary. + + Args: + t: Tensor + func_graphs: a list of function._FuncGraphs that may have captured t. + + Returns: + A list of tensors. The tensors will be from the current graph and/or + func_graphs. + """ + consumers = t.consumers() + for func in func_graphs: + for input_t, placeholder in func._captured.items(): # pylint: disable=protected-access + if input_t == t: + consumers.extend(_Consumers(placeholder, func_graphs)) + return consumers + + @tf_export("gradients") def gradients(ys, xs, @@ -524,10 +588,10 @@ def gradients(ys, RuntimeError: if called in Eager mode. """ - # Creating the gradient graph for control flow mutates Operations. _lock - # ensures a Session.run call cannot occur between creating and mutating new - # ops. - with ops.get_default_graph()._lock: # pylint: disable=protected-access + # Creating the gradient graph for control flow mutates Operations. + # _mutation_lock ensures a Session.run call cannot occur between creating and + # mutating new ops. + with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients) @@ -543,12 +607,19 @@ def _GradientsHelper(ys, src_graph=None): """Implementation of gradients().""" if context.executing_eagerly(): - raise RuntimeError("tf.gradients not supported when eager execution " - "is enabled. Use tf.contrib.eager.GradientTape " - "instead.") + raise RuntimeError("tf.gradients is not supported when eager execution " + "is enabled. Use tf.GradientTape instead.") if src_graph is None: src_graph = ops.get_default_graph() + # If src_graph is a _FuncGraph (i.e. a function body), gather it and all + # ancestor graphs. This is necessary for correctly handling captured values. + func_graphs = [] + curr_graph = src_graph + while isinstance(curr_graph, function._FuncGraph): # pylint: disable=protected-access + func_graphs.append(curr_graph) + curr_graph = curr_graph._outer_graph # pylint: disable=protected-access + ys = _AsList(ys) xs = _AsList(xs) stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients) @@ -582,13 +653,11 @@ def _GradientsHelper(ys, # Initialize the pending count for ops in the connected subgraph from ys # to the xs. - if len(ys) > 1: - ys = [array_ops.identity(y) if y.consumers() else y for y in ys] to_ops = [t.op for t in ys] from_ops = [t.op for t in xs] stop_gradient_ops = [t.op for t in stop_gradients] reachable_to_ops, pending_count, loop_state = _PendingCount( - to_ops, from_ops, colocate_gradients_with_ops) + to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs) # Iterate over the collected ops. # @@ -622,7 +691,7 @@ def _GradientsHelper(ys, _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) queue.append(y.op) - stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count) + stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs) while queue: # generate gradient subgraph for op. op = queue.popleft() @@ -636,13 +705,19 @@ def _GradientsHelper(ys, grad_fn = None func_call = None + is_partitioned_call = _IsPartitionedCall(op) # pylint: disable=protected-access - is_func_call = src_graph._is_function(op.type) + is_func_call = ( + src_graph._is_function(op.type) or is_partitioned_call) # pylint: enable=protected-access has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) if has_out_grads and (op not in stop_ops): if is_func_call: - func_call = src_graph._get_function(op.type) # pylint: disable=protected-access + if is_partitioned_call: + func_call = src_graph._get_function( # pylint: disable=protected-access + compat.as_bytes(op.get_attr("f").name)) + else: + func_call = src_graph._get_function(op.type) # pylint: disable=protected-access # Note that __defun is not set if the graph is # imported. If it's set, we prefer to access the original # defun. @@ -671,7 +746,7 @@ def _GradientsHelper(ys, op._control_flow_context.IsWhileContext() and op._control_flow_context == ops.get_default_graph()._get_control_flow_context()): - _RaiseNoGradWrtInitialLoopValError(op, from_ops) + _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs) # pylint: enable=protected-access if (grad_fn or is_func_call) and has_out_grads: @@ -703,7 +778,7 @@ def _GradientsHelper(ys, # For function call ops, we add a 'SymbolicGradient' # node to the graph to compute gradients. in_grads = _MaybeCompile(grad_scope, op, func_call, - lambda: _SymGrad(op, out_grads)) + lambda: _SymGrad(op, out_grads, xs)) in_grads = _AsList(in_grads) _VerifyGeneratedGradients(in_grads, op) if gate_gradients and len([x for x in in_grads @@ -718,8 +793,8 @@ def _GradientsHelper(ys, else: # If no grad_fn is defined or none of out_grads is available, # just propagate a list of None backwards. - in_grads = [None] * len(op.inputs) - for i, (t_in, in_grad) in enumerate(zip(op.inputs, in_grads)): + in_grads = [None] * len(_Inputs(op, xs)) + for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)): if in_grad is not None: if (isinstance(in_grad, ops.Tensor) and t_in.dtype != dtypes.resource): @@ -737,7 +812,8 @@ def _GradientsHelper(ys, loop_state.ExitGradWhileContext(op, before=False) # Update pending count for the inputs of op and enqueue ready ops. - _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state) + _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, + xs) if loop_state: loop_state.PostProcessing() @@ -756,9 +832,10 @@ def _HasAnyNotNoneGrads(grads, op): return False -def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state): +def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, + xs): """Update pending count for the inputs of op and enqueue ready ops.""" - for x in op.inputs: + for x in _Inputs(op, xs): pending_count[x.op] -= 1 ready = (pending_count[x.op] == 0) if loop_state and not ready: diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index d81c756f1cbc0a46d094066cda369067f7d3d1f6..d02fcf4ee27c180003e5b026e486a4ec0ad11e7d 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -57,90 +57,8 @@ from tensorflow.python.ops.nn_ops import bias_add from tensorflow.python.platform import googletest -def _OpsBetween(to_ops, from_ops): - """Build the list of operations between two lists of Operations. - - Args: - to_ops: list of Operations. - from_ops: list of Operations. - - Returns: - The list of operations between "from_ops" and "to_ops", sorted by - decreasing operation id. This list contains all elements of to_ops. - - TODO(touts): Think about returning an empty list if from_ops are not - reachable from to_ops. Presently it returns to_ops in that case. - """ - # Ops that are reachable from the output of "input_ops". - reached_ops = set() - # We only care to reach up to "output_ops" so we mark the - # output ops as reached to avoid recursing past them. - for op in to_ops: - reached_ops.add(op) - gradients_impl._MarkReachedOps(from_ops, reached_ops) - between_ops = gradients_impl._GatherInputs(to_ops, reached_ops) - between_ops.sort(key=lambda x: -x._id) - return between_ops - - class GradientsTest(test_util.TensorFlowTestCase): - def _OpNames(self, op_list): - return ["%s/%d" % (str(op.name), op._id) for op in op_list] - - def _assertOpListEqual(self, ops1, ops2): - self.assertEquals(self._OpNames(ops1), self._OpNames(ops2)) - - def testOpsBetweenSimple(self): - with ops.Graph().as_default(): - t1 = constant(1.0) - t2 = constant(2.0) - t3 = array_ops.stack([t1, t2]) - # Full graph - self._assertOpListEqual([t3.op, t2.op, t1.op], - _OpsBetween([t3.op], [t1.op, t2.op])) - # Only t1, t3. - self._assertOpListEqual([t3.op, t1.op], _OpsBetween([t3.op], [t1.op])) - - def testOpsBetweenUnreachable(self): - with ops.Graph().as_default(): - t1 = constant(1.0) - t2 = constant(2.0) - _ = array_ops.stack([t1, t2]) - t4 = constant(1.0) - t5 = constant(2.0) - t6 = array_ops.stack([t4, t5]) - # Elements of to_ops are always listed. - self._assertOpListEqual([t6.op], _OpsBetween([t6.op], [t1.op])) - - def testOpsBetweenCut(self): - with ops.Graph().as_default(): - t1 = constant(1.0) - t2 = constant(2.0) - t3 = array_ops.stack([t1, t2]) - t4 = constant([1.0]) - t5 = array_ops.concat([t4, t3], 0) - t6 = constant([2.0]) - t7 = array_ops.concat([t5, t6], 0) - self._assertOpListEqual([t7.op, t5.op, t4.op], - _OpsBetween([t7.op], [t4.op])) - - def testOpsBetweenCycle(self): - with ops.Graph().as_default(): - t1 = constant(1.0) - t2 = constant(2.0) - t3 = array_ops.stack([t1, t2]) - t4 = array_ops.concat([t3, t3, t3], 0) - t5 = constant([1.0]) - t6 = array_ops.concat([t4, t5], 0) - t7 = array_ops.concat([t6, t3], 0) - self._assertOpListEqual([t6.op, t4.op, t3.op], - _OpsBetween([t6.op], [t3.op])) - self._assertOpListEqual([t7.op, t6.op, t5.op, t4.op, t3.op, t1.op], - _OpsBetween([t7.op], [t1.op, t5.op])) - self._assertOpListEqual([t6.op, t5.op, t4.op, t3.op, t2.op], - _OpsBetween([t6.op], [t2.op, t5.op])) - def testGradients(self): with ops.Graph().as_default(): inp = constant(1.0, shape=[32, 100], name="in") @@ -519,6 +437,96 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): grad_func=grad_func, python_grad_func=self._PythonGradient) f.add_to_graph(ops.Graph()) + def testGradientWrtCaptured(self): + with ops.Graph().as_default(): + x = constant_op.constant(1.0, name="x") + + @function.Defun() + def Foo(): + y = math_ops.multiply(x, 2.0, name="y") + g = gradients_impl.gradients(y, x) + return g[0] + + f = Foo() + with self.test_session() as sess: + self.assertEqual(sess.run(f), 2.0) + + def testGradientOfCaptured(self): + with ops.Graph().as_default(): + x = constant_op.constant(1.0, name="x") + y = math_ops.multiply(x, 2.0, name="y") + + @function.Defun() + def Foo(): + g = gradients_impl.gradients(y, x) + return g[0] + + f = Foo() + with self.test_session() as sess: + self.assertEqual(sess.run(f), 2.0) + + def testCapturedResourceVariable(self): + with ops.Graph().as_default(): + var = resource_variable_ops.ResourceVariable(1.0, name="var") + + @function.Defun() + def Foo(): + y = math_ops.multiply(var, 2.0, name="y") + g = gradients_impl.gradients(y, var) + return g[0] + + f = Foo() + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertEqual(sess.run(f), 2.0) + + def testCapturedNested(self): + with ops.Graph().as_default(): + x1 = constant_op.constant(1.0, name="x1") + x2 = constant_op.constant(2.0, name="x2") + x3 = math_ops.multiply(x1, x2, name="x3") + + @function.Defun() + def Outer(): + outer1 = array_ops.identity(x1, name="outer1") + + @function.Defun() + def Inner(): + inner1 = array_ops.identity(outer1, name="inner1") + inner2 = array_ops.identity(x2, name="inner2") + inner3 = array_ops.identity(x3, name="inner3") + return gradients_impl.gradients([inner1, inner2, inner3, x1], + [x1, x2]) + + return Inner() + + x1_grad, x2_grad = Outer() + with self.test_session() as sess: + # 1.0 + None + 2.0 + 1.0 = 4.0 + self.assertEqual(sess.run(x1_grad), 4.0) + # None + 1.0 + 1.0 + None = 2.0 + self.assertEqual(sess.run(x2_grad), 2.0) + + def testCapturedFromFunction(self): + with ops.Graph().as_default(): + x = constant_op.constant(1.0, name="x") + + @function.Defun() + def Outer(): + y = math_ops.multiply(x, 2.0, name="y") + + @function.Defun() + def Inner(): + z = math_ops.multiply(y, 3.0, name="z") + g = gradients_impl.gradients(z, y) + return g[0] + + return Inner() + + z_grad = Outer() + with self.test_session() as sess: + self.assertEqual(sess.run(z_grad), 3.0) + class StopGradientTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/ops/histogram_ops_test.py b/tensorflow/python/ops/histogram_ops_test.py index a226ac81bb536934cd191872ffc1aca84925abc0..2e57ae8a2dd5dcc0398955f44d3c46e3097522b1 100644 --- a/tensorflow/python/ops/histogram_ops_test.py +++ b/tensorflow/python/ops/histogram_ops_test.py @@ -84,6 +84,23 @@ class HistogramFixedWidthTest(test.TestCase): def setUp(self): self.rng = np.random.RandomState(0) + def test_with_invalid_value_range(self): + values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 1 but is rank 0"): + histogram_ops.histogram_fixed_width(values, 1.0) + with self.assertRaisesRegexp(ValueError, "Dimension must be 2 but is 3"): + histogram_ops.histogram_fixed_width(values, [1.0, 2.0, 3.0]) + + def test_with_invalid_nbins(self): + values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + histogram_ops.histogram_fixed_width(values, [1.0, 5.0], nbins=[1, 2]) + with self.assertRaisesRegexp( + ValueError, "Requires nbins > 0"): + histogram_ops.histogram_fixed_width(values, [1.0, 5.0], nbins=-5) + def test_empty_input_gives_all_zero_counts(self): # Bins will be: # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index f27d9224c15cc9c3791d1d56454415ed95d196bd..12356944f8b4be695e90a4f1d978c68faa626e82 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -55,8 +56,10 @@ ops.NotDifferentiable('SampleDistortedBoundingBoxV2') ops.NotDifferentiable('ExtractGlimpse') ops.NotDifferentiable('NonMaxSuppression') ops.NotDifferentiable('NonMaxSuppressionV2') +ops.NotDifferentiable('NonMaxSuppressionWithOverlaps') +# pylint: disable=invalid-name def _assert(cond, ex_type, msg): """A polymorphic assert, works with tensors and boolean expressions. @@ -262,7 +265,7 @@ def random_flip_up_down(image, seed=None): image: 4-D Tensor of shape `[batch, height, width, channels]` or 3-D Tensor of shape `[height, width, channels]`. seed: A Python integer. Used to create a random seed. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. Returns: @@ -284,7 +287,7 @@ def random_flip_left_right(image, seed=None): image: 4-D Tensor of shape `[batch, height, width, channels]` or 3-D Tensor of shape `[height, width, channels]`. seed: A Python integer. Used to create a random seed. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. Returns: @@ -304,7 +307,7 @@ def _random_flip(image, flip_index, seed, scope_name): flip_index: The dimension along which to flip the image. Vertical: 0, Horizontal: 1 seed: A Python integer. Used to create a random seed. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. scope_name: Name of the scope in which the ops are added. @@ -945,7 +948,7 @@ def resize_images(images, Resized images will be distorted if their original aspect ratio is not the same as `size`. To avoid distortions see - @{tf.image.resize_image_with_crop_or_pad}. + `tf.image.resize_image_with_pad`. `method` can be one of: @@ -1069,6 +1072,106 @@ def resize_images(images, return images +@tf_export('image.resize_image_with_pad') +def resize_image_with_pad(image, + target_height, + target_width, + method=ResizeMethod.BILINEAR): + """Resizes and pads an image to a target width and height. + + Resizes an image to a target width and height by keeping + the aspect ratio the same without distortion. If the target + dimensions don't match the image dimensions, the image + is resized and then padded with zeroes to match requested + dimensions. + + Args: + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. + target_height: Target height. + target_width: Target width. + method: Method to use for resizing image. See `resize_images()` + + Raises: + ValueError: if `target_height` or `target_width` are zero or negative. + + Returns: + Resized and padded image. + If `images` was 4-D, a 4-D float Tensor of shape + `[batch, new_height, new_width, channels]`. + If `images` was 3-D, a 3-D float Tensor of shape + `[new_height, new_width, channels]`. + """ + with ops.name_scope(None, 'resize_image_with_pad', [image]): + image = ops.convert_to_tensor(image, name='image') + image_shape = image.get_shape() + is_batch = True + if image_shape.ndims == 3: + is_batch = False + image = array_ops.expand_dims(image, 0) + elif image_shape.ndims is None: + is_batch = False + image = array_ops.expand_dims(image, 0) + image.set_shape([None] * 4) + elif image_shape.ndims != 4: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') + + assert_ops = _CheckAtLeast3DImage(image, require_static=False) + assert_ops += _assert(target_width > 0, ValueError, + 'target_width must be > 0.') + assert_ops += _assert(target_height > 0, ValueError, + 'target_height must be > 0.') + + image = control_flow_ops.with_dependencies(assert_ops, image) + + def max_(x, y): + if _is_tensor(x) or _is_tensor(y): + return math_ops.maximum(x, y) + else: + return max(x, y) + + _, height, width, _ = _ImageDimensions(image, rank=4) + + # convert values to float, to ease divisions + f_height = math_ops.cast(height, dtype=dtypes.float64) + f_width = math_ops.cast(width, dtype=dtypes.float64) + f_target_height = math_ops.cast(target_height, dtype=dtypes.float64) + f_target_width = math_ops.cast(target_width, dtype=dtypes.float64) + + # Find the ratio by which the image must be adjusted + # to fit within the target + ratio = max_(f_width / f_target_width, f_height / f_target_height) + resized_height_float = f_height / ratio + resized_width_float = f_width / ratio + resized_height = math_ops.cast( + math_ops.floor(resized_height_float), dtype=dtypes.int32) + resized_width = math_ops.cast( + math_ops.floor(resized_width_float), dtype=dtypes.int32) + + padding_height = (f_target_height - resized_height_float) / 2 + padding_width = (f_target_width - resized_width_float) / 2 + f_padding_height = math_ops.floor(padding_height) + f_padding_width = math_ops.floor(padding_width) + p_height = max_(0, math_ops.cast(f_padding_height, dtype=dtypes.int32)) + p_width = max_(0, math_ops.cast(f_padding_width, dtype=dtypes.int32)) + + # Resize first, then pad to meet requested dimensions + resized = resize_images(image, [resized_height, resized_width], method) + + padded = pad_to_bounding_box(resized, p_height, p_width, target_height, + target_width) + + if padded.get_shape().ndims is None: + raise ValueError('padded contains no shape.') + + _ImageDimensions(padded, rank=4) + + if not is_batch: + padded = array_ops.squeeze(padded, axis=[0]) + + return padded + + @tf_export('image.per_image_standardization') def per_image_standardization(image): """Linearly scales `image` to have zero mean and unit norm. @@ -1124,7 +1227,7 @@ def random_brightness(image, max_delta, seed=None): image: An image. max_delta: float, must be non-negative. seed: A Python integer. Used to create a random seed. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. Returns: @@ -1152,7 +1255,7 @@ def random_contrast(image, lower, upper, seed=None): lower: float. Lower bound for the random contrast factor. upper: float. Upper bound for the random contrast factor. seed: A Python integer. Used to create a random seed. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. Returns: @@ -1651,6 +1754,22 @@ def is_jpeg(contents, name=None): return math_ops.equal(substr, b'\xff\xd8\xff', name=name) +def _is_png(contents, name=None): + r"""Convenience function to check if the 'contents' encodes a PNG image. + + Args: + contents: 0-D `string`. The encoded image bytes. + name: A name for the operation (optional) + + Returns: + A scalar boolean tensor indicating if 'contents' may be a PNG image. + is_png is susceptible to false positives. + """ + with ops.name_scope(name, 'is_png'): + substr = string_ops.substr(contents, 0, 3) + return math_ops.equal(substr, b'\211PN', name=name) + + @tf_export('image.decode_image') def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None): """Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`, @@ -1728,8 +1847,8 @@ def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None): def check_png(): """Checks if an image is PNG.""" - is_png = math_ops.equal(substr, b'\211PN', name='is_png') - return control_flow_ops.cond(is_png, _png, check_gif, name='cond_png') + return control_flow_ops.cond( + _is_png(contents), _png, check_gif, name='cond_png') def _jpeg(): """Decodes a jpeg image.""" @@ -1992,6 +2111,108 @@ def non_max_suppression(boxes, iou_threshold, score_threshold) +@tf_export('image.non_max_suppression_padded') +def non_max_suppression_padded(boxes, + scores, + max_output_size, + iou_threshold=0.5, + score_threshold=float('-inf'), + pad_to_max_output_size=False, + name=None): + """Greedily selects a subset of bounding boxes in descending order of score. + + Performs algorithmically equivalent operation to tf.image.non_max_suppression, + with the addition of an optional parameter which zero-pads the output to + be of size `max_output_size`. + The output of this operation is a tuple containing the set of integers + indexing into the input collection of bounding boxes representing the selected + boxes and the number of valid indices in the index set. The bounding box + coordinates corresponding to the selected indices can then be obtained using + the `tf.slice` and `tf.gather` operations. For example: + selected_indices_padded, num_valid = tf.image.non_max_suppression_padded( + boxes, scores, max_output_size, iou_threshold, + score_threshold, pad_to_max_output_size=True) + selected_indices = tf.slice( + selected_indices_padded, tf.constant([0]), num_valid) + selected_boxes = tf.gather(boxes, selected_indices) + + Args: + boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`. + scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single + score corresponding to each box (each row of boxes). + max_output_size: A scalar integer `Tensor` representing the maximum number + of boxes to be selected by non max suppression. + iou_threshold: A float representing the threshold for deciding whether boxes + overlap too much with respect to IOU. + score_threshold: A float representing the threshold for deciding when to + remove boxes based on score. + pad_to_max_output_size: bool. If True, size of `selected_indices` output + is padded to `max_output_size`. + name: A name for the operation (optional). + + Returns: + selected_indices: A 1-D integer `Tensor` of shape `[M]` representing the + selected indices from the boxes tensor, where `M <= max_output_size`. + valid_outputs: A scalar integer `Tensor` denoting how many elements in + `selected_indices` are valid. Valid elements occur first, then padding. + """ + with ops.name_scope(name, 'non_max_suppression_padded'): + iou_threshold = ops.convert_to_tensor(iou_threshold, name='iou_threshold') + score_threshold = ops.convert_to_tensor( + score_threshold, name='score_threshold') + if compat.forward_compatible(2018, 8, 7) or pad_to_max_output_size: + return gen_image_ops.non_max_suppression_v4( + boxes, scores, max_output_size, iou_threshold, score_threshold, + pad_to_max_output_size) + else: + return gen_image_ops.non_max_suppression_v3( + boxes, scores, max_output_size, iou_threshold, score_threshold) + + +@tf_export('image.non_max_suppression_overlaps') +def non_max_suppression_with_overlaps(overlaps, + scores, + max_output_size, + overlap_threshold=0.5, + score_threshold=float('-inf'), + name=None): + """Greedily selects a subset of bounding boxes in descending order of score. + + Prunes away boxes that have high overlap with previously selected boxes. + N-by-n overlap values are supplied as square matrix. + The output of this operation is a set of integers indexing into the input + collection of bounding boxes representing the selected boxes. The bounding + box coordinates corresponding to the selected indices can then be obtained + using the `tf.gather operation`. For example: + selected_indices = tf.image.non_max_suppression_overlaps( + overlaps, scores, max_output_size, iou_threshold) + selected_boxes = tf.gather(boxes, selected_indices) + + Args: + overlaps: A 2-D float `Tensor` of shape `[num_boxes, num_boxes]`. + scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single + score corresponding to each box (each row of boxes). + max_output_size: A scalar integer `Tensor` representing the maximum number + of boxes to be selected by non max suppression. + overlap_threshold: A float representing the threshold for deciding whether + boxes overlap too much with respect to the provided overlap values. + score_threshold: A float representing the threshold for deciding when to + remove boxes based on score. + name: A name for the operation (optional). + + Returns: + selected_indices: A 1-D integer `Tensor` of shape `[M]` representing the + selected indices from the overlaps tensor, where `M <= max_output_size`. + """ + with ops.name_scope(name, 'non_max_suppression_overlaps'): + overlap_threshold = ops.convert_to_tensor( + overlap_threshold, name='overlap_threshold') + # pylint: disable=protected-access + return gen_image_ops._non_max_suppression_v3( + overlaps, scores, max_output_size, overlap_threshold, score_threshold) + # pylint: enable=protected-access + + _rgb_to_yiq_kernel = [[0.299, 0.59590059, 0.2115], [0.587, -0.27455667, -0.52273617], [0.114, -0.32134392, 0.31119955]] diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 2a6ab26e96b2339f005cbe8712f7b8f15304b1cc..2c61bb232a1484c9180c05da74eb3e56e877dd69 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -1410,6 +1410,14 @@ class AdjustContrastTest(test_util.TensorFlowTestCase): y_tf = self._adjustContrastTf(x_np, contrast_factor) self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5) + def testContrastFactorShape(self): + x_shape = [1, 2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 0 but is rank 1'): + image_ops.adjust_contrast(x_np, [2.0]) + class AdjustBrightnessTest(test_util.TensorFlowTestCase): @@ -1956,7 +1964,7 @@ class PadToBoundingBoxTest(test_util.TensorFlowTestCase): "all dims of 'image.shape' must be > 0", use_tensor_inputs_options=[False]) - # The orignal error message does not contain back slashes. However, they + # The original error message does not contain back slashes. However, they # are added by either the assert op or the runtime. If this behavior # changes in the future, the match string will also needs to be changed. self._assertRaises( @@ -2680,6 +2688,102 @@ class ResizeImagesTest(test_util.TensorFlowTestCase): self._assertResizeCheckShape(x, x_shape, [3840, 2160], [3840, 2160, 3]) +class ResizeImageWithPadTest(test_util.TensorFlowTestCase): + + def _ResizeImageWithPad(self, x, target_height, target_width, + use_tensor_inputs): + if use_tensor_inputs: + target_height = ops.convert_to_tensor(target_height) + target_width = ops.convert_to_tensor(target_width) + x_tensor = array_ops.placeholder(x.dtype, shape=[None] * x.ndim) + feed_dict = {x_tensor: x} + else: + x_tensor = x + feed_dict = {} + + y = image_ops.resize_image_with_pad(x_tensor, target_height, + target_width) + if not use_tensor_inputs: + self.assertTrue(y.get_shape().is_fully_defined()) + + with self.test_session(use_gpu=True): + return y.eval(feed_dict=feed_dict) + + def _assertReturns(self, + x, + x_shape, + y, + y_shape, + use_tensor_inputs_options=None): + use_tensor_inputs_options = use_tensor_inputs_options or [False, True] + target_height, target_width, _ = y_shape + x = np.array(x).reshape(x_shape) + y = np.array(y).reshape(y_shape) + + for use_tensor_inputs in use_tensor_inputs_options: + y_tf = self._ResizeImageWithPad(x, target_height, target_width, + use_tensor_inputs) + self.assertAllClose(y, y_tf) + + def _assertRaises(self, + x, + x_shape, + target_height, + target_width, + err_msg, + use_tensor_inputs_options=None): + use_tensor_inputs_options = use_tensor_inputs_options or [False, True] + x = np.array(x).reshape(x_shape) + + for use_tensor_inputs in use_tensor_inputs_options: + try: + self._ResizeImageWithPad(x, target_height, target_width, + use_tensor_inputs) + except Exception as e: # pylint: disable=broad-except + if err_msg not in str(e): + raise + else: + raise AssertionError("Exception not raised: %s" % err_msg) + + def _assertShapeInference(self, pre_shape, height, width, post_shape): + image = array_ops.placeholder(dtypes.float32, shape=pre_shape) + y = image_ops.resize_image_with_pad(image, height, width) + self.assertEqual(y.get_shape().as_list(), post_shape) + + def testNoOp(self): + x_shape = [10, 10, 10] + x = np.random.uniform(size=x_shape) + + self._assertReturns(x, x_shape, x, x_shape) + + def testPad(self): + # Reduce vertical dimension + x = [1, 2, 3, 4, 5, 6, 7, 8] + x_shape = [2, 4, 1] + + y = [0, 1, 3, 0] + y_shape = [1, 4, 1] + + self._assertReturns(x, x_shape, y, y_shape) + + # Reduce horizontal dimension + x = [1, 2, 3, 4, 5, 6, 7, 8] + x_shape = [2, 4, 1] + + y = [1, 3, 0, 0] + y_shape = [2, 2, 1] + + self._assertReturns(x, x_shape, y, y_shape) + + x = [1, 2, 3, 4, 5, 6, 7, 8] + x_shape = [2, 4, 1] + + y = [1, 3] + y_shape = [1, 2, 1] + + self._assertReturns(x, x_shape, y, y_shape) + + class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase): def _ResizeImageWithCropOrPad(self, x, target_height, target_width, @@ -2889,7 +2993,7 @@ class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase): "all dims of 'image.shape' must be > 0", use_tensor_inputs_options=[False]) - # The orignal error message does not contain back slashes. However, they + # The original error message does not contain back slashes. However, they # are added by either the assert op or the runtime. If this behavior # changes in the future, the match string will also needs to be changed. self._assertRaises( @@ -3105,7 +3209,8 @@ class PngTest(test_util.TensorFlowTestCase): def testExisting(self): # Read some real PNGs, converting to different channel numbers prefix = "tensorflow/core/lib/png/testdata/" - inputs = (1, "lena_gray.png"), (4, "lena_rgba.png") + inputs = ((1, "lena_gray.png"), (4, "lena_rgba.png"), + (3, "lena_palette.png"), (4, "lena_palette_trns.png")) for channels_in, filename in inputs: for channels in 0, 1, 3, 4: with self.test_session(use_gpu=True) as sess: @@ -3553,6 +3658,41 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase): image_ops.non_max_suppression(boxes, scores, 3, [[0.5]]) +class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase): + + def testSelectFromThreeClusters(self): + boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + max_output_size_np = 5 + iou_threshold_np = 0.5 + boxes = constant_op.constant(boxes_np) + scores = constant_op.constant(scores_np) + max_output_size = constant_op.constant(max_output_size_np) + iou_threshold = constant_op.constant(iou_threshold_np) + selected_indices_padded, num_valid_padded = \ + image_ops.non_max_suppression_padded( + boxes, + scores, + max_output_size, + iou_threshold, + pad_to_max_output_size=True) + selected_indices, num_valid = image_ops.non_max_suppression_padded( + boxes, + scores, + max_output_size, + iou_threshold, + pad_to_max_output_size=False) + # The output shape of the padded operation must be fully defined. + self.assertEqual(selected_indices_padded.shape.is_fully_defined(), True) + self.assertEqual(selected_indices.shape.is_fully_defined(), False) + with self.test_session(): + self.assertAllClose(selected_indices_padded.eval(), [3, 0, 5, 0, 0]) + self.assertEqual(num_valid_padded.eval(), 3) + self.assertAllClose(selected_indices.eval(), [3, 0, 5]) + self.assertEqual(num_valid.eval(), 3) + + class VerifyCompatibleImageShapesTest(test_util.TensorFlowTestCase): """Tests utility function used by ssim() and psnr().""" diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index 724fcc39cddbdd2a8acd9c0bbaa7b968c6d1510d..4d75ee3974807a3ec00f6813edea7072d2f0bd8d 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -43,7 +43,8 @@ from tensorflow.python.ops import linalg_ops_impl from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.deprecation import ( + deprecated, deprecated_arg_values) from tensorflow.python.util.tf_export import tf_export @@ -237,7 +238,7 @@ class RandomUniform(Initializer): maxval: A python scalar or a scalar tensor. Upper bound of the range of random values to generate. Defaults to 1 for float types. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. dtype: The data type. """ @@ -275,7 +276,7 @@ class RandomNormal(Initializer): stddev: a python scalar or a scalar tensor. Standard deviation of the random values to generate. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. dtype: The data type. Only floating point types are supported. """ @@ -318,7 +319,7 @@ class TruncatedNormal(Initializer): stddev: a python scalar or a scalar tensor. Standard deviation of the random values to generate. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. dtype: The data type. Only floating point types are supported. """ @@ -368,7 +369,7 @@ class UniformUnitScaling(Initializer): Args: factor: Float. A multiplicative factor by which the values will be scaled. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. dtype: The data type. Only floating point types are supported. """ @@ -409,8 +410,10 @@ class UniformUnitScaling(Initializer): class VarianceScaling(Initializer): """Initializer capable of adapting its scale to the shape of weights tensors. - With `distribution="normal"`, samples are drawn from a truncated normal - distribution centered on zero, with `stddev = sqrt(scale / n)` + With `distribution="truncated_normal" or "untruncated_normal"`, + samples are drawn from a truncated/untruncated normal + distribution with a mean of zero and a standard deviation (after truncation, + if used) `stddev = sqrt(scale / n)` where n is: - number of input units in the weight tensor, if mode = "fan_in" - number of output units, if mode = "fan_out" @@ -424,7 +427,7 @@ class VarianceScaling(Initializer): mode: One of "fan_in", "fan_out", "fan_avg". distribution: Random distribution to use. One of "normal", "uniform". seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. dtype: The data type. Only floating point types are supported. @@ -433,10 +436,14 @@ class VarianceScaling(Initializer): "distribution" arguments. """ + @deprecated_arg_values( + None, + "`normal` is a deprecated alias for `truncated_normal`", + distribution="normal") def __init__(self, scale=1.0, mode="fan_in", - distribution="normal", + distribution="truncated_normal", seed=None, dtype=dtypes.float32): if scale <= 0.: @@ -444,7 +451,8 @@ class VarianceScaling(Initializer): if mode not in {"fan_in", "fan_out", "fan_avg"}: raise ValueError("Invalid `mode` argument:", mode) distribution = distribution.lower() - if distribution not in {"normal", "uniform"}: + if distribution not in {"normal", "uniform", + "truncated_normal", "untruncated_normal"}: raise ValueError("Invalid `distribution` argument:", distribution) self.scale = scale self.mode = mode @@ -466,11 +474,15 @@ class VarianceScaling(Initializer): scale /= max(1., fan_out) else: scale /= max(1., (fan_in + fan_out) / 2.) - if self.distribution == "normal": + if self.distribution == "normal" or self.distribution == "truncated_normal": # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) stddev = math.sqrt(scale) / .87962566103423978 return random_ops.truncated_normal( shape, 0.0, stddev, dtype, seed=self.seed) + elif self.distribution == "untruncated_normal": + stddev = math.sqrt(scale) + return random_ops.random_normal( + shape, 0.0, stddev, dtype, seed=self.seed) else: limit = math.sqrt(3.0 * scale) return random_ops.random_uniform( @@ -505,7 +517,7 @@ class Orthogonal(Initializer): Args: gain: multiplicative factor to apply to the orthogonal matrix seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. dtype: The data type. """ @@ -551,14 +563,16 @@ class ConvolutionDeltaOrthogonal(Initializer): The shape of the tensor must have length 3, 4 or 5. The number of input filters must not exceed the number of output filters. The center pixels of the - tensor form an orthogonal matrix. Other pixels are set to be zero. + tensor form an orthogonal matrix. Other pixels are set to be zero. See + algorithm 2 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393 + Args: gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. The 2-norm of an input is multiplied by a factor of 'sqrt(gain)' after applying this convolution. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. dtype: The data type. """ @@ -614,7 +628,7 @@ class ConvolutionOrthogonal(Initializer): The 2-norm of an input is multiplied by a factor of 'sqrt(gain)' after applying this convolution. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. dtype: The data type. """ @@ -672,13 +686,14 @@ class ConvolutionOrthogonal2D(ConvolutionOrthogonal): filters must not exceed the number of output filters. The orthogonality(==isometry) is exact when the inputs are circular padded. There are finite-width effects with non-circular padding (e.g. zero padding). + See algorithm 1 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393 Args: gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. This has the effect of scaling the output 2-norm by a factor of `sqrt(gain)`. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. dtype: The data type. """ @@ -807,13 +822,14 @@ class ConvolutionOrthogonal1D(ConvolutionOrthogonal): filters must not exceed the number of output filters. The orthogonality(==isometry) is exact when the inputs are circular padded. There are finite-width effects with non-circular padding (e.g. zero padding). + See algorithm 1 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393 Args: gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. The 2-norm of an input is multiplied by a factor of 'sqrt(gain)' after applying this convolution. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. dtype: The data type. """ @@ -923,13 +939,14 @@ class ConvolutionOrthogonal3D(ConvolutionOrthogonal): filters must not exceed the number of output filters. The orthogonality(==isometry) is exact when the inputs are circular padded. There are finite-width effects with non-circular padding (e.g. zero padding). + See algorithm 1 [Xiao et al., 2018] in: https://arxiv.org/abs/1806.05393 Args: gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. The 2-norm of an input is multiplied by a factor of 'sqrt(gain)' after applying this convolution. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} for behavior. + `tf.set_random_seed` for behavior. dtype: The data type. """ @@ -1119,7 +1136,8 @@ convolutional_orthogonal_3d = ConvolutionOrthogonal3D # pylint: enable=invalid-name -@tf_export("glorot_uniform_initializer") +@tf_export("glorot_uniform_initializer", "keras.initializers.glorot_uniform", + "initializers.glorot_uniform") def glorot_uniform_initializer(seed=None, dtype=dtypes.float32): """The Glorot uniform initializer, also called Xavier uniform initializer. @@ -1132,7 +1150,7 @@ def glorot_uniform_initializer(seed=None, dtype=dtypes.float32): Args: seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. dtype: The data type. Only floating point types are supported. @@ -1143,7 +1161,8 @@ def glorot_uniform_initializer(seed=None, dtype=dtypes.float32): scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype) -@tf_export("glorot_normal_initializer") +@tf_export("glorot_normal_initializer", "keras.initializers.glorot_normal", + "initializers.glorot_normal") def glorot_normal_initializer(seed=None, dtype=dtypes.float32): """The Glorot normal initializer, also called Xavier normal initializer. @@ -1156,7 +1175,7 @@ def glorot_normal_initializer(seed=None, dtype=dtypes.float32): Args: seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. dtype: The data type. Only floating point types are supported. @@ -1164,7 +1183,98 @@ def glorot_normal_initializer(seed=None, dtype=dtypes.float32): An initializer. """ return variance_scaling_initializer( - scale=1.0, mode="fan_avg", distribution="normal", seed=seed, dtype=dtype) + scale=1.0, + mode="fan_avg", + distribution="truncated_normal", + seed=seed, + dtype=dtype) + + +@tf_export("keras.initializers.lecun_normal", "initializers.lecun_normal") +def lecun_normal(seed=None): + """LeCun normal initializer. + + It draws samples from a truncated normal distribution centered on 0 + with `stddev = sqrt(1 / fan_in)` + where `fan_in` is the number of input units in the weight tensor. + + Arguments: + seed: A Python integer. Used to seed the random generator. + + Returns: + An initializer. + + References: + - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) + - [Efficient + Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf) + """ + return VarianceScaling( + scale=1., mode="fan_in", distribution="truncated_normal", seed=seed) + + +@tf_export("keras.initializers.lecun_uniform", "initializers.lecun_uniform") +def lecun_uniform(seed=None): + """LeCun uniform initializer. + + It draws samples from a uniform distribution within [-limit, limit] + where `limit` is `sqrt(3 / fan_in)` + where `fan_in` is the number of input units in the weight tensor. + + Arguments: + seed: A Python integer. Used to seed the random generator. + + Returns: + An initializer. + + References: + LeCun 98, Efficient Backprop, + http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf + """ + return VarianceScaling( + scale=1., mode="fan_in", distribution="uniform", seed=seed) + + +@tf_export("keras.initializers.he_normal", "initializers.he_normal") +def he_normal(seed=None): + """He normal initializer. + + It draws samples from a truncated normal distribution centered on 0 + with `stddev = sqrt(2 / fan_in)` + where `fan_in` is the number of input units in the weight tensor. + + Arguments: + seed: A Python integer. Used to seed the random generator. + + Returns: + An initializer. + + References: + He et al., http://arxiv.org/abs/1502.01852 + """ + return VarianceScaling( + scale=2., mode="fan_in", distribution="truncated_normal", seed=seed) + + +@tf_export("keras.initializers.he_uniform", "initializers.he_uniform") +def he_uniform(seed=None): + """He uniform variance scaling initializer. + + It draws samples from a uniform distribution within [-limit, limit] + where `limit` is `sqrt(6 / fan_in)` + where `fan_in` is the number of input units in the weight tensor. + + Arguments: + seed: A Python integer. Used to seed the random generator. + + Returns: + An initializer. + + References: + He et al., http://arxiv.org/abs/1502.01852 + """ + return VarianceScaling( + scale=2., mode="fan_in", distribution="uniform", seed=seed) # Utility functions. diff --git a/tensorflow/python/ops/init_ops_test.py b/tensorflow/python/ops/init_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f6fffa907951e5c09a4c1d59e2bdc7f28d86806b --- /dev/null +++ b/tensorflow/python/ops/init_ops_test.py @@ -0,0 +1,196 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for initializers in init_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test + + +class InitializersTest(test.TestCase): + + def _runner(self, + init, + shape, + target_mean=None, + target_std=None, + target_max=None, + target_min=None): + variable = resource_variable_ops.ResourceVariable(init(shape)) + if context.executing_eagerly(): + output = variable.numpy() + else: + sess = ops.get_default_session() + sess.run(variable.initializer) + output = sess.run(variable) + lim = 3e-2 + if target_std is not None: + self.assertGreater(lim, abs(output.std() - target_std)) + if target_mean is not None: + self.assertGreater(lim, abs(output.mean() - target_mean)) + if target_max is not None: + self.assertGreater(lim, abs(output.max() - target_max)) + if target_min is not None: + self.assertGreater(lim, abs(output.min() - target_min)) + + def test_uniform(self): + tensor_shape = (9, 6, 7) + with self.test_session(): + self._runner( + init_ops.RandomUniform(minval=-1, maxval=1, seed=124), + tensor_shape, + target_mean=0., + target_max=1, + target_min=-1) + + def test_normal(self): + tensor_shape = (8, 12, 99) + with self.test_session(): + self._runner( + init_ops.RandomNormal(mean=0, stddev=1, seed=153), + tensor_shape, + target_mean=0., + target_std=1) + + def test_truncated_normal(self): + tensor_shape = (12, 99, 7) + with self.test_session(): + self._runner( + init_ops.TruncatedNormal(mean=0, stddev=1, seed=126), + tensor_shape, + target_mean=0., + target_max=2, + target_min=-2) + + def test_constant(self): + tensor_shape = (5, 6, 4) + with self.test_session(): + self._runner( + init_ops.Constant(2), + tensor_shape, + target_mean=2, + target_max=2, + target_min=2) + + def test_lecun_uniform(self): + tensor_shape = (5, 6, 4, 2) + with self.test_session(): + fan_in, _ = init_ops._compute_fans(tensor_shape) + std = np.sqrt(1. / fan_in) + self._runner( + init_ops.lecun_uniform(seed=123), + tensor_shape, + target_mean=0., + target_std=std) + + def test_glorot_uniform_initializer(self): + tensor_shape = (5, 6, 4, 2) + with self.test_session(): + fan_in, fan_out = init_ops._compute_fans(tensor_shape) + std = np.sqrt(2. / (fan_in + fan_out)) + self._runner( + init_ops.glorot_uniform_initializer(seed=123), + tensor_shape, + target_mean=0., + target_std=std) + + def test_he_uniform(self): + tensor_shape = (5, 6, 4, 2) + with self.test_session(): + fan_in, _ = init_ops._compute_fans(tensor_shape) + std = np.sqrt(2. / fan_in) + self._runner( + init_ops.he_uniform(seed=123), + tensor_shape, + target_mean=0., + target_std=std) + + def test_lecun_normal(self): + tensor_shape = (5, 6, 4, 2) + with self.test_session(): + fan_in, _ = init_ops._compute_fans(tensor_shape) + std = np.sqrt(1. / fan_in) + self._runner( + init_ops.lecun_normal(seed=123), + tensor_shape, + target_mean=0., + target_std=std) + + def test_glorot_normal_initializer(self): + tensor_shape = (5, 6, 4, 2) + with self.test_session(): + fan_in, fan_out = init_ops._compute_fans(tensor_shape) + std = np.sqrt(2. / (fan_in + fan_out)) + self._runner( + init_ops.glorot_normal_initializer(seed=123), + tensor_shape, + target_mean=0., + target_std=std) + + def test_he_normal(self): + tensor_shape = (5, 6, 4, 2) + with self.test_session(): + fan_in, _ = init_ops._compute_fans(tensor_shape) + std = np.sqrt(2. / fan_in) + self._runner( + init_ops.he_normal(seed=123), + tensor_shape, + target_mean=0., + target_std=std) + + def test_Orthogonal(self): + tensor_shape = (20, 20) + with self.test_session(): + self._runner(init_ops.Orthogonal(seed=123), tensor_shape, target_mean=0.) + + def test_Identity(self): + with self.test_session(): + tensor_shape = (3, 4, 5) + with self.assertRaises(ValueError): + self._runner( + init_ops.Identity(), + tensor_shape, + target_mean=1. / tensor_shape[0], + target_max=1.) + + tensor_shape = (3, 3) + self._runner( + init_ops.Identity(), + tensor_shape, + target_mean=1. / tensor_shape[0], + target_max=1.) + + def test_Zeros(self): + tensor_shape = (4, 5) + with self.test_session(): + self._runner( + init_ops.Zeros(), tensor_shape, target_mean=0., target_max=0.) + + def test_Ones(self): + tensor_shape = (4, 5) + with self.test_session(): + self._runner(init_ops.Ones(), tensor_shape, target_mean=1., target_max=1.) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/ops/linalg/BUILD b/tensorflow/python/ops/linalg/BUILD index 07659ef44c443ad15876781d6c6254ae3bc38660..c7314d77749130e4696d58896249b73cc2de4a12 100644 --- a/tensorflow/python/ops/linalg/BUILD +++ b/tensorflow/python/ops/linalg/BUILD @@ -29,6 +29,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:special_math_ops", diff --git a/tensorflow/python/ops/linalg/linalg.py b/tensorflow/python/ops/linalg/linalg.py index a7ba0bbe9cbc4be9daea79cc97eaac4c21523c04..c29b5033bb137e8376e1c19985755b4fc72e8834 100644 --- a/tensorflow/python/ops/linalg/linalg.py +++ b/tensorflow/python/ops/linalg/linalg.py @@ -31,6 +31,7 @@ from tensorflow.python.ops.linalg.linear_operator_identity import * from tensorflow.python.ops.linalg.linear_operator_kronecker import * from tensorflow.python.ops.linalg.linear_operator_low_rank_update import * from tensorflow.python.ops.linalg.linear_operator_lower_triangular import * +from tensorflow.python.ops.linalg.linear_operator_zeros import * # pylint: enable=wildcard-import # Seal API. diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index 8343c62816c6aeadc77dae701ae9917a86e68954..1e3d81798060548d98487f3426184df2df72f123 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -18,8 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops @@ -38,8 +41,6 @@ diag_part = array_ops.matrix_diag_part eigh = linalg_ops.self_adjoint_eig eigvalsh = linalg_ops.self_adjoint_eigvals einsum = special_math_ops.einsum -expm = gen_linalg_ops.matrix_exponential -tf_export('linalg.expm')(expm) eye = linalg_ops.eye inv = linalg_ops.matrix_inverse logm = gen_linalg_ops.matrix_logarithm @@ -114,3 +115,214 @@ def adjoint(matrix, name=None): with ops.name_scope(name, 'adjoint', [matrix]): matrix = ops.convert_to_tensor(matrix, name='matrix') return array_ops.matrix_transpose(matrix, conjugate=True) + + +# This section is ported nearly verbatim from Eigen's implementation: +# https://eigen.tuxfamily.org/dox/unsupported/MatrixExponential_8h_source.html +def _matrix_exp_pade3(matrix): + """3rd-order Pade approximant for matrix exponential.""" + b = [120.0, 60.0, 12.0] + b = [constant_op.constant(x, matrix.dtype) for x in b] + ident = linalg_ops.eye(array_ops.shape(matrix)[-2], + batch_shape=array_ops.shape(matrix)[:-2], + dtype=matrix.dtype) + matrix_2 = math_ops.matmul(matrix, matrix) + tmp = matrix_2 + b[1] * ident + matrix_u = math_ops.matmul(matrix, tmp) + matrix_v = b[2] * matrix_2 + b[0] * ident + return matrix_u, matrix_v + + +def _matrix_exp_pade5(matrix): + """5th-order Pade approximant for matrix exponential.""" + b = [30240.0, 15120.0, 3360.0, 420.0, 30.0] + b = [constant_op.constant(x, matrix.dtype) for x in b] + ident = linalg_ops.eye(array_ops.shape(matrix)[-2], + batch_shape=array_ops.shape(matrix)[:-2], + dtype=matrix.dtype) + matrix_2 = math_ops.matmul(matrix, matrix) + matrix_4 = math_ops.matmul(matrix_2, matrix_2) + tmp = matrix_4 + b[3] * matrix_2 + b[1] * ident + matrix_u = math_ops.matmul(matrix, tmp) + matrix_v = b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident + return matrix_u, matrix_v + + +def _matrix_exp_pade7(matrix): + """7th-order Pade approximant for matrix exponential.""" + b = [17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0] + b = [constant_op.constant(x, matrix.dtype) for x in b] + ident = linalg_ops.eye(array_ops.shape(matrix)[-2], + batch_shape=array_ops.shape(matrix)[:-2], + dtype=matrix.dtype) + matrix_2 = math_ops.matmul(matrix, matrix) + matrix_4 = math_ops.matmul(matrix_2, matrix_2) + matrix_6 = math_ops.matmul(matrix_4, matrix_2) + tmp = matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident + matrix_u = math_ops.matmul(matrix, tmp) + matrix_v = b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident + return matrix_u, matrix_v + + +def _matrix_exp_pade9(matrix): + """9th-order Pade approximant for matrix exponential.""" + b = [ + 17643225600.0, 8821612800.0, 2075673600.0, 302702400.0, 30270240.0, + 2162160.0, 110880.0, 3960.0, 90.0 + ] + b = [constant_op.constant(x, matrix.dtype) for x in b] + ident = linalg_ops.eye(array_ops.shape(matrix)[-2], + batch_shape=array_ops.shape(matrix)[:-2], + dtype=matrix.dtype) + matrix_2 = math_ops.matmul(matrix, matrix) + matrix_4 = math_ops.matmul(matrix_2, matrix_2) + matrix_6 = math_ops.matmul(matrix_4, matrix_2) + matrix_8 = math_ops.matmul(matrix_6, matrix_2) + tmp = ( + matrix_8 + b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + + b[1] * ident) + matrix_u = math_ops.matmul(matrix, tmp) + matrix_v = ( + b[8] * matrix_8 + b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + + b[0] * ident) + return matrix_u, matrix_v + + +def _matrix_exp_pade13(matrix): + """13th-order Pade approximant for matrix exponential.""" + b = [ + 64764752532480000.0, 32382376266240000.0, 7771770303897600.0, + 1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0, + 33522128640.0, 1323241920.0, 40840800.0, 960960.0, 16380.0, 182.0 + ] + b = [constant_op.constant(x, matrix.dtype) for x in b] + ident = linalg_ops.eye(array_ops.shape(matrix)[-2], + batch_shape=array_ops.shape(matrix)[:-2], + dtype=matrix.dtype) + matrix_2 = math_ops.matmul(matrix, matrix) + matrix_4 = math_ops.matmul(matrix_2, matrix_2) + matrix_6 = math_ops.matmul(matrix_4, matrix_2) + tmp_u = ( + math_ops.matmul(matrix_6, + matrix_6 + b[11] * matrix_4 + b[9] * matrix_2) + + b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident) + matrix_u = math_ops.matmul(matrix, tmp_u) + tmp_v = b[12] * matrix_6 + b[10] * matrix_4 + b[8] * matrix_2 + matrix_v = ( + math_ops.matmul(matrix_6, tmp_v) + b[6] * matrix_6 + b[4] * matrix_4 + + b[2] * matrix_2 + b[0] * ident) + return matrix_u, matrix_v + + +@tf_export('linalg.expm') +def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin + r"""Computes the matrix exponential of one or more square matrices. + + exp(A) = \sum_{n=0}^\infty A^n/n! + + The exponential is computed using a combination of the scaling and squaring + method and the Pade approximation. Details can be found in: + Nicholas J. Higham, "The scaling and squaring method for the matrix + exponential revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005. + + The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions + form square matrices. The output is a tensor of the same shape as the input + containing the exponential for all input submatrices `[..., :, :]`. + + Args: + input: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, + or `complex128` with shape `[..., M, M]`. + name: A name to give this `Op` (optional). + + Returns: + the matrix exponential of the input. + + Raises: + ValueError: An unsupported type is provided as input. + + @compatibility(scipy) + Equivalent to scipy.linalg.expm + @end_compatibility + """ + with ops.name_scope(name, 'matrix_exponential', [input]): + matrix = ops.convert_to_tensor(input, name='input') + if matrix.shape[-2:] == [0, 0]: + return matrix + batch_shape = matrix.shape[:-2] + if not batch_shape.is_fully_defined(): + batch_shape = array_ops.shape(matrix)[:-2] + + # reshaping the batch makes the where statements work better + matrix = array_ops.reshape( + matrix, array_ops.concat(([-1], array_ops.shape(matrix)[-2:]), axis=0)) + l1_norm = math_ops.reduce_max( + math_ops.reduce_sum(math_ops.abs(matrix), + axis=array_ops.size(array_ops.shape(matrix)) - 2), + axis=-1) + const = lambda x: constant_op.constant(x, l1_norm.dtype) + def _nest_where(vals, cases): + assert len(vals) == len(cases) - 1 + if len(vals) == 1: + return array_ops.where( + math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1]) + else: + return array_ops.where( + math_ops.less(l1_norm, const(vals[0])), cases[0], + _nest_where(vals[1:], cases[1:])) + + if matrix.dtype in [dtypes.float16, dtypes.float32, dtypes.complex64]: + maxnorm = const(3.925724783138660) + squarings = math_ops.maximum( + math_ops.floor( + math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0) + u3, v3 = _matrix_exp_pade3(matrix) + u5, v5 = _matrix_exp_pade5(matrix) + u7, v7 = _matrix_exp_pade7( + matrix / math_ops.pow( + constant_op.constant(2.0, dtype=matrix.dtype), + math_ops.cast(squarings, matrix.dtype))[..., + array_ops.newaxis, + array_ops.newaxis]) + conds = (4.258730016922831e-001, 1.880152677804762e+000) + u = _nest_where(conds, (u3, u5, u7)) + v = _nest_where(conds, (v3, v5, v7)) + elif matrix.dtype in [dtypes.float64, dtypes.complex128]: + maxnorm = const(5.371920351148152) + squarings = math_ops.maximum( + math_ops.floor( + math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0) + u3, v3 = _matrix_exp_pade3(matrix) + u5, v5 = _matrix_exp_pade5(matrix) + u7, v7 = _matrix_exp_pade7(matrix) + u9, v9 = _matrix_exp_pade9(matrix) + u13, v13 = _matrix_exp_pade13( + matrix / math_ops.pow( + constant_op.constant(2.0, dtype=matrix.dtype), + math_ops.cast(squarings, matrix.dtype))[..., + array_ops.newaxis, + array_ops.newaxis]) + conds = (1.495585217958292e-002, + 2.539398330063230e-001, + 9.504178996162932e-001, + 2.097847961257068e+000) + u = _nest_where(conds, (u3, u5, u7, u9, u13)) + v = _nest_where(conds, (v3, v5, v7, v9, v13)) + else: + raise ValueError( + 'tf.linalg.expm does not support matrices of type %s' % matrix.dtype) + numer = u + v + denom = -u + v + result = linalg_ops.matrix_solve(denom, numer) + max_squarings = math_ops.reduce_max(squarings) + + i = const(0.0) + c = lambda i, r: math_ops.less(i, max_squarings) + def b(i, r): + return i+1, array_ops.where(math_ops.less(i, squarings), + math_ops.matmul(r, r), r) + _, result = control_flow_ops.while_loop(c, b, [i, result]) + if not matrix.shape.is_fully_defined(): + return array_ops.reshape( + result, + array_ops.concat((batch_shape, array_ops.shape(result)[-2:]), axis=0)) + return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:])) diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py index 8cfe964b1c0a572f43a14c66885e74ea105b0916..20c46fbb82b0671c6cc586eafdd7fa346d8b4e6d 100644 --- a/tensorflow/python/ops/linalg/linear_operator.py +++ b/tensorflow/python/ops/linalg/linear_operator.py @@ -42,7 +42,7 @@ __all__ = ["LinearOperator"] class LinearOperator(object): """Base class defining a [batch of] linear operator[s]. - Subclasses of `LinearOperator` provide a access to common methods on a + Subclasses of `LinearOperator` provide access to common methods on a (batch) matrix, without the need to materialize the matrix. This allows: * Matrix free computations @@ -69,11 +69,11 @@ class LinearOperator(object): #### Shape compatibility - `LinearOperator` sub classes should operate on a [batch] matrix with + `LinearOperator` subclasses should operate on a [batch] matrix with compatible shape. Class docstrings should define what is meant by compatible - shape. Some sub-classes may not support batching. + shape. Some subclasses may not support batching. - An example is: + Examples: `x` is a batch matrix with compatible shape for `matmul` if diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py index 5beaea65a5171ad7e92042a2afa81c0507e51d0e..ed53decc00dc90df5c6c97d9fd9d5cb124ddf660 100644 --- a/tensorflow/python/ops/linalg/linear_operator_diag.py +++ b/tensorflow/python/ops/linalg/linear_operator_diag.py @@ -231,8 +231,11 @@ class LinearOperatorDiag(linear_operator.LinearOperator): return math_ops.reduce_prod(self._diag, reduction_indices=[-1]) def _log_abs_determinant(self): - return math_ops.reduce_sum( + log_det = math_ops.reduce_sum( math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1]) + if self.dtype.is_complex: + log_det = math_ops.cast(log_det, dtype=self.dtype) + return log_det def _solve(self, rhs, adjoint=False, adjoint_arg=False): diag_term = math_ops.conj(self._diag) if adjoint else self._diag diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py index 08e5896e1034fb1782beacfb18fef16da083bded..2b2bf80f276a62d20aae717ac9fa08f9769f455e 100644 --- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py +++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py @@ -18,16 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.linalg import linear_operator from tensorflow.python.ops.linalg import linear_operator_diag from tensorflow.python.ops.linalg import linear_operator_identity from tensorflow.python.ops.linalg import linear_operator_util +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -153,8 +152,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): `is_X` matrix property hints, which will trigger the appropriate code path. Args: - base_operator: Shape `[B1,...,Bb, M, N]` real `float16`, `float32` or - `float64` `LinearOperator`. This is `L` above. + base_operator: Shape `[B1,...,Bb, M, N]`. u: Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`. This is `U` above. diag_update: Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype` @@ -183,23 +181,12 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): Raises: ValueError: If `is_X` flags are set in an inconsistent way. """ - # TODO(langmore) support complex types. - # Complex types are not allowed due to tf.cholesky() requiring float. - # If complex dtypes are allowed, we update the following - # 1. is_diag_update_positive should still imply that `diag > 0`, but we need - # to remind the user that this implies diag is real. This is needed - # because if diag has non-zero imaginary part, it will not be - # self-adjoint positive definite. dtype = base_operator.dtype - allowed_dtypes = [ - dtypes.float16, - dtypes.float32, - dtypes.float64, - ] - if dtype not in allowed_dtypes: - raise TypeError( - "Argument matrix must have dtype in %s. Found: %s" - % (allowed_dtypes, dtype)) + + if diag_update is not None: + if is_diag_update_positive and dtype.is_complex: + logging.warn("Note: setting is_diag_update_positive with a complex " + "dtype means that diagonal is real and positive.") if diag_update is None: if is_diag_update_positive is False: @@ -271,8 +258,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): self._set_diag_operators(diag_update, is_diag_update_positive) self._is_diag_update_positive = is_diag_update_positive - check_ops.assert_same_float_dtype((base_operator, self.u, self.v, - self._diag_update)) self._check_shapes() # Pre-compute the so-called "capacitance" matrix @@ -407,6 +392,8 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): else: det_c = linalg_ops.matrix_determinant(self._capacitance) log_abs_det_c = math_ops.log(math_ops.abs(det_c)) + if self.dtype.is_complex: + log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype) return log_abs_det_c + log_abs_det_d + log_abs_det_l diff --git a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py index fb1eb2fedba5b47ce38f9635527b91e18d894a8f..ca6d3f54051d7bf0ff748804d3cd314b144c2f88 100644 --- a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py +++ b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py @@ -119,8 +119,7 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): Args: tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`. The lower triangular part of `tril` defines this operator. The strictly - upper triangle is ignored. Allowed dtypes: `float16`, `float32`, - `float64`. + upper triangle is ignored. is_non_singular: Expect that this operator is non-singular. This operator is non-singular if and only if its diagonal elements are all non-zero. @@ -137,7 +136,6 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): name: A name for this `LinearOperator`. Raises: - TypeError: If `diag.dtype` is not an allowed type. ValueError: If `is_square` is `False`. """ @@ -163,12 +161,12 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): def _check_tril(self, tril): """Static check of the `tril` argument.""" - # TODO(langmore) Add complex types once matrix_triangular_solve works for - # them. allowed_dtypes = [ dtypes.float16, dtypes.float32, dtypes.float64, + dtypes.complex64, + dtypes.complex128, ] dtype = tril.dtype if dtype not in allowed_dtypes: diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py index 1b5bb9470c4406ad075f2f6d5c38661311472727..78c85db557047ebcc3dd655deae62acbcef929c7 100644 --- a/tensorflow/python/ops/linalg/linear_operator_test_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py @@ -102,7 +102,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): raise NotImplementedError("operator_build_infos has not been implemented.") @abc.abstractmethod - def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + def _operator_and_matrix(self, build_info, dtype, use_placeholder): """Build a batch matrix and an Operator that should have similar behavior. Every operator acts like a (batch) matrix. This method returns both @@ -118,9 +118,6 @@ class LinearOperatorDerivedClassTest(test.TestCase): Returns: operator: `LinearOperator` subclass instance. mat: `Tensor` representing operator. - feed_dict: Dictionary. - If placholder is True, this must contains everything needed to be fed - to sess.run calls at runtime to make the operator work. """ # Create a matrix as a numpy array with desired shape/dtype. # Create a LinearOperator that should have the same behavior as the matrix. @@ -189,12 +186,12 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_dense = operator.to_dense() if not use_placeholder: self.assertAllEqual(build_info.shape, op_dense.get_shape()) - op_dense_v, mat_v = sess.run([op_dense, mat], feed_dict=feed_dict) + op_dense_v, mat_v = sess.run([op_dense, mat]) self.assertAC(op_dense_v, mat_v) def test_det(self): @@ -204,14 +201,13 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_det = operator.determinant() if not use_placeholder: self.assertAllEqual(build_info.shape[:-2], op_det.get_shape()) op_det_v, mat_det_v = sess.run( - [op_det, linalg_ops.matrix_determinant(mat)], - feed_dict=feed_dict) + [op_det, linalg_ops.matrix_determinant(mat)]) self.assertAC(op_det_v, mat_det_v) def test_log_abs_det(self): @@ -221,7 +217,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_log_abs_det = operator.log_abs_determinant() _, mat_log_abs_det = linalg.slogdet(mat) @@ -229,7 +225,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): self.assertAllEqual( build_info.shape[:-2], op_log_abs_det.get_shape()) op_log_abs_det_v, mat_log_abs_det_v = sess.run( - [op_log_abs_det, mat_log_abs_det], feed_dict=feed_dict) + [op_log_abs_det, mat_log_abs_det]) self.assertAC(op_log_abs_det_v, mat_log_abs_det_v) def _test_matmul(self, with_batch): @@ -246,7 +242,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) x = self._make_x( operator, adjoint=adjoint, with_batch=with_batch) @@ -264,7 +260,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): self.assertAllEqual(op_matmul.get_shape(), mat_matmul.get_shape()) op_matmul_v, mat_matmul_v = sess.run( - [op_matmul, mat_matmul], feed_dict=feed_dict) + [op_matmul, mat_matmul]) self.assertAC(op_matmul_v, mat_matmul_v) def test_matmul(self): @@ -289,7 +285,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) rhs = self._make_rhs( operator, adjoint=adjoint, with_batch=with_batch) @@ -307,8 +303,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): if not use_placeholder: self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape()) - op_solve_v, mat_solve_v = sess.run( - [op_solve, mat_solve], feed_dict=feed_dict) + op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve]) self.assertAC(op_solve_v, mat_solve_v) def test_solve(self): @@ -326,14 +321,13 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_trace = operator.trace() mat_trace = math_ops.trace(mat) if not use_placeholder: self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape()) - op_trace_v, mat_trace_v = sess.run( - [op_trace, mat_trace], feed_dict=feed_dict) + op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace]) self.assertAC(op_trace_v, mat_trace_v) def test_add_to_tensor(self): @@ -343,15 +337,14 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_plus_2mat = operator.add_to_tensor(2 * mat) if not use_placeholder: self.assertAllEqual(build_info.shape, op_plus_2mat.get_shape()) - op_plus_2mat_v, mat_v = sess.run( - [op_plus_2mat, mat], feed_dict=feed_dict) + op_plus_2mat_v, mat_v = sess.run([op_plus_2mat, mat]) self.assertAC(op_plus_2mat_v, 3 * mat_v) @@ -362,7 +355,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED - operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + operator, mat = self._operator_and_matrix( build_info, dtype, use_placeholder=use_placeholder) op_diag_part = operator.diag_part() mat_diag_part = array_ops.matrix_diag_part(mat) @@ -372,7 +365,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): op_diag_part.get_shape()) op_diag_part_, mat_diag_part_ = sess.run( - [op_diag_part, mat_diag_part], feed_dict=feed_dict) + [op_diag_part, mat_diag_part]) self.assertAC(op_diag_part_, mat_diag_part_) diff --git a/tensorflow/python/ops/linalg/linear_operator_zeros.py b/tensorflow/python/ops/linalg/linear_operator_zeros.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a79c065b32f452cfbb49c6bbd485556cc79445 --- /dev/null +++ b/tensorflow/python/ops/linalg/linear_operator_zeros.py @@ -0,0 +1,452 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`LinearOperator` acting like a zero matrix.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg_impl as linalg +from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.ops.linalg import linear_operator_util +from tensorflow.python.util.tf_export import tf_export + +__all__ = [ + "LinearOperatorZeros", +] + + +@tf_export("linalg.LinearOperatorZeros") +class LinearOperatorZeros(linear_operator.LinearOperator): + """`LinearOperator` acting like a [batch] zero matrix. + + This operator acts like a [batch] zero matrix `A` with shape + `[B1,...,Bb, N, M]` for some `b >= 0`. The first `b` indices index a + batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is + an `N x M` matrix. This matrix `A` is not materialized, but for + purposes of broadcasting this shape will be relevant. + + `LinearOperatorZeros` is initialized with `num_rows`, and optionally + `num_columns, `batch_shape`, and `dtype` arguments. If `num_columns` is + `None`, then this operator will be initialized as a square matrix. If + `batch_shape` is `None`, this operator efficiently passes through all + arguments. If `batch_shape` is provided, broadcasting may occur, which will + require making copies. + + ```python + # Create a 2 x 2 zero matrix. + operator = LinearOperatorZero(num_rows=2, dtype=tf.float32) + + operator.to_dense() + ==> [[0., 0.] + [0., 0.]] + + operator.shape + ==> [2, 2] + + operator.determinant() + ==> 0. + + x = ... Shape [2, 4] Tensor + operator.matmul(x) + ==> Shape [2, 4] Tensor, same as x. + + # Create a 2-batch of 2x2 zero matrices + operator = LinearOperatorZeros(num_rows=2, batch_shape=[2]) + operator.to_dense() + ==> [[[0., 0.] + [0., 0.]], + [[0., 0.] + [0., 0.]]] + + # Here, even though the operator has a batch shape, the input is the same as + # the output, so x can be passed through without a copy. The operator is able + # to detect that no broadcast is necessary because both x and the operator + # have statically defined shape. + x = ... Shape [2, 2, 3] + operator.matmul(x) + ==> Shape [2, 2, 3] Tensor, same as tf.zeros_like(x) + + # Here the operator and x have different batch_shape, and are broadcast. + # This requires a copy, since the output is different size than the input. + x = ... Shape [1, 2, 3] + operator.matmul(x) + ==> Shape [2, 2, 3] Tensor, equal to tf.zeros_like([x, x]) + ``` + + ### Shape compatibility + + This operator acts on [batch] matrix with compatible shape. + `x` is a batch matrix with compatible shape for `matmul` and `solve` if + + ``` + operator.shape = [B1,...,Bb] + [N, M], with b >= 0 + x.shape = [C1,...,Cc] + [M, R], + and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] + ``` + + #### Matrix property hints + + This `LinearOperator` is initialized with boolean flags of the form `is_X`, + for `X = non_singular, self_adjoint, positive_definite, square`. + These have the following meaning: + + * If `is_X == True`, callers should expect the operator to have the + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. + * If `is_X == False`, callers should expect the operator to not have `X`. + * If `is_X == None` (the default), callers should have no expectation either + way. + """ + + def __init__(self, + num_rows, + num_columns=None, + batch_shape=None, + dtype=None, + is_non_singular=False, + is_self_adjoint=True, + is_positive_definite=False, + is_square=True, + assert_proper_shapes=False, + name="LinearOperatorZeros"): + r"""Initialize a `LinearOperatorZeros`. + + The `LinearOperatorZeros` is initialized with arguments defining `dtype` + and shape. + + This operator is able to broadcast the leading (batch) dimensions, which + sometimes requires copying data. If `batch_shape` is `None`, the operator + can take arguments of any batch shape without copying. See examples. + + Args: + num_rows: Scalar non-negative integer `Tensor`. Number of rows in the + corresponding zero matrix. + num_columns: Scalar non-negative integer `Tensor`. Number of columns in + the corresponding zero matrix. If `None`, defaults to the value of + `num_rows`. + batch_shape: Optional `1-D` integer `Tensor`. The shape of the leading + dimensions. If `None`, this operator has no leading dimensions. + dtype: Data type of the matrix that this operator represents. + is_non_singular: Expect that this operator is non-singular. + is_self_adjoint: Expect that this operator is equal to its hermitian + transpose. + is_positive_definite: Expect that this operator is positive definite, + meaning the quadratic form `x^H A x` has positive real part for all + nonzero `x`. Note that we do not require the operator to be + self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices + is_square: Expect that this operator acts like square [batch] matrices. + assert_proper_shapes: Python `bool`. If `False`, only perform static + checks that initialization and method arguments have proper shape. + If `True`, and static checks are inconclusive, add asserts to the graph. + name: A name for this `LinearOperator` + + Raises: + ValueError: If `num_rows` is determined statically to be non-scalar, or + negative. + ValueError: If `num_columns` is determined statically to be non-scalar, + or negative. + ValueError: If `batch_shape` is determined statically to not be 1-D, or + negative. + ValueError: If any of the following is not `True`: + `{is_self_adjoint, is_non_singular, is_positive_definite}`. + """ + dtype = dtype or dtypes.float32 + self._assert_proper_shapes = assert_proper_shapes + + with ops.name_scope(name): + dtype = dtypes.as_dtype(dtype) + if not is_self_adjoint and is_square: + raise ValueError("A zero operator is always self adjoint.") + if is_non_singular: + raise ValueError("A zero operator is always singular.") + if is_positive_definite: + raise ValueError("A zero operator is always not positive-definite.") + + super(LinearOperatorZeros, self).__init__( + dtype=dtype, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=is_square, + name=name) + + self._num_rows = linear_operator_util.shape_tensor( + num_rows, name="num_rows") + self._num_rows_static = tensor_util.constant_value(self._num_rows) + + if num_columns is None: + num_columns = num_rows + + self._num_columns = linear_operator_util.shape_tensor( + num_columns, name="num_columns") + self._num_columns_static = tensor_util.constant_value(self._num_columns) + + self._check_domain_range_possibly_add_asserts() + + if (self._num_rows_static is not None and + self._num_columns_static is not None): + if is_square and self._num_rows_static != self._num_columns_static: + raise ValueError( + "LinearOperatorZeros initialized as is_square=True, but got " + "num_rows({}) != num_columns({})".format( + self._num_rows_static, + self._num_columns_static)) + + if batch_shape is None: + self._batch_shape_arg = None + else: + self._batch_shape_arg = linear_operator_util.shape_tensor( + batch_shape, name="batch_shape_arg") + self._batch_shape_static = tensor_util.constant_value( + self._batch_shape_arg) + self._check_batch_shape_possibly_add_asserts() + + def _shape(self): + matrix_shape = tensor_shape.TensorShape((self._num_rows_static, + self._num_columns_static)) + if self._batch_shape_arg is None: + return matrix_shape + + batch_shape = tensor_shape.TensorShape(self._batch_shape_static) + return batch_shape.concatenate(matrix_shape) + + def _shape_tensor(self): + matrix_shape = array_ops.stack((self._num_rows, self._num_columns), axis=0) + if self._batch_shape_arg is None: + return matrix_shape + + return array_ops.concat((self._batch_shape_arg, matrix_shape), 0) + + def _assert_non_singular(self): + raise errors.InvalidArgumentError( + node_def=None, op=None, message="Zero operators are always " + "non-invertible.") + + def _assert_positive_definite(self): + raise errors.InvalidArgumentError( + node_def=None, op=None, message="Zero operators are always " + "non-positive definite.") + + def _assert_self_adjoint(self): + return control_flow_ops.no_op("assert_self_adjoint") + + def _possibly_broadcast_batch_shape(self, x): + """Return 'x', possibly after broadcasting the leading dimensions.""" + # If we have no batch shape, our batch shape broadcasts with everything! + if self._batch_shape_arg is None: + return x + + # Static attempt: + # If we determine that no broadcast is necessary, pass x through + # If we need a broadcast, add to an array of zeros. + # + # special_shape is the shape that, when broadcast with x's shape, will give + # the correct broadcast_shape. Note that + # We have already verified the second to last dimension of self.shape + # matches x's shape in assert_compatible_matrix_dimensions. + # Also, the final dimension of 'x' can have any shape. + # Therefore, the final two dimensions of special_shape are 1's. + special_shape = self.batch_shape.concatenate([1, 1]) + bshape = array_ops.broadcast_static_shape(x.get_shape(), special_shape) + if special_shape.is_fully_defined(): + # bshape.is_fully_defined iff special_shape.is_fully_defined. + if bshape == x.get_shape(): + return x + # Use the built in broadcasting of addition. + zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) + return x + zeros + + # Dynamic broadcast: + # Always add to an array of zeros, rather than using a "cond", since a + # cond would require copying data from GPU --> CPU. + special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0) + zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) + return x + zeros + + def _matmul(self, x, adjoint=False, adjoint_arg=False): + if self._assert_proper_shapes: + x = linalg.adjoint(x) if adjoint_arg else x + aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) + x = control_flow_ops.with_dependencies([aps], x) + if self.is_square: + # Note that adjoint has no effect since this matrix is self-adjoint. + if adjoint_arg: + output_shape = array_ops.concat([ + array_ops.shape(x)[:-2], + [array_ops.shape(x)[-1], array_ops.shape(x)[-2]]], axis=0) + else: + output_shape = array_ops.shape(x) + + return self._possibly_broadcast_batch_shape( + array_ops.zeros(shape=output_shape, dtype=x.dtype)) + + x_shape = array_ops.shape(x) + n = self._num_columns if adjoint else self._num_rows + m = x_shape[-2] if adjoint_arg else x_shape[-1] + + output_shape = array_ops.concat([x_shape[:-2], [n, m]], axis=0) + + zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype) + return self._possibly_broadcast_batch_shape(zeros) + + def _determinant(self): + if self.batch_shape.is_fully_defined(): + return array_ops.zeros(shape=self.batch_shape, dtype=self.dtype) + else: + return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) + + def _trace(self): + # Get Tensor of all zeros of same shape as self.batch_shape. + if self.batch_shape.is_fully_defined(): + return array_ops.zeros(shape=self.batch_shape, dtype=self.dtype) + else: + return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) + + def _diag_part(self): + return self._zeros_diag() + + def add_to_tensor(self, mat, name="add_to_tensor"): + """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. + + Args: + mat: `Tensor` with same `dtype` and shape broadcastable to `self`. + name: A name to give this `Op`. + + Returns: + A `Tensor` with broadcast shape and same `dtype` as `self`. + """ + return self._possibly_broadcast_batch_shape(mat) + + def _check_domain_range_possibly_add_asserts(self): + """Static check of init arg `num_rows`, possibly add asserts.""" + # Possibly add asserts. + if self._assert_proper_shapes: + self._num_rows = control_flow_ops.with_dependencies([ + check_ops.assert_rank( + self._num_rows, + 0, + message="Argument num_rows must be a 0-D Tensor."), + check_ops.assert_non_negative( + self._num_rows, + message="Argument num_rows must be non-negative."), + ], self._num_rows) + self._num_columns = control_flow_ops.with_dependencies([ + check_ops.assert_rank( + self._num_columns, + 0, + message="Argument num_columns must be a 0-D Tensor."), + check_ops.assert_non_negative( + self._num_columns, + message="Argument num_columns must be non-negative."), + ], self._num_columns) + + # Static checks. + if not self._num_rows.dtype.is_integer: + raise TypeError("Argument num_rows must be integer type. Found:" + " %s" % self._num_rows) + + if not self._num_columns.dtype.is_integer: + raise TypeError("Argument num_columns must be integer type. Found:" + " %s" % self._num_columns) + + num_rows_static = self._num_rows_static + num_columns_static = self._num_columns_static + + if num_rows_static is not None: + if num_rows_static.ndim != 0: + raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" + " %s" % num_rows_static) + + if num_rows_static < 0: + raise ValueError("Argument num_rows must be non-negative. Found:" + " %s" % num_rows_static) + if num_columns_static is not None: + if num_columns_static.ndim != 0: + raise ValueError("Argument num_columns must be a 0-D Tensor. Found:" + " %s" % num_columns_static) + + if num_columns_static < 0: + raise ValueError("Argument num_columns must be non-negative. Found:" + " %s" % num_columns_static) + + def _check_batch_shape_possibly_add_asserts(self): + """Static check of init arg `batch_shape`, possibly add asserts.""" + if self._batch_shape_arg is None: + return + + # Possibly add asserts + if self._assert_proper_shapes: + self._batch_shape_arg = control_flow_ops.with_dependencies([ + check_ops.assert_rank( + self._batch_shape_arg, + 1, + message="Argument batch_shape must be a 1-D Tensor."), + check_ops.assert_non_negative( + self._batch_shape_arg, + message="Argument batch_shape must be non-negative."), + ], self._batch_shape_arg) + + # Static checks + if not self._batch_shape_arg.dtype.is_integer: + raise TypeError("Argument batch_shape must be integer type. Found:" + " %s" % self._batch_shape_arg) + + if self._batch_shape_static is None: + return # Cannot do any other static checks. + + if self._batch_shape_static.ndim != 1: + raise ValueError("Argument batch_shape must be a 1-D Tensor. Found:" + " %s" % self._batch_shape_static) + + if np.any(self._batch_shape_static < 0): + raise ValueError("Argument batch_shape must be non-negative. Found:" + "%s" % self._batch_shape_static) + + def _min_matrix_dim(self): + """Minimum of domain/range dimension, if statically available, else None.""" + domain_dim = self.domain_dimension.value + range_dim = self.range_dimension.value + if domain_dim is None or range_dim is None: + return None + return min(domain_dim, range_dim) + + def _min_matrix_dim_tensor(self): + """Minimum of domain/range dimension, as a tensor.""" + return math_ops.reduce_min(self.shape_tensor()[-2:]) + + def _zeros_diag(self): + """Returns the diagonal of this operator as all zeros.""" + if self.shape.is_fully_defined(): + d_shape = self.batch_shape.concatenate([self._min_matrix_dim()]) + else: + d_shape = array_ops.concat( + [self.batch_shape_tensor(), + [self._min_matrix_dim_tensor()]], axis=0) + + return array_ops.zeros(shape=d_shape, dtype=self.dtype) diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index a0dfa543f9b3aee15f11b073dc683b1d2d14388f..f4a93560bee558512f33214148ddec22590b9dd6 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -401,7 +401,7 @@ def svd(tensor, full_matrices=False, compute_uv=True, name=None): import tensorflow as tf import numpy as np s, u, v = tf.linalg.svd(a) - tf_a_approx = tf.matmul(u, tf.matmul(tf.linalg.diag(s), v, adjoint_v=True)) + tf_a_approx = tf.matmul(u, tf.matmul(tf.linalg.diag(s), v, adjoint_b=True)) u, s, v_adj = np.linalg.svd(a, full_matrices=False) np_a_approx = np.dot(u, np.dot(np.diag(s), v_adj)) # tf_a_approx and np_a_approx should be numerically close. diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py index 8276047cb678f3d340701718156f8a1cfd6831cb..df41933f8a864be3ada72dbf101420c886dfb36b 100644 --- a/tensorflow/python/ops/logging_ops.py +++ b/tensorflow/python/ops/logging_ops.py @@ -35,9 +35,12 @@ from tensorflow.python.util.tf_export import tf_export # Assert and Print are special symbols in python, so we must -# have an upper-case version of them. For users with Python 3 or Python 2.7 -# with `from __future__ import print_function`, we also allow lowercase. -@tf_export("Print", "print") +# have an upper-case version of them. +# +# For users with Python 3 or Python 2.7 +# with `from __future__ import print_function`, we could also allow lowercase. +# See https://github.com/tensorflow/tensorflow/issues/18053 +@tf_export("Print") def Print(input_, data, message=None, first_n=None, summarize=None, name=None): """Prints a list of tensors. diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index de9b3c6909ddd9c22ac4bced5ec48e4de354bd19..806539747e5e74cf1c5f40ab47aa84dcbb364344 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -190,8 +190,13 @@ def compute_weighted_loss( When calculating the gradient of a weighted loss contributions from both `losses` and `weights` are considered. If your `weights` depend on some model parameters but you do not want this to affect the loss - gradient, you need to apply @{tf.stop_gradient} to `weights` before + gradient, you need to apply `tf.stop_gradient` to `weights` before passing them to `compute_weighted_loss`. + + @compatibility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ Reduction.validate(reduction) with ops.name_scope(scope, "weighted_loss", (losses, weights)): @@ -260,6 +265,11 @@ def absolute_difference( ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid or if `labels` or `predictions` is None. + + @compatibility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -306,6 +316,11 @@ def cosine_distance( Raises: ValueError: If `predictions` shape doesn't match `labels` shape, or `axis`, `labels`, `predictions` or `weights` is `None`. + + @compatibility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ axis = deprecated_argument_lookup("axis", axis, "dim", dim) if axis is None: @@ -353,6 +368,11 @@ def hinge_loss(labels, logits, weights=1.0, scope=None, Raises: ValueError: If the shapes of `logits` and `labels` don't match or if `labels` or `logits` is None. + + @compatibility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -416,6 +436,11 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid. Also if `labels` or `predictions` is None. + + @compatibility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -477,6 +502,11 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid. Also if `labels` or `predictions` is None. + + @compatibility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -540,6 +570,11 @@ def mean_pairwise_squared_error( ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid. Also if `labels` or `predictions` is None. + + @compatibility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -618,6 +653,11 @@ def mean_squared_error( ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weights` is invalid. Also if `labels` or `predictions` is None. + + @compatibility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") @@ -670,6 +710,11 @@ def sigmoid_cross_entropy( ValueError: If the shape of `logits` doesn't match that of `multi_class_labels` or if the shape of `weights` is invalid, or if `weights` is None. Also if `multi_class_labels` or `logits` is None. + + @compatibility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if multi_class_labels is None: raise ValueError("multi_class_labels must not be None.") @@ -731,6 +776,11 @@ def softmax_cross_entropy( ValueError: If the shape of `logits` doesn't match that of `onehot_labels` or if the shape of `weights` is invalid or if `weights` is None. Also if `onehot_labels` or `logits` is None. + + @compatibility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if onehot_labels is None: raise ValueError("onehot_labels must not be None.") @@ -828,7 +878,8 @@ def sparse_softmax_cross_entropy( exception when this op is run on CPU, and return `NaN` for corresponding loss and gradient rows on GPU. logits: Unscaled log probabilities of shape - `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`. + `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32` or + `float64`. weights: Coefficients for the loss. This must be scalar or broadcastable to `labels` (i.e. same rank and each dimension is either 1 or the same). scope: the scope for the operations performed in computing the loss. @@ -842,6 +893,11 @@ def sparse_softmax_cross_entropy( Raises: ValueError: If the shapes of `logits`, `labels`, and `weights` are incompatible, or if any of them are None. + + @compatibility(eager) + The `loss_collection` argument is ignored when executing eagerly. Consider + holding on to the return value or collecting losses via a `tf.keras.Model`. + @end_compatibility """ if labels is None: raise ValueError("labels must not be None.") diff --git a/tensorflow/python/ops/losses/util.py b/tensorflow/python/ops/losses/util.py index 10646af8a983f149cf0620bf355cf0bc1fa697fb..97bba46661d056fd336c68988e3bc17ef4232487 100644 --- a/tensorflow/python/ops/losses/util.py +++ b/tensorflow/python/ops/losses/util.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops @@ -32,7 +33,10 @@ def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): loss: A loss `Tensor`. loss_collection: Optional collection to add the loss to. """ - if loss_collection: + # Since we have no way of figuring out when a training iteration starts or + # ends, holding on to a loss when executing eagerly is indistingishable from + # leaking memory. We instead leave the collection empty. + if loss_collection and not context.executing_eagerly(): ops.add_to_collection(loss_collection, loss) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index a48b3c9395f48ecaf3879b95199f69c84f7f095a..2a7a2fd51f3ea5dd596472f47ea64a77c66bc948 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -651,27 +651,28 @@ def _BesselI1eGrad(op, grad): @ops.RegisterGradient("Igamma") def _IgammaGrad(op, grad): - """Returns gradient of igamma(a, x) with respect to x.""" - # TODO(ebrevdo): Perhaps add the derivative w.r.t. a + """Returns gradient of igamma(a, x) with respect to a and x.""" a = op.inputs[0] x = op.inputs[1] sa = array_ops.shape(a) sx = array_ops.shape(x) - unused_ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx) + ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx) - # Perform operations in log space before summing, because Gamma(a) - # and Gamma'(a) can grow large. - partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - math_ops.lgamma(a)) - # TODO(b/36815900): Mark None return values as NotImplemented - return (None, array_ops.reshape( - math_ops.reduce_sum(partial_x * grad, rx), sx)) + with ops.control_dependencies([grad]): + partial_a = gen_math_ops.igamma_grad_a(a, x) + # Perform operations in log space before summing, because Gamma(a) + # and Gamma'(a) can grow large. + partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) + - math_ops.lgamma(a)) + return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa), + array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) @ops.RegisterGradient("Igammac") def _IgammacGrad(op, grad): - """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. x.""" - _, igamma_grad_x = _IgammaGrad(op, grad) - return None, -igamma_grad_x + """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x.""" + igamma_grad_a, igamma_grad_x = _IgammaGrad(op, grad) + return (-igamma_grad_a, -igamma_grad_x) @ops.RegisterGradient("Betainc") @@ -971,6 +972,24 @@ def _RealDivGrad(op, grad): grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy)) +@ops.RegisterGradient("UnsafeDiv") +def _UnsafeDivGrad(op, grad): + """UnsafeDiv op gradient.""" + x = op.inputs[0] + y = op.inputs[1] + sx = array_ops.shape(x) + sy = array_ops.shape(y) + rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) + x = math_ops.conj(x) + y = math_ops.conj(y) + return (array_ops.reshape( + math_ops.reduce_sum(math_ops.unsafe_div(grad, y), rx), sx), + array_ops.reshape( + math_ops.reduce_sum( + grad * math_ops.unsafe_div(math_ops.unsafe_div(-x, y), y), + ry), sy)) + + @ops.RegisterGradient("Pow") def _PowGrad(op, grad): """Returns grad * (y*x^(y-1), z*log(x)).""" diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py index fa47b8f9b8a0e72c5ecf814e6a80e04fb559990c..f9bb60e7fedb029e846628bddaaf15980a8ec625 100644 --- a/tensorflow/python/ops/math_grad_test.py +++ b/tensorflow/python/ops/math_grad_test.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -230,5 +231,27 @@ class FloorModGradientTest(test.TestCase): self.assertLess(error, 1e-4) +class UnsafeDivGradientTest(test.TestCase): + + def testBasicGradient(self): + inputs = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32) + outputs = math_ops.unsafe_div(inputs, 1 + math_ops.abs(inputs)) + with self.test_session(): + error = gradient_checker.compute_gradient_error( + inputs, + inputs.get_shape().as_list(), outputs, + outputs.get_shape().as_list()) + self.assertLess(error, 1e-4) + + def testGradientWithDenominatorIsZero(self): + x = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32) + y = array_ops.zeros_like(x, dtype=dtypes.float32) + outputs = math_ops.unsafe_div(x, y) + with self.test_session(): + dx, dy = gradients.gradients(outputs, [x, y]) + self.assertAllClose(dx.eval(), np.zeros(x.shape.as_list())) + self.assertAllClose(dy.eval(), np.zeros(y.shape.as_list())) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 466d0dadc8a430861ab27b6a522ca6acd2db7855..81499bee56b12eba627e23828aceba6f1f5dc578 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -37,11 +37,11 @@ from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gen_sparse_ops from tensorflow.python.ops import gen_spectral_ops -from tensorflow.python.platform import tf_logging as logging # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_math_ops import * # pylint: enable=wildcard-import +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import nest @@ -628,16 +628,17 @@ def cast(x, dtype, name=None): ``` The operation supports data types (for `x` and `dtype`) of - `uint8`, `int8`, `uint16`, `int16`, `int32`, `int64`, `float16`, `float32`, - `float64`, `complex64`, `complex128`, `bfloat16`. In case of casting from - complex types (`complex64`, `complex128`) to real types, only the real part - of `x` is returned. In case of casting from real types to complex types - (`complex64`, `complex128`), the imaginary part of the returned value is set - to `0`. The handling of complex types here matches the behavior of numpy. + `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`, + `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`. + In case of casting from complex types (`complex64`, `complex128`) to real + types, only the real part of `x` is returned. In case of casting from real + types to complex types (`complex64`, `complex128`), the imaginary part of the + returned value is set to `0`. The handling of complex types here matches the + behavior of numpy. Args: x: A `Tensor` or `SparseTensor` of numeric type. It could be - `uint8`, `int8`, `uint16`, `int16`, `int32`, `int64`, + `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`, `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`. dtype: The destination type. The list of supported dtypes is the same as `x`. @@ -651,6 +652,9 @@ def cast(x, dtype, name=None): TypeError: If `x` cannot be cast to the `dtype`. """ base_type = dtypes.as_dtype(dtype).base_dtype + if isinstance(x, + (ops.Tensor, _resource_variable_type)) and base_type == x.dtype: + return x with ops.name_scope(name, "Cast", [x]) as name: if isinstance(x, sparse_tensor.SparseTensor): values_cast = cast(x.values, base_type, name=name) @@ -1034,6 +1038,31 @@ def div(x, y, name=None): return _div_python2(x, y, name) +def unsafe_div(x, y, name=None): + """Computes an unsafe divide which returns 0 if the y is zero. + + Note that the function uses Python 3 division operator semantics. + + Args: + x: A `Tensor`. Must be one of the following types: + `float32`, `float64`, `int16`, `int32`, `int64`. + y: A `Tensor` whose dtype is compatible with `x`. + name: A name for the operation (optional). + Returns: + The element-wise value of the x divided by y. + """ + + with ops.name_scope(name, "unsafe_div", [x, y]) as name: + x = ops.convert_to_tensor(x, name="x") + y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype) + x_dtype = x.dtype.base_dtype + y_dtype = y.dtype.base_dtype + if x_dtype != y_dtype: + raise TypeError( + "x and y must have the same dtype, got %r != %r" % (x_dtype, y_dtype)) + return gen_math_ops.unsafe_div(x, y, name=name) + + # TODO(aselle): This should be removed mod = gen_math_ops.floor_mod @@ -1222,8 +1251,9 @@ def _ReductionDims(x, axis, reduction_indices): return axis else: # Fast path: avoid creating Rank and Range ops if ndims is known. - if isinstance(x, ops.Tensor) and x._rank() is not None: # pylint: disable=protected-access - return constant_op.constant(np.arange(x._rank()), dtype=dtypes.int32) # pylint: disable=protected-access + rank = common_shapes.rank(x) + if rank is not None: + return constant_op.constant(np.arange(rank), dtype=dtypes.int32) if (isinstance(x, sparse_tensor.SparseTensor) and x.dense_shape.get_shape().is_fully_defined()): rank = x.dense_shape.get_shape()[0].value # sparse.dense_shape is 1-D. @@ -1234,8 +1264,8 @@ def _ReductionDims(x, axis, reduction_indices): def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output): - """Set a reduction's output's shape to be a scalar if we are certain.""" - if (not output.shape.is_fully_defined()) and (not keepdims) and ( + """Set a reduction's output shape to be a scalar if we are certain.""" + if not common_shapes.has_fully_defined_shape(output) and (not keepdims) and ( axis is None) and (reduction_indices is None): output.set_shape(()) return output @@ -1990,7 +2020,7 @@ def matmul(a, sparse_matmul_types = [dtypes.bfloat16, dtypes.float32] use_sparse_matmul = ( a.dtype in sparse_matmul_types and b.dtype in sparse_matmul_types) - if (a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16 and + if ((a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16) and a.dtype != b.dtype): # matmul currently doesn't handle mixed-precision inputs. use_sparse_matmul = True diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index c807c8bc2efbf867700f0df37783b02fefa0ca82..5fe7bbca112c7dcb7894b79cdcd952210f636dde 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -37,14 +37,14 @@ log = np.log class ReduceTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testReduceAllDims(self): x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) with test_util.device(use_gpu=True): y_tf = self.evaluate(math_ops.reduce_sum(x)) self.assertEqual(y_tf, 21) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testReduceExplicitAxes(self): x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) with test_util.device(use_gpu=True): @@ -57,7 +57,7 @@ class ReduceTest(test_util.TensorFlowTestCase): for axis in (None, (0, 1), (-1, -2), (-2, -1, 0, 1)): self.assertEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)), 21) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testReduceInvalidAxis(self): if context.executing_eagerly(): # The shape check is in run a graph construction time. In eager mode, @@ -150,7 +150,7 @@ class LogSumExpTest(test_util.TensorFlowTestCase): class RoundTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRounding(self): x = np.arange(-5.0, 5.0, .25) for dtype in [np.float32, np.double, np.int32]: @@ -194,7 +194,7 @@ class ModTest(test_util.TensorFlowTestCase): class SquaredDifferenceTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSquaredDifference(self): for dtype in [np.int32, np.float16]: x = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) @@ -207,7 +207,7 @@ class SquaredDifferenceTest(test_util.TensorFlowTestCase): class ApproximateEqualTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testApproximateEqual(self): for dtype in [np.float32, np.double]: x = dtype(1) @@ -237,8 +237,8 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase): def testApproximateEqualShape(self): for dtype in [np.float32, np.double]: - x = np.array([1, 2], dtype=np.float32) - y = np.array([[1, 2]], dtype=np.float32) + x = np.array([1, 2], dtype=dtype) + y = np.array([[1, 2]], dtype=dtype) # The inputs 'x' and 'y' must have the same shape. with self.assertRaisesRegexp( ValueError, "Shapes must be equal rank, but are 1 and 2"): @@ -247,7 +247,7 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase): class ScalarMulTest(test_util.TensorFlowTestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAcceptsRefs(self): if context.executing_eagerly(): var = resource_variable_ops.ResourceVariable(10, name="var") @@ -259,14 +259,14 @@ class ScalarMulTest(test_util.TensorFlowTestCase): self.evaluate(init) self.assertEqual(30, self.evaluate(result)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAcceptsConstant(self): const = constant_op.constant(10) result = math_ops.scalar_mul(3, const) with test_util.device(use_gpu=True): self.assertEqual(30, self.evaluate(result)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAcceptsTensor(self): tensor = array_ops.ones([10, 10]) result = math_ops.scalar_mul(3, tensor) @@ -275,7 +275,7 @@ class ScalarMulTest(test_util.TensorFlowTestCase): with test_util.device(use_gpu=True): self.assertAllEqual(self.evaluate(expected), self.evaluate(result)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAcceptsIndexedSlices(self): values = constant_op.constant([2, 3, 5, 7, 0, -1], shape=[3, 2]) indices = constant_op.constant([0, 2, 5]) @@ -473,5 +473,19 @@ class DivAndModTest(test_util.TensorFlowTestCase): self.assertAllEqual(tf_result, expanded_nums) +class UnsafeDivTest(test_util.TensorFlowTestCase): + + def testBasic(self): + nums = np.arange(-10, 10, .25).reshape(80, 1) + divs = np.arange(-3, 3, .25).reshape(1, 24) + + np_result = np.true_divide(nums, divs) + np_result[:, divs[0] == 0] = 0 + + with self.test_session(): + tf_result = math_ops.unsafe_div(nums, divs).eval() + self.assertAllEqual(tf_result, np_result) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 5eab12c41d5f781b496a3a1bcfa9ce35fca4fc54..9461a01515a7a7aee62f3bf18f21c3ee9c352924 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -34,7 +34,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export @@ -57,7 +57,8 @@ def metric_variable(shape, dtype, validate_shape=True, name=None): Furthermore, the final answer should be computed once instead of in every replica/tower. Both of these are accomplished by running the computation of the final result value inside - `tf.contrib.distribute.get_tower_context().merge_call(fn)`. + `tf.contrib.distribution_strategy_context.get_tower_context( + ).merge_call(fn)`. Inside the `merge_call()`, ops are only added to the graph once and access to a tower-local variable in a computation returns the sum across all replicas/towers. @@ -73,15 +74,16 @@ def metric_variable(shape, dtype, validate_shape=True, name=None): A (non-trainable) variable initialized to zero, or if inside a `DistributionStrategy` scope a tower-local variable container. """ - with distribute_lib.get_tower_context().tower_local_var_scope('sum'): - # Note that "tower local" implies trainable=False. - return variable_scope.variable( - lambda: array_ops.zeros(shape, dtype), - collections=[ - ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES - ], - validate_shape=validate_shape, - name=name) + # Note that synchronization "ON_READ" implies trainable=False. + return variable_scope.variable( + lambda: array_ops.zeros(shape, dtype), + collections=[ + ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES + ], + validate_shape=validate_shape, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM, + name=name) def _remove_squeezable_dimensions(predictions, labels, weights): @@ -372,7 +374,7 @@ def mean(values, ops.add_to_collections(metrics_collections, mean_t) return mean_t - mean_t = distribute_lib.get_tower_context().merge_call( + mean_t = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, total, count) update_op = _safe_div(update_total_op, update_count_op, 'update_op') @@ -617,7 +619,7 @@ def _aggregate_variable(v, collections): ops.add_to_collections(collections, value) return value - return distribute_lib.get_tower_context().merge_call(f, v) + return distribution_strategy_context.get_tower_context().merge_call(f, v) @tf_export('metrics.auc') @@ -812,7 +814,7 @@ def auc(labels, ops.add_to_collections(metrics_collections, auc_value) return auc_value - auc_value = distribute_lib.get_tower_context().merge_call( + auc_value = distribution_strategy_context.get_tower_context().merge_call( aggregate_auc, values) update_op = compute_auc(update_ops['tp'], update_ops['fn'], update_ops['tn'], update_ops['fp'], 'update_op') @@ -1052,8 +1054,8 @@ def mean_per_class_accuracy(labels, ops.add_to_collections(metrics_collections, mean_accuracy_v) return mean_accuracy_v - mean_accuracy_v = distribute_lib.get_tower_context().merge_call( - aggregate_mean_accuracy, count, total) + mean_accuracy_v = distribution_strategy_context.get_tower_context( + ).merge_call(aggregate_mean_accuracy, count, total) update_op = _safe_div(update_count_op, update_total_op, name='update_op') if updates_collections: @@ -1159,7 +1161,7 @@ def mean_iou(labels, ops.add_to_collections(metrics_collections, mean_iou_v) return mean_iou_v - mean_iou_v = distribute_lib.get_tower_context().merge_call( + mean_iou_v = distribution_strategy_context.get_tower_context().merge_call( mean_iou_across_towers, total_cm) if updates_collections: @@ -1375,7 +1377,7 @@ def mean_tensor(values, ops.add_to_collections(metrics_collections, mean_t) return mean_t - mean_t = distribute_lib.get_tower_context().merge_call( + mean_t = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, total, count) update_op = _safe_div(update_total_op, update_count_op, 'update_op') @@ -2007,7 +2009,7 @@ def precision(labels, ops.add_to_collections(metrics_collections, p) return p - p = distribute_lib.get_tower_context().merge_call( + p = distribution_strategy_context.get_tower_context().merge_call( once_across_towers, true_p, false_p) update_op = compute_precision(true_positives_update_op, @@ -2091,7 +2093,7 @@ def precision_at_thresholds(labels, ops.add_to_collections(metrics_collections, prec) return prec - prec = distribute_lib.get_tower_context().merge_call( + prec = distribution_strategy_context.get_tower_context().merge_call( precision_across_towers, values) update_op = compute_precision(update_ops['tp'], update_ops['fp'], @@ -2187,7 +2189,7 @@ def recall(labels, ops.add_to_collections(metrics_collections, rec) return rec - rec = distribute_lib.get_tower_context().merge_call( + rec = distribution_strategy_context.get_tower_context().merge_call( once_across_towers, true_p, false_n) update_op = compute_recall(true_positives_update_op, @@ -2626,7 +2628,7 @@ def recall_at_top_k(labels, ops.add_to_collections(metrics_collections, metric) return metric - metric = distribute_lib.get_tower_context().merge_call( + metric = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, tp, fn) update = math_ops.div( @@ -2707,7 +2709,7 @@ def recall_at_thresholds(labels, ops.add_to_collections(metrics_collections, rec) return rec - rec = distribute_lib.get_tower_context().merge_call( + rec = distribution_strategy_context.get_tower_context().merge_call( recall_across_towers, values) update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') @@ -2782,7 +2784,7 @@ def root_mean_squared_error(labels, ops.add_to_collections(metrics_collections, rmse) return rmse - rmse = distribute_lib.get_tower_context().merge_call( + rmse = distribution_strategy_context.get_tower_context().merge_call( once_across_towers, mse) update_rmse_op = math_ops.sqrt(update_mse_op) @@ -2885,7 +2887,7 @@ def sensitivity_at_specificity(labels, ops.add_to_collections(metrics_collections, sensitivity) return sensitivity - sensitivity = distribute_lib.get_tower_context().merge_call( + sensitivity = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, values) update_op = compute_sensitivity_at_specificity( @@ -3161,8 +3163,8 @@ def _streaming_sparse_average_precision_at_top_k(labels, ops.add_to_collections(metrics_collections, mean_average_precision) return mean_average_precision - mean_average_precision = distribute_lib.get_tower_context().merge_call( - aggregate_across_towers, total_var, max_var) + mean_average_precision = distribution_strategy_context.get_tower_context( + ).merge_call(aggregate_across_towers, total_var, max_var) update = _safe_scalar_div(total_update, max_update, name=scope) if updates_collections: @@ -3447,7 +3449,7 @@ def precision_at_top_k(labels, ops.add_to_collections(metrics_collections, metric) return metric - metric = distribute_lib.get_tower_context().merge_call( + metric = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, tp, fp) update = math_ops.div( @@ -3686,7 +3688,7 @@ def specificity_at_sensitivity(labels, ops.add_to_collections(metrics_collections, specificity) return specificity - specificity = distribute_lib.get_tower_context().merge_call( + specificity = distribution_strategy_context.get_tower_context().merge_call( aggregate_across_towers, values) update_op = compute_specificity_at_sensitivity( diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 3a41391340edbe25bd97cfadc58587d91bef9de2..df23ac55ceea7ab268bf435541613e62a15c00de 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -240,13 +240,9 @@ def _SoftmaxGrad(op, grad_softmax): gradient w.r.t the input to the softmax """ - # TODO(ilyasu): assert that the tensor has two dimensions at - # graph-construction time? Alternatively: do different things - # depending on the dimensionality of the input tensors. softmax = op.outputs[0] - grad_x = ((grad_softmax - array_ops.reshape( - math_ops.reduce_sum(grad_softmax * softmax, [1]), [-1, 1])) * softmax) - return grad_x + sum_channels = math_ops.reduce_sum(grad_softmax * softmax, -1, keepdims=True) + return (grad_softmax - sum_channels) * softmax @ops.RegisterGradient("LogSoftmax") @@ -264,7 +260,7 @@ def _LogSoftmaxGrad(op, grad): The gradients w.r.t. the input. """ softmax = math_ops.exp(op.outputs[0]) - return grad - math_ops.reduce_sum(grad, 1, keepdims=True) * softmax + return grad - math_ops.reduce_sum(grad, -1, keepdims=True) * softmax @ops.RegisterGradient("BiasAdd") diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index f47f38e29e328ea92bfc494d60673c70a58274d3..51f812b395381defbd15b59f5661fa6603966f8e 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -425,7 +425,7 @@ def depthwise_conv2d(input, strides: 1-D of size 4. The stride of the sliding window for each dimension of `input`. padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. - See the @{tf.nn.convolution$comment here} + See the "returns" section of `tf.nn.convolution` for details. rate: 1-D of size 2. The dilation rate in which we sample input values across the `height` and `width` dimensions in atrous convolution. If it is greater than 1, then all values of strides must be 1. @@ -507,7 +507,7 @@ def separable_conv2d(input, strides: 1-D of size 4. The strides for the depthwise convolution for each dimension of `input`. padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. - See the @{tf.nn.convolution$comment here} + See the "returns" section of `tf.nn.convolution` for details. rate: 1-D of size 2. The dilation rate in which we sample input values across the `height` and `width` dimensions in atrous convolution. If it is greater than 1, then all values of strides must be 1. @@ -1189,7 +1189,7 @@ def nce_loss(weights, Note: By default this uses a log-uniform (Zipfian) distribution for sampling, so your labels must be sorted in order of decreasing frequency to achieve good results. For more details, see - @{tf.nn.log_uniform_candidate_sampler}. + `tf.nn.log_uniform_candidate_sampler`. Note: In the case where `num_true` > 1, we assign to each target class the target probability 1 / `num_true` so that the target probabilities diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 0c2f5b06c497e8ca7db20ac09938c86b425d66a0..6fd1273687eddd24e68e16efb19617ab152ab959 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -22,6 +22,7 @@ import numbers import numpy as np +from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_util @@ -897,8 +898,8 @@ def pool( ``` where the reduction function REDUCE depends on the value of `pooling_type`, - and pad_before is defined based on the value of `padding` as described in the - @{tf.nn.convolution$comment here}. + and pad_before is defined based on the value of `padding` as described in + the "returns" section of `tf.nn.convolution` for details. The reduction never includes out-of-bounds positions. In the case that `data_format` starts with `"NC"`, the `input` and output are @@ -920,7 +921,7 @@ def pool( window_shape: Sequence of N ints >= 1. pooling_type: Specifies pooling operation, must be "AVG" or "MAX". padding: The padding algorithm, must be "SAME" or "VALID". - See the @{tf.nn.convolution$comment here} + See the "returns" section of `tf.nn.convolution` for details. dilation_rate: Optional. Dilation rate. List of N ints >= 1. Defaults to [1]*N. If any value of dilation_rate is > 1, then all values of strides must be 1. @@ -1044,8 +1045,8 @@ def atrous_conv2d(value, filters, rate, padding, name=None): """Atrous convolution (a.k.a. convolution with holes or dilated convolution). This function is a simpler wrapper around the more general - @{tf.nn.convolution}, and exists only for backwards compatibility. You can - use @{tf.nn.convolution} to perform 1-D, 2-D, or 3-D atrous convolution. + `tf.nn.convolution`, and exists only for backwards compatibility. You can + use `tf.nn.convolution` to perform 1-D, 2-D, or 3-D atrous convolution. Computes a 2-D atrous convolution, also known as convolution with holes or @@ -1204,7 +1205,7 @@ def conv2d_transpose( strides: A list of ints. The stride of the sliding window for each dimension of the input tensor. padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. - See the @{tf.nn.convolution$comment here} + See the "returns" section of `tf.nn.convolution` for details. data_format: A string. 'NHWC' and 'NCHW' are supported. name: Optional name for the returned tensor. @@ -1429,7 +1430,7 @@ def conv3d_transpose( strides: A list of ints. The stride of the sliding window for each dimension of the input tensor. padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. - See the @{tf.nn.convolution$comment here} + See the "returns" section of `tf.nn.convolution` for details. data_format: A string, either `'NDHWC'` or `'NCDHW`' specifying the layout of the input and output tensors. Defaults to `'NDHWC'`. name: Optional name for the returned tensor. @@ -1669,17 +1670,19 @@ def _softmax(logits, compute_op, dim=-1, name=None): shape = logits.get_shape() is_last_dim = (dim is -1) or (dim == shape.ndims - 1) - if shape.ndims is 2 and is_last_dim: - return compute_op(logits, name=name) - - # If dim is the last dimension, simply reshape the logits to a matrix and - # apply the internal softmax. + # TODO(phawkins): remove after 2018/8/27 and simplify this code. + softmax_accepts_r1_or_greater = compat.forward_compatible(2018, 8, 27) + reshape_required = (not softmax_accepts_r1_or_greater) and shape.ndims != 2 if is_last_dim: - input_shape = array_ops.shape(logits) - logits = _flatten_outer_dims(logits) - output = compute_op(logits) - output = array_ops.reshape(output, input_shape, name=name) - return output + if reshape_required: + # If dim is the last dimension, simply reshape the logits to a matrix and + # apply the internal softmax. + input_shape = array_ops.shape(logits) + logits = _flatten_outer_dims(logits) + output = compute_op(logits) + output = array_ops.reshape(output, input_shape, name=name) + return output + return compute_op(logits, name=name) # If dim is not the last dimension, we have to do a reshape and transpose so # that we can still perform softmax on its last dimension. @@ -1690,14 +1693,19 @@ def _softmax(logits, compute_op, dim=-1, name=None): logits = _swap_axis(logits, dim_axis, math_ops.subtract(input_rank, 1)) shape_after_swap = array_ops.shape(logits) - # Reshape logits into a matrix. - logits = _flatten_outer_dims(logits) + if reshape_required: + # Reshape logits into a matrix. + logits = _flatten_outer_dims(logits) + + # Do the actual softmax on its last dimension. + output = compute_op(logits) - # Do the actual softmax on its last dimension. - output = compute_op(logits) + # Transform back the output tensor. + output = array_ops.reshape(output, shape_after_swap) + else: + # Do the actual softmax on its last dimension. + output = compute_op(logits) - # Transform back the output tensor. - output = array_ops.reshape(output, shape_after_swap) output = _swap_axis( output, dim_axis, math_ops.subtract(input_rank, 1), name=name) @@ -1811,7 +1819,7 @@ def softmax_cross_entropy_with_logits_v2( or `float64`). Backpropagation will happen into both `logits` and `labels`. To disallow - backpropagation into `labels`, pass label tensors through @{tf.stop_gradient} + backpropagation into `labels`, pass label tensors through `tf.stop_gradient` before feeding it to this function. **Note that to avoid confusion, it is required to pass only named arguments to @@ -1901,7 +1909,7 @@ _XENT_DEPRECATION = """ Future major versions of TensorFlow will allow gradients to flow into the labels input on backprop by default. -See @{tf.nn.softmax_cross_entropy_with_logits_v2}. +See `tf.nn.softmax_cross_entropy_with_logits_v2`. """ @@ -1938,7 +1946,7 @@ def softmax_cross_entropy_with_logits( Backpropagation will happen only into `logits`. To calculate a cross entropy loss that allows backpropagation into both `logits` and `labels`, see - @{tf.nn.softmax_cross_entropy_with_logits_v2}. + `tf.nn.softmax_cross_entropy_with_logits_v2`. **Note that to avoid confusion, it is required to pass only named arguments to this function.** @@ -1995,8 +2003,8 @@ def sparse_softmax_cross_entropy_with_logits( A common use case is to have logits and labels of shape `[batch_size, num_classes]`, but higher dimensions are supported, in which case the `dim`-th dimension is assumed to be of size `num_classes`. - `logits` and `labels` must have the same dtype (either `float16`, `float32`, - or `float64`). + `logits` must have the dtype of `float16`, `float32`, or `float64`, and + `labels` must have the dtype of `int32` or `int64`. **Note that to avoid confusion, it is required to pass only named arguments to this function.** @@ -2009,7 +2017,8 @@ def sparse_softmax_cross_entropy_with_logits( exception when this op is run on CPU, and return `NaN` for corresponding loss and gradient rows on GPU. logits: Unscaled log probabilities of shape - `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`. + `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32`, or + `float64`. name: A name for the operation (optional). Returns: @@ -2105,7 +2114,7 @@ def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None): strides: A list or tuple of 4 ints. The stride of the sliding window for each dimension of the input tensor. padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. - See the @{tf.nn.convolution$comment here} + See the "returns" section of `tf.nn.convolution` for details. data_format: A string. 'NHWC' and 'NCHW' are supported. name: Optional name for the operation. @@ -2134,7 +2143,7 @@ def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None): strides: A list or tuple of 4 ints. The stride of the sliding window for each dimension of the input tensor. padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. - See the @{tf.nn.convolution$comment here} + See the "returns" section of `tf.nn.convolution` for details. data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported. name: Optional name for the operation. @@ -2166,7 +2175,7 @@ def _calc_conv_flops(graph, node): filter_height = int(filter_shape[0]) filter_width = int(filter_shape[1]) filter_in_depth = int(filter_shape[2]) - output_count = np.prod(output_shape.as_list()) + output_count = np.prod(output_shape.as_list(), dtype=np.int64) return ops.OpStats( "flops", (output_count * filter_in_depth * filter_height * filter_width * 2)) @@ -2184,7 +2193,7 @@ def _calc_depthwise_conv_flops(graph, node): output_shape.assert_is_fully_defined() filter_height = int(filter_shape[0]) filter_width = int(filter_shape[1]) - output_count = np.prod(output_shape.as_list()) + output_count = np.prod(output_shape.as_list(), dtype=np.int64) return ops.OpStats("flops", (output_count * filter_height * filter_width * 2)) @@ -2292,7 +2301,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di noise_shape: A 1-D `Tensor` of type `int32`, representing the shape for randomly generated keep/drop flags. seed: A Python integer. Used to create random seeds. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. name: A name for this operation (optional). @@ -2512,7 +2521,7 @@ def conv1d_transpose( stride: An `integer`. The number of entries by which the filter is moved right at each step. padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. - See the @{tf.nn.convolution$comment here} + See the "returns" section of `tf.nn.convolution` for details. data_format: A string. 'NHWC' and 'NCHW' are supported. name: Optional name for the returned tensor. @@ -2594,7 +2603,7 @@ def _calc_dilation2d_flops(graph, node): output_shape.assert_is_fully_defined() filter_height = int(filter_shape[0]) filter_width = int(filter_shape[1]) - output_count = np.prod(output_shape.as_list()) + output_count = np.prod(output_shape.as_list(), dtype=np.int64) return ops.OpStats("flops", (output_count * filter_height * filter_width * 2)) diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 035b4735affbd37f9de94057eed6f7b5d9aadd6e..ce0db6b264a55d0565b4ffcc6e6a783a5bf555a1 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import math +from absl.testing import parameterized import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin @@ -67,7 +68,7 @@ class ZeroFractionTest(test_lib.TestCase): self.assertTrue(np.isnan(y)) -class SoftmaxTest(test_lib.TestCase): +class SoftmaxTest(test_lib.TestCase, parameterized.TestCase): def _softmax(self, x): assert len(x.shape) == 2 @@ -76,7 +77,7 @@ class SoftmaxTest(test_lib.TestCase): z = u.sum(1)[:, np.newaxis] return u / z - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSoftmax(self): x_shape = [5, 10] x_np = np.random.randn(*x_shape).astype(np.float32) @@ -102,15 +103,15 @@ class SoftmaxTest(test_lib.TestCase): self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) - def testGradient(self): - x_shape = [5, 10] + @parameterized.parameters(((5, 10),), ((2, 3, 4),)) + def testGradient(self, x_shape): x_np = np.random.randn(*x_shape).astype(np.float64) with self.test_session(): x_tf = constant_op.constant(x_np) y_tf = nn_ops.softmax(x_tf) err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, x_shape) - eps = 1e-8 + eps = 2e-8 self.assertLess(err, eps) @@ -123,7 +124,7 @@ class LogPoissonLossTest(test_lib.TestCase): lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.) return lpl - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLogPoissonLoss(self): x_shape = [5, 10] x_np = np.random.randn(*x_shape).astype(np.float32) @@ -156,7 +157,7 @@ class LogPoissonLossTest(test_lib.TestCase): self.assertLess(err_stirling, eps) -class LogSoftmaxTest(test_lib.TestCase): +class LogSoftmaxTest(test_lib.TestCase, parameterized.TestCase): def _log_softmax(self, x): assert len(x.shape) == 2 @@ -164,7 +165,7 @@ class LogSoftmaxTest(test_lib.TestCase): u = x - m return u - np.log(np.sum(np.exp(u), 1, keepdims=True)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLogSoftmax(self): x_shape = [5, 10] x_np = np.random.randn(*x_shape).astype(np.float32) @@ -187,8 +188,8 @@ class LogSoftmaxTest(test_lib.TestCase): self.assertAllClose(x_neg_axis_tf, y_pos_axis_tf, eps) self.assertAllClose(y_pos_axis_tf, z_gt_axis_tf, eps) - def testGradient(self): - x_shape = [5, 10] + @parameterized.parameters(((5, 10),), ((2, 3, 4),)) + def testGradient(self, x_shape): x_np = np.random.randn(*x_shape).astype(np.float64) with self.test_session(): x_tf = constant_op.constant(x_np) @@ -201,7 +202,7 @@ class LogSoftmaxTest(test_lib.TestCase): class L2LossTest(test_lib.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testL2Loss(self): for dtype in [dtypes.float32, dtypes.float64]: x = constant_op.constant( @@ -219,7 +220,7 @@ class L2LossTest(test_lib.TestCase): output = nn_ops.l2_loss(x) err = gradient_checker.compute_gradient_error(x, x_shape, output, [1]) print("L2Loss gradient err = %g " % err) - err_tolerance = 1e-11 + err_tolerance = 1e-10 self.assertLess(err, err_tolerance) @@ -235,7 +236,7 @@ class L2NormalizeTest(test_lib.TestCase): norm = np.apply_along_axis(np.linalg.norm, dim, x) return x / np.expand_dims(norm, dim) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testL2Normalize(self): x_shape = [20, 7, 3] np.random.seed(1) @@ -246,7 +247,7 @@ class L2NormalizeTest(test_lib.TestCase): y_tf = nn_impl.l2_normalize(x_tf, dim) self.assertAllClose(y_np, self.evaluate(y_tf)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testL2NormalizeDimArray(self): x_shape = [20, 7, 3] np.random.seed(1) diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py index d348e47f57b703138aabfc3463e750b795113335..8fcbd7d83407ac1972f5165175dc498f06615cc2 100644 --- a/tensorflow/python/ops/numerics.py +++ b/tensorflow/python/ops/numerics.py @@ -56,8 +56,8 @@ def add_check_numerics_ops(): `check_numerics` op for all of its (`half`, `float`, or `double`) inputs is guaranteed to run before the `check_numerics` op on any of its outputs. - Note: This API is not compatible with the use of @{tf.cond} or - @{tf.while_loop}, and will raise a `ValueError` if you attempt to call it + Note: This API is not compatible with the use of `tf.cond` or + `tf.while_loop`, and will raise a `ValueError` if you attempt to call it in such a graph. Returns: diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..015181af47b310cd6aec52b4a383f8868dddc493 --- /dev/null +++ b/tensorflow/python/ops/parallel_for/BUILD @@ -0,0 +1,129 @@ +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "parallel_for", + srcs = [ + "__init__.py", + "control_flow_ops.py", + "gradients.py", + "pfor.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":control_flow_ops", + ":gradients", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:functional_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//tensorflow/python:util", + "@absl_py//absl/flags", + ], +) + +py_library( + name = "pfor_lib", + srcs = ["pfor.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:functional_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "@absl_py//absl/flags", + ], +) + +py_library( + name = "control_flow_ops", + srcs = ["control_flow_ops.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":pfor_lib", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:util", + ], +) + +cuda_py_test( + name = "control_flow_ops_test", + size = "large", + srcs = ["control_flow_ops_test.py"], + additional_deps = [ + ":control_flow_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:gradients", + "//tensorflow/python:logging_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:session", + "//tensorflow/python:tensor_array_grad", + "//tensorflow/python:random_ops", + "//tensorflow/python:util", + ], +) + +py_library( + name = "gradients", + srcs = ["gradients.py"], + srcs_version = "PY2AND3", + deps = [ + ":control_flow_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:util", + ], +) + +cuda_py_test( + name = "gradients_test", + size = "large", + srcs = ["gradients_test.py"], + additional_deps = [ + ":control_flow_ops", + ":gradients", + "//third_party/py/numpy", + "//tensorflow/python:layers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:random_ops", + "//tensorflow/python/ops/losses", + ], +) diff --git a/tensorflow/python/ops/parallel_for/__init__.py b/tensorflow/python/ops/parallel_for/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd8bc6d487f625c9ab442c91da417dce00074a2a --- /dev/null +++ b/tensorflow/python/ops/parallel_for/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for pfor, for_loop, jacobian.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops.parallel_for import * # pylint: disable=wildcard-import +from tensorflow.python.ops.parallel_for.control_flow_ops import for_loop +from tensorflow.python.ops.parallel_for.control_flow_ops import pfor +from tensorflow.python.ops.parallel_for.gradients import batch_jacobian +from tensorflow.python.ops.parallel_for.gradients import jacobian diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ccf2eb82146969532c84b7d56d40974e94337507 --- /dev/null +++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""for_loop and pfor ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops.parallel_for.pfor import PFor +from tensorflow.python.util import nest + + +def for_loop(loop_fn, loop_fn_dtypes, iters): + """Runs `loop_fn` `iters` times and stacks the outputs. + + + Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and + stacks corresponding outputs of the different runs. + + Args: + loop_fn: A function that takes an int32 scalar tf.Tensor object representing + the iteration number, and returns a possibly nested structure of tensor + objects. The shape of these outputs should not depend on the input. + loop_fn_dtypes: dtypes for the outputs of loop_fn. + iters: Number of iterations for which to run loop_fn. + + Returns: + Returns a nested structure of stacked output tensor objects with the same + nested structure as the output of `loop_fn`. + """ + + flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes) + + def while_body(i, *ta_list): + """Body of while loop.""" + fn_output = nest.flatten(loop_fn(i)) + if len(fn_output) != len(flat_loop_fn_dtypes): + raise ValueError( + "Number of expected outputs, %d, does not match the number of " + "actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes), + len(fn_output))) + outputs = [] + for out, ta in zip(fn_output, ta_list): + # TODO(agarwal): support returning Operation objects from loop_fn. + assert isinstance(out, ops.Tensor) + outputs.append(ta.write(i, array_ops.expand_dims(out, 0))) + return tuple([i + 1] + outputs) + + ta_list = control_flow_ops.while_loop( + lambda i, *ta: i < iters, while_body, [0] + [ + tensor_array_ops.TensorArray(dtype, iters) + for dtype in flat_loop_fn_dtypes + ])[1:] + + # TODO(rachelim): enable this for sparse tensors + return nest.pack_sequence_as(loop_fn_dtypes, [ta.concat() for ta in ta_list]) + + +def pfor(loop_fn, iters): + """Equivalent to running `loop_fn` `iters` times and stacking the outputs. + + `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters` + times, with input from 0 to `iters - 1`, and stacking corresponding output of + each iteration. However the implementation does not use a tf.while_loop. + Instead it adds new operations to the graph that collectively compute the same + value as what running `loop_fn` in a loop would compute. + + + This is an experimental feature and currently has a lot of limitations: + - There should be no data depenendency between the different iterations. For + example, a future iteration should not depend on a value or side-effect of + a previous iteration. + - Stateful kernels may mostly not be supported since these often imply a + data dependency or ordering of the iterations. We do support a limited set + of such stateful kernels though (like RandomFoo, Variable operations like + reads, etc). + - Conversion works only on a limited set of kernels for which a converter + has been registered. + - loop_fn cannot currently contain control flow operations like + tf.while_loop or tf.cond. + - `loop_fn` should return nested structure of Tensors or Operations. However + if an Operation is returned, it should have zero outputs. + - The shape and dtype of `loop_fn` outputs should not depend on the input + to loop_fn. + + Args: + loop_fn: A function that takes an int32 scalar tf.Tensor object representing + the iteration number, and returns a possibly nested structure of Tensor or + Operation objects. + iters: Number of iterations for which to run loop_fn. + + Returns: + Returns a nested structure of stacked tensor objects with the same nested + structure as the output of `loop_fn`. + """ + existing_ops = set(ops.get_default_graph().get_operations()) + with ops.name_scope("loop_body"): + loop_var = array_ops.placeholder(dtypes.int32, shape=[]) + loop_fn_outputs = loop_fn(loop_var) + new_ops = set(ops.get_default_graph().get_operations()) - existing_ops + iters = ops.convert_to_tensor(iters) + with ops.name_scope("pfor"): + converter = PFor(loop_var, iters, new_ops) + outputs = [] + for loop_fn_output in nest.flatten(loop_fn_outputs): + outputs.append(converter.convert(loop_fn_output)) + return nest.pack_sequence_as(loop_fn_outputs, outputs) diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c0e66cb0b874b183d53cc34dbb3aa3d182e255a4 --- /dev/null +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -0,0 +1,1404 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for pfor and for_loop.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from absl import flags +import numpy as np + +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gradients as gradient_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops +from tensorflow.python.platform import test +from tensorflow.python.util import nest + + +class PForTest(test.TestCase): + + def _run_targets(self, targets1, targets2=None, run_init=True): + targets1 = nest.flatten(targets1) + targets2 = ([] if targets2 is None else nest.flatten(targets2)) + assert len(targets1) == len(targets2) or not targets2 + if run_init: + init = variables.global_variables_initializer() + self.evaluate(init) + return self.evaluate(targets1 + targets2) + + def run_and_assert_equal(self, targets1, targets2): + outputs = self._run_targets(targets1, targets2) + outputs = nest.flatten(outputs) # flatten SparseTensorValues + n = len(outputs) // 2 + for i in range(n): + if outputs[i + n].dtype != np.object: + self.assertAllClose(outputs[i + n], outputs[i], rtol=1e-4, atol=1e-5) + else: + self.assertAllEqual(outputs[i + n], outputs[i]) + + def _test_loop_fn(self, loop_fn, iters, loop_fn_dtypes=dtypes.float32): + t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters) + t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters) + self.run_and_assert_equal(t1, t2) + + def test_op_conversion_fallback_to_while_loop(self): + # Note that we used top_k op for this test. If a converter gets defined for + # it, we will need to find another op for which a converter has not been + # defined. + x = random_ops.random_uniform([3, 2, 4]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + return nn.top_k(x_i) + + with self.assertRaisesRegexp(ValueError, "No converter defined"): + self._test_loop_fn( + loop_fn, 3, loop_fn_dtypes=[dtypes.float32, dtypes.int32]) + flags.FLAGS.op_conversion_fallback_to_while_loop = True + self._test_loop_fn( + loop_fn, 3, loop_fn_dtypes=[dtypes.float32, dtypes.int32]) + flags.FLAGS.op_conversion_fallback_to_while_loop = False + + +class ArrayTest(PForTest): + + def test_gather(self): + x = random_ops.random_uniform([3, 3, 3]) + + def loop_fn(i): + outputs = [] + x_i = array_ops.gather(x, i) + for y in [x, x_i]: + axes = [0, 2, -1] if y == x else [0] + for axis in axes: + outputs.append(array_ops.gather(y, 2, axis=axis)) + outputs.append(array_ops.gather(y, i, axis=axis)) + outputs.append(array_ops.gather(y, [i], axis=axis)) + outputs.append(array_ops.gather(y, [i, 2], axis=axis)) + outputs.append(array_ops.gather(y, [[2, i], [i, 1]], axis=axis)) + return outputs + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 20) + + def test_shape(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + return array_ops.shape(x_i), array_ops.shape(x_i, out_type=dtypes.int64) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32, dtypes.int64]) + + def test_size(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + return array_ops.size(x_i), array_ops.size(x_i, out_type=dtypes.int64) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32, dtypes.int64]) + + def test_rank(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + return array_ops.rank(x_i) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_shape_n(self): + x = random_ops.random_uniform([3, 2, 3]) + y = random_ops.random_uniform([3]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + y_i = array_ops.gather(y, i) + return array_ops.shape_n([x_i, x, y, y_i]), array_ops.shape_n( + [x_i, x, y, y_i], out_type=dtypes.int64) + + self._test_loop_fn( + loop_fn, 3, loop_fn_dtypes=[dtypes.int32] * 4 + [dtypes.int64] * 4) + + def test_reshape(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.reshape(x1, [-1]), array_ops.reshape(x1, [1, 3, 1, -1]) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_expand_dims(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.expand_dims( + x1, axis=-1), array_ops.expand_dims( + x1, axis=1) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_slice(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.slice(x1, begin=(0, 1), size=(2, 1)) + + self._test_loop_fn(loop_fn, 3) + + def test_tile(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.tile(x1, [2, 1]) + + self._test_loop_fn(loop_fn, 3) + + def test_tile_loop_dependent(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.tile(x1, [i, 1]) + + with self.assertRaisesRegexp(ValueError, "expected to be loop invariant"): + pfor_control_flow_ops.pfor(loop_fn, 2) + + def test_pack(self): + x = random_ops.random_uniform([3, 2, 3]) + y = random_ops.random_uniform([2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.stack([x1, y], axis=-1) + + self._test_loop_fn(loop_fn, 1) + + def test_unpack(self): + x = random_ops.random_uniform([3, 2, 3, 4]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + return array_ops.unstack( + x_i, 4, axis=-1), array_ops.unstack( + x_i, 3, axis=1) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 7) + + def test_pad(self): + x = random_ops.random_uniform([3, 2, 3]) + padding = constant_op.constant([[1, 2], [3, 4]]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.pad(x1, padding, mode="CONSTANT") + + self._test_loop_fn(loop_fn, 3) + + def test_split(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.split(x1, 2, axis=0), array_ops.split(x1, 3, axis=-1) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 5) + + def test_transpose(self): + x = random_ops.random_uniform([3, 2, 3, 4]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.transpose(x1, [2, 1, 0]) + + self._test_loop_fn(loop_fn, 3) + + def test_zeros_like(self): + x = random_ops.random_uniform([3, 2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + z = array_ops.zeros_like(x1), + return z, z + x1 + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_concat_v2(self): + x = random_ops.random_uniform([3, 2, 3]) + y = random_ops.random_uniform([2, 3]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.concat( + [x1, x1, y], axis=0), array_ops.concat( + [x1, x1, y], axis=-1) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_unary_cwise_ops(self): + for op in [array_ops.identity, array_ops.stop_gradient]: + x = random_ops.random_uniform([3, 5]) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + x1 = array_ops.gather(x, i) + y = op(x1) + x1 + loss = nn.l2_loss(y) + return op(x), y, gradient_ops.gradients(loss, x1) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3) + + def test_strided_slice(self): + x = random_ops.random_uniform([3, 3, 4, 4, 2, 2, 2]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + y = x_i[:2, ::2, 1::3, ..., array_ops.newaxis, 1] + loss = nn.l2_loss(y) + return y, gradient_ops.gradients(loss, x_i) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + +class MathTest(PForTest): + + def test_unary_cwise_ops(self): + for op in [ + math_ops.tanh, nn.relu, math_ops.sigmoid, math_ops.negative, + math_ops.square + ]: + x = random_ops.random_uniform([3, 5]) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + x1 = array_ops.gather(x, i) + y = op(x1) + loss = math_ops.reduce_sum(y * y) + return op(x), y, gradient_ops.gradients(loss, x1) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3) + + def test_unary_cwise_no_grad(self): + for op in [math_ops.ceil, math_ops.floor, math_ops.logical_not]: + x = random_ops.random_uniform([3, 5]) + if op == math_ops.logical_not: + x = x > 0 + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + return op(array_ops.gather(x, i)) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=x.dtype) + + def test_binary_cwise_ops(self): + logical_ops = [ + math_ops.logical_and, math_ops.logical_or, math_ops.logical_xor + ] + bool_ops = [ + math_ops.less, math_ops.less_equal, math_ops.greater, + math_ops.greater_equal, math_ops.equal, math_ops.not_equal + ] + float_ops = [ + math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.divide, + math_ops.maximum, math_ops.minimum + ] + for op in logical_ops + bool_ops + float_ops: + x = random_ops.random_uniform([7, 3, 5]) + y = random_ops.random_uniform([3, 5]) + if op in logical_ops: + x = x > 0 + y = y > 0 + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + x1 = array_ops.gather(x, i) + y1 = array_ops.gather(y, i) + return op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1) + + # pylint: enable=cell-var-from-loop + + dtype = dtypes.float32 if op in float_ops else dtypes.bool + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtype] * 5) + + def test_addn(self): + x = random_ops.random_uniform([2, 3, 5]) + y = random_ops.random_uniform([3, 5]) + z = random_ops.random_uniform([3, 5]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return math_ops.add_n([x1, y, z]) + + self._test_loop_fn(loop_fn, 2) + + def test_matmul(self): + for tr_a in (True, False): + for tr_b in (True, False): + for stack_a in (True, False): + for stack_b in (True, False): + shape_a = (5, 3) if tr_a else (3, 5) + if stack_a: + shape_a = (2,) + shape_a + shape_b = (7, 5) if tr_b else (5, 7) + if stack_b: + shape_b = (2,) + shape_b + + x = random_ops.random_uniform(shape_a) + y = random_ops.random_uniform(shape_b) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a = array_ops.gather(x, i) if stack_a else x + b = array_ops.gather(y, i) if stack_b else y + return math_ops.matmul(a, b, transpose_a=tr_a, transpose_b=tr_b) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + def test_batch_matmul(self): + for tr_a in (True, False): + for tr_b in (True, False): + for stack_a in (True, False): + for stack_b in (True, False): + shape_a = (4, 5, 3) if tr_a else (4, 3, 5) + if stack_a: + shape_a = (2,) + shape_a + shape_b = (4, 7, 5) if tr_b else (4, 5, 7) + if stack_b: + shape_b = (2,) + shape_b + + x = random_ops.random_uniform(shape_a) + y = random_ops.random_uniform(shape_b) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a = array_ops.gather(x, i) if stack_a else x + b = array_ops.gather(y, i) if stack_b else y + return math_ops.matmul(a, b, transpose_a=tr_a, transpose_b=tr_b) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + def test_reduction(self): + x = random_ops.random_uniform([2, 3, 4, 5]) + for op in [ + math_ops.reduce_sum, math_ops.reduce_prod, math_ops.reduce_max, + math_ops.reduce_min + ]: + for axis in ([1], None, [0, 2]): + for keepdims in (True, False): + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a = array_ops.gather(x, i) + return op(a, axis=axis, keepdims=keepdims) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + def test_cum_sum(self): + x = random_ops.random_uniform([2, 3, 4, 5]) + for axis in (1, -2): + for exclusive in (True, False): + for reverse in (True, False): + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a = array_ops.gather(x, i) + return math_ops.cumsum( + a, axis=axis, exclusive=exclusive, reverse=reverse) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + def test_cum_prod(self): + x = random_ops.random_uniform([2, 3, 4, 5]) + for axis in (1, -2): + for exclusive in (True, False): + for reverse in (True, False): + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a = array_ops.gather(x, i) + return math_ops.cumprod( + a, axis=axis, exclusive=exclusive, reverse=reverse) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + def test_bias_add(self): + x_shape = [2, 3, 4, 5, 6] + x = random_ops.random_uniform(x_shape) + for data_format in ("NCHW", "NHWC"): + bias_dim = 2 if data_format == "NCHW" else -1 + bias_shape = x_shape[bias_dim] + bias = random_ops.random_uniform([bias_shape]) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a = array_ops.gather(x, i) + y = nn.bias_add(a, bias, data_format=data_format) + loss = math_ops.reduce_sum(y * y) + return y, gradient_ops.gradients(loss, bias) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn( + loop_fn, 2, loop_fn_dtypes=[dtypes.float32, dtypes.float32]) + + def test_unsorted_segment_sum(self): + t = random_ops.random_uniform([3, 3, 2]) + segment_ids = constant_op.constant([[0, 0, 2], [0, 1, 2], [2, 2, 2]]) + num_segments = 3 + + def loop_fn(i): + data = array_ops.gather(t, i) + data_0 = array_ops.gather(t, 0) + seg_ids = array_ops.gather(segment_ids, i) + return (math_ops.unsorted_segment_sum(data, seg_ids, num_segments), + math_ops.unsorted_segment_sum(data_0, seg_ids, num_segments)) + + self._test_loop_fn(loop_fn, 3, [dtypes.float32] * 2) + + def test_cast(self): + x = constant_op.constant([[1], [2]]) + y = constant_op.constant([[1.0], [2.0]]) + + def loop_fn(i): + return (math_ops.cast(array_ops.gather(x, i), dtypes.float32), + math_ops.cast(array_ops.gather(y, i), dtypes.int32)) + + self._test_loop_fn( + loop_fn, 2, loop_fn_dtypes=[dtypes.float32, dtypes.int32]) + + def test_tanh_axpy(self): + a = constant_op.constant(3.) + x = random_ops.random_uniform([4, 5]) + y = random_ops.random_uniform([6, 5]) + n = x.shape[0] + + def loop_fn(i): + return math_ops.tanh(a * array_ops.gather(x, i) + array_ops.gather(y, i)) + + self._test_loop_fn(loop_fn, n) + + def test_select(self): + cond = constant_op.constant([True, False]) + a = random_ops.random_uniform([2, 3, 5]) + b = random_ops.random_uniform([2, 3, 5]) + for cond_shape in [2], [2, 3], [2, 3, 5]: + cond = random_ops.random_uniform(cond_shape) > 0.5 + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + a_i = array_ops.gather(a, i) + b_i = array_ops.gather(b, i) + cond_i = array_ops.gather(cond, i) + return array_ops.where(cond_i, a_i, b_i) + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 2) + + +class NNTest(PForTest): + + def test_conv2d(self): + x = random_ops.random_uniform([3, 2, 12, 12, 3]) + filt = random_ops.random_uniform([3, 3, 3, 7]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return nn.conv2d( + x1, filt, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC") + + self._test_loop_fn(loop_fn, 3) + + def test_conv2d_backprop_input(self): + x_shape = [2, 12, 12, 3] + filt = random_ops.random_uniform([3, 3, 3, 7]) + grad = random_ops.random_uniform([3, 2, 5, 5, 7]) + + def loop_fn(i): + grad1 = array_ops.gather(grad, i) + return nn.conv2d_backprop_input( + x_shape, + filt, + grad1, + strides=[1, 2, 2, 1], + padding="VALID", + data_format="NHWC") + + self._test_loop_fn(loop_fn, 3) + + def test_conv2d_backprop_filter(self): + x = random_ops.random_uniform([3, 2, 12, 12, 3]) + x_0 = array_ops.gather(x, 0) + filter_sizes = [3, 3, 3, 7] + grad = random_ops.random_uniform([3, 2, 5, 5, 7]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + grad_i = array_ops.gather(grad, i) + return [ + nn.conv2d_backprop_filter( + inp, + filter_sizes, + grad_i, + strides=[1, 2, 2, 1], + padding="VALID", + data_format="NHWC") for inp in [x_i, x_0] + ] + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_avg_pool(self): + x = random_ops.random_uniform([3, 2, 12, 12, 3]) + ksize = [1, 3, 3, 1] + + def loop_fn(i): + x1 = array_ops.gather(x, i) + output = nn.avg_pool( + x1, ksize, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC") + loss = nn.l2_loss(output) + return output, gradient_ops.gradients(loss, x1) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_max_pool(self): + x = random_ops.random_uniform([3, 2, 12, 12, 3]) + ksize = [1, 3, 3, 1] + + def loop_fn(i): + x1 = array_ops.gather(x, i) + output = nn.max_pool( + x1, ksize, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC") + loss = nn.l2_loss(output) + return output, gradient_ops.gradients(loss, x1) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + def test_fused_batch_norm(self): + data_formats = ["NHWC"] + if test.is_gpu_available(): + data_formats.append("NCHW") + for is_training in (True, False): + for data_format in data_formats: + if data_format == "NCHW": + x = random_ops.random_uniform([3, 1, 2, 5, 5]) + else: + x = random_ops.random_uniform([3, 1, 5, 5, 2]) + scale = random_ops.random_uniform([2]) + offset = random_ops.random_uniform([2]) + mean = None if is_training else random_ops.random_uniform([2]) + variance = None if is_training else random_ops.random_uniform([2]) + + # pylint: disable=cell-var-from-loop + def loop_fn(i): + x1 = array_ops.gather(x, i) + outputs = nn.fused_batch_norm( + x1, + scale, + offset, + mean=mean, + variance=variance, + epsilon=0.01, + data_format=data_format, + is_training=is_training) + outputs = list(outputs) + # We only test the first value of outputs when is_training is False. + # It looks like CPU and GPU have different outputs for batch_mean and + # batch_variance for this case. + if not is_training: + outputs[1] = constant_op.constant(0.) + outputs[2] = constant_op.constant(0.) + loss = nn.l2_loss(outputs[0]) + gradients = gradient_ops.gradients(loss, [x1, scale, offset]) + return outputs + gradients + + # pylint: enable=cell-var-from-loop + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 6) + + def test_softmax_cross_entropy_with_logits(self): + logits = random_ops.random_uniform([3, 2, 4]) + labels = random_ops.random_uniform([3, 2, 4]) + labels /= math_ops.reduce_sum(labels, axis=[2], keepdims=True) + + def loop_fn(i): + logits_i = array_ops.gather(logits, i) + labels_i = array_ops.gather(labels, i) + loss = nn.softmax_cross_entropy_with_logits( + labels=labels_i, logits=logits_i) + return loss, gradient_ops.gradients(math_ops.reduce_sum(loss), logits_i) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + + +class RandomTest(PForTest): + + # The random values generated in the two implementations are not guaranteed to + # match. So we only check the returned shapes. + def run_and_assert_equal(self, targets1, targets2): + outputs = self._run_targets(targets1, targets2) + n = len(outputs) // 2 + for i in range(n): + self.assertAllEqual(outputs[i].shape, outputs[i + n].shape) + + def test_random_uniform(self): + + def loop_fn(_): + return random_ops.random_uniform([3]) + + self._test_loop_fn(loop_fn, 5) + + def test_random_uniform_int(self): + + def loop_fn(_): + return random_ops.random_uniform([3], maxval=1, dtype=dtypes.int32) + + self._test_loop_fn(loop_fn, 5, loop_fn_dtypes=dtypes.int32) + + def test_random_standard_normal(self): + + def loop_fn(_): + return random_ops.random_normal([3]) + + self._test_loop_fn(loop_fn, 5) + + def test_truncated_normal(self): + + def loop_fn(_): + return random_ops.truncated_normal([3]) + + self._test_loop_fn(loop_fn, 5) + + def test_random_gamma(self): + + def loop_fn(_): + return random_ops.random_gamma([3], alpha=[0.5]) + + self._test_loop_fn(loop_fn, 5) + + def test_random_poisson_v2(self): + + def loop_fn(_): + return random_ops.random_poisson(lam=[1.3], shape=[3]) + + self._test_loop_fn(loop_fn, 5) + + +class LoggingTest(PForTest): + + def test_print(self): + x = random_ops.random_uniform([3, 5]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return logging_ops.Print( + x1, [x1, "x1", array_ops.shape(x1)], summarize=10) + + self._test_loop_fn(loop_fn, 3) + + def test_assert(self): + + def loop_fn(i): + return control_flow_ops.Assert(i < 10, [i, [10], [i + 1]]) + + # TODO(agarwal): make this work with for_loop. + with session.Session() as sess: + sess.run(pfor_control_flow_ops.pfor(loop_fn, 3)) + + +class TensorArrayTest(PForTest): + + def test_create_outside_and_read(self): + + ta = tensor_array_ops.TensorArray( + dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1) + + def loop_fn(i): + return ta.read(i), ta.read(0) + + self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 2) + + def test_create_outside_and_gather(self): + + ta = tensor_array_ops.TensorArray( + dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1) + + def loop_fn(i): + return ta.gather([i]), ta.gather([0, 1]) + + self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 2) + + def test_create_outside_and_write_and_scatter(self): + + t = tensor_array_ops.TensorArray(dtypes.int32, 10, clear_after_read=False) + handle = t.handle + + def loop_fn(i): + ta = t.write(i + 2, 2 * i).write(i, 5) + ta = ta.scatter([4 + i], [4]).scatter([6 + i, 8 + i], [6 + i, 8 + i]) + return ta.flow + + t1 = pfor_control_flow_ops.pfor(loop_fn, iters=2) + out1 = tensor_array_ops.TensorArray( + dtypes.int32, handle=handle, flow=t1[-1]).stack() + output1 = self._run_targets(out1) + + t2 = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, iters=2) + out2 = tensor_array_ops.TensorArray( + dtypes.int32, handle=handle, flow=t2[-1]).stack() + output2 = self._run_targets(out2) + self.assertAllClose(output2, output1) + + def test_create_inside_and_write(self): + + def loop_fn(i): + # TODO(agarwal): switching the order of writes to ta1 does not work. + ta1 = tensor_array_ops.TensorArray(dtypes.int32, 2).write(0, i).write( + 1, 1) + ta2 = tensor_array_ops.TensorArray(dtypes.int32, 1).write(0, 1) + return ta1.stack(), ta2.stack() + + self._test_loop_fn(loop_fn, 3, [dtypes.int32] * 2) + + def test_create_inside_and_scatter(self): + + def loop_fn(i): + # TODO(agarwal): switching the order of scatter to ta1 does not work. + ta1 = tensor_array_ops.TensorArray(dtypes.int32, 2).scatter( + [0], [[i, 2]]).scatter([1], [[1, 2]]) + ta2 = tensor_array_ops.TensorArray(dtypes.int32, + 2).scatter([0], [3]).scatter([1], [4]) + return ta1.stack(), ta2.stack() + + self._test_loop_fn(loop_fn, 3, [dtypes.int32] * 2) + + def test_create_inside_and_read(self): + + def loop_fn(i): + ta1 = tensor_array_ops.TensorArray( + dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1) + ta2 = tensor_array_ops.TensorArray( + dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2) + # TODO(agarwal): ta1.read(i) currently is not supported. + return ta1.read(0), ta2.read(0), ta2.read(i) + + self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 3) + + def test_create_inside_and_gather(self): + + def loop_fn(i): + ta1 = tensor_array_ops.TensorArray( + dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1) + ta2 = tensor_array_ops.TensorArray( + dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2) + # TODO(agarwal): ta1.read(i) currently is not supported. + return ta1.gather([0, 1]), ta2.gather([0, 1]), ta2.gather([i]) + + self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 3) + + def test_grad(self): + x = random_ops.random_uniform([3, 2]) + ta = tensor_array_ops.TensorArray( + dtypes.float32, 3, clear_after_read=False).unstack(x) + y = math_ops.square(ta.stack()) + + def loop_fn(i): + y_i = array_ops.gather(y, i) + grad = gradient_ops.gradients(y_i, x)[0] + return array_ops.gather(grad, i) + + t1 = pfor_control_flow_ops.pfor(loop_fn, iters=3) + # y = x * x. Hence dy/dx = 2 * x. + actual_grad = 2.0 * x + with session.Session() as sess: + actual_grad, computed_grad = sess.run([t1, actual_grad]) + self.assertAllClose(actual_grad, computed_grad) + + +class StackTest(PForTest): + + def test_stack_inside_loop_invariant(self): + + def loop_fn(_): + s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) + op1 = data_flow_ops.stack_push_v2(s, 1) + with ops.control_dependencies([op1]): + op2 = data_flow_ops.stack_push_v2(s, 2) + with ops.control_dependencies([op2]): + e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + with ops.control_dependencies([e2]): + e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + return e1, e2 + + self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 2) + + def test_stack_inside_push_loop_dependent(self): + + def loop_fn(i): + s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) + op1 = data_flow_ops.stack_push_v2(s, i) + with ops.control_dependencies([op1]): + op2 = data_flow_ops.stack_push_v2(s, 2) + with ops.control_dependencies([op2]): + e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + with ops.control_dependencies([e2]): + e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + return e1, e2 + + self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 2) + + def test_stack_outside_pop(self): + s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) + op = data_flow_ops.stack_push_v2(s, 5) + with ops.control_dependencies([op]): + op = data_flow_ops.stack_push_v2(s, 6) + with ops.control_dependencies([op]): + op = data_flow_ops.stack_push_v2(s, 7) + + def loop_fn(_): + e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + with ops.control_dependencies([e1]): + e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + return e1, e2 + + with ops.control_dependencies([op]): + e1, e2 = pfor_control_flow_ops.pfor(loop_fn, iters=2) + with ops.control_dependencies([e1, e2]): + e3 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32) + v1, v2, v3 = self._run_targets([e1, e2, e3], run_init=False) + self.assertAllEqual([7, 7], v1) + self.assertAllEqual([6, 6], v2) + self.assertAllEqual(5, v3) + + def test_stack_outside_push(self): + s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32) + + def loop_fn(_): + return data_flow_ops.stack_push_v2(s, 7) + + with self.assertRaisesRegexp(ValueError, "StackPushV2 not allowed.*"): + pfor_control_flow_ops.pfor(loop_fn, iters=2) + + +# TODO(agarwal): test nested while_loops. This currently requires converting a +# tf.cond. +class ControlFlowTest(PForTest): + + def test_while_outside_loop(self): + + x = control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) + + def loop_fn(i): + return x + i + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_invariant_while(self): + + def loop_fn(_): + return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_invariant_while_with_control_dependency(self): + + def loop_fn(i): + with ops.control_dependencies([i]): + return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, + [0]) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_while_with_stateful_ops(self): + + def loop_fn(_): + return control_flow_ops.while_loop( + lambda j, x: j < 4, + lambda j, x: (j + 1, x + random_ops.random_uniform([])), [0, 0.])[0] + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_while_unstacked_condition(self): + + def loop_fn(i): + return control_flow_ops.while_loop(lambda j, x: j < 4, + lambda j, x: (j + 1, x + i), [0, 0]) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32, dtypes.int32]) + + def test_while(self): + x = random_ops.random_uniform([3, 5]) + lengths = constant_op.constant([4, 0, 2]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + lengths_i = array_ops.gather(lengths, i) + + _, total = control_flow_ops.while_loop( + lambda j, _: j < lengths_i, + lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), [0, 0.]) + return total + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32]) + + def test_while_jacobian(self): + x = random_ops.random_uniform([1, 3]) + y = random_ops.random_uniform([3, 3]) + + # out = x @ y @ y @ y @ y, where @ is matmul operator. + _, out = control_flow_ops.while_loop( + lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)), + [0, x]) + + def loop_fn(i): + out_i = array_ops.gather(out, i, axis=1) + return array_ops.reshape(gradient_ops.gradients(out_i, x)[0], [-1]) + + out = pfor_control_flow_ops.pfor(loop_fn, iters=3) + + # The above code does not work with tf.while_loop instead of pfor. So we + # manually compute the expected output here. + # Note that gradient of output w.r.t is (y @ y @ y @ y)^T. + expected_output = y + for _ in range(3): + expected_output = math_ops.matmul(expected_output, y) + expected_output = array_ops.transpose(expected_output, [1, 0]) + + with session.Session() as sess: + out, expected = sess.run([out, expected_output]) + self.assertAllClose(expected, out) + + def test_tensor_array_as_loop_variable(self): + + def loop_fn(i): + + def body(j, ta): + ta = ta.write(j, i + j * j) + return j + 1, ta + + _, ta = control_flow_ops.while_loop( + lambda j, _: j < 4, body, + (0, tensor_array_ops.TensorArray(dtypes.int32, size=4))) + return ta.stack() + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_read_tensor_array_partitioned_indices(self): + # Note that tensor array values are pfor loop dependent, and the while loop + # termination condition is also dependent on pfor iteration. + def loop_fn(i): + ta = tensor_array_ops.TensorArray(dtypes.int32, size=6) + ta = ta.unstack(i + list(range(5))) + + def body(j, s): + return j + 1, s + ta.read(j) + + _, s = control_flow_ops.while_loop(lambda j, _: j < i, + body, + (0, 0)) + return s + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32]) + + def test_external_while_loop_grad(self): + # Here we test that external while_loops that are extended from inside pfor + # (due to gradient calls) are not actually converted. If the below was + # converted all pfor iterations would write to the same tensor array + # indices. + x = constant_op.constant(1.) + + def body(j, ta): + ta = ta.write(j, x) + return j + 1, ta + + _, ta = control_flow_ops.while_loop( + lambda j, _: j < 4, body, + (0, tensor_array_ops.TensorArray(dtypes.float32, size=4))) + out = ta.stack() + + def loop_fn(i): + out_i = array_ops.gather(out, i) + return gradient_ops.gradients(out_i, x)[0] + + with session.Session() as sess: + # out is [x, x, x]. Hence the gradients should be [1, 1, 1]. + self.assertAllEqual([1, 1, 1], + sess.run(pfor_control_flow_ops.pfor(loop_fn, 3))) + + def test_tensor_array_grad(self): + inp = constant_op.constant(np.random.rand(3, 4, 2), dtype=dtypes.float32) + ta = tensor_array_ops.TensorArray(dtypes.float32, size=3) + ta = ta.unstack(inp) + + def loop_fn(i): + + def body(j, x): + value = ta.gather([j]) + value = array_ops.gather(array_ops.reshape(value, [4, 2]), i) + return j + 1, x + value + + _, out = control_flow_ops.while_loop(lambda j, _: j < 3, body, + (0, array_ops.zeros([2]))) + out = math_ops.reduce_prod(out) + return out, gradient_ops.gradients(out, inp)[0] + + pfor_out, pfor_out_grad = pfor_control_flow_ops.pfor(loop_fn, 4) + # Note that tf.while_loop does not work in the setup above. So we manually + # construct the equivalent computation of the above loops here. + real_out = math_ops.reduce_sum(inp, reduction_indices=[0]) + real_out = math_ops.reduce_prod(real_out, reduction_indices=[1]) + # Note that gradients of real_out will accumulate the gradients across the + # output value. Hence we do the same aggregation on pfor_out_grad. + real_out_grad = gradient_ops.gradients(real_out, inp)[0] + sum_pfor_out_grad = math_ops.reduce_sum( + pfor_out_grad, reduction_indices=[0]) + + with session.Session() as sess: + v1, v2, v1_grad, v2_grad = sess.run( + [pfor_out, real_out, sum_pfor_out_grad, real_out_grad]) + self.assertAllClose(v1, v2) + self.assertAllClose(v1_grad, v2_grad) + + +def dynamic_lstm_input_fn(batch_size, state_size, max_steps): + # We make inputs and sequence_length constant so that multiple session.run + # calls produce the same result. + inputs = constant_op.constant( + np.random.rand(batch_size, max_steps, state_size), dtype=dtypes.float32) + sequence_length = np.random.randint(0, size=[batch_size], high=max_steps + 1) + sequence_length = constant_op.constant(sequence_length, dtype=dtypes.int32) + return inputs, sequence_length + + +def create_dynamic_lstm(cell_fn, batch_size, state_size, max_steps): + cell = cell_fn(state_size) + inputs, sequence_length = dynamic_lstm_input_fn(batch_size, + state_size, + max_steps) + inputs_ta = tensor_array_ops.TensorArray( + dtypes.float32, size=max_steps, element_shape=[batch_size, state_size]) + inputs_time_major = array_ops.transpose(inputs, [1, 0, 2]) + inputs_ta = inputs_ta.unstack(inputs_time_major) + zeros = array_ops.zeros([state_size]) + + def loop_fn(i): + sequence_length_i = array_ops.gather(sequence_length, i) + + def body_fn(t, state, ta): + inputs_t = array_ops.expand_dims( + array_ops.gather(inputs_ta.read(t), i), 0) + output, new_state = cell(inputs_t, state) + output = array_ops.reshape(output, [-1]) + # TODO(agarwal): one optimization that dynamic_rnn uses is to avoid the + # array_ops.where when t < min(sequence_length). Doing that requires + # supporting tf.cond pfor conversion. + done = t >= sequence_length_i + output = array_ops.where(done, zeros, output) + ta = ta.write(t, output) + new_state = [array_ops.where(done, s, ns) for s, ns in + zip(nest.flatten(state), nest.flatten(new_state))] + new_state = nest.pack_sequence_as(state, new_state) + return t + 1, new_state, ta + + def condition_fn(t, _, unused): + del unused + return t < max_steps + + initial_state = cell.zero_state(1, dtypes.float32) + _, state, ta = control_flow_ops.while_loop(condition_fn, body_fn, [ + 0, initial_state, + tensor_array_ops.TensorArray(dtypes.float32, max_steps) + ]) + + new_state = [array_ops.reshape(x, [-1]) for x in nest.flatten(state)] + new_state = nest.pack_sequence_as(initial_state, new_state) + return ta.stack(), new_state + + pfor_output = pfor_control_flow_ops.pfor(loop_fn, batch_size) + tf_output = rnn.dynamic_rnn( + cell, + inputs, + sequence_length=sequence_length, + initial_state=cell.zero_state(batch_size, dtypes.float32)) + return pfor_output, tf_output + + +class RNNTest(PForTest): + + def test_dynamic_rnn(self): + pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, + 3, 5, 7) + self.run_and_assert_equal(pfor_outputs, tf_outputs) + + def test_dynamic_lstm(self): + pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicLSTMCell, + 3, 5, 7) + self.run_and_assert_equal(pfor_outputs, tf_outputs) + + +# TODO(agarwal): benchmark numbers on GPU for graphs based on while_loop +# conversion don't look good. Some of it seems like lot of copies between host +# and device. Optimize that. +class Benchmarks(test.Benchmark): + + def _run(self, targets, iters, name=None): + + def _done(t): + # Note that we don't use tf.control_dependencies since that will not make + # sure that the computation on GPU has actually finished. So we fetch the + # first element of the output, and assume that this will not be called on + # empty tensors. + return array_ops.gather(array_ops.reshape(t, [-1]), 0) + + targets = [_done(x) for x in nest.flatten(targets)] + sess = session.Session() + with sess: + init = variables.global_variables_initializer() + sess.run(init) + sess.run(targets) + begin = time.time() + for _ in range(iters): + sess.run(targets) + end = time.time() + avg_time_ms = 1000 * (end - begin) / iters + self.report_benchmark(iters=iters, wall_time=avg_time_ms, name=name) + return avg_time_ms + + def benchmark_basic_while(self): + with ops.Graph().as_default(): + + def loop_fn(i): + _, s = control_flow_ops.while_loop( + lambda t, x: t < i, + lambda t, x: (t + 1, x + i), + [0, 0]) + return s + + iters = 50 + pfor_output = pfor_control_flow_ops.pfor(loop_fn, iters) + for_loop_output = pfor_control_flow_ops.for_loop(loop_fn, dtypes.int32, + iters) + self._run(pfor_output, 100, name="pfor_basic") + self._run(for_loop_output, 100, name="for_loop_basic") + + def benchmark_dynamic_rnn(self): + with ops.Graph().as_default(): + pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell, + 128, 512, 16) + self._run(pfor_outputs, 100, name="pfor_rnn") + self._run(tf_outputs, 100, name="tf_rnn") + + def benchmark_dynamic_lstm(self): + with ops.Graph().as_default(): + pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicLSTMCell, + 128, 512, 16) + self._run(pfor_outputs, 100, name="pfor_lstm") + self._run(tf_outputs, 100, name="tf_lstm") + + +class SparseTest(PForTest): + + def test_var_loop_len(self): + num_iters = array_ops.placeholder(dtypes.int32) + + def loop_fn(_): + return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6], + [3]) # [0, 2, 0] + + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + with self.test_session() as sess: + sess.run(pfor, feed_dict={num_iters: 3}) + + def test_sparse_result_none_stacked(self): + num_iters = 10 + + def loop_fn(_): + return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6], + [3]) # [0, 2, 0] + + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + + indices = [[i, j] for i in range(num_iters) for j in range(3)] + values = [4, 5, 6] * num_iters + dense_shapes = [num_iters, 3] + # Expected result: [[4, 5, 6], [4, 5, 6], [4, 5, 6], ...] + manual = sparse_tensor.SparseTensor(indices, values, dense_shapes) + self.run_and_assert_equal(pfor, manual) + + def test_sparse_result_all_stacked(self): + num_iters = 10 + + def loop_fn(i): + i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) + indices = array_ops.expand_dims(i, 0) + return sparse_tensor.SparseTensor(indices, i, i + 1) # [0, ..., 0, i] + + # Expected result: [[0], [0, 1], [0, 0, 2], [0, 0, 0, 3], ...] + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)], + list(range(num_iters)), + (num_iters, num_iters)) + self.run_and_assert_equal(pfor, manual) + + def test_sparse_result_indices_stacked(self): + num_iters = 10 + + def loop_fn(i): + i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) + indices = array_ops.expand_dims(i, 0) + return sparse_tensor.SparseTensor(indices, [1], [num_iters]) + + # Expected result: identity matrix size num_iters * num_iters + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)], + [1] * num_iters, (num_iters, num_iters)) + self.run_and_assert_equal(pfor, manual) + + def test_sparse_result_values_stacked(self): + num_iters = 10 + + def loop_fn(i): + i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) + return sparse_tensor.SparseTensor([[0]], i, [num_iters]) # [i, 0, ..., 0] + + # Expected result: [[1, 0, ...], [2, 0, ...], [3, 0, ...], ...] + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)], + list(range(num_iters)), + (num_iters, num_iters)) + self.run_and_assert_equal(pfor, manual) + + def test_sparse_result_shapes_stacked(self): + num_iters = 10 + + def loop_fn(i): + i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0) + return sparse_tensor.SparseTensor([[0]], [1], i + 1) # [1, 0, ..., 0] + + # Expected result: [[1, 0, 0, ...], [1, 0, 0, ...], ...] + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)], + [1] * num_iters, (num_iters, num_iters)) + self.run_and_assert_equal(pfor, manual) + + def test_sparse_result_shapes_stacked_2D(self): + num_iters = 10 + + def loop_fn(i): + i = array_ops.expand_dims(math_ops.cast(i + 1, dtypes.int64), 0) + shape = array_ops.concat([i, i], 0) + return sparse_tensor.SparseTensor([[0, 0]], [1], shape) # [1, 0, ..., 0] + + # Expected result: [[[1, 0, ...], [0, ..., 0], [0, ..., 0], ...], ...] + pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters) + manual = sparse_tensor.SparseTensor([[i, 0, 0] for i in range(num_iters)], + [1] * num_iters, + (num_iters, num_iters, num_iters)) + self.run_and_assert_equal(pfor, manual) + + +class ParsingTest(PForTest): + + def test_decode_csv(self): + csv_tensor = constant_op.constant([["1:2:3"], ["::"], ["7:8:9"]]) + kwargs = {"record_defaults": [[10], [20], [30]], "field_delim": ":"} + + def loop_fn(i): + line = array_ops.gather(csv_tensor, i) + return parsing_ops.decode_csv(line, **kwargs) + + self._test_loop_fn(loop_fn, iters=3, loop_fn_dtypes=[dtypes.int32] * 3) + + def test_parse_single_example(self): + + def _int64_feature(*values): + return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values)) + + def _bytes_feature(*values): + return feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=[v.encode("utf-8") for v in values])) + + examples = constant_op.constant([ + example_pb2.Example( + features=feature_pb2.Features( + feature={ + "dense_int": _int64_feature(i), + "dense_str": _bytes_feature(str(i)), + "sparse_int": _int64_feature(i, i * 2, i * 4, i * 8), + "sparse_str": _bytes_feature(*["abc"] * i) + })).SerializeToString() for i in range(10) + ]) + + features = { + "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0), + "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""), + "sparse_int": parsing_ops.VarLenFeature(dtypes.int64), + "sparse_str": parsing_ops.VarLenFeature(dtypes.string), + } + + def loop_fn(i): + example_proto = array_ops.gather(examples, i) + f = parsing_ops.parse_single_example(example_proto, features) + return f + + pfor = pfor_control_flow_ops.pfor(loop_fn, iters=10) + manual = parsing_ops.parse_example(examples, features) + self.run_and_assert_equal(pfor, manual) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3d5c9b86ed186f76e113351646b3dda153e72b --- /dev/null +++ b/tensorflow/python/ops/parallel_for/gradients.py @@ -0,0 +1,126 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Jacobian ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import gradients as gradient_ops +from tensorflow.python.ops.parallel_for import control_flow_ops +from tensorflow.python.util import nest + + +def jacobian(output, inputs, use_pfor=True): + """Computes jacobian of `output` w.r.t. `inputs`. + + Args: + output: A tensor. + inputs: A tensor or a nested structure of tensor objects. + use_pfor: If true, uses pfor for computing the jacobian. Else uses + tf.while_loop. + + Returns: + A tensor or a nested strucutre of tensors with the same structure as + `inputs`. Each entry is the jacobian of `output` w.rt. to the corresponding + value in `inputs`. If output has shape [y_1, ..., y_n] and inputs_i has + shape [x_1, ..., x_m], the corresponding jacobian has shape + [y_1, ..., y_n, x_1, ..., x_m]. + """ + flat_inputs = nest.flatten(inputs) + output_shape = array_ops.shape(output) + output = array_ops.reshape(output, [-1]) + + def loop_fn(i): + y = array_ops.gather(output, i) + return gradient_ops.gradients(y, flat_inputs) + + try: + output_size = int(output.shape[0]) + except TypeError: + output_size = array_ops.shape(output)[0] + + if use_pfor: + pfor_outputs = control_flow_ops.pfor(loop_fn, output_size) + else: + pfor_outputs = control_flow_ops.for_loop( + loop_fn, [output.dtype] * len(flat_inputs), output_size) + + for i, out in enumerate(pfor_outputs): + new_shape = array_ops.concat( + [output_shape, array_ops.shape(out)[1:]], axis=0) + out = array_ops.reshape(out, new_shape) + pfor_outputs[i] = out + + return nest.pack_sequence_as(inputs, pfor_outputs) + + +def batch_jacobian(output, inp, use_pfor=True): + """Computes and stacks jacobians of `output[i,...]` w.r.t. `input[i,...]`. + + e.g. + x = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) + y = x * x + jacobian = batch_jacobian(y, x) + # => [[[2, 0], [0, 4]], [[6, 0], [0, 8]]] + + Args: + output: A tensor with shape [b, y1, ..., y_n]. `output[i,...]` should + only depend on `inp[i,...]`. + inp: A tensor with shape [b, x1, ..., x_m] + use_pfor: If true, uses pfor for computing the Jacobian. Else uses a + tf.while_loop. + + Returns: + A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]` + is the jacobian of `output[i, ...]` w.r.t. `inp[i, ...]`, i.e. stacked + per-example jacobians. + + Raises: + ValueError: if first dimension of `output` and `inp` do not match. + """ + output_shape = output.shape + if not output_shape[0].is_compatible_with(inp.shape[0]): + raise ValueError("Need first dimension of output shape (%s) and inp shape " + "(%s) to match." % (output.shape, inp.shape)) + if output_shape.is_fully_defined(): + batch_size = int(output_shape[0]) + output_row_size = output_shape.num_elements() // batch_size + else: + output_shape = array_ops.shape(output) + batch_size = output_shape[0] + output_row_size = array_ops.size(output) // batch_size + inp_shape = array_ops.shape(inp) + # Flatten output to 2-D. + with ops.control_dependencies( + [check_ops.assert_equal(batch_size, inp_shape[0])]): + output = array_ops.reshape(output, [batch_size, output_row_size]) + + def loop_fn(i): + y = array_ops.gather(output, i, axis=1) + return gradient_ops.gradients(y, inp)[0] + + if use_pfor: + pfor_output = control_flow_ops.pfor(loop_fn, output_row_size) + else: + pfor_output = control_flow_ops.for_loop(loop_fn, output.dtype, + output_row_size) + pfor_output = array_ops.reshape(pfor_output, + [output_row_size, batch_size, -1]) + output = array_ops.transpose(pfor_output, [1, 0, 2]) + new_shape = array_ops.concat([output_shape, inp_shape[1:]], axis=0) + return array_ops.reshape(output, new_shape) diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3a6d9149ad80e5087c8ecc755e6b81b67d4a5ed2 --- /dev/null +++ b/tensorflow/python/ops/parallel_for/gradients_test.py @@ -0,0 +1,579 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for jacobian and batch_jacobian ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os +import time + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.keras.engine import training as keras_training +from tensorflow.python.layers import layers as tf_layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients as gradient_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.ops.parallel_for import control_flow_ops +from tensorflow.python.ops.parallel_for import gradients +from tensorflow.python.platform import test +from tensorflow.python.util import nest + + +class FullyConnectedModel(object): + + def __init__(self, activation_size, num_layers): + self._layers = [ + tf_layers.Dense(activation_size, activation=nn.relu) + for _ in range(num_layers) + ] + + def __call__(self, inp): + activation = inp + for layer in self._layers: + activation = layer(activation) + return activation + + +def fully_connected_model_fn(batch_size, activation_size, num_layers): + model = FullyConnectedModel(activation_size, num_layers) + inp = random_ops.random_normal([batch_size, activation_size]) + return inp, model(inp) + + +def lstm_model_fn(batch_size, state_size, steps): + inputs = [ + random_ops.random_normal([batch_size, state_size]) for _ in range(steps) + ] + cell = rnn_cell.BasicLSTMCell(state_size) + init_state = cell.zero_state(batch_size, dtypes.float32) + state = init_state + for inp in inputs: + _, state = cell(inp, state) + return init_state.c, state.c + + +def dynamic_lstm_model_fn(batch_size, state_size, max_steps): + # We make inputs and sequence_length constant so that multiple session.run + # calls produce the same result. + inputs = constant_op.constant( + np.random.rand(batch_size, max_steps, state_size), dtype=dtypes.float32) + sequence_length = constant_op.constant( + np.random.randint(0, size=[batch_size], high=max_steps + 1), + dtype=dtypes.int32) + + cell = rnn_cell.BasicLSTMCell(state_size) + initial_state = cell.zero_state(batch_size, dtypes.float32) + return inputs, rnn.dynamic_rnn( + cell, + inputs, + sequence_length=sequence_length, + initial_state=initial_state) + + +def create_fc_batch_jacobian(batch_size, activation_size, num_layers): + inp, output = fully_connected_model_fn(batch_size, activation_size, + num_layers) + pfor_jacobian = gradients.batch_jacobian(output, inp, use_pfor=True) + while_jacobian = gradients.batch_jacobian(output, inp, use_pfor=False) + return pfor_jacobian, while_jacobian + + +def create_lstm_batch_jacobian(batch_size, state_size, steps): + inp, output = lstm_model_fn(batch_size, state_size, steps) + pfor_jacobian = gradients.batch_jacobian(output, inp, use_pfor=True) + while_jacobian = gradients.batch_jacobian(output, inp, use_pfor=False) + return pfor_jacobian, while_jacobian + + +def create_dynamic_lstm_batch_jacobian(batch_size, state_size, max_steps): + inp, (_, final_state) = dynamic_lstm_model_fn(batch_size, state_size, + max_steps) + pfor_jacobian = gradients.batch_jacobian(final_state.c, inp, use_pfor=True) + # Note that use_pfor=False does not work above given the current limitations + # on implementation of while_loop. So we statically unroll the looping in the + # jacobian computation. + while_gradients = [ + gradient_ops.gradients(array_ops.gather(final_state.c, i, axis=1), inp)[0] + for i in range(state_size) + ] + return pfor_jacobian, while_gradients + + +def create_lstm_batch_hessian(batch_size, state_size, steps): + inp, output = lstm_model_fn(batch_size, state_size, steps) + pfor_jacobian = gradients.batch_jacobian(output, inp, use_pfor=True) + pfor_jacobian = array_ops.reshape(pfor_jacobian, [batch_size, -1]) + pfor_hessian = gradients.batch_jacobian(pfor_jacobian, inp, use_pfor=True) + # TODO(agarwal): using two nested while_loop doesn't seem to work here. + # Hence we use pfor_jacobian for computing while_hessian. + while_jacobian = pfor_jacobian + while_hessian = gradients.batch_jacobian(while_jacobian, inp, use_pfor=False) + return pfor_hessian, while_hessian + + +def create_lstm_hessian(batch_size, state_size, steps): + _, output = lstm_model_fn(batch_size, state_size, steps) + weights = variables.trainable_variables() + pfor_jacobians = gradients.jacobian(output, weights, use_pfor=True) + pfor_hessians = [ + gradients.jacobian(x, weights, use_pfor=True) for x in pfor_jacobians + ] + # TODO(agarwal): using two nested while_loop doesn't seem to work here. + # Hence we use pfor_jacobians for computing while_hessians. + while_jacobians = pfor_jacobians + while_hessians = [ + gradients.jacobian(x, weights, use_pfor=False) for x in while_jacobians + ] + return pfor_hessians, while_hessians + + +def create_fc_per_eg_grad(batch_size, activation_size, num_layers): + inp = random_ops.random_normal([batch_size, activation_size]) + layers = [ + tf_layers.Dense(activation_size, activation=nn.relu) + for _ in range(num_layers) + ] + projection = tf_layers.Dense(1) + + def model_fn(activation): + for layer in layers: + activation = layer(activation) + activation = projection(activation) + activation = nn.l2_loss(activation) + return gradient_ops.gradients(activation, variables.trainable_variables()) + + def loop_fn(i): + return model_fn(array_ops.expand_dims(array_ops.gather(inp, i), 0)) + + pfor_outputs = control_flow_ops.pfor(loop_fn, batch_size) + loop_fn_dtypes = [x.dtype for x in variables.trainable_variables()] + while_outputs = control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, batch_size) + return pfor_outputs, while_outputs + + +def create_lstm_per_eg_grad(batch_size, state_size, steps): + inputs = [ + random_ops.random_normal([batch_size, state_size]) for _ in range(steps) + ] + cell = rnn_cell.BasicLSTMCell(state_size) + init_state = cell.zero_state(batch_size, dtypes.float32) + + def model_fn(inps, init_state): + state = init_state + for inp in inps: + _, state = cell(inp, state) + output = nn.l2_loss(state.c) + return gradient_ops.gradients(output, variables.trainable_variables()) + + def loop_fn(i): + loop_inputs = [ + array_ops.expand_dims(array_ops.gather(x, i), 0) for x in inputs + ] + loop_init_state = rnn_cell.LSTMStateTuple( + *[array_ops.expand_dims(array_ops.gather(x, i), 0) for x in init_state]) + return model_fn(loop_inputs, loop_init_state) + + pfor_outputs = control_flow_ops.pfor(loop_fn, batch_size) + loop_fn_dtypes = [x.dtype for x in variables.trainable_variables()] + while_outputs = control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, batch_size) + return pfor_outputs, while_outputs + + +# Importing the code from tensorflow_models seems to cause errors. Hence we +# duplicate the model definition here. +# TODO(agarwal): Use the version in tensorflow_models/official instead. +class Mnist(keras_training.Model): + + def __init__(self, data_format): + """Creates a model for classifying a hand-written digit. + + Args: + data_format: Either 'channels_first' or 'channels_last'. + """ + super(Mnist, self).__init__() + if data_format == "channels_first": + self._input_shape = [-1, 1, 28, 28] + else: + assert data_format == "channels_last" + self._input_shape = [-1, 28, 28, 1] + + self.conv1 = tf_layers.Conv2D( + 32, 5, padding="same", data_format=data_format, activation=nn.relu) + self.conv2 = tf_layers.Conv2D( + 64, 5, padding="same", data_format=data_format, activation=nn.relu) + self.fc1 = tf_layers.Dense(1024, activation=nn.relu) + self.fc2 = tf_layers.Dense(10) + self.dropout = tf_layers.Dropout(0.4) + self.max_pool2d = tf_layers.MaxPooling2D( + (2, 2), (2, 2), padding="same", data_format=data_format) + + def __call__(self, inputs, training): + """Add operations to classify a batch of input images. + + Args: + inputs: A Tensor representing a batch of input images. + training: A boolean. Set to True to add operations required only when + training the classifier. + + Returns: + A logits Tensor with shape [, 10]. + """ + y = array_ops.reshape(inputs, self._input_shape) + y = self.conv1(y) + y = self.max_pool2d(y) + y = self.conv2(y) + y = self.max_pool2d(y) + y = tf_layers.flatten(y) + y = self.fc1(y) + y = self.dropout(y, training=training) + return self.fc2(y) + + +def create_mnist_per_eg_grad(batch_size, data_format, training): + images = random_ops.random_uniform([batch_size, 28, 28]) + sparse_labels = np.random.randint( + low=0, high=10, size=[batch_size]).astype(np.int32) + labels = np.zeros((batch_size, 10)).astype(np.float32) + labels[np.arange(batch_size), sparse_labels] = 1. + model = Mnist(data_format) + + def loop_fn(i): + image = array_ops.gather(images, i) + label = array_ops.gather(labels, i) + logits = array_ops.reshape(model(image, training=training), [-1]) + loss = losses.softmax_cross_entropy( + logits=logits, onehot_labels=label, reduction=losses.Reduction.NONE) + return gradient_ops.gradients(loss, variables.trainable_variables()) + + pfor_outputs = control_flow_ops.pfor(loop_fn, batch_size) + while_outputs = control_flow_ops.for_loop( + loop_fn, [dtypes.float32] * len(variables.trainable_variables()), + batch_size) + return pfor_outputs, while_outputs + + +def create_mnist_per_eg_jacobian(batch_size, data_format, training): + images = random_ops.random_uniform([batch_size, 28, 28]) + model = Mnist(data_format) + + def loop_fn(i, use_pfor): + image = array_ops.gather(images, i) + logits = array_ops.reshape(model(image, training=training), [-1]) + return gradients.jacobian( + logits, variables.trainable_variables(), use_pfor=use_pfor) + + pfor_outputs = control_flow_ops.pfor( + functools.partial(loop_fn, use_pfor=True), + batch_size) + while_outputs = control_flow_ops.for_loop( + functools.partial(loop_fn, use_pfor=False), + [dtypes.float32] * len(variables.trainable_variables()), batch_size) + return pfor_outputs, while_outputs + + +def create_fc_per_eg_jacobians(batch_size, activation_size, num_layers): + model = FullyConnectedModel(activation_size=activation_size, + num_layers=num_layers) + inp = random_ops.random_normal([batch_size, activation_size]) + output = model(inp) + jacobians = gradients.jacobian(output, variables.trainable_variables()) + + def loop_fn(i, use_pfor): + inp_i = array_ops.expand_dims(array_ops.gather(inp, i), 0) + output = array_ops.reshape(model(inp_i), [-1]) + return gradients.jacobian( + output, variables.trainable_variables(), use_pfor=use_pfor) + + per_eg_jacobians_pfor = control_flow_ops.pfor( + functools.partial(loop_fn, use_pfor=True), + batch_size) + per_eg_jacobians_while = control_flow_ops.for_loop( + functools.partial(loop_fn, use_pfor=False), + [dtypes.float32] * len(variables.trainable_variables()), batch_size) + return jacobians, per_eg_jacobians_pfor, per_eg_jacobians_while + + +class GradientsTest(test.TestCase): + + def run_and_assert_equal(self, targets1, targets2, atol=1e-4, rtol=1e-4): + targets1 = nest.flatten(targets1) + targets2 = nest.flatten(targets2) + assert len(targets1) == len(targets2) + init = variables.global_variables_initializer() + self.evaluate(init) + outputs = self.evaluate(targets1 + targets2) + n = len(outputs) // 2 + for i in range(n): + self.assertAllClose(outputs[i], outputs[i + n], rtol=rtol, atol=atol) + + def test_jacobian_fixed_shape(self): + x = random_ops.random_uniform([2, 2]) + y = math_ops.matmul(x, x, transpose_a=True) + jacobian_pfor = gradients.jacobian(y, x, use_pfor=True) + jacobian_while = gradients.jacobian(y, x, use_pfor=False) + answer = ops.convert_to_tensor([[ + gradient_ops.gradients(y[0][0], x)[0], + gradient_ops.gradients(y[0][1], x)[0] + ], [ + gradient_ops.gradients(y[1][0], x)[0], + gradient_ops.gradients(y[1][1], x)[0] + ]]) + self.run_and_assert_equal(answer, jacobian_pfor) + self.run_and_assert_equal(answer, jacobian_while) + + def test_jacobian_unknown_shape(self): + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.float32, shape=[None, None]) + y = math_ops.matmul(x, x, transpose_a=True) + jacobian_pfor = gradients.jacobian(y, x, use_pfor=True) + jacobian_while = gradients.jacobian(y, x, use_pfor=False) + answer = ops.convert_to_tensor([[ + gradient_ops.gradients(y[0][0], x)[0], + gradient_ops.gradients(y[0][1], x)[0] + ], [ + gradient_ops.gradients(y[1][0], x)[0], + gradient_ops.gradients(y[1][1], x)[0] + ]]) + ans, pfor_value, while_value = sess.run( + [answer, jacobian_pfor, jacobian_while], + feed_dict={x: [[1, 2], [3, 4]]}) + self.assertAllClose(ans, pfor_value) + self.assertAllClose(ans, while_value) + + def test_batch_jacobian_bad_shapes(self): + x = random_ops.random_uniform([2, 2]) + y = random_ops.random_uniform([3, 2]) + with self.assertRaisesRegexp(ValueError, "Need first dimension of output"): + gradients.batch_jacobian(y, x, use_pfor=True) + + def test_batch_jacobian_bad_unknown_shapes(self): + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.float32) + y = array_ops.concat([x, x], axis=0) + jacobian = gradients.batch_jacobian(y, x) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "assertion failed"): + sess.run(jacobian, feed_dict={x: [[1, 2], [3, 4]]}) + + def test_batch_jacobian_fixed_shape(self): + x = random_ops.random_uniform([2, 3, 5]) + y = x * x + batch_jacobian_pfor = gradients.batch_jacobian(y, x, use_pfor=True) + batch_jacobian_while = gradients.batch_jacobian(y, x, use_pfor=False) + two_x = 2 * x + answer = array_ops.stack( + [array_ops.diag(two_x[0]), + array_ops.diag(two_x[1])]) + self.run_and_assert_equal(answer, batch_jacobian_pfor) + self.run_and_assert_equal(answer, batch_jacobian_while) + + def test_batch_jacobian_unknown_shape(self): + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.float32) + y = x * x + batch_jacobian_pfor = gradients.batch_jacobian(y, x, use_pfor=True) + batch_jacobian_while = gradients.batch_jacobian(y, x, use_pfor=False) + two_x = 2 * x + answer = array_ops.stack( + [array_ops.diag(two_x[0]), + array_ops.diag(two_x[1])]) + ans, pfor_value, while_value = sess.run( + [answer, batch_jacobian_pfor, batch_jacobian_while], + feed_dict={x: [[1, 2], [3, 4]]}) + self.assertAllClose(ans, pfor_value) + self.assertAllClose(ans, while_value) + + def test_fc_batch_jacobian(self): + pfor_jacobian, while_jacobian = create_fc_batch_jacobian(8, 4, 2) + self.run_and_assert_equal(pfor_jacobian, while_jacobian) + + def test_lstm_batch_jacobian(self): + pfor_jacobian, while_jacobian = create_lstm_batch_jacobian(8, 4, 2) + self.run_and_assert_equal(pfor_jacobian, while_jacobian) + + def test_dynamic_lstm_batch_jacobian(self): + pfor_jacobian, while_gradients = create_dynamic_lstm_batch_jacobian(8, 4, 3) + with session.Session() as sess: + init = variables.global_variables_initializer() + sess.run(init) + pfor = sess.run(pfor_jacobian) + for i in range(4): + while_i = sess.run(while_gradients[i]) + self.assertAllClose(while_i, pfor[:, i, ...]) + + def test_lstm_hessian(self): + pfor_hessian, while_hessian = create_lstm_hessian(2, 2, 2) + self.run_and_assert_equal(pfor_hessian, while_hessian) + + def test_lstm_batch_hessian(self): + pfor_hessian, while_hessian = create_lstm_batch_hessian(2, 2, 2) + self.run_and_assert_equal(pfor_hessian, while_hessian) + + def test_fc_per_eg_grad(self): + pfor_outputs, while_outputs = create_fc_per_eg_grad(8, 4, 2) + self.run_and_assert_equal(pfor_outputs, while_outputs) + + def test_lstm_per_eg_grad(self): + pfor_outputs, while_outputs = create_lstm_per_eg_grad(8, 4, 2) + self.run_and_assert_equal(pfor_outputs, while_outputs) + + def test_mnist_per_eg_grad(self): + # It looks like CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED + # configuration of Winograd can cause low precision output resulting in + # tests failing. So we disable that here. + os.environ["TF_ENABLE_WINOGRAD_NONFUSED"] = "0" + data_format = ("channels_first" + if test.is_gpu_available() else "channels_last") + # Note that we we are setting training=False here so that dropout produces + # the same result with pfor and with while_loop. + pfor_outputs, while_outputs = create_mnist_per_eg_grad( + 4, data_format, training=False) + self.run_and_assert_equal(pfor_outputs, while_outputs, rtol=1e-3) + os.environ.pop("TF_ENABLE_WINOGRAD_NONFUSED", None) + + def test_mnist_per_eg_jacobian(self): + # It looks like CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED + # configuration of Winograd can cause low precision output resulting in + # tests failing. So we disable that here. + os.environ["TF_ENABLE_WINOGRAD_NONFUSED"] = "0" + data_format = ("channels_first" + if test.is_gpu_available() else "channels_last") + # Note that we we are setting training=False here so that dropout produces + # the same result with pfor and with while_loop. + pfor_outputs, while_outputs = create_mnist_per_eg_jacobian( + 2, data_format, training=False) + self.run_and_assert_equal(pfor_outputs, while_outputs, rtol=1e-3) + os.environ.pop("TF_ENABLE_WINOGRAD_NONFUSED", None) + + def test_fc_jacobian(self): + jacobians, per_eg_jacobians_pfor, per_eg_jacobians_while = ( + create_fc_per_eg_jacobians(batch_size=8, + activation_size=4, + num_layers=2)) + self.run_and_assert_equal(jacobians, per_eg_jacobians_pfor, + rtol=2e-3, atol=1e-3) + self.run_and_assert_equal(jacobians, per_eg_jacobians_while, + rtol=2e-3, atol=1e-3) + + +class GradientsBenchmarks(test.Benchmark): + + def _run(self, targets, iters, name=None): + + def _done(t): + # Note that we don't use tf.control_dependencies since that will not make + # sure that the computation on GPU has actually finished. So we fetch the + # first element of the output, and assume that this will not be called on + # empty tensors. + return array_ops.gather(array_ops.reshape(t, [-1]), 0) + + targets = [_done(x) for x in nest.flatten(targets)] + sess = session.Session() + with sess: + init = variables.global_variables_initializer() + sess.run(init) + sess.run(targets) + begin = time.time() + for _ in range(iters): + sess.run(targets) + end = time.time() + avg_time_ms = 1000 * (end - begin) / iters + self.report_benchmark(iters=iters, wall_time=avg_time_ms, name=name) + return avg_time_ms + + def benchmark_fc_batch_jacobian(self): + with ops.Graph().as_default(): + pfor_jacobian, while_jacobian = create_fc_batch_jacobian(100, 32, 20) + self._run(pfor_jacobian, 100, name="fc_batch_jacobian_pfor") + self._run(while_jacobian, 20, name="fc_batch_jacobian_while") + + def benchmark_lstm_batch_jacobian(self): + with ops.Graph().as_default(): + pfor_jacobian, while_jacobian = create_lstm_batch_jacobian(100, 32, 8) + self._run(pfor_jacobian, 100, name="lstm_batch_jacobian_pfor") + self._run(while_jacobian, 20, name="lstm_batch_jacobian_while") + + def benchmark_lstm_hessian(self): + with ops.Graph().as_default(): + pfor_hessian, while_hessian = create_lstm_hessian(2, 2, 10) + self._run(pfor_hessian, 20, name="lstm_hessian_pfor") + self._run(while_hessian, 3, name="lstm_hessian_while_pfor") + + def benchmark_lstm_batch_hessian(self): + with ops.Graph().as_default(): + pfor_hessian, while_hessian = create_lstm_batch_hessian(4, 4, 10) + self._run(pfor_hessian, 100, name="lstm_batch_hessian_pfor") + self._run(while_hessian, 20, name="lstm_batch_hessian_while_pfor") + + def benchmark_fc_per_eg_grad(self): + with ops.Graph().as_default(): + pfor_outputs, while_outputs = create_fc_per_eg_grad(100, 32, 3) + self._run(pfor_outputs, 100, name="fc_per_eg_grad_pfor") + self._run(while_outputs, 20, name="fc_per_eg_grad_while") + + def benchmark_lstm_per_eg_grad(self): + with ops.Graph().as_default(): + pfor_outputs, while_outputs = create_lstm_per_eg_grad(100, 32, 8) + self._run(pfor_outputs, 100, name="lstm_per_eg_grad_pfor") + self._run(while_outputs, 20, name="lstm_per_eg_grad_while") + + def benchmark_mnist_per_eg_grad(self): + with ops.Graph().as_default(): + data_format = ("channels_first" + if test.is_gpu_available() else "channels_last") + pfor_outputs, while_outputs = create_mnist_per_eg_grad( + 128, data_format, training=True) + self._run(pfor_outputs, 20, name="mnist_per_eg_grad_pfor") + self._run(while_outputs, 20, name="mnist_per_eg_grad_while") + + def benchmark_mnist_per_eg_jacobian(self): + with ops.Graph().as_default(): + data_format = ("channels_first" + if test.is_gpu_available() else "channels_last") + pfor_outputs, while_outputs = create_mnist_per_eg_jacobian( + 16, data_format, training=True) + self._run(pfor_outputs, 20, name="mnist_per_eg_jacobian_pfor") + self._run(while_outputs, 20, name="mnist_per_eg_jacobian_while") + + def benchmark_fc_per_eg_jacobian(self): + with ops.Graph().as_default(): + jacobians, per_eg_jacobians_pfor, per_eg_jacobians_while = ( + create_fc_per_eg_jacobians(batch_size=128, + activation_size=32, + num_layers=3)) + self._run(jacobians, 30, name="fc_jacobians_pfor") + self._run(per_eg_jacobians_pfor, 100, + name="fc_per_eg_jacobians_pfor") + self._run(per_eg_jacobians_while, 10, + name="fc_per_eg_jacobians_while") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py new file mode 100644 index 0000000000000000000000000000000000000000..2e4b2fd64eb9db5df283ca5918675b22b6835e83 --- /dev/null +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -0,0 +1,2552 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Compiled parallel-for loop.""" +# pylint: disable=missing-docstring + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from absl import flags + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.ops import gen_sparse_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest + +flags.DEFINE_bool( + "op_conversion_fallback_to_while_loop", False, + "If true, falls back to using a while loop for ops for " + "which a converter is not defined.") + + +def _stack(t, length): + """stacks `t` `length` times.""" + ones = array_ops.ones_like(array_ops.shape(t)) + multiples = array_ops.concat([length, ones], 0) + t = array_ops.tile(array_ops.expand_dims(t, 0), multiples) + return wrap(t, True) + + +# The following stateful ops can be safely called once, and with the same +# signature as the unconverted version, if their inputs are loop invariant. +# TODO(agarwal): implement a strategy for converting Variable reads/writes. The +# plan is to map each read/write in the loop_fn to a corresponding merged +# read/write in the converted graph. Writes need to be mergeable (e.g. +# AssignAdd) to be used in `pfor`. Given a certain read/write order in the +# loop_fn, doing a one-to-one conversion will simulate executing such +# instructions in lock-step across all iterations. +passthrough_stateful_ops = set([ + "VariableV2", + "VarHandleOp", + "ReadVariableOp", + "StackV2", + "TensorArrayWriteV3", + "TensorArrayReadV3", + "TensorArraySizeV3", +]) + + +def _is_stateful_pfor_op(op): + if isinstance(op, WhileOp): + return op.is_stateful + if op.type == "Const": + # Const didn't have an op_def. + return False + if op.type in passthrough_stateful_ops: + return False + assert hasattr(op, "op_def") and op.op_def is not None, op + return op.op_def.is_stateful + + +# pylint: disable=protected-access +class WhileOp(object): + """Object for storing state for converting the outputs of a while_loop.""" + + def __init__(self, exit_node, pfor_ops): + """Initializer. + + Args: + exit_node: A tensor output from the while_loop. + pfor_ops: list of ops inside the current pfor loop. + """ + self._pfor_ops = set(pfor_ops) + self._pfor_op_ids = set([x._id for x in pfor_ops]) + assert isinstance(exit_node, ops.Tensor) + self._while_context = exit_node.op._get_control_flow_context() + assert isinstance(self._while_context, control_flow_ops.WhileContext) + self._context_name = self._while_context.name + self._condition = self._while_context.pivot.op.inputs[0] + # Parts of an external while_loop could be created inside a pfor loop. + # However for the purpose here, we declare such loops to be external. Also + # note that we check if the condition was created inside or outside to + # determine if the while_loop was first created inside or outside. + # TODO(agarwal): check that the Enter and Exit of this loop are unstacked. + self._is_inside_loop = self.op_is_inside_loop(self._condition.op) + if self._is_inside_loop: + for e in self._while_context.loop_exits: + assert self.op_is_inside_loop(e.op) + + # Note the code below tries to reverse engineer an existing while_loop graph + # by assuming the following pattern of nodes. + # + # NextIteration <---- Body <--- Enter + # | ^ + # V ___| Y + # Enter -> Merge -> Switch___ + # ^ | N + # | V + # LoopCond Exit + + # Node that elements in the list below correspond one-to-one with each + # other. i.e. these lists are the same size, and the i_th entry corresponds + # to different Operations/Tensors of a single cycle as illustrated above. + # List of Switch ops (ops.Operation) that feed into an Exit Node. + self._exit_switches = [] + # List of inputs (ops.Tensor) to NextIteration. + self._body_outputs = [] + # List of list of control inputs of the NextIteration nodes. + self._next_iter_control_inputs = [] + # List of Merge ops (ops.Operation). + self._enter_merges = [] + # List of output (ops.Tensor) of Exit nodes. + self._outputs = [] + + # List of Enter Tensors. + # There are two types of Enter nodes: + # - The Enter nodes that are used in the `loop_vars` argument to + # `while_loop` (see + # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect + # these Enter nodes immediately below by tracing backwards from the Exit + # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the + # diagram above. This allows us to have a 1:1 correspondence between the + # self._outputs and the first elements in self._enters. + # - The Enter nodes that are used only by the body. They don't appear in the + # `loop_vars` and are not returned from the `while_loop`. In Python code, + # they are usually captured by the body lambda. We collect them below by + # iterating over all the ops in the graph. They are appended to the end of + # self._enters or self._direct_enters, and don't correspond to any outputs + # in self._outputs. Note that we keep the resource/variant Enter nodes in + # self._direct_enters and the constructed while_loop's body uses them + # directly as opposed to passing them as loop variables. This is done + # because the while_body cannot partition the resource/variant Tensors, so + # it has to leave them unchanged. + self._enters = [] + self._direct_enters = [] + + for e in self._while_context.loop_exits: + self._outputs.append(e.op.outputs[0]) + switch = e.op.inputs[0].op + assert switch.type == "Switch", switch + self._exit_switches.append(switch) + merge = switch.inputs[0].op + assert merge.type == "Merge", merge + self._enter_merges.append(merge) + enter = merge.inputs[0].op + assert enter.type == "Enter", enter + self._enters.append(enter.outputs[0]) + next_iter = merge.inputs[1].op + assert next_iter.type == "NextIteration", next_iter + self._body_outputs.append(next_iter.inputs[0]) + self._next_iter_control_inputs.append(next_iter.control_inputs) + + # Collect all the Enter nodes that are not part of `loop_vars`, the second + # category described above. + # Also track whether the loop body has any stateful ops. + self._is_stateful = False + for op in ops.get_default_graph().get_operations(): + # TODO(agarwal): make sure this works with nested case. + control_flow_context = op._get_control_flow_context() + if control_flow_context is None: + continue + if control_flow_context.name == self._context_name: + self._is_stateful |= _is_stateful_pfor_op(op) + if op.type == "Enter": + output = op.outputs[0] + if output not in self._enters: + if output.dtype in (dtypes.resource, dtypes.variant): + if output not in self._direct_enters: + self._direct_enters.append(output) + else: + self._enters.append(output) + + def __str__(self): + """String representation.""" + return "while_loop(%s)" % self.name + + @property + def inputs(self): + """Input to all the Enter nodes.""" + return [x.op.inputs[0] for x in self._enters + self._direct_enters] + + @property + def control_inputs(self): + """Control input to all the Enter nodes.""" + control_inputs = [] + for x in self._enters + self._direct_enters: + control_inputs.extend(x.op.control_inputs) + return control_inputs + + @property + def outputs(self): + """Outputs of all the Exit nodes.""" + return self._outputs + + @property + def name(self): + """Context name for the while loop.""" + return self._context_name + + @property + def is_inside_loop(self): + """Returns true if the while_loop was created inside the pfor.""" + return self._is_inside_loop + + def op_is_inside_loop(self, op): + """True if op was created inside the pfor loop body.""" + assert isinstance(op, ops.Operation) + # Note that we use self._pfor_op_ids for the check and not self._pfor_ops + # since it appears there tensorflow API could return different python + # objects representing the same Operation node. + return op._id in self._pfor_op_ids + + @property + def is_stateful(self): + return self._is_stateful + + @property + def pfor_converter(self): + """Return a converter for the while loop.""" + return self + + def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs, + inputs_stacked): + """Create a PFor object for converting parts of the while_loop. + + Args: + parent_pfor: PFor object being used for converting the while_loop. + indices: int32 Tensor of ids for the iterations that are still active + (i.e. did not exit the while_loop). + cond_stacked: True if the while_loop condition is stacked. + inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note + that these Tensors are a subset of the loop variables for the generated + while_loop. + inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`, + indicating if the value is stacked or not. + + Returns: + A PFor instance. The instance is initialized by adding conversion mappings + of nodes that will be external to the conversion that the returned + instance will be used for. e.g. Enter nodes as well as Merge and Switch + outputs are mapped to converted values. + """ + num_outputs = len(self._outputs) + assert len(inputs) == len(self._enters) + assert len(inputs_stacked) == len(self._enters) + loop_var = parent_pfor.loop_var + loop_len = array_ops.size(indices) + pfor = PFor( + loop_var, + loop_len, + pfor_ops=self._pfor_ops, + all_indices=indices, + all_indices_partitioned=cond_stacked) + # Map all inputs of Enter nodes in self._direct_enters to their converted + # values. + for enter in self._direct_enters: + enter_input = enter.op.inputs[0] + converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper( + enter_input) + # Since these are resources / variants, they should be unstacked. + assert not stacked and not is_sparse_stacked, (enter, converted_enter) + pfor._add_conversion(enter, wrap(converted_enter, False)) + + # Map all Enter nodes to the inputs. + for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked): + pfor._add_conversion(enter, wrap(inp, stacked)) + # Map outputs of Switch and Merge. + for i in range(num_outputs): + wrapped_inp = wrap(inputs[i], inputs_stacked[i]) + merge = self._enter_merges[i] + pfor._add_conversion(merge.outputs[0], wrapped_inp) + # Note that second output of Merge is typically not used, except possibly + # as a control dependency. To avoid trying to output the correct value, we + # employ a hack here. We output a dummy invalid value with an incorrect + # dtype. This will allow control dependency to work but if using it as an + # input, it should typically lead to errors during graph construction due + # to dtype mismatch. + # TODO(agarwal): Check in the original graph to see if there are any + # consumers of this Tensor that use it as an input. + pfor._add_conversion(merge.outputs[1], + wrap(constant_op.constant(-1.0), False)) + switch = self._exit_switches[i] + # Don't need to worry about switch.output[0] which will feed to Exit node. + pfor._add_conversion(switch.outputs[1], wrapped_inp) + return pfor + + def _convert_enter(self, parent_pfor, enter): + """Converts an Enter node.""" + inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0]) + control_inputs = [ + parent_pfor._convert_helper(x).t for x in enter.op.control_inputs + ] + if control_inputs: + with ops.control_dependencies(control_inputs): + inp = array_ops.identity(inp) + return inp, stacked + + def _maybe_stacked(self, cache, inp): + """Heuristic to figue out if the coverting inp leads to a stacked value. + + + Args: + cache: map from Tensor to boolean indicating stacked/unstacked. + inp: input Tensor. + + Returns: + True if `inp` could get stacked. If the function returns False, the + converted value should be guaranteed to be unstacked. If returning True, + it may or may not be stacked. + """ + if inp in cache: + return cache[inp] + if not self.op_is_inside_loop(inp.op): + return False + op = inp.op + output = False + if op.type in [ + "Shape", + "Rank" + "ShapeN", + "ZerosLike", + "TensorArrayV3", + "TensorArraySizeV3", + ]: + output = False + elif _is_stateful_pfor_op(op): + # This may be fairly aggressive. + output = True + elif op.type == "Exit": + # This may be fairly aggressive. + output = True + else: + for t in op.inputs: + if self._maybe_stacked(cache, t): + output = True + break + cache[inp] = output + return output + + def _create_init_values(self, pfor_input): + """Create arguments passed to converted while_loop.""" + with ops.name_scope("while_init"): + loop_len_vector = pfor_input.pfor.loop_len_vector + loop_len = loop_len_vector[0] + num_outputs = len(self._outputs) + + inputs = [] + maybe_stacked_cache = {} + # Convert all the Enters. Need to do this before checking for stacking + # below. + for i, enter in enumerate(self._enters): + inp, stacked = self._convert_enter(pfor_input.pfor, enter) + inputs.append(inp) + maybe_stacked_cache[enter] = stacked + # Since this enter node is part of the `loop_vars`, it corresponds to an + # output and its preceding switch. We mark this switch's output the same + # stackness, to act at the base case for the logic below. Below, we will + # be going through the body figuring out which inputs might need to be + # stacked and which inputs can safely remain unstacked. + if i < num_outputs: + maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked + + # Shape invariants for init_values corresponding to self._enters. + input_shape_invariants = [] + # TensorArrays for outputs of converted while loop + output_tas = [] + # Shape invariants for output TensorArrays. + ta_shape_invariants = [] + # List of booleans indicating stackness of inputs, i.e. tensors + # corresponding to self._enters. + inputs_stacked = [] + for i, inp in enumerate(inputs): + enter = self._enters[i] + inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter) + # Note that even when an input is unstacked, the body could make it + # stacked. we use a heuristic below to figure out if body may be making + # it stacked. + if i < num_outputs: + body_output = self._body_outputs[i] + if enter.op in self._pfor_ops: + body_output_stacked = self._maybe_stacked(maybe_stacked_cache, + body_output) + else: + # If constructed outside of pfor loop, then the output would not be + # stacked. + body_output_stacked = False + if body_output_stacked and not inp_stacked: + inp = _stack(inp, loop_len_vector).t + inputs[i] = inp + inp_stacked = True + # TODO(agarwal): other attributes for the TensorArray ? + output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len)) + ta_shape_invariants.append(tensor_shape.TensorShape(None)) + + inputs_stacked.append(inp_stacked) + input_shape_invariants.append(tensor_shape.TensorShape(None)) + + # See documentation for __call__ for the structure of init_values. + init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas + # TODO(agarwal): try stricter shape invariants + shape_invariants = ( + [tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None) + ] + input_shape_invariants + ta_shape_invariants) + + return init_values, inputs_stacked, shape_invariants + + def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): + """Handles case when condition is unstacked. + + Note that all iterations end together. So we don't need to partition the + inputs. When all iterations are done, we write the inputs to the + TensorArrays. Note that we only write to index 0 of output_tas. Since all + iterations end together, they can all be output together. + """ + not_all_done = array_ops.reshape(conditions, []) + new_output_tas = [] + # pylint: disable=cell-var-from-loop + for i, out_ta in enumerate(output_tas): + inp = inputs[i] + new_output_tas.append( + control_flow_ops.cond(not_all_done, + lambda: out_ta, + lambda: out_ta.write(0, inp))) + # pylint: enable=cell-var-from-loop + return not_all_done, indices, inputs, new_output_tas + + def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, + output_tas): + num_outputs = len(self._outputs) + # Compute if all iterations are done. + not_all_done = math_ops.reduce_any(conditions) + conditions_int = math_ops.cast(conditions, dtypes.int32) + # Partition the indices. + done_indices, new_indices = data_flow_ops.dynamic_partition( + indices, conditions_int, 2) + + new_inputs = [] + new_output_tas = [] + for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): + # Partition the inputs. + if stacked: + done_inp, new_inp = data_flow_ops.dynamic_partition( + inp, conditions_int, 2) + else: + # TODO(agarwal): avoid this stacking. See TODO earlier in + # _process_cond_unstacked. + done_inp = _stack(inp, [array_ops.size(done_indices)]).t + new_inp = inp + new_inputs.append(new_inp) + # For iterations that are done, write them to TensorArrays. + if i < num_outputs: + out_ta = output_tas[i] + # Note that done_indices can be empty. done_inp should also be empty in + # that case. + new_output_tas.append(out_ta.scatter(done_indices, done_inp)) + return not_all_done, new_indices, new_inputs, new_output_tas + + def _process_body(self, pfor_input, inputs_stacked, + new_indices, cond_stacked, new_inputs, + not_all_done): + """Convert the body function.""" + + def true_fn(control_inputs, body_pfor, body_output, stacked): + """Converts the body function for all but last iteration. + + This essentially converts body_output. Additionally, it needs to handle + any control dependencies on the NextIteration node. So it creates another + Identity node with the converted dependencies. + """ + converted_control_inp = [] + for x in control_inputs: + for t in x.outputs: + converted_control_inp.append(body_pfor._convert_helper(t).t) + if stacked: + # Note convert always does the stacking. + output = body_pfor.convert(body_output) + else: + output, convert_stacked, _ = body_pfor._convert_helper(body_output) + assert convert_stacked == stacked, body_output + with ops.control_dependencies(converted_control_inp): + return array_ops.identity(output) + + body_pfor = self._init_pfor(pfor_input.pfor, new_indices, + cond_stacked, new_inputs, + inputs_stacked) + new_outputs = [] + + for i, (body_output, stacked) in enumerate( + zip(self._body_outputs, inputs_stacked)): + control_inp = self._next_iter_control_inputs[i] + out_dtype = body_output.dtype + # Note that we want to run the body only if not all pfor iterations are + # done. If all are done, we return empty tensors since these values will + # not be used. Notice that the value returned by the loop is based on + # TensorArrays and not directly on these returned values. + # pylint: disable=cell-var-from-loop + new_output = control_flow_ops.cond( + not_all_done, + lambda: true_fn(control_inp, body_pfor, body_output, stacked), + lambda: constant_op.constant([], dtype=out_dtype)) + # pylint: enable=cell-var-from-loop + new_outputs.append(new_output) + return new_outputs + + def __call__(self, pfor_input): + """Converter for the while_loop. + + The conversion of a while_loop is another while_loop. + + The arguments to this converted while_loop are as follows: + not_all_done: Boolean scalar Tensor indicating if all the pfor iterations + are done. + indices: int32 1-D Tensor storing the id of the iterations that are not + done. + args: Remaining arguments. These can be divided into 3 categories: + - First set of arguments are the tensors that correspond to the initial + elements of self._enters. The elements that appear in original while + loop's `loop_vars`. + - The second set of arguments are the tensors that correspond to the + remaining elements of self._enters. These are the tensors that directly + enter the original while loop body. + - Finally, the last set of arguments are TensorArrays. These TensorArrays + correspond to the outputs of the original while_loop, i.e. to the + elements in self._outputs. Each TensorArray has `PFor.loop_len` + elements, i.e. the number of pfor iterations. At the end, the i'th + element of each TensorArray will contain the output computed by the + i'th iteration of pfor. Note that elements can be written into these + tensors arrays in any order, depending on when the corresponding pfor + iteration is done. + If the original while_loop had `k` tensors in its `loop_vars` and its body + directly captured `m` tensors, the `args` will contain `2 * k + m` values. + + In each iteration, the while_loop body recomputes the condition for all + active pfor iterations to see which of them are now done. It then partitions + all the inputs and passes them along to the converted body. Values for all + the iterations that are done are written to TensorArrays indexed by the pfor + iteration number. When all iterations are done, the TensorArrays are stacked + to get the final value. + + Args: + pfor_input: A PForInput object corresponding to the output of any Exit + node from this while loop. + + Returns: + List of converted outputs. + """ + # Create init_values that will be passed to the while_loop. + init_values, inputs_stacked, shape_invariants = self._create_init_values( + pfor_input) + # Note that we use a list as a hack since we need the nested function body + # to set the value of cond_is_stacked. python2.x doesn't support nonlocal + # variables. + cond_is_stacked = [None] + + def cond(not_all_done, *_): + return not_all_done + + def body(not_all_done, indices, *args): + # See documentatin for __call__ for the structure of *args. + num_enters = len(self._enters) + inputs = args[:num_enters] + output_tas = args[num_enters:] + # TODO(agarwal): see which outputs have consumers and only populate the + # TensorArrays corresponding to those. Or do those paths get trimmed out + # from inside the while_loop body? + assert len(inputs) >= len(output_tas) + assert len(inputs) == len(inputs_stacked) + + # Convert condition + with ops.name_scope("while_cond"): + # Note that we set cond_stacked to True here. At this point we don't + # know if it could be loop invariant, hence the conservative value is + # to assume stacked. + cond_pfor = self._init_pfor(pfor_input.pfor, indices, + cond_stacked=True, + inputs=inputs, + inputs_stacked=inputs_stacked) + conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition) + cond_is_stacked[0] = cond_stacked + + # Recompute the new condition, write outputs of done iterations, and + # partition the inputs if needed. + if not cond_stacked: + (not_all_done, new_indices, + new_inputs, new_output_tas) = self._process_cond_unstacked( + conditions, indices, inputs, output_tas) + else: + (not_all_done, new_indices, + new_inputs, new_output_tas) = self._process_cond_stacked( + conditions, indices, inputs, inputs_stacked, output_tas) + + # Convert body + with ops.name_scope("while_body"): + # Compute the outputs from the body. + new_outputs = self._process_body(pfor_input, inputs_stacked, + new_indices, cond_stacked, new_inputs, + not_all_done) + + # Note that the first num_outputs new values of inputs are computed using + # the body. Rest of them were direct Enters into the condition/body and + # the partitioning done earlier is sufficient to give the new value. + num_outputs = len(self._outputs) + new_args = ([not_all_done, new_indices] + new_outputs + list( + new_inputs[num_outputs:]) + new_output_tas) + return tuple(new_args) + + while_outputs = control_flow_ops.while_loop( + cond, body, init_values, shape_invariants=shape_invariants) + output_tas = while_outputs[-len(self._outputs):] + outputs = [] + assert cond_is_stacked[0] is not None + for inp_stacked, ta in zip(inputs_stacked, output_tas): + if cond_is_stacked[0]: + outputs.append(wrap(ta.stack(), True)) + else: + # Note that if while_loop condition is unstacked, all iterations exit at + # the same time and we wrote those outputs in index 0 of the tensor + # array. + outputs.append(wrap(ta.read(0), inp_stacked)) + return outputs + + +class _PforInput(object): + """Input object passed to registered pfor converters.""" + + def __init__(self, pfor, op, inputs): + """Creates a _PforInput object. + + Args: + pfor: PFor converter object. + op: the Operation object that is being converted. + inputs: list of WrappedTensor objects representing converted values of the + inputs of `op`. + """ + self.pfor = pfor + self._op = op + self._inputs = inputs + + def stack_inputs(self, stack_indices=None): + """Stacks unstacked inputs at `stack_indices`. + + Args: + stack_indices: indices of inputs at which stacking is done. If None, + stacking is done at all indices. + """ + if stack_indices is None: + stack_indices = range(len(self._inputs)) + length = self.pfor.loop_len_vector + for i in stack_indices: + inp = self._inputs[i] + if not inp.is_stacked: + self._inputs[i] = _stack(inp.t, length) + + def expanddim_inputs_for_broadcast(self): + """Reshapes stacked inputs to prepare them for broadcast. + + Since stacked inputs have an extra leading dimension, automatic broadcasting + rules could incorrectly try to expand dimensions before that leading + dimension. To avoid that, we reshape these stacked inputs to the maximum + rank they will need to be broadcasted to. + """ + if not self._inputs: + return + + # Find max rank + def _get_rank(x): + rank = array_ops.rank(x.t) + if not x.is_stacked: + rank += 1 + return rank + + ranks = [_get_rank(x) for x in self._inputs] + max_rank = ranks[0] + for rank in ranks[1:]: + max_rank = math_ops.maximum(rank, max_rank) + + for i, inp in enumerate(self._inputs): + if inp.is_stacked: + shape = array_ops.shape(inp.t) + rank_diff = array_ops.reshape(max_rank - ranks[i], [1]) + ones = array_ops.tile([1], rank_diff) + new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0) + self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True) + + @property + def inputs(self): + return self._inputs + + @property + def num_inputs(self): + return len(self._inputs) + + def input(self, index): + assert len(self._inputs) > index, (index, self._inputs) + return self._inputs[index] + + def stacked_input(self, index): + t, is_stacked, _ = self.input(index) + if not is_stacked: + op_type = self.op_type + op_def = getattr(self._op, "op_def", None) + if op_def is None: + input_name = "at index %d" % index + else: + input_name = "\"%s\"" % op_def.input_arg[index].name + raise ValueError("Input %s of op \"%s\" expected to be not loop invariant" + ".\nError while converting op %s" + "with converted inputs\n%s" % (input_name, op_type, + self._op, self.inputs)) + return t + + def unstacked_input(self, index): + t, is_stacked, _ = self.input(index) + if is_stacked: + op_type = self.op_type + op_def = getattr(self._op, "op_def", None) + if op_def is None: + input_name = "at index %d" % index + else: + input_name = "\"%s\"" % op_def.input_arg[index].name + raise ValueError("Input %s of op \"%s\" expected to be loop invariant" + ".\nError while converting op %s" + "with converted inputs\n%s" % (input_name, op_type, + self._op, self.inputs)) + return t + + @property + def op(self): + return self._op + + @property + def op_type(self): + return self._op.type + + def get_attr(self, attr): + return self._op.get_attr(attr) + + @property + def outputs(self): + return self._op.outputs + + def output(self, index): + assert index < len(self._op.outputs) + return self._op.outputs[index] + + +_pfor_converter_registry = {} + + +class RegisterPFor(object): + """Utility to register converters for pfor. + + Usage: + @RegisterPFor(foo_op_type) + def _foo_converter(pfor_input): + ... + + The above will register conversion function `_foo_converter` for handling + conversion of `foo_op_type`. During conversion, the registered functin will be + called with a single argument of type `PForInput` which will contain state + needed for the conversion. This registered function should output a list of + WrappedTensor object with the same length as the number of outputs of op being + converted. If the op had zero outputs, then it should return a ops.Operation + object. + """ + + def __init__(self, op_type): + """Creates an object to register a converter for op with type `op_type`.""" + self.op_type = op_type + + def __call__(self, converter): + name = self.op_type + assert name not in _pfor_converter_registry, "Re-registering %s " % name + _pfor_converter_registry[name] = converter + return converter + + +class RegisterPForWithArgs(RegisterPFor): + """Utility to register converters for pfor. + + Usage: + @RegisteRPFor(foo_op_type, foo=value, ....) + def _foo_converter(pfor_input, foo=None, ....): + ... + + See RegisterPFor for details on the conversion function. + `RegisterPForWithArgs` allows binding extra arguments to the + conversion function at registration time. + """ + + def __init__(self, op_type, *args, **kw_args): + super(RegisterPForWithArgs, self).__init__(op_type) + self._args = args + self._kw_args = kw_args + + def __call__(self, converter): + + def _f(pfor_input): + return converter(pfor_input, self.op_type, *self._args, **self._kw_args) + + super(RegisterPForWithArgs, self).__call__(_f) + return converter + + +def _create_op(op_type, inputs, op_dtypes, attrs=None): + """Utility to create an op.""" + return ops.get_default_graph().create_op( + op_type, inputs, op_dtypes, attrs=attrs, compute_device=True) + + +WrappedTensor = collections.namedtuple("WrappedTensor", + ["t", "is_stacked", "is_sparse_stacked"]) +"""Wrapper around the result of a Tensor conversion. + +The additional fields are useful for keeping track of the conversion state as +data flows through the ops in the loop body. For every op whose output is a +Tensor, its converter should return either a WrappedTensor or a list of +WrappedTensors. + +Args: + t: The converted tensor + is_stacked: True if the tensor is stacked, i.e. represents the results of all + the iterations of the loop, where each row i of the tensor corresponds to + that op's output on iteration i of the loop. False if the tensor is not + stacked, i.e. represents the result of the op on of a single iteration of + the loop, where the result does not vary between iterations. + is_sparse_stacked: True if the tensor corresponds to a component tensor + (indices, values, or dense_shape) of a sparse tensor, and has been logically + stacked via a sparse conversion. +""" + + +def wrap(tensor, is_stacked=True, is_sparse_stacked=False): + """Helper to create a WrappedTensor object.""" + assert isinstance(is_stacked, bool) + assert isinstance(is_sparse_stacked, bool) + assert isinstance(tensor, ops.Tensor) + assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is " + "stacked via a sparse " + "conversion, it must also be " + "stacked.") + return WrappedTensor(tensor, is_stacked, is_sparse_stacked) + + +def _fallback_converter(pfor_input): + logging.warn("Using a while_loop for converting %s", pfor_input.op_type) + output_dtypes = [x.dtype for x in pfor_input.outputs] + iters = pfor_input.pfor.loop_len_vector[0] + + def while_body(i, *ta_list): + """Body of while loop.""" + inputs = [ + x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs + ] + op_outputs = _create_op( + pfor_input.op_type, + inputs, + output_dtypes, + attrs=pfor_input.op.node_def.attr).outputs + + outputs = [] + for out, ta in zip(op_outputs, ta_list): + assert isinstance(out, ops.Tensor) + outputs.append(ta.write(i, array_ops.expand_dims(out, 0))) + return tuple([i + 1] + outputs) + + ta_list = control_flow_ops.while_loop( + lambda i, *ta: i < iters, while_body, [0] + [ + tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes + ])[1:] + return tuple([wrap(ta.concat(), True) for ta in ta_list]) + + +class PFor(object): + """Implementation of rewrite of parallel-for loops. + + This class takes a DAG or a set of DAGs representing the body of a + parallel-for loop, and adds new operations to the graph that implements + functionality equivalent to running that loop body for a specified number of + iterations. This new set of nodes may or may not use a tensorflow loop + construct. + + The process of conversion does not delete or change any existing operations. + It only adds operations that efficiently implement the equivalent + functionality. We refer to the added ops as "converted ops". + + The conversion process uses a simple greedy heuristic. It walks the loop body + and tries to express the functionality of running each node in a loop with a + new set of nodes. When converting an op several cases are possible: + - The op is not inside the loop body. Hence it can be used as is. + - The op does not depend on the iteration number and is stateless. In this + case, it can be used as is. + - The op is not stateful, and depends on iteration number only through control + dependencies. In this case, we can create a single op with same inputs and + attributes, but with "converted" control dependencies. + - The op is not stateful, and all its inputs are loop invariant. In this + case, similar to above, we can create a single op with same inputs and + attributes, but with "converted" control dependencies. + - The op is stateful or at least one of the inputs is not loop invariant. In + this case, we run the registered converter for that op to create a set of + converted ops. All nodes in the set will have converted control dependencies + corresponding to control dependencies of the original op. If the op returned + multiple outputs, "converted outputs" could be produced by different ops in + this set. + """ + + def __init__(self, + loop_var, + loop_len, + pfor_ops, + all_indices=None, + all_indices_partitioned=False): + """Creates an object to rewrite a parallel-for loop. + + Args: + loop_var: ops.Tensor output of a Placeholder operation. The value should + be an int32 scalar representing the loop iteration number. + loop_len: A scalar or scalar Tensor representing the number of iterations + the loop is run for. + pfor_ops: List of all ops inside the loop body. + all_indices: If not None, an int32 vector with size `loop_len` + representing the iteration ids that are still active. These values + should be unique and sorted. However they may not be contiguous. This is + typically the case when inside a control flow construct which has + partitioned the indices of the iterations that are being converted. + all_indices_partitioned: If True, this object is being constructed from a + control flow construct where not all the pfor iterations are guaranteed + to be active. + """ + assert isinstance(loop_var, ops.Tensor) + assert loop_var.op.type == "Placeholder" + self._loop_var = loop_var + loop_len_value = tensor_util.constant_value(loop_len) + if loop_len_value is not None: + loop_len = loop_len_value + self._loop_len_vector = array_ops.reshape(loop_len, [1]) + self._all_indices_partitioned = all_indices_partitioned + if all_indices_partitioned: + assert all_indices is not None + self.all_indices = ( + math_ops.range(loop_len) if all_indices is None else all_indices) + + self._conversion_map = {} + self._conversion_map[loop_var] = wrap(self.all_indices, True) + self._pfor_ops = set(pfor_ops) + self._pfor_op_ids = set([x._id for x in pfor_ops]) + + def op_is_inside_loop(self, op): + """True if op was created inside the pfor loop body.""" + assert isinstance(op, ops.Operation) + # Note that we use self._pfor_op_ids for the check and not self._pfor_ops + # since it appears there tensorflow API could return different python + # objects representing the same Operation node. + return op._id in self._pfor_op_ids + + def _convert_sparse(self, y): + """Returns the converted value corresponding to SparseTensor y. + + For SparseTensors, instead of stacking the component tensors separately, + resulting in component tensors with shapes (N, m, rank), (N, m), and (N, + rank) respectively for indices, values, and dense_shape (where N is the loop + length and m is the number of sparse tensor values per loop iter), we want + to logically stack the SparseTensors, to create a SparseTensor whose + components are size (N * m, rank + 1), (N * m, ), and (rank + 1,) + respectively. + + Here, we try to get the conversion of each component tensor. + If the tensors are stacked via a sparse conversion, return the resulting + SparseTensor composed of the converted components. Otherwise, the component + tensors are either unstacked or stacked naively. In the latter case, we + unstack the component tensors to reform loop_len SparseTensor elements, + then correctly batch them. + + The unstacked tensors must have the same rank. Each dimension of each + SparseTensor will expand to be the largest among all SparseTensor elements + for that dimension. For example, if there are N SparseTensors of rank 3 + being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i), + the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)). + + Args: + y: A tf.SparseTensor. + + Returns: + A tf.SparseTensor that is the converted value corresponding to y. + """ + outputs = [ + self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape) + ] + assert all(isinstance(o, WrappedTensor) for o in outputs) + + if all(w.is_sparse_stacked for w in outputs): + return sparse_tensor.SparseTensor(*[w.t for w in outputs]) + + assert not any(w.is_sparse_stacked for w in outputs), ( + "Error converting SparseTensor. All components should be logically " + "stacked, or none.") + + # If component tensors were not sparsely stacked, they are either unstacked + # or stacked without knowledge that they are components of sparse tensors. + # In this case, we have to restack them. + return self._restack_sparse_tensor_logically( + *[self._unwrap_or_tile(w) for w in outputs]) + + def _restack_sparse_tensor_logically(self, indices, values, shape): + sparse_tensor_rank = indices.get_shape()[-1].value + if sparse_tensor_rank is not None: + sparse_tensor_rank += 1 + + def map_fn(args): + res = gen_sparse_ops.serialize_sparse( + args[0], args[1], args[2], out_type=dtypes.variant) + return res + + # Applies a map function to the component tensors to serialize each + # sparse tensor element and batch them all, then deserializes the batch. + # TODO(rachelim): Try to do this without map_fn -- add the right offsets + # to shape and indices tensors instead. + result = functional_ops.map_fn( + map_fn, [indices, values, shape], dtype=dtypes.variant) + return sparse_ops.deserialize_sparse( + result, dtype=values.dtype, rank=sparse_tensor_rank) + + def _unwrap_or_tile(self, wrapped_tensor): + """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it.""" + output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked + if is_stacked: + return output + else: + return _stack(output, self._loop_len_vector).t + + def convert(self, y): + """Returns the converted value corresponding to y. + + Args: + y: A ops.Tensor or a ops.Operation object. If latter, y should not have + any outputs. + + Returns: + If y does not need to be converted, it returns y as is. Else it returns + the "converted value" corresponding to y. + """ + if isinstance(y, sparse_tensor.SparseTensor): + return self._convert_sparse(y) + output = self._convert_helper(y) + if isinstance(output, WrappedTensor): + assert isinstance(y, ops.Tensor) + return self._unwrap_or_tile(output) + else: + assert isinstance(y, ops.Operation) + assert not y.outputs + assert isinstance(output, ops.Operation) + return output + + def _was_converted(self, t): + """True if t is not a conversion of itself.""" + converted_t = self._conversion_map[t] + return converted_t.t is not t + + def _add_conversion(self, old_output, new_output): + self._conversion_map[old_output] = new_output + + def _convert_helper(self, op_or_tensor): + stack = [op_or_tensor] + while stack: + y = stack[0] + if y in self._conversion_map: + assert isinstance(self._conversion_map[y], + (WrappedTensor, ops.Operation)) + stack.pop(0) + continue + if isinstance(y, ops.Operation): + assert not y.outputs, ( + "We only support converting Operation objects with no outputs. " + "Got %s", y) + y_op = y + else: + assert isinstance(y, ops.Tensor), y + y_op = y.op + + is_while_loop = y_op.type == "Exit" + if is_while_loop: + while_op = WhileOp(y, pfor_ops=self._pfor_ops) + is_inside_loop = while_op.is_inside_loop + # If all nodes in the while_loop graph were created inside the pfor, we + # treat the whole loop subgraph as a single op (y_op) and try to convert + # it. For while_loops that are created completely or partially outside, + # we treat them as external and should be able to simply return the Exit + # node output as is without needing any conversion. Note that for + # while_loops that are partially constructed inside, we assume they will + # be loop invariant. If that is not the case, it will create runtime + # errors since the converted graph would depend on the self._loop_var + # placeholder. + if is_inside_loop: + y_op = while_op + else: + is_inside_loop = self.op_is_inside_loop(y_op) + + # If this op was not created inside the loop body, we will return as is. + # 1. Convert inputs and control inputs. + + def _add_to_stack(x): + if x not in self._conversion_map: + stack.insert(0, x) + return True + else: + return False + + if is_inside_loop: + added_to_stack = False + for inp in y_op.inputs: + added_to_stack |= _add_to_stack(inp) + for cinp in y_op.control_inputs: + if cinp.outputs: + for t in cinp.outputs: + added_to_stack |= _add_to_stack(t) + else: + added_to_stack |= _add_to_stack(cinp) + if added_to_stack: + continue + + converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs] + some_input_converted = any( + [self._was_converted(x) for x in y_op.inputs]) + some_input_stacked = any([x.is_stacked for x in converted_inputs]) + + converted_control_ops = set() + some_control_input_converted = False + for cinp in y_op.control_inputs: + if cinp.outputs: + for t in cinp.outputs: + converted_t = self._conversion_map[t] + if self._was_converted(t): + some_control_input_converted = True + converted_control_ops.add(converted_t.t.op) + else: + converted_cinp = self._conversion_map[cinp] + assert isinstance(converted_cinp, ops.Operation) + if converted_cinp != cinp: + some_control_input_converted = True + converted_control_ops.add(converted_cinp) + converted_control_ops = list(converted_control_ops) + is_stateful = _is_stateful_pfor_op(y_op) + else: + converted_inputs = [] + converted_control_ops = [] + logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op, + converted_inputs, converted_control_ops) + + # 2. Convert y_op + # If converting a while_loop, we let the while_loop convertor deal with + # putting the control dependencies appropriately. + control_dependencies = [] if is_while_loop else converted_control_ops + with ops.control_dependencies(control_dependencies), ops.name_scope( + y_op.name + "/pfor/"): + # None of the inputs and control inputs were converted. + if (not is_inside_loop or + (not is_stateful and not some_input_converted and + not some_control_input_converted)): + if y == y_op: + assert not isinstance(y_op, WhileOp) + new_outputs = y_op + else: + new_outputs = [wrap(x, False) for x in y_op.outputs] + elif not (is_stateful or is_while_loop or some_input_stacked): + # All inputs are unstacked or uncoverted but some control inputs are + # converted. + # TODO(rachelim): Handle the case where some inputs are sparsely + # stacked (i.e. any([x.is_sparse_stacked for x in converted_inputs])) + new_op = _create_op(y_op.type, [x.t for x in converted_inputs], + [x.dtype for x in y_op.outputs], + y_op.node_def.attr) + if y == y_op: + new_outputs = new_op + else: + new_outputs = [wrap(x, False) for x in new_op.outputs] + else: + # Either some inputs are not loop invariant or op is stateful. + if hasattr(y_op, "pfor_converter"): + converter = y_op.pfor_converter + else: + converter = _pfor_converter_registry.get(y_op.type, None) + if converter is None: + if flags.FLAGS.op_conversion_fallback_to_while_loop: + converter = _fallback_converter + else: + raise ValueError( + "No converter defined for %s\n%s\ninputs: %s. " + "\nEither add a converter or set " + "--op_conversion_fallback_to_while_loop=True, " + "which may run slower" % (y_op.type, y_op, converted_inputs)) + # TODO(rachelim): Handle the case where some inputs are sparsely + # stacked. We should only call the converter if it supports handling + # those inputs. + new_outputs = converter(_PforInput(self, y_op, converted_inputs)) + if isinstance(new_outputs, WrappedTensor): + new_outputs = [new_outputs] + assert isinstance(new_outputs, + (list, tuple, ops.Operation)), new_outputs + logging.vlog(2, "converted %s %s", y_op, new_outputs) + + # Insert into self._conversion_map + if y == y_op: + assert isinstance(new_outputs, ops.Operation) + self._add_conversion(y_op, new_outputs) + else: + for old_output, new_output in zip(y_op.outputs, new_outputs): + assert isinstance(new_output, WrappedTensor), (new_output, y, y_op) + self._add_conversion(old_output, new_output) + stack.pop(0) + + return self._conversion_map[op_or_tensor] + + @property + def loop_len_vector(self): + """Returns a single element vector whose value is number of iterations.""" + return self._loop_len_vector + + @property + def loop_var(self): + """Returns placeholder loop variable.""" + return self._loop_var + + @property + def pfor_ops(self): + return self._pfor_ops + + @property + def all_indices_partitioned(self): + """all_indices_partitioned property. + + Returns: + True if we are inside a control flow construct and not all pfor iterations + may be active. + """ + return self._all_indices_partitioned + +# nn_ops + + +def _flatten_first_two_dims(x): + """Merges first two dimensions.""" + old_shape = array_ops.shape(x) + new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0) + return array_ops.reshape(x, new_shape) + + +def _unflatten_first_dim(x, first_dim): + """Splits first dimension into [first_dim, -1].""" + old_shape = array_ops.shape(x) + new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0) + return array_ops.reshape(x, new_shape) + + +def _inputs_with_flattening(pfor_input, input_indices): + """Stacks and flattens first dim of inputs at indices `input_indices`.""" + if input_indices is None: + input_indices = [] + pfor_input.stack_inputs(stack_indices=input_indices) + inputs = [] + for i in range(pfor_input.num_inputs): + if i in input_indices: + inp = pfor_input.stacked_input(i) + inp = _flatten_first_two_dims(inp) + else: + inp = pfor_input.unstacked_input(i) + inputs.append(inp) + return inputs + + +@RegisterPForWithArgs("Conv2D", dims=[0]) +@RegisterPForWithArgs("AvgPool", dims=[0]) +@RegisterPForWithArgs("MaxPool", dims=[0]) +@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2]) +@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1]) +def _convert_flatten_batch(pfor_input, op_type, dims): + del op_type + inputs = _inputs_with_flattening(pfor_input, dims) + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + n = pfor_input.pfor.loop_len_vector + outputs = [_unflatten_first_dim(x, n) for x in outputs] + return [wrap(x, True) for x in outputs] + + +_channel_flatten_input_cache = {} + + +def _channel_flatten_input(x, data_format): + """Merge the stack dimension with the channel dimension. + + If S is pfor's stacking dimension, then, + - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose + should be cheap. + - for SNHWC, we transpose to NHWCS. + We then merge the S and C dimension. + + Args: + x: ops.Tensor to transform. + data_format: "NCHW" or "NHWC". + + Returns: + A 3-element tuple with the transformed value, along with the shape for + reshape and order for transpose required to transform back. + """ + + graph = ops.get_default_graph() + cache_key = (graph, x, data_format) + if cache_key not in _channel_flatten_input_cache: + x_shape = array_ops.shape(x) + if data_format == b"NCHW": + order = [1, 0, 2, 3, 4] + shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0) + reverse_order = order + else: + order = [1, 2, 3, 0, 4] + shape = array_ops.concat([x_shape[1:4], [-1]], axis=0) + reverse_order = [3, 0, 1, 2, 4] + # Move S dimension next to C dimension. + x = array_ops.transpose(x, order) + reverse_shape = array_ops.shape(x) + # Reshape to merge the S and C dimension. + x = array_ops.reshape(x, shape) + outputs = x, reverse_order, reverse_shape + _channel_flatten_input_cache[cache_key] = outputs + else: + outputs = _channel_flatten_input_cache[cache_key] + return outputs + + +# Note that with training=True, running FusedBatchNorm on individual examples +# is very different from running FusedBatchNorm on a batch of those examples. +# This is because, for the latter case, the operation can be considered as first +# computing the mean and variance over all the examples and then using these +# to scale all those examples. This creates a data dependency between these +# different "iterations" since the inputs to the scaling step depends on the +# statistics coming from all these inputs. +# As with other kernels, the conversion here effectively runs the kernel +# independently for each iteration, and returns outputs by stacking outputs from +# each of those iterations. +@RegisterPFor("FusedBatchNorm") +def _convert_fused_batch_norm(pfor_input): + is_training = pfor_input.get_attr("is_training") + # When BatchNorm is used with training=False, mean and variance are provided + # externally and used as is by the op. Thus, we can merge the S and N + # dimensions as we do for regular operations. + # When BatchNorm is used with training=True, mean and variance are computed + # for each channel across the batch dimension (first one). If we merge S and N + # dimensions, mean and variances will be computed over a larger set. So, we + # merge the S and C dimensions instead. + if not is_training: + # We return zeros for batch_mean and batch_variance output. Note that CPU + # and GPU seem to have different behavior for those two outputs. CPU outputs + # zero because these values are not used during inference. GPU outputs + # something, probably real means and variances. + inputs = _inputs_with_flattening(pfor_input, [0]) + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + y = outputs[0] + n = pfor_input.pfor.loop_len_vector + y = _unflatten_first_dim(y, n) + mean = pfor_input.unstacked_input(3) + zeros = array_ops.zeros_like(mean) + return [wrap(y, True), wrap(zeros, False), wrap(zeros, False)] + + pfor_input.stack_inputs() + data_format = pfor_input.get_attr("data_format") + # We merge the first dimension with the "C" dimension, run FusedBatchNorm, and + # then transpose back. + x = pfor_input.stacked_input(0) + x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format) + # Note that we stack all the other inputs as well so that they are the same + # size as the new size of the channel dimension. + inputs = [x] + [ + array_ops.reshape(pfor_input.stacked_input(i), [-1]) + for i in range(1, pfor_input.num_inputs) + ] + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + y = outputs[0] + y = array_ops.reshape(y, reverse_shape) + y = array_ops.transpose(y, reverse_order) + n = pfor_input.pfor.loop_len_vector + outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] + outputs = [y] + outputs + return [wrap(x, True) for x in outputs] + + +@RegisterPFor("FusedBatchNormGrad") +def _convert_fused_batch_norm_grad(pfor_input): + pfor_input.stack_inputs() + data_format = pfor_input.get_attr("data_format") + y_backprop = pfor_input.stacked_input(0) + y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format) + x = pfor_input.stacked_input(1) + x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format) + inputs = [y_backprop, x] + [ + array_ops.reshape(pfor_input.stacked_input(i), [-1]) + for i in range(2, pfor_input.num_inputs) + ] + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + x_backprop = outputs[0] + x_backprop = array_ops.reshape(x_backprop, x_reverse_shape) + x_backprop = array_ops.transpose(x_backprop, x_reverse_order) + n = pfor_input.pfor.loop_len_vector + outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] + outputs = [x_backprop] + outputs + return [wrap(output, True) for output in outputs] + + +@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0) +@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0) +def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims, + shape_dim): + del op_type + inputs = _inputs_with_flattening(pfor_input, flatten_dims) + n = pfor_input.pfor.loop_len_vector + # Adjust the `input_sizes` input. + ones = array_ops.ones( + [array_ops.shape(inputs[shape_dim])[0] - 1], dtype=n.dtype) + inputs[shape_dim] *= array_ops.concat([n, ones], axis=0) + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + outputs = [_unflatten_first_dim(x, n) for x in outputs] + return [wrap(x, True) for x in outputs] + + +@RegisterPFor("Conv2DBackpropFilter") +def _convert_conv2d_backprop_filter(pfor_input): + pfor_input.stack_inputs(stack_indices=[2]) + inputs, inputs_stacked, _ = pfor_input.input(0) + filter_sizes = pfor_input.unstacked_input(1) + grads = pfor_input.stacked_input(2) + strides = pfor_input.get_attr("strides") + padding = pfor_input.get_attr("padding") + use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu") + data_format = pfor_input.get_attr("data_format") + dilations = pfor_input.get_attr("dilations") + if inputs_stacked: + # TODO(agarwal): Implement this efficiently. + logging.warn("Conv2DBackpropFilter uses a while_loop. Fix that!") + + def while_body(i, ta): + inp_i = inputs[i, ...] + grad_i = grads[i, ...] + output = nn_ops.conv2d_backprop_filter( + inp_i, + filter_sizes, + grad_i, + strides=strides, + padding=padding, + use_cudnn_on_gpu=use_cudnn_on_gpu, + data_format=data_format, + dilations=dilations) + return i + 1, ta.write(i, array_ops.expand_dims(output, 0)) + + n = array_ops.reshape(pfor_input.pfor.loop_len_vector, []) + _, ta = control_flow_ops.while_loop( + lambda i, ta: i < n, while_body, + (0, tensor_array_ops.TensorArray(inputs.dtype, n))) + output = ta.concat() + return wrap(output, True) + else: + # We merge the stack dimension with the channel dimension of the gradients + # and pretend we had a larger filter (see change to filter_sizes below). + # Once the filter backprop is computed, we reshape and transpose back + # appropriately. + grads, _, _ = _channel_flatten_input(grads, data_format) + n = pfor_input.pfor.loop_len_vector + old_filter_sizes = filter_sizes + filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0) + output = nn_ops.conv2d_backprop_filter( + inputs, + filter_sizes, + grads, + strides=strides, + padding=padding, + use_cudnn_on_gpu=use_cudnn_on_gpu, + data_format=data_format, + dilations=dilations) + new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0) + output = array_ops.reshape(output, new_filter_shape) + output = array_ops.transpose(output, [3, 0, 1, 2, 4]) + return wrap(output, True) + + +# array_ops + + +@RegisterPForWithArgs("Identity", array_ops.identity) +@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient) +def _convert_identity(pfor_input, op_type, op_func): + del op_type + return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) + + +@RegisterPFor("Reshape") +def _convert_reshape(pfor_input): + t = pfor_input.stacked_input(0) + shape = pfor_input.unstacked_input(1) + new_dim = array_ops.shape(t)[:1] + new_shape = array_ops.concat([new_dim, shape], axis=0) + return wrap(array_ops.reshape(t, new_shape), True) + + +@RegisterPFor("ExpandDims") +def _convert_expanddims(pfor_input): + t = pfor_input.stacked_input(0) + dim = pfor_input.unstacked_input(1) + dim += math_ops.cast(dim >= 0, dtypes.int32) + return wrap(array_ops.expand_dims(t, axis=dim), True) + + +@RegisterPFor("Slice") +def _convert_slice(pfor_input): + t = pfor_input.stacked_input(0) + begin = pfor_input.unstacked_input(1) + size = pfor_input.unstacked_input(2) + begin = array_ops.concat([[0], begin], axis=0) + size = array_ops.concat([[-1], size], axis=0) + return wrap(array_ops.slice(t, begin, size), True) + + +@RegisterPFor("Tile") +def _convert_tile(pfor_input): + t = pfor_input.stacked_input(0) + multiples = pfor_input.unstacked_input(1) + multiples = array_ops.concat([[1], multiples], 0) + return wrap(array_ops.tile(t, multiples), True) + + +@RegisterPFor("Pack") +def _convert_pack(pfor_input): + pfor_input.stack_inputs() + axis = pfor_input.get_attr("axis") + if axis >= 0: + axis += 1 + return wrap( + array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True) + + +@RegisterPFor("Unpack") +def _convert_unpack(pfor_input): + value = pfor_input.stacked_input(0) + axis = pfor_input.get_attr("axis") + if axis >= 0: + axis += 1 + num = pfor_input.get_attr("num") + return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)] + + +@RegisterPFor("Pad") +def _convert_pad(pfor_input): + t = pfor_input.stacked_input(0) + paddings = pfor_input.unstacked_input(1) + paddings = array_ops.concat([[[0, 0]], paddings], 0) + return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True) + + +@RegisterPFor("Split") +def _convert_split(pfor_input): + split_dim = pfor_input.unstacked_input(0) + t = pfor_input.stacked_input(1) + num_split = pfor_input.get_attr("num_split") + split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) + return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)] + + +@RegisterPFor("Transpose") +def _convert_transpose(pfor_input): + t = pfor_input.stacked_input(0) + perm = pfor_input.unstacked_input(1) + new_perm = array_ops.concat([[0], perm + 1], axis=0) + return wrap(array_ops.transpose(t, new_perm), True) + + +@RegisterPFor("ZerosLike") +def _convert_zeroslike(pfor_input): + t = pfor_input.stacked_input(0) + shape = array_ops.shape(t)[1:] + return wrap(array_ops.zeros(shape, dtype=t.dtype), False) + + +@RegisterPFor("Gather") +@RegisterPFor("GatherV2") +def _convert_gather(pfor_input): + param, param_stacked, _ = pfor_input.input(0) + indices, indices_stacked, _ = pfor_input.input(1) + op_type = pfor_input.op_type + if op_type == "Gather": + validate_indices = pfor_input.get_attr("validate_indices") + axis = 0 + else: + validate_indices = None + axis = pfor_input.unstacked_input(2) + axis_value = tensor_util.constant_value(axis) + if axis_value is not None: + axis = axis_value + if indices_stacked and not param_stacked: + if indices == pfor_input.pfor.all_indices and axis == 0: + param_shape0 = param.shape[0].value + indices_shape0 = indices.shape[0].value + if param_shape0 is not None and indices_shape0 == param_shape0: + # Note that with loops and conditionals, indices may not be contiguous. + # However they will be sorted and unique. So if the shape matches, then + # it must be picking up all the rows of param. + return wrap(param, True) + # TODO(agarwal): use array_ops.slice here. + output = array_ops.gather( + param, indices, validate_indices=validate_indices, axis=axis) + if axis != 0: + axis = control_flow_ops.cond( + axis < 0, lambda: axis + array_ops.rank(param), lambda: axis) + order = array_ops.concat( + [[axis], + math_ops.range(axis), + math_ops.range(axis + 1, array_ops.rank(output))], + axis=0) + output = control_flow_ops.cond( + math_ops.equal(axis, 0), lambda: output, + lambda: array_ops.transpose(output, order)) + return wrap(output, True) + if param_stacked: + loop_len_vector = pfor_input.pfor.loop_len_vector + pfor_input.stack_inputs(stack_indices=[1]) + indices = pfor_input.stacked_input(1) + param_flat = _flatten_first_two_dims(param) + + # Recompute indices to handle stacked param. + indices_offset = math_ops.range( + loop_len_vector[0]) * array_ops.shape(param)[1] + # Reshape indices_offset to allow broadcast addition + ones = array_ops.ones([array_ops.rank(indices) - 1], dtype=dtypes.int32) + new_shape = array_ops.concat([loop_len_vector, ones], axis=0) + indices_offset = array_ops.reshape(indices_offset, new_shape) + indices += indices_offset + + # TODO(agarwal): handle axis != 0. May need to transpose param or + # array_ops.gather_nd. + if isinstance(axis, ops.Tensor): + axis_value = tensor_util.constant_value(axis) + else: + try: + axis_value = int(axis) + except TypeError: + axis_value = None + msg = ("Gather, where indices and param are both loop dependent, currently " + "requires axis=0") + if axis_value is not None and axis_value != 0: + raise ValueError("Error while converting %s. %s. Got axis=%d" % + (pfor_input.op, msg, axis)) + with ops.control_dependencies( + [check_ops.assert_equal(axis, 0, message=msg)]): + output = array_ops.gather(param_flat, indices) + return wrap(output, True) + + +@RegisterPFor("ConcatV2") +def _convert_concatv2(pfor_input): + n = pfor_input.num_inputs + pfor_input.stack_inputs(stack_indices=range(n - 1)) + axis = pfor_input.unstacked_input(n - 1) + axis += math_ops.cast(axis >= 0, axis.dtype) + return wrap( + array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis), + True) + + +@RegisterPFor("StridedSlice") +def _convert_strided_slice(pfor_input): + inp = pfor_input.stacked_input(0) + begin = pfor_input.unstacked_input(1) + end = pfor_input.unstacked_input(2) + strides = pfor_input.unstacked_input(3) + begin_mask = pfor_input.get_attr("begin_mask") + end_mask = pfor_input.get_attr("end_mask") + ellipsis_mask = pfor_input.get_attr("ellipsis_mask") + new_axis_mask = pfor_input.get_attr("new_axis_mask") + shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") + + begin = array_ops.concat([[0], begin], axis=0) + end = array_ops.concat([[0], end], axis=0) + strides = array_ops.concat([[1], strides], axis=0) + begin_mask = begin_mask << 1 | 1 + end_mask = end_mask << 1 | 1 + ellipsis_mask <<= 1 + new_axis_mask <<= 1 + shrink_axis_mask <<= 1 + return wrap( + array_ops.strided_slice( + inp, + begin, + end, + strides, + begin_mask=begin_mask, + end_mask=end_mask, + ellipsis_mask=ellipsis_mask, + new_axis_mask=new_axis_mask, + shrink_axis_mask=shrink_axis_mask), True) + + +@RegisterPFor("StridedSliceGrad") +def _convert_strided_slice_grad(pfor_input): + shape = pfor_input.unstacked_input(0) + begin = pfor_input.unstacked_input(1) + end = pfor_input.unstacked_input(2) + strides = pfor_input.unstacked_input(3) + dy = pfor_input.stacked_input(4) + begin_mask = pfor_input.get_attr("begin_mask") + end_mask = pfor_input.get_attr("end_mask") + ellipsis_mask = pfor_input.get_attr("ellipsis_mask") + new_axis_mask = pfor_input.get_attr("new_axis_mask") + shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") + + shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) + begin = array_ops.concat([[0], begin], axis=0) + end = array_ops.concat([[0], end], axis=0) + strides = array_ops.concat([[1], strides], axis=0) + begin_mask = begin_mask << 1 | 1 + end_mask = end_mask << 1 | 1 + ellipsis_mask <<= 1 + new_axis_mask <<= 1 + shrink_axis_mask <<= 1 + return wrap( + array_ops.strided_slice_grad( + shape, + begin, + end, + strides, + dy, + begin_mask=begin_mask, + end_mask=end_mask, + ellipsis_mask=ellipsis_mask, + new_axis_mask=new_axis_mask, + shrink_axis_mask=shrink_axis_mask), True) + + +# math_ops + + +@RegisterPFor("MatMul") +def _convert_matmul(pfor_input): + # TODO(agarwal): Check if tiling is faster than two transposes. + a, a_stacked, _ = pfor_input.input(0) + b, b_stacked, _ = pfor_input.input(1) + tr_a = pfor_input.get_attr("transpose_a") + tr_b = pfor_input.get_attr("transpose_b") + if a_stacked and b_stacked: + output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True) + return output + elif a_stacked: + if tr_a: + a = array_ops.transpose(a, [0, 2, 1]) + if a.shape.is_fully_defined(): + x, y, z = a.shape + else: + x, y, z = [ + array_ops.reshape(i, []) + for i in array_ops.split(array_ops.shape(a), 3) + ] + a = array_ops.reshape(a, [x * y, z]) + prod = math_ops.matmul(a, b, transpose_b=tr_b) + return wrap(array_ops.reshape(prod, [x, y, -1]), True) + else: + assert b_stacked + if tr_b: + perm = [2, 0, 1] + b = array_ops.transpose(b, perm) + else: + # As an optimization, if one of the first two dimensions is 1, then we can + # reshape instead of transpose. + # TODO(agarwal): This check can be done inside Transpose kernel. + b_shape = array_ops.shape(b) + min_dim = math_ops.minimum(b_shape[0], b_shape[1]) + perm = control_flow_ops.cond( + math_ops.equal(min_dim, 1), lambda: [0, 1, 2], lambda: [1, 0, 2]) + new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]]) + b = array_ops.transpose(b, perm) + b = array_ops.reshape(b, new_shape) + + if b.shape.is_fully_defined(): + x, y, z = b.shape + else: + x, y, z = [ + array_ops.reshape(i, []) + for i in array_ops.split(array_ops.shape(b), 3) + ] + b = array_ops.reshape(b, [x, y * z]) + prod = math_ops.matmul(a, b, transpose_a=tr_a) + prod = array_ops.reshape(prod, [-1, y, z]) + prod = array_ops.transpose(prod, [1, 0, 2]) + return wrap(prod, True) + + +@RegisterPFor("BatchMatMul") +def _convert_batch_mat_mul(pfor_input): + # TODO(agarwal): There may be a more efficient way to do this instead of + # stacking the inputs. + pfor_input.stack_inputs() + x = pfor_input.stacked_input(0) + y = pfor_input.stacked_input(1) + adj_x = pfor_input.get_attr("adj_x") + adj_y = pfor_input.get_attr("adj_y") + + x = _flatten_first_two_dims(x) + y = _flatten_first_two_dims(y) + output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) + output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector) + return wrap(output, True) + + +@RegisterPForWithArgs("Sum", math_ops.reduce_sum) +@RegisterPForWithArgs("Prod", math_ops.reduce_prod) +@RegisterPForWithArgs("Max", math_ops.reduce_max) +@RegisterPForWithArgs("Min", math_ops.reduce_min) +def _convert_reduction(pfor_input, _, op_func): + t = pfor_input.stacked_input(0) + indices = pfor_input.unstacked_input(1) + # Shift positive indices by one to account for the extra dimension. + indices += math_ops.cast(indices >= 0, dtypes.int32) + keep_dims = pfor_input.get_attr("keep_dims") + return wrap(op_func(t, indices, keepdims=keep_dims), True) + + +@RegisterPForWithArgs("Cumsum", math_ops.cumsum) +@RegisterPForWithArgs("Cumprod", math_ops.cumprod) +def _convert_cumfoo(pfor_input, _, op_func): + t = pfor_input.stacked_input(0) + axis = pfor_input.unstacked_input(1) + # Shift positive indices by one to account for the extra dimension. + axis += math_ops.cast(axis >= 0, dtypes.int32) + exclusive = pfor_input.get_attr("exclusive") + reverse = pfor_input.get_attr("reverse") + return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True) + + +@RegisterPFor("BiasAdd") +def _convert_biasadd(pfor_input): + t = pfor_input.stacked_input(0) + bias = pfor_input.unstacked_input(1) + data_format = pfor_input.get_attr("data_format") + if data_format != b"NCHW": + return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True) + shape = array_ops.shape(t) + flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0) + t = array_ops.reshape(t, flattened_shape) + t = nn_ops.bias_add(t, bias, data_format=b"NCHW") + t = array_ops.reshape(t, shape) + return wrap(t, True) + + +@RegisterPFor("UnsortedSegmentSum") +def _convert_unsortedsegmentsum(pfor_input): + data, data_stacked, _ = pfor_input.input(0) + # TODO(agarwal): handle unstacked? + segment_ids = pfor_input.stacked_input(1) + # TODO(agarwal): handle stacked? + num_segments = pfor_input.unstacked_input(2) + if not data_stacked: + data = _stack(data, pfor_input.pfor.loop_len_vector).t + segment_shape = array_ops.shape(segment_ids) + n = segment_shape[0] + ones = array_ops.ones_like(segment_shape)[1:] + segment_offset = num_segments * math_ops.range(n) + segment_offset = array_ops.reshape(segment_offset, + array_ops.concat([[n], ones], axis=0)) + segment_ids += segment_offset + num_segments *= n + output = math_ops.unsorted_segment_sum(data, segment_ids, num_segments) + new_output_shape = array_ops.concat( + [[n, -1], array_ops.shape(output)[1:]], axis=0) + output = array_ops.reshape(output, new_output_shape) + return wrap(output, True) + + +@RegisterPFor("Cast") +def _convert_cast(pfor_input): + inp = pfor_input.stacked_input(0) + dtype = pfor_input.get_attr("DstT") + return wrap(math_ops.cast(inp, dtype), True) + + +# Note that ops handled here do not have attributes except "T", and hence don't +# need extra arguments passed to the cwise_op call below. +@RegisterPForWithArgs("Add", math_ops.add) +@RegisterPForWithArgs("Ceil", math_ops.ceil) +@RegisterPForWithArgs("Equal", math_ops.equal) +@RegisterPForWithArgs("NotEqual", math_ops.not_equal) +@RegisterPForWithArgs("Floor", math_ops.floor) +@RegisterPForWithArgs("Greater", math_ops.greater) +@RegisterPForWithArgs("GreaterEqual", math_ops.greater_equal) +@RegisterPForWithArgs("Less", math_ops.less) +@RegisterPForWithArgs("LessEqual", math_ops.less_equal) +@RegisterPForWithArgs("LogicalOr", math_ops.logical_or) +@RegisterPForWithArgs("LogicalAnd", math_ops.logical_and) +@RegisterPForWithArgs("LogicalNot", math_ops.logical_not) +@RegisterPForWithArgs("LogicalXor", math_ops.logical_xor) +@RegisterPForWithArgs("Maximum", math_ops.maximum) +@RegisterPForWithArgs("Minimum", math_ops.minimum) +@RegisterPForWithArgs("Mul", math_ops.multiply) +@RegisterPForWithArgs("Neg", math_ops.negative) +@RegisterPForWithArgs("RealDiv", math_ops.divide) +@RegisterPForWithArgs("Relu", nn_ops.relu) +@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid) +@RegisterPForWithArgs("Square", math_ops.square) +@RegisterPForWithArgs("Sub", math_ops.subtract) +@RegisterPForWithArgs("Tanh", math_ops.tanh) +def _convert_cwise(pfor_input, op_type, op_func): + del op_type + pfor_input.expanddim_inputs_for_broadcast() + return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) + + +@RegisterPFor("Shape") +def _convert_shape(pfor_input): + out_type = pfor_input.get_attr("out_type") + return wrap( + array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:], + False) + + +@RegisterPFor("ShapeN") +def _convert_shape_n(pfor_input): + out_type = pfor_input.get_attr("out_type") + shapes = [ + array_ops.shape(x, out_type=out_type)[1:] + if stacked else array_ops.shape(x) for x, stacked, _ in pfor_input.inputs + ] + return [wrap(x, False) for x in shapes] + + +@RegisterPFor("Size") +def _convert_size(pfor_input): + out_type = pfor_input.get_attr("out_type") + n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type) + return wrap( + array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n, + False) + + +@RegisterPFor("Rank") +def _convert_rank(pfor_input): + return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False) + + +@RegisterPFor("AddN") +def _convert_addn(pfor_input): + # AddN does not support broadcasting. + pfor_input.stack_inputs() + return wrap(math_ops.add_n([x.t for x in pfor_input.inputs]), True) + + +@RegisterPFor("BiasAddGrad") +def _convert_biasaddgrad(pfor_input): + grad = pfor_input.stacked_input(0) + fmt = pfor_input.get_attr("data_format") + if fmt == b"NCHW": + output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False) + else: + grad_shape = array_ops.shape(grad) + last_dim_shape = grad_shape[-1] + first_dim_shape = grad_shape[0] + output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape]) + output = math_ops.reduce_sum(output, axis=[1], keepdims=False) + return wrap(output, True) + + +# Some required ops are not exposed under the tf namespace. Hence relying on +# _create_op to create them. +@RegisterPForWithArgs("ReluGrad") +@RegisterPForWithArgs("TanhGrad") +@RegisterPForWithArgs("SigmoidGrad") +def _convert_grads(pfor_input, op_type, *args, **kw_args): + del args + del kw_args + # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we + # have to use tiling here. + pfor_input.stack_inputs() + outputs = _create_op( + op_type, [x.t for x in pfor_input.inputs], + [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + return [wrap(x, True) for x in outputs] + + +@RegisterPFor("Select") +def _convert_select(pfor_input): + pfor_input.stack_inputs() + cond = pfor_input.stacked_input(0) + t = pfor_input.stacked_input(1) + e = pfor_input.stacked_input(2) + cond_rank = array_ops.rank(cond) + cond, t, e = control_flow_ops.cond( + cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]), + lambda: [cond, t, e]) + outputs = _create_op( + pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + n = pfor_input.pfor.loop_len_vector + out = control_flow_ops.cond(cond_rank > 1, + lambda: _unflatten_first_dim(outputs[0], n), + lambda: outputs[0]) + return [wrap(out, True) for x in outputs] + + +# random_ops + + +@RegisterPForWithArgs("RandomUniform") +@RegisterPForWithArgs("RandomUniformInt") +@RegisterPForWithArgs("RandomStandardNormal") +@RegisterPForWithArgs("TruncatedNormal") +@RegisterPForWithArgs("RandomGamma") +@RegisterPForWithArgs("RandomPoissonV2") +def _convert_random(pfor_input, op_type, *args, **kw_args): + del args + del kw_args + inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)] + # inputs[0] is "shape" + inputs[0] = array_ops.concat( + [pfor_input.pfor.loop_len_vector, inputs[0]], axis=0) + logging.warning( + "Note that %s inside pfor op may not give same output as " + "inside a sequential loop.", op_type) + outputs = _create_op( + op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + return [wrap(x, True) for x in outputs] + + +# logging_ops + + +@RegisterPFor("Assert") +def _convert_assert(pfor_input): + cond, cond_stacked, _ = pfor_input.input(0) + if cond_stacked: + cond = math_ops.reduce_all(cond) + + data_list = [x.t for x in pfor_input.inputs][1:] + return _create_op("Assert", [cond] + data_list, [], + attrs=pfor_input.op.node_def.attr) + + +@RegisterPFor("Print") +def _convert_print(pfor_input): + # Note that we don't stack all the inputs. Hence unstacked values are printed + # once here vs multiple times in a while_loop. + pfor_input.stack_inputs([0]) + outputs = _create_op( + "Print", [x.t for x in pfor_input.inputs], + [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + return [wrap(x, True) for x in outputs] + + +# data_flow_ops + +# TensorArray conversion is tricky since we don't support arrays of +# TensorArrays. For converting them, we consider two distinct cases: +# +# 1. The array is constructed outside the pfor call, and read/written inside the +# loop. +# This is an easier case since we don't need to make an array of TensorArrays. +# A correctness requirement is that these parallel iterations shouldn't attempt +# to write to the same location. Hence at conversion time we disallow indices to +# be loop-invariant as that would guarantee a collision. Even if the indices are +# not loop-invariant, they could conflict and that shall trigger runtime errors. +# +# 2. The array is constructed and used entirely inside each pfor iteration. +# For simplicity, here we require that the indices used for write/scatter are +# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in +# different pfor iterations. We consider two sub_cases: +# +# 2a Elements written to the array are "stacked" +# To simulate multiple TensorArrays, we may increase the dimension of each +# element of the array. i.e. the i_th row of the j_th entry of the converted +# TensorArray corresponds to the j_th entry of the TensorArray in the i_th +# pfor iteration. +# +# 2b Elements written to the array are "unstacked" +# In this case we don't increase the dimensions to avoid redundant tiling. Each +# iteration is trying to write the same value. So we convert that to a single +# write. +# +# Here are some tricks used to implement the above: +# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of +# trying to trace whether future writes are stacked or unstacked in order to set +# this attr, we set it to correspond to unknown shape. +# - We use the "flow" output of the different ops to track whether the array +# elements are stacked or unstacked. If a stacked write/scatter is done, we make +# the flow stacked as well. +# - We use some heuristic traversal of the graph to track whether the +# TensorArray handle was created inside or outside the pfor loop. + + +@RegisterPFor("TensorArrayV3") +def _convert_tensor_array_v3(pfor_input): + size = pfor_input.unstacked_input(0) + dtype = pfor_input.get_attr("dtype") + dynamic_size = pfor_input.get_attr("dynamic_size") + clear_after_read = pfor_input.get_attr("clear_after_read") + identical_element_shapes = pfor_input.get_attr("identical_element_shapes") + tensor_array_name = pfor_input.get_attr("tensor_array_name") + handle, flow = data_flow_ops.tensor_array_v3( + size, + dtype=dtype, + # We don't set element shape since we don't know if writes are stacked or + # not yet. + element_shape=None, + dynamic_size=dynamic_size, + clear_after_read=clear_after_read, + identical_element_shapes=identical_element_shapes, + tensor_array_name=tensor_array_name) + # Note we keep flow unstacked for now since we don't know if writes will be + # stacked or not. + return wrap(handle, False), wrap(flow, False) + + +@RegisterPFor("TensorArraySizeV3") +def _convert_tensor_array_size_v3(pfor_input): + handle = pfor_input.unstacked_input(0) + flow, flow_stacked, _ = pfor_input.input(1) + if flow_stacked: + flow = _unstack_flow(flow) + size = data_flow_ops.tensor_array_size_v3(handle, flow) + return wrap(size, False) + + +def _handle_inside_pfor(pfor_input, handle): + """Returns True if handle was created inside the pfor loop.""" + # We use some heuristic to find the original TensorArray creation op. + # The logic should handle the common cases (except cond based subgraphs). + # In theory the user could perform different operations on the handle (like + # Reshape, stack multiple handles, etc) which could break this logic. + # TODO(agarwal): handle Switch/Merge. + while handle.op.type in ("Enter", "Identity"): + handle = handle.op.inputs[0] + if handle.op.type not in [ + "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape"]: + raise ValueError("Unable to find source for handle %s" % handle) + else: + return pfor_input.pfor.op_is_inside_loop(handle.op) + + +def _unstack_flow(value): + # TODO(agarwal): consider looking if this is a Tile op then get its input. + # This may avoid running the Tile operations. + return array_ops.gather(value, 0) + + +@RegisterPFor("TensorArrayReadV3") +def _convert_tensor_array_read_v3(pfor_input): + handle = pfor_input.unstacked_input(0) + index, index_stacked, _ = pfor_input.input(1) + dtype = pfor_input.get_attr("dtype") + flow, flow_stacked, _ = pfor_input.input(2) + if flow_stacked: + flow = _unstack_flow(flow) + + is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) + if is_inside_pfor: + # Note that if we are inside a control flow construct inside the pfor, and + # only some of the iterations are doing the read (i.e. + # `all_indices_partitioned` is True), then the read operation should only + # return values for the currently active pfor iterations (`all_indices` + # below). Hence, whenever the returned value is stacked (i.e. `flow` is + # stacked), we may need to do an extra gather after reading the values. Also + # note that if `is_inside` is false, then values in the tensor array are + # unstacked. So the check is only needed in this branch. + all_indices = pfor_input.pfor.all_indices + all_indices_partitioned = pfor_input.pfor.all_indices_partitioned + # Note: flow_stacked indicates if values in the TensorArray are stacked or + # not. + if index_stacked: + if flow_stacked: + raise ValueError( + "It looks like TensorArrayReadV3 was called on a TensorArray whose" + " values are not loop-invariant, and the read indices were also" + " not loop invariant. This is currently unsupported.") + value = data_flow_ops.tensor_array_gather_v3( + handle, index, flow, dtype=dtype) + return wrap(value, True) + value = data_flow_ops.tensor_array_read_v3( + handle, index, flow, dtype=dtype) + if flow_stacked and all_indices_partitioned: + value = array_ops.gather(value, all_indices) + return wrap(value, flow_stacked) + # Values in the TensorArray should be unstacked (since different iterations + # couldn't write to the same location). So whether output is stacked or not + # depends on index_stacked. + if index_stacked: + value = data_flow_ops.tensor_array_gather_v3( + handle, index, flow, dtype=dtype) + else: + value = data_flow_ops.tensor_array_read_v3( + handle, index, flow, dtype=dtype) + return wrap(value, index_stacked) + + +@RegisterPFor("TensorArrayWriteV3") +def _convert_tensor_array_write_v3(pfor_input): + handle = pfor_input.unstacked_input(0) + index, index_stacked, _ = pfor_input.input(1) + value, value_stacked, _ = pfor_input.input(2) + flow, flow_stacked, _ = pfor_input.input(3) + if value_stacked and pfor_input.pfor.all_indices_partitioned: + # Looks like we are in a control flow in a pfor where not all iterations are + # active now. We don't allow that since that could lead to different indices + # having different shapes which will be hard to merge later. + raise ValueError("Writing non loop invariant values to TensorArray from " + "inside a while_loop/cond not supported.") + if flow_stacked: + flow = _unstack_flow(flow) + is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) + if is_inside: + if index_stacked: + raise ValueError("Need indices for %s to be loop invariant" % handle) + if not flow_stacked and not value_stacked: + flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) + return wrap(flow_out, False) + else: + if not value_stacked: + value = _stack(value, pfor_input.pfor.loop_len_vector).t + # TODO(agarwal): Note that if flow is unstacked and value is stacked, then + # this may or may not be a safe situation. flow is unstacked both for a + # freshly created TensorArray, as well as after unstacked values are + # written to it. If it is the latter, then we cannot write a stacked value + # now since that may cause runtime errors due to different shapes in the + # array. At the moment we are not able to handle this gracefully and + # distinguish between the two cases. That would require some heuristic + # traversal of the graph to figure out whether all the writes are + # unstacked or not. + flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) + return _stack(flow_out, pfor_input.pfor.loop_len_vector) + else: + if not index_stacked: + raise ValueError("Need indices for %s to be not loop invariant" % handle) + # Note that even when index_stacked is true, actual values in index may + # still not be unique. However that will cause runtime error when executing + # the scatter operation below. + if not value_stacked: + value = _stack(value, pfor_input.pfor.loop_len_vector).t + flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow) + return _stack(flow_out, pfor_input.pfor.loop_len_vector) + + +def _transpose_first_two_dims(value): + # TODO(agarwal): optimize if one of the dims == 1. + value_shape = array_ops.shape(value) + v0 = value_shape[0] + v1 = value_shape[1] + value = array_ops.reshape(value, [v0, v1, -1]) + value = array_ops.transpose(value, [1, 0, 2]) + new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0) + return array_ops.reshape(value, new_shape) + + +@RegisterPFor("TensorArrayGatherV3") +def _convert_tensor_array_gather_v3(pfor_input): + handle = pfor_input.unstacked_input(0) + indices, indices_stacked, _ = pfor_input.input(1) + indices = array_ops.reshape(indices, [-1]) + flow, flow_stacked, _ = pfor_input.input(2) + if flow_stacked: + flow = _unstack_flow(flow) + dtype = pfor_input.get_attr("dtype") + # TODO(agarwal): support element_shape attr? + + n = pfor_input.pfor.loop_len_vector + value = data_flow_ops.tensor_array_gather_v3( + handle, indices, flow, dtype=dtype) + is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) + if is_inside: + # flow_stacked indicates if values in the TensorArray are stacked or not. + if indices_stacked: + if flow_stacked: + raise ValueError( + "It looks like TensorArrayGatherV3 was called on a TensorArray " + "whose values are not loop-invariant, and the indices were also " + "not loop invariant. This is currently unsupported.") + else: + value = _unflatten_first_dim(value, n) + return wrap(value, True) + else: + if flow_stacked: + # Since elements in this array are stacked and `value` was produced by + # gather, its first two dims are "gathered elements" and "stack + # dimension". Our semantics require these two to be flipped. + value = _transpose_first_two_dims(value) + return wrap(value, flow_stacked) + else: + # Values in the TensorArray should be unstacked (since different iterations + # couldn't write to the same location). So whether output is stacked or not + # depends on indices_stacked. + if indices_stacked: + value = _unflatten_first_dim(value, n) + return wrap(value, indices_stacked) + + +@RegisterPFor("TensorArrayScatterV3") +def _convert_tensor_array_scatter_v3(pfor_input): + handle = pfor_input.unstacked_input(0) + indices, indices_stacked, _ = pfor_input.input(1) + indices = array_ops.reshape(indices, [-1]) + value, value_stacked, _ = pfor_input.input(2) + flow, flow_stacked, _ = pfor_input.input(3) + + if flow_stacked: + flow = _unstack_flow(flow) + + is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) + if is_inside: + if indices_stacked: + raise ValueError("Need indices for %s to be loop invariant" % handle) + # Note that flow_stacked indicates if existing values in the array are + # stacked or not. + if not flow_stacked and not value_stacked: + flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, + flow) + return wrap(flow_out, False) + if not value_stacked: + # TODO(agarwal): tile in the second dimension directly instead of + # transposing below. + value = _stack(value, pfor_input.pfor.loop_len_vector).t + + value = _transpose_first_two_dims(value) + # TODO(agarwal): Note that if a previous write was unstacked, flow will be + # unstacked, and a stacked value may be written here which may cause + # runtime error due to different elements having different shape. We do + # not try to prevent that. + flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, + flow) + return _stack(flow_out, pfor_input.pfor.loop_len_vector) + if not indices_stacked: + raise ValueError("Need indices for %s to be not loop invariant" % handle) + if not value_stacked: + value = _stack(value, pfor_input.pfor.loop_len_vector).t + value = _flatten_first_two_dims(value) + flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, + flow) + return _stack(flow_out, pfor_input.pfor.loop_len_vector) + + +@RegisterPFor("TensorArrayGradV3") +def _convert_tensor_array_grad_v3(pfor_input): + handle = pfor_input.unstacked_input(0) + flow, flow_stacked, _ = pfor_input.input(1) + if flow_stacked: + flow = _unstack_flow(flow) + source = pfor_input.get_attr("source") + # TODO(agarwal): For now, we assume that gradients are stacked if the + # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong + # will give runtime error due to incorrect shape being written to the + # accumulator. It is difficult to know in advance if gradients written will be + # stacked or not. Note that flow being stacked is not indicative of the + # gradient being stacked or not. Revisit this later. + shape_to_prepend = pfor_input.pfor.loop_len_vector + grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape( + handle=handle, + flow_in=flow, + shape_to_prepend=shape_to_prepend, + source=source) + flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t + return [wrap(grad_handle, False), wrap(flow_out, True)] + + +# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar +# to TensorArrays, we convert them by changing the dimension of the elements +# inside the stack. +# +# We consider two cases: +# +# 1. StackV2 is constructed and used entirely inside the pfor loop. +# We keep a single Stack and perform the push/pop operations of all the +# iterations in lock-step. We also assume that all the iterations perform these +# operations. In case of dynamic control flow, if only some of the iterations +# try to perform a push/pop, then the conversion may not work correctly and may +# cause undefined behavior. +# TODO(agarwal): test StackV2 with dynamic control flow. +# +# 2. StackV2 is constructed outside the pfor loop. +# Performing stack push/pop in a parallel fashion is ill-defined. However given +# that reading stacks created externally is a common operation when computing +# jacobians, we provide some special semantics here as follows. +# - disallow push operations to the stack +# - pop operations are performed in lock step by all iterations, similar to the +# case when the stack is created inside. A single value is popped during the +# lock-step operation and broadcast to all the iterations. Values in the stack +# are assumed to be loop-invariant. +# +# Some other implementation details: +# We use an ugly logic to find whether values in Stack data structure are +# loop invariant or not. When converting push/pop operations, we keep track of +# whether the last conversion used a stacked value or not (see _stack_cache +# below). As a result if an unstacked value is written first, subsequent stacked +# writes are disallowed when they could have been allowed in theory. + +# Map from cache key based on StackV2 handle to a bool indicating whether values +# are stacked or not. +# TODO(agarwal): move _stack_cache inside pfor? +_stack_cache = {} + + +def _stack_cache_key(pfor_input): + """Create cache key corresponding to a stack handle.""" + op_type = pfor_input.op_type + assert op_type in ["StackPushV2", "StackPopV2"], op_type + orig_handle = pfor_input.op.inputs[0] + while orig_handle.op.type in ["Identity", "Enter"]: + orig_handle = orig_handle.op.inputs[0] + assert orig_handle.op.type == "StackV2", orig_handle.op + return ops.get_default_graph(), pfor_input.pfor, orig_handle + + +def _stack_handle_inside_pfor(handle, pfor_input): + while handle.op.type in ["Identity", "Enter"]: + handle = handle.op.inputs[0] + assert handle.op.type == "StackV2", ( + "Unable to find StackV2 op. Got %s" % handle.op) + return pfor_input.pfor.op_is_inside_loop(handle.op) + + +@RegisterPFor("StackPushV2") +def _convert_stack_push_v2(pfor_input): + handle = pfor_input.unstacked_input(0) + elem, elem_stacked, _ = pfor_input.input(1) + swap_memory = pfor_input.get_attr("swap_memory") + + if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input): + raise ValueError("StackPushV2 not allowed on stacks created outside pfor") + stack_cache_key = _stack_cache_key(pfor_input) + stacked = _stack_cache.get(stack_cache_key, None) + if stacked is None: + stacked = elem_stacked + _stack_cache[stack_cache_key] = stacked + else: + # If we previously made it unstacked then we can't revert to being stacked. + if not stacked and elem_stacked: + raise ValueError( + "It looks like the stack was previously determined to be loop" + " invariant, but we are now trying to push a loop dependent value" + " to it. This is currently unsupported.") + if stacked and not elem_stacked: + elem = _stack(elem, pfor_input.pfor.loop_len_vector).t + out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory) + return wrap(out, stacked) + + +# Note that inputs to this convertor will be unstacked. However it should get +# called since it is a stateful op. +@RegisterPFor("StackPopV2") +def _convert_stack_pop_v2(pfor_input): + handle = pfor_input.unstacked_input(0) + stack_cache_key = _stack_cache_key(pfor_input) + stacked = _stack_cache.get(stack_cache_key, None) + # If a StackPushV2 has not been converted yet, we default to unstacked since + # the push could be outside of pfor, or the covertor may not be called if the + # inputs are unconverted. + if stacked is None: + stacked = False + _stack_cache[stack_cache_key] = False + elem_type = pfor_input.get_attr("elem_type") + out = data_flow_ops.stack_pop_v2(handle, elem_type) + return wrap(out, stacked) + + +# parsing_ops + + +@RegisterPFor("DecodeCSV") +def _convert_decode_csv(pfor_input): + lines = pfor_input.stacked_input(0) + record_defaults = [ + pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) + ] + field_delim = pfor_input.get_attr("field_delim") + use_quote_delim = pfor_input.get_attr("use_quote_delim") + select_cols = pfor_input.get_attr("select_cols") + if not select_cols: + select_cols = None + return [ + wrap(t, True) for t in parsing_ops.decode_csv( + lines, + record_defaults, + field_delim=field_delim, + use_quote_delim=use_quote_delim, + select_cols=select_cols) + ] + + +@RegisterPFor("ParseSingleExample") +def _convert_parse_single_example(pfor_input): + serialized = pfor_input.stacked_input(0) + dense_defaults = [ + pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) + ] + sparse_keys = pfor_input.get_attr("sparse_keys") + dense_keys = pfor_input.get_attr("dense_keys") + sparse_types = pfor_input.get_attr("sparse_types") + dense_shapes = pfor_input.get_attr("dense_shapes") + output = gen_parsing_ops.parse_example( + serialized=serialized, + names=[], + dense_defaults=dense_defaults, + sparse_keys=sparse_keys, + dense_keys=dense_keys, + sparse_types=sparse_types, + dense_shapes=dense_shapes) + return [wrap(t, True, True) for t in nest.flatten(output)] diff --git a/tensorflow/python/ops/random_grad.py b/tensorflow/python/ops/random_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..baa8e2e2cd33d37312b5b14bea3c248c06ff2e50 --- /dev/null +++ b/tensorflow/python/ops/random_grad.py @@ -0,0 +1,65 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradients for operators defined in random_ops.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_random_ops +from tensorflow.python.ops import math_ops + + +def add_leading_unit_dimensions(x, num_dimensions): + new_shape = array_ops.concat( + [array_ops.ones([num_dimensions], dtype=dtypes.int32), + array_ops.shape(x)], axis=0) + return array_ops.reshape(x, new_shape) + + +@ops.RegisterGradient("RandomGamma") +def _RandomGammaGrad(op, grad): # pylint: disable=invalid-name + """Returns the gradient of a Gamma sample w.r.t. alpha. + + The gradient is computed using implicit differentiation, see + "Implicit Reparameterization Gradients" (https://arxiv.org/abs/1805.08498). + + Args: + op: A `RandomGamma` operation. We assume that the inputs to the operation + are `shape` and `alpha` tensors, and the output is the `sample` tensor. + grad: The incoming gradient `dloss / dsample` of the same shape as + `op.outputs[0]`. + + Returns: + A `Tensor` with derivatives `dloss / dalpha` + """ + shape = op.inputs[0] + alpha = op.inputs[1] + sample = op.outputs[0] + + with ops.control_dependencies([grad]): + # Make the parameters alpha broadcastable with samples by appending + # unit dimensions. + num_sample_dimensions = array_ops.shape(shape)[0] + alpha_broadcastable = add_leading_unit_dimensions( + alpha, num_sample_dimensions) + partial_a = gen_random_ops.random_gamma_grad(alpha_broadcastable, sample) + + # The first input is shape; the second input is alpha. + return (None, math_ops.reduce_sum( + grad * partial_a, axis=math_ops.range(num_sample_dimensions))) diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index 6a2dd3f1cd55eea1d3b652a31cd2784c411c2ce0..4baf50638504527b474fc335ef1d57bb1a84611e 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -61,7 +61,7 @@ def random_normal(shape, dtype: The type of the output. seed: A Python integer. Used to create a random seed for the distribution. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. name: A name for the operation (optional). @@ -110,7 +110,7 @@ def parameterized_truncated_normal(shape, dtype: The type of the output. seed: A Python integer. Used to create a random seed for the distribution. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. name: A name for the operation (optional). @@ -158,7 +158,7 @@ def truncated_normal(shape, dtype: The type of the output. seed: A Python integer. Used to create a random seed for the distribution. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. name: A name for the operation (optional). @@ -212,7 +212,7 @@ def random_uniform(shape, dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or `int64`. seed: A Python integer. Used to create a random seed for the distribution. - See @{tf.set_random_seed} + See `tf.set_random_seed` for behavior. name: A name for the operation (optional). @@ -264,7 +264,7 @@ def random_shuffle(value, seed=None, name=None): value: A Tensor to be shuffled. seed: A Python integer. Used to create a random seed for the distribution. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. name: A name for the operation (optional). @@ -292,7 +292,7 @@ def random_crop(value, size, seed=None, name=None): value: Input tensor to crop. size: 1-D tensor with size the rank of `value`. seed: Python integer. Used to create a random seed. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. name: A name for this operation (optional). @@ -338,7 +338,7 @@ def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None): num_samples: 0-D. Number of independent samples to draw for each row slice. seed: A Python integer. Used to create a random seed for the distribution. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. name: Optional name for the operation. output_dtype: integer type to use for the output. Defaults to int64. @@ -368,25 +368,41 @@ def random_gamma(shape, `alpha` is the shape parameter describing the distribution(s), and `beta` is the inverse scale parameter(s). - Example: + Note: Because internal calculations are done using `float64` and casting has + `floor` semantics, we must manually map zero outcomes to the smallest + possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This + means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise + should. This bias can only happen for small values of `alpha`, i.e., + `alpha << 1` or large values of `beta`, i.e., `beta >> 1`. - samples = tf.random_gamma([10], [0.5, 1.5]) - # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents - # the samples drawn from each distribution + The samples are differentiable w.r.t. alpha and beta. + The derivatives are computed using the approach described in the paper - samples = tf.random_gamma([7, 5], [0.5, 1.5]) - # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] - # represents the 7x5 samples drawn from each of the two distributions + [Michael Figurnov, Shakir Mohamed, Andriy Mnih. + Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498) - samples = tf.random_gamma([30], [[1.],[3.],[5.]], beta=[[3., 4.]]) - # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions. + Example: - Note: Because internal calculations are done using `float64` and casting has - `floor` semantics, we must manually map zero outcomes to the smallest - possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This - means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise - should. This bias can only happen for small values of `alpha`, i.e., - `alpha << 1` or large values of `beta`, i.e., `beta >> 1`. + ```python + samples = tf.random_gamma([10], [0.5, 1.5]) + # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents + # the samples drawn from each distribution + + samples = tf.random_gamma([7, 5], [0.5, 1.5]) + # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] + # represents the 7x5 samples drawn from each of the two distributions + + alpha = tf.constant([[1.],[3.],[5.]]) + beta = tf.constant([[3., 4.]]) + samples = tf.random_gamma([30], alpha=alpha, beta=beta) + # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions. + + loss = tf.reduce_mean(tf.square(samples)) + dloss_dalpha, dloss_dbeta = tf.gradients(loss, [alpha, beta]) + # unbiased stochastic derivatives of the loss function + alpha.shape == dloss_dalpha.shape # True + beta.shape == dloss_dbeta.shape # True + ``` Args: shape: A 1-D integer Tensor or Python array. The shape of the output samples @@ -401,13 +417,14 @@ def random_gamma(shape, `float64`. seed: A Python integer. Used to create a random seed for the distributions. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. name: Optional name for the operation. Returns: - samples: a `Tensor` of shape `tf.concat(shape, tf.shape(alpha + beta))` - with values of type `dtype`. + samples: a `Tensor` of shape + `tf.concat([shape, tf.shape(alpha + beta)], axis=0)` with values of type + `dtype`. """ with ops.name_scope(name, "random_gamma", [shape, alpha, beta]): shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32) @@ -421,8 +438,6 @@ def random_gamma(shape, gen_random_ops.random_gamma( shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta) -ops.NotDifferentiable("RandomGamma") - @tf_export("random_poisson") def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): @@ -432,13 +447,15 @@ def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): Example: - samples = tf.random_poisson([0.5, 1.5], [10]) - # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents - # the samples drawn from each distribution + ```python + samples = tf.random_poisson([0.5, 1.5], [10]) + # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents + # the samples drawn from each distribution - samples = tf.random_poisson([12.2, 3.3], [7, 5]) - # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] - # represents the 7x5 samples drawn from each of the two distributions + samples = tf.random_poisson([12.2, 3.3], [7, 5]) + # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] + # represents the 7x5 samples drawn from each of the two distributions + ``` Args: lam: A Tensor or Python value or N-D array of type `dtype`. @@ -450,13 +467,13 @@ def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): `int64`. seed: A Python integer. Used to create a random seed for the distributions. See - @{tf.set_random_seed} + `tf.set_random_seed` for behavior. name: Optional name for the operation. Returns: - samples: a `Tensor` of shape `tf.concat(shape, tf.shape(lam))` with - values of type `dtype`. + samples: a `Tensor` of shape `tf.concat([shape, tf.shape(lam)], axis=0)` + with values of type `dtype`. """ with ops.name_scope(name, "random_poisson", [lam, shape]): shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index de44a3e848586d92b3b3155edfbfcadc47755089..d533731c07288e512e50f3028f2dcb740900bcd1 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -181,7 +181,8 @@ def shape_safe_assign_variable_handle(handle, shape, value, name=None): name=name) -class ResourceVariable(variables.Variable): +# TODO(apassos) make this be variables.Variable +class ResourceVariable(variables.RefVariable): """Variable based on resource handles. See the @{$variables$Variables How To} for a high level overview. @@ -195,15 +196,16 @@ class ResourceVariable(variables.Variable): the variable are fixed. The value can be changed using one of the assign methods. - Just like any `Tensor`, variables created with `ResourceVariable()` can be - used as inputs for other Ops in the graph. Additionally, all the operators - overloaded for the `Tensor` class are carried over to variables, so you can - also add nodes to the graph by just doing arithmetic on variables. + Just like any `Tensor`, variables created with + `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the + graph. Additionally, all the operators overloaded for the `Tensor` class are + carried over to variables, so you can also add nodes to the graph by just + doing arithmetic on variables. - Unlike tf.Variable, a tf.ResourceVariable has well-defined semantics. Each + Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each usage of a ResourceVariable in a TensorFlow graph adds a read_value operation - to the graph. The Tensors returned by a read_value operation are guaranteed - to see all modifications to the value of the variable which happen in any + to the graph. The Tensors returned by a read_value operation are guaranteed to + see all modifications to the value of the variable which happen in any operation on which the read_value depends on (either directly, indirectly, or via a control dependency) and guaranteed to not see any modification to the value of the variable from operations that depend on the read_value operation. @@ -217,7 +219,7 @@ class ResourceVariable(variables.Variable): can cause tf.Variable and tf.ResourceVariable to behave differently: ```python - a = tf.ResourceVariable(1.0) + a = tf.Variable(1.0, use_resource=True) a.initializer.run() assign = a.assign(2.0) @@ -741,8 +743,14 @@ class ResourceVariable(variables.Variable): def _read_variable_op(self): if self.trainable: tape.watch_variable(self) - return gen_resource_variable_ops.read_variable_op(self._handle, - self._dtype) + result = gen_resource_variable_ops.read_variable_op(self._handle, + self._dtype) + if not context.executing_eagerly(): + # Note that if a control flow context is active the input of the read op + # might not actually be the handle. This line bypasses it. + tape.record_operation( + "ReadVariableOp", [result], [self._handle], lambda x: [x]) + return result def read_value(self): """Constructs an op which reads the value of this variable. @@ -851,14 +859,15 @@ class ResourceVariable(variables.Variable): operator: string. The operator name. """ + tensor_oper = getattr(ops.Tensor, operator) def _run_op(a, *args): # pylint: disable=protected-access value = a._AsTensor() - return getattr(ops.Tensor, operator)(value, *args) + return tensor_oper(value, *args) # Propagate __doc__ to wrapper try: - _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__ + _run_op.__doc__ = tensor_oper.__doc__ except AttributeError: pass @@ -866,6 +875,19 @@ class ResourceVariable(variables.Variable): __array_priority__ = 100 + def is_initialized(self, name=None): + """Checks whether a resource variable has been initialized. + + Outputs boolean scalar indicating whether the tensor has been initialized. + + Args: + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `bool`. + """ + return gen_resource_variable_ops.var_is_initialized_op(self.handle, name) + def assign_sub(self, delta, use_locking=None, name=None, read_value=True): """Subtracts a value from this variable. @@ -921,9 +943,10 @@ class ResourceVariable(variables.Variable): if self.trainable: tape.watch_variable(self) return _UnreadVariable( - self._handle, self.dtype, self._shape, self._in_graph_mode, - self._handle_deleter if not self._in_graph_mode else None, op, - self._unique_id) + handle=self._handle, dtype=self.dtype, shape=self._shape, + in_graph_mode=self._in_graph_mode, + deleter=self._handle_deleter if not self._in_graph_mode else None, + parent_op=op, parent_name=self._handle_name, unique_id=self._unique_id) def assign(self, value, use_locking=None, name=None, read_value=True): """Assigns a new value to this variable. @@ -998,32 +1021,28 @@ class ResourceVariable(variables.Variable): def __imul__(self, unused_other): raise RuntimeError("Variable *= value not supported. Use " - "variable.assign_mul(value) to modify the variable " - "value and variable = variable * value to get a new " - "Tensor object.") + "`var.assign(var * value)` to modify the variable or " + "`var = var * value` to get a new Tensor object.") def __idiv__(self, unused_other): raise RuntimeError("Variable /= value not supported. Use " - "variable.assign_div(value) to modify the variable " - "value and variable = variable / value to get a new " - "Tensor object.") + "`var.assign(var / value)` to modify the variable or " + "`var = var / value` to get a new Tensor object.") def __itruediv__(self, unused_other): raise RuntimeError("Variable /= value not supported. Use " - "variable.assign_div(value) to modify the variable " - "value and variable = variable / value to get a new " - "Tensor object.") + "`var.assign(var / value)` to modify the variable or " + "`var = var / value` to get a new Tensor object.") def __irealdiv__(self, unused_other): raise RuntimeError("Variable /= value not supported. Use " - "variable.assign_div(value) to modify the variable " - "value and variable = variable / value to get a new " - "Tensor object.") + "`var.assign(var / value)` to modify the variable or " + "`var = var / value` to get a new Tensor object.") def __ipow__(self, unused_other): raise RuntimeError("Variable **= value not supported. Use " - "value and variable = variable ** value to get a new " - "Tensor object.") + "`var.assign(var ** value)` to modify the variable or " + "`var = var ** value` to get a new Tensor object.") pywrap_tensorflow.TFE_Py_RegisterResourceVariableType(ResourceVariable) @@ -1041,7 +1060,8 @@ class _UnreadVariable(ResourceVariable): """ def __init__(self, handle, dtype, # pylint: disable=super-init-not-called - shape, in_graph_mode, deleter, parent_op, unique_id): + shape, in_graph_mode, deleter, parent_op, parent_name, + unique_id): # We do not call super init on purpose. self._trainable = False self._save_slice_info = None @@ -1069,7 +1089,10 @@ class _UnreadVariable(ResourceVariable): @property def name(self): - return self._parent_op.name + if self._in_graph_mode: + return self._parent_op.name + else: + return "UnreadVariable" def value(self): return self._read_variable_op() @@ -1094,6 +1117,113 @@ class _UnreadVariable(ResourceVariable): ops.register_tensor_conversion_function(_UnreadVariable, _dense_var_to_tensor) ops.register_dense_tensor_like_type(_UnreadVariable) + +class _MixedPrecisionVariable(ResourceVariable): + """Represents a variable that can return in desired dtype when read. + + In mixed precision training, it is usually desirable to use different dtypes + for variables and computation. This class will be used to wrap created + ResourceVariable when mixed precision training is enabled. It allows layers to + perform computation in a different dtype than their variable dtypes, in order + to achieve higher performance without causing quality loss. + """ + + def __init__(self, var, read_dtype): + """Creates a MixedPrecisionVariable. + + Args: + var: A ResourceVariable instance. + read_dtype: A tf.DType, the returned dtype when read, default to None. + Casting is performed if read_dtype is not None and differs from + var.dtype. + Returns: + An MixedPrecisionVariable instance. + Raises: + ValueError: if var is not a ResourceVariable instance, or read_dtype is + not a tf.DType instance. + """ + # pylint: disable=super-init-not-called + # We do not call super init on purpose. + if not isinstance(var, ResourceVariable): + raise ValueError("InvalidArgument: var must be a ResourceVariable type.") + if not isinstance(read_dtype, dtypes.DType): + raise ValueError("InvalidArgument: read_dtype must be a tf.DType type.") + + self._var = var + self._trainable = var.trainable + self._save_slice_info = None + self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + self._in_graph_mode = var._in_graph_mode # pylint: disable=protected-access + self._handle = var.handle + self._shape = var.shape + self._initial_value = None + if isinstance(self.handle, ops.EagerTensor): + self._handle_name = "" + else: + self._handle_name = self.handle.name + self._unique_id = var._unique_id # pylint: disable=protected-access + self._dtype = var.dtype + self._constraint = None + self._cached_value = None + self._is_initialized_op = var._is_initialized_op # pylint: disable=protected-access + self._initializer_op = var._initializer_op # pylint: disable=protected-access + # This needs to be set before read_value() is called. + self._read_dtype = read_dtype + if context.executing_eagerly(): + self._graph_element = None + else: + self._graph_element = self.read_value() + self._handle_deleter = ( + var._handle_deleter if not self._in_graph_mode # pylint: disable=protected-access + else None) + # pylint: enable=super-init-not-called + + @property + def name(self): + return self._var.name + + def value(self): + return self._read_variable_op() + + def read_value(self): + return self._read_variable_op() + + def _read_variable_op(self): + with ops.colocate_with(self._handle): + res = gen_resource_variable_ops.read_variable_op(self._handle, + self._dtype) + if self._read_dtype != self._dtype: + return math_ops.cast(res, self._read_dtype) + else: + return res + + def set_shape(self, shape): + self._shape = shape + self._cached_shape_as_list = None + + @property + def op(self): + """The op for this variable.""" + return self._var.op + + @property + def read_dtype(self): + """The dtype of the returned tensor when reading the var.""" + return self._read_dtype + + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): + del name + dtype = dtype or self.read_dtype + if dtype != self.read_dtype or as_ref: + return NotImplemented + else: + res = self.value() + return res + + def _should_act_as_resource_variable(self): + """To pass resource_variable_ops.is_resource_variable check.""" + pass + # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 10d576c95bc4fd3147da44ee1522dc829bcab83d..7b6ab20975114da5599aabdd0c6e664191f5fe48 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops @@ -131,6 +132,18 @@ def _maybe_tensor_shape_from_tensor(shape): return shape +def _should_cache(): + """Returns True if a default caching device should be set, otherwise False.""" + if context.executing_eagerly(): + return False + # Don't set a caching device when running in a loop, since it is possible that + # train steps could be wrapped in a tf.while_loop. In that scenario caching + # prevents forward computations in loop iterations from re-reading the + # updated weights. + ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + return control_flow_util.GetContainingWhileContext(ctxt) is None + + # pylint: disable=unused-argument def _rnn_step( time, sequence_length, min_sequence_length, max_sequence_length, @@ -404,24 +417,30 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, # Backward direction if not time_major: - time_dim = 1 - batch_dim = 0 + time_axis = 1 + batch_axis = 0 else: - time_dim = 0 - batch_dim = 1 + time_axis = 0 + batch_axis = 1 - def _reverse(input_, seq_lengths, seq_dim, batch_dim): + def _reverse(input_, seq_lengths, seq_axis, batch_axis): if seq_lengths is not None: return array_ops.reverse_sequence( input=input_, seq_lengths=seq_lengths, - seq_dim=seq_dim, batch_dim=batch_dim) + seq_axis=seq_axis, batch_axis=batch_axis) else: - return array_ops.reverse(input_, axis=[seq_dim]) + return array_ops.reverse(input_, axis=[seq_axis]) with vs.variable_scope("bw") as bw_scope: - inputs_reverse = _reverse( - inputs, seq_lengths=sequence_length, - seq_dim=time_dim, batch_dim=batch_dim) + + def _map_reverse(inp): + return _reverse( + inp, + seq_lengths=sequence_length, + seq_axis=time_axis, + batch_axis=batch_axis) + + inputs_reverse = nest.map_structure(_map_reverse, inputs) tmp, output_state_bw = dynamic_rnn( cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length, initial_state=initial_state_bw, dtype=dtype, @@ -430,7 +449,7 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, output_bw = _reverse( tmp, seq_lengths=sequence_length, - seq_dim=time_dim, batch_dim=batch_dim) + seq_axis=time_axis, batch_axis=batch_axis) outputs = (output_fw, output_bw) output_states = (output_state_fw, output_state_bw) @@ -558,7 +577,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. - if not context.executing_eagerly(): + if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) @@ -828,7 +847,8 @@ def _dynamic_rnn_loop(cell, final_outputs = nest.pack_sequence_as( structure=cell.output_size, flat_sequence=final_outputs) if not in_graph_mode: - final_outputs = array_ops.stack(final_outputs, axis=0) + final_outputs = nest.map_structure_up_to( + cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs) return (final_outputs, final_state) @@ -1014,7 +1034,7 @@ def raw_rnn(cell, loop_fn, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if not context.executing_eagerly(): + if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) @@ -1227,7 +1247,7 @@ def static_rnn(cell, # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: - if not context.executing_eagerly(): + if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 05723c6960af3772d9576756ee94bd19f562edd1..85a6a2233ccb98dbde1416081aa0e954cb1ad96a 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -34,6 +34,9 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.keras import activations +from tensorflow.python.keras import initializers +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -48,22 +51,13 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest +from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export _BIAS_VARIABLE_NAME = "bias" _WEIGHTS_VARIABLE_NAME = "kernel" - -# TODO(jblespiau): Remove this function when we are sure there are no longer -# any usage (even if protected, it is being used). Prefer assert_like_rnncell. -def _like_rnncell(cell): - """Checks that a given object is an RNNCell by using duck typing.""" - conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"), - hasattr(cell, "zero_state"), callable(cell)] - return all(conditions) - - # This can be used with self.assertRaisesRegexp for assert_like_rnncell. ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell" @@ -199,6 +193,13 @@ class RNNCell(base_layer.Layer): for each `s` in `self.batch_size`. """ + def __init__(self, trainable=True, name=None, dtype=None, **kwargs): + super(RNNCell, self).__init__( + trainable=trainable, name=name, dtype=dtype, **kwargs) + # Attribute that indicates whether the cell is a TF RNN cell, due the slight + # difference between TF and Keras RNN cell. + self._is_tf_rnn_cell = True + def __call__(self, inputs, state, scope=None): """Run this RNN cell on inputs, starting from the given state. @@ -345,7 +346,8 @@ class BasicRNNCell(LayerRNNCell): Args: num_units: int, The number of units in the RNN cell. - activation: Nonlinearity to use. Default: `tanh`. + activation: Nonlinearity to use. Default: `tanh`. It could also be string + that is within Keras activation function names. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. @@ -354,6 +356,8 @@ class BasicRNNCell(LayerRNNCell): cases. dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). """ def __init__(self, @@ -361,14 +365,19 @@ class BasicRNNCell(LayerRNNCell): activation=None, reuse=None, name=None, - dtype=None): - super(BasicRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + dtype=None, + **kwargs): + super(BasicRNNCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) # Inputs must be 2-dimensional. self.input_spec = base_layer.InputSpec(ndim=2) self._num_units = num_units - self._activation = activation or math_ops.tanh + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh @property def state_size(self): @@ -378,12 +387,13 @@ class BasicRNNCell(LayerRNNCell): def output_size(self): return self._num_units + @tf_utils.shape_type_conversion def build(self, inputs_shape): - if inputs_shape[1].value is None: + if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) - input_depth = inputs_shape[1].value + input_depth = inputs_shape[-1] self._kernel = self.add_variable( _WEIGHTS_VARIABLE_NAME, shape=[input_depth + self._num_units, self._num_units]) @@ -403,6 +413,15 @@ class BasicRNNCell(LayerRNNCell): output = self._activation(gate_inputs) return output, output + def get_config(self): + config = { + "num_units": self._num_units, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(BasicRNNCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + @tf_export("nn.rnn_cell.GRUCell") class GRUCell(LayerRNNCell): @@ -422,6 +441,8 @@ class GRUCell(LayerRNNCell): cases. dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). """ def __init__(self, @@ -431,16 +452,21 @@ class GRUCell(LayerRNNCell): kernel_initializer=None, bias_initializer=None, name=None, - dtype=None): - super(GRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + dtype=None, + **kwargs): + super(GRUCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) # Inputs must be 2-dimensional. self.input_spec = base_layer.InputSpec(ndim=2) self._num_units = num_units - self._activation = activation or math_ops.tanh - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh + self._kernel_initializer = initializers.get(kernel_initializer) + self._bias_initializer = initializers.get(bias_initializer) @property def state_size(self): @@ -450,12 +476,13 @@ class GRUCell(LayerRNNCell): def output_size(self): return self._num_units + @tf_utils.shape_type_conversion def build(self, inputs_shape): - if inputs_shape[1].value is None: + if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) - input_depth = inputs_shape[1].value + input_depth = inputs_shape[-1] self._gate_kernel = self.add_variable( "gates/%s" % _WEIGHTS_VARIABLE_NAME, shape=[input_depth + self._num_units, 2 * self._num_units], @@ -501,6 +528,17 @@ class GRUCell(LayerRNNCell): new_h = u * state + (1 - u) * c return new_h, new_h + def get_config(self): + config = { + "num_units": self._num_units, + "kernel_initializer": initializers.serialize(self._kernel_initializer), + "bias_initializer": initializers.serialize(self._bias_initializer), + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(GRUCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + _LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) @@ -525,9 +563,12 @@ class LSTMStateTuple(_LSTMStateTuple): return c.dtype +# TODO(scottzhu): Stop exporting this class in TF 2.0. @tf_export("nn.rnn_cell.BasicLSTMCell") class BasicLSTMCell(LayerRNNCell): - """Basic LSTM recurrent network cell. + """DEPRECATED: Please use @{tf.nn.rnn_cell.LSTMCell} instead. + + Basic LSTM recurrent network cell. The implementation is based on: http://arxiv.org/abs/1409.2329. @@ -537,10 +578,14 @@ class BasicLSTMCell(LayerRNNCell): It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. - For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell} + For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell` that follows. """ + @deprecated(None, "This class is deprecated, please use " + "tf.nn.rnn_cell.LSTMCell, which supports all the feature " + "this cell currently has. Please replace the existing code " + "with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').") def __init__(self, num_units, forget_bias=1.0, @@ -548,7 +593,8 @@ class BasicLSTMCell(LayerRNNCell): activation=None, reuse=None, name=None, - dtype=None): + dtype=None, + **kwargs): """Initialize the basic LSTM cell. Args: @@ -559,7 +605,8 @@ class BasicLSTMCell(LayerRNNCell): state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. The latter behavior will soon be deprecated. - activation: Activation function of the inner states. Default: `tanh`. + activation: Activation function of the inner states. Default: `tanh`. It + could also be string that is within Keras activation function names. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. @@ -568,11 +615,14 @@ class BasicLSTMCell(LayerRNNCell): cases. dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). When restoring from CudnnLSTM-trained checkpoints, must use `CudnnCompatibleLSTMCell` instead. """ - super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + super(BasicLSTMCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) @@ -583,7 +633,10 @@ class BasicLSTMCell(LayerRNNCell): self._num_units = num_units self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple - self._activation = activation or math_ops.tanh + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh @property def state_size(self): @@ -594,12 +647,13 @@ class BasicLSTMCell(LayerRNNCell): def output_size(self): return self._num_units + @tf_utils.shape_type_conversion def build(self, inputs_shape): - if inputs_shape[1].value is None: + if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) - input_depth = inputs_shape[1].value + input_depth = inputs_shape[-1] h_depth = self._num_units self._kernel = self.add_variable( _WEIGHTS_VARIABLE_NAME, @@ -657,6 +711,17 @@ class BasicLSTMCell(LayerRNNCell): new_state = array_ops.concat([new_c, new_h], 1) return new_h, new_state + def get_config(self): + config = { + "num_units": self._num_units, + "forget_bias": self._forget_bias, + "state_is_tuple": self._state_is_tuple, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(BasicLSTMCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + @tf_export("nn.rnn_cell.LSTMCell") class LSTMCell(LayerRNNCell): @@ -686,7 +751,7 @@ class LSTMCell(LayerRNNCell): initializer=None, num_proj=None, proj_clip=None, num_unit_shards=None, num_proj_shards=None, forget_bias=1.0, state_is_tuple=True, - activation=None, reuse=None, name=None, dtype=None): + activation=None, reuse=None, name=None, dtype=None, **kwargs): """Initialize the parameters for an LSTM cell. Args: @@ -712,7 +777,8 @@ class LSTMCell(LayerRNNCell): state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. This latter behavior will soon be deprecated. - activation: Activation function of the inner states. Default: `tanh`. + activation: Activation function of the inner states. Default: `tanh`. It + could also be string that is within Keras activation function names. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. @@ -721,11 +787,14 @@ class LSTMCell(LayerRNNCell): cases. dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). When restoring from CudnnLSTM-trained checkpoints, use `CudnnCompatibleLSTMCell` instead. """ - super(LSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + super(LSTMCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) @@ -741,14 +810,17 @@ class LSTMCell(LayerRNNCell): self._num_units = num_units self._use_peepholes = use_peepholes self._cell_clip = cell_clip - self._initializer = initializer + self._initializer = initializers.get(initializer) self._num_proj = num_proj self._proj_clip = proj_clip self._num_unit_shards = num_unit_shards self._num_proj_shards = num_proj_shards self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple - self._activation = activation or math_ops.tanh + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh if num_proj: self._state_size = ( @@ -769,12 +841,13 @@ class LSTMCell(LayerRNNCell): def output_size(self): return self._output_size + @tf_utils.shape_type_conversion def build(self, inputs_shape): - if inputs_shape[1].value is None: + if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) - input_depth = inputs_shape[1].value + input_depth = inputs_shape[-1] h_depth = self._num_units if self._num_proj is None else self._num_proj maybe_partitioner = ( partitioned_variables.fixed_size_partitioner(self._num_unit_shards) @@ -888,6 +961,24 @@ class LSTMCell(LayerRNNCell): array_ops.concat([c, m], 1)) return m, new_state + def get_config(self): + config = { + "num_units": self._num_units, + "use_peepholes": self._use_peepholes, + "cell_clip": self._cell_clip, + "initializer": initializers.serialize(self._initializer), + "num_proj": self._num_proj, + "proj_clip": self._proj_clip, + "num_unit_shards": self._num_unit_shards, + "num_proj_shards": self._num_proj_shards, + "forget_bias": self._forget_bias, + "state_is_tuple": self._state_is_tuple, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(LSTMCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): ix = [0] @@ -1271,6 +1362,11 @@ class MultiRNNCell(RNNCell): raise TypeError( "cells must be a list or tuple, but saw: %s." % cells) + if len(set([id(cell) for cell in cells])) < len(cells): + logging.log_first_n(logging.WARN, + "At least two cells provided to MultiRNNCell " + "are the same object and will share weights.", 1) + self._cells = cells for cell_number, cell in enumerate(self._cells): # Add Checkpointable dependencies on these cells so their variables get @@ -1329,48 +1425,3 @@ class MultiRNNCell(RNNCell): array_ops.concat(new_states, 1)) return cur_inp, new_states - - -class _SlimRNNCell(RNNCell, checkpointable.NotCheckpointable): - """A simple wrapper for slim.rnn_cells.""" - - def __init__(self, cell_fn): - """Create a SlimRNNCell from a cell_fn. - - Args: - cell_fn: a function which takes (inputs, state, scope) and produces the - outputs and the new_state. Additionally when called with inputs=None and - state=None it should return (initial_outputs, initial_state). - - Raises: - TypeError: if cell_fn is not callable - ValueError: if cell_fn cannot produce a valid initial state. - """ - if not callable(cell_fn): - raise TypeError("cell_fn %s needs to be callable", cell_fn) - self._cell_fn = cell_fn - self._cell_name = cell_fn.func.__name__ - init_output, init_state = self._cell_fn(None, None) - output_shape = init_output.get_shape() - state_shape = init_state.get_shape() - self._output_size = output_shape.with_rank(2)[1].value - self._state_size = state_shape.with_rank(2)[1].value - if self._output_size is None: - raise ValueError("Initial output created by %s has invalid shape %s" % - (self._cell_name, output_shape)) - if self._state_size is None: - raise ValueError("Initial state created by %s has invalid shape %s" % - (self._cell_name, state_shape)) - - @property - def state_size(self): - return self._state_size - - @property - def output_size(self): - return self._output_size - - def __call__(self, inputs, state, scope=None): - scope = scope or self._cell_name - output, state = self._cell_fn(inputs, state, scope=scope) - return output, state diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index cc23d0d133ecdb1415e0effcbc2ce52a962fb41e..d11e446dbfdc88b2ad992dca937d13e681b555be 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -34,6 +34,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_script_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import compat @@ -96,28 +97,27 @@ class EagerFunc(object): return constant_op.constant(0.0, dtype=dtype) return ops.convert_to_tensor(value, dtype=dtype) - def __call__(self, on_gpu, token, args): + def __call__(self, device, token, args): """Passes `args` to `self._func`, which is executed eagerly.""" - with context.eager_mode(): - with backprop.GradientTape() as tape: - for tensor in args: - tape.watch(tensor) - ret = self._func(*args) - # NB: The tape needs to watch copies across devices. - maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu() + with context.eager_mode(), backprop.GradientTape() as tape: + for tensor in args: + tape.watch(tensor) + ret = self._func(*args) + # Use tf.identity to copy the returned tensors to device if neccesary. + with ops.device(device): if isinstance(ret, (tuple, list)): outputs = [ - maybe_copy_to_gpu(self._convert(x, dtype=dtype)) + array_ops.identity(self._convert(x, dtype=dtype)) for (x, dtype) in zip(ret, self._out_dtypes) ] elif ret is None: outputs = None else: - outputs = maybe_copy_to_gpu( + outputs = array_ops.identity( self._convert(ret, dtype=self._out_dtypes[0])) - tape_cache[compat.as_bytes(token)] = (tape, args, outputs) - return outputs + tape_cache[compat.as_bytes(token)] = (tape, args, outputs) + return outputs class FuncRegistry(object): @@ -130,7 +130,7 @@ class FuncRegistry(object): def __init__(self): self._lock = threading.Lock() self._unique_id = 0 # GUARDED_BY(self._lock) - # Only store weakrefs to the funtions. The strong reference is stored in + # Only store weakrefs to the functions. The strong reference is stored in # the graph. self._funcs = weakref.WeakValueDictionary() @@ -174,14 +174,14 @@ class FuncRegistry(object): else: return result - def __call__(self, token, on_gpu, args): + def __call__(self, token, device, args): """Calls the registered function for `token` with args. Args: token: A key into this `FuncRegistry` identifying which function to call. - on_gpu: A boolean indicating whether or not `token`'s corresponding - operation was placed on GPU; only used if the function registered for - `token` is an `EagerPyFunc`. + device: Name of the device on which outputs of `token`'s corresponding + operation should be placed. Used iff the function registered for `token` + is an EagerPyFunc. args: The arguments to pass to the function registered for `token`. Returns: @@ -201,7 +201,7 @@ class FuncRegistry(object): # or if the graph is being driven by concurrent session.run() calls. # # TODO(akshayka): Key the tape cache in a thread-safe way. - return func(on_gpu, token, args) + return func(device, token, args) else: ret = func(*args) # Strings seem to lead to a memory leak here if they're not wrapped in a @@ -232,8 +232,13 @@ _py_funcs = FuncRegistry() pywrap_tensorflow.InitializePyTrampoline(_py_funcs) -def _internal_py_func(func, inp, Tout, stateful=None, eager=False, - is_grad_func=False, name=None): +def _internal_py_func(func, + inp, + Tout, + stateful=None, + eager=False, + is_grad_func=False, + name=None): """See documentation for py_func and eager_py_func.""" is_list_or_tuple = False @@ -296,7 +301,8 @@ def _EagerPyFuncGrad(op, dy): func=eagerly_executed_grad, inp=[dy] if isinstance(dy, ops.Tensor) else dy, Tout=[tensor.dtype for tensor in op.inputs], - eager=True, is_grad_func=True) + eager=True, + is_grad_func=True) def eager_py_func(func, inp, Tout, name=None): @@ -307,8 +313,8 @@ def eager_py_func(func, inp, Tout, name=None): in a once-differentiable TensorFlow operation that executes it with eager exeuction enabled. As a consequence, `tf.contrib.eager.py_func` makes it possible to express control flow using Python constructs (`if`, `while`, - `for`, etc.), instead of TensorFlow control flow constructs (@{tf.cond}, - @{tf.while_loop}). For example, you might use `tf.contrib.eager.py_func` to + `for`, etc.), instead of TensorFlow control flow constructs (`tf.cond`, + `tf.while_loop`). For example, you might use `tf.contrib.eager.py_func` to implement the log huber function: ```python @@ -337,17 +343,17 @@ def eager_py_func(func, inp, Tout, name=None): or print statements as desired, and wrap those functions in `tf.contrib.eager.py_func`. - For more information on eager execution, see @{$programmers_guide/eager}. + For more information on eager execution, see @{$guide/eager}. - `tf.contrib.eager.py_func` is similar in spirit to @{tf.py_func}, but unlike + `tf.contrib.eager.py_func` is similar in spirit to `tf.py_func`, but unlike the latter, the former lets you use TensorFlow operations in the wrapped - Python function. In particular, while @{tf.py_func} only runs on CPUs and + Python function. In particular, while `tf.py_func` only runs on CPUs and wraps functions that take NumPy arrays as inputs and return NumPy arrays as outputs, `tf.contrib.eager.py_func` can be placed on GPUs and wraps functions that take Tensors as inputs, execute TensorFlow operations in their bodies, and return Tensors as outputs. - Like @{tf.py_func}, `tf.contrib.eager.py_func` has the following limitations + Like `tf.py_func`, `tf.contrib.eager.py_func` has the following limitations with respect to serialization and distribution: * The body of the function (i.e. `func`) will not be serialized in a diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py index 97353d6c747cb7e4d3c1fa92ad61af24fb17de91..1223b290ff6cfcfba27f40c05556c85b59e77148 100644 --- a/tensorflow/python/ops/sparse_grad.py +++ b/tensorflow/python/ops/sparse_grad.py @@ -116,6 +116,35 @@ def _SparseReduceSumGrad(op, out_grad): None, None) +@ops.RegisterGradient("SparseSlice") +def _SparseSliceGrad(op, *grads): + """The backward operator for the SparseSlice op. + + This op takes in the upstream gradient w.r.t. non-empty values of + the sliced `SparseTensor`, and outputs the gradients w.r.t. + the non-empty values of input `SparseTensor`. + + Args: + op: the SparseSlice op + *grads: the incoming gradients, one element per output of `op` + + Returns: + Gradient for each of the 5 input tensors of SparseSlice: + (indices, values, shape, start, size) + The gradients for the indices, shape, start and the size are None. + """ + backprop_val_grad = grads[1] + input_indices = op.inputs[0] + input_start = op.inputs[3] + output_indices = op.outputs[0] + + val_grad = gen_sparse_ops.sparse_slice_grad( + backprop_val_grad, input_indices, input_start, output_indices) + val_grad.set_shape(op.inputs[1].get_shape()) + # (indices, values, shape, start, size) + return (None, val_grad, None, None, None) + + @ops.RegisterGradient("SparseTensorDenseMatMul") def _SparseTensorDenseMatMulGrad(op, grad): """Gradients for the dense tensor in the SparseTensorDenseMatMul op. diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index c3b16a7bd5387e006aaea60b8814b1209ce01414..fd547dcb19bdda9c12bab1eafcfb3e392528f77a 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -777,8 +777,10 @@ def sparse_to_dense(sparse_indices, @tf_export("sparse_reduce_max") -def sparse_reduce_max(sp_input, axis=None, keep_dims=False, - reduction_axes=None): +@deprecation.deprecated_args( + None, "keep_dims is deprecated, use keepdims instead", "keep_dims") +def sparse_reduce_max(sp_input, axis=None, keepdims=None, + reduction_axes=None, keep_dims=None): """Computes the max of elements across dimensions of a SparseTensor. This Op takes a SparseTensor and is the sparse counterpart to @@ -786,8 +788,8 @@ def sparse_reduce_max(sp_input, axis=None, keep_dims=False, instead of a sparse one. Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless - `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in - `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained + `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in + `reduction_axes`. If `keepdims` is true, the reduced dimensions are retained with length 1. If `reduction_axes` has no entries, all dimensions are reduced, and a tensor @@ -803,7 +805,7 @@ def sparse_reduce_max(sp_input, axis=None, keep_dims=False, tf.sparse_reduce_max(x) ==> 3 tf.sparse_reduce_max(x, 0) ==> [1, 3, 2] tf.sparse_reduce_max(x, 1) ==> [2, 3] # Can also use -1 as the axis. - tf.sparse_reduce_max(x, 1, keep_dims=True) ==> [[2], [3]] + tf.sparse_reduce_max(x, 1, keepdims=True) ==> [[2], [3]] tf.sparse_reduce_max(x, [0, 1]) ==> 3 ``` @@ -811,22 +813,31 @@ def sparse_reduce_max(sp_input, axis=None, keep_dims=False, sp_input: The SparseTensor to reduce. Should have numeric type. axis: The dimensions to reduce; list or scalar. If `None` (the default), reduces all dimensions. - keep_dims: If true, retain reduced dimensions with length 1. + keepdims: If true, retain reduced dimensions with length 1. reduction_axes: Deprecated name of axis. + keep_dims: Deprecated alias for `keepdims`. Returns: The reduced Tensor. """ + keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, + "keep_dims", keep_dims) + if keepdims is None: + keepdims = False + return gen_sparse_ops.sparse_reduce_max( sp_input.indices, sp_input.values, sp_input.dense_shape, - math_ops._ReductionDims(sp_input, axis, reduction_axes), keep_dims) + math_ops._ReductionDims(sp_input, axis, reduction_axes), keepdims) @tf_export("sparse_reduce_max_sparse") +@deprecation.deprecated_args( + None, "keep_dims is deprecated, use keepdims instead", "keep_dims") def sparse_reduce_max_sparse(sp_input, axis=None, - keep_dims=False, - reduction_axes=None): + keepdims=None, + reduction_axes=None, + keep_dims=None): """Computes the max of elements across dimensions of a SparseTensor. This Op takes a SparseTensor and is the sparse counterpart to @@ -834,8 +845,8 @@ def sparse_reduce_max_sparse(sp_input, SparseTensor. Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless - `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in - `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained + `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in + `reduction_axes`. If `keepdims` is true, the reduced dimensions are retained with length 1. If `reduction_axes` has no entries, all dimensions are reduced, and a tensor @@ -846,23 +857,31 @@ def sparse_reduce_max_sparse(sp_input, sp_input: The SparseTensor to reduce. Should have numeric type. axis: The dimensions to reduce; list or scalar. If `None` (the default), reduces all dimensions. - keep_dims: If true, retain reduced dimensions with length 1. - reduction_axes: Deprecated name of axis + keepdims: If true, retain reduced dimensions with length 1. + reduction_axes: Deprecated name of axis. + keep_dims: Deprecated alias for `keepdims`. Returns: The reduced SparseTensor. """ + keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, + "keep_dims", keep_dims) + if keepdims is None: + keepdims = False + output_ind, output_val, output_shape = ( gen_sparse_ops.sparse_reduce_max_sparse( sp_input.indices, sp_input.values, sp_input.dense_shape, - math_ops._ReductionDims(sp_input, axis, reduction_axes), keep_dims)) + math_ops._ReductionDims(sp_input, axis, reduction_axes), keepdims)) return sparse_tensor.SparseTensor(output_ind, output_val, output_shape) @tf_export("sparse_reduce_sum") -def sparse_reduce_sum(sp_input, axis=None, keep_dims=False, - reduction_axes=None): +@deprecation.deprecated_args( + None, "keep_dims is deprecated, use keepdims instead", "keep_dims") +def sparse_reduce_sum(sp_input, axis=None, keepdims=None, + reduction_axes=None, keep_dims=None): """Computes the sum of elements across dimensions of a SparseTensor. This Op takes a SparseTensor and is the sparse counterpart to @@ -870,8 +889,8 @@ def sparse_reduce_sum(sp_input, axis=None, keep_dims=False, instead of a sparse one. Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless - `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in - `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained + `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in + `reduction_axes`. If `keepdims` is true, the reduced dimensions are retained with length 1. If `reduction_axes` has no entries, all dimensions are reduced, and a tensor @@ -887,7 +906,7 @@ def sparse_reduce_sum(sp_input, axis=None, keep_dims=False, tf.sparse_reduce_sum(x) ==> 3 tf.sparse_reduce_sum(x, 0) ==> [1, 1, 1] tf.sparse_reduce_sum(x, 1) ==> [2, 1] # Can also use -1 as the axis. - tf.sparse_reduce_sum(x, 1, keep_dims=True) ==> [[2], [1]] + tf.sparse_reduce_sum(x, 1, keepdims=True) ==> [[2], [1]] tf.sparse_reduce_sum(x, [0, 1]) ==> 3 ``` @@ -895,22 +914,31 @@ def sparse_reduce_sum(sp_input, axis=None, keep_dims=False, sp_input: The SparseTensor to reduce. Should have numeric type. axis: The dimensions to reduce; list or scalar. If `None` (the default), reduces all dimensions. - keep_dims: If true, retain reduced dimensions with length 1. + keepdims: If true, retain reduced dimensions with length 1. reduction_axes: Deprecated name of axis. + keep_dims: Deprecated alias for `keepdims`. Returns: The reduced Tensor. """ + keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, + "keep_dims", keep_dims) + if keepdims is None: + keepdims = False + return gen_sparse_ops.sparse_reduce_sum( sp_input.indices, sp_input.values, sp_input.dense_shape, - math_ops._ReductionDims(sp_input, axis, reduction_axes), keep_dims) + math_ops._ReductionDims(sp_input, axis, reduction_axes), keepdims) @tf_export("sparse_reduce_sum_sparse") +@deprecation.deprecated_args( + None, "keep_dims is deprecated, use keepdims instead", "keep_dims") def sparse_reduce_sum_sparse(sp_input, axis=None, - keep_dims=False, - reduction_axes=None): + keepdims=None, + reduction_axes=None, + keep_dims=None): """Computes the sum of elements across dimensions of a SparseTensor. This Op takes a SparseTensor and is the sparse counterpart to @@ -918,8 +946,8 @@ def sparse_reduce_sum_sparse(sp_input, SparseTensor. Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless - `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in - `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained + `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in + `reduction_axes`. If `keepdims` is true, the reduced dimensions are retained with length 1. If `reduction_axes` has no entries, all dimensions are reduced, and a tensor @@ -930,16 +958,22 @@ def sparse_reduce_sum_sparse(sp_input, sp_input: The SparseTensor to reduce. Should have numeric type. axis: The dimensions to reduce; list or scalar. If `None` (the default), reduces all dimensions. - keep_dims: If true, retain reduced dimensions with length 1. - reduction_axes: Deprecated name of axis + keepdims: If true, retain reduced dimensions with length 1. + reduction_axes: Deprecated name of axis. + keep_dims: Deprecated alias for `keepdims`. Returns: The reduced SparseTensor. """ + keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, + "keep_dims", keep_dims) + if keepdims is None: + keepdims = False + output_ind, output_val, output_shape = ( gen_sparse_ops.sparse_reduce_sum_sparse( sp_input.indices, sp_input.values, sp_input.dense_shape, - math_ops._ReductionDims(sp_input, axis, reduction_axes), keep_dims)) + math_ops._ReductionDims(sp_input, axis, reduction_axes), keepdims)) return sparse_tensor.SparseTensor(output_ind, output_val, output_shape) diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index 1508873b751c4fa42d3488ff2d18b5795fda9652..9a10abfcf736be783bfcd7907ec6f357912828ab 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -34,7 +34,7 @@ from tensorflow.python.util.tf_export import tf_export # TODO(b/27419586) Change docstring for required dtype of x once int allowed @tf_export('lbeta') -def lbeta(x, name='lbeta'): +def lbeta(x, name=None): r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension. Given one-dimensional `z = [z_0,...,z_{K-1}]`, we define @@ -64,7 +64,7 @@ def lbeta(x, name='lbeta'): # This is consistent with a convention that the sum over the empty set 0, and # the product is 1. # This is standard. See https://en.wikipedia.org/wiki/Empty_set. - with ops.name_scope(name, values=[x]): + with ops.name_scope(name, 'lbeta', [x]): x = ops.convert_to_tensor(x, name='x') # Note reduce_sum([]) = 0. @@ -83,7 +83,7 @@ def lbeta(x, name='lbeta'): @tf_export('math.bessel_i0') -def bessel_i0(x, name='bessel_i0'): +def bessel_i0(x, name=None): """Computes the Bessel i0 function of `x` element-wise. Modified Bessel function of order 0. @@ -102,12 +102,12 @@ def bessel_i0(x, name='bessel_i0'): Equivalent to scipy.special.i0 @end_compatibility """ - with ops.name_scope(name, [x]): + with ops.name_scope(name, 'bessel_i0', [x]): return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i0e(x) @tf_export('math.bessel_i1') -def bessel_i1(x, name='bessel_i1'): +def bessel_i1(x, name=None): """Computes the Bessel i1 function of `x` element-wise. Modified Bessel function of order 1. @@ -126,7 +126,7 @@ def bessel_i1(x, name='bessel_i1'): Equivalent to scipy.special.i1 @end_compatibility """ - with ops.name_scope(name, [x]): + with ops.name_scope(name, 'bessel_i1', [x]): return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i1e(x) @@ -201,8 +201,8 @@ def einsum(equation, *inputs, **kwargs): indices in its subscript, or - the input shapes are inconsistent along a particular axis. """ - equation = equation.replace(" ", "") - + equation = equation.replace(' ', '') + name = kwargs.pop('name', None) if kwargs: raise TypeError('invalid keyword arguments for this function: ' + ', '.join( diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index b7e164f149a9cca336fee061ae2cc3a464ca6132..9bc4098d5b63c3e8ee4f9c14332e65b3d2875d8b 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -25,24 +25,25 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging - class LBetaTest(test.TestCase): + @test_util.run_in_graph_and_eager_modes def test_one_dimensional_arg(self): # Should evaluate to 1 and 1/2. x_one = [1, 1.] x_one_half = [2, 1.] with self.test_session(use_gpu=True): - self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_one)).eval()) - self.assertAllClose(0.5, - math_ops.exp( - special_math_ops.lbeta(x_one_half)).eval()) + self.assertAllClose( + 1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one)))) + self.assertAllClose( + 0.5, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half)))) self.assertEqual([], special_math_ops.lbeta(x_one).get_shape()) def test_one_dimensional_arg_dynamic(self): @@ -53,7 +54,8 @@ class LBetaTest(test.TestCase): ph = array_ops.placeholder(dtypes.float32) beta_ph = math_ops.exp(special_math_ops.lbeta(ph)) self.assertAllClose(1, beta_ph.eval(feed_dict={ph: x_one})) - self.assertAllClose(0.5, beta_ph.eval(feed_dict={ph: x_one_half})) + self.assertAllClose(0.5, + beta_ph.eval(feed_dict={ph: x_one_half})) def test_four_dimensional_arg_with_partial_shape_dynamic(self): x_ = np.ones((3, 2, 3, 4)) @@ -66,15 +68,17 @@ class LBetaTest(test.TestCase): with self.test_session(use_gpu=True): x_ph = array_ops.placeholder(dtypes.float32, [3, 2, 3, None]) beta_ph = math_ops.exp(special_math_ops.lbeta(x_ph)) - self.assertAllClose(expected_beta_x, beta_ph.eval(feed_dict={x_ph: x_})) + self.assertAllClose(expected_beta_x, + beta_ph.eval(feed_dict={x_ph: x_})) + @test_util.run_in_graph_and_eager_modes def test_two_dimensional_arg(self): # Should evaluate to 1/2. x_one_half = [[2, 1.], [2, 1.]] with self.test_session(use_gpu=True): - self.assertAllClose([0.5, 0.5], - math_ops.exp( - special_math_ops.lbeta(x_one_half)).eval()) + self.assertAllClose( + [0.5, 0.5], + self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half)))) self.assertEqual((2,), special_math_ops.lbeta(x_one_half).get_shape()) def test_two_dimensional_arg_dynamic(self): @@ -83,50 +87,59 @@ class LBetaTest(test.TestCase): with self.test_session(use_gpu=True): ph = array_ops.placeholder(dtypes.float32) beta_ph = math_ops.exp(special_math_ops.lbeta(ph)) - self.assertAllClose([0.5, 0.5], beta_ph.eval(feed_dict={ph: x_one_half})) + self.assertAllClose([0.5, 0.5], + beta_ph.eval(feed_dict={ph: x_one_half})) + @test_util.run_in_graph_and_eager_modes def test_two_dimensional_proper_shape(self): # Should evaluate to 1/2. x_one_half = [[2, 1.], [2, 1.]] with self.test_session(use_gpu=True): - self.assertAllClose([0.5, 0.5], - math_ops.exp( - special_math_ops.lbeta(x_one_half)).eval()) + self.assertAllClose( + [0.5, 0.5], + self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half)))) self.assertEqual( (2,), - array_ops.shape(special_math_ops.lbeta(x_one_half)).eval()) + self.evaluate(array_ops.shape(special_math_ops.lbeta(x_one_half)))) self.assertEqual( tensor_shape.TensorShape([2]), special_math_ops.lbeta(x_one_half).get_shape()) + @test_util.run_in_graph_and_eager_modes def test_complicated_shape(self): with self.test_session(use_gpu=True): x = ops.convert_to_tensor(np.random.rand(3, 2, 2)) - self.assertAllEqual((3, 2), - array_ops.shape(special_math_ops.lbeta(x)).eval()) + self.assertAllEqual( + (3, 2), self.evaluate(array_ops.shape(special_math_ops.lbeta(x)))) self.assertEqual( tensor_shape.TensorShape([3, 2]), special_math_ops.lbeta(x).get_shape()) + @test_util.run_in_graph_and_eager_modes def test_length_1_last_dimension_results_in_one(self): # If there is only one coefficient, the formula still works, and we get one # as the answer, always. x_a = [5.5] x_b = [0.1] with self.test_session(use_gpu=True): - self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_a)).eval()) - self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_b)).eval()) + self.assertAllClose( + 1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_a)))) + self.assertAllClose( + 1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_b)))) self.assertEqual((), special_math_ops.lbeta(x_a).get_shape()) + @test_util.run_in_graph_and_eager_modes def test_empty_rank1_returns_negative_infinity(self): with self.test_session(use_gpu=True): x = constant_op.constant([], shape=[0]) lbeta_x = special_math_ops.lbeta(x) expected_result = constant_op.constant(-np.inf, shape=()) - self.assertAllEqual(expected_result.eval(), lbeta_x.eval()) + self.assertAllEqual(self.evaluate(expected_result), + self.evaluate(lbeta_x)) self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape()) + @test_util.run_in_graph_and_eager_modes def test_empty_rank2_with_zero_last_dim_returns_negative_infinity(self): with self.test_session(use_gpu=True): event_size = 0 @@ -135,9 +148,11 @@ class LBetaTest(test.TestCase): lbeta_x = special_math_ops.lbeta(x) expected_result = constant_op.constant(-np.inf, shape=[batch_size]) - self.assertAllEqual(expected_result.eval(), lbeta_x.eval()) + self.assertAllEqual(self.evaluate(expected_result), + self.evaluate(lbeta_x)) self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape()) + @test_util.run_in_graph_and_eager_modes def test_empty_rank2_with_zero_batch_dim_returns_empty(self): with self.test_session(use_gpu=True): batch_size = 0 @@ -147,12 +162,14 @@ class LBetaTest(test.TestCase): expected_result = constant_op.constant([], shape=[batch_size]) - self.assertAllEqual(expected_result.eval(), lbeta_x.eval()) + self.assertAllEqual(self.evaluate(expected_result), + self.evaluate(lbeta_x)) self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape()) class BesselTest(test.TestCase): + @test_util.run_in_graph_and_eager_modes def test_bessel_i0(self): x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32) x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64) @@ -165,6 +182,7 @@ class BesselTest(test.TestCase): except ImportError as e: tf_logging.warn('Cannot test special functions: %s' % str(e)) + @test_util.run_in_graph_and_eager_modes def test_bessel_i1(self): x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32) x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64) @@ -316,7 +334,7 @@ class EinsumTest(test.TestCase): output_tensor = special_math_ops.einsum(axes, *input_tensors) with self.test_session(use_gpu=True): - output_value = output_tensor.eval() + output_value = self.evaluate(output_tensor) correct_value = np.einsum(axes, *input_vals) diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py index 28054f50ef3b1227f12376b4b3700a7618270d65..da5884e74626b493fb71c50ff040ce4fc4a97ce3 100644 --- a/tensorflow/python/ops/spectral_ops.py +++ b/tensorflow/python/ops/spectral_ops.py @@ -167,8 +167,8 @@ def _validate_dct_arguments(dct_type, n, axis, norm): raise NotImplementedError("The DCT length argument is not implemented.") if axis != -1: raise NotImplementedError("axis must be -1. Got: %s" % axis) - if dct_type != 2: - raise ValueError("Only the Type II DCT is supported.") + if dct_type not in (2, 3): + raise ValueError("Only Types II and III (I)DCT are supported.") if norm not in (None, "ortho"): raise ValueError( "Unknown normalization. Expected None or 'ortho', got: %s" % norm) @@ -179,18 +179,20 @@ def _validate_dct_arguments(dct_type, n, axis, norm): def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`. - Currently only Type II is supported. Implemented using a length `2N` padded - @{tf.spectral.rfft}, as described here: https://dsp.stackexchange.com/a/10606 + Currently only Types II and III are supported. Type II is implemented using a + length `2N` padded `tf.spectral.rfft`, as described here: + https://dsp.stackexchange.com/a/10606. Type III is a fairly straightforward + inverse of Type II (i.e. using a length `2N` padded `tf.spectral.irfft`). @compatibility(scipy) - Equivalent to scipy.fftpack.dct for the Type-II DCT. + Equivalent to scipy.fftpack.dct for Type-II and Type-III DCT. https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html @end_compatibility Args: input: A `[..., samples]` `float32` `Tensor` containing the signals to take the DCT of. - type: The DCT type to perform. Must be 2. + type: The DCT type to perform. Must be 2 or 3. n: For future expansion. The length of the transform. Must be `None`. axis: For future expansion. The axis to compute the DCT along. Must be `-1`. norm: The normalization to apply. `None` for no normalization or `'ortho'` @@ -201,8 +203,8 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl A `[..., samples]` `float32` `Tensor` containing the DCT of `input`. Raises: - ValueError: If `type` is not `2`, `n` is not `None, `axis` is not `-1`, or - `norm` is not `None` or `'ortho'`. + ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not + `-1`, or `norm` is not `None` or `'ortho'`. [dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform """ @@ -214,22 +216,91 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl axis_dim = input.shape[-1].value or _array_ops.shape(input)[-1] axis_dim_float = _math_ops.to_float(axis_dim) - scale = 2.0 * _math_ops.exp(_math_ops.complex( - 0.0, -_math.pi * _math_ops.range(axis_dim_float) / - (2.0 * axis_dim_float))) - - # TODO(rjryan): Benchmark performance and memory usage of the various - # approaches to computing a DCT via the RFFT. - dct2 = _math_ops.real( - rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale) - - if norm == "ortho": - n1 = 0.5 * _math_ops.rsqrt(axis_dim_float) - n2 = n1 * _math_ops.sqrt(2.0) - # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. - weights = _array_ops.pad( - _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], - constant_values=n2) - dct2 *= weights - - return dct2 + if type == 2: + scale = 2.0 * _math_ops.exp( + _math_ops.complex( + 0.0, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 / + axis_dim_float)) + + # TODO(rjryan): Benchmark performance and memory usage of the various + # approaches to computing a DCT via the RFFT. + dct2 = _math_ops.real( + rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale) + + if norm == "ortho": + n1 = 0.5 * _math_ops.rsqrt(axis_dim_float) + n2 = n1 * _math_ops.sqrt(2.0) + # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. + weights = _array_ops.pad( + _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], + constant_values=n2) + dct2 *= weights + + return dct2 + + elif type == 3: + if norm == "ortho": + n1 = _math_ops.sqrt(axis_dim_float) + n2 = n1 * _math_ops.sqrt(0.5) + # Use tf.pad to make a vector of [n1, n2, n2, n2, ...]. + weights = _array_ops.pad( + _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]], + constant_values=n2) + input *= weights + else: + input *= axis_dim_float + scale = 2.0 * _math_ops.exp( + _math_ops.complex( + 0.0, + _math_ops.range(axis_dim_float) * _math.pi * 0.5 / + axis_dim_float)) + dct3 = _math_ops.real( + irfft( + scale * _math_ops.complex(input, 0.0), + fft_length=[2 * axis_dim]))[..., :axis_dim] + + return dct3 + + +# TODO(rjryan): Implement `type`, `n` and `axis` parameters. +@tf_export("spectral.idct") +def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin + """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`. + + Currently only Types II and III are supported. Type III is the inverse of + Type II, and vice versa. + + Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is + not `'ortho'`. That is: + `signal == idct(dct(signal)) * 0.5 / signal.shape[-1]`. + When `norm='ortho'`, we have: + `signal == idct(dct(signal, norm='ortho'), norm='ortho')`. + + @compatibility(scipy) + Equivalent to scipy.fftpack.idct for Type-II and Type-III DCT. + https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html + @end_compatibility + + Args: + input: A `[..., samples]` `float32` `Tensor` containing the signals to take + the DCT of. + type: The IDCT type to perform. Must be 2 or 3. + n: For future expansion. The length of the transform. Must be `None`. + axis: For future expansion. The axis to compute the DCT along. Must be `-1`. + norm: The normalization to apply. `None` for no normalization or `'ortho'` + for orthonormal normalization. + name: An optional name for the operation. + + Returns: + A `[..., samples]` `float32` `Tensor` containing the IDCT of `input`. + + Raises: + ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not + `-1`, or `norm` is not `None` or `'ortho'`. + + [idct]: + https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms + """ + _validate_dct_arguments(type, n, axis, norm) + inverse_type = {2: 3, 3: 2}[type] + return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name) diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index a2d24711e2291bafcf5736c6206ceb09ac210453..d0e5f700254fa5273cb707e59ac0d141fdc13627 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import cudnn_rnn_grad from tensorflow.python.ops import data_flow_grad from tensorflow.python.ops import manip_grad from tensorflow.python.ops import math_grad +from tensorflow.python.ops import random_grad from tensorflow.python.ops import sparse_grad from tensorflow.python.ops import spectral_grad from tensorflow.python.ops import state_grad diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 08b7cda73bdc739912ec58f161ec7113aeffd9e8..d556d11a1b26ff3159588ca7f83c2a7e54d2d711 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -19,7 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_resource_variable_ops @@ -124,9 +123,7 @@ def is_variable_initialized(ref, name=None): if ref.dtype._is_ref_dtype: return gen_state_ops.is_variable_initialized(ref=ref, name=name) # Handle resource variables. - if context.executing_eagerly() or ref.op.type == "VarHandleOp": - return gen_resource_variable_ops.var_is_initialized_op(ref.handle, - name=name) + return ref.is_initialized(name=name) @tf_export("assign_sub") @@ -332,7 +329,7 @@ def scatter_nd_update(ref, indices, updates, use_locking=True, name=None): [1, 11, 3, 10, 9, 6, 7, 12] - See @{tf.scatter_nd} for more details about how to make updates to + See `tf.scatter_nd` for more details about how to make updates to slices. Args: @@ -394,7 +391,7 @@ def scatter_add(ref, indices, updates, use_locking=False, name=None): A tensor of indices into the first dimension of `ref`. updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated values to store in `ref`. - use_locking: An optional `bool`. Defaults to `True`. + use_locking: An optional `bool`. Defaults to `False`. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. name: A name for the operation (optional). @@ -446,7 +443,7 @@ def scatter_nd_add(ref, indices, updates, use_locking=False, name=None): [1, 13, 3, 14, 14, 6, 7, 20] - See @{tf.scatter_nd} for more details about how to make updates to + See `tf.scatter_nd` for more details about how to make updates to slices. Args: @@ -458,7 +455,7 @@ def scatter_nd_add(ref, indices, updates, use_locking=False, name=None): A tensor of indices into ref. updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated values to add to ref. - use_locking: An optional `bool`. Defaults to `True`. + use_locking: An optional `bool`. Defaults to `False`. An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. @@ -473,3 +470,57 @@ def scatter_nd_add(ref, indices, updates, use_locking=False, name=None): return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), name=name)) + + +@tf_export("scatter_sub") +def scatter_sub(ref, indices, updates, use_locking=False, name=None): + r"""Subtracts sparse updates to a variable reference. + + ```python + # Scalar indices + ref[indices, ...] -= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] -= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] + ``` + + This operation outputs `ref` after the update is done. + This makes it easier to chain operations that need to use the reset value. + + Duplicate entries are handled correctly: if multiple `indices` reference + the same location, their (negated) contributions add. + + Requires `updates.shape = indices.shape + ref.shape[1:]` or + `updates.shape = []`. + +
+ +
+ + Args: + ref: A mutable `Tensor`. Must be one of the following types: `float32`, + `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, + `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, + `uint32`, `uint64`. Should be from a `Variable` node. + indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. + A tensor of indices into the first dimension of `ref`. + updates: A `Tensor`. Must have the same type as `ref`. + A tensor of updated values to subtract from `ref`. + use_locking: An optional `bool`. Defaults to `False`. + If True, the subtraction will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. + name: A name for the operation (optional). + + Returns: + A mutable `Tensor`. Has the same type as `ref`. + """ + if ref.dtype._is_ref_dtype: + return gen_state_ops.scatter_sub(ref, indices, updates, + use_locking=use_locking, name=name) + return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access + ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), + name=name)) diff --git a/tensorflow/python/ops/summary_op_util.py b/tensorflow/python/ops/summary_op_util.py index a793f634bda06ad43991fb978f865a2c5fe25437..b382c3b7ce57e3b07d7a6e598ef86948f3abe3a6 100644 --- a/tensorflow/python/ops/summary_op_util.py +++ b/tensorflow/python/ops/summary_op_util.py @@ -23,7 +23,7 @@ import re from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging -from tensorflow.python.training import distribute +from tensorflow.python.training import distribution_strategy_context def collect(val, collections, default_collections): @@ -49,7 +49,7 @@ def skip_summary(): # TODO(priyag): Add a new optional argument that will provide multiple # alternatives to override default behavior. (e.g. run on last tower, # compute sum or mean across towers). - tower_context = distribute.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() return tower_context and tower_context.tower_id > 0 diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index b80f84eb7cde264c5a7c83eafacc344adb50b80a..94c7d88b5c9f13de4769a450a07031206a4020fd 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -110,8 +110,8 @@ class SummaryWriter(object): """Encapsulates a stateful summary writer resource. See also: - - @{tf.contrib.summary.create_file_writer} - - @{tf.contrib.summary.create_db_writer} + - `tf.contrib.summary.create_file_writer` + - `tf.contrib.summary.create_db_writer` """ def __init__(self, resource, init_op_fn): @@ -174,22 +174,22 @@ def initialize( """Initializes summary writing for graph execution mode. This helper method provides a higher-level alternative to using - @{tf.contrib.summary.summary_writer_initializer_op} and - @{tf.contrib.summary.graph}. + `tf.contrib.summary.summary_writer_initializer_op` and + `tf.contrib.summary.graph`. - Most users will also want to call @{tf.train.create_global_step} + Most users will also want to call `tf.train.create_global_step` which can happen before or after this function is called. Args: - graph: A @{tf.Graph} or @{tf.GraphDef} to output to the writer. + graph: A `tf.Graph` or `tf.GraphDef` to output to the writer. This function will not write the default graph by default. When writing to an event log file, the associated step will be zero. - session: So this method can call @{tf.Session.run}. This defaults - to @{tf.get_default_session}. + session: So this method can call `tf.Session.run`. This defaults + to `tf.get_default_session`. Raises: RuntimeError: If the current thread has no default - @{tf.contrib.summary.SummaryWriter}. + `tf.contrib.summary.SummaryWriter`. ValueError: If session wasn't passed and no default session. """ if context.executing_eagerly(): @@ -278,10 +278,10 @@ def create_db_writer(db_uri, Experiment will not be associated with a User. Must be valid as both a DNS label and Linux username. name: Shared name for this SummaryWriter resource stored to default - @{tf.Graph}. + `tf.Graph`. Returns: - A @{tf.contrib.summary.SummaryWriter} instance. + A `tf.contrib.summary.SummaryWriter` instance. """ with ops.device("cpu:0"): if experiment_name is None: @@ -306,10 +306,11 @@ def create_db_writer(db_uri, def _make_summary_writer(name, factory, **kwargs): resource = gen_summary_ops.summary_writer(shared_name=name) init_op_fn = lambda: factory(resource, **kwargs) - # TODO(apassos): Consider doing this instead. - # if not context.executing_eagerly(): - # ops.get_default_session().run(init_op) - ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op_fn()) + init_op = init_op_fn() + if not context.executing_eagerly(): + # TODO(apassos): Consider doing this instead. + # ops.get_default_session().run(init_op) + ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op) return SummaryWriter(resource, init_op_fn) @@ -327,7 +328,7 @@ def _nothing(): def all_summary_ops(): """Graph-mode only. Returns all summary ops. - Please note this excludes @{tf.contrib.summary.graph} ops. + Please note this excludes `tf.contrib.summary.graph` ops. Returns: The summary ops. @@ -380,7 +381,8 @@ def summary_writer_function(name, tensor, function, family=None): with ops.device("cpu:0"): op = smart_cond.smart_cond( should_record_summaries(), record, _nothing, name="") - ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access + if not context.executing_eagerly(): + ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access return op @@ -408,20 +410,20 @@ def generic(name, tensor, metadata=None, family=None, step=None): def scalar(name, tensor, family=None, step=None): """Writes a scalar summary if possible. - Unlike @{tf.contrib.summary.generic} this op may change the dtype + Unlike `tf.contrib.summary.generic` this op may change the dtype depending on the writer, for both practical and efficiency concerns. Args: name: An arbitrary name for this summary. - tensor: A @{tf.Tensor} Must be one of the following types: + tensor: A `tf.Tensor` Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `uint8`, `int16`, `int8`, `uint16`, `half`, `uint32`, `uint64`. family: Optional, the summary's family. step: The `int64` monotonic step variable, which defaults - to @{tf.train.get_global_step}. + to `tf.train.get_global_step`. Returns: - The created @{tf.Operation} or a @{tf.no_op} if summary writing has + The created `tf.Operation` or a `tf.no_op` if summary writing has not been enabled for this context. """ @@ -492,31 +494,31 @@ def graph(param, step=None, name=None): """Writes a TensorFlow graph to the summary interface. The graph summary is, strictly speaking, not a summary. Conditions - like @{tf.contrib.summary.never_record_summaries} do not apply. Only + like `tf.contrib.summary.never_record_summaries` do not apply. Only a single graph can be associated with a particular run. If multiple graphs are written, then only the last one will be considered by TensorBoard. When not using eager execution mode, the user should consider passing - the `graph` parameter to @{tf.contrib.summary.initialize} instead of + the `graph` parameter to `tf.contrib.summary.initialize` instead of calling this function. Otherwise special care needs to be taken when using the graph to record the graph. Args: - param: A @{tf.Tensor} containing a serialized graph proto. When + param: A `tf.Tensor` containing a serialized graph proto. When eager execution is enabled, this function will automatically - coerce @{tf.Graph}, @{tf.GraphDef}, and string types. + coerce `tf.Graph`, `tf.GraphDef`, and string types. step: The global step variable. This doesn't have useful semantics for graph summaries, but is used anyway, due to the structure of event log files. This defaults to the global step. name: A name for the operation (optional). Returns: - The created @{tf.Operation} or a @{tf.no_op} if summary writing has + The created `tf.Operation` or a `tf.no_op` if summary writing has not been enabled for this context. Raises: - TypeError: If `param` isn't already a @{tf.Tensor} in graph mode. + TypeError: If `param` isn't already a `tf.Tensor` in graph mode. """ if not context.executing_eagerly() and not isinstance(param, ops.Tensor): raise TypeError("graph() needs a tf.Tensor (e.g. tf.placeholder) in graph " @@ -537,21 +539,21 @@ _graph = graph # for functions with a graph parameter def import_event(tensor, name=None): - """Writes a @{tf.Event} binary proto. + """Writes a `tf.Event` binary proto. When using create_db_writer(), this can be used alongside - @{tf.TFRecordReader} to load event logs into the database. Please + `tf.TFRecordReader` to load event logs into the database. Please note that this is lower level than the other summary functions and will ignore any conditions set by methods like - @{tf.contrib.summary.should_record_summaries}. + `tf.contrib.summary.should_record_summaries`. Args: - tensor: A @{tf.Tensor} of type `string` containing a serialized - @{tf.Event} proto. + tensor: A `tf.Tensor` of type `string` containing a serialized + `tf.Event` proto. name: A name for the operation (optional). Returns: - The created @{tf.Operation}. + The created `tf.Operation`. """ return gen_summary_ops.import_event( context.context().summary_writer_resource, tensor, name=name) @@ -563,13 +565,13 @@ def flush(writer=None, name=None): This operation blocks until that finishes. Args: - writer: The @{tf.contrib.summary.SummaryWriter} resource to flush. + writer: The `tf.contrib.summary.SummaryWriter` resource to flush. The thread default will be used if this parameter is None. - Otherwise a @{tf.no_op} is returned. + Otherwise a `tf.no_op` is returned. name: A name for the operation (optional). Returns: - The created @{tf.Operation}. + The created `tf.Operation`. """ if writer is None: writer = context.context().summary_writer_resource @@ -591,7 +593,7 @@ def eval_dir(model_dir, name=None): def create_summary_file_writer(*args, **kwargs): - """Please use @{tf.contrib.summary.create_file_writer}.""" + """Please use `tf.contrib.summary.create_file_writer`.""" logging.warning("Deprecation Warning: create_summary_file_writer was renamed " "to create_file_writer") return create_file_writer(*args, **kwargs) diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index 355b0d961e2105bf19105dbc6f8a9ddfc41c0d30..e7ad261615f57c1e0ff967d0f7cd498571d21bc7 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_util from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator from tensorflow.python.util.deprecation import deprecated @@ -127,7 +128,7 @@ def make_template(name_, func_, create_scope_now_=False, unique_name_=None, template of the same scope/unique_name already exists and reuse is false, an error is raised. Defaults to None. custom_getter_: Optional custom getter for variables used in `func_`. See - the @{tf.get_variable} `custom_getter` documentation for + the `tf.get_variable` `custom_getter` documentation for more information. **kwargs: Keyword arguments to apply to `func_`. @@ -175,7 +176,7 @@ def make_template_internal(name_, template of the same scope/unique_name already exists and reuse is false, an error is raised. Defaults to None. If executing eagerly, must be None. custom_getter_: Optional custom getter for variables used in `func_`. See - the @{tf.get_variable} `custom_getter` documentation for + the `tf.get_variable` `custom_getter` documentation for more information. create_graph_function_: When True, `func_` will be executed as a graph function. This implies that `func_` must satisfy the properties that @@ -295,85 +296,25 @@ class Template(checkpointable.CheckpointableBase): # which is not the same as whether the scope has been created. self._variables_created = False - def _checkpointable_custom_creator(self, next_creator, name, initial_value, - checkpointable_parent=None, **kwargs): - """A variable creation hook which adds Checkpointable dependencies. - - Set during the `Template`'s first wrapped function execution. Ensures that - (a) `Template` objects depend on `Template`s created inside them which - create variables, and (b) that any variables not in a more deeply nested - `Template` are added as dependencies directly. - - The `checkpointable_parent` argument is passed between `Template` custom - creators but ignored when the variable object itself is created. This - argument indicates (if not `None`) that a more deeply nested `Template` has - already added the variable as a dependency, and that parent `Template`s - should add a dependency on that `Template` rather than on the variable - directly. - - Args: - next_creator: See `variable_scope.variable_creator_scope`; the next - creator in the chain. - name: The (full, scope-influenced) name of the variable. The scope name - for the Template itself is stripped for the purposes of object-based - dependency tracking, but scopes within Templates are respected. - initial_value: See `variable_scope.variable_creator_scope`. Taken - explicitly so the argument can be re-named and used with - `Checkpointable._add_variable_with_custom_getter`. - checkpointable_parent: If not None, a more deeply nested Template object - to add a dependency on (rather than depending on the variable directly). - **kwargs: Passed through to the next creator. - Returns: - The output of `next_creator`: the fetched/created variable object. - """ - def _call_next_creator_renaming_initializer(initializer, **inner_kwargs): - inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which - # we don't want to propagate. - return next_creator( - initial_value=initializer, - name=name, - **inner_kwargs) - if name.startswith(self._variable_scope.name): - scope_stripped_name = name[len(self._variable_scope.name) + 1:] - if not checkpointable_parent: - return self._add_variable_with_custom_getter( - initializer=initial_value, - name=scope_stripped_name, - getter=_call_next_creator_renaming_initializer, - # Disable error checking for Checkpointable. Exceptions are instead - # raised if necessary when the object-based saver tries to - # save/restore the object. - overwrite=True, - checkpointable_parent=self, - **kwargs) - else: - self._track_checkpointable( - checkpointable_parent, - name=checkpointable_parent._variable_scope.name[ # pylint: disable=protected-access - len(self._variable_scope.name) + 1:], - overwrite=True) - return next_creator(name=name, initial_value=initial_value, - checkpointable_parent=self, **kwargs) - def _call_func(self, args, kwargs): try: - vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + vars_at_start = len( + ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)) trainable_at_start = len( - ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) + ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)) if self._variables_created: result = self._func(*args, **kwargs) else: # The first time we run, restore variables if necessary (via # Checkpointable). - with variable_scope.variable_creator_scope( - self._checkpointable_custom_creator): + with checkpointable_util.capture_dependencies(template=self): result = self._func(*args, **kwargs) if self._variables_created: # Variables were previously created, implying this is not the first # time the template has been called. Check to make sure that no new # trainable variables were created this time around. - trainable_variables = ops.get_collection( + trainable_variables = ops.get_collection_ref( ops.GraphKeys.TRAINABLE_VARIABLES) # If a variable that we intend to train is created as a side effect # of creating a template, then that is almost certainly an error. @@ -386,7 +327,7 @@ class Template(checkpointable.CheckpointableBase): # Non-trainable tracking variables are a legitimate reason why a new # variable would be created, but it is a relatively advanced use-case, # so log it. - variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + variables = ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES) if vars_at_start != len(variables): logging.info("New variables created when calling a template after " "the first time, perhaps you used tf.Variable when you " @@ -634,8 +575,7 @@ class EagerTemplate(Template): else: # The first time we run, restore variables if necessary (via # Checkpointable). - with variable_scope.variable_creator_scope( - self._checkpointable_custom_creator): + with checkpointable_util.capture_dependencies(template=self): result = self._func(*args, **kwargs) if self._variables_created: diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index cc92da4fd7afd49d0dd80bd859d7393f2761303f..f86dfb35276f608c5cb323fe5deceb58733be007 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -554,7 +554,7 @@ class _EagerTensorArray(object): self._tensor_array.extend([None for _ in range(index - size + 1)]) if not isinstance(value, ops.EagerTensor): - value = constant_op.constant(value) + value = ops.convert_to_tensor(value) if self._infer_shape: if self._element_shape is None: @@ -633,8 +633,8 @@ class _EagerTensorArray(object): def split(self, value, lengths, name=None): """See TensorArray.""" # error checking to match graph-mode errors - value = constant_op.constant(value) - lengths = constant_op.constant(lengths) + value = ops.convert_to_tensor(value) + lengths = ops.convert_to_tensor(lengths) sum_lengths = math_ops.reduce_sum(lengths) if lengths.shape.ndims != 1: raise errors_impl.InvalidArgumentError( diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 9a711edaa44f711aa4122ca282ca99ca23d17bdc..66122be8a330996104de31e3b6e996c20c64c942 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1,4 +1,4 @@ - # Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,9 +44,11 @@ from tensorflow.python.util import function_utils from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -__all__ = ["AUTO_REUSE", "VariableScope", "get_variable_scope", - "get_variable", "get_local_variable", "variable_scope", - "variable_op_scope", "no_regularizer"] +__all__ = [ + "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable", + "get_local_variable", "variable_scope", "variable_op_scope", + "no_regularizer", "VariableSynchronization", "VariableAggregation" +] class _PartitionInfo(object): @@ -188,6 +190,11 @@ class _ReuseMode(enum.Enum): # REUSE_FALSE = 2 # REUSE_TRUE = 3 + +# TODO(apassos) remove these forwarding symbols. +VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name +VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name + AUTO_REUSE = _ReuseMode.AUTO_REUSE tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE") AUTO_REUSE.__doc__ = """ @@ -214,11 +221,23 @@ class _VariableStore(object): self._partitioned_vars = {} # A dict of the stored PartitionedVariables. self._store_eager_variables = False - def get_variable(self, name, shape=None, dtype=dtypes.float32, - initializer=None, regularizer=None, reuse=None, - trainable=True, collections=None, caching_device=None, - partitioner=None, validate_shape=True, use_resource=None, - custom_getter=None, constraint=None): + def get_variable(self, + name, + shape=None, + dtype=dtypes.float32, + initializer=None, + regularizer=None, + reuse=None, + trainable=None, + collections=None, + caching_device=None, + partitioner=None, + validate_shape=True, + use_resource=None, + custom_getter=None, + constraint=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): """Gets an existing variable with these parameters or create a new one. If a variable with the given name is already stored, we return the stored @@ -254,6 +273,8 @@ class _VariableStore(object): forced to be False. trainable: If `True` also add the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + `trainable` defaults to `True` unless `synchronization` is + set to `ON_READ`. collections: List of graph collections keys to add the `Variable` to. Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). caching_device: Optional device string or function describing where the @@ -291,6 +312,15 @@ class _VariableStore(object): variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. Returns: The created or existing `Variable` (or `PartitionedVariable`, if a @@ -343,11 +373,22 @@ class _VariableStore(object): # it to custom_getter. # Note: the parameters of _true_getter, and their documentation, match # *exactly* item-for-item with the docstring of this method. - def _true_getter(name, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring - initializer=None, regularizer=None, reuse=None, - trainable=True, collections=None, caching_device=None, - partitioner=None, validate_shape=True, use_resource=None, - constraint=None): + def _true_getter( # pylint: disable=missing-docstring + name, + shape=None, + dtype=dtypes.float32, + initializer=None, + regularizer=None, + reuse=None, + trainable=None, + collections=None, + caching_device=None, + partitioner=None, + validate_shape=True, + use_resource=None, + constraint=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): is_scalar = (shape is not None and isinstance(shape, collections_lib.Sequence) and not shape) @@ -397,11 +438,24 @@ class _VariableStore(object): "name was already created with partitioning?" % name) return self._get_single_variable( - name=name, shape=shape, dtype=dtype, - initializer=initializer, regularizer=regularizer, reuse=reuse, - trainable=trainable, collections=collections, - caching_device=caching_device, validate_shape=validate_shape, - use_resource=use_resource, constraint=constraint) + name=name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + reuse=reuse, + trainable=trainable, + collections=collections, + caching_device=caching_device, + validate_shape=validate_shape, + use_resource=use_resource, + constraint=constraint, + synchronization=synchronization, + aggregation=aggregation) + + # Set trainable value based on synchronization value. + trainable = _get_trainable_value( + synchronization=synchronization, trainable=trainable) if custom_getter is not None: # Handle backwards compatibility with getter arguments that were added @@ -420,6 +474,8 @@ class _VariableStore(object): "partitioner": partitioner, "validate_shape": validate_shape, "use_resource": use_resource, + "synchronization": synchronization, + "aggregation": aggregation, } # `fn_args` can handle functions, `functools.partial`, `lambda`. if "constraint" in function_utils.fn_args(custom_getter): @@ -427,18 +483,36 @@ class _VariableStore(object): return custom_getter(**custom_getter_kwargs) else: return _true_getter( - name, shape=shape, dtype=dtype, - initializer=initializer, regularizer=regularizer, - reuse=reuse, trainable=trainable, collections=collections, - caching_device=caching_device, partitioner=partitioner, - validate_shape=validate_shape, use_resource=use_resource, - constraint=constraint) - - def _get_partitioned_variable( - self, name, partitioner, shape=None, dtype=dtypes.float32, - initializer=None, regularizer=None, reuse=None, - trainable=True, collections=None, caching_device=None, - validate_shape=True, use_resource=None, constraint=None): + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + reuse=reuse, + trainable=trainable, + collections=collections, + caching_device=caching_device, + partitioner=partitioner, + validate_shape=validate_shape, + use_resource=use_resource, + constraint=constraint, + synchronization=synchronization, + aggregation=aggregation) + + def _get_partitioned_variable(self, + name, + partitioner, + shape=None, + dtype=dtypes.float32, + initializer=None, + regularizer=None, + reuse=None, + trainable=None, + collections=None, + caching_device=None, + validate_shape=True, + use_resource=None, + constraint=None): """Gets or creates a sharded variable list with these parameters. The `partitioner` must be a callable that accepts a fully defined @@ -688,12 +762,14 @@ class _VariableStore(object): regularizer=None, partition_info=None, reuse=None, - trainable=True, + trainable=None, collections=None, caching_device=None, validate_shape=True, use_resource=None, - constraint=None): + constraint=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): """Get or create a single Variable (e.g. a shard or entire variable). See the documentation of get_variable above (ignore partitioning components) @@ -713,6 +789,8 @@ class _VariableStore(object): validate_shape: see get_variable. use_resource: see get_variable. constraint: see get_variable. + synchronization: see get_variable. + aggregation: see get_variable. Returns: A Variable. See documentation of get_variable above. @@ -724,6 +802,7 @@ class _VariableStore(object): initializing_from_value = False if initializer is not None and not callable(initializer): initializing_from_value = True + if shape is not None and initializing_from_value: raise ValueError("If initializer is a constant, do not specify shape.") @@ -759,9 +838,6 @@ class _VariableStore(object): raise ValueError("Variable %s does not exist, or was not created with " "tf.get_variable(). Did you mean to set " "reuse=tf.AUTO_REUSE in VarScope?" % name) - if not shape.is_fully_defined() and not initializing_from_value: - raise ValueError("Shape of a new variable (%s) must be fully defined, " - "but instead was %s." % (name, shape)) # Create the tensor to initialize the variable with default value. if initializer is None: @@ -776,8 +852,11 @@ class _VariableStore(object): # Instantiate initializer if provided initializer is a type object. if isinstance(initializer, type(init_ops.Initializer)): initializer = initializer(dtype=dtype) - init_val = lambda: initializer( # pylint: disable=g-long-lambda - shape.as_list(), dtype=dtype, partition_info=partition_info) + if validate_shape: + init_val = lambda: initializer( # pylint: disable=g-long-lambda + shape.as_list(), dtype=dtype, partition_info=partition_info) + else: + init_val = initializer variable_dtype = dtype.base_dtype # Create the variable. @@ -793,7 +872,9 @@ class _VariableStore(object): dtype=variable_dtype, validate_shape=validate_shape, constraint=constraint, - use_resource=use_resource) + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) if context.executing_eagerly() and self._store_eager_variables: if collections: ops.add_to_collections(collections, v) @@ -822,7 +903,7 @@ class _VariableStore(object): v_name = v.name loss_name = loss.name logging.vlog(1, "Applied regularizer to %s and added the result %s " - "to REGULARIZATION_LOSSES.", v_name, loss_name) + "to REGULARIZATION_LOSSES.", v_name, loss_name) ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss) return v @@ -1045,14 +1126,16 @@ class VariableScope(object): initializer=None, regularizer=None, reuse=None, - trainable=True, + trainable=None, collections=None, caching_device=None, partitioner=None, validate_shape=True, use_resource=None, custom_getter=None, - constraint=None): + constraint=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): """Gets an existing variable with this name or create a new one.""" if regularizer is None: regularizer = self._regularizer @@ -1090,12 +1173,22 @@ class VariableScope(object): if dtype is None: dtype = self._dtype return var_store.get_variable( - full_name, shape=shape, dtype=dtype, initializer=initializer, - regularizer=regularizer, reuse=reuse, trainable=trainable, - collections=collections, caching_device=caching_device, - partitioner=partitioner, validate_shape=validate_shape, - use_resource=use_resource, custom_getter=custom_getter, - constraint=constraint) + full_name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + reuse=reuse, + trainable=trainable, + collections=collections, + caching_device=caching_device, + partitioner=partitioner, + validate_shape=validate_shape, + use_resource=use_resource, + custom_getter=custom_getter, + constraint=constraint, + synchronization=synchronization, + aggregation=aggregation) def _get_partitioned_variable(self, var_store, @@ -1104,7 +1197,7 @@ class VariableScope(object): dtype=None, initializer=None, regularizer=None, - trainable=True, + trainable=None, collections=None, caching_device=None, partitioner=None, @@ -1319,21 +1412,35 @@ def get_variable(name, dtype=None, initializer=None, regularizer=None, - trainable=True, + trainable=None, collections=None, caching_device=None, partitioner=None, validate_shape=True, use_resource=None, custom_getter=None, - constraint=None): + constraint=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): return get_variable_scope().get_variable( - _get_default_variable_store(), name, shape=shape, dtype=dtype, - initializer=initializer, regularizer=regularizer, trainable=trainable, - collections=collections, caching_device=caching_device, - partitioner=partitioner, validate_shape=validate_shape, - use_resource=use_resource, custom_getter=custom_getter, - constraint=constraint) + _get_default_variable_store(), + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + trainable=trainable, + collections=collections, + caching_device=caching_device, + partitioner=partitioner, + validate_shape=validate_shape, + use_resource=use_resource, + custom_getter=custom_getter, + constraint=constraint, + synchronization=synchronization, + aggregation=aggregation) + + get_variable_or_local_docstring = ( """%s @@ -1378,7 +1485,7 @@ Args: unless validate_shape is False. regularizer: A (Tensor -> Tensor or None) function; the result of applying it on a newly created variable will be added to the collection - @{tf.GraphKeys.REGULARIZATION_LOSSES} and can be used for regularization. + `tf.GraphKeys.REGULARIZATION_LOSSES` and can be used for regularization. %scollections: List of graph collections keys to add the Variable to. Defaults to `[%s]` (see `tf.Variable`). caching_device: Optional device string or function describing where the @@ -1430,29 +1537,44 @@ get_variable.__doc__ = get_variable_or_local_docstring % ( # The argument list for get_local_variable must match arguments to get_variable. # So, if you are updating the arguments, also update arguments to get_variable. @tf_export("get_local_variable") -def get_local_variable(name, - shape=None, - dtype=None, - initializer=None, - regularizer=None, - trainable=False, # pylint: disable=unused-argument - collections=None, - caching_device=None, - partitioner=None, - validate_shape=True, - use_resource=None, - custom_getter=None, - constraint=None): +def get_local_variable( # pylint: disable=missing-docstring + name, + shape=None, + dtype=None, + initializer=None, + regularizer=None, + trainable=False, # pylint: disable=unused-argument + collections=None, + caching_device=None, + partitioner=None, + validate_shape=True, + use_resource=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE, + custom_getter=None, + constraint=None): if collections: collections += [ops.GraphKeys.LOCAL_VARIABLES] else: collections = [ops.GraphKeys.LOCAL_VARIABLES] return get_variable( - name, shape=shape, dtype=dtype, initializer=initializer, - regularizer=regularizer, trainable=False, collections=collections, - caching_device=caching_device, partitioner=partitioner, - validate_shape=validate_shape, use_resource=use_resource, - custom_getter=custom_getter, constraint=constraint) + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + trainable=False, + collections=collections, + caching_device=caching_device, + partitioner=partitioner, + validate_shape=validate_shape, + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation, + custom_getter=custom_getter, + constraint=constraint) + + get_local_variable.__doc__ = get_variable_or_local_docstring % ( "Gets an existing *local* variable or creates a new one.", "Behavior is the same as in `get_variable`, except that variables are\n" @@ -1925,7 +2047,8 @@ class variable_scope(object): for this scope as well as all sub-scopes; if tf.AUTO_REUSE, we create variables if they do not exist, and return them otherwise; if None, we inherit the parent scope's reuse flag. When eager execution is enabled, - this argument is always forced to be tf.AUTO_REUSE. + new variables are always created unless an EagerVariableStore or + template is currently active. dtype: type of variables created in this scope (defaults to the type in the passed scope, or inherited from parent scope). use_resource: If False, all variables will be regular Variables. If True, @@ -2201,37 +2324,64 @@ def _compute_slice_dim_and_shape(full_shape, slicing): return slice_dim, slice_shape +def _get_trainable_value(synchronization, trainable): + """Computes the trainable value based on the given arguments.""" + if synchronization == VariableSynchronization.ON_READ: + if trainable: + raise ValueError( + "Synchronization value can be set to " + "VariableSynchronization.ON_READ only for non-trainable variables. " + "You have specified trainable=True and " + "synchronization=VariableSynchronization.ON_READ.") + else: + # Set trainable to be false when variable is to be synced on read. + trainable = False + elif trainable is None: + trainable = True + return trainable + + def default_variable_creator(next_creator=None, **kwargs): """Default variable creator.""" assert next_creator is None initial_value = kwargs.get("initial_value", None) - trainable = kwargs.get("trainable", True) + trainable = kwargs.get("trainable", None) collections = kwargs.get("collections", None) validate_shape = kwargs.get("validate_shape", True) caching_device = kwargs.get("caching_device", None) name = kwargs.get("name", None) + variable_def = kwargs.get("variable_def", None) dtype = kwargs.get("dtype", None) + expected_shape = kwargs.get("expected_shape", None) + import_scope = kwargs.get("import_scope", None) constraint = kwargs.get("constraint", None) use_resource = kwargs.get("use_resource", None) + + # Set trainable value based on synchronization value. + synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO) + trainable = _get_trainable_value( + synchronization=synchronization, trainable=trainable) + if use_resource is None: use_resource = get_variable_scope().use_resource - if use_resource or (use_resource is None and context.executing_eagerly()): + use_resource = use_resource or context.executing_eagerly() + if use_resource: return resource_variable_ops.ResourceVariable( initial_value=initial_value, trainable=trainable, collections=collections, validate_shape=validate_shape, caching_device=caching_device, name=name, dtype=dtype, - constraint=constraint) - elif not use_resource and context.executing_eagerly(): - raise RuntimeError( - "VariableScope should use resource variable when eager execution is" - " enabled, but use_resource is False." - ) + constraint=constraint, variable_def=variable_def, + import_scope=import_scope) else: - return variables.Variable( + return variables.RefVariable( initial_value=initial_value, trainable=trainable, collections=collections, validate_shape=validate_shape, caching_device=caching_device, name=name, dtype=dtype, - constraint=constraint) + constraint=constraint, variable_def=variable_def, + expected_shape=expected_shape, import_scope=import_scope) + + +variables.default_variable_creator = default_variable_creator def _make_getter(captured_getter, captured_previous): @@ -2239,26 +2389,8 @@ def _make_getter(captured_getter, captured_previous): return lambda **kwargs: captured_getter(captured_previous, **kwargs) -def variable(initial_value=None, - trainable=True, - collections=None, - validate_shape=True, - caching_device=None, - name=None, - dtype=None, - constraint=None, - use_resource=None): - previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) - for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access - previous_getter = _make_getter(getter, previous_getter) - return previous_getter(initial_value=initial_value, - trainable=trainable, - collections=collections, - validate_shape=validate_shape, - caching_device=caching_device, - name=name, dtype=dtype, - constraint=constraint, - use_resource=use_resource) +# TODO(apassos) remove forwarding symbol +variable = variables.Variable @tf_contextlib.contextmanager @@ -2292,6 +2424,8 @@ def variable_creator_scope(variable_creator): trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. + `trainable` defaults to `True` unless `synchronization` is + set to `ON_READ`. collections: List of graph collections keys. The new variable is added to these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. validate_shape: If `False`, allows the variable to be initialized with a @@ -2310,6 +2444,15 @@ def variable_creator_scope(variable_creator): constraint: A constraint function to be applied to the variable after updates by some algorithms. use_resource: if True, a ResourceVariable is always created. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. This set may grow over time, so it's important the signature of creators is as mentioned above. diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 4be9f5eb6864015cd9c3f6f3526285ebbdc180f9..402ab2dd9d7f8b9caba3dc7ab58b56e41572bb55 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -17,6 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import enum # pylint: disable=g-bad-import-order + +import six + from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import variable_pb2 from tensorflow.python.eager import context @@ -36,8 +40,101 @@ from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export +def default_variable_creator(_, **kwds): + del kwds + raise NotImplementedError("variable_scope needs to be imported") + + +def _make_getter(captured_getter, captured_previous): + """To avoid capturing loop variables.""" + def getter(**kwargs): + return captured_getter(captured_previous, **kwargs) + return getter + + +@tf_export("VariableSynchronization") +class VariableSynchronization(enum.Enum): + """Indicates when a distributed variable will be synced.""" + + # Indicates that the synchronization will be determined by the current + # `DistributionStrategy` (eg. With `MirroredStrategy` this would be + # `ON_WRITE`). + AUTO = 0 + + # Indicates that there will only be one copy of the variable, so there is no + # need to sync. + NONE = 1 + + # Indicates that the variable will be aggregated across devices + # every time it is updated. + ON_WRITE = 2 + + # Indicates that the variable will be aggregated across devices + # when it is read (eg. when checkpointing or when evaluating an op that uses + # the variable). + ON_READ = 3 + + +@tf_export("VariableAggregation") +class VariableAggregation(enum.Enum): + """Indicates how a distributed variable will be aggregated.""" + NONE = 0 + SUM = 1 + MEAN = 2 + + +class VariableMetaclass(type): + """Metaclass to allow construction of tf.Variable to be overridden.""" + + def _variable_call(cls, + initial_value=None, + trainable=None, + collections=None, + validate_shape=True, + caching_device=None, + name=None, + variable_def=None, + dtype=None, + expected_shape=None, + import_scope=None, + constraint=None, + use_resource=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): + """Call on Variable class. Useful to force the signature.""" + previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) + for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access + previous_getter = _make_getter(getter, previous_getter) + + # Reset `aggregation` that is explicitly set as `None` to the enum NONE. + if aggregation is None: + aggregation = VariableAggregation.NONE + return previous_getter( + initial_value=initial_value, + trainable=trainable, + collections=collections, + validate_shape=validate_shape, + caching_device=caching_device, + name=name, + variable_def=variable_def, + dtype=dtype, + expected_shape=expected_shape, + import_scope=import_scope, + constraint=constraint, + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) + + def __call__(cls, *args, **kwargs): + if cls is Variable: + return cls._variable_call(*args, **kwargs) + else: + return super(VariableMetaclass, cls).__call__(*args, **kwargs) + + @tf_export("Variable") -class Variable(checkpointable.CheckpointableBase): +class Variable(six.with_metaclass(VariableMetaclass, + checkpointable.CheckpointableBase)): """See the @{$variables$Variables How To} for a high level overview. A variable maintains state in the graph across calls to `run()`. You add a @@ -49,112 +146,663 @@ class Variable(checkpointable.CheckpointableBase): the variable are fixed. The value can be changed using one of the assign methods. - If you want to change the shape of a variable later you have to use an - `assign` Op with `validate_shape=False`. + If you want to change the shape of a variable later you have to use an + `assign` Op with `validate_shape=False`. + + Just like any `Tensor`, variables created with `Variable()` can be used as + inputs for other Ops in the graph. Additionally, all the operators + overloaded for the `Tensor` class are carried over to variables, so you can + also add nodes to the graph by just doing arithmetic on variables. + + ```python + import tensorflow as tf + + # Create a variable. + w = tf.Variable(, name=) + + # Use the variable in the graph like any Tensor. + y = tf.matmul(w, ...another variable or tensor...) + + # The overloaded operators are available too. + z = tf.sigmoid(w + y) + + # Assign a new value to the variable with `assign()` or a related method. + w.assign(w + 1.0) + w.assign_add(1.0) + ``` + + When you launch the graph, variables have to be explicitly initialized before + you can run Ops that use their value. You can initialize a variable by + running its *initializer op*, restoring the variable from a save file, or + simply running an `assign` Op that assigns a value to the variable. In fact, + the variable *initializer op* is just an `assign` Op that assigns the + variable's initial value to the variable itself. + + ```python + # Launch the graph in a session. + with tf.Session() as sess: + # Run the variable initializer. + sess.run(w.initializer) + # ...you now can run ops that use the value of 'w'... + ``` + + The most common initialization pattern is to use the convenience function + `global_variables_initializer()` to add an Op to the graph that initializes + all the variables. You then run that Op after launching the graph. + + ```python + # Add an Op to initialize global variables. + init_op = tf.global_variables_initializer() + + # Launch the graph in a session. + with tf.Session() as sess: + # Run the Op that initializes global variables. + sess.run(init_op) + # ...you can now run any Op that uses variable values... + ``` + + If you need to create a variable with an initial value dependent on another + variable, use the other variable's `initialized_value()`. This ensures that + variables are initialized in the right order. + + All variables are automatically collected in the graph where they are + created. By default, the constructor adds the new variable to the graph + collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function + `global_variables()` returns the contents of that collection. + + When building a machine learning model it is often convenient to distinguish + between variables holding the trainable model parameters and other variables + such as a `global step` variable used to count training steps. To make this + easier, the variable constructor supports a `trainable=` parameter. If + `True`, the new variable is also added to the graph collection + `GraphKeys.TRAINABLE_VARIABLES`. The convenience function + `trainable_variables()` returns the contents of this collection. The + various `Optimizer` classes use this collection as the default list of + variables to optimize. + + WARNING: tf.Variable objects by default have a non-intuitive memory model. A + Variable is represented internally as a mutable Tensor which can + non-deterministically alias other Tensors in a graph. The set of operations + which consume a Variable and can lead to aliasing is undetermined and can + change across TensorFlow versions. Avoid writing code which relies on the + value of a Variable either changing or not changing as other operations + happen. For example, using Variable objects or simple functions thereof as + predicates in a `tf.cond` is dangerous and error-prone: + + ``` + v = tf.Variable(True) + tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken. + ``` + + Here replacing adding `use_resource=True` when constructing the variable will + fix any nondeterminism issues: + ``` + v = tf.Variable(True, use_resource=True) + tf.cond(v, lambda: v.assign(False), my_false_fn) + ``` + + To use the replacement for variables which does + not have these issues: + + * Add `use_resource=True` when constructing `tf.Variable`; + * Call `tf.get_variable_scope().set_use_resource(True)` inside a + `tf.variable_scope` before the `tf.get_variable()` call. + """ + + def __init__(self, + initial_value=None, + trainable=True, + collections=None, + validate_shape=True, + caching_device=None, + name=None, + variable_def=None, + dtype=None, + expected_shape=None, + import_scope=None, + constraint=None, + use_resource=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): + """Creates a new variable with value `initial_value`. + + The new variable is added to the graph collections listed in `collections`, + which defaults to `[GraphKeys.GLOBAL_VARIABLES]`. + + If `trainable` is `True` the variable is also added to the graph collection + `GraphKeys.TRAINABLE_VARIABLES`. + + This constructor creates both a `variable` Op and an `assign` Op to set the + variable to its initial value. + + Args: + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. In + that case, `dtype` must be specified. (Note that initializer functions + from init_ops.py must first be bound to a shape before being used here.) + trainable: If `True`, the default, also adds the variable to the graph + collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as + the default list of variables to use by the `Optimizer` classes. + collections: List of graph collections keys. The new variable is added to + these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. + validate_shape: If `False`, allows the variable to be initialized with a + value of unknown shape. If `True`, the default, the shape of + `initial_value` must be known. + caching_device: Optional device string describing where the Variable + should be cached for reading. Defaults to the Variable's device. + If not `None`, caches on another device. Typical use is to cache + on the device where the Ops using the Variable reside, to deduplicate + copying through `Switch` and other conditional statements. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + variable_def: `VariableDef` protocol buffer. If not `None`, recreates + the Variable object with its contents, referencing the variable's nodes + in the graph, which must already exist. The graph is not changed. + `variable_def` and the other arguments are mutually exclusive. + dtype: If set, initial_value will be converted to the given type. + If `None`, either the datatype will be kept (if `initial_value` is + a Tensor), or `convert_to_tensor` will decide. + expected_shape: A TensorShape. If set, initial_value is expected + to have this shape. + import_scope: Optional `string`. Name scope to add to the + `Variable.` Only used when initializing from protocol buffer. + constraint: An optional projection function to be applied to the variable + after being updated by an `Optimizer` (e.g. used to implement norm + constraints or value constraints for layer weights). The function must + take as input the unprojected Tensor representing the value of the + variable and return the Tensor for the projected value + (which must have the same shape). Constraints are not safe to + use when doing asynchronous distributed training. + use_resource: if True, a ResourceVariable is created; otherwise an + old-style ref-based variable is created. When eager execution is enabled + a resource variable is always created. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. + + Raises: + ValueError: If both `variable_def` and initial_value are specified. + ValueError: If the initial value is not specified, or does not have a + shape and `validate_shape` is `True`. + RuntimeError: If eager execution is enabled. + """ + raise NotImplementedError + + def __repr__(self): + raise NotImplementedError + + def value(self): + """Returns the last snapshot of this variable. + + You usually do not need to call this method as all ops that need the value + of the variable call it automatically through a `convert_to_tensor()` call. + + Returns a `Tensor` which holds the value of the variable. You can not + assign a new value to this tensor as it is not a reference to the variable. + + To avoid copies, if the consumer of the returned value is on the same device + as the variable, this actually returns the live value of the variable, not + a copy. Updates to the variable are seen by the consumer. If the consumer + is on a different device it will get a copy of the variable. + + Returns: + A `Tensor` containing the value of the variable. + """ + raise NotImplementedError + + def read_value(self): + """Returns the value of this variable, read in the current context. + + Can be different from value() if it's on another device, with control + dependencies, etc. + + Returns: + A `Tensor` containing the value of the variable. + """ + raise NotImplementedError + + def set_shape(self, shape): + """Overrides the shape for this variable. + + Args: + shape: the `TensorShape` representing the overridden shape. + """ + raise NotImplementedError + + @property + def trainable(self): + raise NotImplementedError + + def eval(self, session=None): + """In a session, computes and returns the value of this variable. + + This is not a graph construction method, it does not add ops to the graph. + + This convenience method requires a session where the graph + containing this variable has been launched. If no session is + passed, the default session is used. See `tf.Session` for more + information on launching a graph and on sessions. + + ```python + v = tf.Variable([1, 2]) + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + # Usage passing the session explicitly. + print(v.eval(sess)) + # Usage with the default session. The 'with' block + # above makes 'sess' the default session. + print(v.eval()) + ``` + + Args: + session: The session to use to evaluate this variable. If + none, the default session is used. + + Returns: + A numpy `ndarray` with a copy of the value of this variable. + """ + raise NotImplementedError + + def initialized_value(self): + """Returns the value of the initialized variable. + + You should use this instead of the variable itself to initialize another + variable with a value that depends on the value of this variable. + + ```python + # Initialize 'v' with a random tensor. + v = tf.Variable(tf.truncated_normal([10, 40])) + # Use `initialized_value` to guarantee that `v` has been + # initialized before its value is used to initialize `w`. + # The random values are picked only once. + w = tf.Variable(v.initialized_value() * 2.0) + ``` + + Returns: + A `Tensor` holding the value of this variable after its initializer + has run. + """ + raise NotImplementedError + + @property + def initial_value(self): + """Returns the Tensor used as the initial value for the variable. + + Note that this is different from `initialized_value()` which runs + the op that initializes the variable before returning its value. + This method returns the tensor that is used by the op that initializes + the variable. + + Returns: + A `Tensor`. + """ + raise NotImplementedError + + @property + def constraint(self): + """Returns the constraint function associated with this variable. + + Returns: + The constraint function that was passed to the variable constructor. + Can be `None` if no constraint was passed. + """ + raise NotImplementedError + + def assign(self, value, use_locking=False): + """Assigns a new value to the variable. + + This is essentially a shortcut for `assign(self, value)`. + + Args: + value: A `Tensor`. The new value for this variable. + use_locking: If `True`, use locking during the assignment. + + Returns: + A `Tensor` that will hold the new value of this variable after + the assignment has completed. + """ + raise NotImplementedError + + def assign_add(self, delta, use_locking=False): + """Adds a value to this variable. + + This is essentially a shortcut for `assign_add(self, delta)`. + + Args: + delta: A `Tensor`. The value to add to this variable. + use_locking: If `True`, use locking during the operation. + + Returns: + A `Tensor` that will hold the new value of this variable after + the addition has completed. + """ + raise NotImplementedError + + def assign_sub(self, delta, use_locking=False): + """Subtracts a value from this variable. + + This is essentially a shortcut for `assign_sub(self, delta)`. + + Args: + delta: A `Tensor`. The value to subtract from this variable. + use_locking: If `True`, use locking during the operation. + + Returns: + A `Tensor` that will hold the new value of this variable after + the subtraction has completed. + """ + raise NotImplementedError + + def scatter_sub(self, sparse_delta, use_locking=False): + """Subtracts `IndexedSlices` from this variable. + + This is essentially a shortcut for `scatter_sub(self, sparse_delta.indices, + sparse_delta.values)`. + + Args: + sparse_delta: `IndexedSlices` to be subtracted from this variable. + use_locking: If `True`, use locking during the operation. + + Returns: + A `Tensor` that will hold the new value of this variable after + the scattered subtraction has completed. + + Raises: + ValueError: if `sparse_delta` is not an `IndexedSlices`. + """ + raise NotImplementedError + + def count_up_to(self, limit): + """Increments this variable until it reaches `limit`. + + When that Op is run it tries to increment the variable by `1`. If + incrementing the variable would bring it above `limit` then the Op raises + the exception `OutOfRangeError`. + + If no error is raised, the Op outputs the value of the variable before + the increment. + + This is essentially a shortcut for `count_up_to(self, limit)`. + + Args: + limit: value at which incrementing the variable raises an error. + + Returns: + A `Tensor` that will hold the variable value before the increment. If no + other Op modifies this variable, the values produced will all be + distinct. + """ + raise NotImplementedError + + def load(self, value, session=None): + """Load new value into this variable. + + Writes new value to variable's memory. Doesn't add ops to the graph. + + This convenience method requires a session where the graph + containing this variable has been launched. If no session is + passed, the default session is used. See `tf.Session` for more + information on launching a graph and on sessions. + + ```python + v = tf.Variable([1, 2]) + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + # Usage passing the session explicitly. + v.load([2, 3], sess) + print(v.eval(sess)) # prints [2 3] + # Usage with the default session. The 'with' block + # above makes 'sess' the default session. + v.load([3, 4], sess) + print(v.eval()) # prints [3 4] + ``` + + Args: + value: New variable value + session: The session to use to evaluate this variable. If + none, the default session is used. + + Raises: + ValueError: Session is not passed and no default session + """ + raise NotImplementedError + + # Conversion to tensor. + @staticmethod + def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name + """Utility function for converting a Variable to a Tensor.""" + _ = name + if dtype and not dtype.is_compatible_with(v.dtype): + raise ValueError( + "Incompatible type conversion requested to type '%s' for variable " + "of type '%s'" % (dtype.name, v.dtype.name)) + if as_ref: + return v._ref() # pylint: disable=protected-access + else: + return v.value() + + @staticmethod + def _OverloadAllOperators(): # pylint: disable=invalid-name + """Register overloads for all operators.""" + for operator in ops.Tensor.OVERLOADABLE_OPERATORS: + Variable._OverloadOperator(operator) + # For slicing, bind getitem differently than a tensor (use SliceHelperVar + # instead) + # pylint: disable=protected-access + setattr(Variable, "__getitem__", array_ops._SliceHelperVar) + + @staticmethod + def _OverloadOperator(operator): # pylint: disable=invalid-name + """Defer an operator overload to `ops.Tensor`. + + We pull the operator out of ops.Tensor dynamically to avoid ordering issues. + + Args: + operator: string. The operator name. + """ + + def _run_op(a, *args): + # pylint: disable=protected-access + return getattr(ops.Tensor, operator)(a._AsTensor(), *args) + # Propagate __doc__ to wrapper + try: + _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__ + except AttributeError: + pass + + setattr(Variable, operator, _run_op) + + # NOTE(mrry): This enables the Variable's overloaded "right" binary + # operators to run when the left operand is an ndarray, because it + # accords the Variable class higher priority than an ndarray, or a + # numpy matrix. + # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ + # mechanism, which allows more control over how Variables interact + # with ndarrays. + __array_priority__ = 100 + + @property + def name(self): + """The name of this variable.""" + raise NotImplementedError + + @property + def initializer(self): + """The initializer operation for this variable.""" + raise NotImplementedError + + @property + def device(self): + """The device of this variable.""" + raise NotImplementedError + + @property + def dtype(self): + """The `DType` of this variable.""" + raise NotImplementedError + + @property + def op(self): + """The `Operation` of this variable.""" + raise NotImplementedError + + @property + def graph(self): + """The `Graph` of this variable.""" + raise NotImplementedError + + @property + def shape(self): + """The `TensorShape` of this variable. + + Returns: + A `TensorShape`. + """ + raise NotImplementedError + + def get_shape(self): + """Alias of Variable.shape.""" + raise NotImplementedError + + def to_proto(self, export_scope=None): + """Converts a `Variable` to a `VariableDef` protocol buffer. + + Args: + export_scope: Optional `string`. Name scope to remove. + + Returns: + A `VariableDef` protocol buffer, or `None` if the `Variable` is not + in the specified name scope. + """ + raise NotImplementedError - Just like any `Tensor`, variables created with `Variable()` can be used as - inputs for other Ops in the graph. Additionally, all the operators - overloaded for the `Tensor` class are carried over to variables, so you can - also add nodes to the graph by just doing arithmetic on variables. + @staticmethod + def from_proto(variable_def, import_scope=None): + """Returns a `Variable` object created from `variable_def`.""" + return RefVariable(variable_def=variable_def, + import_scope=import_scope) - ```python - import tensorflow as tf + class SaveSliceInfo(object): + """Information on how to save this Variable as a slice. - # Create a variable. - w = tf.Variable(, name=) + Provides internal support for saving variables as slices of a larger + variable. This API is not public and is subject to change. - # Use the variable in the graph like any Tensor. - y = tf.matmul(w, ...another variable or tensor...) + Available properties: - # The overloaded operators are available too. - z = tf.sigmoid(w + y) + * full_name + * full_shape + * var_offset + * var_shape + """ - # Assign a new value to the variable with `assign()` or a related method. - w.assign(w + 1.0) - w.assign_add(1.0) - ``` + def __init__(self, + full_name=None, + full_shape=None, + var_offset=None, + var_shape=None, + save_slice_info_def=None, + import_scope=None): + """Create a `SaveSliceInfo`. - When you launch the graph, variables have to be explicitly initialized before - you can run Ops that use their value. You can initialize a variable by - running its *initializer op*, restoring the variable from a save file, or - simply running an `assign` Op that assigns a value to the variable. In fact, - the variable *initializer op* is just an `assign` Op that assigns the - variable's initial value to the variable itself. + Args: + full_name: Name of the full variable of which this `Variable` is a + slice. + full_shape: Shape of the full variable, as a list of int. + var_offset: Offset of this `Variable` into the full variable, as a + list of int. + var_shape: Shape of this `Variable`, as a list of int. + save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`, + recreates the SaveSliceInfo object its contents. + `save_slice_info_def` and other arguments are mutually + exclusive. + import_scope: Optional `string`. Name scope to add. Only used + when initializing from protocol buffer. + """ + if save_slice_info_def: + assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef) + self.full_name = ops.prepend_name_scope( + save_slice_info_def.full_name, import_scope=import_scope) + self.full_shape = [i for i in save_slice_info_def.full_shape] + self.var_offset = [i for i in save_slice_info_def.var_offset] + self.var_shape = [i for i in save_slice_info_def.var_shape] + else: + self.full_name = full_name + self.full_shape = full_shape + self.var_offset = var_offset + self.var_shape = var_shape - ```python - # Launch the graph in a session. - with tf.Session() as sess: - # Run the variable initializer. - sess.run(w.initializer) - # ...you now can run ops that use the value of 'w'... - ``` + @property + def spec(self): + """Computes the spec string used for saving.""" + full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " " + sl_spec = ":".join([ + "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape) + ]) + return full_shape_str + sl_spec - The most common initialization pattern is to use the convenience function - `global_variables_initializer()` to add an Op to the graph that initializes - all the variables. You then run that Op after launching the graph. + def to_proto(self, export_scope=None): + """Returns a SaveSliceInfoDef() proto. - ```python - # Add an Op to initialize global variables. - init_op = tf.global_variables_initializer() + Args: + export_scope: Optional `string`. Name scope to remove. - # Launch the graph in a session. - with tf.Session() as sess: - # Run the Op that initializes global variables. - sess.run(init_op) - # ...you can now run any Op that uses variable values... - ``` + Returns: + A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not + in the specified name scope. + """ + if (export_scope is None or + self.full_name.startswith(export_scope)): + save_slice_info_def = variable_pb2.SaveSliceInfoDef() + save_slice_info_def.full_name = ops.strip_name_scope( + self.full_name, export_scope) + for i in self.full_shape: + save_slice_info_def.full_shape.append(i) + for i in self.var_offset: + save_slice_info_def.var_offset.append(i) + for i in self.var_shape: + save_slice_info_def.var_shape.append(i) + return save_slice_info_def + else: + return None - If you need to create a variable with an initial value dependent on another - variable, use the other variable's `initialized_value()`. This ensures that - variables are initialized in the right order. + def __iadd__(self, other): + raise NotImplementedError - All variables are automatically collected in the graph where they are - created. By default, the constructor adds the new variable to the graph - collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function - `global_variables()` returns the contents of that collection. + def __isub__(self, other): + raise NotImplementedError - When building a machine learning model it is often convenient to distinguish - between variables holding the trainable model parameters and other variables - such as a `global step` variable used to count training steps. To make this - easier, the variable constructor supports a `trainable=` parameter. If - `True`, the new variable is also added to the graph collection - `GraphKeys.TRAINABLE_VARIABLES`. The convenience function - `trainable_variables()` returns the contents of this collection. The - various `Optimizer` classes use this collection as the default list of - variables to optimize. + def __imul__(self, other): + raise NotImplementedError - WARNING: tf.Variable objects have a non-intuitive memory model. A Variable is - represented internally as a mutable Tensor which can non-deterministically - alias other Tensors in a graph. The set of operations which consume a Variable - and can lead to aliasing is undetermined and can change across TensorFlow - versions. Avoid writing code which relies on the value of a Variable either - changing or not changing as other operations happen. For example, using - Variable objects or simple functions thereof as predicates in a `tf.cond` is - dangerous and error-prone: + def __idiv__(self, other): + raise NotImplementedError - ``` - v = tf.Variable(True) - tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken. - ``` + def __itruediv__(self, other): + raise NotImplementedError - Here replacing tf.Variable with tf.contrib.eager.Variable will fix any - nondeterminism issues. + def __irealdiv__(self, other): + raise NotImplementedError - To use the replacement for variables which does - not have these issues: + def __ipow__(self, other): + raise NotImplementedError - * Replace `tf.Variable` with `tf.contrib.eager.Variable`; - * Call `tf.get_variable_scope().set_use_resource(True)` inside a - `tf.variable_scope` before the `tf.get_variable()` call. - @compatibility(eager) - `tf.Variable` is not compatible with eager execution. Use - `tf.contrib.eager.Variable` instead which is compatible with both eager - execution and graph construction. See [the TensorFlow Eager Execution - guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) - for details on how variables work in eager execution. - @end_compatibility - """ +# TODO(apassos): do not repeat all comments here +class RefVariable(Variable): + """Ref-based implementation of variables.""" def __init__(self, initial_value=None, @@ -225,19 +873,7 @@ class Variable(checkpointable.CheckpointableBase): ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. RuntimeError: If eager execution is enabled. - - @compatibility(eager) - `tf.Variable` is not compatible with eager execution. Use - `tfe.Variable` instead which is compatible with both eager execution - and graph construction. See [the TensorFlow Eager Execution - guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) - for details on how variables work in eager execution. - @end_compatibility """ - if context.executing_eagerly(): - raise RuntimeError( - "tf.Variable not supported when eager execution is enabled. " - "Please use tf.contrib.eager.Variable instead") self._in_graph_mode = True if variable_def: # If variable_def is provided, recreates the variable from its fields. @@ -348,8 +984,7 @@ class Variable(checkpointable.CheckpointableBase): # Ensure that we weren't lifted into the eager context. if context.executing_eagerly(): raise RuntimeError( - "tf.Variable not supported when eager execution is enabled. " - "Please use tf.contrib.eager.Variable instead") + "RefVariable not supported when eager execution is enabled. ") with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: @@ -471,7 +1106,7 @@ class Variable(checkpointable.CheckpointableBase): def _AsTensor(self): # pylint: disable=invalid-name """Converts this variable to a Tensor. - See @{tf.Variable.value}. + See `tf.Variable.value`. Returns: A `Tensor` containing the value of the variable. @@ -528,7 +1163,7 @@ class Variable(checkpointable.CheckpointableBase): Returns is a `Tensor` which holds a reference to the variable. You can assign a new value to the variable by passing the tensor to an assign op. - See @{tf.Variable.value} if you want to get the value of the + See `tf.Variable.value` if you want to get the value of the variable. Returns: @@ -556,7 +1191,7 @@ class Variable(checkpointable.CheckpointableBase): This convenience method requires a session where the graph containing this variable has been launched. If no session is - passed, the default session is used. See @{tf.Session} for more + passed, the default session is used. See `tf.Session` for more information on launching a graph and on sessions. ```python @@ -751,7 +1386,7 @@ class Variable(checkpointable.CheckpointableBase): This convenience method requires a session where the graph containing this variable has been launched. If no session is - passed, the default session is used. See @{tf.Session} for more + passed, the default session is used. See `tf.Session` for more information on launching a graph and on sessions. ```python @@ -1068,12 +1703,6 @@ class Variable(checkpointable.CheckpointableBase): else: return None - @staticmethod - def from_proto(variable_def, import_scope=None): - """Returns a `Variable` object created from `variable_def`.""" - return Variable(variable_def=variable_def, - import_scope=import_scope) - def __iadd__(self, other): logging.log_first_n( logging.WARN, @@ -1093,126 +1722,43 @@ class Variable(checkpointable.CheckpointableBase): def __imul__(self, other): logging.log_first_n( logging.WARN, - "Variable *= will be deprecated. Use variable.assign_mul" - " if you want assignment to the variable value or 'x = x * y'" + "Variable *= will be deprecated. Use `var.assign(var * other)`" + " if you want assignment to the variable value or `x = x * y`" " if you want a new python Tensor object.", 1) return self * other def __idiv__(self, other): logging.log_first_n( logging.WARN, - "Variable /= will be deprecated. Use variable.assign_div" - " if you want assignment to the variable value or 'x = x / y'" + "Variable /= will be deprecated. Use `var.assign(var / other)`" + " if you want assignment to the variable value or `x = x / y`" " if you want a new python Tensor object.", 1) return self / other def __itruediv__(self, other): logging.log_first_n( logging.WARN, - "Variable /= will be deprecated. Use variable.assign_div" - " if you want assignment to the variable value or 'x = x / y'" + "Variable /= will be deprecated. Use `var.assign(var / other)`" + " if you want assignment to the variable value or `x = x / y`" " if you want a new python Tensor object.", 1) return self / other def __irealdiv__(self, other): logging.log_first_n( logging.WARN, - "Variable /= will be deprecated. Use variable.assign_div" - " if you want assignment to the variable value or 'x = x / y'" + "Variable /= will be deprecated. Use `var.assign(var / other)`" + " if you want assignment to the variable value or `x = x / y`" " if you want a new python Tensor object.", 1) return self / other def __ipow__(self, other): logging.log_first_n( logging.WARN, - "Variable **= will be deprecated. Use 'x = x ** y'" + "Variable **= will be deprecated. Use `var.assign(var ** other)`" + " if you want assignment to the variable value or `x = x ** y`" " if you want a new python Tensor object.", 1) return self ** other - class SaveSliceInfo(object): - """Information on how to save this Variable as a slice. - - Provides internal support for saving variables as slices of a larger - variable. This API is not public and is subject to change. - - Available properties: - - * full_name - * full_shape - * var_offset - * var_shape - """ - - def __init__(self, - full_name=None, - full_shape=None, - var_offset=None, - var_shape=None, - save_slice_info_def=None, - import_scope=None): - """Create a `SaveSliceInfo`. - - Args: - full_name: Name of the full variable of which this `Variable` is a - slice. - full_shape: Shape of the full variable, as a list of int. - var_offset: Offset of this `Variable` into the full variable, as a - list of int. - var_shape: Shape of this `Variable`, as a list of int. - save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`, - recreates the SaveSliceInfo object its contents. - `save_slice_info_def` and other arguments are mutually - exclusive. - import_scope: Optional `string`. Name scope to add. Only used - when initializing from protocol buffer. - """ - if save_slice_info_def: - assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef) - self.full_name = ops.prepend_name_scope( - save_slice_info_def.full_name, import_scope=import_scope) - self.full_shape = [i for i in save_slice_info_def.full_shape] - self.var_offset = [i for i in save_slice_info_def.var_offset] - self.var_shape = [i for i in save_slice_info_def.var_shape] - else: - self.full_name = full_name - self.full_shape = full_shape - self.var_offset = var_offset - self.var_shape = var_shape - - @property - def spec(self): - """Computes the spec string used for saving.""" - full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " " - sl_spec = ":".join([ - "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape) - ]) - return full_shape_str + sl_spec - - def to_proto(self, export_scope=None): - """Returns a SaveSliceInfoDef() proto. - - Args: - export_scope: Optional `string`. Name scope to remove. - - Returns: - A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not - in the specified name scope. - """ - if (export_scope is None or - self.full_name.startswith(export_scope)): - save_slice_info_def = variable_pb2.SaveSliceInfoDef() - save_slice_info_def.full_name = ops.strip_name_scope( - self.full_name, export_scope) - for i in self.full_shape: - save_slice_info_def.full_shape.append(i) - for i in self.var_offset: - save_slice_info_def.var_offset.append(i) - for i in self.var_shape: - save_slice_info_def.var_shape.append(i) - return save_slice_info_def - else: - return None - def _set_save_slice_info(self, save_slice_info): """Sets the slice info for this `Variable`. @@ -1229,7 +1775,7 @@ class PartitionedVariable(object): """A container for partitioned `Variable` objects. @compatibility(eager) `tf.PartitionedVariable` is not compatible with - eager execution. Use `tfe.Variable` instead which is compatible + eager execution. Use `tf.Variable` instead which is compatible with both eager execution and graph construction. See [the TensorFlow Eager Execution guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) @@ -1371,15 +1917,10 @@ class PartitionedVariable(object): def as_tensor(self): """Returns the overall concatenated value as a `Tensor`. - The returned tensor will not inherit the control dependencies from the scope - where the value is used, which is similar to getting the value of - `Variable`. - Returns: `Tensor` containing the concatenated value. """ - with ops.control_dependencies(None): - return self._concat() + return self._concat() @staticmethod def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): @@ -1403,6 +1944,10 @@ class PartitionedVariable(object): def dtype(self): return self._dtype + @property + def shape(self): + return self.get_shape() + def get_shape(self): return self._shape @@ -1429,7 +1974,7 @@ def global_variables(scope=None): This convenience function returns the contents of that collection. An alternative to global variables are local variables. See - @{tf.local_variables} + `tf.local_variables` Args: scope: (Optional.) A string. If supplied, the resulting list is filtered @@ -1482,7 +2027,7 @@ def local_variables(scope=None): This convenience function returns the contents of that collection. An alternative to local variables are global variables. See - @{tf.global_variables} + `tf.global_variables` Args: scope: (Optional.) A string. If supplied, the resulting list is filtered @@ -1722,6 +2267,8 @@ def report_uninitialized_variables(var_list=None, var_list.append(op.outputs[0]) with ops.name_scope(name): # Run all operations on CPU + if var_list: + init_vars = [state_ops.is_variable_initialized(v) for v in var_list] with ops.device("/cpu:0"): if not var_list: # Return an empty tensor so we only need to check for returned tensor @@ -1729,9 +2276,7 @@ def report_uninitialized_variables(var_list=None, return array_ops.constant([], dtype=dtypes.string) else: # Get a 1-D boolean tensor listing whether each variable is initialized. - variables_mask = math_ops.logical_not( - array_ops.stack( - [state_ops.is_variable_initialized(v) for v in var_list])) + variables_mask = math_ops.logical_not(array_ops.stack(init_vars)) # Get a 1-D string tensor containing all the variable names. variable_names_tensor = array_ops.constant( [s.op.name for s in var_list]) diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py index eba2baaf6f836c872c8315e558c51733fc013ec2..fa17b17d104221990ed7847b725c4b741cb4aca7 100644 --- a/tensorflow/python/platform/benchmark.py +++ b/tensorflow/python/platform/benchmark.py @@ -66,11 +66,11 @@ def _global_report_benchmark( if not isinstance(extras, dict): raise TypeError("extras must be a dict") - logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g," - "throughput: %g %s", name, iters if iters is not None else -1, - wall_time if wall_time is not None else -1, cpu_time if - cpu_time is not None else -1, throughput if - throughput is not None else -1, str(extras) if extras else "") + logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g," + "throughput: %g %s", name, iters if iters is not None else -1, + wall_time if wall_time is not None else -1, cpu_time if + cpu_time is not None else -1, throughput if + throughput is not None else -1, str(extras) if extras else "") entries = test_log_pb2.BenchmarkEntries() entry = entries.entry.add() diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py index fd697d70bf200f1f661b410a9636d7b60e87f430..45de047894dddc8a82eb50bb2a38cd6d4ffcabcb 100644 --- a/tensorflow/python/platform/gfile.py +++ b/tensorflow/python/platform/gfile.py @@ -38,7 +38,14 @@ from tensorflow.python.util.tf_export import tf_export @tf_export('gfile.GFile', 'gfile.Open') class GFile(_FileIO): - """File I/O wrappers without thread locking.""" + """File I/O wrappers without thread locking. + + Note, that this is somewhat like builtin Python file I/O, but + there are semantic differences to make it more efficient for + some backing filesystems. For example, a write mode file will + not be opened until the first write call (to minimize RPC + invocations in network filesystems). + """ def __init__(self, name, mode='r'): super(GFile, self).__init__(name=name, mode=mode) @@ -46,7 +53,14 @@ class GFile(_FileIO): @tf_export('gfile.FastGFile') class FastGFile(_FileIO): - """File I/O wrappers without thread locking.""" + """File I/O wrappers without thread locking. + + Note, that this is somewhat like builtin Python file I/O, but + there are semantic differences to make it more efficient for + some backing filesystems. For example, a write mode file will + not be opened until the first write call (to minimize RPC + invocations in network filesystems). + """ def __init__(self, name, mode='r'): super(FastGFile, self).__init__(name=name, mode=mode) diff --git a/tensorflow/python/platform/self_check.py b/tensorflow/python/platform/self_check.py index 966a094e55e09d51c2d5edd36eb3ca29e71935f8..844ae999186f6eed89b113469782840f08502a85 100644 --- a/tensorflow/python/platform/self_check.py +++ b/tensorflow/python/platform/self_check.py @@ -78,7 +78,7 @@ def preload_check(): "Could not find %r. TensorFlow requires that this DLL be " "installed in a directory that is named in your %%PATH%% " "environment variable. Download and install CUDA %s from " - "this URL: https://developer.nvidia.com/cuda-toolkit" + "this URL: https://developer.nvidia.com/cuda-90-download-archive" % (build_info.cudart_dll_name, build_info.cuda_version_number)) if hasattr(build_info, "cudnn_dll_name") and hasattr( diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py index 9e49188c1ef353d345c97ea0295aa1a68283605e..c0e16ca536e5ff2b3fdbd17088f3b1eebe0b50ec 100644 --- a/tensorflow/python/profiler/model_analyzer_test.py +++ b/tensorflow/python/profiler/model_analyzer_test.py @@ -106,7 +106,7 @@ class PrintModelAnalysisTest(test.TestCase): # Make sure time is profiled. gap = 1 if test.is_gpu_available() else 2 for i in range(3, 6, gap): - mat = re.search('(.*)[um]s/(.*)[um]s', metrics[i]) + mat = re.search('(.*)(?:us|ms|sec)/(.*)(?:us|ms|sec)', metrics[i]) self.assertGreater(float(mat.group(1)), 0.0) self.assertGreater(float(mat.group(2)), 0.0) # Make sure device is profiled. @@ -707,8 +707,10 @@ class PrintModelAnalysisTest(test.TestCase): a = array_ops.constant(np.ones((100, 100))) b = array_ops.constant(np.ones((100, 100))) c = a * b + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.min_graph_nodes = -1 - with session.Session() as sess: + with session.Session(config=config) as sess: run_options = config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE) run_metadata = config_pb2.RunMetadata() diff --git a/tensorflow/python/profiler/profile_context.py b/tensorflow/python/profiler/profile_context.py index 18eb66ef988c9f49eb04264545d417d8a986e16e..fa4260a7120d72eacff32a7b4960b34545eb32e5 100644 --- a/tensorflow/python/profiler/profile_context.py +++ b/tensorflow/python/profiler/profile_context.py @@ -88,16 +88,19 @@ def _profiled_run(self, to_profiles = self.profile_context._profile_candidates() for to_prof in to_profiles: cmd, opts, _ = to_prof + saved_views = self.profile_context._views.setdefault(cmd, {}) if self.profile_context._debug: sys.stderr.write('debug: profiling %s step: %d\n' % (cmd, step)) if cmd == 'graph': - self.profile_context.profiler.profile_graph(opts) + saved_views[step] = self.profile_context.profiler.profile_graph(opts) elif cmd == 'scope': - self.profile_context.profiler.profile_name_scope(opts) + saved_views[step] = self.profile_context.profiler.profile_name_scope( + opts) elif cmd == 'op': - self.profile_context.profiler.profile_operations(opts) + saved_views[step] = self.profile_context.profiler.profile_operations( + opts) elif cmd == 'code': - self.profile_context.profiler.profile_python(opts) + saved_views[step] = self.profile_context.profiler.profile_python(opts) else: raise ValueError('Unknown cmd: %s\n' % cmd) return ret @@ -185,8 +188,30 @@ class ProfileContext(object): self._traced_steps = 0 self._auto_profiles = [] self._profiler = None + self._views = {} self._lock = threading.Lock() + def get_profiles(self, cmd): + """Returns profiling results for each step at which `cmd` was run. + + Args: + cmd: string, profiling command used in an `add_auto_profiling` call. + + Returns: + dict[int: (MultiGraphNodeProto | GraphNodeProto)]. Keys are steps at which + the profiling command was run. Values are the outputs of profiling. + For "code" and "op" commands this will be a `MultiGraphNodeProto`, for + "scope" and "graph" commands this will be a `GraphNodeProto. + + Raises: + ValueError: if `cmd` was never run (either because no session.run call was + made or because there was no `add_auto_profiling` call with the specified + `cmd`. + """ + if cmd not in self._views: + raise ValueError('No autoprofiler for command: {}, was run'.format(cmd)) + return self._views[cmd] + def add_auto_profiling(self, cmd, options, profile_steps): """Traces and profiles at some session run steps. diff --git a/tensorflow/python/profiler/profile_context_test.py b/tensorflow/python/profiler/profile_context_test.py index a623beee23ebf98cf96bd0f334f813db5ae04040..107ad443c32e20ab69f3c2fb71c652d97a9c0cc6 100644 --- a/tensorflow/python/profiler/profile_context_test.py +++ b/tensorflow/python/profiler/profile_context_test.py @@ -61,6 +61,8 @@ class ProfilerContextTest(test.TestCase): profile_str = f.read() gfile.Remove(outfile) + self.assertEqual(set([15, 50, 100]), set(pctx.get_profiles("op").keys())) + with lib.ProfilerFromFile( os.path.join(test.get_temp_dir(), "profile_100")) as profiler: profiler.profile_operations(options=opts) diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 500dc30cc30f757965791e504bc79718bb7f7bd7..157f2341e0846764860af07f0656377b84b16ac6 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -29,6 +29,7 @@ limitations under the License. %rename("%s") TFE_ContextGetDevicePlacementPolicy; %rename("%s") TFE_ContextSetThreadLocalDevicePlacementPolicy; %rename("%s") TFE_ContextSetAsyncForThread; +%rename("%s") TFE_ContextSetServerDef; %rename("%s") TFE_ContextAsyncWait; %rename("%s") TFE_ContextAsyncClearError; %rename("%s") TFE_OpNameGetAttrType; @@ -62,6 +63,8 @@ limitations under the License. %rename("%s") TFE_DeleteContextOptions; %rename("%s") TFE_Py_TensorShapeSlice; %rename("%s") TFE_Py_TensorShapeOnDevice; +%rename("%s") TFE_ContextStartStep; +%rename("%s") TFE_ContextEndStep; %{ #include "tensorflow/python/eager/pywrap_tfe.h" diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 81786fbf435ffebba6217c0a03f06494195afc3c..7a37eda5eadbd0e133ec662e2a77240538d28782 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -62,6 +62,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":constants", + ":utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lib", @@ -81,12 +82,37 @@ py_library( srcs_version = "PY2AND3", deps = [ ":constants", + ":utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python:variables", + ], +) + +py_test( + name = "loader_test", + size = "small", + srcs = ["loader_test.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + ":builder", + ":loader", + ":signature_def_utils", + ":utils", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lib", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", ], ) @@ -163,8 +189,10 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":constants", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:lib", "//tensorflow/python:sparse_tensor", "//tensorflow/python:util", ], diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index e58be804c2738dbad0e2f90c21d6eff3832a8148..8e7f123a85aae7d714b162096e1a40ab498c3312 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -32,8 +32,10 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.training import saver as tf_saver from tensorflow.python.util import compat +from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.tf_export import tf_export @@ -111,12 +113,8 @@ class SavedModelBuilder(object): tf_logging.info("No assets to write.") return - assets_destination_dir = os.path.join( - compat.as_bytes(self._export_dir), - compat.as_bytes(constants.ASSETS_DIRECTORY)) - - if not file_io.file_exists(assets_destination_dir): - file_io.recursive_create_dir(assets_destination_dir) + assets_destination_dir = saved_model_utils.get_or_create_assets_dir( + self._export_dir) # Copy each asset from source path to destination path. for asset_basename, asset_source_filepath in asset_filename_map.items(): @@ -133,39 +131,32 @@ class SavedModelBuilder(object): tf_logging.info("Assets written to: %s", compat.as_text(assets_destination_dir)) - def _maybe_add_legacy_init_op(self, legacy_init_op=None): - """Add legacy init op to the SavedModel. + def _maybe_add_main_op(self, main_op): + """Adds main op to the SavedModel. Args: - legacy_init_op: Optional legacy init op to support backward compatibility. + main_op: Main op to run as part of graph initialization. If None, no + main op will be added to the graph. Raises: - TypeError if legacy init op is not of type `Operation`. - AssertionError if the graph already contains one or more legacy init ops. + TypeError: if main op is provided but is not of type `Operation`. + ValueError: if the Graph already contains an init op. """ - if legacy_init_op is not None: - if not isinstance(legacy_init_op, ops.Operation): - raise TypeError("legacy_init_op needs to be an Operation: %r" % - legacy_init_op) - if ops.get_collection(constants.LEGACY_INIT_OP_KEY): - raise AssertionError( - "graph already contains one or more legacy init ops under the " - "collection {}.".format(constants.LEGACY_INIT_OP_KEY)) - ops.add_to_collection(constants.LEGACY_INIT_OP_KEY, legacy_init_op) - - def _add_main_op(self, main_op): - """Add main op to the SavedModel. + if main_op is None: + return - Args: - main_op: Main op to run as part of graph initialization. + if not isinstance(main_op, ops.Operation): + raise TypeError("main_op needs to be an Operation: %r" % main_op) - Raises: - TypeError if main op is not of type `Operation`. - """ - if main_op is not None: - if not isinstance(main_op, ops.Operation): - raise TypeError("main_op needs to be an Operation: %r" % main_op) - ops.add_to_collection(constants.MAIN_OP_KEY, main_op) + # Validate that no other init ops have been added to this graph already. + # We check main_op and legacy_init_op for thoroughness and explicitness. + for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY): + if ops.get_collection(init_op_key): + raise ValueError( + "Graph already contains one or more main ops under the " + "collection {}.".format(init_op_key)) + + ops.add_to_collection(constants.MAIN_OP_KEY, main_op) def _add_train_op(self, train_op): """Add train op to the SavedModel. @@ -257,16 +248,12 @@ class SavedModelBuilder(object): self._validate_tensor_info(outputs[outputs_key]) def _add_collections( - self, assets_collection, legacy_init_op, main_op, train_op): + self, assets_collection, main_op, train_op): """Add asset and op collections to be saved.""" # Save asset files and write them to disk, if any. self._save_and_write_assets(assets_collection) - if main_op is None: - # Add legacy init op to the SavedModel. - self._maybe_add_legacy_init_op(legacy_init_op) - else: - self._add_main_op(main_op) + self._maybe_add_main_op(main_op) self._add_train_op(train_op) @@ -282,6 +269,9 @@ class SavedModelBuilder(object): allow_empty=True) return saver + @deprecated_args(None, + "Pass your op to the equivalent parameter main_op instead.", + "legacy_init_op") def add_meta_graph(self, tags, signature_def_map=None, @@ -306,7 +296,7 @@ class SavedModelBuilder(object): that this collection should be a subset of the assets saved as part of the first meta graph in the SavedModel. legacy_init_op: Legacy support for op or group of ops to execute after the - restore op upon a load. + restore op upon a load. Deprecated; please use main_op instead. clear_devices: Set to true if the device info on the default graph should be cleared. main_op: Op or group of ops to execute when the graph is loaded. Note @@ -333,8 +323,12 @@ class SavedModelBuilder(object): # properly populated. self._validate_signature_def_map(signature_def_map) + # legacy_init_op is deprecated, and going away in TF 2.0. + # Re-mapping to main_op, as treatment is identical regardless. + main_op = main_op or legacy_init_op + # Add assets and ops - self._add_collections(assets_collection, legacy_init_op, main_op, None) + self._add_collections(assets_collection, main_op, None) saver = self._maybe_create_saver(saver) @@ -351,6 +345,9 @@ class SavedModelBuilder(object): # Tag the meta graph def and add it to the SavedModel. self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) + @deprecated_args(None, + "Pass your op to the equivalent parameter main_op instead.", + "legacy_init_op") def add_meta_graph_and_variables(self, sess, tags, @@ -378,7 +375,7 @@ class SavedModelBuilder(object): def. assets_collection: Assets collection to be saved with SavedModel. legacy_init_op: Legacy support for op or group of ops to execute after the - restore op upon a load. + restore op upon a load. Deprecated; please use main_op instead. clear_devices: Set to true if the device info on the default graph should be cleared. main_op: Op or group of ops to execute when the graph is loaded. Note @@ -402,19 +399,15 @@ class SavedModelBuilder(object): # properly populated. self._validate_signature_def_map(signature_def_map) + # legacy_init_op is deprecated, and going away in TF 2.0. + # Re-mapping to main_op, as treatment is identical regardless. + main_op = main_op or legacy_init_op + # Add assets and ops - self._add_collections(assets_collection, legacy_init_op, main_op, None) - - # Create the variables sub-directory, if it does not exist. - variables_dir = os.path.join( - compat.as_text(self._export_dir), - compat.as_text(constants.VARIABLES_DIRECTORY)) - if not file_io.file_exists(variables_dir): - file_io.recursive_create_dir(variables_dir) - - variables_path = os.path.join( - compat.as_text(variables_dir), - compat.as_text(constants.VARIABLES_FILENAME)) + self._add_collections(assets_collection, main_op, None) + + saved_model_utils.get_or_create_variables_dir(self._export_dir) + variables_path = saved_model_utils.get_variables_path(self._export_dir) saver = self._maybe_create_saver(saver) diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py index 61c6ffbd0d11ef48c6dfb8d14a4328df7f7c5df5..cb251f08bb56fd5496ea4f3aaedfd2822ae1565c 100644 --- a/tensorflow/python/saved_model/constants.py +++ b/tensorflow/python/saved_model/constants.py @@ -60,6 +60,10 @@ SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt" tf_export("saved_model.constants.SAVED_MODEL_FILENAME_PBTXT").export_constant( __name__, "SAVED_MODEL_FILENAME_PBTXT") +# File name for json format of SavedModel. +# Not exported while keras_saved_model is in contrib. +SAVED_MODEL_FILENAME_JSON = "saved_model.json" + # Subdirectory name containing the variables/checkpoint files. VARIABLES_DIRECTORY = "variables" tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant( @@ -69,5 +73,3 @@ tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant( VARIABLES_FILENAME = "variables" tf_export("saved_model.constants.VARIABLES_FILENAME").export_constant( __name__, "VARIABLES_FILENAME") - - diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index d1bd8d47aee94fd913b807860eff7fa94bb469e5..e8536108e8711f903f1db74775f76e6836642396 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -28,8 +28,10 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.training import saver as tf_saver from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -115,11 +117,14 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None): return asset_tensor_dict -def _get_main_op_tensor(meta_graph_def_to_load): +def _get_main_op_tensor( + meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY): """Gets the main op tensor, if one exists. Args: meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. + init_op_key: name of collection to check; should be one of MAIN_OP_KEY + or the deprecated LEGACY_INIT_OP_KEY Returns: The main op tensor, if it exists and `None` otherwise. @@ -130,38 +135,15 @@ def _get_main_op_tensor(meta_graph_def_to_load): """ collection_def = meta_graph_def_to_load.collection_def main_op_tensor = None - if constants.MAIN_OP_KEY in collection_def: - main_ops = collection_def[constants.MAIN_OP_KEY].node_list.value + if init_op_key in collection_def: + main_ops = collection_def[init_op_key].node_list.value if len(main_ops) != 1: - raise RuntimeError("Expected exactly one SavedModel main op.") - main_op_tensor = ops.get_collection(constants.MAIN_OP_KEY)[0] + raise RuntimeError("Expected exactly one SavedModel main op. " + "Found: {}".format(main_ops)) + main_op_tensor = ops.get_collection(init_op_key)[0] return main_op_tensor -def _get_legacy_init_op_tensor(meta_graph_def_to_load): - """Gets the legacy init op tensor, if one exists. - - Args: - meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. - - Returns: - The legacy init op tensor, if it exists and `None` otherwise. - - Raises: - RuntimeError: If the collection def corresponding to the legacy init op key - has other than exactly one tensor. - """ - collection_def = meta_graph_def_to_load.collection_def - legacy_init_op_tensor = None - if constants.LEGACY_INIT_OP_KEY in collection_def: - legacy_init_ops = collection_def[ - constants.LEGACY_INIT_OP_KEY].node_list.value - if len(legacy_init_ops) != 1: - raise RuntimeError("Expected exactly one legacy serving init op.") - legacy_init_op_tensor = ops.get_collection(constants.LEGACY_INIT_OP_KEY)[0] - return legacy_init_op_tensor - - @tf_export("saved_model.loader.maybe_saved_model_directory") def maybe_saved_model_directory(export_dir): """Checks whether the provided export directory could contain a SavedModel. @@ -207,11 +189,53 @@ def load(sess, tags, export_dir, import_scope=None, **saver_kwargs): Raises: RuntimeError: MetaGraphDef associated with the tags cannot be found. """ - with sess.graph.as_default(): - # Build the SavedModel protocol buffer and find requested meta graph def. - saved_model = _parse_saved_model(export_dir) + loader = SavedModelLoader(export_dir) + return loader.load(sess, tags, import_scope, **saver_kwargs) + + +class SavedModelLoader(object): + """Load graphs and restore variable values from a `SavedModel`.""" + + def __init__(self, export_dir): + """Creates a `SavedModelLoader`. + + Args: + export_dir: Directory in which the SavedModel protocol buffer and + variables to be loaded are located. + """ + self._export_dir = export_dir + self._variables_path = saved_model_utils.get_variables_path(export_dir) + self._saved_model = _parse_saved_model(export_dir) + + @property + def export_dir(self): + """Directory containing the SavedModel.""" + return self._export_dir + + @property + def variables_path(self): + """Path to variable checkpoint files.""" + return self._variables_path + + @property + def saved_model(self): + """SavedModel object parsed from the export directory.""" + return self._saved_model + + def get_meta_graph_def_from_tags(self, tags): + """Return MetaGraphDef with the exact specified tags. + + Args: + tags: A list or set of string tags that identify the MetaGraphDef. + + Returns: + MetaGraphDef with the same tags. + + Raises: + RuntimeError: if no metagraphs were found with the associated tags. + """ found_match = False - for meta_graph_def in saved_model.meta_graphs: + for meta_graph_def in self._saved_model.meta_graphs: if set(meta_graph_def.meta_info_def.tags) == set(tags): meta_graph_def_to_load = meta_graph_def found_match = True @@ -223,32 +247,103 @@ def load(sess, tags, export_dir, import_scope=None, **saver_kwargs): " could not be found in SavedModel. To inspect available tag-sets in" " the SavedModel, please use the SavedModel CLI: `saved_model_cli`" ) + return meta_graph_def_to_load - # Build a saver by importing the meta graph def to load. - saver = tf_saver.import_meta_graph( - meta_graph_def_to_load, import_scope=import_scope, **saver_kwargs) - - if saver: - # Build the checkpoint path where the variables are located. - variables_path = os.path.join( - compat.as_bytes(export_dir), - compat.as_bytes(constants.VARIABLES_DIRECTORY), - compat.as_bytes(constants.VARIABLES_FILENAME)) - - # Restore the variables using the built saver in the provided session. - saver.restore(sess, variables_path) - else: - tf_logging.info("The specified SavedModel has no variables; no " - "checkpoints were restored.") - - # Get asset tensors, if any. - asset_tensors_dictionary = _get_asset_tensors( - export_dir, meta_graph_def_to_load, import_scope=import_scope) - - main_op_tensor = ( - _get_main_op_tensor(meta_graph_def_to_load) or - (_get_legacy_init_op_tensor(meta_graph_def_to_load))) - if main_op_tensor is not None: - sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) + def load_graph(self, graph, tags, import_scope=None, **saver_kwargs): + """Load ops and nodes from SavedModel MetaGraph into graph. - return meta_graph_def_to_load + Args: + graph: tf.Graph object. + tags: a set of string tags identifying a MetaGraphDef. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. + **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. + + Returns: + A tuple of + * Saver defined by the MetaGraph, which can be used to restore the + variable values. + * List of `Operation`/`Tensor` objects returned from + `tf.import_graph_def` (may be `None`). + """ + meta_graph_def = self.get_meta_graph_def_from_tags(tags) + with graph.as_default(): + return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access + meta_graph_def, import_scope=import_scope, **saver_kwargs) + + def restore_variables(self, sess, saver, import_scope=None): + """Restore SavedModel variable values into the session. + + Args: + sess: tf.Session to restore variable values. + saver: a tf.train.Saver object. Can be None if there are no variables in + graph. This may be the saver returned by the load_graph() function, or a + default `tf.train.Saver()`. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. + + Raises: + ValueError: if no saver was passed to the saver argument, and there are + variables in the graph. + """ + with sess.graph.as_default(): + if (saver is None and + not variables._all_saveable_objects(scope=import_scope)): # pylint: disable=protected-access + tf_logging.info("The specified SavedModel has no variables; no " + "checkpoints were restored.") + elif isinstance(saver, tf_saver.Saver): + saver.restore(sess, self._variables_path) + else: + raise ValueError( + "No tf.train.Saver object was passed to the function " + "SavedModelLoader.restore_variables. Since there are variables in " + "the graph, a saver is required.") + + def run_init_ops(self, sess, tags, import_scope=None): + """Run initialization ops defined in the `MetaGraphDef`. + + Args: + sess: tf.Session to restore variable values. + tags: a set of string tags identifying a MetaGraphDef. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. + """ + meta_graph_def = self.get_meta_graph_def_from_tags(tags) + with sess.graph.as_default(): + # Get asset tensors, if any. + asset_tensors_dictionary = _get_asset_tensors( + self._export_dir, meta_graph_def, import_scope=import_scope) + + main_op_tensor = ( + _get_main_op_tensor(meta_graph_def, constants.MAIN_OP_KEY) or + _get_main_op_tensor(meta_graph_def, constants.LEGACY_INIT_OP_KEY)) + if main_op_tensor is not None: + sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) + + def load(self, sess, tags, import_scope=None, **saver_kwargs): + """Load the MetaGraphDef graph and restore variable values into the session. + + Args: + sess: tf.Session to restore variable values. + tags: a set of string tags identifying a MetaGraphDef. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. + **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. + + Returns: + `MetagraphDef` proto of the graph that was loaded. + """ + with sess.graph.as_default(): + saver, _ = self.load_graph(sess.graph, tags, import_scope, + **saver_kwargs) + self.restore_variables(sess, saver, import_scope) + self.run_init_ops(sess, tags, import_scope) + return self.get_meta_graph_def_from_tags(tags) diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0b276a4b20390ad6bae012e4d61e39c57ac4fc --- /dev/null +++ b/tensorflow/python/saved_model/loader_test.py @@ -0,0 +1,230 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SavedModelLoader class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.client import session +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.saved_model import builder as saved_model_builder +from tensorflow.python.saved_model import loader_impl +from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.saved_model import utils +from tensorflow.python.training import saver as tf_saver + + +def _get_export_dir(label): + return os.path.join(test.get_temp_dir(), label) + +SIMPLE_ADD_SAVED_MODEL = _get_export_dir("simple_add_saved_model") +SAVED_MODEL_WITH_MAIN_OP = _get_export_dir("saved_model_with_main_op") + + +class SavedModelLoaderTest(test.TestCase): + + def setUp(self): + """Write test SavedModels to a temp directory.""" + with session.Session(graph=ops.Graph()) as sess: + x = variables.Variable(5, name="x") + y = variables.Variable(11, name="y") + z = x + y + sess.run(variables.global_variables_initializer()) + + foo_sig_def = signature_def_utils.build_signature_def( + {"foo_input": utils.build_tensor_info(x)}, + {"foo_output": utils.build_tensor_info(z)}) + bar_sig_def = signature_def_utils.build_signature_def( + {"bar_x": utils.build_tensor_info(x), + "bar_y": utils.build_tensor_info(y)}, + {"bar_z": utils.build_tensor_info(z)}) + + builder = saved_model_builder.SavedModelBuilder(SIMPLE_ADD_SAVED_MODEL) + builder.add_meta_graph_and_variables( + sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def}) + builder.save() + + # Write SavedModel with a main_op + assign_op = control_flow_ops.group(state_ops.assign(y, 7)) + + builder = saved_model_builder.SavedModelBuilder(SAVED_MODEL_WITH_MAIN_OP) + builder.add_meta_graph_and_variables( + sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def}, + main_op=assign_op) + builder.save() + + def tearDown(self): + file_io.delete_recursively(test.get_temp_dir()) + + def test_load_function(self): + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo_graph"]) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) + + loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + with self.test_session(graph=ops.Graph()) as sess: + loader2.load(sess, ["foo_graph"]) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval()) + + def test_load_graph(self): + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + graph = ops.Graph() + loader.load_graph(graph, ["foo_graph"]) + + x = graph.get_tensor_by_name("x:0") + y = graph.get_tensor_by_name("y:0") + + with self.assertRaises(KeyError): + graph.get_tensor_by_name("z:0") + + with self.test_session(graph=graph) as sess: + # Check that x and y are not initialized + with self.assertRaises(errors.FailedPreconditionError): + sess.run(x) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(y) + + def test_load_with_import_scope(self): + loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + with self.test_session(graph=ops.Graph()) as sess: + saver, _ = loader.load_graph( + sess.graph, ["foo_graph"], import_scope="baz") + + # The default saver should not work when the import scope is set. + with self.assertRaises(errors.NotFoundError): + loader.restore_variables(sess, tf_saver.Saver()) + + loader.restore_variables(sess, saver) + loader.run_init_ops(sess, ["foo_graph"]) + + self.assertEqual(5, sess.graph.get_tensor_by_name("baz/x:0").eval()) + self.assertEqual(7, sess.graph.get_tensor_by_name("baz/y:0").eval()) + + # Test combined load function. + loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo_graph"], import_scope="baa") + self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval()) + self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval()) + + def test_restore_variables(self): + loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + with self.test_session(graph=ops.Graph()) as sess: + x = variables.Variable(0, name="x") + y = variables.Variable(0, name="y") + z = x * y + + sess.run(variables.global_variables_initializer()) + + # There are variables to restore, so a saver must be created. + with self.assertRaises(ValueError): + loader.restore_variables(sess, None) + + loader.restore_variables(sess, tf_saver.Saver()) + self.assertEqual(55, z.eval()) + + def test_run_init_op(self): + loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) + graph = ops.Graph() + saver, _ = loader.load_graph(graph, ["foo_graph"]) + with self.test_session(graph=graph) as sess: + loader.restore_variables(sess, saver) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) + + loader.run_init_ops(sess, ["foo_graph"]) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval()) + + def test_parse_saved_model(self): + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + meta_graph = loader.get_meta_graph_def_from_tags(["foo_graph"]) + self.assertIsNotNone(meta_graph) + self.assertIn("foo", meta_graph.signature_def) + self.assertIn("bar", meta_graph.signature_def) + + def test_load_invalid_meta_graph(self): + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + with self.assertRaises(RuntimeError): + loader.get_meta_graph_def_from_tags([]) + with self.assertRaises(RuntimeError): + loader.get_meta_graph_def_from_tags([""]) + with self.assertRaises(RuntimeError): + loader.get_meta_graph_def_from_tags(["not_a_graph"]) + + def test_load_saved_model_with_no_variables(self): + """Test that SavedModel runs saver when there appear to be no variables. + + When no variables are detected, this may mean that the variables were saved + to different collections, or the collections weren't saved to the + SavedModel. If the SavedModel MetaGraphDef contains a saver, it should still + run in either of these cases. + """ + path = _get_export_dir("no_variable_saved_model") + with session.Session(graph=ops.Graph()) as sess: + x = variables.Variable(5, name="x", collections=["not_global_variable"]) + y = variables.Variable(11, name="y", collections=["not_global_variable"]) + self.assertFalse(variables._all_saveable_objects()) + z = x + y + sess.run(variables.variables_initializer([x, y])) + + foo_sig_def = signature_def_utils.build_signature_def( + {"foo_input": utils.build_tensor_info(x)}, + {"foo_output": utils.build_tensor_info(z)}) + + builder = saved_model_builder.SavedModelBuilder(path) + builder.add_meta_graph_and_variables( + sess, ["foo_graph"], {"foo": foo_sig_def}, + saver=tf_saver.Saver([x, y])) + builder.save() + + loader = loader_impl.SavedModelLoader(path) + with self.test_session(graph=ops.Graph()) as sess: + saver, _ = loader.load_graph(sess.graph, ["foo_graph"]) + self.assertFalse(variables._all_saveable_objects()) + self.assertIsNotNone(saver) + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo_graph"]) + self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) + self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) + + def test_load_saved_model_graph_with_return_elements(self): + """Ensure that the correct elements are returned.""" + loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) + graph = ops.Graph() + _, ret = loader.load_graph(graph, ["foo_graph"], + return_elements=["y:0", "x:0"]) + + self.assertEqual(graph.get_tensor_by_name("y:0"), ret[0]) + self.assertEqual(graph.get_tensor_by_name("x:0"), ret[1]) + + with self.assertRaisesRegexp(ValueError, "not found in graph"): + loader.load_graph(graph, ["foo_graph"], return_elements=["z:0"]) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index fb4732aca21d4661aaea21a472475690687a42be..00b669fc97950d25a6e29f728c649f4dba482162 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -846,9 +846,19 @@ class SavedModelTest(test.TestCase): def testLegacyInitOpWithNonEmptyCollection(self): export_dir = self._get_export_dir( "test_legacy_init_op_with_non_empty_collection") + self._testInitOpsWithNonEmptyCollection( + export_dir, constants.LEGACY_INIT_OP_KEY) + + def testMainOpWithNonEmptyCollection(self): + export_dir = self._get_export_dir( + "test_main_op_with_non_empty_collection") + self._testInitOpsWithNonEmptyCollection(export_dir, constants.MAIN_OP_KEY) + + def _testInitOpsWithNonEmptyCollection(self, export_dir, key): builder = saved_model_builder.SavedModelBuilder(export_dir) - with self.test_session(graph=ops.Graph()) as sess: + g = ops.Graph() + with self.test_session(graph=g) as sess: # Initialize variable `v1` to 1. v1 = variables.Variable(1, name="v1") ops.add_to_collection("v", v1) @@ -857,19 +867,21 @@ class SavedModelTest(test.TestCase): v2 = variables.Variable(42, name="v2", trainable=False, collections=[]) ops.add_to_collection("v", v2) - # Set up an assignment op to be run as part of the legacy_init_op. + # Set up an assignment op to be run as part of the init op. assign_v2 = state_ops.assign(v2, v1) - legacy_init_op = control_flow_ops.group(assign_v2, name="legacy_init_op") + init_op = control_flow_ops.group(assign_v2, name="init_op") sess.run(variables.global_variables_initializer()) - ops.add_to_collection(constants.LEGACY_INIT_OP_KEY, - control_flow_ops.no_op()) - # AssertionError should be raised since the LEGACY_INIT_OP_KEY collection + ops.add_to_collection(key, control_flow_ops.no_op()) + # ValueError should be raised since the LEGACY_INIT_OP_KEY collection # is not empty and we don't support multiple init ops. - with self.assertRaises(AssertionError): + with self.assertRaisesRegexp(ValueError, "Graph already contains"): builder.add_meta_graph_and_variables( - sess, ["foo"], legacy_init_op=legacy_init_op) + sess, ["foo"], legacy_init_op=init_op) + # We shouldn't be able to add as MAIN_OP, either. + with self.assertRaisesRegexp(ValueError, "Graph already contains"): + builder.add_meta_graph_and_variables(sess, ["foo"], main_op=init_op) def testTrainOp(self): export_dir = self._get_export_dir("test_train_op") diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py index cddce29a08a6c4c79a4c7c5dbfb48a86131530b2..20ff34fd8e8a9d8aebee8757cff44d1bf929405e 100644 --- a/tensorflow/python/saved_model/utils_impl.py +++ b/tensorflow/python/saved_model/utils_impl.py @@ -18,10 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.lib.io import file_io +from tensorflow.python.saved_model import constants +from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -84,3 +89,45 @@ def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name)) else: raise ValueError("Invalid TensorInfo.encoding: %s" % encoding) + + +# Path helpers. + + +def get_or_create_variables_dir(export_dir): + """Return variables sub-directory, or create one if it doesn't exist.""" + variables_dir = get_variables_dir(export_dir) + if not file_io.file_exists(variables_dir): + file_io.recursive_create_dir(variables_dir) + return variables_dir + + +def get_variables_dir(export_dir): + """Return variables sub-directory in the SavedModel.""" + return os.path.join( + compat.as_text(export_dir), + compat.as_text(constants.VARIABLES_DIRECTORY)) + + +def get_variables_path(export_dir): + """Return the variables path, used as the prefix for checkpoint files.""" + return os.path.join( + compat.as_text(get_variables_dir(export_dir)), + compat.as_text(constants.VARIABLES_FILENAME)) + + +def get_or_create_assets_dir(export_dir): + """Return assets sub-directory, or create one if it doesn't exist.""" + assets_destination_dir = get_assets_dir(export_dir) + + if not file_io.file_exists(assets_destination_dir): + file_io.recursive_create_dir(assets_destination_dir) + + return assets_destination_dir + + +def get_assets_dir(export_dir): + """Return path to asset directory in the SavedModel.""" + return os.path.join( + compat.as_text(export_dir), + compat.as_text(constants.ASSETS_DIRECTORY)) diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py index 1421d2772fe140dd5f207f159db0ab462231420d..980320cc6687e981085b594ac29c8a546b89465b 100644 --- a/tensorflow/python/summary/summary.py +++ b/tensorflow/python/summary/summary.py @@ -268,7 +268,7 @@ def merge(inputs, collections=None, name=None): @compatibility(eager) Not compatible with eager execution. To write TensorBoard summaries under eager execution, use `tf.contrib.summary` instead. - @end_compatbility + @end_compatibility """ # pylint: enable=line-too-long if _context.executing_eagerly(): @@ -285,7 +285,7 @@ def merge(inputs, collections=None, name=None): @tf_export('summary.merge_all') -def merge_all(key=_ops.GraphKeys.SUMMARIES, scope=None): +def merge_all(key=_ops.GraphKeys.SUMMARIES, scope=None, name=None): """Merges all summaries collected in the default graph. Args: @@ -304,7 +304,7 @@ def merge_all(key=_ops.GraphKeys.SUMMARIES, scope=None): @compatibility(eager) Not compatible with eager execution. To write TensorBoard summaries under eager execution, use `tf.contrib.summary` instead. - @end_compatbility + @end_compatibility """ if _context.executing_eagerly(): raise RuntimeError( @@ -314,7 +314,7 @@ def merge_all(key=_ops.GraphKeys.SUMMARIES, scope=None): if not summary_ops: return None else: - return merge(summary_ops) + return merge(summary_ops, name=name) @tf_export('summary.get_summary_description') @@ -336,7 +336,7 @@ def get_summary_description(node_def): @compatibility(eager) Not compatible with eager execution. To write TensorBoard summaries under eager execution, use `tf.contrib.summary` instead. - @end_compatbility + @end_compatibility """ if node_def.op != 'TensorSummary': diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index aca084fc9168e710316e4c988594cff69e54ebab..16b8626476eb1d43a800c9f41704971ecf5992ae 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -104,8 +104,8 @@ class SummaryToEventTransformer(object): and adds it to the event file. You can pass the result of evaluating any summary op, using - @{tf.Session.run} or - @{tf.Tensor.eval}, to this + `tf.Session.run` or + `tf.Tensor.eval`, to this function. Alternatively, you can pass a `tf.Summary` protocol buffer that you populate with your own data. The latter is commonly done to report evaluation results in event files. @@ -325,7 +325,7 @@ class FileWriter(SummaryToEventTransformer): ``` The `session` argument to the constructor makes the returned `FileWriter` a - a compatibility layer over new graph-based summaries (`tf.contrib.summary`). + compatibility layer over new graph-based summaries (`tf.contrib.summary`). Crucially, this means the underlying writer resource and events file will be shared with any other `FileWriter` using the same `session` and `logdir`, and with any `tf.contrib.summary.SummaryWriter` in this session using the @@ -352,7 +352,7 @@ class FileWriter(SummaryToEventTransformer): @compatibility(eager) `FileWriter` is not compatible with eager execution. To write TensorBoard summaries under eager execution, use `tf.contrib.summary` instead. - @end_compatbility + @end_compatibility """ if context.executing_eagerly(): raise RuntimeError( diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index 6c34b6aaf310c7b576e6ae259af90ef4c23a013a..222f8565118f5b46c3fe7c653fff6968db64ab40 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -64,6 +64,7 @@ py_binary( srcs_version = "PY2AND3", deps = [ "//tensorflow/core:protos_all_py", + "//tensorflow/python", "//tensorflow/python:client", "//tensorflow/python:framework", "//tensorflow/python:framework_ops", diff --git a/tensorflow/python/tools/api/generator/BUILD b/tensorflow/python/tools/api/generator/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f87fdb2d88068e013285d35ddb057ccf36397c35 --- /dev/null +++ b/tensorflow/python/tools/api/generator/BUILD @@ -0,0 +1,100 @@ +# Description: +# Scripts used to generate TensorFlow Python API. + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow/python/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES") +load("//tensorflow/python/tools/api/generator:api_init_files.bzl", "TENSORFLOW_API_INIT_FILES") + +exports_files( + [ + "LICENSE", + "create_python_api.py", + ], +) + +py_binary( + name = "create_python_api", + srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"], + main = "//tensorflow/python/tools/api/generator:create_python_api.py", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:no_contrib", + "//tensorflow/python/tools/api/generator:doc_srcs", + ], +) + +py_library( + name = "doc_srcs", + srcs = ["doc_srcs.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:util", + ], +) + +py_test( + name = "create_python_api_test", + srcs = [ + "create_python_api.py", + "create_python_api_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":doc_srcs", + "//tensorflow/python:client_testlib", + "//tensorflow/python:no_contrib", + ], +) + +py_test( + name = "tensorflow_doc_srcs_test", + srcs = ["doc_srcs_test.py"], + args = [ + "--package=tensorflow.python", + "--api_name=tensorflow", + ] + TENSORFLOW_API_INIT_FILES, + main = "doc_srcs_test.py", + srcs_version = "PY2AND3", + deps = [ + ":doc_srcs", + "//tensorflow/python:client_testlib", + "//tensorflow/python:no_contrib", + ], +) + +py_test( + name = "estimator_doc_srcs_test", + srcs = ["doc_srcs_test.py"], + args = [ + "--package=tensorflow.python.estimator", + "--api_name=estimator", + ] + ESTIMATOR_API_INIT_FILES, + main = "doc_srcs_test.py", + srcs_version = "PY2AND3", + deps = [ + ":doc_srcs", + "//tensorflow/python:client_testlib", + "//tensorflow/python:no_contrib", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "output_init_files_test", + srcs = ["output_init_files_test.py"], + data = [ + "api_init_files.bzl", + "api_init_files_v1.bzl", + ], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:no_contrib", + "//tensorflow/python/tools/api/generator:create_python_api", + ], +) diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl new file mode 100644 index 0000000000000000000000000000000000000000..2810d83bd2428a05cf5736a17fcbba95c8f71344 --- /dev/null +++ b/tensorflow/python/tools/api/generator/api_gen.bzl @@ -0,0 +1,98 @@ +"""Targets for generating TensorFlow Python API __init__.py files.""" + +load("//tensorflow/python/tools/api/generator:api_init_files.bzl", "TENSORFLOW_API_INIT_FILES") + +# keep sorted +ESTIMATOR_API_INIT_FILES = [ + # BEGIN GENERATED ESTIMATOR FILES + "__init__.py", + "estimator/__init__.py", + "estimator/export/__init__.py", + "estimator/inputs/__init__.py", + # END GENERATED ESTIMATOR FILES +] + +def gen_api_init_files( + name, + output_files = TENSORFLOW_API_INIT_FILES, + compat_output_files = {}, + root_init_template = None, + srcs = [], + api_name = "tensorflow", + api_version = 2, + compat_api_versions = [], + package = "tensorflow.python", + package_dep = "//tensorflow/python:no_contrib", + output_package = "tensorflow"): + """Creates API directory structure and __init__.py files. + + Creates a genrule that generates a directory structure with __init__.py + files that import all exported modules (i.e. modules with tf_export + decorators). + + Args: + name: name of genrule to create. + output_files: List of __init__.py files that should be generated. + This list should include file name for every module exported using + tf_export. For e.g. if an op is decorated with + @tf_export('module1.module2', 'module3'). Then, output_files should + include module1/module2/__init__.py and module3/__init__.py. + compat_output_files: Dictionary mapping each compat_api_version to the + set of __init__.py file paths that should be generated for that version. + root_init_template: Python init file that should be used as template for + root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this + template will be replaced with root imports collected by this genrule. + srcs: genrule sources. If passing root_init_template, the template file + must be included in sources. + api_name: Name of the project that you want to generate API files for + (e.g. "tensorflow" or "estimator"). + api_version: TensorFlow API version to generate. Must be either 1 or 2. + compat_api_versions: Older TensorFlow API versions to generate under + compat/ directory. + package: Python package containing the @tf_export decorators you want to + process + package_dep: Python library target containing your package. + output_package: Package where generated API will be added to. + """ + root_init_template_flag = "" + if root_init_template: + root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")" + + api_gen_binary_target = "create_" + package + "_api" + native.py_binary( + name = "create_" + package + "_api", + srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"], + main = "//tensorflow/python/tools/api/generator:create_python_api.py", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + package_dep, + "//tensorflow/python:util", + "//tensorflow/python/tools/api/generator:doc_srcs", + ], + ) + + all_output_files = list(output_files) + compat_api_version_flags = "" + for compat_api_version in compat_api_versions: + compat_files = compat_output_files.get(compat_api_version, []) + all_output_files.extend([ + "compat/v%d/%s" % (compat_api_version, f) + for f in compat_files + ]) + compat_api_version_flags += " --compat_apiversion=%d" % compat_api_version + + native.genrule( + name = name, + outs = all_output_files, + cmd = ( + "$(location :" + api_gen_binary_target + ") " + + root_init_template_flag + " --apidir=$(@D) --apiname=" + + api_name + " --apiversion=" + str(api_version) + + compat_api_version_flags + " --package=" + package + + " --output_package=" + output_package + " $(OUTS)" + ), + srcs = srcs, + tools = [":" + api_gen_binary_target], + visibility = ["//tensorflow:__pkg__"], + ) diff --git a/tensorflow/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl similarity index 58% rename from tensorflow/tools/api/generator/api_gen.bzl rename to tensorflow/python/tools/api/generator/api_init_files.bzl index 41713a94ecbde567340a9e1571a06aca3bbda97a..7001e566ced2af433535e64b0c6ea01c0d2d59a9 100644 --- a/tensorflow/tools/api/generator/api_gen.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -1,4 +1,4 @@ -"""Targets for generating TensorFlow Python API __init__.py files.""" +"""TensorFlow V2 API __init__.py files.""" # keep sorted TENSORFLOW_API_INIT_FILES = [ @@ -8,13 +8,15 @@ TENSORFLOW_API_INIT_FILES = [ "bitwise/__init__.py", "compat/__init__.py", "data/__init__.py", + "debugging/__init__.py", "distributions/__init__.py", - "distributions/bijectors/__init__.py", + "dtypes/__init__.py", "errors/__init__.py", "feature_column/__init__.py", "gfile/__init__.py", "graph_util/__init__.py", "image/__init__.py", + "io/__init__.py", "initializers/__init__.py", "keras/__init__.py", "keras/activations/__init__.py", @@ -65,6 +67,7 @@ TENSORFLOW_API_INIT_FILES = [ "nn/rnn_cell/__init__.py", "profiler/__init__.py", "python_io/__init__.py", + "quantization/__init__.py", "resource_loader/__init__.py", "strings/__init__.py", "saved_model/__init__.py", @@ -87,49 +90,3 @@ TENSORFLOW_API_INIT_FILES = [ "user_ops/__init__.py", # END GENERATED FILES ] - -# keep sorted -ESTIMATOR_API_INIT_FILES = [ - # BEGIN GENERATED ESTIMATOR FILES - "__init__.py", - "estimator/__init__.py", - "estimator/export/__init__.py", - "estimator/inputs/__init__.py", - # END GENERATED ESTIMATOR FILES -] - -# Creates a genrule that generates a directory structure with __init__.py -# files that import all exported modules (i.e. modules with tf_export -# decorators). -# -# Args: -# name: name of genrule to create. -# output_files: List of __init__.py files that should be generated. -# This list should include file name for every module exported using -# tf_export. For e.g. if an op is decorated with -# @tf_export('module1.module2', 'module3'). Then, output_files should -# include module1/module2/__init__.py and module3/__init__.py. -# root_init_template: Python init file that should be used as template for -# root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this -# template will be replaced with root imports collected by this genrule. -# srcs: genrule sources. If passing root_init_template, the template file -# must be included in sources. -def gen_api_init_files(name, - output_files=TENSORFLOW_API_INIT_FILES, - root_init_template=None, - srcs=[], - api_name="tensorflow", - package="tensorflow.python"): - root_init_template_flag = "" - if root_init_template: - root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")" - native.genrule( - name = name, - outs = output_files, - cmd = ( - "$(location //tensorflow/tools/api/generator:create_python_api) " + - root_init_template_flag + " --apidir=$(@D) --apiname=" + api_name + " --package=" + package + " $(OUTS)"), - srcs = srcs, - tools = ["//tensorflow/tools/api/generator:create_python_api"], - visibility = ["//tensorflow:__pkg__"], - ) diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl new file mode 100644 index 0000000000000000000000000000000000000000..73d11199d9cb4cc95f4597e2f6b8d58f692b2e30 --- /dev/null +++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl @@ -0,0 +1,92 @@ +"""TensorFlow V1 API __init__.py files.""" + +# keep sorted +TENSORFLOW_API_INIT_FILES_V1 = [ + # BEGIN GENERATED FILES + "__init__.py", + "app/__init__.py", + "bitwise/__init__.py", + "compat/__init__.py", + "data/__init__.py", + "debugging/__init__.py", + "distributions/__init__.py", + "dtypes/__init__.py", + "errors/__init__.py", + "feature_column/__init__.py", + "gfile/__init__.py", + "graph_util/__init__.py", + "image/__init__.py", + "io/__init__.py", + "initializers/__init__.py", + "keras/__init__.py", + "keras/activations/__init__.py", + "keras/applications/__init__.py", + "keras/applications/densenet/__init__.py", + "keras/applications/inception_resnet_v2/__init__.py", + "keras/applications/inception_v3/__init__.py", + "keras/applications/mobilenet/__init__.py", + "keras/applications/nasnet/__init__.py", + "keras/applications/resnet50/__init__.py", + "keras/applications/vgg16/__init__.py", + "keras/applications/vgg19/__init__.py", + "keras/applications/xception/__init__.py", + "keras/backend/__init__.py", + "keras/callbacks/__init__.py", + "keras/constraints/__init__.py", + "keras/datasets/__init__.py", + "keras/datasets/boston_housing/__init__.py", + "keras/datasets/cifar10/__init__.py", + "keras/datasets/cifar100/__init__.py", + "keras/datasets/fashion_mnist/__init__.py", + "keras/datasets/imdb/__init__.py", + "keras/datasets/mnist/__init__.py", + "keras/datasets/reuters/__init__.py", + "keras/estimator/__init__.py", + "keras/initializers/__init__.py", + "keras/layers/__init__.py", + "keras/losses/__init__.py", + "keras/metrics/__init__.py", + "keras/models/__init__.py", + "keras/optimizers/__init__.py", + "keras/preprocessing/__init__.py", + "keras/preprocessing/image/__init__.py", + "keras/preprocessing/sequence/__init__.py", + "keras/preprocessing/text/__init__.py", + "keras/regularizers/__init__.py", + "keras/utils/__init__.py", + "keras/wrappers/__init__.py", + "keras/wrappers/scikit_learn/__init__.py", + "layers/__init__.py", + "linalg/__init__.py", + "logging/__init__.py", + "losses/__init__.py", + "manip/__init__.py", + "math/__init__.py", + "metrics/__init__.py", + "nn/__init__.py", + "nn/rnn_cell/__init__.py", + "profiler/__init__.py", + "python_io/__init__.py", + "quantization/__init__.py", + "resource_loader/__init__.py", + "strings/__init__.py", + "saved_model/__init__.py", + "saved_model/builder/__init__.py", + "saved_model/constants/__init__.py", + "saved_model/loader/__init__.py", + "saved_model/main_op/__init__.py", + "saved_model/signature_constants/__init__.py", + "saved_model/signature_def_utils/__init__.py", + "saved_model/tag_constants/__init__.py", + "saved_model/utils/__init__.py", + "sets/__init__.py", + "sparse/__init__.py", + "spectral/__init__.py", + "summary/__init__.py", + "sysconfig/__init__.py", + "test/__init__.py", + "train/__init__.py", + "train/queue_runner/__init__.py", + "user_ops/__init__.py", + # END GENERATED FILES +] diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py similarity index 62% rename from tensorflow/tools/api/generator/create_python_api.py rename to tensorflow/python/tools/api/generator/create_python_api.py index e375cd48d8ce0cf1b132e851f4133ad255f14aca..67cfd799fffae77d7d79a487672bbece50462da8 100644 --- a/tensorflow/tools/api/generator/create_python_api.py +++ b/tensorflow/python/tools/api/generator/create_python_api.py @@ -24,12 +24,15 @@ import importlib import os import sys +from tensorflow.python.tools.api.generator import doc_srcs from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_export -from tensorflow.tools.api.generator import doc_srcs API_ATTRS = tf_export.API_ATTRS +API_ATTRS_V1 = tf_export.API_ATTRS_V1 +_API_VERSIONS = [1, 2] +_COMPAT_MODULE_TEMPLATE = 'compat.v%d' _DEFAULT_PACKAGE = 'tensorflow.python' _GENFILES_DIR_SUFFIX = 'genfiles/' _SYMBOLS_TO_SKIP_EXPLICITLY = { @@ -38,14 +41,14 @@ _SYMBOLS_TO_SKIP_EXPLICITLY = { 'tensorflow.python.platform.flags.FLAGS' } _GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit. -# Generated by: tensorflow/tools/api/generator/create_python_api.py script. +# Generated by: tensorflow/python/tools/api/generator/create_python_api.py script. \"\"\"%s \"\"\" from __future__ import print_function """ -_GENERATED_FILE_FOOTER = "\n\ndel print_function\n" +_GENERATED_FILE_FOOTER = '\n\ndel print_function\n' class SymbolExposedTwiceError(Exception): @@ -80,8 +83,9 @@ def format_import(source_module_name, source_name, dest_name): class _ModuleInitCodeBuilder(object): """Builds a map from module name to imports included in that module.""" - def __init__(self): - self.module_imports = collections.defaultdict( + def __init__(self, output_package): + self._output_package = output_package + self._module_imports = collections.defaultdict( lambda: collections.defaultdict(set)) self._dest_import_to_id = collections.defaultdict(int) # Names that start with underscore in the root module. @@ -123,7 +127,30 @@ class _ModuleInitCodeBuilder(object): # The same symbol can be available in multiple modules. # We store all possible ways of importing this symbol and later pick just # one. - self.module_imports[dest_module_name][full_api_name].add(import_str) + self._module_imports[dest_module_name][full_api_name].add(import_str) + + def _import_submodules(self): + """Add imports for all destination modules in self._module_imports.""" + # Import all required modules in their parent modules. + # For e.g. if we import 'foo.bar.Value'. Then, we also + # import 'bar' in 'foo'. + imported_modules = set(self._module_imports.keys()) + for module in imported_modules: + if not module: + continue + module_split = module.split('.') + parent_module = '' # we import submodules in their parent_module + + for submodule_index in range(len(module_split)): + if submodule_index > 0: + submodule = module_split[submodule_index-1] + parent_module += '.' + submodule if parent_module else submodule + import_from = self._output_package + if submodule_index > 0: + import_from += '.' + '.'.join(module_split[:submodule_index]) + self.add_import( + -1, parent_module, import_from, + module_split[submodule_index], module_split[submodule_index]) def build(self): """Get a map from destination module to __init__.py code for that module. @@ -134,8 +161,9 @@ class _ModuleInitCodeBuilder(object): value: (string) text that should be in __init__.py files for corresponding modules. """ + self._import_submodules() module_text_map = {} - for dest_module, dest_name_to_imports in self.module_imports.items(): + for dest_module, dest_name_to_imports in self._module_imports.items(): # Sort all possible imports for a symbol and pick the first one. imports_list = [ sorted(imports)[0] @@ -159,13 +187,94 @@ __all__.remove('print_function') return module_text_map -def get_api_init_text(package, api_name): +def _get_name_and_module(full_name): + """Split full_name into module and short name. + + Args: + full_name: Full name of symbol that includes module. + + Returns: + Full module name and short symbol name. + """ + name_segments = full_name.split('.') + return '.'.join(name_segments[:-1]), name_segments[-1] + + +def _join_modules(module1, module2): + """Concatenate 2 module components. + + Args: + module1: First module to join. + module2: Second module to join. + + Returns: + Given two modules aaa.bbb and ccc.ddd, returns a joined + module aaa.bbb.ccc.ddd. + """ + if not module1: + return module2 + if not module2: + return module1 + return '%s.%s' % (module1, module2) + + +def add_imports_for_symbol( + module_code_builder, + symbol, + source_module_name, + source_name, + api_name, + api_version, + output_module_prefix=''): + """Add imports for the given symbol to `module_code_builder`. + + Args: + module_code_builder: `_ModuleInitCodeBuilder` instance. + symbol: A symbol. + source_module_name: Module that we can import the symbol from. + source_name: Name we can import the symbol with. + api_name: API name. Currently, must be either `tensorflow` or `estimator`. + api_version: API version. + output_module_prefix: Prefix to prepend to destination module. + """ + if api_version == 1: + names_attr = API_ATTRS_V1[api_name].names + constants_attr = API_ATTRS_V1[api_name].constants + else: + names_attr = API_ATTRS[api_name].names + constants_attr = API_ATTRS[api_name].constants + + # If symbol is _tf_api_constants attribute, then add the constants. + if source_name == constants_attr: + for exports, name in symbol: + for export in exports: + dest_module, dest_name = _get_name_and_module(export) + dest_module = _join_modules(output_module_prefix, dest_module) + module_code_builder.add_import( + -1, dest_module, source_module_name, name, dest_name) + + # If symbol has _tf_api_names attribute, then add import for it. + if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__): + for export in getattr(symbol, names_attr): # pylint: disable=protected-access + dest_module, dest_name = _get_name_and_module(export) + dest_module = _join_modules(output_module_prefix, dest_module) + module_code_builder.add_import( + id(symbol), dest_module, source_module_name, source_name, dest_name) + + +def get_api_init_text( + package, output_package, api_name, api_version, compat_api_versions=None): """Get a map from destination module to __init__.py code for that module. Args: package: Base python package containing python with target tf_export decorators. + output_package: Base output python package where generated API will + be added. api_name: API you want to generate (e.g. `tensorflow` or `estimator`). + api_version: API version you want to generate (1 or 2). + compat_api_versions: Additional API versions to generate under compat/ + directory. Returns: A dictionary where @@ -173,14 +282,15 @@ def get_api_init_text(package, api_name): value: (string) text that should be in __init__.py files for corresponding modules. """ - module_code_builder = _ModuleInitCodeBuilder() - + if compat_api_versions is None: + compat_api_versions = [] + module_code_builder = _ModuleInitCodeBuilder(output_package) # Traverse over everything imported above. Specifically, # we want to traverse over TensorFlow Python modules. for module in list(sys.modules.values()): # Only look at tensorflow modules. if (not module or not hasattr(module, '__name__') or - package not in module.__name__): + module.__name__ is None or package not in module.__name__): continue # Do not generate __init__.py files for contrib modules for now. if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'): @@ -191,47 +301,16 @@ def get_api_init_text(package, api_name): in _SYMBOLS_TO_SKIP_EXPLICITLY): continue attr = getattr(module, module_contents_name) - - # If attr is _tf_api_constants attribute, then add the constants. - if module_contents_name == API_ATTRS[api_name].constants: - for exports, value in attr: - for export in exports: - names = export.split('.') - dest_module = '.'.join(names[:-1]) - module_code_builder.add_import( - -1, dest_module, module.__name__, value, names[-1]) - continue - _, attr = tf_decorator.unwrap(attr) - # If attr is a symbol with _tf_api_names attribute, then - # add import for it. - if (hasattr(attr, '__dict__') and - API_ATTRS[api_name].names in attr.__dict__): - for export in getattr(attr, API_ATTRS[api_name].names): # pylint: disable=protected-access - names = export.split('.') - dest_module = '.'.join(names[:-1]) - module_code_builder.add_import( - id(attr), dest_module, module.__name__, module_contents_name, - names[-1]) - - # Import all required modules in their parent modules. - # For e.g. if we import 'foo.bar.Value'. Then, we also - # import 'bar' in 'foo'. - imported_modules = set(module_code_builder.module_imports.keys()) - import_from = '.' - for module in imported_modules: - if not module: - continue - module_split = module.split('.') - parent_module = '' # we import submodules in their parent_module - for submodule_index in range(len(module_split)): - if submodule_index > 0: - parent_module += ('.' + module_split[submodule_index-1] if parent_module - else module_split[submodule_index-1]) - module_code_builder.add_import( - -1, parent_module, import_from, - module_split[submodule_index], module_split[submodule_index]) + add_imports_for_symbol( + module_code_builder, attr, module.__name__, module_contents_name, + api_name, api_version) + for compat_api_version in compat_api_versions: + add_imports_for_symbol( + module_code_builder, attr, module.__name__, module_contents_name, + api_name, compat_api_version, + _COMPAT_MODULE_TEMPLATE % compat_api_version) return module_code_builder.build() @@ -252,7 +331,7 @@ def get_module(dir_path, relative_to_dir): return dir_path.replace('/', '.').strip('.') -def get_module_docstring(module_name, package): +def get_module_docstring(module_name, package, api_name): """Get docstring for the given module. This method looks for docstring in the following order: @@ -268,15 +347,25 @@ def get_module_docstring(module_name, package): (excluding 'tensorflow.' prefix) to get a docstring for. package: Base python package containing python with target tf_export decorators. + api_name: API you want to generate (e.g. `tensorflow` or `estimator`). Returns: One-line docstring to describe the module. """ + # Get the same module doc strings for any version. That is, for module + # 'compat.v1.foo' we can get docstring from module 'foo'. + for version in _API_VERSIONS: + compat_prefix = _COMPAT_MODULE_TEMPLATE % version + if module_name.startswith(compat_prefix): + module_name = module_name[len(compat_prefix):].strip('.') + # Module under base package to get a docstring from. docstring_module_name = module_name - if module_name in doc_srcs.TENSORFLOW_DOC_SOURCES: - docsrc = doc_srcs.TENSORFLOW_DOC_SOURCES[module_name] + doc_sources = doc_srcs.get_doc_sources(api_name) + + if module_name in doc_sources: + docsrc = doc_sources[module_name] if docsrc.docstring: return docsrc.docstring if docsrc.docstring_module_name: @@ -291,23 +380,32 @@ def get_module_docstring(module_name, package): def create_api_files( - output_files, package, root_init_template, output_dir, api_name): + output_files, + package, + root_init_template, + output_dir, + output_package, + api_name, + api_version, + compat_api_versions): """Creates __init__.py files for the Python API. Args: output_files: List of __init__.py file paths to create. - Each file must be under api/ directory. package: Base python package containing python with target tf_export decorators. root_init_template: Template for top-level __init__.py file. - "#API IMPORTS PLACEHOLDER" comment in the template file will be replaced + "# API IMPORTS PLACEHOLDER" comment in the template file will be replaced with imports. output_dir: output API root directory. + output_package: Base output package where generated API will be added. api_name: API you want to generate (e.g. `tensorflow` or `estimator`). + api_version: API version to generate (`v1` or `v2`). + compat_api_versions: Additional API versions to generate in compat/ + subdirectory. Raises: - ValueError: if an output file is not under api/ directory, - or output_files list is missing a required file. + ValueError: if output_files list is missing a required file. """ module_name_to_file_path = {} for output_file in output_files: @@ -320,10 +418,14 @@ def create_api_files( os.makedirs(os.path.dirname(file_path)) open(file_path, 'a').close() - module_text_map = get_api_init_text(package, api_name) + module_text_map = get_api_init_text( + package, output_package, api_name, api_version, compat_api_versions) # Add imports to output files. missing_output_files = [] + # Root modules are "" and "compat.v*". + root_modules = set(_COMPAT_MODULE_TEMPLATE % v for v in compat_api_versions) + root_modules.add('') for module, text in module_text_map.items(): # Make sure genrule output file list is in sync with API exports. if module not in module_name_to_file_path: @@ -331,11 +433,13 @@ def create_api_files( module.replace('.', '/')) missing_output_files.append(module_file_path) continue + contents = '' - if module or not root_init_template: + if module not in root_modules or not root_init_template: contents = ( _GENERATED_FILE_HEADER % - get_module_docstring(module, package) + text + _GENERATED_FILE_FOOTER) + get_module_docstring(module, package, api_name) + + text + _GENERATED_FILE_FOOTER) else: # Read base init file with open(root_init_template, 'r') as root_init_template_file: @@ -346,9 +450,7 @@ def create_api_files( if missing_output_files: raise ValueError( - 'Missing outputs for python_api_gen genrule:\n%s.' - 'Make sure all required outputs are in the ' - 'tensorflow/tools/api/generator/api_gen.bzl file.' % + 'Missing outputs for genrule:\n%s.' % ',\n'.join(sorted(missing_output_files))) @@ -377,7 +479,17 @@ def main(): '--apiname', required=True, type=str, choices=API_ATTRS.keys(), help='The API you want to generate.') - + parser.add_argument( + '--apiversion', default=2, type=int, + choices=_API_VERSIONS, + help='The API version you want to generate.') + parser.add_argument( + '--compat_apiversions', default=[], type=int, action='append', + help='Additional versions to generate in compat/ subdirectory. ' + 'If set to 0, then no additional version would be generated.') + parser.add_argument( + '--output_package', default='tensorflow', type=str, + help='Root output package.') args = parser.parse_args() if len(args.outputs) == 1: @@ -391,7 +503,8 @@ def main(): # Populate `sys.modules` with modules containing tf_export(). importlib.import_module(args.package) create_api_files(outputs, args.package, args.root_init_template, - args.apidir, args.apiname) + args.apidir, args.output_package, args.apiname, + args.apiversion, args.compat_apiversions) if __name__ == '__main__': diff --git a/tensorflow/tools/api/generator/create_python_api_test.py b/tensorflow/python/tools/api/generator/create_python_api_test.py similarity index 74% rename from tensorflow/tools/api/generator/create_python_api_test.py rename to tensorflow/python/tools/api/generator/create_python_api_test.py index 651ec9d040302a4343ae6e0053cf6a4b37a971d4..95ef8bbb0f6aa83e99e0b702f4a70a909f05d741 100644 --- a/tensorflow/tools/api/generator/create_python_api_test.py +++ b/tensorflow/python/tools/api/generator/create_python_api_test.py @@ -22,11 +22,11 @@ import imp import sys from tensorflow.python.platform import test +from tensorflow.python.tools.api.generator import create_python_api from tensorflow.python.util.tf_export import tf_export -from tensorflow.tools.api.generator import create_python_api -@tf_export('test_op', 'test_op1') +@tf_export('test_op', 'test_op1', 'test.test_op2') def test_op(): pass @@ -58,7 +58,8 @@ class CreatePythonApiTest(test.TestCase): def testFunctionImportIsAdded(self): imports = create_python_api.get_api_init_text( package=create_python_api._DEFAULT_PACKAGE, - api_name='tensorflow') + output_package='tensorflow', + api_name='tensorflow', api_version=1) expected_import = ( 'from tensorflow.python.test_module ' 'import test_op as test_op1') @@ -71,11 +72,15 @@ class CreatePythonApiTest(test.TestCase): self.assertTrue( expected_import in str(imports), msg='%s not in %s' % (expected_import, str(imports))) + # Also check that compat.v1 is not added to imports. + self.assertFalse('compat.v1' in imports, + msg='compat.v1 in %s' % str(imports.keys())) def testClassImportIsAdded(self): imports = create_python_api.get_api_init_text( package=create_python_api._DEFAULT_PACKAGE, - api_name='tensorflow') + output_package='tensorflow', + api_name='tensorflow', api_version=2) expected_import = ('from tensorflow.python.test_module ' 'import TestClass') self.assertTrue( @@ -85,12 +90,25 @@ class CreatePythonApiTest(test.TestCase): def testConstantIsAdded(self): imports = create_python_api.get_api_init_text( package=create_python_api._DEFAULT_PACKAGE, - api_name='tensorflow') + output_package='tensorflow', + api_name='tensorflow', api_version=1) expected = ('from tensorflow.python.test_module ' 'import _TEST_CONSTANT') self.assertTrue(expected in str(imports), msg='%s not in %s' % (expected, str(imports))) + def testCompatModuleIsAdded(self): + imports = create_python_api.get_api_init_text( + package=create_python_api._DEFAULT_PACKAGE, + output_package='tensorflow', + api_name='tensorflow', + api_version=2, + compat_api_versions=[1]) + self.assertTrue('compat.v1' in imports, + msg='compat.v1 not in %s' % str(imports.keys())) + self.assertTrue('compat.v1.test' in imports, + msg='compat.v1.test not in %s' % str(imports.keys())) + if __name__ == '__main__': test.main() diff --git a/tensorflow/tools/api/generator/doc_srcs.py b/tensorflow/python/tools/api/generator/doc_srcs.py similarity index 79% rename from tensorflow/tools/api/generator/doc_srcs.py rename to tensorflow/python/tools/api/generator/doc_srcs.py index 74f6db98fd58adf17d062a7aff6776fe1db0445b..ad1988494dae4a9d3ee96af5af76f02c52c0dff4 100644 --- a/tensorflow/tools/api/generator/doc_srcs.py +++ b/tensorflow/python/tools/api/generator/doc_srcs.py @@ -19,6 +19,8 @@ from __future__ import print_function import collections +from tensorflow.python.util import tf_export + # Specifies docstring source for a module. # Only one of docstring or docstring_module_name should be set. @@ -31,7 +33,7 @@ DocSource = collections.namedtuple( # Each attribute of DocSource is optional. DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields) -TENSORFLOW_DOC_SOURCES = { +_TENSORFLOW_DOC_SOURCES = { 'app': DocSource(docstring_module_name='platform.app'), 'compat': DocSource(docstring_module_name='util.compat'), 'distributions': DocSource( @@ -41,7 +43,7 @@ TENSORFLOW_DOC_SOURCES = { 'gfile': DocSource(docstring_module_name='platform.gfile'), 'graph_util': DocSource(docstring_module_name='framework.graph_util'), 'image': DocSource(docstring_module_name='ops.image_ops'), - 'keras.estimator': DocSource(docstring_module_name='estimator.keras'), + 'keras.estimator': DocSource(docstring_module_name='keras.estimator'), 'linalg': DocSource(docstring_module_name='ops.linalg_ops'), 'logging': DocSource(docstring_module_name='ops.logging_ops'), 'losses': DocSource(docstring_module_name='ops.losses.losses'), @@ -63,3 +65,28 @@ TENSORFLOW_DOC_SOURCES = { 'train.queue_runner': DocSource( docstring_module_name='training.queue_runner'), } + +_ESTIMATOR_DOC_SOURCES = { + 'estimator': DocSource( + docstring_module_name='estimator_lib'), + 'estimator.export': DocSource( + docstring_module_name='export.export_lib'), + 'estimator.inputs': DocSource( + docstring_module_name='inputs.inputs'), +} + + +def get_doc_sources(api_name): + """Get a map from module to a DocSource object. + + Args: + api_name: API you want to generate (e.g. `tensorflow` or `estimator`). + + Returns: + Map from module name to DocSource object. + """ + if api_name == tf_export.TENSORFLOW_API_NAME: + return _TENSORFLOW_DOC_SOURCES + if api_name == tf_export.ESTIMATOR_API_NAME: + return _ESTIMATOR_DOC_SOURCES + return {} diff --git a/tensorflow/tools/api/generator/doc_srcs_test.py b/tensorflow/python/tools/api/generator/doc_srcs_test.py similarity index 65% rename from tensorflow/tools/api/generator/doc_srcs_test.py rename to tensorflow/python/tools/api/generator/doc_srcs_test.py index 9ba95a34391a8011477ec6afd630ec5c37586496..481d9874a4bcdcdadcdcb16b5b5c1b10b765dc48 100644 --- a/tensorflow/tools/api/generator/doc_srcs_test.py +++ b/tensorflow/python/tools/api/generator/doc_srcs_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""Tests for tensorflow.tools.api.generator.doc_srcs.""" +"""Tests for tensorflow.python.tools.api.generator.doc_srcs.""" from __future__ import absolute_import from __future__ import division @@ -23,7 +23,7 @@ import importlib import sys from tensorflow.python.platform import test -from tensorflow.tools.api.generator import doc_srcs +from tensorflow.python.tools.api.generator import doc_srcs FLAGS = None @@ -32,34 +32,34 @@ FLAGS = None class DocSrcsTest(test.TestCase): def testModulesAreValidAPIModules(self): - for module_name in doc_srcs.TENSORFLOW_DOC_SOURCES: + for module_name in doc_srcs.get_doc_sources(FLAGS.api_name): # Convert module_name to corresponding __init__.py file path. file_path = module_name.replace('.', '/') if file_path: file_path += '/' file_path += '__init__.py' - if file_path not in FLAGS.outputs: - self.assertFalse('%s is not a valid API module' % module_name) + self.assertIn( + file_path, FLAGS.outputs, + msg='%s is not a valid API module' % module_name) def testHaveDocstringOrDocstringModule(self): - for module_name, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items(): - if docsrc.docstring and docsrc.docstring_module_name: - self.assertFalse( - '%s contains DocSource has both a docstring and a ' - 'docstring_module_name. ' - 'Only one of "docstring" or "docstring_module_name" should be set.' - % (module_name)) + for module_name, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items(): + self.assertFalse( + docsrc.docstring and docsrc.docstring_module_name, + msg=('%s contains DocSource has both a docstring and a ' + 'docstring_module_name. Only one of "docstring" or ' + '"docstring_module_name" should be set.') % (module_name)) def testDocstringModulesAreValidModules(self): - for _, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items(): + for _, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items(): if docsrc.docstring_module_name: doc_module_name = '.'.join([ FLAGS.package, docsrc.docstring_module_name]) - if doc_module_name not in sys.modules: - sys.assertFalse( - 'docsources_module %s is not a valid module under %s.' % - (docsrc.docstring_module_name, FLAGS.package)) + self.assertIn( + doc_module_name, sys.modules, + msg=('docsources_module %s is not a valid module under %s.' % + (docsrc.docstring_module_name, FLAGS.package))) if __name__ == '__main__': @@ -71,6 +71,9 @@ if __name__ == '__main__': '--package', type=str, help='Base package that imports modules containing the target tf_export ' 'decorators.') + parser.add_argument( + '--api_name', type=str, + help='API name: tensorflow or estimator') FLAGS, unparsed = parser.parse_known_args() importlib.import_module(FLAGS.package) diff --git a/tensorflow/python/tools/api/generator/output_init_files_test.py b/tensorflow/python/tools/api/generator/output_init_files_test.py new file mode 100644 index 0000000000000000000000000000000000000000..602ad165c0c1a0e39cce5e7f4eac8cabad2c2e7b --- /dev/null +++ b/tensorflow/python/tools/api/generator/output_init_files_test.py @@ -0,0 +1,179 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for api_init_files.bzl and api_init_files_v1.bzl.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from tensorflow.python.platform import test +from tensorflow.python.util import tf_decorator + + +def _get_module_from_symbol(symbol): + if '.' not in symbol: + return '' + return '.'.join(symbol.split('.')[:-1]) + + +def _get_modules(package, attr_name, constants_attr_name): + """Get list of TF API modules. + + Args: + package: We only look at modules that contain package in the name. + attr_name: Attribute set on TF symbols that contains API names. + constants_attr_name: Attribute set on TF modules that contains + API constant names. + + Returns: + Set of TensorFow API modules. + """ + modules = set() + # TODO(annarev): split up the logic in create_python_api.py so that + # it can be reused in this test. + for module in list(sys.modules.values()): + if (not module or not hasattr(module, '__name__') or + package not in module.__name__): + continue + + for module_contents_name in dir(module): + attr = getattr(module, module_contents_name) + _, attr = tf_decorator.unwrap(attr) + + # Add modules to _tf_api_constants attribute. + if module_contents_name == constants_attr_name: + for exports, _ in attr: + modules.update( + [_get_module_from_symbol(export) for export in exports]) + continue + + # Add modules for _tf_api_names attribute. + if (hasattr(attr, '__dict__') and attr_name in attr.__dict__): + modules.update([ + _get_module_from_symbol(export) + for export in getattr(attr, attr_name)]) + return modules + + +def _get_files_set(path, start_tag, end_tag): + """Get set of file paths from the given file. + + Args: + path: Path to file. File at `path` is expected to contain a list of paths + where entire list starts with `start_tag` and ends with `end_tag`. List + must be comma-separated and each path entry must be surrounded by double + quotes. + start_tag: String that indicates start of path list. + end_tag: String that indicates end of path list. + + Returns: + List of string paths. + """ + with open(path, 'r') as f: + contents = f.read() + start = contents.find(start_tag) + len(start_tag) + 1 + end = contents.find(end_tag) + contents = contents[start:end] + file_paths = [ + file_path.strip().strip('"') for file_path in contents.split(',')] + return set(file_path for file_path in file_paths if file_path) + + +def _module_to_paths(module): + """Get all API __init__.py file paths for the given module. + + Args: + module: Module to get file paths for. + + Returns: + List of paths for the given module. For e.g. module foo.bar + requires 'foo/__init__.py' and 'foo/bar/__init__.py'. + """ + submodules = [] + module_segments = module.split('.') + for i in range(len(module_segments)): + submodules.append('.'.join(module_segments[:i+1])) + paths = [] + for submodule in submodules: + if not submodule: + paths.append('__init__.py') + continue + paths.append('%s/__init__.py' % (submodule.replace('.', '/'))) + return paths + + +class OutputInitFilesTest(test.TestCase): + """Test that verifies files that list paths for TensorFlow API.""" + + def _validate_paths_for_modules( + self, actual_paths, expected_paths, file_to_update_on_error): + """Validates that actual_paths match expected_paths. + + Args: + actual_paths: */__init__.py file paths listed in file_to_update_on_error. + expected_paths: */__init__.py file paths that we need to create for + TensorFlow API. + file_to_update_on_error: File that contains list of */__init__.py files. + We include it in error message printed if the file list needs to be + updated. + """ + self.assertTrue(actual_paths) + self.assertTrue(expected_paths) + missing_paths = expected_paths - actual_paths + extra_paths = actual_paths - expected_paths + + # Surround paths with quotes so that they can be copy-pasted + # from error messages as strings. + missing_paths = ['\'%s\'' % path for path in missing_paths] + extra_paths = ['\'%s\'' % path for path in extra_paths] + + self.assertFalse( + missing_paths, + 'Please add %s to %s.' % ( + ',\n'.join(sorted(missing_paths)), file_to_update_on_error)) + self.assertFalse( + extra_paths, + 'Redundant paths, please remove %s in %s.' % ( + ',\n'.join(sorted(extra_paths)), file_to_update_on_error)) + + def test_V2_init_files(self): + modules = _get_modules( + 'tensorflow', '_tf_api_names', '_tf_api_constants') + file_path = ( + 'tensorflow/python/tools/api/generator/api_init_files.bzl') + paths = _get_files_set( + file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES') + module_paths = set( + f for module in modules for f in _module_to_paths(module)) + self._validate_paths_for_modules( + paths, module_paths, file_to_update_on_error=file_path) + + def test_V1_init_files(self): + modules = _get_modules( + 'tensorflow', '_tf_api_names_v1', '_tf_api_constants_v1') + file_path = ( + 'tensorflow/python/tools/api/generator/' + 'api_init_files_v1.bzl') + paths = _get_files_set( + file_path, '# BEGIN GENERATED FILES', '# END GENERATED FILES') + module_paths = set( + f for module in modules for f in _module_to_paths(module)) + self._validate_paths_for_modules( + paths, module_paths, file_to_update_on_error=file_path) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index e9f1def48c462dcd8a5acf0e3d29d562cd1b3d58..acf070075e3feafe053bcc9109d4137ad1acf44b 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -38,6 +38,7 @@ from __future__ import division from __future__ import print_function import argparse +import re import sys from google.protobuf import text_format @@ -54,9 +55,25 @@ from tensorflow.python.platform import gfile from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import tag_constants from tensorflow.python.tools import saved_model_utils +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib +def _has_variables(sess): + """Determines if the graph has any variables. + + Args: + sess: TensorFlow Session. + + Returns: + Bool. + """ + for op in sess.graph.get_operations(): + if op.type.startswith("Variable") or op.type.endswith("VariableOp"): + return False + return True + + def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, @@ -77,7 +94,7 @@ def freeze_graph_with_def_protos(input_graph_def, # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and - not saver_lib.checkpoint_exists(input_checkpoint)): + not checkpoint_management.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 @@ -116,16 +133,48 @@ def freeze_graph_with_def_protos(input_graph_def, var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() + + # List of all partition variables. Because the condition is heuristic + # based, the list could include false positives. + all_parition_variable_names = [ + tensor.name.split(":")[0] + for op in sess.graph.get_operations() + for tensor in op.values() + if re.search(r"/part_\d+/", tensor.name) + ] + has_partition_var = False + for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") + if any(key in name for name in all_parition_variable_names): + has_partition_var = True except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor - saver = saver_lib.Saver( - var_list=var_list, write_version=checkpoint_version) + + try: + saver = saver_lib.Saver( + var_list=var_list, write_version=checkpoint_version) + except TypeError as e: + # `var_list` is required to be a map of variable names to Variable + # tensors. Partition variables are Identity tensors that cannot be + # handled by Saver. + if has_partition_var: + print("Models containing partition variables cannot be converted " + "from checkpoint files. Please pass in a SavedModel using " + "the flag --input_saved_model_dir.") + return -1 + # Models that have been frozen previously do not contain Variables. + elif _has_variables(sess): + print("No variables were found in this model. It is likely the model " + "was frozen previously. You cannot freeze a graph twice.") + return 0 + else: + raise e + saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py index 91f0061ebccaebbdbb09f283d9d52d813459f493..e38945fabccfb6a49643cb9d49cff385631e628f 100644 --- a/tensorflow/python/tools/freeze_graph_test.py +++ b/tensorflow/python/tools/freeze_graph_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +import re from tensorflow.core.example import example_pb2 from tensorflow.core.framework import graph_pb2 @@ -31,7 +32,10 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.saved_model import builder as saved_model_builder @@ -262,6 +266,69 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): output = sess.run(output_node, feed_dict={input_node: [example]}) self.assertNear(feature_value, output, 0.00001) + def testSinglePartitionedVariable(self): + """Ensures partitioned variables fail cleanly with freeze graph.""" + checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint") + checkpoint_state_name = "checkpoint_state" + input_graph_name = "input_graph.pb" + output_graph_name = "output_graph.pb" + + # Create a graph with partition variables. When weights are partitioned into + # a single partition, the weights variable is followed by a identity -> + # identity (an additional identity node). + partitioner = partitioned_variables.fixed_size_partitioner(1) + with ops.Graph().as_default(): + with variable_scope.variable_scope("part", partitioner=partitioner): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros( + (batch_size, height, width, depth), name="input1") + input2 = array_ops.zeros( + (batch_size, height, width, depth), name="input2") + + num_nodes = depth + filter1 = variable_scope.get_variable("filter", [num_nodes, num_nodes]) + filter2 = array_ops.reshape(filter1, [1, 1, num_nodes, num_nodes]) + conv = nn.conv2d( + input=input1, filter=filter2, strides=[1, 1, 1, 1], padding="SAME") + node = math_ops.add(conv, input2, name="test/add") + node = nn.relu6(node, name="test/relu6") + + # Save graph and checkpoints. + sess = session.Session() + sess.run(variables.global_variables_initializer()) + + saver = saver_lib.Saver() + checkpoint_path = saver.save( + sess, + checkpoint_prefix, + global_step=0, + latest_filename=checkpoint_state_name) + graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name) + + # Ensure this graph has partition variables. + self.assertTrue([ + tensor.name.split(":")[0] + for op in sess.graph.get_operations() + for tensor in op.values() + if re.search(r"/part_\d+/", tensor.name) + ]) + + # Test freezing graph doesn't make it crash. + output_node_names = "save/restore_all" + output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) + + return_value = freeze_graph.freeze_graph_with_def_protos( + input_graph_def=sess.graph_def, + input_saver_def=None, + input_checkpoint=checkpoint_path, + output_node_names=output_node_names, + restore_op_name="save/restore_all", # default value + filename_tensor_name="save/Const:0", # default value + output_graph=output_graph_path, + clear_devices=False, + initializer_nodes="") + self.assertTrue(return_value, -1) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/tools/import_pb_to_tensorboard.py b/tensorflow/python/tools/import_pb_to_tensorboard.py index 00de044505f7f18e6af8237be57c4d8b346caa42..6d2fec3ad6ea193dd72bb29a5f5450f5356d4f1a 100644 --- a/tensorflow/python/tools/import_pb_to_tensorboard.py +++ b/tensorflow/python/tools/import_pb_to_tensorboard.py @@ -29,6 +29,16 @@ from tensorflow.python.platform import app from tensorflow.python.platform import gfile from tensorflow.python.summary import summary +# Try importing TensorRT ops if available +# TODO(aaroey): ideally we should import everything from contrib, but currently +# tensorrt module would cause build errors when being imported in +# tensorflow/contrib/__init__.py. Fix it. +# pylint: disable=unused-import,g-import-not-at-top,wildcard-import +try: + from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import * +except ImportError: + pass +# pylint: enable=unused-import,g-import-not-at-top,wildcard-import def import_to_tensorboard(model_dir, log_dir): """View an imported protobuf model (`.pb` file) as a graph in Tensorboard. diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 5b9d25d449d43d8420e0f30fa8b907d41171d5e5..38fed5335ef39e9832c8b47e3c872ada453aa645 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -15,7 +15,7 @@ """Command-line interface to inspect and execute a graph in a SavedModel. For detailed usages and examples, please refer to: -https://www.tensorflow.org/programmers_guide/saved_model_cli +https://www.tensorflow.org/guide/saved_model_cli """ @@ -720,7 +720,7 @@ def create_parser(): '\'input4_key=[{"id":[26],"weights":[0.5, 0.5]}]\' \\\n' ' --outdir=/out\n\n' 'For more information about input file format, please see:\n' - 'https://www.tensorflow.org/programmers_guide/saved_model_cli\n') + 'https://www.tensorflow.org/guide/saved_model_cli\n') parser_run = subparsers.add_parser( 'run', description=run_msg, formatter_class=argparse.RawTextHelpFormatter) parser_run.add_argument( diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py index 6778f3c735a70fc32ed299bc9d800b270f06cc66..07051d08b958100748f15357ae1e20a349cef727 100644 --- a/tensorflow/python/training/adagrad.py +++ b/tensorflow/python/training/adagrad.py @@ -70,17 +70,17 @@ class AdagradOptimizer(optimizer.Optimizer): def _create_slots(self, var_list): for v in var_list: - with ops.colocate_with(v): - dtype = v.dtype.base_dtype - if v.get_shape().is_fully_defined(): - init = init_ops.constant_initializer(self._initial_accumulator_value, - dtype=dtype) - else: + dtype = v.dtype.base_dtype + if v.get_shape().is_fully_defined(): + init = init_ops.constant_initializer(self._initial_accumulator_value, + dtype=dtype) + else: + def init(v=v, dtype=dtype): # Use a Tensor instead of initializer if variable does not have static # shape. init_constant = gen_array_ops.fill(array_ops.shape(v), self._initial_accumulator_value) - init = math_ops.cast(init_constant, dtype) + return math_ops.cast(init_constant, dtype) self._get_or_make_slot_with_initializer(v, init, v.get_shape(), dtype, "accumulator", self._name) diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py index c9aec33d0916781e3d1a41b996083da92a4ae839..4e634fff84921f628ba864317028aca9b5f2113e 100644 --- a/tensorflow/python/training/adagrad_test.py +++ b/tensorflow/python/training/adagrad_test.py @@ -302,6 +302,39 @@ class AdagradOptimizerTest(test.TestCase): # Creating optimizer should cause no exception. adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1) + def testDynamicShapeVariableWithCallableInit(self): + var0 = variable_scope.get_variable("var0", + initializer=constant_op.constant(1.), + validate_shape=False) + self.assertFalse(var0.shape.is_fully_defined()) + + grads0 = constant_op.constant(0.1, dtype=dtypes.float32) + learning_rate = lambda: 3.0 + + ada_opt = adagrad.AdagradOptimizer( + learning_rate, initial_accumulator_value=0.1, use_locking=True) + + if not context.executing_eagerly(): + ada_update = ada_opt.apply_gradients( + zip([grads0], [var0])) + self.evaluate(variables.global_variables_initializer()) + + # Fetch params to validate initial values + v0_val = self.evaluate([var0]) + self.assertAllClose([1.0], v0_val) + + # Run 3 steps of adagrad + for _ in range(3): + if not context.executing_eagerly(): + self.evaluate(ada_update) + else: + ada_opt.apply_gradients(zip([grads0], [var0])) + + # Validate updated params + v0_val = self.evaluate([var0]) + self.assertAllCloseAccordingToType( + np.array([-1.6026098728179932]), v0_val) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py index b65c88e972454da14dc5161a19cd26280d51d28f..bcbe5907d6370e0c0a268c2ea6a2f10bdf30683e 100644 --- a/tensorflow/python/training/adam.py +++ b/tensorflow/python/training/adam.py @@ -109,12 +109,13 @@ class AdamOptimizer(optimizer.Optimizer): self._updated_lr = None def _get_beta_accumulators(self): - if context.executing_eagerly(): - graph = None - else: - graph = ops.get_default_graph() - return (self._get_non_slot_variable("beta1_power", graph=graph), - self._get_non_slot_variable("beta2_power", graph=graph)) + with ops.init_scope(): + if context.executing_eagerly(): + graph = None + else: + graph = ops.get_default_graph() + return (self._get_non_slot_variable("beta1_power", graph=graph), + self._get_non_slot_variable("beta2_power", graph=graph)) def _create_slots(self, var_list): # Create the beta1 and beta2 accumulators on the same device as the first diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py index ccdc7e384da2ae792a681298c7076fc582d362df..8f844276540bf5b5cc9a61a3cb072b3cfa9cfa7a 100644 --- a/tensorflow/python/training/adam_test.py +++ b/tensorflow/python/training/adam_test.py @@ -315,6 +315,12 @@ class AdamOptimizerTest(test.TestCase): def testTwoSessions(self): optimizer = adam.AdamOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + g = ops.Graph() with g.as_default(): with session.Session(): diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index b0dd188db14a46aae44f8150095cf9ed337ee8a7..76625624e40c04b58b376a98bce9e243a52ae80d 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -28,9 +28,12 @@ from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 from tensorflow.core.util.event_pb2 import SessionLog from tensorflow.python.client import timeline +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import session_run_hook @@ -40,6 +43,10 @@ from tensorflow.python.training.summary_io import SummaryWriterCache from tensorflow.python.util.tf_export import tf_export +_HOOKS = "hooks" +_STEPS_PER_RUN_VAR = "steps_per_run" + + class _HookTimer(object): """Base timer for determining when Hooks should trigger. @@ -255,6 +262,116 @@ class LoggingTensorHook(session_run_hook.SessionRunHook): self._log_tensors(values) +def get_or_create_steps_per_run_variable(): + """Gets or creates the steps_per_run variable. + + In Estimator, the user provided computation, the model_fn, is wrapped + inside a tf.while_loop for peak performance. The iterations of the loop are + specified by this variable, which adjusts its value on the CPU after each + device program execution and before the next execution. + + The purpose of using a variable, rather than a constant, is to allow + Estimator adapt the device training iterations according to the final steps + specified by users. For example, if the user sets the steps_per_run as + 4 and steps as 10 in Estimator.train(), the steps_per_run + variable will have the following value before each training run. + + - 1-st execution: steps_per_run = 4 + - 2-nd execution: steps_per_run = 4 + - 3-rd execution: steps_per_run = 2 + + As model_fn increases the global step once per train_op invocation, the global + step is 10 after all executions, matching the steps=10 inputs passed in by + users. + + Returns: + A TF non-trainable resource variable. + + Raises: + RuntimeError: If multi steps_per_run variables were found. + """ + graph = ops.get_default_graph() + collection_name = "{}_{}".format(_HOOKS, _STEPS_PER_RUN_VAR) + steps_per_run_vars = graph.get_collection(collection_name) + if len(steps_per_run_vars) == 1: + return steps_per_run_vars[0] + elif len(steps_per_run_vars) > 1: + raise RuntimeError("Multiple steps_per_run_var in collection.") + + with variable_scope.variable_scope(_HOOKS, reuse=variable_scope.AUTO_REUSE): + return variable_scope.get_variable( + _STEPS_PER_RUN_VAR, + initializer=init_ops.ones_initializer(), + shape=[], + dtype=dtypes.int32, + trainable=False, + collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES], + use_resource=True) + + +class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook): + """Hook that requests stop at a specified step.""" + + def __init__(self, num_steps=None, last_step=None, steps_per_run=1): + """Initializes a `MultiStepStopAtStepHook`. + + This hook requests stop after either a number of steps have been + executed or a last step has been reached. Only one of the two options can be + specified. + + if `num_steps` is specified, it indicates the number of steps to execute + after `begin()` is called. If instead `last_step` is specified, it + indicates the last step we want to execute, as passed to the `after_run()` + call. + + In Estimator, the user provided computation, the model_fn, is wrapped + inside a tf.while_loop for peak performance. The steps_per_run variable + determines the number of iterations of the loop before returning to the CPU. + + Args: + num_steps: Number of steps to execute. + last_step: Step after which to stop. + steps_per_run: Number of steps executed per run call. + + Raises: + ValueError: If one of the arguments is invalid. + """ + if num_steps is None and last_step is None: + raise ValueError("One of num_steps or last_step must be specified.") + if num_steps is not None and last_step is not None: + raise ValueError("Only one of num_steps or last_step can be specified.") + if steps_per_run is None or steps_per_run < 1: + raise ValueError("steps_per_run should be greater than 0") + self._num_steps = num_steps + self._last_step = last_step + self._steps_per_run = steps_per_run + + def begin(self): + self._global_step_tensor = training_util.get_global_step() + if self._global_step_tensor is None: + raise RuntimeError("Global step should be created to use StopAtStepHook.") + self._steps_per_run_variable = get_or_create_steps_per_run_variable() + + def _update_steps_per_run_variable(self, global_step, session): + steps = min(self._last_step - global_step, self._steps_per_run) + self._steps_per_run_variable.load(steps, session=session) + + def after_create_session(self, session, coord): + global_step = session.run(self._global_step_tensor) + if self._last_step is None: + self._last_step = global_step + self._num_steps + self._update_steps_per_run_variable(global_step, session) + + def after_run(self, run_context, run_values): + # Global step cannot be retrieved via SessionRunArgs and before_run due to + # race condition in hook execution. + global_step = run_context.session.run(self._global_step_tensor) + if global_step >= self._last_step: + run_context.request_stop() + else: + self._update_steps_per_run_variable(global_step, run_context.session) + + @tf_export("train.StopAtStepHook") class StopAtStepHook(session_run_hook.SessionRunHook): """Hook that requests stop at a specified step.""" @@ -404,7 +521,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): Raises: ValueError: One of `save_steps` or `save_secs` should be set. - ValueError: At most one of saver or scaffold should be set. + ValueError: At most one of `saver` or `scaffold` should be set. """ logging.info("Create CheckpointSaverHook.") if saver is not None and scaffold is not None: diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py new file mode 100644 index 0000000000000000000000000000000000000000..85f29043188bdd8928be6965dd50183c2da874f9 --- /dev/null +++ b/tensorflow/python/training/checkpoint_management.py @@ -0,0 +1,681 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# pylint: disable=invalid-name +"""Save and restore variables.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os.path +import re +import time + +from google.protobuf import text_format + +from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.eager import context +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import training_util +from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState +from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export + + +def _GetCheckpointFilename(save_dir, latest_filename): + """Returns a filename for storing the CheckpointState. + + Args: + save_dir: The directory for saving and restoring checkpoints. + latest_filename: Name of the file in 'save_dir' that is used + to store the CheckpointState. + + Returns: + The path of the file that contains the CheckpointState proto. + """ + if latest_filename is None: + latest_filename = "checkpoint" + return os.path.join(save_dir, latest_filename) + + +@tf_export("train.generate_checkpoint_state_proto") +def generate_checkpoint_state_proto(save_dir, + model_checkpoint_path, + all_model_checkpoint_paths=None, + all_model_checkpoint_timestamps=None, + last_preserved_timestamp=None): + """Generates a checkpoint state proto. + + Args: + save_dir: Directory where the model was saved. + model_checkpoint_path: The checkpoint file. + all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted + checkpoints, sorted from oldest to newest. If this is a non-empty list, + the last element must be equal to model_checkpoint_path. These paths + are also saved in the CheckpointState proto. + all_model_checkpoint_timestamps: A list of floats, indicating the number of + seconds since the Epoch when each checkpoint was generated. + last_preserved_timestamp: A float, indicating the number of seconds since + the Epoch when the last preserved checkpoint was written, e.g. due to a + `keep_checkpoint_every_n_hours` parameter (see + `tf.contrib.checkpoint.CheckpointManager` for an implementation). + Returns: + CheckpointState proto with model_checkpoint_path and + all_model_checkpoint_paths updated to either absolute paths or + relative paths to the current save_dir. + + Raises: + ValueError: If `all_model_checkpoint_timestamps` was provided but its length + does not match `all_model_checkpoint_paths`. + """ + if all_model_checkpoint_paths is None: + all_model_checkpoint_paths = [] + + if (not all_model_checkpoint_paths or + all_model_checkpoint_paths[-1] != model_checkpoint_path): + logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", + model_checkpoint_path) + all_model_checkpoint_paths.append(model_checkpoint_path) + + if (all_model_checkpoint_timestamps + and (len(all_model_checkpoint_timestamps) + != len(all_model_checkpoint_paths))): + raise ValueError( + ("Checkpoint timestamps, if provided, must match checkpoint paths (got " + "paths %s and timestamps %s)") + % (all_model_checkpoint_paths, all_model_checkpoint_timestamps)) + + # Relative paths need to be rewritten to be relative to the "save_dir" + # if model_checkpoint_path already contains "save_dir". + if not os.path.isabs(save_dir): + if not os.path.isabs(model_checkpoint_path): + model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) + for i in range(len(all_model_checkpoint_paths)): + p = all_model_checkpoint_paths[i] + if not os.path.isabs(p): + all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) + + coord_checkpoint_proto = CheckpointState( + model_checkpoint_path=model_checkpoint_path, + all_model_checkpoint_paths=all_model_checkpoint_paths, + all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, + last_preserved_timestamp=last_preserved_timestamp) + + return coord_checkpoint_proto + + +@tf_export("train.update_checkpoint_state") +def update_checkpoint_state(save_dir, + model_checkpoint_path, + all_model_checkpoint_paths=None, + latest_filename=None, + all_model_checkpoint_timestamps=None, + last_preserved_timestamp=None): + """Updates the content of the 'checkpoint' file. + + This updates the checkpoint file containing a CheckpointState + proto. + + Args: + save_dir: Directory where the model was saved. + model_checkpoint_path: The checkpoint file. + all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted + checkpoints, sorted from oldest to newest. If this is a non-empty list, + the last element must be equal to model_checkpoint_path. These paths + are also saved in the CheckpointState proto. + latest_filename: Optional name of the checkpoint file. Default to + 'checkpoint'. + all_model_checkpoint_timestamps: Optional list of timestamps (floats, + seconds since the Epoch) indicating when the checkpoints in + `all_model_checkpoint_paths` were created. + last_preserved_timestamp: A float, indicating the number of seconds since + the Epoch when the last preserved checkpoint was written, e.g. due to a + `keep_checkpoint_every_n_hours` parameter (see + `tf.contrib.checkpoint.CheckpointManager` for an implementation). + Raises: + RuntimeError: If any of the model checkpoint paths conflict with the file + containing CheckpointSate. + """ + update_checkpoint_state_internal( + save_dir=save_dir, + model_checkpoint_path=model_checkpoint_path, + all_model_checkpoint_paths=all_model_checkpoint_paths, + latest_filename=latest_filename, + save_relative_paths=False, + all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, + last_preserved_timestamp=last_preserved_timestamp) + + +def update_checkpoint_state_internal(save_dir, + model_checkpoint_path, + all_model_checkpoint_paths=None, + latest_filename=None, + save_relative_paths=False, + all_model_checkpoint_timestamps=None, + last_preserved_timestamp=None): + """Updates the content of the 'checkpoint' file. + + This updates the checkpoint file containing a CheckpointState + proto. + + Args: + save_dir: Directory where the model was saved. + model_checkpoint_path: The checkpoint file. + all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted + checkpoints, sorted from oldest to newest. If this is a non-empty list, + the last element must be equal to model_checkpoint_path. These paths + are also saved in the CheckpointState proto. + latest_filename: Optional name of the checkpoint file. Default to + 'checkpoint'. + save_relative_paths: If `True`, will write relative paths to the checkpoint + state file. + all_model_checkpoint_timestamps: Optional list of timestamps (floats, + seconds since the Epoch) indicating when the checkpoints in + `all_model_checkpoint_paths` were created. + last_preserved_timestamp: A float, indicating the number of seconds since + the Epoch when the last preserved checkpoint was written, e.g. due to a + `keep_checkpoint_every_n_hours` parameter (see + `tf.contrib.checkpoint.CheckpointManager` for an implementation). + + Raises: + RuntimeError: If any of the model checkpoint paths conflict with the file + containing CheckpointSate. + """ + # Writes the "checkpoint" file for the coordinator for later restoration. + coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) + if save_relative_paths: + if os.path.isabs(model_checkpoint_path): + rel_model_checkpoint_path = os.path.relpath( + model_checkpoint_path, save_dir) + else: + rel_model_checkpoint_path = model_checkpoint_path + rel_all_model_checkpoint_paths = [] + for p in all_model_checkpoint_paths: + if os.path.isabs(p): + rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) + else: + rel_all_model_checkpoint_paths.append(p) + ckpt = generate_checkpoint_state_proto( + save_dir, + rel_model_checkpoint_path, + all_model_checkpoint_paths=rel_all_model_checkpoint_paths, + all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, + last_preserved_timestamp=last_preserved_timestamp) + else: + ckpt = generate_checkpoint_state_proto( + save_dir, + model_checkpoint_path, + all_model_checkpoint_paths=all_model_checkpoint_paths, + all_model_checkpoint_timestamps=all_model_checkpoint_timestamps, + last_preserved_timestamp=last_preserved_timestamp) + + if coord_checkpoint_filename == ckpt.model_checkpoint_path: + raise RuntimeError("Save path '%s' conflicts with path used for " + "checkpoint state. Please use a different save path." % + model_checkpoint_path) + + # Preventing potential read/write race condition by *atomically* writing to a + # file. + file_io.atomic_write_string_to_file(coord_checkpoint_filename, + text_format.MessageToString(ckpt)) + + +@tf_export("train.get_checkpoint_state") +def get_checkpoint_state(checkpoint_dir, latest_filename=None): + """Returns CheckpointState proto from the "checkpoint" file. + + If the "checkpoint" file contains a valid CheckpointState + proto, returns it. + + Args: + checkpoint_dir: The directory of checkpoints. + latest_filename: Optional name of the checkpoint file. Default to + 'checkpoint'. + + Returns: + A CheckpointState if the state was available, None + otherwise. + + Raises: + ValueError: if the checkpoint read doesn't have model_checkpoint_path set. + """ + ckpt = None + coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, + latest_filename) + f = None + try: + # Check that the file exists before opening it to avoid + # many lines of errors from colossus in the logs. + if file_io.file_exists(coord_checkpoint_filename): + file_content = file_io.read_file_to_string( + coord_checkpoint_filename) + ckpt = CheckpointState() + text_format.Merge(file_content, ckpt) + if not ckpt.model_checkpoint_path: + raise ValueError("Invalid checkpoint state loaded from " + + checkpoint_dir) + # For relative model_checkpoint_path and all_model_checkpoint_paths, + # prepend checkpoint_dir. + if not os.path.isabs(ckpt.model_checkpoint_path): + ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, + ckpt.model_checkpoint_path) + for i in range(len(ckpt.all_model_checkpoint_paths)): + p = ckpt.all_model_checkpoint_paths[i] + if not os.path.isabs(p): + ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) + except errors.OpError as e: + # It's ok if the file cannot be read + logging.warning("%s: %s", type(e).__name__, e) + logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) + return None + except text_format.ParseError as e: + logging.warning("%s: %s", type(e).__name__, e) + logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) + return None + finally: + if f: + f.close() + return ckpt + + +def _prefix_to_checkpoint_path(prefix, format_version): + """Returns the pathname of a checkpoint file, given the checkpoint prefix. + + For V1 checkpoint, simply returns the prefix itself (the data file). For V2, + returns the pathname to the index file. + + Args: + prefix: a string, the prefix of a checkpoint. + format_version: the checkpoint format version that corresponds to the + prefix. + Returns: + The pathname of a checkpoint file, taking into account the checkpoint + format version. + """ + if format_version == saver_pb2.SaverDef.V2: + return prefix + ".index" # The index file identifies a checkpoint. + return prefix # Just the data file. + + +@tf_export("train.latest_checkpoint") +def latest_checkpoint(checkpoint_dir, latest_filename=None): + """Finds the filename of latest saved checkpoint file. + + Args: + checkpoint_dir: Directory where the variables were saved. + latest_filename: Optional name for the protocol buffer file that + contains the list of most recent checkpoint filenames. + See the corresponding argument to `Saver.save()`. + + Returns: + The full path to the latest checkpoint or `None` if no checkpoint was found. + """ + # Pick the latest checkpoint based on checkpoint state. + ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) + if ckpt and ckpt.model_checkpoint_path: + # Look for either a V2 path or a V1 path, with priority for V2. + v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, + saver_pb2.SaverDef.V2) + v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, + saver_pb2.SaverDef.V1) + if file_io.get_matching_files(v2_path) or file_io.get_matching_files( + v1_path): + return ckpt.model_checkpoint_path + else: + logging.error("Couldn't match files for checkpoint %s", + ckpt.model_checkpoint_path) + return None + + +@tf_export("train.checkpoint_exists") +def checkpoint_exists(checkpoint_prefix): + """Checks whether a V1 or V2 checkpoint exists with the specified prefix. + + This is the recommended way to check if a checkpoint exists, since it takes + into account the naming difference between V1 and V2 formats. + + Args: + checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking + priority. Typically the result of `Saver.save()` or that of + `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or + V1/V2. + Returns: + A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists. + """ + pathname = _prefix_to_checkpoint_path(checkpoint_prefix, + saver_pb2.SaverDef.V2) + if file_io.get_matching_files(pathname): + return True + elif file_io.get_matching_files(checkpoint_prefix): + return True + else: + return False + + +@tf_export("train.get_checkpoint_mtimes") +def get_checkpoint_mtimes(checkpoint_prefixes): + """Returns the mtimes (modification timestamps) of the checkpoints. + + Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files + exist, collect their mtime. Both V2 and V1 checkpoints are considered, in + that priority. + + This is the recommended way to get the mtimes, since it takes into account + the naming difference between V1 and V2 formats. + + Args: + checkpoint_prefixes: a list of checkpoint paths, typically the results of + `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of + sharded/non-sharded or V1/V2. + Returns: + A list of mtimes (in microseconds) of the found checkpoints. + """ + mtimes = [] + + def match_maybe_append(pathname): + fnames = file_io.get_matching_files(pathname) + if fnames: + mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9) + return True + return False + + for checkpoint_prefix in checkpoint_prefixes: + # Tries V2's metadata file first. + pathname = _prefix_to_checkpoint_path(checkpoint_prefix, + saver_pb2.SaverDef.V2) + if match_maybe_append(pathname): + continue + # Otherwise, tries V1, where the prefix is the complete pathname. + match_maybe_append(checkpoint_prefix) + + return mtimes + + +@tf_export("train.remove_checkpoint") +def remove_checkpoint(checkpoint_prefix, + checkpoint_format_version=saver_pb2.SaverDef.V2, + meta_graph_suffix="meta"): + """Removes a checkpoint given by `checkpoint_prefix`. + + Args: + checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result + of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of + sharded/non-sharded or V1/V2. + checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to + `SaverDef.V2`. + meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. + """ + _delete_file_if_exists( + meta_graph_filename(checkpoint_prefix, meta_graph_suffix)) + if checkpoint_format_version == saver_pb2.SaverDef.V2: + # V2 has a metadata file and some data files. + _delete_file_if_exists(checkpoint_prefix + ".index") + _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????") + else: + # V1, Legacy. Exact match on the data file. + _delete_file_if_exists(checkpoint_prefix) + + +def _delete_file_if_exists(filespec): + """Deletes files matching `filespec`.""" + for pathname in file_io.get_matching_files(filespec): + file_io.delete_file(pathname) + + +def meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): + """Returns the meta graph filename. + + Args: + checkpoint_filename: Name of the checkpoint file. + meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. + + Returns: + MetaGraph file name. + """ + # If the checkpoint_filename is sharded, the checkpoint_filename could + # be of format model.ckpt-step#-?????-of-shard#. For example, + # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. + basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) + suffixed_filename = ".".join([basename, meta_graph_suffix]) + return suffixed_filename + + +# TODO(allenl): Allow tf.keras.Model instances in the constructor directly? +class CheckpointManager(object): + """Deletes old checkpoints. + + Example usage: + ```python + import tensorflow as tf + checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) + manager = tf.contrib.checkpoint.CheckpointManager( + checkpoint, directory="/tmp/model", max_to_keep=5) + status = checkpoint.restore(manager.latest_checkpoint) + while True: + # train + manager.save() + ``` + + `CheckpointManager` preserves its own state across instantiations (see the + `__init__` documentation for details). Only one should be active in a + particular directory at a time. + """ + + def __init__(self, checkpoint, directory, + max_to_keep, keep_checkpoint_every_n_hours=None): + """Configure a `CheckpointManager` for use in `directory`. + + If a `CheckpointManager` was previously used in `directory`, its + state will be restored. This includes the list of managed checkpoints and + the timestamp bookkeeping necessary to support + `keep_checkpoint_every_n_hours`. The behavior of the new `CheckpointManager` + will be the same as the previous `CheckpointManager`, including cleaning up + existing checkpoints if appropriate. + + Checkpoints are only considered for deletion just after a new checkpoint has + been added. At that point, `max_to_keep` checkpoints will remain in an + "active set". Once a checkpoint is preserved by + `keep_checkpoint_every_n_hours` it will not be deleted by this + `CheckpointManager` or any future `CheckpointManager` instantiated in + `directory` (regardless of the new setting of + `keep_checkpoint_every_n_hours`). The `max_to_keep` checkpoints in the + active set may be deleted by this `CheckpointManager` or a future + `CheckpointManager` instantiated in `directory` (subject to its + `max_to_keep` and `keep_checkpoint_every_n_hours` settings). + + Args: + checkpoint: The `tf.train.Checkpoint` instance to save and manage + checkpoints for. + directory: The path to a directory in which to write checkpoints. A + special file named "checkpoint" is also written to this directory (in a + human-readable text format) which contains the state of the + `CheckpointManager`. + max_to_keep: An integer, the number of checkpoints to keep. Unless + preserved by `keep_checkpoint_every_n_hours`, checkpoints will be + deleted from the active set, oldest first, until only `max_to_keep` + checkpoints remain. + keep_checkpoint_every_n_hours: Upon removal from the active set, a + checkpoint will be preserved if it has been at least + `keep_checkpoint_every_n_hours` since the last preserved checkpoint. The + default setting of `None` does not preserve any checkpoints in this way. + + Raises: + ValueError: If `max_to_keep` is not a positive integer. + """ + self._checkpoint = checkpoint + self._save_counter_assign = None + if not max_to_keep or max_to_keep < 0: + raise ValueError( + "Expected a positive integer for `max_to_max_to_keep`, got %d." + % (max_to_keep,)) + self._max_to_keep = max_to_keep + self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours + self._directory = directory + self._checkpoint_prefix = os.path.join(directory, "ckpt") + recovered_state = get_checkpoint_state(directory) + current_clock = time.time() + self._maybe_delete = collections.OrderedDict() + if recovered_state is None: + self._latest_checkpoint = None + self._last_preserved_timestamp = current_clock + else: + self._latest_checkpoint = recovered_state.model_checkpoint_path + self._last_preserved_timestamp = recovered_state.last_preserved_timestamp + if current_clock < self._last_preserved_timestamp: + # Time seems to have reversed itself. In addition to this warning, we'll + # min() saved checkpoint timestamps with the current time to ensure that + # old checkpoints don't get deleted accidentally. + logging.warning( + ("time.time() returned a value %f seconds behind the last " + "preserved checkpoint timestamp.") + % (self._last_preserved_timestamp - current_clock,)) + self._last_preserved_timestamp = current_clock + all_timestamps = recovered_state.all_model_checkpoint_timestamps + all_paths = recovered_state.all_model_checkpoint_paths + del recovered_state # Uses modified values from now on + if not all_timestamps: + all_timestamps = [self._last_preserved_timestamp] * len(all_paths) + + for filename, timestamp in zip(all_paths, all_timestamps): + timestamp = min(timestamp, current_clock) + if timestamp > self._last_preserved_timestamp: + self._maybe_delete[filename] = timestamp + + @property + def latest_checkpoint(self): + """The prefix of the most recent checkpoint in `directory`. + + Equivalent to `tf.train.latest_checkpoint(directory)` where `directory` is + the constructor argument to `CheckpointManager`. + + Suitable for passing to `tf.train.Checkpoint.restore` to resume training. + + Returns: + The checkpoint prefix. If there are no checkpoints, returns `None`. + """ + return self._latest_checkpoint + + @property + def checkpoints(self): + """A list of managed checkpoints. + + Note that checkpoints saved due to `keep_checkpoint_every_n_hours` will not + show up in this list (to avoid ever-growing filename lists). + + Returns: + A list of filenames, sorted from oldest to newest. + """ + return list(self._maybe_delete.keys()) + + def _sweep(self): + """Deletes or preserves managed checkpoints.""" + while len(self._maybe_delete) > self._max_to_keep: + filename, timestamp = self._maybe_delete.popitem(last=False) + # Even if we're keeping this checkpoint due to + # keep_checkpoint_every_n_hours, we won't reference it to avoid + # infinitely-growing CheckpointState protos. + if (self._keep_checkpoint_every_n_hours + and (timestamp - self._keep_checkpoint_every_n_hours * 3600. + >= self._last_preserved_timestamp)): + self._last_preserved_timestamp = timestamp + continue + remove_checkpoint(filename) + + def _record_state(self): + """Saves the `CheckpointManager`'s state in `directory`.""" + filenames, timestamps = zip(*self._maybe_delete.items()) + update_checkpoint_state_internal( + self._directory, + model_checkpoint_path=self.latest_checkpoint, + all_model_checkpoint_paths=filenames, + all_model_checkpoint_timestamps=timestamps, + last_preserved_timestamp=self._last_preserved_timestamp, + save_relative_paths=True) + + @property + def _prefix(self): + """A common prefix for all checkpoints saved with this manager. + + For example, if `directory` (a constructor argument) were `"/tmp/tf-model"`, + `prefix` would be `"/tmp/tf-model/ckpt"` and checkpoints would generally be + numbered `"/tmp/tf-model/ckpt-1"`, `"/tmp/tf-model/ckpt-2"`, and so on. Each + checkpoint has several associated files + (e.g. `"/tmp/tf-model/ckpt-2.index"`). + + Returns: + A string prefix. + """ + return self._checkpoint_prefix + + def save(self, session=None, checkpoint_number=None): + """Creates a new checkpoint and manages it. + + Args: + session: The session to evaluate variables in. Ignored when executing + eagerly. If not provided when graph building, the default session is + used. + checkpoint_number: An optional integer, or an integer-dtype `Variable` or + `Tensor`, used to number the checkpoint. If `None` (default), + checkpoints are numbered using `checkpoint.save_counter`. Even if + `checkpoint_number` is provided, `save_counter` is still incremented. A + user-provided `checkpoint_number` is not incremented even if it is a + `Variable`. + + Returns: + The path to the new checkpoint. It is also recorded in the `checkpoints` + and `latest_checkpoint` properies. + """ + # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge + # slightly with a custom numbering option. + if context.executing_eagerly(): + save_counter = self._checkpoint.save_counter + save_counter.assign_add(1) + else: + if session is None: + session = ops.get_default_session() + + def _initializing_creator(next_creator, **kwargs): + """Initialize the save counter if it has been newly created.""" + v = next_creator(**kwargs) + session.run(v.initializer) + return v + + with variable_scope.variable_creator_scope(_initializing_creator): + save_counter = self._checkpoint.save_counter + if self._save_counter_assign is None: + self._save_counter_assign = save_counter.assign_add(1, read_value=False) + session.run(self._save_counter_assign) + if checkpoint_number is None: + checkpoint_number = save_counter + if not isinstance(checkpoint_number, compat.integral_types): + checkpoint_number = training_util.global_step( + sess=session, global_step_tensor=checkpoint_number) + prefix = "%s-%d" % (self._prefix, checkpoint_number) + save_path = self._checkpoint.write(prefix) + timestamp = time.time() + # If this is an overwritten checkpoint we were previously tracking, delete + # and reinsert it to make sure it goes to the end of the queue. + if save_path in self._maybe_delete: + del self._maybe_delete[save_path] + self._maybe_delete[save_path] = timestamp + self._latest_checkpoint = save_path + self._sweep() + self._record_state() + return save_path diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1e2827d0a449571067e9d698c73740cc31dd1f77 --- /dev/null +++ b/tensorflow/python/training/checkpoint_management_test.py @@ -0,0 +1,517 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for tensorflow.python.training.saver.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import os +import shutil +import tempfile + +from google.protobuf import text_format + +from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import test_util +from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import saver as saver_module +from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState +from tensorflow.python.training.checkpointable import util + + +class LatestCheckpointWithRelativePaths(test.TestCase): + + @staticmethod + @contextlib.contextmanager + def tempWorkingDir(temppath): + cwd = os.getcwd() + os.chdir(temppath) + try: + yield + finally: + os.chdir(cwd) + + @staticmethod + @contextlib.contextmanager + def tempDir(): + tempdir = tempfile.mkdtemp() + try: + yield tempdir + finally: + shutil.rmtree(tempdir) + + def testNameCollision(self): + # Make sure we have a clean directory to work in. + with self.tempDir() as tempdir: + # Jump to that directory until this test is done. + with self.tempWorkingDir(tempdir): + # Save training snapshots to a relative path. + traindir = "train/" + os.mkdir(traindir) + # Collides with the default name of the checkpoint state file. + filepath = os.path.join(traindir, "checkpoint") + + with self.test_session() as sess: + unused_a = variables.Variable(0.0) # So that Saver saves something. + variables.global_variables_initializer().run() + + # Should fail. + saver = saver_module.Saver(sharded=False) + with self.assertRaisesRegexp(ValueError, "collides with"): + saver.save(sess, filepath) + + # Succeeds: the file will be named "checkpoint-". + saver.save(sess, filepath, global_step=1) + self.assertIsNotNone( + checkpoint_management.latest_checkpoint(traindir)) + + # Succeeds: the file will be named "checkpoint--of-". + saver = saver_module.Saver(sharded=True) + saver.save(sess, filepath) + self.assertIsNotNone( + checkpoint_management.latest_checkpoint(traindir)) + + # Succeeds: the file will be named "checkpoint---of-". + saver = saver_module.Saver(sharded=True) + saver.save(sess, filepath, global_step=1) + self.assertIsNotNone( + checkpoint_management.latest_checkpoint(traindir)) + + def testRelativePath(self): + # Make sure we have a clean directory to work in. + with self.tempDir() as tempdir: + + # Jump to that directory until this test is done. + with self.tempWorkingDir(tempdir): + + # Save training snapshots to a relative path. + traindir = "train/" + os.mkdir(traindir) + + filename = "snapshot" + filepath = os.path.join(traindir, filename) + + with self.test_session() as sess: + # Build a simple graph. + v0 = variables.Variable(0.0) + inc = v0.assign_add(1.0) + + save = saver_module.Saver({"v0": v0}) + + # Record a short training history. + variables.global_variables_initializer().run() + save.save(sess, filepath, global_step=0) + inc.eval() + save.save(sess, filepath, global_step=1) + inc.eval() + save.save(sess, filepath, global_step=2) + + with self.test_session() as sess: + # Build a new graph with different initialization. + v0 = variables.Variable(-1.0) + + # Create a new saver. + save = saver_module.Saver({"v0": v0}) + variables.global_variables_initializer().run() + + # Get the most recent checkpoint name from the training history file. + name = checkpoint_management.latest_checkpoint(traindir) + self.assertIsNotNone(name) + + # Restore "v0" from that checkpoint. + save.restore(sess, name) + self.assertEqual(v0.eval(), 2.0) + + +class CheckpointStateTest(test.TestCase): + + def _get_test_dir(self, dirname): + test_dir = os.path.join(self.get_temp_dir(), dirname) + gfile.MakeDirs(test_dir) + return test_dir + + def testAbsPath(self): + save_dir = self._get_test_dir("abs_paths") + abs_path = os.path.join(save_dir, "model-0") + ckpt = checkpoint_management.generate_checkpoint_state_proto( + save_dir, abs_path) + self.assertEqual(ckpt.model_checkpoint_path, abs_path) + self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) + self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) + self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) + + def testRelPath(self): + train_dir = "train" + model = os.path.join(train_dir, "model-0") + # model_checkpoint_path should have no "train" directory part. + new_rel_path = "model-0" + ckpt = checkpoint_management.generate_checkpoint_state_proto( + train_dir, model) + self.assertEqual(ckpt.model_checkpoint_path, new_rel_path) + self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) + self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path) + + def testAllModelCheckpointPaths(self): + save_dir = self._get_test_dir("all_models_test") + abs_path = os.path.join(save_dir, "model-0") + for paths in [None, [], ["model-2"]]: + ckpt = checkpoint_management.generate_checkpoint_state_proto( + save_dir, abs_path, all_model_checkpoint_paths=paths) + self.assertEqual(ckpt.model_checkpoint_path, abs_path) + self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) + self.assertEqual( + len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1) + self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) + + def testUpdateCheckpointState(self): + save_dir = self._get_test_dir("update_checkpoint_state") + os.chdir(save_dir) + # Make a temporary train directory. + train_dir = "train" + os.mkdir(train_dir) + abs_path = os.path.join(save_dir, "model-0") + rel_path = os.path.join("train", "model-2") + checkpoint_management.update_checkpoint_state( + train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path]) + ckpt = checkpoint_management.get_checkpoint_state(train_dir) + self.assertEqual(ckpt.model_checkpoint_path, rel_path) + self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) + self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path) + self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path) + + def testUpdateCheckpointStateSaveRelativePaths(self): + save_dir = self._get_test_dir("update_checkpoint_state") + os.chdir(save_dir) + abs_path2 = os.path.join(save_dir, "model-2") + rel_path2 = "model-2" + abs_path0 = os.path.join(save_dir, "model-0") + rel_path0 = "model-0" + checkpoint_management.update_checkpoint_state_internal( + save_dir=save_dir, + model_checkpoint_path=abs_path2, + all_model_checkpoint_paths=[rel_path0, abs_path2], + save_relative_paths=True) + + # File should contain relative paths. + file_content = file_io.read_file_to_string( + os.path.join(save_dir, "checkpoint")) + ckpt = CheckpointState() + text_format.Merge(file_content, ckpt) + self.assertEqual(ckpt.model_checkpoint_path, rel_path2) + self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) + self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2) + self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0) + + # get_checkpoint_state should return absolute paths. + ckpt = checkpoint_management.get_checkpoint_state(save_dir) + self.assertEqual(ckpt.model_checkpoint_path, abs_path2) + self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) + self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2) + self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0) + + def testCheckPointStateFailsWhenIncomplete(self): + save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete") + os.chdir(save_dir) + ckpt_path = os.path.join(save_dir, "checkpoint") + ckpt_file = open(ckpt_path, "w") + ckpt_file.write("") + ckpt_file.close() + with self.assertRaises(ValueError): + checkpoint_management.get_checkpoint_state(save_dir) + + def testCheckPointCompletesRelativePaths(self): + save_dir = self._get_test_dir("checkpoint_completes_relative_paths") + os.chdir(save_dir) + ckpt_path = os.path.join(save_dir, "checkpoint") + ckpt_file = open(ckpt_path, "w") + ckpt_file.write(""" + model_checkpoint_path: "./model.ckpt-687529" + all_model_checkpoint_paths: "./model.ckpt-687500" + all_model_checkpoint_paths: "./model.ckpt-687529" + """) + ckpt_file.close() + ckpt = checkpoint_management.get_checkpoint_state(save_dir) + self.assertEqual(ckpt.model_checkpoint_path, + os.path.join(save_dir, "./model.ckpt-687529")) + self.assertEqual(ckpt.all_model_checkpoint_paths[0], + os.path.join(save_dir, "./model.ckpt-687500")) + self.assertEqual(ckpt.all_model_checkpoint_paths[1], + os.path.join(save_dir, "./model.ckpt-687529")) + + +class SaverUtilsTest(test.TestCase): + + def setUp(self): + self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test") + gfile.MakeDirs(self._base_dir) + + def tearDown(self): + gfile.DeleteRecursively(self._base_dir) + + def testCheckpointExists(self): + for sharded in (False, True): + for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): + with self.test_session(graph=ops_lib.Graph()) as sess: + unused_v = variables.Variable(1.0, name="v") + variables.global_variables_initializer().run() + saver = saver_module.Saver(sharded=sharded, write_version=version) + + path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) + self.assertFalse( + checkpoint_management.checkpoint_exists(path)) # Not saved yet. + + ckpt_prefix = saver.save(sess, path) + self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) + + ckpt_prefix = checkpoint_management.latest_checkpoint(self._base_dir) + self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) + + def testGetCheckpointMtimes(self): + prefixes = [] + for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): + with self.test_session(graph=ops_lib.Graph()) as sess: + unused_v = variables.Variable(1.0, name="v") + variables.global_variables_initializer().run() + saver = saver_module.Saver(write_version=version) + prefixes.append( + saver.save(sess, os.path.join(self._base_dir, str(version)))) + + mtimes = checkpoint_management.get_checkpoint_mtimes(prefixes) + self.assertEqual(2, len(mtimes)) + self.assertTrue(mtimes[1] >= mtimes[0]) + + def testRemoveCheckpoint(self): + for sharded in (False, True): + for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): + with self.test_session(graph=ops_lib.Graph()) as sess: + unused_v = variables.Variable(1.0, name="v") + variables.global_variables_initializer().run() + saver = saver_module.Saver(sharded=sharded, write_version=version) + + path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) + ckpt_prefix = saver.save(sess, path) + self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) + checkpoint_management.remove_checkpoint(ckpt_prefix, version) + self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix)) + + +class CheckpointManagerTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testDeletion(self): + checkpoint = util.Checkpoint() + manager = checkpoint_management.CheckpointManager( + checkpoint, self.get_temp_dir(), max_to_keep=3) + first_path = manager.save() + second_path = manager.save() + third_path = manager.save() + fourth_path = manager.save() + self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) + + @test_util.run_in_graph_and_eager_modes + @test.mock.patch.object(checkpoint_management, "time") + def testSaveRestoreState(self, mock_time): + directory = self.get_temp_dir() + mock_time.time.return_value = 3. + checkpoint = util.Checkpoint() + first_manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=2) + first_time = 10000. + first_name = os.path.join(directory, "ckpt-1") + mock_time.time.return_value = first_time + first_manager.save() + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual([first_time], state.all_model_checkpoint_timestamps) + self.assertEqual(3., state.last_preserved_timestamp) + second_time = first_time + 3610. + second_name = os.path.join(directory, "ckpt-2") + mock_time.time.return_value = second_time + first_manager.save() + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual([first_time, second_time], + state.all_model_checkpoint_timestamps) + self.assertEqual(3., state.last_preserved_timestamp) + self.assertEqual([first_name, second_name], first_manager.checkpoints) + self.assertEqual(second_name, first_manager.latest_checkpoint) + del first_manager + + second_manager = checkpoint_management.CheckpointManager( + checkpoint, directory, + max_to_keep=2, keep_checkpoint_every_n_hours=1.5) + self.assertEqual([first_name, second_name], second_manager.checkpoints) + self.assertEqual(second_name, second_manager.latest_checkpoint) + third_name = os.path.join(directory, "ckpt-3") + third_time = second_time + 3600. * 0.2 + mock_time.time.return_value = third_time + second_manager.save() + self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) + self.assertTrue(checkpoint_management.checkpoint_exists(second_name)) + self.assertEqual([second_name, third_name], + second_manager.checkpoints) + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual(first_time, state.last_preserved_timestamp) + fourth_time = third_time + 3600. * 0.5 + mock_time.time.return_value = fourth_time + fourth_name = os.path.join(directory, "ckpt-4") + second_manager.save() + self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) + self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) + self.assertEqual([third_name, fourth_name], + second_manager.checkpoints) + fifth_time = fourth_time + 3600. * 0.5 + mock_time.time.return_value = fifth_time + fifth_name = os.path.join(directory, "ckpt-5") + second_manager.save() + self.assertEqual([fourth_name, fifth_name], + second_manager.checkpoints) + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual(first_time, state.last_preserved_timestamp) + del second_manager + third_manager = checkpoint_management.CheckpointManager( + checkpoint, directory, + max_to_keep=2, keep_checkpoint_every_n_hours=1.5) + self.assertEqual(fifth_name, third_manager.latest_checkpoint) + mock_time.time.return_value += 10. + third_manager.save() + sixth_name = os.path.join(directory, "ckpt-6") + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual(fourth_time, state.last_preserved_timestamp) + self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) + self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name)) + self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name)) + self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name)) + self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) + self.assertFalse(checkpoint_management.checkpoint_exists(third_name)) + self.assertEqual([fifth_name, sixth_name], + third_manager.checkpoints) + + @test_util.run_in_graph_and_eager_modes + def testContinueFromUnmanaged(self): + directory = self.get_temp_dir() + prefix = os.path.join(directory, "unusual_prefix") + checkpoint = util.Checkpoint() + first_path = checkpoint.save(prefix) + second_path = checkpoint.save(prefix) + del checkpoint + checkpoint = util.Checkpoint() + manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=2) + checkpoint.restore(manager.latest_checkpoint).run_restore_ops() + self.assertEqual(2, self.evaluate(checkpoint.save_counter)) + third_path = manager.save() + self.assertEqual([third_path], manager.checkpoints) + fourth_path = manager.save() + self.assertEqual([third_path, fourth_path], + manager.checkpoints) + fifth_path = manager.save() + self.assertEqual([fourth_path, fifth_path], + manager.checkpoints) + self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertFalse(checkpoint_management.checkpoint_exists(third_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path)) + + @test_util.run_in_graph_and_eager_modes + @test.mock.patch.object(checkpoint_management, "time") + def testClockReset(self, mock_time): + directory = self.get_temp_dir() + mock_time.time.return_value = 10000. + checkpoint = util.Checkpoint() + first_manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.) + first_path = first_manager.save() + mock_time.time.return_value += 3600. + second_path = first_manager.save() + mock_time.time.return_value += 3600. + third_path = first_manager.save() + self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertEqual([third_path], first_manager.checkpoints) + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual(13600., state.last_preserved_timestamp) + # Set the clock back in time + mock_time.time.return_value = 5000. + del first_manager + with test.mock.patch.object(logging, "warning") as mock_log: + second_manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=1) + self.assertRegexpMatches( + str(mock_log.call_args), + "behind the last preserved checkpoint timestamp") + # We should err on the side of keeping checkpoints around when we're not + # sure whether they were preserved or not due to clock funkiness. + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + # We know about the existing checkpoints, but they'll never be deleted and + # so won't go in the CheckpointState proto on save. + self.assertEqual(third_path, second_manager.latest_checkpoint) + self.assertEqual([], second_manager.checkpoints) + mock_time.time.return_value += 10. + fourth_path = second_manager.save() + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertEqual(fourth_path, second_manager.latest_checkpoint) + self.assertEqual([fourth_path], second_manager.checkpoints) + mock_time.time.return_value += 10. + fifth_path = second_manager.save() + self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) + self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) + self.assertEqual([fifth_path], second_manager.checkpoints) + state = checkpoint_management.get_checkpoint_state(directory) + self.assertEqual(5000., state.last_preserved_timestamp) + self.assertEqual([5020.], + state.all_model_checkpoint_timestamps) + + @test_util.run_in_graph_and_eager_modes + def testCustomNumbering(self): + directory = self.get_temp_dir() + step = variables.Variable(0, dtype=dtypes.int64) + checkpoint = util.Checkpoint(step=step) + manager = checkpoint_management.CheckpointManager( + checkpoint, directory, max_to_keep=2) + self.evaluate(step.initializer) + for i in range(5): + path = manager.save(checkpoint_number=step) + expected_suffix = "-%d" % (2 * i,) + if not path.endswith(expected_suffix): + self.fail("%s should have suffix %s" % (path, expected_suffix)) + self.evaluate(step.assign_add(2)) + self.assertEqual(5, self.evaluate(checkpoint.save_counter)) + # Test regular integers + last_path = manager.save(checkpoint_number=32) + self.assertIn("-32", last_path) + self.assertEqual(last_path, manager.latest_checkpoint) + self.assertEqual( + last_path, checkpoint_management.latest_checkpoint(directory)) + state = checkpoint_management.get_checkpoint_state(directory) + # Only the most recent two checkpoints are saved + self.assertEqual([path, last_path], state.all_model_checkpoint_paths) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/training/checkpoint_state.proto b/tensorflow/python/training/checkpoint_state.proto index 9172a5c33142568f478ab203f9736516eadf250f..704f7fdc88da850f8cb0c45f3b5f7e5acbaf4138 100644 --- a/tensorflow/python/training/checkpoint_state.proto +++ b/tensorflow/python/training/checkpoint_state.proto @@ -4,8 +4,6 @@ package tensorflow; option cc_enable_arenas = true; // Protocol buffer representing the checkpoint state. -// -// TODO(touts): Add other attributes as needed. message CheckpointState { // Path to the most-recent model checkpoint. string model_checkpoint_path = 1; @@ -15,4 +13,10 @@ message CheckpointState { // Note that the value of model_checkpoint_path should be the last item in // this list. repeated string all_model_checkpoint_paths = 2; + // Unix timestamps corresponding to all_model_checkpoint_paths, indicating + // when each checkpoint was created. + repeated double all_model_checkpoint_timestamps = 3; + // Unix timestamp indicating the creation time for the last preserved + // checkpoint. + double last_preserved_timestamp = 4; } diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index e7f88de1d2290a49f3b7bdf47417016d7e7c9cea..e6118177fd1004b0f6f807666302289de6b7d2f6 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -24,11 +24,12 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import saver from tensorflow.python.util.tf_export import tf_export @@ -147,7 +148,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): partitioner=lambda shape, dtype: [5, 1]) # Initialize all variables in `new_scope_1` from `old_scope_1`. - init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/', 'new_scope_1'}) + init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'}) # Use names to specify which variables to initialize from checkpoint. init_from_checkpoint('/tmp/model.ckpt', @@ -179,6 +180,16 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): tf.errors.OpError: If missing checkpoints or tensors in checkpoints. ValueError: If missing variables in current graph. """ + if distribution_strategy_context.get_cross_tower_context(): + _init_from_checkpoint(None, ckpt_dir_or_file, assignment_map) + else: + distribution_strategy_context.get_tower_context().merge_call( + _init_from_checkpoint, ckpt_dir_or_file, assignment_map) + + +def _init_from_checkpoint(_, ckpt_dir_or_file, assignment_map): + """See `init_from_checkpoint` for documentation.""" + ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file) reader = load_checkpoint(ckpt_dir_or_file) variable_map = reader.get_variable_to_shape_map() @@ -187,10 +198,9 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): var = None # Check if this is Variable object or list of Variable objects (in case of # partitioned variables). - is_var = lambda x: isinstance(x, variables.Variable) - if is_var(current_var_or_name) or ( + if _is_variable(current_var_or_name) or ( isinstance(current_var_or_name, list) - and all(is_var(v) for v in current_var_or_name)): + and all(_is_variable(v) for v in current_var_or_name)): var = current_var_or_name else: store_vars = vs._get_default_variable_store()._vars # pylint:disable=protected-access @@ -205,7 +215,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): raise ValueError("Tensor %s is not found in %s checkpoint %s" % ( tensor_name_in_ckpt, ckpt_dir_or_file, variable_map )) - if is_var(var): + if _is_variable(var): # Additional at-call-time checks. if not var.get_shape().is_compatible_with( variable_map[tensor_name_in_ckpt]): @@ -219,8 +229,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): else: var_name = ",".join([v.name for v in var]) _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt) - logging.info("Initialize variable %s from checkpoint %s with %s", - var_name, ckpt_dir_or_file, tensor_name_in_ckpt) + logging.debug("Initialize variable %s from checkpoint %s with %s", + var_name, ckpt_dir_or_file, tensor_name_in_ckpt) else: scopes = "" # TODO(vihanjain): Support list of 'current_var_or_name' here. @@ -261,14 +271,14 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map): if var is None: var = _collect_partitioned_variable(var_name, store_vars) _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name) - logging.info("Initialize variable %s from checkpoint %s with %s", - var_name, ckpt_dir_or_file, full_tensor_name) + logging.debug("Initialize variable %s from checkpoint %s with %s", + var_name, ckpt_dir_or_file, full_tensor_name) def _get_checkpoint_filename(ckpt_dir_or_file): """Returns checkpoint filename given directory or specific checkpoint file.""" if gfile.IsDirectory(ckpt_dir_or_file): - return saver.latest_checkpoint(ckpt_dir_or_file) + return checkpoint_management.latest_checkpoint(ckpt_dir_or_file) return ckpt_dir_or_file @@ -297,13 +307,21 @@ def _set_checkpoint_initializer(variable, with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2( ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] - if isinstance(variable, resource_variable_ops.ResourceVariable): - init_op = variable.assign(restore_op, read_value=False) - else: - init_op = state_ops.assign(variable, restore_op) - variable._initializer_op = init_op # pylint:disable=protected-access + + names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable]) + saveable_objects = [] + for name, op in names_to_saveables.items(): + for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name): + saveable_objects.append(s) + + assert len(saveable_objects) == 1 # Should be only one variable. + init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) + + # pylint:disable=protected-access + variable._initializer_op = init_op restore_op.set_shape(variable.shape) - variable._initial_value = restore_op # pylint:disable=protected-access + variable._initial_value = restore_op + # pylint:enable=protected-access def _set_variable_or_list_initializer(variable_or_list, ckpt_file, @@ -337,6 +355,11 @@ def _set_variable_or_list_initializer(variable_or_list, ckpt_file, _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "") +def _is_variable(x): + return (isinstance(x, variables.Variable) or + resource_variable_ops.is_resource_variable(x)) + + def _collect_partitioned_variable(name, all_vars): """Returns list of `tf.Variable` that comprise the partitioned variable.""" if name + "/part_0" in all_vars: diff --git a/tensorflow/python/training/checkpoint_utils_test.py b/tensorflow/python/training/checkpoint_utils_test.py index 4e08a1c859fbaac75e7cd09ad498d9fea14c6338..1c1f126ce94614ae02d701a4a15c8956e3dd4f2f 100644 --- a/tensorflow/python/training/checkpoint_utils_test.py +++ b/tensorflow/python/training/checkpoint_utils_test.py @@ -386,7 +386,9 @@ class CheckpointsTest(test.TestCase): op for op in g.get_operations() if (op.name.startswith("init_from_checkpoint/") and not op.name.startswith("init_from_checkpoint/checkpoint_initializer" - ) and op.type != "AssignVariableOp") + ) and + op.type != "AssignVariableOp" and + op.type != "Identity") ] self.assertEqual(ops_in_init_from_checkpoint_scope, []) diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD index 87ba4dc91c89e03ac5f2a93bedca81878f5254a6..6e9b8ff9050647386c57b122df60ec68a325a6f3 100644 --- a/tensorflow/python/training/checkpointable/BUILD +++ b/tensorflow/python/training/checkpointable/BUILD @@ -42,21 +42,39 @@ py_test( ) py_library( - name = "data_structures_base", - srcs = ["data_structures_base.py"], + name = "tracking", + srcs = ["tracking.py"], srcs_version = "PY2AND3", deps = [ ":base", + ":data_structures", + ], +) + +py_test( + name = "tracking_test", + srcs = ["tracking_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":base", + ":tracking", + "//tensorflow/python:client_testlib", ], ) +py_library( + name = "layer_utils", + srcs = ["layer_utils.py"], + srcs_version = "PY2AND3", +) + py_library( name = "data_structures", srcs = ["data_structures.py"], srcs_version = "PY2AND3", deps = [ ":base", - ":data_structures_base", + ":layer_utils", ], ) @@ -83,6 +101,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":base", + ":tracking", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", @@ -99,20 +118,21 @@ py_test( name = "util_test", srcs = ["util_test.py"], srcs_version = "PY2AND3", - tags = [ - "no_windows", # TODO: needs investigation on Windows - "notsan", # b/74395663 - ], + tags = ["notsan"], # b/74395663 deps = [ ":base", + ":tracking", ":util", + "//tensorflow/python:checkpoint_management", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:init_ops", + "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:saver", "//tensorflow/python:session", "//tensorflow/python:state_ops", "//tensorflow/python:template", diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py index cfe7259e1b6d9932fff9e78049fa85554f022076..390434c0a2511964e0187eff34fcc506f35530e6 100644 --- a/tensorflow/python/training/checkpointable/base.py +++ b/tensorflow/python/training/checkpointable/base.py @@ -33,6 +33,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saveable_object from tensorflow.python.util import nest from tensorflow.python.util import serialization +from tensorflow.python.util import tf_decorator # Key where the object graph proto is saved in a TensorBundle @@ -78,10 +79,6 @@ class CheckpointInitialValue(ops.Tensor): self.wrapped_value.set_shape(shape) self._checkpoint_position = checkpoint_position - @property - def __class__(self): - return (self.wrapped_value.__class__, CheckpointInitialValue) - def __getattr__(self, attr): try: return getattr(self.wrapped_value, attr) @@ -143,7 +140,7 @@ class _CheckpointPosition(object): # process deferred restorations for it and its dependencies. restore_ops = checkpointable._restore_from_checkpoint_position(self) # pylint: disable=protected-access if restore_ops: - self._checkpoint.restore_ops.extend(restore_ops) + self._checkpoint.new_restore_ops(restore_ops) def bind_object(self, checkpointable): """Set a checkpoint<->object correspondence and process slot variables. @@ -340,6 +337,34 @@ _SlotVariableRestoration = collections.namedtuple( ]) +def no_automatic_dependency_tracking(method): + """Disables automatic dependency tracking on attribute assignment. + + Use to decorate any method of a Checkpointable object. Attribute assignment in + that method will not add dependencies (also respected in Model). Harmless if + used in a class which does not do automatic dependency tracking (which means + it's safe to use in base classes which may have subclasses which also inherit + from Checkpointable). + + Args: + method: The method to decorate. + Returns: + A decorated method which sets and un-sets automatic dependency tracking for + the object the method is called on (not thread safe). + """ + + def _method_wrapper(self, *args, **kwargs): + previous_value = getattr(self, "_setattr_tracking", True) + self._setattr_tracking = False # pylint: disable=protected-access + try: + method(self, *args, **kwargs) + finally: + self._setattr_tracking = previous_value # pylint: disable=protected-access + + return tf_decorator.make_decorator( + target=method, decorator_func=_method_wrapper) + + class CheckpointableBase(object): """Base class for `Checkpointable` objects without automatic dependencies. @@ -349,6 +374,11 @@ class CheckpointableBase(object): checks. """ + # CheckpointableBase does not do automatic dependency tracking, but uses the + # no_automatic_dependency_tracking decorator so it can avoid adding + # dependencies if a subclass is Checkpointable / inherits from Model (both of + # which have __setattr__ overrides). + @no_automatic_dependency_tracking def _maybe_initialize_checkpointable(self): """Initialize dependency management. @@ -386,6 +416,10 @@ class CheckpointableBase(object): # building. self._name_based_restores = set() + def _no_dependency(self, value): + """If automatic dependency tracking is enabled, ignores `value`.""" + return value + def _name_based_attribute_restore(self, checkpoint): """Restore the object's attributes from a name-based checkpoint.""" self._name_based_restores.add(checkpoint) @@ -463,12 +497,6 @@ class CheckpointableBase(object): ValueError: If the variable name is not unique. """ self._maybe_initialize_checkpointable() - if not overwrite and self._lookup_dependency(name) is not None: - raise ValueError( - ("A variable named '%s' already exists in this Checkpointable, but " - "Checkpointable._add_variable called to create another with " - "that name. Variable names must be unique within a Checkpointable " - "object.") % (name,)) with ops.init_scope(): if context.executing_eagerly(): # If this is a variable with a single Tensor stored in the checkpoint, @@ -593,9 +621,9 @@ class CheckpointableBase(object): self._unconditional_checkpoint_dependencies[index] = new_reference elif current_object is None: self._unconditional_checkpoint_dependencies.append(new_reference) - self._unconditional_dependency_names[name] = checkpointable self._handle_deferred_dependencies( name=name, checkpointable=checkpointable) + self._unconditional_dependency_names[name] = checkpointable return checkpointable def _handle_deferred_dependencies(self, name, checkpointable): @@ -733,86 +761,3 @@ class CheckpointableBase(object): return {OBJECT_CONFIG_JSON_KEY: functools.partial( PythonStringStateSaveable, state_callback=_state_callback)} - - -class NoDependency(object): - """Allows attribute assignment to `Checkpointable` objects with no dependency. - - Example usage: - ```python - obj = Checkpointable() - obj.has_dependency = tf.Variable(0., name="dep") - obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) - assert obj.no_dependency.name == "nodep:0" - ``` - - `obj` in this example has a dependency on the variable "dep", and both - attributes contain un-wrapped `Variable` objects. - - `NoDependency` also works with `tf.keras.Model`, but only for checkpoint - dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) - `Layer` to the attribute without a checkpoint dependency, but the `Model` will - still track the `Layer` (so it will appear in `Model.layers`, and its - variables will appear in `Model.variables`). - """ - - def __init__(self, value): - self.value = value - - -class NotCheckpointable(object): - """Marks instances of child classes as unsaveable using an object-based API. - - Useful for marking objects which would otherwise look checkpointable because - of inheritance (e.g. through `Layer`) as not checkpointable. Inheriting from - `NotCheckpointable` does not prevent an object from being assigned to any - attributes, but will throw an error on save/restore. - """ - pass - - -class Checkpointable(CheckpointableBase): - """Manages dependencies on other objects. - - `Checkpointable` objects may have dependencies: other `Checkpointable` objects - which should be saved if the object declaring the dependency is saved. A - correctly saveable program has a dependency graph such that if changing a - global variable affects an object (e.g. changes the behavior of any of its - methods) then there is a chain of dependencies from the influenced object to - the variable. - - Dependency edges have names, and are created implicitly when a - `Checkpointable` object is assigned to an attribute of another - `Checkpointable` object. For example: - - ``` - obj = Checkpointable() - obj.v = ResourceVariable(0.) - ``` - - The `Checkpointable` object `obj` now has a dependency named "v" on a - variable. - - `Checkpointable` objects may specify `Tensor`s to be saved and restored - directly (e.g. a `Variable` indicating how to save itself) rather than through - dependencies on other objects. See - `Checkpointable._gather_saveables_for_checkpoint` for details. - """ - - def __setattr__(self, name, value): - """Support self.foo = checkpointable syntax.""" - # Perform the attribute assignment, and potentially call other __setattr__ - # overrides such as that for tf.keras.Model. - no_dependency = isinstance(value, NoDependency) - if no_dependency: - value = value.value - super(Checkpointable, self).__setattr__(name, value) - if not no_dependency and isinstance(value, CheckpointableBase): - self._track_checkpointable( - value, name=name, - # Allow the user to switch the Checkpointable which is tracked by this - # name, since assigning a new variable to an attribute has - # historically been fine (e.g. Adam did this). - # TODO(allenl): Should this be a warning once Checkpointable save/load - # is usable? - overwrite=True) diff --git a/tensorflow/python/training/checkpointable/base_test.py b/tensorflow/python/training/checkpointable/base_test.py index 0a274cdfed5af83a69513e9b26bf427f284a4df7..fd935ac559ed7cd607145e7b2433a00c1f8431ea 100644 --- a/tensorflow/python/training/checkpointable/base_test.py +++ b/tensorflow/python/training/checkpointable/base_test.py @@ -16,34 +16,46 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import util class InterfaceTests(test.TestCase): - def testMultipleAssignment(self): - root = checkpointable.Checkpointable() - root.leaf = checkpointable.Checkpointable() - root.leaf = root.leaf - duplicate_name_dep = checkpointable.Checkpointable() + def testOverwrite(self): + root = base.CheckpointableBase() + leaf = base.CheckpointableBase() + root._track_checkpointable(leaf, name="leaf") + (current_name, current_dependency), = root._checkpoint_dependencies + self.assertIs(leaf, current_dependency) + self.assertEqual("leaf", current_name) + duplicate_name_dep = base.CheckpointableBase() with self.assertRaises(ValueError): root._track_checkpointable(duplicate_name_dep, name="leaf") - # No error; we're overriding __setattr__, so we can't really stop people - # from doing this while maintaining backward compatibility. - root.leaf = duplicate_name_dep root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True) + (current_name, current_dependency), = root._checkpoint_dependencies + self.assertIs(duplicate_name_dep, current_dependency) + self.assertEqual("leaf", current_name) - def testNoDependency(self): - root = checkpointable.Checkpointable() - hasdep = checkpointable.Checkpointable() - root.hasdep = hasdep - nodep = checkpointable.Checkpointable() - root.nodep = checkpointable.NoDependency(nodep) - self.assertEqual(1, len(root._checkpoint_dependencies)) - self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) - self.assertIs(root.hasdep, hasdep) - self.assertIs(root.nodep, nodep) + def testAddVariableOverwrite(self): + root = base.CheckpointableBase() + a = root._add_variable_with_custom_getter( + name="v", shape=[], getter=variable_scope.get_variable) + self.assertEqual([root, a], util.list_objects(root)) + with ops.Graph().as_default(): + b = root._add_variable_with_custom_getter( + name="v", shape=[], overwrite=True, + getter=variable_scope.get_variable) + self.assertEqual([root, b], util.list_objects(root)) + with ops.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, "already declared as a dependency"): + root._add_variable_with_custom_getter( + name="v", shape=[], overwrite=False, + getter=variable_scope.get_variable) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py index 69ed253fb2d874954ee7563cd8bb21add59a7318..507cda87349cda25012a0170f230637d0a9758bc 100644 --- a/tensorflow/python/training/checkpointable/data_structures.py +++ b/tensorflow/python/training/checkpointable/data_structures.py @@ -21,54 +21,129 @@ import collections import six -from tensorflow.python.keras.engine import base_layer -from tensorflow.python.keras.utils import layer_utils from tensorflow.python.ops import variables -from tensorflow.python.training.checkpointable import base as checkpointable_lib -from tensorflow.python.training.checkpointable import data_structures_base - - -# TODO(allenl): We could track regular Python data structures which get assigned -# to Checkpointable objects. Making this work with restore-on-create would be -# tricky; we'd need to re-create nested structures with our own wrapped objects -# on assignment to an attribute, and track the user's original structure to make -# sure they don't modify it except through the wrappers (since we could save the -# user's updated structure, but would have no way to support restore-on-create -# for those modifications). -# TODO(allenl): A dictionary data structure would be good too. -class CheckpointableDataStructure( - data_structures_base.CheckpointableDataStructureBase): +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import layer_utils + + +class NoDependency(object): + """Allows attribute assignment to `Checkpointable` objects with no dependency. + + Example usage: + ```python + obj = Checkpointable() + obj.has_dependency = tf.Variable(0., name="dep") + obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) + assert obj.no_dependency.name == "nodep:0" + ``` + + `obj` in this example has a dependency on the variable "dep", and both + attributes contain un-wrapped `Variable` objects. + + `NoDependency` also works with `tf.keras.Model`, but only for checkpoint + dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) + `Layer` to the attribute without a checkpoint dependency, but the `Model` will + still track the `Layer` (so it will appear in `Model.layers`, and its + variables will appear in `Model.variables`). + """ + + def __init__(self, value): + self.value = value + + +def _wrap_or_unwrap(value): + """Wraps basic data structures, unwraps NoDependency objects.""" + if isinstance(value, NoDependency): + return value.value + if isinstance(value, base.CheckpointableBase): + return value # Skip conversion for already checkpointable objects. + elif isinstance(value, dict): + return _DictWrapper(value) + elif isinstance(value, list): + return _ListWrapper(value) + else: + return value + # TODO(allenl): Handle other common data structures. Tuples will require + # special casing (tuple subclasses are not weak referenceable, so replacement + # with a wrapper that subclasses tuple on attribute assignment works poorly, + # and replacement with a wrapper that isn't a tuple is also problematic), + # probably a tree traversal where the leaves are non-tuples(/namedtuples) to + # come up with names. Dictionaries should look like lists. + + +def sticky_attribute_assignment(checkpointable, name, value): + """Adds dependencies, generally called from __setattr__. + + This behavior is shared between Checkpointable and Model. + + Respects NoDependency indicators, but otherwise makes checkpointable objects + out of common data structures and tracks objects by their attribute names. + + Args: + checkpointable: The object to add dependencies to (generally the one having + an attribute assigned). + name: The attribute name being assigned. + value: The value being assigned. Not necessarily a checkpointable object. + + Returns: + The value which should be stored in the attribute (unwrapped from a + NoDependency object if necessary). + """ + if isinstance(value, NoDependency): + add_dependency = False + else: + add_dependency = True + value = _wrap_or_unwrap(value) + if not add_dependency: + return value + if isinstance(value, base.CheckpointableBase): + checkpointable._track_checkpointable( # pylint: disable=protected-access + value, name=name, + # Allow the user to switch the Checkpointable which is tracked by this + # name, since assigning a new variable to an attribute has + # historically been fine (e.g. Adam did this). + overwrite=True) + return value + + +class CheckpointableDataStructure(base.CheckpointableBase): """Base class for data structures which contain checkpointable objects.""" def __init__(self): + # An append-only ordered set self._layers = [] + self.trainable = True self._extra_variables = [] def _track_value(self, value, name): """Add a dependency on `value`.""" - if isinstance(value, checkpointable_lib.CheckpointableBase): - self._track_checkpointable(value, name=name) - if isinstance(value, variables.Variable): - self._extra_variables.append(value) - else: + value = sticky_attribute_assignment( + checkpointable=self, value=value, name=name) + if isinstance(value, variables.Variable): + self._extra_variables.append(value) + if not isinstance(value, base.CheckpointableBase): raise ValueError( ("Only checkpointable objects (such as Layers or Optimizers) may be " "stored in a List object. Got %s, which does not inherit from " "CheckpointableBase.") % (value,)) - if isinstance(value, ( - base_layer.Layer, - data_structures_base.CheckpointableDataStructureBase)): - if value not in self._layers: + if (isinstance(value, CheckpointableDataStructure) + or layer_utils.is_layer(value)): + # Check for object-identity rather than with __eq__ to avoid + # de-duplicating empty container types. Automatically generated list + # wrappers keep things like "[] == []" true, which means "[] in [[]]" is + # also true. This becomes not true once one of the lists is mutated. + if not any((layer is value for layer in self._layers)): self._layers.append(value) if hasattr(value, "_use_resource_variables"): # In subclassed models, legacy layers (tf.layers) must always use # resource variables. value._use_resource_variables = True # pylint: disable=protected-access + return value @property def layers(self): - return self._layers + return layer_utils.filter_empty_layer_containers(self._layers) @property def trainable_weights(self): @@ -168,24 +243,28 @@ class List(CheckpointableDataStructure, collections.Sequence): def __init__(self, *args, **kwargs): """Construct a new sequence. Arguments are passed to `list()`.""" super(List, self).__init__() - self._storage = list(*args, **kwargs) + self._storage = self._make_storage(*args, **kwargs) for index, element in enumerate(self._storage): - self._track_value(element, name=self._name_element(index)) + self._storage[index] = self._track_value( + element, name=self._name_element(index)) + + def _make_storage(self, *args, **kwargs): + """Determines the backing storage (overridden in subclasses).""" + return list(*args, **kwargs) def _name_element(self, index): return "%d" % (index,) def append(self, value): """Add a new checkpointable value.""" - self._track_value(value, self._name_element(len(self._storage))) + value = self._track_value(value, self._name_element(len(self._storage))) self._storage.append(value) def extend(self, values): """Add a sequence of checkpointable values.""" - for index_offset, value in enumerate(values): - self._track_value( - value, name=self._name_element(len(self._storage) + index_offset)) - self._storage.extend(values) + for value in values: + self._storage.append(self._track_value( + value, name=self._name_element(len(self._storage)))) def __iadd__(self, values): self.extend(values) @@ -193,9 +272,12 @@ class List(CheckpointableDataStructure, collections.Sequence): def __add__(self, other): if isinstance(other, List): - return List(self._storage + other._storage) # pylint: disable=protected-access + return self.__class__(self._storage + other._storage) # pylint: disable=protected-access else: - return List(self._storage + other) + return self.__class__(self._storage + other) + + def __radd__(self, other): + return self + other def __getitem__(self, key): return self._storage[key] @@ -207,6 +289,144 @@ class List(CheckpointableDataStructure, collections.Sequence): return "List(%s)" % (repr(self._storage),) +class _ListWrapper(List, collections.MutableSequence, + # Shadowed, but there for isinstance checks. + list): + """Wraps the built-in `list` to support restore-on-create for variables. + + Unlike `List`, this sequence type is mutable in the same ways built-in lists + are. Instead of throwing an error immediately like `List`, it records + problematic mutations (e.g. assigning a new element to a position already + occupied, meaning both elements get the same names at different times) and + refuses to save. + + On assignment to an attribute of a Model or Checkpointable object, Python + lists are replaced with _ListWrapper. Wrapping a list in a + `tf.contrib.checkpoint.NoDependency` object prevents this. + """ + + def __init__(self, wrapped_list): + """Construct a new list wrapper. + + Args: + wrapped_list: The initial value of the data structure. A shallow copy may + be maintained for error checking. `wrapped_list` itself should not be + modified directly after constructing the `_ListWrapper`, and if changes + are detected the `_ListWrapper` will throw an exception on save. + """ + # Monotonic flags which indicate this object would not be restored properly, + # and therefore should throw an error on save to avoid giving the impression + # that restoring it will work. + self._non_append_mutation = False + self._external_modification = False + super(_ListWrapper, self).__init__(wrapped_list) + self._last_wrapped_list_snapshot = list(self._storage) + + def _make_storage(self, wrapped_list): + """Use the user's original list for storage.""" + return wrapped_list + + def _check_external_modification(self): + """Checks for any changes to the wrapped list not through the wrapper.""" + if self._external_modification or self._non_append_mutation: + return + if self._storage != self._last_wrapped_list_snapshot: + self._external_modification = True + self._last_wrapped_list_snapshot = None + + def _update_snapshot(self): + """Acknowledges tracked changes to the wrapped list.""" + if self._external_modification or self._non_append_mutation: + return + self._last_wrapped_list_snapshot = list(self._storage) + + @property + def _checkpoint_dependencies(self): + self._check_external_modification() + if self._non_append_mutation: + raise ValueError( + ("Unable to save the object %s (a list wrapper constructed to track " + "checkpointable TensorFlow objects). A list element was replaced " + "(__setitem__), deleted, or inserted. In order to support " + "restoration on object creation, tracking is exclusively for " + "append-only data structures.\n\nIf you don't need this list " + "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency " + "object; it will be automatically un-wrapped and subsequently " + "ignored." % (self,))) + if self._external_modification: + raise ValueError( + ("Unable to save the object %s (a list wrapper constructed to track " + "checkpointable TensorFlow objects). The wrapped list was modified " + "outside the wrapper (its final value was %s, its value when a " + "checkpoint dependency was added was %s), which breaks restoration " + "on object creation.\n\nIf you don't need this list checkpointed, " + "wrap it in a tf.contrib.checkpoint.NoDependency object; it will be " + "automatically un-wrapped and subsequently ignored." % ( + self, self._storage, self._last_wrapped_list_snapshot))) + return super(_ListWrapper, self)._checkpoint_dependencies + + def __delitem__(self, key): + self._non_append_mutation = True + del self._storage[key] + + def __setitem__(self, key, value): + self._non_append_mutation = True + self._storage[key] = value + + def append(self, value): + """Add a new checkpointable value.""" + self._check_external_modification() + super(_ListWrapper, self).append(value) + self._update_snapshot() + + def extend(self, values): + """Add a sequence of checkpointable values.""" + self._check_external_modification() + super(_ListWrapper, self).extend(values) + self._update_snapshot() + + def __eq__(self, other): + return self._storage == getattr(other, "_storage", other) + + def __ne__(self, other): + return self._storage != getattr(other, "_storage", other) + + def __lt__(self, other): + return self._storage < getattr(other, "_storage", other) + + def __le__(self, other): + return self._storage <= getattr(other, "_storage", other) + + def __gt__(self, other): + return self._storage > getattr(other, "_storage", other) + + def __ge__(self, other): + return self._storage >= getattr(other, "_storage", other) + + def __hash__(self): + # List wrappers need to compare like regular lists, and so like regular + # lists they don't belong in hash tables. + raise TypeError("unhashable type: 'ListWrapper'") + + def insert(self, index, obj): + self._non_append_mutation = True + self._storage.insert(index, obj) + + def _track_value(self, value, name): + """Allows storage of non-checkpointable objects.""" + try: + value = super(_ListWrapper, self)._track_value(value=value, name=name) + except ValueError: + # Even if this value isn't checkpointable, we need to make sure + # NoDependency objects get unwrapped. + value = sticky_attribute_assignment( + checkpointable=self, value=value, name=name) + return value + + def __repr__(self): + return "ListWrapper(%s)" % (repr(self._storage),) + + class Mapping(CheckpointableDataStructure, collections.Mapping): """An append-only checkpointable mapping data structure with string keys. @@ -220,9 +440,14 @@ class Mapping(CheckpointableDataStructure, collections.Mapping): def __init__(self, *args, **kwargs): """Construct a new sequence. Arguments are passed to `dict()`.""" super(Mapping, self).__init__() - self._storage = dict(*args, **kwargs) - for key, value in self._storage.items(): - self._track_value(value, name=self._name_element(key)) + self._storage = self._make_storage(*args, **kwargs) + self._storage.update( + {key: self._track_value( + value, name=self._name_element(key)) + for key, value in self._storage.items()}) + + def _make_storage(self, *args, **kwargs): + return dict(*args, **kwargs) def _name_element(self, key): if not isinstance(key, six.string_types): @@ -232,13 +457,14 @@ class Mapping(CheckpointableDataStructure, collections.Mapping): return str(key) def __setitem__(self, key, value): + name = self._name_element(key) + value = self._track_value(value, name=name) current_value = self._storage.setdefault(key, value) if current_value is not value: raise ValueError( ("Mappings are an append-only data structure. Tried to overwrite the " "key '%s' with value %s, but it already contains %s") % (key, value, current_value)) - self._track_value(value, name=self._name_element(key)) def update(self, *args, **kwargs): for key, value in dict(*args, **kwargs).items(): @@ -255,3 +481,185 @@ class Mapping(CheckpointableDataStructure, collections.Mapping): def __iter__(self): return iter(self._storage) + + +# Unlike _ListWrapper, having _DictWrapper inherit from dict and pass isinstance +# checks seems infeasible. CPython will not call Python methods/properties on +# dictionary subclasses when running e.g. {}.update(dict_subclass), and instead +# collects elements directly from dict_subclass's C structs. So subclassing dict +# implies that the storage has to be "self" (i.e. the C structs for the object +# must be updated correctly), but we also need that storage to be the wrapped +# dictionary to avoid synchronization bugs (un-tracked external modifications +# should still show up when the dict is accessed through the wrapper). Monkey +# patching all of the "wrapped" dict's methods instead of creating a wrapper +# object is an option, but not a very attractive one (replacing methods without +# creating reference cycles is difficult, and then dicts would need to be +# special cased everywhere as being checkpointable). +class _DictWrapper(Mapping, collections.MutableMapping): + """Wraps built-in dicts to support restore-on-create for variables. + + _DictWrapper is to Mapping as _ListWrapper is to List. Unlike Mapping, + _DictWrapper allows non-string keys and values and arbitrary mutations (delete + keys, reassign values). Like _ListWrapper, these mutations mean that + _DictWrapper will raise an exception on save. + """ + + def __new__(cls, *args): + if len(args) == 1 and isinstance(args[0], dict): + return super(_DictWrapper, cls).__new__(cls) + else: + # Allow construction from a sequence, e.g. for nest.pack_sequence_as. In + # this case there's nothing to wrap, so we make a normal dictionary. Also + # allows constructing empty instances of the _DictWrapper type, as Session + # is wont to do (and again there's nothing to wrap, so a normal dictionary + # makes more sense). + return dict(*args) + + def __init__(self, wrapped_dict): + self._non_string_key = False + self._non_append_mutation = False + self._external_modification = False + super(_DictWrapper, self).__init__(wrapped_dict) + self._update_snapshot() + + def _make_storage(self, wrapped_dict): + """Re-use the wrapped dict for storage (to force them to be in sync).""" + return wrapped_dict + + @property + def _checkpoint_dependencies(self): + """Check that the object is saveable before listing its dependencies.""" + self._check_external_modification() + if self._non_string_key: + raise ValueError( + "Unable to save the object %s (a dictionary wrapper constructed " + "automatically on attribute assignment). The wrapped dictionary " + "contains a non-string key which maps to a checkpointable object or " + "mutable data structure.\n\nIf you don't need this dictionary " + "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency " + "object; it will be automatically un-wrapped and subsequently " + "ignored." % (self,)) + if self._non_append_mutation: + raise ValueError( + "Unable to save the object %s (a dictionary wrapper constructed " + "automatically on attribute assignment). A key mapping to a " + "checkpointable object was overwritten or deleted, which would " + "cause problems for restoration.\n\nIf you don't need this " + "dictionary checkpointed, wrap it in a " + "tf.contrib.checkpoint.NoDependency object; it will be automatically " + "un-wrapped and subsequently ignored." % (self,)) + if self._external_modification: + raise ValueError( + "Unable to save the object %s (a dictionary wrapper constructed " + "automatically on attribute assignment). The wrapped dictionary was " + "modified outside the wrapper (its final value was %s, its value " + "when a checkpoint dependency was added was %s), which breaks " + "restoration on object creation.\n\nIf you don't need this " + "dictionary checkpointed, wrap it in a " + "tf.contrib.checkpoint.NoDependency object; it will be automatically " + "un-wrapped and subsequently ignored." % ( + self, self, self._last_wrapped_dict_snapshot)) + assert not self._dirty # Any reason for dirtiness should have an exception. + return super(_DictWrapper, self)._checkpoint_dependencies + + @property + def _dirty(self): + """Check if there has already been a mutation which prevents saving.""" + return (self._external_modification + or self._non_append_mutation + or self._non_string_key) + + def _check_external_modification(self): + """Checks for any changes to the wrapped dict not through the wrapper.""" + if self._dirty: + return + if self != self._last_wrapped_dict_snapshot: + self._external_modification = True + self._last_wrapped_dict_snapshot = None + + def _update_snapshot(self): + """Acknowledges tracked changes to the wrapped dict.""" + if self._dirty: + return + self._last_wrapped_dict_snapshot = dict(self) + + def _track_value(self, value, name): + """Allows storage of non-checkpointable objects.""" + if isinstance(name, six.string_types): + string_key = True + else: + name = "-non_string_key" + string_key = False + try: + no_dependency = isinstance(value, NoDependency) + value = super(_DictWrapper, self)._track_value(value=value, name=name) + if not (string_key or no_dependency): + # A non-string key maps to a checkpointable value. This data structure + # is not saveable. + self._non_string_key = True + return value + except ValueError: + # Even if this value isn't checkpointable, we need to make sure + # NoDependency objects get unwrapped. + return sticky_attribute_assignment( + checkpointable=self, value=value, name=name) + + def _name_element(self, key): + """Don't throw errors for non-string keys.""" + if isinstance(key, six.string_types): + return super(_DictWrapper, self)._name_element(key) + else: + return key + + def __setitem__(self, key, value): + """Allow any modifications, but possibly mark the wrapper as unsaveable.""" + self._check_external_modification() + no_dep = isinstance(value, NoDependency) + if isinstance(key, six.string_types): + existing_dependency = self._lookup_dependency(key) + value = self._track_value(value, name=key) + else: + value = _wrap_or_unwrap(value) + existing_dependency = None + if not no_dep and isinstance(value, base.CheckpointableBase): + # Non-string keys are OK as long as we have no reason to add a + # dependency on the value (either because the value is not + # checkpointable, or because it was wrapped in a NoDependency object). + self._non_string_key = True + current_value = self._storage.setdefault(key, value) + if current_value is not value: + if ((not no_dep and isinstance(value, base.CheckpointableBase)) + # We don't want to just check that the existing object is + # checkpointable, since it may have been wrapped in a NoDependency + # object. + or existing_dependency is not None): + # A checkpointable object was replaced under the same key; this means + # that restoring would be error-prone, so we'll throw an exception on + # save. + self._non_append_mutation = True + self._storage[key] = value + + self._update_snapshot() + + def __delitem__(self, key): + self._check_external_modification() + existing_value = self[key] + if isinstance(existing_value, base.CheckpointableBase): + # Deleting tracked checkpointable values means restoring is problematic, + # so we'll throw an exception on save. + self._non_append_mutation = True + del self._storage[key] + self._update_snapshot() + + def __repr__(self): + return "DictWrapper(%s)" % (repr(self._storage),) + + def __hash__(self): + raise TypeError("unhashable type: 'DictWrapper'") + + def __eq__(self, other): + return self._storage == getattr(other, "_storage", other) + + def update(self, *args, **kwargs): + for key, value in dict(*args, **kwargs).items(): + self[key] = value diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py index b05b3a88002e31560ed6c2005fdd29f56c5227a3..472b7c32b48ac02d6719a59a6020e97ff9c46cc2 100644 --- a/tensorflow/python/training/checkpointable/data_structures_test.py +++ b/tensorflow/python/training/checkpointable/data_structures_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import os import numpy +import six from tensorflow.python.eager import context from tensorflow.python.eager import test @@ -31,6 +32,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util class HasList(training.Model): @@ -66,16 +69,19 @@ class HasList(training.Model): class ListTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTracking(self): model = HasList() output = model(array_ops.ones([32, 2])) self.assertAllEqual([32, 12], output.shape) - self.assertEqual(2, len(model.layers)) - self.assertIs(model.layer_list, model.layers[0]) - self.assertEqual(10, len(model.layers[0].layers)) + self.assertEqual(11, len(model.layers)) + self.assertEqual(10, len(model.layer_list.layers)) + six.assertCountEqual( + self, + model.layers, + model.layer_list.layers + model.layers_with_updates) for index in range(10): - self.assertEqual(3 + index, model.layers[0].layers[index].units) + self.assertEqual(3 + index, model.layer_list.layers[index].units) self.assertEqual(2, len(model._checkpoint_dependencies)) self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref) self.assertIs(model.layers_with_updates, @@ -106,13 +112,28 @@ class ListTests(test.TestCase): model(model_input) self.assertEqual(0, len(model.updates)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLossesForwarded(self): model = HasList() model_input = array_ops.ones([32, 2]) model(model_input) self.assertEqual(2, len(model.losses)) + def testModelContainersCompareEqual(self): + class HasEqualContainers(training.Model): + + def __init__(self): + super(HasEqualContainers, self).__init__() + self.l1 = [] + self.l2 = [] + + model = HasEqualContainers() + first_layer = HasEqualContainers() + model.l1.append(first_layer) + second_layer = HasEqualContainers() + model.l2.append(second_layer) + self.assertEqual([first_layer, second_layer], model.layers) + def testNotCheckpointable(self): class NotCheckpointable(object): pass @@ -158,11 +179,62 @@ class ListTests(test.TestCase): self.assertEqual([v], l.trainable_weights) self.assertEqual([v2], l.non_trainable_weights) + def testListWrapperBasic(self): + # _ListWrapper, unlike List, compares like the built-in list type (since it + # is used to automatically replace lists). + a = tracking.Checkpointable() + b = tracking.Checkpointable() + self.assertEqual([a, a], + [a, a]) + self.assertEqual(data_structures._ListWrapper([a, a]), + data_structures._ListWrapper([a, a])) + self.assertEqual([a, a], + data_structures._ListWrapper([a, a])) + self.assertEqual(data_structures._ListWrapper([a, a]), + [a, a]) + self.assertNotEqual([a, a], + [b, a]) + self.assertNotEqual(data_structures._ListWrapper([a, a]), + data_structures._ListWrapper([b, a])) + self.assertNotEqual([a, a], + data_structures._ListWrapper([b, a])) + self.assertLess([a], [a, b]) + self.assertLess(data_structures._ListWrapper([a]), + data_structures._ListWrapper([a, b])) + self.assertLessEqual([a], [a, b]) + self.assertLessEqual(data_structures._ListWrapper([a]), + data_structures._ListWrapper([a, b])) + self.assertGreater([a, b], [a]) + self.assertGreater(data_structures._ListWrapper([a, b]), + data_structures._ListWrapper([a])) + self.assertGreaterEqual([a, b], [a]) + self.assertGreaterEqual(data_structures._ListWrapper([a, b]), + data_structures._ListWrapper([a])) + self.assertEqual([a], data_structures._ListWrapper([a])) + self.assertEqual([a], list(data_structures.List([a]))) + self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a]) + self.assertEqual([a, a], [a] + data_structures._ListWrapper([a])) + self.assertIsInstance(data_structures._ListWrapper([a]), list) + + def testWrapperChangesList(self): + l = [] + l_wrapper = data_structures._ListWrapper(l) + l_wrapper.append(1) + self.assertEqual([1], l) + + def testListChangesWrapper(self): + l = [] + l_wrapper = data_structures._ListWrapper(l) + l.append(1) + self.assertEqual([1], l_wrapper) + def testHashing(self): has_sequences = set([data_structures.List(), data_structures.List()]) self.assertEqual(2, len(has_sequences)) self.assertNotIn(data_structures.List(), has_sequences) + with self.assertRaises(TypeError): + has_sequences.add(data_structures._ListWrapper([])) class HasMapping(training.Model): @@ -190,14 +262,13 @@ class HasMapping(training.Model): class MappingTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testTracking(self): model = HasMapping() output = model(array_ops.ones([32, 2])) self.assertAllEqual([32, 7], output.shape) - self.assertEqual(1, len(model.layers)) - self.assertIs(model.layer_dict, model.layers[0]) - self.assertEqual(3, len(model.layers[0].layers)) + self.assertEqual(5, len(model.layers)) + six.assertCountEqual(self, model.layers, model.layer_dict.layers) self.assertEqual(1, len(model._checkpoint_dependencies)) self.assertIs(model.layer_dict, model._checkpoint_dependencies[0].ref) self.evaluate([v.initializer for v in model.variables]) @@ -233,6 +304,124 @@ class MappingTests(test.TestCase): data_structures.Mapping()]) self.assertEqual(2, len(has_mappings)) self.assertNotIn(data_structures.Mapping(), has_mappings) + # In contrast to Mapping, dict wrappers are not hashable + a = tracking.Checkpointable() + a.d = {} + self.assertEqual({}, a.d) + self.assertFalse({} != a.d) # pylint: disable=g-explicit-bool-comparison + self.assertNotEqual({1: 2}, a.d) + with self.assertRaisesRegexp(TypeError, "unhashable"): + set([a.d]) + + def testDictWrapperBadKeys(self): + a = tracking.Checkpointable() + a.d = {} + a.d[1] = data_structures.List() + model = training.Model() + model.sub = a + save_path = os.path.join(self.get_temp_dir(), "ckpt") + with self.assertRaisesRegexp(ValueError, "non-string key"): + model.save_weights(save_path) + + def testDictWrapperNoDependency(self): + a = tracking.Checkpointable() + a.d = data_structures.NoDependency({}) + a.d[1] = [3] + self.assertEqual([a], util.list_objects(a)) + model = training.Model() + model.sub = a + save_path = os.path.join(self.get_temp_dir(), "ckpt") + model.save_weights(save_path) + model.load_weights(save_path) + + def testNonStringKeyNotCheckpointableValue(self): + a = tracking.Checkpointable() + a.d = {} + a.d["a"] = [3] + a.d[1] = data_structures.NoDependency([3]) + self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) + model = training.Model() + model.sub = a + save_path = os.path.join(self.get_temp_dir(), "ckpt") + model.save_weights(save_path) + model.load_weights(save_path) + + def testNonAppendNotCheckpointable(self): + # Non-append mutations (deleting or overwriting values) are OK when the + # values aren't tracked. + a = tracking.Checkpointable() + a.d = {} + a.d["a"] = [3] + a.d[1] = 3 + a.d[1] = 2 + self.assertEqual(2, a.d[1]) + del a.d[1] + a.d[2] = data_structures.NoDependency(tracking.Checkpointable()) + second = tracking.Checkpointable() + a.d[2] = data_structures.NoDependency(second) + self.assertIs(second, a.d[2]) + self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) + model = training.Model() + model.sub = a + save_path = os.path.join(self.get_temp_dir(), "ckpt") + model.save_weights(save_path) + model.load_weights(save_path) + + def testDelNoSave(self): + model = training.Model() + model.d = {} + model.d["a"] = [] + del model.d["a"] + save_path = os.path.join(self.get_temp_dir(), "ckpt") + with self.assertRaisesRegexp(ValueError, "overwritten or deleted"): + model.save_weights(save_path) + + def testPopNoSave(self): + model = training.Model() + model.d = {} + model.d["a"] = [] + model.d.pop("a") + save_path = os.path.join(self.get_temp_dir(), "ckpt") + with self.assertRaisesRegexp(ValueError, "overwritten or deleted"): + model.save_weights(save_path) + + def testExternalModificationNoSave(self): + model = training.Model() + external_reference = {} + model.d = external_reference + external_reference["a"] = [] + save_path = os.path.join(self.get_temp_dir(), "ckpt") + with self.assertRaisesRegexp(ValueError, "modified outside the wrapper"): + model.save_weights(save_path) + + def testOverwriteNoSave(self): + model = training.Model() + model.d = {} + model.d["a"] = {} + model.d["a"] = {} + save_path = os.path.join(self.get_temp_dir(), "ckpt") + with self.assertRaisesRegexp(ValueError, "overwritten or deleted"): + model.save_weights(save_path) + + def testIter(self): + model = training.Model() + model.d = {1: 3} + model.d[1] = 3 + self.assertEqual([1], list(model.d)) + new_dict = {} + # This update() is super tricky. If the dict wrapper subclasses dict, + # CPython will access its storage directly instead of calling any + # methods/properties on the object. So the options are either not to + # subclass dict (in which case update will call normal iter methods, but the + # object won't pass isinstance checks) or to subclass dict and keep that + # storage updated (no shadowing all its methods like _ListWrapper). + new_dict.update(model.d) + self.assertEqual({1: 3}, new_dict) + + def testConstructableFromSequence(self): + result = data_structures._DictWrapper([(1, 2), (3, 4)]) + self.assertIsInstance(result, dict) + self.assertEqual({1: 2, 3: 4}, result) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d65b631fe9b855af20329b35cd3f725004a89822 --- /dev/null +++ b/tensorflow/python/training/checkpointable/layer_utils.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities related to layer/model functionality.""" + +# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils +# once __init__ files no longer require all of tf.keras to be imported together. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def is_layer(obj): + """Implicit check for Layer-like objects.""" + # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). + return (hasattr(obj, "call") + and hasattr(obj, "build") + and hasattr(obj, "variables")) + + +def filter_empty_layer_containers(layer_list): + """Filter out empty Layer-like containers.""" + filtered = [] + for obj in layer_list: + if is_layer(obj): + filtered.append(obj) + else: + # Checkpointable data structures will not show up in ".layers" lists, but + # the layers they contain will. + filtered.extend(obj.layers) + return filtered + + +def gather_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected trainable weights/variables. + """ + if not trainable: + return [] + weights = [] + for layer in sub_layers: + weights += layer.trainable_weights + trainable_extra_variables = [ + v for v in extra_variables if v.trainable] + return weights + trainable_extra_variables + + +def gather_non_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the non-trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected non-trainable weights/variables. + """ + trainable_extra_variables = [] + non_trainable_extra_variables = [] + for v in extra_variables: + if v.trainable: + trainable_extra_variables.append(v) + else: + non_trainable_extra_variables.append(v) + weights = [] + for layer in sub_layers: + weights += layer.non_trainable_weights + if not trainable: + trainable_weights = [] + for layer in sub_layers: + trainable_weights += layer.trainable_weights + return (trainable_weights + trainable_extra_variables + + weights + non_trainable_extra_variables) + return weights + non_trainable_extra_variables diff --git a/tensorflow/python/training/checkpointable/tracking.py b/tensorflow/python/training/checkpointable/tracking.py new file mode 100644 index 0000000000000000000000000000000000000000..bd0bed9d46f2e75633e3bf1230eded3708ec1c8b --- /dev/null +++ b/tensorflow/python/training/checkpointable/tracking.py @@ -0,0 +1,72 @@ +"""Dependency tracking for checkpointable objects.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import data_structures + + +class NotCheckpointable(object): + """Marks instances of child classes as unsaveable using an object-based API. + + Useful for marking objects which would otherwise look checkpointable because + of inheritance (e.g. through `Layer`) as not checkpointable. Inheriting from + `NotCheckpointable` does not prevent an object from being assigned to any + attributes, but will throw an error on save/restore. + """ + pass + + +class Checkpointable(base.CheckpointableBase): + """Manages dependencies on other objects. + + `Checkpointable` objects may have dependencies: other `Checkpointable` objects + which should be saved if the object declaring the dependency is saved. A + correctly saveable program has a dependency graph such that if changing a + global variable affects an object (e.g. changes the behavior of any of its + methods) then there is a chain of dependencies from the influenced object to + the variable. + + Dependency edges have names, and are created implicitly when a + `Checkpointable` object is assigned to an attribute of another + `Checkpointable` object. For example: + + ``` + obj = Checkpointable() + obj.v = ResourceVariable(0.) + ``` + + The `Checkpointable` object `obj` now has a dependency named "v" on a + variable. + + `Checkpointable` objects may specify `Tensor`s to be saved and restored + directly (e.g. a `Variable` indicating how to save itself) rather than through + dependencies on other objects. See + `Checkpointable._gather_saveables_for_checkpoint` for details. + """ + + def __setattr__(self, name, value): + """Support self.foo = checkpointable syntax.""" + if getattr(self, "_setattr_tracking", True): + value = data_structures.sticky_attribute_assignment( + checkpointable=self, value=value, name=name) + super(Checkpointable, self).__setattr__(name, value) + + def _no_dependency(self, value): + """Override to allow CheckpointableBase to disable dependency tracking.""" + return data_structures.NoDependency(value) diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e85f812ce211b0db6080cf3ecbaee11b8420bb6e --- /dev/null +++ b/tensorflow/python/training/checkpointable/tracking_test.py @@ -0,0 +1,197 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy +import six + +from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util +from tensorflow.python.util import nest + + +class InterfaceTests(test.TestCase): + + def testMultipleAssignment(self): + root = tracking.Checkpointable() + root.leaf = tracking.Checkpointable() + root.leaf = root.leaf + duplicate_name_dep = tracking.Checkpointable() + with self.assertRaisesRegexp(ValueError, "already declared"): + root._track_checkpointable(duplicate_name_dep, name="leaf") + # No error; we're overriding __setattr__, so we can't really stop people + # from doing this while maintaining backward compatibility. + root.leaf = duplicate_name_dep + root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True) + self.assertIs(duplicate_name_dep, root._lookup_dependency("leaf")) + (_, dep_object), = root._checkpoint_dependencies + self.assertIs(duplicate_name_dep, dep_object) + + def testNoDependency(self): + root = tracking.Checkpointable() + hasdep = tracking.Checkpointable() + root.hasdep = hasdep + nodep = tracking.Checkpointable() + root.nodep = data_structures.NoDependency(nodep) + self.assertEqual(1, len(root._checkpoint_dependencies)) + self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) + self.assertIs(root.hasdep, hasdep) + self.assertIs(root.nodep, nodep) + + class NoDependencyModel(training.Model): + + @base.no_automatic_dependency_tracking + def __init__(self): + super(NoDependencyModel, self).__init__() + self.a = [] + self.b = tracking.Checkpointable() + + nodeps = NoDependencyModel() + self.assertEqual([nodeps], util.list_objects(nodeps)) + + def testListBasic(self): + a = tracking.Checkpointable() + b = tracking.Checkpointable() + a.l = [b] + c = tracking.Checkpointable() + a.l.append(c) + a_deps = util.list_objects(a) + self.assertIn(b, a_deps) + self.assertIn(c, a_deps) + direct_a_dep, = a._checkpoint_dependencies + self.assertEqual("l", direct_a_dep.name) + self.assertIn(b, direct_a_dep.ref) + self.assertIn(c, direct_a_dep.ref) + + @test_util.run_in_graph_and_eager_modes + def testMutationDirtiesList(self): + a = tracking.Checkpointable() + b = tracking.Checkpointable() + a.l = [b] + c = tracking.Checkpointable() + a.l.insert(0, c) + checkpoint = util.Checkpoint(a=a) + with self.assertRaisesRegexp(ValueError, "A list element was replaced"): + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + + @test_util.run_in_graph_and_eager_modes + def testOutOfBandEditDirtiesList(self): + a = tracking.Checkpointable() + b = tracking.Checkpointable() + held_reference = [b] + a.l = held_reference + c = tracking.Checkpointable() + held_reference.append(c) + checkpoint = util.Checkpoint(a=a) + with self.assertRaisesRegexp(ValueError, "The wrapped list was modified"): + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + + @test_util.run_in_graph_and_eager_modes + def testNestedLists(self): + a = tracking.Checkpointable() + a.l = [] + b = tracking.Checkpointable() + a.l.append([b]) + c = tracking.Checkpointable() + a.l[0].append(c) + a_deps = util.list_objects(a) + self.assertIn(b, a_deps) + self.assertIn(c, a_deps) + a.l[0].append(1) + d = tracking.Checkpointable() + a.l[0].append(d) + a_deps = util.list_objects(a) + self.assertIn(d, a_deps) + self.assertIn(b, a_deps) + self.assertIn(c, a_deps) + self.assertNotIn(1, a_deps) + e = tracking.Checkpointable() + f = tracking.Checkpointable() + a.l1 = [[], [e]] + a.l1[0].append(f) + a_deps = util.list_objects(a) + self.assertIn(e, a_deps) + self.assertIn(f, a_deps) + checkpoint = util.Checkpoint(a=a) + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + a.l[0].append(data_structures.NoDependency([])) + a.l[0][-1].append(5) + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + # Dirtying the inner list means the root object is unsaveable. + a.l[0][1] = 2 + with self.assertRaisesRegexp(ValueError, "A list element was replaced"): + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + + @test_util.run_in_graph_and_eager_modes + def testDictionariesBasic(self): + a = training.Model() + b = training.Model() + a.attribute = {"b": b} + c = training.Model() + a.attribute["c"] = [] + a.attribute["c"].append(c) + a_deps = util.list_objects(a) + self.assertIn(b, a_deps) + self.assertIn(c, a_deps) + self.assertIs(b, a.attribute["b"]) + six.assertCountEqual( + self, + ["b", "c"], + [dep.name for dep in a.attribute._checkpoint_dependencies]) + self.assertEqual([b, c], a.layers) + self.assertEqual([b, c], a.attribute.layers) + self.assertEqual([c], a.attribute["c"].layers) + checkpoint = util.Checkpoint(a=a) + save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + with self.test_session(): + checkpoint.restore(save_path).assert_consumed().initialize_or_restore() + + @test_util.run_in_graph_and_eager_modes + def testNoDepList(self): + a = training.Model() + a.l1 = data_structures.NoDependency([]) + a.l1.insert(1, 0) + self.assertTrue(isinstance(a.l1, list)) + checkpoint = util.Checkpoint(a=a) + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + a.l2 = [] + a.l2.insert(1, 0) + with self.assertRaisesRegexp(ValueError, "A list element was replaced"): + checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) + + @test_util.run_in_graph_and_eager_modes + def testAssertions(self): + a = tracking.Checkpointable() + a.l = {"k": [numpy.zeros([2, 2])]} + self.assertAllEqual(nest.flatten({"k": [numpy.zeros([2, 2])]}), + nest.flatten(a.l)) + self.assertAllClose({"k": [numpy.zeros([2, 2])]}, a.l) + nest.map_structure(self.assertAllClose, a.l, {"k": [numpy.zeros([2, 2])]}) + a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]} + self.assertAllClose({"k": [numpy.ones([2, 2]), numpy.zeros([3, 3])]}, + self.evaluate(a.tensors)) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index 96e6d10791f396ad7f9f73cce9356dd4cbe3ce9d..e42f9894697456efdd963290df404802d0839694 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -19,6 +19,7 @@ from __future__ import print_function import abc import collections +import os import weakref from tensorflow.core.protobuf import checkpointable_object_graph_pb2 @@ -34,13 +35,17 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_io_ops as io_ops from tensorflow.python.ops import init_ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import saveable_object as saveable_object_lib from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import data_structures +from tensorflow.python.training.checkpointable import tracking from tensorflow.python.util import deprecation +from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export @@ -91,13 +96,14 @@ class _CheckpointRestoreCoordinator(object): # use them (for example because of inconsistent references when # loading). Used to make status assertions fail when loading checkpoints # that don't quite match. - self.all_python_objects = weakref.WeakSet() + self.all_python_objects = _ObjectIdentityWeakSet() self.save_path = save_path self.dtype_map = dtype_map # When graph building, contains a list of ops to run to restore objects from # this checkpoint. self.restore_ops = [] self.restore_ops_by_name = {} + self.new_restore_ops_callback = None # A mapping from optimizer proto ids to lists of slot variables to be # restored when the optimizer is tracked. Only includes slot variables whose # regular variables have already been created, and only for optimizer @@ -113,11 +119,16 @@ class _CheckpointRestoreCoordinator(object): # `node` refers to an `Optimizer`, since only these have slot variables. self.slot_restorations.setdefault( slot_reference.original_variable_node_id, []).append( - checkpointable_lib._SlotVariableRestoration( # pylint: disable=protected-access + base._SlotVariableRestoration( # pylint: disable=protected-access optimizer_id=node_index, slot_variable_id=slot_reference.slot_variable_node_id, slot_name=slot_reference.slot_name)) + def new_restore_ops(self, new_ops): + self.restore_ops.extend(new_ops) + if self.new_restore_ops_callback: + self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable + class _NameBasedRestoreCoordinator(object): """Keeps the status of a name-based checkpoint restore.""" @@ -216,10 +227,11 @@ def _default_getter(name, shape, dtype, initializer=None, def initial_value(): return initializer( shape_object.as_list(), dtype=dtype, partition_info=partition_info) - return resource_variable_ops.ResourceVariable( + return variables.Variable( initial_value=initial_value, name=name, dtype=variable_dtype, + use_resource=True, **kwargs ) @@ -257,27 +269,163 @@ def object_metadata(save_path): reader = pywrap_tensorflow.NewCheckpointReader(save_path) try: object_graph_string = reader.get_tensor( - checkpointable_lib.OBJECT_GRAPH_PROTO_KEY) + base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: raise ValueError( ('The specified checkpoint "%s" does not appear to be object-based (it ' 'is missing the key "%s"). Likely it was created with a name-based ' 'saver and does not contain an object dependency graph.') % ( - save_path, checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)) + save_path, base.OBJECT_GRAPH_PROTO_KEY)) object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) return object_graph_proto +class _ObjectIdentityWrapper(object): + """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped. + + Since __eq__ is based on object identity, it's safe to also define __hash__ + based on object ids. This lets us add unhashable types like checkpointable + _ListWrapper objects to object-identity collections. + """ + + def __init__(self, wrapped): + self._wrapped = wrapped + + @property + def unwrapped(self): + return self._wrapped + + def __eq__(self, other): + if isinstance(other, _ObjectIdentityWrapper): + return self._wrapped is other._wrapped # pylint: disable=protected-access + return self._wrapped is other + + def __hash__(self): + # Wrapper id() is also fine for weakrefs. In fact, we rely on + # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is + # weakref.ref(a) in _WeakObjectIdentityWrapper. + return id(self._wrapped) + + +class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper): + + def __init__(self, wrapped): + super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped)) + + @property + def unwrapped(self): + return self._wrapped() + + +class _ObjectIdentityDictionary(collections.MutableMapping): + """A mutable mapping data structure which compares using "is". + + This is necessary because we have checkpointable objects (_ListWrapper) which + have behavior identical to built-in Python lists (including being unhashable + and comparing based on the equality of their contents by default). + """ + + def __init__(self): + self._storage = {} + + def _wrap_key(self, key): + return _ObjectIdentityWrapper(key) + + def __getitem__(self, key): + return self._storage[self._wrap_key(key)] + + def __setitem__(self, key, value): + self._storage[self._wrap_key(key)] = value + + def __delitem__(self, key): + del self._storage[self._wrap_key(key)] + + def __len__(self): + return len(self._storage) + + def __iter__(self): + for key in self._storage: + yield key.unwrapped + + +class _ObjectIdentityWeakKeyDictionary(_ObjectIdentityDictionary): + """Like weakref.WeakKeyDictionary, but compares objects with "is".""" + + def _wrap_key(self, key): + return _WeakObjectIdentityWrapper(key) + + def __len__(self): + # Iterate, discarding old weak refs + return len(list(self._storage)) + + def __iter__(self): + keys = self._storage.keys() + for key in keys: + unwrapped = key.unwrapped + if unwrapped is None: + del self[key] + else: + yield unwrapped + + +class _ObjectIdentitySet(collections.MutableSet): + """Like the built-in set, but compares objects with "is".""" + + def __init__(self, *args): + self._storage = set([self._wrap_key(obj) for obj in list(*args)]) + + def _wrap_key(self, key): + return _ObjectIdentityWrapper(key) + + def __contains__(self, key): + return self._wrap_key(key) in self._storage + + def discard(self, key): + self._storage.discard(self._wrap_key(key)) + + def add(self, key): + self._storage.add(self._wrap_key(key)) + + def __len__(self): + return len(self._storage) + + def __iter__(self): + keys = list(self._storage) + for key in keys: + yield key.unwrapped + + +class _ObjectIdentityWeakSet(_ObjectIdentitySet): + """Like weakref.WeakSet, but compares objects with "is".""" + + def _wrap_key(self, key): + return _WeakObjectIdentityWrapper(key) + + def __len__(self): + # Iterate, discarding old weak refs + return len([_ for _ in self]) + + def __iter__(self): + keys = list(self._storage) + for key in keys: + unwrapped = key.unwrapped + if unwrapped is None: + self.discard(key) + else: + yield unwrapped + + def _breadth_first_checkpointable_traversal(root_checkpointable): """Find shortest paths to all variables owned by dependencies of root.""" bfs_sorted = [] to_visit = collections.deque([root_checkpointable]) - path_to_root = {root_checkpointable: ()} + path_to_root = _ObjectIdentityDictionary() + path_to_root[root_checkpointable] = () while to_visit: current_checkpointable = to_visit.popleft() - if isinstance(current_checkpointable, checkpointable_lib.NotCheckpointable): + if isinstance(current_checkpointable, tracking.NotCheckpointable): raise NotImplementedError( ("The object %s does not support object-based saving. File a feature " "request if this limitation bothers you. In the meantime, you can " @@ -335,7 +483,7 @@ def _slot_variable_naming_for_optimizer(optimizer_path): def _serialize_slot_variables(checkpointable_objects, node_ids, object_names): """Gather and name slot variables.""" non_slot_objects = list(checkpointable_objects) - slot_variables = {} + slot_variables = _ObjectIdentityDictionary() for checkpointable in non_slot_objects: if isinstance(checkpointable, optimizer_lib.Optimizer): naming_scheme = _slot_variable_naming_for_optimizer( @@ -498,11 +646,12 @@ def _serialize_object_graph(root_checkpointable, saveables_cache): """ checkpointable_objects, path_to_root = ( _breadth_first_checkpointable_traversal(root_checkpointable)) - object_names = { - obj: _object_prefix_from_path(path) - for obj, path in path_to_root.items()} - node_ids = {node: node_id for node_id, node - in enumerate(checkpointable_objects)} + object_names = _ObjectIdentityDictionary() + for obj, path in path_to_root.items(): + object_names[obj] = _object_prefix_from_path(path) + node_ids = _ObjectIdentityDictionary() + for node_id, node in enumerate(checkpointable_objects): + node_ids[node] = node_id slot_variables = _serialize_slot_variables( checkpointable_objects=checkpointable_objects, node_ids=node_ids, @@ -533,11 +682,12 @@ def list_objects(root_checkpointable): # to run. checkpointable_objects, path_to_root = ( _breadth_first_checkpointable_traversal(root_checkpointable)) - object_names = { - obj: _object_prefix_from_path(path) - for obj, path in path_to_root.items()} - node_ids = {node: node_id for node_id, node - in enumerate(checkpointable_objects)} + object_names = _ObjectIdentityDictionary() + for obj, path in path_to_root.items(): + object_names[obj] = _object_prefix_from_path(path) + node_ids = _ObjectIdentityDictionary() + for node_id, node in enumerate(checkpointable_objects): + node_ids[node] = node_id _serialize_slot_variables( checkpointable_objects=checkpointable_objects, node_ids=node_ids, @@ -564,6 +714,93 @@ def gather_initializers(root_checkpointable): if hasattr(c, "initializer") and c.initializer is not None] +@tf_contextlib.contextmanager +def capture_dependencies(template): + """Capture variables created within this scope as `Template` dependencies. + + Requires that `template.variable_scope` is active. + + This scope is intended as a compatibility measure, allowing a checkpointable + object to add dependencies on variables created in a block of code which is + not aware of object-based saving (and instead uses variable names + heavily). This is how `Template` objects add dependencies on variables and + sub-`Template`s. Where possible, use `tf.make_template` directly. + + Args: + template: The `Template` object to register dependencies with. + + Yields: + None (when used as a context manager). + """ + name_prefix = template.variable_scope.name + + def _checkpointable_custom_creator(next_creator, name, initial_value, + checkpointable_parent=None, **kwargs): + """A variable creation hook which adds Checkpointable dependencies. + + Set for example during a `Template`'s first wrapped function + execution. Ensures that (a) `template` depends on any checkpointable + objects using their own `capture_dependencies` scope inside this scope which + create variables, and (b) that any variables not in a more deeply nested + scope are added as dependencies directly. + + The `checkpointable_parent` argument is passed between custom creators but + ignored when the variable object itself is created. This argument indicates + (if not `None`) that a more deeply nested scope has already added the + variable as a dependency, and that parent scopes should add a dependency on + that object rather than on the variable directly. + + Args: + next_creator: See `variable_scope.variable_creator_scope`; the next + creator in the chain. + name: The (full, scope-influenced) name of the variable. The `name_prefix` + itself is stripped for the purposes of object-based dependency tracking, + but scopes opened within this scope are respected. + initial_value: See `variable_scope.variable_creator_scope`. Taken + explicitly so the argument can be re-named and used with + `Checkpointable._add_variable_with_custom_getter`. + checkpointable_parent: If not None, a more deeply nested checkpointable + object and its name prefix which were passed to `capture_dependencies` + to add a dependency on (rather than depending on the variable directly). + **kwargs: Passed through to the next creator. + + Returns: + The output of `next_creator`: the fetched/created variable object. + """ + def _call_next_creator_renaming_initializer(initializer, **inner_kwargs): + inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which + # we don't want to propagate. + return next_creator( + initial_value=initializer, + name=name, + **inner_kwargs) + if name is not None and name.startswith(name_prefix): + scope_stripped_name = name[len(name_prefix) + 1:] + if not checkpointable_parent: + return template._add_variable_with_custom_getter( # pylint: disable=protected-access + initializer=initial_value, + name=scope_stripped_name, + getter=_call_next_creator_renaming_initializer, + # Disable error checking for Checkpointable. Exceptions are instead + # raised if necessary when the object-based saver tries to + # save/restore the object. + overwrite=True, + checkpointable_parent=(template, name_prefix), + **kwargs) + else: + parent_object, parent_name_prefix = checkpointable_parent + template._track_checkpointable( # pylint: disable=protected-access + parent_object, + name=parent_name_prefix[len(name_prefix) + 1:], + overwrite=True) + return next_creator( + name=name, initial_value=initial_value, + checkpointable_parent=(template, name_prefix), **kwargs) + + with variable_scope.variable_creator_scope(_checkpointable_custom_creator): + yield + + class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject): def __init__(self, tensor, name): @@ -593,6 +830,31 @@ class _LoadStatus(object): pass +def streaming_restore(status, session=None): + """When graph building, runs restore ops as soon as they come in. + + Args: + status: A _LoadStatus objects from an object-based saver's + restore(). Streaming restore from name-based checkpoints is not currently + supported. + session: A session to run new restore ops in. + """ + if context.executing_eagerly(): + # Streaming restore is the default/only behavior when executing eagerly. + return + if session is None: + session = ops.get_default_session() + if isinstance(status, NameBasedSaverStatus): + raise NotImplementedError( + "Streaming restore not supported from name-based checkpoints. File a " + "feature request if this limitation bothers you.") + status.run_restore_ops(session=session) + # pylint: disable=protected-access + status._checkpoint.new_restore_ops_callback = ( + lambda ops: session.run(ops, feed_dict=status._feed_dict)) + # pylint: enable=protected-access + + class CheckpointLoadStatus(_LoadStatus): """Checks the status of checkpoint loading and manages restore ops. @@ -647,8 +909,8 @@ class CheckpointLoadStatus(_LoadStatus): for checkpointable_object in list_objects(self._root_checkpointable): self._checkpoint.all_python_objects.add(checkpointable_object) unused_python_objects = ( - set(self._checkpoint.all_python_objects) - - set(self._checkpoint.object_by_proto_id.values())) + _ObjectIdentitySet(self._checkpoint.all_python_objects) + - _ObjectIdentitySet(self._checkpoint.object_by_proto_id.values())) if unused_python_objects: raise AssertionError( ("Some Python objects were not bound to checkpointed values, likely " @@ -684,7 +946,7 @@ class CheckpointLoadStatus(_LoadStatus): if session is None: session = ops.get_default_session() all_objects = list_objects(self._root_checkpointable) - already_initialized_objects = set( + already_initialized_objects = _ObjectIdentitySet( self._checkpoint.object_by_proto_id.values()) initializers_for_non_restored_variables = [ c.initializer for c in all_objects @@ -764,11 +1026,13 @@ _DEPRECATED_RESTORE_INSTRUCTIONS = ( "one this message is coming from) and use that checkpoint in the future.") -@deprecation.deprecated( - date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) class NameBasedSaverStatus(_LoadStatus): """Status for loading a name-based training checkpoint.""" + # Ideally this deprecation decorator would be on the class, but that + # interferes with isinstance checks. + @deprecation.deprecated( + date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) def __init__(self, checkpoint, root_checkpointable): self._checkpoint = checkpoint self._root_checkpointable = root_checkpointable @@ -839,7 +1103,7 @@ class _SessionWithFeedDictAdditions(session_lib.SessionInterface): def _copy_saver_with_new_var_list(old_saver, new_var_list): """Copy a `tf.train.Saver`'s state to a new Saver with different variables.""" - new_saver = saver_lib.Saver(var_list=new_var_list) + new_saver = saver_lib.Saver(var_list=new_var_list, max_to_keep=None) # TODO(allenl): Move to copying functionality to Saver? # pylint: disable=protected-access new_saver._last_checkpoints = old_saver._last_checkpoints @@ -899,7 +1163,7 @@ class CheckpointableSaver(object): else: # Maps Checkpointable objects -> attribute names -> SaveableObjects, to # avoid re-creating SaveableObjects when graph building. - self._saveable_object_cache = weakref.WeakKeyDictionary() + self._saveable_object_cache = _ObjectIdentityWeakKeyDictionary() @property def _root_checkpointable(self): @@ -950,11 +1214,11 @@ class CheckpointableSaver(object): with ops.device("/cpu:0"): object_graph_tensor = constant_op.constant( graph_proto.SerializeToString(), dtype=dtypes.string) - assert checkpointable_lib.OBJECT_GRAPH_PROTO_KEY not in named_variables + assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables named_variables.append( _NoRestoreSaveable( tensor=object_graph_tensor, - name=checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)) + name=base.OBJECT_GRAPH_PROTO_KEY)) if (self._last_save_object_graph != graph_proto # When executing eagerly, we need to re-create SaveableObjects each time # save() is called so they pick up new Tensors passed to their @@ -965,7 +1229,8 @@ class CheckpointableSaver(object): self._last_save_saver = _copy_saver_with_new_var_list( old_saver=self._last_save_saver, new_var_list=named_variables) else: - self._last_save_saver = saver_lib.Saver(var_list=named_variables) + self._last_save_saver = saver_lib.Saver( + var_list=named_variables, max_to_keep=None) self._last_save_object_graph = graph_proto with ops.device("/cpu:0"): save_path = self._last_save_saver.save( @@ -973,6 +1238,7 @@ class CheckpointableSaver(object): session=session, feed_additions=feed_additions), save_path=file_prefix, write_meta_graph=False, + write_state=False, global_step=checkpoint_number) return save_path @@ -1044,7 +1310,7 @@ class CheckpointableSaver(object): dtype_map = reader.get_variable_to_dtype_map() try: object_graph_string = reader.get_tensor( - checkpointable_lib.OBJECT_GRAPH_PROTO_KEY) + base.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try the # name-based compatibility mode. @@ -1090,7 +1356,7 @@ class CheckpointableSaver(object): "file a feature request if this limitation bothers you.") self._last_restore_checkpoint = checkpoint self._last_restore_object_graph = object_graph_proto - checkpointable_lib._CheckpointPosition( # pylint: disable=protected-access + base._CheckpointPosition( # pylint: disable=protected-access checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable) load_status = CheckpointLoadStatus( checkpoint, @@ -1100,7 +1366,7 @@ class CheckpointableSaver(object): @tf_export("train.Checkpoint") -class Checkpoint(checkpointable_lib.Checkpointable): +class Checkpoint(tracking.Checkpointable): """Groups checkpointable objects, saving and restoring them. `Checkpoint`'s constructor accepts keyword arguments whose values are types @@ -1202,7 +1468,7 @@ class Checkpoint(checkpointable_lib.Checkpointable): """ super(Checkpoint, self).__init__() for k, v in sorted(kwargs.items(), key=lambda item: item[0]): - if not isinstance(v, checkpointable_lib.CheckpointableBase): + if not isinstance(v, base.CheckpointableBase): raise ValueError( ("`Checkpoint` was expecting a checkpointable object (an object " "derived from `CheckpointableBase`), got %s. If you believe this " @@ -1221,10 +1487,36 @@ class Checkpoint(checkpointable_lib.Checkpointable): with ops.device("/cpu:0"): # add_variable creates a dependency named "save_counter"; NoDependency # prevents creating a second dependency named "_save_counter". - self._save_counter = checkpointable_lib.NoDependency( + self._save_counter = data_structures.NoDependency( add_variable(self, name="save_counter", initializer=0, dtype=dtypes.int64)) + def write(self, file_prefix, session=None): + """Writes a training checkpoint. + + The checkpoint includes variables created by this object and any + checkpointable objects it depends on at the time `Checkpoint.write()` is + called. + + `write` does not number checkpoints, increment `save_counter`, or update the + metadata used by `tf.train.latest_checkpoint`. It is primarily intended for + use by higher level checkpoint management utilities. `save` provides a very + basic implementation of these features. + + Args: + file_prefix: A prefix to use for the checkpoint filenames + (/path/to/directory/and_a_prefix). + session: The session to evaluate variables in. Ignored when executing + eagerly. If not provided when graph building, the default session is + used. + + Returns: + The full path to the checkpoint (i.e. `file_prefix`). + """ + return self._saver.save( + file_prefix=file_prefix, + session=session) + @property def save_counter(self): """An integer variable which starts at zero and is incremented on save. @@ -1238,12 +1530,19 @@ class Checkpoint(checkpointable_lib.Checkpointable): return self._save_counter def save(self, file_prefix, session=None): - """Save a training checkpoint. + """Saves a training checkpoint and provides basic checkpoint management. The saved checkpoint includes variables created by this object and any checkpointable objects it depends on at the time `Checkpoint.save()` is called. + `save` is a basic convenience wrapper around the `write` method, + sequentially numbering checkpoints using `save_counter` and updating the + metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint + management, for example garbage collection and custom numbering, may be + provided by other utilities which also wrap `write` + (`tf.contrib.checkpoint.CheckpointManager` for example). + Args: file_prefix: A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix). Names are generated based on this @@ -1266,15 +1565,20 @@ class Checkpoint(checkpointable_lib.Checkpointable): session.run(self.save_counter.initializer) if not graph_building or self._save_assign_op is None: with ops.colocate_with(self.save_counter): - assign_op = self.save_counter.assign_add(1, read_value=False) + assign_op = self.save_counter.assign_add(1, read_value=True) if graph_building: - self._save_assign_op = assign_op + self._save_assign_op = data_structures.NoDependency(assign_op) if graph_building: - session.run(self._save_assign_op) - return self._saver.save( - file_prefix=file_prefix, - checkpoint_number=self.save_counter, - session=session) + checkpoint_number = session.run(self._save_assign_op) + else: + checkpoint_number = assign_op.numpy() + file_path = self.write("%s-%d" % (file_prefix, checkpoint_number), + session=session) + checkpoint_management.update_checkpoint_state( + save_dir=os.path.dirname(file_prefix), + model_checkpoint_path=file_path, + all_model_checkpoint_paths=[file_path]) + return file_path def restore(self, save_path): """Restore a training checkpoint. diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index 8cdf5d78554b01874115d438e7f0fadaf5b6b91c..a0a87b6b793b4a560d4088dc2ea8ed7ff696787b 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -42,13 +42,15 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope from tensorflow.python.training import adam +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import base +from tensorflow.python.training.checkpointable import tracking from tensorflow.python.training.checkpointable import util as checkpointable_utils -class NonLayerCheckpointable(checkpointable.Checkpointable): +class NonLayerCheckpointable(tracking.Checkpointable): def __init__(self): super(NonLayerCheckpointable, self).__init__() @@ -101,7 +103,7 @@ class InterfaceTests(test.TestCase): name="duplicate", initial_value=1.) duplicate = checkpointable_utils.add_variable( obj, name="duplicate", shape=[]) - with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): + with self.assertRaisesRegexp(ValueError, "'duplicate'.*already declared"): checkpointable_utils.add_variable(obj, name="duplicate", shape=[]) self.evaluate(checkpointable_utils.gather_initializers(obj)) @@ -136,7 +138,7 @@ class InterfaceTests(test.TestCase): def testInitNotCalled(self): - class NoInit(checkpointable.Checkpointable): + class NoInit(tracking.Checkpointable): def __init__(self): pass @@ -145,7 +147,7 @@ class InterfaceTests(test.TestCase): checkpointable_utils.add_variable(NoInit(), "var", shape=[]) def testShapeDtype(self): - root = checkpointable.Checkpointable() + root = tracking.Checkpointable() v1 = checkpointable_utils.add_variable( root, name="v1", initializer=3., dtype=dtypes.float64) self.assertEqual(dtypes.float64, v1.dtype) @@ -177,7 +179,7 @@ class InterfaceTests(test.TestCase): def testNotCheckpointable(self): class CallsFunctionalStuff( - checkpointable.NotCheckpointable, checkpointable.Checkpointable): + tracking.NotCheckpointable, tracking.Checkpointable): pass test_dir = self.get_temp_dir() @@ -187,7 +189,7 @@ class InterfaceTests(test.TestCase): checkpoint.save(prefix) class CallsFunctionalStuffOtherMRO( - checkpointable.Checkpointable, checkpointable.NotCheckpointable): + tracking.Checkpointable, tracking.NotCheckpointable): pass checkpoint_reversed = checkpointable_utils.Checkpoint( @@ -217,7 +219,7 @@ class _MirroringSaveable(saver_lib.BaseSaverBuilder.SaveableObject): self._mirrored_variable.assign(tensor)) -class _OwnsMirroredVariables(checkpointable.CheckpointableBase): +class _OwnsMirroredVariables(base.CheckpointableBase): """A Checkpointable object which returns a more complex SaveableObject.""" def __init__(self): @@ -232,7 +234,7 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase): primary_variable=self.non_dep_variable, mirrored_variable=self.mirrored, name=name) - return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + return {base.VARIABLE_VALUE_KEY: _saveable_factory} # The Saver sorts by name before parsing, so we need a name property. @property @@ -355,7 +357,7 @@ class CheckpointingTests(test.TestCase): optimizer_node.slot_variables[0] .slot_variable_node_id].attributes[0].checkpoint_key) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMoreComplexSaveableReturned(self): v = _OwnsMirroredVariables() checkpoint = checkpointable_utils.Checkpoint(v=v) @@ -375,7 +377,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual(44., self.evaluate(v.non_dep_variable)) self.assertEqual(44., self.evaluate(v.mirrored)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMoreComplexSaveableReturnedWithGlobalName(self): # The same object can also be saved using the name-based saver. v = _OwnsMirroredVariables() @@ -391,7 +393,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual(42., self.evaluate(v.non_dep_variable)) self.assertEqual(42., self.evaluate(v.mirrored)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveRestore(self): model = MyModel() optimizer = adam.AdamOptimizer(0.001) @@ -466,7 +468,8 @@ class CheckpointingTests(test.TestCase): root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) - root.restore(saver_lib.latest_checkpoint(checkpoint_directory)) + root.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) for _ in range(num_training_steps): # TODO(allenl): Use a Dataset and serialize/checkpoint it. input_value = constant_op.constant([[3.]]) @@ -494,7 +497,8 @@ class CheckpointingTests(test.TestCase): train_op = optimizer.minimize( model(input_value), global_step=root.global_step) - checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) with self.test_session(graph=ops.get_default_graph()) as session: status = root.restore(save_path=checkpoint_path) status.initialize_or_restore(session=session) @@ -512,13 +516,12 @@ class CheckpointingTests(test.TestCase): self.assertEqual(training_continuation + 1, session.run(root.save_counter)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAgnosticUsage(self): """Graph/eager agnostic usage.""" # Does create garbage when executing eagerly due to ops.Graph() creation. num_training_steps = 10 checkpoint_directory = self.get_temp_dir() - checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") for training_continuation in range(3): with ops.Graph().as_default(), self.test_session( graph=ops.get_default_graph()), test_util.device(use_gpu=True): @@ -527,8 +530,9 @@ class CheckpointingTests(test.TestCase): root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory) - status = root.restore(save_path=checkpoint_path) + manager = checkpoint_management.CheckpointManager( + root, checkpoint_directory, max_to_keep=1) + status = root.restore(save_path=manager.latest_checkpoint) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( optimizer.minimize, @@ -539,14 +543,28 @@ class CheckpointingTests(test.TestCase): status.initialize_or_restore() for _ in range(num_training_steps): train_fn() - root.save(file_prefix=checkpoint_prefix) + manager.save() self.assertEqual((training_continuation + 1) * num_training_steps, self.evaluate(root.global_step)) self.assertEqual(training_continuation + 1, self.evaluate(root.save_counter)) + @test_util.run_in_graph_and_eager_modes + def testCustomNumbering(self): + directory = self.get_temp_dir() + prefix = os.path.join(directory, "ckpt") + step = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64) + checkpoint = checkpointable_utils.Checkpoint(step=step) + self.evaluate(step.initializer) + for i in range(5): + path = checkpoint.write("%s-%d" % (prefix, self.evaluate(step))) + expected_suffix = "-%d" % (2 * i,) + if not path.endswith(expected_suffix): + self.fail("%s should have suffix %s" % (path, expected_suffix)) + self.evaluate(step.assign_add(2)) + # pylint: disable=cell-var-from-loop - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testWithDefun(self): num_training_steps = 2 checkpoint_directory = self.get_temp_dir() @@ -560,7 +578,8 @@ class CheckpointingTests(test.TestCase): root = checkpointable_utils.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) def train_fn(): @function.defun @@ -590,7 +609,7 @@ class CheckpointingTests(test.TestCase): # pylint: enable=cell-var-from-loop def _get_checkpoint_name(self, name): - root = checkpointable.Checkpointable() + root = tracking.Checkpointable() checkpointable_utils.add_variable( root, name=name, shape=[1, 2], dtype=dtypes.float64) (named_variable,), _, _ = checkpointable_utils._serialize_object_graph( @@ -611,18 +630,18 @@ class CheckpointingTests(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNumberedPath(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() + root = tracking.Checkpointable() + leaf = tracking.Checkpointable() root.leaf = leaf checkpointable_utils.add_variable(leaf, name="v", shape=[]) (named_variable,), _, _ = checkpointable_utils._serialize_object_graph( root, saveables_cache=None) self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", named_variable.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLocalNameValidation(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() + root = tracking.Checkpointable() + leaf = tracking.Checkpointable() # Dots are escaped, which avoids conflicts with reserved names. root._track_checkpointable(leaf, name=".ATTRIBUTES") checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[]) @@ -660,16 +679,16 @@ class CheckpointingTests(test.TestCase): optimizer.apply_gradients( [(g, v) for g, v in zip(grad, model.vars)]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLateDependencyTracking(self): - class Dependency(checkpointable.Checkpointable): + class Dependency(tracking.Checkpointable): def build(self): self.var = checkpointable_utils.add_variable( self, "var", initializer=0.) - class LateDependencies(checkpointable.Checkpointable): + class LateDependencies(tracking.Checkpointable): def add_dep(self): self.dep = Dependency() @@ -692,16 +711,16 @@ class CheckpointingTests(test.TestCase): status.run_restore_ops() self.assertEqual(123., self.evaluate(load_into.dep.var)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDepAfterVar(self): - class Dependency(checkpointable.Checkpointable): + class Dependency(tracking.Checkpointable): def build(self): self.var = checkpointable_utils.add_variable( self, "var", initializer=0.) - class DepAfterVar(checkpointable.Checkpointable): + class DepAfterVar(tracking.Checkpointable): def add_dep(self): dep = Dependency() @@ -724,11 +743,11 @@ class CheckpointingTests(test.TestCase): status.run_restore_ops() self.assertEqual(-14., self.evaluate(loaded_dep_after_var.dep.var)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() - root = checkpointable.Checkpointable() + root = tracking.Checkpointable() root.var = checkpointable_utils.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) @@ -751,7 +770,7 @@ class CheckpointingTests(test.TestCase): 14.)) slots_path = checkpointable_utils.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) - new_root = checkpointable.Checkpointable() + new_root = tracking.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). slot_status = checkpointable_utils.CheckpointableSaver( @@ -789,11 +808,11 @@ class CheckpointingTests(test.TestCase): self.evaluate(train_op) slot_status.assert_consumed() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testOverlappingRestores(self): checkpoint_directory = self.get_temp_dir() - save_root = checkpointable.Checkpointable() - save_root.dep = checkpointable.Checkpointable() + save_root = tracking.Checkpointable() + save_root.dep = tracking.Checkpointable() save_root.dep.var = checkpointable_utils.add_variable( save_root.dep, name="var", initializer=0.) self.evaluate(state_ops.assign(save_root.dep.var, 12.)) @@ -802,13 +821,13 @@ class CheckpointingTests(test.TestCase): self.evaluate(state_ops.assign(save_root.dep.var, 13.)) second_path = saver.save(os.path.join(checkpoint_directory, "second")) - first_root = checkpointable.Checkpointable() - second_root = checkpointable.Checkpointable() + first_root = tracking.Checkpointable() + second_root = tracking.Checkpointable() first_status = checkpointable_utils.CheckpointableSaver( first_root).restore(first_path) second_status = checkpointable_utils.CheckpointableSaver( second_root).restore(second_path) - load_dep = checkpointable.Checkpointable() + load_dep = tracking.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep @@ -822,13 +841,13 @@ class CheckpointingTests(test.TestCase): # Try again with the order of the restore() reversed. The last restore # determines the final value. - first_root = checkpointable.Checkpointable() - second_root = checkpointable.Checkpointable() + first_root = tracking.Checkpointable() + second_root = tracking.Checkpointable() second_status = checkpointable_utils.CheckpointableSaver( second_root).restore(second_path) first_status = checkpointable_utils.CheckpointableSaver( first_root).restore(first_path) - load_dep = checkpointable.Checkpointable() + load_dep = tracking.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep @@ -840,39 +859,39 @@ class CheckpointingTests(test.TestCase): second_status.run_restore_ops() self.assertEqual(12., self.evaluate(load_dep.var)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testAmbiguousLoad(self): # Not OK to split one checkpoint object into two checkpoint_directory = self.get_temp_dir() - save_root = checkpointable.Checkpointable() - save_root.dep_one = checkpointable.Checkpointable() - save_root.dep_two = checkpointable.Checkpointable() - dep_three = checkpointable.Checkpointable() + save_root = tracking.Checkpointable() + save_root.dep_one = tracking.Checkpointable() + save_root.dep_two = tracking.Checkpointable() + dep_three = tracking.Checkpointable() save_root.dep_one.dep_three = dep_three save_root.dep_two.dep_three = dep_three checkpointable_utils.add_variable(dep_three, name="var", initializer=0.) self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) - load_root = checkpointable.Checkpointable() + load_root = tracking.Checkpointable() status = checkpointable_utils.CheckpointableSaver(load_root).restore( save_path) - load_root.dep_one = checkpointable.Checkpointable() - load_root.dep_two = checkpointable.Checkpointable() - load_root.dep_one.dep_three = checkpointable.Checkpointable() - load_root.dep_two.dep_three = checkpointable.Checkpointable() + load_root.dep_one = tracking.Checkpointable() + load_root.dep_two = tracking.Checkpointable() + load_root.dep_one.dep_three = tracking.Checkpointable() + load_root.dep_two.dep_three = tracking.Checkpointable() checkpointable_utils.add_variable( load_root.dep_one.dep_three, name="var", initializer=0.) with self.assertRaises(AssertionError): status.assert_consumed() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testObjectsCombined(self): # Currently fine to load two checkpoint objects into one Python object checkpoint_directory = self.get_temp_dir() - save_root = checkpointable.Checkpointable() - save_root.dep_one = checkpointable.Checkpointable() - save_root.dep_two = checkpointable.Checkpointable() + save_root = tracking.Checkpointable() + save_root.dep_one = tracking.Checkpointable() + save_root.dep_two = tracking.Checkpointable() checkpointable_utils.add_variable( save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64) checkpointable_utils.add_variable( @@ -880,8 +899,8 @@ class CheckpointingTests(test.TestCase): self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.CheckpointableSaver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) - load_root = checkpointable.Checkpointable() - load_root.dep_one = checkpointable.Checkpointable() + load_root = tracking.Checkpointable() + load_root.dep_one = tracking.Checkpointable() load_root.dep_two = load_root.dep_one v1 = checkpointable_utils.add_variable( load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64) @@ -893,12 +912,12 @@ class CheckpointingTests(test.TestCase): self.assertEqual(32., self.evaluate(v1)) self.assertEqual(64., self.evaluate(v2)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testDependencyLoop(self): # Note: this test creates garbage during eager execution because it # purposefully creates a reference cycle. - first = checkpointable.Checkpointable() - second = checkpointable.Checkpointable() + first = tracking.Checkpointable() + second = tracking.Checkpointable() first.second = second second.first = first first.v = checkpointable_utils.add_variable( @@ -911,10 +930,10 @@ class CheckpointingTests(test.TestCase): os.path.join(checkpoint_directory, "ckpt")) # Test deferred loading - first_load = checkpointable.Checkpointable() + first_load = tracking.Checkpointable() status = checkpointable_utils.CheckpointableSaver( first_load).restore(save_path) - second_load = checkpointable.Checkpointable() + second_load = tracking.Checkpointable() first_load.second = second_load second_load.first = first_load with self.assertRaises(AssertionError): @@ -939,13 +958,13 @@ class CheckpointingTests(test.TestCase): self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testRestoreOnAssign(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_graph = ops.Graph() with save_graph.as_default(), self.test_session(save_graph): - first = checkpointable.Checkpointable() + first = tracking.Checkpointable() first.var1 = variable_scope.get_variable( name="outside_var", initializer=0.) first.var2 = variable_scope.get_variable( @@ -956,7 +975,7 @@ class CheckpointingTests(test.TestCase): checkpoint_prefix) restore_graph = ops.Graph() with restore_graph.as_default(), self.test_session(restore_graph): - second = checkpointable.Checkpointable() + second = tracking.Checkpointable() second.var2 = variable_scope.get_variable( name="blah", initializer=0.) status = checkpointable_utils.CheckpointableSaver( @@ -978,7 +997,7 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) @@ -989,18 +1008,19 @@ class CheckpointingTests(test.TestCase): saver.save(checkpoint_prefix) self.assertEqual(before_ops, graph.get_operations()) - @test_util.run_in_graph_and_eager_modes() - def testCheckpointCleanup(self): + @test_util.run_in_graph_and_eager_modes + def testCheckpointState(self): + # No checkpoints are deleted by default checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) self.evaluate(checkpointable_utils.gather_initializers(obj)) saver = checkpointable_utils.Checkpoint(obj=obj) for _ in range(10): saver.save(checkpoint_prefix) expected_filenames = ["checkpoint"] - for checkpoint_number in range(6, 11): + for checkpoint_number in range(1, 11): expected_filenames.append("ckpt-%d.index" % (checkpoint_number,)) expected_filenames.append( "ckpt-%d.data-00000-of-00001" % (checkpoint_number,)) @@ -1009,11 +1029,11 @@ class CheckpointingTests(test.TestCase): expected_filenames, os.listdir(checkpoint_directory)) - @test_util.run_in_graph_and_eager_modes() - def testCheckpointCleanupChangingVarList(self): + @test_util.run_in_graph_and_eager_modes + def testCheckpointStateChangingVarList(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) self.evaluate(checkpointable_utils.gather_initializers(obj)) checkpoint = checkpointable_utils.Checkpoint(obj=obj) @@ -1026,8 +1046,8 @@ class CheckpointingTests(test.TestCase): looped_variables.append(new_variable) expected_filenames = ["checkpoint"] # We've copied the saver each time, but checkpoint management should still - # be consistent. - for checkpoint_number in range(6, 11): + # be consistent. Nothing gets deleted. + for checkpoint_number in range(1, 11): expected_filenames.append("ckpt-%d.index" % (checkpoint_number,)) expected_filenames.append( "ckpt-%d.data-00000-of-00001" % (checkpoint_number,)) @@ -1035,6 +1055,15 @@ class CheckpointingTests(test.TestCase): self, expected_filenames, os.listdir(checkpoint_directory)) + self.assertEqual( + checkpoint_prefix + "-10", + checkpoint_management.latest_checkpoint(checkpoint_directory)) + # The checkpoint list only contains the most recent checkpoint, but they're + # all on disk. This means we won't eventually run into proto size limits. + self.assertEqual( + [checkpoint_prefix + "-10"], + (checkpoint_management.get_checkpoint_state(checkpoint_directory) + .all_model_checkpoint_paths)) for v in looped_variables: self.evaluate(v.assign(314)) checkpoint.restore(checkpoint_prefix + "-6").run_restore_ops() @@ -1062,7 +1091,7 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) @@ -1132,7 +1161,7 @@ class CheckpointingTests(test.TestCase): beta1_power, _ = optimizer._get_beta_accumulators() self.assertAllEqual(3., self.evaluate(beta1_power)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_sequential(self): model = sequential.Sequential() checkpoint = checkpointable_utils.Checkpoint(model=model) @@ -1164,7 +1193,7 @@ class CheckpointingTests(test.TestCase): self.assertAllEqual([1., 2., 3., 4., 5.], self.evaluate(deferred_second_dense.bias)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_initialize_if_not_restoring(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") @@ -1179,7 +1208,8 @@ class CheckpointingTests(test.TestCase): optimizer_checkpoint = checkpointable_utils.Checkpoint( optimizer=optimizer) - checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( @@ -1243,9 +1273,21 @@ class CheckpointingTests(test.TestCase): self.assertEqual(42., self.evaluate(optimizer.variables()[0])) +class _ManualScope(tracking.Checkpointable): + + def __call__(self): + with variable_scope.variable_scope("ManualScope") as vs: + self.variable_scope = vs + with checkpointable_utils.capture_dependencies(template=self): + return self._build() + + def _build(self): + return variable_scope.get_variable(name="in_manual_scope", shape=[]) + + class TemplateTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_checkpointable_save_restore(self): def _templated(): @@ -1255,14 +1297,23 @@ class TemplateTests(test.TestCase): v2 = variable_scope.get_variable( "v2", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) - return v, v + 1., v2 + manual = _ManualScope() + return v, v + 1., v2, manual, manual() save_template = template.make_template("s1", _templated) - v1_save, _, v2_save = save_template() + v1_save, _, v2_save, manual_scope, manual_scope_v = save_template() + six.assertCountEqual( + self, + [v1_save, v2_save, manual_scope, manual_scope_v, save_template], + checkpointable_utils.list_objects(save_template)) + manual_dep, = manual_scope._checkpoint_dependencies + self.assertEqual("in_manual_scope", manual_dep.name) + self.assertIs(manual_scope_v, manual_dep.ref) optimizer = adam.AdamOptimizer(0.0) save_root = checkpointable_utils.Checkpoint( my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value) + self.evaluate([v.initializer for v in save_template.variables]) self.evaluate([v.initializer for v in optimizer.variables()]) self.evaluate(v1_save.assign([12.])) self.evaluate(v2_save.assign([14.])) @@ -1275,17 +1326,19 @@ class TemplateTests(test.TestCase): load_root = checkpointable_utils.Checkpoint( my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) - var, var_plus_one, var2 = load_template() + var, var_plus_one, var2, _, _ = load_template() load_optimizer.minimize(var.read_value) - self.assertEqual(2, len(load_template._checkpoint_dependencies)) + self.assertEqual(3, len(load_template._checkpoint_dependencies)) self.assertEqual("v", load_template._checkpoint_dependencies[0].name) self.assertEqual("v2", load_template._checkpoint_dependencies[1].name) + self.assertEqual("ManualScope", + load_template._checkpoint_dependencies[2].name) status.assert_consumed().run_restore_ops() self.assertAllEqual([12.], self.evaluate(var)) self.assertAllEqual([13.], self.evaluate(var_plus_one)) self.assertAllEqual([14.], self.evaluate(var2)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_checkpointable_save_restore_nested(self): def _inner_template(): @@ -1386,7 +1439,7 @@ class CheckpointCompatibilityTests(test.TestCase): sess=session, save_path=checkpoint_prefix, global_step=root.optimizer_step) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testLoadFromNameBasedSaver(self): """Save a name-based checkpoint, load it using the object-based API.""" with test_util.device(use_gpu=True): @@ -1448,7 +1501,7 @@ class CheckpointCompatibilityTests(test.TestCase): class PythonMetadataTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveLoad(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index caffd042a0917209c87cab8993169dc4bc956039..0d8d74a0969c9b5e393a06a6141e396b4df96670 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -19,9 +19,9 @@ from __future__ import division from __future__ import print_function import threading -import six from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context as eager_context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -31,70 +31,10 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses_impl from tensorflow.python.platform import tf_logging from tensorflow.python.training import device_util +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util import nest -# ------------------------------------------------------------------------------ -# Internal API for setting the current thread mode as being either in a -# tower or cross-tower context for a particular distribution strategy. - - -class _ThreadMode(object): - - def __init__(self, dist, cross, tower): - self.distribution_strategy = dist - self.cross_tower_context = cross - self.tower_context = tower - - -class _CrossTowerThreadMode(_ThreadMode): - - def __init__(self, distribution_strategy): - _ThreadMode.__init__( - self, distribution_strategy, distribution_strategy, None) - - -class _InTowerThreadMode(_ThreadMode): - - def __init__(self, tower_ctx): - _ThreadMode.__init__( - self, tower_ctx.distribution_strategy, None, tower_ctx) - - -_per_thread_mode = threading.local() - - -def _push_per_thread_mode(context): - if not hasattr(_per_thread_mode, "stack"): - _per_thread_mode.stack = [] - _per_thread_mode.stack.append(context) - - -def _pop_per_thread_mode(): - _per_thread_mode.stack.pop(-1) - - -class _DefaultTowerThreadMode(_ThreadMode): - """Type of default value returned by `_get_per_thread_mode()`. - - Used when the thread-local stack is empty. - """ - - def __init__(self): - # _default_distribution_strategy and _default_tower_context are - # defined at the bottom of this file. - _ThreadMode.__init__( - self, _default_distribution_strategy, None, _default_tower_context) - - -def _get_per_thread_mode(): - try: - return _per_thread_mode.stack[-1] - except (AttributeError, IndexError): - # _default_tower_mode is defined at the bottom of this file. - return _default_tower_mode - - # ------------------------------------------------------------------------------ # Context tracking whether in a distribution.update() or .update_non_slot() # call. @@ -127,106 +67,16 @@ class UpdateContext(object): _update_device.current = self._old_device -# ------------------------------------------------------------------------------ -# Public API for accessing the current thread mode - - -def get_tower_context(): - """Returns the current TowerContext or None if in a cross-tower context. - - Note that execution: - 1. starts in the default (single-tower) tower context (this function - will return the default TowerContext object); - 2. switches to cross-tower context (in which case this will return - None) when entering a `with DistributionStrategy.scope():` block; - 3. switches to a (non-default) tower context inside - `call_for_each_tower(fn, ...)`; - 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then - inside `merge_fn` you are back in the cross-tower context (and again - this function will return None). - - Note that you can also go directly from step 1 to 4 to switch to a - cross-tower context for the default `DistributionStrategy`. You may - also switch from the cross-tower context of 4 to a tower context by - calling `call_for_each_tower()`, jumping back to step 3. - - Most `DistributionStrategy` methods may only be executed in - a cross-tower context, in a tower context you should use the - `TowerContext` API instead. - - Returns: - The current `TowerContext` object when in a tower context scope, else None. - - Exactly one of `get_tower_context()` and `get_cross_tower_context()` - will return None in a particular block. - """ - return _get_per_thread_mode().tower_context - - -def get_cross_tower_context(): - """Returns the current DistributionStrategy if in a cross-tower context. - - Note that execution: - 1. starts in the default (single-tower) tower context; - 2. switches to cross-tower context when entering a - `with DistributionStrategy.scope():` block; - 3. switches to a (non-default) tower context inside - `call_for_each_tower(fn, ...)`; - 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then - inside `merge_fn` you are back in the cross-tower context. - - Note that you can also go directly from step 1 to 4 to switch to a - cross-tower context for the default `DistributionStrategy`. You may - also switch from the cross-tower context of 4 to a tower context by - calling `call_for_each_tower()`, jumping back to step 3. - - Most `DistributionStrategy` methods may only be executed in - a cross-tower context. - - Returns: - Returns the current `DistributionStrategy` object in a cross-tower - context, or None. - - Exactly one of `get_tower_context()` and `get_cross_tower_context()` - will return None in a particular block. - """ - return _get_per_thread_mode().cross_tower_context - - -def get_distribution_strategy(): - """Returns the current `DistributionStrategy` object. - - Prefer to use `get_tower_context()` or `get_cross_tower_context()` - instead when possible. - - Returns: - A `DistributionStrategy` object. Inside a - `with distribution_strategy.scope()` block, it returns - `distribution_strategy`, otherwise it returns the default - (single-tower) `DistributionStrategy` object. - """ - return _get_per_thread_mode().distribution_strategy - - -def has_distribution_strategy(): - """Return if there is a current non-default `DistributionStrategy`. - - Returns: - True if inside a `with distribution_strategy.scope():`. - """ - return get_distribution_strategy() is not _default_distribution_strategy - - # ------------------------------------------------------------------------------ # Public utility functions. def get_loss_reduction(): - """Reduce `method_string` corresponding to the last loss reduction.""" + """Reduce `aggregation` corresponding to the last loss reduction.""" loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access if loss_reduction == losses_impl.Reduction.SUM: - return "sum" - return "mean" + return variable_scope.VariableAggregation.SUM + return variable_scope.VariableAggregation.MEAN # ------------------------------------------------------------------------------ @@ -239,7 +89,8 @@ def _require_cross_tower_context(distribution_strategy): if context.cross_tower_context is distribution_strategy: return # We have an error to report, figure out the right message. if context.distribution_strategy is not distribution_strategy: - if context.distribution_strategy is _default_distribution_strategy: + if (context.distribution_strategy is + distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access raise RuntimeError( 'Need to be inside "with distribution_strategy.scope()" for %s' % (distribution_strategy,)) @@ -272,7 +123,8 @@ def _require_distribution_strategy_scope(distribution_strategy): context = _get_per_thread_mode() if context.distribution_strategy is distribution_strategy: return # We have an error to report, figure out the right message. - if context.distribution_strategy is _default_distribution_strategy: + if (context.distribution_strategy is + distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access raise RuntimeError( 'Need to be inside "with distribution_strategy.scope()" for %s' % (distribution_strategy,)) @@ -295,7 +147,8 @@ class _CurrentDistributionContext(object): var_creator_scope, var_scope=None, default_device=None): - self._context = _CrossTowerThreadMode(distribution_strategy) + self._context = distribution_strategy_context._CrossTowerThreadMode( # pylint: disable=protected-access + distribution_strategy) self._var_creator_scope = var_creator_scope self._var_scope = var_scope if default_device: @@ -527,17 +380,21 @@ class DistributionStrategy(object): V(`v`), output will have locality V(`v`) as well. * `d.update_non_slot(d.non_slot_devices(), fn)`: in cross-tower context, like `d.update()` except with locality N. - * `d.fetch(t)`: Copy `t` with any locality to the client's CPU device. - TODO(josh11b): Deprecate `fetch`, switch to `read_var` for - reading tower-local variables. + * `d.read_var(v)`: Gets the (read-only) value of the variable `v` (on + the device determined by the current device scope), aggregating + across towers for tower-local variables. Frequently, this will be + done automatically when using `v` in an expression or fetching it in + a cross-tower context, but this function can be used to force that + conversion happens at a particular point in time (for example, to + add the result of the conversion to a graph collection). The standard pattern for updating variables is to: 1. Wrap your input dataset in `d.distribute_dataset()` and create an iterator. 2. Define each tower `d.call_for_each_tower()` up to the point of getting a list of gradient, variable pairs. - 3. Call `d.reduce("sum", t, v)` or `d.batch_reduce()` to sum the - gradients (with locality T) into values with locality V(`v`). + 3. Call `d.reduce(VariableAggregation.SUM, t, v)` or `d.batch_reduce()` to sum + the gradients (with locality T) into values with locality V(`v`). 4. Call `d.update(v)` for each variable to update its value. Steps 3 and 4 are done automatically by class `Optimizer` if you call @@ -584,7 +441,7 @@ class DistributionStrategy(object): Returns: A context manager. """ - if has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): _require_cross_tower_context(self) return _SameScopeAgainContext(self) @@ -611,43 +468,6 @@ class DistributionStrategy(object): # Note: should support "colocate_with" argument. raise NotImplementedError("must be implemented in descendants") - def tower_local_var_scope(self, reduce_method): - """Inside this scope, new variables will not be mirrored. - - There will still be one component variable per tower, but there is - no requirement that they stay in sync. Instead, when saving them - or calling `fetch()/read_var()`, we use the value that - results when calling `reduce()` on all the towers' variables. - - Note: tower-local implies not trainable. Instead, it is expected - that each tower will directly update (using `assign_add()` or - whatever) its local variable instance but only the aggregated - value (accessible using `fetch()`) will be exported from the - model. When it is acceptable to only aggregate on export, we - greatly reduce communication overhead by using tower-local - variables. - - Note: All component variables will be initialized to the same - value, using the initialization expression from the first tower. - The values will match even if the initialization expression uses - random numbers. - - Args: - reduce_method: String used as a `method_string` to `reduce()` - to get the value to save when checkpointing. - - Returns: - A context manager. - """ - def create_tower_local_variable(next_creator, *args, **kwargs): - _require_distribution_strategy_scope(self) - kwargs["use_resource"] = True - kwargs["tower_local_reduce_method"] = reduce_method - return next_creator(*args, **kwargs) - - _require_distribution_strategy_scope(self) - return variable_scope.variable_creator_scope(create_tower_local_variable) - def read_var(self, v): """Reads the value of a variable. @@ -761,6 +581,85 @@ class DistributionStrategy(object): def _broadcast(self, tensor, destinations): raise NotImplementedError("must be implemented in descendants") + def initialize(self): + """Any initialization to be done before running any computations. + + In eager mode, it executes any initialization as a side effect. + In graph mode, it creates the initialization ops and returns them. + + For example, TPU initialize_system ops. + + Returns: + In eager mode, returns `None`. + In graph mode, a list of ops to execute. Empty list if nothing to be done. + """ + if eager_context.executing_eagerly(): + return + else: + return [] + + def finalize(self): + """Any final actions to be done at the end of all computations. + + In eager mode, it executes any finalize actions as a side effect. + In graph mode, it creates the finalize ops and returns them. + + For example, TPU shutdown ops. + + Returns: + In eager mode, returns `None`. + In graph mode, a list of ops to execute. Empty list if nothing to be done. + """ + if eager_context.executing_eagerly(): + return + else: + return [] + + def run_steps_on_dataset(self, fn, iterator, iterations=1, + initial_loop_values=None): + """Run `fn` with input from `iterator` for `iterations` times. + + This method can be used to run a step function for training a number of + times using input from a dataset. + + Args: + fn: function to run using this distribution strategy. The function must + have the following signature: def fn(context, inputs). + `context` is an instance of `MultiStepContext` that will be passed when + `fn` is run. `context` can be used to specify the outputs to be returned + from `fn` by calling `context.set_last_step_output`. It can also be used + to capture non tensor outputs by `context.set_non_tensor_output`. + See `MultiStepContext` documentation for more information. + `inputs` will have same type/structure as `iterator.get_next()`. + Typically, `fn` will use `call_for_each_tower` method of the strategy + to distribute the computation over multiple towers. + iterator: Iterator of a dataset that represents the input for `fn`. The + caller is responsible for initializing the iterator as needed. + iterations: (Optional) Number of iterations that `fn` should be run. + Defaults to 1. + initial_loop_values: (Optional) Initial values to be passed into the + loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove + initial_loop_values argument when we have a mechanism to infer the + outputs of `fn`. + + Returns: + Returns the `MultiStepContext` object which has the following properties, + among other things: + - run_op: An op that runs `fn` `iterations` times. + - last_step_outputs: A dictionary containing tensors set using + `context.set_last_step_output`. Evaluating this returns the value of + the tensors after the last iteration. + - non_tensor_outputs: A dictionatry containing anything that was set by + `fn` by calling `context.set_non_tensor_output`. + """ + _require_cross_tower_context(self) + return self._run_steps_on_dataset(fn, iterator, iterations, + initial_loop_values) + + def _run_steps_on_dataset(self, fn, iterator, iterations, + initial_loop_values): + raise NotImplementedError("must be implemented in descendants") + def call_for_each_tower(self, fn, *args, **kwargs): """Run `fn` once per tower. @@ -813,12 +712,12 @@ class DistributionStrategy(object): def _call_for_each_tower(self, fn, *args, **kwargs): raise NotImplementedError("must be implemented in descendants") - def reduce(self, method_string, value, destinations=None): + def reduce(self, aggregation, value, destinations=None): """Combine (via e.g. sum or mean) values across towers. Args: - method_string: A string indicating how to combine values, either - "sum" or "mean". + aggregation: Indicates how a variable will be aggregated. Accepted values + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. value: A per-device value with one value per tower. destinations: An optional mirrored variable, a device string, list of device strings. The return value will be copied to all @@ -833,18 +732,21 @@ class DistributionStrategy(object): # TODO(josh11b): Return an unwrapped value if colocate_with is a # single device. _require_cross_tower_context(self) - assert method_string in ("sum", "mean") - return self._reduce(method_string, value, destinations) + assert aggregation in [ + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN + ] + return self._reduce(aggregation, value, destinations) - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): raise NotImplementedError("must be implemented in descendants") - def batch_reduce(self, method_string, value_destination_pairs): + def batch_reduce(self, aggregation, value_destination_pairs): """Combine multiple `reduce` calls into one for faster execution. Args: - method_string: A string indicating how to combine values, either - "sum" or "mean". + aggregation: Indicates how a variable will be aggregated. Accepted values + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. value_destination_pairs: A sequence of (value, destinations) pairs. See `reduce()` for a description. @@ -853,12 +755,17 @@ class DistributionStrategy(object): """ # TODO(josh11b): More docstring _require_cross_tower_context(self) - assert method_string in ("sum", "mean") - return self._batch_reduce(method_string, value_destination_pairs) - - def _batch_reduce(self, method_string, value_destination_pairs): - return [self.reduce(method_string, t, destinations=v) - for t, v in value_destination_pairs] + assert aggregation in [ + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN + ] + return self._batch_reduce(aggregation, value_destination_pairs) + + def _batch_reduce(self, aggregation, value_destination_pairs): + return [ + self.reduce(aggregation, t, destinations=v) + for t, v in value_destination_pairs + ] def update(self, var, fn, *args, **kwargs): """Run `fn` to update `var` using inputs mirrored to the same devices. @@ -914,45 +821,33 @@ class DistributionStrategy(object): def _update_non_slot(self, colocate_with, fn, *args, **kwargs): raise NotImplementedError("must be implemented in descendants") - def fetch(self, val, destination="/device:CPU:0", fn=lambda x: x): - """Return a copy of `val` or `fn(val)` on `destination`. - - This is useful for getting a mirrored value onto a device. It - will attempt to avoid a copy by checking if the value is already - on the destination device. - - TODO(josh11b): Switch to `read_var`. + def unwrap(self, value): + """Returns the list of all per-device values contained in `value`. Args: - val: Value (which may be mirrored) to copy. - destination: A device string to copy the value to. - fn: An optional function to apply to the value on the source - device, before copying. + value: A value returned by `call_for_each_tower()` or a variable + created in `scope()`. Returns: - A `Tensor` on `destination`. + A list of values contained in `value`. If `value` represents a single + value, this returns `[value].` """ - _require_cross_tower_context(self) - assert isinstance(destination, six.string_types) - destination = device_util.resolve(destination) - return self._fetch(val, destination, fn) - - def _fetch(self, val, destination, fn): - raise NotImplementedError("must be implemented in descendants") + return self._unwrap(value) - def unwrap(self, value): - """Returns the list of all per-device values contained in `value`. + def value_container(self, value): + """Returns the container that this per-device `value` belongs to. Args: value: A value returned by `call_for_each_tower()` or a variable created in `scope()`. Returns: - A list of values contained in `value`. If `value` represents a single - value, this returns `[value].` + A container that `value` belongs to. + If value does not belong to any container (including the case of + container having been destroyed), returns the value itself. + `value in unwrap(value_container(value))` will always be true. """ - _require_cross_tower_context(self) - return self._unwrap(value) + raise NotImplementedError("must be implemented in descendants") def _unwrap(self, distributed_value): raise NotImplementedError("must be implemented in descendants") @@ -1064,7 +959,8 @@ class TowerContext(object): def __init__(self, distribution_strategy, tower_id): self._distribution_strategy = distribution_strategy - self._thread_context = _InTowerThreadMode(self) + self._thread_context = distribution_strategy_context._InTowerThreadMode( # pylint: disable=protected-access + self) self._tower_id = tower_id def __enter__(self): @@ -1107,16 +1003,13 @@ class TowerContext(object): def _merge_call(self, merge_fn, *args, **kwargs): """Default implementation for single tower.""" _push_per_thread_mode( # thread-local, so not needed with multiple threads - _CrossTowerThreadMode(self._distribution_strategy)) + distribution_strategy_context._CrossTowerThreadMode( # pylint: disable=protected-access + self._distribution_strategy)) try: return merge_fn(self._distribution_strategy, *args, **kwargs) finally: _pop_per_thread_mode() - def tower_local_var_scope(self, reduce_method): - """Alias for distribution_strategy.tower_local_var_scope().""" - return self._distribution_strategy.tower_local_var_scope(reduce_method) - @property def is_single_tower(self): """Returns whether there is a single tower or multiple.""" @@ -1158,27 +1051,16 @@ class _DefaultDistributionStrategy(DistributionStrategy): def scope(self): """Context manager setting a variable creator and `self` as current.""" - if has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): raise RuntimeError("Must not nest DistributionStrategy scopes.") def creator(next_creator, *args, **kwargs): _require_distribution_strategy_scope(self) - kwargs.pop("tower_local_reduce_method", None) return next_creator(*args, **kwargs) return _CurrentDistributionContext( self, variable_scope.variable_creator_scope(creator)) - def tower_local_var_scope(self, reduce_method): - """Does not set to resource variables.""" - def create_tower_local_variable(next_creator, *args, **kwargs): - _require_distribution_strategy_scope(self) - kwargs["trainable"] = False - return next_creator(*args, **kwargs) - - _require_distribution_strategy_scope(self) - return variable_scope.variable_creator_scope(create_tower_local_variable) - def colocate_vars_with(self, colocate_with_variable): """Does not require `self.scope`.""" _require_distribution_strategy_scope(self) @@ -1199,9 +1081,9 @@ class _DefaultDistributionStrategy(DistributionStrategy): with TowerContext(self, tower_id=0): return fn(*args, **kwargs) - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): # TODO(josh11b): Use destinations? - del method_string, destinations + del aggregation, destinations return value def _update(self, var, fn, *args, **kwargs): @@ -1219,15 +1101,12 @@ class _DefaultDistributionStrategy(DistributionStrategy): def read_var(self, tower_local_var): return array_ops.identity(tower_local_var) - def _fetch(self, var, destination, fn): - with ops.colocate_with(var): - var = fn(var) - with ops.device(destination): - return array_ops.identity(var) - def _unwrap(self, distributed_value): return [distributed_value] + def value_container(self, value): + return value + @property def is_single_tower(self): return True @@ -1253,6 +1132,7 @@ class _DefaultDistributionStrategy(DistributionStrategy): raise RuntimeError("worker_device_index() method unsupported by " "_DefaultDistributionStrategy.") + # ------------------------------------------------------------------------------ # Common operations @@ -1268,19 +1148,10 @@ def increment_var(v, amount=1): def merge_fn(dist, vm): return dist.group(dist.update(vm, update)) - tower_context = get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() return tower_context.merge_call(merge_fn, v) -# ------------------------------------------------------------------------------ -# Singletons - -_default_distribution_strategy = _DefaultDistributionStrategy() -_default_tower_context = TowerContext( - _default_distribution_strategy, tower_id=0) -_default_tower_mode = _DefaultTowerThreadMode() - - # ------------------------------------------------------------------------------ # We haven't yet implemented deserialization for DistributedVariables. # So here we catch any attempts to deserialize variables @@ -1290,7 +1161,7 @@ _original_from_proto = resource_variable_ops._from_proto_fn def _from_proto_fn(v, import_scope=None): - if has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): raise NotImplementedError( "Deserialization of variables is not yet supported when using" "distributed strategies.") @@ -1299,3 +1170,10 @@ def _from_proto_fn(v, import_scope=None): resource_variable_ops._from_proto_fn = _from_proto_fn # pylint: enable=protected-access + + +#------------------------------------------------------------------------------- +# Shorthand for some methods from distribution_strategy_context. +_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access +_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access +_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access diff --git a/tensorflow/python/training/distribute_test.py b/tensorflow/python/training/distribute_test.py index 0a4f19c31f6714e1211f9deed9703c02192cc2c0..f03bd3910055d3022e5dc4d22ebb5ffc1a19cef8 100644 --- a/tensorflow/python/training/distribute_test.py +++ b/tensorflow/python/training/distribute_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.training import distribute +from tensorflow.python.training import distribution_strategy_context class _TestTowerContext(distribute.TowerContext): @@ -29,6 +30,14 @@ class _TestTowerContext(distribute.TowerContext): return kwargs["test_arg"] +def _get_test_variable(name, synchronization, aggregation): + return { + "name": name, + "synchronization": synchronization, + "aggregation": aggregation + } + + class _TestStrategy(distribute.DistributionStrategy): def _call_for_each_tower(self, fn, *args, **kwargs): @@ -36,16 +45,17 @@ class _TestStrategy(distribute.DistributionStrategy): return fn(*args, **kwargs) def _create_variable(self, next_creator, *args, **kwargs): - return kwargs["name"] + return _get_test_variable(kwargs["name"], kwargs["synchronization"], + kwargs["aggregation"]) def _assert_in_default_state(t): - t.assertIs(distribute._default_tower_context, - distribute.get_tower_context()) - t.assertIs(None, distribute.get_cross_tower_context()) - t.assertIs(distribute._default_distribution_strategy, - distribute.get_distribution_strategy()) - t.assertFalse(distribute.has_distribution_strategy()) + t.assertIs(distribution_strategy_context._get_default_tower_context(), + distribution_strategy_context.get_tower_context()) + t.assertIs(None, distribution_strategy_context.get_cross_tower_context()) + t.assertIs(distribution_strategy_context._get_default_distribution_strategy(), + distribution_strategy_context.get_distribution_strategy()) + t.assertFalse(distribution_strategy_context.has_distribution_strategy()) class TestStrategyTest(test.TestCase): @@ -55,13 +65,19 @@ class TestStrategyTest(test.TestCase): dist = _TestStrategy() def run_fn(): - tower_context = distribute.get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() self.assertTrue(tower_context is not None) - self.assertIs(None, distribute.get_cross_tower_context()) - self.assertTrue(distribute.has_distribution_strategy()) - self.assertIs(dist, distribute.get_distribution_strategy()) + self.assertIs(None, + distribution_strategy_context.get_cross_tower_context()) + self.assertTrue(distribution_strategy_context.has_distribution_strategy()) + self.assertIs(dist, + distribution_strategy_context.get_distribution_strategy()) self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo")) - self.assertEqual("bar", variable_scope.variable(1.0, name="bar")) + expected_value = _get_test_variable( + "bar", variable_scope.VariableSynchronization.AUTO, + variable_scope.VariableAggregation.NONE) + self.assertDictEqual(expected_value, + variable_scope.variable(1.0, name="bar")) with self.assertRaises(RuntimeError): dist.call_for_each_tower(run_fn) @@ -73,11 +89,33 @@ class TestStrategyTest(test.TestCase): _assert_in_default_state(self) dist = _TestStrategy() with dist.scope(): - self.assertIs(None, distribute.get_tower_context()) - self.assertIs(dist, distribute.get_cross_tower_context()) - self.assertTrue(distribute.has_distribution_strategy()) - self.assertIs(dist, distribute.get_distribution_strategy()) - self.assertEqual("baz", variable_scope.variable(1.0, name="baz")) + self.assertIs(None, distribution_strategy_context.get_tower_context()) + self.assertIs(dist, + distribution_strategy_context.get_cross_tower_context()) + self.assertTrue(distribution_strategy_context.has_distribution_strategy()) + self.assertIs(dist, + distribution_strategy_context.get_distribution_strategy()) + expected_value = _get_test_variable( + "baz", variable_scope.VariableSynchronization.AUTO, + variable_scope.VariableAggregation.NONE) + self.assertDictEqual(expected_value, + variable_scope.variable(1.0, name="baz")) + _assert_in_default_state(self) + + def testSettingSynchronizationAndAggregation(self): + _assert_in_default_state(self) + dist = _TestStrategy() + with dist.scope(): + expected_value = _get_test_variable( + "baz", variable_scope.VariableSynchronization.ON_WRITE, + variable_scope.VariableAggregation.MEAN) + self.assertDictEqual( + expected_value, + variable_scope.variable( + 1.0, + name="baz", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN)) _assert_in_default_state(self) @@ -87,15 +125,21 @@ class DefaultDistributionStrategyTest(test.TestCase): _assert_in_default_state(self) def merge_fn(dist, s): - self.assertIs(distribute._default_distribution_strategy, dist) - self.assertIs(None, distribute.get_tower_context()) - self.assertIs(dist, distribute.get_cross_tower_context()) - self.assertIs(dist, distribute.get_distribution_strategy()) - self.assertFalse(distribute.has_distribution_strategy()) + self.assertIs( + distribution_strategy_context._get_default_distribution_strategy(), + dist) + self.assertIs(None, distribution_strategy_context.get_tower_context()) + self.assertIs(dist, + distribution_strategy_context.get_cross_tower_context()) + self.assertIs(dist, + distribution_strategy_context.get_distribution_strategy()) + self.assertFalse( + distribution_strategy_context.has_distribution_strategy()) return "foo_" + s - tower_ctx = distribute.get_tower_context() - self.assertIs(distribute._default_tower_context, tower_ctx) + tower_ctx = distribution_strategy_context.get_tower_context() + self.assertIs(distribution_strategy_context._get_default_tower_context(), + tower_ctx) self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar")) _assert_in_default_state(self) diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py new file mode 100644 index 0000000000000000000000000000000000000000..998b5c35ceeee4a0db6114fc54995605862d79d1 --- /dev/null +++ b/tensorflow/python/training/distribution_strategy_context.py @@ -0,0 +1,203 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility to get distribution strategy related contexts.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.util.lazy_loader import LazyLoader + + +# There is a circular dependency between this and `distribute` module. So we +# load it lazily to workaround this. +distribute_lib = LazyLoader( + "distribute_lib", globals(), + "tensorflow.python.training.distribute") + +# ------------------------------------------------------------------------------ +# Internal API for setting the current thread mode as being either in a +# tower or cross-tower context for a particular distribution strategy. + + +class _ThreadMode(object): + + def __init__(self, dist, cross, tower): + self.distribution_strategy = dist + self.cross_tower_context = cross + self.tower_context = tower + + +class _CrossTowerThreadMode(_ThreadMode): + + def __init__(self, distribution_strategy): + _ThreadMode.__init__( + self, distribution_strategy, distribution_strategy, None) + + +class _InTowerThreadMode(_ThreadMode): + + def __init__(self, tower_ctx): + _ThreadMode.__init__( + self, tower_ctx.distribution_strategy, None, tower_ctx) + + +def _push_per_thread_mode(context): + ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access + + +def _pop_per_thread_mode(): + ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access + + +class _DefaultTowerThreadMode(_ThreadMode): + """Type of default value returned by `_get_per_thread_mode()`. + + Used when the thread-local stack is empty. + """ + + def __init__(self): + _ThreadMode.__init__(self, _get_default_distribution_strategy(), None, + _get_default_tower_context()) + + +def _get_per_thread_mode(): + try: + return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access + except (AttributeError, IndexError): + return _get_default_tower_mode() + + +# ------------------------------------------------------------------------------ +# Public API for accessing the current thread mode + + +def get_tower_context(): + """Returns the current TowerContext or None if in a cross-tower context. + + Note that execution: + 1. starts in the default (single-tower) tower context (this function + will return the default TowerContext object); + 2. switches to cross-tower context (in which case this will return + None) when entering a `with DistributionStrategy.scope():` block; + 3. switches to a (non-default) tower context inside + `call_for_each_tower(fn, ...)`; + 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then + inside `merge_fn` you are back in the cross-tower context (and again + this function will return None). + + Note that you can also go directly from step 1 to 4 to switch to a + cross-tower context for the default `DistributionStrategy`. You may + also switch from the cross-tower context of 4 to a tower context by + calling `call_for_each_tower()`, jumping back to step 3. + + Most `DistributionStrategy` methods may only be executed in + a cross-tower context, in a tower context you should use the + `TowerContext` API instead. + + Returns: + The current `TowerContext` object when in a tower context scope, else None. + + Exactly one of `get_tower_context()` and `get_cross_tower_context()` + will return None in a particular block. + """ + return _get_per_thread_mode().tower_context + + +def get_cross_tower_context(): + """Returns the current DistributionStrategy if in a cross-tower context. + + Note that execution: + 1. starts in the default (single-tower) tower context; + 2. switches to cross-tower context when entering a + `with DistributionStrategy.scope():` block; + 3. switches to a (non-default) tower context inside + `call_for_each_tower(fn, ...)`; + 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then + inside `merge_fn` you are back in the cross-tower context. + + Note that you can also go directly from step 1 to 4 to switch to a + cross-tower context for the default `DistributionStrategy`. You may + also switch from the cross-tower context of 4 to a tower context by + calling `call_for_each_tower()`, jumping back to step 3. + + Most `DistributionStrategy` methods may only be executed in + a cross-tower context. + + Returns: + Returns the current `DistributionStrategy` object in a cross-tower + context, or None. + + Exactly one of `get_tower_context()` and `get_cross_tower_context()` + will return None in a particular block. + """ + return _get_per_thread_mode().cross_tower_context + + +def get_distribution_strategy(): + """Returns the current `DistributionStrategy` object. + + Prefer to use `get_tower_context()` or `get_cross_tower_context()` + instead when possible. + + Returns: + A `DistributionStrategy` object. Inside a + `with distribution_strategy.scope()` block, it returns + `distribution_strategy`, otherwise it returns the default + (single-tower) `DistributionStrategy` object. + """ + return _get_per_thread_mode().distribution_strategy + + +def has_distribution_strategy(): + """Return if there is a current non-default `DistributionStrategy`. + + Returns: + True if inside a `with distribution_strategy.scope():`. + """ + return get_distribution_strategy() is not _get_default_distribution_strategy() + + +# ------------------------------------------------------------------------------ +# Defaults that are used when no distribution strategy is explicitly created. +# We create them lazily in a function so that we can workaround the circular +# dependency on distribute_lib. See lazy loader at the top of this file. + +_defaults = { + "distribution_strategy": None, + "tower_context": None, + "tower_mode": None +} + + +def _get_default_distribution_strategy(): + if _defaults["distribution_strategy"] is None: + _defaults["distribution_strategy"] = ( + distribute_lib._DefaultDistributionStrategy()) # pylint: disable=protected-access + return _defaults["distribution_strategy"] + + +def _get_default_tower_context(): + if _defaults["tower_context"] is None: + _defaults["tower_context"] = distribute_lib.TowerContext( + _get_default_distribution_strategy(), tower_id=0) + return _defaults["tower_context"] + + +def _get_default_tower_mode(): + if _defaults["tower_mode"] is None: + _defaults["tower_mode"] = _DefaultTowerThreadMode() + return _defaults["tower_mode"] diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py index 4fa081fab72df62107cf4957d4ff68240ced9ee0..832c10d454e6083be9715ef0af4642ad3e936f97 100644 --- a/tensorflow/python/training/ftrl.py +++ b/tensorflow/python/training/ftrl.py @@ -86,7 +86,7 @@ class FtrlOptimizer(optimizer.Optimizer): if initial_accumulator_value < 0.0: raise ValueError( - "initial_accumulator_value %f needs to be be positive or zero" % + "initial_accumulator_value %f needs to be positive or zero" % initial_accumulator_value) if learning_rate_power > 0.0: raise ValueError("learning_rate_power %f needs to be negative or zero" % diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py index 10ab4c1137ff226d88902143d4f2281ad77de531..fd195a7965ab7512728e4e9e9e0c51a00b6ad79d 100644 --- a/tensorflow/python/training/learning_rate_decay.py +++ b/tensorflow/python/training/learning_rate_decay.py @@ -19,6 +19,7 @@ from __future__ import print_function import math +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -87,6 +88,12 @@ def exponential_decay(learning_rate, Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("global_step is required for exponential_decay.") @@ -95,14 +102,22 @@ def exponential_decay(learning_rate, [learning_rate, global_step, decay_steps, decay_rate]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) decay_rate = math_ops.cast(decay_rate, dtype) - p = global_step / decay_steps - if staircase: - p = math_ops.floor(p) - return math_ops.multiply( - learning_rate, math_ops.pow(decay_rate, p), name=name) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + p = global_step_recomp / decay_steps + if staircase: + p = math_ops.floor(p) + return math_ops.multiply( + learning_rate, math_ops.pow(decay_rate, p), name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.piecewise_constant") @@ -141,48 +156,62 @@ def piecewise_constant(x, boundaries, values, name=None): ValueError: if types of `x` and `boundaries` do not match, or types of all `values` do not match or the number of elements in the lists does not match. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if len(boundaries) != len(values) - 1: raise ValueError( "The length of boundaries should be 1 less than the length of values") with ops.name_scope(name, "PiecewiseConstant", [x, boundaries, values, name]) as name: - x = ops.convert_to_tensor(x) - # Avoid explicit conversion to x's dtype. This could result in faulty - # comparisons, for example if floats are converted to integers. boundaries = ops.convert_n_to_tensor(boundaries) - for i, b in enumerate(boundaries): - if b.dtype.base_dtype != x.dtype.base_dtype: - # We can promote int32 boundaries to int64 without loss of precision. - # This covers the most common case where the user passes in boundaries - # as an array of Python integers. - if (b.dtype.base_dtype == dtypes.int32 and - x.dtype.base_dtype == dtypes.int64): - b = math_ops.cast(b, x.dtype.base_dtype) - boundaries[i] = b - else: - raise ValueError( - "Boundaries (%s) must have the same dtype as x (%s)." % - (b.dtype.base_dtype, x.dtype.base_dtype)) - # TODO(rdipietro): Ensure that boundaries' elements are strictly increasing. values = ops.convert_n_to_tensor(values) - for v in values[1:]: - if v.dtype.base_dtype != values[0].dtype.base_dtype: - raise ValueError( - "Values must have elements all with the same dtype (%s vs %s)." % - (values[0].dtype.base_dtype, v.dtype.base_dtype)) - pred_fn_pairs = [] - pred_fn_pairs.append((x <= boundaries[0], lambda: values[0])) - pred_fn_pairs.append((x > boundaries[-1], lambda: values[-1])) - for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): - # Need to bind v here; can do this with lambda v=v: ... - pred = (x > low) & (x <= high) - pred_fn_pairs.append((pred, lambda v=v: v)) - - # The default isn't needed here because our conditions are mutually - # exclusive and exhaustive, but tf.case requires it. - default = lambda: values[0] - return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + x_recomp = ops.convert_to_tensor(x) + # Avoid explicit conversion to x's dtype. This could result in faulty + # comparisons, for example if floats are converted to integers. + for i, b in enumerate(boundaries): + if b.dtype.base_dtype != x_recomp.dtype.base_dtype: + # We can promote int32 boundaries to int64 without loss of precision. + # This covers the most common case where the user passes in boundaries + # as an array of Python integers. + if (b.dtype.base_dtype == dtypes.int32 and + x_recomp.dtype.base_dtype == dtypes.int64): + b = math_ops.cast(b, x_recomp.dtype.base_dtype) + boundaries[i] = b + else: + raise ValueError( + "Boundaries (%s) must have the same dtype as x (%s)." % + (b.dtype.base_dtype, x_recomp.dtype.base_dtype)) + # TODO(rdipietro): Ensure that boundaries' elements strictly increases. + for v in values[1:]: + if v.dtype.base_dtype != values[0].dtype.base_dtype: + raise ValueError( + "Values must have elements all with the same dtype (%s vs %s)." % + (values[0].dtype.base_dtype, v.dtype.base_dtype)) + pred_fn_pairs = [] + pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0])) + pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1])) + for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): + # Need to bind v here; can do this with lambda v=v: ... + pred = (x_recomp > low) & (x_recomp <= high) + pred_fn_pairs.append((pred, lambda v=v: v)) + + # The default isn't needed here because our conditions are mutually + # exclusive and exhaustive, but tf.case requires it. + default = lambda: values[0] + return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.polynomial_decay") @@ -263,6 +292,12 @@ def polynomial_decay(learning_rate, Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("global_step is required for polynomial_decay.") @@ -272,27 +307,35 @@ def polynomial_decay(learning_rate, ]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) - decay_steps = math_ops.cast(decay_steps, dtype) end_learning_rate = math_ops.cast(end_learning_rate, dtype) power = math_ops.cast(power, dtype) - if cycle: - # Find the first multiple of decay_steps that is bigger than global_step. - # If global_step is zero set the multiplier to 1 - multiplier = control_flow_ops.cond( - math_ops.equal(global_step, 0), lambda: 1.0, - lambda: math_ops.ceil(global_step / decay_steps)) - decay_steps = math_ops.multiply(decay_steps, multiplier) - else: - # Make sure that the global_step used is not bigger than decay_steps. - global_step = math_ops.minimum(global_step, decay_steps) - - p = math_ops.div(global_step, decay_steps) - return math_ops.add( - math_ops.multiply(learning_rate - end_learning_rate, - math_ops.pow(1 - p, power)), - end_learning_rate, - name=name) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + decay_steps_recomp = math_ops.cast(decay_steps, dtype) + if cycle: + # Find the first multiple of decay_steps that is bigger than + # global_step. If global_step is zero set the multiplier to 1 + multiplier = control_flow_ops.cond( + math_ops.equal(global_step_recomp, 0), lambda: 1.0, + lambda: math_ops.ceil(global_step_recomp / decay_steps)) + decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier) + else: + # Make sure that the global_step used is not bigger than decay_steps. + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + + p = math_ops.div(global_step_recomp, decay_steps_recomp) + return math_ops.add( + math_ops.multiply(learning_rate - end_learning_rate, + math_ops.pow(1 - p, power)), + end_learning_rate, + name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.natural_exp_decay") @@ -313,7 +356,15 @@ def natural_exp_decay(learning_rate, The function returns the decayed learning rate. It is computed as: ```python - decayed_learning_rate = learning_rate * exp(-decay_rate * global_step) + decayed_learning_rate = learning_rate * exp(-decay_rate * global_step / + decay_step) + ``` + + or, if `staircase` is `True`, as: + + ```python + decayed_learning_rate = learning_rate * exp(-decay_rate * floor(global_step / + decay_step)) ``` Example: decay exponentially with a base of 0.96: @@ -322,8 +373,10 @@ def natural_exp_decay(learning_rate, ... global_step = tf.Variable(0, trainable=False) learning_rate = 0.1 + decay_steps = 5 k = 0.5 - learning_rate = tf.train.exponential_time_decay(learning_rate, global_step, k) + learning_rate = tf.train.natural_exp_decay(learning_rate, global_step, + decay_steps, k) # Passing global_step to minimize() will increment it at each step. learning_step = ( @@ -350,6 +403,12 @@ def natural_exp_decay(learning_rate, Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("global_step is required for natural_exp_decay.") @@ -357,14 +416,23 @@ def natural_exp_decay(learning_rate, [learning_rate, global_step, decay_rate]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) decay_rate = math_ops.cast(decay_rate, dtype) - p = global_step / decay_steps - if staircase: - p = math_ops.floor(p) - exponent = math_ops.exp(math_ops.multiply(math_ops.negative(decay_rate), p)) - return math_ops.multiply(learning_rate, exponent, name=name) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + p = global_step_recomp / decay_steps + if staircase: + p = math_ops.floor(p) + exponent = math_ops.exp( + math_ops.multiply(math_ops.negative(decay_rate), p)) + return math_ops.multiply(learning_rate, exponent, name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.inverse_time_decay") @@ -432,6 +500,12 @@ def inverse_time_decay(learning_rate, Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("global_step is required for inverse_time_decay.") @@ -439,15 +513,23 @@ def inverse_time_decay(learning_rate, [learning_rate, global_step, decay_rate]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) decay_rate = math_ops.cast(decay_rate, dtype) - p = global_step / decay_steps - if staircase: - p = math_ops.floor(p) - const = math_ops.cast(constant_op.constant(1), learning_rate.dtype) - denom = math_ops.add(const, math_ops.multiply(decay_rate, p)) - return math_ops.div(learning_rate, denom, name=name) + + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + p = global_step_recomp / decay_steps + if staircase: + p = math_ops.floor(p) + const = math_ops.cast(constant_op.constant(1), dtype) + denom = math_ops.add(const, math_ops.multiply(decay_rate, p)) + return math_ops.div(learning_rate, denom, name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.cosine_decay") @@ -492,6 +574,12 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): learning rate. Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("cosine decay requires global_step") @@ -499,15 +587,23 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): [learning_rate, global_step]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) - global_step = math_ops.minimum(global_step, decay_steps) - completed_fraction = global_step / decay_steps - cosine_decayed = 0.5 * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction)) - decayed = (1 - alpha) * cosine_decayed + alpha - return math_ops.multiply(learning_rate, decayed) + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + completed_fraction = global_step_recomp / decay_steps + cosine_decayed = 0.5 * (1.0 + math_ops.cos( + constant_op.constant(math.pi) * completed_fraction)) + + decayed = (1 - alpha) * cosine_decayed + alpha + return math_ops.multiply(learning_rate, decayed) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.cosine_decay_restarts") @@ -561,6 +657,12 @@ def cosine_decay_restarts(learning_rate, learning rate. Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("cosine decay restarts requires global_step") @@ -568,40 +670,48 @@ def cosine_decay_restarts(learning_rate, learning_rate = ops.convert_to_tensor( learning_rate, name="initial_learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) first_decay_steps = math_ops.cast(first_decay_steps, dtype) alpha = math_ops.cast(alpha, dtype) t_mul = math_ops.cast(t_mul, dtype) m_mul = math_ops.cast(m_mul, dtype) - completed_fraction = global_step / first_decay_steps + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + completed_fraction = global_step_recomp / first_decay_steps - def compute_step(completed_fraction, geometric=False): - if geometric: - i_restart = math_ops.floor( - math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) / - math_ops.log(t_mul)) + def compute_step(completed_fraction, geometric=False): + """Helper for `cond` operation.""" + if geometric: + i_restart = math_ops.floor( + math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) / + math_ops.log(t_mul)) - sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) - completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart + sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) + completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart - else: - i_restart = math_ops.floor(completed_fraction) - completed_fraction = completed_fraction - i_restart + else: + i_restart = math_ops.floor(completed_fraction) + completed_fraction -= i_restart + + return i_restart, completed_fraction - return i_restart, completed_fraction + i_restart, completed_fraction = control_flow_ops.cond( + math_ops.equal(t_mul, 1.0), + lambda: compute_step(completed_fraction, geometric=False), + lambda: compute_step(completed_fraction, geometric=True)) - i_restart, completed_fraction = control_flow_ops.cond( - math_ops.equal(t_mul, 1.0), - lambda: compute_step(completed_fraction, geometric=False), - lambda: compute_step(completed_fraction, geometric=True)) + m_fac = m_mul**i_restart + cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos( + constant_op.constant(math.pi) * completed_fraction)) + decayed = (1 - alpha) * cosine_decayed + alpha - m_fac = m_mul**i_restart - cosine_decayed = 0.5 * m_fac * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction)) - decayed = (1 - alpha) * cosine_decayed + alpha + return math_ops.multiply(learning_rate, decayed, name=name) - return math_ops.multiply(learning_rate, decayed, name=name) + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.linear_cosine_decay") @@ -664,6 +774,12 @@ def linear_cosine_decay(learning_rate, learning rate. Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("linear cosine decay requires global_step") @@ -671,21 +787,28 @@ def linear_cosine_decay(learning_rate, [learning_rate, global_step]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) num_periods = math_ops.cast(num_periods, dtype) - global_step = math_ops.minimum(global_step, decay_steps) alpha = math_ops.cast(alpha, dtype) beta = math_ops.cast(beta, dtype) - linear_decayed = (decay_steps - global_step) / decay_steps - completed_fraction = global_step / decay_steps - fraction = 2.0 * num_periods * completed_fraction - cosine_decayed = 0.5 * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + linear_decayed = (decay_steps - global_step_recomp) / decay_steps + completed_fraction = global_step_recomp / decay_steps + fraction = 2.0 * num_periods * completed_fraction + cosine_decayed = 0.5 * ( + 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) + + linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta + return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name) - linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta - return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name) + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr @tf_export("train.noisy_linear_cosine_decay") @@ -756,6 +879,12 @@ def noisy_linear_cosine_decay(learning_rate, learning rate. Raises: ValueError: if `global_step` is not supplied. + + @compatibility(eager) + When eager execution is enabled, this function returns a function which in + turn returns the decayed learning rate Tensor. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + @end_compatibility """ if global_step is None: raise ValueError("noisy linear cosine decay requires global_step") @@ -763,29 +892,36 @@ def noisy_linear_cosine_decay(learning_rate, [learning_rate, global_step]) as name: learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") dtype = learning_rate.dtype - global_step = math_ops.cast(global_step, dtype) decay_steps = math_ops.cast(decay_steps, dtype) - global_step = math_ops.minimum(global_step, decay_steps) initial_variance = math_ops.cast(initial_variance, dtype) variance_decay = math_ops.cast(variance_decay, dtype) num_periods = math_ops.cast(num_periods, dtype) alpha = math_ops.cast(alpha, dtype) beta = math_ops.cast(beta, dtype) - linear_decayed = (decay_steps - global_step) / decay_steps - variance = initial_variance / ( - math_ops.pow(1.0 + global_step, variance_decay)) - std = math_ops.sqrt(variance) - noisy_linear_decayed = ( - linear_decayed + - random_ops.random_normal(linear_decayed.shape, stddev=std)) - - completed_fraction = global_step / decay_steps - fraction = 2.0 * num_periods * completed_fraction - cosine_decayed = 0.5 * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) - noisy_linear_cosine_decayed = ( - (alpha + noisy_linear_decayed) * cosine_decayed + beta) - - return math_ops.multiply( - learning_rate, noisy_linear_cosine_decayed, name=name) + def decayed_lr(): + """Helper to recompute learning rate; most helpful in eager-mode.""" + global_step_recomp = math_ops.cast(global_step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + linear_decayed = (decay_steps - global_step_recomp) / decay_steps + variance = initial_variance / ( + math_ops.pow(1.0 + global_step_recomp, variance_decay)) + std = math_ops.sqrt(variance) + noisy_linear_decayed = ( + linear_decayed + random_ops.random_normal( + linear_decayed.shape, stddev=std)) + + completed_fraction = global_step_recomp / decay_steps + fraction = 2.0 * num_periods * completed_fraction + cosine_decayed = 0.5 * ( + 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) + noisy_linear_cosine_decayed = ( + (alpha + noisy_linear_decayed) * cosine_decayed + beta) + + return math_ops.multiply( + learning_rate, noisy_linear_cosine_decayed, name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py index 60306e4f1239a759ea1f68492a1211d5f0858997..4f3cf01822c5b56c8fd05f859c3a1db302a57625 100644 --- a/tensorflow/python/training/learning_rate_decay_test.py +++ b/tensorflow/python/training/learning_rate_decay_test.py @@ -21,12 +21,9 @@ from __future__ import print_function import math from tensorflow.python.eager import context -from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util -from tensorflow.python.ops import gen_state_ops # Import resource_variable_ops for the variables-to-tensor implicit conversion. from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import -from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.training import learning_rate_decay @@ -34,31 +31,35 @@ from tensorflow.python.training import learning_rate_decay class LRDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testContinuous(self): - with self.test_session(): - step = 5 - decayed_lr = learning_rate_decay.exponential_decay(0.05, step, 10, 0.96) - expected = .05 * 0.96 ** (5.0 / 10.0) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + self.evaluate(variables.global_variables_initializer()) + step = 5 + decayed_lr = learning_rate_decay.exponential_decay(0.05, step, 10, 0.96) + expected = .05 * 0.96**(5.0 / 10.0) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testStaircase(self): - with self.test_session(): - step = gen_state_ops.variable(shape=[], dtype=dtypes.int32, - name="step", container="", shared_name="") - assign_100 = state_ops.assign(step, 100) - assign_1 = state_ops.assign(step, 1) - assign_2 = state_ops.assign(step, 2) - decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96, - staircase=True) - # No change to learning rate - assign_1.op.run() - self.assertAllClose(decayed_lr.eval(), .1, 1e-6) - assign_2.op.run() - self.assertAllClose(decayed_lr.eval(), .1, 1e-6) + if context.executing_eagerly(): + step = resource_variable_ops.ResourceVariable(0) + self.evaluate(variables.global_variables_initializer()) + decayed_lr = learning_rate_decay.exponential_decay( + .1, step, 3, 0.96, staircase=True) + + # No change to learning rate due to staircase + expected = .1 + self.evaluate(step.assign(1)) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + expected = .1 + self.evaluate(step.assign(2)) + self.assertAllClose(self.evaluate(decayed_lr), .1, 1e-6) + # Decayed learning rate - assign_100.op.run() expected = .1 * 0.96 ** (100 // 3) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + self.evaluate(step.assign(100)) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) def testVariables(self): with self.test_session(): @@ -79,38 +80,44 @@ class LRDecayTest(test_util.TensorFlowTestCase): expected = .1 * 0.96 ** (100 // 3) self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPiecewiseConstant(self): x = resource_variable_ops.ResourceVariable(-999) - def pc(): - return learning_rate_decay.piecewise_constant(x, [100, 110, 120], - [1.0, 0.1, 0.01, 0.001]) + decayed_lr = learning_rate_decay.piecewise_constant( + x, [100, 110, 120], [1.0, 0.1, 0.01, 0.001]) self.evaluate(variables.global_variables_initializer()) - self.assertAllClose(self.evaluate(pc()), 1.0, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 1.0, 1e-6) self.evaluate(x.assign(100)) - self.assertAllClose(self.evaluate(pc()), 1.0, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 1.0, 1e-6) self.evaluate(x.assign(105)) - self.assertAllClose(self.evaluate(pc()), 0.1, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.1, 1e-6) self.evaluate(x.assign(110)) - self.assertAllClose(self.evaluate(pc()), 0.1, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.1, 1e-6) self.evaluate(x.assign(120)) - self.assertAllClose(self.evaluate(pc()), 0.01, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.01, 1e-6) self.evaluate(x.assign(999)) - self.assertAllClose(self.evaluate(pc()), 0.001, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.001, 1e-6) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testPiecewiseConstantEdgeCases(self): x_int = resource_variable_ops.ResourceVariable( 0, dtype=variables.dtypes.int32) boundaries, values = [-1.0, 1.0], [1, 2, 3] with self.assertRaises(ValueError): - learning_rate_decay.piecewise_constant(x_int, boundaries, values) + decayed_lr = learning_rate_decay.piecewise_constant( + x_int, boundaries, values) + if context.executing_eagerly(): + decayed_lr() + x = resource_variable_ops.ResourceVariable(0.0) boundaries, values = [-1.0, 1.0], [1.0, 2, 3] with self.assertRaises(ValueError): - learning_rate_decay.piecewise_constant(x, boundaries, values) + decayed_lr = learning_rate_decay.piecewise_constant( + x, boundaries, values) + if context.executing_eagerly(): + decayed_lr() # Test that ref types are valid. if not context.executing_eagerly(): @@ -123,221 +130,205 @@ class LRDecayTest(test_util.TensorFlowTestCase): x_int64 = resource_variable_ops.ResourceVariable( 0, dtype=variables.dtypes.int64) boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7] - def pc(): - return learning_rate_decay.piecewise_constant(x_int64, boundaries, values) + decayed_lr = learning_rate_decay.piecewise_constant( + x_int64, boundaries, values) self.evaluate(variables.global_variables_initializer()) - self.assertAllClose(self.evaluate(pc()), 0.4, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.4, 1e-6) self.evaluate(x_int64.assign(1)) - self.assertAllClose(self.evaluate(pc()), 0.4, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.4, 1e-6) self.evaluate(x_int64.assign(2)) - self.assertAllClose(self.evaluate(pc()), 0.5, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.5, 1e-6) self.evaluate(x_int64.assign(3)) - self.assertAllClose(self.evaluate(pc()), 0.6, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.6, 1e-6) self.evaluate(x_int64.assign(4)) - self.assertAllClose(self.evaluate(pc()), 0.7, 1e-6) + self.assertAllClose(self.evaluate(decayed_lr), 0.7, 1e-6) class LinearDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testHalfWay(self): - with self.test_session(): - step = 5 - lr = 0.05 - end_lr = 0.0 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) - expected = lr * 0.5 - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 5 + lr = 0.05 + end_lr = 0.0 + decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) + expected = lr * 0.5 + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testEnd(self): - with self.test_session(): - step = 10 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) - expected = end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 10 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testHalfWayWithEnd(self): - with self.test_session(): - step = 5 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) - expected = (lr + end_lr) * 0.5 - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 5 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) + expected = (lr + end_lr) * 0.5 + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testBeyondEnd(self): - with self.test_session(): - step = 15 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) - expected = end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 15 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testBeyondEndWithCycle(self): - with self.test_session(): - step = 15 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - cycle=True) - expected = (lr - end_lr) * 0.25 + end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + step = 15 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, cycle=True) + expected = (lr - end_lr) * 0.25 + end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class SqrtDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testHalfWay(self): - with self.test_session(): - step = 5 - lr = 0.05 - end_lr = 0.0 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power) - expected = lr * 0.5 ** power - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 5 + lr = 0.05 + end_lr = 0.0 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = lr * 0.5**power + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testEnd(self): - with self.test_session(): - step = 10 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power) - expected = end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 10 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testHalfWayWithEnd(self): - with self.test_session(): - step = 5 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power) - expected = (lr - end_lr) * 0.5 ** power + end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 5 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = (lr - end_lr) * 0.5**power + end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testBeyondEnd(self): - with self.test_session(): - step = 15 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power) - expected = end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - + step = 15 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes def testBeyondEndWithCycle(self): - with self.test_session(): - step = 15 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, - power=power, cycle=True) - expected = (lr - end_lr) * 0.25 ** power + end_lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + step = 15 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, 10, end_lr, power=power, cycle=True) + expected = (lr - end_lr) * 0.25**power + end_lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class PolynomialDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testBeginWithCycle(self): - with self.test_session(): - lr = 0.001 - decay_steps = 10 - step = 0 - decayed_lr = learning_rate_decay.polynomial_decay(lr, step, - decay_steps, cycle=True) - expected = lr - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + lr = 0.001 + decay_steps = 10 + step = 0 + decayed_lr = learning_rate_decay.polynomial_decay( + lr, step, decay_steps, cycle=True) + expected = lr + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class ExponentialDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testDecay(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops.variable( - shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") - assign_step = state_ops.assign(step, 0) - increment_step = state_ops.assign_add(step, 1) - decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, step, - k, decay_rate) - with self.test_session(): - assign_step.op.run() - for i in range(k+1): - expected = initial_lr * math.exp(-i / k * decay_rate) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - increment_step.op.run() + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, step, k, + decay_rate) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr * math.exp(-i / k * decay_rate) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + self.evaluate(step.assign_add(1)) + @test_util.run_in_graph_and_eager_modes def testStaircase(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops.variable( - shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") - assign_step = state_ops.assign(step, 0) - increment_step = state_ops.assign_add(step, 1) - decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, - step, - k, - decay_rate, - staircase=True) - with self.test_session(): - assign_step.op.run() - for i in range(k+1): - expected = initial_lr * math.exp(-decay_rate * (i // k)) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - increment_step.op.run() + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay.natural_exp_decay( + initial_lr, step, k, decay_rate, staircase=True) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr * math.exp(-decay_rate * (i // k)) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + self.evaluate(step.assign_add(1)) class InverseDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testDecay(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops.variable( - shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") - assign_step = state_ops.assign(step, 0) - increment_step = state_ops.assign_add(step, 1) - decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, - step, - k, + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, step, k, decay_rate) - with self.test_session(): - assign_step.op.run() - for i in range(k+1): - expected = initial_lr / (1 + i / k * decay_rate) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - increment_step.op.run() + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr / (1 + i / k * decay_rate) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + self.evaluate(step.assign_add(1)) + + @test_util.run_in_graph_and_eager_modes def testStaircase(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 - step = gen_state_ops.variable( - shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") - assign_step = state_ops.assign(step, 0) - increment_step = state_ops.assign_add(step, 1) - decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, - step, - k, - decay_rate, - staircase=True) - with self.test_session(): - assign_step.op.run() - for i in range(k+1): - expected = initial_lr / (1 + decay_rate * (i // k)) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) - increment_step.op.run() + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay.inverse_time_decay( + initial_lr, step, k, decay_rate, staircase=True) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr / (1 + decay_rate * (i // k)) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + self.evaluate(step.assign_add(1)) class CosineDecayTest(test_util.TensorFlowTestCase): @@ -348,34 +339,35 @@ class CosineDecayTest(test_util.TensorFlowTestCase): decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) return (1.0 - alpha) * decay + alpha + @test_util.run_in_graph_and_eager_modes def testDecay(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay( - initial_lr, step, num_training_steps) - expected = self.np_cosine_decay(step, num_training_steps) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay(initial_lr, step, + num_training_steps) + expected = self.np_cosine_decay(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testAlpha(self): num_training_steps = 1000 initial_lr = 1.0 alpha = 0.1 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay( - initial_lr, step, num_training_steps, alpha) - expected = self.np_cosine_decay(step, num_training_steps, alpha) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay(initial_lr, step, + num_training_steps, alpha) + expected = self.np_cosine_decay(step, num_training_steps, alpha) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class CosineDecayRestartsTest(test_util.TensorFlowTestCase): + def np_cosine_decay_restarts(self, step, decay_steps, t_mul=2.0, m_mul=1.0, alpha=0.0): fac = 1.0 while step >= decay_steps: - step = step - decay_steps + step -= decay_steps decay_steps *= t_mul fac *= m_mul @@ -383,51 +375,51 @@ class CosineDecayRestartsTest(test_util.TensorFlowTestCase): decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) return (1.0 - alpha) * decay + alpha + @test_util.run_in_graph_and_eager_modes def testDecay(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay_restarts( - initial_lr, step, num_training_steps) - expected = self.np_cosine_decay_restarts(step, num_training_steps) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay_restarts( + initial_lr, step, num_training_steps) + expected = self.np_cosine_decay_restarts(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testAlpha(self): num_training_steps = 1000 initial_lr = 1.0 alpha = 0.1 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay_restarts( - initial_lr, step, num_training_steps, alpha=alpha) - expected = self.np_cosine_decay_restarts(step, num_training_steps, - alpha=alpha) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay_restarts( + initial_lr, step, num_training_steps, alpha=alpha) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, alpha=alpha) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testMMul(self): num_training_steps = 1000 initial_lr = 1.0 m_mul = 0.9 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay_restarts( - initial_lr, step, num_training_steps, m_mul=m_mul) - expected = self.np_cosine_decay_restarts(step, num_training_steps, - m_mul=m_mul) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay_restarts( + initial_lr, step, num_training_steps, m_mul=m_mul) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, m_mul=m_mul) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testTMul(self): num_training_steps = 1000 initial_lr = 1.0 t_mul = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.cosine_decay_restarts( - initial_lr, step, num_training_steps, t_mul=t_mul) - expected = self.np_cosine_decay_restarts(step, num_training_steps, - t_mul=t_mul) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.cosine_decay_restarts( + initial_lr, step, num_training_steps, t_mul=t_mul) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, t_mul=t_mul) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class LinearCosineDecayTest(test_util.TensorFlowTestCase): @@ -444,65 +436,63 @@ class LinearCosineDecayTest(test_util.TensorFlowTestCase): cosine_decayed = 0.5 * (1.0 + math.cos(math.pi * fraction)) return (alpha + linear_decayed) * cosine_decayed + beta + @test_util.run_in_graph_and_eager_modes def testDefaultDecay(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.linear_cosine_decay( - initial_lr, step, num_training_steps) - expected = self.np_linear_cosine_decay(step, num_training_steps) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.linear_cosine_decay( + initial_lr, step, num_training_steps) + expected = self.np_linear_cosine_decay(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) + @test_util.run_in_graph_and_eager_modes def testNonDefaultDecay(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - decayed_lr = learning_rate_decay.linear_cosine_decay( - initial_lr, - step, - num_training_steps, - alpha=0.1, - beta=1e-4, - num_periods=5) - expected = self.np_linear_cosine_decay( - step, - num_training_steps, - alpha=0.1, - beta=1e-4, - num_periods=5) - self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + decayed_lr = learning_rate_decay.linear_cosine_decay( + initial_lr, + step, + num_training_steps, + alpha=0.1, + beta=1e-4, + num_periods=5) + expected = self.np_linear_cosine_decay( + step, num_training_steps, alpha=0.1, beta=1e-4, num_periods=5) + self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) class NoisyLinearCosineDecayTest(test_util.TensorFlowTestCase): + @test_util.run_in_graph_and_eager_modes def testDefaultNoisyLinearCosine(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - # No numerical check because of noise - decayed_lr = learning_rate_decay.noisy_linear_cosine_decay( - initial_lr, step, num_training_steps) - decayed_lr.eval() + # No numerical check because of noise + decayed_lr = learning_rate_decay.noisy_linear_cosine_decay( + initial_lr, step, num_training_steps) + # Cannot be deterministically tested + self.evaluate(decayed_lr) + @test_util.run_in_graph_and_eager_modes def testNonDefaultNoisyLinearCosine(self): num_training_steps = 1000 initial_lr = 1.0 for step in range(0, 1500, 250): - with self.test_session(): - # No numerical check because of noise - decayed_lr = learning_rate_decay.noisy_linear_cosine_decay( - initial_lr, - step, - num_training_steps, - initial_variance=0.5, - variance_decay=0.1, - alpha=0.1, - beta=1e-4, - num_periods=5) - decayed_lr.eval() + # No numerical check because of noise + decayed_lr = learning_rate_decay.noisy_linear_cosine_decay( + initial_lr, + step, + num_training_steps, + initial_variance=0.5, + variance_decay=0.1, + alpha=0.1, + beta=1e-4, + num_periods=5) + # Cannot be deterministically tested + self.evaluate(decayed_lr) if __name__ == "__main__": diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index fece3370f343173de46bc447c478264864708dca..7b06bffa4b29b92dd8d3df5d8eaa6ebec1ea44b1 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -298,7 +298,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name stop_grace_period_secs=120, log_step_count_steps=100, max_wait_secs=7200, - save_checkpoint_steps=USE_DEFAULT): + save_checkpoint_steps=USE_DEFAULT, + summary_dir=None): """Creates a `MonitoredSession` for training. For a chief, this utility sets proper session initializer/restorer. It also @@ -348,6 +349,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then the default checkpoint saver isn't used. If both are provided, then only `save_checkpoint_secs` is used. Default not enabled. + summary_dir: A string. Optional path to a directory where to + save summaries. If None, checkpoint_dir is used instead. Returns: A `MonitoredSession` object. @@ -388,11 +391,12 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name master=master, config=config) - if checkpoint_dir: + summary_dir = summary_dir or checkpoint_dir + if summary_dir: if log_step_count_steps and log_step_count_steps > 0: all_hooks.append( basic_session_run_hooks.StepCounterHook( - output_dir=checkpoint_dir, every_n_steps=log_step_count_steps)) + output_dir=summary_dir, every_n_steps=log_step_count_steps)) if (save_summaries_steps and save_summaries_steps > 0) or ( save_summaries_secs and save_summaries_secs > 0): @@ -400,7 +404,9 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name scaffold=scaffold, save_steps=save_summaries_steps, save_secs=save_summaries_secs, - output_dir=checkpoint_dir)) + output_dir=summary_dir)) + + if checkpoint_dir: if (save_checkpoint_secs and save_checkpoint_secs > 0) or ( save_checkpoint_steps and save_checkpoint_steps > 0): all_hooks.append(basic_session_run_hooks.CheckpointSaverHook( diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index 3806056f01a73d21faf3de4539c0dd1ada5f96f8..92533ca4f3b0080914af55e28368febd8ed1f90f 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import coordinator from tensorflow.python.training import monitored_session from tensorflow.python.training import saver as saver_lib @@ -1364,8 +1365,8 @@ class MonitoredSessionTest(test.TestCase): with monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( scaffold, - checkpoint_filename_with_path=saver_lib.latest_checkpoint( - logdir))) as session: + checkpoint_filename_with_path= + checkpoint_management.latest_checkpoint(logdir))) as session: self.assertEqual(2, session.run(gstep)) def test_retry_initialization_on_aborted_error(self): diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 60cc54c2645a0f44195bbb86013e0306387aa8aa..4b91d1e963a234951a7b1254eb07038935257136 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -300,7 +300,7 @@ class ExponentialMovingAverage(object): for a given variable. * Build a model normally but load the checkpoint files to evaluate by using the shadow variable names. For this use the `average_name()` method. See - the @{tf.train.Saver} for more + the `tf.train.Saver` for more information on restoring saved variables. Example of restoring the shadow variable values: diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index cae29eea933aec1554beea8d0413fd9febcf2d94..1b6bce28652388dfceaa0b291519e85e869ada83 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import slot_creator from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest @@ -51,8 +52,8 @@ def get_filtered_grad_fn(grad_fn): # those variables are accessed in another thread during the gradient # computation. To get a consistent set of variables, we filter out # those with `None` gradients. - def filtered_grad_fn(x=None): - return [(g, v) for g, v in grad_fn(x) if g is not None] + def filtered_grad_fn(*args, **kwargs): + return [(g, v) for g, v in grad_fn(*args, **kwargs) if g is not None] return filtered_grad_fn @@ -77,9 +78,10 @@ def _deduplicate_indexed_slices(values, indices): def _var_key(var): - if context.executing_eagerly(): - return var._unique_id # pylint: disable=protected-access - return (var.op.graph, var.op.name) + # TODO(ashankar): Consolidate handling for eager and graph + if hasattr(var, "op"): + return (var.op.graph, var.op.name) + return var._unique_id # pylint: disable=protected-access class _OptimizableVariable(object): @@ -461,8 +463,10 @@ class Optimizer( # Have to be careful to call distribute_lib.get_loss_reduction() # *after* loss() is evaluated, so we know what loss reduction it uses. # TODO(josh11b): Test that we handle weight decay in a reasonable way. - if distribute_lib.get_loss_reduction() == "mean": - num_towers = distribute_lib.get_distribution_strategy().num_towers + if (distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN): + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss_value *= (1. / num_towers) @@ -478,8 +482,10 @@ class Optimizer( "be a function when eager execution is enabled.") # Scale loss if using a "mean" loss reduction and multiple towers. - if distribute_lib.get_loss_reduction() == "mean": - num_towers = distribute_lib.get_distribution_strategy().num_towers + if (distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN): + num_towers = distribution_strategy_context.get_distribution_strategy( + ).num_towers if num_towers > 1: loss *= (1. / num_towers) @@ -545,15 +551,15 @@ class Optimizer( # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). # Handle DistributionStrategy case. - if distribute_lib.get_cross_tower_context(): + if distribution_strategy_context.get_cross_tower_context(): raise RuntimeError("Use `_distributed_apply()` instead of " "`apply_gradients()` in a cross-tower context.") # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by # always calling _distributed_apply(), using the default distribution # as needed. - if distribute_lib.has_distribution_strategy(): - grads_and_vars = get_filtered_grad_fn(lambda _: grads_and_vars)() - return distribute_lib.get_tower_context().merge_call( + if distribution_strategy_context.has_distribution_strategy(): + grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)() + return distribution_strategy_context.get_tower_context().merge_call( self._distributed_apply, grads_and_vars, global_step, name) # No DistributionStrategy case. @@ -649,7 +655,8 @@ class Optimizer( towers. If `global_step` was not None, that operation also increments `global_step`. """ - reduced_grads = distribution.batch_reduce("sum", grads_and_vars) + reduced_grads = distribution.batch_reduce( + variable_scope.VariableAggregation.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) # Note that this is called in a cross-tower context. @@ -730,15 +737,15 @@ class Optimizer( if not named_slots: return None - if hasattr(var, "_mirrored_container"): + if hasattr(var, "_distributed_container"): # NOTE: If this isn't patched, then there is no `handle` in # `_resource_apply_dense`. - mirrored_container = var._mirrored_container() - assert mirrored_container is not None + distributed_container = var._distributed_container() + assert distributed_container is not None if context.executing_eagerly(): - key = mirrored_container._unique_id + key = distributed_container._unique_id else: - key = (mirrored_container.graph, mirrored_container._shared_name) + key = (distributed_container.graph, distributed_container._shared_name) # pylint: enable=protected-access mirrored_slot = named_slots.get(key, None) if mirrored_slot is None: return None @@ -795,7 +802,8 @@ class Optimizer( v = self._non_slot_dict.get(key, None) if v is None: self._maybe_initialize_checkpointable() - distribution_strategy = distribute_lib.get_distribution_strategy() + distribution_strategy = ( + distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(colocate_with): if eager: restored_initial_value = self._preload_simple_restoration( @@ -839,7 +847,7 @@ class Optimizer( def _get_non_slot_variable(self, name, graph=None): non_slot = self._non_slot_dict.get((name, graph), None) - if hasattr(non_slot, "_mirrored_container"): + if hasattr(non_slot, "_distributed_container"): # This is a mirrored non-slot. In order to enable code like `_finish` # to assign to a non-slot, return the current context replica. return non_slot.get() diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py index 0cab6410e83ca1880a0a4a80d2cfa5c17517af95..dfe9176beaf27f3cfa945eee8693ba7c5e9551fa 100644 --- a/tensorflow/python/training/optimizer_test.py +++ b/tensorflow/python/training/optimizer_test.py @@ -34,7 +34,7 @@ from tensorflow.python.training import gradient_descent class OptimizerTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testBasic(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -112,7 +112,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)], var1.eval()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoVariables(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: # pylint: disable=cell-var-from-loop @@ -127,7 +127,7 @@ class OptimizerTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'No.*variables'): sgd_op.minimize(loss) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -145,7 +145,7 @@ class OptimizerTest(test.TestCase): # var1 has no gradient sgd_op.minimize(loss, var_list=[var1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_Minimize(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -161,7 +161,7 @@ class OptimizerTest(test.TestCase): 'No gradients provided for any variable'): sgd_op.minimize(loss, var_list=[var0, var1]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_ApplyGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -175,7 +175,7 @@ class OptimizerTest(test.TestCase): 'No gradients provided for any variable'): sgd_op.apply_gradients([(None, var0), (None, var1)]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testGradientsAsVariables(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): # Note that we name the variables uniquely here since the variables don't @@ -215,7 +215,7 @@ class OptimizerTest(test.TestCase): self.assertAllClose([-14., -13.], self.evaluate(var0)) self.assertAllClose([-6., -5.], self.evaluate(var1)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testComputeGradientsWithTensors(self): x = ops.convert_to_tensor(1.0) def f(): diff --git a/tensorflow/python/training/quantize_training.i b/tensorflow/python/training/quantize_training.i index fb5e47efa0259d02df3ccf2e9b1430e027f8fcfb..41e62e02521bf9ad39d09bb8ad7d3c108916e34a 100644 --- a/tensorflow/python/training/quantize_training.i +++ b/tensorflow/python/training/quantize_training.i @@ -56,7 +56,7 @@ PyObject* DoQuantizeTrainingOnGraphDefHelper( %insert("python") %{ def do_quantize_training_on_graphdef(input_graph, num_bits): - """A general quantization scheme is being developed in @{tf.contrib.quantize}. + """A general quantization scheme is being developed in `tf.contrib.quantize`. Consider using that instead, though since it is in the tf.contrib namespace, it is not subject to backward compatibility guarantees. @@ -73,6 +73,8 @@ def do_quantize_training_on_graphdef(input_graph, num_bits): do_quantize_training_on_graphdef._tf_api_names = [ 'train.do_quantize_training_on_graphdef'] +do_quantize_training_on_graphdef._tf_api_names_v1 = [ + 'train.do_quantize_training_on_graphdef'] %} %unignoreall diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index b8f58a288c73cde881a60a8e467c8224d2547c5d..04fce496bd9b90f36027939a98f2d33130b55250 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -21,16 +21,12 @@ from __future__ import print_function import collections import os.path -import re -import sys import time import uuid import numpy as np import six -from google.protobuf import text_format - from tensorflow.core.protobuf import checkpointable_object_graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saver_pb2 @@ -42,7 +38,6 @@ from tensorflow.python.framework import device as pydev from tensorflow.python.framework import errors from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops -from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_io_ops @@ -53,14 +48,25 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saveable_object from tensorflow.python.training import training_util -from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export +# TODO(allenl): Remove these aliases once all users are migrated off. +get_checkpoint_state = checkpoint_management.get_checkpoint_state +update_checkpoint_state = checkpoint_management.update_checkpoint_state +generate_checkpoint_state_proto = ( + checkpoint_management.generate_checkpoint_state_proto) +latest_checkpoint = checkpoint_management.latest_checkpoint +checkpoint_exists = checkpoint_management.checkpoint_exists +get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes +remove_checkpoint = checkpoint_management.remove_checkpoint + + # Op names which identify variable reads which should be saved. _VARIABLE_OPS = set(["Variable", "VariableV2", @@ -127,8 +133,10 @@ class BaseSaverBuilder(object): def f(): with ops.device(v.device): x = v.read_value() - with ops.device("/device:CPU:0"): - return array_ops.identity(x) + # To allow variables placed on non-CPU devices to be checkpointed, + # we copy them to CPU on the same machine first. + with ops.device("/device:CPU:0"): + return array_ops.identity(x) return f self.handle_op = var.handle @@ -206,21 +214,19 @@ class BaseSaverBuilder(object): filename_tensor: String Tensor. saveables: List of BaseSaverBuilder.SaveableObject objects. preferred_shard: Int. Shard to open first when loading a sharded file. - restore_sequentially: Bool. If true, each restore is sequential. + restore_sequentially: Unused. Bool. If true, each restore is sequential. Returns: A list of Tensors resulting from reading 'saveable' from 'filename'. """ + del restore_sequentially all_tensors = [] - assign_ops = [] for saveable in saveables: - restore_control_inputs = assign_ops[-1:] if restore_sequentially else [] with ops.device(_set_cpu0(saveable.device) if saveable.device else None): - with ops.control_dependencies(restore_control_inputs): - all_tensors.extend( - self.restore_op(filename_tensor, saveable, preferred_shard)) + all_tensors.extend( + self.restore_op(filename_tensor, saveable, preferred_shard)) return all_tensors # pylint: disable=unused-argument @@ -859,218 +865,6 @@ def _get_saver_or_default(): return saver -def _GetCheckpointFilename(save_dir, latest_filename): - """Returns a filename for storing the CheckpointState. - - Args: - save_dir: The directory for saving and restoring checkpoints. - latest_filename: Name of the file in 'save_dir' that is used - to store the CheckpointState. - - Returns: - The path of the file that contains the CheckpointState proto. - """ - if latest_filename is None: - latest_filename = "checkpoint" - return os.path.join(save_dir, latest_filename) - - -@tf_export("train.generate_checkpoint_state_proto") -def generate_checkpoint_state_proto(save_dir, - model_checkpoint_path, - all_model_checkpoint_paths=None): - """Generates a checkpoint state proto. - - Args: - save_dir: Directory where the model was saved. - model_checkpoint_path: The checkpoint file. - all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted - checkpoints, sorted from oldest to newest. If this is a non-empty list, - the last element must be equal to model_checkpoint_path. These paths - are also saved in the CheckpointState proto. - - Returns: - CheckpointState proto with model_checkpoint_path and - all_model_checkpoint_paths updated to either absolute paths or - relative paths to the current save_dir. - """ - if all_model_checkpoint_paths is None: - all_model_checkpoint_paths = [] - - if (not all_model_checkpoint_paths or - all_model_checkpoint_paths[-1] != model_checkpoint_path): - logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", - model_checkpoint_path) - all_model_checkpoint_paths.append(model_checkpoint_path) - - # Relative paths need to be rewritten to be relative to the "save_dir" - # if model_checkpoint_path already contains "save_dir". - if not os.path.isabs(save_dir): - if not os.path.isabs(model_checkpoint_path): - model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) - for i in range(len(all_model_checkpoint_paths)): - p = all_model_checkpoint_paths[i] - if not os.path.isabs(p): - all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) - - coord_checkpoint_proto = CheckpointState( - model_checkpoint_path=model_checkpoint_path, - all_model_checkpoint_paths=all_model_checkpoint_paths) - - return coord_checkpoint_proto - - -@tf_export("train.update_checkpoint_state") -def update_checkpoint_state(save_dir, - model_checkpoint_path, - all_model_checkpoint_paths=None, - latest_filename=None): - """Updates the content of the 'checkpoint' file. - - This updates the checkpoint file containing a CheckpointState - proto. - - Args: - save_dir: Directory where the model was saved. - model_checkpoint_path: The checkpoint file. - all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted - checkpoints, sorted from oldest to newest. If this is a non-empty list, - the last element must be equal to model_checkpoint_path. These paths - are also saved in the CheckpointState proto. - latest_filename: Optional name of the checkpoint file. Default to - 'checkpoint'. - - Raises: - RuntimeError: If any of the model checkpoint paths conflict with the file - containing CheckpointSate. - """ - _update_checkpoint_state( - save_dir=save_dir, - model_checkpoint_path=model_checkpoint_path, - all_model_checkpoint_paths=all_model_checkpoint_paths, - latest_filename=latest_filename, - save_relative_paths=False) - - -def _update_checkpoint_state(save_dir, - model_checkpoint_path, - all_model_checkpoint_paths=None, - latest_filename=None, - save_relative_paths=False): - """Updates the content of the 'checkpoint' file. - - This updates the checkpoint file containing a CheckpointState - proto. - - Args: - save_dir: Directory where the model was saved. - model_checkpoint_path: The checkpoint file. - all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted - checkpoints, sorted from oldest to newest. If this is a non-empty list, - the last element must be equal to model_checkpoint_path. These paths - are also saved in the CheckpointState proto. - latest_filename: Optional name of the checkpoint file. Default to - 'checkpoint'. - save_relative_paths: If `True`, will write relative paths to the checkpoint - state file. - - Raises: - RuntimeError: If any of the model checkpoint paths conflict with the file - containing CheckpointSate. - """ - # Writes the "checkpoint" file for the coordinator for later restoration. - coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) - if save_relative_paths: - if os.path.isabs(model_checkpoint_path): - rel_model_checkpoint_path = os.path.relpath( - model_checkpoint_path, save_dir) - else: - rel_model_checkpoint_path = model_checkpoint_path - rel_all_model_checkpoint_paths = [] - for p in all_model_checkpoint_paths: - if os.path.isabs(p): - rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) - else: - rel_all_model_checkpoint_paths.append(p) - ckpt = generate_checkpoint_state_proto( - save_dir, - rel_model_checkpoint_path, - all_model_checkpoint_paths=rel_all_model_checkpoint_paths) - else: - ckpt = generate_checkpoint_state_proto( - save_dir, - model_checkpoint_path, - all_model_checkpoint_paths=all_model_checkpoint_paths) - - if coord_checkpoint_filename == ckpt.model_checkpoint_path: - raise RuntimeError("Save path '%s' conflicts with path used for " - "checkpoint state. Please use a different save path." % - model_checkpoint_path) - - # Preventing potential read/write race condition by *atomically* writing to a - # file. - file_io.atomic_write_string_to_file(coord_checkpoint_filename, - text_format.MessageToString(ckpt)) - - -@tf_export("train.get_checkpoint_state") -def get_checkpoint_state(checkpoint_dir, latest_filename=None): - """Returns CheckpointState proto from the "checkpoint" file. - - If the "checkpoint" file contains a valid CheckpointState - proto, returns it. - - Args: - checkpoint_dir: The directory of checkpoints. - latest_filename: Optional name of the checkpoint file. Default to - 'checkpoint'. - - Returns: - A CheckpointState if the state was available, None - otherwise. - - Raises: - ValueError: if the checkpoint read doesn't have model_checkpoint_path set. - """ - ckpt = None - coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, - latest_filename) - f = None - try: - # Check that the file exists before opening it to avoid - # many lines of errors from colossus in the logs. - if file_io.file_exists(coord_checkpoint_filename): - file_content = file_io.read_file_to_string( - coord_checkpoint_filename) - ckpt = CheckpointState() - text_format.Merge(file_content, ckpt) - if not ckpt.model_checkpoint_path: - raise ValueError("Invalid checkpoint state loaded from %s", - checkpoint_dir) - # For relative model_checkpoint_path and all_model_checkpoint_paths, - # prepend checkpoint_dir. - if not os.path.isabs(ckpt.model_checkpoint_path): - ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, - ckpt.model_checkpoint_path) - for i in range(len(ckpt.all_model_checkpoint_paths)): - p = ckpt.all_model_checkpoint_paths[i] - if not os.path.isabs(p): - ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) - except errors.OpError as e: - # It's ok if the file cannot be read - logging.warning("%s: %s", type(e).__name__, e) - logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) - return None - except text_format.ParseError as e: - logging.warning("%s: %s", type(e).__name__, e) - logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) - return None - finally: - if f: - f.close() - return ckpt - - @tf_export("train.Saver") class Saver(object): """Saves and restores variables. @@ -1413,7 +1207,7 @@ class Saver(object): # Otherwise delete the files. try: - remove_checkpoint( + checkpoint_management.remove_checkpoint( self._CheckpointFilename(p), self.saver_def.version, meta_graph_suffix) except Exception as e: # pylint: disable=broad-except @@ -1519,7 +1313,7 @@ class Saver(object): Args: checkpoint_paths: a list of checkpoint paths. """ - mtimes = get_checkpoint_mtimes(checkpoint_paths) + mtimes = checkpoint_management.get_checkpoint_mtimes(checkpoint_paths) self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes))) def save(self, @@ -1625,7 +1419,7 @@ class Saver(object): model_checkpoint_path = compat.as_str(model_checkpoint_path) if write_state: self._RecordLastCheckpoint(model_checkpoint_path) - _update_checkpoint_state( + checkpoint_management.update_checkpoint_state_internal( save_dir=save_path_parent, model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=self.last_checkpoints, @@ -1640,7 +1434,7 @@ class Saver(object): raise exc if write_meta_graph: - meta_graph_filename = _meta_graph_filename( + meta_graph_filename = checkpoint_management.meta_graph_filename( checkpoint_file, meta_graph_suffix=meta_graph_suffix) if not context.executing_eagerly(): with sess.graph.as_default(): @@ -1708,12 +1502,17 @@ class Saver(object): save_path: Path where parameters were previously saved. Raises: - ValueError: If save_path is None. + ValueError: If save_path is None or not a valid checkpoint. """ if self._is_empty: return if save_path is None: raise ValueError("Can't load save_path when it is None.") + + if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)): + raise ValueError("The passed save_path is not a valid checkpoint: " + + compat.as_text(save_path)) + logging.info("Restoring parameters from %s", compat.as_text(save_path)) try: if context.executing_eagerly(): @@ -1721,23 +1520,22 @@ class Saver(object): else: sess.run(self.saver_def.restore_op_name, {self.saver_def.filename_tensor_name: save_path}) - except errors.NotFoundError: - exception_type, exception_value, exception_traceback = sys.exc_info() - # The checkpoint would not be loaded successfully as is. Try to parse it - # as an object-based checkpoint. - should_reraise = False + except errors.NotFoundError as err: + # There are three common conditions that might cause this error: + # 0. The file is missing. We ignore here, as this is checked above. + # 1. This is an object-based checkpoint trying name-based loading. + # 2. The graph has been altered and a variable or other name is missing. + + # 1. The checkpoint would not be loaded successfully as is. Try to parse + # it as an object-based checkpoint. try: - reader = pywrap_tensorflow.NewCheckpointReader(save_path) - object_graph_string = reader.get_tensor( - checkpointable.OBJECT_GRAPH_PROTO_KEY) + names_to_keys = object_graph_key_mapping(save_path) except errors.NotFoundError: - # This is not an object-based checkpoint, or the checkpoint doesn't - # exist. Re-raise the original exception, but do it outside the except - # block so the object graph lookup isn't included in the stack trace. - should_reraise = True - if should_reraise: - six.reraise(exception_type, exception_value, exception_traceback) - del exception_traceback # avoid reference cycles + # 2. This is not an object-based checkpoint, which likely means there + # is a graph mismatch. Re-raise the original error with + # a helpful message (b/110263146) + raise _wrap_restore_error_with_msg( + err, "a Variable name or other graph key that is missing") # This is an object-based checkpoint. We'll print a warning and then do # the restore. @@ -1746,36 +1544,18 @@ class Saver(object): "may be somewhat fragile, and will re-build the Saver. Instead, " "consider loading object-based checkpoints using " "tf.train.Checkpoint().") - self._restore_from_object_based_checkpoint( - sess=sess, save_path=save_path, - object_graph_string=object_graph_string) - - def _restore_from_object_based_checkpoint(self, sess, save_path, - object_graph_string): - """A compatibility mode for reading object-based checkpoints.""" - object_graph_proto = ( - checkpointable_object_graph_pb2.CheckpointableObjectGraph()) - object_graph_proto.ParseFromString(object_graph_string) - names_to_keys = {} - for node in object_graph_proto.nodes: - for attribute in node.attributes: - names_to_keys[attribute.full_name] = attribute.checkpoint_key - saveables = self._builder._ValidateAndSliceInputs(self._var_list) # pylint: disable=protected-access - for saveable in saveables: - for spec in saveable.specs: - if spec.name not in names_to_keys: - raise errors.NotFoundError( - None, None, - message=("Attempting to load an object-based checkpoint using " - "variable names, but could not find %s in the " - "checkpoint.") % spec.name) - spec.name = names_to_keys[spec.name] - if self._object_restore_saver is None: - # Cache the Saver so multiple restore() calls don't pollute the graph when - # graph building. This assumes keys are consistent (i.e. this is the same - # type of object-based checkpoint we saw previously). - self._object_restore_saver = Saver(saveables) - self._object_restore_saver.restore(sess=sess, save_path=save_path) + self._object_restore_saver = saver_from_object_based_checkpoint( + checkpoint_path=save_path, + var_list=self._var_list, + builder=self._builder, + names_to_keys=names_to_keys, + cached_saver=self._object_restore_saver) + self._object_restore_saver.restore(sess=sess, save_path=save_path) + except errors.InvalidArgumentError as err: + # There is a mismatch between the graph and the checkpoint being loaded. + # We add a more reasonable error message here to help users (b/110263146) + raise _wrap_restore_error_with_msg( + err, "a mismatch between the current graph and the graph") @staticmethod def _add_collection_def(meta_graph_def, key, export_scope=None): @@ -1790,55 +1570,6 @@ class Saver(object): export_scope=export_scope) -def _prefix_to_checkpoint_path(prefix, format_version): - """Returns the pathname of a checkpoint file, given the checkpoint prefix. - - For V1 checkpoint, simply returns the prefix itself (the data file). For V2, - returns the pathname to the index file. - - Args: - prefix: a string, the prefix of a checkpoint. - format_version: the checkpoint format version that corresponds to the - prefix. - Returns: - The pathname of a checkpoint file, taking into account the checkpoint - format version. - """ - if format_version == saver_pb2.SaverDef.V2: - return prefix + ".index" # The index file identifies a checkpoint. - return prefix # Just the data file. - - -@tf_export("train.latest_checkpoint") -def latest_checkpoint(checkpoint_dir, latest_filename=None): - """Finds the filename of latest saved checkpoint file. - - Args: - checkpoint_dir: Directory where the variables were saved. - latest_filename: Optional name for the protocol buffer file that - contains the list of most recent checkpoint filenames. - See the corresponding argument to `Saver.save()`. - - Returns: - The full path to the latest checkpoint or `None` if no checkpoint was found. - """ - # Pick the latest checkpoint based on checkpoint state. - ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) - if ckpt and ckpt.model_checkpoint_path: - # Look for either a V2 path or a V1 path, with priority for V2. - v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, - saver_pb2.SaverDef.V2) - v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path, - saver_pb2.SaverDef.V1) - if file_io.get_matching_files(v2_path) or file_io.get_matching_files( - v1_path): - return ckpt.model_checkpoint_path - else: - logging.error("Couldn't match files for checkpoint %s", - ckpt.model_checkpoint_path) - return None - - @tf_export("train.import_meta_graph") def import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None, **kwargs): @@ -1915,6 +1646,14 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False, execution is enabled. @end_compatibility """ # pylint: disable=g-doc-exception + return _import_meta_graph_with_return_elements( + meta_graph_or_file, clear_devices, import_scope, **kwargs)[0] + + +def _import_meta_graph_with_return_elements( + meta_graph_or_file, clear_devices=False, import_scope=None, + return_elements=None, **kwargs): + """Import MetaGraph, and return both a saver and returned elements.""" if context.executing_eagerly(): raise RuntimeError("Exporting/importing meta graphs is not supported when " "eager execution is enabled. No graph exists when eager " @@ -1924,12 +1663,22 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False, else: meta_graph_def = meta_graph_or_file - imported_vars = meta_graph.import_scoped_meta_graph( - meta_graph_def, - clear_devices=clear_devices, - import_scope=import_scope, - **kwargs) + imported_vars, imported_return_elements = ( + meta_graph.import_scoped_meta_graph_with_return_elements( + meta_graph_def, + clear_devices=clear_devices, + import_scope=import_scope, + return_elements=return_elements, + **kwargs)) + + saver = _create_saver_from_imported_meta_graph( + meta_graph_def, import_scope, imported_vars) + return saver, imported_return_elements + +def _create_saver_from_imported_meta_graph( + meta_graph_def, import_scope, imported_vars): + """Return a saver for restoring variable values to an imported MetaGraph.""" if meta_graph_def.HasField("saver_def"): # Infer the scope that is prepended by `import_scoped_meta_graph`. scope = import_scope @@ -2028,121 +1777,105 @@ def export_meta_graph(filename=None, return meta_graph_def -@tf_export("train.checkpoint_exists") -def checkpoint_exists(checkpoint_prefix): - """Checks whether a V1 or V2 checkpoint exists with the specified prefix. - - This is the recommended way to check if a checkpoint exists, since it takes - into account the naming difference between V1 and V2 formats. - - Args: - checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking - priority. Typically the result of `Saver.save()` or that of - `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or - V1/V2. - Returns: - A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists. - """ - pathname = _prefix_to_checkpoint_path(checkpoint_prefix, - saver_pb2.SaverDef.V2) - if file_io.get_matching_files(pathname): - return True - elif file_io.get_matching_files(checkpoint_prefix): - return True - else: - return False +def _wrap_restore_error_with_msg(err, extra_verbiage): + err_msg = ("Restoring from checkpoint failed. This is most likely " + "due to {} from the checkpoint. Please ensure that you " + "have not altered the graph expected based on the checkpoint. " + "Original error:\n\n{}").format(extra_verbiage, err.message) + return err.__class__(err.node_def, err.op, err_msg) -@tf_export("train.get_checkpoint_mtimes") -def get_checkpoint_mtimes(checkpoint_prefixes): - """Returns the mtimes (modification timestamps) of the checkpoints. +ops.register_proto_function( + ops.GraphKeys.SAVERS, + proto_type=saver_pb2.SaverDef, + to_proto=Saver.to_proto, + from_proto=Saver.from_proto) - Globs for the checkpoints pointed to by `checkpoint_prefixes`. If the files - exist, collect their mtime. Both V2 and V1 checkpoints are considered, in - that priority. - This is the recommended way to get the mtimes, since it takes into account - the naming difference between V1 and V2 formats. +def object_graph_key_mapping(checkpoint_path): + """Return name to key mappings from the checkpoint. Args: - checkpoint_prefixes: a list of checkpoint paths, typically the results of - `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of - sharded/non-sharded or V1/V2. + checkpoint_path: string, path to object-based checkpoint + Returns: - A list of mtimes (in microseconds) of the found checkpoints. + Dictionary mapping tensor names to checkpoint keys. """ - mtimes = [] - - def match_maybe_append(pathname): - fnames = file_io.get_matching_files(pathname) - if fnames: - mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9) - return True - return False - - for checkpoint_prefix in checkpoint_prefixes: - # Tries V2's metadata file first. - pathname = _prefix_to_checkpoint_path(checkpoint_prefix, - saver_pb2.SaverDef.V2) - if match_maybe_append(pathname): - continue - # Otherwise, tries V1, where the prefix is the complete pathname. - match_maybe_append(checkpoint_prefix) - - return mtimes - - -@tf_export("train.remove_checkpoint") -def remove_checkpoint(checkpoint_prefix, - checkpoint_format_version=saver_pb2.SaverDef.V2, - meta_graph_suffix="meta"): - """Removes a checkpoint given by `checkpoint_prefix`. + reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) + object_graph_string = reader.get_tensor( + checkpointable.OBJECT_GRAPH_PROTO_KEY) + object_graph_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph()) + object_graph_proto.ParseFromString(object_graph_string) + names_to_keys = {} + for node in object_graph_proto.nodes: + for attribute in node.attributes: + names_to_keys[attribute.full_name] = attribute.checkpoint_key + return names_to_keys + + +def saver_from_object_based_checkpoint( + checkpoint_path, var_list=None, builder=None, names_to_keys=None, + cached_saver=None): + """Return a `Saver` which reads from an object-based checkpoint. + + This function validates that all variables in the variables list are remapped + in the object-based checkpoint (or `names_to_keys` dict if provided). A + saver will be created with the list of remapped variables. + + The `cached_saver` argument allows the user to pass in a previously created + saver, so multiple `saver.restore()` calls don't pollute the graph when graph + building. This assumes that keys are consistent, meaning that the + 1) `checkpoint_path` checkpoint, and + 2) checkpoint used to create the `cached_saver` + are the same type of object-based checkpoint. If this argument is set, this + function will simply validate that all variables have been remapped by the + checkpoint at `checkpoint_path`. + + Note that in general, `tf.train.Checkpoint` should be used to restore/save an + object-based checkpoint. Args: - checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result - of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of - sharded/non-sharded or V1/V2. - checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to - `SaverDef.V2`. - meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. - """ - _delete_file_if_exists( - _meta_graph_filename(checkpoint_prefix, meta_graph_suffix)) - if checkpoint_format_version == saver_pb2.SaverDef.V2: - # V2 has a metadata file and some data files. - _delete_file_if_exists(checkpoint_prefix + ".index") - _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????") - else: - # V1, Legacy. Exact match on the data file. - _delete_file_if_exists(checkpoint_prefix) - - -def _delete_file_if_exists(filespec): - """Deletes files matching `filespec`.""" - for pathname in file_io.get_matching_files(filespec): - file_io.delete_file(pathname) - - -def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"): - """Returns the meta graph filename. - - Args: - checkpoint_filename: Name of the checkpoint file. - meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. + checkpoint_path: string, path to object-based checkpoint + var_list: list of `Variables` that appear in the checkpoint. If `None`, + `var_list` will be set to all saveable objects. + builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder` + will be created. + names_to_keys: dict mapping string tensor names to checkpooint keys. If + `None`, this dict will be generated from the checkpoint file. + cached_saver: Cached `Saver` object with remapped variables. Returns: - MetaGraph file name. - """ - # If the checkpoint_filename is sharded, the checkpoint_filename could - # be of format model.ckpt-step#-?????-of-shard#. For example, - # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002. - basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename) - meta_graph_filename = ".".join([basename, meta_graph_suffix]) - return meta_graph_filename + `Saver` with remapped variables for reading from an object-based checkpoint. - -ops.register_proto_function( - ops.GraphKeys.SAVERS, - proto_type=saver_pb2.SaverDef, - to_proto=Saver.to_proto, - from_proto=Saver.from_proto) + Raises: + ValueError if the checkpoint provided is not an object-based checkpoint. + NotFoundError: If one of the variables in `var_list` can not be found in the + checkpoint. This could mean the checkpoint or `names_to_keys` mapping is + missing the variable. + """ + if names_to_keys is None: + try: + names_to_keys = object_graph_key_mapping(checkpoint_path) + except errors.NotFoundError: + raise ValueError("Checkpoint in %s not an object-based checkpoint." + % checkpoint_path) + if var_list is None: + var_list = variables._all_saveable_objects() # pylint: disable=protected-access + if builder is None: + builder = BulkSaverBuilder() + + saveables = builder._ValidateAndSliceInputs(var_list) # pylint: disable=protected-access + for saveable in saveables: + for spec in saveable.specs: + if spec.name not in names_to_keys: + raise errors.NotFoundError( + None, None, + message=("Attempting to load an object-based checkpoint using " + "variable names, but could not find %s in the " + "checkpoint.") % spec.name) + spec.name = names_to_keys[spec.name] + + if cached_saver is None: + return Saver(saveables) + return cached_saver diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index e3be7d868e5805cf89c111cf29ea8412521da1be..b55e64122a607766d22e7f8017a39779005ac93e 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -18,22 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib import functools import math import os import random -import shutil -import sys -import tempfile import time -import traceback import numpy as np import six from google.protobuf.any_pb2 import Any -from google.protobuf import text_format from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 @@ -73,13 +67,14 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.training import adam +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import gradient_descent from tensorflow.python.training import queue_runner_impl from tensorflow.python.training import saver as saver_module from tensorflow.python.training import saver_test_utils from tensorflow.python.training import training_util -from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import base as checkpointable_base +from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking from tensorflow.python.training.checkpointable import util as checkpointable_utils from tensorflow.python.util import compat @@ -171,10 +166,28 @@ class SaverTest(test.TestCase): def testBasic(self): self.basicSaveRestore(variables.Variable) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testResourceBasic(self): self.basicSaveRestore(resource_variable_ops.ResourceVariable) + def testResourceColocation(self): + partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2) + with ops_lib.device("/job:ps/device:GPU:0"): + v = variable_scope.get_variable("v0", + shape=[10, 2], + partitioner=partitioner, + use_resource=True) + saver_module.Saver({"v0": v}).build() + save_op = None + for op in ops_lib.get_default_graph().get_operations(): + if op.type == "SaveV2": + save_op = op + break + assert save_op is not None + for save_inp in save_op.inputs[3:]: + # Input to SaveV2 op is placed on CPU of the same device as the Variable. + self.assertEqual("/job:ps/device:CPU:0", save_inp.device) + def testResourceVariableReadOpsAddedDeterministically(self): graph_defs = [] num_graphs = 10 @@ -252,7 +265,7 @@ class SaverTest(test.TestCase): self.assertAllEqual(w3.eval(), 3.0) self.assertAllEqual(w4.eval(), 4.0) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testResourceSaveRestoreCachingDevice(self): save_path = os.path.join(self.get_temp_dir(), "resource_cache") with self.test_session(graph=ops_lib.Graph()) as sess: @@ -326,11 +339,13 @@ class SaverTest(test.TestCase): self.assertTrue(isinstance(val, six.string_types)) self.assertEqual(save_path1, val) - self.assertEqual(saver_module.latest_checkpoint(save_dir1), save_path1) + self.assertEqual( + checkpoint_management.latest_checkpoint(save_dir1), save_path1) save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2") os.renames(save_dir1, save_dir2) save_path2 = os.path.join(save_dir2, "save_copy_restore") - self.assertEqual(saver_module.latest_checkpoint(save_dir2), save_path2) + self.assertEqual( + checkpoint_management.latest_checkpoint(save_dir2), save_path2) # Start a second session. In that session the parameter nodes # have not been initialized either. @@ -368,8 +383,8 @@ class SaverTest(test.TestCase): for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2): with self.test_session() as sess: save = saver_module.Saver({"v0": v0}, write_version=ver) - with self.assertRaisesRegexp(errors.NotFoundError, - "Failed to find any matching files for"): + with self.assertRaisesRegexp( + ValueError, "The passed save_path is not a valid checkpoint:"): save.restore(sess, "invalid path") def testInt64(self): @@ -671,7 +686,7 @@ class SaverTest(test.TestCase): save.restore(sess, save_path) self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], var.eval()) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testSaveWithGlobalStep(self, pad_step_number=False): save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step") global_step_int = 5 @@ -769,6 +784,63 @@ class SaverTest(test.TestCase): self.assertEqual(20.0, v1.eval()) save.save(sess, save_path) + def testSaveRestoreAndValidateVariableDtype(self): + for variable_op in [ + variables.Variable, resource_variable_ops.ResourceVariable + ]: + save_path = os.path.join(self.get_temp_dir(), "basic_save_restore") + + # Build the first session. + with self.test_session(graph=ops_lib.Graph()) as sess: + v0 = variable_op(10.0, name="v0", dtype=dtypes.float32) + + if not context.executing_eagerly(): + self.evaluate([variables.global_variables_initializer()]) + + save = saver_module.Saver({"v0": v0}) + save.save(sess, save_path) + + # Start a second session. + with self.test_session(graph=ops_lib.Graph()) as sess: + v0_wrong_dtype = variable_op(1, name="v0", dtype=dtypes.int32) + # Restore the saved value with different dtype + # in the parameter nodes. + save = saver_module.Saver({"v0": v0_wrong_dtype}) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "original dtype"): + save.restore(sess, save_path) + + # Test restoring large tensors (triggers a thread pool) + def testRestoreLargeTensors(self): + save_dir = self.get_temp_dir() + def _model(): + small_v = [variable_scope.get_variable( + "small%d" % i, shape=[10, 2], use_resource=True) for i in range(5)] + large_v = [variable_scope.get_variable( + "large%d" % i, shape=[32000, 1000], use_resource=True) + for i in range(3)] + return small_v + large_v + + save_graph = ops_lib.Graph() + with save_graph.as_default(), self.test_session(graph=save_graph) as sess: + orig_vars = _model() + sess.run(variables.global_variables_initializer()) + save = saver_module.Saver(max_to_keep=1) + variables.global_variables_initializer().run() + save.save(sess, save_dir) + orig_vals = sess.run(orig_vars) + + restore_graph = ops_lib.Graph() + with restore_graph.as_default(), self.test_session( + graph=restore_graph) as sess: + restored_vars = _model() + save = saver_module.Saver(max_to_keep=1) + save.restore(sess, save_dir) + restored_vals = sess.run(restored_vars) + + for orig, restored in zip(orig_vals, restored_vals): + self.assertAllEqual(orig, restored) + class SaveRestoreShardedTest(test.TestCase): @@ -809,7 +881,7 @@ class SaveRestoreShardedTest(test.TestCase): self.assertEqual(save_path + "-?????-of-00002", val) else: self.assertEqual(save_path, val) - meta_graph_filename = saver_module._meta_graph_filename(val) + meta_graph_filename = checkpoint_management.meta_graph_filename(val) self.assertEqual(save_path + ".meta", meta_graph_filename) if save._write_version is saver_pb2.SaverDef.V1: @@ -903,11 +975,11 @@ class SaveRestoreShardedTest(test.TestCase): if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual( - saver_module.latest_checkpoint(self.get_temp_dir()), + checkpoint_management.latest_checkpoint(self.get_temp_dir()), os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002")) else: self.assertEqual( - saver_module.latest_checkpoint(self.get_temp_dir()), + checkpoint_management.latest_checkpoint(self.get_temp_dir()), os.path.join(self.get_temp_dir(), "sharded_basics")) def testSaverDef(self): @@ -1057,7 +1129,7 @@ class MaxToKeepTest(test.TestCase): def assertCheckpointState(self, model_checkpoint_path, all_model_checkpoint_paths, save_dir): - checkpoint_state = saver_module.get_checkpoint_state(save_dir) + checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir) self.assertEqual(checkpoint_state.model_checkpoint_path, model_checkpoint_path) self.assertEqual(checkpoint_state.all_model_checkpoint_paths, @@ -1065,7 +1137,7 @@ class MaxToKeepTest(test.TestCase): def testMaxToKeepEager(self): with context.eager_mode(): - save_dir = self._get_test_dir("max_to_keep_non_sharded") + save_dir = self._get_test_dir("max_to_keep_eager") v = variable_scope.variable(10.0, name="v") save = saver_module.Saver({"v": v}, max_to_keep=2) @@ -1075,7 +1147,7 @@ class MaxToKeepTest(test.TestCase): s1 = save.save(None, os.path.join(save_dir, "s1")) self.assertEqual([s1], save.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s1], @@ -1083,8 +1155,8 @@ class MaxToKeepTest(test.TestCase): s2 = save.save(None, os.path.join(save_dir, "s2")) self.assertEqual([s1, s2], save.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s1)) - self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s1, s2], @@ -1092,9 +1164,9 @@ class MaxToKeepTest(test.TestCase): s3 = save.save(None, os.path.join(save_dir, "s3")) self.assertEqual([s2, s3], save.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s1)) - self.assertTrue(saver_module.checkpoint_exists(s2)) - self.assertTrue(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s3)) self.assertCheckpointState( model_checkpoint_path=s3, all_model_checkpoint_paths=[s2, s3], @@ -1109,9 +1181,9 @@ class MaxToKeepTest(test.TestCase): # Adding s2 again (old s2 is removed first, then new s2 appended) s2 = save.save(None, os.path.join(save_dir, "s2")) self.assertEqual([s3, s2], save.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s1)) - self.assertTrue(saver_module.checkpoint_exists(s3)) - self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertFalse(checkpoint_management.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s3)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s3, s2], @@ -1120,8 +1192,8 @@ class MaxToKeepTest(test.TestCase): # Adding s1 (s3 should now be deleted as oldest in list) s1 = save.save(None, os.path.join(save_dir, "s1")) self.assertEqual([s2, s1], save.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s3)) - self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1130,9 +1202,9 @@ class MaxToKeepTest(test.TestCase): s2 = save2.save(None, os.path.join(save_dir, "s2")) self.assertEqual([s3, s2], save2.last_checkpoints) # Created by the first helper. - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) # Deleted by the first helper. - self.assertFalse(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) def testNonSharded(self): save_dir = self._get_test_dir("max_to_keep_non_sharded") @@ -1145,7 +1217,7 @@ class MaxToKeepTest(test.TestCase): s1 = save.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s1], save.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s1], @@ -1153,8 +1225,8 @@ class MaxToKeepTest(test.TestCase): s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s1, s2], save.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s1)) - self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s1, s2], @@ -1162,9 +1234,9 @@ class MaxToKeepTest(test.TestCase): s3 = save.save(sess, os.path.join(save_dir, "s3")) self.assertEqual([s2, s3], save.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s1)) - self.assertTrue(saver_module.checkpoint_exists(s2)) - self.assertTrue(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s3)) self.assertCheckpointState( model_checkpoint_path=s3, all_model_checkpoint_paths=[s2, s3], @@ -1183,15 +1255,18 @@ class MaxToKeepTest(test.TestCase): # Adding s2 again (old s2 is removed first, then new s2 appended) s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s3, s2], save.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s1)) + self.assertFalse(checkpoint_management.checkpoint_exists(s1)) self.assertFalse( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) - self.assertTrue(saver_module.checkpoint_exists(s3)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s1))) + self.assertTrue(checkpoint_management.checkpoint_exists(s3)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) - self.assertTrue(saver_module.checkpoint_exists(s2)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s3))) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s2))) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s3, s2], @@ -1200,15 +1275,18 @@ class MaxToKeepTest(test.TestCase): # Adding s1 (s3 should now be deleted as oldest in list) s1 = save.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s2, s1], save.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) - self.assertTrue(saver_module.checkpoint_exists(s2)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s3))) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) - self.assertTrue(saver_module.checkpoint_exists(s1)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s2))) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s1))) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1220,16 +1298,19 @@ class MaxToKeepTest(test.TestCase): s2 = save2.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s3, s2], save2.last_checkpoints) # Created by the first helper. - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s1))) # Deleted by the first helper. - self.assertFalse(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) - self.assertTrue(saver_module.checkpoint_exists(s2)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s3))) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s2))) self.assertCheckpointState( model_checkpoint_path=s2, all_model_checkpoint_paths=[s3, s2], @@ -1238,15 +1319,18 @@ class MaxToKeepTest(test.TestCase): # Adding s1 (s3 should now be deleted as oldest in list) s1 = save2.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s2, s1], save2.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) - self.assertTrue(saver_module.checkpoint_exists(s2)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s3))) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) - self.assertTrue(saver_module.checkpoint_exists(s1)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s2))) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s1))) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1258,16 +1342,19 @@ class MaxToKeepTest(test.TestCase): s2 = save3.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s2], save3.last_checkpoints) # Created by the first helper. - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s1))) # Deleted by the first helper. - self.assertFalse(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) - self.assertTrue(saver_module.checkpoint_exists(s2)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s3))) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s2))) # Even though the file for s1 exists, this saver isn't aware of it, which # is why it doesn't end up in the checkpoint state. self.assertCheckpointState( @@ -1278,15 +1365,18 @@ class MaxToKeepTest(test.TestCase): # Adding s1 (s3 should not be deleted because helper is unaware of it) s1 = save3.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([s2, s1], save3.last_checkpoints) - self.assertFalse(saver_module.checkpoint_exists(s3)) + self.assertFalse(checkpoint_management.checkpoint_exists(s3)) self.assertFalse( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3))) - self.assertTrue(saver_module.checkpoint_exists(s2)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s3))) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2))) - self.assertTrue(saver_module.checkpoint_exists(s1)) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s2))) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) self.assertTrue( - saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1))) + checkpoint_management.checkpoint_exists( + checkpoint_management.meta_graph_filename(s1))) self.assertCheckpointState( model_checkpoint_path=s1, all_model_checkpoint_paths=[s2, s1], @@ -1317,7 +1407,8 @@ class MaxToKeepTest(test.TestCase): else: self.assertEqual(4, len(gfile.Glob(s1 + "*"))) - self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1))) + self.assertTrue( + gfile.Exists(checkpoint_management.meta_graph_filename(s1))) s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([s1, s2], save.last_checkpoints) @@ -1325,27 +1416,32 @@ class MaxToKeepTest(test.TestCase): self.assertEqual(2, len(gfile.Glob(s1))) else: self.assertEqual(4, len(gfile.Glob(s1 + "*"))) - self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1))) + self.assertTrue( + gfile.Exists(checkpoint_management.meta_graph_filename(s1))) if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual(2, len(gfile.Glob(s2))) else: self.assertEqual(4, len(gfile.Glob(s2 + "*"))) - self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2))) + self.assertTrue( + gfile.Exists(checkpoint_management.meta_graph_filename(s2))) s3 = save.save(sess, os.path.join(save_dir, "s3")) self.assertEqual([s2, s3], save.last_checkpoints) self.assertEqual(0, len(gfile.Glob(s1 + "*"))) - self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1))) + self.assertFalse( + gfile.Exists(checkpoint_management.meta_graph_filename(s1))) if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual(2, len(gfile.Glob(s2))) else: self.assertEqual(4, len(gfile.Glob(s2 + "*"))) - self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2))) + self.assertTrue( + gfile.Exists(checkpoint_management.meta_graph_filename(s2))) if save._write_version is saver_pb2.SaverDef.V1: self.assertEqual(2, len(gfile.Glob(s3))) else: self.assertEqual(4, len(gfile.Glob(s3 + "*"))) - self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s3))) + self.assertTrue( + gfile.Exists(checkpoint_management.meta_graph_filename(s3))) def testNoMaxToKeep(self): save_dir = self._get_test_dir("no_max_to_keep") @@ -1360,20 +1456,20 @@ class MaxToKeepTest(test.TestCase): self.assertEqual([], save.last_checkpoints) s1 = save.save(sess, os.path.join(save_dir, "s1")) self.assertEqual([], save.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) s2 = save.save(sess, os.path.join(save_dir, "s2")) self.assertEqual([], save.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) # Test max_to_keep being 0. save2 = saver_module.Saver({"v": v}, max_to_keep=0) self.assertEqual([], save2.last_checkpoints) s1 = save2.save(sess, os.path.join(save_dir2, "s1")) self.assertEqual([], save2.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s1)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) s2 = save2.save(sess, os.path.join(save_dir2, "s2")) self.assertEqual([], save2.last_checkpoints) - self.assertTrue(saver_module.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s2)) def testNoMetaGraph(self): save_dir = self._get_test_dir("no_meta_graph") @@ -1384,8 +1480,9 @@ class MaxToKeepTest(test.TestCase): variables.global_variables_initializer().run() s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False) - self.assertTrue(saver_module.checkpoint_exists(s1)) - self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1))) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) + self.assertFalse( + gfile.Exists(checkpoint_management.meta_graph_filename(s1))) class KeepCheckpointEveryNHoursTest(test.TestCase): @@ -1395,7 +1492,7 @@ class KeepCheckpointEveryNHoursTest(test.TestCase): gfile.MakeDirs(test_dir) return test_dir - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes @test.mock.patch.object(saver_module, "time") def testNonSharded(self, mock_time): save_dir = self._get_test_dir("keep_checkpoint_every_n_hours") @@ -1441,10 +1538,10 @@ class KeepCheckpointEveryNHoursTest(test.TestCase): self.assertEqual([s3, s4], save.last_checkpoints) # Check that s1 is still here, but s2 is gone. - self.assertTrue(saver_module.checkpoint_exists(s1)) - self.assertFalse(saver_module.checkpoint_exists(s2)) - self.assertTrue(saver_module.checkpoint_exists(s3)) - self.assertTrue(saver_module.checkpoint_exists(s4)) + self.assertTrue(checkpoint_management.checkpoint_exists(s1)) + self.assertFalse(checkpoint_management.checkpoint_exists(s2)) + self.assertTrue(checkpoint_management.checkpoint_exists(s3)) + self.assertTrue(checkpoint_management.checkpoint_exists(s4)) class SaveRestoreWithVariableNameMap(test.TestCase): @@ -1515,7 +1612,7 @@ class SaveRestoreWithVariableNameMap(test.TestCase): self.assertEqual(10.0, self.evaluate(v0)) self.assertEqual(20.0, self.evaluate(v1)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNonReshapeResourceVariable(self): self._testNonReshape(resource_variable_ops.ResourceVariable) @@ -1523,221 +1620,6 @@ class SaveRestoreWithVariableNameMap(test.TestCase): self._testNonReshape(variables.Variable) -class LatestCheckpointWithRelativePaths(test.TestCase): - - @staticmethod - @contextlib.contextmanager - def tempWorkingDir(temppath): - cwd = os.getcwd() - os.chdir(temppath) - try: - yield - finally: - os.chdir(cwd) - - @staticmethod - @contextlib.contextmanager - def tempDir(): - tempdir = tempfile.mkdtemp() - try: - yield tempdir - finally: - shutil.rmtree(tempdir) - - def testNameCollision(self): - # Make sure we have a clean directory to work in. - with self.tempDir() as tempdir: - # Jump to that directory until this test is done. - with self.tempWorkingDir(tempdir): - # Save training snapshots to a relative path. - traindir = "train/" - os.mkdir(traindir) - # Collides with the default name of the checkpoint state file. - filepath = os.path.join(traindir, "checkpoint") - - with self.test_session() as sess: - unused_a = variables.Variable(0.0) # So that Saver saves something. - variables.global_variables_initializer().run() - - # Should fail. - saver = saver_module.Saver(sharded=False) - with self.assertRaisesRegexp(ValueError, "collides with"): - saver.save(sess, filepath) - - # Succeeds: the file will be named "checkpoint-". - saver.save(sess, filepath, global_step=1) - self.assertIsNotNone(saver_module.latest_checkpoint(traindir)) - - # Succeeds: the file will be named "checkpoint--of-". - saver = saver_module.Saver(sharded=True) - saver.save(sess, filepath) - self.assertIsNotNone(saver_module.latest_checkpoint(traindir)) - - # Succeeds: the file will be named "checkpoint---of-". - saver = saver_module.Saver(sharded=True) - saver.save(sess, filepath, global_step=1) - self.assertIsNotNone(saver_module.latest_checkpoint(traindir)) - - def testRelativePath(self): - # Make sure we have a clean directory to work in. - with self.tempDir() as tempdir: - - # Jump to that directory until this test is done. - with self.tempWorkingDir(tempdir): - - # Save training snapshots to a relative path. - traindir = "train/" - os.mkdir(traindir) - - filename = "snapshot" - filepath = os.path.join(traindir, filename) - - with self.test_session() as sess: - # Build a simple graph. - v0 = variables.Variable(0.0) - inc = v0.assign_add(1.0) - - save = saver_module.Saver({"v0": v0}) - - # Record a short training history. - variables.global_variables_initializer().run() - save.save(sess, filepath, global_step=0) - inc.eval() - save.save(sess, filepath, global_step=1) - inc.eval() - save.save(sess, filepath, global_step=2) - - with self.test_session() as sess: - # Build a new graph with different initialization. - v0 = variables.Variable(-1.0) - - # Create a new saver. - save = saver_module.Saver({"v0": v0}) - variables.global_variables_initializer().run() - - # Get the most recent checkpoint name from the training history file. - name = saver_module.latest_checkpoint(traindir) - self.assertIsNotNone(name) - - # Restore "v0" from that checkpoint. - save.restore(sess, name) - self.assertEqual(v0.eval(), 2.0) - - -class CheckpointStateTest(test.TestCase): - - def _get_test_dir(self, dirname): - test_dir = os.path.join(self.get_temp_dir(), dirname) - gfile.MakeDirs(test_dir) - return test_dir - - def testAbsPath(self): - save_dir = self._get_test_dir("abs_paths") - abs_path = os.path.join(save_dir, "model-0") - ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path) - self.assertEqual(ckpt.model_checkpoint_path, abs_path) - self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) - self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) - self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) - - def testRelPath(self): - train_dir = "train" - model = os.path.join(train_dir, "model-0") - # model_checkpoint_path should have no "train" directory part. - new_rel_path = "model-0" - ckpt = saver_module.generate_checkpoint_state_proto(train_dir, model) - self.assertEqual(ckpt.model_checkpoint_path, new_rel_path) - self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) - self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path) - - def testAllModelCheckpointPaths(self): - save_dir = self._get_test_dir("all_models_test") - abs_path = os.path.join(save_dir, "model-0") - for paths in [None, [], ["model-2"]]: - ckpt = saver_module.generate_checkpoint_state_proto( - save_dir, abs_path, all_model_checkpoint_paths=paths) - self.assertEqual(ckpt.model_checkpoint_path, abs_path) - self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) - self.assertEqual( - len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1) - self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) - - def testUpdateCheckpointState(self): - save_dir = self._get_test_dir("update_checkpoint_state") - os.chdir(save_dir) - # Make a temporary train directory. - train_dir = "train" - os.mkdir(train_dir) - abs_path = os.path.join(save_dir, "model-0") - rel_path = os.path.join("train", "model-2") - saver_module.update_checkpoint_state( - train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path]) - ckpt = saver_module.get_checkpoint_state(train_dir) - self.assertEqual(ckpt.model_checkpoint_path, rel_path) - self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) - self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path) - self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path) - - def testUpdateCheckpointStateSaveRelativePaths(self): - save_dir = self._get_test_dir("update_checkpoint_state") - os.chdir(save_dir) - abs_path2 = os.path.join(save_dir, "model-2") - rel_path2 = "model-2" - abs_path0 = os.path.join(save_dir, "model-0") - rel_path0 = "model-0" - saver_module._update_checkpoint_state( # pylint: disable=protected-access - save_dir=save_dir, - model_checkpoint_path=abs_path2, - all_model_checkpoint_paths=[rel_path0, abs_path2], - save_relative_paths=True) - - # File should contain relative paths. - file_content = file_io.read_file_to_string( - os.path.join(save_dir, "checkpoint")) - ckpt = CheckpointState() - text_format.Merge(file_content, ckpt) - self.assertEqual(ckpt.model_checkpoint_path, rel_path2) - self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) - self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2) - self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0) - - # get_checkpoint_state should return absolute paths. - ckpt = saver_module.get_checkpoint_state(save_dir) - self.assertEqual(ckpt.model_checkpoint_path, abs_path2) - self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) - self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2) - self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0) - - def testCheckPointStateFailsWhenIncomplete(self): - save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete") - os.chdir(save_dir) - ckpt_path = os.path.join(save_dir, "checkpoint") - ckpt_file = open(ckpt_path, "w") - ckpt_file.write("") - ckpt_file.close() - with self.assertRaises(ValueError): - saver_module.get_checkpoint_state(save_dir) - - def testCheckPointCompletesRelativePaths(self): - save_dir = self._get_test_dir("checkpoint_completes_relative_paths") - os.chdir(save_dir) - ckpt_path = os.path.join(save_dir, "checkpoint") - ckpt_file = open(ckpt_path, "w") - ckpt_file.write(""" - model_checkpoint_path: "./model.ckpt-687529" - all_model_checkpoint_paths: "./model.ckpt-687500" - all_model_checkpoint_paths: "./model.ckpt-687529" - """) - ckpt_file.close() - ckpt = saver_module.get_checkpoint_state(save_dir) - self.assertEqual(ckpt.model_checkpoint_path, - os.path.join(save_dir, "./model.ckpt-687529")) - self.assertEqual(ckpt.all_model_checkpoint_paths[0], - os.path.join(save_dir, "./model.ckpt-687500")) - self.assertEqual(ckpt.all_model_checkpoint_paths[1], - os.path.join(save_dir, "./model.ckpt-687529")) - - class MetaGraphTest(test.TestCase): def _get_test_dir(self, dirname): @@ -2580,62 +2462,6 @@ class WriteGraphTest(test.TestCase): self.assertTrue(os.path.exists(path)) -class SaverUtilsTest(test.TestCase): - - def setUp(self): - self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test") - gfile.MakeDirs(self._base_dir) - - def tearDown(self): - gfile.DeleteRecursively(self._base_dir) - - def testCheckpointExists(self): - for sharded in (False, True): - for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): - with self.test_session(graph=ops_lib.Graph()) as sess: - unused_v = variables.Variable(1.0, name="v") - variables.global_variables_initializer().run() - saver = saver_module.Saver(sharded=sharded, write_version=version) - - path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) - self.assertFalse( - saver_module.checkpoint_exists(path)) # Not saved yet. - - ckpt_prefix = saver.save(sess, path) - self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix)) - - ckpt_prefix = saver_module.latest_checkpoint(self._base_dir) - self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix)) - - def testGetCheckpointMtimes(self): - prefixes = [] - for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): - with self.test_session(graph=ops_lib.Graph()) as sess: - unused_v = variables.Variable(1.0, name="v") - variables.global_variables_initializer().run() - saver = saver_module.Saver(write_version=version) - prefixes.append( - saver.save(sess, os.path.join(self._base_dir, str(version)))) - - mtimes = saver_module.get_checkpoint_mtimes(prefixes) - self.assertEqual(2, len(mtimes)) - self.assertTrue(mtimes[1] >= mtimes[0]) - - def testRemoveCheckpoint(self): - for sharded in (False, True): - for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): - with self.test_session(graph=ops_lib.Graph()) as sess: - unused_v = variables.Variable(1.0, name="v") - variables.global_variables_initializer().run() - saver = saver_module.Saver(sharded=sharded, write_version=version) - - path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) - ckpt_prefix = saver.save(sess, path) - self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix)) - saver_module.remove_checkpoint(ckpt_prefix, version) - self.assertFalse(saver_module.checkpoint_exists(ckpt_prefix)) - - class ScopedGraphTest(test.TestCase): def _get_test_dir(self, dirname): @@ -2939,7 +2765,7 @@ class ScopedGraphTest(test.TestCase): self.assertEqual(2.0, var_dict2["variable2:0"].eval()) -class _OwnsAVariableSimple(checkpointable.CheckpointableBase): +class _OwnsAVariableSimple(checkpointable_base.CheckpointableBase): """A Checkpointable object which can be saved using a tf.train.Saver.""" def __init__(self): @@ -2947,7 +2773,7 @@ class _OwnsAVariableSimple(checkpointable.CheckpointableBase): name="non_dep_variable", initializer=6., use_resource=True) def _gather_saveables_for_checkpoint(self): - return {checkpointable.VARIABLE_VALUE_KEY: self.non_dep_variable} + return {checkpointable_base.VARIABLE_VALUE_KEY: self.non_dep_variable} # The Saver sorts by name before parsing, so we need a name property. @property @@ -2972,7 +2798,7 @@ class _MirroringSaveable( self._mirrored_variable.assign(tensor)) -class _OwnsMirroredVariables(checkpointable.CheckpointableBase): +class _OwnsMirroredVariables(checkpointable_base.CheckpointableBase): """A Checkpointable object which returns a more complex SaveableObject.""" def __init__(self): @@ -2987,7 +2813,7 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase): primary_variable=self.non_dep_variable, mirrored_variable=self.mirrored, name=name) - return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + return {checkpointable_base.VARIABLE_VALUE_KEY: _saveable_factory} # The Saver sorts by name before parsing, so we need a name property. @property @@ -2995,7 +2821,7 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase): return self.non_dep_variable.name -class NonLayerCheckpointable(checkpointable.Checkpointable): +class NonLayerCheckpointable(checkpointable_tracking.Checkpointable): def __init__(self): super(NonLayerCheckpointable, self).__init__() @@ -3021,7 +2847,7 @@ class MyModel(training.Model): class CheckpointableCompatibilityTests(test.TestCase): # TODO(allenl): Track down python3 reference cycles in these tests. - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testNotSaveableButIsCheckpointable(self): v = _OwnsAVariableSimple() saver = saver_module.Saver(var_list=[v]) @@ -3034,7 +2860,7 @@ class CheckpointableCompatibilityTests(test.TestCase): saver.restore(sess, save_path) self.assertEqual(42., self.evaluate(v.non_dep_variable)) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def testMoreComplexSaveableReturned(self): v = _OwnsMirroredVariables() saver = saver_module.Saver(var_list=[v]) @@ -3138,27 +2964,33 @@ class CheckpointableCompatibilityTests(test.TestCase): errors.NotFoundError, "Key b not found in checkpoint"): b_saver.restore(sess=sess, save_path=save_path) - def testCheckpointNotFoundErrorRaised(self): - # Restore does some tricky exception handling to figure out if it should - # load an object-based checkpoint. Tests that the exception handling isn't - # too broad. - a = resource_variable_ops.ResourceVariable(1., name="a") - saver = saver_module.Saver([a]) - with self.test_session() as sess: - with self.assertRaisesRegexp( - errors.NotFoundError, - "Failed to find any matching files for path_which_does_not_exist"): - saver.restore(sess=sess, save_path="path_which_does_not_exist") - try: - saver.restore(sess=sess, save_path="path_which_does_not_exist") - except errors.NotFoundError: - # Make sure we don't have a confusing "During handling of the above - # exception" block in Python 3. - # pylint: disable=no-value-for-parameter - exception_string = "\n".join( - traceback.format_exception(*sys.exc_info())) - # pylint: enable=no-value-for-parameter - self.assertNotIn("NewCheckpointReader", exception_string) + with self.assertRaises(errors.NotFoundError) as cs: + b_saver.restore(sess=sess, save_path=save_path) + + # Make sure we don't have a confusing "During handling of the above + # exception" block in Python 3. + self.assertNotIn("NewCheckpointReader", cs.exception.message) + + def testGraphChangedForRestoreErrorRaised(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + with ops_lib.Graph().as_default() as g: + a = variables.Variable(1., name="a") + a_saver = saver_module.Saver([a]) + + with self.test_session(graph=g) as sess: + sess.run(a.initializer) + save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix) + + with ops_lib.Graph().as_default() as g: + a = variables.Variable([1.], name="a") + a_saver = saver_module.Saver([a]) + with self.test_session(graph=g) as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "a mismatch between the current graph and the graph"): + a_saver.restore(sess=sess, save_path=save_path) def testLoadFromObjectBasedGraph(self): checkpoint_directory = self.get_temp_dir() diff --git a/tensorflow/python/training/server_lib.py b/tensorflow/python/training/server_lib.py index 2f421d1cc0a0190670082fabf4e25470c6a1723b..46543413e40a5a212b180b0cdeb2280148d606c5 100644 --- a/tensorflow/python/training/server_lib.py +++ b/tensorflow/python/training/server_lib.py @@ -42,8 +42,8 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol, Defaults to the value in `server_or_cluster_def`, if specified. Otherwise defaults to 0 if the server's job has only one task. protocol: (Optional.) Specifies the protocol to be used by the server. - Acceptable values include `"grpc"`. Defaults to the value in - `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. + Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value + in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. config: (Options.) A `tf.ConfigProto` that specifies default configuration options for all sessions that run on this server. @@ -98,9 +98,9 @@ class Server(object): """An in-process TensorFlow server, for use in distributed training. A `tf.train.Server` instance encapsulates a set of devices and a - @{tf.Session} target that + `tf.Session` target that can participate in distributed training. A server belongs to a - cluster (specified by a @{tf.train.ClusterSpec}), and + cluster (specified by a `tf.train.ClusterSpec`), and corresponds to a particular task in a named job. The server can communicate with any other server in the same cluster. """ @@ -129,8 +129,9 @@ class Server(object): job. Defaults to the value in `server_or_cluster_def`, if specified. Otherwise defaults to 0 if the server's job has only one task. protocol: (Optional.) Specifies the protocol to be used by the server. - Acceptable values include `"grpc"`. Defaults to the value in - `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. + Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the + value in `server_or_cluster_def`, if specified. Otherwise defaults to + `"grpc"`. config: (Options.) A `tf.ConfigProto` that specifies default configuration options for all sessions that run on this server. start: (Optional.) Boolean, indicating whether to start the server @@ -185,7 +186,7 @@ class Server(object): """Returns the target for a `tf.Session` to connect to this server. To create a - @{tf.Session} that + `tf.Session` that connects to this server, use the following snippet: ```python @@ -229,7 +230,7 @@ class ClusterSpec(object): A `tf.train.ClusterSpec` represents the set of processes that participate in a distributed TensorFlow computation. Every - @{tf.train.Server} is constructed in a particular cluster. + `tf.train.Server` is constructed in a particular cluster. To create a cluster with two jobs and five tasks, you specify the mapping from job names to lists of network addresses (typically @@ -420,7 +421,7 @@ class ClusterSpec(object): NOTE: For backwards compatibility, this method returns a list. If the given job was defined with a sparse set of task indices, the length of this list may not reflect the number of tasks defined in - this job. Use the @{tf.train.ClusterSpec.num_tasks} method + this job. Use the `tf.train.ClusterSpec.num_tasks` method to find the number of tasks defined in a particular job. Args: diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py index 974f75777f43ab4ef3be2edea564d1ad902e4fd5..a2e0645ba894cf1d5b4acea0aadc2abbd77eb29e 100644 --- a/tensorflow/python/training/session_manager.py +++ b/tensorflow/python/training/session_manager.py @@ -24,7 +24,7 @@ from tensorflow.python.client import session from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import saver as saver_mod +from tensorflow.python.training import checkpoint_management from tensorflow.python.util.tf_export import tf_export @@ -197,13 +197,13 @@ class SessionManager(object): # Waits up until max_wait_secs for checkpoint to become available. wait_time = 0 - ckpt = saver_mod.get_checkpoint_state(checkpoint_dir) + ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) while not ckpt or not ckpt.model_checkpoint_path: if wait_for_checkpoint and wait_time < max_wait_secs: logging.info("Waiting for checkpoint to be available.") time.sleep(self._recovery_wait_secs) wait_time += self._recovery_wait_secs - ckpt = saver_mod.get_checkpoint_state(checkpoint_dir) + ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) else: return sess, False diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py index 6670d9365f2994a70b7228170179f97d314041c9..d7e6dac95b1afe35c4dd93848d3b1cda872266e8 100644 --- a/tensorflow/python/training/session_manager_test.py +++ b/tensorflow/python/training/session_manager_test.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import server_lib from tensorflow.python.training import session_manager @@ -174,13 +175,13 @@ class SessionManagerTest(test.TestCase): os.path.join(checkpoint_dir, "recover_session_checkpoint")) self._test_recovered_variable(checkpoint_dir=checkpoint_dir) self._test_recovered_variable( - checkpoint_filename_with_path=saver_lib.latest_checkpoint( + checkpoint_filename_with_path=checkpoint_management.latest_checkpoint( checkpoint_dir)) # Cannot set both checkpoint_dir and checkpoint_filename_with_path. with self.assertRaises(ValueError): self._test_recovered_variable( checkpoint_dir=checkpoint_dir, - checkpoint_filename_with_path=saver_lib.latest_checkpoint( + checkpoint_filename_with_path=checkpoint_management.latest_checkpoint( checkpoint_dir)) def testWaitForSessionReturnsNoneAfterTimeout(self): diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py index 258a6f045d7c1b491ce00bdf8dd0ae6ad500ba68..d76b22acd86956e9b7bbd768299e3db7f630a4d5 100644 --- a/tensorflow/python/training/slot_creator.py +++ b/tensorflow/python/training/slot_creator.py @@ -45,7 +45,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import distribution_strategy_context def _create_slot_var(primary, val, scope, validate_shape, shape, dtype): @@ -112,7 +112,8 @@ def create_slot(primary, val, name, colocate_with_primary=True): prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: - distribution_strategy = distribute_lib.get_distribution_strategy() + distribution_strategy = ( + distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(primary): return _create_slot_var(primary, val, "", validate_shape, None, None) else: @@ -149,7 +150,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name, prefix = primary.op.name with variable_scope.variable_scope(None, prefix + "/" + name): if colocate_with_primary: - distribution_strategy = distribute_lib.get_distribution_strategy() + distribution_strategy = ( + distribution_strategy_context.get_distribution_strategy()) with distribution_strategy.colocate_vars_with(primary): return _create_slot_var(primary, initializer, "", validate_shape, shape, dtype) diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index 372ea415df0ee299ebb51b2369c1027eb2db4865..0755364bbe291d951c3641c44aa2e9995e1efbfb 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -45,7 +45,7 @@ class Supervisor(object): """A training helper that checkpoints models and computes summaries. This class is deprecated. Please use - @{tf.train.MonitoredTrainingSession} instead. + `tf.train.MonitoredTrainingSession` instead. The Supervisor is a small wrapper around a `Coordinator`, a `Saver`, and a `SessionManager` that takes care of common needs of TensorFlow @@ -134,7 +134,7 @@ class Supervisor(object): * Specifying `'local'` requests a session that uses the RPC-based "Master interface" to run TensorFlow programs. See - @{tf.train.Server.create_local_server} for + `tf.train.Server.create_local_server` for details. * Specifying `'grpc://hostname:port'` requests a session that uses diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py index 4abce85852c4a3a4b319aea919df57e5cb67b9e3..71ed88093aa72c951366a1c7ccd2078380fa4ee3 100644 --- a/tensorflow/python/training/supervisor_test.py +++ b/tensorflow/python/training/supervisor_test.py @@ -44,6 +44,7 @@ from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.summary import summary_iterator from tensorflow.python.summary.writer import writer +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import input as input_lib from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import server_lib @@ -83,7 +84,7 @@ class SupervisorTest(test.TestCase): end_time = time.time() + timeout_secs while time.time() < end_time: if for_checkpoint: - if saver_lib.checkpoint_exists(pattern): + if checkpoint_management.checkpoint_exists(pattern): return else: if len(gfile.Glob(pattern)) >= 1: diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index 3f2dc6797623b4973543b674c3069a3110c59465..6f6305a50576b9e4ef8d14e76e60d08e025939c3 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -53,6 +53,7 @@ from tensorflow.python.training import input as _input from tensorflow.python.training.input import * # pylint: disable=redefined-builtin # pylint: enable=wildcard-import +from tensorflow.python.training.basic_session_run_hooks import get_or_create_steps_per_run_variable from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer from tensorflow.python.training.basic_session_run_hooks import LoggingTensorHook from tensorflow.python.training.basic_session_run_hooks import StopAtStepHook @@ -82,12 +83,12 @@ from tensorflow.python.training.monitored_session import WorkerSessionCreator from tensorflow.python.training.monitored_session import MonitoredSession from tensorflow.python.training.monitored_session import SingularMonitoredSession from tensorflow.python.training.saver import Saver -from tensorflow.python.training.saver import checkpoint_exists -from tensorflow.python.training.saver import generate_checkpoint_state_proto -from tensorflow.python.training.saver import get_checkpoint_mtimes -from tensorflow.python.training.saver import get_checkpoint_state -from tensorflow.python.training.saver import latest_checkpoint -from tensorflow.python.training.saver import update_checkpoint_state +from tensorflow.python.training.checkpoint_management import checkpoint_exists +from tensorflow.python.training.checkpoint_management import generate_checkpoint_state_proto +from tensorflow.python.training.checkpoint_management import get_checkpoint_mtimes +from tensorflow.python.training.checkpoint_management import get_checkpoint_state +from tensorflow.python.training.checkpoint_management import latest_checkpoint +from tensorflow.python.training.checkpoint_management import update_checkpoint_state from tensorflow.python.training.saver import export_meta_graph from tensorflow.python.training.saver import import_meta_graph from tensorflow.python.training.session_run_hook import SessionRunHook diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index 0877b2a8a2fc7d59c4075c7d37c52ab691ec0361..2ff3eeb1535eb5cdb143cdc1307c2e4d7b00da8b 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -44,11 +44,13 @@ def global_step(sess, global_step_tensor): """Small helper to get the global step. ```python - # Creates a variable to hold the global_step. + # Create a variable to hold the global_step. global_step_tensor = tf.Variable(10, trainable=False, name='global_step') - # Creates a session. + # Create a session. sess = tf.Session() - # Initializes the variable. + # Initialize the variable + sess.run(global_step_tensor.initializer) + # Get the variable value. print('global_step: %s' % tf.train.global_step(sess, global_step_tensor)) global_step: 10 diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py index ec740abdd15ae2904f79246429deaa5fc831dad5..0ba7ba983d131adf318ea4aba6e2479679226120 100644 --- a/tensorflow/python/training/warm_starting_util.py +++ b/tensorflow/python/training/warm_starting_util.py @@ -22,7 +22,6 @@ import collections import six from tensorflow.python.framework import ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib @@ -45,7 +44,7 @@ class VocabInfo( ])): """Vocabulary information for warm-starting. - See @{tf.estimator.WarmStartSettings$WarmStartSettings} for examples of using + See `tf.estimator.WarmStartSettings` for examples of using VocabInfo to warm-start. Attributes: @@ -83,11 +82,6 @@ class VocabInfo( ) -def _is_variable(x): - return (isinstance(x, variables_lib.Variable) or - isinstance(x, resource_variable_ops.ResourceVariable)) - - def _infer_var_name(var): """Returns name of the `var`. @@ -126,9 +120,10 @@ def _warm_start_var(var, prev_ckpt, prev_tensor_name=None): prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If None, we lookup tensor with same name as given `var`. """ - if _is_variable(var): + if checkpoint_utils._is_variable(var): # pylint: disable=protected-access current_var_name = _infer_var_name([var]) - elif isinstance(var, list) and all(_is_variable(v) for v in var): + elif (isinstance(var, list) and + all(checkpoint_utils._is_variable(v) for v in var)): # pylint: disable=protected-access current_var_name = _infer_var_name(var) elif isinstance(var, variables_lib.PartitionedVariable): current_var_name = _infer_var_name([var]) @@ -193,9 +188,10 @@ def _warm_start_var_with_vocab(var, prev_vocab_path): raise ValueError("Invalid args: Must provide all of [current_vocab_path, " "current_vocab_size, prev_ckpt, prev_vocab_path}.") - if _is_variable(var): + if checkpoint_utils._is_variable(var): var = [var] - elif isinstance(var, list) and all(_is_variable(v) for v in var): + elif (isinstance(var, list) and + all(checkpoint_utils._is_variable(v) for v in var)): var = var elif isinstance(var, variables_lib.PartitionedVariable): var = var._get_variable_list() @@ -271,7 +267,7 @@ def _get_grouped_variables(vars_to_warm_start): for v in vars_to_warm_start: list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=v) - elif all([_is_variable(v) for v in vars_to_warm_start]): + elif all([checkpoint_utils._is_variable(v) for v in vars_to_warm_start]): # pylint: disable=protected-access list_of_vars = vars_to_warm_start else: raise ValueError("If `vars_to_warm_start` is a list, it must be all " diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py index 376be39978fb11463ae8a870492a359c89a9f2ce..c43589f5c4555180442a1962e25f82e51d677d1b 100644 --- a/tensorflow/python/util/deprecation.py +++ b/tensorflow/python/util/deprecation.py @@ -37,6 +37,11 @@ _PRINT_DEPRECATION_WARNINGS = True _PRINTED_WARNING = {} +class DeprecatedNamesAlreadySet(Exception): + """Raised when setting deprecated names multiple times for the same symbol.""" + pass + + def _add_deprecated_function_notice_to_docstring(doc, date, instructions): """Adds a deprecation notice to a docstring for deprecated functions.""" main_text = ['THIS FUNCTION IS DEPRECATED. It will be removed %s.' % @@ -87,6 +92,27 @@ def _call_location(outer=False): return '%s:%d' % (entry[1], entry[2]) +def _wrap_decorator(wrapped_function): + """Indicate that one function wraps another. + + This decorator wraps a function using `tf_decorator.make_decorator` + so that doc generation scripts can pick up original function + signature. + It would be better to use @functools.wrap decorator, but it would + not update function signature to match wrapped function in Python 2. + + Args: + wrapped_function: The function that decorated function wraps. + + Returns: + Function that accepts wrapper function as an argument and returns + `TFDecorator` instance. + """ + def wrapper(wrapper_func): + return tf_decorator.make_decorator(wrapped_function, wrapper_func) + return wrapper + + def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True): """Deprecate a symbol in favor of a new name with identical semantics. @@ -144,7 +170,7 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True): if tf_inspect.isclass(func_or_class): # Make a new class with __init__ wrapped in a warning. - class NewClass(func_or_class): # pylint: disable=missing-docstring + class _NewClass(func_or_class): # pylint: disable=missing-docstring __doc__ = decorator_utils.add_notice_to_docstring( func_or_class.__doc__, 'Please use %s instead.' % name, 'DEPRECATED CLASS', @@ -153,27 +179,28 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True): __name__ = func_or_class.__name__ __module__ = _call_location(outer=True) + @_wrap_decorator(func_or_class.__init__) def __init__(self, *args, **kwargs): - if hasattr(NewClass.__init__, '__func__'): + if hasattr(_NewClass.__init__, '__func__'): # Python 2 - NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__ + _NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__ else: # Python 3 - NewClass.__init__.__doc__ = func_or_class.__init__.__doc__ + _NewClass.__init__.__doc__ = func_or_class.__init__.__doc__ if _PRINT_DEPRECATION_WARNINGS: # We're making the alias as we speak. The original may have other # aliases, so we cannot use it to check for whether it's already been # warned about. - if NewClass.__init__ not in _PRINTED_WARNING: + if _NewClass.__init__ not in _PRINTED_WARNING: if warn_once: - _PRINTED_WARNING[NewClass.__init__] = True + _PRINTED_WARNING[_NewClass.__init__] = True logging.warning( 'From %s: The name %s is deprecated. Please use %s instead.\n', _call_location(), deprecated_name, name) - super(NewClass, self).__init__(*args, **kwargs) + super(_NewClass, self).__init__(*args, **kwargs) - return NewClass + return _NewClass else: decorator_utils.validate_callable(func_or_class, 'deprecated') @@ -197,6 +224,35 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True): func_or_class.__doc__, None, 'Please use %s instead.' % name)) +def deprecated_endpoints(*args): + """Decorator for marking endpoints deprecated. + + This decorator does not print deprecation messages. + TODO(annarev): eventually start printing deprecation warnings when + @deprecation_endpoints decorator is added. + + Args: + *args: Deprecated endpoint names. + + Returns: + A function that takes symbol as an argument and adds + _tf_deprecated_api_names to that symbol. + _tf_deprecated_api_names would be set to a list of deprecated + endpoint names for the symbol. + """ + def deprecated_wrapper(func): + # pylint: disable=protected-access + if '_tf_deprecated_api_names' in func.__dict__: + raise DeprecatedNamesAlreadySet( + 'Cannot set deprecated names for %s to %s. ' + 'Deprecated names are already set to %s.' % ( + func.__name__, str(args), str(func._tf_deprecated_api_names))) + func._tf_deprecated_api_names = args + # pylint: disable=protected-access + return func + return deprecated_wrapper + + def deprecated(date, instructions, warn_once=True): """Decorator for marking functions or methods deprecated. @@ -332,13 +388,13 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples, Args: names_to_ok_vals: dict from string arg_name to a list of values, possibly empty, which should not elicit a warning. - arg_spec: Output from tf_inspect.getargspec on the called function. + arg_spec: Output from tf_inspect.getfullargspec on the called function. Returns: Dictionary from arg_name to DeprecatedArgSpec. """ - arg_name_to_pos = dict( - (name, pos) for (pos, name) in enumerate(arg_spec.args)) + arg_name_to_pos = { + name: pos for pos, name in enumerate(arg_spec.args)} deprecated_positional_args = {} for arg_name, spec in iter(names_to_ok_vals.items()): if arg_name in arg_name_to_pos: @@ -352,16 +408,16 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples, decorator_utils.validate_callable(func, 'deprecated_args') deprecated_arg_names = _get_arg_names_to_ok_vals() - arg_spec = tf_inspect.getargspec(func) + arg_spec = tf_inspect.getfullargspec(func) deprecated_positions = _get_deprecated_positional_arguments( deprecated_arg_names, arg_spec) is_varargs_deprecated = arg_spec.varargs in deprecated_arg_names - is_kwargs_deprecated = arg_spec.keywords in deprecated_arg_names + is_kwargs_deprecated = arg_spec.varkw in deprecated_arg_names if (len(deprecated_positions) + is_varargs_deprecated + is_kwargs_deprecated != len(deprecated_arg_names_or_tuples)): - known_args = arg_spec.args + [arg_spec.varargs, arg_spec.keywords] + known_args = arg_spec.args + [arg_spec.varargs, arg_spec.varkw] missing_args = [arg_name for arg_name in deprecated_arg_names if arg_name not in known_args] raise ValueError('The following deprecated arguments are not present ' @@ -411,7 +467,7 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples, if is_varargs_deprecated and len(args) > len(arg_spec.args): invalid_args.append(arg_spec.varargs) if is_kwargs_deprecated and kwargs: - invalid_args.append(arg_spec.keywords) + invalid_args.append(arg_spec.varkw) for arg_name in deprecated_arg_names: if (arg_name in kwargs and not (deprecated_positions[arg_name].has_ok_value and diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py index bdd0bc48d29319914e184ea4331a5e9d4a1c3328..90c73a0a58d129af44cc051874acda37d5c78394 100644 --- a/tensorflow/python/util/deprecation_test.py +++ b/tensorflow/python/util/deprecation_test.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation +from tensorflow.python.util import tf_inspect class DeprecatedAliasTest(test.TestCase): @@ -73,6 +74,11 @@ class DeprecatedAliasTest(test.TestCase): self.assertEqual(["test", "deprecated", "deprecated again"], MyClass.init_args) + # Check __init__ signature matches for doc generation. + self.assertEqual( + tf_inspect.getfullargspec(MyClass.__init__), + tf_inspect.getfullargspec(deprecated_cls.__init__)) + class DeprecationTest(test.TestCase): @@ -929,5 +935,27 @@ class DeprecationArgumentsTest(test.TestCase): self.assertEqual(new_docs, new_docs_ref) +class DeprecatedEndpointsTest(test.TestCase): + + def testSingleDeprecatedEndpoint(self): + @deprecation.deprecated_endpoints("foo1") + def foo(): + pass + self.assertEqual(("foo1",), foo._tf_deprecated_api_names) + + def testMultipleDeprecatedEndpoint(self): + @deprecation.deprecated_endpoints("foo1", "foo2") + def foo(): + pass + self.assertEqual(("foo1", "foo2"), foo._tf_deprecated_api_names) + + def testCannotSetDeprecatedEndpointsTwice(self): + with self.assertRaises(deprecation.DeprecatedNamesAlreadySet): + @deprecation.deprecated_endpoints("foo1") + @deprecation.deprecated_endpoints("foo2") + def foo(): # pylint: disable=unused-variable + pass + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py index 7bbbde3cd288a7373c1ac845977a4d92d2a1b7c0..4e9b07e20ac7ef176316d3532958c84754628e56 100644 --- a/tensorflow/python/util/function_utils.py +++ b/tensorflow/python/util/function_utils.py @@ -20,6 +20,8 @@ from __future__ import print_function import functools +import six + from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -55,3 +57,36 @@ def fn_args(fn): if _is_bounded_method(fn): args.remove('self') return tuple(args) + + +def get_func_name(func): + """Returns name of passed callable.""" + _, func = tf_decorator.unwrap(func) + if callable(func): + if tf_inspect.isfunction(func): + return func.__name__ + elif tf_inspect.ismethod(func): + return '%s.%s' % (six.get_method_self(func).__class__.__name__, + six.get_method_function(func).__name__) + else: # Probably a class instance with __call__ + return str(type(func)) + else: + raise ValueError('Argument must be callable') + + +def get_func_code(func): + """Returns func_code of passed callable, or None if not available.""" + _, func = tf_decorator.unwrap(func) + if callable(func): + if tf_inspect.isfunction(func) or tf_inspect.ismethod(func): + return six.get_function_code(func) + # Since the object is not a function or method, but is a callable, we will + # try to access the __call__method as a function. This works with callable + # classes but fails with functool.partial objects despite their __call__ + # attribute. + try: + return six.get_function_code(func.__call__) + except AttributeError: + return None + else: + raise ValueError('Argument must be callable') diff --git a/tensorflow/python/util/function_utils_test.py b/tensorflow/python/util/function_utils_test.py index e78cf6a5b02af317b08ff3a833f7b73b062f106e..1588328c262982e5b71446e499c8d0217c28c0a5 100644 --- a/tensorflow/python/util/function_utils_test.py +++ b/tensorflow/python/util/function_utils_test.py @@ -24,6 +24,16 @@ from tensorflow.python.platform import test from tensorflow.python.util import function_utils +def silly_example_function(): + pass + + +class SillyCallableClass(object): + + def __call__(self): + pass + + class FnArgsTest(test.TestCase): def test_simple_function(self): @@ -124,5 +134,73 @@ class FnArgsTest(test.TestCase): self.assertEqual(3, double_wrapped_fn(3)) self.assertEqual(3, double_wrapped_fn(a=3)) + +class GetFuncNameTest(test.TestCase): + + def testWithSimpleFunction(self): + self.assertEqual( + 'silly_example_function', + function_utils.get_func_name(silly_example_function)) + + def testWithClassMethod(self): + self.assertEqual( + 'GetFuncNameTest.testWithClassMethod', + function_utils.get_func_name(self.testWithClassMethod)) + + def testWithCallableClass(self): + callable_instance = SillyCallableClass() + self.assertRegexpMatches( + function_utils.get_func_name(callable_instance), + '<.*SillyCallableClass.*>') + + def testWithFunctoolsPartial(self): + partial = functools.partial(silly_example_function) + self.assertRegexpMatches( + function_utils.get_func_name(partial), + '<.*functools.partial.*>') + + def testWithLambda(self): + anon_fn = lambda x: x + self.assertEqual('', function_utils.get_func_name(anon_fn)) + + def testRaisesWithNonCallableObject(self): + with self.assertRaises(ValueError): + function_utils.get_func_name(None) + + +class GetFuncCodeTest(test.TestCase): + + def testWithSimpleFunction(self): + code = function_utils.get_func_code(silly_example_function) + self.assertIsNotNone(code) + self.assertRegexpMatches(code.co_filename, 'function_utils_test.py') + + def testWithClassMethod(self): + code = function_utils.get_func_code(self.testWithClassMethod) + self.assertIsNotNone(code) + self.assertRegexpMatches(code.co_filename, 'function_utils_test.py') + + def testWithCallableClass(self): + callable_instance = SillyCallableClass() + code = function_utils.get_func_code(callable_instance) + self.assertIsNotNone(code) + self.assertRegexpMatches(code.co_filename, 'function_utils_test.py') + + def testWithLambda(self): + anon_fn = lambda x: x + code = function_utils.get_func_code(anon_fn) + self.assertIsNotNone(code) + self.assertRegexpMatches(code.co_filename, 'function_utils_test.py') + + def testWithFunctoolsPartial(self): + partial = functools.partial(silly_example_function) + code = function_utils.get_func_code(partial) + self.assertIsNone(code) + + def testRaisesWithNonCallableObject(self): + with self.assertRaises(ValueError): + function_utils.get_func_code(None) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/util/lock_util.py b/tensorflow/python/util/lock_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0424960666323870fb1db83804857dd838cfe9ae --- /dev/null +++ b/tensorflow/python/util/lock_util.py @@ -0,0 +1,128 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Locking related utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + + +class GroupLock(object): + """A lock to allow many members of a group to access a resource exclusively. + + This lock provides a way to allow access to a resource by multiple threads + belonging to a logical group at the same time, while restricting access to + threads from all other groups. You can think of this as an extension of a + reader-writer lock, where you allow multiple writers at the same time. We + made it generic to support multiple groups instead of just two - readers and + writers. + + Simple usage example with two groups accessing the same resource: + + ```python + lock = GroupLock(num_groups=2) + + # In a member of group 0: + with lock.group(0): + # do stuff, access the resource + # ... + + # In a member of group 1: + with lock.group(1): + # do stuff, access the resource + # ... + ``` + + Using as a context manager with `.group(group_id)` is the easiest way. You + can also use the `acquire` and `release` method directly. + """ + + def __init__(self, num_groups=2): + """Initialize a group lock. + + Args: + num_groups: The number of groups that will be accessing the resource under + consideration. Should be a positive number. + + Returns: + A group lock that can then be used to synchronize code. + + Raises: + ValueError: If num_groups is less than 1. + """ + if num_groups < 1: + raise ValueError("num_groups must be a positive integer, got {}".format( + num_groups)) + self._ready = threading.Condition(threading.Lock()) + self._num_groups = num_groups + self._group_member_counts = [0] * self._num_groups + + def group(self, group_id): + """Enter a context where the lock is with group `group_id`. + + Args: + group_id: The group for which to acquire and release the lock. + + Returns: + A context manager which will acquire the lock for `group_id`. + """ + self._validate_group_id(group_id) + return self._Context(self, group_id) + + def acquire(self, group_id): + """Acquire the group lock for a specific group `group_id`.""" + self._validate_group_id(group_id) + + self._ready.acquire() + while self._another_group_active(group_id): + self._ready.wait() + self._group_member_counts[group_id] += 1 + self._ready.release() + + def release(self, group_id): + """Release the group lock for a specific group `group_id`.""" + self._validate_group_id(group_id) + + self._ready.acquire() + self._group_member_counts[group_id] -= 1 + if self._group_member_counts[group_id] == 0: + self._ready.notifyAll() + self._ready.release() + + def _another_group_active(self, group_id): + return any( + c > 0 for g, c in enumerate(self._group_member_counts) if g != group_id) + + def _validate_group_id(self, group_id): + if group_id < 0 or group_id >= self._num_groups: + raise ValueError( + "group_id={} should be between 0 and num_groups={}".format( + group_id, self._num_groups)) + + class _Context(object): + """Context manager helper for `GroupLock`.""" + + def __init__(self, lock, group_id): + self._lock = lock + self._group_id = group_id + + def __enter__(self): + self._lock.acquire(self._group_id) + + def __exit__(self, type_arg, value_arg, traceback_arg): + del type_arg, value_arg, traceback_arg + self._lock.release(self._group_id) diff --git a/tensorflow/python/util/lock_util_test.py b/tensorflow/python/util/lock_util_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cda8f952259c9e117e0bd7ff3cac35e764856f43 --- /dev/null +++ b/tensorflow/python/util/lock_util_test.py @@ -0,0 +1,63 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for lock_util.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import time + +from absl.testing import parameterized + +from tensorflow.python.platform import test +from tensorflow.python.util import lock_util + + +class GroupLockTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters(1, 2, 3, 5, 10) + def testGroups(self, num_groups): + lock = lock_util.GroupLock(num_groups) + num_threads = 10 + finished = set() + + def thread_fn(thread_id): + time.sleep(random.random() * 0.1) + group_id = thread_id % num_groups + with lock.group(group_id): + time.sleep(random.random() * 0.1) + self.assertGreater(lock._group_member_counts[group_id], 0) + for g, c in enumerate(lock._group_member_counts): + if g != group_id: + self.assertEqual(0, c) + finished.add(thread_id) + + threads = [ + self.checkedThread(target=thread_fn, args=(i,)) + for i in range(num_threads) + ] + + for i in range(num_threads): + threads[i].start() + for i in range(num_threads): + threads[i].join() + + self.assertEqual(set(range(num_threads)), finished) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 1104768ae8f69598f686eb2ffee8b69e43051011..faae0d89c397073dfd6f022775e72bdb106b17e8 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -73,7 +73,7 @@ def _sequence_like(instance, args): Returns: `args` with the type of `instance`. """ - if isinstance(instance, dict): + if isinstance(instance, (dict, _collections.Mapping)): # Pack dictionaries in a deterministic order by sorting the keys. # Notice this means that we ignore the original order of `OrderedDict` # instances. This is intentional, to avoid potential bugs caused by mixing @@ -89,7 +89,7 @@ def _sequence_like(instance, args): def _yield_value(iterable): - if isinstance(iterable, dict): + if isinstance(iterable, (dict, _collections.Mapping)): # Iterate through dictionaries in a deterministic order by sorting the # keys. Notice this means that we ignore the original order of `OrderedDict` # instances. This is intentional, to avoid potential bugs caused by mixing @@ -167,11 +167,14 @@ def assert_same_structure(nest1, nest2, check_types=True): Args: nest1: an arbitrarily nested structure. nest2: an arbitrarily nested structure. - check_types: if `True` (default) types of sequences are checked as - well, including the keys of dictionaries. If set to `False`, for example - a list and a tuple of objects will look the same if they have the same + check_types: if `True` (default) types of sequences are checked as well, + including the keys of dictionaries. If set to `False`, for example a + list and a tuple of objects will look the same if they have the same size. Note that namedtuples with identical name and fields are always - considered to have the same shallow structure. + considered to have the same shallow structure. Two types will also be + considered the same if they are both list subtypes (which allows "list" + and "_ListWrapper" from checkpointable dependency tracking to compare + equal). Raises: ValueError: If the two structures do not have the same number of elements or @@ -212,7 +215,7 @@ def flatten_dict_items(dictionary): ValueError: If any key and value have not the same structure, or if keys are not unique. """ - if not isinstance(dictionary, dict): + if not isinstance(dictionary, (dict, _collections.Mapping)): raise TypeError("input must be a dictionary") flat_dictionary = {} for i, v in _six.iteritems(dictionary): @@ -374,6 +377,62 @@ def map_structure(func, *structure, **check_types_dict): structure[0], [func(*x) for x in entries]) +def map_structure_with_paths(func, *structure, **kwargs): + """Applies `func` to each entry in `structure` and returns a new structure. + + Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in + `structure[i]` and `path` is the common path to x[i] in the structures. All + structures in `structure` must have the same arity, and the return value will + contain the results in the same structure. Special kwarg `check_types` + determines whether the types of iterables within the structure must be the + same-- see **kwargs definition below. + + Args: + func: A callable with the signature func(path, *values, **kwargs) that is + evaluated on the leaves of the structure. + *structure: A variable number of compatible structures to process. + **kwargs: Optional kwargs to be passed through to func. Special kwarg + `check_types` is not passed to func, but instead determines whether the + types of iterables within the structures have to be same (e.g., + `map_structure(func, [1], (1,))` raises a `TypeError` exception). By + default, the types must match. To allow iteration over structures of + different types (but common arity), set this kwarg to `False`. + + Returns: + A structure of the same form as the input structures whose leaves are the + result of evaluating func on corresponding leaves of the input structures. + + Raises: + TypeError: If `func` is not callable or if the structures do not match + each other by depth tree. + TypeError: If `check_types` is not `False` and the two structures differ in + the type of sequence in any of their substructures. + ValueError: If no structures are provided. + """ + if not callable(func): + raise TypeError("func must be callable, got: %s" % func) + if not structure: + raise ValueError("Must provide at least one structure") + + check_types = kwargs.pop("check_types", True) + for other in structure[1:]: + assert_same_structure(structure[0], other, check_types=check_types) + + # First set paths_and_values to: + # [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]] + paths_and_values = [flatten_with_joined_string_paths(s) for s in structure] + + # Now zip(*paths_and_values) would be: + # [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))] + # so grouped_by_path is set to: + # [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]] + # Note that p1i, ... pmi must all be equal since the structures are the same. + grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)] + + return pack_sequence_as(structure[0], [ + func(paths[0], *values, **kwargs) for paths, values in grouped_by_path]) + + def _yield_flat_up_to(shallow_tree, input_tree): """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" if is_sequence(shallow_tree): @@ -452,7 +511,7 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): "structure has length %s, while shallow structure has length %s." % (len(input_tree), len(shallow_tree))) - if check_types and isinstance(shallow_tree, dict): + if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)): if set(input_tree) != set(shallow_tree): raise ValueError( "The two structures don't have the same keys. Input " @@ -713,7 +772,7 @@ def yield_flat_paths(nest): # The _maybe_add_final_path_element function is used below in order to avoid # adding trailing slashes when the sub-element recursed into is a leaf. - if isinstance(nest, dict): + if isinstance(nest, (dict, _collections.Mapping)): for key in _sorted(nest): value = nest[key] for sub_path in yield_flat_paths(value): @@ -757,3 +816,4 @@ def flatten_with_joined_string_paths(structure, separator="/"): _pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence) +_pywrap_tensorflow.RegisterMappingClass(_collections.Mapping) diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 2f12b25354a905b2aafa870c28f1e9c0b693e888..2369eb610e2a56e84fed54d129fc4b36cd96886f 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections import time +from absl.testing import parameterized import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin @@ -33,7 +34,22 @@ from tensorflow.python.platform import test from tensorflow.python.util import nest -class NestTest(test.TestCase): +class _CustomMapping(collections.Mapping): + + def __init__(self, *args, **kwargs): + self._wrapped = dict(*args, **kwargs) + + def __getitem__(self, key): + return self._wrapped[key] + + def __iter__(self): + return iter(self._wrapped) + + def __len__(self): + return len(self._wrapped) + + +class NestTest(parameterized.TestCase, test.TestCase): PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name @@ -72,26 +88,32 @@ class NestTest(test.TestCase): with self.assertRaises(ValueError): nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) + @parameterized.parameters({"mapping_type": collections.OrderedDict}, + {"mapping_type": _CustomMapping}) @test_util.assert_no_new_pyobjects_executing_eagerly - def testFlattenDictOrder(self): + def testFlattenDictOrder(self, mapping_type): """`flatten` orders dicts by key, including OrderedDicts.""" - ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) + ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) plain = {"d": 3, "b": 1, "a": 0, "c": 2} ordered_flat = nest.flatten(ordered) plain_flat = nest.flatten(plain) self.assertEqual([0, 1, 2, 3], ordered_flat) self.assertEqual([0, 1, 2, 3], plain_flat) - def testPackDictOrder(self): + @parameterized.parameters({"mapping_type": collections.OrderedDict}, + {"mapping_type": _CustomMapping}) + def testPackDictOrder(self, mapping_type): """Packing orders dicts by key, including OrderedDicts.""" - ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) + custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) plain = {"d": 0, "b": 0, "a": 0, "c": 0} seq = [0, 1, 2, 3] - ordered_reconstruction = nest.pack_sequence_as(ordered, seq) + custom_reconstruction = nest.pack_sequence_as(custom, seq) plain_reconstruction = nest.pack_sequence_as(plain, seq) + self.assertIsInstance(custom_reconstruction, mapping_type) + self.assertIsInstance(plain_reconstruction, dict) self.assertEqual( - collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), - ordered_reconstruction) + mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), + custom_reconstruction) self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name @@ -101,8 +123,10 @@ class NestTest(test.TestCase): # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. mess = [ "z", - NestTest.Abc(3, 4), - { + NestTest.Abc(3, 4), { + "d": _CustomMapping({ + 41: 4 + }), "c": [ 1, collections.OrderedDict([ @@ -111,17 +135,19 @@ class NestTest(test.TestCase): ]), ], "b": 5 - }, - 17 + }, 17 ] flattened = nest.flatten(mess) - self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17]) + self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) structure_of_mess = [ 14, NestTest.Abc("a", True), { + "d": _CustomMapping({ + 41: 42 + }), "c": [ 0, collections.OrderedDict([ @@ -142,6 +168,10 @@ class NestTest(test.TestCase): self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) + unflattened_custom_mapping = unflattened[2]["d"] + self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) + self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) + def testFlatten_numpyIsNotFlattened(self): structure = np.array([1, 2, 3]) flattened = nest.flatten(structure) @@ -179,19 +209,23 @@ class NestTest(test.TestCase): self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) self.assertFalse(nest.is_sequence(np.ones((4, 5)))) - def testFlattenDictItems(self): - dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} + @parameterized.parameters({"mapping_type": _CustomMapping}, + {"mapping_type": dict}) + def testFlattenDictItems(self, mapping_type): + dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))}) flat = {4: "a", 5: "b", 6: "c", 8: "d"} self.assertEqual(nest.flatten_dict_items(dictionary), flat) with self.assertRaises(TypeError): nest.flatten_dict_items(4) - bad_dictionary = {(4, 5, (4, 8)): ("a", "b", ("c", "d"))} + bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))}) with self.assertRaisesRegexp(ValueError, "not unique"): nest.flatten_dict_items(bad_dictionary) - another_bad_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))} + another_bad_dictionary = mapping_type({ + (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) + }) with self.assertRaisesRegexp( ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): nest.flatten_dict_items(another_bad_dictionary) @@ -320,6 +354,10 @@ class NestTest(test.TestCase): EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name + def testHeterogeneousComparison(self): + nest.assert_same_structure({"a": 4}, _CustomMapping(a=3)) + nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) + @test_util.assert_no_new_pyobjects_executing_eagerly def testMapStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) @@ -712,6 +750,35 @@ class NestTest(test.TestCase): self.assertEqual( list(nest.flatten_with_joined_string_paths(inputs)), expected) + @parameterized.named_parameters( + ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), + ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, + {"a": ("a", 4), "b": ("b", 6)}), + ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), + ("nested", + {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, + {"a": [("a/0", 10), ("a/1", 12)], + "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) + def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): + def format_sum(path, *values): + return (path, sum(values)) + result = nest.map_structure_with_paths(format_sum, s1, s2, + check_types=check_types) + self.assertEqual(expected, result) + + @parameterized.named_parameters( + ("tuples", (1, 2), (3, 4, 5), ValueError), + ("dicts", {"a": 1}, {"b": 2}, ValueError), + ("mixed", (1, 2), [3, 4], TypeError), + ("nested", + {"a": [2, 3], "b": [1, 3]}, + {"b": [5, 6, 7], "a": [8, 9]}, + ValueError + )) + def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): + with self.assertRaises(error_type): + nest.map_structure_with_paths(lambda path, *s: 0, s1, s2) + class NestBenchmark(test.Benchmark): diff --git a/tensorflow/python/util/py_checkpoint_reader.i b/tensorflow/python/util/py_checkpoint_reader.i index 8004898cbcbce7ce593ce35efdc6493e052468bd..1c73f7f06f1937a8db0bd858421c2e884892e25b 100644 --- a/tensorflow/python/util/py_checkpoint_reader.i +++ b/tensorflow/python/util/py_checkpoint_reader.i @@ -166,6 +166,7 @@ def NewCheckpointReader(filepattern): return CheckpointReader(compat.as_bytes(filepattern), status) NewCheckpointReader._tf_api_names = ['train.NewCheckpointReader'] +NewCheckpointReader._tf_api_names_v1 = ['train.NewCheckpointReader'] %} %include "tensorflow/c/checkpoint_reader.h" diff --git a/tensorflow/python/util/serialization_test.py b/tensorflow/python/util/serialization_test.py index 5000bcfad05900e63bc72c1bd0e31e30434b74ae..6df7533831bf7bacf8bb2833dac83276de30612a 100644 --- a/tensorflow/python/util/serialization_test.py +++ b/tensorflow/python/util/serialization_test.py @@ -47,7 +47,7 @@ class SerializationTests(test.TestCase): self.assertIs(round_trip[0], None) self.assertEqual(round_trip[1], 2) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_serialize_sequential(self): model = sequential.Sequential() model.add(core.Dense(4)) @@ -55,13 +55,10 @@ class SerializationTests(test.TestCase): model(constant_op.constant([[1.]])) sequential_round_trip = json.loads( json.dumps(model, default=serialization.get_json_type)) - self.assertEqual(5, sequential_round_trip["config"][1]["config"]["units"]) - input_round_trip = json.loads( - json.dumps(model._input_layers, default=serialization.get_json_type)) - self.assertAllEqual([1, 1], - input_round_trip[0]["config"]["batch_input_shape"]) + self.assertEqual( + 5, sequential_round_trip["config"]["layers"][1]["config"]["units"]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes def test_serialize_model(self): x = input_layer.Input(shape=[3]) y = core.Dense(10)(x) diff --git a/tensorflow/python/util/stat_summarizer.i b/tensorflow/python/util/stat_summarizer.i index 73fa85494b72d920d00577c826b76c3381d963a4..a5a7984d914f24964c377149f8125ceb3126c009 100644 --- a/tensorflow/python/util/stat_summarizer.i +++ b/tensorflow/python/util/stat_summarizer.i @@ -27,8 +27,8 @@ limitations under the License. %ignoreall -%unignore _NewStatSummarizer; -%unignore _DeleteStatSummarizer; +%unignore NewStatSummarizer; +%unignore DeleteStatSummarizer; %unignore tensorflow; %unignore tensorflow::StatSummarizer; %unignore tensorflow::StatSummarizer::StatSummarizer; @@ -43,20 +43,20 @@ limitations under the License. // TODO(ashankar): Remove the unused argument from the API. %{ -tensorflow::StatSummarizer* _NewStatSummarizer( +tensorflow::StatSummarizer* NewStatSummarizer( const string& unused) { return new tensorflow::StatSummarizer(tensorflow::StatSummarizerOptions()); } %} %{ -void _DeleteStatSummarizer(tensorflow::StatSummarizer* ss) { +void DeleteStatSummarizer(tensorflow::StatSummarizer* ss) { delete ss; } %} -tensorflow::StatSummarizer* _NewStatSummarizer(const string& unused); -void _DeleteStatSummarizer(tensorflow::StatSummarizer* ss); +tensorflow::StatSummarizer* NewStatSummarizer(const string& unused); +void DeleteStatSummarizer(tensorflow::StatSummarizer* ss); %extend tensorflow::StatSummarizer { void ProcessStepStatsStr(const string& step_stats_str) { @@ -76,16 +76,3 @@ void _DeleteStatSummarizer(tensorflow::StatSummarizer* ss); %include "tensorflow/core/util/stat_summarizer_options.h" %include "tensorflow/core/util/stat_summarizer.h" %unignoreall - -%insert("python") %{ - -# Wrapping NewStatSummarizer and DeletStatSummarizer because -# SWIG-generated functions are built-in functions and do not support -# setting _tf_api_names attribute. - -def NewStatSummarizer(unused): - return _NewStatSummarizer(unused) - -def DeleteStatSummarizer(stat_summarizer): - _DeleteStatSummarizer(stat_summarizer) -%} diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index e154ffb68a4f0ccdebf5320cad7d3da056117197..274f32c21f77483464a12a1beb25043a208b4b35 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -63,12 +63,63 @@ API_ATTRS = { '_estimator_api_constants') } +API_ATTRS_V1 = { + TENSORFLOW_API_NAME: _Attributes( + '_tf_api_names_v1', + '_tf_api_constants_v1'), + ESTIMATOR_API_NAME: _Attributes( + '_estimator_api_names_v1', + '_estimator_api_constants_v1') +} + class SymbolAlreadyExposedError(Exception): """Raised when adding API names to symbol that already has API names.""" pass +def get_canonical_name_for_symbol(symbol, api_name=TENSORFLOW_API_NAME): + """Get canonical name for the API symbol. + + Canonical name is the first non-deprecated endpoint name. + + Args: + symbol: API function or class. + api_name: API name (tensorflow or estimator). + + Returns: + Canonical name for the API symbol (for e.g. initializers.zeros) if + canonical name could be determined. Otherwise, returns None. + """ + if not hasattr(symbol, '__dict__'): + return None + api_names_attr = API_ATTRS[api_name].names + _, undecorated_symbol = tf_decorator.unwrap(symbol) + if api_names_attr not in undecorated_symbol.__dict__: + return None + api_names = getattr(undecorated_symbol, api_names_attr) + # TODO(annarev): may be add a separate deprecated attribute + # for estimator names. + deprecated_api_names = undecorated_symbol.__dict__.get( + '_tf_deprecated_api_names', []) + return get_canonical_name(api_names, deprecated_api_names) + + +def get_canonical_name(api_names, deprecated_api_names): + """Get first non-deprecated endpoint name. + + Args: + api_names: API names iterable. + deprecated_api_names: Deprecated API names iterable. + Returns: + Canonical name if there is at least one non-deprecated endpoint. + Otherwise returns None. + """ + return next( + (name for name in api_names if name not in deprecated_api_names), + None) + + class api_export(object): # pylint: disable=invalid-name """Provides ways to export symbols to the TensorFlow API.""" @@ -78,13 +129,16 @@ class api_export(object): # pylint: disable=invalid-name Args: *args: API names in dot delimited format. **kwargs: Optional keyed arguments. - overrides: List of symbols that this is overriding + v1: Names for the TensorFlow V1 API. If not set, we will use V2 API + names both for TensorFlow V1 and V2 APIs. + overrides: List of symbols that this is overriding (those overrided api exports will be removed). Note: passing overrides has no effect on exporting a constant. - api_name: Name of the API you want to generate (e.g. `tensorflow` or + api_name: Name of the API you want to generate (e.g. `tensorflow` or `estimator`). Default is `tensorflow`. """ self._names = args + self._names_v1 = kwargs.get('v1', args) self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME) self._overrides = kwargs.get('overrides', []) @@ -102,24 +156,27 @@ class api_export(object): # pylint: disable=invalid-name and kwarg `allow_multiple_exports` not set. """ api_names_attr = API_ATTRS[self._api_name].names - + api_names_attr_v1 = API_ATTRS_V1[self._api_name].names # Undecorate overridden names for f in self._overrides: _, undecorated_f = tf_decorator.unwrap(f) delattr(undecorated_f, api_names_attr) + delattr(undecorated_f, api_names_attr_v1) _, undecorated_func = tf_decorator.unwrap(func) + self.set_attr(undecorated_func, api_names_attr, self._names) + self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1) + return func + def set_attr(self, func, api_names_attr, names): # Check for an existing api. We check if attribute name is in # __dict__ instead of using hasattr to verify that subclasses have # their own _tf_api_names as opposed to just inheriting it. - if api_names_attr in undecorated_func.__dict__: + if api_names_attr in func.__dict__: raise SymbolAlreadyExposedError( 'Symbol %s is already exposed as %s.' % - (undecorated_func.__name__, getattr( - undecorated_func, api_names_attr))) # pylint: disable=protected-access - setattr(undecorated_func, api_names_attr, self._names) - return func + (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access + setattr(func, api_names_attr, names) def export_constant(self, module_name, name): """Store export information for constants/string literals. @@ -140,12 +197,20 @@ class api_export(object): # pylint: disable=invalid-name name: (string) Current constant name. """ module = sys.modules[module_name] - if not hasattr(module, API_ATTRS[self._api_name].constants): - setattr(module, API_ATTRS[self._api_name].constants, []) + api_constants_attr = API_ATTRS[self._api_name].constants + api_constants_attr_v1 = API_ATTRS_V1[self._api_name].constants + + if not hasattr(module, api_constants_attr): + setattr(module, api_constants_attr, []) # pylint: disable=protected-access - getattr(module, API_ATTRS[self._api_name].constants).append( + getattr(module, api_constants_attr).append( (self._names, name)) + if not hasattr(module, api_constants_attr_v1): + setattr(module, api_constants_attr_v1, []) + getattr(module, api_constants_attr_v1).append( + (self._names_v1, name)) + tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME) estimator_export = functools.partial(tf_export, api_name=ESTIMATOR_API_NAME) diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py index b9e26ecb33383f5aa936a6bc92acea6d91eb996e..4ae1dc55e06b434aeb4a95e2ca9aa68e4eef56de 100644 --- a/tensorflow/python/util/tf_export_test.py +++ b/tensorflow/python/util/tf_export_test.py @@ -60,6 +60,8 @@ class ValidateExportTest(test.TestCase): for symbol in [_test_function, _test_function, TestClassA, TestClassB]: if hasattr(symbol, '_tf_api_names'): del symbol._tf_api_names + if hasattr(symbol, '_tf_api_names_v1'): + del symbol._tf_api_names_v1 def _CreateMockModule(self, name): mock_module = self.MockModule(name) diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py index fbd65617670b15bfc69506bab1e83369081502af..778121e15bde6fb61a73cdf7ff8a3e2f34dd5266 100644 --- a/tensorflow/python/util/tf_inspect.py +++ b/tensorflow/python/util/tf_inspect.py @@ -184,7 +184,7 @@ else: Returns: A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations. """ - argspecs = _inspect.getargspec(target) + argspecs = getargspec(target) fullargspecs = FullArgSpec( args=argspecs.args, varargs=argspecs.varargs, @@ -300,6 +300,16 @@ def getsource(object): # pylint: disable=redefined-builtin return _inspect.getsource(tf_decorator.unwrap(object)[1]) +def getsourcefile(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.getsourcefile.""" + return _inspect.getsourcefile(tf_decorator.unwrap(object)[1]) + + +def getsourcelines(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.getsourcelines.""" + return _inspect.getsourcelines(tf_decorator.unwrap(object)[1]) + + def isbuiltin(object): # pylint: disable=redefined-builtin """TFDecorator-aware replacement for inspect.isbuiltin.""" return _inspect.isbuiltin(tf_decorator.unwrap(object)[1]) diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py index beaf350de1e469a7675a4b55ff341419262b79b2..d3b7e4b969bd9f3c1984c664a72fcf517252a5f6 100644 --- a/tensorflow/python/util/tf_inspect_test.py +++ b/tensorflow/python/util/tf_inspect_test.py @@ -122,6 +122,18 @@ class TfInspectTest(test.TestCase): self.assertEqual(argspec, tf_inspect.getargspec(partial_func)) + def testGetFullArgsSpecForPartial(self): + + def func(a, b): + del a, b + + partial_function = functools.partial(func, 1) + argspec = tf_inspect.FullArgSpec( + args=['b'], varargs=None, varkw=None, defaults=None, + kwonlyargs=[], kwonlydefaults=None, annotations={}) + + self.assertEqual(argspec, tf_inspect.getfullargspec(partial_function)) + def testGetArgSpecOnPartialInvalidArgspec(self): """Tests getargspec on partial function that doesn't have valid argspec.""" @@ -326,6 +338,18 @@ def test_decorated_function_with_defaults(a, b=2, c='Hello'): self.assertEqual( expected, tf_inspect.getsource(test_decorated_function_with_defaults)) + def testGetSourceFile(self): + self.assertEqual( + __file__, + tf_inspect.getsourcefile(test_decorated_function_with_defaults)) + + def testGetSourceLines(self): + expected = inspect.getsourcelines( + test_decorated_function_with_defaults.decorated_target) + self.assertEqual( + expected, + tf_inspect.getsourcelines(test_decorated_function_with_defaults)) + def testIsBuiltin(self): self.assertEqual( tf_inspect.isbuiltin(TestDecoratedClass), diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py index 28e49afa023904abed076373685bb38f2537b7d4..ca6710bcf2178db0fcf63c9bdfdf27531651f7ed 100644 --- a/tensorflow/python/util/tf_should_use.py +++ b/tensorflow/python/util/tf_should_use.py @@ -17,23 +17,124 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools -import types +import copy +import sys +import traceback import six # pylint: disable=unused-import -from tensorflow.python.eager import context +from tensorflow.python.platform import tf_logging from tensorflow.python.util import tf_decorator # pylint: enable=g-bad-import-order,g-import-not-at-top -# TODO(b/65412899): Re-implement to avoid leaking python objects. -# This function / class remains since the API is public (mark_used()). +class _TFShouldUseHelper(object): + """Object stored in TFShouldUse-wrapped objects. + + When it is deleted it will emit a warning or error if its `sate` method + has not been called by time of deletion. + """ + + def __init__(self, type_, repr_, stack_frame, fatal_error_if_unsated): + self._type = type_ + self._repr = repr_ + self._stack_frame = stack_frame + self._fatal_error_if_unsated = fatal_error_if_unsated + self._sated = False + + def sate(self): + self._sated = True + self._type = None + self._repr = None + self._stack_frame = None + self._logging_module = None + + def __del__(self): + if self._sated: + return + if self._fatal_error_if_unsated: + logger = tf_logging.fatal + else: + logger = tf_logging.error + creation_stack = ''.join( + [line.rstrip() for line in traceback.format_stack(self._stack_frame)]) + logger( + '==================================\n' + 'Object was never used (type %s):\n%s\nIf you want to mark it as ' + 'used call its "mark_used()" method.\nIt was originally created ' + 'here:\n%s\n' + '==================================' % + (self._type, self._repr, creation_stack)) + + +def _new__init__(self, true_value, tf_should_use_helper): + # pylint: disable=protected-access + self._tf_should_use_helper = tf_should_use_helper + self._true_value = true_value + + +def _new__setattr__(self, key, value): + if key in ('_tf_should_use_helper', '_true_value'): + return object.__setattr__(self, key, value) + return setattr( + object.__getattribute__(self, '_true_value'), + key, value) + + +def _new__getattribute__(self, key): + if key not in ('_tf_should_use_helper', '_true_value'): + object.__getattribute__(self, '_tf_should_use_helper').sate() + if key in ('_tf_should_use_helper', 'mark_used', '__setatt__'): + return object.__getattribute__(self, key) + return getattr(object.__getattribute__(self, '_true_value'), key) + + +def _new_mark_used(self, *args, **kwargs): + object.__getattribute__(self, '_tf_should_use_helper').sate() + try: + mu = object.__getattribute__( + object.__getattribute__(self, '_true_value'), + 'mark_used') + return mu(*args, **kwargs) + except AttributeError: + pass + + +_WRAPPERS = dict() + + +def _get_wrapper(x, tf_should_use_helper): + """Create a wrapper for object x, whose class subclasses type(x). + + The wrapper will emit a warning if it is deleted without any of its + properties being accessed or methods being called. + + Args: + x: The instance to wrap. + tf_should_use_helper: The object that tracks usage. + + Returns: + An object wrapping `x`, of type `type(x)`. + """ + type_x = type(x) + memoized = _WRAPPERS.get(type_x, None) + if memoized: + return memoized(x, tf_should_use_helper) + + tx = copy.deepcopy(type_x) + copy_tx = type(tx.__name__, tx.__bases__, dict(tx.__dict__)) + copy_tx.__init__ = _new__init__ + copy_tx.__getattribute__ = _new__getattribute__ + copy_tx.mark_used = _new_mark_used + copy_tx.__setattr__ = _new__setattr__ + _WRAPPERS[type_x] = copy_tx + + return copy_tx(x, tf_should_use_helper) + + def _add_should_use_warning(x, fatal_error=False): """Wraps object x so that if it is never used, a warning is logged. - Does nothing when executing eagerly. - Args: x: Python object. fatal_error: Python bool. If `True`, tf.logging.fatal is raised @@ -43,50 +144,22 @@ def _add_should_use_warning(x, fatal_error=False): An instance of `TFShouldUseWarningWrapper` which subclasses `type(x)` and is a very shallow wrapper for `x` which logs access into `x`. """ - del fatal_error if x is None or x == []: # pylint: disable=g-explicit-bool-comparison return x - if context.executing_eagerly(): - # Typically not needed when executing eagerly (the main use case is for ops - # which need to be incorporated into the graph), and even the no-op wrapper - # creates reference cycles which require garbage collection. - return x - - def override_method(method): - def fn(self, *args, **kwargs): - return method(self, *args, **kwargs) - return fn - - class TFShouldUseWarningWrapper(type(x)): - """Wrapper for objects that keeps track of their use.""" - - def __init__(self, true_self): - self.__dict__ = true_self.__dict__ + # Extract the current frame for later use by traceback printing. + try: + raise ValueError() + except ValueError: + stack_frame = sys.exc_info()[2].tb_frame.f_back - # Not sure why this pylint warning is being used; this is not an - # old class form. - # pylint: disable=super-on-old-class - def __getattribute__(self, name): - return super(TFShouldUseWarningWrapper, self).__getattribute__(name) - - def mark_used(self, *args, **kwargs): - return + tf_should_use_helper = _TFShouldUseHelper( + type_=type(x), + repr_=repr(x), + stack_frame=stack_frame, + fatal_error_if_unsated=fatal_error) - # pylint: enable=super-on-old-class - - for name in dir(TFShouldUseWarningWrapper): - method = getattr(TFShouldUseWarningWrapper, name) - if not isinstance(method, types.FunctionType): - continue - if name in ('__init__', '__getattribute__', '__del__', 'mark_used'): - continue - setattr(TFShouldUseWarningWrapper, name, - functools.wraps(method)(override_method(method))) - - wrapped = TFShouldUseWarningWrapper(x) - wrapped.__doc__ = x.__doc__ # functools.wraps fails on some objects. - return wrapped + return _get_wrapper(x, tf_should_use_helper) def should_use_result(fn): @@ -106,8 +179,6 @@ def should_use_result(fn): - `t != 0`. In this case, comparison is done on types / ids. - `isinstance(t, tf.Tensor)`. Similar to above. - Does nothing when executing eagerly. - Args: fn: The function to wrap. @@ -142,8 +213,6 @@ def must_use_result_or_fatal(fn): - `t != 0`. In this case, comparison is done on types / ids. - `isinstance(t, tf.Tensor)`. Similar to above. - Does nothing when executing eagerly. - Args: fn: The function to wrap. diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py index 4c6e48b11c1d013d1e4c6cdfc376973baa7bb9a2..16fa1f547d4c6b9d2c4da6994d380ba2b671b886 100644 --- a/tensorflow/python/util/tf_should_use_test.py +++ b/tensorflow/python/util/tf_should_use_test.py @@ -30,48 +30,51 @@ from tensorflow.python.util import tf_should_use @contextlib.contextmanager -def reroute_error(captured): +def reroute_error(): """Temporarily reroute errors written to tf_logging.error into `captured`.""" - del captured[:] - true_logger = tf_logging.error - def capture_errors(*args, **unused_kwargs): - captured.extend(args) - tf_logging.error = capture_errors - try: - yield - finally: - tf_logging.error = true_logger + with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error: + with test.mock.patch.object(tf_should_use.tf_logging, 'fatal') as fatal: + yield error, fatal class TfShouldUseTest(test.TestCase): def testAddShouldUseWarningWhenNotUsed(self): - self.skipTest('b/65412899') c = constant_op.constant(0, name='blah0') - captured = [] - with reroute_error(captured): - def in_this_function(): - h = tf_should_use._add_should_use_warning(c) - del h + def in_this_function(): + h = tf_should_use._add_should_use_warning(c) + del h + with reroute_error() as (error, _): in_this_function() - self.assertIn('Object was never used', '\n'.join(captured)) - self.assertIn('blah0:0', '\n'.join(captured)) - self.assertIn('in_this_function', '\n'.join(captured)) - gc.collect() + msg = '\n'.join(error.call_args[0]) + self.assertIn('Object was never used', msg) + self.assertIn('blah0:0', msg) + self.assertIn('in_this_function', msg) + self.assertFalse(gc.garbage) + + def testAddShouldUseFatalWhenNotUsed(self): + c = constant_op.constant(0, name='blah0') + def in_this_function(): + h = tf_should_use._add_should_use_warning(c, fatal_error=True) + del h + with reroute_error() as (_, fatal): + in_this_function() + msg = '\n'.join(fatal.call_args[0]) + self.assertIn('Object was never used', msg) + self.assertIn('blah0:0', msg) + self.assertIn('in_this_function', msg) self.assertFalse(gc.garbage) def _testAddShouldUseWarningWhenUsed(self, fn, name): c = constant_op.constant(0, name=name) - captured = [] - with reroute_error(captured): + with reroute_error() as (error, fatal): h = tf_should_use._add_should_use_warning(c) fn(h) del h - self.assertNotIn('Object was never used', '\n'.join(captured)) - self.assertNotIn('%s:0' % name, '\n'.join(captured)) + error.assert_not_called() + fatal.assert_not_called() def testAddShouldUseWarningWhenUsedWithAdd(self): - self.skipTest('b/65412899') def add(h): _ = h + 1 self._testAddShouldUseWarningWhenUsed(add, name='blah_add') @@ -79,7 +82,6 @@ class TfShouldUseTest(test.TestCase): self.assertFalse(gc.garbage) def testAddShouldUseWarningWhenUsedWithGetName(self): - self.skipTest('b/65412899') def get_name(h): _ = h.name self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name') @@ -87,35 +89,33 @@ class TfShouldUseTest(test.TestCase): self.assertFalse(gc.garbage) def testShouldUseResult(self): - self.skipTest('b/65412899') @tf_should_use.should_use_result def return_const(value): return constant_op.constant(value, name='blah2') - captured = [] - with reroute_error(captured): + with reroute_error() as (error, _): return_const(0.0) - self.assertIn('Object was never used', '\n'.join(captured)) - self.assertIn('blah2:0', '\n'.join(captured)) - self.assertIn('return_const', '\n'.join(captured)) + msg = '\n'.join(error.call_args[0]) + self.assertIn('Object was never used', msg) + self.assertIn('blah2:0', msg) + self.assertIn('return_const', msg) gc.collect() self.assertFalse(gc.garbage) def testShouldUseResultWhenNotReallyUsed(self): - self.skipTest('b/65412899') @tf_should_use.should_use_result def return_const(value): return constant_op.constant(value, name='blah3') - captured = [] - with reroute_error(captured): + with reroute_error() as (error, _): with self.test_session(): return_const(0.0) # Creating another op and executing it does not mark the # unused op as being "used". v = constant_op.constant(1.0, name='meh') v.eval() - self.assertIn('Object was never used', '\n'.join(captured)) - self.assertIn('blah3:0', '\n'.join(captured)) - self.assertIn('return_const', '\n'.join(captured)) + msg = '\n'.join(error.call_args[0]) + self.assertIn('Object was never used', msg) + self.assertIn('blah3:0', msg) + self.assertIn('return_const', msg) gc.collect() self.assertFalse(gc.garbage) diff --git a/tensorflow/python/util/tf_stack.py b/tensorflow/python/util/tf_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4f4a63eb52d4b9549f42ddeb00f7d95f15d5d2 --- /dev/null +++ b/tensorflow/python/util/tf_stack.py @@ -0,0 +1,103 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions used to extract and analyze stacks. Faster than Python libs.""" +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import linecache +import sys + +# Names for indices into TF traceback tuples. +TB_FILENAME = 0 +TB_LINENO = 1 +TB_FUNCNAME = 2 +TB_CODEDICT = 3 # Dictionary of Python interpreter state. + + +def extract_stack(extract_frame_info_fn=None): + """A lightweight, extensible re-implementation of traceback.extract_stack. + + NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for + each stack frame using linecache, which results in an abundance of stat() + calls. This implementation does not retrieve the code, and any consumer + should apply _convert_stack to the result to obtain a traceback that can + be formatted etc. using traceback methods. + + Args: + extract_frame_info_fn: Optional callable fn(stack_frame) applied to each + stack frame. This callable's return value is stored as the sixth (last) + element of the returned tuples. If not provided, the returned tuples + will have None as their sixth value. + + Returns: + A list of 6-tuples + (filename, lineno, name, frame_globals, func_start_lineno, custom_info) + corresponding to the call stack of the current thread. The returned tuples + have the innermost stack frame at the end, unlike the Python inspect + module's stack() function. + """ + default_fn = lambda f: None + extract_frame_info_fn = extract_frame_info_fn or default_fn + try: + raise ZeroDivisionError + except ZeroDivisionError: + f = sys.exc_info()[2].tb_frame.f_back + ret = [] + while f is not None: + lineno = f.f_lineno + co = f.f_code + filename = co.co_filename + name = co.co_name + frame_globals = f.f_globals + func_start_lineno = co.co_firstlineno + frame_info = extract_frame_info_fn(f) + ret.append((filename, lineno, name, frame_globals, func_start_lineno, + frame_info)) + f = f.f_back + ret.reverse() + return ret + + +def convert_stack(stack, include_func_start_lineno=False): + """Converts a stack extracted using extract_stack() to a traceback stack. + + Args: + stack: A list of n 5-tuples, + (filename, lineno, name, frame_globals, func_start_lineno). + include_func_start_lineno: True if function start line number should be + included as the 5th entry in return tuples. + + Returns: + A list of n 4-tuples or 5-tuples + (filename, lineno, name, code, [optional: func_start_lineno]), where the + code tuple element is calculated from the corresponding elements of the + input tuple. + """ + ret = [] + for (filename, lineno, name, frame_globals, func_start_lineno, + unused_frame_info) in stack: + linecache.checkcache(filename) + line = linecache.getline(filename, lineno, frame_globals) + if line: + line = line.strip() + else: + line = None + if include_func_start_lineno: + ret.append((filename, lineno, name, line, func_start_lineno)) + else: + ret.append((filename, lineno, name, line)) + return ret diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index c79d8a84458800937e3e51a8dae26605bd834233..ebb72079efdf4ab84527bf67a0b3e0dd97d9c800 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -31,6 +31,8 @@ namespace { // Type object for collections.Sequence. This is set by RegisterSequenceClass. PyObject* CollectionsSequenceType = nullptr; +// Type object for collections.Mapping, set by RegisterMappingClass. +PyObject* CollectionsMappingType = nullptr; PyTypeObject* SparseTensorValueType = nullptr; const int kMaxItemsInCache = 1024; @@ -45,6 +47,28 @@ bool IsString(PyObject* o) { PyUnicode_Check(o); } +// Work around a writable-strings warning with Python 2's PyMapping_Keys macro, +// and while we're at it give them consistent behavior by making sure the +// returned value is a list. +// +// As with PyMapping_Keys, returns a new reference. +// +// On failure, returns nullptr. +PyObject* MappingKeys(PyObject* o) { +#if PY_MAJOR_VERSION >= 3 + return PyMapping_Keys(o); +#else + static char key_method_name[] = "keys"; + Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr)); + if (PyErr_Occurred() || raw_result.get() == nullptr) { + return nullptr; + } + return PySequence_Fast( + raw_result.get(), + "The '.keys()' method of a custom mapping returned a non-sequence."); +#endif +} + // Equivalent to Python's 'o.__class__.__name__' // Note that '__class__' attribute is set only in new-style classes. // A lot of tensorflow code uses __class__ without checks, so it seems like @@ -85,6 +109,119 @@ string PyObjectToString(PyObject* o) { } } +class CachedTypeCheck { + public: + explicit CachedTypeCheck(std::function ternary_predicate) + : ternary_predicate_(std::move(ternary_predicate)) {} + + ~CachedTypeCheck() { + mutex_lock l(type_to_sequence_map_mu_); + for (const auto& pair : type_to_sequence_map_) { + Py_DECREF(pair.first); + } + } + + // Caches successful executions of the one-argument (PyObject*) callable + // "ternary_predicate" based on the type of "o". -1 from the callable + // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type + // does not match the predicate, and 1 indicates that it does. Used to avoid + // calling back into Python for expensive isinstance checks. + int CachedLookup(PyObject* o) { + // Try not to return to Python - see if the type has already been seen + // before. + + auto* type = Py_TYPE(o); + + { + mutex_lock l(type_to_sequence_map_mu_); + auto it = type_to_sequence_map_.find(type); + if (it != type_to_sequence_map_.end()) { + return it->second; + } + } + + int check_result = ternary_predicate_(o); + + if (check_result == -1) { + return -1; // Type check error, not cached. + } + + // NOTE: This is never decref'd as long as the object lives, which is likely + // forever, but we don't want the type to get deleted as long as it is in + // the map. This should not be too much of a leak, as there should only be a + // relatively small number of types in the map, and an even smaller number + // that are eligible for decref. As a precaution, we limit the size of the + // map to 1024. + { + mutex_lock l(type_to_sequence_map_mu_); + if (type_to_sequence_map_.size() < kMaxItemsInCache) { + Py_INCREF(type); + type_to_sequence_map_.insert({type, check_result}); + } + } + + return check_result; + } + + private: + std::function ternary_predicate_; + mutex type_to_sequence_map_mu_; + std::unordered_map type_to_sequence_map_ + GUARDED_BY(type_to_sequence_map_mu_); +}; + +// Returns 1 if `o` is considered a mapping for the purposes of Flatten(). +// Returns 0 otherwise. +// Returns -1 if an error occurred. +int IsMappingHelper(PyObject* o) { + static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { + return PyObject_IsInstance(to_check, CollectionsMappingType); + }); + if (PyDict_Check(o)) return true; + if (TF_PREDICT_FALSE(CollectionsMappingType == nullptr)) { + PyErr_SetString( + PyExc_RuntimeError, + tensorflow::strings::StrCat( + "collections.Mapping type has not been set. " + "Please call RegisterMappingClass before using this module") + .c_str()); + return -1; + } + return check_cache->CachedLookup(o); +} + +// Returns 1 if `o` is considered a sequence for the purposes of Flatten(). +// Returns 0 otherwise. +// Returns -1 if an error occurred. +int IsSequenceHelper(PyObject* o) { + static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { + int is_instance = PyObject_IsInstance(to_check, CollectionsSequenceType); + + // Don't cache a failed is_instance check. + if (is_instance == -1) return -1; + + return static_cast(is_instance != 0 && !IsString(to_check)); + }); + // We treat dicts and other mappings as special cases of sequences. + if (IsMappingHelper(o)) return true; + if (PySet_Check(o) && !WarnedThatSetIsNotSequence) { + LOG(WARNING) << "Sets are not currently considered sequences, " + "but this may change in the future, " + "so consider avoiding using them."; + WarnedThatSetIsNotSequence = true; + } + if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) { + PyErr_SetString( + PyExc_RuntimeError, + tensorflow::strings::StrCat( + "collections.Sequence type has not been set. " + "Please call RegisterSequenceClass before using this module") + .c_str()); + return -1; + } + return check_cache->CachedLookup(o); +} + // Implements the same idea as tensorflow.util.nest._yield_value // During construction we check if the iterable is a dictionary. // If so, we construct a sequence from its sorted keys that will be used @@ -96,7 +233,12 @@ string PyObjectToString(PyObject* o) { // 'iterable' must not be modified while ValIterator is used. class ValIterator { public: - explicit ValIterator(PyObject* iterable) : dict_(nullptr), index_(0) { + explicit ValIterator(PyObject* iterable) + : dict_(nullptr), + mapping_(nullptr), + last_mapping_element_(nullptr), + seq_(nullptr), + index_(0) { if (PyDict_Check(iterable)) { dict_ = iterable; // PyDict_Keys returns a list, which can be used with @@ -108,6 +250,10 @@ class ValIterator { // bugs caused by mixing ordered and plain dicts (e.g., flattening // a dict but using a corresponding `OrderedDict` to pack it back). PyList_Sort(seq_); + } else if (IsMappingHelper(iterable)) { + mapping_ = iterable; + seq_ = MappingKeys(iterable); + PyList_Sort(seq_); } else { seq_ = PySequence_Fast(iterable, ""); } @@ -119,10 +265,15 @@ class ValIterator { // Return a borrowed reference to the next element from iterable. // Return nullptr when iteration is over. PyObject* next() { + if (TF_PREDICT_FALSE(seq_ == nullptr)) { + return nullptr; + } PyObject* element = nullptr; if (index_ < size_) { // Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed - // references. + // references. For general mappings, ValIterator keeps a reference to the + // last retrieved element (and decrefs it before producing the next + // element) to abstract away the borrowed/new difference. element = PySequence_Fast_GET_ITEM(seq_, index_); ++index_; if (dict_ != nullptr) { @@ -132,85 +283,32 @@ class ValIterator { "Dictionary was modified during iteration over it"); return nullptr; } + } else if (mapping_ != nullptr) { + element = PyObject_GetItem(mapping_, element); + if (element == nullptr) { + PyErr_SetString(PyExc_RuntimeError, + "Mapping was modified during iteration over it"); + return nullptr; + } + last_mapping_element_.reset(element); } } return element; } private: - PyObject* seq_; + // Special casing for things that pass PyDict_Check (faster, no Python calls) PyObject* dict_; + + // General mappings which have custom Python logic + PyObject* mapping_; + Safe_PyObjectPtr last_mapping_element_; + + PyObject* seq_; Py_ssize_t size_; Py_ssize_t index_; }; -mutex g_type_to_sequence_map(LINKER_INITIALIZED); -std::unordered_map* IsTypeSequenceMap() { - static auto* const m = new std::unordered_map; - return m; -} - -// Returns 1 if `o` is considered a sequence for the purposes of Flatten(). -// Returns 0 otherwise. -// Returns -1 if an error occurred. -int IsSequenceHelper(PyObject* o) { - if (PyDict_Check(o)) return true; - if (PySet_Check(o) && !WarnedThatSetIsNotSequence) { - LOG(WARNING) << "Sets are not currently considered sequences, " - "but this may change in the future, " - "so consider avoiding using them."; - WarnedThatSetIsNotSequence = true; - } - if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) { - PyErr_SetString( - PyExc_RuntimeError, - tensorflow::strings::StrCat( - "collections.Sequence type has not been set. " - "Please call RegisterSequenceClass before using this module") - .c_str()); - return -1; - } - - // Try not to return to Python - see if the type has already been seen - // before. - - auto* type_to_sequence_map = IsTypeSequenceMap(); - auto* type = Py_TYPE(o); - - { - mutex_lock l(g_type_to_sequence_map); - auto it = type_to_sequence_map->find(type); - if (it != type_to_sequence_map->end()) { - return it->second; - } - } - - // NOTE: We explicitly release the g_type_to_sequence_map mutex, - // because PyObject_IsInstance() may release the GIL, allowing another thread - // concurrent entry to this function. - int is_instance = PyObject_IsInstance(o, CollectionsSequenceType); - - // Don't cache a failed is_instance check. - if (is_instance == -1) return -1; - - bool is_sequence = static_cast(is_instance != 0 && !IsString(o)); - - // NOTE: This is never decref'd, but we don't want the type to get deleted - // as long as it is in the map. This should not be too much of a - // leak, as there should only be a relatively small number of types in the - // map, and an even smaller number that are eligible for decref. As a - // precaution, we limit the size of the map to 1024. - { - mutex_lock l(g_type_to_sequence_map); - if (type_to_sequence_map->size() < kMaxItemsInCache) { - Py_INCREF(type); - type_to_sequence_map->insert({type, is_sequence}); - } - } - - return is_sequence; -} - bool IsSparseTensorValueType(PyObject* o) { if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) { return false; @@ -226,21 +324,35 @@ int IsSequenceForDataHelper(PyObject* o) { bool GetNextValuesForDict(PyObject* nested, std::vector* next_values) { - std::vector result; - - PyObject* keys = PyDict_Keys(nested); - if (PyList_Sort(keys) == -1) return false; - Py_ssize_t size = PyList_Size(keys); + Safe_PyObjectPtr keys(PyDict_Keys(nested)); + if (PyList_Sort(keys.get()) == -1) return false; + Py_ssize_t size = PyList_Size(keys.get()); for (Py_ssize_t i = 0; i < size; ++i) { // We know that key and item will not be deleted because nested owns // a reference to them and callers of flatten must not modify nested // while the method is running. - PyObject* key = PyList_GET_ITEM(keys, i); + PyObject* key = PyList_GET_ITEM(keys.get(), i); PyObject* item = PyDict_GetItem(nested, key); Py_INCREF(item); next_values->emplace_back(item); } - Py_DECREF(keys); + return true; +} + +bool GetNextValuesForMapping(PyObject* nested, + std::vector* next_values) { + Safe_PyObjectPtr keys(MappingKeys(nested)); + if (keys.get() == nullptr) { + return false; + } + if (PyList_Sort(keys.get()) == -1) return false; + Py_ssize_t size = PyList_Size(keys.get()); + for (Py_ssize_t i = 0; i < size; ++i) { + PyObject* key = PyList_GET_ITEM(keys.get(), i); + // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference. + PyObject* item = PyObject_GetItem(nested, key); + next_values->emplace_back(item); + } return true; } @@ -265,6 +377,9 @@ bool GetNextValues(PyObject* nested, if (PyDict_Check(nested)) { // if nested is dictionary, sort it by key and recurse on each value return GetNextValuesForDict(nested, next_values); + } else if (IsMappingHelper(nested)) { + // same treatment as dictionaries, but for custom mapping types + return GetNextValuesForMapping(nested, next_values); } // iterate and recurse return GetNextValuesForIterable(nested, next_values); @@ -276,6 +391,9 @@ bool GetNextValuesForData(PyObject* nested, if (PyDict_Check(nested)) { // if nested is dictionary, sort it by key and recurse on each value return GetNextValuesForDict(nested, next_values); + } else if (IsMappingHelper(nested)) { + // same treatment as dictionaries, but for custom mapping types + return GetNextValuesForMapping(nested, next_values); } else if (IsSparseTensorValueType(nested)) { // if nested is a SparseTensorValue, just return itself as a single item Py_INCREF(nested); @@ -320,16 +438,26 @@ bool FlattenHelper( // 'dict1' and 'dict2' are assumed to be Python dictionaries. void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg, bool* is_type_error) { - PyObject* k1 = PyDict_Keys(dict1); - PyObject* k2 = PyDict_Keys(dict2); + Safe_PyObjectPtr k1(MappingKeys(dict1)); + if (PyErr_Occurred() || k1.get() == nullptr) { + *error_msg = + ("The two dictionaries don't have the same set of keys. Failed to " + "fetch keys."); + return; + } + Safe_PyObjectPtr k2(MappingKeys(dict2)); + if (PyErr_Occurred() || k2.get() == nullptr) { + *error_msg = + ("The two dictionaries don't have the same set of keys. Failed to " + "fetch keys."); + return; + } *is_type_error = false; *error_msg = tensorflow::strings::StrCat( "The two dictionaries don't have the same set of keys. " "First structure has keys ", - PyObjectToString(k1), ", while second structure has keys ", - PyObjectToString(k2)); - Py_DECREF(k1); - Py_DECREF(k2); + PyObjectToString(k1.get()), ", while second structure has keys ", + PyObjectToString(k2.get())); } // Returns true iff there were no "internal" errors. In other words, @@ -394,7 +522,14 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types, type2->tp_name); return true; } - } else if (type1 != type2) { + } else if (type1 != type2 + /* If both sequences are list types, don't complain. This allows + one to be a list subclass (e.g. _ListWrapper used for + automatic dependency tracking.) */ + && !(PyList_Check(o1) && PyList_Check(o2)) + /* Two mapping types will also compare equal, making _DictWrapper + and dict compare equal. */ + && !(IsMappingHelper(o1) && IsMappingHelper(o2))) { *is_type_error = true; *error_msg = tensorflow::strings::StrCat( "The two namedtuples don't have the same sequence type. " @@ -405,7 +540,7 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types, return true; } - if (PyDict_Check(o1)) { + if (PyDict_Check(o1) && PyDict_Check(o2)) { if (PyDict_Size(o1) != PyDict_Size(o2)) { SetDifferentKeysError(o1, o2, error_msg, is_type_error); return true; @@ -419,6 +554,24 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types, return true; } } + } else if (IsMappingHelper(o1)) { + // Fallback for custom mapping types. Instead of using PyDict methods + // which stay in C, we call iter(o1). + if (PyMapping_Size(o1) != PyMapping_Size(o2)) { + SetDifferentKeysError(o1, o2, error_msg, is_type_error); + return true; + } + + Safe_PyObjectPtr iter(PyObject_GetIter(o1)); + PyObject* key; + while ((key = PyIter_Next(iter.get())) != nullptr) { + if (!PyMapping_HasKey(o2, key)) { + SetDifferentKeysError(o1, o2, error_msg, is_type_error); + Py_DECREF(key); + return true; + } + Py_DECREF(key); + } } } @@ -466,6 +619,19 @@ void RegisterSequenceClass(PyObject* sequence_class) { CollectionsSequenceType = sequence_class; } +void RegisterMappingClass(PyObject* mapping_class) { + if (!PyType_Check(mapping_class)) { + PyErr_SetString( + PyExc_TypeError, + tensorflow::strings::StrCat( + "Expecting a class definition for `collections.Mapping`. Got ", + Py_TYPE(mapping_class)->tp_name) + .c_str()); + return; + } + CollectionsMappingType = mapping_class; +} + void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) { if (!PyType_Check(sparse_tensor_value_class)) { PyErr_SetString( @@ -593,6 +759,11 @@ PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types) { string error_msg; bool is_type_error = false; AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error); + if (PyErr_Occurred()) { + // Don't hide Python exceptions while checking (e.g. errors fetching keys + // from custom mappings). + return nullptr; + } if (!error_msg.empty()) { PyErr_SetString( is_type_error ? PyExc_TypeError : PyExc_ValueError, diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index 70efc10c9abe7c57da61311bb2eb7ae362a48e3d..41dcc969f88f8bb81b70b734cbf41fe10135a9c0 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -118,7 +118,9 @@ PyObject* Flatten(PyObject* nested); // the type from the module. This approach also requires some trigger from // Python so that we know that Python interpreter had been initialzied. void RegisterSequenceClass(PyObject* sequence_class); -// Similar to the above function, except for the +// Like RegisterSequenceClass, but for collections.Mapping. +void RegisterMappingClass(PyObject* mapping_class); +// Similar to the above functions, except for the // sparse_tensor.SparseTensorValue class. void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class); diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i index 9f3b11b982bb0d52f903b09975cc7029fa8cb013..6ad148429541a6855a4f1da5a5c89d6479f53f39 100644 --- a/tensorflow/python/util/util.i +++ b/tensorflow/python/util/util.i @@ -31,6 +31,9 @@ limitations under the License. %unignore tensorflow::swig::RegisterSequenceClass; %noexception tensorflow::swig::RegisterSequenceClass; +%unignore tensorflow::swig::RegisterMappingClass; +%noexception tensorflow::swig::RegisterMappingClass; + %unignore tensorflow::swig::RegisterSparseTensorValueClass; %noexception tensorflow::swig::RegisterSparseTensorValueClass; diff --git a/tensorflow/security/advisory/tfsa-2018-001.md b/tensorflow/security/advisory/tfsa-2018-001.md index bb97543a21988b4370ddac912102add6a10e2b35..1966789c8467539ef7f19e281b3a4acfbaace6ae 100644 --- a/tensorflow/security/advisory/tfsa-2018-001.md +++ b/tensorflow/security/advisory/tfsa-2018-001.md @@ -22,7 +22,7 @@ TensorFlow 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0 ### Mitigation We have patched the vulnerability in GitHub commit -[49f73c55](https://github.com/tensorflow/tensorflow/commit/49f73c55d56edffebde4bca4a407ad69c1cae4333c55). +[49f73c55](https://github.com/tensorflow/tensorflow/commit/49f73c55d56edffebde4bca4a407ad69c1cae433). If users are running TensorFlow in production or on untrusted data, they are encouraged to apply this patch. diff --git a/tensorflow/security/index.md b/tensorflow/security/index.md index ea39e17ab2bb417bba1ebe4a589833309fc2c626..0f176151c2c4527d60c0cb451d33c9206a50bd81 100644 --- a/tensorflow/security/index.md +++ b/tensorflow/security/index.md @@ -4,7 +4,7 @@ We regularly publish security advisories about using TensorFlow. *Note*: In conjunction with these security advisories, we strongly encourage TensorFlow users to read and understand TensorFlow's security model as outlined -in (https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md)[SECURITY.md]. +in [SECURITY.md](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md). | Advisory Number | Type | Versions affected | Reported by | Additional Information | |-----------------|--------------------|:-----------------:|-----------------------|-----------------------------| diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD index c68cda01002b1c5bbc2facb95b1eba214fbad7cb..d4d97087ba48087acf2313ca16fa2144bca649be 100644 --- a/tensorflow/stream_executor/BUILD +++ b/tensorflow/stream_executor/BUILD @@ -2,6 +2,7 @@ licenses(["restricted"]) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") STREAM_EXECUTOR_HEADERS = glob([ "*.h", @@ -29,11 +30,11 @@ cc_library( hdrs = STREAM_EXECUTOR_HEADERS, linkopts = select({ "//tensorflow:freebsd": [], + "//tensorflow:windows": [], "//conditions:default": ["-ldl"], }), visibility = ["//visibility:public"], deps = [ - "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "@local_config_cuda//cuda:cuda_headers", @@ -48,11 +49,18 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:ptr_util", - "//tensorflow/compiler/xla:statusor", "@local_config_cuda//cuda:cuda_headers", ] + if_static([":stream_executor_impl"]), ) +cc_header_only_library( + name = "stream_executor_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":stream_executor", + ], +) + cc_library( name = "cuda_platform", srcs = if_cuda_is_configured( @@ -72,6 +80,7 @@ cc_library( }), linkopts = select({ "//tensorflow:freebsd": [], + "//tensorflow:windows": [], "//conditions:default": ["-ldl"], }), visibility = ["//visibility:public"], diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index ea87744b225215ceb24b926f1ef7bace017cb2b8..7f851e36462eab5e7c58246bbab551f6834ae87a 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -1121,6 +1121,40 @@ class BlasSupport { const port::ArraySlice> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; + // Batched gemm with strides instead of pointer arrays. + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory &a, + int lda, int64 stride_a, const DeviceMemory &b, int ldb, + int64 stride_b, float beta, DeviceMemory *c, int ldc, + int64 stride_c, int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + float beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + double beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count) = 0; + virtual bool DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count) = 0; + // Computes a matrix-matrix product where one input matrix is Hermitian: // // c <- alpha * a * b + beta * c, @@ -1990,6 +2024,38 @@ class BlasSupport { int ldb, std::complex beta, \ const port::ArraySlice> *> &c, \ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, \ + const DeviceMemory &a, int lda, int64 stride_a, \ + const DeviceMemory &b, int ldb, int64 stride_b, float beta, \ + DeviceMemory *c, int ldc, int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory &a, \ + int lda, int64 stride_a, const DeviceMemory &b, int ldb, \ + int64 stride_b, float beta, DeviceMemory *c, int ldc, \ + int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, double alpha, \ + const DeviceMemory &a, int lda, int64 stride_a, \ + const DeviceMemory &b, int ldb, int64 stride_b, double beta, \ + DeviceMemory *c, int ldc, int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex alpha, \ + const DeviceMemory> &a, int lda, int64 stride_a, \ + const DeviceMemory> &b, int ldb, int64 stride_b, \ + std::complex beta, DeviceMemory> *c, int ldc, \ + int64 stride_c, int batch_count); \ + bool DoBlasGemmStridedBatched( \ + Stream *stream, blas::Transpose transa, blas::Transpose transb, \ + uint64 m, uint64 n, uint64 k, std::complex alpha, \ + const DeviceMemory> &a, int lda, int64 stride_a, \ + const DeviceMemory> &b, int ldb, int64 stride_b, \ + std::complex beta, DeviceMemory> *c, \ + int ldc, int64 stride_c, int batch_count); \ bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ uint64 m, uint64 n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 92c1a5fc07190e77a6ecc814145ec06da303e27e..ab7091b3f54727874097f3887cfb63376ed34c9a 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -279,6 +279,10 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx) #if CUDA_VERSION >= 8000 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmStridedBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmStridedBatched) #endif #if CUDA_VERSION >= 9000 @@ -288,6 +292,7 @@ STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode) #if CUDA_VERSION >= 9010 STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmBatchedEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmStridedBatchedEx) #endif } // namespace wrap @@ -643,7 +648,7 @@ bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, } #endif cublasStatus_t ret = cublas_func(parent_, blas_, args...); - if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) { + if ((err_on_failure || VLOG_IS_ON(3)) && ret != CUBLAS_STATUS_SUCCESS) { LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": " << ToString(ret); } @@ -1865,7 +1870,7 @@ bool CUDABlas::DoBlasGemm( stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); - // GPUs < sm_70 don't support Volta hardware. + // GPUs < sm_70 don't support tensor ops. if (cc_major >= 7 && TensorOpMathEnabled()) { use_tensor_ops = true; } @@ -2139,6 +2144,10 @@ static bool UsesTensorOps(blas::AlgorithmType algo) { template static bool TensorOpsAvailable(int cc_major) { #if CUDA_VERSION >= 9000 + // cublas *does* allow tensor ops on inputs that are not fp16, so this is not + // strictly correct. We can't simply enable it, though, as that would change + // clients' behavior significantly: Using tensor ops on fp32 inputs cause them + // to be rounded to fp16. if (cc_major >= 7 && TensorOpMathEnabled() && std::is_same::value) { return true; @@ -2160,16 +2169,30 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( if (stream->parent()->GetDeviceDescription().cuda_compute_capability( &cc_major, &cc_minor) && cc_major < 5) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because sm" << cc_major + << cc_minor << " devices don't support explicit gemm algorithms."; return false; } if (UsesTensorOps(algorithm) && !TensorOpsAvailable(cc_major)) { + if (std::is_same::value) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " + << algorithm + << " uses tensor ops, but tensor ops are not available in sm" + << cc_major << "X devices."; + } else { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because algorithm " + << algorithm + << " uses tensor ops, but the input data type is not fp16."; + } return false; } // Either both 'alpha' and 'beta' need to be pointers to device memory, or // they need to be both host scalars. if (alpha.is_pointer() != beta.is_pointer()) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because one of `alpha` " + "and `beta` is a pointer, but the other is not."; return false; } @@ -2177,16 +2200,23 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( if (output_profile_result != nullptr) { timer.reset(new CUDATimer(parent_)); if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false because " + "output_profile_result was given, but we were unable to " + "create a CUDATimer."; return false; } } // Return false if we might be hitting a cuBLAS bug that produces the wrong // result. See nvbugs/2156201, b/79126339. - if (CUDA_VERSION < 9020 && algorithm != CUBLAS_GEMM_ALGO12 && +#if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020 + if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) && std::max({m, n, k}) >= 2097153 && cc_major < 7) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false to work around cudnn " + "<9.2 bug with m, n, or k >= 2097153. See b/79126339."; return false; } +#endif cudaDataType_t cuda_in_type = CUDADataType::type; // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast, @@ -2209,6 +2239,8 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error // state. if (!timer->Stop(AsCUDAStream(stream))) { + VLOG(2) << "DoBlasGemmWithAlgorithm returning false; unable to stop " + "CUDATimer."; return false; } output_profile_result->set_is_valid(true); @@ -2221,26 +2253,60 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl( bool CUDABlas::GetBlasGemmAlgorithms( std::vector *out_algorithms) { -// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) -// were first introduced in CUDA 8. -// Note that when CUDA version and compute capability is not sufficient, we -// still return the out_algorithms. Caller needs to make sure that in this case, -// the returned vector is empty. - for (cublasGemmAlgo_t algo : { - CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, - CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4, - CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7, + // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) + // were first introduced in CUDA 8. + // + // Note that when CUDA version and compute capability is not sufficient, we + // still return the out_algorithms. Caller needs to make sure that in this + // case, the returned vector is empty. + *out_algorithms = { + CUBLAS_GEMM_DFALT, + CUBLAS_GEMM_ALGO0, + CUBLAS_GEMM_ALGO1, + CUBLAS_GEMM_ALGO2, + CUBLAS_GEMM_ALGO3, + CUBLAS_GEMM_ALGO4, + CUBLAS_GEMM_ALGO5, + CUBLAS_GEMM_ALGO6, + CUBLAS_GEMM_ALGO7, #if CUDA_VERSION >= 9000 - CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10, - CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13, - CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16, - CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP, - CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP, - CUBLAS_GEMM_ALGO2_TENSOR_OP + CUBLAS_GEMM_ALGO8, + CUBLAS_GEMM_ALGO9, + CUBLAS_GEMM_ALGO10, + CUBLAS_GEMM_ALGO11, + CUBLAS_GEMM_ALGO12, + CUBLAS_GEMM_ALGO13, + CUBLAS_GEMM_ALGO14, + CUBLAS_GEMM_ALGO15, + CUBLAS_GEMM_ALGO16, + CUBLAS_GEMM_ALGO17, + CUBLAS_GEMM_DFALT_TENSOR_OP, + CUBLAS_GEMM_ALGO0_TENSOR_OP, + CUBLAS_GEMM_ALGO1_TENSOR_OP, + CUBLAS_GEMM_ALGO2_TENSOR_OP, + CUBLAS_GEMM_ALGO3_TENSOR_OP, + CUBLAS_GEMM_ALGO4_TENSOR_OP, #endif - }) { - out_algorithms->push_back(algo); - } +#if CUDA_VERSION >= 9200 + CUBLAS_GEMM_ALGO18, + CUBLAS_GEMM_ALGO19, + CUBLAS_GEMM_ALGO20, + CUBLAS_GEMM_ALGO21, + CUBLAS_GEMM_ALGO22, + CUBLAS_GEMM_ALGO23, + CUBLAS_GEMM_ALGO5_TENSOR_OP, + CUBLAS_GEMM_ALGO6_TENSOR_OP, + CUBLAS_GEMM_ALGO7_TENSOR_OP, + CUBLAS_GEMM_ALGO8_TENSOR_OP, + CUBLAS_GEMM_ALGO9_TENSOR_OP, + CUBLAS_GEMM_ALGO10_TENSOR_OP, + CUBLAS_GEMM_ALGO11_TENSOR_OP, + CUBLAS_GEMM_ALGO12_TENSOR_OP, + CUBLAS_GEMM_ALGO13_TENSOR_OP, + CUBLAS_GEMM_ALGO14_TENSOR_OP, + CUBLAS_GEMM_ALGO15_TENSOR_OP, +#endif + }; return true; } @@ -2562,6 +2628,119 @@ bool CUDABlas::DoBlasGemmBatched( return status.ok(); } +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory &a, + int lda, int64 stride_a, const DeviceMemory &b, int ldb, + int64 stride_b, float beta, DeviceMemory *c, int ldc, + int64 stride_c, int batch_count) { + bool use_tensor_ops = false; +#if CUDA_VERSION >= 9000 + int cc_major, cc_minor; + if (stream->parent()->GetDeviceDescription().cuda_compute_capability( + &cc_major, &cc_minor)) { + // GPUs < sm_70 don't support tensor ops. + if (cc_major >= 7 && TensorOpMathEnabled()) { + use_tensor_ops = true; + } +#if CUDA_VERSION >= 9010 + if (cc_major >= 5) { + cublasGemmAlgo_t algo = + (use_tensor_ops ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); + bool ok = DoBlasInternalImpl( + wrap::cublasGemmStridedBatchedEx, stream, + true /* = pointer_mode_host */, true /* = err_on_failure */, + use_tensor_ops, CUDABlasTranspose(transa), CUDABlasTranspose(transb), + m, n, k, &alpha, CUDAMemory(a), CUDA_R_16F, lda, stride_a, + CUDAMemory(b), CUDA_R_16F, ldb, stride_b, &beta, CUDAMemoryMutable(c), + CUDA_R_16F, ldc, stride_c, batch_count, CUDA_R_32F, algo); + if (ok) { + return true; + } + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } +#endif + } +#endif + // Either CUDA_VERSION < 9.1 or SM < 5.0. Fall back to a loop. + for (int batch = 0; batch < batch_count; ++batch) { + const auto *a_matrix = + reinterpret_cast(CUDAMemory(a) + batch * stride_a); + const auto *b_matrix = + reinterpret_cast(CUDAMemory(b) + batch * stride_b); + auto *c_matrix = + reinterpret_cast<__half *>(CUDAMemoryMutable(c) + batch * stride_c); + bool ok = DoBlasInternalImpl( + wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */, + true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa), + CUDABlasTranspose(transb), m, n, k, &alpha, a_matrix, SE_CUDA_DATA_HALF, + lda, b_matrix, SE_CUDA_DATA_HALF, ldb, &beta, c_matrix, + SE_CUDA_DATA_HALF, ldc); + if (!ok) { + LOG(ERROR) << "failed BLAS call, see log for details"; + return false; + } + } + return true; +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, float alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + float beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasSgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, double alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + double beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) { + return DoBlasInternal( + wrap::cublasDgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, + CUDAMemory(a), lda, stride_a, CUDAMemory(b), ldb, stride_b, &beta, + CUDAMemoryMutable(c), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasCgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + +bool CUDABlas::DoBlasGemmStridedBatched( + Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, + uint64 n, uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count) { + return DoBlasInternal( + wrap::cublasZgemmStridedBatched, stream, true /* = pointer_mode_host */, + CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, + CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, stride_a, + CUDAComplex(CUDAMemory(b)), ldb, stride_b, CUDAComplex(&beta), + CUDAComplex(CUDAMemoryMutable(c)), ldc, stride_c, batch_count); +} + bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, uint64 m, uint64 n, std::complex alpha, diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc index 10f6d21d54a15b46ebfc6f8ad32e3e908fab9a96..124d5905b91cbf839437e763728cc76ad0d671dc 100644 --- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc +++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc @@ -24,12 +24,17 @@ limitations under the License. #include #include #include +#ifdef __APPLE__ +#include +#include +#else #if !defined(PLATFORM_WINDOWS) #include #include #include #endif #include +#endif #include #include #include @@ -49,7 +54,9 @@ limitations under the License. namespace stream_executor { namespace cuda { -#if !defined(PLATFORM_WINDOWS) +#ifdef __APPLE__ +static const CFStringRef kDriverKextIdentifier = CFSTR("com.nvidia.CUDA"); +#elif !defined(PLATFORM_WINDOWS) static const char *kDriverVersionPath = "/proc/driver/nvidia/version"; #endif @@ -114,7 +121,31 @@ string Diagnostician::GetDevNodePath(int dev_node_ordinal) { } void Diagnostician::LogDiagnosticInformation() { -#if !defined(PLATFORM_WINDOWS) +#ifdef __APPLE__ + CFStringRef kext_ids[1]; + kext_ids[0] = kDriverKextIdentifier; + CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void **)kext_ids, 1, + &kCFTypeArrayCallBacks); + CFDictionaryRef kext_infos = + KextManagerCopyLoadedKextInfo(kext_id_query, nullptr); + CFRelease(kext_id_query); + + CFDictionaryRef cuda_driver_info = nullptr; + if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, + (const void **)&cuda_driver_info)) { + bool started = CFBooleanGetValue((CFBooleanRef)CFDictionaryGetValue( + cuda_driver_info, CFSTR("OSBundleStarted"))); + if (!started) { + LOG(INFO) << "kernel driver is installed, but does not appear to be " + "running on this host " + << "(" << port::Hostname() << ")"; + } + } else { + LOG(INFO) << "kernel driver does not appear to be installed on this host " + << "(" << port::Hostname() << ")"; + } + CFRelease(kext_infos); +#elif !defined(PLATFORM_WINDOWS) if (access(kDriverVersionPath, F_OK) != 0) { LOG(INFO) << "kernel driver does not appear to be running on this host " << "(" << port::Hostname() << "): " @@ -168,7 +199,8 @@ void Diagnostician::LogDiagnosticInformation() { << DriverVersionStatusToString(kernel_version); #endif -#if !defined(PLATFORM_WINDOWS) + // OS X kernel driver does not report version accurately +#if !defined(__APPLE__) && !defined(PLATFORM_WINDOWS) if (kernel_version.ok() && dso_version.ok()) { WarnOnDsoKernelMismatch(dso_version, kernel_version); } @@ -182,6 +214,29 @@ port::StatusOr Diagnostician::FindDsoVersion() { port::error::NOT_FOUND, "was unable to find libcuda.so DSO loaded into this program")); +#if defined(__APPLE__) + // OSX CUDA libraries have names like: libcuda_310.41.15_mercury.dylib + const string prefix("libcuda_"); + const string suffix("_mercury.dylib"); + for (uint32_t image_index = 0; image_index < _dyld_image_count(); + ++image_index) { + const string path(_dyld_get_image_name(image_index)); + const size_t suffix_pos = path.rfind(suffix); + const size_t prefix_pos = path.rfind(prefix, suffix_pos); + if (prefix_pos == string::npos || suffix_pos == string::npos) { + // no match + continue; + } + const size_t start = prefix_pos + prefix.size(); + if (start >= suffix_pos) { + // version not included + continue; + } + const size_t length = suffix_pos - start; + const string version = path.substr(start, length); + result = StringToDriverVersion(version); + } +#else #if !defined(PLATFORM_WINDOWS) && !defined(ANDROID_TEGRA) // Callback used when iterating through DSOs. Looks for the driver-interfacing // DSO and yields its version number into the callback data, when found. @@ -214,6 +269,7 @@ port::StatusOr Diagnostician::FindDsoVersion() { }; dl_iterate_phdr(iterate_phdr, &result); +#endif #endif return result; @@ -259,7 +315,41 @@ void Diagnostician::WarnOnDsoKernelMismatch( port::StatusOr Diagnostician::FindKernelDriverVersion() { -#if defined(PLATFORM_WINDOWS) +#if defined(__APPLE__) + CFStringRef kext_ids[1]; + kext_ids[0] = kDriverKextIdentifier; + CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void **)kext_ids, 1, + &kCFTypeArrayCallBacks); + CFDictionaryRef kext_infos = + KextManagerCopyLoadedKextInfo(kext_id_query, nullptr); + CFRelease(kext_id_query); + + CFDictionaryRef cuda_driver_info = nullptr; + if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, + (const void **)&cuda_driver_info)) { + // NOTE: OSX CUDA driver does not currently store the same driver version + // in kCFBundleVersionKey as is returned by cuDriverGetVersion + CFRelease(kext_infos); + const CFStringRef str = (CFStringRef)CFDictionaryGetValue( + cuda_driver_info, kCFBundleVersionKey); + const char *version = CFStringGetCStringPtr(str, kCFStringEncodingUTF8); + + // version can be NULL in which case treat it as empty string + // see + // https://developer.apple.com/library/mac/documentation/CoreFoundation/Conceptual/CFStrings/Articles/AccessingContents.html#//apple_ref/doc/uid/20001184-100980-TPXREF112 + if (version == NULL) { + return StringToDriverVersion(""); + } + return StringToDriverVersion(version); + } + CFRelease(kext_infos); + auto status = port::Status( + port::error::INTERNAL, + port::StrCat( + "failed to read driver bundle version: ", + CFStringGetCStringPtr(kDriverKextIdentifier, kCFStringEncodingUTF8))); + return status; +#elif defined(PLATFORM_WINDOWS) auto status = port::Status(port::error::UNIMPLEMENTED, "kernel reported driver version not implemented on Windows"); diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index d4f2fd262544a6a2771e17a712ae911c7249a7d5..55408ab9ab77fdca4a4eaa7b3ffd1223afc573e4 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -322,6 +322,7 @@ port::Status GetLoadedCudnnVersion(CudnnVersion* version) { CudnnSupport::CudnnSupport(CUDAExecutor* parent) : parent_(parent) {} port::Status CudnnSupport::Init() { + ScopedActivateExecutorContext context(parent_); cudnnHandle_t cudnn_handle = nullptr; auto status = cudnnCreate(&cudnn_handle); if (status == CUDNN_STATUS_SUCCESS) { @@ -791,6 +792,11 @@ class CudnnActivationDescriptor { double relu_ceiling = 0.0; cudnnActivationMode_t mode; switch (activation_mode) { +#if CUDNN_VERSION >= 7100 + case dnn::ActivationMode::kNone: + mode = CUDNN_ACTIVATION_IDENTITY; + break; +#endif case dnn::ActivationMode::kRelu6: relu_ceiling = 6.0; mode = CUDNN_ACTIVATION_CLIPPED_RELU; @@ -1980,15 +1986,14 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn, port::StatusOr> AllocateCudnnConvolutionForwardWorkspace( Stream* stream, const CudnnHandle& cudnn, - const dnn::AlgorithmDesc& algorithm_desc, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, - const CudnnTensorDescriptor& output_nd, + const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc, ScratchAllocator* scratch_allocator) { // TODO(csigg): This has side effects on the convolution descriptor. It is // functionally correct because the convolution is run with the algorithm of // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled()); // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -1996,8 +2001,14 @@ port::StatusOr> AllocateCudnnConvolutionForwardWorkspace( cudnn.handle(), /*xDesc=*/input_nd.handle(), /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), - /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc), + /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(*algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); + + if (TF_PREDICT_FALSE(!algorithm_desc)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No AlgorithmDesc provided"); + } + algorithm_desc->set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2022,15 +2033,14 @@ port::StatusOr> AllocateCudnnConvolutionForwardWorkspace( port::StatusOr> AllocateCudnnConvolutionBackwardDataWorkspace( Stream* stream, const CudnnHandle& cudnn, - const dnn::AlgorithmDesc& algorithm_desc, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, - const CudnnTensorDescriptor& output_nd, + const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc, ScratchAllocator* scratch_allocator) { // TODO(csigg): This has side effects on the convolution descriptor. It is // functionally correct because the convolution is run with the algorithm of // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled()); // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -2040,8 +2050,14 @@ AllocateCudnnConvolutionBackwardDataWorkspace( /*dyDesc=*/output_nd.handle(), /*convDesc=*/conv.handle(), /*dxDesc=*/input_nd.handle(), - /*algo=*/ToConvBackwardDataAlgo(algorithm_desc), + /*algo=*/ToConvBackwardDataAlgo(*algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); + + if (TF_PREDICT_FALSE(!algorithm_desc)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No AlgorithmDesc provided"); + } + algorithm_desc->set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2066,15 +2082,14 @@ AllocateCudnnConvolutionBackwardDataWorkspace( port::StatusOr> AllocateCudnnConvolutionBackwardFilterWorkspace( Stream* stream, const CudnnHandle& cudnn, - const dnn::AlgorithmDesc& algorithm_desc, const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv, - const CudnnTensorDescriptor& output_nd, + const CudnnTensorDescriptor& output_nd, dnn::AlgorithmDesc* algorithm_desc, ScratchAllocator* scratch_allocator) { // TODO(csigg): This has side effects on the convolution descriptor. It is // functionally correct because the convolution is run with the algorithm of // the last call to this function, but should be fixed anyway. - conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + conv.set_use_tensor_op_math(algorithm_desc->tensor_ops_enabled()); // Query the size of the workspace and allocate it. size_t size_in_bytes; @@ -2084,8 +2099,14 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( /*dyDesc=*/output_nd.handle(), /*convDesc=*/conv.handle(), /*gradDesc=*/filter.handle(), - /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc), + /*algo=*/ToConvBackwardFilterAlgo(*algorithm_desc), /*sizeInBytes=*/&size_in_bytes)); + + if (TF_PREDICT_FALSE(!algorithm_desc)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No AlgorithmDesc provided"); + } + algorithm_desc->set_scratch_size(size_in_bytes); int64 size_in_bytes_int64 = size_in_bytes; if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { @@ -2132,7 +2153,7 @@ port::StatusOr GetCudnnConvolutionForwardAlgorithm( } auto scratch_or = AllocateCudnnConvolutionForwardWorkspace( - stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc, scratch_allocator); if (scratch_or.ok()) { @@ -2149,11 +2170,11 @@ port::StatusOr GetCudnnConvolutionForwardAlgorithm( "while a secondary algorithm is not provided."); } - SE_ASSIGN_OR_RETURN( - *scratch, AllocateCudnnConvolutionForwardWorkspace( - stream, cudnn, algorithm_config.algorithm_no_scratch(), - input_nd, filter, conv, output_nd, scratch_allocator)); - return algorithm_config.algorithm_no_scratch(); + algo_desc = algorithm_config.algorithm_no_scratch(); + SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace( + stream, cudnn, input_nd, filter, conv, + output_nd, &algo_desc, scratch_allocator)); + return algo_desc; } port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( @@ -2181,7 +2202,7 @@ port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( } auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace( - stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc, scratch_allocator); if (scratch_or.ok()) { @@ -2198,11 +2219,11 @@ port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( "while a secondary algorithm is not provided."); } - SE_ASSIGN_OR_RETURN( - *scratch, AllocateCudnnConvolutionBackwardDataWorkspace( - stream, cudnn, algorithm_config.algorithm_no_scratch(), - input_nd, filter, conv, output_nd, scratch_allocator)); - return algorithm_config.algorithm_no_scratch(); + algo_desc = algorithm_config.algorithm_no_scratch(); + SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace( + stream, cudnn, input_nd, filter, conv, + output_nd, &algo_desc, scratch_allocator)); + return algo_desc; } port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( @@ -2230,7 +2251,7 @@ port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( } auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace( - stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + stream, cudnn, input_nd, filter, conv, output_nd, &algo_desc, scratch_allocator); if (scratch_or.ok()) { @@ -2247,11 +2268,11 @@ port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( "while a secondary algorithm is not provided."); } - SE_ASSIGN_OR_RETURN(*scratch, - AllocateCudnnConvolutionBackwardFilterWorkspace( - stream, cudnn, algorithm_config.algorithm(), input_nd, - filter, conv, output_nd, scratch_allocator)); - return algorithm_config.algorithm_no_scratch(); + algo_desc = algorithm_config.algorithm_no_scratch(); + SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace( + stream, cudnn, input_nd, filter, conv, + output_nd, &algo_desc, scratch_allocator)); + return algo_desc; } // A helper class to set env-vars and choose options for cudnn-related @@ -2480,10 +2501,11 @@ port::Status CudnnSupport::DoFusedConvolveImpl( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - if (activation_mode != dnn::ActivationMode::kRelu) { + if (activation_mode != dnn::ActivationMode::kRelu && + activation_mode != dnn::ActivationMode::kNone) { return port::Status(port::error::INVALID_ARGUMENT, "cudnnConvolutionBiasActivationForward() only supports " - "Relu activation."); + "Relu or None activation."); } CudnnTensorDescriptor conv_input_nd( @@ -3074,6 +3096,21 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl( } } + // Cudnn 7.1.4 has a bug if the workspace of the following convolution is not + // zero-initialized, nvbugs/2254619. + if (CUDNN_VERSION >= 7000 && + algorithm_config.algorithm().algo_id() == + CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 && + cudnn_type == CUDNN_DATA_HALF && + algorithm_config.algorithm().tensor_ops_enabled() && + input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth && + filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX && + output_descriptor.layout() == dnn::DataLayout::kBatchDepthYX && + (convolution_descriptor.vertical_filter_stride() > 1 || + convolution_descriptor.horizontal_filter_stride() > 1)) { + stream->ThenMemZero(&scratch, scratch.size()); + } + RETURN_IF_CUDNN_ERROR( cudnnConvolutionBackwardData(cudnn.handle(), /*alpha=*/alpha, @@ -3587,7 +3624,7 @@ bool CudnnSupport::DoPoolForward( const dnn::BatchDescriptor& input_dimensions, const DeviceMemory& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) { + DeviceMemory* output_data, ScratchAllocator* workspace_allocator) { // Alpha is the scaling factor for input. double alpha = 1.0; // Beta is the scaling factor for output. @@ -3612,7 +3649,7 @@ bool CudnnSupport::DoPoolForward( const dnn::BatchDescriptor& input_dimensions, const DeviceMemory& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) { + DeviceMemory* output_data, ScratchAllocator* workspace_allocator) { // Alpha is the scaling factor for input. float alpha = 1.0; // Beta is the scaling factor for output. @@ -3637,7 +3674,8 @@ bool CudnnSupport::DoPoolForward( const dnn::BatchDescriptor& input_dimensions, const DeviceMemory& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) { + DeviceMemory* output_data, + ScratchAllocator* workspace_allocator) { // Alpha is the scaling factor for input. float alpha = 1.0; // Beta is the scaling factor for output. @@ -3663,7 +3701,8 @@ bool CudnnSupport::DoPoolBackward( const dnn::BatchDescriptor& output_dimensions, const DeviceMemory& output_data, const DeviceMemory& input_diff_data, - DeviceMemory* output_diff_data) { + DeviceMemory* output_diff_data, + ScratchAllocator* workspace_allocator) { // Alpha is the scaling factor for input. double alpha = 1.0; // Beta is the scaling factor for output. @@ -3692,7 +3731,8 @@ bool CudnnSupport::DoPoolBackward( const dnn::BatchDescriptor& output_dimensions, const DeviceMemory& output_data, const DeviceMemory& input_diff_data, - DeviceMemory* output_diff_data) { + DeviceMemory* output_diff_data, + ScratchAllocator* workspace_allocator) { // Alpha is the scaling factor for input. float alpha = 1.0; // Beta is the scaling factor for output. @@ -3721,7 +3761,8 @@ bool CudnnSupport::DoPoolBackward( const dnn::BatchDescriptor& output_dimensions, const DeviceMemory& output_data, const DeviceMemory& input_diff_data, - DeviceMemory* output_diff_data) { + DeviceMemory* output_diff_data, + ScratchAllocator* workspace_allocator) { // Alpha is the scaling factor for input. float alpha = 1.0; // Beta is the scaling factor for output. @@ -3790,7 +3831,8 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions( const dnn::BatchDescriptor& dimensions, const DeviceMemory& raw_data, const DeviceMemory& normalized_data, const DeviceMemory& normalized_variable_gradient, - DeviceMemory* raw_variable_gradient) { + DeviceMemory* raw_variable_gradient, + ScratchAllocator* workspace_allocator) { // Check for unsupported modes. if (normalize_descriptor.wrap_around()) { LOG(ERROR) << "CUDA LRN does not support cudnn-around mode"; diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index c924d41cb5239d704e658f0b5452e04087caeba2..9d88f971bb17510099978a03b673f39576c32587 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -515,21 +515,24 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& input_dimensions, const DeviceMemory& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) override; + DeviceMemory* output_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolForward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, const DeviceMemory& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) override; + DeviceMemory* output_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolForward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, const DeviceMemory& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) override; + DeviceMemory* output_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolBackward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, @@ -538,7 +541,8 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory& output_data, const DeviceMemory& input_diff_data, - DeviceMemory* output_diff_data) override; + DeviceMemory* output_diff_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolBackward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, @@ -547,7 +551,8 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory& output_data, const DeviceMemory& input_diff_data, - DeviceMemory* output_diff_data) override; + DeviceMemory* output_diff_data, + ScratchAllocator* workspace_allocator) override; bool DoPoolBackward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, @@ -556,7 +561,8 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory& output_data, const DeviceMemory& input_diff_data, - DeviceMemory* output_diff_data) override; + DeviceMemory* output_diff_data, + ScratchAllocator* workspace_allocator) override; bool DoNormalize(Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, @@ -575,7 +581,8 @@ class CudnnSupport : public dnn::DnnSupport { const DeviceMemory& raw_data, const DeviceMemory& normalized_data, const DeviceMemory& normalized_variable_gradient, - DeviceMemory* raw_variable_gradient) override; + DeviceMemory* raw_variable_gradient, + ScratchAllocator* workspace_allocator) override; bool DoDepthConcatenate( Stream* stream, port::ArraySlice input_dimensions, diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc index d508f6594a9f9ac3c924b0b952620b6a4ac727ea..f982f34b98eca60dbf50dbf7c970b079283d0b42 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/stream_executor/cuda/cuda_driver.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/human_readable.h" #include "tensorflow/stream_executor/lib/inlined_vector.h" #include "tensorflow/stream_executor/lib/notification.h" +#include "tensorflow/stream_executor/lib/ptr_util.h" #include "tensorflow/stream_executor/lib/stacktrace.h" #include "tensorflow/stream_executor/lib/static_threadlocal.h" #include "tensorflow/stream_executor/lib/strcat.h" @@ -66,14 +67,17 @@ class CreatedContexts { return Live()->find(context) != Live()->end(); } - // Adds context to the live set. + // Adds context to the live set, or returns it if it's already present. static CudaContext* Add(CUcontext context) { CHECK(context != nullptr); mutex_lock lock(mu_); - auto cuda_context = new CudaContext(context, next_id_++); - Live()->insert( - std::make_pair(context, std::unique_ptr(cuda_context))); - return cuda_context; + auto insert_result = Live()->insert(std::make_pair(context, nullptr)); + auto it = insert_result.first; + if (insert_result.second) { + // context was not present in the map. Add it. + it->second = MakeUnique(context, next_id_++); + } + return it->second.get(); } // Removes context from the live set. @@ -102,117 +106,16 @@ class CreatedContexts { /* static */ int64 CreatedContexts::next_id_ = 1; // 0 means "no context" // Formats CUresult to output prettified values into a log stream. -// Error summaries taken from: -// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gc6c391505e117393cc2558fff6bfc2e9 -// -// TODO(leary) switch to cuGetErrorName when updated cuda.h is available. string ToString(CUresult result) { -#define OSTREAM_CUDA_ERROR(__name) \ - case CUDA_ERROR_##__name: \ - return "CUDA_ERROR_" #__name; - -/////////////// -// NOTE: here we specify return code values outside of the enum explicitly -// because our in-tree cuda.h is from the CUDA 5.5 SDK, but CUDA 6.0+ driver -// libraries are deployed in the fleet these error codes are backwards -// compatible, but if we see a "new" one, we want to be able to identify it in -// the logs. -// -// Once we get a cuda.h that has cuGetErrorName (TODO is above) we can -// eliminate this function and just rely on the driver to provide us these -// strings. -// -// NOTE: "Must reboot all context" below is shorthand for, "must -// destroy/recreate the offending context and any allocation which come from -// it if you are to continue using CUDA." -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wswitch" - switch (result) { - OSTREAM_CUDA_ERROR(INVALID_VALUE) - OSTREAM_CUDA_ERROR(OUT_OF_MEMORY) - OSTREAM_CUDA_ERROR(NOT_INITIALIZED) - OSTREAM_CUDA_ERROR(DEINITIALIZED) - OSTREAM_CUDA_ERROR(NO_DEVICE) - OSTREAM_CUDA_ERROR(INVALID_DEVICE) - OSTREAM_CUDA_ERROR(INVALID_IMAGE) - OSTREAM_CUDA_ERROR(INVALID_CONTEXT) - OSTREAM_CUDA_ERROR(INVALID_HANDLE) - OSTREAM_CUDA_ERROR(NOT_FOUND) - OSTREAM_CUDA_ERROR(NOT_READY) - OSTREAM_CUDA_ERROR(NO_BINARY_FOR_GPU) - - // Encountered an uncorrectable ECC error during execution. - OSTREAM_CUDA_ERROR(ECC_UNCORRECTABLE) - - // Load/store on an invalid address. Must reboot all context. - case 700: - return "CUDA_ERROR_ILLEGAL_ADDRESS"; - // Passed too many / wrong arguments, too many threads for register count. - case 701: - return "CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES"; - // Kernel took too long to execute. - case 702: - return "CUDA_ERROR_LAUNCH_TIMEOUT"; - // Kernel launch uses an incompatible texturing mode. - case 703: - return "CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING"; - // Trying to re-enable peer access that already has it enabled. - case 704: - return "CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED"; - // Trying to disable peer access that has not yet been enabled. - case 705: - return "CUDA_ERROR_PEER_ACCESS_NOT_ENABLED"; - // Primary context for the specified device has already been initialized. - case 708: - return "CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE"; - // Context current to calling thread has been destroyed or is a primary - // context that has not yet been initialized. - case 709: - return "CUDA_ERROR_CONTEXT_IS_DESTROYED"; - // Device-side assert triggered during kernel execution. Must reboot all - // context. - case 710: - return "CUDA_ERROR_ASSERT"; - // Hardware resources to enable peer access have been exhausted. - case 711: - return "CUDA_ERROR_TOO_MANY_PEERS"; - // Memory range has already been registered. - case 712: - return "CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED"; - // Pointer does not correspond to any currently registered memory region. - case 713: - return "CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED"; - // Due to stack corruption or exceeding stack size limit. Must reboot all - // context. - case 714: - return "CUDA_ERROR_HARDWARE_STACK_ERROR"; - case 715: - return "CUDA_ERROR_ILLEGAL_INSTRUCTION"; - // Load/store on an unaligned memory address. Must reboot all context. - case 716: - return "CUDA_ERROR_MISALIGNED_ADDRESS"; - // Device instruction with specific address space given address not - // belonging to allowed address space. Must reboot all context. - case 717: - return "CUDA_ERROR_INVALID_ADDRESS_SPACE"; - // Device program counter wrapped its address space. Must reboot all - // context. - case 718: - return "CUDA_ERROR_INVALID_PC"; - // Exception on device while executing a kernel; e.g. deref invalid device - // pointer, accessing OOB shared memory. Must reboot all context. - case 719: - return "CUDA_ERROR_LAUNCH_FAILED"; - - OSTREAM_CUDA_ERROR(CONTEXT_ALREADY_IN_USE) - OSTREAM_CUDA_ERROR(PEER_ACCESS_UNSUPPORTED) - OSTREAM_CUDA_ERROR(NOT_PERMITTED) - OSTREAM_CUDA_ERROR(NOT_SUPPORTED) - OSTREAM_CUDA_ERROR(UNKNOWN) // Unknown internal error to CUDA. - default: - return port::StrCat("CUresult(", static_cast(result), ")"); + const char *error_name; + if (cuGetErrorName(result, &error_name)) { + return port::StrCat("UNKNOWN ERROR (", static_cast(result), ")"); + } + const char *error_string; + if (cuGetErrorString(result, &error_string)) { + return error_name; } -#pragma GCC diagnostic pop + return port::StrCat(error_name, ": ", error_string); } // Returns the current context and checks that it is in the set of CUDA contexts @@ -528,7 +431,7 @@ bool DeviceOptionsToContextFlags(const DeviceOptions &device_options, *context = CreatedContexts::Add(new_context); CHECK(*context != nullptr) << "success in this call must entail non-null result"; - VLOG(2) << "created context " << context << " for this thread"; + VLOG(2) << "created or reused context " << context << " for this thread"; return port::Status::OK(); } diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index edf217875ff30b26407e35a39d1ee603b980b480..73f05b94db8f022825ccc72c2222a78634423ddf 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" +#if defined(__APPLE__) +#include +#endif #if defined(PLATFORM_WINDOWS) #include #define PATH_MAX MAX_PATH @@ -176,11 +179,19 @@ bool CUDAExecutor::FindOnDiskForComputeCapability( // would return /usr/bin. static string GetBinaryDir(bool strip_exe) { char exe_path[PATH_MAX] = {0}; +#if defined(__APPLE__) + uint32_t buffer_size = 0U; + _NSGetExecutablePath(nullptr, &buffer_size); + char unresolved_path[buffer_size]; + _NSGetExecutablePath(unresolved_path, &buffer_size); + CHECK_ERR(realpath(unresolved_path, exe_path) ? 1 : -1); +#else #if defined(PLATFORM_WINDOWS) HMODULE hModule = GetModuleHandle(NULL); GetModuleFileName(hModule, exe_path, MAX_PATH); #else CHECK_ERR(readlink("/proc/self/exe", exe_path, sizeof(exe_path) - 1)); +#endif #endif // Make sure it's null-terminated: exe_path[sizeof(exe_path) - 1] = 0; @@ -195,6 +206,48 @@ static string GetBinaryDir(bool strip_exe) { return exe_path; } +bool CUDAExecutor::LoadModuleFromCuBin(const char *cubin, CUmodule *module) { + uint64_t module_refcount; + std::tie(*module, module_refcount) = gpu_binary_to_module_[cubin]; + + if (*module == nullptr) { + auto load_status = CUDADriver::LoadCubin(context_, cubin, module); + if (!load_status.ok()) { + LOG(ERROR) << "failed to load CUBIN: " << load_status; + return false; + } + module_refcount = 1; + VLOG(3) << "Loaded CUBIN " << static_cast(cubin) + << " as module " << *module; + } else { + ++module_refcount; + VLOG(3) << "CUBIN " << static_cast(cubin) + << " is already loaded as module " << *module; + } + gpu_binary_to_module_[cubin] = {*module, module_refcount}; + return true; +} + +bool CUDAExecutor::LoadModuleFromPtx(const char *ptx, CUmodule *module) { + uint64_t module_refcount; + std::tie(*module, module_refcount) = gpu_binary_to_module_[ptx]; + + if (*module == nullptr) { + if (!CUDADriver::LoadPtx(context_, ptx, module)) { + return false; + } + VLOG(3) << "Loaded PTX " << static_cast(ptx) << " as module " + << *module; + module_refcount = 1; + } else { + ++module_refcount; + VLOG(3) << "PTX " << static_cast(ptx) + << " is already loaded as module " << module; + } + gpu_binary_to_module_[ptx] = {*module, module_refcount}; + return true; +} + bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel) { CUDAKernel *cuda_kernel = AsCUDAKernel(kernel); @@ -204,28 +257,13 @@ bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec, VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name(); if (spec.has_cuda_cubin_in_memory()) { + mutex_lock lock{in_memory_modules_mu_}; kernelname = &spec.cuda_cubin_in_memory().kernelname(); const char *cubin = spec.cuda_cubin_in_memory().bytes(); - mutex_lock lock{in_memory_modules_mu_}; - uint64_t module_refcount; - std::tie(module, module_refcount) = gpu_binary_to_module_[cubin]; - - if (module == nullptr) { - auto load_status = CUDADriver::LoadCubin(context_, cubin, &module); - if (!load_status.ok()) { - LOG(ERROR) << "failed to load CUBIN: " << load_status; - return false; - } - module_refcount = 1; - VLOG(3) << "Loaded CUBIN " << static_cast(cubin) - << " as module " << module; - } else { - ++module_refcount; - VLOG(3) << "CUBIN " << static_cast(cubin) - << " is already loaded as module " << module; + if (!LoadModuleFromCuBin(cubin, &module)) { + return false; } kernel_to_gpu_binary_[kernel] = cubin; - gpu_binary_to_module_[cubin] = {module, module_refcount}; } else if (spec.has_cuda_ptx_in_memory()) { kernelname = &spec.cuda_ptx_in_memory().kernelname(); @@ -243,24 +281,10 @@ bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec, } mutex_lock lock{in_memory_modules_mu_}; - uint64_t module_refcount; - std::tie(module, module_refcount) = gpu_binary_to_module_[ptx]; - - if (module == nullptr) { - if (!CUDADriver::LoadPtx(context_, ptx, &module)) { - LOG(ERROR) << "failed to load PTX for kernel " << *kernelname; - return false; - } - VLOG(3) << "Loaded PTX " << static_cast(ptx) - << " as module " << module; - module_refcount = 1; - } else { - ++module_refcount; - VLOG(3) << "PTX " << static_cast(ptx) - << " is already loaded as module " << module; + if (!LoadModuleFromPtx(ptx, &module)) { + return false; } kernel_to_gpu_binary_[kernel] = ptx; - gpu_binary_to_module_[ptx] = {module, module_refcount}; } else { LOG(WARNING) << "no method of loading CUDA kernel provided"; return false; @@ -284,6 +308,23 @@ bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec, return true; } +bool CUDAExecutor::UnloadGpuBinary(const void *gpu_binary) { + auto module_it = gpu_binary_to_module_.find(gpu_binary); + if (gpu_binary_to_module_.end() == module_it) { + VLOG(3) << "No loaded CUDA module for " << gpu_binary; + return false; + } + auto &module = module_it->second.first; + auto &refcount = module_it->second.second; + VLOG(3) << "Found CUDA module " << module << " with refcount " << refcount; + if (--refcount == 0) { + VLOG(3) << "Unloading CUDA module " << module; + CUDADriver::UnloadModule(context_, module); + gpu_binary_to_module_.erase(module_it); + } + return true; +} + void CUDAExecutor::UnloadKernel(const KernelBase *kernel) { VLOG(3) << "Unloading kernel " << kernel << " : " << kernel->name(); @@ -296,25 +337,52 @@ void CUDAExecutor::UnloadKernel(const KernelBase *kernel) { } VLOG(3) << "Kernel " << kernel << " : " << kernel->name() << " has loaded GPU code " << gpu_binary_it->second; - auto module_it = gpu_binary_to_module_.find(gpu_binary_it->second); - if (gpu_binary_to_module_.end() == module_it) { - VLOG(3) << "Kernel " << kernel << " : " << kernel->name() - << " has no loaded CUDA module."; - return; // This kernel never loaded any modules - } - auto &module = module_it->second.first; - auto &refcount = module_it->second.second; - VLOG(3) << "Kernel " << kernel << " : " << kernel->name() - << " has loaded GPU code " << gpu_binary_it->second - << " into CUDA module " << module << " with refcount " << refcount; - if (--refcount == 0) { - VLOG(3) << "Unloading CUDA module " << module; - CUDADriver::UnloadModule(context_, module); - gpu_binary_to_module_.erase(module_it); - } + UnloadGpuBinary(gpu_binary_it->second); kernel_to_gpu_binary_.erase(gpu_binary_it); } +bool CUDAExecutor::LoadModule(const MultiModuleLoaderSpec &spec, + ModuleHandle *module_handle) { + // In CUDAExecutor we store the pointer to the GPU binary (PTX or CUBIN) as + // ModuleHandle::id(). + CUmodule cu_module; + if (spec.has_cuda_cubin_in_memory()) { + mutex_lock lock{in_memory_modules_mu_}; + if (!LoadModuleFromCuBin( + reinterpret_cast(spec.cuda_cubin_in_memory().data()), + &cu_module)) { + return false; + } + *module_handle = ModuleHandle(const_cast( + static_cast(spec.cuda_cubin_in_memory().data()))); + return true; + } else if (spec.has_cuda_ptx_in_memory()) { + if (cc_major_ == 0 && cc_minor_ == 0) { + return false; + } + + if (!spec.cuda_ptx_in_memory()) { + return false; + } + + mutex_lock lock{in_memory_modules_mu_}; + if (!LoadModuleFromPtx(spec.cuda_ptx_in_memory(), &cu_module)) { + return false; + } + *module_handle = ModuleHandle(const_cast( + static_cast(spec.cuda_ptx_in_memory()))); + return true; + } + LOG(WARNING) << "no method of loading CUDA module provided"; + return false; +} + +bool CUDAExecutor::UnloadModule(ModuleHandle module_handle) { + const char *gpu_binary = reinterpret_cast(module_handle.id()); + mutex_lock lock{in_memory_modules_mu_}; + return UnloadGpuBinary(gpu_binary); +} + bool CUDAExecutor::GetKernelMetadata(CUDAKernel *cuda_kernel, KernelMetadata *kernel_metadata) { int value; @@ -772,16 +840,26 @@ bool CUDAExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const { return CUDADriver::GetDeviceMemoryInfo(context_, free, total); } -bool CUDAExecutor::GetSymbol(const string& symbol_name, void **mem, +bool CUDAExecutor::GetSymbol(const string &symbol_name, + ModuleHandle module_handle, void **mem, size_t *bytes) { + auto lookup_in_module = [&](CUmodule module) { + CHECK(module != nullptr); + return CUDADriver::GetModuleSymbol(context_, module, symbol_name.c_str(), + reinterpret_cast(mem), + bytes); + }; + { // give limited scope to mutex_lock mutex_lock lock{in_memory_modules_mu_}; + if (static_cast(module_handle)) { + auto it = gpu_binary_to_module_.find(module_handle.id()); + CHECK(it != gpu_binary_to_module_.end()); + return lookup_in_module(it->second.first); + } + for (auto &it : gpu_binary_to_module_) { - CUmodule module = it.second.first; - CHECK(module != nullptr); - if (CUDADriver::GetModuleSymbol(context_, module, symbol_name.c_str(), - reinterpret_cast(mem), - bytes)) { + if (lookup_in_module(it.second.first)) { return true; } } @@ -833,7 +911,7 @@ CUDAExecutor::GetTimerImplementation() { return std::unique_ptr(new CUDATimer(this)); } -void *CUDAExecutor::CudaContextHack() { return context_; } +void *CUDAExecutor::GpuContextHack() { return context_; } CudaContext* CUDAExecutor::cuda_context() { return context_; } @@ -843,7 +921,10 @@ CudaContext* CUDAExecutor::cuda_context() { return context_; } // For anything more complicated/prod-focused than this, you'll likely want to // turn to gsys' topology modeling. static int TryToReadNumaNode(const string &pci_bus_id, int device_ordinal) { -#if defined(PLATFORM_WINDOWS) +#if defined(__APPLE__) + LOG(INFO) << "OS X does not support NUMA - returning NUMA node zero"; + return 0; +#elif defined(PLATFORM_WINDOWS) // Windows support for NUMA is not currently implemented. Return node 0. return 0; #elif defined(__aarch64__) diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index 773cbfb8a17a416d18ae599bf4f72e1550538dee..8a954d5461c60749019c87971cee22089bbd22e5 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -62,6 +62,9 @@ class CUDAExecutor : public internal::StreamExecutorInterface { bool GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel) override; void UnloadKernel(const KernelBase *kernel) override; + bool LoadModule(const MultiModuleLoaderSpec &spec, + ModuleHandle *module_handle) override; + bool UnloadModule(ModuleHandle module_handle) override; bool Launch(Stream *stream, const ThreadDim &thread_dims, const BlockDim &block_dims, const KernelBase &k, @@ -175,7 +178,8 @@ class CUDAExecutor : public internal::StreamExecutorInterface { // Search for the symbol and returns a device pointer and size. // Returns false if symbol does not exist. - bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) override; + bool GetSymbol(const string &symbol_name, ModuleHandle module_handle, + void **mem, size_t *bytes) override; DeviceDescription *PopulateDeviceDescription() const override; @@ -210,7 +214,7 @@ class CUDAExecutor : public internal::StreamExecutorInterface { std::unique_ptr GetTimerImplementation() override; - void *CudaContextHack() override; + void *GpuContextHack() override; CudaContext* cuda_context(); @@ -239,6 +243,16 @@ class CUDAExecutor : public internal::StreamExecutorInterface { void VlogOccupancyInfo(const KernelBase &kernel, const ThreadDim &thread_dims, const BlockDim &block_dims); + bool LoadModuleFromCuBin(const char *cubin, CUmodule *module) + EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); + + // Loads the PTX text `ptx` as a CUDA module. `ptx` must be null terminated. + bool LoadModuleFromPtx(const char *ptx, CUmodule *module) + EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); + + bool UnloadGpuBinary(const void *gpu_binary) + EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); + // Guards the in-memory-module mapping. mutex in_memory_modules_mu_; diff --git a/tensorflow/stream_executor/cuda/cuda_stream.h b/tensorflow/stream_executor/cuda/cuda_stream.h index 02edff643117fc2e3c6e6f74d2932f3f4c00c66d..bb8bda4755344d859668425f89614cc87d7e2d3e 100644 --- a/tensorflow/stream_executor/cuda/cuda_stream.h +++ b/tensorflow/stream_executor/cuda/cuda_stream.h @@ -40,8 +40,8 @@ class CUDAStream : public internal::StreamInterface { // Note: teardown is handled by a parent's call to DeallocateStream. ~CUDAStream() override {} - void *CudaStreamHack() override { return cuda_stream_; } - void **CudaStreamMemberHack() override { + void *GpuStreamHack() override { return cuda_stream_; } + void **GpuStreamMemberHack() override { return reinterpret_cast(&cuda_stream_); } diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index 82aa8ceb3298a30a4c117882dc96c504d9d10226..2a30f922bca4d1dc7d8a9d4ee6e26f7bdf41251c 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -117,6 +117,8 @@ string FilterLayoutString(FilterLayout layout) { switch (layout) { case FilterLayout::kOutputInputYX: return "OutputInputYX"; + case FilterLayout::kOutputYXInput: + return "OutputYXInput"; case FilterLayout::kOutputInputYX4: return "OutputInputYX4"; case FilterLayout::kInputYXOutput: diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 9eca5abe1ae7265ebca0a1ea653823816deaa8f5..9abfa1db6ab60351557ff6243ec354cfada6bb6d 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -713,15 +713,23 @@ class PoolingDescriptor { class AlgorithmDesc { public: typedef int64 Index; - AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {} + AlgorithmDesc() + : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true), scratch_size_(0) {} AlgorithmDesc(Index a, bool use_tensor_ops) - : algo_(a), tensor_ops_enabled_(use_tensor_ops) {} + : algo_(a), tensor_ops_enabled_(use_tensor_ops), scratch_size_(0) {} + AlgorithmDesc(Index a, bool use_tensor_ops, size_t scratch_size) + : algo_(a), + tensor_ops_enabled_(use_tensor_ops), + scratch_size_(scratch_size) {} bool is_default() const { return algo_ == kDefaultAlgorithm; } bool tensor_ops_enabled() const { return tensor_ops_enabled_; } Index algo_id() const { return algo_; } + size_t scratch_size() const { return scratch_size_; } + void set_scratch_size(size_t val) { scratch_size_ = val; } bool operator==(const AlgorithmDesc& other) const { return this->algo_ == other.algo_ && - this->tensor_ops_enabled_ == other.tensor_ops_enabled_; + this->tensor_ops_enabled_ == other.tensor_ops_enabled_ && + this->scratch_size_ == other.scratch_size_; } uint64 hash() const; @@ -729,6 +737,7 @@ class AlgorithmDesc { enum { kDefaultAlgorithm = -1 }; Index algo_; bool tensor_ops_enabled_; + size_t scratch_size_; }; // Describes the result from a perf experiment. @@ -1552,14 +1561,16 @@ class DnnSupport { const dnn::BatchDescriptor& input_dimensions, const DeviceMemory& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) = 0; + DeviceMemory* output_data, + ScratchAllocator* workspace_allocator) = 0; virtual bool DoPoolForward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, const DeviceMemory& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) { + DeviceMemory* output_data, + ScratchAllocator* workspace_allocator) { LOG(FATAL) << "DoPoolForward not implemented for double."; return false; } @@ -1569,7 +1580,8 @@ class DnnSupport { const dnn::BatchDescriptor& input_dimensions, const DeviceMemory& input_data, const dnn::BatchDescriptor& output_dimensions, - DeviceMemory* output_data) { + DeviceMemory* output_data, + ScratchAllocator* workspace_allocator) { LOG(FATAL) << "DoPoolForward not implemented for float16."; return false; } @@ -1582,7 +1594,8 @@ class DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory& output_data, const DeviceMemory& input_diff_data, - DeviceMemory* output_diff_data) { + DeviceMemory* output_diff_data, + ScratchAllocator* workspace_allocator) { LOG(FATAL) << "DoPoolBackward not implemented."; return false; } @@ -1594,7 +1607,8 @@ class DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory& output_data, const DeviceMemory& input_diff_data, - DeviceMemory* output_diff_data) { + DeviceMemory* output_diff_data, + ScratchAllocator* workspace_allocator) { LOG(FATAL) << "DoPoolBackward not implemented."; return false; } @@ -1606,7 +1620,8 @@ class DnnSupport { const dnn::BatchDescriptor& output_dimensions, const DeviceMemory& output_data, const DeviceMemory& input_diff_data, - DeviceMemory* output_diff_data) { + DeviceMemory* output_diff_data, + ScratchAllocator* workspace_allocator) { LOG(FATAL) << "DoPoolBackward not implemented."; return false; } @@ -1653,7 +1668,8 @@ class DnnSupport { const DeviceMemory& raw_data, const DeviceMemory& normalized_data, const DeviceMemory& normalized_variable_gradient, - DeviceMemory* raw_variable_gradient) { + DeviceMemory* raw_variable_gradient, + ScratchAllocator* workspace_allocator) { return false; } diff --git a/tensorflow/stream_executor/event.cc b/tensorflow/stream_executor/event.cc index 50a6edd80bd39004e32f09bcde36fbc8a8b59ad9..52efe771bc3c43e65b4539f811196e2d8785eb77 100644 --- a/tensorflow/stream_executor/event.cc +++ b/tensorflow/stream_executor/event.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/stream_executor/event.h" +#include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h" -#include "tensorflow/stream_executor/stream.h" namespace stream_executor { @@ -27,9 +27,12 @@ Event::Event(StreamExecutor* stream_exec) stream_exec_->implementation()->CreateEventImplementation()) {} Event::~Event() { - auto status = stream_exec_->DeallocateEvent(this); - if (!status.ok()) { - LOG(ERROR) << status.error_message(); + // Deal with nullptr implementation_, as this event may have been std::moved. + if (stream_exec_ && implementation_) { + auto status = stream_exec_->DeallocateEvent(this); + if (!status.ok()) { + LOG(ERROR) << status.error_message(); + } } } diff --git a/tensorflow/stream_executor/event.h b/tensorflow/stream_executor/event.h index 1f37262c78d82f72f8818f35db273e87a47bdc1c..9cc87a7c129962820ed0c84d02faada4ba460d51 100644 --- a/tensorflow/stream_executor/event.h +++ b/tensorflow/stream_executor/event.h @@ -61,6 +61,9 @@ class Event { // Returns a pointer to the underlying platform-specific implementation. internal::EventInterface* implementation() { return implementation_.get(); } + Event(Event&&) = default; + Event& operator=(Event&&) = default; + private: friend class Stream; diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc index 2c4819651acaa2c6ee99c720b2c3d80e5c2ea1a9..8adf739b170c42e5aeda5ccf3ea469f2c3cea07c 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.cc +++ b/tensorflow/stream_executor/host/host_gpu_executor.cc @@ -26,8 +26,6 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/plugin_registry.h" -bool FLAGS_stream_executor_cpu_real_clock_rate = false; - namespace stream_executor { namespace host { @@ -95,7 +93,7 @@ bool HostExecutor::MemcpyDeviceToDevice(Stream *stream, // the nature of the HostExecutor) memcpy on the stream (HostStream) // associated with the HostExecutor. AsHostStream(stream)->EnqueueTask( - [src_mem, dst_mem, size]() { memcpy(src_mem, dst_mem, size); }); + [src_mem, dst_mem, size]() { memcpy(dst_mem, src_mem, size); }); return true; } @@ -190,11 +188,8 @@ DeviceDescription *HostExecutor::PopulateDeviceDescription() const { // doesn't result in thrashing or other badness? 4GiB chosen arbitrarily. builder.set_device_memory_size(static_cast(4) * 1024 * 1024 * 1024); - float cycle_counter_frequency = 1e9; - if (FLAGS_stream_executor_cpu_real_clock_rate) { - cycle_counter_frequency = static_cast( - tensorflow::profile_utils::CpuUtils::GetCycleCounterFrequency()); - } + float cycle_counter_frequency = static_cast( + tensorflow::profile_utils::CpuUtils::GetCycleCounterFrequency()); builder.set_clock_rate_ghz(cycle_counter_frequency / 1e9); auto built = builder.Build(); diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h index e82f57569f35eb286ecc81caec30a77f148bd675..7ba1f181015e057b66e7e7287a592d5f2af1ead2 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.h +++ b/tensorflow/stream_executor/host/host_gpu_executor.h @@ -88,7 +88,7 @@ class HostExecutor : public internal::StreamExecutorInterface { uint64 size) override; // No "synchronize all activity" implemented for this platform at the moment. - bool SynchronizeAllActivity() override { return false; } + bool SynchronizeAllActivity() override { return true; } bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override; bool SynchronousMemSet(DeviceMemoryBase *location, int value, @@ -202,7 +202,7 @@ class HostExecutor : public internal::StreamExecutorInterface { return std::unique_ptr(new HostTimer()); } - void *CudaContextHack() override { return nullptr; } + void *GpuContextHack() override { return nullptr; } private: const PluginConfig plugin_config_; diff --git a/tensorflow/stream_executor/host/host_stream.cc b/tensorflow/stream_executor/host/host_stream.cc index 5a7d3b3dd49275edd5242c30b38bb4f505042816..bfbfb56cd7955196a295f263f1e62eedfa06d98d 100644 --- a/tensorflow/stream_executor/host/host_stream.cc +++ b/tensorflow/stream_executor/host/host_stream.cc @@ -28,18 +28,28 @@ HostStream::HostStream() HostStream::~HostStream() {} bool HostStream::EnqueueTask(std::function task) { + struct NotifiedTask { + HostStream* stream; + std::function task; + + void operator()() { + task(); + // Destroy the task before unblocking its waiters, as BlockHostUntilDone() + // should guarantee that all tasks are destroyed. + task = std::function(); + { + mutex_lock lock(stream->mu_); + --stream->pending_tasks_; + } + stream->completion_condition_.notify_all(); + } + }; + { mutex_lock lock(mu_); ++pending_tasks_; } - host_executor_->Schedule([this, task]() { - task(); - { - mutex_lock lock(mu_); - --pending_tasks_; - } - completion_condition_.notify_all(); - }); + host_executor_->Schedule(NotifiedTask{this, std::move(task)}); return true; } diff --git a/tensorflow/stream_executor/host/host_stream.h b/tensorflow/stream_executor/host/host_stream.h index 5d7b8a378268c3226a61fa43e738f209e84b30e9..be88f074cf6ece7bf925bf4dea546bb8aa2b4661 100644 --- a/tensorflow/stream_executor/host/host_stream.h +++ b/tensorflow/stream_executor/host/host_stream.h @@ -34,8 +34,8 @@ class HostStream : public internal::StreamInterface { bool EnqueueTask(std::function task); - void *CudaStreamHack() override { return nullptr; } - void **CudaStreamMemberHack() override { return nullptr; } + void *GpuStreamHack() override { return nullptr; } + void **GpuStreamMemberHack() override { return nullptr; } void BlockUntilDone(); diff --git a/tensorflow/compiler/xla/statusor.cc b/tensorflow/stream_executor/lib/statusor.cc similarity index 89% rename from tensorflow/compiler/xla/statusor.cc rename to tensorflow/stream_executor/lib/statusor.cc index 72ab67ff810e0ec384a22da092363cc7446435bb..e0e851f96ef6fe18ec32ff7d3fd1d1aed18b0343 100644 --- a/tensorflow/compiler/xla/statusor.cc +++ b/tensorflow/stream_executor/lib/statusor.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" -namespace xla { +namespace stream_executor { +namespace port { namespace internal_statusor { void Helper::HandleInvalidStatusCtorArg(Status* status) { @@ -35,4 +36,5 @@ void Helper::Crash(const Status& status) { } } // namespace internal_statusor -} // namespace xla +} // namespace port +} // namespace stream_executor diff --git a/tensorflow/stream_executor/lib/statusor.h b/tensorflow/stream_executor/lib/statusor.h index dab59096740102b94c0ff63c089b83ce052ea264..3c716acb462f1ca25e1d86408386d9eca37265b7 100644 --- a/tensorflow/stream_executor/lib/statusor.h +++ b/tensorflow/stream_executor/lib/statusor.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,19 +13,297 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h" - +// StatusOr is the union of a Status object and a T object. StatusOr models +// the concept of an object that is either a value, or an error Status +// explaining why such a value is not present. To this end, StatusOr does not +// allow its Status value to be Status::OK. +// +// The primary use-case for StatusOr is as the return value of a +// function which may fail. +// +// Example client usage for a StatusOr, where T is not a pointer: +// +// StatusOr result = DoBigCalculationThatCouldFail(); +// if (result.ok()) { +// float answer = result.ValueOrDie(); +// printf("Big calculation yielded: %f", answer); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr: +// +// StatusOr result = FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo(result.ValueOrDie()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example client usage for a StatusOr>: +// +// StatusOr> result = FooFactory::MakeNewFoo(arg); +// if (result.ok()) { +// std::unique_ptr foo = std::move(result.ValueOrDie()); +// foo->DoSomethingCool(); +// } else { +// LOG(ERROR) << result.status(); +// } +// +// Example factory implementation returning StatusOr: +// +// StatusOr FooFactory::MakeNewFoo(int arg) { +// if (arg <= 0) { +// return tensorflow::InvalidArgument("Arg must be positive"); +// } else { +// return new Foo(arg); +// } +// } +// +// Note that the assignment operators require that destroying the currently +// stored value cannot invalidate the argument; in other words, the argument +// cannot be an alias for the current value, or anything owned by the current +// value. #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_ -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/statusor_internals.h" namespace stream_executor { namespace port { -// Use XLA's StatusOr so we don't duplicate code. +#if defined(__clang__) +// Only clang supports warn_unused_result as a type annotation. +template +class TF_MUST_USE_RESULT StatusOr; +#endif + +template +class StatusOr : private internal_statusor::StatusOrData, + private internal_statusor::TraitsBase< + std::is_copy_constructible::value, + std::is_move_constructible::value> { + template + friend class StatusOr; + + typedef internal_statusor::StatusOrData Base; + + public: + typedef T element_type; + + // Constructs a new StatusOr with Status::UNKNOWN status. This is marked + // 'explicit' to try to catch cases like 'return {};', where people think + // StatusOr> will be initialized with an empty vector, + // instead of a Status::UNKNOWN status. + explicit StatusOr(); + + // StatusOr will be copy constructible/assignable if T is copy + // constructible. + StatusOr(const StatusOr&) = default; + StatusOr& operator=(const StatusOr&) = default; + + // StatusOr will be move constructible/assignable if T is move + // constructible. + StatusOr(StatusOr&&) = default; + StatusOr& operator=(StatusOr&&) = default; + + // Conversion copy/move constructor, T must be convertible from U. + template ::value>::type* = nullptr> + StatusOr(const StatusOr& other); + template ::value>::type* = nullptr> + StatusOr(StatusOr&& other); + + // Conversion copy/move assignment operator, T must be convertible from U. + template ::value>::type* = nullptr> + StatusOr& operator=(const StatusOr& other); + template ::value>::type* = nullptr> + StatusOr& operator=(StatusOr&& other); + + // Constructs a new StatusOr with the given value. After calling this + // constructor, calls to ValueOrDie() will succeed, and calls to status() will + // return OK. + // + // NOTE: Not explicit - we want to use StatusOr as a return type + // so it is convenient and sensible to be able to do 'return T()' + // when the return type is StatusOr. + // + // REQUIRES: T is copy constructible. + StatusOr(const T& value); + + // Constructs a new StatusOr with the given non-ok status. After calling + // this constructor, calls to ValueOrDie() will CHECK-fail. + // + // NOTE: Not explicit - we want to use StatusOr as a return + // value, so it is convenient and sensible to be able to do 'return + // Status()' when the return type is StatusOr. + // + // REQUIRES: !status.ok(). This requirement is DCHECKed. + // In optimized builds, passing Status::OK() here will have the effect + // of passing tensorflow::error::INTERNAL as a fallback. + StatusOr(const Status& status); + StatusOr& operator=(const Status& status); + + // TODO(b/62186997): Add operator=(T) overloads. + + // Similar to the `const T&` overload. + // + // REQUIRES: T is move constructible. + StatusOr(T&& value); + + // RValue versions of the operations declared above. + StatusOr(Status&& status); + StatusOr& operator=(Status&& status); + + // Returns this->status().ok() + bool ok() const { return this->status_.ok(); } + + // Returns a reference to our status. If this contains a T, then + // returns Status::OK(). + const Status& status() const &; + Status status() &&; + + // Returns a reference to our current value, or CHECK-fails if !this->ok(). + // + // Note: for value types that are cheap to copy, prefer simple code: + // + // T value = statusor.ValueOrDie(); + // + // Otherwise, if the value type is expensive to copy, but can be left + // in the StatusOr, simply assign to a reference: + // + // T& value = statusor.ValueOrDie(); // or `const T&` + // + // Otherwise, if the value type supports an efficient move, it can be + // used as follows: + // + // T value = std::move(statusor).ValueOrDie(); + // + // The std::move on statusor instead of on the whole expression enables + // warnings about possible uses of the statusor object after the move. + // C++ style guide waiver for ref-qualified overloads granted in cl/143176389 + // See go/ref-qualifiers for more details on such overloads. + const T& ValueOrDie() const &; + T& ValueOrDie() &; + const T&& ValueOrDie() const &&; + T&& ValueOrDie() &&; + + T ConsumeValueOrDie() { return std::move(ValueOrDie()); } + + // Ignores any errors. This method does nothing except potentially suppress + // complaints from any tools that are checking that errors are not dropped on + // the floor. + void IgnoreError() const; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Implementation details for StatusOr + +template +StatusOr::StatusOr() : Base(Status(tensorflow::error::UNKNOWN, "")) {} + +template +StatusOr::StatusOr(const T& value) : Base(value) {} + +template +StatusOr::StatusOr(const Status& status) : Base(status) {} + +template +StatusOr& StatusOr::operator=(const Status& status) { + this->Assign(status); + return *this; +} + +template +StatusOr::StatusOr(T&& value) : Base(std::move(value)) {} + +template +StatusOr::StatusOr(Status&& status) : Base(std::move(status)) {} + +template +StatusOr& StatusOr::operator=(Status&& status) { + this->Assign(std::move(status)); + return *this; +} + +template +template ::value>::type*> +inline StatusOr::StatusOr(const StatusOr& other) + : Base(static_cast::Base&>(other)) {} + +template +template ::value>::type*> +inline StatusOr& StatusOr::operator=(const StatusOr& other) { + if (other.ok()) + this->Assign(other.ValueOrDie()); + else + this->Assign(other.status()); + return *this; +} + +template +template ::value>::type*> +inline StatusOr::StatusOr(StatusOr&& other) + : Base(static_cast::Base&&>(other)) {} + +template +template ::value>::type*> +inline StatusOr& StatusOr::operator=(StatusOr&& other) { + if (other.ok()) { + this->Assign(std::move(other).ValueOrDie()); + } else { + this->Assign(std::move(other).status()); + } + return *this; +} + +template +const Status& StatusOr::status() const & { + return this->status_; +} +template +Status StatusOr::status() && { + return ok() ? Status::OK() : std::move(this->status_); +} + +template +const T& StatusOr::ValueOrDie() const & { + this->EnsureOk(); + return this->data_; +} + +template +T& StatusOr::ValueOrDie() & { + this->EnsureOk(); + return this->data_; +} + +template +const T&& StatusOr::ValueOrDie() const && { + this->EnsureOk(); + return std::move(this->data_); +} + +template +T&& StatusOr::ValueOrDie() && { + this->EnsureOk(); + return std::move(this->data_); +} + template -using StatusOr = ::xla::StatusOr; +void StatusOr::IgnoreError() const { + // no-op +} } // namespace port } // namespace stream_executor diff --git a/tensorflow/compiler/xla/statusor_internals.h b/tensorflow/stream_executor/lib/statusor_internals.h similarity index 94% rename from tensorflow/compiler/xla/statusor_internals.h rename to tensorflow/stream_executor/lib/statusor_internals.h index 14636bd144bc0a155fc96c5a350c658fd2dadfe6..09f88f5825f57c8e654bd079616a074e84de4f30 100644 --- a/tensorflow/compiler/xla/statusor_internals.h +++ b/tensorflow/stream_executor/lib/statusor_internals.h @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ -#define TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ +#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_INTERNALS_H_ +#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_INTERNALS_H_ + -#include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/stream_executor/lib/status.h" -namespace xla { +namespace stream_executor { +namespace port { namespace internal_statusor { class Helper { @@ -240,6 +242,7 @@ struct TraitsBase { }; } // namespace internal_statusor -} // namespace xla +} // namespace port +} // namespace stream_executor -#endif // TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ +#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_INTERNALS_H_ diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/stream_executor/lib/statusor_test.cc similarity index 99% rename from tensorflow/compiler/xla/statusor_test.cc rename to tensorflow/stream_executor/lib/statusor_test.cc index 377a618ffbd99316d409130df8a39f352664dee0..56584e189208b2576f10650fd56bca6d04ecc6c1 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/stream_executor/lib/statusor_test.cc @@ -15,18 +15,18 @@ limitations under the License. // Unit tests for StatusOr -#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/stream_executor/lib/statusor.h" #include #include -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test_benchmark.h" -namespace xla { +namespace stream_executor { +namespace port { namespace { class Base1 { @@ -672,4 +672,5 @@ void BM_StatusOrFactoryFailLongMsg(int iters) { BENCHMARK(BM_StatusOrFactoryFailLongMsg); } // namespace -} // namespace xla +} // namespace port +} // namespace stream_executor diff --git a/tensorflow/stream_executor/module_spec.h b/tensorflow/stream_executor/module_spec.h new file mode 100644 index 0000000000000000000000000000000000000000..75bdfed2d70364da4b191804d1e0973ee2658b70 --- /dev/null +++ b/tensorflow/stream_executor/module_spec.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_MODULE_SPEC_H_ +#define TENSORFLOW_STREAM_EXECUTOR_MODULE_SPEC_H_ + +#include "tensorflow/stream_executor/lib/array_slice.h" +#include "tensorflow/stream_executor/lib/stringpiece.h" +#include "tensorflow/stream_executor/platform/logging.h" +#include "tensorflow/stream_executor/platform/port.h" + +namespace stream_executor { + +// Describes how to load a module on a target platform. +// +// The exact meaning of a "module" may differ from platform to platform but +// loosely speaking a module a collection of kernels and global variables. It +// corresponds to CUmodule when running on CUDA. +class MultiModuleLoaderSpec { + public: + bool has_cuda_cubin_in_memory() const { return has_cuda_cubin_in_memory_; } + port::ArraySlice cuda_cubin_in_memory() const { + CHECK(has_cuda_cubin_in_memory()); + return {cuda_cubin_in_memory_.data(), cuda_cubin_in_memory_.size()}; + } + + bool has_cuda_ptx_in_memory() const { return has_cuda_ptx_in_memory_; } + const char* cuda_ptx_in_memory() const { + CHECK(has_cuda_ptx_in_memory()); + return cuda_ptx_in_memory_; + } + + void AddCudaCubinInMemory(port::ArraySlice cubin_bytes) { + CHECK(!cubin_bytes.empty()); + has_cuda_cubin_in_memory_ = true; + cuda_cubin_in_memory_ = cubin_bytes; + } + + void AddCudaPtxInMemory(const char* ptx) { + has_cuda_ptx_in_memory_ = true; + // The CUDA driver does not like getting an empty string as PTX. + cuda_ptx_in_memory_ = *ptx ? ptx : nullptr; + } + + private: + port::ArraySlice cuda_cubin_in_memory_; + bool has_cuda_cubin_in_memory_ = false; + const char* cuda_ptx_in_memory_; + bool has_cuda_ptx_in_memory_ = false; +}; + +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_MODULE_SPEC_H_ diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 4a98cfe16460ff860b6b73fedc21e98b5a3ed9fd..9efd34de24e8581993eefe7a18646e21b25007c2 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -115,7 +115,7 @@ string ToVlogString(const DeviceMemoryBase &memory) { } string ToVlogString(const DeviceMemoryBase *memory) { - return ToVlogString(*memory); + return memory == nullptr ? "null" : ToVlogString(*memory); } string ToVlogString(const Eigen::half &h) { @@ -192,6 +192,7 @@ string ToVlogString(dnn::DataType data_type) { case dnn::DataType::kInt8: return "dnn::DataType::kInt8"; } + return "unknown DataType"; } // Used together with PARAM to VLOG calls made to the stream. Intended @@ -210,13 +211,14 @@ string CallStr(const char *function_name, Stream *stream, // constructing all the strings in params is expensive. CHECK(VLOG_IS_ON(1)); - string str = port::StrCat("Called Stream::", function_name, "("); + string str = port::StrCat(stream->DebugStreamPointers(), + " Called Stream::", function_name, "("); const char *separator = ""; for (const auto ¶m : params) { port::StrAppend(&str, separator, param.first, "=", param.second); separator = ", "; } - port::StrAppend(&str, ") stream=", ToVlogString(stream)); + port::StrAppend(&str, ")"); if (VLOG_IS_ON(10)) { port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n"); } @@ -266,6 +268,12 @@ Stream::Stream(StreamExecutor *parent, Stream::~Stream() { VLOG_CALL(); + // Ensure the stream is completed. + auto status = BlockHostUntilDone(); + if (!status.ok()) { + LOG(WARNING) << "Error blocking host until done in stream destructor: " + << status; + } temporary_memory_manager_.ForceDeallocateAll(); if (allocated_) { @@ -1376,15 +1384,16 @@ Stream &Stream::ThenPoolForward( const dnn::BatchDescriptor &input_dimensions, const DeviceMemory &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data) { + DeviceMemory *output_data, ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), - PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); + PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, - input_data, output_dimensions, - output_data)); + input_data, output_dimensions, output_data, + workspace_allocator)); } else { SetError(); LOG(WARNING) @@ -1400,15 +1409,16 @@ Stream &Stream::ThenPoolForward( const dnn::BatchDescriptor &input_dimensions, const DeviceMemory &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data) { + DeviceMemory *output_data, ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), - PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); + PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, - input_data, output_dimensions, - output_data)); + input_data, output_dimensions, output_data, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1421,15 +1431,17 @@ Stream &Stream::ThenPoolForward( const dnn::BatchDescriptor &input_dimensions, const DeviceMemory &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data) { + DeviceMemory *output_data, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), - PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); + PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, - input_data, output_dimensions, - output_data)); + input_data, output_dimensions, output_data, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1444,16 +1456,19 @@ Stream &Stream::ThenPoolBackward( const dnn::BatchDescriptor &output_dimensions, const DeviceMemory &output_data, const DeviceMemory &input_diff_data, - DeviceMemory *output_diff_data) { + DeviceMemory *output_diff_data, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), - PARAM(input_diff_data), PARAM(output_diff_data)); + PARAM(input_diff_data), PARAM(output_diff_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, input_data, output_dimensions, output_data, - input_diff_data, output_diff_data)); + input_diff_data, output_diff_data, + workspace_allocator)); } else { SetError(); LOG(WARNING) @@ -1471,16 +1486,19 @@ Stream &Stream::ThenPoolBackward( const dnn::BatchDescriptor &output_dimensions, const DeviceMemory &output_data, const DeviceMemory &input_diff_data, - DeviceMemory *output_diff_data) { + DeviceMemory *output_diff_data, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), - PARAM(input_diff_data), PARAM(output_diff_data)); + PARAM(input_diff_data), PARAM(output_diff_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, input_data, output_dimensions, output_data, - input_diff_data, output_diff_data)); + input_diff_data, output_diff_data, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1495,16 +1513,19 @@ Stream &Stream::ThenPoolBackward( const dnn::BatchDescriptor &output_dimensions, const DeviceMemory &output_data, const DeviceMemory &input_diff_data, - DeviceMemory *output_diff_data) { + DeviceMemory *output_diff_data, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), - PARAM(input_diff_data), PARAM(output_diff_data)); + PARAM(input_diff_data), PARAM(output_diff_data), + PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, input_data, output_dimensions, output_data, - input_diff_data, output_diff_data)); + input_diff_data, output_diff_data, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1551,16 +1572,18 @@ Stream &Stream::ThenNormalizeBackwardWithDimensions( const dnn::BatchDescriptor &dimensions, const DeviceMemory &raw_data, const DeviceMemory &normalized_data, const DeviceMemory &normalized_variable_gradient, - DeviceMemory *raw_variable_gradient) { + DeviceMemory *raw_variable_gradient, + ScratchAllocator *workspace_allocator) { VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data), PARAM(normalized_data), PARAM(normalized_variable_gradient), - PARAM(raw_variable_gradient)); + PARAM(raw_variable_gradient), PARAM(workspace_allocator)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoNormalizeBackwardWithDimensions( this, normalize_descriptor, dimensions, raw_data, normalized_data, - normalized_variable_gradient, raw_variable_gradient)); + normalized_variable_gradient, raw_variable_gradient, + workspace_allocator)); } else { SetErrorAndLogNoDnnSupport(); } @@ -1900,30 +1923,82 @@ Stream &Stream::ThenCopyDevice2HostBuffer( Stream *Stream::GetOrCreateSubStream() { mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (stream.second) { - stream.second = false; - return stream.first.get(); + + // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams + // we encounter along the way. + for (int64 index = 0; index < sub_streams_.size();) { + std::pair, bool> &pair = sub_streams_[index]; + if (pair.second) { + // The sub_stream is reusable. + Stream *sub_stream = pair.first.get(); + if (sub_stream->ok()) { + VLOG(1) << DebugStreamPointers() << " reusing sub_stream " + << sub_stream->DebugStreamPointers(); + pair.second = false; + return sub_stream; + } + + // The stream is reusable and not ok. Streams have a monotonic state + // machine; the stream will remain in !ok forever. Swap it with the last + // stream and pop it off. + const int64 last = sub_streams_.size() - 1; + if (index != last) { + std::swap(pair, sub_streams_[last]); + } + sub_streams_.pop_back(); + VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream " + << sub_stream->DebugStreamPointers(); + } else { + // The sub_stream is not reusable, move on to the next one. + ++index; } } + + // No streams are reusable; create a new stream. sub_streams_.emplace_back(std::unique_ptr{new Stream{parent_}}, false); Stream *sub_stream = sub_streams_.back().first.get(); sub_stream->Init(); CHECK(ok_) << "sub-stream failed to be initialized"; + VLOG(1) << DebugStreamPointers() << " created new sub_stream " + << sub_stream->DebugStreamPointers(); return sub_stream; } void Stream::ReturnSubStream(Stream *sub_stream) { mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (stream.first.get() == sub_stream) { - stream.second = true; - return; + + // Look for the sub-stream. + for (int64 index = 0; index < sub_streams_.size(); ++index) { + std::pair, bool> &pair = sub_streams_[index]; + if (pair.first.get() != sub_stream) { + continue; + } + + // Found the sub_stream. + if (sub_stream->ok()) { + VLOG(1) << DebugStreamPointers() << " returned ok sub_stream " + << sub_stream->DebugStreamPointers(); + pair.second = true; + } else { + // The returned stream is not ok. Streams have a monotonic state + // machine; the stream will remain in !ok forever. Swap it with the last + // stream and pop it off. + VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream " + << sub_stream->DebugStreamPointers(); + const int64 last = sub_streams_.size() - 1; + if (index != last) { + std::swap(pair, sub_streams_[last]); + } + sub_streams_.pop_back(); } + return; } - LOG(FATAL) << "the sub-stream to be returned is not created by this stream"; + + LOG(FATAL) << DebugStreamPointers() + << " did not create the returned sub-stream " + << sub_stream->DebugStreamPointers(); } Stream &Stream::ThenStartTimer(Timer *t) { @@ -1932,7 +2007,8 @@ Stream &Stream::ThenStartTimer(Timer *t) { if (ok()) { CheckError(parent_->StartTimer(this, t)); } else { - LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t; + LOG(INFO) << DebugStreamPointers() + << " did not enqueue 'start timer': " << t; } return *this; } @@ -1943,7 +2019,8 @@ Stream &Stream::ThenStopTimer(Timer *t) { if (ok()) { CheckError(parent_->StopTimer(this, t)); } else { - LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t; + LOG(INFO) << DebugStreamPointers() + << " did not enqueue 'stop timer': " << t; } return *this; } @@ -1956,7 +2033,8 @@ Stream &Stream::ThenWaitFor(Stream *other) { CheckError(parent_->CreateStreamDependency(this, other)); } else { SetError(); - LOG(INFO) << "stream " << this << " did not wait for stream: " << other; + LOG(INFO) << DebugStreamPointers() << " did not wait for " + << other->DebugStreamPointers(); } return *this; } @@ -1973,7 +2051,7 @@ Stream &Stream::ThenWaitFor(Event *event) { << "at fault. Monitor for further errors."; } } else { - LOG(INFO) << "stream " << this << " did not wait for an event."; + LOG(INFO) << DebugStreamPointers() << " did not wait for an event."; } return *this; } @@ -4656,6 +4734,115 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( scratch_allocator); } +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + float beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl &, int, int64, + const DeviceMemory &, int, int64, float, + DeviceMemory *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + float beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl &, int, int64, + const DeviceMemory &, int, int64, float, + DeviceMemory *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + double beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl &, int, int64, + const DeviceMemory &, int, int64, double, + DeviceMemory *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl, const DeviceMemory> &, + int, int64, const DeviceMemory> &, int, + int64, std::complex, DeviceMemory> *, + int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + +Stream &Stream::ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count) { + VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), + PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b), + PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc), + PARAM(stride_c), PARAM(batch_count)); + + ThenBlasImpl, const DeviceMemory> &, + int, int64, const DeviceMemory> &, int, + int64, std::complex, + DeviceMemory> *, int, int64, int> + impl; + return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa, + transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_count); +} + Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { VLOG_CALL(PARAM(seed), PARAM(seed_bytes)); @@ -4664,10 +4851,10 @@ Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { CheckError(rng->SetSeed(this, seed, seed_bytes)); } else { SetError(); - LOG(INFO) << "stream " << this << " unable to initialize RNG"; + LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG"; } } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not set RNG seed: " << static_cast(seed) << "; bytes: " << seed_bytes; } @@ -4682,8 +4869,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory *values) { CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4698,8 +4886,9 @@ Stream &Stream::ThenPopulateRandGaussian(float mean, float sd, CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4714,8 +4903,9 @@ Stream &Stream::ThenPopulateRandGaussian(double mean, double sd, CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4729,8 +4919,9 @@ Stream &Stream::ThenPopulateRandUniform(DeviceMemory *values) { CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4745,8 +4936,9 @@ Stream &Stream::ThenPopulateRandUniform( CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4761,9 +4953,9 @@ Stream &Stream::ThenPopulateRandUniform( CheckError(rng->DoPopulateRandUniform(this, values)); } else { SetError(); - LOG(INFO) << "stream " << this - << " attempting to perform RNG operation using StreamExecutor " - "without RNG support."; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform RNG operation using StreamExecutor" + " without RNG support."; } } return *this; @@ -4776,7 +4968,7 @@ Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, if (ok()) { CheckError(parent_->Memcpy(this, host_dst, gpu_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy device-to-host; source: " << gpu_src.opaque(); } return *this; @@ -4789,7 +4981,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, if (ok()) { CheckError(parent_->Memcpy(this, gpu_dst, host_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy host-to-device; source: " << host_src; } return *this; @@ -4802,7 +4994,7 @@ Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, if (ok()) { CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memcpy gpu-to-gpu; source: " << &gpu_src; } return *this; @@ -4814,7 +5006,7 @@ Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) { if (ok()) { CheckError(parent_->MemZero(this, location, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memzero GPU location; source: " << location; } return *this; @@ -4827,7 +5019,7 @@ Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern, if (ok()) { CheckError(parent_->Memset32(this, location, pattern, size)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " did not memset GPU location; source: " << location << "; size: " << size << "; pattern: " << std::hex << pattern; } @@ -5096,12 +5288,25 @@ Stream &Stream::ThenDoHostCallback(std::function callback) { if (ok()) { CheckError(parent_->HostCallback(this, callback)); } else { - LOG(INFO) << "stream " << this + LOG(INFO) << DebugStreamPointers() << " was in error state before adding host callback"; } return *this; } +Stream &Stream::ThenDoHostCallbackWithStatus( + std::function callback) { + VLOG_CALL(PARAM(callback)); + + if (ok()) { + CheckError(parent_->HostCallback(this, std::move(callback))); + } else { + LOG(WARNING) << "stream " << DebugStreamPointers() + << " was in error state before adding host callback"; + } + return *this; +} + Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory> &input, DeviceMemory> *output) { @@ -5112,8 +5317,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5129,8 +5335,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5145,8 +5352,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory &input, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5161,8 +5369,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory &input, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5178,8 +5387,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5195,8 +5405,9 @@ Stream &Stream::ThenFft(fft::Plan *plan, CheckError(fft->DoFft(this, plan, input, output)); } else { SetError(); - LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " - "without FFT support"; + LOG(INFO) << DebugStreamPointers() + << " attempting to perform FFT operation using StreamExecutor" + " without FFT support"; } } return *this; @@ -5223,28 +5434,21 @@ port::Status Stream::BlockHostUntilDone() { port::Status status = port::Status( port::error::INTERNAL, "stream did not block host until done; was already in an error state"); - LOG(INFO) << status << " " << this; + LOG(INFO) << DebugStreamPointers() << " " << status; return status; } - port::Status first_error; - { - // Wait until all active sub-streams have done their tasks. - mutex_lock lock(mu_); - for (auto &stream : sub_streams_) { - if (!stream.second) { - first_error.Update(stream.first->BlockHostUntilDone()); - // Set this sub-stream as available. - stream.second = true; - } - } - } - temporary_memory_manager_.DeallocateFinalizedTemporaries(); - first_error.Update(parent_->BlockHostUntilDone(this)); - CheckError(first_error.ok()); - return first_error; + port::Status error = parent_->BlockHostUntilDone(this); + CheckError(error.ok()); + return error; +} + +string Stream::DebugStreamPointers() const { + // Relies on the ToVlogString(const void*) overload above. + return port::StrCat("[stream=", ToVlogString(this), + ",impl=", ToVlogString(implementation_.get()), "]"); } } // namespace stream_executor diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 3da1b856d6a41fa0c8d5a77feac33932da392422..e1629b5b3084e6641bcdf80d1de00f33f1c81940 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "tensorflow/core/platform/macros.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/dnn.h" @@ -121,10 +122,14 @@ class Stream { // Get or create a sub-stream from this stream. If there is any sub-stream in // the pool that can be reused then just return this sub-stream. Otherwise // create a new sub-stream. + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone. Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_); // Return the sub-stream back to the host stream so that it can be reused - // later. + // later. Sub-streams that are !ok() will not be reused. + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone. void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_); // Allocate temporary memories. The stream will deallocate them when blocked @@ -156,14 +161,13 @@ class Stream { const TypedKernel &kernel, Args... args); // Record a "start" event for the interval timer at this point in the - // stream's - // execution (relative to the previously and subsequently enqueued items in - // the stream's execution). Streams may be started/stopped multiple times. + // stream's execution (relative to the previously and subsequently enqueued + // items in the stream's execution). Streams may be started/stopped multiple + // times. Stream &ThenStartTimer(Timer *t); // Record a "stop" event for the interval timer at this point in the - // stream's - // execution. See also Stream::ThenStartTimer. + // stream's execution. See also Stream::ThenStartTimer. Stream &ThenStopTimer(Timer *t); // TODO(leary) If work is added to the stream that is being depended upon, @@ -179,8 +183,7 @@ class Stream { // // Checks that a stream does not wait for itself, and it is up to the // user to guarantee that a stream does not come to wait on itself in a - // cyclic - // manner; in that case, behavior is undefined. + // cyclic manner; in that case, behavior is undefined. // // N.B. Base recursion case for the variadic ThenWaitFor. Stream &ThenWaitFor(Stream *other); @@ -630,19 +633,22 @@ class Stream { const dnn::BatchDescriptor &input_dimensions, const DeviceMemory &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data); + DeviceMemory *output_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, const DeviceMemory &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data); + DeviceMemory *output_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, const DeviceMemory &input_data, const dnn::BatchDescriptor &output_dimensions, - DeviceMemory *output_data); + DeviceMemory *output_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, @@ -650,7 +656,8 @@ class Stream { const dnn::BatchDescriptor &output_dimensions, const DeviceMemory &output_data, const DeviceMemory &input_diff_data, - DeviceMemory *output_diff_data); + DeviceMemory *output_diff_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, @@ -658,7 +665,8 @@ class Stream { const dnn::BatchDescriptor &output_dimensions, const DeviceMemory &output_data, const DeviceMemory &input_diff_data, - DeviceMemory *output_diff_data); + DeviceMemory *output_diff_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, @@ -666,7 +674,8 @@ class Stream { const dnn::BatchDescriptor &output_dimensions, const DeviceMemory &output_data, const DeviceMemory &input_diff_data, - DeviceMemory *output_diff_data); + DeviceMemory *output_diff_data, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenNormalize(const dnn::NormalizeDescriptor &normalize_descriptor, const DeviceMemory &input_data, @@ -685,7 +694,8 @@ class Stream { const DeviceMemory &raw_data, const DeviceMemory &normalized_data, const DeviceMemory &normalized_variable_gradient, - DeviceMemory *raw_variable_gradient); + DeviceMemory *raw_variable_gradient, + ScratchAllocator *workspace_allocator = nullptr); Stream &ThenActivate(dnn::ActivationMode activation_mode, const dnn::BatchDescriptor &dimensions, @@ -1351,33 +1361,39 @@ class Stream { DeviceMemory> *x, int incx); // See BlasSupport::DoBlasGemm. - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, float alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, float beta, - DeviceMemory *c, int ldc); - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, float alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, float beta, - DeviceMemory *c, int ldc); - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, double alpha, - const DeviceMemory &a, int lda, - const DeviceMemory &b, int ldb, double beta, - DeviceMemory *c, int ldc); - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &b, int ldb, - std::complex beta, - DeviceMemory> *c, int ldc); - Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, - uint64 n, uint64 k, std::complex alpha, - const DeviceMemory> &a, int lda, - const DeviceMemory> &b, int ldb, - std::complex beta, - DeviceMemory> *c, int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, float alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + float beta, DeviceMemory *c, + int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, float alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + float beta, DeviceMemory *c, int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, double alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, + double beta, DeviceMemory *c, int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, + std::complex alpha, + const DeviceMemory> &a, + int lda, + const DeviceMemory> &b, + int ldb, std::complex beta, + DeviceMemory> *c, int ldc); + TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, + uint64 m, uint64 n, uint64 k, + std::complex alpha, + const DeviceMemory> &a, + int lda, + const DeviceMemory> &b, + int ldb, std::complex beta, + DeviceMemory> *c, + int ldc); Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, @@ -1545,6 +1561,38 @@ class Stream { std::complex beta, const port::ArraySlice> *> &c, int ldc, int batch_count, ScratchAllocator *scratch_allocator); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, + int64 stride_b, float beta, DeviceMemory *c, int ldc, + int64 stride_c, int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, float alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + float beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, double alpha, const DeviceMemory &a, int lda, + int64 stride_a, const DeviceMemory &b, int ldb, int64 stride_b, + double beta, DeviceMemory *c, int ldc, int64 stride_c, + int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count); + Stream &ThenBlasGemmStridedBatched( + blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, + uint64 k, std::complex alpha, + const DeviceMemory> &a, int lda, int64 stride_a, + const DeviceMemory> &b, int ldb, int64 stride_b, + std::complex beta, DeviceMemory> *c, int ldc, + int64 stride_c, int batch_count); // See BlasSupport::DoBlasHemm. Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m, @@ -1997,6 +2045,11 @@ class Stream { // negative effects on performance. Stream &ThenDoHostCallback(std::function callback); + // Entrains onto the stream a callback to the host (from the device). + // Behaves as ThenDoHostCallback above, but returns a Status instead of void. + // This overload should be preferred if the callback could fail. + Stream &ThenDoHostCallbackWithStatus(std::function callback); + // Returns the StreamExecutor (parent object) associated with this stream. StreamExecutor *parent() const { CHECK(parent_ != nullptr); @@ -2007,6 +2060,9 @@ class Stream { // with this stream. internal::TemporaryMemoryManager *temporary_memory_manager(); + // Returns a debugging string "[stream=0x...,impl=0x...]". + string DebugStreamPointers() const; + private: friend class host::HostBlas; // for parent_. friend class host::HostFft; // for parent_. diff --git a/tensorflow/stream_executor/stream_executor_internal.cc b/tensorflow/stream_executor/stream_executor_internal.cc index 8297228e6fecddffa8fc68a1a028456dc8e75a65..7df6a361c6810b9a15c97f15704435d145dccb8e 100644 --- a/tensorflow/stream_executor/stream_executor_internal.cc +++ b/tensorflow/stream_executor/stream_executor_internal.cc @@ -36,5 +36,17 @@ StreamExecutorFactory* MakeOpenCLExecutorImplementation() { StreamExecutorFactory MakeHostExecutorImplementation; +// TODO(b/112125301): Consolodate this down to one implementation of +// HostCallback, taking a callback that returns a Status. +bool StreamExecutorInterface::HostCallback( + Stream* stream, std::function callback) { + return HostCallback(stream, [callback]() { + port::Status s = callback(); + if (!s.ok()) { + LOG(WARNING) << "HostCallback failed: " << s; + } + }); +} + } // namespace internal } // namespace stream_executor diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index 9c989b971dcee6dd99aa155cd2230ba849d204fe..59a477b5c9c37f10d8f12645deb3cdb832a8d544 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -36,20 +36,38 @@ limitations under the License. #include "tensorflow/stream_executor/kernel_cache_config.h" #include "tensorflow/stream_executor/kernel_spec.h" #include "tensorflow/stream_executor/launch_dim.h" +#include "tensorflow/stream_executor/lib/inlined_vector.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/module_spec.h" #include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/shared_memory_config.h" #include "tensorflow/stream_executor/trace_listener.h" -#include "tensorflow/stream_executor/lib/inlined_vector.h" namespace stream_executor { class Stream; class Timer; +// An opaque handle to a loaded module. +// +// An instance of this is returned from StreamExecutor::GetModule. +class ModuleHandle { + public: + /*implicit*/ ModuleHandle(void *id = nullptr) : id_(id) {} + + // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a + // null pointer. + void *id() const { return id_; } + + explicit operator bool() const { return id() != nullptr; } + + private: + void *id_; +}; + namespace internal { // Platform-dependent interface class for the generic Events interface, in @@ -100,19 +118,20 @@ class StreamInterface { // Default destructor for the abstract interface. virtual ~StreamInterface() {} - // Returns the CUDA stream associated with this platform's stream + // Returns the GPU stream associated with this platform's stream // implementation. // - // WARNING: checks that the underlying platform is, in fact, CUDA, causing a - // fatal error if it is not. This hack is made available solely for use from - // distbelief code, which temporarily has strong ties to CUDA as a platform. - virtual void *CudaStreamHack() { return nullptr; } - - // See the above comment on CudaStreamHack -- this further breaks abstraction - // for Eigen within distbelief, which has strong ties to CUDA as a platform, - // and a historical attachment to a programming model which takes a + // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, + // causing a fatal error if it is not. This hack is made available solely for + // use from distbelief code, which temporarily has strong ties to CUDA or + // ROCm as a platform. + virtual void *GpuStreamHack() { return nullptr; } + + // See the above comment on GpuStreamHack -- this further breaks abstraction + // for Eigen within distbelief, which has strong ties to CUDA or ROCm as a + // platform, and a historical attachment to a programming model which takes a // stream-slot rather than a stream-value. - virtual void **CudaStreamMemberHack() { return nullptr; } + virtual void **GpuStreamMemberHack() { return nullptr; } private: SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface); @@ -163,6 +182,11 @@ class StreamExecutorInterface { KernelBase *kernel) { return false; } + virtual bool LoadModule(const MultiModuleLoaderSpec &spec, + ModuleHandle *module_handle) { + return false; + } + virtual bool UnloadModule(ModuleHandle module_handle) { return false; } virtual bool Launch(Stream *stream, const ThreadDim &thread_dims, const BlockDim &block_dims, const KernelBase &k, const KernelArgsArrayBase &args) { @@ -212,9 +236,11 @@ class StreamExecutorInterface { virtual bool Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst, const void *host_src, uint64 size) = 0; virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst, - const DeviceMemoryBase &host_src, + const DeviceMemoryBase &gpu_src, uint64 size) = 0; virtual bool HostCallback(Stream *stream, std::function callback) = 0; + virtual bool HostCallback(Stream *stream, + std::function callback); virtual port::Status AllocateEvent(Event *event) = 0; virtual port::Status DeallocateEvent(Event *event) = 0; virtual port::Status RecordEvent(Stream *stream, Event *event) = 0; @@ -246,7 +272,12 @@ class StreamExecutorInterface { // null, however, both of them cannot be null at the same time. To use // constant memory in CUDA, GetSymbol has to be used. Returns true if symbol // is found. - virtual bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes) { + // + // If ModuleHandle is set then we search for `symbol_name` only within the + // module corresponding to `module_handle`. Otherwise all loaded modules are + // searched. + virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle, + void **mem, size_t *bytes) { return false; } @@ -324,13 +355,14 @@ class StreamExecutorInterface { virtual std::unique_ptr GetStreamImplementation() = 0; virtual std::unique_ptr GetTimerImplementation() = 0; - // Returns the CUDA context associated with this StreamExecutor platform - // implementation. + // Returns the CUDA or ROCm context associated with this StreamExecutor + // platform implementation. // - // WARNING: checks that the underlying platform is, in fact, CUDA, causing a - // fatal error if it is not. This hack is made available solely for use from - // distbelief code, which temporarily has strong ties to CUDA as a platform. - virtual void *CudaContextHack() { return nullptr; } + // WARNING: checks that the underlying platform is, in fact, CUDA or ROCm, + // causing a fatal error if it is not. This hack is made available solely for + // use from distbelief code, which temporarily has strong ties to CUDA or ROCm + // as a platform. + virtual void *GpuContextHack() { return nullptr; } private: SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface); diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 000795ff0048dddb0eb4a08956e6de6f5e336f28..9515d8e62a8ed809d88182bdf3fdb3ba536dd68c 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -220,6 +220,15 @@ void StreamExecutor::UnloadKernel(const KernelBase *kernel) { implementation_->UnloadKernel(kernel); } +bool StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec, + ModuleHandle *module_handle) { + return implementation_->LoadModule(spec, module_handle); +} + +bool StreamExecutor::UnloadModule(ModuleHandle module_handle) { + return implementation_->UnloadModule(module_handle); +} + void StreamExecutor::Deallocate(DeviceMemoryBase *mem) { VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque() << ") mem->size()=" << mem->size() << StackTraceIfVLOG10(); @@ -459,9 +468,34 @@ void *StreamExecutor::Allocate(uint64 size) { return buf; } -bool StreamExecutor::GetSymbol(const string &symbol_name, void **mem, +port::StatusOr StreamExecutor::GetUntypedSymbol( + const string &symbol_name, ModuleHandle module_handle) { + // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to + // be nullptr/0 for consistency with DeviceMemory semantics. + void *opaque = nullptr; + size_t bytes = 0; + if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) { + return DeviceMemoryBase(opaque, bytes); + } + + if (static_cast(module_handle)) { + return port::Status( + port::error::NOT_FOUND, + port::StrCat("Check if module containing symbol ", symbol_name, + " is loaded (module_handle = ", + reinterpret_cast(module_handle.id()), ")")); + } else { + return port::Status( + port::error::NOT_FOUND, + port::StrCat("Check if kernel using the symbol is loaded: ", + symbol_name)); + } +} + +bool StreamExecutor::GetSymbol(const string &symbol_name, + ModuleHandle module_handle, void **mem, size_t *bytes) { - return implementation_->GetSymbol(symbol_name, mem, bytes); + return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes); } void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) { @@ -665,6 +699,11 @@ bool StreamExecutor::HostCallback(Stream *stream, return implementation_->HostCallback(stream, std::move(callback)); } +bool StreamExecutor::HostCallback(Stream *stream, + std::function callback) { + return implementation_->HostCallback(stream, std::move(callback)); +} + port::Status StreamExecutor::AllocateEvent(Event *event) { return implementation_->AllocateEvent(event); } diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index ad80a1ba259ce0c6e2785373cc986b8bf34f6460..437f29861670309424940f39f325a6aee2bbf897 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -106,6 +106,16 @@ class StreamExecutor { // Releases any state associated with the previously loaded kernel. void UnloadKernel(const KernelBase *kernel); + // Loads a module for the platform this StreamExecutor is acting upon. + // + // `spec` describes the module to be loaded. On success writes the handle for + // the loaded module to `module_handle` and returns true. Else returns false. + bool LoadModule(const MultiModuleLoaderSpec &spec, + ModuleHandle *module_handle); + + // Unloads the module with handle `module_handle`. + bool UnloadModule(ModuleHandle module_handle); + // Synchronously allocates an array on the device of type T with element_count // elements. template @@ -169,8 +179,16 @@ class StreamExecutor { // type of symbol and T match. // - Note: symbol_name should include its namespace as well. For example, // pass "nms0::symbol" if referring to nms0::symbol. + // + // If `module_handle` is set then searches only within the module + // corresponding to `module_handle`. template - port::StatusOr> GetSymbol(const string &symbol_name); + port::StatusOr> GetSymbol(const string &symbol_name, + ModuleHandle module_handle = {}); + + // An untyped version of GetSymbol. + port::StatusOr GetUntypedSymbol( + const string &symbol_name, ModuleHandle module_handle = {}); // Deallocate the DeviceMemory previously allocated via this interface. // Deallocation of a nullptr-representative value is permitted. @@ -507,7 +525,8 @@ class StreamExecutor { // Finds and retrieves device memory for the symbol on the underlying // platform. - bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes); + bool GetSymbol(const string &symbol_name, ModuleHandle module_handle, + void **mem, size_t *bytes); // Entrains a memcpy operation onto stream, with a host destination location // host_dst and a device memory source, with target size size. @@ -530,6 +549,11 @@ class StreamExecutor { // See Stream::ThenDoHostCallback for full details. bool HostCallback(Stream *stream, std::function callback); + // Entrains on a stream a user-specified function to be run on the host. + // See Stream::ThenDoHostCallback for full details. + // This is the preferred form for a callback that may return an error. + bool HostCallback(Stream *stream, std::function callback); + // Performs platform-specific allocation and initialization of an event. port::Status AllocateEvent(Event *event); @@ -678,6 +702,41 @@ class StreamExecutor { SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor); }; +// A wrapper around ModuleHandle that uses RAII to manage its lifetime. +class ScopedModuleHandle { + public: + explicit ScopedModuleHandle(StreamExecutor *executor, + ModuleHandle module_handle) + : executor_(executor), module_handle_(module_handle) {} + + ScopedModuleHandle(ScopedModuleHandle &&other) { + executor_ = other.executor_; + module_handle_ = other.module_handle_; + other.executor_ = nullptr; + other.module_handle_ = ModuleHandle(); + } + + ScopedModuleHandle &operator=(ScopedModuleHandle &&other) { + executor_ = other.executor_; + module_handle_ = other.module_handle_; + other.executor_ = nullptr; + other.module_handle_ = ModuleHandle(); + return *this; + } + + ~ScopedModuleHandle() { + if (static_cast(module_handle_)) { + CHECK(executor_->UnloadModule(module_handle_)); + } + } + + private: + StreamExecutor *executor_; + ModuleHandle module_handle_; + + TF_DISALLOW_COPY_AND_ASSIGN(ScopedModuleHandle); +}; + //////////// // Inlines @@ -690,19 +749,13 @@ inline DeviceMemory StreamExecutor::AllocateArray(uint64 element_count) { template inline port::StatusOr> StreamExecutor::GetSymbol( - const string &symbol_name) { - // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to - // be nullptr/0 for consistency with DeviceMemory semantics. - void *opaque = nullptr; - size_t bytes = 0; - if (GetSymbol(symbol_name, &opaque, &bytes)) { - CHECK_EQ(bytes % sizeof(T), 0); - return DeviceMemory::MakeFromByteSize(opaque, bytes); + const string &symbol_name, ModuleHandle module_handle) { + port::StatusOr untyped_symbol = + GetUntypedSymbol(symbol_name, module_handle); + if (!untyped_symbol.ok()) { + return untyped_symbol.status(); } - return port::Status( - port::error::NOT_FOUND, - port::StrCat("Check if kernel using the symbol is loaded: ", - symbol_name)); + return DeviceMemory(untyped_symbol.ValueOrDie()); } template diff --git a/tensorflow/stream_executor/stream_test.cc b/tensorflow/stream_executor/stream_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cfc051fd0950dc38034785397ce49340a11c89a9 --- /dev/null +++ b/tensorflow/stream_executor/stream_test.cc @@ -0,0 +1,203 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/stream_executor/stream_executor.h" + +#include "tensorflow/core/platform/test.h" + +namespace stream_executor { +namespace { + +class StreamTest : public ::testing::Test { + protected: + std::unique_ptr NewStreamExecutor() { + Platform* platform = + MultiPlatformManager::PlatformWithName("Host").ConsumeValueOrDie(); + StreamExecutorConfig config(/*ordinal=*/0); + return platform->GetUncachedExecutor(config).ConsumeValueOrDie(); + } +}; + +TEST_F(StreamTest, NoInitNotOk) { + std::unique_ptr executor = NewStreamExecutor(); + Stream stream(executor.get()); + EXPECT_FALSE(stream.ok()); +} + +TEST_F(StreamTest, InitOk) { + std::unique_ptr executor = NewStreamExecutor(); + Stream stream(executor.get()); + stream.Init(); + EXPECT_TRUE(stream.ok()); +} + +TEST_F(StreamTest, OneSubStream) { + std::unique_ptr executor = NewStreamExecutor(); + Stream stream(executor.get()); + stream.Init(); + EXPECT_TRUE(stream.ok()); + + // Get and return a sub-stream. Sub-streams are always initialized. + Stream* sub_stream1 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream1->ok()); + stream.ReturnSubStream(sub_stream1); + + // Get and return another sub-stream. + Stream* sub_stream2 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream2->ok()); + stream.ReturnSubStream(sub_stream1); + + // The underlying sub-streams should be the same, since sub_stream1 + // was returned before we tried to get sub_stream2. + EXPECT_EQ(sub_stream1, sub_stream2); +} + +TEST_F(StreamTest, TwoSubStreams) { + std::unique_ptr executor = NewStreamExecutor(); + Stream stream(executor.get()); + stream.Init(); + EXPECT_TRUE(stream.ok()); + + // Get two sub-streams. + Stream* sub_stream1 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream1->ok()); + Stream* sub_stream2 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream2->ok()); + + // The underlying sub-streams should be different, since neither + // sub-stream has been returned. + EXPECT_NE(sub_stream1, sub_stream2); + + // Return sub_stream1 and get sub_stream3, which should be the same. + stream.ReturnSubStream(sub_stream1); + Stream* sub_stream3 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream3->ok()); + EXPECT_EQ(sub_stream1, sub_stream3); + EXPECT_NE(sub_stream2, sub_stream3); + + // Return sub_stream2 and get sub_stream4, which should be the same. + stream.ReturnSubStream(sub_stream2); + Stream* sub_stream4 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream4->ok()); + EXPECT_EQ(sub_stream2, sub_stream4); + EXPECT_NE(sub_stream3, sub_stream4); +} + +TEST_F(StreamTest, FailedSubStreamBeforeReturnNotReused) { + std::unique_ptr executor = NewStreamExecutor(); + Stream stream(executor.get()); + stream.Init(); + EXPECT_TRUE(stream.ok()); + + // Get sub_stream1. + Stream* sub_stream1 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream1->ok()); + + // Force an error on sub_stream1; here we call a method that requires DNN + // support, which we know the Host platform doesn't support. + sub_stream1->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(sub_stream1->ok()); + + // Return sub_stream1 and get sub_stream2. + stream.ReturnSubStream(sub_stream1); + Stream* sub_stream2 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream2->ok()); + + // The underlying sub_streams should be different. They would have been the + // same, but since we forced an error on sub_stream1, it will not be + // re-used. Sadly we can't just check: + // EXPECT_NE(sub_stream1, sub_stream2); + // + // The above should hold logically, but it may fail if the new Stream instance + // allocated for sub_stream2 happens to reside in the same memory address as + // sub_stream1. + // + // The check that sub_stream2->ok() serves as a good-enough check. + + // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1 + // has no effect on these streams, and they are the same. + stream.ReturnSubStream(sub_stream2); + Stream* sub_stream3 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream3->ok()); + EXPECT_EQ(sub_stream2, sub_stream3); +} + +TEST_F(StreamTest, FailedSubStreamAfterReturnNotReused) { + std::unique_ptr executor = NewStreamExecutor(); + Stream stream(executor.get()); + stream.Init(); + EXPECT_TRUE(stream.ok()); + + // Get and return sub_stream1. + Stream* sub_stream1 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream1->ok()); + stream.ReturnSubStream(sub_stream1); + + // Force an error on sub_stream1; here we call a method that requires DNN + // support, which we know the Host platform doesn't support. + // + // It is a bit weird to use sub_stream1 after it has already been returned. By + // doing this, we're simulating an asynchronous error that occurs during + // execution of the sub_stream, that occurs after the sub_stream is returned. + // + // E.g. the following is a common pattern of usage, where the execution of the + // operations enqueued onto the sub streams may occur after the streams have + // already been returned. + // + // void EnqueueOnSubStreams(Stream* stream) { + // Stream* sub_stream1 = stream.GetOrCreateSubStream(); + // Stream* sub_stream2 = stream.GetOrCreateSubStream(); + // // ... enqueue some operations on the sub streams ... + // stream.ThenWaitFor(sub_stream1).ThenWaitFor(sub_stream2); + // stream.ReturnSubStream(sub_stream1); + // stream.ReturnSubStream(sub_stream2); + // } + // + // Stream* main_stream = ...; + // EnqueueOnSubStreams(main_stream); + // main_stream.BlockHostUntilDone(); + // + // TODO(b/112196569): The semantics of failed sub-streams is error-prone; + // GetOrCreateSubStream can still return a sub-stream that has not encountered + // an error yet, but will encounter one in the future, based on previously + // enqueued operations. + sub_stream1->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(sub_stream1->ok()); + + // Get and return sub_stream2. + Stream* sub_stream2 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream2->ok()); + + // The underlying streams should be different. They would have been the same, + // but since we forced an error on sub_stream1, it will not be re-used. Sadly + // we can't just check: + // EXPECT_NE(sub_stream1, sub_stream2); + // + // The above should hold logically, but it may fail if the new stream instance + // allocated for sub_stream2 happens to reside in the same memory address as + // sub_stream1. + // + // The check that sub_stream2->ok() serves as a good-enough check. + + // Return sub_stream2 and get sub_stream3. The previous error on sub_stream1 + // has no effect on these streams, and they are the same. + stream.ReturnSubStream(sub_stream2); + Stream* sub_stream3 = stream.GetOrCreateSubStream(); + EXPECT_TRUE(sub_stream3->ok()); + EXPECT_EQ(sub_stream2, sub_stream3); +} + +} // namespace +} // namespace stream_executor diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index b59f8e1f987567727ef3d4051618edd377d06f89..0b26c0d4451f611dba8d528905ea5670524cee53 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -9,6 +9,7 @@ load( "tf_additional_grpc_deps_py", "tf_additional_xla_deps_py", "if_static", + "if_dynamic_kernels", ) load( "@local_config_tensorrt//:build_defs.bzl", @@ -22,9 +23,14 @@ load( load( "//third_party/mkl:build_defs.bzl", "if_mkl", - "if_mkl_lnx_x64" + "if_mkl_lnx_x64", + "if_mkl_ml", + "mkl_deps", +) +load( + "//third_party/mkl_dnn:build_defs.bzl", + "if_mkl_open_source_only" ) - def register_extension_info(**kwargs): pass @@ -134,20 +140,32 @@ def if_not_mobile(a): "//conditions:default": a, }) +# Config setting selector used when building for products +# which requires restricted licenses to be avoided. +def if_not_lgpl_restricted(a): + _ = (a,) + return select({ + "//conditions:default": [], + }) + def if_not_windows(a): return select({ clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": a, }) def if_windows(a): return select({ clean_dep("//tensorflow:windows"): a, - clean_dep("//tensorflow:windows_msvc"): a, "//conditions:default": [], }) +def if_not_windows_cuda(a): + return select({ + clean_dep("//tensorflow:with_cuda_support_windows_override"): [], + "//conditions:default": a, + }) + def if_linux_x86_64(a): return select({ clean_dep("//tensorflow:linux_x86_64"): a, @@ -174,9 +192,13 @@ def get_win_copts(is_external=False): "/DEIGEN_AVOID_STL_ARRAY", "/Iexternal/gemmlowp", "/wd4018", # -Wno-sign-compare - "/U_HAS_EXCEPTIONS", - "/D_HAS_EXCEPTIONS=1", - "/EHsc", # -fno-exceptions + # Bazel's CROSSTOOL currently pass /EHsc to enable exception by + # default. We can't pass /EHs-c- to disable exception, otherwise + # we will get a waterfall of flag conflict warnings. Wait for + # Bazel to fix this. + # "/D_HAS_EXCEPTIONS=0", + # "/EHs-c-", + "/wd4577", "/DNOGDI", ] if is_external: @@ -208,6 +230,7 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False): + if_cuda(["-DGOOGLE_CUDA=1"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"]) + if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) + + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_mkl_lnx_x64(["-fopenmp"]) + if_android_arm(["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) @@ -220,8 +243,8 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False): clean_dep("//tensorflow:android"): android_copts, clean_dep("//tensorflow:darwin"): [], clean_dep("//tensorflow:windows"): get_win_copts(is_external), - clean_dep("//tensorflow:windows_msvc"): get_win_copts(is_external), clean_dep("//tensorflow:ios"): ["-std=c++11"], + clean_dep("//tensorflow:no_lgpl_deps"): ["-D__TENSORFLOW_NO_LGPL_DEPS__", "-pthread"], "//conditions:default": ["-pthread"] })) @@ -241,6 +264,9 @@ def tf_opts_nortti_if_android(): # LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt) +def tf_features_nomodules_if_android(): + return if_android(["-use_header_modules"]) + # Given a list of "op_lib_names" (a list of files in the ops directory # without their .cc extensions), generate a library for that file. def tf_gen_op_libs(op_lib_names, deps=None, is_external=True): @@ -277,7 +303,6 @@ def _rpath_linkopts(name): "-Wl,%s" % (_make_search_paths("@loader_path", levels_to_root),), ], clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ "-Wl,%s" % (_make_search_paths("$$ORIGIN", levels_to_root),), ], @@ -292,18 +317,36 @@ def tf_binary_additional_srcs(): clean_dep("//tensorflow:libtensorflow_framework.so"), ]) + +# Helper functions to add kernel dependencies to tf binaries when using dynamic +# kernel linking. +def tf_binary_dynamic_kernel_dsos(kernels): + return if_dynamic_kernels( + extra_deps=["libtfkernel_%s.so" % clean_dep(k) for k in kernels], + otherwise=[]) + +# Helper functions to add kernel dependencies to tf binaries when using static +# kernel linking. +def tf_binary_dynamic_kernel_deps(kernels): + return if_dynamic_kernels( + extra_deps=[], + otherwise=kernels) + def tf_cc_shared_object( name, srcs=[], deps=[], + data=[], linkopts=[], framework_so=tf_binary_additional_srcs(), + kernels=[], **kwargs): native.cc_binary( name=name, srcs=srcs + framework_so, - deps=deps, + deps=deps + tf_binary_dynamic_kernel_deps(kernels), linkshared = 1, + data = data + tf_binary_dynamic_kernel_dsos(kernels), linkopts=linkopts + _rpath_linkopts(name) + select({ clean_dep("//tensorflow:darwin"): [ "-Wl,-install_name,@rpath/" + name.split("/")[-1], @@ -327,18 +370,21 @@ register_extension_info( def tf_cc_binary(name, srcs=[], deps=[], + data=[], linkopts=[], copts=tf_copts(), + kernels=[], **kwargs): native.cc_binary( name=name, copts=copts, srcs=srcs + tf_binary_additional_srcs(), - deps=deps + if_mkl( + deps=deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl_ml( [ - "//third_party/mkl:intel_binary_blob", + "//third_party/intel_mkl_ml", ], ), + data=data + tf_binary_dynamic_kernel_dsos(kernels), linkopts=linkopts + _rpath_linkopts(name), **kwargs) @@ -523,9 +569,6 @@ def tf_gen_op_wrappers_cc(name, # is invalid to specify both "hidden" and "op_whitelist". # cc_linkopts: Optional linkopts to be added to tf_cc_binary that contains the # specified ops. -# gen_locally: if True, the genrule to generate the Python library will be run -# without sandboxing. This would help when the genrule depends on symlinks -# which may not be supported in the sandbox. def tf_gen_op_wrapper_py(name, out=None, hidden=None, @@ -536,8 +579,7 @@ def tf_gen_op_wrapper_py(name, generated_target_name=None, op_whitelist=[], cc_linkopts=[], - api_def_srcs=[], - gen_locally=False): + api_def_srcs=[]): if (hidden or hidden_file) and op_whitelist: fail('Cannot pass specify both hidden and op_whitelist.') @@ -592,7 +634,6 @@ def tf_gen_op_wrapper_py(name, outs=[out], srcs=api_def_srcs + [hidden_file], tools=[tool_name] + tf_binary_additional_srcs(), - local = (1 if gen_locally else 0), cmd=("$(location " + tool_name + ") " + api_def_args_str + " @$(location " + hidden_file + ") " + ("1" if require_shape_functions else "0") + " > $@")) @@ -602,7 +643,6 @@ def tf_gen_op_wrapper_py(name, outs=[out], srcs=api_def_srcs, tools=[tool_name] + tf_binary_additional_srcs(), - local = (1 if gen_locally else 0), cmd=("$(location " + tool_name + ") " + api_def_args_str + " " + op_list_arg + " " + ("1" if require_shape_functions else "0") + " " + @@ -632,11 +672,13 @@ def tf_gen_op_wrapper_py(name, def tf_cc_test(name, srcs, deps, + data=[], linkstatic=0, extra_copts=[], suffix="", linkopts=[], nocopts=None, + kernels=[], **kwargs): native.cc_test( name="%s%s" % (name, suffix), @@ -647,7 +689,6 @@ def tf_cc_test(name, "-pie", ], clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], clean_dep("//tensorflow:darwin"): [ "-lm", ], @@ -656,11 +697,12 @@ def tf_cc_test(name, "-lm" ], }) + linkopts + _rpath_linkopts(name), - deps=deps + if_mkl( + deps=deps + tf_binary_dynamic_kernel_deps(kernels) + if_mkl_ml( [ - "//third_party/mkl:intel_binary_blob", + "//third_party/intel_mkl_ml", ], ), + data=data + tf_binary_dynamic_kernel_dsos(kernels), # Nested select() statements seem not to be supported when passed to # linkstatic, and we already have a cuda select() passed in to this # function. @@ -761,6 +803,7 @@ def tf_cuda_only_cc_test(name, size="medium", linkstatic=0, args=[], + kernels=[], linkopts=[]): native.cc_test( name="%s%s" % (name, "_gpu"), @@ -768,8 +811,8 @@ def tf_cuda_only_cc_test(name, size=size, args=args, copts= _cuda_copts() + tf_copts(), - data=data, - deps=deps + if_cuda([ + data=data + tf_binary_dynamic_kernel_dsos(kernels), + deps=deps + tf_binary_dynamic_kernel_deps(kernels) + if_cuda([ clean_dep("//tensorflow/core:cuda"), clean_dep("//tensorflow/core:gpu_lib")]), linkopts=if_not_windows(["-lpthread", "-lm"]) + linkopts + _rpath_linkopts(name), @@ -812,10 +855,15 @@ def tf_cc_tests(srcs, def tf_cc_test_mkl(srcs, deps, name="", + data=[], linkstatic=0, tags=[], size="medium", + kernels=[], args=None): + # -fno-exceptions in nocopts breaks compilation if header modules are enabled. + disable_header_modules = ["-use_header_modules"] + for src in srcs: native.cc_test( name=src_to_test_name(src), @@ -826,21 +874,18 @@ def tf_cc_test_mkl(srcs, "-pie", ], clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ "-lpthread", "-lm" ], }) + _rpath_linkopts(src_to_test_name(src)), - deps=deps + if_mkl( - [ - "//third_party/mkl:intel_binary_blob", - ], - ), + deps=deps + tf_binary_dynamic_kernel_deps(kernels) + mkl_deps(), + data=data + tf_binary_dynamic_kernel_dsos(kernels), linkstatic=linkstatic, tags=tags, size=size, args=args, + features=disable_header_modules, nocopts="-fno-exceptions") @@ -875,12 +920,13 @@ def tf_cuda_cc_tests(srcs, def tf_java_test(name, srcs=[], deps=[], + kernels=[], *args, **kwargs): native.java_test( name=name, srcs=srcs, - deps=deps + tf_binary_additional_srcs(), + deps=deps + tf_binary_additional_srcs() + tf_binary_dynamic_kernel_dsos(kernels) + tf_binary_dynamic_kernel_deps(kernels), *args, **kwargs) @@ -919,6 +965,7 @@ def tf_gpu_kernel_library(srcs, hdrs=[], **kwargs): copts = copts + _cuda_copts() + if_cuda(cuda_copts) + tf_copts() + kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"] native.cc_library( srcs=srcs, @@ -959,12 +1006,14 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs): if not cuda_deps: cuda_deps = [] + kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"] native.cc_library( deps=deps + if_cuda(cuda_deps + [ clean_dep("//tensorflow/core:cuda"), "@local_config_cuda//cuda:cuda_headers" ]), copts=(copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])), **kwargs) @@ -973,16 +1022,17 @@ register_extension_info( label_regex_for_dep = "{extension_name}", ) -def tf_kernel_library(name, - prefix=None, - srcs=None, - gpu_srcs=None, - hdrs=None, - deps=None, - alwayslink=1, - copts=None, - is_external=False, - **kwargs): +def tf_kernel_library( + name, + prefix = None, + srcs = None, + gpu_srcs = None, + hdrs = None, + deps = None, + alwayslink = 1, + copts = None, + is_external = False, + **kwargs): """A rule to build a TensorFlow OpKernel. May either specify srcs/hdrs or prefix. Similar to tf_cuda_library, @@ -1012,6 +1062,7 @@ def tf_kernel_library(name, deps = [] if not copts: copts = [] + textual_hdrs = [] copts = copts + tf_copts(is_external=is_external) if prefix: if native.glob([prefix + "*.cu.cc"], exclude=["*test*"]): @@ -1022,8 +1073,13 @@ def tf_kernel_library(name, srcs = srcs + native.glob( [prefix + "*.cc"], exclude=[prefix + "*test*", prefix + "*.cu.cc"]) hdrs = hdrs + native.glob( - [prefix + "*.h"], exclude=[prefix + "*test*", prefix + "*.cu.h"]) - + [prefix + "*.h"], + exclude = [prefix + "*test*", prefix + "*.cu.h", prefix + "*impl.h"], + ) + textual_hdrs = native.glob( + [prefix + "*impl.h"], + exclude = [prefix + "*test*", prefix + "*.cu.h"], + ) cuda_deps = [clean_dep("//tensorflow/core:gpu_lib")] if gpu_srcs: for gpu_src in gpu_srcs: @@ -1033,10 +1089,15 @@ def tf_kernel_library(name, tf_gpu_kernel_library( name=name + "_gpu", srcs=gpu_srcs, deps=deps, **kwargs) cuda_deps.extend([":" + name + "_gpu"]) + kwargs["tags"] = kwargs.get("tags", []) + [ + "req_dep=%s" % clean_dep("//tensorflow/core:gpu_lib"), + "req_dep=@local_config_cuda//cuda:cuda_headers", + ] tf_cuda_library( name=name, srcs=srcs, hdrs=hdrs, + textual_hdrs = textual_hdrs, copts=copts, cuda_deps=cuda_deps, linkstatic=1, # Needed since alwayslink is broken in bazel b/27630669 @@ -1044,6 +1105,15 @@ def tf_kernel_library(name, deps=deps, **kwargs) + # TODO(gunan): CUDA dependency not clear here. Fix it. + tf_cc_shared_object( + name="libtfkernel_%s.so" % name, + srcs=srcs + hdrs, + copts=copts, + deps=deps, + tags=["manual", "notap"]) + + register_extension_info( extension_name = "tf_kernel_library", label_regex_for_dep = "{extension_name}(_gpu)?", @@ -1070,6 +1140,9 @@ def tf_mkl_kernel_library(name, hdrs = hdrs + native.glob( [prefix + "*.h"]) + # -fno-exceptions in nocopts breaks compilation if header modules are enabled. + disable_header_modules = ["-use_header_modules"] + native.cc_library( name=name, srcs=if_mkl(srcs), @@ -1077,7 +1150,8 @@ def tf_mkl_kernel_library(name, deps=deps, alwayslink=alwayslink, copts=copts, - nocopts=nocopts + nocopts=nocopts, + features = disable_header_modules ) register_extension_info( @@ -1124,7 +1198,6 @@ _py_wrap_cc = rule( allow_files = True, ), "swig_includes": attr.label_list( - cfg = "data", allow_files = True, ), "deps": attr.label_list( @@ -1301,6 +1374,7 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]): name=basename + "_gpu", srcs=gpu_srcs, copts=_cuda_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]), + features = if_cuda(["-use_header_modules"]), deps=deps + if_cuda(cuda_deps)) cuda_deps.extend([":" + basename + "_gpu"]) @@ -1315,7 +1389,7 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]): name=name, srcs=srcs, deps=deps + if_cuda(cuda_deps), - data=[name + "_check_deps"], + data=if_static([name + "_check_deps"]), copts=tf_copts(is_external=True), features = ["windows_export_all_symbols"], linkopts=linkopts + select({ @@ -1323,7 +1397,6 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]): "-lm", ], clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], clean_dep("//tensorflow:darwin"): [], }),) @@ -1411,7 +1484,7 @@ def tf_py_wrap_cc(name, srcs=srcs, swig_includes=swig_includes, deps=deps + extra_deps, - toolchain_deps=["//tools/defaults:crosstool"], + toolchain_deps=["@bazel_tools//tools/cpp:current_cc_toolchain"], module_name=module_name, py_module_name=name) vscriptname=name+"_versionscript" @@ -1433,7 +1506,6 @@ def tf_py_wrap_cc(name, "$(location %s.lds)"%vscriptname, ], clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ "-Wl,--version-script", "$(location %s.lds)"%vscriptname, @@ -1444,7 +1516,6 @@ def tf_py_wrap_cc(name, "%s.lds"%vscriptname, ], clean_dep("//tensorflow:windows"): [], - clean_dep("//tensorflow:windows_msvc"): [], "//conditions:default": [ "%s.lds"%vscriptname, ] diff --git a/tensorflow/tf_framework_version_script.lds b/tensorflow/tf_framework_version_script.lds new file mode 100644 index 0000000000000000000000000000000000000000..d4977f88c0c340fa236b746efcefd607f4752359 --- /dev/null +++ b/tensorflow/tf_framework_version_script.lds @@ -0,0 +1,11 @@ +VERS_1.0 { + # Hide libjpeg symbols to avoid symbol conflict with OpenCV + local: + jpeg_*; + jinit_*; + jdiv_round_up; + jround_up; + jzero_far; + jcopy_*; + jsimd_*; +}; diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD deleted file mode 100644 index 3a28153e522a25b02858c7aaa75f2ade620479a1..0000000000000000000000000000000000000000 --- a/tensorflow/tools/api/generator/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -# Description: -# Scripts used to generate TensorFlow Python API. - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES") - -py_library( - name = "doc_srcs", - srcs = ["doc_srcs.py"], - srcs_version = "PY2AND3", -) - -py_binary( - name = "create_python_api", - srcs = ["create_python_api.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - ":doc_srcs", - "//tensorflow/python:no_contrib", - ], -) - -py_test( - name = "create_python_api_test", - srcs = ["create_python_api_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":create_python_api", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "tensorflow_doc_srcs_test", - srcs = ["doc_srcs_test.py"], - args = [ - "--package=tensorflow.python", - ] + TENSORFLOW_API_INIT_FILES, - main = "doc_srcs_test.py", - srcs_version = "PY2AND3", - deps = [ - ":doc_srcs", - "//tensorflow/python:client_testlib", - "//tensorflow/python:no_contrib", - ], -) diff --git a/tensorflow/tools/api/golden/BUILD b/tensorflow/tools/api/golden/BUILD index ebdf42df2c01a60b1cadd0368647adc4121db7ef..1f041ef19362c427fe327658f36f9f15eb5ce17d 100644 --- a/tensorflow/tools/api/golden/BUILD +++ b/tensorflow/tools/api/golden/BUILD @@ -7,6 +7,11 @@ package( licenses(["notice"]) # Apache 2.0 filegroup( - name = "api_golden", - srcs = glob(["*.pbtxt"]), + name = "api_golden_v1", + srcs = glob(["v1/*.pbtxt"]), +) + +filegroup( + name = "api_golden_v2", + srcs = glob(["v1/*.pbtxt"]), ) diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt index 9e09a8d48ec7a501cb25a30163b5dae84b7c8655..eb41deee13de99d6e9534c32141096edc018ed1c 100644 --- a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt @@ -8,5 +8,17 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_STRING } + field { + name: "client_handles_error_formatting" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "executor_type" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } } } diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt index 4af4ed70ef0698e996905bcb3b2222380b8694d8..e565b903d22c3921743becbdd34f33a8850e84d5 100644 --- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt @@ -131,6 +131,18 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_STRING } + field { + name: "client_handles_error_formatting" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "executor_type" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } } } } diff --git a/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt index 1f9aeb6ad62e1030c6e78f731fb5e05b876899e6..4f0147a52381c748eccbfee29df0d3537ba5d14a 100644 --- a/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.data.Iterator" tf_class { is_instance: "" + is_instance: "" is_instance: "" member { name: "initializer" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt index 099838fa65f6a532a594c08e8a44ead8ce008185..c23b04b4ef85a290f055d35d0c7f0d4d8a18a2de 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\'], " + argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt index 87bd19a23a3db727b5c1f13de04e3c11fd91de9b..6878d28fffabc895433f97415ee71cfe8f6232c1 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\'], " + argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt index c8da55d8021b7659446d0771a089b7b605d86c4f..bf1f94b6aedfd02c15c4750bc00beb057fa8694a 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt @@ -10,6 +10,10 @@ tf_class { name: "device_fn" mtype: "" } + member { + name: "eval_distribute" + mtype: "" + } member { name: "evaluation_master" mtype: "" @@ -50,6 +54,10 @@ tf_class { name: "num_worker_replicas" mtype: "" } + member { + name: "protocol" + mtype: "" + } member { name: "save_checkpoints_secs" mtype: "" @@ -88,7 +96,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'\', \'\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'\', \'\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "replace" diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt index 10171b3d60f300962e05f14d6aa7e13f364d2bde..5c46dc5ee7dc04f57591d4883ec8eb034a34d2d0 100644 --- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt @@ -84,6 +84,10 @@ tf_module { name: "extract_glimpse" argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'None\'], " } + member_method { + name: "extract_image_patches" + argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "extract_jpeg_shape" argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " @@ -116,6 +120,14 @@ tf_module { name: "non_max_suppression" argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], " } + member_method { + name: "non_max_suppression_overlaps" + argspec: "args=[\'overlaps\', \'scores\', \'max_output_size\', \'overlap_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], " + } + member_method { + name: "non_max_suppression_padded" + argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], " + } member_method { name: "pad_to_bounding_box" argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" @@ -172,6 +184,10 @@ tf_module { name: "resize_image_with_crop_or_pad" argspec: "args=[\'image\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "resize_image_with_pad" + argspec: "args=[\'image\', \'target_height\', \'target_width\', \'method\'], varargs=None, keywords=None, defaults=[\'0\'], " + } member_method { name: "resize_images" argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\', \'preserve_aspect_ratio\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt index 11cdd6f0b5e48f5835385fdd4e3e5144fb7d5166..e579fe6a1aeca296ac8ceb7b8ba951f250331eee 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt @@ -119,7 +119,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" @@ -135,7 +135,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt index 4afad3e4df308d412a1c18dea3b4e99aa1d2c84f..6f05cdd093d9f3061f1fd5dc74605ff476fd4040 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt @@ -124,7 +124,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" @@ -140,7 +140,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" @@ -266,6 +266,10 @@ tf_class { name: "summary" argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } + member_method { + name: "symbolic_set_inputs" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "test_on_batch" argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt index 2cd83baf65cf4114e58f52cdc40de7e4b6df7554..2e9de9ebb21021ab82ed4409243e13db49d7327c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt @@ -22,7 +22,7 @@ tf_module { } member_method { name: "relu" - argspec: "args=[\'x\', \'alpha\', \'max_value\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], " + argspec: "args=[\'x\', \'alpha\', \'max_value\', \'threshold\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\', \'0\'], " } member_method { name: "selu" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt index 62aa929d32b57518abbe924c036062eb7ccd3acf..56914e1746b0429adc2570c6cb31ddc8f9a6535a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt @@ -119,7 +119,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" @@ -135,7 +135,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt index 93ecbbce9b17b9ca6157e65bbabd6c36008c3992..4c1c54001d5f29ee77889c1d54f3983c5fb05161 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt @@ -124,7 +124,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" @@ -140,7 +140,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" @@ -266,6 +266,10 @@ tf_class { name: "summary" argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } + member_method { + name: "symbolic_set_inputs" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "test_on_batch" argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.manip.pbtxt b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt deleted file mode 100644 index 0b84165285102daf0a8e3dd6542bfc391e50f77b..0000000000000000000000000000000000000000 --- a/tensorflow/tools/api/golden/tensorflow.manip.pbtxt +++ /dev/null @@ -1,7 +0,0 @@ -path: "tensorflow.manip" -tf_module { - member_method { - name: "roll" - argspec: "args=[\'input\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/tensorflow.math.pbtxt deleted file mode 100644 index 03fbf6266d24f6a31ef76e2668c15987e55d2651..0000000000000000000000000000000000000000 --- a/tensorflow/tools/api/golden/tensorflow.math.pbtxt +++ /dev/null @@ -1,23 +0,0 @@ -path: "tensorflow.math" -tf_module { - member_method { - name: "bessel_i0" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'bessel_i0\'], " - } - member_method { - name: "bessel_i0e" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "bessel_i1" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'bessel_i1\'], " - } - member_method { - name: "bessel_i1e" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "polyval" - argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } -} diff --git a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt deleted file mode 100644 index b641c39feb6bcc4b5b73ba81ce0f0d4a499007ea..0000000000000000000000000000000000000000 --- a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt +++ /dev/null @@ -1,11 +0,0 @@ -path: "tensorflow.strings" -tf_module { - member_method { - name: "regex_full_match" - argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "split" - argspec: "args=[\'source\', \'sep\', \'maxsplit\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], " - } -} diff --git a/tensorflow/tools/api/golden/tensorflow.-aggregation-method.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-aggregation-method.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-aggregation-method.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-aggregation-method.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-attr-value.-list-value.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-attr-value.-list-value.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-attr-value.-list-value.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-attr-value.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-attr-value.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-attr-value.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-conditional-accumulator-base.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator-base.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-conditional-accumulator-base.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator-base.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-conditional-accumulator.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-device-count-entry.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-config-proto.-device-count-entry.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-device-count-entry.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..eb41deee13de99d6e9534c32141096edc018ed1c --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt @@ -0,0 +1,24 @@ +path: "tensorflow.ConfigProto.Experimental" +tf_proto { + descriptor { + name: "Experimental" + field { + name: "collective_group_leader" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "client_handles_error_formatting" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "executor_type" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e565b903d22c3921743becbdd34f33a8850e84d5 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt @@ -0,0 +1,148 @@ +path: "tensorflow.ConfigProto" +tf_proto { + descriptor { + name: "ConfigProto" + field { + name: "device_count" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.ConfigProto.DeviceCountEntry" + } + field { + name: "intra_op_parallelism_threads" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "inter_op_parallelism_threads" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "use_per_session_threads" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "session_inter_op_thread_pool" + number: 12 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.ThreadPoolOptionProto" + } + field { + name: "placement_period" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "device_filters" + number: 4 + label: LABEL_REPEATED + type: TYPE_STRING + } + field { + name: "gpu_options" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GPUOptions" + } + field { + name: "allow_soft_placement" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "log_device_placement" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "graph_options" + number: 10 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GraphOptions" + } + field { + name: "operation_timeout_in_ms" + number: 11 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "rpc_options" + number: 13 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.RPCOptions" + } + field { + name: "cluster_def" + number: 14 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ClusterDef" + } + field { + name: "isolate_session_state" + number: 15 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "experimental" + number: 16 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ConfigProto.Experimental" + } + nested_type { + name: "DeviceCountEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + options { + map_entry: true + } + } + nested_type { + name: "Experimental" + field { + name: "collective_group_leader" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "client_handles_error_formatting" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "executor_type" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-d-type.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-d-type.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-d-type.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-d-type.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-device-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-device-spec.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-device-spec.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-device-spec.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-dimension.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-dimension.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-dimension.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-dimension.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-event.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-event.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-event.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-event.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-f-i-f-o-queue.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-f-i-f-o-queue.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-f-i-f-o-queue.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-fixed-len-feature.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-fixed-len-feature.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-fixed-len-feature.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-fixed-len-feature.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-fixed-len-sequence-feature.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-fixed-len-sequence-feature.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-fixed-len-sequence-feature.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-fixed-len-sequence-feature.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-fixed-length-record-reader.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-fixed-length-record-reader.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-g-p-u-options.pbtxt similarity index 92% rename from tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-g-p-u-options.pbtxt index f819b174c0b701153af4709fade9313efa7f7fb6..353e63127de174a79c209a05327da2de20bf0dd7 100644 --- a/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-g-p-u-options.pbtxt @@ -72,6 +72,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_BOOL } + field { + name: "num_dev_to_dev_copy_streams" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } nested_type { name: "VirtualDevices" field { diff --git a/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-graph-def.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-graph-def.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-graph-def.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-graph-keys.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-graph-keys.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-graph-keys.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-graph-keys.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-graph-options.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-graph-options.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-graph-options.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-graph.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-graph.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-graph.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-histogram-proto.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-histogram-proto.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-histogram-proto.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-identity-reader.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-identity-reader.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-identity-reader.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-identity-reader.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-indexed-slices.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-indexed-slices.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-indexed-slices.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-indexed-slices.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-interactive-session.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-interactive-session.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-interactive-session.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-l-m-d-b-reader.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-l-m-d-b-reader.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-l-m-d-b-reader.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-l-m-d-b-reader.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-log-message.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-log-message.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-log-message.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-meta-graph-def.-meta-info-def.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-meta-graph-def.-meta-info-def.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-meta-graph-def.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-meta-graph-def.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-meta-graph-def.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-name-attr-list.-attr-entry.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-name-attr-list.-attr-entry.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-name-attr-list.-attr-entry.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-name-attr-list.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-name-attr-list.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-name-attr-list.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-node-def.-attr-entry.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-node-def.-attr-entry.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-node-def.-attr-entry.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-node-def.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-node-def.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-node-def.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-op-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-op-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-op-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-op-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-operation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-operation.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-operation.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-operation.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-optimizer-options.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-optimizer-options.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-optimizer-options.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-padding-f-i-f-o-queue.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-padding-f-i-f-o-queue.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-padding-f-i-f-o-queue.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-padding-f-i-f-o-queue.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-priority-queue.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-priority-queue.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-priority-queue.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-priority-queue.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-queue-base.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-queue-base.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-queue-base.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-queue-base.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-random-shuffle-queue.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-random-shuffle-queue.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-random-shuffle-queue.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-random-shuffle-queue.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-reader-base.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-reader-base.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-reader-base.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-reader-base.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-register-gradient.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-register-gradient.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-register-gradient.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-register-gradient.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-run-metadata.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-run-metadata.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-run-metadata.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-run-options.-experimental.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-session-log.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-session-log.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-session-log.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-session.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-session.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-session.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-session.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-sparse-conditional-accumulator.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-sparse-feature.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-feature.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-sparse-feature.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-sparse-feature.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-sparse-tensor-value.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor-value.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-sparse-tensor-value.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor-value.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-sparse-tensor.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-sparse-tensor.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-summary-metadata.-plugin-data.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-summary-metadata.-plugin-data.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-summary-metadata.-plugin-data.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-summary-metadata.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-summary-metadata.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-summary-metadata.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-summary.-audio.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-summary.-audio.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-summary.-audio.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-summary.-image.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-summary.-image.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-summary.-image.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-summary.-value.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-summary.-value.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-summary.-value.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-summary.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-summary.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-summary.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-summary.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-t-f-record-reader.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-t-f-record-reader.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-t-f-record-reader.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-t-f-record-reader.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-array.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-array.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-tensor-array.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-tensor-array.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.-coo-sparse.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-tensor-info.-coo-sparse.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.-coo-sparse.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-tensor-info.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-tensor-info.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor-shape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-shape.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-tensor-shape.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-tensor-shape.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-tensor.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-tensor.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-text-line-reader.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-text-line-reader.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-text-line-reader.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-text-line-reader.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-var-len-feature.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-var-len-feature.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-var-len-feature.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-var-len-feature.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..36b534af360835e3c1cbd1f0fb12a38c42232abf --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.-variable-aggregation.pbtxt @@ -0,0 +1,16 @@ +path: "tensorflow.VariableAggregation" +tf_class { + is_instance: "" + member { + name: "MEAN" + mtype: "" + } + member { + name: "NONE" + mtype: "" + } + member { + name: "SUM" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-variable-scope.pbtxt similarity index 92% rename from tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-variable-scope.pbtxt index 8e539069da05fbb192c383d3f5acff78ab9bfeff..c13eb7b8bb9474f3534582c8af8c3ee4b6c7e076 100644 --- a/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-variable-scope.pbtxt @@ -56,7 +56,7 @@ tf_class { } member_method { name: "get_variable" - argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " } member_method { name: "global_variables" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-variable-synchronization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-variable-synchronization.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7589bb28888774839a3011e1e5581f004313f81d --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.-variable-synchronization.pbtxt @@ -0,0 +1,20 @@ +path: "tensorflow.VariableSynchronization" +tf_class { + is_instance: "" + member { + name: "AUTO" + mtype: "" + } + member { + name: "NONE" + mtype: "" + } + member { + name: "ON_READ" + mtype: "" + } + member { + name: "ON_WRITE" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.-save-slice-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-variable.-save-slice-info.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-variable.-save-slice-info.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-variable.-save-slice-info.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt similarity index 91% rename from tensorflow/tools/api/golden/tensorflow.-variable.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt index 23b552cc38488bdc15d7deed20f563379dba24c3..e841c4ad8904ae1ae49881b47648b901a4abf778 100644 --- a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt @@ -49,7 +49,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " } member_method { name: "assign" diff --git a/tensorflow/tools/api/golden/tensorflow.-whole-file-reader.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-whole-file-reader.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.-whole-file-reader.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.-whole-file-reader.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.app.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.app.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.app.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.app.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.bitwise.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.bitwise.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.bitwise.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.bitwise.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.compat.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.compat.pbtxt similarity index 76% rename from tensorflow/tools/api/golden/tensorflow.compat.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.compat.pbtxt index bab480ff9b105546790aadb72f3eb88a795ebbff..f1d760603e981a0b9a72fdc379dc81932ac71d67 100644 --- a/tensorflow/tools/api/golden/tensorflow.compat.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.compat.pbtxt @@ -32,6 +32,14 @@ tf_module { name: "as_text" argspec: "args=[\'bytes_or_text\', \'encoding\'], varargs=None, keywords=None, defaults=[\'utf-8\'], " } + member_method { + name: "forward_compatibility_horizon" + argspec: "args=[\'year\', \'month\', \'day\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "forward_compatible" + argspec: "args=[\'year\', \'month\', \'day\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "path_to_str" argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.constant_initializer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.constant_initializer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.constant_initializer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.data.-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.data.-dataset.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-iterator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-iterator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4f0147a52381c748eccbfee29df0d3537ba5d14a --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-iterator.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.data.Iterator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "initializer" + mtype: "" + } + member { + name: "output_classes" + mtype: "" + } + member { + name: "output_shapes" + mtype: "" + } + member { + name: "output_types" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'iterator_resource\', \'initializer\', \'output_types\', \'output_shapes\', \'output_classes\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_string_handle" + argspec: "args=[\'string_handle\', \'output_types\', \'output_shapes\', \'output_classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "from_structure" + argspec: "args=[\'output_types\', \'output_shapes\', \'shared_name\', \'output_classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "get_next" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "make_initializer" + argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "string_handle" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.data.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.data.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.data.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.debugging.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.debugging.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d9efe97821904f5891148b72a0c31e02c9562bd7 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.debugging.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.debugging" +tf_module { + member_method { + name: "check_numerics" + argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_finite" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_inf" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_nan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-bernoulli.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-bernoulli.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-bernoulli.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-bernoulli.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-beta.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-beta.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-beta.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-beta.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-categorical.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-categorical.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-categorical.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-categorical.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-dirichlet-multinomial.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-dirichlet-multinomial.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-dirichlet-multinomial.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-dirichlet-multinomial.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-dirichlet.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-dirichlet.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-dirichlet.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-dirichlet.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-distribution.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-distribution.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-distribution.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-distribution.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-exponential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-exponential.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-exponential.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-exponential.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-gamma.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-gamma.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-gamma.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-gamma.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-laplace.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-laplace.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-laplace.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-laplace.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-multinomial.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-multinomial.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-multinomial.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-multinomial.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-normal.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-normal.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-normal.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-register-k-l.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-register-k-l.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-register-k-l.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-register-k-l.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-reparameterization-type.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-reparameterization-type.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-reparameterization-type.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-reparameterization-type.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-student-t.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-student-t.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-student-t.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-student-t.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.-uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.-uniform.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.-uniform.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.-uniform.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.distributions.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distributions.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.distributions.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.distributions.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.dtypes.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.dtypes.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..98e1feed002ceb4f455aa5ec361d26a159fdad1a --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.dtypes.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.dtypes" +tf_module { + member_method { + name: "as_string" + argspec: "args=[\'input\', \'precision\', \'scientific\', \'shortest\', \'width\', \'fill\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'False\', \'False\', \'-1\', \'\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-aborted-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-aborted-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-aborted-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-aborted-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-already-exists-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-already-exists-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-already-exists-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-already-exists-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-cancelled-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-cancelled-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-cancelled-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-cancelled-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-data-loss-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-data-loss-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-data-loss-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-data-loss-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-deadline-exceeded-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-deadline-exceeded-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-deadline-exceeded-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-deadline-exceeded-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-failed-precondition-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-failed-precondition-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-failed-precondition-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-failed-precondition-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-internal-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-internal-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-internal-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-internal-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-invalid-argument-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-invalid-argument-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-invalid-argument-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-invalid-argument-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-not-found-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-not-found-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-not-found-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-not-found-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-op-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-op-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-op-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-op-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-out-of-range-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-out-of-range-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-out-of-range-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-out-of-range-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-permission-denied-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-permission-denied-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-permission-denied-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-permission-denied-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-resource-exhausted-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-resource-exhausted-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-resource-exhausted-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-resource-exhausted-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-unauthenticated-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-unauthenticated-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-unauthenticated-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-unauthenticated-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-unavailable-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-unavailable-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-unavailable-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-unavailable-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-unimplemented-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-unimplemented-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-unimplemented-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-unimplemented-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.-unknown-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.-unknown-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.-unknown-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.-unknown-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.errors.raise_exception_on_not_ok_status.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.errors.raise_exception_on_not_ok_status.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.errors.raise_exception_on_not_ok_status.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.errors.raise_exception_on_not_ok_status.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-classifier.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-classifier.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-regressor.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-regressor.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-best-exporter.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-best-exporter.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.-best-exporter.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-best-exporter.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c23b04b4ef85a290f055d35d0c7f0d4d8a18a2de --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.estimator.BoostedTreesClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6878d28fffabc895433f97415ee71cfe8f6232c1 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.estimator.BoostedTreesRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-classifier.pbtxt similarity index 91% rename from tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-classifier.pbtxt index 111914f643a3b192d496c5b0857b4429da12b1d6..0c6b7e4a821ad47c20b6f6074b575bf83c403653 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-classifier.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], " + argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt similarity index 88% rename from tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt index 67e4ee02d0581207e7dd316196aeb782930e7602..9c1c072124083006a1dd8e04526755dd980ba85a 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], " + argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\', \'linear_sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\', \'sum\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt similarity index 89% rename from tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt index e1289b975e721e94f4a63889f3e0b76b0db23d81..7391d4b07a7e79541091b94fe4a9f38f42d6f68a 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'1\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], " + argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\', \'linear_sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'1\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\', \'sum\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-regressor.pbtxt similarity index 89% rename from tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-regressor.pbtxt index d030b2f51f019ecc179a09b76c4484e60ada9dd0..f50e375f7cd392567f5c87536c95eb1f6809bc97 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-regressor.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], " + argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-estimator-spec.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.-estimator-spec.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-estimator-spec.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-estimator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-estimator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-eval-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-eval-spec.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.-eval-spec.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-eval-spec.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-exporter.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-exporter.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.-exporter.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-exporter.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-final-exporter.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-final-exporter.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.-final-exporter.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-final-exporter.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-latest-exporter.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-latest-exporter.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.-latest-exporter.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-latest-exporter.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-classifier.pbtxt similarity index 92% rename from tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-classifier.pbtxt index cb578759eee2ed43465195a8c4e8760443a60b71..154f171e89571a43a3f905094a1dbd41cbb000d3 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-classifier.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\'], " + argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\', \'sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\', \'sum\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-regressor.pbtxt similarity index 90% rename from tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-regressor.pbtxt index fcd01bb663c7af22791c3855e6da22d93c667f84..4d46d1e6b68758bf634f9b0f82c279fdfa91a0b8 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-regressor.pbtxt @@ -21,7 +21,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\'], " + argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\', \'sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\', \'sum\'], " } member_method { name: "eval_dir" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-mode-keys.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-mode-keys.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.-mode-keys.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-mode-keys.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..bf1f94b6aedfd02c15c4750bc00beb057fa8694a --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-run-config.pbtxt @@ -0,0 +1,105 @@ +path: "tensorflow.estimator.RunConfig" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "cluster_spec" + mtype: "" + } + member { + name: "device_fn" + mtype: "" + } + member { + name: "eval_distribute" + mtype: "" + } + member { + name: "evaluation_master" + mtype: "" + } + member { + name: "global_id_in_cluster" + mtype: "" + } + member { + name: "is_chief" + mtype: "" + } + member { + name: "keep_checkpoint_every_n_hours" + mtype: "" + } + member { + name: "keep_checkpoint_max" + mtype: "" + } + member { + name: "log_step_count_steps" + mtype: "" + } + member { + name: "master" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "num_ps_replicas" + mtype: "" + } + member { + name: "num_worker_replicas" + mtype: "" + } + member { + name: "protocol" + mtype: "" + } + member { + name: "save_checkpoints_secs" + mtype: "" + } + member { + name: "save_checkpoints_steps" + mtype: "" + } + member { + name: "save_summary_steps" + mtype: "" + } + member { + name: "service" + mtype: "" + } + member { + name: "session_config" + mtype: "" + } + member { + name: "task_id" + mtype: "" + } + member { + name: "task_type" + mtype: "" + } + member { + name: "tf_random_seed" + mtype: "" + } + member { + name: "train_distribute" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'\', \'\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "replace" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-train-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-train-spec.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.-train-spec.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-train-spec.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.-vocab-info.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-warm-start-settings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-warm-start-settings.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.-warm-start-settings.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.-warm-start-settings.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-classification-output.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-classification-output.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.export.-classification-output.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-classification-output.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-classification-output.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-classification-output.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.export.-classification-output.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-classification-output.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-export-output.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-export-output.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.export.-export-output.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-export-output.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-export-output.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-export-output.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.export.-export-output.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-export-output.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-predict-output.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-predict-output.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.export.-predict-output.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-predict-output.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-predict-output.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-predict-output.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.export.-predict-output.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-predict-output.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-regression-output.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-regression-output.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.export.-regression-output.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-regression-output.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-regression-output.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-regression-output.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.export.-regression-output.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-regression-output.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-serving-input-receiver.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-serving-input-receiver.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.export.-serving-input-receiver.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-serving-input-receiver.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.-tensor-serving-input-receiver.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-tensor-serving-input-receiver.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.export.-tensor-serving-input-receiver.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.export.-tensor-serving-input-receiver.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.export.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.export.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.export.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.export.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.inputs.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.inputs.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.inputs.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.inputs.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.estimator.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.estimator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.feature_column.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.gfile.-fast-g-file.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.gfile.-fast-g-file.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.gfile.-fast-g-file.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.gfile.-fast-g-file.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.gfile.-g-file.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.gfile.-g-file.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.gfile.-g-file.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.gfile.-g-file.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.gfile.-open.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.gfile.-open.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.gfile.-open.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.gfile.-open.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.gfile.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.gfile.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.gfile.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.gfile.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.graph_util.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.graph_util.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.graph_util.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.graph_util.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.image.-resize-method.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.image.-resize-method.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.image.-resize-method.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.image.-resize-method.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5c46dc5ee7dc04f57591d4883ec8eb034a34d2d0 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt @@ -0,0 +1,251 @@ +path: "tensorflow.image" +tf_module { + member { + name: "ResizeMethod" + mtype: "" + } + member_method { + name: "adjust_brightness" + argspec: "args=[\'image\', \'delta\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "adjust_contrast" + argspec: "args=[\'images\', \'contrast_factor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "adjust_gamma" + argspec: "args=[\'image\', \'gamma\', \'gain\'], varargs=None, keywords=None, defaults=[\'1\', \'1\'], " + } + member_method { + name: "adjust_hue" + argspec: "args=[\'image\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "adjust_jpeg_quality" + argspec: "args=[\'image\', \'jpeg_quality\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "adjust_saturation" + argspec: "args=[\'image\', \'saturation_factor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "central_crop" + argspec: "args=[\'image\', \'central_fraction\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "convert_image_dtype" + argspec: "args=[\'image\', \'dtype\', \'saturate\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "crop_and_resize" + argspec: "args=[\'image\', \'boxes\', \'box_ind\', \'crop_size\', \'method\', \'extrapolation_value\', \'name\'], varargs=None, keywords=None, defaults=[\'bilinear\', \'0\', \'None\'], " + } + member_method { + name: "crop_to_bounding_box" + argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "decode_and_crop_jpeg" + argspec: "args=[\'contents\', \'crop_window\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], " + } + member_method { + name: "decode_bmp" + argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], " + } + member_method { + name: "decode_gif" + argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_image" + argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"\", \'None\'], " + } + member_method { + name: "decode_jpeg" + argspec: "args=[\'contents\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], " + } + member_method { + name: "decode_png" + argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"\", \'None\'], " + } + member_method { + name: "draw_bounding_boxes" + argspec: "args=[\'images\', \'boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "encode_jpeg" + argspec: "args=[\'image\', \'format\', \'quality\', \'progressive\', \'optimize_size\', \'chroma_downsampling\', \'density_unit\', \'x_density\', \'y_density\', \'xmp_metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'95\', \'False\', \'False\', \'True\', \'in\', \'300\', \'300\', \'\', \'None\'], " + } + member_method { + name: "encode_png" + argspec: "args=[\'image\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + } + member_method { + name: "extract_glimpse" + argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'None\'], " + } + member_method { + name: "extract_image_patches" + argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "extract_jpeg_shape" + argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "flip_left_right" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flip_up_down" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "grayscale_to_rgb" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "hsv_to_rgb" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "image_gradients" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_jpeg" + argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "non_max_suppression" + argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], " + } + member_method { + name: "non_max_suppression_overlaps" + argspec: "args=[\'overlaps\', \'scores\', \'max_output_size\', \'overlap_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], " + } + member_method { + name: "non_max_suppression_padded" + argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], " + } + member_method { + name: "pad_to_bounding_box" + argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "per_image_standardization" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "psnr" + argspec: "args=[\'a\', \'b\', \'max_val\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_brightness" + argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_contrast" + argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_flip_left_right" + argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_flip_up_down" + argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_hue" + argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_jpeg_quality" + argspec: "args=[\'image\', \'min_jpeg_quality\', \'max_jpeg_quality\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_saturation" + argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "resize_area" + argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "resize_bicubic" + argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "resize_bilinear" + argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "resize_image_with_crop_or_pad" + argspec: "args=[\'image\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "resize_image_with_pad" + argspec: "args=[\'image\', \'target_height\', \'target_width\', \'method\'], varargs=None, keywords=None, defaults=[\'0\'], " + } + member_method { + name: "resize_images" + argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\', \'preserve_aspect_ratio\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\'], " + } + member_method { + name: "resize_nearest_neighbor" + argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "rgb_to_grayscale" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rgb_to_hsv" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rgb_to_yiq" + argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "rgb_to_yuv" + argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "rot90" + argspec: "args=[\'image\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " + } + member_method { + name: "sample_distorted_bounding_box" + argspec: "args=[\'image_size\', \'bounding_boxes\', \'seed\', \'seed2\', \'min_object_covered\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.1\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sobel_edges" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "ssim" + argspec: "args=[\'img1\', \'img2\', \'max_val\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "ssim_multiscale" + argspec: "args=[\'img1\', \'img2\', \'max_val\', \'power_factors\'], varargs=None, keywords=None, defaults=[\'(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)\'], " + } + member_method { + name: "total_variation" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "transpose_image" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "yiq_to_rgb" + argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "yuv_to_rgb" + argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.constant.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.constant.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.initializers.constant.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.initializers.constant.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.identity.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.identity.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.initializers.identity.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.initializers.identity.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.ones.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.ones.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.initializers.ones.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.initializers.ones.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.orthogonal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.orthogonal.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.initializers.orthogonal.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.initializers.orthogonal.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt similarity index 56% rename from tensorflow/tools/api/golden/tensorflow.initializers.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt index eaf0036cacfadce335a84bcf61f47f9d360be7e2..bc0426f2f145763552dcb46fb6c2efc7c42b974f 100644 --- a/tensorflow/tools/api/golden/tensorflow.initializers.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt @@ -44,6 +44,30 @@ tf_module { name: "global_variables" argspec: "args=[], varargs=None, keywords=None, defaults=None" } + member_method { + name: "glorot_normal" + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "glorot_uniform" + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "he_normal" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "he_uniform" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lecun_normal" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lecun_uniform" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "local_variables" argspec: "args=[], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.random_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.random_normal.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.initializers.random_normal.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.initializers.random_normal.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.random_uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.random_uniform.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.initializers.random_uniform.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.initializers.random_uniform.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.truncated_normal.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.initializers.truncated_normal.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.initializers.truncated_normal.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.uniform_unit_scaling.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.uniform_unit_scaling.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.initializers.uniform_unit_scaling.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.initializers.uniform_unit_scaling.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.variance_scaling.pbtxt similarity index 88% rename from tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.initializers.variance_scaling.pbtxt index a6b6e5eceb62654c9ad567a361f7558a2865e57a..86340913e2506c96499aae05a3ed0d5273c93bba 100644 --- a/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.initializers.variance_scaling.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"\"], " + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"\"], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.zeros.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.zeros.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.initializers.zeros.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.initializers.zeros.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3a36c168aa703721421b662185fc852fa3d6a3ec --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt @@ -0,0 +1,39 @@ +path: "tensorflow.io" +tf_module { + member_method { + name: "decode_base64" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_compressed" + argspec: "args=[\'bytes\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " + } + member_method { + name: "decode_json_example" + argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_raw" + argspec: "args=[\'bytes\', \'out_type\', \'little_endian\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "encode_base64" + argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "matching_files" + argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "parse_tensor" + argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_file" + argspec: "args=[\'filename\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "write_file" + argspec: "args=[\'filename\', \'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e579fe6a1aeca296ac8ceb7b8ba951f250331eee --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt @@ -0,0 +1,268 @@ +path: "tensorflow.keras.Model" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "layers" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "state_updates" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "uses_learning_phase" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compile" + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + } + member_method { + name: "evaluate_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "fit" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + } + member_method { + name: "fit_generator" + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_layer" + argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_weights" + argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "predict_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "predict_on_batch" + argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "save" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], " + } + member_method { + name: "save_weights" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "summary" + argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "test_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "to_json" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "to_yaml" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "train_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..97688fcb0f2357b38b97053ade1ff88ca8c1bfd0 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt @@ -0,0 +1,285 @@ +path: "tensorflow.keras.Sequential" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "layers" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "state_updates" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "uses_learning_phase" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'layers\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add" + argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compile" + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + } + member_method { + name: "evaluate_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "fit" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + } + member_method { + name: "fit_generator" + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_layer" + argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_weights" + argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "pop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "predict_classes" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " + } + member_method { + name: "predict_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "predict_on_batch" + argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict_proba" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " + } + member_method { + name: "reset_states" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "save" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], " + } + member_method { + name: "save_weights" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "summary" + argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "test_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "to_json" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "to_yaml" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "train_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2e9de9ebb21021ab82ed4409243e13db49d7327c --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt @@ -0,0 +1,55 @@ +path: "tensorflow.keras.activations" +tf_module { + member_method { + name: "deserialize" + argspec: "args=[\'name\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "elu" + argspec: "args=[\'x\', \'alpha\'], varargs=None, keywords=None, defaults=[\'1.0\'], " + } + member_method { + name: "get" + argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "hard_sigmoid" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "linear" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "relu" + argspec: "args=[\'x\', \'alpha\', \'max_value\', \'threshold\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\', \'0\'], " + } + member_method { + name: "selu" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "serialize" + argspec: "args=[\'activation\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "sigmoid" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "softmax" + argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "softplus" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "softsign" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "tanh" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.backend.name_scope.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.name_scope.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.backend.name_scope.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.backend.name_scope.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..126ce8db6a73e2c486dbf34512812e630b3e9a32 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt @@ -0,0 +1,555 @@ +path: "tensorflow.keras.backend" +tf_module { + member { + name: "name_scope" + mtype: "" + } + member_method { + name: "abs" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "all" + argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "any" + argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "arange" + argspec: "args=[\'start\', \'stop\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'int32\'], " + } + member_method { + name: "argmax" + argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "argmin" + argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "backend" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "batch_dot" + argspec: "args=[\'x\', \'y\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "batch_flatten" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "batch_get_value" + argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "batch_normalization" + argspec: "args=[\'x\', \'mean\', \'var\', \'beta\', \'gamma\', \'epsilon\'], varargs=None, keywords=None, defaults=[\'0.001\'], " + } + member_method { + name: "batch_set_value" + argspec: "args=[\'tuples\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "bias_add" + argspec: "args=[\'x\', \'bias\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "binary_crossentropy" + argspec: "args=[\'target\', \'output\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "cast" + argspec: "args=[\'x\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "cast_to_floatx" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "categorical_crossentropy" + argspec: "args=[\'target\', \'output\', \'from_logits\', \'axis\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\'], " + } + member_method { + name: "clear_session" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "clip" + argspec: "args=[\'x\', \'min_value\', \'max_value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "concatenate" + argspec: "args=[\'tensors\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "constant" + argspec: "args=[\'value\', \'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "conv1d" + argspec: "args=[\'x\', \'kernel\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\'], varargs=None, keywords=None, defaults=[\'1\', \'valid\', \'None\', \'1\'], " + } + member_method { + name: "conv2d" + argspec: "args=[\'x\', \'kernel\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\'], " + } + member_method { + name: "conv2d_transpose" + argspec: "args=[\'x\', \'kernel\', \'output_shape\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\'], " + } + member_method { + name: "conv3d" + argspec: "args=[\'x\', \'kernel\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\'], varargs=None, keywords=None, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\'], " + } + member_method { + name: "cos" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "ctc_batch_cost" + argspec: "args=[\'y_true\', \'y_pred\', \'input_length\', \'label_length\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "ctc_decode" + argspec: "args=[\'y_pred\', \'input_length\', \'greedy\', \'beam_width\', \'top_paths\'], varargs=None, keywords=None, defaults=[\'True\', \'100\', \'1\'], " + } + member_method { + name: "ctc_label_dense_to_sparse" + argspec: "args=[\'labels\', \'label_lengths\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "dot" + argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "dropout" + argspec: "args=[\'x\', \'level\', \'noise_shape\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "dtype" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "elu" + argspec: "args=[\'x\', \'alpha\'], varargs=None, keywords=None, defaults=[\'1.0\'], " + } + member_method { + name: "epsilon" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "equal" + argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "eval" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "exp" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "expand_dims" + argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "eye" + argspec: "args=[\'size\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "flatten" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "floatx" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "foldl" + argspec: "args=[\'fn\', \'elems\', \'initializer\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "foldr" + argspec: "args=[\'fn\', \'elems\', \'initializer\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "function" + argspec: "args=[\'inputs\', \'outputs\', \'updates\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "gather" + argspec: "args=[\'reference\', \'indices\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_session" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_uid" + argspec: "args=[\'prefix\'], varargs=None, keywords=None, defaults=[\'\'], " + } + member_method { + name: "get_value" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "gradients" + argspec: "args=[\'loss\', \'variables\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "greater" + argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "greater_equal" + argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "hard_sigmoid" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "image_data_format" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "in_test_phase" + argspec: "args=[\'x\', \'alt\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "in_top_k" + argspec: "args=[\'predictions\', \'targets\', \'k\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "in_train_phase" + argspec: "args=[\'x\', \'alt\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "int_shape" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_sparse" + argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "l2_normalize" + argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "learning_phase" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "less" + argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "less_equal" + argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "log" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "manual_variable_initialization" + argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "map_fn" + argspec: "args=[\'fn\', \'elems\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "max" + argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "maximum" + argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mean" + argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "min" + argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "minimum" + argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "moving_average_update" + argspec: "args=[\'x\', \'value\', \'momentum\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "ndim" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "normalize_batch_in_training" + argspec: "args=[\'x\', \'gamma\', \'beta\', \'reduction_axes\', \'epsilon\'], varargs=None, keywords=None, defaults=[\'0.001\'], " + } + member_method { + name: "not_equal" + argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "one_hot" + argspec: "args=[\'indices\', \'num_classes\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "ones" + argspec: "args=[\'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "ones_like" + argspec: "args=[\'x\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "permute_dimensions" + argspec: "args=[\'x\', \'pattern\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "placeholder" + argspec: "args=[\'shape\', \'ndim\', \'dtype\', \'sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "pool2d" + argspec: "args=[\'x\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'pool_mode\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\', \'max\'], " + } + member_method { + name: "pool3d" + argspec: "args=[\'x\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'pool_mode\'], varargs=None, keywords=None, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'max\'], " + } + member_method { + name: "pow" + argspec: "args=[\'x\', \'a\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "print_tensor" + argspec: "args=[\'x\', \'message\'], varargs=None, keywords=None, defaults=[\'\'], " + } + member_method { + name: "prod" + argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "random_binomial" + argspec: "args=[\'shape\', \'p\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\', \'None\'], " + } + member_method { + name: "random_normal" + argspec: "args=[\'shape\', \'mean\', \'stddev\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \'None\'], " + } + member_method { + name: "random_normal_variable" + argspec: "args=[\'shape\', \'mean\', \'scale\', \'dtype\', \'name\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "random_uniform" + argspec: "args=[\'shape\', \'minval\', \'maxval\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \'None\'], " + } + member_method { + name: "random_uniform_variable" + argspec: "args=[\'shape\', \'low\', \'high\', \'dtype\', \'name\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "relu" + argspec: "args=[\'x\', \'alpha\', \'max_value\', \'threshold\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\', \'0\'], " + } + member_method { + name: "repeat" + argspec: "args=[\'x\', \'n\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "repeat_elements" + argspec: "args=[\'x\', \'rep\', \'axis\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_uids" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reshape" + argspec: "args=[\'x\', \'shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "resize_images" + argspec: "args=[\'x\', \'height_factor\', \'width_factor\', \'data_format\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "resize_volumes" + argspec: "args=[\'x\', \'depth_factor\', \'height_factor\', \'width_factor\', \'data_format\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reverse" + argspec: "args=[\'x\', \'axes\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "rnn" + argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "round" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "separable_conv2d" + argspec: "args=[\'x\', \'depthwise_kernel\', \'pointwise_kernel\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\'], " + } + member_method { + name: "set_epsilon" + argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_floatx" + argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_image_data_format" + argspec: "args=[\'data_format\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_learning_phase" + argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_session" + argspec: "args=[\'session\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_value" + argspec: "args=[\'x\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "shape" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "sigmoid" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "sign" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "sin" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "softmax" + argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "softplus" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "softsign" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "sparse_categorical_crossentropy" + argspec: "args=[\'target\', \'output\', \'from_logits\', \'axis\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\'], " + } + member_method { + name: "spatial_2d_padding" + argspec: "args=[\'x\', \'padding\', \'data_format\'], varargs=None, keywords=None, defaults=[\'((1, 1), (1, 1))\', \'None\'], " + } + member_method { + name: "spatial_3d_padding" + argspec: "args=[\'x\', \'padding\', \'data_format\'], varargs=None, keywords=None, defaults=[\'((1, 1), (1, 1), (1, 1))\', \'None\'], " + } + member_method { + name: "sqrt" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "square" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "squeeze" + argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stack" + argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], " + } + member_method { + name: "std" + argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "stop_gradient" + argspec: "args=[\'variables\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "sum" + argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "switch" + argspec: "args=[\'condition\', \'then_expression\', \'else_expression\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "tanh" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "temporal_padding" + argspec: "args=[\'x\', \'padding\'], varargs=None, keywords=None, defaults=[\'(1, 1)\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "transpose" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "truncated_normal" + argspec: "args=[\'shape\', \'mean\', \'stddev\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \'None\'], " + } + member_method { + name: "update" + argspec: "args=[\'x\', \'new_x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update_add" + argspec: "args=[\'x\', \'increment\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update_sub" + argspec: "args=[\'x\', \'decrement\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "var" + argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "variable" + argspec: "args=[\'value\', \'dtype\', \'name\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "zeros" + argspec: "args=[\'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "zeros_like" + argspec: "args=[\'x\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-base-logger.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-base-logger.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-callback.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-callback.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.-callback.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-callback.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-early-stopping.pbtxt similarity index 91% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-early-stopping.pbtxt index 7b0ad85eaac5b83835a9e1c4b152e38e7051a2f6..f71292856cd29b2e52194bec8a586686fbfad667 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-early-stopping.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'monitor\', \'min_delta\', \'patience\', \'verbose\', \'mode\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0\', \'0\', \'0\', \'auto\'], " + argspec: "args=[\'self\', \'monitor\', \'min_delta\', \'patience\', \'verbose\', \'mode\', \'baseline\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0\', \'0\', \'0\', \'auto\', \'None\'], " } member_method { name: "on_batch_begin" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-history.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-history.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.-history.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-history.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-lambda-callback.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-lambda-callback.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.-lambda-callback.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-lambda-callback.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-model-checkpoint.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.-model-checkpoint.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-progbar-logger.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-progbar-logger.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-remote-monitor.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-remote-monitor.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt similarity index 86% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt index 2f52464315d8c1b526792c92f5cf8e83ce3ce087..e58ba18c1c0d06df3a53d93ae18f5bf0931df329 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'batch_size\', \'write_graph\', \'write_grads\', \'write_images\'], varargs=None, keywords=None, defaults=[\'./logs\', \'0\', \'32\', \'True\', \'False\', \'False\'], " + argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'batch_size\', \'write_graph\', \'write_grads\', \'write_images\', \'embeddings_freq\', \'embeddings_layer_names\', \'embeddings_metadata\', \'embeddings_data\'], varargs=None, keywords=None, defaults=[\'./logs\', \'0\', \'32\', \'True\', \'False\', \'False\', \'0\', \'None\', \'None\', \'None\'], " } member_method { name: "on_batch_begin" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.callbacks.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-constraint.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-constraint.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.constraints.-constraint.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-constraint.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-max-norm.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-max-norm.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.constraints.-max-norm.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-max-norm.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-min-max-norm.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-min-max-norm.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.constraints.-min-max-norm.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-min-max-norm.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-non-neg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-non-neg.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.constraints.-non-neg.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-non-neg.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-unit-norm.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-unit-norm.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.constraints.-unit-norm.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.-unit-norm.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.max_norm.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.max_norm.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.constraints.max_norm.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.max_norm.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.min_max_norm.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.min_max_norm.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.constraints.min_max_norm.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.min_max_norm.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.non_neg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.non_neg.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.constraints.non_neg.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.non_neg.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.constraints.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.unit_norm.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.unit_norm.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.constraints.unit_norm.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.constraints.unit_norm.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.boston_housing.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.boston_housing.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.cifar10.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.cifar10.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.datasets.cifar10.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.cifar10.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.cifar100.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.cifar100.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.datasets.cifar100.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.cifar100.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.fashion_mnist.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.fashion_mnist.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.imdb.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.imdb.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.mnist.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.mnist.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.datasets.mnist.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.mnist.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.datasets.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.reuters.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.datasets.reuters.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.estimator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.estimator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.estimator.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.estimator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-constant.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-constant.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.-constant.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-constant.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-identity.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-identity.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.-identity.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-identity.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-initializer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.-initializer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-initializer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-ones.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-ones.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.-ones.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-ones.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-orthogonal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-orthogonal.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.-orthogonal.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-orthogonal.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-random-normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.-random-normal.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-normal.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-random-uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.-random-uniform.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-random-uniform.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-truncated-normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.-truncated-normal.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-truncated-normal.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-variance-scaling.pbtxt similarity index 88% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-variance-scaling.pbtxt index 32a6f6ee88815b3dc70e9cca855f73099554953b..03f4064b9ef5093044a9cbb897043d643cf7f83e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-variance-scaling.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"\"], " + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"\"], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-zeros.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-zeros.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.-zeros.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.-zeros.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.constant.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.constant.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.constant.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.constant.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.identity.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.identity.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.identity.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.identity.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.normal.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.normal.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.ones.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.ones.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.ones.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.ones.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.orthogonal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.orthogonal.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.orthogonal.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.orthogonal.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt similarity index 90% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt index 14a667870d3118e48bfac03eee9accb3d48a72ce..8645e5430295dff0a5b7c715b03860fb7734e7f1 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.pbtxt @@ -90,11 +90,11 @@ tf_module { } member_method { name: "glorot_normal" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " } member_method { name: "glorot_uniform" - argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " } member_method { name: "he_normal" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.random_normal.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_normal.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.random_uniform.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.random_uniform.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.truncated_normal.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.truncated_normal.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.uniform.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.uniform.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.uniform.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.zeros.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.zeros.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.initializers.zeros.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.initializers.zeros.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt index 2bf973debb175d27bb80e627d7ccbb41b567020d..86e328888e596852caf9ad1020dfdedb71864969 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt index 03f20e72c2a325cec000cf4a5cfc0f1bbf255c8f..b0ed54578109c6ae8d5bc2c9f5c978b562a9cc84 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt index 4b46b8d15afb0a2f636962b762e1808312c2f7c3..42f98ed03d426d60cabeb0b533311d41eb378285 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt index d8a1c76fd07634ef413152020a397897f2d5b97c..000898a4be928e4e64b4072ef3170b6fbc930bdf 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt index 622926bc4b8b2430ee1ab936665acb5744155e0d..380b49f99ce6e62770a9516ba81db99f194c5b37 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt index 82100d8e09c8e95730993527293d2b72ce69f1d4..82db5e6137639e516f6df6f0e130e73be516c9b8 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt index 408061077cdeab2f8fd08c7e972744e5ee383f52..b6ff688ec36f8c47b2ac9694fb84350818be25c5 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt index a3c80311043eeb95b06855f662a5e3d344803ba3..b41290f8b067397bf6678d9e98ac53f28a05a3fc 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt index e2dfaca29f86bd9d91d524ec337afad81e7f2da3..88a033e61f42e2fb02b08968ff001ea21195972a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt index 4f068d2066a450bab77becc85a33662b78ad03e2..c1b9b96044ed2e057b8e86dda59ee7f7166cfd43 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt index b8c261a74364e9bb6bf8f6c7463993fbff5e9552..f59f7727a3eaeb4fa5631cb1b42901ea6d39b06b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt index 4ccd6cace650e2efd1583c75f6639c8598bb8f20..7d3744ed92636a972bae2f9b62a6b2da8f91d106 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt index 2790e5fd850c24bd3e94cd15a6e079e1c9f79868..3fd4ccdab2573964c2f3192d503e9fb15f442dc5 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt @@ -107,7 +107,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt index b1326bd0e6054b2a3fd36e7ad42cd3d4a0cad8dc..ba21b50be41f3adc735b3350bdf9dbeae3c2e358 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt similarity index 96% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt index e3ac3dbf28da731e14640d5f464547d62391a28f..46f9fa2bbbbe3cfff3aade33c5ebdec92bc70ef0 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -188,7 +188,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt index 1117a695a395f495d988464bbf59d4b8e01877e6..c3ad326589d2822bc5dd381d78216b25f5fb6f95 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt index b9de1421428dcf61b988df343a22996cfb8fecef..fd9eb43066be580a7df57aeb717b59569c9bba61 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt @@ -100,7 +100,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt index deb535e06e06008a17b80c8e13d8f01ad1535059..40d61688f29a81e873a26c8a5eb823d679320ed6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt index 9a9a223fbad11cafd8620110d80b27d5382dd29c..b8c227d7257311578e41abe0a384ed93e6a2866c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt @@ -100,7 +100,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt index 1c59b0bdf624b09a7454f2d51698951a790f393a..095d35e5749d0113956b04f971e6a8ca1fa277b8 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt index 30cf5489f4fcd4af3d0bd957fc9c576c57ee2bbd..8f999611982bbfe3c613ef26d93782e299275f19 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt index 0ec69508d5a1992b46d1a7c65255cfb5408ab439..96d522a016aedba01032a1c05a69511cb03d19af 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt @@ -100,7 +100,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt index 4cd8928403c98abad85bc1349a29148c73003c9d..de2824dab4526d90eebf9cef16710cadf82f4850 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt index 4b4912496deac2a79a5b0ea3d1ca0f8fa625301a..1d563241d8f0d93bcd19a319eb8383f4bcdf4388 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt @@ -100,7 +100,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt index d0ad9cf56702e585e31a79de0f93d9efd48ed484..c87e52c53796f0743365a9d8780decf237bba070 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt index 98cff95a7fe9d4e58cf883502df08c58c651cd76..dccf5523e3870b6c1ce0de70c648ab47968a105f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt index 2357498b46376ef13de102944b69931a9e7d3584..7ac4116d922eea51e5a7e7fe3d02ad919300c459 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt index 3324cbff304c5106360f3f3d3d608a528fa5fc31..024f72705de1e76866a8132246884dffb0c4e72a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt index 6c81823654b78a936cded4a1d5a6f54e02dc7fc9..4e0233331bd47e86e8a4df2f84b5392517fbf884 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt @@ -108,7 +108,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt index 487e04fd0790cb39ef6aee8d0498b3aae6726084..32d46ce8f3deff6077eaf5a1a8cf7ba64478d9f4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt @@ -108,7 +108,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt index 137e7cced4e8113dd6a54a837e08cfd5af35c94d..858486c725c3be5ecae2a02d0d3134ebeb113ce1 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt index 7161665d2550c1cc3aff1c28f9d7676276b62303..f65d7509262bfeb148588e069c08961058a3fa74 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt @@ -100,7 +100,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt index 24affa248121bcb1e1a947417a95ad4f5ba55ab2..2e71ef503d54927edbb3e1ef6c701ac845883e46 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt index 7ba19a42695da37b4ad43cdde2c0d4978fd0a1eb..42533bcd21b28a0acf183db195a6b5c1848a5d91 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt index 503aa9162c3a78e9bb42ce16af98451441adbbb7..b5df16941792a29d72f2ee709993b007d342d2d0 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt index 1737e590a29c5777b5eca2b4cb23081aa8ece738..0ea17919a9bb13ffdedd60ce618bca23dd52712f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt index 021d024dc2150a75532ea7597d85f36efd2a3cf2..a33248bc005a73d0be679cd62150d6019b475305 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt index 65387008bf3f78e404d8d8bbd7bb8cd3789bf256..4ba21a25cda83122fbced7fed76d4b1ae28cb4c8 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt similarity index 96% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt index 4f791acf0585c95d6c0f1d5ea48e607f9a05188d..a7a570418e0a78873237c1c8cefe36a212e4c9af 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt @@ -171,7 +171,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt index abc30e54e0630a2d7b4de6074445e155e0ac2782..763bc231136908d469b7f942aec94f6248d2e2d4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt index 20791bb448d17788ea4aebe4900169a70a9703d6..3c50a3d7f28809b2b810b52951207e48f9f50e34 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt index 449a91d8735c59f563360307cdb35c5a30344d82..ac78bdafada8c157efd4ab8746be15726eb0bc24 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt index bb361e129728ddd42c21144937efbc617d98ba30..275282d9d2b1753cf0189b605f921bb039ef5f3c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt index e564bf3216104a902fb6cfbe65b1e2b6dafc2524..0e31e6058bd6036a5fb4422335917718f4f82851 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt index 4cb9cc3ec84d679b78465e43caa5a257466d5676..aacd0b1791dda5babb6eef5d87a1335c8d519b08 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt index 5ed52b88ae3e2dd25b560206db404952034a04cd..c23654866341818aeb804cfb71dae052049e3f25 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt index f4559d29d75ef7cd8fcbdeac0a1a2c9e633246bc..6b9c0290aac35d80c7f87acfc44479c57623a645 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt index 64e2d061e26997365c461113d3ea15140fef64dd..0d7b2211e6cd35ca331b4a1068f237e7ca07f70c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt index 3372ad645388beb54f7ed9e3715449facba07f87..d080ad6aedbd5183da890cd63f5f18453d5d476a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt index 08a6860bcd7d9a260e44af87c51796a9cc2af379..fcb0a109da208ff5bd20447ddced9816a42af311 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt index 22c9eab64fde41e1199ecbb1b8b03939653ecd00..1d0e22abd0d8732182881c43ee79400642cef24b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt index 74c405ba9b1b465f89c4fef43020181a1a7f3d31..653c9f547bc888a8fec87137f7d495141d4f8599 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt index 39f6f981931296eb6d31eb6580f93b479ff64ce6..cdbaf82cf6746e878619647439d2256f6e2c4aa3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt index 7b25e80b6b7653c5e76bf176b54110b1aabaf5ea..230c5e903438b0a75edf80f0f5c8706987c66a78 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-spec.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-input-spec.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-spec.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt index 3619b8bfc44373ba6b8e306b020ac63d4b498573..511456e740837455818ff3f9be270daed03f334f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt similarity index 96% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt index 8ef3d71dd82efc79e333770d4a7a7c8aee1a4202..4a3492ebd652e5ab8f0faf8a1583480abc80fba7 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt @@ -171,7 +171,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2dff7a6de4231711dc9154ea8ba036e1e4a1ed11 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.Lambda" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'function\', \'output_shape\', \'mask\', \'arguments\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt index 9b90db1e5e56d1e5749669bba8dba1cdbd45bb55..7efa29be77c075a29784d8cd3ebfcd871bc9aa0c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt @@ -97,7 +97,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt index 3c60eaab7f1df15331004685676d74943d5d538f..0ca8e0b52c4a81c4ff3b756aa6c24b47a664f999 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt index 3dac1ff342ac1b7f984e9af5a6028ef71da701df..f754fa1da85692c28f31a76bbfa987b3c4c30731 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt index 7f1b5db4d34f706f2107ef43ab9c5acf67dac9f6..c9516b8f07d0b6a818bf99d45499d161c2a5cffd 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt index b3e31000f3bca0821377d70b1d88a20aa8f8e4ef..850ecff9743b5f5048bb81c5a15b0a4be6b4d0ce 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt index bbd9d1b0dc075bb9241f240b423933db20b38b75..7c69e31f9af9bbd221882d160fa4206997ec3b08 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt index fe72beea802d12b996948b00436b274ee7e83177..fba42642d7c701688c2bd274cf97e077e7ff571c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt index e9bf57b2b0e60376a28c0abfc16fba393df3e73c..9c277411ea5ce26df9c033ada773ad2e45292cb1 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt index 0eecc58a2b6a2846a2c92502cc23bd328f8b5193..7c2f6ccc8a98017aba014ab6a7896e0a4bf40324 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt index 96785a7d8559611a19b7f36216dbf0f8a3e39e61..802178dba63d66cca1629bcb7bef0f578c9a6659 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt index 42c46cccb37b1ab7ece7760e6858b2180ea833b9..e870dfe9ade75da367f87a4b54d38ba4274bab2e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt index ac816f68d492cbfc5503c057a869e3e981de9190..c1337ce0cbac2d1e0e011f5309bfb2722960d3b2 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt index 56e32e9d3690a92c3f6e41bf2b5164c6bf62f443..ed27a62765d5670802d4593b3e648e3f65eaf926 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt index 9ae99563e9a1b3b0700116ed88c13f94fafe1658..b9f05cb3e56f89cb02e1a74c3ec0d362ea27f2bf 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt index 815f3bc2d142069adb4e418a4dc6ef82d683373f..336d9f76fb1e6215b763b5064cd6be68d4d0d5a0 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt index e704992b4a18f6bdbd9474af2ee59ea81534d80a..46282217e01e8a137d9fc564f0e3544602d93de4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt index b3a58fa11eda61baa5c932bcc04fdca7459a215f..42cd7e87eebdd969f002d8bcd0dca101168c58e0 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt @@ -102,7 +102,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4d3de58bd188e301ef516ac5eeae6cf0709d66da --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.ReLU" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'max_value\', \'negative_slope\', \'threshold\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0\', \'0\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt index 78f464583b4e8083f4cdd1a8c6b9f377645cd562..9f094a877a3a47ff89a022db563803f5f391ff2a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt index 222344fd0497afe9a32d1d05ec37aa160479d88a..2f519a24385ac4e147798ed3e96101cff23e19aa 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt index 55fddf576cac6afabe984cd51e2ddbf112a55d25..6b93116ba02c2b7e9c5bdf79ddfa1f93050062a4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt @@ -100,7 +100,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt index 96314ce49849a50ccc6b968b50c98ddae74c6c70..fd17115e2733d561bff1d53d62d32458b03dc65b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt @@ -100,7 +100,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt index 88bdf9956603c590940e3ef857765586df7e91d7..4b37a94478857ac8550ea0c4f464058c68770047 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt @@ -100,7 +100,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt index 6eeea7a8d1312ada423206378b4c6ee079ffdd73..5bdadca74aeb963adef4999b7e758add1aec4681 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt @@ -100,7 +100,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt index 3050d46249003716eb0778104b729ee9cb52b34f..9dfda96fc81572d70d76ba767b69ee2e41f017ee 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt similarity index 96% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt index dda4c9358ba5faa084ad2e6cf75ff83b6a7b2b20..7b7684ccd27a1d4c3fabf56c2669f77095f501ef 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt @@ -159,7 +159,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt index cc6275158b67e94c3c39802cc7c0f9e169c8b144..3b15407fca2cf65f7fa31f29b84db52b5c5d1a7a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt index 5eb7e750477b17571ef861305806894dd2b9ac38..6d04415267c9ce21268b9d86a5b078d8f92db93f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt index 500cb8c14ead3eeff28d11b72e2300cc471756d2..04950654d55f30bf095167d176b5b2717e72f2cd 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt index 1113a7634fa98b499175d90ae7da2d3fb9fb1a13..c424e6dcc869f977100e77fdb543983c3ab7e63c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6718e36dc6057d70e9101b6fa26a53f3fb3f3569 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -0,0 +1,183 @@ +path: "tensorflow.keras.layers.StackedRNNCells" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cells\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'states\', \'constants\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt index 35ad87ad5d91f1cc5d413b0adc8e9e5d1403726a..740a03367bd69edf797d3ea8616fdde72f6726b7 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt @@ -99,7 +99,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt index 282c98d79a6e1da46e4d7ea2e5c7228754792f09..a08c583adb4175ff5ee77869c80c6c0204018166 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt index acab93706b29fedc1bf7b48da2f5b6636dea48e5..c1294fed0fcfca9c8607bf3e5d41efd240fd4d45 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt @@ -103,7 +103,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt index a5ec228a074721775d4ec0369345b5439d84e186..dc401d3ed0fee5b6fb4bb5563941c3461eb592f4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt index d8d8e0bfe95a6cf2ef61cdb344b963df3f21aabb..4b5165ae9793f900fb474affe52b9abaeb64adbd 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt index 97d6dc06fb2e883b20540e4496efa5b39a538263..789af15fea8c0d41dd3f0c00e7be3afd6afafecf 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt index ea9bb41b9979de9049397892372f37aafc719a68..0536a7cee7e6dd5878f532854753cebeaa043c21 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt @@ -102,7 +102,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt index e6d1d2e089b01c4eb212d01c456f6fa6b850f7de..8915353ec334f28c4ed058b20a506ff102ca1f61 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt index f62017305f26519181b1ef86bdd0946d44d16b88..6efb5ef15a133877666decfd1f2b40fad4463469 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt index 07a1fde5bdc35535ca5d8443a97cb85adc54b14a..4c33c5d0bf800239e2bff4cc874e594b515a8071 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt similarity index 99% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt index 475e9dade3a6f5ba8a5020afbd4668be5b5ed9d7..9d7e5bb8c7808689bedd8abb835e61c1f38fdb1d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt @@ -300,6 +300,10 @@ tf_module { name: "RNN" mtype: "" } + member { + name: "ReLU" + mtype: "" + } member { name: "RepeatVector" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.losses.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.losses.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.losses.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt similarity index 97% rename from tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt index a97a9b57587070ec4841b627920ac91737a67997..73b577da373b1381a7e8d5841d6e002452a21f9e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt @@ -22,7 +22,7 @@ tf_module { } member_method { name: "binary_accuracy" - argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'y_true\', \'y_pred\', \'threshold\'], varargs=None, keywords=None, defaults=[\'0.5\'], " } member_method { name: "binary_crossentropy" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..56914e1746b0429adc2570c6cb31ddc8f9a6535a --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt @@ -0,0 +1,268 @@ +path: "tensorflow.keras.models.Model" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "layers" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "state_updates" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "uses_learning_phase" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compile" + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + } + member_method { + name: "evaluate_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "fit" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + } + member_method { + name: "fit_generator" + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_layer" + argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_weights" + argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "predict_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "predict_on_batch" + argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "save" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], " + } + member_method { + name: "save_weights" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "summary" + argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "test_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "to_json" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "to_yaml" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "train_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..acfb3521c0f08a2d560d67231ca24d8ffe3e81b4 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt @@ -0,0 +1,285 @@ +path: "tensorflow.keras.models.Sequential" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "layers" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "state_updates" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "uses_learning_phase" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'layers\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add" + argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compile" + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + } + member_method { + name: "evaluate_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "fit" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + } + member_method { + name: "fit_generator" + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_layer" + argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_weights" + argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "pop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "predict_classes" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " + } + member_method { + name: "predict_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "predict_on_batch" + argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict_proba" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " + } + member_method { + name: "reset_states" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "save" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], " + } + member_method { + name: "save_weights" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "summary" + argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "test_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "to_json" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "to_yaml" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "train_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.models.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.models.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adadelta.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adagrad.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adam.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-adamax.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-nadam.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.optimizers.-optimizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-optimizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-r-m-sprop.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.optimizers.-s-g-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.-s-g-d.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.optimizers.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-l1-l2.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.regularizers.-l1-l2.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.regularizers.-l1-l2.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.regularizers.-l1-l2.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-regularizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.regularizers.-regularizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.regularizers.-regularizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.regularizers.-regularizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.regularizers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.regularizers.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.regularizers.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.regularizers.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-custom-object-scope.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-custom-object-scope.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.utils.-custom-object-scope.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-custom-object-scope.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-generator-enqueuer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-generator-enqueuer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-progbar.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-progbar.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence-enqueuer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-sequence-enqueuer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence-enqueuer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-sequence-enqueuer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-sequence.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-sequence.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.wrappers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.wrappers.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.wrappers.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.wrappers.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.wrappers.scikit_learn.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.keras.wrappers.scikit_learn.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt index 11067058d5852669e1672bf3eb8b7c680d0e5dc9..c82e67526b21696a7d56517dc2cb6998882dc7a5 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling1-d.pbtxt @@ -109,7 +109,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt index 3259e706d7f7ea4d0348c1ee586c50f5a2c82b39..1d031cb5f8461145127b0f13d77e6b8774f5a0b3 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling2-d.pbtxt @@ -109,7 +109,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt index e561f2f415018840420232a97f0ece3f3c60d0d7..a8dda6655df1d06ca77b74f0a992c8fd7e7a357d 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-average-pooling3-d.pbtxt @@ -109,7 +109,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt similarity index 96% rename from tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt index 3124a35c7852a97e79a3cfe575017484f2f5731f..97f65ed89436bd0b4027bb0cbeb80b6f1419269c 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt @@ -108,7 +108,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt index b5ec61255ace78c1fa13370727eb5f5084522f4a..ccd9578f0d62bd70ea252ddeac587d59c926b018 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv1-d.pbtxt @@ -109,7 +109,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt index b2c89ae66f53299289508eef174b5c44a6be2606..9cbb58d721bb49bde562a57728a9ee46968e611e 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d-transpose.pbtxt @@ -110,7 +110,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt index 9e4f4969dc6e1b6a39cf1d25c5e5e6175fa87c7c..c75ea3911e17bc879d140068ef54521effd2824e 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv2-d.pbtxt @@ -109,7 +109,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt index 9850e6d7659d311c93dabad73d35f2fcd028dd52..5dc834e5141e58d255357e02d7446a06e6e2aa45 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d-transpose.pbtxt @@ -110,7 +110,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt index be113826cc2b9589e1f8bbde896fbcbe183d4d1b..96ab209874ac14d6acf2e8115e7f04fc35c4b2bd 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-conv3-d.pbtxt @@ -109,7 +109,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt index 0d951bf6336ac7b65be57535c1065e5f87a77a0b..7e9656b3525c1d53940b869607616ff414a466cf 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dense.pbtxt @@ -108,7 +108,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt index f1beeed9ef0cb54318249e42b1279680ea117ba8..e9a2269a6e8de1f9a12f1b54d2e6dced3d4f8902 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-dropout.pbtxt @@ -108,7 +108,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt index b75a012811ff10f055382ea1315eaba506c24ed8..7d2eaaab2a8cb9159214a16ba65473d0b6870ac4 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-flatten.pbtxt @@ -108,7 +108,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-input-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-input-spec.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.layers.-input-spec.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-input-spec.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt index 80e0fb228b034727854ab1a4df97e25c6bc2cd97..8bc3eb26e9ca0bf0f129db336b7ca23466fd036f 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-layer.pbtxt @@ -106,7 +106,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt index 50ff484d733633e20e9923dbbf1344af7b51ba9a..6a0dcce56ac0184ffe995662fd62b89e16257a29 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling1-d.pbtxt @@ -109,7 +109,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt index cea809744cd07cc6ed0d1655f217cb5821e503e4..b6c84edf2a2f86240369b4053cd7351d0b59442d 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling2-d.pbtxt @@ -109,7 +109,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt index ab9e89554c81decf5ee7e42dc963da9ab35e65c7..062a02fa590537b9efbf540a874eeaa6d36697f3 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-max-pooling3-d.pbtxt @@ -109,7 +109,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt similarity index 96% rename from tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt index 4362568445e892d6127759c925d47426d49d9927..eaad0fb23ef7501c8c5b7acee6a9677665b7057f 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv1-d.pbtxt @@ -110,7 +110,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt similarity index 96% rename from tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt index 3cad824cd3b197b91a749347c860ff926610c081..ece28a8ce962d8fafb3f7a397a814b903e915d48 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-separable-conv2-d.pbtxt @@ -110,7 +110,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.layers.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-block-diag.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-block-diag.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-block-diag.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-block-diag.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-composition.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-composition.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-composition.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-composition.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-composition.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-composition.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-composition.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-composition.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-diag.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-diag.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-diag.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-diag.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-diag.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-diag.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-diag.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-diag.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-full-matrix.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-full-matrix.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-full-matrix.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-full-matrix.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-full-matrix.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-full-matrix.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-full-matrix.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-full-matrix.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-identity.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-identity.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-identity.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-identity.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-identity.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-identity.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-identity.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-identity.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-kronecker.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-kronecker.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-kronecker.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-kronecker.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-low-rank-update.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-low-rank-update.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-low-rank-update.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-low-rank-update.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-lower-triangular.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-lower-triangular.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-lower-triangular.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-lower-triangular.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-scaled-identity.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-scaled-identity.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-scaled-identity.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-scaled-identity.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-zeros.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-zeros.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..49ff85728ffab559ec706691356ce071aab89083 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-zeros.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorZeros.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-zeros.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-zeros.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a1b0e06b4753488bc9fcbe9aeb0d260092745f9c --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-zeros.pbtxt @@ -0,0 +1,130 @@ +path: "tensorflow.linalg.LinearOperatorZeros" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'assert_proper_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'True\', \'False\', \'True\', \'False\', \'LinearOperatorZeros\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'mat\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator.__metaclass__.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator.__metaclass__.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator.__metaclass__.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt similarity index 91% rename from tensorflow/tools/api/golden/tensorflow.linalg.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt index 00b9238543367546cff96b736f73440214e99e22..d979116887a739d2d372687fac0e5ea3b39a4b69 100644 --- a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt @@ -52,6 +52,10 @@ tf_module { name: "LinearOperatorScaledIdentity" mtype: "" } + member { + name: "LinearOperatorZeros" + mtype: "" + } member_method { name: "adjoint" argspec: "args=[\'matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -68,6 +72,10 @@ tf_module { name: "cholesky_solve" argspec: "args=[\'chol\', \'rhs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "cross" + argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "det" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -140,6 +148,14 @@ tf_module { name: "svd" argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], " } + member_method { + name: "tensor_diag" + argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tensor_diag_part" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "tensordot" argspec: "args=[\'a\', \'b\', \'axes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.logging.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.logging.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.logging.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.logging.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.losses.-reduction.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.losses.-reduction.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.losses.-reduction.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.losses.-reduction.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.losses.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.losses.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.losses.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.losses.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.manip.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.manip.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9add462396ea526ae94678e969c9acf5bce86df1 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.manip.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.manip" +tf_module { + member_method { + name: "batch_to_space_nd" + argspec: "args=[\'input\', \'block_shape\', \'crops\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "gather_nd" + argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reshape" + argspec: "args=[\'tensor\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reverse" + argspec: "args=[\'tensor\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "roll" + argspec: "args=[\'input\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "scatter_nd" + argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "space_to_batch_nd" + argspec: "args=[\'input\', \'block_shape\', \'paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tile" + argspec: "args=[\'input\', \'multiples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a308c76ebc08df06c0c360579451ea70e60695d4 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt @@ -0,0 +1,239 @@ +path: "tensorflow.math" +tf_module { + member_method { + name: "acos" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "acosh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "asin" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "asinh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atan2" + argspec: "args=[\'y\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atanh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bessel_i0" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bessel_i0e" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bessel_i1" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bessel_i1e" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "betainc" + argspec: "args=[\'a\', \'b\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ceil" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cos" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cosh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "digamma" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "erfc" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "exp" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "expm1" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "floor" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "greater" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "greater_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "igamma" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "igammac" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "invert_permutation" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "less" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "less_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lgamma" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "log" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "log1p" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_and" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_not" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_or" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "maximum" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "minimum" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "not_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "polygamma" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "polyval" + argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reciprocal" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rint" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rsqrt" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_max" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_mean" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_min" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_prod" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_sum" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sin" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sinh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "softplus" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "softsign" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "squared_difference" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_max" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_min" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_prod" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_sum" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "zeta" + argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.metrics.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.metrics.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.metrics.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.name_scope.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.name_scope.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.name_scope.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.name_scope.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt similarity index 98% rename from tensorflow/tools/api/golden/tensorflow.nn.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt index 455590d866a4c1ebea65ccff51e34f2e0b0479d7..d9e5b0d0fca8bbcf82feb34304f2a1e4f43f48dd 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt @@ -260,6 +260,10 @@ tf_module { name: "relu_layer" argspec: "args=[\'x\', \'weights\', \'biases\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "safe_embedding_lookup_sparse" + argspec: "args=[\'embedding_weights\', \'sparse_ids\', \'sparse_weights\', \'combiner\', \'default_id\', \'name\', \'partition_strategy\', \'max_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'mean\', \'None\', \'None\', \'div\', \'None\'], " + } member_method { name: "sampled_softmax_loss" argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\', \'seed\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'True\', \'mod\', \'sampled_softmax_loss\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e606eab919602a90f2c5bdd7044632a2c751ccca --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -0,0 +1,198 @@ +path: "tensorflow.nn.rnn_cell.BasicLSTMCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_units\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'1.0\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5deb02d56933c471eba09aac7ec3c5cce787f32b --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -0,0 +1,198 @@ +path: "tensorflow.nn.rnn_cell.BasicRNNCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index 62c393de34475a8806015bed187572f79cf2a196..8a63b4918008674041c9c216a5e5547ed7152fce 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -116,7 +116,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index f121ba7939acb14681aa6b04b333668dded37aad..db1aae275792dad94c4cf823d0d30f934e397601 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -120,7 +120,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..32fa151a8e6531d719ccaa7683ef6efb7701c5f7 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -0,0 +1,198 @@ +path: "tensorflow.nn.rnn_cell.GRUCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'kernel_initializer\', \'bias_initializer\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..30c6c2ce3b67cee9a909633539c0b0af5e4978ca --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -0,0 +1,198 @@ +path: "tensorflow.nn.rnn_cell.LSTMCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_units\', \'use_peepholes\', \'cell_clip\', \'initializer\', \'num_proj\', \'proj_clip\', \'num_unit_shards\', \'num_proj_shards\', \'forget_bias\', \'state_is_tuple\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'1.0\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt index 59623fc983a63c2966882aa5113423c0a9e23b72..72b40cc9f7a720888a1399a60aa216013e0b9918 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -116,7 +116,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt index e2ab5aaee9456ffbe42894f2384d7bc9c7ad6a6f..a5c2b4aefd6a1b96cbe63271ca27de06616f1deb 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -115,7 +115,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index bd2a6d61f8578a2a3c8d94d3a8d5eb49679df2f7..61d5f04b22a4b4e3801643958b73a35403b79139 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -116,7 +116,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.ones_initializer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.ones_initializer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.ones_initializer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.orthogonal_initializer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.orthogonal_initializer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.orthogonal_initializer.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4f1962769150c4aae5fdd9a7424bd22316f2cb45 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -0,0 +1,2215 @@ +path: "tensorflow" +tf_module { + member { + name: "AUTO_REUSE" + mtype: "" + } + member { + name: "AggregationMethod" + mtype: "" + } + member { + name: "AttrValue" + mtype: "" + } + member { + name: "COMPILER_VERSION" + mtype: "" + } + member { + name: "CXX11_ABI_FLAG" + mtype: "" + } + member { + name: "ConditionalAccumulator" + mtype: "" + } + member { + name: "ConditionalAccumulatorBase" + mtype: "" + } + member { + name: "ConfigProto" + mtype: "" + } + member { + name: "DType" + mtype: "" + } + member { + name: "DeviceSpec" + mtype: "" + } + member { + name: "Dimension" + mtype: "" + } + member { + name: "Event" + mtype: "" + } + member { + name: "FIFOQueue" + mtype: "" + } + member { + name: "FixedLenFeature" + mtype: "" + } + member { + name: "FixedLenSequenceFeature" + mtype: "" + } + member { + name: "FixedLengthRecordReader" + mtype: "" + } + member { + name: "GIT_VERSION" + mtype: "" + } + member { + name: "GPUOptions" + mtype: "" + } + member { + name: "GRAPH_DEF_VERSION" + mtype: "" + } + member { + name: "GRAPH_DEF_VERSION_MIN_CONSUMER" + mtype: "" + } + member { + name: "GRAPH_DEF_VERSION_MIN_PRODUCER" + mtype: "" + } + member { + name: "GradientTape" + mtype: "" + } + member { + name: "Graph" + mtype: "" + } + member { + name: "GraphDef" + mtype: "" + } + member { + name: "GraphKeys" + mtype: "" + } + member { + name: "GraphOptions" + mtype: "" + } + member { + name: "HistogramProto" + mtype: "" + } + member { + name: "IdentityReader" + mtype: "" + } + member { + name: "IndexedSlices" + mtype: "" + } + member { + name: "InteractiveSession" + mtype: "" + } + member { + name: "LMDBReader" + mtype: "" + } + member { + name: "LogMessage" + mtype: "" + } + member { + name: "MONOLITHIC_BUILD" + mtype: "" + } + member { + name: "MetaGraphDef" + mtype: "" + } + member { + name: "NameAttrList" + mtype: "" + } + member { + name: "NodeDef" + mtype: "" + } + member { + name: "OpError" + mtype: "" + } + member { + name: "Operation" + mtype: "" + } + member { + name: "OptimizerOptions" + mtype: "" + } + member { + name: "PaddingFIFOQueue" + mtype: "" + } + member { + name: "PriorityQueue" + mtype: "" + } + member { + name: "QUANTIZED_DTYPES" + mtype: "" + } + member { + name: "QueueBase" + mtype: "" + } + member { + name: "RandomShuffleQueue" + mtype: "" + } + member { + name: "ReaderBase" + mtype: "" + } + member { + name: "RegisterGradient" + mtype: "" + } + member { + name: "RunMetadata" + mtype: "" + } + member { + name: "RunOptions" + mtype: "" + } + member { + name: "Session" + mtype: "" + } + member { + name: "SessionLog" + mtype: "" + } + member { + name: "SparseConditionalAccumulator" + mtype: "" + } + member { + name: "SparseFeature" + mtype: "" + } + member { + name: "SparseTensor" + mtype: "" + } + member { + name: "SparseTensorValue" + mtype: "" + } + member { + name: "Summary" + mtype: "" + } + member { + name: "SummaryMetadata" + mtype: "" + } + member { + name: "TFRecordReader" + mtype: "" + } + member { + name: "Tensor" + mtype: "" + } + member { + name: "TensorArray" + mtype: "" + } + member { + name: "TensorInfo" + mtype: "" + } + member { + name: "TensorShape" + mtype: "" + } + member { + name: "TextLineReader" + mtype: "" + } + member { + name: "VERSION" + mtype: "" + } + member { + name: "VarLenFeature" + mtype: "" + } + member { + name: "Variable" + mtype: "" + } + member { + name: "VariableAggregation" + mtype: "" + } + member { + name: "VariableScope" + mtype: "" + } + member { + name: "VariableSynchronization" + mtype: "" + } + member { + name: "WholeFileReader" + mtype: "" + } + member { + name: "app" + mtype: "" + } + member { + name: "bfloat16" + mtype: "" + } + member { + name: "bitwise" + mtype: "" + } + member { + name: "bool" + mtype: "" + } + member { + name: "compat" + mtype: "" + } + member { + name: "complex128" + mtype: "" + } + member { + name: "complex64" + mtype: "" + } + member { + name: "constant_initializer" + mtype: "" + } + member { + name: "contrib" + mtype: "" + } + member { + name: "data" + mtype: "" + } + member { + name: "debugging" + mtype: "" + } + member { + name: "distributions" + mtype: "" + } + member { + name: "double" + mtype: "" + } + member { + name: "dtypes" + mtype: "" + } + member { + name: "errors" + mtype: "" + } + member { + name: "estimator" + mtype: "" + } + member { + name: "feature_column" + mtype: "" + } + member { + name: "flags" + mtype: "" + } + member { + name: "float16" + mtype: "" + } + member { + name: "float32" + mtype: "" + } + member { + name: "float64" + mtype: "" + } + member { + name: "gfile" + mtype: "" + } + member { + name: "graph_util" + mtype: "" + } + member { + name: "half" + mtype: "" + } + member { + name: "image" + mtype: "" + } + member { + name: "initializers" + mtype: "" + } + member { + name: "int16" + mtype: "" + } + member { + name: "int32" + mtype: "" + } + member { + name: "int64" + mtype: "" + } + member { + name: "int8" + mtype: "" + } + member { + name: "io" + mtype: "" + } + member { + name: "keras" + mtype: "" + } + member { + name: "layers" + mtype: "" + } + member { + name: "linalg" + mtype: "" + } + member { + name: "logging" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "manip" + mtype: "" + } + member { + name: "math" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "name_scope" + mtype: "" + } + member { + name: "newaxis" + mtype: "" + } + member { + name: "nn" + mtype: "" + } + member { + name: "ones_initializer" + mtype: "" + } + member { + name: "orthogonal_initializer" + mtype: "" + } + member { + name: "profiler" + mtype: "" + } + member { + name: "python_io" + mtype: "" + } + member { + name: "pywrap_tensorflow" + mtype: "" + } + member { + name: "qint16" + mtype: "" + } + member { + name: "qint32" + mtype: "" + } + member { + name: "qint8" + mtype: "" + } + member { + name: "quantization" + mtype: "" + } + member { + name: "quint16" + mtype: "" + } + member { + name: "quint8" + mtype: "" + } + member { + name: "random_normal_initializer" + mtype: "" + } + member { + name: "random_uniform_initializer" + mtype: "" + } + member { + name: "resource" + mtype: "" + } + member { + name: "resource_loader" + mtype: "" + } + member { + name: "saved_model" + mtype: "" + } + member { + name: "sets" + mtype: "" + } + member { + name: "sparse" + mtype: "" + } + member { + name: "spectral" + mtype: "" + } + member { + name: "string" + mtype: "" + } + member { + name: "strings" + mtype: "" + } + member { + name: "summary" + mtype: "" + } + member { + name: "sysconfig" + mtype: "" + } + member { + name: "test" + mtype: "" + } + member { + name: "train" + mtype: "" + } + member { + name: "truncated_normal_initializer" + mtype: "" + } + member { + name: "uint16" + mtype: "" + } + member { + name: "uint32" + mtype: "" + } + member { + name: "uint64" + mtype: "" + } + member { + name: "uint8" + mtype: "" + } + member { + name: "uniform_unit_scaling_initializer" + mtype: "" + } + member { + name: "user_ops" + mtype: "" + } + member { + name: "variable_scope" + mtype: "" + } + member { + name: "variance_scaling_initializer" + mtype: "" + } + member { + name: "variant" + mtype: "" + } + member { + name: "zeros_initializer" + mtype: "" + } + member_method { + name: "Assert" + argspec: "args=[\'condition\', \'data\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "NoGradient" + argspec: "args=[\'op_type\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "NotDifferentiable" + argspec: "args=[\'op_type\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Print" + argspec: "args=[\'input_\', \'data\', \'message\', \'first_n\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "abs" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "accumulate_n" + argspec: "args=[\'inputs\', \'shape\', \'tensor_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "acos" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "acosh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_check_numerics_ops" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_n" + argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_to_collection" + argspec: "args=[\'name\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_to_collections" + argspec: "args=[\'names\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "all_variables" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "angle" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "arg_max" + argspec: "args=[\'input\', \'dimension\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "arg_min" + argspec: "args=[\'input\', \'dimension\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "argmax" + argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\', \'output_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"\"], " + } + member_method { + name: "argmin" + argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\', \'output_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"\"], " + } + member_method { + name: "as_dtype" + argspec: "args=[\'type_value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_string" + argspec: "args=[\'input\', \'precision\', \'scientific\', \'shortest\', \'width\', \'fill\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'False\', \'False\', \'-1\', \'\', \'None\'], " + } + member_method { + name: "asin" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "asinh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "assert_equal" + argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_greater" + argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_greater_equal" + argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_integer" + argspec: "args=[\'x\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "assert_less" + argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_less_equal" + argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_near" + argspec: "args=[\'x\', \'y\', \'rtol\', \'atol\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_negative" + argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_non_negative" + argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_non_positive" + argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_none_equal" + argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_positive" + argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_proper_iterable" + argspec: "args=[\'values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "assert_rank" + argspec: "args=[\'x\', \'rank\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_rank_at_least" + argspec: "args=[\'x\', \'rank\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_rank_in" + argspec: "args=[\'x\', \'ranks\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "assert_same_float_dtype" + argspec: "args=[\'tensors\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "assert_scalar" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "assert_type" + argspec: "args=[\'tensor\', \'tf_type\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "assert_variables_initialized" + argspec: "args=[\'var_list\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "assign" + argspec: "args=[\'ref\', \'value\', \'validate_shape\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "assign_add" + argspec: "args=[\'ref\', \'value\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "assign_sub" + argspec: "args=[\'ref\', \'value\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "atan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atan2" + argspec: "args=[\'y\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atanh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "batch_to_space" + argspec: "args=[\'input\', \'crops\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "batch_to_space_nd" + argspec: "args=[\'input\', \'block_shape\', \'crops\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "betainc" + argspec: "args=[\'a\', \'b\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bincount" + argspec: "args=[\'arr\', \'weights\', \'minlength\', \'maxlength\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"\"], " + } + member_method { + name: "bitcast" + argspec: "args=[\'input\', \'type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "boolean_mask" + argspec: "args=[\'tensor\', \'mask\', \'name\', \'axis\'], varargs=None, keywords=None, defaults=[\'boolean_mask\', \'None\'], " + } + member_method { + name: "broadcast_dynamic_shape" + argspec: "args=[\'shape_x\', \'shape_y\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "broadcast_static_shape" + argspec: "args=[\'shape_x\', \'shape_y\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "broadcast_to" + argspec: "args=[\'input\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "case" + argspec: "args=[\'pred_fn_pairs\', \'default\', \'exclusive\', \'strict\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\', \'case\'], " + } + member_method { + name: "cast" + argspec: "args=[\'x\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ceil" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "check_numerics" + argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cholesky" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cholesky_solve" + argspec: "args=[\'chol\', \'rhs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "clip_by_average_norm" + argspec: "args=[\'t\', \'clip_norm\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "clip_by_global_norm" + argspec: "args=[\'t_list\', \'clip_norm\', \'use_norm\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "clip_by_norm" + argspec: "args=[\'t\', \'clip_norm\', \'axes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "clip_by_value" + argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "colocate_with" + argspec: "args=[\'op\', \'ignore_existing\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "complex" + argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "concat" + argspec: "args=[\'values\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'concat\'], " + } + member_method { + name: "cond" + argspec: "args=[\'pred\', \'true_fn\', \'false_fn\', \'strict\', \'name\', \'fn1\', \'fn2\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "confusion_matrix" + argspec: "args=[\'labels\', \'predictions\', \'num_classes\', \'dtype\', \'name\', \'weights\'], varargs=None, keywords=None, defaults=[\'None\', \"\", \'None\', \'None\'], " + } + member_method { + name: "conj" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "constant" + argspec: "args=[\'value\', \'dtype\', \'shape\', \'name\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Const\', \'False\'], " + } + member_method { + name: "container" + argspec: "args=[\'container_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "control_dependencies" + argspec: "args=[\'control_inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "convert_to_tensor" + argspec: "args=[\'value\', \'dtype\', \'name\', \'preferred_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "convert_to_tensor_or_indexed_slices" + argspec: "args=[\'value\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "convert_to_tensor_or_sparse_tensor" + argspec: "args=[\'value\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "cos" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cosh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "count_nonzero" + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'dtype\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"\", \'None\', \'None\', \'None\'], " + } + member_method { + name: "count_up_to" + argspec: "args=[\'ref\', \'limit\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "create_partitioned_variables" + argspec: "args=[\'shape\', \'slicing\', \'initializer\', \'dtype\', \'trainable\', \'collections\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\"\", \'True\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "cross" + argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cumprod" + argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "cumsum" + argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "custom_gradient" + argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "decode_base64" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_compressed" + argspec: "args=[\'bytes\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " + } + member_method { + name: "decode_csv" + argspec: "args=[\'records\', \'record_defaults\', \'field_delim\', \'use_quote_delim\', \'name\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\',\', \'True\', \'None\', \'\', \'None\'], " + } + member_method { + name: "decode_json_example" + argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_raw" + argspec: "args=[\'bytes\', \'out_type\', \'little_endian\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "delete_session_tensor" + argspec: "args=[\'handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "depth_to_space" + argspec: "args=[\'input\', \'block_size\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'NHWC\'], " + } + member_method { + name: "dequantize" + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\'], " + } + member_method { + name: "deserialize_many_sparse" + argspec: "args=[\'serialized_sparse\', \'dtype\', \'rank\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "device" + argspec: "args=[\'device_name_or_function\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "diag" + argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "digamma" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "div" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "divide" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dynamic_partition" + argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dynamic_stitch" + argspec: "args=[\'indices\', \'data\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "edit_distance" + argspec: "args=[\'hypothesis\', \'truth\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'edit_distance\'], " + } + member_method { + name: "einsum" + argspec: "args=[\'equation\'], varargs=inputs, keywords=kwargs, defaults=None" + } + member_method { + name: "enable_eager_execution" + argspec: "args=[\'config\', \'device_policy\', \'execution_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "encode_base64" + argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "erf" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "erfc" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "executing_eagerly" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "exp" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "expand_dims" + argspec: "args=[\'input\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "expm1" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "extract_image_patches" + argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "eye" + argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"\", \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_args" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_args_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_per_channel" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_per_channel_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fft" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fft2d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fft3d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fill" + argspec: "args=[\'dims\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fixed_size_partitioner" + argspec: "args=[\'num_shards\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], " + } + member_method { + name: "floor" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "floor_div" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "floordiv" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "floormod" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "foldl" + argspec: "args=[\'fn\', \'elems\', \'initializer\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\'], " + } + member_method { + name: "foldr" + argspec: "args=[\'fn\', \'elems\', \'initializer\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\'], " + } + member_method { + name: "gather" + argspec: "args=[\'params\', \'indices\', \'validate_indices\', \'name\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0\'], " + } + member_method { + name: "gather_nd" + argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_collection" + argspec: "args=[\'key\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_collection_ref" + argspec: "args=[\'key\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_default_graph" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_default_session" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_local_variable" + argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'synchronization\', \'aggregation\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\', \'None\'], " + } + member_method { + name: "get_seed" + argspec: "args=[\'op_seed\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_session_handle" + argspec: "args=[\'data\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_session_tensor" + argspec: "args=[\'handle\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_variable" + argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "get_variable_scope" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "global_norm" + argspec: "args=[\'t_list\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "global_variables" + argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "global_variables_initializer" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "glorot_normal_initializer" + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "glorot_uniform_initializer" + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "gradients" + argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "greater" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "greater_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "group" + argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None" + } + member_method { + name: "guarantee_const" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "hessians" + argspec: "args=[\'ys\', \'xs\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\'], varargs=None, keywords=None, defaults=[\'hessians\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "histogram_fixed_width" + argspec: "args=[\'values\', \'value_range\', \'nbins\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'100\', \"\", \'None\'], " + } + member_method { + name: "histogram_fixed_width_bins" + argspec: "args=[\'values\', \'value_range\', \'nbins\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'100\', \"\", \'None\'], " + } + member_method { + name: "identity" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "identity_n" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ifft" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ifft2d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ifft3d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "igamma" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "igammac" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "imag" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "import_graph_def" + argspec: "args=[\'graph_def\', \'input_map\', \'return_elements\', \'name\', \'op_dict\', \'producer_op_list\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "initialize_all_tables" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], " + } + member_method { + name: "initialize_all_variables" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "initialize_local_variables" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "initialize_variables" + argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'init\'], " + } + member_method { + name: "invert_permutation" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_finite" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_inf" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_nan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_non_decreasing" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_numeric_tensor" + argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_strictly_increasing" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_variable_initialized" + argspec: "args=[\'variable\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "lbeta" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "less" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "less_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lgamma" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lin_space" + argspec: "args=[\'start\', \'stop\', \'num\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "linspace" + argspec: "args=[\'start\', \'stop\', \'num\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "load_file_system_library" + argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_op_library" + argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "local_variables" + argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "local_variables_initializer" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "log" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "log1p" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "log_sigmoid" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_and" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_not" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_or" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_xor" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'LogicalXor\'], " + } + member_method { + name: "make_ndarray" + argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "make_template" + argspec: "args=[\'name_\', \'func_\', \'create_scope_now_\', \'unique_name_\', \'custom_getter_\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'None\', \'None\'], " + } + member_method { + name: "make_tensor_proto" + argspec: "args=[\'values\', \'dtype\', \'shape\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " + } + member_method { + name: "map_fn" + argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'True\', \'None\'], " + } + member_method { + name: "matching_files" + argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "matrix_band_part" + argspec: "args=[\'input\', \'num_lower\', \'num_upper\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "matrix_determinant" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "matrix_diag" + argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "matrix_diag_part" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "matrix_inverse" + argspec: "args=[\'input\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "matrix_set_diag" + argspec: "args=[\'input\', \'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "matrix_solve" + argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "matrix_solve_ls" + argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], " + } + member_method { + name: "matrix_transpose" + argspec: "args=[\'a\', \'name\', \'conjugate\'], varargs=None, keywords=None, defaults=[\'matrix_transpose\', \'False\'], " + } + member_method { + name: "matrix_triangular_solve" + argspec: "args=[\'matrix\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], " + } + member_method { + name: "maximum" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "meshgrid" + argspec: "args=[], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "min_max_variable_partitioner" + argspec: "args=[\'max_partitions\', \'axis\', \'min_slice_size\', \'bytes_per_string_element\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'262144\', \'16\'], " + } + member_method { + name: "minimum" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "mod" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "model_variables" + argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "moving_average_variables" + argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "multinomial" + argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'name\', \'output_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "multiply" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "negative" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "no_op" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "no_regularizer" + argspec: "args=[\'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "norm" + argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "not_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "one_hot" + argspec: "args=[\'indices\', \'depth\', \'on_value\', \'off_value\', \'axis\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "ones" + argspec: "args=[\'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "ones_like" + argspec: "args=[\'tensor\', \'dtype\', \'name\', \'optimize\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], " + } + member_method { + name: "op_scope" + argspec: "args=[\'values\', \'name\', \'default_name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "pad" + argspec: "args=[\'tensor\', \'paddings\', \'mode\', \'name\', \'constant_values\'], varargs=None, keywords=None, defaults=[\'CONSTANT\', \'None\', \'0\'], " + } + member_method { + name: "parallel_stack" + argspec: "args=[\'values\', \'name\'], varargs=None, keywords=None, defaults=[\'parallel_stack\'], " + } + member_method { + name: "parse_example" + argspec: "args=[\'serialized\', \'features\', \'name\', \'example_names\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "parse_single_example" + argspec: "args=[\'serialized\', \'features\', \'name\', \'example_names\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "parse_single_sequence_example" + argspec: "args=[\'serialized\', \'context_features\', \'sequence_features\', \'example_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "parse_tensor" + argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "placeholder" + argspec: "args=[\'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "placeholder_with_default" + argspec: "args=[\'input\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "polygamma" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "pow" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "py_func" + argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "qr" + argspec: "args=[\'input\', \'full_matrices\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "quantize" + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'T\', \'mode\', \'round_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'HALF_AWAY_FROM_ZERO\', \'None\'], " + } + member_method { + name: "quantize_v2" + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'T\', \'mode\', \'name\', \'round_mode\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'HALF_AWAY_FROM_ZERO\'], " + } + member_method { + name: "quantized_concat" + argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_crop" + argspec: "args=[\'value\', \'size\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "random_gamma" + argspec: "args=[\'shape\', \'alpha\', \'beta\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"\", \'None\', \'None\'], " + } + member_method { + name: "random_normal" + argspec: "args=[\'shape\', \'mean\', \'stddev\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"\", \'None\', \'None\'], " + } + member_method { + name: "random_poisson" + argspec: "args=[\'lam\', \'shape\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\', \'None\'], " + } + member_method { + name: "random_shuffle" + argspec: "args=[\'value\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "random_uniform" + argspec: "args=[\'shape\', \'minval\', \'maxval\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \"\", \'None\', \'None\'], " + } + member_method { + name: "range" + argspec: "args=[\'start\', \'limit\', \'delta\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'range\'], " + } + member_method { + name: "rank" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_file" + argspec: "args=[\'filename\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "real" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "realdiv" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reciprocal" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reduce_all" + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "reduce_any" + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "reduce_join" + argspec: "args=[\'inputs\', \'axis\', \'keep_dims\', \'separator\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'\', \'None\', \'None\'], " + } + member_method { + name: "reduce_logsumexp" + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "reduce_max" + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "reduce_mean" + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "reduce_min" + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "reduce_prod" + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "reduce_sum" + argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "regex_replace" + argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "register_tensor_conversion_function" + argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], " + } + member_method { + name: "report_uninitialized_variables" + argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'report_uninitialized_variables\'], " + } + member_method { + name: "required_space_to_batch_paddings" + argspec: "args=[\'input_shape\', \'block_shape\', \'base_paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "reset_default_graph" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reshape" + argspec: "args=[\'tensor\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reverse" + argspec: "args=[\'tensor\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reverse_sequence" + argspec: "args=[\'input\', \'seq_lengths\', \'seq_axis\', \'batch_axis\', \'name\', \'seq_dim\', \'batch_dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "reverse_v2" + argspec: "args=[\'tensor\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rint" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "round" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rsqrt" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "saturate_cast" + argspec: "args=[\'value\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "scalar_mul" + argspec: "args=[\'scalar\', \'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "scan" + argspec: "args=[\'fn\', \'elems\', \'initializer\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'True\', \'False\', \'None\'], " + } + member_method { + name: "scatter_add" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "scatter_div" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "scatter_max" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "scatter_min" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "scatter_mul" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "scatter_nd" + argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "scatter_nd_add" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "scatter_nd_sub" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "scatter_nd_update" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "scatter_sub" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "scatter_update" + argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "segment_max" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_mean" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_min" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_prod" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_sum" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "self_adjoint_eig" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "self_adjoint_eigvals" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sequence_mask" + argspec: "args=[\'lengths\', \'maxlen\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"\", \'None\'], " + } + member_method { + name: "serialize_many_sparse" + argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "serialize_sparse" + argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "serialize_tensor" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_random_seed" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "setdiff1d" + argspec: "args=[\'x\', \'y\', \'index_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "shape" + argspec: "args=[\'input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "shape_n" + argspec: "args=[\'input\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "sigmoid" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sign" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sin" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sinh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "slice" + argspec: "args=[\'input_\', \'begin\', \'size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "space_to_batch" + argspec: "args=[\'input\', \'paddings\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "space_to_batch_nd" + argspec: "args=[\'input\', \'block_shape\', \'paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "space_to_depth" + argspec: "args=[\'input\', \'block_size\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'NHWC\'], " + } + member_method { + name: "sparse_add" + argspec: "args=[\'a\', \'b\', \'thresh\'], varargs=None, keywords=None, defaults=[\'0\'], " + } + member_method { + name: "sparse_concat" + argspec: "args=[\'axis\', \'sp_inputs\', \'name\', \'expand_nonconcat_dim\', \'concat_dim\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "sparse_fill_empty_rows" + argspec: "args=[\'sp_input\', \'default_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_mask" + argspec: "args=[\'a\', \'mask_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_matmul" + argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "sparse_maximum" + argspec: "args=[\'sp_a\', \'sp_b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_merge" + argspec: "args=[\'sp_ids\', \'sp_values\', \'vocab_size\', \'name\', \'already_sorted\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "sparse_minimum" + argspec: "args=[\'sp_a\', \'sp_b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_placeholder" + argspec: "args=[\'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "sparse_reduce_max" + argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sparse_reduce_max_sparse" + argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sparse_reduce_sum" + argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sparse_reduce_sum_sparse" + argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sparse_reorder" + argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_reset_shape" + argspec: "args=[\'sp_input\', \'new_shape\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_reshape" + argspec: "args=[\'sp_input\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_retain" + argspec: "args=[\'sp_input\', \'to_retain\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "sparse_segment_mean" + argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "sparse_segment_sqrt_n" + argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "sparse_segment_sum" + argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "sparse_slice" + argspec: "args=[\'sp_input\', \'start\', \'size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_softmax" + argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_split" + argspec: "args=[\'keyword_required\', \'sp_input\', \'num_split\', \'axis\', \'name\', \'split_dim\'], varargs=None, keywords=None, defaults=[\'KeywordRequired()\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sparse_tensor_dense_matmul" + argspec: "args=[\'sp_a\', \'b\', \'adjoint_a\', \'adjoint_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " + } + member_method { + name: "sparse_tensor_to_dense" + argspec: "args=[\'sp_input\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], " + } + member_method { + name: "sparse_to_dense" + argspec: "args=[\'sparse_indices\', \'output_shape\', \'sparse_values\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], " + } + member_method { + name: "sparse_to_indicator" + argspec: "args=[\'sp_input\', \'vocab_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_transpose" + argspec: "args=[\'sp_input\', \'perm\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "split" + argspec: "args=[\'value\', \'num_or_size_splits\', \'axis\', \'num\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'split\'], " + } + member_method { + name: "sqrt" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "square" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "squared_difference" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "squeeze" + argspec: "args=[\'input\', \'axis\', \'name\', \'squeeze_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "stack" + argspec: "args=[\'values\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'stack\'], " + } + member_method { + name: "stop_gradient" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "strided_slice" + argspec: "args=[\'input_\', \'begin\', \'end\', \'strides\', \'begin_mask\', \'end_mask\', \'ellipsis_mask\', \'new_axis_mask\', \'shrink_axis_mask\', \'var\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'0\', \'0\', \'0\', \'None\', \'None\'], " + } + member_method { + name: "string_join" + argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " + } + member_method { + name: "string_split" + argspec: "args=[\'source\', \'delimiter\', \'skip_empty\'], varargs=None, keywords=None, defaults=[\' \', \'True\'], " + } + member_method { + name: "string_strip" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "string_to_hash_bucket" + argspec: "args=[\'string_tensor\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "string_to_hash_bucket_fast" + argspec: "args=[\'input\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "string_to_hash_bucket_strong" + argspec: "args=[\'input\', \'num_buckets\', \'key\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "string_to_number" + argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "substr" + argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "subtract" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "svd" + argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], " + } + member_method { + name: "tables_initializer" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], " + } + member_method { + name: "tan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tanh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tensordot" + argspec: "args=[\'a\', \'b\', \'axes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tile" + argspec: "args=[\'input\', \'multiples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "timestamp" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_bfloat16" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToBFloat16\'], " + } + member_method { + name: "to_complex128" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToComplex128\'], " + } + member_method { + name: "to_complex64" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToComplex64\'], " + } + member_method { + name: "to_double" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToDouble\'], " + } + member_method { + name: "to_float" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToFloat\'], " + } + member_method { + name: "to_int32" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToInt32\'], " + } + member_method { + name: "to_int64" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToInt64\'], " + } + member_method { + name: "trace" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "trainable_variables" + argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "transpose" + argspec: "args=[\'a\', \'perm\', \'name\', \'conjugate\'], varargs=None, keywords=None, defaults=[\'None\', \'transpose\', \'False\'], " + } + member_method { + name: "truediv" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "truncated_normal" + argspec: "args=[\'shape\', \'mean\', \'stddev\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"\", \'None\', \'None\'], " + } + member_method { + name: "truncatediv" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "truncatemod" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tuple" + argspec: "args=[\'tensors\', \'name\', \'control_inputs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "unique" + argspec: "args=[\'x\', \'out_idx\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "unique_with_counts" + argspec: "args=[\'x\', \'out_idx\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "unravel_index" + argspec: "args=[\'indices\', \'dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_max" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_mean" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_min" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_prod" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_sqrt_n" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_sum" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unstack" + argspec: "args=[\'value\', \'num\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'unstack\'], " + } + member_method { + name: "variable_axis_size_partitioner" + argspec: "args=[\'max_shard_bytes\', \'axis\', \'bytes_per_string_element\', \'max_shards\'], varargs=None, keywords=None, defaults=[\'0\', \'16\', \'None\'], " + } + member_method { + name: "variable_op_scope" + argspec: "args=[\'values\', \'name_or_scope\', \'default_name\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "variables_initializer" + argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'init\'], " + } + member_method { + name: "verify_tensor_all_finite" + argspec: "args=[\'t\', \'msg\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "where" + argspec: "args=[\'condition\', \'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "while_loop" + argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\', \'return_same_structure\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\', \'False\'], " + } + member_method { + name: "write_file" + argspec: "args=[\'filename\', \'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "zeros" + argspec: "args=[\'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "zeros_like" + argspec: "args=[\'tensor\', \'dtype\', \'name\', \'optimize\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], " + } + member_method { + name: "zeta" + argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.profiler.-advice-proto.-checker.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checker.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.profiler.-advice-proto.-checker.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.profiler.-advice-proto.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.profiler.-advice-proto.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.profiler.-advice-proto.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.profiler.-graph-node-proto.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.profiler.-graph-node-proto.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.profiler.-graph-node-proto.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.profiler.-multi-graph-node-proto.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.profiler.-multi-graph-node-proto.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.profiler.-multi-graph-node-proto.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.profiler.-op-log-proto.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.profiler.-op-log-proto.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.profiler.-op-log-proto.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-profile-option-builder.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.profiler.-profile-option-builder.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.profiler.-profile-option-builder.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.profiler.-profile-option-builder.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.-profiler.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.profiler.-profiler.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.profiler.-profiler.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.profiler.-profiler.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.profiler.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.profiler.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.profiler.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.profiler.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-compression-type.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-compression-type.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-compression-type.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-compression-type.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-options.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-writer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.python_io.-t-f-record-writer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-writer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.python_io.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.python_io.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.python_io.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.python_io.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6d865efed0bfdada8dde64e86ddb5d2b2b364c79 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.quantization" +tf_module { + member_method { + name: "dequantize" + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_args" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_args_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_per_channel" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_per_channel_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "quantized_concat" + argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.random_normal_initializer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.random_normal_initializer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.random_normal_initializer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.random_uniform_initializer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.random_uniform_initializer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.random_uniform_initializer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.resource_loader.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.resource_loader.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.resource_loader.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.resource_loader.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.builder.-saved-model-builder.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.saved_model.builder.-saved-model-builder.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.builder.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.saved_model.builder.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.saved_model.builder.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.constants.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.constants.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.saved_model.constants.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.saved_model.constants.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.loader.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.saved_model.loader.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.main_op.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.main_op.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.saved_model.main_op.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.saved_model.main_op.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.saved_model.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.signature_constants.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.signature_constants.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.saved_model.signature_constants.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.saved_model.signature_constants.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.signature_def_utils.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.signature_def_utils.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.saved_model.signature_def_utils.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.saved_model.signature_def_utils.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.tag_constants.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.saved_model.tag_constants.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.saved_model.tag_constants.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.utils.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.utils.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.saved_model.utils.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.saved_model.utils.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.sets.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sets.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.sets.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.sets.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.sparse.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.spectral.pbtxt similarity index 90% rename from tensorflow/tools/api/golden/tensorflow.spectral.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.spectral.pbtxt index 4f306540ccfdeac8ce59a394ec77b24284f13ceb..6a421ef12d58dc047905ec916cbe777b4ce19b9a 100644 --- a/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.spectral.pbtxt @@ -16,6 +16,10 @@ tf_module { name: "fft3d" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "idct" + argspec: "args=[\'input\', \'type\', \'n\', \'axis\', \'norm\', \'name\'], varargs=None, keywords=None, defaults=[\'2\', \'None\', \'-1\', \'None\', \'None\'], " + } member_method { name: "ifft" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9a831fed2692b30db6ce991c86f46a42908c0789 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt @@ -0,0 +1,43 @@ +path: "tensorflow.strings" +tf_module { + member_method { + name: "join" + argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " + } + member_method { + name: "regex_full_match" + argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "regex_replace" + argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "split" + argspec: "args=[\'source\', \'sep\', \'maxsplit\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], " + } + member_method { + name: "strip" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "substr" + argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_hash_bucket" + argspec: "args=[\'string_tensor\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_hash_bucket_fast" + argspec: "args=[\'input\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_hash_bucket_strong" + argspec: "args=[\'input\', \'num_buckets\', \'key\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_number" + argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.summary.-event.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.summary.-event.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.summary.-event.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-file-writer-cache.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.summary.-file-writer-cache.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.summary.-file-writer-cache.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.summary.-file-writer-cache.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-file-writer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.summary.-file-writer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.summary.-file-writer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.summary.-file-writer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.summary.-session-log.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.summary.-session-log.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.summary.-session-log.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.summary.-summary-description.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.summary.-summary-description.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.summary.-summary-description.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.summary.-summary.-audio.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.summary.-summary.-audio.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.summary.-summary.-audio.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.summary.-summary.-image.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.summary.-summary.-image.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.summary.-summary.-image.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.summary.-summary.-value.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.summary.-summary.-value.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.summary.-summary.-value.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.summary.-summary.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.summary.-summary.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.summary.-summary.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.summary.-tagged-run-metadata.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.summary.-tagged-run-metadata.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.summary.-tagged-run-metadata.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.summary.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7ed9cd77a01c2eadb5ea43a02306d60d505127a0 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.summary.pbtxt @@ -0,0 +1,67 @@ +path: "tensorflow.summary" +tf_module { + member { + name: "Event" + mtype: "" + } + member { + name: "FileWriter" + mtype: "" + } + member { + name: "FileWriterCache" + mtype: "" + } + member { + name: "SessionLog" + mtype: "" + } + member { + name: "Summary" + mtype: "" + } + member { + name: "SummaryDescription" + mtype: "" + } + member { + name: "TaggedRunMetadata" + mtype: "" + } + member_method { + name: "audio" + argspec: "args=[\'name\', \'tensor\', \'sample_rate\', \'max_outputs\', \'collections\', \'family\'], varargs=None, keywords=None, defaults=[\'3\', \'None\', \'None\'], " + } + member_method { + name: "get_summary_description" + argspec: "args=[\'node_def\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "histogram" + argspec: "args=[\'name\', \'values\', \'collections\', \'family\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "image" + argspec: "args=[\'name\', \'tensor\', \'max_outputs\', \'collections\', \'family\'], varargs=None, keywords=None, defaults=[\'3\', \'None\', \'None\'], " + } + member_method { + name: "merge" + argspec: "args=[\'inputs\', \'collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "merge_all" + argspec: "args=[\'key\', \'scope\', \'name\'], varargs=None, keywords=None, defaults=[\'summaries\', \'None\', \'None\'], " + } + member_method { + name: "scalar" + argspec: "args=[\'name\', \'tensor\', \'collections\', \'family\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "tensor_summary" + argspec: "args=[\'name\', \'tensor\', \'summary_description\', \'collections\', \'summary_metadata\', \'family\', \'display_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "text" + argspec: "args=[\'name\', \'tensor\', \'collections\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.sysconfig.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sysconfig.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.sysconfig.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.sysconfig.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.test.-benchmark.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.test.-benchmark.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.test.-benchmark.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.test.-benchmark.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.test.-stub-out-for-testing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.test.-stub-out-for-testing.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.test.-stub-out-for-testing.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.test.-stub-out-for-testing.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.test.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.test.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.test.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.test.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-adadelta-optimizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-adadelta-optimizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-adagrad-d-a-optimizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-adagrad-d-a-optimizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-adagrad-optimizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-adagrad-optimizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-adam-optimizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-adam-optimizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-bytes-list.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-bytes-list.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-bytes-list.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-saver-hook.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-hook.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-saver-hook.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-listener.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-saver-listener.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-listener.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint-saver-listener.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5be37200f368b1823093c67ad7042db534b0df93 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-checkpoint.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.train.Checkpoint" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "save_counter" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "restore" + argspec: "args=[\'self\', \'save_path\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "save" + argspec: "args=[\'self\', \'file_prefix\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "write" + argspec: "args=[\'self\', \'file_prefix\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-chief-session-creator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-chief-session-creator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-chief-session-creator.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-chief-session-creator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-cluster-def.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-cluster-def.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-cluster-def.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-cluster-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-cluster-spec.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-cluster-spec.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-cluster-spec.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-coordinator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-coordinator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-coordinator.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-coordinator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-example.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-example.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-example.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-exponential-moving-average.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-exponential-moving-average.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-feature-list.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-feature-list.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-feature-list.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-feature-lists.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-feature-lists.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-feature-lists.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-feature.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-feature.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-feature.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-features.-feature-entry.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-features.-feature-entry.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-features.-feature-entry.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-features.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-features.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-features.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-feed-fn-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-feed-fn-hook.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-feed-fn-hook.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-feed-fn-hook.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-final-ops-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-final-ops-hook.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-final-ops-hook.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-final-ops-hook.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-float-list.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-float-list.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-float-list.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-ftrl-optimizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-global-step-waiter-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-global-step-waiter-hook.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-global-step-waiter-hook.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-global-step-waiter-hook.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-gradient-descent-optimizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-gradient-descent-optimizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-int64-list.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-int64-list.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-int64-list.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-job-def.-tasks-entry.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-job-def.-tasks-entry.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-job-def.-tasks-entry.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-job-def.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-job-def.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-job-def.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-logging-tensor-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-logging-tensor-hook.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-logging-tensor-hook.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-logging-tensor-hook.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-looper-thread.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-looper-thread.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-looper-thread.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-looper-thread.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-momentum-optimizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-momentum-optimizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-monitored-session.-step-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-monitored-session.-step-context.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-monitored-session.-step-context.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-monitored-session.-step-context.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-monitored-session.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-monitored-session.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-monitored-session.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-monitored-session.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-nan-loss-during-training-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-nan-loss-during-training-error.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-nan-loss-during-training-error.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-nan-loss-during-training-error.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-nan-tensor-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-nan-tensor-hook.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-nan-tensor-hook.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-nan-tensor-hook.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-optimizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-optimizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-profiler-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-profiler-hook.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-profiler-hook.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-profiler-hook.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-proximal-adagrad-optimizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-proximal-adagrad-optimizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-queue-runner.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-queue-runner.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-queue-runner.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-queue-runner.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-r-m-s-prop-optimizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-r-m-s-prop-optimizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-saver-def.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-saver-def.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-saver-def.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-saver.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-saver.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-scaffold.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-scaffold.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-scaffold.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-second-or-step-timer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-second-or-step-timer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-second-or-step-timer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-second-or-step-timer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-sequence-example.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-sequence-example.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-sequence-example.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-server-def.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-server-def.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-server-def.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-server.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-server.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-server.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-server.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-creator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-session-creator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-session-creator.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-session-creator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-session-manager.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-session-manager.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-run-args.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-session-run-args.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-session-run-args.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-session-run-args.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-run-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-session-run-context.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-session-run-context.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-session-run-context.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-run-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-session-run-hook.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-session-run-hook.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-session-run-hook.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-run-values.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-session-run-values.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-session-run-values.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-session-run-values.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-singular-monitored-session.-step-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-singular-monitored-session.-step-context.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-singular-monitored-session.-step-context.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-singular-monitored-session.-step-context.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-singular-monitored-session.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-singular-monitored-session.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-singular-monitored-session.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-singular-monitored-session.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-step-counter-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-step-counter-hook.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-step-counter-hook.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-step-counter-hook.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-stop-at-step-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-stop-at-step-hook.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-stop-at-step-hook.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-stop-at-step-hook.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-summary-saver-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-summary-saver-hook.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-summary-saver-hook.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-summary-saver-hook.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-supervisor.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-supervisor.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-sync-replicas-optimizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-sync-replicas-optimizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-vocab-info.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.-worker-session-creator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-worker-session-creator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.-worker-session-creator.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.-worker-session-creator.pbtxt diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9f3539528435f0487492deb10fa2cfb63f8f58ae --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt @@ -0,0 +1,459 @@ +path: "tensorflow.train" +tf_module { + member { + name: "AdadeltaOptimizer" + mtype: "" + } + member { + name: "AdagradDAOptimizer" + mtype: "" + } + member { + name: "AdagradOptimizer" + mtype: "" + } + member { + name: "AdamOptimizer" + mtype: "" + } + member { + name: "BytesList" + mtype: "" + } + member { + name: "Checkpoint" + mtype: "" + } + member { + name: "CheckpointSaverHook" + mtype: "" + } + member { + name: "CheckpointSaverListener" + mtype: "" + } + member { + name: "ChiefSessionCreator" + mtype: "" + } + member { + name: "ClusterDef" + mtype: "" + } + member { + name: "ClusterSpec" + mtype: "" + } + member { + name: "Coordinator" + mtype: "" + } + member { + name: "Example" + mtype: "" + } + member { + name: "ExponentialMovingAverage" + mtype: "" + } + member { + name: "Feature" + mtype: "" + } + member { + name: "FeatureList" + mtype: "" + } + member { + name: "FeatureLists" + mtype: "" + } + member { + name: "Features" + mtype: "" + } + member { + name: "FeedFnHook" + mtype: "" + } + member { + name: "FinalOpsHook" + mtype: "" + } + member { + name: "FloatList" + mtype: "" + } + member { + name: "FtrlOptimizer" + mtype: "" + } + member { + name: "GlobalStepWaiterHook" + mtype: "" + } + member { + name: "GradientDescentOptimizer" + mtype: "" + } + member { + name: "Int64List" + mtype: "" + } + member { + name: "JobDef" + mtype: "" + } + member { + name: "LoggingTensorHook" + mtype: "" + } + member { + name: "LooperThread" + mtype: "" + } + member { + name: "MomentumOptimizer" + mtype: "" + } + member { + name: "MonitoredSession" + mtype: "" + } + member { + name: "NanLossDuringTrainingError" + mtype: "" + } + member { + name: "NanTensorHook" + mtype: "" + } + member { + name: "Optimizer" + mtype: "" + } + member { + name: "ProfilerHook" + mtype: "" + } + member { + name: "ProximalAdagradOptimizer" + mtype: "" + } + member { + name: "ProximalGradientDescentOptimizer" + mtype: "" + } + member { + name: "QueueRunner" + mtype: "" + } + member { + name: "RMSPropOptimizer" + mtype: "" + } + member { + name: "Saver" + mtype: "" + } + member { + name: "SaverDef" + mtype: "" + } + member { + name: "Scaffold" + mtype: "" + } + member { + name: "SecondOrStepTimer" + mtype: "" + } + member { + name: "SequenceExample" + mtype: "" + } + member { + name: "Server" + mtype: "" + } + member { + name: "ServerDef" + mtype: "" + } + member { + name: "SessionCreator" + mtype: "" + } + member { + name: "SessionManager" + mtype: "" + } + member { + name: "SessionRunArgs" + mtype: "" + } + member { + name: "SessionRunContext" + mtype: "" + } + member { + name: "SessionRunHook" + mtype: "" + } + member { + name: "SessionRunValues" + mtype: "" + } + member { + name: "SingularMonitoredSession" + mtype: "" + } + member { + name: "StepCounterHook" + mtype: "" + } + member { + name: "StopAtStepHook" + mtype: "" + } + member { + name: "SummarySaverHook" + mtype: "" + } + member { + name: "Supervisor" + mtype: "" + } + member { + name: "SyncReplicasOptimizer" + mtype: "" + } + member { + name: "VocabInfo" + mtype: "" + } + member { + name: "WorkerSessionCreator" + mtype: "" + } + member { + name: "queue_runner" + mtype: "" + } + member_method { + name: "MonitoredTrainingSession" + argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\', \'summary_dir\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'\', \'\', \'\', \'None\', \'120\', \'100\', \'7200\', \'\', \'None\'], " + } + member_method { + name: "NewCheckpointReader" + argspec: "args=[\'filepattern\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_queue_runner" + argspec: "args=[\'qr\', \'collection\'], varargs=None, keywords=None, defaults=[\'queue_runners\'], " + } + member_method { + name: "assert_global_step" + argspec: "args=[\'global_step_tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "basic_train_loop" + argspec: "args=[\'supervisor\', \'train_step_fn\', \'args\', \'kwargs\', \'master\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'\'], " + } + member_method { + name: "batch" + argspec: "args=[\'tensors\', \'batch_size\', \'num_threads\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "batch_join" + argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "checkpoint_exists" + argspec: "args=[\'checkpoint_prefix\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "cosine_decay" + argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], " + } + member_method { + name: "cosine_decay_restarts" + argspec: "args=[\'learning_rate\', \'global_step\', \'first_decay_steps\', \'t_mul\', \'m_mul\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'2.0\', \'1.0\', \'0.0\', \'None\'], " + } + member_method { + name: "create_global_step" + argspec: "args=[\'graph\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "do_quantize_training_on_graphdef" + argspec: "args=[\'input_graph\', \'num_bits\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "exponential_decay" + argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "export_meta_graph" + argspec: "args=[\'filename\', \'meta_info_def\', \'graph_def\', \'saver_def\', \'collection_list\', \'as_text\', \'graph\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\', \'strip_default_attrs\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\', \'False\', \'False\'], " + } + member_method { + name: "generate_checkpoint_state_proto" + argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'all_model_checkpoint_timestamps\', \'last_preserved_timestamp\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "get_checkpoint_mtimes" + argspec: "args=[\'checkpoint_prefixes\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_checkpoint_state" + argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_global_step" + argspec: "args=[\'graph\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_or_create_global_step" + argspec: "args=[\'graph\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "global_step" + argspec: "args=[\'sess\', \'global_step_tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "import_meta_graph" + argspec: "args=[\'meta_graph_or_file\', \'clear_devices\', \'import_scope\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'None\'], " + } + member_method { + name: "init_from_checkpoint" + argspec: "args=[\'ckpt_dir_or_file\', \'assignment_map\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "input_producer" + argspec: "args=[\'input_tensor\', \'element_shape\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'summary_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "inverse_time_decay" + argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "limit_epochs" + argspec: "args=[\'tensor\', \'num_epochs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "linear_cosine_decay" + argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'0.0\', \'0.001\', \'None\'], " + } + member_method { + name: "list_variables" + argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_checkpoint" + argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_variable" + argspec: "args=[\'ckpt_dir_or_file\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "match_filenames_once" + argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "maybe_batch" + argspec: "args=[\'tensors\', \'keep_input\', \'batch_size\', \'num_threads\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "maybe_batch_join" + argspec: "args=[\'tensors_list\', \'keep_input\', \'batch_size\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "maybe_shuffle_batch" + argspec: "args=[\'tensors\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'keep_input\', \'num_threads\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "maybe_shuffle_batch_join" + argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'keep_input\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "natural_exp_decay" + argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "noisy_linear_cosine_decay" + argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'initial_variance\', \'variance_decay\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'1.0\', \'0.55\', \'0.5\', \'0.0\', \'0.001\', \'None\'], " + } + member_method { + name: "piecewise_constant" + argspec: "args=[\'x\', \'boundaries\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "polynomial_decay" + argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'end_learning_rate\', \'power\', \'cycle\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1.0\', \'False\', \'None\'], " + } + member_method { + name: "range_input_producer" + argspec: "args=[\'limit\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], " + } + member_method { + name: "remove_checkpoint" + argspec: "args=[\'checkpoint_prefix\', \'checkpoint_format_version\', \'meta_graph_suffix\'], varargs=None, keywords=None, defaults=[\'2\', \'meta\'], " + } + member_method { + name: "replica_device_setter" + argspec: "args=[\'ps_tasks\', \'ps_device\', \'worker_device\', \'merge_devices\', \'cluster\', \'ps_ops\', \'ps_strategy\'], varargs=None, keywords=None, defaults=[\'0\', \'/job:ps\', \'/job:worker\', \'True\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sdca_fprint" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sdca_optimizer" + argspec: "args=[\'sparse_example_indices\', \'sparse_feature_indices\', \'sparse_feature_values\', \'dense_features\', \'example_weights\', \'example_labels\', \'sparse_indices\', \'sparse_weights\', \'dense_weights\', \'example_state_data\', \'loss_type\', \'l1\', \'l2\', \'num_loss_partitions\', \'num_inner_iterations\', \'adaptative\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "sdca_shrink_l1" + argspec: "args=[\'weights\', \'l1\', \'l2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "shuffle_batch" + argspec: "args=[\'tensors\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'num_threads\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "shuffle_batch_join" + argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "slice_input_producer" + argspec: "args=[\'tensor_list\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], " + } + member_method { + name: "start_queue_runners" + argspec: "args=[\'sess\', \'coord\', \'daemon\', \'start\', \'collection\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'True\', \'queue_runners\'], " + } + member_method { + name: "string_input_producer" + argspec: "args=[\'string_tensor\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "summary_iterator" + argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update_checkpoint_state" + argspec: "args=[\'save_dir\', \'model_checkpoint_path\', \'all_model_checkpoint_paths\', \'latest_filename\', \'all_model_checkpoint_timestamps\', \'last_preserved_timestamp\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "warm_start" + argspec: "args=[\'ckpt_to_initialize_from\', \'vars_to_warm_start\', \'var_name_to_vocab_info\', \'var_name_to_prev_var_name\'], varargs=None, keywords=None, defaults=[\'.*\', \'None\', \'None\'], " + } + member_method { + name: "write_graph" + argspec: "args=[\'graph_or_graph_def\', \'logdir\', \'name\', \'as_text\'], varargs=None, keywords=None, defaults=[\'True\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.queue_runner.-queue-runner.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.queue_runner.-queue-runner.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.queue_runner.-queue-runner.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.queue_runner.-queue-runner.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.train.queue_runner.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.queue_runner.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.train.queue_runner.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.train.queue_runner.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.truncated_normal_initializer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.truncated_normal_initializer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.truncated_normal_initializer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.uniform_unit_scaling_initializer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.uniform_unit_scaling_initializer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.uniform_unit_scaling_initializer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.variable_scope.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.variable_scope.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.variable_scope.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.variable_scope.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.variance_scaling_initializer.pbtxt similarity index 88% rename from tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.variance_scaling_initializer.pbtxt index a58398d645e8397dc8e61a6e0241710c3e34218f..09d7bc03b4f238923db6778ec32ce78ae76eed61 100644 --- a/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.variance_scaling_initializer.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"\"], " + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"\"], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.zeros_initializer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.zeros_initializer.pbtxt rename to tensorflow/tools/api/golden/v1/tensorflow.zeros_initializer.pbtxt diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-aggregation-method.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-aggregation-method.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f79029d3fe0b88a454b11456b3785c3ae28a253c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-aggregation-method.pbtxt @@ -0,0 +1,24 @@ +path: "tensorflow.AggregationMethod" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "ADD_N" + mtype: "" + } + member { + name: "DEFAULT" + mtype: "" + } + member { + name: "EXPERIMENTAL_ACCUMULATE_N" + mtype: "" + } + member { + name: "EXPERIMENTAL_TREE" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-attr-value.-list-value.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-attr-value.-list-value.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f1dffd595285098afaeb0ff04e5db35d594f7fac --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-attr-value.-list-value.pbtxt @@ -0,0 +1,70 @@ +path: "tensorflow.AttrValue.ListValue" +tf_proto { + descriptor { + name: "ListValue" + field { + name: "s" + number: 2 + label: LABEL_REPEATED + type: TYPE_BYTES + } + field { + name: "i" + number: 3 + label: LABEL_REPEATED + type: TYPE_INT64 + options { + packed: true + } + } + field { + name: "f" + number: 4 + label: LABEL_REPEATED + type: TYPE_FLOAT + options { + packed: true + } + } + field { + name: "b" + number: 5 + label: LABEL_REPEATED + type: TYPE_BOOL + options { + packed: true + } + } + field { + name: "type" + number: 6 + label: LABEL_REPEATED + type: TYPE_ENUM + type_name: ".tensorflow.DataType" + options { + packed: true + } + } + field { + name: "shape" + number: 7 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + } + field { + name: "tensor" + number: 8 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + } + field { + name: "func" + number: 9 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.NameAttrList" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-attr-value.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-attr-value.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6ccd64f428c3b87c807d0af82f67a884187f738c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-attr-value.pbtxt @@ -0,0 +1,151 @@ +path: "tensorflow.AttrValue" +tf_proto { + descriptor { + name: "AttrValue" + field { + name: "s" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "i" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + oneof_index: 0 + } + field { + name: "f" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + oneof_index: 0 + } + field { + name: "b" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_BOOL + oneof_index: 0 + } + field { + name: "type" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.DataType" + oneof_index: 0 + } + field { + name: "shape" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + oneof_index: 0 + } + field { + name: "tensor" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + oneof_index: 0 + } + field { + name: "list" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.AttrValue.ListValue" + oneof_index: 0 + } + field { + name: "func" + number: 10 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.NameAttrList" + oneof_index: 0 + } + field { + name: "placeholder" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_STRING + oneof_index: 0 + } + nested_type { + name: "ListValue" + field { + name: "s" + number: 2 + label: LABEL_REPEATED + type: TYPE_BYTES + } + field { + name: "i" + number: 3 + label: LABEL_REPEATED + type: TYPE_INT64 + options { + packed: true + } + } + field { + name: "f" + number: 4 + label: LABEL_REPEATED + type: TYPE_FLOAT + options { + packed: true + } + } + field { + name: "b" + number: 5 + label: LABEL_REPEATED + type: TYPE_BOOL + options { + packed: true + } + } + field { + name: "type" + number: 6 + label: LABEL_REPEATED + type: TYPE_ENUM + type_name: ".tensorflow.DataType" + options { + packed: true + } + } + field { + name: "shape" + number: 7 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + } + field { + name: "tensor" + number: 8 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + } + field { + name: "func" + number: 9 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.NameAttrList" + } + } + oneof_decl { + name: "value" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator-base.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator-base.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c9a32c16b34a78bd5a182b7c0635a559bddc611d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator-base.pbtxt @@ -0,0 +1,29 @@ +path: "tensorflow.ConditionalAccumulatorBase" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "accumulator_ref" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'shape\', \'accumulator_ref\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "num_accumulated" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_global_step" + argspec: "args=[\'self\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d23b3bd0cae1f9ab1c2896244a17d4d93e2427e9 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.ConditionalAccumulator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "accumulator_ref" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], " + } + member_method { + name: "apply_grad" + argspec: "args=[\'self\', \'grad\', \'local_step\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], " + } + member_method { + name: "num_accumulated" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_global_step" + argspec: "args=[\'self\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "take_grad" + argspec: "args=[\'self\', \'num_required\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-device-count-entry.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-device-count-entry.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d9b142682899bf5d9fd5d942437359adf8962466 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-device-count-entry.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.ConfigProto.DeviceCountEntry" +tf_proto { + descriptor { + name: "DeviceCountEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + options { + map_entry: true + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ef9fe096a11a0a75576b9b3d2bc083a82e9818d4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.ConfigProto.Experimental" +tf_proto { + descriptor { + name: "Experimental" + field { + name: "collective_group_leader" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "client_handles_error_formatting" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..eeef15515d73cf45581533fb8d3b02e4cbc4c208 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt @@ -0,0 +1,142 @@ +path: "tensorflow.ConfigProto" +tf_proto { + descriptor { + name: "ConfigProto" + field { + name: "device_count" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.ConfigProto.DeviceCountEntry" + } + field { + name: "intra_op_parallelism_threads" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "inter_op_parallelism_threads" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "use_per_session_threads" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "session_inter_op_thread_pool" + number: 12 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.ThreadPoolOptionProto" + } + field { + name: "placement_period" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "device_filters" + number: 4 + label: LABEL_REPEATED + type: TYPE_STRING + } + field { + name: "gpu_options" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GPUOptions" + } + field { + name: "allow_soft_placement" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "log_device_placement" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "graph_options" + number: 10 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GraphOptions" + } + field { + name: "operation_timeout_in_ms" + number: 11 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "rpc_options" + number: 13 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.RPCOptions" + } + field { + name: "cluster_def" + number: 14 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ClusterDef" + } + field { + name: "isolate_session_state" + number: 15 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "experimental" + number: 16 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ConfigProto.Experimental" + } + nested_type { + name: "DeviceCountEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + options { + map_entry: true + } + } + nested_type { + name: "Experimental" + field { + name: "collective_group_leader" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "client_handles_error_formatting" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-d-type.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-d-type.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0b5b88bba80e6bf7b9d4917c73e3876e00ef956b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-d-type.pbtxt @@ -0,0 +1,77 @@ +path: "tensorflow.DType" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "as_datatype_enum" + mtype: "" + } + member { + name: "as_numpy_dtype" + mtype: "" + } + member { + name: "base_dtype" + mtype: "" + } + member { + name: "is_bool" + mtype: "" + } + member { + name: "is_complex" + mtype: "" + } + member { + name: "is_floating" + mtype: "" + } + member { + name: "is_integer" + mtype: "" + } + member { + name: "is_numpy_compatible" + mtype: "" + } + member { + name: "is_quantized" + mtype: "" + } + member { + name: "is_unsigned" + mtype: "" + } + member { + name: "limits" + mtype: "" + } + member { + name: "max" + mtype: "" + } + member { + name: "min" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "real_dtype" + mtype: "" + } + member { + name: "size" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_compatible_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-device-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-device-spec.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..92e535c341447628a50d8941998a4065e78d12a5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-device-spec.pbtxt @@ -0,0 +1,37 @@ +path: "tensorflow.DeviceSpec" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "job" + mtype: "" + } + member { + name: "replica" + mtype: "" + } + member { + name: "task" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'job\', \'replica\', \'task\', \'device_type\', \'device_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "from_string" + argspec: "args=[\'spec\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "merge_from" + argspec: "args=[\'self\', \'dev\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "parse_from_string" + argspec: "args=[\'self\', \'spec\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "to_string" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-dimension.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-dimension.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a9ab27719b4d71f3d7ed10963ad896ccafa82f15 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-dimension.pbtxt @@ -0,0 +1,25 @@ +path: "tensorflow.Dimension" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "value" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "assert_is_compatible_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_compatible_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "merge_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-event.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-event.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3b75a1735be76fe77689736e492c42c54ab795c1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-event.pbtxt @@ -0,0 +1,74 @@ +path: "tensorflow.Event" +tf_proto { + descriptor { + name: "Event" + field { + name: "wall_time" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "step" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "file_version" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + oneof_index: 0 + } + field { + name: "graph_def" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "summary" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary" + oneof_index: 0 + } + field { + name: "log_message" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.LogMessage" + oneof_index: 0 + } + field { + name: "session_log" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SessionLog" + oneof_index: 0 + } + field { + name: "tagged_run_metadata" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TaggedRunMetadata" + oneof_index: 0 + } + field { + name: "meta_graph_def" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + oneof_decl { + name: "what" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-f-i-f-o-queue.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-f-i-f-o-queue.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a095616c00cfe8fb64413e2078ae1589a423d2f4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-f-i-f-o-queue.pbtxt @@ -0,0 +1,66 @@ +path: "tensorflow.FIFOQueue" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dtypes" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "names" + mtype: "" + } + member { + name: "queue_ref" + mtype: "" + } + member { + name: "shapes" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'capacity\', \'dtypes\', \'shapes\', \'names\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'fifo_queue\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "dequeue" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_many" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_up_to" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue_many" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_list" + argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-fixed-len-feature.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-fixed-len-feature.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6933814a7b68f775e694fe940a7c65a8e31b9398 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-fixed-len-feature.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.FixedLenFeature" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "default_value" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-fixed-len-sequence-feature.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-fixed-len-sequence-feature.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c53878795190924e205a1e7efe1672f216869c41 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-fixed-len-sequence-feature.pbtxt @@ -0,0 +1,31 @@ +path: "tensorflow.FixedLenSequenceFeature" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_missing" + mtype: "" + } + member { + name: "default_value" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..260c796fd65b90020eb2b8191645ffdb2402a4a4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.FixedLengthRecordReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'hop_bytes\', \'name\', \'encoding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-g-p-u-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-g-p-u-options.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..353e63127de174a79c209a05327da2de20bf0dd7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-g-p-u-options.pbtxt @@ -0,0 +1,92 @@ +path: "tensorflow.GPUOptions" +tf_proto { + descriptor { + name: "GPUOptions" + field { + name: "per_process_gpu_memory_fraction" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "allow_growth" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "allocator_type" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "deferred_deletion_bytes" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "visible_device_list" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "polling_active_delay_usecs" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "polling_inactive_delay_msecs" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "force_gpu_compatible" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "experimental" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GPUOptions.Experimental" + } + nested_type { + name: "Experimental" + field { + name: "virtual_devices" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.GPUOptions.Experimental.VirtualDevices" + } + field { + name: "use_unified_memory" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "num_dev_to_dev_copy_streams" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + nested_type { + name: "VirtualDevices" + field { + name: "memory_limit_mb" + number: 1 + label: LABEL_REPEATED + type: TYPE_FLOAT + } + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..cbf655498c02a6521ef45f722f30acd7c13de9cc --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt @@ -0,0 +1,29 @@ +path: "tensorflow.GradientTape" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "gradient" + argspec: "args=[\'self\', \'target\', \'sources\', \'output_gradients\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stop_recording" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "watch" + argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "watched_variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-graph-def.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-graph-def.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..19eccff03d24719d95ea84ccdad4014aa777ccd5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-graph-def.pbtxt @@ -0,0 +1,36 @@ +path: "tensorflow.GraphDef" +tf_proto { + descriptor { + name: "GraphDef" + field { + name: "node" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.NodeDef" + } + field { + name: "versions" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.VersionDef" + } + field { + name: "version" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + options { + deprecated: true + } + } + field { + name: "library" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.FunctionDefLibrary" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-graph-keys.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-graph-keys.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ffe479093397a9bf98d10aa4e054c643e64d5f5d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-graph-keys.pbtxt @@ -0,0 +1,140 @@ +path: "tensorflow.GraphKeys" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "ACTIVATIONS" + mtype: "" + } + member { + name: "ASSET_FILEPATHS" + mtype: "" + } + member { + name: "BIASES" + mtype: "" + } + member { + name: "CONCATENATED_VARIABLES" + mtype: "" + } + member { + name: "COND_CONTEXT" + mtype: "" + } + member { + name: "EVAL_STEP" + mtype: "" + } + member { + name: "GLOBAL_STEP" + mtype: "" + } + member { + name: "GLOBAL_VARIABLES" + mtype: "" + } + member { + name: "INIT_OP" + mtype: "" + } + member { + name: "LOCAL_INIT_OP" + mtype: "" + } + member { + name: "LOCAL_RESOURCES" + mtype: "" + } + member { + name: "LOCAL_VARIABLES" + mtype: "" + } + member { + name: "LOSSES" + mtype: "" + } + member { + name: "METRIC_VARIABLES" + mtype: "" + } + member { + name: "MODEL_VARIABLES" + mtype: "" + } + member { + name: "MOVING_AVERAGE_VARIABLES" + mtype: "" + } + member { + name: "QUEUE_RUNNERS" + mtype: "" + } + member { + name: "READY_FOR_LOCAL_INIT_OP" + mtype: "" + } + member { + name: "READY_OP" + mtype: "" + } + member { + name: "REGULARIZATION_LOSSES" + mtype: "" + } + member { + name: "RESOURCES" + mtype: "" + } + member { + name: "SAVEABLE_OBJECTS" + mtype: "" + } + member { + name: "SAVERS" + mtype: "" + } + member { + name: "SUMMARIES" + mtype: "" + } + member { + name: "SUMMARY_OP" + mtype: "" + } + member { + name: "TABLE_INITIALIZERS" + mtype: "" + } + member { + name: "TRAINABLE_RESOURCE_VARIABLES" + mtype: "" + } + member { + name: "TRAINABLE_VARIABLES" + mtype: "" + } + member { + name: "TRAIN_OP" + mtype: "" + } + member { + name: "UPDATE_OPS" + mtype: "" + } + member { + name: "VARIABLES" + mtype: "" + } + member { + name: "WEIGHTS" + mtype: "" + } + member { + name: "WHILE_CONTEXT" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-graph-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-graph-options.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a9f99bc171cc3661031981f467f583b122e43476 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-graph-options.pbtxt @@ -0,0 +1,67 @@ +path: "tensorflow.GraphOptions" +tf_proto { + descriptor { + name: "GraphOptions" + field { + name: "enable_recv_scheduling" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "optimizer_options" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.OptimizerOptions" + } + field { + name: "build_cost_model" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "build_cost_model_after" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "infer_shapes" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "place_pruned_graph" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "enable_bfloat16_sendrecv" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "timeline_step" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "rewrite_options" + number: 10 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.RewriterConfig" + } + reserved_range { + start: 1 + end: 2 + } + reserved_name: "skip_common_subexpression_elimination" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..cdaeb55e30865e082054085f47d6a071ebf3affd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-graph.pbtxt @@ -0,0 +1,141 @@ +path: "tensorflow.Graph" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "building_function" + mtype: "" + } + member { + name: "collections" + mtype: "" + } + member { + name: "finalized" + mtype: "" + } + member { + name: "graph_def_versions" + mtype: "" + } + member { + name: "seed" + mtype: "" + } + member { + name: "version" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_to_collection" + argspec: "args=[\'self\', \'name\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_to_collections" + argspec: "args=[\'self\', \'names\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_default" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_graph_def" + argspec: "args=[\'self\', \'from_version\', \'add_shapes\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "as_graph_element" + argspec: "args=[\'self\', \'obj\', \'allow_tensor\', \'allow_operation\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], " + } + member_method { + name: "clear_collection" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "colocate_with" + argspec: "args=[\'self\', \'op\', \'ignore_existing\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "container" + argspec: "args=[\'self\', \'container_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "control_dependencies" + argspec: "args=[\'self\', \'control_inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "create_op" + argspec: "args=[\'self\', \'op_type\', \'inputs\', \'dtypes\', \'input_types\', \'name\', \'attrs\', \'op_def\', \'compute_shapes\', \'compute_device\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'True\'], " + } + member_method { + name: "device" + argspec: "args=[\'self\', \'device_name_or_function\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "finalize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_all_collection_keys" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_collection" + argspec: "args=[\'self\', \'name\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_collection_ref" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_name_scope" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_operation_by_name" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_operations" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_tensor_by_name" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "gradient_override_map" + argspec: "args=[\'self\', \'op_type_map\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_feedable" + argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_fetchable" + argspec: "args=[\'self\', \'tensor_or_op\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "name_scope" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prevent_feeding" + argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prevent_fetching" + argspec: "args=[\'self\', \'op\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "switch_to_thread_local" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "unique_name" + argspec: "args=[\'self\', \'name\', \'mark_as_used\'], varargs=None, keywords=None, defaults=[\'True\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-histogram-proto.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-histogram-proto.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d4402f330b8a28eaa61eb2b74c9ca412dce06b62 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-histogram-proto.pbtxt @@ -0,0 +1,54 @@ +path: "tensorflow.HistogramProto" +tf_proto { + descriptor { + name: "HistogramProto" + field { + name: "min" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "max" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "num" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "sum" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "sum_squares" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "bucket_limit" + number: 6 + label: LABEL_REPEATED + type: TYPE_DOUBLE + options { + packed: true + } + } + field { + name: "bucket" + number: 7 + label: LABEL_REPEATED + type: TYPE_DOUBLE + options { + packed: true + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2eda320d6368324f4caea64767fe55aae28494f4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.IdentityReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-indexed-slices.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-indexed-slices.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..fee84d85307dffb675b507a31c4f1fda60de869d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-indexed-slices.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.IndexedSlices" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dense_shape" + mtype: "" + } + member { + name: "device" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "indices" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member { + name: "values" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'values\', \'indices\', \'dense_shape\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-interactive-session.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-interactive-session.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0a3b81bf829f48e88e9c48ce26cdbb4207101a16 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-interactive-session.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.InteractiveSession" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member { + name: "graph_def" + mtype: "" + } + member { + name: "sess_str" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'target\', \'graph\', \'config\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'None\'], " + } + member_method { + name: "as_default" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "list_devices" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "make_callable" + argspec: "args=[\'self\', \'fetches\', \'feed_list\', \'accept_options\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "partial_run" + argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "partial_run_setup" + argspec: "args=[\'self\', \'fetches\', \'feeds\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "run" + argspec: "args=[\'self\', \'fetches\', \'feed_dict\', \'options\', \'run_metadata\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f9b7e9bbca82858ca99e67d70cf93583ca75972f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.LMDBReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-log-message.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-log-message.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5023aa96bf3b4f3f550421db5f41872d9f62b70d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-log-message.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.LogMessage" +tf_proto { + descriptor { + name: "LogMessage" + field { + name: "level" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.LogMessage.Level" + } + field { + name: "message" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + enum_type { + name: "Level" + value { + name: "UNKNOWN" + number: 0 + } + value { + name: "DEBUGGING" + number: 10 + } + value { + name: "INFO" + number: 20 + } + value { + name: "WARN" + number: 30 + } + value { + name: "ERROR" + number: 40 + } + value { + name: "FATAL" + number: 50 + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0ba09bec4b3fa6e9eaf59978beaa958ebc038b4c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-meta-graph-def.-collection-def-entry.pbtxt @@ -0,0 +1,22 @@ +path: "tensorflow.MetaGraphDef.CollectionDefEntry" +tf_proto { + descriptor { + name: "CollectionDefEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.CollectionDef" + } + options { + map_entry: true + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-meta-graph-def.-meta-info-def.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-meta-graph-def.-meta-info-def.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..41c62a407b8577288016f2376c35ba6ec1c3c1ca --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-meta-graph-def.-meta-info-def.pbtxt @@ -0,0 +1,50 @@ +path: "tensorflow.MetaGraphDef.MetaInfoDef" +tf_proto { + descriptor { + name: "MetaInfoDef" + field { + name: "meta_graph_version" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "stripped_op_list" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.OpList" + } + field { + name: "any_info" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".google.protobuf.Any" + } + field { + name: "tags" + number: 4 + label: LABEL_REPEATED + type: TYPE_STRING + } + field { + name: "tensorflow_version" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tensorflow_git_version" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "stripped_default_attrs" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..73dc414a779ded3d1f896e743b7f1f1a443352f0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-meta-graph-def.-signature-def-entry.pbtxt @@ -0,0 +1,22 @@ +path: "tensorflow.MetaGraphDef.SignatureDefEntry" +tf_proto { + descriptor { + name: "SignatureDefEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SignatureDef" + } + options { + map_entry: true + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-meta-graph-def.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-meta-graph-def.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d71c2358c93e9597726665fdf8f92e648b2ea772 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-meta-graph-def.pbtxt @@ -0,0 +1,133 @@ +path: "tensorflow.MetaGraphDef" +tf_proto { + descriptor { + name: "MetaGraphDef" + field { + name: "meta_info_def" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.MetaGraphDef.MetaInfoDef" + } + field { + name: "graph_def" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GraphDef" + } + field { + name: "saver_def" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SaverDef" + } + field { + name: "collection_def" + number: 4 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.MetaGraphDef.CollectionDefEntry" + } + field { + name: "signature_def" + number: 5 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.MetaGraphDef.SignatureDefEntry" + } + field { + name: "asset_file_def" + number: 6 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.AssetFileDef" + } + nested_type { + name: "MetaInfoDef" + field { + name: "meta_graph_version" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "stripped_op_list" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.OpList" + } + field { + name: "any_info" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".google.protobuf.Any" + } + field { + name: "tags" + number: 4 + label: LABEL_REPEATED + type: TYPE_STRING + } + field { + name: "tensorflow_version" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tensorflow_git_version" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "stripped_default_attrs" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + } + nested_type { + name: "CollectionDefEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.CollectionDef" + } + options { + map_entry: true + } + } + nested_type { + name: "SignatureDefEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SignatureDef" + } + options { + map_entry: true + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-name-attr-list.-attr-entry.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-name-attr-list.-attr-entry.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b119b208772199e5c3596be142f3e0f62d3ed50e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-name-attr-list.-attr-entry.pbtxt @@ -0,0 +1,22 @@ +path: "tensorflow.NameAttrList.AttrEntry" +tf_proto { + descriptor { + name: "AttrEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.AttrValue" + } + options { + map_entry: true + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-name-attr-list.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-name-attr-list.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..fcdb411ffce9b68ac28696f86ca11a47f9e64e8f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-name-attr-list.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.NameAttrList" +tf_proto { + descriptor { + name: "NameAttrList" + field { + name: "name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "attr" + number: 2 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.NameAttrList.AttrEntry" + } + nested_type { + name: "AttrEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.AttrValue" + } + options { + map_entry: true + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-node-def.-attr-entry.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-node-def.-attr-entry.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..622e4c3d0f60ce4842a6fd4cc421551aa795fcbf --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-node-def.-attr-entry.pbtxt @@ -0,0 +1,22 @@ +path: "tensorflow.NodeDef.AttrEntry" +tf_proto { + descriptor { + name: "AttrEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.AttrValue" + } + options { + map_entry: true + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-node-def.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-node-def.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..646fa8abb9b22dbd908ff821cbe66a33ad02ba64 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-node-def.pbtxt @@ -0,0 +1,56 @@ +path: "tensorflow.NodeDef" +tf_proto { + descriptor { + name: "NodeDef" + field { + name: "name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "op" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "input" + number: 3 + label: LABEL_REPEATED + type: TYPE_STRING + } + field { + name: "device" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "attr" + number: 5 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.NodeDef.AttrEntry" + } + nested_type { + name: "AttrEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.AttrValue" + } + options { + map_entry: true + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-op-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-op-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7e59615534fc2b3ed4fb128caf8ea092ebfd25f4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-op-error.pbtxt @@ -0,0 +1,29 @@ +path: "tensorflow.OpError" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-operation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-operation.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..64240f706983bb2ced63e49937800d2db4e627f2 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-operation.pbtxt @@ -0,0 +1,69 @@ +path: "tensorflow.Operation" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "control_inputs" + mtype: "" + } + member { + name: "device" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inputs" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op_def" + mtype: "" + } + member { + name: "outputs" + mtype: "" + } + member { + name: "traceback" + mtype: "" + } + member { + name: "traceback_with_start_lines" + mtype: "" + } + member { + name: "type" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'g\', \'inputs\', \'output_types\', \'control_inputs\', \'input_types\', \'original_op\', \'op_def\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "colocation_groups" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_attr" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run" + argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "values" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-optimizer-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-optimizer-options.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3ccf9d459b133b48e5456f02e4780ade8d3042c8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-optimizer-options.pbtxt @@ -0,0 +1,74 @@ +path: "tensorflow.OptimizerOptions" +tf_proto { + descriptor { + name: "OptimizerOptions" + field { + name: "do_common_subexpression_elimination" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "do_constant_folding" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "max_folded_constant_in_bytes" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "do_function_inlining" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "opt_level" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.OptimizerOptions.Level" + } + field { + name: "global_jit_level" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.OptimizerOptions.GlobalJitLevel" + } + enum_type { + name: "Level" + value { + name: "L1" + number: 0 + } + value { + name: "L0" + number: -1 + } + } + enum_type { + name: "GlobalJitLevel" + value { + name: "DEFAULT" + number: 0 + } + value { + name: "OFF" + number: -1 + } + value { + name: "ON_1" + number: 1 + } + value { + name: "ON_2" + number: 2 + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-padding-f-i-f-o-queue.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-padding-f-i-f-o-queue.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8fed133561544b91abfc64577e63a7088b43a007 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-padding-f-i-f-o-queue.pbtxt @@ -0,0 +1,66 @@ +path: "tensorflow.PaddingFIFOQueue" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dtypes" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "names" + mtype: "" + } + member { + name: "queue_ref" + mtype: "" + } + member { + name: "shapes" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'capacity\', \'dtypes\', \'shapes\', \'names\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'padding_fifo_queue\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "dequeue" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_many" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_up_to" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue_many" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_list" + argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-priority-queue.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-priority-queue.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ebb017e81bc29e062d804fbe9f50c62f7b615dab --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-priority-queue.pbtxt @@ -0,0 +1,66 @@ +path: "tensorflow.PriorityQueue" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dtypes" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "names" + mtype: "" + } + member { + name: "queue_ref" + mtype: "" + } + member { + name: "shapes" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'capacity\', \'types\', \'shapes\', \'names\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'priority_queue\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "dequeue" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_many" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_up_to" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue_many" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_list" + argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-queue-base.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-queue-base.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..761f90989f316611d42580ee911e24bb3d0d2fec --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-queue-base.pbtxt @@ -0,0 +1,65 @@ +path: "tensorflow.QueueBase" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "dtypes" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "names" + mtype: "" + } + member { + name: "queue_ref" + mtype: "" + } + member { + name: "shapes" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtypes\', \'shapes\', \'names\', \'queue_ref\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "close" + argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "dequeue" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_many" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_up_to" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue_many" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_list" + argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-random-shuffle-queue.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-random-shuffle-queue.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f3ca84139311bc05478e3dce876b53f7b9dec883 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-random-shuffle-queue.pbtxt @@ -0,0 +1,66 @@ +path: "tensorflow.RandomShuffleQueue" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dtypes" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "names" + mtype: "" + } + member { + name: "queue_ref" + mtype: "" + } + member { + name: "shapes" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'capacity\', \'min_after_dequeue\', \'dtypes\', \'shapes\', \'names\', \'seed\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'random_shuffle_queue\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "dequeue" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_many" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue_up_to" + argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue_many" + argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_list" + argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_closed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f6a3ce76a157686becd92e2c7f873bfbc7572116 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt @@ -0,0 +1,45 @@ +path: "tensorflow.ReaderBase" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'reader_ref\', \'supports_serialize\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-register-gradient.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-register-gradient.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4d6e4137d12d4a1ff283a114d4f0cc5602b0b734 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-register-gradient.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.RegisterGradient" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'op_type\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-run-metadata.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-run-metadata.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1287940326c0196e76fff2cf6363622226092504 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-run-metadata.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.RunMetadata" +tf_proto { + descriptor { + name: "RunMetadata" + field { + name: "step_stats" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.StepStats" + } + field { + name: "cost_graph" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.CostGraphDef" + } + field { + name: "partition_graphs" + number: 3 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.GraphDef" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..537e73aa8969905c108a59688cfd99793ce211f0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.RunOptions.Experimental" +tf_proto { + descriptor { + name: "Experimental" + field { + name: "collective_graph_key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..cec04a2bf0962455495340da001214914cc8bb36 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt @@ -0,0 +1,83 @@ +path: "tensorflow.RunOptions" +tf_proto { + descriptor { + name: "RunOptions" + field { + name: "trace_level" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.RunOptions.TraceLevel" + } + field { + name: "timeout_in_ms" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "inter_op_thread_pool" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "output_partition_graphs" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "debug_options" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.DebugOptions" + } + field { + name: "report_tensor_allocations_upon_oom" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "experimental" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.RunOptions.Experimental" + } + nested_type { + name: "Experimental" + field { + name: "collective_graph_key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + } + enum_type { + name: "TraceLevel" + value { + name: "NO_TRACE" + number: 0 + } + value { + name: "SOFTWARE_TRACE" + number: 1 + } + value { + name: "HARDWARE_TRACE" + number: 2 + } + value { + name: "FULL_TRACE" + number: 3 + } + } + reserved_range { + start: 4 + end: 5 + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-session-log.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-session-log.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..259f2418740cbfe47cdb4bd871d4f5c6306d25f5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-session-log.pbtxt @@ -0,0 +1,44 @@ +path: "tensorflow.SessionLog" +tf_proto { + descriptor { + name: "SessionLog" + field { + name: "status" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.SessionLog.SessionStatus" + } + field { + name: "checkpoint_path" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "msg" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + enum_type { + name: "SessionStatus" + value { + name: "STATUS_UNSPECIFIED" + number: 0 + } + value { + name: "START" + number: 1 + } + value { + name: "STOP" + number: 2 + } + value { + name: "CHECKPOINT" + number: 3 + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-session.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-session.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1d6b037f9c3540653a8fb18b6508f74b01da66ab --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-session.pbtxt @@ -0,0 +1,55 @@ +path: "tensorflow.Session" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "graph" + mtype: "" + } + member { + name: "graph_def" + mtype: "" + } + member { + name: "sess_str" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'target\', \'graph\', \'config\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'None\'], " + } + member_method { + name: "as_default" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "list_devices" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "make_callable" + argspec: "args=[\'self\', \'fetches\', \'feed_list\', \'accept_options\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "partial_run" + argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "partial_run_setup" + argspec: "args=[\'self\', \'fetches\', \'feeds\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'target\', \'containers\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "run" + argspec: "args=[\'self\', \'fetches\', \'feed_dict\', \'options\', \'run_metadata\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2260279ad2bcfc246f42b225adc05f7c19f1aac1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.SparseConditionalAccumulator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "accumulator_ref" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], " + } + member_method { + name: "apply_grad" + argspec: "args=[\'self\', \'grad_indices\', \'grad_values\', \'grad_shape\', \'local_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "apply_indexed_slices_grad" + argspec: "args=[\'self\', \'grad\', \'local_step\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], " + } + member_method { + name: "num_accumulated" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_global_step" + argspec: "args=[\'self\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "take_grad" + argspec: "args=[\'self\', \'num_required\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "take_indexed_slices_grad" + argspec: "args=[\'self\', \'num_required\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-feature.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-feature.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d875394fb5de73f67629b77c902a2ed2a03dd982 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-feature.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.SparseFeature" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "already_sorted" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "index_key" + mtype: "" + } + member { + name: "size" + mtype: "" + } + member { + name: "value_key" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor-value.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor-value.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d33fd4d5d7b6b3e2eb7454b5326d993c139f0490 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor-value.pbtxt @@ -0,0 +1,26 @@ +path: "tensorflow.SparseTensorValue" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "dense_shape" + mtype: "" + } + member { + name: "indices" + mtype: "" + } + member { + name: "values" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..eac236d4982b809a0478665096c2b18d69c54184 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-tensor.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.SparseTensor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dense_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "indices" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member { + name: "values" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "eval" + argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "from_value" + argspec: "args=[\'cls\', \'sparse_tensor_value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_shape" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-summary-metadata.-plugin-data.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-summary-metadata.-plugin-data.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a66b74b315c6132e8f884bd52e7a3b5bd7f52ccd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-summary-metadata.-plugin-data.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.SummaryMetadata.PluginData" +tf_proto { + descriptor { + name: "PluginData" + field { + name: "plugin_name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "content" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-summary-metadata.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-summary-metadata.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c02575b9626c848e9b871d2cc6febb26a5142f08 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-summary-metadata.pbtxt @@ -0,0 +1,40 @@ +path: "tensorflow.SummaryMetadata" +tf_proto { + descriptor { + name: "SummaryMetadata" + field { + name: "plugin_data" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SummaryMetadata.PluginData" + } + field { + name: "display_name" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "summary_description" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + nested_type { + name: "PluginData" + field { + name: "plugin_name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "content" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-summary.-audio.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-summary.-audio.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..94f712073e0d0dda201fcf7adba849dd45a1229b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-summary.-audio.pbtxt @@ -0,0 +1,36 @@ +path: "tensorflow.Summary.Audio" +tf_proto { + descriptor { + name: "Audio" + field { + name: "sample_rate" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + } + field { + name: "num_channels" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "length_frames" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "encoded_audio_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + field { + name: "content_type" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-summary.-image.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-summary.-image.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..fc1acb483b3051cba01f5d9bc8501a61965bbc37 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-summary.-image.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.Summary.Image" +tf_proto { + descriptor { + name: "Image" + field { + name: "height" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "width" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "colorspace" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "encoded_image_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-summary.-value.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-summary.-value.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..feb84b6ee996549ac58aa0e8a4ac560f947b6339 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-summary.-value.pbtxt @@ -0,0 +1,74 @@ +path: "tensorflow.Summary.Value" +tf_proto { + descriptor { + name: "Value" + field { + name: "node_name" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tag" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "metadata" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SummaryMetadata" + } + field { + name: "simple_value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + oneof_index: 0 + } + field { + name: "obsolete_old_style_histogram" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "image" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Image" + oneof_index: 0 + } + field { + name: "histo" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.HistogramProto" + oneof_index: 0 + } + field { + name: "audio" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Audio" + oneof_index: 0 + } + field { + name: "tensor" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + oneof_index: 0 + } + oneof_decl { + name: "value" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-summary.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-summary.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b2bdff7171804aae114d1e3631e3074b1e4006ba --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-summary.pbtxt @@ -0,0 +1,144 @@ +path: "tensorflow.Summary" +tf_proto { + descriptor { + name: "Summary" + field { + name: "value" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Value" + } + nested_type { + name: "Image" + field { + name: "height" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "width" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "colorspace" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "encoded_image_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + } + nested_type { + name: "Audio" + field { + name: "sample_rate" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + } + field { + name: "num_channels" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "length_frames" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "encoded_audio_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + field { + name: "content_type" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } + nested_type { + name: "Value" + field { + name: "node_name" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tag" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "metadata" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SummaryMetadata" + } + field { + name: "simple_value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + oneof_index: 0 + } + field { + name: "obsolete_old_style_histogram" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "image" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Image" + oneof_index: 0 + } + field { + name: "histo" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.HistogramProto" + oneof_index: 0 + } + field { + name: "audio" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Audio" + oneof_index: 0 + } + field { + name: "tensor" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + oneof_index: 0 + } + oneof_decl { + name: "value" + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..cdf79373919b6c5f26c68996d8f1cf30e8992203 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.TFRecordReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-tensor-array.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-array.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ed088c41ed3fc444fb9e45919769950f1984e3e8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-array.pbtxt @@ -0,0 +1,69 @@ +path: "tensorflow.TensorArray" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "dtype" + mtype: "" + } + member { + name: "flow" + mtype: "" + } + member { + name: "handle" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'size\', \'dynamic_size\', \'clear_after_read\', \'tensor_array_name\', \'handle\', \'flow\', \'infer_shape\', \'element_shape\', \'colocate_with_first_write_call\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "concat" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "gather" + argspec: "args=[\'self\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "grad" + argspec: "args=[\'self\', \'source\', \'flow\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "identity" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "read" + argspec: "args=[\'self\', \'index\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "scatter" + argspec: "args=[\'self\', \'indices\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "split" + argspec: "args=[\'self\', \'value\', \'lengths\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "stack" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unstack" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "write" + argspec: "args=[\'self\', \'index\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-tensor-info.-coo-sparse.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-info.-coo-sparse.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0064c8460cb374f1e3f108085a2efed4131dd205 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-info.-coo-sparse.pbtxt @@ -0,0 +1,24 @@ +path: "tensorflow.TensorInfo.CooSparse" +tf_proto { + descriptor { + name: "CooSparse" + field { + name: "values_tensor_name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "indices_tensor_name" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "dense_shape_tensor_name" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-tensor-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-info.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..63566c808e55cb4d3b630f0a017fa3a2c8a30de3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-info.pbtxt @@ -0,0 +1,59 @@ +path: "tensorflow.TensorInfo" +tf_proto { + descriptor { + name: "TensorInfo" + field { + name: "name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + oneof_index: 0 + } + field { + name: "coo_sparse" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorInfo.CooSparse" + oneof_index: 0 + } + field { + name: "dtype" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.DataType" + } + field { + name: "tensor_shape" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + } + nested_type { + name: "CooSparse" + field { + name: "values_tensor_name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "indices_tensor_name" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "dense_shape_tensor_name" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } + oneof_decl { + name: "encoding" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-tensor-shape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-shape.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8e3598fb2470b327e6e3601969f055d4907f614a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-shape.pbtxt @@ -0,0 +1,77 @@ +path: "tensorflow.TensorShape" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "dims" + mtype: "" + } + member { + name: "ndims" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dims\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_list" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_proto" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "assert_has_rank" + argspec: "args=[\'self\', \'rank\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "assert_is_compatible_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "assert_is_fully_defined" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "assert_same_rank" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "concatenate" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_compatible_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_fully_defined" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "merge_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "most_specific_compatible_shape" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "num_elements" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_rank" + argspec: "args=[\'self\', \'rank\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_rank_at_least" + argspec: "args=[\'self\', \'rank\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_rank_at_most" + argspec: "args=[\'self\', \'rank\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..38d19bb5374037981c01b29053ab8d05b551eb84 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-tensor.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.Tensor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "OVERLOADABLE_OPERATORS" + mtype: "" + } + member { + name: "device" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "value_index" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'op\', \'value_index\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "consumers" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "eval" + argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "get_shape" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_shape" + argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e9779f07620d2cc1ef3b0ff1b2d32796fc10834a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.TextLineReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'skip_header_lines\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-var-len-feature.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-var-len-feature.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..54b66f43f8e7d714e82ae9d68b37ac348c476c97 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-var-len-feature.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.VarLenFeature" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "dtype" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..36b534af360835e3c1cbd1f0fb12a38c42232abf --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-variable-aggregation.pbtxt @@ -0,0 +1,16 @@ +path: "tensorflow.VariableAggregation" +tf_class { + is_instance: "" + member { + name: "MEAN" + mtype: "" + } + member { + name: "NONE" + mtype: "" + } + member { + name: "SUM" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable-scope.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable-scope.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c13eb7b8bb9474f3534582c8af8c3ee4b6c7e076 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-variable-scope.pbtxt @@ -0,0 +1,105 @@ +path: "tensorflow.VariableScope" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "caching_device" + mtype: "" + } + member { + name: "constraint" + mtype: "" + } + member { + name: "custom_getter" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "initializer" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "original_name_scope" + mtype: "" + } + member { + name: "partitioner" + mtype: "" + } + member { + name: "regularizer" + mtype: "" + } + member { + name: "reuse" + mtype: "" + } + member { + name: "use_resource" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'reuse\', \'name\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'name_scope\', \'dtype\', \'use_resource\', \'constraint\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'\', \"\", \'None\', \'None\'], " + } + member_method { + name: "get_collection" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable" + argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "global_variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "local_variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reuse_variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_caching_device" + argspec: "args=[\'self\', \'caching_device\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_custom_getter" + argspec: "args=[\'self\', \'custom_getter\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_dtype" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_initializer" + argspec: "args=[\'self\', \'initializer\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_partitioner" + argspec: "args=[\'self\', \'partitioner\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_regularizer" + argspec: "args=[\'self\', \'regularizer\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_use_resource" + argspec: "args=[\'self\', \'use_resource\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "trainable_variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable-synchronization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable-synchronization.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7589bb28888774839a3011e1e5581f004313f81d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-variable-synchronization.pbtxt @@ -0,0 +1,20 @@ +path: "tensorflow.VariableSynchronization" +tf_class { + is_instance: "" + member { + name: "AUTO" + mtype: "" + } + member { + name: "NONE" + mtype: "" + } + member { + name: "ON_READ" + mtype: "" + } + member { + name: "ON_WRITE" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ac3ccd468b216ab817c9ed05dcb292eaf1f44398 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt @@ -0,0 +1,17 @@ +path: "tensorflow.Variable.SaveSliceInfo" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "spec" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'full_name\', \'full_shape\', \'var_offset\', \'var_shape\', \'save_slice_info_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "to_proto" + argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e841c4ad8904ae1ae49881b47648b901a4abf778 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt @@ -0,0 +1,110 @@ +path: "tensorflow.Variable" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "SaveSliceInfo" + mtype: "" + } + member { + name: "constraint" + mtype: "" + } + member { + name: "device" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "initial_value" + mtype: "" + } + member { + name: "initializer" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "assign" + argspec: "args=[\'self\', \'value\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "assign_add" + argspec: "args=[\'self\', \'delta\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "assign_sub" + argspec: "args=[\'self\', \'delta\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "count_up_to" + argspec: "args=[\'self\', \'limit\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "eval" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_proto" + argspec: "args=[\'variable_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_shape" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "initialized_value" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load" + argspec: "args=[\'self\', \'value\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_value" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "scatter_sub" + argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "set_shape" + argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "to_proto" + argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "value" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4ac759891c62ae44bf8f8c365da75664f2e65ce2 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.WholeFileReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.app.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.app.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..85044a8987963126ae12aaa0e5eb5d1ecc134539 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.app.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.app" +tf_module { + member { + name: "flags" + mtype: "" + } + member_method { + name: "run" + argspec: "args=[\'main\', \'argv\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.bitwise.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.bitwise.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..01cbd55c5d2e1b6fa3148af956217c3664864eaa --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.bitwise.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.bitwise" +tf_module { + member_method { + name: "bitwise_and" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bitwise_or" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bitwise_xor" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "invert" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "left_shift" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "right_shift" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.compat.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.compat.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f1d760603e981a0b9a72fdc379dc81932ac71d67 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.compat.pbtxt @@ -0,0 +1,47 @@ +path: "tensorflow.compat" +tf_module { + member { + name: "bytes_or_text_types" + mtype: "" + } + member { + name: "complex_types" + mtype: "" + } + member { + name: "integral_types" + mtype: "" + } + member { + name: "real_types" + mtype: "" + } + member_method { + name: "as_bytes" + argspec: "args=[\'bytes_or_text\', \'encoding\'], varargs=None, keywords=None, defaults=[\'utf-8\'], " + } + member_method { + name: "as_str" + argspec: "args=[\'bytes_or_text\', \'encoding\'], varargs=None, keywords=None, defaults=[\'utf-8\'], " + } + member_method { + name: "as_str_any" + argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_text" + argspec: "args=[\'bytes_or_text\', \'encoding\'], varargs=None, keywords=None, defaults=[\'utf-8\'], " + } + member_method { + name: "forward_compatibility_horizon" + argspec: "args=[\'year\', \'month\', \'day\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "forward_compatible" + argspec: "args=[\'year\', \'month\', \'day\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "path_to_str" + argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.constant_initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.constant_initializer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..00ec669b1685f3cbdacd676bac61755bebb9f6da --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.constant_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.constant_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'value\', \'dtype\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'0\', \"\", \'False\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..af08c88d3333fa897c38cc2f6530a9c5cda15342 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.data.Dataset.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..834f0954d5bba655a8eb923672d89bac6bb80808 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt @@ -0,0 +1,117 @@ +path: "tensorflow.data.Dataset" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "output_classes" + mtype: "" + } + member { + name: "output_shapes" + mtype: "" + } + member { + name: "output_types" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "cache" + argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], " + } + member_method { + name: "concatenate" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "filter" + argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flat_map" + argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_generator" + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "from_sparse_tensor_slices" + argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_tensor_slices" + argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_tensors" + argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "interleave" + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + } + member_method { + name: "list_files" + argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "make_initializable_iterator" + argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "make_one_shot_iterator" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "map" + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "padded_batch" + argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "prefetch" + argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "range" + argspec: "args=[], varargs=args, keywords=None, defaults=None" + } + member_method { + name: "repeat" + argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "shard" + argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "shuffle" + argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "skip" + argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "take" + argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zip" + argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f384323fc89bb7d21309e86ddaab2e6e1f9f212b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.data.FixedLengthRecordDataset.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4d854a4ceea3907d7d795d0a19d081f4069c9ba9 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -0,0 +1,118 @@ +path: "tensorflow.data.FixedLengthRecordDataset" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "output_classes" + mtype: "" + } + member { + name: "output_shapes" + mtype: "" + } + member { + name: "output_types" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filenames\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'buffer_size\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "cache" + argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], " + } + member_method { + name: "concatenate" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "filter" + argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flat_map" + argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_generator" + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "from_sparse_tensor_slices" + argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_tensor_slices" + argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_tensors" + argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "interleave" + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + } + member_method { + name: "list_files" + argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "make_initializable_iterator" + argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "make_one_shot_iterator" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "map" + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "padded_batch" + argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "prefetch" + argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "range" + argspec: "args=[], varargs=args, keywords=None, defaults=None" + } + member_method { + name: "repeat" + argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "shard" + argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "shuffle" + argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "skip" + argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "take" + argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zip" + argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-iterator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-iterator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1f9aeb6ad62e1030c6e78f731fb5e05b876899e6 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-iterator.pbtxt @@ -0,0 +1,45 @@ +path: "tensorflow.data.Iterator" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "initializer" + mtype: "" + } + member { + name: "output_classes" + mtype: "" + } + member { + name: "output_shapes" + mtype: "" + } + member { + name: "output_types" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'iterator_resource\', \'initializer\', \'output_types\', \'output_shapes\', \'output_classes\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_string_handle" + argspec: "args=[\'string_handle\', \'output_types\', \'output_shapes\', \'output_classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "from_structure" + argspec: "args=[\'output_types\', \'output_shapes\', \'shared_name\', \'output_classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "get_next" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "make_initializer" + argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "string_handle" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b12dec8a70be5e0cd8346785b48f56b15155dd02 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.data.TFRecordDataset.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..601f095a60ae481b895a535efa37341611499499 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt @@ -0,0 +1,118 @@ +path: "tensorflow.data.TFRecordDataset" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "output_classes" + mtype: "" + } + member { + name: "output_shapes" + mtype: "" + } + member { + name: "output_types" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\', \'num_parallel_reads\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "cache" + argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], " + } + member_method { + name: "concatenate" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "filter" + argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flat_map" + argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_generator" + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "from_sparse_tensor_slices" + argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_tensor_slices" + argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_tensors" + argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "interleave" + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + } + member_method { + name: "list_files" + argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "make_initializable_iterator" + argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "make_one_shot_iterator" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "map" + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "padded_batch" + argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "prefetch" + argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "range" + argspec: "args=[], varargs=args, keywords=None, defaults=None" + } + member_method { + name: "repeat" + argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "shard" + argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "shuffle" + argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "skip" + argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "take" + argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zip" + argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7ddcdce2663ca0ef6409fb3ab3c29555948d7302 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.data.TextLineDataset.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..587829a4c078e8ab945f66c64f5adad21223dfb1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt @@ -0,0 +1,118 @@ +path: "tensorflow.data.TextLineDataset" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "output_classes" + mtype: "" + } + member { + name: "output_shapes" + mtype: "" + } + member { + name: "output_types" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filenames\', \'compression_type\', \'buffer_size\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "batch" + argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "cache" + argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], " + } + member_method { + name: "concatenate" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "filter" + argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flat_map" + argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_generator" + argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "from_sparse_tensor_slices" + argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_tensor_slices" + argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_tensors" + argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "interleave" + argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], " + } + member_method { + name: "list_files" + argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "make_initializable_iterator" + argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "make_one_shot_iterator" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "map" + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "padded_batch" + argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "prefetch" + argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "range" + argspec: "args=[], varargs=args, keywords=None, defaults=None" + } + member_method { + name: "repeat" + argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "shard" + argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "shuffle" + argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "skip" + argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "take" + argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zip" + argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..56fb270a49943a916012ccfcaf816a9156f4fed8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt @@ -0,0 +1,23 @@ +path: "tensorflow.data" +tf_module { + member { + name: "Dataset" + mtype: "" + } + member { + name: "FixedLengthRecordDataset" + mtype: "" + } + member { + name: "Iterator" + mtype: "" + } + member { + name: "TFRecordDataset" + mtype: "" + } + member { + name: "TextLineDataset" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.debugging.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.debugging.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d9efe97821904f5891148b72a0c31e02c9562bd7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.debugging.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.debugging" +tf_module { + member_method { + name: "check_numerics" + argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_finite" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_inf" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "is_nan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-bernoulli.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-bernoulli.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ca96f4eaece0020235d24901f51306a65676c1c9 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-bernoulli.pbtxt @@ -0,0 +1,143 @@ +path: "tensorflow.distributions.Bernoulli" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "logits" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "probs" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'logits\', \'probs\', \'dtype\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"\", \'False\', \'True\', \'Bernoulli\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "cross_entropy" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'cross_entropy\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'kl_divergence\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-beta.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-beta.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d0508acd9f4f6c190b205301223599cf5b027955 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-beta.pbtxt @@ -0,0 +1,147 @@ +path: "tensorflow.distributions.Beta" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "concentration0" + mtype: "" + } + member { + name: "concentration1" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "total_concentration" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'concentration1\', \'concentration0\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'True\', \'Beta\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "cross_entropy" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'cross_entropy\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'kl_divergence\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-categorical.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-categorical.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ff0fbb56cd4b9e4c288a168a7c3d9e83c552b0e2 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-categorical.pbtxt @@ -0,0 +1,147 @@ +path: "tensorflow.distributions.Categorical" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "event_size" + mtype: "" + } + member { + name: "logits" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "probs" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'logits\', \'probs\', \'dtype\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"\", \'False\', \'True\', \'Categorical\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "cross_entropy" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'cross_entropy\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'kl_divergence\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-dirichlet-multinomial.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-dirichlet-multinomial.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d75e4a2f88b29ff7f638d72f98876a230b191dce --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-dirichlet-multinomial.pbtxt @@ -0,0 +1,147 @@ +path: "tensorflow.distributions.DirichletMultinomial" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "concentration" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "total_concentration" + mtype: "" + } + member { + name: "total_count" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'total_count\', \'concentration\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'DirichletMultinomial\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "cross_entropy" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'cross_entropy\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'kl_divergence\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-dirichlet.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-dirichlet.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b838b9ae21decba0323211f08d09fe373ababf23 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-dirichlet.pbtxt @@ -0,0 +1,143 @@ +path: "tensorflow.distributions.Dirichlet" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "concentration" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "total_concentration" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'concentration\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'Dirichlet\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "cross_entropy" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'cross_entropy\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'kl_divergence\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-distribution.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-distribution.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6f06b7d50dd9f5f405673d572503ff549f148f33 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-distribution.pbtxt @@ -0,0 +1,134 @@ +path: "tensorflow.distributions.Distribution" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'reparameterization_type\', \'validate_args\', \'allow_nan_stats\', \'parameters\', \'graph_parents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "cross_entropy" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'cross_entropy\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'kl_divergence\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-exponential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-exponential.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d34f9cde5d4d4161883f6d1b4646f22f054d16ad --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-exponential.pbtxt @@ -0,0 +1,144 @@ +path: "tensorflow.distributions.Exponential" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "concentration" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "rate" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'rate\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'Exponential\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "cross_entropy" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'cross_entropy\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'kl_divergence\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-gamma.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-gamma.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..df268b8d99eb6bf22264ddb63231074413686efa --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-gamma.pbtxt @@ -0,0 +1,143 @@ +path: "tensorflow.distributions.Gamma" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "concentration" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "rate" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'concentration\', \'rate\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'Gamma\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "cross_entropy" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'cross_entropy\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'kl_divergence\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-laplace.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-laplace.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..303dcb4ed3bf8416b822bb010c2e87e8ef03b7c9 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-laplace.pbtxt @@ -0,0 +1,143 @@ +path: "tensorflow.distributions.Laplace" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "loc" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "scale" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'loc\', \'scale\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'Laplace\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "cross_entropy" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'cross_entropy\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'kl_divergence\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-multinomial.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-multinomial.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ecda8acb15c49c390eaae203a0082e78e53499bd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-multinomial.pbtxt @@ -0,0 +1,147 @@ +path: "tensorflow.distributions.Multinomial" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "logits" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "probs" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "total_count" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'total_count\', \'logits\', \'probs\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'True\', \'Multinomial\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "cross_entropy" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'cross_entropy\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'kl_divergence\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-normal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..92b9eeea223b488cda1ebcabd31ec808e78fcf70 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-normal.pbtxt @@ -0,0 +1,143 @@ +path: "tensorflow.distributions.Normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "loc" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "scale" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'loc\', \'scale\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'Normal\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "cross_entropy" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'cross_entropy\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'kl_divergence\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-register-k-l.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-register-k-l.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e3db443c2bdaa70f7651126a30caf2062a3c6f67 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-register-k-l.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.distributions.RegisterKL" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dist_cls_a\', \'dist_cls_b\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-reparameterization-type.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-reparameterization-type.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..02e8d576ddd00aa21005fa39cd323a92392bf75a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-reparameterization-type.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.distributions.ReparameterizationType" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'rep_type\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-student-t.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-student-t.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9aa7f9a63465c78f79ae4a8a11bc63d92d027dab --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-student-t.pbtxt @@ -0,0 +1,147 @@ +path: "tensorflow.distributions.StudentT" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "df" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "loc" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "scale" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'df\', \'loc\', \'scale\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'StudentT\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "cross_entropy" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'cross_entropy\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'kl_divergence\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-uniform.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d1b9d3069629c552d6c6048642934f422a13dce7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.-uniform.pbtxt @@ -0,0 +1,147 @@ +path: "tensorflow.distributions.Uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "allow_nan_stats" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "event_shape" + mtype: "" + } + member { + name: "high" + mtype: "" + } + member { + name: "low" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "parameters" + mtype: "" + } + member { + name: "reparameterization_type" + mtype: "" + } + member { + name: "validate_args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'low\', \'high\', \'validate_args\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'False\', \'True\', \'Uniform\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'cdf\'], " + } + member_method { + name: "copy" + argspec: "args=[\'self\'], varargs=None, keywords=override_parameters_kwargs, defaults=None" + } + member_method { + name: "covariance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'covariance\'], " + } + member_method { + name: "cross_entropy" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'cross_entropy\'], " + } + member_method { + name: "entropy" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'entropy\'], " + } + member_method { + name: "event_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'event_shape_tensor\'], " + } + member_method { + name: "is_scalar_batch" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_batch\'], " + } + member_method { + name: "is_scalar_event" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'is_scalar_event\'], " + } + member_method { + name: "kl_divergence" + argspec: "args=[\'self\', \'other\', \'name\'], varargs=None, keywords=None, defaults=[\'kl_divergence\'], " + } + member_method { + name: "log_cdf" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_cdf\'], " + } + member_method { + name: "log_prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_prob\'], " + } + member_method { + name: "log_survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'log_survival_function\'], " + } + member_method { + name: "mean" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mean\'], " + } + member_method { + name: "mode" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'mode\'], " + } + member_method { + name: "param_shapes" + argspec: "args=[\'cls\', \'sample_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'DistributionParamShapes\'], " + } + member_method { + name: "param_static_shapes" + argspec: "args=[\'cls\', \'sample_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "prob" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'prob\'], " + } + member_method { + name: "quantile" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'quantile\'], " + } + member_method { + name: "range" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range\'], " + } + member_method { + name: "sample" + argspec: "args=[\'self\', \'sample_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'sample\'], " + } + member_method { + name: "stddev" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'stddev\'], " + } + member_method { + name: "survival_function" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'survival_function\'], " + } + member_method { + name: "variance" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'variance\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distributions.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distributions.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..90b60ef074dd2eaf911291e6c725b98e2891e728 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distributions.pbtxt @@ -0,0 +1,75 @@ +path: "tensorflow.distributions" +tf_module { + member { + name: "Bernoulli" + mtype: "" + } + member { + name: "Beta" + mtype: "" + } + member { + name: "Categorical" + mtype: "" + } + member { + name: "Dirichlet" + mtype: "" + } + member { + name: "DirichletMultinomial" + mtype: "" + } + member { + name: "Distribution" + mtype: "" + } + member { + name: "Exponential" + mtype: "" + } + member { + name: "FULLY_REPARAMETERIZED" + mtype: "" + } + member { + name: "Gamma" + mtype: "" + } + member { + name: "Laplace" + mtype: "" + } + member { + name: "Multinomial" + mtype: "" + } + member { + name: "NOT_REPARAMETERIZED" + mtype: "" + } + member { + name: "Normal" + mtype: "" + } + member { + name: "RegisterKL" + mtype: "" + } + member { + name: "ReparameterizationType" + mtype: "" + } + member { + name: "StudentT" + mtype: "" + } + member { + name: "Uniform" + mtype: "" + } + member_method { + name: "kl_divergence" + argspec: "args=[\'distribution_a\', \'distribution_b\', \'allow_nan_stats\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.dtypes.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.dtypes.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..98e1feed002ceb4f455aa5ec361d26a159fdad1a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.dtypes.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.dtypes" +tf_module { + member_method { + name: "as_string" + argspec: "args=[\'input\', \'precision\', \'scientific\', \'shortest\', \'width\', \'fill\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'False\', \'False\', \'-1\', \'\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-aborted-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-aborted-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ea9186b0b9d5fecff35b43d2ef5dc0f2c99f3412 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-aborted-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.AbortedError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-already-exists-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-already-exists-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4e155081dd28a8a859e940338f70e9db24dff0d2 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-already-exists-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.AlreadyExistsError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-cancelled-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-cancelled-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b02a0e023aaecb5930c45aa35dbb1f0d97432cea --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-cancelled-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.CancelledError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-data-loss-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-data-loss-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c1fa66342a7022031faec68f65de9cb0ae28bcba --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-data-loss-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.DataLossError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-deadline-exceeded-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-deadline-exceeded-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8e037936191b5d52c2422f2587e7196614104d6b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-deadline-exceeded-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.DeadlineExceededError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-failed-precondition-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-failed-precondition-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..384d4b534c6ea05f9ce0fdbad32dcaf02db0ac58 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-failed-precondition-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.FailedPreconditionError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-internal-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-internal-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ac5c4d7879bbe5b040209abee088b78b15ae6f5f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-internal-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.InternalError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-invalid-argument-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-invalid-argument-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..161edd4a7c5763fe6fd96d80024065a3e3138de3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-invalid-argument-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.InvalidArgumentError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-not-found-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-not-found-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1e64730ac6d7c0d3517a8a072b9622691a7e77d7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-not-found-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.NotFoundError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-op-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-op-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b1f14c0457d95fd09fe485ae241ba9a9852879db --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-op-error.pbtxt @@ -0,0 +1,29 @@ +path: "tensorflow.errors.OpError" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-out-of-range-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-out-of-range-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6365e472868607d1ca4056859d56d16d022b3128 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-out-of-range-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.OutOfRangeError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-permission-denied-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-permission-denied-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..dc8a66f9eadf3985b6805afa3adf729e7c24f3d8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-permission-denied-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.PermissionDeniedError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-resource-exhausted-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-resource-exhausted-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..85bb384b46992c4565b14b3c13c8115fb1998abd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-resource-exhausted-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.ResourceExhaustedError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-unauthenticated-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unauthenticated-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d57d7ac2f20b98f464c5a67abdd926cd20de5e32 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unauthenticated-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.UnauthenticatedError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-unavailable-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unavailable-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..cc33e6ed8d1a9b7160b321c18735690b7b52a7d4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unavailable-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.UnavailableError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-unimplemented-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unimplemented-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b8c2e22dbd7e66909f4ba613ba7f19b6abbaa4b9 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unimplemented-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.UnimplementedError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.-unknown-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unknown-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8ffcfae95b8c7ccea29dd5b7b75e8c74fa245f7e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.-unknown-error.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.errors.UnknownError" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "error_code" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member { + name: "node_def" + mtype: "" + } + member { + name: "op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'node_def\', \'op\', \'message\', \'error_code\'], varargs=None, keywords=None, defaults=[\'2\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c5fe49baab7da5936184aa4b823de7d0a6dc33c5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.pbtxt @@ -0,0 +1,151 @@ +path: "tensorflow.errors" +tf_module { + member { + name: "ABORTED" + mtype: "" + } + member { + name: "ALREADY_EXISTS" + mtype: "" + } + member { + name: "AbortedError" + mtype: "" + } + member { + name: "AlreadyExistsError" + mtype: "" + } + member { + name: "CANCELLED" + mtype: "" + } + member { + name: "CancelledError" + mtype: "" + } + member { + name: "DATA_LOSS" + mtype: "" + } + member { + name: "DEADLINE_EXCEEDED" + mtype: "" + } + member { + name: "DataLossError" + mtype: "" + } + member { + name: "DeadlineExceededError" + mtype: "" + } + member { + name: "FAILED_PRECONDITION" + mtype: "" + } + member { + name: "FailedPreconditionError" + mtype: "" + } + member { + name: "INTERNAL" + mtype: "" + } + member { + name: "INVALID_ARGUMENT" + mtype: "" + } + member { + name: "InternalError" + mtype: "" + } + member { + name: "InvalidArgumentError" + mtype: "" + } + member { + name: "NOT_FOUND" + mtype: "" + } + member { + name: "NotFoundError" + mtype: "" + } + member { + name: "OK" + mtype: "" + } + member { + name: "OUT_OF_RANGE" + mtype: "" + } + member { + name: "OpError" + mtype: "" + } + member { + name: "OutOfRangeError" + mtype: "" + } + member { + name: "PERMISSION_DENIED" + mtype: "" + } + member { + name: "PermissionDeniedError" + mtype: "" + } + member { + name: "RESOURCE_EXHAUSTED" + mtype: "" + } + member { + name: "ResourceExhaustedError" + mtype: "" + } + member { + name: "UNAUTHENTICATED" + mtype: "" + } + member { + name: "UNAVAILABLE" + mtype: "" + } + member { + name: "UNIMPLEMENTED" + mtype: "" + } + member { + name: "UNKNOWN" + mtype: "" + } + member { + name: "UnauthenticatedError" + mtype: "" + } + member { + name: "UnavailableError" + mtype: "" + } + member { + name: "UnimplementedError" + mtype: "" + } + member { + name: "UnknownError" + mtype: "" + } + member { + name: "raise_exception_on_not_ok_status" + mtype: "" + } + member_method { + name: "error_code_from_exception_type" + argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "exception_type_from_error_code" + argspec: "args=[\'error_code\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.errors.raise_exception_on_not_ok_status.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.errors.raise_exception_on_not_ok_status.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5d25ec769ad7b086ec05f11f5676766380476012 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.errors.raise_exception_on_not_ok_status.pbtxt @@ -0,0 +1,8 @@ +path: "tensorflow.errors.raise_exception_on_not_ok_status" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..cf22e39d4c8ab915ea9507960bf28ebc09e4e5aa --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-classifier.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.estimator.BaselineClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'weighted_sum\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a363bceae3b57d879b4b8e5a8205a21c92e8835a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-regressor.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.estimator.BaselineRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'weighted_sum\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-best-exporter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-best-exporter.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9694268199a29c51f37bc73a2f92715c78854a2f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-best-exporter.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.estimator.BestExporter" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'serving_input_receiver_fn\', \'event_file_pattern\', \'compare_fn\', \'assets_extra\', \'as_text\', \'exports_to_keep\'], varargs=None, keywords=None, defaults=[\'best_exporter\', \'None\', \'eval/*.tfevents.*\', \'\', \'None\', \'False\', \'5\'], " + } + member_method { + name: "export" + argspec: "args=[\'self\', \'estimator\', \'export_path\', \'checkpoint_path\', \'eval_result\', \'is_the_final_export\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9dbb5d16a4e903a755c86bd0a6241180d1999f4d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.estimator.BoostedTreesClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..34a30c2874b90285706c9df6bec8cbbdc3451fe4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.estimator.BoostedTreesRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0c6b7e4a821ad47c20b6f6074b575bf83c403653 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-classifier.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.estimator.DNNClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9c1c072124083006a1dd8e04526755dd980ba85a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.estimator.DNNLinearCombinedClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\', \'linear_sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\', \'sum\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7391d4b07a7e79541091b94fe4a9f38f42d6f68a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.estimator.DNNLinearCombinedRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\', \'linear_sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'\', \'None\', \'1\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\', \'sum\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f50e375f7cd392567f5c87536c95eb1f6809bc97 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-regressor.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.estimator.DNNRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator-spec.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..aa6ac46613fbead7457b19e1aae5f2532afddef1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator-spec.pbtxt @@ -0,0 +1,59 @@ +path: "tensorflow.estimator.EstimatorSpec" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "eval_metric_ops" + mtype: "" + } + member { + name: "evaluation_hooks" + mtype: "" + } + member { + name: "export_outputs" + mtype: "" + } + member { + name: "loss" + mtype: "" + } + member { + name: "mode" + mtype: "" + } + member { + name: "prediction_hooks" + mtype: "" + } + member { + name: "predictions" + mtype: "" + } + member { + name: "scaffold" + mtype: "" + } + member { + name: "train_op" + mtype: "" + } + member { + name: "training_chief_hooks" + mtype: "" + } + member { + name: "training_hooks" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d72b5769778d2ee8e5da34c531878a6d53ef44f5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator.pbtxt @@ -0,0 +1,57 @@ +path: "tensorflow.estimator.Estimator" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_fn\', \'model_dir\', \'config\', \'params\', \'warm_start_from\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-eval-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-eval-spec.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..db83ba1bd8f0bd13c9048d62d74790ed2b729589 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-eval-spec.pbtxt @@ -0,0 +1,43 @@ +path: "tensorflow.estimator.EvalSpec" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "exporters" + mtype: "" + } + member { + name: "hooks" + mtype: "" + } + member { + name: "input_fn" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "start_delay_secs" + mtype: "" + } + member { + name: "steps" + mtype: "" + } + member { + name: "throttle_secs" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-exporter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-exporter.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..035af70e52024f8d16e1cd12951af10aad355eda --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-exporter.pbtxt @@ -0,0 +1,16 @@ +path: "tensorflow.estimator.Exporter" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "export" + argspec: "args=[\'self\', \'estimator\', \'export_path\', \'checkpoint_path\', \'eval_result\', \'is_the_final_export\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-final-exporter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-final-exporter.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ee37b1fa210ea816ef762590cfd1725c71262ed8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-final-exporter.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.estimator.FinalExporter" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "export" + argspec: "args=[\'self\', \'estimator\', \'export_path\', \'checkpoint_path\', \'eval_result\', \'is_the_final_export\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-latest-exporter.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-latest-exporter.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2a9d0290295114daa006d39f17a295a01e40da6b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-latest-exporter.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.estimator.LatestExporter" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'exports_to_keep\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'5\'], " + } + member_method { + name: "export" + argspec: "args=[\'self\', \'estimator\', \'export_path\', \'checkpoint_path\', \'eval_result\', \'is_the_final_export\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..154f171e89571a43a3f905094a1dbd41cbb000d3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-classifier.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.estimator.LinearClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\', \'sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\', \'sum\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4d46d1e6b68758bf634f9b0f82c279fdfa91a0b8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-regressor.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.estimator.LinearRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\', \'sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\', \'sum\'], " + } + member_method { + name: "eval_dir" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-mode-keys.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-mode-keys.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6a1c24fa63fc074c2b4ae9b3225a6abb47958b68 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-mode-keys.pbtxt @@ -0,0 +1,20 @@ +path: "tensorflow.estimator.ModeKeys" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "EVAL" + mtype: "" + } + member { + name: "PREDICT" + mtype: "" + } + member { + name: "TRAIN" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5aa4b3d4fb269785841e74c51f2121ce64377691 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-run-config.pbtxt @@ -0,0 +1,101 @@ +path: "tensorflow.estimator.RunConfig" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "cluster_spec" + mtype: "" + } + member { + name: "device_fn" + mtype: "" + } + member { + name: "evaluation_master" + mtype: "" + } + member { + name: "global_id_in_cluster" + mtype: "" + } + member { + name: "is_chief" + mtype: "" + } + member { + name: "keep_checkpoint_every_n_hours" + mtype: "" + } + member { + name: "keep_checkpoint_max" + mtype: "" + } + member { + name: "log_step_count_steps" + mtype: "" + } + member { + name: "master" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "num_ps_replicas" + mtype: "" + } + member { + name: "num_worker_replicas" + mtype: "" + } + member { + name: "protocol" + mtype: "" + } + member { + name: "save_checkpoints_secs" + mtype: "" + } + member { + name: "save_checkpoints_steps" + mtype: "" + } + member { + name: "save_summary_steps" + mtype: "" + } + member { + name: "service" + mtype: "" + } + member { + name: "session_config" + mtype: "" + } + member { + name: "task_id" + mtype: "" + } + member { + name: "task_type" + mtype: "" + } + member { + name: "tf_random_seed" + mtype: "" + } + member { + name: "train_distribute" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'\', \'\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "replace" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-train-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-train-spec.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7d2f77438afa41f2d8391524470f82a22076313b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-train-spec.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.estimator.TrainSpec" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "hooks" + mtype: "" + } + member { + name: "input_fn" + mtype: "" + } + member { + name: "max_steps" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5301b94eb361251a1cb4d02a5d8168f7c8191045 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt @@ -0,0 +1,39 @@ +path: "tensorflow.estimator.VocabInfo" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "backup_initializer" + mtype: "" + } + member { + name: "new_vocab" + mtype: "" + } + member { + name: "new_vocab_size" + mtype: "" + } + member { + name: "num_oov_buckets" + mtype: "" + } + member { + name: "old_vocab" + mtype: "" + } + member { + name: "old_vocab_size" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-warm-start-settings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-warm-start-settings.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..43f5343359aff3b856a2b3708e4cda7cec29e146 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-warm-start-settings.pbtxt @@ -0,0 +1,31 @@ +path: "tensorflow.estimator.WarmStartSettings" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "ckpt_to_initialize_from" + mtype: "" + } + member { + name: "var_name_to_prev_var_name" + mtype: "" + } + member { + name: "var_name_to_vocab_info" + mtype: "" + } + member { + name: "vars_to_warm_start" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-classification-output.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-classification-output.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3cf7af8da95479cf49469b2f328db0919fd5ce95 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-classification-output.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.estimator.export.ClassificationOutput.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-classification-output.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-classification-output.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2df1840c4a4f03fc08ba535b4f6557d49608fa5f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-classification-output.pbtxt @@ -0,0 +1,22 @@ +path: "tensorflow.estimator.export.ClassificationOutput" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "classes" + mtype: "" + } + member { + name: "scores" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'scores\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "as_signature_def" + argspec: "args=[\'self\', \'receiver_tensors\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-export-output.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-export-output.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5d165ccbf91865e48f40f88ff817bff03881a03b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-export-output.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.estimator.export.ExportOutput.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-export-output.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-export-output.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..fa62e8ced801d66951ef5a62ec4fdd9795226ebd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-export-output.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.estimator.export.ExportOutput" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "as_signature_def" + argspec: "args=[\'self\', \'receiver_tensors\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-predict-output.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-predict-output.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..743495ba98cf4db0abeba86e26b812d9e3c8695b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-predict-output.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.estimator.export.PredictOutput.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-predict-output.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-predict-output.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e0160b10ce13a0b3499143d151ee7e58ad858fb2 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-predict-output.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.estimator.export.PredictOutput" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "outputs" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'outputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_signature_def" + argspec: "args=[\'self\', \'receiver_tensors\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-regression-output.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-regression-output.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..dbf4e3dec85d7d00045bfe4e7086ba23edf61a84 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-regression-output.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.estimator.export.RegressionOutput.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-regression-output.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-regression-output.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..905f0e055350fe9a7d5790e531fb2b089332f279 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-regression-output.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.estimator.export.RegressionOutput" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "value" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_signature_def" + argspec: "args=[\'self\', \'receiver_tensors\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-serving-input-receiver.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-serving-input-receiver.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d71b2a430065740c376f8e90e3244d105ac2101f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-serving-input-receiver.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.estimator.export.ServingInputReceiver" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "features" + mtype: "" + } + member { + name: "receiver_tensors" + mtype: "" + } + member { + name: "receiver_tensors_alternatives" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-tensor-serving-input-receiver.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-tensor-serving-input-receiver.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4fe92643bf9867765499d7bf475b9cdd1686aec5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.-tensor-serving-input-receiver.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.estimator.export.TensorServingInputReceiver" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "features" + mtype: "" + } + member { + name: "receiver_tensors" + mtype: "" + } + member { + name: "receiver_tensors_alternatives" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..bd72f6cd79f7dffb9f0a7f8ae43751c4ecba939d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.export.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.estimator.export" +tf_module { + member { + name: "ClassificationOutput" + mtype: "" + } + member { + name: "ExportOutput" + mtype: "" + } + member { + name: "PredictOutput" + mtype: "" + } + member { + name: "RegressionOutput" + mtype: "" + } + member { + name: "ServingInputReceiver" + mtype: "" + } + member { + name: "TensorServingInputReceiver" + mtype: "" + } + member_method { + name: "build_parsing_serving_input_receiver_fn" + argspec: "args=[\'feature_spec\', \'default_batch_size\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "build_raw_serving_input_receiver_fn" + argspec: "args=[\'features\', \'default_batch_size\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.inputs.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.inputs.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b318fea1f82077c3924a843dd6b3857a3fdc0e8e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.inputs.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.estimator.inputs" +tf_module { + member_method { + name: "numpy_input_fn" + argspec: "args=[\'x\', \'y\', \'batch_size\', \'num_epochs\', \'shuffle\', \'queue_capacity\', \'num_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'128\', \'1\', \'None\', \'1000\', \'1\'], " + } + member_method { + name: "pandas_input_fn" + argspec: "args=[\'x\', \'y\', \'batch_size\', \'num_epochs\', \'shuffle\', \'queue_capacity\', \'num_threads\', \'target_column\'], varargs=None, keywords=None, defaults=[\'None\', \'128\', \'1\', \'None\', \'1000\', \'1\', \'target\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f1d204a3ef96f35e31f642bcb0a61351b263d273 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.pbtxt @@ -0,0 +1,111 @@ +path: "tensorflow.estimator" +tf_module { + member { + name: "BaselineClassifier" + mtype: "" + } + member { + name: "BaselineRegressor" + mtype: "" + } + member { + name: "BestExporter" + mtype: "" + } + member { + name: "BoostedTreesClassifier" + mtype: "" + } + member { + name: "BoostedTreesRegressor" + mtype: "" + } + member { + name: "DNNClassifier" + mtype: "" + } + member { + name: "DNNLinearCombinedClassifier" + mtype: "" + } + member { + name: "DNNLinearCombinedRegressor" + mtype: "" + } + member { + name: "DNNRegressor" + mtype: "" + } + member { + name: "Estimator" + mtype: "" + } + member { + name: "EstimatorSpec" + mtype: "" + } + member { + name: "EvalSpec" + mtype: "" + } + member { + name: "Exporter" + mtype: "" + } + member { + name: "FinalExporter" + mtype: "" + } + member { + name: "LatestExporter" + mtype: "" + } + member { + name: "LinearClassifier" + mtype: "" + } + member { + name: "LinearRegressor" + mtype: "" + } + member { + name: "ModeKeys" + mtype: "" + } + member { + name: "RunConfig" + mtype: "" + } + member { + name: "TrainSpec" + mtype: "" + } + member { + name: "VocabInfo" + mtype: "" + } + member { + name: "WarmStartSettings" + mtype: "" + } + member { + name: "export" + mtype: "" + } + member { + name: "inputs" + mtype: "" + } + member_method { + name: "classifier_parse_example_spec" + argspec: "args=[\'feature_columns\', \'label_key\', \'label_dtype\', \'label_default\', \'weight_column\'], varargs=None, keywords=None, defaults=[\"\", \'None\', \'None\'], " + } + member_method { + name: "regressor_parse_example_spec" + argspec: "args=[\'feature_columns\', \'label_key\', \'label_dtype\', \'label_default\', \'label_dimension\', \'weight_column\'], varargs=None, keywords=None, defaults=[\"\", \'None\', \'1\', \'None\'], " + } + member_method { + name: "train_and_evaluate" + argspec: "args=[\'estimator\', \'train_spec\', \'eval_spec\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..24a58fb118bf52e650e1df71e9374099745ade52 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt @@ -0,0 +1,59 @@ +path: "tensorflow.feature_column" +tf_module { + member_method { + name: "bucketized_column" + argspec: "args=[\'source_column\', \'boundaries\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "categorical_column_with_hash_bucket" + argspec: "args=[\'key\', \'hash_bucket_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "categorical_column_with_identity" + argspec: "args=[\'key\', \'num_buckets\', \'default_value\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "categorical_column_with_vocabulary_file" + argspec: "args=[\'key\', \'vocabulary_file\', \'vocabulary_size\', \'num_oov_buckets\', \'default_value\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \"\"], " + } + member_method { + name: "categorical_column_with_vocabulary_list" + argspec: "args=[\'key\', \'vocabulary_list\', \'dtype\', \'default_value\', \'num_oov_buckets\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\', \'0\'], " + } + member_method { + name: "crossed_column" + argspec: "args=[\'keys\', \'hash_bucket_size\', \'hash_key\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "embedding_column" + argspec: "args=[\'categorical_column\', \'dimension\', \'combiner\', \'initializer\', \'ckpt_to_load_from\', \'tensor_name_in_ckpt\', \'max_norm\', \'trainable\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "indicator_column" + argspec: "args=[\'categorical_column\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "input_layer" + argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\', \'cols_to_vars\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], " + } + member_method { + name: "linear_model" + argspec: "args=[\'features\', \'feature_columns\', \'units\', \'sparse_combiner\', \'weight_collections\', \'trainable\', \'cols_to_vars\'], varargs=None, keywords=None, defaults=[\'1\', \'sum\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "make_parse_example_spec" + argspec: "args=[\'feature_columns\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "numeric_column" + argspec: "args=[\'key\', \'shape\', \'default_value\', \'dtype\', \'normalizer_fn\'], varargs=None, keywords=None, defaults=[\'(1,)\', \'None\', \"\", \'None\'], " + } + member_method { + name: "shared_embedding_columns" + argspec: "args=[\'categorical_columns\', \'dimension\', \'combiner\', \'initializer\', \'shared_embedding_collection_name\', \'ckpt_to_load_from\', \'tensor_name_in_ckpt\', \'max_norm\', \'trainable\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "weighted_categorical_column" + argspec: "args=[\'categorical_column\', \'weight_feature_key\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.gfile.-fast-g-file.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.gfile.-fast-g-file.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..eecfaffd0a6f6e611eba8bf3f5bb709bc9e0157f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.gfile.-fast-g-file.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.gfile.FastGFile" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "mode" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'mode\'], varargs=None, keywords=None, defaults=[\'r\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flush" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "next" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "read" + argspec: "args=[\'self\', \'n\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "readline" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "readlines" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "seek" + argspec: "args=[\'self\', \'offset\', \'whence\', \'position\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "tell" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "write" + argspec: "args=[\'self\', \'file_content\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.gfile.-g-file.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.gfile.-g-file.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..305251059d90b52aa2e76e99a4ec65e68b73fb79 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.gfile.-g-file.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.gfile.GFile" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "mode" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'mode\'], varargs=None, keywords=None, defaults=[\'r\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flush" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "next" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "read" + argspec: "args=[\'self\', \'n\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "readline" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "readlines" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "seek" + argspec: "args=[\'self\', \'offset\', \'whence\', \'position\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "tell" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "write" + argspec: "args=[\'self\', \'file_content\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.gfile.-open.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.gfile.-open.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6e8894180a4a685d5a35ba02df53c6e054db01b9 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.gfile.-open.pbtxt @@ -0,0 +1,58 @@ +path: "tensorflow.gfile.Open" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "mode" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'mode\'], varargs=None, keywords=None, defaults=[\'r\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flush" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "next" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "read" + argspec: "args=[\'self\', \'n\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "readline" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "readlines" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "seek" + argspec: "args=[\'self\', \'offset\', \'whence\', \'position\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "size" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "tell" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "write" + argspec: "args=[\'self\', \'file_content\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.gfile.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.gfile.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..65b55a8b7c4e30e349c1ea256664002b19191c82 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.gfile.pbtxt @@ -0,0 +1,63 @@ +path: "tensorflow.gfile" +tf_module { + member { + name: "FastGFile" + mtype: "" + } + member { + name: "GFile" + mtype: "" + } + member { + name: "Open" + mtype: "" + } + member_method { + name: "Copy" + argspec: "args=[\'oldpath\', \'newpath\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "DeleteRecursively" + argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Exists" + argspec: "args=[\'filename\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Glob" + argspec: "args=[\'filename\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "IsDirectory" + argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "ListDirectory" + argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MakeDirs" + argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MkDir" + argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Remove" + argspec: "args=[\'filename\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Rename" + argspec: "args=[\'oldname\', \'newname\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "Stat" + argspec: "args=[\'filename\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Walk" + argspec: "args=[\'top\', \'in_order\'], varargs=None, keywords=None, defaults=[\'True\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.graph_util.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.graph_util.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..eeabf845dca94eea3ab4e54ee6ba3ba33c8995a5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.graph_util.pbtxt @@ -0,0 +1,23 @@ +path: "tensorflow.graph_util" +tf_module { + member_method { + name: "convert_variables_to_constants" + argspec: "args=[\'sess\', \'input_graph_def\', \'output_node_names\', \'variable_names_whitelist\', \'variable_names_blacklist\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "extract_sub_graph" + argspec: "args=[\'graph_def\', \'dest_nodes\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "must_run_on_cpu" + argspec: "args=[\'node\', \'pin_variables_on_cpu\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "remove_training_nodes" + argspec: "args=[\'input_graph\', \'protected_nodes\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tensor_shape_from_node_def_name" + argspec: "args=[\'graph\', \'input_name\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.image.-resize-method.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.image.-resize-method.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..dbc360b13ee7dc8228f5fb4fe0cd6fc21504d0d0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.image.-resize-method.pbtxt @@ -0,0 +1,24 @@ +path: "tensorflow.image.ResizeMethod" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "AREA" + mtype: "" + } + member { + name: "BICUBIC" + mtype: "" + } + member { + name: "BILINEAR" + mtype: "" + } + member { + name: "NEAREST_NEIGHBOR" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6ec3aba77586a9ffffd1e4375bf58394a118ea82 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt @@ -0,0 +1,247 @@ +path: "tensorflow.image" +tf_module { + member { + name: "ResizeMethod" + mtype: "" + } + member_method { + name: "adjust_brightness" + argspec: "args=[\'image\', \'delta\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "adjust_contrast" + argspec: "args=[\'images\', \'contrast_factor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "adjust_gamma" + argspec: "args=[\'image\', \'gamma\', \'gain\'], varargs=None, keywords=None, defaults=[\'1\', \'1\'], " + } + member_method { + name: "adjust_hue" + argspec: "args=[\'image\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "adjust_jpeg_quality" + argspec: "args=[\'image\', \'jpeg_quality\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "adjust_saturation" + argspec: "args=[\'image\', \'saturation_factor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "central_crop" + argspec: "args=[\'image\', \'central_fraction\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "convert_image_dtype" + argspec: "args=[\'image\', \'dtype\', \'saturate\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "crop_and_resize" + argspec: "args=[\'image\', \'boxes\', \'box_ind\', \'crop_size\', \'method\', \'extrapolation_value\', \'name\'], varargs=None, keywords=None, defaults=[\'bilinear\', \'0\', \'None\'], " + } + member_method { + name: "crop_to_bounding_box" + argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "decode_and_crop_jpeg" + argspec: "args=[\'contents\', \'crop_window\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], " + } + member_method { + name: "decode_bmp" + argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], " + } + member_method { + name: "decode_gif" + argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_image" + argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"\", \'None\'], " + } + member_method { + name: "decode_jpeg" + argspec: "args=[\'contents\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], " + } + member_method { + name: "decode_png" + argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"\", \'None\'], " + } + member_method { + name: "draw_bounding_boxes" + argspec: "args=[\'images\', \'boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "encode_jpeg" + argspec: "args=[\'image\', \'format\', \'quality\', \'progressive\', \'optimize_size\', \'chroma_downsampling\', \'density_unit\', \'x_density\', \'y_density\', \'xmp_metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'95\', \'False\', \'False\', \'True\', \'in\', \'300\', \'300\', \'\', \'None\'], " + } + member_method { + name: "encode_png" + argspec: "args=[\'image\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + } + member_method { + name: "extract_glimpse" + argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'None\'], " + } + member_method { + name: "extract_image_patches" + argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "extract_jpeg_shape" + argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "flip_left_right" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flip_up_down" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "grayscale_to_rgb" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "hsv_to_rgb" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "image_gradients" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_jpeg" + argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "non_max_suppression" + argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], " + } + member_method { + name: "non_max_suppression_overlaps" + argspec: "args=[\'overlaps\', \'scores\', \'max_output_size\', \'overlap_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], " + } + member_method { + name: "pad_to_bounding_box" + argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "per_image_standardization" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "psnr" + argspec: "args=[\'a\', \'b\', \'max_val\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_brightness" + argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_contrast" + argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_flip_left_right" + argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_flip_up_down" + argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_hue" + argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_jpeg_quality" + argspec: "args=[\'image\', \'min_jpeg_quality\', \'max_jpeg_quality\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "random_saturation" + argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "resize_area" + argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "resize_bicubic" + argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "resize_bilinear" + argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "resize_image_with_crop_or_pad" + argspec: "args=[\'image\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "resize_image_with_pad" + argspec: "args=[\'image\', \'target_height\', \'target_width\', \'method\'], varargs=None, keywords=None, defaults=[\'0\'], " + } + member_method { + name: "resize_images" + argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\', \'preserve_aspect_ratio\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\'], " + } + member_method { + name: "resize_nearest_neighbor" + argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "rgb_to_grayscale" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rgb_to_hsv" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rgb_to_yiq" + argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "rgb_to_yuv" + argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "rot90" + argspec: "args=[\'image\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " + } + member_method { + name: "sample_distorted_bounding_box" + argspec: "args=[\'image_size\', \'bounding_boxes\', \'seed\', \'seed2\', \'min_object_covered\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.1\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sobel_edges" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "ssim" + argspec: "args=[\'img1\', \'img2\', \'max_val\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "ssim_multiscale" + argspec: "args=[\'img1\', \'img2\', \'max_val\', \'power_factors\'], varargs=None, keywords=None, defaults=[\'(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)\'], " + } + member_method { + name: "total_variation" + argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "transpose_image" + argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "yiq_to_rgb" + argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "yuv_to_rgb" + argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.constant.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.constant.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..607a5aae21ff7299fc96aee3b932c10d622f1127 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.constant.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.initializers.constant" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'value\', \'dtype\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'0\', \"\", \'False\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.identity.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..37fcab95997bb7299675a387d08184fc1387eee1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.identity.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.initializers.identity" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'gain\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.ones.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.ones.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..18481d48150d2dcf7d6908ab1914ab217da93c10 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.ones.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.initializers.ones" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.orthogonal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.orthogonal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ff64efd60cf1197bb9032912eb5cba48a63609a0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.orthogonal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.initializers.orthogonal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'gain\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..bc0426f2f145763552dcb46fb6c2efc7c42b974f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt @@ -0,0 +1,79 @@ +path: "tensorflow.initializers" +tf_module { + member { + name: "constant" + mtype: "" + } + member { + name: "identity" + mtype: "" + } + member { + name: "ones" + mtype: "" + } + member { + name: "orthogonal" + mtype: "" + } + member { + name: "random_normal" + mtype: "" + } + member { + name: "random_uniform" + mtype: "" + } + member { + name: "truncated_normal" + mtype: "" + } + member { + name: "uniform_unit_scaling" + mtype: "" + } + member { + name: "variance_scaling" + mtype: "" + } + member { + name: "zeros" + mtype: "" + } + member_method { + name: "global_variables" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "glorot_normal" + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "glorot_uniform" + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "he_normal" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "he_uniform" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lecun_normal" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lecun_uniform" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "local_variables" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "variables" + argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'init\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_normal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..133e61c1d9869bdd00948df3877be990b30b7cc3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_normal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.initializers.random_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_uniform.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0cfa0080f5a936bc80f69c2b5c15f671096ba350 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.random_uniform.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.initializers.random_uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.truncated_normal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..730390fba274f9dc25eea7a53bb8145a2ade8613 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.truncated_normal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.initializers.truncated_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.uniform_unit_scaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.uniform_unit_scaling.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..13295ef375a4002f8fece5ebb5d2a5d5d26c68eb --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.uniform_unit_scaling.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.initializers.uniform_unit_scaling" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'factor\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.variance_scaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.variance_scaling.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..86340913e2506c96499aae05a3ed0d5273c93bba --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.variance_scaling.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.initializers.variance_scaling" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.zeros.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.zeros.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7df4237bb6537b39f42f7b3894beb1bec6641f6f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.zeros.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.initializers.zeros" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3a36c168aa703721421b662185fc852fa3d6a3ec --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt @@ -0,0 +1,39 @@ +path: "tensorflow.io" +tf_module { + member_method { + name: "decode_base64" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_compressed" + argspec: "args=[\'bytes\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " + } + member_method { + name: "decode_json_example" + argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "decode_raw" + argspec: "args=[\'bytes\', \'out_type\', \'little_endian\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "encode_base64" + argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "matching_files" + argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "parse_tensor" + argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_file" + argspec: "args=[\'filename\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "write_file" + argspec: "args=[\'filename\', \'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..40e82b18b68f9e8353dcb04f76ebb36446d3ab3f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt @@ -0,0 +1,268 @@ +path: "tensorflow.keras.Model" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "layers" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "state_updates" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "uses_learning_phase" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compile" + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + } + member_method { + name: "evaluate_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "fit" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + } + member_method { + name: "fit_generator" + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_layer" + argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_weights" + argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "predict_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "predict_on_batch" + argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "save" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], " + } + member_method { + name: "save_weights" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "summary" + argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "test_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "to_json" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "to_yaml" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "train_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..65cfad77d1f3cdf682b6681fbebc950e6c1ca8a8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt @@ -0,0 +1,289 @@ +path: "tensorflow.keras.Sequential" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "layers" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "state_updates" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "uses_learning_phase" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'layers\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add" + argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compile" + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + } + member_method { + name: "evaluate_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "fit" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + } + member_method { + name: "fit_generator" + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_layer" + argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_weights" + argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "pop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "predict_classes" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " + } + member_method { + name: "predict_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "predict_on_batch" + argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict_proba" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " + } + member_method { + name: "reset_states" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "save" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], " + } + member_method { + name: "save_weights" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "summary" + argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "symbolic_set_inputs" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "test_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "to_json" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "to_yaml" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "train_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2cd83baf65cf4114e58f52cdc40de7e4b6df7554 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt @@ -0,0 +1,55 @@ +path: "tensorflow.keras.activations" +tf_module { + member_method { + name: "deserialize" + argspec: "args=[\'name\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "elu" + argspec: "args=[\'x\', \'alpha\'], varargs=None, keywords=None, defaults=[\'1.0\'], " + } + member_method { + name: "get" + argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "hard_sigmoid" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "linear" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "relu" + argspec: "args=[\'x\', \'alpha\', \'max_value\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], " + } + member_method { + name: "selu" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "serialize" + argspec: "args=[\'activation\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "sigmoid" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "softmax" + argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "softplus" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "softsign" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "tanh" + argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.densenet.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.densenet.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.applications.densenet.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.applications.densenet.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.inception_resnet_v2.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.inception_resnet_v2.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.applications.inception_resnet_v2.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.applications.inception_resnet_v2.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.inception_v3.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.inception_v3.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.applications.inception_v3.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.applications.inception_v3.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.mobilenet.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.mobilenet.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.applications.mobilenet.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.applications.mobilenet.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.nasnet.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.nasnet.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.applications.nasnet.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.applications.nasnet.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.applications.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.resnet50.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.resnet50.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.applications.resnet50.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.applications.resnet50.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.vgg16.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.vgg16.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.applications.vgg16.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.applications.vgg16.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.vgg19.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.vgg19.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.applications.vgg19.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.applications.vgg19.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.xception.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.applications.xception.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.applications.xception.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.applications.xception.pbtxt diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.name_scope.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.name_scope.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a2b98b1c27c2268326af2653177b38e25f838c8d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.name_scope.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.keras.backend.name_scope" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'default_name\', \'values\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt similarity index 98% rename from tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt index c6149e8aa7e3650e628e37b0e00a54348012475b..fddac63b7817102cfc7e46d132d2871d8726c358 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt @@ -70,7 +70,7 @@ tf_module { } member_method { name: "categorical_crossentropy" - argspec: "args=[\'target\', \'output\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'target\', \'output\', \'from_logits\', \'axis\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\'], " } member_method { name: "clear_session" @@ -462,7 +462,7 @@ tf_module { } member_method { name: "sparse_categorical_crossentropy" - argspec: "args=[\'target\', \'output\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'False\'], " + argspec: "args=[\'target\', \'output\', \'from_logits\', \'axis\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\'], " } member_method { name: "spatial_2d_padding" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-base-logger.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-base-logger.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9eee9b378964a9947b067b7ec495ef6556ab6d0c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-base-logger.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.keras.callbacks.BaseLogger" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'stateful_metrics\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5bb949c5bb650acee91b14a4d6bf95b36029edf7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.keras.callbacks.CSVLogger" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filename\', \'separator\', \'append\'], varargs=None, keywords=None, defaults=[\',\', \'False\'], " + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-callback.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-callback.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a5340d52c1af6d69da30fd710bcee9d832917574 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-callback.pbtxt @@ -0,0 +1,41 @@ +path: "tensorflow.keras.callbacks.Callback" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-early-stopping.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-early-stopping.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f71292856cd29b2e52194bec8a586686fbfad667 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-early-stopping.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.keras.callbacks.EarlyStopping" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'monitor\', \'min_delta\', \'patience\', \'verbose\', \'mode\', \'baseline\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0\', \'0\', \'0\', \'auto\', \'None\'], " + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-history.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-history.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ee400b31c43829efba156298d5ee807cdafc8a98 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-history.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.keras.callbacks.History" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-lambda-callback.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-lambda-callback.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..df8d7b0ef7afca17338a26388c38827b5b306f95 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-lambda-callback.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.keras.callbacks.LambdaCallback" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'on_epoch_begin\', \'on_epoch_end\', \'on_batch_begin\', \'on_batch_end\', \'on_train_begin\', \'on_train_end\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ce1a9b694d8708720e0eb677afd25607c6262e9c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.keras.callbacks.LearningRateScheduler" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'schedule\', \'verbose\'], varargs=None, keywords=None, defaults=[\'0\'], " + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..48bb24a05274addca03f11acef99607f78b92e51 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.keras.callbacks.ModelCheckpoint" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'period\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'1\'], " + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-progbar-logger.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-progbar-logger.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d8bb8b2a7d0f491c7ec2b30096a1acaf04681a56 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-progbar-logger.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.keras.callbacks.ProgbarLogger" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'count_mode\', \'stateful_metrics\'], varargs=None, keywords=None, defaults=[\'samples\', \'None\'], " + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..dc27af9552a88650261b4f0694ea0265e6bda05c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.keras.callbacks.ReduceLROnPlateau" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'monitor\', \'factor\', \'patience\', \'verbose\', \'mode\', \'min_delta\', \'cooldown\', \'min_lr\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0.1\', \'10\', \'0\', \'auto\', \'0.0001\', \'0\', \'0\'], " + } + member_method { + name: "in_cooldown" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-remote-monitor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-remote-monitor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5a3b791c0adc0d61129d38b2995ee9077cf0988b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-remote-monitor.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.keras.callbacks.RemoteMonitor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'root\', \'path\', \'field\', \'headers\', \'send_as_json\'], varargs=None, keywords=None, defaults=[\'http://localhost:9000\', \'/publish/epoch/end/\', \'data\', \'None\', \'False\'], " + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e58ba18c1c0d06df3a53d93ae18f5bf0931df329 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.keras.callbacks.TensorBoard" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'batch_size\', \'write_graph\', \'write_grads\', \'write_images\', \'embeddings_freq\', \'embeddings_layer_names\', \'embeddings_metadata\', \'embeddings_data\'], varargs=None, keywords=None, defaults=[\'./logs\', \'0\', \'32\', \'True\', \'False\', \'False\', \'0\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5c2d336353aee7fc98b45620adac4f4bcda05ea0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.keras.callbacks.TerminateOnNaN" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "on_batch_begin" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_batch_end" + argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_begin" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_begin" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "on_train_end" + argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_model" + argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1e9085e034ccf22fda7be7565aabb86992a8b0b7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.pbtxt @@ -0,0 +1,55 @@ +path: "tensorflow.keras.callbacks" +tf_module { + member { + name: "BaseLogger" + mtype: "" + } + member { + name: "CSVLogger" + mtype: "" + } + member { + name: "Callback" + mtype: "" + } + member { + name: "EarlyStopping" + mtype: "" + } + member { + name: "History" + mtype: "" + } + member { + name: "LambdaCallback" + mtype: "" + } + member { + name: "LearningRateScheduler" + mtype: "" + } + member { + name: "ModelCheckpoint" + mtype: "" + } + member { + name: "ProgbarLogger" + mtype: "" + } + member { + name: "ReduceLROnPlateau" + mtype: "" + } + member { + name: "RemoteMonitor" + mtype: "" + } + member { + name: "TensorBoard" + mtype: "" + } + member { + name: "TerminateOnNaN" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-constraint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-constraint.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8e07b7d98e1d832628f65bed19eddca76bfbd51a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-constraint.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.keras.constraints.Constraint" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-max-norm.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-max-norm.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2b81174b6cd4d57d8d6e20da7f6961442045d908 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-max-norm.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.keras.constraints.MaxNorm" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'max_value\', \'axis\'], varargs=None, keywords=None, defaults=[\'2\', \'0\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-min-max-norm.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-min-max-norm.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a41eda86ac2583b1adfe745f713ac8f8647f7a31 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-min-max-norm.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.keras.constraints.MinMaxNorm" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'min_value\', \'max_value\', \'rate\', \'axis\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'1.0\', \'0\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-non-neg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-non-neg.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..572e3eea4d985999f513a066b348d088ab01fe54 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-non-neg.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.keras.constraints.NonNeg" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-unit-norm.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-unit-norm.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..fe16c38cc83fb9979ecf0d08ab2cba7a2c38f9b6 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.-unit-norm.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.keras.constraints.UnitNorm" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.max_norm.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.max_norm.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6650bae07a0d32448e748598af3426f85ca8e199 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.max_norm.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.keras.constraints.max_norm" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'max_value\', \'axis\'], varargs=None, keywords=None, defaults=[\'2\', \'0\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.min_max_norm.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.min_max_norm.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9dd3bc92fc4fadee863f30b300ddb60fe0b3d340 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.min_max_norm.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.keras.constraints.min_max_norm" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'min_value\', \'max_value\', \'rate\', \'axis\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'1.0\', \'0\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.non_neg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.non_neg.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a565840939f99080b784e4e95302071600a1fa7c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.non_neg.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.keras.constraints.non_neg" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..655685956f0e42e2d92dca0ac36f4cca075f474a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.keras.constraints" +tf_module { + member { + name: "Constraint" + mtype: "" + } + member { + name: "MaxNorm" + mtype: "" + } + member { + name: "MinMaxNorm" + mtype: "" + } + member { + name: "NonNeg" + mtype: "" + } + member { + name: "UnitNorm" + mtype: "" + } + member { + name: "max_norm" + mtype: "" + } + member { + name: "min_max_norm" + mtype: "" + } + member { + name: "non_neg" + mtype: "" + } + member { + name: "unit_norm" + mtype: "" + } + member_method { + name: "deserialize" + argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get" + argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "serialize" + argspec: "args=[\'constraint\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.unit_norm.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.unit_norm.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5cbe0da4c1d1ff97fe836f76402cfca92e1cc511 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.constraints.unit_norm.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.keras.constraints.unit_norm" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.boston_housing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.boston_housing.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..bda31751d429ca0d0544402e5c496a0597e1849e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.boston_housing.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.keras.datasets.boston_housing" +tf_module { + member_method { + name: "load_data" + argspec: "args=[\'path\', \'test_split\', \'seed\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.cifar10.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.cifar10.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8a5142f793d67b3a923f3033c0da14442c4f680f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.cifar10.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.keras.datasets.cifar10" +tf_module { + member_method { + name: "load_data" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.cifar100.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.cifar100.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..16f184eeb5e8ee4f126b943c8988ec28ceab89a4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.cifar100.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.keras.datasets.cifar100" +tf_module { + member_method { + name: "load_data" + argspec: "args=[\'label_mode\'], varargs=None, keywords=None, defaults=[\'fine\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.fashion_mnist.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.fashion_mnist.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a0e14356fa5e91bc81bd89f6eb8c07087956c392 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.fashion_mnist.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.keras.datasets.fashion_mnist" +tf_module { + member_method { + name: "load_data" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.imdb.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.imdb.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ff962876b66cae013de5d711dc7eac5d5c80d8c3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.imdb.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.keras.datasets.imdb" +tf_module { + member_method { + name: "get_word_index" + argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'imdb_word_index.json\'], " + } + member_method { + name: "load_data" + argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.mnist.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.mnist.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..530bb0755060f243281523c68b9c554dcbdbc634 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.mnist.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.keras.datasets.mnist" +tf_module { + member_method { + name: "load_data" + argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'mnist.npz\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..36e3aafbe4dbc22fade073b45b2d7495f8f7ec52 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.pbtxt @@ -0,0 +1,31 @@ +path: "tensorflow.keras.datasets" +tf_module { + member { + name: "boston_housing" + mtype: "" + } + member { + name: "cifar10" + mtype: "" + } + member { + name: "cifar100" + mtype: "" + } + member { + name: "fashion_mnist" + mtype: "" + } + member { + name: "imdb" + mtype: "" + } + member { + name: "mnist" + mtype: "" + } + member { + name: "reuters" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.reuters.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.reuters.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2da4a13067f2b39eb06304864ea626002300a862 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.datasets.reuters.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.keras.datasets.reuters" +tf_module { + member_method { + name: "get_word_index" + argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'reuters_word_index.json\'], " + } + member_method { + name: "load_data" + argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.estimator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7a3fb39f774d24d3e6e5c87233f055f50cfc08bb --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.estimator.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.keras.estimator" +tf_module { + member_method { + name: "model_to_estimator" + argspec: "args=[\'keras_model\', \'keras_model_path\', \'custom_objects\', \'model_dir\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-constant.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-constant.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..cbaba78ed5a851c3d6e29ab67c89fdfd5db01754 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-constant.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.Constant" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'value\', \'dtype\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'0\', \"\", \'False\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-identity.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a5f7f348de9d9899d962e7647d7943ddb6a60604 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-identity.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.Identity" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'gain\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-initializer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8f10d1698e7b7b2afa9c2664c7dca38045eda85b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-initializer.pbtxt @@ -0,0 +1,16 @@ +path: "tensorflow.keras.initializers.Initializer" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-ones.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-ones.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2fbfa774f8ed020164e32bb3cfb69b8a235609ba --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-ones.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.Ones" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-orthogonal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-orthogonal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..874d320d73d1f1cdbd817db587ea9dcfea4d352b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-orthogonal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.Orthogonal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'gain\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..23cd02c0b069d3cb2d7b9e7ebc754db288e4637a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-normal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.RandomNormal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d98628f42253603178cdff2624f639afa846a66a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-random-uniform.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.RandomUniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..86d48257c1ffb95fc217de475efba41002f8e7a5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-truncated-normal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.TruncatedNormal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-variance-scaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-variance-scaling.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..03f4064b9ef5093044a9cbb897043d643cf7f83e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-variance-scaling.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.VarianceScaling" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-zeros.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-zeros.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b6ab68e5beb47c9bcfbc52f9808255bbb03d2dc0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.-zeros.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.Zeros" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.constant.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.constant.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..bddc37b907e7573c9fff27a0c3a5f7e199b88a9a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.constant.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.constant" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'value\', \'dtype\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'0\', \"\", \'False\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a4c5a6149047ffdaadde1243e4c80feae05cd77b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.identity" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'gain\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.normal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7485772784d40b7bf552efe9bbe8b22fadee3b86 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.normal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.ones.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.ones.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a89f78d1e1a47c7cd5a252cfd0a7b2fa23979e90 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.ones.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.ones" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ee1e9bbae2b7130db5b96309e2d87719169d788a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.orthogonal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'gain\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8645e5430295dff0a5b7c715b03860fb7734e7f1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt @@ -0,0 +1,119 @@ +path: "tensorflow.keras.initializers" +tf_module { + member { + name: "Constant" + mtype: "" + } + member { + name: "Identity" + mtype: "" + } + member { + name: "Initializer" + mtype: "" + } + member { + name: "Ones" + mtype: "" + } + member { + name: "Orthogonal" + mtype: "" + } + member { + name: "RandomNormal" + mtype: "" + } + member { + name: "RandomUniform" + mtype: "" + } + member { + name: "TruncatedNormal" + mtype: "" + } + member { + name: "VarianceScaling" + mtype: "" + } + member { + name: "Zeros" + mtype: "" + } + member { + name: "constant" + mtype: "" + } + member { + name: "identity" + mtype: "" + } + member { + name: "normal" + mtype: "" + } + member { + name: "ones" + mtype: "" + } + member { + name: "orthogonal" + mtype: "" + } + member { + name: "random_normal" + mtype: "" + } + member { + name: "random_uniform" + mtype: "" + } + member { + name: "truncated_normal" + mtype: "" + } + member { + name: "uniform" + mtype: "" + } + member { + name: "zeros" + mtype: "" + } + member_method { + name: "deserialize" + argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get" + argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "glorot_normal" + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "glorot_uniform" + argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"\"], " + } + member_method { + name: "he_normal" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "he_uniform" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lecun_normal" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lecun_uniform" + argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize" + argspec: "args=[\'initializer\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a6df1e87a3f68fb16e32dce1ba4ee29f6d86e74e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_normal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.random_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..37a0fa0d5508de0026472ff1a3aa532bb8f343cd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.random_uniform.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.random_uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f97e93f0b72d5e959722d15fa9dc35869c550710 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.truncated_normal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.truncated_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.uniform.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..58186b1383d8997165bb457e1cb54df86cd02d11 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.uniform.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.zeros.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.zeros.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a262390687f31a5fb79822e69273306b9e1897b5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.zeros.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.zeros" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..86e328888e596852caf9ad1020dfdedb71864969 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.Activation" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'activation\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b0ed54578109c6ae8d5bc2c9f5c978b562a9cc84 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.ActivityRegularization" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'l1\', \'l2\'], varargs=None, keywords=kwargs, defaults=[\'0.0\', \'0.0\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..42f98ed03d426d60cabeb0b533311d41eb378285 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Add" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..000898a4be928e4e64b4072ef3170b6fbc930bdf --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.AlphaDropout" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'rate\', \'noise_shape\', \'seed\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..380b49f99ce6e62770a9516ba81db99f194c5b37 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.AveragePooling1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..82db5e6137639e516f6df6f0e130e73be516c9b8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.AveragePooling2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2)\', \'None\', \'valid\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b6ff688ec36f8c47b2ac9694fb84350818be25c5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.AveragePooling3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2, 2)\', \'None\', \'valid\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b41290f8b067397bf6678d9e98ac53f28a05a3fc --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Average" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..88a033e61f42e2fb02b08968ff001ea21195972a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.AvgPool1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c1b9b96044ed2e057b8e86dda59ee7f7166cfd43 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.AvgPool2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2)\', \'None\', \'valid\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f59f7727a3eaeb4fa5631cb1b42901ea6d39b06b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.AvgPool3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2, 2)\', \'None\', \'valid\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7d3744ed92636a972bae2f9b62a6b2da8f91d106 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.BatchNormalization" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\', \'trainable\', \'virtual_batch_size\', \'adjustment\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'zeros\', \'ones\', \'zeros\', \'ones\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3fd4ccdab2573964c2f3192d503e9fb15f442dc5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt @@ -0,0 +1,188 @@ +path: "tensorflow.keras.layers.Bidirectional" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "constraints" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'layer\', \'merge_mode\', \'weights\'], varargs=None, keywords=kwargs, defaults=[\'concat\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\', \'initial_state\', \'constants\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ba21b50be41f3adc735b3350bdf9dbeae3c2e358 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Concatenate" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'axis\'], varargs=None, keywords=kwargs, defaults=[\'-1\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..46f9fa2bbbbe3cfff3aade33c5ebdec92bc70ef0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -0,0 +1,273 @@ +path: "tensorflow.keras.layers.ConvLSTM2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activation" + mtype: "" + } + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "bias_constraint" + mtype: "" + } + member { + name: "bias_initializer" + mtype: "" + } + member { + name: "bias_regularizer" + mtype: "" + } + member { + name: "data_format" + mtype: "" + } + member { + name: "dilation_rate" + mtype: "" + } + member { + name: "dropout" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "filters" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "kernel_constraint" + mtype: "" + } + member { + name: "kernel_initializer" + mtype: "" + } + member { + name: "kernel_regularizer" + mtype: "" + } + member { + name: "kernel_size" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "padding" + mtype: "" + } + member { + name: "recurrent_activation" + mtype: "" + } + member { + name: "recurrent_constraint" + mtype: "" + } + member { + name: "recurrent_dropout" + mtype: "" + } + member { + name: "recurrent_initializer" + mtype: "" + } + member { + name: "recurrent_regularizer" + mtype: "" + } + member { + name: "states" + mtype: "" + } + member { + name: "strides" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "unit_forget_bias" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "use_bias" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c3ad326589d2822bc5dd381d78216b25f5fb6f95 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Conv1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..fd9eb43066be580a7df57aeb717b59569c9bba61 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt @@ -0,0 +1,177 @@ +path: "tensorflow.keras.layers.Conv2DTranspose" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..40d61688f29a81e873a26c8a5eb823d679320ed6 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Conv2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b8c227d7257311578e41abe0a384ed93e6a2866c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt @@ -0,0 +1,177 @@ +path: "tensorflow.keras.layers.Conv3DTranspose" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..095d35e5749d0113956b04f971e6a8ca1fa277b8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Conv3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8f999611982bbfe3c613ef26d93782e299275f19 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Convolution1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..96d522a016aedba01032a1c05a69511cb03d19af --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt @@ -0,0 +1,177 @@ +path: "tensorflow.keras.layers.Convolution2DTranspose" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..de2824dab4526d90eebf9cef16710cadf82f4850 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Convolution2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1d563241d8f0d93bcd19a319eb8383f4bcdf4388 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt @@ -0,0 +1,177 @@ +path: "tensorflow.keras.layers.Convolution3DTranspose" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c87e52c53796f0743365a9d8780decf237bba070 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Convolution3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..dccf5523e3870b6c1ce0de70c648ab47968a105f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.Cropping1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cropping\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7ac4116d922eea51e5a7e7fe3d02ad919300c459 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.Cropping2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cropping\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'((0, 0), (0, 0))\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..024f72705de1e76866a8132246884dffb0c4e72a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.Cropping3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cropping\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'((1, 1), (1, 1), (1, 1))\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4e0233331bd47e86e8a4df2f84b5392517fbf884 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt @@ -0,0 +1,193 @@ +path: "tensorflow.keras.layers.CuDNNGRU" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "cell" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "states" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\'], varargs=None, keywords=kwargs, defaults=[\'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'False\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..32d46ce8f3deff6077eaf5a1a8cf7ba64478d9f4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt @@ -0,0 +1,193 @@ +path: "tensorflow.keras.layers.CuDNNLSTM" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "cell" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "states" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\'], varargs=None, keywords=kwargs, defaults=[\'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'False\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..858486c725c3be5ecae2a02d0d3134ebeb113ce1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.Dense" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f65d7509262bfeb148588e069c08961058a3fa74 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt @@ -0,0 +1,177 @@ +path: "tensorflow.keras.layers.DepthwiseConv2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'kernel_size\', \'strides\', \'padding\', \'depth_multiplier\', \'data_format\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'1\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2e71ef503d54927edbb3e1ef6c701ac845883e46 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Dot" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'axes\', \'normalize\'], varargs=None, keywords=kwargs, defaults=[\'False\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..42533bcd21b28a0acf183db195a6b5c1848a5d91 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.Dropout" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'rate\', \'noise_shape\', \'seed\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b5df16941792a29d72f2ee709993b007d342d2d0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.ELU" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'alpha\'], varargs=None, keywords=kwargs, defaults=[\'1.0\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0ea17919a9bb13ffdedd60ce618bca23dd52712f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.Embedding" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'input_dim\', \'output_dim\', \'embeddings_initializer\', \'embeddings_regularizer\', \'activity_regularizer\', \'embeddings_constraint\', \'mask_zero\', \'input_length\'], varargs=None, keywords=kwargs, defaults=[\'uniform\', \'None\', \'None\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a33248bc005a73d0be679cd62150d6019b475305 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.Flatten" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4ba21a25cda83122fbced7fed76d4b1ae28cb4c8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.GRUCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'reset_after\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a7a570418e0a78873237c1c8cefe36a212e4c9af --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt @@ -0,0 +1,256 @@ +path: "tensorflow.keras.layers.GRU" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activation" + mtype: "" + } + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "bias_constraint" + mtype: "" + } + member { + name: "bias_initializer" + mtype: "" + } + member { + name: "bias_regularizer" + mtype: "" + } + member { + name: "dropout" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "implementation" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "kernel_constraint" + mtype: "" + } + member { + name: "kernel_initializer" + mtype: "" + } + member { + name: "kernel_regularizer" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "recurrent_activation" + mtype: "" + } + member { + name: "recurrent_constraint" + mtype: "" + } + member { + name: "recurrent_dropout" + mtype: "" + } + member { + name: "recurrent_initializer" + mtype: "" + } + member { + name: "recurrent_regularizer" + mtype: "" + } + member { + name: "reset_after" + mtype: "" + } + member { + name: "states" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "units" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "use_bias" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'reset_after\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\', \'False\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..763bc231136908d469b7f942aec94f6248d2e2d4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.GaussianDropout" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'rate\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3c50a3d7f28809b2b810b52951207e48f9f50e34 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.GaussianNoise" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'stddev\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ac78bdafada8c157efd4ab8746be15726eb0bc24 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.GlobalAveragePooling1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..275282d9d2b1753cf0189b605f921bb039ef5f3c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.GlobalAveragePooling2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0e31e6058bd6036a5fb4422335917718f4f82851 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.GlobalAveragePooling3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..aacd0b1791dda5babb6eef5d87a1335c8d519b08 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.GlobalAvgPool1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c23654866341818aeb804cfb71dae052049e3f25 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.GlobalAvgPool2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6b9c0290aac35d80c7f87acfc44479c57623a645 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.GlobalAvgPool3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0d7b2211e6cd35ca331b4a1068f237e7ca07f70c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.GlobalMaxPool1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d080ad6aedbd5183da890cd63f5f18453d5d476a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.GlobalMaxPool2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..fcb0a109da208ff5bd20447ddced9816a42af311 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.GlobalMaxPool3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1d0e22abd0d8732182881c43ee79400642cef24b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.GlobalMaxPooling1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..653c9f547bc888a8fec87137f7d495141d4f8599 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.GlobalMaxPooling2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..cdbaf82cf6746e878619647439d2256f6e2c4aa3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.GlobalMaxPooling3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..230c5e903438b0a75edf80f0f5c8706987c66a78 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.InputLayer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5fd0a47a68c0d4ad218c4c64cc6be8f603d9673a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.keras.layers.InputSpec" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..511456e740837455818ff3f9be270daed03f334f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.LSTMCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4a3492ebd652e5ab8f0faf8a1583480abc80fba7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt @@ -0,0 +1,256 @@ +path: "tensorflow.keras.layers.LSTM" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activation" + mtype: "" + } + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "bias_constraint" + mtype: "" + } + member { + name: "bias_initializer" + mtype: "" + } + member { + name: "bias_regularizer" + mtype: "" + } + member { + name: "dropout" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "implementation" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "kernel_constraint" + mtype: "" + } + member { + name: "kernel_initializer" + mtype: "" + } + member { + name: "kernel_regularizer" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "recurrent_activation" + mtype: "" + } + member { + name: "recurrent_constraint" + mtype: "" + } + member { + name: "recurrent_dropout" + mtype: "" + } + member { + name: "recurrent_initializer" + mtype: "" + } + member { + name: "recurrent_regularizer" + mtype: "" + } + member { + name: "states" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "unit_forget_bias" + mtype: "" + } + member { + name: "units" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "use_bias" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt index ecbaa9ce2c76bf3d2964a6c79c96c4d67cc3b80e..5d05cf689fb399d6630f68b09fd123d2d968786b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt @@ -98,7 +98,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7efa29be77c075a29784d8cd3ebfcd871bc9aa0c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt @@ -0,0 +1,174 @@ +path: "tensorflow.keras.layers.Layer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'trainable\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'True\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0ca8e0b52c4a81c4ff3b756aa6c24b47a664f999 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.LeakyReLU" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'alpha\'], varargs=None, keywords=kwargs, defaults=[\'0.3\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f754fa1da85692c28f31a76bbfa987b3c4c30731 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.LocallyConnected1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c9516b8f07d0b6a818bf99d45499d161c2a5cffd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.LocallyConnected2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..850ecff9743b5f5048bb81c5a15b0a4be6b4d0ce --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.Masking" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mask_value\'], varargs=None, keywords=kwargs, defaults=[\'0.0\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7c69e31f9af9bbd221882d160fa4206997ec3b08 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.MaxPool1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..fba42642d7c701688c2bd274cf97e077e7ff571c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.MaxPool2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2)\', \'None\', \'valid\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9c277411ea5ce26df9c033ada773ad2e45292cb1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.MaxPool3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2, 2)\', \'None\', \'valid\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7c2f6ccc8a98017aba014ab6a7896e0a4bf40324 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.MaxPooling1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..802178dba63d66cca1629bcb7bef0f578c9a6659 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.MaxPooling2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2)\', \'None\', \'valid\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e870dfe9ade75da367f87a4b54d38ba4274bab2e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.MaxPooling3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2, 2)\', \'None\', \'valid\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c1337ce0cbac2d1e0e011f5309bfb2722960d3b2 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Maximum" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ed27a62765d5670802d4593b3e648e3f65eaf926 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Minimum" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b9f05cb3e56f89cb02e1a74c3ec0d362ea27f2bf --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Multiply" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..336d9f76fb1e6215b763b5064cd6be68d4d0d5a0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.PReLU" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'alpha_initializer\', \'alpha_regularizer\', \'alpha_constraint\', \'shared_axes\'], varargs=None, keywords=kwargs, defaults=[\'zeros\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..46282217e01e8a137d9fc564f0e3544602d93de4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.Permute" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dims\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..42cd7e87eebdd969f002d8bcd0dca101168c58e0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt @@ -0,0 +1,187 @@ +path: "tensorflow.keras.layers.RNN" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "states" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\', \'constants\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c00fa79adfbe5b986b481f6c9567bafbf3abc1ae --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.ReLU" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'max_value\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9f094a877a3a47ff89a022db563803f5f391ff2a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.RepeatVector" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'n\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2f519a24385ac4e147798ed3e96101cff23e19aa --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.Reshape" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'target_shape\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6b93116ba02c2b7e9c5bdf79ddfa1f93050062a4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt @@ -0,0 +1,177 @@ +path: "tensorflow.keras.layers.SeparableConv1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..fd17115e2733d561bff1d53d62d32458b03dc65b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt @@ -0,0 +1,177 @@ +path: "tensorflow.keras.layers.SeparableConv2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4b37a94478857ac8550ea0c4f464058c68770047 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt @@ -0,0 +1,177 @@ +path: "tensorflow.keras.layers.SeparableConvolution1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5bdadca74aeb963adef4999b7e758add1aec4681 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt @@ -0,0 +1,177 @@ +path: "tensorflow.keras.layers.SeparableConvolution2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9dfda96fc81572d70d76ba767b69ee2e41f017ee --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.SimpleRNNCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7b7684ccd27a1d4c3fabf56c2669f77095f501ef --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt @@ -0,0 +1,244 @@ +path: "tensorflow.keras.layers.SimpleRNN" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activation" + mtype: "" + } + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "bias_constraint" + mtype: "" + } + member { + name: "bias_initializer" + mtype: "" + } + member { + name: "bias_regularizer" + mtype: "" + } + member { + name: "dropout" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "kernel_constraint" + mtype: "" + } + member { + name: "kernel_initializer" + mtype: "" + } + member { + name: "kernel_regularizer" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "recurrent_constraint" + mtype: "" + } + member { + name: "recurrent_dropout" + mtype: "" + } + member { + name: "recurrent_initializer" + mtype: "" + } + member { + name: "recurrent_regularizer" + mtype: "" + } + member { + name: "states" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "units" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "use_bias" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'False\', \'False\', \'False\', \'False\', \'False\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3b15407fca2cf65f7fa31f29b84db52b5c5d1a7a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.Softmax" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'axis\'], varargs=None, keywords=kwargs, defaults=[\'-1\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6d04415267c9ce21268b9d86a5b078d8f92db93f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.SpatialDropout1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'rate\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..04950654d55f30bf095167d176b5b2717e72f2cd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.SpatialDropout2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'rate\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c424e6dcc869f977100e77fdb543983c3ab7e63c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.SpatialDropout3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'rate\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt similarity index 94% rename from tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt index c4b9f93561de6a5d8ecc19bbae17831466b51fe6..1160d2840f5ddd2937db53406af9d4d2132a6515 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -102,7 +102,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..740a03367bd69edf797d3ea8616fdde72f6726b7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt @@ -0,0 +1,176 @@ +path: "tensorflow.keras.layers.Subtract" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a08c583adb4175ff5ee77869c80c6c0204018166 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.ThresholdedReLU" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'theta\'], varargs=None, keywords=kwargs, defaults=[\'1.0\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c1294fed0fcfca9c8607bf3e5d41efd240fd4d45 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt @@ -0,0 +1,180 @@ +path: "tensorflow.keras.layers.TimeDistributed" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..dc401d3ed0fee5b6fb4bb5563941c3461eb592f4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.UpSampling1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'size\'], varargs=None, keywords=kwargs, defaults=[\'2\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4b5165ae9793f900fb474affe52b9abaeb64adbd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.UpSampling2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'size\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2)\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..789af15fea8c0d41dd3f0c00e7be3afd6afafecf --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.UpSampling3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'size\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2, 2)\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0536a7cee7e6dd5878f532854753cebeaa043c21 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt @@ -0,0 +1,179 @@ +path: "tensorflow.keras.layers.Wrapper" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8915353ec334f28c4ed058b20a506ff102ca1f61 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.ZeroPadding1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'padding\'], varargs=None, keywords=kwargs, defaults=[\'1\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6efb5ef15a133877666decfd1f2b40fad4463469 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.ZeroPadding2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4c33c5d0bf800239e2bff4cc874e594b515a8071 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.keras.layers.ZeroPadding3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9d7e5bb8c7808689bedd8abb835e61c1f38fdb1d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt @@ -0,0 +1,435 @@ +path: "tensorflow.keras.layers" +tf_module { + member { + name: "Activation" + mtype: "" + } + member { + name: "ActivityRegularization" + mtype: "" + } + member { + name: "Add" + mtype: "" + } + member { + name: "AlphaDropout" + mtype: "" + } + member { + name: "Average" + mtype: "" + } + member { + name: "AveragePooling1D" + mtype: "" + } + member { + name: "AveragePooling2D" + mtype: "" + } + member { + name: "AveragePooling3D" + mtype: "" + } + member { + name: "AvgPool1D" + mtype: "" + } + member { + name: "AvgPool2D" + mtype: "" + } + member { + name: "AvgPool3D" + mtype: "" + } + member { + name: "BatchNormalization" + mtype: "" + } + member { + name: "Bidirectional" + mtype: "" + } + member { + name: "Concatenate" + mtype: "" + } + member { + name: "Conv1D" + mtype: "" + } + member { + name: "Conv2D" + mtype: "" + } + member { + name: "Conv2DTranspose" + mtype: "" + } + member { + name: "Conv3D" + mtype: "" + } + member { + name: "Conv3DTranspose" + mtype: "" + } + member { + name: "ConvLSTM2D" + mtype: "" + } + member { + name: "Convolution1D" + mtype: "" + } + member { + name: "Convolution2D" + mtype: "" + } + member { + name: "Convolution2DTranspose" + mtype: "" + } + member { + name: "Convolution3D" + mtype: "" + } + member { + name: "Convolution3DTranspose" + mtype: "" + } + member { + name: "Cropping1D" + mtype: "" + } + member { + name: "Cropping2D" + mtype: "" + } + member { + name: "Cropping3D" + mtype: "" + } + member { + name: "CuDNNGRU" + mtype: "" + } + member { + name: "CuDNNLSTM" + mtype: "" + } + member { + name: "Dense" + mtype: "" + } + member { + name: "DepthwiseConv2D" + mtype: "" + } + member { + name: "Dot" + mtype: "" + } + member { + name: "Dropout" + mtype: "" + } + member { + name: "ELU" + mtype: "" + } + member { + name: "Embedding" + mtype: "" + } + member { + name: "Flatten" + mtype: "" + } + member { + name: "GRU" + mtype: "" + } + member { + name: "GRUCell" + mtype: "" + } + member { + name: "GaussianDropout" + mtype: "" + } + member { + name: "GaussianNoise" + mtype: "" + } + member { + name: "GlobalAveragePooling1D" + mtype: "" + } + member { + name: "GlobalAveragePooling2D" + mtype: "" + } + member { + name: "GlobalAveragePooling3D" + mtype: "" + } + member { + name: "GlobalAvgPool1D" + mtype: "" + } + member { + name: "GlobalAvgPool2D" + mtype: "" + } + member { + name: "GlobalAvgPool3D" + mtype: "" + } + member { + name: "GlobalMaxPool1D" + mtype: "" + } + member { + name: "GlobalMaxPool2D" + mtype: "" + } + member { + name: "GlobalMaxPool3D" + mtype: "" + } + member { + name: "GlobalMaxPooling1D" + mtype: "" + } + member { + name: "GlobalMaxPooling2D" + mtype: "" + } + member { + name: "GlobalMaxPooling3D" + mtype: "" + } + member { + name: "InputLayer" + mtype: "" + } + member { + name: "InputSpec" + mtype: "" + } + member { + name: "LSTM" + mtype: "" + } + member { + name: "LSTMCell" + mtype: "" + } + member { + name: "Lambda" + mtype: "" + } + member { + name: "Layer" + mtype: "" + } + member { + name: "LeakyReLU" + mtype: "" + } + member { + name: "LocallyConnected1D" + mtype: "" + } + member { + name: "LocallyConnected2D" + mtype: "" + } + member { + name: "Masking" + mtype: "" + } + member { + name: "MaxPool1D" + mtype: "" + } + member { + name: "MaxPool2D" + mtype: "" + } + member { + name: "MaxPool3D" + mtype: "" + } + member { + name: "MaxPooling1D" + mtype: "" + } + member { + name: "MaxPooling2D" + mtype: "" + } + member { + name: "MaxPooling3D" + mtype: "" + } + member { + name: "Maximum" + mtype: "" + } + member { + name: "Minimum" + mtype: "" + } + member { + name: "Multiply" + mtype: "" + } + member { + name: "PReLU" + mtype: "" + } + member { + name: "Permute" + mtype: "" + } + member { + name: "RNN" + mtype: "" + } + member { + name: "ReLU" + mtype: "" + } + member { + name: "RepeatVector" + mtype: "" + } + member { + name: "Reshape" + mtype: "" + } + member { + name: "SeparableConv1D" + mtype: "" + } + member { + name: "SeparableConv2D" + mtype: "" + } + member { + name: "SeparableConvolution1D" + mtype: "" + } + member { + name: "SeparableConvolution2D" + mtype: "" + } + member { + name: "SimpleRNN" + mtype: "" + } + member { + name: "SimpleRNNCell" + mtype: "" + } + member { + name: "Softmax" + mtype: "" + } + member { + name: "SpatialDropout1D" + mtype: "" + } + member { + name: "SpatialDropout2D" + mtype: "" + } + member { + name: "SpatialDropout3D" + mtype: "" + } + member { + name: "StackedRNNCells" + mtype: "" + } + member { + name: "Subtract" + mtype: "" + } + member { + name: "ThresholdedReLU" + mtype: "" + } + member { + name: "TimeDistributed" + mtype: "" + } + member { + name: "UpSampling1D" + mtype: "" + } + member { + name: "UpSampling2D" + mtype: "" + } + member { + name: "UpSampling3D" + mtype: "" + } + member { + name: "Wrapper" + mtype: "" + } + member { + name: "ZeroPadding1D" + mtype: "" + } + member { + name: "ZeroPadding2D" + mtype: "" + } + member { + name: "ZeroPadding3D" + mtype: "" + } + member_method { + name: "Input" + argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "add" + argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "average" + argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "concatenate" + argspec: "args=[\'inputs\', \'axis\'], varargs=None, keywords=kwargs, defaults=[\'-1\'], " + } + member_method { + name: "dot" + argspec: "args=[\'inputs\', \'axes\', \'normalize\'], varargs=None, keywords=kwargs, defaults=[\'False\'], " + } + member_method { + name: "maximum" + argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "minimum" + argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "multiply" + argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "subtract" + argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..eca6b915388ebff0103f7ad16f43c6be0df60b7d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt @@ -0,0 +1,115 @@ +path: "tensorflow.keras.losses" +tf_module { + member_method { + name: "KLD" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MAE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MAPE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MSE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MSLE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "binary_crossentropy" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "categorical_crossentropy" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "categorical_hinge" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "cosine" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "cosine_proximity" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "deserialize" + argspec: "args=[\'name\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get" + argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "hinge" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "kld" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "kullback_leibler_divergence" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "logcosh" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mae" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mape" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mean_absolute_error" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mean_absolute_percentage_error" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mean_squared_error" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mean_squared_logarithmic_error" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mse" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "msle" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "poisson" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "serialize" + argspec: "args=[\'loss\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "sparse_categorical_crossentropy" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "squared_hinge" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..73b577da373b1381a7e8d5841d6e002452a21f9e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt @@ -0,0 +1,123 @@ +path: "tensorflow.keras.metrics" +tf_module { + member_method { + name: "KLD" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MAE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MAPE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MSE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MSLE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "binary_accuracy" + argspec: "args=[\'y_true\', \'y_pred\', \'threshold\'], varargs=None, keywords=None, defaults=[\'0.5\'], " + } + member_method { + name: "binary_crossentropy" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "categorical_accuracy" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "categorical_crossentropy" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "cosine" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "cosine_proximity" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "deserialize" + argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get" + argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "hinge" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "kld" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "kullback_leibler_divergence" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mae" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mape" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mean_absolute_error" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mean_absolute_percentage_error" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mean_squared_error" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mean_squared_logarithmic_error" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mse" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "msle" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "poisson" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "serialize" + argspec: "args=[\'metric\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "sparse_categorical_crossentropy" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "sparse_top_k_categorical_accuracy" + argspec: "args=[\'y_true\', \'y_pred\', \'k\'], varargs=None, keywords=None, defaults=[\'5\'], " + } + member_method { + name: "squared_hinge" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "top_k_categorical_accuracy" + argspec: "args=[\'y_true\', \'y_pred\', \'k\'], varargs=None, keywords=None, defaults=[\'5\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..85f7c2bfedb936d3b21624448cf8875775de918b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt @@ -0,0 +1,268 @@ +path: "tensorflow.keras.models.Model" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "layers" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "state_updates" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "uses_learning_phase" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compile" + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + } + member_method { + name: "evaluate_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "fit" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + } + member_method { + name: "fit_generator" + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_layer" + argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_weights" + argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "predict_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "predict_on_batch" + argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "save" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], " + } + member_method { + name: "save_weights" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "summary" + argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "test_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "to_json" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "to_yaml" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "train_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6a83129f7df46a63c8fa1080a6a35dc3f558c549 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt @@ -0,0 +1,289 @@ +path: "tensorflow.keras.models.Sequential" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "layers" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "state_updates" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "uses_learning_phase" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'layers\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add" + argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compile" + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " + } + member_method { + name: "evaluate_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "fit" + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " + } + member_method { + name: "fit_generator" + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_layer" + argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_weights" + argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "pop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " + } + member_method { + name: "predict_classes" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " + } + member_method { + name: "predict_generator" + argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], " + } + member_method { + name: "predict_on_batch" + argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict_proba" + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " + } + member_method { + name: "reset_states" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "save" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], " + } + member_method { + name: "save_weights" + argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "summary" + argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "symbolic_set_inputs" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "test_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "to_json" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "to_yaml" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "train_on_batch" + argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8ba0e7480bf5100e4bb10ceaf220cfaac0f43f52 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.pbtxt @@ -0,0 +1,31 @@ +path: "tensorflow.keras.models" +tf_module { + member { + name: "Model" + mtype: "" + } + member { + name: "Sequential" + mtype: "" + } + member_method { + name: "load_model" + argspec: "args=[\'filepath\', \'custom_objects\', \'compile\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " + } + member_method { + name: "model_from_config" + argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "model_from_json" + argspec: "args=[\'json_string\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "model_from_yaml" + argspec: "args=[\'yaml_string\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "save_model" + argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b9ce154bddef609e0aaf6627d6f59de551e51e3b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adadelta.pbtxt @@ -0,0 +1,34 @@ +path: "tensorflow.keras.optimizers.Adadelta" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'1.0\', \'0.95\', \'None\', \'0.0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_gradients" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d0dc9e37a386a26143365eb443d5ba5fce8a87d9 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adagrad.pbtxt @@ -0,0 +1,34 @@ +path: "tensorflow.keras.optimizers.Adagrad" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'lr\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'None\', \'0.0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_gradients" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..06815fa99a4a474ec131c29d0cbc78bb2b9cb72d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adam.pbtxt @@ -0,0 +1,34 @@ +path: "tensorflow.keras.optimizers.Adam" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\', \'amsgrad\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'None\', \'0.0\', \'False\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_gradients" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..47b55fdb44e79e976b6de13d760a7cf175323c6c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-adamax.pbtxt @@ -0,0 +1,34 @@ +path: "tensorflow.keras.optimizers.Adamax" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'None\', \'0.0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_gradients" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8c63a7dda98568b24ea1b3cda15d4c840fbfd804 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-nadam.pbtxt @@ -0,0 +1,34 @@ +path: "tensorflow.keras.optimizers.Nadam" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'schedule_decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'None\', \'0.004\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_gradients" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..53d64dae932e250b9d81b2767a833de3bac8c403 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-optimizer.pbtxt @@ -0,0 +1,33 @@ +path: "tensorflow.keras.optimizers.Optimizer" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_gradients" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a1e9b8cceb95e8f25ac5f414fadacf237be33cd9 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-r-m-sprop.pbtxt @@ -0,0 +1,34 @@ +path: "tensorflow.keras.optimizers.RMSprop" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'None\', \'0.0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_gradients" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a67fefb1bafebd62db9f6108f0fe1847b5d2e0cb --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.-s-g-d.pbtxt @@ -0,0 +1,34 @@ +path: "tensorflow.keras.optimizers.SGD" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'lr\', \'momentum\', \'decay\', \'nesterov\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'0.0\', \'0.0\', \'False\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_gradients" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7257b02087e237eaa47ed6a042559aa1332fc87b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.pbtxt @@ -0,0 +1,47 @@ +path: "tensorflow.keras.optimizers" +tf_module { + member { + name: "Adadelta" + mtype: "" + } + member { + name: "Adagrad" + mtype: "" + } + member { + name: "Adam" + mtype: "" + } + member { + name: "Adamax" + mtype: "" + } + member { + name: "Nadam" + mtype: "" + } + member { + name: "Optimizer" + mtype: "" + } + member { + name: "RMSprop" + mtype: "" + } + member { + name: "SGD" + mtype: "" + } + member_method { + name: "deserialize" + argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get" + argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "serialize" + argspec: "args=[\'optimizer\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..754b3b84b08b08c7d12eba4ddad0a483440055a9 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.pbtxt @@ -0,0 +1,83 @@ +path: "tensorflow.keras" +tf_module { + member { + name: "Model" + mtype: "" + } + member { + name: "Sequential" + mtype: "" + } + member { + name: "activations" + mtype: "" + } + member { + name: "applications" + mtype: "" + } + member { + name: "backend" + mtype: "" + } + member { + name: "callbacks" + mtype: "" + } + member { + name: "constraints" + mtype: "" + } + member { + name: "datasets" + mtype: "" + } + member { + name: "estimator" + mtype: "" + } + member { + name: "initializers" + mtype: "" + } + member { + name: "layers" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "models" + mtype: "" + } + member { + name: "optimizers" + mtype: "" + } + member { + name: "preprocessing" + mtype: "" + } + member { + name: "regularizers" + mtype: "" + } + member { + name: "utils" + mtype: "" + } + member { + name: "wrappers" + mtype: "" + } + member_method { + name: "Input" + argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-iterator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-iterator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.image.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.preprocessing.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.sequence.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.sequence.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.text.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.keras.preprocessing.text.pbtxt diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.regularizers.-l1-l2.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.regularizers.-l1-l2.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a45fb7b55e58a5679427752af22dce49203dc1cc --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.regularizers.-l1-l2.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.regularizers.L1L2" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'l1\', \'l2\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.regularizers.-regularizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.regularizers.-regularizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..641001a646564d0a466739ee6d2bdd31a27beab7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.regularizers.-regularizer.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.keras.regularizers.Regularizer" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.regularizers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.regularizers.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..bb10d41d704ca456fbf5b8bd19324ee71f17ba8d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.regularizers.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.keras.regularizers" +tf_module { + member { + name: "L1L2" + mtype: "" + } + member { + name: "Regularizer" + mtype: "" + } + member_method { + name: "deserialize" + argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get" + argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "l1" + argspec: "args=[\'l\'], varargs=None, keywords=None, defaults=[\'0.01\'], " + } + member_method { + name: "l1_l2" + argspec: "args=[\'l1\', \'l2\'], varargs=None, keywords=None, defaults=[\'0.01\', \'0.01\'], " + } + member_method { + name: "l2" + argspec: "args=[\'l\'], varargs=None, keywords=None, defaults=[\'0.01\'], " + } + member_method { + name: "serialize" + argspec: "args=[\'regularizer\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-custom-object-scope.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-custom-object-scope.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..109682046b990107915d65be3cad86ead3e22688 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-custom-object-scope.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.keras.utils.CustomObjectScope" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=args, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-generator-enqueuer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-generator-enqueuer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..939fd547d06bbd03b7e1a1db1404263ff01fd07c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-generator-enqueuer.pbtxt @@ -0,0 +1,26 @@ +path: "tensorflow.keras.utils.GeneratorEnqueuer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'wait_time\', \'seed\'], varargs=None, keywords=None, defaults=[\'False\', \'0.05\', \'None\'], " + } + member_method { + name: "get" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_running" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start" + argspec: "args=[\'self\', \'workers\', \'max_queue_size\'], varargs=None, keywords=None, defaults=[\'1\', \'10\'], " + } + member_method { + name: "stop" + argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6b832051a975b61ba05874c3dda558c63aeaa055 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt @@ -0,0 +1,29 @@ +path: "tensorflow.keras.utils.HDF5Matrix" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "dtype" + mtype: "" + } + member { + name: "ndim" + mtype: "" + } + member { + name: "refs" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "size" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'datapath\', \'dataset\', \'start\', \'end\', \'normalizer\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-progbar.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-progbar.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..be4496e753f8bdcd76a4761f9bd1804a77380359 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-progbar.pbtxt @@ -0,0 +1,17 @@ +path: "tensorflow.keras.utils.Progbar" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'target\', \'width\', \'verbose\', \'interval\', \'stateful_metrics\'], varargs=None, keywords=None, defaults=[\'30\', \'1\', \'0.05\', \'None\'], " + } + member_method { + name: "add" + argspec: "args=[\'self\', \'n\', \'values\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "update" + argspec: "args=[\'self\', \'current\', \'values\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-sequence-enqueuer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-sequence-enqueuer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a9e499d1009b5a7458080db6c10a948af21c7b6c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-sequence-enqueuer.pbtxt @@ -0,0 +1,24 @@ +path: "tensorflow.keras.utils.SequenceEnqueuer" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "get" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_running" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start" + argspec: "args=[\'self\', \'workers\', \'max_queue_size\'], varargs=None, keywords=None, defaults=[\'1\', \'10\'], " + } + member_method { + name: "stop" + argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-sequence.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-sequence.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e2dc932dc86dbba49d186e1dbc4bc026a52f6ef5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-sequence.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.keras.utils.Sequence" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4d7a1519ce59b6f0a7f0bbfb3292842a6f21dffd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt @@ -0,0 +1,67 @@ +path: "tensorflow.keras.utils" +tf_module { + member { + name: "CustomObjectScope" + mtype: "" + } + member { + name: "GeneratorEnqueuer" + mtype: "" + } + member { + name: "HDF5Matrix" + mtype: "" + } + member { + name: "Progbar" + mtype: "" + } + member { + name: "Sequence" + mtype: "" + } + member { + name: "SequenceEnqueuer" + mtype: "" + } + member_method { + name: "convert_all_kernels_in_model" + argspec: "args=[\'model\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "custom_object_scope" + argspec: "args=[], varargs=args, keywords=None, defaults=None" + } + member_method { + name: "deserialize_keras_object" + argspec: "args=[\'identifier\', \'module_objects\', \'custom_objects\', \'printable_module_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'object\'], " + } + member_method { + name: "get_custom_objects" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_file" + argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], " + } + member_method { + name: "multi_gpu_model" + argspec: "args=[\'model\', \'gpus\', \'cpu_merge\', \'cpu_relocation\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], " + } + member_method { + name: "normalize" + argspec: "args=[\'x\', \'axis\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'2\'], " + } + member_method { + name: "plot_model" + argspec: "args=[\'model\', \'to_file\', \'show_shapes\', \'show_layer_names\', \'rankdir\'], varargs=None, keywords=None, defaults=[\'model.png\', \'False\', \'True\', \'TB\'], " + } + member_method { + name: "serialize_keras_object" + argspec: "args=[\'instance\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "to_categorical" + argspec: "args=[\'y\', \'num_classes\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.wrappers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.wrappers.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0b2fac9b7d998312d1bc080d7464d17b2b5543f5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.wrappers.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.keras.wrappers" +tf_module { + member { + name: "scikit_learn" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..67cca3af41dbf68b963fb2315b65f9f843c9a42d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt @@ -0,0 +1,42 @@ +path: "tensorflow.keras.wrappers.scikit_learn.KerasClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'build_fn\'], varargs=None, keywords=sk_params, defaults=[\'None\'], " + } + member_method { + name: "check_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "filter_sk_params" + argspec: "args=[\'self\', \'fn\', \'override\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fit" + argspec: "args=[\'self\', \'x\', \'y\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "get_params" + argspec: "args=[\'self\'], varargs=None, keywords=params, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'x\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "predict_proba" + argspec: "args=[\'self\', \'x\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "score" + argspec: "args=[\'self\', \'x\', \'y\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\'], varargs=None, keywords=params, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f4b9b7e277ecdb155327d83c57ec2a997c043555 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.keras.wrappers.scikit_learn.KerasRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'build_fn\'], varargs=None, keywords=sk_params, defaults=[\'None\'], " + } + member_method { + name: "check_params" + argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "filter_sk_params" + argspec: "args=[\'self\', \'fn\', \'override\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fit" + argspec: "args=[\'self\', \'x\', \'y\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "get_params" + argspec: "args=[\'self\'], varargs=None, keywords=params, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'x\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "score" + argspec: "args=[\'self\', \'x\', \'y\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "set_params" + argspec: "args=[\'self\'], varargs=None, keywords=params, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.wrappers.scikit_learn.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.wrappers.scikit_learn.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..fbd4d13387a931c3c947d8d0babcbfa978070de9 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.wrappers.scikit_learn.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.keras.wrappers.scikit_learn" +tf_module { + member { + name: "KerasClassifier" + mtype: "" + } + member { + name: "KerasRegressor" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-average-pooling1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c82e67526b21696a7d56517dc2cb6998882dc7a5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-average-pooling1-d.pbtxt @@ -0,0 +1,186 @@ +path: "tensorflow.layers.AveragePooling1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-average-pooling2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1d031cb5f8461145127b0f13d77e6b8774f5a0b3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-average-pooling2-d.pbtxt @@ -0,0 +1,186 @@ +path: "tensorflow.layers.AveragePooling2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-average-pooling3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a8dda6655df1d06ca77b74f0a992c8fd7e7a357d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-average-pooling3-d.pbtxt @@ -0,0 +1,186 @@ +path: "tensorflow.layers.AveragePooling3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-batch-normalization.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..97f65ed89436bd0b4027bb0cbeb80b6f1419269c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-batch-normalization.pbtxt @@ -0,0 +1,185 @@ +path: "tensorflow.layers.BatchNormalization" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\', \'trainable\', \'virtual_batch_size\', \'adjustment\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'\', \'\', \'\', \'\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ccd9578f0d62bd70ea252ddeac587d59c926b018 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv1-d.pbtxt @@ -0,0 +1,186 @@ +path: "tensorflow.layers.Conv1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv2-d-transpose.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9cbb58d721bb49bde562a57728a9ee46968e611e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv2-d-transpose.pbtxt @@ -0,0 +1,187 @@ +path: "tensorflow.layers.Conv2DTranspose" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'channels_last\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c75ea3911e17bc879d140068ef54521effd2824e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv2-d.pbtxt @@ -0,0 +1,186 @@ +path: "tensorflow.layers.Conv2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'channels_last\', \'(1, 1)\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv3-d-transpose.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5dc834e5141e58d255357e02d7446a06e6e2aa45 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv3-d-transpose.pbtxt @@ -0,0 +1,187 @@ +path: "tensorflow.layers.Conv3DTranspose" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'channels_last\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..96ab209874ac14d6acf2e8115e7f04fc35c4b2bd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-conv3-d.pbtxt @@ -0,0 +1,186 @@ +path: "tensorflow.layers.Conv3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'channels_last\', \'(1, 1, 1)\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-dense.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7e9656b3525c1d53940b869607616ff414a466cf --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-dense.pbtxt @@ -0,0 +1,185 @@ +path: "tensorflow.layers.Dense" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-dropout.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e9a2269a6e8de1f9a12f1b54d2e6dced3d4f8902 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-dropout.pbtxt @@ -0,0 +1,185 @@ +path: "tensorflow.layers.Dropout" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'rate\', \'noise_shape\', \'seed\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.5\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-flatten.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7d2eaaab2a8cb9159214a16ba65473d0b6870ac4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-flatten.pbtxt @@ -0,0 +1,185 @@ +path: "tensorflow.layers.Flatten" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-input-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-input-spec.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..fd02c919aeb5a536bd052324618983af699e7c47 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-input-spec.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.layers.InputSpec" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-layer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8bc3eb26e9ca0bf0f129db336b7ca23466fd036f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-layer.pbtxt @@ -0,0 +1,183 @@ +path: "tensorflow.layers.Layer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'trainable\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'True\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-max-pooling1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6a0dcce56ac0184ffe995662fd62b89e16257a29 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-max-pooling1-d.pbtxt @@ -0,0 +1,186 @@ +path: "tensorflow.layers.MaxPooling1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-max-pooling2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b6c84edf2a2f86240369b4053cd7351d0b59442d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-max-pooling2-d.pbtxt @@ -0,0 +1,186 @@ +path: "tensorflow.layers.MaxPooling2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-max-pooling3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..062a02fa590537b9efbf540a874eeaa6d36697f3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-max-pooling3-d.pbtxt @@ -0,0 +1,186 @@ +path: "tensorflow.layers.MaxPooling3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-separable-conv1-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..eaad0fb23ef7501c8c5b7acee6a9677665b7057f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-separable-conv1-d.pbtxt @@ -0,0 +1,187 @@ +path: "tensorflow.layers.SeparableConv1D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'1\', \'None\', \'True\', \'None\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.-separable-conv2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ece28a8ce962d8fafb3f7a397a814b903e915d48 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.-separable-conv2-d.pbtxt @@ -0,0 +1,187 @@ +path: "tensorflow.layers.SeparableConv2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'channels_last\', \'(1, 1)\', \'1\', \'None\', \'True\', \'None\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..df74c32e1f10cc7540ef105adef6be681e93d089 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt @@ -0,0 +1,147 @@ +path: "tensorflow.layers" +tf_module { + member { + name: "AveragePooling1D" + mtype: "" + } + member { + name: "AveragePooling2D" + mtype: "" + } + member { + name: "AveragePooling3D" + mtype: "" + } + member { + name: "BatchNormalization" + mtype: "" + } + member { + name: "Conv1D" + mtype: "" + } + member { + name: "Conv2D" + mtype: "" + } + member { + name: "Conv2DTranspose" + mtype: "" + } + member { + name: "Conv3D" + mtype: "" + } + member { + name: "Conv3DTranspose" + mtype: "" + } + member { + name: "Dense" + mtype: "" + } + member { + name: "Dropout" + mtype: "" + } + member { + name: "Flatten" + mtype: "" + } + member { + name: "InputSpec" + mtype: "" + } + member { + name: "Layer" + mtype: "" + } + member { + name: "MaxPooling1D" + mtype: "" + } + member { + name: "MaxPooling2D" + mtype: "" + } + member { + name: "MaxPooling3D" + mtype: "" + } + member { + name: "SeparableConv1D" + mtype: "" + } + member { + name: "SeparableConv2D" + mtype: "" + } + member_method { + name: "average_pooling1d" + argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "average_pooling2d" + argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "average_pooling3d" + argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "batch_normalization" + argspec: "args=[\'inputs\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\', \'training\', \'trainable\', \'name\', \'reuse\', \'renorm\', \'renorm_clipping\', \'renorm_momentum\', \'fused\', \'virtual_batch_size\', \'adjustment\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'\', \'\', \'\', \'\', \'None\', \'None\', \'None\', \'None\', \'False\', \'True\', \'None\', \'None\', \'False\', \'None\', \'0.99\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "conv1d" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "conv2d" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'channels_last\', \'(1, 1)\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "conv2d_transpose" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'channels_last\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "conv3d" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1, 1)\', \'valid\', \'channels_last\', \'(1, 1, 1)\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "conv3d_transpose" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1, 1)\', \'valid\', \'channels_last\', \'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "dense" + argspec: "args=[\'inputs\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "dropout" + argspec: "args=[\'inputs\', \'rate\', \'noise_shape\', \'seed\', \'training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'None\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "flatten" + argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "max_pooling1d" + argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "max_pooling2d" + argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "max_pooling3d" + argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " + } + member_method { + name: "separable_conv1d" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'1\', \'None\', \'True\', \'None\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "separable_conv2d" + argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'channels_last\', \'(1, 1)\', \'1\', \'None\', \'True\', \'None\', \'None\', \'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b6dee6317604363275a128fe8d83aaa9473a257a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-block-diag.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorBlockDiag.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-block-diag.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-block-diag.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..973705dae2fabbef0eafb38ad12e96c747aeee27 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-block-diag.pbtxt @@ -0,0 +1,134 @@ +path: "tensorflow.linalg.LinearOperatorBlockDiag" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "operators" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'operators\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3b33f3da97ec2ecb3f94e8bc309be2519fc79c62 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorCirculant.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..de917706d55214cc59f3205f0778d600a356a5b1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant.pbtxt @@ -0,0 +1,155 @@ +path: "tensorflow.linalg.LinearOperatorCirculant" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "block_depth" + mtype: "" + } + member { + name: "block_shape" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "spectrum" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'spectrum\', \'input_output_dtype\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\', \'None\', \'None\', \'True\', \'LinearOperatorCirculant\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_hermitian_spectrum" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_hermitian_spectrum\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "block_shape_tensor" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "convolution_kernel" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'convolution_kernel\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..591bc9631a1d8ecbbd6e133b99c67e432399d73f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorCirculant2D.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c4e6a21c3ac9324f5dd445dc65415c2abb4c6e9f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt @@ -0,0 +1,155 @@ +path: "tensorflow.linalg.LinearOperatorCirculant2D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "block_depth" + mtype: "" + } + member { + name: "block_shape" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "spectrum" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'spectrum\', \'input_output_dtype\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\', \'None\', \'None\', \'True\', \'LinearOperatorCirculant2D\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_hermitian_spectrum" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_hermitian_spectrum\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "block_shape_tensor" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "convolution_kernel" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'convolution_kernel\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d643139a53fc501fe2997a2b9f2d11c57b96f2e4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorCirculant3D.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2e085a8e289e21173789041efb9254e992bd723b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt @@ -0,0 +1,155 @@ +path: "tensorflow.linalg.LinearOperatorCirculant3D" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "block_depth" + mtype: "" + } + member { + name: "block_shape" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "spectrum" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'spectrum\', \'input_output_dtype\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\', \'None\', \'None\', \'True\', \'LinearOperatorCirculant3D\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_hermitian_spectrum" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_hermitian_spectrum\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "block_shape_tensor" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "convolution_kernel" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'convolution_kernel\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-composition.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-composition.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1adbcb41adfac33acfdb415662ced7992e21385e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-composition.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorComposition.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-composition.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-composition.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..42d22bce42d8850a784afae3f67771ef1cfe5403 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-composition.pbtxt @@ -0,0 +1,134 @@ +path: "tensorflow.linalg.LinearOperatorComposition" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "operators" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'operators\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-diag.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-diag.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..023d90ccdba8a8739a11f4691d33b7087bedcc0b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-diag.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorDiag.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-diag.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-diag.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d6749fdcec69425e83a044409ec695d2661f782e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-diag.pbtxt @@ -0,0 +1,134 @@ +path: "tensorflow.linalg.LinearOperatorDiag" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "diag" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'diag\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'LinearOperatorDiag\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-full-matrix.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-full-matrix.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..381072e76c4d069ebf51fec44079b30f17cafc06 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-full-matrix.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorFullMatrix.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-full-matrix.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-full-matrix.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d9f363d1336210623536e8293a6290d9ebfc2fe1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-full-matrix.pbtxt @@ -0,0 +1,130 @@ +path: "tensorflow.linalg.LinearOperatorFullMatrix" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'matrix\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'LinearOperatorFullMatrix\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-identity.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-identity.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5d115b35fb79cbc176a9e8a9bf1ec0f0edcc79e6 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-identity.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorIdentity.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-identity.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..aac7ee31ed62c22b2e86d287d48c68c7e905fd00 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-identity.pbtxt @@ -0,0 +1,131 @@ +path: "tensorflow.linalg.LinearOperatorIdentity" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_rows\', \'batch_shape\', \'dtype\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'assert_proper_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'True\', \'True\', \'True\', \'False\', \'LinearOperatorIdentity\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'mat\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5c6784dd02104129a9ac38fe171d87c115efbbf0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-kronecker.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorKronecker.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-kronecker.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-kronecker.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c11d39082939eda4520b3955b767022bd485b5be --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-kronecker.pbtxt @@ -0,0 +1,134 @@ +path: "tensorflow.linalg.LinearOperatorKronecker" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "operators" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'operators\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-low-rank-update.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-low-rank-update.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1f0d33298a252a8b3da6eea9fd4bc096e8dd6745 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-low-rank-update.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorLowRankUpdate.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3ee800269e617390c25248a2c847cbe259b18e79 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt @@ -0,0 +1,154 @@ +path: "tensorflow.linalg.LinearOperatorLowRankUpdate" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "base_operator" + mtype: "" + } + member { + name: "batch_shape" + mtype: "" + } + member { + name: "diag_operator" + mtype: "" + } + member { + name: "diag_update" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_diag_update_positive" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member { + name: "u" + mtype: "" + } + member { + name: "v" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'base_operator\', \'u\', \'diag_update\', \'v\', \'is_diag_update_positive\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'LinearOperatorLowRankUpdate\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-lower-triangular.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-lower-triangular.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2683430f4fc5d96d63c5b6fdb4035d6e5e8ba609 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-lower-triangular.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorLowerTriangular.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..63a1bc2321e35645700778c5906d1b8659eb4a32 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt @@ -0,0 +1,130 @@ +path: "tensorflow.linalg.LinearOperatorLowerTriangular" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'tril\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'LinearOperatorLowerTriangular\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-scaled-identity.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-scaled-identity.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..38bf7ad586a063046f260aca9b1c517a343c4c05 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-scaled-identity.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorScaledIdentity.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e2c5a505a7d2f9abbee5b3bb4f92ee8843198c51 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt @@ -0,0 +1,135 @@ +path: "tensorflow.linalg.LinearOperatorScaledIdentity" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "multiplier" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_rows\', \'multiplier\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'assert_proper_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'False\', \'LinearOperatorScaledIdentity\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'mat\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-zeros.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-zeros.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..49ff85728ffab559ec706691356ce071aab89083 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-zeros.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperatorZeros.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-zeros.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-zeros.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a1b0e06b4753488bc9fcbe9aeb0d260092745f9c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-zeros.pbtxt @@ -0,0 +1,130 @@ +path: "tensorflow.linalg.LinearOperatorZeros" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'assert_proper_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'True\', \'False\', \'True\', \'False\', \'LinearOperatorZeros\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'mat\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator.__metaclass__.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..38da809b360e5ea69b4324a859ed69da679bc436 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.linalg.LinearOperator.__metaclass__" +tf_class { + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6d849dc040f61b498b100820bf7be3d4bc264bb4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator.pbtxt @@ -0,0 +1,129 @@ +path: "tensorflow.linalg.LinearOperator" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "batch_shape" + mtype: "" + } + member { + name: "domain_dimension" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph_parents" + mtype: "" + } + member { + name: "is_non_singular" + mtype: "" + } + member { + name: "is_positive_definite" + mtype: "" + } + member { + name: "is_self_adjoint" + mtype: "" + } + member { + name: "is_square" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "range_dimension" + mtype: "" + } + member { + name: "shape" + mtype: "" + } + member { + name: "tensor_rank" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'graph_parents\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_to_tensor" + argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], " + } + member_method { + name: "assert_non_singular" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], " + } + member_method { + name: "assert_positive_definite" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], " + } + member_method { + name: "assert_self_adjoint" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], " + } + member_method { + name: "batch_shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], " + } + member_method { + name: "determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], " + } + member_method { + name: "domain_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], " + } + member_method { + name: "log_abs_determinant" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], " + } + member_method { + name: "matmul" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], " + } + member_method { + name: "matvec" + argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], " + } + member_method { + name: "range_dimension_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], " + } + member_method { + name: "shape_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], " + } + member_method { + name: "solve" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], " + } + member_method { + name: "solvevec" + argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], " + } + member_method { + name: "tensor_rank_tensor" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], " + } + member_method { + name: "to_dense" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], " + } + member_method { + name: "trace" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d979116887a739d2d372687fac0e5ea3b39a4b69 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt @@ -0,0 +1,175 @@ +path: "tensorflow.linalg" +tf_module { + member { + name: "LinearOperator" + mtype: "" + } + member { + name: "LinearOperatorBlockDiag" + mtype: "" + } + member { + name: "LinearOperatorCirculant" + mtype: "" + } + member { + name: "LinearOperatorCirculant2D" + mtype: "" + } + member { + name: "LinearOperatorCirculant3D" + mtype: "" + } + member { + name: "LinearOperatorComposition" + mtype: "" + } + member { + name: "LinearOperatorDiag" + mtype: "" + } + member { + name: "LinearOperatorFullMatrix" + mtype: "" + } + member { + name: "LinearOperatorIdentity" + mtype: "" + } + member { + name: "LinearOperatorKronecker" + mtype: "" + } + member { + name: "LinearOperatorLowRankUpdate" + mtype: "" + } + member { + name: "LinearOperatorLowerTriangular" + mtype: "" + } + member { + name: "LinearOperatorScaledIdentity" + mtype: "" + } + member { + name: "LinearOperatorZeros" + mtype: "" + } + member_method { + name: "adjoint" + argspec: "args=[\'matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "band_part" + argspec: "args=[\'input\', \'num_lower\', \'num_upper\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cholesky" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cholesky_solve" + argspec: "args=[\'chol\', \'rhs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cross" + argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "det" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "diag" + argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "diag_part" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "eigh" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "eigvalsh" + argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "einsum" + argspec: "args=[\'equation\'], varargs=inputs, keywords=kwargs, defaults=None" + } + member_method { + name: "expm" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "eye" + argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"\", \'None\'], " + } + member_method { + name: "inv" + argspec: "args=[\'input\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "logdet" + argspec: "args=[\'matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logm" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lstsq" + argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], " + } + member_method { + name: "norm" + argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "qr" + argspec: "args=[\'input\', \'full_matrices\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "set_diag" + argspec: "args=[\'input\', \'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "slogdet" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "solve" + argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "svd" + argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], " + } + member_method { + name: "tensor_diag" + argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tensor_diag_part" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tensordot" + argspec: "args=[\'a\', \'b\', \'axes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "trace" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "transpose" + argspec: "args=[\'a\', \'name\', \'conjugate\'], varargs=None, keywords=None, defaults=[\'matrix_transpose\', \'False\'], " + } + member_method { + name: "triangular_solve" + argspec: "args=[\'matrix\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.logging.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.logging.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..85bb15455da624962744a0cc856e79e0a6d57d7c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.logging.pbtxt @@ -0,0 +1,83 @@ +path: "tensorflow.logging" +tf_module { + member { + name: "DEBUG" + mtype: "" + } + member { + name: "ERROR" + mtype: "" + } + member { + name: "FATAL" + mtype: "" + } + member { + name: "INFO" + mtype: "" + } + member { + name: "WARN" + mtype: "" + } + member_method { + name: "TaskLevelStatusMessage" + argspec: "args=[\'msg\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "debug" + argspec: "args=[\'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "error" + argspec: "args=[\'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "fatal" + argspec: "args=[\'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "flush" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_verbosity" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "info" + argspec: "args=[\'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "log" + argspec: "args=[\'level\', \'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "log_every_n" + argspec: "args=[\'level\', \'msg\', \'n\'], varargs=args, keywords=None, defaults=None" + } + member_method { + name: "log_first_n" + argspec: "args=[\'level\', \'msg\', \'n\'], varargs=args, keywords=None, defaults=None" + } + member_method { + name: "log_if" + argspec: "args=[\'level\', \'msg\', \'condition\'], varargs=args, keywords=None, defaults=None" + } + member_method { + name: "set_verbosity" + argspec: "args=[\'v\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "vlog" + argspec: "args=[\'level\', \'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "warn" + argspec: "args=[\'msg\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "warning" + argspec: "args=[\'msg\'], varargs=args, keywords=kwargs, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.losses.-reduction.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.losses.-reduction.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..258ad5047eb6e82eeb9c0941b0acf0573e5ca61d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.losses.-reduction.pbtxt @@ -0,0 +1,40 @@ +path: "tensorflow.losses.Reduction" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "MEAN" + mtype: "" + } + member { + name: "NONE" + mtype: "" + } + member { + name: "SUM" + mtype: "" + } + member { + name: "SUM_BY_NONZERO_WEIGHTS" + mtype: "" + } + member { + name: "SUM_OVER_BATCH_SIZE" + mtype: "" + } + member { + name: "SUM_OVER_NONZERO_WEIGHTS" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "all" + argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "validate" + argspec: "args=[\'cls\', \'key\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.losses.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.losses.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c1d190ae116e94ec8f837237e54b6fcff7358254 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.losses.pbtxt @@ -0,0 +1,71 @@ +path: "tensorflow.losses" +tf_module { + member { + name: "Reduction" + mtype: "" + } + member_method { + name: "absolute_difference" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'loss\', \'loss_collection\'], varargs=None, keywords=None, defaults=[\'losses\'], " + } + member_method { + name: "compute_weighted_loss" + argspec: "args=[\'losses\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "cosine_distance" + argspec: "args=[\'labels\', \'predictions\', \'axis\', \'weights\', \'scope\', \'loss_collection\', \'reduction\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\', \'None\'], " + } + member_method { + name: "get_losses" + argspec: "args=[\'scope\', \'loss_collection\'], varargs=None, keywords=None, defaults=[\'None\', \'losses\'], " + } + member_method { + name: "get_regularization_loss" + argspec: "args=[\'scope\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'total_regularization_loss\'], " + } + member_method { + name: "get_regularization_losses" + argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_total_loss" + argspec: "args=[\'add_regularization_losses\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'total_loss\'], " + } + member_method { + name: "hinge_loss" + argspec: "args=[\'labels\', \'logits\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "huber_loss" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'delta\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "log_loss" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'epsilon\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1e-07\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "mean_pairwise_squared_error" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'scope\', \'loss_collection\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\'], " + } + member_method { + name: "mean_squared_error" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "sigmoid_cross_entropy" + argspec: "args=[\'multi_class_labels\', \'logits\', \'weights\', \'label_smoothing\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "softmax_cross_entropy" + argspec: "args=[\'onehot_labels\', \'logits\', \'weights\', \'label_smoothing\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } + member_method { + name: "sparse_softmax_cross_entropy" + argspec: "args=[\'labels\', \'logits\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.manip.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.manip.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9add462396ea526ae94678e969c9acf5bce86df1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.manip.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.manip" +tf_module { + member_method { + name: "batch_to_space_nd" + argspec: "args=[\'input\', \'block_shape\', \'crops\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "gather_nd" + argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reshape" + argspec: "args=[\'tensor\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reverse" + argspec: "args=[\'tensor\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "roll" + argspec: "args=[\'input\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "scatter_nd" + argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "space_to_batch_nd" + argspec: "args=[\'input\', \'block_shape\', \'paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tile" + argspec: "args=[\'input\', \'multiples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a308c76ebc08df06c0c360579451ea70e60695d4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt @@ -0,0 +1,239 @@ +path: "tensorflow.math" +tf_module { + member_method { + name: "acos" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "acosh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "asin" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "asinh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atan2" + argspec: "args=[\'y\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atanh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bessel_i0" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bessel_i0e" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bessel_i1" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bessel_i1e" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "betainc" + argspec: "args=[\'a\', \'b\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ceil" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cos" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cosh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "digamma" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "erfc" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "exp" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "expm1" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "floor" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "greater" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "greater_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "igamma" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "igammac" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "invert_permutation" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "less" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "less_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "lgamma" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "log" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "log1p" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_and" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_not" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "logical_or" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "maximum" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "minimum" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "not_equal" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "polygamma" + argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "polyval" + argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reciprocal" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rint" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "rsqrt" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_max" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_mean" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_min" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_prod" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "segment_sum" + argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sin" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sinh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "softplus" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "softsign" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "squared_difference" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "tan" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_max" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_min" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_prod" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "unsorted_segment_sum" + argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "zeta" + argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e9b996c9f53e9062dcdd39ef22f99eef5175eb35 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt @@ -0,0 +1,135 @@ +path: "tensorflow.metrics" +tf_module { + member_method { + name: "accuracy" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "auc" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'curve\', \'name\', \'summation_method\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'ROC\', \'None\', \'trapezoidal\'], " + } + member_method { + name: "average_precision_at_k" + argspec: "args=[\'labels\', \'predictions\', \'k\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "false_negatives" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "false_negatives_at_thresholds" + argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "false_positives" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "false_positives_at_thresholds" + argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean" + argspec: "args=[\'values\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_absolute_error" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_cosine_distance" + argspec: "args=[\'labels\', \'predictions\', \'dim\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_iou" + argspec: "args=[\'labels\', \'predictions\', \'num_classes\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_per_class_accuracy" + argspec: "args=[\'labels\', \'predictions\', \'num_classes\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_relative_error" + argspec: "args=[\'labels\', \'predictions\', \'normalizer\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_squared_error" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "mean_tensor" + argspec: "args=[\'values\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "percentage_below" + argspec: "args=[\'values\', \'threshold\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "precision" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "precision_at_k" + argspec: "args=[\'labels\', \'predictions\', \'k\', \'class_id\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "precision_at_thresholds" + argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "precision_at_top_k" + argspec: "args=[\'labels\', \'predictions_idx\', \'k\', \'class_id\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "recall" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "recall_at_k" + argspec: "args=[\'labels\', \'predictions\', \'k\', \'class_id\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "recall_at_thresholds" + argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "recall_at_top_k" + argspec: "args=[\'labels\', \'predictions_idx\', \'k\', \'class_id\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "root_mean_squared_error" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sensitivity_at_specificity" + argspec: "args=[\'labels\', \'predictions\', \'specificity\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sparse_average_precision_at_k" + argspec: "args=[\'labels\', \'predictions\', \'k\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "sparse_precision_at_k" + argspec: "args=[\'labels\', \'predictions\', \'k\', \'class_id\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "specificity_at_sensitivity" + argspec: "args=[\'labels\', \'predictions\', \'sensitivity\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "true_negatives" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "true_negatives_at_thresholds" + argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "true_positives" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "true_positives_at_thresholds" + argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.name_scope.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.name_scope.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..80418970132377a5d578e4f11fa4091a19202cf3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.name_scope.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.name_scope" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'default_name\', \'values\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d9e5b0d0fca8bbcf82feb34304f2a1e4f43f48dd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt @@ -0,0 +1,359 @@ +path: "tensorflow.nn" +tf_module { + member { + name: "rnn_cell" + mtype: "" + } + member { + name: "swish" + mtype: "" + } + member_method { + name: "all_candidate_sampler" + argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "atrous_conv2d" + argspec: "args=[\'value\', \'filters\', \'rate\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "atrous_conv2d_transpose" + argspec: "args=[\'value\', \'filters\', \'output_shape\', \'rate\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "avg_pool" + argspec: "args=[\'value\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], " + } + member_method { + name: "avg_pool3d" + argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NDHWC\', \'None\'], " + } + member_method { + name: "batch_norm_with_global_normalization" + argspec: "args=[\'t\', \'m\', \'v\', \'beta\', \'gamma\', \'variance_epsilon\', \'scale_after_normalization\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "batch_normalization" + argspec: "args=[\'x\', \'mean\', \'variance\', \'offset\', \'scale\', \'variance_epsilon\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bias_add" + argspec: "args=[\'value\', \'bias\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "bidirectional_dynamic_rnn" + argspec: "args=[\'cell_fw\', \'cell_bw\', \'inputs\', \'sequence_length\', \'initial_state_fw\', \'initial_state_bw\', \'dtype\', \'parallel_iterations\', \'swap_memory\', \'time_major\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "compute_accidental_hits" + argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "conv1d" + argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "conv2d" + argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], " + } + member_method { + name: "conv2d_backprop_filter" + argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], " + } + member_method { + name: "conv2d_backprop_input" + argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], " + } + member_method { + name: "conv2d_transpose" + argspec: "args=[\'value\', \'filter\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'SAME\', \'NHWC\', \'None\'], " + } + member_method { + name: "conv3d" + argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NDHWC\', \'[1, 1, 1, 1, 1]\', \'None\'], " + } + member_method { + name: "conv3d_backprop_filter_v2" + argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NDHWC\', \'[1, 1, 1, 1, 1]\', \'None\'], " + } + member_method { + name: "conv3d_transpose" + argspec: "args=[\'value\', \'filter\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'SAME\', \'NDHWC\', \'None\'], " + } + member_method { + name: "convolution" + argspec: "args=[\'input\', \'filter\', \'padding\', \'strides\', \'dilation_rate\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "crelu" + argspec: "args=[\'features\', \'name\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], " + } + member_method { + name: "ctc_beam_search_decoder" + argspec: "args=[\'inputs\', \'sequence_length\', \'beam_width\', \'top_paths\', \'merge_repeated\'], varargs=None, keywords=None, defaults=[\'100\', \'1\', \'True\'], " + } + member_method { + name: "ctc_greedy_decoder" + argspec: "args=[\'inputs\', \'sequence_length\', \'merge_repeated\'], varargs=None, keywords=None, defaults=[\'True\'], " + } + member_method { + name: "ctc_loss" + argspec: "args=[\'labels\', \'inputs\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'True\'], " + } + member_method { + name: "depthwise_conv2d" + argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'rate\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "depthwise_conv2d_native" + argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], " + } + member_method { + name: "depthwise_conv2d_native_backprop_filter" + argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], " + } + member_method { + name: "depthwise_conv2d_native_backprop_input" + argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], " + } + member_method { + name: "dilation2d" + argspec: "args=[\'input\', \'filter\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dropout" + argspec: "args=[\'x\', \'keep_prob\', \'noise_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "dynamic_rnn" + argspec: "args=[\'cell\', \'inputs\', \'sequence_length\', \'initial_state\', \'dtype\', \'parallel_iterations\', \'swap_memory\', \'time_major\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "elu" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "embedding_lookup" + argspec: "args=[\'params\', \'ids\', \'partition_strategy\', \'name\', \'validate_indices\', \'max_norm\'], varargs=None, keywords=None, defaults=[\'mod\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "embedding_lookup_sparse" + argspec: "args=[\'params\', \'sp_ids\', \'sp_weights\', \'partition_strategy\', \'name\', \'combiner\', \'max_norm\'], varargs=None, keywords=None, defaults=[\'mod\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "erosion2d" + argspec: "args=[\'value\', \'kernel\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fixed_unigram_candidate_sampler" + argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'vocab_file\', \'distortion\', \'num_reserved_ids\', \'num_shards\', \'shard\', \'unigrams\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'1.0\', \'0\', \'1\', \'0\', \'()\', \'None\', \'None\'], " + } + member_method { + name: "fractional_avg_pool" + argspec: "args=[\'value\', \'pooling_ratio\', \'pseudo_random\', \'overlapping\', \'deterministic\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'0\', \'0\', \'None\'], " + } + member_method { + name: "fractional_max_pool" + argspec: "args=[\'value\', \'pooling_ratio\', \'pseudo_random\', \'overlapping\', \'deterministic\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'0\', \'0\', \'None\'], " + } + member_method { + name: "fused_batch_norm" + argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\'], " + } + member_method { + name: "in_top_k" + argspec: "args=[\'predictions\', \'targets\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "l2_loss" + argspec: "args=[\'t\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "l2_normalize" + argspec: "args=[\'x\', \'axis\', \'epsilon\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'1e-12\', \'None\', \'None\'], " + } + member_method { + name: "leaky_relu" + argspec: "args=[\'features\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.2\', \'None\'], " + } + member_method { + name: "learned_unigram_candidate_sampler" + argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "local_response_normalization" + argspec: "args=[\'input\', \'depth_radius\', \'bias\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'5\', \'1\', \'1\', \'0.5\', \'None\'], " + } + member_method { + name: "log_poisson_loss" + argspec: "args=[\'targets\', \'log_input\', \'compute_full_loss\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "log_softmax" + argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "log_uniform_candidate_sampler" + argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "lrn" + argspec: "args=[\'input\', \'depth_radius\', \'bias\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'5\', \'1\', \'1\', \'0.5\', \'None\'], " + } + member_method { + name: "max_pool" + argspec: "args=[\'value\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], " + } + member_method { + name: "max_pool3d" + argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NDHWC\', \'None\'], " + } + member_method { + name: "max_pool_with_argmax" + argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'Targmax\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "moments" + argspec: "args=[\'x\', \'axes\', \'shift\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " + } + member_method { + name: "nce_loss" + argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'False\', \'mod\', \'nce_loss\'], " + } + member_method { + name: "normalize_moments" + argspec: "args=[\'counts\', \'mean_ss\', \'variance_ss\', \'shift\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "pool" + argspec: "args=[\'input\', \'window_shape\', \'pooling_type\', \'padding\', \'dilation_rate\', \'strides\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "quantized_avg_pool" + argspec: "args=[\'input\', \'min_input\', \'max_input\', \'ksize\', \'strides\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "quantized_conv2d" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'strides\', \'padding\', \'out_type\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'[1, 1, 1, 1]\', \'None\'], " + } + member_method { + name: "quantized_max_pool" + argspec: "args=[\'input\', \'min_input\', \'max_input\', \'ksize\', \'strides\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "quantized_relu_x" + argspec: "args=[\'features\', \'max_value\', \'min_features\', \'max_features\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } + member_method { + name: "raw_rnn" + argspec: "args=[\'cell\', \'loop_fn\', \'parallel_iterations\', \'swap_memory\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "relu" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "relu6" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "relu_layer" + argspec: "args=[\'x\', \'weights\', \'biases\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "safe_embedding_lookup_sparse" + argspec: "args=[\'embedding_weights\', \'sparse_ids\', \'sparse_weights\', \'combiner\', \'default_id\', \'name\', \'partition_strategy\', \'max_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'mean\', \'None\', \'None\', \'div\', \'None\'], " + } + member_method { + name: "sampled_softmax_loss" + argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\', \'seed\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'True\', \'mod\', \'sampled_softmax_loss\', \'None\'], " + } + member_method { + name: "selu" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "separable_conv2d" + argspec: "args=[\'input\', \'depthwise_filter\', \'pointwise_filter\', \'strides\', \'padding\', \'rate\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "sigmoid" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sigmoid_cross_entropy_with_logits" + argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "softmax" + argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "softmax_cross_entropy_with_logits" + argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'dim\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'-1\', \'None\'], " + } + member_method { + name: "softmax_cross_entropy_with_logits_v2" + argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'dim\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'-1\', \'None\'], " + } + member_method { + name: "softplus" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "softsign" + argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "sparse_softmax_cross_entropy_with_logits" + argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "static_bidirectional_rnn" + argspec: "args=[\'cell_fw\', \'cell_bw\', \'inputs\', \'initial_state_fw\', \'initial_state_bw\', \'dtype\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "static_rnn" + argspec: "args=[\'cell\', \'inputs\', \'initial_state\', \'dtype\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "static_state_saving_rnn" + argspec: "args=[\'cell\', \'inputs\', \'state_saver\', \'state_name\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "sufficient_statistics" + argspec: "args=[\'x\', \'axes\', \'shift\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "tanh" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "top_k" + argspec: "args=[\'input\', \'k\', \'sorted\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'True\', \'None\'], " + } + member_method { + name: "uniform_candidate_sampler" + argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "weighted_cross_entropy_with_logits" + argspec: "args=[\'targets\', \'logits\', \'pos_weight\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "weighted_moments" + argspec: "args=[\'x\', \'axes\', \'frequency_weights\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + } + member_method { + name: "with_space_to_batch" + argspec: "args=[\'input\', \'dilation_rate\', \'padding\', \'op\', \'filter_shape\', \'spatial_dims\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "xw_plus_b" + argspec: "args=[\'x\', \'weights\', \'biases\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "zero_fraction" + argspec: "args=[\'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt index a8d9e120cb4aa965c1d85df59de1fbabc196bf54..c74773000aa31b0c51677b49eed6e83cc1f073ae 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt index c039890e1f4c1d57e7b795f1f09cff71921f6554..d251f548069b430de0fe9af83b6e9c641ea9237c 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8a63b4918008674041c9c216a5e5547ed7152fce --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -0,0 +1,197 @@ +path: "tensorflow.nn.rnn_cell.DeviceWrapper" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cell\', \'device\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..db1aae275792dad94c4cf823d0d30f934e397601 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -0,0 +1,201 @@ +path: "tensorflow.nn.rnn_cell.DropoutWrapper" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member { + name: "wrapped_cell" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cell\', \'input_keep_prob\', \'output_keep_prob\', \'state_keep_prob\', \'variational_recurrent\', \'input_size\', \'dtype\', \'seed\', \'dropout_state_filter_visitor\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1.0\', \'1.0\', \'False\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt index 4583dc32b2e98d4a9912378fe0e3d841882772fd..d76eab7eb874c981ac111cf6f96f28363f5e4375 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt similarity index 95% rename from tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt index 5016b6ac3010e2e184674db4837173c57c44b97e..944db6ac937acb0d6a134aa2f17dfaa0d3d618ff 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt @@ -117,7 +117,7 @@ tf_class { } member_method { name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " } member_method { name: "apply" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1de8a55dccac10ee9af08eb1efc0cb6d22f7163b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-l-s-t-m-state-tuple.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.nn.rnn_cell.LSTMStateTuple" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "c" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "h" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..72b40cc9f7a720888a1399a60aa216013e0b9918 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -0,0 +1,197 @@ +path: "tensorflow.nn.rnn_cell.MultiRNNCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a5c2b4aefd6a1b96cbe63271ca27de06616f1deb --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -0,0 +1,196 @@ +path: "tensorflow.nn.rnn_cell.RNNCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'trainable\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'True\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..61d5f04b22a4b4e3801643958b73a35403b79139 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -0,0 +1,197 @@ +path: "tensorflow.nn.rnn_cell.ResidualWrapper" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "output_size" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cell\', \'residual_fn\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "zero_state" + argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..64697e8a02b90bdace731a414570b7dc9da11015 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt @@ -0,0 +1,43 @@ +path: "tensorflow.nn.rnn_cell" +tf_module { + member { + name: "BasicLSTMCell" + mtype: "" + } + member { + name: "BasicRNNCell" + mtype: "" + } + member { + name: "DeviceWrapper" + mtype: "" + } + member { + name: "DropoutWrapper" + mtype: "" + } + member { + name: "GRUCell" + mtype: "" + } + member { + name: "LSTMCell" + mtype: "" + } + member { + name: "LSTMStateTuple" + mtype: "" + } + member { + name: "MultiRNNCell" + mtype: "" + } + member { + name: "RNNCell" + mtype: "" + } + member { + name: "ResidualWrapper" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.ones_initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.ones_initializer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..210b56242b27fe4a832cfe50a53626d716d8877e --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.ones_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.ones_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.orthogonal_initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.orthogonal_initializer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..13ec7454f41eac2b23e07ba62068bb48dddac90b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.orthogonal_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.orthogonal_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'gain\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt similarity index 98% rename from tensorflow/tools/api/golden/tensorflow.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 01b80581188e65d228aaa669254d9951546ecfa0..5eb42b4db3c95a3bd139596665384a116f467b6c 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -258,12 +258,20 @@ tf_module { } member { name: "Variable" - mtype: "" + mtype: "" + } + member { + name: "VariableAggregation" + mtype: "" } member { name: "VariableScope" mtype: "" } + member { + name: "VariableSynchronization" + mtype: "" + } member { name: "WholeFileReader" mtype: "" @@ -308,6 +316,10 @@ tf_module { name: "data" mtype: "" } + member { + name: "debugging" + mtype: "" + } member { name: "distributions" mtype: "" @@ -316,6 +328,10 @@ tf_module { name: "double" mtype: "" } + member { + name: "dtypes" + mtype: "" + } member { name: "errors" mtype: "" @@ -380,6 +396,10 @@ tf_module { name: "int8" mtype: "" } + member { + name: "io" + mtype: "" + } member { name: "keras" mtype: "" @@ -456,6 +476,10 @@ tf_module { name: "qint8" mtype: "" } + member { + name: "quantization" + mtype: "" + } member { name: "quint16" mtype: "" @@ -1134,7 +1158,7 @@ tf_module { } member_method { name: "get_local_variable" - argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'synchronization\', \'aggregation\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\', \'None\'], " } member_method { name: "get_seed" @@ -1150,7 +1174,7 @@ tf_module { } member_method { name: "get_variable" - argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " } member_method { name: "get_variable_scope" @@ -1294,7 +1318,7 @@ tf_module { } member_method { name: "lbeta" - argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'lbeta\'], " + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "less" @@ -1536,10 +1560,6 @@ tf_module { name: "pow" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "print" - argspec: "args=[\'input_\', \'data\', \'message\', \'first_n\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " - } member_method { name: "py_func" argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " @@ -2174,7 +2194,7 @@ tf_module { } member_method { name: "while_loop" - argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\', \'return_same_structure\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\', \'False\'], " } member_method { name: "write_file" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.profiler.-advice-proto.-checker.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-advice-proto.-checker.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e09c44cc9ce71305692740ba2d63b0940b2e0573 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-advice-proto.-checker.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.profiler.AdviceProto.Checker" +tf_proto { + descriptor { + name: "Checker" + field { + name: "reports" + number: 2 + label: LABEL_REPEATED + type: TYPE_STRING + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..87462435496fd2eedeb0bc8d92e8a833671b6531 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-advice-proto.-checkers-entry.pbtxt @@ -0,0 +1,22 @@ +path: "tensorflow.profiler.AdviceProto.CheckersEntry" +tf_proto { + descriptor { + name: "CheckersEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.AdviceProto.Checker" + } + options { + map_entry: true + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.profiler.-advice-proto.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-advice-proto.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a8a8858ccd5af3fb3dac612eef44e5cb450df914 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-advice-proto.pbtxt @@ -0,0 +1,41 @@ +path: "tensorflow.profiler.AdviceProto" +tf_proto { + descriptor { + name: "AdviceProto" + field { + name: "checkers" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.AdviceProto.CheckersEntry" + } + nested_type { + name: "CheckersEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.AdviceProto.Checker" + } + options { + map_entry: true + } + } + nested_type { + name: "Checker" + field { + name: "reports" + number: 2 + label: LABEL_REPEATED + type: TYPE_STRING + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..afec73f537aadd5d1a274db8d57e37b8c6fa3e74 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-graph-node-proto.-input-shapes-entry.pbtxt @@ -0,0 +1,22 @@ +path: "tensorflow.profiler.GraphNodeProto.InputShapesEntry" +tf_proto { + descriptor { + name: "InputShapesEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + } + options { + map_entry: true + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.profiler.-graph-node-proto.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-graph-node-proto.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3c83177005323a277f929d8c769cd7b1eeff4d51 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-graph-node-proto.pbtxt @@ -0,0 +1,191 @@ +path: "tensorflow.profiler.GraphNodeProto" +tf_proto { + descriptor { + name: "GraphNodeProto" + field { + name: "name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tensor_value" + number: 15 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.TFProfTensorProto" + } + field { + name: "run_count" + number: 21 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "exec_micros" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "accelerator_exec_micros" + number: 17 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "cpu_exec_micros" + number: 18 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "requested_bytes" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "peak_bytes" + number: 24 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "residual_bytes" + number: 25 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "output_bytes" + number: 26 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "parameters" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "float_ops" + number: 13 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "devices" + number: 10 + label: LABEL_REPEATED + type: TYPE_STRING + } + field { + name: "total_definition_count" + number: 23 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_run_count" + number: 22 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_exec_micros" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_accelerator_exec_micros" + number: 19 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_cpu_exec_micros" + number: 20 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_requested_bytes" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_peak_bytes" + number: 27 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_residual_bytes" + number: 28 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_output_bytes" + number: 29 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_parameters" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_float_ops" + number: 14 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "shapes" + number: 11 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + } + field { + name: "input_shapes" + number: 16 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.GraphNodeProto.InputShapesEntry" + } + field { + name: "children" + number: 12 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.GraphNodeProto" + } + nested_type { + name: "InputShapesEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorShapeProto" + } + options { + map_entry: true + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.profiler.-multi-graph-node-proto.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-multi-graph-node-proto.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2b08a05437f90b91160fc08e670b2466ae163149 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-multi-graph-node-proto.pbtxt @@ -0,0 +1,134 @@ +path: "tensorflow.profiler.MultiGraphNodeProto" +tf_proto { + descriptor { + name: "MultiGraphNodeProto" + field { + name: "name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "exec_micros" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "accelerator_exec_micros" + number: 12 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "cpu_exec_micros" + number: 13 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "requested_bytes" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "peak_bytes" + number: 16 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "residual_bytes" + number: 17 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "output_bytes" + number: 18 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "parameters" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "float_ops" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_exec_micros" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_accelerator_exec_micros" + number: 14 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_cpu_exec_micros" + number: 15 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_requested_bytes" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_peak_bytes" + number: 19 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_residual_bytes" + number: 20 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_output_bytes" + number: 21 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_parameters" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "total_float_ops" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "graph_nodes" + number: 10 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.GraphNodeProto" + } + field { + name: "children" + number: 11 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.MultiGraphNodeProto" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b3adc50c7e14152a81a148df9deccc5272189aad --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-op-log-proto.-id-to-string-entry.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.profiler.OpLogProto.IdToStringEntry" +tf_proto { + descriptor { + name: "IdToStringEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + options { + map_entry: true + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.profiler.-op-log-proto.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-op-log-proto.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7510c566ba574e9370f5e54c29023ef4fb5ee804 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-op-log-proto.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.profiler.OpLogProto" +tf_proto { + descriptor { + name: "OpLogProto" + field { + name: "log_entries" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.OpLogEntry" + } + field { + name: "id_to_string" + number: 2 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.tfprof.OpLogProto.IdToStringEntry" + } + nested_type { + name: "IdToStringEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + options { + map_entry: true + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.profiler.-profile-option-builder.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-profile-option-builder.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..19ff38a3900c2d358faaa40e7316cc3a9da73040 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-profile-option-builder.pbtxt @@ -0,0 +1,93 @@ +path: "tensorflow.profiler.ProfileOptionBuilder" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "account_displayed_op_only" + argspec: "args=[\'self\', \'is_true\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "float_operation" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "order_by" + argspec: "args=[\'self\', \'attribute\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "select" + argspec: "args=[\'self\', \'attributes\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "time_and_memory" + argspec: "args=[\'min_micros\', \'min_bytes\', \'min_accelerator_micros\', \'min_cpu_micros\', \'min_peak_bytes\', \'min_residual_bytes\', \'min_output_bytes\'], varargs=None, keywords=None, defaults=[\'1\', \'1\', \'0\', \'0\', \'0\', \'0\', \'0\'], " + } + member_method { + name: "trainable_variables_parameter" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_accounted_types" + argspec: "args=[\'self\', \'account_type_regexes\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_empty_output" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_file_output" + argspec: "args=[\'self\', \'outfile\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_max_depth" + argspec: "args=[\'self\', \'max_depth\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_min_execution_time" + argspec: "args=[\'self\', \'min_micros\', \'min_accelerator_micros\', \'min_cpu_micros\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\'], " + } + member_method { + name: "with_min_float_operations" + argspec: "args=[\'self\', \'min_float_ops\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_min_memory" + argspec: "args=[\'self\', \'min_bytes\', \'min_peak_bytes\', \'min_residual_bytes\', \'min_output_bytes\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'0\'], " + } + member_method { + name: "with_min_occurrence" + argspec: "args=[\'self\', \'min_occurrence\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_min_parameters" + argspec: "args=[\'self\', \'min_params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_node_names" + argspec: "args=[\'self\', \'start_name_regexes\', \'show_name_regexes\', \'hide_name_regexes\', \'trim_name_regexes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "with_pprof_output" + argspec: "args=[\'self\', \'pprof_file\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_stdout_output" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_step" + argspec: "args=[\'self\', \'step\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_timeline_output" + argspec: "args=[\'self\', \'timeline_file\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.profiler.-profiler.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-profiler.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..acb61dae9f0d184ba998aa820ec40de5bc38c3eb --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.profiler.-profiler.pbtxt @@ -0,0 +1,37 @@ +path: "tensorflow.profiler.Profiler" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'graph\', \'op_log\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add_step" + argspec: "args=[\'self\', \'step\', \'run_meta\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "advise" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "profile_graph" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "profile_name_scope" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "profile_operations" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "profile_python" + argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "serialize_to_string" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.profiler.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.profiler.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7b4d3ac522abc4229c5623da25c4ec818d86f829 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.profiler.pbtxt @@ -0,0 +1,39 @@ +path: "tensorflow.profiler" +tf_module { + member { + name: "AdviceProto" + mtype: "" + } + member { + name: "GraphNodeProto" + mtype: "" + } + member { + name: "MultiGraphNodeProto" + mtype: "" + } + member { + name: "OpLogProto" + mtype: "" + } + member { + name: "ProfileOptionBuilder" + mtype: "" + } + member { + name: "Profiler" + mtype: "" + } + member_method { + name: "advise" + argspec: "args=[\'graph\', \'run_meta\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0\'], " + } + member_method { + name: "profile" + argspec: "args=[\'graph\', \'run_meta\', \'op_log\', \'cmd\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'scope\', \'0\'], " + } + member_method { + name: "write_op_log" + argspec: "args=[\'graph\', \'log_dir\', \'op_log\', \'run_meta\', \'add_trace\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-compression-type.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-compression-type.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4941dda50e4964f8400a4cb5033c8e918aeaea5d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-compression-type.pbtxt @@ -0,0 +1,20 @@ +path: "tensorflow.python_io.TFRecordCompressionType" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "GZIP" + mtype: "" + } + member { + name: "NONE" + mtype: "" + } + member { + name: "ZLIB" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0853716023ae5271fba6e8024e719eebb22ec56d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt @@ -0,0 +1,17 @@ +path: "tensorflow.python_io.TFRecordOptions" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "compression_type_map" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'compression_type\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_compression_type_string" + argspec: "args=[\'cls\', \'options\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-writer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..31775de2d12bcd2f214f5a04be7a92f49c594fde --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-writer.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.python_io.TFRecordWriter" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flush" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "write" + argspec: "args=[\'self\', \'record\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.python_io.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.python_io.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7c9953e5fe3c883fd5e6e19ae011cc464f4107af --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.python_io.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.python_io" +tf_module { + member { + name: "TFRecordCompressionType" + mtype: "" + } + member { + name: "TFRecordOptions" + mtype: "" + } + member { + name: "TFRecordWriter" + mtype: "" + } + member_method { + name: "tf_record_iterator" + argspec: "args=[\'path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6d865efed0bfdada8dde64e86ddb5d2b2b364c79 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.quantization" +tf_module { + member_method { + name: "dequantize" + argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_args" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_args_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_per_channel" + argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "fake_quant_with_min_max_vars_per_channel_gradient" + argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], " + } + member_method { + name: "quantized_concat" + argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.random_normal_initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.random_normal_initializer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5993fdeb9c232ebc4090d9fffd8857da8ca6ada4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.random_normal_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.random_normal_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.random_uniform_initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.random_uniform_initializer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a434ed1599ef8b99b6e0496be388aa0e44755249 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.random_uniform_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.random_uniform_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.resource_loader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.resource_loader.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..288b78b4cd0ad3f5d5bc1f9c773977d50a6db086 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.resource_loader.pbtxt @@ -0,0 +1,23 @@ +path: "tensorflow.resource_loader" +tf_module { + member_method { + name: "get_data_files_path" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_path_to_datafile" + argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_root_dir_with_all_resources" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_resource" + argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "readahead_file_path" + argspec: "args=[\'path\', \'readahead\'], varargs=None, keywords=None, defaults=[\'128M\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.builder.-saved-model-builder.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.builder.-saved-model-builder.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..83bd7035409534abf036c7e2b0d66fcc060ada3a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.builder.-saved-model-builder.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.saved_model.builder.SavedModelBuilder" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'export_dir\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_meta_graph" + argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "add_meta_graph_and_variables" + argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "save" + argspec: "args=[\'self\', \'as_text\'], varargs=None, keywords=None, defaults=[\'False\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.builder.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.builder.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..adc697ad1c0bdd0c9b52be736fca3a19a2a82ef3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.builder.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.saved_model.builder" +tf_module { + member { + name: "SavedModelBuilder" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.constants.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.constants.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..20e10aa094f704f2168de37abb73f6edf6765f93 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.constants.pbtxt @@ -0,0 +1,39 @@ +path: "tensorflow.saved_model.constants" +tf_module { + member { + name: "ASSETS_DIRECTORY" + mtype: "" + } + member { + name: "ASSETS_KEY" + mtype: "" + } + member { + name: "LEGACY_INIT_OP_KEY" + mtype: "" + } + member { + name: "MAIN_OP_KEY" + mtype: "" + } + member { + name: "SAVED_MODEL_FILENAME_PB" + mtype: "" + } + member { + name: "SAVED_MODEL_FILENAME_PBTXT" + mtype: "" + } + member { + name: "SAVED_MODEL_SCHEMA_VERSION" + mtype: "" + } + member { + name: "VARIABLES_DIRECTORY" + mtype: "" + } + member { + name: "VARIABLES_FILENAME" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.loader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.loader.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..511e6b4712d3c55746a39fe9098fa3b649bc75dc --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.loader.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.saved_model.loader" +tf_module { + member_method { + name: "load" + argspec: "args=[\'sess\', \'tags\', \'export_dir\', \'import_scope\'], varargs=None, keywords=saver_kwargs, defaults=[\'None\'], " + } + member_method { + name: "maybe_saved_model_directory" + argspec: "args=[\'export_dir\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.main_op.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.main_op.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..176cb788c249e68f1221713e96c7e808c39c8f6d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.main_op.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.saved_model.main_op" +tf_module { + member_method { + name: "main_op" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "main_op_with_restore" + argspec: "args=[\'restore_op_name\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e1a0385092c1384bcb5958fce2e24693ee731ae5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt @@ -0,0 +1,39 @@ +path: "tensorflow.saved_model" +tf_module { + member { + name: "builder" + mtype: "" + } + member { + name: "constants" + mtype: "" + } + member { + name: "loader" + mtype: "" + } + member { + name: "main_op" + mtype: "" + } + member { + name: "signature_constants" + mtype: "" + } + member { + name: "signature_def_utils" + mtype: "" + } + member { + name: "tag_constants" + mtype: "" + } + member { + name: "utils" + mtype: "" + } + member_method { + name: "simple_save" + argspec: "args=[\'session\', \'export_dir\', \'inputs\', \'outputs\', \'legacy_init_op\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.signature_constants.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.signature_constants.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..478d410e066b1ce3a17bb3ef9cc6e4503991ad0b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.signature_constants.pbtxt @@ -0,0 +1,47 @@ +path: "tensorflow.saved_model.signature_constants" +tf_module { + member { + name: "CLASSIFY_INPUTS" + mtype: "" + } + member { + name: "CLASSIFY_METHOD_NAME" + mtype: "" + } + member { + name: "CLASSIFY_OUTPUT_CLASSES" + mtype: "" + } + member { + name: "CLASSIFY_OUTPUT_SCORES" + mtype: "" + } + member { + name: "DEFAULT_SERVING_SIGNATURE_DEF_KEY" + mtype: "" + } + member { + name: "PREDICT_INPUTS" + mtype: "" + } + member { + name: "PREDICT_METHOD_NAME" + mtype: "" + } + member { + name: "PREDICT_OUTPUTS" + mtype: "" + } + member { + name: "REGRESS_INPUTS" + mtype: "" + } + member { + name: "REGRESS_METHOD_NAME" + mtype: "" + } + member { + name: "REGRESS_OUTPUTS" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.signature_def_utils.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.signature_def_utils.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a5602464eeb09a290076ef102ed5502ea61b4ac3 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.signature_def_utils.pbtxt @@ -0,0 +1,23 @@ +path: "tensorflow.saved_model.signature_def_utils" +tf_module { + member_method { + name: "build_signature_def" + argspec: "args=[\'inputs\', \'outputs\', \'method_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "classification_signature_def" + argspec: "args=[\'examples\', \'classes\', \'scores\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_valid_signature" + argspec: "args=[\'signature_def\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict_signature_def" + argspec: "args=[\'inputs\', \'outputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "regression_signature_def" + argspec: "args=[\'examples\', \'predictions\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.tag_constants.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.tag_constants.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6af72498d74d4bbc12e7ca68ad1e0a6f0c237e0a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.tag_constants.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.saved_model.tag_constants" +tf_module { + member { + name: "GPU" + mtype: "" + } + member { + name: "SERVING" + mtype: "" + } + member { + name: "TPU" + mtype: "" + } + member { + name: "TRAINING" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.utils.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.utils.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d95c94668250e1de236462ccdcb134245eebf092 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.utils.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.saved_model.utils" +tf_module { + member_method { + name: "build_tensor_info" + argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_tensor_from_tensor_info" + argspec: "args=[\'tensor_info\', \'graph\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sets.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sets.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8a196b1a556e283671cc75af28df3eaa62532975 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.sets.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.sets" +tf_module { + member_method { + name: "set_difference" + argspec: "args=[\'a\', \'b\', \'aminusb\', \'validate_indices\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], " + } + member_method { + name: "set_intersection" + argspec: "args=[\'a\', \'b\', \'validate_indices\'], varargs=None, keywords=None, defaults=[\'True\'], " + } + member_method { + name: "set_size" + argspec: "args=[\'a\', \'validate_indices\'], varargs=None, keywords=None, defaults=[\'True\'], " + } + member_method { + name: "set_union" + argspec: "args=[\'a\', \'b\', \'validate_indices\'], varargs=None, keywords=None, defaults=[\'True\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..bbfe395031aece42363ca7d6577fee856df6bde8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.sparse" +tf_module { + member_method { + name: "cross" + argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "cross_hashed" + argspec: "args=[\'inputs\', \'num_buckets\', \'hash_key\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.spectral.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.spectral.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6a421ef12d58dc047905ec916cbe777b4ce19b9a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.spectral.pbtxt @@ -0,0 +1,59 @@ +path: "tensorflow.spectral" +tf_module { + member_method { + name: "dct" + argspec: "args=[\'input\', \'type\', \'n\', \'axis\', \'norm\', \'name\'], varargs=None, keywords=None, defaults=[\'2\', \'None\', \'-1\', \'None\', \'None\'], " + } + member_method { + name: "fft" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fft2d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "fft3d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "idct" + argspec: "args=[\'input\', \'type\', \'n\', \'axis\', \'norm\', \'name\'], varargs=None, keywords=None, defaults=[\'2\', \'None\', \'-1\', \'None\', \'None\'], " + } + member_method { + name: "ifft" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ifft2d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ifft3d" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "irfft" + argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "irfft2d" + argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "irfft3d" + argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "rfft" + argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "rfft2d" + argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "rfft3d" + argspec: "args=[\'input_tensor\', \'fft_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9a831fed2692b30db6ce991c86f46a42908c0789 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt @@ -0,0 +1,43 @@ +path: "tensorflow.strings" +tf_module { + member_method { + name: "join" + argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " + } + member_method { + name: "regex_full_match" + argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "regex_replace" + argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + } + member_method { + name: "split" + argspec: "args=[\'source\', \'sep\', \'maxsplit\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], " + } + member_method { + name: "strip" + argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "substr" + argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_hash_bucket" + argspec: "args=[\'string_tensor\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_hash_bucket_fast" + argspec: "args=[\'input\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_hash_bucket_strong" + argspec: "args=[\'input\', \'num_buckets\', \'key\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_number" + argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.-event.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.-event.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..eb99d0f5334457aa654fed0553af143839328dba --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.-event.pbtxt @@ -0,0 +1,74 @@ +path: "tensorflow.summary.Event" +tf_proto { + descriptor { + name: "Event" + field { + name: "wall_time" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_DOUBLE + } + field { + name: "step" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "file_version" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + oneof_index: 0 + } + field { + name: "graph_def" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "summary" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary" + oneof_index: 0 + } + field { + name: "log_message" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.LogMessage" + oneof_index: 0 + } + field { + name: "session_log" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SessionLog" + oneof_index: 0 + } + field { + name: "tagged_run_metadata" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TaggedRunMetadata" + oneof_index: 0 + } + field { + name: "meta_graph_def" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + oneof_decl { + name: "what" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.-file-writer-cache.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.-file-writer-cache.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2a5b63dceae3c0ac27b34c2e896ee3b90bbd7f75 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.-file-writer-cache.pbtxt @@ -0,0 +1,16 @@ +path: "tensorflow.summary.FileWriterCache" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "clear" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get" + argspec: "args=[\'logdir\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.-file-writer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.-file-writer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6b65b0ace3cf7740ab03390841c941592000d127 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.-file-writer.pbtxt @@ -0,0 +1,50 @@ +path: "tensorflow.summary.FileWriter" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'logdir\', \'graph\', \'max_queue\', \'flush_secs\', \'graph_def\', \'filename_suffix\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'120\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_event" + argspec: "args=[\'self\', \'event\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_graph" + argspec: "args=[\'self\', \'graph\', \'global_step\', \'graph_def\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add_meta_graph" + argspec: "args=[\'self\', \'meta_graph_def\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_run_metadata" + argspec: "args=[\'self\', \'run_metadata\', \'tag\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_session_log" + argspec: "args=[\'self\', \'session_log\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_summary" + argspec: "args=[\'self\', \'summary\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "flush" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_logdir" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reopen" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.-session-log.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.-session-log.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..73de73869c8d1a6808b16fe8853fd21cc8891879 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.-session-log.pbtxt @@ -0,0 +1,44 @@ +path: "tensorflow.summary.SessionLog" +tf_proto { + descriptor { + name: "SessionLog" + field { + name: "status" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.SessionLog.SessionStatus" + } + field { + name: "checkpoint_path" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "msg" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + enum_type { + name: "SessionStatus" + value { + name: "STATUS_UNSPECIFIED" + number: 0 + } + value { + name: "START" + number: 1 + } + value { + name: "STOP" + number: 2 + } + value { + name: "CHECKPOINT" + number: 3 + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary-description.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary-description.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4a8b59cf02ed46ef70f22564f3134214840600fe --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary-description.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.summary.SummaryDescription" +tf_proto { + descriptor { + name: "SummaryDescription" + field { + name: "type_hint" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary.-audio.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary.-audio.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8b271cf58fc11c8666abd456021afeedc0b14c7a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary.-audio.pbtxt @@ -0,0 +1,36 @@ +path: "tensorflow.summary.Summary.Audio" +tf_proto { + descriptor { + name: "Audio" + field { + name: "sample_rate" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + } + field { + name: "num_channels" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "length_frames" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "encoded_audio_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + field { + name: "content_type" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary.-image.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary.-image.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..dbbc02dd0506dbcebd1690602b5786b02c3ed4a0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary.-image.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.summary.Summary.Image" +tf_proto { + descriptor { + name: "Image" + field { + name: "height" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "width" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "colorspace" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "encoded_image_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary.-value.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary.-value.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4176171cd938e383fe5366153364d8e8e8c1a1ee --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary.-value.pbtxt @@ -0,0 +1,74 @@ +path: "tensorflow.summary.Summary.Value" +tf_proto { + descriptor { + name: "Value" + field { + name: "node_name" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tag" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "metadata" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SummaryMetadata" + } + field { + name: "simple_value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + oneof_index: 0 + } + field { + name: "obsolete_old_style_histogram" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "image" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Image" + oneof_index: 0 + } + field { + name: "histo" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.HistogramProto" + oneof_index: 0 + } + field { + name: "audio" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Audio" + oneof_index: 0 + } + field { + name: "tensor" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + oneof_index: 0 + } + oneof_decl { + name: "value" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d6c5e3a87a115b9bdcfd044abe93177eda2af275 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.-summary.pbtxt @@ -0,0 +1,144 @@ +path: "tensorflow.summary.Summary" +tf_proto { + descriptor { + name: "Summary" + field { + name: "value" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Value" + } + nested_type { + name: "Image" + field { + name: "height" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "width" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "colorspace" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "encoded_image_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + } + nested_type { + name: "Audio" + field { + name: "sample_rate" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + } + field { + name: "num_channels" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "length_frames" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } + field { + name: "encoded_audio_string" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + field { + name: "content_type" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } + nested_type { + name: "Value" + field { + name: "node_name" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tag" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "metadata" + number: 9 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.SummaryMetadata" + } + field { + name: "simple_value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + oneof_index: 0 + } + field { + name: "obsolete_old_style_histogram" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_BYTES + oneof_index: 0 + } + field { + name: "image" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Image" + oneof_index: 0 + } + field { + name: "histo" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.HistogramProto" + oneof_index: 0 + } + field { + name: "audio" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Summary.Audio" + oneof_index: 0 + } + field { + name: "tensor" + number: 8 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.TensorProto" + oneof_index: 0 + } + oneof_decl { + name: "value" + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.-tagged-run-metadata.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.-tagged-run-metadata.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..27c8873320403cb2e7402ef9f1bb0e7134d5f96b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.-tagged-run-metadata.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.summary.TaggedRunMetadata" +tf_proto { + descriptor { + name: "TaggedRunMetadata" + field { + name: "tag" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "run_metadata" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_BYTES + } + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt similarity index 100% rename from tensorflow/tools/api/golden/tensorflow.summary.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sysconfig.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sysconfig.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2f00aeac25f691d9767080251798248281e5edf5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.sysconfig.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.sysconfig" +tf_module { + member_method { + name: "get_compile_flags" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_include" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_lib" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_link_flags" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.test.-benchmark.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.test.-benchmark.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..df528e26b60f8d8ddcc1eaf0ed292cc7ff0ebd94 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.test.-benchmark.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.test.Benchmark" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "is_abstract" + argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "report_benchmark" + argspec: "args=[\'self\', \'iters\', \'cpu_time\', \'wall_time\', \'throughput\', \'extras\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "run_op_benchmark" + argspec: "args=[\'self\', \'sess\', \'op_or_tensor\', \'feed_dict\', \'burn_iters\', \'min_iters\', \'store_trace\', \'store_memory_usage\', \'name\', \'extras\', \'mbs\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'10\', \'False\', \'True\', \'None\', \'None\', \'0\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.test.-stub-out-for-testing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.test.-stub-out-for-testing.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e02a0c6097c5ea4dae905b25cd0e381f5e257105 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.test.-stub-out-for-testing.pbtxt @@ -0,0 +1,28 @@ +path: "tensorflow.test.StubOutForTesting" +tf_class { + is_instance: "" + member_method { + name: "CleanUp" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Set" + argspec: "args=[\'self\', \'parent\', \'child_name\', \'new_child\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "SmartSet" + argspec: "args=[\'self\', \'obj\', \'attr_name\', \'new_attr\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "SmartUnsetAll" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "UnsetAll" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..abe9b068ae95c08a2b72c9a5e164a097e6162dff --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt @@ -0,0 +1,59 @@ +path: "tensorflow.test" +tf_module { + member { + name: "Benchmark" + mtype: "" + } + member { + name: "StubOutForTesting" + mtype: "" + } + member { + name: "TestCase" + mtype: "" + } + member { + name: "mock" + mtype: "" + } + member_method { + name: "assert_equal_graph_def" + argspec: "args=[\'actual\', \'expected\', \'checkpoint_v2\'], varargs=None, keywords=None, defaults=[\'False\'], " + } + member_method { + name: "compute_gradient" + argspec: "args=[\'x\', \'x_shape\', \'y\', \'y_shape\', \'x_init_value\', \'delta\', \'init_targets\', \'extra_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'0.001\', \'None\', \'None\'], " + } + member_method { + name: "compute_gradient_error" + argspec: "args=[\'x\', \'x_shape\', \'y\', \'y_shape\', \'x_init_value\', \'delta\', \'init_targets\', \'extra_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'0.001\', \'None\', \'None\'], " + } + member_method { + name: "create_local_cluster" + argspec: "args=[\'num_workers\', \'num_ps\', \'protocol\', \'worker_config\', \'ps_config\'], varargs=None, keywords=None, defaults=[\'grpc\', \'None\', \'None\'], " + } + member_method { + name: "get_temp_dir" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "gpu_device_name" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_built_with_cuda" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_gpu_available" + argspec: "args=[\'cuda_only\', \'min_cuda_compute_capability\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "main" + argspec: "args=[\'argv\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "test_src_dir_path" + argspec: "args=[\'relative_path\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-adadelta-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-adadelta-optimizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1f1d8b6f9e2cde4800cdef9c417191b1a0ce07b5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-adadelta-optimizer.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.train.AdadeltaOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'rho\', \'epsilon\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.001\', \'0.95\', \'1e-08\', \'False\', \'Adadelta\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-adagrad-d-a-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-adagrad-d-a-optimizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..a7c05d484905a0af26c80a52d92623ef4a3eb6c4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-adagrad-d-a-optimizer.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.train.AdagradDAOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'global_step\', \'initial_gradient_squared_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.1\', \'0.0\', \'0.0\', \'False\', \'AdagradDA\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-adagrad-optimizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..bc8b92389c6ed7dcb0fa23ff3abd86bb0d1c488a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-adagrad-optimizer.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.train.AdagradOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator_value\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\', \'Adagrad\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-adam-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-adam-optimizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5d17be9378fd130b89e199544f85e03a23a71d3c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-adam-optimizer.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.train.AdamOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-08\', \'False\', \'Adam\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-bytes-list.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-bytes-list.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..87e4f160e5bd5950dfc338649fb531c92cc81b60 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-bytes-list.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.train.BytesList" +tf_proto { + descriptor { + name: "BytesList" + field { + name: "value" + number: 1 + label: LABEL_REPEATED + type: TYPE_BYTES + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-saver-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-saver-hook.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c3037baa8c951ecd9b60267ee7cc8674ead88dbe --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-saver-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.CheckpointSaverHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-saver-listener.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-saver-listener.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9d3688e565761758e765d00086de8b59dcc3801b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint-saver-listener.pbtxt @@ -0,0 +1,24 @@ +path: "tensorflow.train.CheckpointSaverListener" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "after_save" + argspec: "args=[\'self\', \'session\', \'global_step_value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_save" + argspec: "args=[\'self\', \'session\', \'global_step_value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\', \'global_step_value\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt similarity index 96% rename from tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt index ddc553d7c984b24fe33c03bb90e00e7e81f55d26..2d067e4eff13208cb03ca01b7b8a8018a1e99097 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-checkpoint.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.train.Checkpoint" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-chief-session-creator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-chief-session-creator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..abbe273be32c6fd20b1a6464f3e99966bd3c8953 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-chief-session-creator.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.train.ChiefSessionCreator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'scaffold\', \'master\', \'config\', \'checkpoint_dir\', \'checkpoint_filename_with_path\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "create_session" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-cluster-def.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-cluster-def.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f9de26839f5f6dc1591bfc909ca8e6c02271b5c7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-cluster-def.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.train.ClusterDef" +tf_proto { + descriptor { + name: "ClusterDef" + field { + name: "job" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.JobDef" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-cluster-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-cluster-spec.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1658b15a5f82167f9167338145b479c9e9197ea5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-cluster-spec.pbtxt @@ -0,0 +1,37 @@ +path: "tensorflow.train.ClusterSpec" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "jobs" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cluster\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_cluster_def" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "as_dict" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "job_tasks" + argspec: "args=[\'self\', \'job_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "num_tasks" + argspec: "args=[\'self\', \'job_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "task_address" + argspec: "args=[\'self\', \'job_name\', \'task_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "task_indices" + argspec: "args=[\'self\', \'job_name\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-coordinator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-coordinator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..11277f077eef830aec3be61ddd981bfd3a55d149 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-coordinator.pbtxt @@ -0,0 +1,45 @@ +path: "tensorflow.train.Coordinator" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "joined" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'clean_stop_exception_types\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "clear_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "join" + argspec: "args=[\'self\', \'threads\', \'stop_grace_period_secs\', \'ignore_live_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'120\', \'False\'], " + } + member_method { + name: "raise_requested_exception" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "register_thread" + argspec: "args=[\'self\', \'thread\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "request_stop" + argspec: "args=[\'self\', \'ex\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "should_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stop_on_exception" + argspec: "args=[], varargs=args, keywords=kwds, defaults=None" + } + member_method { + name: "wait_for_stop" + argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-example.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-example.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..23c30f1ef4fe2dd93e8714655dbb1ef3b8e05c65 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-example.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.train.Example" +tf_proto { + descriptor { + name: "Example" + field { + name: "features" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Features" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-exponential-moving-average.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-exponential-moving-average.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c9fe136e68b5f3cadaff6d4fd0638b7f10d18365 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-exponential-moving-average.pbtxt @@ -0,0 +1,29 @@ +path: "tensorflow.train.ExponentialMovingAverage" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'decay\', \'num_updates\', \'zero_debias\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'ExponentialMovingAverage\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "average" + argspec: "args=[\'self\', \'var\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "average_name" + argspec: "args=[\'self\', \'var\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "variables_to_restore" + argspec: "args=[\'self\', \'moving_avg_variables\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-feature-list.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-feature-list.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2a8b3714fc0c4f5e979bc02550a8e08835d53cb4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-feature-list.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.train.FeatureList" +tf_proto { + descriptor { + name: "FeatureList" + field { + name: "feature" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.Feature" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..cd1d56e606c96b62346b936001a5a0f07a8a8ad8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-feature-lists.-feature-list-entry.pbtxt @@ -0,0 +1,22 @@ +path: "tensorflow.train.FeatureLists.FeatureListEntry" +tf_proto { + descriptor { + name: "FeatureListEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.FeatureList" + } + options { + map_entry: true + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-feature-lists.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-feature-lists.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3c183a64769b59b104c52b6840e8f351f4b0cef5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-feature-lists.pbtxt @@ -0,0 +1,32 @@ +path: "tensorflow.train.FeatureLists" +tf_proto { + descriptor { + name: "FeatureLists" + field { + name: "feature_list" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.FeatureLists.FeatureListEntry" + } + nested_type { + name: "FeatureListEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.FeatureList" + } + options { + map_entry: true + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-feature.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-feature.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5d0eb871c2f4aeb13d6b8518486f11b1f80d0620 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-feature.pbtxt @@ -0,0 +1,33 @@ +path: "tensorflow.train.Feature" +tf_proto { + descriptor { + name: "Feature" + field { + name: "bytes_list" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.BytesList" + oneof_index: 0 + } + field { + name: "float_list" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.FloatList" + oneof_index: 0 + } + field { + name: "int64_list" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Int64List" + oneof_index: 0 + } + oneof_decl { + name: "kind" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-features.-feature-entry.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-features.-feature-entry.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f912005f1cc35f12ce6eba5313b0c67adebe70f7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-features.-feature-entry.pbtxt @@ -0,0 +1,22 @@ +path: "tensorflow.train.Features.FeatureEntry" +tf_proto { + descriptor { + name: "FeatureEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Feature" + } + options { + map_entry: true + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-features.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-features.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b788ca1d57e1d679a1b809d85c6aa9bcef01f252 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-features.pbtxt @@ -0,0 +1,32 @@ +path: "tensorflow.train.Features" +tf_proto { + descriptor { + name: "Features" + field { + name: "feature" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.Features.FeatureEntry" + } + nested_type { + name: "FeatureEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Feature" + } + options { + map_entry: true + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-feed-fn-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-feed-fn-hook.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7bec4d032cedc0711ca07049d5d04490e8bc3f30 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-feed-fn-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.FeedFnHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feed_fn\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-final-ops-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-final-ops-hook.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..31cf9aaeb2c640f8db205c0753f20acc75338fe0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-final-ops-hook.pbtxt @@ -0,0 +1,34 @@ +path: "tensorflow.train.FinalOpsHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "final_ops_values" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'final_ops\', \'final_ops_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-float-list.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-float-list.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..55d3b46f20e17ec4e6fbac5672e1b0a8ef98552d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-float-list.pbtxt @@ -0,0 +1,15 @@ +path: "tensorflow.train.FloatList" +tf_proto { + descriptor { + name: "FloatList" + field { + name: "value" + number: 1 + label: LABEL_REPEATED + type: TYPE_FLOAT + options { + packed: true + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-ftrl-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-ftrl-optimizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d265fdeb01c38d8a1347e630d7f7bff111999634 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-ftrl-optimizer.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.train.FtrlOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\', \'0.0\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-global-step-waiter-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-global-step-waiter-hook.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..147448618e2df9f71ac794e369b108629e10ce0a --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-global-step-waiter-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.GlobalStepWaiterHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'wait_until_step\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-gradient-descent-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-gradient-descent-optimizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c673e29cd4dd6cd3c01582abfbc306c092818892 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-gradient-descent-optimizer.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.train.GradientDescentOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'GradientDescent\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-int64-list.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-int64-list.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..1de92b3ab7b5e0ff873a7e8092c7e6c2edcbd2ce --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-int64-list.pbtxt @@ -0,0 +1,15 @@ +path: "tensorflow.train.Int64List" +tf_proto { + descriptor { + name: "Int64List" + field { + name: "value" + number: 1 + label: LABEL_REPEATED + type: TYPE_INT64 + options { + packed: true + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-job-def.-tasks-entry.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-job-def.-tasks-entry.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..58115590a5eebd742afac4b31b5f585e8077e049 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-job-def.-tasks-entry.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.train.JobDef.TasksEntry" +tf_proto { + descriptor { + name: "TasksEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + options { + map_entry: true + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-job-def.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-job-def.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d7eb505e27930d6411a589909584f237a7e8b8f5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-job-def.pbtxt @@ -0,0 +1,37 @@ +path: "tensorflow.train.JobDef" +tf_proto { + descriptor { + name: "JobDef" + field { + name: "name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "tasks" + number: 2 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.JobDef.TasksEntry" + } + nested_type { + name: "TasksEntry" + field { + name: "key" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "value" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + options { + map_entry: true + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-logging-tensor-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-logging-tensor-hook.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9801c05df181ee65cc8ce0ad2e886566c0145fd5 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-logging-tensor-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.LoggingTensorHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'tensors\', \'every_n_iter\', \'every_n_secs\', \'at_end\', \'formatter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-looper-thread.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-looper-thread.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c61859004e897a14b580dc0b55957edfa6ae6860 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-looper-thread.pbtxt @@ -0,0 +1,73 @@ +path: "tensorflow.train.LooperThread" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "daemon" + mtype: "" + } + member { + name: "ident" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'coord\', \'timer_interval_secs\', \'target\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "getName" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "isAlive" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "isDaemon" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_alive" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "join" + argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "loop" + argspec: "args=[\'coord\', \'timer_interval_secs\', \'target\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "run" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run_loop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "setDaemon" + argspec: "args=[\'self\', \'daemonic\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "setName" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start_loop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stop_loop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-momentum-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-momentum-optimizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8199f63b9b8c64c73a3d62294277838cdc240280 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-momentum-optimizer.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.train.MomentumOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'momentum\', \'use_locking\', \'name\', \'use_nesterov\'], varargs=None, keywords=None, defaults=[\'False\', \'Momentum\', \'False\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-monitored-session.-step-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-monitored-session.-step-context.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..03efe6639e0e3d2c6c280bd30d2b59b5d654f995 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-monitored-session.-step-context.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.train.MonitoredSession.StepContext" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "session" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'session\', \'run_with_hooks_fn\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "request_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run_with_hooks" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-monitored-session.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-monitored-session.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..09b7b3fb538fb8d87dcfd622089818081a1fb79b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-monitored-session.pbtxt @@ -0,0 +1,34 @@ +path: "tensorflow.train.MonitoredSession" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "StepContext" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'session_creator\', \'hooks\', \'stop_grace_period_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'120\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run" + argspec: "args=[\'self\', \'fetches\', \'feed_dict\', \'options\', \'run_metadata\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "run_step_fn" + argspec: "args=[\'self\', \'step_fn\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "should_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-nan-loss-during-training-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-nan-loss-during-training-error.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..25fd5e75a79f6e4fe2cf77ebc7aa0d1fef759e7f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-nan-loss-during-training-error.pbtxt @@ -0,0 +1,16 @@ +path: "tensorflow.train.NanLossDuringTrainingError" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member { + name: "message" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-nan-tensor-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-nan-tensor-hook.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7d1c89f9b37b5e63ecf2cf766986cb8faa5872c4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-nan-tensor-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.NanTensorHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'loss_tensor\', \'fail_on_nan_loss\'], varargs=None, keywords=None, defaults=[\'True\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-optimizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..876bb35e391885e751066a415967af848280c714 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-optimizer.pbtxt @@ -0,0 +1,50 @@ +path: "tensorflow.train.Optimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-profiler-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-profiler-hook.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4df6c4156a8bfe6d3bc0fb6746512cb3025c2604 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-profiler-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.ProfilerHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'save_steps\', \'save_secs\', \'output_dir\', \'show_dataflow\', \'show_memory\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'\', \'True\', \'False\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-proximal-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-proximal-adagrad-optimizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..14349a74efb61124fc7b5568d5ec023f08b1b62f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-proximal-adagrad-optimizer.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.train.ProximalAdagradOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.1\', \'0.0\', \'0.0\', \'False\', \'ProximalAdagrad\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7d982dc51f6edce1cf691671e31ddd07664f0dc1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.train.ProximalGradientDescentOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.0\', \'False\', \'ProximalGradientDescent\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-queue-runner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-queue-runner.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d84d0058eea34d2d4413c8b1a09bd7d5720c07f7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-queue-runner.pbtxt @@ -0,0 +1,49 @@ +path: "tensorflow.train.QueueRunner" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "cancel_op" + mtype: "" + } + member { + name: "close_op" + mtype: "" + } + member { + name: "enqueue_ops" + mtype: "" + } + member { + name: "exceptions_raised" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "queue" + mtype: "" + } + member { + name: "queue_closed_exception_types" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'queue\', \'enqueue_ops\', \'close_op\', \'cancel_op\', \'queue_closed_exception_types\', \'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "create_threads" + argspec: "args=[\'self\', \'sess\', \'coord\', \'daemon\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\'], " + } + member_method { + name: "from_proto" + argspec: "args=[\'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_proto" + argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-r-m-s-prop-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-r-m-s-prop-optimizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..906384a2875bf7b05ac26fc43207f4ef9b5a7472 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-r-m-s-prop-optimizer.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.train.RMSPropOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'decay\', \'momentum\', \'epsilon\', \'use_locking\', \'centered\', \'name\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.0\', \'1e-10\', \'False\', \'False\', \'RMSProp\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-saver-def.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-saver-def.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4ec99469e4025603e7ab340b190cbebf7e33eed7 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-saver-def.pbtxt @@ -0,0 +1,64 @@ +path: "tensorflow.train.SaverDef" +tf_proto { + descriptor { + name: "SaverDef" + field { + name: "filename_tensor_name" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "save_tensor_name" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "restore_op_name" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "max_to_keep" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "sharded" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_BOOL + } + field { + name: "keep_checkpoint_every_n_hours" + number: 6 + label: LABEL_OPTIONAL + type: TYPE_FLOAT + } + field { + name: "version" + number: 7 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.SaverDef.CheckpointFormatVersion" + } + enum_type { + name: "CheckpointFormatVersion" + value { + name: "LEGACY" + number: 0 + } + value { + name: "V1" + number: 1 + } + value { + name: "V2" + number: 2 + } + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-saver.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-saver.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2cda458f468b2d748b43954b14b670df7145243f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-saver.pbtxt @@ -0,0 +1,53 @@ +path: "tensorflow.train.Saver" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "last_checkpoints" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'var_list\', \'reshape\', \'sharded\', \'max_to_keep\', \'keep_checkpoint_every_n_hours\', \'name\', \'restore_sequentially\', \'saver_def\', \'builder\', \'defer_build\', \'allow_empty\', \'write_version\', \'pad_step_number\', \'save_relative_paths\', \'filename\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\', \'5\', \'10000.0\', \'None\', \'False\', \'None\', \'None\', \'False\', \'False\', \'2\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "as_saver_def" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "export_meta_graph" + argspec: "args=[\'self\', \'filename\', \'collection_list\', \'as_text\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'False\', \'False\', \'False\'], " + } + member_method { + name: "from_proto" + argspec: "args=[\'saver_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "recover_last_checkpoints" + argspec: "args=[\'self\', \'checkpoint_paths\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "restore" + argspec: "args=[\'self\', \'sess\', \'save_path\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "save" + argspec: "args=[\'self\', \'sess\', \'save_path\', \'global_step\', \'latest_filename\', \'meta_graph_suffix\', \'write_meta_graph\', \'write_state\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'meta\', \'True\', \'True\', \'False\'], " + } + member_method { + name: "set_last_checkpoints" + argspec: "args=[\'self\', \'last_checkpoints\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_last_checkpoints_with_time" + argspec: "args=[\'self\', \'last_checkpoints_with_time\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "to_proto" + argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-scaffold.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-scaffold.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..38cc98b48e78aa93f7614a9baff236f7b119f99d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-scaffold.pbtxt @@ -0,0 +1,53 @@ +path: "tensorflow.train.Scaffold" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "init_feed_dict" + mtype: "" + } + member { + name: "init_fn" + mtype: "" + } + member { + name: "init_op" + mtype: "" + } + member { + name: "local_init_op" + mtype: "" + } + member { + name: "ready_for_local_init_op" + mtype: "" + } + member { + name: "ready_op" + mtype: "" + } + member { + name: "saver" + mtype: "" + } + member { + name: "summary_op" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'init_op\', \'init_feed_dict\', \'init_fn\', \'ready_op\', \'ready_for_local_init_op\', \'local_init_op\', \'summary_op\', \'saver\', \'copy_from_scaffold\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "default_local_init_op" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "finalize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_or_default" + argspec: "args=[\'arg_name\', \'collection_key\', \'default_constructor\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-second-or-step-timer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-second-or-step-timer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..3c5a6ac13cc2d8a4d464ab48da6edaa0a9ccc14b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-second-or-step-timer.pbtxt @@ -0,0 +1,26 @@ +path: "tensorflow.train.SecondOrStepTimer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'every_secs\', \'every_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "last_triggered_step" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "should_trigger_for_step" + argspec: "args=[\'self\', \'step\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update_last_triggered_step" + argspec: "args=[\'self\', \'step\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-sequence-example.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-sequence-example.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6a4553bbc157960696ef17959f532fecdfd54ae8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-sequence-example.pbtxt @@ -0,0 +1,20 @@ +path: "tensorflow.train.SequenceExample" +tf_proto { + descriptor { + name: "SequenceExample" + field { + name: "context" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.Features" + } + field { + name: "feature_lists" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.FeatureLists" + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-server-def.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-server-def.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..83ee7b3eb91a558765abcde630fe6e0480b9818f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-server-def.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.train.ServerDef" +tf_proto { + descriptor { + name: "ServerDef" + field { + name: "cluster" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ClusterDef" + } + field { + name: "job_name" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + field { + name: "task_index" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_INT32 + } + field { + name: "default_session_config" + number: 4 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.ConfigProto" + } + field { + name: "protocol" + number: 5 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-server.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-server.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9b8f185f5b699e860c6fbb50b8d2912984908982 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-server.pbtxt @@ -0,0 +1,29 @@ +path: "tensorflow.train.Server" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "server_def" + mtype: "" + } + member { + name: "target" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'server_or_cluster_def\', \'job_name\', \'task_index\', \'protocol\', \'config\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\'], " + } + member_method { + name: "create_local_server" + argspec: "args=[\'config\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " + } + member_method { + name: "join" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-session-creator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-creator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..beb232715f725047dd8c03054b899a90fa81eec2 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-creator.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.train.SessionCreator" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "create_session" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-session-manager.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-manager.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..448764fe081b250e1e22633f118268ad638cb9dd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-manager.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.train.SessionManager" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\'], " + } + member_method { + name: "prepare_session" + argspec: "args=[\'self\', \'master\', \'init_op\', \'saver\', \'checkpoint_dir\', \'checkpoint_filename_with_path\', \'wait_for_checkpoint\', \'max_wait_secs\', \'config\', \'init_feed_dict\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'7200\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "recover_session" + argspec: "args=[\'self\', \'master\', \'saver\', \'checkpoint_dir\', \'checkpoint_filename_with_path\', \'wait_for_checkpoint\', \'max_wait_secs\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'7200\', \'None\'], " + } + member_method { + name: "wait_for_session" + argspec: "args=[\'self\', \'master\', \'config\', \'max_wait_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'inf\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-session-run-args.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-run-args.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..442990893e33c92bd05a72b198a6584bc979b2fe --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-run-args.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.train.SessionRunArgs" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "feed_dict" + mtype: "" + } + member { + name: "fetches" + mtype: "" + } + member { + name: "options" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-session-run-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-run-context.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d5adb15c95f8a6ebde4ca0e0c535dfebc5edfbf2 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-run-context.pbtxt @@ -0,0 +1,25 @@ +path: "tensorflow.train.SessionRunContext" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "original_args" + mtype: "" + } + member { + name: "session" + mtype: "" + } + member { + name: "stop_requested" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'original_args\', \'session\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "request_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-session-run-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-run-hook.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..db1aa24acf0e295b4b787eef68250401dd6a6e27 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-run-hook.pbtxt @@ -0,0 +1,28 @@ +path: "tensorflow.train.SessionRunHook" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-session-run-values.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-run-values.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0b401d59c400f1d08f47daa2d264a9a5bfc91538 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-run-values.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.train.SessionRunValues" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "options" + mtype: "" + } + member { + name: "results" + mtype: "" + } + member { + name: "run_metadata" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-singular-monitored-session.-step-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-singular-monitored-session.-step-context.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..36d8ce7ff82e02300b59705400be40d7cc3f65ae --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-singular-monitored-session.-step-context.pbtxt @@ -0,0 +1,21 @@ +path: "tensorflow.train.SingularMonitoredSession.StepContext" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "session" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'session\', \'run_with_hooks_fn\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "request_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run_with_hooks" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-singular-monitored-session.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-singular-monitored-session.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..de0f2c1c1a2497ef4e541ee6583d416e31f48826 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-singular-monitored-session.pbtxt @@ -0,0 +1,38 @@ +path: "tensorflow.train.SingularMonitoredSession" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "StepContext" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'hooks\', \'scaffold\', \'master\', \'config\', \'checkpoint_dir\', \'stop_grace_period_secs\', \'checkpoint_filename_with_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'\', \'None\', \'None\', \'120\', \'None\'], " + } + member_method { + name: "close" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "raw_session" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run" + argspec: "args=[\'self\', \'fetches\', \'feed_dict\', \'options\', \'run_metadata\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "run_step_fn" + argspec: "args=[\'self\', \'step_fn\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "should_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-step-counter-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-step-counter-hook.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..13261f6dde1cf8e6fd228950600303370947b7ea --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-step-counter-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.StepCounterHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'every_n_steps\', \'every_n_secs\', \'output_dir\', \'summary_writer\'], varargs=None, keywords=None, defaults=[\'100\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-stop-at-step-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-stop-at-step-hook.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e388599b0bf63379fa95a3276e3f4859eab86d6d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-stop-at-step-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.StopAtStepHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_steps\', \'last_step\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-summary-saver-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-summary-saver-hook.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..697c3667b09f42f208dec38938f5a1ce0cc09029 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-summary-saver-hook.pbtxt @@ -0,0 +1,30 @@ +path: "tensorflow.train.SummarySaverHook" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'save_steps\', \'save_secs\', \'output_dir\', \'summary_writer\', \'scaffold\', \'summary_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "after_create_session" + argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "after_run" + argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "before_run" + argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "begin" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "end" + argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-supervisor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-supervisor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..9677e5a98e4a8308093f51a84d8b1edae405cd2b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-supervisor.pbtxt @@ -0,0 +1,153 @@ +path: "tensorflow.train.Supervisor" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "USE_DEFAULT" + mtype: "" + } + member { + name: "coord" + mtype: "" + } + member { + name: "global_step" + mtype: "" + } + member { + name: "init_feed_dict" + mtype: "" + } + member { + name: "init_op" + mtype: "" + } + member { + name: "is_chief" + mtype: "" + } + member { + name: "ready_for_local_init_op" + mtype: "" + } + member { + name: "ready_op" + mtype: "" + } + member { + name: "save_model_secs" + mtype: "" + } + member { + name: "save_path" + mtype: "" + } + member { + name: "save_summaries_secs" + mtype: "" + } + member { + name: "saver" + mtype: "" + } + member { + name: "session_manager" + mtype: "" + } + member { + name: "summary_op" + mtype: "" + } + member { + name: "summary_writer" + mtype: "" + } + member_method { + name: "Loop" + argspec: "args=[\'self\', \'timer_interval_secs\', \'target\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "PrepareSession" + argspec: "args=[\'self\', \'master\', \'config\', \'wait_for_checkpoint\', \'max_wait_secs\', \'start_standard_services\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'False\', \'7200\', \'True\'], " + } + member_method { + name: "RequestStop" + argspec: "args=[\'self\', \'ex\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "ShouldStop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "StartQueueRunners" + argspec: "args=[\'self\', \'sess\', \'queue_runners\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "StartStandardServices" + argspec: "args=[\'self\', \'sess\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "Stop" + argspec: "args=[\'self\', \'threads\', \'close_summary_writer\', \'ignore_live_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'False\'], " + } + member_method { + name: "StopOnException" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "SummaryComputed" + argspec: "args=[\'self\', \'sess\', \'summary\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "WaitForStop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'graph\', \'ready_op\', \'ready_for_local_init_op\', \'is_chief\', \'init_op\', \'init_feed_dict\', \'local_init_op\', \'logdir\', \'summary_op\', \'saver\', \'global_step\', \'save_summaries_secs\', \'save_model_secs\', \'recovery_wait_secs\', \'stop_grace_secs\', \'checkpoint_basename\', \'session_manager\', \'summary_writer\', \'init_fn\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'True\', \'0\', \'None\', \'0\', \'None\', \'0\', \'0\', \'0\', \'120\', \'600\', \'30\', \'120\', \'model.ckpt\', \'None\', \'0\', \'None\', \'None\'], " + } + member_method { + name: "loop" + argspec: "args=[\'self\', \'timer_interval_secs\', \'target\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "managed_session" + argspec: "args=[], varargs=args, keywords=kwds, defaults=None" + } + member_method { + name: "prepare_or_wait_for_session" + argspec: "args=[\'self\', \'master\', \'config\', \'wait_for_checkpoint\', \'max_wait_secs\', \'start_standard_services\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'False\', \'7200\', \'True\'], " + } + member_method { + name: "request_stop" + argspec: "args=[\'self\', \'ex\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "should_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "start_queue_runners" + argspec: "args=[\'self\', \'sess\', \'queue_runners\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "start_standard_services" + argspec: "args=[\'self\', \'sess\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stop" + argspec: "args=[\'self\', \'threads\', \'close_summary_writer\', \'ignore_live_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'False\'], " + } + member_method { + name: "stop_on_exception" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "summary_computed" + argspec: "args=[\'self\', \'sess\', \'summary\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "wait_for_stop" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-sync-replicas-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-sync-replicas-optimizer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2c0fda3c72b7e1f02265827b9dc1929500935cd1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-sync-replicas-optimizer.pbtxt @@ -0,0 +1,63 @@ +path: "tensorflow.train.SyncReplicasOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'opt\', \'replicas_to_aggregate\', \'total_num_replicas\', \'variable_averages\', \'variables_to_average\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'sync_replicas\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "get_chief_queue_runner" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_init_tokens_op" + argspec: "args=[\'self\', \'num_tokens\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "make_session_run_hook" + argspec: "args=[\'self\', \'is_chief\', \'num_tokens\'], varargs=None, keywords=None, defaults=[\'-1\'], " + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4ce7cb111163e103a1cebe30d5c6f3eeb4234693 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt @@ -0,0 +1,39 @@ +path: "tensorflow.train.VocabInfo" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "backup_initializer" + mtype: "" + } + member { + name: "new_vocab" + mtype: "" + } + member { + name: "new_vocab_size" + mtype: "" + } + member { + name: "num_oov_buckets" + mtype: "" + } + member { + name: "old_vocab" + mtype: "" + } + member { + name: "old_vocab_size" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-worker-session-creator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-worker-session-creator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..ac263580687e53bb3fcffd5268f73f8b67aa43a1 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-worker-session-creator.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.train.WorkerSessionCreator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'scaffold\', \'master\', \'config\', \'max_wait_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'1800\'], " + } + member_method { + name: "create_session" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt similarity index 98% rename from tensorflow/tools/api/golden/tensorflow.train.pbtxt rename to tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt index 5f45b3b1ad9fc04cd9ab306b9a05917ecd54a837..b0fb04d7d4d71e8cb2630ca79284e0ade1db8571 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt @@ -242,7 +242,7 @@ tf_module { } member_method { name: "MonitoredTrainingSession" - argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'\', \'\', \'\', \'None\', \'120\', \'100\', \'7200\', \'\'], " + argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\', \'summary_dir\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'\', \'\', \'\', \'None\', \'120\', \'100\', \'7200\', \'\', \'None\'], " } member_method { name: "NewCheckpointReader" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.-queue-runner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.-queue-runner.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..23d402de30888c1c503a3971cefa1167af3bc8c6 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.-queue-runner.pbtxt @@ -0,0 +1,49 @@ +path: "tensorflow.train.queue_runner.QueueRunner" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "cancel_op" + mtype: "" + } + member { + name: "close_op" + mtype: "" + } + member { + name: "enqueue_ops" + mtype: "" + } + member { + name: "exceptions_raised" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "queue" + mtype: "" + } + member { + name: "queue_closed_exception_types" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'queue\', \'enqueue_ops\', \'close_op\', \'cancel_op\', \'queue_closed_exception_types\', \'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "create_threads" + argspec: "args=[\'self\', \'sess\', \'coord\', \'daemon\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\'], " + } + member_method { + name: "from_proto" + argspec: "args=[\'queue_runner_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "to_proto" + argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6e2d04304967dd08d2c389c209dd43c731c5f956 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.queue_runner.pbtxt @@ -0,0 +1,15 @@ +path: "tensorflow.train.queue_runner" +tf_module { + member { + name: "QueueRunner" + mtype: "" + } + member_method { + name: "add_queue_runner" + argspec: "args=[\'qr\', \'collection\'], varargs=None, keywords=None, defaults=[\'queue_runners\'], " + } + member_method { + name: "start_queue_runners" + argspec: "args=[\'sess\', \'coord\', \'daemon\', \'start\', \'collection\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'True\', \'queue_runners\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.truncated_normal_initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.truncated_normal_initializer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c1e1c230a9f79e87294eb6038f870726a0ba85a4 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.truncated_normal_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.truncated_normal_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.uniform_unit_scaling_initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.uniform_unit_scaling_initializer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e1b18dc92fbee9565dba81e8c09534bea6734f23 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.uniform_unit_scaling_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.uniform_unit_scaling_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'factor\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.variable_scope.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.variable_scope.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e62dec93e6f06a10f48d72b0cda74426887806fb --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.variable_scope.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.variable_scope" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name_or_scope\', \'default_name\', \'values\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\', \'constraint\', \'auxiliary_name_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.variance_scaling_initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.variance_scaling_initializer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..09d7bc03b4f238923db6778ec32ce78ae76eed61 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.variance_scaling_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.variance_scaling_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.zeros_initializer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.zeros_initializer.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..e229b02ceec6739974d3b4ae2bb02ef273398c45 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.zeros_initializer.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.zeros_initializer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py index 1cf330e70247260cd9e50b18903bdfecad6260e4..3a48cf683c908021a6a87849601227283a8e2034 100644 --- a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py +++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py @@ -88,6 +88,9 @@ def _SanitizedMRO(obj): """ return_list = [] for cls in tf_inspect.getmro(obj): + if cls.__name__ == '_NewClass': + # Ignore class created by @deprecated_alias decorator. + continue str_repr = str(cls) return_list.append(str_repr) if 'tensorflow' not in str_repr: diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD index 724b12cd4799eb76fe602c737c850e96e92faa58..8764409e4d1af4ea7f6092b9df64f59511bca43d 100644 --- a/tensorflow/tools/api/tests/BUILD +++ b/tensorflow/tools/api/tests/BUILD @@ -17,7 +17,8 @@ py_test( name = "api_compatibility_test", srcs = ["api_compatibility_test.py"], data = [ - "//tensorflow/tools/api/golden:api_golden", + "//tensorflow/tools/api/golden:api_golden_v1", + "//tensorflow/tools/api/golden:api_golden_v2", "//tensorflow/tools/api/tests:API_UPDATE_WARNING.txt", "//tensorflow/tools/api/tests:README.txt", ], diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index 90375a794f64a9edd2bab2671f5870ae02e84e3c..b65dbc4b7dfeca0d0e65a4b68776814180aee557 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -47,7 +47,6 @@ from tensorflow.tools.api.lib import python_object_to_proto_visitor from tensorflow.tools.common import public_api from tensorflow.tools.common import traverse - # FLAGS defined at the bottom: FLAGS = None # DEFINE_boolean, update_goldens, default False: @@ -62,19 +61,25 @@ _VERBOSE_DIFFS_HELP = """ false, only print which libraries have differences. """ -_API_GOLDEN_FOLDER = 'tensorflow/tools/api/golden' +_API_GOLDEN_FOLDER_V1 = 'tensorflow/tools/api/golden/v1' +_API_GOLDEN_FOLDER_V2 = 'tensorflow/tools/api/golden/v2' _TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt' _UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt' -def _KeyToFilePath(key): - """From a given key, construct a filepath.""" +def _KeyToFilePath(key, api_version): + """From a given key, construct a filepath. + + Filepath will be inside golden folder for api_version. + """ def _ReplaceCapsWithDash(matchobj): match = matchobj.group(0) return '-%s' % (match.lower()) case_insensitive_key = re.sub('([A-Z]{1})', _ReplaceCapsWithDash, key) - return os.path.join(_API_GOLDEN_FOLDER, '%s.pbtxt' % case_insensitive_key) + api_folder = ( + _API_GOLDEN_FOLDER_V2 if api_version == 2 else _API_GOLDEN_FOLDER_V1) + return os.path.join(_API_GOLDEN_FOLDER_V1, '%s.pbtxt' % case_insensitive_key) def _FileNameToKey(filename): @@ -90,6 +95,21 @@ def _FileNameToKey(filename): return api_object_key +def _VerifyNoSubclassOfMessageVisitor(path, parent, unused_children): + """A Visitor that crashes on subclasses of generated proto classes.""" + # If the traversed object is a proto Message class + if not (isinstance(parent, type) and + issubclass(parent, message.Message)): + return + if parent is message.Message: + return + # Check that it is a direct subclass of Message. + if message.Message not in parent.__bases__: + raise NotImplementedError( + 'Object tf.%s is a subclass of a generated proto Message. ' + 'They are not yet supported by the API tools.' % path) + + class ApiCompatibilityTest(test.TestCase): def __init__(self, *args, **kwargs): @@ -112,7 +132,8 @@ class ApiCompatibilityTest(test.TestCase): actual_dict, verbose=False, update_goldens=False, - additional_missing_object_message=''): + additional_missing_object_message='', + api_version=2): """Diff given dicts of protobufs and report differences a readable way. Args: @@ -125,6 +146,7 @@ class ApiCompatibilityTest(test.TestCase): update_goldens: Whether to update goldens when there are diffs found. additional_missing_object_message: Message to print when a symbol is missing. + api_version: TensorFlow API version to test. """ diffs = [] verbose_diffs = [] @@ -150,6 +172,8 @@ class ApiCompatibilityTest(test.TestCase): diff_message = 'New object %s found (added).' % key verbose_diff_message = diff_message else: + # Do not truncate diff + self.maxDiffs = None # pylint: disable=invalid-name # Now we can run an actual proto diff. try: self.assertProtoEquals(expected_dict[key], actual_dict[key]) @@ -180,13 +204,13 @@ class ApiCompatibilityTest(test.TestCase): # If the keys are only in expected, some objects are deleted. # Remove files. for key in only_in_expected: - filepath = _KeyToFilePath(key) + filepath = _KeyToFilePath(key, api_version) file_io.delete_file(filepath) # If the files are only in actual (current library), these are new # modules. Write them to files. Also record all updates in files. for key in only_in_actual | set(updated_keys): - filepath = _KeyToFilePath(key) + filepath = _KeyToFilePath(key, api_version) file_io.write_string_to_file( filepath, text_format.MessageToString(actual_dict[key])) else: @@ -197,43 +221,44 @@ class ApiCompatibilityTest(test.TestCase): logging.info('No differences found between API and golden.') def testNoSubclassOfMessage(self): - - def Visit(path, parent, unused_children): - """A Visitor that crashes on subclasses of generated proto classes.""" - # If the traversed object is a proto Message class - if not (isinstance(parent, type) and - issubclass(parent, message.Message)): - return - if parent is message.Message: - return - # Check that it is a direct subclass of Message. - if message.Message not in parent.__bases__: - raise NotImplementedError( - 'Object tf.%s is a subclass of a generated proto Message. ' - 'They are not yet supported by the API tools.' % path) - visitor = public_api.PublicAPIVisitor(Visit) + visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) visitor.do_not_descend_map['tf'].append('contrib') + # Skip compat.v1 and compat.v2 since they are validated in separate tests. + visitor.private_map['tf.compat'] = ['v1', 'v2'] traverse.traverse(tf, visitor) - @unittest.skipUnless( - sys.version_info.major == 2, - 'API compabitility test goldens are generated using python2.') - def testAPIBackwardsCompatibility(self): + def testNoSubclassOfMessageV1(self): + if not hasattr(tf.compat, 'v1'): + return + visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) + visitor.do_not_descend_map['tf'].append('contrib') + traverse.traverse(tf.compat.v1, visitor) + + def testNoSubclassOfMessageV2(self): + if not hasattr(tf.compat, 'v2'): + return + visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) + visitor.do_not_descend_map['tf'].append('contrib') + traverse.traverse(tf.compat.v2, visitor) + + def _checkBackwardsCompatibility( + self, root, golden_file_pattern, api_version, + additional_private_map=None): # Extract all API stuff. visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() public_api_visitor = public_api.PublicAPIVisitor(visitor) public_api_visitor.do_not_descend_map['tf'].append('contrib') - public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental'] - traverse.traverse(tf, public_api_visitor) + public_api_visitor.do_not_descend_map['tf.GPUOptions'] = [ + 'Experimental'] + if additional_private_map: + public_api_visitor.private_map.update(additional_private_map) + traverse.traverse(root, public_api_visitor) proto_dict = visitor.GetProtos() # Read all golden files. - expression = os.path.join( - resource_loader.get_root_dir_with_all_resources(), - _KeyToFilePath('*')) - golden_file_list = file_io.get_matching_files(expression) + golden_file_list = file_io.get_matching_files(golden_file_pattern) def _ReadFileToProto(filename): """Read a filename, create a protobuf from its contents.""" @@ -252,7 +277,50 @@ class ApiCompatibilityTest(test.TestCase): golden_proto_dict, proto_dict, verbose=FLAGS.verbose_diffs, - update_goldens=FLAGS.update_goldens) + update_goldens=FLAGS.update_goldens, + api_version=api_version) + + @unittest.skipUnless( + sys.version_info.major == 2, + 'API compabitility test goldens are generated using python2.') + def testAPIBackwardsCompatibility(self): + api_version = 1 + golden_file_pattern = os.path.join( + resource_loader.get_root_dir_with_all_resources(), + _KeyToFilePath('*', api_version)) + self._checkBackwardsCompatibility( + tf, + golden_file_pattern, + api_version, + # Skip compat.v1 and compat.v2 since they are validated + # in separate tests. + additional_private_map={'tf.compat': ['v1', 'v2']}) + + @unittest.skipUnless( + sys.version_info.major == 2, + 'API compabitility test goldens are generated using python2.') + def testAPIBackwardsCompatibilityV1(self): + if not hasattr(tf.compat, 'v1'): + return + api_version = 1 + golden_file_pattern = os.path.join( + resource_loader.get_root_dir_with_all_resources(), + _KeyToFilePath('*', api_version)) + self._checkBackwardsCompatibility( + tf.compat.v1, golden_file_pattern, api_version) + + @unittest.skipUnless( + sys.version_info.major == 2, + 'API compabitility test goldens are generated using python2.') + def testAPIBackwardsCompatibilityV2(self): + if not hasattr(tf.compat, 'v2'): + return + api_version = 1 + golden_file_pattern = os.path.join( + resource_loader.get_root_dir_with_all_resources(), + _KeyToFilePath('*', api_version)) + self._checkBackwardsCompatibility( + tf.compat.v2, golden_file_pattern, api_version) if __name__ == '__main__': diff --git a/tensorflow/tools/ci_build/Dockerfile.cmake b/tensorflow/tools/ci_build/Dockerfile.cmake index d5dea4f3e41841aed5aeac02fcca850dbfdfaeb3..4587bcf89103c48e39bb8cc3188391ea99941b3e 100644 --- a/tensorflow/tools/ci_build/Dockerfile.cmake +++ b/tensorflow/tools/ci_build/Dockerfile.cmake @@ -28,6 +28,8 @@ RUN pip install --upgrade astor RUN pip install --upgrade gast RUN pip install --upgrade numpy RUN pip install --upgrade termcolor +RUN pip install keras_applications==1.0.4 +RUN pip install keras_preprocessing==1.0.2 # Install golang RUN apt-get install -t xenial-backports -y golang-1.9 diff --git a/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le new file mode 100644 index 0000000000000000000000000000000000000000..ada2c63880972b3fb9cf525becdf8aae2c248e5f --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le @@ -0,0 +1,20 @@ +FROM ubuntu:16.04 + +LABEL maintainer="William Irons " + +# Copy and run the install scripts. +COPY install/*.sh /install/ +RUN /install/install_bootstrap_deb_packages.sh +RUN add-apt-repository -y ppa:openjdk-r/ppa +RUN /install/install_deb_packages.sh +RUN /install/install_openblas_ppc64le.sh +RUN /install/install_hdf5_ppc64le.sh +RUN /install/install_pip_packages.sh +RUN /install/install_bazel_from_source.sh +RUN /install/install_proto3.sh +RUN /install/install_buildifier_from_source.sh +RUN /install/install_auditwheel.sh +RUN /install/install_golang_ppc64le.sh + +# Set up the master bazelrc configuration file. +COPY install/.bazelrc /etc/bazel.bazelrc diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu b/tensorflow/tools/ci_build/Dockerfile.gpu index 7591ecc04efa887ec1d35ba92881386f5a25241d..383f9545c9fc47c6f2c0213a2c07af48085461a3 100644 --- a/tensorflow/tools/ci_build/Dockerfile.gpu +++ b/tensorflow/tools/ci_build/Dockerfile.gpu @@ -14,6 +14,7 @@ RUN /install/install_bootstrap_deb_packages.sh RUN add-apt-repository -y ppa:openjdk-r/ppa && \ add-apt-repository -y ppa:george-edison55/cmake-3.x RUN /install/install_deb_packages.sh + RUN /install/install_pip_packages.sh RUN /install/install_bazel.sh RUN /install/install_golang.sh @@ -22,6 +23,11 @@ RUN /install/install_golang.sh COPY install/.bazelrc /etc/bazel.bazelrc ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH +# Link NCCL libray and header where the build script expects them. +RUN mkdir /usr/local/cuda-9.0/lib && \ + ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \ + ln -s /usr/include/nccl.h /usr/local/cuda/include/nccl.h + # Configure the build for our CUDA configuration. ENV TF_NEED_CUDA 1 ENV TF_CUDA_COMPUTE_CAPABILITIES 3.0 diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le new file mode 100644 index 0000000000000000000000000000000000000000..e026edb6bb7c946dfd318053b034c796f815b671 --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le @@ -0,0 +1,31 @@ +FROM nvidia/cuda-ppc64le:9.0-cudnn7-devel-ubuntu16.04 + +LABEL maintainer="William Irons " + +# In the Ubuntu 16.04 images, cudnn is placed in system paths. Move them to +# /usr/local/cuda +RUN cp -P /usr/include/cudnn.h /usr/local/cuda/include +RUN cp -P /usr/lib/powerpc64le-linux-gnu/libcudnn* /usr/local/cuda/lib64 + +# Copy and run the install scripts. +COPY install/*.sh /install/ +ARG DEBIAN_FRONTEND=noninteractive +RUN /install/install_bootstrap_deb_packages.sh +RUN add-apt-repository -y ppa:openjdk-r/ppa +RUN /install/install_deb_packages.sh +RUN /install/install_openblas_ppc64le.sh +RUN /install/install_hdf5_ppc64le.sh +RUN /install/install_pip_packages.sh +RUN /install/install_bazel_from_source.sh +RUN /install/install_golang_ppc64le.sh + +# Set up the master bazelrc configuration file. +COPY install/.bazelrc /etc/bazel.bazelrc +ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH + +# Configure the build for our CUDA configuration. +ENV TF_NEED_CUDA 1 +ENV TF_CUDA_COMPUTE_CAPABILITIES 3.0 + +# TODO get NCCL 2 in the docker image +ENV TF_NCCL_VERSION 1 diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cpu b/tensorflow/tools/ci_build/Dockerfile.rbe.cpu index 3bc52b9ed611a0f0a4a269a2864d5b349ee9232c..7e5860aeec186d908e5d2884bd690b2e5e43cffa 100644 --- a/tensorflow/tools/ci_build/Dockerfile.rbe.cpu +++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cpu @@ -1,4 +1,4 @@ -FROM launcher.gcr.io/google/rbe-debian8:r327695 +FROM launcher.gcr.io/google/rbe-ubuntu16-04:r327695 LABEL maintainer="Yu Yi " # Copy install scripts @@ -9,6 +9,6 @@ ENV CC /usr/local/bin/clang ENV CXX /usr/local/bin/clang++ ENV AR /usr/bin/ar -# Run pip install script for RBE Debian8 container. +# Run pip install script for RBE Ubuntu 16-04 container. RUN /install/install_pip_packages_remote.sh RUN /install/install_pip_packages.sh diff --git a/tensorflow/tools/ci_build/builds/android.sh b/tensorflow/tools/ci_build/builds/android.sh index d81793efe08f151c1b448a9da3cc971ca3137829..7c3e30822952f10ec9ff12b61faa783510508adf 100755 --- a/tensorflow/tools/ci_build/builds/android.sh +++ b/tensorflow/tools/ci_build/builds/android.sh @@ -26,13 +26,19 @@ configure_android_workspace # android_full.sh echo "========== TensorFlow Demo Build Test ==========" +TARGETS= +TARGETS+=" //tensorflow/examples/android:tensorflow_demo" +# Also build the Eager Runtime so it remains compatible with Android for the +# benefits of clients like TensorFlow Lite. For now it is enough to build only +# :execute, which what TF Lite needs. +TARGETS+=" //tensorflow/core/common_runtime/eager:execute" # Enable sandboxing so that zip archives don't get incorrectly packaged # in assets/ dir (see https://github.com/bazelbuild/bazel/issues/2334) # TODO(gunan): remove extra flags once sandboxing is enabled for all builds. bazel --bazelrc=/dev/null build \ --compilation_mode=opt --cxxopt=-std=c++11 --fat_apk_cpu=x86_64 \ --spawn_strategy=sandboxed --genrule_strategy=sandboxed \ - //tensorflow/examples/android:tensorflow_demo + ${TARGETS} echo "========== Makefile Build Test ==========" # Test Makefile build just to make sure it still works. diff --git a/tensorflow/tools/ci_build/builds/pip.sh b/tensorflow/tools/ci_build/builds/pip.sh index 883bb9364742e74b4a5a7c8b0d41253352d6c2e7..fef121ab5aaea461b4cbf88ce8c6d4f81718e377 100755 --- a/tensorflow/tools/ci_build/builds/pip.sh +++ b/tensorflow/tools/ci_build/builds/pip.sh @@ -314,7 +314,10 @@ create_activate_virtualenv_and_install_tensorflow() { # Upgrade pip so it supports tags such as cp27mu, manylinux1 etc. echo "Upgrade pip in virtualenv" - pip install --upgrade pip==9.0.1 + + # NOTE: pip install --upgrade pip leads to a documented TLS issue for + # some versions in python + curl https://bootstrap.pypa.io/get-pip.py | python # Force tensorflow reinstallation. Otherwise it may not get installed from # last build if it had the same version number as previous build. diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh index 29680e6882371d7917b446d01f0640dbdfa1b56f..bbaf59c69aa358ce6dd4696049a16f11d0c61c2f 100755 --- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh +++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh @@ -97,7 +97,8 @@ fi # TF_BUILD_APPEND_ARGUMENTS any user supplied args. BAZEL_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py \ --build_tests_only -k --test_tag_filters=${PIP_TEST_FILTER_TAG} \ - --test_timeout 300,450,1200,3600 ${TF_BUILD_APPEND_ARGUMENTS}" + --test_timeout 300,450,1200,3600 ${TF_BUILD_APPEND_ARGUMENTS} \ + --test_output=errors" BAZEL_TEST_TARGETS="//${PIP_TEST_PREFIX}/tensorflow/contrib/... \ //${PIP_TEST_PREFIX}/tensorflow/python/... \ diff --git a/tensorflow/tools/ci_build/ci_build.sh b/tensorflow/tools/ci_build/ci_build.sh index 1f0fd0387af28bf15e5c42fa14f5c1a1ee5a8cfb..77265e0f50bb2c17c9fac76c710ba8bb8559bd7e 100755 --- a/tensorflow/tools/ci_build/ci_build.sh +++ b/tensorflow/tools/ci_build/ci_build.sh @@ -79,7 +79,7 @@ if [[ "${CONTAINER_TYPE}" == "cmake" ]]; then fi # Use nvidia-docker if the container is GPU. -if [[ "${CONTAINER_TYPE}" == "gpu" ]]; then +if [[ "${CONTAINER_TYPE}" == gpu* ]]; then DOCKER_BINARY="nvidia-docker" else DOCKER_BINARY="docker" @@ -99,7 +99,7 @@ BUILD_TAG="${BUILD_TAG:-tf_ci}" # Add extra params for cuda devices and libraries for GPU container. # And clear them if we are not building for GPU. -if [[ "${CONTAINER_TYPE}" != "gpu" ]]; then +if [[ "${CONTAINER_TYPE}" != gpu* ]]; then GPU_EXTRA_PARAMS="" fi @@ -115,6 +115,7 @@ DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | tr '[:upper:]' '[:lower:]') # Print arguments. echo "WORKSPACE: ${WORKSPACE}" +echo "CI_DOCKER_BUILD_EXTRA_PARAMS: ${CI_DOCKER_BUILD_EXTRA_PARAMS[*]}" echo "CI_DOCKER_EXTRA_PARAMS: ${CI_DOCKER_EXTRA_PARAMS[*]}" echo "COMMAND: ${COMMAND[*]}" echo "CI_COMMAND_PREFIX: ${CI_COMMAND_PREFIX[*]}" @@ -126,7 +127,7 @@ echo "" # Build the docker container. echo "Building container (${DOCKER_IMG_NAME})..." -docker build -t ${DOCKER_IMG_NAME} \ +docker build -t ${DOCKER_IMG_NAME} ${CI_DOCKER_BUILD_EXTRA_PARAMS[@]} \ -f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}" # Check docker build status diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh index 90bd8bc3d0c349b5df51535cadc3b1d85d76b7d0..993894d658e63706174105211fc433ecd53d5fe7 100755 --- a/tensorflow/tools/ci_build/ci_parameterized_build.sh +++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh @@ -59,6 +59,9 @@ # TF_BUILD_BAZEL_CLEAN: # Will perform "bazel clean", if and only if this variable # is set to any non-empty and non-0 value +# TF_BAZEL_BUILD_ONLY: +# If it is set to any non-empty value that is not "0", Bazel +# will only build specified targets # TF_GPU_COUNT: # Run this many parallel tests for serial builds. # For now, only can be edited for PIP builds. @@ -128,7 +131,7 @@ BAZEL_CMD="bazel test" BAZEL_BUILD_ONLY_CMD="bazel build" BAZEL_CLEAN_CMD="bazel clean" -DEFAULT_BAZEL_CONFIGS="--config=gcp --config=hdfs" +DEFAULT_BAZEL_CONFIGS="" PIP_CMD="${CI_BUILD_DIR}/builds/pip.sh" PIP_TEST_TUTORIALS_FLAG="--test_tutorials" @@ -147,36 +150,7 @@ BAZEL_TARGET="//tensorflow/... -//tensorflow/compiler/..." if [[ -n "$TF_SKIP_CONTRIB_TESTS" ]]; then BAZEL_TARGET="$BAZEL_TARGET -//tensorflow/contrib/..." else - BAZEL_TARGET="${BAZEL_TARGET} -//tensorflow/contrib/lite/..." - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:context_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:framework" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:interpreter_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:model_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/toco:toco" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:simple_memory_arena_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:string_util_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:activations_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:add_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:basic_rnn_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:concatenation_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:conv_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:depthwise_conv_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:embedding_lookup_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:embedding_lookup_sparse_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:fully_connected_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:hashtable_lookup_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:local_response_norm_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:lsh_projection_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:lstm_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:l2norm_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:mul_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:pooling_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:reshape_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:resize_bilinear_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:skip_gram_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:softmax_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:space_to_depth_test" - BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:svdf_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/..." fi TUT_TEST_DATA_DIR="/tmp/tf_tutorial_test_data" @@ -258,9 +232,9 @@ function set_script_variable() { # Process container type -if [[ ${CTYPE} == "cpu" ]] || [[ ${CTYPE} == "debian.jessie.cpu" ]]; then +if [[ ${CTYPE} == cpu* ]] || [[ ${CTYPE} == "debian.jessie.cpu" ]]; then : -elif [[ ${CTYPE} == "gpu" ]]; then +elif [[ ${CTYPE} == gpu* ]]; then set_script_variable TF_NEED_CUDA 1 if [[ $TF_CUDA_CLANG == "1" ]]; then @@ -410,6 +384,11 @@ fi # this flag, and it only affects a few tests. EXTRA_ARGS="${EXTRA_ARGS} --distinct_host_configuration=false" +if [[ ! -z "${TF_BAZEL_BUILD_ONLY}" ]] && + [[ "${TF_BAZEL_BUILD_ONLY}" != "0" ]];then + BAZEL_CMD=${BAZEL_BUILD_ONLY_CMD} +fi + # Process PIP install-test option if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] || [[ ${TF_BUILD_IS_PIP} == "both" ]]; then @@ -418,12 +397,12 @@ if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] || BAZEL_TARGET=${TF_BUILD_BAZEL_TARGET} fi - if [[ ${CTYPE} == "cpu" ]] || \ + if [[ ${CTYPE} == cpu* ]] || \ [[ ${CTYPE} == "debian.jessie.cpu" ]]; then # CPU only command, fully parallel. NO_PIP_MAIN_CMD="${MAIN_CMD} ${BAZEL_CMD} ${OPT_FLAG} ${EXTRA_ARGS} -- "\ "${BAZEL_TARGET}" - elif [[ ${CTYPE} == "gpu" ]]; then + elif [[ ${CTYPE} == gpu* ]]; then # GPU only command, run as many jobs as the GPU count only. NO_PIP_MAIN_CMD="${BAZEL_CMD} ${OPT_FLAG} "\ "--local_test_jobs=${TF_GPU_COUNT} "\ @@ -562,33 +541,35 @@ echo "" TMP_DIR="" DOCKERFILE_FLAG="" -if [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.5" ]] || - [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.6" ]]; then - # Modify Dockerfile for Python3.5 | Python3.6 build - TMP_DIR=$(mktemp -d) - echo "Docker build will occur in temporary directory: ${TMP_DIR}" - - # Copy the files required for the docker build - SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - cp -r "${SCRIPT_DIR}/install" "${TMP_DIR}/install" || \ - die "ERROR: Failed to copy directory ${SCRIPT_DIR}/install" - - DOCKERFILE="${SCRIPT_DIR}/Dockerfile.${TF_BUILD_CONTAINER_TYPE}" - cp "${DOCKERFILE}" "${TMP_DIR}/" || \ - die "ERROR: Failed to copy Dockerfile at ${DOCKERFILE}" - DOCKERFILE="${TMP_DIR}/Dockerfile.${TF_BUILD_CONTAINER_TYPE}" - - # Replace a line in the Dockerfile - if sed -i \ - "s/RUN \/install\/install_pip_packages.sh/RUN \/install\/install_${TF_BUILD_PYTHON_VERSION}_pip_packages.sh/g" \ - "${DOCKERFILE}" - then - echo "Copied and modified Dockerfile for ${TF_BUILD_PYTHON_VERSION} build: ${DOCKERFILE}" - else - die "ERROR: Faild to copy and modify Dockerfile: ${DOCKERFILE}" - fi +if [[ "${DO_DOCKER}" == "1" ]]; then + if [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.5" ]] || + [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.6" ]]; then + # Modify Dockerfile for Python3.5 | Python3.6 build + TMP_DIR=$(mktemp -d) + echo "Docker build will occur in temporary directory: ${TMP_DIR}" + + # Copy the files required for the docker build + SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + cp -r "${SCRIPT_DIR}/install" "${TMP_DIR}/install" || \ + die "ERROR: Failed to copy directory ${SCRIPT_DIR}/install" + + DOCKERFILE="${SCRIPT_DIR}/Dockerfile.${TF_BUILD_CONTAINER_TYPE}" + cp "${DOCKERFILE}" "${TMP_DIR}/" || \ + die "ERROR: Failed to copy Dockerfile at ${DOCKERFILE}" + DOCKERFILE="${TMP_DIR}/Dockerfile.${TF_BUILD_CONTAINER_TYPE}" + + # Replace a line in the Dockerfile + if sed -i \ + "s/RUN \/install\/install_pip_packages.sh/RUN \/install\/install_${TF_BUILD_PYTHON_VERSION}_pip_packages.sh/g" \ + "${DOCKERFILE}" + then + echo "Copied and modified Dockerfile for ${TF_BUILD_PYTHON_VERSION} build: ${DOCKERFILE}" + else + die "ERROR: Faild to copy and modify Dockerfile: ${DOCKERFILE}" + fi - DOCKERFILE_FLAG="--dockerfile ${DOCKERFILE}" + DOCKERFILE_FLAG="--dockerfile ${DOCKERFILE}" + fi fi chmod +x ${TMP_SCRIPT} diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 05676f9551d4a1e0cb55d0693f99e458381887df..866fe95d2b4b358b63b14b8744eb631a58e18b49 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -349,12 +349,12 @@ do_external_licenses_check(){ # Blacklist echo ${MISSING_LICENSES_FILE} - grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -v ${MISSING_LICENSES_FILE} > temp.txt + grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -e "@com_github_googlecloudplatform_google_cloud_cpp//google" -v ${MISSING_LICENSES_FILE} > temp.txt mv temp.txt ${MISSING_LICENSES_FILE} # Whitelist echo ${EXTRA_LICENSE_FILE} - grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -v ${EXTRA_LICENSES_FILE} > temp.txt + grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -e "@embedded_jdk//" -v ${EXTRA_LICENSES_FILE} > temp.txt mv temp.txt ${EXTRA_LICENSES_FILE} @@ -543,7 +543,7 @@ SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "d SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Test entries in /tensorflow/contrib/cmake/python_{modules|protos|protos_cc}.txt for validity and consistency" "Check file names for cases") INCREMENTAL_FLAG="" -DEFAULT_BAZEL_CONFIGS="--config=hdfs --config=gcp" +DEFAULT_BAZEL_CONFIGS="" # Parse command-line arguments BAZEL_FLAGS=${DEFAULT_BAZEL_CONFIGS} diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh index d0816c92b7308a1079579e605ee9af491a0533fb..75da9bb8356db08c7b9570db673a30ae850e129e 100755 --- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh +++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh @@ -35,6 +35,30 @@ elif [[ ${BASH_VER_MAJOR} -eq 4 ]] && [[ ${BASH_VER_MINOR} -lt 2 ]]; then exit 1 fi +function is_absolute { + [[ "$1" = /* ]] || [[ "$1" =~ ^[a-zA-Z]:[/\\].* ]] +} + +RUNFILES_MANIFEST_FILE="${TEST_SRCDIR}/MANIFEST" +function rlocation() { + if is_absolute "$1" ; then + # If the file path is already fully specified, simply return it. + echo "$1" + elif [[ -e "$TEST_SRCDIR/$1" ]]; then + # If the file exists in the $TEST_SRCDIR then just use it. + echo "$TEST_SRCDIR/$1" + elif [[ -e "$RUNFILES_MANIFEST_FILE" ]]; then + # If a runfiles manifest file exists then use it. + echo "$(grep "^$1 " "$RUNFILES_MANIFEST_FILE" | sed 's/[^ ]* //')" + fi +} + +TEST_BINARY="$(rlocation $TEST_WORKSPACE/${1#./})" +shift + +# Make sure /var/lock exists, this may not be true under MSYS +mkdir -p /var/lock + TF_GPU_COUNT=${TF_GPU_COUNT:-8} for i in `seq 0 $((TF_GPU_COUNT-1))`; do @@ -45,8 +69,8 @@ for i in `seq 0 $((TF_GPU_COUNT-1))`; do # This export only works within the brackets, so it is isolated to one # single command. export CUDA_VISIBLE_DEVICES=$i - echo "Running test $* on GPU $CUDA_VISIBLE_DEVICES" - $@ + echo "Running test $TEST_BINARY $* on GPU $CUDA_VISIBLE_DEVICES" + "$TEST_BINARY" $@ ) return_code=$? flock -u "$lock_fd" diff --git a/tensorflow/tools/ci_build/install/install_bazel.sh b/tensorflow/tools/ci_build/install/install_bazel.sh index 3e27a94cf2bf3110ac181d6ef5a57366be17255f..e284401b8aa469ebcbed856cd09dd597be242d7a 100755 --- a/tensorflow/tools/ci_build/install/install_bazel.sh +++ b/tensorflow/tools/ci_build/install/install_bazel.sh @@ -15,7 +15,7 @@ # ============================================================================== # Select bazel version. -BAZEL_VERSION="0.11.0" +BAZEL_VERSION="0.15.0" set +e local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}') diff --git a/tensorflow/tools/ci_build/install/install_bazel_from_source.sh b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh new file mode 100755 index 0000000000000000000000000000000000000000..87be81577d0efb395a12afc85109f10ad4178c27 --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# This script is to be used to install bzel on non x86_64 systems +# It will compile bazel from source and install it in /usr/local/bin + +# Select bazel version. +BAZEL_VERSION="0.15.0" + +set +e +local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}') + +if [[ "$local_bazel_ver" == "$BAZEL_VERSION" ]]; then + exit 0 +fi + +set -e + +# Compile bazel from source +mkdir -p /bazel +cd /bazel + +curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-dist.zip +unzip bazel-$BAZEL_VERSION-dist.zip +bash ./compile.sh +cp output/bazel /usr/local/bin/ +rm -rf /bazel diff --git a/tensorflow/tools/ci_build/install/install_buildifier_from_source.sh b/tensorflow/tools/ci_build/install/install_buildifier_from_source.sh new file mode 100755 index 0000000000000000000000000000000000000000..a93c258fad1ca62b0c95f22560110ba231aa0053 --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_buildifier_from_source.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e +BUILDTOOLS_VERSION="0.11.1" + +# Clone buildtools +git clone -b $BUILDTOOLS_VERSION https://github.com/bazelbuild/buildtools +cd buildtools + +# Build buildifier +bazel build //buildifier +sudo mv bazel-bin/buildifier/linux*stripped/buildifier /usr/local/bin + +# Build buildozer +bazel build //buildozer +sudo mv bazel-bin/buildozer/linux*stripped/buildozer /usr/local/bin diff --git a/tensorflow/tools/ci_build/install/install_golang_ppc64le.sh b/tensorflow/tools/ci_build/install/install_golang_ppc64le.sh new file mode 100755 index 0000000000000000000000000000000000000000..47d23a59b3ee9152ef9812fbe939e20ee7c2b40a --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_golang_ppc64le.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -ex + +GOLANG_URL="https://storage.googleapis.com/golang/go1.10.linux-ppc64le.tar.gz" + +sudo mkdir -p /usr/local +wget -q -O - "${GOLANG_URL}" | sudo tar -C /usr/local -xz diff --git a/tensorflow/tools/ci_build/install/install_hdf5_ppc64le.sh b/tensorflow/tools/ci_build/install/install_hdf5_ppc64le.sh new file mode 100755 index 0000000000000000000000000000000000000000..4989d986b8eb0690f63ecff41f7107371724bc3a --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_hdf5_ppc64le.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +#This is required because pypi doesn't have a pre-built h5py binary for ppc64le +#It has to be compiled from source during the install +apt-get update +apt-get install -y libhdf5-dev + +#h5py is not expecting the shared libraries to have _serial in the name. +ln -s /usr/lib/powerpc64le-linux-gnu/libhdf5_serial.so /usr/lib/powerpc64le-linux-gnu/libhdf5.so +ln -s /usr/lib/powerpc64le-linux-gnu/libhdf5_serial_hl.so /usr/lib/powerpc64le-linux-gnu/libhdf5_hl.so + +#pip is not installed yet, so use easy_install +#CPATH is the location of hdf5.h +CPATH=/usr/include/hdf5/serial/ easy_install -U h5py +CPATH=/usr/include/hdf5/serial/ easy_install3 -U h5py diff --git a/tensorflow/tools/ci_build/install/install_openblas_ppc64le.sh b/tensorflow/tools/ci_build/install/install_openblas_ppc64le.sh new file mode 100755 index 0000000000000000000000000000000000000000..107cc61ff5aba222dfd49ae8935b7234df4da169 --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_openblas_ppc64le.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +OPENBLAS_SRC_PATH=/tmp/openblas_src/ +POWER="POWER8" +USE_OPENMP="USE_OPENMP=1" +OPENBLAS_INSTALL_PATH="/usr" +apt-get update +apt-get install -y gfortran gfortran-5 +rm -rf ${OPENBLAS_SRC_PATH} +git clone -b release-0.3.0 https://github.com/xianyi/OpenBLAS ${OPENBLAS_SRC_PATH} +cd ${OPENBLAS_SRC_PATH} +# Pick up fix for OpenBLAS issue 1571 +git cherry-pick -X theirs 961d25e9c7e4a1758adb1dbeaa15187de69dd052 +make TARGET=${POWER} ${USE_OPENMP} FC=gfortran +make PREFIX=${OPENBLAS_INSTALL_PATH} install diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index b3d3f23ec8fd16c78b525159abeeb22776e3eac5..bb316ecfc92e41a28a3fbfaf5a12f234c0483d5c 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -51,8 +51,8 @@ pip2 install --upgrade markdown==2.6.8 pip3 install --upgrade markdown==2.6.8 # Install protobuf. -pip2 install --upgrade protobuf==3.3.0 -pip3 install --upgrade protobuf==3.3.0 +pip2 install --upgrade protobuf==3.6.0 +pip3 install --upgrade protobuf==3.6.0 # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* @@ -61,11 +61,11 @@ rm -rf /usr/lib/python3/dist-packages/six* # https://github.com/tensorflow/tensorflow/issues/6968 # This workaround isn't needed for Ubuntu 16.04 or later. if $(cat /etc/*-release | grep -q 14.04); then - pip2 install --no-binary=:all: --upgrade numpy==1.12.0 - pip3 install --no-binary=:all: --upgrade numpy==1.12.0 + pip2 install --no-binary=:all: --upgrade numpy==1.14.5 + pip3 install --no-binary=:all: --upgrade numpy==1.14.5 else - pip2 install --upgrade numpy==1.12.0 - pip3 install --upgrade numpy==1.12.0 + pip2 install --upgrade numpy==1.14.5 + pip3 install --upgrade numpy==1.14.5 fi pip2 install scipy==0.18.1 @@ -113,3 +113,13 @@ pip3 install --upgrade termcolor # Install last working version of setuptools. pip2 install --upgrade setuptools==39.1.0 pip3 install --upgrade setuptools==39.1.0 + +# Keras +pip2 install keras_applications==1.0.4 --no-deps +pip3 install keras_applications==1.0.4 --no-deps +pip2 install keras_preprocessing==1.0.2 --no-deps +pip3 install keras_preprocessing==1.0.2 --no-deps + +# Install last working version of setuptools. +pip2 install --upgrade setuptools==39.1.0 +pip3 install --upgrade setuptools==39.1.0 diff --git a/tensorflow/tools/ci_build/install/install_proto3.sh b/tensorflow/tools/ci_build/install/install_proto3.sh index 7934002b2c982cd10216016f8614b70b77b58e29..821d50baff325106fceca368d46042401d13c336 100755 --- a/tensorflow/tools/ci_build/install/install_proto3.sh +++ b/tensorflow/tools/ci_build/install/install_proto3.sh @@ -17,7 +17,7 @@ # Install protobuf3. # Select protobuf version. -PROTOBUF_VERSION="3.3.0" +PROTOBUF_VERSION="3.6.0" protobuf_ver_flat=$(echo $PROTOBUF_VERSION | sed 's/\.//g' | sed 's/^0*//g') local_protobuf_ver=$(protoc --version) local_protobuf_ver_flat=$(echo $local_protobuf_ver | sed 's/\.//g' | sed 's/^0*//g') diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index 61d34c7304708be3290b116b7e528560ea907031..15e4396ce3c0ee55ccfc18d939dc4e1883fc132d 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -48,7 +48,7 @@ pip3.5 install --upgrade absl-py pip3.5 install --upgrade six==1.10.0 # Install protobuf. -pip3.5 install --upgrade protobuf==3.3.0 +pip3.5 install --upgrade protobuf==3.6.0 # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* @@ -58,7 +58,7 @@ rm -rf /usr/lib/python3/dist-packages/six* # numpy needs to be installed from source to fix segfaults. See: # https://github.com/tensorflow/tensorflow/issues/6968 # This workaround isn't needed for Ubuntu 16.04 or later. -pip3.5 install --no-binary=:all: --upgrade numpy==1.12.0 +pip3.5 install --no-binary=:all: --upgrade numpy==1.14.5 pip3.5 install scipy==0.18.1 @@ -84,4 +84,11 @@ pip3.5 install --upgrade termcolor # Install last working version of setuptools. pip3.5 install --upgrade setuptools==39.1.0 +# Keras +pip3.5 install keras_applications==1.0.4 +pip3.5 install keras_preprocessing==1.0.2 + +# Install last working version of setuptools. +pip3.5 install --upgrade setuptools==39.1.0 + # LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh) diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh index fe2d2cf11c3a172cb795f91184e474a0d7d15167..0fc3eee71ceb5cd331625f26c904f04f844dccc6 100755 --- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh @@ -60,7 +60,7 @@ pip3 install --upgrade absl-py pip3 install --upgrade six==1.10.0 # Install protobuf. -pip3 install --upgrade protobuf==3.3.0 +pip3 install --upgrade protobuf==3.6.0 # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* @@ -70,7 +70,7 @@ rm -rf /usr/lib/python3/dist-packages/six* # numpy needs to be installed from source to fix segfaults. See: # https://github.com/tensorflow/tensorflow/issues/6968 # This workaround isn't needed for Ubuntu 16.04 or later. -pip3 install --no-binary=:all: --upgrade numpy==1.12.0 +pip3 install --no-binary=:all: --upgrade numpy==1.14.5 pip3 install scipy==0.18.1 @@ -100,4 +100,8 @@ pip3 install --upgrade termcolor # Install last working version of setuptools. pip3 install --upgrade setuptools==39.1.0 +# Keras +pip3 install keras_applications==1.0.4 +pip3 install keras_preprocessing==1.0.2 + # LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh) diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh index 2b68de3c5b9bbb0c09ddead7466049827fac4147..f6fa9251d43074e119ea0eacb721727cec953c0c 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh @@ -34,35 +34,4 @@ yes "" | $PYTHON_BIN_PATH configure.py bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --config=opt \ --test_size_filters=small,medium --test_output=errors -- \ - //tensorflow/contrib/... \ - -//tensorflow/contrib/lite/... \ - //tensorflow/contrib/lite:context_test \ - //tensorflow/contrib/lite:framework \ - //tensorflow/contrib/lite:interpreter_test \ - //tensorflow/contrib/lite:model_test \ - //tensorflow/contrib/lite/toco:toco \ - //tensorflow/contrib/lite:simple_memory_arena_test \ - //tensorflow/contrib/lite:string_util_test \ - //tensorflow/contrib/lite/kernels:activations_test \ - //tensorflow/contrib/lite/kernels:add_test \ - //tensorflow/contrib/lite/kernels:basic_rnn_test \ - //tensorflow/contrib/lite/kernels:concatenation_test \ - //tensorflow/contrib/lite/kernels:conv_test \ - //tensorflow/contrib/lite/kernels:depthwise_conv_test \ - //tensorflow/contrib/lite/kernels:embedding_lookup_test \ - //tensorflow/contrib/lite/kernels:embedding_lookup_sparse_test \ - //tensorflow/contrib/lite/kernels:fully_connected_test \ - //tensorflow/contrib/lite/testing:generated_zip_tests \ - //tensorflow/contrib/lite/kernels:hashtable_lookup_test \ - //tensorflow/contrib/lite/kernels:local_response_norm_test \ - //tensorflow/contrib/lite/kernels:lsh_projection_test \ - //tensorflow/contrib/lite/kernels:lstm_test \ - //tensorflow/contrib/lite/kernels:l2norm_test \ - //tensorflow/contrib/lite/kernels:mul_test \ - //tensorflow/contrib/lite/kernels:pooling_test \ - //tensorflow/contrib/lite/kernels:reshape_test \ - //tensorflow/contrib/lite/kernels:resize_bilinear_test \ - //tensorflow/contrib/lite/kernels:skip_gram_test \ - //tensorflow/contrib/lite/kernels:softmax_test \ - //tensorflow/contrib/lite/kernels:space_to_depth_test \ - //tensorflow/contrib/lite/kernels:svdf_test + //tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/linux/gpu/run_mkl.sh b/tensorflow/tools/ci_build/linux/gpu/run_mkl.sh new file mode 100755 index 0000000000000000000000000000000000000000..50ee07e727b309c1370edc993928d7165e9eb6cc --- /dev/null +++ b/tensorflow/tools/ci_build/linux/gpu/run_mkl.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "" + +# Run configure. +export PYTHON_BIN_PATH=`which python2` + +export TF_NEED_CUDA=1 +export TF_CUDA_VERSION=9.0 +export TF_CUDNN_VERSION=7 +export TF_CUDA_COMPUTE_CAPABILITIES=3.7 + +yes "" | $PYTHON_BIN_PATH configure.py + +# Run bazel test command. Double test timeouts to avoid flakes. +# Setting KMP_BLOCKTIME to 0 lets OpenMP threads to sleep right after parallel execution +# in an MKL primitive. This reduces the effects of an oversubscription of OpenMP threads +# caused by executing multiple tests concurrently. +bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test \ + --test_lang_filters=cc,py -k --jobs="${N_JOBS}" \ + --test_timeout 300,450,1200,3600 --build_tests_only --test_env=KMP_BLOCKTIME=0\ + --config=mkl --config=opt --test_output=errors --local_test_jobs=8 \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ + //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... + diff --git a/tensorflow/tools/ci_build/linux/mkl/basic-mkl-gpu-test.sh b/tensorflow/tools/ci_build/linux/mkl/basic-mkl-gpu-test.sh new file mode 100755 index 0000000000000000000000000000000000000000..68354bf7c1cd6717bd0e27dc872703bb723925c4 --- /dev/null +++ b/tensorflow/tools/ci_build/linux/mkl/basic-mkl-gpu-test.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Usage: basic_mkl_test.sh + +# Helper function to traverse directories up until given file is found. +function upsearch () { + test / == "$PWD" && return || \ + test -e "$1" && echo "$PWD" && return || \ + cd .. && upsearch "$1" +} + +# Set up WORKSPACE. +WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}" + +BUILD_TAG=mkl-gpu-ci-test CI_BUILD_USER_FORCE_BADNAME=yes ${WORKSPACE}/tensorflow/tools/ci_build/ci_build.sh gpu tensorflow/tools/ci_build/linux/gpu/run_mkl.sh diff --git a/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh b/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh new file mode 100755 index 0000000000000000000000000000000000000000..b497326d98de21556f7821b2132b5e8255dfee54 --- /dev/null +++ b/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh @@ -0,0 +1,103 @@ +#!/usr/bin/env bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Build a whl and container with Intel(R) MKL support +# Usage: build-dev-container.sh + +# Helper function to traverse directories up until given file is found. +function upsearch () { + test / == "$PWD" && return || \ + test -e "$1" && echo "$PWD" && return || \ + cd .. && upsearch "$1" +} + +# Set up WORKSPACE. +WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}" + +TF_DOCKER_BUILD_DEVEL_BRANCH=${TF_DOCKER_BUILD_DEVEL_BRANCH:-master} +TF_DOCKER_BUILD_IMAGE_NAME=${TF_DOCKER_BUILD_IMAGE_NAME:-intel-mkl/tensorflow} +TF_DOCKER_BUILD_VERSION=${TF_DOCKER_BUILD_VERSION:-nightly} + +echo "TF_DOCKER_BUILD_DEVEL_BRANCH=${TF_DOCKER_BUILD_DEVEL_BRANCH}" +echo "TF_DOCKER_BUILD_IMAGE_NAME=${TF_DOCKER_BUILD_IMAGE_NAME}" +echo "TF_DOCKER_BUILD_VERSION=${TF_DOCKER_BUILD_VERSION}" + +# Build containers for AVX +# Include the instructions for sandybridge and later, but tune for ivybridge +TF_BAZEL_BUILD_OPTIONS="--config=mkl --copt=-march=sandybridge --copt=-mtune=ivybridge --copt=-O3 --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0" + +# build the python 2 container and whl +TF_DOCKER_BUILD_TYPE="MKL" \ + TF_DOCKER_BUILD_IS_DEVEL="YES" \ + TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \ + TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \ + TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}" \ + TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}" \ + ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh + +# build the python 3 container and whl +TF_DOCKER_BUILD_TYPE="MKL" \ + TF_DOCKER_BUILD_IS_DEVEL="YES" \ + TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \ + TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \ + TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}" \ + TF_DOCKER_BUILD_PYTHON_VERSION="PYTHON3" \ + TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}" \ + ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh + +# build the python3.6 container and whl +TF_DOCKER_BUILD_TYPE="MKL" \ + TF_DOCKER_BUILD_IS_DEVEL="YES" \ + TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \ + TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \ + TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}" \ + TF_DOCKER_BUILD_PYTHON_VERSION="PYTHON3.6" \ + TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}" \ + ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh + + +# Build containers for AVX2 +# Include the instructions for haswell and later, but tune for broadwell +TF_BAZEL_BUILD_OPTIONS="--config=mkl --copt=-march=haswell --copt=-mtune=broadwell --copt=-O3 --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0" + +# build the python 2 container and whl +TF_DOCKER_BUILD_TYPE="MKL" \ + TF_DOCKER_BUILD_IS_DEVEL="YES" \ + TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \ + TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \ + TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}-avx2" \ + TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}" \ + ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh + +# build the python 3 container and whl +TF_DOCKER_BUILD_TYPE="MKL" \ + TF_DOCKER_BUILD_IS_DEVEL="YES" \ + TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \ + TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \ + TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}-avx2" \ + TF_DOCKER_BUILD_PYTHON_VERSION="PYTHON3" \ + TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}" \ + ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh + +# build the python3.6 container and whl +TF_DOCKER_BUILD_TYPE="MKL" \ + TF_DOCKER_BUILD_IS_DEVEL="YES" \ + TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \ + TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \ + TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}-avx2" \ + TF_DOCKER_BUILD_PYTHON_VERSION="PYTHON3.6" \ + TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}" \ + ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh + diff --git a/tensorflow/tools/ci_build/linux/ppc64le/cpu/run_py2.sh b/tensorflow/tools/ci_build/linux/ppc64le/cpu/run_py2.sh new file mode 100755 index 0000000000000000000000000000000000000000..e13de35061731d956ffdfd44c056e589cd5aae69 --- /dev/null +++ b/tensorflow/tools/ci_build/linux/ppc64le/cpu/run_py2.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "" + +# Run configure. +export TF_NEED_CUDA=0 +export CC_OPT_FLAGS='-mcpu=power8 -mtune=power8' +export PYTHON_BIN_PATH=`which python2` +yes "" | $PYTHON_BIN_PATH configure.py + +# Run bazel test command. Double test timeouts to avoid flakes. +bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only --config=opt \ + --test_output=errors --test_size_filters=small,medium -- \ + //tensorflow/... -//tensorflow/compiler/... diff --git a/tensorflow/tools/ci_build/linux/ppc64le/cpu/run_py3.sh b/tensorflow/tools/ci_build/linux/ppc64le/cpu/run_py3.sh new file mode 100755 index 0000000000000000000000000000000000000000..a04ac158f5f2b0064d38cf36fb92c2946914ab00 --- /dev/null +++ b/tensorflow/tools/ci_build/linux/ppc64le/cpu/run_py3.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "" + +# Run configure. +export TF_NEED_CUDA=0 +export CC_OPT_FLAGS='-mcpu=power8 -mtune=power8' +export PYTHON_BIN_PATH=`which python3` +yes "" | $PYTHON_BIN_PATH configure.py + +# Run bazel test command. Double test timeouts to avoid flakes. +bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only --config=opt \ + --test_output=errors --test_size_filters=small,medium -- \ + //tensorflow/... -//tensorflow/compiler/... diff --git a/tensorflow/tools/ci_build/linux/ppc64le/gpu/run_py2.sh b/tensorflow/tools/ci_build/linux/ppc64le/gpu/run_py2.sh new file mode 100755 index 0000000000000000000000000000000000000000..77286e8448a1954522a67ca794175b397c05f082 --- /dev/null +++ b/tensorflow/tools/ci_build/linux/ppc64le/gpu/run_py2.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) +LT_JOBS=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "Bazel will use ${LT_JOBS} local test job(s)." +echo "" + +# Run configure. +export PYTHON_BIN_PATH=`which python2` +export CC_OPT_FLAGS='-mcpu=power8 -mtune=power8' + +export TF_NEED_CUDA=1 +export TF_CUDA_COMPUTE_CAPABILITIES=3.7 + +yes "" | $PYTHON_BIN_PATH configure.py + +# Run bazel test command. Double test timeouts to avoid flakes. +bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ + --test_output=errors --local_test_jobs=${LT_JOBS} --build_tests_only --config=opt \ + --test_size_filters=small,medium \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ + //tensorflow/... -//tensorflow/compiler/... diff --git a/tensorflow/tools/ci_build/linux/ppc64le/gpu/run_py3.sh b/tensorflow/tools/ci_build/linux/ppc64le/gpu/run_py3.sh new file mode 100755 index 0000000000000000000000000000000000000000..17aa52ee6b0e61a26f6553834acdab41f64ea409 --- /dev/null +++ b/tensorflow/tools/ci_build/linux/ppc64le/gpu/run_py3.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== + +set -e +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) +LT_JOBS=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader | wc -l) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "Bazel will use ${LT_JOBS} local test job(s)." +echo "" + +# Run configure. +export PYTHON_BIN_PATH=`which python3` +export CC_OPT_FLAGS='-mcpu=power8 -mtune=power8' + +export TF_NEED_CUDA=1 +export TF_CUDA_COMPUTE_CAPABILITIES=3.7 + +yes "" | $PYTHON_BIN_PATH configure.py + +# Run bazel test command. Double test timeouts to avoid flakes. +bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ + --test_output=errors --local_test_jobs=${LT_JOBS} --build_tests_only --config=opt \ + --test_size_filters=small,medium \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \ + //tensorflow/... -//tensorflow/compiler/... diff --git a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh index b8bce57c87ab39ab2f51288163187f2e87c9135d..3d27e84b81c586729aff21d0859383c24f436a11 100755 --- a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh +++ b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh @@ -65,6 +65,10 @@ OPENBLAS_SRC_PATH=/tmp/openblas_src/ sudo rm -rf ${OPENBLAS_SRC_PATH} git clone https://github.com/xianyi/OpenBLAS ${OPENBLAS_SRC_PATH} cd ${OPENBLAS_SRC_PATH} +# The commit after this introduced Fortran compile issues. In theory they should +# be solvable using NOFORTRAN=1 on the make command, but my initial tries didn't +# work, so pinning to the last know good version. +git checkout 5a6a2bed9aff0ba8a18651d5514d029c8cae336a # If this path is changed, you'll also need to update # cxx_builtin_include_directory in third_party/toolchains/cpus/arm/CROSSTOOL.tpl OPENBLAS_INSTALL_PATH=/tmp/openblas_install/ diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py index 00bfcfd49bd1d90dccf094de21173ca9e4307319..30c318a58fae4c84033ea5e906f3ec88818c4b65 100755 --- a/tensorflow/tools/ci_build/update_version.py +++ b/tensorflow/tools/ci_build/update_version.py @@ -37,7 +37,7 @@ SETUP_PY = "%s/tools/pip_package/setup.py" % TF_SRC_DIR README_MD = "./README.md" DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel" % TF_SRC_DIR GPU_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-gpu" % TF_SRC_DIR -CPU_MKL_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-cpu-mkl" % TF_SRC_DIR +CPU_MKL_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-mkl" % TF_SRC_DIR RELEVANT_FILES = [TF_SRC_DIR, VERSION_H, SETUP_PY, @@ -248,16 +248,6 @@ def update_md_files(old_version, new_version): replace_string_in_line(r"%s<\/version>" % old_version, "%s" % new_version, filepath) - # Update any links to colab notebooks. - def colab_url(version): - version_string = "%s.%s.%s" % (version.major, version.minor, version.patch) - prefix = "https://colab.research.google.com/github/tensorflow/models/blob/r" - return prefix + version_string + "/" - - replace_string_in_line( - colab_url(old_version), colab_url(new_version), - "%s/docs_src/get_started/eager.md" % TF_SRC_DIR) - def major_minor_change(old_version, new_version): """Check if a major or minor change occurred.""" diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh index a3e07737a4fa79de80cf667d058517772db9f103..0482cf619a831ebb87e76cd18efbdac83a0d2f11 100644 --- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh +++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh @@ -23,17 +23,20 @@ function run_configure_for_gpu_build { # Enable CUDA support export TF_NEED_CUDA=1 - # TODO(pcloudy): Remove this after TensorFlow uses its own CRSOOTOOL - # for GPU build on Windows - export USE_MSVC_WRAPPER=1 - yes "" | ./configure } -function set_gcs_remote_cache_options { - echo "build --experimental_remote_spawn_cache" >> "${TMP_BAZELRC}" +function set_remote_cache_options { + echo "build --remote_instance_name=projects/tensorflow-testing-cpu" >> "${TMP_BAZELRC}" echo "build --experimental_remote_platform_override='properties:{name:\"build\" value:\"windows-x64\"}'" >> "${TMP_BAZELRC}" - echo "build --remote_http_cache=https://storage.googleapis.com/$GCS_BUCKET_NAME" >> "${TMP_BAZELRC}" + echo "build --remote_cache=remotebuildexecution.googleapis.com" >> "${TMP_BAZELRC}" + echo "build --tls_enabled=true" >> "${TMP_BAZELRC}" + echo "build --remote_timeout=3600" >> "${TMP_BAZELRC}" + echo "build --auth_enabled=true" >> "${TMP_BAZELRC}" + echo "build --spawn_strategy=standalone" >> "${TMP_BAZELRC}" + echo "build --strategy=Javac=standalone" >> "${TMP_BAZELRC}" + echo "build --strategy=Closure=standalone" >> "${TMP_BAZELRC}" + echo "build --genrule_strategy=standalone" >> "${TMP_BAZELRC}" echo "build --google_credentials=$GOOGLE_CLOUD_CREDENTIAL" >> "${TMP_BAZELRC}" } diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh index eefa8ee2d504945991c91e1574b6a74330ba3a8d..333a89d3f5e43edeb440c2a0ac69bd50a1663732 100644 --- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh +++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh @@ -26,7 +26,8 @@ # * Bazel windows executable copied as "bazel.exe" and included in PATH. # Use a temporary directory with a short name. -export TMPDIR="C:/tmp" +export TMPDIR=${TMPDIR:-"C:/tmp"} +export TMPDIR=$(cygpath -m "$TMPDIR") mkdir -p "$TMPDIR" # Set bash path @@ -49,3 +50,15 @@ export PATH="/c/Program Files/Git/cmd:$PATH" # Make sure we have pip in PATH export PATH="/c/${PYTHON_BASE_PATH}/Scripts:$PATH" + +# Setting default values to CUDA related environment variables +export TF_CUDA_VERSION=${TF_CUDA_VERSION:-9.0} +export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7.0} +export TF_CUDA_COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES:-3.7} +export CUDA_TOOLKIT_PATH=${CUDA_TOOLKIT_PATH:-"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${TF_CUDA_VERSION}"} +export CUDNN_INSTALL_PATH=${CUDNN_INSTALL_PATH:-"C:/tools/cuda"} + +# Add Cuda and Cudnn dll directories into PATH +export PATH="$(cygpath -u "${CUDA_TOOLKIT_PATH}")/bin:$PATH" +export PATH="$(cygpath -u "${CUDA_TOOLKIT_PATH}")/extras/CUPTI/libx64:$PATH" +export PATH="$(cygpath -u "${CUDNN_INSTALL_PATH}")/bin:$PATH" diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh index 0b13b97209fa6cd6c629a64fdd54a0423535a9a3..47e0e5dd59af76c733e7f2271294ad0e5c7e6b26 100644 --- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh +++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh @@ -53,42 +53,55 @@ function cleanup { } trap cleanup EXIT -skip_test=0 -release_build=0 +PY_TEST_DIR="py_test_dir" +SKIP_TEST=0 +RELEASE_BUILD=0 +TEST_TARGET="//${PY_TEST_DIR}/tensorflow/python/... \ + //${PY_TEST_DIR}/tensorflow/contrib/... " + +# --skip_test Skip running tests +# --enable_remote_cache Add options to enable remote cache for build and test +# --release_build Build for release, compilation time will be longer to +# ensure performance +# --test_core_only Use tensorflow/python/... as test target +# --test_contrib_only Use tensorflow/contrib/... as test target for ARG in "$@"; do - if [[ "$ARG" == --skip_test ]]; then - skip_test=1 - elif [[ "$ARG" == --enable_gcs_remote_cache ]]; then - set_gcs_remote_cache_options - elif [[ "$ARG" == --release_build ]]; then - release_build=1 - fi + case "$ARG" in + --skip_test) SKIP_TEST=1 ;; + --enable_remote_cache) set_remote_cache_options ;; + --release_build) RELEASE_BUILD=1 ;; + --test_core_only) TEST_TARGET="//${PY_TEST_DIR}/tensorflow/python/..." ;; + --test_contrib_only) TEST_TARGET="//${PY_TEST_DIR}/tensorflow/contrib/..." ;; + *) + esac done -if [[ "$release_build" != 1 ]]; then - # --define=override_eigen_strong_inline=true speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc +if [[ "$RELEASE_BUILD" == 1 ]]; then + # Overriding eigen strong inline speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc # by 20 minutes. See https://github.com/tensorflow/tensorflow/issues/10521 - # Because this hurts the performance of TF, we don't enable it in release build. - echo "build --define=override_eigen_strong_inline=true" >> "${TMP_BAZELRC}" + # Because this hurts the performance of TF, we don't override it in release build. + export TF_OVERRIDE_EIGEN_STRONG_INLINE=0 +else + export TF_OVERRIDE_EIGEN_STRONG_INLINE=1 fi -# The host and target platforms are the same in Windows build. So we don't have -# to distinct them. This helps avoid building the same targets twice. -echo "build --distinct_host_configuration=false" >> "${TMP_BAZELRC}" +# Enable short object file path to avoid long path issue on Windows. +echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}" -echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc +if ! grep -q "import %workspace%/${TMP_BAZELRC}" .bazelrc; then + echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc +fi run_configure_for_cpu_build bazel build --announce_rc --config=opt tensorflow/tools/pip_package:build_pip_package || exit $? -if [[ "$skip_test" == 1 ]]; then +if [[ "$SKIP_TEST" == 1 ]]; then exit 0 fi # Create a python test directory to avoid package name conflict -PY_TEST_DIR="py_test_dir" create_python_test_dir "${PY_TEST_DIR}" ./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$PWD/${PY_TEST_DIR}" @@ -106,7 +119,7 @@ bazel test --announce_rc --config=opt -k --test_output=errors \ --define=no_tensorflow_py_deps=true --test_lang_filters=py \ --test_tag_filters=-no_pip,-no_windows,-no_oss \ --build_tag_filters=-no_pip,-no_windows,-no_oss --build_tests_only \ + --test_size_filters=small,medium \ --jobs="${N_JOBS}" --test_timeout="300,450,1200,3600" \ --flaky_test_attempts=3 \ - //${PY_TEST_DIR}/tensorflow/python/... \ - //${PY_TEST_DIR}/tensorflow/contrib/... + ${TEST_TARGET} diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh index 922bb67bbf6ce34f55acad6d3399bd810032abd0..e3eee110808ce6cf1905a44b9108ad6de49f10cb 100644 --- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh +++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh @@ -42,12 +42,69 @@ source "tensorflow/tools/ci_build/windows/bazel/common_env.sh" \ source "tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh" \ || { echo "Failed to source bazel_test_lib.sh" >&2; exit 1; } +# Recreate an empty bazelrc file under source root +export TMP_BAZELRC=.tmp.bazelrc +rm -f "${TMP_BAZELRC}" +touch "${TMP_BAZELRC}" + +function cleanup { + # Remove all options in .tmp.bazelrc + echo "" > "${TMP_BAZELRC}" +} +trap cleanup EXIT + +PY_TEST_DIR="py_test_dir" + +SKIP_TEST=0 +RELEASE_BUILD=0 +TEST_TARGET="//${PY_TEST_DIR}/tensorflow/python/... \ + //${PY_TEST_DIR}/tensorflow/contrib/... " + +# --skip_test Skip running tests +# --enable_remote_cache Add options to enable remote cache for build and test +# --release_build Build for release, compilation time will be longer to +# ensure performance +# --test_core_only Use tensorflow/python/... as test target +# --test_contrib_only Use tensorflow/contrib/... as test target +for ARG in "$@"; do + case "$ARG" in + --skip_test) SKIP_TEST=1 ;; + --enable_remote_cache) set_remote_cache_options ;; + --release_build) RELEASE_BUILD=1 ;; + --test_core_only) TEST_TARGET="//${PY_TEST_DIR}/tensorflow/python/..." ;; + --test_contrib_only) TEST_TARGET="//${PY_TEST_DIR}/tensorflow/contrib/..." ;; + *) + esac +done + +if [[ "$RELEASE_BUILD" == 1 ]]; then + # Overriding eigen strong inline speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc + # by 20 minutes. See https://github.com/tensorflow/tensorflow/issues/10521 + # Because this hurts the performance of TF, we don't override it in release build. + export TF_OVERRIDE_EIGEN_STRONG_INLINE=0 +else + export TF_OVERRIDE_EIGEN_STRONG_INLINE=1 +fi + +# Enable short object file path to avoid long path issue on Windows. +echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}" + +# Disable nvcc warnings to reduce log file size. +echo "build --copt=-nvcc_options=disable-warnings" >> "${TMP_BAZELRC}" + +if ! grep -q "import %workspace%/${TMP_BAZELRC}" .bazelrc; then + echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc +fi + run_configure_for_gpu_build -bazel build -c opt tensorflow/tools/pip_package:build_pip_package || exit $? +bazel build --announce_rc --config=opt tensorflow/tools/pip_package:build_pip_package || exit $? + +if [[ "$SKIP_TEST" == 1 ]]; then + exit 0 +fi # Create a python test directory to avoid package name conflict -PY_TEST_DIR="py_test_dir" create_python_test_dir "${PY_TEST_DIR}" ./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$PWD/${PY_TEST_DIR}" @@ -56,11 +113,18 @@ create_python_test_dir "${PY_TEST_DIR}" PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow-*.whl) reinstall_tensorflow_pip ${PIP_NAME} +TF_GPU_COUNT=${TF_GPU_COUNT:-8} + # Define no_tensorflow_py_deps=true so that every py_test has no deps anymore, # which will result testing system installed tensorflow # GPU tests are very flaky when running concurrently, so set local_test_jobs=1 -bazel test -c opt -k --test_output=errors \ +bazel test --announce_rc --config=opt -k --test_output=errors \ + --test_env=TF_GPU_COUNT \ + --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute \ --define=no_tensorflow_py_deps=true --test_lang_filters=py \ - --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,no_oss \ - --build_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,no_oss \ - --local_test_jobs=1 --build_tests_only //${PY_TEST_DIR}/tensorflow/python/... + --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss \ + --build_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss --build_tests_only \ + --test_size_filters=small,medium \ + --local_test_jobs=$TF_GPU_COUNT --test_timeout="300,450,1200,3600" \ + --flaky_test_attempts=3 \ + ${TEST_TARGET} diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh index 583d1d5f09527861015458c636af2259b34d45f8..fdbd1120b20ea4461a4ec5f84c666d8b62309905 100755 --- a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh +++ b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh @@ -41,7 +41,7 @@ run_configure_for_cpu_build # build_libtensorflow_tarball in ../builds/libtensorflow.sh # cannot be used on Windows since it relies on pkg_tar rules. # So we do something special here -bazel build -c opt --copt=/arch:AVX \ +bazel --output_user_root=${TMPDIR} build -c opt --copt=/arch:AVX \ tensorflow:libtensorflow.so \ tensorflow/tools/lib_package:clicenses_generate \ tensorflow/java:libtensorflow_jni.so \ diff --git a/tensorflow/tools/common/public_api.py b/tensorflow/tools/common/public_api.py index e0acead9195933c97e3ca8eb3aed5a1b40c19aa6..09933d266b843641a79ec71aa40081d8c8675e23 100644 --- a/tensorflow/tools/common/public_api.py +++ b/tensorflow/tools/common/public_api.py @@ -50,6 +50,7 @@ class PublicAPIVisitor(object): # Each entry maps a module path to a name to ignore in traversal. self._do_not_descend_map = { 'tf': [ + 'compiler', 'core', 'examples', 'flags', # Don't add flags @@ -69,6 +70,8 @@ class PublicAPIVisitor(object): 'tf.app': ['flags'], # Imported for compatibility between py2/3. 'tf.test': ['mock'], + # Externalized modules of the Keras API. + 'tf.keras': ['applications', 'preprocessing'] } @property diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD index b7bfb29aae4fcaa55e01ba924f72cf79d2b09ad1..55792c51fe87f0ded92730c13409169f6c67d035 100644 --- a/tensorflow/tools/compatibility/BUILD +++ b/tensorflow/tools/compatibility/BUILD @@ -8,10 +8,17 @@ load( "tf_cc_test", # @unused ) +py_library( + name = "ast_edits", + srcs = ["ast_edits.py"], + srcs_version = "PY2AND3", +) + py_binary( name = "tf_upgrade", srcs = ["tf_upgrade.py"], srcs_version = "PY2AND3", + deps = [":ast_edits"], ) py_test( @@ -26,6 +33,28 @@ py_test( ], ) +py_binary( + name = "tf_upgrade_v2", + srcs = [ + "renames_v2.py", + "tf_upgrade_v2.py", + ], + srcs_version = "PY2AND3", + deps = [":ast_edits"], +) + +py_test( + name = "tf_upgrade_v2_test", + srcs = ["tf_upgrade_v2_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":tf_upgrade_v2", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "@six_archive//:six", + ], +) + # Keep for reference, this test will succeed in 0.11 but fail in 1.0 # py_test( # name = "test_file_v0_11", @@ -62,9 +91,37 @@ py_test( ], ) +genrule( + name = "generate_upgraded_file_v2", + testonly = 1, + srcs = ["testdata/test_file_v1_10.py"], + outs = [ + "test_file_v2_0.py", + "report_v2.txt", + ], + cmd = ("$(location :tf_upgrade_v2)" + + " --infile $(location testdata/test_file_v1_10.py)" + + " --outfile $(location test_file_v2_0.py)" + + " --reportfile $(location report_v2.txt)"), + tools = [":tf_upgrade_v2"], +) + +py_test( + name = "test_file_v2_0", + size = "small", + srcs = ["test_file_v2_0.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + exports_files( [ + "ast_edits.py", "tf_upgrade.py", + "renames_v2.py", "testdata/test_file_v0_11.py", + "testdata/test_file_v1_10.py", ], ) diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py new file mode 100644 index 0000000000000000000000000000000000000000..23cc4a21a9e6f81c8dc5016bc2cb6a2f151c7924 --- /dev/null +++ b/tensorflow/tools/compatibility/ast_edits.py @@ -0,0 +1,502 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Upgrader for Python scripts according to an API change specification.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ast +import collections +import os +import shutil +import sys +import tempfile +import traceback + + +class APIChangeSpec(object): + """This class defines the transformations that need to happen. + + This class must provide the following fields: + + * `function_keyword_renames`: maps function names to a map of old -> new + argument names + * `function_renames`: maps function names to new function names + * `change_to_function`: a set of function names that have changed (for + notifications) + * `function_reorders`: maps functions whose argument order has changed to the + list of arguments in the new order + * `function_handle`: maps function names to custom handlers for the function + + For an example, see `TFAPIChangeSpec`. + """ + + +class _FileEditTuple( + collections.namedtuple("_FileEditTuple", + ["comment", "line", "start", "old", "new"])): + """Each edit that is recorded by a _FileEditRecorder. + + Fields: + comment: A description of the edit and why it was made. + line: The line number in the file where the edit occurs (1-indexed). + start: The line number in the file where the edit occurs (0-indexed). + old: text string to remove (this must match what was in file). + new: text string to add in place of `old`. + """ + + __slots__ = () + + +class _FileEditRecorder(object): + """Record changes that need to be done to the file.""" + + def __init__(self, filename): + # all edits are lists of chars + self._filename = filename + + self._line_to_edit = collections.defaultdict(list) + self._errors = [] + + def process(self, text): + """Process a list of strings, each corresponding to the recorded changes. + + Args: + text: A list of lines of text (assumed to contain newlines) + Returns: + A tuple of the modified text and a textual description of what is done. + Raises: + ValueError: if substitution source location does not have expected text. + """ + + change_report = "" + + # Iterate of each line + for line, edits in self._line_to_edit.items(): + offset = 0 + # sort by column so that edits are processed in order in order to make + # indexing adjustments cumulative for changes that change the string + # length + edits.sort(key=lambda x: x.start) + + # Extract each line to a list of characters, because mutable lists + # are editable, unlike immutable strings. + char_array = list(text[line - 1]) + + # Record a description of the change + change_report += "%r Line %d\n" % (self._filename, line) + change_report += "-" * 80 + "\n\n" + for e in edits: + change_report += "%s\n" % e.comment + change_report += "\n Old: %s" % (text[line - 1]) + + # Make underscore buffers for underlining where in the line the edit was + change_list = [" "] * len(text[line - 1]) + change_list_new = [" "] * len(text[line - 1]) + + # Iterate for each edit + for e in edits: + # Create effective start, end by accounting for change in length due + # to previous edits + start_eff = e.start + offset + end_eff = start_eff + len(e.old) + + # Make sure the edit is changing what it should be changing + old_actual = "".join(char_array[start_eff:end_eff]) + if old_actual != e.old: + raise ValueError("Expected text %r but got %r" % + ("".join(e.old), "".join(old_actual))) + # Make the edit + char_array[start_eff:end_eff] = list(e.new) + + # Create the underline highlighting of the before and after + change_list[e.start:e.start + len(e.old)] = "~" * len(e.old) + change_list_new[start_eff:end_eff] = "~" * len(e.new) + + # Keep track of how to generate effective ranges + offset += len(e.new) - len(e.old) + + # Finish the report comment + change_report += " %s\n" % "".join(change_list) + text[line - 1] = "".join(char_array) + change_report += " New: %s" % (text[line - 1]) + change_report += " %s\n\n" % "".join(change_list_new) + return "".join(text), change_report, self._errors + + def add(self, comment, line, start, old, new, error=None): + """Add a new change that is needed. + + Args: + comment: A description of what was changed + line: Line number (1 indexed) + start: Column offset (0 indexed) + old: old text + new: new text + error: this "edit" is something that cannot be fixed automatically + Returns: + None + """ + + self._line_to_edit[line].append( + _FileEditTuple(comment, line, start, old, new)) + if error: + self._errors.append("%s:%d: %s" % (self._filename, line, error)) + + +class _ASTCallVisitor(ast.NodeVisitor): + """AST Visitor that processes function calls. + + Updates function calls from old API version to new API version using a given + change spec. + """ + + def __init__(self, filename, lines, api_change_spec): + self._filename = filename + self._file_edit = _FileEditRecorder(filename) + self._lines = lines + self._api_change_spec = api_change_spec + + def process(self, lines): + return self._file_edit.process(lines) + + def generic_visit(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def _rename_functions(self, node, full_name): + function_renames = self._api_change_spec.function_renames + try: + new_name = function_renames[full_name] + self._file_edit.add("Renamed function %r to %r" % (full_name, new_name), + node.lineno, node.col_offset, full_name, new_name) + except KeyError: + pass + + def _get_attribute_full_path(self, node): + """Traverse an attribute to generate a full name e.g. tf.foo.bar. + + Args: + node: A Node of type Attribute. + + Returns: + a '.'-delimited full-name or None if the tree was not a simple form. + i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c". + """ + curr = node + items = [] + while not isinstance(curr, ast.Name): + if not isinstance(curr, ast.Attribute): + return None + items.append(curr.attr) + curr = curr.value + items.append(curr.id) + return ".".join(reversed(items)) + + def _find_true_position(self, node): + """Return correct line number and column offset for a given node. + + This is necessary mainly because ListComp's location reporting reports + the next token after the list comprehension list opening. + + Args: + node: Node for which we wish to know the lineno and col_offset + """ + import re + find_open = re.compile("^\s*(\\[).*$") + find_string_chars = re.compile("['\"]") + + if isinstance(node, ast.ListComp): + # Strangely, ast.ListComp returns the col_offset of the first token + # after the '[' token which appears to be a bug. Workaround by + # explicitly finding the real start of the list comprehension. + line = node.lineno + col = node.col_offset + # loop over lines + while 1: + # Reverse the text to and regular expression search for whitespace + text = self._lines[line - 1] + reversed_preceding_text = text[:col][::-1] + # First find if a [ can be found with only whitespace between it and + # col. + m = find_open.match(reversed_preceding_text) + if m: + new_col_offset = col - m.start(1) - 1 + return line, new_col_offset + else: + if (reversed_preceding_text == "" or + reversed_preceding_text.isspace()): + line = line - 1 + prev_line = self._lines[line - 1] + # TODO(aselle): + # this is poor comment detection, but it is good enough for + # cases where the comment does not contain string literal starting/ + # ending characters. If ast gave us start and end locations of the + # ast nodes rather than just start, we could use string literal + # node ranges to filter out spurious #'s that appear in string + # literals. + comment_start = prev_line.find("#") + if comment_start == -1: + col = len(prev_line) - 1 + elif find_string_chars.search(prev_line[comment_start:]) is None: + col = comment_start + else: + return None, None + else: + return None, None + # Most other nodes return proper locations (with notably does not), but + # it is not possible to use that in an argument. + return node.lineno, node.col_offset + + def visit_Call(self, node): # pylint: disable=invalid-name + """Handle visiting a call node in the AST. + + Args: + node: Current Node + """ + + # Find a simple attribute name path e.g. "tf.foo.bar" + full_name = self._get_attribute_full_path(node.func) + + # Make sure the func is marked as being part of a call + node.func.is_function_for_call = True + + if full_name: + # Call special handlers + function_handles = self._api_change_spec.function_handle + if full_name in function_handles: + function_handles[full_name](self._file_edit, node) + + # Examine any non-keyword argument and make it into a keyword argument + # if reordering required. + function_reorders = self._api_change_spec.function_reorders + function_keyword_renames = ( + self._api_change_spec.function_keyword_renames) + + if full_name in function_reorders: + reordered = function_reorders[full_name] + for idx, arg in enumerate(node.args): + lineno, col_offset = self._find_true_position(arg) + if lineno is None or col_offset is None: + self._file_edit.add( + "Failed to add keyword %r to reordered function %r" % + (reordered[idx], full_name), + arg.lineno, + arg.col_offset, + "", + "", + error="A necessary keyword argument failed to be inserted.") + else: + keyword_arg = reordered[idx] + if (full_name in function_keyword_renames and + keyword_arg in function_keyword_renames[full_name]): + keyword_arg = function_keyword_renames[full_name][keyword_arg] + self._file_edit.add("Added keyword %r to reordered function %r" % + (reordered[idx], full_name), lineno, col_offset, + "", keyword_arg + "=") + + # Examine each keyword argument and convert it to the final renamed form + renamed_keywords = ({} if full_name not in function_keyword_renames else + function_keyword_renames[full_name]) + for keyword in node.keywords: + argkey = keyword.arg + argval = keyword.value + + if argkey in renamed_keywords: + argval_lineno, argval_col_offset = self._find_true_position(argval) + if argval_lineno is not None and argval_col_offset is not None: + # TODO(aselle): We should scan backward to find the start of the + # keyword key. Unfortunately ast does not give you the location of + # keyword keys, so we are forced to infer it from the keyword arg + # value. + key_start = argval_col_offset - len(argkey) - 1 + key_end = key_start + len(argkey) + 1 + if (self._lines[argval_lineno - 1][key_start:key_end] == argkey + + "="): + self._file_edit.add("Renamed keyword argument from %r to %r" % + (argkey, + renamed_keywords[argkey]), argval_lineno, + argval_col_offset - len(argkey) - 1, + argkey + "=", renamed_keywords[argkey] + "=") + continue + self._file_edit.add( + "Failed to rename keyword argument from %r to %r" % + (argkey, renamed_keywords[argkey]), + argval.lineno, + argval.col_offset - len(argkey) - 1, + "", + "", + error="Failed to find keyword lexographically. Fix manually.") + + ast.NodeVisitor.generic_visit(self, node) + + def visit_Attribute(self, node): # pylint: disable=invalid-name + """Handle bare Attributes i.e. [tf.foo, tf.bar]. + + Args: + node: Node that is of type ast.Attribute + """ + full_name = self._get_attribute_full_path(node) + if full_name: + self._rename_functions(node, full_name) + if full_name in self._api_change_spec.change_to_function: + if not hasattr(node, "is_function_for_call"): + new_text = full_name + "()" + self._file_edit.add("Changed %r to %r" % (full_name, new_text), + node.lineno, node.col_offset, full_name, new_text) + + ast.NodeVisitor.generic_visit(self, node) + + +class ASTCodeUpgrader(object): + """Handles upgrading a set of Python files using a given API change spec.""" + + def __init__(self, api_change_spec): + if not isinstance(api_change_spec, APIChangeSpec): + raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" % + type(api_change_spec)) + self._api_change_spec = api_change_spec + + def process_file(self, in_filename, out_filename): + """Process the given python file for incompatible changes. + + Args: + in_filename: filename to parse + out_filename: output file to write to + Returns: + A tuple representing number of files processed, log of actions, errors + """ + + # Write to a temporary file, just in case we are doing an implace modify. + with open(in_filename, "r") as in_file, \ + tempfile.NamedTemporaryFile("w", delete=False) as temp_file: + ret = self.process_opened_file(in_filename, in_file, out_filename, + temp_file) + + shutil.move(temp_file.name, out_filename) + return ret + + # Broad exceptions are required here because ast throws whatever it wants. + # pylint: disable=broad-except + def process_opened_file(self, in_filename, in_file, out_filename, out_file): + """Process the given python file for incompatible changes. + + This function is split out to facilitate StringIO testing from + tf_upgrade_test.py. + + Args: + in_filename: filename to parse + in_file: opened file (or StringIO) + out_filename: output file to write to + out_file: opened file (or StringIO) + Returns: + A tuple representing number of files processed, log of actions, errors + """ + process_errors = [] + text = "-" * 80 + "\n" + text += "Processing file %r\n outputting to %r\n" % (in_filename, + out_filename) + text += "-" * 80 + "\n\n" + + parsed_ast = None + lines = in_file.readlines() + try: + parsed_ast = ast.parse("".join(lines)) + except Exception: + text += "Failed to parse %r\n\n" % in_filename + text += traceback.format_exc() + if parsed_ast: + visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec) + visitor.visit(parsed_ast) + out_text, new_text, process_errors = visitor.process(lines) + text += new_text + if out_file: + out_file.write(out_text) + text += "\n" + return 1, text, process_errors + + # pylint: enable=broad-except + + def process_tree(self, root_directory, output_root_directory, + copy_other_files): + """Processes upgrades on an entire tree of python files in place. + + Note that only Python files. If you have custom code in other languages, + you will need to manually upgrade those. + + Args: + root_directory: Directory to walk and process. + output_root_directory: Directory to use as base. + copy_other_files: Copy files that are not touched by this converter. + + Returns: + A tuple of files processed, the report string ofr all files, and errors + """ + + # make sure output directory doesn't exist + if output_root_directory and os.path.exists(output_root_directory): + print("Output directory %r must not already exist." % + (output_root_directory)) + sys.exit(1) + + # make sure output directory does not overlap with root_directory + norm_root = os.path.split(os.path.normpath(root_directory)) + norm_output = os.path.split(os.path.normpath(output_root_directory)) + if norm_root == norm_output: + print("Output directory %r same as input directory %r" % + (root_directory, output_root_directory)) + sys.exit(1) + + # Collect list of files to process (we do this to correctly handle if the + # user puts the output directory in some sub directory of the input dir) + files_to_process = [] + files_to_copy = [] + for dir_name, _, file_list in os.walk(root_directory): + py_files = [f for f in file_list if f.endswith(".py")] + copy_files = [f for f in file_list if not f.endswith(".py")] + for filename in py_files: + fullpath = os.path.join(dir_name, filename) + fullpath_output = os.path.join(output_root_directory, + os.path.relpath(fullpath, + root_directory)) + files_to_process.append((fullpath, fullpath_output)) + if copy_other_files: + for filename in copy_files: + fullpath = os.path.join(dir_name, filename) + fullpath_output = os.path.join(output_root_directory, + os.path.relpath( + fullpath, root_directory)) + files_to_copy.append((fullpath, fullpath_output)) + + file_count = 0 + tree_errors = [] + report = "" + report += ("=" * 80) + "\n" + report += "Input tree: %r\n" % root_directory + report += ("=" * 80) + "\n" + + for input_path, output_path in files_to_process: + output_directory = os.path.dirname(output_path) + if not os.path.isdir(output_directory): + os.makedirs(output_directory) + file_count += 1 + _, l_report, l_errors = self.process_file(input_path, output_path) + tree_errors += l_errors + report += l_report + for input_path, output_path in files_to_copy: + output_directory = os.path.dirname(output_path) + if not os.path.isdir(output_directory): + os.makedirs(output_directory) + shutil.copy(input_path, output_path) + return file_count, report, tree_errors diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..216aa41b60eb566db37244b72cbeef024546607f --- /dev/null +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -0,0 +1,134 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=line-too-long +"""List of renames to apply when converting from TF 1.0 to TF 2.0. + +THIS FILE IS AUTOGENERATED: To update, please run: + bazel build tensorflow/tools/compatibility/update:generate_v2_renames_map + bazel-bin/tensorflow/tools/compatibility/update/generate_v2_renames_map +This file should be updated whenever endpoints are deprecated. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +renames = { + 'tf.acos': 'tf.math.acos', + 'tf.acosh': 'tf.math.acosh', + 'tf.add': 'tf.math.add', + 'tf.as_string': 'tf.dtypes.as_string', + 'tf.asin': 'tf.math.asin', + 'tf.asinh': 'tf.math.asinh', + 'tf.atan': 'tf.math.atan', + 'tf.atan2': 'tf.math.atan2', + 'tf.atanh': 'tf.math.atanh', + 'tf.batch_to_space_nd': 'tf.manip.batch_to_space_nd', + 'tf.betainc': 'tf.math.betainc', + 'tf.ceil': 'tf.math.ceil', + 'tf.check_numerics': 'tf.debugging.check_numerics', + 'tf.cholesky': 'tf.linalg.cholesky', + 'tf.cos': 'tf.math.cos', + 'tf.cosh': 'tf.math.cosh', + 'tf.cross': 'tf.linalg.cross', + 'tf.decode_base64': 'tf.io.decode_base64', + 'tf.decode_compressed': 'tf.io.decode_compressed', + 'tf.decode_json_example': 'tf.io.decode_json_example', + 'tf.decode_raw': 'tf.io.decode_raw', + 'tf.dequantize': 'tf.quantization.dequantize', + 'tf.diag': 'tf.linalg.tensor_diag', + 'tf.diag_part': 'tf.linalg.tensor_diag_part', + 'tf.digamma': 'tf.math.digamma', + 'tf.encode_base64': 'tf.io.encode_base64', + 'tf.equal': 'tf.math.equal', + 'tf.erfc': 'tf.math.erfc', + 'tf.exp': 'tf.math.exp', + 'tf.expm1': 'tf.math.expm1', + 'tf.extract_image_patches': 'tf.image.extract_image_patches', + 'tf.fake_quant_with_min_max_args': 'tf.quantization.fake_quant_with_min_max_args', + 'tf.fake_quant_with_min_max_args_gradient': 'tf.quantization.fake_quant_with_min_max_args_gradient', + 'tf.fake_quant_with_min_max_vars': 'tf.quantization.fake_quant_with_min_max_vars', + 'tf.fake_quant_with_min_max_vars_gradient': 'tf.quantization.fake_quant_with_min_max_vars_gradient', + 'tf.fake_quant_with_min_max_vars_per_channel': 'tf.quantization.fake_quant_with_min_max_vars_per_channel', + 'tf.fake_quant_with_min_max_vars_per_channel_gradient': 'tf.quantization.fake_quant_with_min_max_vars_per_channel_gradient', + 'tf.fft': 'tf.spectral.fft', + 'tf.floor': 'tf.math.floor', + 'tf.gather_nd': 'tf.manip.gather_nd', + 'tf.greater': 'tf.math.greater', + 'tf.greater_equal': 'tf.math.greater_equal', + 'tf.ifft': 'tf.spectral.ifft', + 'tf.igamma': 'tf.math.igamma', + 'tf.igammac': 'tf.math.igammac', + 'tf.invert_permutation': 'tf.math.invert_permutation', + 'tf.is_finite': 'tf.debugging.is_finite', + 'tf.is_inf': 'tf.debugging.is_inf', + 'tf.is_nan': 'tf.debugging.is_nan', + 'tf.less': 'tf.math.less', + 'tf.less_equal': 'tf.math.less_equal', + 'tf.lgamma': 'tf.math.lgamma', + 'tf.log': 'tf.math.log', + 'tf.log1p': 'tf.math.log1p', + 'tf.logical_and': 'tf.math.logical_and', + 'tf.logical_not': 'tf.math.logical_not', + 'tf.logical_or': 'tf.math.logical_or', + 'tf.matching_files': 'tf.io.matching_files', + 'tf.matrix_band_part': 'tf.linalg.band_part', + 'tf.matrix_determinant': 'tf.linalg.det', + 'tf.matrix_diag': 'tf.linalg.diag', + 'tf.matrix_diag_part': 'tf.linalg.diag_part', + 'tf.matrix_inverse': 'tf.linalg.inv', + 'tf.matrix_set_diag': 'tf.linalg.set_diag', + 'tf.matrix_solve': 'tf.linalg.solve', + 'tf.matrix_triangular_solve': 'tf.linalg.triangular_solve', + 'tf.maximum': 'tf.math.maximum', + 'tf.minimum': 'tf.math.minimum', + 'tf.not_equal': 'tf.math.not_equal', + 'tf.parse_tensor': 'tf.io.parse_tensor', + 'tf.polygamma': 'tf.math.polygamma', + 'tf.qr': 'tf.linalg.qr', + 'tf.quantized_concat': 'tf.quantization.quantized_concat', + 'tf.read_file': 'tf.io.read_file', + 'tf.reciprocal': 'tf.math.reciprocal', + 'tf.regex_replace': 'tf.strings.regex_replace', + 'tf.reshape': 'tf.manip.reshape', + 'tf.reverse': 'tf.manip.reverse', + 'tf.reverse_v2': 'tf.manip.reverse', + 'tf.rint': 'tf.math.rint', + 'tf.rsqrt': 'tf.math.rsqrt', + 'tf.scatter_nd': 'tf.manip.scatter_nd', + 'tf.segment_max': 'tf.math.segment_max', + 'tf.segment_mean': 'tf.math.segment_mean', + 'tf.segment_min': 'tf.math.segment_min', + 'tf.segment_prod': 'tf.math.segment_prod', + 'tf.segment_sum': 'tf.math.segment_sum', + 'tf.sin': 'tf.math.sin', + 'tf.sinh': 'tf.math.sinh', + 'tf.space_to_batch_nd': 'tf.manip.space_to_batch_nd', + 'tf.squared_difference': 'tf.math.squared_difference', + 'tf.string_join': 'tf.strings.join', + 'tf.string_strip': 'tf.strings.strip', + 'tf.string_to_hash_bucket': 'tf.strings.to_hash_bucket', + 'tf.string_to_hash_bucket_fast': 'tf.strings.to_hash_bucket_fast', + 'tf.string_to_hash_bucket_strong': 'tf.strings.to_hash_bucket_strong', + 'tf.string_to_number': 'tf.strings.to_number', + 'tf.substr': 'tf.strings.substr', + 'tf.tan': 'tf.math.tan', + 'tf.tile': 'tf.manip.tile', + 'tf.unsorted_segment_max': 'tf.math.unsorted_segment_max', + 'tf.unsorted_segment_min': 'tf.math.unsorted_segment_min', + 'tf.unsorted_segment_prod': 'tf.math.unsorted_segment_prod', + 'tf.unsorted_segment_sum': 'tf.math.unsorted_segment_sum', + 'tf.write_file': 'tf.io.write_file', + 'tf.zeta': 'tf.math.zeta' +} diff --git a/tensorflow/tools/compatibility/testdata/test_file_v1_10.py b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py new file mode 100644 index 0000000000000000000000000000000000000000..a49035a1a09bb6b6ea33a375766c9c414f871df1 --- /dev/null +++ b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py @@ -0,0 +1,34 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf upgrader.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test as test_lib + + +class TestUpgrade(test_util.TensorFlowTestCase): + """Test various APIs that have been changed in 2.0.""" + + def testRenames(self): + with self.test_session(): + self.assertAllClose(1.04719755, tf.acos(0.5).eval()) + self.assertAllClose(0.5, tf.rsqrt(4.0).eval()) + +if __name__ == "__main__": + test_lib.main() diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py index 1f8833582af4c922115e637117e775e619439786..96705b1a4c27e72ba1d50f16dad10c35705b1782 100644 --- a/tensorflow/tools/compatibility/tf_upgrade.py +++ b/tensorflow/tools/compatibility/tf_upgrade.py @@ -19,491 +19,11 @@ from __future__ import division from __future__ import print_function import argparse -import ast -import collections -import os -import shutil -import sys -import tempfile -import traceback +from tensorflow.tools.compatibility import ast_edits -class APIChangeSpec(object): - """This class defines the transformations that need to happen. - This class must provide the following fields: - - * `function_keyword_renames`: maps function names to a map of old -> new - argument names - * `function_renames`: maps function names to new function names - * `change_to_function`: a set of function names that have changed (for - notifications) - * `function_reorders`: maps functions whose argument order has changed to the - list of arguments in the new order - * `function_handle`: maps function names to custom handlers for the function - - For an example, see `TFAPIChangeSpec`. - """ - - -class _FileEditTuple( - collections.namedtuple("_FileEditTuple", - ["comment", "line", "start", "old", "new"])): - """Each edit that is recorded by a _FileEditRecorder. - - Fields: - comment: A description of the edit and why it was made. - line: The line number in the file where the edit occurs (1-indexed). - start: The line number in the file where the edit occurs (0-indexed). - old: text string to remove (this must match what was in file). - new: text string to add in place of `old`. - """ - - __slots__ = () - - -class _FileEditRecorder(object): - """Record changes that need to be done to the file.""" - - def __init__(self, filename): - # all edits are lists of chars - self._filename = filename - - self._line_to_edit = collections.defaultdict(list) - self._errors = [] - - def process(self, text): - """Process a list of strings, each corresponding to the recorded changes. - - Args: - text: A list of lines of text (assumed to contain newlines) - Returns: - A tuple of the modified text and a textual description of what is done. - Raises: - ValueError: if substitution source location does not have expected text. - """ - - change_report = "" - - # Iterate of each line - for line, edits in self._line_to_edit.items(): - offset = 0 - # sort by column so that edits are processed in order in order to make - # indexing adjustments cumulative for changes that change the string - # length - edits.sort(key=lambda x: x.start) - - # Extract each line to a list of characters, because mutable lists - # are editable, unlike immutable strings. - char_array = list(text[line - 1]) - - # Record a description of the change - change_report += "%r Line %d\n" % (self._filename, line) - change_report += "-" * 80 + "\n\n" - for e in edits: - change_report += "%s\n" % e.comment - change_report += "\n Old: %s" % (text[line - 1]) - - # Make underscore buffers for underlining where in the line the edit was - change_list = [" "] * len(text[line - 1]) - change_list_new = [" "] * len(text[line - 1]) - - # Iterate for each edit - for e in edits: - # Create effective start, end by accounting for change in length due - # to previous edits - start_eff = e.start + offset - end_eff = start_eff + len(e.old) - - # Make sure the edit is changing what it should be changing - old_actual = "".join(char_array[start_eff:end_eff]) - if old_actual != e.old: - raise ValueError("Expected text %r but got %r" % - ("".join(e.old), "".join(old_actual))) - # Make the edit - char_array[start_eff:end_eff] = list(e.new) - - # Create the underline highlighting of the before and after - change_list[e.start:e.start + len(e.old)] = "~" * len(e.old) - change_list_new[start_eff:end_eff] = "~" * len(e.new) - - # Keep track of how to generate effective ranges - offset += len(e.new) - len(e.old) - - # Finish the report comment - change_report += " %s\n" % "".join(change_list) - text[line - 1] = "".join(char_array) - change_report += " New: %s" % (text[line - 1]) - change_report += " %s\n\n" % "".join(change_list_new) - return "".join(text), change_report, self._errors - - def add(self, comment, line, start, old, new, error=None): - """Add a new change that is needed. - - Args: - comment: A description of what was changed - line: Line number (1 indexed) - start: Column offset (0 indexed) - old: old text - new: new text - error: this "edit" is something that cannot be fixed automatically - Returns: - None - """ - - self._line_to_edit[line].append( - _FileEditTuple(comment, line, start, old, new)) - if error: - self._errors.append("%s:%d: %s" % (self._filename, line, error)) - - -class _ASTCallVisitor(ast.NodeVisitor): - """AST Visitor that processes function calls. - - Updates function calls from old API version to new API version using a given - change spec. - """ - - def __init__(self, filename, lines, api_change_spec): - self._filename = filename - self._file_edit = _FileEditRecorder(filename) - self._lines = lines - self._api_change_spec = api_change_spec - - def process(self, lines): - return self._file_edit.process(lines) - - def generic_visit(self, node): - ast.NodeVisitor.generic_visit(self, node) - - def _rename_functions(self, node, full_name): - function_renames = self._api_change_spec.function_renames - try: - new_name = function_renames[full_name] - self._file_edit.add("Renamed function %r to %r" % (full_name, new_name), - node.lineno, node.col_offset, full_name, new_name) - except KeyError: - pass - - def _get_attribute_full_path(self, node): - """Traverse an attribute to generate a full name e.g. tf.foo.bar. - - Args: - node: A Node of type Attribute. - - Returns: - a '.'-delimited full-name or None if the tree was not a simple form. - i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c". - """ - curr = node - items = [] - while not isinstance(curr, ast.Name): - if not isinstance(curr, ast.Attribute): - return None - items.append(curr.attr) - curr = curr.value - items.append(curr.id) - return ".".join(reversed(items)) - - def _find_true_position(self, node): - """Return correct line number and column offset for a given node. - - This is necessary mainly because ListComp's location reporting reports - the next token after the list comprehension list opening. - - Args: - node: Node for which we wish to know the lineno and col_offset - """ - import re - find_open = re.compile("^\s*(\\[).*$") - find_string_chars = re.compile("['\"]") - - if isinstance(node, ast.ListComp): - # Strangely, ast.ListComp returns the col_offset of the first token - # after the '[' token which appears to be a bug. Workaround by - # explicitly finding the real start of the list comprehension. - line = node.lineno - col = node.col_offset - # loop over lines - while 1: - # Reverse the text to and regular expression search for whitespace - text = self._lines[line - 1] - reversed_preceding_text = text[:col][::-1] - # First find if a [ can be found with only whitespace between it and - # col. - m = find_open.match(reversed_preceding_text) - if m: - new_col_offset = col - m.start(1) - 1 - return line, new_col_offset - else: - if (reversed_preceding_text == "" or - reversed_preceding_text.isspace()): - line = line - 1 - prev_line = self._lines[line - 1] - # TODO(aselle): - # this is poor comment detection, but it is good enough for - # cases where the comment does not contain string literal starting/ - # ending characters. If ast gave us start and end locations of the - # ast nodes rather than just start, we could use string literal - # node ranges to filter out spurious #'s that appear in string - # literals. - comment_start = prev_line.find("#") - if comment_start == -1: - col = len(prev_line) - 1 - elif find_string_chars.search(prev_line[comment_start:]) is None: - col = comment_start - else: - return None, None - else: - return None, None - # Most other nodes return proper locations (with notably does not), but - # it is not possible to use that in an argument. - return node.lineno, node.col_offset - - def visit_Call(self, node): # pylint: disable=invalid-name - """Handle visiting a call node in the AST. - - Args: - node: Current Node - """ - - # Find a simple attribute name path e.g. "tf.foo.bar" - full_name = self._get_attribute_full_path(node.func) - - # Make sure the func is marked as being part of a call - node.func.is_function_for_call = True - - if full_name: - # Call special handlers - function_handles = self._api_change_spec.function_handle - if full_name in function_handles: - function_handles[full_name](self._file_edit, node) - - # Examine any non-keyword argument and make it into a keyword argument - # if reordering required. - function_reorders = self._api_change_spec.function_reorders - function_keyword_renames = ( - self._api_change_spec.function_keyword_renames) - - if full_name in function_reorders: - reordered = function_reorders[full_name] - for idx, arg in enumerate(node.args): - lineno, col_offset = self._find_true_position(arg) - if lineno is None or col_offset is None: - self._file_edit.add( - "Failed to add keyword %r to reordered function %r" % - (reordered[idx], full_name), - arg.lineno, - arg.col_offset, - "", - "", - error="A necessary keyword argument failed to be inserted.") - else: - keyword_arg = reordered[idx] - if (full_name in function_keyword_renames and - keyword_arg in function_keyword_renames[full_name]): - keyword_arg = function_keyword_renames[full_name][keyword_arg] - self._file_edit.add("Added keyword %r to reordered function %r" % - (reordered[idx], full_name), lineno, col_offset, - "", keyword_arg + "=") - - # Examine each keyword argument and convert it to the final renamed form - renamed_keywords = ({} if full_name not in function_keyword_renames else - function_keyword_renames[full_name]) - for keyword in node.keywords: - argkey = keyword.arg - argval = keyword.value - - if argkey in renamed_keywords: - argval_lineno, argval_col_offset = self._find_true_position(argval) - if argval_lineno is not None and argval_col_offset is not None: - # TODO(aselle): We should scan backward to find the start of the - # keyword key. Unfortunately ast does not give you the location of - # keyword keys, so we are forced to infer it from the keyword arg - # value. - key_start = argval_col_offset - len(argkey) - 1 - key_end = key_start + len(argkey) + 1 - if (self._lines[argval_lineno - 1][key_start:key_end] == argkey + - "="): - self._file_edit.add("Renamed keyword argument from %r to %r" % - (argkey, - renamed_keywords[argkey]), argval_lineno, - argval_col_offset - len(argkey) - 1, - argkey + "=", renamed_keywords[argkey] + "=") - continue - self._file_edit.add( - "Failed to rename keyword argument from %r to %r" % - (argkey, renamed_keywords[argkey]), - argval.lineno, - argval.col_offset - len(argkey) - 1, - "", - "", - error="Failed to find keyword lexographically. Fix manually.") - - ast.NodeVisitor.generic_visit(self, node) - - def visit_Attribute(self, node): # pylint: disable=invalid-name - """Handle bare Attributes i.e. [tf.foo, tf.bar]. - - Args: - node: Node that is of type ast.Attribute - """ - full_name = self._get_attribute_full_path(node) - if full_name: - self._rename_functions(node, full_name) - if full_name in self._api_change_spec.change_to_function: - if not hasattr(node, "is_function_for_call"): - new_text = full_name + "()" - self._file_edit.add("Changed %r to %r" % (full_name, new_text), - node.lineno, node.col_offset, full_name, new_text) - - ast.NodeVisitor.generic_visit(self, node) - - -class ASTCodeUpgrader(object): - """Handles upgrading a set of Python files using a given API change spec.""" - - def __init__(self, api_change_spec): - if not isinstance(api_change_spec, APIChangeSpec): - raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" % - type(api_change_spec)) - self._api_change_spec = api_change_spec - - def process_file(self, in_filename, out_filename): - """Process the given python file for incompatible changes. - - Args: - in_filename: filename to parse - out_filename: output file to write to - Returns: - A tuple representing number of files processed, log of actions, errors - """ - - # Write to a temporary file, just in case we are doing an implace modify. - with open(in_filename, "r") as in_file, \ - tempfile.NamedTemporaryFile("w", delete=False) as temp_file: - ret = self.process_opened_file(in_filename, in_file, out_filename, - temp_file) - - shutil.move(temp_file.name, out_filename) - return ret - - # Broad exceptions are required here because ast throws whatever it wants. - # pylint: disable=broad-except - def process_opened_file(self, in_filename, in_file, out_filename, out_file): - """Process the given python file for incompatible changes. - - This function is split out to facilitate StringIO testing from - tf_upgrade_test.py. - - Args: - in_filename: filename to parse - in_file: opened file (or StringIO) - out_filename: output file to write to - out_file: opened file (or StringIO) - Returns: - A tuple representing number of files processed, log of actions, errors - """ - process_errors = [] - text = "-" * 80 + "\n" - text += "Processing file %r\n outputting to %r\n" % (in_filename, - out_filename) - text += "-" * 80 + "\n\n" - - parsed_ast = None - lines = in_file.readlines() - try: - parsed_ast = ast.parse("".join(lines)) - except Exception: - text += "Failed to parse %r\n\n" % in_filename - text += traceback.format_exc() - if parsed_ast: - visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec) - visitor.visit(parsed_ast) - out_text, new_text, process_errors = visitor.process(lines) - text += new_text - if out_file: - out_file.write(out_text) - text += "\n" - return 1, text, process_errors - - # pylint: enable=broad-except - - def process_tree(self, root_directory, output_root_directory, - copy_other_files): - """Processes upgrades on an entire tree of python files in place. - - Note that only Python files. If you have custom code in other languages, - you will need to manually upgrade those. - - Args: - root_directory: Directory to walk and process. - output_root_directory: Directory to use as base. - copy_other_files: Copy files that are not touched by this converter. - - Returns: - A tuple of files processed, the report string ofr all files, and errors - """ - - # make sure output directory doesn't exist - if output_root_directory and os.path.exists(output_root_directory): - print("Output directory %r must not already exist." % - (output_root_directory)) - sys.exit(1) - - # make sure output directory does not overlap with root_directory - norm_root = os.path.split(os.path.normpath(root_directory)) - norm_output = os.path.split(os.path.normpath(output_root_directory)) - if norm_root == norm_output: - print("Output directory %r same as input directory %r" % - (root_directory, output_root_directory)) - sys.exit(1) - - # Collect list of files to process (we do this to correctly handle if the - # user puts the output directory in some sub directory of the input dir) - files_to_process = [] - files_to_copy = [] - for dir_name, _, file_list in os.walk(root_directory): - py_files = [f for f in file_list if f.endswith(".py")] - copy_files = [f for f in file_list if not f.endswith(".py")] - for filename in py_files: - fullpath = os.path.join(dir_name, filename) - fullpath_output = os.path.join(output_root_directory, - os.path.relpath(fullpath, - root_directory)) - files_to_process.append((fullpath, fullpath_output)) - if copy_other_files: - for filename in copy_files: - fullpath = os.path.join(dir_name, filename) - fullpath_output = os.path.join(output_root_directory, - os.path.relpath( - fullpath, root_directory)) - files_to_copy.append((fullpath, fullpath_output)) - - file_count = 0 - tree_errors = [] - report = "" - report += ("=" * 80) + "\n" - report += "Input tree: %r\n" % root_directory - report += ("=" * 80) + "\n" - - for input_path, output_path in files_to_process: - output_directory = os.path.dirname(output_path) - if not os.path.isdir(output_directory): - os.makedirs(output_directory) - file_count += 1 - _, l_report, l_errors = self.process_file(input_path, output_path) - tree_errors += l_errors - report += l_report - for input_path, output_path in files_to_copy: - output_directory = os.path.dirname(output_path) - if not os.path.isdir(output_directory): - os.makedirs(output_directory) - shutil.copy(input_path, output_path) - return file_count, report, tree_errors - - -class TFAPIChangeSpec(APIChangeSpec): +class TFAPIChangeSpec(ast_edits.APIChangeSpec): """List of maps that describe what changed in the API.""" def __init__(self): @@ -718,7 +238,7 @@ Simple usage: default="report.txt") args = parser.parse_args() - upgrade = ASTCodeUpgrader(TFAPIChangeSpec()) + upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec()) report_text = None report_filename = args.report_filename files_processed = 0 diff --git a/tensorflow/tools/compatibility/tf_upgrade_test.py b/tensorflow/tools/compatibility/tf_upgrade_test.py index 3d02eacba6e7a91e6d3c88e8297306de9782f4bf..66325ea2ad36265c6c3779b414774abab8213a84 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_test.py @@ -22,6 +22,7 @@ import tempfile import six from tensorflow.python.framework import test_util from tensorflow.python.platform import test as test_lib +from tensorflow.tools.compatibility import ast_edits from tensorflow.tools.compatibility import tf_upgrade @@ -36,7 +37,7 @@ class TestUpgrade(test_util.TensorFlowTestCase): def _upgrade(self, old_file_text): in_file = six.StringIO(old_file_text) out_file = six.StringIO() - upgrader = tf_upgrade.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) + upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) count, report, errors = ( upgrader.process_opened_file("test.py", in_file, "test_out.py", out_file)) @@ -139,7 +140,7 @@ class TestUpgradeFiles(test_util.TensorFlowTestCase): upgraded = "tf.multiply(a, b)\n" temp_file.write(original) temp_file.close() - upgrader = tf_upgrade.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) + upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) upgrader.process_file(temp_file.name, temp_file.name) self.assertAllEqual(open(temp_file.name).read(), upgraded) os.unlink(temp_file.name) diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..9702430a1219c33e6d68875e1366ee7ebb2ce308 --- /dev/null +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -0,0 +1,115 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Upgrader for Python scripts from 1.* TensorFlow to 2.0 TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse + +from tensorflow.tools.compatibility import ast_edits +from tensorflow.tools.compatibility import renames_v2 + + +class TFAPIChangeSpec(ast_edits.APIChangeSpec): + """List of maps that describe what changed in the API.""" + + def __init__(self): + # Maps from a function name to a dictionary that describes how to + # map from an old argument keyword to the new argument keyword. + self.function_keyword_renames = {} + + # Mapping from function to the new name of the function + self.function_renames = renames_v2.renames + + # Variables that should be changed to functions. + self.change_to_function = {} + + # Functions that were reordered should be changed to the new keyword args + # for safety, if positional arguments are used. If you have reversed the + # positional arguments yourself, this could do the wrong thing. + self.function_reorders = {} + + # Specially handled functions. + self.function_handle = {} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""Convert a TensorFlow Python file to 2.0 + +Simple usage: + tf_convert_v2.py --infile foo.py --outfile bar.py + tf_convert_v2.py --intree ~/code/old --outtree ~/code/new +""") + parser.add_argument( + "--infile", + dest="input_file", + help="If converting a single file, the name of the file " + "to convert") + parser.add_argument( + "--outfile", + dest="output_file", + help="If converting a single file, the output filename.") + parser.add_argument( + "--intree", + dest="input_tree", + help="If converting a whole tree of files, the directory " + "to read from (relative or absolute).") + parser.add_argument( + "--outtree", + dest="output_tree", + help="If converting a whole tree of files, the output " + "directory (relative or absolute).") + parser.add_argument( + "--copyotherfiles", + dest="copy_other_files", + help=("If converting a whole tree of files, whether to " + "copy the other files."), + type=bool, + default=False) + parser.add_argument( + "--reportfile", + dest="report_filename", + help=("The name of the file where the report log is " + "stored." + "(default: %(default)s)"), + default="report.txt") + args = parser.parse_args() + + upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec()) + report_text = None + report_filename = args.report_filename + files_processed = 0 + if args.input_file: + files_processed, report_text, errors = upgrade.process_file( + args.input_file, args.output_file) + files_processed = 1 + elif args.input_tree: + files_processed, report_text, errors = upgrade.process_tree( + args.input_tree, args.output_tree, args.copy_other_files) + else: + parser.print_help() + if report_text: + open(report_filename, "w").write(report_text) + print("TensorFlow 2.0 Upgrade Script") + print("-----------------------------") + print("Converted %d files\n" % files_processed) + print("Detected %d errors that require attention" % len(errors)) + print("-" * 80) + print("\n".join(errors)) + print("\nMake sure to read the detailed log %r\n" % report_filename) diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..57ac04de0667b83b66853b7cee7b4a34bc9f2f2f --- /dev/null +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -0,0 +1,83 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf 2.0 upgrader.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import tempfile +import six +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test as test_lib +from tensorflow.tools.compatibility import ast_edits +from tensorflow.tools.compatibility import tf_upgrade_v2 + + +class TestUpgrade(test_util.TensorFlowTestCase): + """Test various APIs that have been changed in 2.0. + + We also test whether a converted file is executable. test_file_v1_10.py + aims to exhaustively test that API changes are convertible and actually + work when run with current TensorFlow. + """ + + def _upgrade(self, old_file_text): + in_file = six.StringIO(old_file_text) + out_file = six.StringIO() + upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec()) + count, report, errors = ( + upgrader.process_opened_file("test.py", in_file, + "test_out.py", out_file)) + return count, report, errors, out_file.getvalue() + + def testParseError(self): + _, report, unused_errors, unused_new_text = self._upgrade( + "import tensorflow as tf\na + \n") + self.assertTrue(report.find("Failed to parse") != -1) + + def testReport(self): + text = "tf.acos(a)\n" + _, report, unused_errors, unused_new_text = self._upgrade(text) + # This is not a complete test, but it is a sanity test that a report + # is generating information. + self.assertTrue(report.find("Renamed function `tf.acos` to `tf.math.acos`")) + + def testRename(self): + text = "tf.acos(a)\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.math.acos(a)\n") + text = "tf.rsqrt(tf.log(3.8))\n" + _, unused_report, unused_errors, new_text = self._upgrade(text) + self.assertEqual(new_text, "tf.math.rsqrt(tf.math.log(3.8))\n") + + +class TestUpgradeFiles(test_util.TensorFlowTestCase): + + def testInplace(self): + """Check to make sure we don't have a file system race.""" + temp_file = tempfile.NamedTemporaryFile("w", delete=False) + original = "tf.acos(a, b)\n" + upgraded = "tf.math.acos(a, b)\n" + temp_file.write(original) + temp_file.close() + upgrader = ast_edits.ASTCodeUpgrader(tf_upgrade_v2.TFAPIChangeSpec()) + upgrader.process_file(temp_file.name, temp_file.name) + self.assertAllEqual(open(temp_file.name).read(), upgraded) + os.unlink(temp_file.name) + + +if __name__ == "__main__": + test_lib.main() diff --git a/tensorflow/tools/compatibility/update/BUILD b/tensorflow/tools/compatibility/update/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..feb37c902ec3359e6221937f4334ab2504394fa3 --- /dev/null +++ b/tensorflow/tools/compatibility/update/BUILD @@ -0,0 +1,15 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +py_binary( + name = "generate_v2_renames_map", + srcs = ["generate_v2_renames_map.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:lib", + "//tensorflow/tools/common:public_api", + "//tensorflow/tools/common:traverse", + ], +) diff --git a/tensorflow/tools/compatibility/update/generate_v2_renames_map.py b/tensorflow/tools/compatibility/update/generate_v2_renames_map.py new file mode 100644 index 0000000000000000000000000000000000000000..567eceb0b6595ceac624fe8211f22885a6490d85 --- /dev/null +++ b/tensorflow/tools/compatibility/update/generate_v2_renames_map.py @@ -0,0 +1,103 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=line-too-long +"""Script for updating tensorflow/tools/compatibility/renames_v2.py. + +To update renames_v2.py, run: + bazel build tensorflow/tools/compatibility/update:generate_v2_renames_map + bazel-bin/tensorflow/tools/compatibility/update/generate_v2_renames_map +""" +# pylint: enable=line-too-long + +import tensorflow as tf + +from tensorflow.python.lib.io import file_io +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_export +from tensorflow.tools.common import public_api +from tensorflow.tools.common import traverse + + +_OUTPUT_FILE_PATH = 'third_party/tensorflow/tools/compatibility/renames_v2.py' +_FILE_HEADER = """# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=line-too-long +\"\"\"List of renames to apply when converting from TF 1.0 to TF 2.0. + +THIS FILE IS AUTOGENERATED: To update, please run: + bazel build tensorflow/tools/compatibility/update:generate_v2_renames_map + bazel-bin/tensorflow/tools/compatibility/update/generate_v2_renames_map +This file should be updated whenever endpoints are deprecated. +\"\"\" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +""" + + +def update_renames_v2(output_file_path): + """Writes a Python dictionary mapping deprecated to canonical API names. + + Args: + output_file_path: File path to write output to. Any existing contents + would be replaced. + """ + # Set of rename lines to write to output file in the form: + # 'tf.deprecated_name': 'tf.canonical_name' + rename_line_set = set() + # _tf_api_names attribute name + tensorflow_api_attr = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names + + def visit(unused_path, unused_parent, children): + """Visitor that collects rename strings to add to rename_line_set.""" + for child in children: + _, attr = tf_decorator.unwrap(child[1]) + if not hasattr(attr, '__dict__'): + continue + api_names = attr.__dict__.get(tensorflow_api_attr, []) + deprecated_api_names = attr.__dict__.get('_tf_deprecated_api_names', []) + canonical_name = tf_export.get_canonical_name( + api_names, deprecated_api_names) + for name in deprecated_api_names: + rename_line_set.add(' \'tf.%s\': \'tf.%s\'' % (name, canonical_name)) + + visitor = public_api.PublicAPIVisitor(visit) + visitor.do_not_descend_map['tf'].append('contrib') + traverse.traverse(tf, visitor) + + renames_file_text = '%srenames = {\n%s\n}\n' % ( + _FILE_HEADER, ',\n'.join(sorted(rename_line_set))) + file_io.write_string_to_file(output_file_path, renames_file_text) + + +def main(unused_argv): + update_renames_v2(_OUTPUT_FILE_PATH) + + +if __name__ == '__main__': + tf.app.run(main=main) diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl index 8bdc03eb0f19fd6daae826727f429bc1255f0eca..4bfcc2570cce9c8dac369b7c9cf882356c428df5 100644 --- a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl +++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl @@ -48,6 +48,7 @@ EXCLUDE_RE = re.compile(r"RTTI|deleting destructor|::internal::") INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|" r"google::protobuf::internal::ArenaImpl::AllocateAligned|" # for contrib/data/_prefetching_ops r"google::protobuf::internal::ArenaImpl::AddCleanup|" # for contrib/data/_prefetching_ops + r"google::protobuf::internal::LogMessage|" # for contrib/data/_prefetching_ops r"google::protobuf::Arena::OnArenaAllocation|" # for contrib/data/_prefetching_ops r"tensorflow::internal::LogMessage|" r"tensorflow::internal::LogString|" diff --git a/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl b/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl index f8f63e276cab61900cba9de599a11efc7718d078..df0fd053194e7b5da2cd656309467ca0f90e4092 100644 --- a/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl +++ b/tensorflow/tools/def_file_filter/def_file_filter_configure.bzl @@ -24,27 +24,27 @@ load("@bazel_tools//tools/cpp:windows_cc_configure.bzl", "find_msvc_tool") load("@bazel_tools//tools/cpp:lib_cc_configure.bzl", "auto_configure_fail") def _def_file_filter_configure_impl(repository_ctx): - if repository_ctx.os.name.lower().find("windows") == -1: + if repository_ctx.os.name.lower().find("windows") == -1: + repository_ctx.symlink(Label("//tensorflow/tools/def_file_filter:BUILD.tpl"), "BUILD") + repository_ctx.file("def_file_filter.py", "") + return + vc_path = find_vc_path(repository_ctx) + if vc_path == None: + auto_configure_fail("Visual C++ build tools not found on your machine") + + undname = find_msvc_tool(repository_ctx, vc_path, "undname.exe") + if undname == None: + auto_configure_fail("Couldn't find undname.exe under %s, please check your VC installation and set BAZEL_VC environment variable correctly." % vc_path) + undname_bin_path = undname.replace("\\", "\\\\") + + repository_ctx.template( + "def_file_filter.py", + Label("//tensorflow/tools/def_file_filter:def_file_filter.py.tpl"), + { + "%{undname_bin_path}": undname_bin_path, + }, + ) repository_ctx.symlink(Label("//tensorflow/tools/def_file_filter:BUILD.tpl"), "BUILD") - repository_ctx.file("def_file_filter.py", "") - return - vc_path = find_vc_path(repository_ctx) - if vc_path == "visual-studio-not-found": - auto_configure_fail("Visual C++ build tools not found on your machine") - - undname = find_msvc_tool(repository_ctx, vc_path, "undname.exe") - if undname == None: - auto_configure_fail("Couldn't find undname.exe under %s, please check your VC installation and set BAZEL_VC environment variable correctly." % vc_path) - undname_bin_path = undname.replace("\\", "\\\\") - - repository_ctx.template( - "def_file_filter.py", - Label("//tensorflow/tools/def_file_filter:def_file_filter.py.tpl"), - { - "%{undname_bin_path}": undname_bin_path, - }) - repository_ctx.symlink(Label("//tensorflow/tools/def_file_filter:BUILD.tpl"), "BUILD") - def_file_filter_configure = repository_rule( implementation = _def_file_filter_configure_impl, @@ -55,6 +55,6 @@ def_file_filter_configure = repository_rule( "VS100COMNTOOLS", "VS110COMNTOOLS", "VS120COMNTOOLS", - "VS140COMNTOOLS" + "VS140COMNTOOLS", ], ) diff --git a/tensorflow/tools/docker/Dockerfile b/tensorflow/tools/docker/Dockerfile index a3ff8211e3e81925722566863c5ad910295a94ba..2c31d784e5ada2b9fef3a6b6a17ed7988bedf749 100644 --- a/tensorflow/tools/docker/Dockerfile +++ b/tensorflow/tools/docker/Dockerfile @@ -29,8 +29,10 @@ RUN pip --no-cache-dir install \ h5py \ ipykernel \ jupyter \ + keras_applications==1.0.4 \ + keras_preprocessing==1.0.2 \ matplotlib \ - numpy \ + numpy==1.14.5 \ pandas \ scipy \ sklearn \ diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel index 57a491255ea968b08e6e9cbaf9dd0178e8d2c3bf..bacdea72ce66f27da496d90d8cdd8b0d75cff336 100644 --- a/tensorflow/tools/docker/Dockerfile.devel +++ b/tensorflow/tools/docker/Dockerfile.devel @@ -33,9 +33,11 @@ RUN pip --no-cache-dir install \ h5py \ ipykernel \ jupyter \ + keras_applications==1.0.4 \ + keras_preprocessing==1.0.2 \ matplotlib \ mock \ - numpy \ + numpy==1.14.5 \ scipy \ sklearn \ pandas \ @@ -63,7 +65,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ >>/etc/bazel.bazelrc # Install the most recent bazel release. -ENV BAZEL_VERSION 0.11.0 +ENV BAZEL_VERSION 0.15.0 WORKDIR / RUN mkdir /bazel && \ cd /bazel && \ @@ -76,7 +78,7 @@ RUN mkdir /bazel && \ # Download and build TensorFlow. WORKDIR /tensorflow -RUN git clone --branch=r1.9 --depth=1 https://github.com/tensorflow/tensorflow.git . +RUN git clone --branch=r1.10 --depth=1 https://github.com/tensorflow/tensorflow.git . # TODO(craigcitro): Don't install the pip package, since it makes it # more difficult to experiment with local changes. Instead, just add diff --git a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl deleted file mode 100644 index 6796ad70e5d22ca683343680b142081d8d58a9e4..0000000000000000000000000000000000000000 --- a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl +++ /dev/null @@ -1,83 +0,0 @@ -FROM tensorflow/tensorflow:latest-devel - -LABEL maintainer="Clayne Robison" - -# These arguments are parameterized. Use --build-args to override. -ARG TF_BRANCH=r1.9 -ARG WHL_DIR=/whl - -RUN apt-get update && apt-get install -y --no-install-recommends \ - golang \ - vim \ - emacs \ - && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* - -RUN pip --no-cache-dir install --upgrade \ - pip setuptools - -RUN pip --no-cache-dir install wheel - -# Download and build TensorFlow. -WORKDIR / -RUN rm -rf tensorflow && \ - git clone https://github.com/tensorflow/tensorflow.git && \ - cd tensorflow && \ - git checkout ${TF_BRANCH} -WORKDIR /tensorflow - -# Configure the build for CPU with MKL by accepting default build options and -# setting library locations -ENV CI_BUILD_PYTHON=python \ - LD_LIBRARY_PATH=${LD_LIBRARY_PATH} \ - PYTHON_BIN_PATH=/usr/bin/python \ - PYTHON_LIB_PATH=/usr/local/lib/python2.7/dist-packages \ - CC_OPT_FLAGS='-march=native' \ - TF_NEED_JEMALLOC=0 \ - TF_NEED_GCP=1 \ - TF_NEED_CUDA=0 \ - TF_NEED_HDFS=0 \ - TF_NEED_S3=1 \ - TF_NEED_OPENCL=0 \ - TF_NEED_GDR=0 \ - TF_ENABLE_XLA=0 \ - TF_NEED_VERBS=0 \ - TF_NEED_MPI=0 -RUN ./configure - -# Build and Install TensorFlow. -# The 'mkl' option builds with Intel(R) Math Kernel Library (MKL), which detects -# the platform it is currently running on and takes appropriately optimized -# paths. The -march=native option is for code that is not in MKL, and assumes -# this container will be run on the same architecture on which it is built. -RUN LD_LIBRARY_PATH=${LD_LIBRARY_PATH} \ - bazel build --config=mkl \ - --config="opt" \ - --copt="-march=broadwell" \ - --copt="-O3" \ - //tensorflow/tools/pip_package:build_pip_package && \ - mkdir ${WHL_DIR} && \ - bazel-bin/tensorflow/tools/pip_package/build_pip_package ${WHL_DIR} - -# Clean up Bazel cache when done, but leave the whl. -# This will upgrade the default Tensorflow version with the Intel MKL version -RUN pip --no-cache-dir install --upgrade ${WHL_DIR}/tensorflow-*.whl && \ - rm -rf /root/.cache - -WORKDIR /root - -#add welcome message with instructions - -RUN echo '[ ! -z "$TERM" -a -r /etc/motd ] && cat /etc/issue && cat /etc/motd' \ - >> /etc/bash.bashrc \ - ; echo "\ -||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||\n\ -| \n\ -| Docker container running Ubuntu \n\ -| with TensorFlow ${TF_BRANCH} optimized for CPU \n\ -| with Intel(R) MKL \n\ -| \n\ -||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||\n\ -\n "\ - > /etc/motd diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index 204b5b4dba1b607fb709b7f45d145ceafc33f3e7..4f89e3f70161a769837165502ddd3fe78b38e2a3 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -15,6 +15,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ git \ libcudnn7=7.1.4.18-1+cuda9.0 \ libcudnn7-dev=7.1.4.18-1+cuda9.0 \ + libnccl2=2.2.13-1+cuda9.0 \ + libnccl-dev=2.2.13-1+cuda9.0 \ libcurl3-dev \ libfreetype6-dev \ libhdf5-serial-dev \ @@ -33,6 +35,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ find /usr/local/cuda-9.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \ rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a +# Link NCCL libray and header where the build script expects them. +RUN mkdir /usr/local/cuda-9.0/lib && \ + ln -s /usr/lib/x86_64-linux-gnu/libnccl.so.2 /usr/local/cuda/lib/libnccl.so.2 && \ + ln -s /usr/include/nccl.h /usr/local/cuda/include/nccl.h + RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \ python get-pip.py && \ rm get-pip.py @@ -42,9 +49,11 @@ RUN pip --no-cache-dir install \ h5py \ ipykernel \ jupyter \ + keras_applications==1.0.4 \ + keras_preprocessing==1.0.2 \ matplotlib \ mock \ - numpy \ + numpy==1.14.5 \ scipy \ sklearn \ pandas \ @@ -72,7 +81,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ >>/etc/bazel.bazelrc # Install the most recent bazel release. -ENV BAZEL_VERSION 0.11.0 +ENV BAZEL_VERSION 0.15.0 WORKDIR / RUN mkdir /bazel && \ cd /bazel && \ @@ -85,16 +94,19 @@ RUN mkdir /bazel && \ # Download and build TensorFlow. WORKDIR /tensorflow -RUN git clone --branch=r1.9 --depth=1 https://github.com/tensorflow/tensorflow.git . +RUN git clone --branch=r1.10 --depth=1 https://github.com/tensorflow/tensorflow.git . # Configure the build for our CUDA configuration. ENV CI_BUILD_PYTHON python ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH ENV TF_NEED_CUDA 1 -ENV TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2,6.0,6.1 +ENV TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0 ENV TF_CUDA_VERSION=9.0 ENV TF_CUDNN_VERSION=7 +# NCCL 2.x +ENV TF_NCCL_VERSION=2 + RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 && \ LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs:${LD_LIBRARY_PATH} \ tensorflow/tools/ci_build/builds/configured GPU \ diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 new file mode 100644 index 0000000000000000000000000000000000000000..056b4755f4d8efbf13d124b22bc40a8a738e2755 --- /dev/null +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 @@ -0,0 +1,117 @@ +FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 + +LABEL maintainer="Gunhan Gulsoy " + +# It is possible to override these for releases. +ARG TF_BRANCH=master +ARG BAZEL_VERSION=0.15.0 +ARG TF_AVAILABLE_CPUS=32 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + golang \ + libcurl3-dev \ + libfreetype6-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + python-dev \ + python-pip \ + rsync \ + software-properties-common \ + unzip \ + zip \ + zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ + wget \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN pip --no-cache-dir install --upgrade \ + pip setuptools + +RUN pip --no-cache-dir install \ + ipykernel \ + jupyter \ + keras_applications==1.0.4 \ + keras_preprocessing==1.0.2 \ + matplotlib \ + numpy \ + scipy \ + sklearn \ + pandas \ + wheel \ + && \ + python -m ipykernel.kernelspec + +# Set up our notebook config. +COPY jupyter_notebook_config.py /root/.jupyter/ + +# Jupyter has issues with being run directly: +# https://github.com/ipython/ipython/issues/7062 +# We just add a little wrapper script. +COPY run_jupyter.sh / + +# Set up Bazel. + +# Running bazel inside a `docker build` command causes trouble, cf: +# https://github.com/bazelbuild/bazel/issues/134 +# The easiest solution is to set up a bazelrc file forcing --batch. +RUN echo "startup --batch" >>/etc/bazel.bazelrc +# Similarly, we need to workaround sandboxing issues: +# https://github.com/bazelbuild/bazel/issues/418 +RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ + >>/etc/bazel.bazelrc +WORKDIR / +RUN mkdir /bazel && \ + cd /bazel && \ + wget --quiet https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + wget --quiet https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \ + chmod +x bazel-*.sh && \ + ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh + +# Download and build TensorFlow. +WORKDIR / +RUN git clone https://github.com/tensorflow/tensorflow.git && \ + cd tensorflow && \ + git checkout ${TF_BRANCH} +WORKDIR /tensorflow + +# Configure the build for our CUDA configuration. +ENV CI_BUILD_PYTHON=python \ + LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:${LD_LIBRARY_PATH} \ + CUDNN_INSTALL_PATH=/usr/lib/x86_64-linux-gnu \ + PYTHON_BIN_PATH=/usr/bin/python \ + PYTHON_LIB_PATH=/usr/local/lib/python2.7/dist-packages \ + TF_NEED_CUDA=1 \ + TF_CUDA_VERSION=9.0 \ + TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2,6.0,6.1,7.0 \ + TF_CUDNN_VERSION=7 +RUN ./configure + +# Build and Install TensorFlow. +RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 && \ + LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs:${LD_LIBRARY_PATH} \ + bazel build -c opt \ + --config=cuda \ + --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \ + --jobs=${TF_AVAILABLE_CPUS} \ + tensorflow/tools/pip_package:build_pip_package && \ + mkdir /pip_pkg && \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package /pip_pkg && \ + pip --no-cache-dir install --upgrade /pip_pkg/tensorflow-*.whl && \ + rm -rf /pip_pkg && \ + rm -rf /root/.cache +# Clean up pip wheel and Bazel cache when done. + +WORKDIR /root + +# TensorBoard +EXPOSE 6006 +# IPython +EXPOSE 8888 diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl new file mode 100755 index 0000000000000000000000000000000000000000..2df770e52555bf15625257e1082bb55b9bf6fe1f --- /dev/null +++ b/tensorflow/tools/docker/Dockerfile.devel-mkl @@ -0,0 +1,143 @@ +FROM ubuntu:16.04 + +LABEL maintainer="Clayne Robison " + +# These parameters can be overridden by parameterized_docker_build.sh +ARG TF_BUILD_VERSION=r1.10 +ARG PYTHON="python" +ARG PYTHON3_DEV="" +ARG WHL_DIR="/tmp/pip" +ARG PIP="pip" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + libssl-dev \ + pkg-config \ + rsync \ + software-properties-common \ + unzip \ + zip \ + zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless + +#install Python 3 +RUN if [ ${PYTHON} = "python3.6" ]; then \ + curl https://www.python.org/ftp/python/3.6.5/Python-3.6.5.tar.xz -o /opt/python.tar.xz && \ + cd /opt && tar xvf python.tar.xz && \ + cd /opt/*/ && ./configure && \ + make && make install; \ + else \ + apt-get install -y --no-install-recommends \ + python-dev \ + ${PYTHON3_DEV}; \ + fi + +RUN apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \ + ${PYTHON} get-pip.py && \ + rm get-pip.py + +RUN ${PIP} --no-cache-dir install \ + Pillow \ + h5py \ + ipykernel \ + jupyter \ + keras_applications==1.0.4 \ + keras_preprocessing==1.0.2 \ + matplotlib \ + mock \ + numpy \ + scipy \ + sklearn \ + pandas \ + && \ + ${PYTHON} -m ipykernel.kernelspec + +RUN if [ "${PYTHON}" = "python3" ]; then \ + ln -s -f /usr/bin/python3 /usr/bin/python; \ + elif [ "${PYTHON}" = "python3.6" ]; then \ + ln -s -f /usr/local/bin/python3.6 /usr/bin/python; \ + fi + +# Set up our notebook config. +COPY jupyter_notebook_config.py /root/.jupyter/ + +# Jupyter has issues with being run directly: +# https://github.com/ipython/ipython/issues/7062 +# We just add a little wrapper script. +COPY run_jupyter.sh / + +# Set up Bazel. + +# Running bazel inside a `docker build` command causes trouble, cf: +# https://github.com/bazelbuild/bazel/issues/134 +# The easiest solution is to set up a bazelrc file forcing --batch. +RUN echo "startup --batch" >>/etc/bazel.bazelrc +# Similarly, we need to workaround sandboxing issues: +# https://github.com/bazelbuild/bazel/issues/418 +RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ + >>/etc/bazel.bazelrc +# Install the most recent bazel release. +ENV BAZEL_VERSION 0.15.0 +WORKDIR / +RUN mkdir /bazel && \ + cd /bazel && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \ + chmod +x bazel-*.sh && \ + ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + cd / && \ + rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh + +# Download and build TensorFlow. +WORKDIR /tensorflow + +# Download and build TensorFlow. +# Enable checking out both tags and branches +RUN export TAG_PREFIX="v" && \ + echo ${TF_BUILD_VERSION} | grep -q ^${TAG_PREFIX}; \ + if [ $? -eq 0 ]; then \ + git clone --depth=1 https://github.com/tensorflow/tensorflow.git . && \ + git fetch --tags && \ + git checkout ${TF_BUILD_VERSION}; \ + else \ + git clone --depth=1 --branch=${TF_BUILD_VERSION} https://github.com/tensorflow/tensorflow.git . ; \ + fi + +RUN yes "" | ${PYTHON} configure.py + +ENV CI_BUILD_PYTHON ${PYTHON} + +# Set bazel build parameters in .bazelrc in parameterized_docker_build.sh +# Use --copt=-march values to get optimized builds appropriate for the hardware +# platform of your choice. +# For ivy-bridge or sandy-bridge +# --copt=-march="avx" \ +# For haswell, broadwell, or skylake +# --copt=-march="avx2" \ +COPY .bazelrc /root/.bazelrc + +RUN tensorflow/tools/ci_build/builds/configured CPU \ + bazel --bazelrc=/root/.bazelrc build -c opt \ + tensorflow/tools/pip_package:build_pip_package && \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package "${WHL_DIR}" && \ + ${PIP} --no-cache-dir install --upgrade "${WHL_DIR}"/tensorflow-*.whl && \ + rm -rf /root/.cache +# Clean up Bazel cache when done. + +# TensorBoard +EXPOSE 6006 +# IPython +EXPOSE 8888 + +WORKDIR /root diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod new file mode 100755 index 0000000000000000000000000000000000000000..ab2eec172892f314f52a22dda08587abeb144a39 --- /dev/null +++ b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod @@ -0,0 +1,168 @@ +FROM ubuntu:16.04 + +LABEL maintainer="Cong Xu " + +# These parameters can be overridden by parameterized_docker_build.sh +ARG TF_BUILD_VERSION=r1.9 +ARG PYTHON="python" +ARG PYTHON3_DEV="" +ARG WHL_DIR="/tmp/pip" +ARG PIP="pip" + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + git \ + libcurl3-dev \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + python-dev \ + ${PYTHON3_DEV} \ + rsync \ + software-properties-common \ + unzip \ + zip \ + zlib1g-dev \ + openjdk-8-jdk \ + openjdk-8-jre-headless \ + wget \ + numactl \ + openssh-client \ + openssh-server \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \ + ${PYTHON} get-pip.py && \ + rm get-pip.py + +RUN ${PIP} --no-cache-dir install \ + Pillow \ + h5py \ + ipykernel \ + jupyter \ + keras_applications==1.0.4 \ + keras_preprocessing==1.0.2 \ + matplotlib \ + mock \ + numpy \ + scipy \ + sklearn \ + pandas \ + && \ + ${PYTHON} -m ipykernel.kernelspec + +RUN if [ "${PYTHON}" = "python3" ]; then \ + ln -s -f /usr/bin/python3 /usr/bin/python; \ + fi + +# Set up our notebook config. +COPY jupyter_notebook_config.py /root/.jupyter/ + +# Jupyter has issues with being run directly: +# https://github.com/ipython/ipython/issues/7062 +# We just add a little wrapper script. +COPY run_jupyter.sh / + +# Set up Bazel. + +# Running bazel inside a `docker build` command causes trouble, cf: +# https://github.com/bazelbuild/bazel/issues/134 +# The easiest solution is to set up a bazelrc file forcing --batch. +RUN echo "startup --batch" >>/etc/bazel.bazelrc +# Similarly, we need to workaround sandboxing issues: +# https://github.com/bazelbuild/bazel/issues/418 +RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ + >>/etc/bazel.bazelrc +# Install the most recent bazel release. +ENV BAZEL_VERSION 0.15.0 +WORKDIR / +RUN mkdir /bazel && \ + cd /bazel && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \ + chmod +x bazel-*.sh && \ + ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \ + cd / && \ + rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh + +# Download and build TensorFlow. +WORKDIR /tensorflow + +# Download and build TensorFlow. +# Enable checking out both tags and branches +RUN export TAG_PREFIX="v" && \ + echo ${TF_BUILD_VERSION} | grep -q ^${TAG_PREFIX}; \ + if [ $? -eq 0 ]; then \ + git clone --depth=1 https://github.com/tensorflow/tensorflow.git . && \ + git fetch --tags && \ + git checkout ${TF_BUILD_VERSION}; \ + else \ + git clone --depth=1 --branch=${TF_BUILD_VERSION} https://github.com/tensorflow/tensorflow.git . ; \ + fi + +RUN yes "" | ${PYTHON} configure.py + +ENV CI_BUILD_PYTHON ${PYTHON} + +# Set bazel build parameters in .bazelrc in parameterized_docker_build.sh +# Use --copt=-march values to get optimized builds appropriate for the hardware +# platform of your choice. +# For ivy-bridge or sandy-bridge +# --copt=-march="avx" \ +# For haswell, broadwell, or skylake +# --copt=-march="avx2" \ +COPY .bazelrc /root/.bazelrc + +RUN tensorflow/tools/ci_build/builds/configured CPU \ + bazel --bazelrc=/root/.bazelrc build -c opt \ + tensorflow/tools/pip_package:build_pip_package && \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package "${WHL_DIR}" && \ + ${PIP} --no-cache-dir install --upgrade "${WHL_DIR}"/tensorflow-*.whl && \ + rm -rf /root/.cache +# Clean up Bazel cache when done. + +WORKDIR /root + +# Install Open MPI +RUN mkdir /tmp/openmpi && \ + cd /tmp/openmpi && \ + wget https://www.open-mpi.org/software/ompi/v3.0/downloads/openmpi-3.0.0.tar.gz && \ + tar zxf openmpi-3.0.0.tar.gz && \ + cd openmpi-3.0.0 && \ + ./configure --enable-orterun-prefix-by-default && \ + make -j $(nproc) all && \ + make install && \ + ldconfig && \ + rm -rf /tmp/openmpi + +# Create a wrapper for OpenMPI to allow running as root by default +RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ + echo '#!/bin/bash' > /usr/local/bin/mpirun && \ + echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ + chmod a+x /usr/local/bin/mpirun + +# Configure OpenMPI to run good defaults: +RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf + +# Install Horovod +RUN ${PIP} install --no-cache-dir horovod + +# Install OpenSSH for MPI to communicate between containers +RUN mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ + echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ + mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config + +# TensorBoard +EXPOSE 6006 +# IPython +EXPOSE 8888 + +WORKDIR /root diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu index 9197651ff4326e9b40264183a94b82e936746010..aa0e0face171b97b9cc8d2672bc751db735695b7 100644 --- a/tensorflow/tools/docker/Dockerfile.gpu +++ b/tensorflow/tools/docker/Dockerfile.gpu @@ -13,6 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ cuda-cusparse-9-0 \ curl \ libcudnn7=7.1.4.18-1+cuda9.0 \ + libnccl2=2.2.13-1+cuda9.0 \ libfreetype6-dev \ libhdf5-serial-dev \ libpng12-dev \ @@ -36,8 +37,10 @@ RUN pip --no-cache-dir install \ h5py \ ipykernel \ jupyter \ + keras_applications==1.0.4 \ + keras_preprocessing==1.0.2 \ matplotlib \ - numpy \ + numpy==1.14.5 \ pandas \ scipy \ sklearn \ diff --git a/tensorflow/tools/docker/Dockerfile.mkl b/tensorflow/tools/docker/Dockerfile.mkl new file mode 100755 index 0000000000000000000000000000000000000000..69553302d879a74c1cb1a67b46ded17f2bab553b --- /dev/null +++ b/tensorflow/tools/docker/Dockerfile.mkl @@ -0,0 +1,77 @@ +FROM ubuntu:16.04 + +LABEL maintainer="Clayne Robison " + +# This parameter MUST be set by parameterized_docker_build.sh +ARG TF_WHL_URL + +# Optional parameters +ARG TF_BUILD_VERSION=r1.9 +ARG PYTHON="python" +ARG PYTHON_DEV="python-dev" +ARG PIP="pip" + +# Pick up some TF dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + ${PYTHON} \ + ${PYTHON_DEV} \ + rsync \ + software-properties-common \ + unzip \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ + ${PYTHON} get-pip.py && \ + rm get-pip.py + +RUN ${PIP} --no-cache-dir install \ + Pillow \ + h5py \ + ipykernel \ + jupyter \ + keras_applications==1.0.4 \ + keras_preprocessing==1.0.2 \ + matplotlib \ + numpy \ + pandas \ + scipy \ + sklearn \ + && \ + ${PYTHON} -m ipykernel.kernelspec + +COPY ${TF_WHL_URL} / +RUN ${PIP} install --no-cache-dir --force-reinstall /${TF_WHL_URL} && \ + rm -rf /${TF_WHL_URL} + +RUN if [ "${PYTHON}" = "python3" ]; then \ + ln -s -f /usr/bin/python3 /usr/bin/python; \ + fi + +# Set up our notebook config. +COPY jupyter_notebook_config.py /root/.jupyter/ + +# Copy sample notebooks. +COPY notebooks /notebooks + +# Jupyter has issues with being run directly: +# https://github.com/ipython/ipython/issues/7062 +# We just add a little wrapper script. +COPY run_jupyter.sh / + +# TensorBoard +EXPOSE 6006 +# IPython +EXPOSE 8888 + +WORKDIR "/notebooks" + +CMD ["/run_jupyter.sh", "--allow-root"] diff --git a/tensorflow/tools/docker/Dockerfile.mkl-horovod b/tensorflow/tools/docker/Dockerfile.mkl-horovod new file mode 100755 index 0000000000000000000000000000000000000000..756716ee0e0578fcc57112fc30a3ef99049e7047 --- /dev/null +++ b/tensorflow/tools/docker/Dockerfile.mkl-horovod @@ -0,0 +1,111 @@ +FROM ubuntu:16.04 + +LABEL maintainer="Cong Xu " + +# This parameter MUST be set by parameterized_docker_build.sh +ARG TF_WHL_URL + +# Optional parameters +ARG TF_BUILD_VERSION=r1.9 +ARG PYTHON="python" +ARG PYTHON_DEV="python-dev" +ARG PIP="pip" + +# Pick up some TF dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + libfreetype6-dev \ + libhdf5-serial-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + python \ + ${PYTHON_DEV} \ + rsync \ + software-properties-common \ + unzip \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ + python get-pip.py && \ + rm get-pip.py + +RUN ${PIP} --no-cache-dir install \ + Pillow \ + h5py \ + ipykernel \ + jupyter \ + keras_applications==1.0.4 \ + keras_preprocessing==1.0.2 \ + matplotlib \ + numpy \ + pandas \ + scipy \ + sklearn \ + && \ + python -m ipykernel.kernelspec + +COPY ${TF_WHL_URL} / +RUN ${PIP} install --no-cache-dir --force-reinstall /${TF_WHL_URL} && \ + rm -rf /${TF_WHL_URL} + +RUN if [ "${PYTHON}" = "python3" ]; then \ + ln -s -f /usr/bin/python3 /usr/bin/python; \ + fi + +# Set up our notebook config. +COPY jupyter_notebook_config.py /root/.jupyter/ + +# Copy sample notebooks. +COPY notebooks /notebooks + +# Jupyter has issues with being run directly: +# https://github.com/ipython/ipython/issues/7062 +# We just add a little wrapper script. +COPY run_jupyter.sh / + +WORKDIR /root + +# Install Open MPI +RUN mkdir /tmp/openmpi && \ + cd /tmp/openmpi && \ + wget https://www.open-mpi.org/software/ompi/v3.0/downloads/openmpi-3.0.0.tar.gz && \ + tar zxf openmpi-3.0.0.tar.gz && \ + cd openmpi-3.0.0 && \ + ./configure --enable-orterun-prefix-by-default && \ + make -j $(nproc) all && \ + make install && \ + ldconfig && \ + rm -rf /tmp/openmpi + +# Create a wrapper for OpenMPI to allow running as root by default +RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ + echo '#!/bin/bash' > /usr/local/bin/mpirun && \ + echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ + chmod a+x /usr/local/bin/mpirun + +# Configure OpenMPI to run good defaults: +RUN echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf + +# Install Horovod +RUN ${PIP} install --no-cache-dir horovod + +# Install OpenSSH for MPI to communicate between containers +RUN mkdir -p /var/run/sshd + +# Allow OpenSSH to talk to containers without asking for confirmation +RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ + echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ + mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config + +# TensorBoard +EXPOSE 6006 +# IPython +EXPOSE 8888 + +WORKDIR "/notebooks" + +CMD ["/run_jupyter.sh", "--allow-root"] diff --git a/tensorflow/tools/docker/README.md b/tensorflow/tools/docker/README.md index 525f2995ceecd48ee7463fc207406c5f9b25f61e..a286e8a21202162b1665112f5211809bd661f926 100644 --- a/tensorflow/tools/docker/README.md +++ b/tensorflow/tools/docker/README.md @@ -87,8 +87,10 @@ export TF_DOCKER_BUILD_IS_DEVEL=NO export TF_DOCKER_BUILD_TYPE=CPU export TF_DOCKER_BUILD_PYTHON_VERSION=PYTHON2 -export NIGHTLY_VERSION="1.head" -export TF_DOCKER_BUILD_CENTRAL_PIP=$(echo ${TF_DOCKER_BUILD_PYTHON_VERSION} | sed s^PYTHON2^http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=${TF_DOCKER_BUILD_PYTHON_VERSION},label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-${NIGHTLY_VERSION}-cp27-cp27mu-manylinux1_x86_64.whl^ | sed s^PYTHON3^http://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-${NIGHTLY_VERSION}-cp35-cp35m-manylinux1_x86_64.whl^) +pip download --no-deps tf-nightly + +export TF_DOCKER_BUILD_CENTRAL_PIP=$(ls tf_nightly*.whl) +export TF_DOCKER_BUILD_CENTRAL_PIP_IS_LOCAL=1 tensorflow/tools/docker/parameterized_docker_build.sh ``` diff --git a/tensorflow/tools/docker/notebooks/1_hello_tensorflow.ipynb b/tensorflow/tools/docker/notebooks/1_hello_tensorflow.ipynb index 0633b03259a06363d0d069eb479971f8b87f983e..8fa871ef7729a9194de282b84cdd9539c80f8555 100644 --- a/tensorflow/tools/docker/notebooks/1_hello_tensorflow.ipynb +++ b/tensorflow/tools/docker/notebooks/1_hello_tensorflow.ipynb @@ -665,7 +665,7 @@ "source": [ "## What's next?\n", "\n", - "This has been a gentle introduction to TensorFlow, focused on what TensorFlow is and the very basics of doing anything in TensorFlow. If you'd like more, the next tutorial in the series is Getting Started with TensorFlow, also available in the [notebooks directory](..)." + "This has been a gentle introduction to TensorFlow, focused on what TensorFlow is and the very basics of doing anything in TensorFlow. If you'd like more, the next tutorial in the series is Getting Started with TensorFlow, also available in the [notebooks directory](../notebooks)." ] } ], diff --git a/tensorflow/tools/docker/parameterized_docker_build.sh b/tensorflow/tools/docker/parameterized_docker_build.sh index 05de25f2cb11d76f223a31bc12329e6ab7368e8a..c9f17a8242d19d5db0d73c7f67fc028c5bc016e8 100755 --- a/tensorflow/tools/docker/parameterized_docker_build.sh +++ b/tensorflow/tools/docker/parameterized_docker_build.sh @@ -19,8 +19,8 @@ # parameterized_docker_build.sh # # The script obeys the following environment variables: -# TF_DOCKER_BUILD_TYPE: (CPU | GPU) -# CPU or GPU image +# TF_DOCKER_BUILD_TYPE: (CPU | GPU | MKL | MKL-HOROVOD) +# CPU, GPU, MKL or MKL-HOROVOD image # # TF_DOCKER_BUILD_IS_DEVEL: (NO | YES) # Is this developer image @@ -87,6 +87,15 @@ # TF_DOCKER_BUILD_OPTIONS # (Optional) # Specifies the desired build options. Defaults to OPT. +# +# TF_DOCKER_BUILD_ARGS +# (Optional) +# A list (array) of docker build args. Will be passed to docker build +# command as list of --build-arg parameters. +# +# TF_BAZEL_BUILD_OPTIONS +# (Optional) +# Bazel compiler flags to be passed to the bazelrc file # Script directory SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -116,6 +125,8 @@ echo " TF_DOCKER_BUILD_IMAGE_NAME=${TF_DOCKER_BUILD_IMAGE_NAME}" echo " TF_DOCKER_BUILD_VERSION=${TF_DOCKER_BUILD_VERSION}" echo " TF_DOCKER_BUILD_PORT=${TF_DOCKER_BUILD_PORT}" echo " TF_DOCKER_BUILD_PUSH_CMD=${TF_DOCKER_BUILD_PUSH_CMD}" +echo " TF_DOCKER_BUILD_ARGS=${TF_DOCKER_BUILD_ARGS[@]:-()}" +echo " TF_BAZEL_BUILD_OPTIONS=${TF_BAZEL_BUILD_OPTIONS}" CONTAINER_PORT=${TF_DOCKER_BUILD_PORT:-8888} @@ -149,6 +160,24 @@ fi if [[ ${TF_DOCKER_BUILD_TYPE} == "cpu" ]]; then DOCKER_BINARY="docker" +elif [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then + DOCKER_BINARY="docker" + FINAL_TAG="${FINAL_TAG}-mkl" + if [[ ${ORIG_DOCKERFILE} == *"."* ]]; then + # There is already a dot in the tag, use "-" + ORIG_DOCKERFILE="${ORIG_DOCKERFILE}-mkl" + else + ORIG_DOCKERFILE="${ORIG_DOCKERFILE}.mkl" + fi +elif [[ ${TF_DOCKER_BUILD_TYPE} == "mkl-horovod" ]]; then + DOCKER_BINARY="docker" + FINAL_TAG="${FINAL_TAG}-mkl-horovod" + if [[ ${ORIG_DOCKERFILE} == *"."* ]]; then + # There is already a dot in the tag, use "-" + ORIG_DOCKERFILE="${ORIG_DOCKERFILE}-mkl-horovod" + else + ORIG_DOCKERFILE="${ORIG_DOCKERFILE}.mkl-horovod" + fi elif [[ ${TF_DOCKER_BUILD_TYPE} == "gpu" ]]; then DOCKER_BINARY="nvidia-docker" @@ -168,6 +197,8 @@ if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python2" ]]; then : elif [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then FINAL_TAG="${FINAL_TAG}-py3" +elif [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3.6" ]]; then + FINAL_TAG="${FINAL_TAG}-py3.6" else die "Unrecognized value in TF_DOCKER_BUILD_PYTHON_VERSION: "\ "${TF_DOCKER_BUILD_PYTHON_VERSION}" @@ -203,6 +234,14 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then export TF_BUILD_OPTIONS=${TF_DOCKER_BUILD_OPTIONS} export TF_BUILD_IS_PIP="PIP" + if [[ "${TF_DOCKER_BUILD_TYPE}" == "mkl" ]]; then + die "FAIL: Non-development MKL builds require a pre-built pip whl." + fi + + if [[ "${TF_DOCKER_BUILD_TYPE}" == "mkl-horovod" ]]; then + die "FAIL: Non-development MKL-HOROVOD builds require a pre-built pip whl." + fi + if [[ "${TF_DOCKER_BUILD_TYPE}" == "gpu" ]]; then export TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS=\ "${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS} -e TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2" @@ -255,25 +294,41 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then # Use string replacement to put the correct file name into the Dockerfile PIP_WHL=$(basename "${PIP_WHL}") - # Modify the non-devel Dockerfile to point to the correct pip whl file - # location - sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\ + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]] || \ + [[ ${TF_DOCKER_BUILD_TYPE} == "mkl-horovod" ]]; then + TF_DOCKER_BUILD_ARGS+=("--build-arg TF_WHL_URL=${PIP_WHL}" ) + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" + else + # Modify the non-devel Dockerfile to point to the correct pip whl file + # location + sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\ "/# --- ~ DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/c"\ "COPY ${PIP_WHL} /\n"\ "RUN pip --no-cache-dir install /${PIP_WHL}" "${ORIG_DOCKERFILE}" \ - > "${DOCKERFILE}" + > "${DOCKERFILE}" + fi echo "Using local pip wheel from: ${TF_DOCKER_BUILD_CENTRAL_PIP}" echo - else echo "Downloading pip wheel from: ${TF_DOCKER_BUILD_CENTRAL_PIP}" - echo - - # Modify the non-devel Dockerfile to point to the correct pip whl URL. - sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\ + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]] || \ + [[ ${TF_DOCKER_BUILD_TYPE} == "mkl-horovod" ]]; then + pushd "${TMP_DIR}/" + curl -O ${TF_DOCKER_BUILD_CENTRAL_PIP} + popd + PIP_WHL_PATH=`find ${TMP_DIR} -name "*.whl"` + PIP_WHL=$(basename "${PIP_WHL_PATH}") + echo "PIP_WHL= ${PIP_WHL}" + echo + TF_DOCKER_BUILD_ARGS+=("--build-arg TF_WHL_URL=${PIP_WHL}") + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" + else + # Modify the non-devel Dockerfile to point to the correct pip whl URL. + sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\ "/# --- ~ DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/c"\ "RUN pip --no-cache-dir install ${TF_DOCKER_BUILD_CENTRAL_PIP}" "${ORIG_DOCKERFILE}" \ - > "${DOCKERFILE}" + > "${DOCKERFILE}" + fi fi echo "Modified Dockerfile at: ${DOCKERFILE}" @@ -281,36 +336,71 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then # Modify python/pip version if necessary. if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then - if sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ - sed -i -e 's/python-dev/python3-dev/g' "${DOCKERFILE}" && \ - sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ - sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" - then - echo "Modified Dockerfile for python version "\ -"${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]] || \ + [[ ${TF_DOCKER_BUILD_TYPE} == "mkl-horovod" ]]; then + TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON=${TF_DOCKER_BUILD_PYTHON_VERSION}") + TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON_DEV=python3-dev") + TF_DOCKER_BUILD_ARGS+=("--build-arg PIP=pip3") + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" else - die "FAILED to modify ${DOCKERFILE} for python3" + if sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ + sed -i -e 's/python-dev/python3-dev/g' "${DOCKERFILE}" && \ + sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ + sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" + then + echo "Modified Dockerfile for python version "\ + "${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + else + die "FAILED to modify ${DOCKERFILE} for python3" + fi fi fi -else +else # TF_DOCKER_BUILD_IS_DEVEL == 'yes' DOCKERFILE="${TMP_DIR}/Dockerfile" - # Modify the devel Dockerfile to specify the git branch - sed "s/^RUN git clone --branch=.* --depth=1/RUN git clone --branch=${TF_DOCKER_BUILD_DEVEL_BRANCH} --depth=1/" \ - "${ORIG_DOCKERFILE}" > "${DOCKERFILE}" + # Set up Dockerfile ARGS for mkl and mkl-horovod build + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]] || \ + [[ ${TF_DOCKER_BUILD_TYPE} == "mkl-horovod" ]]; then + if [[ -z "${TF_BAZEL_BUILD_OPTIONS// }" ]]; then + TF_BAZEL_BUILD_OPTIONS=("--config=mkl --copt=-mavx --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0") + else + TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}" + fi + TF_DOCKER_BUILD_ARGS+=("--build-arg TF_BUILD_VERSION=${TF_DOCKER_BUILD_DEVEL_BRANCH}") + echo "TF_DOCKER_BUILD_ARGS=${TF_DOCKER_BUILD_ARGS[@]}" + + # Pass the build options to bazel using the user-specific .bazelrc file + echo "build ${TF_BAZEL_BUILD_OPTIONS}" >> ${TMP_DIR}/.bazelrc + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" + else + # Modify the devel Dockerfile to specify the git branch + sed "s/^RUN git clone --branch=.* --depth=1/RUN git clone --branch=${TF_DOCKER_BUILD_DEVEL_BRANCH} --depth=1/" \ + "${ORIG_DOCKERFILE}" > "${DOCKERFILE}" + fi # Modify python/pip version if necessary. - if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then - if sed -i -e 's/python-dev/python-dev python3-dev/g' "${DOCKERFILE}" && \ - sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ - sed -i -e 's^/tmp/pip^/tmp/pip3^g' "${DOCKERFILE}" && \ - sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ - sed -i -e 's/ENV CI_BUILD_PYTHON python/ENV CI_BUILD_PYTHON python3/g' "${DOCKERFILE}" && \ - sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" - then - echo "Modified Dockerfile further for python version ${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]] || [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3.6" ]]; then + if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]] || [[ ${TF_DOCKER_BUILD_TYPE} == "mkl-horovod" ]]; then + TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON=${TF_DOCKER_BUILD_PYTHON_VERSION}") + TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON3_DEV=python3-dev") + TF_DOCKER_BUILD_ARGS+=("--build-arg WHL_DIR=/tmp/pip3") + TF_DOCKER_BUILD_ARGS+=("--build-arg PIP=pip3") + cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}" else - die "FAILED to modify ${DOCKERFILE} for python3" + if [[ "${TF_DOCKER_BUILD_TYPE}" != "mkl" ]]; then + die "Python 3.6 build only supported for MKL builds." + fi + if sed -i -e 's/python-dev/python-dev python3-dev/g' "${DOCKERFILE}" && \ + sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \ + sed -i -e 's^/tmp/pip^/tmp/pip3^g' "${DOCKERFILE}" && \ + sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \ + sed -i -e 's/ENV CI_BUILD_PYTHON python/ENV CI_BUILD_PYTHON python3/g' "${DOCKERFILE}" && \ + sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}" + then + echo "Modified Dockerfile further for python version ${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}" + else + die "FAILED to modify ${DOCKERFILE} for python3" + fi fi fi fi @@ -319,8 +409,11 @@ fi # Intermediate image name with tag IMG="${USER}/tensorflow:${FINAL_TAG}" echo "Building docker image with image name and tag: ${IMG}" +echo "TF_DOCKER_BUILD_ARGS=${TF_DOCKER_BUILD_ARGS[@]}" +CMD="${DOCKER_BINARY} build ${TF_DOCKER_BUILD_ARGS[@]} --no-cache --pull -t ${IMG} -f ${DOCKERFILE} ${TMP_DIR}" +echo "CMD=${CMD}" +${CMD} -"${DOCKER_BINARY}" build --no-cache --pull -t "${IMG}" -f "${DOCKERFILE}" "${TMP_DIR}" if [[ $? == "0" ]]; then echo "${DOCKER_BINARY} build of ${IMG} succeeded" else @@ -340,7 +433,7 @@ fi DOCKER_RUN_LOG="${TMP_DIR}/docker_run.log" echo "" echo "Running docker container from image ${IMG}..." -echo " (Log file is at: ${DOCKER_RUN_LOG}" +echo " Log file is at: ${DOCKER_RUN_LOG}" echo "" if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then @@ -386,7 +479,6 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then # Stop the running docker container sleep 1 "${DOCKER_BINARY}" stop --time=0 ${CONTAINER_ID} - fi diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 58b5ef8345c9de83e2d50cd01fe11e11f51fe298..4f7efe193f13c5f8c7d85907186ffdb7052da5b7 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -28,6 +28,24 @@ py_test( srcs_version = "PY2AND3", deps = [ ":doc_generator_visitor", + ":generate_lib", + "//tensorflow/python:platform_test", + ], +) + +py_library( + name = "doc_controls", + srcs = ["doc_controls.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "doc_controls_test", + size = "small", + srcs = ["doc_controls_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":doc_controls", "//tensorflow/python:platform_test", ], ) @@ -37,7 +55,12 @@ py_library( srcs = ["parser.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = ["@astor_archive//:astor"], + deps = [ + ":doc_controls", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "@astor_archive//:astor", + ], ) py_test( @@ -63,6 +86,7 @@ py_binary( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":doc_controls", ":doc_generator_visitor", ":parser", ":pretty_docs", @@ -92,6 +116,7 @@ py_binary( deps = [ ":generate_lib", "//tensorflow:tensorflow_py", + "//tensorflow/python:util", "//tensorflow/python/debug:debug_py", ], ) @@ -100,7 +125,7 @@ py_test( name = "build_docs_test", size = "small", srcs = ["build_docs_test.py"], - data = ["//tensorflow:docs_src"], + data = ["//tensorflow/docs_src"], srcs_version = "PY2AND3", tags = [ # No reason to run sanitizers or fastbuild for this test. diff --git a/tensorflow/tools/docs/doc_controls.py b/tensorflow/tools/docs/doc_controls.py new file mode 100644 index 0000000000000000000000000000000000000000..5e526443cc977ed6f2b68b43b385f57992560d4d --- /dev/null +++ b/tensorflow/tools/docs/doc_controls.py @@ -0,0 +1,319 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Documentation control decorators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +_DO_NOT_DOC = "_tf_docs_do_not_document" + + +def do_not_generate_docs(obj): + """A decorator: Do not generate docs for this object. + + For example the following classes: + + ``` + class Parent(object): + def method1(self): + pass + def method2(self): + pass + + class Child(Parent): + def method1(self): + pass + def method2(self): + pass + ``` + + Produce the following api_docs: + + ``` + /Parent.md + # method1 + # method2 + /Child.md + # method1 + # method2 + ``` + + This decorator allows you to skip classes or methods: + + ``` + @do_not_generate_docs + class Parent(object): + def method1(self): + pass + def method2(self): + pass + + class Child(Parent): + @do_not_generate_docs + def method1(self): + pass + def method2(self): + pass + ``` + + This will only produce the following docs: + + ``` + /Child.md + # method2 + ``` + + Note: This is implemented by adding a hidden attribute on the object, so it + cannot be used on objects which do not allow new attributes to be added. So + this decorator must go *below* `@property`, `@classmethod`, + or `@staticmethod`: + + ``` + class Example(object): + @property + @do_not_generate_docs + def x(self): + return self._x + ``` + + Args: + obj: The object to hide from the generated docs. + + Returns: + obj + """ + setattr(obj, _DO_NOT_DOC, None) + return obj + + +_DO_NOT_DOC_INHERITABLE = "_tf_docs_do_not_doc_inheritable" + + +def do_not_doc_inheritable(obj): + """A decorator: Do not generate docs for this method. + + This version of the decorator is "inherited" by subclasses. No docs will be + generated for the decorated method in any subclass. Even if the sub-class + overrides the method. + + For example, to ensure that `method1` is **never documented** use this + decorator on the base-class: + + ``` + class Parent(object): + @do_not_doc_inheritable + def method1(self): + pass + def method2(self): + pass + + class Child(Parent): + def method1(self): + pass + def method2(self): + pass + ``` + This will produce the following docs: + + ``` + /Parent.md + # method2 + /Child.md + # method2 + ``` + + When generating docs for a class's arributes, the `__mro__` is searched and + the attribute will be skipped if this decorator is detected on the attribute + on any class in the `__mro__`. + + Note: This is implemented by adding a hidden attribute on the object, so it + cannot be used on objects which do not allow new attributes to be added. So + this decorator must go *below* `@property`, `@classmethod`, + or `@staticmethod`: + + ``` + class Example(object): + @property + @do_not_doc_inheritable + def x(self): + return self._x + ``` + + Args: + obj: The class-attribute to hide from the generated docs. + + Returns: + obj + """ + setattr(obj, _DO_NOT_DOC_INHERITABLE, None) + return obj + + +_FOR_SUBCLASS_IMPLEMENTERS = "_tf_docs_tools_for_subclass_implementers" + + +def for_subclass_implementers(obj): + """A decorator: Only generate docs for this method in the defining class. + + Also group this method's docs with and `@abstractmethod` in the class's docs. + + No docs will generated for this class attribute in sub-classes. + + The canonical use case for this is `tf.keras.layers.Layer.call`: It's a + public method, essential for anyone implementing a subclass, but it should + never be called directly. + + Works on method, or other class-attributes. + + When generating docs for a class's arributes, the `__mro__` is searched and + the attribute will be skipped if this decorator is detected on the attribute + on any **parent** class in the `__mro__`. + + For example: + + ``` + class Parent(object): + @for_subclass_implementers + def method1(self): + pass + def method2(self): + pass + + class Child1(Parent): + def method1(self): + pass + def method2(self): + pass + + class Child2(Parent): + def method1(self): + pass + def method2(self): + pass + ``` + + This will produce the following docs: + + ``` + /Parent.md + # method1 + # method2 + /Child1.md + # method2 + /Child2.md + # method2 + ``` + + Note: This is implemented by adding a hidden attribute on the object, so it + cannot be used on objects which do not allow new attributes to be added. So + this decorator must go *below* `@property`, `@classmethod`, + or `@staticmethod`: + + ``` + class Example(object): + @property + @for_subclass_implementers + def x(self): + return self._x + ``` + + Args: + obj: The class-attribute to hide from the generated docs. + + Returns: + obj + """ + setattr(obj, _FOR_SUBCLASS_IMPLEMENTERS, None) + return obj + + +def should_skip(obj): + """Returns true if docs generation should be skipped for this object. + + checks for the `do_not_generate_docs` or `do_not_doc_inheritable` decorators. + + Args: + obj: The object to document, or skip. + + Returns: + True if the object should be skipped + """ + # Unwrap fget if the object is a property + if isinstance(obj, property): + obj = obj.fget + + return hasattr(obj, _DO_NOT_DOC) or hasattr(obj, _DO_NOT_DOC_INHERITABLE) + + +def should_skip_class_attr(cls, name): + """Returns true if docs should be skipped for this class attribute. + + Args: + cls: The class the attribute belongs to. + name: The name of the attribute. + + Returns: + True if the attribute should be skipped. + """ + # Get the object with standard lookup, from the nearest + # defining parent. + try: + obj = getattr(cls, name) + except AttributeError: + # Avoid error caused by enum metaclasses in python3 + if name in ("name", "value"): + return True + raise + + # Unwrap fget if the object is a property + if isinstance(obj, property): + obj = obj.fget + + # Skip if the object is decorated with `do_not_generate_docs` or + # `do_not_doc_inheritable` + if should_skip(obj): + return True + + # Use __dict__ lookup to get the version defined in *this* class. + obj = cls.__dict__.get(name, None) + if isinstance(obj, property): + obj = obj.fget + if obj is not None: + # If not none, the object is defined in *this* class. + # Do not skip if decorated with `for_subclass_implementers`. + if hasattr(obj, _FOR_SUBCLASS_IMPLEMENTERS): + return False + + # for each parent class + for parent in cls.__mro__[1:]: + obj = getattr(parent, name, None) + + if obj is None: + continue + + if isinstance(obj, property): + obj = obj.fget + + # Skip if the parent's definition is decorated with `do_not_doc_inheritable` + # or `for_subclass_implementers` + if hasattr(obj, _DO_NOT_DOC_INHERITABLE): + return True + + if hasattr(obj, _FOR_SUBCLASS_IMPLEMENTERS): + return True + + # No blockng decorators --> don't skip + return False diff --git a/tensorflow/tools/docs/doc_controls_test.py b/tensorflow/tools/docs/doc_controls_test.py new file mode 100644 index 0000000000000000000000000000000000000000..410342fb69fde518178ca8708b88e8ce5ae71a10 --- /dev/null +++ b/tensorflow/tools/docs/doc_controls_test.py @@ -0,0 +1,183 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for documentation control decorators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import googletest +from tensorflow.tools.docs import doc_controls + + +class DocControlsTest(googletest.TestCase): + + def test_do_not_generate_docs(self): + + @doc_controls.do_not_generate_docs + def dummy_function(): + pass + + self.assertTrue(doc_controls.should_skip(dummy_function)) + + def test_do_not_doc_on_method(self): + """The simple decorator is not aware of inheritance.""" + + class Parent(object): + + @doc_controls.do_not_generate_docs + def my_method(self): + pass + + class Child(Parent): + + def my_method(self): + pass + + class GrandChild(Child): + pass + + self.assertTrue(doc_controls.should_skip(Parent.my_method)) + self.assertFalse(doc_controls.should_skip(Child.my_method)) + self.assertFalse(doc_controls.should_skip(GrandChild.my_method)) + + self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) + self.assertFalse(doc_controls.should_skip_class_attr(Child, 'my_method')) + self.assertFalse( + doc_controls.should_skip_class_attr(GrandChild, 'my_method')) + + def test_do_not_doc_inheritable(self): + + class Parent(object): + + @doc_controls.do_not_doc_inheritable + def my_method(self): + pass + + class Child(Parent): + + def my_method(self): + pass + + class GrandChild(Child): + pass + + self.assertTrue(doc_controls.should_skip(Parent.my_method)) + self.assertFalse(doc_controls.should_skip(Child.my_method)) + self.assertFalse(doc_controls.should_skip(GrandChild.my_method)) + + self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) + self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) + self.assertTrue( + doc_controls.should_skip_class_attr(GrandChild, 'my_method')) + + def test_do_not_doc_inheritable_property(self): + + class Parent(object): + + @property + @doc_controls.do_not_doc_inheritable + def my_method(self): + pass + + class Child(Parent): + + @property + def my_method(self): + pass + + class GrandChild(Child): + pass + + self.assertTrue(doc_controls.should_skip(Parent.my_method)) + self.assertFalse(doc_controls.should_skip(Child.my_method)) + self.assertFalse(doc_controls.should_skip(GrandChild.my_method)) + + self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) + self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) + self.assertTrue( + doc_controls.should_skip_class_attr(GrandChild, 'my_method')) + + def test_do_not_doc_inheritable_staticmethod(self): + + class GrandParent(object): + + def my_method(self): + pass + + class Parent(GrandParent): + + @staticmethod + @doc_controls.do_not_doc_inheritable + def my_method(): + pass + + class Child(Parent): + + @staticmethod + def my_method(): + pass + + class GrandChild(Child): + pass + + self.assertFalse(doc_controls.should_skip(GrandParent.my_method)) + self.assertTrue(doc_controls.should_skip(Parent.my_method)) + self.assertFalse(doc_controls.should_skip(Child.my_method)) + self.assertFalse(doc_controls.should_skip(GrandChild.my_method)) + + self.assertFalse( + doc_controls.should_skip_class_attr(GrandParent, 'my_method')) + self.assertTrue(doc_controls.should_skip_class_attr(Parent, 'my_method')) + self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) + self.assertTrue( + doc_controls.should_skip_class_attr(GrandChild, 'my_method')) + + def testfor_subclass_implementers(self): + + class GrandParent(object): + + def my_method(self): + pass + + class Parent(GrandParent): + + @doc_controls.for_subclass_implementers + def my_method(self): + pass + + class Child(Parent): + pass + + class GrandChild(Child): + + def my_method(self): + pass + + class Grand2Child(Child): + pass + + self.assertFalse( + doc_controls.should_skip_class_attr(GrandParent, 'my_method')) + self.assertFalse(doc_controls.should_skip_class_attr(Parent, 'my_method')) + self.assertTrue(doc_controls.should_skip_class_attr(Child, 'my_method')) + self.assertTrue( + doc_controls.should_skip_class_attr(GrandChild, 'my_method')) + self.assertTrue( + doc_controls.should_skip_class_attr(Grand2Child, 'my_method')) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/tools/docs/doc_generator_visitor.py b/tensorflow/tools/docs/doc_generator_visitor.py index 259a4694fdcc0048a25d9facf2d45eaa86d6daaa..a66f3e449377fef3d4c7bf4e0b8810cd6111eb85 100644 --- a/tensorflow/tools/docs/doc_generator_visitor.py +++ b/tensorflow/tools/docs/doc_generator_visitor.py @@ -20,6 +20,7 @@ from __future__ import print_function import six +from tensorflow.python.util import tf_export from tensorflow.python.util import tf_inspect @@ -158,6 +159,55 @@ class DocGeneratorVisitor(object): self._index[full_name] = child self._tree[parent_name].append(name) + def _score_name(self, name): + """Return a tuple of scores indicating how to sort for the best name. + + This function is meant to be used as the `key` to the `sorted` function. + + This sorting in order: + Prefers names refering to the defining class, over a subclass. + Prefers names that are not in "contrib". + prefers submodules to the root namespace. + Prefers short names `tf.thing` over `tf.a.b.c.thing` + Sorts lexicographically on name parts. + + Args: + name: the full name to score, for example `tf.estimator.Estimator` + + Returns: + A tuple of scores. When sorted the preferred name will have the lowest + value. + """ + parts = name.split('.') + short_name = parts[-1] + + container = self._index['.'.join(parts[:-1])] + + defining_class_score = 1 + if tf_inspect.isclass(container): + if short_name in container.__dict__: + # prefer the defining class + defining_class_score = -1 + + contrib_score = -1 + if 'contrib' in parts: + contrib_score = 1 + + while parts: + parts.pop() + container = self._index['.'.join(parts)] + if tf_inspect.ismodule(container): + break + module_length = len(parts) + if len(parts) == 2: + # `tf.submodule.thing` is better than `tf.thing` + module_length_score = -1 + else: + # shorter is better + module_length_score = module_length + + return (defining_class_score, contrib_score, module_length_score, name) + def _maybe_find_duplicates(self): """Compute data structures containing information about duplicates. @@ -191,7 +241,7 @@ class DocGeneratorVisitor(object): if (py_object is not None and not isinstance(py_object, six.integer_types + six.string_types + (six.binary_type, six.text_type, float, complex, bool)) - and py_object is not ()): + and py_object is not ()): # pylint: disable=literal-comparison object_id = id(py_object) if object_id in reverse_index: master_name = reverse_index[object_id] @@ -201,7 +251,6 @@ class DocGeneratorVisitor(object): raw_duplicates[master_name] = [master_name, full_name] else: reverse_index[object_id] = full_name - # Decide on master names, rewire duplicates and make a duplicate_of map # mapping all non-master duplicates to the master name. The master symbol # does not have an entry in this map. @@ -211,10 +260,15 @@ class DocGeneratorVisitor(object): duplicates = {} for names in raw_duplicates.values(): names = sorted(names) - - # Choose the lexicographically first name with the minimum number of - # submodules. This will prefer highest level namespace for any symbol. - master_name = min(names, key=lambda name: name.count('.')) + master_name = ( + tf_export.get_canonical_name_for_symbol(self._index[names[0]]) + if names else None) + if master_name: + master_name = 'tf.%s' % master_name + else: + # Choose the master name with a lexical sort on the tuples returned by + # by _score_name. + master_name = min(names, key=self._score_name) duplicates[master_name] = names for name in names: diff --git a/tensorflow/tools/docs/doc_generator_visitor_test.py b/tensorflow/tools/docs/doc_generator_visitor_test.py index cf5be45f40e3a2f727c3961c2896754cf8f269f2..1c2635d4a8c0acbe25502e3b9870420a38b7e22e 100644 --- a/tensorflow/tools/docs/doc_generator_visitor_test.py +++ b/tensorflow/tools/docs/doc_generator_visitor_test.py @@ -18,8 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import types + from tensorflow.python.platform import googletest from tensorflow.tools.docs import doc_generator_visitor +from tensorflow.tools.docs import generate_lib + + +class NoDunderVisitor(doc_generator_visitor.DocGeneratorVisitor): + + def __call__(self, parent_name, parent, children): + """Drop all the dunder methods to make testing easier.""" + children = [ + (name, obj) for (name, obj) in children if not name.startswith('_') + ] + super(NoDunderVisitor, self).__call__(parent_name, parent, children) class DocGeneratorVisitorTest(googletest.TestCase): @@ -57,52 +70,184 @@ class DocGeneratorVisitorTest(googletest.TestCase): with self.assertRaises(RuntimeError): visitor('non_class_or_module', 'non_class_or_module_object', []) - def test_duplicates(self): - visitor = doc_generator_visitor.DocGeneratorVisitor() - visitor( - 'submodule.DocGeneratorVisitor', - doc_generator_visitor.DocGeneratorVisitor, - [('index', doc_generator_visitor.DocGeneratorVisitor.index), - ('index2', doc_generator_visitor.DocGeneratorVisitor.index)]) - visitor( - 'submodule2.DocGeneratorVisitor', - doc_generator_visitor.DocGeneratorVisitor, - [('index', doc_generator_visitor.DocGeneratorVisitor.index), - ('index2', doc_generator_visitor.DocGeneratorVisitor.index)]) - visitor( - 'DocGeneratorVisitor2', - doc_generator_visitor.DocGeneratorVisitor, - [('index', doc_generator_visitor.DocGeneratorVisitor.index), - ('index2', doc_generator_visitor.DocGeneratorVisitor.index)]) - - # The shorter path should be master, or if equal, the lexicographically - # first will be. - self.assertEqual( - {'DocGeneratorVisitor2': sorted(['submodule.DocGeneratorVisitor', - 'submodule2.DocGeneratorVisitor', - 'DocGeneratorVisitor2']), - 'DocGeneratorVisitor2.index': sorted([ - 'submodule.DocGeneratorVisitor.index', - 'submodule.DocGeneratorVisitor.index2', - 'submodule2.DocGeneratorVisitor.index', - 'submodule2.DocGeneratorVisitor.index2', - 'DocGeneratorVisitor2.index', - 'DocGeneratorVisitor2.index2' - ]), - }, visitor.duplicates) - self.assertEqual({ - 'submodule.DocGeneratorVisitor': 'DocGeneratorVisitor2', - 'submodule.DocGeneratorVisitor.index': 'DocGeneratorVisitor2.index', - 'submodule.DocGeneratorVisitor.index2': 'DocGeneratorVisitor2.index', - 'submodule2.DocGeneratorVisitor': 'DocGeneratorVisitor2', - 'submodule2.DocGeneratorVisitor.index': 'DocGeneratorVisitor2.index', - 'submodule2.DocGeneratorVisitor.index2': 'DocGeneratorVisitor2.index', - 'DocGeneratorVisitor2.index2': 'DocGeneratorVisitor2.index' + def test_duplicates_module_class_depth(self): + + class Parent(object): + + class Nested(object): + pass + + tf = types.ModuleType('tf') + tf.Parent = Parent + tf.submodule = types.ModuleType('submodule') + tf.submodule.Parent = Parent + + visitor = generate_lib.extract( + [('tf', tf)], + private_map={}, + do_not_descend_map={}, + visitor_cls=NoDunderVisitor) + + self.assertEqual({ + 'tf.submodule.Parent': + sorted([ + 'tf.Parent', + 'tf.submodule.Parent', + ]), + 'tf.submodule.Parent.Nested': + sorted([ + 'tf.Parent.Nested', + 'tf.submodule.Parent.Nested', + ]), + }, visitor.duplicates) + + self.assertEqual({ + 'tf.Parent.Nested': 'tf.submodule.Parent.Nested', + 'tf.Parent': 'tf.submodule.Parent', + }, visitor.duplicate_of) + + self.assertEqual({ + id(Parent): 'tf.submodule.Parent', + id(Parent.Nested): 'tf.submodule.Parent.Nested', + id(tf): 'tf', + id(tf.submodule): 'tf.submodule', + }, visitor.reverse_index) + + def test_duplicates_contrib(self): + + class Parent(object): + pass + + tf = types.ModuleType('tf') + tf.contrib = types.ModuleType('contrib') + tf.submodule = types.ModuleType('submodule') + tf.contrib.Parent = Parent + tf.submodule.Parent = Parent + + visitor = generate_lib.extract( + [('tf', tf)], + private_map={}, + do_not_descend_map={}, + visitor_cls=NoDunderVisitor) + + self.assertEqual({ + 'tf.submodule.Parent': + sorted(['tf.contrib.Parent', 'tf.submodule.Parent']), + }, visitor.duplicates) + + self.assertEqual({ + 'tf.contrib.Parent': 'tf.submodule.Parent', + }, visitor.duplicate_of) + + self.assertEqual({ + id(tf): 'tf', + id(tf.submodule): 'tf.submodule', + id(Parent): 'tf.submodule.Parent', + id(tf.contrib): 'tf.contrib', + }, visitor.reverse_index) + + def test_duplicates_defining_class(self): + + class Parent(object): + obj1 = object() + + class Child(Parent): + pass + + tf = types.ModuleType('tf') + tf.Parent = Parent + tf.Child = Child + + visitor = generate_lib.extract( + [('tf', tf)], + private_map={}, + do_not_descend_map={}, + visitor_cls=NoDunderVisitor) + + self.assertEqual({ + 'tf.Parent.obj1': sorted([ + 'tf.Parent.obj1', + 'tf.Child.obj1', + ]), + }, visitor.duplicates) + + self.assertEqual({ + 'tf.Child.obj1': 'tf.Parent.obj1', }, visitor.duplicate_of) + + self.assertEqual({ + id(tf): 'tf', + id(Parent): 'tf.Parent', + id(Child): 'tf.Child', + id(Parent.obj1): 'tf.Parent.obj1', + }, visitor.reverse_index) + + def test_duplicates_module_depth(self): + + class Parent(object): + pass + + tf = types.ModuleType('tf') + tf.submodule = types.ModuleType('submodule') + tf.submodule.submodule2 = types.ModuleType('submodule2') + tf.Parent = Parent + tf.submodule.submodule2.Parent = Parent + + visitor = generate_lib.extract( + [('tf', tf)], + private_map={}, + do_not_descend_map={}, + visitor_cls=NoDunderVisitor) + + self.assertEqual({ + 'tf.Parent': sorted(['tf.Parent', 'tf.submodule.submodule2.Parent']), + }, visitor.duplicates) + + self.assertEqual({ + 'tf.submodule.submodule2.Parent': 'tf.Parent' + }, visitor.duplicate_of) + + self.assertEqual({ + id(tf): 'tf', + id(tf.submodule): 'tf.submodule', + id(tf.submodule.submodule2): 'tf.submodule.submodule2', + id(Parent): 'tf.Parent', + }, visitor.reverse_index) + + def test_duplicates_name(self): + + class Parent(object): + obj1 = object() + + Parent.obj2 = Parent.obj1 + + tf = types.ModuleType('tf') + tf.submodule = types.ModuleType('submodule') + tf.submodule.Parent = Parent + + visitor = generate_lib.extract( + [('tf', tf)], + private_map={}, + do_not_descend_map={}, + visitor_cls=NoDunderVisitor) + + self.assertEqual({ + 'tf.submodule.Parent.obj1': + sorted([ + 'tf.submodule.Parent.obj1', + 'tf.submodule.Parent.obj2', + ]), + }, visitor.duplicates) + + self.assertEqual({ + 'tf.submodule.Parent.obj2': 'tf.submodule.Parent.obj1', + }, visitor.duplicate_of) + self.assertEqual({ - id(doc_generator_visitor.DocGeneratorVisitor): 'DocGeneratorVisitor2', - id(doc_generator_visitor.DocGeneratorVisitor.index): - 'DocGeneratorVisitor2.index', + id(tf): 'tf', + id(tf.submodule): 'tf.submodule', + id(Parent): 'tf.submodule.Parent', + id(Parent.obj1): 'tf.submodule.Parent.obj1', }, visitor.reverse_index) if __name__ == '__main__': diff --git a/tensorflow/tools/docs/generate.py b/tensorflow/tools/docs/generate.py index fc93085e3e0316cf274f4d9b325d6af0ea3a2f83..f96887e4c70b0580fd8a799c8f1d602491a66ef2 100644 --- a/tensorflow/tools/docs/generate.py +++ b/tensorflow/tools/docs/generate.py @@ -31,6 +31,11 @@ if __name__ == '__main__': doc_generator = generate_lib.DocGenerator() doc_generator.add_output_dir_argument() doc_generator.add_src_dir_argument() + doc_generator.argument_parser.add_argument( + '--site_api_path', + type=str, default='api_docs/python', + help='The path from the site-root to api_docs' + 'directory for this project') # This doc generator works on the TensorFlow codebase. Since this script lives # at tensorflow/tools/docs, and all code is defined somewhere inside diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 853ec6194f8327f13b3eb6ac7792511c9c4494cd..b9c418db1e821009023181c21fab048ec163608f 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -21,12 +21,14 @@ from __future__ import print_function import argparse import fnmatch import os +import shutil import six from tensorflow.python.util import tf_inspect from tensorflow.tools.common import public_api from tensorflow.tools.common import traverse +from tensorflow.tools.docs import doc_controls from tensorflow.tools.docs import doc_generator_visitor from tensorflow.tools.docs import parser from tensorflow.tools.docs import pretty_docs @@ -54,7 +56,8 @@ def write_docs(output_dir, parser_config, yaml_toc, root_title='TensorFlow', - search_hints=True): + search_hints=True, + site_api_path=None): """Write previously extracted docs to disk. Write a docs page for each symbol included in the indices of parser_config to @@ -72,6 +75,8 @@ def write_docs(output_dir, root_title: The title name for the root level index.md. search_hints: (bool) include meta-data search hints at the top of each output file. + site_api_path: Used to write the api-duplicates _redirects.yaml file. if + None (the default) the file is not generated. Raises: ValueError: if `output_dir` is not an absolute path @@ -81,12 +86,8 @@ def write_docs(output_dir, raise ValueError("'output_dir' must be an absolute path.\n" " output_dir='%s'" % output_dir) - try: - if not os.path.exists(output_dir): - os.makedirs(output_dir) - except OSError as e: - print('Creating output dir "%s" failed: %s' % (output_dir, e)) - raise + if not os.path.exists(output_dir): + os.makedirs(output_dir) # These dictionaries are used for table-of-contents generation below # They will contain, after the for-loop below:: @@ -95,6 +96,9 @@ def write_docs(output_dir, # - symbol name(string):pathname (string) symbol_to_file = {} + # Collect redirects for an api _redirects.yaml file. + redirects = ['redirects:\n'] + # Parse and write Markdown pages, resolving cross-links (@{symbol}). for full_name, py_object in six.iteritems(parser_config.index): parser_config.reference_resolver.current_doc_full_name = full_name @@ -107,6 +111,9 @@ def write_docs(output_dir, _is_free_function(py_object, full_name, parser_config.index)): continue + if doc_controls.should_skip(py_object): + continue + sitepath = os.path.join('api_docs/python', parser.documentation_path(full_name)[:-3]) @@ -129,8 +136,6 @@ def write_docs(output_dir, module_children.setdefault(subname, []).append(full_name) break - print('Writing docs for %s (%r).' % (full_name, py_object)) - # Generate docs for `py_object`, resolving references. page_info = parser.docs_for_object(full_name, py_object, parser_config) @@ -151,10 +156,28 @@ def write_docs(output_dir, text = text.encode('utf-8') with open(path, 'wb') as f: f.write(text) - except OSError as e: - print('Cannot write documentation for %s to %s: %s' % (full_name, - directory, e)) - raise + except OSError: + raise OSError( + 'Cannot write documentation for %s to %s' % (full_name, directory)) + + if site_api_path: + duplicates = parser_config.duplicates.get(full_name, []) + if not duplicates: + continue + + duplicates = [item for item in duplicates if item != full_name] + template = ('- from: /{}\n' + ' to: /{}\n') + for dup in duplicates: + from_path = os.path.join(site_api_path, dup.replace('.', '/')) + to_path = os.path.join(site_api_path, full_name.replace('.', '/')) + redirects.append( + template.format(from_path, to_path)) + + if site_api_path: + api_redirects_path = os.path.join(output_dir, '_redirects.yaml') + with open(api_redirects_path, 'w') as redirect_file: + redirect_file.write(''.join(redirects)) if yaml_toc: # Generate table of contents @@ -216,12 +239,16 @@ def add_dict_to_dict(add_from, add_to): # Exclude some libraries in contrib from the documentation altogether. def _get_default_private_map(): - return {'tf.test': ['mock']} + return { + 'tf.contrib.autograph': ['utils', 'operators'], + 'tf.test': ['mock'], + 'tf.compat': ['v1', 'v2'], + } # Exclude members of some libraries. def _get_default_do_not_descend_map(): - # TODO(wicke): Shrink this list once the modules get sealed. + # TODO(markdaoust): Use docs_controls decorators, locally, instead. return { 'tf': ['cli', 'lib', 'wrappers'], 'tf.contrib': [ @@ -265,10 +292,13 @@ def _get_default_do_not_descend_map(): } -def extract(py_modules, private_map, do_not_descend_map): +def extract(py_modules, + private_map, + do_not_descend_map, + visitor_cls=doc_generator_visitor.DocGeneratorVisitor): """Extract docs from tf namespace and write them to disk.""" # Traverse the first module. - visitor = doc_generator_visitor.DocGeneratorVisitor(py_modules[0][0]) + visitor = visitor_cls(py_modules[0][0]) api_visitor = public_api.PublicAPIVisitor(visitor) api_visitor.set_root_name(py_modules[0][0]) add_dict_to_dict(private_map, api_visitor.private_map) @@ -394,16 +424,40 @@ def _build_guide_index(guide_src_dir): class _UpdateTags(py_guide_parser.PyGuideParser): - """Rewrites a Python guide so that each section has an explicit tag.""" + """Rewrites a Python guide so that each section has an explicit id tag. + + "section" here refers to blocks delimited by second level headings. + """ def process_section(self, line_number, section_title, tag): self.replace_line(line_number, '

%s

' % (tag, section_title)) +def update_id_tags_inplace(src_dir): + """Set explicit ids on all second-level headings to ensure back-links work. + + Args: + src_dir: The directory of md-files to convert (inplace). + """ + tag_updater = _UpdateTags() + + for dirpath, _, filenames in os.walk(src_dir): + for base_name in filenames: + if not base_name.endswith('.md'): + continue + full_path = os.path.join(src_dir, dirpath, base_name) + + # Tag updater loads the file, makes the replacements, and returns the + # modified file contents + content = tag_updater.process(full_path) + with open(full_path, 'w') as f: + f.write(content) + + EXCLUDED = set(['__init__.py', 'OWNERS', 'README.txt']) -def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): +def replace_refs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): """Fix @{} references in all files under `src_dir` matching `file_pattern`. A matching directory structure, with the modified files is @@ -424,7 +478,6 @@ def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): using fnmatch. Non-matching files are copied unchanged. """ # Iterate through all the source files and process them. - tag_updater = _UpdateTags() for dirpath, _, filenames in os.walk(src_dir): # How to get from `dirpath` to api_docs/python/ relative_path_to_root = os.path.relpath( @@ -433,41 +486,32 @@ def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'): # Make the directory under output_dir. new_dir = os.path.join(output_dir, os.path.relpath(path=dirpath, start=src_dir)) - try: - if not os.path.exists(new_dir): - os.makedirs(new_dir) - except OSError as e: - print('Creating output dir "%s" failed: %s' % (new_dir, e)) - raise + if not os.path.exists(new_dir): + os.makedirs(new_dir) for base_name in filenames: if base_name in EXCLUDED: - print('Skipping excluded file %s...' % base_name) continue full_in_path = os.path.join(dirpath, base_name) + # Set the `current_doc_full_name` so bad files can be reported on errors. reference_resolver.current_doc_full_name = full_in_path suffix = os.path.relpath(path=full_in_path, start=src_dir) full_out_path = os.path.join(output_dir, suffix) + # Copy files that do not match the file_pattern, unmodified. if not fnmatch.fnmatch(base_name, file_pattern): - print('Copying un-matched file %s...' % suffix) - open(full_out_path, 'wb').write(open(full_in_path, 'rb').read()) + shutil.copyfile(full_in_path, full_out_path) continue - if dirpath.endswith('/api_guides/python'): - print('Processing Python guide %s...' % base_name) - content = tag_updater.process(full_in_path) - else: - print('Processing doc %s...' % suffix) - content = open(full_in_path, 'rb').read().decode('utf-8') + + with open(full_in_path, 'rb') as f: + content = f.read().decode('utf-8') content = reference_resolver.replace_references(content, relative_path_to_root) with open(full_out_path, 'wb') as f: f.write(content.encode('utf-8')) - print('Done.') - class DocGenerator(object): """Main entry point for generating docs.""" @@ -554,15 +598,43 @@ class DocGenerator(object): self._do_not_descend_map) def build(self, flags): - """Actually build the docs.""" + """Build all the docs. + + This produces two outputs + + python api docs: + + * generated from modules set with `set_py_modules`. + * written to '{FLAGS.output_dir}/api_docs/python/' + + non-api docs: + + * Everything in '{FLAGS.src_dir}' is copied to '{FLAGS.output_dir}'. + * '@{}' references in '.md' files are replaced with links. + * '.md' files under 'api_guides/python' have explicit ids set for their + second level headings. + + Args: + flags: + * src_dir: Where to fetch the non-api-docs. + * base_dir: Base of the docs directory (Used to build correct + relative links). + * output_dir: Where to write the resulting docs. + + Returns: + The number of errors encountered while processing. + """ + # Extract the python api from the _py_modules doc_index = build_doc_index(flags.src_dir) visitor = self.run_extraction() reference_resolver = self.make_reference_resolver(visitor, doc_index) + # Build the guide_index for the api_docs back links. root_title = getattr(flags, 'root_title', 'TensorFlow') guide_index = _build_guide_index( os.path.join(flags.src_dir, 'api_guides/python')) + # Write the api docs. parser_config = self.make_parser_config(visitor, reference_resolver, guide_index, flags.base_dir) output_dir = os.path.join(flags.output_dir, 'api_docs/python') @@ -572,9 +644,18 @@ class DocGenerator(object): parser_config, yaml_toc=self.yaml_toc, root_title=root_title, - search_hints=getattr(flags, 'search_hints', True)) - _other_docs(flags.src_dir, flags.output_dir, reference_resolver) - + search_hints=getattr(flags, 'search_hints', True), + site_api_path=getattr(flags, 'site_api_path', None)) + + # Replace all the @{} references in files under `FLAGS.src_dir` + replace_refs(flags.src_dir, flags.output_dir, reference_resolver, '*.md') + # Fix the tags in the guide dir. + guide_dir = os.path.join(flags.output_dir, 'api_guides/python') + if os.path.exists(guide_dir): + update_id_tags_inplace(guide_dir) + + # Report all errors found by the reference resolver, and return the error + # code. parser_config.reference_resolver.log_errors() return parser_config.reference_resolver.num_errors() diff --git a/tensorflow/tools/docs/generate_lib_test.py b/tensorflow/tools/docs/generate_lib_test.py index ea6d28a02b1f3c07fe8783fd59e345dade1fc804..de18b1325454ce4c1c02bb943f7443c3e1876d5f 100644 --- a/tensorflow/tools/docs/generate_lib_test.py +++ b/tensorflow/tools/docs/generate_lib_test.py @@ -51,7 +51,9 @@ class DummyVisitor(object): class GenerateTest(googletest.TestCase): - def test_write(self): + def get_test_objects(self): + # These are all mutable objects, so rebuild them for each test. + # Don't cache the objects. module = sys.modules[__name__] index = { @@ -98,9 +100,25 @@ class GenerateTest(googletest.TestCase): guide_index={}, base_dir=base_dir) + return reference_resolver, parser_config + + def test_write(self): + _, parser_config = self.get_test_objects() + output_dir = googletest.GetTempDir() - generate_lib.write_docs(output_dir, parser_config, yaml_toc=True) + generate_lib.write_docs(output_dir, parser_config, yaml_toc=True, + site_api_path='api_docs/python') + + # Check redirects + redirects_file = os.path.join(output_dir, '_redirects.yaml') + self.assertTrue(os.path.exists(redirects_file)) + with open(redirects_file) as f: + redirects = f.read() + self.assertEqual(redirects.split(), [ + 'redirects:', '-', 'from:', '/api_docs/python/tf/test_function', 'to:', + '/api_docs/python/tf/TestModule/test_function' + ]) # Make sure that the right files are written to disk. self.assertTrue(os.path.exists(os.path.join(output_dir, 'index.md'))) @@ -127,6 +145,107 @@ class GenerateTest(googletest.TestCase): os.path.exists( os.path.join(output_dir, 'tf/TestModule/test_function.md'))) + def test_update_id_tags_inplace(self): + test_dir = googletest.GetTempDir() + test_sub_dir = os.path.join(test_dir, 'a/b') + os.makedirs(test_sub_dir) + + test_path1 = os.path.join(test_dir, 'file1.md') + test_path2 = os.path.join(test_sub_dir, 'file2.md') + test_path3 = os.path.join(test_sub_dir, 'file3.notmd') + + with open(test_path1, 'w') as f: + f.write('## abc&123') + + with open(test_path2, 'w') as f: + f.write('# A Level 1 Heading\n') + f.write('## A Level 2 Heading') + + with open(test_path3, 'w') as f: + f.write("## don\'t change this") + + generate_lib.update_id_tags_inplace(test_dir) + + with open(test_path1) as f: + content = f.read() + + self.assertEqual(content, '

abc&123

') + + with open(test_path2) as f: + content = f.read() + + self.assertEqual( + content, '# A Level 1 Heading\n' + '

A Level 2 Heading

') + + with open(test_path3) as f: + content = f.read() + + self.assertEqual(content, "## don\'t change this") + + def test_replace_refes(self): + test_dir = googletest.GetTempDir() + test_in_dir = os.path.join(test_dir, 'in') + test_in_dir_a = os.path.join(test_dir, 'in/a') + test_in_dir_b = os.path.join(test_dir, 'in/b') + os.makedirs(test_in_dir) + os.makedirs(test_in_dir_a) + os.makedirs(test_in_dir_b) + + test_out_dir = os.path.join(test_dir, 'out') + os.makedirs(test_out_dir) + + test_path1 = os.path.join(test_in_dir_a, 'file1.md') + test_path2 = os.path.join(test_in_dir_b, 'file2.md') + test_path3 = os.path.join(test_in_dir_b, 'file3.notmd') + test_path4 = os.path.join(test_in_dir_b, 'OWNERS') + + with open(test_path1, 'w') as f: + f.write('Use `tf.test_function` to test things.') + + with open(test_path2, 'w') as f: + f.write('Use @{tf.TestModule.TestClass.ChildClass} to test things.\n' + "`tf.whatever` doesn't exist") + + with open(test_path3, 'w') as f: + file3_content = ( + 'Not a .md file. Should be copied unchanged:' + '@{tf.TestModule.TestClass.ChildClass}, `tf.test_function`') + f.write(file3_content) + + with open(test_path4, 'w') as f: + f.write('') + + reference_resolver, _ = self.get_test_objects() + generate_lib.replace_refs(test_in_dir, test_out_dir, reference_resolver, + '*.md') + + with open(os.path.join(test_out_dir, 'a/file1.md')) as f: + content = f.read() + self.assertEqual( + content, + 'Use ' + 'tf.test_function to test things.') + + with open(os.path.join(test_out_dir, 'b/file2.md')) as f: + content = f.read() + self.assertEqual( + content, + 'Use ' + '' + 'tf.TestModule.TestClass.ChildClass ' + 'to test things.\n' + '`tf.whatever` doesn\'t exist') + + with open(os.path.join(test_out_dir, 'b/file3.notmd')) as f: + content = f.read() + self.assertEqual(content, file3_content) + + with self.assertRaises(IOError): + # This should fail. The OWNERS file should not be copied + with open(os.path.join(test_out_dir, 'b/OWNERS')) as f: + content = f.read() + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index 50c90527413d0904c78dab199a68678f6cc91845..801c8bcb4aa97925fdf17426a40b6be99832ba1a 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -25,13 +25,14 @@ import itertools import json import os import re -import sys import astor import six from google.protobuf.message import Message as ProtoMessage +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect +from tensorflow.tools.docs import doc_controls # A regular expression capturing a python identifier. @@ -53,7 +54,7 @@ class _Errors(object): template = 'ERROR:\n output file name: %s\n %s\n\n' for full_name, message in self._errors: - print(template % (full_name, message), file=sys.stderr) + logging.warn(template, full_name, message) def append(self, full_name, message): """Add an error to the collection. @@ -761,8 +762,9 @@ def _generate_signature(func, reverse_index): lookup_text = public_name + default_text[len(internal_name):] break if default_text is lookup_text: - print('WARNING: Using default arg, failed lookup: %s, repr: %r' % - (default_text, default)) + logging.warn( + 'WARNING: Using default arg, failed lookup: %s, repr: %r', + default_text, default) else: default_text = lookup_text else: @@ -1165,7 +1167,7 @@ class _ClassPageInfo(object): if short_name in [ '__class__', '__base__', '__weakref__', '__doc__', '__module__', '__dict__', '__abstractmethods__', '__slots__', '__getnewargs__', - '__str__', '__repr__', '__hash__' + '__str__', '__repr__', '__hash__', '__reduce__' ]: continue @@ -1174,15 +1176,18 @@ class _ClassPageInfo(object): # Don't document anything that is defined in object or by protobuf. defining_class = _get_defining_class(py_class, short_name) - if (defining_class is object or - defining_class is type or defining_class is tuple or - defining_class is BaseException or defining_class is Exception or - # The following condition excludes most protobuf-defined symbols. - defining_class and defining_class.__name__ in ['CMessage', 'Message', - 'MessageMeta']): + if defining_class in [object, type, tuple, BaseException, Exception]: + continue + + # The following condition excludes most protobuf-defined symbols. + if (defining_class and + defining_class.__name__ in ['CMessage', 'Message', 'MessageMeta']): continue # TODO(markdaoust): Add a note in child docs showing the defining class. + if doc_controls.should_skip_class_attr(py_class, short_name): + continue + child_doc = _parse_md_docstring(child, relative_path, parser_config.reference_resolver) @@ -1213,8 +1218,6 @@ class _ClassPageInfo(object): if not child_doc.brief.strip() and short_name in [ '__del__', '__copy__' ]: - print('Skipping %s, defined in %s, no docstring.' % (child_name, - defining_class)) continue try: @@ -1371,7 +1374,8 @@ class _ModulePageInfo(object): for name in member_names: if name in ['__builtins__', '__doc__', '__file__', - '__name__', '__path__', '__package__']: + '__name__', '__path__', '__package__', + '__cached__', '__loader__', '__spec__']: continue member_full_name = self.full_name + '.' + name if self.full_name else name diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py index 274d48ef66071a4e6a5ebea65087f18382fea6a2..9f6b185e812ded5e690682b1515a1bf0d7add7e0 100644 --- a/tensorflow/tools/docs/parser_test.py +++ b/tensorflow/tools/docs/parser_test.py @@ -24,6 +24,7 @@ import sys from tensorflow.python.platform import googletest from tensorflow.python.util import tf_inspect +from tensorflow.tools.docs import doc_controls from tensorflow.tools.docs import parser @@ -37,13 +38,27 @@ def test_function_with_args_kwargs(unused_arg, *unused_args, **unused_kwargs): pass -class TestClass(object): +class ParentClass(object): + + @doc_controls.do_not_doc_inheritable + def hidden_method(self): + pass + + +class TestClass(ParentClass): """Docstring for TestClass itself.""" def a_method(self, arg='default'): """Docstring for a method.""" pass + def hidden_method(self): + pass + + @doc_controls.do_not_generate_docs + def hidden_method2(self): + pass + class ChildClass(object): """Docstring for a child class.""" pass @@ -175,6 +190,104 @@ class ParserTest(googletest.TestCase): # Make sure this file is contained as the definition location. self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path) + def test_docs_for_class_should_skip(self): + + class Parent(object): + + @doc_controls.do_not_doc_inheritable + def a_method(self, arg='default'): + pass + + class Child(Parent): + + def a_method(self, arg='default'): + pass + + index = { + 'Child': Child, + 'Child.a_method': Child.a_method, + } + + visitor = DummyVisitor(index=index, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) + + tree = { + 'Child': ['a_method'], + } + + parser_config = parser.ParserConfig( + reference_resolver=reference_resolver, + duplicates={}, + duplicate_of={}, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir='/') + + page_info = parser.docs_for_object( + full_name='Child', py_object=Child, parser_config=parser_config) + + # Make sure the `a_method` is not present + self.assertEqual(0, len(page_info.methods)) + + def test_docs_for_message_class(self): + + class CMessage(object): + + def hidden(self): + pass + + class Message(object): + + def hidden2(self): + pass + + class MessageMeta(object): + + def hidden3(self): + pass + + class ChildMessage(CMessage, Message, MessageMeta): + + def my_method(self): + pass + + index = { + 'ChildMessage': ChildMessage, + 'ChildMessage.hidden': ChildMessage.hidden, + 'ChildMessage.hidden2': ChildMessage.hidden2, + 'ChildMessage.hidden3': ChildMessage.hidden3, + 'ChildMessage.my_method': ChildMessage.my_method, + } + + visitor = DummyVisitor(index=index, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) + + tree = {'ChildMessage': ['hidden', 'hidden2', 'hidden3', 'my_method']} + + parser_config = parser.ParserConfig( + reference_resolver=reference_resolver, + duplicates={}, + duplicate_of={}, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir='/') + + page_info = parser.docs_for_object( + full_name='ChildMessage', + py_object=ChildMessage, + parser_config=parser_config) + + self.assertEqual(1, len(page_info.methods)) + self.assertEqual('my_method', page_info.methods[0].short_name) + def test_docs_for_module(self): # Get the current module. module = sys.modules[__name__] diff --git a/tensorflow/tools/docs/py_guide_parser.py b/tensorflow/tools/docs/py_guide_parser.py index 328f42d18f1efb0fd82725a4683abad2df0d5a19..b00694dc40322161f180410630bb4dcfd8c2fb18 100644 --- a/tensorflow/tools/docs/py_guide_parser.py +++ b/tensorflow/tools/docs/py_guide_parser.py @@ -44,7 +44,8 @@ class PyGuideParser(object): def process(self, full_path): """Read and process the file at `full_path`.""" - md_string = open(full_path, 'rb').read().decode('utf-8') + with open(full_path, 'rb') as f: + md_string = f.read().decode('utf-8') self._lines = md_string.split('\n') seen = set() diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc index f1d361e07d8f00aa37a4e063a7d17bf85de74fde..156636ab8215d9abdc9e0ed461df550f1c7ed09c 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc @@ -159,7 +159,7 @@ Status FuseScaleOffsetToConvWeights(const std::vector& scale_values, NodeDef bias_add_node; bias_add_node.set_op("BiasAdd"); bias_add_node.set_name(conv_output_name); - if (!conv_node.attr().count("data_format")) { + if (conv_node.attr().count("data_format") > 0) { CopyNodeAttr(conv_node, "data_format", "data_format", &bias_add_node); } CopyNodeAttr(conv_node, "T", "T", &bias_add_node); diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index af17fd75bc1ccac61538c17658d59ee2efd6254a..cb084e49b7c797acd85d77c65ce2c69fd05be4ce 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -247,9 +247,16 @@ Status SortByExecutionOrder(const GraphDef& input_graph_def, } } - if (processed < input_graph_def.node_size()) { - return errors::InvalidArgument(input_graph_def.node_size() - processed, - " nodes in a cycle"); + if (processed < num_nodes) { + LOG(WARNING) << "IN " << __func__ << (num_nodes - processed) + << " NODES IN A CYCLE"; + for (int64 i = 0; i < num_nodes; i++) { + if (pending_count[i] != 0) { + LOG(WARNING) << "PENDING: " << SummarizeNodeDef(input_graph_def.node(i)) + << "WITH PENDING COUNT = " << pending_count[i]; + } + } + return errors::InvalidArgument(num_nodes - processed, " nodes in a cycle"); } return Status::OK(); } diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 77f83b77a0214110e520c85d15ffa38bce65955f..b450bc42c541cf51249c462d12255d79edf353c1 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -4,7 +4,9 @@ package(default_visibility = ["//visibility:private"]) load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_tar") +load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib") load("//tensorflow:tensorflow.bzl", "tf_binary_additional_srcs") +load("//tensorflow:tensorflow.bzl", "if_cuda") load("//third_party/mkl:build_defs.bzl", "if_mkl") genrule( @@ -113,10 +115,8 @@ genrule( "//third_party/hadoop:LICENSE.txt", "//third_party/eigen3:LICENSE", "//third_party/fft2d:LICENSE", - "@aws//:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", - "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@double_conversion//:LICENSE", "@eigen_archive//:COPYING.MPL2", @@ -124,13 +124,8 @@ genrule( "@fft2d//:fft/readme.txt", "@gemmlowp//:LICENSE", "@gif_archive//:COPYING", - "@grpc//:LICENSE", - "@grpc//third_party/address_sorting:LICENSE", - "@grpc//third_party/nanopb:LICENSE.txt", "@highwayhash//:LICENSE", - "@jemalloc//:COPYING", "@jpeg//:LICENSE.md", - "@libxsmm_archive//:LICENSE", "@llvm//:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", @@ -140,9 +135,42 @@ genrule( "@protobuf_archive//:LICENSE", "@snappy//:COPYING", "@zlib_archive//:zlib.h", - ] + if_mkl([ + ] + select({ + "//tensorflow:with_aws_support": [ + "@aws//:LICENSE", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow:with_gcp_support": [ + "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow:with_jemalloc_linux_x86_64": [ + "@jemalloc//:COPYING", + ], + "//tensorflow:with_jemalloc_linux_ppc64le": [ + "@jemalloc//:COPYING", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow/core/kernels:xsmm": [ + "@libxsmm_archive//:LICENSE.md", + ], + "//conditions:default": [], + }) + if_cuda([ + "@cub_archive//:LICENSE.TXT", + ]) + if_mkl([ "//third_party/mkl:LICENSE", - ]), + "//third_party/mkl_dnn:LICENSE", + ]) + if_not_system_lib( + "grpc", + [ + "@grpc//:LICENSE", + "@grpc//third_party/nanopb:LICENSE.txt", + "@grpc//third_party/address_sorting:LICENSE", + ], + ), outs = ["include/tensorflow/c/LICENSE"], cmd = "$(location :concat_licenses.sh) $(SRCS) >$@", tools = [":concat_licenses.sh"], @@ -154,10 +182,8 @@ genrule( "//third_party/hadoop:LICENSE.txt", "//third_party/eigen3:LICENSE", "//third_party/fft2d:LICENSE", - "@aws//:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", - "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@double_conversion//:LICENSE", "@eigen_archive//:COPYING.MPL2", @@ -166,9 +192,7 @@ genrule( "@gemmlowp//:LICENSE", "@gif_archive//:COPYING", "@highwayhash//:LICENSE", - "@jemalloc//:COPYING", "@jpeg//:LICENSE.md", - "@libxsmm_archive//:LICENSE", "@llvm//:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", @@ -178,8 +202,34 @@ genrule( "@protobuf_archive//:LICENSE", "@snappy//:COPYING", "@zlib_archive//:zlib.h", - ] + if_mkl([ + ] + select({ + "//tensorflow:with_aws_support": [ + "@aws//:LICENSE", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow:with_gcp_support": [ + "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow:with_jemalloc_linux_x86_64": [ + "@jemalloc//:COPYING", + ], + "//tensorflow:with_jemalloc_linux_ppc64le": [ + "@jemalloc//:COPYING", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow/core/kernels:xsmm": [ + "@libxsmm_archive//:LICENSE.md", + ], + "//conditions:default": [], + }) + if_cuda([ + "@cub_archive//:LICENSE.TXT", + ]) + if_mkl([ "//third_party/mkl:LICENSE", + "//third_party/mkl_dnn:LICENSE", ]), outs = ["include/tensorflow/jni/LICENSE"], cmd = "$(location :concat_licenses.sh) $(SRCS) >$@", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 9d4148c07f37fb511836425e1b6ffceb7c259777..6bba139b4d4fc3abb57397f4bf0fdf6efc316284 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -9,9 +9,9 @@ load( "if_windows", "transitive_hdrs", ) -load("//third_party/mkl:build_defs.bzl", "if_mkl") +load("//third_party/mkl:build_defs.bzl", "if_mkl", "if_mkl_ml") load("//tensorflow:tensorflow.bzl", "if_cuda") -load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") +load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib") load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps") # This returns a list of headers of all public header libraries (e.g., @@ -57,15 +57,20 @@ COMMON_PIP_DEPS = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/autograph:autograph", "//tensorflow/contrib/autograph/converters:converters", - "//tensorflow/contrib/autograph/converters:test_lib", + "//tensorflow/contrib/autograph/core:core", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/impl:impl", + "//tensorflow/contrib/autograph/lang:lang", "//tensorflow/contrib/autograph/operators:operators", "//tensorflow/contrib/autograph/pyct:pyct", + "//tensorflow/contrib/autograph/pyct/testing:testing", "//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis", + "//tensorflow/contrib/autograph/pyct/common_transformers:common_transformers", "//tensorflow/contrib/boosted_trees:boosted_trees_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/constrained_optimization:constrained_optimization_pip", - "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test", + "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:contrib_op_loader", "//tensorflow/contrib/eager/python/examples:examples_pip", "//tensorflow/contrib/eager/python:evaluator", @@ -75,7 +80,7 @@ COMMON_PIP_DEPS = [ "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/predictor:predictor_pip", - "//tensorflow/contrib/proto:proto_pip", + "//tensorflow/contrib/proto:proto", "//tensorflow/contrib/receptive_field:receptive_field_pip", "//tensorflow/contrib/rpc:rpc_pip", "//tensorflow/contrib/session_bundle:session_bundle_pip", @@ -91,6 +96,7 @@ COMMON_PIP_DEPS = [ "//tensorflow/contrib/timeseries:timeseries_pip", "//tensorflow/contrib/tpu", "//tensorflow/examples/tutorials/mnist:package", + "//tensorflow/python:cond_v2", "//tensorflow/python:distributed_framework_test_lib", "//tensorflow/python:meta_graph_testdata", "//tensorflow/python:spectral_ops_test_util", @@ -100,6 +106,7 @@ COMMON_PIP_DEPS = [ "//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files", "//tensorflow/python/saved_model:saved_model", "//tensorflow/python/tools:tools_pip", + "//tensorflow/python/tools/api/generator:create_python_api", "//tensorflow/python:test_ops", "//tensorflow/tools/dist_test/server:grpc_tensorflow_server", ] @@ -124,11 +131,9 @@ filegroup( "@absl_py//absl/flags:LICENSE", "@arm_neon_2_x86_sse//:LICENSE", "@astor_archive//:LICENSE", - "@aws//:LICENSE", "@boringssl//:LICENSE", "@com_google_absl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", - "@cub_archive//:LICENSE.TXT", "@curl//:COPYING", "@double_conversion//:LICENSE", "@eigen_archive//:COPYING.MPL2", @@ -138,17 +143,10 @@ filegroup( "@gast_archive//:PKG-INFO", "@gemmlowp//:LICENSE", "@gif_archive//:COPYING", - "@grpc//:LICENSE", "@highwayhash//:LICENSE", - "@jemalloc//:COPYING", "@jpeg//:LICENSE.md", - "@kafka//:LICENSE", - "@libxsmm_archive//:LICENSE", "@lmdb//:LICENSE", - "@local_config_nccl//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", - "@grpc//third_party/nanopb:LICENSE.txt", - "@grpc//third_party/address_sorting:LICENSE", "@nasm//:LICENSE", "@nsync//:LICENSE", "@pcre//:LICENCE", @@ -160,9 +158,49 @@ filegroup( "@termcolor_archive//:COPYING.txt", "@zlib_archive//:zlib.h", "@org_python_pypi_backports_weakref//:LICENSE", - ] + if_mkl([ + ] + select({ + "//tensorflow:with_aws_support": [ + "@aws//:LICENSE", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow:with_gcp_support": [ + "@com_github_googleapis_googleapis//:LICENSE", + "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow:with_jemalloc_linux_x86_64": [ + "@jemalloc//:COPYING", + ], + "//tensorflow:with_jemalloc_linux_ppc64le": [ + "@jemalloc//:COPYING", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow:with_kafka_support": [ + "@kafka//:LICENSE", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow/core/kernels:xsmm": [ + "@libxsmm_archive//:LICENSE.md", + ], + "//conditions:default": [], + }) + if_cuda([ + "@cub_archive//:LICENSE.TXT", + "@local_config_nccl//:LICENSE", + ]) + if_mkl([ "//third_party/mkl:LICENSE", - ]) + tf_additional_license_deps(), + "//third_party/mkl_dnn:LICENSE", + ]) + if_not_system_lib( + "grpc", + [ + "@grpc//:LICENSE", + "@grpc//third_party/nanopb:LICENSE.txt", + "@grpc//third_party/address_sorting:LICENSE", + ], + ) + tf_additional_license_deps(), ) sh_binary( @@ -170,16 +208,13 @@ sh_binary( srcs = ["build_pip_package.sh"], data = select({ "//tensorflow:windows": [":simple_console_for_windows"], - "//tensorflow:windows_msvc": [":simple_console_for_windows"], "//conditions:default": COMMON_PIP_DEPS + [ ":simple_console", "//tensorflow/contrib/lite/python:interpreter_test_data", "//tensorflow/contrib/lite/python:tflite_convert", "//tensorflow/contrib/lite/toco/python:toco_from_protos", ], - }) + if_mkl(["//third_party/mkl:intel_binary_blob"]) + if_tensorrt([ - "//tensorflow/contrib/tensorrt:init_py", - ]), + }) + if_mkl_ml(["//third_party/intel_mkl_ml"]), ) # A genrule for generating a marker file for the pip package on Windows diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index f7e42ce5362163a23ae1060c47e3e02869372556..666ea75d4640774b23c5e0c9ea83ab84e99de6b2 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -17,16 +17,26 @@ set -e +function is_absolute { + [[ "$1" = /* ]] || [[ "$1" =~ ^[a-zA-Z]:[/\\].* ]] +} + function real_path() { - [[ $1 = /* ]] && echo "$1" || echo "$PWD/${1#./}" + is_absolute "$1" && echo "$1" || echo "$PWD/${1#./}" } function cp_external() { local src_dir=$1 local dest_dir=$2 - for f in `find "$src_dir" -maxdepth 1 -mindepth 1 ! -name '*local_config_cuda*' ! -name '*local_config_tensorrt*' ! -name '*org_tensorflow*'`; do - cp -R "$f" "$dest_dir" + + pushd . + cd "$src_dir" + for f in `find . ! -type d ! -name '*.py' ! -path '*local_config_cuda*' ! -path '*local_config_tensorrt*' ! -path '*local_config_syslibs*' ! -path '*org_tensorflow*'`; do + mkdir -p "${dest_dir}/$(dirname ${f})" + cp "${f}" "${dest_dir}/$(dirname ${f})/" done + popd + mkdir -p "${dest_dir}/local_config_cuda/cuda/cuda/" cp "${src_dir}/local_config_cuda/cuda/cuda/cuda_config.h" "${dest_dir}/local_config_cuda/cuda/cuda/" } @@ -34,7 +44,7 @@ function cp_external() { PLATFORM="$(uname -s | tr 'A-Z' 'a-z')" function is_windows() { # On windows, the shell script is actually running in msys - if [[ "${PLATFORM}" =~ msys_nt* ]]; then + if [[ "${PLATFORM}" =~ (mingw64|msys)_nt* ]]; then true else false @@ -49,6 +59,8 @@ function prepare_src() { TMPDIR="$1" mkdir -p "$TMPDIR" + EXTERNAL_INCLUDES="${TMPDIR}/tensorflow/include/external" + echo $(date) : "=== Preparing sources in dir: ${TMPDIR}" if [ ! -d bazel-bin/tensorflow ]; then @@ -66,10 +78,9 @@ function prepare_src() { cp -R \ bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip/runfiles/org_tensorflow/tensorflow \ "${TMPDIR}" - mkdir "${TMPDIR}/external" cp_external \ bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip/runfiles \ - "${TMPDIR}/external" + "${EXTERNAL_INCLUDES}/" RUNFILES=bazel-bin/tensorflow/tools/pip_package/simple_console_for_window_unzip/runfiles/org_tensorflow else RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow @@ -78,10 +89,9 @@ function prepare_src() { cp -R \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow \ "${TMPDIR}" - mkdir "${TMPDIR}/external" cp_external \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external \ - "${TMPDIR}/external" + "${EXTERNAL_INCLUDES}" # Copy MKL libs over so they can be loaded at runtime so_lib_dir=$(ls $RUNFILES | grep solib) || true if [ -n "${so_lib_dir}" ]; then @@ -96,10 +106,9 @@ function prepare_src() { cp -R \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow \ "${TMPDIR}" - mkdir "${TMPDIR}/external" cp_external \ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles \ - "${TMPDIR}/external" + "${EXTERNAL_INCLUDES}" # Copy MKL libs over so they can be loaded at runtime so_lib_dir=$(ls $RUNFILES | grep solib) || true if [ -n "${so_lib_dir}" ]; then diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index 401f833dbd6ae404af000714219cae482a31129b..bfc007bc391fc3964a087b305bdb3684cc614631 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -90,6 +90,7 @@ BLACKLIST = [ "//tensorflow/contrib/lite/python:interpreter.py", "//tensorflow/contrib/lite/python:interpreter_test.py", "//tensorflow/contrib/ffmpeg:test_data", + "//tensorflow/contrib/hadoop:test_data", "//tensorflow/contrib/factorization/examples:mnist", "//tensorflow/contrib/factorization/examples:mnist.py", "//tensorflow/contrib/factorization:factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", # pylint:disable=line-too-long diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 97f625e7e9cf5bc5064596cc76737f5cb7a591f2..5e179079c576ca23db87038442b9be9990fbc5ab 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -45,17 +45,19 @@ DOCLINES = __doc__.split('\n') # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. -_VERSION = '1.9.0-rc0' +_VERSION = '1.10.0' REQUIRED_PACKAGES = [ 'absl-py >= 0.1.6', 'astor >= 0.6.0', 'gast >= 0.2.0', - 'numpy >= 1.13.3', + 'keras_applications == 1.0.4', + 'keras_preprocessing == 1.0.2', + 'numpy >= 1.13.3, <= 1.14.5', 'six >= 1.10.0', - 'protobuf >= 3.4.0', + 'protobuf >= 3.6.0', 'setuptools <= 39.1.0', - 'tensorboard >= 1.8.0, < 1.9.0', + 'tensorboard >= 1.10.0, < 1.11.0', 'termcolor >= 1.1.0', ] @@ -84,7 +86,7 @@ else: if 'tf_nightly' in project_name: for i, pkg in enumerate(REQUIRED_PACKAGES): if 'tensorboard' in pkg: - REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.9.0a0, < 1.10.0a0' + REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.11.0a0, < 1.12.0a0' break # weakref.finalize and enum were introduced in Python 3.4 @@ -170,8 +172,9 @@ class InstallHeaders(Command): # symlink within the directory hierarchy. # NOTE(keveman): Figure out how to customize bdist_wheel package so # we can do the symlink. - if 'external/eigen_archive/' in install_dir: - extra_dir = install_dir.replace('external/eigen_archive', '') + if 'tensorflow/include/external/eigen_archive/' in install_dir: + extra_dir = install_dir.replace( + 'tensorflow/include/external/eigen_archive', '') if not os.path.exists(extra_dir): self.mkpath(extra_dir) self.copy_file(header, extra_dir) @@ -204,13 +207,12 @@ def find_files(pattern, root): yield os.path.join(dirpath, filename) -matches = ['../' + x for x in find_files('*', 'external') if '.py' not in x] - so_lib_paths = [ i for i in os.listdir('.') if os.path.isdir(i) and fnmatch.fnmatch(i, '_solib_*') ] +matches = [] for path in so_lib_paths: matches.extend( ['../' + x for x in find_files('*', path) if '.py' not in x] @@ -225,7 +227,7 @@ headers = (list(find_files('*.h', 'tensorflow/core')) + list(find_files('*.h', 'tensorflow/stream_executor')) + list(find_files('*.h', 'google/protobuf_archive/src')) + list(find_files('*', 'third_party/eigen3')) + - list(find_files('*', 'external/eigen_archive'))) + list(find_files('*', 'tensorflow/include/external/eigen_archive'))) setup( name=project_name, diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD index 31e8fb9120c3b6280911f836eb0b68b883f2ac9d..fc2c041b6c14b7946bbdcea7ae890f34d8e0ea79 100644 --- a/tensorflow/tools/proto_text/BUILD +++ b/tensorflow/tools/proto_text/BUILD @@ -49,7 +49,6 @@ cc_library( copts = if_ios(["-DGOOGLE_LOGGING"]), linkopts = select({ "//tensorflow:windows": [], - "//tensorflow:windows_msvc": [], "//tensorflow:darwin": [ "-lm", "-lpthread", diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index ef5cb60ceed9d2b67dd96c6c1f36e2e6ca927a9c..217910c04f09987deff7dfac9e535ab493d7421b 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -8,826 +8,902 @@ load("//third_party/git:git_configure.bzl", "git_configure") load("//third_party/py:python_configure.bzl", "python_configure") load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") +load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure") load("//third_party/toolchains/clang6:repo.bzl", "clang6_configure") load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", "arm_compiler_configure") load("//third_party:repo.bzl", "tf_http_archive") load("//third_party/clang_toolchain:cc_configure_clang.bzl", "cc_download_clang_toolchain") load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", "java_import_external") load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") -load("//tensorflow/tools/def_file_filter:def_file_filter_configure.bzl", - "def_file_filter_configure") - +load( + "//tensorflow/tools/def_file_filter:def_file_filter_configure.bzl", + "def_file_filter_configure", +) # Sanitize a dependency so that it works correctly from code that includes # TensorFlow as a submodule. def clean_dep(dep): - return str(Label(dep)) + return str(Label(dep)) # If TensorFlow is linked as a submodule. # path_prefix is no longer used. # tf_repo_name is thought to be under consideration. -def tf_workspace(path_prefix="", tf_repo_name=""): - # Note that we check the minimum bazel version in WORKSPACE. - clang6_configure(name="local_config_clang6") - cc_download_clang_toolchain(name="local_config_download_clang") - cuda_configure(name="local_config_cuda") - tensorrt_configure(name="local_config_tensorrt") - nccl_configure(name="local_config_nccl") - git_configure(name="local_config_git") - sycl_configure(name="local_config_sycl") - python_configure(name="local_config_python") - - # For windows bazel build - # TODO: Remove def file filter when TensorFlow can export symbols properly on Windows. - def_file_filter_configure(name = "local_config_def_file_filter") - - # Point //external/local_config_arm_compiler to //external/arm_compiler - arm_compiler_configure( - name="local_config_arm_compiler", - remote_config_repo="../arm_compiler", - build_file = clean_dep("//third_party/toolchains/cpus/arm:BUILD")) - - mkl_repository( - name = "mkl_linux", - urls = [ - "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.14/mklml_lnx_2018.0.3.20180406.tgz", - "https://github.com/intel/mkl-dnn/releases/download/v0.14/mklml_lnx_2018.0.3.20180406.tgz" - ], - sha256 = "d2305244fdc9b87db7426ed4496e87a4b3977ad3374d73b8000e8b7a5b7aa725", - strip_prefix = "mklml_lnx_2018.0.3.20180406", - build_file = clean_dep("//third_party/mkl:mkl.BUILD") - ) - mkl_repository( - name = "mkl_windows", - urls = [ - "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.14/mklml_win_2018.0.3.20180406.zip", - "https://github.com/intel/mkl-dnn/releases/download/v0.14/mklml_win_2018.0.3.20180406.zip" - ], - sha256 = "a584a5bf1c8d2ad70b90d12b52652030e9a338217719064fdb84b7ad0d693694", - strip_prefix = "mklml_win_2018.0.3.20180406", - build_file = clean_dep("//third_party/mkl:mkl.BUILD") - ) - mkl_repository( - name = "mkl_darwin", - urls = [ - "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.14/mklml_mac_2018.0.3.20180406.tgz", - "https://github.com/intel/mkl-dnn/releases/download/v0.14/mklml_mac_2018.0.3.20180406.tgz" - ], - sha256 = "094e3dfd61c816136dc8d12a45cc611ce26c5f4828176a3644cd0b0efa15a25b", - strip_prefix = "mklml_mac_2018.0.3.20180406", - build_file = clean_dep("//third_party/mkl:mkl.BUILD") - ) - - if path_prefix: - print("path_prefix was specified to tf_workspace but is no longer used " + - "and will be removed in the future.") - - tf_http_archive( - name = "mkl_dnn", - urls = [ - "https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/v0.14.tar.gz", - "https://github.com/intel/mkl-dnn/archive/v0.14.tar.gz", - ], - sha256 = "efebc53882856afec86457a2da644693f5d59c68772d41d640d6b60a8efc4eb0", - strip_prefix = "mkl-dnn-0.14", - build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"), - ) - - tf_http_archive( - name = "com_google_absl", - urls = [ - "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/9613678332c976568272c8f4a78631a29159271d.tar.gz", - "https://github.com/abseil/abseil-cpp/archive/9613678332c976568272c8f4a78631a29159271d.tar.gz", - ], - sha256 = "1273a1434ced93bc3e703a48c5dced058c95e995c8c009e9bdcb24a69e2180e9", - strip_prefix = "abseil-cpp-9613678332c976568272c8f4a78631a29159271d", - build_file = clean_dep("//third_party:com_google_absl.BUILD"), - ) - - tf_http_archive( - name = "eigen_archive", - urls = [ - "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/267806ed9b4f.tar.gz", - "https://bitbucket.org/eigen/eigen/get/267806ed9b4f.tar.gz", - ], - sha256 = "ade57357093463cab9e4e51cd5749c81483a75451b1471a3ebc73f9c1d14043b", - strip_prefix = "eigen-eigen-267806ed9b4f", - build_file = clean_dep("//third_party:eigen.BUILD"), - ) - - tf_http_archive( - name = "arm_compiler", - sha256 = "970285762565c7890c6c087d262b0a18286e7d0384f13a37786d8521773bc969", - strip_prefix = "tools-0e906ebc527eab1cdbf7adabff5b474da9562e9f/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf", - urls = [ - "https://mirror.bazel.build/github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", - # Please uncomment me, when the next upgrade happens. Then - # remove the whitelist entry in third_party/repo.bzl. - # "https://github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", - ], - build_file = clean_dep("//:arm_compiler.BUILD"), - ) - - tf_http_archive( - name = "libxsmm_archive", - urls = [ - "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.8.1.tar.gz", - "https://github.com/hfp/libxsmm/archive/1.8.1.tar.gz", - ], - sha256 = "2ade869c3f42f23b5263c7d594aa3c7e5e61ac6a3afcaf5d6e42899d2a7986ce", - strip_prefix = "libxsmm-1.8.1", - build_file = clean_dep("//third_party:libxsmm.BUILD"), - ) - - tf_http_archive( - name = "ortools_archive", - urls = [ - "https://mirror.bazel.build/github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz", - # Please uncomment me, when the next upgrade happens. Then - # remove the whitelist entry in third_party/repo.bzl. - # "https://github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz", - ], - sha256 = "932075525642b04ac6f1b50589f1df5cd72ec2f448b721fd32234cf183f0e755", - strip_prefix = "or-tools-253f7955c6a1fd805408fba2e42ac6d45b312d15/src", - build_file = clean_dep("//third_party:ortools.BUILD"), - ) - - tf_http_archive( - name = "com_googlesource_code_re2", - urls = [ - "https://mirror.bazel.build/github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz", - "https://github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz", - - ], - sha256 = "e57eeb837ac40b5be37b2c6197438766e73343ffb32368efea793dfd8b28653b", - strip_prefix = "re2-26cd968b735e227361c9703683266f01e5df7857", - ) - - tf_http_archive( - name = "gemmlowp", - urls = [ - "https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip", - "https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip", - ], - sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658", - strip_prefix = "gemmlowp-38ebac7b059e84692f53e5938f97a9943c120d98", - ) - - tf_http_archive( - name = "farmhash_archive", - urls = [ - "https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", - "https://github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", - ], - sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0", - strip_prefix = "farmhash-816a4ae622e964763ca0862d9dbd19324a1eaf45", - build_file = clean_dep("//third_party:farmhash.BUILD"), - ) - - tf_http_archive( - name = "highwayhash", - urls = [ - "http://mirror.bazel.build/github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz", - "https://github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz", - ], - sha256 = "9c3e0e87d581feeb0c18d814d98f170ff23e62967a2bd6855847f0b2fe598a37", - strip_prefix = "highwayhash-fd3d9af80465e4383162e4a7c5e2f406e82dd968", - build_file = clean_dep("//third_party:highwayhash.BUILD"), - ) - - tf_http_archive( - name = "nasm", - urls = [ - "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2", - "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.12.02.tar.bz2/d15843c3fb7db39af80571ee27ec6fad/nasm-2.12.02.tar.bz2", - ], - sha256 = "00b0891c678c065446ca59bcee64719d0096d54d6886e6e472aeee2e170ae324", - strip_prefix = "nasm-2.12.02", - build_file = clean_dep("//third_party:nasm.BUILD"), - ) - - tf_http_archive( - name = "jpeg", - urls = [ - "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz", - "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz", - ], - sha256 = "1a17020f859cb12711175a67eab5c71fc1904e04b587046218e36106e07eabde", - strip_prefix = "libjpeg-turbo-1.5.3", - build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"), - ) - - tf_http_archive( - name = "png_archive", - urls = [ - "https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz", - "https://github.com/glennrp/libpng/archive/v1.6.34.tar.gz", - ], - sha256 = "e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef", - strip_prefix = "libpng-1.6.34", - build_file = clean_dep("//third_party:png.BUILD"), - patch_file = clean_dep("//third_party:png_fix_rpi.patch"), - ) - - tf_http_archive( - name = "org_sqlite", - urls = [ - "https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3230100.zip", - "https://www.sqlite.org/2018/sqlite-amalgamation-3230100.zip", - ], - sha256 = "4239a1f69e5721d07d9a374eb84d594225229e54be4ee628da2995f4315d8dfc", - strip_prefix = "sqlite-amalgamation-3230100", - build_file = clean_dep("//third_party:sqlite.BUILD"), - ) - - tf_http_archive( - name = "gif_archive", - urls = [ - "https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz", - "http://pilotfiber.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz", - ], - sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1", - strip_prefix = "giflib-5.1.4", - build_file = clean_dep("//third_party:gif.BUILD"), - ) - - tf_http_archive( - name = "six_archive", - urls = [ - "https://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", - "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", - ], - sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", - strip_prefix = "six-1.10.0", - build_file = clean_dep("//third_party:six.BUILD"), - ) - - tf_http_archive( - name = "astor_archive", - urls = [ - "https://mirror.bazel.build/pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz", - "https://pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz", - ], - sha256 = "ff6d2e2962d834acb125cc4dcc80c54a8c17c253f4cc9d9c43b5102a560bb75d", - strip_prefix = "astor-0.6.2", - build_file = clean_dep("//third_party:astor.BUILD"), - ) - - tf_http_archive( - name = "gast_archive", - urls = [ - "https://mirror.bazel.build/pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz", - "https://pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz", - ], - sha256 = "7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930", - strip_prefix = "gast-0.2.0", - build_file = clean_dep("//third_party:gast.BUILD"), - ) - - tf_http_archive( - name = "termcolor_archive", - urls = [ - "https://mirror.bazel.build/pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz", - "https://pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz", - ], - sha256 = "1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b", - strip_prefix = "termcolor-1.1.0", - build_file = clean_dep("//third_party:termcolor.BUILD"), - ) - - tf_http_archive( - name = "absl_py", - urls = [ - "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/ea8c4d2ddbf3fba610c4d613260561699b776db8.tar.gz", - "https://github.com/abseil/abseil-py/archive/ea8c4d2ddbf3fba610c4d613260561699b776db8.tar.gz", - ], - sha256 = "c30b48e0d2580ef1412e55c5c0e1dab8db2ee4ab56e2075eccff29c90c7c7059", - strip_prefix = "abseil-py-ea8c4d2ddbf3fba610c4d613260561699b776db8", - ) - - tf_http_archive( - name = "org_python_pypi_backports_weakref", - urls = [ - "https://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", - "https://pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", - ], - sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892", - strip_prefix = "backports.weakref-1.0rc1/src", - build_file = clean_dep("//third_party:backports_weakref.BUILD"), - ) - - filegroup_external( - name = "org_python_license", - licenses = ["notice"], # Python 2.0 - sha256_urls = { - "b5556e921715ddb9242c076cae3963f483aa47266c5e37ea4c187f77cc79501c": [ - "https://mirror.bazel.build/docs.python.org/2.7/_sources/license.txt", - "https://docs.python.org/2.7/_sources/license.txt", - ], - }, - ) - - tf_http_archive( - name = "protobuf_archive", - urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", - "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", - ], - sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", - strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", - ) - - # We need to import the protobuf library under the names com_google_protobuf - # and com_google_protobuf_cc to enable proto_library support in bazel. - # Unfortunately there is no way to alias http_archives at the moment. - tf_http_archive( - name = "com_google_protobuf", - urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", - "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", - ], - sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", - strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", - ) - - tf_http_archive( - name = "com_google_protobuf_cc", - urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", - "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz", - ], - sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3", - strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a", - ) - - tf_http_archive( - name = "nsync", - urls = [ - "https://mirror.bazel.build/github.com/google/nsync/archive/0559ce013feac8db639ee1bf776aca0325d28777.tar.gz", - "https://github.com/google/nsync/archive/0559ce013feac8db639ee1bf776aca0325d28777.tar.gz", - ], - sha256 = "6284454c5cd8b1dae2eeb8cf5eb63004de930b5427ed5f6b1aa793513df6b361", - strip_prefix = "nsync-0559ce013feac8db639ee1bf776aca0325d28777", - ) - - tf_http_archive( - name = "com_google_googletest", - urls = [ - "https://mirror.bazel.build/github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", - "https://github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", - ], - sha256 = "9cbca84c4256bed17df2c8f4d00c912c19d247c11c9ba6647cd6dd5b5c996b8d", - strip_prefix = "googletest-9816b96a6ddc0430671693df90192bbee57108b6", - ) - - tf_http_archive( - name = "com_github_gflags_gflags", - urls = [ - "https://mirror.bazel.build/github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", - "https://github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", - ], - sha256 = "4d222fab8f1ede4709cdff417d15a1336f862d7334a81abf76d09c15ecf9acd1", - strip_prefix = "gflags-f8a0efe03aa69b3336d8e228b37d4ccb17324b88", - ) - - tf_http_archive( - name = "pcre", - sha256 = "ccdf7e788769838f8285b3ee672ed573358202305ee361cfec7a4a4fb005bbc7", - urls = [ - "https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.39.tar.gz", - "http://ftp.exim.org/pub/pcre/pcre-8.39.tar.gz", - ], - strip_prefix = "pcre-8.39", - build_file = clean_dep("//third_party:pcre.BUILD"), - ) - - tf_http_archive( - name = "swig", - sha256 = "58a475dbbd4a4d7075e5fe86d4e54c9edde39847cdb96a3053d87cb64a23a453", - urls = [ - "https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", - "http://ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", - "http://pilotfiber.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", - ], - strip_prefix = "swig-3.0.8", - build_file = clean_dep("//third_party:swig.BUILD"), - ) - - tf_http_archive( - name = "curl", - sha256 = "ff3e80c1ca6a068428726cd7dd19037a47cc538ce58ef61c59587191039b2ca6", - urls = [ - "https://mirror.bazel.build/curl.haxx.se/download/curl-7.49.1.tar.gz", - "https://curl.haxx.se/download/curl-7.49.1.tar.gz", - ], - strip_prefix = "curl-7.49.1", - build_file = clean_dep("//third_party:curl.BUILD"), - ) - - tf_http_archive( - name = "grpc", - urls = [ - "https://mirror.bazel.build/github.com/grpc/grpc/archive/d184fa229d75d336aedea0041bd59cb93e7e267f.tar.gz", - "https://github.com/grpc/grpc/archive/d184fa229d75d336aedea0041bd59cb93e7e267f.tar.gz", - ], - sha256 = "895b31310e718a61f7335759a778c068a6edde1c089883598a0830cbb7075673", - strip_prefix = "grpc-d184fa229d75d336aedea0041bd59cb93e7e267f", - ) - - - tf_http_archive( - name = "linenoise", - sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7", - urls = [ - "https://mirror.bazel.build/github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", - "https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", - ], - strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3", - build_file = clean_dep("//third_party:linenoise.BUILD"), - ) - - # TODO(phawkins): currently, this rule uses an unofficial LLVM mirror. - # Switch to an official source of snapshots if/when possible. - tf_http_archive( - name = "llvm", - urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/582e5dd5553e3089fef97f9ab5a3f063e0160fa9.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/582e5dd5553e3089fef97f9ab5a3f063e0160fa9.tar.gz", - ], - sha256 = "9a0e63469ae5a546e0c84b778955f0febabfc8497d312324546ec7d0db68430e", - strip_prefix = "llvm-582e5dd5553e3089fef97f9ab5a3f063e0160fa9", - build_file = clean_dep("//third_party/llvm:llvm.BUILD"), - ) - - tf_http_archive( - name = "lmdb", - urls = [ - "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", - "https://github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", - ], - sha256 = "108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326", - strip_prefix = "lmdb-LMDB_0.9.19/libraries/liblmdb", - build_file = clean_dep("//third_party:lmdb.BUILD"), - ) - - tf_http_archive( - name = "jsoncpp_git", - urls = [ - "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", - "https://github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", - ], - sha256 = "07d34db40593d257324ec5fb9debc4dc33f29f8fb44e33a2eeb35503e61d0fe2", - strip_prefix = "jsoncpp-11086dd6a7eba04289944367ca82cea71299ed70", - build_file = clean_dep("//third_party:jsoncpp.BUILD"), - ) - - tf_http_archive( - name = "boringssl", - urls = [ - "https://mirror.bazel.build/github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz", - "https://github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz", - ], - sha256 = "524ba98a56300149696481b4cb9ddebd0c7b7ac9b9f6edee81da2d2d7e5d2bb3", - strip_prefix = "boringssl-a0fb951d2a26a8ee746b52f3ba81ab011a0af778", - ) - - tf_http_archive( - name = "zlib_archive", - urls = [ - "https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz", - "https://zlib.net/zlib-1.2.11.tar.gz", - ], - sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", - strip_prefix = "zlib-1.2.11", - build_file = clean_dep("//third_party:zlib.BUILD"), - ) - - tf_http_archive( - name = "fft2d", - urls = [ - "https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", - "http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", - ], - sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296", - build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"), - ) - - tf_http_archive( - name = "snappy", - urls = [ - "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.7.tar.gz", - "https://github.com/google/snappy/archive/1.1.7.tar.gz", - ], - sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4", - strip_prefix = "snappy-1.1.7", - build_file = clean_dep("//third_party:snappy.BUILD"), - ) - - tf_http_archive( - name = "nccl_archive", - urls = [ - "https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", - "https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", - ], - sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176", - strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7", - build_file = clean_dep("//third_party:nccl/nccl_archive.BUILD"), - ) - - tf_http_archive( - name = "kafka", - urls = [ - "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.1.tar.gz", - "https://github.com/edenhill/librdkafka/archive/v0.11.1.tar.gz", - ], - sha256 = "dd035d57c8f19b0b612dd6eefe6e5eebad76f506e302cccb7c2066f25a83585e", - strip_prefix = "librdkafka-0.11.1", - build_file = clean_dep("//third_party:kafka/BUILD"), - patch_file = clean_dep("//third_party/kafka:config.patch"), - ) - - tf_http_archive( - name = "aws", - urls = [ - "https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz", - "https://github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz", - ], - sha256 = "b888d8ce5fc10254c3dd6c9020c7764dd53cf39cf011249d0b4deda895de1b7c", - strip_prefix = "aws-sdk-cpp-1.3.15", - build_file = clean_dep("//third_party:aws.BUILD"), - ) - - java_import_external( - name = "junit", - jar_sha256 = "59721f0805e223d84b90677887d9ff567dc534d7c502ca903c0c2b17f05c116a", - jar_urls = [ - "https://mirror.bazel.build/repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar", - "http://repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar", - "http://maven.ibiblio.org/maven2/junit/junit/4.12/junit-4.12.jar", - ], - licenses = ["reciprocal"], # Common Public License Version 1.0 - testonly_ = True, - deps = ["@org_hamcrest_core"], - ) - - java_import_external( - name = "org_hamcrest_core", - jar_sha256 = "66fdef91e9739348df7a096aa384a5685f4e875584cce89386a7a47251c4d8e9", - jar_urls = [ - "https://mirror.bazel.build/repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", - "http://repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", - "http://maven.ibiblio.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", - ], - licenses = ["notice"], # New BSD License - testonly_ = True, - ) - - tf_http_archive( - name = "jemalloc", - urls = [ - "https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", - "https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", - ], - sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8", - strip_prefix = "jemalloc-4.4.0", - build_file = clean_dep("//third_party:jemalloc.BUILD"), - ) - - java_import_external( - name = "com_google_testing_compile", - jar_sha256 = "edc180fdcd9f740240da1a7a45673f46f59c5578d8cd3fbc912161f74b5aebb8", - jar_urls = [ - "http://mirror.bazel.build/repo1.maven.org/maven2/com/google/testing/compile/compile-testing/0.11/compile-testing-0.11.jar", - "http://repo1.maven.org/maven2/com/google/testing/compile/compile-testing/0.11/compile-testing-0.11.jar", - ], - licenses = ["notice"], # New BSD License - testonly_ = True, - deps = ["@com_google_guava", "@com_google_truth"], - ) - - java_import_external( - name = "com_google_truth", - jar_sha256 = "032eddc69652b0a1f8d458f999b4a9534965c646b8b5de0eba48ee69407051df", - jar_urls = [ - "http://mirror.bazel.build/repo1.maven.org/maven2/com/google/truth/truth/0.32/truth-0.32.jar", - "http://repo1.maven.org/maven2/com/google/truth/truth/0.32/truth-0.32.jar", - ], - licenses = ["notice"], # Apache 2.0 - testonly_ = True, - deps = ["@com_google_guava"], - ) - - java_import_external( - name = "org_checkerframework_qual", - jar_sha256 = "a17501717ef7c8dda4dba73ded50c0d7cde440fd721acfeacbf19786ceac1ed6", - jar_urls = [ - "http://mirror.bazel.build/repo1.maven.org/maven2/org/checkerframework/checker-qual/2.4.0/checker-qual-2.4.0.jar", - "http://repo1.maven.org/maven2/org/checkerframework/checker-qual/2.4.0/checker-qual-2.4.0.jar", - ], - licenses = ["notice"], # Apache 2.0 - ) - - tf_http_archive( - name = "com_google_pprof", - urls = [ - "https://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", - "https://github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", - ], - sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4", - strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650", - build_file = clean_dep("//third_party:pprof.BUILD"), - ) - - tf_http_archive( - name = "cub_archive", - urls = [ - "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip", - "https://github.com/NVlabs/cub/archive/1.8.0.zip", - ], - sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3", - strip_prefix = "cub-1.8.0", - build_file = clean_dep("//third_party:cub.BUILD"), - ) - - tf_http_archive( - name = "cython", - sha256 = "6dcd30b5ceb887b2b965ee7ceb82ea3acb5f0642fe2206c7636b45acea4798e5", - urls = [ - "https://mirror.bazel.build/github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz", - "https://github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz", - ], - strip_prefix = "cython-3732784c45cfb040a5b0936951d196f83a12ea17", - build_file = clean_dep("//third_party:cython.BUILD"), - delete = ["BUILD.bazel"], - ) - - tf_http_archive( - name = "bazel_toolchains", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/44200e0c026d86c53470d107b3697a3e46469c43.tar.gz", - "https://github.com/bazelbuild/bazel-toolchains/archive/44200e0c026d86c53470d107b3697a3e46469c43.tar.gz", - ], - strip_prefix = "bazel-toolchains-44200e0c026d86c53470d107b3697a3e46469c43", - sha256 = "699b55a6916c687f4b7dc092dbbf5f64672cde0dc965f79717735ec4e5416556", - ) - - tf_http_archive( - name = "arm_neon_2_x86_sse", - sha256 = "c8d90aa4357f8079d427e87a6f4c493da1fa4140aee926c05902d7ec1533d9a5", - strip_prefix = "ARM_NEON_2_x86_SSE-0f77d9d182265259b135dad949230ecbf1a2633d", - urls = [ - "https://mirror.bazel.build/github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz", - "https://github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz", - ], - build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"), - ) - - tf_http_archive( - name = "flatbuffers", - strip_prefix = "flatbuffers-971a68110e4fc1bace10fcb6deeb189e7e1a34ce", - sha256 = "874088d2ee0d9f8524191f77209556415f03dd44e156276edf19e5b90ceb5f55", - urls = [ - "https://mirror.bazel.build/github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz", - "https://github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz", - ], - build_file = clean_dep("//third_party/flatbuffers:flatbuffers.BUILD"), - ) - - native.new_http_archive( - name = "double_conversion", - urls = [ - "https://github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip", - ], - sha256 = "2f7fbffac0d98d201ad0586f686034371a6d152ca67508ab611adc2386ad30de", - strip_prefix = "double-conversion-3992066a95b823efc8ccc1baf82a1cfc73f6e9b8", - build_file = clean_dep("//third_party:double_conversion.BUILD") - ) - - tf_http_archive( - name = "tflite_mobilenet", - sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b", - urls = [ - "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", - "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", - ], - build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"), - ) - - tf_http_archive( - name = "tflite_mobilenet_ssd", - sha256 = "767057f2837a46d97882734b03428e8dd640b93236052b312b2f0e45613c1cf0", - urls = [ - "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip", - "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip", - ], - build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), - ) - - tf_http_archive( - name = "tflite_conv_actions_frozen", - sha256 = "d947b38cba389b5e2d0bfc3ea6cc49c784e187b41a071387b3742d1acac7691e", - urls = [ - "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip", - "https://storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip", - ], - build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), - ) - - tf_http_archive( - name = "tflite_smartreply", - sha256 = "8980151b85a87a9c1a3bb1ed4748119e4a85abd3cb5744d83da4d4bd0fbeef7c", - urls = [ - "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip", - "https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip" - ], - build_file = clean_dep("//third_party:tflite_smartreply.BUILD"), - ) - - tf_http_archive( - name = "tflite_ovic_testdata", - sha256 = "a9a705d8d519220178e2e65d383fdb21da37fdb31d1e909b0a1acdac46479e9c", - urls = [ - "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/data/ovic.zip", - "https://storage.googleapis.com/download.tensorflow.org/data/ovic.zip", - ], - build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"), - strip_prefix = "ovic", - ) - - ############################################################################## - # BIND DEFINITIONS - # - # Please do not add bind() definitions unless we have no other choice. - # If that ends up being the case, please leave a comment explaining - # why we can't depend on the canonical build target. - - # gRPC wants a cares dependency but its contents is not actually - # important since we have set GRPC_ARES=0 in tools/bazel.rc - native.bind( - name = "cares", - actual = "@grpc//third_party/nanopb:nanopb", - ) - - # Needed by Protobuf - native.bind( - name = "grpc_cpp_plugin", - actual = "@grpc//:grpc_cpp_plugin", - ) - native.bind( - name = "grpc_python_plugin", - actual = "@grpc//:grpc_python_plugin", - ) - - native.bind( - name = "grpc_lib", - actual = "@grpc//:grpc++", - ) - - # Needed by gRPC - native.bind( - name = "libssl", - actual = "@boringssl//:ssl", - ) - - # Needed by gRPC - native.bind( - name = "nanopb", - actual = "@grpc//third_party/nanopb:nanopb", - ) - - # Needed by gRPC - native.bind( - name = "protobuf", - actual = "@protobuf_archive//:protobuf", - ) - - # gRPC expects //external:protobuf_clib and //external:protobuf_compiler - # to point to Protobuf's compiler library. - native.bind( - name = "protobuf_clib", - actual = "@protobuf_archive//:protoc_lib", - ) - - # Needed by gRPC - native.bind( - name = "protobuf_headers", - actual = "@protobuf_archive//:protobuf_headers", - ) - - # Needed by Protobuf - native.bind( - name = "python_headers", - actual = clean_dep("//third_party/python_runtime:headers"), - ) - - # Needed by Protobuf - native.bind( - name = "six", - actual = "@six_archive//:six", - ) - - # Needed by gRPC - native.bind( - name = "zlib", - actual = "@zlib_archive//:zlib", - ) +def tf_workspace(path_prefix = "", tf_repo_name = ""): + # Note that we check the minimum bazel version in WORKSPACE. + clang6_configure(name = "local_config_clang6") + cc_download_clang_toolchain(name = "local_config_download_clang") + cuda_configure(name = "local_config_cuda") + tensorrt_configure(name = "local_config_tensorrt") + nccl_configure(name = "local_config_nccl") + git_configure(name = "local_config_git") + sycl_configure(name = "local_config_sycl") + syslibs_configure(name = "local_config_syslibs") + python_configure(name = "local_config_python") + + # For windows bazel build + # TODO: Remove def file filter when TensorFlow can export symbols properly on Windows. + def_file_filter_configure(name = "local_config_def_file_filter") + + # Point //external/local_config_arm_compiler to //external/arm_compiler + arm_compiler_configure( + name = "local_config_arm_compiler", + remote_config_repo = "../arm_compiler", + build_file = clean_dep("//third_party/toolchains/cpus/arm:BUILD"), + ) + + mkl_repository( + name = "mkl_linux", + urls = [ + "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_lnx_2018.0.3.20180406.tgz", + "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_lnx_2018.0.3.20180406.tgz", + ], + sha256 = "d2305244fdc9b87db7426ed4496e87a4b3977ad3374d73b8000e8b7a5b7aa725", + strip_prefix = "mklml_lnx_2018.0.3.20180406", + build_file = clean_dep("//third_party/mkl:mkl.BUILD"), + ) + mkl_repository( + name = "mkl_windows", + urls = [ + "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_win_2018.0.3.20180406.zip", + "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_win_2018.0.3.20180406.zip", + ], + sha256 = "a584a5bf1c8d2ad70b90d12b52652030e9a338217719064fdb84b7ad0d693694", + strip_prefix = "mklml_win_2018.0.3.20180406", + build_file = clean_dep("//third_party/mkl:mkl.BUILD"), + ) + mkl_repository( + name = "mkl_darwin", + urls = [ + "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.15/mklml_mac_2018.0.3.20180406.tgz", + "https://github.com/intel/mkl-dnn/releases/download/v0.15/mklml_mac_2018.0.3.20180406.tgz", + ], + sha256 = "094e3dfd61c816136dc8d12a45cc611ce26c5f4828176a3644cd0b0efa15a25b", + strip_prefix = "mklml_mac_2018.0.3.20180406", + build_file = clean_dep("//third_party/mkl:mkl.BUILD"), + ) + + if path_prefix: + print("path_prefix was specified to tf_workspace but is no longer used " + + "and will be removed in the future.") + + tf_http_archive( + name = "mkl_dnn", + urls = [ + "https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/0c1cf54b63732e5a723c5670f66f6dfb19b64d20.tar.gz", + "https://github.com/intel/mkl-dnn/archive/0c1cf54b63732e5a723c5670f66f6dfb19b64d20.tar.gz", + ], + sha256 = "da1f27f92453a65331197dd8e4992e810fb7b1c4e0b902a1da5611592df2b633", + strip_prefix = "mkl-dnn-0c1cf54b63732e5a723c5670f66f6dfb19b64d20", + build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"), + ) + + tf_http_archive( + name = "com_google_absl", + urls = [ + "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/9613678332c976568272c8f4a78631a29159271d.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/9613678332c976568272c8f4a78631a29159271d.tar.gz", + ], + sha256 = "1273a1434ced93bc3e703a48c5dced058c95e995c8c009e9bdcb24a69e2180e9", + strip_prefix = "abseil-cpp-9613678332c976568272c8f4a78631a29159271d", + build_file = clean_dep("//third_party:com_google_absl.BUILD"), + ) + + tf_http_archive( + name = "eigen_archive", + urls = [ + "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz", + "https://bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz", + ], + sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9", + strip_prefix = "eigen-eigen-fd6845384b86", + build_file = clean_dep("//third_party:eigen.BUILD"), + ) + + tf_http_archive( + name = "arm_compiler", + sha256 = "970285762565c7890c6c087d262b0a18286e7d0384f13a37786d8521773bc969", + strip_prefix = "tools-0e906ebc527eab1cdbf7adabff5b474da9562e9f/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf", + urls = [ + "https://mirror.bazel.build/github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", + # Please uncomment me, when the next upgrade happens. Then + # remove the whitelist entry in third_party/repo.bzl. + # "https://github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz", + ], + build_file = clean_dep("//:arm_compiler.BUILD"), + ) + + tf_http_archive( + name = "libxsmm_archive", + urls = [ + "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.9.tar.gz", + "https://github.com/hfp/libxsmm/archive/1.9.tar.gz", + ], + sha256 = "cd8532021352b4a0290d209f7f9bfd7c2411e08286a893af3577a43457287bfa", + strip_prefix = "libxsmm-1.9", + build_file = clean_dep("//third_party:libxsmm.BUILD"), + ) + + tf_http_archive( + name = "ortools_archive", + urls = [ + "https://mirror.bazel.build/github.com/google/or-tools/archive/v6.7.2.tar.gz", + "https://github.com/google/or-tools/archive/v6.7.2.tar.gz", + ], + sha256 = "d025a95f78b5fc5eaa4da5f395f23d11c23cf7dbd5069f1f627f002de87b86b9", + strip_prefix = "or-tools-6.7.2/src", + build_file = clean_dep("//third_party:ortools.BUILD"), + ) + + tf_http_archive( + name = "com_googlesource_code_re2", + urls = [ + "https://mirror.bazel.build/github.com/google/re2/archive/2018-07-01.tar.gz", + "https://github.com/google/re2/archive/2018-07-01.tar.gz", + ], + sha256 = "803c7811146edeef8f91064de37c6f19136ff01a2a8cdb3230e940b2fd9f07fe", + strip_prefix = "re2-2018-07-01", + system_build_file = clean_dep("//third_party/systemlibs:re2.BUILD"), + ) + + tf_http_archive( + name = "com_github_googlecloudplatform_google_cloud_cpp", + urls = [ + "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz", + "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz", + ], + sha256 = "fdd3b3aecce60987e5525e55bf3a21d68a8695320bd5b980775af6507eec3944", + strip_prefix = "google-cloud-cpp-14760a86c4ffab9943b476305c4fe927ad95db1c", + ) + + tf_http_archive( + name = "com_github_googleapis_googleapis", + urls = [ + "https://mirror.bazel.build/github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip", + "https://github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip", + ], + sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378", + strip_prefix = "googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb", + build_file = clean_dep("//third_party:googleapis.BUILD"), + ) + + tf_http_archive( + name = "gemmlowp", + urls = [ + "https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip", + "https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip", + ], + sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658", + strip_prefix = "gemmlowp-38ebac7b059e84692f53e5938f97a9943c120d98", + ) + + tf_http_archive( + name = "farmhash_archive", + urls = [ + "https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", + "https://github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", + ], + sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0", + strip_prefix = "farmhash-816a4ae622e964763ca0862d9dbd19324a1eaf45", + build_file = clean_dep("//third_party:farmhash.BUILD"), + ) + + tf_http_archive( + name = "highwayhash", + urls = [ + "http://mirror.bazel.build/github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz", + "https://github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz", + ], + sha256 = "9c3e0e87d581feeb0c18d814d98f170ff23e62967a2bd6855847f0b2fe598a37", + strip_prefix = "highwayhash-fd3d9af80465e4383162e4a7c5e2f406e82dd968", + build_file = clean_dep("//third_party:highwayhash.BUILD"), + ) + + tf_http_archive( + name = "nasm", + urls = [ + "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2", + "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.13.03.tar.bz2/sha512/d7a6b4cee8dfd603d8d4c976e5287b5cc542fa0b466ff989b743276a6e28114e64289bf02a7819eca63142a5278aa6eed57773007e5f589e15768e6456a8919d/nasm-2.13.03.tar.bz2", + "http://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2", + ], + sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011", + strip_prefix = "nasm-2.13.03", + build_file = clean_dep("//third_party:nasm.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:nasm.BUILD"), + ) + + tf_http_archive( + name = "jpeg", + urls = [ + "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz", + "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz", + ], + sha256 = "1a17020f859cb12711175a67eab5c71fc1904e04b587046218e36106e07eabde", + strip_prefix = "libjpeg-turbo-1.5.3", + build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"), + ) + + tf_http_archive( + name = "png_archive", + urls = [ + "https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz", + "https://github.com/glennrp/libpng/archive/v1.6.34.tar.gz", + ], + sha256 = "e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef", + strip_prefix = "libpng-1.6.34", + build_file = clean_dep("//third_party:png.BUILD"), + patch_file = clean_dep("//third_party:png_fix_rpi.patch"), + system_build_file = clean_dep("//third_party/systemlibs:png.BUILD"), + ) + + tf_http_archive( + name = "org_sqlite", + urls = [ + "https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3240000.zip", + "https://www.sqlite.org/2018/sqlite-amalgamation-3240000.zip", + ], + sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6", + strip_prefix = "sqlite-amalgamation-3240000", + build_file = clean_dep("//third_party:sqlite.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"), + ) + + tf_http_archive( + name = "gif_archive", + urls = [ + "https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz", + "http://pilotfiber.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz", + ], + sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1", + strip_prefix = "giflib-5.1.4", + build_file = clean_dep("//third_party:gif.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:gif.BUILD"), + ) + + tf_http_archive( + name = "six_archive", + urls = [ + "https://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", + "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", + ], + sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", + strip_prefix = "six-1.10.0", + build_file = clean_dep("//third_party:six.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:six.BUILD"), + ) + + tf_http_archive( + name = "astor_archive", + urls = [ + "https://mirror.bazel.build/pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz", + "https://pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz", + ], + sha256 = "ff6d2e2962d834acb125cc4dcc80c54a8c17c253f4cc9d9c43b5102a560bb75d", + strip_prefix = "astor-0.6.2", + build_file = clean_dep("//third_party:astor.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:astor.BUILD"), + ) + + tf_http_archive( + name = "gast_archive", + urls = [ + "https://mirror.bazel.build/pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz", + "https://pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz", + ], + sha256 = "7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930", + strip_prefix = "gast-0.2.0", + build_file = clean_dep("//third_party:gast.BUILD"), + ) + + tf_http_archive( + name = "termcolor_archive", + urls = [ + "https://mirror.bazel.build/pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz", + "https://pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz", + ], + sha256 = "1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b", + strip_prefix = "termcolor-1.1.0", + build_file = clean_dep("//third_party:termcolor.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:termcolor.BUILD"), + ) + + tf_http_archive( + name = "absl_py", + urls = [ + "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz", + "https://github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz", + ], + sha256 = "95160f778a62c7a60ddeadc7bf2d83f85a23a27359814aca12cf949e896fa82c", + strip_prefix = "abseil-py-pypi-v0.2.2", + ) + + tf_http_archive( + name = "org_python_pypi_backports_weakref", + urls = [ + "https://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", + "https://pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", + ], + sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892", + strip_prefix = "backports.weakref-1.0rc1/src", + build_file = clean_dep("//third_party:backports_weakref.BUILD"), + ) + + filegroup_external( + name = "org_python_license", + licenses = ["notice"], # Python 2.0 + sha256_urls = { + "b5556e921715ddb9242c076cae3963f483aa47266c5e37ea4c187f77cc79501c": [ + "https://mirror.bazel.build/docs.python.org/2.7/_sources/license.txt", + "https://docs.python.org/2.7/_sources/license.txt", + ], + }, + ) + + tf_http_archive( + name = "protobuf_archive", + urls = [ + "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz", + "https://github.com/google/protobuf/archive/v3.6.0.tar.gz", + ], + sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4", + strip_prefix = "protobuf-3.6.0", + ) + + # We need to import the protobuf library under the names com_google_protobuf + # and com_google_protobuf_cc to enable proto_library support in bazel. + # Unfortunately there is no way to alias http_archives at the moment. + tf_http_archive( + name = "com_google_protobuf", + urls = [ + "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz", + "https://github.com/google/protobuf/archive/v3.6.0.tar.gz", + ], + sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4", + strip_prefix = "protobuf-3.6.0", + ) + + tf_http_archive( + name = "com_google_protobuf_cc", + urls = [ + "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz", + "https://github.com/google/protobuf/archive/v3.6.0.tar.gz", + ], + sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4", + strip_prefix = "protobuf-3.6.0", + ) + + tf_http_archive( + name = "nsync", + urls = [ + "https://mirror.bazel.build/github.com/google/nsync/archive/1.20.0.tar.gz", + "https://github.com/google/nsync/archive/1.20.0.tar.gz", + ], + sha256 = "0c1b03962b2f8450f21e74a5a46116bf2d6009a807c57eb4207e974a8c4bb7dd", + strip_prefix = "nsync-1.20.0", + ) + + tf_http_archive( + name = "com_google_googletest", + urls = [ + "https://mirror.bazel.build/github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", + "https://github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", + ], + sha256 = "9cbca84c4256bed17df2c8f4d00c912c19d247c11c9ba6647cd6dd5b5c996b8d", + strip_prefix = "googletest-9816b96a6ddc0430671693df90192bbee57108b6", + ) + + tf_http_archive( + name = "com_github_gflags_gflags", + urls = [ + "https://mirror.bazel.build/github.com/gflags/gflags/archive/v2.2.1.tar.gz", + "https://github.com/gflags/gflags/archive/v2.2.1.tar.gz", + ], + sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e", + strip_prefix = "gflags-2.2.1", + ) + + tf_http_archive( + name = "pcre", + sha256 = "69acbc2fbdefb955d42a4c606dfde800c2885711d2979e356c0636efde9ec3b5", + urls = [ + "https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.42.tar.gz", + "http://ftp.exim.org/pub/pcre/pcre-8.42.tar.gz", + ], + strip_prefix = "pcre-8.42", + build_file = clean_dep("//third_party:pcre.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:pcre.BUILD"), + ) + + tf_http_archive( + name = "swig", + sha256 = "58a475dbbd4a4d7075e5fe86d4e54c9edde39847cdb96a3053d87cb64a23a453", + urls = [ + "https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", + "http://ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", + "http://pilotfiber.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz", + ], + strip_prefix = "swig-3.0.8", + build_file = clean_dep("//third_party:swig.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:swig.BUILD"), + ) + + tf_http_archive( + name = "curl", + sha256 = "e9c37986337743f37fd14fe8737f246e97aec94b39d1b71e8a5973f72a9fc4f5", + urls = [ + "https://mirror.bazel.build/curl.haxx.se/download/curl-7.60.0.tar.gz", + "https://curl.haxx.se/download/curl-7.60.0.tar.gz", + ], + strip_prefix = "curl-7.60.0", + build_file = clean_dep("//third_party:curl.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:curl.BUILD"), + ) + + tf_http_archive( + name = "grpc", + urls = [ + "https://mirror.bazel.build/github.com/grpc/grpc/archive/v1.13.0.tar.gz", + "https://github.com/grpc/grpc/archive/v1.13.0.tar.gz", + ], + sha256 = "50db9cf2221354485eb7c3bd55a4c27190caef7048a2a1a15fbe60a498f98b44", + strip_prefix = "grpc-1.13.0", + system_build_file = clean_dep("//third_party/systemlibs:grpc.BUILD"), + ) + + tf_http_archive( + name = "linenoise", + sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7", + urls = [ + "https://mirror.bazel.build/github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", + "https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", + ], + strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3", + build_file = clean_dep("//third_party:linenoise.BUILD"), + ) + + # TODO(phawkins): currently, this rule uses an unofficial LLVM mirror. + # Switch to an official source of snapshots if/when possible. + tf_http_archive( + name = "llvm", + urls = [ + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/36f54002c931a026f490f9fb074c11d91e3487a2.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/36f54002c931a026f490f9fb074c11d91e3487a2.tar.gz", + ], + sha256 = "e360a9e9b0d4f1adedcdb89fc1efc171f68e250c115ddfaeb82d71edef7a10c8", + strip_prefix = "llvm-36f54002c931a026f490f9fb074c11d91e3487a2", + build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"), + ) + + tf_http_archive( + name = "lmdb", + urls = [ + "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz", + "https://github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz", + ], + sha256 = "f3927859882eb608868c8c31586bb7eb84562a40a6bf5cc3e13b6b564641ea28", + strip_prefix = "lmdb-LMDB_0.9.22/libraries/liblmdb", + build_file = clean_dep("//third_party:lmdb.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:lmdb.BUILD"), + ) + + tf_http_archive( + name = "jsoncpp_git", + urls = [ + "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz", + "https://github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz", + ], + sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6", + strip_prefix = "jsoncpp-1.8.4", + build_file = clean_dep("//third_party:jsoncpp.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:jsoncpp.BUILD"), + ) + + tf_http_archive( + name = "boringssl", + urls = [ + "https://mirror.bazel.build/github.com/google/boringssl/archive/45c4a87ae97eb95a8fc2906c035d6a8d0e02e1b8.tar.gz", + "https://github.com/google/boringssl/archive/45c4a87ae97eb95a8fc2906c035d6a8d0e02e1b8.tar.gz", + ], + sha256 = "972e8d8a9d1daf9892fff7155312b1af46b4754446575a7b285e62f917424c78", + strip_prefix = "boringssl-45c4a87ae97eb95a8fc2906c035d6a8d0e02e1b8", + ) + + tf_http_archive( + name = "zlib_archive", + urls = [ + "https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz", + "https://zlib.net/zlib-1.2.11.tar.gz", + ], + sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", + strip_prefix = "zlib-1.2.11", + build_file = clean_dep("//third_party:zlib.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:zlib.BUILD"), + ) + + tf_http_archive( + name = "fft2d", + urls = [ + "https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", + "http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", + ], + sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296", + build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"), + ) + + tf_http_archive( + name = "snappy", + urls = [ + "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.7.tar.gz", + "https://github.com/google/snappy/archive/1.1.7.tar.gz", + ], + sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4", + strip_prefix = "snappy-1.1.7", + build_file = clean_dep("//third_party:snappy.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:snappy.BUILD"), + ) + + tf_http_archive( + name = "nccl_archive", + urls = [ + "https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", + "https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", + ], + sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176", + strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7", + build_file = clean_dep("//third_party:nccl/nccl_archive.BUILD"), + ) + + tf_http_archive( + name = "kafka", + urls = [ + "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz", + "https://github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz", + ], + sha256 = "9d8f1eb7b0e29e9ab1168347c939cb7ae5dff00a39cef99e7ef033fd8f92737c", + strip_prefix = "librdkafka-0.11.4", + build_file = clean_dep("//third_party:kafka/BUILD"), + patch_file = clean_dep("//third_party/kafka:config.patch"), + ) + + tf_http_archive( + name = "aws", + urls = [ + "https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz", + "https://github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz", + ], + sha256 = "b888d8ce5fc10254c3dd6c9020c7764dd53cf39cf011249d0b4deda895de1b7c", + strip_prefix = "aws-sdk-cpp-1.3.15", + build_file = clean_dep("//third_party:aws.BUILD"), + ) + + java_import_external( + name = "junit", + jar_sha256 = "59721f0805e223d84b90677887d9ff567dc534d7c502ca903c0c2b17f05c116a", + jar_urls = [ + "https://mirror.bazel.build/repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar", + "http://repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar", + "http://maven.ibiblio.org/maven2/junit/junit/4.12/junit-4.12.jar", + ], + licenses = ["reciprocal"], # Common Public License Version 1.0 + testonly_ = True, + deps = ["@org_hamcrest_core"], + ) + + java_import_external( + name = "org_hamcrest_core", + jar_sha256 = "66fdef91e9739348df7a096aa384a5685f4e875584cce89386a7a47251c4d8e9", + jar_urls = [ + "https://mirror.bazel.build/repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", + "http://repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", + "http://maven.ibiblio.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", + ], + licenses = ["notice"], # New BSD License + testonly_ = True, + ) + + tf_http_archive( + name = "jemalloc", + urls = [ + "https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", + "https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", + ], + sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8", + strip_prefix = "jemalloc-4.4.0", + build_file = clean_dep("//third_party:jemalloc.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:jemalloc.BUILD"), + ) + + java_import_external( + name = "com_google_testing_compile", + jar_sha256 = "edc180fdcd9f740240da1a7a45673f46f59c5578d8cd3fbc912161f74b5aebb8", + jar_urls = [ + "http://mirror.bazel.build/repo1.maven.org/maven2/com/google/testing/compile/compile-testing/0.11/compile-testing-0.11.jar", + "http://repo1.maven.org/maven2/com/google/testing/compile/compile-testing/0.11/compile-testing-0.11.jar", + ], + licenses = ["notice"], # New BSD License + testonly_ = True, + deps = ["@com_google_guava", "@com_google_truth"], + ) + + java_import_external( + name = "com_google_truth", + jar_sha256 = "032eddc69652b0a1f8d458f999b4a9534965c646b8b5de0eba48ee69407051df", + jar_urls = [ + "http://mirror.bazel.build/repo1.maven.org/maven2/com/google/truth/truth/0.32/truth-0.32.jar", + "http://repo1.maven.org/maven2/com/google/truth/truth/0.32/truth-0.32.jar", + ], + licenses = ["notice"], # Apache 2.0 + testonly_ = True, + deps = ["@com_google_guava"], + ) + + java_import_external( + name = "org_checkerframework_qual", + jar_sha256 = "a17501717ef7c8dda4dba73ded50c0d7cde440fd721acfeacbf19786ceac1ed6", + jar_urls = [ + "http://mirror.bazel.build/repo1.maven.org/maven2/org/checkerframework/checker-qual/2.4.0/checker-qual-2.4.0.jar", + "http://repo1.maven.org/maven2/org/checkerframework/checker-qual/2.4.0/checker-qual-2.4.0.jar", + ], + licenses = ["notice"], # Apache 2.0 + ) + + java_import_external( + name = "com_squareup_javapoet", + jar_sha256 = "5bb5abdfe4366c15c0da3332c57d484e238bd48260d6f9d6acf2b08fdde1efea", + jar_urls = [ + "http://mirror.bazel.build/repo1.maven.org/maven2/com/squareup/javapoet/1.9.0/javapoet-1.9.0.jar", + "http://repo1.maven.org/maven2/com/squareup/javapoet/1.9.0/javapoet-1.9.0.jar", + ], + licenses = ["notice"], # Apache 2.0 + ) + + tf_http_archive( + name = "com_google_pprof", + urls = [ + "https://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", + "https://github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", + ], + sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4", + strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650", + build_file = clean_dep("//third_party:pprof.BUILD"), + ) + + tf_http_archive( + name = "cub_archive", + urls = [ + "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip", + "https://github.com/NVlabs/cub/archive/1.8.0.zip", + ], + sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3", + strip_prefix = "cub-1.8.0", + build_file = clean_dep("//third_party:cub.BUILD"), + ) + + tf_http_archive( + name = "cython", + sha256 = "bccc9aa050ea02595b2440188813b936eaf345e85fb9692790cecfe095cf91aa", + urls = [ + "https://mirror.bazel.build/github.com/cython/cython/archive/0.28.4.tar.gz", + "https://github.com/cython/cython/archive/0.28.4.tar.gz", + ], + strip_prefix = "cython-0.28.4", + build_file = clean_dep("//third_party:cython.BUILD"), + delete = ["BUILD.bazel"], + system_build_file = clean_dep("//third_party/systemlibs:cython.BUILD"), + ) + + tf_http_archive( + name = "bazel_toolchains", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz", + "https://github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz", + ], + strip_prefix = "bazel-toolchains-37acf1841ab1475c98a152cb9e446460c8ae29e1", + sha256 = "3b604699685c5c65dd3f6f17425570a4b2f00ddba2f750db15acc72e55bb098b", + ) + + tf_http_archive( + name = "arm_neon_2_x86_sse", + sha256 = "c8d90aa4357f8079d427e87a6f4c493da1fa4140aee926c05902d7ec1533d9a5", + strip_prefix = "ARM_NEON_2_x86_SSE-0f77d9d182265259b135dad949230ecbf1a2633d", + urls = [ + "https://mirror.bazel.build/github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz", + "https://github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz", + ], + build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"), + ) + + tf_http_archive( + name = "flatbuffers", + strip_prefix = "flatbuffers-1.9.0", + sha256 = "5ca5491e4260cacae30f1a5786d109230db3f3a6e5a0eb45d0d0608293d247e3", + urls = [ + "https://mirror.bazel.build/github.com/google/flatbuffers/archive/v1.9.0.tar.gz", + "https://github.com/google/flatbuffers/archive/v1.9.0.tar.gz", + ], + build_file = clean_dep("//third_party/flatbuffers:flatbuffers.BUILD"), + system_build_file = clean_dep("//third_party/systemlibs:flatbuffers.BUILD"), + ) + + native.new_http_archive( + name = "double_conversion", + urls = [ + "https://github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip", + ], + sha256 = "2f7fbffac0d98d201ad0586f686034371a6d152ca67508ab611adc2386ad30de", + strip_prefix = "double-conversion-3992066a95b823efc8ccc1baf82a1cfc73f6e9b8", + build_file = clean_dep("//third_party:double_conversion.BUILD"), + ) + + tf_http_archive( + name = "tflite_mobilenet", + sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b", + urls = [ + "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", + ], + build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"), + ) + + tf_http_archive( + name = "tflite_mobilenet_ssd", + sha256 = "767057f2837a46d97882734b03428e8dd640b93236052b312b2f0e45613c1cf0", + urls = [ + "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip", + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip", + ], + build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), + ) + tf_http_archive( + name = "tflite_mobilenet_ssd_quant", + sha256 = "a809cd290b4d6a2e8a9d5dad076e0bd695b8091974e0eed1052b480b2f21b6dc", + urls = [ + "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip", + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip", + ], + build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), + ) + + tf_http_archive( + name = "tflite_conv_actions_frozen", + sha256 = "d947b38cba389b5e2d0bfc3ea6cc49c784e187b41a071387b3742d1acac7691e", + urls = [ + "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip", + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip", + ], + build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), + ) + + tf_http_archive( + name = "tflite_smartreply", + sha256 = "8980151b85a87a9c1a3bb1ed4748119e4a85abd3cb5744d83da4d4bd0fbeef7c", + urls = [ + "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip", + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip", + ], + build_file = clean_dep("//third_party:tflite_smartreply.BUILD"), + ) + + tf_http_archive( + name = "tflite_ovic_testdata", + sha256 = "a9a705d8d519220178e2e65d383fdb21da37fdb31d1e909b0a1acdac46479e9c", + urls = [ + "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/data/ovic.zip", + "https://storage.googleapis.com/download.tensorflow.org/data/ovic.zip", + ], + build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"), + strip_prefix = "ovic", + ) + + tf_http_archive( + name = "build_bazel_rules_android", + sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/rules_android/archive/v0.1.1.zip", + "https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip", + ], + strip_prefix = "rules_android-0.1.1", + ) + + ############################################################################## + # BIND DEFINITIONS + # + # Please do not add bind() definitions unless we have no other choice. + # If that ends up being the case, please leave a comment explaining + # why we can't depend on the canonical build target. + + # gRPC wants a cares dependency but its contents is not actually + # important since we have set GRPC_ARES=0 in tools/bazel.rc + native.bind( + name = "cares", + actual = "@grpc//third_party/nanopb:nanopb", + ) + + # Needed by Protobuf + native.bind( + name = "grpc_cpp_plugin", + actual = "@grpc//:grpc_cpp_plugin", + ) + native.bind( + name = "grpc_python_plugin", + actual = "@grpc//:grpc_python_plugin", + ) + + native.bind( + name = "grpc_lib", + actual = "@grpc//:grpc++", + ) + + native.bind( + name = "grpc_lib_unsecure", + actual = "@grpc//:grpc++_unsecure", + ) + + # Needed by gRPC + native.bind( + name = "libssl", + actual = "@boringssl//:ssl", + ) + + # Needed by gRPC + native.bind( + name = "nanopb", + actual = "@grpc//third_party/nanopb:nanopb", + ) + + # Needed by gRPC + native.bind( + name = "protobuf", + actual = "@protobuf_archive//:protobuf", + ) + + # gRPC expects //external:protobuf_clib and //external:protobuf_compiler + # to point to Protobuf's compiler library. + native.bind( + name = "protobuf_clib", + actual = "@protobuf_archive//:protoc_lib", + ) + + # Needed by gRPC + native.bind( + name = "protobuf_headers", + actual = "@protobuf_archive//:protobuf_headers", + ) + + # Needed by Protobuf + native.bind( + name = "python_headers", + actual = clean_dep("//third_party/python_runtime:headers"), + ) + + # Needed by Protobuf + native.bind( + name = "six", + actual = "@six_archive//:six", + ) + + # Needed by gRPC + native.bind( + name = "zlib", + actual = "@zlib_archive//:zlib", + ) diff --git a/tools/bazel.rc b/tools/bazel.rc index 1c1e6afb65ab8da5b689d58ecaec6ac6c8a69bb8..660e3d328038b618fefdf96d60863941d3a46edd 100644 --- a/tools/bazel.rc +++ b/tools/bazel.rc @@ -27,6 +27,10 @@ build --define framework_shared_object=true build:mkl --define=using_mkl=true build:mkl -c opt +# This config option is used to enable MKL-DNN open source library only, +# without depending on MKL binary version. +build:mkl_open_source_only --define=using_mkl_dnn_only=true + build:download_clang --crosstool_top=@local_config_download_clang//:toolchain build:download_clang --define=using_clang=true @@ -36,10 +40,6 @@ build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true -build:win-cuda --define=using_cuda=true --define=using_cuda_nvcc=true - -build:mkl --define=using_mkl=true - build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain build:sycl --define=using_sycl=true --define=using_trisycl=false @@ -59,3 +59,6 @@ build --define=grpc_no_ares=true build --spawn_strategy=standalone build --genrule_strategy=standalone build -c opt + +# Modular TF build options +build:dynamic_kernels --define=dynamic_loaded_kernels=true